├── .gitignore ├── CONTRIBUTING.md ├── Dockerfile ├── LICENSE.txt ├── README.md ├── assets ├── LiberationSans-Regular.ttf ├── chroma_logo.svg ├── chroma_logo_outline.svg ├── conditioners.png ├── lattice.png ├── logo.png ├── proteins.png └── refolding.png ├── chroma ├── __init__.py ├── constants │ ├── __init__.py │ ├── geometry.py │ ├── named_models.py │ └── sequence.py ├── data │ ├── __init__.py │ ├── protein.py │ ├── system.py │ └── xcs.py ├── layers │ ├── __init__.py │ ├── attention.py │ ├── basic.py │ ├── complexity.py │ ├── conv.py │ ├── graph.py │ ├── linalg.py │ ├── norm.py │ ├── sde.py │ └── structure │ │ ├── __init__.py │ │ ├── backbone.py │ │ ├── conditioners.py │ │ ├── diffusion.py │ │ ├── geometry.py │ │ ├── hbonds.py │ │ ├── mvn.py │ │ ├── optimal_transport.py │ │ ├── potts.py │ │ ├── protein_graph.py │ │ ├── protein_graph_allatom.py │ │ ├── rmsd.py │ │ ├── sidechain.py │ │ ├── symmetry.py │ │ └── transforms.py ├── models │ ├── __init__.py │ ├── chroma.py │ ├── graph_backbone.py │ ├── graph_classifier.py │ ├── graph_design.py │ ├── graph_energy.py │ └── procap.py └── utility │ ├── __init__.py │ ├── api.py │ ├── chroma.py │ ├── fetchdb.py │ ├── model.py │ ├── ngl.py │ ├── polyseq.py │ └── starparser.py ├── notebooks ├── ChromaAPI.ipynb ├── ChromaConditioners.ipynb ├── ChromaDemo.ipynb └── ChromaTutorial.ipynb ├── requirements.txt ├── setup.py └── tests ├── __init__.py ├── conftest.py ├── data ├── __init__.py ├── test_protein.py └── test_system.py ├── layers ├── __init__.py ├── structure │ ├── __init__.py │ ├── test_backbone.py │ ├── test_conditioners.py │ ├── test_diffusion.py │ ├── test_geometry.py │ ├── test_hbonds.py │ ├── test_mvn.py │ ├── test_optimal_transport.py │ ├── test_potts.py │ ├── test_protein_graph.py │ ├── test_rmsd.py │ ├── test_sidechain.py │ ├── test_symmetry.py │ └── test_transforms.py ├── test_basic.py ├── test_graph.py ├── test_norm.py └── test_sde.py ├── models ├── __init__.py ├── conftest.py ├── test_chroma.py ├── test_graph_backbone.py ├── test_graph_classifier.py ├── test_graph_design.py ├── test_graph_energy.py └── test_procap.py └── utility ├── __init__.py └── test_api.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea 3 | .ipynb_checkpoints 4 | .coverage 5 | *.egg-info 6 | 7 | *.csv 8 | *.tsv 9 | *.pdb 10 | *.cif 11 | *.pk 12 | *.pt 13 | *.fasta 14 | *.pickle 15 | *.pyc 16 | *.mp4 17 | 18 | *.constraints 19 | *.movemap 20 | *.resfile 21 | test_cmd.sh 22 | 23 | __mmtf__ 24 | __pycache__ 25 | public 26 | htmlcov 27 | make.bat 28 | examples 29 | chroma/layers/structure/params/centering_2g3n.params 30 | wandb 31 | config.json 32 | 33 | # ides 34 | .vscode 35 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Code contributions 2 | 3 | We welcome contributions to the Chroma code base, including new conditioners, integrators, patches, bug fixes, and others. 4 | 5 | Note that your contributions will be governed by the Apache 2.0 license, meaning that you will be giving us permission to use your contributed code under the conditions specified in the [Apache 2.0 license](https://www.apache.org/licenses/LICENSE-2.0) (also available in [LICENSE.txt](LICENSE.txt)). 6 | 7 | ## How to Contribute 8 | 9 | Please use GitHub pull requests to contribute code. See 10 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 11 | information on using pull requests. We will try to monitor incoming requests with some regularity, but cannot promise a specific timeframe within which we will review your request. 12 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:11.3.1-devel-ubuntu20.04 2 | ARG DEBIAN_FRONTEND=noninteractive 3 | RUN apt-get update && apt-get install -y --no-install-recommends \ 4 | build-essential \ 5 | cmake \ 6 | git \ 7 | curl \ 8 | ca-certificates \ 9 | libjpeg-dev \ 10 | libpng-dev && \ 11 | rm -rf /var/lib/apt/lists/* 12 | 13 | WORKDIR /tmp 14 | 15 | RUN curl -o ~/miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ 16 | chmod +x ~/miniconda.sh && \ 17 | ~/miniconda.sh -b -p /opt/conda && \ 18 | rm ~/miniconda.sh 19 | RUN /opt/conda/bin/conda create --name chroma python=3.9.7 20 | RUN /opt/conda/envs/chroma/bin/pip install torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113 21 | WORKDIR /workspace 22 | COPY . . 23 | RUN /opt/conda/envs/chroma/bin/pip install . 24 | ENV PATH /opt/conda/envs/chroma/bin:$PATH 25 | -------------------------------------------------------------------------------- /assets/LiberationSans-Regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/generatebio/chroma/929407c605013613941803c6113adefdccaad679/assets/LiberationSans-Regular.ttf -------------------------------------------------------------------------------- /assets/chroma_logo_outline.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 87 | 109 | 110 | -------------------------------------------------------------------------------- /assets/conditioners.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/generatebio/chroma/929407c605013613941803c6113adefdccaad679/assets/conditioners.png -------------------------------------------------------------------------------- /assets/lattice.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/generatebio/chroma/929407c605013613941803c6113adefdccaad679/assets/lattice.png -------------------------------------------------------------------------------- /assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/generatebio/chroma/929407c605013613941803c6113adefdccaad679/assets/logo.png -------------------------------------------------------------------------------- /assets/proteins.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/generatebio/chroma/929407c605013613941803c6113adefdccaad679/assets/proteins.png -------------------------------------------------------------------------------- /assets/refolding.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/generatebio/chroma/929407c605013613941803c6113adefdccaad679/assets/refolding.png -------------------------------------------------------------------------------- /chroma/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Generate Biomedicines, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | __version__ = "1.0.0" 16 | from chroma.data.protein import Protein 17 | from chroma.layers.structure import conditioners 18 | from chroma.models.chroma import Chroma 19 | from chroma.utility import api 20 | -------------------------------------------------------------------------------- /chroma/constants/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Generate Biomedicines, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from chroma.constants.geometry import AA_GEOMETRY 16 | from chroma.constants.sequence import * 17 | -------------------------------------------------------------------------------- /chroma/constants/named_models.py: -------------------------------------------------------------------------------- 1 | # Copyright Generate Biomedicines, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ Paths for named models in the zoo """ 16 | 17 | GRAPH_BACKBONE_MODELS = { 18 | "public": { 19 | "s3_uri": "https://chroma-weights.generatebiomedicines.com/downloads?weights=chroma_backbone_v1.0.pt", 20 | "data": "Generate Structure ETL: July 25 2022", 21 | "task": "BLNL backbone model training with EMA, trained July 2023", 22 | }, 23 | } 24 | 25 | GRAPH_CLASSIFIER_MODELS = { 26 | "public": { 27 | "s3_uri": "https://chroma-weights.generatebiomedicines.com/downloads?weights=chroma_proclass_v1.0.pt", 28 | "data": "Generate Structure ETL: June 2022", 29 | "task": "Backbone classification model training with cross-entropy loss", 30 | }, 31 | } 32 | 33 | GRAPH_DESIGN_MODELS = { 34 | "public": { 35 | "s3_uri": "https://chroma-weights.generatebiomedicines.com/downloads?weights=chroma_design_v1.0.pt", 36 | "data": "Generate Structure ETL: July 25 2022", 37 | "task": "Autoregressive joint prediction of sequence and chi angles, two-stage", 38 | }, 39 | } 40 | 41 | PROCAP_MODELS = { 42 | "public": { 43 | "s3_uri": "https://chroma-weights.generatebiomedicines.com/downloads?weights=chroma_procap_v1.0.pt", 44 | "data": "Generate Structure ETL: June 2022", 45 | "task": "Backbone caption model training with cross-entropy loss, using M5 ProClass GNN embeddings", 46 | }, 47 | } 48 | 49 | NAMED_MODELS = { 50 | "GraphBackbone": GRAPH_BACKBONE_MODELS, 51 | "GraphDesign": GRAPH_DESIGN_MODELS, 52 | "GraphClassifier": GRAPH_CLASSIFIER_MODELS, 53 | "ProteinCaption": PROCAP_MODELS, 54 | } 55 | -------------------------------------------------------------------------------- /chroma/constants/sequence.py: -------------------------------------------------------------------------------- 1 | # Copyright Generate Biomedicines, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Constants used across protein representations. 16 | 17 | These constants standardize protein tokenization alphabets, ideal structure 18 | geometries and topologies, etc. 19 | """ 20 | from chroma.constants.geometry import AA_GEOMETRY 21 | 22 | # Standard tokenization for Omniprot and Omniprot-interacting models 23 | OMNIPROT_TOKENS = "ABCDEFGHIKLMNOPQRSTUVWYXZ*-#" 24 | POTTS_EXTENDED_TOKENS = "ACDEFGHIKLMNPQRSTVWY-*#" 25 | PAD = "-" 26 | START = "@" 27 | STOP = "*" 28 | MASK = "#" 29 | DNA_TOKENS = "ACGT" 30 | RNA_TOKENS = "AGCU" 31 | PROTEIN_TOKENS = "ACDEFGHIKLMNPQRSTVWY" 32 | 33 | # Minimal 20-letter alphabet and corresponding triplet codes 34 | AA20 = "ACDEFGHIKLMNPQRSTVWY" 35 | AA20_3_TO_1 = { 36 | "ALA": "A", 37 | "ARG": "R", 38 | "ASN": "N", 39 | "ASP": "D", 40 | "CYS": "C", 41 | "GLN": "Q", 42 | "GLU": "E", 43 | "GLY": "G", 44 | "HIS": "H", 45 | "ILE": "I", 46 | "LEU": "L", 47 | "LYS": "K", 48 | "MET": "M", 49 | "PHE": "F", 50 | "PRO": "P", 51 | "SER": "S", 52 | "THR": "T", 53 | "TRP": "W", 54 | "TYR": "Y", 55 | "VAL": "V", 56 | } 57 | AA20_1_TO_3 = { 58 | "A": "ALA", 59 | "R": "ARG", 60 | "N": "ASN", 61 | "D": "ASP", 62 | "C": "CYS", 63 | "Q": "GLN", 64 | "E": "GLU", 65 | "G": "GLY", 66 | "H": "HIS", 67 | "I": "ILE", 68 | "L": "LEU", 69 | "K": "LYS", 70 | "M": "MET", 71 | "F": "PHE", 72 | "P": "PRO", 73 | "S": "SER", 74 | "T": "THR", 75 | "W": "TRP", 76 | "Y": "TYR", 77 | "V": "VAL", 78 | } 79 | AA20_3 = [AA20_1_TO_3[aa] for aa in AA20] 80 | 81 | # Adding noncanonical amino acids 82 | NONCANON_AA = [ 83 | "HSD", 84 | "HSE", 85 | "HSC", 86 | "HSP", 87 | "MSE", 88 | "CSO", 89 | "SEC", 90 | "CSX", 91 | "HIP", 92 | "SEP", 93 | "TPO", 94 | ] 95 | AA31_3 = AA20_3 + NONCANON_AA 96 | 97 | # Chain alphabet for PDB chain naming 98 | CHAIN_ALPHABET = "_ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" 99 | 100 | # Standard atom indexing 101 | ATOMS_BB = ["N", "CA", "C", "O"] 102 | 103 | ATOM_SYMMETRIES = { 104 | "ARG": [("NH1", "NH2")], # Correct handling of NH1 and NH2 is relabeling 105 | "ASP": [("OD1", "OD2")], 106 | "GLU": [("OE1", "OE2")], 107 | "PHE": [("CD1", "CD2"), ("CE1", "CE2")], 108 | "TYR": [("CD1", "CD2"), ("CE1", "CE2")], 109 | } 110 | 111 | AA20_NUM_ATOMS = [4 + len(AA_GEOMETRY[aa]["atoms"]) for aa in AA20_3] 112 | AA20_NUM_CHI = [len(AA_GEOMETRY[aa]["chi_indices"]) for aa in AA20_3] 113 | -------------------------------------------------------------------------------- /chroma/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Generate Biomedicines, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | This package includes io formats and tools for a few common datatypes, 17 | including antibodies, proteins, sequences, and structures. 18 | """ 19 | from chroma.data.protein import Protein 20 | -------------------------------------------------------------------------------- /chroma/data/xcs.py: -------------------------------------------------------------------------------- 1 | # Copyright Generate Biomedicines, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """XCS represents protein structure as a tuple of PyTorch tensors. 16 | 17 | The tensors in an XCS representation are: 18 | 19 | `X` (FloatTensor), the Cartesian coordinates representing the protein 20 | structure with shape `(num_batch, num_residues, num_atoms, 3)`. The 21 | `num_atoms` dimension can be one of two sizes: `num_atoms=4` for 22 | backbone-only structures or `num_atoms=14` for all-atom structures 23 | (excluding hydrogens). The first four atoms will always be 24 | `N, CA, C, O`, and the meaning of the optional 10 additional atom 25 | positions will vary based on the residue identity at 26 | a given position. Atom orders for each amino acid are defined in 27 | `constants.AA_GEOMETRY[TRIPLET_CODE]["atoms"]`. 28 | 29 | `C` (LongTensor), the chain map encoding per-residue chain assignments with 30 | shape `(num_batch, num_residues)`.The chain map codes positions as `0` 31 | when masked, poitive integers for chain indices, and negative integers 32 | to represent missing residues (of the corresponding positive integers). 33 | 34 | `S` (LongTensor), the sequence of the protein as alphabet indices with 35 | shape `(num_batch, num_residues)`. The standard alphabet is 36 | `ACDEFGHIKLMNPQRSTVWY`, also defined in `constants.AA20`. 37 | """ 38 | 39 | 40 | from functools import partial, wraps 41 | from inspect import getfullargspec 42 | 43 | import torch 44 | from torch.nn import functional as F 45 | 46 | try: 47 | pass 48 | except ImportError: 49 | print("MST not installed!") 50 | 51 | 52 | def validate_XCS(all_atom=None, sequence=True): 53 | """Decorator factory that adds XCS validation to any function. 54 | 55 | Args: 56 | all_atom (bool, optional): If True, requires that input structure 57 | tensors have 14 residues per atom. If False, reduces to 4 residues 58 | per atom. If None, applies no transformation on input structures. 59 | sequence (bool, optional): If True, makes sure that if S and O are both 60 | provided, that they match, i.e. that O is a one-hot version of S. 61 | If only one of S or O is provided, the other is generated, and both 62 | are passed. 63 | """ 64 | 65 | def decorator(func): 66 | @wraps(func) 67 | def new_func(*args, **kwargs): 68 | args = list(args) 69 | arg_list = getfullargspec(func)[0] 70 | tensors = {} 71 | for var in ["X", "C", "S", "O"]: 72 | try: 73 | if var in kwargs: 74 | tensors[var] = kwargs[var] 75 | else: 76 | tensors[var] = args[arg_list.index(var)] 77 | except IndexError: # empty args_list 78 | tensors[var] = None 79 | except ValueError: # variable not an argument of function 80 | if not sequence and var in ["S", "O"]: 81 | pass 82 | else: 83 | raise Exception( 84 | f"Variable {var} is required by validation but not defined!" 85 | ) 86 | if tensors["X"] is not None and tensors["C"] is not None: 87 | if tensors["X"].shape[:2] != tensors["C"].shape[:2]: 88 | raise ValueError( 89 | f"X shape {tensors['X'].shape} does not match C shape" 90 | f" {tensors['C'].shape}" 91 | ) 92 | if all_atom is not None and tensors["X"] is not None: 93 | if all_atom and tensors["X"].shape[2] != 14: 94 | raise ValueError("Side chain atoms missing!") 95 | elif not all_atom: 96 | if "X" in kwargs: 97 | kwargs["X"] = tensors["X"][:, :, :4] 98 | else: 99 | args[arg_list.index("X")] = tensors["X"][:, :, :4] 100 | if sequence and (tensors["S"] is not None or tensors["O"] is not None): 101 | if tensors["O"] is None: 102 | if "O" in kwargs: 103 | kwargs["O"] = F.one_hot(tensors["S"], 20).float() 104 | else: 105 | args[arg_list.index("O")] = F.one_hot(tensors["S"], 20).float() 106 | elif tensors["S"] is None: 107 | if "S" in kwargs: 108 | kwargs["S"] = tensors["O"].argmax(dim=2) 109 | else: 110 | args[arg_list.index("S")] = tensors["O"].argmax(dim=2) 111 | else: 112 | if not torch.allclose(tensors["O"].argmax(dim=2), tensors["S"]): 113 | raise ValueError("S and O are both provided but don't match!") 114 | return func(*args, **kwargs) 115 | 116 | return new_func 117 | 118 | return decorator 119 | 120 | 121 | validate_XC = partial(validate_XCS, sequence=False) 122 | -------------------------------------------------------------------------------- /chroma/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Generate Biomedicines, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | This package contains low-level PyTorch layers, including ``nn.Module`` s and ops. 17 | These layers are often used in :mod:`chroma.models`. 18 | """ 19 | -------------------------------------------------------------------------------- /chroma/layers/complexity.py: -------------------------------------------------------------------------------- 1 | # Copyright Generate Biomedicines, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Layers for computing sequence complexities. 16 | """ 17 | 18 | import numpy as np 19 | import torch 20 | import torch.nn.functional as F 21 | 22 | from chroma.constants import AA20 23 | from chroma.layers.graph import collect_neighbors 24 | 25 | 26 | def compositions(S: torch.Tensor, C: torch.LongTensor, w: int = 30): 27 | """Compute local compositions per residue. 28 | 29 | Args: 30 | S (torch.Tensor): Sequence tensor with shape `(num_batch, num_residues)` 31 | (long) or `(num_batch, num_residues, num_alphabet)` (float). 32 | C (torch.LongTensor): Chain map with shape `(num_batch, num_residues)`. 33 | w (int, optional): Window size. 34 | 35 | Returns: 36 | P (torch.Tensor): Local compositions with shape 37 | `(num_batch, num_residues - w + 1, num_alphabet)`. 38 | N (torch.Tensor): Local counts with shape 39 | `(num_batch, num_residues - w + 1, num_alphabet)`. 40 | mask_P (torch.Tensor): Mask with shape 41 | `(num_batch, num_residues - w + 1)`. 42 | """ 43 | device = S.device 44 | Q = len(AA20) 45 | mask_i = (C > 0).float() 46 | if len(S.shape) == 2: 47 | S = F.one_hot(S, Q) 48 | 49 | # Build neighborhoods and masks 50 | S_onehot = mask_i[..., None] * S 51 | kx = torch.arange(w, device=S.device) - w // 2 52 | edge_idx = ( 53 | torch.arange(S.shape[1], device=S.device)[None, :, None] + kx[None, None, :] 54 | ) 55 | mask_ij = (edge_idx > 0) & (edge_idx < S.shape[1]) 56 | edge_idx = edge_idx.clamp(min=0, max=S.shape[1] - 1) 57 | C_i = C[..., None] 58 | C_j = collect_neighbors(C_i, edge_idx)[..., 0] 59 | mask_ij = (mask_ij & C_j.eq(C_i) & (C_i > 0) & (C_j > 0)).float() 60 | 61 | # Sum neighborhood composition 62 | S_j = mask_ij[..., None] * collect_neighbors(S_onehot, edge_idx) 63 | N = S_j.sum(2) 64 | 65 | num_N = N.sum(-1, keepdims=True) 66 | P = N / (num_N + 1e-5) 67 | mask_i = ((num_N[..., 0] > 0) & (C > 0)).float() 68 | mask_ij = mask_i[..., None] * mask_ij 69 | return P, N, edge_idx, mask_i, mask_ij 70 | 71 | 72 | def complexity_lcp( 73 | S: torch.LongTensor, 74 | C: torch.LongTensor, 75 | w: int = 30, 76 | entropy_min: float = 2.32, 77 | method: str = "naive", 78 | differentiable=True, 79 | eps: float = 1e-5, 80 | min_coverage=0.9, 81 | # entropy_min: float = 2.52, 82 | # method = "chao-shen" 83 | ) -> torch.Tensor: 84 | """Compute the Local Composition Perplexity metric. 85 | 86 | Args: 87 | S (torch.Tensor): Sequence tensor with shape `(num_batch, num_residues)` 88 | (index tensor) or `(num_batch, num_residues, num_alphabet)`. 89 | C (torch.LongTensor): Chain map with shape `(num_batch, num_residues)`. 90 | w (int): Window size. 91 | grad_pseudocount (float): Pseudocount for stabilizing entropy gradients 92 | on backwards pass. 93 | eps (float): Small number for numerical stability in division and logarithms. 94 | 95 | Returns: 96 | U (torch.Tensor): Complexities with shape `(num_batch)`. 97 | """ 98 | 99 | # adjust window size based on sequence length 100 | if S.shape[1] < w: 101 | w = S.shape[1] 102 | 103 | P, N, edge_idx, mask_i, mask_ij = compositions(S, C, w) 104 | 105 | # Only count windows with `min_coverage` 106 | min_N = int(min_coverage * w) 107 | mask_coverage = N.sum(-1) > int(min_coverage * w) 108 | 109 | H = estimate_entropy(N, method=method) 110 | U = mask_coverage * (torch.exp(H) - np.exp(entropy_min)).clamp(max=0).square() 111 | 112 | # Compute entropy as a function of perturbed counts 113 | if differentiable and len(S.shape) == 3: 114 | # Compute how a mutation changes entropy for each neighbor 115 | N_neighbors = collect_neighbors(N, edge_idx) 116 | mask_coverage_j = collect_neighbors(mask_coverage[..., None], edge_idx) 117 | N_ij = (N_neighbors - S[:, :, None, :])[..., None, :] + torch.eye( 118 | N.shape[-1], device=N.device 119 | )[None, None, None, ...] 120 | N_ij = N_ij.clamp(min=0) 121 | H_ij = estimate_entropy(N_ij, method=method) 122 | U_ij = (torch.exp(H_ij) - np.exp(entropy_min)).clamp(max=0).square() 123 | U_ij = mask_ij[..., None] * mask_coverage_j * U_ij 124 | U_differentiable = (U_ij.detach() * S[:, :, None, :]).sum([-1, -2]) 125 | U = U.detach() + U_differentiable - U_differentiable.detach() 126 | 127 | U = (mask_i * U).sum(1) 128 | return U 129 | 130 | 131 | def complexity_scores_lcp_t( 132 | t, 133 | S: torch.LongTensor, 134 | C: torch.LongTensor, 135 | idx: torch.LongTensor, 136 | edge_idx_t: torch.LongTensor, 137 | mask_ij_t: torch.Tensor, 138 | w: int = 30, 139 | entropy_min: float = 2.515, 140 | eps: float = 1e-5, 141 | method: str = "chao-shen", 142 | ) -> torch.Tensor: 143 | """Compute local LCP scores for autoregressive decoding.""" 144 | Q = len(AA20) 145 | O = F.one_hot(S, Q) 146 | O_j = collect_neighbors(O, edge_idx_t) 147 | idx_i = idx[:, t, None] 148 | C_i = C[:, t, None] 149 | idx_j = collect_neighbors(idx[..., None], edge_idx_t)[..., 0] 150 | C_j = collect_neighbors(C[..., None], edge_idx_t)[..., 0] 151 | 152 | # Sum valid neighbor counts 153 | is_near = (idx_i - idx_j).abs() <= w / 2 154 | same_chain = C_i == C_j 155 | valid_ij_t = (is_near * same_chain * (mask_ij_t > 0)).float()[..., None] 156 | N_k = (valid_ij_t * O_j).sum(-2) 157 | 158 | # Compute counts under all possible extensions 159 | N_k = N_k[:, :, None, :] + torch.eye(Q, device=N_k.device)[None, None, ...] 160 | 161 | H = estimate_entropy(N_k, method=method) 162 | U = -(torch.exp(H) - np.exp(entropy_min)).clamp(max=0).square() 163 | return U 164 | 165 | 166 | def estimate_entropy( 167 | N: torch.Tensor, method: str = "chao-shen", eps: float = 1e-11 168 | ) -> torch.Tensor: 169 | """Estimate entropy from counts. 170 | 171 | See Chao, A., & Shen, T. J. (2003) for more details. 172 | 173 | Args: 174 | N (torch.Tensor): Tensor of counts with shape `(..., num_bins)`. 175 | 176 | Returns: 177 | H (torch.Tensor): Estimated entropy with shape `(...)`. 178 | """ 179 | N = N.float() 180 | N_total = N.sum(-1, keepdims=True) 181 | P = N / (N_total + eps) 182 | 183 | if method == "chao-shen": 184 | # Estimate coverage and adjusted frequencies 185 | singletons = N.long().eq(1).sum(-1, keepdims=True).float() 186 | C = 1.0 - singletons / (N_total + eps) 187 | P_adjust = C * P 188 | P_inclusion = (1.0 - (1.0 - P_adjust) ** N_total).clamp(min=eps) 189 | H = -(P_adjust * torch.log(P_adjust.clamp(min=eps)) / P_inclusion).sum(-1) 190 | elif method == "miller-maddow": 191 | bins = (N > 0).float().sum(-1) 192 | bias = (bins - 1) / (2 * N_total[..., 0] + eps) 193 | H = -(P * torch.log(P + eps)).sum(-1) + bias 194 | elif method == "laplace": 195 | N = N.float() + 1 / N.shape[-1] 196 | N_total = N.sum(-1, keepdims=True) 197 | P = N / (N_total + eps) 198 | H = -(P * torch.log(P)).sum(-1) 199 | else: 200 | H = -(P * torch.log(P + eps)).sum(-1) 201 | return H 202 | -------------------------------------------------------------------------------- /chroma/layers/conv.py: -------------------------------------------------------------------------------- 1 | # Copyright Generate Biomedicines, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import platform 16 | 17 | import torch 18 | import torch.nn.functional as F 19 | 20 | MACHINE = platform.machine() 21 | 22 | 23 | def filter1D_linear_decay(Z, B): 24 | """Apply a low-pass filter with batch-heterogeneous coefficients. 25 | 26 | Computes `x_i = z_i + b * x_{i-1}` where `b` varies per batch member. 27 | 28 | Args: 29 | Z (torch.Tensor): Batch of one-dimensional signals with shape `(N, W)`. 30 | B (torch.Tensor): Batch of coefficients with shape `(N)`. 31 | 32 | Returns: 33 | X (torch.Tensor): Result of applying linear recurrence with shape `(N, W)`. 34 | """ 35 | 36 | # Build filter coefficients as powers of B 37 | N, W = Z.shape 38 | k = (W - 1) - torch.arange(W, device=Z.device) 39 | kernel = B[:, None, None] ** k[None, None, :] 40 | 41 | # Pad on left to convolve from backwards in time 42 | Z_pad = F.pad(Z, (W - 1, 0))[None, ...] 43 | 44 | # Group convolution can effectively do one filter per batch 45 | while True: 46 | X = F.conv1d(Z_pad, kernel, stride=1, padding=0, groups=N)[0, :, :] 47 | # on arm64 (M1 Mac) this convolution erroneously sometimes produces NaNs 48 | if ( 49 | (MACHINE == "arm64") 50 | and torch.isnan(X).any() 51 | and (not torch.isnan(Z_pad).any()) 52 | and (not torch.isnan(kernel).any()) 53 | ): 54 | continue 55 | break 56 | return X 57 | -------------------------------------------------------------------------------- /chroma/layers/linalg.py: -------------------------------------------------------------------------------- 1 | # Copyright Generate Biomedicines, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Layers for linear algebra. 16 | 17 | This module contains additional pytorch layers for linear algebra operations, 18 | such as a more parallelization-friendly implementation of eigvenalue estimation. 19 | """ 20 | 21 | import torch 22 | 23 | 24 | def eig_power_iteration(A, num_iterations=50, eps=1e-5): 25 | """Estimate largest magnitude eigenvalue and associated eigenvector. 26 | 27 | This uses a simple power iteration algorithm to estimate leading 28 | eigenvalues, which can often be considerably faster than torch's built-in 29 | eigenvalue routines. All steps are differentiable and small constants are 30 | added to any division to preserve the stability of the gradients. For more 31 | information on power iteration, see 32 | https://en.wikipedia.org/wiki/Power_iteration. 33 | 34 | Args: 35 | A (tensor): Batch of square matrices with shape 36 | `(..., num_dims, num_dims)`. 37 | num_iterations (int, optional): Number of iterations for power 38 | iteration. Default: 50. 39 | eps (float, optional): Small number to prevent division by zero. 40 | Default: 1E-5. 41 | 42 | Returns: 43 | lam (tensor): Batch of estimated highest-magnitude eigenvalues with 44 | shape `(...)`. 45 | v (tensor): Associated eigvector with shape `(..., num_dims)`. 46 | """ 47 | _safe = lambda x: x + eps 48 | 49 | dims = list(A.size())[:-1] 50 | v = torch.randn(dims, device=A.device).unsqueeze(-1) 51 | for i in range(num_iterations): 52 | v_prev = v 53 | Av = torch.matmul(A, v) 54 | v = Av / _safe(Av.norm(p=2, dim=-2, keepdim=True)) 55 | 56 | # Compute eigenvalue 57 | v_prev = v_prev.transpose(-1, -2) 58 | lam = torch.matmul(v_prev, Av) / _safe(torch.abs(torch.matmul(v_prev, v))) 59 | 60 | # Reshape 61 | v = v.squeeze(-1) 62 | lam = lam.view(list(lam.size())[:-2]) 63 | return lam, v 64 | 65 | 66 | def eig_leading(A, num_iterations=50): 67 | """Estimate largest positive eigenvalue and associated eigenvector. 68 | 69 | This estimates the *most positive* eigenvalue of each matrix in a batch of 70 | matrices by using two consecutive power iterations with spectral shifting. 71 | 72 | Args: 73 | A (tensor): Batch of square matrices with shape 74 | `(..., num_dims, num_dims)`. 75 | num_iterations (int, optional): Number of iterations for power 76 | iteration. Default: 50. 77 | 78 | Returns: 79 | lam (tensor): Estimated most positive eigenvalue with shape `(...)`. 80 | v (tensor): Associated eigenvectors with shape `(..., num_dims)`. 81 | """ 82 | batch_dims = list(A.size())[:-2] 83 | 84 | # First pass gets largest magnitude 85 | lam_1, vec_1 = eig_power_iteration(A, num_iterations) 86 | 87 | # Second pass guaranteed to grab most positive eigenvalue 88 | lam_1_abs = torch.abs(lam_1) 89 | lam_I = lam_1_abs.reshape(batch_dims + [1, 1]) * torch.eye(4, device=A.device).view( 90 | [1 for _ in batch_dims] + [4, 4] 91 | ) 92 | A_shift = A + lam_I 93 | lam_2, vec = eig_power_iteration(A_shift, num_iterations) 94 | 95 | # Shift back to original specta 96 | lam = lam_2 - lam_1_abs 97 | return lam, vec 98 | -------------------------------------------------------------------------------- /chroma/layers/sde.py: -------------------------------------------------------------------------------- 1 | # Copyright Generate Biomedicines, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Layers for integrating Stochastic Differential Equations (SDEs). 16 | 17 | 18 | """ 19 | 20 | 21 | from typing import Callable, Tuple 22 | 23 | import torch 24 | from tqdm.autonotebook import tqdm 25 | 26 | 27 | def sde_integrate( 28 | sde_func: Callable, 29 | y0: torch.Tensor, 30 | tspan: Tuple, 31 | N: int, 32 | project_func: Callable = None, 33 | T_grid: torch.Tensor = None, 34 | ) -> list: 35 | """Integrate an Ito SDE with the Euler-Maruyama method. 36 | 37 | args: 38 | sde_func (function): a function that takes in time and y and returns SDE drift and diffusion terms for the evolution of y 39 | y0 (torch.tensor): the initial value of y, e.g. a noised protein structure tensor 40 | tspan (tuple): a tuple (t_i, t_f) with t_i being the initial time and t_f being the final time for integration 41 | N (int): number of integration steps 42 | 43 | returns: 44 | y_trajectory (list): a list of snapshots of the evolution of y as the SDE is integrated 45 | 46 | """ 47 | 48 | with torch.no_grad(): 49 | # Integrate SDE 50 | y_trajectory = [y0] 51 | 52 | if T_grid is None: 53 | T_grid = torch.linspace(tspan[0], tspan[1], N + 1).to(y0.device) 54 | else: 55 | assert T_grid.shape[0] == N + 1 56 | 57 | y = y0 58 | for t0, t1 in tqdm( 59 | zip(T_grid[:-1], T_grid[1:]), total=N, desc="Integrating SDE" 60 | ): 61 | t = t0 62 | dT = t1 - t0 63 | 64 | f, gZ = sde_func(t, y) 65 | y = y + dT * f + dT.abs().sqrt() * gZ 66 | y = y if project_func is None else project_func(t, y) 67 | 68 | y_trajectory.append(y) 69 | return y_trajectory 70 | 71 | 72 | def sde_integrate_heun( 73 | sde_func: Callable, 74 | y0: torch.Tensor, 75 | tspan: Tuple, 76 | N: int, 77 | project_func: Callable = None, 78 | T_grid: torch.Tensor = None, 79 | ) -> list: 80 | """Integrate an Ito SDE with Heun's method. 81 | 82 | args: 83 | sde_func (function): a function that takes in time and y and returns SDE drift and diffusion terms for the evolution of y 84 | y0 (torch.tensor): the initial value of y, e.g. a noised protein structure tensor 85 | tspan (tuple): a tuple (t_i, t_f) with t_i being the initial time and t_f being the final time for integration 86 | N (int): number of integration steps 87 | 88 | returns: 89 | y_trajectory (list): a list of snapshots of the evolution of y as the SDE is integrated 90 | 91 | """ 92 | 93 | with torch.no_grad(): 94 | # Integrate SDE 95 | y_trajectory = [y0] 96 | dT = (tspan[1] - tspan[0]) / N 97 | 98 | if T_grid is None: 99 | T_grid = torch.linspace(tspan[0], tspan[1], N + 1).to(y0.device) 100 | else: 101 | assert T_grid.shape[0] == N + 1 102 | 103 | y = y0 104 | 105 | for t0, t1 in tqdm( 106 | zip(T_grid[:-1], T_grid[1:]), total=N, desc="Integrating SDE" 107 | ): 108 | # for i in tqdm(range(N)): 109 | # t = tspan[0] + i * dT 110 | t = t0 111 | dT = t1 - t0 112 | f, gZ = sde_func(t, y) 113 | y_pred = y + dT * f + dT.abs().sqrt() * gZ 114 | f_pred, gZ_pred = sde_func(t, y_pred) 115 | y_correct = y + dT * f_pred + dT.abs().sqrt() * gZ 116 | y = (y_pred + y_correct) / 2.0 117 | y = y if project_func is None else project_func(t, y) 118 | y_trajectory.append(y) 119 | 120 | return y_trajectory 121 | -------------------------------------------------------------------------------- /chroma/layers/structure/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Generate Biomedicines, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /chroma/layers/structure/hbonds.py: -------------------------------------------------------------------------------- 1 | # Copyright Generate Biomedicines, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Layers for annotating hydrogen bonds in protein structures. 16 | """ 17 | 18 | from typing import Tuple 19 | 20 | import torch 21 | import torch.nn as nn 22 | import torch.nn.functional as F 23 | 24 | from chroma.layers.graph import collect_neighbors 25 | from chroma.layers.structure import protein_graph 26 | from chroma.layers.structure.geometry import normed_vec 27 | 28 | 29 | class BackboneHBonds(nn.Module): 30 | """Compute hydrogen bonds from protein backbones. 31 | 32 | We use the simple electrostatic model for calling hydrogen 33 | bonds of DSSP, which is described at 34 | https://en.wikipedia.org/wiki/DSSP_(algorithm). After 35 | placing virtual hydrogens on all backbone nitrogens, 36 | we consider potential hydrogen bonds with carbonyl groups 37 | on the backbone with residue distance |i-j| > 2. The 38 | picture is: 39 | 40 | -0.20e +0.20e -0.42e +0.42e 41 | [N_i]-----[H_i] ::: [O_j]=====[C_j] 42 | 43 | Args: 44 | cutoff_energy (float, optional): Cutoff energy with 45 | default value -0.5 (DSSP). 46 | cutoff_distance (float, optional): Max distance 47 | between `N_i` and `O_j` with default value 3.6 angstroms. 48 | cutoff_gap (float, optional): Minimum tolerated residue 49 | distance, i.e. `|i-j| >= cutoff_gap`. 50 | Default value of 3. 51 | 52 | Inputs: 53 | X (Tensor): Backbone coordinates with shape 54 | `(num_batch, num_residues, num_atom_types, 3)`. 55 | C (LongTensor): Chain map tensor with shape `(num_batch, num_residues)`. 56 | edge_idx (LongTensor): Edge indices for neighbors with shape 57 | `(num_batch, num_residues, num_neighbors)`. 58 | mask_ij (Tensor): Edge mask with shape 59 | `(num_batch, num_nodes, num_neighbors)`. 60 | 61 | Outputs: 62 | hbonds (Tensor): Binary matrix annotating backbone hydrogen bonds 63 | with shape `(num_batch, num_nodes, num_neighbors)`. 64 | mask_hb_ij (Tensor): Hydrogen bond mask with shape 65 | `(num_batch, num_nodes, num_neighbors)`. 66 | H_i (Tensor): Virtual hydrogen coordinates with shape 67 | `(num_batch, num_nodes, 3)`. 68 | """ 69 | 70 | def __init__( 71 | self, 72 | cutoff_energy: float = -0.5, 73 | cutoff_distance: float = 3.6, 74 | cutoff_gap: float = 3, 75 | distance_eps: float = 1e-3, 76 | ) -> None: 77 | super(BackboneHBonds, self).__init__() 78 | self.cutoff_energy = cutoff_energy 79 | self.cutoff_distance = cutoff_distance 80 | self.cutoff_gap = cutoff_gap 81 | self._coefficient = 0.42 * 0.2 * 332 82 | self._eps = distance_eps 83 | 84 | # Lishan Yao et al. JACS 2008, NMR data 85 | self._length_NH = 1.015 86 | return 87 | 88 | def forward( 89 | self, 90 | X: torch.Tensor, 91 | C: torch.LongTensor, 92 | edge_idx: torch.LongTensor, 93 | mask_ij: torch.Tensor, 94 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 95 | num_batch, num_residues, _, _ = X.shape 96 | # Collect coordinates at i and j 97 | X_flat = X.reshape([num_batch, num_residues, -1]) 98 | X_j_flat = collect_neighbors(X_flat, edge_idx) 99 | X_j = X_j_flat.reshape([num_batch, num_residues, -1, 4, 3]) 100 | 101 | # Get amide [N-H] atoms at i by 102 | # by placing virtual H from C_{i-1}-N-Ca neg bisector 103 | X_prev = F.pad(X, [0, 0, 0, 0, 1, 0], mode="replicate")[:, :-1, :, :] 104 | C_prev_i = X_prev[:, :, 2, :] 105 | N_i = X[:, :, 0, :] 106 | Ca_i = X[:, :, 1, :] 107 | u_CprevN_i = normed_vec(N_i - C_prev_i) 108 | u_CaN_i = normed_vec(N_i - Ca_i) 109 | u_NH_i = normed_vec(u_CprevN_i + u_CaN_i) 110 | H_i = N_i + self._length_NH * u_NH_i 111 | # Add broadcasting dimensions 112 | N_i = N_i[:, :, None, :] 113 | H_i = H_i[:, :, None, :] 114 | 115 | # Get carbonyl [C=O] atoms at j 116 | O_j = X_j[:, :, :, 3, :] 117 | C_j = X_j[:, :, :, 2, :] 118 | 119 | _invD = ( 120 | lambda Xi, Xj: (Xi - Xj).square().sum(-1).add(self._eps).sqrt().reciprocal() 121 | ) 122 | U_ij = self._coefficient * ( 123 | _invD(N_i, O_j) - _invD(N_i, C_j) + _invD(H_i, C_j) - _invD(H_i, O_j) 124 | ) 125 | 126 | # Mask any bonds exceeding donor/acceptor cutoff distance 127 | D_nonhydrogen = (N_i - O_j).square().sum(-1).add(self._eps).sqrt() 128 | mask_ij_cutoff_D = (D_nonhydrogen < self.cutoff_distance).float() 129 | 130 | # Mask hbonds on same chain with |i-j| < gap_cutoff 131 | mask_ij_nonlocal = 1.0 - _locality_mask(C, edge_idx, cutoff=self.cutoff_gap) 132 | 133 | # Ignore N terminal hydrogen bonding because of ambiguous hydrogen placement 134 | C_prev = F.pad(C, [1, 0], "constant")[:, 1:] 135 | mask_i = ((C > 0) * (C == C_prev)).float() 136 | mask_j = collect_neighbors(C[..., None], edge_idx)[..., 0] 137 | mask_ij_internal = mask_i[..., None] * (mask_j > 0).float() 138 | 139 | mask_hb_ij = mask_ij * mask_ij_nonlocal * mask_ij_cutoff_D * mask_ij_internal 140 | 141 | # Call hydrogen bonds 142 | hbonds = mask_hb_ij * (U_ij < self.cutoff_energy).float() 143 | return hbonds, mask_hb_ij, H_i 144 | 145 | 146 | class LossBackboneHBonds(nn.Module): 147 | """Score hydrogen bond recovery from protein backbones. 148 | 149 | Args: 150 | See `BackboneHBonds`. 151 | 152 | Inputs: 153 | X (Tensor): Backbone coordinates to score with shape 154 | `(num_batch, num_residues, 4, 3)`. 155 | X_target (Tensor): Reference coordinates to compare to with shape 156 | `(num_batch, num_residues, 4, 3)`. 157 | C (LongTensor): Chain map tensor with shape `(num_batch, num_residues)`. 158 | 159 | Outputs: 160 | recovery_local (Tensor): Local hydrogen bond recovery with shape 161 | `(num_batch)`. 162 | recovery_nonlocal (Tensor): Nonlocal hydrogen bond recovery with shape 163 | `(num_batch)`. 164 | error_co (Tensor): Absolute error in terms of contact order recovery 165 | """ 166 | 167 | def __init__( 168 | self, 169 | cutoff_local: float = 8, 170 | cutoff_energy: float = -0.5, 171 | cutoff_distance: float = 3.6, 172 | cutoff_gap: float = 3, 173 | distance_eps: float = 1e-3, 174 | num_neighbors: int = 30, 175 | ) -> None: 176 | super(LossBackboneHBonds, self).__init__() 177 | self.cutoff_local = cutoff_local 178 | self.cutoff_energy = cutoff_energy 179 | self.cutoff_distance = cutoff_distance 180 | self.cutoff_gap = cutoff_gap 181 | self._eps = 1e-3 182 | 183 | self.graph_builder = protein_graph.ProteinGraph(num_neighbors=num_neighbors) 184 | self.hbonds = BackboneHBonds( 185 | cutoff_energy=cutoff_energy, 186 | cutoff_distance=cutoff_distance, 187 | cutoff_gap=cutoff_gap, 188 | ) 189 | 190 | def forward( 191 | self, X: torch.Tensor, X_target: torch.Tensor, C: torch.LongTensor, 192 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 193 | # Build Graph 194 | edge_idx, mask_ij = self.graph_builder(X_target, C) 195 | hb_target, mask_hb, H_i = self.hbonds(X_target, C, edge_idx, mask_ij) 196 | hb_current, _, _ = self.hbonds(X, C, edge_idx, mask_ij) 197 | 198 | # Split into local and long range hbonds 199 | mask_local = _locality_mask(C, edge_idx, cutoff=self.cutoff_local) 200 | hb_target_local = mask_local * hb_target 201 | hb_target_nonlocal = (1 - mask_local) * hb_target 202 | 203 | # Compute per complex 204 | recovery_local = (hb_current * hb_target_local).sum([1, 2]) / ( 205 | hb_target_local.sum([1, 2]) + self._eps 206 | ) 207 | recovery_nonlocal = (hb_current * hb_target_nonlocal).sum([1, 2]) / ( 208 | hb_target_nonlocal.sum([1, 2]) + self._eps 209 | ) 210 | 211 | # Compute contact order 212 | co_target = _contact_order(hb_target, C, edge_idx) 213 | co_current = _contact_order(hb_current, C, edge_idx) 214 | 215 | error_co = (co_target - co_current).abs() 216 | return recovery_local, recovery_nonlocal, error_co 217 | 218 | 219 | def _ij_distance( 220 | C: torch.LongTensor, edge_idx: torch.LongTensor, 221 | ) -> Tuple[torch.Tensor, torch.Tensor]: 222 | C_i = C[..., None] 223 | C_j = collect_neighbors(C_i, edge_idx)[..., 0] 224 | ix = torch.arange(C.shape[1], device=C.device)[None, :, None].expand( 225 | C.shape[0], -1, -1 226 | ) 227 | jx = collect_neighbors(ix, edge_idx)[..., 0] 228 | dij = (jx - ix).abs() 229 | mask_same_chain = C_i.eq(C_j).float() 230 | return dij, mask_same_chain 231 | 232 | 233 | def _contact_order( 234 | contacts: torch.Tensor, 235 | C: torch.LongTensor, 236 | edge_idx: torch.LongTensor, 237 | eps: float = 1e-3, 238 | ) -> torch.Tensor: 239 | """Compute contact order""" 240 | dij, mask_same_chain = _ij_distance(C, edge_idx) 241 | mask_ij = mask_same_chain * contacts 242 | CO = (mask_ij * dij).sum([1, 2]) / (mask_ij + eps).sum([1, 2]) 243 | return CO 244 | 245 | 246 | def _locality_mask( 247 | C: torch.LongTensor, edge_idx: torch.LongTensor, cutoff: float, 248 | ) -> torch.Tensor: 249 | dij, mask_same_chain = _ij_distance(C, edge_idx) 250 | mask_ij_local = ((dij < cutoff) * mask_same_chain).float() 251 | return mask_ij_local 252 | -------------------------------------------------------------------------------- /chroma/layers/structure/optimal_transport.py: -------------------------------------------------------------------------------- 1 | # Copyright Generate Biomedicines, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Layers for comparing and mapping point clouds via optimal transport. 16 | 17 | This module contains minimalist implementations of basic optimal transport 18 | routines which can be used to, for example, measure similarities between 19 | point clouds of different shapes by computing optimal mappings between them. 20 | For more information see the excellent book by Peyre, 21 | https://arxiv.org/pdf/1803.00567.pdf 22 | """ 23 | 24 | import numpy as np 25 | import torch 26 | 27 | 28 | def optimize_couplings_sinkhorn(C, scale=1.0, iterations=10): 29 | """Solve entropy regularized optimized transport via Sinkhorn iteration. 30 | 31 | This method uses the log-domain for numerical stability. 32 | 33 | Args: 34 | C (Tensor): Batch of cost matrices with with shape `(B, I, J)`. 35 | scale (float, optional): Entropy regularization parameter for 36 | rescaling the cost matrix. 37 | iterations (int, optional): Number of Sinkhorn iterations. 38 | 39 | Returns: 40 | T (Tensor): Couplings map with shape `(B, I, J)`. 41 | """ 42 | log_T = -C * scale 43 | 44 | # Initialize normalizers 45 | B, I, J = log_T.shape 46 | log_u = torch.zeros((B, I), device=log_T.device) 47 | log_v = torch.zeros((B, J), device=log_T.device) 48 | log_a = log_u - np.log(I) 49 | log_b = log_v - np.log(J) 50 | 51 | # Iterate normalizers 52 | for j in range(iterations): 53 | log_u = log_a - torch.logsumexp(log_T + log_v.unsqueeze(1), 2) 54 | log_v = log_b - torch.logsumexp(log_T + log_u.unsqueeze(2), 1) 55 | log_T = log_T + log_v.unsqueeze(1) + log_u.unsqueeze(2) 56 | T = torch.exp(log_T) 57 | return T 58 | 59 | 60 | def optimize_couplings_gw( 61 | D_a, D_b, scale=200.0, iterations_outer=30, iterations_inner=10, 62 | ): 63 | """Gromov-Wasserstein Optimal Transport. 64 | https://arxiv.org/pdf/1905.07645.pdf 65 | 66 | Args: 67 | D_a (Tensor): Distance matrix describing objects in set `a` with shape `(B, I, I)`. 68 | D_b (Tensor): Distance matrix describing objects in set `b` with shape `(B, J, J)`. 69 | scale (float, optional): Entropy regularization parameter for 70 | rescaling the cost matrix. 71 | iterations_outer (int, optional): Number of outer GW iterations. 72 | iterations_inner (int, optional): Number of inner Sinkhorn iterations. 73 | 74 | Returns: 75 | T (Tensor): Couplings map with shape `(B, I, J)`. 76 | 77 | """ 78 | 79 | # Gromov-Wasserstein Distance 80 | N_a = D_a.shape[1] 81 | N_b = D_b.shape[1] 82 | p_a = torch.ones_like(D_a[:, :, 0]) / N_a 83 | p_b = torch.ones_like(D_b[:, :, 0]) / N_b 84 | C_ab = ( 85 | torch.einsum("bij,bj->bi", D_a ** 2, p_a)[:, :, None] 86 | + torch.einsum("bij,bj->bi", D_b ** 2, p_b)[:, None, :] 87 | ) 88 | T_gw = torch.einsum("bi,bj->bij", p_a, p_b) 89 | for i in range(iterations_outer): 90 | cost = C_ab - 2.0 * torch.einsum("bik,bkl,blj->bij", D_a, T_gw, D_b) 91 | T_gw = optimize_couplings_sinkhorn(cost, scale, iterations=iterations_inner) 92 | 93 | # Compute cost 94 | cost = C_ab - 2.0 * torch.einsum("bik,bkl,blj->bij", D_a, T_gw, D_b) 95 | D_gw = (T_gw * cost).sum([-1, -2]).abs().sqrt() 96 | return T_gw, D_gw 97 | -------------------------------------------------------------------------------- /chroma/layers/structure/protein_graph_allatom.py: -------------------------------------------------------------------------------- 1 | # Copyright Generate Biomedicines, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Layers for building graph representations of protein structure, all-atom. 16 | 17 | This module contains pytorch layers for representing protein structure as a 18 | graph with node and edge features based on geometric information. The graph 19 | features are differentiable with respect to input coordinates and can be used 20 | for building protein scoring functions and optimizing protein geometries 21 | natively in pytorch. 22 | """ 23 | 24 | 25 | import numpy as np 26 | import torch 27 | import torch.nn as nn 28 | 29 | from chroma.layers import graph 30 | from chroma.layers.structure import geometry, sidechain 31 | 32 | 33 | class NodeChiRBF(nn.Module): 34 | """Layers for featurizing chi angles with a smooth binning 35 | 36 | Args: 37 | num_chi_bins (int): Number of bins for discretizing chi angles. 38 | num_chi (int): Number of chi angles. 39 | dim_out (int): Number of output feature dimensions. 40 | bin_scale (float, optional): Scaling parameter that sets bin smoothing. 41 | 42 | Input: 43 | chi (Tensor): Chi angles with shape `(num_batch, num_residues, num_chi)`. 44 | 45 | Output: 46 | h_chi (Tensor): Chi angle features with shape 47 | `(num_batch, num_residues, num_chi * num_chi_bins)`. 48 | """ 49 | 50 | def __init__(self, dim_out, num_chi, num_chi_bins, bin_scale=2.0): 51 | super(NodeChiRBF, self).__init__() 52 | self.dim_out = dim_out 53 | self.num_chi = num_chi 54 | self.num_chi_bins = num_chi_bins 55 | self.bin_scale = bin_scale 56 | 57 | self.embed = nn.Linear(self.num_chi * self.num_chi_bins, dim_out) 58 | 59 | def _featurize(self, chi, mask_chi=None): 60 | num_batch, num_residues, _ = chi.shape 61 | 62 | chi_bin_center = ( 63 | torch.arange(0, self.num_chi_bins, device=chi.device) 64 | * 2.0 65 | * np.pi 66 | / self.num_chi_bins 67 | ) 68 | chi_bin_center = chi_bin_center.reshape([1, 1, 1, -1]) 69 | 70 | # Set smoothing length scale based on ratio beteen adjacent bin centers 71 | # bin_i / bin_i+1 = 1 / scale 72 | delta_adjacent = np.cos(0.0) - np.cos(2.0 * np.pi / self.num_chi_bins) 73 | cosine = torch.cos(chi.unsqueeze(-1) - chi_bin_center) 74 | chi_features = torch.exp((cosine - 1.0) * self.bin_scale / delta_adjacent) 75 | if mask_chi is not None: 76 | chi_features = mask_chi.unsqueeze(-1) * chi_features 77 | chi_features = chi_features.reshape( 78 | [num_batch, num_residues, self.num_chi * self.num_chi_bins] 79 | ) 80 | return chi_features 81 | 82 | def forward(self, chi, mask_chi=None): 83 | chi_features = self._featurize(chi, mask_chi=mask_chi) 84 | h_chi = self.embed(chi_features) 85 | return h_chi 86 | 87 | 88 | class EdgeSidechainsDirect(nn.Module): 89 | """Layers for direct encoding of side chain geometries. 90 | 91 | Args: 92 | dim_out (int): Number of output hidden dimensions. 93 | max_D (float, optional): Maximum distance cutoff for encoding 94 | of edges. 95 | 96 | Input: 97 | X (Tensor): All atom coordinates with shape 98 | `(num_batch, num_residues, 14, 3)`. 99 | C (LongTensor): Chain map with shape `(num_batch, num_residues)`. 100 | S (LongTensor): Sequence tensor with shape 101 | `(num_batch, num_residues)`. 102 | edge_idx (Tensor): Graph indices for expansion with shape 103 | `(num_batch, num_residues_out, num_neighbors)`. The dimension 104 | of output variables `num_residues_out` must either equal 105 | `num_residues` or 1, the latter of which can be useful for sequential 106 | decoding. 107 | 108 | Output: 109 | h (Tensor): Features with shape 110 | `(num_batch, num_residues_out, num_neighbors, num_hidden)`. 111 | """ 112 | 113 | def __init__( 114 | self, 115 | dim_out, 116 | length_scale=7.5, 117 | distance_eps=0.1, 118 | num_fourier=30, 119 | fourier_order=2, 120 | basis_type="rff", 121 | ): 122 | super(EdgeSidechainsDirect, self).__init__() 123 | self.dim_out = dim_out 124 | self.length_scale = length_scale 125 | self.distance_eps = distance_eps 126 | 127 | # self.embed = nn.Linear(14 * 3 , dim_out) 128 | self.num_fourier = num_fourier 129 | self.rff = torch.nn.Parameter( 130 | 2.0 * np.pi / self.length_scale * torch.randn((3, self.num_fourier)) 131 | ) 132 | self.basis_type = basis_type 133 | if self.basis_type == "rff": 134 | self.embed = nn.Linear(14 * self.num_fourier * 2, dim_out) 135 | elif self.basis_type == "spherical": 136 | self.fourier_order = fourier_order 137 | self.embed = nn.Linear(14 * (self.fourier_order * 2) ** 3, dim_out) 138 | 139 | def _local_coordinates(self, X, C, S, edge_idx): 140 | num_batch, num_residues, num_neighbors = edge_idx.shape 141 | 142 | # Mask and transform into features 143 | mask_atoms = sidechain.atom_mask(C, S) 144 | mask_atoms_j = graph.collect_neighbors(mask_atoms, edge_idx) 145 | mask_i = (C > 0).float().reshape([num_batch, num_residues, 1, 1]) 146 | mask_atoms_ij = mask_i * mask_atoms_j 147 | 148 | # Build conditioning mask 149 | R_i, CA = geometry.frames_from_backbone(X[:, :, :4, :]) 150 | 151 | # Transform neighbor X coordinates into local frames 152 | X_flat = X.reshape([num_batch, num_residues, -1]) 153 | X_j_flat = graph.collect_neighbors(X_flat, edge_idx) 154 | X_j = X_j_flat.reshape([num_batch, num_residues, num_neighbors, 14, 3]) 155 | dX_ij = X_j - CA.reshape([num_batch, num_residues, 1, 1, 3]) 156 | U_ij = torch.einsum("niab,nijma->nijmb", R_i, dX_ij) 157 | return U_ij, mask_atoms_ij 158 | 159 | def _local_coordinates_t(self, t, X, C, S, edge_idx_t): 160 | num_batch, _, num_neighbors = edge_idx_t.shape 161 | num_residues = X.shape[1] 162 | 163 | # Make a mask that 164 | C_i = C[:, t].unsqueeze(1) 165 | # S_i = S[:,t].unsqueeze(1) 166 | # mask_atoms_i = sidechain.atom_mask(C_i, S_i) 167 | C_j = graph.collect_neighbors(C.unsqueeze(-1), edge_idx_t).reshape( 168 | [num_batch, num_neighbors] 169 | ) 170 | S_j = graph.collect_neighbors(S.unsqueeze(-1), edge_idx_t).reshape( 171 | [num_batch, num_neighbors] 172 | ) 173 | mask_atoms_j = sidechain.atom_mask(C_j, S_j).unsqueeze(1) 174 | mask_i = (C_i > 0).float().reshape([num_batch, 1, 1, 1]) 175 | mask_atoms_ij = mask_i * mask_atoms_j 176 | 177 | # Build conditioning mask 178 | X_bb_i = X[:, t, :4, :].unsqueeze(1) 179 | R_i, CA = geometry.frames_from_backbone(X_bb_i) 180 | 181 | # Transform neighbor X coordinates into local frames 182 | X_flat = X.reshape([num_batch, num_residues, -1]) 183 | X_j_flat = graph.collect_neighbors(X_flat, edge_idx_t) 184 | X_j = X_j_flat.reshape([num_batch, 1, num_neighbors, 14, 3]) 185 | dX_ij = X_j - CA.reshape([num_batch, 1, 1, 1, 3]) 186 | U_ij = torch.einsum("niab,nijma->nijmb", R_i, dX_ij) 187 | return U_ij, mask_atoms_ij 188 | 189 | def _fourier_expand(self, h, order): 190 | k = torch.arange(order, device=h.device) 191 | k = k.reshape([1 for i in h.shape] + [-1]) 192 | return torch.cat( 193 | [torch.sin(h.unsqueeze(-1) * (k + 1)), torch.cos(h.unsqueeze(-1) * k)], 194 | dim=-1, 195 | ) 196 | 197 | def _featurize(self, U_ij, mask_atoms_ij): 198 | if self.basis_type == "rff": 199 | # Random fourier features 200 | U_ij = mask_atoms_ij.unsqueeze(-1) * U_ij 201 | U_ff = torch.einsum("nijax,xy->nijay", U_ij, self.rff) 202 | U_ff = torch.concat([torch.cos(U_ff), torch.sin(U_ff)], -1) 203 | 204 | # Gaussian RBF envelope 205 | D_ij = torch.sqrt((U_ij ** 2).sum(-1) + self.distance_eps) 206 | magnitude = torch.exp(-D_ij * D_ij / (2 * self.length_scale ** 2)) 207 | U_ff = magnitude.unsqueeze(-1) * U_ff 208 | 209 | U_ff = U_ff.reshape(list(D_ij.shape)[:3] + [-1]) 210 | h = mask_atoms_ij[:, :, :, 0].unsqueeze(-1) * self.embed(U_ff) 211 | 212 | elif self.basis_type == "spherical": 213 | # Convert to spherical coordinates 214 | r_ij = torch.sqrt((U_ij ** 2).sum(-1) + self.distance_eps) 215 | r_ij_scale = r_ij * 2.0 * np.pi / self.length_scale 216 | x, y, z = U_ij.unbind(-1) 217 | theta_ij = torch.acos(z / r_ij) 218 | phi_ij = torch.atan2(y, x) 219 | 220 | # Build Fourier expansions of each coordinate 221 | r_ff, theta_ff, phi_ff = [ 222 | self._fourier_expand(h, self.fourier_order) 223 | for h in [r_ij_scale, theta_ij, phi_ij] 224 | ] 225 | # Radial envelope function 226 | r_envelope = mask_atoms_ij * torch.exp( 227 | -r_ij * r_ij / (2 * self.length_scale ** 2) 228 | ) 229 | 230 | # Tensor outer product 231 | bf_ij = torch.einsum( 232 | "bika,bikar,bikat,bikap->bikartp", r_envelope, r_ff, theta_ff, phi_ff 233 | ).reshape(list(r_ij.shape)[:3] + [-1]) 234 | 235 | h = mask_atoms_ij[:, :, :, 0].unsqueeze(-1) * self.embed(bf_ij) 236 | 237 | return h 238 | 239 | def forward(self, X, C, S, edge_idx): 240 | U_ij, mask_atoms_ij = self._local_coordinates(X, C, S, edge_idx) 241 | h = self._featurize(U_ij, mask_atoms_ij) 242 | return h 243 | 244 | def step(self, t, X, C, S, edge_idx_t): 245 | U_ij, mask_atoms_ij = self._local_coordinates_t(t, X, C, S, edge_idx_t) 246 | h = self._featurize(U_ij, mask_atoms_ij) 247 | return h 248 | -------------------------------------------------------------------------------- /chroma/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Generate Biomedicines, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | This package contains complete PyTorch models, that is model classes that can 17 | be trained on data. See also :mod:`chroma.layers`. 18 | """ 19 | 20 | from chroma.models.chroma import Chroma 21 | -------------------------------------------------------------------------------- /chroma/models/graph_energy.py: -------------------------------------------------------------------------------- 1 | # Copyright Generate Biomedicines, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Models for building energy functions for protein sequence and structure. 16 | 17 | This module contains pytorch models for building energy functions that score 18 | protein sequence and structure and that can be used for partial and full 19 | protein de novo design. 20 | """ 21 | 22 | 23 | import torch.nn as nn 24 | 25 | from chroma.layers import graph 26 | 27 | 28 | class GraphHarmonicFeatures(nn.Module): 29 | """Layer for quadratic node and edge features. 30 | 31 | Args: 32 | dim_nodes (int): Hidden dimension of node tensor. 33 | dim_edges (int): Hidden dimension of edge tensor. 34 | node_mlp_layers (int): Number of hidden layers for node update. 35 | node_mlp_dim (int): Node update function, hidden dimension. 36 | edge_mlp_layers (int): Edge update function, number of hidden layers. 37 | edge_mlp_dim (int): Edge update function, hidden dimension. 38 | mlp_activation (str): MLP nonlinearity. 39 | `'relu'`: Rectified linear unit. 40 | `'softplus'`: Softplus. 41 | 42 | Inputs: 43 | node_h (Tensor): Node embeddings with shape 44 | `(num_batch, num_nodes, dim_nodes)`. 45 | node_feature (Tensor): Node features with shape 46 | `(num_batch, num_nodes, dim_nodes)`. 47 | edge_h (Tensor): Edge embeddings with shape 48 | `(num_batch, num_nodes, num_neighbors, dim_edges)`. 49 | edge_feature (Tensor): Edge features with shape 50 | `(num_batch, num_nodes, num_neighbors, dim_edges)`. 51 | edge_idx (LongTensor): Edge indices for neighbors with shape 52 | `(num_batch, num_nodes, num_neighbors)`. 53 | mask_i (tensor, optional): Node mask with shape `(num_batch, num_nodes)` 54 | mask_ij (tensor, optional): Edge mask with shape 55 | `(num_batch, num_nodes, num_neighbors)`. 56 | 57 | Outputs: 58 | node_h (Tensor): Updated node embeddings with shape 59 | `(num_batch, num_nodes, dim_nodes)`. 60 | edge_h (Tensor): Updated edge embeddings with shape 61 | `(num_batch, num_nodes, num_neighbors, dim_edges)`. 62 | """ 63 | 64 | def __init__( 65 | self, 66 | dim_nodes, 67 | dim_edges, 68 | node_mlp_layers, 69 | node_mlp_dim, 70 | edge_mlp_layers, 71 | edge_mlp_dim, 72 | mlp_activation="softplus", 73 | dropout=0.0, 74 | ): 75 | super(GraphHarmonicFeatures, self).__init__() 76 | self.dim_nodes = dim_nodes 77 | self.dim_edges = dim_edges 78 | self.node_mlp = graph.MLP( 79 | dim_in=dim_nodes, 80 | dim_out=2 * dim_nodes, 81 | num_layers_hidden=node_mlp_layers, 82 | dim_hidden=node_mlp_dim, 83 | activation=mlp_activation, 84 | dropout=dropout, 85 | ) 86 | self.edge_mlp = graph.MLP( 87 | dim_in=dim_edges, 88 | dim_out=2 * dim_edges, 89 | num_layers_hidden=edge_mlp_layers, 90 | dim_hidden=edge_mlp_dim, 91 | activation=mlp_activation, 92 | dropout=dropout, 93 | ) 94 | self.node_out = nn.Linear(dim_nodes, dim_nodes) 95 | self.edge_out = nn.Linear(dim_edges, dim_edges) 96 | 97 | def forward(self, node_h, node_feature, edge_h, edge_feature, mask_i, mask_ij): 98 | node_h_pred = self.node_mlp(node_h) 99 | node_mu = node_h_pred[:, :, : self.dim_nodes] 100 | node_coeff = node_h_pred[:, :, self.dim_nodes :] 101 | node_errors = node_coeff * (node_feature - node_mu) ** 2 102 | node_h = node_h + self.node_out(node_errors) 103 | node_h = mask_i.unsqueeze(-1) * node_h 104 | 105 | edge_h_pred = self.edge_mlp(edge_h) 106 | edge_mu = edge_h_pred[:, :, :, : self.dim_edges] 107 | edge_coeff = edge_h_pred[:, :, :, self.dim_edges :] 108 | edge_errors = edge_coeff * (edge_feature - edge_mu) ** 2 109 | edge_h = edge_h + self.edge_out(edge_errors) 110 | edge_h = mask_ij.unsqueeze(-1) * edge_h 111 | return node_h, edge_h 112 | -------------------------------------------------------------------------------- /chroma/utility/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Generate Biomedicines, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | This package is home to miscellaneous utilities that don't fit elsewhere. 17 | Modules in this package should aim to minimize their dependencies on other modules. 18 | """ 19 | -------------------------------------------------------------------------------- /chroma/utility/api.py: -------------------------------------------------------------------------------- 1 | # Copyright Generate Biomedicines, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import hashlib 16 | import json 17 | import os 18 | import tempfile 19 | 20 | import requests 21 | 22 | import chroma 23 | 24 | ROOT_DIR = os.path.dirname(os.path.dirname(chroma.__file__)) 25 | 26 | 27 | def register_key(key: str, key_directory=ROOT_DIR) -> None: 28 | """ 29 | Registers the provided key by saving it to a JSON file. 30 | 31 | Args: 32 | key (str): The access token to be registered. 33 | key_directory (str, optional): The directory where the access key is registered. 34 | 35 | Returns: 36 | None 37 | """ 38 | config_path = os.path.join(key_directory, "config.json") 39 | with open(config_path, "w") as f: 40 | json.dump({"access_token": key}, f) 41 | 42 | 43 | def read_key(key_directory=ROOT_DIR) -> str: 44 | """ 45 | Reads the registered key from the JSON file. If no key has been registered, 46 | it informs the user and raises a FileNotFoundError. 47 | 48 | Args: 49 | key_directory (str, optional): The directory where the access key is registered. 50 | 51 | Returns: 52 | str: The registered access token. 53 | 54 | Raises: 55 | FileNotFoundError: If no key has been registered. 56 | """ 57 | config_path = os.path.join(key_directory, "config.json") 58 | 59 | if not os.path.exists(config_path): 60 | print("No access token has been registered.") 61 | print( 62 | "To obtain an access token, go to https://chroma-weights.generatebiomedicines.com/ and agree to the license." 63 | ) 64 | raise FileNotFoundError("No token has been registered.") 65 | 66 | with open(config_path, "r") as f: 67 | config = json.load(f) 68 | 69 | return config["access_token"] 70 | 71 | 72 | def download_from_generate( 73 | base_url: str, 74 | weights_name: str, 75 | force: bool = False, 76 | exist_ok: bool = False, 77 | key_directory=ROOT_DIR, 78 | ) -> str: 79 | """ 80 | Downloads data from the provided URL using the registered access token. 81 | Provides caching behavior based on force and exist_ok flags. 82 | 83 | Args: 84 | base_url (str): The base URL from which data should be fetched. 85 | force (bool): If True, always fetches data from the URL regardless of cache existence. 86 | exist_ok (bool): If True and cache exists (and force is False), uses the cached data. 87 | key_directory (str, optional): The directory where the access key is registered. 88 | 89 | Returns: 90 | str: Path to the downloaded (or cached) file. 91 | """ 92 | 93 | # Create a hash of the URL + weight name to determine the path for the cached/temporary file 94 | url_hash = hashlib.md5((base_url + weights_name).encode()).hexdigest() 95 | temp_dir = os.path.join(tempfile.gettempdir(), "chroma_weights", url_hash) 96 | destination = os.path.join(temp_dir, "weights.pt") 97 | 98 | # Ensure the directory exists 99 | os.makedirs(temp_dir, exist_ok=True) 100 | 101 | # Check if cache exists 102 | cache_exists = os.path.exists(destination) 103 | 104 | # Determine if we should use the cache or not 105 | use_cache = cache_exists and exist_ok and not force 106 | 107 | if use_cache: 108 | print(f"Using cached data from {destination}") 109 | return destination 110 | 111 | # If not using cache, proceed with download 112 | 113 | # Define the query parameters 114 | params = {"token": read_key(key_directory), "weights": weights_name} 115 | 116 | # Perform the GET request with the token as a query parameter 117 | response = requests.get(base_url, params=params) 118 | response.raise_for_status() # Raise an error for HTTP errors 119 | 120 | with open(destination, "wb") as file: 121 | file.write(response.content) 122 | 123 | print(f"Data saved to {destination}") 124 | return destination 125 | -------------------------------------------------------------------------------- /chroma/utility/fetchdb.py: -------------------------------------------------------------------------------- 1 | # Copyright Generate Biomedicines, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Functions to retrieve information from external databases via their API; Uniprot and RCSB are the primary databases included here. 16 | 17 | """ 18 | 19 | 20 | import requests 21 | 22 | 23 | def _download_file(url, out_file): 24 | try: 25 | with requests.get(url, stream=True) as r: 26 | r.raise_for_status() 27 | with open(out_file, "wb") as f: 28 | for chunk in r.iter_content(chunk_size=8192): 29 | if chunk: 30 | f.write(chunk) 31 | return True 32 | except requests.HTTPError: 33 | return False 34 | 35 | 36 | def RCSB_file_download(pdb_id, ext, local_filename): 37 | """Downloads a file from the RCSB files section. 38 | 39 | Args: 40 | pdb_id (str) : 4-letter pdb id, case-insensitive 41 | ext (str) : Extension of file. E.g. ".pdb" or ".pdb1" 42 | local_filename (str) : Name for downloaded file. 43 | Returns: 44 | None 45 | """ 46 | url = f"https://files.rcsb.org/view/{pdb_id.upper()}{ext}" 47 | return _download_file(url, local_filename) 48 | -------------------------------------------------------------------------------- /chroma/utility/model.py: -------------------------------------------------------------------------------- 1 | # Copyright Generate Biomedicines, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Utilities to save and load models with metadata. 17 | """ 18 | 19 | import os 20 | import os.path as osp 21 | import tempfile 22 | from pathlib import Path 23 | from urllib.parse import parse_qs, urlparse 24 | from uuid import uuid4 25 | 26 | import torch 27 | 28 | import chroma.utility.api as api 29 | from chroma.constants.named_models import NAMED_MODELS 30 | 31 | 32 | def save_model(model, weight_file, metadata=None): 33 | """Save model, including optional metadata. 34 | 35 | Args: 36 | model (nn.Module): The model to save. Details about the model needed 37 | for initialization, such as layer sizes, should be in model.kwargs. 38 | weight_file (str): The destination path for saving model weights. 39 | metadata (dict): A dictionary of additional metadata to add to the model 40 | weights. For example, when saving models during training it can be 41 | useful to store `args` representing the CLI args, the date and time 42 | of training, etc. 43 | """ 44 | save_dict = {"init_kwargs": model.kwargs, "model_state_dict": model.state_dict()} 45 | if metadata is not None: 46 | save_dict.update(metadata) 47 | local_path = str( 48 | Path(tempfile.gettempdir(), str(uuid4())[:8]) 49 | if weight_file.startswith("s3:") 50 | else weight_file 51 | ) 52 | torch.save(save_dict, local_path) 53 | if weight_file.startswith("s3:"): 54 | raise NotImplementedError("Uploading to an s3 link not supported.") 55 | 56 | 57 | def load_model( 58 | weights, 59 | model_class, 60 | device="cpu", 61 | strict=False, 62 | strict_unexpected=True, 63 | verbose=True, 64 | ): 65 | """Load model saved with save_model. 66 | 67 | Args: 68 | weights (str): The destination path of the model weights to load. 69 | Compatible with files saved by `save_model`. 70 | model_class: Name of model class. 71 | device (str, optional): Pytorch device specification, e.g. `'cuda'` for 72 | GPU. Default is `'cpu'`. 73 | strict (bool): Whether to require that the keys match between the 74 | input file weights and the model created from the parameters stored 75 | in the model kwargs. 76 | strict_unexpected (bool): Whether to require that there are no 77 | unexpected keys when loading model weights, as distinct from the 78 | strict option which doesn't allow for missing keys either. By 79 | default, we use this option rather than strict for ease of 80 | development when adding model features. 81 | verbose (bool, optional): Show outputs from download and loading. Default True. 82 | 83 | Returns: 84 | model (nn.Module): Torch model with loaded weights. 85 | """ 86 | 87 | # Process weights path 88 | if str(weights).startswith("named:"): 89 | weights = weights.split("named:")[1] 90 | if weights not in NAMED_MODELS[model_class.__name__]: 91 | raise Exception(f"Unknown {model_class.__name__} model name: {weights},") 92 | weights = NAMED_MODELS[model_class.__name__][weights]["s3_uri"] 93 | 94 | # resolve s3 paths 95 | if str(weights).startswith("s3:"): 96 | raise NotImplementedError("Loading Models from an S3 link not supported.") 97 | 98 | # download public models from generate 99 | if str(weights).startswith("https:"): 100 | # Decompose into arguments 101 | parsed_url = urlparse(weights) 102 | base_url = f"{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path}" 103 | model_name = parse_qs(parsed_url.query).get("weights", [None])[0] 104 | weights = api.download_from_generate( 105 | base_url, model_name, force=False, exist_ok=True 106 | ) 107 | 108 | # load model weights 109 | params = torch.load(weights, map_location="cpu") 110 | model = model_class(**params["init_kwargs"]).to(device) 111 | missing_keys, unexpected_keys = model.load_state_dict( 112 | params["model_state_dict"], strict=strict 113 | ) 114 | if strict_unexpected and len(unexpected_keys) > 0: 115 | raise Exception( 116 | f"Error loading model from checkpoint file: {weights} contains {len(unexpected_keys)} unexpected keys: {unexpected_keys}" 117 | ) 118 | return model 119 | -------------------------------------------------------------------------------- /chroma/utility/ngl.py: -------------------------------------------------------------------------------- 1 | # Copyright Generate Biomedicines, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Utilities for rendering protein structures in Jupyter notebooks. 17 | 18 | This provides convenience functions for rendering our common structure 19 | datatypes, such as `mst.System` and `XCS` tensors, with nglview. 20 | """ 21 | 22 | import tempfile 23 | import uuid 24 | 25 | import nglview as nv 26 | 27 | 28 | class SystemTrajectory(nv.base_adaptor.Trajectory, nv.base_adaptor.Structure): 29 | """MST multi-state System object adaptor, by analogy to other NGLView adaptor 30 | classes (e.g., MDTrajTrajectory or PyTrajTrajectory in nglview.adaptor). 31 | Example 32 | ------- 33 | >>> import nglview as nv 34 | >>> import chroma.data.protein import Protein 35 | >>> protein_trajectory = Protein("multi-state.cif") 36 | >>> t = SystemTrajectory(protein_trajectory.sys) 37 | >>> w = nv.NGLWidget(t) 38 | >>> w 39 | """ 40 | 41 | def __init__(self, protein): 42 | self.protein = protein 43 | self.ext = "pdb" 44 | self.params = {} 45 | self.id = str(uuid.uuid4()) 46 | 47 | def get_coordinates(self, index): 48 | self.protein.sys.swap_model(index) 49 | X, _, _ = self.protein.sys.to_XCS() 50 | self.protein.sys.swap_model(index) 51 | return X.view(-1, 3).numpy() 52 | 53 | @property 54 | def n_frames(self): 55 | return self.protein.sys.num_models() 56 | 57 | def get_structure_string(self): 58 | return self.protein.sys.to_PDB_string() 59 | 60 | 61 | def view_gsystem(system, **kwargs): 62 | """Return an NGL Viewer Widget for an generate System. 63 | 64 | Args: 65 | system (System): Structure to view. 66 | 67 | Returns: 68 | view: NGL Viewer widget instance that. In a Jupyter notebook 69 | returning this to the notebook will trigger display of a 70 | widget. 71 | """ 72 | temp = tempfile.NamedTemporaryFile(suffix=".pdb") 73 | filename = temp.name 74 | system.to_PDB(filename) 75 | view = nv.show_file(filename) 76 | view.clear_representations() 77 | view.add_representation("cartoon") 78 | view.add_representation("licorice", selection="(sidechain or .CA) and not hydrogen") 79 | view.add_representation("contact") 80 | view.center() 81 | return view 82 | -------------------------------------------------------------------------------- /chroma/utility/polyseq.py: -------------------------------------------------------------------------------- 1 | # Copyright Generate Biomedicines, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Standard residue names for polymers of different types (e.g., L- or D-amino acid proteins, 16 | mixed-chirality proteins, DNA/RNA, etc.) 17 | """ 18 | 19 | from enum import Enum 20 | 21 | 22 | class polymerType(Enum): 23 | LPROT = 0 24 | DPROT = 1 25 | LDPROT = 2 26 | DNA = 3 27 | RNA = 4 28 | 29 | 30 | def polymer_type_name(ptype: polymerType): 31 | if ptype == polymerType.LPROT: 32 | return "polypeptide(L)" 33 | elif ptype == polymerType.DPROT: 34 | return "polypeptide(D)" 35 | elif ptype == polymerType.LDPROT: 36 | return "polypeptide(L,D)" 37 | elif ptype == polymerType.DNA: 38 | return "polydeoxyribonucleotide" 39 | elif ptype == polymerType.RNA: 40 | return "polyribonucleotide" 41 | else: 42 | raise Exception(f"unknown polymer type {ptype}") 43 | 44 | 45 | _res3 = [[] for _ in range(len(polymerType))] 46 | 47 | _res1 = [[] for _ in range(len(polymerType))] 48 | 49 | _res_to_idx = [dict() for _ in range(len(polymerType))] 50 | 51 | _unk_idx = [set() for _ in range(len(polymerType))] 52 | 53 | _gap_idx = [set() for _ in range(len(polymerType))] 54 | 55 | _stp_idx = [set() for _ in range(len(polymerType))] 56 | 57 | 58 | def _add_residue(ptype: polymerType, res3, res1): 59 | if isinstance(ptype, list): 60 | for pt, r3, r1 in zip(ptype, res3, res1): 61 | _add_residue(pt, r3, r1) 62 | else: 63 | _res_to_idx[ptype.value][res3] = len(_res3[ptype.value]) 64 | # single-letter code is ambiguous, so take the first residue when going from single-letter code to index 65 | if res1 not in _res_to_idx[ptype.value]: 66 | _res_to_idx[ptype.value][res1] = _res_to_idx[ptype.value][res3] 67 | _res3[ptype.value].append(res3) 68 | _res1[ptype.value].append(res1) 69 | if res3 == "---": 70 | _gap_idx[ptype.value].add(_res_to_idx[ptype.value][res3]) 71 | elif res3 == "UNK": 72 | _unk_idx[ptype.value].add(_res_to_idx[ptype.value][res3]) 73 | elif res3 == "STP": 74 | _stp_idx[ptype.value].add(_res_to_idx[ptype.value][res3]) 75 | 76 | 77 | def num_tokens(ptype=polymerType.LPROT): 78 | return len(_res3[ptype.value]) 79 | 80 | 81 | def num_known_molecular_tokens(ptype=polymerType.LPROT): 82 | return sum( 83 | [ 84 | not is_punctuation_index(idx) and not is_unknown(idx) 85 | for idx in range(len(_res3[ptype.value])) 86 | ] 87 | ) 88 | 89 | 90 | def res_to_index(res: str, ptype=polymerType.LPROT): 91 | return _res_to_idx[ptype.value].get(res, next(iter(_unk_idx[ptype.value]))) 92 | 93 | 94 | def index_to_single(idx: int, ptype=polymerType.LPROT): 95 | return _res1[ptype.value][idx] 96 | 97 | 98 | def index_to_triple(idx: int, ptype=polymerType.LPROT): 99 | return _res3[ptype.value][idx] 100 | 101 | 102 | def to_single(res: str, ptype=polymerType.LPROT): 103 | return index_to_single(res_to_index(res, ptype)) 104 | 105 | 106 | def to_triple(res: str, ptype=polymerType.LPROT): 107 | return index_to_triple(res_to_index(res, ptype)) 108 | 109 | 110 | def is_gap_index(idx: int, ptype=polymerType.LPROT): 111 | return idx in _gap_idx[ptype.value] 112 | 113 | 114 | def is_stop_index(idx: int, ptype=polymerType.LPROT): 115 | return idx in _stp_idx[ptype.value] 116 | 117 | 118 | def is_unknown(res: str, ptype=polymerType.LPROT): 119 | return is_unknown_index(res_to_index(res, ptype), ptype) 120 | 121 | 122 | def is_unknown_index(idx: int, ptype=polymerType.LPROT): 123 | return idx in _unk_idx[ptype.value] 124 | 125 | 126 | def is_polymer_residue(res: str, ptype: polymerType): 127 | if ptype is None: 128 | # determine if this is a polymer residue for any known polymer 129 | for ptype in polymerType: 130 | if res in _res_to_idx[ptype.value]: 131 | return True 132 | return False 133 | return res in _res_to_idx[ptype.value] 134 | 135 | 136 | def is_punctuation_index(idx: int, ptype=polymerType.LPROT): 137 | return is_gap_index(idx, ptype) or is_stop_index(idx, ptype) 138 | 139 | 140 | def is_canonical(res: str, ptype=polymerType.LPROT): 141 | if ptype == polymerType.LPROT or ptype == polymerType.DPROT: 142 | idx = res_to_index(res, ptype) 143 | return (idx < 20) and (idx >= 0) 144 | elif ptype == polymerType.LDPROT: 145 | return is_canonical(res, polymerType.LPROT) or is_canonical( 146 | mirror_amino_acid(res), polymerType.DPROT 147 | ) 148 | raise Exception(f"do not known how to deal with polymer type {ptype}") 149 | 150 | 151 | def canonical_amino_acids(ptype=polymerType.LPROT): 152 | canonicals = [] 153 | for aa in _res3[ptype.value]: 154 | if is_canonical(aa, ptype): 155 | canonicals.append(aa) 156 | return canonicals 157 | 158 | 159 | _add_residue([polymerType.LPROT, polymerType.DPROT], ["ALA", "DAL"], ["A", "a"]) 160 | 161 | _add_residue([polymerType.LPROT, polymerType.DPROT], ["CYS", "DCY"], ["C", "c"]) 162 | 163 | _add_residue([polymerType.LPROT, polymerType.DPROT], ["ASP", "DAS"], ["D", "d"]) 164 | 165 | _add_residue([polymerType.LPROT, polymerType.DPROT], ["GLU", "DGL"], ["E", "e"]) 166 | 167 | _add_residue([polymerType.LPROT, polymerType.DPROT], ["PHE", "DPN"], ["F", "f"]) 168 | 169 | _add_residue([polymerType.LPROT, polymerType.DPROT], ["GLY", "GLY"], ["G", "G"]) 170 | 171 | _add_residue([polymerType.LPROT, polymerType.DPROT], ["HIS", "DHI"], ["H", "h"]) 172 | 173 | _add_residue([polymerType.LPROT, polymerType.DPROT], ["ILE", "DIL"], ["I", "i"]) 174 | 175 | _add_residue([polymerType.LPROT, polymerType.DPROT], ["LYS", "DLY"], ["K", "k"]) 176 | 177 | _add_residue([polymerType.LPROT, polymerType.DPROT], ["LEU", "DLE"], ["L", "l"]) 178 | 179 | _add_residue([polymerType.LPROT, polymerType.DPROT], ["MET", "MED"], ["M", "m"]) 180 | 181 | _add_residue([polymerType.LPROT, polymerType.DPROT], ["ASN", "DSG"], ["N", "n"]) 182 | 183 | _add_residue([polymerType.LPROT, polymerType.DPROT], ["PRO", "DPR"], ["P", "p"]) 184 | 185 | _add_residue([polymerType.LPROT, polymerType.DPROT], ["GLN", "DGN"], ["Q", "q"]) 186 | 187 | _add_residue([polymerType.LPROT, polymerType.DPROT], ["ARG", "DAR"], ["R", "r"]) 188 | 189 | _add_residue([polymerType.LPROT, polymerType.DPROT], ["SER", "DSN"], ["S", "s"]) 190 | 191 | _add_residue([polymerType.LPROT, polymerType.DPROT], ["THR", "DTH"], ["T", "t"]) 192 | 193 | _add_residue([polymerType.LPROT, polymerType.DPROT], ["VAL", "DVA"], ["V", "v"]) 194 | 195 | _add_residue([polymerType.LPROT, polymerType.DPROT], ["TRP", "DTR"], ["W", "w"]) 196 | 197 | _add_residue([polymerType.LPROT, polymerType.DPROT], ["TYR", "DTY"], ["Y", "y"]) 198 | 199 | _add_residue([polymerType.LPROT, polymerType.DPROT], ["HSD", "DSD"], ["H", "h"]) 200 | 201 | _add_residue([polymerType.LPROT, polymerType.DPROT], ["HSE", "DSE"], ["H", "h"]) 202 | 203 | _add_residue([polymerType.LPROT, polymerType.DPROT], ["HSC", "DSC"], ["H", "h"]) 204 | 205 | _add_residue([polymerType.LPROT, polymerType.DPROT], ["HSP", "DSP"], ["H", "h"]) 206 | 207 | _add_residue([polymerType.LPROT, polymerType.DPROT], ["MSE", "DMS"], ["M", "m"]) 208 | 209 | _add_residue([polymerType.LPROT, polymerType.DPROT], ["CSO", "DCS"], ["C", "c"]) 210 | 211 | _add_residue([polymerType.LPROT, polymerType.DPROT], ["SEC", "DEC"], ["C", "c"]) 212 | 213 | _add_residue([polymerType.LPROT, polymerType.DPROT], ["CSX", "DCX"], ["C", "c"]) 214 | 215 | _add_residue([polymerType.LPROT, polymerType.DPROT], ["HIP", "DHP"], ["H", "h"]) 216 | 217 | _add_residue([polymerType.LPROT, polymerType.DPROT], ["SEP", "DEP"], ["S", "s"]) 218 | 219 | _add_residue([polymerType.LPROT, polymerType.DPROT], ["TPO", "DTP"], ["T", "t"]) 220 | 221 | _add_residue([polymerType.LPROT, polymerType.DPROT], ["PTR", "DPT"], ["Y", "y"]) 222 | 223 | _add_residue([polymerType.LPROT, polymerType.DPROT], ["UNK", "UNK"], ["X", "X"]) 224 | 225 | _add_residue([polymerType.LPROT, polymerType.DPROT], ["STP", "STP"], ["*", "*"]) 226 | 227 | _add_residue([polymerType.LPROT, polymerType.DPROT], ["---", "---"], ["-", "-"]) 228 | 229 | _add_residue([polymerType.LPROT, polymerType.DPROT], ["---", "---"], [".", "."]) 230 | 231 | for grp in [1, 2, 3]: 232 | for tp in [polymerType.LPROT, polymerType.DPROT]: 233 | for idx in range(num_tokens(tp)): 234 | if grp == 1: 235 | if not is_punctuation_index(idx, tp) and ( 236 | not is_unknown_index(idx, tp) 237 | ): 238 | if _res3[tp.value][idx] not in _res3[polymerType.LDPROT.value]: 239 | _add_residue( 240 | polymerType.LDPROT, 241 | _res3[tp.value][idx], 242 | _res1[tp.value][idx], 243 | ) 244 | elif grp == 2: 245 | if is_unknown_index(idx, tp): 246 | if _res3[tp.value][idx] not in _res3[polymerType.LDPROT.value]: 247 | _add_residue( 248 | polymerType.LDPROT, 249 | _res3[tp.value][idx], 250 | _res1[tp.value][idx], 251 | ) 252 | elif grp == 3: 253 | if is_punctuation_index(idx, tp): 254 | if _res3[tp.value][idx] not in _res3[polymerType.LDPROT.value]: 255 | _add_residue( 256 | polymerType.LDPROT, 257 | _res3[tp.value][idx], 258 | _res1[tp.value][idx], 259 | ) 260 | 261 | 262 | def mirror_amino_acid(res: str): 263 | idx = mirror_amino_acid_index(res_to_index(res, polymerType.LDPROT)) 264 | if len(res) == 1: 265 | return index_to_single(idx) 266 | return index_to_triple(idx) 267 | 268 | 269 | def mirror_amino_acid_index(idx: int): 270 | N = num_known_molecular_tokens(polymerType.LDPROT) 271 | 272 | # if this is an unknown residue or a punctuation mark, return as is 273 | if idx >= N: 274 | return idx 275 | 276 | # otherwise, flip chirality 277 | return (idx + N // 2) % N 278 | -------------------------------------------------------------------------------- /chroma/utility/starparser.py: -------------------------------------------------------------------------------- 1 | # Copyright Generate Biomedicines, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import shlex 16 | from dataclasses import dataclass 17 | 18 | 19 | @dataclass 20 | class PeekedLine: 21 | line: str 22 | next_position: int 23 | 24 | 25 | def peek_line(f, peeked: PeekedLine, rewind=True): 26 | ret = True 27 | pos = f.tell() 28 | line = f.readline() 29 | if line == "": # at EOF 30 | ret = False 31 | elif line[-1] == "\n": 32 | line = line[:-1] 33 | peeked.line = line 34 | if rewind: 35 | peeked.next_position = f.tell() 36 | f.seek(pos) 37 | else: 38 | peeked.next_position = pos 39 | return ret 40 | 41 | 42 | def advance(f, peeked: PeekedLine): 43 | f.seek(peeked.next_position) 44 | 45 | 46 | def star_item_parse(line: str): 47 | parts = line.split(".") 48 | if len(parts) < 2: 49 | raise Exception(f"expected at least two parts in the STAR data line {line}") 50 | cat = parts[0] 51 | name_parts = parts[1].split() 52 | name = name_parts[0] 53 | if len(name_parts) >= 2: 54 | val = name_parts[1] 55 | else: 56 | val = "" 57 | return (cat, name, val) 58 | 59 | 60 | def star_read_data(f, names: list, in_loop: bool, cols=False, has_blocks=True): 61 | tab = [] 62 | line = "" 63 | if cols: 64 | tab = [[] for _ in range(len(names))] 65 | peeked = PeekedLine("", 0) 66 | if in_loop: 67 | heads = [] 68 | while peek_line(f, peeked): 69 | if not peeked.line.startswith("_"): 70 | break 71 | parts = peeked.line.split(".") 72 | if len(parts) != 2: 73 | raise Exception(f"expected two parts in the STAR data line {line}") 74 | heads.append(parts[1].strip()) 75 | advance(f, peeked) 76 | 77 | # figure out which columns we want 78 | indices = [-1] * len(names) 79 | for i, name in enumerate(names): 80 | if name in heads: 81 | indices[i] = heads.index(name) 82 | 83 | # read each row and get the corresponding columns 84 | row = [None] * len(heads) 85 | ma = max(indices) 86 | while star_read_data_row(f, row, in_loop, has_blocks): 87 | if (ma >= 0) and (len(row) <= ma): 88 | raise Exception(f"loop row has insufficient elements: {line}") 89 | if not cols: 90 | tab.append([""] * len(names)) 91 | for i, index in enumerate(indices): 92 | if cols: 93 | tab[i].append(row[index] if index >= 0 else "") 94 | else: 95 | tab[-1][i] = row[index] if index >= 0 else "" 96 | else: 97 | if not cols: 98 | tab = [[""] * len(names)] 99 | category, cat, name = "", "", "" 100 | 101 | row = ["", ""] 102 | while star_read_data_row(f, row, in_loop, has_blocks, peeked): 103 | cat, name, _ = star_item_parse(row[0]) 104 | if category == "": 105 | category = cat 106 | elif category != cat: 107 | advance(f, peeked) 108 | break 109 | 110 | if name not in names: 111 | continue 112 | idx = names.index(name) 113 | if cols: 114 | tab[idx].push_back(row[1]) 115 | else: 116 | tab[0][idx] = row[1] 117 | 118 | return tab 119 | 120 | 121 | def star_read_data_row( 122 | f, row: list, in_loop: bool, has_blocks: bool, peeked: PeekedLine = None 123 | ): 124 | i = 0 125 | ret = True 126 | if peeked is None: 127 | peeked = PeekedLine("", 0) 128 | while i < len(row): 129 | if not peek_line(f, peeked, rewind=False): 130 | if peeked.line == "" and i == 0: 131 | return False 132 | raise Exception(f"read {i} tokens when {len(row)} were requested: {row}") 133 | if ( 134 | peeked.line.startswith("loop_") 135 | or peeked.line.startswith("data_") 136 | or (in_loop and peeked.line.startswith("_")) 137 | ): 138 | if i == 0: 139 | advance(f, peeked) 140 | return False 141 | raise Exception( 142 | f"data block ended while reading requested number of tokens: {len(row)}" 143 | ) 144 | 145 | if peeked.line.startswith(";"): 146 | row[i] = peeked.line[1:] 147 | while peek_line(f, peeked, rewind=False): 148 | if peeked.line.startswith(";"): 149 | break 150 | row[i] += peeked.line 151 | i = i + 1 152 | elif peeked.line.startswith("#"): 153 | pass 154 | else: 155 | elems = ( 156 | [part for part in shlex.split(peeked.line.strip())] 157 | if has_blocks 158 | else peeked.line.strip().split() 159 | ) 160 | if i + len(elems) > len(row): 161 | raise Exception( 162 | f"too many elements when trying to read {len(row)} tokens; last read: {elems}, row was: {row}, i = {i}" 163 | ) 164 | for elem in elems: 165 | row[i] = elem 166 | i = i + 1 167 | 168 | return ret 169 | 170 | 171 | def star_string_escape(text): 172 | # NOTE: has_space designates whether the string really should be quoted, not 173 | # based on having quote characters within it, but just because of some other 174 | # reason (e.g., it has spaces or is empty or starts with underscore, which can 175 | # have special meaning in CIF). 176 | has_space = (" " in text) or (text == "") or ((len(text) > 0) and (text[0] == "_")) 177 | has_single = "'" in text 178 | has_double = '"' in text 179 | 180 | if not has_single and not has_double: 181 | if not has_space: 182 | return text 183 | else: 184 | return f"'{text}'" 185 | elif not has_single: 186 | return f"'{text}'" 187 | elif not has_double: 188 | return '"' + text + '"' 189 | return "\n;" + str + "\n;" 190 | 191 | 192 | def star_loop_header_write(f, category, names): 193 | f.write("loop_\n") 194 | for name in names: 195 | f.write(f"{category}.{name} \n") 196 | 197 | 198 | def star_value_defined(val): 199 | return (val != ".") and (val != "?") 200 | 201 | 202 | def star_value(val, default): 203 | if star_value_defined(val): 204 | return val 205 | return default 206 | 207 | 208 | def atom_site_token(value): 209 | return "." if value == " " else value 210 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | nglview==3.0.8 2 | numpy 3 | pandas 4 | Pillow 5 | pytest 6 | Requests 7 | scikit_learn==1.1.2 8 | scipy 9 | torch 10 | tqdm 11 | transformers==4.24.0 12 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import codecs 2 | import os 3 | import time 4 | 5 | from setuptools import find_packages, setup 6 | 7 | with open("requirements.txt", "r") as req_file: 8 | requirements = [line.split("#")[0].strip() for line in req_file] 9 | requirements = [line for line in requirements if line] 10 | 11 | 12 | def read(rel_path): 13 | here = os.path.abspath(os.path.dirname(__file__)) 14 | with codecs.open(os.path.join(here, rel_path), "r") as fp: 15 | return fp.read() 16 | 17 | 18 | def get_version(rel_path): 19 | for line in read(rel_path).splitlines(): 20 | if line.startswith("__version__"): 21 | delim = '"' if '"' in line else "'" 22 | return line.split(delim)[1] 23 | else: 24 | raise RuntimeError("Unable to find version string.") 25 | 26 | 27 | version = get_version("chroma/__init__.py") 28 | 29 | # During CICD, append "-dev" and unix timestamp to version 30 | if os.environ.get("CI_COMMIT_BRANCH") == "develop": 31 | version += f".dev{int(time.time())}" 32 | 33 | setup( 34 | name="generate-chroma", 35 | version=version, 36 | url="https://github.com/generatebio/chroma", 37 | packages=find_packages(), 38 | description="Chroma is a generative model for designing proteins programmatically", 39 | include_package_data=True, 40 | author="Generate Biomedicines", 41 | license="Apache 2.0", 42 | install_requires=requirements, 43 | ) 44 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/generatebio/chroma/929407c605013613941803c6113adefdccaad679/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | try: 4 | import numpy as np 5 | 6 | HAS_NP = True 7 | except Exception: 8 | HAS_NP = False 9 | 10 | try: 11 | import torch 12 | 13 | HAS_TORCH = True 14 | except Exception: 15 | HAS_TORCH = False 16 | 17 | 18 | def pytest_runtest_setup(item): 19 | # Fix seeds at the beginning of each test. 20 | seed = 20220714 21 | random.seed(seed) 22 | if HAS_NP: 23 | np.random.seed(seed) 24 | if HAS_TORCH: 25 | torch.manual_seed(seed) 26 | -------------------------------------------------------------------------------- /tests/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/generatebio/chroma/929407c605013613941803c6113adefdccaad679/tests/data/__init__.py -------------------------------------------------------------------------------- /tests/data/test_protein.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import tempfile 3 | from pathlib import Path 4 | 5 | import pytest 6 | 7 | import chroma 8 | from chroma.data.protein import Protein 9 | 10 | BASE_PATH = str(Path(chroma.__file__).parent.parent) 11 | PROTEIN_SINGLE_CHAIN = BASE_PATH + "/tests/resources/4kw4.cif" 12 | PROTEIN_COMPLEX = BASE_PATH + "/tests/resources/3hn3.cif" 13 | CIF_TRAJECTORY = BASE_PATH + "/tests/resources/chroma_trajectory.cif" 14 | SEQUENCE = "MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTL" 15 | PDB_ID = "1B9C" 16 | 17 | TESTS = [PROTEIN_SINGLE_CHAIN, PROTEIN_COMPLEX, SEQUENCE, PDB_ID] 18 | 19 | 20 | @pytest.mark.parametrize("protein_path", TESTS) 21 | def test_Protein(protein_path): 22 | # Loading Smoke Tests 23 | if protein_path.endswith(".pdb"): 24 | protein = Protein.from_PDB(protein_path) 25 | elif protein_path.endswith(".cif"): 26 | protein = Protein.from_CIF(protein_path) 27 | elif len(protein_path) == 4: 28 | protein = Protein.from_PDBID(protein_path) 29 | else: # Protein Sequence Input 30 | protein = Protein.from_sequence(protein_path) 31 | 32 | # Selection Smoke Test 33 | # Select all structured residues 34 | D = protein.get_mask("all").bool() 35 | 36 | # Method Smoke Tests 37 | protein.canonicalize() 38 | protein.sequence() 39 | len(protein) 40 | protein.display() 41 | 42 | # Cycles save / load /validate 43 | X, C, S = protein.to_XCS() 44 | 45 | # XCS 46 | xcs_cycle_protein = Protein.from_XCS(X, C, S) 47 | Xt, Ct, St = xcs_cycle_protein.to_XCS() 48 | assert (Xt == X).all() and (Ct == C).all() and (St == S).all() 49 | 50 | # CIF 51 | with tempfile.NamedTemporaryFile(suffix=".cif", delete=True) as temp_file: 52 | protein.to_CIF(temp_file.name) 53 | Xt, Ct, St = protein.from_CIF(temp_file.name).to_XCS() 54 | assert (Xt == X).all() and (Ct == C).all() and (St == S).all() 55 | 56 | # PDB 57 | with tempfile.NamedTemporaryFile(suffix=".pdb", delete=True) as temp_file: 58 | protein.to_PDB(temp_file.name) 59 | structured_residues = protein.sys.num_structured_residues() 60 | round_trip_protein = protein.from_PDB(temp_file.name) 61 | assert len(round_trip_protein) == structured_residues 62 | 63 | # smoke test copy behavior 64 | copy.copy(protein) 65 | copy.deepcopy(protein) 66 | 67 | 68 | def compare_proteins(A, B): 69 | A, B = A.sys, B.sys 70 | if ( 71 | A.num_chains() != B.num_chains() 72 | or A.num_residues() != B.num_residues() 73 | or A.num_atoms() != B.num_atoms() 74 | or A.num_atom_locations() != B.num_atom_locations() 75 | or A.num_structured_residues() != B.num_structured_residues() 76 | ): 77 | return False 78 | 79 | for cA, cB in zip(A.chains(), B.chains()): 80 | if ( 81 | cA.num_residues() != cB.num_residues() 82 | or cA.cid != cB.cid 83 | or cA.segid != cB.segid 84 | or cA.authid != cB.authid 85 | ): 86 | print(f"chains {cA} and {cB} differ") 87 | return False 88 | for rA, rB in zip(cA.residues(), cB.residues()): 89 | if ( 90 | rA.num_atoms() != rB.num_atoms() 91 | or rA.name != rA.name 92 | or rA.num != rB.num 93 | or rA.authid != rB.authid 94 | or rA.icode != rB.icode 95 | ): 96 | print(f"residues {rA} and {rB} differ") 97 | return False 98 | for aA, aB in zip(rA.atoms(), rB.atoms()): 99 | if ( 100 | aA.num_locations() != aB.num_locations() 101 | or aA.name != aB.name 102 | or aA.het != aB.het 103 | ): 104 | print(f"atoms {aA} and {aB} differ") 105 | return False 106 | for lA, lB in zip(aA.locations(), aB.locations()): 107 | if ( 108 | (abs(lA.coors - lB.coors) > 0.01).any() 109 | or lA.occ != lB.occ 110 | or lA.B != lB.B 111 | or lA.alt != lB.alt 112 | ): 113 | print(f"atoms {lA} and {lB} differ") 114 | return False 115 | return True 116 | 117 | 118 | def test_xcs_trajectory(): 119 | # Load Trajectory 120 | protein = Protein(CIF_TRAJECTORY) 121 | 122 | # Save out Trajectory 123 | X_list, C, S = protein.to_XCS_trajectory() 124 | 125 | # Load back in via XCS 126 | protein_xcs_load = Protein(X_list, C, S) 127 | assert compare_proteins(protein, protein_xcs_load) 128 | 129 | # Print Trajectory 130 | print(protein) 131 | 132 | # Display Trajectory 133 | protein.display() 134 | 135 | 136 | def test_trajectory_round_trip(): 137 | # Load Trajectory 138 | protein = Protein(CIF_TRAJECTORY) 139 | 140 | # Save out Trajectory 141 | X_list, C, S = protein.to_XCS_trajectory() 142 | 143 | # Load back in via XCS 144 | protein_xcs_load = Protein.from_XCS_trajectory(X_list, C, S) 145 | assert compare_proteins(protein, protein_xcs_load) 146 | 147 | # Turn back into XCS 148 | X_list_1, C_1, S_1 = protein_xcs_load.to_XCS_trajectory() 149 | assert len(X_list) == len(X_list_1) 150 | assert [(x1 == x2).all() for x1, x2 in zip(X_list, X_list_1)] 151 | assert (C == C_1).all() 152 | assert (S == S_1).all() 153 | 154 | 155 | @pytest.mark.parametrize("pdb_id", ["3bdi", "5sv5"]) 156 | def test_edge_cases(pdb_id): 157 | Protein(pdb_id, canonicalize=True) 158 | -------------------------------------------------------------------------------- /tests/data/test_system.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import filecmp 3 | import os 4 | import random 5 | import tempfile 6 | import time 7 | from pathlib import Path 8 | from unittest import TestCase 9 | 10 | import numpy as np 11 | import pandas 12 | import pytest 13 | import requests 14 | 15 | import chroma 16 | from chroma.data.system import System 17 | 18 | BASE_PATH = str(Path(chroma.__file__).parent.parent) 19 | CIF_TRAJECTORY = BASE_PATH + "/tests/resources/chroma_trajectory.cif" 20 | 21 | 22 | @pytest.fixture 23 | def cif_file(): 24 | file = str( 25 | Path(Path(chroma.__file__).parent.parent, "tests", "resources", "7bz5.cif") 26 | ) 27 | return file 28 | 29 | 30 | def download_file(url, destination_file): 31 | r = requests.get(url, allow_redirects=True) 32 | open(destination_file, "wb").write(r.content) 33 | 34 | 35 | def test_7bz5_selection(cif_file): 36 | """ 37 | Using a selector on a System, expect canonical results 38 | """ 39 | valid_strings = [ 40 | ("chain B", 218, 222), 41 | # very small queries that are easy to visually verify 42 | ("chain B and resid 3 around 2.5", 4, 4), # selects B near any resid 3 43 | ( 44 | "(chain B and resid 3) around 2.5", 45 | 3, 46 | 3, 47 | ), # use parentheses for specific search around B, resid 3. - 14 atoms in 3 residues 48 | ( 49 | "byres ((chain B and resid 3) around 2.5)", 50 | 3, 51 | 3, 52 | ), # byres selects all atoms in all affected residues - 24 atoms; 3 residues 53 | ( 54 | "(chain B and resid 3 and name CA) around 2.5", 55 | 3, 56 | 3, 57 | ), # further refine to radius around the B, resid 3 CA - 7 atoms in those 3 resis 58 | # re versus name enumeration (should yield same results) 59 | ( 60 | "(chain B and resid 3-5 and (re C.?)) around 12.5", 61 | 92, 62 | 92, 63 | ), # Atoms: 329, Residues: 92 64 | ( 65 | "(chain B and resid 3-5 and (name C or name CA or name CB or name CD or" 66 | " name CG)) around 12.5", 67 | 92, 68 | 92, 69 | ), # Atoms: 329, Residues: 92 70 | # some larger queries and tests 71 | ( 72 | "(chain B and resid 3) or (chain C and resid 103) saround 13", 73 | 28, 74 | 28, 75 | ), # right-association will only expand resid range on C 76 | ( 77 | "((chain B and resid 3-9) or (chain C and resid 103)) saround 13", 78 | 48, 79 | 49, 80 | ), # expand both ranges. 81 | ( 82 | "(authchain H and authresid 1-14) around 12.0", 83 | 238, 84 | 238, 85 | ), # pymol-like use of authchain and auth res ids 86 | ( 87 | "not (authchain H and authresid 1-14) around 12.0", 88 | 959, 89 | 1000, 90 | ), # right-association lets this one work ok. 91 | ( 92 | "(not ((authchain H and authresid 1-14) around 12.0)) or ((authchain H and" 93 | " authresid 1-14) around 12.0)", 94 | 1148, 95 | 1189, 96 | ), # Venn full coverage 97 | ( 98 | "first (not ((authchain H and authresid 1-14) around 12.0)) or ((authchain" 99 | " H and authresid 1-14) around 12.0)", 100 | 0, 101 | 1, 102 | ), # testing first 103 | ( 104 | "last (not ((authchain H and authresid 1-14) around 12.0)) or ((authchain H" 105 | " and authresid 1-14) around 12.0)", 106 | 1, 107 | 1, 108 | ), # testing last 109 | ("gti 0-18 or gti 190-234", 28, 64), # GTI access 110 | ( 111 | "gti 0-18 or gti 1140-2340", 112 | 53, 113 | 68, 114 | ), # GTI access, including some out-of-range values 115 | ] 116 | 117 | # compare with MST in right-associativity mode 118 | sys = System.from_CIF(cif_file) 119 | for cmd in valid_strings: 120 | print(f"cmd = {cmd}") 121 | 122 | # check saved selections (MST happens to be right-associative) 123 | for uns in [False, True]: 124 | print(f"allow unstructued {uns}") 125 | sys.save_selection(cmd[0], left_associativity=False, allow_unstructured=uns) 126 | assert len(sys.get_selected()) == cmd[1 + uns] 127 | 128 | # smoke test in left-associativity mode and test writing/reading selections 129 | with tempfile.NamedTemporaryFile(suffix=".pdb", delete=True) as file: 130 | file.close() 131 | for cmd in valid_strings: 132 | print(f"cmd = {cmd}") 133 | 134 | for uns in [False, True]: 135 | print(f"allow unstructued {uns}") 136 | selname = "_some_good_selection" 137 | sys.save_selection( 138 | cmd[0], 139 | selname=selname, 140 | left_associativity=False, 141 | allow_unstructured=uns, 142 | ) 143 | assert sys.has_selection(selname) 144 | sys.to_CIF(file.name) 145 | newsys = System.from_CIF(file.name) 146 | assert newsys.has_selection(selname) 147 | assert newsys.get_selected(selname) == sys.get_selected(selname) 148 | 149 | results = sys.select(cmd[0], left_associativity=False) 150 | 151 | 152 | def test_invalid_input(cif_file): 153 | parse_failure_strings = { # strings that should fail but never segfault. 154 | "chain B and resid 3 around 2.5 B", # extra stuff - need parentheses 155 | "chain B and resid 3 around", # can't convert empty token into distance 156 | "byres (chain B and resid 3)) around 2.5", # unmatched close paren 157 | "chain B chain A", # invalid connector 158 | } 159 | sys = System.from_CIF(cif_file) 160 | for cmd in parse_failure_strings: 161 | print(cmd) 162 | try: 163 | result = sys.select(cmd) 164 | worked = True 165 | except Exception as e: 166 | worked = False 167 | assert worked == False, f"expression {cmd} was meant to fail but succeeded" 168 | 169 | 170 | def next_structure_file(num=100, cif=True): 171 | tmp_file = os.path.join(tempfile.gettempdir(), "_pdb_list.txt") 172 | download_file( 173 | "https://files.wwpdb.org/pub/pdb/derived_data/pdb_entry_type.txt", tmp_file 174 | ) 175 | D = pandas.read_csv(tmp_file, sep="\t", header=None) 176 | pdb_ids = list(D[0]) 177 | random.shuffle(pdb_ids) 178 | 179 | if cif: 180 | file = os.path.join(tempfile.gettempdir(), "_pdb_download.cif") 181 | else: 182 | file = os.path.join(tempfile.gettempdir(), "_pdb_download.pdb") 183 | for pdb_id in pdb_ids[:num]: 184 | # download CIF file 185 | if cif: 186 | file = os.path.join(tempfile.gettempdir(), "_pdb_download.cif") 187 | download_file(f"https://files.rcsb.org/download/{pdb_id}.cif", file) 188 | else: 189 | file = os.path.join(tempfile.gettempdir(), "_pdb_download.pdb") 190 | download_file(f"https://files.rcsb.org/download/{pdb_id}.pdb", file) 191 | yield pdb_id, file 192 | 193 | 194 | def test_writing_pdb(cif_file): 195 | # smoke test for PDB writing 196 | with tempfile.NamedTemporaryFile(suffix=".pdb", delete=True) as file: 197 | file.close() 198 | for pdb_id, cif_file in next_structure_file(num=10, cif=True): 199 | sys = System.from_CIF(cif_file) 200 | sys.to_PDB(file.name) 201 | 202 | 203 | def test_reading_cif(cif_file): 204 | # smoke test for CIF reading 205 | for pdb_id, cif_file in next_structure_file(num=100, cif=True): 206 | # load into python and MST in random order, record time 207 | order = random.random() 208 | for i in range(2): 209 | sys = System.from_CIF(cif_file) 210 | 211 | 212 | def test_reading_pdb(cif_file): 213 | # smoke test for PDB reading 214 | for pdb_id, pdb_file in next_structure_file(num=100, cif=False): 215 | # load into python and MST in random order, record time 216 | order = random.random() 217 | for i in range(2): 218 | sys = System.from_PDB(pdb_file) 219 | 220 | 221 | def test_update_with_xcs(cif_file): 222 | sys = System.from_CIF(CIF_TRAJECTORY) 223 | sys_copy = copy.deepcopy(sys) 224 | 225 | sys.swap_model(1) 226 | X, C, S = sys.to_XCS() 227 | sys.swap_model(1) 228 | sys.update_with_XCS(X, C, S) 229 | 230 | sys_copy.swap_model(1) 231 | assert sys_copy.num_atom_locations() == sys.num_atom_locations() 232 | for loc1, loc2 in zip(sys.locations(), sys_copy.locations()): 233 | assert (abs(loc1.coors == loc2.coors) < 0.01).all() 234 | -------------------------------------------------------------------------------- /tests/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/generatebio/chroma/929407c605013613941803c6113adefdccaad679/tests/layers/__init__.py -------------------------------------------------------------------------------- /tests/layers/structure/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/generatebio/chroma/929407c605013613941803c6113adefdccaad679/tests/layers/structure/__init__.py -------------------------------------------------------------------------------- /tests/layers/structure/test_backbone.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import pytest 4 | import torch 5 | 6 | from chroma.layers.structure.backbone import ( 7 | BackboneBuilder, 8 | LossBackboneResidueDistance, 9 | ProteinBackbone, 10 | RigidTransform, 11 | RigidTransformer, 12 | ) 13 | 14 | 15 | class TestProteinBackbone(TestCase): 16 | def test_cuda(self): 17 | if torch.cuda.is_available(): 18 | try: 19 | protein_backbone = ProteinBackbone(1).cuda() 20 | except Exception: 21 | protein_backbone = None 22 | 23 | self.assertTrue(protein_backbone is not None) 24 | 25 | def test_sample(self): 26 | protein_backbone = ProteinBackbone(1) 27 | 28 | expected = torch.Tensor( 29 | [ 30 | [ 31 | [ 32 | [0.1331, -1.6303, -0.7377], 33 | [0.0414, -0.1759, -0.8080], 34 | [-0.3710, 0.4114, 0.5376], 35 | [0.1965, 1.3947, 1.0081], 36 | ] 37 | ] 38 | ] 39 | ) 40 | 41 | predicted = protein_backbone() 42 | self.assertEqual((1, 1, 4, 3), predicted.shape) 43 | self.assertTrue(torch.allclose(expected, predicted, rtol=1e-03)) 44 | 45 | def test_random_init_backbone(self): 46 | protein_backbone = ProteinBackbone(1, init_state="") 47 | predicted = protein_backbone() 48 | self.assertEqual((1, 1, 4, 3), predicted.shape) 49 | 50 | def test_sample_cartesian(self): 51 | 52 | protein_backbone = ProteinBackbone(1, use_internal_coords=False) 53 | 54 | expected = torch.Tensor( 55 | [ 56 | [ 57 | [ 58 | [0.1331, -1.6303, -0.7377], 59 | [0.0414, -0.1759, -0.8080], 60 | [-0.3710, 0.4114, 0.5376], 61 | [0.1965, 1.3947, 1.0081], 62 | ] 63 | ] 64 | ] 65 | ) 66 | 67 | predicted = protein_backbone() 68 | self.assertEqual((1, 1, 4, 3), predicted.shape) 69 | self.assertTrue(torch.allclose(expected, predicted, rtol=1e-03)) 70 | 71 | def test_initialized_sample(self): 72 | 73 | torch.manual_seed(7) 74 | input_x = torch.rand(1, 2, 4, 3) 75 | predicted = ProteinBackbone(1, use_internal_coords=False, X_init=input_x)() 76 | 77 | expected = torch.Tensor( 78 | [ 79 | [ 80 | [ 81 | [-5.3644e-07, -2.6469e-01, 1.4716e-01], 82 | [1.2197e-01, -2.3073e-01, -8.6988e-02], 83 | [-3.2784e-01, 1.6625e-01, -1.4673e-01], 84 | [3.1635e-01, 3.9145e-01, 3.8886e-02], 85 | ], 86 | [ 87 | [-2.4808e-01, -2.5717e-01, -6.6959e-02], 88 | [-1.7564e-01, 2.5689e-01, -4.3900e-01], 89 | [4.3500e-01, -3.5570e-01, 3.7083e-01], 90 | [-1.2175e-01, 2.9370e-01, 1.8280e-01], 91 | ], 92 | ] 93 | ] 94 | ) 95 | 96 | self.assertTrue(torch.allclose(predicted, expected, rtol=1e-03)) 97 | 98 | 99 | class TestRigidTransform(TestCase): 100 | def test_sample(self): 101 | # Default behavior should be identity transformation 102 | rigid_transform = RigidTransform() 103 | torch.manual_seed(7) 104 | input_x = torch.rand(1, 1, 4, 3) 105 | predicted = rigid_transform(input_x) 106 | self.assertTrue(torch.allclose(predicted, input_x, rtol=1e-3)) 107 | 108 | 109 | class TestRigidTransformer(TestCase): 110 | def test_sample(self): 111 | rigid_transformer = RigidTransformer(center_rotation=True, keep_centered=True) 112 | 113 | input_x = torch.rand(1, 1, 4, 3) 114 | mean_centered = input_x - torch.mean(input_x.reshape(1, -1, 3), axis=-2) 115 | # Test Identity 116 | no_translation = torch.zeros(1, 3) 117 | identity_q = torch.Tensor([[1.0, 0, 0, 0]]) 118 | 119 | predicted = rigid_transformer(input_x, no_translation, identity_q) 120 | self.assertTrue(torch.allclose(predicted, mean_centered, rtol=1e-3)) 121 | 122 | # Test Translation 123 | x_translation = torch.Tensor([[1, 0, 0]]) 124 | expected = mean_centered + x_translation 125 | 126 | predicted = rigid_transformer(input_x, x_translation, identity_q) 127 | self.assertTrue(torch.allclose(predicted, expected, rtol=1e-3)) 128 | 129 | 130 | class TestBackboneBuilder(TestCase): 131 | def test_sample(self): 132 | phi_tensor = torch.Tensor([[-1.0472]]) 133 | psi_tensor = torch.Tensor([[-0.7854]]) 134 | backbone_builder = BackboneBuilder() 135 | 136 | expected = torch.Tensor( 137 | [ 138 | [ 139 | [ 140 | [-1.2286, 0.2223, -1.2286], 141 | [-1.3203, 1.6767, -1.2989], 142 | [-1.7327, 2.2640, 0.0468], 143 | [-1.1652, 3.2473, 0.5172], 144 | ] 145 | ] 146 | ] 147 | ) 148 | 149 | predicted = backbone_builder(phi_tensor, psi_tensor) 150 | 151 | self.assertTrue(torch.allclose(expected, predicted, rtol=1e-3)) 152 | 153 | def test_custom_sample(self): 154 | num_residues = 1 155 | phi_tensor = torch.Tensor([[-1.0472]]) 156 | psi_tensor = torch.Tensor([[-0.7854]]) 157 | backbone_builder = BackboneBuilder() 158 | 159 | expected = torch.Tensor( 160 | [ 161 | [ 162 | [ 163 | [-1.2286, 0.2223, -1.2286], 164 | [-1.3203, 1.6767, -1.2989], 165 | [-1.7327, 2.2640, 0.0468], 166 | [-1.1652, 3.2473, 0.5172], 167 | ] 168 | ] 169 | ] 170 | ) 171 | 172 | predicted = backbone_builder(phi_tensor, psi_tensor) 173 | 174 | lengths = torch.tensor( 175 | [[backbone_builder.lengths[key] for key in ["C_N", "N_CA", "CA_C"]]], 176 | dtype=torch.float32, 177 | ) 178 | lengths = lengths.repeat(1, 1) # (1,3) 179 | 180 | angles = torch.tensor( 181 | [[backbone_builder.angles[key] for key in ["CA_C_N", "C_N_CA", "N_CA_C"]]], 182 | dtype=torch.float32, 183 | ) 184 | angles = angles.repeat(1, 1) # (1,3) 185 | 186 | omega = backbone_builder.angles["omega"] * torch.ones(1, 1) # (1,1) 187 | 188 | predicted = backbone_builder(phi_tensor, psi_tensor, omega, angles, lengths) 189 | self.assertTrue(torch.allclose(expected, predicted, rtol=1e-3)) 190 | 191 | lengths = torch.tensor( 192 | [[backbone_builder.lengths[key] for key in ["C_N", "N_CA", "CA_C"]]], 193 | dtype=torch.float32, 194 | ) 195 | lengths = lengths.repeat(1, num_residues) # (1,3) 196 | 197 | angles = torch.tensor( 198 | [[backbone_builder.angles[key] for key in ["CA_C_N", "C_N_CA", "N_CA_C"]]], 199 | dtype=torch.float32, 200 | ) 201 | angles = angles.repeat(1, num_residues) # (1,3) 202 | 203 | omega = backbone_builder.angles["omega"] * torch.ones(1, num_residues) # (1,1) 204 | 205 | predicted = backbone_builder(phi_tensor, psi_tensor, omega, angles, lengths) 206 | self.assertTrue(torch.allclose(expected, predicted, rtol=1e-3)) 207 | -------------------------------------------------------------------------------- /tests/layers/structure/test_conditioners.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import pytest 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | import chroma 9 | from chroma.data import xcs 10 | from chroma.data.protein import Protein 11 | from chroma.layers.structure import backbone, conditioners, rmsd, symmetry 12 | from chroma.models.graph_backbone import GraphBackbone 13 | from chroma.models.procap import ProteinCaption 14 | 15 | 16 | @pytest.fixture(scope="session") 17 | def XCO(): 18 | repo = Path(chroma.__file__).parent.parent 19 | test_cif = str(Path(repo, "tests", "resources", "6wgl.cif")) 20 | protein = Protein.from_CIF(test_cif) 21 | X, C, S = protein.to_XCS() 22 | X.requires_grad = True 23 | O = F.one_hot(S, 20) 24 | return X, C, O 25 | 26 | 27 | @pytest.fixture(scope="session") 28 | def protein(): 29 | pdb_id = "1drf" 30 | protein = Protein.from_PDBID(pdb_id, canonicalize=True) 31 | protein.sys.save_selection(gti=list(range(15)), selname="clamp") 32 | protein.sys.save_selection(gti=list(range(15, 25)), selname="semirigid") 33 | return protein 34 | 35 | 36 | @pytest.fixture 37 | def test_conditioner_pointgroup_conditioner(XCO): 38 | X, C, O = XCO 39 | conditioner = conditioners.SymmetryConditioner( 40 | G=symmetry.get_point_group("I"), num_chain_neighbors=3 41 | ) 42 | X_constrained, _, _, _, _ = conditioner(X, C, O, 0.0, 0.5) 43 | return conditioner, X_constrained, C 44 | 45 | 46 | @pytest.fixture 47 | def test_conditioner_screw_conditioner(XCO): 48 | X, C, O = XCO 49 | conditioner = conditioners.ScrewConditioner(theta=np.pi / 4, tz=5.0, M=10) 50 | X_constrained, _, _, _, _ = conditioner(X, C, O, 0.0, 0.5) 51 | return conditioner, X_constrained, C 52 | 53 | 54 | @pytest.fixture 55 | def test_conditioner_Rg_conditioner(XCO): 56 | X, C, O = XCO 57 | conditioner = conditioners.RgConditioner() 58 | conditioner(X, C, O, 0.0, 0.5) 59 | return conditioner, X, C 60 | 61 | 62 | @pytest.fixture 63 | def test_conditioner_symmetry_and_substructure(protein): 64 | bb_model = GraphBackbone(dim_nodes=16, dim_edges=16) 65 | protein.get_mask("namesel clamp") 66 | sub_conditioner = conditioners.SubstructureConditioner( 67 | protein, bb_model, "namesel clamp" 68 | ) 69 | 70 | sym_conditioner = conditioners.SymmetryConditioner( 71 | G=symmetry.get_point_group("C_3"), num_chain_neighbors=1, freeze_com=True 72 | ) 73 | 74 | composed_conditioner = conditioners.ComposedConditioner( 75 | [sub_conditioner, sym_conditioner] 76 | ) 77 | 78 | X, C, S = protein.to_XCS() 79 | X.requires_grad = True 80 | O = F.one_hot(S, 20) 81 | X_constrained, _, _, _, _ = composed_conditioner(X, C, O, 0.0, torch.tensor([0.0])) 82 | return composed_conditioner, X_constrained, C 83 | 84 | 85 | @pytest.fixture 86 | def test_conditioner_substructure_conditioner(protein): 87 | aligner = rmsd.BackboneRMSD() 88 | bb_model = GraphBackbone(dim_nodes=16, dim_edges=16) 89 | X, C, S = protein.to_XCS() 90 | O = F.one_hot(S, 20) 91 | conditioner = conditioners.SubstructureConditioner( 92 | protein, bb_model, "namesel clamp" 93 | ) 94 | X_conditioned, _, _, _, _ = conditioner( 95 | torch.randn_like(X), C, O, 0.0, torch.tensor([0.0]) 96 | ) 97 | D = protein.get_mask("namesel clamp") 98 | _, rmsd1 = aligner.align(X_conditioned, X, D) 99 | assert rmsd1.isclose(torch.tensor(0.0), atol=1e-1) 100 | 101 | return conditioner, X, C 102 | 103 | 104 | @pytest.fixture 105 | def test_conditioner_procap_conditioner(XCO): 106 | model = ProteinCaption() 107 | X, C, O = XCO 108 | conditioner = conditioners.ProCapConditioner("Test caption", -1, model=model) 109 | conditioner(X, C, O, 0, 0.5) 110 | return conditioner, X, C 111 | 112 | 113 | def collect_conditioners(): 114 | return [v for k, v in globals().items() if k.startswith("test_conditioner_")] 115 | 116 | 117 | @pytest.fixture(params=["globular"]) 118 | def gaussian_noise(request): 119 | from chroma.layers.structure.diffusion import DiffusionChainCov 120 | 121 | covariance_model = request.param 122 | return DiffusionChainCov( 123 | covariance_model=covariance_model, 124 | complex_scaling=False, 125 | noise_schedule="log_snr", 126 | ) 127 | 128 | 129 | @pytest.mark.parametrize("conditioner", collect_conditioners()) 130 | def test_sampling(gaussian_noise, conditioner, request): 131 | conditioner_cls, X_native, C = request.getfixturevalue(conditioner.__name__) 132 | 133 | def X0_func(X, C, t): 134 | return X_native 135 | 136 | out = gaussian_noise.sample_sde( 137 | X0_func=X0_func, C=C, X_init=None, N=2, conditioner=conditioner_cls 138 | ) 139 | 140 | 141 | def test_proclass_conditioner(protein): 142 | """Smoke test for secondary structure conditioning""" 143 | SECONDARY_STRUCTURE = "CCEEEEEEEETTTTECTTTTTTTTCCCHHHHHHHHHHHHCCCTTTTEEEEEECHHHHHHCTGGTTTTTTTEEEEETTTTTTTTTTTCEEECTHHHHHHHHHCHGHGGHCCEEEEEECHHHHHHHHHCTCEEEEEEEEETTCCCTTEECCCCTGGGTEEETETTTTTCCEEEETTEEEEEEEEEEEC" 144 | X, C, S = protein.to_XCS() 145 | X.detach() 146 | X.requires_grad = True 147 | O = F.one_hot(S, 20) 148 | conditioner = conditioners.ProClassConditioner( 149 | "secondary_structure", SECONDARY_STRUCTURE, device="cpu" 150 | ) 151 | conditioner(X, C, O, 0, 0.5) 152 | -------------------------------------------------------------------------------- /tests/layers/structure/test_geometry.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from unittest import TestCase 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | import chroma 9 | from chroma.data import Protein 10 | from chroma.layers.structure import geometry 11 | 12 | 13 | class TestDistances(TestCase): 14 | def test_sample(self): 15 | distances = geometry.Distances() 16 | torch.manual_seed(7) 17 | input_x = torch.rand(1, 2, 4, 3) 18 | dim = -2 19 | predicted = distances(input_x, None, dim) 20 | self.assertTrue(predicted.shape == (1, 2, 4, 4)) 21 | expected = torch.tensor( 22 | [ 23 | [ 24 | [ 25 | [0.0316, 0.2681, 0.6169, 0.7371], 26 | [0.2681, 0.0316, 0.6037, 0.6646], 27 | [0.6169, 0.6037, 0.0316, 0.7079], 28 | [0.7371, 0.6646, 0.7079, 0.0316], 29 | ], 30 | [ 31 | [0.0316, 0.6395, 0.8179, 0.6187], 32 | [0.6395, 0.0316, 1.1853, 0.6260], 33 | [0.8179, 1.1853, 0.0316, 0.8764], 34 | [0.6187, 0.6260, 0.8764, 0.0316], 35 | ], 36 | ] 37 | ] 38 | ) 39 | self.assertTrue(torch.allclose(predicted, expected, rtol=1e-3)) 40 | 41 | 42 | class TestRotations(TestCase): 43 | def setUp(self): 44 | self.R = torch.tensor( 45 | [ 46 | [ 47 | [0.9027011, -0.1829866, -0.3894183], 48 | [-0.3146039, 0.3367128, -0.8874959], 49 | [0.2935220, 0.9236560, 0.2463827], 50 | ], 51 | [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], 52 | [ 53 | [-0.6638935, 0.6988353, 0.2662229], 54 | [-0.6322795, -0.3344426, -0.6988353], 55 | [-0.3993345, -0.6322795, 0.6638935], 56 | ], 57 | ] 58 | ) 59 | 60 | self.q = torch.tensor( 61 | [ 62 | [0.7883205, 0.5743704, -0.2165808, -0.0417398], 63 | [1.0, 0.0, 0.0, 0.0], 64 | [0.4079085, 0.0407909, 0.4079085, -0.815817], 65 | ] 66 | ) 67 | 68 | def test_rotations_from_quaternions(self): 69 | R_from_q = geometry.rotations_from_quaternions(self.q) 70 | self.assertTrue(torch.allclose(self.R, R_from_q, atol=1e-3)) 71 | 72 | def test_quaternions_from_rotations(self): 73 | q_from_R = geometry.quaternions_from_rotations(self.R, eps=0.0) 74 | self.assertTrue(torch.allclose(self.q, q_from_R, atol=1e-3)) 75 | 76 | def test_round_trip(self): 77 | R_from_q = geometry.rotations_from_quaternions(self.q) 78 | q_round_trip = geometry.quaternions_from_rotations(R_from_q, eps=0.0) 79 | R_from_round_trip = geometry.rotations_from_quaternions(q_round_trip) 80 | 81 | self.assertTrue(torch.allclose(self.q, q_round_trip, atol=1e-3)) 82 | self.assertTrue(torch.allclose(self.R, R_from_round_trip, atol=1e-3)) 83 | 84 | 85 | class TestExtendAtoms(TestCase): 86 | def test_extend_atoms_round_trip(self): 87 | # Test cycle-consistency of geometry measurement and building routines 88 | num_batch, num_residues = 10, 30 89 | X1, X2, X3 = torch.randn([num_batch, num_residues, 3, 3]).unbind(-1) 90 | L = torch.exp(torch.randn([num_batch, num_residues])) + 1.0 91 | A = np.pi * torch.sigmoid(torch.randn([num_batch, num_residues])) 92 | D = np.pi * torch.randn([num_batch, num_residues]) 93 | 94 | X4 = geometry.extend_atoms(X1, X2, X3, L, A, D, distance_eps=1e-6) 95 | 96 | L_recover = geometry.lengths(X3, X4, distance_eps=0.0) 97 | A_recover = geometry.angles(X2, X3, X4, distance_eps=0.0) 98 | D_recover = geometry.dihedrals(X1, X2, X3, X4, distance_eps=0.0) 99 | 100 | _embed = lambda a: torch.stack([torch.cos(a), torch.sin(a)], -1) 101 | 102 | self.assertTrue(torch.allclose(L, L_recover, atol=1e-2)) 103 | self.assertTrue(torch.allclose(A, A_recover, atol=1e-2)) 104 | self.assertTrue(torch.allclose(_embed(D), _embed(D_recover), atol=1e-2)) 105 | return 106 | 107 | 108 | class TestVirtualAtomsCA(TestCase): 109 | def test_atom_placement(self): 110 | # Load test case 111 | file_cif = str( 112 | Path(Path(chroma.__file__).parent.parent, "tests", "resources", "5jg9.cif",) 113 | ) 114 | X, C, S = Protein(file_cif).to_XCS() 115 | 116 | for v_type in ["cbeta", "dicons"]: 117 | # Place atoms 118 | atom_placer = geometry.VirtualAtomsCA(virtual_type=v_type) 119 | X_virtual = atom_placer(X, C) 120 | # DEBUG: Sanity check is useful for testing 121 | # geometry.debug_pymol_virtual_atoms(X, X_virtual, 'test_5jg9.pml') 122 | 123 | # Test that generated angles are correct 124 | X_N, X_CA, X_C, X_O = X.unbind(2) 125 | 126 | bonds = torch.norm(X_virtual - X_CA, dim=-1) 127 | angles = geometry.angles( 128 | X_N, X_CA, X_virtual, distance_eps=1e-6, degrees=True 129 | ) 130 | dihedrals = geometry.dihedrals( 131 | X_C, X_N, X_CA, X_virtual, distance_eps=1e-6, degrees=True 132 | ) 133 | 134 | bond_t, angle_t, dihedral_t = atom_placer.geometry() 135 | mask = (C > 0).type(torch.float32) 136 | bond_error = mask * (bonds - bond_t) 137 | angle_error = mask * (angles - angle_t) 138 | dihedral_error = mask * (dihedrals - dihedral_t) 139 | 140 | self.assertTrue( 141 | torch.allclose(bond_error, torch.zeros_like(bond_error), atol=1e-2) 142 | ) 143 | self.assertTrue( 144 | torch.allclose(angle_error, torch.zeros_like(angle_error), atol=1e-2) 145 | ) 146 | self.assertTrue( 147 | torch.allclose( 148 | dihedral_error, torch.zeros_like(dihedral_error), atol=1e-2 149 | ) 150 | ) 151 | -------------------------------------------------------------------------------- /tests/layers/structure/test_hbonds.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import pytest 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | import chroma 9 | from chroma.data import Protein 10 | from chroma.layers import graph 11 | from chroma.layers.structure import hbonds, protein_graph 12 | 13 | 14 | @pytest.fixture(scope="session") 15 | def XCS(): 16 | repo = Path(chroma.__file__).parent.parent 17 | pdb_id = "6wgl" 18 | test_cif = str(Path(repo, "tests", "resources", "6wgl.cif")) 19 | X, C, S = Protein(test_cif).to_XCS() 20 | return X, C, S, pdb_id 21 | 22 | 23 | def test_backbone_hbonds(XCS, debug_plot=False): 24 | X, C, S, pdb_id = XCS 25 | 26 | bb_hbonds = hbonds.BackboneHBonds() 27 | 28 | # Build Graph 29 | graph_builder = protein_graph.ProteinGraph() 30 | edge_idx, mask_ij = graph_builder(X, C) 31 | hb, mask_hb, H_i = bb_hbonds(X, C, edge_idx, mask_ij) 32 | hb_dense = graph.scatter_edges(hb[..., None], edge_idx)[..., 0] 33 | 34 | if debug_plot: 35 | if False: 36 | H = hb_dense[0, :, :].data.numpy() 37 | from matplotlib import pyplot as plt 38 | 39 | plt.matshow(H) 40 | plt.show() 41 | 42 | # Build 43 | rgb = (0.3, 0.7, 0.1) 44 | with open(f"viz_hbonds_{pdb_id}.pml", "w") as f: 45 | f.write( 46 | "delete all\n" 47 | f"fetch {pdb_id}\n" 48 | f"hide everything, {pdb_id}\n" 49 | "show sticks, bb.\n" 50 | "color white, all\n" 51 | "color atomic, (not elem C)\n" 52 | "h_add bb.\n" 53 | "distance hbonds_pymol, don. and bb., acc. and bb., 3.6, mode=2\n" 54 | "hide labels\n" 55 | ) 56 | cgo_list = [protein_graph._cgo_color(rgb)] 57 | for i in range(edge_idx.size(1)): 58 | for j_idx in range(edge_idx.size(2)): 59 | if hb[0, i, j_idx] > 0: 60 | j = edge_idx[0, i, j_idx] 61 | cgo_list.append( 62 | protein_graph._cgo_cylinder( 63 | H_i[0, i, :], X[0, j, 3, :], radius=0.08, rgb=rgb 64 | ) 65 | ) 66 | cgo_list.append( 67 | protein_graph._cgo_sphere(H_i[0, i, :], radius=0.3) 68 | ) 69 | cgo_str = " + ".join(cgo_list) 70 | f.write(f'cmd.load_cgo({cgo_str}, "hbonds_pytorch", 1)\n') 71 | 72 | # These hydrogen bonds were manually spot checked for 6wgl 73 | # in Pymol using the above script. We don't count i-i+2 and 74 | # there appear to be subtle orientation dependent, but 75 | # SS-dependent calls agree well 76 | assert hb_dense.sum().item() == 303 77 | 78 | 79 | def test_loss_hbb(XCS, debug=False): 80 | X, C, S, pdb_id = XCS 81 | loss_hbb = hbonds.LossBackboneHBonds() 82 | 83 | torch.manual_seed(1.0) 84 | X_noise = X + torch.randn_like(X) 85 | recovery_local, recovery_nonlocal, error_co = loss_hbb(X_noise, X, C) 86 | assert recovery_local.mean().item() < 1.0 87 | assert recovery_nonlocal.mean().item() < 1.0 88 | assert error_co > 0.0 89 | 90 | recovery_local, recovery_nonlocal, error_co = loss_hbb(X, X, C) 91 | assert recovery_local.mean().item() == pytest.approx(1.0, 1e-2) 92 | assert recovery_nonlocal.mean().item() == pytest.approx(1.0, 1e-2) 93 | assert error_co.mean().item() == pytest.approx(0.0, 1e-2) 94 | 95 | if debug: 96 | # This 97 | from chroma.layers.structure import diffusion 98 | 99 | noise = diffusion.DiffusionChainCov(complex_scaling=True) 100 | 101 | T = np.linspace(0, 1, 100) 102 | R_local = [] 103 | R_nonlocal = [] 104 | for t in T: 105 | X_noise = noise(X, C, t=t) 106 | recovery_local, recovery_nonlocal, error_co = loss_hbb(X_noise, X, C) 107 | R_local.append(recovery_local.mean().item()) 108 | R_nonlocal.append(recovery_nonlocal.mean().item()) 109 | A = noise.noise_schedule.alpha(T.tolist()).data.numpy().flatten() 110 | 111 | from matplotlib import pyplot as plt 112 | 113 | plt.subplot(1, 2, 1) 114 | plt.plot(T, R_local, label="Local H-Bonds") 115 | plt.plot(T, R_nonlocal, label="Nonlocal H-Bonds") 116 | plt.xlim([0, 1]) 117 | plt.ylim([0, 1]) 118 | plt.xlabel("t") 119 | plt.ylabel("Recovery") 120 | plt.legend() 121 | plt.grid() 122 | plt.subplot(1, 2, 2) 123 | plt.plot(A, R_local, label="Local H-Bonds") 124 | plt.plot(A, R_nonlocal, label="Nonlocal H-Bonds") 125 | plt.xlim([0, 1]) 126 | plt.ylim([0, 1]) 127 | plt.xlabel("alpha") 128 | plt.ylabel("Recovery") 129 | plt.legend() 130 | plt.grid() 131 | plt.show() 132 | return 133 | -------------------------------------------------------------------------------- /tests/layers/structure/test_mvn.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from unittest import SkipTest, TestCase 3 | 4 | import numpy as np 5 | import pytest 6 | import torch 7 | 8 | import chroma 9 | from chroma.data import Protein 10 | from chroma.layers.structure.backbone import impute_masked_X 11 | from chroma.layers.structure.mvn import ( 12 | BackboneMVNGlobular, 13 | BackboneMVNResidueGas, 14 | ConditionalBackboneMVNGlobular, 15 | ) 16 | from chroma.layers.structure.rmsd import BackboneRMSD 17 | 18 | 19 | @pytest.fixture(params=["brownian", "globular", "residue_gas"]) 20 | def noise(request): 21 | covariance_model = request.param 22 | if covariance_model in ["brownian", "globular"]: 23 | return BackboneMVNGlobular( 24 | covariance_model=covariance_model, complex_scaling=True, 25 | ) 26 | else: 27 | return BackboneMVNResidueGas( 28 | covariance_model=covariance_model, complex_scaling=True, 29 | ) 30 | 31 | 32 | @pytest.fixture(params=["real", "synthetic"]) 33 | def XCS(request): 34 | xcs_type = request.param 35 | if xcs_type == "real": 36 | repo = Path(chroma.__file__).parent.parent 37 | test_cif = str(Path(repo, "tests", "resources", "6wgl.cif")) 38 | X, C, S = Protein(test_cif).to_XCS() 39 | else: 40 | num_batch, num_residues = 5, 100 41 | X = 10 * torch.randn([num_batch, num_residues * 4, 3]) 42 | C = torch.ones([num_batch, num_residues]) 43 | S = C.clone() 44 | return X, C, S 45 | 46 | 47 | def test_full_covariance_and_sqrt_covariance_computation(): 48 | num_batch, num_residues = 1, 100 49 | X = 10 * torch.randn([num_batch, num_residues, 4, 3]) 50 | C = torch.ones([num_batch, num_residues]) 51 | S = C.clone() 52 | D = torch.randint(low=0, high=2, size=(C.size())) 53 | 54 | # Fill in missing pieces 55 | X = impute_masked_X(X, C) 56 | C = torch.abs(C) 57 | 58 | mvn = BackboneMVNGlobular(covariance_model="globular", complex_scaling=True,) 59 | cmvn = ConditionalBackboneMVNGlobular( 60 | covariance_model="globular", complex_scaling=True, X=X, C=C, D=D 61 | ) 62 | 63 | # Test R 64 | Z = torch.randn_like(X).reshape(X.shape[0], -1, 3) 65 | RZ_mvn_implicit = mvn._multiply_R(Z, C) 66 | RZ_mvn_dense = (cmvn.R @ Z).reshape(RZ_mvn_implicit.shape) 67 | assert torch.allclose(RZ_mvn_implicit, RZ_mvn_dense, atol=1e-2) 68 | 69 | # Test RRt 70 | RRt_Z_implicit = mvn.multiply_covariance(Z, C) 71 | RRt_Z_dense = cmvn.RRt @ Z 72 | assert torch.allclose(RRt_Z_implicit, RRt_Z_dense, atol=1e-2) 73 | 74 | 75 | def test_invertibility_R(noise, XCS): 76 | """Test invertibility of the covariance square root.""" 77 | X, C, S = XCS 78 | X = X.reshape([X.shape[0], -1, 3]) 79 | 80 | Ri_X = noise._multiply_R_inverse(X, C) 81 | R_Ri_X = noise._multiply_R(Ri_X, C) 82 | 83 | Rti_X = noise._multiply_R_inverse_transpose(X, C) 84 | Rt_Rti_X = noise._multiply_R_transpose(Rti_X, C) 85 | X = X.reshape(X.shape[0], C.shape[1], -1, 3) 86 | 87 | Ri_X = Ri_X.reshape(X.shape) 88 | R_Ri_X = R_Ri_X.reshape(X.shape) 89 | Rti_X = Rti_X.reshape(X.shape) 90 | Rt_Rti_X = Rt_Rti_X.reshape(X.shape) 91 | 92 | if False: 93 | from chroma.layers.structure.diffusion import _debug_viz_XZC 94 | 95 | _debug_viz_XZC(X, Ri_X, C) 96 | 97 | assert torch.allclose(X, R_Ri_X, atol=1e-2) 98 | assert torch.allclose(X, Rt_Rti_X, atol=1e-2) 99 | assert not torch.allclose(Ri_X, R_Ri_X, atol=1e-2) 100 | assert not torch.allclose(Rti_X, Rt_Rti_X, atol=1e-2) 101 | 102 | 103 | def test_invertibility_covariance(noise, XCS, debug=False): 104 | """Test invertibility of the covariance matrix. 105 | 106 | Note: the covariance matrix is poorly conditioned for all but 107 | the smallest systems, so for numerical verification the system needs 108 | to be small with a large tolerance. 109 | """ 110 | X, C, S = XCS 111 | 112 | # Cycle constraint 113 | Ci_X = noise.multiply_inverse_covariance(X, C) 114 | C_Ci_X = noise.multiply_covariance(Ci_X, C) 115 | 116 | if debug and not torch.allclose(X, C_Ci_X, atol=1e-1): 117 | from matplotlib import pyplot as plt 118 | 119 | plt.figure() 120 | plt.subplot(3, 1, 1) 121 | plt.plot((X - C_Ci_X).data.numpy().flatten(), ".") 122 | plt.subplot(3, 1, 2) 123 | plt.plot(X.data.numpy().flatten()) 124 | plt.subplot(3, 1, 3) 125 | plt.plot(C.data.numpy().flatten()) 126 | plt.savefig(f"test_icov.pdf") 127 | 128 | assert torch.allclose(X, C_Ci_X, atol=1e-1) 129 | assert not torch.allclose(Ci_X, C_Ci_X, atol=1e-1) 130 | 131 | 132 | def test_log_determinant(noise): 133 | """Test log determinant of the covariance matrix.""" 134 | X, C, S = Protein("5imm").to_XCS() 135 | 136 | X = X[0:1, ...] 137 | C = C[0:1, ...] 138 | C = torch.abs(C) 139 | X = impute_masked_X(X, C) 140 | 141 | if hasattr(noise, "covariance_model"): 142 | # Use the conditional covariance model to build a dense RRt 143 | cmvn = ConditionalBackboneMVNGlobular( 144 | covariance_model=noise.covariance_model, 145 | complex_scaling=noise.complex_scaling, 146 | X=X, 147 | C=C, 148 | D=C.ne(0).float(), 149 | ) 150 | 151 | R, RRt = cmvn._materialize_RRt(C) 152 | R = R.data.numpy() 153 | logdet_dense = 3.0 * np.linalg.slogdet(R)[1] 154 | logdet = noise.log_determinant(C) 155 | 156 | assert logdet.item() == pytest.approx(logdet_dense.item()) 157 | 158 | 159 | def test_cmvn(noise): 160 | if isinstance(noise, BackboneMVNResidueGas): 161 | pass 162 | else: 163 | aligner = BackboneRMSD() 164 | protein = Protein("1drf") 165 | X, C, S = protein.to_XCS() 166 | protein.sys.save_selection(gti=list(range(14)), selname="clamp") 167 | cmvn = ConditionalBackboneMVNGlobular( 168 | covariance_model=noise.covariance_model, 169 | complex_scaling=noise.complex_scaling, 170 | X=X, 171 | C=C, 172 | D=protein.get_mask("namesel clamp"), 173 | ) 174 | 175 | X_sample = cmvn.sample() 176 | _, rmsd = aligner.align(X, X_sample, protein.get_mask("namesel clamp")) 177 | assert rmsd.item() < 1e-1 178 | -------------------------------------------------------------------------------- /tests/layers/structure/test_optimal_transport.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from chroma.layers.structure.optimal_transport import ( 5 | optimize_couplings_gw, 6 | optimize_couplings_sinkhorn, 7 | ) 8 | 9 | 10 | # test sinkhorn 11 | def test_sinkhorn(): 12 | C = torch.Tensor([[[1, 0, 0], [0, 0, 1], [0, 1, 0]]]) 13 | assert torch.allclose( 14 | optimize_couplings_sinkhorn(C).argmin(-1), torch.LongTensor([[0, 2, 1]]) 15 | ) 16 | 17 | 18 | def test_gw(): 19 | # TODO: need a nontrivial test 20 | seed1 = torch.randn(4).abs() 21 | adj1 = torch.outer(seed1, seed1) 22 | 23 | Da = torch.stack([adj1, adj1]) 24 | Db = torch.stack([adj1, adj1]) 25 | 26 | optimize_couplings_gw(Da, Db, scale=2) 27 | -------------------------------------------------------------------------------- /tests/layers/structure/test_potts.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | from itertools import product 3 | 4 | import numpy as np 5 | import pytest 6 | import torch 7 | from torch.nn import functional as F 8 | 9 | from chroma.layers.structure.potts import ( 10 | GraphPotts, 11 | compute_potts_energy, 12 | fold_symmetry, 13 | sample_potts, 14 | ) 15 | 16 | 17 | def test_graphpotts(): 18 | # Testing symmetry 19 | # Create non-symmetric Potts model and symmetrize using serial or not 20 | potts = GraphPotts(128, 128, 20, symmetric_J=False) 21 | 22 | node_h = torch.rand(1, 3, 128) 23 | edge_h = torch.rand(1, 3, 2, 128) 24 | edge_idx = torch.tensor([[[1, 2], [0, 2], [0, 1]]]) 25 | mask_i = torch.ones(1, 3) 26 | mask_ij = torch.ones(1, 3, 2) 27 | 28 | h, J = potts(node_h, edge_h, edge_idx, mask_i, mask_ij) 29 | 30 | assert ( 31 | potts._symmetrize_J(J, edge_idx, mask_ij) 32 | != potts._symmetrize_J_serial(J, edge_idx, mask_ij) 33 | ).sum().detach().numpy() == 0 34 | 35 | mask_ij = torch.tensor([[[1, 1], [1, 0], [1, 0]]]) 36 | h, J = potts(node_h, edge_h, edge_idx, mask_i, mask_ij) 37 | 38 | assert ( 39 | potts._symmetrize_J(J, edge_idx, mask_ij) 40 | != potts._symmetrize_J_serial(J, edge_idx, mask_ij) 41 | ).sum().detach().numpy() == 0 42 | 43 | 44 | def test_symmetry_folding(): 45 | N, Q = 12, 3 46 | symmetry_order = 3 47 | N_au = N // symmetry_order 48 | 49 | # Testing symmetry 50 | mask_i = torch.ones(1, N) 51 | mask_ij = (1.0 - torch.eye(N))[None, ...] 52 | h = torch.randn([1, N, Q]) 53 | J = torch.randn([1, N, N, Q, Q]) 54 | J = J + J.permute([0, 2, 1, 4, 3]) 55 | # J = torch.eye(Q)[None,None,None,...].expand([1, N, N, Q, Q]) 56 | J = J * mask_ij[..., None, None] 57 | edge_idx = torch.arange(N).long()[None, None, :].expand([1, N, N]) 58 | 59 | h_fold, J_fold, edge_idx_fold, mask_i_fold, mask_ij_fold = fold_symmetry( 60 | symmetry_order, h, J, edge_idx, mask_i, mask_ij, normalize=False 61 | ) 62 | # Validate dimensions 63 | assert tuple(h_fold.shape) == (1, N_au, Q) 64 | assert tuple(J_fold.shape) == (1, N_au, N_au, Q, Q) 65 | assert tuple(edge_idx_fold.shape) == (1, N_au, N_au) 66 | assert tuple(mask_i_fold.shape) == (1, N_au) 67 | assert tuple(mask_ij_fold.shape) == (1, N_au, N_au) 68 | 69 | # Does the folded Potts model return same energies as full? 70 | S_test_fold = torch.randint(high=Q, size=[1, N_au]) 71 | S_test = S_test_fold[:, None, :].expand([1, symmetry_order, N_au]).reshape([1, N]) 72 | 73 | U, U_i = compute_potts_energy(S_test, h, J, edge_idx) 74 | U_fold, U_i_fold = compute_potts_energy(S_test_fold, h_fold, J_fold, edge_idx_fold) 75 | 76 | assert torch.allclose(U, U_fold) 77 | 78 | 79 | @pytest.mark.parametrize("proposal", ["dlmc", "chromatic"]) 80 | def test_potts_mcmc(proposal, debug=False): 81 | """MCMC test for Chromatic Gibbs sampling.""" 82 | # Build a test, fully connected Potts model 83 | if debug: 84 | # Heavy duty sampling with large state space 85 | N = 5 86 | q = 4 87 | num_sweeps = 1000 88 | num_chains = 1000 89 | rtol = 0.05 90 | else: 91 | # Quick and dirty small state space 92 | N = 3 93 | q = 3 94 | num_sweeps = 200 95 | num_chains = 1000 96 | rtol = 0.1 97 | 98 | beta = 0.1 99 | warmup_fraction = 0.1 100 | 101 | torch.manual_seed(1) 102 | mask_i = torch.ones([1, N]).float() 103 | mask_ij = (1 - torch.eye(N))[None, ...].float() 104 | edge_idx = torch.arange(N)[None, None, :].expand([1, N, N]) 105 | 106 | h = beta * torch.randn([1, N, q]) 107 | J = beta * torch.randn([1, N, N, q, q]) 108 | J = mask_ij[..., None, None] * (J + J.permute([0, 2, 1, 4, 3])) / np.sqrt(2) 109 | 110 | # Enumerate all of sequence space 111 | alphabet = "ABCDEFGHIJK"[:q] 112 | sequences = ["".join(x) for x in product(alphabet, repeat=N)] 113 | S_exact = torch.Tensor( 114 | [[alphabet.index(s) for s in seq] for seq in sequences] 115 | ).long() 116 | print(f"Enumerated {len(sequences)} sequences") 117 | 118 | if torch.cuda.is_available(): 119 | device = "cuda" 120 | h = h.to(device) 121 | J = J.to(device) 122 | edge_idx = edge_idx.to(device) 123 | mask_i = mask_i.to(device) 124 | mask_ij = mask_ij.to(device) 125 | S_exact = S_exact.to(device) 126 | 127 | # Compute exact distribution over sequence space 128 | B = S_exact.shape[0] 129 | h_expand = h.expand([B, -1, -1]) 130 | J_expand = J.expand([B, -1, -1, -1, -1]) 131 | edge_idx_expand = edge_idx.expand([B, -1, -1]) 132 | mask_i_expand = mask_i.expand([B, -1]) 133 | mask_ij_expand = mask_ij.expand([B, -1, -1]) 134 | U, _ = compute_potts_energy(S_exact, h_expand, J_expand, edge_idx_expand) 135 | p_exact = F.softmax(-U, -1).tolist() 136 | 137 | # Estimate distribution from sampled sequences 138 | h_expand = h.expand([num_chains, -1, -1]) 139 | J_expand = J.expand([num_chains, -1, -1, -1, -1]) 140 | edge_idx_expand = edge_idx.expand([num_chains, -1, -1]) 141 | mask_i_expand = mask_i.expand([num_chains, -1]) 142 | mask_ij_expand = mask_ij.expand([num_chains, -1, -1]) 143 | 144 | S, U, S_trajectory, U_trajectory = sample_potts( 145 | h_expand, 146 | J_expand, 147 | edge_idx_expand, 148 | mask_i_expand, 149 | mask_ij_expand, 150 | num_sweeps=num_sweeps, 151 | proposal=proposal, 152 | rejection_step=True, 153 | verbose=True, 154 | return_trajectory=True, 155 | ) 156 | if warmup_fraction is not None: 157 | S_trajectory = S_trajectory[int(warmup_fraction * len(S_trajectory)) :] 158 | 159 | S_samples = torch.cat(S_trajectory, 0) 160 | U_trajectory = torch.stack(U_trajectory, 1).cpu().data.numpy() 161 | S_samples = S_samples.cpu().data.numpy() 162 | sample_counts = Counter(["".join([alphabet[c] for c in s]) for s in S_samples]) 163 | p_sample = [sample_counts[seq] / S_samples.shape[0] for seq in sequences] 164 | 165 | if debug: 166 | from matplotlib import pyplot as plt 167 | 168 | plt.figure(figsize=(6, 3)) 169 | plt.subplot(1, 2, 1) 170 | plt.plot(p_exact, p_sample, "k.") 171 | plt.grid() 172 | plt.axis("square") 173 | plt.xlabel("Probability, exact enumeration") 174 | plt.ylabel("Sampling frequencey (MCMC)") 175 | plt.title(f"Random Potts model over {q}^{N} sequences") 176 | plt.subplot(1, 2, 2) 177 | plt.plot(U_trajectory[0, :]) 178 | plt.xlabel("Iterations") 179 | plt.ylabel("Energy") 180 | plt.tight_layout() 181 | plt.show() 182 | 183 | # The frequencies of states visited via MCMC should reproduce their 184 | # exact probabilities (via enumeration) within rtol percent error 185 | assert np.allclose(p_sample, p_exact, rtol=rtol) 186 | 187 | 188 | def debug_potts_2D(): 189 | """Debug test for Potts model""" 190 | N = 100 191 | q = 4 192 | 193 | num_sites = N * N 194 | mask_i = torch.ones([1, N]).float() 195 | ix = torch.arange(num_sites).long() 196 | 197 | # Build 2D lattice topology 198 | edge_idx = torch.stack([ix + 1, ix - 1, ix + N, ix - N], -1) 199 | mask_ij = torch.ones_like(edge_idx).float()[None, :, :] 200 | edge_idx = torch.remainder(edge_idx, num_sites)[None, :, :].long() 201 | 202 | # Ferromagnetic parameters 203 | h = torch.zeros([1, num_sites, q]) 204 | h[:, :, 0] = h[:, :, 0] 205 | mask_J = mask_ij[:, :, :, None, None] * torch.eye(q)[None, None, None, :, :] 206 | 207 | if torch.cuda.is_available(): 208 | device = "cuda" 209 | h = h.to(device) 210 | edge_idx = edge_idx.to(device) 211 | mask_J = mask_J.to(device) 212 | mask_ij = mask_ij.to(device) 213 | 214 | import numpy as np 215 | from matplotlib import pyplot as plt 216 | from matplotlib.animation import FuncAnimation 217 | 218 | temp_range = (1.2, 0.8) 219 | plt.figure(figsize=(5, 5), dpi=600) 220 | _, _, S_trajectory, U_trajectory = sample_potts( 221 | h, 222 | -mask_J, 223 | edge_idx, 224 | mask_i, 225 | mask_ij, 226 | num_sweeps=10000, 227 | verbose=True, 228 | return_trajectory=True, 229 | S=None, 230 | annealing_fraction=1.0, 231 | temperature_init=1.2, 232 | temperature=0.8, 233 | ) 234 | 235 | # Define a function to update the plot for each frame 236 | num_frames = len(S_trajectory) 237 | temps = np.linspace(temp_range[0], temp_range[1], len(S_trajectory)) 238 | betas = 1.0 / temps 239 | 240 | def update(frame): 241 | plt.clf() # Clear the previous frame 242 | plt.pcolor(S_trajectory[frame].cpu().data.numpy().reshape([N, N]), cmap="tab10") 243 | plt.clim([0, 10]) 244 | plt.axis("square") 245 | plt.axis("off") 246 | plt.title(f"Beta = {betas[frame]:0.2f}") 247 | print(frame) 248 | 249 | # Create a figure and set the number of frames 250 | fig = plt.figure(figsize=(4, 4), dpi=300) 251 | animation = FuncAnimation(fig, update, frames=num_frames, interval=1000 / 60) 252 | animation.save("potts.mp4", writer="ffmpeg") 253 | return 254 | -------------------------------------------------------------------------------- /tests/layers/structure/test_protein_graph.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from chroma.data import Protein 7 | from chroma.layers.structure.backbone import RigidTransformer 8 | from chroma.layers.structure.protein_graph import ProteinFeatureGraph 9 | 10 | 11 | def test_protein_feature_graph(): 12 | torch.manual_seed(10) 13 | 14 | dim_nodes, dim_edges = 128, 64 15 | num_neighbors = 30 16 | feature_graph = ProteinFeatureGraph( 17 | dim_nodes=dim_nodes, 18 | dim_edges=dim_edges, 19 | node_features=(("internal_coords", {"log_lengths": True}),), 20 | edge_features=( 21 | "distances_6mer", 22 | "distances_2mer", 23 | "orientations_2mer", 24 | "distances_chain", 25 | "orientations_chain", 26 | "position_2mer", 27 | ), 28 | num_neighbors=num_neighbors, 29 | graph_kwargs={"mask_interfaces": False, "cutoff": None}, 30 | ) 31 | 32 | X, C, S = Protein("5imm").to_XCS() 33 | 34 | node_h, edge_h, edge_idx, mask_i, mask_ij = feature_graph(X, C) 35 | num_nodes = X.shape[1] 36 | 37 | # Test shapes 38 | assert node_h.shape == (1, num_nodes, dim_nodes) 39 | assert edge_h.shape == (1, num_nodes, num_neighbors, dim_edges) 40 | assert edge_idx.shape == (1, num_nodes, num_neighbors) 41 | assert mask_i.shape == (1, num_nodes) 42 | assert mask_ij.shape == (1, num_nodes, num_neighbors) 43 | 44 | # Test masks 45 | masked_sum_i = torch.abs((1.0 - mask_i).unsqueeze(-1) * node_h).sum() 46 | masked_sum_ij = torch.abs((1.0 - mask_ij).unsqueeze(-1) * edge_h).sum() 47 | assert masked_sum_i == 0 48 | assert masked_sum_ij == 0 49 | 50 | transformer = RigidTransformer(center_rotation=False) 51 | q_rotate = torch.Tensor([0.0, 0.1, -1.2, 0.5]).unsqueeze(0) 52 | dX_rotate = torch.Tensor([0.0, 1.0, -1.0]).unsqueeze(0) 53 | _rotate = lambda X_input: transformer(X_input, dX_rotate, q_rotate) 54 | 55 | # Test feature invariance to rotation and translation 56 | X_transformed = _rotate(X) 57 | node_h_r, edge_h_r, edge_idx_r, mask_i_r, mask_ij_r = feature_graph( 58 | X_transformed, C 59 | ) 60 | assert not torch.allclose(X, X_transformed, atol=1e-3) 61 | assert torch.allclose(node_h_r, node_h, atol=1e-3) 62 | assert torch.allclose(edge_h_r, edge_h, atol=1e-3) 63 | assert torch.allclose(edge_idx, edge_idx_r, atol=1e-3) 64 | 65 | 66 | def test_masked_interfaces(): 67 | torch.manual_seed(10) 68 | dim_nodes, dim_edges = 128, 64 69 | num_neighbors = 30 70 | feature_graph = ProteinFeatureGraph( 71 | dim_nodes=dim_nodes, 72 | dim_edges=dim_edges, 73 | node_features=(("internal_coords", {"log_lengths": True}),), 74 | edge_features=( 75 | ("distances_6mer", {"require_contiguous": True}), 76 | "distances_2mer", 77 | "orientations_2mer", 78 | "distances_chain", 79 | "orientations_chain", 80 | ), 81 | num_neighbors=num_neighbors, 82 | graph_kwargs={"mask_interfaces": True, "cutoff": None}, 83 | ) 84 | 85 | X, C, S = Protein("5imm").to_XCS() 86 | 87 | node_h, edge_h, edge_idx, mask_i, mask_ij = feature_graph(X, C) 88 | num_nodes = X.shape[1] 89 | 90 | # Test feature invariance to *single-chain* rotation and translation 91 | rigid_transformer = RigidTransformer(center_rotation=False) 92 | 93 | chain_mask = (C == 2).type(torch.float32) 94 | q = torch.Tensor([2.0, 1.0, 1.0, 0.5]).unsqueeze(0) 95 | dX = torch.Tensor([1.0, 2.0, -1.0]).unsqueeze(0) 96 | X_transformed = rigid_transformer(X, dX, q, mask=chain_mask) 97 | 98 | node_h_r, edge_h_r, edge_idx_r, mask_i_r, mask_ij_r = feature_graph( 99 | X_transformed, C 100 | ) 101 | 102 | assert torch.allclose(node_h_r, node_h, atol=1e-3) 103 | assert torch.allclose(edge_h_r, edge_h, atol=1e-3) 104 | assert torch.allclose(edge_idx, edge_idx_r, atol=1e-3) 105 | -------------------------------------------------------------------------------- /tests/layers/structure/test_rmsd.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | 5 | from chroma.data import Protein 6 | from chroma.layers.structure.backbone import RigidTransformer 7 | from chroma.layers.structure.rmsd import ( 8 | BackboneRMSD, 9 | CrossRMSD, 10 | LossFragmentPairRMSD, 11 | LossFragmentRMSD, 12 | LossNeighborhoodRMSD, 13 | ) 14 | 15 | 16 | @pytest.fixture 17 | def backbones(): 18 | bb1 = torch.tensor( 19 | [ 20 | -5.68175, 21 | -2.183, 22 | 3.27979, 23 | -4.82875, 24 | -3.256, 25 | 2.79379, 26 | -3.34475, 27 | -2.899, 28 | 2.79579, 29 | -2.51375, 30 | -3.697, 31 | 3.21979, 32 | -3.01675, 33 | -1.713, 34 | 2.29979, 35 | -1.62875, 36 | -1.289, 37 | 2.22979, 38 | -0.95775, 39 | -1.094, 40 | 3.58379, 41 | 0.20325, 42 | -1.46, 43 | 3.75679, 44 | -1.69375, 45 | -0.547, 46 | 4.54479, 47 | -1.16675, 48 | -0.358, 49 | 5.88579, 50 | -0.94175, 51 | -1.732, 52 | 6.50679, 53 | 0.03125, 54 | -1.943, 55 | 7.23679, 56 | ] 57 | ).reshape(-1, 12, 3) 58 | 59 | bb2 = torch.tensor( 60 | [ 61 | 3.91725, 62 | 1.271, 63 | -1.22921, 64 | 3.22825, 65 | 0.099, 66 | -1.74321, 67 | 2.09025, 68 | 0.535, 69 | -2.66521, 70 | 1.91025, 71 | -0.018, 72 | -3.74821, 73 | 1.34825, 74 | 1.553, 75 | -2.23921, 76 | 0.24325, 77 | 2.085, 78 | -3.02321, 79 | 0.76425, 80 | 2.518, 81 | -4.38621, 82 | 0.10225, 83 | 2.315, 84 | -5.41221, 85 | 1.96925, 86 | 3.085, 87 | -4.38121, 88 | 2.61525, 89 | 3.562, 90 | -5.59821, 91 | 3.27725, 92 | 2.453, 93 | -6.40421, 94 | 4.07425, 95 | 2.713, 96 | -7.30321, 97 | ] 98 | ).reshape(-1, 12, 3) 99 | return bb1, bb2 100 | 101 | 102 | def test_pairedRMSD(backbones): 103 | bb1, bb2 = backbones 104 | cross_rmsd = CrossRMSD() 105 | predicted_rmsd = cross_rmsd.pairedRMSD(bb1, bb2) 106 | assert torch.isclose(predicted_rmsd, torch.tensor(0.3542), rtol=1e-3) 107 | 108 | 109 | def test_pairedRMSD_symeig(backbones): 110 | bb1, bb2 = backbones 111 | cross_rmsd = CrossRMSD(method="symeig") 112 | predicted_rmsd = cross_rmsd.pairedRMSD(bb1, bb2) 113 | assert torch.isclose(predicted_rmsd, torch.tensor(0.35), rtol=1e1) 114 | 115 | 116 | def test_sample(backbones): 117 | bb1, bb2 = backbones 118 | 119 | cross_rmsd = CrossRMSD() 120 | input_x = torch.cat([bb1, bb1]) 121 | predicted = cross_rmsd(input_x, input_x) 122 | 123 | assert predicted.shape == (input_x.shape[0], input_x.shape[0]) 124 | assert torch.allclose(predicted, torch.zeros_like(predicted), atol=1e-1) 125 | 126 | predicted = cross_rmsd(bb1, bb2) 127 | assert all(torch.isclose(predicted, torch.tensor(0.35), rtol=1e-1)) 128 | 129 | 130 | def test_sample_symeigh(backbones): 131 | bb1, bb2 = backbones 132 | cross_rmsd = CrossRMSD(method="symeig") 133 | input_x = torch.cat([bb1, bb1]) 134 | 135 | predicted = cross_rmsd(input_x, input_x) 136 | assert torch.allclose(predicted, torch.zeros_like(predicted), atol=1e-2) 137 | 138 | 139 | def test_backbone_rmsd(backbones): 140 | bb1, bb2 = backbones 141 | for method in ["symeig", "power"]: 142 | backbone_rmsd = BackboneRMSD(method=method) 143 | 144 | X, C, S = Protein("5imm").to_XCS() 145 | 146 | rigid_transformer = RigidTransformer() 147 | dX = torch.Tensor([[1, 4, 2]]) 148 | q = torch.Tensor([[0.5, 1, 0, 1]]) 149 | X_transform = rigid_transformer(X, dX, q) 150 | X_transform_aligned, rmsd = backbone_rmsd.align( 151 | X_transform, X, C, align_unmasked=True 152 | ) 153 | 154 | assert not torch.allclose(X, X_transform, atol=1e-2) 155 | assert torch.allclose(X, X_transform_aligned, atol=1e-2) 156 | assert rmsd < 1e-2 157 | 158 | 159 | def test_fragment_rmsd(debug=False): 160 | X, C, S = Protein("1SHG").to_XCS() 161 | 162 | loss_frags = LossFragmentRMSD() 163 | 164 | X_noise = X + torch.randn_like(X) 165 | rmsd = loss_frags(X, X, C) 166 | rmsd_noised = loss_frags(X_noise, X, C) 167 | assert rmsd.mean() < 1e-2 168 | assert rmsd_noised.mean() > 1.0 169 | 170 | if debug: 171 | from chroma.layers.structure import diffusion 172 | 173 | noise = diffusion.DiffusionChainCov(complex_scaling=True) 174 | X_noise = noise(X, C, t=0.6) 175 | 176 | rmsd, X_frag_target, X_frag_mobile, X_frag_mobile_align = loss_frags( 177 | X_noise, X, C, return_coords=True 178 | ) 179 | print(rmsd) 180 | 181 | def _trajectory(X_frags): 182 | B, I, _, _ = list(X_frags.shape) 183 | X_frags = X_frags.reshape([B * I, -1, 4, 3]) 184 | X_trajectory = [X_t[None, ...] for X_t in X_frags.unbind(0)] 185 | return X_trajectory 186 | 187 | C = torch.ones([1, loss_frags.k]) 188 | X_trajectory_1 = _trajectory(X_frag_target) 189 | X_trajectory_2 = _trajectory(X_frag_mobile_align) 190 | X_trajectory_3 = _trajectory(X_frag_mobile) 191 | # Fight pymol confusion 192 | index = 10 193 | X_trajectory_3 = [X_trajectory_1[index]] + X_trajectory_3 194 | X_trajectory_2 = [X_trajectory_1[index]] + X_trajectory_2 195 | X_trajectory_1 = [X_trajectory_1[index]] + X_trajectory_1 196 | Protein.from_XCS_trajectory(X_trajectory_1, C, 0.0 * C).to_CIF( 197 | "X_frag_target.cif" 198 | ) 199 | Protein.from_XCS_trajectory(X_trajectory_2, C, 0.0 * C).to_CIF( 200 | "X_frag_noise_aligned.cif" 201 | ) 202 | Protein.from_XCS_trajectory(X_trajectory_3, C, 0.0 * C).to_CIF( 203 | "X_frag_noise.cif" 204 | ) 205 | return 206 | 207 | 208 | def test_fragment_pair_rmsd(debug=False): 209 | X, C, S = Protein("1SHG").to_XCS() 210 | 211 | loss_pairs = LossFragmentPairRMSD() 212 | 213 | X_noise = X + torch.randn_like(X) 214 | rmsd, mask_ij = loss_pairs(X, X, C) 215 | rmsd_noised, mask_ij = loss_pairs(X_noise, X, C) 216 | assert rmsd.mean() < 1e-2 217 | assert rmsd_noised.mean() > 1.0 218 | 219 | if debug: 220 | from chroma.layers.structure import diffusion 221 | 222 | noise = diffusion.DiffusionChainCov(complex_scaling=True) 223 | X_noise = noise(X, C, t=0.6) 224 | 225 | rmsd, mask_ij, X_pair_target, X_pair_mobile, X_pair_mobile_align = loss_pairs( 226 | X_noise, X, C, return_coords=True 227 | ) 228 | print(rmsd) 229 | 230 | def _trajectory(X_pairs): 231 | B, I, J, _, _ = list(X_pairs.shape) 232 | X_pairs = X_pairs.reshape([B * I * J, -1, 4, 3]) 233 | X_trajectory = [X_t[None, ...] for X_t in X_pairs.unbind(0)] 234 | return X_trajectory 235 | 236 | C = torch.cat( 237 | [torch.ones([1, loss_pairs.k]), 2 * torch.ones([1, loss_pairs.k])], -1 238 | ) 239 | X_trajectory_1 = _trajectory(X_pair_target) 240 | X_trajectory_2 = _trajectory(X_pair_mobile_align) 241 | X_trajectory_3 = _trajectory(X_pair_mobile) 242 | # Fight pymol confusion 243 | index = 1579 244 | X_trajectory_3 = [X_trajectory_1[index]] + X_trajectory_3 245 | X_trajectory_2 = [X_trajectory_1[index]] + X_trajectory_2 246 | X_trajectory_1 = [X_trajectory_1[index]] + X_trajectory_1 247 | Protein.from_XCS_trajectory(X_trajectory_1, C, 0.0 * C).to_CIF( 248 | "X_pair_target.cif" 249 | ) 250 | Protein.from_XCS_trajectory(X_trajectory_2, C, 0.0 * C).to_CIF( 251 | "X_pair_noise_aligned.cif" 252 | ) 253 | Protein.from_XCS_trajectory(X_trajectory_3, C, 0.0 * C).to_CIF( 254 | "X_pair_noise.cif" 255 | ) 256 | return 257 | 258 | 259 | def test_neighborhood_rmsd(debug=False): 260 | X, C, S = Protein("1SHG").to_XCS() 261 | 262 | loss_nb = LossNeighborhoodRMSD() 263 | 264 | X_noise = X + torch.randn_like(X) 265 | rmsd, mask = loss_nb(X, X, C) 266 | rmsd_noised, mask = loss_nb(X_noise, X, C) 267 | assert rmsd.mean() < 1e-2 268 | assert rmsd_noised.mean() > 1.0 269 | 270 | if debug: 271 | from chroma.layers.structure import diffusion 272 | 273 | noise = diffusion.DiffusionChainCov(complex_scaling=True) 274 | X_noise = noise(X, C, t=0.7) 275 | 276 | rmsd, mask, X_nb_target, X_nb_mobile, X_nb_mobile_align = loss_nb( 277 | X_noise, X, C, return_coords=True 278 | ) 279 | print(rmsd, X_nb_target.shape) 280 | 281 | def _trajectory(X_nbs): 282 | B, I, _, _ = list(X_nbs.shape) 283 | X_nbs = X_nbs.reshape([B * I, -1, 4, 3]) 284 | X_trajectory = [X_t[None, ...] for X_t in X_nbs.unbind(0)] 285 | return X_trajectory 286 | 287 | C = torch.ones([1, X_nb_target.shape[2] // 4]) 288 | X_trajectory_1 = _trajectory(X_nb_target) 289 | X_trajectory_2 = _trajectory(X_nb_mobile_align) 290 | X_trajectory_3 = _trajectory(X_nb_mobile) 291 | # Fight pymol confusion 292 | index = 10 293 | X_trajectory_3 = [X_trajectory_1[index]] + X_trajectory_3 294 | X_trajectory_2 = [X_trajectory_1[index]] + X_trajectory_2 295 | X_trajectory_1 = [X_trajectory_1[index]] + X_trajectory_1 296 | Protein.from_XCS_trajectory(X_trajectory_1, C, 0.0 * C).to_CIF( 297 | "X_nb_target.cif" 298 | ) 299 | Protein.from_XCS_trajectory(X_trajectory_2, C, 0.0 * C).to_CIF( 300 | "X_nb_noise_aligned.cif" 301 | ) 302 | Protein.from_XCS_trajectory(X_trajectory_3, C, 0.0 * C).to_CIF("X_nb_noise.cif") 303 | return 304 | -------------------------------------------------------------------------------- /tests/layers/structure/test_sidechain.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import numpy as np 4 | import pytest 5 | import torch 6 | 7 | from chroma import constants 8 | from chroma.data import Protein 9 | from chroma.layers.structure import backbone, sidechain 10 | 11 | 12 | class TestSideChain(TestCase): 13 | def setUp(self): 14 | self.builder = sidechain.SideChainBuilder() 15 | self.chi_angles = sidechain.ChiAngles() 16 | self.rmsd_loss = sidechain.LossSideChainRMSD() 17 | self.clash_loss = sidechain.LossSidechainClashes() 18 | self.frame_loss = sidechain.LossFrameAlignedGraph(distance_eps=1e-9) 19 | self.distance_loss = sidechain.LossAllAtomDistances() 20 | self.frame_builder = sidechain.AllAtomFrameBuilder() 21 | 22 | pdb_id = "1SHG" 23 | self.X, self.C, self.S = Protein(pdb_id).to_XCS(all_atom=True) 24 | 25 | def test_chi_cartesian_round_trip(self): 26 | X, C, S = self.X, self.C, self.S 27 | 28 | X_bb = X[:, :, :4, :] 29 | chi, mask_chi = self.chi_angles(X, C, S) 30 | X_reference, mask_X = self.builder(X_bb, C, S, chi) 31 | 32 | # Test round trip processing 33 | chi_direct, _ = self.chi_angles(X_reference, C, S) 34 | X_cycle, _ = self.builder(X_bb, C, S, chi_direct) 35 | chi_cycle, _ = self.chi_angles(X_cycle, C, S) 36 | 37 | _embed = lambda a: torch.stack([torch.cos(a), torch.sin(a)], -1) 38 | 39 | self.assertTrue(torch.allclose(X_reference, X_cycle, atol=1e-1)) 40 | self.assertTrue(torch.allclose(_embed(chi), _embed(chi_cycle), atol=1e-2)) 41 | 42 | loss = self.rmsd_loss(X, X_cycle, C, S) 43 | loss = self.clash_loss(X_cycle, C, S) 44 | 45 | def test_integration(self): 46 | num_letters = 20 47 | chi = np.pi * torch.rand([1, num_letters, 4]) 48 | 49 | X_bb = backbone.ProteinBackbone(num_letters, init_state="beta")() 50 | S = torch.arange(num_letters).unsqueeze(0) 51 | C = torch.ones_like(S) 52 | X, mask_X = self.builder(X_bb, C, S, chi) 53 | chi, mask_chi = self.chi_angles(X, C, S) 54 | 55 | self.assertTrue( 56 | np.allclose( 57 | mask_X.sum([-1, -2]).data.numpy(), np.asarray(constants.AA20_NUM_ATOMS) 58 | ) 59 | ) 60 | self.assertTrue( 61 | np.allclose( 62 | mask_chi.sum(-1).data.numpy(), np.asarray(constants.AA20_NUM_CHI) 63 | ) 64 | ) 65 | 66 | def test_frame_builder_round_trip(self): 67 | X, C, S = self.X, self.C, self.S 68 | 69 | x, q, chi = self.frame_builder.inverse(X, C, S) 70 | X_cycle, mask_atoms = self.frame_builder(x, q, chi, C, S) 71 | 72 | x = x + torch.randn_like(x) * 10.0 73 | q = q + torch.randn_like(q) * 2.0 74 | X_perturb, mask_atoms = self.frame_builder(x, q, chi, C, S) 75 | 76 | mask = (C > 0).float() 77 | 78 | _loss = lambda loss: (loss * mask).sum() / mask.sum() 79 | loss_cycle_avg = _loss(self.frame_loss(X, X_cycle, C, S)) 80 | loss_perturb_avg = _loss(self.frame_loss(X, X_perturb, C, S)) 81 | print(loss_cycle_avg, loss_perturb_avg) 82 | self.assertTrue(loss_cycle_avg.item() < 1.0) 83 | self.assertTrue(loss_perturb_avg.item() > 1.0) 84 | 85 | loss_cycle_avg = _loss(self.distance_loss(X, X_cycle, C, S)) 86 | loss_perturb_avg = _loss(self.distance_loss(X, X_perturb, C, S)) 87 | print(loss_cycle_avg, loss_perturb_avg) 88 | self.assertTrue(loss_cycle_avg.item() < 1.0) 89 | self.assertTrue(loss_perturb_avg.item() > 1.0) 90 | -------------------------------------------------------------------------------- /tests/layers/structure/test_symmetry.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from chroma.layers.structure import backbone, symmetry 5 | 6 | 7 | class Test_symmetry: 8 | @pytest.mark.parametrize("group", ["C_2", "C_4", "D_2", "D_4", "T", "O", "I"]) 9 | def test_point_groups(self, group): 10 | G = symmetry.get_point_group(group) 11 | 12 | # test if the determinants are ones for all rotation matrices 13 | assert torch.allclose(torch.det(G), torch.ones(G.shape[0])) 14 | 15 | # iterate the group multiplication table and check closure under multiplication 16 | for g1 in G: 17 | for g2 in G: 18 | assert (g1 @ g2) in G 19 | 20 | # check identity exists 21 | assert torch.eye(3) in G 22 | 23 | # check inverse is also in G 24 | for g in G: 25 | assert g.inverse() in G 26 | -------------------------------------------------------------------------------- /tests/layers/structure/test_transforms.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from chroma.layers.structure.geometry import rotations_from_quaternions 5 | from chroma.layers.structure.transforms import ( 6 | average_transforms, 7 | collect_neighbor_transforms, 8 | compose_inner_transforms, 9 | compose_transforms, 10 | compose_translation, 11 | equilibrate_transforms, 12 | fuse_gaussians_isometric_plus_radial, 13 | ) 14 | 15 | 16 | @pytest.fixture 17 | def vec(): 18 | torch.manual_seed(0) 19 | return torch.rand(3) 20 | 21 | 22 | @pytest.fixture 23 | def rotations(): 24 | torch.manual_seed(0) 25 | q = torch.rand(2, 4) 26 | return rotations_from_quaternions(q, normalize=True).unbind() 27 | 28 | 29 | @pytest.fixture 30 | def translations(): 31 | torch.manual_seed(0) 32 | return torch.rand(2, 3).unbind() 33 | 34 | 35 | def test_compose_transforms(vec, rotations, translations): 36 | R_a, R_b = rotations 37 | t_a, t_b = translations 38 | inter = R_b @ vec + t_b 39 | result = R_a @ inter + t_a 40 | R_composed, t_composed = compose_transforms(R_a, t_a, R_b, t_b) 41 | assert torch.allclose(result, R_composed @ vec + t_composed) 42 | 43 | 44 | def test_compose_translation(vec, rotations, translations): 45 | R_a, _ = rotations 46 | t_a, t_b = translations 47 | inter = vec + t_b 48 | result = R_a @ inter + t_a 49 | t_composed = compose_translation(R_a, t_a, t_b) 50 | assert torch.allclose(result, R_a @ vec + t_composed) 51 | 52 | 53 | def test_compose_inner_transforms(vec, rotations, translations): 54 | R_a, R_b = rotations 55 | t_a, t_b = translations 56 | R_a_inv = torch.inverse(R_a) 57 | inter = R_b @ vec + t_b 58 | result = R_a_inv @ (inter - t_a) 59 | R_composed, t_composed = compose_inner_transforms(R_a, t_a, R_b, t_b) 60 | # bump up tolerance because of matrix inversion 61 | assert torch.allclose(result, R_composed @ vec + t_composed, atol=1e-3, rtol=1e-2) 62 | 63 | 64 | def test_fuse_gaussians_isometric_plus_radial(vec): 65 | p_iso = torch.tensor([0.3, 0.7]) 66 | p_rad = torch.zeros_like(p_iso) 67 | x = torch.stack([vec, 2 * vec]) 68 | direction = torch.zeros_like(x) 69 | x_fused, P_fused = fuse_gaussians_isometric_plus_radial( 70 | x, p_iso, p_rad, direction, 0 71 | ) 72 | assert torch.allclose((p_iso[0] + 2 * p_iso[1]) * vec, P_fused @ x_fused) 73 | 74 | 75 | def test_collect_neighbor_transforms(rotations, translations): 76 | R_i = torch.stack(rotations).unsqueeze(0) 77 | t_i = torch.stack(translations).unsqueeze(0) 78 | edge_idx = torch.LongTensor([[1], [0]]).unsqueeze(0) 79 | R_j, t_j = collect_neighbor_transforms(R_i, t_i, edge_idx) 80 | assert torch.allclose(R_j, torch.flip(R_i, [1]).unsqueeze(2)) 81 | assert torch.allclose(t_j, torch.flip(t_i, [1]).unsqueeze(2)) 82 | 83 | 84 | def test_equilibrate_transforms(rotations, translations): 85 | R_i = torch.stack(rotations).unsqueeze(0) 86 | t_i = torch.stack(translations).unsqueeze(0) 87 | R_ji = torch.eye(3).expand(1, 2, 1, 3, 3) 88 | t_ji = torch.zeros(1, 2, 1, 3) 89 | logit_ij = torch.ones(1, 2, 1, 1) 90 | mask_ij = torch.ones(1, 2, 1) 91 | edge_idx = torch.LongTensor([[1], [0]]).unsqueeze(0) 92 | # two transforms on nodes that are each other's neighbors, so a single 93 | # iteration will just swap the transforms 94 | R_eq, t_eq = equilibrate_transforms( 95 | R_i, t_i, R_ji, t_ji, logit_ij, mask_ij, edge_idx, iterations=1 96 | ) 97 | assert torch.allclose(R_eq, torch.flip(R_i, [1]), atol=1e-3, rtol=1e-2) 98 | assert torch.allclose(t_eq, torch.flip(t_i, [1]), atol=1e-3, rtol=1e-2) 99 | # two iterations moves the transforms back to themselves 100 | R_eq, t_eq = equilibrate_transforms( 101 | R_i, t_i, R_ji, t_ji, logit_ij, mask_ij, edge_idx, iterations=2 102 | ) 103 | assert torch.allclose(R_eq, R_i, atol=1e-3, rtol=1e-2) 104 | assert torch.allclose(t_eq, t_i, atol=1e-3, rtol=1e-2) 105 | 106 | 107 | def test_average_transforms(rotations, translations): 108 | R = torch.stack([rotations[0], torch.eye(3)]) 109 | t = torch.stack([translations[0], torch.zeros(3)]) 110 | w = torch.ones(2, 2) 111 | mask = torch.ones(2) 112 | # average of a transform with the identity is "half" the transform 113 | R_avg, t_avg = average_transforms(R, t, w, mask, dim=0, dither=False) 114 | R_total_fromavg, _ = compose_transforms( 115 | R_avg, torch.zeros(3), R_avg, torch.zeros(3) 116 | ) 117 | _, t_total_fromavg = compose_transforms(torch.eye(3), t_avg, torch.eye(3), t_avg) 118 | assert torch.allclose(R_total_fromavg, R[0], atol=1e-3, rtol=1e-2) 119 | assert torch.allclose(t_total_fromavg, t[0], atol=1e-3, rtol=1e-2) 120 | -------------------------------------------------------------------------------- /tests/layers/test_basic.py: -------------------------------------------------------------------------------- 1 | import math 2 | from unittest import TestCase 3 | 4 | import pytest 5 | import torch 6 | import torch.nn as nn 7 | 8 | from chroma.layers.basic import ( 9 | MaybeOnehotEmbedding, 10 | MeanEmbedding, 11 | NodeProduct, 12 | NoOp, 13 | OneHot, 14 | PeriodicPositionalEncoding, 15 | PositionalEncoding, 16 | PositionWiseFeedForward, 17 | Transpose, 18 | TriangleMultiplication, 19 | Unsqueeze, 20 | ) 21 | 22 | 23 | class TestBasicLayers(TestCase): 24 | def setUp(self): 25 | self.noop = NoOp() 26 | self.onehot = OneHot(n_tokens=4) 27 | self.transpose = Transpose(1, 2) 28 | self.unsqueeze = Unsqueeze(1) 29 | self.mean_embedding = MeanEmbedding(nn.Embedding(4, 64), use_softmax=False) 30 | self.periodic = PeriodicPositionalEncoding(64) 31 | self.pwff = PositionWiseFeedForward(64, 64) 32 | 33 | def test_noop(self): 34 | x = torch.randn(4, 2, 2) 35 | self.assertTrue((x == self.noop(x)).all().item()) 36 | 37 | def test_onehot(self): 38 | input = torch.tensor([[0, 1, 2], [3, 0, 1]]) 39 | onehot = self.onehot(input).transpose(1, 2) 40 | target = torch.tensor( 41 | [ 42 | [[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 0]], 43 | [[0, 1, 0], [0, 0, 1], [0, 0, 0], [1, 0, 0]], 44 | ], 45 | dtype=onehot.dtype, 46 | ) 47 | self.assertTrue((onehot == target).all().item()) 48 | 49 | def test_mean_embedding(self): 50 | input = torch.tensor([[0, 1, 2], [3, 0, 1]]) 51 | onehot = self.onehot(input) 52 | self.assertTrue( 53 | (self.mean_embedding(input) == self.mean_embedding(onehot.float())) 54 | .all() 55 | .item() 56 | ) 57 | 58 | def test_triangle_multiplication(self): 59 | bs = 4 60 | nres = 25 61 | d_model = 12 62 | m = TriangleMultiplication(d_model=d_model) 63 | X = torch.randn(bs, nres, nres, d_model) 64 | mask = torch.ones(bs, nres, nres, 1) 65 | self.assertTrue( 66 | m(X, mask.bool()).size() == torch.Size([bs, nres, nres, d_model]) 67 | ) 68 | 69 | def test_node_product(self): 70 | bs = 4 71 | nres = 25 72 | d_model = 12 73 | m = NodeProduct(d_in=d_model, d_out=d_model) 74 | node_h = torch.randn(bs, nres, d_model) 75 | node_mask = torch.ones(bs, nres).bool() 76 | edge_mask = torch.ones(bs, nres, nres).bool() 77 | self.assertTrue( 78 | m(node_h, node_mask, edge_mask).size() 79 | == torch.Size([bs, nres, nres, d_model]) 80 | ) 81 | 82 | def test_transpose(self): 83 | x = torch.randn(4, 5, 2) 84 | self.assertTrue((x == self.transpose(x).transpose(1, 2)).all().item()) 85 | 86 | def test_periodic(self): 87 | position = torch.arange(0.0, 4000).unsqueeze(1) 88 | div_term = torch.exp(torch.arange(0.0, 64, 2) * -(math.log(10000.0) / 64)) 89 | self.assertTrue( 90 | (self.periodic.pe.squeeze()[:, 0::2] == torch.sin(position * div_term)) 91 | .all() 92 | .item() 93 | ) 94 | self.periodic(torch.randn(6, 30, 64)) 95 | 96 | def test_pwff(self): 97 | x = torch.randn(4, 5, 64) 98 | self.assertTrue(self.pwff(x).size() == x.size()) 99 | 100 | 101 | @pytest.mark.parametrize( 102 | "d_model, d_input", [(2, 1), (12, 1), (12, 2), (12, 3), (12, 6)], ids=str 103 | ) 104 | def test_positional_encoding(d_model, d_input): 105 | encoding = PositionalEncoding(d_model, d_input) 106 | 107 | for batch_shape in [(), (4,), (3, 2)]: 108 | inputs = torch.randn(batch_shape + (d_input,), requires_grad=True) 109 | outputs = encoding(inputs) 110 | assert outputs.shape == batch_shape + (d_model,) 111 | assert torch.isfinite(outputs).all() 112 | outputs.sum().backward() # smoke test 113 | 114 | 115 | def test_maybe_onehot_embedding(): 116 | x = torch.empty(10, dtype=torch.long).random_(4) 117 | x_onehot = nn.functional.one_hot(x, 4).float() 118 | 119 | embedding = MaybeOnehotEmbedding(4, 8) 120 | expected = embedding(x) 121 | actual = embedding(x_onehot) 122 | assert torch.allclose(expected, actual) 123 | -------------------------------------------------------------------------------- /tests/layers/test_graph.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import numpy as np 4 | import pytest 5 | import torch 6 | 7 | from chroma.layers.graph import ( 8 | MLP, 9 | GraphLayer, 10 | GraphNN, 11 | collect_edges_transpose, 12 | edge_mask_causal, 13 | permute_graph_embeddings, 14 | ) 15 | 16 | 17 | class Testcollect_edges_transpose(TestCase): 18 | # Simple case of 3 noddes that are connected to each other 19 | edge_idx = torch.tensor([[[1, 2], [0, 2], [0, 1]]]) 20 | mask_ij = torch.tensor([[[1, 1], [1, 1], [1, 1]]]) 21 | edge_h = torch.tensor([[[[1], [2]], [[3], [4]], [[5], [6]]]]) 22 | 23 | edge_h_transpose, mask_ji = collect_edges_transpose(edge_h, edge_idx, mask_ij) 24 | 25 | # Manually inspected the tensor so that it work 26 | # I view(-1) so that it is easier to write 27 | assert ( 28 | torch.tensor([3.0, 5.0, 1.0, 6.0, 2.0, 4.0]) != edge_h_transpose.view(-1) 29 | ).detach().numpy().sum() == 0 30 | # Assert that shape stay the sample_input 31 | assert edge_h.shape == edge_h_transpose.shape 32 | 33 | # Kind of dumb, but if all mask, all edege shoudl be zero 34 | edge_h_transpose, mask_ji = collect_edges_transpose( 35 | edge_h, edge_idx, torch.zeros_like(mask_ij) 36 | ) 37 | assert edge_h_transpose.abs().sum() == 0 38 | 39 | # Masking connection between 1,2 40 | mask_ij = torch.tensor([[[1, 1], [1, 0], [1, 0]]]) 41 | edge_h_transpose, mask_ji = collect_edges_transpose(edge_h, edge_idx, mask_ij) 42 | print(edge_h_transpose.view(-1)) 43 | assert ( 44 | torch.tensor([3.0, 5.0, 1.0, 0.0, 2.0, 0.0]) != edge_h_transpose.view(-1) 45 | ).detach().numpy().sum() == 0 46 | 47 | # Masking 0 vers 2 mais pas 2 vers 0 48 | # 2 vers 0 should be masked in the transpose 49 | mask_ij = torch.tensor([[[1, 0], [1, 0], [1, 0]]]) 50 | edge_h_transpose, mask_ji = collect_edges_transpose(edge_h, edge_idx, mask_ij) 51 | assert ( 52 | torch.tensor([3.0, 0.0, 1.0, 0.0, 0.0, 0.0]) != edge_h_transpose.view(-1) 53 | ).detach().numpy().sum() == 0 54 | 55 | 56 | class TestGraphNN(TestCase): 57 | def test_sample(self): 58 | dim_nodes = 128 59 | dim_edges = 64 60 | 61 | model = GraphNN(num_layers=6, dim_nodes=dim_nodes, dim_edges=dim_edges,) 62 | 63 | num_nodes = 10 64 | num_neighbors = 8 65 | node_h_out, edge_h_out = model( 66 | torch.ones(1, num_nodes, dim_nodes), 67 | torch.ones(1, num_nodes, num_neighbors, dim_edges), 68 | torch.ones(1, num_nodes, num_neighbors, dtype=torch.long), 69 | ) 70 | self.assertTrue(node_h_out.shape == (1, num_nodes, dim_nodes)) 71 | self.assertTrue(edge_h_out.shape == (1, num_nodes, num_neighbors, dim_edges)) 72 | 73 | 74 | class TestGraphLayer(TestCase): 75 | def test_sample(self): 76 | 77 | dim_nodes = 128 78 | dim_edges = 64 79 | 80 | graph_layer = GraphLayer( 81 | dim_nodes=dim_nodes, dim_edges=dim_edges, dropout=0, edge_update=True 82 | ) 83 | 84 | num_parameters = sum([np.prod(p.size()) for p in graph_layer.parameters()]) 85 | 86 | # self.assertEqual(num_parameters, 131712) 87 | 88 | num_nodes = 10 89 | num_neighbors = 8 90 | node_h_out, edge_h_out = graph_layer( 91 | torch.ones(1, num_nodes, dim_nodes), 92 | torch.ones(1, num_nodes, num_neighbors, dim_edges), 93 | torch.ones(1, num_nodes, num_neighbors, dtype=torch.long), 94 | ) 95 | self.assertTrue(node_h_out.shape == (1, num_nodes, dim_nodes)) 96 | self.assertTrue(edge_h_out.shape == (1, num_nodes, num_neighbors, dim_edges)) 97 | 98 | 99 | class TestMLP(TestCase): 100 | def test_sample(self): 101 | dim_in = 10 102 | sample_input = torch.rand(dim_in) 103 | prediction = MLP(dim_in)(sample_input) 104 | self.assertTrue(prediction.shape[-1] == dim_in) 105 | 106 | sample_input = torch.rand(dim_in) 107 | dim_out = 8 108 | model = MLP(dim_in, dim_out=dim_out) 109 | prediction = model(sample_input) 110 | self.assertTrue(prediction.shape[-1] == dim_out) 111 | 112 | sample_input = torch.rand(dim_in) 113 | dim_hidden = 5 114 | model = MLP(dim_in, dim_hidden=5, dim_out=5) 115 | prediction = model(sample_input) 116 | self.assertTrue(prediction.shape[-1] == dim_hidden) 117 | 118 | sample_input = torch.rand(dim_in) 119 | 120 | model = MLP(dim_in, num_layers_hidden=0, dim_out=dim_out) 121 | prediction = model(sample_input) 122 | self.assertTrue(prediction.shape[-1] == dim_out) 123 | 124 | 125 | class TestGraphFunctions(TestCase): 126 | def hello(): 127 | print("hello") 128 | 129 | 130 | def test_graph_permutation(): 131 | B, N, K, H = 2, 7, 4, 3 132 | # Create a random graph embedding 133 | node_h = torch.randn([B, N, H]) 134 | edge_h = torch.randn([B, N, K, H]) 135 | edge_idx = torch.randint(low=0, high=N, size=[B, N, K]) 136 | mask_i = torch.ones([B, N]) 137 | mask_ij = torch.ones([B, N, K]) 138 | 139 | # Create a random permutation matrix embedding 140 | permute_idx = torch.argsort(torch.randn([B, N]), dim=-1) 141 | 142 | # Permute 143 | node_h_p, edge_h_p, edge_idx_p, mask_i_p, mask_ij_p = permute_graph_embeddings( 144 | node_h, edge_h, edge_idx, mask_i, mask_ij, permute_idx 145 | ) 146 | 147 | # Inverse permute 148 | permute_idx_inverse = torch.argsort(permute_idx, dim=-1) 149 | node_h_pp, edge_h_pp, edge_idx_pp, mask_i_pp, mask_ij_pp = permute_graph_embeddings( 150 | node_h_p, edge_h_p, edge_idx_p, mask_i_p, mask_ij_p, permute_idx_inverse 151 | ) 152 | 153 | # Test round-trip of permutation . inverse permutation 154 | assert torch.allclose(node_h, node_h_pp) 155 | assert torch.allclose(edge_h, edge_h_pp) 156 | assert torch.allclose(edge_idx, edge_idx_pp) 157 | assert torch.allclose(mask_i, mask_i_pp) 158 | assert torch.allclose(mask_ij, mask_ij_pp) 159 | 160 | # Test permutation equivariance of GNN layers 161 | gnn = GraphNN(num_layers=1, dim_nodes=H, dim_edges=H) 162 | outs = gnn(node_h, edge_h, edge_idx, mask_i, mask_ij) 163 | outs_perm = gnn(node_h_p, edge_h_p, edge_idx_p, mask_i_p, mask_ij_p) 164 | outs_pp = permute_graph_embeddings( 165 | outs_perm[0], outs_perm[1], edge_idx_p, mask_i_p, mask_ij_p, permute_idx_inverse 166 | ) 167 | 168 | assert torch.allclose(outs[0], outs_pp[0]) 169 | assert torch.allclose(outs[1], outs_pp[1]) 170 | return 171 | 172 | 173 | def test_autoregressive_gnn(): 174 | B, N, K, H = 1, 3, 3, 4 175 | 176 | torch.manual_seed(0) 177 | 178 | # Build random GNN input 179 | node_h = torch.randn([B, N, H]) 180 | edge_h = torch.randn([B, N, K, H]) 181 | # edge_idx = torch.randint(low=0, high=N, size=[B, N, K]) 182 | edge_idx = torch.arange(K).reshape([1, 1, K]).expand([B, N, K]).contiguous() 183 | mask_i = torch.ones([B, N]) 184 | mask_ij = torch.ones([B, N, K]) 185 | mask_ij = edge_mask_causal(edge_idx, mask_ij) 186 | 187 | error = lambda x, y: (torch.abs(x - y) / (torch.abs(y) + 1e-3)).mean() 188 | 189 | # Parallel mode computation 190 | for mode in [True, False]: 191 | gnn = GraphNN(num_layers=4, dim_nodes=H, dim_edges=H, attentional=mode) 192 | 193 | node_h_gnn, edge_h_gnn = gnn(node_h, edge_h, edge_idx, mask_i, mask_ij) 194 | 195 | # Step wise computation 196 | node_h_cache, edge_h_cache = gnn.init_steps(node_h, edge_h) 197 | for t in range(N): 198 | node_h_cache, edge_h_cache = gnn.step( 199 | t, node_h_cache, edge_h_cache, edge_idx, mask_i, mask_ij 200 | ) 201 | node_h_sequential = node_h_cache[-1] 202 | edge_h_sequential = edge_h_cache[-1] 203 | 204 | assert torch.allclose(node_h_gnn, node_h_sequential) 205 | assert torch.allclose(edge_h_gnn, edge_h_sequential) 206 | return 207 | -------------------------------------------------------------------------------- /tests/layers/test_norm.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import torch 4 | 5 | from chroma.layers.norm import MaskedBatchNorm1d 6 | 7 | 8 | class TestBatchNorm(TestCase): 9 | def test_norm(self): 10 | device = ( 11 | torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 12 | ) 13 | B, C, L = (3, 5, 7) 14 | x1 = torch.randn(B, C, L).to(device) 15 | mean1 = x1.sum(dim=0, keepdim=True).sum(dim=2, keepdim=True) / (B * L) 16 | var1 = ((x1 - mean1) ** 2).sum(dim=0, keepdim=True).sum(dim=2, keepdim=True) / ( 17 | B * L 18 | ) 19 | x2 = torch.randn(B, C, L).to(device) 20 | mean2 = x2.sum(dim=0, keepdim=True).sum(dim=2, keepdim=True) / (B * L) 21 | var2 = ((x2 - mean2) ** 2).sum(dim=0, keepdim=True).sum(dim=2, keepdim=True) / ( 22 | B * L 23 | ) 24 | 25 | mbn = MaskedBatchNorm1d(C) 26 | mbn = mbn.to(device) 27 | 28 | # Test without mask in train 29 | mbn.train() 30 | out = mbn(x1) 31 | self.assertTrue(mean1.allclose(mbn.running_mean)) 32 | self.assertTrue(var1.allclose(mbn.running_var)) 33 | normed = (x1 - mean1) / torch.sqrt(var1 + mbn.eps) * mbn.weight + mbn.bias 34 | self.assertTrue(normed.allclose(out)) 35 | out = mbn(x2) 36 | normed = (x2 - mean2) / torch.sqrt(var2 + mbn.eps) * mbn.weight + mbn.bias 37 | self.assertTrue(normed.allclose(out)) 38 | self.assertTrue( 39 | mbn.running_mean.allclose((1 - mbn.momentum) * mean1 + mbn.momentum * mean2) 40 | ) 41 | self.assertTrue( 42 | mbn.running_var.allclose((1 - mbn.momentum) * var1 + mbn.momentum * var2) 43 | ) 44 | 45 | # Without mask in eval 46 | mbn.eval() 47 | out = mbn(x1) 48 | self.assertTrue( 49 | mbn.running_mean.allclose((1 - mbn.momentum) * mean1 + mbn.momentum * mean2) 50 | ) 51 | self.assertTrue( 52 | mbn.running_var.allclose((1 - mbn.momentum) * var1 + mbn.momentum * var2) 53 | ) 54 | normed = (x1 - mbn.running_mean) / torch.sqrt( 55 | mbn.running_var + mbn.eps 56 | ) * mbn.weight + mbn.bias 57 | self.assertTrue(normed.allclose(out)) 58 | 59 | # Check that masking with all ones doesn't change values 60 | mask = x1.new_ones((B, 1, L)) 61 | outm = mbn(x1, input_mask=mask) 62 | self.assertTrue(outm.allclose(out)) 63 | mbn.eval() 64 | out = mbn(x2) 65 | outm = mbn(x2, input_mask=mask) 66 | self.assertTrue(outm.allclose(out)) 67 | 68 | # With mask in train 69 | mask = torch.randn(B, 1, L) 70 | mask = mask > 0.0 71 | mask = mask.to(device) 72 | n = mask.sum() 73 | mean1 = (x1 * mask).sum(dim=0, keepdim=True).sum(dim=2, keepdim=True) / n 74 | var1 = (((x1 * mask) - mean1) ** 2).sum(dim=0, keepdim=True).sum( 75 | dim=2, keepdim=True 76 | ) / n 77 | mbn = MaskedBatchNorm1d(C) 78 | mbn = mbn.to(device) 79 | mbn.train() 80 | out = mbn(x1, input_mask=mask) 81 | self.assertTrue(mean1.allclose(mbn.running_mean)) 82 | self.assertTrue(var1.allclose(mbn.running_var)) 83 | normed = (x1 * mask - mean1) / torch.sqrt( 84 | var1 + mbn.eps 85 | ) * mbn.weight + mbn.bias 86 | self.assertTrue(normed.allclose(out)) 87 | # With mask in eval 88 | mbn.eval() 89 | out = mbn(x1, input_mask=mask) 90 | normed = (x1 * mask - mbn.running_mean) / torch.sqrt( 91 | mbn.running_var + mbn.eps 92 | ) * mbn.weight + mbn.bias 93 | self.assertTrue(normed.allclose(out)) 94 | -------------------------------------------------------------------------------- /tests/layers/test_sde.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import pytest 4 | import torch 5 | 6 | from chroma.layers.sde import sde_integrate, sde_integrate_heun 7 | 8 | 9 | @pytest.fixture 10 | def y0(): 11 | # try multiple 1D trajectories, then take mean and variance in testing 12 | return torch.zeros(10000) 13 | 14 | 15 | @pytest.fixture 16 | def tspan(): 17 | return (0.5, 0.3) 18 | 19 | 20 | @pytest.fixture 21 | def N(): 22 | return 200 23 | 24 | 25 | @pytest.fixture 26 | def exp_mean(y0, tspan): 27 | return torch.Tensor(y0 + (tspan[1] - tspan[0]) / 2).mean() 28 | 29 | 30 | @pytest.fixture 31 | def exp_var(tspan): 32 | deltat = tspan[1] - tspan[0] 33 | # variance contributions arising from drift and diffusion, respectively 34 | return torch.Tensor([deltat ** 2 / 12 + abs(deltat) / 6]) 35 | 36 | 37 | def sde_sample_func(t, y): 38 | f = torch.ones_like(y) 39 | gZ = torch.randn(y.shape) 40 | return f, gZ 41 | 42 | 43 | def test_sde_integrate(y0, tspan, N, exp_mean, exp_var): 44 | y_trajectory = torch.stack(sde_integrate(sde_sample_func, y0, tspan, N), dim=-1) 45 | assert torch.allclose(torch.mean(y_trajectory, dim=-1).mean(), exp_mean, rtol=5e-2) 46 | assert torch.allclose(torch.var(y_trajectory, dim=-1).mean(), exp_var, rtol=5e-2) 47 | 48 | 49 | def test_sde_integrate_heun(y0, tspan, N, exp_mean, exp_var): 50 | y_trajectory = torch.stack( 51 | sde_integrate_heun(sde_sample_func, y0, tspan, N), dim=-1 52 | ) 53 | assert torch.allclose(torch.mean(y_trajectory, dim=-1).mean(), exp_mean, rtol=5e-2) 54 | assert torch.allclose(torch.var(y_trajectory, dim=-1).mean(), exp_var, rtol=5e-2) 55 | -------------------------------------------------------------------------------- /tests/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/generatebio/chroma/929407c605013613941803c6113adefdccaad679/tests/models/__init__.py -------------------------------------------------------------------------------- /tests/models/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | from pathlib import Path 4 | 5 | import pytest 6 | 7 | import chroma 8 | from chroma.data import Protein 9 | 10 | 11 | @pytest.fixture(scope="module") 12 | def XCS(): 13 | input_file = str( 14 | Path(Path(chroma.__file__).parent.parent, "tests", "resources", "1n8z.cif") 15 | ) 16 | 17 | protein = Protein(input_file) 18 | length = 100 19 | return [_t[:, :length] for _t in protein.to_XCS(all_atom=True)] 20 | # return protein.to_XCS(all_atom=True) 21 | 22 | 23 | @pytest.fixture(scope="module") 24 | def XCS_backbone(XCS): 25 | X, C, S = XCS 26 | return X[:, :, :4, :], C, S 27 | -------------------------------------------------------------------------------- /tests/models/test_chroma.py: -------------------------------------------------------------------------------- 1 | from math import isclose 2 | from pathlib import Path 3 | 4 | import pytest 5 | import torch 6 | 7 | import chroma 8 | from chroma.data.protein import Protein 9 | from chroma.layers.structure import conditioners 10 | from chroma.models.chroma import Chroma 11 | 12 | BB_MODEL_PATH = "https://chroma-weights.generatebiomedicines.com/downloads?weights=chroma_backbone_v1.0.pt" #'named:nature_v3' 13 | GD_MODEL_PATH = "https://chroma-weights.generatebiomedicines.com/downloads?weights=chroma_design_v1.0.pt" #'named:nature_v3' 14 | 15 | BASE_PATH = str(Path(chroma.__file__).parent.parent) 16 | PROTEIN_SAMPLE = BASE_PATH + "/tests/resources/steps200_seed42_len100.cif" 17 | 18 | 19 | @pytest.fixture(scope="session") 20 | def chroma(): 21 | return Chroma(BB_MODEL_PATH, GD_MODEL_PATH, device="cpu") 22 | 23 | 24 | def test_chroma(chroma): 25 | 26 | # Fixed Protein Value 27 | protein = Protein.from_CIF(PROTEIN_SAMPLE) 28 | 29 | # Fixed value test score 30 | torch.manual_seed(42) 31 | scores = chroma.score(protein, num_samples=5) 32 | assert isclose(scores["elbo"].score, 5.890165328979492, abs_tol=1e-3) 33 | 34 | # Test Sampling & Design 35 | # torch.manual_seed(42) 36 | # sample = chroma.sample(steps=200) 37 | 38 | # Xs, _, Ss = sample.to_XCS() 39 | # X , _, S = protein.to_XCS() 40 | # assert torch.allclose(X,Xs) 41 | # assert (S == Ss).all() 42 | 43 | # test postprocessing 44 | from chroma.layers.structure import conditioners 45 | 46 | X, C, S = protein.to_XCS() 47 | c_symmetry = conditioners.SymmetryConditioner(G="C_8", num_chain_neighbors=1) 48 | 49 | X_s, C_s, S_s = ( 50 | torch.cat([X, X], dim=1), 51 | torch.cat([C, C], dim=1), 52 | torch.cat([S, S], dim=1), 53 | ) 54 | protein_sym = Protein(X_s, C_s, S_s) 55 | 56 | chroma._postprocess(c_symmetry, protein_sym, output_dictionary=None) 57 | 58 | 59 | @pytest.mark.parametrize( 60 | "conditioner", 61 | [ 62 | conditioners.Identity(), 63 | conditioners.SymmetryConditioner(G="C_3", num_chain_neighbors=1), 64 | ], 65 | ) 66 | def test_sample(chroma, conditioner): 67 | chroma.sample(steps=3, conditioner=conditioner, design_method=None) 68 | 69 | 70 | @pytest.mark.parametrize( 71 | "conditioner", 72 | [ 73 | conditioners.Identity(), 74 | conditioners.SymmetryConditioner(G="C_3", num_chain_neighbors=1), 75 | ], 76 | ) 77 | def test_sample_backbone(chroma, conditioner): 78 | chroma._sample(steps=3, conditioner=conditioner) 79 | 80 | 81 | @pytest.mark.parametrize("design_method", ["autoregressive", "potts",]) 82 | @pytest.mark.parametrize("potts_proposal", ["dlmc", "chromatic"]) 83 | def test_design(chroma, design_method, potts_proposal): 84 | protein = Protein.from_CIF(PROTEIN_SAMPLE) 85 | chroma.design( 86 | protein, 87 | design_method=design_method, 88 | potts_proposal=potts_proposal, 89 | potts_mcmc_depth=20, 90 | ) 91 | -------------------------------------------------------------------------------- /tests/models/test_graph_backbone.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from chroma.data import Protein 5 | from chroma.layers.structure import backbone, protein_graph 6 | from chroma.models.graph_backbone import GraphBackbone 7 | 8 | 9 | def test_denoiser(dim_nodes=32, dim_edges=32): 10 | X, C, S = Protein("1SHG").to_XCS() 11 | model = GraphBackbone(dim_nodes=dim_nodes, dim_edges=dim_edges) 12 | 13 | # check if denoiser is working as expected 14 | model.CA_dist_scaling = False 15 | X0 = model.denoise(X, C, 0.0) 16 | assert X0.shape == X.shape 17 | 18 | # test if prediction_type="scale" is working 19 | model = GraphBackbone(prediction_type="scale") 20 | X0 = model.denoise(X, C, 0.0) 21 | assert torch.allclose(X0, X, rtol=1e-2) 22 | 23 | # check if CA_dist scale is working as expected 24 | model.CA_dist_scaling = False 25 | X0 = model.denoise(0.25 * X, C, 1e-4) 26 | assert X0.shape == X.shape 27 | # assert model._D_backbone_CA(X0, C).min().item() < model.min_CA_bb_distance 28 | 29 | 30 | @pytest.mark.parametrize("t", [0.1, 0.7, 1.0]) 31 | def test_equivariance_denoiser(t, dim_nodes=32, dim_edges=32, seed=10, debug=False): 32 | X = backbone.ProteinBackbone(num_batch=1, num_residues=20, init_state="alpha")() 33 | C = torch.ones(X.shape[:2]) 34 | S = torch.zeros_like(C).long() 35 | 36 | model = GraphBackbone(dim_nodes=dim_nodes, dim_edges=dim_edges).eval() 37 | 38 | # Test rotation equivariance 39 | transformer = backbone.RigidTransformer(center_rotation=False) 40 | q_transform = torch.Tensor([0.0, 0.1, -1.2, 0.5]).unsqueeze(0) 41 | dX_transform = torch.Tensor([-3.0, 30.0, 7.0]).unsqueeze(0) 42 | _transform = lambda X_input: transformer(X_input, dX_transform, q_transform) 43 | 44 | # Add noise 45 | X_noised = model.noise_perturb(X, C, t=t) 46 | X_noised_transform = _transform(X_noised) 47 | 48 | # Synchronize random seeds for random graph generation 49 | torch.manual_seed(seed) 50 | X_denoised = model.denoise(X_noised, C, t=t) 51 | X_denoised_transform = _transform(X_denoised) 52 | torch.manual_seed(seed) 53 | X_transform_denoised = model.denoise(X_noised_transform, C, t=t) 54 | 55 | if debug: 56 | print((X_denoised_transform - X_transform_denoised).abs().max()) 57 | 58 | Protein(X, C, S).to_CIF("X_denoised.cif") 59 | Protein(X_denoised, C, S).to_CIF("X_denoised.cif") 60 | Protein(X_denoised_transform, C, S).to_CIF("X_denoised_transform.cif") 61 | Protein(X_transform_denoised, C, S).to_CIF("X_transform_denoised.cif") 62 | 63 | # The oxygen atom of the final carboxy terminus residue in each chain 64 | # is disambiguated via zero-padding (non-equivariant), so it can be up to \ 65 | # ~1 angstrom off depending on global pose 66 | assert torch.allclose( 67 | X_denoised_transform[:, :-1, :, :], 68 | X_transform_denoised[:, :-1, :, :], 69 | atol=1e-1, 70 | ) 71 | # Nevertheless at this adjusted tolerance we are equivariant 72 | assert torch.allclose(X_denoised_transform, X_transform_denoised, atol=3.0) 73 | assert not torch.allclose(X_denoised, X_transform_denoised, atol=1e-1) 74 | 75 | 76 | @pytest.mark.parametrize("num_transform_weights", [1, 2, 3]) 77 | @pytest.mark.parametrize("dim_nodes", [32]) 78 | @pytest.mark.parametrize("dim_edges", [32]) 79 | def test_equivariance_graph_update( 80 | num_transform_weights, dim_nodes, dim_edges, output_structures=False 81 | ): 82 | torch.manual_seed(10.0) 83 | 84 | # Initialize layers 85 | bb_update = backbone.GraphBackboneUpdate( 86 | dim_nodes=dim_nodes, 87 | dim_edges=dim_edges, 88 | method="neighbor_global_affine", 89 | num_transform_weights=num_transform_weights, 90 | ).eval() 91 | pg = protein_graph.ProteinFeatureGraph( 92 | dim_nodes=dim_nodes, dim_edges=dim_edges, num_neighbors=5 93 | ) 94 | 95 | # Test rotation equivariance 96 | transformer = backbone.RigidTransformer(center_rotation=False) 97 | q_rotate = torch.Tensor([0.0, 0.1, -1.2, 0.5]).unsqueeze(0) 98 | dX_rotate = torch.Tensor([0.0, 0.0, 0.0]).unsqueeze(0) 99 | _rotate = lambda X_input: transformer(X_input, dX_rotate, q_rotate) 100 | 101 | # Load test structure and canonicalize 102 | X, C, S = Protein("1qys").to_XCS() 103 | R, t, _ = bb_update.frame_builder.inverse(X, C) 104 | X = bb_update.frame_builder.forward(R, t, C) 105 | 106 | # Apply transformation 107 | node_h, edge_h, edge_idx, mask_i, mask_ij = pg(X, C) 108 | X_update, _, _, _ = bb_update(X, C, node_h, edge_h, edge_idx, mask_i, mask_ij) 109 | 110 | # Compute for rotated system 111 | X_rotate = _rotate(X) 112 | X_rotate_update, _, _, _ = bb_update( 113 | X_rotate, C, node_h, edge_h, edge_idx, mask_i, mask_ij 114 | ) 115 | X_update_rotate = _rotate(X_update) 116 | 117 | assert torch.allclose(X_rotate_update, X_update_rotate, atol=1e-2) 118 | 119 | if output_structures: 120 | from chroma.layers.structure.rmsd import BackboneRMSD 121 | 122 | bb_rmsd = BackboneRMSD() 123 | X_aligned, rmsd = bb_rmsd.align(X_rotate_update, X_update_rotate, C) 124 | print(rmsd) 125 | Protein.from_XCS_trajectory( 126 | [X, X_update, X_rotate, X_rotate_update, X_update_rotate], C, S 127 | ).to_CIF("test_equi.cif") 128 | return 129 | -------------------------------------------------------------------------------- /tests/models/test_graph_classifier.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import torch 4 | 5 | from chroma.models.graph_classifier import GraphClassifier 6 | 7 | 8 | class TestGraphClassifier(TestCase): 9 | def test_graph_classifier(self): 10 | class_config = { 11 | "dummy_1": { 12 | "tokens": ["a", "b", "c", "d"], 13 | "loss": "bce", 14 | "level": "chain", 15 | }, 16 | "dummy_2": { 17 | "tokens": ["w", "x", "y", "z"], 18 | "loss": "ce", 19 | "level": "first_order", 20 | }, 21 | } 22 | for k, v in class_config.items(): 23 | v["tokenizer"] = {k: i for i, k in enumerate(v["tokens"])} 24 | 25 | model = GraphClassifier( 26 | dim_nodes=16, 27 | dim_edges=16, 28 | edge_mlp_dim=8, 29 | node_mlp_dim=8, 30 | class_config=class_config, 31 | ) 32 | 33 | bs = 1 34 | sl = 8 35 | 36 | X = torch.randn(bs, sl, 4, 3) 37 | C = torch.ones(bs, sl) 38 | 39 | with torch.no_grad(): 40 | node_h, edge_h = model(X, C) 41 | 42 | self.assertTrue(node_h.size() == torch.Size([bs, sl, 16])) 43 | 44 | grad = model.gradient(X, C, t=0.5, label="dummy_2", value="w") 45 | self.assertTrue(grad.size() == torch.Size([bs, sl, 4, 3])) 46 | -------------------------------------------------------------------------------- /tests/models/test_graph_design.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from chroma.models.graph_design import GraphDesign, ProteinTraversalSpatial 5 | 6 | 7 | @pytest.fixture 8 | def model(): 9 | model = GraphDesign(predict_S_marginals=True, predict_S_potts=True) 10 | model.eval() 11 | return model 12 | 13 | 14 | def test_sequential_decoding(model, XCS): 15 | """Test that the sequential and parallelized decoding of GNN agree.""" 16 | 17 | from chroma.data import xcs 18 | 19 | X, C, S = XCS 20 | permute_idx = torch.argsort(torch.randn_like(C.float()), dim=-1) 21 | 22 | # Fix a permutation 23 | scores_parallel = model(X, C, S, permute_idx=permute_idx) 24 | 25 | _, _, _, scores_sequential = model.sample( 26 | X, C, S, permute_idx=permute_idx, clamped=True, return_scores=True 27 | ) 28 | 29 | assert torch.allclose( 30 | scores_parallel["logp_S"], scores_sequential["logp_S"], atol=1e-3 31 | ) 32 | assert torch.allclose( 33 | scores_parallel["logp_chi"], scores_sequential["logp_chi"], atol=1e-3 34 | ) 35 | 36 | # =============Fix a permutation ======== 37 | X_sample, S_sample, _, scores_sequential = model.sample( 38 | X, C, S, permute_idx=permute_idx, clamped=False, return_scores=True 39 | ) 40 | scores_parallel = model(X_sample, C, S_sample, permute_idx=permute_idx) 41 | 42 | assert torch.allclose( 43 | scores_parallel["logp_S"], scores_sequential["logp_S"], atol=1e-3 44 | ) 45 | assert torch.allclose( 46 | scores_parallel["logp_chi"], scores_sequential["logp_chi"], atol=1e-3 47 | ) 48 | return 49 | 50 | 51 | def test_deterministic_traversal(XCS): 52 | """Check deterministic flag on ProteinTraversalSpatial module.""" 53 | traversal = ProteinTraversalSpatial(deterministic=True) 54 | X, C, _ = XCS 55 | permute_idx = traversal(X, C) 56 | permute_idx_2 = traversal(X, C) 57 | assert torch.allclose(permute_idx, permute_idx_2) 58 | return 59 | 60 | 61 | def test_graph_design_outputs(model, XCS): 62 | """Smoke test all GraphDesign outputs.""" 63 | X, C, S = XCS 64 | outputs = model(X, C, S) 65 | for key in ["logp_S", "logp_S_marginals", "logp_S_potts"]: 66 | assert outputs[key].shape == X.shape[:2] 67 | assert torch.allclose(outputs["X_noise"], X) 68 | for key in ["chi", "logp_chi"]: 69 | assert outputs[key].shape[:-1] == X.shape[:2] and outputs[key].shape[-1] == 4 70 | -------------------------------------------------------------------------------- /tests/models/test_graph_energy.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from chroma.models import graph_energy 7 | 8 | 9 | class TestGraphHarmonicFeatures(TestCase): 10 | def test_sample(self): 11 | num_batch = 1 12 | num_nodes = 10 13 | num_neighbors = 8 14 | dim_nodes = 128 15 | dim_edges = 64 16 | 17 | layer = graph_energy.GraphHarmonicFeatures( 18 | dim_nodes=dim_nodes, 19 | dim_edges=dim_edges, 20 | node_mlp_layers=2, 21 | node_mlp_dim=dim_nodes, 22 | edge_mlp_layers=2, 23 | edge_mlp_dim=dim_edges, 24 | ) 25 | 26 | node_h = torch.ones(num_batch, num_nodes, dim_nodes) 27 | node_features = torch.ones(num_batch, num_nodes, dim_nodes) 28 | edge_h = torch.ones(num_batch, num_nodes, num_neighbors, dim_edges) 29 | edge_features = torch.ones(num_batch, num_nodes, num_neighbors, dim_edges) 30 | mask_i = torch.ones(num_batch, num_nodes) 31 | mask_ij = torch.ones(num_batch, num_nodes, num_neighbors) 32 | 33 | node_out, edge_out = layer( 34 | node_h, node_features, edge_h, edge_features, mask_i, mask_ij 35 | ) 36 | 37 | self.assertTrue(node_out.shape == (1, num_nodes, dim_nodes)) 38 | self.assertTrue(edge_out.shape == (1, num_nodes, num_neighbors, dim_edges)) 39 | -------------------------------------------------------------------------------- /tests/models/test_procap.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | 3 | import torch 4 | 5 | from chroma.models.procap import ProteinCaption, load_model, save_model 6 | 7 | 8 | def test_procap(): 9 | model = ProteinCaption( 10 | lm_id="EleutherAI/gpt-neo-125m", 11 | gnn_dim_edges=16, 12 | context_size=8, 13 | context_per_chain=1, 14 | gnn_num_neighbors=4, 15 | gnn_num_layers=1, 16 | ) 17 | 18 | assert sum(p.numel() for p in model.parameters()) == 128839584 19 | X = torch.randn(1, 8, 4, 3) 20 | C = torch.ones(X.shape[:2]) 21 | caption = ["test caption"] 22 | chain_id = torch.tensor([1]) 23 | with torch.no_grad(): 24 | logits = model(X, C, caption, chain_id).logits 25 | assert logits.shape == torch.Size([1, 11, 50260]) 26 | temp = tempfile.NamedTemporaryFile() 27 | save_model(model, temp.name) 28 | del model 29 | model = load_model(temp.name) 30 | assert sum(p.numel() for p in model.parameters()) == 128839584 31 | temp.close() 32 | -------------------------------------------------------------------------------- /tests/utility/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/generatebio/chroma/929407c605013613941803c6113adefdccaad679/tests/utility/__init__.py -------------------------------------------------------------------------------- /tests/utility/test_api.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | 4 | import pytest 5 | 6 | import chroma 7 | import chroma.utility.api as api 8 | from chroma.models.graph_backbone import GraphBackbone 9 | from chroma.utility.model import load_model 10 | 11 | KEY_PATH = os.path.dirname(os.path.dirname(chroma.__file__)) 12 | KEY_PATH = os.path.join(KEY_PATH, "config.json") 13 | 14 | 15 | @pytest.mark.skipif(not os.path.exists(KEY_PATH), reason="requires file.txt") 16 | def test_api(): 17 | 18 | # Test Key Registration 19 | with tempfile.TemporaryDirectory() as key_directory: 20 | api.register_key("my_key", key_directory) 21 | 22 | # Test Reading 23 | api.read_key() 24 | 25 | # Test Download 26 | api.download_from_generate( 27 | "https://chroma-weights.generatebiomedicines.com/", "chroma_backbone_v1.0.pt" 28 | ) 29 | 30 | # Test Public Loading of BB Model (load a specific model using this requests pull) 31 | model = load_model( 32 | "named:public", GraphBackbone, device="cpu", strict_unexpected=False, 33 | ) 34 | --------------------------------------------------------------------------------