├── biobert ├── biobert_pretrain │ └── README.md └── requirement.txt ├── .gitmodules ├── src ├── bioseq.h ├── bioseq.cpp ├── omp.cpp ├── tokenize.cpp ├── fxstats.cpp ├── kseq.h ├── poa.cpp ├── alphabet.h └── tokenize.h ├── .idea ├── vcs.xml ├── .gitignore ├── inspectionProfiles │ └── profiles_settings.xml ├── modules.xml ├── misc.xml └── bioseq.iml ├── requirements.txt ├── GAT ├── __init__.py ├── utils.py ├── secondary_structure.py ├── solvent_accessibility.py ├── gat_module.py ├── embedding_module.py ├── rna_gat_model.py └── training.py ├── graphseq ├── visualization.py ├── squence_encoders │ ├── lstm.py │ ├── bilstm.py │ ├── attlstm.py │ ├── bert.py │ └── xlstm.py ├── data_preparation.py ├── graph_encoders │ ├── graphsage.py │ ├── gcn.py │ └── gat.py ├── distillation.py ├── evaluation.py └── train.py ├── scripts ├── makeflatfile └── flatten_swiss ├── bioseq ├── tax.py ├── softmax.py ├── annotations.py ├── lem.py ├── poa_util.py ├── blosum.py ├── loaders.py ├── cnnencoder.py ├── __init__.py └── hattn.py ├── POA_README.md ├── setup.py ├── README.md └── training ├── trainh.py ├── rcompute.py ├── compute.py └── cnnpretrain.py /biobert/biobert_pretrain/README.md: -------------------------------------------------------------------------------- 1 | Files I used to construct pretrain maodel for biobert 2 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "spoa"] 2 | path = spoa 3 | url = https://github.com/rvaser/spoa/ 4 | -------------------------------------------------------------------------------- /src/bioseq.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include "pybind11/numpy.h" 4 | namespace py = pybind11; 5 | void init_tokenize(py::module &m); 6 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops 2 | numpy 3 | torch 4 | x-transformers>=0.19 5 | entmax 6 | local-attention 7 | numpy 8 | pybind11==2.7.1 9 | feedback-transformer-pytorch 10 | memcnn 11 | e2cnn 12 | pysam 13 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /src/bioseq.cpp: -------------------------------------------------------------------------------- 1 | #include "bioseq.h" 2 | 3 | void init_omp_helpers(py::module &m); 4 | void init_fxstats(py::module &m); 5 | void init_poa(py::module &m); 6 | PYBIND11_MODULE(cbioseq, m) { 7 | init_tokenize(m); 8 | init_omp_helpers(m); 9 | init_fxstats(m); 10 | init_poa(m); 11 | } 12 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /GAT/__init__.py: -------------------------------------------------------------------------------- 1 | from .embedding_module import EmbeddingModule 2 | from .gat_module import GraphAttentionTransformer 3 | from .rna_gat_model import RNA_GAT_Model 4 | from .training import train_model, evaluate_model, fine_tune_model 5 | from .secondary_structure import SecondaryStructurePredictor 6 | from .solvent_accessibility import SolventAccessibilityPredictor 7 | from .utils import visualize_predictions 8 | -------------------------------------------------------------------------------- /graphseq/visualization.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import matplotlib.pyplot as plt 3 | 4 | def visualize_graph(graph, title): 5 | pos = nx.spring_layout(graph) 6 | labels = nx.get_node_attributes(graph, 'nucleotide') 7 | plt.figure(figsize=(8, 6)) 8 | nx.draw(graph, pos, with_labels=True, labels=labels, node_color='skyblue', node_size=1500, edge_color='black', linewidths=1, font_size=15) 9 | plt.title(title) 10 | plt.show() 11 | -------------------------------------------------------------------------------- /scripts/makeflatfile: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from multiprocessing import Pool, cpu_count 3 | import sys 4 | import os 5 | 6 | if __name__ == "__main__": 7 | from bioseq import FlatFile 8 | from argparse import ArgumentParser as AP 9 | ap = AP() 10 | ap.add_argument("--threads", "-p", default=1, type=int) 11 | ap.add_argument("files", nargs="+") 12 | args = ap.parse_args() 13 | mapper = Pool(args.threads).map if args.threads > 1 else map 14 | FFs = list(mapper(FlatFile, args.files)) 15 | -------------------------------------------------------------------------------- /.idea/bioseq.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 14 | -------------------------------------------------------------------------------- /GAT/utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | def visualize_predictions(model, data_loader): 4 | model.eval() 5 | with torch.no_grad(): 6 | for data in data_loader: 7 | x, edge_index, labels = data.x, data.edge_index, data.y 8 | out = model(x, edge_index) 9 | predictions = out.argmax(dim=1) 10 | plt.figure(figsize=(10, 5)) 11 | plt.plot(predictions.cpu().numpy(), label='Predictions') 12 | plt.plot(labels.cpu().numpy(), label='Ground Truth') 13 | plt.legend() 14 | plt.show() 15 | -------------------------------------------------------------------------------- /biobert/requirement.txt: -------------------------------------------------------------------------------- 1 | aiohttp==3.5.4 2 | async-timeout==3.0.1 3 | attrs==18.2.0 4 | boto3==1.9.105 5 | botocore==1.12.105 6 | certifi==2018.11.29 7 | chardet==3.0.4 8 | Click==7.0 9 | docutils==0.14 10 | h11==0.8.1 11 | httptools==0.0.13 12 | idna==2.8 13 | idna-ssl==1.1.0 14 | jmespath==0.9.4 15 | multidict==4.5.2 16 | numpy==1.16.2 17 | python-dateutil==2.8.0 18 | pytorch-pretrained-bert==0.6.1 19 | regex==2019.2.21 20 | requests==2.21.0 21 | s3transfer==0.2.0 22 | six==1.12.0 23 | starlette==0.11.3 24 | torch==1.0.1.post2 25 | tqdm==4.31.1 26 | typing-extensions==3.7.2 27 | urllib3==1.24.1 28 | uvicorn==0.4.6 29 | uvloop==0.12.1 30 | websockets==7.0 31 | yarl==1.3.0 32 | -------------------------------------------------------------------------------- /graphseq/squence_encoders/lstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class LSTM(nn.Module): 5 | def __init__(self, input_size, hidden_size, num_layers): 6 | super(LSTM, self).__init__() 7 | self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True) 8 | 9 | def forward(self, x): 10 | out, _ = self.lstm(x) 11 | return out 12 | 13 | def save_lstm_model(model, filepath): 14 | torch.save(model.state_dict(), filepath) 15 | 16 | def load_lstm_model(filepath, input_size, hidden_size, num_layers): 17 | model = LSTM(input_size, hidden_size, num_layers) 18 | model.load_state_dict(torch.load(filepath)) 19 | return model -------------------------------------------------------------------------------- /graphseq/squence_encoders/bilstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class BiLSTM(nn.Module): 5 | def __init__(self, input_size, hidden_size, num_layers): 6 | super(BiLSTM, self).__init__() 7 | self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True) 8 | 9 | def forward(self, x): 10 | out, _ = self.lstm(x) 11 | return out 12 | 13 | def save_bilstm_model(model, filepath): 14 | torch.save(model.state_dict(), filepath) 15 | 16 | def load_bilstm_model(filepath, input_size, hidden_size, num_layers): 17 | model = BiLSTM(input_size, hidden_size, num_layers) 18 | model.load_state_dict(torch.load(filepath)) 19 | return model 20 | -------------------------------------------------------------------------------- /bioseq/tax.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | # Simple taxonomic utilities 4 | 5 | def skipgt(x): 6 | return x[x.startswith(">"):] 7 | 8 | 9 | def get_qstr(path): 10 | import gzip 11 | with gzip.open(path, "rt") as gfp: 12 | return skipgt(next(gfp).split(" ")[0]) 13 | 14 | 15 | def get_taxids(fns, gbac2id): 16 | import numpy as np 17 | return np.array(list(map(get_taxid, fns))) 18 | 19 | def get_taxid(fn, isid=False): 20 | if not isid: 21 | fn = get_qstr(fn) 22 | from subprocess import check_output 23 | cmd = f"esearch -db nucleotide -query \"{fn}\"|esummary|xtract -pattern TaxId -element TaxId" 24 | print(cmd, file=sys.stderr, flush=True) 25 | try: 26 | return int(check_output(cmd, shell=True).decode().strip()) 27 | except: 28 | return -1 29 | 30 | 31 | __all__ = ["get_taxid", "get_taxids"] 32 | -------------------------------------------------------------------------------- /GAT/secondary_structure.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class SecondaryStructurePredictor(nn.Module): 5 | def __init__(self, in_channels): 6 | super(SecondaryStructurePredictor, self).__init__() 7 | self.conv1 = nn.Conv2d(in_channels, 128, kernel_size=3, padding=1) 8 | self.conv2 = nn.Conv2d(128, 64, kernel_size=3, padding=1) 9 | self.fc1 = nn.Linear(64 * in_channels, 512) 10 | self.fc2 = nn.Linear(512, 1) 11 | 12 | def forward(self, x): 13 | x = self.conv1(x) 14 | x = nn.functional.relu(x) 15 | x = self.conv2(x) 16 | x = nn.functional.relu(x) 17 | x = x.view(x.size(0), -1) 18 | x = self.fc1(x) 19 | x = nn.functional.relu(x) 20 | x = self.fc2(x) 21 | return x 22 | 23 | secondary_structure_model = SecondaryStructurePredictor(768) 24 | -------------------------------------------------------------------------------- /GAT/solvent_accessibility.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class SolventAccessibilityPredictor(nn.Module): 5 | def __init__(self, in_channels): 6 | super(SolventAccessibilityPredictor, self).__init__() 7 | self.conv1 = nn.Conv1d(in_channels, 128, kernel_size=3, padding=1) 8 | self.conv2 = nn.Conv1d(128, 64, kernel_size=3, padding=1) 9 | self.fc1 = nn.Linear(64, 512) 10 | self.fc2 = nn.Linear(512, 1) 11 | 12 | def forward(self, x): 13 | x = self.conv1(x) 14 | x = nn.functional.relu(x) 15 | x = self.conv2(x) 16 | x = nn.functional.relu(x) 17 | x = x.view(x.size(0), -1) 18 | x = self.fc1(x) 19 | x = nn.functional.relu(x) 20 | x = self.fc2(x) 21 | return x 22 | 23 | solvent_accessibility_model = SolventAccessibilityPredictor(768) 24 | -------------------------------------------------------------------------------- /GAT/gat_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch_geometric.nn import GATConv 4 | 5 | class GraphAttentionTransformer(nn.Module): 6 | def __init__(self, in_channels, out_channels, heads=8, dropout=0.6, num_layers=10): 7 | super(GraphAttentionTransformer, self).__init__() 8 | self.layers = nn.ModuleList() 9 | self.layers.append(GATConv(in_channels, out_channels, heads=heads, dropout=dropout)) 10 | for _ in range(num_layers - 1): 11 | self.layers.append(GATConv(out_channels * heads, out_channels, heads=heads, dropout=dropout)) 12 | self.fc = nn.Linear(out_channels * heads, out_channels) 13 | 14 | def forward(self, x, edge_index): 15 | for layer in self.layers: 16 | x = layer(x, edge_index) 17 | x = nn.functional.elu(x) 18 | x = self.fc(x) 19 | return x 20 | -------------------------------------------------------------------------------- /graphseq/squence_encoders/attlstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class AttLSTM(nn.Module): 5 | def __init__(self, input_size, hidden_size, num_layers): 6 | super(AttLSTM, self).__init__() 7 | self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True) 8 | self.attention = nn.Linear(hidden_size, 1) 9 | 10 | def forward(self, x): 11 | out, _ = self.lstm(x) 12 | attn_weights = F.softmax(self.attention(out), dim=1) 13 | out = torch.bmm(attn_weights.transpose(1, 2), out) 14 | return out 15 | 16 | def save_attlstm_model(model, filepath): 17 | torch.save(model.state_dict(), filepath) 18 | 19 | def load_attlstm_model(filepath, input_size, hidden_size, num_layers): 20 | model = AttLSTM(input_size, hidden_size, num_layers) 21 | model.load_state_dict(torch.load(filepath)) 22 | return model -------------------------------------------------------------------------------- /graphseq/squence_encoders/bert.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import BertTokenizer, BertModel 3 | 4 | class BERTSequenceEncoder(nn.Module): 5 | def __init__(self): 6 | super(BERTSequenceEncoder, self).__init__() 7 | self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 8 | self.bert = BertModel.from_pretrained('bert-base-uncased') 9 | 10 | def forward(self, sequences): 11 | inputs = self.tokenizer(sequences, return_tensors='pt', padding=True, truncation=True) 12 | outputs = self.bert(**inputs) 13 | return outputs.last_hidden_state.mean(dim=1) 14 | 15 | def save_bert_model(model, filepath): 16 | torch.save(model.state_dict(), filepath) 17 | 18 | def load_bert_model(filepath, pretrained_model_name='bert-base-uncased'): 19 | model = BERTSequenceEncoder() 20 | model.load_state_dict(torch.load(filepath)) 21 | return model -------------------------------------------------------------------------------- /src/omp.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #ifdef _OPENMP 4 | # include 5 | #endif 6 | namespace py = pybind11; 7 | 8 | py::ssize_t threadgetter() { 9 | py::ssize_t ret = 1; 10 | #ifdef _OPENMP 11 | #pragma omp parallel 12 | { 13 | ret = omp_get_num_threads(); 14 | } 15 | #endif 16 | return ret; 17 | } 18 | void threadsetter(py::ssize_t x) { 19 | if(x > 0) omp_set_num_threads(x); 20 | } 21 | struct OMPThreadNumManager { 22 | OMPThreadNumManager(int nthreads=-1) {set(nthreads);} 23 | void set(py::ssize_t nthreads) const {threadsetter(nthreads);} 24 | py::ssize_t get() const {return threadgetter();} 25 | }; 26 | 27 | void init_omp_helpers(py::module &m) { 28 | m.def("set_num_threads", threadsetter); 29 | m.def("get_num_threads", threadgetter); 30 | py::class_(m, "Threading").def(py::init<>()).def(py::init()) 31 | .def_property("nthreads", &OMPThreadNumManager::get, &OMPThreadNumManager::set) 32 | .def_property("p", &OMPThreadNumManager::get, &OMPThreadNumManager::set); 33 | } 34 | 35 | -------------------------------------------------------------------------------- /GAT/embedding_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | class EmbeddingModule(nn.Module): 6 | def __init__(self, vocab_size, embedding_dim): 7 | super(EmbeddingModule, self).__init__() 8 | self.embedding = nn.Embedding(vocab_size, embedding_dim) 9 | self.position_embedding = nn.Embedding(5000, embedding_dim) # Assuming max length of 5000 10 | 11 | def forward(self, x): 12 | seq_length = x.size(1) 13 | positions = torch.arange(0, seq_length, device=x.device).unsqueeze(0).expand_as(x) 14 | return self.embedding(x) + self.position_embedding(positions) 15 | 16 | def mask_input(self, x, mask_token_id, mask_prob=0.15): 17 | mask = np.random.rand(*x.shape) < mask_prob 18 | x_masked = x.clone() 19 | x_masked[mask] = mask_token_id 20 | return x_masked, mask 21 | 22 | embedding_dim = 768 23 | vocab_size = len("AGCUX-") 24 | mask_token_id = vocab_size # Assuming the last token in the vocabulary is used as the mask token 25 | embedding_module = EmbeddingModule(vocab_size + 1, embedding_dim) # +1 for the mask token 26 | -------------------------------------------------------------------------------- /graphseq/data_preparation.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | from Bio import SeqIO 3 | import RNA 4 | 5 | def read_fasta(file_path): 6 | sequences = [] 7 | for record in SeqIO.parse(file_path, "fasta"): 8 | sequences.append(str(record.seq)) 9 | return sequences 10 | 11 | def predict_secondary_structure(sequence): 12 | structure, _ = RNA.fold(sequence) 13 | return structure 14 | 15 | def construct_knowledge_graph(sequence): 16 | G = nx.DiGraph() 17 | for i, nucleotide in enumerate(sequence): 18 | G.add_node(i, nucleotide=nucleotide, x=i) 19 | for i in range(len(sequence) - 1): 20 | G.add_edge(i, i + 1, weight=1) 21 | return G 22 | 23 | def construct_secondary_structure_graph(sequence): 24 | structure = predict_secondary_structure(sequence) 25 | G = nx.DiGraph() 26 | for i, nucleotide in enumerate(sequence): 27 | G.add_node(i, nucleotide=nucleotide, x=i) 28 | stack = [] 29 | for i, char in enumerate(structure): 30 | if char == '(': 31 | stack.append(i) 32 | elif char == ')': 33 | j = stack.pop() 34 | G.add_edge(j, i, weight=1) 35 | return G 36 | -------------------------------------------------------------------------------- /GAT/rna_gat_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from embedding_module import EmbeddingModule, embedding_dim, vocab_size, mask_token_id 4 | from gat_module import GraphAttentionTransformer 5 | 6 | 7 | class RNA_GAT_Model(nn.Module): 8 | def __init__(self, vocab_size, embedding_dim, gat_out_channels, gat_heads, gat_layers, dropout): 9 | super(RNA_GAT_Model, self).__init__() 10 | self.embedding_module = EmbeddingModule(vocab_size, embedding_dim) 11 | self.gat_module = GraphAttentionTransformer(embedding_dim, gat_out_channels, heads=gat_heads, dropout=dropout, 12 | num_layers=gat_layers) 13 | 14 | def forward(self, x, edge_index): 15 | x = self.embedding_module(x) 16 | x = x.view(-1, x.size(2)) # Flatten the sequence length dimension 17 | x = self.gat_module(x, edge_index) 18 | return x 19 | 20 | 21 | gat_out_channels = 768 22 | gat_heads = 8 23 | gat_layers = 10 24 | dropout = 0.6 25 | 26 | rna_gat_model = RNA_GAT_Model(vocab_size + 1, embedding_dim, gat_out_channels, gat_heads, gat_layers, 27 | dropout) # +1 for the mask token 28 | -------------------------------------------------------------------------------- /bioseq/softmax.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class SparseSoftmax(nn.Module): 5 | """ 6 | SparseSoftmax is a wrapper around entmax's entmax_bisect, which allows us to learn 7 | the parameter alpha in entmax, which determines how sparse a given layer's output should be 8 | """ 9 | def __init__(self, alpha_init=1.5, n_iter=24, dtype=torch.float32, device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), 10 | reduction='sum', requires_grad=True): 11 | super().__init__() 12 | from entmax import EntmaxBisectLoss 13 | self.alpha = torch.tensor([alpha_init], dtype=dtype, device=device, requires_grad=requires_grad) 14 | self.loss = EntmaxBisectLoss( 15 | self.alpha, n_iter=n_iter, reduction=reduction) 16 | 17 | def forward(self, *args, **kwargs): 18 | if "target" in kwargs: 19 | args = args + (kwargs["target"],) 20 | del kwargs["target"] 21 | if len(args) == 1: 22 | from entmax import entmax_bisect 23 | return entmax_bisect(args[0], self.alpha) 24 | return self.loss(*args, **kwargs) 25 | 26 | Softmax = torch.nn.Softmax 27 | 28 | __all__ = ["SparseSoftmax", "Softmax"] 29 | -------------------------------------------------------------------------------- /graphseq/graph_encoders/graphsage.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from dgl.nn.pytorch.conv import SAGEConv 4 | 5 | class GraphSAGE(nn.Module): 6 | def __init__(self, in_feats, hidden_feats, out_feats, num_layers, activation, dropout, aggregator_type='mean'): 7 | super(GraphSAGE, self).__init__() 8 | self.layers = nn.ModuleList() 9 | self.activation = activation 10 | self.dropout = nn.Dropout(dropout) 11 | 12 | # input layer 13 | self.layers.append(SAGEConv(in_feats, hidden_feats, aggregator_type)) 14 | # hidden layers 15 | for _ in range(num_layers - 2): 16 | self.layers.append(SAGEConv(hidden_feats, hidden_feats, aggregator_type)) 17 | # output layer 18 | self.layers.append(SAGEConv(hidden_feats, out_feats, aggregator_type)) 19 | 20 | def forward(self, g, features): 21 | x = features 22 | for layer in self.layers: 23 | x = self.dropout(x) 24 | x = layer(g, x) 25 | return x 26 | 27 | def save_graphsage_model(model, filepath): 28 | torch.save(model.state_dict(), filepath) 29 | 30 | def load_graphsage_model(filepath, in_feats, hidden_feats, out_feats, num_layers, activation, dropout, aggregator_type='mean'): 31 | model = GraphSAGE(in_feats, hidden_feats, out_feats, num_layers, activation, dropout, aggregator_type) 32 | model.load_state_dict(torch.load(filepath)) 33 | return model -------------------------------------------------------------------------------- /graphseq/graph_encoders/gcn.py: -------------------------------------------------------------------------------- 1 | # scripts/graph_encoders/gcn.py 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class GCNLayer(nn.Module): 8 | def __init__(self, in_features, out_features): 9 | super(GCNLayer, self).__init__() 10 | self.linear = nn.Linear(in_features, out_features) 11 | 12 | def forward(self, x, edge_index): 13 | row, col = edge_index 14 | deg = torch.bincount(row) 15 | deg_inv_sqrt = deg.pow(-0.5) 16 | norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] 17 | 18 | out = torch.matmul(edge_index, x) 19 | out = norm.view(-1, 1) * out 20 | return self.linear(out) 21 | 22 | 23 | class GCN(nn.Module): 24 | def __init__(self, in_features, hidden_features, out_features): 25 | super(GCN, self).__init__() 26 | self.layer1 = GCNLayer(in_features, hidden_features) 27 | self.layer2 = GCNLayer(hidden_features, out_features) 28 | 29 | def forward(self, x, edge_index): 30 | x = F.relu(self.layer1(x, edge_index)) 31 | x = self.layer2(x, edge_index) 32 | return x 33 | 34 | 35 | def save_gcn_model(model, filepath): 36 | torch.save(model.state_dict(), filepath) 37 | 38 | 39 | def load_gcn_model(filepath, in_features, hidden_features, out_features): 40 | model = GCN(in_features, hidden_features, out_features) 41 | model.load_state_dict(torch.load(filepath)) 42 | return model 43 | -------------------------------------------------------------------------------- /graphseq/graph_encoders/gat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from dgl.nn.pytorch import GATConv 4 | 5 | class GAT(nn.Module): 6 | def __init__(self, in_feats, hidden_feats, out_feats, num_layers, num_heads, activation, dropout): 7 | super(GAT, self).__init__() 8 | self.layers = nn.ModuleList() 9 | self.activation = activation 10 | self.dropout = nn.Dropout(dropout) 11 | 12 | # input layer 13 | self.layers.append(GATConv(in_feats, hidden_feats, num_heads)) 14 | # hidden layers 15 | for _ in range(num_layers - 2): 16 | self.layers.append(GATConv(hidden_feats * num_heads, hidden_feats, num_heads)) 17 | # output layer 18 | self.layers.append(GATConv(hidden_feats * num_heads, out_feats, 1)) 19 | 20 | def forward(self, g, features): 21 | x = features 22 | for layer in self.layers[:-1]: 23 | x = self.dropout(x) 24 | x = layer(g, x).flatten(1) 25 | x = self.activation(x) 26 | x = self.layers[-1](g, x).mean(1) 27 | return x 28 | 29 | def save_gat_model(model, filepath): 30 | torch.save(model.state_dict(), filepath) 31 | 32 | def load_gat_model(filepath, in_feats, hidden_feats, out_feats, num_layers, num_heads, activation, dropout): 33 | model = GAT(in_feats, hidden_feats, out_feats, num_layers, num_heads, activation, dropout) 34 | model.load_state_dict(torch.load(filepath)) 35 | return model -------------------------------------------------------------------------------- /bioseq/annotations.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | CDDtypes = {'J': 'Translation, ribosomal structure and biogenesis', 4 | 'A': 'RNA processing and modification', 5 | 'K': 'Transcription', 6 | 'L': 'Replication, recombination and repair', 7 | 'B': 'Chromatin structure and dynamics', 8 | 'D': 'Cell cycle control, cell division, chromosome partitioning', 9 | 'Y': 'Nuclear structure', 10 | 'V': 'Defense mechanisms', 11 | 'T': 'Signal transduction mechanisms', 12 | 'M': 'Cell wall/membrane/envelope biogenesis', 13 | 'N': 'Cell motility', 14 | 'Z': 'Cytoskeleton', 15 | 'W': 'Extracellular structures', 16 | 'U': 'Intracellular trafficking, secretion, and vesicular transport', 17 | 'O': 'Posttranslational modification, protein turnover, chaperones', 18 | 'C': 'Energy production and conversion', 19 | 'G': 'Carbohydrate transport and metabolism', 20 | 'E': 'Amino acid transport and metabolism', 21 | 'F': 'Nucleotide transport and metabolism', 22 | 'H': 'Coenzyme transport and metabolism', 23 | 'I': 'Lipid transport and metabolism', 24 | 'P': 'Inorganic ion transport and metabolism', 25 | 'Q': 'Secondary metabolites biosynthesis, transport and catabolism', 26 | 'R': 'General function prediction only', 27 | 'S': 'Function unknown'} 28 | -------------------------------------------------------------------------------- /graphseq/distillation.py: -------------------------------------------------------------------------------- 1 | # scripts/distillation.py 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | def similarity_preserving_loss(teacher_activations, student_activations): 7 | def compute_similarity_matrix(activations): 8 | norm_activations = F.normalize(activations, p=2, dim=1) 9 | similarity_matrix = torch.mm(norm_activations, norm_activations.t()) 10 | return similarity_matrix 11 | 12 | teacher_similarity = compute_similarity_matrix(teacher_activations) 13 | student_similarity = compute_similarity_matrix(student_activations) 14 | loss = F.mse_loss(student_similarity, teacher_similarity) 15 | return loss 16 | 17 | 18 | def generate_rna_family_loss(output, target): 19 | return F.cross_entropy(output, target) 20 | 21 | 22 | def train_with_distillation(teacher_model, student_model, data_loader, criterion, optimizer, device): 23 | teacher_model.train() 24 | student_model.train() 25 | 26 | for data in data_loader: 27 | inputs, labels = data 28 | inputs, labels = inputs.to(device), labels.to(device) 29 | 30 | optimizer.zero_grad() 31 | 32 | # Forward pass through teacher model 33 | with torch.no_grad(): 34 | teacher_outputs = teacher_model(inputs) 35 | 36 | # Forward pass through student model 37 | student_outputs = student_model(inputs) 38 | 39 | # Compute losses 40 | distillation_loss = similarity_preserving_loss(teacher_outputs, student_outputs) 41 | generation_loss = generate_rna_family_loss(student_outputs, labels) 42 | loss = generation_loss + distillation_loss 43 | 44 | loss.backward() 45 | optimizer.step() 46 | 47 | return student_model 48 | -------------------------------------------------------------------------------- /POA_README.md: -------------------------------------------------------------------------------- 1 | ### Embedding POA graph 2 | 3 | So the current method that extracts the poa graph and embeds it for pytorch geometric is in branch poa-embed. 4 | 5 | You give the input sequences as a list of strings: 6 | 7 | ``` 8 | graph = bioseq.SequenceGraph(seqs) 9 | graph.build() 10 | mat = graph.matrix() 11 | # ext = ExtractedPOAGraph(mat) - optional but easier to think around. 12 | ``` 13 | 14 | 15 | Then to make an input for pytorch geometric/gat, you use bioseq.POAEmbedder. 16 | 17 | ``` 18 | tok = bioseq.DNATokenizer 19 | embedder = bioseq.POAEmbedder(tok, embed_dim=64) 20 | ``` 21 | 22 | ``` 23 | x, data = embedder.to_x_data(mat) 24 | # Or for pytorch_geometric 25 | x_data = embedder.embed_graph(mat) 26 | 27 | gat = # ... (create gat) 28 | gat_output = gat(x, data) 29 | ``` 30 | 31 | 32 | ### Project direction 33 | 34 | 1. Choose input sequences. 35 | RNAFam is good, we can use sets of sequences from RNA-MSM or rinalmo, optionally other databases. 36 | RNAFam is small. 37 | 2. Gather downstream tasks. 38 | (2) - RNA-MSM https://drive.google.com/drive/folders/1jYqk7rAp9ysJCBXOa5Yx4Z9es89h-f2h, RINALMO have tasks. 39 | 40 | To evaluate after pretraininig, we need to map the graph representation to outputs. 41 | 42 | 43 | 2a. How to predict 44 | The simplest thing I can think of for downstream tasks is to take the graph embeddings for each position and predict the task (e.g., structure) directly. 45 | 46 | The ExtractedPOAGraph `seq_node_support` field tells us which nodes each sequence aligned to. We can just use a small network from that embedding to the task. 47 | 48 | I think we can start with this. If the graph approach works well, I think this should do decently. 49 | 50 | 2b. Instead, we can concatenate the embeddings at those positions and use an LSTM layer to predict the secondary task from the graph-informed embeddings. 51 | 52 | This can hopefully improve our results if 2a goes well. 53 | -------------------------------------------------------------------------------- /bioseq/lem.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | # From https://github.com/tk-rusch/LEM/ 6 | # https://arxiv.org/abs/2110.04744 7 | 8 | class LEMCell(nn.Module): 9 | def __init__(self, ninp, nhid, dt): 10 | super(LEMCell, self).__init__() 11 | self.ninp = ninp 12 | self.nhid = nhid 13 | self.dt = dt 14 | self.inp2hid = nn.Linear(ninp, 4 * nhid) 15 | self.hid2hid = nn.Linear(nhid, 3 * nhid) 16 | self.transform_z = nn.Linear(nhid, nhid) 17 | self.reset_parameters() 18 | 19 | def reset_parameters(self): 20 | std = 1.0 / math.sqrt(self.nhid) 21 | for w in self.parameters(): 22 | w.data.uniform_(-std, std) 23 | 24 | def forward(self, x, y, z): 25 | transformed_inp = self.inp2hid(x) 26 | transformed_hid = self.hid2hid(y) 27 | i_dt1, i_dt2, i_z, i_y = transformed_inp.chunk(4, 1) 28 | h_dt1, h_dt2, h_y = transformed_hid.chunk(3, 1) 29 | 30 | ms_dt_bar = self.dt * torch.sigmoid(i_dt1 + h_dt1) 31 | ms_dt = self.dt * torch.sigmoid(i_dt2 + h_dt2) 32 | 33 | z = (1.-ms_dt) * z + ms_dt * torch.tanh(i_y + h_y) 34 | y = (1.-ms_dt_bar)* y + ms_dt_bar * torch.tanh(self.transform_z(z)+i_z) 35 | 36 | return y, z 37 | 38 | class LEM(nn.Module): 39 | def __init__(self, ninp, nhid, nout, dt=1.): 40 | super(LEM, self).__init__() 41 | self.nhid = nhid 42 | self.cell = LEMCell(ninp,nhid,dt) 43 | self.classifier = nn.Linear(nhid, nout) 44 | self.init_weights() 45 | 46 | def init_weights(self): 47 | for name, param in self.named_parameters(): 48 | if 'classifier' in name and 'weight' in name: 49 | nn.init.kaiming_normal_(param.data) 50 | 51 | def forward(self, input): 52 | ## initialize hidden states 53 | y = input.data.new(input.size(1), self.nhid).zero_() 54 | z = input.data.new(input.size(1), self.nhid).zero_() 55 | for x in input: 56 | y, z = self.cell(x,y,z) 57 | out = self.classifier(y) 58 | return out 59 | -------------------------------------------------------------------------------- /GAT/training.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim import Adam 3 | from embedding_module import embedding_module, mask_token_id 4 | import numpy as np 5 | 6 | 7 | def train_model(model, data_loader, epochs=300, lr=0.0003, weight_decay=0.0003): 8 | optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay) 9 | criterion = nn.CrossEntropyLoss() 10 | 11 | for epoch in range(epochs): 12 | model.train() 13 | total_loss = 0 14 | for data in data_loader: 15 | optimizer.zero_grad() 16 | x, edge_index, labels = data.x, data.edge_index, data.y 17 | 18 | # Mask input 19 | x_masked, mask = embedding_module.mask_input(x, mask_token_id) 20 | 21 | # Forward pass 22 | out = model(x_masked, edge_index) 23 | 24 | # Compute loss only on masked positions 25 | loss = criterion(out[mask], x[mask]) 26 | loss.backward() 27 | optimizer.step() 28 | total_loss += loss.item() 29 | 30 | print(f'Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(data_loader)}') 31 | 32 | 33 | def evaluate_model(model, data_loader): 34 | model.eval() 35 | total_correct = 0 36 | total_examples = 0 37 | with torch.no_grad(): 38 | for data in data_loader: 39 | x, edge_index, labels = data.x, data.edge_index, data.y 40 | 41 | # Mask input 42 | x_masked, mask = embedding_module.mask_input(x, mask_token_id) 43 | 44 | out = model(x_masked, edge_index) 45 | predictions = out.argmax(dim=1) 46 | total_correct += (predictions == labels).sum().item() 47 | total_examples += labels.size(0) 48 | return total_correct / total_examples 49 | 50 | 51 | def fine_tune_model(model, train_loader, val_loader, epochs=50, lr=0.0003, weight_decay=0.0003): 52 | optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay) 53 | criterion = nn.CrossEntropyLoss() 54 | 55 | best_val_acc = 0 56 | for epoch in range(epochs): 57 | model.train() 58 | total_loss = 0 59 | for data in train_loader: 60 | optimizer.zero_grad() 61 | x, edge_index, labels = data.x, data.edge_index, data.y 62 | 63 | # Mask input 64 | x_masked, mask = embedding_module.mask_input(x, mask_token_id) 65 | 66 | out = model(x_masked, edge_index) 67 | loss = criterion(out[mask], x[mask]) 68 | loss.backward() 69 | optimizer.step() 70 | total_loss += loss.item() 71 | 72 | val_acc = evaluate_model(model, val_loader) 73 | if val_acc > best_val_acc: 74 | best_val_acc = val_acc 75 | torch.save(model.state_dict(), 'best_model.pt') 76 | 77 | print(f'Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(train_loader)}, Validation Accuracy: {val_acc}') 78 | -------------------------------------------------------------------------------- /graphseq/squence_encoders/xlstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class xLSTMCell(nn.Module): 5 | def __init__(self, input_size, hidden_size): 6 | super(xLSTMCell, self).__init__() 7 | self.input_size = input_size 8 | self.hidden_size = hidden_size 9 | self.Wi = nn.Linear(input_size, hidden_size) 10 | self.Wf = nn.Linear(input_size, hidden_size) 11 | self.Wo = nn.Linear(input_size, hidden_size) 12 | self.Wc = nn.Linear(input_size, hidden_size) 13 | self.Ui = nn.Linear(hidden_size, hidden_size) 14 | self.Uf = nn.Linear(hidden_size, hidden_size) 15 | self.Uo = nn.Linear(hidden_size, hidden_size) 16 | self.Uc = nn.Linear(hidden_size, hidden_size) 17 | self.init_weights() 18 | 19 | def init_weights(self): 20 | for m in self.modules(): 21 | if isinstance(m, nn.Linear): 22 | nn.init.xavier_uniform_(m.weight) 23 | nn.init.constant_(m.bias, 0) 24 | 25 | def forward(self, x, hidden): 26 | h, c = hidden 27 | i = torch.sigmoid(self.Wi(x) + self.Ui(h)) 28 | f = torch.sigmoid(self.Wf(x) + self.Uf(h)) 29 | o = torch.sigmoid(self.Wo(x) + self.Uo(h)) 30 | c_hat = torch.tanh(self.Wc(x) + self.Uc(h)) 31 | c = f * c + i * c_hat 32 | h = o * torch.tanh(c) 33 | return h, (h, c) 34 | 35 | class xLSTM(nn.Module): 36 | def __init__(self, input_size, hidden_size, num_layers): 37 | super(xLSTM, self).__init__() 38 | self.hidden_size = hidden_size 39 | self.num_layers = num_layers 40 | self.layers = nn.ModuleList([xLSTMCell(input_size if i == 0 else hidden_size, hidden_size) for i in range(num_layers)]) 41 | 42 | def forward(self, x, hidden=None): 43 | seq_len, batch_size, _ = x.size() 44 | if hidden is None: 45 | h = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(x.device) 46 | c = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(x.device) 47 | else: 48 | h, c = hidden 49 | 50 | hiddens = [] 51 | for i in range(self.num_layers): 52 | h_i, c_i = h[i], c[i] 53 | layer_output = [] 54 | for t in range(seq_len): 55 | h_i, (h_i, c_i) = self.layers[i](x[t], (h_i, c_i)) 56 | layer_output.append(h_i.unsqueeze(0)) 57 | x = torch.cat(layer_output, dim=0) 58 | hiddens.append((h_i, c_i)) 59 | 60 | h, c = zip(*hiddens) 61 | h = torch.stack(h) 62 | c = torch.stack(c) 63 | return x, (h, c) 64 | 65 | def save_xlstm_model(model, filepath): 66 | torch.save(model.state_dict(), filepath) 67 | 68 | def load_xlstm_model(filepath, input_size, hidden_size, num_layers): 69 | model = xLSTM(input_size, hidden_size, num_layers) 70 | model.load_state_dict(torch.load(filepath)) 71 | return model 72 | -------------------------------------------------------------------------------- /bioseq/poa_util.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import torch 3 | 4 | class FastxSeq: 5 | trans = str.maketrans("U", "T") 6 | def __init__(self, x, standardize_nuc=False): 7 | self.seq = x.sequence 8 | self.name = x.name 9 | self.comment = x.comment 10 | self.qual = x.quality 11 | if standardize_nuc: 12 | self.standardize() 13 | 14 | def __str__(self): 15 | comment = "" if not self.comment else " " + self.comment 16 | if self.qual is not None: 17 | return f"@{self.name}{comment}\n{self.seq}\n+\n{seq.qual}" 18 | else: 19 | return f">{self.name}{comment}\n{self.seq}" 20 | 21 | def standardize(self): 22 | self.seq = str.translate(self.seq, self.trans) 23 | 24 | 25 | class ExtractedPOAGraph: 26 | def __init__(self, mat): 27 | self.node_feats = list(mat['bases']) 28 | 29 | edge_ip = mat['edge_indptr'] 30 | edge_supporting_seqs = mat['edge_nodes'] 31 | self.edge_seq_support = [edge_supporting_seqs[edge_ip[idx]:edge_ip[idx + 1]] for idx in range(len(edge_ip) - 1)] 32 | 33 | seq_ip = mat['seq_indptr'] 34 | seq_supporting_nodes = mat['seq_nodes'] 35 | self.seq_node_support = [seq_supporting_nodes[seq_ip[idx]:seq_ip[idx + 1]] for idx in range(len(seq_ip) - 1)] 36 | 37 | self.edge_coo = mat['matrix_coo'][:,:2] 38 | self.mat = mat 39 | self.ranks = mat['ranks'] 40 | graph = nx.DiGraph() 41 | node_names = [f"{self.node_feats[x]}-{x}" for x in range(len(self.node_feats))] 42 | node_handles = list(map(graph.add_node, node_names)) 43 | for (x, y) in self.edge_coo: 44 | x = node_names[x] 45 | y = node_names[y] 46 | graph.add_edge(y, x) 47 | self.graph = graph 48 | 49 | 50 | def __str__(self): 51 | return f"feats: {self.node_feats}. Ranks: {self.ranks}. Edges: {self.edge_coo}. Graph:{self.graph}" 52 | 53 | 54 | class POAEmbedder: 55 | # TODO: add EOS/BOS support, future enhancement. 56 | # Requires adding edges/nodes to nodes with no incident or excident edges, respectively. 57 | def __init__(self, tok, emb_dim=128): 58 | self.tok = tok 59 | self.emb = torch.nn.Embedding(tok.alphabet_size(), emb_dim) 60 | 61 | # Takes the output of bioseq.SequenceGraph().matrix() and creates the data for GAT 62 | def embed_graph(self, mat): 63 | try: 64 | import pytorch_geometric as pyg 65 | except ImportError: 66 | print("Cannot import pytorch_geometric") 67 | raise 68 | x, data = self.to_x_data(mat) 69 | return pyg.Data(x, data) 70 | 71 | # Takes the output of bioseq.SequenceGraph().matrix() and creates the data for GAT 72 | def to_x_data(self, mat): 73 | embedded = self.emb(torch.from_numpy(tok.batch_tokenize([mat['bases']], padlen=len(mat['bases'])).astype(np.int32))) 74 | x = embedded.view(-1, embedded.size(2)) 75 | data = torch.from_numpy(mat['matrix_coo'][:,:2].astype(np.int32)) # COO 76 | return (x, data) 77 | 78 | 79 | __all__ = ["ExtractedPOAGraph", "FastxSeq"] 80 | -------------------------------------------------------------------------------- /scripts/flatten_swiss: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from Bio.SeqIO import parse as SeqParse 3 | 4 | def fmtpos(x): 5 | try: 6 | return int(x) 7 | except: 8 | import Bio 9 | if isinstance(x, Bio.SeqFeature.UnknownPosition): 10 | return "?" 11 | raise 12 | 13 | 14 | def feat2str(x): 15 | loc = x.location 16 | try: 17 | start, stop = map(fmtpos, (loc.start, loc.end)) 18 | except Exception as e: 19 | print(e, x) 20 | raise 21 | featid = x.id if x.id is not None else "NoID" 22 | feattype = x.type if x.type is not None else "NoType" 23 | qfstr = ";".join(k + ":" + v for k, v in x.qualifiers.items()) 24 | return f"{start}-{stop}-{featid}-{feattype}-{qfstr}" 25 | 26 | 27 | def flatten(inputfile, outputfile): 28 | import gzip 29 | import lzma 30 | if inputfile.endswith(".gz"): 31 | import gzip 32 | ifp = gzip.open(inputfile, "rb") 33 | elif inputfile.endswith(".xz"): 34 | import lzma 35 | ifp = lzma.open(inputfile, "rb") 36 | else: 37 | ifp = open(inputfile, "rb") 38 | featnames = ["name", "id", "desc", "moltype", "dbrefs", "taxid", "genename", "organism", "comment", "features", "hostorgs", "keywords", "refs"] 39 | CCHAR = '==NEWLINE==' 40 | outopen = gzip.open if outputfile.endswith('.gz') else (lzma.open if outputfile.endswith(".xz") else open) 41 | outmode = ("wt" if outputfile.endswith(".xz") else "w") 42 | with (gzip.open(inputfile, "rb") if inputfile.endswith(".gz") else (lzma.open(inputfile, "rb") if inputfile.endswith("xz") else open(inputfile, "rb"))) as ifp: 43 | with outopen(outputfile, outmode) as ofp: 44 | print("#" + "\t".join(featnames), file=ofp) 45 | for seq in SeqParse(ifp, "swiss"): 46 | # first few things 47 | gt = seq.annotations.get 48 | features = ";".join(map(feat2str, seq.features)) 49 | host_orgs = ";".join(seq.annotations.get("host_ncbi_taxid", [])) 50 | comment = seq.annotations.get('comment', "") 51 | if CCHAR in comment: 52 | raise Exception() 53 | outitems = [seq.name, seq.id, seq.description, seq.annotations['molecule_type'], "-".join(seq.dbxrefs), seq.annotations.get('ncbi_taxid', [-1])[0], gt('gene_name'), gt('organism'), comment.replace("\n", CCHAR), 54 | features, host_orgs, ";".join(seq.annotations.get("keywords", [])), ";".join(map(str, (x.pubmed_id for x in seq.annotations.get("references", []))))] 55 | outitems = [x if x is not None else "" for x in outitems] 56 | assert len(featnames) == len(outitems) 57 | outstr = "\t".join(map(str, outitems)) 58 | for ct, fname in zip(outitems, featnames): 59 | if '\t' in ct: 60 | print(ct, fname) 61 | raise Exception() 62 | print(outstr, file=ofp) 63 | 64 | 65 | if __name__ == "__main__": 66 | import sys 67 | from argparse import ArgumentParser as AP 68 | ap = AP() 69 | aa = ap.add_argument 70 | aa("Input", help="Input SwissProt annotated sequence file") 71 | aa("Output", nargs="?", const="") 72 | args = ap.parse_args() 73 | infile = args.Input 74 | outfile = args.Output if args.Output else infile + ".flattened.tsv.gz" 75 | flatten(infile, outfile) 76 | -------------------------------------------------------------------------------- /graphseq/evaluation.py: -------------------------------------------------------------------------------- 1 | # scripts/evaluate.py 2 | import networkx as nx 3 | import numpy as np 4 | from sklearn.metrics import jaccard_score, accuracy_score, precision_score, recall_score 5 | import RNA 6 | import difflib 7 | import torch 8 | from load_data import load_graphs, load_sequences 9 | 10 | # Function to calculate Graph Edit Distance 11 | def calculate_ged(graph1, graph2): 12 | return nx.graph_edit_distance(graph1, graph2) 13 | 14 | # Function to calculate Jaccard Similarity 15 | def calculate_jaccard_similarity(graph1, graph2): 16 | nodes1 = set(graph1.nodes()) 17 | nodes2 = set(graph2.nodes()) 18 | return jaccard_score(list(nodes1), list(nodes2)) 19 | 20 | # Function to calculate Sequence Similarity 21 | def calculate_sequence_similarity(seq1, seq2): 22 | sm = difflib.SequenceMatcher(None, seq1, seq2) 23 | return sm.ratio() 24 | 25 | # Function to predict RNA secondary structure using RNAfold 26 | def predict_secondary_structure(sequence): 27 | (structure, mfe) = RNA.fold(sequence) 28 | return structure 29 | 30 | # Function to calculate base pair distance 31 | def calculate_base_pair_distance(structure1, structure2): 32 | return RNA.bp_distance(structure1, structure2) 33 | 34 | # Evaluate Structural Similarity 35 | def evaluate_structural_similarity(generated_graphs, original_graphs): 36 | ged_scores = [] 37 | jaccard_scores = [] 38 | for g_graph, o_graph in zip(generated_graphs, original_graphs): 39 | ged_scores.append(calculate_ged(g_graph, o_graph)) 40 | jaccard_scores.append(calculate_jaccard_similarity(g_graph, o_graph)) 41 | return np.mean(ged_scores), np.mean(jaccard_scores) 42 | 43 | # Evaluate Sequence Similarity 44 | def evaluate_sequence_similarity(generated_sequences, original_sequences): 45 | seq_similarities = [] 46 | for g_seq, o_seq in zip(generated_sequences, original_sequences): 47 | seq_similarities.append(calculate_sequence_similarity(g_seq, o_seq)) 48 | return np.mean(seq_similarities) 49 | 50 | # Evaluate Functional Similarity 51 | def evaluate_functional_similarity(generated_sequences, original_sequences): 52 | bp_distances = [] 53 | for g_seq, o_seq in zip(generated_sequences, original_sequences): 54 | g_structure = predict_secondary_structure(g_seq) 55 | o_structure = predict_secondary_structure(o_seq) 56 | bp_distances.append(calculate_base_pair_distance(g_structure, o_structure)) 57 | return np.mean(bp_distances) 58 | 59 | # Evaluate Performance Metrics 60 | def evaluate_performance_metrics(predictions, targets): 61 | accuracy = accuracy_score(targets, predictions) 62 | precision = precision_score(targets, predictions, average='weighted') 63 | recall = recall_score(targets, predictions, average='weighted') 64 | return accuracy, precision, recall 65 | 66 | # Main evaluation function 67 | def evaluate_all(): 68 | original_graphs = load_graphs('../data/original_graphs') 69 | generated_graphs = load_graphs('../output/generated_graphs') 70 | original_sequences = load_sequences('../data/original_sequences') 71 | generated_sequences = load_sequences('../output/generated_sequences') 72 | 73 | # Structural Similarity 74 | ged_score, jaccard_score = evaluate_structural_similarity(generated_graphs, original_graphs) 75 | print(f"Graph Edit Distance: {ged_score}, Jaccard Similarity: {jaccard_score}") 76 | 77 | # Sequence Similarity 78 | seq_similarity = evaluate_sequence_similarity(generated_sequences, original_sequences) 79 | print(f"Sequence Similarity: {seq_similarity}") 80 | 81 | # Functional Similarity 82 | bp_distance = evaluate_functional_similarity(generated_sequences, original_sequences) 83 | print(f"Base Pair Distance: {bp_distance}") 84 | 85 | # Performance Metrics (Dummy example with predictions and targets) 86 | predictions = [1, 0, 1, 1] 87 | targets = [1, 0, 1, 0] 88 | accuracy, precision, recall = evaluate_performance_metrics(predictions, targets) 89 | print(f"Accuracy: {accuracy}, Precision: {precision}, Recall: {recall}") 90 | 91 | if __name__ == "__main__": 92 | evaluate_all() 93 | -------------------------------------------------------------------------------- /bioseq/blosum.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | from collections import Counter 4 | 5 | from numpy.random import default_rng 6 | rng = default_rng(int(10000. / 137)) # Fine-structure constant of the universe 7 | 8 | 9 | BLOSUM_TEXT = '''A R N D C Q E G H I L K M F P S T W Y V B Z X * 10 | A 4 -1 -2 -2 0 -1 -1 0 -2 -1 -1 -1 -1 -2 -1 1 0 -3 -2 0 -2 -1 0 -4 11 | R -1 5 0 -2 -3 1 0 -2 0 -3 -2 2 -1 -3 -2 -1 -1 -3 -2 -3 -1 0 -1 -4 12 | N -2 0 6 1 -3 0 0 0 1 -3 -3 0 -2 -3 -2 1 0 -4 -2 -3 3 0 -1 -4 13 | D -2 -2 1 6 -3 0 2 -1 -1 -3 -4 -1 -3 -3 -1 0 -1 -4 -3 -3 4 1 -1 -4 14 | C 0 -3 -3 -3 9 -3 -4 -3 -3 -1 -1 -3 -1 -2 -3 -1 -1 -2 -2 -1 -3 -3 -2 -4 15 | Q -1 1 0 0 -3 5 2 -2 0 -3 -2 1 0 -3 -1 0 -1 -2 -1 -2 0 3 -1 -4 16 | E -1 0 0 2 -4 2 5 -2 0 -3 -3 1 -2 -3 -1 0 -1 -3 -2 -2 1 4 -1 -4 17 | G 0 -2 0 -1 -3 -2 -2 6 -2 -4 -4 -2 -3 -3 -2 0 -2 -2 -3 -3 -1 -2 -1 -4 18 | H -2 0 1 -1 -3 0 0 -2 8 -3 -3 -1 -2 -1 -2 -1 -2 -2 2 -3 0 0 -1 -4 19 | I -1 -3 -3 -3 -1 -3 -3 -4 -3 4 2 -3 1 0 -3 -2 -1 -3 -1 3 -3 -3 -1 -4 20 | L -1 -2 -3 -4 -1 -2 -3 -4 -3 2 4 -2 2 0 -3 -2 -1 -2 -1 1 -4 -3 -1 -4 21 | K -1 2 0 -1 -3 1 1 -2 -1 -3 -2 5 -1 -3 -1 0 -1 -3 -2 -2 0 1 -1 -4 22 | M -1 -1 -2 -3 -1 0 -2 -3 -2 1 2 -1 5 0 -2 -1 -1 -1 -1 1 -3 -1 -1 -4 23 | F -2 -3 -3 -3 -2 -3 -3 -3 -1 0 0 -3 0 6 -4 -2 -2 1 3 -1 -3 -3 -1 -4 24 | P -1 -2 -2 -1 -3 -1 -1 -2 -2 -3 -3 -1 -2 -4 7 -1 -1 -4 -3 -2 -2 -1 -2 -4 25 | S 1 -1 1 0 -1 0 0 0 -1 -2 -2 0 -1 -2 -1 4 1 -3 -2 -2 0 0 0 -4 26 | T 0 -1 0 -1 -1 -1 -1 -2 -2 -1 -1 -1 -1 -2 -1 1 5 -2 -2 0 -1 -1 0 -4 27 | W -3 -3 -4 -4 -2 -2 -3 -2 -2 -3 -2 -3 -1 1 -4 -3 -2 11 2 -3 -4 -3 -2 -4 28 | Y -2 -2 -2 -3 -2 -1 -2 -3 2 -1 -1 -2 -1 3 -3 -2 -2 2 7 -1 -3 -2 -1 -4 29 | V 0 -3 -3 -3 -1 -2 -2 -3 -3 3 1 -2 1 -1 -2 -2 0 -3 -1 4 -3 -2 -1 -4 30 | B -2 -1 3 4 -3 0 1 -1 0 -3 -4 0 -3 -3 -2 0 -1 -4 -3 -3 4 1 -1 -4 31 | Z -1 0 0 1 -3 3 4 -2 0 -3 -3 1 -1 -3 -1 0 -1 -3 -2 -2 1 4 -1 -4 32 | X 0 -1 -1 -1 -2 -1 -1 -1 -1 -1 -1 -1 -1 -1 -2 0 0 -2 -1 -1 -1 -1 -1 -4 33 | * -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 1''' 34 | 35 | 36 | true_aas = 'ARNDCQEGHILKMFPSTWYVX' 37 | blosum_data = np.array([list(map(int, x.strip().split()[1:])) for x in BLOSUM_TEXT.split('\n')[1:]]) 38 | amine_chrs = "".join(x.split()[0] for x in BLOSUM_TEXT.split('\n')[1:]) 39 | true_idx = [i for i, x in enumerate(amine_chrs) if x in true_aas] 40 | blosum_specific = blosum_data[np.ix_(true_idx, true_idx[:-1])] 41 | blosum_odds = np.exp2(blosum_specific) 42 | rowsums = np.sum(blosum_odds, axis=1) 43 | normrows = blosum_odds / rowsums[:,np.newaxis] 44 | ca = np.array(list(true_aas))[:-1] 45 | 46 | aa_array = ca 47 | 48 | probdict = {k: normrows[idx].copy() for idx, k in enumerate(true_aas)} 49 | default_transitions = probdict['X'] 50 | substituters = {} 51 | def substitute(inchar, size=1): 52 | ''' 53 | Inputs: inchar [str] - input character to be replaced 54 | KWarg: size=1 - number of samples 55 | Outputs: samples of length [size] 56 | 57 | Used for generating new sequences for augmentation. 58 | Transitions are based on BLOSUM62 scores. 59 | ''' 60 | return rng.choice(ca, p=probdict.get(inchar, default_transitions), size=size, replace=True) 61 | 62 | 63 | def augment_seq(inseq, chain_len=1): 64 | ''' 65 | Takes an input sequence, mutates it `chain_len` times, and then returns the final sequence. 66 | 67 | Inputs: 68 | inseq - Str 69 | Must be comprised of valid AAs 70 | chain_len=1 - Int 71 | Number of mutations to cause 72 | Returns: 73 | outseq - Str 74 | Final sequence after mutations 75 | ''' 76 | # ba = bytearray(inseq, 'utf-8') 77 | ls = len(inseq) 78 | for _ in range(chain_len): 79 | outchar, inchar = (0, 0) 80 | while inchar == outchar: 81 | idx = rng.choice(ls) 82 | outchar = inseq[idx] 83 | inchar = substitute(outchar)[0] 84 | ba = bytearray(inseq, 'utf-8') 85 | ba[idx] = ord(inchar) 86 | inseq = ba.decode() 87 | return inseq 88 | 89 | # substituters = {k: lambda size=1: substitute(k, size=size) for k in true_aas} 90 | hc = Counter(aa_array[rng.choice(20, size=10000, p=probdict['H'])]) 91 | assert hc.most_common()[0][0] == 'H', str(hc) 92 | kc = Counter(aa_array[rng.choice(20, size=10000, p=probdict['K'])]) 93 | assert kc.most_common()[0][0] == 'K', str(hc) 94 | hc = Counter(substitute('H', size=10000)) 95 | assert hc.most_common()[0][0] == 'H', str(hc) + ", but through substituters" 96 | 97 | __all__ = ["BLOSUM_TEXT", "aa_array", "substitute", "normrows", "probdict", "augment_seq"] 98 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, Extension, find_packages, distutils 2 | from sys import platform 3 | from setuptools.command.build_ext import build_ext 4 | from glob import glob 5 | import multiprocessing 6 | import subprocess 7 | 8 | 9 | class get_pybind_include(object): 10 | """Helper class to determine the pybind11 include path 11 | The purpose of this class is to postpone importing pybind11 12 | until it is actually installed, so that the ``get_include()`` 13 | method can be invoked. """ 14 | 15 | def __init__(self, user=False): 16 | self.user = user 17 | 18 | def __str__(self): 19 | import pybind11 20 | return pybind11.get_include(self.user) 21 | 22 | 23 | def has_flag(compiler, flagname): 24 | """Return a boolean indicating whether a flag name is supported on 25 | the specified compiler. 26 | """ 27 | import tempfile 28 | with tempfile.NamedTemporaryFile('w', suffix='.cpp') as f: 29 | f.write('int main (int argc, char **argv) { return 0; }') 30 | try: 31 | compiler.compile([f.name], extra_postargs=[flagname]) 32 | except distutils.errors.CompileError: 33 | return False 34 | return True 35 | 36 | 37 | def cpp_flag(compiler): 38 | """Return the -std=c++[11/14/17] compiler flag. 39 | The newer version is prefered over c++11 (when it is available). 40 | """ 41 | 42 | for flag in ("-std=c++%s" % x for x in ("2a", 17, 14, 11, "03")): 43 | if has_flag(compiler, flag): return flag 44 | 45 | raise RuntimeError('Unsupported compiler -- at least C++11 support ' 46 | 'is needed!') 47 | 48 | full_gomp_path = subprocess.check_output("realpath `$CXX --print-file-name=libgomp.a`", shell=True).decode('utf-8').strip() 49 | 50 | extra_compile_args = ['-march=native', 51 | '-Wno-char-subscripts', '-Wno-unused-function', '-Wno-ignored-qualifiers', 52 | '-Wno-strict-aliasing', '-Wno-ignored-attributes', '-fno-wrapv', 53 | '-Wall', '-Wextra', '-Wformat', 54 | '-lz', '-fopenmp', 55 | "-pipe", '-O0', '-DNDEBUG'] 56 | 57 | extra_link_opts = ["-fopenmp", "-lz"] 58 | 59 | 60 | class BuildExt(build_ext): 61 | """A custom build extension for adding compiler-specific options.""" 62 | c_opts = { 63 | 'msvc': ['/EHsc'], 64 | 'unix': [], 65 | } 66 | l_opts = { 67 | 'msvc': [], 68 | 'unix': [], 69 | } 70 | 71 | if platform == 'darwin': 72 | darwin_opts = ['-mmacosx-version-min=10.7']# , '-libstd=libc++'] 73 | # darwin_opts = [] 74 | c_opts['unix'] += darwin_opts 75 | l_opts['unix'] += darwin_opts 76 | 77 | 78 | def build_extensions(self): 79 | ct = self.compiler.compiler_type 80 | opts = self.c_opts.get(ct, []) 81 | link_opts = self.l_opts.get(ct, []) 82 | if ct == 'unix': 83 | opts.append('-DVERSION_INFO="%s"' % self.distribution.get_version()) 84 | opts.append(cpp_flag(self.compiler)) 85 | if has_flag(self.compiler, '-fvisibility=hidden'): 86 | opts.append('-fvisibility=hidden') 87 | elif ct == 'msvc': 88 | opts.append('/DVERSION_INFO=\\"%s\\"' % self.distribution.get_version()) 89 | for ext in self.extensions: 90 | ext.extra_compile_args = opts 91 | ext.extra_compile_args += extra_compile_args 92 | ext.extra_link_args = link_opts + extra_link_opts 93 | ext.extra_objects = [full_gomp_path, "spoa/lib/libspoa.a"] 94 | build_ext.build_extensions(self) 95 | 96 | 97 | def build_spoa(): 98 | import subprocess 99 | import os 100 | os.makedirs("spoa/build", exist_ok=True) 101 | subprocess.check_call("cd spoa/build && cmake .. && cd .. && make", shell=True) 102 | 103 | if __name__ == "__main__": 104 | __version__ = "0.1.8" 105 | build_spoa() 106 | include_dirs = [get_pybind_include(), get_pybind_include(True), "./", "spoa/include"] 107 | ext_modules = [Extension("cbioseq", ["src/bioseq.cpp", "src/poa.cpp", "src/tokenize.cpp", "src/omp.cpp", 'src/fxstats.cpp'], include_dirs=include_dirs, language='c++')] 108 | setup( 109 | name='bioseq', 110 | version=__version__, 111 | author='Daniel Baker', 112 | author_email='dnb@cs.jhu.edu', 113 | url='https://github.com/dnbaker/bioseq', 114 | description='A python module for tokenizing biological sequences', 115 | long_description='', 116 | ext_modules=ext_modules, 117 | install_requires=['pybind11', 'numpy>=0.19', 'einops', 'torch', 'fast_transformer_pytorch', 'x-transformers'], 118 | setup_requires=['pybind11'], 119 | cmdclass={'build_ext': BuildExt}, 120 | zip_safe=False, 121 | packages=find_packages(), 122 | scripts=['scripts/flatten_swiss', 'scripts/makeflatfile'] 123 | ) 124 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # bioseq 2 | 3 | A C++/Python package performing fast one-hot encoding for DNA or Protein sequences with C++ code, optionally converting to pytorch and moving to device. 4 | 5 | Offers 4-letter DNA, 20-letter amino acid, and a variety of other compressed protein and DNA alphabets, and optionally is parallelized. 6 | 7 | ## Tokenizing 8 | 9 | bioseq.Tokenizer does the tokenizing, and there are pre-made tokenizers for all alphabets, as well as combinations of EOS, BOS, and whether padding gets a unique character, or is simply masked. 10 | 11 | `bos_tokenizers` is a dictionary from alphabets to Tokenizers with a BOS tag prepended. 12 | `eos_tokenizers` is a dictionary from alphabets to Tokenizers with an EOS tag appended. 13 | `pos_tokenizers` is adictionary from alphabets to Tokenizers with a padding character used. 14 | `beos_tokenizers` adds both BOS and EOS 15 | `pbeos_tokenizers` adds BOS, EOS, and padding characters. 16 | 17 | Tokenizer can tokenize (`batch_tokenize`), which creates array of tokens, (uint8 by default), 18 | or it can one-hot encode (`batch_onehot_encode`), which takes the tokens one-step further into one-hot encoding. 19 | Both of these `Tokenizer::batch_*` functions can be parallelized by providing `nthreads={int}`. 20 | 21 | tokenizing uses seq-first ordering by default as well, but this can be changed with `batch_first=True`. 22 | one-hot encoding uses seq-first ordering (not batch-first). It does not support `batch_first`. 23 | 24 | Both of these are ~30x as fast as using bytes.translate + np.frombuffer + np.vstack + `torch.from_numpy`, 25 | and ~500x as fast as transformers.tokenizer.batch\_encode\_plus. 26 | 27 | 1. To train Transformers, you need to use `batch_first=True`, followed by torch.nn.Embedding. 28 | 2. To train CNNs, tokenize with `batch_first=True`, embed with torch.nn.Embedding, and then apply `lambda x: einops.rearrange(x, 'batch seq emb -> batch emb seq')`. 29 | This is because CNNs expect (Batch, C, L) 30 | 3. To train LSTMs, use `batch_first=False` to tokenize, and embed with torch.nn.Embedding 31 | 32 | Basically, you only want `batch_first=False` for LSTM training, and using CNNs will require a rearrange call due to the varying expectation of dimension ordering. 33 | 34 | ## Decoding 35 | 36 | You can decode a sequence with a tokenizer. 37 | 38 | ```python 39 | import bioseq 40 | tok = bioseq.pbeos_tokenizers['DNA'] # To add BOS, EOS, and PAD characters separately. 41 | tokens = tok.batch_tokenize(["ACGT", "GGGG"], padlen=7, batch_first=True) 42 | decoded = tok.decode_tokens(tokens) 43 | # decoded == ['ACGT', 'GGGG'] 44 | ``` 45 | 46 | It accepts 1D and 2D arrays. Be careful - if you don't have `batch_first` set, you may get the wrong outputs. You can fix this by swapping dimensions. 47 | 48 | And if you have a one-hot encoded array (or have logits), just use an argmax by dimension to convert batch to tokens for decoding. 49 | 50 | *Warning* (sharp edges): 51 | 52 | 1. if you're using a reduced amino acid alphabet, each token represents several amino acids. We simply pick the lexicographically smallest as a representative. 53 | 54 | To the the full set of tokens for ambiguous tokens, use the `tokenizer.token_decoder()`. `token_decoder()` returns a dictionary mapping integers to all possible characters. 55 | 56 | 2. Consider ensuring padding gets its own character. `pbeos_tokenizers`, for instance, adds padding tokens as well as beginning/end of sequence tokens. 57 | 58 | Since sequences have different lengths, we have to pad to equal length for a batch. If `padding=True` on the `Tokenizer`, then we add padding tokens at the ends. 59 | One-hot encoding simply leaves them as 0s by default, but for tokens it's particularly important. For instance, in DNA, an empty padding is marked as a 0 and would then be marked as A. You pay slightly more (and use more tokens), but models learn the patterns of padding tokens at the end rather quickly, and you can avoid making mistakes. 60 | 61 | ## DataLoading 62 | We use a bioseq.FlatFile method, which provides random access to the sequences in a FAST{Q,A} file. 63 | This is then used by bioseq.FlatFileDataset for use with torch.utils.data.DataLoader. 64 | 65 | For an example, see training/trainh.py and training/compute.py. 66 | 67 | ## Sequence augmentation 68 | 69 | We also support augmentation by random mutations sampled according to BLOSUM62 transition probabilities. 70 | This is only valid for tokenizers using the full 20-character amino acid alphabet ("PROTEIN" or "AMINO20"). We may modify this in the future to support other alphabets. 71 | 72 | bioseq.AmineTokenizer is a pre-build tokenizer without BOS, EOS, or padding which is valid for this. 73 | 74 | 75 | ## Dependencies 76 | 77 | pybind11 v2.7 is required, in order to support bytearray 78 | numpy is required 79 | pytorch (as torch) is also required 80 | 81 | Besides these, there are some python-only dependencies which setup.py should download for you. 82 | 83 | All of these can be manually installed via `python3 -m pip install -r requirements.txt`. 84 | 85 | ## Version history 86 | 87 | v0.1.3: Bug fix - previous versions mapped Proline ("P") to Lysine ("K"), instead of mapping Pyrrolysine ("O") to "K". 88 | 89 | v0.1.2: Dependencies made optional, token decoding added 90 | 91 | v0.1.1: Initial version 92 | -------------------------------------------------------------------------------- /training/trainh.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from time import time 4 | 5 | 6 | import bioseq 7 | 8 | from bioseq.decoders import SeqEncoder, HTransformer1D, XDecoder, XAutoregressiveWrapper, FastEncoder, FAutoregressiveWrapper 9 | from bioseq.hattn import AutoregressiveWrapper as HAutoregressor 10 | from argparse import ArgumentParser as AP 11 | import numpy as np 12 | import torch 13 | from torch.utils.data import DataLoader 14 | 15 | import tqdm 16 | 17 | 18 | ap = AP() 19 | aa = ap.add_argument 20 | aa("--bos", action="store_true", help="Prepend a BOS tag for sequence") 21 | aa("--eos", action="store_true", help="Append an EOS tag to sequence") 22 | aa("--padchar", action="store_true", help="Treat padding characters are unique identifier. (Default: no embeddings)") 23 | aa("--alphabet", default="PROTEIN") 24 | aa("sequencefile", help="Input sequences; Must be in Fasta or Fastq format. All quality scores are ignored.") 25 | aa("--nbatches", type=int, default=1) 26 | aa("--batchsize", type=int, default=8) 27 | aa("--embdim", type=int, default=64) 28 | aa("--headdim", type=int, default=64) 29 | aa("--nheads", type=int, default=8) 30 | aa("--depth", "--nlayers", type=int, default=6) 31 | aa("--sparseemb", action='store_true', help="Use sparse embeddings.") 32 | aa("--learning-rate", "-R", type=float, default=2e-4) 33 | aa("--accumfreq", type=int, default=4) 34 | aa("--bidir-loss", type=float, const=1., nargs='?') 35 | aa("--clip-grad-norm", "--clip", type=float, default=.25) 36 | aa("--transformer-type", "-T", choices=("Fast", "Hier", "X"), help="Type of transformer to use. Default: HTransformer1D (Hier)", default="X") 37 | aa("--sparse-softmax", action='store_true', help="Whether to use differentiably sparse top-k") 38 | aa("--nthreads", "-p", type=int, default=1) 39 | aa("--gate-residual", action='store_true') 40 | aa("--augment", type=int, default=0, help="Number of mutations to introduce while augmenting data. Default: 0.") 41 | aa("--augment-frac", type=float, default=.5, help="Fraction of sequences to augment. Default: 0.5, but only used if --augment is set.") 42 | args = ap.parse_args() 43 | LEARNING_RATE = args.learning_rate 44 | GRADIENT_ACCUMULATE_EVERY = args.accumfreq 45 | torch.set_num_threads(args.nthreads) 46 | if args.sparseemb: 47 | raise Exception("Cannot use sparse embeddings rn") 48 | 49 | def roundup(x): 50 | x = x + 1 51 | for shift in (1, 2, 4, 8, 16, 32): 52 | x |= x >> shift 53 | return x + 1 54 | 55 | NUM_BATCHES = args.nbatches 56 | BATCH_SIZE = args.batchsize 57 | argtup = (args.bos, args.eos, args.padchar, args.alphabet) 58 | tokd = bioseq.get_tokenizer_dict(args.bos, args.eos, args.padchar) 59 | try: 60 | tokenizer = tokd[args.alphabet.upper()] 61 | except KeyError: 62 | print(tokd.keys()) 63 | raise 64 | 65 | 66 | def cycle(x): 67 | while 1: 68 | yield from x 69 | 70 | 71 | 72 | embeddings = bioseq.make_embedding(tokenizer, args.embdim, norm_type=2.0, sparse=args.sparseemb) 73 | 74 | ffp = args.sequencefile + ".ff" 75 | 76 | if os.path.isfile(ffp): 77 | print("Found existing flatfile", file=sys.stderr) 78 | ff = bioseq.FlatFile(ffp) 79 | else: 80 | print("Making flatfile", file=sys.stderr) 81 | ff = bioseq.FlatFile(args.sequencefile, ffp) 82 | ffl = bioseq.loaders.FlatFileDataset(ff, tokenizer) 83 | 84 | train_loader = cycle(DataLoader(ffl, batch_size=args.batchsize)) 85 | 86 | msl = ffl.max_seq_len 87 | if args.transformer_type == "Hier": 88 | nmsl = roundup(msl) 89 | print(f"Padding msl to next power of to: {msl}->{nmsl}",file=sys.stderr) 90 | msl = nmsl 91 | del nmsl 92 | # print("msl: %d. roundedup: %d\n" % (msl, roundup(msl))) 93 | # msl = roundup(msl) 94 | 95 | tokl = bioseq.decoders.TokenizerLayer(tokenizer, padlen=msl) 96 | 97 | argdict = {} 98 | 99 | baseargs = {"num_tokens": tokenizer.alphabet_size(), "heads": args.nheads, "depth": args.depth, "dim": args.embdim, "max_seq_len": msl} 100 | if args.transformer_type == "Fast": 101 | TxType = FastEncoder 102 | baseargs.update({"query_sparse_softmax": args.sparse_softmax, "key_sparse_softmax": args.sparse_softmax}) 103 | elif args.transformer_type == "Hier": 104 | TxType = HTransformer1D 105 | baseargs.update({"causal": True, "reversible": True}) 106 | else: 107 | assert args.transformer_type == "X" 108 | TxType = XDecoder 109 | baseargs.update({"gate_residual": args.gate_residual, 'rotary_pos_emb': True, "reversible": True}) 110 | seq_encoder = SeqEncoder(tokl, embeddings, TxType, **baseargs) 111 | encoder = seq_encoder.encoder 112 | model = seq_encoder 113 | if torch.cuda.is_available(): 114 | print("Using CUDA") 115 | model = model.cuda() 116 | else: 117 | print("Using CPU with %d threads" % torch.get_num_threads()) 118 | if args.transformer_type == "Hier": 119 | model = HAutoregressor(model) 120 | elif args.transformer_type == "Fast": 121 | model = XAutoregressiveWrapper(model) 122 | else: 123 | model = XAutoregressiveWrapper(model) 124 | optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) 125 | 126 | tstart = time() 127 | num = 0 128 | for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'): 129 | model.train() 130 | 131 | for __ in range(GRADIENT_ACCUMULATE_EVERY): 132 | gstart = time() 133 | nextbatch = next(train_loader).to(torch.long) 134 | loss = model(nextbatch) 135 | if args.bidir_loss: 136 | loss += args.bidir_loss * model(torch.flip(nextbatch, (1,))) 137 | loss.backward() 138 | 139 | print(f'training loss: {loss.item()} after {time() - tstart}s', file=sys.stderr, flush=True) 140 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm) 141 | optim.step() 142 | optim.zero_grad() 143 | print(f"Average time per item: {(time() - tstart) / (GRADIENT_ACCUMULATE_EVERY * args.batchsize * NUM_BATCHES)}") 144 | 145 | from datetime import datetime 146 | dstr = str(datetime.now()).replace(" ", "_").replace(":", "-") 147 | torch.save(model, f"hmodel.{dstr}.pt") 148 | 149 | print(f"Total time: {time() - tstart}") 150 | -------------------------------------------------------------------------------- /src/tokenize.cpp: -------------------------------------------------------------------------------- 1 | #include "bioseq.h" 2 | #include "tokenize.h" 3 | #include 4 | #include 5 | #include 6 | 7 | inline __attribute__((always_inline)) 8 | py::object tokenize(const Tokenizer &tok, const char *s, const py::ssize_t size, const py::ssize_t padlen, const std::string dt) { 9 | py::object ret = py::none(); 10 | switch(dt[0] & 223) { // remove case from character by removing bit 0b100000 == 32 11 | case 'B': ret = tok.tokenize(s, size, padlen); break; 12 | case 'H': ret = tok.tokenize(s, size, padlen); break; 13 | case 'I': ret = tok.tokenize(s, size, padlen); break; 14 | case 'F': ret = tok.tokenize(s, size, padlen); break; 15 | case 'D': ret = tok.tokenize(s, size, padlen); break; 16 | default: ; // Else, return None 17 | } 18 | return ret; 19 | } 20 | 21 | void init_tokenize(py::module &m) { 22 | py::class_(m, "Tokenizer") 23 | .def(py::init(), py::arg("key"), py::arg("eos") = false, py::arg("bos") = false, py::arg("padchar") = false) 24 | .def("onehot_encode", [](const Tokenizer &tok, py::str s, py::ssize_t padlen, std::string dt) -> py::object { 25 | py::ssize_t size; 26 | const char *ptr = PyUnicode_AsUTF8AndSize(s.ptr(), &size); 27 | py::object ret = tokenize(tok, ptr, size, padlen, dt); 28 | if(ret.is_none()) 29 | throw std::invalid_argument(std::string("Unsupported dtype: ") + dt); 30 | return ret; 31 | }, py::arg("str"), py::arg("padlen") = 0, py::arg("destchar") = "f") 32 | .def("onehot_encode", [](const Tokenizer &tok, py::bytearray s, py::ssize_t padlen, std::string dt) -> py::object { 33 | const py::ssize_t size = PyByteArray_GET_SIZE(s.ptr()); 34 | const char *ptr = PyByteArray_AS_STRING(s.ptr()); 35 | py::object ret = tokenize(tok, ptr, size, padlen, dt); 36 | if(ret.is_none()) 37 | throw std::invalid_argument(std::string("Unsupported dtype: ") + dt); 38 | return ret; 39 | }, py::arg("bytearray"), py::arg("padlen") = 0, py::arg("destchar") = "f") 40 | .def("onehot_encode", [](const Tokenizer &tok, py::bytes bs, py::ssize_t padlen, std::string dt) -> py::object { 41 | py::ssize_t size; 42 | char *ptr; 43 | PyBytes_AsStringAndSize(bs.ptr(), &ptr, &size); 44 | py::object ret = tokenize(tok, ptr, size, padlen, dt); 45 | if(ret.is_none()) 46 | throw std::invalid_argument(std::string("Unsupported dtype: ") + dt); 47 | return ret; 48 | }, py::arg("str"), py::arg("padlen") = 0, py::arg("destchar") = "B") 49 | .def("decode_tokens", [](const Tokenizer& tok, py::array array) { 50 | return tok.decode_tokens_to_string(array); 51 | }, py::arg("tokenizer")) 52 | .def("lut", [](const Tokenizer& tok) { 53 | return tok.lookup; 54 | }) 55 | .def("token_map", [](const Tokenizer& tok) { 56 | return tok.token_map(); 57 | }) 58 | //std::vector token_set() const noexcept {return tokenset_vector;} 59 | //std::unordered_map token_set_map() const noexcept {return tokensets;} 60 | .def("token_decoder", [](const Tokenizer& tok) {return tok.token_set_map();}) 61 | .def("nchars", [](const Tokenizer& tok) { 62 | return tok.nchars(); 63 | }) 64 | // batched one-hot encoding 65 | .def("batch_onehot_encode", [](const Tokenizer &tok, py::sequence seq, py::ssize_t padlen, std::string dt, int nthreads, py::object mask) -> py::object { 66 | switch(std::tolower(dt[0])) { 67 | #define C(x, t) case x: return tok.tokenize(seq, padlen, false, nthreads, mask) 68 | C('b', int8_t); 69 | C('B', uint8_t); 70 | C('h', int16_t); 71 | C('H', uint16_t); 72 | C('I', uint32_t); 73 | C('i', int32_t); 74 | case 'l': C('q', uint64_t); 75 | case 'L': C('Q', int64_t); 76 | C('f', float); 77 | C('d', double); 78 | #undef C 79 | } 80 | throw std::invalid_argument(std::string("Unsupported dtype: ") + dt); 81 | }, py::arg("batch"), py::arg("padlen") = -1, py::arg("destchar") = "B", py::arg("nthreads") = 1, py::arg("mask") = py::none()) 82 | .def("batch_tokenize", [](const Tokenizer &tok, py::sequence seq, py::ssize_t padlen, std::string dt, bool batch_first, int nthreads) -> py::object { 83 | switch(std::tolower(dt[0])) { 84 | #define C(x, t) case x: return tok.transencode(seq, padlen, batch_first, nthreads) 85 | C('b', int8_t); 86 | C('B', uint8_t); 87 | C('h', int16_t); 88 | C('H', uint16_t); 89 | C('I', uint32_t); 90 | C('i', int32_t); 91 | case 'l': C('q', uint64_t); 92 | case 'L': C('Q', int64_t); 93 | C('f', float); 94 | C('d', double); 95 | #undef C 96 | } 97 | throw std::invalid_argument(std::string("Unsupported dtype: ") + dt); 98 | }, py::arg("batch"), py::arg("padlen") = -1, py::arg("destchar") = "B", py::arg("batch_first")=false, py::arg("nthreads") = 1) 99 | .def("alphabet_size", &Tokenizer::full_alphabet_size) 100 | .def("bos", &Tokenizer::bos) 101 | .def("eos", &Tokenizer::eos) 102 | .def("pad", &Tokenizer::pad) 103 | .def_property_readonly("key", [](const Tokenizer &tok) {return tok.key;}) 104 | .def("is_padded", &Tokenizer::is_padded) 105 | .def("includes_bos", &Tokenizer::includes_bos) 106 | .def("includes_eos", &Tokenizer::includes_eos) 107 | .def(py::pickle( 108 | [](const Tokenizer &tok) -> py::tuple {return py::make_tuple(tok.key, tok.include_eos_, tok.include_bos_, tok.zero_onehot_pad_);}, 109 | [](py::tuple t) { 110 | return Tokenizer(t[0].cast(), t[1].cast(), t[2].cast(), t[3].cast()); 111 | } 112 | )); 113 | } 114 | -------------------------------------------------------------------------------- /bioseq/loaders.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import bioseq 5 | import numpy as np 6 | import torch 7 | from bioseq.blosum import augment_seq 8 | import einops 9 | 10 | 11 | def FF2NP(x, tokenizer, destfile, *, batch_size=8192): 12 | import numpy as np 13 | assert isinstance(x, bioseq.FlatFile) 14 | assert isinstance(tokenizer, bioseq.Tokenizer) 15 | msl = x.maxseqlen 16 | total_msl = msl + tokenizer.includes_bos() + tokenizer.includes_eos() 17 | nseqs = x.nseqs() 18 | retmat = np.memmap(destfile, mode='w+', dtype=np.uint8, shape=(nseqs, total_msl)) 19 | nbatches = (nseqs + batch_size - 1) // batch_size 20 | for i in range(nbatches): 21 | start = i * batch_size 22 | stop = min(start + batch_size, nseqs) 23 | seqs = x.access(i * batch_size, stop) 24 | retmat[start:stop] = tokenizer.batch_tokenize( 25 | seqs, padlen=msl, batch_first=True, destchar='B') 26 | return (retmat, destfile) 27 | 28 | 29 | class FlatFileDataset(torch.utils.data.Dataset): 30 | """ 31 | Creates a FlatFileDataset from a Tokenizer and a FlatFile. 32 | if keyword augment is provided, then sequences will be mutated times 33 | before tokenizing using BLOSUM62 substitution rates. 34 | """ 35 | def __init__(self, ff, tokenizer, *, augment=0, augment_frac=0.5, cnn=False, device=None, maskfrac=0.15): 36 | self.maskfrac = maskfrac 37 | if device is None: 38 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 39 | self.device = device 40 | super(FlatFileDataset).__init__() 41 | assert isinstance(ff, bioseq.FlatFile) 42 | assert isinstance(tokenizer, bioseq.Tokenizer) 43 | self.ff = ff 44 | self.tokenizer = tokenizer 45 | self.max_seq_len = ff.maxseqlen + tokenizer.includes_bos() + tokenizer.includes_eos() 46 | self.maxseqlen = self.max_seq_len 47 | self.augment = augment 48 | self.augment_frac = augment_frac 49 | self.cnn = cnn 50 | from numpy.random import default_rng 51 | self.rng = default_rng(13) 52 | 53 | ''' 54 | def access_masked(self, index): 55 | item = self.ff[index] 56 | for i, x in enumerate(item): 57 | if self.augment and self.augment_frac >= 1. or self.rng.uniform() < self.augment_frac: 58 | item[i] = augment_seq(x.decode(), self.augment) 59 | masked_positions = [np.where(self.rng.uniform(size=len(x)))[0] for x in item] 60 | original_sequence = item 61 | import copy 62 | new_items = copy.deepcopy(item) 63 | for mp, item in zip(masked_positions, new_items): 64 | ''' 65 | def fetch(self, index, return_items=False): 66 | import numpy as np 67 | from torch import from_numpy as frnp 68 | if self.cnn: 69 | item = self.ff[index] 70 | if isinstance(item, list): 71 | if self.augment and self.augment_frac >= 1. or self.rng.uniform() < self.augment_frac: 72 | item = list(map(lambda x: augment_seq(x.decode(), self.augment) if (self.augment and (self.augment_frac >= 1. or self.rng.uniform() < self.augment_frac)) else x, item)) 73 | oh = self.tokenizer.batch_onehot_encode(item, padlen=self.ff.maxseqlen + self.tokenizer.includes_eos() + self.tokenizer.includes_bos()) 74 | ret = einops.rearrange(frnp(oh).to(self.device), "length batch emb -> batch emb length").float() 75 | else: 76 | oh = self.tokenizer.onehot_encode(item, padlen=self.ff.maxseqlen + self.tokenizer.includes_eos() + self.tokenizer.includes_bos()) 77 | ret = frnp(oh).to(self.device).float() 78 | if return_items: 79 | return ret, item 80 | else: 81 | seq = self.ff.access(index) 82 | if self.augment and (self.augment_frac >= 1. or self.rng.uniform() < self.augment_frac): 83 | seq = augment_seq(seq.decode(), self.augment) 84 | return frnp(self.tokenizer.batch_tokenize([seq], padlen=self.max_seq_len, batch_first=True, destchar='B')).to(torch.long).squeeze() 85 | 86 | def __getitem__(self, index): 87 | import numpy as np 88 | from torch import from_numpy as frnp 89 | if self.cnn: 90 | item = self.ff[index] 91 | if isinstance(item, list): 92 | if self.augment and self.augment_frac >= 1. or self.rng.uniform() < self.augment_frac: 93 | item = list(map(lambda x: augment_seq(x.decode(), self.augment) if (self.augment and (self.augment_frac >= 1. or self.rng.uniform() < self.augment_frac)) else x, item)) 94 | oh = self.tokenizer.batch_onehot_encode(item, padlen=self.ff.maxseqlen + self.tokenizer.includes_eos() + self.tokenizer.includes_bos()) 95 | return einops.rearrange(frnp(oh).to(self.device), "length batch emb -> batch emb length").float() 96 | else: 97 | oh = self.tokenizer.onehot_encode(item, padlen=self.ff.maxseqlen + self.tokenizer.includes_eos() + self.tokenizer.includes_bos()) 98 | return frnp(oh).to(self.device).float() 99 | else: 100 | seq = self.ff.access(index) 101 | if self.augment and (self.augment_frac >= 1. or self.rng.uniform() < self.augment_frac): 102 | seq = augment_seq(seq.decode(), self.augment) 103 | return frnp(self.tokenizer.batch_tokenize([seq], padlen=self.max_seq_len, batch_first=True, destchar='B')).to(torch.long).squeeze() 104 | 105 | def access(self, slc, stop=None, step=None): 106 | if isinstance(slc, int): 107 | slc = slice(slc, stop, step) 108 | from torch import from_numpy as frnp 109 | seqs = self.ff.access(slc.start, slc.stop, slc.step) 110 | toks = self.tokenizer.batch_tokenize(seqs, padlen=self.max_seq_len, batch_first=True, destchar='B') 111 | toks = frnp(toks).to(torch.long) 112 | def __len__(self): 113 | return self.ff.nseqs() 114 | def cleanup(self): 115 | pass 116 | 117 | class AugmentedSeqDataset(FlatFileDataset): 118 | def __init__(self, ff, tokenizer, augment=1, augment_frac=.5): 119 | super().__init__(ff, tokenizer, augment=augment, augment_frac=augment_frac) 120 | -------------------------------------------------------------------------------- /training/rcompute.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from time import time 4 | from datetime import datetime 5 | 6 | import bioseq 7 | from bioseq.decoders import SeqEncoder, HTransformer1D, XDecoder, XAutoregressiveWrapper, FastEncoder, FAutoregressiveWrapper, RecurrentTransformerWrapper, RecurrentAutoregressiveWrapper 8 | from bioseq.hattn import AutoregressiveWrapper as HAutoregressor 9 | from x_transformers import TransformerWrapper, Decoder 10 | from argparse import ArgumentParser as AP 11 | import numpy as np 12 | import torch 13 | from torch.utils.data import DataLoader 14 | 15 | import tqdm 16 | 17 | 18 | ap = AP() 19 | aa = ap.add_argument 20 | aa("--bos", action="store_true", help="Prepend a BOS tag for sequence") 21 | aa("--eos", action="store_true", help="Append an EOS tag to sequence") 22 | aa("--padchar", action="store_true", help="Treat padding characters are unique identifier. (Default: no embeddings)") 23 | aa("--alphabet", default="PROTEIN") 24 | aa("sequencefile", help="Input sequences; Must be in Fasta or Fastq format. All quality scores are ignored.") 25 | aa("--nepochs", type=float, default=1) 26 | aa("--batchsize", type=int, default=8) 27 | aa("--embdim", type=int, default=64) 28 | aa("--headdim", type=int, default=64) 29 | aa("--nheads", type=int, default=8) 30 | aa("--depth", "--nlayers", type=int, default=6) 31 | aa("--sparseemb", action='store_true', help="Use sparse embeddings.") 32 | aa("--learning-rate", "-R", type=float, default=2e-4) 33 | aa("--accumfreq", type=int, default=4) 34 | aa("--bidir-loss", type=float, const=1., nargs='?') 35 | aa("--clip-grad-norm", "--clip", type=float, default=.5) 36 | aa("--sparse-softmax", action='store_true', help="Whether to use differentiably sparse top-k") 37 | aa("--nthreads", "-p", type=int, default=1) 38 | aa("--gate-residual", action='store_true') 39 | aa("--window-length", "--window_length", type=int, default=128) 40 | aa("--max-mem-len", "--max-mem-length", type=int, default=-1) 41 | aa("--shift-mem-down", default=0, type=int) 42 | aa("--augment", type=int, default=0, help="Number of mutations to introduce while augmenting data. Default: 0.") 43 | aa("--augment-frac", type=float, default=.5, help="Fraction of sequences to augment. Default: 0.5, but only used if --augment is set.") 44 | args = ap.parse_args() 45 | print("#Parameters: %s" % args, file=sys.stderr) 46 | LEARNING_RATE = args.learning_rate 47 | GRADIENT_ACCUMULATE_EVERY = args.accumfreq 48 | torch.set_num_threads(args.nthreads) 49 | if args.sparseemb: 50 | raise Exception("Cannot use sparse embeddings rn") 51 | 52 | def roundup(x): 53 | x = x + 1 54 | for shift in (1, 2, 4, 8, 16, 32): 55 | x |= x >> shift 56 | return x + 1 57 | 58 | if args.max_mem_len <= 0: 59 | args.max_mem_len = args.window_length 60 | print(f"max_mem_len unset; defaulting to window-length {args.window_length}", file=sys.stderr) 61 | 62 | BATCH_SIZE = args.batchsize 63 | argtup = (args.bos, args.eos, args.padchar, args.alphabet) 64 | tokd = bioseq.get_tokenizer_dict(args.bos, args.eos, args.padchar) 65 | try: 66 | tokenizer = tokd[args.alphabet.upper()] 67 | except KeyError: 68 | print(tokd.keys()) 69 | raise 70 | 71 | 72 | def cycle(x): 73 | while 1: 74 | yield from x 75 | 76 | 77 | dstr = str(datetime.now()).replace(" ", "_").replace(":", "-") 78 | ebpos = f"{'eos'if args.eos else 'noeos'}" + f".{'bos'if args.bos else 'nobos'}" 79 | if args.padchar: 80 | ebpos += ".padded" 81 | sequencefile = args.sequencefile 82 | 83 | ffp = args.sequencefile + ".ff" 84 | 85 | if os.path.isfile(ffp): 86 | print("Found existing flatfile", file=sys.stderr) 87 | ff = bioseq.FlatFile(ffp) 88 | else: 89 | print("Making flatfile", file=sys.stderr) 90 | ff = bioseq.FlatFile(args.sequencefile, ffp) 91 | ffl = bioseq.loaders.FlatFileDataset(ff, tokenizer, augment=args.augment, augment_frac=args.augment_frac) 92 | 93 | train_loader = cycle(DataLoader(ffl, batch_size=args.batchsize)) 94 | 95 | msl = ffl.max_seq_len 96 | nchunks = (args.window_length - 1 + msl) // args.window_length 97 | unique_name = f"{sequencefile}.{dstr}.{args.window_length}.{args.alphabet}.heads{args.nheads}.depth{args.depth}.dim{args.embdim}.maxseqlen{msl}.{ebpos}" 98 | # print("msl: %d. roundedup: %d\n" % (msl, roundup(msl))) 99 | # msl = roundup(msl) 100 | 101 | argdict = {} 102 | 103 | # First, make transformerwrapper, which tokenizes, trims, pads, etc. 104 | model = TransformerWrapper(num_tokens=tokenizer.alphabet_size(), max_seq_len=args.window_length, max_mem_len=args.max_mem_len, shift_mem_down=args.shift_mem_down, 105 | attn_layers = Decoder(dim=args.embdim, depth=args.depth, heads=args.nheads, rotary_pos_emb=True, rel_pos_bias=True, reversible=True, gate_residual=args.gate_residual)) 106 | # Then, make recurrenttransformerwrapper which makes this run as if it were a giant transformer model 107 | model = RecurrentTransformerWrapper(model, max_seq_len=msl) 108 | # Finally, apply the autoregressivewrapper 109 | model = RecurrentAutoregressiveWrapper(model) 110 | if torch.cuda.is_available(): 111 | print("Using CUDA") 112 | model = model.cuda() 113 | else: 114 | print("Using CPU with %d threads" % torch.get_num_threads()) 115 | 116 | optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) 117 | 118 | NUM_BATCHES = int((args.nepochs * len(ffl) + GRADIENT_ACCUMULATE_EVERY * args.batchsize - 1) / (GRADIENT_ACCUMULATE_EVERY * args.batchsize)) 119 | print("Num batches: ", NUM_BATCHES) 120 | print(f"Using seqfile {args.sequencefile}") 121 | 122 | tstart = time() 123 | num = 0 124 | for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'): 125 | model.train() 126 | 127 | for __ in range(GRADIENT_ACCUMULATE_EVERY): 128 | nextbatch = next(train_loader).to(torch.long) 129 | loss = model(nextbatch) 130 | loss.backward() 131 | 132 | print(f'training loss: {loss.item()} after {time() - tstart}s') 133 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm) 134 | optim.step() 135 | optim.zero_grad() 136 | from time import time 137 | print(f"Average time per item: {(time() - tstart) / (GRADIENT_ACCUMULATE_EVERY * args.batchsize * NUM_BATCHES)}") 138 | model.eval() 139 | costs = np.memmap(f"costs.{unique_name}.{time()}.f32.bin", mode="w+", shape=(len(ffl),), dtype=np.float32) 140 | for i in range(len(ffl)): 141 | costs[i] = model(ffl[i].to(torch.long).unsqueeze(0)) 142 | print(f"Total cost of dataset: {np.sum(costs)}") 143 | 144 | torch.save(model, f"hmodel.{unique_name}.pt") 145 | 146 | print(f"Total time: {time() - tstart}") 147 | -------------------------------------------------------------------------------- /graphseq/train.py: -------------------------------------------------------------------------------- 1 | # scripts/integrate_and_execute.py 2 | import os 3 | import torch 4 | import dgl 5 | import numpy as np 6 | from data_preparation import read_fasta, construct_knowledge_graph, construct_secondary_structure_graph 7 | from visualization import visualize_graph 8 | from sequence_encoders.xlstm import xLSTM, save_xlstm_model, load_xlstm_model 9 | from sequence_encoders.lstm import LSTM, save_lstm_model, load_lstm_model 10 | from sequence_encoders.bilstm import BiLSTM, save_bilstm_model, load_bilstm_model 11 | from sequence_encoders.attlstm import AttLSTM, save_attlstm_model, load_attlstm_model 12 | from sequence_encoders.bert import BERTSequenceEncoder, save_bert_model, load_bert_model 13 | from graph_encoders.gcn import GCN, save_gcn_model, load_gcn_model 14 | from graph_encoders.graphsage import GraphSAGE, save_graphsage_model, load_graphsage_model 15 | from graph_encoders.gat import GAT, save_gat_model, load_gat_model 16 | from distillation import train_with_distillation 17 | 18 | # Set device 19 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 20 | 21 | # Define model parameters 22 | input_size = 4 # Assuming one-hot encoding of RNA sequences (A, U, C, G) 23 | hidden_size = 128 24 | num_layers = 2 25 | dropout = 0.5 26 | 27 | # Initialize sequence models 28 | sequence_models = { 29 | "xLSTM": xLSTM(input_size, hidden_size, num_layers).to(device), 30 | "LSTM": LSTM(input_size, hidden_size, num_layers).to(device), 31 | "BiLSTM": BiLSTM(input_size, hidden_size, num_layers).to(device), 32 | "AttLSTM": AttLSTM(input_size, hidden_size, num_layers).to(device), 33 | "BERT": BERTSequenceEncoder().to(device) 34 | } 35 | 36 | # Define graph model parameters 37 | in_feats = 4 # Input feature size (e.g., one-hot encoded nucleotides) 38 | hidden_feats = 128 # Hidden feature size 39 | out_feats = 128 # Output feature size (latent space size) 40 | num_layers_gcn = 3 # Number of GCN layers 41 | activation = torch.nn.functional.relu # Activation function 42 | dropout = 0.5 # Dropout rate 43 | 44 | # Initialize graph models 45 | graph_models = { 46 | "GCN": GCN(in_feats, hidden_feats, out_feats).to(device), 47 | "GraphSAGE": GraphSAGE(in_feats, hidden_feats, out_feats, num_layers_gcn, activation, dropout).to(device), 48 | "GAT": GAT(in_feats, hidden_feats, out_feats, num_layers_gcn, num_heads=4, activation=activation, 49 | dropout=dropout).to(device) 50 | } 51 | 52 | 53 | # Function to train and save models with distillation 54 | def train_and_save_models(data_loader, teacher_model, student_model, model_name, graph_model, graph_model_name): 55 | # Define loss criterion and optimizer 56 | criterion = torch.nn.CrossEntropyLoss() 57 | optimizer = torch.optim.Adam(student_model.parameters(), lr=0.001) 58 | 59 | # Train student model with distillation 60 | student_model = train_with_distillation(teacher_model, student_model, data_loader, criterion, optimizer, device) 61 | 62 | # Save student model 63 | if model_name == "xLSTM": 64 | save_xlstm_model(student_model, f'../output/distilled_{model_name}_model.pth') 65 | elif model_name == "LSTM": 66 | save_lstm_model(student_model, f'../output/distilled_{model_name}_model.pth') 67 | elif model_name == "BiLSTM": 68 | save_bilstm_model(student_model, f'../output/distilled_{model_name}_model.pth') 69 | elif model_name == "AttLSTM": 70 | save_attlstm_model(student_model, f'../output/distilled_{model_name}_model.pth') 71 | elif model_name == "BERT": 72 | save_bert_model(student_model, f'../output/distilled_{model_name}_model.pth') 73 | 74 | 75 | # Process each .fa file in the samples directory 76 | data_dir = '../data/samples' 77 | output_dir = '../output/results' 78 | os.makedirs(output_dir, exist_ok=True) 79 | 80 | for filename in os.listdir(data_dir): 81 | if filename.endswith('.fa'): 82 | file_path = os.path.join(data_dir, filename) 83 | sequences = read_fasta(file_path) 84 | 85 | for i, sequence in enumerate(sequences): 86 | # Construct knowledge and secondary structure graphs 87 | knowledge_graph = construct_knowledge_graph(sequence) 88 | secondary_structure_graph = construct_secondary_structure_graph(sequence) 89 | 90 | # Visualize graphs 91 | visualize_graph(knowledge_graph, f"Knowledge Graph - {filename} - Sequence {i + 1}") 92 | visualize_graph(secondary_structure_graph, f"Secondary Structure Graph - {filename} - Sequence {i + 1}") 93 | 94 | # Encode sequences 95 | encoded_sequence = encode_one_hot(sequence) 96 | 97 | for model_name, model in sequence_models.items(): 98 | z_combined_sequence = encode_sequences(model, encoded_sequence, device) 99 | 100 | # Create DGL graph for GCN 101 | g = dgl.DGLGraph() 102 | g.add_nodes(len(sequence)) 103 | for j in range(len(sequence) - 1): 104 | g.add_edge(j, j + 1) 105 | g = dgl.add_self_loop(g) 106 | features = torch.tensor(encoded_sequence, dtype=torch.float32).to(device) 107 | 108 | for graph_model_name, graph_model in graph_models.items(): 109 | # Forward pass to get graph latent features 110 | graph_latent_features = graph_model(g, features) 111 | 112 | # Combine graph and sequence latent features 113 | combined_latent_features = torch.cat([graph_latent_features, z_combined_sequence.unsqueeze(0)], 114 | dim=1) 115 | 116 | # Save combined latent features to file 117 | output_file = os.path.join(output_dir, 118 | f"{filename}_seq{i + 1}_{model_name}_{graph_model_name}_latent_features.npy") 119 | np.save(output_file, combined_latent_features.cpu().numpy()) 120 | print( 121 | f"Saved combined latent features for {filename} sequence {i + 1} with {model_name} and {graph_model_name} to {output_file}") 122 | 123 | # Initialize student model 124 | student_model = model.__class__(*model.args).to(device) # Assuming the model has an args attribute 125 | 126 | # Train and save models with self-distillation 127 | train_and_save_models(data_loader, model, student_model, model_name, graph_model, graph_model_name) 128 | -------------------------------------------------------------------------------- /training/compute.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from time import time 4 | 5 | 6 | import bioseq 7 | 8 | from bioseq.decoders import SeqEncoder, HTransformer1D, XDecoder, XAutoregressiveWrapper, FastEncoder, FAutoregressiveWrapper 9 | from bioseq.hattn import AutoregressiveWrapper as HAutoregressor 10 | from argparse import ArgumentParser as AP 11 | import numpy as np 12 | import torch 13 | from torch.utils.data import DataLoader 14 | 15 | import tqdm 16 | 17 | 18 | ap = AP() 19 | aa = ap.add_argument 20 | aa("--bos", action="store_true", help="Prepend a BOS tag for sequence") 21 | aa("--eos", action="store_true", help="Append an EOS tag to sequence") 22 | aa("--padchar", action="store_true", help="Treat padding characters are unique identifier. (Default: no embeddings)") 23 | aa("--alphabet", default="PROTEIN") 24 | aa("sequencefile", help="Input sequences; Must be in Fasta or Fastq format. All quality scores are ignored.") 25 | aa("--nepochs", type=float, default=1) 26 | aa("--batchsize", type=int, default=8) 27 | aa("--embdim", type=int, default=64) 28 | aa("--headdim", type=int, default=64) 29 | aa("--nheads", type=int, default=8) 30 | aa("--depth", "--nlayers", type=int, default=6) 31 | aa("--sparseemb", action='store_true', help="Use sparse embeddings.") 32 | aa("--learning-rate", "-R", type=float, default=2e-4) 33 | aa("--accumfreq", type=int, default=4) 34 | aa("--bidir-loss", type=float, const=1., nargs='?') 35 | aa("--clip-grad-norm", "--clip", type=float, default=.5) 36 | aa("--transformer-type", "-T", choices=("Fast", "Hier", "X"), help="Type of transformer to use. Default: HTransformer1D (Hier)", default="X") 37 | aa("--sparse-softmax", action='store_true', help="Whether to use differentiably sparse top-k") 38 | aa("--nthreads", "-p", type=int, default=1) 39 | aa("--gate-residual", action='store_true') 40 | aa("--augment", type=int, default=0, help="Number of mutations to introduce while augmenting data. Default: 0.") 41 | aa("--augment-frac", type=float, default=.5, help="Fraction of sequences to augment. Default: 0.5, but only used if --augment is set.") 42 | args = ap.parse_args() 43 | print("#Parameters: %s" % args, file=sys.stderr) 44 | LEARNING_RATE = args.learning_rate 45 | GRADIENT_ACCUMULATE_EVERY = args.accumfreq 46 | torch.set_num_threads(args.nthreads) 47 | if args.sparseemb: 48 | raise Exception("Cannot use sparse embeddings rn") 49 | 50 | def roundup(x): 51 | x = x + 1 52 | for shift in (1, 2, 4, 8, 16, 32): 53 | x |= x >> shift 54 | return x + 1 55 | 56 | BATCH_SIZE = args.batchsize 57 | argtup = (args.bos, args.eos, args.padchar, args.alphabet) 58 | tokd = bioseq.get_tokenizer_dict(args.bos, args.eos, args.padchar) 59 | try: 60 | tokenizer = tokd[args.alphabet.upper()] 61 | except KeyError: 62 | print(tokd.keys()) 63 | raise 64 | 65 | 66 | def cycle(x): 67 | while 1: 68 | yield from x 69 | 70 | 71 | from datetime import datetime 72 | dstr = str(datetime.now()).replace(" ", "_").replace(":", "-") 73 | ebpos = f"{'eos'if args.eos else 'noeos'}" + f".{'bos'if args.bos else 'nobos'}" 74 | if args.padchar: 75 | ebpos += ".padded" 76 | sequencefile = args.sequencefile 77 | 78 | embeddings = bioseq.make_embedding(tokenizer, args.embdim, norm_type=2.0, sparse=args.sparseemb) 79 | 80 | ffp = args.sequencefile + ".ff" 81 | 82 | if os.path.isfile(ffp): 83 | print("Found existing flatfile", file=sys.stderr) 84 | ff = bioseq.FlatFile(ffp) 85 | else: 86 | print("Making flatfile", file=sys.stderr) 87 | ff = bioseq.FlatFile(args.sequencefile, ffp) 88 | ffl = bioseq.loaders.FlatFileDataset(ff, tokenizer, augment=args.augment, augment_frac=args.augment_frac) 89 | 90 | train_loader = cycle(DataLoader(ffl, batch_size=args.batchsize)) 91 | 92 | msl = ffl.max_seq_len 93 | if args.transformer_type == "Hier": 94 | nmsl = roundup(msl) 95 | print(f"Padding msl to next power of to: {msl}->{nmsl}",file=sys.stderr) 96 | msl = nmsl 97 | del nmsl 98 | unique_name = f"{sequencefile}.{dstr}.{args.transformer_type}.{args.alphabet}.heads{args.nheads}.depth{args.depth}.dim{args.embdim}.maxseqlen{msl}.{ebpos}" 99 | # print("msl: %d. roundedup: %d\n" % (msl, roundup(msl))) 100 | # msl = roundup(msl) 101 | 102 | tokl = bioseq.decoders.TokenizerLayer(tokenizer, padlen=msl) 103 | 104 | argdict = {} 105 | 106 | baseargs = {"num_tokens": tokenizer.alphabet_size(), "heads": args.nheads, "depth": args.depth, "dim": args.embdim, "max_seq_len": msl} 107 | if args.transformer_type == "Fast": 108 | TxType = FastEncoder 109 | baseargs.update({"query_sparse_softmax": args.sparse_softmax, "key_sparse_softmax": args.sparse_softmax}) 110 | elif args.transformer_type == "Hier": 111 | TxType = HTransformer1D 112 | baseargs.update({"causal": True, "reversible": True}) 113 | else: 114 | assert args.transformer_type == "X" 115 | TxType = XDecoder 116 | baseargs.update({"gate_residual": args.gate_residual, 'rotary_pos_emb': True, "reversible": True}) 117 | seq_encoder = SeqEncoder(tokl, embeddings, TxType, **baseargs) 118 | encoder = seq_encoder.encoder 119 | model = seq_encoder 120 | if torch.cuda.is_available(): 121 | print("Using CUDA") 122 | model = model.cuda() 123 | else: 124 | print("Using CPU with %d threads" % torch.get_num_threads()) 125 | if args.transformer_type == "Hier": 126 | model = HAutoregressor(model) 127 | elif args.transformer_type == "Fast": 128 | model = XAutoregressiveWrapper(model) 129 | else: 130 | model = XAutoregressiveWrapper(model) 131 | optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) 132 | 133 | NUM_BATCHES = int((args.nepochs * len(ffl) + GRADIENT_ACCUMULATE_EVERY * args.batchsize - 1) / (GRADIENT_ACCUMULATE_EVERY * args.batchsize)) 134 | print("Num batches: ", NUM_BATCHES) 135 | print(f"Using seqfile {args.sequencefile}") 136 | 137 | tstart = time() 138 | num = 0 139 | for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'): 140 | model.train() 141 | 142 | for __ in range(GRADIENT_ACCUMULATE_EVERY): 143 | gstart = time() 144 | nextbatch = next(train_loader).to(torch.long) 145 | loss = model(nextbatch) 146 | if args.bidir_loss: 147 | loss += args.bidir_loss * model(torch.flip(nextbatch, (1,))) 148 | loss.backward() 149 | 150 | print(f'training loss: {loss.item()} after {time() - tstart}s') 151 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm) 152 | optim.step() 153 | optim.zero_grad() 154 | from time import time 155 | print(f"Average time per item: {(time() - tstart) / (GRADIENT_ACCUMULATE_EVERY * args.batchsize * NUM_BATCHES)}") 156 | model.eval() 157 | costs = np.memmap(f"costs.{unique_name}.{time()}.f32.bin", mode="w+", shape=(len(ffl),), dtype=np.float32) 158 | for i in range(len(ffl)): 159 | costs[i] = model(ffl[i].to(torch.long).unsqueeze(0)) 160 | print(f"Total cost of dataset: {np.sum(costs)}") 161 | 162 | torch.save(model, f"hmodel.{unique_name}.pt") 163 | 164 | print(f"Total time: {time() - tstart}") 165 | -------------------------------------------------------------------------------- /training/cnnpretrain.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from argparse import ArgumentParser as AP 4 | from timeit import default_timer as time 5 | from functools import reduce 6 | 7 | import torch 8 | import random 9 | import numpy as np 10 | import einops 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from torch.utils.data import DataLoader 14 | 15 | import tqdm 16 | 17 | 18 | import bioseq 19 | import bioseq.cnnencoder as cnn 20 | from bioseq.cnnencoder import RevConvInfiller 21 | from bioseq.blosum import augment_seq 22 | 23 | 24 | ap = AP() 25 | aa = ap.add_argument 26 | aa("inputfile", help="Path to input fasta") 27 | 28 | aa("--alphabet", default="PROTEIN") 29 | aa("--bos", action="store_true", help="Prepend a BOS tag for sequence") 30 | aa("--eos", action="store_true", help="Append an EOS tag to sequence") 31 | aa("--padchar", action="store_true", help="Treat padding characters are unique identifier. (Default: no embeddings)") 32 | aa("--batchsize", "--batch-size", type=int, default=64) 33 | aa("--emb-dim", "--embedding", type=int, default=64) 34 | aa("--revdepth", default=3, type=int, help="Depth of reversible CNN block (number of squeeze-excite layers)") 35 | aa("--totaldepth", default=3, type=int, help="Number of RevNet + Squeeze/Excite Block Pairs. Total number of reversible layers is revdepth * totaldepth.") 36 | aa("--noactivation", action='store_true', help="Whether or not to perform activation at the start of bottleneck layer") 37 | aa("--nthreads", type=int, default=1, help="Number of threads. Set to < 0 to use all threads") 38 | aa("--learning-rate", "-R", type=float, default=2e-4) 39 | aa("--accumfreq", type=int, default=4) 40 | aa("--augment", type=int, default=0, help="Number of mutations to introduce while augmenting data. Default: 0.") 41 | aa("--augment-frac", type=float, default=.5, help="Fraction of sequences to augment. Default: 0.5, but only used if --augment is set.") 42 | aa("--kernel-size", type=int, default=9) 43 | aa("--nepochs", type=float, default=1) 44 | aa("--maskfrac", type=float, default=0.15) 45 | aa("--seed", type=int, default=0) 46 | 47 | usecuda = torch.cuda.is_available() 48 | device = torch.device("cuda:0") if usecuda else torch.device("cpu") 49 | 50 | ap = ap.parse_args() 51 | args = ap 52 | nt = ap.nthreads 53 | if nt < 0: 54 | from multiprocessing import cpu_count as CC 55 | nt = CC() 56 | torch.set_num_threads(nt) 57 | LEARNING_RATE = ap.learning_rate 58 | torch.manual_seed(ap.seed) 59 | np.random.seed(ap.seed) 60 | random.seed(ap.seed) 61 | 62 | 63 | ff = None 64 | ffp = ap.inputfile + ".ff" 65 | if os.path.isfile(ffp): 66 | ff = bioseq.FlatFile(ffp) 67 | else: 68 | ff = bioseq.FlatFile(ap.inputfile, ffp) 69 | 70 | argtup = (ap.bos, ap.eos, ap.padchar, ap.alphabet) 71 | tokd = bioseq.get_tokenizer_dict(ap.bos, ap.eos, ap.padchar) 72 | try: 73 | tokenizer = tokd[ap.alphabet.upper()] 74 | except KeyError: 75 | print(tokd.keys()) 76 | raise 77 | 78 | # print("tokenizer eos: ", tokenizer.eos(), "bos", tokenizer.bos(), "padchar", tokenizer.pad(), "is padded", tokenizer.is_padded()) 79 | 80 | pl = padlen = ff.maxseqlen + tokenizer.includes_eos() + tokenizer.includes_bos() 81 | inchannels = tokenizer.alphabet_size() 82 | # print("Alphabet size: ", inchannels) 83 | model = cnn.RevConvNetwork1D(inchannels, channels=ap.emb_dim, kernel_size=ap.kernel_size, revdepth=ap.revdepth, totaldepth=ap.totaldepth, noactivation=ap.noactivation) 84 | model = RevConvInfiller(model, tokenizer, ap.emb_dim).to(device) 85 | if usecuda: 86 | model = nn.DataParallel(model) 87 | optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) 88 | ffl = bioseq.loaders.FlatFileDataset(ff, tokenizer, augment=ap.augment, augment_frac=ap.augment_frac, cnn=True, device=device) 89 | NUM_BATCHES = int((ap.nepochs * len(ffl) + ap.accumfreq * ap.batchsize - 1) / (ap.accumfreq * ap.batchsize)) 90 | 91 | def load_next(batch): 92 | oh = tokenizer.batch_onehot_encode(batch, padlen=pl) 93 | laidout = einops.rearrange(oh, "length batch emb -> batch emb length") 94 | return torch.from_numpy(laidout).device(device).float() 95 | 96 | def cycle(x): 97 | while 1: 98 | yield from x 99 | 100 | random_name = "".join(random.choice("abcdefghijklmn") for x in range(10)) + hex(reduce(lambda x, y: x ^ hash(y), sys.argv, 0)) 101 | 102 | # train_loader = cycle(DataLoader(ffl, batch_size=ap.batchsize)) 103 | 104 | assert 1. > ap.maskfrac > 0. 105 | 106 | tstart = time() 107 | num = 0 108 | global bstart 109 | bstart = 0 110 | PL = ff.maxseqlen + tokenizer.includes_eos() + tokenizer.includes_bos() 111 | def getbatch(): 112 | global bstart 113 | seqs = ff[bstart:bstart + ap.batchsize] 114 | LS = len(seqs) 115 | augmented_seq_indexes = np.where(np.random.rand(LS) < ap.augment_frac)[0] if ap.augment else [] 116 | for idx in augmented_seq_indexes: 117 | seqs[idx] = augment_seq(seqs[idx].decode(), ap.augment) 118 | 119 | mask = torch.rand(LS, ff.maxseqlen, device=device) > ap.maskfrac 120 | for row, seq in enumerate(seqs): 121 | mask[row,len(seq):] = 1 122 | # mask = [np.hstack([np.random.rand(len(seq)) > ap.maskfrac, np.zeros(ff.maxseqlen - len(seq), dtype=np.bool_)]) for seq in seqs] 123 | ohdata = tokenizer.batch_onehot_encode(seqs, padlen=PL) 124 | maskeddata = tokenizer.batch_onehot_encode(seqs, padlen=PL, mask=mask) 125 | ohdata, maskeddata = map(lambda x: einops.rearrange(torch.from_numpy(x), "length batch emb -> batch emb length").to(device).float().contiguous(), (ohdata, maskeddata)) 126 | bstart += ap.batchsize 127 | if bstart > len(ffl): 128 | bstart = 0 129 | return ohdata, maskeddata, seqs, mask 130 | 131 | losses = [] 132 | finished_seqs = 0 133 | saved_loss_id = 0 134 | 135 | startpos = int(tokenizer.includes_bos()) 136 | tstop = PL - int(tokenizer.includes_eos()) 137 | 138 | for bn in range(NUM_BATCHES): 139 | for __ in range(ap.accumfreq): 140 | gstart = time() 141 | oh, moh, seqs, masks = getbatch() 142 | stackmask = torch.logical_not(masks) 143 | assert moh.device == device 144 | emb, bo = model(moh) 145 | tokens = torch.from_numpy(tokenizer.batch_tokenize(seqs, padlen=PL)).to(device).long() 146 | # print("tokens shape", tokens.shape, stackmask.shape) 147 | ''' 148 | if 0: 149 | where = torch.where(stackmask) 150 | seltoks = tokens[startpos:tstop,:][stackmask.T].to(device).long() 151 | bot = bo[:,startpos:tstop,:][stackmask] 152 | loss = F.cross_entropy(bot, seltoks) 153 | else: 154 | ''' 155 | loss = F.cross_entropy(bo.transpose(1, 2), tokens.T) 156 | losses.append(float(loss.item())) 157 | # sys.exit(1) 158 | # Now, loss function for masked items 159 | # This will simply be the MASS objective. 160 | # Coming back around, it would benefit from turning it to the autoregressive loss 161 | loss.backward() 162 | finished_seqs += len(seqs) 163 | 164 | if not (bn & 127): 165 | print(f'[Batch {bn}] training loss: {loss.item()} after {time() - tstart}s after {finished_seqs} sequences; mean of last 10 {np.mean(losses[-10:])}', flush=True) 166 | if finished_seqs >= len(seqs): 167 | torch.save(model, f"model.{random_name}.{saved_loss_id}.pt") 168 | saved_loss_id += 1 169 | optim.step() 170 | optim.zero_grad() 171 | 172 | tend = time() 173 | np.array(losses).astype(np.float32).tofile(f"model.{random_name}.final.losses.f32") 174 | torch.save(model, f"model.{random_name}.final.pt") 175 | print("Training took %gs" % (tend - tstart)) 176 | -------------------------------------------------------------------------------- /bioseq/cnnencoder.py: -------------------------------------------------------------------------------- 1 | import einops 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | from fast_transformer_pytorch.fast_transformer_pytorch import PreNorm 6 | 7 | def default(x, y): 8 | return x if x is not None else y 9 | 10 | class ConvBlock1D(nn.Module): 11 | def __init__(self, channels, outchannels=None, *, kernel_size=3, stride=None, padding=None, groups=1, dilation=1): 12 | outchannels = default(outchannels, channels) 13 | super().__init__() 14 | stride = default(stride, max(1, (kernel_size // 2) - 1)) 15 | padding = default(padding, max(1, (kernel_size // 2))) 16 | kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size,) 17 | self.seq = nn.Sequential( 18 | nn.Conv1d(in_channels=channels, out_channels=outchannels, 19 | kernel_size=kernel_size, padding=padding), 20 | nn.BatchNorm1d(num_features=outchannels), 21 | nn.ReLU(inplace=True) 22 | ) 23 | 24 | def forward(self, x): 25 | return self.seq(x) 26 | 27 | 28 | def batch_norm(x): 29 | """match Tensorflow batch norm settings""" 30 | return nn.BatchNorm1d(x, momentum=0.99, eps=0.001) 31 | 32 | 33 | class RevBottleneck(nn.Module): 34 | expansion = 1 35 | ''' 36 | Adapted from MemCNN's Resnet https://github.com/silvandeleemput/memcnn/blob/afd65198fb41e7339882ec55d35ad041edba13d7/memcnn/models/resnet.py 37 | ''' 38 | def __init__(self, inchannels, channels=None, stride=1, downsample=None, noactivation=False, expansion=4): 39 | channels = default(channels, inchannels) 40 | super(RevBottleneck, self).__init__() 41 | makelayer = lambda: BottleneckSub(inchannels // 2, channels // 2, stride, noactivation, expansion=expansion) 42 | from memcnn import create_coupling, InvertibleModuleWrapper 43 | coupling = create_coupling(Fm=makelayer(), Gm=makelayer(), coupling='additive') 44 | self.revblock = InvertibleModuleWrapper(fn=coupling, keep_input=False) 45 | self.downsample = downsample 46 | self.stride = stride 47 | self.expansion = expansion 48 | 49 | def forward(self, x): 50 | if self.downsample is not None: 51 | out = self.bottleneck_sub(x) 52 | residual = self.downsample(x) 53 | out += residual 54 | else: 55 | out = self.revblock(x) 56 | return out 57 | 58 | 59 | class BottleneckSub(nn.Module): 60 | ''' 61 | Adapted from MemCNN's Resnet https://github.com/silvandeleemput/memcnn/blob/afd65198fb41e7339882ec55d35ad041edba13d7/memcnn/models/resnet.py 62 | Unlike their class, this increases then decreases the embedding dimension, returning dat of the same shape. 63 | ''' 64 | def __init__(self, inchannels, channels=None, kernel_size=3, stride=1, noactivation=False, expansion=4): 65 | if channels is None: 66 | channels = inchannels 67 | super(BottleneckSub, self).__init__() 68 | self.noactivation = noactivation 69 | if not self.noactivation: 70 | self.bn1 = batch_norm(inchannels) 71 | self.conv1 = nn.Conv1d(inchannels, channels, kernel_size=1, bias=False) 72 | self.bn2 = batch_norm(channels) 73 | self.conv2 = nn.Conv1d(channels, channels, kernel_size=kernel_size, stride=stride, 74 | bias=False, padding='same') 75 | self.bn3 = batch_norm(channels) 76 | expanded_channels = channels * expansion 77 | self.conv3 = nn.Conv1d(channels, expanded_channels, kernel_size=1, bias=False) 78 | self.relu = nn.ReLU(inplace=True) 79 | self.expansion = expansion 80 | self.bn4 = batch_norm(expanded_channels) 81 | self.conv4 = nn.Conv1d(expanded_channels, channels, kernel_size=kernel_size, bias=False, padding='same') 82 | 83 | def forward(self, x): 84 | if not self.noactivation: 85 | x = self.bn1(x) 86 | x = self.relu(x) 87 | x = self.conv1(x) 88 | x = self.bn2(x) 89 | x = self.relu(x) 90 | x = self.conv2(x) 91 | x = self.bn3(x) 92 | x = self.relu(x) 93 | x = self.conv3(x) 94 | x = self.bn4(x) 95 | x = self.conv4(x) 96 | return x 97 | 98 | 99 | class RevConvBlock1D(nn.Module): 100 | def __init__(self, channels, additive=True, padding=None, kernel_size=3, dilation=1, groups=1, stride=1, depth=2): 101 | """ 102 | Reversible Conv1DBlock 103 | Uses Affine coupling by default, but this can be switched to Additive via additive= 104 | Has padding, kernel_size, dilation, groups, and stride arguments which are passed to Convblock1D 105 | """ 106 | super().__init__() 107 | if not additive: 108 | raise InvalidArgument("Only additive is supported currently. MemCNN seems to have a problem with affine.") 109 | from memcnn import AffineCoupling, AdditiveCoupling, InvertibleModuleWrapper 110 | if channels & 1: 111 | raise RuntimeError("Channels must be divisble by 2 in order to perform a reversible conv block") 112 | self.channels = channels 113 | halfchan = channels >> 1 114 | Coupling = [AffineCoupling, AdditiveCoupling][int(additive)] 115 | # Set coupling function, then create the invertible block 116 | makeblock = lambda : ConvBlock1D(channels=halfchan, kernel_size=kernel_size, padding=padding, dilation=dilation, stride=stride, groups=groups) 117 | self.invmods = [Coupling(Fm=makeblock(), Gm=makeblock()) for i in range(depth)] 118 | self.wrappers = [InvertibleModuleWrapper(fn=x, keep_input=True, keep_input_inverse=True) for x in self.invmods] 119 | self.invmodw = nn.Sequential(*self.wrappers) 120 | 121 | def forward(self, x): 122 | return self.invmodw(x) 123 | 124 | 125 | class RevConvNetwork1D(nn.Module): 126 | def __init__(self, inchannels, channels=None, padding=None, kernel_size=3, revdepth=3, totaldepth=3, noactivation=False): 127 | super().__init__() 128 | import itertools 129 | channels = default(channels, inchannels) 130 | layers = [ConvBlock1D(inchannels, channels)] 131 | for _ in range(totaldepth): 132 | layers.append(RevConvBlock1D(channels, padding=padding, kernel_size=kernel_size, depth=revdepth)) 133 | layers.append(BottleneckSub(channels, kernel_size=kernel_size, noactivation=noactivation)) 134 | self.seq = nn.Sequential(*layers) 135 | 136 | def forward(self, x): 137 | return self.seq(x) 138 | 139 | 140 | class RevConvClassifier(nn.Module): 141 | def __init__(self, inchannels, num_classes, *, channels=None, padding=None, kernel_size=3, revdepth=3, totaldepth=3, noactivation=False, softmax=None): 142 | super().__init__() 143 | channels = default(channels, inchannels) 144 | self.net = RevConvNetwork1D(inchannels=inchannels, channels=channels, padding=padding, kernel_size=kernel_size, revdepth=revdepth, totaldepth=totaldepth, noactivation=noactivation) 145 | self.pool = nn.AdaptiveAvgPool1d(1) 146 | self.fc = nn.Linear(channels, num_classes) 147 | softmax = default(softmax, nn.Softmax(-1)) 148 | 149 | def logits(self, x): 150 | ''' 151 | Returns logits 152 | ''' 153 | embeddings = self.net(x) 154 | pooled = self.pool(embeddings).squeeze(-1) # Average across data 155 | return self.fc(pooled) 156 | 157 | def forward(self, x): 158 | logits = self.logits(x) 159 | 160 | 161 | 162 | class ResConvBlock1D(nn.Module): 163 | def __init__(self, channels, additive=True, padding=None, kernel_size=3, dilation=1, groups=1, stride=1, depth=2, downsample=None): 164 | """ 165 | Reversible Conv1DBlock 166 | Uses Affine coupling by default, but this can be switched to Additive via additive= 167 | Has padding, kernel_size, dilation, groups, and stride arguments which are passed to Convblock1D 168 | """ 169 | super().__init__() 170 | self.block = RevConvBlock1D(channels, additive=additive, padding=padding, kernel_size=kernel_size, dilation=dilation, groups=groups, stride=stride, depth=depth) 171 | self.downsample = downsample 172 | self.expansion = 1 173 | 174 | def forward(self, x): 175 | res = x 176 | out = self.block(x) 177 | if self.downsample is not None: 178 | res = self.downsample(x) 179 | out += res 180 | return out 181 | 182 | class RevConvInfiller(nn.Module): 183 | def __init__(self, net, tokenizer, embdim): 184 | super().__init__() 185 | self.net = net 186 | self.tokenizer = tokenizer 187 | self.fc = nn.Linear(embdim, tokenizer.alphabet_size()) 188 | def forward(self, x): 189 | emb = self.net(x) 190 | return emb, self.fc(emb.transpose(2, 1)) 191 | -------------------------------------------------------------------------------- /src/fxstats.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include "mio.hpp" 6 | #include "span.hpp" 7 | #include "kseq.h" 8 | namespace py = pybind11; 9 | 10 | KSEQ_INIT(gzFile, gzread) 11 | 12 | std::vector getlens(const std::string &path) { 13 | gzFile fp = gzopen(path.data(), "r"); 14 | if(fp == nullptr) throw std::runtime_error(path + " failed to open"); 15 | kseq_t *kseq = kseq_init(fp); 16 | std::vector lens; 17 | while(kseq_read(kseq) >= 0) 18 | lens.push_back(kseq->seq.l); 19 | kseq_destroy(kseq); 20 | gzclose(fp); 21 | return lens; 22 | } 23 | 24 | 25 | struct FlatFileIterator; 26 | struct FlatFile { 27 | const std::string path_; 28 | mio::mmap_source data_; 29 | const size_t nseqs_; 30 | nonstd::span offsets_; 31 | const size_t seq_offset_; 32 | uint32_t max_seq_len_; 33 | static FlatFile make(std::string inpath, std::string outpath) { 34 | if(outpath.empty()) { 35 | outpath = inpath + ".ff"; 36 | } 37 | std::vector offsets{0}; 38 | uint32_t max_seq_len = 0; 39 | gzFile fp = gzopen(inpath.data(), "r"); 40 | if(fp == nullptr) throw std::runtime_error(inpath + " failed to open"); 41 | kseq_t *ks = kseq_init(fp); 42 | std::string cseq; 43 | std::vector seqs; 44 | while(kseq_read(ks) >= 0) { 45 | if(ks->seq.l > 0xFFFFFFFFu) throw std::invalid_argument("Cannot handle sequences longer than 2^32 - 1"); 46 | max_seq_len = std::max(max_seq_len, uint32_t(ks->seq.l)); 47 | offsets.push_back(offsets.back() + ks->seq.l); 48 | seqs.emplace_back(ks->seq.s, ks->seq.l); 49 | } 50 | std::FILE *ofp = std::fopen(outpath.data(), "w"); 51 | if(!ofp) throw std::runtime_error(outpath + " could not be opened for writing"); 52 | uint64_t nseqs = seqs.size(); 53 | // 8 bytes: number of sequences 54 | // 8 * nseqs: offsets to start of sequences 55 | std::fwrite(&nseqs, sizeof(nseqs), 1, ofp); 56 | std::fwrite(offsets.data(), sizeof(uint64_t), offsets.size(), ofp); 57 | for(const auto &s: seqs) { 58 | std::fwrite(s.data(), 1, s.size(), ofp); 59 | } 60 | kseq_destroy(ks); 61 | gzclose(fp); 62 | std::fclose(ofp); 63 | return FlatFile(outpath, max_seq_len); 64 | } 65 | FlatFile(std::string inpath, std::string outpath): FlatFile(FlatFile::make(inpath, outpath)) {} 66 | FlatFile(std::string path, py::ssize_t mslen=-1): path_(path), data_(path_), nseqs_(*(uint64_t *)(data_.data())), 67 | offsets_((size_t *)(data_.data()) + 1, nseqs_ + 1), seq_offset_((nseqs_ + 2) * 8), max_seq_len_(mslen) 68 | { 69 | if(mslen < 0) { 70 | max_seq_len_ = 0; 71 | for(size_t i = 0; i < nseqs(); ++i) { 72 | max_seq_len_ = std::max(uint32_t(length(i)), max_seq_len_); 73 | } 74 | } 75 | } 76 | FlatFile(FlatFile &&o): path_(o.path_), nseqs_(o.nseqs_), seq_offset_(o.seq_offset_), max_seq_len_(o.max_seq_len_) { 77 | std::swap_ranges((uint8_t *)&data_,(uint8_t *)&data_ + sizeof(data_), (uint8_t *)&o.data_); 78 | } 79 | FlatFile(const FlatFile &o) = default; 80 | auto &offsets() {return offsets_;} 81 | const auto &offsets() const {return offsets_;} 82 | uint32_t max_seq_len() const {return max_seq_len_;} 83 | size_t nseqs() const { 84 | return nseqs_; 85 | } 86 | size_t seq_offset() const { 87 | return seq_offset_; 88 | } 89 | size_t length(size_t idx) const { 90 | return offsets_[idx + 1] - offsets_[idx]; 91 | } 92 | const char*offset(size_t idx) const { 93 | return data_.data() + offsets_[idx] + seq_offset_; 94 | } 95 | py::list range_access(py::array idx) const { 96 | py::array_t ids(idx); 97 | py::list ret; 98 | auto bi = ids.request(); 99 | const uint64_t *ptr = (const uint64_t *)bi.ptr; 100 | for(py::ssize_t i = 0; i < bi.size; ++i) { 101 | py::ssize_t ind = ptr[i]; 102 | if(ind < 0) ind = nseqs_ - ind; 103 | ret.append(access(ptr[ind])); 104 | } 105 | return ret; 106 | } 107 | py::list range_access(py::slice slc) const { 108 | size_t start = 0, stop = 0, step = 0, slicelength = 0; 109 | if(!slc.compute(this->nseqs_, &start, &stop, &step, &slicelength)) 110 | throw py::error_already_set(); 111 | #ifndef NDEBUG 112 | std::fprintf(stderr, "Start, stop, step %zd %zd %zd\n", start, stop, step); 113 | #endif 114 | return range_access(start, stop, step); 115 | } 116 | py::array indptr() const { 117 | py::array ret({{nseqs() + 1}}, offsets_.data()); 118 | uint64_t *ptr = (uint64_t *)ret.request().ptr; 119 | std::copy(offsets().begin(), offsets().end(), ptr); 120 | return ret; 121 | } 122 | py::list range_access(py::ssize_t i, py::ssize_t j, py::ssize_t step) const { 123 | py::list ret; 124 | if(step == 0) throw std::invalid_argument("step must be nonzero"); 125 | for(py::ssize_t idx = i; step > 0 ? idx < j: idx > j; ret.append(access(idx)), idx += step); 126 | return ret; 127 | } 128 | py::bytearray access(size_t i) const { 129 | if(i >= nseqs_) throw std::out_of_range("Accessing sequence out of range"); 130 | //std::fprintf(stderr, "First 4 chars: %s\n", std::string(offset(i), 4).data()); 131 | //std::fprintf(stderr, "Accessing string %s\n", std::string(offset(i), length(i)).data()); 132 | return py::bytearray(offset(i), length(i)); 133 | } 134 | }; 135 | 136 | struct FlatFileIterator { 137 | const FlatFile *ptr_; 138 | size_t start_; 139 | size_t stop_; 140 | FlatFileIterator(const FlatFile &src): ptr_(&src), start_(0), stop_(src.nseqs()) {} 141 | //FlatFileIterator(const FlatFileIterator &) = default; 142 | FlatFileIterator(const FlatFile *ptr, size_t start, size_t stop): ptr_(ptr), start_(start), stop_(stop) {} 143 | FlatFileIterator &next() { 144 | if(__builtin_expect(++start_== stop_, 0)) 145 | throw py::stop_iteration("End of iterator"); 146 | return *this; 147 | } 148 | py::bytearray sequence() const { 149 | return ptr_->access(start_); 150 | } 151 | }; 152 | //#undef bytearray 153 | 154 | void init_fxstats(py::module &m) { 155 | py::class_(m, "FlatFileIterator") 156 | .def(py::init()) 157 | .def("__iter__", [](const FlatFileIterator &x) {return x;}) 158 | .def("__next__", [](FlatFileIterator &x) {return x.next();}) 159 | .def_property_readonly("sequence", &FlatFileIterator::sequence) 160 | .def_property_readonly("seq", &FlatFileIterator::sequence); 161 | 162 | py::class_(m, "FlatFile") 163 | .def(py::init(), py::arg("inputfile"), py::arg("maxseqlen") = -1) 164 | .def(py::init()) 165 | .def_readonly("path", &FlatFile::path_) 166 | .def("access", &FlatFile::access) 167 | .def("access", [](const FlatFile &x, py::slice slc) {return x.range_access(slc);}) 168 | .def("access", [](const FlatFile &x, size_t i, size_t j, size_t step) {return x.range_access(i, j, step);}, 169 | py::arg("start"), py::arg("stop"), py::arg("step") = 1) 170 | .def("__len__", &FlatFile::nseqs) 171 | .def("nseqs", &FlatFile::nseqs) 172 | .def("size", &FlatFile::nseqs) 173 | .def("seq_offset", &FlatFile::seq_offset) 174 | .def("indptr", &FlatFile::indptr) 175 | .def_property_readonly("maxseqlen", &FlatFile::max_seq_len) 176 | .def_property_readonly("max_seq_len", &FlatFile::max_seq_len) 177 | .def("__iter__", [](const FlatFile &x) {return FlatFileIterator(x);}, py::keep_alive<0, 1>()) 178 | .def("__getitem__", [](const FlatFile &x, py::ssize_t idx) -> py::bytearray { 179 | py::ssize_t ai; 180 | if(idx >= 0) { 181 | ai = idx; 182 | } else { 183 | if(idx < -py::ssize_t(x.nseqs())) 184 | throw std::out_of_range("For a negative index, idx must be >= -len(x)"); 185 | ai = x.nseqs() + idx; 186 | } 187 | return x.access(ai); 188 | }) 189 | .def("__getitem__", [](const FlatFile &x, py::slice slc) -> py::list { 190 | return x.range_access(slc); 191 | }) 192 | .def("__getitem__", [](const FlatFile &x, py::array arr) -> py::list { 193 | return x.range_access(arr); 194 | }); 195 | 196 | 197 | #if 0 198 | m.def("makeflat", [](std::string inpath, std::string outpath) { 199 | return FlatFile::make(inpath, outpath); 200 | }, py::arg("input"), py::arg("output") = "", py::return_value_policy::move); 201 | #endif 202 | m.def("getstats", [](py::sequence items) { 203 | std::vector paths; 204 | for(const auto item: items) 205 | paths.emplace_back(item.cast()); 206 | py::list alist = py::list(); 207 | while(py::len(alist) < paths.size()) { 208 | alist.append(py::none()); 209 | } 210 | for(size_t i = 0; i < paths.size(); ++i) { 211 | const auto vals = getlens(paths[i]); 212 | const py::ssize_t sz = vals.size(); 213 | py::array_t ret(std::vector{sz}); 214 | py::buffer_info bi = ret.request(); 215 | std::copy(vals.data(), vals.data() + sz, (size_t *)bi.ptr); 216 | alist[i] = ret; 217 | } 218 | return alist; 219 | }); 220 | } 221 | -------------------------------------------------------------------------------- /src/kseq.h: -------------------------------------------------------------------------------- 1 | /* The MIT License 2 | 3 | Copyright (c) 2008, 2009, 2011 Attractive Chaos 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining 6 | a copy of this software and associated documentation files (the 7 | "Software"), to deal in the Software without restriction, including 8 | without limitation the rights to use, copy, modify, merge, publish, 9 | distribute, sublicense, and/or sell copies of the Software, and to 10 | permit persons to whom the Software is furnished to do so, subject to 11 | the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be 14 | included in all copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 19 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS 20 | BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN 21 | ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 22 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. 24 | */ 25 | 26 | /* Last Modified: 05MAR2012 */ 27 | 28 | #ifndef AC_KSEQ_H 29 | #define AC_KSEQ_H 30 | 31 | #include 32 | #include 33 | #include 34 | 35 | #define KS_SEP_SPACE 0 // isspace(): \t, \n, \v, \f, \r 36 | #define KS_SEP_TAB 1 // isspace() && !' ' 37 | #define KS_SEP_LINE 2 // line separator: "\n" (Unix) or "\r\n" (Windows) 38 | #define KS_SEP_MAX 2 39 | 40 | #define __KS_TYPE(type_t) \ 41 | typedef struct __kstream_t { \ 42 | unsigned char *buf; \ 43 | int begin, end, is_eof; \ 44 | type_t f; \ 45 | } kstream_t; 46 | 47 | #define ks_err(ks) ((ks)->end == -1) 48 | #define ks_eof(ks) ((ks)->is_eof && (ks)->begin >= (ks)->end) 49 | #define ks_rewind(ks) ((ks)->is_eof = (ks)->begin = (ks)->end = 0) 50 | 51 | #define __KS_BASIC(type_t, __bufsize) \ 52 | static inline kstream_t *ks_init(type_t f) \ 53 | { \ 54 | kstream_t *ks = (kstream_t*)calloc(1, sizeof(kstream_t)); \ 55 | ks->f = f; \ 56 | ks->buf = (unsigned char*)malloc(__bufsize); \ 57 | return ks; \ 58 | } \ 59 | static inline void ks_destroy(kstream_t *ks) \ 60 | { \ 61 | if (ks) { \ 62 | free(ks->buf); \ 63 | free(ks); \ 64 | } \ 65 | } 66 | 67 | #define __KS_GETC(__read, __bufsize) \ 68 | static inline int ks_getc(kstream_t *ks) \ 69 | { \ 70 | if (ks_err(ks)) return -3; \ 71 | if (ks->is_eof && ks->begin >= ks->end) return -1; \ 72 | if (ks->begin >= ks->end) { \ 73 | ks->begin = 0; \ 74 | ks->end = __read(ks->f, ks->buf, __bufsize); \ 75 | if (ks->end == 0) { ks->is_eof = 1; return -1;} \ 76 | if (ks->end == -1) { ks->is_eof = 1; return -3;}\ 77 | } \ 78 | return (int)ks->buf[ks->begin++]; \ 79 | } 80 | 81 | #ifndef KSTRING_T 82 | #define KSTRING_T kstring_t 83 | typedef struct __kstring_t { 84 | size_t l, m; 85 | char *s; 86 | } kstring_t; 87 | #endif 88 | 89 | #ifndef kroundup32 90 | #define kroundup32(x) (--(x), (x)|=(x)>>1, (x)|=(x)>>2, (x)|=(x)>>4, (x)|=(x)>>8, (x)|=(x)>>16, ++(x)) 91 | #endif 92 | 93 | #define __KS_GETUNTIL(__read, __bufsize) \ 94 | static int ks_getuntil2(kstream_t *ks, int delimiter, kstring_t *str, int *dret, int append) \ 95 | { \ 96 | int gotany = 0; \ 97 | if (dret) *dret = 0; \ 98 | str->l = append? str->l : 0; \ 99 | for (;;) { \ 100 | int i; \ 101 | if (ks_err(ks)) return -3; \ 102 | if (ks->begin >= ks->end) { \ 103 | if (!ks->is_eof) { \ 104 | ks->begin = 0; \ 105 | ks->end = __read(ks->f, ks->buf, __bufsize); \ 106 | if (ks->end == 0) { ks->is_eof = 1; break; } \ 107 | if (ks->end == -1) { ks->is_eof = 1; return -3; } \ 108 | } else break; \ 109 | } \ 110 | if (delimiter == KS_SEP_LINE) { \ 111 | for (i = ks->begin; i < ks->end; ++i) \ 112 | if (ks->buf[i] == '\n') break; \ 113 | } else if (delimiter > KS_SEP_MAX) { \ 114 | for (i = ks->begin; i < ks->end; ++i) \ 115 | if (ks->buf[i] == delimiter) break; \ 116 | } else if (delimiter == KS_SEP_SPACE) { \ 117 | for (i = ks->begin; i < ks->end; ++i) \ 118 | if (isspace(ks->buf[i])) break; \ 119 | } else if (delimiter == KS_SEP_TAB) { \ 120 | for (i = ks->begin; i < ks->end; ++i) \ 121 | if (isspace(ks->buf[i]) && ks->buf[i] != ' ') break; \ 122 | } else i = 0; /* never come to here! */ \ 123 | if (str->m - str->l < (size_t)(i - ks->begin + 1)) { \ 124 | str->m = str->l + (i - ks->begin) + 1; \ 125 | kroundup32(str->m); \ 126 | str->s = (char*)realloc(str->s, str->m); \ 127 | } \ 128 | gotany = 1; \ 129 | memcpy(str->s + str->l, ks->buf + ks->begin, i - ks->begin); \ 130 | str->l = str->l + (i - ks->begin); \ 131 | ks->begin = i + 1; \ 132 | if (i < ks->end) { \ 133 | if (dret) *dret = ks->buf[i]; \ 134 | break; \ 135 | } \ 136 | } \ 137 | if (!gotany && ks_eof(ks)) return -1; \ 138 | if (str->s == 0) { \ 139 | str->m = 1; \ 140 | str->s = (char*)calloc(1, 1); \ 141 | } else if (delimiter == KS_SEP_LINE && str->l > 1 && str->s[str->l-1] == '\r') --str->l; \ 142 | str->s[str->l] = '\0'; \ 143 | return str->l; \ 144 | } \ 145 | static inline int ks_getuntil(kstream_t *ks, int delimiter, kstring_t *str, int *dret) \ 146 | { return ks_getuntil2(ks, delimiter, str, dret, 0); } 147 | 148 | #define KSTREAM_INIT(type_t, __read, __bufsize) \ 149 | __KS_TYPE(type_t) \ 150 | __KS_BASIC(type_t, __bufsize) \ 151 | __KS_GETC(__read, __bufsize) \ 152 | __KS_GETUNTIL(__read, __bufsize) 153 | 154 | #define kseq_rewind(ks) ((ks)->last_char = (ks)->f->is_eof = (ks)->f->begin = (ks)->f->end = 0) 155 | 156 | #define __KSEQ_BASIC(SCOPE, type_t) \ 157 | SCOPE kseq_t *kseq_init(type_t fd) \ 158 | { \ 159 | kseq_t *s = (kseq_t*)calloc(1, sizeof(kseq_t)); \ 160 | s->f = ks_init(fd); \ 161 | return s; \ 162 | } \ 163 | SCOPE void kseq_destroy(kseq_t *ks) \ 164 | { \ 165 | if (!ks) return; \ 166 | free(ks->name.s); free(ks->comment.s); free(ks->seq.s); free(ks->qual.s); \ 167 | ks_destroy(ks->f); \ 168 | free(ks); \ 169 | } 170 | 171 | /* Return value: 172 | >=0 length of the sequence (normal) 173 | -1 end-of-file 174 | -2 truncated quality string 175 | -3 error reading stream 176 | */ 177 | #define __KSEQ_READ(SCOPE) \ 178 | SCOPE int kseq_read(kseq_t *seq) \ 179 | { \ 180 | int c,r; \ 181 | kstream_t *ks = seq->f; \ 182 | if (seq->last_char == 0) { /* then jump to the next header line */ \ 183 | while ((c = ks_getc(ks)) >= 0 && c != '>' && c != '@'); \ 184 | if (c < 0) return c; /* end of file or error*/ \ 185 | seq->last_char = c; \ 186 | } /* else: the first header char has been read in the previous call */ \ 187 | seq->comment.l = seq->seq.l = seq->qual.l = 0; /* reset all members */ \ 188 | if ((r=ks_getuntil(ks, 0, &seq->name, &c)) < 0) return r; /* normal exit: EOF or error */ \ 189 | if (c != '\n') ks_getuntil(ks, KS_SEP_LINE, &seq->comment, 0); /* read FASTA/Q comment */ \ 190 | if (seq->seq.s == 0) { /* we can do this in the loop below, but that is slower */ \ 191 | seq->seq.m = 256; \ 192 | seq->seq.s = (char*)malloc(seq->seq.m); \ 193 | } \ 194 | while ((c = ks_getc(ks)) >= 0 && c != '>' && c != '+' && c != '@') { \ 195 | if (c == '\n') continue; /* skip empty lines */ \ 196 | seq->seq.s[seq->seq.l++] = c; /* this is safe: we always have enough space for 1 char */ \ 197 | ks_getuntil2(ks, KS_SEP_LINE, &seq->seq, 0, 1); /* read the rest of the line */ \ 198 | } \ 199 | if (c == '>' || c == '@') seq->last_char = c; /* the first header char has been read */ \ 200 | if (seq->seq.l + 1 >= seq->seq.m) { /* seq->seq.s[seq->seq.l] below may be out of boundary */ \ 201 | seq->seq.m = seq->seq.l + 2; \ 202 | kroundup32(seq->seq.m); /* rounded to the next closest 2^k */ \ 203 | seq->seq.s = (char*)realloc(seq->seq.s, seq->seq.m); \ 204 | } \ 205 | seq->seq.s[seq->seq.l] = 0; /* null terminated string */ \ 206 | if (c != '+') return seq->seq.l; /* FASTA */ \ 207 | if (seq->qual.m < seq->seq.m) { /* allocate memory for qual in case insufficient */ \ 208 | seq->qual.m = seq->seq.m; \ 209 | seq->qual.s = (char*)realloc(seq->qual.s, seq->qual.m); \ 210 | } \ 211 | while ((c = ks_getc(ks)) >= 0 && c != '\n'); /* skip the rest of '+' line */ \ 212 | if (c == -1) return -2; /* error: no quality string */ \ 213 | while ((c = ks_getuntil2(ks, KS_SEP_LINE, &seq->qual, 0, 1) >= 0 && seq->qual.l < seq->seq.l)); \ 214 | if (c == -3) return -3; /* stream error */ \ 215 | seq->last_char = 0; /* we have not come to the next header line */ \ 216 | if (seq->seq.l != seq->qual.l) return -2; /* error: qual string is of a different length */ \ 217 | return seq->seq.l; \ 218 | } 219 | 220 | #define __KSEQ_TYPE(type_t) \ 221 | typedef struct { \ 222 | kstring_t name, comment, seq, qual; \ 223 | int last_char; \ 224 | kstream_t *f; \ 225 | } kseq_t; 226 | 227 | #define KSEQ_INIT2(SCOPE, type_t, __read) \ 228 | KSTREAM_INIT(type_t, __read, 16384) \ 229 | __KSEQ_TYPE(type_t) \ 230 | __KSEQ_BASIC(SCOPE, type_t) \ 231 | __KSEQ_READ(SCOPE) 232 | 233 | #define KSEQ_INIT(type_t, __read) KSEQ_INIT2(static, type_t, __read) 234 | 235 | #define KSEQ_DECLARE(type_t) \ 236 | __KS_TYPE(type_t) \ 237 | __KSEQ_TYPE(type_t) \ 238 | extern kseq_t *kseq_init(type_t fd); \ 239 | void kseq_destroy(kseq_t *ks); \ 240 | int kseq_read(kseq_t *seq); 241 | 242 | #endif 243 | -------------------------------------------------------------------------------- /src/poa.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include "spoa/spoa.hpp" 9 | #include "spoa/graph.hpp" 10 | namespace py = pybind11; 11 | 12 | using namespace spoa; 13 | using Edge = spoa::Graph::Edge; 14 | using Node = spoa::Graph::Node; 15 | 16 | 17 | static constexpr int32_t GLOBAL = 1; 18 | 19 | std::unique_ptr make_engine() { 20 | return spoa::AlignmentEngine::Create( static_cast(GLOBAL), 5, -4, -8, -6, -10, -4); 21 | } 22 | 23 | struct GraphRepr { 24 | std::vector bases; 25 | std::vector strings; 26 | std::vector> outEdges; 27 | std::vector> inEdges; 28 | std::vector> alignedNodes; 29 | std::vector alignments; 30 | std::string consensus; 31 | }; 32 | 33 | struct SequenceGroup { 34 | py::list sequences; 35 | std::vector scores; 36 | std::string consensus; 37 | std::vector inputs; 38 | std::vector alignments; 39 | std::unique_ptr graph; 40 | bool isBuilt{false}; 41 | SequenceGroup(py::list sequences): sequences{sequences} {} 42 | void build(int min_coverage=-1, spoa::AlignmentEngine* engine=nullptr) { 43 | if(min_coverage <= 0) { 44 | min_coverage = std::max(py::size_t(0), (sequences.size() - 1) / 2); 45 | } 46 | std::unique_ptr localEngine(engine ? std::unique_ptr(): make_engine()); 47 | if(localEngine) engine = localEngine.get(); 48 | graph = std::make_unique(); 49 | /* 50 | auto getQual = [&] () -> std::optional> { 51 | auto it = py::cast(qualities); 52 | if(!it.is_none()) { 53 | } 54 | return std::nullopt; 55 | }; 56 | */ 57 | for(auto seq: sequences) { 58 | int32_t score{0}; 59 | py::ssize_t size; 60 | py::str str = py::cast(seq); 61 | const char *ptr = PyUnicode_AsUTF8AndSize(str.ptr(), &size); 62 | scores.push_back(score); 63 | inputs.emplace_back(ptr); 64 | const auto alignment = engine->Align(inputs.back(), *graph, &score); 65 | graph->AddAlignment(alignment, ptr); 66 | alignments.push_back(alignment); 67 | } 68 | consensus = graph->GenerateConsensus(min_coverage); 69 | isBuilt = true; 70 | } 71 | py::dict GraphToPython() { 72 | if(!isBuilt) { 73 | build(); 74 | } 75 | using namespace pybind11::literals; 76 | std::string bases; 77 | std::unordered_map edgeIdMap; 78 | std::unordered_map nodeIdMap; 79 | std::unordered_map nodeRankMap; 80 | std::unordered_map> seqIdToNodes; 81 | std::unordered_map> seqIdToEdges; 82 | std::vector edgeLabels; 83 | std::vector edgeIndptr; 84 | edgeIndptr.push_back(0); 85 | 86 | std::unordered_map edges; // Map from (from, to): edge_id 87 | for(const auto& edge: graph->edges()) { 88 | const int32_t id = edgeIdMap.size(); 89 | edgeIdMap.emplace(edge.get(), id); 90 | } 91 | const auto& rankToNode = graph->rank_to_node(); 92 | int32_t nodeId{0}; 93 | //const auto edgeToId = [&edgeIdMap](Edge* const edge) {return edgeIdMap.at(edge);}; 94 | for(const auto& node: rankToNode) { 95 | bases.push_back(node->code); 96 | nodeRankMap.emplace(node, nodeId++); 97 | } 98 | for(const auto& node: graph->nodes()) { 99 | const int32_t id = nodeIdMap.size(); 100 | nodeIdMap.emplace(node.get(), id); 101 | } 102 | const auto nodeToId = [&nodeIdMap](Node* const node) {return nodeIdMap.at(node);}; 103 | auto updateEdges = [&edges](const int32_t from, const int32_t to) { 104 | const int32_t edgeId = edges.size(); 105 | edges.emplace((uint64_t(from) << 32) | to, edgeId); 106 | }; 107 | for(const auto& edge: graph->edges()) { 108 | updateEdges(nodeToId(edge->head), nodeToId(edge->tail)); 109 | edgeIndptr.push_back(edgeIndptr.back() + edge->labels.size()); 110 | edgeLabels.insert(edgeLabels.end(), edge->labels.begin(), edge->labels.end()); 111 | } 112 | // 1. Get the nodes out in topological order. 113 | // 2. Get all edges out. 114 | // 3. Annotate all edges with input support. 115 | // 4. Annotate all nodes with input support. 116 | // 5. Generate all supported paths through the graph. 117 | // 6. Outside of this - bring in other data from the reads. 118 | for(const auto& edge: graph->edges()) { 119 | Node* const source = edge->head; 120 | Node* const sink = edge->tail; 121 | for(const int32_t id: edge->labels) { 122 | seqIdToNodes[id].insert(nodeIdMap.at(source)); 123 | seqIdToNodes[id].insert(nodeIdMap.at(sink)); 124 | seqIdToEdges[id].insert(edgeIdMap.at(edge.get())); 125 | } 126 | } 127 | std::vector nodeRanks(nodeRankMap.size()); 128 | for(const auto& [node, rank]: nodeRankMap) { 129 | nodeRanks[nodeIdMap.at(node)] = rank; 130 | } 131 | py::array_t nodeRanksPy({std::ssize(nodeRanks)}); 132 | int32_t* data = reinterpret_cast(nodeRanksPy.request().ptr); 133 | std::copy(nodeRanks.begin(), nodeRanks.end(), data); 134 | 135 | std::vector seqAlignments; // Packed sparse matrix 136 | 137 | std::vector seqIndptr; 138 | seqIndptr.push_back(0); 139 | 140 | for(const auto& [seqId, nodes]: seqIdToNodes) { 141 | const int64_t numNodes = nodes.size(); 142 | seqIndptr.push_back(numNodes + seqIndptr.back()); 143 | std::copy(nodes.begin(), nodes.end(), std::back_inserter(seqAlignments)); 144 | } 145 | 146 | py::array_t seqAlignmentsPy({std::ssize(seqAlignments)}); 147 | std::copy(seqAlignments.begin(), seqAlignments.end(), reinterpret_cast(seqAlignmentsPy.request().ptr)); 148 | 149 | py::array_t seqIndptrPy({std::ssize(seqIndptr)}); 150 | std::copy(seqIndptr.begin(), seqIndptr.end(), reinterpret_cast(seqIndptrPy.request().ptr)); 151 | 152 | py::array_t edgeLabelsPy({std::ssize(edgeLabels)}); 153 | std::copy(edgeLabels.begin(), edgeLabels.end(), reinterpret_cast(edgeLabelsPy.request().ptr)); 154 | 155 | py::array_t edgeIndptrPy({std::ssize(edgeIndptr)}); 156 | std::copy(edgeIndptr.begin(), edgeIndptr.end(), reinterpret_cast(edgeIndptrPy.request().ptr)); 157 | 158 | py::array_t matrixCOOPy({std::ssize(edges) * 3}); 159 | int32_t* destPtr = reinterpret_cast(matrixCOOPy.request().ptr); 160 | for(const auto& edge: edges) { 161 | *destPtr++ = edge.first >> 32; 162 | *destPtr++ = (edge.first << 32) >> 32; 163 | *destPtr++ = edge.second; 164 | } 165 | matrixCOOPy.resize(std::vector{{static_cast(edges.size()), 3}}); 166 | 167 | std::transform(std::cbegin(bases), std::cend(bases), std::begin(bases), [&](const char base) -> char {return graph->decoder(base);}); 168 | 169 | return py::dict("bases"_a=bases, "ranks"_a=nodeRanksPy, "seq_nodes"_a=seqAlignmentsPy, "seq_indptr"_a=seqIndptrPy, 170 | "edge_nodes"_a=edgeLabelsPy, "edge_indptr"_a=edgeIndptrPy, "matrix_coo"_a=matrixCOOPy, "consensus"_a=consensus, "input_sequences"_a=sequences); 171 | } 172 | GraphRepr GenerateGraph() { 173 | GraphRepr ret; 174 | std::vector bases; 175 | std::unordered_map edgeIdMap; 176 | std::unordered_map nodeIdMap; 177 | for(const auto& edge: graph->edges()) { 178 | const int32_t id = edgeIdMap.size(); 179 | edgeIdMap.emplace(edge.get(), id); 180 | } 181 | const auto& rankToNode = graph->rank_to_node(); 182 | int32_t nodeId{0}; 183 | const auto edgeToId = [&edgeIdMap](Edge* const edge) {return edgeIdMap.at(edge);}; 184 | for(const auto& node: rankToNode) { 185 | bases.push_back(node->code); 186 | nodeIdMap.emplace(node, nodeId); 187 | std::vector outEdges; 188 | std::transform(std::cbegin(node->outedges), std::cend(node->outedges), std::back_inserter(outEdges), edgeToId); 189 | ret.outEdges.emplace_back(outEdges); 190 | 191 | std::vector inEdges; 192 | std::transform(std::cbegin(node->inedges), std::cend(node->inedges), std::back_inserter(inEdges), edgeToId); 193 | ret.inEdges.emplace_back(inEdges); 194 | 195 | std::vector alignedNodes; 196 | std::transform(std::cbegin(node->aligned_nodes), std::cend(node->aligned_nodes), std::back_inserter(alignedNodes), [&nodeIdMap](const auto node) {return nodeIdMap.at(node);}); 197 | ret.alignedNodes.emplace_back(alignedNodes); 198 | ++nodeId; 199 | } 200 | ret.bases = std::move(bases); 201 | // copy the rest 202 | ret.consensus = consensus; 203 | ret.strings = inputs; 204 | ret.alignments = alignments; 205 | return ret; 206 | } 207 | }; 208 | 209 | void init_poa(py::module &m) { 210 | py::class_(m, "SequenceGraph") 211 | .def(py::init()) 212 | .def("build", [](SequenceGroup& group, int minCov) {group.build(minCov);}, py::arg("mincov") = -1) 213 | .def("matrix", &SequenceGroup::GraphToPython) 214 | .def_property_readonly("sequence", [] (const SequenceGroup& group) -> std::string {return group.consensus;}); 215 | } 216 | -------------------------------------------------------------------------------- /bioseq/__init__.py: -------------------------------------------------------------------------------- 1 | import cbioseq 2 | from cbioseq import * 3 | import bioseq.poa_util as poa_util 4 | import bioseq.softmax as softmax 5 | import bioseq.loaders as loaders 6 | import bioseq.annotations as annotations 7 | import bioseq.blosum as blosum 8 | import bioseq.lem as lem 9 | import bioseq.poa_util as poa_util 10 | from bioseq.poa_util import FastxSeq, ExtractedPOAGraph 11 | 12 | """ 13 | bioseq provides tokenizers and utilities for generating embeddings 14 | See the list of existing tokenizers: 15 | DNATokenizer 16 | AmineTokenizer 17 | Reduced6Tokenizer 18 | Reduced8Tokenizer 19 | Reduced10Tokenizer 20 | Reduced14Tokenizer 21 | DayhoffTokenizer 22 | LIATokenizer 23 | LIBTokenizer 24 | These are pre-made, with no bos, eos, or padding characters. 25 | Use the keys from this set: 26 | ("SEB6", "SEB8", "SEB10", "SEV10", "MURPHY", "LIA10", "LIB10", "SEB6", "DAYHOFF", "DNA4", "DNA", "DNA5", "KETO", "PURPYR", "BYTES") 27 | to index specific tokenizers into the premade tokenizer dictionaries. 28 | 29 | There are 'bos_tokenizers', where BOS is an additional character, but there is no padding or EOS char. 30 | There are 'eos_tokenizers', where EOS is an additional character, but there is no padding or BOS char. 31 | There are 'beos_tokenizers', where EOS and BOS are an additional character, but there is no padding. 32 | There are 'pbeos_tokenizers', where EOS, BOS, and padchar are all characters. This means there are 3 additional alphabet characters. 33 | """ 34 | 35 | 36 | def onehot_encode(tokenizer, seqbatch, padlen=-1, destchar='B', batch_first=False, to_pytorch=False, device=None): 37 | """ 38 | Args: 39 | tokenizer: 40 | cbioseq.Tokenizer 41 | This is the type doing the encoding 42 | seqbatch: Union[Iterable[Union[Str,Bytes,numpy.ndarray]],Str,Bytes,numpy.ndarray] 43 | Can be a set of sequences or a single sequence 44 | if a set of sequences, padlen required, and all sequences must be short enough to fit 45 | in that length, including EOS, BOS, and padding characters, if relevant 46 | padlen: Size to which to pad the sequence. Optional 47 | destchar: One of BbIiUuLlQq - specifies the data type of the encoded sequence 48 | to_pytorch: 49 | False, by default. Set to true to convert from numpy to pytorch. 50 | device: 51 | None by default. Set to a pytorch device or a string representing it to 52 | cause this function to copy to device after encoding. 53 | """ 54 | if isinstance(seqbatch, str) or isinstance(seqbatch, bytes): 55 | res = tokenizer.onehot_encode(seqbatch, padlen, destchar) 56 | else: 57 | res = tokenizer.batch_onehot_encode(seqbatch, padlen, destchar) 58 | if batch_first: 59 | from einops import rearrange 60 | res = rearrange(res, 'seq batch base -> batch seq base') 61 | if to_pytorch: 62 | from torch import from_numpy 63 | res = from_numpy(res) 64 | if device is not None and res.device: 65 | res = res.to(device) 66 | return res 67 | 68 | 69 | def f_encode(seqbatch, key="DNA", bos=False, eos=False, padchar=False, padlen=-1, destchar='B', batch_first=False, to_pytorch=False, device=None): 70 | """ 71 | Functional encoding of sequence batch. 72 | Creates a tokenizer and then uses it 73 | Args: 74 | seqbatch: Union[Iterable[Union[Str,Bytes,numpy.ndarray]],Str,Bytes,numpy.ndarray] 75 | Kwargs: 76 | key: 77 | Alphabet with which to encode. Choose from 78 | "DNA" is the default, with 4 character encodings. 79 | 80 | 81 | Other options: 82 | DNA: 83 | DNA - 4 characters - ACGT 84 | DNA4 - alias for DNA 85 | DNA5 - DNA plus character for N, separate from unexpected wildcards 86 | Reduced DNA: 87 | KETO - split into keto/amine -- AC/GT 88 | PURPYR - split into purines/pyrimidines (AG/CT) 89 | 90 | For amino acid sequences: 91 | Full alphabet: 92 | AMINO20,AMINO,PROTEIN 93 | -- 20 amino acids 94 | Reduced alphabets: 95 | SEB6, SEB8, SEB10, SEB14 -- 96 | All reduced protein alphabets, used for long-range homology detection 97 | LIA10, LIB10 98 | DAYHOFF 99 | KETO 100 | PURPYR 101 | 102 | ,SEB8,SEB10,SEB14,SEV10,MURPHY,LIA10,LIB10,SEB6,DAYHOFF,KETO,PURPYR,DNA4,DNA,DNA5 103 | 104 | bos: To include bos as its own symbol [False] 105 | eos: To include eos as its own symbol [False] 106 | padchar: To include padchar as its own symbol [False] 107 | padlen: Size to which to pad the sequence. Optional 108 | destchar: One of BbIiUuLlQq - specifies the data type of the encoded sequence 109 | device: 110 | None by default. Set to a pytorch device or a string representing it to 111 | cause this function to copy to device after encoding. 112 | """ 113 | from cbioseq import Tokenizer 114 | tokenizer = Tokenizer(key, bos=bos, eos=eos, padchar=padchar) 115 | return onehot_encode(tokenizer, seqbatch, padlen=padlen, destchar=destchar, 116 | batch_first=batch_first, to_pytorch=to_pytorch, device=device) 117 | 118 | 119 | keys = ("SEB6", "SEB8", "SEB10", "SEV10", "MURPHY", "LIA10", "LIB10", "SEB6", "DAYHOFF", "DNA4", "DNA", "DNA5", "KETO", "PURPYR", "BYTES", "AMINO20", "PROTEIN") 120 | bkeys = keys + tuple(map(str.lower, keys)) 121 | 122 | 123 | DNATokenizer = cbioseq.Tokenizer("DNA") 124 | AmineTokenizer = cbioseq.Tokenizer("AMINO20") 125 | Reduced6Tokenizer = cbioseq.Tokenizer("SEB6") 126 | Reduced8Tokenizer = cbioseq.Tokenizer("SEB8") 127 | Reduced10Tokenizer = cbioseq.Tokenizer("SEB10") 128 | Reduced14Tokenizer = cbioseq.Tokenizer("SEB14") 129 | DayhoffTokenizer = cbioseq.Tokenizer("DAYHOFF") 130 | LIATokenizer = cbioseq.Tokenizer("LIA10") 131 | LIBTokenizer = cbioseq.Tokenizer("LIB10") 132 | default_tokenizers = {"DNA": DNATokenizer, 133 | "AMINO20": AmineTokenizer, 134 | "AMINE": AmineTokenizer, 135 | "PROTEIN": AmineTokenizer, 136 | "SEB6": Reduced6Tokenizer, 137 | "SEB8": Reduced8Tokenizer, 138 | "SEB10": Reduced10Tokenizer, 139 | "SEB14": Reduced14Tokenizer, 140 | "LIA10": LIATokenizer, 141 | "LIA": LIATokenizer, 142 | "LIB10": LIBTokenizer, 143 | "LIB": LIBTokenizer} 144 | pbeos_tokenizers = {k: cbioseq.Tokenizer(k, bos=True, eos=True, padchar=True) for k in bkeys} 145 | beos_tokenizers = {k: cbioseq.Tokenizer(k, bos=True, eos=True, padchar=False) for k in bkeys} 146 | pbos_tokenizers = {k: cbioseq.Tokenizer(k, bos=True, eos=False, padchar=True) for k in bkeys} 147 | bos_tokenizers = {k: cbioseq.Tokenizer(k, bos=True, eos=False, padchar=False) for k in bkeys} 148 | peos_tokenizers = {k: cbioseq.Tokenizer(k, bos=False, eos=True, padchar=True) for k in bkeys} 149 | eos_tokenizers = {k: cbioseq.Tokenizer(k, bos=False, eos=True, padchar=False) for k in bkeys} 150 | pos_tokenizers = {k: cbioseq.Tokenizer(k, bos=False, eos=False, padchar=True) for k in bkeys} 151 | total_tokenizer_dict = {} 152 | for bos in [0, 1]: 153 | for eos in [0, 1]: 154 | for padchar in [0, 1]: 155 | for k in bkeys: 156 | total_tokenizer_dict[(bos, eos, padchar, k)] = cbioseq.Tokenizer(k.upper(), bos=bos, eos=eos, padchar=padchar) 157 | 158 | 159 | def get_tokenizer_dict(bos, eos, padchar): 160 | if bos: 161 | if eos: 162 | return pbeos_tokenizers if padchar else beos_tokenizers 163 | else: 164 | return pbos_tokenizers if padchar else bos_tokenizers 165 | elif eos: 166 | return peos_tokenizers if padchar else eos_tokenizers 167 | else: 168 | return pos_tokenizers if padchar else default_tokenizers 169 | 170 | 171 | def make_embedding(tok, embdim, maxnorm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None): 172 | """ 173 | Args: 174 | tok: bioseq.Tokenizer 175 | embdim: Int, dimension for embeddings 176 | KWArgs: 177 | maxnorm = None: maximum norm for embeddings. If not None, embeddings will be scaled by the p-norm corresponding to norm_type 178 | norm_type = 2.: Sets p for the Lp norm. Must be See torch.nn.Embedding. 179 | scale_grad_by_freq = False: Whether to scale gradient by the count frequencies. See torch.nn.Embedding. 180 | sparse = False: Whether to use sparse embeddings. False by default. 181 | _weight = None: If providing tensors from pre-trained, set them here. Must match the tokenizer's number of tokens and the embedding dimension. 182 | """ 183 | assert norm_type >= 1., f"{norm_type} is not >= 1., so it is not a norm." 184 | import torch.nn as nn 185 | return nn.Embedding(tok.alphabet_size(), 186 | embdim, 187 | padding_idx=tok.pad() if tok.is_padded() else None, 188 | scale_grad_by_freq=scale_grad_by_freq, sparse=sparse, _weight=_weight) 189 | 190 | 191 | def torchify(arr): 192 | '''Simply wrapper for torch.from_numpy, converts numpy array to pytorch. 193 | ''' 194 | from torch import from_numpy 195 | return from_numpy(arr) 196 | 197 | 198 | class PyViewFF: 199 | ''' 200 | PyViewFF provides a pure-python view into the C++ FlatFile database. 201 | ''' 202 | def __init__(self, path): 203 | fp = np.memmap(path, mode='r', dtype=np.uint8) 204 | self.nseqs = int(fp[:8].view(np.uint64)[0]) 205 | self.offsets = fp[8:8 * (2 + self.nseqs)].view(np.uint64) 206 | self.seqs = fp[8 * (2 + self.nseqs):] 207 | self.fp = fp 208 | def access(self, idx): 209 | res = self.seqs[self.offsets[idx]:self.offsets[idx + 1]] 210 | return bytes(res) 211 | def __getitem__(self, idx): 212 | if isinstance(idx, int): 213 | return self.access(idx) 214 | elif isinstance(idx, slice): 215 | return [self.access(x) for x in range(idx.start, idx.stop, idx.step)] 216 | else: 217 | raise InvalidArgument("PyViewFF can only support slices and integers.") 218 | def __len__(self): 219 | return self.nseqs 220 | 221 | 222 | __all__ = ["onehot_encode", "cbioseq", "f_encode", "Tokenizer", "tax", 223 | "make_embedding", 224 | "bos_tokenizers", "eos_tokenizers", "beos_tokenizers", "pbeos_tokenizers", "peos_tokenizers", "pbos_tokenizers", "pos_tokenizers", "default_tokenizers", "get_tokenizer_dict", 225 | "DNATokenizer", "AmineTokenizer", "Reduced6Tokenizer", "Reduced8Tokenizer", "Reduced10Tokenizer", "Reduced14Tokenizer", "DayhoffTokenizer", "LIATokenizer", "LIBTokenizer", 226 | "decoders", 'softmax', 'hattn', 'loaders', 'torchify'] 227 | -------------------------------------------------------------------------------- /src/alphabet.h: -------------------------------------------------------------------------------- 1 | #ifndef BIOSEQ_ALPHBET_H__ 2 | #define BIOSEQ_ALPHBET_H__ 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | namespace alph { 15 | using std::size_t; 16 | 17 | template 18 | struct TAlphabet { 19 | static_assert(NCHAR > 1, "Nchar must be positive"); 20 | static_assert(std::is_integral::value, "VT must be integral"); 21 | const char *name; 22 | const char *setstr; 23 | private: 24 | size_t nc; 25 | public: 26 | const bool padding; 27 | size_t nchars() const {return nc + 1;} // One for padding 28 | size_t num_commas() const noexcept {return nc;} 29 | using LUType = std::array; 30 | LUType lut; 31 | bool has_padding() const noexcept {return padding;} 32 | static constexpr LUType make_lut(const char *s, const size_t nc, bool padding=false, const char *aliases=0) { 33 | LUType arr{-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}; 34 | int id = padding; 35 | size_t ci = 0; 36 | for(size_t i = 0; i < nc; ++i, ++id, ++ci) { 37 | while(s[ci] && s[ci] != ',') { 38 | const auto v = s[ci++]; 39 | arr[v | 32] = arr[v & static_cast(0xdf)] = id; // lower-case and upper-case 40 | } 41 | } 42 | while(s[ci]) { 43 | const auto v = s[ci++]; 44 | arr[v | 32] = arr[v & static_cast(0xdf)] = id; // lower-case and upper-case 45 | } 46 | if(aliases) { 47 | const char *p = aliases; 48 | while(*p && *p != ':') ++p; 49 | if(*p) { 50 | const size_t offset = p - aliases; 51 | for(size_t i = 0; i < offset; ++i) { 52 | const auto destchar = arr[p[i + 1]]; 53 | if(arr[aliases[i] & 0xdf] == VT(-1)) 54 | arr[aliases[i] & 0xdf] = arr[destchar]; 55 | if(arr[aliases[i] | 32] == VT(-1)) 56 | arr[aliases[i] | 32] = arr[destchar]; 57 | } 58 | } 59 | } 60 | return arr; 61 | } 62 | std::vector> to_sparse() const { 63 | std::vector> ret; 64 | for(size_t i = 0; i < NCHAR; ++i) { 65 | if(lut[i] != VT(-1)) 66 | ret.push_back({VT(i), lut[i]}); 67 | } 68 | return ret; 69 | } 70 | void display() const { 71 | for(size_t i = 0; i < NCHAR; ++i) { 72 | if(lut[i] != VT(-1)) { 73 | std::fprintf(stderr, "Mapping %d/%c to %d/%c\n", int(i), char(i), int(lut[i]), lut[i]); 74 | } 75 | } 76 | } 77 | using SignedVT = typename std::make_signed::type; 78 | constexpr VT translate(VT x) const {return lut[static_cast(x)];} 79 | constexpr VT *data() {return lut.data();} 80 | constexpr const VT *data() const {return lut.data();} 81 | static constexpr size_t size() {return NCHAR;} 82 | static constexpr size_t ncommas(const char *s) { 83 | size_t ret = 0, i = 0; 84 | while(s[i]) ret += (s[i] == ','), ++i; 85 | return ret; 86 | } 87 | constexpr TAlphabet(const TAlphabet &) = default; 88 | constexpr TAlphabet(TAlphabet &&) = default; 89 | constexpr TAlphabet(const char *name, const char *s, bool padding=false, const char *aliases=0): name(name), setstr(s), nc(ncommas(s)), padding(padding), lut(make_lut(s, nc, padding, aliases)) { 90 | } 91 | static constexpr LUType emptylut(bool padding) { 92 | LUType tlut{-1}; 93 | for(size_t i = 0; i < NCHAR; ++i) tlut[i] = i + padding; 94 | return tlut; 95 | } 96 | constexpr TAlphabet(bool padding=false): name("Bytes"), setstr(""), nc(NCHAR - 1 + padding), padding(padding), lut(emptylut(padding)) { 97 | } 98 | }; 99 | struct Alphabet: public TAlphabet { 100 | template constexpr Alphabet(Args &&...args): TAlphabet(std::forward(args)...) {} 101 | }; 102 | 103 | // Protein Alphabets 104 | // All protein alphabets handle pyrrolysine and selenocysteine if unhandled 105 | // Handle Pyrrolysine (O) by mapping it to Lysine (K) if unhandled 106 | // Handle SelenoCysteine (U) by mapping it to Cysteine if unhandled 107 | // (OU:KC) means O maps to K and U maps to C 108 | static constexpr const Alphabet BYTES; 109 | static constexpr const Alphabet AMINO20("Standard20", "A,C,D,E,F,G,H,I,K,L,M,N,P,Q,R,S,T,V,W,Y", false, "OU:KC"); 110 | 111 | static constexpr const Alphabet SEB14("SE-B(14)", "A,C,D,EQ,FY,G,H,IV,KR,LM,N,P,ST,W", false, "OU:KC"); 112 | 113 | static constexpr const Alphabet SEB10("SE-B(10)", "AST,C,DN,EQ,FY,G,HW,ILMV,KR,P", false, "OU:KC"); 114 | static constexpr const Alphabet SEV10("SE-V(10)", "AST,C,DEN,FY,G,H,ILMV,KQR,P,W", false, "OU:KC"); 115 | static constexpr const Alphabet SOLISD("Solis-D", "AM,C,DNS,EKQR,F,GP,HT,IV,LY,W", false, "OU:KC"); 116 | static constexpr const Alphabet SOLISG("Solis-G", "AEFIKLMQRVW,C,D,G,H,N,P,S,T,Y", false, "OU:KC"); 117 | static constexpr const Alphabet MURPHY("Murphy", "A,C,DENQ,FWY,G,H,ILMV,KR,P,ST", false, "OU:KC"); 118 | static constexpr const Alphabet LIA10("Li-A(10)", "AC,DE,FWY,G,HN,IV,KQR,LM,P,ST", false, "OU:KC"); 119 | static constexpr const Alphabet LIB10("Li-B(10)", "AST,C,DEQ,FWY,G,HN,IV,KR,LM,P", false, "OU:KC"); 120 | 121 | static constexpr const Alphabet SEB8("SE-B(8)","AST,C,DHN,EKQR,FWY,G,ILMV,P", false, "OU:KC"); 122 | static constexpr const Alphabet SEB6("SE-B(6)","AST,CP,DHNEKQR,FWY,G,ILMV", false, "OU:KC"); 123 | 124 | static constexpr const Alphabet DAYHOFF("Dayhoff","AGPST,C,DENQ,FWY,HKR,ILMV", false, "OU:KC"); 125 | 126 | namespace amine_traits { 127 | static constexpr const std::string_view alcoholic_bases("oST"); 128 | static constexpr const std::string_view hydrophobic_bases("hACFGHIKLMRTVWY"); 129 | static constexpr const std::string_view polar_bases("pCDEHKNQRST"); 130 | static constexpr const std::string_view charged_bases("cDEHKR"); 131 | static constexpr const std::string_view positive_bases("+HKR"); 132 | static constexpr const std::string_view negative_bases("-DE"); 133 | static constexpr const std::string_view small_bases("sAGSCDNPTV"); 134 | static constexpr const std::string_view tiny_bases("uAGS"); 135 | static constexpr const std::string_view aromatic_bases("aFHWY"); 136 | static constexpr const std::string_view turnlike_bases("tACDEGHKNQRST"); 137 | 138 | static constexpr bool is_alcoholic(char c) { 139 | switch(c) case 'o': case 'S': case 'T': return true; 140 | return false; 141 | } 142 | static constexpr bool is_hydrophobic(char c) { 143 | switch(c) { 144 | case 'h': case 'A': case 'C': case 'F': case 'G': case 'H': case 'I': case 'K': case 'L': case 'M': case 'R': case 'T': case 'V': case 'W': case 'Y': 145 | return true; 146 | } 147 | return false; 148 | } 149 | static constexpr bool is_polar(char c) { 150 | switch(c) case 'p': case 'C': case 'D': case 'E': case 'H': case 'K': case 'N': case 'Q': case 'R': case 'S': case 'T': return true; 151 | return false; 152 | } 153 | static constexpr bool is_negative(char c) { 154 | switch(c) case '-': case 'D': case 'E': return true; 155 | return false; 156 | } 157 | static constexpr bool is_positive(char c) { 158 | switch(c) case '+': case 'H': case 'K': case 'R': return true; 159 | return false; 160 | } 161 | static constexpr bool is_charged(char c) { 162 | switch(c) case 'c': case 'D': case 'E': case 'H': case 'K': case 'R': return true; 163 | return false; 164 | } 165 | static constexpr bool is_small(char c) { 166 | switch(c) case 's': case 'A': case 'G': case 'S': case 'C': case 'D': case 'N': case 'P': case 'T': case 'V': return true; 167 | return false; 168 | } 169 | static constexpr bool is_tiny(char c) { 170 | switch(c) case 'u': case 'A': case 'G': case 'S': return true; 171 | return false; 172 | } 173 | static constexpr bool is_aromatic(char c) { 174 | switch(c) case 'a': case 'F': case 'H': case 'W': case 'Y': return true; 175 | return false; 176 | } 177 | static constexpr bool is_turnlike(char c) { 178 | switch(c) case 't': case 'A': case 'C': case 'D': case 'E': case 'G': case 'H': case 'K': case 'N': case 'Q': case 'R': case 'S': case 'T': 179 | return true; 180 | return false; 181 | } 182 | 183 | } // namespace amine_traits 184 | 185 | 186 | // DNA alphabets 187 | // We also map U to T in order to support RNA sequences 188 | 189 | static constexpr const Alphabet DNA4("DNA4", "A,C,G,T", false, "U:T"); 190 | static constexpr const Alphabet DNA5("DNA5", "A,C,G,T,NMRWSYKVHDB", false, "U:T"); 191 | 192 | static constexpr const Alphabet DNA2KETAMINE("DNA2", "ACM,KGT", false, "U:T"); // Amino/Ketones 193 | static constexpr const Alphabet DNA2PYRPUR("DNA2", "AGR,YCT", false, "U:T"); // Purines/Pyrimidines 194 | static constexpr const Alphabet DNA2METHYL("DNAMETH", "C,AGT", false, "U:T"); // Purines/Pyrimidines 195 | 196 | // Source: Reference 197 | // Edgar, RC (2004) Local homology recognition and distance measures in linear time using compressed amino acid alphabets, NAR 32(1), 380-385. doi: 10.1093/nar/gkh180 198 | static const std::map CAMAP { 199 | {"BYTES", &BYTES}, 200 | {"AMINO20", &AMINO20}, 201 | {"AMINO", &AMINO20}, 202 | {"PROTEIN", &AMINO20}, 203 | {"SEB8", &SEB8}, 204 | {"SEB10", &SEB10}, 205 | {"SEB14", &SEB14}, 206 | {"SEV10", &SEV10}, 207 | {"MURPHY", &MURPHY}, 208 | {"LIA10", &LIA10}, 209 | {"LIB10", &LIB10}, 210 | {"SEB6", &SEB6}, 211 | {"DAYHOFF", &DAYHOFF}, 212 | 213 | {"DNAMETH", &DNA2METHYL}, 214 | {"C", &DNA2METHYL}, 215 | {"KETO", &DNA2KETAMINE}, 216 | {"PURPYR", &DNA2PYRPUR}, 217 | 218 | {"DNA4", &DNA4}, 219 | {"DNA", &DNA4}, 220 | 221 | {"DNA5", &DNA5} 222 | }; 223 | 224 | } // alph 225 | 226 | using alph::Alphabet; 227 | using alph::BYTES; 228 | using alph::AMINO20; 229 | using alph::SEB14; 230 | using alph::SEB10; 231 | using alph::SEB6; 232 | using alph::SEB8; 233 | using alph::SEV10; 234 | using alph::SOLISD; 235 | using alph::SOLISG; 236 | using alph::MURPHY; 237 | using alph::LIA10; 238 | using alph::LIB10; 239 | using alph::DAYHOFF; 240 | using alph::DNA5; 241 | using alph::DNA4; 242 | using alph::DNA2KETAMINE; 243 | using alph::DNA2PYRPUR; 244 | 245 | 246 | #endif /* BIOSEQ_ALPHBET_H__ */ 247 | -------------------------------------------------------------------------------- /bioseq/hattn.py: -------------------------------------------------------------------------------- 1 | from math import log2, ceil 2 | import sys 3 | import torch 4 | from torch import nn, einsum, diagonal 5 | import torch.nn.functional as F 6 | 7 | from h_transformer_1d.reversible import ReversibleSequence, SequentialSequence 8 | import bioseq 9 | from rotary_embedding_torch import apply_rotary_emb, RotaryEmbedding 10 | import einops 11 | 12 | # helpers 13 | 14 | 15 | def exists(val): 16 | return val is not None 17 | 18 | 19 | def masked_aggregate(tensor, mask=None, dim=-1, average=True): 20 | if not exists(mask): 21 | fn = torch.sum if not average else torch.mean 22 | return fn(tensor, dim=dim) 23 | 24 | diff_len = len(tensor.shape) - len(mask.shape) 25 | mask = mask[(..., *((None,) * diff_len))] 26 | tensor = tensor.masked_fill(~mask, 0.) 27 | 28 | total_el = mask.sum(dim=dim) 29 | agg = tensor.sum(dim=dim) 30 | 31 | if average: 32 | agg = agg / total_el.clamp(min=1.) 33 | 34 | agg.masked_fill_(total_el == 0, 0.) 35 | return agg 36 | 37 | 38 | def shift(t, amount, mask=None): 39 | if amount == 0: 40 | return t 41 | 42 | if exists(mask): 43 | t = t.masked_fill(~mask[..., None], 0.) 44 | 45 | return F.pad(t, (0, 0, amount, -amount), value=0.) 46 | 47 | # helper classes 48 | 49 | 50 | class PreNorm(nn.Module): 51 | def __init__(self, dim, fn): 52 | super().__init__() 53 | self.fn = fn 54 | self.norm = nn.LayerNorm(dim) 55 | 56 | def forward(self, x, **kwargs): 57 | x = self.norm(x) 58 | return self.fn(x, **kwargs) 59 | 60 | 61 | class FeedForward(nn.Module): 62 | def __init__( 63 | self, 64 | dim, 65 | *, 66 | mult=4 67 | ): 68 | super().__init__() 69 | self.net = nn.Sequential( 70 | nn.Linear(dim, dim * mult), 71 | nn.GELU(), 72 | nn.Linear(dim * mult, dim) 73 | ) 74 | 75 | def forward(self, x): 76 | return self.net(x) 77 | 78 | # token shifting 79 | 80 | 81 | class PreShiftTokens(nn.Module): 82 | def __init__(self, shifts, fn): 83 | super().__init__() 84 | self.fn = fn 85 | self.shifts = tuple(shifts) 86 | 87 | def forward(self, x, **kwargs): 88 | mask = kwargs.get('mask', None) 89 | shifts = self.shifts 90 | segments = len(shifts) 91 | feats_per_shift = x.shape[-1] // segments 92 | splitted = x.split(feats_per_shift, dim=-1) 93 | segments_to_shift, rest = splitted[:segments], splitted[segments:] 94 | segments_to_shift = list(map(lambda args: shift( 95 | *args, mask=mask), zip(segments_to_shift, shifts))) 96 | x = torch.cat((*segments_to_shift, *rest), dim=-1) 97 | return self.fn(x, **kwargs) 98 | 99 | # hierarchical attention helper functions 100 | 101 | 102 | def flip_every_two(t): 103 | t = einops.rearrange(t, 'b (n r) ... -> b n r ...', r=2) 104 | # so we pay attention to the off-diagonal blocks in the attention matrix 105 | t = torch.flip(t, dims=(2,)) 106 | t = einops.rearrange(t, 'b n r ... -> b (n r) ...') 107 | return t 108 | 109 | # attention 110 | 111 | 112 | class HAttention1D(nn.Module): 113 | def __init__( 114 | self, 115 | dim, 116 | *, 117 | heads=8, 118 | dim_head=64, 119 | block_size=16, 120 | pos_emb=None, 121 | eps=1e-8, 122 | **kwargs 123 | ): 124 | super().__init__() 125 | self.eps = eps 126 | self.heads = heads 127 | self.scale = dim_head ** -0.5 128 | self.block_size = block_size 129 | inner_dim = heads * dim_head 130 | 131 | self.pos_emb = pos_emb 132 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) 133 | self.to_out = nn.Linear(inner_dim, dim) 134 | 135 | def forward(self, x, mask=None): 136 | b, n, h, device, bsz, eps = * \ 137 | x.shape[:2], self.heads, x.device, self.block_size, self.eps 138 | 139 | # pad sequence length to power of 2 140 | 141 | pad_to_len = 2 ** ceil(log2(n)) 142 | padding = pad_to_len - n 143 | 144 | if padding != 0: 145 | x = F.pad(x, (0, 0, 0, padding), value=0.) 146 | if exists(mask): 147 | mask = F.pad(mask, (0, padding), value=False) 148 | 149 | # derive queries, keys, values 150 | 151 | q, k, v = self.to_qkv(x).chunk(3, dim=-1) 152 | 153 | # split out heads, and also divide sequence into blocks 154 | 155 | q, k, v = map(lambda t: einops.rearrange( 156 | t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 157 | 158 | if exists(mask): 159 | mask = einops.repeat(mask, 'b n -> (b h) n', h=h) 160 | 161 | # scale 162 | 163 | q = q * self.scale 164 | 165 | # rotary pos emb 166 | 167 | if exists(self.pos_emb): 168 | freqs = self.pos_emb(torch.arange( 169 | pad_to_len, device=device), cache_key=pad_to_len) 170 | freqs = einops.rearrange(freqs, 'n d -> () n d') 171 | q, k, v = map(lambda t: apply_rotary_emb(freqs, t), (q, k, v)) 172 | 173 | # calculate number of levels until 2 x 2 174 | 175 | num_levels = int(log2(pad_to_len // bsz)) - 2 176 | assert num_levels >= 0, 'number of levels must be at least greater than 0' 177 | 178 | # coarsening 179 | 180 | qkvs = [(q, k, v, mask)] 181 | 182 | for level in range(num_levels): 183 | q, k, v = map(lambda t: einops.rearrange( 184 | t, 'b (n r) d -> b n r d', r=2), (q, k, v)) 185 | 186 | if exists(mask): 187 | mask = einops.repeat(mask, 'b (n r) -> b n r', r=2) 188 | 189 | # masked mean for queries and keys, but not values 190 | 191 | q = masked_aggregate(q, mask, dim=2) 192 | k = masked_aggregate(k, mask, dim=2) 193 | v = masked_aggregate(v, mask, dim=2, average=False) 194 | 195 | if exists(mask): 196 | mask = torch.any(mask, dim=2) 197 | 198 | coarsened_qkvs = (q, k, v, mask) 199 | qkvs.append(coarsened_qkvs) 200 | 201 | # duplicate the finest resolution an extra time, for the base diagonal 202 | qkvs = [qkvs[0], *qkvs] 203 | 204 | # half-attention function 205 | 206 | def calculate_Y_and_A(q, k, v, mask=None): 207 | S = einsum('... i d, ... j d -> ... i j', q, k) 208 | 209 | if exists(mask): 210 | mask_value = -torch.finfo(S.dtype).max 211 | S = S.masked_fill(~mask, mask_value) 212 | 213 | S = S - torch.max(S, dim=-1, keepdim=True).values 214 | A = S.exp() 215 | 216 | y = einsum('... i j, ... j d -> ... i d', A, v) 217 | 218 | A = A.sum(dim=-1) 219 | 220 | y = einops.rearrange(y, 'b ... n d -> b (... n) d') 221 | A = einops.rearrange(A, 'b ... i -> b (... i)') 222 | return y, A 223 | 224 | def to_blocks(t): return einops.rearrange( 225 | t, 'b (n z) ... -> b n z ...', z=bsz) 226 | 227 | # calculate Ys, as in the paper 228 | 229 | Ys = [] 230 | 231 | for ind, (q, k, v, mask) in enumerate(reversed(qkvs)): 232 | is_last = ind == (len(qkvs) - 1) 233 | 234 | q, k, v = map(to_blocks, (q, k, v)) 235 | 236 | # generate the mask for S 237 | 238 | S_mask = None 239 | if exists(mask): 240 | mask = to_blocks(mask) 241 | q_mask = mask 242 | k_mask = flip_every_two(mask) if not is_last else mask 243 | S_mask = einops.rearrange( 244 | q_mask, '... n -> ... n ()') * einops.rearrange(k_mask, '... n -> ... () n') 245 | 246 | # flip keys and values to capture the off-diagonals 247 | 248 | if not is_last: 249 | k, v = map(flip_every_two, (k, v)) 250 | 251 | Y_level = calculate_Y_and_A(q, k, v, mask=S_mask) 252 | Ys.append(Y_level) 253 | 254 | # interpolate 255 | 256 | Y = 0 257 | A = 0 258 | 259 | for ind, (Y_level, A_level) in enumerate(Ys): 260 | is_last = ind == (len(Ys) - 1) 261 | 262 | if not is_last and torch.is_tensor(Y): 263 | Y = einops.repeat(Y, 'b n d -> b (n r) d', r=2) 264 | 265 | if not is_last and torch.is_tensor(A): 266 | A = einops.repeat(A, 'b n -> b (n r)', r=2) 267 | 268 | Y = Y_level + Y 269 | A = A_level + A 270 | 271 | out = Y / einops.rearrange(A + eps, 'b n -> b n ()') 272 | 273 | # merge heads 274 | 275 | out = einops.rearrange(out, '(b h) n d -> b n (h d)', h=h) 276 | 277 | # combine out 278 | 279 | return self.to_out(out[:, :n]) 280 | 281 | # causal attention 282 | 283 | 284 | class CausalHAttention1D(nn.Module): 285 | def __init__( 286 | self, 287 | dim, 288 | *, 289 | max_seq_len, 290 | heads=8, 291 | dim_head=64, 292 | block_size=16, 293 | eps=1e-8, 294 | pos_emb=None 295 | ): 296 | super().__init__() 297 | self.eps = eps 298 | self.heads = heads 299 | self.scale = dim_head ** -0.5 300 | self.block_size = block_size 301 | inner_dim = heads * dim_head 302 | 303 | self.pos_emb = pos_emb 304 | 305 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) 306 | self.to_out = nn.Linear(inner_dim, dim) 307 | 308 | # derive mask 309 | 310 | num_levels = int(log2(max_seq_len // block_size)) - 1 311 | root_seq = torch.arange(max_seq_len) 312 | seqs = [root_seq] 313 | seq = root_seq 314 | 315 | for ind in range(num_levels): 316 | seq = einops.rearrange(seq, '(n r) -> n r', r=2) 317 | seq = seq.max(dim=-1).values 318 | expanded_mask_seq = einops.repeat( 319 | seq, 'n -> (n r)', r=(2 ** (ind + 1))) 320 | seqs.append(expanded_mask_seq) 321 | 322 | seq_keys = torch.stack(seqs, dim=0) 323 | mask = seq_keys > einops.rearrange(root_seq, 'n -> () n') 324 | self.register_buffer('mask', mask) 325 | 326 | def forward(self, x, **kwargs): 327 | b, n, h, device, bsz, eps = * \ 328 | x.shape[:2], self.heads, x.device, self.block_size, self.eps 329 | 330 | # pad sequence length to power of 2 331 | 332 | pad_to_len = 2 ** ceil(log2(n)) 333 | padding = pad_to_len - n 334 | 335 | if padding != 0: 336 | x = F.pad(x, (0, 0, 0, padding), value=0.) 337 | 338 | # derive queries, keys, values 339 | 340 | q, k, v = self.to_qkv(x).chunk(3, dim=-1) 341 | 342 | # split out heads, and also divide sequence into blocks 343 | 344 | q, k, v = map(lambda t: einops.rearrange( 345 | t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 346 | 347 | # scale 348 | 349 | q = q * self.scale 350 | 351 | # rotary embedding 352 | 353 | if exists(self.pos_emb): 354 | freqs = self.pos_emb(torch.arange( 355 | pad_to_len, device=device), cache_key=pad_to_len) 356 | freqs = einops.rearrange(freqs, 'n d -> () n d') 357 | q, k, v = map(lambda t: apply_rotary_emb(freqs, t), (q, k, v)) 358 | 359 | # calculate number of levels until 2 x 2 360 | 361 | num_levels = int(log2(pad_to_len // bsz)) - 1 362 | 363 | # coarsening 364 | 365 | qkvs = [(q, k, v)] 366 | 367 | for level in range(num_levels): 368 | q, k, v = map(lambda t: einops.rearrange( 369 | t, 'b (n r) d -> b n r d', r=2), (q, k, v)) 370 | 371 | # masked mean for queries and keys, but not values 372 | 373 | q = q.mean(dim=2) 374 | k = k.mean(dim=2) 375 | v = v.sum(dim=2) 376 | 377 | coarsened_qkvs = (q, k, v) 378 | qkvs.append(coarsened_qkvs) 379 | 380 | # half-attention function 381 | 382 | def calculate_Y_and_A(q, k, v, mask_right_off_diagonals=False, causal_mask_diagonal=False): 383 | if mask_right_off_diagonals: 384 | q, k, v = map(lambda t: einops.rearrange( 385 | t, 'b (n r) ... -> b n r ...', r=2), (q, k, v)) 386 | q, k, v = map(lambda t: t[:, :, 1], (q, k, v)) 387 | 388 | S = einsum('... i d, ... j d -> ... i j', q, k) 389 | 390 | if causal_mask_diagonal: 391 | causal_mask = torch.ones( 392 | *S.shape[-2:], device=S.device).triu(1).bool() 393 | mask_value = -torch.finfo(S.dtype).max 394 | causal_mask = einops.rearrange(causal_mask, 'i j -> () () i j') 395 | S = S.masked_fill(causal_mask, mask_value) 396 | 397 | S = S - torch.amax(S, dim=-1, keepdim=True) 398 | A = S.exp() 399 | 400 | y = einsum('... i j, ... j d -> ... i d', A, v) 401 | 402 | A = A.sum(dim=-1) 403 | 404 | if mask_right_off_diagonals: 405 | y, A = map(lambda t: einops.rearrange( 406 | t, 'b n ... -> b n () ...'), (y, A)) 407 | y = F.pad(y, (0, 0, 0, 0, 1, 0), value=0.) 408 | A = F.pad(A, (0, 0, 1, 0), value=0.) 409 | 410 | y = einops.rearrange(y, 'b ... d -> b (...) d') 411 | A = einops.rearrange(A, 'b ... -> b (...)') 412 | return y, A 413 | 414 | def to_blocks(t): return einops.rearrange( 415 | t, 'b (n z) ... -> b n z ...', z=bsz) 416 | 417 | # calculate Ys, as in the paper 418 | 419 | Ys = [] 420 | 421 | for ind, (q, k, v) in enumerate(reversed(qkvs)): 422 | is_last = ind == (len(qkvs) - 1) 423 | 424 | q, k, v = map(to_blocks, (q, k, v)) 425 | 426 | # flip keys and values to capture the off-diagonals 427 | 428 | if not is_last: 429 | k, v = map(flip_every_two, (k, v)) 430 | 431 | Y_level = calculate_Y_and_A( 432 | q, k, v, mask_right_off_diagonals=not is_last, causal_mask_diagonal=is_last) 433 | Ys.append(Y_level) 434 | 435 | # interpolate 436 | 437 | def safe_cat(acc, el, dim=0): 438 | if not exists(acc): 439 | return el 440 | return torch.cat((el, acc), dim=dim) 441 | 442 | Y = None 443 | A = None 444 | 445 | for Y_level, A_level in Ys: 446 | Y_level, A_level = map(lambda t: einops.rearrange( 447 | t, '... -> () ...'), (Y_level, A_level)) 448 | 449 | if torch.is_tensor(Y): 450 | Y = einops.repeat(Y, '... n d -> ... (n r) d', r=2) 451 | 452 | if torch.is_tensor(A): 453 | A = einops.repeat(A, '... n -> ... (n r)', r=2) 454 | 455 | Y = safe_cat(Y, Y_level) 456 | A = safe_cat(A, A_level) 457 | 458 | # create causal mask for Y and A 459 | 460 | causal_mask = self.mask[:(num_levels + 1), :pad_to_len] 461 | 462 | # mask and sum 463 | 464 | Y_causal_mask = einops.rearrange(causal_mask, 'h n -> h () n ()') 465 | A_causal_mask = einops.rearrange(causal_mask, 'h n -> h () n') 466 | 467 | Y = Y.masked_fill(Y_causal_mask, 0.) 468 | A = A.masked_fill(A_causal_mask, 0.) 469 | 470 | Y = Y.sum(dim=0) 471 | A = A.sum(dim=0) 472 | 473 | # normalize 474 | 475 | out = Y / einops.rearrange(A + eps, 'b n -> b n ()') 476 | 477 | # merge heads 478 | 479 | out = einops.rearrange(out, '(b h) n d -> b n (h d)', h=h) 480 | 481 | # combine out 482 | 483 | return self.to_out(out[:, :n]) 484 | 485 | # main class 486 | 487 | 488 | class HTransformer1D(nn.Module): 489 | def __init__( 490 | self, 491 | *, 492 | num_tokens, 493 | dim, 494 | depth, 495 | max_seq_len, 496 | causal=False, 497 | heads=8, 498 | dim_head=64, 499 | ff_mult=4, 500 | # this is the Nr in the paper - Nb = (max_seq_len / tokens_per_block) 501 | block_size=128, 502 | pos_emb=None, 503 | reversible=False, 504 | shift_tokens=False 505 | ): 506 | super().__init__() 507 | assert (max_seq_len % 508 | block_size) == 0, 'maximum sequence length must be divisible by the block size' 509 | num_blocks = max_seq_len // block_size 510 | assert log2(max_seq_len // block_size).is_integer( 511 | ), f'number of blocks {num_blocks} must be a power of 2' 512 | 513 | # self.token_emb = nn.Embedding(num_tokens, dim) 514 | self.pos_emb = RotaryEmbedding(dim=dim_head) if pos_emb else None 515 | self.max_seq_len = max_seq_len 516 | 517 | layers = nn.ModuleList([]) 518 | 519 | attn_class = CausalHAttention1D if causal else HAttention1D 520 | attn_kwargs = dict(max_seq_len=max_seq_len) if causal else dict() 521 | 522 | shift_token_ranges = (0, 1) if shift_tokens else (-1, 0, 1) 523 | 524 | for ind in range(depth): 525 | attn = attn_class(dim, dim_head=dim_head, heads=heads, 526 | block_size=block_size, pos_emb=self.pos_emb, **attn_kwargs) 527 | ff = FeedForward(dim, mult=ff_mult) 528 | 529 | if shift_tokens: 530 | attn, ff = map(lambda t: PreShiftTokens( 531 | shift_token_ranges, t), (attn, ff)) 532 | 533 | attn, ff = map(lambda t: PreNorm(dim, t), (attn, ff)) 534 | layers.append(nn.ModuleList([attn, ff])) 535 | 536 | execute_type = ReversibleSequence if reversible else SequentialSequence 537 | route_attn = ((True, False),) * depth 538 | attn_route_map = {'mask': route_attn} 539 | 540 | self.layers = execute_type(layers, args_route={**attn_route_map}) 541 | 542 | self.to_logits = nn.Sequential( 543 | nn.LayerNorm(dim), 544 | nn.Linear(dim, num_tokens) 545 | ) 546 | 547 | def forward(self, x, mask=None, return_embeddings=False): 548 | # b, n, device = *x.shape, x.device 549 | # assert n <= self.max_seq_len, 'sequence length must be less than the maximum sequence length' 550 | x = self.layers(x, mask=mask) 551 | if not return_embeddings: 552 | x = self.to_logits(x) 553 | return x 554 | 555 | 556 | def eval_decorator(fn): 557 | def inner(model, *args, **kwargs): 558 | was_training = model.training 559 | model.eval() 560 | out = fn(model, *args, **kwargs) 561 | model.train(was_training) 562 | return out 563 | return inner 564 | 565 | # top k filtering 566 | 567 | 568 | def top_k(logits, thres=0.9): 569 | k = int((1 - thres) * logits.shape[-1]) 570 | val, ind = torch.topk(logits, k) 571 | probs = torch.full_like(logits, float('-inf')) 572 | probs.scatter_(1, ind, val) 573 | return probs 574 | 575 | 576 | class AutoregressiveWrapper(nn.Module): 577 | def __init__(self, net, ignore_index=-100, pad_value=0): 578 | super().__init__() 579 | self.pad_value = pad_value 580 | self.ignore_index = ignore_index 581 | 582 | self.net = net 583 | self.max_seq_len = net.max_seq_len 584 | 585 | @torch.no_grad() 586 | @eval_decorator 587 | def generate(self, start_tokens, seq_len, eos_token=None, temperature=1., filter_logits_fn=top_k, filter_thres=0.9, **kwargs): 588 | if isinstance(self.net, bioseq.encoders.SeqEncoder): 589 | if eos_token is None: 590 | eos = self.net.tokenizer.eos() 591 | if eos >= 0: 592 | eos_token = eos 593 | device = start_tokens.device 594 | num_dims = len(start_tokens.shape) 595 | 596 | if num_dims == 1: 597 | start_tokens = start_tokens[None, :] 598 | 599 | b, t = start_tokens.shape 600 | 601 | out = start_tokens 602 | 603 | for _ in range(seq_len): 604 | x = out[:, -self.max_seq_len:] 605 | 606 | logits = self.net(x, **kwargs)[:, -1, :] 607 | 608 | filtered_logits = top_k(logits, thres=filter_thres) 609 | probs = F.softmax(filtered_logits / temperature, dim=-1) 610 | 611 | sample = torch.multinomial(probs, 1) 612 | 613 | out = torch.cat((out, sample), dim=-1) 614 | 615 | if exists(eos_token): 616 | is_eos_token = (out == eos_token) 617 | 618 | if is_eos_token.any(dim=-1).all(): 619 | # mask out everything after the eos tokens 620 | shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1)) 621 | mask = shifted_is_eos_tokens.float().cumsum(dim=-1) >= 1 622 | out = out.masked_fill(mask, self.pad_value) 623 | break 624 | 625 | out = out[:, t:] 626 | 627 | if num_dims == 1: 628 | out = out.squeeze(0) 629 | 630 | return out 631 | 632 | def forward(self, x, **kwargs): 633 | # if not isinstance(x, torch.Tensor) and isinstance(self.net, bioseq.encoders.SeqEncoder): 634 | device = kwargs.get("device", torch.device( 635 | "cuda") if torch.cuda.is_available() else torch.device("cpu")) 636 | x = self.net.tokenize(x, device=device) 637 | xi = x[:, :-1] 638 | xo = x[:, 1:] 639 | 640 | out = self.net(xi, tokenize=False, **kwargs) 641 | if xo.dtype is not torch.long: 642 | xo = xo.to(torch.long) 643 | loss = F.cross_entropy(out.transpose(1, 2), xo, 644 | ignore_index=self.ignore_index) 645 | return loss 646 | -------------------------------------------------------------------------------- /src/tokenize.h: -------------------------------------------------------------------------------- 1 | #include "pybind11/pybind11.h" 2 | #include "alphabet.h" 3 | #include "pybind11/numpy.h" 4 | #include 5 | #include 6 | 7 | using namespace alph; 8 | 9 | struct Tokenizer { 10 | const Alphabet *ca_; 11 | const bool include_eos_; 12 | const bool include_bos_; 13 | const bool zero_onehot_pad_; 14 | std::string key; 15 | std::unordered_map lookup; 16 | std::unordered_map tokensets; 17 | std::string token_map_str; 18 | // Whether to pad with all 0s (instead of one-hot encoding with a wholly new character for 'padding') 19 | // If true, then trailing sections are left as all 0s 20 | // if false, then the pad() character is one-hot encoded for padding sections 21 | 22 | size_t full_alphabet_size() const {return ca_->nchars() + include_eos_ + include_bos_ + zero_onehot_pad_;} 23 | int bos() const { 24 | if(!include_bos_) return -1; 25 | return ca_->nchars(); 26 | } 27 | int eos() const { 28 | if(!include_eos_) return -1; 29 | return ca_->nchars() + include_bos_; 30 | } 31 | int pad() const { 32 | return ca_->nchars() + include_bos_ + include_eos_; 33 | } 34 | 35 | bool is_padded() const {return zero_onehot_pad_;} 36 | bool includes_bos() const {return include_bos_;} 37 | bool includes_eos() const {return include_eos_;} 38 | int nchars() const noexcept {return ca_->nchars();} 39 | // Always included: padding 40 | Tokenizer(const Alphabet &ca, bool eos=false, bool bos=false, bool zero_onehot_pad=false): ca_(&ca), include_eos_(eos), include_bos_(bos), zero_onehot_pad_(zero_onehot_pad) { 41 | for(int32_t i = 0, e = ca.lut.size(); i < e; ++i) { 42 | const char value = ca.lut[i]; 43 | if(!lookup.contains(value)) { 44 | lookup[value] = std::string(1, static_cast(i)); 45 | } 46 | tokensets[value] += static_cast(i); 47 | } 48 | if(includes_bos()) { 49 | lookup[this->bos()] = ""; 50 | } 51 | if(includes_eos()) { 52 | lookup[this->eos()] = ""; 53 | } 54 | if(is_padded()) { 55 | lookup[pad()] = ""; 56 | } 57 | for(const auto& pair: lookup) { 58 | token_map_str += std::to_string(pair.first) + ':' + pair.second; 59 | token_map_str += ';'; 60 | } 61 | if(!token_map_str.empty()) 62 | token_map_str.pop_back(); 63 | } 64 | std::string token_map() const noexcept {return token_map_str;} 65 | std::unordered_map token_set_map() const { 66 | std::unordered_map ret; 67 | for(const auto& [k, v]: tokensets) { 68 | ret[k] = py::bytes(v); 69 | } 70 | return ret; 71 | } 72 | Tokenizer(std::string key_, bool include_eos, bool include_bos, bool zohpad): include_eos_(include_eos), include_bos_(include_bos), zero_onehot_pad_(zohpad), key(key_) { 73 | std::transform(key.begin(), key.end(), key.begin(),[](auto x){return std::toupper(x);}); 74 | auto it = CAMAP.find(key); 75 | if(it == CAMAP.end()) { 76 | std::string options; 77 | for(const auto &pair: CAMAP) options += pair.first, options += ';'; 78 | throw std::runtime_error(std::string("Invalid tokenizer type; select one from") + options); 79 | } 80 | ca_ = it->second; 81 | const auto& ca = *ca_; 82 | int32_t maxv{0}; 83 | for(int32_t i = 0, e = ca.lut.size(); i < e; ++i) { 84 | const char value = ca.lut[i]; 85 | if(!lookup.contains(value)) { 86 | lookup[value] = std::string(1, static_cast(i)); 87 | } 88 | tokensets[value] += static_cast(i); 89 | maxv = std::max(int(value), maxv); 90 | } 91 | if(includes_bos()) { 92 | lookup[this->bos()] = ""; 93 | } 94 | if(includes_eos()) { 95 | lookup[this->eos()] = ""; 96 | } 97 | if(is_padded()) { 98 | lookup[pad()] = ""; 99 | } 100 | for(const auto& pair: lookup) { 101 | token_map_str += std::to_string(pair.first) + ':' + pair.second; 102 | token_map_str += ';'; 103 | } 104 | if(!token_map_str.empty()) 105 | token_map_str.pop_back(); 106 | } 107 | static uint64_t load_value(const uint8_t* const data, const int64_t bytes) { 108 | switch(bytes) { 109 | case 1: { 110 | return *data; 111 | } 112 | case 2: { 113 | return *static_cast(static_cast(data)); 114 | } 115 | case 4: { 116 | return *static_cast(static_cast(data)); 117 | } 118 | case 8: { 119 | return *static_cast(static_cast(data)); 120 | } 121 | default: ; 122 | } 123 | throw std::runtime_error(std::string("Unexpected itemsize: expected 1, 2, 4, or 8. Found ") + std::to_string(bytes)); 124 | } 125 | static void trim_to_eos(std::string& x) { 126 | const auto pos = x.find(""); 127 | if(pos != std::string::npos) { 128 | x.resize(pos + 5); 129 | } 130 | } 131 | py::object decode_tokens(const py::buffer_info& info, const bool trim=false) const { 132 | if(info.ptr == nullptr) { 133 | throw std::invalid_argument("Empty array cannot yield a decoded string"); 134 | } 135 | const int32_t ndim = info.ndim; 136 | if((ndim > 2) || (ndim == 0)) { 137 | throw std::invalid_argument("Currently supported: 1 or 2 dimensions for decoding tokens."); 138 | } 139 | if(ndim == 1) { 140 | std::ostringstream oss; 141 | const uint8_t* data_ptr = static_cast(info.ptr); 142 | const int64_t stride = info.strides[0]; 143 | const uint8_t* const end_ptr = static_cast(info.ptr) + (stride * info.size); 144 | for(;data_ptr < end_ptr; data_ptr += stride) { 145 | const uint32_t value = load_value(data_ptr, info.itemsize); 146 | const auto it = lookup.find(value); 147 | if(it == lookup.end()) { 148 | throw std::runtime_error(std::string("Unexpected/invalid token ") + std::to_string(value)); 149 | } 150 | oss << it->second; 151 | } 152 | std::string ret = oss.str(); 153 | if(trim) trim_to_eos(ret); 154 | return py::str(oss.str()); 155 | } 156 | // ndim == 2 157 | py::list ret; 158 | 159 | const int64_t nrows = info.shape[0]; 160 | const int64_t ncols = info.shape[1]; 161 | const int64_t rowStride = info.strides[0]; 162 | const int64_t colStride = info.strides[1]; 163 | for(int64_t row = 0; row < nrows; ++row) { 164 | std::ostringstream oss; 165 | const uint8_t* data_ptr = static_cast(info.ptr) + rowStride * row; 166 | for(int64_t col = 0; col < ncols; ++col) { 167 | const uint32_t value = load_value(data_ptr + col * colStride, info.itemsize); 168 | const auto it = lookup.find(value); 169 | if(it == lookup.end()) { 170 | throw std::runtime_error(std::string("Unexpected/invalid token ") + std::to_string(value)); 171 | } 172 | oss << it->second; 173 | } 174 | std::string next = oss.str(); 175 | if(trim) {trim_to_eos(next);} 176 | ret.append(py::str(next)); 177 | } 178 | return ret; 179 | } 180 | py::object decode_tokens_to_string(py::array array) const { 181 | py::buffer_info bi = array.request(); 182 | return decode_tokens(bi); 183 | } 184 | template 185 | py::array_t tokenize(const std::string &seq, py::ssize_t padlen=0) const { 186 | return tokenize(seq.data(), seq.size(), padlen); 187 | } 188 | template 189 | py::array_t tokenize(const char *seq, const py::ssize_t seqsz, py::ssize_t padlen=0) const { 190 | if(padlen > 0) { 191 | if(seqsz > padlen) throw std::runtime_error("padlen is too short to accommodate sequence\n"); 192 | } 193 | py::array_t ret; 194 | const py::ssize_t nc = full_alphabet_size(); 195 | ret.resize({std::max(py::ssize_t(seqsz), padlen) + include_bos_ + include_eos_, nc}); 196 | py::buffer_info bi = ret.request(); 197 | T *ptr = (T *)bi.ptr, *offp = ptr; 198 | if(include_bos_) 199 | ptr[bos()] = 1, offp += nc; 200 | for(py::ssize_t i = 0; i < seqsz; ++i) { 201 | //, translated to %d\n", i, seq[i], seq[i], ca_->translate(seq[i])); 202 | assert(std::strlen(seq) > i); 203 | auto offset = ca_->translate(seq[i]); 204 | assert(offset < full_alphabet_size()); 205 | assert(offset >= 0u); 206 | offp[offset] = 1, offp += nc; 207 | } 208 | if(include_eos_) 209 | offp[eos()] = 1, offp += nc; 210 | if(zero_onehot_pad_) { 211 | for(auto pos = (offp - ptr) / nc;pos < padlen; ++pos) { 212 | ptr[pos * nc + pad()] = 1; 213 | } 214 | } 215 | return ret; 216 | } 217 | template 218 | py::array_t tokenize(const std::vector &seqs, py::ssize_t padlen=0, bool batch_first = false) const { 219 | if(seqs.empty()) { 220 | throw std::invalid_argument(std::string("Cannot tokenize an empty set of sequences; len: ") + std::to_string(seqs.size())); 221 | } 222 | const size_t mseqlen = std::accumulate(seqs.begin(), seqs.end(), size_t(0), [](auto x, const auto &y) {return std::max(x, y.size());}); 223 | if(padlen > 0) { 224 | if(mseqlen > padlen) throw std::runtime_error("padlen is too short to accommodate sequence batch\n"); 225 | } else { 226 | padlen = mseqlen; 227 | } 228 | py::array_t ret; 229 | const py::ssize_t batchsize = seqs.size(); 230 | const py::ssize_t nc = full_alphabet_size(); 231 | py::ssize_t nr = padlen; // + include_bos_ + include_eos_; 232 | if(batch_first) { 233 | ret.resize({batchsize, nr, nc}); 234 | } else { 235 | ret.resize({nr, batchsize, nc}); 236 | } 237 | py::buffer_info bi = ret.request(); 238 | T *ptr = (T *)bi.ptr, *offp = ptr; 239 | if(0) { 240 | #if 0 241 | #ifdef _OPENMP 242 | #pragma omp parallel for 243 | #endif 244 | const auto mul = nc * nr; 245 | for(size_t i = 0; i < seqs.size(); ++i) { 246 | const auto &seq(seqs[i]); 247 | auto tp = &offp[i * mul]; 248 | if(include_bos_) 249 | tp[bos()] = 1, tp += nc; 250 | for(size_t j = 0; j < seq.size(); ++j) 251 | tp[ca_->translate(seq[j])] = 1, tp += nc; 252 | if(include_eos_) 253 | tp[eos()] = 1, tp += nc; 254 | if(zero_onehot_pad_) { 255 | for(auto ep = &offp[(i + 1) * mul];tp < ep;tp += nc) 256 | tp[pad()] = 1; 257 | } 258 | } 259 | #endif 260 | } else { 261 | const size_t total_nregs = nr * batchsize * nc; 262 | for(size_t i = 0; i < seqs.size(); ++i) { 263 | const auto &seq(seqs[i]); 264 | if(include_bos_) 265 | ptr[i * nc + bos()] = 1; 266 | for(size_t j = 0; j < seq.size(); ++j) { 267 | auto tr = ca_->translate(seq[i]); 268 | assert(tr >= 0); 269 | assert(tr < full_alphabet_size()); 270 | assert(ptr + (include_bos_ + j) * nr * nc + i * nc + tr < ptr + total_nregs); 271 | ptr[(include_bos_ + j) * nr * nc + i * nc + tr] = 1; 272 | } 273 | if(include_eos_) { 274 | ptr[(include_bos_ + seq.size()) * nr * nc + i * nc + eos()] = 1; 275 | } 276 | if(zero_onehot_pad_) { 277 | for(py::ssize_t myi = seq.size() + include_bos_ + include_eos_; myi < padlen; ++myi) 278 | ptr[myi * nr * nc + i * nc + pad()] = 1; 279 | } 280 | } 281 | } 282 | } 283 | template 284 | py::array_t tokenize(py::sequence items, py::ssize_t padlen=-1, bool batch_first = false, py::ssize_t nthreads = 1, py::object mask=py::none()) const { 285 | if(padlen <= 0) throw std::invalid_argument("batch tokenize requires padlen is provded."); 286 | if(nthreads <= 0) nthreads = 1; 287 | const py::ssize_t nc = full_alphabet_size(); 288 | py::ssize_t nr = padlen; // + include_bos_ + include_eos_; 289 | std::vector> strs; 290 | std::vector maskptrs; 291 | py::ssize_t nitems = 0; 292 | for(auto item: items) { 293 | const uint8_t *maskptr = 0; 294 | if(py::isinstance(mask)) { 295 | maskptr = getmaskptr(py::cast(mask)[nitems]); 296 | } 297 | maskptrs.push_back(maskptr); 298 | py::ssize_t size; 299 | if(py::isinstance(item)) { 300 | const char *s = PyUnicode_AsUTF8AndSize(item.ptr(), &size); 301 | strs.push_back({s, size}); 302 | } else if(py::isinstance(item)) { 303 | char *s; 304 | if(PyBytes_AsStringAndSize(item.ptr(), &s, &size)) 305 | throw std::invalid_argument("item is not a bytes object; this should never happen."); 306 | strs.push_back({s, size}); 307 | } else if(py::isinstance(item)) { 308 | strs.push_back({PyByteArray_AS_STRING(item.ptr()), PyByteArray_GET_SIZE(item.ptr())}); 309 | } else if(py::isinstance(item)) { 310 | auto inf = py::cast(item).request(); 311 | switch(inf.format.front()) { 312 | case 'b': case 'B': { 313 | strs.push_back({(const char *)inf.ptr, size_t(inf.size)}); 314 | } 315 | default: goto invalid; 316 | } 317 | } else { 318 | invalid: 319 | throw std::invalid_argument("item was none of string, bytes, or numpy array of 8-bit integers. "); 320 | } 321 | ++nitems; 322 | } 323 | if(batch_first) { 324 | throw std::invalid_argument("Batch first is disabled. Instead, using Einops' rearrange to correct the shape."); 325 | } 326 | py::array_t ret(std::vector({nr, nitems, nc})); // T B C 327 | const auto nrc = nitems * nc; 328 | #define __access(seqind, batchind, charind) \ 329 | assert((seqind) * nrc + (batchind) * nc + (charind) < nr * nitems * nc || !std::fprintf(stderr, "seqlen %zu, batchind %zu, charind %zu %s", seqind, batchind, charind, seq.first));\ 330 | ptr[(seqind) * nrc + (batchind) * nc + (charind)] 331 | py::buffer_info bi = ret.request(); 332 | std::memset(bi.ptr, 0, sizeof(T) * nitems * nr * nc); 333 | T *ptr = (T *)bi.ptr; 334 | #ifndef NDEBUG 335 | for(size_t i = 0; i < strs.size(); ++i) { 336 | assert(strs[i].second + include_bos_ + include_eos_ <= nr); 337 | } 338 | #endif 339 | #ifdef _OPENMP 340 | #pragma omp parallel for num_threads(nthreads) 341 | #endif 342 | for(size_t i = 0; i < strs.size(); ++i) { 343 | const auto maskptr = maskptrs[i]; 344 | const auto &seq(strs[i]); 345 | if(include_bos_) { 346 | __access(0, i, bos()) = 1; 347 | } 348 | for(size_t j = 0; j < seq.second; ++j) { 349 | if(!maskptr || maskptr[j]) { 350 | const auto tr = ca_->translate(seq.first[j]); 351 | if(tr >= 0) { 352 | __access((include_bos_ + j), i, tr) = 1; 353 | } 354 | } 355 | } 356 | if(include_eos_) { 357 | __access((include_bos_ + seq.second), i, eos()) = 1; 358 | } 359 | if(static_cast(seq.second + include_bos_ + include_eos_) > padlen) { 360 | auto tl = seq.second + include_bos_ + include_eos_; 361 | throw std::invalid_argument(std::string("seq len + bos + eos > padlen: ") + std::to_string(tl) + ", vs padlen " + std::to_string(padlen)); 362 | } 363 | if(zero_onehot_pad_) { 364 | for(py::ssize_t k = seq.second + include_bos_ + include_eos_; k < padlen;) 365 | { 366 | __access(k++, i, pad()) = 1; 367 | } 368 | } 369 | } 370 | return ret; 371 | } 372 | static const uint8_t *getmaskptr(py::object mask) { 373 | const uint8_t *maskptr = 0; 374 | if(py::isinstance(mask)) { 375 | py::array_t arr(mask); 376 | auto inf = arr.request(); 377 | maskptr = (const uint8_t *)inf.ptr; 378 | } 379 | return maskptr; 380 | } 381 | template 382 | py::object transencode(py::sequence items, py::ssize_t padlen=-1, bool batch_first = false, py::ssize_t nthreads = 1, py::object mask=py::none()) const { 383 | if(padlen <= 0) throw std::invalid_argument("batch tokenize requires padlen is provded."); 384 | if(nthreads <= 0) nthreads = 1; 385 | py::ssize_t nr = padlen; // + include_bos_ + include_eos_; 386 | std::vector> strs; 387 | std::vector maskptrs; 388 | py::ssize_t nitems = 0; 389 | for(auto item: items) { 390 | const uint8_t *maskptr = 0; 391 | py::ssize_t size; 392 | if(py::isinstance(mask)) { 393 | maskptr = getmaskptr(py::cast(mask)[nitems]); 394 | } 395 | maskptrs.push_back(maskptr); 396 | if(py::isinstance(item)) { 397 | const char *s = PyUnicode_AsUTF8AndSize(item.ptr(), &size); 398 | strs.push_back({s, size}); 399 | } else if(py::isinstance(item)) { 400 | char *s; 401 | if(PyBytes_AsStringAndSize(item.ptr(), &s, &size)) 402 | throw std::invalid_argument("item is not a bytes object; this should never happen."); 403 | strs.push_back({s, size}); 404 | } else if(py::isinstance(item)) { 405 | strs.push_back({PyByteArray_AS_STRING(item.ptr()), PyByteArray_GET_SIZE(item.ptr())}); 406 | } else if(py::isinstance(item)) { 407 | auto inf = py::cast(item).request(); 408 | switch(inf.format.front()) { 409 | case 'b': case 'B': { 410 | strs.push_back({(const char *)inf.ptr, size_t(inf.size)}); 411 | } 412 | default: goto invalid; 413 | } 414 | } else { 415 | invalid: 416 | throw std::invalid_argument("item was none of string, bytes, or numpy array of 8-bit integers. "); 417 | } 418 | ++nitems; 419 | } 420 | py::object ret = py::none(); 421 | if(batch_first) { 422 | ret = py::array_t(std::vector({nitems, nr})); // B T 423 | } else { 424 | ret = py::array_t(std::vector({nr, nitems})); // T B 425 | } 426 | py::buffer_info bi = ret.cast>().request(); 427 | std::memset(bi.ptr, 0, sizeof(T) * nitems * nr); 428 | const size_t padl = nr; 429 | T *ptr = (T *)bi.ptr; 430 | #define __assign_bf(seqind, batchind, charind) \ 431 | do {\ 432 | assert(seqind + batchind * padl < nitems * padl);\ 433 | ptr[batchind * padl + seqind] = charind;\ 434 | } while(0) 435 | #define __assign_tf(seqind, batchind, charind) \ 436 | do {\ 437 | assert(seqind * nitems + batchind < nitems * padl);\ 438 | ptr[seqind * nitems + batchind] = charind;\ 439 | } while(0) 440 | #define __assign(seqind, batchind, charind) \ 441 | do {\ 442 | if(charind >= 0) {\ 443 | if(batch_first) {\ 444 | __assign_bf(seqind, batchind, charind);\ 445 | } else {\ 446 | __assign_tf(seqind, batchind, charind);\ 447 | }\ 448 | }\ 449 | } while(0) 450 | 451 | #ifdef _OPENMP 452 | #pragma omp parallel for num_threads(nthreads) 453 | #endif 454 | for(size_t i = 0; i < strs.size(); ++i) { 455 | const auto &seq(strs[i]); 456 | if(__builtin_expect(static_cast(seq.second + include_bos_ + include_eos_) > padlen, 0)) { 457 | auto tl = seq.second + include_bos_ + include_eos_; 458 | throw std::runtime_error(std::string("seq len + bos + eos > padlen: ") + std::to_string(tl) + ", vs padlen " + std::to_string(padlen)); 459 | } 460 | if(include_bos_) { 461 | __assign(0, i, bos()); 462 | } 463 | const auto maskptr = maskptrs[i]; 464 | for(size_t j = 0; j < seq.second; ++j) { 465 | auto tr = ca_->translate(seq.first[j]); 466 | if(!maskptr || maskptr[j]) { 467 | __assign((include_bos_ + j), i, tr); 468 | } 469 | } 470 | if(include_eos_) { 471 | __assign((include_bos_ + seq.second), i, eos()); 472 | } 473 | if(zero_onehot_pad_) { 474 | for(py::ssize_t k = seq.second + include_bos_ + include_eos_; k < padlen;) 475 | { 476 | __assign(k++, i, pad()); 477 | } 478 | } 479 | } 480 | return ret; 481 | #undef __assign_tf 482 | #undef __assign_bf 483 | #undef __assign 484 | #undef __access 485 | } 486 | }; 487 | --------------------------------------------------------------------------------