├── CellPLM ├── __init__.py ├── decoder │ ├── __init__.py │ ├── mlp.py │ └── zinb.py ├── embedder │ ├── __init__.py │ └── omics.py ├── encoder │ ├── __init__.py │ ├── mlp.py │ └── transformer.py ├── head │ ├── __init__.py │ └── downstream.py ├── latent │ ├── __init__.py │ ├── adversarial.py │ ├── autoencoders.py │ └── contrastive.py ├── layer │ ├── __init__.py │ ├── cosformer.py │ ├── flowformer.py │ ├── performer.py │ └── transformer.py ├── model │ ├── __init__.py │ └── cellformer.py ├── objective │ ├── __init__.py │ ├── autoencoder.py │ └── zinb.py ├── pipeline │ ├── __init__.py │ ├── cell_embedding.py │ ├── cell_type_annotation.py │ ├── experimental.py │ └── imputation.py └── utils │ ├── __init__.py │ ├── data.py │ ├── eval.py │ ├── mask.py │ ├── pe.py │ └── sparse.py ├── LICENSE ├── README.md ├── ckpt ├── 20230926_85M.config.json └── README.md ├── data └── README.md ├── pyproject.toml ├── requirements.txt └── tutorials ├── README.md ├── cell_embedding.ipynb ├── cell_type_annotation.ipynb └── spatial_imputation.ipynb /CellPLM/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OmicsML/CellPLM/b59ee7688b1bd2a745856a610248b46f15019a22/CellPLM/__init__.py -------------------------------------------------------------------------------- /CellPLM/decoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .mlp import MLPDecoder, ResMLPDecoder 2 | from .zinb import NBMLPDecoder 3 | from torch import nn 4 | 5 | def setup_decoder(model_type, in_dim, hidden_dim, out_dim, num_layers, dropout, norm, batch_num=0, dataset_num=0, platform_num=0) -> nn.Module: 6 | if model_type == 'nbmlp': 7 | mod = NBMLPDecoder( 8 | in_dim=in_dim, 9 | hidden_dim=hidden_dim, 10 | out_dim=out_dim, 11 | num_layers=num_layers, 12 | dropout=dropout, 13 | norm=norm, 14 | batch_num=batch_num, 15 | dataset_num=dataset_num, 16 | platform_num=platform_num, 17 | ) 18 | elif model_type == 'mlp': 19 | mod = MLPDecoder( 20 | in_dim=in_dim, 21 | hidden_dim=hidden_dim, 22 | out_dim=out_dim, 23 | num_layers=num_layers, 24 | dropout=dropout, 25 | norm=norm, 26 | batch_num=batch_num, 27 | dataset_num=dataset_num, 28 | platform_num=platform_num, 29 | ) 30 | elif model_type == "resmlp": 31 | mod = ResMLPDecoder( 32 | in_dim=in_dim, 33 | hidden_dim=hidden_dim, 34 | out_dim=out_dim, 35 | num_layers=num_layers, 36 | dropout=dropout, 37 | norm=norm, 38 | batch_num=batch_num, 39 | ) 40 | else: 41 | raise NotImplementedError(f'Unsupported model type: {model_type}') 42 | return mod -------------------------------------------------------------------------------- /CellPLM/decoder/mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from ..utils import create_norm 4 | import torch.nn.functional as F 5 | 6 | class ResMLPDecoder(nn.Module): 7 | def __init__(self, in_dim, hidden_dim, out_dim, num_layers, dropout, norm, batch_num): 8 | super().__init__() 9 | self.layers = nn.ModuleList() 10 | assert num_layers > 1, 'At least two layer for MLPs.' 11 | for i in range(num_layers - 1): 12 | dim = hidden_dim if i>0 else in_dim 13 | self.layers.append(nn.Sequential( 14 | nn.Linear(dim, hidden_dim), 15 | nn.PReLU(), 16 | nn.Dropout(dropout), 17 | create_norm(norm, hidden_dim) 18 | )) 19 | self.out_layer = nn.Sequential( 20 | nn.Linear(hidden_dim * (num_layers - 1), out_dim), 21 | nn.PReLU(), 22 | ) 23 | self.batch_emb = nn.Embedding(batch_num, hidden_dim) 24 | self.layer_norm = nn.LayerNorm(hidden_dim) 25 | 26 | def forward(self, x_dict): 27 | hist = [] 28 | batch_labels = x_dict['batch'] 29 | x = x_dict['h'] 30 | for layer in self.layers: 31 | x = layer(x) 32 | x = x + self.layer_norm(self.batch_emb(batch_labels)) 33 | hist.append(x) 34 | return {'recon': self.out_layer(torch.cat(hist, 1)), 'latent': x_dict['h']} 35 | 36 | class MLPDecoder(nn.Module): 37 | def __init__(self, in_dim, hidden_dim, out_dim, num_layers, dropout, norm, batch_num=0, dataset_num=0, platform_num=0, out_act=nn.ReLU()): 38 | super().__init__() 39 | self.layers = nn.ModuleList() 40 | # assert num_layers > 1, 'At least two layer for MLPs.' 41 | covariate_num = batch_num + dataset_num + platform_num 42 | for i in range(num_layers - 1): 43 | dim = hidden_dim if i > 0 else in_dim 44 | self.layers.append(nn.Sequential( 45 | nn.Linear(dim + covariate_num, hidden_dim), 46 | nn.PReLU(), 47 | nn.Dropout(dropout), 48 | create_norm(norm, hidden_dim), 49 | )) 50 | 51 | self.out_layer = [nn.Linear(hidden_dim, out_dim)] 52 | if out_act is not None: 53 | self.out_layer.append(out_act) 54 | self.out_layer = nn.Sequential(*self.out_layer) 55 | # self.batch_emb = nn.Embedding(covariate_num, in_dim) 56 | self.layer_norm = nn.LayerNorm(in_dim) 57 | self.batch_num = batch_num 58 | self.dataset_num = dataset_num 59 | self.platform_num = platform_num 60 | 61 | 62 | def reset_batch_emb(self): 63 | self.batch_emb.reset_parameters() 64 | 65 | def forward(self, x_dict): 66 | covariates = [] 67 | if self.batch_num > 0: 68 | covariates.append(F.one_hot(x_dict['batch'], num_classes=self.batch_num)) 69 | if self.dataset_num > 0: 70 | covariates.append(F.one_hot(x_dict['dataset'], num_classes=self.dataset_num)) 71 | if self.platform_num > 0: 72 | covariates.append(F.one_hot(x_dict['platform'], num_classes=self.platform_num)) 73 | x = x_dict['h'] 74 | # x = x + self.batch_emb(batch_labels)#self.layer_norm(self.batch_emb(batch_labels)) 75 | for i, layer in enumerate(self.layers): 76 | x = torch.cat([x]+covariates, 1) 77 | x = layer(x) 78 | # if i == 0: 79 | # x += self.lib_emb(x_dict['lib_size']) 80 | return {'recon': self.out_layer(x), 'latent': x_dict['h']} -------------------------------------------------------------------------------- /CellPLM/decoder/zinb.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from ..utils import create_activation, create_norm 5 | 6 | class MeanAct(nn.Module): 7 | """Mean activation class.""" 8 | 9 | def __init__(self, softmax): 10 | super().__init__() 11 | self.softmax = softmax 12 | 13 | def forward(self, x): 14 | if not self.softmax: 15 | return torch.clamp(torch.exp(x), min=1e-5, max=1e6) 16 | else: 17 | return torch.softmax(x, 1) 18 | 19 | class DispAct(nn.Module): 20 | """Dispersion activation class.""" 21 | 22 | def __init__(self): 23 | super().__init__() 24 | 25 | def forward(self, x): 26 | return torch.clamp(torch.exp(x), min=1e-4, max=1e4) 27 | 28 | class ZINB(nn.Module): 29 | """ZINB Decoder. 30 | Parameters 31 | ---------- 32 | input_dim : int 33 | dimension of input feature. 34 | n_z : int 35 | dimension of latent embedding. 36 | n_dec_1 : int optional 37 | number of nodes of decoder layer 1. 38 | """ 39 | 40 | def __init__(self, hidden_dim, out_dim, n_dec_1=128, softmax=True, disp='gene'): 41 | super().__init__() 42 | self.dec_1 = nn.Linear(hidden_dim, n_dec_1) 43 | self.dec_mean = nn.Sequential(nn.Linear(n_dec_1, out_dim), MeanAct(softmax)) 44 | self.dec_pi = nn.Sequential(nn.Linear(n_dec_1, out_dim), nn.Sigmoid()) 45 | self.disp = disp 46 | if disp == 'gene': 47 | self.dec_disp = nn.Parameter(torch.ones(out_dim)) 48 | else: 49 | self.dec_disp = nn.Sequential(nn.Linear(n_dec_1, out_dim), DispAct()) 50 | 51 | def forward(self, z): 52 | """Forward propagation. 53 | Parameters 54 | ---------- 55 | z : 56 | embedding. 57 | Returns 58 | ------- 59 | _mean : 60 | data mean from ZINB. 61 | _disp : 62 | data dispersion from ZINB. 63 | _pi : 64 | data dropout probability from ZINB4 65 | """ 66 | 67 | h = F.relu(self.dec_1(z)) 68 | _mean = self.dec_mean(h) 69 | if self.disp == 'gene': 70 | _disp = self.dec_disp.repeat(z.shape[0], 1) 71 | else: 72 | _disp = self.dec_disp(h) 73 | _pi = self.dec_pi(h) 74 | return _mean, _disp, _pi 75 | 76 | class NB(nn.Module): 77 | """NB Decoder. 78 | Parameters 79 | ---------- 80 | input_dim : int 81 | dimension of input feature. 82 | n_z : int 83 | dimension of latent embedding. 84 | n_dec_1 : int optional 85 | number of nodes of decoder layer 1. 86 | """ 87 | 88 | def __init__(self, hidden_dim, out_dim, n_dec_1=128, softmax=False, disp='gene'): 89 | super().__init__() 90 | self.dec_1 = nn.Linear(hidden_dim, n_dec_1) 91 | self.dec_mean = nn.Sequential(nn.Linear(n_dec_1, out_dim), MeanAct(softmax)) 92 | self.disp = disp 93 | if disp == 'gene': 94 | self.dec_disp = nn.Parameter(torch.randn(out_dim)) 95 | self.dec_disp_act = DispAct() 96 | else: 97 | self.dec_disp = nn.Sequential(nn.Linear(n_dec_1, out_dim), DispAct()) 98 | 99 | def forward(self, z): 100 | """Forward propagation. 101 | Parameters 102 | ---------- 103 | z : 104 | embedding. 105 | Returns 106 | ------- 107 | _mean : 108 | data mean from NB. 109 | _disp : 110 | data dispersion from NB. 111 | """ 112 | 113 | h = F.relu(self.dec_1(z)) 114 | _mean = self.dec_mean(h) 115 | if self.disp == 'gene': 116 | _disp = self.dec_disp_act(self.dec_disp.repeat(z.shape[0], 1)) 117 | else: 118 | _disp = self.dec_disp(h) 119 | return _mean, _disp 120 | 121 | 122 | class NBMLPDecoder(nn.Module): 123 | def __init__(self, in_dim, hidden_dim, out_dim, num_layers, dropout, norm, batch_num=0, dataset_num=0, platform_num=0): 124 | super().__init__() 125 | self.hidden_dim = hidden_dim 126 | self.norm = norm 127 | self.layers = nn.ModuleList() 128 | self.covariate_layers = nn.ModuleList() 129 | self.covariate_num = { 130 | 'batch': batch_num, 131 | 'dataset': dataset_num, 132 | 'platform': platform_num, 133 | } 134 | for i in range(num_layers-1): 135 | dim = hidden_dim if i > 0 else in_dim 136 | self.layers.append(nn.Sequential( 137 | nn.Linear(dim, hidden_dim), 138 | nn.PReLU(), 139 | nn.Dropout(dropout), 140 | create_norm(norm, hidden_dim), 141 | )) 142 | if sum(self.covariate_num.values()): # Covariates exist 143 | self.covariate_layers.append(nn.ModuleDict()) 144 | for cov in self.covariate_num.keys(): 145 | if self.covariate_num[cov] > 0: 146 | self.covariate_layers[-1][cov] = nn.Sequential( 147 | nn.Embedding(self.covariate_num[cov], hidden_dim), 148 | nn.PReLU(), 149 | create_norm(norm, hidden_dim), 150 | ) 151 | self.out_layer = NB( 152 | hidden_dim, out_dim, 153 | ) 154 | 155 | 156 | def forward(self, x_dict): 157 | x = x_dict['h'] 158 | for i, layer in enumerate(self.layers): 159 | if sum(self.covariate_num.values()): # Covarites (batch/dataset/platform) exist 160 | x = layer(x) 161 | for cov in self.covariate_num.keys(): # Iterate over each type of covariate (batch/dataset/platform) 162 | if self.covariate_num[cov] > 0: # if a certain type of covariate exist 163 | if cov in x_dict: # Whether the covaraite label is input 164 | x += self.covariate_layers[i][cov](x_dict[cov]) 165 | else: # If not input, take average over all of them 166 | convariate_layer = self.covariate_layers[i][cov] 167 | x += convariate_layer[2](convariate_layer[1](convariate_layer[0].weight.detach().sum(0).unsqueeze(0))) 168 | else: 169 | x = layer(x) 170 | mean, disp = self.out_layer(x) 171 | return {'mean': mean, 'disp': disp, 'recon': mean, 'latent': x_dict['h']} 172 | -------------------------------------------------------------------------------- /CellPLM/embedder/__init__.py: -------------------------------------------------------------------------------- 1 | from .omics import OmicsEmbedder, OmicsEmbeddingLayer 2 | 3 | -------------------------------------------------------------------------------- /CellPLM/embedder/omics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from ..utils.pe import select_pe_encoder 5 | from ..utils import create_norm, create_activation 6 | import numpy as np 7 | from ..utils.sparse import sparse_normalize, sparse_tpm 8 | 9 | class OmicsEmbedder(nn.Module): 10 | def __init__(self, pretrained_gene_list, num_hid, gene_emb=None, fix_embedding=False): 11 | super().__init__() 12 | self.pretrained_gene_list = pretrained_gene_list 13 | self.gene_index = dict(zip(pretrained_gene_list, list(range(len(pretrained_gene_list))))) 14 | self.num_hid = num_hid 15 | 16 | if gene_emb is not None: 17 | self.emb = nn.Parameter(gene_emb, requires_grad=not fix_embedding) 18 | else: 19 | self.emb = nn.Parameter(torch.randn([len(pretrained_gene_list), num_hid], dtype=torch.float32)*0.005) 20 | if fix_embedding: 21 | self.emb.requires_grad = False 22 | 23 | def forward(self, x_dict, input_gene_list=None): 24 | if 'masked_x_seq' in x_dict: 25 | x = x_dict['masked_x_seq'] 26 | else: 27 | x = x_dict['x_seq'] 28 | 29 | if 'dropout' in x_dict: 30 | indices = x._indices().t() 31 | values = x._values() 32 | temp = values.sum() 33 | values = values.float() 34 | values = torch.distributions.binomial.Binomial(values, x_dict['dropout']).sample() 35 | x = torch.sparse.FloatTensor(indices.t(), values, x.shape) 36 | 37 | x = torch.log1p(x) 38 | # x = sparse_tpm(x) 39 | if input_gene_list is not None: 40 | gene_idx = torch.tensor([self.gene_index[o] for o in input_gene_list if o in self.gene_index]).long() 41 | x_dict['input_gene_mask'] = gene_idx 42 | else: 43 | if x.shape[1] != len(self.pretrained_gene_list): 44 | raise ValueError('The input gene size is not the same as the pretrained gene list. Please provide the input gene list.') 45 | gene_idx = torch.arange(x.shape[1]).long() 46 | gene_idx = gene_idx.to(x.device) 47 | feat = F.embedding(gene_idx, self.emb) 48 | feat = torch.sparse.mm(x, feat) 49 | return feat 50 | 51 | class OmicsEmbeddingLayer(nn.Module): 52 | def __init__(self, gene_list, num_hidden, norm, activation='gelu', dropout=0.3, pe_type=None, cat_pe=True, gene_emb=None, 53 | inject_covariate=False, batch_num=None): 54 | super().__init__() 55 | 56 | self.pe_type = pe_type 57 | self.cat_pe = cat_pe 58 | self.act = nn.ReLU()#create_activation(activation) 59 | self.norm0 = create_norm(norm, num_hidden) 60 | self.dropout = nn.Dropout(dropout) 61 | self.extra_linear = nn.Sequential( 62 | nn.Linear(num_hidden, num_hidden), 63 | nn.ReLU(), 64 | nn.Dropout(dropout), 65 | create_norm(norm, num_hidden), 66 | ) 67 | if pe_type is not None: 68 | if cat_pe: 69 | num_emb = num_hidden // 2 70 | else: 71 | num_emb = num_hidden 72 | self.pe_enc = select_pe_encoder(pe_type)(num_emb) 73 | else: 74 | self.pe_enc = None 75 | num_emb = num_hidden 76 | 77 | if gene_emb is None: 78 | self.feat_enc = OmicsEmbedder(gene_list, num_emb) 79 | else: 80 | self.feat_enc = OmicsEmbedder(gene_list, num_emb, gene_emb) 81 | 82 | if inject_covariate: 83 | self.cov_enc = nn.Embedding(batch_num, num_emb) 84 | self.inject_covariate = True 85 | else: 86 | self.inject_covariate = False 87 | 88 | def forward(self, x_dict, input_gene_list=None): 89 | x = self.feat_enc(x_dict, input_gene_list)#self.act(self.feat_enc(x_dict, input_gene_list)) 90 | if self.pe_enc is not None: 91 | pe_input = x_dict[self.pe_enc.pe_key] 92 | pe = 0.#self.pe_enc(pe_input) 93 | if self.inject_covariate: 94 | pe = pe + self.cov_enc(x_dict['batch']) 95 | if self.cat_pe: 96 | x = torch.cat([x, pe], 1) 97 | else: 98 | x = x + pe 99 | x = self.extra_linear(x) 100 | # x = self.norm0(self.dropout(x)) 101 | return x 102 | -------------------------------------------------------------------------------- /CellPLM/encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .transformer import TransformerEncoder 2 | from .mlp import ResMLPEncoder, MLPEncoder 3 | from torch import nn 4 | 5 | def setup_encoder(model_type, num_hidden, num_layers, dropout, activation, norm, nhead, covariates_dim=0) -> nn.Module: 6 | if model_type in ["performer", "cosformer", "transformer", "flowformer"]: 7 | mod = TransformerEncoder( 8 | num_hidden = num_hidden, 9 | nhead = nhead, 10 | num_layers = num_layers, 11 | dropout = dropout, 12 | activation = activation, 13 | # norm = norm, 14 | model_type = model_type, 15 | covariates_dim = covariates_dim, 16 | ) 17 | elif model_type == 'mlp': 18 | mod = MLPEncoder( 19 | num_hidden=num_hidden, 20 | num_layers=num_layers, 21 | dropout=dropout, 22 | norm=norm, 23 | covariates_dim=covariates_dim, 24 | ) 25 | elif model_type == "resmlp": 26 | mod = ResMLPEncoder( 27 | num_hidden=num_hidden, 28 | num_layers=num_layers, 29 | dropout=dropout, 30 | norm=norm, 31 | covariates_dim=covariates_dim, 32 | ) 33 | elif model_type == 'none': 34 | mod = NullEncoder() 35 | else: 36 | raise NotImplementedError(f'Unsupported model type: {model_type}') 37 | return mod 38 | 39 | class NullEncoder(nn.Module): 40 | def __init__(self, **kwargs): 41 | super().__init__() 42 | 43 | def forward(self, x_dict): 44 | x = x_dict['h'] 45 | return {'hidden': x} -------------------------------------------------------------------------------- /CellPLM/encoder/mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from ..utils import create_norm 4 | 5 | class ResMLPEncoder(nn.Module): 6 | def __init__(self, num_hidden, num_layers, dropout, norm, covariates_dim=0): 7 | super().__init__() 8 | self.layers = nn.ModuleList() 9 | assert num_layers > 1, 'At least two layer for MLPs.' 10 | for i in range(num_layers - 1): 11 | self.layers.append(nn.Sequential( 12 | nn.Linear(num_hidden, num_hidden), 13 | nn.PReLU(), 14 | nn.Dropout(dropout), 15 | create_norm(norm, num_hidden) 16 | )) 17 | self.out_layer = nn.Sequential( 18 | nn.Linear(num_hidden * (num_layers - 1), num_hidden), 19 | nn.PReLU(), 20 | nn.Dropout(dropout), 21 | create_norm(norm, num_hidden) 22 | ) 23 | 24 | def forward(self, x_dict): 25 | hist = [] 26 | x = x_dict['h'] 27 | for layer in self.layers: 28 | x = layer(x) 29 | hist.append(x) 30 | return self.out_layer(torch.cat(hist, 1)) 31 | 32 | class MLPEncoder(nn.Module): 33 | def __init__(self, num_hidden, num_layers, dropout, norm, covariates_dim=0): 34 | super().__init__() 35 | self.layers = nn.ModuleList() 36 | # assert num_layers > 1, 'At least two layer for MLPs.' 37 | for i in range(num_layers): 38 | self.layers.append(nn.Sequential( 39 | nn.Linear(num_hidden, num_hidden), 40 | nn.PReLU(), 41 | nn.Dropout(dropout), 42 | # nn.BatchNorm1d(num_hidden), 43 | create_norm(norm, num_hidden) 44 | )) 45 | 46 | def forward(self, x_dict): 47 | x = x_dict['h'] 48 | for layer in self.layers: 49 | x = x + layer(x) 50 | return {'hidden': x} -------------------------------------------------------------------------------- /CellPLM/encoder/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from ..utils import create_activation 4 | from ..layer import CosformerLayer, PerformerLayer, VanillaTransformerLayer, FlowformerLayer 5 | from ..utils.pe import select_pe_encoder 6 | 7 | class TransformerEncoder(nn.Module): 8 | def __init__(self, 9 | num_hidden, 10 | nhead, 11 | num_layers, 12 | dropout, 13 | activation, 14 | norm=None, 15 | model_type='performer', 16 | covariates_dim=0, 17 | ): 18 | super(TransformerEncoder, self).__init__() 19 | self.num_layers = num_layers 20 | 21 | self.layers = nn.ModuleList() 22 | if model_type == 'cosformer': 23 | TransformerLayer = CosformerLayer 24 | elif model_type == 'performer': 25 | TransformerLayer = PerformerLayer 26 | elif model_type == 'transformer': 27 | TransformerLayer = VanillaTransformerLayer 28 | elif model_type == 'flowformer': 29 | TransformerLayer = FlowformerLayer 30 | else: 31 | raise NotImplementedError(f'Not implemented transformer type: {model_type}') 32 | 33 | for i in range(num_layers): 34 | self.layers.append( 35 | TransformerLayer( 36 | embed_dim=num_hidden, num_heads=nhead, 37 | dropout=dropout) 38 | ) 39 | 40 | def forward(self, x_dict, output_attentions=False): 41 | h = x_dict['h'] 42 | att_list = [] 43 | for l in range(self.num_layers): 44 | 45 | if l == 0: 46 | x_dict['base0'] = h.detach() 47 | if output_attentions: 48 | h, att = self.layers[l](h, output_attentions=True) 49 | att_list.append(att) 50 | else: 51 | h = self.layers[l](h) 52 | if l == 0: 53 | x_dict['base1'] = h.detach() 54 | 55 | if output_attentions: 56 | return {'hidden': h, 'attn': att_list} 57 | else: 58 | return {'hidden': h} 59 | -------------------------------------------------------------------------------- /CellPLM/head/__init__.py: -------------------------------------------------------------------------------- 1 | from .downstream import AnnotationHead, DenoisingHead, PerturbationPredictionHead, PatientClassificationHead, EmbedderHead, ImputationHead 2 | # from .spatial import 3 | from torch import nn 4 | 5 | def setup_head(head_type, in_dim, hidden_dim, out_dim, num_layers, dropout, norm, batch_num) -> nn.Module: 6 | if head_type == 'annotation': 7 | mod = AnnotationHead( 8 | in_dim=in_dim, 9 | hidden_dim=hidden_dim, 10 | num_classes=out_dim, 11 | num_layers=num_layers, 12 | dropout=dropout, 13 | norm=norm, 14 | batch_num=batch_num, 15 | ) 16 | elif head_type == 'denoising': 17 | mod = DenoisingHead( 18 | in_dim=in_dim, 19 | hidden_dim=hidden_dim, 20 | out_dim=out_dim, 21 | num_layers=num_layers, 22 | dropout=dropout, 23 | norm=norm, 24 | batch_num=batch_num, 25 | ) 26 | elif head_type == 'perturbation_prediction': 27 | mod = PerturbationPredictionHead( 28 | in_dim=in_dim, 29 | hidden_dim=hidden_dim, 30 | out_dim=out_dim, 31 | num_layers=num_layers, 32 | dropout=dropout, 33 | norm=norm, 34 | batch_num=batch_num, 35 | ) 36 | elif head_type == 'patient_classification': 37 | mod = PatientClassificationHead( 38 | in_dim=in_dim, 39 | hidden_dim=hidden_dim, 40 | num_classes=out_dim, 41 | num_layers=num_layers, 42 | dropout=dropout, 43 | norm=norm, 44 | batch_num=batch_num, 45 | ) 46 | elif head_type == 'imputation': 47 | mod = ImputationHead( 48 | in_dim=in_dim, 49 | hidden_dim=hidden_dim, 50 | out_dim=out_dim, 51 | num_layers=num_layers, 52 | dropout=dropout, 53 | norm=norm, 54 | batch_num=batch_num, 55 | ) 56 | elif head_type == 'embedder': 57 | mod = EmbedderHead( 58 | in_dim=in_dim, 59 | hidden_dim=hidden_dim, 60 | out_dim=out_dim, 61 | num_layers=num_layers, 62 | dropout=dropout, 63 | norm=norm, 64 | batch_num=batch_num, 65 | ) 66 | else: 67 | raise NotImplementedError(f'Unsupported model type: {head_type}') 68 | return mod -------------------------------------------------------------------------------- /CellPLM/head/downstream.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from tqdm import tqdm 6 | from ..decoder import MLPDecoder, ResMLPDecoder 7 | from ..latent import GMVAELatentLayer 8 | from ..objective import ReconstructionLoss 9 | from ..utils.data import XDict 10 | from ..encoder.transformer import TransformerEncoder 11 | 12 | def buildNetwork(layers, dropouts, activation=nn.ReLU()): 13 | net = [] 14 | for i in range(1, len(layers)): 15 | if dropouts[i-1] > 0: 16 | net.append(nn.Dropout(dropouts[i-1])) 17 | net.append(nn.Linear(layers[i - 1], layers[i])) 18 | if i < len(layers) - 1: 19 | net.append(activation) 20 | net = nn.Sequential(*net) 21 | return net 22 | 23 | class AnnotationHead(nn.Module): 24 | def __init__(self, in_dim, hidden_dim, num_classes, num_layers, dropout, norm, batch_num, **kwargs): 25 | super().__init__() 26 | self.ce_loss = nn.CrossEntropyLoss() 27 | layers = [in_dim] + [hidden_dim] * (num_layers - 1) + [num_classes] 28 | dropouts = [dropout] * len(layers) 29 | self.mlp = buildNetwork(layers, dropouts) 30 | 31 | def forward(self, x_dict): 32 | logits = self.mlp(x_dict['h'][x_dict['loss_mask']]) 33 | pred = logits.argmax(1) 34 | if 'label' in x_dict: 35 | y = x_dict['label'][x_dict['loss_mask']].long() 36 | loss = self.ce_loss(logits, y) 37 | return {'pred': pred, 'latent': x_dict['h'], 'label': y}, loss 38 | else: 39 | return {'pred': pred, 'latent': x_dict['h']}, torch.tensor(float('nan')) 40 | 41 | class PatientClassificationHead(nn.Module): 42 | def __init__(self, in_dim, hidden_dim, num_classes, num_layers, dropout, norm=None, batch_num=None, **kwargs): 43 | super().__init__() 44 | 45 | self.ce_loss = nn.CrossEntropyLoss() 46 | self.cls = nn.Parameter(torch.randn((1, in_dim)) * 0.01) 47 | self.output_layer = nn.Linear(in_dim, num_classes) 48 | 49 | def classify(self, x_dict): 50 | 51 | return self.output_layer(torch.mean(x_dict['h'], 0, keepdim=True)) 52 | 53 | def forward(self, x_dict): 54 | prob = self.classify(x_dict) 55 | pred = prob.argmax(1) 56 | y = x_dict['label'].long() 57 | loss = self.ce_loss(prob, y) 58 | return {'pred': pred, 'latent': x_dict['h']}, loss 59 | 60 | class DenoisingHead(nn.Module): 61 | def __init__(self, in_dim, hidden_dim, out_dim, num_layers, dropout, norm, batch_num, lib_size=1e4, 62 | log_norm=True, **kwargs): 63 | super().__init__() 64 | self.mse_loss = nn.MSELoss() 65 | self.lib_size = lib_size 66 | self.log_norm = log_norm 67 | layers = [in_dim] + [hidden_dim] * (num_layers - 1) + [out_dim] 68 | dropouts = [dropout] * len(layers) 69 | self.mlp = buildNetwork(layers, dropouts) 70 | 71 | def forward(self, x_dict): 72 | pred = self.mlp(x_dict['h']) #* x_dict['input_mask'] 73 | if self.training: 74 | y = x_dict['x_seq'].to_dense() 75 | if self.lib_size is not None: 76 | y = y/y.sum(1)[:, None] * self.lib_size 77 | if self.log_norm: 78 | y = torch.log(y+1) 79 | loss = self.mse_loss(pred * x_dict['input_mask'], y * x_dict['input_mask']) + self.mse_loss(pred, y) 80 | else: 81 | loss = torch.zeros(1) 82 | return {'pred': pred, 'latent': x_dict['h']}, loss 83 | 84 | 85 | class EmbedderHead(nn.Module): 86 | def __init__(self, in_dim=None, hidden_dim=None, out_dim=None, num_layers=None, 87 | dropout=None, norm=None, batch_num=None, lib_size=None, 88 | log_norm=False, **kwargs): 89 | super().__init__() 90 | 91 | def forward(self, x_dict): 92 | pred = x_dict['h'] 93 | return {'pred': pred, 'latent': x_dict['h']}, torch.tensor(0.).to(x_dict['h'].device) 94 | 95 | class ImputationHead(nn.Module): 96 | def __init__(self, in_dim, hidden_dim, out_dim, num_layers, dropout, norm, batch_num, **kwargs): 97 | super().__init__() 98 | self.mse_loss = nn.MSELoss() 99 | # self.mse_loss = lambda x, y: torch.mean(((x-y) * (y/5+1))**2) 100 | layers = [in_dim] + [hidden_dim] * (num_layers - 1) + [out_dim] 101 | dropouts = [dropout] * len(layers) 102 | self.mlp = buildNetwork(layers, dropouts) 103 | 104 | def forward(self, x_dict): 105 | pred = self.mlp(x_dict['h'])[:, x_dict['gene_mask']] 106 | y = x_dict['label'][:, x_dict['gene_mask']] 107 | loss = self.mse_loss(pred, y) 108 | return {'pred': pred, 'latent': x_dict['h']}, loss 109 | 110 | class PerturbationPredictionHead(nn.Module): 111 | def __init__(self, in_dim, hidden_dim, out_dim, num_layers, dropout, norm, batch_num, lib_size=None, 112 | log_norm=False, **kwargs): 113 | super().__init__() 114 | self.mse_loss = nn.MSELoss() 115 | self.lib_size = lib_size 116 | self.log_norm = log_norm 117 | layers = [in_dim] + [hidden_dim] * (num_layers - 1) + [out_dim] 118 | dropouts = [dropout] * len(layers) 119 | self.mlp = buildNetwork(layers, dropouts) 120 | 121 | def forward(self, x_dict): 122 | pred = self.mlp(x_dict['h']) * x_dict['input_mask'] 123 | y = x_dict['label'].to_dense() 124 | if self.lib_size is not None: 125 | y = y/y.sum(1)[:, None] * self.lib_size 126 | if self.log_norm: 127 | y = torch.log1p(y) 128 | loss = self.mse_loss(pred, y * x_dict['input_mask']) 129 | return {'pred': pred, 'latent': x_dict['h']}, loss 130 | 131 | def get_normalized_expression(model, seq_list, batch_list, coord_list=None, device='cuda', 132 | transform_batch=None, library_size=None, n_samples=1, return_mean=True): 133 | transform_batch = range(len(seq_list)) if transform_batch is None else transform_batch 134 | exprs = [] 135 | for i in tqdm(range(len(seq_list))): 136 | input_dict = {'x_seq': seq_list[i].to(device)} 137 | if coord_list is not None: 138 | input_dict['coord'] = coord_list[i].to(device) 139 | x_dict = XDict(input_dict) 140 | per_batch_exprs = [] 141 | for batch in transform_batch: 142 | per_sample_exprs = [] 143 | input_dict['batch'] = batch * torch.ones(len(seq_list[i])).float().to(device) 144 | for sample in range(len(n_samples)): 145 | out_dict, _ = model(x_dict) 146 | output = out_dict['pred'] 147 | if library_size is not None: 148 | output = output / torch.sum(output, 1, keepdim=True) * library_size 149 | output = output.cpu().numpy() 150 | per_sample_exprs.append(output) 151 | per_batch_exprs.append(np.stack(per_sample_exprs)) 152 | per_batch_exprs = np.stack(per_batch_exprs, axis=1) 153 | exprs.append(per_batch_exprs.mean(1)) 154 | 155 | if n_samples > 1: 156 | # The -2 axis correspond to cells. 157 | exprs = np.concatenate(exprs, axis=-2) 158 | else: 159 | exprs = np.concatenate(exprs, axis=0) 160 | if n_samples > 1 and return_mean: 161 | exprs = exprs.mean(0) 162 | 163 | return exprs 164 | -------------------------------------------------------------------------------- /CellPLM/latent/__init__.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from .autoencoders import VAELatentLayer, SplitLatentLayer, MergeLatentLayer, GMVAELatentLayer 3 | from .adversarial import AdversarialLatentLayer 4 | from .contrastive import ECSLatentLayer 5 | import torch 6 | import logging 7 | from abc import ABC 8 | from abc import abstractmethod 9 | from ..utils import DSBNNorm 10 | 11 | # class AbstractLatentLayer(ABC):s 12 | # def __init__(self): 13 | # self.is_adverserial = True 14 | # 15 | # @abstractmethod 16 | # def forward(self, h, g): 17 | # pass 18 | 19 | def create_latent_layer(**config) -> nn.Module: 20 | if config['type'] == 'adversarial': 21 | return AdversarialLatentLayer(**config) 22 | elif config['type'] == 'vae': 23 | return VAELatentLayer(**config) 24 | elif config['type'] == 'split': 25 | return SplitLatentLayer(**config) 26 | elif config['type'] == 'merge': 27 | return MergeLatentLayer(**config) 28 | elif config['type'] == 'gmvae': 29 | return GMVAELatentLayer(**config) 30 | elif config['type'] == 'vqvae': 31 | return VQVAELatentLayer(**config) 32 | elif config['type'] == 'ecs': 33 | return ECSLatentLayer(**config) 34 | else: 35 | raise ValueError(f"Unrecognized latent model name: {config['type']}") 36 | 37 | class PlaceholderLayer(nn.Module): 38 | def __init__(self, **kwargs): 39 | super().__init__() 40 | self.is_adversarial = False 41 | 42 | def forward(self, x_dict): 43 | return x_dict['h'], torch.tensor(0.).to(x_dict['h'].device) 44 | 45 | class LatentModel(nn.Module): 46 | def __init__(self, configs=None): 47 | super().__init__() 48 | self.layers = nn.ModuleList([PlaceholderLayer()]) 49 | self.alias_dict = {} 50 | if configs is not None: 51 | for c in configs: 52 | self.layers.append(create_latent_layer(**c)) 53 | 54 | def forward(self, x_dict): 55 | total_loss = 0 56 | for layer in self.layers: 57 | x_dict['h'], loss = layer(x_dict) 58 | total_loss += loss 59 | return x_dict['h'], total_loss 60 | 61 | def add_layer(self, **config): 62 | if 'alias' in config: 63 | self.alias_dict[config['alias']] = len(self.layers) 64 | else: 65 | self.alias_dict[config['type']] = len(self.layers) 66 | self.layers.append(create_latent_layer(**config)) 67 | 68 | def get_layer(self, alias): 69 | return self.layers[self.alias_dict[alias]] 70 | 71 | def d_train(self, x_dict): 72 | loss = 0 73 | for layer in self.layers: 74 | if layer.is_adversarial: 75 | loss += layer.d_iter(x_dict) 76 | return loss 77 | 78 | class PreLatentNorm(nn.Module): 79 | def __init__(self, type='none', enc_hid=None, dataset_num=None): 80 | super().__init__() 81 | self.type = type 82 | if type not in ['none', 'dsbn', 'ln']: 83 | raise NotImplementedError(f'"{type}" type of pre latent norm is not implemented.') 84 | if type == 'dsbn': 85 | self.norm = DSBNNorm(enc_hid, dataset_num) 86 | elif type == 'ln': 87 | self.norm = nn.LayerNorm(enc_hid) 88 | 89 | def forward(self, xdict): 90 | if self.type == 'dsbn': 91 | return self.norm(xdict) 92 | elif self.type == 'ln': 93 | return self.norm(xdict['h']) 94 | else: 95 | return xdict['h'] 96 | -------------------------------------------------------------------------------- /CellPLM/latent/adversarial.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | class BatchDiscriminator(nn.Module): 6 | def __init__(self, input_dim, hidden_dim, hidden_layers, dropout, target_classes): 7 | super(BatchDiscriminator, self).__init__() 8 | self.layers = nn.ModuleList() 9 | for i in range(hidden_layers - 1): 10 | self.layers.append(nn.Sequential( 11 | nn.Linear(input_dim, hidden_dim), 12 | nn.PReLU(), 13 | nn.Dropout(dropout), 14 | )) 15 | input_dim = hidden_dim 16 | self.layers.append(nn.Linear(input_dim, target_classes)) 17 | self.layers.cuda() 18 | 19 | def forward(self, h): 20 | rp = h.shape[0] 21 | h = h.mean(dim=0) 22 | for layer in self.layers: 23 | h = layer(h) 24 | return h.repeat(rp, 1) 25 | 26 | class Discriminator(nn.Module): 27 | def __init__(self, input_dim, hidden_dim, hidden_layers, dropout, target_classes): 28 | super(Discriminator, self).__init__() 29 | self.layers = nn.ModuleList() 30 | for i in range(hidden_layers - 1): 31 | self.layers.append(nn.Sequential( 32 | nn.Linear(input_dim, hidden_dim), 33 | nn.PReLU(), 34 | nn.Dropout(dropout), 35 | nn.LayerNorm(hidden_dim), 36 | )) 37 | input_dim = hidden_dim 38 | self.layers.append(nn.Linear(input_dim, target_classes)) 39 | 40 | def forward(self, h): 41 | for layer in self.layers: 42 | h = layer(h) 43 | return h 44 | 45 | class AdversarialLatentLayer(nn.Module): 46 | """Adversarial latent layer for CellBert 47 | 48 | Parameters 49 | ---------- 50 | input_dims : Iterable[int] 51 | List of input dimensions 52 | label_key : str 53 | Key of the label in the input dictionary 54 | batch_wise : bool 55 | Whether to use batch-wise discriminator 56 | discriminator_hidden : int 57 | Hidden dimension of the discriminator 58 | discriminator_layers : int 59 | Number of layers in the discriminator 60 | discriminator_dropout : float 61 | Dropout rate of the discriminator 62 | target_classes : int 63 | Number of classes in the discriminator 64 | disc_lr : float 65 | Learning rate of the discriminator 66 | disc_wd : float 67 | Weight decay of the discriminator 68 | """ 69 | def __init__(self, input_dims, label_key, batch_wise=False, discriminator_hidden=128, discriminator_layers=2, discriminator_dropout=0.1, 70 | target_classes=2, disc_lr=1e-3, disc_wd=1e-6, **kwargs): 71 | super().__init__() 72 | self.source_dims = input_dims 73 | self.label_keys = label_key 74 | num_src_dim = len(input_dims) 75 | if batch_wise: 76 | self.discriminator = BatchDiscriminator(num_src_dim, discriminator_hidden, discriminator_layers, 77 | discriminator_dropout, target_classes) 78 | else: 79 | self.discriminator = Discriminator(num_src_dim, discriminator_hidden, discriminator_layers, 80 | discriminator_dropout, target_classes) 81 | self.is_adversarial = True 82 | self.set_d_optimizer(disc_lr, disc_wd) 83 | self.d_loss = 0 84 | self.trained = False 85 | 86 | def forward(self, x_dict): 87 | if self.training and self.trained: 88 | h = x_dict['h'] 89 | y = self.discriminator(h[:, self.source_dims]) 90 | return h, -F.cross_entropy(y, x_dict[self.label_keys]) 91 | else: 92 | return x_dict['h'], 0 93 | 94 | def set_d_optimizer(self, lr=1e-3, wd=1e-6): 95 | self.d_optimizer = torch.optim.SGD(self.discriminator.parameters(), lr=lr, weight_decay=wd) 96 | 97 | def d_step(self): 98 | self.d_optimizer.step() 99 | self.d_optimizer.zero_grad() 100 | 101 | def d_iter(self, x_dict): 102 | y_nograd = self.discriminator(x_dict['h'][:, self.source_dims].detach()) 103 | d_loss = F.cross_entropy(y_nograd, x_dict[self.label_keys]) 104 | 105 | self.d_optimizer.zero_grad() 106 | d_loss.backward() 107 | self.d_optimizer.step() 108 | self.d_optimizer.zero_grad() 109 | self.trained = True 110 | return d_loss.item() 111 | -------------------------------------------------------------------------------- /CellPLM/latent/autoencoders.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | from sklearn import mixture 7 | from ..decoder import MLPDecoder 8 | 9 | class SplitLatentLayer(nn.Module): 10 | def __init__(self, enc_hid, latent_dim=None, conti_dim=None, cat_dim=None, cont_l2_reg=0.01, cont_l1_reg=0.01, **kwargs): 11 | super().__init__() 12 | if conti_dim is None and cat_dim is None: 13 | assert latent_dim is not None, 'Latent dimension not specified!' 14 | self.hid_2lat = nn.Sequential( 15 | nn.Linear(enc_hid, latent_dim), 16 | nn.GELU(), 17 | ) 18 | else: 19 | if conti_dim is not None and cat_dim is not None: 20 | if latent_dim is None and conti_dim + cat_dim != latent_dim: 21 | logging.warning("latent_dim is ignored, since conti_dim and cat_dim are given.") 22 | elif cat_dim is None: 23 | conti_dim = latent_dim - cat_dim 24 | else: 25 | cat_dim = latent_dim - conti_dim 26 | 27 | latent_dim = None 28 | self.hid_2cont = nn.Sequential( 29 | nn.Linear(enc_hid, conti_dim), 30 | nn.GELU(), 31 | ) 32 | self.hid_2cat = nn.Sequential( 33 | nn.Linear(enc_hid, cat_dim), 34 | nn.Softmax(1), 35 | ) 36 | 37 | self.latent_dim = latent_dim 38 | self.conti_dim = conti_dim 39 | self.cat_dim = cat_dim 40 | self.is_adversarial = False 41 | self.cont_l1_reg = cont_l1_reg 42 | self.cont_l2_reg = cont_l2_reg 43 | 44 | def forward(self, x_dict=None): 45 | h = x_dict['h'] 46 | if self.latent_dim is not None: 47 | h = self.hid_2lat(h) 48 | loss = 0 49 | else: 50 | h = torch.cat([self.hid_2cont(h), self.hid_2cat(h)], 1) 51 | params = torch.cat([x.view(-1) for x in self.hid_2cont.parameters()]) 52 | loss = self.cont_l1_reg * torch.norm(params, 1) + self.cont_l2_reg * torch.norm(params, 2) 53 | return h, loss 54 | 55 | class MergeLatentLayer(nn.Module): 56 | """ 57 | Merge discrete and continuous dimensions to a new continious latent space 58 | """ 59 | def __init__(self, conti_dim, cat_dim, post_latent_dim, **kwargs): 60 | super().__init__() 61 | 62 | self.lat_2lat = nn.Sequential( 63 | nn.Linear(conti_dim + cat_dim, post_latent_dim), 64 | # nn.ReLU(), 65 | ) 66 | self.post_latent_dim = post_latent_dim 67 | self.conti_dim = conti_dim 68 | self.cat_dim = cat_dim 69 | self.is_adversarial = False 70 | 71 | def forward(self, x_dict): 72 | h = x_dict['h'] 73 | return self.lat_2lat(h), 0 74 | 75 | class VAELatentLayer(nn.Module): 76 | def __init__(self, enc_hid, latent_dim, kl_weight=1., warmup_step=10000, lamda=1.0, **kwargs):#400*160 77 | super().__init__() 78 | self.hid_2mu = nn.Linear(enc_hid, latent_dim)#, bias=False) 79 | self.hid_2sigma = nn.Linear(enc_hid, latent_dim)#, bias=False) 80 | self.kl_weight = 0#kl_weight 81 | self.max_kl_weight = kl_weight 82 | self.step_count = 0 83 | self.warmup_step = warmup_step 84 | self.is_adversarial = False 85 | self.lamda = lamda 86 | 87 | def kl_schedule_step(self): 88 | self.step_count += 1 89 | if self.step_count < self.warmup_step: 90 | self.kl_weight = self.kl_weight + self.max_kl_weight / self.warmup_step 91 | elif self.step_count == self.warmup_step: 92 | pass 93 | 94 | def forward(self, x_dict, var_eps=True): 95 | h = x_dict['h'] 96 | mu = self.hid_2mu(h) 97 | log_var = torch.clamp(self.hid_2sigma(h), -5, 5) #+ 1e-4 98 | if var_eps: 99 | sigma = (torch.exp(log_var) + 1e-4).sqrt() 100 | log_var = 2 * torch.log(sigma) 101 | else: 102 | sigma = torch.exp(0.5 * log_var) 103 | eps = torch.randn_like(sigma) 104 | 105 | if self.training: 106 | z = mu + sigma * eps 107 | kl_loss = -0.5 * (1 + log_var - mu ** 2 - log_var.exp()).sum(1).mean() * self.kl_weight 108 | if kl_loss < self.lamda: 109 | kl_loss = 0 110 | self.kl_schedule_step() 111 | else: 112 | z = mu 113 | kl_loss = 0 114 | return z, kl_loss 115 | 116 | ### Reference: https://github.com/jariasf/GMVAE/blob/master/pytorch/networks/Networks.py ### 117 | 118 | class GumbelSoftmax(nn.Module): 119 | 120 | def __init__(self, f_dim, c_dim): 121 | super(GumbelSoftmax, self).__init__() 122 | self.logits = nn.Linear(f_dim, c_dim) 123 | self.f_dim = f_dim 124 | self.c_dim = c_dim 125 | 126 | def sample_gumbel(self, logits, eps=1e-20): 127 | U = torch.rand_like(logits) 128 | return -torch.log(-torch.log(U + eps) + eps) 129 | 130 | def gumbel_softmax_sample(self, logits, temperature): 131 | y = logits + self.sample_gumbel(logits) 132 | return F.softmax(y / temperature, dim=-1) 133 | 134 | def gumbel_softmax(self, logits, temperature, hard=False): 135 | """ 136 | ST-gumple-softmax 137 | input: [*, n_class] 138 | return: flatten --> [*, n_class] an one-hot vector 139 | """ 140 | # categorical_dim = 10 141 | y = self.gumbel_softmax_sample(logits, temperature) 142 | 143 | if not hard: 144 | return y 145 | 146 | shape = y.size() 147 | _, ind = y.max(dim=-1) 148 | y_hard = torch.zeros_like(y).view(-1, shape[-1]) 149 | y_hard.scatter_(1, ind.view(-1, 1), 1) 150 | y_hard = y_hard.view(*shape) 151 | # Set gradients w.r.t. y_hard gradients w.r.t. y 152 | y_hard = (y_hard - y).detach() + y 153 | return y_hard 154 | 155 | def forward(self, x, temperature=1.0, hard=False): 156 | logits = self.logits(x).view(-1, self.c_dim) 157 | prob = F.softmax(logits, dim=-1) 158 | y = self.gumbel_softmax(logits, temperature, hard) 159 | return logits, prob, y 160 | 161 | 162 | # Sample from a Gaussian distribution 163 | class Gaussian(nn.Module): 164 | def __init__(self, in_dim, z_dim): 165 | super(Gaussian, self).__init__() 166 | self.mu = nn.Linear(in_dim, z_dim) 167 | self.var = nn.Linear(in_dim, z_dim) 168 | 169 | def reparameterize(self, mu, var): 170 | std = torch.sqrt(var + 1e-10) 171 | noise = torch.randn_like(std) 172 | z = mu + noise * std 173 | return z 174 | 175 | def forward(self, x): 176 | mu = self.mu(x) 177 | var = F.softplus(self.var(x)) 178 | z = self.reparameterize(mu, var) 179 | return mu, var, z 180 | 181 | class InferenceNet(nn.Module): 182 | def __init__(self, x_dim, z_dim, y_dim): 183 | super(InferenceNet, self).__init__() 184 | 185 | # q(y|x) 186 | self.inference_qyx = torch.nn.ModuleList([ 187 | GumbelSoftmax(x_dim, y_dim) 188 | ]) 189 | 190 | # q(z|y,x) 191 | self.inference_qzyx = torch.nn.ModuleList([ 192 | nn.Linear(x_dim + y_dim, 512), 193 | nn.ReLU(), 194 | Gaussian(512, z_dim) 195 | ]) 196 | 197 | self.y_mu = nn.Linear(y_dim, z_dim) 198 | self.y_var = nn.Linear(y_dim, z_dim) 199 | 200 | # q(y|x) 201 | def qyx(self, x, temperature, hard): 202 | num_layers = len(self.inference_qyx) 203 | for i, layer in enumerate(self.inference_qyx): 204 | if i == num_layers - 1: 205 | # last layer is gumbel softmax 206 | x = layer(x, temperature, hard) 207 | else: 208 | x = layer(x) 209 | return x 210 | 211 | # q(z|x,y) 212 | def qzxy(self, x, y): 213 | concat = torch.cat((x, y), dim=1) 214 | for layer in self.inference_qzyx: 215 | concat = layer(concat) 216 | return concat 217 | 218 | def pzy(self, y): 219 | y_mu = self.y_mu(y) 220 | y_var = F.softplus(self.y_var(y)) 221 | return y_mu, y_var 222 | 223 | def forward(self, x, temperature=1.0, hard=0): 224 | # x = Flatten(x) 225 | 226 | # q(y|x) 227 | logits, prob, y = self.qyx(x, temperature, hard) 228 | 229 | # q(z|x,y) 230 | mu, var, z = self.qzxy(x, y) 231 | 232 | y_mu, y_var = self.pzy(y) 233 | 234 | output = {'mean': mu, 'var': var, 'gaussian': z, 235 | 'logits': logits, 'prob_cat': prob, 'categorical': y, 236 | 'y_mean': y_mu, 'y_var': y_var} 237 | return output 238 | 239 | class GMVAELatentLayer(nn.Module): 240 | def __init__(self, enc_hid, latent_dim, num_clusters, hard=False, 241 | w_li=1., w_en=1., lamda=0.5, **kwargs): 242 | super(GMVAELatentLayer, self).__init__() 243 | 244 | self.hard = hard 245 | self.inference = InferenceNet(enc_hid, latent_dim, num_clusters) 246 | self.w_li = w_li 247 | self.w_en = w_en 248 | self.lamda = lamda 249 | self.eps = 1e-8 250 | self.num_clusters = num_clusters 251 | self.is_adversarial = False 252 | 253 | def forward(self, x_dict, temperature=1.0): 254 | if self.training: 255 | out_dict = self.inference(x_dict['h'], temperature, self.hard) 256 | z = out_dict['gaussian'] 257 | loss = self.unlabeled_loss(out_dict) 258 | return z, loss 259 | else: 260 | out_dict = self.inference(x_dict['h'], temperature, True) 261 | z = out_dict['mean'] 262 | return z, self.unlabeled_loss(out_dict) 263 | 264 | def log_normal(self, x, mu, var): 265 | """Logarithm of normal distribution with mean=mu and variance=var 266 | log(x|μ, σ^2) = loss = -0.5 * Σ log(2π) + log(σ^2) + ((x - μ)/σ)^2 267 | 268 | Args: 269 | x: (array) corresponding array containing the input 270 | mu: (array) corresponding array containing the mean 271 | var: (array) corresponding array containing the variance 272 | 273 | Returns: 274 | output: (array/float) depending on average parameters the result will be the mean 275 | of all the sample losses or an array with the losses per sample 276 | """ 277 | if self.eps > 0.0: 278 | var = var + self.eps 279 | return -0.5 * torch.sum( 280 | np.log(2.0 * np.pi) + torch.log(var) + torch.pow(x - mu, 2) / var, dim=-1) 281 | 282 | def gaussian_loss(self, z, z_mu, z_var, z_mu_prior, z_var_prior): 283 | """Variational loss when using labeled data without considering reconstruction loss 284 | loss = log q(z|x,y) - log p(z) - log p(y) 285 | 286 | Args: 287 | z: (array) array containing the gaussian latent variable 288 | z_mu: (array) array containing the mean of the inference model 289 | z_var: (array) array containing the variance of the inference model 290 | z_mu_prior: (array) array containing the prior mean of the generative model 291 | z_var_prior: (array) array containing the prior variance of the generative mode 292 | 293 | Returns: 294 | output: (array/float) depending on average parameters the result will be the mean 295 | of all the sample losses or an array with the losses per sample 296 | """ 297 | loss = self.log_normal(z, z_mu, z_var) - self.log_normal(z, z_mu_prior, z_var_prior) 298 | return loss.mean() 299 | 300 | def entropy(self, logits, targets): 301 | """Entropy loss 302 | loss = (1/n) * -Σ targets*log(predicted) 303 | 304 | Args: 305 | logits: (array) corresponding array containing the logits of the categorical variable 306 | real: (array) corresponding array containing the true labels 307 | 308 | Returns: 309 | output: (array/float) depending on average parameters the result will be the mean 310 | of all the sample losses or an array with the losses per sample 311 | """ 312 | log_q = F.log_softmax(logits, dim=-1) 313 | return -torch.mean(torch.sum(targets * log_q, dim=-1)) 314 | 315 | def unlabeled_loss(self, out_net): 316 | """Method defining the loss functions derived from the variational lower bound 317 | Args: 318 | data: (array) corresponding array containing the input data 319 | out_net: (dict) contains the graph operations or nodes of the network output 320 | 321 | Returns: 322 | loss_dic: (dict) contains the values of each loss function and predictions 323 | """ 324 | # obtain network variables 325 | z = out_net['gaussian'] 326 | logits, prob_cat = out_net['logits'], out_net['prob_cat'] 327 | y_mu, y_var = out_net['y_mean'], out_net['y_var'] 328 | mu, var = out_net['mean'], out_net['var'] 329 | 330 | # gaussian loss 331 | loss_gauss = max(self.lamda, self.gaussian_loss(z, mu, var, y_mu, y_var)) 332 | 333 | # categorical loss 334 | loss_cat = max(self.lamda, -self.entropy(logits, prob_cat) - np.log(1/self.num_clusters)) 335 | 336 | # total loss 337 | loss_total = self.w_li * loss_gauss + self.w_en * loss_cat 338 | 339 | return loss_total -------------------------------------------------------------------------------- /CellPLM/latent/contrastive.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch.nn as nn 3 | import torch 4 | 5 | class ECSLatentLayer(nn.Module): 6 | def __init__(self, ecs_threshold=0.6, **kwargs): 7 | super().__init__() 8 | self.ecs_threshold = ecs_threshold 9 | self.is_adversarial = False 10 | 11 | def forward(self, x_dict): 12 | if self.training and 'ecs' in x_dict: 13 | cell_emb = x_dict['h'] 14 | cell_emb_normed = F.normalize(cell_emb, p=2, dim=1) 15 | cos_sim = torch.mm(cell_emb_normed, cell_emb_normed.t()) # (batch, batch) 16 | 17 | # mask out diagnal elements 18 | mask = torch.eye(cos_sim.size(0)).bool().to(cos_sim.device) 19 | cos_sim = cos_sim.masked_fill(mask, 0.0) 20 | # only optimize positive similarities 21 | cos_sim = F.relu(cos_sim) 22 | return cell_emb, torch.mean(1 - (cos_sim - self.ecs_threshold) ** 2) 23 | else: 24 | return x_dict['h'], 0 25 | 26 | -------------------------------------------------------------------------------- /CellPLM/layer/__init__.py: -------------------------------------------------------------------------------- 1 | from .cosformer import CosformerLayer 2 | from .performer import PerformerLayer 3 | from .transformer import VanillaTransformerLayer 4 | from .flowformer import FlowformerLayer -------------------------------------------------------------------------------- /CellPLM/layer/cosformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | 5 | from torch import Tensor 6 | from typing import Optional 7 | from torch import nn 8 | from ..utils import create_norm 9 | from .transformer import AbstractTrasnformerLayer 10 | 11 | class CosformerAttention(nn.Module): 12 | """ 13 | cosformer attention in "cosFormer: Rethinking Softmax In Attention" 14 | https://arxiv.org/abs/2202.08791 15 | """ 16 | 17 | def __init__( 18 | self, 19 | embed_dim, 20 | num_heads, 21 | kdim=None, 22 | vdim=None, 23 | dropout_rate=0.0, 24 | causal=False, 25 | has_outproj=True, 26 | act_fun="gelu", 27 | ): 28 | super().__init__() 29 | self.embed_dim = embed_dim 30 | self.kdim = kdim if kdim is not None else embed_dim 31 | self.vdim = vdim if kdim is not None else embed_dim 32 | self.num_heads = num_heads 33 | self.has_outproj = has_outproj 34 | self.act_fun = self.get_act_fun(act_fun) 35 | # q, k, v projection 36 | self.k_proj = nn.Linear(self.kdim, embed_dim) 37 | self.v_proj = nn.Linear(self.vdim, embed_dim) 38 | self.q_proj = nn.Linear(embed_dim, embed_dim) 39 | # outprojection 40 | self.out_proj = nn.Linear(embed_dim, embed_dim) 41 | # dropout rate 42 | self.dropout_rate = dropout_rate 43 | self.attn_dropout = nn.Dropout(dropout_rate) 44 | # causal 45 | self.causal = causal 46 | 47 | assert (self.embed_dim % self.num_heads == 0), "embed_dim must be divisible by num_heads" 48 | 49 | def get_index(self, seq_len): 50 | index = np.pi / 2 * torch.arange(1, seq_len + 1).reshape(1, -1, 1) 51 | 52 | return nn.Parameter(index, requires_grad=False) 53 | 54 | def get_act_fun(self, act_fun): 55 | if act_fun == "relu": 56 | return F.relu 57 | elif act_fun == "elu": 58 | return 1 + F.elu 59 | elif act_fun == "gelu": 60 | return F.gelu 61 | else: 62 | raise ValueError(f"Unrecognized activation function: {act_fun}.") 63 | 64 | def forward( 65 | self, 66 | query: Tensor, 67 | key: Optional[Tensor] = None, 68 | value: Optional[Tensor] = None, 69 | attn_mask: Optional[Tensor] = None, 70 | eps: Optional[float] = 1e-6, 71 | ): 72 | """Input shape: Sequence x Batch x Embedding 73 | Args: 74 | query (Tensor): `(L, N, E)` where L is the target sequence length, N is the batch size, 75 | E is the embedding dimension. 76 | key (Tensor): `(S, N, E)` where S is the source sequence length, N is the batch size, 77 | E is the embedding dimension. 78 | value (Tensor): `(S, N, E)` where S is the source sequence length, N is the batch size, 79 | E is the embedding dimension. 80 | attn_mask (Optional[Tensor], optional): typically used to implement causal attention, 81 | where the mask prevents the attention from looking forward in time (default: None). 82 | """ 83 | if key == None: 84 | key = query 85 | if value == None: 86 | value = query 87 | 88 | num_heads = self.num_heads 89 | tgt_len, bsz, embed_dim = query.size() 90 | src_len = key.size(0) 91 | head_dim = embed_dim // num_heads 92 | 93 | # get q, k, v 94 | # (L, N, E) 95 | q = self.q_proj(query) 96 | # (S, N, E) 97 | k = self.k_proj(key) 98 | # (S, N, E) 99 | v = self.v_proj(value) 100 | 101 | # activation 102 | q = self.act_fun(q) 103 | k = self.act_fun(k) 104 | 105 | # multihead reshape 106 | # (N * h, L, d) 107 | q = q.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) 108 | # (N * h, S, d) 109 | k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) 110 | # (N * h, S, d) 111 | v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) 112 | 113 | # cos transform 114 | m = max(src_len, tgt_len) 115 | # get index and send to cuda 116 | weight_index = self.get_index(m).to(q) 117 | # (N * h, L, 2 * d) 118 | q_ = torch.cat( 119 | [q * torch.sin(weight_index[:, :tgt_len, :] / m), q * torch.cos(weight_index[:, :tgt_len, :] / m)], dim=-1) 120 | # (N * h, S, 2 * d) 121 | k_ = torch.cat( 122 | [k * torch.sin(weight_index[:, :src_len, :] / m), k * torch.cos(weight_index[:, :src_len, :] / m)], dim=-1) 123 | 124 | if self.causal: 125 | ## Need to improve speed! 126 | # (N * h, L, 2 * d) (N * h, L, d) -> (N * h, L, h, 2 * d, d) 127 | kv_ = torch.einsum("nld,nlm->nldm", k_, v) 128 | kv_ = self.attn_dropout(kv_) 129 | # (N * h, L, 2 * d, d) -> (N * h, L, 2 * d, d) 130 | kv_cum = torch.cumsum(kv_, dim=1) 131 | # (N * h, L, 2 * d) (N * h, L, 2 * d, d) -> (N * h, L, d) 132 | qkv = torch.einsum("nld,nldm->nlm", q_, kv_cum) 133 | # (N * h, L, 2 * d) -> (N * h, L, 2 * d) 134 | k_cum = torch.cumsum(k_, dim=1) 135 | # (N * h, L, 2 * d) (N * h, L, 2 * d) -> (N * h, L) 136 | denom = torch.clamp_min(torch.einsum("nlm,nlm->nl", q_, k_cum), eps) 137 | # (N * h, L, d) (N * h, L, 1) -> (N * h, L, d) 138 | attn_output = qkv / denom.unsqueeze(-1) 139 | # (N * h, L, d) -> (L, N * h, d) -> (L, N, E) 140 | attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, -1) 141 | else: 142 | # (N * h, L, 2 * d) (N * h, L, d) -> (N * h, 2 * d, d) 143 | kv_ = torch.einsum('nld,nlm->ndm', k_, v) 144 | kv_ = self.attn_dropout(kv_) 145 | # (N * h, L, 2 * d) (N * h, 2 * d) -> (N * h, L) 146 | z_ = 1 / torch.clamp_min(torch.einsum('nld,nd->nl', q_, torch.sum(k_, axis=1)), eps) 147 | # (N * h, L, 2 * d) (N * h, d, 2 * d) (N * h, L) -> (N * h, L, d) 148 | attn_output = torch.einsum('nld,ndm,nl->nlm', q_, kv_, z_) 149 | # (N * h, L, d) -> (L, N * h, d) -> (L, N, E) 150 | attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, -1) 151 | # L, N, E 152 | if self.has_outproj: 153 | attn_output = self.out_proj(attn_output) 154 | 155 | return attn_output 156 | 157 | def left_product( 158 | self, 159 | query: Tensor, 160 | key: Optional[Tensor] = None, 161 | value: Optional[Tensor] = None, 162 | attn_mask: Optional[Tensor] = None, 163 | eps: Optional[float] = 1e-6, 164 | ): 165 | """Input shape: Sequence x Batch x Embedding 166 | Args: 167 | query (Tensor): `(L, N, E)` where L is the target sequence length, N is the batch size, 168 | E is the embedding dimension. 169 | key (Tensor): `(S, N, E)` where S is the source sequence length, N is the batch size, 170 | E is the embedding dimension. 171 | value (Tensor): `(S, N, E)` where S is the source sequence length, N is the batch size, 172 | E is the embedding dimension. 173 | attn_mask (Optional[Tensor], optional): typically used to implement causal attention, 174 | where the mask prevents the attention from looking forward in time (default: None). 175 | """ 176 | # test for the correctness of the program 177 | if key == None: 178 | key = query 179 | if value == None: 180 | value = query 181 | 182 | num_heads = self.num_heads 183 | tgt_len, bsz, embed_dim = query.size() 184 | src_len = key.size(0) 185 | head_dim = embed_dim // num_heads 186 | 187 | # get q, k, v 188 | # (L, N, E) 189 | q = self.q_proj(query) 190 | # (S, N, E) 191 | k = self.k_proj(key) 192 | # (S, N, E) 193 | v = self.v_proj(value) 194 | 195 | # activation 196 | q = self.act_fun(q) 197 | k = self.act_fun(k) 198 | 199 | # multihead reshape 200 | # (N * h, L, d) 201 | q = q.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) 202 | # (N * h, S, d) 203 | k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) 204 | # (N * h, S, d) 205 | v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) 206 | 207 | # cos transform 208 | m = max(src_len, tgt_len) 209 | # get index and send to cuda 210 | weight_index = self.get_index(m).to(q) 211 | # (N * h, L, 2 * d) 212 | q_ = torch.cat( 213 | [q * torch.sin(weight_index[:, :tgt_len, :] / m), q * torch.cos(weight_index[:, :tgt_len, :] / m)], dim=-1) 214 | # (N * h, S, 2 * d) 215 | k_ = torch.cat( 216 | [k * torch.sin(weight_index[:, :src_len, :] / m), k * torch.cos(weight_index[:, :src_len, :] / m)], dim=-1) 217 | 218 | # (N * h, L, d) (N * h, d, S) -> (N * h, L, S) 219 | weights = torch.bmm(q_, k_.transpose(1, 2)) 220 | # mask 221 | if self.causal: 222 | weights = weights.masked_fill(attn_mask == float("-inf"), 0) 223 | # (N * h, L, S) -> (N * h, L, S) 224 | denom = torch.clamp_min(weights.sum(dim=-1, keepdim=True), eps) 225 | # (N * h, L, S) (N * h, L, S) -> (N * h, L, S) 226 | attn_weights = weights / denom 227 | # (N * h, L, S) (N * h, S, d) -> (N * h, L, d) 228 | attn_output = torch.bmm(attn_weights, v) 229 | # (N * h, L, d) -> (L, N * h, d) -> (L, N, E) 230 | attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, -1) 231 | # L, N, E 232 | if self.has_outproj: 233 | attn_output = self.out_proj(attn_output) 234 | 235 | return attn_output 236 | 237 | class CosformerLayer(nn.Module, AbstractTrasnformerLayer): 238 | def __init__( 239 | self, 240 | embed_dim, 241 | num_heads, 242 | dropout = 0.0, 243 | norm = 'layernorm', 244 | norm_first: bool = True, 245 | causal=False, 246 | ): 247 | super().__init__() 248 | self.self_attn = CosformerAttention(embed_dim=embed_dim, num_heads=num_heads) 249 | self._ff_block = nn.Sequential(nn.Linear(embed_dim, embed_dim*2), 250 | nn.GELU(), 251 | nn.Dropout(dropout), 252 | nn.Linear(embed_dim*2, embed_dim), 253 | nn.Dropout(dropout), 254 | ) 255 | self.dropout1 = nn.Dropout(dropout) 256 | self.norm1 = create_norm(norm, embed_dim) 257 | self.norm2 = create_norm(norm, embed_dim) 258 | self.norm_first = norm_first 259 | self.support_output_attentions = False 260 | 261 | def _sa_block(self, x: Tensor, attn_mask: Optional[Tensor]): 262 | x = x.unsqueeze(1) 263 | x = self.self_attn(x, attn_mask=attn_mask) 264 | return self.dropout1(x)[:, 0, :] 265 | 266 | def forward(self, x, attn_mask=None, output_attentions=False): 267 | assert output_attentions == False, 'output_attentions not implemented for Cosformer' 268 | if self.norm_first: 269 | x = x + self._sa_block(self.norm1(x),attn_mask) 270 | x = x + self._ff_block(self.norm2(x)) 271 | else: 272 | x = self.norm1(x + self._sa_block(x, attn_mask)) 273 | x = self.norm2(x + self._ff_block(x)) 274 | return x 275 | 276 | -------------------------------------------------------------------------------- /CellPLM/layer/flowformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .transformer import AbstractTrasnformerLayer 4 | from ..utils import create_norm 5 | 6 | ## Core code for Flow-Attention, Please refer to each folder for corresponding experiments 7 | 8 | class Flow_Attention(nn.Module): 9 | # flow attention in normal version 10 | def __init__(self, d_model, n_heads, drop_out=0.01, d_input=None, d_output=None, eps=1e-6): 11 | super(Flow_Attention, self).__init__() 12 | self.n_heads = n_heads 13 | if d_input is None: 14 | d_input = d_model 15 | if d_output is None: 16 | d_output = d_model 17 | self.query_projection = nn.Linear(d_input, d_model) 18 | self.key_projection = nn.Linear(d_input, d_model) 19 | self.value_projection = nn.Linear(d_input, d_model) 20 | self.out_projection = nn.Linear(d_model, d_output) 21 | self.dropout = nn.Dropout(drop_out) 22 | self.eps = eps 23 | 24 | def kernel_method(self, x): 25 | return torch.sigmoid(x) 26 | 27 | def dot_product(self, q, k, v): 28 | kv = torch.einsum("nhld,nhlm->nhdm", k, v) 29 | qkv = torch.einsum("nhld,nhdm->nhlm", q, kv) 30 | return qkv 31 | 32 | def forward(self, x): 33 | ## input: B (L or S) D; output: B L D 34 | ## Note: queries, keys, values are not projected yet 35 | ## 1. Linear projection 36 | queries = keys = values = x 37 | B, L, _ = queries.shape 38 | _, S, _ = keys.shape 39 | queries = self.query_projection(queries).view(B, L, self.n_heads, -1) 40 | keys = self.key_projection(keys).view(B, S, self.n_heads, -1) 41 | values = self.value_projection(values).view(B, S, self.n_heads, -1) 42 | queries = queries.transpose(1, 2) 43 | keys = keys.transpose(1, 2) 44 | values = values.transpose(1, 2) 45 | # 2. Non-negative projection 46 | queries = self.kernel_method(queries) 47 | keys = self.kernel_method(keys) 48 | ## 3. Flow-Attention 49 | # (1) Calculate incoming and outgoing flow 50 | sink_incoming = 1.0 / (torch.einsum("nhld,nhd->nhl", queries + self.eps, keys.sum(dim=2) + self.eps)) 51 | source_outgoing = 1.0 / (torch.einsum("nhld,nhd->nhl", keys + self.eps, queries.sum(dim=2) + self.eps)) 52 | # (2) conservation refine for source and sink 53 | conserved_sink = torch.einsum("nhld,nhd->nhl", queries + self.eps, 54 | (keys * source_outgoing[:, :, :, None]).sum(dim=2) + self.eps) 55 | conserved_source = torch.einsum("nhld,nhd->nhl", keys + self.eps, 56 | (queries * sink_incoming[:, :, :, None]).sum(dim=2) + self.eps) 57 | conserved_source = torch.clamp(conserved_source, min=-1.0, max=1.0) # for stability 58 | # (3) Competition & Allocation 59 | sink_allocation = torch.sigmoid(conserved_sink * (float(queries.shape[2]) / float(keys.shape[2]))) 60 | source_competition = torch.softmax(conserved_source, dim=-1) * float(keys.shape[2]) 61 | # (4) dot product 62 | x = (self.dot_product(queries * sink_incoming[:, :, :, None], # for value normalization 63 | keys, 64 | values * source_competition[:, :, :, None]) # competition 65 | * sink_allocation[:, :, :, None]).transpose(1, 2) # allocation 66 | ## (5) Final projection 67 | x = x.reshape(B, L, -1) 68 | x = self.out_projection(x) 69 | x = self.dropout(x) 70 | return x 71 | 72 | 73 | class Flow_Attention_Causal(nn.Module): 74 | # flow attention in causal version 75 | def __init__(self, d_model, n_heads, drop_out=0.05, d_input=None, d_output=None, eps=1e-6): 76 | super(Flow_Attention_Causal, self).__init__() 77 | self.n_heads = n_heads 78 | if d_input is None: 79 | d_input = d_model 80 | if d_output is None: 81 | d_output = d_model 82 | self.query_projection = nn.Linear(d_input, d_model) 83 | self.key_projection = nn.Linear(d_input, d_model) 84 | self.value_projection = nn.Linear(d_input, d_model) 85 | self.out_projection = nn.Linear(d_model, d_output) 86 | self.dropout = nn.Dropout(drop_out) 87 | self.eps = eps 88 | 89 | def kernel_method(self, x): 90 | return torch.sigmoid(x) 91 | 92 | def causal_dot_product(self, q, k, v): 93 | kv = torch.einsum("nhld,nhlm->nhldm", k, v) 94 | kv = torch.cumsum(kv, dim=2) 95 | qkv = torch.einsum("nhld,nhldm->nhlm", q, kv) 96 | return qkv 97 | 98 | def forward(self, x): 99 | ## input: B (L or S) D; output: B L D 100 | ## Note: queries, keys, values are not projected yet 101 | queries = keys= values = x 102 | ## 1. Linear projection 103 | B, L, _ = queries.shape 104 | _, S, _ = keys.shape 105 | queries = self.query_projection(queries).view(B, L, self.n_heads, -1) 106 | keys = self.key_projection(keys).view(B, S, self.n_heads, -1) 107 | values = self.value_projection(values).view(B, S, self.n_heads, -1) 108 | queries = queries.transpose(1, 2) 109 | keys = keys.transpose(1, 2) 110 | values = values.transpose(1, 2) 111 | # 2. Non-negative projection 112 | queries = self.kernel_method(queries) 113 | keys = self.kernel_method(keys) 114 | ## 3. Causal Flow-Attention 115 | # (1) Calculate incoming and outgoing flow 116 | sink_incoming = 1.0 / (torch.einsum("nhld,nhld->nhl", queries + self.eps, keys.cumsum(dim=2) + self.eps)) 117 | source_outgoing = 1.0 / (torch.einsum("nhld,nhld->nhl", keys + self.eps, queries.cumsum(dim=2) + self.eps)) 118 | # approximate normal conservation col and row by multiplying corresponding element number 119 | normal = (((torch.arange(queries.shape[2])).float() + 1.0)).to(queries.device)[None, None, :] 120 | sink_incoming = sink_incoming * normal 121 | source_outgoing = source_outgoing * normal 122 | # (2) conservation refine for source and sink 123 | conserved_sink = torch.einsum("nhld,nhld->nhl", queries + self.eps, 124 | (keys * source_outgoing[:, :, :, None]).cumsum(dim=2) + self.eps) / normal 125 | conserved_source = torch.einsum("nhld,nhld->nhl", keys + self.eps, 126 | (queries * sink_incoming[:, :, :, None]).cumsum( 127 | dim=2) + self.eps) / normal 128 | conserved_source = torch.clamp(conserved_source, min=-1.0, max=1.0) # for stability 129 | # (3) Competition & Allocation 130 | sink_allocation = torch.sigmoid(conserved_sink) 131 | conserved_source = torch.exp(conserved_source) 132 | source_competition = (conserved_source / conserved_source.cumsum(dim=-1)) * normal 133 | # (4) Causal dot product 134 | x = (self.causal_dot_product(queries * (sink_incoming[:, :, :, None] / normal[:, :, :, None]), # for value normalization 135 | keys, 136 | values * source_competition[:, :, :, None]) # competition 137 | * sink_allocation[:, :, :, None]).transpose(1, 2) # allocation 138 | ## (5) Final projection 139 | x = x.reshape(B, L, -1) 140 | x = self.out_projection(x) 141 | x = self.dropout(x) 142 | return x 143 | 144 | class FlowformerLayer(nn.Module, AbstractTrasnformerLayer): 145 | def __init__( 146 | self, 147 | embed_dim, 148 | num_heads, 149 | dropout = 0.0, 150 | norm = 'layernorm', 151 | norm_first=True, 152 | causal=False, 153 | ): 154 | super(FlowformerLayer, self).__init__() 155 | if not causal: 156 | self.self_attn = Flow_Attention(embed_dim, num_heads) 157 | else: 158 | self.self_attn = Flow_Attention_Causal(embed_dim, num_heads) 159 | self._ff_block = nn.Sequential(nn.Linear(embed_dim, embed_dim * 2), 160 | nn.GELU(), 161 | nn.Dropout(dropout), 162 | nn.Linear(embed_dim * 2, embed_dim), 163 | nn.Dropout(dropout), 164 | ) 165 | self.dropout1 = nn.Dropout(dropout) 166 | self.norm1 = create_norm(norm, embed_dim) 167 | self.norm2 = create_norm(norm, embed_dim) 168 | self.norm_first = norm_first 169 | self.support_output_attentions = False 170 | 171 | def _sa_block(self, x): 172 | x = x.unsqueeze(0) 173 | x = self.self_attn(x) 174 | return self.dropout1(x)[0, :, :] 175 | 176 | def forward(self, x, attn_mask=None, output_attentions=False): 177 | assert output_attentions == False, 'output_attentions not implemented for Cosformer' 178 | if self.norm_first: 179 | x = x + self._sa_block(self.norm1(x)) 180 | x = x + self._ff_block(self.norm2(x)) 181 | else: 182 | x = self.norm1(x + self._sa_block(x)) 183 | x = self.norm2(x + self._ff_block(x)) 184 | return x -------------------------------------------------------------------------------- /CellPLM/layer/performer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from functools import partial 4 | import numpy as np 5 | from .transformer import AbstractTrasnformerLayer 6 | 7 | from torch import Tensor 8 | from typing import Optional 9 | from torch import nn, einsum 10 | import math 11 | from ..utils import create_norm 12 | from einops import rearrange, repeat, pack, unpack 13 | 14 | TOKEN_SELF_ATTN_VALUE = -5e4 15 | 16 | def exists(val): 17 | return val is not None 18 | 19 | def empty(tensor): 20 | return tensor.numel() == 0 21 | 22 | def default(val, d): 23 | return val if exists(val) else d 24 | 25 | def to(t): 26 | return {'device': t.device, 'dtype': t.dtype} 27 | 28 | def max_neg_value(tensor): 29 | return -torch.finfo(tensor.dtype).max 30 | 31 | def l2norm(tensor): 32 | dtype = tensor.dtype 33 | normed = F.normalize(tensor, dim = -1) 34 | return normed.type(dtype) 35 | 36 | def pad_to_multiple(tensor, multiple, dim=-1, value=0): 37 | seqlen = tensor.shape[dim] 38 | m = seqlen / multiple 39 | if m.is_integer(): 40 | return False, tensor 41 | remainder = math.ceil(m) * multiple - seqlen 42 | pad_offset = (0,) * (-1 - dim) * 2 43 | return True, F.pad(tensor, (*pad_offset, 0, remainder), value = value) 44 | 45 | def look_around(x, backward = 1, forward = 0, pad_value = -1, dim = 2): 46 | t = x.shape[1] 47 | dims = (len(x.shape) - dim) * (0, 0) 48 | padded_x = F.pad(x, (*dims, backward, forward), value = pad_value) 49 | tensors = [padded_x[:, ind:(ind + t), ...] for ind in range(forward + backward + 1)] 50 | return torch.cat(tensors, dim = dim) 51 | 52 | class SinusoidalEmbeddings(nn.Module): 53 | def __init__(self, dim): 54 | super().__init__() 55 | inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) 56 | self.register_buffer('inv_freq', inv_freq) 57 | 58 | def forward(self, x): 59 | n = x.shape[-2] 60 | t = torch.arange(n, device = x.device).type_as(self.inv_freq) 61 | freqs = torch.einsum('i , j -> i j', t, self.inv_freq) 62 | return torch.cat((freqs, freqs), dim=-1) 63 | 64 | def rotate_half(x): 65 | x = rearrange(x, 'b ... (r d) -> b (...) r d', r = 2) 66 | x1, x2 = x.unbind(dim = -2) 67 | return torch.cat((-x2, x1), dim = -1) 68 | 69 | def apply_rotary_pos_emb(q, k, freqs): 70 | q, k = map(lambda t: (t * freqs.cos()) + (rotate_half(t) * freqs.sin()), (q, k)) 71 | return q, k 72 | 73 | def orthogonal_matrix_chunk(cols, device = None): 74 | unstructured_block = torch.randn((cols, cols), device = device) 75 | q, r = torch.qr(unstructured_block.cpu(), some = True) 76 | q, r = map(lambda t: t.to(device), (q, r)) 77 | return q.t() 78 | 79 | def softmax_kernel(data, *, projection_matrix, is_query, normalize_data=True, eps=1e-4, device = None): 80 | b, h, *_ = data.shape 81 | 82 | data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1. 83 | 84 | ratio = (projection_matrix.shape[0] ** -0.5) 85 | 86 | projection = repeat(projection_matrix, 'j d -> b h j d', b = b, h = h) 87 | projection = projection.type_as(data) 88 | 89 | data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data), projection) 90 | 91 | diag_data = data ** 2 92 | diag_data = torch.sum(diag_data, dim=-1) 93 | diag_data = (diag_data / 2.0) * (data_normalizer ** 2) 94 | diag_data = diag_data.unsqueeze(dim=-1) 95 | 96 | if is_query: 97 | data_dash = ratio * ( 98 | torch.exp(data_dash - diag_data - 99 | torch.max(data_dash, dim=-1, keepdim=True).values) + eps) 100 | else: 101 | data_dash = ratio * ( 102 | torch.exp(data_dash - diag_data - torch.max(data_dash)) + eps) 103 | 104 | return data_dash.type_as(data) 105 | 106 | def generalized_kernel(data, *, projection_matrix, kernel_fn = nn.ReLU(), kernel_epsilon = 0.001, normalize_data = True, device = None): 107 | b, h, *_ = data.shape 108 | 109 | data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1. 110 | 111 | if projection_matrix is None: 112 | return kernel_fn(data_normalizer * data) + kernel_epsilon 113 | 114 | projection = repeat(projection_matrix, 'j d -> b h j d', b = b, h = h) 115 | projection = projection.type_as(data) 116 | 117 | data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data), projection) 118 | 119 | data_prime = kernel_fn(data_dash) + kernel_epsilon 120 | return data_prime.type_as(data) 121 | 122 | # non-causal linear attention 123 | def linear_attention(q, k, v): 124 | k_cumsum = k.sum(dim = -2) 125 | D_inv = 1. / torch.einsum('...nd,...d->...n', q, k_cumsum.type_as(q)) 126 | context = torch.einsum('...nd,...ne->...de', k, v) 127 | out = torch.einsum('...de,...nd,...n->...ne', context, q, D_inv) 128 | return out 129 | 130 | # efficient causal linear attention, created by EPFL 131 | def causal_linear_attention(q, k, v, eps = 1e-6): 132 | raise NotImplementedError('Please refer to performer-pytorch repo!') 133 | # https://github.com/lucidrains/performer-pytorch/blob/main/performer_pytorch/performer_pytorch.py 134 | 135 | def causal_linear_attention_noncuda(q, k, v, chunk_size = 128, eps = 1e-6): 136 | raise NotImplementedError('Please refer to performer-pytorch repo!') 137 | # https://github.com/lucidrains/performer-pytorch/blob/main/performer_pytorch/performer_pytorch.py 138 | 139 | def gaussian_orthogonal_random_matrix(nb_rows, nb_columns, scaling = 0, device = None): 140 | nb_full_blocks = int(nb_rows / nb_columns) 141 | 142 | block_list = [] 143 | 144 | for _ in range(nb_full_blocks): 145 | q = orthogonal_matrix_chunk(nb_columns, device = device) 146 | block_list.append(q) 147 | 148 | remaining_rows = nb_rows - nb_full_blocks * nb_columns 149 | if remaining_rows > 0: 150 | q = orthogonal_matrix_chunk(nb_columns, device = device) 151 | block_list.append(q[:remaining_rows]) 152 | 153 | final_matrix = torch.cat(block_list) 154 | 155 | if scaling == 0: 156 | multiplier = torch.randn((nb_rows, nb_columns), device = device).norm(dim = 1) 157 | elif scaling == 1: 158 | multiplier = math.sqrt((float(nb_columns))) * torch.ones((nb_rows,), device = device) 159 | else: 160 | raise ValueError(f'Invalid scaling {scaling}') 161 | 162 | return torch.diag(multiplier) @ final_matrix 163 | 164 | class LocalAttention(nn.Module): 165 | def __init__( 166 | self, 167 | window_size, 168 | causal = False, 169 | look_backward = 1, 170 | look_forward = None, 171 | dropout = 0., 172 | shared_qk = False, 173 | rel_pos_emb_config = None, 174 | dim = None, 175 | autopad = False, 176 | exact_windowsize = False, 177 | scale = None 178 | ): 179 | super().__init__() 180 | look_forward = default(look_forward, 0 if causal else 1) 181 | assert not (causal and look_forward > 0), 'you cannot look forward if causal' 182 | 183 | self.scale = scale 184 | 185 | self.window_size = window_size 186 | self.autopad = autopad 187 | self.exact_windowsize = exact_windowsize 188 | 189 | self.causal = causal 190 | 191 | self.look_backward = look_backward 192 | self.look_forward = look_forward 193 | 194 | self.dropout = nn.Dropout(dropout) 195 | 196 | self.shared_qk = shared_qk 197 | 198 | # relative positions 199 | 200 | self.rel_pos = None 201 | if exists(rel_pos_emb_config) or exists(dim): # backwards compatible with old `rel_pos_emb_config` deprecated argument 202 | if exists(rel_pos_emb_config): 203 | dim = rel_pos_emb_config[0] 204 | self.rel_pos = SinusoidalEmbeddings(dim) 205 | 206 | def forward(self, q, k, v, mask = None, input_mask = None): 207 | mask = default(mask, input_mask) 208 | 209 | shape, autopad, pad_value, window_size, causal, look_backward, look_forward, shared_qk = q.shape, self.autopad, -1, self.window_size, self.causal, self.look_backward, self.look_forward, self.shared_qk 210 | 211 | # https://github.com/arogozhnikov/einops/blob/master/docs/4-pack-and-unpack.ipynb 212 | (q, packed_shape), (k, _), (v, _) = map(lambda t: pack([t], '* n d'), (q, k, v)) 213 | 214 | # rotary embeddings 215 | 216 | if exists(self.rel_pos): 217 | pos_emb = self.rel_pos(q) 218 | q, k = apply_rotary_pos_emb(q, k, pos_emb) 219 | 220 | # auto padding 221 | 222 | if autopad: 223 | orig_seq_len = q.shape[1] 224 | (needed_pad, q), (_, k), (_, v) = map(lambda t: pad_to_multiple(t, self.window_size, dim = -2), (q, k, v)) 225 | 226 | b, n, dim_head, device, dtype = *q.shape, q.device, q.dtype 227 | 228 | scale = default(self.scale, dim_head ** -0.5) 229 | 230 | assert (n % window_size) == 0, f'sequence length {n} must be divisible by window size {window_size} for local attention' 231 | 232 | windows = n // window_size 233 | 234 | if shared_qk: 235 | k = l2norm(k) 236 | 237 | seq = torch.arange(n, device = device) 238 | b_t = rearrange(seq, '(w n) -> 1 w n', w = windows, n = window_size) 239 | 240 | bq, bk, bv = map(lambda t: rearrange(t, 'b (w n) d -> b w n d', w = windows), (q, k, v)) 241 | 242 | look_around_kwargs = dict( 243 | backward = look_backward, 244 | forward = look_forward, 245 | pad_value = pad_value 246 | ) 247 | 248 | bk = look_around(bk, **look_around_kwargs) 249 | bv = look_around(bv, **look_around_kwargs) 250 | 251 | bq_t = b_t 252 | bq_k = look_around(b_t, **look_around_kwargs) 253 | 254 | bq_t = rearrange(bq_t, '... i -> ... i 1') 255 | bq_k = rearrange(bq_k, '... j -> ... 1 j') 256 | 257 | sim = einsum('b h i e, b h j e -> b h i j', bq, bk) * scale 258 | 259 | mask_value = max_neg_value(sim) 260 | 261 | if shared_qk: 262 | self_mask = bq_t == bq_k 263 | sim = sim.masked_fill(self_mask, TOKEN_SELF_ATTN_VALUE) 264 | del self_mask 265 | 266 | if causal: 267 | causal_mask = bq_t < bq_k 268 | 269 | if self.exact_windowsize: 270 | max_causal_window_size = (self.window_size * self.look_backward) 271 | causal_mask = causal_mask | (bq_t > (bq_k + max_causal_window_size)) 272 | 273 | sim = sim.masked_fill(causal_mask, mask_value) 274 | del causal_mask 275 | 276 | # mask out padding value 277 | 278 | if autopad and needed_pad: 279 | pad_mask = bq_k == pad_value 280 | sim = sim.masked_fill(pad_mask, mask_value) 281 | del pad_mask 282 | 283 | if exists(mask): 284 | batch = mask.shape[0] 285 | assert (b % batch) == 0 286 | 287 | h = b // mask.shape[0] 288 | 289 | if autopad: 290 | _, mask = pad_to_multiple(mask, window_size, dim = -1, value = False) 291 | 292 | mask = rearrange(mask, '... (w n) -> (...) w n', w = windows, n = window_size) 293 | mask = look_around(mask, **{**look_around_kwargs, 'pad_value': False}) 294 | mask = rearrange(mask, '... j -> ... 1 j') 295 | mask = repeat(mask, 'b ... -> (b h) ...', h = h) 296 | sim = sim.masked_fill(~mask, mask_value) 297 | del mask 298 | 299 | # attention 300 | 301 | attn = sim.softmax(dim = -1) 302 | attn = self.dropout(attn) 303 | 304 | # aggregation 305 | 306 | out = einsum('b h i j, b h j e -> b h i e', attn, bv) 307 | out = rearrange(out, 'b w n d -> b (w n) d') 308 | 309 | if autopad: 310 | out = out[:, :orig_seq_len, :] 311 | 312 | out, *_ = unpack(out, packed_shape, '* n d') 313 | return out 314 | 315 | class FastAttention(nn.Module): 316 | def __init__(self, dim_heads, nb_features = None, ortho_scaling = 0, causal = False, generalized_attention = False, kernel_fn = nn.ReLU(), no_projection = False): 317 | super().__init__() 318 | nb_features = default(nb_features, int(dim_heads * math.log(dim_heads))) 319 | 320 | self.dim_heads = dim_heads 321 | self.nb_features = nb_features 322 | self.ortho_scaling = ortho_scaling 323 | 324 | self.create_projection = partial(gaussian_orthogonal_random_matrix, nb_rows = self.nb_features, nb_columns = dim_heads, scaling = ortho_scaling) 325 | projection_matrix = self.create_projection() 326 | self.register_buffer('projection_matrix', projection_matrix) 327 | 328 | self.generalized_attention = generalized_attention 329 | self.kernel_fn = kernel_fn 330 | 331 | # if this is turned on, no projection will be used 332 | # queries and keys will be softmax-ed as in the original efficient attention paper 333 | self.no_projection = no_projection 334 | 335 | self.causal = causal 336 | if causal: 337 | try: 338 | import fast_transformers.causal_product.causal_product_cuda 339 | self.causal_linear_fn = partial(causal_linear_attention) 340 | except ImportError: 341 | print('unable to import cuda code for auto-regressive Performer. will default to the memory inefficient non-cuda version') 342 | self.causal_linear_fn = causal_linear_attention_noncuda 343 | 344 | @torch.no_grad() 345 | def redraw_projection_matrix(self, device): 346 | projections = self.create_projection(device = device) 347 | self.projection_matrix.copy_(projections) 348 | del projections 349 | 350 | def forward(self, q, k, v, output_attentions = False): 351 | device = q.device 352 | # inds = [8060, 8064, 6243, 8575, 10342, 10913, 9366, 993, 7796, 5210, 5212, 5504, 6851, 6559, 5508, 13107, 13820] 353 | if self.no_projection: 354 | q = q.softmax(dim = -1) 355 | k = torch.exp(k) if self.causal else k.softmax(dim = -2) 356 | 357 | elif self.generalized_attention: 358 | create_kernel = partial(generalized_kernel, kernel_fn = self.kernel_fn, projection_matrix = self.projection_matrix, device = device) 359 | q, k = map(create_kernel, (q, k)) 360 | 361 | else: 362 | create_kernel = partial(softmax_kernel, projection_matrix = self.projection_matrix, device = device) 363 | q = create_kernel(q, is_query = True) 364 | k = create_kernel(k, is_query = False) 365 | 366 | attn_fn = linear_attention if not self.causal else self.causal_linear_fn 367 | out = attn_fn(q, k, v) 368 | if output_attentions: 369 | v_diag = torch.eye(v.shape[-2]).to(device) 370 | v_diag = v_diag.unsqueeze(0).unsqueeze(0).repeat(v.shape[0],v.shape[1],1,1) 371 | # attn_weights = torch.zeros(1, 1, len(inds), len(inds)).to(device).to(torch.float16) 372 | # attn_weights = torch.zeros(1, q.shape[1], len(inds), len(inds)).to(device).to(torch.float16) 373 | attn_weights = torch.zeros(1, 1, q.shape[2], q.shape[2]).to(device).to(torch.float16) 374 | for head_dim in range(q.shape[1]): 375 | # attn_weights[0, head_dim] = torch.abs(attn_fn(q[:,head_dim].to(torch.float16), k[:,head_dim].to(torch.float16), v_diag[:,head_dim].to(torch.float16)))[0, inds][:, inds] 376 | attn_weights += torch.abs(attn_fn(q[:,head_dim].to(torch.float16), k[:,head_dim].to(torch.float16), v_diag[:,head_dim].to(torch.float16))) 377 | # attn_weights += norm_tensor(torch.abs(attn_fn(q[:,head_dim].to(torch.float16), k[:,head_dim].to(torch.float16), v_diag[:,head_dim].to(torch.float16))), dim=-1) 378 | attn_weights /= q.shape[1] 379 | return out, attn_weights 380 | else: 381 | return out 382 | 383 | 384 | class PerformerAttention(nn.Module): 385 | def __init__( 386 | self, 387 | dim, 388 | causal = False, 389 | heads = 8, 390 | dim_head = 64, 391 | local_heads = 0, 392 | local_window_size = 256, 393 | nb_features = None, 394 | feature_redraw_interval = 1000, 395 | generalized_attention = False, 396 | kernel_fn = nn.ReLU(), 397 | dropout = 0., 398 | no_projection = False, 399 | qkv_bias = False 400 | ): 401 | super().__init__() 402 | assert dim % heads == 0, 'dimension must be divisible by number of heads' 403 | dim_head = default(dim_head, dim // heads) 404 | inner_dim = dim_head * heads 405 | self.fast_attention = FastAttention(dim_head, nb_features, causal = causal, generalized_attention = generalized_attention, kernel_fn = kernel_fn, no_projection = no_projection) 406 | 407 | self.heads = heads 408 | self.global_heads = heads - local_heads 409 | self.local_attn = LocalAttention(window_size = local_window_size, causal = causal, autopad = True, dropout = dropout, look_forward = int(not causal), rel_pos_emb_config = (dim_head, local_heads)) if local_heads > 0 else None 410 | 411 | self.to_q = nn.Linear(dim, inner_dim, bias = qkv_bias) 412 | self.to_k = nn.Linear(dim, inner_dim, bias = qkv_bias) 413 | self.to_v = nn.Linear(dim, inner_dim, bias = qkv_bias) 414 | self.to_out = nn.Linear(inner_dim, dim) 415 | self.dropout = nn.Dropout(dropout) 416 | 417 | def forward(self, x, pos_emb = None, context = None, mask = None, context_mask = None, output_attentions = False, **kwargs): 418 | b, n, _, h, gh = *x.shape, self.heads, self.global_heads 419 | 420 | cross_attend = exists(context) 421 | 422 | context = default(context, x) 423 | context_mask = default(context_mask, mask) if not cross_attend else context_mask 424 | 425 | q, k, v = self.to_q(x), self.to_k(context), self.to_v(context) 426 | 427 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) 428 | (q, lq), (k, lk), (v, lv) = map(lambda t: (t[:, :gh], t[:, gh:]), (q, k, v)) 429 | 430 | attn_outs = [] 431 | 432 | if not empty(q): 433 | if exists(context_mask): 434 | global_mask = context_mask[:, None, :, None] 435 | v.masked_fill_(~global_mask, 0.) 436 | 437 | if exists(pos_emb) and not cross_attend: 438 | q, k, = apply_rotary_pos_emb(q, k, pos_emb) 439 | 440 | if output_attentions: 441 | out, attn_weights = self.fast_attention(q, k, v, output_attentions) 442 | else: 443 | out = self.fast_attention(q, k, v) 444 | attn_outs.append(out) 445 | 446 | if not empty(lq): 447 | assert not cross_attend, 'local attention is not compatible with cross attention' 448 | out = self.local_attn(lq, lk, lv, input_mask = mask) 449 | attn_outs.append(out) 450 | 451 | out = torch.cat(attn_outs, dim = 1) # combine attn_out and cross_attn_out, here we have only attn_out, that means this line does nothing 452 | out = rearrange(out, 'b h n d -> b n (h d)') 453 | out = self.to_out(out) 454 | if output_attentions: 455 | return self.dropout(out), attn_weights 456 | else: 457 | return self.dropout(out), None 458 | 459 | 460 | class PerformerLayer(nn.Module, AbstractTrasnformerLayer): 461 | def __init__( 462 | self, 463 | embed_dim, 464 | num_heads, 465 | dropout = 0.0, 466 | norm = 'rmsnorm', 467 | norm_first: bool = True, 468 | causal=False, 469 | ): 470 | super().__init__() 471 | self.self_attn = PerformerAttention(dim=embed_dim, heads=num_heads, 472 | dropout=dropout) 473 | self._ff_block = nn.Sequential(nn.Linear(embed_dim, embed_dim*2), 474 | nn.GELU(), 475 | nn.Dropout(dropout), 476 | nn.Linear(embed_dim*2, embed_dim), 477 | nn.Dropout(dropout), 478 | ) 479 | self.dropout1 = nn.Dropout(dropout) 480 | self.norm1 = create_norm(norm, embed_dim) 481 | self.norm2 = create_norm(norm, embed_dim) 482 | self.norm_first = norm_first 483 | self.support_output_attentions = True 484 | 485 | def _sa_block(self, x: Tensor, attn_mask: Optional[Tensor]): 486 | x = x.unsqueeze(0) 487 | out, attn_weights = self.self_attn(x, attn_mask=attn_mask) 488 | return out[0], attn_weights 489 | 490 | def forward(self, x, attn_mask=None, output_attentions=False): 491 | if self.norm_first: 492 | x_prime, attn = self._sa_block(self.norm1(x), attn_mask) 493 | x = x + x_prime 494 | x = x + self._ff_block(self.norm2(x)) 495 | else: 496 | x_prime, attn = self._sa_block(x, attn_mask) 497 | x = self.norm1(x + x_prime) 498 | x = self.norm2(x + self._ff_block(x)) 499 | if output_attentions: 500 | return x, attn 501 | else: 502 | return x -------------------------------------------------------------------------------- /CellPLM/layer/transformer.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from abc import ABC, abstractmethod 3 | from torch import Tensor 4 | from torch.nn import functional as F 5 | from torch.nn.modules import Module 6 | from torch.nn.modules.activation import MultiheadAttention 7 | from torch.nn.modules.container import ModuleList 8 | from torch.nn.init import xavier_uniform_ 9 | from torch.nn.modules.dropout import Dropout 10 | from torch.nn.modules.normalization import LayerNorm 11 | import torch 12 | import copy 13 | from typing import Optional, Any, Union, Callable 14 | from ..utils import RMSNorm 15 | 16 | 17 | class Linear(nn.Linear): 18 | def __init__(self, in_features, out_features, bias=True, device='cpu', dtype=None): 19 | super().__init__(in_features, out_features, bias=bias) 20 | self.register_buffer('u', nn.functional.normalize(torch.randn(in_features), dim=0)) 21 | with torch.no_grad(): 22 | sigma = self.get_sigma() 23 | self.register_buffer('spectral_norm', sigma) 24 | 25 | self.sigma = nn.Parameter(torch.ones(1)) 26 | self.to(device) 27 | 28 | def get_sigma(self): 29 | with torch.no_grad(): 30 | u = self.u 31 | v = self.weight.mv(u) 32 | v = nn.functional.normalize(v, dim=0) 33 | u = self.weight.T.mv(v) 34 | u = nn.functional.normalize(u, dim=0) 35 | self.u.data.copy_(u) 36 | return torch.einsum('c,cd,d->', v, self.weight, u) 37 | 38 | def get_weight(self): 39 | sigma = self.get_sigma() 40 | if self.training: 41 | self.spectral_norm.data.copy_(sigma) 42 | weight = (self.sigma / sigma) * self.weight 43 | return weight 44 | 45 | def forward(self, x): 46 | return nn.functional.linear(x, self.get_weight(), self.bias) 47 | 48 | def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]: 49 | if activation == "relu": 50 | return F.relu 51 | elif activation == "gelu": 52 | return F.gelu 53 | 54 | raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) 55 | 56 | class AbstractTrasnformerLayer(ABC): 57 | @abstractmethod 58 | def __init__(self, 59 | embed_dim, 60 | num_heads, 61 | dropout, 62 | norm, 63 | norm_first: bool, 64 | causal: bool, 65 | ): 66 | pass 67 | 68 | @abstractmethod 69 | def forward(self, x, attn_mask, output_attentions): 70 | pass 71 | 72 | class TransformerEncoderLayer(Module): 73 | __constants__ = ['batch_first', 'norm_first'] 74 | 75 | def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, 76 | activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, 77 | layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False, 78 | device=None, dtype=None) -> None: 79 | factory_kwargs = {'device': device, 'dtype': dtype} 80 | super(TransformerEncoderLayer, self).__init__() 81 | self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, 82 | **factory_kwargs) 83 | # Implementation of Feedforward model 84 | self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs) 85 | self.dropout = Dropout(dropout) 86 | self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs) 87 | 88 | self.norm_first = norm_first 89 | self.norm1 = RMSNorm(d_model, eps=layer_norm_eps) 90 | self.norm2 = RMSNorm(d_model, eps=layer_norm_eps) 91 | self.dropout1 = Dropout(dropout) 92 | self.dropout2 = Dropout(dropout) 93 | 94 | # Legacy string support for activation function. 95 | if isinstance(activation, str): 96 | activation = _get_activation_fn(activation) 97 | 98 | # We can't test self.activation in forward() in TorchScript, 99 | # so stash some information about it instead. 100 | if activation is F.relu: 101 | self.activation_relu_or_gelu = 1 102 | elif activation is F.gelu: 103 | self.activation_relu_or_gelu = 2 104 | else: 105 | self.activation_relu_or_gelu = 0 106 | self.activation = activation 107 | 108 | def __setstate__(self, state): 109 | super(TransformerEncoderLayer, self).__setstate__(state) 110 | if not hasattr(self, 'activation'): 111 | self.activation = F.relu 112 | 113 | 114 | def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, 115 | src_key_padding_mask: Optional[Tensor] = None) -> Tensor: 116 | r"""Pass the input through the encoder layer. 117 | 118 | Args: 119 | src: the sequence to the encoder layer (required). 120 | src_mask: the mask for the src sequence (optional). 121 | src_key_padding_mask: the mask for the src keys per batch (optional). 122 | 123 | Shape: 124 | see the docs in Transformer class. 125 | """ 126 | 127 | # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf 128 | 129 | if (src.dim() == 3 and not self.norm_first and not self.training and 130 | self.self_attn.batch_first and 131 | self.self_attn._qkv_same_embed_dim and self.activation_relu_or_gelu and 132 | self.norm1.eps == self.norm2.eps and 133 | src_mask is None and 134 | not (src.is_nested and src_key_padding_mask is not None)): 135 | tensor_args = ( 136 | src, 137 | self.self_attn.in_proj_weight, 138 | self.self_attn.in_proj_bias, 139 | self.self_attn.out_proj.weight, 140 | self.self_attn.out_proj.bias, 141 | self.norm1.weight, 142 | self.norm2.weight, 143 | self.linear1.weight, 144 | self.linear1.bias, 145 | self.linear2.weight, 146 | self.linear2.bias, 147 | ) 148 | if (not torch.overrides.has_torch_function(tensor_args) and 149 | # We have to use a list comprehension here because TorchScript 150 | # doesn't support generator expressions. 151 | all([(x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args]) and 152 | (not torch.is_grad_enabled() or all([not x.requires_grad for x in tensor_args]))): 153 | return torch._transformer_encoder_layer_fwd( 154 | src, 155 | self.self_attn.embed_dim, 156 | self.self_attn.num_heads, 157 | self.self_attn.in_proj_weight, 158 | self.self_attn.in_proj_bias, 159 | self.self_attn.out_proj.weight, 160 | self.self_attn.out_proj.bias, 161 | self.activation_relu_or_gelu == 2, 162 | False, # norm_first, currently not supported 163 | self.norm1.eps, 164 | self.norm1.weight, 165 | self.norm2.weight, 166 | self.linear1.weight, 167 | self.linear1.bias, 168 | self.linear2.weight, 169 | self.linear2.bias, 170 | src_mask if src_mask is not None else src_key_padding_mask, # TODO: split into two args 171 | ) 172 | x = src 173 | if self.norm_first: 174 | x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask) 175 | x = x + self._ff_block(self.norm2(x)) 176 | else: 177 | x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask)) 178 | x = self.norm2(x + self._ff_block(x)) 179 | 180 | return x 181 | 182 | # self-attention block 183 | def _sa_block(self, x: Tensor, 184 | attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor: 185 | x = self.self_attn(x, x, x, 186 | attn_mask=attn_mask, 187 | key_padding_mask=key_padding_mask, 188 | need_weights=True) 189 | # print(x[1]) 190 | x = x[0] 191 | return self.dropout1(x) 192 | 193 | # feed forward block 194 | def _ff_block(self, x: Tensor) -> Tensor: 195 | x = self.linear2(self.dropout(self.activation(self.linear1(x)))) 196 | return self.dropout2(x) 197 | 198 | class VanillaTransformerLayer(nn.Module, AbstractTrasnformerLayer): 199 | def __init__( 200 | self, 201 | embed_dim, 202 | num_heads, 203 | dropout = 0.0, 204 | norm = 'layernorm', 205 | norm_first=True, 206 | causal=False, 207 | ): 208 | super().__init__() 209 | assert norm=='layernorm', 'Vanilla transformer only supports layernorm.' 210 | assert causal==False, 'Vanilla transformer does not supports causal inference.' 211 | self.layer = TransformerEncoderLayer(embed_dim, num_heads, embed_dim*2, 212 | dropout, activation='gelu', norm_first=norm_first) 213 | self.support_output_attentions = False 214 | 215 | def forward(self, x, attn_mask=None, output_attentions=False): 216 | assert output_attentions == False, 'output_attentions not implemented for VanillaTransformer' 217 | # self.train() 218 | x = x.unsqueeze(1) 219 | return self.layer(x, attn_mask)[:, 0, :] -------------------------------------------------------------------------------- /CellPLM/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .cellformer import OmicsFormer -------------------------------------------------------------------------------- /CellPLM/model/cellformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from ..embedder import OmicsEmbeddingLayer 6 | from ..utils.mask import MaskBuilder, NullMaskBuilder, HiddenMaskBuilder 7 | from ..encoder import setup_encoder 8 | from ..decoder import setup_decoder 9 | from ..latent import LatentModel, PreLatentNorm 10 | from ..latent.adversarial import AdversarialLatentLayer 11 | from ..objective import Objectives 12 | from ..head import setup_head 13 | 14 | class OmicsFormer(nn.Module): 15 | def __init__(self, gene_list, enc_mod, enc_hid, enc_layers, post_latent_dim, dec_mod, dec_hid, dec_layers, 16 | out_dim, batch_num=0, dataset_num=0, platform_num=0, mask_type='input', model_dropout=0.1, 17 | activation='gelu', norm='layernorm', enc_head=8, mask_node_rate=0.5, 18 | mask_feature_rate=0.8, drop_node_rate=0., max_batch_size=2000, cat_dim=None, conti_dim=None, 19 | pe_type='sin', cat_pe=True, 20 | gene_emb=None, latent_mod='vae', w_li=1., w_en=1., w_ce=1., 21 | head_type=None, dsbn=False, ecs=False, dar=False, input_covariate=False, 22 | num_clusters=16, dae=True, lamda=0.5, mask_beta=False, **kwargs): 23 | super(OmicsFormer, self).__init__() 24 | 25 | self.embedder = OmicsEmbeddingLayer(gene_list, enc_hid, norm, activation, model_dropout, 26 | pe_type, cat_pe, gene_emb, inject_covariate=input_covariate, batch_num=batch_num) 27 | self.gene_set = set(gene_list) 28 | self.mask_type = mask_type 29 | if mask_node_rate > 0 and mask_feature_rate > 0: 30 | if mask_type == 'input': 31 | self.mask_model = MaskBuilder(mask_node_rate, mask_feature_rate, drop_node_rate, max_batch_size, mask_beta) 32 | elif mask_type == 'hidden': 33 | self.mask_model = HiddenMaskBuilder(mask_node_rate, mask_feature_rate, drop_node_rate, max_batch_size) 34 | else: 35 | raise NotImplementedError(f"Only support mask_type in ['input', 'hidden'], but got {mask_type}") 36 | else: 37 | self.mask_model = NullMaskBuilder(drop_node_rate, max_batch_size) 38 | self.encoder = setup_encoder(enc_mod, enc_hid, enc_layers, model_dropout, activation, norm, enc_head) 39 | 40 | self.latent = LatentModel() 41 | self.latent_mod = latent_mod 42 | if latent_mod=='vae': 43 | self.latent.add_layer(type='vae', enc_hid=enc_hid, latent_dim=post_latent_dim) 44 | elif latent_mod=='ae': 45 | self.latent.add_layer(type='merge', conti_dim=enc_hid, cat_dim=0, post_latent_dim=post_latent_dim) 46 | elif latent_mod=='gmvae': 47 | self.latent.add_layer(type='gmvae', enc_hid=enc_hid, latent_dim=post_latent_dim, batch_num=batch_num, 48 | w_li=w_li, w_en=w_en, w_ce=w_ce, dropout=model_dropout, num_layers=dec_layers, 49 | num_clusters=num_clusters, lamda=lamda) 50 | elif latent_mod=='split': 51 | self.latent.add_layer(type='split', enc_hid=enc_hid, latent_dim=None, conti_dim=conti_dim, cat_dim=cat_dim) 52 | self.latent.add_layer(type='merge', conti_dim=conti_dim, cat_dim=cat_dim, post_latent_dim=post_latent_dim) 53 | elif latent_mod == 'none': 54 | post_latent_dim = enc_hid 55 | else: 56 | raise NotImplementedError(f'Latent mod "{latent_mod}" is not implemented.') 57 | if latent_mod is not None: 58 | if dar: 59 | self.latent.add_layer(type='adversarial', input_dims=np.arange(post_latent_dim), label_key='batch', 60 | discriminator_hidden=64, disc_lr=1e-3, 61 | target_classes=batch_num) 62 | if ecs: 63 | self.latent.add_layer(type='ecs') 64 | 65 | self.head_type = head_type 66 | if head_type is not None: 67 | self.head = setup_head(head_type, post_latent_dim, dec_hid, out_dim, dec_layers, 68 | model_dropout, norm, batch_num=batch_num) 69 | else: 70 | self.decoder = setup_decoder(dec_mod, post_latent_dim, dec_hid, out_dim, dec_layers, 71 | model_dropout, norm, batch_num=batch_num, dataset_num=dataset_num, platform_num=platform_num) 72 | if 'objective' in kwargs: 73 | self.objective = Objectives([{'type': kwargs['objective']}]) 74 | else: 75 | if 'nb' in dec_mod: 76 | self.objective = Objectives([{'type': 'nb', 'dae': dae}]) 77 | else: 78 | self.objective = Objectives([{'type': 'recon'}]) 79 | 80 | if dsbn: 81 | self.pre_latent_norm = PreLatentNorm('dsbn', enc_hid, dataset_num) 82 | else: 83 | self.pre_latent_norm = PreLatentNorm('ln', enc_hid) 84 | # self.post_latent_norm = nn.LayerNorm(post_latent_dim, dataset_num) 85 | 86 | def forward(self, x_dict, input_gene_list=None, d_iter=False): 87 | if self.mask_type == 'input': 88 | x_dict = self.mask_model.apply_mask(x_dict) 89 | x_dict['h'] = self.embedder(x_dict, input_gene_list) 90 | if self.mask_type == 'hidden': 91 | x_dict = self.mask_model.apply_mask(x_dict) 92 | x_dict['h'] = self.encoder(x_dict)['hidden'] 93 | x_dict['h'] = self.pre_latent_norm(x_dict) 94 | x_dict['h'], latent_loss = self.latent(x_dict) 95 | 96 | # x_dict['h'] = self.post_latent_norm(x_dict['h']) 97 | # if 'ecs' in x_dict: 98 | # x_dict['h'] = self.latent_norm(x_dict['h']) 99 | 100 | if d_iter: 101 | return self.latent.d_train(x_dict) 102 | else: 103 | if self.head_type is not None: 104 | out_dict, loss = self.head(x_dict) 105 | out_dict['latent_loss'] = latent_loss.item() if torch.is_tensor(latent_loss) else latent_loss 106 | out_dict['target_loss'] = loss.item() 107 | else: 108 | out_dict = self.decoder(x_dict) 109 | loss = latent_loss + self.objective(out_dict, x_dict) #/ 1e4 110 | out_dict['latent_loss'] = latent_loss.item() if torch.is_tensor(latent_loss) else latent_loss 111 | out_dict['target_loss'] = loss.item() - out_dict['latent_loss'] 112 | return out_dict, loss 113 | 114 | def nondisc_parameters(self): 115 | other_params = [] 116 | for pname, p in self.named_parameters(): 117 | if 'discriminator' not in pname: 118 | other_params += [p] 119 | else: 120 | print(pname) 121 | return other_params 122 | -------------------------------------------------------------------------------- /CellPLM/objective/__init__.py: -------------------------------------------------------------------------------- 1 | from .zinb import ZINBReconstructionLoss, NBReconstructionLoss, NBDenoisingLoss, NBImputationLoss 2 | from .autoencoder import ReconstructionLoss 3 | from torch import nn 4 | 5 | def create_objective(**config) -> nn.Module: 6 | if config['type'] == 'recon': 7 | return ReconstructionLoss(**config) 8 | elif config['type'] == 'zinb': 9 | return ZINBReconstructionLoss(**config) 10 | elif config['type'] == 'nb': 11 | return NBReconstructionLoss(**config) 12 | elif config['type'] == 'denoise': 13 | return NBDenoisingLoss(**config) 14 | elif config['type'] == 'imputation': 15 | return NBImputationLoss(**config) 16 | else: 17 | raise ValueError(f"Unrecognized latent model name: {config['type']}") 18 | 19 | class Objectives(nn.Module): 20 | def __init__(self, configs=None): 21 | super().__init__() 22 | self.layers = nn.ModuleList() 23 | if configs is not None: 24 | for c in configs: 25 | self.layers.append(create_objective(**c)) 26 | 27 | def forward(self, out_dict, x_dict): 28 | if len(self.layers) == 0: 29 | raise RuntimeError("No objectives added to model.") 30 | total_loss = 0 31 | for layer in self.layers: 32 | loss = layer(out_dict, x_dict) 33 | total_loss += loss 34 | return total_loss 35 | 36 | def add_layer(self, **config): 37 | self.layers.append(create_objective(**config)) -------------------------------------------------------------------------------- /CellPLM/objective/autoencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | class ReconstructionLoss(nn.Module): 6 | def __init__(self, lib_size=None, log_norm=False, **kwargs): 7 | super().__init__() 8 | self.reconstruction_loss = nn.MSELoss() 9 | self.lib_size = lib_size 10 | self.log_norm = log_norm 11 | self.downstream = None 12 | 13 | def forward(self, out_dict, x_dict): 14 | y = x_dict['x_seq'].to_dense() 15 | if self.lib_size is not None: 16 | y = y/y.sum(1)[:, None] * self.lib_size 17 | if self.log_norm: 18 | y = torch.log(y+1) 19 | size_factor = y.sum(1, keepdim=True) 20 | pred = (size_factor * out_dict['recon'] * x_dict['input_mask'])[:, x_dict['gene_mask']] 21 | truth = (y * x_dict['input_mask'])[:, x_dict['gene_mask']] 22 | pred = pred[x_dict['input_mask'].sum(1)>0] 23 | truth = truth[x_dict['input_mask'].sum(1)>0] 24 | out_dict['pred'] = pred 25 | 26 | return self.reconstruction_loss(pred, truth) -------------------------------------------------------------------------------- /CellPLM/objective/zinb.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import random 5 | import torch.nn.functional as F 6 | 7 | class ZINBReconstructionLoss(nn.Module): 8 | """ZINB loss class.""" 9 | 10 | def __init__(self, **kwargs): 11 | super().__init__() 12 | 13 | def forward(self, out_dict, x_dict, ridge_lambda = 0.0): 14 | """Forward propagation. 15 | Parameters 16 | ---------- 17 | x : 18 | input features. 19 | mean : 20 | data mean. 21 | disp : 22 | data dispersion. 23 | pi : 24 | data dropout probability. 25 | scale_factor : torch.Tensor 26 | scale factor of mean. 27 | ridge_lambda : float optional 28 | ridge parameter. 29 | Returns 30 | ------- 31 | result : float 32 | ZINB loss. 33 | """ 34 | eps = 1e-10 35 | x = x_dict['x_seq'].to_dense()[x_dict['input_mask']] 36 | # x = x_dict['x_seq'].index_select(0, x_dict['input_mask']).to_dense() 37 | mean = out_dict['mean'][x_dict['input_mask']] 38 | disp = out_dict['disp'][x_dict['input_mask']] 39 | pi = out_dict['pi'][x_dict['input_mask']] 40 | scale_factor = x_dict['scale_factor'][x_dict['input_mask']] 41 | scale_factor = scale_factor.unsqueeze(-1) 42 | mean = mean * scale_factor 43 | 44 | t1 = torch.lgamma(disp + eps) + torch.lgamma(x + 1.0) - torch.lgamma(x + disp + eps) 45 | t2 = (disp + x) * torch.log(1.0 + (mean / (disp + eps))) + (x * (torch.log(disp + eps) - torch.log(mean + eps))) 46 | nb_final = t1 + t2 47 | 48 | nb_case = nb_final - torch.log(1.0 - pi + eps) 49 | zero_nb = torch.pow(disp / (disp + mean + eps), disp) 50 | zero_case = -torch.log(pi + ((1.0 - pi) * zero_nb) + eps) 51 | result = torch.where(torch.le(x, 1e-8), zero_case, nb_case) 52 | 53 | if ridge_lambda > 0: 54 | ridge = ridge_lambda * torch.square(pi) 55 | result += ridge 56 | result = torch.mean(result) 57 | return result 58 | 59 | 60 | class NBImputationLoss(nn.Module): 61 | """NB loss class.""" 62 | 63 | def __init__(self, dae=True, **kwargs): 64 | super().__init__() 65 | self.dae = dae 66 | self.downstream = None 67 | 68 | def forward(self, out_dict, x_dict): 69 | eps = 1e-10 70 | if 'gene_mask' not in x_dict: 71 | x_dict['gene_mask'] = torch.arange(x_dict['x_seq'].shape[1]).to(x_dict['x_seq'].device) 72 | mean = out_dict['mean'] 73 | disp = out_dict['disp'] 74 | if 'input_gene_mask' in x_dict: 75 | mean = mean[:, x_dict['input_gene_mask']] 76 | disp = disp[:, x_dict['input_gene_mask']] 77 | out_dict['pred'] = mean 78 | mean = mean[:, x_dict['gene_mask']] 79 | disp = disp[:, x_dict['gene_mask']] 80 | truth = x_dict['x_seq'].to_dense()[:, x_dict['gene_mask']] 81 | size_factor = truth.sum(1, keepdim=True) / mean.sum(1, keepdim=True) 82 | mean *= size_factor 83 | out_dict['pred'] *= size_factor 84 | 85 | if False: 86 | return F.mse_loss(torch.log1p(out_dict['pred']), torch.log1p(truth)) 87 | else: 88 | t1 = torch.lgamma(disp + eps) + torch.lgamma(truth + 1.0) - torch.lgamma(truth + disp + eps) 89 | t2 = (disp + truth) * torch.log(1.0 + (mean / (disp + eps))) + ( 90 | truth * (torch.log(disp + eps) - torch.log(mean + eps))) 91 | nb_final = t1 + t2 92 | return nb_final.sum(-1).mean() 93 | 94 | class NBDenoisingLoss(nn.Module): 95 | """NB loss class.""" 96 | 97 | def __init__(self, dae=True, **kwargs): 98 | super().__init__() 99 | self.dae = dae 100 | self.downstream = None 101 | 102 | def forward(self, out_dict, x_dict): 103 | eps = 1e-10 104 | 105 | truth = x_dict['label'] 106 | mean = out_dict['mean'][:, x_dict['gene_mask']] 107 | disp = out_dict['disp'][:, x_dict['gene_mask']] 108 | mean = mean / mean.sum(1, keepdim=True) * truth.sum(1, keepdim=True) 109 | out_dict['pred'] = mean 110 | 111 | if False: 112 | return F.mse_loss(torch.log1p(out_dict['pred']), torch.log1p(truth)) 113 | else: 114 | t1 = torch.lgamma(disp + eps) + torch.lgamma(truth + 1.0) - torch.lgamma(truth + disp + eps) 115 | t2 = (disp + truth) * torch.log(1.0 + (mean / (disp + eps))) + ( 116 | truth * (torch.log(disp + eps) - torch.log(mean + eps))) 117 | nb_final = t1 + t2 118 | return nb_final.sum(-1).mean() 119 | 120 | class NBReconstructionLoss(nn.Module): 121 | """NB loss class.""" 122 | 123 | def __init__(self, dae=True, **kwargs): 124 | super().__init__() 125 | self.dae = dae 126 | 127 | def forward(self, out_dict, x_dict): 128 | eps = 1e-10 129 | 130 | y = x_dict['x_seq'].to_dense() 131 | truth = y[:, x_dict['gene_mask']] 132 | mean = out_dict['mean'][:, x_dict['gene_mask']] 133 | disp = out_dict['disp'][:, x_dict['gene_mask']] 134 | masked_nodes = x_dict['input_mask'].sum(1)>0 135 | 136 | if self.dae and self.training: 137 | truth_masked = (truth * x_dict['input_mask'])[masked_nodes] #/ (x_dict['input_mask'][masked_nodes].mean()) 138 | mean_masked = (out_dict['mean'] * x_dict['input_mask'])[masked_nodes] 139 | disp_masked = (out_dict['disp'] * x_dict['input_mask'])[masked_nodes] 140 | mean_masked = mean_masked / mean_masked.sum(1, keepdim=True) * truth_masked.sum(1, keepdim=True) 141 | t1 = torch.lgamma(disp_masked + eps) + torch.lgamma(truth_masked + 1.0) - torch.lgamma(truth_masked + disp_masked + eps) 142 | t2 = (disp_masked + truth_masked) * torch.log(1.0 + (mean_masked / (disp_masked + eps))) + ( 143 | truth_masked * (torch.log(disp_masked + eps) - torch.log(mean_masked + eps))) 144 | nb_final_masked = t1 + t2 145 | else: 146 | nb_final_masked = 0. 147 | 148 | truth = truth[masked_nodes] 149 | mean = mean[masked_nodes] 150 | disp = disp[masked_nodes] 151 | mean = mean / mean.sum(1, keepdim=True) * truth.sum(1, keepdim=True) 152 | 153 | t1 = torch.lgamma(disp + eps) + torch.lgamma(truth + 1.0) - torch.lgamma(truth + disp + eps) 154 | t2 = (disp + truth) * torch.log(1.0 + (mean / (disp + eps))) + (truth * (torch.log(disp + eps) - torch.log(mean + eps))) 155 | nb_final = t1 + t2 + nb_final_masked 156 | 157 | return nb_final.sum(-1).mean() -------------------------------------------------------------------------------- /CellPLM/pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import anndata as ad 4 | from ..model import OmicsFormer 5 | from abc import ABC, abstractmethod 6 | from typing import List, Union 7 | from .experimental import symbol_to_ensembl 8 | import json 9 | import warnings 10 | import scanpy as sc 11 | 12 | def load_pretrain( 13 | pretrain_prefix: str, 14 | overwrite_config: dict = None, 15 | pretrain_directory: str = './ckpt'): 16 | config_path = os.path.join(pretrain_directory, f'{pretrain_prefix}.config.json') 17 | ckpt_path = os.path.join(pretrain_directory, f'{pretrain_prefix}.best.ckpt') 18 | with open(config_path, "r") as openfile: 19 | config = json.load(openfile) 20 | config.update(overwrite_config) 21 | model = OmicsFormer(**config) 22 | pretrained_model_dict = torch.load(ckpt_path)['model_state_dict'] 23 | model_dict = model.state_dict() 24 | pretrained_dict = { 25 | k: v 26 | for k, v in pretrained_model_dict.items() 27 | if k in model_dict and v.shape == model_dict[k].shape 28 | } 29 | model_dict.update(pretrained_dict) 30 | model.load_state_dict(model_dict) 31 | return model 32 | 33 | 34 | class Pipeline(ABC): 35 | def __init__(self, 36 | pretrain_prefix: str, 37 | overwrite_config: dict = None, 38 | pretrain_directory: str = './ckpt', 39 | ): 40 | # Load pretrain model 41 | self.model = load_pretrain(pretrain_prefix, overwrite_config, pretrain_directory) 42 | self.gene_list = None 43 | self.fitted = False 44 | self.eval_dict = {} 45 | 46 | def common_preprocess(self, adata, hvg, covariate_fields, ensembl_auto_conversion): 47 | if covariate_fields: 48 | for i in covariate_fields: 49 | assert i in ['batch', 'dataset', 50 | 'platform'], 'Currently does not support customized covariate other than "batch", "dataset" and "platform"' 51 | adata = adata.copy() 52 | if not adata.var.index.isin(self.model.gene_set).any(): 53 | if ensembl_auto_conversion: 54 | print('Automatically converting gene symbols to ensembl ids...') 55 | adata.var.index = symbol_to_ensembl(adata.var.index.tolist()) 56 | if (adata.var.index == '0').all(): 57 | raise ValueError( 58 | 'None of AnnData.var.index found in pre-trained gene set.') 59 | adata.var_names_make_unique() 60 | else: 61 | raise ValueError( 62 | 'None of AnnData.var.index found in pre-trained gene set. In case the input gene names are gene symbols, please enable `ensembl_auto_conversion`, or manually convert gene symbols to ensembl ids in the input dataset.') 63 | if self.fitted: 64 | return adata[:, adata.var.index.isin(self.gene_list)] 65 | else: 66 | if hvg > 0: 67 | if hvg < adata.shape[1]: 68 | sc.pp.highly_variable_genes(adata, n_top_genes=hvg, subset=True, flavor='seurat_v3') 69 | else: 70 | warnings.warn('HVG number is larger than number of valid genes.') 71 | adata = adata[:, [x for x in adata.var.index.tolist() if x in self.model.gene_set]] 72 | self.gene_list = adata.var.index.tolist() 73 | return adata 74 | 75 | @abstractmethod 76 | def fit(self, adata: ad.AnnData, 77 | train_config: dict = None, 78 | split_field: str = None, # A field in adata.obs for representing train-test split 79 | train_split: str = None, # A specific split where labels can be utilized for training 80 | valid_split: str = None, # A specific split where labels can be utilized for validation 81 | covariate_fields: List[str] = None, # A list of fields in adata.obs that contain cellular covariates 82 | label_fields: List[str] = None, # A list of fields in adata.obs that contain cell labels 83 | batch_gene_list: dict = None, # A dictionary that contains batch and gene list pairs 84 | ensembl_auto_conversion: bool = True, # A bool value indicating whether the function automativally convert symbols to ensembl id 85 | device: Union[str, torch.device] = 'cpu' 86 | ): 87 | # Fine-tune the model on an anndata object 88 | pass 89 | 90 | @abstractmethod 91 | def predict(self, adata: ad.AnnData, 92 | inference_config: dict = None, 93 | covariate_fields: List[str] = None, 94 | batch_gene_list: dict = None, 95 | ensembl_auto_conversion: bool = True, 96 | device: Union[str, torch.device] = 'cpu' 97 | ): 98 | # Inference on an anndata object 99 | pass 100 | 101 | @abstractmethod 102 | def score(self, adata: ad.AnnData, 103 | evaluation_config: dict = None, 104 | split_field: str = None, 105 | target_split: str = 'test', 106 | covariate_fields: List[str] = None, 107 | label_fields: List[str] = None, 108 | batch_gene_list: dict = None, 109 | ensembl_auto_conversion: bool = True, 110 | device: Union[str, torch.device] = 'cpu' 111 | ): 112 | # Inference on an anndata object and automatically evaluate 113 | pass 114 | -------------------------------------------------------------------------------- /CellPLM/pipeline/cell_embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import scanpy as sc 5 | import anndata as ad 6 | from torch.optim.lr_scheduler import _LRScheduler, ReduceLROnPlateau 7 | from tqdm import tqdm 8 | from copy import deepcopy 9 | from ..utils.eval import downstream_eval, aggregate_eval_results 10 | from ..utils.data import XDict, TranscriptomicDataset 11 | from typing import List, Literal, Union 12 | from .experimental import symbol_to_ensembl 13 | from torch.utils.data import DataLoader 14 | import warnings 15 | from . import Pipeline, load_pretrain 16 | from sklearn.metrics.cluster import adjusted_rand_score, normalized_mutual_info_score 17 | 18 | class CellEmbeddingPipeline(Pipeline): 19 | def __init__(self, 20 | pretrain_prefix: str, 21 | pretrain_directory: str = './ckpt', 22 | ): 23 | super().__init__(pretrain_prefix, {'head_type': 'embedder'}, pretrain_directory) 24 | self.label_encoders = None 25 | 26 | def fit(self, adata: ad.AnnData, 27 | train_config: dict = None, 28 | split_field: str = None, # A field in adata.obs for representing train-test split 29 | train_split: str = None, # A specific split where labels can be utilized for training 30 | valid_split: str = None, # A specific split where labels can be utilized for validation 31 | covariate_fields: List[str] = None, # A list of fields in adata.obs that contain cellular covariates 32 | label_fields: List[str] = None, # A list of fields in adata.obs that contain cell labels 33 | batch_gene_list: dict = None, # A dictionary that contains batch and gene list pairs 34 | ensembl_auto_conversion: bool = True, 35 | # A bool value indicating whether the function automativally convert symbols to ensembl id 36 | device: Union[str, torch.device] = 'cpu' 37 | ): 38 | raise NotImplementedError('Currently CellPLM only supports zero shot embedding instead of fine-tuning.') 39 | 40 | def predict(self, adata: ad.AnnData, 41 | inference_config: dict = None, 42 | covariate_fields: List[str] = None, 43 | batch_gene_list: dict = None, 44 | ensembl_auto_conversion: bool = True, 45 | device: Union[str, torch.device] = 'cpu' 46 | ): 47 | if inference_config and 'batch_size' in inference_config: 48 | batch_size = inference_config['batch_size'] 49 | else: 50 | batch_size = 0 51 | if covariate_fields: 52 | warnings.warn('`covariate_fields` argument is ignored in CellEmbeddingPipeline.') 53 | if batch_gene_list: 54 | warnings.warn('`batch_gene_list` argument is ignored in CellEmbeddingPipeline.') 55 | return self._inference(adata, batch_size, device, ensembl_auto_conversion) 56 | 57 | def _inference(self, adata: ad.AnnData, 58 | batch_size: int = 0, 59 | device: Union[str, torch.device] = 'cpu', 60 | ensembl_auto_conversion: bool = True): 61 | self.model.to(device) 62 | adata = self.common_preprocess(adata, 0, covariate_fields=None, ensembl_auto_conversion=ensembl_auto_conversion) 63 | print(f'After filtering, {adata.shape[1]} genes remain.') 64 | dataset = TranscriptomicDataset(adata, order_required=True) 65 | dataloader = DataLoader(dataset, batch_size=None, shuffle=False, num_workers=0) 66 | order_list = [] 67 | if batch_size <= 0: 68 | batch_size = adata.shape[0] 69 | 70 | with torch.no_grad(): 71 | self.model.eval() 72 | pred = [] 73 | for i, data_dict in enumerate(dataloader): 74 | idx = torch.arange(data_dict['x_seq'].shape[0]) 75 | for j in range(0, len(idx), batch_size): 76 | if len(idx) - j < batch_size: 77 | cur = idx[j:] 78 | else: 79 | cur = idx[j:j + batch_size] 80 | input_dict = {} 81 | for k in data_dict: 82 | if k == 'x_seq': 83 | input_dict[k] = data_dict[k].index_select(0, cur).to(device) 84 | elif k not in ['gene_list', 'split']: 85 | input_dict[k] = data_dict[k][cur].to(device) 86 | x_dict = XDict(input_dict) 87 | out_dict, _ = self.model(x_dict, data_dict['gene_list']) 88 | order_list.append(input_dict['order_list']) 89 | pred.append(out_dict['pred'])#[input_dict['order_list']]) 90 | order = torch.cat(order_list) 91 | order.scatter_(0, order.clone(), torch.arange(order.shape[0]).to(order.device)) 92 | pred = torch.cat(pred) 93 | pred = pred[order] 94 | return pred 95 | 96 | def score(self, adata: ad.AnnData, 97 | evaluation_config: dict = None, 98 | split_field: str = None, 99 | target_split: str = 'test', 100 | covariate_fields: List[str] = None, 101 | label_fields: List[str] = None, 102 | batch_gene_list: dict = None, 103 | ensembl_auto_conversion: bool = True, 104 | device: Union[str, torch.device] = 'cpu' 105 | ): 106 | if evaluation_config and 'batch_size' in evaluation_config: 107 | batch_size = evaluation_config['batch_size'] 108 | else: 109 | batch_size = 0 110 | if len(label_fields) != 1: 111 | raise NotImplementedError( 112 | f'`label_fields` containing multiple labels (f{len(label_fields)}) is not implemented for evaluation of cell embedding pipeline. Please raise an issue on Github for further support.') 113 | if split_field: 114 | warnings.warn('`split_field` argument is ignored in CellEmbeddingPipeline.') 115 | if target_split: 116 | warnings.warn('`target_split` argument is ignored in CellEmbeddingPipeline.') 117 | if covariate_fields: 118 | warnings.warn('`covariate_fields` argument is ignored in CellEmbeddingPipeline.') 119 | if batch_gene_list: 120 | warnings.warn('`batch_gene_list` argument is ignored in CellEmbeddingPipeline.') 121 | 122 | adata = adata.copy() 123 | pred = self._inference(adata, batch_size, device) 124 | adata.obsm['emb'] = pred.cpu().numpy() 125 | if 'method' in evaluation_config and evaluation_config['method'] == 'rapids': 126 | sc.pp.neighbors(adata, use_rep='emb', method='rapids') 127 | else: 128 | sc.pp.neighbors(adata, use_rep='emb') 129 | best_ari = -1 130 | best_nmi = -1 131 | for res in range(1, 15, 1): 132 | res = res / 10 133 | if 'method' in evaluation_config and evaluation_config['method'] == 'rapids': 134 | import rapids_singlecell as rsc 135 | rsc.tl.leiden(adata, resolution=res, key_added='leiden') 136 | else: 137 | sc.tl.leiden(adata, resolution=res, key_added='leiden') 138 | ari_score = adjusted_rand_score(adata.obs['leiden'].to_numpy(), adata.obs[label_fields[0]].to_numpy()) 139 | if ari_score > best_ari: 140 | best_ari = ari_score 141 | nmi_score = normalized_mutual_info_score(adata.obs['leiden'].to_numpy(), adata.obs[label_fields[0]].to_numpy()) 142 | if nmi_score > best_nmi: 143 | best_nmi = nmi_score 144 | return {'ari': best_ari, 'nmi': best_nmi} 145 | 146 | 147 | 148 | 149 | -------------------------------------------------------------------------------- /CellPLM/pipeline/cell_type_annotation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import scanpy as sc 5 | import anndata as ad 6 | from torch.optim.lr_scheduler import _LRScheduler, ReduceLROnPlateau 7 | from tqdm import tqdm 8 | from copy import deepcopy 9 | from ..utils.eval import downstream_eval, aggregate_eval_results 10 | from ..utils.data import XDict, TranscriptomicDataset 11 | from typing import List, Union 12 | from .experimental import symbol_to_ensembl 13 | from torch.utils.data import DataLoader 14 | import warnings 15 | from . import Pipeline, load_pretrain 16 | 17 | CellTypeAnnotationDefaultModelConfig = { 18 | 'drop_node_rate': 0.3, 19 | 'dec_layers': 1, 20 | 'model_dropout': 0.5, 21 | 'mask_node_rate': 0.75, 22 | 'mask_feature_rate': 0.25, 23 | 'dec_mod': 'mlp', 24 | 'latent_mod': 'ae', 25 | 'head_type': 'annotation', 26 | 'max_batch_size': 70000, 27 | } 28 | 29 | CellTypeAnnotationDefaultPipelineConfig = { 30 | 'es': 200, 31 | 'lr': 5e-3, 32 | 'wd': 1e-7, 33 | 'scheduler': 'plat', 34 | 'epochs': 2000, 35 | 'max_eval_batch_size': 100000, 36 | 'hvg': 3000, 37 | 'patience': 25, 38 | 'workers': 0, 39 | } 40 | def inference(model, dataloader, split, device, batch_size, eval_dict, label_fields=None, order_required=False): 41 | if order_required and split: 42 | warnings.warn('When cell order required to be preserved, dataset split will be ignored.') 43 | 44 | with torch.no_grad(): 45 | model.eval() 46 | epoch_loss = [] 47 | order_list = [] 48 | pred = [] 49 | label = [] 50 | for i, data_dict in enumerate(dataloader): 51 | if not order_required and split and np.sum(data_dict['split'] == split) == 0: 52 | continue 53 | 54 | idx = torch.arange(data_dict['x_seq'].shape[0]) 55 | if split: 56 | data_dict['loss_mask'] = torch.from_numpy((data_dict['split'] == split).values).bool() 57 | else: 58 | data_dict['loss_mask'] = torch.ones(data_dict['x_seq'].shape[0]).bool() 59 | if label_fields: 60 | data_dict['label'] = data_dict[label_fields[0]] 61 | for j in range(0, len(idx), batch_size): 62 | if len(idx) - j < batch_size: 63 | cur = idx[j:] 64 | else: 65 | cur = idx[j:j + batch_size] 66 | input_dict = {} 67 | for k in data_dict: 68 | if k =='x_seq': 69 | input_dict[k] = data_dict[k].index_select(0, cur).to(device) 70 | elif k not in ['gene_list', 'split']: 71 | input_dict[k] = data_dict[k][cur].to(device) 72 | x_dict = XDict(input_dict) 73 | out_dict, loss = model(x_dict, data_dict['gene_list']) 74 | if 'label' in input_dict: 75 | epoch_loss.append(loss.item()) 76 | label.append(out_dict['label']) 77 | if order_required: 78 | order_list.append(input_dict['order_list']) 79 | pred.append(out_dict['pred']) 80 | 81 | pred = torch.cat(pred) 82 | if order_required: 83 | order = torch.cat(order_list) 84 | order.scatter_(0, order.clone(), torch.arange(order.shape[0]).to(order.device)) 85 | pred = pred[order] 86 | 87 | if len(epoch_loss) == 0: 88 | return {'pred': pred} 89 | else: 90 | scores = downstream_eval('annotation', pred, torch.cat(label), 91 | **eval_dict) 92 | return {'pred': pred, 93 | 'loss': sum(epoch_loss)/len(epoch_loss), 94 | 'metrics': scores} 95 | 96 | class CellTypeAnnotationPipeline(Pipeline): 97 | def __init__(self, 98 | pretrain_prefix: str, 99 | overwrite_config: dict = CellTypeAnnotationDefaultModelConfig, 100 | pretrain_directory: str = './ckpt', 101 | ): 102 | assert 'out_dim' in overwrite_config, '`out_dim` must be provided in `overwrite_config` for initializing a cell type annotation pipeline. ' 103 | super().__init__(pretrain_prefix, overwrite_config, pretrain_directory) 104 | self.eval_dict = {'num_classes': overwrite_config['out_dim']} 105 | self.label_encoders = None 106 | 107 | def fit(self, adata: ad.AnnData, 108 | train_config: dict = None, 109 | split_field: str = None, 110 | train_split: str = 'train', 111 | valid_split: str = 'valid', 112 | covariate_fields: List[str] = None, 113 | label_fields: List[str] = None, 114 | batch_gene_list: dict = None, 115 | ensembl_auto_conversion: bool = True, 116 | device: Union[str, torch.device] = 'cpu', 117 | ): 118 | config = CellTypeAnnotationDefaultPipelineConfig.copy() 119 | if train_config: 120 | config.update(train_config) 121 | self.model.to(device) 122 | assert not self.fitted, 'Current pipeline is already fitted and does not support continual training. Please initialize a new pipeline.' 123 | if batch_gene_list is not None: 124 | raise NotImplementedError('Batch specific gene set is not implemented for cell type annotation pipeline. Please raise an issue on Github for further support.') 125 | if len(label_fields) != 1: 126 | raise NotImplementedError(f'`label_fields` containing multiple labels (f{len(label_fields)}) is not implemented for cell type annotation pipeline. Please raise an issue on Github for further support.') 127 | assert (split_field and train_split and valid_split), '`train_split` and `valid_split` must be specified.' 128 | adata = self.common_preprocess(adata, config['hvg'], covariate_fields, ensembl_auto_conversion) 129 | print(f'After filtering, {adata.shape[1]} genes remain.') 130 | dataset = TranscriptomicDataset(adata, split_field, covariate_fields, label_fields) 131 | self.label_encoders = dataset.label_encoders 132 | dataloader = DataLoader(dataset, batch_size=None, shuffle=True, num_workers=config['workers']) 133 | optim = torch.optim.AdamW([ 134 | {'params': list(self.model.embedder.parameters()), 'lr': config['lr'] * 0.1, 135 | 'weight_decay': 1e-10}, 136 | {'params': list(self.model.encoder.parameters()) + list(self.model.head.parameters()) + list( 137 | self.model.latent.parameters()), 'lr': config['lr'], 138 | 'weight_decay': config['wd']}, 139 | ]) 140 | if config['scheduler'] == 'plat': 141 | scheduler = ReduceLROnPlateau(optim, 'min', patience=config['patience'], factor=0.95) 142 | else: 143 | scheduler = None 144 | 145 | train_loss = [] 146 | valid_loss = [] 147 | valid_metric = [] 148 | final_epoch = -1 149 | best_dict = None 150 | 151 | for epoch in tqdm(range(config['epochs'])): 152 | self.model.train() 153 | epoch_loss = [] 154 | train_scores = [] 155 | 156 | if epoch < 30: 157 | for param_group in optim.param_groups[1:]: 158 | param_group['lr'] = config['lr'] * (epoch + 1) / 30 159 | 160 | for i, data_dict in enumerate(dataloader): 161 | input_dict = data_dict.copy() 162 | del input_dict['gene_list'], input_dict['split'] 163 | input_dict['loss_mask'] = torch.from_numpy((data_dict['split'] == train_split).values).bool() 164 | input_dict['label'] = input_dict[label_fields[0]] # Currently only support annotating one label 165 | for k in input_dict: 166 | input_dict[k] = input_dict[k].to(device) 167 | x_dict = XDict(input_dict) 168 | out_dict, loss = self.model(x_dict, data_dict['gene_list']) 169 | with torch.no_grad(): 170 | train_scores.append( 171 | downstream_eval('annotation', out_dict['pred'], out_dict['label'], **self.eval_dict)) 172 | 173 | optim.zero_grad() 174 | loss.backward() 175 | nn.utils.clip_grad_norm_(self.model.parameters(), 2.0) 176 | optim.step() 177 | epoch_loss.append(loss.item()) 178 | 179 | if config['scheduler'] == 'plat': 180 | scheduler.step(loss.item()) 181 | 182 | train_loss.append(sum(epoch_loss) / len(epoch_loss)) 183 | train_scores = aggregate_eval_results(train_scores) 184 | result_dict = inference(self.model, dataloader, valid_split, device, 185 | config['max_eval_batch_size'], self.eval_dict, label_fields) 186 | valid_scores = result_dict['metrics'] 187 | valid_loss.append(result_dict['loss']) 188 | valid_metric.append(valid_scores['f1_score']) 189 | 190 | print(f'Epoch {epoch} | Train loss: {train_loss[-1]:.4f} | Valid loss: {valid_loss[-1]:.4f}') 191 | print( 192 | f'Train ACC: {train_scores["acc"]:.4f} | Valid ACC: {valid_scores["acc"]:.4f} | ' 193 | f'Train f1: {train_scores["f1_score"]:.4f} | Valid f1: {valid_scores["f1_score"]:.4f} | ' 194 | f'Train pre: {train_scores["precision"]:.4f} | Valid pre: {valid_scores["precision"]:.4f}') 195 | 196 | if max(valid_metric) == valid_metric[-1]: 197 | best_dict = deepcopy(self.model.state_dict()) 198 | final_epoch = epoch 199 | 200 | if max(valid_metric) != max(valid_metric[-config['es']:]): 201 | print(f'Early stopped. Best validation performance achieved at epoch {final_epoch}.') 202 | break 203 | 204 | assert best_dict, 'Best state dict was not stored. Please report this issue on Github.' 205 | self.model.load_state_dict(best_dict) 206 | self.fitted = True 207 | return self 208 | 209 | def predict(self, adata: ad.AnnData, 210 | inference_config: dict = None, 211 | covariate_fields: List[str] = None, 212 | batch_gene_list: dict = None, 213 | ensembl_auto_conversion: bool = True, 214 | device: Union[str, torch.device] = 'cpu', 215 | ): 216 | config = CellTypeAnnotationDefaultPipelineConfig.copy() 217 | if inference_config: 218 | config.update(inference_config) 219 | self.model.to(device) 220 | assert self.fitted, 'Cell type annotation pipeline does not support zero shot setting. Please fine-tune the model on downstream datasets before inference.' 221 | if batch_gene_list is not None: 222 | raise NotImplementedError('Batch specific gene set is not implemented for cell type annotation pipeline. Please raise an issue on Github for further support.') 223 | adata = self.common_preprocess(adata, config['hvg'], covariate_fields, ensembl_auto_conversion) 224 | print(f'After filtering, {adata.shape[1]} genes remain.') 225 | dataset = TranscriptomicDataset(adata, None, covariate_fields, order_required=True) 226 | dataloader = DataLoader(dataset, batch_size=None, shuffle=False, num_workers=config['workers']) 227 | return inference(self.model, dataloader, None, device, 228 | config['max_eval_batch_size'], self.eval_dict, order_required=True)['pred'] 229 | 230 | def score(self, adata: ad.AnnData, 231 | evaluation_config: dict = None, 232 | split_field: str = None, 233 | target_split: str = 'test', 234 | covariate_fields: List[str] = None, 235 | label_fields: List[str] = None, 236 | batch_gene_list: dict = None, 237 | ensembl_auto_conversion: bool = True, 238 | device: Union[str, torch.device] = 'cpu', 239 | ): 240 | config = CellTypeAnnotationDefaultPipelineConfig.copy() 241 | if evaluation_config: 242 | config.update(evaluation_config) 243 | self.model.to(device) 244 | assert self.fitted, 'Cell type annotation pipeline does not support zero shot setting. Please fine-tune the model on downstream datasets before inference.' 245 | if batch_gene_list is not None: 246 | raise NotImplementedError('Batch specific gene set is not implemented for cell type annotation pipeline. Please raise an issue on Github for further support.') 247 | if len(label_fields) != 1: 248 | raise NotImplementedError( 249 | f'`label_fields` containing multiple labels (f{len(label_fields)}) is not implemented for cell type annotation pipeline. Please raise an issue on Github for further support.') 250 | if target_split: 251 | assert split_field, '`split_filed` must be provided when `target_split` is specified.' 252 | adata = self.common_preprocess(adata, config['hvg'], covariate_fields, ensembl_auto_conversion, ) 253 | print(f'After filtering, {adata.shape[1]} genes remain.') 254 | dataset = TranscriptomicDataset(adata, split_field, covariate_fields, label_fields, label_encoders=self.label_encoders) 255 | dataloader = DataLoader(dataset, batch_size=None, shuffle=False, num_workers=config['workers']) 256 | return inference(self.model, dataloader, target_split, device, 257 | config['max_eval_batch_size'], self.eval_dict, label_fields)['metrics'] -------------------------------------------------------------------------------- /CellPLM/pipeline/experimental.py: -------------------------------------------------------------------------------- 1 | def ensembl_to_symbol(gene_list): 2 | import mygene 3 | mg = mygene.MyGeneInfo() 4 | return mg.querymany(gene_list, scopes='ensembl.gene', fields='symbol', as_dataframe=True, 5 | species='human').reset_index().drop_duplicates(subset='query')['symbol'].fillna('0').tolist() 6 | 7 | def symbol_to_ensembl(gene_list): 8 | import mygene 9 | mg = mygene.MyGeneInfo() 10 | return mg.querymany(gene_list, scopes='symbol', fields='ensembl.gene', as_dataframe=True, 11 | species='human').reset_index().drop_duplicates(subset='query')['ensembl.gene'].fillna('0').tolist() -------------------------------------------------------------------------------- /CellPLM/pipeline/imputation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import scanpy as sc 5 | import anndata as ad 6 | from torch.optim.lr_scheduler import _LRScheduler, ReduceLROnPlateau 7 | from tqdm import tqdm 8 | from copy import deepcopy 9 | from ..utils.eval import downstream_eval, aggregate_eval_results, imputation_eval 10 | from ..utils.data import XDict, TranscriptomicDataset 11 | from typing import List, Literal, Union 12 | from .experimental import symbol_to_ensembl 13 | from torch.utils.data import DataLoader 14 | import warnings 15 | from . import Pipeline, load_pretrain 16 | from sklearn.metrics.cluster import adjusted_rand_score, normalized_mutual_info_score 17 | import scipy.sparse 18 | 19 | ImputationDefaultModelConfig = { 20 | 'objective': 'imputation', 21 | 'mask_node_rate': 0.95, 22 | 'mask_feature_rate': 0.25, 23 | 'max_batch_size': 70000, 24 | } 25 | 26 | ImputationDefaultPipelineConfig = { 27 | 'lr': 5e-4, 28 | 'wd': 1e-6, 29 | 'scheduler': 'plat', 30 | 'epochs': 100, 31 | 'max_eval_batch_size': 100000, 32 | 'patience': 5, 33 | 'workers': 0, 34 | } 35 | 36 | def inference(model, dataloader, split, device, batch_size, order_required=False): 37 | if order_required and split: 38 | warnings.warn('When cell order required to be preserved, dataset split will be ignored.') 39 | 40 | with torch.no_grad(): 41 | model.eval() 42 | epoch_loss = [] 43 | order_list = [] 44 | pred = [] 45 | for i, data_dict in enumerate(dataloader): 46 | if not order_required and split and np.sum(data_dict['split'] == split) == 0: 47 | continue 48 | 49 | idx = torch.arange(data_dict['x_seq'].shape[0]) 50 | for j in range(0, len(idx), batch_size): 51 | if len(idx) - j < batch_size: 52 | cur = idx[j:] 53 | else: 54 | cur = idx[j:j + batch_size] 55 | input_dict = {} 56 | for k in data_dict: 57 | if k == 'x_seq': 58 | input_dict[k] = data_dict[k].index_select(0, cur).to(device) 59 | elif k == 'gene_mask': 60 | input_dict[k] = data_dict[k].to(device) 61 | elif k not in ['gene_list', 'split']: 62 | input_dict[k] = data_dict[k][cur].to(device) 63 | x_dict = XDict(input_dict) 64 | out_dict, loss = model(x_dict, data_dict['gene_list']) 65 | epoch_loss.append(loss.item()) 66 | pred.append(out_dict['pred']) 67 | if order_required: 68 | order_list.append(input_dict['order_list']) 69 | pred = torch.cat(pred) 70 | if order_required: 71 | order = torch.cat(order_list) 72 | order.scatter_(0, order.clone(), torch.arange(order.shape[0]).to(order.device)) 73 | pred = pred[order] 74 | 75 | return {'pred': pred, 76 | 'loss': sum(epoch_loss) / len(epoch_loss)} 77 | 78 | class ImputationPipeline(Pipeline): 79 | def __init__(self, 80 | pretrain_prefix: str, 81 | overwrite_config: dict = ImputationDefaultModelConfig, 82 | pretrain_directory: str = './ckpt', 83 | ): 84 | super().__init__(pretrain_prefix, overwrite_config, pretrain_directory) 85 | self.label_encoders = None 86 | 87 | def fit(self, adata: ad.AnnData, 88 | train_config: dict = None, 89 | split_field: str = None, 90 | train_split: str = 'train', 91 | valid_split: str = 'valid', 92 | covariate_fields: List[str] = None, 93 | label_fields: List[str] = None, 94 | batch_gene_list: dict = None, 95 | ensembl_auto_conversion: bool = True, 96 | device: Union[str, torch.device] = 'cpu', 97 | ): 98 | config = ImputationDefaultPipelineConfig.copy() 99 | if train_config: 100 | config.update(train_config) 101 | self.model.to(device) 102 | assert not self.fitted, 'Current pipeline is already fitted and does not support continual training. Please initialize a new pipeline.' 103 | if label_fields: 104 | warnings.warn('`label_fields` argument is ignored in ImputationPipeline.') 105 | # adata = ad.concat([query_data, reference_data], join='outer', label='ref', keys=[False, True]) 106 | adata = self.common_preprocess(adata, 0, covariate_fields, ensembl_auto_conversion=False) 107 | print(f'After filtering, {adata.shape[1]} genes remain.') 108 | dataset = TranscriptomicDataset(adata, split_field, covariate_fields, label_fields, batch_gene_list) 109 | dataloader = DataLoader(dataset, batch_size=None, shuffle=True, num_workers=config['workers']) 110 | optim = torch.optim.AdamW(self.model.parameters(), lr=config['lr'], weight_decay=config['wd']) 111 | 112 | if config['scheduler'] == 'plat': 113 | scheduler = ReduceLROnPlateau(optim, 'min', patience=config['patience'], factor=0.9) 114 | else: 115 | scheduler = None 116 | 117 | train_loss = [] 118 | valid_loss = [] 119 | final_epoch = -1 120 | best_dict = None 121 | 122 | for epoch in tqdm(range(config['epochs'])): 123 | self.model.train() 124 | epoch_loss = [] 125 | 126 | if epoch < 5: 127 | for param_group in optim.param_groups: 128 | param_group['lr'] = config['lr'] * (epoch + 1) / 5 129 | 130 | for i, data_dict in enumerate(dataloader): 131 | if split_field and np.sum(data_dict['split'] == train_split) == 0: 132 | continue 133 | input_dict = data_dict.copy() 134 | del input_dict['gene_list'], input_dict['split'] 135 | for k in input_dict: 136 | input_dict[k] = input_dict[k].to(device) 137 | x_dict = XDict(input_dict) 138 | out_dict, loss = self.model(x_dict, data_dict['gene_list']) 139 | optim.zero_grad() 140 | loss.backward() 141 | nn.utils.clip_grad_norm_(self.model.parameters(), 2.0) 142 | optim.step() 143 | epoch_loss.append(loss.item()) 144 | 145 | train_loss.append(sum(epoch_loss) / len(epoch_loss)) 146 | if config['scheduler'] == 'plat': 147 | scheduler.step(train_loss[-1]) 148 | result_dict = inference(self.model, dataloader, valid_split, device, 149 | config['max_eval_batch_size']) 150 | valid_loss.append(result_dict['loss']) 151 | 152 | print(f'Epoch {epoch} | Train loss: {train_loss[-1]:.4f} | Valid loss: {valid_loss[-1]:.4f}') 153 | 154 | if min(valid_loss) == valid_loss[-1]: 155 | best_dict = deepcopy(self.model.state_dict()) 156 | # final_epoch = epoch 157 | 158 | # if min(valid_loss) != min(valid_loss[-config['es']:]): 159 | # print(f'Early stopped. Best validation performance achieved at epoch {final_epoch}.') 160 | # break 161 | 162 | assert best_dict, 'Best state dict was not stored. Please report this issue on Github.' 163 | self.model.load_state_dict(best_dict) 164 | self.fitted = True 165 | return self 166 | 167 | def predict(self, adata: ad.AnnData, 168 | inference_config: dict = None, 169 | covariate_fields: List[str] = None, 170 | batch_gene_list: dict = None, 171 | ensembl_auto_conversion: bool = True, 172 | device: Union[str, torch.device] = 'cpu',): 173 | 174 | self.model.to(device) 175 | config = ImputationDefaultPipelineConfig.copy() 176 | if inference_config: 177 | config.update(inference_config) 178 | adata = self.common_preprocess(adata, 0, covariate_fields, ensembl_auto_conversion) 179 | print(f'After filtering, {adata.shape[1]} genes remain.') 180 | dataset = TranscriptomicDataset(adata, None, order_required=True) 181 | dataloader = DataLoader(dataset, batch_size=None, shuffle=False, num_workers=0) 182 | pred = inference(self.model, dataloader, None, device, 183 | config['max_eval_batch_size'], order_required=True)['pred'] 184 | if 'target_genes' in inference_config: 185 | target_mask = torch.tensor( 186 | [adata.var.index.get_loc(g) for g in inference_config['target_genes']]).long().to(pred.device) 187 | pred = pred[:, target_mask] 188 | return pred 189 | 190 | def score(self, adata: ad.AnnData, 191 | evaluation_config: dict = None, 192 | split_field: str = None, 193 | target_split: str = None, 194 | covariate_fields: List[str] = None, 195 | label_fields: List[str] = None, 196 | batch_gene_list: dict = None, 197 | ensembl_auto_conversion: bool = True, 198 | device: Union[str, torch.device] = 'cpu', 199 | ): 200 | self.model.to(device) 201 | config = ImputationDefaultPipelineConfig.copy() 202 | if evaluation_config: 203 | config.update(evaluation_config) 204 | adata = self.common_preprocess(adata, 0, covariate_fields, ensembl_auto_conversion) 205 | print(f'After filtering, {adata.shape[1]} genes remain.') 206 | dataset = TranscriptomicDataset(adata, None, order_required=True) 207 | dataloader = DataLoader(dataset, batch_size=None, shuffle=False, num_workers=0) 208 | pred = inference(self.model, dataloader, None, device, 209 | config['max_eval_batch_size'], order_required=True)['pred'] 210 | if 'target_genes' in evaluation_config: 211 | target_mask = torch.tensor( 212 | [adata.var.index.get_loc(g) for g in evaluation_config['target_genes']]).long().to(pred.device) 213 | pred = pred[:, target_mask] 214 | if len(label_fields) != 1: 215 | raise NotImplementedError( 216 | f'`label_fields` containing multiple labels (f{len(label_fields)}) is not implemented for imputation pipeline. Please raise an issue on Github for further support.') 217 | if scipy.sparse.issparse(adata.obsm[label_fields[0]]): 218 | labels = torch.from_numpy(adata.obsm[label_fields[0]].toarray()).to(pred.device) 219 | else: 220 | labels = torch.from_numpy(adata.obsm[label_fields[0]]).to(pred.device) 221 | assert labels.shape[1] == pred.shape[ 222 | 1], f'Inconsistent number of genes between prediction ({pred.shape[1]}) and labels ({labels.shape[1]}). Please check: (1) Correct target gene list is provided in evaluation_config["target_genes"]. (2) Correct ground-truth gene expressions are provided in .obsm[label_fields[0]].' 223 | if split_field and target_split: 224 | labels = labels[adata.obs[split_field]==target_split] 225 | pred = pred[adata.obs[split_field]==target_split] 226 | # size_factor = 1e4 / (labels.sum(1, keepdims=True) + torch.from_numpy(adata.X.sum(1)).to(pred.device)) 227 | # size_factor[size_factor.isinf()] == 0 228 | # labels = size_factor * labels 229 | # pred = size_factor * pred 230 | return imputation_eval(torch.log1p(pred), torch.log1p(labels)) 231 | 232 | 233 | 234 | 235 | 236 | 237 | -------------------------------------------------------------------------------- /CellPLM/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import numpy as np 4 | import random 5 | import os 6 | import logging 7 | 8 | def set_seed(rndseed, cuda: bool = True, extreme_mode: bool = False): 9 | os.environ["PYTHONHASHSEED"] = str(rndseed) 10 | random.seed(rndseed) 11 | np.random.seed(rndseed) 12 | torch.manual_seed(rndseed) 13 | if cuda: 14 | torch.cuda.manual_seed(rndseed) 15 | torch.cuda.manual_seed_all(rndseed) 16 | if extreme_mode: 17 | torch.backends.cudnn.benchmark = False 18 | torch.backends.cudnn.deterministic = True 19 | # dgl.seed(rndseed) 20 | # dgl.random.seed(rndseed) 21 | logging.info(f"Setting global random seed to {rndseed}") 22 | 23 | class RMSNorm(nn.Module): 24 | def __init__(self, dim: int, eps: float = 1e-6): 25 | super().__init__() 26 | self.eps = eps 27 | self.weight = nn.Parameter(torch.ones(dim)) 28 | 29 | def _norm(self, x): 30 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 31 | 32 | def forward(self, x): 33 | output = self._norm(x.float()).type_as(x) 34 | return output * self.weight 35 | 36 | class DSBNNorm(nn.Module): 37 | def __init__(self, dim: int, domain_num: int, domain_label: str = 'dataset', eps: float = 1e-6, flip_rate=0.3): 38 | super().__init__() 39 | self.eps = eps 40 | self.domain_label = domain_label 41 | self.bns = nn.ModuleList([nn.BatchNorm1d(dim) for _ in range(domain_num+1)]) 42 | self.flip_rate = flip_rate 43 | 44 | def forward(self, xdict): 45 | h = xdict['h'] 46 | if self.training and random.random() 0, 'No available validation set.' 198 | return self._partition(self.val_idx) 199 | 200 | def to_ddp(self, n_partitions, max_batch_size=2000, val_num=0, val_idx=None): 201 | assert not self.isddp, 'Dataset is already ddp dataset.' 202 | 203 | if val_num > 0: 204 | if not val_idx: 205 | ids = np.random.permutation(len(self.batch_metadata['batch_id'])) 206 | self.val_idx = ids[:val_num] 207 | self.train_idx = ids[val_num:] 208 | else: 209 | self.train_idx = np.array( 210 | [i for i in range(len(self.batch_metadata['batch_id'])) if i not in set(val_idx)]) 211 | self.val_idx = np.array(val_idx) 212 | self.partitions = balanced_partition(np.array(self.batch_metadata['batch_size'])[self.train_idx], 213 | n_partitions, 214 | max_batch_size) 215 | new_partitions = [[] for _ in range(n_partitions)] 216 | for i, p in enumerate(self.partitions): 217 | for j in p: 218 | new_partitions[i].append(self.train_idx[j]) 219 | self.partitions = new_partitions 220 | 221 | else: 222 | self.train_idx = np.arange(len(self.batch_metadata['batch_id'])) 223 | self.val_idx = np.array([]) 224 | self.partitions = balanced_partition(self.batch_metadata['batch_size'], n_partitions, max_batch_size) 225 | self.isddp = True 226 | 227 | 228 | class SCPartitionDataset(Dataset): 229 | def __init__(self, batch_metadata, tensor_dir, idx, gene_set=None): 230 | self.batch_metadata = {} 231 | for k in batch_metadata: 232 | self.batch_metadata[k] = [batch_metadata[k][i] for i in idx] 233 | self.tensor_dir = tensor_dir 234 | with open(f'{tensor_dir}/dataset_metadata.json') as f: 235 | self.dataset_metadata = json.load(f) 236 | 237 | self.bid2did = dict(zip(self.batch_metadata['batch_id'], self.batch_metadata['dataset_id'])) 238 | self.did2gene = dict(zip(self.dataset_metadata['id'], self.dataset_metadata['gene_list'])) 239 | 240 | if gene_set: 241 | gene_mask = [] 242 | for i in self.dataset_metadata['gene_list']: 243 | i = set(i) 244 | gene_mask.append(torch.tensor([j in i for j in gene_set]).bool()) 245 | self.did2mask = dict(zip(self.dataset_metadata['id'], gene_mask)) 246 | else: 247 | self.did2mask = None 248 | 249 | def __len__(self): 250 | return len(self.batch_metadata['batch_id']) # //10 251 | 252 | def __getitem__(self, idx): 253 | tensor_path = os.path.join(self.tensor_dir, str(self.batch_metadata['batch_id'][idx]) + '.pt') 254 | seq = torch.load(tensor_path).coalesce() 255 | if self.batch_metadata['platform'][idx] in SPATIAL_PLATFORM_LIST: 256 | coord = torch.load(os.path.join(self.tensor_dir, str(self.batch_metadata['batch_id'][idx]) + '.coord.pt')) 257 | else: 258 | coord = torch.zeros([seq.shape[0], 2]).float() - 1 259 | if seq.shape[0] > 2000: 260 | randid = torch.randperm(seq.shape[0]) 261 | coord = coord[randid[:2000]] 262 | seq = seq.index_select(0, randid[:2000]).coalesce() 263 | batch_id = torch.zeros([seq.shape[0]]).long() + int(self.batch_metadata['batch_id'][idx]) 264 | dataset_id = torch.zeros([seq.shape[0]]).long() + int(self.batch_metadata['dataset_id'][idx]) 265 | gene_mask = self.get_gene_mask(self.batch_metadata['dataset_id'][idx]) if self.did2mask else torch.ones( 266 | [seq.shape[1]]).bool() 267 | seq = [seq.indices(), seq.values(), torch.tensor(seq.shape)] 268 | return seq, coord, batch_id, dataset_id, gene_mask 269 | 270 | def get_gene_list(self, dataset_id): 271 | return self.did2gene[dataset_id] 272 | 273 | def get_gene_mask(self, dataset_id): 274 | assert self.did2mask, 'gene_set was not passed when created dataset.' 275 | return self.did2mask[dataset_id] 276 | 277 | 278 | class XDict(dict): 279 | def __init__(self, *args, **kwargs): 280 | super().__init__(*args, **kwargs) 281 | self._num = self[list(self.keys())[0]].shape[0] 282 | 283 | # No longer required 284 | # def check(self): 285 | # for k, v in self.items(): 286 | # assert isinstance(v, torch.Tensor), f'{k} is not a torch.Tensor' 287 | # assert v.shape[0] == self._num, f'{k} contains {v.shape[0]} samples. Expected: f{self._num}' 288 | 289 | def size(self): 290 | warnings.warn("Deprecated function: Xdict.size().", DeprecationWarning) 291 | return self._num 292 | 293 | # Not usable for sparse data 294 | # def drop(self, ratio): 295 | # drop_num = int(self._num * ratio) 296 | # keep_idx = np.random.permutation(self._num)[drop_num:] 297 | # for k, v in self.items(): 298 | # self[k] = v[keep_idx] 299 | # return self 300 | 301 | 302 | def clean_batches(data): 303 | # Remove batch with less than 1000 cells 304 | sc.pp.filter_cells(data, min_counts=5) 305 | remove_list = [] 306 | for b in data.obs['batch'].value_counts().reset_index().iterrows(): 307 | if b[1]['batch'] < 500: 308 | remove_list.append(b[1]['index']) 309 | data = data[~data.obs['batch'].isin(set(remove_list))] 310 | return data 311 | 312 | 313 | def balanced_partition(data, n_partitions, max_batch_size=2000): 314 | # Sort batches 315 | if torch.is_tensor(data[0]): 316 | batch_sizes = [(i, len(batch)) for i, batch in enumerate(data)] 317 | else: 318 | batch_sizes = [(i, batch) for i, batch in enumerate(data)] 319 | batch_sizes.sort(key=lambda x: x[1], reverse=True) 320 | 321 | # inialize partitions 322 | partitions = [[] for _ in range(n_partitions)] 323 | 324 | # Fill partitions 325 | j = 0 326 | for (i, _) in batch_sizes: 327 | partitions[j].append(i) 328 | j = (j + 1) % n_partitions 329 | return partitions 330 | 331 | 332 | def stratified_sample_genes_by_sparsity(data, boundaries=None, seed=10): 333 | df = data.to_df() 334 | zero_rates = 1 - df.astype(bool).sum(axis=0) / df.shape[0] 335 | if boundaries is None: 336 | # boundaries = [0, zero_rates.mean() - zero_rates.std(), zero_rates.mean(), 337 | # min(zero_rates.mean() + zero_rates.std(), 1)] 338 | boundaries = [0, 0.75, 0.9, 0.95, 1] 339 | gene_group = pd.cut(zero_rates, boundaries, labels=False) 340 | # gene_df = pd.DataFrame({'zero_rates': zero_rates, 'gene_group': gene_group}) 341 | zero_rates = zero_rates.groupby(gene_group, group_keys=False) 342 | samples = zero_rates.apply(lambda x: x.sample(min(len(x), 25), random_state=seed)) 343 | return list(samples.index) 344 | 345 | 346 | def data_setup(adata, return_sparse=True, device='cpu'): 347 | warnings.warn("`Data_setup` function is deprecated. Use `CellPLM.pipeline` instead.", DeprecationWarning) 348 | # Data Setup 349 | order = torch.arange(adata.shape[0], device=device) 350 | lb = LabelEncoder().fit(adata.obs['batch']) 351 | batch_labels = lb.transform(adata.obs['batch']) 352 | # print(lb.classes_) 353 | seq_list = [[], [], [], []] if return_sparse else [] 354 | batch_list = [] 355 | order_list = [] 356 | dataset_list = [] 357 | coord_list = [] 358 | if adata.obs['cell_type'].dtype != int: 359 | labels = LabelEncoder().fit_transform(adata.obs['cell_type']) 360 | else: 361 | labels = adata.obs['cell_type'].values 362 | print(labels.mean()) 363 | label_list = [] 364 | dataset_label = LabelEncoder().fit_transform(adata.obs['Dataset']) 365 | for batch in range(batch_labels.max() + 1): 366 | if return_sparse: 367 | x = (adata.X[batch_labels == batch]).astype(float) 368 | x = list(map(torch.from_numpy, [x.indptr, x.indices, x.data])) + [torch.tensor(x.shape)] 369 | for i in range(4): 370 | seq_list[i].append(x[i].to(device)) 371 | else: 372 | x = torch.from_numpy(adata.X[batch_labels == batch].todense()).float() 373 | seq_list.append(x.to(device)) 374 | # x = torch.sparse_csr_tensor(x.indptr, x.indices, x.data, (x.shape[0], x.shape[1])).to_sparse().float() 375 | # seq_list.append(x) 376 | order_list.append(order[batch_labels == batch]) 377 | dataset_list.append(torch.from_numpy(dataset_label[batch_labels == batch]).long().to(device)) 378 | batch_list.append(torch.from_numpy(batch_labels[batch_labels == batch]).to(device)) 379 | if adata.obs['platform'][batch_labels == batch][0] in SPATIAL_PLATFORM_LIST: 380 | coord_list.append( 381 | torch.from_numpy(adata.obs[['x_FOV_px', 'y_FOV_px']][batch_labels == batch].values).to(device)) 382 | else: 383 | coord_list.append(torch.zeros(order_list[-1].shape[0], 2).to(device) - 1) 384 | label_list.append(torch.from_numpy(labels[batch_labels == batch].astype(int)).to(device)) 385 | del order 386 | return seq_list, batch_list, batch_labels, order_list, dataset_list, coord_list, label_list 387 | -------------------------------------------------------------------------------- /CellPLM/utils/eval.py: -------------------------------------------------------------------------------- 1 | import scanpy as sc 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | from typing import List 6 | from torchmetrics.functional.classification import multiclass_f1_score, multiclass_accuracy, multiclass_precision, multiclass_recall 7 | 8 | def aggregate_eval_results(scores: List[dict]): 9 | scores_new = {} 10 | for k in scores[0].keys(): 11 | scores_new[k] = [] 12 | for t in scores: 13 | scores_new[k].append(t[k]) 14 | scores_new[k] = sum(scores_new[k]) / len(scores_new[k]) 15 | scores = scores_new 16 | return scores 17 | 18 | def downstream_eval(task, pred_labels, true_labels, num_classes=None, eval_mask=None, dim=1, 19 | normalize=True, top_de_dict=None, batch_labels=None, control_level=None, 20 | topk=20, **kwargs): 21 | if task == 'annotation': 22 | return annotation_eval(pred_labels, true_labels, num_classes) 23 | elif task == 'denoising': 24 | return denoising_eval(pred_labels, true_labels, eval_mask, normalize) 25 | elif task == 'imputation': 26 | return imputation_eval(pred_labels, true_labels, dim) 27 | elif task == 'perturbation_prediction': 28 | raise NotImplementedError("For simplicity, the perturbation evaluation is removed from the current release.") 29 | else: 30 | raise NotImplementedError(f"{task} should be chosen from ['annotation', 'denoising', 'imputation', 'perturbation_prediction']") 31 | 32 | def CountCorr(y_true, y_pred): 33 | y_true = torch.log1p(y_true) 34 | y_pred = torch.log1p(y_pred) 35 | y_true_c = y_true - torch.mean(y_true, 1)[:, None] 36 | y_pred_c = y_pred - torch.mean(y_pred, 1)[:, None] 37 | pearson = torch.mean(torch.sum(y_true_c * y_pred_c, 1) / torch.sqrt(torch.sum(y_true_c * y_true_c, 1)) / torch.sqrt( 38 | torch.sum(y_pred_c * y_pred_c, 1))) 39 | return pearson 40 | 41 | def PearsonCorr(y_true, y_pred): 42 | assert len(y_true.shape) == 2 43 | y_true_c = y_true - torch.mean(y_true, 1)[:, None] 44 | y_pred_c = y_pred - torch.mean(y_pred, 1)[:, None] 45 | pearson = torch.mean(torch.sum(y_true_c * y_pred_c, 1) / torch.sqrt(torch.sum(y_true_c * y_true_c, 1)) 46 | / torch.sqrt(torch.sum(y_pred_c * y_pred_c, 1))) 47 | return pearson 48 | 49 | def PearsonCorr1d(y_true, y_pred): 50 | assert len(y_true.shape) == 1 51 | y_true_c = y_true - torch.mean(y_true) 52 | y_pred_c = y_pred - torch.mean(y_pred) 53 | pearson = torch.mean(torch.sum(y_true_c * y_pred_c) / torch.sqrt(torch.sum(y_true_c * y_true_c)) 54 | / torch.sqrt(torch.sum(y_pred_c * y_pred_c))) 55 | return pearson 56 | 57 | 58 | def clustering_eval(adata, cluster_key='leiden', label_key='cell_type'): 59 | raise NotImplementedError("For simplicity, rapids_singlecell was removed from the dependency. Therefore currently the clustering evaluation is not available.") 60 | import rapids_singlecell as rsc 61 | from scib.metrics.ari import ari 62 | from scib.metrics.nmi import nmi 63 | print('Start building knn.') 64 | sc.pp.neighbors(adata, use_rep='X_cellbert', method='rapids') 65 | best_ari = -1 66 | best_nmi = -1 67 | for res in range(1, 15, 1): 68 | res = res / 10 69 | rsc.tl.leiden(adata, resolution=res, key_added=cluster_key) 70 | ari_score = ari(adata, cluster_key=cluster_key, label_key=label_key) 71 | if ari_score > best_ari: 72 | best_ari = ari_score 73 | nmi_score = nmi(adata, cluster_key=cluster_key, label_key=label_key) 74 | if nmi_score > best_nmi: 75 | best_nmi = nmi_score 76 | return {'ari': best_ari, 'nmi':best_nmi} 77 | 78 | def minimum_eval(adata): 79 | raise NotImplementedError("For simplicity, scib was removed from the dependency. Therefore currently the scib evaluation is not available.") 80 | import scib 81 | print('Start building knn.') 82 | sc.pp.neighbors(adata, use_rep='X_cellbert', method='rapids') 83 | return scib.metrics.metrics(adata, adata, "batch", "cell_type", embed='X_cellbert', cluster_key="cluster", 84 | #organism='human', ari_=True, nmi_=True, pcr_=True, graph_conn_=True) 85 | organism = 'human', graph_conn_ = True) 86 | 87 | def annotation_eval(pred_labels, true_labels, num_classes=None): 88 | num_classes = len(true_labels.unique()) if num_classes is None else num_classes 89 | acc = multiclass_accuracy(pred_labels, true_labels, num_classes).cpu().item() 90 | f1_score = multiclass_f1_score(pred_labels, true_labels, num_classes).cpu().item() 91 | precision = multiclass_precision(pred_labels, true_labels, num_classes).cpu().item() 92 | recall = multiclass_recall(pred_labels, true_labels, num_classes).cpu().item() 93 | return {'acc': acc, 'f1_score': f1_score, 'precision': precision, 'recall': recall} 94 | 95 | def normalize_counts(counts): 96 | counts = F.relu(counts / counts.sum(1, keepdim=True) * 1e4) 97 | return torch.log1p(counts) 98 | 99 | def denoising_eval(pred_labels, true_labels, eval_mask=None, normalize=True): 100 | if normalize: 101 | true_labels = normalize_counts(true_labels) 102 | pred_labels = normalize_counts(pred_labels) 103 | if eval_mask is not None: 104 | true_labels = true_labels[eval_mask] 105 | pred_labels = pred_labels[eval_mask] 106 | nz_idx = torch.nonzero(true_labels, as_tuple=True) 107 | true_labels = true_labels#[nz_idx] 108 | pred_labels = pred_labels#[nz_idx] 109 | corr = PearsonCorr1d(pred_labels, true_labels).item() 110 | cos = F.cosine_similarity(pred_labels, true_labels, dim=0).item() 111 | else: 112 | corr = PearsonCorr(pred_labels, true_labels).item() 113 | cos = F.cosine_similarity(pred_labels, true_labels, dim=1).mean().item() 114 | mse = F.mse_loss(pred_labels, true_labels).item() 115 | rmse = np.sqrt(mse) 116 | mae = F.l1_loss(pred_labels, true_labels).item() 117 | return {'mse': mse, 'rmse':rmse, 'mae':mae, 'corr':corr, 'cos': cos} 118 | 119 | def imputation_eval(pred_labels, true_labels, dim=1): 120 | mse = [] 121 | rmse = [] 122 | rmsle = [] 123 | mae = [] 124 | corr = [] 125 | cos = [] 126 | # if ((true_labels - true_labels.int().float())<1e-6).all(): 127 | # print('Lognorm') 128 | # true_labels = torch.log1p(true_labels) 129 | # pred_labels = torch.log1p(pred_labels) 130 | for i in range(true_labels.shape[dim]): 131 | true_vec = true_labels[i] if dim == 0 else true_labels[:, i] 132 | pred_vec = F.relu(pred_labels[i]) if dim == 0 else F.relu(pred_labels[:, i]) 133 | (nz_idx,) = torch.nonzero(true_vec, as_tuple=True) 134 | true_nz = true_vec#[nz_idx] 135 | pred_nz = pred_vec#[nz_idx] 136 | mse.append(F.mse_loss(pred_nz, true_nz).item()) 137 | rmse.append(np.sqrt(mse)) 138 | # rmsle.append(np.sqrt(F.mse_loss(torch.log(pred_nz + 1), torch.log(true_nz + 1)).item())) 139 | mae.append(F.l1_loss(pred_nz, true_nz).item()) 140 | corr.append(PearsonCorr1d(pred_nz, true_nz).item()) 141 | cos.append(F.cosine_similarity(pred_nz, true_nz, dim=0).item()) 142 | rmse = np.concatenate(rmse) 143 | return { 144 | 'mse': sum(mse) / len(mse), 145 | 'rmse': sum(rmse) / len(rmse), 146 | # 'rmsle': sum(rmsle) / len(rmsle), 147 | 'mae': sum(mae) / len(mae), 148 | 'corr': sum(corr) / len(corr), 149 | 'cos': sum(cos) / len(cos), 150 | } 151 | -------------------------------------------------------------------------------- /CellPLM/utils/mask.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | import torch.nn as nn 5 | from scipy.stats import expon 6 | from scipy.sparse import csr_matrix 7 | from .sparse import simple_mask 8 | import torch.distributions as td 9 | import math 10 | 11 | 12 | def drop_nodes(x_dict, drop_node_rate=0., max_batch_size=2000, inplace=True): 13 | if inplace == False: 14 | raise NotImplementedError('Only support inplace drop nodes') 15 | 16 | if drop_node_rate > 0: # cell_idx is the index of the nodes that are not dropped 17 | cell_idx = torch.randperm(x_dict['x_seq'].shape[0], device=x_dict['x_seq'].device)[ 18 | :min(max_batch_size, int(x_dict['x_seq'].shape[0] * (1 - drop_node_rate)))] 19 | x_dict['x_seq'] = x_dict['x_seq'].index_select(0, cell_idx) 20 | if 'batch' in x_dict: 21 | x_dict['batch'] = x_dict['batch'][cell_idx] 22 | if 'h' in x_dict: 23 | x_dict['h'] = x_dict['h'][cell_idx] 24 | if 'g' in x_dict: 25 | x_dict['g'] = x_dict['g'][cell_idx][:, cell_idx] 26 | if 'coord' in x_dict: 27 | x_dict['coord'] = x_dict['coord'][cell_idx] 28 | if 'label' in x_dict: 29 | x_dict['label'] = x_dict['label'][cell_idx] 30 | if 'lib_size' in x_dict: 31 | x_dict['lib_size'] = x_dict['lib_size'][cell_idx] 32 | if 'x_masked_seq' in x_dict: 33 | x_dict['x_masked_seq'] = x_dict['x_masked_seq'].index_select(0, cell_idx) 34 | if 'dataset' in x_dict: 35 | x_dict['dataset'] = x_dict['dataset'][cell_idx] 36 | if 'loss_mask' in x_dict: 37 | x_dict['loss_mask'] = x_dict['loss_mask'][cell_idx] 38 | 39 | class NullMaskBuilder(nn.Module): 40 | def __init__(self, drop_node_rate, max_batch_size=2000): 41 | super().__init__() 42 | self._drop_node_rate = drop_node_rate 43 | self._max_batch_size = max_batch_size 44 | 45 | def apply_mask(self, x_dict): 46 | if self._drop_node_rate > 0 and self.training: 47 | drop_nodes(x_dict, self._drop_node_rate, self._max_batch_size) 48 | # x_dict['mask'] = torch.arange(x_dict['h'].shape[0], device=x_dict['h'].device) 49 | x_dict['input_mask'] = torch.ones(*x_dict['x_seq'].shape, device=x_dict['x_seq'].device).int() 50 | return x_dict 51 | 52 | class MaskBuilder(nn.Module): 53 | def __init__(self, mask_node_rate, mask_feature_rate, drop_node_rate=0, max_batch_size=2000, edge_mask=None, mask_beta=False): 54 | super().__init__() 55 | self._mask_node_rate = mask_node_rate 56 | self._mask_feature_rate = mask_feature_rate 57 | self._edge_mask = edge_mask 58 | self._drop_node_rate = drop_node_rate 59 | self._max_batch_size = max_batch_size 60 | if self._mask_node_rate > 0 and self._mask_feature_rate and mask_beta: 61 | alpha = 5 62 | beta = 4 / self._mask_feature_rate + 2 - alpha 63 | self.beta_dist = td.Beta(alpha, beta) 64 | 65 | self.mask_beta = mask_beta 66 | 67 | def update_mask_ratio(self, mask_node_rate, mask_feature_rate): 68 | self._mask_node_rate = mask_node_rate 69 | self._mask_feature_rate = mask_feature_rate 70 | 71 | # This function mask parts of the nodes, and only the masked nodes will be used in the loss function 72 | def apply_mask(self, x_dict): 73 | if self.training and self._drop_node_rate > 0: 74 | drop_nodes(x_dict, self._drop_node_rate, self._max_batch_size) 75 | if self.training and self._mask_node_rate > 0: 76 | if 'x_masked_seq' in x_dict: 77 | x = x_dict['x_masked_seq'] 78 | else: 79 | x = x_dict['x_seq'] 80 | 81 | if self.mask_beta: 82 | mask_ratio = self.beta_dist.sample((x.shape[0],)).to(x.device) 83 | mask_ratio[mask_ratio > 0.9] = 0.9 84 | num_nodes = x.shape[0] 85 | perm = np.random.permutation(num_nodes) 86 | num_mask_nodes = int(self._mask_node_rate * num_nodes) 87 | keep_nodes = perm[num_mask_nodes:] 88 | mask = torch.rand(*x.shape, device=x.device) <= mask_ratio.unsqueeze(-1) 89 | mask[keep_nodes] = False 90 | else: 91 | num_nodes = x.shape[0] 92 | perm = np.random.permutation(num_nodes) 93 | num_mask_nodes = int(self._mask_node_rate * num_nodes) 94 | keep_nodes = perm[num_mask_nodes:] # keep_nodes is the index of the nodes that are not masked 95 | mask = torch.rand(*x.shape, device=x.device) <= self._mask_feature_rate 96 | mask[keep_nodes] = False 97 | 98 | x = x.coalesce() 99 | masked_x_seq = simple_mask(x, mask) 100 | x_dict['masked_x_seq'] = masked_x_seq 101 | x_dict['input_mask'] = mask.int() 102 | else: 103 | x_dict['input_mask'] = torch.ones(*x_dict['x_seq'].shape, device=x_dict['x_seq'].device).int() 104 | return x_dict 105 | 106 | class HiddenMaskBuilder(nn.Module): 107 | def __init__(self, mask_node_rate, mask_countsure_rate, drop_node_rate=0, max_batch_size=2000, edge_mask=None): 108 | super().__init__() 109 | self._mask_node_rate = mask_node_rate 110 | self._mask_countsure_rate = mask_countsure_rate 111 | self._edge_mask = edge_mask 112 | self._drop_node_rate = drop_node_rate 113 | self._max_batch_size = max_batch_size 114 | 115 | def update_mask_ratio(self, mask_node_rate, mask_feature_rate): 116 | self._mask_node_rate = mask_node_rate 117 | self._mask_feature_rate = mask_feature_rate 118 | 119 | # This function mask parts of the nodes, and only the masked nodes will be used in the loss function 120 | def apply_mask(self, x_dict): 121 | if self._drop_node_rate > 0 and self.training: 122 | drop_nodes(x_dict, self._drop_node_rate, self._max_batch_size) 123 | if self._mask_node_rate > 0 and self.training: 124 | num_nodes = x_dict['h'].shape[0] 125 | perm = np.random.permutation(num_nodes) 126 | num_mask_nodes = int(self._mask_node_rate * num_nodes) 127 | keep_nodes = perm[num_mask_nodes:] # keep_nodes is the index of the nodes that are not masked 128 | 129 | out_x = F.dropout(x_dict['h'], p=self._mask_countsure_rate) # mask the countsures of all nodes 130 | out_x[keep_nodes] = x_dict['h'][keep_nodes] # keep the countsures of the nodes that are not masked 131 | # x_dict['h'] = out_x 132 | x_dict['input_mask'] = torch.zeros(x_dict['h'].shape[0], device=x_dict['h'].device).unsqueeze(-1) 133 | x_dict['input_mask'][perm[: num_mask_nodes]] = 1. 134 | else: 135 | x_dict['input_mask'] = torch.ones(x_dict['h'].shape[0], device=x_dict['h'].device).unsqueeze(-1) 136 | return x_dict 137 | 138 | 139 | class InputDropoutMaskBuilder(nn.Module): 140 | def __init__(self, input_drop_type="mar", valid_drop_rate=0.1, test_drop_rate=0.1, seed=10, 141 | min_gene_counts=5): 142 | super().__init__() 143 | assert 0 <= valid_drop_rate < 1, "valid_drop_rate should be in [0, 1)" 144 | assert 0 < test_drop_rate < 1, "test_drop_rate should be in (0, 1)" 145 | assert 0 < valid_drop_rate + test_drop_rate < 1, "Total masking rate should be in (0, 1)" 146 | self._input_drop_type = input_drop_type 147 | self._valid_drop_rate = valid_drop_rate 148 | self._test_drop_rate = test_drop_rate 149 | self._min_gene_counts = min_gene_counts 150 | self._seed = seed 151 | if input_drop_type == "mcar": 152 | self.distr = "uniform" 153 | elif input_drop_type == "mar": 154 | self.distr = "exp" 155 | else: 156 | raise NotImplementedError(f"Expect mask_type in ['mar', 'mcar'], but found {self.mask_type}") 157 | 158 | def _get_probs(self, vec): 159 | return { 160 | "exp": expon.pdf(vec, 0, 20), 161 | "uniform": np.tile([1. / len(vec)], len(vec)), 162 | }.get(self.distr) 163 | 164 | def apply_mask(self, x_seq): 165 | counts = x_seq.to_dense() 166 | train_mask = np.ones(counts.shape, dtype=bool) 167 | valid_mask = np.zeros(counts.shape, dtype=bool) 168 | test_mask = np.zeros(counts.shape, dtype=bool) 169 | rng = np.random.default_rng(self._seed) 170 | 171 | for c in range(counts.shape[0]): 172 | # Retrieve indices of positive values 173 | ind_pos = torch.nonzero(counts[c], as_tuple=True)[0] 174 | cells_c_pos = counts[c, ind_pos] 175 | 176 | # Get masking probability of each value 177 | if len(cells_c_pos) > self._min_gene_counts: 178 | mask_prob = self._get_probs(cells_c_pos) 179 | mask_prob = mask_prob / sum(mask_prob) 180 | n_test = int(np.floor(len(cells_c_pos) * self._test_drop_rate)) 181 | n_valid = int(np.floor(len(cells_c_pos) * self._valid_drop_rate)) 182 | if n_test + n_valid >= len(cells_c_pos): 183 | print(f"Too many genes masked for cell {c} ({n_test + n_valid}/{len(cells_c_pos)})") 184 | n_test -= 1 185 | n_valid -= 1 186 | 187 | idx_mask = np.ones(len(ind_pos), dtype=bool) 188 | test_idx = rng.choice(np.arange(len(ind_pos)), n_test, p=mask_prob, replace=False) 189 | train_mask[c, ind_pos[test_idx]] = False 190 | test_mask[c, ind_pos[test_idx]] = True 191 | if self._valid_drop_rate > 0: 192 | idx_mask[test_idx] = False 193 | masked_mask_prob = mask_prob[idx_mask] / sum(mask_prob[idx_mask]) 194 | valid_idx = rng.choice(np.arange(len(ind_pos))[idx_mask], n_valid, p=masked_mask_prob, replace=False) 195 | train_mask[c, ind_pos[valid_idx]] = False 196 | valid_mask[c, ind_pos[valid_idx]] = True 197 | 198 | return train_mask, valid_mask, test_mask 199 | -------------------------------------------------------------------------------- /CellPLM/utils/pe.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from ..utils import create_norm 4 | import math 5 | 6 | def select_pe_encoder(pe): 7 | if pe in ['sin', 'sinu', 'sinusoidal']: 8 | return Sinusoidal2dPE 9 | elif pe in ['learnable', 'bin']: 10 | return Learnable2dPE 11 | elif pe in ['naive', 'mlp']: 12 | return NaivePE 13 | elif pe in ['lap', 'graphlap', 'lappe']: 14 | return GraphLapPE 15 | else: 16 | raise NotImplementedError(f'Unsupported positional encoding type: {pe}') 17 | 18 | class Sinusoidal2dPE(nn.Module): 19 | def __init__(self, d_model, height=100, width=100): 20 | """ 21 | :param d_model: dimension of the model 22 | :param height: height of the positions 23 | :param width: width of the positions 24 | """ 25 | super().__init__() 26 | if d_model % 4 != 0: 27 | raise ValueError("Cannot use sin/cos positional encoding with " 28 | "odd dimension (got dim={:d})".format(d_model)) 29 | self.d_model = d_model 30 | self.height = height 31 | self.width = width 32 | self.pe_key = 'coord' 33 | self.missing_pe = nn.Parameter(torch.randn(d_model) * 1e-2) 34 | 35 | pe = torch.zeros(d_model, height, width) 36 | # Each dimension use half of d_model 37 | d_model = int(d_model / 2) 38 | div_term = torch.exp(torch.arange(0., d_model, 2) * 39 | -(math.log(10000.0) / d_model)) 40 | pos_w = torch.arange(0., width).unsqueeze(1) 41 | pos_h = torch.arange(0., height).unsqueeze(1) 42 | pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1) 43 | pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1) 44 | pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) 45 | pe[d_model + 1::2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) 46 | self.pe_enc = nn.Embedding.from_pretrained(pe.flatten(1).T) 47 | 48 | def forward(self, coordinates): 49 | if coordinates[0][0] == -1: 50 | return self.missing_pe.unsqueeze(0).expand(coordinates.shape[0], -1) 51 | x = coordinates[:, 0] 52 | y = coordinates[:, 1] 53 | x = ((x*1.02-0.01) * self.width).long() 54 | y = ((y*1.02-0.01) * self.height).long() 55 | x[x >= self.width] = self.width - 1 56 | y[y >= self.height] = self.height - 1 57 | x[x < 0] = 0 58 | y[y < 0] = 0 59 | pe_input = x * self.width + y 60 | return self.pe_enc(pe_input) 61 | 62 | class Learnable2dPE(nn.Module): 63 | def __init__(self, d_model, height=100, width=100): 64 | """ 65 | :param d_model: dimension of the model 66 | :param height: height of the positions 67 | :param width: width of the positions 68 | """ 69 | super().__init__() 70 | self.pe_enc = nn.Embedding(height * width, d_model) 71 | self.missing_pe = nn.Parameter(torch.randn(d_model) * 1e-2) 72 | self.pe_key = 'coord' 73 | 74 | def forward(self, coordinates): 75 | if coordinates[0][0] == -1: 76 | return self.missing_pe.unsqueeze(0).expand(coordinates.shape[0], -1) 77 | x = coordinates[:, 0] 78 | y = coordinates[:, 1] 79 | x = ((x*1.02-0.01) * self.width).long() 80 | y = ((y*1.02-0.01) * self.height).long() 81 | x[x >= self.width] = self.width 82 | y[y >= self.height] = self.height 83 | x[x < 0] = 0 84 | y[y < 0] = 0 85 | pe_input = x * self.width + y 86 | return self.pe_enc(pe_input) 87 | 88 | class NaivePE(nn.Module): 89 | def __init__(self, d_model, coord_dim = 2, height=None, width=None): 90 | """ 91 | :param d_model: dimension of the model 92 | :param coord_dim: dimension of coordinates 93 | :param height: placeholder 94 | :param width: placeholder 95 | """ 96 | super().__init__() 97 | self.pe_enc = nn.Sequential( 98 | nn.Linear(coord_dim, d_model), 99 | nn.PReLU(), 100 | ) 101 | self.missing_pe = nn.Parameter(torch.randn(d_model) * 1e-2) 102 | self.pe_key = 'coord' 103 | 104 | def forward(self, coordinates): 105 | if coordinates[0][0] == -1: 106 | return self.missing_pe.unsqueeze(0).expand(coordinates.shape[0], -1) 107 | return self.pe_enc(coordinates) 108 | 109 | class GraphLapPE(nn.Module): 110 | def __init__(self, d_model, k = 10, height=None, width=None): 111 | """ 112 | :param d_model: dimension of the model 113 | :param k: top k 114 | :param height: placeholder 115 | :param width: placeholder 116 | """ 117 | super().__init__() 118 | self.k = k 119 | self.pe_enc = nn.Sequential( 120 | nn.Linear(k, d_model), 121 | nn.PReLU(), 122 | ) 123 | self.missing_pe = nn.Parameter(torch.randn(d_model) * 1e-2) 124 | self.pe_key = 'eigvec' 125 | 126 | def forward(self, eigvec): 127 | if eigvec[0][0] == -1: 128 | return self.missing_pe.unsqueeze(0).expand(eigvec.shape[0], -1) 129 | eigvec = eigvec * (torch.randint(0, 2, (self.k, ), dtype=torch.float, device=eigvec.device)[None, :]*2-1) 130 | return self.pe_enc(eigvec) 131 | 132 | -------------------------------------------------------------------------------- /CellPLM/utils/sparse.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | def sparse_diag(x): 5 | indices = torch.arange(len(x), device=x.device).unsqueeze(0).repeat(2, 1) 6 | values = x 7 | return torch.sparse_coo_tensor(indices, values, (len(x), len(x)), device=x.device) 8 | 9 | def sparse_normalize(x): 10 | size_factor = sparse_diag(1. / (torch.sparse.sum(x, dim=1).to_dense() + 1e-8)) 11 | res = torch.sparse.mm(size_factor, x) 12 | return res 13 | 14 | def sparse_tpm(x): 15 | x = sparse_normalize(x) * 1e4 16 | x = torch.log1p(x) 17 | return x 18 | 19 | def create_sparse_tensor(x, i): 20 | # x is a list of 4 tensors 21 | return torch.sparse_csr_tensor(x[0][i], x[1][i], 22 | x[2][i], 23 | x[3][i].tolist()).to_sparse().float().coalesce() 24 | 25 | def mask_with_renormalize(x, mask, keep_nodes, mask_feature_rate): 26 | masked_x_seq = torch.sparse.FloatTensor(x.indices(), 27 | torch.where(mask[x.indices()[0], 28 | x.indices()[1]], 29 | torch.where(torch.isin(x.indices()[0], 30 | torch.from_numpy(keep_nodes).to(x.device)), 31 | x.values(), 32 | x.values() + math.log(1) - math.log(1-mask_feature_rate))), 33 | x.shape) 34 | return masked_x_seq 35 | 36 | def simple_mask(x, mask): 37 | masked_x_seq = torch.sparse.FloatTensor(x.indices(), 38 | torch.where(mask[x.indices()[0], 39 | x.indices()[1]], 40 | 0., 41 | x.values()), 42 | x.shape) 43 | return masked_x_seq -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2023, OmicsML 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CellPLM 2 | This is the official codebase for [CellPLM: Pre-training of Cell Language Model Beyond Single Cells](https://openreview.net/forum?id=BKXvPDekud). **The paper has been accepted by ICLR 2024 conference.** 3 | 4 | ![Paper](https://img.shields.io/badge/Paper-ICLR24-brightgreen?link=https%3A%2F%2Fopenreview.net%2Fforum%3Fid%3DBKXvPDekud) 5 | [![License](https://img.shields.io/badge/License-BSD_2--Clause-orange.svg)](https://opensource.org/licenses/BSD-2-Clause) 6 | 7 | ***CellPLM*** is the first single-***Cell*** ***P***re-trained ***L***anguage ***M***odel that encodes cell-cell relations and it consistently outperforms existing pre-trained and non-pre-trained models in diverse downstream tasks, with 100x higher inference speed compared to existing pre-trained models. You can also find a brilliant blog about the idea of CellPLM [here](https://portal.valencelabs.com/blogs/post/cellplm-pre-training-of-cell-language-model-beyond-single-cells-wKScCQHIyicpXbx). 8 | 9 | ## Installation 10 | We recommend PyPI for quick installation. We recommend using `python 3.9` and `cuda>=11.7` but they are adjustable. 11 | 12 | ### Quick Installation with PyPI 13 | Make sure gpu version of pytorch (>=1.13.0) has been installed before installing CellPLM. 14 | ``` 15 | pip install cellplm 16 | ``` 17 | 18 | ### Full Installation (recommended for HPC users and developers) 19 | ``` 20 | conda create -n cellplm python=3.9 -y && conda activate cellplm 21 | conda install cudatoolkit=11.7 -c pytorch -c nvidia 22 | pip install -r requirements.txt 23 | ``` 24 | The full installation will install the same environment as we used during development. This includes `rapids` used to accelerate evaluation. 25 | 26 | ## Tutorials 27 | We offer several [notebooks](https://github.com/OmicsML/CellPLM/tree/main/tutorials) for various downstream tasks as introductory tutorials. _Our latest studies demonstrate CellPLM is competitive on cell-type annotation tasks compared to other SOTA methods and pretrained models. The result table is shown below:_ 28 | 29 | | Method | PBMC12K | Pancreas | HLCA | Immune | Brain | Liver | 30 | | --- | --- | --- | --- | --- | --- | --- | 31 | | SingleCellNet | 0.845+-0.0064 | 0.644+-0.0006 | 0.811+-0.0046 | 0.775+-0.0009 | 0.877+-0.0033 | 0.872+-0.0023 | 32 | | ACTINN | 0.614+-0.0709 | 0.528+-0.0926 | 0.218+-0.0440 | 0.236+-0.0300 | 0.695+-0.0624 | 0.614+-0.0349 | 33 | | scANVI | 0.930+-0.0148 | 0.963+-0.0083 | 0.708+-0.0183 | 0.851+-0.0133 | 0.933+-0.0010 | **0.908+-0.0144** | 34 | | CellTypist | 0.883+-0.0055 | 0.882+-0.0011 | 0.776+-0.0079 | 0.822+-0.0020 | 0.901+-0.0031 | 0.764+-0.0132 | 35 | | scDiff | 0.967+-0.0042 | **0.968+-0.0143** | **0.893+-0.0070** | 0.844+-0.0076 | 0.947+-0.0074 | 0.844+-0.0042 | 36 | | scGPT | 0.963 | 0.954 | 0.863 | ***0.907*** | **0.950** | 0.864 | 37 | | Geneformer | ***0.979*** | - | 0.833 | 0.856 | 0.934 | 0.871 | 38 | | CellPLM | **0.975** | ***0.983*** | ***0.929*** | **0.902** | ***0.967*** | ***0.913*** | 39 | 40 | _(The evaluation follows the setting in [scDiff](https://www.biorxiv.org/content/10.1101/2023.10.13.562243v1.abstract) paper)_ 41 | 42 | ## Pretrained CellPLM Model Checkpoints 43 | The checkpoint can be acquired from our [dropbox](https://www.dropbox.com/scl/fo/i5rmxgtqzg7iykt2e9uqm/h?rlkey=o8hi0xads9ol07o48jdityzv1&dl=0). We might update our checkpoints from time to time. 44 | 45 | [10/10/2023] The latest version is `20230926_85M`. 46 | 47 | ## Citation 48 | ``` 49 | @article{wen2023cellplm, 50 | title={CellPLM: Pre-training of Cell Language Model Beyond Single Cells}, 51 | author={Wen, Hongzhi and Tang, Wenzhuo and Dai, Xinnan and Ding, Jiayuan and Jin, Wei and Xie, Yuying and Tang, Jiliang}, 52 | journal={bioRxiv}, 53 | pages={2023--10}, 54 | year={2023}, 55 | publisher={Cold Spring Harbor Laboratory} 56 | } 57 | ``` 58 | -------------------------------------------------------------------------------- /ckpt/README.md: -------------------------------------------------------------------------------- 1 | ## Model Checkpoints 2 | 3 | Please download and place pre-trained model checkpoints here. A model checkpoint consists of two files: a model config (`.config.json` file) and a torch ckpt (`.best.ckpt` file), which share common prefix. The checkpoints can be found on our [dropbox](https://www.dropbox.com/scl/fo/i5rmxgtqzg7iykt2e9uqm/h?rlkey=o8hi0xads9ol07o48jdityzv1&dl=0). 4 | 5 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | ## Download 2 | 3 | Please download datasets from our [dropbox](https://www.dropbox.com/scl/fo/i5rmxgtqzg7iykt2e9uqm/h?rlkey=o8hi0xads9ol07o48jdityzv1&dl=0) and place it here. 4 | 5 | ## References 6 | 7 | Datasets in our tutorials are collected from different previous publications and we provide the references here: 8 | 9 | 1. `demo_train.h5ad` from hPancreas dataset. Chen, Jiawei, et al. "Transformer for one stop interpretable cell type annotation." Nature Communications 14.1 (2023): 223. 10 | 1. `demo_test.h5ad` from hPancreas dataset. Chen, Jiawei, et al. "Transformer for one stop interpretable cell type annotation." Nature Communications 14.1 (2023): 223. 11 | 1. `c_data.h5ad` from MS dataset. Cui, Haotian, et al. "scgpt: Towards building a foundation model for single-cell multi-omics using generative ai." bioRxiv (2023): 2023-04. 12 | 1. `filtered_ms_adata.h5ad` from MS dataset. Cui, Haotian, et al. "scgpt: Towards building a foundation model for single-cell multi-omics using generative ai." bioRxiv (2023): 2023-04. 13 | 1. `GSE131907_Lung_ensg.h5ad` from GSE131907 dataset. Kim, Nayoung, et al. "Single-cell RNA sequencing demonstrates the molecular and cellular reprogramming of metastatic lung adenocarcinoma." Nature communications 11.1 (2020): 2285. 14 | 1. `GSE151530_Liver_ensg.h5ad` from GSE151530 dataset. Ma, Lichun, et al. "Single-cell atlas of tumor cell evolution in response to therapy in hepatocellular carcinoma and intrahepatic cholangiocarcinoma." Journal of hepatology 75.6 (2021): 1397-1408. 15 | 1. `HumanLiverCancerPatient2_filtered_ensg.h5ad` from Liver cancer 2 dataset. Vizgen MERFISH FFPE Human Immuno-oncology Data Set, May 2022. [link](https://info.vizgen.com/ffpe-showcase?submissionGuid=88ba0a44-26e2-47a2-8ee4-9118b9811fbf) 16 | 2. `HumanLungCancerPatient2_filtered_ensg.h5ad` from Lung cancer 2 dataset. Vizgen MERFISH FFPE Human Immuno-oncology Data Set, May 2022. [link](https://info.vizgen.com/ffpe-showcase?submissionGuid=88ba0a44-26e2-47a2-8ee4-9118b9811fbf) 17 | 3. `gse155468.h5ad` from GSE155468 dataset. Li, Yanming, et al. "Single-cell transcriptome analysis reveals dynamic cell populations and differential gene expression patterns in control and aneurysmal human aortic tissue." Circulation 142.14 (2020): 1374-1388. 18 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=42.0", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "cellplm" 7 | version = "0.1.0" 8 | authors = [ 9 | { name="Hongzhi Wen", email="wenhongz@msu.edu" }, 10 | ] 11 | description = "Package of CellPLM: A pretrain-ed cell language model beyond single cells. Paper link: https://www.biorxiv.org/content/10.1101/2023.10.03.560734" 12 | readme = "README.md" 13 | requires-python = ">=3.9" 14 | classifiers = [ 15 | "Programming Language :: Python :: 3", 16 | "License :: OSI Approved :: BSD License", 17 | "Operating System :: OS Independent", 18 | "Intended Audience :: Science/Research", 19 | ] 20 | dependencies = [ 21 | 'pydance', 22 | 'einops', 23 | 'torchmetrics', 24 | 'hdf5plugin', 25 | 'mygene', 26 | ] 27 | 28 | [project.optional-dependencies] 29 | rapids = [ 30 | 'rapids-singlecell', 31 | ] 32 | 33 | [project.urls] 34 | Homepage = "https://github.com/OmicsML/CellPLM" 35 | Issues = "https://github.com/OmicsML/CellPLM/issues" 36 | 37 | [tool.setuptools] 38 | packages = ['CellPLM'] -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu117 2 | --extra-index-url https://pypi.nvidia.com 3 | torch==1.13.0+cu117 4 | cudf-cu11==23.4.1 5 | dask-cudf-cu11==23.4.1 6 | cuml-cu11==23.4.1 7 | cugraph-cu11==23.4.1 8 | cucim==23.4.1 9 | rapids-singlecell 10 | cellplm 11 | -------------------------------------------------------------------------------- /tutorials/README.md: -------------------------------------------------------------------------------- 1 | ## Tutorial Notebooks 2 | 3 | We plan to gradually release tutorials for different downstream tasks, including: 4 | 5 | - [x] Cell-type Annotation 6 | - [x] Spatial Imputation 7 | - [x] Zero-shot Cell Embedding 8 | - [ ] Gene Perturbation Prediction 9 | 10 | ~We also plan to build a more user-friendly interface for downstream tasks, therefore, the tutorials might be updated from time to time.~ 11 | The cell-type annotation tutorial is updated on 12/05/2023. A unified `pipeline` module is now added for downstream analyses. 12 | 13 | ## Dataset Preparation 14 | 15 | Before running the tutorial, please download datasets from our [dropbox](https://www.dropbox.com/scl/fo/i5rmxgtqzg7iykt2e9uqm/h?rlkey=o8hi0xads9ol07o48jdityzv1&dl=0) and placed h5ad datasets in the `../data` folder. 16 | 17 | All datasets used in our tutorials are collected from previous publications and we provide the references in `../data/README.md`. 18 | 19 | ## Customized Dataset 20 | 21 | The customized dataset can now be easily processed with `CellPLM.pipeline` module. Please refer to the tutorial of each downstream task. 22 | --------------------------------------------------------------------------------