├── README.md ├── checkpoint └── download_link.txt ├── code ├── downstream_application │ ├── attention_matrix.py │ └── pseudo_knockout_gene.py ├── main_scripts │ ├── stage3_fine-tune.py │ └── stage3_inference_without_finetune.py ├── model │ ├── ID_dic │ │ ├── EntrezID_to_myID.pkl │ │ ├── hgs_to_EntrezID.pkl │ │ ├── mouse_gene_ID_to_human_gene_symbol.pkl │ │ └── mouse_gene_symbol_to_human_gene_symbol.pkl │ ├── data_preprocessing_ID_convert.py │ ├── performer_enc_dec.py │ ├── performer_pytorch.py │ ├── reversible.py │ └── utils.py └── visualization │ ├── fig2.ipynb │ ├── fig3.r │ ├── fig4.ipynb │ ├── fig5.ipynb │ └── fig6.ipynb ├── dataset └── download_link.txt ├── requirements.txt ├── result └── download_link.txt └── scTranslator.jpg /README.md: -------------------------------------------------------------------------------- 1 | # scTranslator: A pre-trained large generative model for translating single-cell transcriptome to proteome 2 | 3 | Despite the recent advancements in single-cell proteome technology, it still has limitation on throughput, proteome depth and batch effect and the cost is still high. Inspired by the nature language translation and the central dogma of molecular biology, we propose a pre-trained large generative model named scTranslator (single-cell translator), which is align-free and generates absent single-cell proteome by inferring from the transcriptome. 4 | 5 | 6 | 7 | # Dataset 8 | The data can be downloaded from this link. If you have any question, please contact elainelliu@tencent.com. 9 | 10 | https://drive.google.com/drive/folders/1XmdKikkG3g0yl1vKmY9lhru_78Nc0NMk?usp=sharing 11 | 12 | # Pre-trained model 13 | The pre-trained models can be downloaded from these links. If you have any question, please contact elainelliu@tencent.com. 14 | 15 | | Model name | Description | Download | 16 | | :------------------------ | :------------------------------------------------------ | :------------------------------------------------------------------------------------------- | 17 | | scTranslator 2M | Pretrained on over 2 million human cells and 18,000 bulk samples. | [link](https://drive.google.com/file/d/11FR3nebhJKAt_QIng35H-O1s_-ZrKyCG/view?usp=sharing) | 18 | | scTranslator 160k | Pretrained on over 160K human PBMCs and 18,000 bulk samples. | [link](https://drive.google.com/file/d/1nmYIsctfMD60DxOKKQc9-AQj2Wla24m8/view?usp=sharing) | 19 | | scTranslator 10k | Pretrained on over 18,000 bulk samples. | [link](https://drive.google.com/file/d/14D6hFCcMrrkpo7zW90WmH3FFOR-iyIZ-/view?usp=sharing) | 20 | 21 | 22 | 23 | # Results 24 | The results for analysis with jupyter demo can be downloaded from this link. If you have any question, please contact elainelliu@tencent.com. 25 | 26 | https://drive.google.com/drive/folders/1R4JEJjwP27yLqYMlOulmvGiocnlJZT3Z?usp=sharing 27 | 28 | # Installation 29 | [![python >3.8.13](https://img.shields.io/badge/python-3.8.13-brightgreen)](https://www.python.org/) 30 | [![scipy-1.6.2](https://img.shields.io/badge/scipy-1.6.2-yellowgreen)](https://github.com/scipy/scipy) [![pytorch-1.12.1](https://img.shields.io/badge/pytorch-1.12.1-orange)](https://github.com/pytorch/pytorch) [![numpy-1.21.5](https://img.shields.io/badge/numpy-1.21.5-red)](https://github.com/numpy/numpy) [![pandas-1.2.4](https://img.shields.io/badge/pandas-1.2.4-lightgrey)](https://github.com/pandas-dev/pandas) [![scanpy-1.9.1](https://img.shields.io/badge/scanpy-1.9.1-blue)](https://github.com/theislab/scanpy) [![scikit--learn-1.1.1](https://img.shields.io/badge/scikit--learn-1.1.1-green)](https://github.com/scikit-learn/scikit-learn) 31 | [![local--attention-1.4.3](https://img.shields.io/badge/local--attention-1.4.3-red)](https://fast-transformers.github.io/) 32 | 33 | ## 1. Environment preparation 34 | The environment for scTranslator can be obtained from the Docker Hub registry or by installing the dependencies with requirement.txt. 35 | ### Option 1: Download the docker image from Docker Hub. 36 | ```bash 37 | $ docker pull linjingliu/sctranslator:latest 38 | ``` 39 | Start a container based on the image and ativate the enviroment . 40 | ```bash 41 | $ docker run --name sctranslator --gpus all -it --rm linjingliu/sctranslator:latest /bin/bash 42 | ``` 43 | ### Option 2: Utilize conda to create and activate a environment. 44 | ```bash 45 | $ conda create performer 46 | $ conda activate performer 47 | ``` 48 | Install the necessary dependencies 49 | ```bash 50 | $ conda install requirements.txt 51 | ``` 52 | ## 2. Install by git clone 53 | This usually takes 5 seconds on a normal desktop computer. 54 | 55 | ```bash 56 | $ git clone git@github.com:TencentAILabHealthcare/scTranslator.git 57 | ``` 58 | Download datasets and checkpoint from provided links and place to the corresponding folder in scTranslator. 59 | 60 | # Step by step tutorial 61 | 1. Activate the environment and switch to scTranslator folder. 62 | ```bash 63 | $ conda activate performer 64 | $ cd scTranslator 65 | ``` 66 | 1. Input file format. 67 | 68 | scTranslator accepts single-cell data in .h5ad format as input. If you want to fine-tune the model, you need to provide scRNA.h5ad paired with scProtein.h5ad. If you want to perform inference directly, you only need to provide your query protein names in scProtein.h5ad and leave the values as 0. 69 | 70 | 1. ID convert from HUGO symbol/Entrez ID/mouse gene symbol/mouse gene ID to scTranslator ID. 71 | 72 | If you want to use scTranslator with your own data, you need to convert gene symbols or IDs to scTranslator IDs. Here, we provide an example. The data after gene mapping is stored in the same directory as the original data and is distinguished by the suffix '_mapped'. 73 | 74 | ```bash 75 | $ python code/model/data_preprocessing_ID_convert.py \ 76 | --origin_gene_type='mouse_gene_symbol' \ 77 | --origin_gene_column='index' \ 78 | --data_path='dataset/test/cite-seq_mouse/spleen_lymph_111.h5ad' 79 | ``` 80 | 81 | Parameter |Description | Default 82 | ------------------------|----------------------------| ---------------------------------------------- 83 | origin_gene_column |If gene information is in the index column of anndata.var, use 'index'. If it's in a specific column, provide the column name, such as 'gene_name'. |'index' 84 | origin_gene_type |Original gene type before mapping|choices=['mouse_gene_ID', 'mouse_gene_symbol', 'human_gene_symbol', 'EntrezID'] 85 | data_path |dataset path |'dataset/test/cite-seq_mouse/spleen_lymph_111.h5ad' 86 | 87 | 88 | 89 | 90 | 1. Demo for protein abundance prediction with or without fine-tuning. The results, comprising both protein abundance and performance metrics, are stored in the 'scTranslator/result/test' directory. 91 | ```bash 92 | # Inferrence without fine-tune 93 | $ python code/main_scripts/stage3_inference_without_finetune.py \ 94 | --pretrain_checkpoint='checkpoint/scTranslator_2M.pt' \ 95 | --RNA_path='dataset/test/dataset1/GSM5008737_RNA_finetune_withcelltype.h5ad' \ 96 | --Pro_path='dataset/test/dataset1/GSM5008738_protein_finetune_withcelltype.h5ad' 97 | 98 | # Inferrence with fine-tune 99 | $ python -m torch.distributed.launch --nnodes=1 --node_rank=0 --nproc_per_node 1 --master_port 23333 \ 100 | code/main_scripts/stage3_fine-tune.py --epoch=100 --frac_finetune_test=0.1 --fix_set \ 101 | --pretrain_checkpoint='checkpoint/scTranslator_2M.pt' \ 102 | --RNA_path='dataset/test/dataset1/GSM5008737_RNA_finetune_withcelltype.h5ad' \ 103 | --Pro_path='dataset/test/dataset1/GSM5008738_protein_finetune_withcelltype.h5ad' 104 | ``` 105 | 1. Demo for obtaining attention matrix. The results are stored in the 'scTranslatorresult/fig5/a' directory. 106 | ```bash 107 | $ python code/downstream_application/attention_matrix.py \ 108 | --pretrain_checkpoint='checkpoint/Dataset1_fine-tuned_scTranslator.pt' \ 109 | --RNA_path='dataset/test/dataset1/GSM5008737_RNA_finetune_withcelltype.h5ad' \ 110 | --Pro_path='dataset/test/dataset1/GSM5008738_protein_finetune_withcelltype.h5ad' 111 | ``` 112 | 1. Demo for pseudo-knockout gene.he results are stored in the 'scTranslatorresult/fig5/e' directory. 113 | ```bashs 114 | # Compute origin protein abundance 115 | $ python code/downstream_application/pseudo_knockout_gene.py --gene='org' 116 | # Compute protein abundance after pseudo-knockout gene 117 | $ python code/downstream_application/pseudo_knockout_gene.py --gene='TP53' 118 | ``` 119 | # Hyperparameters 120 | 121 | Hyperparameter |Description | Default 122 | -------------------------|-----------------------------------| ----------- 123 | batch_size |Batch_size |8 124 | epoch |Training epochs |100 125 | 126 | 127 | # Results analysis 128 | The [scripts](./code/visualization) for results analysis and visualization are provided. 129 | 130 | # Time cost 131 | The anticipated runtime for inferring 1000 proteins in 100 cells is approximately 20 seconds using a 16GB GPU and 110 seconds with a CPU. 132 | 133 | # Disclaimer 134 | This tool is for research purpose and not approved for clinical use. 135 | 136 | This is not an official Tencent product. 137 | 138 | # Coypright 139 | This tool is developed in Tencent AI Lab. 140 | 141 | The copyright holder for this project is Tencent AI Lab. 142 | 143 | All rights reserved. 144 | -------------------------------------------------------------------------------- /checkpoint/download_link.txt: -------------------------------------------------------------------------------- 1 | https://drive.google.com/drive/folders/16psbNA0fUsY9Crac1CHKYF4n-efbwTrY?usp=sharing -------------------------------------------------------------------------------- /code/downstream_application/attention_matrix.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | import scanpy as sc 6 | import numpy as np 7 | import pandas as pd 8 | 9 | import sys 10 | sys.path.append('code/model') 11 | from performer_enc_dec import * 12 | from utils import * 13 | 14 | ################################################# 15 | #------------ Train & Test Function ------------# 16 | ################################################# 17 | 18 | def route_args(router, args, depth): 19 | routed_args = [(dict(), dict()) for _ in range(depth)] 20 | matched_keys = [key for key in args.keys() if key in router] 21 | 22 | for key in matched_keys: 23 | val = args[key] 24 | for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])): 25 | new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes) 26 | routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args}) 27 | return routed_args 28 | 29 | class SequentialSequence(nn.Module): 30 | def __init__(self, layers, args_route = {}): 31 | super().__init__() 32 | assert all(len(route) == len(layers) for route in args_route.values()), 'each argument route map must have the same depth as the number of sequential layers' 33 | self.layers = layers 34 | self.args_route = args_route 35 | 36 | def forward(self, x, output_attentions = False, **kwargs): 37 | args = route_args(self.args_route, kwargs, len(self.layers)) 38 | layers_and_args = list(zip(self.layers, args)) 39 | 40 | if output_attentions: 41 | attn_weights = [] 42 | for (f, g), (f_args, g_args) in layers_and_args: 43 | if output_attentions: 44 | x = x + f(x, output_attentions = output_attentions, **f_args)[0] 45 | attn_weights.append(f(x, output_attentions = output_attentions, **f_args)[1].unsqueeze(0)) 46 | else: 47 | x = x + f(x, **f_args) 48 | x = x + g(x, **g_args) 49 | if output_attentions: 50 | attn_weights = torch.transpose(torch.cat(attn_weights, dim=0), 0, 1) # the final dim is (batch, layer, head, len, len) 51 | attn_weights = torch.mean(attn_weights, dim=1) # the dim is (batch, head, len, len) 52 | return x, attn_weights 53 | else: 54 | return x 55 | class FastAttention(nn.Module): 56 | def __init__(self, dim_heads, nb_features = None, ortho_scaling = 0, causal = False, generalized_attention = False, kernel_fn = nn.ReLU(), no_projection = False): 57 | super().__init__() 58 | nb_features = default(nb_features, int(dim_heads * math.log(dim_heads))) 59 | 60 | self.dim_heads = dim_heads 61 | self.nb_features = nb_features 62 | self.ortho_scaling = ortho_scaling 63 | 64 | self.create_projection = partial(gaussian_orthogonal_random_matrix, nb_rows = self.nb_features, nb_columns = dim_heads, scaling = ortho_scaling) 65 | projection_matrix = self.create_projection() 66 | self.register_buffer('projection_matrix', projection_matrix) 67 | 68 | self.generalized_attention = generalized_attention 69 | self.kernel_fn = kernel_fn 70 | 71 | # if this is turned on, no projection will be used 72 | # queries and keys will be softmax-ed as in the original efficient attention paper 73 | self.no_projection = no_projection 74 | 75 | self.causal = causal 76 | if causal: 77 | try: 78 | import fast_transformers.causal_product.causal_product_cuda 79 | self.causal_linear_fn = partial(causal_linear_attention) 80 | except ImportError: 81 | print('unable to import cuda code for auto-regressive Performer. will default to the memory inefficient non-cuda version') 82 | self.causal_linear_fn = causal_linear_attention_noncuda 83 | 84 | @torch.no_grad() 85 | def redraw_projection_matrix(self, device): 86 | projections = self.create_projection(device = device) 87 | self.projection_matrix.copy_(projections) 88 | del projections 89 | 90 | def forward(self, q, k, v, output_attentions = False): 91 | device = q.device 92 | if self.no_projection: 93 | q = q.softmax(dim = -1) 94 | k = torch.exp(k) if self.causal else k.softmax(dim = -2) 95 | 96 | elif self.generalized_attention: 97 | create_kernel = partial(generalized_kernel, kernel_fn = self.kernel_fn, projection_matrix = self.projection_matrix, device = device) 98 | q, k = map(create_kernel, (q, k)) 99 | 100 | else: 101 | create_kernel = partial(softmax_kernel, projection_matrix = self.projection_matrix, device = device) 102 | q = create_kernel(q, is_query = True) 103 | k = create_kernel(k, is_query = False) 104 | 105 | attn_fn = linear_attention if not self.causal else self.causal_linear_fn 106 | out = attn_fn(q, k, v) 107 | if output_attentions: 108 | v_diag = torch.eye(v.shape[-2]).to(device) 109 | v_diag = v_diag.unsqueeze(0).unsqueeze(0).repeat(v.shape[0],v.shape[1],1,1) 110 | attn_weights = torch.zeros(1, 1, q.shape[2], q.shape[2]).to(device) 111 | for head_dim in range(q.shape[1]): 112 | attn_weights += attn_fn(q[:,head_dim], k[:,head_dim], v_diag[:,head_dim]) 113 | attn_weights /= q.shape[1] 114 | return out, attn_weights 115 | else: 116 | return out 117 | 118 | 119 | 120 | class SelfAttention(nn.Module): 121 | def __init__( 122 | self, 123 | dim, 124 | causal = False, 125 | heads = 8, 126 | dim_head = 64, 127 | local_heads = 0, 128 | local_window_size = 256, 129 | nb_features = None, 130 | feature_redraw_interval = 1000, 131 | generalized_attention = False, 132 | kernel_fn = nn.ReLU(), 133 | dropout = 0., 134 | no_projection = False, 135 | qkv_bias = False 136 | ): 137 | super().__init__() 138 | assert dim % heads == 0, 'dimension must be divisible by number of heads' 139 | dim_head = default(dim_head, dim // heads) 140 | inner_dim = dim_head * heads 141 | self.fast_attention = FastAttention(dim_head, nb_features, causal = causal, generalized_attention = generalized_attention, kernel_fn = kernel_fn, no_projection = no_projection) 142 | 143 | self.heads = heads 144 | self.global_heads = heads - local_heads 145 | self.local_attn = LocalAttention(window_size = local_window_size, causal = causal, autopad = True, dropout = dropout, look_forward = int(not causal), rel_pos_emb_config = (dim_head, local_heads)) if local_heads > 0 else None 146 | 147 | self.to_q = nn.Linear(dim, inner_dim, bias = qkv_bias) 148 | self.to_k = nn.Linear(dim, inner_dim, bias = qkv_bias) 149 | self.to_v = nn.Linear(dim, inner_dim, bias = qkv_bias) 150 | self.to_out = nn.Linear(inner_dim, dim) 151 | self.dropout = nn.Dropout(dropout) 152 | 153 | def forward(self, x, pos_emb = None, context = None, mask = None, context_mask = None, output_attentions = False, **kwargs): 154 | b, n, _, h, gh = *x.shape, self.heads, self.global_heads 155 | 156 | cross_attend = exists(context) 157 | 158 | context = default(context, x) 159 | context_mask = default(context_mask, mask) if not cross_attend else context_mask 160 | 161 | q, k, v = self.to_q(x), self.to_k(context), self.to_v(context) 162 | 163 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) 164 | (q, lq), (k, lk), (v, lv) = map(lambda t: (t[:, :gh], t[:, gh:]), (q, k, v)) 165 | 166 | attn_outs = [] 167 | 168 | if not empty(q): 169 | if exists(context_mask): 170 | global_mask = context_mask[:, None, :, None] 171 | v.masked_fill_(~global_mask, 0.) 172 | 173 | if exists(pos_emb) and not cross_attend: 174 | q, k, = apply_rotary_pos_emb(q, k, pos_emb) 175 | 176 | if output_attentions: 177 | out, attn_weights = self.fast_attention(q, k, v, output_attentions) 178 | else: 179 | out = self.fast_attention(q, k, v) 180 | attn_outs.append(out) 181 | 182 | if not empty(lq): 183 | assert not cross_attend, 'local attention is not compatible with cross attention' 184 | out = self.local_attn(lq, lk, lv, input_mask = mask) 185 | attn_outs.append(out) 186 | 187 | out = torch.cat(attn_outs, dim = 1) 188 | out = rearrange(out, 'b h n d -> b n (h d)') 189 | out = self.to_out(out) 190 | if output_attentions: 191 | return self.dropout(out), attn_weights 192 | else: 193 | return self.dropout(out) 194 | 195 | class Performer(nn.Module): 196 | def __init__( 197 | self, 198 | dim, 199 | depth, 200 | heads, 201 | dim_head, 202 | local_attn_heads = 0, 203 | local_window_size = 256, 204 | causal = False, 205 | ff_mult = 4, 206 | nb_features = None,#64,# 207 | feature_redraw_interval = 1000, 208 | reversible = False, 209 | ff_chunks = 1, 210 | generalized_attention = False, 211 | kernel_fn = nn.ReLU(), 212 | use_scalenorm = False, 213 | use_rezero = False, 214 | ff_glu = False, 215 | ff_dropout = 0., 216 | attn_dropout = 0., 217 | cross_attend = False, 218 | no_projection = False, 219 | auto_check_redraw = True, 220 | qkv_bias = True, 221 | attn_out_bias = True, 222 | shift_tokens = False 223 | ): 224 | super().__init__() 225 | layers = nn.ModuleList([]) 226 | local_attn_heads = cast_tuple(local_attn_heads) 227 | local_attn_heads = local_attn_heads * depth if len(local_attn_heads) == 1 else local_attn_heads 228 | assert len(local_attn_heads) == depth, 'tuple specifying number of local attention heads per depth must be equal to the total depth' 229 | assert all(map(lambda n: n >= 0 and n <= heads, local_attn_heads)), 'local attention head value must be less than the total number of heads' 230 | 231 | if use_scalenorm: 232 | wrapper_fn = partial(PreScaleNorm, dim) 233 | elif use_rezero: 234 | wrapper_fn = ReZero 235 | else: 236 | wrapper_fn = partial(PreLayerNorm, dim) 237 | 238 | for _, local_heads in zip(range(depth), local_attn_heads): 239 | 240 | attn = SelfAttention(dim, causal = causal, heads = heads, dim_head = dim_head, local_heads = local_heads, local_window_size = local_window_size, nb_features = nb_features, generalized_attention = generalized_attention, kernel_fn = kernel_fn, dropout = attn_dropout, no_projection = no_projection, qkv_bias = qkv_bias) 241 | ff = Chunk(ff_chunks, FeedForward(dim, mult = ff_mult, dropout = ff_dropout, glu = ff_glu), along_dim = 1) 242 | 243 | if shift_tokens: 244 | shift = (0, 1) if causal else (-1, 0, 1) 245 | attn, ff = map(lambda t: PreShiftTokens(shift, t), (attn, ff)) 246 | 247 | attn, ff = map(wrapper_fn, (attn, ff)) 248 | layers.append(nn.ModuleList([attn, ff])) 249 | 250 | if not cross_attend: 251 | continue 252 | 253 | layers.append(nn.ModuleList([ 254 | wrapper_fn(CrossAttention(dim, heads = heads, dim_head = dim_head, nb_features = nb_features, generalized_attention = generalized_attention, kernel_fn = kernel_fn, dropout = attn_dropout, no_projection = no_projection, qkv_bias = qkv_bias, attn_out_bias = attn_out_bias)), 255 | wrapper_fn(Chunk(ff_chunks, FeedForward(dim, mult = ff_mult, dropout = ff_dropout, glu = ff_glu), along_dim = 1)) 256 | ])) 257 | 258 | execute_type = ReversibleSequence if reversible else SequentialSequence 259 | 260 | route_attn = ((True, False),) * depth * (2 if cross_attend else 1) 261 | route_context = ((False, False), (True, False)) * depth 262 | attn_route_map = {'mask': route_attn, 'pos_emb': route_attn} 263 | context_route_map = {'context': route_context, 'context_mask': route_context} if cross_attend else {} 264 | self.net = execute_type(layers, args_route = {**attn_route_map, **context_route_map}) 265 | 266 | # keeping track of when to redraw projections for all attention layers 267 | self.auto_check_redraw = auto_check_redraw 268 | self.proj_updater = ProjectionUpdater(self.net, feature_redraw_interval) 269 | 270 | def fix_projection_matrices_(self): 271 | self.proj_updater.feature_redraw_interval = None 272 | 273 | def forward(self, x, output_attentions = True, **kwargs): 274 | if self.auto_check_redraw: 275 | self.proj_updater.redraw_projections() 276 | return self.net(x, output_attentions = output_attentions, **kwargs) 277 | 278 | class scPerformerLM(nn.Module): 279 | def __init__( 280 | self, 281 | *, 282 | 283 | max_seq_len, 284 | dim,depth, 285 | heads, 286 | num_tokens=1, 287 | dim_head = 64, 288 | local_attn_heads = 0, 289 | local_window_size = 256, 290 | causal = False, 291 | ff_mult = 4, 292 | nb_features = None, 293 | feature_redraw_interval = 1000, 294 | reversible = False, 295 | ff_chunks = 1, 296 | ff_glu = False, 297 | emb_dropout = 0., 298 | ff_dropout = 0., 299 | attn_dropout = 0., 300 | generalized_attention = False, 301 | kernel_fn = nn.ReLU(), 302 | use_scalenorm = False, 303 | use_rezero = False, 304 | cross_attend = False, 305 | no_projection = False, 306 | tie_embed = False, 307 | rotary_position_emb = True, 308 | axial_position_emb = False, 309 | axial_position_shape = None, 310 | auto_check_redraw = True, 311 | qkv_bias = False, 312 | attn_out_bias = False, 313 | shift_tokens = False 314 | ): 315 | super().__init__() 316 | local_attn_heads = cast_tuple(local_attn_heads) 317 | 318 | self.max_seq_len = max_seq_len 319 | self.to_vector = nn.Linear(1,dim) 320 | self.pos_emb = nn.Embedding(85500,dim,padding_idx=0)# There are 75500 NCBI Gene ID obtained on 19th July, 2022 321 | self.layer_pos_emb = Always(None) 322 | self.dropout = nn.Dropout(emb_dropout) 323 | self.performer = Performer(dim, depth, heads, dim_head) 324 | self.to_out = nn.Linear(dim, num_tokens) if not tie_embed else None 325 | 326 | def check_redraw_projections(self): 327 | self.performer.check_redraw_projections() 328 | 329 | def fix_projection_matrices_(self): 330 | self.performer.fix_projection_matrices_() 331 | 332 | def forward(self, x, geneID, return_encodings = False, output_attentions = True,**kwargs): 333 | b, n = x.shape[0], x.shape[1] 334 | assert n <= self.max_seq_len, f'sequence length {n} must be less than the max sequence length {self.max_seq_len}' 335 | 336 | # token and positional embeddings 337 | if len(x.shape)<3: 338 | x = torch.unsqueeze(x,dim=2) 339 | x = self.to_vector(x) 340 | 341 | x += self.pos_emb(geneID) 342 | x = self.dropout(x) 343 | # performer layers 344 | layer_pos_emb = self.layer_pos_emb(x) 345 | x, attn_weights = self.performer(x, pos_emb = layer_pos_emb, output_attentions = output_attentions, **kwargs) 346 | 347 | if return_encodings: 348 | return x, attn_weights 349 | 350 | return torch.squeeze(self.to_out(x)), attn_weights 351 | 352 | class scPerformerEncDec(nn.Module): 353 | def __init__( 354 | self, 355 | dim, 356 | translator_depth, 357 | initial_dropout, 358 | ignore_index = 0, 359 | pad_value = 0, 360 | tie_token_embeds = False, 361 | no_projection = False, 362 | **kwargs 363 | ): 364 | super().__init__() 365 | enc_kwargs, dec_kwargs, _ = extract_enc_dec_kwargs(kwargs) 366 | 367 | assert 'dim' not in dec_kwargs and 'dim' not in enc_kwargs, 'you must set the dim for both encoder and decoder' 368 | 369 | enc_kwargs['dim'] = dec_kwargs['dim'] = dim 370 | enc_kwargs['no_projection'] = dec_kwargs['no_projection'] = no_projection 371 | 372 | enc = scPerformerLM(**enc_kwargs) 373 | dec = scPerformerLM(**dec_kwargs) 374 | 375 | 376 | self.enc = enc 377 | self.translator = MLPTranslator(enc_kwargs['max_seq_len'], dec_kwargs['max_seq_len'], translator_depth, initial_dropout) 378 | self.dec = dec 379 | 380 | def forward(self, seq_in, seq_inID, seq_outID, **kwargs): 381 | enc_kwargs, dec_kwargs, kwargs = extract_and_set_enc_dec_kwargs(kwargs) 382 | encodings, enc_weights = self.enc(seq_in, seq_inID, return_encodings = True, **enc_kwargs)# batch_size, input_seq_lenth, dim 383 | seq_out = self.translator(encodings.transpose(1,2).contiguous()).transpose(1,2).contiguous() # batch_size, out_seq_lenth, dim \ 384 | _, dec_weights = self.dec(seq_out, seq_outID, **dec_kwargs) 385 | enc2dec_weights = torch.einsum('...ik,...kj->...ij', self.translator(enc_weights.type_as(encodings)), dec_weights.type_as(encodings)) 386 | return torch.squeeze(enc_weights), torch.squeeze(dec_weights), torch.squeeze(enc2dec_weights) 387 | 388 | 389 | ################################################# 390 | #---------------- Main Function ----------------# 391 | ################################################# 392 | def main(): 393 | #--- Training Settings ---# 394 | parser = argparse.ArgumentParser(description='PyTorch Example') 395 | parser.add_argument('--test_batch_size', type=int, default=1, 396 | help='input batch size for testing (default: 32)') 397 | parser.add_argument('--seed', type=int, default=1105, 398 | help='random seed (default: 1105)') 399 | parser.add_argument('--dim', type=int, default=128, 400 | help='latend dimension of each token') 401 | parser.add_argument('--enc_max_seq_len', type=int, default=20000, 402 | help='sequence length of encoder') 403 | parser.add_argument('--dec_max_seq_len', type=int, default=1000, 404 | help='sequence length of decoder') 405 | parser.add_argument('--translator_depth', type=int, default=2, 406 | help='translator depth') 407 | parser.add_argument('--initial_dropout', type=float, default=0.1, 408 | help='sequence length of decoder') 409 | parser.add_argument('--enc_depth', type=int, default=2, 410 | help='sequence length of decoder') 411 | parser.add_argument('--enc_heads', type=int, default=8, 412 | help='sequence length of decoder') 413 | parser.add_argument('--dec_depth', type=int, default=2, 414 | help='sequence length of decoder') 415 | parser.add_argument('--dec_heads', type=int, default=8, 416 | help='sequence length of decoder') 417 | parser.add_argument('--pretrain_checkpoint', default='checkpoint/Dataset1_fine-tuned_scTranslator.pt', 418 | help='path for loading the checkpoint') 419 | parser.add_argument('--RNA_path', default='dataset/test/dataset1/GSM5008737_RNA_finetune_withcelltype.h5ad', 420 | help='path for loading the rna') 421 | parser.add_argument('--Pro_path', default='dataset/test/dataset1/GSM5008738_protein_finetune_withcelltype.h5ad', 422 | help='path for loading the protein') 423 | args = parser.parse_args() 424 | 425 | 426 | print('seed', args.seed) 427 | 428 | #--- Prepare The Model ---# 429 | model = scPerformerEncDec( 430 | dim=args.dim, 431 | translator_depth=args.translator_depth, 432 | initial_dropout=args.initial_dropout, 433 | enc_depth=args.enc_depth, 434 | enc_heads=args.enc_heads, 435 | enc_max_seq_len=args.enc_max_seq_len, 436 | dec_depth=args.dec_depth, 437 | dec_heads=args.dec_heads, 438 | dec_max_seq_len=args.dec_max_seq_len 439 | ) 440 | 441 | 442 | model.load_state_dict(torch.load(args.pretrain_checkpoint, map_location=torch.device('cpu')).state_dict()) 443 | device = ('cuda' if torch.cuda.is_available() else 'cpu') 444 | model = model.to(device) 445 | #----- Load Single Cell Data -----# 446 | scRNA_adata = sc.read_h5ad(args.RNA_path)[:10] 447 | scP_adata = sc.read_h5ad(args.Pro_path)[:10] 448 | print('Total number of origin RNA genes: ', scRNA_adata.n_vars) 449 | print('Total number of origin proteins: ', scP_adata.n_vars) 450 | print('Total number of origin cells: ', scRNA_adata.n_obs) 451 | print('# of NAN in X', np.isnan(scRNA_adata.X).sum()) 452 | print('# of NAN in X', np.isnan(scP_adata.X).sum()) 453 | 454 | #--- Seperate Training and Testing set ---# 455 | setup_seed(1105) 456 | att_index = scRNA_adata.obs.index 457 | my_testset = fix_SCDataset(scRNA_adata[att_index], scP_adata[att_index], args.enc_max_seq_len, args.dec_max_seq_len) 458 | test_loader = torch.utils.data.DataLoader(my_testset, batch_size=args.test_batch_size) 459 | print("load data ended!") 460 | 461 | enc_weights, dec_weights, enc2dec_weights = np.zeros((args.enc_max_seq_len,args.enc_max_seq_len)), np.zeros((args.dec_max_seq_len,args.dec_max_seq_len)), \ 462 | np.zeros((args.enc_max_seq_len,args.dec_max_seq_len)) 463 | 464 | torch.cuda.empty_cache() 465 | with torch.no_grad(): 466 | i = 0 467 | for x, y in test_loader: 468 | #--- Extract Feature ---# 469 | RNA_geneID = torch.tensor(x[:,1].tolist()).long().to(device) 470 | Protein_geneID = torch.tensor(y[:,1].tolist()).long().to(device) 471 | rna_mask = torch.tensor(x[:,2].tolist()).bool().to(device) 472 | pro_mask = torch.tensor(y[:,2].tolist()).bool().to(device) 473 | x = torch.tensor(x[:,0].tolist(), dtype=torch.float32).to(device) 474 | y = torch.tensor(y[:,0].tolist(), dtype=torch.float32).to(device) 475 | 476 | #--- Prediction ---# 477 | enc_weight, dec_weight, enc2dec_weight = model(x, RNA_geneID, Protein_geneID, enc_mask=rna_mask, dec_mask=pro_mask) 478 | enc_weights += enc_weight.detach().cpu().numpy() 479 | dec_weights += dec_weight.detach().cpu().numpy() 480 | enc2dec_weights += enc2dec_weight.detach().cpu().numpy() 481 | i+=1 482 | print('attention for cell', i) 483 | args.enc_max_seq_len = min(args.enc_max_seq_len, len(scRNA_adata.var.index)) 484 | args.dec_max_seq_len = min(args.dec_max_seq_len, len(scP_adata.var.index)) 485 | enc_weights = pd.DataFrame(enc_weights[:args.enc_max_seq_len, :args.enc_max_seq_len], columns=scRNA_adata.var.index[:args.enc_max_seq_len].tolist(),index= scRNA_adata.var.index[:args.enc_max_seq_len].tolist()) 486 | dec_weights = pd.DataFrame(dec_weights[:args.dec_max_seq_len, :args.dec_max_seq_len], columns=scP_adata.var.index[:args.dec_max_seq_len].tolist(), index=scP_adata.var.index[:args.dec_max_seq_len].tolist()) 487 | enc2dec_weights = pd.DataFrame(enc2dec_weights[:args.enc_max_seq_len, :args.dec_max_seq_len], columns=scP_adata.var.index[:args.dec_max_seq_len].tolist(), index=scRNA_adata.var.index[:args.enc_max_seq_len].tolist()) 488 | enc_weights = attention_normalize(enc_weights) 489 | dec_weights = attention_normalize(dec_weights) 490 | enc2dec_weights = attention_normalize(enc2dec_weights) 491 | file_path = 'result/fig5/a' 492 | 493 | if not os.path.exists(file_path): 494 | os.makedirs(file_path) 495 | enc_weights.to_csv(file_path + '/encoder_attention_score.csv') 496 | dec_weights.to_csv(file_path + '/decoder_attention_score.csv') 497 | enc2dec_weights.to_csv(file_path + '/encoder2decoder_attention_score.csv') 498 | 499 | print('completed') 500 | 501 | 502 | if __name__ == '__main__': 503 | main() 504 | -------------------------------------------------------------------------------- /code/downstream_application/pseudo_knockout_gene.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | import scanpy as sc 5 | import numpy as np 6 | import pandas as pd 7 | import anndata as ad 8 | from torch.utils.data import Dataset 9 | 10 | import sys 11 | sys.path.append('code/model') 12 | from performer_enc_dec import * 13 | from utils import * 14 | 15 | ################################################# 16 | #------------ Train & Test Function ------------# 17 | ################################################# 18 | def test(model, test_loader, device): 19 | model.eval() 20 | y_hat_all = [] 21 | with torch.no_grad(): 22 | for x, y in test_loader: 23 | #--- Extract Feature ---# 24 | RNA_geneID = torch.tensor(x[:,1].tolist()).long().to(device) 25 | Protein_geneID = torch.tensor(y[:,1].tolist()).long().to(device) 26 | rna_mask = torch.tensor(x[:,2].tolist()).bool().to(device) 27 | pro_mask = torch.tensor(y[:,2].tolist()).bool().to(device) 28 | x = torch.tensor(x[:,0].tolist(), dtype=torch.float32).to(device) 29 | 30 | #--- Prediction ---# 31 | _, y_hat = model(x, RNA_geneID, Protein_geneID, enc_mask=rna_mask, dec_mask=pro_mask) 32 | y_hat = torch.squeeze(y_hat) 33 | if device == 'cpu': 34 | y_hat_all.extend(y_hat.numpy().tolist()) 35 | else: 36 | y_hat_all.extend(y_hat.detach().cpu().numpy().tolist()) 37 | 38 | return np.array(y_hat_all) 39 | 40 | ################################################# 41 | #---------- scData Preprocess Function ---------# 42 | ################################################# 43 | def pro_fix_sc_truncate_padding(x, length): 44 | ''' 45 | x = (num_gene,1) 46 | 47 | ''' 48 | len_x = len(x.X[0]) 49 | tmp = [i for i in x.X[0]] 50 | if len_x >= length: # truncate 51 | x_value = tmp[:length] 52 | gene = x.var.iloc[:length]['my_Id'].astype(int).values.tolist() 53 | mask = np.full(length, True).tolist() 54 | else: # padding 55 | x_value = tmp.tolist() 56 | x_value.extend([0 for i in range(length-len_x)]) 57 | gene = x.var['my_Id'].astype(int).values.tolist() 58 | gene.extend([0 for i in range(length-len_x)]) 59 | mask = np.concatenate((np.full(len_x,True), np.full(length-len_x,False))) 60 | return x_value, gene, mask 61 | 62 | class fix_SCDataset(Dataset): 63 | def __init__(self, scRNA_adata, scP_adata, len_rna, len_protein): 64 | super().__init__() 65 | self.scRNA_adata = scRNA_adata 66 | self.scP_adata = scP_adata 67 | self.len_rna = len_rna 68 | self.len_protein = len_protein 69 | 70 | def __getitem__(self, index): 71 | k = self.scRNA_adata.obs.index[index] 72 | rna_value, rna_gene, rna_mask = fix_sc_normalize_truncate_padding(self.scRNA_adata[k], self.len_rna) 73 | pro_value, pro_gene, pro_mask = pro_fix_sc_truncate_padding(self.scP_adata[k], self.len_protein) 74 | return np.array([rna_value, rna_gene, rna_mask]), np.array([pro_value, pro_gene, pro_mask]) 75 | 76 | def __len__(self): 77 | return self.scRNA_adata.n_obs 78 | 79 | ################################################# 80 | #---------------- Main Function ----------------# 81 | ################################################# 82 | def main(): 83 | parser = argparse.ArgumentParser(description='PyTorch Example') 84 | parser.add_argument('--enc_max_seq_len', type=int, default=20000, 85 | help='sequence length of encoder') 86 | parser.add_argument('--dec_max_seq_len', type=int, default=1000, 87 | help='sequence length of decoder') 88 | parser.add_argument('--test_batch_size', type=int, default=2, 89 | help='input batch size for testing (default: 2)') 90 | parser.add_argument('--seed', type=int, default=1105, 91 | help='random seed (default: 1105)') 92 | parser.add_argument('--pretrain_checkpoint', default='checkpoint/Dataset1_fine-tuned_scTranslator.pt', 93 | help='path for loading the checkpoint') 94 | parser.add_argument('--RNA_path', default='dataset/test/dataset1/GSM5008737_RNA_finetune_withcelltype.h5ad', 95 | help='path for loading the rna') 96 | parser.add_argument('--Pro_path', default='dataset/test/query_protein_ID.csv', 97 | help='path for loading the protein') 98 | parser.add_argument('--gene', default='org', help='knock out gene, eg: Predictability(org), TRIM39') 99 | args = parser.parse_args() 100 | 101 | print('seed', args.seed) 102 | #----- Load model -----# 103 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 104 | print('device',device) 105 | model = torch.load(args.pretrain_checkpoint, map_location=torch.device(device)) 106 | 107 | #----- Load single-cell data -----# 108 | scRNA_adata = sc.read_h5ad(args.RNA_path) 109 | obs = pd.DataFrame(scRNA_adata.obs.values.tolist(), index=scRNA_adata.obs.index) 110 | protein = pd.read_csv(args.Pro_path) 111 | X = np.zeros((scRNA_adata.n_obs, protein.shape[0])) 112 | scP_adata = ad.AnnData(X, obs=obs, var=protein) 113 | print('Total number of origin RNA genes: ', scRNA_adata.n_vars) 114 | print('Total number of origin proteins: ', scP_adata.n_vars) 115 | print('Total number of origin cells: ', scRNA_adata.n_obs) 116 | print('# of NAN in X', np.isnan(scRNA_adata.X).sum()) 117 | print('# of NAN in X', np.isnan(scP_adata.X).sum()) 118 | 119 | #--- Knock-out ---# 120 | gene = args.gene 121 | if gene != 'org': 122 | scRNA_adata = scRNA_adata[:, scRNA_adata.var.drop(index=gene).index] 123 | 124 | #--- Inference 1.4W protein ---# 125 | # setup_seed(1105+10) 126 | test_index = scRNA_adata.obs.index 127 | for i in range(int(scP_adata.n_vars/args.dec_max_seq_len)): 128 | my_testset = fix_SCDataset(scRNA_adata[test_index], scP_adata[test_index,1000*i:1000*(i+1)], args.enc_max_seq_len, args.dec_max_seq_len) 129 | test_loader = torch.utils.data.DataLoader(my_testset, batch_size=args.test_batch_size, drop_last=True) 130 | y_hat = test(model, test_loader, device) 131 | if i == 0: 132 | y_all = y_hat #(num_cell, num_protein) 133 | else: 134 | y_all = np.concatenate((y_all, y_hat), axis=1) 135 | print(y_hat.shape) 136 | #--- Save results ---# 137 | y_all = pd.DataFrame(y_all, columns=scP_adata.var['Hugo_Symbol'].tolist()) 138 | file_path = 'result/fig5/e' 139 | if not os.path.exists(file_path): 140 | os.makedirs(file_path) 141 | y_all.to_pickle(file_path+'/knock_out_'+gene+'.pkl') 142 | 143 | 144 | print('completed') 145 | 146 | 147 | if __name__ == '__main__': 148 | main() -------------------------------------------------------------------------------- /code/main_scripts/stage3_fine-tune.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import datetime 4 | import argparse 5 | import warnings 6 | 7 | 8 | import torch.optim as optim 9 | from torch.optim.lr_scheduler import StepLR 10 | 11 | 12 | import torch.optim as optim 13 | import scanpy as sc 14 | import numpy as np 15 | import pandas as pd 16 | from sklearn.model_selection import ShuffleSplit 17 | 18 | import sys 19 | sys.path.append('code/model') 20 | from performer_enc_dec import * 21 | from utils import * 22 | 23 | def main(): 24 | parser = argparse.ArgumentParser(description='PyTorch Example') 25 | parser.add_argument('--batch_size', type=int, default=1, metavar='N', 26 | help='input batch size for each GPU training (default: 1)') 27 | parser.add_argument('--test_batch_size', type=int, default=4, 28 | help='input batch size for testing (default: 4)') 29 | parser.add_argument('--epochs', type=int, default=100, metavar='N', 30 | help='number of epochs to train (default: 100)') 31 | parser.add_argument('--lr', type=float, default=2*1e-4, metavar='LR', 32 | help='learning rate (default: 1.0)') 33 | parser.add_argument('--gamma', type=float, default=1, metavar='M', 34 | help='Learning rate step gamma (default: 1 (not used))') 35 | parser.add_argument('--gamma_step', type=float, default=2000, 36 | help='Learning rate step (default: 2000 (not used))') 37 | parser.add_argument('--no-cuda', action='store_true', default=False, 38 | help='disables CUDA training') 39 | parser.add_argument('--seed', type=int, default=1105, 40 | help='random seed (default: 1105)') 41 | parser.add_argument('--repeat', type=int, default=1, 42 | help='for repeating experiments to change seed (default: 1)') 43 | parser.add_argument('--local_rank', default=0, type=int, 44 | help='node rank for distributed training') 45 | parser.add_argument('--frac_finetune_test', type=float, default=0.1, 46 | help='test set ratio') 47 | parser.add_argument('--dim', type=int, default=128, 48 | help='latend dimension of each token') 49 | parser.add_argument('--enc_max_seq_len', type=int, default=20000, 50 | help='sequence length of encoder') 51 | parser.add_argument('--dec_max_seq_len', type=int, default=1000, 52 | help='sequence length of decoder') 53 | parser.add_argument('--translator_depth', type=int, default=2, 54 | help='translator depth') 55 | parser.add_argument('--initial_dropout', type=float, default=0.1, 56 | help='sequence length of decoder') 57 | parser.add_argument('--enc_depth', type=int, default=2, 58 | help='sequence length of decoder') 59 | parser.add_argument('--enc_heads', type=int, default=8, 60 | help='sequence length of decoder') 61 | parser.add_argument('--dec_depth', type=int, default=2, 62 | help='sequence length of decoder') 63 | parser.add_argument('--dec_heads', type=int, default=8, 64 | help='sequence length of decoder') 65 | parser.add_argument('--fix_set', action='store_false', 66 | help='fix (aligned) or disordering (un-aligned) dataset') 67 | parser.add_argument('--pretrain_checkpoint', default='checkpoint/stage2_single-cell_scTranslator.pt', 68 | help='path for loading the pretrain checkpoint') 69 | parser.add_argument('--resume', default=False, help='resume training from breakpoint') 70 | parser.add_argument('--path_checkpoint', default='checkpoint/stage2_single-cell_scTranslator.pt', 71 | help='path for loading the resume checkpoint (need specify)') 72 | parser.add_argument('--RNA_path', default='dataset/test/dataset1/GSM5008737_RNA_finetune_withcelltype.h5ad', 73 | help='path for loading the rna') 74 | parser.add_argument('--Pro_path', default='dataset/test/dataset1/GSM5008738_protein_finetune_withcelltype.h5ad', 75 | help='path for loading the protein') 76 | args = parser.parse_args() 77 | warnings.filterwarnings('ignore') 78 | ######################### 79 | #--- Prepare for DDP ---# 80 | ######################### 81 | use_cuda = not args.no_cuda and torch.cuda.is_available() 82 | print("use_cuda: %s" % use_cuda) 83 | ngpus_per_node = torch.cuda.device_count() 84 | print("ngpus_per_node: %s" % ngpus_per_node) 85 | is_distributed = ngpus_per_node > 1 86 | print('seed', args.seed) 87 | setup_seed(args.seed) 88 | print(torch.__version__) 89 | # Initializes the distributed environment to help process communication 90 | torch.distributed.init_process_group(backend='nccl', timeout=datetime.timedelta(seconds=15400)) 91 | # Each process sets the GPU it should use based on its local rank 92 | print("local_rank: %s" % args.local_rank) 93 | device = torch.device("cuda", args.local_rank) 94 | print(device) 95 | torch.cuda.set_device(args.local_rank) 96 | rank = int(os.environ['RANK']) 97 | print('rank', rank) 98 | 99 | ########################### 100 | #--- Prepare The Model ---# 101 | ########################### 102 | model = scPerformerEncDec( 103 | dim=args.dim, 104 | translator_depth=args.translator_depth, 105 | initial_dropout=args.initial_dropout, 106 | enc_depth=args.enc_depth, 107 | enc_heads=args.enc_heads, 108 | enc_max_seq_len=args.enc_max_seq_len, 109 | dec_depth=args.dec_depth, 110 | dec_heads=args.dec_heads, 111 | dec_max_seq_len=args.dec_max_seq_len 112 | ) 113 | model = torch.load(args.pretrain_checkpoint) 114 | # Resume training from breakpoints 115 | if args.resume == True: 116 | checkpoint = torch.load(args.path_checkpoint) 117 | model = checkpoint['net'] 118 | model = model.to(device) 119 | if is_distributed: 120 | print("start init process group") 121 | # device_ids will include all GPU devices by default 122 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True) 123 | print("end init process group") 124 | #--- Prepare Optimizer ---# 125 | optimizer = optim.Adam(model.parameters(), lr=args.lr, amsgrad=True) 126 | optimizer.load_state_dict(checkpoint['optimizer']) 127 | #--- Prepare Scheduler ---# 128 | scheduler = StepLR(optimizer, step_size=args.gamma_step, gamma=args.gamma) 129 | scheduler.load_state_dict(checkpoint['scheduler']) 130 | start_epoch = checkpoint['epoch'] 131 | else: 132 | start_epoch = 0 133 | model = model.to(device) 134 | if is_distributed: 135 | print("start init process group") 136 | # device_ids will include all GPU devices by default 137 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True) 138 | print("end init process group") 139 | #--- Prepare Optimizer ---# 140 | optimizer = optim.Adam(model.parameters(), lr=args.lr, amsgrad=True) 141 | #--- Prepare Scheduler ---# 142 | scheduler = StepLR(optimizer, step_size=args.gamma_step, gamma=args.gamma) 143 | ########################## 144 | #--- Prepare The Data ---# 145 | ########################## 146 | 147 | #--- Load Single Cell Data ---# 148 | scRNA_adata = sc.read_h5ad(args.RNA_path)[:100] 149 | scP_adata = sc.read_h5ad(args.Pro_path)[:100] 150 | print('Total number of origin RNA genes: ', scRNA_adata.n_vars) 151 | print('Total number of origin proteins: ', scP_adata.n_vars) 152 | print('Total number of origin cells: ', scRNA_adata.n_obs) 153 | print('# of NAN in X', np.isnan(scRNA_adata.X).sum()) 154 | print('# of NAN in X', np.isnan(scP_adata.X).sum()) 155 | 156 | #--- Seperate Training and Testing set ---# 157 | setup_seed(args.seed+args.repeat) 158 | train_index, test_index = next(ShuffleSplit(n_splits=1,test_size=args.frac_finetune_test).split(scRNA_adata.obs.index)) 159 | # --- RNA ---# 160 | train_rna = scRNA_adata[train_index] 161 | test_rna = scRNA_adata[test_index] 162 | # --- Protein ---# 163 | train_protein = scP_adata[train_index] 164 | test_protein = scP_adata[test_index] 165 | #--- Construct Dataloader ---# 166 | train_kwargs = {'batch_size': args.batch_size} 167 | test_kwargs = {'batch_size': args.test_batch_size} 168 | if use_cuda: 169 | cuda_kwargs = {'num_workers': 32, 170 | 'shuffle': False, 171 | 'prefetch_factor': 2, 172 | 'pin_memory': True} 173 | train_kwargs.update(cuda_kwargs) 174 | test_kwargs.update(cuda_kwargs) 175 | if args.fix_set == True: 176 | my_trainset = fix_SCDataset(train_rna, train_protein, args.enc_max_seq_len, args.dec_max_seq_len) 177 | my_testset = fix_SCDataset(test_rna, test_protein, args.enc_max_seq_len, args.dec_max_seq_len) 178 | else: 179 | my_trainset = SCDataset(train_rna, train_protein, args.enc_max_seq_len, args.dec_max_seq_len) 180 | my_testset = SCDataset(test_rna, test_protein, args.enc_max_seq_len, args.dec_max_seq_len) 181 | 182 | train_sampler = torch.utils.data.distributed.DistributedSampler(my_trainset) 183 | test_sampler = torch.utils.data.distributed.DistributedSampler(my_testset) 184 | 185 | train_loader = torch.utils.data.DataLoader(my_trainset, **train_kwargs, drop_last=True, sampler=train_sampler) 186 | test_loader = torch.utils.data.DataLoader(my_testset, **test_kwargs, drop_last=True, sampler=test_sampler) 187 | print("end distributed data") 188 | 189 | ############################### 190 | #--- Training and Testing ---# 191 | ############################### 192 | start_time = time.time() 193 | for epoch in range(start_epoch+1, args.epochs + 1): 194 | train_sampler.set_epoch(epoch) 195 | test_sampler.set_epoch(epoch) 196 | torch.cuda.empty_cache() 197 | 198 | train_loss, train_ccc = train(args, model, device, train_loader, optimizer, epoch) 199 | scheduler.step() 200 | 201 | test_loss, test_ccc, y_hat, y = test(model, device, test_loader) 202 | y_pred = pd.DataFrame(y_hat, columns=test_protein.var.index.tolist()) 203 | y_truth = pd.DataFrame(y, columns=test_protein.var.index.tolist()) 204 | 205 | ############################## 206 | #--- Prepare for Storage ---# 207 | ############################## 208 | 209 | # save results in the first rank 210 | if args.RNA_path == 'dataset/test/dataset1/GSM5008737_RNA_finetune_withcelltype.h5ad': 211 | dataset_flag = '/seuratv4_16W' 212 | else: 213 | dataset_flag = '/new_data' 214 | file_path = 'result/test'+dataset_flag 215 | if not os.path.exists(file_path): 216 | os.makedirs(file_path) 217 | # save args 218 | if rank == 0: 219 | dict = vars(args) 220 | filename = open(file_path+'/args'+str(args.repeat)+'.txt','w') 221 | for k,v in dict.items(): 222 | filename.write(k+':'+str(v)) 223 | filename.write('\n') 224 | filename.close() 225 | 226 | #--- Save the Final Results ---# 227 | log_path = file_path+'/'+str(rank)+'_rank_log.csv' 228 | log_all = pd.DataFrame(columns=['train_loss', 'train_ccc', 'test_loss', 'test_ccc']) 229 | log_all.loc[args.repeat] = np.array([train_loss, train_ccc, test_loss, test_ccc]) 230 | log_all.to_csv(log_path) 231 | y_pred.to_csv(file_path+'/y_pred.csv') 232 | y_truth.to_csv(file_path+'/y_truth.csv') 233 | print('-'*40) 234 | print('single cell '+str(args.enc_max_seq_len)+' RNA To '+str(args.dec_max_seq_len)+' Protein on dataset'+dataset_flag) 235 | print('Overall performance on rank_%d in repeat_%d costTime: %.4fs' % (rank, args.repeat, time.time() - start_time)) 236 | print('Training Set: AVG mse %.4f, AVG ccc %.4f' % (np.mean(log_all['train_loss'][:args.repeat]), np.mean(log_all['train_ccc'][:args.repeat]))) 237 | print('Test Set: AVG mse %.4f, AVG ccc %.4f' % (np.mean(log_all['test_loss'][:args.repeat]), np.mean(log_all['test_ccc'][:args.repeat]))) 238 | 239 | if __name__ == '__main__': 240 | main() -------------------------------------------------------------------------------- /code/main_scripts/stage3_inference_without_finetune.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | import warnings 5 | 6 | import scanpy as sc 7 | import numpy as np 8 | import pandas as pd 9 | 10 | 11 | import sys 12 | sys.path.append('code/model') 13 | from performer_enc_dec import * 14 | from utils import * 15 | 16 | def main(): 17 | parser = argparse.ArgumentParser(description='PyTorch Example') 18 | parser.add_argument('--repeat', type=int, default=1, 19 | help='for repeating experiments to change seed (default: 1)') 20 | parser.add_argument('--test_batch_size', type=int, default=4, 21 | help='input batch size for testing (default: 4)') 22 | parser.add_argument('--no-cuda', action='store_true', default=False, 23 | help='disables CUDA training') 24 | parser.add_argument('--seed', type=int, default=1105, 25 | help='random seed (default: 1105)') 26 | parser.add_argument('--enc_max_seq_len', type=int, default=20000, 27 | help='sequence length of encoder') 28 | parser.add_argument('--dec_max_seq_len', type=int, default=1000, 29 | help='sequence length of decoder') 30 | parser.add_argument('--fix_set', action='store_false', 31 | help='fix (aligned) or disordering (un-aligned) dataset') 32 | parser.add_argument('--pretrain_checkpoint', default='checkpoint/stage2_single-cell_scTranslator.pt', 33 | help='path for loading the pretrain checkpoint') 34 | parser.add_argument('--RNA_path', default='dataset/test/dataset1/GSM5008737_RNA_finetune_withcelltype.h5ad', 35 | help='path for loading the rna') 36 | parser.add_argument('--Pro_path', default='dataset/test/dataset1/GSM5008738_protein_finetune_withcelltype.h5ad', 37 | help='path for loading the protein') 38 | args = parser.parse_args() 39 | warnings.filterwarnings('ignore') 40 | 41 | ########################### 42 | #--- Prepare The Model ---# 43 | ########################### 44 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 45 | print('device',device) 46 | model = torch.load(args.pretrain_checkpoint, map_location=torch.device(device)) 47 | # model = model.to(device) 48 | 49 | ########################## 50 | #--- Prepare The Data ---# 51 | ########################## 52 | 53 | #--- Load Single Cell Data ---# 54 | scRNA_adata = sc.read_h5ad(args.RNA_path)[:100] 55 | scP_adata = sc.read_h5ad(args.Pro_path)[:100] 56 | print('Total number of origin RNA genes: ', scRNA_adata.n_vars) 57 | print('Total number of origin proteins: ', scP_adata.n_vars) 58 | print('Total number of origin cells: ', scRNA_adata.n_obs) 59 | print('# of NAN in X', np.isnan(scRNA_adata.X).sum()) 60 | print('# of NAN in X', np.isnan(scP_adata.X).sum()) 61 | 62 | #--- Seperate Training and Testing set ---# 63 | test_rna = scRNA_adata 64 | # --- Protein ---# 65 | test_protein = scP_adata[test_rna.obs.index] 66 | # #--- Construct Dataloader ---# 67 | if args.fix_set == True: 68 | my_testset = fix_SCDataset(test_rna, test_protein, args.enc_max_seq_len, args.dec_max_seq_len) 69 | else: 70 | my_testset = SCDataset(test_rna, test_protein, args.enc_max_seq_len, args.dec_max_seq_len) 71 | 72 | test_loader = torch.utils.data.DataLoader(my_testset, batch_size=args.test_batch_size, drop_last=True) 73 | print("load data ended") 74 | 75 | ################## 76 | #--- Testing ---# 77 | ################## 78 | start_time = time.time() 79 | test_loss, test_ccc, y_hat, y = test(model, device, test_loader) 80 | y_pred = pd.DataFrame(y_hat, columns=test_protein.var.index.tolist()) 81 | y_truth = pd.DataFrame(y, columns=test_protein.var.index.tolist()) 82 | ############################## 83 | #--- Prepare for Storage ---# 84 | ############################## 85 | 86 | 87 | if args.RNA_path == 'dataset/test/dataset1/GSM5008737_RNA_finetune_withcelltype.h5ad': 88 | dataset_flag = '/seuratv4_16W-without_fine-tune' 89 | else: 90 | dataset_flag = '/new_data-without_fine-tune' 91 | file_path = 'result/test'+dataset_flag 92 | if not os.path.exists(file_path): 93 | os.makedirs(file_path) 94 | 95 | dict = vars(args) 96 | filename = open(file_path+'/args'+str(args.repeat)+'.txt','w') 97 | for k,v in dict.items(): 98 | filename.write(k+':'+str(v)) 99 | filename.write('\n') 100 | filename.close() 101 | 102 | #--- Save the Final Results ---# 103 | log_path = file_path+'/performance_log.csv' 104 | log_all = pd.DataFrame(columns=['test_loss', 'test_ccc']) 105 | log_all.loc[args.repeat] = np.array([test_loss, test_ccc]) 106 | log_all.to_csv(log_path) 107 | y_pred.to_csv(file_path+'/y_pred.csv') 108 | y_truth.to_csv(file_path+'/y_truth.csv') 109 | 110 | print('-'*40) 111 | print('single cell '+str(args.enc_max_seq_len)+' RNA To '+str(args.dec_max_seq_len)+' Protein on dataset'+dataset_flag) 112 | print('Overall performance in repeat_%d costTime: %.4fs' % ( args.repeat, time.time() - start_time)) 113 | print('Test Set: AVG mse %.4f, AVG ccc %.4f' % (np.mean(log_all['test_loss'][:args.repeat]), np.mean(log_all['test_ccc'][:args.repeat]))) 114 | if __name__ == '__main__': 115 | main() -------------------------------------------------------------------------------- /code/model/ID_dic/EntrezID_to_myID.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentAILabHealthcare/scTranslator/5f3c2de37c013bce8f7a14ebb135b7bdc141a6fe/code/model/ID_dic/EntrezID_to_myID.pkl -------------------------------------------------------------------------------- /code/model/ID_dic/hgs_to_EntrezID.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentAILabHealthcare/scTranslator/5f3c2de37c013bce8f7a14ebb135b7bdc141a6fe/code/model/ID_dic/hgs_to_EntrezID.pkl -------------------------------------------------------------------------------- /code/model/ID_dic/mouse_gene_ID_to_human_gene_symbol.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentAILabHealthcare/scTranslator/5f3c2de37c013bce8f7a14ebb135b7bdc141a6fe/code/model/ID_dic/mouse_gene_ID_to_human_gene_symbol.pkl -------------------------------------------------------------------------------- /code/model/ID_dic/mouse_gene_symbol_to_human_gene_symbol.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentAILabHealthcare/scTranslator/5f3c2de37c013bce8f7a14ebb135b7bdc141a6fe/code/model/ID_dic/mouse_gene_symbol_to_human_gene_symbol.pkl -------------------------------------------------------------------------------- /code/model/data_preprocessing_ID_convert.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | import os 4 | import scanpy as sc 5 | import anndata 6 | 7 | def EntrezID_to_myID(EntrezID): 8 | myID = EntrezID_to_myID_dict.get(EntrezID) 9 | if myID: 10 | return myID 11 | else: 12 | return None 13 | 14 | def hugo_symbol_to_myID(hugo_symbol): 15 | EntrezID = hgs_to_EntrezID_dict.get(hugo_symbol) 16 | myID = EntrezID_to_myID_dict.get(EntrezID) 17 | if myID: 18 | return myID 19 | else: 20 | return None 21 | 22 | def mouse_id_to_myID(mouse_id): 23 | hugo_symbol = mouseID_to_hgs_dict.get(mouse_id) 24 | EntrezID = hgs_to_EntrezID_dict.get(hugo_symbol) 25 | myID = EntrezID_to_myID_dict.get(EntrezID) 26 | if myID: 27 | return myID 28 | else: 29 | return None 30 | 31 | def mouse_name_to_myID(mouse_name): 32 | hugo_symbol = mouse_gene_symbol_to_hgs_dict.get(mouse_name) 33 | EntrezID = hgs_to_EntrezID_dict.get(hugo_symbol) 34 | myID = EntrezID_to_myID_dict.get(EntrezID) 35 | if myID: 36 | return myID 37 | else: 38 | return None 39 | 40 | ################################################# 41 | #---------------- Main Function ----------------# 42 | ################################################# 43 | def main(): 44 | global hgs_to_EntrezID_dict 45 | global EntrezID_to_myID_dict 46 | global mouseID_to_hgs_dict 47 | global mouse_gene_symbol_to_hgs_dict 48 | 49 | #--- Settings ---# 50 | parser = argparse.ArgumentParser(description='PyTorch Example') 51 | parser.add_argument('--origin_gene_type', type=str, 52 | choices=['mouse_gene_ID', 'mouse_gene_symbol', 'human_gene_symbol', 'EntrezID'], 53 | default='mouse_gene_symbol', 54 | help='original gene type (must be one of: mouse_gene_ID, mouse_gene_symbol, human_gene_symbol, EntrezID)') 55 | parser.add_argument('--origin_gene_column', type=str, default='index', 56 | help='Colum name of origin gene location, eg. index, feature, gene, protein') 57 | parser.add_argument('--data_path', default='dataset/test/cite-seq_mouse/spleen_lymph_111.h5ad', 58 | help='path for loading the anndata') 59 | args = parser.parse_args() 60 | feature_column = args.origin_gene_column 61 | dic_path = 'code/model/ID_dic/' 62 | file_path = os.path.join(dic_path, 'hgs_to_EntrezID.pkl') 63 | with open(file_path, 'rb') as f: 64 | hgs_to_EntrezID_dict = pickle.load(f) 65 | 66 | file_path = os.path.join(dic_path, 'EntrezID_to_myID.pkl') 67 | with open(file_path, 'rb') as f: 68 | EntrezID_to_myID_dict = pickle.load(f) 69 | 70 | file_path = os.path.join(dic_path, 'mouse_gene_ID_to_human_gene_symbol.pkl') 71 | with open(file_path, 'rb') as f: 72 | mouseID_to_hgs_dict = pickle.load(f) 73 | 74 | file_path = os.path.join(dic_path, 'mouse_gene_symbol_to_human_gene_symbol.pkl') 75 | with open(file_path, 'rb') as f: 76 | mouse_gene_symbol_to_hgs_dict = pickle.load(f) 77 | 78 | data_file_path = os.path.join(args.data_path) 79 | if '.h5ad' in data_file_path: 80 | data_file_path = data_file_path.replace('.h5ad', '') 81 | 82 | try: 83 | adata = sc.read_h5ad(os.path.join(data_file_path + '.h5ad')) 84 | except FileNotFoundError: 85 | print(f"can not find file: {data_file_path}") 86 | return 87 | print('# of genes before mapping:' , adata.n_vars) 88 | 89 | if args.origin_gene_type == 'mouse_gene_ID': 90 | orgID_to_myID = mouse_id_to_myID 91 | elif args.origin_gene_type == 'mouse_gene_symbol': 92 | orgID_to_myID = mouse_name_to_myID 93 | elif args.origin_gene_type == 'human_gene_symbol': 94 | orgID_to_myID = hugo_symbol_to_myID 95 | elif args.origin_gene_type == 'EntrezID': 96 | orgID_to_myID = EntrezID_to_myID 97 | 98 | if feature_column == 'index': 99 | feature_column = 'feature' 100 | adata.var['feature'] = adata.var.index 101 | 102 | adata.var['my_Id'] = adata.var[feature_column].tolist() 103 | for index, org_id in zip(adata.var.index.tolist(), adata.var[feature_column].tolist()): 104 | adata.var.loc[index, 'my_Id'] = orgID_to_myID(org_id) 105 | 106 | # delete no mapping gene 107 | flag = adata.var.index[~adata.var['my_Id'].isna()] 108 | new_var = adata.var.loc[flag, :] 109 | # delete the expression value of no mapping gene 110 | new_X = adata[:, flag].X 111 | # create new AnnData object 112 | filtered_adata = anndata.AnnData(X=new_X, var=new_var, obs=adata.obs) 113 | 114 | filtered_adata.write(data_file_path + '_mapped.h5ad') 115 | print('# of genes after mapping:' , filtered_adata.n_vars) 116 | print('Gene mapping completed!') 117 | 118 | if __name__ == '__main__': 119 | main() 120 | -------------------------------------------------------------------------------- /code/model/performer_enc_dec.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | from torch import nn 4 | from performer_pytorch import * 5 | from math import ceil 6 | 7 | ENC_PREFIX = 'enc_' 8 | DEC_PREFIX = 'dec_' 9 | 10 | def group_dict_by_key(cond, d): 11 | return_val = [dict(),dict()] 12 | for key in d.keys(): 13 | match = bool(cond(key)) 14 | ind = int(not match) 15 | return_val[ind][key] = d[key] 16 | return (*return_val,) 17 | 18 | def string_begins_with(prefix, str): 19 | return bool(re.match(f'^{prefix}', str)) 20 | 21 | def group_by_key_prefix(prefix, d): 22 | return group_dict_by_key(lambda x: string_begins_with(prefix, x), d) 23 | 24 | def group_by_key_prefix_and_remove_prefix(prefix, d): 25 | kwargs_with_prefix, kwargs = group_dict_by_key(lambda x: string_begins_with(prefix, x), d) 26 | kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) 27 | return kwargs_without_prefix, kwargs 28 | 29 | def extract_enc_dec_kwargs(kwargs): 30 | enc_kwargs, kwargs = group_by_key_prefix_and_remove_prefix(ENC_PREFIX, kwargs) 31 | dec_kwargs, kwargs = group_by_key_prefix_and_remove_prefix(DEC_PREFIX, kwargs) 32 | return enc_kwargs, dec_kwargs, kwargs 33 | 34 | def extract_and_set_enc_dec_kwargs(kwargs): 35 | enc_kwargs, dec_kwargs, kwargs = extract_enc_dec_kwargs(kwargs) 36 | if 'mask' in enc_kwargs: 37 | dec_kwargs.setdefault('context_mask', enc_kwargs['mask']) 38 | return enc_kwargs, dec_kwargs, kwargs 39 | 40 | ################################################# 41 | #-------------------- Model --------------------# 42 | ################################################# 43 | 44 | class scPerformerLM(nn.Module): 45 | def __init__( 46 | self, 47 | *, 48 | 49 | max_seq_len, 50 | dim,depth, 51 | heads, 52 | num_tokens=1, 53 | dim_head = 64, 54 | local_attn_heads = 0, 55 | local_window_size = 256, 56 | causal = False, 57 | ff_mult = 4, 58 | nb_features = None, 59 | feature_redraw_interval = 1000, 60 | reversible = False, 61 | ff_chunks = 1, 62 | ff_glu = False, 63 | emb_dropout = 0., 64 | ff_dropout = 0., 65 | attn_dropout = 0., 66 | generalized_attention = False, 67 | kernel_fn = nn.ReLU(), 68 | use_scalenorm = False, 69 | use_rezero = False, 70 | cross_attend = False, 71 | no_projection = False, 72 | tie_embed = False, 73 | rotary_position_emb = True, 74 | axial_position_emb = False, 75 | axial_position_shape = None, 76 | auto_check_redraw = True, 77 | qkv_bias = False, 78 | attn_out_bias = False, 79 | shift_tokens = False 80 | ): 81 | super().__init__() 82 | local_attn_heads = cast_tuple(local_attn_heads) 83 | 84 | self.max_seq_len = max_seq_len 85 | self.to_vector = nn.Linear(1,dim) 86 | self.pos_emb = nn.Embedding(85500,dim,padding_idx=0)# There are 75500 NCBI Gene ID obtained on 19th July, 2022 87 | self.layer_pos_emb = Always(None) 88 | self.dropout = nn.Dropout(emb_dropout) 89 | self.performer = Performer(dim, depth, heads, dim_head) 90 | # self.norm = nn.LayerNorm(dim) 91 | self.to_out = nn.Linear(dim, num_tokens) if not tie_embed else None 92 | 93 | def check_redraw_projections(self): 94 | self.performer.check_redraw_projections() 95 | 96 | def fix_projection_matrices_(self): 97 | self.performer.fix_projection_matrices_() 98 | 99 | def forward(self, x, geneID, return_encodings = False, **kwargs): 100 | b, n = x.shape[0], x.shape[1] 101 | assert n <= self.max_seq_len, f'sequence length {n} must be less than the max sequence length {self.max_seq_len}' 102 | 103 | # token and positional embeddings 104 | if len(x.shape)<3: 105 | x = torch.unsqueeze(x,dim=2) 106 | x = self.to_vector(x) 107 | 108 | x += self.pos_emb(geneID) 109 | x = self.dropout(x) 110 | # performer layers 111 | layer_pos_emb = self.layer_pos_emb(x) 112 | x = self.performer(x, pos_emb = layer_pos_emb, **kwargs) 113 | 114 | if return_encodings: 115 | return x 116 | 117 | return torch.squeeze(self.to_out(x)) 118 | 119 | 120 | 121 | class MLPTranslator(nn.Module): 122 | """ 123 | Class description: translator from RNA to protein 124 | fully connected layer with adjustable number of layers and variable dropout for each layer 125 | 126 | """ 127 | #----- Define all layers -----# 128 | def __init__(self, num_fc_input, num_output_nodes, num_fc_layers, initial_dropout, act = nn.ReLU(), **kwargs): 129 | super(MLPTranslator, self).__init__(**kwargs) 130 | fc_d = pow(num_fc_input/num_output_nodes,1/num_fc_layers) # reduce factor of fc layer dimension 131 | #--- Fully connected layers ---# 132 | self.num_fc_layers = num_fc_layers 133 | if num_fc_layers == 1: 134 | self.fc0 = nn.Linear(num_fc_input, num_output_nodes) 135 | else: 136 | # the first fc layer 137 | self.fc0 = nn.Linear(num_fc_input, int(ceil(num_fc_input/fc_d))) 138 | self.dropout0 = nn.Dropout(initial_dropout) 139 | if num_fc_layers == 2: 140 | # the last fc layer when num_fc_layers == 2 141 | self.fc1 = nn.Linear(int(ceil(num_fc_input/fc_d)), num_output_nodes) 142 | else: 143 | # the middle fc layer 144 | for i in range(1,num_fc_layers-1): 145 | tmp_input = int(ceil(num_fc_input/fc_d**i)) 146 | tmp_output = int(ceil(num_fc_input/fc_d**(i+1))) 147 | exec('self.fc{} = nn.Linear(tmp_input, tmp_output)'.format(i)) 148 | if i < ceil(num_fc_layers/2) and 1.1**(i+1)*initial_dropout < 1: 149 | exec('self.dropout{} = nn.Dropout(1.1**(i+1)*initial_dropout)'.format(i)) 150 | elif i >= ceil(num_fc_layers/2) and 1.1**(num_fc_layers-1-i)*initial_dropout < 1: 151 | exec('self.dropout{} = nn.Dropout(1.1**(num_fc_layers-1-i)*initial_dropout)'.format(i)) 152 | else: 153 | exec('self.dropout{} = nn.Dropout(initial_dropout)'.format(i)) 154 | # the last fc layer 155 | exec('self.fc{} = nn.Linear(tmp_output, num_output_nodes)'.format(i+1)) 156 | 157 | #--- Activation function ---# 158 | self.act = act 159 | 160 | #----- Forward -----# 161 | def forward(self, x): 162 | # x size: [batch size, feature_dim] 163 | 164 | if self.num_fc_layers == 1: 165 | outputs = self.fc0(x) 166 | else: 167 | # the first fc layer 168 | outputs = self.act(self.dropout0(self.fc0(x))) 169 | if self.num_fc_layers == 2: 170 | # the last fc layer when num_fc_layers == 2 171 | outputs = self.fc1(outputs) 172 | else: 173 | # the middle fc layer 174 | for i in range(1,self.num_fc_layers-1): 175 | outputs = eval('self.act(self.dropout{}(self.fc{}(outputs)))'.format(i,i)) 176 | # the last fc layer 177 | outputs = eval('self.fc{}(outputs)'.format(i+1)) 178 | 179 | return outputs 180 | 181 | class scPerformerEncDec(nn.Module): 182 | def __init__( 183 | self, 184 | dim, 185 | translator_depth, 186 | initial_dropout, 187 | ignore_index = 0, 188 | pad_value = 0, 189 | tie_token_embeds = False, 190 | no_projection = False, 191 | **kwargs 192 | ): 193 | super().__init__() 194 | enc_kwargs, dec_kwargs, _ = extract_enc_dec_kwargs(kwargs) 195 | 196 | assert 'dim' not in dec_kwargs and 'dim' not in enc_kwargs, 'you must set the dim for both encoder and decoder' 197 | 198 | enc_kwargs['dim'] = dec_kwargs['dim'] = dim 199 | enc_kwargs['no_projection'] = dec_kwargs['no_projection'] = no_projection 200 | enc = scPerformerLM(**enc_kwargs) 201 | dec = scPerformerLM(**dec_kwargs) 202 | 203 | 204 | self.enc = enc 205 | self.translator = MLPTranslator(enc_kwargs['max_seq_len'], dec_kwargs['max_seq_len'], translator_depth, initial_dropout) 206 | self.dec = dec 207 | 208 | def forward(self, seq_in, seq_inID, seq_outID, **kwargs): 209 | enc_kwargs, dec_kwargs, kwargs = extract_and_set_enc_dec_kwargs(kwargs) 210 | encodings = self.enc(seq_in, seq_inID, return_encodings = True, **enc_kwargs)# batch_size, input_seq_lenth, dim 211 | seq_out = self.translator(encodings.transpose(1,2).contiguous()).transpose(1,2).contiguous() # batch_size, out_seq_lenth, dim 212 | return encodings, self.dec(seq_out, seq_outID, **dec_kwargs) 213 | -------------------------------------------------------------------------------- /code/model/performer_pytorch.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | from torch.cuda.amp import autocast 6 | from einops import rearrange, repeat 7 | from functools import partial 8 | from contextlib import contextmanager 9 | from local_attention import LocalAttention 10 | from reversible import ReversibleSequence, SequentialSequence 11 | from distutils.version import LooseVersion 12 | 13 | TORCH_GE_1_8_0 = LooseVersion(torch.__version__) >= LooseVersion('1.8.0') 14 | 15 | try: 16 | from apex import amp 17 | APEX_AVAILABLE = True 18 | except: 19 | APEX_AVAILABLE = False 20 | 21 | # helpers 22 | 23 | def exists(val): 24 | return val is not None 25 | 26 | def empty(tensor): 27 | return tensor.numel() == 0 28 | 29 | def default(val, d): 30 | return val if exists(val) else d 31 | 32 | @contextmanager 33 | def null_context(): 34 | yield 35 | 36 | def cast_tuple(val): 37 | return (val,) if not isinstance(val, tuple) else val 38 | 39 | def get_module_device(module): 40 | return next(module.parameters()).device 41 | 42 | def find_modules(nn_module, type): 43 | return [module for module in nn_module.modules() if isinstance(module, type)] 44 | 45 | class Always(nn.Module): 46 | def __init__(self, val): 47 | super().__init__() 48 | self.val = val 49 | 50 | def forward(self, *args, **kwargs): 51 | return self.val 52 | 53 | # token shifting helper and classes 54 | 55 | def shift(t, amount, mask = None): 56 | if amount == 0: 57 | return t 58 | 59 | if exists(mask): 60 | t = t.masked_fill(~mask[..., None], 0.) 61 | 62 | return F.pad(t, (0, 0, amount, -amount), value = 0.) 63 | 64 | class PreShiftTokens(nn.Module): 65 | def __init__(self, shifts, fn): 66 | super().__init__() 67 | self.fn = fn 68 | self.shifts = tuple(shifts) 69 | 70 | def forward(self, x, **kwargs): 71 | mask = kwargs.get('mask', None) 72 | shifts = self.shifts 73 | segments = len(shifts) 74 | feats_per_shift = x.shape[-1] // segments 75 | splitted = x.split(feats_per_shift, dim = -1) 76 | segments_to_shift, rest = splitted[:segments], splitted[segments:] 77 | segments_to_shift = list(map(lambda args: shift(*args, mask = mask), zip(segments_to_shift, shifts))) 78 | x = torch.cat((*segments_to_shift, *rest), dim = -1) 79 | return self.fn(x, **kwargs) 80 | 81 | # kernel functions 82 | 83 | # transcribed from jax to pytorch from 84 | # https://github.com/google-research/google-research/blob/master/performer/fast_attention/jax/fast_attention.py 85 | 86 | def softmax_kernel(data, *, projection_matrix, is_query, normalize_data=True, eps=1e-4, device = None): 87 | b, h, *_ = data.shape 88 | 89 | data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1. 90 | 91 | ratio = (projection_matrix.shape[0] ** -0.5) 92 | 93 | projection = repeat(projection_matrix, 'j d -> b h j d', b = b, h = h) 94 | projection = projection.type_as(data) 95 | 96 | data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data), projection) 97 | 98 | diag_data = data ** 2 99 | diag_data = torch.sum(diag_data, dim=-1) 100 | diag_data = (diag_data / 2.0) * (data_normalizer ** 2) 101 | diag_data = diag_data.unsqueeze(dim=-1) 102 | 103 | if is_query: 104 | data_dash = ratio * ( 105 | torch.exp(data_dash - diag_data - 106 | torch.amax(data_dash, dim=-1, keepdim=True).detach()) + eps) 107 | else: 108 | data_dash = ratio * ( 109 | torch.exp(data_dash - diag_data - torch.amax(data_dash, dim=(-1, -2), keepdim=True).detach()) + eps) 110 | 111 | return data_dash.type_as(data) 112 | 113 | def generalized_kernel(data, *, projection_matrix, kernel_fn = nn.ReLU(), kernel_epsilon = 0.001, normalize_data = True, device = None): 114 | b, h, *_ = data.shape 115 | 116 | data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1. 117 | 118 | if projection_matrix is None: 119 | return kernel_fn(data_normalizer * data) + kernel_epsilon 120 | 121 | projection = repeat(projection_matrix, 'j d -> b h j d', b = b, h = h) 122 | projection = projection.type_as(data) 123 | 124 | data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data), projection) 125 | 126 | data_prime = kernel_fn(data_dash) + kernel_epsilon 127 | return data_prime.type_as(data) 128 | 129 | def orthogonal_matrix_chunk(cols, device = None): 130 | unstructured_block = torch.randn((cols, cols), device = device) 131 | if TORCH_GE_1_8_0: 132 | q, r = torch.linalg.qr(unstructured_block.cpu(), mode = 'reduced') 133 | else: 134 | q, r = torch.qr(unstructured_block.cpu(), some = True) 135 | q, r = map(lambda t: t.to(device), (q, r)) 136 | return q.t() 137 | 138 | def gaussian_orthogonal_random_matrix(nb_rows, nb_columns, scaling = 0, device = None): 139 | nb_full_blocks = int(nb_rows / nb_columns) 140 | 141 | block_list = [] 142 | 143 | for _ in range(nb_full_blocks): 144 | q = orthogonal_matrix_chunk(nb_columns, device = device) 145 | block_list.append(q) 146 | 147 | remaining_rows = nb_rows - nb_full_blocks * nb_columns 148 | if remaining_rows > 0: 149 | q = orthogonal_matrix_chunk(nb_columns, device = device) 150 | block_list.append(q[:remaining_rows]) 151 | 152 | final_matrix = torch.cat(block_list) 153 | 154 | if scaling == 0: 155 | multiplier = torch.randn((nb_rows, nb_columns), device = device).norm(dim = 1) 156 | elif scaling == 1: 157 | multiplier = math.sqrt((float(nb_columns))) * torch.ones((nb_rows,), device = device) 158 | else: 159 | raise ValueError(f'Invalid scaling {scaling}') 160 | 161 | return torch.diag(multiplier) @ final_matrix 162 | 163 | # linear attention classes with softmax kernel 164 | 165 | # non-causal linear attention 166 | def linear_attention(q, k, v): 167 | k_cumsum = k.sum(dim = -2) 168 | D_inv = 1. / torch.einsum('...nd,...d->...n', q, k_cumsum.type_as(q)) 169 | context = torch.einsum('...nd,...ne->...de', k, v) 170 | out = torch.einsum('...de,...nd,...n->...ne', context, q, D_inv) 171 | return out 172 | 173 | # efficient causal linear attention, created by EPFL 174 | # TODO: rewrite EPFL's CUDA kernel to do mixed precision and remove half to float conversion and back 175 | def causal_linear_attention(q, k, v, eps = 1e-6): 176 | from fast_transformers.causal_product import CausalDotProduct 177 | autocast_enabled = torch.is_autocast_enabled() 178 | is_half = isinstance(q, torch.cuda.HalfTensor) 179 | assert not is_half or APEX_AVAILABLE, 'half tensors can only be used if nvidia apex is available' 180 | cuda_context = null_context if not autocast_enabled else partial(autocast, enabled = False) 181 | 182 | causal_dot_product_fn = amp.float_function(CausalDotProduct.apply) if is_half else CausalDotProduct.apply 183 | 184 | k_cumsum = k.cumsum(dim=-2) + eps 185 | D_inv = 1. / torch.einsum('...nd,...nd->...n', q, k_cumsum.type_as(q)) 186 | 187 | with cuda_context(): 188 | if autocast_enabled: 189 | q, k, v = map(lambda t: t.float(), (q, k, v)) 190 | 191 | out = causal_dot_product_fn(q, k, v) 192 | 193 | out = torch.einsum('...nd,...n->...nd', out, D_inv) 194 | return out 195 | 196 | # inefficient causal linear attention, without cuda code, for reader's reference 197 | # not being used 198 | def causal_linear_attention_noncuda(q, k, v, chunk_size = 128, eps = 1e-6): 199 | last_k_cumsum = 0 200 | last_context_cumsum = 0 201 | outs = [] 202 | 203 | for q, k, v in zip(*map(lambda t: t.chunk(chunk_size, dim = -2), (q, k, v))): 204 | k_cumsum = last_k_cumsum + k.cumsum(dim=-2) 205 | 206 | D_inv = 1. / torch.einsum('...nd,...nd->...n', q, k_cumsum.type_as(q) + eps) 207 | context = torch.einsum('...nd,...ne->...nde', k, v) 208 | context_cumsum = last_context_cumsum + context.cumsum(dim=-3) 209 | out = torch.einsum('...nde,...nd,...n->...ne', context_cumsum, q, D_inv) 210 | 211 | last_k_cumsum = k_cumsum[:, :, -1:] 212 | last_context_cumsum = context_cumsum[:, :, -1:] 213 | outs.append(out) 214 | 215 | return torch.cat(outs, dim = -2) 216 | 217 | class FastAttention(nn.Module): 218 | def __init__(self, dim_heads, nb_features = None, ortho_scaling = 0, causal = False, generalized_attention = False, kernel_fn = nn.ReLU(), no_projection = False): 219 | super().__init__() 220 | nb_features = default(nb_features, int(dim_heads * math.log(dim_heads))) 221 | 222 | self.dim_heads = dim_heads 223 | self.nb_features = nb_features 224 | self.ortho_scaling = ortho_scaling 225 | 226 | self.create_projection = partial(gaussian_orthogonal_random_matrix, nb_rows = self.nb_features, nb_columns = dim_heads, scaling = ortho_scaling) 227 | projection_matrix = self.create_projection() 228 | self.register_buffer('projection_matrix', projection_matrix) 229 | 230 | self.generalized_attention = generalized_attention 231 | self.kernel_fn = kernel_fn 232 | 233 | # if this is turned on, no projection will be used 234 | # queries and keys will be softmax-ed as in the original efficient attention paper 235 | self.no_projection = no_projection 236 | 237 | self.causal = causal 238 | if causal: 239 | try: 240 | import fast_transformers.causal_product.causal_product_cuda 241 | self.causal_linear_fn = partial(causal_linear_attention) 242 | except ImportError: 243 | print('unable to import cuda code for auto-regressive Performer. will default to the memory inefficient non-cuda version') 244 | self.causal_linear_fn = causal_linear_attention_noncuda 245 | 246 | @torch.no_grad() 247 | def redraw_projection_matrix(self, device): 248 | projections = self.create_projection(device = device) 249 | self.projection_matrix.copy_(projections) 250 | del projections 251 | 252 | def forward(self, q, k, v): 253 | device = q.device 254 | 255 | if self.no_projection: 256 | q = q.softmax(dim = -1) 257 | k = torch.exp(k) if self.causal else k.softmax(dim = -2) 258 | 259 | elif self.generalized_attention: 260 | create_kernel = partial(generalized_kernel, kernel_fn = self.kernel_fn, projection_matrix = self.projection_matrix, device = device) 261 | q, k = map(create_kernel, (q, k)) 262 | 263 | else: 264 | create_kernel = partial(softmax_kernel, projection_matrix = self.projection_matrix, device = device) 265 | q = create_kernel(q, is_query = True) 266 | k = create_kernel(k, is_query = False) 267 | 268 | attn_fn = linear_attention if not self.causal else self.causal_linear_fn 269 | out = attn_fn(q, k, v) 270 | return out 271 | 272 | # a module for keeping track of when to update the projections 273 | 274 | class ProjectionUpdater(nn.Module): 275 | def __init__(self, instance, feature_redraw_interval): 276 | super().__init__() 277 | self.instance = instance 278 | self.feature_redraw_interval = feature_redraw_interval 279 | self.register_buffer('calls_since_last_redraw', torch.tensor(0)) 280 | 281 | def fix_projections_(self): 282 | self.feature_redraw_interval = None 283 | 284 | def redraw_projections(self): 285 | model = self.instance 286 | 287 | if not self.training: 288 | return 289 | 290 | if exists(self.feature_redraw_interval) and self.calls_since_last_redraw >= self.feature_redraw_interval: 291 | device = get_module_device(model) 292 | 293 | fast_attentions = find_modules(model, FastAttention) 294 | for fast_attention in fast_attentions: 295 | fast_attention.redraw_projection_matrix(device) 296 | 297 | self.calls_since_last_redraw.zero_() 298 | return 299 | 300 | self.calls_since_last_redraw += 1 301 | 302 | def forward(self, x): 303 | raise NotImplemented 304 | 305 | # classes 306 | 307 | class ReZero(nn.Module): 308 | def __init__(self, fn): 309 | super().__init__() 310 | self.g = nn.Parameter(torch.tensor(1e-3)) 311 | self.fn = fn 312 | 313 | def forward(self, x, **kwargs): 314 | return self.fn(x, **kwargs) * self.g 315 | 316 | class PreScaleNorm(nn.Module): 317 | def __init__(self, dim, fn, eps=1e-5): 318 | super().__init__() 319 | self.fn = fn 320 | self.g = nn.Parameter(torch.ones(1)) 321 | self.eps = eps 322 | 323 | def forward(self, x, **kwargs): 324 | n = torch.norm(x, dim=-1, keepdim=True).clamp(min=self.eps) 325 | x = x / n * self.g 326 | return self.fn(x, **kwargs) 327 | 328 | class PreLayerNorm(nn.Module): 329 | def __init__(self, dim, fn): 330 | super().__init__() 331 | self.norm = nn.LayerNorm(dim) 332 | self.fn = fn 333 | def forward(self, x, **kwargs): 334 | return self.fn(self.norm(x), **kwargs) 335 | 336 | class Chunk(nn.Module): 337 | def __init__(self, chunks, fn, along_dim = -1): 338 | super().__init__() 339 | self.dim = along_dim 340 | self.chunks = chunks 341 | self.fn = fn 342 | 343 | def forward(self, x, **kwargs): 344 | if self.chunks == 1: 345 | return self.fn(x, **kwargs) 346 | chunks = x.chunk(self.chunks, dim = self.dim) 347 | return torch.cat([self.fn(c, **kwargs) for c in chunks], dim = self.dim) 348 | 349 | class FeedForward(nn.Module): 350 | def __init__(self, dim, mult = 4, dropout = 0., activation = None, glu = False): 351 | super().__init__() 352 | activation = default(activation, nn.GELU) 353 | 354 | self.glu = glu 355 | self.w1 = nn.Linear(dim, dim * mult * (2 if glu else 1)) 356 | self.act = activation() 357 | self.dropout = nn.Dropout(dropout) 358 | self.w2 = nn.Linear(dim * mult, dim) 359 | 360 | def forward(self, x, **kwargs): 361 | if not self.glu: 362 | x = self.w1(x) 363 | x = self.act(x) 364 | else: 365 | x, v = self.w1(x).chunk(2, dim=-1) 366 | x = self.act(x) * v 367 | 368 | x = self.dropout(x) 369 | x = self.w2(x) 370 | return x 371 | 372 | class Attention(nn.Module): 373 | def __init__( 374 | self, 375 | dim, 376 | causal = False, 377 | heads = 8, 378 | dim_head = 64, 379 | local_heads = 0, 380 | local_window_size = 256, 381 | nb_features = None, 382 | feature_redraw_interval = 1000, 383 | generalized_attention = False, 384 | kernel_fn = nn.ReLU(), 385 | dropout = 0., 386 | no_projection = False, 387 | qkv_bias = False, 388 | attn_out_bias = True 389 | ): 390 | super().__init__() 391 | # assert dim % heads == 0, 'dimension must be divisible by number of heads' 392 | dim_head = default(dim_head, dim // heads) 393 | inner_dim = dim_head * heads 394 | self.fast_attention = FastAttention(dim_head, nb_features, causal = causal, generalized_attention = generalized_attention, kernel_fn = kernel_fn, no_projection = no_projection) 395 | 396 | self.heads = heads 397 | self.global_heads = heads - local_heads 398 | self.local_attn = LocalAttention(window_size = local_window_size, causal = causal, autopad = True, dropout = dropout, look_forward = int(not causal), rel_pos_emb_config = (dim_head, local_heads)) if local_heads > 0 else None 399 | 400 | self.to_q = nn.Linear(dim, inner_dim, bias = qkv_bias) 401 | self.to_k = nn.Linear(dim, inner_dim, bias = qkv_bias) 402 | self.to_v = nn.Linear(dim, inner_dim, bias = qkv_bias) 403 | self.to_out = nn.Linear(inner_dim, dim, bias = attn_out_bias) 404 | self.dropout = nn.Dropout(dropout) 405 | 406 | def forward(self, x, pos_emb = None, context = None, mask = None, context_mask = None, **kwargs): 407 | b, n, _, h, gh = *x.shape, self.heads, self.global_heads 408 | 409 | cross_attend = exists(context) 410 | 411 | context = default(context, x) 412 | context_mask = default(context_mask, mask) if not cross_attend else context_mask 413 | 414 | q, k, v = self.to_q(x), self.to_k(context), self.to_v(context) 415 | 416 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) # (batch size, num_heads, seq_len, dim_head) 417 | (q, lq), (k, lk), (v, lv) = map(lambda t: (t[:, :gh], t[:, gh:]), (q, k, v)) 418 | 419 | attn_outs = [] 420 | 421 | if not empty(q): 422 | if exists(context_mask): 423 | global_mask = context_mask[:, None, :, None] 424 | v.masked_fill_(~global_mask, 0.) 425 | 426 | if exists(pos_emb) and not cross_attend: 427 | q, k = apply_rotary_pos_emb(q, k, pos_emb) 428 | 429 | out = self.fast_attention(q, k, v) 430 | attn_outs.append(out) 431 | 432 | if not empty(lq): 433 | assert not cross_attend, 'local attention is not compatible with cross attention' 434 | out = self.local_attn(lq, lk, lv, input_mask = mask) 435 | attn_outs.append(out) 436 | 437 | out = torch.cat(attn_outs, dim = 1) 438 | out = rearrange(out, 'b h n d -> b n (h d)') 439 | out = self.to_out(out) 440 | return self.dropout(out) 441 | 442 | class SelfAttention(Attention): 443 | def forward(self, *args, context = None, **kwargs): 444 | assert not exists(context), 'self attention should not receive context' 445 | return super().forward(*args, **kwargs) 446 | 447 | class CrossAttention(Attention): 448 | def forward(self, *args, context = None, **kwargs): 449 | assert exists(context), 'cross attention should receive context' 450 | return super().forward(*args, context = context, **kwargs) 451 | 452 | # positional embeddings 453 | 454 | class AbsolutePositionalEmbedding(nn.Module): 455 | def __init__(self, dim, max_seq_len): 456 | super().__init__() 457 | self.emb = nn.Embedding(max_seq_len, dim) 458 | 459 | def forward(self, x): 460 | t = torch.arange(x.shape[1], device=x.device) 461 | return self.emb(t) 462 | 463 | # rotary positional embedding helpers 464 | 465 | def rotate_every_two(x): 466 | x = rearrange(x, '... (d j) -> ... d j', j = 2) 467 | x1, x2 = x.unbind(dim = -1) 468 | x = torch.stack((-x2, x1), dim = -1) 469 | return rearrange(x, '... d j -> ... (d j)') 470 | 471 | def apply_rotary_pos_emb(q, k, sinu_pos): 472 | sinu_pos = rearrange(sinu_pos, '() n (j d) -> n j d', j = 2) 473 | sin, cos = sinu_pos.unbind(dim = -2) 474 | sin, cos = map(lambda t: repeat(t, 'b n -> b (n j)', j = 2), (sin, cos)) 475 | q, k = map(lambda t: (t * cos) + (rotate_every_two(t) * sin), (q, k)) 476 | return q, k 477 | 478 | # sinusoidal positional embeddings 479 | 480 | class FixedPositionalEmbedding(nn.Module): 481 | def __init__(self, dim, max_seq_len): 482 | super().__init__() 483 | inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) 484 | position = torch.arange(0, max_seq_len, dtype=torch.float) 485 | sinusoid_inp = torch.einsum("i,j->ij", position, inv_freq) 486 | emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) 487 | self.register_buffer('emb', emb) 488 | 489 | def forward(self, x): 490 | return self.emb[None, :x.shape[1], :].to(x) 491 | 492 | # performer 493 | 494 | class Performer(nn.Module): 495 | def __init__( 496 | self, 497 | dim, 498 | depth, 499 | heads, 500 | dim_head, 501 | local_attn_heads = 0, 502 | local_window_size = 256, 503 | causal = False, 504 | ff_mult = 4, 505 | nb_features = None,#64,# 506 | feature_redraw_interval = 1000, 507 | reversible = False, 508 | ff_chunks = 1, 509 | generalized_attention = False, 510 | kernel_fn = nn.ReLU(), 511 | use_scalenorm = False, 512 | use_rezero = False, 513 | ff_glu = False, 514 | ff_dropout = 0., 515 | attn_dropout = 0., 516 | cross_attend = False, 517 | no_projection = False, 518 | auto_check_redraw = True, 519 | qkv_bias = True, 520 | attn_out_bias = True, 521 | shift_tokens = False 522 | ): 523 | super().__init__() 524 | layers = nn.ModuleList([]) 525 | local_attn_heads = cast_tuple(local_attn_heads) 526 | local_attn_heads = local_attn_heads * depth if len(local_attn_heads) == 1 else local_attn_heads 527 | assert len(local_attn_heads) == depth, 'tuple specifying number of local attention heads per depth must be equal to the total depth' 528 | assert all(map(lambda n: n >= 0 and n <= heads, local_attn_heads)), 'local attention head value must be less than the total number of heads' 529 | 530 | if use_scalenorm: 531 | wrapper_fn = partial(PreScaleNorm, dim) 532 | elif use_rezero: 533 | wrapper_fn = ReZero 534 | else: 535 | wrapper_fn = partial(PreLayerNorm, dim) 536 | 537 | for _, local_heads in zip(range(depth), local_attn_heads): 538 | 539 | attn = SelfAttention(dim, causal = causal, heads = heads, dim_head = dim_head, local_heads = local_heads, local_window_size = local_window_size, nb_features = nb_features, generalized_attention = generalized_attention, kernel_fn = kernel_fn, dropout = attn_dropout, no_projection = no_projection, qkv_bias = qkv_bias, attn_out_bias = attn_out_bias) 540 | ff = Chunk(ff_chunks, FeedForward(dim, mult = ff_mult, dropout = ff_dropout, glu = ff_glu), along_dim = 1) 541 | 542 | if shift_tokens: 543 | shift = (0, 1) if causal else (-1, 0, 1) 544 | attn, ff = map(lambda t: PreShiftTokens(shift, t), (attn, ff)) 545 | 546 | attn, ff = map(wrapper_fn, (attn, ff)) 547 | layers.append(nn.ModuleList([attn, ff])) 548 | 549 | if not cross_attend: 550 | continue 551 | 552 | layers.append(nn.ModuleList([ 553 | wrapper_fn(CrossAttention(dim, heads = heads, dim_head = dim_head, nb_features = nb_features, generalized_attention = generalized_attention, kernel_fn = kernel_fn, dropout = attn_dropout, no_projection = no_projection, qkv_bias = qkv_bias, attn_out_bias = attn_out_bias)), 554 | wrapper_fn(Chunk(ff_chunks, FeedForward(dim, mult = ff_mult, dropout = ff_dropout, glu = ff_glu), along_dim = 1)) 555 | ])) 556 | 557 | execute_type = ReversibleSequence if reversible else SequentialSequence 558 | 559 | route_attn = ((True, False),) * depth * (2 if cross_attend else 1) 560 | route_context = ((False, False), (True, False)) * depth 561 | attn_route_map = {'mask': route_attn, 'pos_emb': route_attn} 562 | context_route_map = {'context': route_context, 'context_mask': route_context} if cross_attend else {} 563 | self.net = execute_type(layers, args_route = {**attn_route_map, **context_route_map}) 564 | 565 | # keeping track of when to redraw projections for all attention layers 566 | self.auto_check_redraw = auto_check_redraw 567 | self.proj_updater = ProjectionUpdater(self.net, feature_redraw_interval) 568 | 569 | def fix_projection_matrices_(self): 570 | self.proj_updater.feature_redraw_interval = None 571 | 572 | def forward(self, x, **kwargs): 573 | if self.auto_check_redraw: 574 | self.proj_updater.redraw_projections() 575 | return self.net(x, **kwargs) -------------------------------------------------------------------------------- /code/model/reversible.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | # from operator import itemgetter 4 | from torch.autograd.function import Function 5 | from torch.utils.checkpoint import get_device_states, set_device_states 6 | 7 | # for routing arguments into the functions of the reversible layer 8 | def route_args(router, args, depth): 9 | routed_args = [(dict(), dict()) for _ in range(depth)] 10 | matched_keys = [key for key in args.keys() if key in router] 11 | 12 | for key in matched_keys: 13 | val = args[key] 14 | for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])): 15 | new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes) 16 | routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args}) 17 | return routed_args 18 | 19 | # following example for saving and setting rng here https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html 20 | class Deterministic(nn.Module): 21 | def __init__(self, net): 22 | super().__init__() 23 | self.net = net 24 | self.cpu_state = None 25 | self.cuda_in_fwd = None 26 | self.gpu_devices = None 27 | self.gpu_states = None 28 | 29 | def record_rng(self, *args): 30 | self.cpu_state = torch.get_rng_state() 31 | if torch.cuda._initialized: 32 | self.cuda_in_fwd = True 33 | self.gpu_devices, self.gpu_states = get_device_states(*args) 34 | 35 | def forward(self, *args, record_rng = False, set_rng = False, **kwargs): 36 | if record_rng: 37 | self.record_rng(*args) 38 | 39 | if not set_rng: 40 | return self.net(*args, **kwargs) 41 | 42 | rng_devices = [] 43 | if self.cuda_in_fwd: 44 | rng_devices = self.gpu_devices 45 | 46 | with torch.random.fork_rng(devices=rng_devices, enabled=True): 47 | torch.set_rng_state(self.cpu_state) 48 | if self.cuda_in_fwd: 49 | set_device_states(self.gpu_devices, self.gpu_states) 50 | return self.net(*args, **kwargs) 51 | 52 | # heavily inspired by https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py 53 | # once multi-GPU is confirmed working, refactor and send PR back to source 54 | class ReversibleBlock(nn.Module): 55 | def __init__(self, f, g): 56 | super().__init__() 57 | self.f = Deterministic(f) 58 | self.g = Deterministic(g) 59 | 60 | def forward(self, x, f_args = {}, g_args = {}): 61 | x1, x2 = torch.chunk(x, 2, dim=2) 62 | y1, y2 = None, None 63 | 64 | with torch.no_grad(): 65 | y1 = x1 + self.f(x2, record_rng=self.training, **f_args) 66 | y2 = x2 + self.g(y1, record_rng=self.training, **g_args) 67 | 68 | return torch.cat([y1, y2], dim=2) 69 | 70 | def backward_pass(self, y, dy, f_args = {}, g_args = {}): 71 | y1, y2 = torch.chunk(y, 2, dim=2) 72 | del y 73 | 74 | dy1, dy2 = torch.chunk(dy, 2, dim=2) 75 | del dy 76 | 77 | with torch.enable_grad(): 78 | y1.requires_grad = True 79 | gy1 = self.g(y1, set_rng=True, **g_args) 80 | torch.autograd.backward(gy1, dy2) 81 | 82 | with torch.no_grad(): 83 | x2 = y2 - gy1 84 | del y2, gy1 85 | 86 | dx1 = dy1 + y1.grad 87 | del dy1 88 | y1.grad = None 89 | 90 | with torch.enable_grad(): 91 | x2.requires_grad = True 92 | fx2 = self.f(x2, set_rng=True, **f_args) 93 | torch.autograd.backward(fx2, dx1, retain_graph=True) 94 | 95 | with torch.no_grad(): 96 | x1 = y1 - fx2 97 | del y1, fx2 98 | 99 | dx2 = dy2 + x2.grad 100 | del dy2 101 | x2.grad = None 102 | 103 | x = torch.cat([x1, x2.detach()], dim=2) 104 | dx = torch.cat([dx1, dx2], dim=2) 105 | 106 | return x, dx 107 | 108 | class _ReversibleFunction(Function): 109 | @staticmethod 110 | def forward(ctx, x, blocks, args): 111 | ctx.args = args 112 | for block, kwarg in zip(blocks, args): 113 | x = block(x, **kwarg) 114 | ctx.y = x.detach() 115 | ctx.blocks = blocks 116 | return x 117 | 118 | @staticmethod 119 | def backward(ctx, dy): 120 | y = ctx.y 121 | args = ctx.args 122 | for block, kwargs in zip(ctx.blocks[::-1], args[::-1]): 123 | y, dy = block.backward_pass(y, dy, **kwargs) 124 | return dy, None, None 125 | 126 | class SequentialSequence(nn.Module): 127 | def __init__(self, layers, args_route = {}): 128 | super().__init__() 129 | assert all(len(route) == len(layers) for route in args_route.values()), 'each argument route map must have the same depth as the number of sequential layers' 130 | self.layers = layers 131 | self.args_route = args_route 132 | 133 | def forward(self, x, **kwargs): 134 | args = route_args(self.args_route, kwargs, len(self.layers)) 135 | layers_and_args = list(zip(self.layers, args)) 136 | 137 | for (f, g), (f_args, g_args) in layers_and_args: 138 | x = x + f(x, **f_args) 139 | x = x + g(x, **g_args) 140 | return x 141 | 142 | class ReversibleSequence(nn.Module): 143 | def __init__(self, blocks, args_route = {}): 144 | super().__init__() 145 | self.args_route = args_route 146 | self.blocks = nn.ModuleList([ReversibleBlock(f=f, g=g) for f, g in blocks]) 147 | 148 | def forward(self, x, **kwargs): 149 | x = torch.cat([x, x], dim=-1) 150 | 151 | blocks = self.blocks 152 | args = route_args(self.args_route, kwargs, len(blocks)) 153 | args = list(map(lambda x: {'f_args': x[0], 'g_args': x[1]}, args)) 154 | 155 | out = _ReversibleFunction.apply(x, blocks, args) 156 | return torch.stack(out.chunk(2, dim=-1)).sum(dim=0) 157 | -------------------------------------------------------------------------------- /code/model/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import random 6 | import numpy as np 7 | from torch.utils.data import Dataset 8 | 9 | ################################################# 10 | #------------ Train & Test Function ------------# 11 | ################################################# 12 | def setup_seed(seed): 13 | #--- Fix random seed ---# 14 | torch.manual_seed(seed) 15 | torch.cuda.manual_seed_all(seed) 16 | np.random.seed(seed) 17 | random.seed(seed) 18 | torch.backends.cudnn.deterministic = True 19 | 20 | def train(args, model, device, train_loader, optimizer, epoch): 21 | model.train() 22 | loss2 = nn.CosineSimilarity(dim=0, eps=1e-8) 23 | train_loss = 0 24 | train_ccc = 0 25 | for batch_idx, (x, y) in enumerate(train_loader): 26 | #--- Extract Feature ---# 27 | RNA_geneID = torch.tensor(x[:,1].tolist()).long().to(device) 28 | Protein_geneID = torch.tensor(y[:,1].tolist()).long().to(device) 29 | rna_mask = torch.tensor(x[:,2].tolist()).bool().to(device) 30 | pro_mask = torch.tensor(y[:,2].tolist()).bool().to(device) 31 | x = torch.tensor(x[:,0].tolist(), dtype=torch.float32).to(device) 32 | y = torch.tensor(y[:,0].tolist(), dtype=torch.float32).to(device) 33 | 34 | #--- Prediction ---# 35 | optimizer.zero_grad() 36 | _, y_hat = model(x, RNA_geneID, Protein_geneID, enc_mask=rna_mask, dec_mask=pro_mask) 37 | 38 | #--- Compute Performance Metric ---# 39 | y_hat = torch.squeeze(y_hat) 40 | y_hat = torch.where(torch.isnan(y), torch.full_like(y_hat, 0), y_hat) 41 | y = torch.where(torch.isnan(y), torch.full_like(y, 0), y) 42 | 43 | loss = F.mse_loss(y_hat[pro_mask], y[pro_mask]) 44 | train_loss += loss.item() 45 | 46 | train_ccc += loss2(y_hat[pro_mask], y[pro_mask]).item() 47 | loss.backward() 48 | optimizer.step() 49 | 50 | train_loss /= len(train_loader) 51 | train_ccc /= len(train_loader) 52 | print('-'*15) 53 | print('--- Epoch {} ---'.format(epoch), flush=True) 54 | print('-'*15) 55 | print('Training set: Average loss: {:.4f}, Average ccc: {:.4f}'.format(train_loss, train_ccc), flush=True) 56 | return train_loss, train_ccc 57 | 58 | def test(model, device, test_loader): 59 | model.eval() 60 | loss2 = nn.CosineSimilarity(dim=0, eps=1e-8) 61 | test_loss = 0 62 | test_ccc = 0 63 | y_hat_all = [] 64 | y_all = [] 65 | with torch.no_grad(): 66 | for x, y in test_loader: 67 | #--- Extract Feature ---# 68 | RNA_geneID = torch.tensor(x[:,1].tolist()).long().to(device) 69 | Protein_geneID = torch.tensor(y[:,1].tolist()).long().to(device) 70 | rna_mask = torch.tensor(x[:,2].tolist()).bool().to(device) 71 | pro_mask = torch.tensor(y[:,2].tolist()).bool().to(device) 72 | x = torch.tensor(x[:,0].tolist(), dtype=torch.float32).to(device) 73 | y = torch.tensor(y[:,0].tolist(), dtype=torch.float32).to(device) 74 | 75 | #--- Prediction ---# 76 | _, y_hat = model(x, RNA_geneID, Protein_geneID, enc_mask=rna_mask, dec_mask=pro_mask) 77 | 78 | #--- Compute Performance Metric ---# 79 | y_hat = torch.squeeze(y_hat) 80 | y_hat = torch.where(torch.isnan(y), torch.full_like(y_hat, 0), y_hat) 81 | y = torch.where(torch.isnan(y), torch.full_like(y, 0), y) 82 | test_loss += F.mse_loss(y_hat[pro_mask], y[pro_mask]).item() 83 | test_ccc += loss2(y_hat[pro_mask], y[pro_mask]).item() 84 | 85 | if device == 'cpu': 86 | y_hat_all.extend(y_hat[pro_mask].view(y_hat.shape[0], -1).numpy().tolist()) 87 | y_all.extend(y[pro_mask].view(y_hat.shape[0], -1).numpy().tolist()) 88 | else: 89 | y_hat_all.extend(y_hat[pro_mask].view(y_hat.shape[0], -1).detach().cpu().numpy().tolist()) 90 | y_all.extend(y[pro_mask].view(y_hat.shape[0], -1).detach().cpu().numpy().tolist()) 91 | 92 | 93 | test_loss /= len(test_loader) 94 | test_ccc /= len(test_loader) 95 | return test_loss, test_ccc, np.array(y_hat_all), np.array(y_all) 96 | 97 | ################################################# 98 | #---------- Dataset Preprocess Function ---------# 99 | ################################################# 100 | def normalization(x, low=1e-8, high=1): 101 | MIN = min(x) 102 | MAX = max(x) 103 | x = low + (x-MIN)/(MAX-MIN)*(high-low) # zoom to (low, high) 104 | return x 105 | 106 | def fix_sc_normalize_truncate_padding(x, length): 107 | ''' 108 | x = (num_gene,1) 109 | 110 | ''' 111 | len_x = len(x.X[0]) 112 | tmp = [i for i in x.X[0]] 113 | tmp = normalization(tmp) 114 | if len_x >= length: # truncate 115 | x_value = tmp[:length] 116 | gene = x.var.iloc[:length]['my_Id'].astype(int).values.tolist() 117 | mask = np.full(length, True).tolist() 118 | else: # padding 119 | x_value = tmp.tolist() 120 | x_value.extend([0 for i in range(length-len_x)]) 121 | gene = x.var['my_Id'].astype(int).values.tolist() 122 | gene.extend([0 for i in range(length-len_x)]) 123 | mask = np.concatenate((np.full(len_x,True), np.full(length-len_x,False))) 124 | return x_value, gene, mask 125 | 126 | class fix_SCDataset(Dataset): 127 | def __init__(self, scRNA_adata, scP_adata, len_rna, len_protein): 128 | super().__init__() 129 | self.scRNA_adata = scRNA_adata 130 | self.scP_adata = scP_adata 131 | self.len_rna = len_rna 132 | self.len_protein = len_protein 133 | 134 | def __getitem__(self, index): 135 | k = self.scRNA_adata.obs.index[index] 136 | rna_value, rna_gene, rna_mask = fix_sc_normalize_truncate_padding(self.scRNA_adata[k], self.len_rna) 137 | pro_value, pro_gene, pro_mask = fix_sc_normalize_truncate_padding(self.scP_adata[k], self.len_protein) 138 | return np.array([rna_value, rna_gene, rna_mask]), np.array([pro_value, pro_gene, pro_mask]) 139 | 140 | def __len__(self): 141 | return self.scRNA_adata.n_obs 142 | 143 | def sc_normalize_truncate_padding(x, length): 144 | ''' 145 | x = (num_gene,1) 146 | 147 | ''' 148 | len_x = len(x.X[0]) 149 | tmp = [i for i in x.X[0]] 150 | tmp = normalization(tmp) 151 | if len_x >= length: # truncate 152 | gene = random.sample(range(len_x), length) 153 | x_value = [i for i in tmp[gene]] 154 | gene = x.var.iloc[gene]['my_Id'].astype(int).values.tolist() 155 | mask = np.full(length, True).tolist() 156 | else: # padding 157 | x_value = tmp.tolist() 158 | x_value.extend([0 for i in range(length-len_x)]) 159 | gene = x.var['my_Id'].astype(int).values.tolist() 160 | gene.extend([0 for i in range(length-len_x)]) 161 | mask = np.concatenate((np.full(len_x,True), np.full(length-len_x,False))) 162 | return x_value, gene, mask 163 | 164 | class SCDataset(Dataset): 165 | def __init__(self, scRNA_adata, scP_adata, len_rna, len_protein): 166 | super().__init__() 167 | self.scRNA_adata = scRNA_adata 168 | self.scP_adata = scP_adata 169 | self.len_rna = len_rna 170 | self.len_protein = len_protein 171 | 172 | def __getitem__(self, index): 173 | k = self.scRNA_adata.obs.index[index] 174 | rna_value, rna_gene, rna_mask = sc_normalize_truncate_padding(self.scRNA_adata[k], self.len_rna) 175 | pro_value, pro_gene, pro_mask = sc_normalize_truncate_padding(self.scP_adata[k], self.len_protein) 176 | return np.array([rna_value, rna_gene, rna_mask]), np.array([pro_value, pro_gene, pro_mask]) 177 | 178 | def __len__(self): 179 | return self.scRNA_adata.n_obs 180 | 181 | def attention_normalize(weights): 182 | for i in weights.columns: 183 | W_min = weights[i].min() 184 | W_max = weights[i].max() 185 | weights[i] = (weights[i]-W_min)/(W_max-W_min) 186 | for i in range(weights.shape[0]): 187 | W_min = weights.iloc[i].min() 188 | W_max = weights.iloc[i].max() 189 | weights.iloc[i] = (weights.iloc[i]-W_min)/(W_max-W_min) 190 | return(weights) -------------------------------------------------------------------------------- /code/visualization/fig3.r: -------------------------------------------------------------------------------- 1 | library(dplyr) 2 | library(ggpubr) 3 | library(data.table) 4 | library(gridExtra) 5 | library(extrafont) 6 | 7 | font_import(pattern = "Arial", prompt = FALSE) 8 | loadfonts(device = "win") 9 | cwd <- getwd() 10 | base_path <- normalizePath(file.path(cwd)) 11 | mode <- 'fewshot' 12 | path <- file.path(base_path, sprintf('/result/fig3/Systematic_benchmark_results_%s.csv', mode)) 13 | perform_all <- fread(path) 14 | current_names <- names(perform_all) 15 | current_names[current_names == "cosine_similarity"] <- "Cosine similarity" 16 | current_names[current_names == "mse"] <- "Mean squared error" 17 | current_names[current_names == "mae"] <- "Mean absolute error" 18 | current_names[current_names == "pearson"] <- "Pearson correlation coefficient" 19 | 20 | names(perform_all) <- current_names 21 | perform_all <- perform_all %>% mutate(index = row_number()) 22 | 23 | perform_summary <- perform_all %>% 24 | group_by(methods, dataset) %>% 25 | summarise(across(c(mse, cosine_similarity), mean, na.rm = TRUE), .groups = 'drop') 26 | 27 | plot_fig3 <- function(perform_all, mode, x_value, y_value) { 28 | model_labels <- c('scMM', 'CMAE', 'MultiVI', 'BABEL', 'Seurat', 'scMoGNN', 'cTP-net', 'sciPENN', 'scTranslator-scratch', 'scTranslator') 29 | num_model <- length(model_labels) 30 | colors <- scales::hue_pal()(num_model) 31 | names(colors) <- model_labels 32 | 33 | for (metric in c("Cosine similarity", "Pearson correlation coefficient")) { # c("Mean squared error", "Mean absolute error") 34 | perform_all$methods <- factor(perform_all$methods, levels = model_labels) 35 | p <- ggboxplot(perform_all, x = "methods", y = metric, color = "methods", palette = colors) + 36 | geom_jitter(aes(color = methods), width = 0.2, height = 0, alpha = 0.6) + 37 | geom_boxplot(aes(color = methods)) + 38 | facet_wrap(~dataset, nrow = 1) + 39 | labs(y = metric, x = NULL) + 40 | theme(axis.text.x = element_blank(), 41 | axis.ticks.x = element_blank(), 42 | panel.grid.major.y = element_line(size = 0.5, linetype = 'dashed', color = 'gray80'), 43 | panel.grid.major.x = element_blank(), 44 | panel.border = element_rect(color = "black", fill = NA, size = 0.5), 45 | strip.background = element_rect(color = "black", fill = "gray90", size = 0.5), 46 | strip.text = element_text(size = 10), 47 | text = element_text(size = 11), 48 | legend.position = "none", 49 | panel.spacing = unit(0.5, "lines"), 50 | axis.line.y = element_blank(), 51 | axis.line.x = element_blank()) + 52 | scale_y_continuous(limits = c(0, 0.98), breaks = seq(0, 1, by = 0.1)) 53 | p <- p + theme( 54 | panel.border = element_rect(color = "black", fill = NA, size = 0.5), 55 | strip.background = element_rect(color = "black", fill = "gray90", size = 0.5) 56 | ) 57 | print(p) 58 | 59 | } 60 | } 61 | 62 | x_value = 12 63 | y_value = 3 64 | windows(width = x_value, height = y_value) 65 | plot_fig3(perform_all, mode, x_value, y_value) 66 | -------------------------------------------------------------------------------- /dataset/download_link.txt: -------------------------------------------------------------------------------- 1 | https://drive.google.com/drive/folders/1XmdKikkG3g0yl1vKmY9lhru_78Nc0NMk?usp=sharing -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | local-attention=1.4.3 2 | numpy=1.21.5 3 | pandas=1.2.4 4 | python=3.8.13 5 | pytorch=1.12.1 6 | scanpy=1.9.1 7 | scikit-learn=1.1.1 8 | scipy=1.6.2 -------------------------------------------------------------------------------- /result/download_link.txt: -------------------------------------------------------------------------------- 1 | https://drive.google.com/drive/folders/1R4JEJjwP27yLqYMlOulmvGiocnlJZT3Z?usp=sharing -------------------------------------------------------------------------------- /scTranslator.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentAILabHealthcare/scTranslator/5f3c2de37c013bce8f7a14ebb135b7bdc141a6fe/scTranslator.jpg --------------------------------------------------------------------------------