├── LICENSE ├── README.md ├── SECURITY.md ├── azure-pipelines.yml ├── examples ├── gb1_a60fb_unrelaxed_rank_1_model_5.pdb ├── gb1_a60fb_unrelaxed_rank_1_model_5.pdb.gz ├── gb1s.csv └── some_proteins.fasta ├── requirements.txt ├── scripts ├── extract.py └── extract_mif.py ├── sequence_models ├── __init__.py ├── aaindex.py ├── collaters.py ├── constants.py ├── convolutional.py ├── datasets.py ├── esm.py ├── flip_utils.py ├── gnn.py ├── gvp.py ├── layers.py ├── losses.py ├── metrics.py ├── mixup.py ├── pdb_utils.py ├── pretrained.py ├── samplers.py ├── structure.py ├── trRosetta.py ├── trRosetta_utils.py ├── utils.py └── vae.py ├── setup.py └── tests ├── conv_test.py ├── data_test.py ├── graphmodel_test ├── T1001.a3m ├── T1001.npz ├── T1001_loader.py └── graphmodel_decoder_test ├── loss_test.py ├── pdb_utils_test.py └── vae_test.py /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2022 2 | Microsoft [All rights reserved]. 3 | 4 | 1. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 5 | 6 | Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 7 | THIS SOFTWARE IS PROVIDED BY [Name of Organization] “AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL [Name of Organisation] BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Pytorch modules and utilities for modeling biological sequence data. 2 | 3 | Here we will demonstrate the application of several tools we hope will help with modeling biological sequences. 4 | 5 | ## Installation 6 | 7 | ``` 8 | $ pip install sequence-models 9 | $ pip install git+https://github.com/microsoft/protein-sequence-models.git # bleeding edge, current repo main branch 10 | 11 | ``` 12 | 13 | ## Loading pretrained models 14 | 15 | Models require PyTorch. We tested on `v1.9.0`, `v1.11.0`,and `1.12`. If you installed into a clean conda environment, you may also need to install pandas, scipy, and wget. 16 | 17 | To load a model: 18 | 19 | ``` 20 | from sequence_models.pretrained import load_model_and_alphabet 21 | 22 | model, collater = load_model_and_alphabet('carp_640M') 23 | ``` 24 | 25 | Available models are 26 | 27 | - `carp_600k` 28 | - `carp_38M` 29 | - `carp_76M` 30 | - `carp_640M` 31 | - `mif` 32 | - `mifst` 33 | - `bigcarp_esm1bfinetune` 34 | - `bigcarp_esm1bfrozen` 35 | - `bigcarp_random` 36 | 37 | 38 | ## Convolutional autoencoding representations of proteins (CARP) 39 | 40 | We make available pretrained CNN protein sequence masked language models of various sizes. All of these have a ByteNet encoder architecture and are pretrained on the March 2020 release of UniRef50 using the same masked language modeling task as in BERT and ESM-1b. 41 | 42 | CARP is described in this [preprint](https://doi.org/10.1101/2022.05.19.492714). 43 | 44 | You can also download the weights manually from [Zenodo](https://doi.org/10.5281/zenodo.6368483). 45 | 46 | To encode a batch of sequences: 47 | 48 | ``` 49 | seqs = [['MDREQ'], ['MGTRRLLP']] 50 | x = collater(seqs)[0] # (n, max_len) 51 | rep = model(x) # (n, max_len, d_model) 52 | ``` 53 | 54 | CARP also supports computing representations from arbitrary layers and the final logits. 55 | 56 | ``` 57 | rep = model(x, repr_layers=[0, 2, 32], logits=True) 58 | ``` 59 | 60 | ### Compute embeddings in bulk from FASTA 61 | 62 | We provide a script that efficiently extracts embeddings in bulk from a FASTA file. A cuda device is optional and will be auto-detected. The following command extracts the final-layer embedding for a FASTA file from the `CARP_640M` model: 63 | 64 | ``` 65 | $ python scripts/extract.py carp_640M examples/some_proteins.fasta \ 66 | examples/results/some_proteins_emb_carp_640M/ \ 67 | --repr_layers 0 32 33 logits --include mean per_tok 68 | ``` 69 | Directory `examples/results/some_proteins_emb_carp_640M/` now contains one `.pt` file per extracted embedding; use `torch.load()` to load them. `scripts/extract.py` has flags that determine what .pt files are included: 70 | 71 | `--repr-layers` (default: final only) selects which layers to include embeddings from. `0` is the input embedding. `logits` is the per-token logits. 72 | 73 | `--include` specifies what embeddings to save. You can use the following: 74 | 75 | - `per_tok` includes the full sequence, with an embedding per amino acid (seq_len x hidden_dim). 76 | - `mean` includes the embeddings averaged over the full sequence, per layer (only valid for representations). 77 | - `logp` computes the average log probability per sequence and stores it in a csv (only valid for logits). 78 | 79 | `scripts/extract.py` also has `--batchsize` and `--device` flags. For example, to use GPU 2 on a multi-GPU machine, pass `--device cuda:2`. The default is to use a batchsize of 1 and `cpu` if cuda is not detected or `cuda:0` if cuda is detected. 80 | 81 | ## Masked Inverse Folding (MIF) and Masked Inverse Folding with Sequence Transfer (MIF-ST) 82 | 83 | We make available pretrained masked inverse folding models with and without sequence pretraining transfer from CARP-640M. 84 | 85 | MIF and MIF-ST are described in this [preprint](https://doi.org/10.1101/2022.05.25.493516) 86 | 87 | You can also download the weights manually from [Zenodo](https://zenodo.org/record/6573779#.YqjXT-zMI-Q). 88 | 89 | To encode a sequence with its structure: 90 | 91 | ``` 92 | from sequence_models.pdb_utils import parse_PDB, process_coords 93 | coords, wt, _ = parse_PDB('examples/gb1_a60fb_unrelaxed_rank_1_model_5.pdb') 94 | coords = { 95 | 'N': coords[:, 0], 96 | 'CA': coords[:, 1], 97 | 'C': coords[:, 2] 98 | } 99 | dist, omega, theta, phi = process_coords(coords) 100 | batch = [[wt, torch.tensor(dist, dtype=torch.float), 101 | torch.tensor(omega, dtype=torch.float), 102 | torch.tensor(theta, dtype=torch.float), torch.tensor(phi, dtype=torch.float)]] 103 | src, nodes, edges, connections, edge_mask = collater(batch) 104 | # can use result='repr' or result='logits'. Default is 'repr'. 105 | rep = model(src, nodes, edges, connections, edge_mask) 106 | ``` 107 | 108 | ### Compute embeddings in bulk from csv 109 | 110 | We provide a script that efficiently extracts embeddings in bulk from a csv file. A cuda device is optional and will be auto-detected. The following command extracts the final-layer embedding for a FASTA file from the `mifst` model: 111 | 112 | ``` 113 | $ python scripts/extract_mif.py mifst examples/gb1s.csv \ 114 | examples/ \ 115 | examples/results/some_proteins_mifst/ \ 116 | repr --include mean per_tok 117 | ``` 118 | Directory `examples/results/some_proteins_mifst/` now contains one `.pt` file per extracted embedding; use `torch.load()` to load them. `scripts/extract_mif.py` has flags that determine what .pt files are included: 119 | 120 | The syntax is: 121 | ``` 122 | $ python script/extract_mif.py --include 123 | ``` 124 | 125 | The input csv should have columns for `name`, `sequence`, and `pdb`. The script looks in `pdb_dir` for the filenames in the `pdb` column. 126 | 127 | The options for `result` are `repr` or `logits`. 128 | 129 | `--include` specifies what embeddings to save. You can use the following: 130 | 131 | - `per_tok` includes the full sequence, with an embedding per amino acid (seq_len x hidden_dim). 132 | - `mean` includes the embeddings averaged over the full sequence, per layer (only valid for representations). 133 | - `logp` computes the average log probability per sequence and stores it in a csv (only valid for logits). 134 | 135 | 136 | `scripts/extract.py` also has a `--device` flags. For example, to use GPU 2 on a multi-GPU machine, pass `--device cuda:2`. The default is to use `cpu` if cuda is not detected or `cuda:0` if cuda is detected. 137 | 138 | 139 | ## Biosynthetic gene cluster CARP (BiGCARP) 140 | 141 | We make available pretrained CNN Pfam domain masked language models of BGCs. All of these have a ByteNet encoder architecture and are pretrained on antiSMASH using the same masked language modeling task as in BERT and ESM-1b. 142 | 143 | BiGCARP is described in this [preprint](https://doi.org/10.1101/2022.07.22.500861). Training code is available [here](https://github.com/microsoft/protein-sequence-models). 144 | 145 | You can also download the weights and datasets manually from [Zenodo](https://doi.org/10.5281/zenodo.6857704). 146 | 147 | To encode a batch of sequences: 148 | 149 | ``` 150 | bgc = [['#;PF07690;PF06609;PF00083;PF00975;PF12697;PF00550;PF14765'], 151 | ['t3pks;PF07690;PF06609;PF00083;PF00975;PF12697;PF00550;PF14765;PF00698']] 152 | model, collater = load_model_and_alphabet('bigcarp_esm1bfinetune') 153 | x = collater(bgc)[0] 154 | rep = model(x) 155 | ``` 156 | 157 | 158 | ### Sequence Datasets and Dataloaders 159 | In ```sampler.py```, you will find two Pytorch sampler classes: ```SortishSampler```, a sampler to sort similarly length 160 | sample sequences into length-defined buckets; and ```ApproxBatchSampler```, a batch sampler which grabs sequences 161 | from length-defined buckets until the batch has the set approximate max number of tokens or max number of tokens squared. 162 | 163 | ``` 164 | from sequence_models.samplers import SortishSampler, ApproxBatchSampler 165 | 166 | # grab datasets 167 | ds = dataset # your sequence dataset 168 | 169 | # build dataloaders 170 | len_ds = np.array([len(i[0]) for i in ds]) # list of lengths of the sequence in dataset (in order) 171 | bucket_size = 1000 # number of length-defined buckets 172 | max_tokens = 8000 # max number of tokens per batch 173 | max_batch_size = 100 # max number of samples per batch 174 | sortish_sampler = SortishSampler(len_ds, bucket_size) 175 | batch_sampler = ApproxBatchSampler(sortish_sampler, max_tokens, max_batch_size, len_ds) 176 | collater = collater # your collater function 177 | dl = DataLoader(ds_train, collate_fn=collater, batch_sampler=batch_sampler, num_workers=16) 178 | ``` 179 | 180 | ### Pre-implemented Models 181 | * Struct2SeqDecoder (GNN) 182 | 183 | The ```Struct2SeqDecoder``` model was adapted from 184 | [Ingraham et al.](https://papers.nips.cc/paper/2019/file/f3a4ff4839c56a5f460c88cce3666a2b-Paper.pdf). This model uses protein structural information 185 | encoded as a graph nodes and edges representing the structural information of each amino acid residue and their 186 | relations to each other, respectively. 187 | 188 | If you already have node features, edge features, connections between nodes, encoded sequences (src), 189 | and edge mask (edge_mask); you can directly use the the ```Struct2SeqDecoder``` as demonstrated below: 190 | 191 | ``` 192 | from sequence_models.constants import trR_ALPHABET 193 | from sequence_models.gnn import Struct2SeqDecoder 194 | 195 | num_letters = len(trR_ALPHABET) # length of your amino acid alphabet 196 | node_features = 10 # number of node features 197 | edge_features = 11 # number of edge features 198 | hidden_dim = 128 # your choice of hidden layer dimension 199 | num_decoder_layers = 3 # your choice of number of decoder layers to use 200 | dropout = 0.1 # dropout used by decoder layer 201 | use_mpnn = False # if True, use MPNN layer, else use Transformer layer for decoder 202 | direction = 'bidirectional' # direction of information flow/masking: forward, backward or bidirectional 203 | 204 | model = Struct2SeqDecoder(num_letters, node_features, edge_features, hidden_dim, 205 | num_decoder_layers, dropout, use_mpnn, direction) 206 | out = model(nodes, edges, connections, src, edge_mask) 207 | ``` 208 | 209 | If you do not have prepared inputs, but have 2d maps representing the distance between residues (dist) and the dihedral 210 | angles between residues (omega, theta, and phi), you can use our preprocessing functions to generate nodes, edges, and 211 | connections as demonstrated below: 212 | 213 | ``` 214 | from sequence_models.gnn import get_node_features, get_k_neighbors, get_edge_features, \ 215 | get_mask, replace_nan 216 | 217 | # process features 218 | node = get_node_features(omega, theta, phi) # generate nodes 219 | dist = dist.fill_diagonal_(np.nan) # if the diagonal of dist tensor is not already filled with nans, it should 220 | # to prevent selecting self when getting k nearest residues in the next step 221 | connections = get_k_neighbors(dist, n_connections) # get connections 222 | edge = get_edge_features(dist, omega, theta, phi, connections) # generate edge 223 | edge_mask = get_mask(edge) # get edge mask (in the scenario where there is missing edge features between neighbors) 224 | edge = replace_nan(edge) # replace nans with 0s 225 | node = replace_nan(node) 226 | ``` 227 | 228 | Alternatively, we have also prepared ```StructureCollater```, a collater function 229 | found in ```collaters.py``` that also performs this task: 230 | 231 | ``` 232 | from sequence_models.collaters import StructureCollater 233 | 234 | n_connections = 20 # number of connections per amino acid residue 235 | collater = StructureCollater(n_connections=n_connections) 236 | ds = dataset # Dataset must return sequences, dists, omegas, thetas, phis 237 | dl = Dataloader(ds, collate_fn=collater) 238 | ``` 239 | 240 | * ByteNet 241 | 242 | The ```ByteNet``` model was adapted from [Kalchbrenner et al.](https://arxiv.org/abs/1610.10099). ByteNet uses stacked 243 | convolutional encoder and decoder layers to preserve temporal resolution of 244 | sequential data. 245 | 246 | ``` 247 | from sequence_models.convolutional import ByteNet 248 | from sequence_models.constants import trR_ALPHABET 249 | 250 | n_tokens = len(trR_ALPHABET) # number of tokens in token dictionary 251 | d_embedding = 128 # dimension of embedding 252 | d_model = 128 # dimension to use within ByteNet model, //2 every layer 253 | n_layers = 3 # number of layers of ByteNet block 254 | kernel_size = 3 # the kernel width 255 | r = ??? # used to calculate dilation factor 256 | padding_idx = trR_ALPHABET.index('-') # location of padding token in ordered alphabet 257 | causal = True # if True, chooses MaskedCausalConv1d() over MaskedConv1d() 258 | dropout = 0.1 259 | 260 | x = torch.randn(32, 128) # input (n samples, len of seqs) 261 | input_mask = torch.ones(32, 128, 1) # mask (n samples, len of seqs, 1) 262 | model = ByteNet(n_tokens, d_embedding, d_model, n_layers, kernel_size, r, 263 | padding_idx=padding_idx, causal=causal, dropout=dropout) 264 | out = model(x, input_mask) 265 | ``` 266 | 267 | We have also an implemented versions of ```ByteNet``` to be able to use 2d inputs (```ByteNet2d```) 268 | and as a language model (```ByteNetLM```): 269 | 270 | ``` 271 | from sequence_models.convolutional import ByteNet2d, ByteNetLM 272 | 273 | x = torch.randn(32, 128, 128, 64) # input (n samples, len of seqs, len of seqs, feature dimension) 274 | input_mask = torch.ones # (n samples, len of seqs, len of seqs, 1), optional 275 | model = ByteNet2d(d_in, d_model, n_layers, kernel_size, r, dropout=0.0) 276 | out = model(x, input_mask) 277 | 278 | x = torch.randn(32, 128) # input (n samples, len of seqs) 279 | input_mask = torch.ones(32, 128, 1) # mask (n samples, len of seqs, 1) 280 | model = ByteNetLM(n_tokens, d_embedding, d_model, n_layers, kernel_size, r, 281 | padding_idx=None, causal=False, dropout=0.0) 282 | out = model(x, input_mask) 283 | ``` 284 | 285 | * trRosetta 286 | The ```trRosetta``` model was implemented according to [Yang et al.](https://www.pnas.org/content/117/3/1496). In this model, multiple sequence 287 | alignments (MSAs) are used to predict distances between amino acid residues as well as their dihedral 288 | angles (omega, theta, phi). Predictions are in the format of bins. Omega, theta and phi angle are binned into 24, 24, and 12 bins, respectively 289 | with 15 degrees segments and one no-contact bin. [Yang et al.](https://www.pnas.org/content/117/3/1496) has pretrained five models (model ids: 'a', 'b', 'c', 'd', 'e'). To run 290 | a single model: 291 | 292 | ``` 293 | from sequence_models.trRosetta_utils import trRosettaPreprocessing, parse_a3m 294 | from sequence_models.trRosetta import trRosetta 295 | from sequence_models.constants import trR_ALPHABET 296 | 297 | msas = parse_a3m(path_to_msa) # load in msas in a3m format 298 | alphabet = trR_ALPHABET # load your alphabet order 299 | tr_preprocessing = trRosettaPreprocessing(alphabet) # setup preprocessor for msa 300 | msas_processed = tr_preprocessing.process(msas) 301 | 302 | n2d_layers = 61 # keep at 61 if you want to use pretrained version 303 | model_id = 'a' # choose your pretrained model id 304 | decoder = True # if True, return 2d structure maps, else returns hidden layer 305 | p_dropout = 0.0 306 | model = trRosetta(n2d_layers, model_id, decoder, p_dropout) 307 | out = model(msas_processed) # returns dist_probs, theta_probs, phi_probs, omega_probs 308 | ``` 309 | 310 | To run an ensemble of models: 311 | ``` 312 | from sequence_models.trRosetta_utils import trRosettaPreprocessing, parse_a3m 313 | from sequence_models.trRosetta import trRosetta, trRosettaEnsemble 314 | from sequence_models.constants import trR_ALPHABET 315 | 316 | msas = parse_a3m(path_to_msa) # load in msas in a3m format 317 | alphabet = trR_ALPHABET # load your alphabet order 318 | tr_preprocessing = trRosettaPreprocessing(alphabet) # setup preprocessor for msa 319 | msas_processed = tr_preprocessing.process(msas) 320 | 321 | n2d_layers = 61 # keep at 61 if you want to use pretrained version 322 | model_ids = 'abcde' # choose your pretrained model id 323 | decoder = True # if True, return 2d structure maps, else returns hidden layer 324 | p_dropout = 0.0 325 | base_model = trRosetta 326 | model = trRosettaEnsemble(base_model, n2d_layers, model_ids) 327 | out = model(msas_processed) 328 | ``` 329 | 330 | If you would like to convert bin prediction into actual values, use ```probs2value```. 331 | Here is an example of converting distance bin predictions into values: 332 | 333 | ``` 334 | from sequence_models.trRosetta_utils import probs2value 335 | 336 | dist_probs, theta_probs, phi_probs, omega_probs = model(x) # structure predictions (batch, # of bins, len of seq, len of seq) 337 | preperty = 'dist' # choose between 'dist', 'theta', 'phi', or 'omega' 338 | mask = mask # your 2d mask (batch, len of seq, len of seq) 339 | dist_values = probs2value(dist, property, mask): 340 | 341 | ``` 342 | 343 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /azure-pipelines.yml: -------------------------------------------------------------------------------- 1 | trigger: 2 | - master 3 | pool: 4 | vmImage: 'windows-latest' 5 | 6 | steps: 7 | - task: CredScan@2 8 | inputs: 9 | toolMajorVersion: 'V2' 10 | 11 | - task: Semmle@1 12 | env: 13 | SYSTEM_ACCESSTOKEN: $(System.AccessToken) 14 | inputs: 15 | sourceCodeDirectory: '$(Build.SourcesDirectory)' 16 | language: 'python' 17 | querySuite: 'Recommended' 18 | timeout: '1800' 19 | ram: '16384' 20 | addProjectDirToScanningExclusionList: true 21 | 22 | - task: ComponentGovernanceComponentDetection@0 23 | inputs: 24 | scanType: 'Register' 25 | verbosity: 'Verbose' 26 | alertWarningLevel: 'High' 27 | -------------------------------------------------------------------------------- /examples/gb1_a60fb_unrelaxed_rank_1_model_5.pdb.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/protein-sequence-models/af695772c4a1c056d930c95ec7e6428aa042f5cd/examples/gb1_a60fb_unrelaxed_rank_1_model_5.pdb.gz -------------------------------------------------------------------------------- /examples/gb1s.csv: -------------------------------------------------------------------------------- 1 | name,sequence,pdb 2 | wt,MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGVDGEWTYDDATKTFTVTE,gb1_a60fb_unrelaxed_rank_1_model_5.pdb 3 | Q1P,MPYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGVDGEWTYDDATKTFTVTE,gb1_a60fb_unrelaxed_rank_1_model_5.pdb 4 | Y2R,MQRKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGVDGEWTYDDATKTFTVTE,gb1_a60fb_unrelaxed_rank_1_model_5.pdb 5 | -------------------------------------------------------------------------------- /examples/some_proteins.fasta: -------------------------------------------------------------------------------- 1 | >UniRef50_A0A1E3NP16 2 | AVIYYRFRSQKPDHIATIKFDGTGLTVFELKRDIILANNLLHSTDVDIVLYSTEDIQDTKSWGYQNGGSSSAGERELDDDNEVVPRSTTVLVRRTMTPKKNKGNVQRYVAGKPRLQVSGTNSVNKSISLGNNVGGTMNFGDAATNGDEDDMIKKMFSVQDEQWSQQQDVMATATRVDNFRTNVNEPVPEYYICYKCGEKGKHHIKNCPKNNDPNWEGVRVRKTTGIPKSHLKAIENPEDTIRDSNSSGNTTYMVNDEGKYVVAVADTKAWEKYQKTKKGESGGYLNGDVDVDDGELKDPETGKLWKSPVRIPCCNKIFSRKIIEDKLIDSDFTCPSCGKEQIYLDTLVADEELQAKVDEYVKNLSENKNNDGNSPKRRQVNPAGATANTSQLPQIPMMPMPPINMQMPPMNIGMPPFMPFMPMPGMNP 3 | >UniRef50_UPI000836A30F 4 | MTLRTLLALSILALAAAATVQARPGAPPCSPLGLKYQPGACEKWKREHPDVNPDGVVQTVTIVNNSSSVLGGYVTFWHANNEHTDVDLPGVKPGETWTANGSWTVGQAPYYLLYSSFQDSYGDPVTFYAVPISKYPALAKKLPPEPDNCQSNHFRMVFGDGPQYVYEQHSGAVVTGGQTDKNTLCQILGCPTGGSTAGGLNTQNRTTSRSAPQPVYFRSCDP 5 | >UniRef50_A0A2G8L8Y3 6 | MAQKRYEENLSPYAELFRSKLIEDFHLLESFDEHGKSDAPKYYSKDFEDPARQDKMMLENPHGLVKFQVYSRKEPGEHFMLVLILSNSIALGLQAEVSESDDPKFAGLKLALDIFDYCSLFLFMVEIILKWIDNFWSFWSDNWNIFDFAVTVGSFVPEIINFFAGDIGGSMVRVIVRNLRVFRILRSLKMVSRFRQVRLIALAIGKAFSAITFIMLLLFTFLYIFAITGIIFFDTYTRSERQDLKYKDSFRSLPRAMITLFQLFTLDQWYKLLNDMWKVMDSMIPLGYIILWICIGSFIFRNVFVGIMVNNFQSIRNDLFEEVKEQEAARQIIQDTEKFNEELSRQEKKLNANRRGTLYQSPTVQPPKPNQPQPSQLAGLDNSETDEQSVSQEDESNTDGQTSLSGTDSYDLLGESSDSLFRRSSDGMIDKDKLSTNWEKTVHDNLTLLTSTPSETLWPRDTLFRYFQLMESLMENLQERQDLQDLAYHSLLQIFDSFDTSA 7 | >UniRef50_A0A1D5ZRM3 8 | MPCVAHECHPRLPAANHCRSLSCLGTPAAGWSSGDDDREEDELDTKQVILNEMRNREMRKRSSRCSVDSPTLSGAFAWSFTPLHPRSSIEKVSCTEEEKEAASDSDNESEAFFSVKSFFTRSTSRAATVASSTDMDPPATWEGLRGCEGWPFGLCP 9 | >UniRef50_UPI0003108055 10 | MPADAREYLESKHATRRFDRPAEVAGVVAFLLSDDTSFVIGAGYLVDGGYTALRAARGRLGPRRQAAPAVKLLPNTDSPR 11 | >UniRef50_A0A223SCH7 12 | MGPPRWWKGITGLAAVVHRADPEDKADLYAKMGLYLEYHPETRIVEARIKPRLHDVCESKVSEGGLEPPCP 13 | >UniRef50_A0A090SUK6 14 | MKTPEDRVFRATDEYSDFVMACRYKGNEREFVIASHDKLNEAQVETLTSYLSGEWFKKTYITGIMNDSDGVLSQHEEYGDEVFCQPLDELRVDRYIMMV 15 | >UniRef50_V4AGU2 16 | MHLGSIYLMVVLLIYFAYTDDRKERENVDVINPEENLVQDDEQYVGDSTENIEKSGSEEEEEDKEAIEEEEDEEEELNYRYIPEPAQDIVDNGKKYIQVHCSFKQDESLIRHPSNCSRYFVCSYGVVEEMPVCDDGEVFSIQVSECVKKGSENDDCDKLPFDSPPEITRGTQPSLIWHPRHKSPRSQFRQPTTLQMKVPHIELESFTCSATGKVLSHHSENCAWYYNCSAHPDAVMQTFYSGFIMECPYPQLFSTETKQCEDFEDVKCGDRYEPKSPCDYRANHCHETSHCIPCWVRYASCLELPDGLNPWSELEWKPFFVECYKERTVFQGVCDKSAVFSPLTRACETPYSIPRQHGGWRPVCDGRRDGIYADEYGRCDIYYVCKGYIFTGFFRCEKGEMFNPVISICQKPEAVPYPCGDLEMPNICESSLNGYHLDMFGRCTHYFECKDQQLEGISMCPSGIFNPELQICESSRDQPKPCGNLTNLCTHKNDGFHSDENDCTKVFQCERGLTMTSYDCSGSVRTECDVCNTPTECNDKPNGLYPNLKEGVGYYYDCVRSQIQNHYKCDKEKGGPIFNPVKQRCFYPEDLCKEVFSLKIAW 17 | >UniRef50_R5PKX9 18 | MQGGEQDVFEHRQVREKVIALKNHADAAAQGAAEFERLSFEQDVAALDGFETDQAAQKRGFAAARRPENHRDFFVVKREVDAVENHSVAELLYETARFQNNVIGHFYAFHFFSRALAASDTGQHARK 19 | >UniRef50_UPI0003674933 20 | MKKKEDILNLEHTLLPWMGRTMKVLDYFIGDFLNLKGIELTKVQWILLKKLNEQNGQPQQNLAFLTNRDKASLARLITTMEKKKLVERIPSKIDGRINHIFITKHGCEILQKSAPVIEKVVGCIQEGISPEEIETVIKVMQKVNNNINRASN 21 | >UniRef50_E9K9Y4 22 | MADNVLMAYHIVHDPDERAKHVLNTKKLYKWRITEKTKGTPVVGNVALVQTQFAKRTPVMIYATKEVANDLSDLQPVKVFTNNRDQETVNQTFDDLMR 23 | >UniRef50_A0A2C9LWN7 24 | MGKITGLLFFLFTSVRVTPSMKNTVNDIYRVRKDVLTKYRNTEEYDLGQRKGPLKQMDTRETRTPYGHFNSDTAFISPKNILKATRPLAFHLGNLKNQTSPCSTWNLTVDCSYKSLDHIESSWFPSNTTVLLLNNNKLVTLHNETFAQLTNLTRLDLSSNDIRRIDAGAFQGLHNLQELNLHMHCCNFTDHYSLESVFAPLRNLRILNAMHNSDVGVLTYSYTFLTRLPLLQSLSIDFDLDTLYCGPEFNDLKNLTFLQFSGQVMYIDDRSFQNVAQLKNLSMDHLSNINNISHNAFKPLSNLKVLTMYHVLLYVQEILSLLEPFQGRNMTEITLDTTTRTLTQVNPTRNGILTNHDTKYLMNICLESFTLIDNRIFYIKPDAVQNIYTWKKCLMHLYIASNPIQGNNFALIRLFTLDNLKSFTFINMFRACHEFQPFPQSSPPATRNVASSSISSQNNHYQKDTTSNQQQMIRHSPFSPYLDMIDENPYQLNIPNYIFISPSLQYMNFQRLVMSQSFEYHFILVGAQNVTSLDISDSGFYRFNGLMEGVSAIKTLIISGNDVSVLSVSFFDTFVSLENFAISSCKLDRDFISLNSRRIFQNHTRLQELDISSNSLNYLSQNTFSYNNRLMWLNMSGNQFKDIPFDLTNTPELQFLDIRFNSLTTIDETTAQQMDHLVTKSGKLEILLEGNVLSCSCSDLSFMRWMRMTLVTFDQNGNFTCMNTDGERKYTLDYSNLDSLWRECWGSFFLYFALIMLCLYCIGVFAVFITMRNKNFIVSFFLQLFGGFKLHSRRDYPVGVYIGYSDKDYQFPCKELRSFIESSLKLKTFLIDRDLIASVDKASGIIEALNASWRILLVCSKSFLKEDDWSMFTMRSAIYTQTPANPARVVVLVHKDCLPLLPPALLSSVNDENICAVSEWAMNYEMMQMLTTRLH 25 | >UniRef50_Q9REE6 26 | MSLRGRELLTSEERLELVRIPEDISEQELGRNFTLSNFDLELIKNRRRDYNRLGFAVQLCVLRFPGWSLNDAEPIPKKVLQHLARQLHVDPDCFSLYSSREA 27 | >UniRef50_A0A226D4M8 28 | MQKVINFPIWRRYYFSECGNCNDFRPDGVKKVHPPQEKSRTMDDEILAAPEVTIPFEPSDPSEVVVNLISSEEEDDDDVIQIVEEKSVDKAERQRRRQKKKDLWAARKLQRNKGQNVPLQTAWQRGPRPQETSSFPTPPQQSGGAQQKPLSPILISTAGSPNTSGAPTPANVSQQPTPTPSFVHTTASTSTHPEISLNINSDLALLIRLGPDGRPILTRVENVEQNNTSTSTPTTTRKLPPAPPPPKISFDTQTGESLLNGELITRPIIDITTDSPPSITHAATVSPQTSRTSGPPTLSPISPPPRTTQSNPAHPPPRYEPRKSRHPPRDPLASSSSSSSSSPSPPPTSRARHTSSANIPPPLEPLFLNTTQLIHLIKTCRKCEKSFPTRCDGVIHQKKEHNRKHCPVCFLTLSRHGNTYKDHLNMYHALEGDKEMVVCPFCAVEHSFDGLYNHIGRSHLIPVESKGESEHEVIFSVAPSPQSNGHQTRSVTQGNTPPKNDTNSPPNLEKRRPGPASKTRKTTNDPVPSTSRTGLICHKYVPDVETPPSKAGRNFESRAGRNFESRNVYPPATNRLNPPRNKSPPRNKNLPRNKSPPRNKSLPRNKTPPPSSSRSSSSRSVSNNLRRKNPTPPPPTQPPPKKVAKPDEAGINEKIQAAIKAVNARVHIERSVHHPNNRSDREIPSTSRTVTSRHKVPTSKTSGNTSVRKDPSPPPPLQTPPKTTNLELSKVQKARIVEDVRSIVRDVRITRVLDDGEVPSTSRAVEEEKKKEEKKNETRARLPRSSRVGERSSGYFQMAEGIDFSPENPTPRSKDLKAIHISKLNLIYQQLKCYPDTSVIANVAKECGVEIPVVAKWFTKKHMEYCQKTQQRKRKRKPPELR 29 | >UniRef50_UPI000B82D8F0 30 | MGGLHLIELRNVNIEFDKKLIEDGTIKIYDGKITAIIGESGSGKTSLLYLLGLISSNHRYLYSFDDVTLDLSNDFEMSRIRKQKIGYIFQDNNLVENLTIFENIRLSATIAGINITDKEIKSYLEFVELGYIDSNHYPRKLSGGERQRVAIACALAKQPELILADEPTSALDTVNSEIIMGIFKKIAHKDKKKIVIATHNDRIYNEADVIYEIKNNKIQLVKGESSNESSKKEPEDYDNNVKLTPRFYFDYAIKTSRKGRFAKNLMIVLSAIAIAFSSVMYNFGDTFVKEQEKLMDAISDKEIFVVNMTAPLNTILDIDENLSIQDKDAELLRNISYVDTIYPYFEFRSIGYPLINETEASEGYIVVSKGQKEEKYTFAESKDNPYDKYVIIPYYPEQNLERRLKEKLSDESSDKVYISSQLAQLLGIENLKESVSLRVLTYVPIAQHETQMTVRPEGIVYEIDIDLSKVVELDLKIEGILDESVRNRYSNSGNNAIYVPYHKMQMILTQTQNSATIDTNLEYIEWRPSAFVVFAKSYNDVGLVIERVSSINPNFRAVSEYQDIESMNAIVKNTREIGLVIVIVILIIIFLLMSIIHMNHILDRKYEISLLKANGLTKIELTKLVSVESLRHVFLVSLISSVISLVVTKVMNLLFEEIA -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/protein-sequence-models/af695772c4a1c056d930c95ec7e6428aa042f5cd/requirements.txt -------------------------------------------------------------------------------- /scripts/extract.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch.cuda 4 | from tqdm import tqdm 5 | import numpy as np 6 | import pandas as pd 7 | 8 | from sequence_models.utils import parse_fasta 9 | from sequence_models.pretrained import load_model_and_alphabet 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('model') 13 | parser.add_argument('in_fpath') 14 | parser.add_argument('out_dir') 15 | parser.add_argument('--repr_layers', nargs='*', default=[-1]) 16 | parser.add_argument('--include', nargs='*', default=['mean']) 17 | parser.add_argument('--device', default=None) 18 | parser.add_argument('--batchsize', default=1, type=int) 19 | args = parser.parse_args() 20 | 21 | # load model 22 | print('Loading model...') 23 | model, collater = load_model_and_alphabet(args.model) 24 | # detect device and move model to it 25 | if args.device is None: 26 | if torch.cuda.is_available(): 27 | device = torch.device('cuda:0') 28 | else: 29 | device = torch.device('cpu') 30 | else: 31 | device = torch.device(args.device) 32 | model = model.to(device) 33 | # load data 34 | print('Loading data...') 35 | seqs, names = parse_fasta(args.in_fpath, return_names=True) 36 | ells = [len(s) for s in seqs] 37 | seqs = [[s] for s in seqs] 38 | n_total = len(seqs) 39 | repr_layers = [] 40 | for r in args.repr_layers: 41 | if r == 'logits': 42 | logits = True 43 | else: 44 | repr_layers.append(int(r)) 45 | if 'logp' in args.include: 46 | logps = np.empty(len(seqs)) 47 | with tqdm(total=n_total) as pbar: 48 | for i in range(0, n_total, args.batchsize): 49 | start = i 50 | end = start + args.batchsize 51 | bs = seqs[start:end] 52 | bn = names[start:end] 53 | bl = ells[start:end] 54 | # tokenize 55 | x = collater(bs)[0].to(device) 56 | # pass through the model 57 | results = model(x, repr_layers=repr_layers, logits=logits) 58 | if 'representations' in results: 59 | for layer, rep in results['representations'].items(): 60 | for r, ell, name in zip(rep, bl, bn): 61 | r = r[:ell] 62 | if 'mean' in args.include: 63 | torch.save(r.mean(dim=0).detach().cpu(), 64 | args.out_dir + '_'.join([name, args.model, str(layer), 'mean']) + '.pt') 65 | if 'per_tok' in args.include: 66 | torch.save(r.detach().cpu(), 67 | args.out_dir + '_'.join([name, args.model, str(layer), 'per_tok']) + '.pt') 68 | if logits: 69 | rep = results['logits'] 70 | for r, ell, name, src in zip(rep, bl, bn, x): 71 | if 'per_tok' in args.include: 72 | r = r[:ell] 73 | torch.save(r.detach().cpu(), args.out_dir + '_'.join([name, args.model, 'logits']) + '.pt') 74 | if 'logp' in args.include: 75 | r = r.log_softmax(dim=-1)[:ell] 76 | logps[i] = r[torch.arange(len(src)), src].mean().detach().cpu().numpy() 77 | pbar.update(len(bs)) 78 | if 'logp' in args.include: 79 | df = pd.DataFrame() 80 | df['name'] = names 81 | df['sequence'] = [s[0] for s in seqs] 82 | df['logp'] = logps 83 | out_fpath = args.out_dir + args.model + '_logp.csv' 84 | print('Writing results to ' + out_fpath) 85 | df = df.to_csv(out_fpath, index=False) 86 | 87 | 88 | 89 | -------------------------------------------------------------------------------- /scripts/extract_mif.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch.cuda 4 | from tqdm import tqdm 5 | import pandas as pd 6 | import numpy as np 7 | 8 | from sequence_models.utils import parse_fasta 9 | from sequence_models.pretrained import load_model_and_alphabet 10 | from sequence_models.pdb_utils import parse_PDB, process_coords 11 | 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('model') 15 | parser.add_argument('csv_fpath') 16 | parser.add_argument('pdb_dir') 17 | parser.add_argument('out_dir') 18 | parser.add_argument('result') 19 | parser.add_argument('--include', nargs='*', default=['mean']) 20 | parser.add_argument('--device', default=None) 21 | args = parser.parse_args() 22 | 23 | # Check inputs 24 | if args.model not in ['mif', 'mifst']: 25 | raise ValueError("Valid models ars 'mif' and 'mifst'.") 26 | if args.result == 'logits': 27 | for inc in args.include: 28 | if inc not in ['per_tok', 'logp']: 29 | raise ValueError("logits can be included as 'per_tok' or as 'logp'.") 30 | elif args.result == 'repr': 31 | for inc in args.include: 32 | if inc not in ['per_tok', 'mean']: 33 | raise ValueError("repr can be included as 'per_tok' or as 'mean'.") 34 | else: 35 | raise ValueError("Valid results ars 'repr' and 'logits'.") 36 | 37 | # load model 38 | print('Loading model...') 39 | model, collater = load_model_and_alphabet(args.model) 40 | # detect device and move model to it 41 | if args.device is None: 42 | if torch.cuda.is_available(): 43 | device = torch.device('cuda:0') 44 | else: 45 | device = torch.device('cpu') 46 | else: 47 | device = torch.device(args.device) 48 | model = model.to(device) 49 | # load data 50 | print('Loading data...') 51 | df = pd.read_csv(args.csv_fpath).reset_index() 52 | if 'logp' in args.include: 53 | logps = np.empty(len(df)) 54 | with tqdm(total=len(df)) as pbar: 55 | for i, row in df.iterrows(): 56 | seq = row['sequence'] 57 | pdb = row['pdb'] 58 | name = row['name'] 59 | coords, wt, _ = parse_PDB(args.pdb_dir + pdb) 60 | coords = { 61 | 'N': coords[:, 0], 62 | 'CA': coords[:, 1], 63 | 'C': coords[:, 2] 64 | } 65 | dist, omega, theta, phi = process_coords(coords) 66 | batch = [[seq, torch.tensor(dist, dtype=torch.float), 67 | torch.tensor(omega, dtype=torch.float), 68 | torch.tensor(theta, dtype=torch.float), torch.tensor(phi, dtype=torch.float)]] 69 | src, nodes, edges, connections, edge_mask = collater(batch) 70 | src = src.to(device) 71 | nodes = nodes.to(device) 72 | edges = edges.to(device) 73 | connections = connections.to(device) 74 | edge_mask = edge_mask.to(device) 75 | rep = model(src, nodes, edges, connections, edge_mask, result=args.result)[0] 76 | if args.result == 'repr': 77 | if 'mean' in args.include: 78 | torch.save(rep.mean(dim=0).detach().cpu(), 79 | args.out_dir + '_'.join([name, args.model, 'mean']) + '.pt') 80 | if 'per_tok' in args.include: 81 | torch.save(rep.detach().cpu(), 82 | args.out_dir + '_'.join([name, args.model, 'per_tok']) + '.pt') 83 | else: 84 | if 'logp' in args.include: 85 | rep = rep.log_softmax(dim=-1) 86 | logps[i] = rep[torch.arange(len(src[0])), src].mean().detach().cpu().numpy() 87 | if 'per_tok' in args.include: 88 | torch.save(rep.detach().cpu(), args.out_dir + '_'.join([name, args.model, 'logits']) + '.pt') 89 | pbar.update(1) 90 | if 'logp' in args.include: 91 | df['logp'] = logps 92 | out_fpath = args.out_dir + args.model + '_logp.csv' 93 | print('Writing results to ' + out_fpath) 94 | df = df.to_csv(out_fpath, index=False) 95 | 96 | 97 | -------------------------------------------------------------------------------- /sequence_models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/protein-sequence-models/af695772c4a1c056d930c95ec7e6428aa042f5cd/sequence_models/__init__.py -------------------------------------------------------------------------------- /sequence_models/aaindex.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import numpy as np 4 | import pandas as pd 5 | import wget 6 | from sklearn.decomposition import PCA 7 | from sklearn.preprocessing import StandardScaler 8 | class AAIndexTokenizer(object): 9 | """Convert between strings and their AAIndex representations.""" 10 | def __init__(self, dpath: str, n_comp: int = 20): 11 | """ 12 | Args: 13 | dpath: directory to save raw and reduced representations 14 | n_comp: number of components in PCA 15 | """ 16 | alphabet = AAINDEX_ALPHABET 17 | if not os.path.exists(dpath): 18 | os.mkdir(dpath) 19 | if not os.path.exists(dpath + '/aaindex1'): 20 | file = wget.download('ftp://ftp.genome.jp/pub/db/community/aaindex/aaindex1', 21 | out=dpath + '/' + 'aaindex1') 22 | if not os.path.exists(dpath + '/raw_aaindex.json'): 23 | raw_dict = {i: [] for i in alphabet} 24 | with open(dpath + '/aaindex1', 'r') as f: 25 | for line in f: 26 | if line[0] == 'I': 27 | set1 = next(f).strip().split() 28 | set2 = next(f).strip().split() 29 | set = set1 + set2 30 | for i in range(len(alphabet)): 31 | val = set[i] 32 | if val == 'NA': 33 | val = None 34 | else: 35 | val = float(val) 36 | raw_dict[alphabet[i]].append(val) 37 | with open(dpath + '/raw_aaindex.json', 'w') as f: 38 | json.dump(raw_dict, f) 39 | if not os.path.exists(dpath + '/red_aaindex.json'): 40 | with open(dpath + '/raw_aaindex.json') as f: 41 | raw_dict = json.load(f) 42 | # preprocessing : drop embeddings with missing data (drop 13) 43 | embed_df = pd.DataFrame(raw_dict).dropna(axis=0) 44 | embed = embed_df.values.T # (len(alphabet), 553) 45 | # scale to 0 mean and unit variance 46 | scaler = StandardScaler() 47 | embed = scaler.fit_transform(embed) 48 | # PCA 49 | pca = PCA(n_components=n_comp, svd_solver='auto') 50 | embed_red = pca.fit_transform(embed) 51 | print('VARIANCE EXPLAINED: ', pca.explained_variance_ratio_.sum()) 52 | red_dict = {alphabet[i]: list(embed_red[i, :]) for i in range(len(alphabet))} 53 | with open(dpath + '/red_aaindex.json', 'w') as f: 54 | json.dump(red_dict, f) 55 | # save reduced representation 56 | with open(dpath + '/red_aaindex.json') as f: 57 | self.red_dict = json.load(f) 58 | def tokenize(self, seq: str) -> np.ndarray: 59 | """ 60 | Args: 61 | seq: str 62 | amino acid sequence 63 | Returns: 64 | encoded: np.array 65 | encoded amino acid sequence based on reduced AAIndex representation, (L*n_comp,) 66 | """ 67 | encoded = np.concatenate([self.red_dict[a] for a in seq]) 68 | return encoded -------------------------------------------------------------------------------- /sequence_models/collaters.py: -------------------------------------------------------------------------------- 1 | from typing import List, Any, Iterable 2 | import random 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | from sequence_models.utils import Tokenizer 9 | from sequence_models.constants import PAD, GAP, START, STOP, MASK, MSA_PAD, PROTEIN_ALPHABET 10 | from sequence_models.constants import ALL_AAS 11 | from sequence_models.gnn import get_node_features, get_edge_features, get_mask, get_k_neighbors, replace_nan 12 | from sequence_models.trRosetta_utils import trRosettaPreprocessing 13 | 14 | 15 | def _pad(tokenized: List[torch.Tensor], value: int) -> torch.Tensor: 16 | """Utility function that pads batches to the same length.""" 17 | batch_size = len(tokenized) 18 | max_len = max(len(t) for t in tokenized) 19 | output = torch.zeros((batch_size, max_len), dtype=tokenized[0].dtype) + value 20 | for row, t in enumerate(tokenized): 21 | output[row, :len(t)] = t 22 | return output 23 | 24 | 25 | class BGCCollater(object): 26 | """A collater for BiGCARP models.""" 27 | 28 | def __init__(self, tokens, pfam_to_domain): 29 | self.tokens = tokens 30 | self.pfam_to_domain = pfam_to_domain 31 | 32 | def __call__(self, domains): 33 | data = tuple(zip(*domains)) 34 | sequences = data[0] 35 | t = [] 36 | for sequence in sequences: 37 | tok = [] 38 | for pfam in sequence.split(';'): 39 | if pfam in self.tokens['specials']: 40 | tok.append(self.tokens['specials'][pfam]) 41 | continue 42 | if pfam in self.pfam_to_domain: 43 | domain = self.pfam_to_domain[pfam] 44 | else: 45 | domain = 'UNK' 46 | tok.append(self.tokens['domains'][domain]) 47 | t.append(torch.tensor(tok)) 48 | t = _pad(t, self.tokens['specials'][PAD]) 49 | return (t, ) 50 | 51 | 52 | class TokenCollater(object): 53 | """A collater that pads batches of tokens.""" 54 | 55 | def __init__(self, pad_idx): 56 | self.pad_idx = pad_idx 57 | 58 | def __call__(self, batch: List[torch.Tensor]) -> List[torch.tensor]: 59 | data = tuple(zip(*batch)) 60 | sequences = data[0] 61 | sequences = _pad(sequences, self.pad_idx) 62 | return [sequences] 63 | 64 | 65 | class SimpleCollater(object): 66 | """A collater that pads and possibly reverses batches of sequences. 67 | 68 | Parameters: 69 | alphabet (str) 70 | pad (Boolean) 71 | backwards (Boolean) 72 | 73 | If sequences are reversed, the padding is still on the right! 74 | 75 | Input (list): a batch of sequences as strings 76 | Output (torch.LongTensor): tokenized batch of sequences 77 | """ 78 | 79 | def __init__(self, alphabet: str, pad=False, backwards=False, pad_token=PAD, start=False, stop=False): 80 | self.pad = pad 81 | self.tokenizer = Tokenizer(alphabet) 82 | self.backwards = backwards 83 | self.pad_idx = self.tokenizer.alphabet.index(pad_token) 84 | self.start = start 85 | self.stop = stop 86 | 87 | def __call__(self, batch: List[Any], ) -> List[torch.Tensor]: 88 | data = tuple(zip(*batch)) 89 | sequences = data[0] 90 | prepped = self._prep(sequences) 91 | return prepped 92 | 93 | def _prep(self, sequences): 94 | if self.start: 95 | sequences = [START + s for s in sequences] 96 | if self.stop: 97 | sequences = [s + STOP for s in sequences] 98 | if self.backwards: 99 | sequences = [s[::-1] for s in sequences] 100 | sequences = [torch.LongTensor(self.tokenizer.tokenize(s)) for s in sequences] 101 | if self.pad: 102 | sequences = _pad(sequences, self.pad_idx) 103 | else: 104 | sequences = torch.stack(sequences) 105 | return (sequences,) 106 | 107 | 108 | class TAPECollater(SimpleCollater): 109 | """Collater for TAPE datasets. 110 | 111 | For single-dimensional outputs, this pads the sequences and masks and batches everything. 112 | 113 | For ss, this also pads the output with -100. 114 | 115 | For contacts, this pads the contacts on the bottom and right. 116 | """ 117 | 118 | def __init__(self, alphabet: str, pad=True, start=False, stop=False): 119 | super().__init__(alphabet, pad=pad, start=start, stop=stop) 120 | 121 | def __call__(self, batch: List[Any], ) -> List[torch.Tensor]: 122 | data = tuple(zip(*batch)) 123 | sequences = data[0] 124 | prepped = self._prep(sequences) 125 | y = data[1] 126 | mask = data[2] 127 | if isinstance(y[0], float): 128 | y = torch.tensor(y).unsqueeze(-1) 129 | 130 | elif isinstance(y[0], int): 131 | y = torch.tensor(y) 132 | 133 | elif len(y[0].shape) == 1: # secondary structure 134 | y = _pad(y, -100).long() 135 | 136 | elif len(y[0].shape) == 2: # contact 137 | max_len = max(len(i) for i in y) 138 | mask = [F.pad(mask_i, 139 | (0, max_len - len(mask_i), 0, max_len - len(mask_i)), value=False) for mask_i in mask] 140 | mask = torch.stack(mask, dim=0) 141 | y = [F.pad(yi, (0, max_len - len(yi), 0, max_len - len(yi))) for yi in y] 142 | y = torch.stack(y, dim=0).float() 143 | 144 | return prepped[0], y, mask 145 | 146 | 147 | class LMCollater(SimpleCollater): 148 | """Collater for autoregressive sequence models. 149 | 150 | Parameters: 151 | alphabet (str) 152 | pad (Boolean) 153 | backwards (Boolean) 154 | 155 | If sequences are reversed, the padding is still on the right! 156 | 157 | Input (list): a batch of sequences as strings 158 | Output: 159 | src (torch.LongTensor): START + input + padding 160 | tgt (torch.LongTensor): input + STOP + padding 161 | mask (torch.LongTensor): 1 where tgt is not padding 162 | """ 163 | 164 | def __init__(self, alphabet: str, pad=False, backwards=False, pad_token=PAD): 165 | super().__init__(alphabet, pad=pad) 166 | self.backwards = backwards 167 | self.pad_idx = self.tokenizer.alphabet.index(pad_token) 168 | 169 | 170 | def _prep(self, sequences): 171 | return self._tokenize_and_mask(*self._split(sequences)) 172 | 173 | def _split(self, sequences): 174 | if not self.backwards: 175 | src = [START + s for s in sequences] 176 | tgt = [s + STOP for s in sequences] 177 | else: 178 | src = [STOP + s[::-1] for s in sequences] 179 | tgt = [s[::-1] + START for s in sequences] 180 | return src, tgt 181 | 182 | def _tokenize_and_mask(self, src, tgt): 183 | src = [torch.LongTensor(self.tokenizer.tokenize(s)) for s in src] 184 | tgt = [torch.LongTensor(self.tokenizer.tokenize(s)) for s in tgt] 185 | mask = [torch.ones_like(t) for t in tgt] 186 | src = _pad(src, self.pad_idx) 187 | tgt = _pad(tgt, self.pad_idx) 188 | mask = _pad(mask, 0) 189 | return src, tgt, mask 190 | 191 | 192 | class AncestorCollater(LMCollater): 193 | """Collater for autoregressive sequence models with ancestors. 194 | 195 | Parameters: 196 | alphabet (str) 197 | pad (Boolean) 198 | backwards (Boolean) 199 | 200 | If sequences are reversed, the padding is still on the right! 201 | 202 | Input (list): a batch of sequences as strings 203 | Output: 204 | src (torch.LongTensor): START + input + STOP + ancestor + padding 205 | tgt (torch.LongTensor): input + STOP + ancestor + STOP + padding 206 | mask (torch.LongTensor): 1 where tgt is not padding 207 | """ 208 | 209 | def __call__(self, batch): 210 | data = tuple(zip(*batch)) 211 | sequences, ancestors = data[:2] 212 | prepped = self._prep(sequences, ancestors) 213 | return prepped 214 | 215 | def _prep(self, sequences, ancestors): 216 | if self.backwards: 217 | sequences = [s[::-1] for s in sequences] 218 | ancestors = [a[::-1] for a in ancestors] 219 | src = [START + s + STOP + a for s, a in zip(sequences, ancestors)] 220 | tgt = [s + STOP + a + STOP for s, a in zip(sequences, ancestors)] 221 | return self._tokenize_and_mask(src, tgt) 222 | 223 | 224 | class MLMCollater(SimpleCollater): 225 | """Collater for masked language sequence models. 226 | 227 | Parameters: 228 | alphabet (str) 229 | pad (Boolean) 230 | 231 | Input (list): a batch of sequences as strings 232 | Output: 233 | src (torch.LongTensor): corrupted input + padding 234 | tgt (torch.LongTensor): input + padding 235 | mask (torch.LongTensor): 1 where loss should be calculated for tgt 236 | """ 237 | 238 | def __init__(self, alphabet: str, pad=False, backwards=False, pad_token=PAD, mut_alphabet=ALL_AAS, startstop=False): 239 | super().__init__(alphabet, pad=pad, backwards=backwards, pad_token=pad_token) 240 | self.mut_alphabet=mut_alphabet 241 | self.startstop = startstop 242 | 243 | def _prep(self, sequences): 244 | tgt = list(sequences[:]) 245 | src = [] 246 | mask = [] 247 | for seq in sequences: 248 | if len(seq) == 0: 249 | tgt.remove(seq) 250 | continue 251 | mod_idx = random.sample(list(range(len(seq))), int(len(seq) * 0.15)) 252 | if len(mod_idx) == 0: 253 | mod_idx = [np.random.choice(len(seq))] # make sure at least one aa is chosen 254 | seq_mod = list(seq) 255 | for idx in mod_idx: 256 | p = np.random.uniform() 257 | if p <= 0.10: # do nothing 258 | mod = seq[idx] 259 | elif 0.10 < p <= 0.20: # replace with random amino acid 260 | mod = np.random.choice([i for i in self.mut_alphabet if i != seq[idx]]) 261 | else: # mask 262 | mod = MASK 263 | seq_mod[idx] = mod 264 | src.append(''.join(seq_mod)) 265 | m = torch.zeros(len(seq_mod)) 266 | m[mod_idx] = 1 267 | mask.append(m) 268 | if self.startstop: 269 | src = [START + s + STOP for s in src] 270 | tgt = [START + s + STOP for s in tgt] 271 | mask = [torch.cat([torch.zeros(1), m, torch.zeros(1)]) for m in mask] 272 | src = [torch.LongTensor(self.tokenizer.tokenize(s)) for s in src] 273 | tgt = [torch.LongTensor(self.tokenizer.tokenize(s)) for s in tgt] 274 | pad_idx = self.tokenizer.alphabet.index(PAD) 275 | src = _pad(src, pad_idx) 276 | tgt = _pad(tgt, pad_idx) 277 | mask = _pad(mask, 0) 278 | return src, tgt, mask 279 | 280 | 281 | class StructureCollater(object): 282 | """Collater for combined seq/str GNNs. 283 | 284 | Parameters: 285 | sequence collater (SimpleCollater) 286 | n_connections (int) 287 | n_node_features (int) 288 | n_edge_features (int) 289 | startstop (boolean): if true, expect the sequence collater to add starts/stops, and adds an 290 | extra zeroed node at the left of the graph. 291 | 292 | Input (list): a batch of sequences as strings 293 | Output: 294 | sequences from sequence_collater 295 | nodes, edges, connections, edge_mask for GNN 296 | """ 297 | 298 | def __init__(self, sequence_collater: SimpleCollater, n_connections=20, 299 | n_node_features=10, n_edge_features=11): 300 | self.sequence_collater = sequence_collater 301 | self.n_connections = n_connections 302 | self.n_node_features = n_node_features 303 | self.n_edge_features = n_edge_features 304 | 305 | def __call__(self, batch: List[Any], ) -> Iterable[torch.Tensor]: 306 | sequences, dists, omegas, thetas, phis = tuple(zip(*batch)) 307 | collated_seqs = self.sequence_collater._prep(sequences) 308 | ells = [len(s) for s in sequences] 309 | max_ell = max(ells) 310 | n = len(sequences) 311 | nodes = torch.zeros(n, max_ell, self.n_node_features) 312 | edges = torch.zeros(n, max_ell, self.n_connections, self.n_edge_features) 313 | connections = torch.zeros(n, max_ell, self.n_connections, dtype=torch.long) 314 | edge_mask = torch.zeros(n, max_ell, self.n_connections, 1) 315 | for i, (ell, dist, omega, theta, phi) in enumerate(zip(ells, dists, omegas, thetas, phis)): 316 | # process features 317 | V = get_node_features(omega, theta, phi) 318 | dist.fill_diagonal_(np.nan) 319 | E_idx = get_k_neighbors(dist, self.n_connections) 320 | E = get_edge_features(dist, omega, theta, phi, E_idx) 321 | str_mask = get_mask(E) 322 | E = replace_nan(E) 323 | V = replace_nan(V) 324 | # reshape 325 | nc = min(ell - 1, self.n_connections) 326 | nodes[i, :ell] = V 327 | edges[i, :ell, :nc] = E 328 | connections[i, :ell, :nc] = E_idx 329 | str_mask = str_mask.view(1, ell, -1) 330 | edge_mask[i, :ell, :nc, 0] = str_mask 331 | return (*collated_seqs, nodes, edges, connections, edge_mask) 332 | 333 | 334 | class StructureOutputCollater(object): 335 | """Collater that batches sequences and ell x ell structure targets. 336 | 337 | Currently cannot deal with starts/stops! 338 | """ 339 | 340 | def __init__(self, sequence_collater: SimpleCollater, exp=True, dist_only=False): 341 | self.exp = exp 342 | self.sequence_collater = sequence_collater 343 | self.dist_only = dist_only 344 | 345 | def _pad(self, squares, ells, value=0.0): 346 | max_len = max(ells) 347 | squares = [F.pad(d, [0, max_len - ell, 0, max_len - ell], value=value) 348 | for d, ell in zip(squares, ells)] 349 | squares = torch.stack(squares, dim=0) 350 | return squares 351 | 352 | def __call__(self, batch: List[Any], ) -> Iterable[torch.Tensor]: 353 | sequences, dists, omegas, thetas, phis = tuple(zip(*batch)) 354 | ells = [len(s) for s in sequences] 355 | seqs = self.sequence_collater._prep(sequences)[0] 356 | if self.exp: 357 | dists = [torch.exp(-d ** 2 / 64) for d in dists] 358 | masks = [~torch.isnan(dist) for dist in dists] 359 | else: 360 | masks = [torch.ones_like(dist).bool() for dist in dists] 361 | masks = [~torch.isnan(omega) & m for omega, m in zip(omegas, masks)] 362 | masks = [~torch.isnan(theta) & m for theta, m in zip(thetas, masks)] 363 | masks = [~torch.isnan(phi) & m for phi, m in zip(phis, masks)] 364 | masks = self._pad(masks, ells, value=False) 365 | dists = self._pad(dists, ells) 366 | if self.dist_only: 367 | return seqs, dists, masks 368 | omegas = self._pad(omegas, ells) 369 | thetas = self._pad(thetas, ells) 370 | phis = self._pad(phis, ells) 371 | return seqs, dists, omegas, thetas, phis, masks 372 | 373 | 374 | class TAPE2trRosettaCollater(SimpleCollater): 375 | """Does trRosetta preprocessing for TAPE datasets. """ 376 | 377 | def __init__(self, alphabet: str, pad=True): 378 | super().__init__(alphabet, pad=pad) 379 | self.featurization = trRosettaPreprocessing(alphabet) 380 | 381 | def __call__(self, batch: List[Any], ) -> List[torch.Tensor]: 382 | data = tuple(zip(*batch)) 383 | if len(data) == 0: 384 | return data 385 | sequences = data[0] 386 | sequences = [i.replace('X', '-') for i in sequences] # get rid of X found in secondary_stucture data 387 | lens = [len(i) for i in sequences] 388 | max_len = max(lens) 389 | prepped = self._prep(sequences)[0] 390 | prepped = torch.stack([self.featurization.process(i.view(1,-1)).squeeze(0) for i in prepped]) 391 | y = data[1] 392 | tgt_mask = data[2] 393 | src_mask = [torch.ones(i, i).bool() for i in lens] 394 | src_mask = [F.pad(mask_i, 395 | (0, max_len - len(mask_i), 0, max_len - len(mask_i)), value=False) for mask_i in src_mask] 396 | src_mask = torch.stack(src_mask, dim=0).unsqueeze(1) 397 | 398 | if isinstance(y[0], float): # stability or fluorescence 399 | y = torch.tensor(y).unsqueeze(-1) 400 | tgt_mask = torch.ones_like(y) 401 | 402 | elif isinstance(y[0], int): # remote homology 403 | y = torch.tensor(y).long() 404 | tgt_mask = torch.ones_like(y) 405 | 406 | elif len(y[0].shape) == 1: # secondary structure 407 | tgt_mask = [torch.ones(i) for i in lens] 408 | y = _pad(y, 0).long() 409 | tgt_mask = _pad(tgt_mask, 0).long() 410 | 411 | elif len(y[0].shape) == 2: # contact 412 | max_len = max(len(i) for i in y) 413 | tgt_mask = [F.pad(mask_i, 414 | (0, max_len - len(mask_i), 0, max_len - len(mask_i)), value=False) for mask_i in tgt_mask] 415 | tgt_mask = torch.stack(tgt_mask, dim=0) 416 | y = [F.pad(yi, (0, max_len - len(yi), 0, max_len - len(yi)), value=-1) for yi in y] 417 | y = torch.stack(y, dim=0).long() 418 | return prepped.float(), y, tgt_mask, src_mask 419 | 420 | 421 | class MSAStructureCollater(StructureOutputCollater): 422 | """Collater that batches msas and ell x ell structure targets. 423 | 424 | Currently cannot deal with starts/stops! 425 | 426 | MSAs should be pre-tokenized. 427 | """ 428 | 429 | def __init__(self, pad_idx): 430 | self.pad_idx = pad_idx 431 | 432 | def __call__(self, batch: List[Any], ) -> Iterable[torch.Tensor]: 433 | msas, dists, omegas, thetas, phis = tuple(zip(*batch)) 434 | ells = [s.shape[1] for s in msas] 435 | max_ell = max(ells) 436 | msas = [F.pad(msa, [0, max_ell - ell], value=self.pad_idx).long() for msa, ell in zip(msas, ells)] 437 | masks = [torch.ones_like(dist).bool() for dist in dists] 438 | masks = self._pad(masks, ells, value=False) 439 | dists = self._pad(dists, ells) 440 | omegas = self._pad(omegas, ells) 441 | thetas = self._pad(thetas, ells) 442 | phis = self._pad(phis, ells) 443 | return msas, dists, omegas, thetas, phis, masks 444 | 445 | 446 | class MSAGapCollater(object): 447 | 448 | def __init__(self, sequence_collater, n_connections=30, direction='bidirectional', task='gap-prob'): 449 | """Collater for gap probability prediction with a GNN. y is (p_gap, 1 - p_gap). 450 | 451 | Uses MASK to pad to distinguish between GAP and padding. 452 | For bidirectional: 453 | src: 454 | tgt:

