├── mogwai ├── __init__.py ├── data │ ├── __init__.py │ ├── base_wrapper_dataset.py │ ├── max_steps_dataset.py │ ├── repeat_dataset.py │ ├── pseudolikelihood_dataset.py │ ├── ms_dataset.py │ ├── trrosetta_dataset.py │ ├── msa_dataset.py │ ├── trrosetta_ms_dataset.py │ └── maskedlm_dataset.py ├── optim │ ├── __init__.py │ └── gremlin_adam.py ├── metrics │ ├── __init__.py │ └── metrics.py ├── utils │ ├── __init__.py │ ├── tensor.py │ ├── init.py │ ├── common.py │ └── functional.py ├── data_loading │ ├── __init__.py │ ├── ms_datamodule.py │ └── msa_datamodule.py ├── plotting │ ├── __init__.py │ ├── precision_length.py │ └── colored_preds.py ├── models │ ├── __init__.py │ ├── base_model.py │ ├── gremlin.py │ ├── multilayer_attention.py │ ├── attention.py │ └── factored_attention.py ├── vocab.py ├── train.py ├── lr_schedulers.py ├── parsing.py └── alignment.py ├── .gitattributes ├── requirements.txt ├── scripts └── download_example.sh ├── tests ├── test_gremlin_pl_performance.py ├── test_data_modules.py ├── test_gremlin_pl.py ├── test_metrics.py └── test_parsing.py ├── README.md ├── setup.py ├── LICENSE └── .gitignore /mogwai/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mogwai/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb -linguist-detectable -------------------------------------------------------------------------------- /mogwai/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from .gremlin_adam import GremlinAdam -------------------------------------------------------------------------------- /mogwai/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .metrics import contact_auc, precision_at_cutoff, precisions_in_range 2 | -------------------------------------------------------------------------------- /mogwai/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .functional import * 2 | from .tensor import collate_tensors # noqa: F401 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | biopython 2 | matplotlib 3 | numpy 4 | pytorch-lightning>=1.4 5 | scipy 6 | torch 7 | biotite 8 | -------------------------------------------------------------------------------- /mogwai/data_loading/__init__.py: -------------------------------------------------------------------------------- 1 | from .msa_datamodule import MSADataModule 2 | from .ms_datamodule import MSDataModule 3 | -------------------------------------------------------------------------------- /mogwai/plotting/__init__.py: -------------------------------------------------------------------------------- 1 | from .colored_preds import plot_colored_preds_on_trues 2 | from .precision_length import plot_precision_vs_length -------------------------------------------------------------------------------- /scripts/download_example.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | mkdir -p data 3 | 4 | wget -P data/test/4rb6Y -q -nc https://files.ipd.uw.edu/krypton/4rb6Y.i90c75.a3m 5 | wget -P data/test/4rb6Y -q -nc https://files.ipd.uw.edu/krypton/4rb6Y.pdb 6 | wget -P data/test/4rb6Y -q -nc https://files.ipd.uw.edu/krypton/4rb6Y.cf 7 | 8 | wget -P data/test/ http://s3.amazonaws.com/songlabdata/proteindata/mogwai/3er7_1_A.npz 9 | -------------------------------------------------------------------------------- /mogwai/models/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Type, Dict 2 | from .gremlin import Gremlin 3 | from .attention import Attention 4 | from .factored_attention import FactoredAttention 5 | from .base_model import BaseModel 6 | 7 | 8 | MODELS: Dict[str, BaseModel] = { 9 | "gremlin": Gremlin, 10 | "attention": Attention, 11 | "factored_attention": FactoredAttention, 12 | } 13 | 14 | 15 | def get(name: str) -> Type[BaseModel]: 16 | return MODELS[name.lower()] 17 | -------------------------------------------------------------------------------- /mogwai/utils/tensor.py: -------------------------------------------------------------------------------- 1 | """Contains common tensor operations.""" 2 | 3 | from typing import Sequence, Union 4 | import torch 5 | import numpy as np 6 | 7 | 8 | def collate_tensors( 9 | tensors: Sequence[torch.Tensor], pad_value: Union[int, float, bool, str] = 0 10 | ): 11 | dtype = tensors[0].dtype 12 | device = tensors[0].device 13 | batch_size = len(tensors) 14 | shape = (batch_size,) + tuple(np.max([tensor.size() for tensor in tensors], 0)) 15 | 16 | padded = torch.full(shape, pad_value, dtype=dtype, device=device) 17 | for position, tensor in zip(padded, tensors): 18 | tensorslice = tuple(slice(dim) for dim in tensor.shape) 19 | position[tensorslice] = tensor 20 | 21 | return padded 22 | -------------------------------------------------------------------------------- /mogwai/data/base_wrapper_dataset.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import torch 3 | from typing import List, Any 4 | 5 | 6 | class BaseWrapperDataset(torch.utils.data.Dataset): 7 | """BaseWrapperDataset. Wraps an existing dataset. 8 | 9 | Args: 10 | dataset (torch.utils.data.dataset): Dataset to wrap. 11 | """ 12 | 13 | def __init__(self, dataset: torch.utils.data.dataset): 14 | super().__init__() 15 | self.dataset = dataset 16 | 17 | def __getitem__(self, idx): 18 | return self.dataset[idx] 19 | 20 | def __len__(self): 21 | return len(self.dataset) 22 | 23 | @staticmethod 24 | def add_args(parser: ArgumentParser) -> ArgumentParser: 25 | return parser 26 | 27 | def collater(self, batch: List[Any]) -> Any: 28 | return self.dataset.collater(batch) 29 | -------------------------------------------------------------------------------- /tests/test_gremlin_pl_performance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import unittest 3 | 4 | import pytorch_lightning as pl 5 | 6 | from mogwai.data_loading import MSADataModule 7 | from mogwai.parsing import contacts_from_cf 8 | from mogwai.models import Gremlin 9 | 10 | 11 | class TestGremlinPLPerformance(unittest.TestCase): 12 | def setUp(self): 13 | npz_path = "data/test/4rb6Y/4rb6Y.i90c75.a3m" 14 | self.dm = MSADataModule(npz_path, batch_size=4096) 15 | self.dm.setup() 16 | 17 | true_contacts = contacts_from_cf("data/test/4rb6Y/4rb6Y.cf") 18 | 19 | n, l, msa_counts = self.dm.get_stats() 20 | self.model = Gremlin( 21 | n, l, msa_counts, true_contacts=torch.tensor(true_contacts) 22 | ) 23 | self.trainer = pl.Trainer( 24 | min_steps=50, 25 | max_steps=50, 26 | gpus=1, 27 | ) 28 | 29 | def test_training_performance(self): 30 | self.trainer.fit(self.model, self.dm) 31 | final_auc_apc = self.model.get_auc(do_apc=True) 32 | self.assertGreaterEqual(final_auc_apc, 0.91) 33 | 34 | 35 | if __name__ == "__main__": 36 | unittest.main() 37 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # mogwai: Probabilistic Models of Protein Families 2 | 3 | A library of tested models, metrics, and data loading for protein families. Implemented in PyTorch and PyTorch Lightning. 4 | 5 | Under active development, feedback welcome. 6 | 7 | ## Getting Started 8 | 9 | For now, we support cloning and installing in developer mode. 10 | 11 | ```bash 12 | pip install -e . 13 | ``` 14 | 15 | You will also need to install `apex` (needed to use DDP and FusedLamb training): 16 | 17 | ``` 18 | source venv/bin/activate 19 | git clone git@github.com:NVIDIA/apex.git 20 | cd apex 21 | ``` 22 | 23 | Modify `setup.py` 24 | 25 | Find: 26 | ``` 27 | if (bare_metal_major != torch_binary_major) or (bare_metal_minor != torch_binary_minor): 28 | ``` 29 | And replace it with 30 | ``` 31 | if (bare_metal_major != torch_binary_major): 32 | ``` 33 | To remove the minor version check. This will allow apex to install. 34 | 35 | Install apex 36 | ``` 37 | pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ 38 | ``` 39 | ## Examples 40 | 41 | * [Potts Model with Pseudolikelihood](https://github.com/nickbhat/mogwai/blob/main/examples/gremlin_train.ipynb) 42 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("requirements.txt", "r") as reqs_f: 4 | requirements = reqs_f.read().split() 5 | 6 | with open("LICENSE", "r") as lf: 7 | LICENSE = lf.read() 8 | 9 | setuptools.setup( 10 | name="mogwai-protein", 11 | version="0.0.1", 12 | author="Nick Bhattacharya", 13 | author_email="nick_bhat@berkeley.edu", 14 | description="A package for training and evaluating probabilistic models of protein families.", 15 | url="https://github.com/nickbhat/mogwai", 16 | packages=setuptools.find_packages(), 17 | license=LICENSE, 18 | install_requires=requirements, 19 | scripts=["scripts/download_example.sh"], 20 | entry_points={ 21 | "console_scripts": [ 22 | "mogwai-train=mogwai.train:train", 23 | "mogwai-align=mogwai.alignment:make_a3m_cli", 24 | ] 25 | }, 26 | classifiers=[ 27 | "Programming Language :: Python :: 3.6", 28 | "Operating System :: POSIX :: Linux", 29 | "Intended Audience :: Science/Research", 30 | "Topic :: Scientific/Engineering :: Artificial Intelligence" 31 | "Topic :: Scientific/Engineering :: Bio-Informatics", 32 | ], 33 | ) 34 | -------------------------------------------------------------------------------- /mogwai/data/max_steps_dataset.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import torch 3 | from .base_wrapper_dataset import BaseWrapperDataset 4 | 5 | 6 | class MaxStepsDataset(BaseWrapperDataset): 7 | """MaxSteps repeats the same dataset in order to have a length of exactly max_steps. 8 | This can help when working with relatively small datasets, since PyTorch dataloading 9 | operations reset at the end of each epoch. It will also provide accurate timing 10 | estimates for pytorch lightning. 11 | 12 | Args: 13 | dataset (torch.utils.data.Dataset): Dataset to wrap 14 | max_steps (int): Total number of training steps. 15 | """ 16 | 17 | def __init__(self, dataset: torch.utils.data.Dataset, max_steps: int, batch_size: int): 18 | super().__init__(dataset) 19 | self._max_steps = max_steps 20 | self._batch_size = batch_size 21 | 22 | def __getitem__(self, idx: int): 23 | if idx >= len(self): 24 | raise IndexError( 25 | f"index {idx} out of bounds for dataset of size {len(self)}" 26 | ) 27 | return self.dataset[idx % len(self.dataset)] 28 | 29 | def __len__(self): 30 | return self._max_steps * self._batch_size 31 | -------------------------------------------------------------------------------- /mogwai/data/repeat_dataset.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import torch 3 | from .base_wrapper_dataset import BaseWrapperDataset 4 | 5 | 6 | class RepeatDataset(BaseWrapperDataset): 7 | """RepeatDataset repeats the same dataset multiple times. This can help when 8 | working with relatively small datasets, since PyTorch dataloading operations 9 | reset at the end of each epoch. 10 | 11 | Args: 12 | dataset (torch.utils.data.Dataset): Dataset to wrap 13 | n (int): Number of times to repeat the dataset. 14 | """ 15 | 16 | def __init__(self, dataset: torch.utils.data.Dataset, n: int): 17 | super().__init__(dataset) 18 | self._n = n 19 | 20 | def __getitem__(self, idx: int): 21 | if idx >= len(self): 22 | raise IndexError( 23 | f"index {idx} out of bounds for dataset of size {len(self)}" 24 | ) 25 | return self.dataset[idx % len(self.dataset)] 26 | 27 | def __len__(self): 28 | return self._n * len(self.dataset) 29 | 30 | @staticmethod 31 | def add_args(parser: ArgumentParser) -> ArgumentParser: 32 | parser.add_argument( 33 | "--num_repeats", 34 | type=int, 35 | default=10000, 36 | help="Number of times to repeat the input dataset.", 37 | ) 38 | return parser 39 | -------------------------------------------------------------------------------- /tests/test_data_modules.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from mogwai.data_loading import MSADataModule 4 | 5 | 6 | class TestA3MDataModule(unittest.TestCase): 7 | def setUp(self): 8 | a3m_path = "data/test/4rb6Y/4rb6Y.i90c75.a3m" 9 | self.dm = MSADataModule(a3m_path, batch_size=64) 10 | self.dm.setup() 11 | 12 | def test_datamodule_stats(self): 13 | num_seqs, msa_length, msa_counts = self.dm.get_stats() 14 | self.assertEqual(num_seqs, 7569) 15 | self.assertEqual(msa_length, 107) 16 | self.assertTupleEqual(msa_counts.shape, (107, 20)) 17 | 18 | def test_batch_shape(self): 19 | batch = next(iter(self.dm.train_dataloader())) 20 | self.assertTupleEqual(batch['src_tokens'].shape, (64, 107)) 21 | 22 | 23 | class TestNPZDataModule(unittest.TestCase): 24 | def setUp(self): 25 | npz_path = "data/test/3er7_1_A.npz" 26 | self.dm = MSADataModule(npz_path, batch_size=64) 27 | self.dm.setup() 28 | 29 | def test_datamodule_stats(self): 30 | num_seqs, msa_length, msa_counts = self.dm.get_stats() 31 | self.assertEqual(num_seqs, 33672) 32 | self.assertEqual(msa_length, 118) 33 | self.assertTupleEqual(msa_counts.shape, (118, 20)) 34 | 35 | def test_batch_shape(self): 36 | batch = next(iter(self.dm.train_dataloader())) 37 | self.assertTupleEqual(batch['src_tokens'].shape, (64, 118)) 38 | 39 | 40 | if __name__ == "__main__": 41 | unittest.main() 42 | -------------------------------------------------------------------------------- /mogwai/data/pseudolikelihood_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict 2 | from argparse import ArgumentParser 3 | import torch 4 | from .base_wrapper_dataset import BaseWrapperDataset 5 | from ..utils import collate_tensors 6 | from ..vocab import FastaVocab 7 | 8 | 9 | class PseudolikelihoodDataset(BaseWrapperDataset): 10 | """PseudolikelihoodDataset implements a mostly-dummy dataset, which simply wraps an 11 | existing token dataset. It is designed to act as a drop-in replacement of the 12 | MaskedLMDataset. 13 | 14 | Args: 15 | dataset (torch.utils.data.dataset): Dataset of tensors to wrap. 16 | """ 17 | 18 | def __init__(self, dataset: torch.utils.data.dataset): 19 | super().__init__(dataset) 20 | 21 | def __getitem__(self, idx): 22 | item = self.dataset[idx] 23 | if isinstance(item, tuple) and len(item) == 1: 24 | item = item[0] 25 | return {"src_tokens": item, "targets": item.clone()} 26 | 27 | @staticmethod 28 | def add_args(parser: ArgumentParser) -> ArgumentParser: 29 | return parser 30 | 31 | def collater(self, batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: 32 | concat = { 33 | "src_tokens": collate_tensors( 34 | [element["src_tokens"] for element in batch], FastaVocab.pad_idx 35 | ), 36 | "targets": collate_tensors( 37 | [element["targets"] for element in batch], FastaVocab.pad_idx 38 | ), 39 | } 40 | return concat 41 | -------------------------------------------------------------------------------- /mogwai/data/ms_dataset.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from typing import Union, NamedTuple, List 3 | from pathlib import Path 4 | import torch 5 | 6 | from ..parsing import parse_fasta 7 | from ..vocab import FastaVocab 8 | from ..utils import collate_tensors 9 | 10 | MSStats = NamedTuple( 11 | "MSStats", [("num_seqs", int), ("reference", torch.Tensor)] 12 | ) 13 | 14 | 15 | class MSDataset(torch.utils.data.Dataset): 16 | """MSDataset: Loads a multiple sequences directly from a fasta or a3m file. 17 | 18 | Args: 19 | data (PathLike): Path to fasta or a3m file. 20 | """ 21 | 22 | def __init__(self, data: Union[str, Path]): 23 | super().__init__() 24 | _, sequences = parse_fasta(data, remove_gaps=True) 25 | self.data = [ 26 | torch.tensor(FastaVocab.tokenize(seq), dtype=torch.long) 27 | for seq in sequences 28 | ] 29 | 30 | def __getitem__(self, idx): 31 | return self.data[idx] 32 | 33 | def __len__(self): 34 | return len(self.data) 35 | 36 | @property 37 | def num_seqs(self) -> int: 38 | return len(self) 39 | 40 | @property 41 | def reference(self) -> torch.Tensor: 42 | return self.data[0] 43 | 44 | def get_stats(self) -> MSStats: 45 | return MSStats(self.num_seqs, self.reference) 46 | 47 | @staticmethod 48 | def add_args(parser: ArgumentParser) -> ArgumentParser: 49 | return parser 50 | 51 | def collater(self, sequences: List[torch.Tensor]) -> torch.Tensor: 52 | return collate_tensors(sequences, FastaVocab.pad_idx) 53 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2020, Nick Bhattacharya 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /mogwai/data/trrosetta_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Union, NamedTuple, List 2 | 3 | from argparse import ArgumentParser 4 | import numpy as np 5 | from pathlib import Path 6 | import torch 7 | 8 | from ..vocab import FastaVocab 9 | 10 | MSAStats = NamedTuple( 11 | "MSAStats", [("num_seqs", int), ("msa_length", int), ("msa_counts", torch.Tensor)] 12 | ) 13 | 14 | 15 | class TRRosetta_MSADataset(torch.utils.data.TensorDataset): 16 | """TRRosetta Dataset: Loads a multiple sequence alignment directly from a TRRosetta npz file. 17 | 18 | Args: 19 | data (PathLike): Path to npz file. 20 | """ 21 | 22 | def __init__(self, data: Union[str, Path]): 23 | fam_data = np.load(data) 24 | msa = fam_data["msa"] 25 | super().__init__(torch.tensor(msa, dtype=torch.long)) 26 | 27 | @property 28 | def num_seqs(self) -> int: 29 | return self.tensors[0].size(0) 30 | 31 | @property 32 | def msa_length(self) -> int: 33 | return self.tensors[0].size(1) 34 | 35 | @property 36 | def msa_counts(self) -> torch.Tensor: 37 | if not hasattr(self, "_msa_counts"): 38 | self._msa_counts = torch.eye(len(FastaVocab) + 1, len(FastaVocab))[ 39 | self.tensors[0] 40 | ].sum(0) 41 | return self._msa_counts 42 | 43 | def get_stats(self) -> MSAStats: 44 | return MSAStats(self.num_seqs, self.msa_length, self.msa_counts) 45 | 46 | @staticmethod 47 | def add_args(parser: ArgumentParser) -> ArgumentParser: 48 | return parser 49 | 50 | def collater(self, sequences: List[torch.Tensor]) -> torch.Tensor: 51 | return torch.stack(sequences, 0) 52 | -------------------------------------------------------------------------------- /mogwai/data/msa_dataset.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from typing import Union, NamedTuple, List 3 | from pathlib import Path 4 | import torch 5 | 6 | from ..parsing import parse_fasta 7 | from ..vocab import FastaVocab 8 | 9 | MSAStats = NamedTuple( 10 | "MSAStats", [("num_seqs", int), ("msa_length", int), ("msa_counts", torch.Tensor)] 11 | ) 12 | 13 | 14 | class MSADataset(torch.utils.data.TensorDataset): 15 | """MSADataset: Loads a multiple sequence alignment directly from an A3M file. 16 | 17 | Args: 18 | data (PathLike): Path to a3m file. 19 | """ 20 | 21 | def __init__(self, data: Union[str, Path]): 22 | _, sequences = parse_fasta(data, remove_insertions=True) 23 | indices = [FastaVocab.tokenize(seq) for seq in sequences] 24 | super().__init__(torch.tensor(indices, dtype=torch.long)) 25 | 26 | @property 27 | def num_seqs(self) -> int: 28 | return self.tensors[0].size(0) 29 | 30 | @property 31 | def msa_length(self) -> int: 32 | return self.tensors[0].size(1) 33 | 34 | @property 35 | def msa_counts(self) -> torch.Tensor: 36 | if not hasattr(self, "_msa_counts"): 37 | self._msa_counts = torch.eye(len(FastaVocab) + 1, len(FastaVocab))[ 38 | self.tensors[0] 39 | ].sum(0) 40 | return self._msa_counts 41 | 42 | def get_stats(self) -> MSAStats: 43 | return MSAStats(self.num_seqs, self.msa_length, self.msa_counts) 44 | 45 | @staticmethod 46 | def add_args(parser: ArgumentParser) -> ArgumentParser: 47 | return parser 48 | 49 | def collater(self, sequences: List[torch.Tensor]) -> torch.Tensor: 50 | return torch.stack(sequences, 0) 51 | -------------------------------------------------------------------------------- /mogwai/utils/init.py: -------------------------------------------------------------------------------- 1 | """Contains initializers shared by models.""" 2 | 3 | from typing import Tuple 4 | 5 | from math import log 6 | import torch 7 | 8 | from .functional import zero_diag_ 9 | 10 | 11 | def init_potts_bias( 12 | msa_counts: torch.Tensor, l2_coeff: float, num_seqs: int 13 | ) -> torch.Tensor: 14 | """Initialize single-site log-potential as regularized PSSM. 15 | 16 | Args: 17 | msa_counts (tensor): Counts of amino acids per-position of MSA. 18 | l2_coeff (float): L2 regularization weight. 19 | num_seqs (int): Number of sequences in MSA. 20 | """ 21 | bias = (msa_counts + l2_coeff * log(num_seqs)).log() 22 | bias.add_(-bias.mean(-1, keepdims=True)) # type: ignore 23 | return bias 24 | 25 | 26 | def init_potts_weight(msa_length: int, vocab_size: int) -> torch.Tensor: 27 | """Initializes Potts coupling matrices of all zeros. 28 | 29 | Args: 30 | msa_length (int): Length of MSA. 31 | vocab_size (int): Number of characters in MSA alphabet. 32 | """ 33 | weight = torch.zeros(msa_length, vocab_size, msa_length, vocab_size) 34 | return weight 35 | 36 | 37 | def init_pseudolik_mask(msa_length: int) -> torch.Tensor: 38 | """Creates mask for efficient pseudolikelihood calculation. 39 | 40 | Args: 41 | msa_length (int): Length of MSA. 42 | """ 43 | diag_mask = torch.ones(msa_length, msa_length, dtype=torch.float) 44 | diag_mask = zero_diag_(diag_mask) 45 | return diag_mask 46 | 47 | 48 | def gremlin_weight_decay_coeffs( 49 | batch_size: int, msa_length: int, l2_coeff: float, vocab_size: int = 20 50 | ) -> Tuple[float, float]: 51 | weight_coeff = l2_coeff * vocab_size / (2 * batch_size) 52 | bias_coeff = l2_coeff / (batch_size * msa_length) 53 | return weight_coeff, bias_coeff 54 | -------------------------------------------------------------------------------- /mogwai/vocab.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence, List 2 | 3 | 4 | class _FastaVocab: 5 | 6 | def __init__(self): 7 | self.ALPHABET = "ARNDCQEGHILKMFPSTWYV-" 8 | self.A2N = {a: n for n, a in enumerate(self.ALPHABET)} 9 | self.A2N["X"] = 20 10 | 11 | self.IUPAC_CODES = { 12 | "Ala": "A", 13 | "Arg": "R", 14 | "Asn": "N", 15 | "Asp": "D", 16 | "Cys": "C", 17 | "Gln": "Q", 18 | "Glu": "E", 19 | "Gly": "G", 20 | "His": "H", 21 | "Ile": "I", 22 | "Leu": "L", 23 | "Lys": "K", 24 | "Met": "M", 25 | "Phe": "F", 26 | "Pro": "P", 27 | "Ser": "S", 28 | "Thr": "T", 29 | "Trp": "W", 30 | "Val": "V", 31 | "Tyr": "Y", 32 | "Asx": "B", 33 | "Sec": "U", 34 | "Xaa": "X", 35 | "Glx": "Z", 36 | } 37 | 38 | self.THREE_LETTER = {aa: name for name, aa in self.IUPAC_CODES.items()} 39 | 40 | def convert_indices_to_tokens(self, indices: Sequence[int]) -> List[str]: 41 | return [self.ALPHABET[i] for i in indices] 42 | 43 | def convert_tokens_to_indices(self, tokens: Sequence[str], skip_unknown: bool = False) -> List[int]: 44 | if skip_unknown: 45 | return [self.A2N[token] for token in tokens if token in self.A2N] 46 | else: 47 | return [self.A2N.get(token, 20) for token in tokens] 48 | 49 | def tokenize(self, sequence: str) -> List[int]: 50 | return self.convert_tokens_to_indices(list(sequence)) 51 | 52 | def __len__(self) -> int: 53 | return 20 54 | 55 | @property 56 | def pad_idx(self) -> int: 57 | return 20 58 | 59 | 60 | FastaVocab = _FastaVocab() 61 | -------------------------------------------------------------------------------- /mogwai/utils/common.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | import functools 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def coerce_numpy(func: Callable) -> Callable: 8 | """ Allows user to pass numpy arguments to a torch function and auto-converts back to 9 | numpy at the end. 10 | """ 11 | @functools.wraps(func) 12 | def make_torch_args(*args, **kwargs): 13 | is_numpy = False 14 | update_args = [] 15 | for arg in args: 16 | if isinstance(arg, np.ndarray): 17 | arg = torch.from_numpy(arg) 18 | is_numpy = True 19 | update_args.append(arg) 20 | update_kwargs = {} 21 | for kw, arg in kwargs.items(): 22 | if isinstance(args, np.ndarray): 23 | arg = torch.from_numpy(arg) 24 | is_numpy = True 25 | update_kwargs[kw] = arg 26 | 27 | output = func(*update_args, **update_kwargs) 28 | 29 | if is_numpy: 30 | output = recursive_make_numpy(output) 31 | 32 | return output 33 | 34 | return make_torch_args 35 | 36 | 37 | def recursive_make_torch(item): 38 | if isinstance(item, np.ndarray): 39 | return torch.from_numpy(item) 40 | elif isinstance(item, (tuple, list)): 41 | return type(item)(recursive_make_torch(el) for el in item) 42 | elif isinstance(item, dict): 43 | return {kw: recursive_make_torch(arg) for kw, arg in item.items()} 44 | else: 45 | return item 46 | 47 | 48 | def recursive_make_numpy(item): 49 | if isinstance(item, torch.Tensor): 50 | return item.detach().cpu().numpy() 51 | elif isinstance(item, (tuple, list)): 52 | return type(item)(recursive_make_numpy(el) for el in item) 53 | elif isinstance(item, dict): 54 | return {kw: recursive_make_numpy(arg) for kw, arg in item.items()} 55 | else: 56 | return item 57 | -------------------------------------------------------------------------------- /tests/test_gremlin_pl.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import numpy as np 3 | import torch 4 | import unittest 5 | 6 | from mogwai.parsing import one_hot 7 | from mogwai.models import Gremlin 8 | 9 | 10 | class TestGremlinPL(unittest.TestCase): 11 | def setUp(self): 12 | torch.manual_seed(0) 13 | 14 | N = 100 15 | L = 20 16 | A = 8 17 | msa = torch.randint(0, A, [N, L]) 18 | msa = torch.FloatTensor(one_hot(msa.numpy(), cat=A)) 19 | msa_counts = msa.sum(0) 20 | 21 | self.msa = msa 22 | self.model = Gremlin(N, L, msa_counts, vocab_size=A) 23 | 24 | # Need nonzero weights but don't want to take a grad for this test 25 | wt = self.model.weight.data 26 | self.model.weight.data = torch.randn_like(wt) 27 | 28 | # Used for data leakage test. 29 | self.A = A 30 | 31 | def test_parameter_shapes(self): 32 | self.assertTupleEqual(self.model.weight.shape, (20, 8, 20, 8)) 33 | self.assertTupleEqual(self.model.bias.shape, (20, 8)) 34 | 35 | def test_forward_shape(self): 36 | batch = self.msa[:64] 37 | logits = self.model(batch)[0] 38 | self.assertTupleEqual(logits.shape, (64, 20, 8)) 39 | 40 | def onehot_vector(self, idx: int): 41 | oh = torch.zeros(self.A) 42 | oh[idx] = 1.0 43 | return oh 44 | 45 | @torch.no_grad() 46 | def test_data_leakage(self): 47 | # Confirm that logits for position 0 do not change 48 | # when sequence at position 0 is exhaustively changed. 49 | logits_list = [] 50 | example = self.msa[0] 51 | 52 | seq_pos = 0 53 | for i in range(self.A): 54 | example[seq_pos] = self.onehot_vector(i) 55 | logits = self.model(example.unsqueeze(0))[0] 56 | logits_list.append(logits[0, seq_pos]) 57 | all_pairs = itertools.combinations(logits_list, 2) 58 | for x, y in all_pairs: 59 | np.testing.assert_array_almost_equal(x.numpy(), y.numpy()) 60 | 61 | 62 | 63 | if __name__ == "__main__": 64 | unittest.main() 65 | -------------------------------------------------------------------------------- /mogwai/data/trrosetta_ms_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Union, NamedTuple, List 2 | 3 | from argparse import ArgumentParser 4 | import numpy as np 5 | from pathlib import Path 6 | import torch 7 | 8 | from ..vocab import FastaVocab 9 | 10 | MSAStats = NamedTuple( 11 | "MSAStats", [("num_seqs", int), ("msa_length", int), ("msa_counts", torch.Tensor)] 12 | ) 13 | 14 | PAD_IDX = FastaVocab.pad_idx 15 | 16 | 17 | def construct_unalignment(msa): 18 | padded_ms = torch.zeros_like(msa) 19 | ms = [seq[seq != PAD_IDX] for seq in msa] 20 | for i, m in enumerate(ms): 21 | padded_ms[i, : len(m)] = m 22 | return padded_ms 23 | 24 | 25 | class TRRosetta_MSDataset(torch.utils.data.TensorDataset): 26 | """TRRosetta Dataset: Loads a multiple sequence alignment directly from a TRRosetta npz file. 27 | 28 | Args: 29 | data (PathLike): Path to npz file. 30 | """ 31 | 32 | def __init__(self, data: Union[str, Path]): 33 | fam_data = np.load(data) 34 | msa = fam_data["msa"] 35 | msa = torch.from_numpy(msa) 36 | ms = construct_unalignment(msa) 37 | super().__init__(ms.long()) 38 | 39 | @property 40 | def num_seqs(self) -> int: 41 | return self.tensors[0].size(0) 42 | 43 | @property 44 | def msa_length(self) -> int: 45 | return self.tensors[0].size(1) 46 | 47 | @property 48 | def reference(self) -> torch.Tensor: 49 | return self.tensors[0] 50 | 51 | @property 52 | def msa_counts(self) -> torch.Tensor: 53 | if not hasattr(self, "_msa_counts"): 54 | self._msa_counts = torch.eye(len(FastaVocab) + 1, len(FastaVocab))[ 55 | self.tensors[0] 56 | ].sum(0) 57 | return self._msa_counts 58 | 59 | def get_stats(self) -> MSAStats: 60 | return MSAStats(self.num_seqs, self.msa_length, self.msa_counts) 61 | 62 | @staticmethod 63 | def add_args(parser: ArgumentParser) -> ArgumentParser: 64 | return parser 65 | 66 | def collater(self, sequences: List[torch.Tensor]) -> torch.Tensor: 67 | return torch.stack(sequences, 0) 68 | -------------------------------------------------------------------------------- /mogwai/plotting/precision_length.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import torch 4 | 5 | 6 | def plot_precision_vs_length( 7 | pred: torch.Tensor, 8 | meas: torch.Tensor, 9 | thresh: float = 1e-4, 10 | superdiag: int = 6, 11 | ): 12 | """Plot precision versus length for various length cutoffs. 13 | 14 | Analogous to a precision-recall curve. 15 | 16 | Args: 17 | pred (tensor): Predicted contact scores or probabilities. 18 | meas (tensor): Binary matrix of true contacts. 19 | thresh (float, optional): Threshold at which to call a predicted contact. 20 | superdiag (int, optional): Ignore all true and predicted contacts from diag to superdiag. 21 | """ 22 | # Ignore nearby contacts 23 | eval_idx = np.triu_indices_from(meas, superdiag) 24 | pred_, meas_ = pred[eval_idx], meas[eval_idx] 25 | 26 | # Sort by model confidence 27 | sort_idx = pred_.argsort(descending=True) 28 | 29 | # want to separate correct from incorrect indices 30 | true_pos = list() 31 | false_pos = list() 32 | length = meas.shape[0] 33 | precision = list() 34 | optimal_precision = list() 35 | 36 | num_contacts = len(np.nonzero(meas_)) 37 | # Only consider top 2L predictions 38 | for i, idx in enumerate(sort_idx[: (2 * length)]): 39 | # idx is in the flattened array of upper triang. values 40 | # recover the position in the matrix 41 | 42 | # Update optimal precision based on number of true contacts 43 | if i <= num_contacts: 44 | optimal_precision.append(1.0) 45 | else: 46 | num_false = i - num_contacts 47 | optimal_precision.append(num_contacts / (num_contacts + num_false)) 48 | 49 | # Update model precision based on predictions 50 | xy = (eval_idx[0][idx], eval_idx[1][idx]) 51 | if meas_[idx] >= thresh: 52 | true_pos.append(xy) 53 | else: 54 | false_pos.append(xy) 55 | 56 | precision.append(len(true_pos) / (len(true_pos) + len(false_pos))) 57 | 58 | plt.plot(precision, color="b") 59 | plt.plot(optimal_precision, color="k") -------------------------------------------------------------------------------- /mogwai/optim/gremlin_adam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.optimizer import Optimizer 3 | 4 | 5 | class GremlinAdam(torch.optim.Adam): 6 | """Modified Adam optimizer for convex Potts model structure learning.""" 7 | 8 | @torch.no_grad() 9 | def step(self, closure=None): 10 | loss = None 11 | if closure is not None: 12 | with torch.enable_grad(): 13 | loss = closure() 14 | 15 | for group in self.param_groups: 16 | for p in group["params"]: 17 | if p.grad is None: 18 | continue 19 | grad = p.grad 20 | if grad.is_sparse: 21 | raise RuntimeError( 22 | "Adam does not support sparse gradients, please consider " 23 | "SparseAdam instead" 24 | ) 25 | 26 | state = self.state[p] 27 | 28 | # State initialization 29 | if len(state) == 0: 30 | state["step"] = 0 31 | # Exponential moving average of gradient values 32 | state["exp_avg"] = torch.zeros_like( 33 | p, memory_format=torch.preserve_format 34 | ) 35 | # Exponential moving average of squared gradient values 36 | state["exp_avg_sq"] = torch.zeros_like( 37 | p, memory_format=torch.preserve_format 38 | ) 39 | 40 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] 41 | beta1, beta2 = group["betas"] 42 | 43 | if group["weight_decay"] != 0: 44 | grad = grad.add(p, alpha=group["weight_decay"]) 45 | 46 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 47 | # NOTE: This ties together all squared updates unlike in Adam. 48 | sq_grad = (grad.view(1, -1) @ grad.view(-1, 1)).squeeze(1).squeeze(0) 49 | exp_avg_sq.mul_(beta2).add_(sq_grad, alpha=1 - beta2) 50 | p.addcdiv_( 51 | exp_avg, (exp_avg_sq.sqrt().add_(group["eps"])), value=-group["lr"] 52 | ) 53 | 54 | return loss 55 | -------------------------------------------------------------------------------- /tests/test_metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import unittest 3 | 4 | from mogwai.metrics import contact_auc, precision_at_cutoff 5 | 6 | 7 | class TestPrecision(unittest.TestCase): 8 | def setUp(self): 9 | self.pred = torch.FloatTensor( 10 | [ 11 | [1e-3, 1e-2, 0.8], 12 | [1e-2, 1e-4, 0.3], 13 | [0.8, 0.3, 1e-10], 14 | ] 15 | ) 16 | self.meas = torch.IntTensor([[0, 1, 1], [1, 0, 0], [1, 1, 0]]) 17 | 18 | def test_precision_cutoffs(self): 19 | p_at_1 = precision_at_cutoff(self.pred, self.meas, cutoff=1, superdiag=0) 20 | p_at_2 = precision_at_cutoff(self.pred, self.meas, cutoff=2, superdiag=0) 21 | p_at_3 = precision_at_cutoff(self.pred, self.meas, cutoff=3, superdiag=0) 22 | 23 | self.assertEqual(p_at_1, 2.0 / 3) 24 | self.assertEqual(p_at_2, 1.0) 25 | self.assertEqual(p_at_3, 1.0) 26 | 27 | def test_superdiag(self): 28 | superdiag_0 = precision_at_cutoff(self.pred, self.meas, cutoff=1, superdiag=0) 29 | superdiag_1 = precision_at_cutoff(self.pred, self.meas, cutoff=1, superdiag=1) 30 | superdiag_2 = precision_at_cutoff(self.pred, self.meas, cutoff=1, superdiag=2) 31 | superdiag_3 = precision_at_cutoff(self.pred, self.meas, cutoff=1, superdiag=3) 32 | self.assertEqual(superdiag_0, 2.0 / 3) 33 | self.assertEqual(superdiag_1, 2.0 / 3) 34 | self.assertEqual(superdiag_2, 1.0) 35 | self.assertTrue(superdiag_3.isnan()) 36 | 37 | 38 | class TestAUC(unittest.TestCase): 39 | def setUp(self): 40 | self.pred = torch.FloatTensor( 41 | [ 42 | [1e-3, 1e-2, 0.8], 43 | [1e-2, 1e-4, 0.3], 44 | [0.8, 0.3, 1e-10], 45 | ] 46 | ) 47 | self.meas = torch.IntTensor([[0, 1, 1], [1, 0, 0], [1, 1, 0]]) 48 | 49 | def test_range(self): 50 | auc = contact_auc(self.pred, self.meas, superdiag=0, cutoff_range=[1, 2, 3]) 51 | self.assertEqual(auc, 8.0 / 9) 52 | 53 | def test_superdiag_range(self): 54 | auc_superdiag_1 = contact_auc( 55 | self.pred, self.meas, superdiag=1, cutoff_range=[1, 2, 3] 56 | ) 57 | auc_superdiag_2 = contact_auc( 58 | self.pred, self.meas, superdiag=2, cutoff_range=[1, 2, 3] 59 | ) 60 | self.assertEqual(auc_superdiag_1, 8.0 / 9) 61 | self.assertEqual(auc_superdiag_2, 1.0) 62 | 63 | 64 | if __name__ == "__main__": 65 | unittest.main() -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Pytorch-lightning 2 | lightning_logs/ 3 | 4 | # Data 5 | data/ 6 | !mogwai/data 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | pip-wheel-metadata/ 31 | share/python-wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 102 | __pypackages__/ 103 | 104 | # Celery stuff 105 | celerybeat-schedule 106 | celerybeat.pid 107 | 108 | # SageMath parsed files 109 | *.sage.py 110 | 111 | # Environments 112 | .env 113 | .venv 114 | env/ 115 | venv/ 116 | ENV/ 117 | env.bak/ 118 | venv.bak/ 119 | 120 | # Spyder project settings 121 | .spyderproject 122 | .spyproject 123 | 124 | # Rope project settings 125 | .ropeproject 126 | 127 | # mkdocs documentation 128 | /site 129 | 130 | # mypy 131 | .mypy_cache/ 132 | .dmypy.json 133 | dmypy.json 134 | 135 | # Pyre type checker 136 | .pyre/ 137 | -------------------------------------------------------------------------------- /mogwai/plotting/colored_preds.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import torch 4 | 5 | 6 | def plot_colored_preds_on_trues( 7 | pred: torch.Tensor, 8 | meas: torch.Tensor, 9 | thresh: float = 1e-4, 10 | superdiag: int = 6, 11 | cutoff: int = 1, 12 | point_size: int = 1, 13 | ): 14 | """Plot contact map predictions overlayed on true contacts. 15 | 16 | Args: 17 | pred (tensor): Predicted contact scores or probabilities. 18 | meas (tensor): Binary matrix of true contacts. 19 | thresh (float, optional): Threshold at which to call a predicted contact. 20 | superdiag (int, optional): Ignore all true and predicted contacts from diag to superdiag. 21 | cutoff (int, optional): Only compute precision of top L/cutoff predictions. 22 | point_size (int, optional): Size of each colored point in the plot. 23 | """ 24 | # Ignore nearby contacts 25 | eval_idx = np.triu_indices_from(meas, superdiag) 26 | pred_, meas_ = pred[eval_idx], meas[eval_idx] 27 | 28 | # Sort by model confidence 29 | sort_idx = pred_.argsort(descending=True) 30 | 31 | # want to plot the indexes that are right in blue and 32 | # the ones that are wrong in red 33 | # just consider the top L/cutoff contacts 34 | true_pos = list() 35 | false_pos = list() 36 | length = meas.shape[0] 37 | len_cutoff = int(length / cutoff) 38 | 39 | for idx in sort_idx[:len_cutoff]: 40 | # idx is in the flattened array of upper triang. values 41 | # recover the position in the matrix 42 | xy = (eval_idx[0][idx], eval_idx[1][idx]) 43 | if meas_[idx] >= thresh: 44 | true_pos.append(xy) 45 | else: 46 | false_pos.append(xy) 47 | 48 | # there should only be len_cutoff total level contacts 49 | assert len(true_pos) + len(false_pos) == len_cutoff 50 | 51 | true_contacts_ij = list() 52 | for i, j in zip(eval_idx[0], eval_idx[1]): 53 | if meas[i, j] >= thresh: 54 | true_contacts_ij.append((i, j)) 55 | 56 | plt.imshow(meas, cmap="gray_r", alpha=0.1, label="measured contacts") 57 | if len(true_contacts_ij) > 1: 58 | x, y = zip(*true_contacts_ij) 59 | plt.scatter(x, y, c="grey", s=point_size, alpha=0.3) 60 | # plot symmetric values 61 | plt.scatter(y, x, c="grey", s=point_size, alpha=0.3) 62 | if len(true_pos) > 1: 63 | x, y = zip(*true_pos) 64 | plt.scatter(x, y, c="b", s=point_size, alpha=0.4, label="true positives") 65 | plt.scatter(y, x, c="b", s=point_size, alpha=0.4, label="true positives") 66 | if len(false_pos) > 1: 67 | x, y = zip(*false_pos) 68 | plt.scatter(x, y, c="r", s=point_size, alpha=0.4, label="false positives") 69 | plt.scatter(y, x, c="r", s=point_size, alpha=0.4, label="false positives") 70 | -------------------------------------------------------------------------------- /tests/test_parsing.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import numpy as np 3 | import unittest 4 | 5 | from mogwai.parsing import one_hot, load_a3m_msa, contacts_from_cf 6 | 7 | 8 | class TestOneHot(unittest.TestCase): 9 | def setUp(self): 10 | self.msa = np.array( 11 | [ 12 | [2, 5, 4], 13 | [3, 2, 19], 14 | ] 15 | ) 16 | self.oh = one_hot(self.msa) 17 | 18 | def test_shape(self): 19 | self.assertTupleEqual(self.oh.shape, (2, 3, 20)) 20 | 21 | def test_argmax(self): 22 | idx = np.argmax(self.oh, -1) 23 | np.testing.assert_array_equal(idx, self.msa) 24 | 25 | def test_pad(self): 26 | padded_msa = np.array([[12, 13, -1, -1], [-1, 9, -1, 8]]) 27 | padded_idx = np.array([[False, False, True, True], [True, False, True, False]]) 28 | 29 | oh = one_hot(padded_msa) 30 | test_padded_idx = np.sum(oh, -1) == 0 31 | 32 | np.testing.assert_array_equal(padded_idx, test_padded_idx) 33 | 34 | 35 | class TestA3MLoading(unittest.TestCase): 36 | def setUp(self): 37 | self.path = Path("data/test/4rb6Y") / "4rb6Y.i90c75.a3m" 38 | if not self.path.exists(): 39 | raise FileNotFoundError( 40 | "Please download 4rb6Y using scripts/download_example.sh to run tests." 41 | ) 42 | self.msa, self.ms, _, self.ref = load_a3m_msa(self.path) 43 | 44 | def test_msa_shape(self): 45 | self.assertTupleEqual(self.msa.shape, (7569, 107, 20)) 46 | 47 | def test_ms_shape(self): 48 | self.assertTupleEqual(self.ms.shape, (7569, 162, 20)) 49 | 50 | def test_msa_indel(self): 51 | # Test for gaps at start of second seq in msa. 52 | seq = self.msa[1].argmax(-1) 53 | self.assertEqual(seq[0], 0) 54 | self.assertEqual(seq[1], 0) 55 | 56 | # Test for no gaps at end 57 | self.assertEqual(seq[-3], 5) 58 | self.assertEqual(seq[-4], 6) 59 | 60 | def test_ms_indel(self): 61 | # Test no gaps at start of second seq in ms. 62 | seq = self.ms[1].argmax(-1) 63 | self.assertEqual(seq[0], 11) 64 | self.assertEqual(seq[1], 16) 65 | 66 | # Test for gaps at end 67 | self.assertEqual(seq[-3], 0) 68 | self.assertEqual(seq[-4], 0) 69 | 70 | def test_reference(self): 71 | reference = "MRVKMHVKKGDTVLVASGKYKGRVGKVKEVLPKKYAVIVEGVNIVKKAVRVSPKYPQGGFIEKEAPLHASKVRPICPACGKPTRVRKKFLENGKKIRVCAKCGGALD" 72 | self.assertEqual(self.ref, reference) 73 | 74 | 75 | class TestCfLoading(unittest.TestCase): 76 | def setUp(self): 77 | self.path = Path("data/test/4rb6Y") / "4rb6Y.cf" 78 | if not self.path.exists(): 79 | raise FileNotFoundError( 80 | "Please download 4rb6Y using scripts/download_example.sh to run tests." 81 | ) 82 | self.contacts = contacts_from_cf(self.path) 83 | 84 | def test_contact_shape(self): 85 | shape = self.contacts.shape 86 | self.assertTupleEqual(shape, (107, 107)) 87 | 88 | def test_zero_contacts(self): 89 | num_zero_contacts = np.sum(self.contacts == 0) 90 | self.assertEqual(num_zero_contacts, 10169) 91 | 92 | def test_contact_mass(self): 93 | contact_mass = np.sum(self.contacts) 94 | self.assertAlmostEqual(contact_mass, 1158.90602) 95 | 96 | 97 | if __name__ == "__main__": 98 | unittest.main() 99 | -------------------------------------------------------------------------------- /mogwai/utils/functional.py: -------------------------------------------------------------------------------- 1 | """Contains transformations shared by models.""" 2 | 3 | import torch 4 | 5 | 6 | def symmetrize_matrix(inp: torch.Tensor) -> torch.Tensor: 7 | """Symmetrize matrix in additive fashion. 8 | 9 | Symmetrizes A with (A + A.T)/2. 10 | 11 | Args: 12 | inp (tensor): Matrix to symmetrize. 13 | """ 14 | return 0.5 * (inp + inp.transpose(-1, -2)) 15 | 16 | 17 | def symmetrize_matrix_(inp: torch.Tensor) -> torch.Tensor: 18 | """Inplace version of symmetrize_matrix. 19 | 20 | Args: 21 | inp (tensor): Matrix to symmetrize. 22 | """ 23 | inp.add_(inp.transpose(-1, -2).clone()) 24 | inp.mul_(0.5) 25 | return inp 26 | 27 | 28 | def symmetrize_potts(weight: torch.Tensor) -> torch.Tensor: 29 | """Symmetrize 4D Potts coupling tensor. 30 | 31 | Enforces the constraint that W(i,j) = W(j,i) for coupling matrices. 32 | 33 | Args: 34 | weight (tensor): 4D tensor of shape (length, vocab, length, vocab) to symmetrize. 35 | """ 36 | return 0.5 * (weight + weight.permute(2, 3, 0, 1)) 37 | 38 | 39 | def symmetrize_potts_(weight: torch.Tensor) -> torch.Tensor: 40 | """Inplace version of symmetrize_potts. 41 | 42 | Args: 43 | weight (tensor): 4D tensor of shape (length, vocab, length, vocab) to symmetrize. 44 | """ 45 | weight.add_(weight.permute(2, 3, 0, 1).clone()) 46 | weight.mul_(0.5) 47 | return weight 48 | 49 | 50 | def zero_diag(inp: torch.Tensor) -> torch.Tensor: 51 | """Zeros all elements of diagonal. 52 | 53 | Computes diagonal along last two axes of input tensor. 54 | 55 | Args: 56 | inp (tensor): Tensor to zero out. 57 | """ 58 | 59 | diag_mask = torch.eye( 60 | inp.size(-2), 61 | inp.size(-1), 62 | dtype=torch.bool, 63 | device=inp.device, 64 | ) 65 | return inp.masked_fill(diag_mask, 0.0) 66 | 67 | 68 | def zero_diag_(inp: torch.Tensor) -> torch.Tensor: 69 | """Inplace version of zero_diag. 70 | 71 | Args: 72 | inp (tensor): Tensor to zero out. 73 | """ 74 | diag_mask = torch.eye( 75 | inp.size(-2), 76 | inp.size(-1), 77 | dtype=torch.bool, 78 | device=inp.device, 79 | ) 80 | return inp.masked_fill_(diag_mask, 0.0) 81 | 82 | 83 | def apc(inp: torch.Tensor, remove_diag: bool = False) -> torch.Tensor: 84 | """Compute Average Product Correction (APC) of tensor. 85 | 86 | Applies correction along last two axes. 87 | 88 | Args: 89 | inp (tensor): Tensor to correct. 90 | remove_diag (bool, optional): Whether to zero out diagonal before correcting. 91 | """ 92 | 93 | if remove_diag: 94 | inp = zero_diag(inp) 95 | 96 | a1 = inp.sum(-1, keepdims=True) # type: ignore 97 | a2 = inp.sum(-2, keepdims=True) # type: ignore 98 | corr = inp - (a1 * a2) / inp.sum((-1, -2), keepdims=True) # type: ignore 99 | corr = zero_diag(corr) 100 | return corr 101 | 102 | 103 | def apc_(inp: torch.Tensor, remove_diag: bool = False) -> torch.Tensor: 104 | """Inplace version of apc. 105 | 106 | Args: 107 | inp (tensor): Tensor to correct. 108 | remove_diag (bool, optional): Whether to zero out diagonal before correcting. 109 | """ 110 | if remove_diag: 111 | zero_diag_(inp) 112 | 113 | a1 = inp.sum(-1, keepdims=True) # type: ignore 114 | a2 = inp.sum(-2, keepdims=True) # type: ignore 115 | a12 = inp.sum((-1, -2), keepdims=True) # type: ignore 116 | corr = a1 * a2 117 | corr.div_(a12) 118 | 119 | inp.sub_(corr) 120 | zero_diag_(inp) 121 | return inp 122 | -------------------------------------------------------------------------------- /mogwai/data/maskedlm_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict 2 | from argparse import ArgumentParser 3 | import torch 4 | from .base_wrapper_dataset import BaseWrapperDataset 5 | from ..utils import collate_tensors 6 | from ..vocab import FastaVocab 7 | 8 | 9 | class MaskedLMDataset(BaseWrapperDataset): 10 | """MaskedLMDataset implements masking tokens with a specified mask index. 11 | 12 | Args: 13 | dataset (torch.utils.data.Dataset): Dataset of tensors to wrap. 14 | mask_idx (int): Index of mask token. 15 | vocab_size (int): Vocab Size. 16 | mask_prob (float): Probability of masking a token. 17 | mask_rnd_prob (float): Probability of replacing token with a random token. 18 | mask_leave_prob (float): Probability of leaving token as-is. 19 | """ 20 | 21 | def __init__( 22 | self, 23 | dataset: torch.utils.data.Dataset, 24 | mask_idx: int, 25 | pad_idx: int, 26 | vocab_size: int, 27 | mask_prob: float = 0.15, 28 | mask_rnd_prob: float = 0.0, 29 | mask_leave_prob: float = 0.0, 30 | ): 31 | super().__init__(dataset) 32 | self.mask_idx = mask_idx 33 | self.pad_idx = pad_idx 34 | self.vocab_size = vocab_size 35 | self.mask_prob = mask_prob 36 | self.mask_rnd_prob = mask_rnd_prob 37 | self.mask_leave_prob = mask_leave_prob 38 | 39 | assert 0 < mask_prob < 1 40 | assert 0 <= mask_rnd_prob < 1 41 | assert 0 <= mask_leave_prob < 1 42 | assert mask_leave_prob + mask_rnd_prob < 1 43 | 44 | def __getitem__(self, idx: int): 45 | item = self.dataset[idx] 46 | if isinstance(item, tuple) and len(item) == 1: 47 | item = item[0] 48 | 49 | mask = torch.rand_like(item, dtype=torch.float) < self.mask_prob 50 | targets = item.masked_fill(~mask, self.pad_idx) 51 | 52 | token_type_probs = torch.rand_like(item, dtype=torch.float) 53 | is_leave = token_type_probs < self.mask_leave_prob 54 | is_rnd = ( 55 | token_type_probs < (self.mask_leave_prob + self.mask_rnd_prob) 56 | ) & ~is_leave 57 | is_mask = ~(is_leave | is_rnd) 58 | 59 | # Do not make this in-place 60 | item = item.masked_fill(mask & is_mask, self.mask_idx) 61 | item[is_rnd & mask] = torch.randint_like(item, self.vocab_size)[is_rnd & mask] 62 | 63 | return {"src_tokens": item, "targets": targets} 64 | 65 | @staticmethod 66 | def add_args(parser: ArgumentParser) -> ArgumentParser: 67 | parser.add_argument( 68 | "--mask_prob", 69 | type=float, 70 | default=0.15, 71 | help="Probability of masking tokens.", 72 | ) 73 | parser.add_argument( 74 | "--mask_rnd_prob", 75 | type=float, 76 | default=0.0, 77 | help="Probability of replacing masked token with a random token.", 78 | ) 79 | parser.add_argument( 80 | "--mask_leave_prob", 81 | type=float, 82 | default=0.0, 83 | help="Probability of keeping correct token as mask token.", 84 | ) 85 | return parser 86 | 87 | def collater(self, batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: 88 | concat = { 89 | "src_tokens": collate_tensors( 90 | [element["src_tokens"] for element in batch], FastaVocab.pad_idx 91 | ), 92 | "targets": collate_tensors( 93 | [element["targets"] for element in batch], FastaVocab.pad_idx 94 | ), 95 | "src_lengths": torch.tensor( 96 | [len(element["src_tokens"]) for element in batch], dtype=torch.long 97 | ), 98 | } 99 | return concat 100 | -------------------------------------------------------------------------------- /mogwai/data_loading/ms_datamodule.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser, Namespace 2 | 3 | from typing import Union, Optional 4 | from pathlib import Path 5 | import torch 6 | import pytorch_lightning as pl 7 | from ..data.ms_dataset import MSDataset, MSStats 8 | from ..data.repeat_dataset import RepeatDataset 9 | from ..data.trrosetta_ms_dataset import TRRosetta_MSDataset 10 | from ..data.pseudolikelihood_dataset import PseudolikelihoodDataset 11 | from ..data.maskedlm_dataset import MaskedLMDataset 12 | from ..vocab import FastaVocab 13 | 14 | 15 | class MSDataModule(pl.LightningDataModule): 16 | """Creates dataset from A3M, Fasta file or TRRosetta npz file. 17 | 18 | Args: 19 | data (Union[str, Path]): Path to fasta, npz, or a3m file to load sequences. 20 | batch_size (int, optional): Batch size for DataLoader. Default 128. 21 | num_repeats (int, optional): Number of times to repeat dataset (can speed up 22 | training for small datasets). Default 1. 23 | task (str, optional): Which task to train with. 24 | Choices: ['pseudolikelihood', 'masked_lm']. Default: 'pseudolikelihhod'. 25 | mask_prob (float, optional): Probability of masking a token when using 26 | 'masked_lm' task. Default: 0.15. 27 | mask_rnd_prob (float, optional): Probability of using a random token when using 28 | 'masked_lm' task. Default: 0.1. 29 | mask_leave_prob (float, optional): Probability of leaving original token when 30 | using 'masked_lm' task. Default: 0.1. 31 | """ 32 | 33 | def __init__( 34 | self, 35 | data: Union[str, Path], 36 | batch_size: int = 128, 37 | num_repeats: int = 100, 38 | task: str = "pseudolikelihood", 39 | mask_prob: float = 0.15, 40 | mask_rnd_prob: float = 0.1, 41 | mask_leave_prob: float = 0.1, 42 | ): 43 | super().__init__() 44 | self.data = Path(data) 45 | self.batch_size = batch_size 46 | self.num_repeats = num_repeats 47 | self.task = task 48 | self.mask_prob = mask_prob 49 | self.mask_rnd_prob = mask_rnd_prob 50 | self.mask_leave_prob = mask_leave_prob 51 | 52 | def setup(self, stage: Optional[str] = None): 53 | if self.data.suffix == ".npz": 54 | ms_dataset = TRRosetta_MSDataset(self.data) 55 | else: 56 | ms_dataset = MSDataset(self.data) 57 | dataset = RepeatDataset(ms_dataset, self.num_repeats) 58 | if self.task == "pseudolikelihood": 59 | dataset = PseudolikelihoodDataset(dataset) 60 | elif self.task == "masked_lm": 61 | dataset = MaskedLMDataset( 62 | dataset, 63 | FastaVocab.pad_idx, 64 | FastaVocab.pad_idx, 65 | len(FastaVocab), 66 | self.mask_prob, 67 | self.mask_rnd_prob, 68 | self.mask_leave_prob, 69 | ) 70 | elif self.task != "none": # allow none to load raw sequences 71 | raise ValueError(f"Invalid task {self.task}") 72 | self.dataset = dataset 73 | self.ms_dataset = ms_dataset 74 | self.dims = (ms_dataset.num_seqs, len(ms_dataset.reference)) 75 | 76 | def get_stats(self) -> MSStats: 77 | try: 78 | return self.ms_dataset.get_stats() 79 | except AttributeError: 80 | raise RuntimeError( 81 | "Trying to get MSA stats before calling setup on module." 82 | ) 83 | 84 | def train_dataloader(self): 85 | return torch.utils.data.DataLoader( 86 | self.dataset, 87 | batch_size=self.batch_size, 88 | shuffle=True, 89 | pin_memory=True, 90 | num_workers=8, 91 | collate_fn=self.dataset.collater, 92 | ) 93 | 94 | @classmethod 95 | def from_args(cls, args: Namespace) -> "MSDataModule": 96 | return cls( 97 | args.data, 98 | args.batch_size, 99 | args.num_repeats, 100 | args.task, 101 | args.mask_prob, 102 | args.mask_rnd_prob, 103 | args.mask_leave_prob, 104 | ) 105 | 106 | @staticmethod 107 | def add_args(parser: ArgumentParser) -> ArgumentParser: 108 | MSDataset.add_args(parser) 109 | RepeatDataset.add_args(parser) 110 | PseudolikelihoodDataset.add_args(parser) 111 | MaskedLMDataset.add_args(parser) 112 | parser.add_argument("--data", type=str, help="Data file to load from.") 113 | parser.add_argument( 114 | "--batch_size", type=int, default=128, help="Batch size for training." 115 | ) 116 | parser.add_argument( 117 | "--task", 118 | choices=["pseudolikelihood", "masked_lm"], 119 | default="masked_lm", 120 | help="Whether to use Pseudolikelihood or Masked LM for training", 121 | ) 122 | return parser 123 | -------------------------------------------------------------------------------- /mogwai/train.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import pytorch_lightning as pl 3 | import torch 4 | from pathlib import Path 5 | import numpy as np 6 | 7 | from mogwai.data_loading import MSADataModule 8 | from mogwai.parsing import read_contacts 9 | from mogwai import models 10 | from mogwai.utils.functional import apc 11 | from mogwai.metrics import contact_auc 12 | from mogwai.vocab import FastaVocab 13 | 14 | 15 | def train(): 16 | # Initialize parser 17 | parser = ArgumentParser() 18 | parser.add_argument( 19 | "--model", 20 | default="gremlin", 21 | choices=models.MODELS.keys(), 22 | help="Which model to train.", 23 | ) 24 | model_name = parser.parse_known_args()[0].model 25 | parser.add_argument( 26 | "--structure_file", 27 | type=str, 28 | default=None, 29 | help=( 30 | "Optional pdb or cf file containing protein structure. " 31 | "Used for evaluation." 32 | ), 33 | ) 34 | parser.add_argument( 35 | "--output_file", 36 | type=str, 37 | default=None, 38 | help="Optional file to output gremlin weights.", 39 | ) 40 | parser.add_argument( 41 | "--contacts_file", 42 | type=str, 43 | default=None, 44 | help="Optional file to output gremlin contacts.", 45 | ) 46 | parser.add_argument( 47 | "--wandb_project", 48 | type=str, 49 | default=None, 50 | help="Optional wandb project to log to.", 51 | ) 52 | parser = MSADataModule.add_args(parser) 53 | parser = pl.Trainer.add_argparse_args(parser) 54 | parser.set_defaults( 55 | gpus=1, 56 | min_steps=50, 57 | max_steps=1000, 58 | ) 59 | model_type = models.get(model_name) 60 | model_type.add_args(parser) 61 | args = parser.parse_args() 62 | 63 | # Load msa 64 | msa_dm = MSADataModule.from_args(args) 65 | msa_dm.setup() 66 | 67 | # Load contacts 68 | true_contacts = ( 69 | torch.from_numpy(read_contacts(args.structure_file)) 70 | if args.structure_file is not None 71 | else None 72 | ) 73 | 74 | # Initialize model 75 | num_seqs, msa_length, msa_counts = msa_dm.get_stats() 76 | model = model_type.from_args( 77 | args, 78 | num_seqs=num_seqs, 79 | msa_length=msa_length, 80 | msa_counts=msa_counts, 81 | vocab_size=len(FastaVocab), 82 | pad_idx=FastaVocab.pad_idx, 83 | true_contacts=true_contacts, 84 | ) 85 | 86 | kwargs = {} 87 | if args.wandb_project: 88 | try: 89 | # Requires wandb to be installed 90 | logger = pl.loggers.WandbLogger(project=args.wandb_project) 91 | logger.log_hyperparams(args) 92 | logger.log_hyperparams( 93 | { 94 | "pdb": Path(args.data).stem, 95 | "num_seqs": num_seqs, 96 | "msa_length": msa_length, 97 | } 98 | ) 99 | kwargs["logger"] = logger 100 | except ImportError: 101 | raise ImportError( 102 | "Cannot use W&B logger w/o W&b install. Run `pip install wandb` first." 103 | ) 104 | 105 | # Initialize Trainer 106 | trainer = pl.Trainer.from_argparse_args(args, checkpoint_callback=False, **kwargs) 107 | 108 | trainer.fit(model, msa_dm) 109 | 110 | if true_contacts is not None: 111 | contacts = model.get_contacts() 112 | auc = contact_auc(contacts, true_contacts).item() 113 | contacts = apc(contacts) 114 | auc_apc = contact_auc(contacts, true_contacts).item() 115 | print(f"AUC: {auc:0.3f}, AUC_APC: {auc_apc:0.3f}") 116 | 117 | if args.wandb_project: 118 | import matplotlib.pyplot as plt 119 | import wandb 120 | 121 | from mogwai.plotting import ( 122 | plot_colored_preds_on_trues, 123 | plot_precision_vs_length, 124 | ) 125 | 126 | filename = "top_L_contacts.png" 127 | plot_colored_preds_on_trues(contacts, true_contacts, point_size=5) 128 | logger.log_metrics({filename: wandb.Image(plt)}) 129 | plt.close() 130 | 131 | filename = "top_L_contacts_apc.png" 132 | plot_colored_preds_on_trues(apc(contacts), true_contacts, point_size=5) 133 | logger.log_metrics({filename: wandb.Image(plt)}) 134 | plt.close() 135 | 136 | filename = "precision_vs_L.png" 137 | plot_precision_vs_length(contacts, true_contacts) 138 | logger.log_metrics({filename: wandb.Image(plt)}) 139 | plt.close() 140 | 141 | if args.output_file is not None: 142 | torch.save(model.state_dict(), args.output_file) 143 | 144 | if args.contacts_file is not None: 145 | contacts = model.get_contacts() 146 | contacts = apc(contacts) 147 | x_ind, y_ind = np.triu_indices_from(contacts, 1) 148 | contacts = contacts[x_ind, y_ind] 149 | torch.save(contacts, args.contacts_file) 150 | 151 | 152 | 153 | if __name__ == "__main__": 154 | train() 155 | -------------------------------------------------------------------------------- /mogwai/models/base_model.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser, Namespace 2 | from typing import List, Optional 3 | 4 | from abc import abstractmethod, abstractclassmethod 5 | 6 | import torch 7 | import pytorch_lightning as pl 8 | 9 | from ..utils import apc 10 | from ..metrics import contact_auc, precision_at_cutoff 11 | 12 | 13 | class BaseModel(pl.LightningModule): 14 | """Base model containing shared init and functionality for all single-MSA models. 15 | 16 | Args: 17 | num_seqs (int): Number of sequences in MSA. 18 | msa_length (int): Length of MSA. 19 | learning_rate (float): Learning rate for training model. 20 | vocab_size (int, optional): Alphabet size of MSA. 21 | true_contacts (tensor, optional): True contacts for family. Used to compute metrics while training. 22 | """ 23 | 24 | def __init__( 25 | self, 26 | num_seqs: int, 27 | msa_length: int, 28 | learning_rate: float, 29 | vocab_size: int = 20, 30 | true_contacts: Optional[torch.Tensor] = None, 31 | ): 32 | super().__init__() 33 | self.num_seqs = num_seqs 34 | self.msa_length = msa_length 35 | self.vocab_size = vocab_size 36 | self.learning_rate = learning_rate 37 | 38 | if true_contacts is not None: 39 | self.register_buffer("_true_contacts", true_contacts, persistent=False) 40 | self.has_true_contacts = True 41 | else: 42 | self.has_true_contacts = False 43 | 44 | self.register_buffer("_max_auc", torch.tensor(0.0), persistent=False) 45 | 46 | def training_step(self, batch, batch_nb): 47 | if isinstance(batch, tuple): 48 | loss, *_ = self.forward(*batch) 49 | elif isinstance(batch, dict): 50 | loss, *_ = self.forward(**batch) 51 | else: 52 | loss, *_ = self.forward(batch) 53 | 54 | if self.has_true_contacts: 55 | auc = self.get_auc(do_apc=False) 56 | auc_apc = self.get_auc(do_apc=True) 57 | 58 | self._max_auc.masked_fill_(self._max_auc < auc, auc) 59 | 60 | self.log("auc", auc, on_step=True, on_epoch=False, prog_bar=False) 61 | self.log("auc_apc", auc_apc, on_step=True, on_epoch=False, prog_bar=True) 62 | self.log( 63 | "max_auc", self._max_auc, on_step=True, on_epoch=False, prog_bar=True 64 | ) 65 | self.log( 66 | "delta_auc", 67 | self._max_auc - auc, 68 | on_step=True, 69 | on_epoch=False, 70 | prog_bar=False, 71 | ) 72 | p_at_l = self.get_precision(do_apc=False) 73 | p_at_l_5 = self.get_precision(do_apc=False, cutoff=5) 74 | self.log("pr_at_L", p_at_l, on_step=True, on_epoch=False, prog_bar=False) 75 | self.log( 76 | "pr_at_L_5", p_at_l_5, on_step=True, on_epoch=False, prog_bar=False 77 | ) 78 | 79 | p_at_l_apc = self.get_precision(do_apc=True) 80 | p_at_l_5_apc = self.get_precision(do_apc=True, cutoff=5) 81 | self.log( 82 | "pr_at_L_apc", p_at_l_apc, on_step=True, on_epoch=False, prog_bar=True 83 | ) 84 | self.log( 85 | "pr_at_L_5_apc", 86 | p_at_l_5_apc, 87 | on_step=True, 88 | on_epoch=False, 89 | prog_bar=False, 90 | ) 91 | 92 | return { 93 | "loss": loss, 94 | } 95 | 96 | @abstractmethod 97 | def get_contacts(self): 98 | raise NotImplementedError 99 | 100 | @torch.no_grad() 101 | def get_precision( 102 | self, 103 | do_apc: bool = True, 104 | thresh: float = 0.01, 105 | superdiag: int = 6, 106 | cutoff: int = 1, 107 | ): 108 | if not self.has_true_contacts: 109 | raise ValueError( 110 | "Model not provided with ground truth contacts, precision can't be computed." 111 | ) 112 | contacts = self.get_contacts() 113 | if do_apc: 114 | contacts = apc(contacts) 115 | return precision_at_cutoff( 116 | contacts, self._true_contacts, thresh, superdiag, cutoff # type: ignore 117 | ) 118 | 119 | @torch.no_grad() 120 | def get_auc( 121 | self, 122 | do_apc: bool = True, 123 | thresh: float = 0.01, 124 | superdiag: int = 6, 125 | cutoff_range: List[int] = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 126 | ): 127 | if not self.has_true_contacts: 128 | raise ValueError( 129 | "Model not provided with ground truth contacts, precision can't be computed." 130 | ) 131 | contacts = self.get_contacts() 132 | if do_apc: 133 | contacts = apc(contacts) 134 | return contact_auc( 135 | contacts, self._true_contacts, thresh, superdiag, cutoff_range # type: ignore 136 | ) 137 | 138 | @abstractclassmethod 139 | def from_args(cls, args: Namespace, *unused, **unusedkw) -> "BaseModel": 140 | return NotImplemented 141 | 142 | @staticmethod 143 | def add_args(parser: ArgumentParser) -> ArgumentParser: 144 | return parser 145 | -------------------------------------------------------------------------------- /mogwai/data_loading/msa_datamodule.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser, Namespace 2 | 3 | from typing import Union, Optional 4 | from pathlib import Path 5 | import torch 6 | import pytorch_lightning as pl 7 | from ..data.msa_dataset import MSADataset, MSAStats 8 | from ..data.trrosetta_dataset import TRRosetta_MSADataset 9 | from ..data.repeat_dataset import RepeatDataset 10 | from ..data.max_steps_dataset import MaxStepsDataset 11 | from ..data.pseudolikelihood_dataset import PseudolikelihoodDataset 12 | from ..data.maskedlm_dataset import MaskedLMDataset 13 | from ..vocab import FastaVocab 14 | 15 | 16 | class MSADataModule(pl.LightningDataModule): 17 | """Creates dataset from an MSA file. 18 | 19 | Args: 20 | data (Union[str, Path]): Path to a3m or npz file to load MSA. 21 | batch_size (int, optional): Batch size for DataLoader. Default 128. 22 | num_repeats (int, optional): Number of times to repeat dataset (can speed up 23 | training for small datasets). Default 1. 24 | task (str, optional): Which task to train with. 25 | Choices: ['pseudolikelihood', 'masked_lm']. Default: 'pseudolikelihhod'. 26 | mask_prob (float, optional): Probability of masking a token when using 27 | 'masked_lm' task. Default: 0.15. 28 | mask_rnd_prob (float, optional): Probability of using a random token when using 29 | 'masked_lm' task. Default: 0.1. 30 | mask_leave_prob (float, optional): Probability of leaving original token when 31 | using 'masked_lm' task. Default: 0.1. 32 | """ 33 | 34 | def __init__( 35 | self, 36 | data: Union[str, Path], 37 | batch_size: int = 128, 38 | num_repeats: int = 100, 39 | task: str = "pseudolikelihood", 40 | mask_prob: float = 0.15, 41 | mask_rnd_prob: float = 0.0, 42 | mask_leave_prob: float = 0.0, 43 | max_steps: int = -1, 44 | ): 45 | super().__init__() 46 | self.data = Path(data) 47 | self.batch_size = batch_size 48 | self.num_repeats = num_repeats 49 | self.task = task 50 | self.mask_prob = mask_prob 51 | self.mask_rnd_prob = mask_rnd_prob 52 | self.mask_leave_prob = mask_leave_prob 53 | self.max_steps = max_steps 54 | 55 | def setup(self, stage: Optional[str] = None): 56 | if self.data.suffix == ".a3m": 57 | msa_dataset = MSADataset(self.data) 58 | elif self.data.suffix == ".npz": 59 | msa_dataset = TRRosetta_MSADataset(self.data) 60 | else: 61 | raise ValueError( 62 | f"Cannot read file of type {self.data.suffix}, must be one of (.a3m," 63 | " .npz)." 64 | ) 65 | if self.max_steps > 0: 66 | dataset = MaxStepsDataset(msa_dataset, self.max_steps, self.batch_size) 67 | else: 68 | dataset = RepeatDataset(msa_dataset, self.num_repeats) 69 | if self.task == "pseudolikelihood": 70 | dataset = PseudolikelihoodDataset(dataset) 71 | elif self.task == "masked_lm": 72 | dataset = MaskedLMDataset( 73 | dataset, 74 | FastaVocab.pad_idx, 75 | FastaVocab.pad_idx, 76 | len(FastaVocab), 77 | self.mask_prob, 78 | self.mask_rnd_prob, 79 | self.mask_leave_prob, 80 | ) 81 | self.dataset = dataset 82 | self.msa_dataset = msa_dataset 83 | self.dims = (msa_dataset.num_seqs, msa_dataset.msa_length) 84 | 85 | def get_stats(self) -> MSAStats: 86 | try: 87 | return self.msa_dataset.get_stats() 88 | except AttributeError: 89 | raise RuntimeError( 90 | "Trying to get MSA stats before calling setup on module." 91 | ) 92 | 93 | def train_dataloader(self): 94 | return torch.utils.data.DataLoader( 95 | self.dataset, 96 | batch_size=self.batch_size, 97 | shuffle=True, 98 | pin_memory=True, 99 | num_workers=4, 100 | collate_fn=self.dataset.collater, 101 | ) 102 | 103 | @classmethod 104 | def from_args(cls, args: Namespace) -> "MSADataModule": 105 | return cls( 106 | args.data, 107 | args.batch_size, 108 | args.num_repeats, 109 | args.task, 110 | args.mask_prob, 111 | args.mask_rnd_prob, 112 | args.mask_leave_prob, 113 | max_steps=getattr(args, "max_steps", -1), 114 | ) 115 | 116 | @staticmethod 117 | def add_args(parser: ArgumentParser) -> ArgumentParser: 118 | MSADataset.add_args(parser) 119 | RepeatDataset.add_args(parser) 120 | PseudolikelihoodDataset.add_args(parser) 121 | MaskedLMDataset.add_args(parser) 122 | parser.add_argument("--data", type=str, help="Data file to load from.") 123 | parser.add_argument( 124 | "--batch_size", type=int, default=128, help="Batch size for training." 125 | ) 126 | parser.add_argument( 127 | "--task", 128 | choices=["pseudolikelihood", "masked_lm"], 129 | default="pseudolikelihood", 130 | help="Whether to use Pseudolikelihood or Masked LM for training", 131 | ) 132 | return parser 133 | -------------------------------------------------------------------------------- /mogwai/metrics/metrics.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | from typing import List 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from scipy.stats import entropy 8 | 9 | 10 | def precision_at_cutoff( 11 | pred: torch.Tensor, 12 | meas: torch.Tensor, 13 | thresh: float = 0.01, 14 | superdiag: int = 6, 15 | cutoff: int = 1, 16 | ): 17 | """Computes precision for top L/k contacts. 18 | 19 | Args: 20 | pred (tensor): Predicted contact scores or probabilities. 21 | meas (tensor): Binary matrix of true contacts. 22 | thresh (float, optional): Threshold at which to call a predicted contact. 23 | superdiag (int, optional): Ignore all true and predicted contacts from diag to superdiag. 24 | cutoff (int, optional): Only compute precision of top L/cutoff predictions. 25 | """ 26 | 27 | # Subset everything above superdiag 28 | eval_idx = np.triu_indices_from(meas, superdiag) 29 | pred_, meas_ = pred[eval_idx], meas[eval_idx] 30 | 31 | # Sort by model confidence 32 | sort_idx = pred_.argsort(descending=True) 33 | 34 | # Extract top predictions and calculate 35 | len_cutoff = (len(meas) / torch.tensor(cutoff)).int() 36 | preds = meas_[sort_idx][:len_cutoff] 37 | 38 | num_positives = (meas_[sort_idx][:len_cutoff] > thresh).sum().float() 39 | num_preds = len(preds) 40 | precision = num_positives / num_preds 41 | 42 | return precision 43 | 44 | 45 | # https://github.com/rmrao/explore-protein-attentcion/blob/main/metrics.py 46 | def precisions_in_range( 47 | predictions: torch.Tensor, 48 | targets: torch.Tensor, 49 | src_lengths: Optional[torch.Tensor] = None, 50 | minsep: int = 6, 51 | maxsep: Optional[int] = None, 52 | ): 53 | if predictions.dim() == 2: 54 | predictions = predictions.unsqueeze(0) 55 | if targets.dim() == 2: 56 | targets = targets.unsqueeze(0) 57 | 58 | # Check sizes 59 | if predictions.size() != targets.size(): 60 | raise ValueError( 61 | f"Size mismatch. Received predictions of size {predictions.size()}, " 62 | f"targets of size {targets.size()}" 63 | ) 64 | device = predictions.device 65 | 66 | batch_size, seqlen, _ = predictions.size() 67 | seqlen_range = torch.arange(seqlen, device=device) 68 | 69 | sep = seqlen_range.unsqueeze(0) - seqlen_range.unsqueeze(1) 70 | sep = sep.unsqueeze(0) 71 | valid_mask = sep >= minsep 72 | 73 | if maxsep is not None: 74 | valid_mask &= sep < maxsep 75 | 76 | if src_lengths is not None: 77 | valid = seqlen_range.unsqueeze(0) < src_lengths.unsqueeze(1) 78 | valid_mask &= valid.unsqueeze(1) & valid.unsqueeze(2) 79 | else: 80 | src_lengths = torch.full( 81 | [batch_size], seqlen, device=device, dtype=torch.long) 82 | 83 | predictions = predictions.masked_fill(~valid_mask, float("-inf")) 84 | 85 | x_ind, y_ind = np.triu_indices(seqlen, minsep) 86 | predictions_upper = predictions[:, x_ind, y_ind] 87 | targets_upper = targets[:, x_ind, y_ind] 88 | 89 | indices = predictions_upper.argsort(dim=-1, descending=True)[:, :seqlen] 90 | # indices = predictions_upper.topk( 91 | # dim=-1, k=seqlen, sorted=True, largest=True 92 | # ).indices 93 | 94 | topk_targets = targets_upper[torch.arange(batch_size), indices] > 0.01 95 | n_targets = topk_targets.size(1) 96 | if n_targets < seqlen: 97 | topk_targets = F.pad(topk_targets, [0, seqlen - n_targets]) 98 | 99 | cumulative_dist = topk_targets.type_as(predictions).cumsum(-1) 100 | 101 | gather_indices = ( 102 | torch.arange(0.1, 1.1, 0.1, device=device).unsqueeze(0) 103 | * src_lengths.unsqueeze(1) 104 | ).type(torch.long) - 1 105 | 106 | binned_cumulative_dist = cumulative_dist.gather(1, gather_indices) 107 | binned_precisions = binned_cumulative_dist / (gather_indices + 1).type_as( 108 | binned_cumulative_dist 109 | ) 110 | 111 | pl10 = binned_precisions[:, 0] 112 | pl5 = binned_precisions[:, 1] 113 | pl2 = binned_precisions[:, 4] 114 | pl = binned_precisions[:, 9] 115 | auc = binned_precisions.mean(-1) 116 | 117 | return { 118 | "auc": auc, 119 | "pr_at_l": pl, 120 | "pr_at_l_2": pl2, 121 | "pr_at_l_5": pl5, 122 | "pr_at_l_10": pl10 123 | } 124 | 125 | 126 | def contact_auc( 127 | pred: torch.Tensor, 128 | meas: torch.Tensor, 129 | thresh: float = 0.01, 130 | superdiag: int = 6, 131 | cutoff_range: List[int] = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 132 | ): 133 | """Compute modified Area Under PR Curve. 134 | 135 | Args: 136 | pred (tensor): Predicted contact scores or probabilities. 137 | meas (tensor): Binary matrix of true contacts. 138 | thresh (float, optional): Threshold at which to call a predicted contact. 139 | superdiag (int, optional): Ignore all true and predicted contacts from diag to superdiag. 140 | cutoff_range (List[int], optional): Range of precision cutoffs to use for averaging. 141 | """ 142 | 143 | # Nick: This does not agree with normal AUPR, but instead computes 144 | # precision for top L/10, L/9, ... then averages them together. 145 | # Compared to normal precision this puts more weight on a small number 146 | # of high confidence true positives. 147 | 148 | # True aupr could be computed via 149 | # 150 | # from sklearn.metrics import average_precision_score 151 | # 152 | # eval_idx = np.triu_indices_from(meas, superdiag) 153 | # pred_, meas_ = pred[eval_idx], meas[eval_idx] 154 | # aupr = average_precision_score(meas_ > thresh, pred_) 155 | 156 | binned_precisions = [ 157 | precision_at_cutoff(pred, meas, thresh, superdiag, c) for c in cutoff_range 158 | ] 159 | return torch.stack(binned_precisions, 0).mean() 160 | 161 | 162 | def get_len_stdev(msa): 163 | lengths = [] 164 | for seq in msa: 165 | num_gaps = len(np.where(seq == 20)[0]) 166 | len_unaligned = len(seq) - num_gaps 167 | lengths.append(len_unaligned) 168 | return np.std(lengths) 169 | 170 | 171 | def get_len_entropy(msa): 172 | lengths = [] 173 | for seq in msa: 174 | num_gaps = len(np.where(seq == 20)[0]) 175 | len_unaligned = len(seq) - num_gaps 176 | lengths.append(len_unaligned) 177 | return entropy(lengths) 178 | -------------------------------------------------------------------------------- /mogwai/lr_schedulers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Modifications by Roshan Rao 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """PyTorch optimization for BERT model.""" 17 | 18 | import logging 19 | import math 20 | from typing import Optional, Type 21 | 22 | from torch.optim.lr_scheduler import LambdaLR 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | class ConstantLRSchedule(LambdaLR): 28 | """ Constant learning rate schedule. 29 | """ 30 | def __init__(self, 31 | optimizer, 32 | warmup_steps: Optional[int] = None, 33 | t_total: Optional[int] = None, 34 | last_epoch: int = -1): 35 | super(ConstantLRSchedule, self).__init__( 36 | optimizer, lambda _: 1.0, last_epoch=last_epoch) # type: ignore 37 | 38 | 39 | class WarmupConstantSchedule(LambdaLR): 40 | """ Linear warmup and then constant. 41 | Linearly increases learning rate schedule from 0 to 1 over `warmup_steps` 42 | training steps. Keeps learning rate schedule equal to 1. after warmup_steps. 43 | """ 44 | def __init__(self, 45 | optimizer, 46 | warmup_steps: int, 47 | t_total: Optional[int] = None, 48 | last_epoch: int = -1): 49 | self.warmup_steps = warmup_steps 50 | super(WarmupConstantSchedule, self).__init__( 51 | optimizer, self.lr_lambda, last_epoch=last_epoch) # type: ignore 52 | 53 | def lr_lambda(self, step): 54 | if step < self.warmup_steps: 55 | return float(step) / float(max(1.0, self.warmup_steps)) 56 | return 1. 57 | 58 | 59 | class WarmupLinearSchedule(LambdaLR): 60 | """ Linear warmup and then linear decay. 61 | Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps. 62 | Linearly decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` 63 | steps. 64 | """ 65 | def __init__(self, 66 | optimizer, 67 | warmup_steps: int, 68 | t_total: int, 69 | last_epoch: int = -1): 70 | self.warmup_steps = warmup_steps 71 | self.t_total = t_total 72 | super(WarmupLinearSchedule, self).__init__( 73 | optimizer, self.lr_lambda, last_epoch=last_epoch) # type: ignore 74 | 75 | def lr_lambda(self, step): 76 | if step < self.warmup_steps: 77 | return float(step) / float(max(1, self.warmup_steps)) 78 | return max(0.0, float(self.t_total - step) / float( 79 | max(1.0, self.t_total - self.warmup_steps))) 80 | 81 | 82 | class WarmupCosineSchedule(LambdaLR): 83 | """ Linear warmup and then cosine decay. 84 | Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps. 85 | Decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps 86 | following a cosine curve. If `cycles` (default=0.5) is different from default, learning 87 | rate follows cosine function after warmup. 88 | """ 89 | def __init__(self, 90 | optimizer, 91 | warmup_steps: int, 92 | t_total: int, 93 | cycles: float = .5, 94 | last_epoch: int = -1): 95 | self.warmup_steps = warmup_steps 96 | self.t_total = t_total 97 | self.cycles = cycles 98 | super(WarmupCosineSchedule, self).__init__( 99 | optimizer, self.lr_lambda, last_epoch=last_epoch) # type: ignore 100 | 101 | def lr_lambda(self, step): 102 | if step < self.warmup_steps: 103 | return float(step) / float(max(1.0, self.warmup_steps)) 104 | # progress after warmup 105 | progress = float(step - self.warmup_steps) / float( 106 | max(1, self.t_total - self.warmup_steps)) 107 | return max(0.0, 0.5 * (1. + math.cos(math.pi * float(self.cycles) * 2.0 * progress))) 108 | 109 | 110 | class WarmupCosineWithHardRestartsSchedule(LambdaLR): 111 | """ Linear warmup and then cosine cycles with hard restarts. 112 | Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps. 113 | If `cycles` (default=1.) is different from default, learning rate follows `cycles` times 114 | a cosine decaying learning rate (with hard restarts). 115 | """ 116 | def __init__(self, 117 | optimizer, 118 | warmup_steps: int, 119 | t_total: int, 120 | cycles: float = 1., 121 | last_epoch: int = -1): 122 | self.warmup_steps = warmup_steps 123 | self.t_total = t_total 124 | self.cycles = cycles 125 | super(WarmupCosineWithHardRestartsSchedule, self).__init__( 126 | optimizer, self.lr_lambda, last_epoch=last_epoch) # type: ignore 127 | 128 | def lr_lambda(self, step): 129 | if step < self.warmup_steps: 130 | return float(step) / float(max(1, self.warmup_steps)) 131 | # progress after warmup 132 | progress = float(step - self.warmup_steps) / float( 133 | max(1, self.t_total - self.warmup_steps)) 134 | if progress >= 1.0: 135 | return 0.0 136 | return max(0.0, 0.5 * (1. + math.cos( 137 | math.pi * ((float(self.cycles) * progress) % 1.0)))) 138 | 139 | 140 | LR_SCHEDULERS = { 141 | 'constant': ConstantLRSchedule, 142 | 'warmup_constant': WarmupConstantSchedule, 143 | 'warmup_linear': WarmupLinearSchedule, 144 | 'warmup_cosine': WarmupCosineSchedule, 145 | 'warmup_cosine_with_restarts': WarmupCosineWithHardRestartsSchedule} 146 | 147 | 148 | def get(scheduler: str) -> Type[LambdaLR]: 149 | try: 150 | return LR_SCHEDULERS[scheduler] 151 | except KeyError: 152 | raise KeyError(f"Unrecognized lr_scheduler {scheduler}") 153 | -------------------------------------------------------------------------------- /mogwai/parsing.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List, Tuple, Dict, Optional, Any 2 | 3 | from Bio import SeqIO 4 | from biotite.structure.io.pdb import PDBFile 5 | from scipy.spatial.distance import pdist, squareform 6 | from pathlib import Path 7 | import numpy as np 8 | import string 9 | from .vocab import FastaVocab 10 | 11 | PathLike = Union[str, Path] 12 | 13 | 14 | def one_hot(x, cat=None): 15 | """Onehot encodes a sequence of ints.""" 16 | if cat is None: 17 | cat = np.max(x) + 1 18 | oh = np.concatenate((np.eye(cat), np.zeros([1, cat]))) 19 | return oh[x] 20 | 21 | 22 | def parse_fasta( 23 | filename: Union[str, Path], 24 | remove_insertions: bool = False, 25 | remove_gaps: bool = False, 26 | ) -> Tuple[List[str], List[str]]: 27 | 28 | filename = Path(filename) 29 | if filename.suffix == ".sto": 30 | form = "stockholm" 31 | elif filename.suffix in (".fas", ".fasta", ".a3m"): 32 | form = "fasta" 33 | else: 34 | raise ValueError(f"Unknown file format {filename.suffix}") 35 | 36 | translate_dict: Dict[str, Optional[str]] = {} 37 | if remove_insertions: 38 | translate_dict.update(dict.fromkeys(string.ascii_lowercase)) 39 | else: 40 | translate_dict.update(dict(zip(string.ascii_lowercase, string.ascii_uppercase))) 41 | 42 | if remove_gaps: 43 | translate_dict["-"] = None 44 | 45 | translate_dict["."] = None 46 | translate_dict["*"] = None 47 | translation = str.maketrans(translate_dict) 48 | 49 | def process_record(record: SeqIO.SeqRecord): 50 | return record.description, str(record.seq).translate(translation) 51 | records = SeqIO.parse(str(filename), form) 52 | records = map(process_record, records) 53 | records = zip(*records) 54 | headers, sequences = tuple(records) 55 | return headers, sequences 56 | 57 | 58 | def get_seqref(x: str) -> Tuple[List[int], List[int], List[int]]: 59 | # input: string 60 | # output 61 | # -seq: unaligned sequence (remove gaps, lower to uppercase, 62 | # numeric(A->0, R->1...)) 63 | # -ref: reference describing how each sequence aligns to the first 64 | # (reference sequence) 65 | n, seq, ref, aligned_seq = 0, [], [], [] 66 | for aa in x: 67 | if aa != "-": 68 | seq.append(FastaVocab.A2N.get(aa.upper(), -1)) 69 | if aa.islower(): 70 | ref.append(-1) 71 | n -= 1 72 | else: 73 | ref.append(n) 74 | aligned_seq.append(seq[-1]) 75 | else: 76 | aligned_seq.append(-1) 77 | n += 1 78 | return np.array(seq), np.array(ref), np.array(aligned_seq) 79 | 80 | 81 | def load_a3m_msa(filename) -> Tuple[Any, Any, Any, str]: 82 | """ 83 | Given A3M file (from hhblits) 84 | return Tuple of (MSA (aligned), MS (unaligned), ALN (alignment), reference sequence) 85 | """ 86 | names, seqs = parse_fasta(filename) 87 | 88 | reference = seqs[0] 89 | # get the multiple sequence alignment 90 | max_len = 0 91 | ms, aln, msa = [], [], [] 92 | for seq in seqs: 93 | seq_, ref_, aligned_seq_ = get_seqref(seq) 94 | max_len = max(max_len, len(seq_)) 95 | ms.append(seq_) 96 | msa.append(aligned_seq_) 97 | aln.append(ref_) 98 | 99 | # pad each unaligned-sequence and alignment to same length 100 | for n in range(len(ms)): 101 | pad = max_len - len(ms[n]) 102 | ms[n] = np.pad(ms[n], [0, pad], constant_values=-1) 103 | aln[n] = np.pad(aln[n], [0, pad], constant_values=-1) 104 | 105 | # TODO nthomas - figure out why this is causing errors in test_parsing.py 106 | return one_hot(msa), one_hot(ms), one_hot(aln), reference 107 | 108 | 109 | def contacts_from_cf(filename: PathLike, cutoff=0.001, sequence=None) -> np.ndarray: 110 | # contact Y,1 Y,2 0.006281 MET ARG 111 | n, cons = 0, [] 112 | with open(filename, "r") as f: 113 | for line in f: 114 | line = line.rstrip() 115 | if line[:7] == "contact": 116 | _, _, i, _, j, p, _, _ = line.replace(",", " ").split() 117 | i, j, p = int(i), int(j), float(p) 118 | if i > n: 119 | n = i 120 | if j > n: 121 | n = j 122 | cons.append([i - 1, j - 1, p]) 123 | if line.startswith("SEQUENCE") and sequence is not None: 124 | seq = line.split()[1:] 125 | seq = "".join(FastaVocab.THREE_LETTER[code] for code in seq) 126 | start = seq.index(sequence) 127 | end = start + len(sequence) 128 | break 129 | else: 130 | start = 0 131 | end = n 132 | cm = np.zeros([n, n]) 133 | for i, j, p in cons: 134 | cm[i, j] = p 135 | contacts = cm + cm.T 136 | contacts = contacts[start:end, start:end] 137 | return contacts 138 | 139 | 140 | def extend(a, b, c, L, A, D): 141 | """ 142 | input: 3 coords (a,b,c), (L)ength, (A)ngle, and (D)ihedral 143 | output: 4th coord 144 | """ 145 | 146 | def normalize(x): 147 | return x / np.linalg.norm(x, ord=2, axis=-1, keepdims=True) 148 | 149 | bc = normalize(b - c) 150 | n = normalize(np.cross(b - a, bc)) 151 | m = [bc, np.cross(n, bc), n] 152 | d = [L * np.cos(A), L * np.sin(A) * np.cos(D), -L * np.sin(A) * np.sin(D)] 153 | return c + sum([m * d for m, d in zip(m, d)]) 154 | 155 | 156 | def contacts_from_pdb( 157 | filename: PathLike, distance_threshold: float = 8.0 158 | ) -> np.ndarray: 159 | pdbfile = PDBFile.read(str(filename)) 160 | structure = pdbfile.get_structure() 161 | 162 | N = structure.coord[0, structure.atom_name == "N"] 163 | C = structure.coord[0, structure.atom_name == "C"] 164 | CA = structure.coord[0, structure.atom_name == "CA"] 165 | 166 | Cbeta = extend(C, N, CA, 1.522, 1.927, -2.143) 167 | distogram = squareform(pdist(Cbeta)) 168 | return distogram < distance_threshold 169 | 170 | 171 | def contacts_from_trrosetta( 172 | filename: PathLike, 173 | distance_threshold: float = 8.0, 174 | ): 175 | fam_data = np.load(filename) 176 | dist = fam_data["dist6d"] 177 | nat_contacts = dist * ((dist > 0) & (dist < distance_threshold)) 178 | return nat_contacts 179 | 180 | 181 | def read_contacts(filename: PathLike, **kwargs) -> np.ndarray: 182 | filename = Path(filename) 183 | if filename.suffix == ".cf": 184 | return contacts_from_cf(filename, **kwargs) 185 | elif filename.suffix == ".pdb": 186 | return contacts_from_pdb(filename, **kwargs) 187 | elif filename.suffix == ".npz": 188 | return contacts_from_trrosetta(filename, **kwargs) 189 | else: 190 | raise ValueError( 191 | f"Cannot read file of type {filename.suffix}, must be one of (.cf, .pdb)" 192 | ) 193 | -------------------------------------------------------------------------------- /mogwai/models/gremlin.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser, Namespace 2 | import math 3 | from typing import Optional, Dict 4 | import gzip 5 | import io 6 | 7 | import torch 8 | import torch.nn as nn 9 | import numpy as np 10 | 11 | from .base_model import BaseModel 12 | from ..optim import GremlinAdam 13 | from ..utils import symmetrize_potts_ 14 | from ..utils.init import ( 15 | init_potts_bias, 16 | init_potts_weight, 17 | init_pseudolik_mask, 18 | gremlin_weight_decay_coeffs, 19 | ) 20 | 21 | 22 | class Gremlin(BaseModel): 23 | def __init__( 24 | self, 25 | num_seqs: int, 26 | msa_length: int, 27 | msa_counts: Optional[torch.Tensor] = None, 28 | optimizer: str = "gremlin_adam", 29 | learning_rate: float = 0.5, 30 | vocab_size: int = 20, 31 | true_contacts: Optional[torch.Tensor] = None, 32 | l2_coeff: float = 1e-2, 33 | use_bias: bool = True, 34 | pad_idx: int = 20, 35 | ): 36 | super().__init__(num_seqs, msa_length, learning_rate, vocab_size, true_contacts) 37 | self.l2_coeff = l2_coeff 38 | self.use_bias = use_bias 39 | self.pad_idx = pad_idx 40 | self.optimizer = optimizer 41 | 42 | weight = init_potts_weight(msa_length, vocab_size) 43 | self.weight = nn.Parameter(weight, True) 44 | 45 | mask = init_pseudolik_mask(msa_length) 46 | self.register_buffer("diag_mask", mask, persistent=False) 47 | 48 | if self.use_bias: 49 | if msa_counts is not None: 50 | bias = init_potts_bias(msa_counts, l2_coeff, num_seqs) 51 | else: 52 | bias = torch.zeros(msa_length, vocab_size) 53 | self.bias = nn.Parameter(bias, True) 54 | 55 | self.register_buffer("one_hot", torch.eye(vocab_size + 1, vocab_size), persistent=False) 56 | 57 | @torch.no_grad() 58 | def apply_constraints(self): 59 | # Symmetrize and mask diagonal 60 | self.weight.data = symmetrize_potts_(self.weight.data) 61 | self.weight.data.mul_(self.diag_mask[:, None, :, None]) 62 | 63 | def maybe_onehot_inputs(self, src_tokens): 64 | """Onehots src_tokens if necessary otherwise uses original tokens""" 65 | if src_tokens.dtype == torch.long: 66 | return self.one_hot[src_tokens] 67 | else: 68 | return src_tokens 69 | 70 | def forward(self, src_tokens, targets=None, src_lengths=None): 71 | self.apply_constraints() 72 | inputs = self.maybe_onehot_inputs(src_tokens) 73 | logits = torch.tensordot(inputs, self.weight, 2) 74 | if self.use_bias: 75 | logits = logits + self.bias 76 | 77 | outputs = (logits,) 78 | if targets is not None: 79 | loss = self.loss(logits, targets) 80 | outputs = (loss,) + outputs 81 | 82 | return outputs 83 | 84 | def loss(self, logits, targets): 85 | """Compute GREMLIN loss w/ L2 Regularization""" 86 | loss = nn.CrossEntropyLoss(ignore_index=self.pad_idx, reduction="sum")( 87 | logits.view(-1, self.vocab_size), targets.view(-1) 88 | ) 89 | loss *= self.num_seqs / logits.size(0) 90 | loss += self.compute_regularization(targets) 91 | return loss 92 | 93 | def compute_regularization(self, targets): 94 | """Compute regularization weights based on the number of targets.""" 95 | batch_size = targets.size(0) 96 | 97 | weight_reg_coeff, bias_reg_coeff = gremlin_weight_decay_coeffs( 98 | batch_size, self.msa_length, self.l2_coeff, self.vocab_size 99 | ) 100 | 101 | sample_size = (targets != self.pad_idx).sum() 102 | # After multiplying by sample_size, comes to lambda * L * A / 2 103 | reg = weight_reg_coeff * self.weight.pow(2).sum() 104 | if self.use_bias: 105 | # After multiplying by sample_size, comes to lambda 106 | reg += bias_reg_coeff * self.bias.pow(2).sum() 107 | 108 | return reg * sample_size 109 | 110 | def configure_optimizers(self): 111 | if self.optimizer == "gremlin_adam": 112 | optimizer = GremlinAdam( 113 | self.parameters(), lr=self.learning_rate, weight_decay=0.0 114 | ) 115 | elif self.optimizer == "adam": 116 | 117 | self.learning_rate *= math.log(self.num_seqs) / self.msa_length 118 | optimizer = torch.optim.Adam( 119 | self.parameters(), lr=self.learning_rate, weight_decay=0.0 120 | ) 121 | return [optimizer] 122 | 123 | @torch.no_grad() 124 | def get_contacts(self): 125 | """Extracts contacts by taking Frobenius norm of each interaction matrix.""" 126 | self.apply_constraints() 127 | contacts = self.weight.data.norm(p=2, dim=(1, 3)) 128 | return contacts 129 | 130 | @classmethod 131 | def from_args( 132 | cls, 133 | args: Namespace, 134 | num_seqs: int, 135 | msa_length: int, 136 | msa_counts: Optional[torch.Tensor] = None, 137 | vocab_size: int = 20, 138 | pad_idx: int = 20, 139 | true_contacts: Optional[torch.Tensor] = None, 140 | ) -> "Gremlin": 141 | return cls( 142 | num_seqs=num_seqs, 143 | msa_length=msa_length, 144 | msa_counts=msa_counts, 145 | learning_rate=args.learning_rate, 146 | vocab_size=vocab_size, 147 | true_contacts=true_contacts, 148 | l2_coeff=args.l2_coeff, 149 | use_bias=args.use_bias, 150 | pad_idx=pad_idx, 151 | optimizer=args.optimizer, 152 | ) 153 | 154 | @staticmethod 155 | def add_args(parser: ArgumentParser) -> ArgumentParser: 156 | parser.add_argument( 157 | "--learning_rate", 158 | type=float, 159 | default=0.5, 160 | help="Learning rate for training.", 161 | ) 162 | parser.add_argument( 163 | "--l2_coeff", 164 | type=float, 165 | default=1e-2, 166 | help="L2 Regularization Coefficient.", 167 | ) 168 | parser.add_argument( 169 | "--use_bias", action="store_true", help="Use a bias when training GREMLIN." 170 | ) 171 | parser.add_argument( 172 | "--no_bias", 173 | action="store_false", 174 | help="Use a bias when training GREMLIN.", 175 | dest="use_bias", 176 | ) 177 | parser.add_argument( 178 | "--optimizer", 179 | choices=["adam", "gremlin_adam"], 180 | default="gremlin_adam", 181 | help="Which optimizer to use.", 182 | ) 183 | return parser 184 | 185 | def save_compressed_state(self, path): 186 | """ Saves the GREMLIN state dict in a highly compressed manner (50x reduction). 187 | 188 | First, note that GREMLIN parameters are symmetric, and the diagonal is always 189 | zero. Saving only the upper half gets us a 2x reduction in space. Next, instead 190 | of saving weights in full precision, we can save in half precision. This *is* a 191 | lossy conversion, however in practice it is unlikely to matter. Converting to 192 | half precision gains us another 2x reduction in space. Finally, the data 193 | compress well with gzip, netting a ~12.5x reduction in space. 194 | 195 | Note that these transformations must be reversed when loading the data. See 196 | `load_compressed_state`. 197 | """ 198 | state = {key: tensor.half() for key, tensor in self.state_dict().items()} 199 | weight = state["weight"] 200 | x_ind, y_ind = np.triu_indices(weight.size(0), 1) 201 | state["weight"] = weight[x_ind, :, y_ind, :] 202 | buffer = io.BytesIO() 203 | torch.save(state, buffer) 204 | buffer.seek(0) 205 | data = buffer.read() 206 | with gzip.open(path, "wb") as f: 207 | f.write(data) 208 | 209 | @classmethod 210 | def load_compressed_state(cls, path) -> Dict[str, torch.Tensor]: 211 | """ Reverses the transformations in `save_compressed_state`. See for details. 212 | """ 213 | with gzip.open(path, "rb") as f: 214 | state = torch.load(f, map_location="cpu") 215 | weight = state["weight"] 216 | vocab_size = weight.size(1) 217 | 218 | # The actual sequence length is not saved, however we know that the number of 219 | # upper-diag values is N = (L * (L - 1)) / 2. Therefore L - 1 < sqrt(2N) < L. 220 | # So we can find L = ceil(sqrt(2N)) 221 | seqlen = math.ceil(math.sqrt(2 * weight.size(0))) 222 | 223 | full_weight = torch.zeros( 224 | seqlen, vocab_size, seqlen, vocab_size, dtype=weight.dtype 225 | ) 226 | x_ind, y_ind = np.triu_indices(seqlen, 1) 227 | full_weight[x_ind, :, y_ind, :] = weight 228 | full_weight.add_(full_weight.permute(2, 3, 0, 1).clone()) 229 | state["weight"] = full_weight 230 | state = {key: tensor.float() for key, tensor in state.items()} 231 | return state 232 | -------------------------------------------------------------------------------- /mogwai/models/multilayer_attention.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser, Namespace 2 | from typing import Optional 3 | import math 4 | 5 | import torch 6 | import torch.nn as nn 7 | from apex.optimizers import FusedLAMB 8 | 9 | from .base_model import BaseModel 10 | from ..utils import symmetrize_matrix_ 11 | from ..utils.init import init_potts_bias 12 | 13 | 14 | class MultiheadAttention(nn.Module): 15 | def __init__( 16 | self, 17 | input_size: int, 18 | num_attention_heads: int, 19 | attention_head_size: int, 20 | output_size: int, 21 | ): 22 | super().__init__() 23 | hidden_size = num_attention_heads * attention_head_size 24 | self.num_attention_heads = num_attention_heads 25 | self.attention_head_size = attention_head_size 26 | self.qkv = nn.Linear(input_size, hidden_size * 3, bias=False) 27 | self.output = nn.Linear(hidden_size, output_size) 28 | 29 | def forward(self, inputs): 30 | batch_size, seqlen, _ = inputs.size() 31 | queries, keys, values = self.qkv(inputs).chunk(dim=-1, chunks=3) 32 | queries = queries.view( 33 | batch_size, seqlen, self.num_attention_heads, self.attention_head_size 34 | ) 35 | keys = keys.view( 36 | batch_size, seqlen, self.num_attention_heads, self.attention_head_size 37 | ) 38 | values = values.view( 39 | batch_size, seqlen, self.num_attention_heads, self.attention_head_size 40 | ) 41 | attention = torch.einsum("nihd,njhd->nhij", queries, keys) 42 | attention = attention / math.sqrt(self.attention_head_size) 43 | attention = attention.softmax(-1) 44 | context = torch.einsum("nhij,njhd->nihd", attention, values) 45 | context = context.reshape( 46 | batch_size, seqlen, self.num_attention_heads * self.attention_head_size 47 | ) 48 | return self.output(context), attention 49 | 50 | 51 | class MultilayerAttention(BaseModel): 52 | """Attention Layer. 53 | 54 | Args: 55 | num_seqs (int): Number of sequences in MSA. 56 | msa_length (int): Length of MSA. 57 | msa_counts (tensor): Counts of each amino acid in each position of MSA. Used 58 | for initialization. 59 | learning_rate (float): Learning rate for training model. 60 | vocab_size (int, optional): Alphabet size of MSA. 61 | true_contacts (tensor, optional): True contacts for family. Used to compute 62 | metrics while training. 63 | l2_coeff (int, optional): Coefficient of L2 regularization for all weights. 64 | use_bias (bool, optional): Whether to include single-site potentials. 65 | """ 66 | 67 | def __init__( 68 | self, 69 | num_seqs: int, 70 | msa_length: int, 71 | msa_counts: torch.Tensor, 72 | attention_head_size: int = 16, 73 | num_attention_heads: int = 32, 74 | num_layers: int = 1, 75 | optimizer: str = "adam", 76 | learning_rate: float = 1e-3, 77 | vocab_size: int = 20, 78 | true_contacts: Optional[torch.Tensor] = None, 79 | l2_coeff: float = 1e-2, 80 | use_bias: bool = True, 81 | pad_idx: int = 20, 82 | ): 83 | super().__init__(num_seqs, msa_length, learning_rate, vocab_size, true_contacts) 84 | self.l2_coeff = l2_coeff 85 | self.use_bias = use_bias 86 | self.pad_idx = pad_idx 87 | self.num_seqs = num_seqs 88 | self.msa_length = msa_length 89 | self.num_attention_heads = num_attention_heads 90 | self.attention_head_size = attention_head_size 91 | self.optimizer = optimizer 92 | 93 | hidden_size = attention_head_size * num_attention_heads 94 | 95 | layers = nn.ModuleList() 96 | for layer in range(num_layers): 97 | input_size = hidden_size if layer > 0 else msa_length + vocab_size 98 | output_size = hidden_size if layer + 1 < num_layers else vocab_size 99 | layers.append( 100 | MultiheadAttention( 101 | input_size, num_attention_heads, attention_head_size, output_size 102 | ) 103 | ) 104 | self.layers = layers 105 | 106 | if self.use_bias: 107 | bias = init_potts_bias(msa_counts, l2_coeff, num_seqs) 108 | bias = nn.Parameter(bias, True) 109 | self.register_parameter("bias", bias) 110 | 111 | self.register_buffer("posembed", torch.eye(msa_length).unsqueeze(0)) 112 | self.register_buffer("one_hot", torch.eye(vocab_size + 1, vocab_size)) 113 | self.register_buffer("diag_mask", torch.eye(msa_length) * -10000) 114 | 115 | def forward(self, src_tokens, targets=None): 116 | batch_size, seqlen = src_tokens.size() 117 | inputs = self.one_hot[src_tokens] 118 | posembed = self.posembed.repeat(batch_size, 1, 1) 119 | inputs = torch.cat((inputs, posembed), -1) 120 | 121 | for layer in self.layers: 122 | outputs, attention = layer(inputs) 123 | if outputs.size() == inputs.size(): 124 | outputs = outputs + inputs 125 | inputs = outputs 126 | 127 | logits = inputs 128 | if self.use_bias: 129 | logits = logits + self.bias 130 | 131 | outputs = (logits, attention) 132 | if targets is not None: 133 | loss = nn.CrossEntropyLoss(ignore_index=self.pad_idx, reduction="sum")( 134 | logits.view(-1, self.vocab_size), targets.view(-1) 135 | ) 136 | loss = loss / inputs.size(0) 137 | outputs = (loss,) + outputs 138 | return outputs 139 | 140 | def configure_optimizers(self): 141 | if self.optimizer == "adam": 142 | optimizer = torch.optim.AdamW( 143 | self.parameters(), lr=self.learning_rate, weight_decay=self.l2_coeff 144 | ) 145 | elif self.optimizer == "lamb": 146 | optimizer = FusedLAMB( 147 | self.parameters(), 148 | lr=self.learning_rate, 149 | weight_decay=self.l2_coeff, 150 | ) 151 | else: 152 | raise ValueError(f"Unrecognized optimizer {self.optimizer}") 153 | return [optimizer] 154 | 155 | @torch.no_grad() 156 | def get_contacts(self): 157 | """Extracts contacts by getting the attentions.""" 158 | inputs = torch.full( 159 | [1, self.msa_length], 160 | self.pad_idx, 161 | dtype=torch.long, 162 | device=next(self.parameters()).device, 163 | ) 164 | *_, attention = self.forward(inputs) 165 | attention = attention.mean((0, 1)) 166 | attention = symmetrize_matrix_(attention) 167 | return attention 168 | 169 | @classmethod 170 | def from_args( 171 | cls, 172 | args: Namespace, 173 | num_seqs: int, 174 | msa_length: int, 175 | msa_counts: torch.Tensor, 176 | vocab_size: int = 20, 177 | pad_idx: int = 20, 178 | true_contacts: Optional[torch.Tensor] = None, 179 | ) -> "Attention": 180 | return cls( 181 | num_seqs=num_seqs, 182 | msa_length=msa_length, 183 | msa_counts=msa_counts, 184 | attention_head_size=args.attention_head_size, 185 | num_attention_heads=args.num_attention_heads, 186 | num_layers=args.num_layers, 187 | optimizer=args.optimizer, 188 | learning_rate=args.learning_rate, 189 | vocab_size=vocab_size, 190 | true_contacts=true_contacts, 191 | l2_coeff=args.l2_coeff, 192 | use_bias=args.use_bias, 193 | pad_idx=pad_idx, 194 | ) 195 | 196 | @staticmethod 197 | def add_args(parser: ArgumentParser) -> ArgumentParser: 198 | parser.add_argument( 199 | "--learning_rate", 200 | type=float, 201 | default=1e-3, 202 | help="Learning rate for training.", 203 | ) 204 | parser.add_argument( 205 | "--l2_coeff", 206 | type=float, 207 | default=1e-2, 208 | help="L2 Regularization Coefficient.", 209 | ) 210 | parser.add_argument( 211 | "--use_bias", action="store_true", help="Use a bias when training GREMLIN." 212 | ) 213 | parser.add_argument( 214 | "--no_bias", 215 | action="store_false", 216 | help="Use a bias when training GREMLIN.", 217 | dest="use_bias", 218 | ) 219 | parser.add_argument( 220 | "--num_attention_heads", 221 | type=int, 222 | default=32, 223 | help="Number of attention heads.", 224 | ) 225 | parser.add_argument( 226 | "--attention_head_size", 227 | type=int, 228 | default=16, 229 | help="Dims in each attention head.", 230 | ) 231 | parser.add_argument( 232 | "--num_layers", 233 | type=int, 234 | default=1, 235 | help="Number of attention layers.") 236 | parser.add_argument( 237 | "--optimizer", 238 | choices=["adam", "lamb"], 239 | default="adam", 240 | help="Which optimizer to use.", 241 | ) 242 | return parser 243 | -------------------------------------------------------------------------------- /mogwai/models/attention.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser, Namespace 2 | from typing import Optional 3 | import math 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from ..utils import symmetrize_matrix_, symmetrize_potts 10 | from ..utils.init import init_potts_bias, gremlin_weight_decay_coeffs 11 | from .. import lr_schedulers 12 | 13 | 14 | class Attention(BaseModel): 15 | """Attention Layer. 16 | 17 | Args: 18 | num_seqs (int): Number of sequences in MSA. 19 | msa_length (int): Length of MSA. 20 | msa_counts (tensor, optional): Counts of each amino acid in each position of MSA. Used 21 | for initialization. 22 | attention_head_size (int, optional): Dimension of queries and keys for a single head. 23 | num_attention_heads (int, optional): Number of attention heads. 24 | optimizer (str, optional): Choice of optimizer from ["adam", "lamb"] 25 | learning_rate (float, optional): Learning rate for training model. 26 | vocab_size (int, optional): Alphabet size of MSA. 27 | true_contacts (tensor, optional): True contacts for family. Used to compute 28 | metrics while training. 29 | l2_coeff (int, optional): Coefficient of L2 regularization for all weights. 30 | use_bias (bool, optional): Whether to include single-site potentials. 31 | pad_idx (int, optional): Integer for padded positions. 32 | lr_scheduler (str, optional): Learning schedule to use. Choose from ["constant", "warmup_constant"]. 33 | warmup_steps (int, optional): Number of warmup steps for learning rate schedule. 34 | max_steps (int, optional): Maximum number of training batches before termination. 35 | """ 36 | 37 | def __init__( 38 | self, 39 | num_seqs: int, 40 | msa_length: int, 41 | msa_counts: Optional[torch.Tensor] = None, 42 | attention_head_size: int = 16, 43 | num_attention_heads: int = 32, 44 | optimizer: str = "adam", 45 | learning_rate: float = 1e-3, 46 | vocab_size: int = 20, 47 | true_contacts: Optional[torch.Tensor] = None, 48 | l2_coeff: float = 1e-2, 49 | use_bias: bool = True, 50 | pad_idx: int = 20, 51 | lr_scheduler: str = "warmup_constant", 52 | warmup_steps: int = 0, 53 | max_steps: int = 10000, 54 | ): 55 | super().__init__(num_seqs, msa_length, learning_rate, vocab_size, true_contacts) 56 | self.l2_coeff = l2_coeff 57 | self.use_bias = use_bias 58 | self.pad_idx = pad_idx 59 | self.num_seqs = num_seqs 60 | self.msa_length = msa_length 61 | self.num_attention_heads = num_attention_heads 62 | self.attention_head_size = attention_head_size 63 | self.optimizer = optimizer 64 | self.vocab_size = vocab_size 65 | self.lr_scheduler = lr_scheduler 66 | self.warmup_steps = warmup_steps 67 | self.max_steps = max_steps 68 | 69 | hidden_size = attention_head_size * num_attention_heads 70 | 71 | self.query = nn.Linear(msa_length + vocab_size, hidden_size, bias=False) 72 | self.key = nn.Linear(msa_length + vocab_size, hidden_size, bias=False) 73 | self.value = nn.Linear(msa_length + vocab_size, hidden_size, bias=False) 74 | self.output = nn.Linear(hidden_size, vocab_size, bias=False) 75 | 76 | if self.use_bias: 77 | if msa_counts is not None: 78 | bias = init_potts_bias(msa_counts, l2_coeff, num_seqs) 79 | else: 80 | bias = torch.zeros(msa_length, vocab_size) 81 | self.bias = nn.Parameter(bias, requires_grad=True) 82 | 83 | self.register_buffer("posembed", torch.eye(msa_length).unsqueeze(0)) 84 | self.register_buffer("one_hot", torch.eye(vocab_size + 1, vocab_size)) 85 | self.register_buffer("diag_mask", torch.eye(msa_length) * -10000) 86 | 87 | # self.save_hyperparameters() 88 | 89 | def maybe_onehot_inputs(self, src_tokens): 90 | """Onehots src_tokens if necessary otherwise uses original tokens""" 91 | if src_tokens.dtype == torch.long: 92 | return self.one_hot[src_tokens] 93 | else: 94 | return src_tokens 95 | 96 | def forward(self, src_tokens, targets=None, src_lengths=None): 97 | batch_size, seqlen = src_tokens.size()[:2] 98 | aa_inputs = self.maybe_onehot_inputs(src_tokens) 99 | posembed = self.posembed.repeat(batch_size, 1, 1) 100 | inputs = torch.cat((aa_inputs, posembed), -1) 101 | 102 | queries = self.query(inputs) 103 | keys = self.key(inputs) 104 | values = self.value(inputs) 105 | 106 | queries = queries.view( 107 | batch_size, seqlen, self.num_attention_heads, self.attention_head_size 108 | ) 109 | keys = keys.view( 110 | batch_size, seqlen, self.num_attention_heads, self.attention_head_size 111 | ) 112 | values = values.view( 113 | batch_size, seqlen, self.num_attention_heads, self.attention_head_size 114 | ) 115 | attention = torch.einsum("nihd,njhd->nhij", queries, keys) 116 | attention = attention / math.sqrt(self.attention_head_size) 117 | attention = attention + self.diag_mask 118 | attention = attention.softmax(-1) 119 | 120 | context = torch.einsum("nhij,njhd->nihd", attention, values) 121 | context = context.reshape( 122 | batch_size, seqlen, self.num_attention_heads * self.attention_head_size 123 | ) 124 | logits = self.output(context).contiguous() 125 | 126 | if self.use_bias: 127 | logits = logits + self.bias 128 | 129 | outputs = (logits, attention) 130 | if targets is not None: 131 | mrf_weight = self.compute_mrf_weight(attention) 132 | loss = self.loss(logits, targets, mrf_weight) 133 | outputs = (loss,) + outputs 134 | 135 | return outputs 136 | 137 | def configure_optimizers(self): 138 | if self.optimizer == "adam": 139 | optimizer = torch.optim.AdamW( 140 | self.parameters(), lr=self.learning_rate, weight_decay=0.0 141 | ) 142 | elif self.optimizer == "lamb": 143 | from apex.optimizers import FusedLAMB 144 | optimizer = FusedLAMB( 145 | self.parameters(), 146 | lr=self.learning_rate, 147 | weight_decay=0.0, 148 | ) 149 | else: 150 | raise ValueError(f"Unrecognized optimizer {self.optimizer}") 151 | 152 | lr_scheduler = lr_schedulers.get(self.lr_scheduler)( 153 | optimizer, self.warmup_steps, self.trainer.max_steps 154 | ) 155 | scheduler_dict = { 156 | "scheduler": lr_scheduler, 157 | "interval": "step", 158 | } 159 | return [optimizer], [scheduler_dict] 160 | 161 | def compute_regularization(self, targets, mrf_weight: torch.Tensor): 162 | """Compute regularization weights based on the number of targets.""" 163 | batch_size = targets.size(0) 164 | 165 | weight_reg_coeff, bias_reg_coeff = gremlin_weight_decay_coeffs( 166 | batch_size, self.msa_length, self.l2_coeff, self.vocab_size 167 | ) 168 | sample_size = (targets != self.pad_idx).sum() 169 | 170 | batch_size = mrf_weight.size()[0] 171 | reg = weight_reg_coeff * mrf_weight.norm() 172 | if self.use_bias: 173 | reg += bias_reg_coeff * self.bias.norm() 174 | 175 | return reg * sample_size 176 | 177 | def loss(self, logits, targets, mrf_weight: torch.Tensor): 178 | """Compute GREMLIN loss w/ L2 Regularization""" 179 | loss = nn.CrossEntropyLoss(ignore_index=self.pad_idx, reduction="sum")( 180 | logits.view(-1, self.vocab_size), targets.view(-1) 181 | ) 182 | loss *= self.num_seqs / logits.size(0) 183 | loss += self.compute_regularization(targets, mrf_weight) 184 | return loss 185 | 186 | def compute_mrf_weight(self, attention): 187 | # Note that attention gives a mapping x -> MRF Weights(x), 188 | # so it is more general than a simple Pairwise MRF. 189 | value = self.value.weight.view( 190 | self.vocab_size + self.msa_length, 191 | self.num_attention_heads, 192 | self.attention_head_size, 193 | ) 194 | output = self.output.weight 195 | output = output.view( 196 | self.vocab_size, self.num_attention_heads, self.attention_head_size 197 | ) 198 | embed = torch.einsum("ahd,bhd->hab", value, output) # H x (A + L) x A 199 | mrf_weight = torch.einsum("hij,hab->iajb", attention.sum(0), embed) 200 | return mrf_weight 201 | 202 | @torch.no_grad() 203 | def get_contacts(self): 204 | """Extracts contacts by getting the attentions.""" 205 | inputs = torch.full( 206 | [1, self.msa_length], 207 | self.pad_idx, 208 | dtype=torch.long, 209 | device=next(self.parameters()).device, 210 | ) 211 | *_, attention = self.forward(inputs) 212 | mrf_weight = self.compute_mrf_weight(attention).squeeze() 213 | return mrf_weight.norm(dim=(1, 3)) 214 | 215 | @classmethod 216 | def from_args( 217 | cls, 218 | args: Namespace, 219 | num_seqs: int, 220 | msa_length: int, 221 | msa_counts: Optional[torch.Tensor] = None, 222 | vocab_size: int = 20, 223 | pad_idx: int = 20, 224 | true_contacts: Optional[torch.Tensor] = None, 225 | ) -> "Attention": 226 | return cls( 227 | num_seqs=num_seqs, 228 | msa_length=msa_length, 229 | msa_counts=msa_counts, 230 | attention_head_size=args.attention_head_size, 231 | num_attention_heads=args.num_attention_heads, 232 | optimizer=args.optimizer, 233 | learning_rate=args.learning_rate, 234 | vocab_size=vocab_size, 235 | true_contacts=true_contacts, 236 | l2_coeff=args.l2_coeff, 237 | use_bias=args.use_bias, 238 | pad_idx=pad_idx, 239 | ) 240 | 241 | @staticmethod 242 | def add_args(parser: ArgumentParser) -> ArgumentParser: 243 | parser.add_argument( 244 | "--learning_rate", 245 | type=float, 246 | default=1e-3, 247 | help="Learning rate for training.", 248 | ) 249 | parser.add_argument( 250 | "--l2_coeff", 251 | type=float, 252 | default=1e-2, 253 | help="L2 Regularization Coefficient.", 254 | ) 255 | parser.add_argument( 256 | "--use_bias", action="store_true", help="Use a bias when training GREMLIN." 257 | ) 258 | parser.add_argument( 259 | "--no_bias", 260 | action="store_false", 261 | help="Use a bias when training GREMLIN.", 262 | dest="use_bias", 263 | ) 264 | parser.add_argument( 265 | "--num_attention_heads", 266 | type=int, 267 | default=32, 268 | help="Number of attention heads.", 269 | ) 270 | parser.add_argument( 271 | "--attention_head_size", 272 | type=int, 273 | default=16, 274 | help="Dims in each attention head.", 275 | ) 276 | parser.add_argument( 277 | "--optimizer", 278 | choices=["adam", "lamb"], 279 | default="adam", 280 | help="Which optimizer to use.", 281 | ) 282 | parser.add_argument( 283 | "--lr_scheduler", 284 | choices=lr_schedulers.LR_SCHEDULERS.keys(), 285 | default="warmup_constant", 286 | help="Learning rate scheduler to use.", 287 | ) 288 | parser.add_argument( 289 | "--warmup_steps", 290 | type=int, 291 | default=0, 292 | help="How many warmup steps to use when using a warmup schedule.", 293 | ) 294 | return parser 295 | -------------------------------------------------------------------------------- /mogwai/alignment.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | from pathlib import Path 3 | import subprocess 4 | 5 | PathLike = Union[str, Path] 6 | 7 | 8 | class HHBlits: 9 | """ 10 | Args: 11 | database (str): UniClust database 12 | mact (float, optional): (0, 1] posterior prob threshold for MAC realignment 13 | controlling greediness at alignment ends: 0:global >0.1:local (default=0.35) 14 | maxfilt (int, optional): max number of hits allowed to pass 2nd prefilter 15 | (default=20000) 16 | neffmax (float, optional): (1,20] skip further search iterations when diversity 17 | Neff of query MSA becomes larger than neffmax (default=20.0) 18 | cpu (int, optional): number of CPUs to use (for shared memory SMPs) (default=2) 19 | all_seqs (bool, optional): show all sequences in result MSA; do not filter 20 | result MSA (default=False) 21 | realign_max (int, optional): realign max. hits (default=500) 22 | maxmem (float, optional): [1,inf) limit memory for realignment (in GB) 23 | (default=3.0) 24 | n (int, optional): [1,8] number of iterations (default=2) 25 | evalue (float, optional): E-value cutoff for inclusion in result alignment 26 | (default=0.001) 27 | verbose (bool, optional): whether to print information (default=False) 28 | hhblits_bin(str, optional): hhfilter binary (def="hhfilter") 29 | """ 30 | 31 | def __init__( 32 | self, 33 | database: str, 34 | mact: float = 0.35, 35 | maxfilt: int = 20000, 36 | neffmax: float = 20.0, 37 | cpu: int = 2, 38 | all_seqs: bool = False, 39 | realign_max: int = 500, 40 | maxmem: float = 3, 41 | n: int = 2, 42 | verbose: bool = False, 43 | evalue: float = 0.001, 44 | diff: int = 1000, 45 | hhblits_bin: str = "hhblits", 46 | ): 47 | command = [ 48 | f"{hhblits_bin}", 49 | f"-d {database}", 50 | f"-mact {mact}", 51 | f"-maxfilt {maxfilt}", 52 | f"-neffmax {neffmax}", 53 | f"-cpu {cpu}", 54 | f"-realign_max {realign_max}", 55 | f"-maxmem {maxmem}", 56 | f"-n {n}", 57 | f"-diff {diff}", 58 | "-o /dev/null", 59 | f"-v {0 if not verbose else 2}", 60 | ] 61 | 62 | if all_seqs: 63 | command.append("-all") 64 | 65 | if not verbose: 66 | command.append("-o /dev/null") 67 | 68 | self.command = " ".join(command) 69 | self._evalue = evalue 70 | 71 | def run(self, input_file: PathLike, out_prefix: Optional[PathLike] = None) -> None: 72 | if out_prefix is None: 73 | out_prefix = input_file 74 | out_path = Path(out_prefix).with_suffix(".a3m") 75 | extra = f"-i {input_file} -oa3m {out_path} -e {self.evalue}" 76 | command = self.command.split() + extra.split() 77 | print(" ".join(command)) 78 | result = subprocess.run(command) 79 | result.check_returncode() 80 | 81 | @property 82 | def evalue(self) -> float: 83 | return self._evalue 84 | 85 | @evalue.setter 86 | def evalue(self, val): 87 | self._evalue = val 88 | 89 | 90 | class HHFilter: 91 | """ 92 | Args: 93 | verbose (bool, optional): Verbose mode (default=False) 94 | seqid (int, optional): [0,100] maximum pairwise sequence identity (%) 95 | (default=90) 96 | diff (int, optional): [0,inf) filter MSA by selecting most diverse set of 97 | sequences, keeping at least this many seqs in each MSA block of length 50 98 | (default=0) 99 | cov (int, optional): [0,100] minimum coverage with query (%) (default=0) 100 | qid (int, optional): [0,100] minimum sequence identity with query (%) 101 | (default=0) 102 | qsc (float, optional): [0,100] minimum score per column with query 103 | (default=-20.0) 104 | neff (float, optional): [1,inf] target diversity of alignment (default=off) 105 | 106 | M (str, optional): One of <'a2m', 'first', [0, 100]> 107 | * a2m: use A2M/A3M (default): upper case = Match; lower case = Insert; 108 | '-' = Delete; '.' = gaps aligned to inserts (may be omitted) 109 | * first: use FASTA: columns with residue in 1st sequence are match states 110 | * [0,100] use FASTA: columns with fewer than X% gaps are match states 111 | maxseq (int, optional): max number of input rows (def=65535) 112 | maxres (int, optional): max number of HMM columns (def=20001) 113 | hhfilter_bin(str, optional): hhfilter binary (def="hhfilter") 114 | """ 115 | 116 | def __init__( 117 | self, 118 | verbose: bool = False, 119 | seqid: int = 90, 120 | diff: int = 0, 121 | cov: int = 0, 122 | qid: int = 0, 123 | qsc: float = -20, 124 | neff: float = -1, 125 | M: str = "a2m", 126 | maxseq: int = 65535, 127 | maxres: int = 20001, 128 | hhfilter_bin: str = "hhfilter", 129 | ): 130 | command = [ 131 | f"{hhfilter_bin}", 132 | f"-v {0 if not verbose else 2}", 133 | f"-id {seqid}", 134 | f"-diff {diff}", 135 | f"-cov {cov}", 136 | f"-qid {qid}", 137 | f"-qsc {qsc}", 138 | f"-neff {neff}", 139 | f"-M {M}", 140 | f"-maxseq {maxseq}", 141 | f"-maxres {maxres}", 142 | ] 143 | 144 | self.command = " ".join(command) 145 | 146 | def run( 147 | self, 148 | input_file: PathLike, 149 | output_file: PathLike, 150 | append_file: Optional[PathLike] = None, 151 | ) -> None: 152 | extra = f"-i {input_file} -o {output_file}" 153 | if append_file is not None: 154 | extra = extra + f" -a {append_file}" 155 | command = self.command.split() + extra.split() 156 | result = subprocess.run(command) 157 | result.check_returncode() 158 | 159 | 160 | def count_sequences(fasta_file: PathLike) -> int: 161 | num_seqs = subprocess.check_output(f'grep "^>" -c {fasta_file}', shell=True) 162 | return int(num_seqs) 163 | 164 | 165 | def remove_descriptions(fasta_file: PathLike) -> None: 166 | input_file = Path(fasta_file) 167 | output_file = input_file.with_suffix(".a3m.bk") 168 | command = f"cat {input_file} | " + r"awk '{print $1}' > " + str(output_file) 169 | subprocess.run(command, shell=True) 170 | output_file.rename(input_file) 171 | 172 | 173 | def make_a3m( 174 | input_file: str, 175 | database: str, 176 | metagenomic_database: Optional[str] = None, 177 | keep_intermediates: bool = False, 178 | hhblits_bin: str = "hhblits", 179 | hhfilter_bin: str = "hhfilter", 180 | ) -> None: 181 | hhblits = HHBlits( 182 | database, 183 | mact=0.35, 184 | maxfilt=100000000, 185 | neffmax=20, 186 | cpu=20, 187 | all_seqs=True, 188 | realign_max=10000000, 189 | maxmem=64, 190 | n=4, 191 | verbose=False, 192 | hhblits_bin=hhblits_bin, 193 | ) 194 | 195 | hhfilter_id90cov75 = HHFilter( 196 | seqid=90, cov=75, verbose=False, hhfilter_bin=hhfilter_bin 197 | ) 198 | hhfilter_id90cov50 = HHFilter( 199 | seqid=90, cov=50, verbose=False, hhfilter_bin=hhfilter_bin 200 | ) 201 | 202 | output_file = Path(input_file).with_suffix(".a3m") 203 | if output_file.exists(): 204 | raise FileExistsError(f"{output_file} already exists!") 205 | 206 | prev_a3m = Path(input_file) 207 | intermediates = [] 208 | evalues = [1e-80, 1e-60, 1e-40, 1e-20, 1e-10, 1e-8, 1e-6, 1e-4, 1e-3, 1e-1] 209 | 210 | for evalue in evalues: 211 | # HHblits at particular evalue 212 | hhblits.evalue = evalue 213 | out_path = Path(input_file).with_name(f".{Path(input_file).stem}.{evalue}.a3m") 214 | if not out_path.exists(): 215 | hhblits.run(prev_a3m, out_path) 216 | intermediates.append(out_path) 217 | 218 | # HHFilter id90, cov75 219 | id90cov75_path = Path(input_file).with_name( 220 | f".{Path(input_file).stem}.{evalue}.id90cov75.a3m" 221 | ) 222 | intermediates.append(id90cov75_path) 223 | if not id90cov75_path.exists(): 224 | hhfilter_id90cov75.run(out_path, id90cov75_path) 225 | if count_sequences(id90cov75_path) > 2000: 226 | id90cov75_path.rename(output_file) 227 | break 228 | 229 | # HHFilter id90, cov50 230 | id90cov50_path = Path(input_file).with_name( 231 | f".{Path(input_file).stem}.{evalue}.id90cov50.a3m" 232 | ) 233 | intermediates.append(id90cov50_path) 234 | if not id90cov50_path.exists(): 235 | hhfilter_id90cov50.run(out_path, id90cov50_path) 236 | if count_sequences(id90cov50_path) > 5000: 237 | id90cov50_path.rename(output_file) 238 | break 239 | 240 | prev_a3m = id90cov50_path 241 | 242 | if not output_file.exists() and metagenomic_database is not None: 243 | hhblits = HHBlits( 244 | metagenomic_database, 245 | mact=0.35, 246 | maxfilt=100000000, 247 | neffmax=20, 248 | cpu=20, 249 | all_seqs=True, 250 | realign_max=10000000, 251 | maxmem=64, 252 | n=4, 253 | verbose=False, 254 | hhblits_bin=hhblits_bin, 255 | ) 256 | for evalue in [1e-80, 1e-60, 1e-40, 1e-20, 1e-10, 1e-8, 1e-6, 1e-4, 1e-3, 1e-1]: 257 | # HHblits at particular evalue 258 | hhblits.evalue = evalue 259 | out_path = Path(input_file).with_name( 260 | f".{Path(input_file).stem}.{evalue}.metagenomic.a3m" 261 | ) 262 | if not out_path.exists(): 263 | hhblits.run(prev_a3m, out_path) 264 | intermediates.append(out_path) 265 | 266 | # HHFilter id90, cov75 267 | id90cov75_path = Path(input_file).with_name( 268 | f".{Path(input_file).stem}.{evalue}.metagenomic.id90cov75.a3m" 269 | ) 270 | intermediates.append(id90cov75_path) 271 | if not id90cov75_path.exists(): 272 | hhfilter_id90cov75.run(out_path, id90cov75_path) 273 | if count_sequences(id90cov75_path) > 2000: 274 | id90cov75_path.rename(output_file) 275 | break 276 | 277 | # HHFilter id90, cov50 278 | id90cov50_path = Path(input_file).with_name( 279 | f".{Path(input_file).stem}.{evalue}.metagenomic.id90cov50.a3m" 280 | ) 281 | intermediates.append(id90cov50_path) 282 | if not id90cov50_path.exists(): 283 | hhfilter_id90cov50.run(out_path, id90cov50_path) 284 | if count_sequences(id90cov50_path) > 5000: 285 | id90cov50_path.rename(output_file) 286 | break 287 | 288 | prev_a3m = id90cov50_path 289 | if not output_file.exists(): 290 | id90cov50_path.rename(output_file) 291 | 292 | remove_descriptions(output_file) 293 | 294 | if not keep_intermediates: 295 | for intermediate in intermediates: 296 | if intermediate.exists(): 297 | intermediate.unlink() 298 | 299 | 300 | def make_a3m_cli(): 301 | import argparse 302 | 303 | parser = argparse.ArgumentParser( 304 | description="Create an alignment from a query fasta file and uniclust database" 305 | ) 306 | parser.add_argument("input_file", type=str, help="Input fasta file.") 307 | parser.add_argument("database", type=str, help="Path to uniclust database.") 308 | parser.add_argument( 309 | "--metagenomic_database", type=str, default=None, help="Path to BFD database." 310 | ) 311 | parser.add_argument( 312 | "--keep_intermediates", 313 | action="store_true", 314 | help="Don't delete intermediate a3m files.", 315 | ) 316 | args = parser.parse_args() 317 | make_a3m( 318 | args.input_file, 319 | args.database, 320 | args.metagenomic_database, 321 | args.keep_intermediates, 322 | ) 323 | -------------------------------------------------------------------------------- /mogwai/models/factored_attention.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser, Namespace 2 | from typing import Optional 3 | import math 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from ..utils import symmetrize_matrix_, symmetrize_potts 10 | from ..utils.init import init_potts_bias, gremlin_weight_decay_coeffs 11 | from .. import lr_schedulers 12 | 13 | 14 | class FactoredAttention(BaseModel): 15 | """FactoredAttention Layer. 16 | 17 | Args: 18 | num_seqs (int): Number of sequences in MSA. 19 | msa_length (int): Length of MSA. 20 | msa_counts (tensor, optional): Counts of each amino acid in each position of MSA. Used 21 | for initialization. 22 | attention_head_size (int, optional): Dimension of queries and keys for a single head. 23 | num_attention_heads (int, optional): Number of attention heads. 24 | optimizer (str, optional): Choice of optimizer from ["adam", "lamb", "gremlin"]. "gremlin" 25 | specifies GremlinAdam. 26 | learning_rate (float, optional): Learning rate for training model. 27 | vocab_size (int, optional): Alphabet size of MSA. 28 | true_contacts (tensor, optional): True contacts for family. Used to compute 29 | metrics while training. 30 | l2_coeff (int, optional): Coefficient of L2 regularization for all weights. 31 | use_bias (bool, optional): Whether to include single-site potentials. 32 | pad_idx (int, optional): Integer for padded positions. 33 | lr_scheduler (str, optional): Learning schedule to use. Choose from ["constant", "warmup_constant"]. 34 | warmup_steps (int, optional): Number of warmup steps for learning rate schedule. 35 | max_steps (int, optional): Maximum number of training batches before termination. 36 | factorize_vocab (bool, optional): Factorize the (A, A) interaction terms into a product of 37 | (A, d) and (d, A) matrices. True allows for arbitrary value dimension. 38 | """ 39 | 40 | def __init__( 41 | self, 42 | num_seqs: int, 43 | msa_length: int, 44 | msa_counts: Optional[torch.Tensor] = None, 45 | attention_head_size: int = 16, 46 | num_attention_heads: int = 32, 47 | optimizer: str = "adam", 48 | learning_rate: float = 1e-3, 49 | use_adaptive_lr: bool = False, 50 | vocab_size: int = 20, 51 | true_contacts: Optional[torch.Tensor] = None, 52 | l2_coeff: float = 1e-2, 53 | use_bias: bool = True, 54 | pad_idx: int = 20, 55 | lr_scheduler: str = "warmup_constant", 56 | warmup_steps: int = 0, 57 | max_steps: int = 10000, 58 | factorize_vocab: bool = False, 59 | ): 60 | super().__init__(num_seqs, msa_length, learning_rate, vocab_size, true_contacts) 61 | self.l2_coeff = l2_coeff 62 | self.use_bias = use_bias 63 | self.pad_idx = pad_idx 64 | self.num_seqs = num_seqs 65 | self.msa_length = msa_length 66 | self.num_attention_heads = num_attention_heads 67 | self.attention_head_size = attention_head_size 68 | self.optimizer = optimizer 69 | self.vocab_size = vocab_size 70 | self.lr_scheduler = lr_scheduler 71 | self.warmup_steps = warmup_steps 72 | self.max_steps = max_steps 73 | self.factorize_vocab = factorize_vocab 74 | self.use_adaptive_lr = use_adaptive_lr 75 | 76 | if self.use_adaptive_lr: 77 | self.learning_rate *= math.log(self.num_seqs) / self.msa_length 78 | 79 | hidden_size = attention_head_size * num_attention_heads 80 | 81 | query = torch.empty(msa_length, num_attention_heads, attention_head_size) 82 | nn.init.xavier_uniform_(query) 83 | self.query = nn.Parameter(query, requires_grad=True) 84 | 85 | key = torch.empty(msa_length, num_attention_heads, attention_head_size) 86 | nn.init.xavier_uniform_(key) 87 | self.key = nn.Parameter(key, requires_grad=True) 88 | 89 | if self.factorize_vocab: 90 | value = torch.empty(num_attention_heads, vocab_size, attention_head_size) 91 | nn.init.xavier_uniform_(value) 92 | self.value = nn.Parameter(value, requires_grad=True) 93 | 94 | output = torch.empty(num_attention_heads, attention_head_size, vocab_size) 95 | nn.init.xavier_uniform_(output) 96 | self.output = nn.Parameter(output, requires_grad=True) 97 | else: 98 | value = torch.empty(num_attention_heads, vocab_size, vocab_size) 99 | nn.init.xavier_uniform_(value) 100 | self.value = nn.Parameter(value, requires_grad=True) 101 | 102 | if self.use_bias: 103 | if msa_counts is not None: 104 | bias = init_potts_bias(msa_counts, l2_coeff, num_seqs) 105 | else: 106 | bias = torch.zeros(msa_length, vocab_size) 107 | 108 | self.bias = nn.Parameter(bias, True) 109 | 110 | self.register_buffer("diag_mask", torch.eye(msa_length) * -10000) 111 | self.register_buffer("one_hot", torch.eye(vocab_size + 1, vocab_size)) 112 | 113 | # self.save_hyperparameters() 114 | 115 | def maybe_onehot_inputs(self, src_tokens): 116 | """Onehots src_tokens if necessary otherwise uses original tokens""" 117 | if src_tokens.dtype == torch.long: 118 | return self.one_hot[src_tokens] 119 | else: 120 | return src_tokens 121 | 122 | def forward(self, src_tokens, targets=None, src_lengths=None): 123 | inputs = self.maybe_onehot_inputs(src_tokens) 124 | mrf_weight = self.compute_mrf_weight() 125 | logits = torch.tensordot(inputs, mrf_weight, 2) 126 | 127 | if self.use_bias: 128 | logits = logits + self.bias 129 | 130 | outputs = (logits, mrf_weight.norm(dim=(1, 3))) 131 | if targets is not None: 132 | loss = self.loss(logits, targets, mrf_weight) 133 | outputs = (loss,) + outputs 134 | return outputs 135 | 136 | def configure_optimizers(self): 137 | if self.optimizer == "adam": 138 | optimizer = torch.optim.AdamW( 139 | self.parameters(), lr=self.learning_rate, weight_decay=0.0 140 | ) 141 | elif self.optimizer == "lamb": 142 | from apex.optimizers import FusedLAMB 143 | optimizer = FusedLAMB( 144 | self.parameters(), 145 | lr=self.learning_rate, 146 | weight_decay=0.0, 147 | ) 148 | elif self.optimizer == "gremlin": 149 | from ..optim import GremlinAdam 150 | 151 | optimizer = GremlinAdam( 152 | [{"params": self.parameters(), "gremlin": True}], 153 | lr=self.learning_rate, 154 | ) 155 | else: 156 | raise ValueError(f"Unrecognized optimizer {self.optimizer}") 157 | 158 | lr_scheduler = lr_schedulers.get(self.lr_scheduler)( 159 | optimizer, self.warmup_steps, self.trainer.max_steps 160 | ) 161 | scheduler_dict = { 162 | "scheduler": lr_scheduler, 163 | "interval": "step", 164 | } 165 | return [optimizer], [scheduler_dict] 166 | 167 | def compute_regularization(self, targets, mrf_weight: torch.Tensor): 168 | """Compute regularization weights based on the number of targets.""" 169 | batch_size = targets.size(0) 170 | 171 | weight_reg_coeff, bias_reg_coeff = gremlin_weight_decay_coeffs( 172 | batch_size, self.msa_length, self.l2_coeff, self.vocab_size 173 | ) 174 | 175 | sample_size = (targets != self.pad_idx).sum() 176 | # After multiplying by sample_size, comes to lambda * L * A / 2 177 | reg = weight_reg_coeff * mrf_weight.pow(2).sum() 178 | if self.use_bias: 179 | # After multiplying by sample_size, comes to lambda 180 | reg += bias_reg_coeff * self.bias.pow(2).sum() 181 | 182 | return reg * sample_size 183 | 184 | def loss(self, logits, targets, mrf_weight: torch.Tensor): 185 | """Compute GREMLIN loss w/ L2 Regularization""" 186 | loss = nn.CrossEntropyLoss(ignore_index=self.pad_idx, reduction="sum")( 187 | logits.view(-1, self.vocab_size), targets.view(-1) 188 | ) 189 | loss *= self.num_seqs / logits.size(0) 190 | loss += self.compute_regularization(targets, mrf_weight) 191 | return loss 192 | 193 | def compute_mrf_weight(self): 194 | attention = torch.einsum("ihd,jhd->hij", self.query, self.key) 195 | attention = attention / math.sqrt(self.attention_head_size) 196 | attention = attention + self.diag_mask 197 | attention = attention.softmax(-1) # H x L x L 198 | 199 | if self.factorize_vocab: 200 | embed = torch.einsum("had,hdb->hab", self.value, self.output) # H x A x A 201 | else: 202 | embed = self.value 203 | 204 | W = torch.einsum("hij,hab->iajb", attention, embed) # L x A x L x A 205 | W = symmetrize_potts(W) 206 | return W 207 | 208 | @torch.no_grad() 209 | def get_contacts(self, mrf_weight: Optional[torch.Tensor] = None): 210 | """Extracts contacts by getting the attentions.""" 211 | if mrf_weight is None: 212 | mrf_weight = self.compute_mrf_weight() 213 | return mrf_weight.norm(dim=(1, 3)) 214 | 215 | @classmethod 216 | def from_args( 217 | cls, 218 | args: Namespace, 219 | num_seqs: int, 220 | msa_length: int, 221 | msa_counts: Optional[torch.Tensor] = None, 222 | vocab_size: int = 20, 223 | pad_idx: int = 20, 224 | true_contacts: Optional[torch.Tensor] = None, 225 | ) -> "FactoredAttention": 226 | return cls( 227 | num_seqs=num_seqs, 228 | msa_length=msa_length, 229 | msa_counts=msa_counts, 230 | attention_head_size=args.attention_head_size, 231 | num_attention_heads=args.num_attention_heads, 232 | optimizer=args.optimizer, 233 | learning_rate=args.learning_rate, 234 | use_adaptive_lr=args.use_adaptive_lr, 235 | vocab_size=vocab_size, 236 | true_contacts=true_contacts, 237 | l2_coeff=args.l2_coeff, 238 | use_bias=args.use_bias, 239 | pad_idx=pad_idx, 240 | lr_scheduler=args.lr_scheduler, 241 | warmup_steps=args.warmup_steps, 242 | factorize_vocab=args.factorize_vocab, 243 | ) 244 | 245 | @staticmethod 246 | def add_args(parser: ArgumentParser) -> ArgumentParser: 247 | parser.add_argument( 248 | "--learning_rate", 249 | type=float, 250 | default=1e-3, 251 | help="Learning rate for training.", 252 | ) 253 | parser.add_argument( 254 | "--use_adaptive_lr", 255 | action="store_true", 256 | help="Whether to rescale lr as a function of MSA.", 257 | ) 258 | parser.add_argument( 259 | "--l2_coeff", 260 | type=float, 261 | default=1e-2, 262 | help="L2 Regularization Coefficient.", 263 | ) 264 | parser.add_argument( 265 | "--use_bias", action="store_true", help="Use a bias when training GREMLIN." 266 | ) 267 | parser.add_argument( 268 | "--no_bias", 269 | action="store_false", 270 | help="Use a bias when training GREMLIN.", 271 | dest="use_bias", 272 | ) 273 | parser.add_argument( 274 | "--num_attention_heads", 275 | type=int, 276 | default=32, 277 | help="Number of attention heads.", 278 | ) 279 | parser.add_argument( 280 | "--attention_head_size", 281 | type=int, 282 | default=16, 283 | help="Dims in each attention head.", 284 | ) 285 | parser.add_argument( 286 | "--optimizer", 287 | choices=["adam", "lamb", "gremlin"], 288 | default="adam", 289 | help="Which optimizer to use.", 290 | ) 291 | parser.add_argument( 292 | "--lr_scheduler", 293 | choices=lr_schedulers.LR_SCHEDULERS.keys(), 294 | default="warmup_constant", 295 | help="Learning rate scheduler to use.", 296 | ) 297 | parser.add_argument( 298 | "--warmup_steps", 299 | type=int, 300 | default=0, 301 | help="How many warmup steps to use when using a warmup schedule.", 302 | ) 303 | parser.add_argument( 304 | "--factorize_vocab", 305 | action="store_true", 306 | help="Whether to factorize the vocab embedding.", 307 | ) 308 | return parser 309 | --------------------------------------------------------------------------------