455 |         
456 |         For forward: 
457 |             src: 
458 |             tgt:        
459 |         
460 |         for backward:
461 |             src: 
462 |             tgt: 
463 | 
464 |             
465 |         Args:
466 |             sequence_collater: should only return src
467 |             direction (str)
468 |             n_connections (int)
469 |             task (str): gap-prob or ar
470 | 
471 |         Returns:
472 |             seqs, anchor_seq, nodes, edges, connections, edge_mask, y, mask_y
473 |         """
474 |         # collaters
475 |         self.sequence_collater = sequence_collater
476 |         self.structure_collater = StructureCollater(self.sequence_collater,
477 |                                                     n_connections=n_connections)
478 |         self.direction = direction
479 |         self.pad_idx = sequence_collater.tokenizer.alphabet.index(MASK)
480 |         if direction != 'bidirectional':
481 |             self.start_idx = sequence_collater.tokenizer.alphabet.index(START)
482 |         self.task = task
483 |         self.gap_idx = sequence_collater.tokenizer.alphabet.index(GAP)
484 | 
485 |     def __call__(self, batch: List[Any], ) -> Iterable[torch.Tensor]:
486 |         seq, anchor_seq, dist, omega, theta, phi, y, y_mask = tuple(zip(*batch))
487 |         anchor_seq = _pad(anchor_seq, self.pad_idx)
488 |         seq = [self.sequence_collater.tokenizer.untokenize(i.numpy()) for i in seq]
489 |         rebatch = [(seq[i], dist[i], omega[i], theta[i], phi[i]) for i in range(len(seq))]
490 |         seqs, nodes, edges, connections, edge_mask = self.structure_collater.__call__(rebatch)
491 | 
492 |         # If backward, reverse everything
493 |         if self.direction != 'bidirectional':
494 |             if self.direction == 'backward':
495 |                 d1_pad = [0, 1]
496 |                 node_pad = [0, 0, 0, 1]
497 |                 edge_pad = [0, 0] + node_pad
498 |             if self.direction == 'forward':
499 |                 d1_pad = [1, 0]
500 |                 node_pad = [0, 0, 1, 0]
501 |                 edge_pad = [0, 0] + node_pad
502 |                 connections = connections + 1
503 |             seqs = F.pad(seqs, d1_pad, value=self.start_idx)
504 |             anchor_seq = F.pad(anchor_seq, d1_pad, value=self.start_idx)
505 |             nodes = F.pad(nodes, node_pad, value=0.0)
506 |             edges = F.pad(edges, edge_pad, value=0.0)
507 |             connections = F.pad(connections, node_pad, value=0)
508 |             edge_mask = F.pad(edge_mask, edge_pad, value=0.0)
509 | 
510 |         X = (seqs, anchor_seq, nodes, edges, connections, edge_mask)
511 | 
512 |         if self.task == 'gap-prob':
513 |             y = _pad(y, 0)
514 |             mask_y = [torch.ones_like(i).bool() for i in y]            
515 |             mask_y = _pad(mask_y, False)    
516 |             if self.direction != 'bidirectional':
517 |                 y = F.pad(y, [0, 1, 0, 0], value=0)
518 |                 mask_y = F.pad(mask_y, [0, 1, 0, 0], value=False)
519 |             # adjust y format to fit kldivloss
520 |             y = y.unsqueeze(-1)
521 |             y = torch.cat((y, torch.ones_like(y) - y), -1)
522 |         else:
523 |             y = (seqs[:, 1:] == self.gap_idx).long()
524 |             y = F.pad(y, d1_pad)
525 |             mask_y = (seqs[:, 1:] != self.pad_idx).float()
526 |             mask_y = F.pad(mask_y, d1_pad)
527 | 
528 | 
529 |         return X + (y, mask_y)
530 | 
531 | 
532 | class Seq2PropertyCollater(SimpleCollater):
533 |     """A collater that batches sequences and a 1d target. """
534 | 
535 |     def __init__(self, alphabet: str, pad=True, scatter=False, return_mask=False, start=False, stop=False):
536 |         super().__init__(alphabet, pad=pad, start=start, stop=stop)
537 |         self.scatter = scatter
538 |         self.mask = return_mask
539 | 
540 |     def __call__(self, batch):
541 |         data = tuple(zip(*batch))
542 |         sequences = data[0]
543 |         prepped = self._prep(sequences)[0]
544 |         if self.mask:
545 |             mask = prepped != self.tokenizer.alphabet.index(PAD)
546 |         if self.scatter:
547 |             prepped = F.one_hot(prepped, len(self.tokenizer.alphabet))
548 | 
549 |         y = data[1]
550 |         y = torch.tensor(y).unsqueeze(-1).float()
551 |         if not self.mask:
552 |             return prepped, y
553 |         else:
554 |             return prepped, y, mask
555 | 
556 | 
557 | def _pad_msa(tokenized: List, num_seq: int, max_len: int, value: int) -> torch.Tensor:
558 |     """Utility function that pads batches to the same length."""
559 |     batch_size = len(tokenized)
560 |     num_seq = max([len(m) for m in tokenized])
561 |     output = torch.zeros((batch_size, num_seq, max_len), dtype=torch.long) + value
562 |     for i in range(batch_size):
563 |         tokenized[i] = torch.LongTensor(np.array(tokenized[i]))
564 |         output[i, :len(tokenized[i]), :len(tokenized[i][0])] = tokenized[i]
565 |     return output
566 | 
567 | 
568 | class MSAAbsorbingCollater(object):
569 |     """Collater for MSA Absorbing Diffusion model.
570 |     Based on implementation described by Hoogeboom et al. in "Autoregressive Diffusion Models"
571 |     https://doi.org/10.48550/arXiv.2110.02037
572 | 
573 |     Parameters:
574 |         alphabet: str,
575 |             protein alphabet to use
576 |         pad_token: str,
577 |             pad_token to use to pad MSAs, default is PAD token from sequence_models.constants
578 |         num_seqs: int,
579 |             number of sequences to include in each MSA
580 | 
581 |     Input (list): a batch of Multiple Sequence Alignments (MSAs), each MSA contains 64 sequences
582 |     Output:
583 |         src (torch.LongTensor): corrupted input + padding
584 |         tgt (torch.LongTensor): input + padding
585 |         mask (torch.LongTensor): 1 where tgt is not padding
586 |     """
587 | 
588 |     def __init__(self, alphabet: str, pad_token=MSA_PAD, num_seqs=64, bert=False):
589 |         self.tokenizer = Tokenizer(alphabet)
590 |         self.pad_idx = self.tokenizer.alphabet.index(pad_token)
591 |         self.num_seqs = num_seqs
592 |         self.bert = bert
593 |         if bert:
594 |             self.choices = [self.tokenizer.alphabet.index(a) for a in PROTEIN_ALPHABET + GAP]
595 | 
596 |     def __call__(self, batch_msa):
597 |         tgt = list(batch_msa)
598 |         src = tgt.copy()
599 | 
600 |         longest_msa = 0
601 |         batch_size = len(batch_msa)
602 |         mask_ix = []
603 |         mask_iy = []
604 |         for i in range(batch_size):
605 |             # Tokenize MSA
606 |             tgt[i] = [self.tokenizer.tokenize(s) for s in tgt[i]]
607 |             src[i] = [self.tokenizer.tokenize(s) for s in src[i]]
608 | 
609 |             curr_msa = src[i]
610 | 
611 |             curr_msa = np.asarray(curr_msa)
612 |             length, depth = curr_msa.shape  # length = number of seqs in MSA, depth = # AA in MSA
613 | 
614 |             curr_msa = curr_msa.flatten()  # Flatten MSA to 1D to mask tokens
615 |             d = len(curr_msa)  # number of residues in MSA
616 |             if not self.bert:
617 |                 t = np.random.choice(d)  # Pick timestep t
618 |                 t += 1  # ensure t cannot be 0
619 |                 num_masked_tokens = d - t + 1
620 |                 mask_idx = np.random.choice(d, num_masked_tokens, replace=False)
621 |             else:
622 |                 num_corr_tokens = int(np.round(0.15 * d))
623 |                 corr_idx = np.random.choice(d, num_corr_tokens, replace=False)
624 |                 num_masked_tokens = int(np.round(0.8 * num_corr_tokens))
625 |                 num_mut_tokens = int(np.round(0.1 * num_corr_tokens))
626 |                 mask_idx = corr_idx[:num_masked_tokens]
627 |                 muta_idx = corr_idx[-num_mut_tokens:]
628 |                 for idx in muta_idx:
629 |                     choices = list(set(self.choices) - set(curr_msa[[idx]]))
630 |                     curr_msa[idx] = np.random.choice(choices)
631 |                 mask_ix.append(corr_idx // depth)
632 |                 mask_iy.append(corr_idx % depth)
633 | 
634 |             curr_msa[mask_idx] = self.tokenizer.mask_id
635 |             curr_msa = curr_msa.reshape(length, depth)
636 | 
637 |             src[i] = list(curr_msa)
638 | 
639 |             longest_msa = max(depth, longest_msa)  # Keep track of the longest MSA for padding
640 | 
641 |         # Pad sequences
642 |         src = _pad_msa(src, self.num_seqs, longest_msa, self.pad_idx)
643 |         tgt = _pad_msa(tgt, self.num_seqs, longest_msa, self.pad_idx)
644 |         if self.bert:
645 |             mask = torch.zeros_like(src)
646 |             for i in range(len(mask_ix)):
647 |                 mask[i, mask_ix[i], mask_iy[i]] = 1
648 |             mask = mask.bool()
649 |         else:
650 |             mask = (src == self.tokenizer.mask_id)
651 | 
652 |         return src, tgt, mask


--------------------------------------------------------------------------------
/sequence_models/constants.py:
--------------------------------------------------------------------------------
 1 | import pathlib
 2 | import os
 3 | 
 4 | import numpy as np
 5 | 
 6 | home = os.getenv('PT_DATA_DIR')
 7 | if home is None:
 8 |     home = str(pathlib.Path.home())
 9 | WEIGHTS_DIR = home + '/sm_weights/'
10 | 
11 | #  It's helpful to separate out the twenty canonical amino acids from the rest
12 | CAN_AAS = 'ACDEFGHIKLMNPQRSTVWY'
13 | AMB_AAS = 'BZX'
14 | OTHER_AAS = 'JOU'
15 | ALL_AAS = CAN_AAS + AMB_AAS + OTHER_AAS
16 | 
17 | DNA = 'GATC'
18 | EXTENDED_NA = 'RYWSMKHBVDN'
19 | RNA = 'GAUC'
20 | IUPAC_AMB_DNA = DNA + EXTENDED_NA
21 | IUPAC_AMB_RNA = RNA + EXTENDED_NA
22 | NAS = 'GATUC' + EXTENDED_NA
23 | 
24 | STOP = '*'
25 | GAP = '-'
26 | PAD = GAP
27 | MSA_PAD = '!'
28 | MASK = '#'  # Useful for masked language model training
29 | START = '@'
30 | SEP = '/'
31 | 
32 | SPECIALS = STOP + GAP + MASK + START
33 | PROTEIN_ALPHABET = ALL_AAS + SPECIALS
34 | MSA_AAS = ALL_AAS + GAP
35 | MSA_ALPHABET = ALL_AAS + GAP + STOP + MASK + START + MSA_PAD
36 | RNA_ALPHABET = IUPAC_AMB_RNA + SPECIALS
37 | ENHANCER_ALPHABET = DNA + 'N' + GAP + MSA_PAD + MASK + STOP + START
38 | MSA_ALPHABET_PLUS = MSA_ALPHABET + SEP
39 | 
40 | trR_ALPHABET = "ARNDCQEGHILKMFPSTWYV-"
41 | 
42 | AAINDEX_ALPHABET = 'ARNDCQEGHILKMFPSTWYV'
43 | INSERT = '.'
44 | ESM2_ALPHABET =  START + MSA_PAD + STOP + MSA_PAD + 'LAGVSERTIDPKQNFYMHWCXBUZO' + INSERT + GAP + MSA_PAD + MASK
45 | 
46 | IUPAC_SS = 'HSTC'
47 | DSSP = 'GHITEBSC'
48 | SS8 = DSSP
49 | SS3 = 'HSL'  # H: GHI; S: EB; L: STC
50 | 
51 | # Bins from TrRosetta paper
52 | DIST_BINS = np.concatenate([np.array([np.nan]), np.linspace(2, 20, 37)])
53 | THETA_BINS = np.concatenate([np.array([np.nan]), np.linspace(0, 2 * np.pi, 25)])
54 | PHI_BINS = np.concatenate([np.array([np.nan]), np.linspace(0, np.pi, 13)])
55 | OMEGA_BINS = np.concatenate([np.array([np.nan]), np.linspace(0, 2 * np.pi, 25)])
56 | 
57 | IUPAC_CODES = {
58 |     "Ala": "A",
59 |     "Arg": "R",
60 |     "Asn": "N",
61 |     "Asp": "D",
62 |     "Cys": "C",
63 |     "Gln": "Q",
64 |     "Glu": "E",
65 |     "Gly": "G",
66 |     "His": "H",
67 |     "Ile": "I",
68 |     "Leu": "L",
69 |     "Lys": "K",
70 |     "Met": "M",
71 |     "Phe": "F",
72 |     "Pro": "P",
73 |     "Ser": "S",
74 |     "Thr": "T",
75 |     "Trp": "W",
76 |     "Val": "V",
77 |     "Tyr": "Y",
78 |     "Asx": "B",
79 |     "Sec": "U",
80 |     "Xaa": "X",
81 |     "Glx": "Z",
82 | }
83 | 
84 | 
85 | 
86 | 


--------------------------------------------------------------------------------
/sequence_models/convolutional.py:
--------------------------------------------------------------------------------
  1 | import torch.nn as nn
  2 | import torch
  3 | import torch.nn.functional as F
  4 | from torch.utils.checkpoint import checkpoint
  5 | import numpy as np
  6 | 
  7 | 
  8 | from sequence_models.layers import PositionFeedForward, PositionFeedForward2d, DoubleEmbedding
  9 | 
 10 | 
 11 | class MaskedConv1d(nn.Conv1d):
 12 |     """ A masked 1-dimensional convolution layer.
 13 | 
 14 |     Takes the same arguments as torch.nn.Conv1D, except that the padding is set automatically.
 15 | 
 16 |          Shape:
 17 |             Input: (N, L, in_channels)
 18 |             input_mask: (N, L, 1), optional
 19 |             Output: (N, L, out_channels)
 20 |     """
 21 | 
 22 |     def __init__(self, in_channels: int, out_channels: int,
 23 |                  kernel_size: int, stride: int=1, dilation: int=1, groups: int=1,
 24 |                  bias: bool=True):
 25 |         """
 26 |         :param in_channels: input channels
 27 |         :param out_channels: output channels
 28 |         :param kernel_size: the kernel width
 29 |         :param stride: filter shift
 30 |         :param dilation: dilation factor
 31 |         :param groups: perform depth-wise convolutions
 32 |         :param bias: adds learnable bias to output
 33 |         """
 34 |         padding = dilation * (kernel_size - 1) // 2
 35 |         super().__init__(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation,
 36 |                                            groups=groups, bias=bias, padding=padding)
 37 | 
 38 |     def forward(self, x, input_mask=None):
 39 |         if input_mask is not None:
 40 |             x = x * input_mask
 41 |         return super().forward(x.transpose(1, 2)).transpose(1, 2)
 42 | 
 43 | 
 44 | class MaskedConv2d(nn.Conv2d):
 45 |     """ A masked 2-dimensional convolution layer.
 46 | 
 47 |     Takes the same arguments as torch.nn.Conv2D, except that the padding is set automatically.
 48 | 
 49 |          Shape:
 50 |             Input: (N, L, L, in_channels)
 51 |             input_mask: (N, L, L, 1), optional
 52 |             Output: (N, L, L, out_channels)
 53 |     """
 54 | 
 55 |     def __init__(self, in_channels: int, out_channels: int,
 56 |                  kernel_size: int, stride: int=1, dilation: int=1, groups: int=1,
 57 |                  bias: bool=True):
 58 |         """
 59 |         :param in_channels: input channels
 60 |         :param out_channels: output channels
 61 |         :param kernel_size: the kernel width
 62 |         :param stride: filter shift
 63 |         :param dilation: dilation factor
 64 |         :param groups: perform depth-wise convolutions
 65 |         :param bias: adds learnable bias to output
 66 |         """
 67 |         padding = dilation * (kernel_size - 1) // 2
 68 |         super().__init__(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation,
 69 |                                            groups=groups, bias=bias, padding=padding)
 70 | 
 71 |     def forward(self, x, input_mask=None):
 72 |         if input_mask is not None:
 73 |             x = x * input_mask
 74 |         return super().forward(x.permute(0, 3, 1, 2).contiguous()).permute(0, 2, 3, 1).contiguous()
 75 | 
 76 | 
 77 | class MaskedCausalConv1d(nn.Module):
 78 |     """Masked Causal 1D convolution based on https://github.com/Popgun-Labs/PopGen/. 
 79 |          
 80 |          Shape:
 81 |             Input: (N, L, in_channels)
 82 |             input_mask: (N, L, 1), optional
 83 |             Output: (N, L, out_channels)
 84 |     """
 85 | 
 86 |     def __init__(self, in_channels, out_channels, kernel_size=1, dilation=1, groups=1, init=None):
 87 |         """
 88 |         Causal 1d convolutions with caching mechanism for O(L) generation,
 89 |         as described in the ByteNet paper (Kalchbrenner et al, 2016) and "Fast Wavenet" (Paine, 2016)
 90 |         Usage:
 91 |             At train time, API is same as regular convolution. `conv = CausalConv1d(...)`
 92 |             At inference time, set `conv.sequential = True` to enable activation caching, and feed
 93 |             sequence through step by step. Recurrent state is managed internally.
 94 |         References:
 95 |             - Neural Machine Translation in Linear Time: https://arxiv.org/abs/1610.10099
 96 |             - Fast Wavenet: https://arxiv.org/abs/1611.09482
 97 |         :param in_channels: input channels
 98 |         :param out_channels: output channels
 99 |         :param kernel_size: the kernel width
100 |         :param dilation: dilation factor
101 |         :param groups: perform depth-wise convolutions
102 |         :param init: optional initialisation function for nn.Conv1d module (e.g xavier)
103 |         """
104 |         super().__init__()
105 | 
106 |         self.in_channels = in_channels
107 |         self.out_channels = out_channels
108 |         self.kernel_size = kernel_size
109 |         self.dilation = dilation
110 |         self.groups = groups
111 | 
112 |         # if `true` enables fast generation
113 |         self.sequential = False
114 | 
115 |         # compute required amount of padding to preserve the length
116 |         self.zeros = (kernel_size - 1) * dilation
117 |         self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation, groups=groups)
118 | 
119 |         # use supplied initialization function
120 |         if init:
121 |             init(self.conv)
122 | 
123 |     def forward(self, x, input_mask=None):
124 |         """
125 |         :param x: (batch, length, in_channels)
126 |         :param input_mask: (batch, length, 1)
127 |         :return: (batch, length, out_channels)
128 |         """
129 |         if input_mask is not None:
130 |             x = x * input_mask
131 |         # training mode
132 |         x = torch.transpose(x, 1, 2)
133 |         if not self.sequential:
134 |             # no padding for kw=1
135 |             if self.kernel_size == 1:
136 |                 return self.conv(x).transpose(1, 2)
137 | 
138 |             # left-pad + conv.
139 |             out = self._pad(x)
140 |             return self._unpad(self.conv(out)).transpose(1, 2)
141 | 
142 |         # sampling mode
143 |         else:
144 |             # note: x refers to a single timestep (batch, features, 1)
145 |             if not hasattr(self, 'recurrent_state'):
146 |                 batch_size = x.size(0)
147 |                 self._init_recurrent_state(batch_size)
148 | 
149 |             return self._generate(x).transpose(1, 2)
150 | 
151 |     def _pad(self, x):
152 |         return F.pad(x, [self.zeros, 0])
153 | 
154 |     def _unpad(self, x):
155 |         return x
156 | 
157 |     def clear_cache(self):
158 |         """
159 |         Delete the recurrent state. Note: this should be called between runs, to prevent
160 |         leftover state bleeding into future samples. Note that we delete state (instead of zeroing) to support
161 |         changes in the inference time batch size.
162 |         """
163 |         if hasattr(self, 'recurrent_state'):
164 |             del self.recurrent_state
165 | 
166 |     def _init_recurrent_state(self, batch):
167 |         """
168 |         Initialize the recurrent state for fast generation.
169 |         :param batch: the batch size to generate
170 |         """
171 | 
172 |         # extract weights and biases from nn.Conv1d module
173 |         state = self.conv.state_dict()
174 |         self.weight = state['weight']
175 |         self.bias = state['bias']
176 | 
177 |         # initialize the recurrent states to zeros
178 |         self.recurrent_state = torch.zeros(batch, self.in_channels, self.zeros, device=self.bias.device)
179 | 
180 |     def _generate(self, x_i):
181 |         """
182 |         Generate a single output activations, from the input activation
183 |         and the cached recurrent state activations from previous steps.
184 |         :param x_i: features of a single timestep (batch, in_channels, 1)
185 |         :return: the next output value in the series (batch, out_channels, 1)
186 |         """
187 | 
188 |         # if the kernel_size is greater than 1, use recurrent state.
189 |         if self.kernel_size > 1:
190 |             # extract the recurrent state and concat with input column
191 |             recurrent_activations = self.recurrent_state[:, :, :self.zeros]
192 |             f = torch.cat([recurrent_activations, x_i], 2)
193 | 
194 |             # update the cache for this layer
195 |             self.recurrent_state = torch.cat(
196 |                 [self.recurrent_state[:, :, 1:], x_i], 2)
197 |         else:
198 |             f = x_i
199 | 
200 |         # perform convolution
201 |         activations = F.conv1d(f, self.weight, self.bias,
202 |                                dilation=self.dilation, groups=self.groups)
203 | 
204 |         return activations
205 | 
206 | 
207 | class ByteNetBlock(nn.Module):
208 |     """Residual block from ByteNet paper (https://arxiv.org/abs/1610.10099).
209 |          
210 |          Shape:
211 |             Input: (N, L, d_in)
212 |             input_mask: (N, L, 1), optional
213 |             Output: (N, L, d_out)
214 | 
215 |     """
216 | 
217 |     def __init__(self, d_in, d_h, d_out, kernel_size, dilation=1, groups=1, causal=False, activation='relu', rank=None):
218 |         super().__init__()
219 |         if causal:
220 |             self.conv = MaskedCausalConv1d(d_h, d_h, kernel_size=kernel_size, dilation=dilation, groups=groups)
221 |         else:
222 |             self.conv = MaskedConv1d(d_h, d_h, kernel_size=kernel_size, dilation=dilation, groups=groups)
223 |         if activation == 'relu':
224 |             act = nn.ReLU
225 |         elif activation == 'gelu':
226 |             act = nn.GELU
227 |         layers1 = [
228 |             nn.LayerNorm(d_in),
229 |             act(),
230 |             PositionFeedForward(d_in, d_h, rank=rank),
231 |             nn.LayerNorm(d_h),
232 |             act()
233 |         ]
234 |         layers2 = [
235 |             nn.LayerNorm(d_h),
236 |             act(),
237 |             PositionFeedForward(d_h, d_out, rank=rank),
238 |         ]
239 |         self.sequence1 = nn.Sequential(*layers1)
240 |         self.sequence2 = nn.Sequential(*layers2)
241 | 
242 |     def forward(self, x, input_mask=None):
243 |         """
244 |         :param x: (batch, length, in_channels)
245 |         :param input_mask: (batch, length, 1)
246 |         :return: (batch, length, out_channels)
247 |         """
248 |         return x + self.sequence2(
249 |             self.conv(self.sequence1(x), input_mask=input_mask)
250 |         )
251 | 
252 | 
253 | class ByteNet(nn.Module):
254 | 
255 |     """Stacked residual blocks from ByteNet paper defined by n_layers
256 |          
257 |          Shape:
258 |             Input: (N, L,)
259 |             input_mask: (N, L, 1), optional
260 |             Output: (N, L, d)
261 | 
262 |     """
263 | 
264 |     def __init__(self, n_tokens, d_embedding, d_model, n_layers, kernel_size, r, rank=None, n_frozen_embs=None,
265 |                  padding_idx=None, causal=False, dropout=0.0, slim=True, activation='relu', down_embed=True):
266 |         """
267 |         :param n_tokens: number of tokens in token dictionary
268 |         :param d_embedding: dimension of embedding
269 |         :param d_model: dimension to use within ByteNet model, //2 every layer
270 |         :param n_layers: number of layers of ByteNet block
271 |         :param kernel_size: the kernel width
272 |         :param r: used to calculate dilation factor
273 |         :padding_idx: location of padding token in ordered alphabet
274 |         :param causal: if True, chooses MaskedCausalConv1d() over MaskedConv1d()
275 |         :param rank: rank of compressed weight matrices
276 |         :param n_frozen_embs: number of frozen embeddings
277 |         :param slim: if True, use half as many dimensions in the NLP as in the CNN
278 |         :param activation: 'relu' or 'gelu'
279 |         :param down_embed: if True, have lower dimension for initial embedding than in CNN layers
280 |         """
281 |         super().__init__()
282 |         if n_tokens is not None:
283 |             if n_frozen_embs is None:
284 |                 self.embedder = nn.Embedding(n_tokens, d_embedding, padding_idx=padding_idx)
285 |             else:
286 |                 self.embedder = DoubleEmbedding(n_tokens - n_frozen_embs, n_frozen_embs,
287 |                                                 d_embedding, padding_idx=padding_idx)
288 |         else:
289 |             self.embedder = nn.Identity()
290 |         if down_embed:
291 |             self.up_embedder = PositionFeedForward(d_embedding, d_model)
292 |         else:
293 |             self.up_embedder = nn.Identity()
294 |             assert n_tokens == d_embedding
295 |         log2 = int(np.log2(r)) + 1
296 |         dilations = [2 ** (n % log2) for n in range(n_layers)]
297 |         d_h = d_model
298 |         if slim:
299 |             d_h = d_h // 2
300 |         layers = [
301 |             ByteNetBlock(d_model, d_h, d_model, kernel_size, dilation=d, causal=causal, rank=rank,
302 |                          activation=activation)
303 |             for d in dilations
304 |         ]
305 |         self.layers = nn.ModuleList(modules=layers)
306 |         self.dropout = dropout
307 | 
308 |     def forward(self, x, input_mask=None):
309 |         """
310 |         :param x: (batch, length)
311 |         :param input_mask: (batch, length, 1)
312 |         :return: (batch, length,)
313 |         """
314 |         e = self._embed(x)
315 |         return self._convolve(e, input_mask=input_mask)
316 | 
317 |     def _embed(self, x):
318 |         e = self.embedder(x)
319 |         e = self.up_embedder(e)
320 |         return e
321 | 
322 |     def _convolve(self, e, input_mask=None):
323 |         for layer in self.layers:
324 |             e = layer(e, input_mask=input_mask)
325 |             if self.dropout > 0.0:
326 |                 e = F.dropout(e, self.dropout)
327 |         return e
328 | 
329 | 
330 | class ByteNetLM(nn.Module):
331 | 
332 |     def __init__(self, n_tokens, d_embedding, d_model, n_layers, kernel_size, r, rank=None, n_frozen_embs=None,
333 |                  padding_idx=None, causal=False, dropout=0.0, final_ln=False, slim=True, activation='relu',
334 |                  tie_weights=False, down_embed=True):
335 |         super().__init__()
336 |         self.embedder = ByteNet(n_tokens, d_embedding, d_model, n_layers, kernel_size, r,
337 |                                 padding_idx=padding_idx, causal=causal, dropout=dropout, down_embed=down_embed,
338 |                                 slim=slim, activation=activation, rank=rank, n_frozen_embs=n_frozen_embs)
339 |         if tie_weights:
340 |             self.decoder = nn.Linear(d_model, n_tokens, bias=False)
341 |             self.decoder.weight = self.embedder.embedder.weight
342 |         else:
343 |             self.decoder = PositionFeedForward(d_model, n_tokens)
344 |         if final_ln:
345 |             self.last_norm = nn.LayerNorm(d_model)
346 |         else:
347 |             self.last_norm = nn.Identity()
348 | 
349 |     def forward(self, x, input_mask=None):
350 |         e = self.embedder(x, input_mask=input_mask)
351 |         e = self.last_norm(e)
352 |         return self.decoder(e)
353 | 
354 | 
355 | class ConditionedByteNetLM(nn.Module):
356 | 
357 |     def __init__(self, n_tokens, d_embedding, d_conditioning, d_model, n_layers, kernel_size, r,
358 |                  padding_idx=None, causal=False):
359 |         super().__init__()
360 |         self.embedder = ConditionedByteNetDecoder(n_tokens, d_embedding, d_conditioning,
361 |                                                   d_model, n_layers, kernel_size, r,
362 |                                                   padding_idx=padding_idx, causal=causal)
363 |         self.decoder = PositionFeedForward(d_model, n_tokens)
364 | 
365 |     def forward(self, x, input_mask=None):
366 |         e = self.embedder(x, input_mask=input_mask)
367 |         return self.decoder(e)
368 | 
369 | 
370 | class ConditionedByteNetDecoder(ByteNet):
371 |     """ A conditioned, ByteNet decoder.
372 |     Inputs:
373 |         x (n, ell)
374 |         c: (n, d_conditioning)
375 | 
376 |     """
377 | 
378 |     def __init__(self, n_tokens, d_embedding, d_conditioning, d_model, n_layers, kernel_size, r,
379 |                  padding_idx=None, causal=False):
380 |         """
381 |         :param n_tokens: number of tokens in token dictionary
382 |         :param d_embedding: dimension of embedding
383 |         :param d_conditioning: dimension for conditioning, subtract from d_model
384 |         :param d_model: dimension to use within ByteNet model, //2 every layer
385 |         :param n_layers: number of layers of ByteNet block
386 |         :param kernel_size: the kernel width
387 |         :param r: used to calculate dilation factor
388 |         """
389 |         super().__init__(n_tokens, d_embedding, d_model, n_layers, kernel_size, r,
390 |                          padding_idx=padding_idx, causal=causal)
391 |         self.up_embedder = PositionFeedForward(d_embedding, d_model - d_conditioning)
392 | 
393 |     def _embed(self, inputs):
394 |         x, c = inputs
395 |         e = self.embedder(x)
396 |         e = self.up_embedder(e)  # (n, ell, d_model - d_conditioning)
397 |         # Concatenate the conditioning
398 |         _, ell = x.shape
399 |         if len(c.shape) == 2:
400 |             c = c.unsqueeze(1)
401 |             c_ = torch.repeat_interleave(c, ell, dim=1)  # (n, ell, d_conditioning)
402 |         else:
403 |             c_ = c
404 |         e = torch.cat([e, c_], dim=2)  # (n, ell, d_model)
405 |         return e
406 | 
407 | 
408 | class ByteNetBlock2d(nn.Module):
409 |     """Residual block from ByteNet paper (https://arxiv.org/abs/1610.10099).
410 | 
411 |          Shape:
412 |             Input: (N, L, L, d_in)
413 |             input_mask: (N, L, L, 1), optional
414 |             Output: (N, L, L, d_out)
415 | 
416 |     """
417 | 
418 |     def __init__(self, d_in, d_h, d_out, kernel_size, dilation=1, groups=1):
419 |         super().__init__()
420 |         self.conv = MaskedConv2d(d_h, d_h, kernel_size=kernel_size, dilation=dilation, groups=groups)
421 |         layers1 = [
422 |             nn.LayerNorm(d_in),
423 |             nn.GELU(),
424 |             PositionFeedForward2d(d_in, d_h),
425 |             nn.LayerNorm(d_h),
426 |             nn.GELU()
427 |         ]
428 |         layers2 = [
429 |             nn.LayerNorm(d_h),
430 |             nn.GELU(),
431 |             PositionFeedForward2d(d_h, d_out),
432 |         ]
433 |         self.sequence1 = nn.Sequential(*layers1)
434 |         self.sequence2 = nn.Sequential(*layers2)
435 | 
436 |     def forward(self, x, input_mask=None):
437 |         """
438 |         :param x: (batch, length, length, in_channels)
439 |         :param input_mask: (batch, length, length, 1)
440 |         :return: (batch, length, length, out_channels)
441 |         """
442 |         return x + self.sequence2(
443 |             self.conv(self.sequence1(x), input_mask=input_mask)
444 |         )
445 | 
446 | 
447 | class ByteNet2d(nn.Module):
448 |     """Stacked residual blocks from ByteNet paper defined by n_layers
449 | 
450 |          Shape:
451 |             Input: (N, L, L, d_in)
452 |             input_mask: (N, L, L, 1), optional
453 |             Output: (N, L, d_model)
454 | 
455 |     """
456 | 
457 |     def __init__(self, d_in, d_model, d_hidden, n_layers, kernel_size, r, dropout=0.0, tokens=True, padding_idx=None):
458 |         """
459 |         :param d_in: number of input dimensions
460 |         :param d_model: dimension to use within ByteNet model, // 2 every layer
461 |         :param n_layers: number of layers of ByteNet block
462 |         :param kernel_size: the kernel width
463 |         :param r: used to calculate dilation factor
464 |         """
465 |         super().__init__()
466 |         if tokens:
467 |             self.up_embedder = nn.Embedding(d_in, d_model, padding_idx=padding_idx)
468 |         else:
469 |             self.up_embedder = PositionFeedForward2d(d_in, d_model)
470 |         log2 = int(np.log2(r)) + 1
471 |         dilations = [2 ** (n % log2) for n in range(n_layers)]
472 |         layers = [
473 |             ByteNetBlock2d(d_model, d_hidden, d_model, kernel_size, dilation=d)
474 |             for d in dilations
475 |         ]
476 |         self.layers = nn.ModuleList(modules=layers)
477 |         self.dropout = dropout
478 | 
479 |     def forward(self, x, input_mask=None):
480 |         e = self._embed(x)
481 |         return self._convolve(e, input_mask=input_mask)
482 | 
483 |     def _embed(self, x):
484 |         e = self.up_embedder(x)
485 |         return e
486 | 
487 |     def _convolve(self, e, input_mask=None):
488 |         for layer in self.layers:
489 |             e = checkpoint(layer, e, input_mask)
490 | 
491 |             # e = layer(e, input_mask=input_mask)
492 |             if self.dropout > 0.0:
493 |                 e = F.dropout(e, self.dropout)
494 |         return e
495 | 
496 | 
497 | class ByteNetLM2d(nn.Module):
498 | 
499 |     def __init__(self, n_tokens, d_model, d_hidden, n_layers, kernel_size, r,
500 |                  padding_idx=None, dropout=0.0):
501 |         super().__init__()
502 |         self.embedder = ByteNet2d(n_tokens, d_model, d_hidden, n_layers, kernel_size, r,
503 |                                 padding_idx=padding_idx, dropout=dropout, tokens=True)
504 | 
505 |         self.decoder = PositionFeedForward2d(d_model, n_tokens)
506 |         self.last_norm = nn.LayerNorm(d_model)
507 | 
508 |     def forward(self, x, input_mask=None):
509 |         e = self.embedder(x, input_mask=input_mask)
510 |         e = self.last_norm(e)
511 |         return self.decoder(e)
512 | 
513 | 
514 | 


--------------------------------------------------------------------------------
/sequence_models/esm.py:
--------------------------------------------------------------------------------
  1 | import torch.nn as nn
  2 | import torch
  3 | import torch.nn.functional as F
  4 | from torch.utils.checkpoint import checkpoint
  5 | 
  6 | from esm.modules import TransformerLayer, LearnedPositionalEmbedding, ESM1bLayerNorm, AxialTransformerLayer
  7 | from sequence_models.constants import PROTEIN_ALPHABET, PAD, MASK
  8 | 
  9 | 
 10 | class RobertaLMHead(nn.Module):
 11 |     """Head for masked language modeling."""
 12 | 
 13 |     def __init__(self, embed_dim, output_dim, weight):
 14 |         super().__init__()
 15 |         self.dense = nn.Linear(embed_dim, embed_dim)
 16 |         self.layer_norm = ESM1bLayerNorm(embed_dim)
 17 |         self.weight = weight
 18 |         self.bias = nn.Parameter(torch.zeros(output_dim))
 19 | 
 20 |     def forward(self, features):
 21 |         x = self.dense(features)
 22 |         x = F.gelu(x)
 23 |         x = self.layer_norm(x)
 24 |         # project back to size of vocabulary with bias
 25 |         x = F.linear(x, self.weight) + self.bias
 26 |         return x
 27 | 
 28 | class ESM1b(nn.Module):
 29 |     """
 30 |     Args:
 31 |         d_model: int,
 32 |             embedding dimension of model
 33 |         d_hidden: int,
 34 |             embedding dimension of feed forward network
 35 |        n_layers: int,
 36 |            number of layers
 37 |        n_heads: int,
 38 |            number of attention heads
 39 |    """
 40 | 
 41 |     def __init__(self, d_model, d_hidden, n_layers, n_heads, n_tokens=len(PROTEIN_ALPHABET),
 42 |                  padding_idx=PROTEIN_ALPHABET.index(PAD), mask_idx=PROTEIN_ALPHABET.index(MASK),
 43 |                  max_positions=1024, tie_weights=True):
 44 |         super(ESM1b, self).__init__()
 45 |         self.embed_tokens = nn.Embedding(
 46 |             n_tokens, d_model, padding_idx=mask_idx
 47 |         )
 48 |         self.layers = nn.ModuleList(
 49 |             [
 50 |                 TransformerLayer(
 51 |                     d_model, d_hidden, n_heads,
 52 |                     add_bias_kv=False,
 53 |                     use_esm1b_layer_norm=True,
 54 |                 )
 55 |                 for _ in range(n_layers)
 56 |             ]
 57 |         )
 58 |         self.padding_idx = padding_idx
 59 | 
 60 |         self.embed_positions = LearnedPositionalEmbedding(max_positions, d_model, padding_idx)
 61 |         self.emb_layer_norm_before = ESM1bLayerNorm(d_model)
 62 |         self.emb_layer_norm_after = ESM1bLayerNorm(d_model)
 63 |         if tie_weights:
 64 |             self.lm_head = RobertaLMHead(
 65 |                 embed_dim=d_model,
 66 |                 output_dim=n_tokens,
 67 |                 weight=self.embed_tokens.weight
 68 |             )
 69 |         else:
 70 |             self.lm_head = RobertaLMHead(
 71 |                 embed_dim=d_model,
 72 |                 output_dim=n_tokens,
 73 |                 weight=nn.Linear(d_model, n_tokens).weight
 74 |             )
 75 | 
 76 |     def forward(self, tokens):
 77 | 
 78 |         assert tokens.ndim == 2
 79 |         padding_mask = tokens.eq(self.padding_idx)  # B, T
 80 | 
 81 |         x = self.embed_tokens(tokens)
 82 |         x = x + self.embed_positions(tokens)
 83 | 
 84 |         x = self.emb_layer_norm_before(x)
 85 |         x = x * (1 - padding_mask.unsqueeze(-1).type_as(x))
 86 | 
 87 |         # (B, T, E) => (T, B, E)
 88 |         x = x.transpose(0, 1)
 89 | 
 90 |         if not padding_mask.any():
 91 |             padding_mask = None
 92 | 
 93 |         for layer_idx, layer in enumerate(self.layers):
 94 |             x, attn = layer(x, self_attn_padding_mask=padding_mask, need_head_weights=False)
 95 | 
 96 |         x = self.emb_layer_norm_after(x)
 97 |         x = x.transpose(0, 1)  # (T, B, E) => (B, T, E)
 98 |         x = self.lm_head(x)
 99 |         return x
100 | 
101 | 
102 | class MSATransformer(nn.Module):
103 |     """
104 |     Based on implementation described by Rao et al. in "MSA Transformer"
105 |     https://doi.org/10.1101/2021.02.12.430858
106 | 
107 |     Args:
108 |         d_model: int,
109 |             embedding dimension of model
110 |         d_hidden: int,
111 |             embedding dimension of feed forward network
112 |        n_layers: int,
113 |            number of layers
114 |        n_heads: int,
115 |            number of attention heads
116 |    """
117 | 
118 |     def __init__(self, d_model, d_hidden, n_layers, n_heads, use_ckpt=False, n_tokens=len(PROTEIN_ALPHABET),
119 |                  padding_idx=PROTEIN_ALPHABET.index(PAD), mask_idx=PROTEIN_ALPHABET.index(MASK),
120 |                  max_positions=1024, tie_weights=True):
121 |         super(MSATransformer, self).__init__()
122 |         self.embed_tokens = nn.Embedding(
123 |             n_tokens, d_model, padding_idx=mask_idx
124 |         )
125 |         self.layers = nn.ModuleList(
126 |             [
127 |                 AxialTransformerLayer(
128 |                     d_model, d_hidden, n_heads
129 |                 )
130 |                 for _ in range(n_layers)
131 |             ]
132 |         )
133 |         self.padding_idx = padding_idx
134 | 
135 |         # self.contact_head = ContactPredictionHead()
136 |         self.embed_positions = LearnedPositionalEmbedding(max_positions, d_model, padding_idx)
137 |         self.emb_layer_norm_before = nn.LayerNorm(d_model)
138 |         self.emb_layer_norm_after = nn.LayerNorm(d_model)
139 |         if tie_weights:
140 |             self.lm_head = RobertaLMHead(
141 |                 embed_dim=d_model,
142 |                 output_dim=n_tokens,
143 |                 weight=self.embed_tokens.weight
144 |             )
145 |         else:
146 |             self.lm_head = RobertaLMHead(
147 |                 embed_dim=d_model,
148 |                 output_dim=n_tokens,
149 |                 weight=nn.Linear(d_model, n_tokens).weight
150 |             )
151 | 
152 |         self.use_ckpt = use_ckpt
153 | 
154 |     def forward(self, tokens):
155 |         assert tokens.ndim == 3
156 |         batch_size, num_alignments, seqlen = tokens.size()
157 |         padding_mask = tokens.eq(self.padding_idx)  # B, R, C
158 | 
159 |         x = self.embed_tokens(tokens)
160 |         x = x + self.embed_positions(tokens.view(batch_size * num_alignments, seqlen)).view(x.size())
161 | 
162 |         x = self.emb_layer_norm_before(x)
163 |         x = x * (1 - padding_mask.unsqueeze(-1).type_as(x))
164 | 
165 |         # B x R x C x D -> R x C x B x D
166 |         x = x.permute(1, 2, 0, 3)
167 | 
168 |         for layer_idx, layer in enumerate(self.layers):
169 |             x = checkpoint(layer, x, None, padding_mask, False, use_reentrant=True)
170 | 
171 |         x = self.emb_layer_norm_after(x)
172 |         x = x.permute(2, 0, 1, 3)  # R x C x B x D -> B x R x C x D
173 |         x = self.lm_head(x)
174 |         return x


--------------------------------------------------------------------------------
/sequence_models/flip_utils.py:
--------------------------------------------------------------------------------
 1 | import re
 2 | 
 3 | import numpy as np
 4 | import pandas as pd
 5 | 
 6 | from torch.utils.data import Dataset
 7 | 
 8 | 
 9 | def load_flip_data(data_fpath, dataset, split, max_len=2048):
10 |     """returns dataframe of train, (val), test sets"""
11 |     datadir = data_fpath + dataset + '/splits/'
12 |     path = datadir + split + '.csv'
13 |     df = pd.read_csv(path)
14 |     df['sequence'] = df.sequence.apply(lambda s: re.sub(r'[^A-Z]', '', s.upper()))  # remove special characters
15 |     targets = []
16 |     if dataset == 'scl':
17 |         locations = {
18 |             'Cytoplasm': 0,
19 |             'Nucleus': 1,
20 |             'Cell membrane': 2,
21 |             'Mitochondrion': 3,
22 |             'Endoplasmic reticulum': 4,
23 |             'Lysosome/Vacuole': 5,
24 |             'Golgi apparatus': 6,
25 |             'Peroxisome': 7,
26 |             'Extracellular': 8,
27 |             'Plastid': 9
28 |         }
29 |         for i, row in df.iterrows():
30 |             targets.append(locations[row['target']])
31 |     df['target'] = targets
32 |     test = df[df.set == 'test']
33 |     train = df[(df.set == 'train') & (df.validation.isnull())]
34 |     valid = df[~df.validation.isnull()]
35 |     return FlipDataset(train, max_len=max_len), FlipDataset(valid, max_len=max_len), FlipDataset(test, max_len=max_len)
36 | 
37 | 
38 | class FlipDataset(Dataset):
39 | 
40 |     def __init__(self, data, max_len=2048):
41 |         self.sequences = data['sequence'].values
42 |         self.targets = data['target'].values
43 |         self.max_len = max_len
44 | 
45 |     def __getitem__(self, idx):
46 |         s = self.sequences[idx]
47 |         if len(s) > self.max_len:
48 |             start = np.random.choice(len(s) - self.max_len)
49 |             s = s[start: start + self.max_len]
50 |         return s, self.targets[idx]
51 | 
52 |     def __len__(self):
53 |         return len(self.sequences)


--------------------------------------------------------------------------------
/sequence_models/gvp.py:
--------------------------------------------------------------------------------
 1 | import torch
 2 | import torch.nn as nn
 3 | from gvp import GVP, GVPConvLayer, LayerNorm
 4 | 
 5 | 
 6 | class GVPEncoder(nn.Module):
 7 |     '''
 8 |         GVP-GNN encoder.
 9 | 
10 |         Takes in protein structure graphs of type `torch_geometric.data.Data`
11 |         or `torch_geometric.data.Batch` and returns an embedding.
12 | 
13 |         Should be used with `gvp.data.ProteinGraphDataset`, or with generators
14 |         of `torch_geometric.data.Batch` objects with the same attributes.
15 | 
16 |         :param node_in_dim: node dimensions in input graph, should be
17 |                             (6, 3) if using original features
18 |         :param node_h_dim: node dimensions to use in GVP-GNN layers
19 |         :param node_in_dim: edge dimensions in input graph, should be
20 |                             (32, 1) if using original features
21 |         :param edge_h_dim: edge dimensions to embed to before use
22 |                            in GVP-GNN layers
23 |         :seq_in: if `True`, sequences will also be passed in with
24 |                  the forward pass; otherwise, sequence information
25 |                  is assumed to be part of input node embeddings
26 |         :param num_layers: number of GVP-GNN layers
27 |         :param drop_rate: rate to use in all dropout layers
28 |         '''
29 |     def __init__(self, node_in_dim, node_h_dim,
30 |                  edge_in_dim, edge_h_dim, num_layers=3, drop_rate=0, vocab_size=30):
31 |         super(GVPEncoder, self).__init__()
32 |         self.W_s = nn.Embedding(vocab_size, vocab_size)
33 |         node_in_dim = (node_in_dim[0] + vocab_size, node_in_dim[1])
34 | 
35 |         self.W_v = nn.Sequential(
36 |             LayerNorm(node_in_dim),
37 |             GVP(node_in_dim, node_h_dim, activations=(None, None))
38 |         )
39 |         self.W_e = nn.Sequential(
40 |             LayerNorm(edge_in_dim),
41 |             GVP(edge_in_dim, edge_h_dim, activations=(None, None))
42 |         )
43 | 
44 |         self.layers = nn.ModuleList(
45 |             GVPConvLayer(node_h_dim, edge_h_dim, drop_rate=drop_rate)
46 |             for _ in range(num_layers))
47 | 
48 |         ns, _ = node_h_dim
49 |         self.W_out = nn.Sequential(
50 |             LayerNorm(node_h_dim),
51 |             GVP(node_h_dim, (ns, 0)))
52 | 
53 |     def forward(self, h_V, edge_index, h_E, seq):
54 |         '''
55 |         :param h_V: tuple (s, V) of node embeddings
56 |         :param edge_index: `torch.Tensor` of shape [2, num_edges]
57 |         :param h_E: tuple (s, V) of edge embeddings
58 |         :param seq: if not `None`, int `torch.Tensor` of shape [num_nodes]
59 |                     to be embedded and appended to `h_V`
60 |         '''
61 |         seq = self.W_s(seq)
62 |         h_V = (torch.cat([h_V[0], seq], dim=-1), h_V[1])
63 |         h_V = self.W_v(h_V)
64 |         h_E = self.W_e(h_E)
65 |         for layer in self.layers:
66 |             h_V = layer(h_V, edge_index, h_E)
67 |         out = self.W_out(h_V)
68 |         return out
69 | 
70 | 
71 | class GVPMLM(nn.Module):
72 |     def __init__(self, node_in_dim, node_h_dim,
73 |                  edge_in_dim, edge_h_dim, num_layers=3, drop_rate=0, vocab_size=30):
74 |         super(GVPMLM, self).__init__()
75 |         self.encoder = GVPEncoder(node_in_dim, node_h_dim, edge_in_dim, edge_h_dim,
76 |                                   num_layers=num_layers, drop_rate=drop_rate, vocab_size=vocab_size)
77 |         self.decoder = nn.Linear(node_h_dim[0], vocab_size)
78 | 
79 |     def forward(self, h_V, edge_index, h_E, seq):
80 |         '''
81 |         :param h_V: tuple (s, V) of node embeddings
82 |         :param edge_index: `torch.Tensor` of shape [2, num_edges]
83 |         :param h_E: tuple (s, V) of edge embeddings
84 |         :param seq: if not `None`, int `torch.Tensor` of shape [num_nodes]
85 |                     to be embedded and appended to `h_V`
86 |         '''
87 |         e = self.encoder(h_V, edge_index, h_E, seq)
88 |         out = self.decoder(e)
89 |         return out
90 | 
91 | 
92 | 


--------------------------------------------------------------------------------
/sequence_models/layers.py:
--------------------------------------------------------------------------------
  1 | from typing import List
  2 | 
  3 | import math
  4 | import torch.nn as nn
  5 | import torch.nn.functional as F
  6 | import torch
  7 | import numpy as np
  8 | 
  9 | 
 10 | class DoubleEmbedding(nn.Module):
 11 | 
 12 |     """Embedding layer that allows some frozen and some trainable embeddings.
 13 | 
 14 |     An embedding layer where the first n_trainable embeddings are trainable and the
 15 |     remaining n_frozen embeddings are frozen.
 16 |     """
 17 | 
 18 |     def __init__(self, n_trainable, n_frozen, embedding_dim, padding_idx=None):
 19 |         super().__init__()
 20 |         if padding_idx is None:
 21 |             train_padding_idx = None
 22 |             freeze_padding_idx = None
 23 |         elif padding_idx < n_trainable:
 24 |             train_padding_idx = padding_idx
 25 |             freeze_padding_idx = None
 26 |         else:
 27 |             train_padding_idx = None
 28 |             freeze_padding_idx = padding_idx - n_trainable
 29 |         self.n_trainable = n_trainable
 30 |         self.embedding_dim = embedding_dim
 31 |         self.trainable = nn.Embedding(n_trainable, embedding_dim, padding_idx=train_padding_idx)
 32 |         self.frozen = nn.Embedding(n_frozen, embedding_dim, padding_idx=freeze_padding_idx)
 33 |         self.frozen.weight.requires_grad = False
 34 | 
 35 |     def forward(self, idx):
 36 |         i = torch.where(idx < self.n_trainable)
 37 |         j = torch.where(idx >= self.n_trainable)
 38 |         b, ell = idx.shape
 39 |         e = torch.empty(b, ell, self.embedding_dim, device=idx.device, dtype=self.trainable.weight.dtype)
 40 |         e[i] = self.trainable(idx[i])
 41 |         e[j] = self.frozen(idx[j] - self.n_trainable)
 42 |         return e
 43 | 
 44 | 
 45 | class FactorizedLinear(nn.Module):
 46 | 
 47 |     def __init__(self, d_in, d_out, rank):
 48 |         super().__init__()
 49 |         layer = nn.Linear(d_in, d_out)
 50 |         w = layer.weight.data
 51 |         self.bias = layer.bias
 52 |         u, s, v = torch.svd(w)
 53 |         s = torch.diag(s[:rank].sqrt())
 54 |         u = u[:, :rank]
 55 |         v = v.t()[:rank]
 56 |         self.u = nn.Parameter((u @ s).t())
 57 |         self.v = nn.Parameter((s @ v).t())
 58 | 
 59 |     def forward(self, x):
 60 |         return x @ self.v @ self.u + self.bias
 61 | 
 62 | 
 63 | class PositionFeedForward(nn.Module):
 64 | 
 65 |     def __init__(self, d_in, d_out, rank=None):
 66 |         super().__init__()
 67 |         if rank is None:
 68 |             self.conv = nn.Conv1d(d_in, d_out, 1)
 69 |             self.factorized = False
 70 |         else:
 71 |             layer = nn.Linear(d_in, d_out)
 72 |             w = layer.weight.data
 73 |             self.bias = layer.bias
 74 |             u, s, v = torch.svd(w)
 75 |             s = torch.diag(s[:rank].sqrt())
 76 |             u = u[:, :rank]
 77 |             v = v.t()[:rank]
 78 |             self.u = nn.Parameter(u @ s)
 79 |             self.v = nn.Parameter(s @ v)
 80 |             self.factorized = True
 81 | 
 82 |     def forward(self, x):
 83 |         if self.factorized:
 84 |             w = self.u @ self.v
 85 |             return x @ w.t() + self.bias
 86 |         else:
 87 |             return self.conv(x.transpose(1, 2)).transpose(1, 2)
 88 | 
 89 | 
 90 | class PositionFeedForward2d(nn.Module):
 91 | 
 92 |     def __init__(self, d_in, d_out):
 93 |         super().__init__()
 94 |         self.dense = nn.Linear(d_in, d_out)
 95 | 
 96 |     def forward(self, x):
 97 |         return self.dense(x)
 98 | 
 99 | 
100 | class MaskedInstanceNorm2d(nn.InstanceNorm2d):
101 |     ### Expects square inputs before and after masking!!
102 | 
103 |     def __init__(self, n_dims, affine=True, eps=1e-6):
104 |         super().__init__(n_dims, affine=affine, eps=eps)
105 | 
106 | 
107 |     def forward(self, x, input_mask=None):
108 |         if input_mask is None:
109 |             return super().forward(x)
110 |         input_mask = input_mask.bool()
111 |         normed = []
112 |         _, _, max_len, _ = x.shape
113 |         for input, mask in zip(x, input_mask):
114 |             input = torch.masked_select(input, mask)
115 |             el = int(np.sqrt(input.shape[0] // self.num_features))
116 |             input = input.reshape(1, self.num_features, el, el)
117 |             n = max_len - el
118 |             normed.append(F.pad(super().forward(input), (0, n, 0, n), value=0))
119 |         return torch.cat(normed, dim=0)
120 | 
121 | 
122 | class FCStack(nn.Sequential):
123 |     """A stack of fully-connected layers.
124 | 
125 |      Every nn.Linear is optionally followed by  a normalization layer,
126 |      a dropout layer, and then a ReLU.
127 | 
128 |      Args:
129 |          sizes (List of ints): the all layer dimensions from input to output
130 |          norm (str): type of norm. 'bn' for batchnorm, 'ln' for layer norm. Default 'bn'
131 |          p (float): dropout probability
132 | 
133 |      Input (N, sizes[0])
134 |      Output (N, sizes[-1])
135 |      """
136 | 
137 |     def __init__(self, sizes: List[int], norm='bn', p=0.0):
138 |         layers = []
139 |         for d0, d1 in zip(sizes, sizes[1:]):
140 |             layers.append(nn.Linear(d0, d1))
141 |             if norm == 'ln':
142 |                 layers.append(nn.LayerNorm(d1))
143 |             elif norm == 'bn':
144 |                 layers.append(nn.BatchNorm1d(d1))
145 |             if p != 0:
146 |                 layers.append(nn.Dropout(p))
147 |             layers.append(nn.ReLU(inplace=True))
148 |         super().__init__(*layers)
149 | 
150 | 
151 | class PositionalEncoding(nn.Module):
152 | 
153 |     def __init__(self, d_model, dropout=0.0, max_len=5000):
154 |         super(PositionalEncoding, self).__init__()
155 |         self.dropout = nn.Dropout(p=dropout)
156 | 
157 |         pe = torch.zeros(max_len, d_model)
158 |         position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
159 |         div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
160 |         pe[:, 0::2] = torch.sin(position * div_term)
161 |         pe[:, 1::2] = torch.cos(position * div_term)
162 |         pe = pe.unsqueeze(0)
163 |         self.register_buffer('pe', pe)
164 | 
165 |     def forward(self, x):
166 |         x = x + self.pe[:, x.size(0), :]
167 |         return self.dropout(x)
168 | 


--------------------------------------------------------------------------------
/sequence_models/losses.py:
--------------------------------------------------------------------------------
  1 | import torch.nn as nn
  2 | import torch
  3 | import torch.nn.functional as F
  4 | import numpy as np
  5 | 
  6 | 
  7 | class MaskedCosineLoss(nn.Module):
  8 |     """Masked cosine loss between angles."""
  9 | 
 10 |     def __init__(self):
 11 |         super().__init__()
 12 | 
 13 |     def forward(self, pred, tgt, mask):
 14 |         mask = mask.bool()
 15 |         p = torch.masked_select(pred, mask)
 16 |         t = torch.masked_select(tgt, mask)
 17 |         diff = p - t
 18 |         return torch.cos(diff).mean()
 19 | 
 20 | 
 21 | class MaskedMSELoss(nn.MSELoss):
 22 |     """Masked mean square error loss.
 23 | 
 24 |     Evaluates the MSE at specified locations.
 25 | 
 26 |     Shape:
 27 |         Inputs:
 28 |             - pred: (N, *)
 29 |             - tgt: (N, *)
 30 |             - mask: (N, *) boolean
 31 |     """
 32 | 
 33 |     def __init__(self, reduction='mean'):
 34 |         super().__init__(reduction=reduction)
 35 | 
 36 |     def forward(self, pred, tgt, mask):
 37 |         # Make sure mask is boolean
 38 |         mask = mask.bool()
 39 |         # Select
 40 |         p = torch.masked_select(pred, mask)
 41 |         t = torch.masked_select(tgt, mask)
 42 |         if len(p) == 0:
 43 |             return pred.sum() * 0
 44 |         return super().forward(p, t)
 45 | 
 46 | 
 47 | class SequenceCrossEntropyLoss(nn.Module):
 48 |     """Cross-entropy loss for sequences. """
 49 | 
 50 |     def __init__(self, weight=None, ignore_index=-100):
 51 |         super(SequenceCrossEntropyLoss, self).__init__()
 52 |         self.class_weights = weight  # These are class weights
 53 |         self.ignore_index = ignore_index
 54 | 
 55 |     def forward(self, prediction, tgt, reduction='mean'):
 56 |         # Transpose because pytorch expects (N, C, ...) where C is number of classes
 57 |         return F.cross_entropy(prediction.transpose(1, 2), tgt, weight=self.class_weights, reduction=reduction,
 58 |                                ignore_index=self.ignore_index)
 59 | 
 60 | 
 61 | class MaskedCrossEntropyLoss(nn.CrossEntropyLoss):
 62 |     """Masked cross-entropy loss for sequences.
 63 | 
 64 |     Evaluates the cross-entropy loss at specified locations in a sequence.
 65 | 
 66 |     Shape:
 67 |         Inputs:
 68 |             - pred: (N, L, n_tokens)
 69 |             - tgt: (N, L)
 70 |             - mask: (N, L) boolean
 71 |             - weight: (C, ): class weights for nn.CrossEntropyLoss
 72 |     """
 73 | 
 74 |     def __init__(self, weight=None, reduction='mean'):
 75 |         super().__init__(weight=weight, reduction=reduction)
 76 | 
 77 |     def forward(self, pred, tgt, mask):
 78 |         # Make sure we have that empty last dimension
 79 |         if len(mask.shape) == len(pred.shape) - 1:
 80 |             mask = mask.unsqueeze(-1)
 81 |         # Make sure mask is boolean
 82 |         mask = mask.bool()
 83 |         # Number of locations to calculate loss
 84 |         n = mask.sum()
 85 |         # Select
 86 |         p = torch.masked_select(pred, mask).view(n, -1)
 87 |         t = torch.masked_select(tgt, mask.squeeze())
 88 |         return super().forward(p, t)
 89 | 
 90 | 
 91 | class VAELoss(nn.Module):
 92 |     """A simple VAE loss.
 93 | 
 94 |     This is the sum of a reconstruction loss (calculated on predictions and ground truths)
 95 |     and the KL divergence between z (mu, log_var) and a standard normal.
 96 | 
 97 |     Args:
 98 |         class_weights (): The reconstruction loss
 99 | 
100 |     Inputs:
101 |         pre: The predictions (N, *)
102 |         tgt: The ground truths (N, *)
103 |         mu: Predicted means for the latent space
104 |         log_var: Predicted log variance for the latent space
105 |         beta: Ratio between the reconstruction and KLD losses. Optional: default is 1.0.
106 |         sample_weights: Weight to place on each sample. Default None. Size (N, 1 x *)
107 |         reduction (str): 'mean' or 'none'
108 | 
109 |     Outputs:
110 |         loss (, ): Tensor containing the VAE loss
111 |     """
112 | 
113 |     def __init__(self, class_weights=None):
114 |         super(VAELoss, self).__init__()
115 |         self.recon_loss = SequenceCrossEntropyLoss(weight=class_weights)
116 | 
117 |     def forward(self, pre, tgt, mu, log_var, beta=1.0, sample_weights=None, reduction='mean'):
118 |         kld = -0.5 * (1 + log_var - mu ** 2 - log_var.exp())
119 |         r_loss = self.recon_loss(pre, tgt, reduction='none')
120 |         if sample_weights is None:
121 |             kld = kld.sum(dim=1)
122 |             r_loss = r_loss.sum(dim=1)
123 |         else:
124 |             kld = (kld * sample_weights).mean(dim=1)
125 |             r_loss *= sample_weights
126 |             if self.recon_loss.class_weights is not None:
127 |                 r_loss = r_loss.sum(dim=1) / self.recon_loss.class_weights[tgt].sum()
128 |             else:
129 |                 r_loss = r_loss.mean(dim=1)
130 |         if reduction == 'mean':
131 |             kld = kld.mean()
132 |             r_loss = r_loss.mean()
133 |         return r_loss + beta * kld, r_loss, kld
134 | 
135 | 
136 | class MaskedCrossEntropyLossMSA(nn.CrossEntropyLoss):
137 |     """Masked cross-entropy loss for MSAs.
138 |     Evaluates the cross-entropy loss at specified locations in an MSA.
139 |     Shape:
140 |         Inputs:
141 |             - pred: (BS, N, L, n_tokens)
142 |             - tgt: (BS, N, L): label, with uncorrupted tokens
143 |             - mask: (BS, N, L) boolean
144 |     """
145 | 
146 |     def __init__(self, ignore_index, reweight=True):
147 |         super().__init__(ignore_index=ignore_index, reduction='none')
148 |         self.reweight = reweight
149 | 
150 |     def forward(self, pred, tgt, mask, nonpad_mask):
151 |         # Make sure we have that empty last dimension
152 |         if len(mask.shape) == len(pred.shape) - 1:
153 |             mask = mask.unsqueeze(-1)
154 |             nonpad_mask = nonpad_mask.unsqueeze(-1)
155 | 
156 |         # Make sure mask is boolean
157 |         mask = mask.bool()
158 |         nonpad_mask = nonpad_mask.bool()
159 | 
160 |         batch_size = pred.shape[0]
161 |         # print(batch_size)
162 | 
163 |         # Create re-weighting array
164 |         num_masked_tokens = mask.sum(axis=(1, 2))  # D-t+1 masked tokens per MSA in each batch
165 |         num_nonpad_tokens = nonpad_mask.sum(axis=(1, 2))
166 | 
167 |         n = mask.sum()
168 |         p = torch.masked_select(pred, mask).view(n, -1)
169 |         t = torch.masked_select(tgt, mask.squeeze())
170 |         loss = super().forward(p, t)
171 |         # loss[torch.isnan(loss)] = 0.
172 |         total_loss = loss.sum()
173 | 
174 |         if self.reweight:
175 |             num_masked_tokens_msa = torch.squeeze(num_masked_tokens)
176 |             val_batch = 1 / num_masked_tokens_msa
177 |             rwt = val_batch.repeat_interleave(num_masked_tokens_msa)
178 |             num_nonpad_tokens_msa = torch.squeeze(num_nonpad_tokens)
179 |             d_term = num_nonpad_tokens_msa.repeat_interleave(num_masked_tokens_msa)
180 |             rwt = rwt.type(loss.dtype)
181 | 
182 |             rwt_loss = (d_term * rwt * loss).sum()
183 |         else:
184 |             rwt_loss = total_loss
185 | 
186 |         return rwt_loss, total_loss


--------------------------------------------------------------------------------
/sequence_models/metrics.py:
--------------------------------------------------------------------------------
  1 | import torch
  2 | import numpy as np
  3 | 
  4 | class SequenceAccuracy(object):
  5 |     """Computes accuracy between two sequences.
  6 | 
  7 |     Inputs:
  8 |         pred (N, L, C)
  9 |         tgt (N, L)
 10 |         ignore_index (int): index of token to mask out
 11 |     """
 12 |     def __init__(self, ignore_index=-100):
 13 |         self.ignore_index = ignore_index
 14 | 
 15 |     def __call__(self, pred, tgt):
 16 |         n = tgt.shape[0]
 17 |         pred = pred.argmax(dim=-1).view(n,-1)
 18 |         tgt = tgt.view(n, -1)
 19 |         mask = tgt != self.ignore_index
 20 |         tgt = tgt[mask]
 21 |         pred = pred[mask]
 22 |         return (pred == tgt).float().mean()
 23 | 
 24 | 
 25 | class MaskedAccuracy(object):
 26 |     """Masked accuracy.
 27 | 
 28 |     Inputs:
 29 |         pred (N, L, C)
 30 |         tgt (N, L)
 31 |         mask (N, L)
 32 |     """
 33 | 
 34 |     def __call__(self, pred, tgt, mask):
 35 |         _, p = torch.max(pred, -1)
 36 |         masked_tgt = torch.masked_select(tgt, mask.bool())
 37 |         p = torch.masked_select(p, mask.bool())
 38 |         return torch.mean((p == masked_tgt).float())
 39 | 
 40 | 
 41 | class MaskedTopkAccuracy(object):
 42 |     """Masked top k accuracy.
 43 | 
 44 |     Inputs:
 45 |         pred (N, L, C)
 46 |         tgt (N, L)
 47 |         mask (N, L)
 48 |         k (int)
 49 |     """
 50 | 
 51 |     def __call__(self, pred, tgt, mask, k):
 52 |         _, p = torch.topk(pred, k, -1)
 53 |         masked_tgt = torch.masked_select(tgt, mask.bool())
 54 |         p = torch.masked_select(p, mask.bool().unsqueeze(-1)).view(-1, k)
 55 |         masked_tgt = masked_tgt.repeat(k).view(k, -1).t()
 56 |         return (p == masked_tgt).float().sum(dim=1).mean()
 57 | 
 58 | 
 59 | class UngappedAccuracy(MaskedAccuracy):
 60 | 
 61 |     def __init__(self, gap_index):
 62 |         self.gap_index = gap_index
 63 | 
 64 |     def __call__(self, pred, tgt):
 65 |         mask = tgt != self.gap_index
 66 |         return super().__call__(pred, tgt, mask)
 67 | 
 68 | 
 69 | class LPrecision(object):
 70 |     """
 71 |     Calculates top L // k precision where L is length
 72 |     * params acquired from https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4894841/#FN1
 73 | 
 74 |     """
 75 |     def __init__(self, k=5, contact_range='medium-long'):
 76 |         """
 77 |         Args:
 78 |             k: L // k number of contacts to check
 79 |             contact_range: short, medium or long contacts
 80 |         """
 81 |         if contact_range == 'short':
 82 |             self.res_range = [6, 12]
 83 |         elif contact_range == 'medium':
 84 |             self.res_range = [12, 24]
 85 |         elif contact_range == 'long':
 86 |             self.res_range = [24, np.inf]
 87 |         elif contact_range == 'medium-long':
 88 |             self.res_range = [12, np.inf]
 89 |         else:
 90 |             raise ValueError("contact_range must be one of 'short', 'medium', 'long', or 'medium-long'.")
 91 |         # contact if d < 8 angstroms, or d > exp(-8 ** 2 / 8 ** 2)
 92 |         self.contact_threshold = np.exp(-1)
 93 |         self.k = k
 94 | 
 95 |     def __call__(self, prediction, tgt, mask, ells):
 96 |         """
 97 |         Args:
 98 |             prediction: torch.tensor (N, L, L)
 99 |             tgt: torch.tensor (N, L, L)
100 |             mask: torch.tensor (N, L, L)
101 |             ells: torch.tensor (N,)
102 |                 lengths of protein sequences
103 |         """
104 | 
105 |         n, el, _ = tgt.shape
106 | 
107 |         # update the mask
108 |         # get distance based on primary structure
109 |         pri_dist = torch.abs(torch.arange(el)[None, :].repeat(el, 1) - torch.arange(el).view(-1, 1)).float()
110 |         # repeat for each sample in batch size
111 |         pri_dist = pri_dist.view(1, el, el).repeat(n, 1, 1)
112 |         dist_mask = (pri_dist > self.res_range[0]) & (pri_dist < self.res_range[1])
113 |         mask = dist_mask & mask
114 | 
115 |         # pull the top_k most likely contacts from each prediction
116 |         prediction = prediction.masked_fill(~mask, -1)
117 |         tgt = tgt.masked_fill(~mask, -1)
118 |         # Get just the upper triangular
119 |         idx = torch.triu_indices(el, el, offset=1)
120 |         prediction = torch.stack([p[idx[0], idx[1]] for p in prediction])  # N x n_triu
121 |         tgt = torch.stack([t[idx[0], idx[1]] for t in tgt])  # N x n_triu
122 |         tgt = tgt > self.contact_threshold
123 |         idx = torch.argsort(prediction, dim=1, descending=True)  # N x tri_u
124 | 
125 |         # see how many are tp or fp
126 |         # how many contacts to look at
127 |         top_k = ells // self.k
128 |         n_valid = mask.sum(dim=-1).sum(dim=1)
129 |         n_valid = np.minimum(n_valid, top_k).long()  # (N, )
130 |         n_predicted = n_valid.sum().item()
131 |         if n_predicted == 0:
132 |             return 0, 0
133 |         # n_predicted = (prediction > self.contact_threshold).sum(dim=1)
134 |         # n_valid = np.minimum(n_valid, n_predicted).long()
135 |         n_contacts = 0
136 |         for ids, t, n in zip(idx, tgt, n_valid):
137 |             n_contacts += t[ids[:n]].sum().item()
138 |         precision = n_contacts / n_predicted
139 |         return precision, n_predicted
140 | 
141 | 


--------------------------------------------------------------------------------
/sequence_models/mixup.py:
--------------------------------------------------------------------------------
 1 | import torch.nn as nn
 2 | 
 3 | 
 4 | class Mixup(object):
 5 | 
 6 |     def __init__(self, alpha_sampler):
 7 |         self.sampler = alpha_sampler
 8 | 
 9 |     def __call__(self, x1, x2, y1, y2):
10 |         alpha = self.sampler.rsample().to(x1.device)
11 |         x = alpha * x1 + (1 - alpha) * x2
12 |         y = alpha * y1 + (1 - alpha) * y2
13 |         return x, y


--------------------------------------------------------------------------------
/sequence_models/pdb_utils.py:
--------------------------------------------------------------------------------
  1 | import gzip
  2 | import numpy as np
  3 | import scipy
  4 | from scipy.spatial.distance import squareform, pdist
  5 | 
  6 | from sequence_models.constants import IUPAC_CODES
  7 | 
  8 | 
  9 | def get_dihedrals(a, b, c, d):
 10 |     b0 = -1.0 * (b - a)
 11 |     b1 = c - b
 12 |     b2 = d - c
 13 | 
 14 |     b1 /= np.linalg.norm(b1, axis=-1)[:, None]
 15 | 
 16 |     v = b0 - np.sum(b0 * b1, axis=-1)[:, None] * b1
 17 |     w = b2 - np.sum(b2 * b1, axis=-1)[:, None] * b1
 18 | 
 19 |     x = np.sum(v * w, axis=-1)
 20 |     y = np.sum(np.cross(b1, v) * w, axis=-1)
 21 | 
 22 |     return np.arctan2(y, x)
 23 | 
 24 | 
 25 | def get_angles(a, b, c):
 26 |     v = a - b
 27 |     v /= np.linalg.norm(v, axis=-1)[:, None]
 28 | 
 29 |     w = c - b
 30 |     w /= np.linalg.norm(w, axis=-1)[:, None]
 31 | 
 32 |     x = np.sum(v * w, axis=1)
 33 | 
 34 |     return np.arccos(x)
 35 | 
 36 | 
 37 | def parse_PDB(x, atoms=["N", "CA", "C"], chain=None):
 38 |     """
 39 |     input:  x = PDB filename
 40 |             atoms = atoms to extract (optional)
 41 |     output: (length, atoms, coords=(x,y,z)), sequence
 42 |     """
 43 |     xyz, seq, min_resn, max_resn = {}, {}, np.inf, -np.inf
 44 |     open_func = gzip.open if x.endswith('.gz') else open
 45 |     for line in open_func(x, "rb"):
 46 |         line = line.decode("utf-8", "ignore").rstrip()
 47 | 
 48 |         if line[:6] == "HETATM" and line[17 : 17 + 3] == "MSE":
 49 |             line = line.replace("HETATM", "ATOM  ")
 50 |             line = line.replace("MSE", "MET")
 51 | 
 52 |         if line[:4] == "ATOM":
 53 |             ch = line[21:22]
 54 |             if ch == chain or chain is None:
 55 |                 atom = line[12 : 12 + 4].strip()
 56 |                 resi = line[17 : 17 + 3]
 57 |                 resn = line[22 : 22 + 5].strip()
 58 |                 x, y, z = [float(line[i : (i + 8)]) for i in [30, 38, 46]]
 59 | 
 60 |                 if resn[-1].isalpha():
 61 |                     resa, resn = resn[-1], int(resn[:-1]) - 1
 62 |                 else:
 63 |                     resa, resn = "", int(resn) - 1
 64 |                 if resn < min_resn:
 65 |                     min_resn = resn
 66 |                 if resn > max_resn:
 67 |                     max_resn = resn
 68 |                 if resn not in xyz:
 69 |                     xyz[resn] = {}
 70 |                 if resa not in xyz[resn]:
 71 |                     xyz[resn][resa] = {}
 72 |                 if resn not in seq:
 73 |                     seq[resn] = {}
 74 |                 if resa not in seq[resn]:
 75 |                     seq[resn][resa] = resi
 76 | 
 77 |                 if atom not in xyz[resn][resa]:
 78 |                     xyz[resn][resa][atom] = np.array([x, y, z])
 79 | 
 80 |     # convert to numpy arrays, fill in missing values
 81 |     seq_, xyz_ = [], []
 82 |     for resn in range(min_resn, max_resn + 1):
 83 |         if resn in seq:
 84 |             for k in sorted(seq[resn]):
 85 |                 seq_.append(IUPAC_CODES.get(seq[resn][k].capitalize(), "X"))
 86 |         else:
 87 |             seq_.append("X")
 88 |         if resn in xyz:
 89 |             for k in sorted(xyz[resn]):
 90 |                 for atom in atoms:
 91 |                     if atom in xyz[resn][k]:
 92 |                         xyz_.append(xyz[resn][k][atom])
 93 |                     else:
 94 |                         xyz_.append(np.full(3, np.nan))
 95 |         else:
 96 |             for atom in atoms:
 97 |                 xyz_.append(np.full(3, np.nan))
 98 | 
 99 |     valid_resn = np.array(sorted(xyz.keys()))
100 |     return np.array(xyz_).reshape(-1, len(atoms), 3), "".join(seq_), valid_resn
101 | 
102 | 
103 | def process_coords(coords):
104 |     N = np.array(coords['N'])
105 |     Ca = np.array(coords['CA'])
106 |     C = np.array(coords['C'])
107 | 
108 |     # recreate Cb given N,Ca,C
109 |     nres = len(N)
110 |     b = Ca - N
111 |     c = C - Ca
112 |     a = np.cross(b, c)
113 |     Cb = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + Ca
114 | 
115 |     # Cb-Cb distance matrix
116 |     dist = squareform(pdist(Cb))
117 |     np.fill_diagonal(dist, np.nan)
118 |     indices = [[i for i in range(nres) if i != j] for j in range(nres)]
119 |     idx = np.array([[i, j] for i in range(len(indices)) for j in indices[i]]).T
120 |     idx0 = idx[0]
121 |     idx1 = idx[1]
122 |     # matrix of Ca-Cb-Cb-Ca dihedrals
123 |     omega = np.zeros((nres, nres)) + np.nan
124 |     omega[idx0, idx1] = get_dihedrals(Ca[idx0], Cb[idx0], Cb[idx1], Ca[idx1])
125 | 
126 |     # matrix of polar coord theta
127 |     theta = np.zeros((nres, nres)) + np.nan
128 |     theta[idx0, idx1] = get_dihedrals(N[idx0], Ca[idx0], Cb[idx0], Cb[idx1])
129 | 
130 |     # matrix of polar coord phi
131 |     phi = np.zeros((nres, nres)) + np.nan
132 |     phi[idx0, idx1] = get_angles(Ca[idx0], Cb[idx0], Cb[idx1])
133 |     return dist, omega, theta, phi
134 | 


--------------------------------------------------------------------------------
/sequence_models/pretrained.py:
--------------------------------------------------------------------------------
  1 | import torch
  2 | import torch.nn as nn
  3 | 
  4 | from sequence_models.constants import PROTEIN_ALPHABET, PAD, MASK
  5 | from sequence_models.convolutional import ByteNetLM
  6 | from sequence_models.gnn import BidirectionalStruct2SeqDecoder
  7 | from sequence_models.collaters import SimpleCollater, StructureCollater, BGCCollater
  8 | 
  9 | 
 10 | CARP_URL = 'https://zenodo.org/record/6564798/files/'
 11 | MIF_URL = 'https://zenodo.org/record/6573779/files/'
 12 | BIG_URL = 'https://zenodo.org/record/6857704/files/'
 13 | n_tokens = len(PROTEIN_ALPHABET)
 14 | 
 15 | 
 16 | def load_carp(model_data):
 17 |     d_embedding = model_data['d_embed']
 18 |     d_model = model_data['d_model']
 19 |     n_layers = model_data['n_layers']
 20 |     kernel_size = model_data['kernel_size']
 21 |     activation = model_data['activation']
 22 |     slim = model_data['slim']
 23 |     r = model_data['r']
 24 |     bgc = 'bigcarp' in model_data['model']
 25 |     if not bgc:
 26 |         n_tokens = len(PROTEIN_ALPHABET)
 27 |         mask_idx = PROTEIN_ALPHABET.index(MASK)
 28 |         pad_idx = PROTEIN_ALPHABET.index(PAD)
 29 |         n_frozen = None
 30 |     else:
 31 |         n_tokens = model_data['tokens']['size']
 32 |         mask_idx = model_data['tokens']['specials'][MASK]
 33 |         pad_idx = model_data['tokens']['specials'][PAD]
 34 |         if 'frozen' in model_data['model']:
 35 |             n_frozen = 19450
 36 |         else:
 37 |             n_frozen = None
 38 |     model = ByteNetLM(n_tokens, d_embedding, d_model, n_layers, kernel_size, r, dropout=0.0,
 39 |                       activation=activation, causal=False, padding_idx=mask_idx,
 40 |                       final_ln=True, slim=slim, n_frozen_embs=n_frozen)
 41 |     sd = model_data['model_state_dict']
 42 |     model.load_state_dict(sd)
 43 |     model = CARP(model.eval(), pad_idx=pad_idx)
 44 |     return model
 45 | 
 46 | def load_gnn(model_data):
 47 |     one_hot_src = model_data['model'] == 'mif'
 48 |     gnn = BidirectionalStruct2SeqDecoder(n_tokens, 10, 11,
 49 |                                          256, num_decoder_layers=4,
 50 |                                          dropout=0.05, use_mpnn=True,
 51 |                                          pe=False, one_hot_src=one_hot_src)
 52 |     sd = model_data['model_state_dict']
 53 |     gnn.load_state_dict(sd)
 54 |     return gnn.eval()
 55 | 
 56 | def load_model_and_alphabet(model_name):
 57 |     if not model_name.endswith(".pt"):  # treat as filepath
 58 |         if 'big' in model_name:
 59 |             url = BIG_URL + '%s.pt?download=1' %model_name
 60 |         elif 'carp' in model_name:
 61 |             url = CARP_URL + '%s.pt?download=1' %model_name
 62 |         elif 'mif' in model_name:
 63 |             url = MIF_URL + '%s.pt?download=1' %model_name
 64 |         model_data = torch.hub.load_state_dict_from_url(url, progress=False, map_location="cpu")
 65 |     else:
 66 |         model_data = torch.load(model_name, map_location="cpu")
 67 |     if 'big' in model_data['model']:
 68 |         pfam_to_domain = model_data['pfam_to_domain']
 69 |         tokens = model_data['tokens']
 70 |         collater = BGCCollater(tokens, pfam_to_domain)
 71 |     else:
 72 |         collater = SimpleCollater(PROTEIN_ALPHABET, pad=True)
 73 |     if 'carp' in model_data['model']:
 74 |         model = load_carp(model_data)
 75 |     elif model_data['model'] in ['mif', 'mif-st']:
 76 |         gnn = load_gnn(model_data)
 77 |         cnn = None
 78 |         if model_data['model'] == 'mif-st':
 79 |             url = CARP_URL + '%s.pt?download=1' % 'carp_640M'
 80 |             cnn_data = torch.hub.load_state_dict_from_url(url, progress=False, map_location="cpu")
 81 |             cnn = load_carp(cnn_data)
 82 |         collater = StructureCollater(collater, n_connections=30)
 83 |         model = MIF(gnn, cnn=cnn)
 84 |     return model, collater
 85 | 
 86 | 
 87 | class CARP(nn.Module):
 88 |     """Wrapper that takes care of input masking."""
 89 | 
 90 |     def __init__(self, model: ByteNetLM, pad_idx=PROTEIN_ALPHABET.index(PAD)):
 91 |         super().__init__()
 92 |         self.model = model
 93 |         self.pad_idx = pad_idx
 94 | 
 95 |     def forward(self, x, repr_layers=[-1], logits=False):
 96 |         padding_mask = (x != self.pad_idx)
 97 |         padding_mask = padding_mask.unsqueeze(-1)
 98 |         if len(repr_layers) == 1 and repr_layers[0] == -1:
 99 |             repr_layers = [len(self.model.embedder.layers)]
100 |         result = {}
101 |         if len(repr_layers) > 0:
102 |             result['representations'] = {}
103 |         x = self.model.embedder._embed(x)
104 |         if 0 in repr_layers:
105 |             result[0] = x
106 |         i = 1
107 |         for layer in self.model.embedder.layers:
108 |             x = layer(x, input_mask=padding_mask)
109 |             if i in repr_layers:
110 |                 result['representations'][i] = x
111 |             i += 1
112 |         if logits:
113 |             result['logits'] = self.model.decoder(self.model.last_norm(x))
114 |         return result
115 | 
116 | class MIF(nn.Module):
117 |     """Wrapper that takes care of input masking."""
118 | 
119 |     def __init__(self, gnn: BidirectionalStruct2SeqDecoder, cnn=None):
120 |         super().__init__()
121 |         self.gnn = gnn
122 |         self.cnn = cnn
123 | 
124 |     def forward(self, src, nodes, edges, connections, edge_mask, result='repr'):
125 |         if result == 'logits':
126 |             decoder = True
127 |         elif result == 'repr':
128 |             decoder = False
129 |         else:
130 |             raise ValueError("Result must be either 'repr' or 'logits'")
131 |         if self.cnn is not None:
132 |             src = self.cnn(src, repr_layers=[], logits=True)['logits']
133 |         return self.gnn(nodes, edges, connections, src, edge_mask, decoder=decoder)


--------------------------------------------------------------------------------
/sequence_models/samplers.py:
--------------------------------------------------------------------------------
  1 | from typing import Iterable
  2 | import math
  3 | 
  4 | import numpy as np
  5 | from torch.utils.data import Sampler, BatchSampler
  6 | 
  7 | 
  8 | class SortishSampler(Sampler):
  9 |     """Returns indices such that inputs with similar lengths are close together."""
 10 | 
 11 |     def __init__(self, sequence_lengths: Iterable, bucket_size: int, num_replicas: int = 1, rank: int = 0):
 12 |         self.data = np.argsort(sequence_lengths)
 13 |         self.num_replicas = num_replicas
 14 |         self.num_samples = int(math.ceil(len(self.data) * 1.0 / self.num_replicas))
 15 |         self.bucket_size = bucket_size
 16 |         n_buckets = int(np.ceil(len(self.data) / self.bucket_size))
 17 |         self.data = [self.data[i * bucket_size: i * bucket_size + bucket_size] for i in range(n_buckets)]
 18 |         self.rank = rank
 19 |         self.epoch = 0
 20 |         self.total_size = self.num_samples * self.num_replicas
 21 | 
 22 |     def __iter__(self):
 23 |         for bucket in self.data:
 24 |             np.random.shuffle(bucket)
 25 |         np.random.shuffle(self.data)
 26 |         indices = [item for sublist in self.data for item in sublist]
 27 |         indices += indices[:(self.total_size - len(indices))]
 28 |         assert len(indices) == self.total_size
 29 |         # subsample
 30 |         start = self.rank * self.num_samples
 31 |         end = start + self.num_samples
 32 |         indices = indices[start:end]
 33 |         assert len(indices) == self.num_samples
 34 |         return iter(indices)
 35 | 
 36 |     def __len__(self):
 37 |         return self.num_samples
 38 | 
 39 |     def set_epoch(self, epoch):
 40 |         self.epoch = epoch
 41 |         np.random.seed(self.epoch)
 42 | 
 43 | 
 44 | class ClusteredSortishSampler(SortishSampler):
 45 |     """Samples from clusters, then yields indices such that inputs with similar lengths are close together."""
 46 | 
 47 |     def __init__(self, sequence_lengths: Iterable, clusters: Iterable,
 48 |                  bucket_size: int, num_replicas: int = 1, rank: int = 0):
 49 |         self.num_replicas = num_replicas
 50 |         self.clusters = clusters
 51 |         self.cluster_sizes = np.array([len(c) for c in self.clusters])
 52 |         self.num_samples = int(math.ceil(len(self.clusters) * 1.0 / self.num_replicas))
 53 |         self.bucket_size = bucket_size
 54 |         self.n_buckets = int(np.ceil(len(self.clusters) / self.bucket_size))
 55 |         self.lengths = sequence_lengths
 56 |         self.rank = rank
 57 |         self.total_size = self.num_samples * self.num_replicas
 58 |         self.all_data = np.argsort(sequence_lengths)
 59 | 
 60 |     def set_epoch(self, epoch):
 61 |         self.epoch = epoch
 62 |         np.random.seed(self.epoch)
 63 |         selected = np.random.randint(self.cluster_sizes)
 64 |         selected_indices = [c[s] for c, s in zip(self.clusters, selected)]
 65 |         self.data = self.all_data[np.isin(self.all_data, selected_indices, assume_unique=True)]
 66 |         self.data = [self.data[i * self.bucket_size: i * self.bucket_size + self.bucket_size] for i in
 67 |                      range(self.n_buckets)]
 68 | 
 69 | 
 70 | 
 71 | class ApproxBatchSampler(BatchSampler):
 72 |     """
 73 | 	Parameters:
 74 | 	-----------
 75 | 	sampler : Pytorch Sampler
 76 | 		Choose base sampler class to use for bucketing
 77 | 
 78 | 	max_tokens : int
 79 | 		Maximum number of tokens per batch
 80 | 
 81 | 	max_batch: int
 82 | 		Maximum batch size
 83 | 
 84 | 	sample_lengths : array-like
 85 | 		List of lengths of sequences in the order of the dataset
 86 | 	"""
 87 | 
 88 |     def __init__(self, sampler, max_tokens, max_batch, sample_lengths, max_square_tokens=np.inf,
 89 |                  msa_depth=None, batch_mult=1):
 90 |         self.longest_token = 0
 91 |         self.max_tokens = max_tokens
 92 |         self.max_batch = max_batch
 93 |         self.sampler = sampler
 94 |         self.sample_lengths = sample_lengths
 95 |         self.max_square_tokens = max_square_tokens
 96 |         self.msa_depth = msa_depth
 97 |         self.batch_mult = batch_mult
 98 | 
 99 |     def __iter__(self):
100 |         batch = []
101 |         length = 0
102 |         ell_sq = 0
103 |         for idx in self.sampler:
104 |             this_length = self.sample_lengths[idx]
105 |             if self.msa_depth is None:
106 |                 linear = (len(batch) + 1) * max(length, this_length)
107 |             else:
108 |                 max_len = max(length, this_length)
109 |                 linear = (len(batch) + 1) * (max_len * self.msa_depth ** 2 + max_len ** 2 * self.msa_depth)
110 |             quadratic = (len(batch) + 1) * max(ell_sq, this_length ** 2)
111 |             if linear <= self.max_tokens and quadratic < self.max_square_tokens:
112 |                 batch.append(idx)
113 |                 length = max(length, this_length)
114 |                 ell_sq = max(ell_sq, this_length ** 2)
115 |                 if len(batch) == self.max_batch:
116 |                     yield batch
117 |                     batch = []
118 |                     length = 0
119 |             else:
120 |                 rounded_n = (len(batch) // self.batch_mult) * self.batch_mult
121 |                 rounded_n = max(1, rounded_n)
122 |                 yield batch[:rounded_n]
123 |                 batch = batch[rounded_n:] + [idx]
124 |                 length = max([self.sample_lengths[i] for i in batch])
125 |                 ell_sq = length ** 2
126 |         if len(batch) > 0:
127 |             yield batch
128 | 


--------------------------------------------------------------------------------
/sequence_models/structure.py:
--------------------------------------------------------------------------------
 1 | import torch.nn as nn
 2 | import torch.nn.functional as F
 3 | from torch.utils.checkpoint import checkpoint
 4 | 
 5 | from sequence_models.convolutional import MaskedConv2d, ByteNet2d, ConditionedByteNetDecoder, MaskedConv1d, ByteNet
 6 | from sequence_models.layers import PositionFeedForward
 7 | 
 8 | 
 9 | class ByteNetStructureModel(nn.Module):
10 |     """Takes a Bytenet embedding and converts it to a 2D structural output.
11 | 
12 |     Inputs:
13 |         x (n, ell)
14 |         input_mask (n, ell), optional
15 | 
16 |     Outputs:
17 |         structure (n, ell, ell, d_out)
18 |     """
19 | 
20 |     def __init__(self, bytenet, d_model, d_out):
21 |         super().__init__()
22 |         self.embedder = bytenet
23 |         self.d_model = d_model
24 |         self.p = MaskedConv1d(d_model, 256, 1)
25 |         self.q = MaskedConv1d(d_model, 256, 1)
26 |         self.relu = nn.ReLU()
27 |         self.linear = MaskedConv2d(16, 1, 1)
28 | 
29 |     def forward(self, x, input_mask=None):
30 |         e = self.embedder(x, input_mask=input_mask)
31 |         p = checkpoint(self.p, e)
32 |         q = checkpoint(self.q, e)
33 |         return p @ q.transpose(1, 2) / 256
34 | 
35 | 
36 | class Attention2d(nn.Module):
37 | 
38 |     def __init__(self, in_dim):
39 |         super().__init__()
40 |         self.layer = MaskedConv2d(in_dim, 1, 1)
41 | 
42 |     def forward(self, x, input_mask=None):
43 |         n, ell, _, _ = x.shape
44 |         attn = self.layer(x)
45 |         attn = attn.view(n, -1)
46 |         if input_mask is not None:
47 |             attn = attn.masked_fill_(~input_mask.view(n, -1).bool(), float('-inf'))
48 |         attn = F.softmax(attn, dim=-1).view(n, -1, 1)
49 |         out = (attn * x.view(n, ell * ell, -1)).sum(dim=1)
50 |         return out
51 | 
52 | 
53 | class Attention1d(nn.Module):
54 |     
55 |     def __init__(self, in_dim):
56 |         super().__init__()
57 |         self.layer = MaskedConv1d(in_dim, 1, 1)
58 | 
59 |     def forward(self, x, input_mask=None):
60 |         n, ell, _ = x.shape
61 |         attn = self.layer(x)
62 |         attn = attn.view(n, -1)
63 |         if input_mask is not None:
64 |             attn = attn.masked_fill_(~input_mask.view(n, -1).bool(), float('-inf'))
65 |         attn = F.softmax(attn, dim=-1).view(n, -1, 1)
66 |         out = (attn * x).sum(dim=1)
67 |         return out
68 |     
69 |     
70 | class StructureConditioner(nn.Module):
71 | 
72 |     def __init__(self, d_in, d_model, n_layers, kernel_size, r, dropout=0.0):
73 |         super().__init__()
74 |         self.embedder = ByteNet2d(d_in, d_model, n_layers, kernel_size, r, dropout=dropout)
75 |         self.attention = Attention2d(d_model)
76 | 
77 |     def forward(self, x, input_mask=None):
78 |         return self.attention(self.embedder(x, input_mask=input_mask), input_mask=input_mask)
79 | 
80 | 
81 | class StructureConditionedBytenet(nn.Module):
82 | 
83 |     def __init__(self, n_tokens, d_embedding, d_conditioning, d_model, n_layers, k_b, r_b,
84 |                  d_structure, n_c_layers, k_c, r_c):
85 |         super().__init__()
86 |         self.conditioner = StructureConditioner(d_structure, d_conditioning, n_c_layers, k_c, r_c)
87 |         self.bytenet = ConditionedByteNetDecoder(n_tokens, d_embedding, d_conditioning, d_model, n_layers, k_b, r_b)
88 |         self.decoder = PositionFeedForward(d_model, n_tokens)
89 | 
90 |     def forward(self, src, struc, src_mask, str_mask):
91 |         c = self.conditioner(struc, input_mask=str_mask)
92 |         out = self.bytenet((src, c), input_mask=src_mask)
93 |         out = self.decoder(out)
94 |         return out


--------------------------------------------------------------------------------
/sequence_models/trRosetta.py:
--------------------------------------------------------------------------------
  1 | import os, sys
  2 | import torch
  3 | import torch.nn as nn
  4 | import torch.nn.functional as F
  5 | 
  6 | from sequence_models.trRosetta_utils import *
  7 | from sequence_models.constants import WEIGHTS_DIR
  8 | from sequence_models.layers import MaskedInstanceNorm2d
  9 | 
 10 | 
 11 | def pad_size(d, k, s):
 12 |     return int(((139 * s) - 140 + k + ((k - 1) * (d - 1))) / 2)
 13 | 
 14 | 
 15 | class trRosettaBlock(nn.Module):
 16 |         
 17 |     def __init__(self, dilation, p_dropout=0.0):
 18 |         
 19 |         """Simple convolution block
 20 |         
 21 |         Parameters:
 22 |         -----------
 23 |         dilation : int
 24 |             dilation for conv
 25 |         """
 26 | 
 27 |         super(trRosettaBlock, self).__init__()
 28 |         self.conv1 = nn.Conv2d(64, 64, kernel_size=3, stride=1, dilation=dilation, padding=pad_size(dilation, 3, 1))
 29 |         self.instnorm1 = MaskedInstanceNorm2d(64, eps=1e-06, affine=True)
 30 |         self.instnorm2 = MaskedInstanceNorm2d(64, eps=1e-06, affine=True)
 31 |         self.dropout1 = nn.Dropout2d(p_dropout)
 32 |         self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, dilation=dilation, padding=pad_size(dilation, 3, 1))
 33 | 
 34 |     def forward(self, x, input_mask=None, last_elu=True):
 35 |         """
 36 |         Parameters:
 37 |         -----------
 38 |         x : torch.Tensor()
 39 |             input tensor
 40 |             
 41 |         old_elu : torch.Tensor()
 42 |             copy of x
 43 |         
 44 |         Returns:
 45 |         --------
 46 |         x : torch.Tensor
 47 |             output of block
 48 | 
 49 |         x.clone() : torch.Tensor
 50 |             copy of x
 51 |         """
 52 |         if input_mask is not None:
 53 |             x = x * input_mask
 54 |         h = F.elu(self.instnorm1(self.conv1(x), input_mask=input_mask))
 55 |         h = self.dropout1(h)
 56 |         if input_mask is not None:
 57 |             h = h * input_mask
 58 |         h = self.instnorm2(self.conv2(h), input_mask=input_mask) + x
 59 |         if last_elu:
 60 |             h = F.elu(h)
 61 |         return h
 62 | 
 63 | 
 64 | class trRosetta(nn.Module):
 65 |     
 66 |     """trRosetta for single model"""
 67 | 
 68 |     def __init__(self, d_init=526, n2d_layers=61, model_id='a', decoder=True, p_dropout=0.0):
 69 |         """
 70 |         Parameters:
 71 |         -----------
 72 |         model_id : str
 73 |             pretrained models a, b, c, d and/or e.
 74 |     
 75 |         decoder : bool
 76 |             whether to run the last layers to produce distance 
 77 |             and angle outputs
 78 | 
 79 |         """
 80 |         super(trRosetta, self).__init__()
 81 | 
 82 |         self.conv0 = nn.Conv2d(d_init, 64, kernel_size=1, stride=1, padding=pad_size(1, 1, 1))
 83 |         self.instnorm0 = MaskedInstanceNorm2d(64, eps=1e-06, affine=True)
 84 | 
 85 |         dilation = 1
 86 |         layers = []
 87 |         for _ in range(n2d_layers):
 88 |             layers.append(trRosettaBlock(dilation, p_dropout=p_dropout))
 89 |             dilation *= 2
 90 |             if dilation > 16:
 91 |                 dilation = 1
 92 | 
 93 |         self.layers = nn.ModuleList(modules=layers)
 94 |         self.decoder = decoder
 95 |         if decoder:
 96 |             self.softmax = nn.Softmax(dim=1)
 97 |             self.conv_theta = nn.Conv2d(64, 25, kernel_size=1, stride=1, padding=pad_size(1, 1, 1))
 98 |             self.conv_phi = nn.Conv2d(64, 13, kernel_size=1, stride=1, padding=pad_size(1, 1, 1))
 99 |             self.conv_dist = nn.Conv2d(64, 37, kernel_size=1, stride=1, padding=pad_size(1, 1, 1))
100 |             self.conv_bb = nn.Conv2d(64, 3, kernel_size=1, stride=1, padding=pad_size(1, 1, 1))
101 |             self.conv_omega = nn.Conv2d(64, 25, kernel_size=1, stride=1, padding=pad_size(1, 1, 1))
102 |         if model_id is not None:
103 |             self.load_weights(model_id)
104 | 
105 |     def forward(self, x, input_mask=None, softmax=True):
106 |         """
107 |         Parameters:
108 |         -----------
109 |         x : torch.Tensor, (batch, 526, len(sequence), len(sequence))
110 |             inputs after trRosettaPreprocessing
111 |     
112 |         Returns:
113 |         --------
114 |         dist_probs : torch.Tensor
115 |             distance map probabilities
116 |             
117 |         theta_probs : torch.Tensor
118 |             theta angle map probabilities
119 |             
120 |         phi_probs : torch.Tensor
121 |             phi angle map probabilities
122 |         
123 |         omega_probs: torch..Tensor
124 |             omega angle map probabilities
125 |         
126 |         x : torch.Tensor
127 |             outputs before calculating final layers
128 |         """
129 |         if input_mask is not None:
130 |             x = x * input_mask
131 |         h = self.conv0(x)
132 |         h = F.elu(self.instnorm0(h, input_mask=input_mask))
133 |         for i, layer in enumerate(self.layers):
134 |             if not self.decoder:
135 |                 last_elu = False
136 |             elif i != len(self.layers) - 1:
137 |                 last_elu = True
138 |             else:
139 |                 last_elu = False
140 |             h = layer(h, input_mask=input_mask, last_elu=last_elu)
141 |             if input_mask is not None:
142 |                 h = h * input_mask
143 |         if self.decoder:
144 |             logits_theta = self.conv_theta(h)
145 | 
146 |             logits_phi = self.conv_phi(h)
147 | 
148 |             # symmetrize
149 |             h = 0.5 * (h + torch.transpose(h, 2, 3))
150 | 
151 |             logits_dist = self.conv_dist(h)
152 | 
153 |             logits_omega = self.conv_omega(h)
154 |             if not softmax:
155 |                 return logits_dist, logits_theta, logits_phi, logits_omega
156 |             else:
157 |                 theta_probs = self.softmax(logits_theta)
158 |                 phi_probs = self.softmax(logits_phi)
159 |                 dist_probs = self.softmax(logits_dist)
160 |                 omega_probs = self.softmax(logits_omega)
161 |                 return dist_probs, theta_probs, phi_probs, omega_probs
162 |         else:
163 |             return h
164 | 
165 |     def load_weights(self, model_id):
166 |         
167 |         """
168 |         Parameters:
169 |         -----------
170 |         model_id : str
171 |             pretrained models a, b, c, d and/or e.
172 |         """
173 | 
174 |         path = WEIGHTS_DIR + 'trrosetta_pytorch_weights/' + model_id + '.pt'
175 | 
176 |         # check to see if pytorch weights exist, if not -> generate
177 |         if not os.path.exists(path):
178 |             tf_to_pytorch_weights(self.named_parameters(), model_id)
179 |         self.load_state_dict(torch.load(path, ), strict=False)
180 | 
181 | 
182 | class trRosettaRegressor(trRosetta):
183 | 
184 |     def __init__(self, model_id='a', p_dropout=0.0):
185 |         super(trRosettaRegressor, self).__init__(n2d_layers=61, model_id=model_id, decoder=False, p_dropout=p_dropout)
186 |         self.dist_layer = nn.Conv2d(64, 1, kernel_size=1, stride=1, padding=pad_size(1, 1, 1))
187 |         self.theta_layer = nn.Conv2d(64, 2, kernel_size=1, stride=1, padding=pad_size(1, 1, 1))
188 |         self.phi_layer =  nn.Conv2d(64, 2, kernel_size=1, stride=1, padding=pad_size(1, 1, 1))
189 |         self.omega_layer = nn.Conv2d(64, 2, kernel_size=1, stride=1, padding=pad_size(1, 1, 1))
190 |         self.relu = nn.ReLU()
191 |         self.tanh = nn.Hardtanh()
192 | 
193 |     def forward(self, x, input_mask=None, softmax=False):
194 |         h = super(trRosettaRegressor, self).forward(x, input_mask=input_mask)
195 |         h = F.elu(h)
196 |         sc_theta = self.tanh(self.theta_layer(h))
197 |         sc_phi = self.tanh(self.phi_layer(h))
198 | 
199 |         # symmetrize
200 |         h = 0.5 * (h + torch.transpose(h, 2, 3))
201 |         dist = self.relu(self.dist_layer(h))
202 |         sc_omega = self.tanh(self.omega_layer(h))
203 |         return dist, sc_theta, sc_phi, sc_omega
204 | 
205 | 
206 | class trRosettaEnsemble(nn.Module):
207 |     """trRosetta ensemble"""
208 |     def __init__(self, model, n2d_layers=61, model_ids='abcde', decoder=True):
209 |         """
210 |         Parameters:
211 |         -----------
212 |         model : class 
213 |             base model to use in ensemble
214 |         
215 |         n2d_layers : int 
216 |             number of layers of the conv block to use for each base model
217 |         
218 |         model_ids: str
219 |             pretrained models to use in the ensemble a, b, c, d and/or e. 
220 |             
221 |         decoder : bool
222 |             if True, return dist, omega, phi, theta; else return layer prior decoder
223 |         
224 |         """
225 | 
226 |         super(trRosettaEnsemble, self).__init__()
227 |         self.model_list = nn.ModuleList()
228 |         for i in list(model_ids):
229 |             params = {'model_id': i, 'n2d_layers': n2d_layers, 'decoder': decoder}
230 |             self.model_list.append(model(**params))
231 | 
232 |     def forward(self, x, input_mask=None, softmax=True):
233 |         """
234 |         Parameters:
235 |         -----------
236 |         x : torch.Tensor, (1, 526, len(sequence), len(sequence))
237 |             inputs after trRosettaPreprocessing
238 |         """
239 |         return [mod(x, input_mask=input_mask, softmax=softmax) for mod in self.model_list]
240 | 
241 | 
242 | class trRosettaDist(nn.Module):
243 |     """trRosetta for distance only, does not use pretrained weights"""
244 |     def __init__(self, n2d_layers=61, hdim=128, decoder=True, d_out=1):
245 |         """
246 |         Args:
247 |             n2d_layers: int
248 |                 number of layers of the conv block to use for each base model
249 |             hdim: int
250 |                 input 1d hidden dimension
251 |             decoder: bool
252 |                  if True, return dist; else return layer prior decoder
253 |         """
254 |         super(trRosettaDist, self).__init__()
255 | 
256 |         self.conv0 = nn.Conv2d(hdim * 2, 64, kernel_size=1, stride=1, padding=pad_size(1, 1, 1))
257 |         self.instnorm0 = nn.InstanceNorm2d(64, eps=1e-06, affine=True)
258 | 
259 |         dilation = 1
260 |         layers = []
261 |         for _ in range(n2d_layers):
262 |             layers.append(trRosettaBlock(dilation))
263 |             dilation *= 2
264 |             if dilation > 16:
265 |                 dilation = 1
266 | 
267 |         self.layers = nn.ModuleList(modules=layers)
268 |         self.decoder = decoder
269 | 
270 |         if decoder:
271 |             self.conv_dist = nn.Conv2d(64, d_out, kernel_size=1, stride=1, padding=pad_size(1, 1, 1))
272 | 
273 |     def forward(self, x, ):
274 |         """
275 |         Args:
276 |             x: torch.tensor (N, L, hdim)
277 |         Returns:
278 |             dist: torch.tensor(), (N, L, L)
279 |             x: torch.tensor(), (N, 64, L, L)
280 | 
281 |         """
282 |         n, el, _ = x.shape
283 | 
284 |         # convert to 2d
285 |         left = x.unsqueeze(2).repeat(1, 1, el, 1)
286 |         right = x.unsqueeze(1).repeat(1, el, 1, 1)
287 |         x = torch.cat((left, right), -1)
288 |         x = x.permute(0, 3, 1, 2)
289 | 
290 |         x = F.elu(self.instnorm0(self.conv0(x)))
291 |         old_elu = x.clone()
292 |         for layer in self.layers:
293 |             x, old_elu = layer(x, old_elu)
294 | 
295 |         if self.decoder:
296 |             # symmetrize
297 |             ## TODO: Some things need to be symmetrical and others don't
298 |             x = 0.5 * (x + torch.transpose(x, 2, 3))
299 |             dist = self.conv_dist(x).squeeze(1)
300 |             return dist
301 |         else:
302 |             return x
303 | 
304 | # EXAMPLE
305 | # filename = 'example/T1001.a3m' 
306 | # seqs = parse_a3m(filename) # grab seqs
307 | # tokenizer = Tokenizer(PROTEIN_ALPHABET) 
308 | # seqs = [tokenizer.tokenize(i) for i in seqs] # ohe into our order
309 | 
310 | # base_model = trRosetta
311 | # input_token_order = PROTEIN_ALPHABET
312 | # ensemble = trRosettaEnsemble(base_model, n2d_layers=61,model_ids='abcde')
313 | # preprocess = trRosettaPreprocessing(input_token_order=PROTEIN_ALPHABET, wmin=0.8)
314 | # x = preprocess.process(seqs)
315 | # with torch.no_grad():
316 | #     ensemble.eval()
317 | #     outputs = ensemble(x.double())
318 | 


--------------------------------------------------------------------------------
/sequence_models/trRosetta_utils.py:
--------------------------------------------------------------------------------
  1 | import torch
  2 | import torch.nn.functional as F
  3 | import numpy as np
  4 | import os
  5 | import tarfile
  6 | import string
  7 | 
  8 | from sequence_models.constants import WEIGHTS_DIR, trR_ALPHABET, DIST_BINS, PHI_BINS, OMEGA_BINS, THETA_BINS
  9 | 
 10 | 
 11 | def probs2value(array, property, mask2d):
 12 |     # input shape: batch, n_bins, ell, ell
 13 |     # output shape: batch, ell, ell
 14 |     if property == 'dist':
 15 |         bins = DIST_BINS
 16 |     elif property == 'phi':
 17 |         bins = PHI_BINS
 18 |     elif property == 'omega':
 19 |         bins = OMEGA_BINS
 20 |     elif property == 'theta':
 21 |         bins = THETA_BINS
 22 |     if property == 'dist':
 23 |         bins = torch.tensor(np.nan_to_num(bins), device=array.device, dtype=array.dtype)
 24 |         b = (bins[:-1] + bins[1:]) / 2
 25 |         diff = b[-1] - b[-2]
 26 |         b[0] = b[-1] + diff
 27 |     else:
 28 |         b = torch.tensor((bins[1:-1] + bins[2:]) / 2, device=array.device, dtype=array.dtype)
 29 |     b = b.view(1, -1, 1, 1)
 30 |     if property != 'dist':
 31 |         probs = array[:, 1:, :, :]
 32 |         den = torch.sum(probs, dim=1, keepdim=True)
 33 |         j = torch.where(den < 1e-9)
 34 |         den[j] = 1e-9
 35 |         probs = probs / den
 36 |     else:
 37 |         probs = array
 38 |     if property in ['dist', 'phi']:
 39 |         values = b * probs
 40 |         values = values.sum(dim=1)
 41 |     else:
 42 |         s = (torch.sin(b) * probs).sum(dim=1)
 43 |         c = (torch.cos(b) * probs).sum(dim=1)
 44 |         j = torch.where(s.abs() < 1e-9)
 45 |         s[j] = 1e-9
 46 |         j = torch.where(c.abs() < 1e-9)
 47 |         c[j] = 1e-9
 48 |         values = torch.atan2(s, c)
 49 | 
 50 |     values = values.masked_fill(~mask2d.bool().squeeze(), np.nan)
 51 |     ii, jj = np.diag_indices(values.shape[1])
 52 |     for i in range(len(values)):
 53 |         values[i, ii, jj] = values[i, ii, jj] + np.nan
 54 |     return values
 55 | 
 56 | 
 57 | # probably move this into a collate_fn 
 58 | class trRosettaPreprocessing():
 59 |     """Preprocessing a3m files to torch tensors for trRosetta"""
 60 | 
 61 |     def __init__(self, input_token_order, wmin=0.8):
 62 |         """
 63 |         Parameters:
 64 |         -----------
 65 |         input_token_order : str
 66 |             order of your amino acid alphabet
 67 | 
 68 |         wmin : float
 69 |             sequence identity value cutoff
 70 |         """
 71 |         if input_token_order == trR_ALPHABET:
 72 |             self.ohe_dict = None
 73 |         else:
 74 |             self.ohe_dict = self._build_ohe_dict(input_token_order)
 75 |         self.wmin = wmin
 76 |         self.seqlen = 0
 77 | 
 78 |     def _build_ohe_dict(self, input_order):
 79 |         """Convert your alphabet order to the one trRosetta uses
 80 | 
 81 |         Parameters:
 82 |         -----------
 83 |         input_token_order : str
 84 |             order of your amino acid alphabet
 85 | 
 86 |         Returns:
 87 |         --------
 88 |         ohe_dict : dict
 89 |             map between your alphabet order and trRosetta order
 90 |         """
 91 |         trR_order = trR_ALPHABET
 92 |         ohe_dict = {}
 93 |         for i in input_order:
 94 |             if i in trR_order:
 95 |                 ohe_dict[input_order.index(i)] = trR_order.index(i)
 96 |             else:
 97 |                 ohe_dict[input_order.index(i)] = trR_order.index('-')
 98 |         return ohe_dict
 99 | 
100 |     def _convert_ohe(self, seqs):
101 |         """Convert sequence to ohe
102 | 
103 |         Parameters:
104 |         -----------
105 |         seqs : list
106 |             list of sequence from MSAs
107 | 
108 |         ohe_dict : dict
109 |             map between your alphabet order and trRosetta order
110 | 
111 |         Returns:
112 |         --------
113 |         * : torch.Tensor
114 |             one-hot-encodings of sequences, (num_of_seqs, len(seq))
115 |         """
116 | 
117 |         processed_seqs = []
118 |         for seq in seqs:
119 |             processed_seqs.append([self.ohe_dict[i.item()] for i in seq])
120 |         return torch.Tensor(np.array(processed_seqs)).long()
121 | 
122 |     def _reweight_py(self, msa1hot, cutoff, eps=1e-9):
123 |         """Scatter one hot encoding
124 | 
125 |         Parameters:
126 |         -----------
127 |         msa1hot : torch.Tensor
128 |             one hot encoded MSA seqs
129 | 
130 |         cutoff : float
131 |             sequence identity value cutoff
132 | 
133 |         eps : float
134 |             margin to prevent divide by 0
135 | 
136 |         Returns:
137 |         --------
138 |         * : torch.Tensor
139 |             weights for sequence, (1, num_of_seq)
140 |         """
141 |         self.seqlen = msa1hot.size(2)
142 |         id_min = self.seqlen * cutoff
143 |         id_mtx = torch.stack([torch.tensordot(el, el, [[1, 2], [1, 2]]) for el in msa1hot], 0)
144 |         id_mask = id_mtx > id_min
145 |         weights = 1.0 / (id_mask.type_as(msa1hot).sum(-1) + eps)
146 |         return weights
147 | 
148 |     def _extract_features_1d(self, msa1hot, weights):
149 |         """Get 1d features
150 | 
151 |         Parameters:
152 |         -----------
153 |         msa1hot : torch.Tensor
154 |             one hot encoded MSA seqs
155 | 
156 |         weights : torch.Tensor
157 |             weights for sequences
158 | 
159 |         Returns:
160 |         --------
161 |         f1d : torch.Tensor
162 |             1d features (1, len(seq), 42)
163 |         """
164 |         # 1D Features
165 |         f1d_seq = msa1hot[:, 0, :, :20]
166 |         batch_size = msa1hot.size(0)
167 | 
168 |         # msa2pssm
169 |         beff = weights.sum()
170 |         f_i = (weights[:, :, None, None] * msa1hot).sum(1) / beff + 1e-9
171 |         h_i = (-f_i * f_i.log()).sum(2, keepdims=True)
172 |         f1d_pssm = torch.cat((f_i, h_i), dim=2)
173 |         f1d = torch.cat((f1d_seq, f1d_pssm), dim=2)
174 |         f1d = f1d.view(batch_size, self.seqlen, 42)
175 |         return f1d
176 | 
177 |     def _extract_features_2d(self, msa1hot, weights, penalty=4.5):
178 |         """Get 2d features
179 | 
180 |         Parameters:
181 |         -----------
182 |         msa1hot : torch.Tensor
183 |             one hot encoded MSA seqs
184 | 
185 |         weights : torch.Tensor
186 |             weights for sequences
187 | 
188 |         penalty : float
189 |             penalty for inv. covariance
190 | 
191 |         Returns:
192 |         --------
193 |         f2d_dca : torch.Tensor
194 |             2d features (1, len(seq), len(seq), 442)
195 |         """
196 |         # 2D Features
197 |         batch_size = msa1hot.size(0)
198 |         num_alignments = msa1hot.size(1)
199 |         num_symbols = 21
200 | 
201 |         if num_alignments == 1:
202 |             # No alignments, predict from sequence alone
203 |             f2d_dca = torch.zeros(
204 |                 batch_size, self.seqlen, self.seqlen, 442,
205 |                 dtype=torch.float,
206 |                 device=msa1hot.device)
207 |             return f2d_dca
208 | 
209 |         # compute fast_dca
210 |         # covariance
211 |         x = msa1hot.view(batch_size, num_alignments, self.seqlen * num_symbols)
212 |         num_points = weights.sum(1) - weights.mean(1).sqrt()
213 |         mean = (x * weights.unsqueeze(2)).sum(1, keepdims=True) / num_points[:, None, None]
214 |         x = (x - mean) * weights[:, :, None].sqrt()
215 |         cov = torch.matmul(x.transpose(-1, -2), x) / num_points[:, None, None]
216 | 
217 |         # inverse covariance
218 |         reg = torch.eye(self.seqlen * num_symbols,
219 |                         device=weights.device,
220 |                         dtype=weights.dtype)[None]
221 |         reg = reg * penalty / weights.sum(1, keepdims=True).sqrt().unsqueeze(2)
222 |         cov_reg = cov + reg
223 |         chol = torch.cholesky(cov_reg.squeeze())
224 |         inv_cov = torch.cholesky_inverse(chol).unsqueeze(0)
225 |         x1 = inv_cov.view(batch_size, self.seqlen, num_symbols, self.seqlen, num_symbols)
226 |         x2 = x1.permute(0, 1, 3, 2, 4)
227 |         features = x2.reshape(batch_size, self.seqlen, self.seqlen, num_symbols * num_symbols)
228 | 
229 |         x3 = (x1[:, :, :-1, :, :-1] ** 2).sum((2, 4)).sqrt() * (
230 |                 1 - torch.eye(self.seqlen, device=weights.device, dtype=weights.dtype)[None])
231 |         apc = x3.sum(1, keepdims=True) * x3.sum(2, keepdims=True) / x3.sum(
232 |             (1, 2), keepdims=True)
233 |         contacts = (x3 - apc) * (1 - torch.eye(
234 |             self.seqlen, device=x3.device, dtype=x3.dtype).unsqueeze(0))
235 | 
236 |         f2d_dca = torch.cat([features, contacts[:, :, :, None]], axis=3)
237 |         return f2d_dca
238 | 
239 |     def process(self, x):
240 |         """Do all preprocessing steps
241 | 
242 |         Parameters:
243 |         -----------
244 |         x : list
245 |             list of sequences from MSA
246 | 
247 |         Returns:
248 |         --------
249 |         features : torch.Tensor, (1, 526, len(seq), len(seq))
250 |             input for trRosetta
251 |         """
252 |         if self.ohe_dict is not None:
253 |             x = self._convert_ohe(x).reshape(len(x), -1)
254 |         x = F.one_hot(x, len(trR_ALPHABET)).unsqueeze(0).float()
255 |         # x = self._one_hot_embedding(x, len(trR_ALPHABET))
256 |         w = self._reweight_py(x, self.wmin)
257 |         f1d = self._extract_features_1d(x, w)
258 |         f2d = self._extract_features_2d(x, w)
259 | 
260 |         left = f1d.unsqueeze(2).repeat(1, 1, self.seqlen, 1)
261 |         right = f1d.unsqueeze(1).repeat(1, self.seqlen, 1, 1)
262 |         features = torch.cat((left, right, f2d), -1)
263 |         features = features.permute(0, 3, 1, 2)
264 |         return features
265 | 
266 |     def __call__(self, x):
267 |         return self.process(x)
268 | 
269 | 
270 | def tf_to_pytorch_weights(model_params, model_id):
271 |     """Generate trRosetta weights for pytorch
272 | 
273 |     Parameters:
274 |     -----------
275 |     model_params : torch's model self.named_parameters()
276 |         name of param and param
277 | 
278 |     model_id: str
279 |         pretrained models a, b, c, d and/or e.
280 | 
281 |     """
282 |     # check to see if previously downloaded weights, if not -> download
283 |     if not os.path.exists(WEIGHTS_DIR):
284 |         os.mkdir(WEIGHTS_DIR)
285 |     tr_src_dir = WEIGHTS_DIR + 'trrosetta_tf_weights/'
286 |     if not os.path.exists(tr_src_dir):
287 |         os.mkdir(tr_src_dir)
288 |     zip_fpath = tr_src_dir + 'model_weights.tar.bz2'
289 |     tf_fpath = tr_src_dir + 'model2019_07/'
290 |     if len(os.listdir(tr_src_dir)) == 0:
291 |         print('grabbing weights from source...')
292 |         import wget
293 |         wget.download('https://files.ipd.uw.edu/pub/trRosetta/model2019_07.tar.bz2', out=zip_fpath)
294 |         model_file = tarfile.open(zip_fpath, mode='r:bz2')
295 |         model_file.extractall(tr_src_dir)
296 |         model_file.close()
297 | 
298 |     # check to see if converted to pytorch weights yet
299 |     tr_tgt_dir = WEIGHTS_DIR + 'trrosetta_pytorch_weights/'
300 |     if not os.path.exists(tr_tgt_dir):
301 |         os.mkdir(tr_tgt_dir)
302 | 
303 |     model_path = tr_tgt_dir + model_id + '.pt'
304 | 
305 |     if not os.path.exists(model_path):
306 |         print('converting model %s weights from tensorflow to pytorch...' % model_id)
307 | 
308 |         ckpt = tf_fpath + 'model.xa' + model_id
309 |         import tensorflow as tf
310 |         w_vars = tf.train.list_variables(ckpt)  # get weight names
311 | 
312 |         # filter weights
313 |         w_vars_fil = [i for i in w_vars if 'Adam' not in i[0]]
314 |         instnorm_beta_vars = [i[0] for i in w_vars_fil if 'InstanceNorm' in i[0] and 'beta' in i[0]]
315 |         instnorm_gamma_vars = [i[0] for i in w_vars_fil if 'InstanceNorm' in i[0] and 'gamma' in i[0]]
316 |         conv_kernel_vars = [i[0] for i in w_vars_fil if 'conv2d' in i[0] and 'kernel' in i[0]]
317 |         conv_bias_vars = [i[0] for i in w_vars_fil if 'conv2d' in i[0] and 'bias' in i[0]]
318 | 
319 |         # order weights
320 |         w_vars_ord = [conv_kernel_vars[0], conv_bias_vars[0], instnorm_gamma_vars[0], instnorm_beta_vars[0]]
321 |         for i in range(len(conv_kernel_vars)):
322 |             if 'conv2d_' + str(i) + '/kernel' in conv_kernel_vars:
323 |                 w_vars_ord.append('conv2d_' + str(i) + '/kernel')
324 |             if 'conv2d_' + str(i) + '/bias' in conv_bias_vars:
325 |                 w_vars_ord.append('conv2d_' + str(i) + '/bias')
326 |             if 'InstanceNorm_' + str(i) + '/gamma' in instnorm_gamma_vars:
327 |                 w_vars_ord.append('InstanceNorm_' + str(i) + '/gamma')
328 |             if 'InstanceNorm_' + str(i) + '/beta' in instnorm_beta_vars:
329 |                 w_vars_ord.append('InstanceNorm_' + str(i) + '/beta')
330 | 
331 |         #         tf_weight_dict = {name:tf.train.load_variable(ckpt, name) for name in w_vars_ord}
332 |         weights_list = [tf.train.load_variable(ckpt, name) for name in w_vars_ord]
333 | 
334 |         # convert into pytorch format
335 |         torch_weight_dict = {}
336 |         weights_idx = 0
337 |         for name, param in model_params:
338 |             if len(weights_list[weights_idx].shape) == 4:
339 |                 torch_weight_dict[name] = torch.from_numpy(weights_list[weights_idx]).to(torch.float64).permute(3, 2, 0,
340 |                                                                                                                 1)
341 |             else:
342 |                 torch_weight_dict[name] = torch.from_numpy(weights_list[weights_idx]).to(torch.float64)
343 |             weights_idx += 1
344 | 
345 |         torch.save(torch_weight_dict, model_path)
346 | 
347 | 
348 | def parse_a3m(filename):
349 |     """Load a3m file to list of sequences
350 | 
351 |     Parameters:
352 |     -----------
353 |     filename : str
354 |         path to a3m file
355 | 
356 |     Returns:
357 |     --------
358 |     seqs : list
359 |         list of seqs in MSA
360 | 
361 |     """
362 |     seqs = []
363 |     table = str.maketrans(dict.fromkeys(string.ascii_lowercase))
364 | 
365 |     # read file line by line
366 |     for line in open(filename, "r"):
367 |         # skip labels
368 |         if line[0] != '>':
369 |             # remove lowercase letters and right whitespaces
370 |             seqs.append(line.rstrip().translate(table))
371 |     return seqs
372 | 


--------------------------------------------------------------------------------
/sequence_models/utils.py:
--------------------------------------------------------------------------------
  1 | from typing import Iterable
  2 | 
  3 | import numpy as np
  4 | import pandas as pd
  5 | from scipy.spatial.distance import squareform, pdist
  6 | 
  7 | from sequence_models.constants import STOP, START, MASK, PAD
  8 | from sequence_models.constants import PROTEIN_ALPHABET
  9 | 
 10 | 
 11 | def warmup(n_warmup_steps):
 12 |     def get_lr(step):
 13 |         return min((step + 1) / n_warmup_steps, 1.0)
 14 |     return get_lr
 15 | 
 16 | 
 17 | def transformer_lr(n_warmup_steps):
 18 |     factor = n_warmup_steps ** 0.5
 19 |     def get_lr(step):
 20 |         step += 1
 21 |         return min(step ** (-0.5), step * n_warmup_steps ** (-1.5)) * factor
 22 |     return get_lr
 23 | 
 24 | 
 25 | def get_metrics(fname, new=False, tokens=False):
 26 |     with open(fname) as f:
 27 |         lines = f.readlines()
 28 |     valid_lines = []
 29 |     train_lines = []
 30 |     all_train_lines = []
 31 |     for i, line in enumerate(lines):
 32 |         if 'Training' in line and 'loss' in line:
 33 |             last_train = line
 34 |             all_train_lines.append(line)
 35 |         if 'Validation complete' in line:
 36 |             valid_lines.append(lines[i - 1])
 37 |             train_lines.append(last_train)
 38 |     metrics = []
 39 |     idx_loss = 13
 40 |     idx_accu = 16
 41 |     idx_step = 6
 42 |     if new:
 43 |         idx_loss += 2
 44 |         idx_accu += 2
 45 |         idx_step += 2
 46 |     if tokens:
 47 |         idx_loss += 2
 48 |         idx_accu += 2
 49 |         idx_tok = 10
 50 |     tok_correction = 0
 51 |     last_raw_toks = 0
 52 |     for t, v in zip(train_lines, valid_lines):
 53 |         step = int(t.split()[idx_step])
 54 |         t_loss = float(t.split()[idx_loss])
 55 |         t_accu = float(t.split()[idx_accu][:6])
 56 |         v_loss = float(v.split()[idx_loss])
 57 |         v_accu = float(v.split()[idx_accu][:6])
 58 |         if tokens:
 59 |             toks = int(t.split()[idx_tok])
 60 |             if toks < last_raw_toks:
 61 |                 tok_correction += last_raw_toks
 62 |                 doubled = int(all_train_lines[-1].split()[idx_tok]) - int(all_train_lines[-999].split()[idx_tok])
 63 |                 tok_correction -= doubled
 64 |             last_raw_toks = toks
 65 |             metrics.append((step, toks + tok_correction, t_loss, t_accu, v_loss, v_accu))
 66 | 
 67 |         else:
 68 |             metrics.append((step, t_loss, t_accu, v_loss, v_accu))
 69 |     if tokens:
 70 |         metrics = pd.DataFrame(metrics, columns=['step', 'tokens', 'train_loss',
 71 |                                                  'train_accu', 'valid_loss', 'valid_accu'])
 72 |     else:
 73 |         metrics = pd.DataFrame(metrics, columns=['step', 'train_loss', 'train_accu', 'valid_loss', 'valid_accu'])
 74 |     return metrics
 75 | 
 76 | 
 77 | def get_weights(seqs):
 78 |     scale = 1.0
 79 |     theta = 0.2
 80 |     seqs = np.array([[PROTEIN_ALPHABET.index(a) for a in s] for s in seqs])
 81 |     weights = scale / (np.sum(squareform(pdist(seqs, metric="hamming")) < theta, axis=0))
 82 |     return weights
 83 | 
 84 | 
 85 | def parse_fasta(fasta_fpath, return_names=False):
 86 |     """ Read in a fasta file and extract just the sequences."""
 87 |     seqs = []
 88 |     with open(fasta_fpath) as f_in:
 89 |         current = ''
 90 |         names = [f_in.readline()[1:].replace('\n', '')]
 91 |         for line in f_in:
 92 |             if line[0] == '>':
 93 |                 seqs.append(current)
 94 |                 current = ''
 95 |                 names.append(line[1:].replace('\n', ''))
 96 |             else:
 97 |                 current += line.replace('\n', '')
 98 |         seqs.append(current)
 99 |     if return_names:
100 |         return seqs, names
101 |     else:
102 |         return seqs
103 | 
104 | 
105 | def read_fasta(fasta_fpath, out_fpath, header='sequence'):
106 |     """ Read in a fasta file and extract just the sequences."""
107 |     with open(fasta_fpath) as f_in, open(out_fpath, 'w') as f_out:
108 |         f_out.write(header + '\n')
109 |         current = ''
110 |         _ = f_in.readline()
111 |         for line in f_in:
112 |             if line[0] == '>':
113 |                 f_out.write(current + '\n')
114 |                 current = ''
115 |             else:
116 |                 current += line[:-1]
117 |         f_out.write(current + '\n')
118 | 
119 | 
120 | class Tokenizer(object):
121 |     """Convert between strings and their one-hot representations."""
122 |     def __init__(self, alphabet: str):
123 |         self.alphabet = alphabet
124 |         self.a_to_t = {a:i for i, a in enumerate(self.alphabet)}
125 |         self.t_to_a = {i:a for i, a in enumerate(self.alphabet)}
126 | 
127 |     @property
128 |     def vocab_size(self) -> int:
129 |         return len(self.alphabet)
130 | 
131 |     @property
132 |     def start_id(self) -> int:
133 |         return self.alphabet.index(START)
134 | 
135 |     @property
136 |     def stop_id(self) -> int:
137 |         return self.alphabet.index(STOP)
138 | 
139 |     @property
140 |     def mask_id(self) -> int:
141 |         return self.alphabet.index(MASK)
142 | 
143 |     @property
144 |     def pad_id(self) -> int:
145 |         return self.alphabet.index(PAD)
146 | 
147 |     def tokenize(self, seq: str) -> np.ndarray:
148 |         return np.array([self.a_to_t[a] for a in seq])
149 | 
150 |     def untokenize(self, x: Iterable) -> str:
151 |         return ''.join([self.t_to_a[t] for t in x])
152 | 
153 | 
154 | 


--------------------------------------------------------------------------------
/sequence_models/vae.py:
--------------------------------------------------------------------------------
  1 | from typing import List
  2 | 
  3 | import torch.nn as nn
  4 | import torch
  5 | import torch.optim as optim
  6 | from apex import amp
  7 | import mlflow
  8 | from torch import nn as nn
  9 | 
 10 | from sequence_models.losses import VAELoss
 11 | from sequence_models.layers import FCStack
 12 | from sequence_models.metrics import UngappedAccuracy
 13 | 
 14 | 
 15 | class VAETrainer(object):
 16 |     """ Trainer for VAEs."""
 17 |     def __init__(self, vae, device, pad_idx, class_weights=None, lr=1e-4, beta=1.0, opt_level='O2', optim_kwargs={},
 18 |                  early_stopping=True, patience=10, improve_threshold=0.001, save_freq=100, scheduler=None,
 19 |                  scheduler_args=[], scheduler_kwargs={}, scheduler_time='epoch', kl_anneal=-1):
 20 |         self.vae = vae.to(device)
 21 |         self.device = device
 22 |         self.beta = beta
 23 |         self.anneal_epochs = kl_anneal
 24 |         # Store an optimizer
 25 |         self.optimizer = optim.Adam(vae.parameters(), lr=lr, **optim_kwargs)
 26 |         if opt_level != 'O0':
 27 |             self.vae, self.optimizer = amp.initialize(self.vae, self.optimizer, opt_level=opt_level)
 28 |         self.opt_level = opt_level
 29 |         # Store the loss
 30 |         self.loss_func = VAELoss(class_weights=class_weights)
 31 |         self.accu_func = UngappedAccuracy(pad_idx)
 32 |         if scheduler is None:
 33 |             self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, 1, gamma=1.0)
 34 |             self.scheduler_time = 'epoch'
 35 |         else:
 36 |             self.scheduler = scheduler(self.optimizer, *scheduler_args, **scheduler_kwargs)
 37 |             self.scheduler_time = scheduler_time
 38 |         self.early_stopping = early_stopping
 39 |         self.patience = patience
 40 |         self.improve_threshold = improve_threshold
 41 |         self.save_freq = save_freq
 42 |         self.current_epoch = 0
 43 | 
 44 |     def step(self, src, tgt, train=True, weights=None):
 45 |         """Do a forward pass. Do a backward pass if train=True. """
 46 |         if train:
 47 |             self.vae = self.vae.train()
 48 |             self.optimizer.zero_grad()
 49 |         else:
 50 |             self.vae = self.vae.eval()
 51 |         loss, r_loss, kl_loss, accu = self._forward(src, tgt, weights=weights)
 52 |         if train:
 53 |             self._backward(loss)
 54 |         return loss.item(), r_loss.item(), kl_loss.item(), accu.item()
 55 | 
 56 |     def _forward(self, src, tgt, weights=None):
 57 |         src = src.to(self.device)
 58 |         tgt = tgt.to(self.device)
 59 |         p, z_mu, z_log_var = self.vae(src)
 60 |         if self.anneal_epochs == -1:
 61 |             beta = self.beta
 62 |         else:
 63 |             beta = self.beta * min(self.current_epoch / self.anneal_epochs, 1.0)
 64 |         loss, r_loss, kl_loss = self.loss_func(p, tgt, z_mu, z_log_var, beta=beta, sample_weights=weights)
 65 |         accu = self.accu_func(p, tgt)
 66 |         return loss, r_loss, kl_loss, accu
 67 | 
 68 |     def _backward(self, loss):
 69 |         if self.opt_level != 'O0':
 70 |             with amp.scale_loss(loss, self.optimizer) as scaled_loss:
 71 |                 scaled_loss.backward()
 72 |         else:
 73 |             loss.backward()
 74 |         self.optimizer.step()
 75 |         if self.scheduler_time == 'batch':
 76 |             self.scheduler.step()
 77 | 
 78 |     def epoch(self, loader, train):
 79 |         losses = 0.0
 80 |         r_losses = 0.0
 81 |         kl_losses = 0.0
 82 |         accus = 0.0
 83 |         for i, batch in enumerate(loader):
 84 |             src = batch[0]
 85 |             if isinstance(self.vae, RecurrentVAE):
 86 |                 tgt = batch[1]
 87 |                 if len(batch) == 3:
 88 |                     weights = batch[2]
 89 |                 else:
 90 |                     weights = None
 91 |             else:
 92 |                 tgt = batch[0]
 93 |                 if len(batch) == 2:
 94 |                     weights = batch[1]
 95 |                 else:
 96 |                     weights = None
 97 |             loss, r_loss, kl_loss, accu = self.step(src, tgt, train=train, weights=weights)
 98 |             losses += loss
 99 |             r_losses += r_loss
100 |             kl_losses += kl_loss
101 |             accus += accu
102 |             mean_loss = losses / (i + 1)
103 |             mean_r = r_losses / (i + 1)
104 |             mean_kl = kl_losses / (i + 1)
105 |             mean_accu = accus / (i + 1)
106 |             if train:
107 |                 print('\rTraining ', end='')
108 |             else:
109 |                 print('\rValidating ', end='')
110 |             print(
111 |                 'Epoch %d of %d Batch %d of %d loss = %.4f r = %.4f kld = %.4f accu = %.4f'
112 |                 % (
113 |                     self.current_epoch + 1,
114 |                     self.total_epochs,
115 |                     i + 1,
116 |                     len(loader),
117 |                     mean_loss,
118 |                     mean_r,
119 |                     mean_kl,
120 |                     mean_accu
121 |                 ),
122 |                   end=''
123 |             )
124 |         print()
125 |         return mean_loss, mean_r, mean_kl, mean_accu
126 | 
127 |     def train(self, train_loader, epochs, valid_loader=None, save_path=None):
128 |         done = False
129 |         stagnant = 0
130 |         best_loss = 1e8
131 |         self.total_epochs = epochs
132 |         for epoch in range(epochs):
133 |             self.current_epoch = epoch
134 |             if epoch > 0 and (epoch % self.save_freq == 0) and save_path is not None:
135 |                 torch.save(self.vae.state_dict(), save_path + 'autosave_epoch_{}.pkl'.format(epoch))
136 |                 torch.save(self.optimizer.state_dict(), save_path + 'optim_autosave_epoch_{}.pkl'.format(epoch))
137 |             if not done:
138 |                 loss, r_loss, kld, accu = self.epoch(train_loader, True)
139 |                 mlflow.log_metrics(
140 |                     {
141 |                         'train_loss': loss,
142 |                         'train_r_loss': r_loss,
143 |                         'train_kld': kld,
144 |                         'train_accu': accu
145 |                     },
146 |                     step=self.current_epoch
147 |                 )
148 | 
149 |                 if valid_loader is not None:
150 |                     with torch.no_grad():
151 |                         loss, r_loss, kld, accu = self.epoch(valid_loader, False)
152 |                     if self.scheduler_time == 'epoch':
153 |                         self.scheduler.step(loss)
154 |                     mlflow.log_metrics(
155 |                         {
156 |                             'valid_loss': loss,
157 |                             'valid_r_loss': r_loss,
158 |                             'valid_kld': kld,
159 |                             'valid_accu': accu
160 |                         },
161 |                         step=self.current_epoch
162 |                     )
163 |                     if self.early_stopping and self.current_epoch > self.anneal_epochs:
164 |                         improve = loss <= (1 - self.improve_threshold) * best_loss
165 |                         if not improve:
166 |                             stagnant += 1
167 |                         else:
168 |                             stagnant = 0
169 |                             best_loss = loss
170 |                         done = stagnant >= self.patience
171 |             else:
172 |                 print('Stopping early at epoch {}'.format(self.current_epoch))
173 |                 break
174 |         return self.vae, self.loss_func, self.optimizer
175 | 
176 | 
177 | class VAE(nn.Module):
178 |     """A Variational Autoencoder.
179 | 
180 |     Args:
181 |         encoder (nn.Module): Should produce outputs mu and log_var, both with dimensions (N, d_z)
182 |         decoder (nn.Module): Takes inputs (N, d_z) and attempts to reconstruct the original input
183 | 
184 |     Inputs:
185 |         x (N, *)
186 | 
187 |     Ouputs:
188 |         reconstructed (N, *)
189 |         mu (N, d_z)
190 |         log_var (N, d_z)
191 |     """
192 |     def __init__(self, encoder: nn.Module, decoder: nn.Module):
193 |         super(VAE, self).__init__()
194 |         self.encoder = encoder
195 |         self.decoder = decoder
196 |         if self.encoder.d_z != self.decoder.d_z:
197 |             raise ValueError('d_zs do not match!')
198 |         self.d_z = encoder.d_z
199 | 
200 |     def encode(self, x: torch.tensor):
201 |         return self.encoder(x)
202 | 
203 |     def decode(self, z: torch.tensor):
204 |         return self.decoder(z)
205 | 
206 |     def reparameterize(self, mu: torch.tensor, log_var: torch.tensor):
207 |         std = torch.exp(0.5 * log_var)
208 |         eps = torch.randn_like(std)
209 |         return mu + eps * std
210 | 
211 |     def forward(self, x: torch.tensor):
212 |         mu, log_var = self.encode(x)
213 |         z = self.reparameterize(mu, log_var)
214 |         return self.decode(z), mu, log_var
215 | 
216 | 
217 | class RecurrentVAE(VAE):
218 | 
219 |     def forward(self, src):
220 |         mu, log_var = self.encode(src)
221 |         z = self.reparameterize(mu, log_var)
222 |         return self.decode(z, src), mu, log_var
223 | 
224 |     def decode(self, z, src):
225 |         return self.decoder(z, src)
226 | 
227 | 
228 | class FCEncoder(nn.Module):
229 |     """ A simple fully-connected encoder for sequences.
230 | 
231 |     Args:
232 |         L (int): Sequence length
233 |         d_in (int): Number of tokens
234 |         d_h (list of ints): the hidden dimensions
235 |         d_z (int): The size of the latent space
236 |         padding_idx (int): Optional: idx for padding to pass to Embedding layer
237 | 
238 |     Input:
239 |         X (N, L): should be torch.LongTensor
240 | 
241 |     Outputs:
242 |         mu (N, d_z)
243 |         log_var (N, d_z)
244 |     """
245 | 
246 |     def __init__(self, L: int, d_in: int, d_h: List[int], d_z: int, padding_idx=None, p=0., norm='bn'):
247 |         super(FCEncoder, self).__init__()
248 |         self.L = L
249 |         self.d_in = d_in
250 |         self.d_z = d_z
251 |         self.embedder = nn.Embedding(d_in, d_h[0], padding_idx=padding_idx)
252 |         sizes = [L * d_h[0]] + d_h[1:]
253 |         self.layers = FCStack(sizes, p=p, norm=norm)
254 |         d1 = sizes[-1]
255 |         self.u_layer = nn.Linear(d1, d_z)  # Calculates the means
256 |         self.s_layer = nn.Linear(d1, d_z)  # Calculates the log sigmas
257 | 
258 |     def forward(self, X):
259 |         n, _ = X.size()
260 |         e = self.embedder(X).view(n, -1)
261 |         h = self.layers(e)
262 |         return self.u_layer(h), self.s_layer(h)
263 | 
264 | 
265 | class FCDecoder(nn.Module):
266 |     """ A simple fully-connected decoder for sequences.
267 | 
268 |     Args:
269 |         L (int): Sequence length
270 |         d_in (int): Number of tokens
271 |         d_h (list of ints): the hidden dimensions
272 |         d_z (int): The size of the latent space
273 | 
274 |     Input:
275 |         Z (N, d_z)
276 | 
277 |     Outputs:
278 |         X (N, L, d_in)
279 |     """
280 | 
281 |     def __init__(self, L: int, d_in: int, d_h: List[int], d_z: int, p=0., norm='bn'):
282 |         super(FCDecoder, self).__init__()
283 |         self.L = L
284 |         self.d_in = d_in
285 |         self.d_z = d_z
286 |         sizes = [d_z] + d_h + [self.L * self.d_in]
287 |         self.layers = FCStack(sizes, p=p, norm=norm)
288 | 
289 |     def forward(self, z):
290 |         n, _ = z.shape
291 |         x = self.layers(z)
292 |         return x.view(n, self.L, self.d_in)
293 | 
294 | 
295 | class HierarchicalRecurrentDecoder(nn.Module):
296 |     """ A hierarchical recurrent decoder.
297 | 
298 |     Args:
299 |         ells (list of ints): subsequence lengths
300 |         d_in (int): Number of tokens
301 |         d_z (int): The size of the latent space
302 |         conductor (nn.Module): outputs conditioning for each subsequence
303 |         decoder (nn.Module): recurrent decoder
304 | 
305 |     Input:
306 |         z (N, d_z):
307 | 
308 |     Outputs:
309 |         X (N, L, d_in)
310 |     """
311 |     def __init__(self, conductor, decoder):
312 |         super().__init__()
313 |         self.conductor = conductor
314 |         self.decoder = decoder
315 |         self.d_z = self.conductor.d_z
316 | 
317 |     def forward(self, z, x):
318 |         c = self.conductor(z)
319 |         return self.decoder((x, c))
320 | 
321 | 
322 | class Conductor(nn.Module):
323 |     """Basically a 1D DCGAN generator."""
324 | 
325 |     def __init__(self, d_z, n_features: List[int], d_out):
326 |         super().__init__()
327 |         self.d_z = d_z
328 |         n_features = [d_z] + n_features
329 |         layers = []
330 |         for nf0, nf1 in zip(n_features[:-1], n_features[1:]):
331 |             if len(layers) == 0:
332 |                 layers.append(nn.ConvTranspose1d(nf0, nf1, 4, stride=1, bias=False))
333 |             else:
334 |                 layers.append(nn.ConvTranspose1d(nf0, nf1, 4, stride=2, padding=1, bias=False))
335 |             layers.append(nn.BatchNorm1d(nf1))
336 |             layers.append(nn.ReLU())
337 |         layers += [
338 |             nn.ConvTranspose1d(n_features[-1], d_out, 4, stride=2, padding=1, bias=False),
339 |             nn.Tanh()
340 |         ]
341 |         self.layers = nn.Sequential(*layers)
342 | 
343 |     def forward(self, x):
344 |         if len(x.shape) == 2:
345 |             x = x.unsqueeze(-1)
346 |         return self.layers(x).transpose(1, 2)
347 | 
348 | 
349 | # class ConvEncoder(nn.Module):
350 | #
351 | #     def __init__(self, n_tokens, d_z, n_features: List[int]):
352 | #         super().__init__()
353 | #         self.embedding = nn.Embedding(n_tokens, n_features[0])
354 | #         n_features = n_features[1:]
355 | #         layers = [(nn.Conv1d(nf0, nf1, 4, stride=1, bias=False),
356 | #                    nn.BatchNorm1d(nf1),
357 | #                    nn.ReLU())
358 | #                   for nf0, nf1 in zip(n_features[:-1], n_features[1:])]
359 | #         layers = [item for sublist in layers for item in sublist]
360 | #         layers += [
361 | #             nn.ConvTranspose1d(n_features[-1], d_z * 2, 4, stride=1, bias=False),
362 | #         ]
363 | #         self.layers = nn.Sequential(*layers)
364 | #
365 | #     def forward(self, x):
366 | #         e = self.embedding(x).transpose(1, 2)
367 | #         z = self.layers(e).transpose(1, 2)


--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
 1 | import setuptools
 2 | 
 3 | with open("README.md", "r") as fh:
 4 |     long_description = fh.read()
 5 | 
 6 | setuptools.setup(
 7 |     name="sequence-models",
 8 |     version="1.8.0",
 9 |     author="Kevin Yang",
10 |     author_email="yang.kevin@microsoft.com",
11 |     description="Machine learning for sequences.",
12 |     long_description=long_description,
13 |     long_description_content_type="text/markdown",
14 |     url="https://github.com/microsoft/protein-sequence-models",
15 |     packages=setuptools.find_packages(),
16 |     classifiers=[
17 |         "Programming Language :: Python :: 3",
18 |         "Operating System :: OS Independent",
19 |     ],
20 |     include_package_data=True,
21 |     python_requires='>=3.6',
22 | )


--------------------------------------------------------------------------------
/tests/conv_test.py:
--------------------------------------------------------------------------------
 1 | import numpy as np
 2 | import torch
 3 | from sequence_models.convolutional import ByteNetBlock, ByteNet, ConditionedByteNetDecoder, \
 4 |     HierarchicalCausalConv1d, MaskedCausalConv1d
 5 | 
 6 | if torch.cuda.is_available():
 7 |     device = torch.device('cuda')
 8 | else:
 9 |     device = torch.device('cpu')
10 | b = 5
11 | 
12 | d = 8
13 | ell = 7
14 | k = 3
15 | dil = 2
16 | block = ByteNetBlock(2 * d, d, 2 * d, k, dilation=dil, causal=False).to(device)
17 | x = torch.randn(b, ell, 2 * d).to(device)
18 | mask = torch.ones(b, ell, 1, device=device)
19 | mask[:, 4:] = 0.0
20 | out = block(x, input_mask=mask)
21 | assert out.shape == (b, ell, 2 * d)
22 | 
23 | n_tokens = 9
24 | d_e = 2
25 | n_layers = 4
26 | r = 4
27 | ells = [2, 2, 3]
28 | d_c = 3
29 | x = torch.randint(0, n_tokens, (b, ell), device=device)
30 | c = torch.randn(b, 3, d_c, device=device)
31 | net = ConditionedByteNetDecoder(n_tokens, d_e, d_c, d, n_layers, 3, r, ells).to(device)
32 | out = net((x, c))
33 | assert out.shape == (b, sum(ells), d)
34 | 
35 | net = ByteNet(n_tokens, d_e, d, n_layers, 3, r).to(device)
36 | out = net(x)
37 | assert out.shape == (b, ell, d)
38 | 
39 | # Test Causal convolution
40 | din = 8
41 | dout = 9
42 | k = 3
43 | n = 5
44 | ells = [3, 3, 4, 4, 10, 5, 10]
45 | dil = 4
46 | x = torch.randn(n, sum(ells), din).to(device)
47 | layer = MaskedCausalConv1d(din, dout, k, dilation=dil).to(device)
48 | x.requires_grad = True
49 | out = layer(x)
50 | pos = 20
51 | loss = out[0, pos, :].sum()
52 | grad = torch.autograd.grad(loss, x, retain_graph=True)[0][0].sum(dim=1)
53 | start = pos - (k - 1) * dil
54 | for p in range(start, sum(ells)):
55 |     if p > pos:
56 |         assert grad[p].sum() == 0
57 |     elif (p - start) % dil == 0:
58 |         assert grad[p].sum() != 0
59 | 
60 | # Test gradients in hierarchical causal convolution
61 | dil = 4
62 | layer = HierarchicalCausalConv1d(din, dout, ells, k, dilation=dil).to(device)
63 | x = torch.randn(n, sum(ells), din).to(device)
64 | x.requires_grad = True
65 | out = layer(x)
66 | start = pos - (k - 1) * dil
67 | loss = out[0, pos, :].sum()
68 | grad = torch.autograd.grad(loss, x, retain_graph=True)[0][0].sum(dim=1)
69 | blocks = np.zeros(sum(ells))
70 | for ell in np.cumsum(ells):
71 |     blocks[ell:] += 1
72 | for p in range(start, sum(ells)):
73 |     if p > pos:
74 |         assert grad[p].sum() == 0
75 |     elif (p - start) % dil == 0:
76 |         if blocks[p] == blocks[pos]:
77 |             assert grad[p].sum() != 0


--------------------------------------------------------------------------------
/tests/data_test.py:
--------------------------------------------------------------------------------
 1 | import numpy as np
 2 | 
 3 | from sequence_models.data import SortishSampler, ApproxBatchSampler
 4 | 
 5 | n_samples = np.random.randint(100, 500)
 6 | lengths = np.random.randint(100, 200, n_samples)
 7 | bucket_size = np.random.randint(10, 20, dtype=int)
 8 | sampler = SortishSampler(lengths, bucket_size)
 9 | assert len(sampler) == n_samples
10 | assert len(sampler.data) == np.ceil(n_samples / bucket_size)
11 | 
12 | max_tokens = 1000
13 | max_batch = 12
14 | batch_sampler = ApproxBatchSampler(sampler, max_tokens, max_batch, lengths)
15 | for batch in batch_sampler:
16 |     assert len(batch) <= max_batch
17 |     assert len(batch) * max(lengths[batch]) <= max_tokens


--------------------------------------------------------------------------------
/tests/graphmodel_test/T1001.a3m:
--------------------------------------------------------------------------------
  1 | >seq1
  2 | SISTRIGEYRSAQSKEDLIQKYLNQLPGSLCVFFKFLPSVRSFVATHASGIPGSDIQGVGVQLESNDMKELSSQMAIGLLPPRFTEMLVEAFHFSPPKALPLYAHNALEGVFVYSGQLPAEEVARMNEEFTLLSLCYSHF
  3 | >seq2
  4 | AFPMRIADYRSAQSKEDMIQRFLNGISGSRCLFFKFLPSVRSFVATHANGIDAAQIQGVGSQLTSDEMKDLGSLLAMGLIPEKFSSMLVEAFHLNPPKAIPVYANAHLEGLFVYSGGMDKKAIEQLNDEYTLFNLCYSHF
  5 | >seq3
  6 | AISEVINDYRIAESKEDIIRMLFQNLSNLPLLFFKFLPSMNSFVMSHASMPNHQVYEGLGSALNPEETKDLVKQILLNIVPASFNQVISNMFMFQRPSMIPLFDRENLEGVFVFDQESANSLIDQMQDYVSATSLYYSLY
  7 | >seq4
  8 | HVSMRITDYRSAQSKEDLIQKYLNHLPNALCIFFKFLPSVRSFVATHAQGIPASDIQGVGVQLETGDTKDLTTQMAMGLLPARFNDMLVEAFHFNPPKALPLYAHHALEGVFVYSGNITPAEASQLAEEFSLLSLCYSTF
  9 | >seq5
 10 | FISPRIGDYRSAQSKEDMVQKYLNNLPQMVCLFFKFLPTVRSFVATHATGIPASDIQGVGVQLDSADMKELNSQMAMGLLPSRFMDMLTEAFHFRPPKALPLYAHNNLEGVFVYSGEVNKQLSTAMTEEFSLMSLCYSNL
 11 | >seq6
 12 | KVSMRITDYRSAQSKEDLIQKYLNHLPNSLCIFFKFLPSVRSFVATHAQGIPASDIQGVGVQLESSDVKDLTTQMAMGLLPARFNEMLVEAFHLSPPKALPLYAHHALEGVFVYSGSISAAESAQLAEEFSLLSICYSAF
 13 | >seq7
 14 | SIGQKIQLYSAATTKDEVLIHFINQLSC-QAIYFKYLPTVQSFVATHSHGLDIEAIKGVGTRLEATETAHLLDLLKSGEIPPALAELMKEGLRIGqyfpKPVIVQNSTLTGLDGLFIFWGAEGF-HFQQIESDFLIFSLLYQQA
 15 | >seq8
 16 | AISSRIATYKVAESKEDMLHKFLQHIPKeTLCIYFKYLPSVRSFVATHGWGIPNSDIQGVGVQLEAEDMRTLSERVTMGQLPERFSKMLKEAFGFNPPKALPLYAYHSLEGVFVFSGSLDAKYVAEINEEFTLLSLCYSNF
 17 | >seq9
 18 | PFQTRIAEYRAAESKEELVQMFFRQTASQSWAFLKYVSSIQTYISVSSQNMPDEWVEGLSYKI-PSSQSDFNDKILLGEYPESFLQYIKAKWGVDTLKVMPLLLKNEIEGLLITPQDI----SAEVAEDFSLMSLVYNLI
 19 | >seq10
 20 | PFQMRIAEYKMSESKEELLQKFYQQSPKQSWVFLKYIKSIQTYISVSHQNMEPSWVEGLSFKI-PTDVQEFNSRVFVGDFPDSLISYIKTKWDVSNLKILPLTFKDEIEGLLISTQDI----SADVAEDFSLMSLVYQLM
 21 | >seq11
 22 | -VSMRITDYRTAQSKEDLIQKYLNHLPGnTVCVFFKFLPSVRSFVATHALGVPASDIQGVGVQLESSDMKDLSAQMAMGLLPPRFTEMLVEGFHFNPPKALPLYAHNVLEGIFVYSGSLPAEDTAVMNEEFTLMSLCYANF
 23 | >seq12
 24 | -FSYKIAEYKTAQSKEDLLQKFLNNLEHTLCVYFKFLPSVRSFVATHANGIPAASIQGVGCQLESDDMKELGSQLALGLLPERFSSMLVEAFHFNPPKALPLYANNILEGVFVYSGLLGGTAAEMMGEEFSLFSLCYSHF
 25 | >seq13
 26 | -VSMRIADYRSGQSKEDLIQRFLNNAPESLIVYFKFLPSVRSFVATHAKGAEGSSIQGVGCQIEMSEARDLGSQMAMGEIPPTFVSMLKEAFHFNPPKALPIYSHSGLEGVLIYSGDMDKLMAHKLQEEFSIFSLCYSF-
 27 | >seq14
 28 | -VSVRLADYRTVQSKEDLIQRYLNHLPTqVGSVFFKYLPSVRSFVATHGAGIPAGDIQGVGVQLESSDVKDLPSQLALGMLPDRFANMLVEAFQFNPPKVLPLYAHNNLEGVFVYSGTTSPADITFLSEEFALMSLCYVN-
 29 | >seq15
 30 | --SMRIADYRSGQSKEDLIQKFLNNAPESLIIYFKFLPSVRSFVATHGKGIEGSSIQGVGCQIEMQEARDLGGHLSMGQIPASFVSMLKEAFRFNPPKALPVYSHAGIEGVVVYSGDIDKLVLHKLQEEFAVFSLCYSY-
 31 | >seq16
 32 | -FQMRITDYKLAESKEELLQKFFSQTPTQSWVFLKYIKSIRTYIAVAHQNMEENWVEGLSFKI-PGDEEKFNQQIMIGNFPESLTDYLKGKWDIQALKVVPLILKDQIEGLLVTPQDI----SAEVAEDFSLMSLVYQV-
 33 | >seq17
 34 | -FQMRISEYRTAESKEDLLQTFFKQTPTQAWVFLKFIKSIQTYISVLHQNMPESWVEGLSFKI-PIQETQFNQKIMVSDFPSSFLNYIKSKWNVEHVKILPLIIKNELEGLLVSTQDIDG----NVAEDFSLMSLVYA--
 35 | >seq18
 36 | ----LIGDYRLSNSKEDIVQKFISSLQKTSCVYFKFLPSVRSFVATHSNGIDARLLKGVGSQLGKEESKNLTSTLSLGEVPPSMKDLLTEGFHFQSPKILPLFIQNQLDGLIAYDGKIDRNELQDFHERFSLFSLVYSHY
 37 | >seq19
 38 | -VSNRVADYRSAQSKEEVLQKYFERLGKIPAVYLKFLPSVRSFVATHASGFPPSHIQGVGCQLENNDLESLNTQITVGLLPPLMNELLQKVFHFASPRVVPLFVQNQLEGAVVYNGNLSKAESLRVGEEFSLFALCYSY-
 39 | >seq20
 40 | -FQTRIAEYRSADSKEDLLQKFFSQTPQQSWAFLKYVKSFNSYILVSSQNMPEDWIQGVSFKI-PNSEPDFNKKVTVGEFPASFLNYLKRKWEVEIIKVLPLLLKDDVEGLLVTTQDI----SPEVAEDFSLASLMYNL-
 41 | >seq21
 42 | -FQMRISEYRAAETKEELIQLFFKQTPQQSWAFLKYAPSIQTYISILSQSMPDSWVEGLSYKV-PSAMKDFNQQIMLGEYPAVLTEYICAKWGVKTVKMMPLILKDEIEGLLVTPQDI----SAEVAEDFSLMSLVYNL-
 43 | >seq22
 44 | -ISEVLNDFRIAESKEDIIRMLFQNLTNMPLLFFKFLPSMNSFVMSHASMPNHQVHEGLGSALNPEETKDLIKQILLNIVPASFNQVVANMFAFQRPSLIPLFDRDSLEGVFVFDQESSHTLIDQMQDYVSATSLYYSL-
 45 | >seq23
 46 | -FQMRIAEYRVAESKEELIQKFFKQTPAQAWVFLKFVKPIQTYISVSHQNMPEAWVAGLSYKI-PINQPDFNEHVIIGSYSESFLKYIKTKWSVDNVKIFPLIFKNEIEGLFVSPQDI----SAEAAEDFSLMSLVYSL-
 47 | >seq24
 48 | -FQSRIAEYRSADSKEELLQKFFSQTPQQSWAFLKYVKSISSYILVTSQNMPETWIQGLSYKV-PATDAEFNKKIMLGEFPDNLLNYLKRKWEVDIVKILPLLLKDDVEGLLVTTQDI----SSEVAEDFSLASLMYNL-
 49 | >seq25
 50 | -FQMRINEYKSADSKEDLLQIFFRQTPLHSWVFMKYVPSISTYIAVANQNMPQSWIEGLSFKVSALE-TEFNNKVAVGDFPLSLTYYLKSKLEVETVKILPLVIKNDVEGILITTQDI----PAEVAEDFSLMSLNYAL-
 51 | >seq26
 52 | ----------------------------SMCIFFKFLPSVRSFVATHGNGVEGSQIQGVGCQLEADDTKDLNSQLSVGLLPARFNAMLVEAFHFNPPKGLPLYGAQNLEGVFVYSGSLDKNAAAQLNEEFALFSLCYSH-
 53 | >seq27
 54 | ----RITDYKSADSKESLVQKYLQRSESTPMVFFKFLTSVRSFVVTQSVHMDVERIQGLGAQLSPGELKDLNAQLSVGLLPPSLLKMVNEAVHLQEFRVWPLYVHHQLEGVVICQVGE---ETTDLHEEFSLFGVMYSHF
 55 | >seq28
 56 | -----IKDYRSAASKEELLRRFVQVAGKTACVFLKYLPTVRSLVVTNASVFDLDHLQGLGCQLQPNEAKDFGSQVALGIVPPSLHDLLRQAFQFQKSRLLPLFIQDRLEGVVAYSTQIAPSEKMRLDDEFALMSLAYTA-
 57 | >seq29
 58 | --QLRIAEYKSATSKEDLLNVFYNQTPTQSWIYLKFVPSIETFICVSYSQVPEDWVEGLSYKVATKDRDNFMSKLFSGALPPNLGNYLKNKFGTDRIKFLPMIIRDKIEGILISTQEI----SAEVAEDFSLMSLVYT--
 59 | >seq30
 60 | -ISAKIAFYKACVNKDEVVLAFFRRLSC-KAIFFKYLPTVNSFVALSAQGVDVEDIKGVGSRLEPSESKDLPKQLSEGVLPQALVQILKEGLHVSRFEWKPLLVQNWVEGVLVFWGDEKF-QFASIENEFLIFDLIYQKM
 61 | >seq31
 62 | --QIRISEYKSCQSKEELLDVFYKQAETQSWVYLKFIPTIETFISVSNHQVPEYWVEGLSYKV-PASNKGFMDQIFQGVLPESFERYLIQKFNVKQIKFIPLIIKNQVEGLMISTQDI----TADSAEDFSLMSLVYTN-
 63 | >seq32
 64 | --QLRIADYKSASSKEDLLNTFYSQTPAQSWVYLKFVPSIETFLCVSYAGVPEEWVDGLSYKV-NSKDKDFMSQLLMGSLPLSIATYFKSKFGTDHVKYLPMIIRDKIEGILISTQEI----SAEVAEDFSLMSLVYTL-
 65 | >seq33
 66 | --QSRISTYRLAESKEQLLDQFYQATPTQTWIYLKLVPTIQTLICVSSANCPEDWSEGLSYKI-PTREKNFSDLLISGQMPEGLLSYLKVKLAVDKIKFLPLIIKQAVEGVLVSTQDI----SAEVAEDFSLMSLVYSN-
 67 | >seq34
 68 | --QTRIALFKGAESKENLLDLFYQQTPEQSWVYLKFAQTIQTFICVSYANVPENWIEGLSFKV-PIKEKNFLEQISLGALPESLSNYLTQKFGVERVKFLPLMLRDSLDGILISPQDI----SAEVAEDFSLMSLIYTN-
 69 | >seq35
 70 | --QNRIAIYRLAESKEQLLDQFYQATPTQTWMYLKLVPTIQTLICLSSANSPASWTEGLSYKI-SNKETSFLEQILNGLLPENLHSYLKNKFDIDKVKFLPLIIKQNIEGLLISTQDIDA----NVAEDFSLMSLVYTN-
 71 | >seq36
 72 | --------------------------------------SLRSFVATHGNGILSSEIKGVGVQLESEDLKDLASNLAMGLLPARFNEMLVAAFQFNPPKALPLYAHNSLEGVFVYAGNMSATEIKNISEEFSLMSLCYSNF
 73 | >seq37
 74 | --QSRISTYRLAESKEQLLDQFYQATPTQAWFYLKLVPTIQTLICVSSANCPEEWSEGLSYKI-PAREKGFSEQLLSGIIPEGLLSYLKVKLGVDKIKFLPLIIKQSVEGLLISTQDI----SAEVAEDFSLMSLVYSN-
 75 | >seq38
 76 | -------------------------LEKIPAVYFKFLPSVRSFVATHASGFNANQIQGVGCQLENQDLEGINTQVSVGLLPPLMLEMFRKVFHLENPKVLPLFVQNAMEGIVVYSRELSKADSFRVGEEFSLFSLCYS--
 77 | >seq39
 78 | --QTRIADYKLAESKEELLNTFYRNTPDQTWIYLKYIESIQTFMGISSHLAPESWVEGLSFKI-PRGQSEFNQTIARGGLPKDFLNYLTEKFDTARIKTLPILLRDKVDGLLVTTQEIPI----EVSEDFSLFSVIYA--
 79 | >seq40
 80 | --AEKMRVYQGAQSKDDYLAVFLQHLPC-HAIYFKFLLTVNSFVATASQKLEIESIKGVGSRLTADEVKSLVEDLEAGRLPASLKELMNEGLKVPRYYSQPVPVHRGLDGLLVFWGDAEF-QYQDIENDFLIFHMLYQQ-
 81 | >seq41
 82 | -FQTRIAEYRSAESKEDLLQKFFAQTPQQSWAFLKYVKSINSYILVSSQNMPESWIQGVSYKI-QNAEADFNKKVIVGEFSQNFLNYLKRKWEV----------------------------------------------
 83 | >seq42
 84 | --TSHVLKYQQASNQEELLQTFMNSLGAIHSIYFKFLPTVNSFVATLSHGIDIESVKGAGSRLTDAESKDVHEFLASGQIPEALRALMQEGLKIVQFISQPVLLYRSLDGLFIFWSSENF-NFSLIENQFQIFQLVYQN-
 85 | >seq43
 86 | ------------------------QAPTQSWAFIKYVKTIQTYVSMSSQNMPGDWVEGLSFKI-PTDQSDFNDKLIVGSYSDSFLDYIKNKWGVKTVKILPLLHKNEVEGLLVTPQDV----SAEIAEDFSLMSIVYSL-
 87 | >seq44
 88 | --KELVQKYKATDSKEDLLGQYLGKLKF-QALFMKYLPTVQSFVATYSQGIDLDSIRGVGARLNTEESARLNELTQNEYLPPSLREVIEKGLGVRQFAVKSVTGATGLDGVFVFWSDHNF-DFSQHEDEFLIFQLLYQN-
 89 | >seq45
 90 | ----FISDLRAAQSKEEMISTMLRESKETPLVYLRYLPSMASFLVTDSSYANVENFKGLGCRLTAEENKDLAKQLELAIVPPSLAELLMKAFRMNSPRIRSLMSGSVLEGILVGDASS-AESTHYLNERFAIMSLVYSHF
 91 | >seq46
 92 | --SEEFKKYRNATSKEDVLQIFLKEINQtflaknnkLSSLYFKFLPSVQSFVATQSIGVDIDSVKGVGGKFTEQDSKDPLELIKNGHVPSMIQELMKEGFSTEDFIFKPVFFDQMLDGFFIFWSNTQKVYSEEFENYFTLFLLFYE--
 93 | >seq47
 94 | -VREKINAYQKALAKDDFVQTFLEQLPC-NAIYFKWLPSVVSFVATCSKGLDIESLKGVGSRMTLEESRSVNEFLESGKLPEALNELMQAGLNIKNYFSQLIPVYHGFDGLIVFWGDENF-HFEQIENDFLVFRMLYQQ-
 95 | >seq48
 96 | ----FISDLRAAHSKEELIATLLREARGTPLVYLRYLPTMGSFLVTDTSYEPPEEFKNLGCRLAPEENKDLVRQLELALVPPSLGELLSQAFRIQNPRVRPLLNGAELEGILVGDANS-PEITHELNERFAIMSLVYSHF
 97 | >seq49
 98 | -VSHRVADYRSSQSKEEVLQKYFERLDKVQAVYFKYLSSVRSFVATHAAGFTSQQIQGVGCQLEGADLESLNTQIAVGLLPPIVMEMLQKVFNIQT--------------------------------------------
 99 | >seq50
100 | ----------------DLVQKFVSSLQKTSCVYFKYLPSVRSFVATHSSGIDARLLKGVGSQLGAEESKNLNTSLGMGEIPKSLKELLGEGFHFQQPRVLPLFVQSQLDG------------------------------
101 | >seq51
102 | --TEKANAYEDLTTKDEYLSAFLLRIPC-RAIYFKFLPSVNSFVAVSGHGIDIESIKGIGARMSPDESKDTLQFLQKGALPESLREIIVEGLKISQYLVKPVPLYRSLDGLFVFWGDEGF-DYREIENEFVLFHLFLQ--
103 | >seq52
104 | -FQTRIAEYRSADSKEDLLQKFFKQTPSQSWAFLKFVKSINSYILVSSQQMPEEWVQGLSYKI-PNSEKDFNDRMMIGEFSDGFLTYIK---------------------------------------------------
105 | >seq53
106 | -IKDFIMDLQAAESKEELLQVFLRESANLSLVYLKYLPSMSSFVVTHTAHLPLEKTEGLGCRLSPEENKDLLKQFSLGLVAPSLMELLEKAFQIKSPRVRPLFDRQALD-------------------------------
107 | >seq54
108 | -FQMRIADYRVAESKEELIQIFFKQTPTQSWAFLKFVKPIQTYISVLQQNMPEAWVEGLSYKI-PVNQTDFNDNVVIGSYPEAFLKYIKNKWD-----------------------------------------------
109 | >seq55
110 | --QLRIAEYKSCESKEELLDIFYKQTDQQSWVYLKYIPTIETFISVSQHEVPDNWVEGLSFKV-PAQDKDFLSQIFQGNLPPALENYLLGKFKVPQIKFIPL--------------------------------------
111 | >seq56
112 | --QLRISDYKLAESKEELLDLFYKNTPEQSWVYLKFIESIQTFMGISSHLIPDSWMEGLSFKV-SKNQKDFMNLVRQGELPKNFLQYLSQKFDTAHLKFLPVLLRDKVDGV-----------------------------
113 | >seq57
114 | ----LIQKFSRASSKEEFVDVYFKYVSErksgapFSAVYFKYLPSVVSFIATQGHQINLERTKGIGLKLMGDEAGHLAEQLQNQNLPMGFQNLLQEAFQIQDWTVYPLFLKDQVEGVVIFWGLE---LTPGDWEEFLLFQLCYQN-
115 | >seq58
116 | --EEKIQLYRYAVTKDDYLEAFLRQIPG-RAIYFKFLPTVGSFVSTFAQGLNLDDLKDVGVRLTFEESKDVDTFFKEGGTPVALKELLAEGLHVDGYMTKPVFSINTLEGIFVFWNFNAA-Q------------------
117 | >seq59
118 | -----------CQSKDDYVLAFLKQLSC-NAIFFKFLPTVSSFVATAAQGLDIETMKGVGSRMSLDESKDVKAFLSSGQVPQALNELMTDGMKVPQFYSHLVSVPQGPEGLLVFWDGQ-RSGPPAIEDDFLIFQLLYQQ-
119 | >seq60
120 | --------------------------------FLKYVKSINSYISVSSQNMPESWVEGLSYKV-PNNETDFNQNMLVGIYSEHFLNYLKRKWSVDIVKVLPLTLKDQIEGLLVTPQDIKG----EVAEDFSLVSLVYS--
121 | >seq61
122 | -FNDEIKKYRKAANKEEVIADFLKDLNQkflsrnqrLFAIFFKYLPSVYSFVSLQSLGLDVESLKGVGFRLTPEEARQPAELFAEGKMPSQLTQLLQEGLKVPSPLVCPVLVQGKPEGYFCFWTNSGGLSLESIANELSLFQVLFE--
123 | >seq62
124 | --------------------------------FLKFVKSINSYILVSSQNMPEGWIQGVSFKI-PGSEPDFNNKVIIGEFSPSFLNYLKRKWEVDIVKVLPLVLKDEVEGLLVTTQDI----SPEVAEDFSLASLMYNL-
125 | >seq63
126 | -VAERIRNYLSVESKEDLLSRWMLGLGEKPCAYLQYLPSVRSLVVTHGTL-PES--QGVGCQLTPAEAQDFATQVALGVVPPTLDDLLKKAFKFSSVRLLPLFTQGKLEGVA----------------------------
127 | >seq64
128 | ------------------------------CVYFKFLPTVMSFVATNGYHVDLERTKGIGAKMTPDETKTLLATILERKMPPSLHTLMAEAFRVSQSFLYPLLLKNQIEGVFVFWGVAESVFPRSFGAEFSLFQLCYQN-
129 | >seq65
130 | --GREISAYANAKTREEMTDVFLQQMQQkclrrnlnLRALILRFLPTVQSFVATQSLGLDIDKLRGIGARLEKDEAEKLTELARAGQHPKQLLQLLQQGLALDNFQVQPMILRDHVEAYFVFWVESGVLQKTDFENEFTTYSLLYY--
131 | >seq66
132 | -FQEKIRLFDAVESKDETIHTFFRALPC-RAIFFKFLPTVNSFVATMAKGLDIESVKGVGSRLKSDEIEKLQKILQSGALPSALEDLMTN--------------------------------------------------
133 | >seq67
134 | -------------SKEEMIWSFLSHLPC-RGIFFKFLPTVQSFVATQAHGLDIDXXXXXXSRLTQEEARDLDQLLQKKSLPKVLDELMKDGLRVPSYQSHVLRVQRHLEGLFVFWGLE-AEGAAQVENSLLIFQLAYENF
135 | >seq68
136 | -IQDRLNEYRSAESLEILVQRFLGQFQGRLVIYFKYIEAVRSLLALNAQGVEISKLAGLGLQVEVKDLQIFYDQLLLGLAPSKVISEVSELFGSTKLKSFPIYVFEKLQGFFITVDED----LEIFESDLSLVSLAYSH-
137 | >seq69
138 | -----------MQSKEELIRWVFSRIEArysqeskINGIFFKYLPTVSSLVATLTMGLGSDKAKGVGAKLTPEERKTLGSDLSGNLLPESLKNVLTSGFHVSEFSSFPVFVQNQVEGVFIFWPMQ-ISESDKI--ELNIFQLYY---
139 | >seq70
140 | ---------------EGVVQTFLELSSElvekKPVLFLKYLPAHSALVAAQVAQMDPEKIKNVGFSVAQIDPKELNELFQHPEKIQQLTELMKVVFSQNEFSSLPFIYQNQVQGVFVLFGSFKSESDQKvFESYMQLMNVRYDN-
141 | >seq71
142 | --KNEVMVYAKAKSKEDILDLFLHHLEQkclrqnlkLNALVLKFLPTVQSFVATQSLGLDLEKVKGVGARLEKDEAQDLLGFLESGHFPKQLIHLLNQGLGVSNFVGKTLMLHDAIEALVVFWSSNGTLKAAH---------------
143 | >seq72
144 | ----------------------------------KYLPTVKSLVVTHASATSAERLNGVGCQLAAEEARDFASQVALGMVPPSVSKLLSEAFAMNSASLWPLFLHRQLEGVVAFSREVDPRVMAALRDEFSLFSTLYAN-
145 | >seq73
146 | -----------SHTVEEATQDWLNKANLlfsgVRAAFFKYVPGHPYLMLTQCSGMELESVRGVGVNLTGLSPLEQKSFYSHTRFLMALKDLLKGAFNSDGFEYREIISDRGVLGIVVVFKTLEHEREKIfFNDSVEILNIVV---
147 | >seq74
148 | -ISNLLTLYENTRTREDLLQVFFHSLESyclegAKVLYFKYLEPVQLLVATHGCGVPVDDIKGAGIRLQPQEVLDAKNLLLSPRGFGSLNKLLTDIFHVDQCYVKPLIVREDIDGLFVFFGESLESfNSIRFSNRFSLFRVCFER-
149 | >seq75
150 | --GEELAKYGKATSKEEVLSLFFRELEAkfsrleipIKALFFKYLPTVQSFVAMQGLGLDLDSIRGVGGRLVNEESHDPEGFFAKGALPYELKMLLNELVGRDepqmQMKSLYISIRDQVDGLFVLWGSSPQIKWQVIQNEFAFFNLLYE--
151 | >seq76
152 | -----LSSFHNATNSDELLQIMVNFISQSKVIFLKYFEGIQSFVGFQSNFGSAEEIQSIGCQLKPPESNELAKQMSLGIVPVTLKELSQKVFAFSQPQFWPLVIEiGKVEGVMISEGGGDEALKADILNRWSLFSLVYKNF
153 | >seq77
154 | ---DEVKIFEGAHTKSDLMALFLRRLSDlmreqqvgLKAIYFKYMPSVQSFVVMQSLGIDENSIQGLGGKLSPEEVKSFPDLFQQGTIPEQIQLLLEQGLKIQQFQSQVLWLNEILDGFFIFWSDQMRLETQFYSAQFLIFKMI----
155 | >seq78
156 | ---------ELCGSVEEAIEACMFQISNwaesAALLFFKYIPGHPYLVLAQSKGIDSDKYRGVGVSMTGLAFSELGHFSEHGRFLSQLDSVMKGAFKAEKYTFSEVKSEQGILGIVVLLKSLESRAQKRfFEDCISVLNLTSQK-
157 | >seq79
158 | -FKEFISDLRAAQSKEELIATMLRETKETPLVYLRYLPTMASFLVTDTSYTPAENFKGLGCRLTPEENKDLAKQLELAIVPPSLAELL----------------------------------------------------
159 | >seq80
160 | -ISNLLTLYENTRNREDLIQVFFHSMEEycqegAKVLYFKYLEPVQLLVATHGSGIPVDQIKGAGIRLLPEEALEAKNLLISPRGFGSLNKLLLDIFHIEQCYVKTLFVREDIDGLFVFFGDSLEvFNTIKFSNRFSLFRVCFER-
161 | >seq81
162 | -----------------------------QGLYLKYLPTVHNLVATRALGLPIEKLKGVGAKLTPEEVQQLDLTVARHEVPPSVKALMVEGFHVPEFAPRGLLVHRGVDGLFVFWSKTSF-DIDLLDNEFMTFSQAYQM-
163 | >seq82
164 | ---------------DHVVQMLLDAISQviqnKPVLFFKYLPQHSSLITSHASKIPVEKIKNLGINLSQVEATKIPEMLLHPLTIPGLPDLMKEVFQVPSYHAIPFVHQNQ---------------------------------
165 | >seq83
166 | ---------------------------------FKYLPAYSSLVTSHASKIPIEQIKNLGVNLNQFQPLEIPEMLLRAQEMPGLIDLMKEVFKVSTFSALPFVYKNQTFGIAVIFDTLQNQSTKRLVESFL---------
167 | >seq84
168 | -----LKIYDGLISKSELTATFLRRMGEilraqklsLKSIYFKYMPSVQSFVVMQSLGLDEATVQGLGGKLSPEEIKSLKDLFQIKKIPEQIQLLMEQGLKVHVYQTQVLWTQDILDGFFVFWSDRVNLEPNFCQAQFLI--------
169 | >seq85
170 | --------YKTCETKEQVINLFFEQMEGLKVLFFKHLPTVHSLLVTHSSGFSHEDVQGIGCQLQDHEHKDLTSQLTLGVCPGPLALLL----------------------------------------------------
171 | >seq86
172 | -----IKEFRAVATQEELIQKLVVFLPPkSMMLYFKFLPSIQAFVATHCCGLPQEQIQGVGCRLSLEEQSTFSADVILGRLPLSLASLLEK--------------------------------------------------
173 | >seq87
174 | ----------------EINQKFLAKNKKLTSVYFKYLPSVSSFVALQSLGIDIETLKGIGCRLTKEETEDSKTFFAQGGVPAELKVLLNEGLNAPQAIVKPIFVQEQLDGFF----------------------------
175 | >seq88
176 | ----------------GVLQVFIEAVSDitdgKPVIFFKFLPAYSSLVANHAAKIPVEQIKNLGVNLSSFDTKSIPELLMQPVKLQPLLEMMKEVFQVQQFFALPFVYQNQPIGVIATFTPLSSDPVRRLLESFL---------
177 | >seq89
178 | --------------------------------------------------MPDEWVQGLSYKI-PNSDKDFNDRMMLGEFSDNFLNYIKRKWSVDIVKVLPLIIKNEIEGLLVTPQDI----TAEVAEDFSLVSLVYN--
179 | >seq90
180 | ----------------------------------------------HASGFTSQQIQGGGCQLQGADLENLNTQLSVGLLPPIMMEMLTQIFHIKTARVLPLFVQNQLEGVVVYSAEAGKADTLRIGEEFSLFALCYSY-
181 | >seq91
182 | -VRQEIEKYRDSRSKEEVLNVFLREISQkflsrnkkISTIYFKYLPSVHSFVAVQALGIDIETLRGVGGKLTDEEAKDPKSFFLSGGIP-----------------------------------------------------------
183 | >seq92
184 | -LEERIKEYRSAAAKEDLLSKFMQIIAPIPCIYFKYLPTVKTLLATNASIFKPEQLQNVSCELRPEEAKDFANQI-----------------------------------------------------------------
185 | >seq93
186 | ----------KSKNLDEVIQSFLDQTQNligKPLIFLTHMPSYLSFIASHAAGIEKAKIRNLGLNLKSIDSKQYLEKIADPQSLEQLKQLMTDIFKVPDYLALPMEEEGAITAYILVLGSVEDVSLRRLLDSFvNIFK------
187 | >seq94
188 | ------------------------------------------------QNMPEVWVEGLSYKI-PANQTDFNENVLIGNYPESFLSYIKNKWGVDSVKVLPLIMKDEIEGLFISPQDITA----EMAED-----------
189 | >seq95
190 | ----------GTKSLDETVQAFMDELSResqrFPVIYFKYLPAHASLAISQASHLAVDKFRGLGLDLKGLEPKALSSFFQSPEKSEKLEDLLKQAFRTESYTAFTHFNEGEALGLFVIFNRAKGEAHAILDSATH---------
191 | >seq96
192 | ------------TEVDEVLKTFLSHTSKitgdAKVLYFKHLPAYTSLLLYMAEAIDVNDFKGVGVSLKDLYPVKYNEIILSPQKIESLKNMLSEVFKMNSALVIPLISESNVAGVIA-CEAIEDKSIRRVFDNFlQVLSLNYQN-
193 | >seq97
194 | ---------------KEAIDLYLQEVSRflgeKQIVFFKYIPSHQSLVASQSVVADVNKVRGLGFELSKEEENFSLDQLHKPMTLEGLRRLMSSSLGIDEYFAQALVVQGVPEGVFVFVGSENIQAEPYIKSCVKSLG------
195 | >seq98
196 | -------------SKEEVLNIFFRELNQkflsrnlkISAIYFKYLPTVQSFVAMQTLGLDLESVRGVGGKLNAEESKDPDQFFAKGAIPQSIQILL----------------------------------------------------
197 | >seq99
198 | --------------VTQVVDIFLKQSAEiledASILFLKHLPSYRSLSVSQAHNVDPQKVRGAGISFKELSQQDYIQRISSIADAPELAEMLVQVFQKSEYVVFPLSIENQLAGAFVFLEEIHDMQSLRLLESF----------
199 | >seq100
200 | --QLEFKKYFQCKTKDDVLALFLLELQQkflarnkkLNAVYFKYLPSVTSFVALQALGIEIDSLKGVGCRLTDIESADPQAFFSQGGVPTELKELMSDGL------------------------------------------------
201 | >seq101
202 | ---------------EAVIQTLVNSTSDliqkKPALYFKYFSAHATLVASHAAQIPGDHIKNLGLKLSTAEPKQLHAWLTQepGKIP-QLTELMKEVFTCKDYVAVPVLYKNDPQGTLVL--------------------------
203 | >seq102
204 | ---------KKAKSRDELFRlyfKWLSQFSNLNLLYFKHIPSAHALMMQYCIGLDQQKLDGLGCELNPKDQL---------KDPQSLRELIRQALGQASVQVIPLVLQNEIQGYFVYWGQQQNISADIFELYFE--NLVY---
205 | >seq103
206 | ---------RISETKdlDETIQIFMEELSResrqFPVLYFKYLPTHASLAISQASHLAVEKFRGIGLDLKRFDPKQAAEFFKAPSRDSELNDLIKQAFRANEFAAFTHFNDSETLGLFVVFNRPAT-AAQNI--------------
207 | >seq104
208 | ------------------------QLAGlRPVIFLKYLPAHQSLTITASAQIPFNEIRTVGVKLNHYDEKNISFVLGEPEKFKELKELMTELFKCPNFTAVPLRVEGKVVGIFAIFQAMKD--FDElVNKIVYLGALSY---
209 | >seq105
210 | ----------------------------KPVLFLRFLPAYSSLIAGAASKLPLDQIKGLGINLSSFDPKQISELLMHPEKIAPLSEFMREVFKVADFSAFPFVHENHPIGIAL---------------------------
211 | >seq106
212 | ----------------------------------------------HSQNMNSIEVEGFTFQIELRQVKDYMSQFSLGVVPPQMIQSIEEKLKIKLNKALPLFVEDQLEGVFVTPTAI----PEDFAEEFSLFSLTYQL-
213 | >seq107
214 | -------------SSDEIHSILLKEISNefrsSGVLYFRYLPTHKSMVLTSWQNVDESNTKSIGLSFKKLTQEEIQNVLLSPHLCQDLTRLMAIVFKVKEFETAPMLgIDGAINGIVILTQKFESDRQSElLNSMVSVGSLVLQ--
215 | >seq108
216 | ------------KDVSDTIQTFLEHTSQltegTKVLFLRYLPAYYSLLLTHAAQYPMEEGKKIGINLKDIDPKSVMDHLRTPSEMPLLQNLLKQVFSVEQYLAVPVETDDEFISIVIVCRDMSDPALRRVFDSFmQIFKVSY---
217 | >seq109
218 | ---------------------------------------------------------GVGSQLTESETQSVAETLKKGTTPQSLRNLMAEAFAVTDFILKPLFVHKNLDGLFLFWSQTGV-AADEFTNEFLIFSLLYQN-
219 | >seq110
220 | ------TQLSKAKNIDDVINIYLSHVASlvdqKPIIFLTHVSAYLSLVVTHSIQLNKDNLKNVGVSLKDLDSKTYVEQLMNPMSMAGIQGLMKDFFKLDKYLAVPIEDETNIGGIVLVFDAIEDVSTRRLFDSFT---------
221 | >seq111
222 | ----------KAKGVGEVVQIFTEQTSEmlggVPVAYLTYFPSMIAYGVTNCALIDKEKIKSLGINLKDKPKEEYDKILQDPYSSSFLNQFVSQYFKTTAYLAIPLVDQGGLKGLFLVLKDITDAQQRRLLDSF----------
223 | >seq112
224 | ------------------------------------------FVAMAAHGLDIESIKGVGSRMTKEEAKDVIAFLASGKVPQSLDELMSQGLKVPQYFSQAVPIHHGLDGLMIFWGDESF-HFKKIENDFYVFLMLFQQ-
225 | >seq113
226 | ----------------------------------------------HSSGYLQESIQGVGCQLPENENQDLVSQLSLGVCPGSLAVLMKRALGLAQVQVFPVFSQNDLEGVIVADVVPDSEFHRHLSEQFALFGLAYSHF
227 | >seq114
228 | ------TQMSKAKEVDQVIQVYLDQVSElvgkKPVIFLSHISAYLSLAVTHSAQVNKETLKNIGISLKDLDTKVYIETLMNPMSFSDLQSLMTDFFKTSDYFAIPIEEDNTVAGIVVVFDALKDVPTRRLFDSFT---------
229 | >seq115
230 | ---------KLAQTKEldEAVRLFINAFSRlssdTPALYFKYLPNHLSLVFSQASLLPENKFRGIGIDLKKVSDLGPDRFFENPASAKILREFVGEVFKKDRFTVFTHRVDQEVVGLFLALEQFDLKARSDLSLLIEAFDLAYK--
231 | >seq116
232 | --QAQIEKIQAATILSRTVETYLQAAHSrtgKPVCFFKYLPARRSLFLTNASGLNLSDYRGLGIDFVKENPEFQASHLQTPHNLQALQELIRTAFKVEGFYALPVDDGNRVRAIVVFLQ------------------------
233 | >seq117
234 | -VQQEFKKYLECQNKDEVLALFLLELNQkflarnkkVCAVYFKYLPSVSSFVALQTLGIEIDSLKGVGCRLTAKENADPISFFAGGGIPQ----------------------------------------------------------
235 | >seq118
236 | -------------------------IKKQPVIYLKYFSAHATLVAADVANLPPEYIKNLGINLSDMDPKELQRVLmEEPAKIQQLKELMKEVFKIDKYVAFPFT-------------------------------------
237 | >seq119
238 | --------------IEVFLTHIEPQLSDeSGALYFKFNSPRRTLVASYGLGIDPTLIEGIGLNLEKVDPEFKRSDIKQIHERDSFNGLLHEVFGVSEYSIRYLTLNEEVIGVFVHWGLLEKTPRHDfVDRSFTLLSK-----
239 | >seq120
240 | ---------------DTVIQTYLDHVSQlvnKPAIYLSHISAYLSLAVTHSSQINKDSLKNIGVSLKDLDSKVYIEKLMYPMNLEGLQALMKDFFQTAEYFALPVEEEAGIGGVVVVFDPLKDVVTRRLFDSFT---------
241 | >seq121
242 | --------------AQDVVQVFVDFASHalgdLPVVFFKHLPAYFTLSITHSAVVPMAQLRGIGLNLKSEISNDYGSLLKNPAELPSFRQLLFEVFGARDFEAYPIETEEGPSGVVVALQKLEDAATKR---------------
243 | >seq122
244 | ---------------DEVVHVFIHTFSRltsdTPVLYFKFLPSHLSLVFAQASLLPEEKFRGVGIDLRKSHPGRPEEFFENPEEAKALREFVAQVFKKNRFTALTHRVDREAVGLWVALEKFDMK-------------------
245 | >seq123
246 | --------------LSCAAETYLQAVHSrieHPVCFFKYLPARQSLFLAHASGLDLNAYRGLGIDFSKGKAEFQASQLRDPSSIQEFSELIQTAFKVEKFCALPIDDGDGVKAIVVFLQSN----------------------
247 | >seq124
248 | ---------------------------------------IQSFVVLQSIGLDAEAVQGLGGKLSPEELLNIKVLFQSQQIPGQIDLLLRSGLKVEKYLGQALFIDGQLDGFFVFWSDSSEIQESFYAAEFLVFKMVYQN-
249 | >seq125
250 | ----------------------TRSLNDTPVLYFKYFPGHMTLLFSRATLLSNDTFRGLGIDLKKEGSRTLEEHFEHPATLPGMRKMVREVFKCELFSALTHMDGNEVRGVFIILDEVEV--------------------
251 | >seq126
252 | -----------------FLANLHQNLSTSPVLYFRFVNSPLSLVITHASGIQVDKVRGVGFHL--QDTQSIREKLLQPEALTGFYDFLKQAFRLQE--------------------------------------------
253 | >seq127
254 | -----------AATKdlDETLDIFLEALAAetkAPTLYLKYLPTHASLLIAQVAQLPIEKYRGLGLDFRKEGIQNASELLRDPSEVKPLAELIRSLFQKESFVAIPHLNEGEALGVFVVLGDLDVSNSA----------------
255 | >seq128
256 | -------------TLDDAVQVFMNSASNalggCSAVFFKYIANRRVLIAGRGEKLQDHDLRGLGMDL-NRSVDGFRaVQLREPMHLMAFVQMVKDVFGIKDFFAWPVRALGEIQGLICFLKN-----------------------
257 | >seq129
258 | ----------AAHSIDQAVEVYLKAMHQslaQPVCLFKYLPARRSLFLTQVMGLPPEPYRGLGVDFSKQEENFSSQVLRQPEGLSSLRDLVNSAFKTSTFAALPVDDGEGIKSVVVVLNRS----------------------
259 | >seq130
260 | -------------------EAYLQAVHSrtgKPVCFFKYLPARQSLFLVHASGLDLNAYRGLGIDFSKEGSEFQGSHLLDPQSIQKFRELIRKVFKVERFCSFPVDDEDGVKAVVAFL-------------------------
261 | >seq131
262 | ----------------------------------------------------------------PEEMEDLTSQLQAGHLPSALQELMTQGLKIAKFNFKPVLVQRALDGLMIFWGEASA-HQALIENNFVIFQMLYQQ-
263 | >seq132
264 | ---------------DHIIQAMVDEISHtsqnAPALFLRFLPAYQSLTVTASAQIPLESIKTVGIKLDRYDPRNIASVLKDPLKFLELKELMSQVFKCPEFYALPFVIGGNPIGLMISFGTSE---------------------
265 | >seq133
266 | ------KELEMCQNKEDVVNLFLREAVRylgtKEGLYFRYLDLHRTLALAQTFGIERNDIDGVGIDLAALEPGFLPSQLSKPHLLFSLQEFVKKGLLRSGAVILPLNFRHQTIGVFIF--------------------------
267 | >seq134
268 | ----------------------------APAVFLKYLPNRRALVTSAAHRLPAEAWKGLGLNL-SEEPDFRIGDLRHPEKLAGLGEMVQSLARAKEYWVRPMIIRDQVHGLFLVMGPSGDLPVQQMESTISVM-------
269 | >seq135
270 | -------------------EEISRASQKAPALFLRFMPAYQSLAVTASAQIPLESIRTVGIKLDRYDPRNVSSVLRDPLKFPELQELMGQVFKCSEFYALPFIVGGNPIGLVISFGKSEN-LLKIFSSLVIISELAYN--
271 | >seq136
272 | --------------------------GQAPAVFLRYLPNRRCLVTTAAHRIPAEAWKGLGLNLSEEPDFRI-GDLRHPEKLPGLREMAQSLGQTNELWVRPLVIREEVHGLFVVLA------------------------
273 | >seq137
274 | ---------------SDAVALFIDEANVlieNPIVFFRHAPTYSSIVYSQSAGLERGDYKNVGLNFRELNDRVYPNNLKNPDNIEPLRELLAKIFQVTKYETHPLIIDGDVAGFFVCFTEINEPPMLRlFHAAFEILQMK----
275 | >seq138
276 | ----QVEQIQASMTLPNAVEAYLQAVHVrtgKPVCFFKYLPNRRSLFLAYASGLDLSTYRGLGIDFNKEKFEFQKGLLRDPQNIQVFKDMIQTAFKVERFCTFPVEDANGVISIIVFLKA-----------------------
277 | >seq139
278 | -------------------------VNSAQAVYFKYNAVRKAIVASYGLGVDPALIEGIGLNLERVDPSFQRSEIRVLHERESFQNLLRDVFKVKTFHCHYLSVGGEVIGVFVHWGLSENSE------------------
279 | >seq140
280 | ----------HAKDVHEIIQRFMNSANTafsqAPVIFFRHLQTQDQFVLSQLVGINEPTLKGVQLGLSSEQI---------EQDPDALKDSIMALFEKDGFEFRFLKEDQQIIGLFVILKSFNDPYERKcLSHLFDIFDLSYQQ-
281 | >seq141
282 | ---------LAASSIDRAVEVFLKAMHQtlsQPICLFKYLPARRSLFLTQVVGLPAEPYRGLGVDFSKQDENFSSQSLRQPETLPRLRELMQSAFKAST--------------------------------------------
283 | >seq142
284 | ---------KMAQAKDtvEVIDFFLNQVSEltgKPCIYLKHFSSFSSLVVSNSSILEIEKLKKVGLSFKEEDPKTYHAQVLDPSKMTKLNALMAQAFKVEK--------------------------------------------
285 | >seq143
286 | -------------------------LASCGMVYFKYIANRRVLMATQAHKLDIEW-TGLGVNFNETGDTFRTQHLREPGQIVEVQNMMREAFHAGEFFAFPVETLGEIQGVVAFLRPEPDAATMKmIQD------------
287 | >seq144
288 | -FSEEFMRYQTAKTLDSLLAIFFERVQKisrknYSGLFLKAIPSIDSLAVFGGTAFDWQKWKGYGVKVQ-----SFQQNPRRPLVISEVVEFMREVVGVEQFVSFPFAFENHLEGVFLFWCHQDKPQSRDLE-------------
289 | >seq145
290 | --------------------------------FFKYIANRRVLMAGQAHKLDDFDLNGLGINFNEANSDFRSSQLRDPHGLNELQAIMAEVFNASEFVALPVEALGEIQGIVILLRNDPDPAGQQkLQEWVFLLS------
291 | >seq146
292 | ---------------DGVLQTLLGSLSQlidgKPALYFKFLPAYNSLVVGLSEKIPIDGLRTVGISLEKYNLKEVPTLLTTPEKIPQLNDLMGEVFK-----------------------------------------------
293 | >seq147
294 | ----------KAQSLDELGINLVWSLKNivssrRKGIYFKYLPTYCSLVALGGFNFEDKNISGVGLNFASSKDFNAAQHLHKLMYVPAFLKVVERLFAHTNVTVTTLSCENEVRGVLVCEKPPSR--------------------
295 | >seq148
296 | ---------------DDAIQLFLREAIRylqtKEALYFRYREVHHTLVLTKSVGIDLADVDGVGIDLAEHEPGFVPALLQKPHMLFSLQEFVKKGLLRQHAVILPHIFRGQVLGVFVLPCDKGLARVHD---------------
297 | >seq149
298 | ----------RSTTRDDSIDLFLREATRylqiKEGLFFRYREIYQTLVLTHANGLDPKKFGDVGIDLAAHEPGFVPALLKKPHLLYSVQEFVKKGLSRQSAVVLPFFYRDQIYGVFVLP-------------------------
299 | >seq150
300 | -------KISVITNINDVVDCFVQSLFElinRPAIYLKYVPSHTSLIVTHVAGLDINKFKNAGILF-KEEPQSYLEKIKHPQKFQQLQEFMLHIFQVQSYFPIP---------------------------------------
301 | >seq151
302 | ---------------DETIQIFLETLSRewndVPVLYFKYLPSHASLPLALGAGQKIEKFRGFGVDLRKESPDQIVEFFRAPESSEILKRFMKEVFASDNFTPFTHTTEGEALGLFVALTKTP---------------------
303 | >seq152
304 | -----------AEAVQLVVNELHRHMPFSQVVFLKHIRGRSTLVAESSSGIEMQALRSVGVDLKQTEPSFKEALLLRPEKLVGVTDLVRSGFDNRHFAAFPVVVQKEVWGI-----------------------------
305 | >seq153
306 | ----------------------------RKGIYFKYLPTYCSLVALGGFNFDEHSVSGVGLNFSSSKDFNAAQHLHKLIYVPAFLKIVERLFEHTNVNVKTLVCDGEVRGVLVYDKSPSM------S-------------
307 | >seq154
308 | ----------------------------------------------SGHHVSASSLKGVGFQL-EGTIENLIAQVQNGEIPKQLQTLMREGFQAGQYLCLPVWVSGDLDALLVVWSPSEVLRAEIFANHEALFSILYE--
309 | >seq155
310 | -------------SLTDIVTHLCSEIHKensCDVVYFKYIDSQGTLVAAHSEGLAFETIRGIGIDFISSGKKFFRDQLHLPASLMEMKELVQQVFNEKEF-------------------------------------------
311 | >seq156
312 | ----------------------------------------------------------VGAKLTVDEYRALENLLRARTLPESLRKVIQEAFRIEKFTTFPLFLKNQVEGALVFWGLDLN---EGDWGEFLIFQLCYQN-
313 | >seq157
314 | ----------NCDSESEAVKHCLGEIARslgrGQIVFFRFIRGRATLVAEAANGISAEAISNIGVELKKTEPKFNEKLLQRPERLLGLLDLVRNGFMQRQFAAFTVEIDRVPEGVILIL-------------------------
315 | >seq158
316 | --------------------------------------------------LDMEELKGVGCRLVESEANDPVSFFAQGGIPAELVGLAREGLNANEPIFRPIFVLGELDGFMIFWSRQNEIYPEELDNDLSLFQLMYER-
317 | >seq159
318 | ---------------QNAIELYMKEVSRylkgAAVIYFKYIPGYESLVVTQSVGHDLNELSGVGLNLLEEEKNFDQEKLKFPQQLNSIRRLMKE--------------------------------------------------
319 | >seq160
320 | -------------------------------LFFKHLPAYFTLSVTHSAVAPMAQLRGVGLNLKDELNADYLQLLKSPSEMAGLKTLLFELFGAREFIGYPIETDEGISGLVVALHGLEDPASRRLFEAFTrLFEMQY---
321 | >seq161
322 | ---------------------------KRKGVFFKYLPTYCSLVAMGSFFFDSpKKLNGVGLNFSKSVKFKPSEHLQHGLKVPAFKKLCEKIFGHKNLNIRVLSVDHEVQGILVYEKPPANSL------------------
323 | >seq162
324 | ------------SSVDECVQVFLASAAQalgsCPAVFFRYIANRRVLLAAYGEQMEAVDLSGLGLDLNETAPGFRTVQLREPMRIIPFVEMVKVVFAVPEFFAWPIHALNEIQGLACF--------------------------
325 | >seq163
326 | ---------------EDIVNAFLTYLSElvdgKICLFLKFYPAKSALVVRNIKGHDLEKIyteqdisdfKNIGMSLGPASEKDIVSIVARIARHPSLKTLVTKLFNTSKYMAYPLIIRDTPIGVTIVVDEMTLSERDDkiLKQYLNQLEISY---
327 | >seq164
328 | ----------------------------RKGLYFKYLPTYCSLVALDGFNFSNKKFNGVGLNFSSSKDFDANQHLNKLKQVPGFIKVIERVFGHQNINLKTLECDGEAKGVLVYEQAPR---------------------
329 | >seq165
330 | ---------------SDTIQVFLEHTSQlaenSKVLFLRYLPSYYSLLLSHAASYQMEEGKKIGLNLKEIDPKKIMDILRLPQEMDLLKNLLTG--------------------------------------------------
331 | >seq166
332 | ---RLMQAFSKAKDIDAVIQIYLEHTSQiignKPIVFFTHLSSYLSLLVSHVVGYEKEALRNVGVNLKSVESKEYMNLL-----------------------------------------------------------------
333 | >seq167
334 | -------------------VMMSEQTTKAPAVFLRHLPNRRCLVTRAAHRLPAEAWKSLGLHL-NEEPDFCLSDLRHPEKLSGLKEMGQTLVGHDEIWVRPLILRDEVYGLFVVFSALIDLPMNRLESIVK---------
335 | >seq168
336 | ---------------------------KilpekRKGVFFKYLPTYCSLVAMGSFFFENpKKINGLGLNFSKSVKFKPRQHLQYGLKVPAFTKLCEKLFGHKKLNIRVLSADQDVKGILVYEKPPANSL------------------
337 | >seq169
338 | ---------KKAKSLDELGINLVSSLNKivlpgRKGVYFKYLPTYCSLVALGGFNFESKKVTGVGLNFSTSKDFDASKHLQQLMYVPAFLKVVERLFEHTDVTVRTFDCDRESKGVVVYED------------------------
339 | >seq170
340 | ---------------------------TqgRKGIYFKYLPTYCSLVALGGFNFKNTKVNGVGLNFSSSKDFNASQHLQKLLHVPAFIKVVERLFEHTEVVVRTFECDGETKGVAVYEKPPGG------SSDIEVLSLC----
341 | >seq171
342 | ---------------DQCVQLFMESVSRvfsdVPILYFRYVASHMSLLVSQAVWLPIEKIRGIGVDLKNEDPARLPECFRDPSRLEPLKTLVQQVFR-----------------------------------------------
343 | >seq172
344 | -------------------------------IFFKYLPAHLSLVTSHASKIPLEQIKNLGINLSQLSQVDATKISEMLLQPSSlpgLPDLMREVFQLQAYEAIPFVHQN----------------------------------
345 | >seq173
346 | --------------------------------------------------INKETLKNIGVSIKDLDQKDYVEKLANPMELVGLKTLMKDFFQTTEYFAVPVEEDSAIAGIIVVFDPMKDVSVRRLFDSF----------
347 | >seq174
348 | --------------ADDCIQIFLQSCSNmlgsCGVIYMKYIANRRVLMTTLAHRIDAEW-NGIGVNFNETTGDSFR--TAHLREPNNipeVKQMIHEVFHTEEFFAHP---------------------------------------
349 | >seq175
350 | ---EAILNLKKAQSLDELgvnLSWSLNHIVIdeRKGMYFKYLPTYCSLVSIGGFNLKDKKTNGVGLNF--SSSKDFNSLMHLREVLNVpaFENIVKKFFEHTDVQTRLFECDGGVKSLLVYERAPGGS-------------------
351 | >seq176
352 | -----------RNSVNECVQDFLDFGSKllgdCGAIYLKCLPLRK--VLSATHGVALENWKGVGVNLADESHFTWGALQEPQNVP-AIREMVREIFNRSDFQAFTFKVAKEVNGIALFFS------------------------
353 | >seq177
354 | -------------SIDDAMNVFLKNVSSvlgsPPVLYMKYIANRRVLMASQSQNLESFDLNGLGVNFNELNANFRASQLHDPQAIPEVGSLVKEVF------------------------------------------------
355 | >seq178
356 | ------------NSADECLRVFLESASHslggCAAVYFRYIPNRRVLLAGHAINTHGIELKGLGINFNEVAPGFRTAQLRDPMGIPEFSEMIREIFGVSEFMAWPVE-------------------------------------
357 | >seq179
358 | --------LKKCETKKDLIKCFLVEFSRyykgAPVVFYKYVRSYKSLVSSMHLGVKSFKS-GEIIKLTAKEDGQLNNDKLDSIYLESLENI---AFDDEeGVYFLPLSILGELKGLFVFSRQEA---------------------
359 | >seq180
360 | --------------------------SKRKGIFLKYLPTYCSLVATSSFFFDSpRKLNGLGLNFSKAIKFNPKEHLQQGLKIPALAKLCYKTFGHRMLNSRVLTSDKDIQGILVYEKP-----------------------
361 | >seq181
362 | ---------------DEIIETFLSYTSEllggATCVFLKYYPQKTALVARHVQcrsehcpytKEQLEGIKNVGMSLGVAGEKDIVSVISKITNHPSLKTLVYKLFNTSKYIAFPLIIRDTPLGVTLIVdvNSLGSKEDKIVQQ------------
363 | >seq182
364 | ---------------DEIIETFLSYTSSllggSTCVFLKYFPQKTALVVRHIQckgqcsfsKEQMEALKNVGMSLGVAGEKDIVSIVSKISNHPSLKTLVNKLFNTTKYLAYPLIIRDTPLGVTlIVDQNsLDPKEEKIVQQYLN---------
365 | >seq183
366 | -----------CTSFTDSIDCFMSEISRylknTSVMYLKYVPAYRSLVTSRSVHLPDF-TSGESFDL-KTLFKENNLNETNFEKSNEFRDKIYDYTYWDNFSIIALRINGDVKGLFIIKTDFVNE-------------------
367 | >seq184
368 | ------------------------------------------LVLTQSTGVDSNEFDGVGIDLAEHEPGFVPALLQRPHLLFSLQEFVKKGLGRQNAVILPLIFRGHVLGIFVLPEDMDKAMRH----------------
369 | >seq185
370 | -------------------------FNCQSLVYFRYLKSYSSLLVTHSEGLKFSDLRGKGISFSTNQNFVPERDLKRIDSNPLFYELVRKLIPNQAYTSFLFEACGEPKGVFVLANA-----------------------
371 | >seq186
372 | -----------CHSVEDAIQDWLNKINKlyqdTPTVFFRYIPNHSHLVLSQCAGLDLQKVRGIGVPLAGLSLKEQKYFYSHVRFLANLRDLVKGAFNVSEFEFRELVTDKAVLGLCVIFKNLKNESEKRfFNDSIELLNLVA---
373 | >seq187
374 | ------------------------------WIawYFKFMPEVQAFVVTQYRRFNDKQTFPFsSFKPQTVSVTDLFALLHAGDRNVELFSFVAKHFDITNLQVFPISYSNLVDGIFCFIGQSEEK-------------------
375 | >seq188
376 | ----------------------VRYYKTDLALYFRYNPGAGTLVVMRASGLPLEHFQAVGIHLRKNE-PGFTEEIL--HQPNRiriLREFVIAGMNRLDFVAFPHVENKIVRGLFVVPAKKNQ--------------------
377 | >seq189
378 | -------------------------------------------LVSHTACLPIDKFRGIGVQLHSQSAIDLAGQLNEPMRIDGLRQLVAEVFRRDNFFAFTHSSEGEILGVCIV--------------------------
379 | >seq190
380 | --------------VHETIQVFLDHVSQlfddSRVVFLRYLPAYYSLMVSHTAKIavqqpptaNPEEARKVGINLKDIDPKTVLDQLKNPQTFVPLSELMSE--------------------------------------------------
381 | >seq191
382 | ------------TTRPQMIEKWMREALRvhctNEILYMSYLGTKKTLIVTQSLGFAADELDDVGLDLAKEEPGFEESMLQRPEKLWALHQFVTKGLQRREALYYSLMHHGQILGVFILPAKEGE--------------------
383 | >seq192
384 | -----------------IIDTFLSYMSElingKTCIFLKYYPAKTALIIKHMAGKNigtlysAEQIEGfknVGMSLGVAGEKDIVSIVSRIANHPSLKTLVSKLFNTTKYLAYPLIIRDTPIGVtlLVDIDSVNEQDEKIIKQYLNQFEISYD--
385 | >seq193
386 | -----------------------------------------------------------------EESKDPEKLFSEGKAPESILKLTKEGLGAEEPVMKFVSIRGSLDGIMAVWGQGSPVFWDQFDHEFSLFNLLYE--
387 | 


--------------------------------------------------------------------------------
/tests/graphmodel_test/T1001.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/protein-sequence-models/af695772c4a1c056d930c95ec7e6428aa042f5cd/tests/graphmodel_test/T1001.npz


--------------------------------------------------------------------------------
/tests/graphmodel_test/T1001_loader.py:
--------------------------------------------------------------------------------
 1 | import os
 2 | import numpy as np
 3 | import string
 4 | from typing import Iterable
 5 | 
 6 | def parse_a3m(filename):
 7 |     seqs = []
 8 |     table = str.maketrans(dict.fromkeys(string.ascii_lowercase))
 9 | 
10 |     # read file line by line
11 |     for line in open(filename,"r"):
12 |         # skip labels
13 |         if line[0] != '>':
14 |             # remove lowercase letters and right whitespaces
15 |             seqs.append(line.rstrip().translate(table))
16 |     return seqs
17 | 
18 | def logits2value(logits, bins):
19 |     preds = np.argmax(logits, axis=2)
20 |     retval = np.zeros(preds.shape)
21 |     for i in range(len(preds)):
22 |         for j in range(len(preds)):
23 |             retval[i,j] = bins[preds[i,j]]
24 |             
25 |     return retval
26 | 
27 | def loadT1001(preprocess=True):
28 |     sample = np.load('T1001.npz')
29 |     sample_dist = sample['dist']
30 |     sample_omega = sample['omega']
31 |     sample_theta = sample['theta']
32 |     sample_phi = sample['phi']
33 |     seq = parse_a3m('T1001.a3m')[0]
34 |     
35 |     if not preprocess:
36 |         return sample_dist, sample_omega, sample_theta, sample_phi, seq
37 |     
38 |     else:
39 |         dist = logits2value(sample_dist, [None] + list(np.linspace(2,20,37)))
40 |         omega = logits2value(sample_omega, [None] + list(np.linspace(-180,180, 24))) 
41 |         theta = logits2value(sample_theta, [None] + list(np.linspace(-180,180, 24)))
42 |         phi = logits2value(sample_theta, [None] + list(np.linspace(0,180, 24)))
43 |         
44 |         return dist, omega, theta, phi, seq
45 | 


--------------------------------------------------------------------------------
/tests/graphmodel_test/graphmodel_decoder_test:
--------------------------------------------------------------------------------
 1 | from T1001_loader import *
 2 | 
 3 | import json, time, os, sys, glob
 4 | import torch
 5 | import torch.nn as nn
 6 | 
 7 | sys.path.insert(0, '../..')
 8 | from sequence_models.graphmodel_utils import *
 9 | from sequence_models.utils import Tokenizer
10 | 
11 | # load features 
12 | dist, omega, theta, phi, seq = loadT1001()
13 | dist = torch.from_numpy(dist)
14 | omega = torch.from_numpy(omega)
15 | theta = torch.from_numpy(theta)
16 | phi = torch.from_numpy(phi)
17 | 
18 | # process features
19 | V = get_node_features(omega, theta, phi)
20 | E_idx = get_k_neighbors(dist, 10)
21 | E = get_edge_features(dist, omega, theta, phi, E_idx)
22 | mask = get_mask(E)
23 | E = replace_nan(E)
24 | L = len(seq)
25 | S = get_S_enc(seq, tokenizer)
26 | 
27 | # reshape 
28 | V = V.view(1,140,10).float()
29 | E = E.view(1,140,10,6).float()
30 | E_idx = E_idx.view(1,140,10)
31 | mask = mask.view(1,140)
32 | S = S.view(1,140).long()
33 | L = [140]
34 | 
35 | decoder = Struct2Seq_decoder(num_letters=20, 
36 |             node_features=10,
37 |             edge_features=6, 
38 |             hidden_dim=128,
39 |             k_neighbors=30,
40 |             protein_features='full',
41 |             dropout=0.10,
42 |             use_mpnn=False)
43 | 
44 | with torch.no_grad():
45 |     decoder.eval()
46 |     output = decoder(V, E, E_idx, S, L,mask)


--------------------------------------------------------------------------------
/tests/loss_test.py:
--------------------------------------------------------------------------------
 1 | import torch
 2 | import torch.nn as nn
 3 | 
 4 | from sequence_models.losses import MaskedCrossEntropyLoss
 5 | 
 6 | 
 7 | if torch.cuda.is_available():
 8 |     device = torch.device('cuda')
 9 | else:
10 |     device = torch.device('cpu')
11 | 
12 | def test_masked_cel():
13 |     n = 5
14 |     ell = 7
15 |     t = 11
16 |     scores = torch.randn(n, ell, t).to(device)
17 |     targets = torch.randint(t, (n, ell)).to(device)
18 |     mask = torch.randint(2, (n, ell), device=device).bool()
19 | 
20 |     weights = None
21 |     mcel = MaskedCrossEntropyLoss(weight=weights, reduction='none')
22 |     loss = mcel(scores, targets, mask)
23 |     assert loss.allclose(mcel(scores, targets, mask.unsqueeze(-1)))
24 |     cel = nn.CrossEntropyLoss(weight=weights, reduction='none')
25 |     full_loss = cel(scores.view(-1, t), targets.view(-1))
26 |     assert loss.allclose(full_loss.masked_select(mask.view(-1)))
27 | 
28 |     mcel = MaskedCrossEntropyLoss(weight=weights, reduction='mean')
29 |     loss = mcel(scores, targets, mask)
30 |     assert loss.allclose(full_loss.masked_select(mask.view(-1)).mean())
31 | 
32 |     weights = torch.rand(t, device=device)
33 |     mcel = MaskedCrossEntropyLoss(weight=weights, reduction='none')
34 |     loss = mcel(scores, targets, mask)
35 |     assert loss.allclose(mcel(scores, targets, mask.unsqueeze(-1)))
36 |     cel = nn.CrossEntropyLoss(weight=weights, reduction='none')
37 |     full_loss = cel(scores.view(-1, t), targets.view(-1))
38 |     assert loss.allclose(full_loss.masked_select(mask.view(-1)))
39 | 
40 |     mcel = MaskedCrossEntropyLoss(weight=weights, reduction='mean')
41 |     loss2 = mcel(scores, targets, mask)
42 |     idx = targets.masked_select(mask).view(-1)
43 |     assert loss2.allclose(full_loss.masked_select(mask.view(-1)).sum() / weights[idx].sum())
44 | 


--------------------------------------------------------------------------------
/tests/pdb_utils_test.py:
--------------------------------------------------------------------------------
 1 | import os, sys
 2 | import numpy as np
 3 | 
 4 | from sequence_models import pdb_utils
 5 | 
 6 | ex_dir = os.path.join(
 7 |     os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "examples"
 8 | )
 9 | assert os.path.isdir(ex_dir)
10 | 
11 | orig_coords, orig_atoms, orig_valid = pdb_utils.parse_PDB(
12 |     os.path.join(ex_dir, "gb1_a60fb_unrelaxed_rank_1_model_5.pdb")
13 | )
14 | gz_coords, gz_atoms, gz_valid = pdb_utils.parse_PDB(
15 |     os.path.join(ex_dir, "gb1_a60fb_unrelaxed_rank_1_model_5.pdb.gz")
16 | )
17 | 
18 | assert np.all(np.isclose(orig_coords, gz_coords))
19 | assert orig_atoms == gz_atoms
20 | assert np.all(orig_valid == gz_valid)
21 | 


--------------------------------------------------------------------------------
/tests/vae_test.py:
--------------------------------------------------------------------------------
  1 | import numpy as np
  2 | import torch.nn as nn
  3 | import torch.nn.functional as F
  4 | import torch
  5 | 
  6 | from sequence_models.losses import VAELoss, SequenceCrossEntropyLoss
  7 | from sequence_models.vae import FCDecoder, FCEncoder, VAE, Conductor
  8 | 
  9 | if torch.cuda.is_available():
 10 |     device = torch.device('cuda')
 11 | else:
 12 |     device = torch.device('cpu')
 13 | 
 14 | N = 2
 15 | L = 5
 16 | D = 3
 17 | 
 18 | src = [
 19 |     [0, 1, 2, 0, 1],
 20 |     [0, 1, 1, 1, 0]
 21 | ]
 22 | src = torch.LongTensor(src)
 23 | 
 24 | n_hidden = np.random.choice(np.arange(3, 10))
 25 | d_h = list(np.random.choice(np.arange(1, 10), size=n_hidden))
 26 | d_z = np.random.choice(np.arange(2, 10))
 27 | encoder = FCEncoder(L, D, d_h, d_z)
 28 | decoder = FCDecoder(L, D, d_h[::-1], d_z)
 29 | 
 30 | 
 31 | def test_encoder():
 32 |     assert encoder.embedder.num_embeddings == D
 33 |     assert encoder.embedder.embedding_dim == d_h[0]
 34 |     mu, logvar = encoder(src)
 35 |     nm, dm = mu.size()
 36 |     nv, dv = logvar.size()
 37 |     assert nm == nv
 38 |     assert nm == N
 39 |     assert dm == d_z
 40 |     assert dv == d_z
 41 | 
 42 | 
 43 | def test_decoder():
 44 |     z = torch.Tensor(np.random.random((N, d_z)))
 45 |     p = decoder(z)
 46 |     n, ell, dp = p.size()
 47 |     assert n == N
 48 |     assert ell == L
 49 |     assert dp == D
 50 | 
 51 | 
 52 | def test_vae():
 53 |     vae = VAE(encoder, decoder)
 54 |     p, mu, logvar = vae(src)
 55 |     mu2, logvar2 = encoder(src)
 56 |     assert torch.allclose(mu2, mu)
 57 |     assert torch.allclose(logvar, logvar2)
 58 |     # Check shape of p
 59 |     n, ell, d = p.size()
 60 |     assert n == N
 61 |     assert ell == L
 62 |     assert d == D
 63 |     # Test encoder
 64 |     m1, s1 = vae.encode(src)
 65 |     m2, s2 = encoder(src)
 66 |     assert torch.allclose(m1, m2)
 67 |     assert torch.allclose(s1, s2)
 68 |     # Test decode
 69 |     z = torch.Tensor(np.random.random((N, d_z)))
 70 |     p1 = decoder(z)
 71 |     p2 = vae.decode(z)
 72 |     assert torch.allclose(p1, p2)
 73 | 
 74 | 
 75 | def test_loss():
 76 |     r_loss_func = SequenceCrossEntropyLoss()
 77 |     vae = VAE(encoder, decoder)
 78 |     p, mu, logvar = vae(src)
 79 |     r_loss = r_loss_func(p, src, reduction='none')
 80 |     kl_loss = -0.5 * (1 + logvar - mu ** 2 - logvar.exp())
 81 |     beta = torch.rand(1)
 82 | 
 83 |     # Without classweights or sample_weights
 84 |     vloss = VAELoss(class_weights=None)
 85 |     # With reduction
 86 |     loss, r, k = vloss(p, src, mu, logvar, beta=beta)
 87 |     assert torch.allclose(r_loss.sum(dim=1).mean(dim=0) + beta * kl_loss.sum(dim=1).mean(dim=0), loss)
 88 |     assert torch.allclose(r, r_loss.sum(dim=1).mean())
 89 |     assert torch.allclose(k, kl_loss.sum(dim=1).mean())
 90 |     # Without reduction
 91 |     loss, r, k = vloss(p, src, mu, logvar, beta=beta, reduction='none')
 92 |     assert torch.allclose(r_loss.sum(dim=1) + beta * kl_loss.sum(dim=1), loss)
 93 |     assert torch.allclose(r, r_loss.sum(dim=1))
 94 |     assert torch.allclose(k, kl_loss.sum(dim=1))
 95 | 
 96 |     # With class_weights and sample_weights
 97 |     cw = torch.rand(3)
 98 |     sw = torch.rand((N, 1))
 99 |     r_loss_func = SequenceCrossEntropyLoss(weight=cw)
100 |     r_loss = r_loss_func(p, src, reduction='none')
101 |     r_loss *= sw
102 |     r_loss = r_loss.sum(dim=1) / r_loss_func.class_weights[src].sum()
103 |     kl_loss *= sw
104 |     vloss = VAELoss(class_weights=cw)
105 |     # With reduction
106 |     loss, r, k = vloss(p, src, mu, logvar, beta=beta, sample_weights=sw)
107 |     assert torch.allclose(r_loss.mean() + beta * kl_loss.sum(dim=1).mean(dim=0), loss)
108 |     assert torch.allclose(r, r_loss.mean())
109 |     assert torch.allclose(k, kl_loss.sum(dim=1).mean())
110 |     # Without reduction
111 |     loss, r, k = vloss(p, src, mu, logvar, beta=beta, sample_weights=sw, reduction='none')
112 |     assert torch.allclose(r_loss + beta * kl_loss.sum(dim=1), loss)
113 |     assert torch.allclose(r, r_loss)
114 |     assert torch.allclose(k, kl_loss.sum(dim=1))
115 | 
116 | 
117 | def test_conductor():
118 |     b = 5
119 |     dz = 8
120 |     d_out = 4
121 |     n_f = [np.random.randint(64, 128) for _ in range(np.random.randint(3, 8))]
122 |     layer = Conductor(dz, n_f, d_out).to(device)
123 |     z = torch.randn(b, dz).to(device)
124 |     out = layer(z)
125 |     assert out.shape == (b, 2 ** (len(n_f) + 2), d_out)
126 |     z = torch.randn(b, dz, 1).to(device)
127 |     out = layer(z)
128 |     assert out.shape == (b, 2 ** (len(n_f) + 2), d_out,)


--------------------------------------------------------------------------------