├── .gitignore ├── tokenizer ├── __init__.py └── gene_tokenizer.py ├── assets ├── framework.png └── batch_correlation.png ├── data ├── __init__.py ├── csv2h5ad.py ├── preprocess.py └── dataloader.py ├── model ├── __init__.py ├── flashDiff.py ├── flashMHA.py ├── reversible.py ├── transformer.py └── EpiFoundation.py ├── loss └── loss.py ├── configs ├── pretrain │ └── atac_cross_binary.yml └── eval │ └── mini_atlas.yml ├── README.md ├── prepare_data.py ├── env.yml ├── finetune.py ├── pretrain_ddp.py ├── utils.py ├── pretrain_fsdp.py └── eval.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__/ 2 | **/.DS_Store -------------------------------------------------------------------------------- /tokenizer/__init__.py: -------------------------------------------------------------------------------- 1 | from .gene_tokenizer import * 2 | -------------------------------------------------------------------------------- /assets/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSC-VLAA/EpiFoundation/HEAD/assets/framework.png -------------------------------------------------------------------------------- /assets/batch_correlation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSC-VLAA/EpiFoundation/HEAD/assets/batch_correlation.png -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from data.csv2h5ad import convert_dense_csv_to_sparse_h5ad 2 | from data.preprocess import Preprocessor 3 | from data.dataloader import prepare_scDataset, prepare_data -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from EpiFoundation.model.EpiFoundation import EpiFoundation 2 | import logging 3 | import sys 4 | 5 | logger = logging.getLogger("scMultiomics") 6 | # check if logger has been initialized 7 | if not logger.hasHandlers() or len(logger.handlers) == 0: 8 | logger.propagate = False 9 | logger.setLevel(logging.INFO) 10 | handler = logging.StreamHandler(sys.stdout) 11 | handler.setLevel(logging.INFO) 12 | formatter = logging.Formatter( 13 | "%(name)s - %(levelname)s - %(message)s", datefmt="%H:%M:%S" 14 | ) 15 | handler.setFormatter(formatter) 16 | logger.addHandler(handler) 17 | -------------------------------------------------------------------------------- /loss/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | 6 | 7 | class MaskedMSELoss(nn.Module): 8 | def __init__(self): 9 | super().__init__() 10 | 11 | def masked_mse_loss( 12 | self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor 13 | ) -> torch.Tensor: 14 | """ 15 | Compute the masked MSE loss between input and target. 16 | """ 17 | mask = mask.float() 18 | loss = F.mse_loss(input * mask, target * mask, reduction="sum") 19 | return loss / mask.sum() 20 | 21 | def forward(self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: 22 | return self.masked_mse_loss(input, target, mask) -------------------------------------------------------------------------------- /data/csv2h5ad.py: -------------------------------------------------------------------------------- 1 | import scanpy as sc 2 | import scipy 3 | import os 4 | 5 | def convert_dense_csv_to_sparse_h5ad(input_file, output_floder=None): 6 | 7 | file_name = input_file.split('/')[-1].split('.')[0] 8 | cell_type = file_name.split(':')[1] 9 | batch = file_name.split(':')[0] 10 | output_file = f'{output_floder}/{file_name}.h5ad' 11 | 12 | adata = sc.read_csv(input_file).T 13 | adata.X = scipy.sparse.csr_matrix(adata.X) 14 | # add a obs column, traverse all the cells, and add the cell name to the obs column 15 | adata.obs['cell_id'] = adata.obs_names 16 | # traverse all cells, get the biggest value of each cell, and add it to the obs column 17 | for i in range(adata.shape[0]): 18 | # max expression value of each cell, used for binning 19 | adata.obs['max_exp'] = adata.X[i].max() 20 | adata.obs['non_zero'] = adata.X[i].count_nonzero() 21 | adata.obs['cell_type'] = cell_type 22 | # add the batch information to the obs column 23 | # currently, I treat name of (such asENCSR008NTI) as batch information 24 | # Cells from different batches may have a different meaning of same expression value 25 | adata.obs['batch'] = batch 26 | 27 | adata.var['gene_id'] = adata.var_names 28 | adata.uns['binned'] = False 29 | adata.uns['HVG'] = False 30 | adata.uns['log1p'] = False 31 | 32 | adata.write(output_file) 33 | return adata 34 | 35 | -------------------------------------------------------------------------------- /configs/pretrain/atac_cross_binary.yml: -------------------------------------------------------------------------------- 1 | task_name: fsdp_debug 2 | 3 | train: 4 | # for ditributed training 5 | local_rank: 0 6 | # random seed 7 | seed: 2002 8 | 9 | # training hyperparameters 10 | batch_size: 8 11 | lr: 1e-4 12 | epochs: 150 13 | gradient_accumulation_steps: 20 14 | amp: True 15 | save_ckpt_freq: 20 16 | resume: False 17 | 18 | model: 19 | encoder: transformer 20 | # pretrained: /home/jwu418/workspace/scMultiomics/experiment/atac_cross_pretrain/ckpts/Epoch_32_Step_26144_atac_cross_pretrain.pth # set to None if not using pretrained model 21 | pretrained: null 22 | embedding_method: id_only 23 | atac_max_len: 8000 24 | rna_max_len: 8000 25 | embedding_dim: 512 26 | num_layers: 6 27 | head_num: 8 28 | head_dim: 1024 29 | dropout: 0.2 30 | additional_config_path: /path/to/additional_config.json 31 | cell_emb_style: cls 32 | mvc_arch_style: concat query 33 | use_batch_labels: False 34 | 35 | task_weight: 36 | cell_type: 0.0 37 | mvc: 1.0 38 | 39 | valid: 40 | freq: 2 41 | 42 | data: 43 | bin_num: &bn 2 44 | append_cls: True 45 | train: 46 | atac_path: /home/jwu418/workspace/data/ours/test/bmmc_kidney_paired_atac.h5ad 47 | atac_key: X 48 | rna_path: /home/jwu418/workspace/data/ours/test/bmmc_kidney_rna_binned_binning_2.h5ad 49 | rna_key: X_binned 50 | test: 51 | atac_path: /home/jwu418/workspace/data/ours/valid/bmmc_kidney_paired_atac.h5ad 52 | atac_key: X 53 | rna_path: /home/jwu418/workspace/data/ours/valid/bmmc_kidney_rna_binned_binning_2.h5ad 54 | rna_key: X_binned 55 | 56 | vocab: 57 | rna_path: /home/jwu418/workspace/data/ours/vocab/bmmc_rna_vocab.json 58 | atac_path: /home/jwu418/workspace/data/ours/vocab/bmmc_atac_vocab.json 59 | cell_type_path: /home/jwu418/workspace/data/ours/vocab/bmmc_kidney_cell_vocab.json 60 | batch_path: /home/jwu418/workspace/data/ours/vocab/bmmc_kidney_batch_vocab.json 61 | special_tokens: 62 | pad: {token: , value: 2} 63 | # value of the mask is 1 plus bin_num 64 | mask: {token: , value: 3} 65 | cls: {token: , value: 0} 66 | 67 | -------------------------------------------------------------------------------- /configs/eval/mini_atlas.yml: -------------------------------------------------------------------------------- 1 | task_name: mini_atlas_eval 2 | train: 3 | # for ditributed training 4 | local_rank: 0 5 | # random seed 6 | seed: 2002 7 | 8 | # training hyperparameters 9 | batch_size: 8 10 | lr: 1e-4 11 | epochs: 150 12 | gradient_accumulation_steps: 20 13 | amp: True 14 | save_ckpt_freq: 20 15 | resume: True 16 | 17 | model: 18 | encoder: transformer 19 | pretrained: /home/jwu418/workspace/scMultiomics/result/mini_atlas/ckpts/Epoch_100_Step_187900_mini_atlas.pth # set to None if not using pretrained model 20 | # pretrained: null 21 | embedding_method: id_only 22 | atac_max_len: 12000 23 | rna_max_len: 8000 24 | embedding_dim: 512 25 | num_layers: 6 26 | head_num: 8 27 | head_dim: 1024 28 | dropout: 0.15 29 | additional_config_path: /path/to/additional_config.json 30 | cell_emb_style: cls 31 | mvc_arch_style: concat query 32 | use_batch_labels: False 33 | use_chr_labels: True 34 | 35 | cell_type_epochs: 0 36 | metric: True 37 | 38 | valid: 39 | freq: 2 40 | 41 | data: 42 | bin_num: &bn 2 43 | append_cls: True 44 | train: 45 | atac_path: /home/jwu418/workspace/data/ours/train/mini_atlas_atac_paired.h5ad 46 | atac_key: X 47 | rna_path: /home/jwu418/workspace/data/ours/train/mini_atlas_rna_binned_binning_2_reduced.h5ad 48 | rna_key: X 49 | test: 50 | atac_path: /home/jwu418/workspace/data/ours/valid/mini_atlas_atac_paired.h5ad 51 | atac_key: X 52 | rna_path: /home/jwu418/workspace/data/ours/valid/mini_atlas_rna_binned_binning_2_reduced.h5ad 53 | rna_key: X 54 | 55 | vocab: 56 | rna_path: /home/jwu418/workspace/data/ours/vocab/rna_vocab.json 57 | atac_path: /home/jwu418/workspace/data/ours/vocab/atac_vocab.json 58 | cell_type_path: /home/jwu418/workspace/data/ours/vocab/mini_atlas_cell_vocab.json 59 | batch_path: /home/jwu418/workspace/data/ours/vocab/mini_atlas_batch_vocab.json 60 | chr_path: /home/jwu418/workspace/data/ours/vocab/chr_vocab.json 61 | gene2chr_path: /home/jwu418/workspace/data/ours/vocab/gene2chr.json 62 | special_tokens: 63 | pad: {token: , value: 2} 64 | # value of the mask is 1 plus bin_num 65 | mask: {token: , value: 3} 66 | cls: {token: , value: 0} 67 | 68 | -------------------------------------------------------------------------------- /model/flashDiff.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | 6 | # from .kernel.rotary import apply_rotary_emb 7 | from flash_attn import flash_attn_func 8 | 9 | def init_method(tensor, **kwargs): 10 | nn.init.kaiming_uniform_(tensor, a=math.sqrt(5)) 11 | 12 | def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: 13 | """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" 14 | bs, n_kv_heads, slen, head_dim = x.shape 15 | if n_rep == 1: 16 | return x 17 | return ( 18 | x[:, :, None, :, :] 19 | .expand(bs, n_kv_heads, n_rep, slen, head_dim) 20 | .reshape(bs, n_kv_heads * n_rep, slen, head_dim) 21 | ) 22 | 23 | def lambda_init_fn(depth): 24 | return 0.8 - 0.6 * math.exp(-0.3 * depth) 25 | 26 | class RMSNorm(nn.Module): 27 | def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True, memory_efficient=False): 28 | super().__init__() 29 | self.dim = dim 30 | self.eps = eps 31 | self.elementwise_affine = elementwise_affine 32 | if self.elementwise_affine: 33 | self.weight = nn.Parameter(torch.ones(dim)) 34 | else: 35 | self.register_parameter('weight', None) 36 | 37 | def _norm(self, x): 38 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 39 | 40 | def forward(self, x): 41 | output = self._norm(x.float()).type_as(x) 42 | if self.weight is not None: 43 | output = output * self.weight 44 | return output 45 | 46 | def extra_repr(self) -> str: 47 | return f'dim={self.dim}, eps={self.eps}, elementwise_affine={self.elementwise_affine}' 48 | 49 | 50 | class MultiheadFlashDiff(nn.Module): 51 | """ 52 | (Recommended) 53 | DiffAttn implemented with FlashAttention, for packages that support different qk/v dimensions 54 | e.g., our customized-flash-attention (https://aka.ms/flash-diff) and xformers (https://github.com/facebookresearch/xformers) 55 | """ 56 | def __init__( 57 | self, 58 | embed_dim, 59 | num_heads, 60 | attention_dropout=0.0, 61 | lambda_init=0.8, 62 | ): 63 | super().__init__() 64 | self.embed_dim = embed_dim 65 | # num_heads set to half of Transformer's #heads 66 | # self.num_heads = num_heads // args.model_parallel_size 67 | self.num_heads = num_heads 68 | self.num_kv_heads = num_heads 69 | # self.num_kv_heads = args.decoder_kv_attention_heads // args.model_parallel_size if args.decoder_kv_attention_heads is not None else num_heads // args.model_parallel_size 70 | self.n_rep = self.num_heads // self.num_kv_heads 71 | self.dropout = attention_dropout 72 | 73 | self.head_dim = embed_dim // num_heads // 2 74 | self.scaling = self.head_dim ** -0.5 75 | 76 | self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False) 77 | self.k_proj = nn.Linear(embed_dim, embed_dim // self.n_rep, bias=False) 78 | self.v_proj = nn.Linear(embed_dim, embed_dim // self.n_rep, bias=False) 79 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False) 80 | 81 | self.lambda_init = lambda_init 82 | self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1)) 83 | self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1)) 84 | self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1)) 85 | self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1)) 86 | 87 | self.subln = RMSNorm(2 * self.head_dim, eps=1e-5, elementwise_affine=True) 88 | 89 | def forward( 90 | self, 91 | x, 92 | ): 93 | bsz, tgt_len, embed_dim = x.size() 94 | src_len = tgt_len 95 | 96 | q = self.q_proj(x) 97 | k = self.k_proj(x) 98 | v = self.v_proj(x) 99 | 100 | q = q.view(bsz, tgt_len, 2 * self.num_heads, self.head_dim) 101 | k = k.view(bsz, src_len, 2 * self.num_kv_heads, self.head_dim) 102 | v = v.view(bsz, src_len, self.num_kv_heads, 2, self.head_dim) 103 | 104 | 105 | offset = src_len - tgt_len 106 | q = q.reshape(bsz, tgt_len, self.num_heads, 2, self.head_dim) 107 | k = k.reshape(bsz, src_len, self.num_kv_heads, 2, self.head_dim) 108 | q1, q2 = q[:, :, :, 0], q[:, :, :, 1] 109 | k1, k2 = k[:, :, :, 0], k[:, :, :, 1] 110 | v1, v2 = v[:, :, :, 0], v[:, :, :, 1] 111 | 112 | attn11 = flash_attn_func(q1, k1, v1, dropout_p=self.dropout) 113 | attn12 = flash_attn_func(q1, k1, v2, dropout_p=self.dropout) 114 | attn1 = torch.cat([attn11, attn12], dim=-1) 115 | 116 | attn21 = flash_attn_func(q2, k2, v1, dropout_p=self.dropout) 117 | attn22 = flash_attn_func(q2, k2, v2, dropout_p=self.dropout) 118 | attn2 = torch.cat([attn21, attn22], dim=-1) 119 | 120 | lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(q) 121 | lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(q) 122 | lambda_full = lambda_1 - lambda_2 + self.lambda_init 123 | attn = attn1 - lambda_full * attn2 124 | 125 | attn = self.subln(attn) 126 | attn = attn * (1 - self.lambda_init) 127 | attn = attn.reshape(bsz, tgt_len, self.num_heads * 2 * self.head_dim) 128 | 129 | attn = self.out_proj(attn) 130 | return attn 131 | -------------------------------------------------------------------------------- /model/flashMHA.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from einops import rearrange 7 | 8 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func 9 | from flash_attn.bert_padding import unpad_input, pad_input 10 | 11 | 12 | class FlashAttention(nn.Module): 13 | """Implement the scaled dot product attention with softmax. 14 | Arguments 15 | --------- 16 | softmax_scale: The temperature to use for the softmax attention. 17 | (default: 1/sqrt(d_keys) where d_keys is computed at 18 | runtime) 19 | attention_dropout: The dropout rate to apply to the attention 20 | (default: 0.0) 21 | """ 22 | def __init__(self, softmax_scale=None, attention_dropout=0.0): 23 | super().__init__() 24 | self.softmax_scale = softmax_scale 25 | self.dropout_p = attention_dropout 26 | 27 | def get_attention_weights(self, qkv): 28 | batch_size, seqlen, _, nheads, head_dim = qkv.shape 29 | q = qkv[:, :, 0, :, :] 30 | k = qkv[:, :, 1, :, :] 31 | 32 | q = rearrange(q, 'b s h d -> b h s d') 33 | k = rearrange(k, 'b s h d -> b h s d') 34 | 35 | scale = head_dim ** -0.5 36 | attn_weights = torch.matmul(q, k.transpose(-2, -1)) * scale 37 | attn_weights = F.softmax(attn_weights, dim=-1) 38 | return attn_weights 39 | 40 | def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None, 41 | max_s=None, need_weights=False): 42 | """Implements the multihead softmax attention. 43 | Arguments 44 | --------- 45 | qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None 46 | if unpadded: (nnz, 3, h, d) 47 | key_padding_mask: a bool tensor of shape (B, S) 48 | """ 49 | # assert not need_weights 50 | assert qkv.dtype in [torch.float16, torch.bfloat16] 51 | assert qkv.is_cuda 52 | if need_weights: 53 | attn_weights = self.get_attention_weights(qkv) 54 | # get the attention weights for the first token 55 | attn_weights = attn_weights[:, :, 0, :] 56 | else: 57 | attn_weights = None 58 | 59 | if cu_seqlens is None: 60 | batch_size = qkv.shape[0] 61 | seqlen = qkv.shape[1] 62 | if key_padding_mask is None: 63 | qkv = rearrange(qkv, 'b s ... -> (b s) ...') 64 | max_s = seqlen 65 | cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, 66 | device=qkv.device) 67 | output = flash_attn_varlen_qkvpacked_func( 68 | qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0, 69 | softmax_scale=self.softmax_scale, causal=causal 70 | ) 71 | output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) 72 | else: 73 | nheads = qkv.shape[-2] 74 | x = rearrange(qkv, 'b s three h d -> b s (three h d)') 75 | x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask) 76 | x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads) 77 | output_unpad = flash_attn_varlen_qkvpacked_func( 78 | x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0, 79 | softmax_scale=self.softmax_scale, causal=causal 80 | ) 81 | output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), 82 | indices, batch_size, seqlen), 83 | 'b s (h d) -> b s h d', h=nheads) 84 | else: 85 | assert max_s is not None 86 | output = flash_attn_varlen_qkvpacked_func( 87 | qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0, 88 | softmax_scale=self.softmax_scale, causal=causal 89 | ) 90 | return output, attn_weights 91 | 92 | 93 | class FlashMHA(nn.Module): 94 | 95 | def __init__(self, embed_dim, num_heads, bias=True, batch_first=True, attention_dropout=0.0, 96 | causal=False, device=None, dtype=None) -> None: 97 | assert batch_first 98 | factory_kwargs = {'device': device, 'dtype': dtype} 99 | super().__init__() 100 | self.embed_dim = embed_dim 101 | self.causal = causal 102 | 103 | self.num_heads = num_heads 104 | assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads" 105 | self.head_dim = self.embed_dim // num_heads 106 | assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8" 107 | 108 | self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs) 109 | self.inner_attn = FlashAttention(attention_dropout=attention_dropout) 110 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) 111 | 112 | def forward(self, x, key_padding_mask=None, need_weights=False): 113 | """x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) 114 | key_padding_mask: bool tensor of shape (batch, seqlen) 115 | """ 116 | qkv = self.Wqkv(x) 117 | qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads) 118 | context, attn_weights = self.inner_attn(qkv, key_padding_mask=key_padding_mask, 119 | need_weights=need_weights, causal=self.causal) 120 | return self.out_proj(rearrange(context, 'b s h d -> b s (h d)')), attn_weights -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EpiFoundation: A Foundation Model for Single-Cell ATAC-seq via Peak-to-Gene Alignment 2 | 3 | This repo contains official Pytorch implementation of **EpiFoundation** in our paper: [EpiFoundation: A Foundation Model for Single-Cell ATAC-seq via Peak-to-Gene Alignment]() 4 | 5 | ![image-20250204132627694](./assets/framework.png) 6 | 7 | ## Introduction 8 | 9 | Foundation models exhibit strong capabilities for downstream tasks by learning generalized representations through self-supervised pre-training on large datasets. While several foundation models have been developed for single-cell RNA-seq (scRNA-seq) data, there is still a lack of models specifically tailored for single-cell ATAC-seq (scATAC-seq), which measures epigenetic information in individual cells. The principal challenge in developing such a model lies in the vast number of scATAC peaks and the significant sparsity of the data, which complicates the formulation of peak-to-peak correlations. To address this challenge, we introduce **EpiFoundation**, a foundation model for learning cell representations from the high-dimensional and sparse space of peaks. EpiFoundation relies on an innovative cross-modality pre-training procedure with two key technical innovations. First, EpiFoundation exclusively processes the non-zero peak set, thereby enhancing the density of cell-specific information within the input data. Second, EpiFoundation utilizes dense gene expression information to supervise the pre-training process, aligning peak-to-gene correlations. EpiFoundation can handle various types of downstream tasks, including cell-type annotation, batch correction, and gene expression prediction. To train and validate EpiFoundation, we curated **MiniAtlas**, a dataset of 100,000+ single cells with paired scRNA-seq and scATAC-seq data, along with diverse test sets spanning various tissues and cell types for robust evaluation. EpiFoundation demonstrates **state-of-the-art performance across multiple tissues and diverse downstream tasks**. 10 | 11 | ![image-20250204133126719](./assets/batch_correlation.png) 12 | 13 | 14 | 15 | ## Data and model Access 16 | 17 | - Our paired scRNA-seq and scATAC-seq data for pre-training and fine-tuning can be found at [MiniAtlas](https://huggingface.co/datasets/UCSC-VLAA/MiniAtlas). 18 | - The pre-trained model weights can be found at [EpiFoundation](https://huggingface.co/UCSC-VLAA/EpiFoundation). 19 | 20 | ## Quick Start 21 | 22 | 0. prepare conda env from `env.yml` 23 | 24 | 1. Preprocess your data using `prepare_data.py` 25 | 26 | - call `preprocess()` to process scRNA-seq data, divide the data into train, valid and test set, and perform binning. 27 | - call `get_pair_data()` to get the paired scATAC-seq data 28 | - generation of vocabularies 29 | - Optional: call reduce_data() to transform data into sparse format and reduce the data size. 30 | 31 | 32 | 2. Prepare batch data: 33 | 34 | get the cell and gene vocab: 35 | 36 | ```python 37 | vocab = GeneVocab.from_file(vocab_config['path']) 38 | cell_vocab = GeneVocab.from_file(vocab_config['cell_type_path']) 39 | ``` 40 | 41 | Init the paired dataset: 42 | 43 | ```python 44 | train_set = PairedSCDataset( 45 | rna_file = data_config['train']['rna_path'], 46 | atac_file= data_config['train']['atac_path'], 47 | rna_key = data_config['train']['rna_key'], 48 | atac_key = data_config['train']['atac_key'], 49 | rna_vocab = rna_vocab, 50 | atac_vocab = atac_vocab, 51 | cell_vocab = cell_vocab, 52 | batch_vocab= batch_vocab, 53 | chr_vocab = chr_vocab, 54 | gene2chr_file= vocab_config['gene2chr_path'], 55 | rna_max_len = train_config['model']['rna_max_len'], 56 | atac_max_len = train_config['model']['atac_max_len'], 57 | pad_token = pad['token'], 58 | rna_pad_value = pad['value'], 59 | cls_token = cls['token'], 60 | # reg_token= reg['token'], 61 | logger = logger, 62 | ) 63 | ``` 64 | 65 | Get the data loader 66 | 67 | ```python 68 | train_sampler = DistributedSampler(train_set) 69 | train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, sampler=train_sampler, num_workers=4) 70 | ``` 71 | 72 | Detailed code can be find in `pretrain_ddp.py`, `pretrain_fsdp.py` and `data/dataloader.py` 73 | 74 | 3. Construct the model using config from ./configs/test/rna_transformer.yml 75 | 76 | ```python 77 | model = EpiFoundation( 78 | num_class_cell = len(cell_vocab), 79 | num_rnas = len(rna_vocab), 80 | num_atacs = len(atac_vocab), 81 | num_values= data_config['bin_num'], 82 | num_chrs= len(chr_vocab), 83 | embed_dim = train_config['model']['embedding_dim'], 84 | depth = train_config['model']['num_layers'], 85 | heads = train_config['model']['head_num'], 86 | head_dim = train_config['model']['head_dim'], 87 | encoder = model_name, 88 | dropout = train_config['model']['dropout'], 89 | pad_token_idx_rna = rna_vocab[pad['token']], 90 | pad_token_idx_atac = atac_vocab[pad['token']], 91 | cell_emb_style = train_config['model']['cell_emb_style'], 92 | mvc_arch_style = train_config['model']['mvc_arch_style'], 93 | use_batch_labels = train_config['model']['use_batch_labels'], 94 | batch_label_num= len(batch_vocab), 95 | use_chr_labels= train_config['model']['use_chr_labels'], 96 | ).to(device) 97 | ``` 98 | 99 | 100 | ## Acknowledgment 101 | 102 | We would like to thank the TPU Research Cloud (TRC) program and the Google Cloud Research Credits program for Research program for supporting our computing needs. W.H. and Z.J. are supported by the National Institute Of General Medical Sciences of the National Institutes of Health (NIH), under Award Number R35GM150887 and R35GM154865 respectively. 103 | 104 | ## Citation 105 | 106 | ``` 107 | @article {Wu2025.02.05.636688, 108 | author = {Wu, Juncheng and Wan, Changxin and Ji, Zhicheng and Zhou, Yuyin and Hou, Wenpin}, 109 | title = {EpiFoundation: A Foundation Model for Single-Cell ATAC-seq via Peak-to-Gene Alignment}, 110 | elocation-id = {2025.02.05.636688}, 111 | year = {2025}, 112 | doi = {10.1101/2025.02.05.636688}, 113 | URL = {https://www.biorxiv.org/content/early/2025/02/08/2025.02.05.636688}, 114 | eprint = {https://www.biorxiv.org/content/early/2025/02/08/2025.02.05.636688.full.pdf}, 115 | journal = {bioRxiv} 116 | } 117 | ``` 118 | 119 | ## Contact 120 | 121 | If you have any questions, please feel free to raise an issue or contact us directly: Juncheng Wu (jwu418@ucsc.edu), Changxin Wan (changxin.wan@duke.edu). 122 | -------------------------------------------------------------------------------- /model/reversible.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from operator import itemgetter 4 | from torch.autograd.function import Function 5 | from torch.utils.checkpoint import get_device_states, set_device_states 6 | 7 | # for routing arguments into the functions of the reversible layer 8 | def route_args(router, args, depth): 9 | routed_args = [(dict(), dict()) for _ in range(depth)] 10 | matched_keys = [key for key in args.keys() if key in router] 11 | 12 | for key in matched_keys: 13 | val = args[key] 14 | for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])): 15 | new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes) 16 | routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args}) 17 | return routed_args 18 | 19 | # following example for saving and setting rng here https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html 20 | class Deterministic(nn.Module): 21 | def __init__(self, net): 22 | super().__init__() 23 | self.net = net 24 | self.cpu_state = None 25 | self.cuda_in_fwd = None 26 | self.gpu_devices = None 27 | self.gpu_states = None 28 | 29 | def record_rng(self, *args): 30 | self.cpu_state = torch.get_rng_state() 31 | if torch.cuda._initialized: 32 | self.cuda_in_fwd = True 33 | self.gpu_devices, self.gpu_states = get_device_states(*args) 34 | 35 | def forward(self, *args, record_rng = False, set_rng = False, **kwargs): 36 | if record_rng: 37 | self.record_rng(*args) 38 | 39 | if not set_rng: 40 | return self.net(*args, **kwargs) 41 | 42 | rng_devices = [] 43 | if self.cuda_in_fwd: 44 | rng_devices = self.gpu_devices 45 | 46 | with torch.random.fork_rng(devices=rng_devices, enabled=True): 47 | torch.set_rng_state(self.cpu_state) 48 | if self.cuda_in_fwd: 49 | set_device_states(self.gpu_devices, self.gpu_states) 50 | return self.net(*args, **kwargs) 51 | 52 | # heavily inspired by https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py 53 | # once multi-GPU is confirmed working, refactor and send PR back to source 54 | class ReversibleBlock(nn.Module): 55 | def __init__(self, f, g): 56 | super().__init__() 57 | self.f = Deterministic(f) 58 | self.g = Deterministic(g) 59 | 60 | def forward(self, x, f_args = {}, g_args = {}): 61 | x1, x2 = torch.chunk(x, 2, dim=2) 62 | y1, y2 = None, None 63 | 64 | with torch.no_grad(): 65 | y1 = x1 + self.f(x2, record_rng=self.training, **f_args) 66 | y2 = x2 + self.g(y1, record_rng=self.training, **g_args) 67 | 68 | return torch.cat([y1, y2], dim=2) 69 | 70 | def backward_pass(self, y, dy, f_args = {}, g_args = {}): 71 | y1, y2 = torch.chunk(y, 2, dim=2) 72 | del y 73 | 74 | dy1, dy2 = torch.chunk(dy, 2, dim=2) 75 | del dy 76 | 77 | with torch.enable_grad(): 78 | y1.requires_grad = True 79 | gy1 = self.g(y1, set_rng=True, **g_args) 80 | torch.autograd.backward(gy1, dy2) 81 | 82 | with torch.no_grad(): 83 | x2 = y2 - gy1 84 | del y2, gy1 85 | 86 | dx1 = dy1 + y1.grad 87 | del dy1 88 | y1.grad = None 89 | 90 | with torch.enable_grad(): 91 | x2.requires_grad = True 92 | fx2 = self.f(x2, set_rng=True, **f_args) 93 | torch.autograd.backward(fx2, dx1, retain_graph=True) 94 | 95 | with torch.no_grad(): 96 | x1 = y1 - fx2 97 | del y1, fx2 98 | 99 | dx2 = dy2 + x2.grad 100 | del dy2 101 | x2.grad = None 102 | 103 | x = torch.cat([x1, x2.detach()], dim=2) 104 | dx = torch.cat([dx1, dx2], dim=2) 105 | 106 | return x, dx 107 | 108 | class _ReversibleFunction(Function): 109 | @staticmethod 110 | def forward(ctx, x, blocks, args): 111 | ctx.args = args 112 | for block, kwarg in zip(blocks, args): 113 | x = block(x, **kwarg) 114 | ctx.y = x.detach() 115 | ctx.blocks = blocks 116 | return x 117 | 118 | @staticmethod 119 | def backward(ctx, dy): 120 | y = ctx.y 121 | args = ctx.args 122 | for block, kwargs in zip(ctx.blocks[::-1], args[::-1]): 123 | y, dy = block.backward_pass(y, dy, **kwargs) 124 | return dy, None, None 125 | 126 | class SequentialSequence(nn.Module): 127 | def __init__(self, layers, args_route = {}): 128 | super().__init__() 129 | assert all(len(route) == len(layers) for route in args_route.values()), 'each argument route map must have the same depth as the number of sequential layers' 130 | self.layers = layers 131 | self.args_route = args_route 132 | 133 | def forward(self, x, output_attentions = False, **kwargs): 134 | args = route_args(self.args_route, kwargs, len(self.layers)) 135 | layers_and_args = list(zip(self.layers, args)) 136 | 137 | if output_attentions: 138 | attn_weights = [] 139 | for (f, g), (f_args, g_args) in layers_and_args: 140 | if output_attentions: 141 | x = x + f(x, output_attentions = output_attentions, **f_args)[0] 142 | attn_weights.append(f(x, output_attentions = output_attentions, **f_args)[1].unsqueeze(0)) 143 | else: 144 | x = x + f(x, **f_args) 145 | x = x + g(x, **g_args) 146 | if output_attentions: 147 | attn_weights = torch.transpose(torch.cat(attn_weights, dim=0), 0, 1) # the final dim is (batch, layer, head, len, len) 148 | attn_weights = torch.mean(attn_weights, dim=1) # the dim is (batch, head, len, len) 149 | return x, attn_weights 150 | else: 151 | return x 152 | 153 | class ReversibleSequence(nn.Module): 154 | def __init__(self, blocks, args_route = {}): 155 | super().__init__() 156 | self.args_route = args_route 157 | self.blocks = nn.ModuleList([ReversibleBlock(f=f, g=g) for f, g in blocks]) 158 | 159 | def forward(self, x, **kwargs): 160 | x = torch.cat([x, x], dim=-1) 161 | 162 | blocks = self.blocks 163 | args = route_args(self.args_route, kwargs, len(blocks)) 164 | args = list(map(lambda x: {'f_args': x[0], 'g_args': x[1]}, args)) 165 | 166 | out = _ReversibleFunction.apply(x, blocks, args) 167 | return torch.stack(out.chunk(2, dim=-1)).sum(dim=0) 168 | -------------------------------------------------------------------------------- /tokenizer/gene_tokenizer.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | from pathlib import Path 4 | from collections import Counter, OrderedDict 5 | from typing import Dict, Iterable, List, Optional, Tuple, Union 6 | from typing_extensions import Self 7 | 8 | import numpy as np 9 | import pandas as pd 10 | import torch 11 | import torchtext.vocab as torch_vocab 12 | from torchtext.vocab import Vocab 13 | 14 | class GeneVocab(Vocab): 15 | """ 16 | Vocabulary for genes. 17 | """ 18 | 19 | def __init__( 20 | self, 21 | gene_list_or_vocab: Union[List[str], Vocab], 22 | specials: Optional[List[str]] = None, 23 | special_first: bool = True, 24 | default_token: Optional[str] = "", 25 | ) -> None: 26 | """ 27 | Initialize the vocabulary. 28 | Note: add specials only works when init from a gene list. 29 | 30 | Args: 31 | gene_list_or_vocab (List[str] or Vocab): List of gene names or a 32 | Vocab object. 33 | specials (List[str]): List of special tokens. 34 | special_first (bool): Whether to add special tokens to the beginning 35 | of the vocabulary. 36 | default_token (str): Default token, by default will set to "", 37 | if "" is in the vocabulary. 38 | """ 39 | if isinstance(gene_list_or_vocab, Vocab): 40 | _vocab = gene_list_or_vocab 41 | if specials is not None: 42 | raise ValueError( 43 | "receive non-empty specials when init from a Vocab object." 44 | ) 45 | elif isinstance(gene_list_or_vocab, list): 46 | _vocab = self._build_vocab_from_iterator( 47 | gene_list_or_vocab, 48 | specials=specials, 49 | special_first=special_first, 50 | ) 51 | else: 52 | raise ValueError( 53 | "gene_list_or_vocab must be a list of gene names or a Vocab object." 54 | ) 55 | super().__init__(_vocab.vocab) 56 | if default_token is not None and default_token in self: 57 | self.set_default_token(default_token) 58 | 59 | @classmethod 60 | def from_file(cls, file_path: Union[Path, str]) -> Self: 61 | """ 62 | Load the vocabulary from a file. The file should be either a pickle or a 63 | json file of token to index mapping. 64 | """ 65 | if isinstance(file_path, str): 66 | file_path = Path(file_path) 67 | if file_path.suffix == ".pkl": 68 | with file_path.open("rb") as f: 69 | vocab = pickle.load(f) 70 | return cls(vocab) 71 | elif file_path.suffix == ".json": 72 | with file_path.open("r") as f: 73 | token2idx = json.load(f) 74 | return cls.from_dict(token2idx) 75 | else: 76 | raise ValueError( 77 | f"{file_path} is not a valid file type. " 78 | "Only .pkl and .json are supported." 79 | ) 80 | 81 | @classmethod 82 | def from_dict( 83 | cls, 84 | token2idx: Dict[str, int], 85 | default_token: Optional[str] = "", 86 | ) -> Self: 87 | """ 88 | Load the vocabulary from a dictionary. 89 | 90 | Args: 91 | token2idx (Dict[str, int]): Dictionary mapping tokens to indices. 92 | """ 93 | # initiate an empty vocabulary first 94 | _vocab = cls([]) 95 | 96 | # add the tokens to the vocabulary, GeneVocab requires consecutive indices 97 | for t, i in sorted(token2idx.items(), key=lambda x: x[1]): 98 | _vocab.insert_token(t, i) 99 | 100 | if default_token is not None and default_token in _vocab: 101 | _vocab.set_default_token(default_token) 102 | 103 | return _vocab 104 | 105 | def _build_vocab_from_iterator( 106 | self, 107 | iterator: Iterable, 108 | min_freq: int = 1, 109 | specials: Optional[List[str]] = None, 110 | special_first: bool = True, 111 | ) -> Vocab: 112 | """ 113 | Build a Vocab from an iterator. This function is modified from 114 | torchtext.vocab.build_vocab_from_iterator. The original function always 115 | splits tokens into characters, which is not what we want. 116 | 117 | Args: 118 | iterator (Iterable): Iterator used to build Vocab. Must yield list 119 | or iterator of tokens. 120 | min_freq (int): The minimum frequency needed to include a token in 121 | the vocabulary. 122 | specials (List[str]): Special symbols to add. The order of supplied 123 | tokens will be preserved. 124 | special_first (bool): Whether to add special tokens to the beginning 125 | 126 | Returns: 127 | torchtext.vocab.Vocab: A `Vocab` object 128 | """ 129 | 130 | counter = Counter() 131 | counter.update(iterator) 132 | 133 | if specials is not None: 134 | for tok in specials: 135 | del counter[tok] 136 | 137 | sorted_by_freq_tuples = sorted(counter.items(), key=lambda x: x[0]) 138 | sorted_by_freq_tuples.sort(key=lambda x: x[1], reverse=True) 139 | ordered_dict = OrderedDict(sorted_by_freq_tuples) 140 | 141 | if specials is not None: 142 | if special_first: 143 | specials = specials[::-1] 144 | for symbol in specials: 145 | ordered_dict.update({symbol: min_freq}) 146 | ordered_dict.move_to_end(symbol, last=not special_first) 147 | 148 | word_vocab = torch_vocab.vocab(ordered_dict, min_freq=min_freq) 149 | return word_vocab 150 | 151 | @property 152 | def pad_token(self) -> Optional[str]: 153 | """ 154 | Get the pad token. 155 | """ 156 | if getattr(self, "_pad_token", None) is None: 157 | self._pad_token = None 158 | return self._pad_token 159 | 160 | @pad_token.setter 161 | def pad_token(self, pad_token: str) -> None: 162 | """ 163 | Set the pad token. Will not add the pad token to the vocabulary. 164 | 165 | Args: 166 | pad_token (str): Pad token, should be in the vocabulary. 167 | """ 168 | if pad_token not in self: 169 | raise ValueError(f"{pad_token} is not in the vocabulary.") 170 | self._pad_token = pad_token 171 | 172 | def save_json(self, file_path: Union[Path, str]) -> None: 173 | """ 174 | Save the vocabulary to a json file. 175 | """ 176 | if isinstance(file_path, str): 177 | file_path = Path(file_path) 178 | with file_path.open("w") as f: 179 | json.dump(self.get_stoi(), f, indent=2) 180 | 181 | def set_default_token(self, default_token: str) -> None: 182 | """ 183 | Set the default token. 184 | 185 | Args: 186 | default_token (str): Default token. 187 | """ 188 | if default_token not in self: 189 | raise ValueError(f"{default_token} is not in the vocabulary.") 190 | self.set_default_index(self[default_token]) -------------------------------------------------------------------------------- /prepare_data.py: -------------------------------------------------------------------------------- 1 | import scanpy as sc 2 | from scanpy import AnnData 3 | import scipy 4 | import os 5 | from model import logger 6 | import numpy as np 7 | from data.preprocess import Preprocessor 8 | from tokenizer import GeneVocab 9 | from data.dataloader import * 10 | import yaml 11 | from tqdm import tqdm 12 | 13 | 14 | def divide_data(data: AnnData, radio: dict = {'train': 0.9, 'test': 0.05, 'valid': 0.05}): 15 | ''' 16 | Divide the data into train, test, valid. 17 | return 3 AnnData object. 18 | ''' 19 | assert sum(radio.values()) == 1.0, 'The sum of radio should be 1.0' 20 | data = data.copy() 21 | # get the number of cells 22 | n_cells = data.shape[0] 23 | # get the index of cells 24 | idx = np.arange(n_cells) 25 | np.random.shuffle(idx) 26 | # get the number of cells for each part 27 | n_train = int(n_cells * radio['train']) 28 | n_test = int(n_cells * radio['test']) 29 | n_valid = int(n_cells * radio['valid']) 30 | # divide the data 31 | train_data = data[idx[:n_train]] 32 | test_data = data[idx[n_train:n_train+n_test]] 33 | valid_data = data[idx[n_train+n_test:]] 34 | 35 | return train_data, test_data, valid_data 36 | 37 | def preprocess(): 38 | 39 | preprocess_config = { 40 | 'path': '/home/jwu418/workspace/data/ours/', 41 | 'raw_data': 'pbmc_rna_s1.h5ad', 42 | 'use_key': 'X', 43 | 'filter_gene_by_counts': False, 44 | 'filter_cell_by_counts': False, 45 | 'normalize_total': False, 46 | 'result_normed_key': 'X_normed', 47 | 'log1p': False, 48 | 'result_log1p_key': 'X_log1p', 49 | 'subset_hvg': False, 50 | 'hvg_use_key': None, 51 | 'hvg_flavor': 'seurat_v3', 52 | 'binning': [2], 53 | 'result_binned_key': 'X_binned', 54 | 'batch_key': 'batch', 55 | 'output_name': 'pbmc_rna_s1', 56 | } 57 | file = '{}/raw/{}'.format(preprocess_config['path'], preprocess_config['raw_data']) 58 | adata = sc.read_h5ad(file) 59 | # devide data into train, test, valid. with 0.8,0.1,0.1 60 | # adata._raw._var.rename(columns={'_index': 'genes'}, inplace=True) 61 | print(adata) 62 | 63 | train_data, test_data, valid_data = divide_data(adata) 64 | for binning in preprocess_config['binning']: 65 | logger.info('Binning: {}'.format(binning)) 66 | processor = Preprocessor(use_key=preprocess_config['use_key'], 67 | filter_gene_by_counts=preprocess_config['filter_gene_by_counts'], 68 | filter_cell_by_counts=preprocess_config['filter_cell_by_counts'], 69 | normalize_total=preprocess_config['normalize_total'], 70 | result_normed_key=preprocess_config['result_normed_key'], 71 | log1p=preprocess_config['log1p'], 72 | result_log1p_key=preprocess_config['result_log1p_key'], 73 | subset_hvg=preprocess_config['subset_hvg'], 74 | hvg_use_key=preprocess_config['hvg_use_key'], 75 | hvg_flavor=preprocess_config['hvg_flavor'], 76 | binning=binning, 77 | result_binned_key=preprocess_config['result_binned_key']) 78 | 79 | 80 | 81 | output_name = f'{preprocess_config["output_name"]}_binning_{binning}' 82 | 83 | logger.info('Preprocessing Train Data') 84 | processor(train_data, batch_key= preprocess_config['batch_key']) 85 | print(train_data) 86 | train_data.write('{}/train/{}.h5ad'.format(preprocess_config['path'], output_name)) 87 | 88 | logger.info('Preprocessing test Data') 89 | processor(test_data, batch_key= preprocess_config['batch_key']) 90 | print(test_data) 91 | test_data.write('{}/test/{}.h5ad'.format(preprocess_config['path'], output_name)) 92 | 93 | logger.info('Preprocessing valid Data') 94 | processor(valid_data, batch_key= preprocess_config['batch_key']) 95 | print(valid_data) 96 | valid_data.write('{}/valid/{}.h5ad'.format(preprocess_config['path'], output_name)) 97 | 98 | 99 | # save preprocess config as a yml file 100 | with open('/home/jwu418/workspace/data/ours/configs/{}.yml'.format(output_name), 'w') as file: 101 | yaml.dump(preprocess_config, file) 102 | 103 | 104 | def reduce_data(): 105 | path = '/home/jwu418/workspace/data/ours' 106 | rna_file = 'pbmc_rna_s1_binning_2.h5ad' 107 | 108 | stage = ['test', 'valid', 'train'] 109 | for s in stage: 110 | adata = sc.read_h5ad('{}/{}/{}'.format(path, s, rna_file)) 111 | print('Before:', adata) 112 | # remove the adata.raw 113 | adata.raw = None 114 | # save the X_binned as X 115 | adata.X = adata.layers['X_binned'] 116 | # save adata.X as sparse matrix 117 | adata.X = scipy.sparse.csr_matrix(adata.X) 118 | # remove the X_binned layer 119 | adata.layers.pop('X_binned') 120 | # save the data as a new file 121 | adata.write('{}/{}/{}'.format(path, s, 'pbmc_rna_s1_binning_2_reduced.h5ad')) 122 | 123 | def get_pair_data(): 124 | path = '/home/jwu418/workspace/data/ours' 125 | rna_file = 'pbmc_rna_s1_binning_2.h5ad' 126 | atac_file = 'raw/pbmc_atac_s1.h5ad' 127 | 128 | output_name = 'pbmc_rna_s1_atac_paired.h5ad' 129 | 130 | stage = ['test', 'valid', 'train'] 131 | 132 | atac = sc.read_h5ad('{}/{}'.format(path, atac_file)) 133 | # breakpoint() 134 | for s in stage: 135 | rna = sc.read_h5ad('{}/{}/{}'.format(path, s, rna_file), backed='r') 136 | print('rna:', rna) 137 | # get the cell name of rna data 138 | 139 | rna_cell_name = rna.obs_names.tolist() 140 | # find the corresponding cell in atac data 141 | atac_cell = atac[rna_cell_name] 142 | # atac_cell._raw._var.rename(columns={'_index': 'peaks'}, inplace=True) 143 | # save the atac data as a new file 144 | atac_cell.write('{}/{}/{}'.format(path, s, output_name)) 145 | print('atac cell:', atac_cell) 146 | 147 | def generate_chr_vocab(): 148 | file = '/home/jwu418/workspace/data/ours/meta/genes.csv' 149 | import pandas as pd 150 | # read the 'seqnames' column 151 | df = pd.read_csv(file) 152 | chr_names = df['seqnames'].tolist() 153 | chr_names = list(set(chr_names)) 154 | vocab = GeneVocab(gene_list_or_vocab=chr_names, 155 | special_first= True, 156 | specials=['', '','', '']) 157 | vocab.save_json('/home/jwu418/workspace/data/ours/chr_vocab.json') 158 | 159 | 160 | def generate_vocab(): 161 | file = '/home/jwu418/workspace/data/ours/raw/mini_atlas_atac.h5ad' 162 | adata = sc.read_h5ad(file) 163 | # adata._raw._var.rename(columns={'_index': 'features'}, inplace=True) 164 | 165 | # get the gene names 166 | gene_names = adata.var_names.tolist() 167 | vocab = GeneVocab(gene_list_or_vocab=gene_names, 168 | special_first= True, 169 | specials=['', '','', '']) 170 | vocab.save_json('/home/jwu418/workspace/data/ours/atac_vocab.json') 171 | 172 | def generate_cell_type_vocab(): 173 | file = '/home/jwu418/workspace/data/ours/raw/pbmc_rna_s1.h5ad' 174 | adata = sc.read_h5ad(file) 175 | print(adata) 176 | # get the gene names 177 | gene_names = adata.obs['batch'].tolist() 178 | 179 | # remove duplicates 180 | gene_names = list(set(gene_names)) 181 | print(gene_names) 182 | # get number of cell types 183 | print('Number of cell types:', len(gene_names)) 184 | vocab = GeneVocab(gene_list_or_vocab=gene_names) 185 | vocab.save_json('/home/jwu418/workspace/data/ours/vocab/pbmc_s1_batch_vocab.json') 186 | 187 | 188 | 189 | # if __name__ == '__main__': 190 | # preprocess() 191 | # reduce_data() 192 | # get_pair_data() 193 | # generate_vocab() 194 | # generate_cell_type_vocab() 195 | 196 | ''' 197 | Proceesing the data: 198 | 1. preprocess() 199 | 2. get_pair_data() 200 | 3. generate_vocab() 201 | 4. generate_cell_type_vocab() 202 | 5. reduce_data() 203 | ''' -------------------------------------------------------------------------------- /env.yml: -------------------------------------------------------------------------------- 1 | name: epiFoundation 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - defaults 6 | - conda-forge 7 | - https://repo.anaconda.com/pkgs/main 8 | - https://repo.anaconda.com/pkgs/r 9 | dependencies: 10 | - _libgcc_mutex=0.1=main 11 | - _openmp_mutex=5.1=1_gnu 12 | - absl-py=2.1.0=pyhd8ed1ab_0 13 | - aom=3.6.0=h6a678d5_0 14 | - blas=1.0=mkl 15 | - bleach=4.1.0=pyhd3eb1b0_0 16 | - blosc=1.21.3=h6a678d5_0 17 | - bokeh=3.1.1=py38h2f386ee_0 18 | - bottleneck=1.3.7=py38ha9d4c09_0 19 | - brotli=1.0.9=h5eee18b_8 20 | - brotli-bin=1.0.9=h5eee18b_8 21 | - brotli-python=1.0.9=py38h6a678d5_8 22 | - brunsli=0.1=h2531618_0 23 | - bzip2=1.0.8=h5eee18b_6 24 | - c-ares=1.19.1=h5eee18b_0 25 | - ca-certificates=2024.9.24=h06a4308_0 26 | - certifi=2024.8.30=py38h06a4308_0 27 | - cfitsio=3.470=h5893167_7 28 | - charls=2.2.0=h2531618_0 29 | - charset-normalizer=3.3.2=pyhd3eb1b0_0 30 | - click=8.1.7=py38h06a4308_0 31 | - cloudpickle=3.0.0=py38h06a4308_0 32 | - colorama=0.4.6=pyhd8ed1ab_0 33 | - colorcet=3.1.0=py38h06a4308_0 34 | - cuda-cudart=12.1.105=0 35 | - cuda-cupti=12.1.105=0 36 | - cuda-libraries=12.1.0=0 37 | - cuda-nvrtc=12.1.105=0 38 | - cuda-nvtx=12.1.105=0 39 | - cuda-opencl=12.6.68=0 40 | - cuda-runtime=12.1.0=0 41 | - cuda-version=12.6=3 42 | - cycler=0.12.1=pyhd8ed1ab_0 43 | - cyrus-sasl=2.1.28=h52b45da_1 44 | - cytoolz=0.12.2=py38h5eee18b_0 45 | - dask-core=2023.4.1=py38h06a4308_0 46 | - datashader=0.15.2=py38h06a4308_0 47 | - datashape=0.5.4=py38h06a4308_1 48 | - dav1d=1.2.1=h5eee18b_0 49 | - dbus=1.13.18=hb2f20db_0 50 | - expat=2.6.3=h6a678d5_0 51 | - ffmpeg=4.3=hf484d3e_0 52 | - fontconfig=2.14.1=h55d465d_3 53 | - freetype=2.12.1=h4a9f257_0 54 | - giflib=5.2.1=h5eee18b_3 55 | - glib=2.78.4=h6a678d5_0 56 | - glib-tools=2.78.4=h6a678d5_0 57 | - gmp=6.2.1=h295c915_3 58 | - gmpy2=2.1.2=py38heeb90bb_0 59 | - gnutls=3.6.15=he1e5248_0 60 | - grpcio=1.62.2=py38h6a678d5_0 61 | - gst-plugins-base=1.14.1=h6a678d5_1 62 | - gstreamer=1.14.1=h5eee18b_1 63 | - holoviews=1.17.1=py38h06a4308_0 64 | - icu=73.1=h6a678d5_0 65 | - imagecodecs=2023.1.23=py38hc4b7b5f_0 66 | - imageio=2.33.1=py38h06a4308_0 67 | - importlib_metadata=8.5.0=hd8ed1ab_0 68 | - importlib_resources=6.4.0=py38h06a4308_0 69 | - intel-openmp=2023.1.0=hdb19cb5_46306 70 | - jinja2=3.1.4=py38h06a4308_0 71 | - joblib=1.4.2=py38h06a4308_0 72 | - jpeg=9e=h5eee18b_3 73 | - jxrlib=1.1=h7b6447c_2 74 | - krb5=1.20.1=h143b758_1 75 | - lame=3.100=h7b6447c_0 76 | - lazy_loader=0.4=py38h06a4308_0 77 | - lcms2=2.12=h3be6417_0 78 | - ld_impl_linux-64=2.38=h1181459_1 79 | - lerc=3.0=h295c915_0 80 | - libabseil=20240116.2=cxx17_h6a678d5_0 81 | - libaec=1.0.4=he6710b0_1 82 | - libavif=0.11.1=h5eee18b_0 83 | - libblas=3.9.0=1_h86c2bf4_netlib 84 | - libbrotlicommon=1.0.9=h5eee18b_8 85 | - libbrotlidec=1.0.9=h5eee18b_8 86 | - libbrotlienc=1.0.9=h5eee18b_8 87 | - libcblas=3.9.0=5_h92ddd45_netlib 88 | - libclang=14.0.6=default_hc6dbbc7_1 89 | - libclang13=14.0.6=default_he11475f_1 90 | - libcublas=12.1.0.26=0 91 | - libcufft=11.0.2.4=0 92 | - libcufile=1.11.1.6=0 93 | - libcups=2.4.2=h2d74bed_1 94 | - libcurand=10.3.7.68=0 95 | - libcurl=8.9.1=h251f7ec_0 96 | - libcusolver=11.4.4.55=0 97 | - libcusparse=12.0.2.55=0 98 | - libdeflate=1.17=h5eee18b_1 99 | - libedit=3.1.20230828=h5eee18b_0 100 | - libev=4.33=h7f8727e_1 101 | - libffi=3.4.4=h6a678d5_1 102 | - libgcc-ng=11.2.0=h1234567_1 103 | - libgfortran-ng=13.2.0=h69a702a_0 104 | - libgfortran5=13.2.0=ha4646dd_0 105 | - libglib=2.78.4=hdc74915_0 106 | - libgomp=11.2.0=h1234567_1 107 | - libgrpc=1.62.2=h2d74bed_0 108 | - libiconv=1.16=h5eee18b_3 109 | - libidn2=2.3.4=h5eee18b_0 110 | - libjpeg-turbo=2.0.0=h9bf148f_0 111 | - liblapack=3.9.0=5_h92ddd45_netlib 112 | - libllvm14=14.0.6=hecde1de_4 113 | - libnghttp2=1.57.0=h2d74bed_0 114 | - libnpp=12.0.2.50=0 115 | - libnvjitlink=12.1.105=0 116 | - libnvjpeg=12.1.1.14=0 117 | - libpng=1.6.39=h5eee18b_0 118 | - libpq=12.17=hdbd6064_0 119 | - libprotobuf=4.25.3=he621ea3_0 120 | - libssh2=1.11.0=h251f7ec_0 121 | - libstdcxx-ng=11.2.0=h1234567_1 122 | - libtasn1=4.19.0=h5eee18b_0 123 | - libtiff=4.5.1=h6a678d5_0 124 | - libunistring=0.9.10=h27cfd23_0 125 | - libuuid=1.41.5=h5eee18b_0 126 | - libwebp-base=1.3.2=h5eee18b_0 127 | - libxcb=1.15=h7f8727e_0 128 | - libxkbcommon=1.0.1=h097e994_2 129 | - libxml2=2.13.1=hfdd30dd_2 130 | - libzopfli=1.0.3=he6710b0_0 131 | - linkify-it-py=2.0.0=py38h06a4308_0 132 | - llvm-openmp=14.0.6=h9e868ea_0 133 | - locket=1.0.0=py38h06a4308_0 134 | - lz4-c=1.9.4=h6a678d5_1 135 | - markdown=3.6=pyhd8ed1ab_0 136 | - markdown-it-py=2.2.0=py38h06a4308_1 137 | - markupsafe=2.1.3=py38h5eee18b_0 138 | - mdit-py-plugins=0.3.0=py38h06a4308_0 139 | - mdurl=0.1.0=py38h06a4308_0 140 | - mkl=2023.1.0=h213fc3f_46344 141 | - mkl-service=2.4.0=py38h5eee18b_1 142 | - mkl_fft=1.3.8=py38h5eee18b_0 143 | - mkl_random=1.2.4=py38hdb19cb5_0 144 | - mpc=1.1.0=h10f8cd9_1 145 | - mpfr=4.0.2=hb69a4c5_1 146 | - mpmath=1.3.0=py38h06a4308_0 147 | - multipledispatch=0.6.0=py38_0 148 | - munkres=1.1.4=pyh9f0ad1d_0 149 | - mysql=5.7.24=h721c034_2 150 | - ncurses=6.4=h6a678d5_0 151 | - nettle=3.7.3=hbbd107a_1 152 | - networkx=3.1=py38h06a4308_0 153 | - numba=0.58.1=py38h6a678d5_0 154 | - numpy=1.24.3=py38hf6e8229_1 155 | - numpy-base=1.24.3=py38h060ed82_1 156 | - openh264=2.1.1=h4ff587b_0 157 | - openjpeg=2.5.2=he7f1fd0_0 158 | - openssl=3.0.15=h5eee18b_0 159 | - packaging=24.1=pyhd8ed1ab_0 160 | - pandas=2.0.3=py38h1128e8f_0 161 | - panel=1.2.3=py38h06a4308_0 162 | - param=1.13.0=py38h06a4308_0 163 | - partd=1.4.1=py38h06a4308_0 164 | - pcre2=10.42=hebb0a14_1 165 | - pillow=10.4.0=py38h5eee18b_0 166 | - pip=24.2=py38h06a4308_0 167 | - ply=3.11=py38_0 168 | - pyct=0.5.0=py38h06a4308_0 169 | - pynndescent=0.5.13=pyhff2d567_0 170 | - pyqt=5.15.10=py38h6a678d5_0 171 | - pyqt5-sip=12.13.0=py38h5eee18b_0 172 | - pysocks=1.7.1=py38h06a4308_0 173 | - python=3.8.19=h955ad1f_0 174 | - python-dateutil=2.9.0post0=py38h06a4308_2 175 | - python_abi=3.8=2_cp38 176 | - pytorch=2.1.1=py3.8_cuda12.1_cudnn8.9.2_0 177 | - pytorch-cuda=12.1=ha16c6d3_5 178 | - pytorch-mutex=1.0=cuda 179 | - pytz=2024.1=py38h06a4308_0 180 | - pyviz_comms=3.0.2=py38h06a4308_0 181 | - pywavelets=1.4.1=py38h5eee18b_0 182 | - pyyaml=6.0.1=py38h5eee18b_0 183 | - qt-main=5.15.2=h53bd1ea_10 184 | - re2=2022.04.01=h27087fc_0 185 | - readline=8.2=h5eee18b_0 186 | - requests=2.32.3=py38h06a4308_0 187 | - scikit-image=0.20.0=py38h6a678d5_0 188 | - scipy=1.8.1=py38h1ee437e_0 189 | - setuptools=72.1.0=py38h06a4308_0 190 | - sip=6.7.12=py38h6a678d5_0 191 | - six=1.16.0=pyh6c4a22f_0 192 | - snappy=1.2.1=h6a678d5_0 193 | - sqlite=3.45.3=h5eee18b_0 194 | - sympy=1.13.2=py38h06a4308_0 195 | - tbb=2021.8.0=hdb19cb5_0 196 | - tensorboard=2.17.1=pyhd8ed1ab_0 197 | - tensorboard-data-server=0.7.0=py38h52d8a92_1 198 | - threadpoolctl=3.5.0=pyhc1e730c_0 199 | - tifffile=2023.4.12=py38h06a4308_0 200 | - tk=8.6.14=h39e8969_0 201 | - tomli=2.0.1=py38h06a4308_0 202 | - toolz=0.12.0=py38h06a4308_0 203 | - torchaudio=2.1.1=py38_cu121 204 | - torchtriton=2.1.0=py38 205 | - torchvision=0.16.1=py38_cu121 206 | - tornado=6.4.1=py38h5eee18b_0 207 | - tqdm=4.66.5=pyhd8ed1ab_0 208 | - uc-micro-py=1.0.1=py38h06a4308_0 209 | - umap-learn=0.5.6=py38h578d9bd_1 210 | - unicodedata2=15.1.0=py38h5eee18b_0 211 | - urllib3=2.2.2=py38h06a4308_0 212 | - webencodings=0.5.1=py38_1 213 | - werkzeug=3.0.4=pyhd8ed1ab_0 214 | - wheel=0.43.0=py38h06a4308_0 215 | - xarray=2022.11.0=py38h06a4308_0 216 | - xyzservices=2022.9.0=py38h06a4308_1 217 | - xz=5.4.6=h5eee18b_1 218 | - yaml=0.2.5=h7b6447c_0 219 | - zfp=1.0.0=h6a678d5_0 220 | - zlib=1.2.13=h5eee18b_1 221 | - zstd=1.5.5=hc292b87_2 222 | - pip: 223 | - aiohappyeyeballs==2.4.0 224 | - aiohttp==3.10.5 225 | - aiosignal==1.3.1 226 | - anndata==0.9.2 227 | - async-timeout==4.0.3 228 | - attrs==24.2.0 229 | - beautifulsoup4==4.12.3 230 | - blessed==1.20.0 231 | - blosc2==2.0.0 232 | - contourpy==1.1.1 233 | - cython==3.0.11 234 | - deprecated==1.2.14 235 | - einops==0.8.0 236 | - filelock==3.15.4 237 | - flash-attn==2.6.3 238 | - fonttools==4.53.1 239 | - frozenlist==1.4.1 240 | - fsspec==2024.9.0 241 | - gdown==5.2.0 242 | - get-annotations==0.1.2 243 | - gpustat==1.1.1 244 | - h5py==3.11.0 245 | - huggingface-hub==0.0.8 246 | - idna==3.8 247 | - igraph==0.11.8 248 | - importlib-metadata==8.4.0 249 | - kiwisolver==1.4.7 250 | - legacy-api-wrap==1.4 251 | - leidenalg==0.10.2 252 | - llvmlite==0.41.1 253 | - local-attention==1.9.14 254 | - louvain==0.8.2 255 | - matplotlib==3.6.3 256 | - memory-profiler==0.61.0 257 | - msgpack==1.0.8 258 | - multidict==6.0.5 259 | - natsort==8.4.0 260 | - ninja==1.11.1.1 261 | - numexpr==2.8.6 262 | - nvidia-ml-py==12.560.30 263 | - patsy==0.5.6 264 | - protobuf==5.28.2 265 | - psutil==6.0.0 266 | - py-cpuinfo==9.0.0 267 | - pydot==3.0.2 268 | - pyparsing==3.1.4 269 | - regex==2024.7.24 270 | - sacremoses==0.1.1 271 | - scanpy==1.9.8 272 | - scib==1.1.5 273 | - scikit-learn==0.24.2 274 | - scikit-misc==0.2.0 275 | - seaborn==0.13.2 276 | - session-info==1.0.0 277 | - sinfo==0.3.4 278 | - soupsieve==2.6 279 | - statsmodels==0.14.1 280 | - stdlib-list==0.10.0 281 | - tables==3.8.0 282 | - tensorboardx==2.6.2.2 283 | - texttable==1.7.0 284 | - tokenizers==0.10.3 285 | - torchdata==0.7.1 286 | - torchtext==0.16.1 287 | - transformers==4.6.1 288 | - typing-extensions==4.12.2 289 | - tzdata==2024.1 290 | - wcwidth==0.2.13 291 | - wrapt==1.16.0 292 | - yarl==1.9.10 293 | - zipp==3.20.1 294 | prefix: /your/conda/envs/epiFoundation 295 | -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import gc 5 | import argparse 6 | import json 7 | import random 8 | import math 9 | import random 10 | from functools import reduce 11 | import numpy as np 12 | import pandas as pd 13 | from scipy import sparse 14 | from sklearn.model_selection import train_test_split 15 | import torch 16 | from torch import nn 17 | from torch.optim import Adam 18 | from torch.nn import functional as F 19 | from tensorboardX import SummaryWriter 20 | from torch.utils.data import DataLoader, Dataset 21 | from torch.utils.data.distributed import DistributedSampler 22 | from torch.nn.parallel import DistributedDataParallel as DDP 23 | import torch.distributed as dist 24 | from model import EpiFoundation 25 | from loss.loss import MaskedMSELoss 26 | from data.dataloader import * 27 | from tokenizer import GeneVocab 28 | import scanpy as sc 29 | import anndata as ad 30 | from utils import * 31 | from memory_profiler import profile 32 | 33 | import yaml 34 | 35 | torch.autograd.set_detect_anomaly(True) 36 | 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument("--config", type=str, default='./configs/pretrain/atac_cross_debug.yml', help='Config file.') 39 | args = parser.parse_args() 40 | 41 | def main(): 42 | # read and parse config file 43 | local_rank = int(os.environ["LOCAL_RANK"]) 44 | with open(args.config, 'r') as f: 45 | config = yaml.load(f, Loader=yaml.FullLoader) 46 | 47 | 48 | train_config = config['train'] 49 | valid_config = config['valid'] 50 | data_config = config['data'] 51 | vocab_config = config['vocab'] 52 | task_name = config['task_name'] 53 | task_floder = './experiment/{}'.format(task_name) 54 | ckpt_dir = os.path.join(task_floder, 'ckpts') 55 | 56 | 57 | 58 | random_seed = train_config['seed'] 59 | EPOCHS = train_config['epochs'] 60 | BATCH_SIZE = train_config['batch_size'] 61 | GRADIENT_ACCUMULATION = train_config['gradient_accumulation_steps'] 62 | LEARNING_RATE = float(train_config['lr']) 63 | 64 | model_name = train_config['model']['encoder'] 65 | 66 | save_ckpt_freq = train_config['save_ckpt_freq'] if 'save_ckpt_freq' in train_config else 5 67 | resume = train_config['resume'] if 'resume' in train_config else False 68 | 69 | # special tokens 70 | pad = vocab_config['special_tokens']['pad'] 71 | mask = vocab_config['special_tokens']['mask'] 72 | cls = vocab_config['special_tokens']['cls'] 73 | 74 | # distibuted setting 75 | dist.init_process_group(backend='nccl') 76 | torch.cuda.set_device(local_rank) 77 | device = torch.device("cuda", local_rank) 78 | world_size = torch.distributed.get_world_size() 79 | seed_all(random_seed + torch.distributed.get_rank()) 80 | is_master = (local_rank == 0) 81 | 82 | # init loggers 83 | logger = set_log(log_dir= os.path.join(task_floder, 'logs')) 84 | tb_logger = SummaryWriter(os.path.join(task_floder, 'tb_logs')) 85 | if is_master: 86 | logger.info(dict2str(config)) 87 | 88 | 89 | rna_vocab = GeneVocab.from_file(vocab_config['rna_path']) 90 | atac_vocab = GeneVocab.from_file(vocab_config['atac_path']) 91 | cell_vocab = GeneVocab.from_file(vocab_config['cell_type_path']) 92 | batch_vocab = GeneVocab.from_file(vocab_config['batch_path']) 93 | chr_vocab = GeneVocab.from_file(vocab_config['chr_path']) 94 | 95 | if is_master: 96 | logger.info(f'Rna vocab size: {len(rna_vocab)}') 97 | logger.info(f'Atac vocab size: {len(atac_vocab)}') 98 | 99 | if is_master: 100 | logger.info('loading training data') 101 | 102 | train_set = PairedSCDataset( 103 | rna_file = data_config['train']['rna_path'], 104 | atac_file= data_config['train']['atac_path'], 105 | rna_key = data_config['train']['rna_key'], 106 | atac_key = data_config['train']['atac_key'], 107 | rna_vocab = rna_vocab, 108 | atac_vocab = atac_vocab, 109 | cell_vocab = cell_vocab, 110 | batch_vocab= batch_vocab, 111 | chr_vocab = chr_vocab, 112 | gene2chr_file= vocab_config['gene2chr_path'], 113 | rna_max_len = train_config['model']['rna_max_len'], 114 | atac_max_len = train_config['model']['atac_max_len'], 115 | pad_token = pad['token'], 116 | rna_pad_value = pad['value'], 117 | cls_token = cls['token'], 118 | logger = logger, 119 | ) 120 | 121 | gc.collect() 122 | train_sampler = DistributedSampler(train_set) 123 | train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, sampler=train_sampler, prefetch_factor=4, num_workers=4) 124 | 125 | 126 | if is_master: 127 | logger.info('loading validation data') 128 | val_set = PairedSCDataset( 129 | rna_file = data_config['test']['rna_path'], 130 | atac_file= data_config['test']['atac_path'], 131 | rna_key = data_config['test']['rna_key'], 132 | atac_key = data_config['test']['atac_key'], 133 | rna_vocab = rna_vocab, 134 | atac_vocab = atac_vocab, 135 | cell_vocab = cell_vocab, 136 | batch_vocab= batch_vocab, 137 | chr_vocab = chr_vocab, 138 | gene2chr_file= vocab_config['gene2chr_path'], 139 | rna_max_len = train_config['model']['rna_max_len'], 140 | atac_max_len = train_config['model']['atac_max_len'], 141 | pad_token = pad['token'], 142 | rna_pad_value = pad['value'], 143 | cls_token = cls['token'], 144 | logger = logger, 145 | ) 146 | gc.collect() 147 | 148 | val_sampler = SequentialDistributedSampler(val_set, batch_size=BATCH_SIZE, world_size=world_size) 149 | val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, sampler=val_sampler, prefetch_factor=4, num_workers=4) 150 | 151 | if is_master: 152 | logger.info('Creating model') 153 | 154 | model = EpiFoundation( 155 | num_class_cell = len(cell_vocab), 156 | num_rnas = len(rna_vocab), 157 | num_atacs = len(atac_vocab), 158 | num_values= data_config['bin_num'], 159 | num_chrs= len(chr_vocab), 160 | embed_dim = train_config['model']['embedding_dim'], 161 | depth = train_config['model']['num_layers'], 162 | heads = train_config['model']['head_num'], 163 | head_dim = train_config['model']['head_dim'], 164 | encoder = model_name, 165 | dropout = train_config['model']['dropout'], 166 | pad_token_idx_rna = rna_vocab[pad['token']], 167 | pad_token_idx_atac = atac_vocab[pad['token']], 168 | cell_emb_style = train_config['model']['cell_emb_style'], 169 | mvc_arch_style = train_config['model']['mvc_arch_style'], 170 | use_batch_labels = train_config['model']['use_batch_labels'], 171 | batch_label_num= len(batch_vocab), 172 | use_chr_labels= train_config['model']['use_chr_labels'], 173 | stage= 'value_finetune', 174 | ).to(device) 175 | 176 | # optimizer 177 | optimizer = Adam(model.parameters(), lr=LEARNING_RATE) 178 | 179 | # learning rate scheduler 180 | scheduler = CosineAnnealingWarmupRestarts( 181 | optimizer, 182 | first_cycle_steps=15, 183 | cycle_mult=2, 184 | max_lr=LEARNING_RATE, 185 | min_lr=1e-6, 186 | warmup_steps=5, 187 | gamma=0.9 188 | ) 189 | 190 | start_epoch = 1 191 | model = DDP(model, device_ids=[local_rank], output_device=local_rank) 192 | 193 | # scaler = torch.amp.GradScaler(enabled=train_config['amp'].amp) 194 | scaler = torch.cuda.amp.GradScaler(enabled=train_config['amp']) 195 | 196 | # masked_mse_loss = MaskedMSELoss().to(local_rank) 197 | cross_entropy_loss = nn.CrossEntropyLoss(reduction='mean').to(local_rank) 198 | mvc_loss_fn = MaskedMSELoss().to(local_rank) 199 | mvc_weight = train_config['task_weight']['mvc'] 200 | cell_type_weight = train_config['task_weight']['cell_type'] 201 | 202 | softmax = nn.Softmax(dim=-1) 203 | 204 | steps = 0 205 | if train_config['model']['pretrained'] is not None: 206 | if is_master: 207 | logger.info('Loading pretrained model from: {}'.format(train_config['model']['pretrained'])) 208 | checkpoint = torch.load(train_config['model']['pretrained'], map_location=device) 209 | 210 | # # do not load value_decoder parameters 211 | pretrained_dict = {k: v for k, v in checkpoint['model'].items() if 'value_decoder' not in k and 'mvc_decoder' not in k and 'batch_emb' not in k and 'cls_decoder' not in k} 212 | model_dict = model.module.state_dict() 213 | model_dict.update(pretrained_dict) 214 | model.module.load_state_dict(model_dict) 215 | 216 | if resume: 217 | start_epoch = checkpoint['epoch'] + 1 218 | steps = checkpoint['steps'] 219 | del checkpoint 220 | del pretrained_dict 221 | gc.collect() 222 | 223 | dist.barrier() 224 | if is_master: 225 | logger.info('Start finetuning from epoch: {}, steps: {}'.format(start_epoch, steps)) 226 | for i in range(start_epoch, start_epoch + EPOCHS): 227 | train_loader.sampler.set_epoch(i) 228 | 229 | if is_master: 230 | logger.info('Training with {} samples, steps: {}'.format(len(train_loader.dataset), len(train_loader))) 231 | model.train() 232 | dist.barrier() 233 | running_loss = {'mvc': 0.0, 'cell': 0.0, 'total': 0.0} 234 | cum_acc_cell = 0.0 235 | for index, batch in enumerate(train_loader): 236 | index += 1 237 | steps += 1 238 | rna_values = batch['rna_values'].to(device) 239 | rna_ids = batch['rna_ids'].to(device) 240 | atac_ids = batch['atac_ids'].to(device) 241 | cell_ids = batch['cell_ids'].to(device) 242 | batch_ids = batch['batch_ids'].to(device) 243 | rna_chrs = batch['rna_chrs'].to(device) 244 | atac_chrs = batch['atac_chrs'].to(device) 245 | 246 | padding_positions = atac_ids.eq(atac_vocab[pad['token']]) 247 | rna_non_pad = rna_ids.ne(rna_vocab[pad['token']]) 248 | 249 | if index % GRADIENT_ACCUMULATION != 0 and index != len(train_loader): 250 | with model.no_sync(): 251 | with torch.cuda.amp.autocast(enabled=train_config['amp'], dtype= torch.bfloat16): 252 | # finetue using all expression values, do not mask 253 | output = model(atac = atac_ids, rna = rna_ids, src_key_padding_mask = padding_positions, batch_id = batch_ids, rna_chrs = rna_chrs, atac_chrs = atac_chrs) 254 | 255 | mvc_loss = mvc_loss_fn(output['value_pred'], rna_values.float(), mask = rna_non_pad) * mvc_weight 256 | cell_loss = cross_entropy_loss(output['cell_pred'], cell_ids) * cell_type_weight 257 | loss = mvc_loss + cell_loss 258 | 259 | running_loss['mvc'] += mvc_loss.item() 260 | running_loss['cell'] += cell_loss.item() 261 | running_loss['total'] += loss.item() 262 | 263 | loss = loss / GRADIENT_ACCUMULATION 264 | scaler.scale(loss).backward() 265 | else: 266 | with torch.cuda.amp.autocast(enabled=train_config['amp'], dtype= torch.bfloat16): 267 | output = model(atac = atac_ids, rna = rna_ids, src_key_padding_mask = padding_positions, batch_id = batch_ids, rna_chrs = rna_chrs, atac_chrs = atac_chrs) 268 | 269 | mvc_loss = mvc_loss_fn(output['value_pred'], rna_values.float(), mask = rna_non_pad) * mvc_weight 270 | cell_loss = cross_entropy_loss(output['cell_pred'], cell_ids) * cell_type_weight 271 | loss = mvc_loss + cell_loss 272 | 273 | running_loss['mvc'] += mvc_loss.item() 274 | running_loss['cell'] += cell_loss.item() 275 | running_loss['total'] += loss.item() 276 | if is_master: 277 | tb_logger.add_scalar('train/mvc_loss', mvc_loss.item(), steps) 278 | tb_logger.add_scalar('train/cell_loss', cell_loss.item(), steps) 279 | tb_logger.add_scalar('train/total_loss', loss.item(), steps) 280 | logger.info(f'Epoch: {i} | Step: {index} | MVC Loss: {mvc_loss:.4f} | Cell Type Loss: {cell_loss:.4f} | Total Loss: {loss:.4f}') 281 | loss = loss / GRADIENT_ACCUMULATION 282 | scaler.scale(loss).backward() 283 | scaler.unscale_(optimizer) 284 | torch.nn.utils.clip_grad_norm_(model.parameters(), int(1e2)) 285 | scaler.step(optimizer) 286 | scaler.update() 287 | optimizer.zero_grad() 288 | # cell type accuracy 289 | type_pred = softmax(output['cell_pred']) 290 | type_pred = type_pred.argmax(dim=-1) 291 | cum_acc_cell += (type_pred.eq(cell_ids)).sum().item() / len(cell_ids) 292 | 293 | cum_acc_cell = 100 * cum_acc_cell / index 294 | cum_acc_cell = get_reduced(cum_acc_cell, local_rank, 0, world_size) 295 | 296 | for key in running_loss: 297 | running_loss[key] = running_loss[key] / index 298 | running_loss[key] = get_reduced(running_loss[key], local_rank, 0, world_size) 299 | if is_master: 300 | logger.info(f'Epoch: {i} | MVC Loss: {running_loss["mvc"]:.4f} | Cell Type Loss: {running_loss["cell"]:.4f} | Total Loss: {running_loss["total"]:.4f} | Cell Type Accuracy: {cum_acc_cell:.2f}') 301 | dist.barrier() 302 | scheduler.step() 303 | # del train_set, train_sampler, train_loader 304 | 305 | if i % valid_config['freq'] == 0: 306 | if is_master: 307 | logger.info('#### Validation ####') 308 | model.eval() 309 | dist.barrier() 310 | running_loss = {'mvc': 0.0, 'cell': 0.0, 'total': 0.0} 311 | 312 | cum_acc_cell = 0.0 313 | 314 | with torch.no_grad(): 315 | for index, batch in enumerate(val_loader): 316 | index += 1 317 | 318 | rna_values = batch['rna_values'].to(device) 319 | rna_ids = batch['rna_ids'].to(device) 320 | atac_ids = batch['atac_ids'].to(device) 321 | cell_ids = batch['cell_ids'].to(device) 322 | batch_ids = batch['batch_ids'].to(device) 323 | rna_chrs = batch['rna_chrs'].to(device) 324 | atac_chrs = batch['atac_chrs'].to(device) 325 | 326 | padding_positions = atac_ids.eq(atac_vocab[pad['token']]) 327 | rna_non_pad = rna_ids.ne(rna_vocab[pad['token']]) 328 | with torch.cuda.amp.autocast(enabled=train_config['amp'], dtype= torch.bfloat16): 329 | output = model(atac = atac_ids, rna = rna_ids, src_key_padding_mask = padding_positions, batch_id = batch_ids, rna_chrs = rna_chrs, atac_chrs = atac_chrs) 330 | 331 | mvc_loss = mvc_loss_fn(output['value_pred'], rna_values.float(), mask = rna_non_pad) * mvc_weight 332 | cell_loss = cross_entropy_loss(output['cell_pred'], cell_ids) * cell_type_weight 333 | loss = mvc_loss + cell_loss 334 | 335 | running_loss['mvc'] += mvc_loss.item() 336 | running_loss['cell'] += cell_loss.item() 337 | running_loss['total'] += loss.item() 338 | 339 | type_pred = softmax(output['cell_pred']) 340 | type_pred = type_pred.argmax(dim=-1) 341 | cum_acc_cell += (type_pred.eq(cell_ids)).sum().item() / len(cell_ids) 342 | 343 | # break 344 | for key in running_loss: 345 | running_loss[key] = running_loss[key] / index 346 | running_loss[key] = get_reduced(running_loss[key], local_rank, 0, world_size) 347 | cum_acc_cell = 100 * cum_acc_cell / index 348 | cum_acc_cell = get_reduced(cum_acc_cell, local_rank, 0, world_size) 349 | 350 | # del val_set, val_sampler, val_loader 351 | if is_master: 352 | logger.info(f'MVC Loss: {running_loss["mvc"]:.4f} | Cell Type Loss: {running_loss["cell"]:.4f} | Total Loss: {running_loss["total"]:.4f} | Cell Type Accuracy: {cum_acc_cell:.2f}') 353 | 354 | if is_master and i % save_ckpt_freq == 0: 355 | save_ckpt(i, steps, model, optimizer, scheduler, scaler, running_loss["total"], task_name, ckpt_dir) 356 | 357 | 358 | if __name__ == '__main__': 359 | main() 360 | -------------------------------------------------------------------------------- /data/preprocess.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, Union 2 | 3 | import numpy as np 4 | import torch 5 | from scipy.sparse import issparse 6 | import scanpy as sc 7 | from scanpy.get import _get_obs_rep, _set_obs_rep 8 | from anndata import AnnData 9 | from tqdm import tqdm 10 | from model import logger 11 | 12 | 13 | class Preprocessor: 14 | """ 15 | Prepare data into training, valid and test split. Normalize raw expression 16 | values, binning or using other transform into the preset model input format. 17 | """ 18 | 19 | def __init__( 20 | self, 21 | use_key: Optional[str] = None, 22 | filter_gene_by_counts: Union[int, bool] = False, 23 | filter_cell_by_counts: Union[int, bool] = False, 24 | normalize_total: Union[float, bool] = 1e4, 25 | result_normed_key: Optional[str] = "X_normed", 26 | log1p: bool = False, 27 | result_log1p_key: str = "X_log1p", 28 | subset_hvg: Union[int, bool] = False, 29 | hvg_use_key: Optional[str] = None, 30 | hvg_flavor: str = "seurat_v3", 31 | binning: Optional[int] = 6, 32 | result_binned_key: str = "X_binned", 33 | ): 34 | r""" 35 | Set up the preprocessor, use the args to config the workflow steps. 36 | 37 | Args: 38 | 39 | use_key (:class:`str`, optional): 40 | The key of :class:`~anndata.AnnData` to use for preprocessing. 41 | filter_gene_by_counts (:class:`int` or :class:`bool`, default: ``False``): 42 | Whther to filter genes by counts, if :class:`int`, filter genes with counts 43 | filter_cell_by_counts (:class:`int` or :class:`bool`, default: ``False``): 44 | Whther to filter cells by counts, if :class:`int`, filter cells with counts 45 | normalize_total (:class:`float` or :class:`bool`, default: ``1e4``): 46 | Whether to normalize the total counts of each cell to a specific value. 47 | result_normed_key (:class:`str`, default: ``"X_normed"``): 48 | The key of :class:`~anndata.AnnData` to store the normalized data. If 49 | :class:`None`, will use normed data to replce the :attr:`use_key`. 50 | log1p (:class:`bool`, default: ``True``): 51 | Whether to apply log1p transform to the normalized data. 52 | result_log1p_key (:class:`str`, default: ``"X_log1p"``): 53 | The key of :class:`~anndata.AnnData` to store the log1p transformed data. 54 | subset_hvg (:class:`int` or :class:`bool`, default: ``False``): 55 | Whether to subset highly variable genes. 56 | hvg_use_key (:class:`str`, optional): 57 | The key of :class:`~anndata.AnnData` to use for calculating highly variable 58 | genes. If :class:`None`, will use :attr:`adata.X`. 59 | hvg_flavor (:class:`str`, default: ``"seurat_v3"``): 60 | The flavor of highly variable genes selection. See 61 | :func:`scanpy.pp.highly_variable_genes` for more details. 62 | binning (:class:`int`, optional): 63 | Whether to bin the data into discrete values of number of bins provided. 64 | result_binned_key (:class:`str`, default: ``"X_binned"``): 65 | The key of :class:`~anndata.AnnData` to store the binned data. 66 | """ 67 | self.use_key = use_key 68 | self.filter_gene_by_counts = filter_gene_by_counts 69 | self.filter_cell_by_counts = filter_cell_by_counts 70 | self.normalize_total = normalize_total 71 | self.result_normed_key = result_normed_key 72 | self.log1p = log1p 73 | self.result_log1p_key = result_log1p_key 74 | self.subset_hvg = subset_hvg 75 | self.hvg_use_key = hvg_use_key 76 | self.hvg_flavor = hvg_flavor 77 | self.binning = binning 78 | self.result_binned_key = result_binned_key 79 | 80 | def __call__(self, adata: AnnData, batch_key: Optional[str] = None) -> Dict: 81 | """ 82 | format controls the different input value wrapping, including categorical 83 | binned style, fixed-sum normalized counts, log1p fixed-sum normalized counts, etc. 84 | 85 | Args: 86 | 87 | adata (:class:`AnnData`): 88 | The :class:`AnnData` object to preprocess. 89 | batch_key (:class:`str`, optional): 90 | The key of :class:`AnnData.obs` to use for batch information. This arg 91 | is used in the highly variable gene selection step. 92 | """ 93 | key_to_process = self.use_key 94 | # preliminary checks, will use later 95 | if key_to_process == "X": 96 | key_to_process = None # the following scanpy apis use arg None to use X 97 | is_logged = self.check_logged(adata, obs_key=key_to_process) 98 | 99 | # step 1: filter genes 100 | if self.filter_gene_by_counts: 101 | logger.info("Filtering genes by counts ...") 102 | sc.pp.filter_genes( 103 | adata, 104 | min_counts=self.filter_gene_by_counts 105 | if isinstance(self.filter_gene_by_counts, int) 106 | else None, 107 | ) 108 | 109 | # step 2: filter cells 110 | if ( 111 | isinstance(self.filter_cell_by_counts, int) 112 | and self.filter_cell_by_counts > 0 113 | ): 114 | logger.info("Filtering cells by counts ...") 115 | sc.pp.filter_cells( 116 | adata, 117 | min_counts=self.filter_cell_by_counts 118 | if isinstance(self.filter_cell_by_counts, int) 119 | else None, 120 | ) 121 | 122 | # step 3: normalize total 123 | if self.normalize_total: 124 | logger.info("Normalizing total counts ...") 125 | normed_ = sc.pp.normalize_total( 126 | adata, 127 | target_sum=self.normalize_total 128 | if isinstance(self.normalize_total, float) 129 | else None, 130 | layer=key_to_process, 131 | inplace=False, 132 | )["X"] 133 | key_to_process = self.result_normed_key or key_to_process 134 | _set_obs_rep(adata, normed_, layer=key_to_process) 135 | 136 | # step 4: log1p 137 | if self.log1p: 138 | logger.info("Log1p transforming ...") 139 | if is_logged: 140 | logger.warning( 141 | "The input data seems to be already log1p transformed. " 142 | "Set `log1p=False` to avoid double log1p transform." 143 | ) 144 | if self.result_log1p_key: 145 | _set_obs_rep( 146 | adata, 147 | _get_obs_rep(adata, layer=key_to_process), 148 | layer=self.result_log1p_key, 149 | ) 150 | key_to_process = self.result_log1p_key 151 | sc.pp.log1p(adata, layer=key_to_process) 152 | 153 | 154 | # Select HVG in each batch if batch_key is provided 155 | # step 5: subset hvg 156 | if self.subset_hvg: 157 | logger.info("Subsetting highly variable genes ...") 158 | if batch_key is None: 159 | logger.warning( 160 | "No batch_key is provided, will use all cells for HVG selection." 161 | ) 162 | sc.pp.highly_variable_genes( 163 | adata, 164 | layer=self.hvg_use_key, 165 | n_top_genes=self.subset_hvg 166 | if isinstance(self.subset_hvg, int) 167 | else None, 168 | batch_key=batch_key, 169 | flavor=self.hvg_flavor, 170 | subset=True, 171 | ) 172 | 173 | # step 6: binning 174 | if self.binning: 175 | logger.info("Binning data with {} bins ...".format(self.binning)) 176 | if isinstance(self.binning, int): 177 | # raise ValueError( 178 | # "Binning arg must be an integer, but got {}.".format(self.binning) 179 | # ) 180 | n_bins = self.binning # NOTE: the first bin is always a spectial for zero 181 | binned_rows = [] 182 | bin_edges = [] 183 | # layer to be preprocessed refer to unbinned data, such as X, X_pca, X_umap, etc. 184 | layer_data = _get_obs_rep(adata, layer=key_to_process) 185 | layer_data = layer_data.A if issparse(layer_data) else layer_data 186 | if layer_data.min() < 0: 187 | raise ValueError( 188 | f"Assuming non-negative data, but got min value {layer_data.min()}." 189 | ) 190 | for row in tqdm(layer_data): 191 | if row.max() == 0: 192 | logger.warning( 193 | "The input data contains all zero rows. Please make sure " 194 | "this is expected. You can use the `filter_cell_by_counts` " 195 | "arg to filter out all zero rows." 196 | ) 197 | binned_rows.append(np.zeros_like(row, dtype=np.int64)) 198 | bin_edges.append(np.array([0] * n_bins)) 199 | continue 200 | non_zero_ids = row.nonzero() 201 | non_zero_row = row[non_zero_ids] 202 | # generate bins, represent a series of bin edges 203 | # np.linspace(0, 1, n_bins - 1) is the cumulative distribution function 204 | # return the quantile of the data, which is the inverse of the cumulative 205 | bins = np.quantile(non_zero_row, np.linspace(0, 1, n_bins - 1)) 206 | # bins = np.sort(np.unique(bins)) 207 | # NOTE: comment this line for now, since this will make the each category 208 | # has different relative meaning across datasets 209 | non_zero_digits = _digitize(non_zero_row, bins, side='one') 210 | assert non_zero_digits.min() >= 1 211 | assert non_zero_digits.max() <= n_bins - 1 212 | binned_row = np.zeros_like(row, dtype=np.int64) 213 | binned_row[non_zero_ids] = non_zero_digits 214 | binned_rows.append(binned_row) 215 | bin_edges.append(np.concatenate([[0], bins])) 216 | 217 | # Layer refer to the data type, such as X, X_pca, X_umap, etc. 218 | adata.layers[self.result_binned_key] = np.stack(binned_rows) 219 | adata.obsm["bin_edges"] = np.stack(bin_edges) 220 | elif isinstance(self.binning, list): 221 | # layer to be preprocessed refer to unbinned data, such as X, X_pca, X_umap, etc. 222 | layer_data = _get_obs_rep(adata, layer=key_to_process) 223 | layer_data = layer_data.A if issparse(layer_data) else layer_data 224 | if layer_data.min() < 0: 225 | raise ValueError( 226 | f"Assuming non-negative data, but got min value {layer_data.min()}." 227 | ) 228 | for bin_number in self.binning: 229 | n_bins = bin_number 230 | binned_rows = [] 231 | bin_edges = [] 232 | logger.info(f"Processing binning with {n_bins} bins") 233 | for row in tqdm(layer_data): 234 | if row.max() == 0: 235 | logger.warning( 236 | "The input data contains all zero rows. Please make sure " 237 | "this is expected. You can use the `filter_cell_by_counts` " 238 | "arg to filter out all zero rows." 239 | ) 240 | binned_rows.append(np.zeros_like(row, dtype=np.int64)) 241 | bin_edges.append(np.array([0] * n_bins)) 242 | continue 243 | non_zero_ids = row.nonzero() 244 | non_zero_row = row[non_zero_ids] 245 | # generate bins, represent a series of bin edges 246 | # np.linspace(0, 1, n_bins - 1) is the cumulative distribution function 247 | # return the quantile of the data, which is the inverse of the cumulative 248 | bins = np.quantile(non_zero_row, np.linspace(0, 1, n_bins - 1)) 249 | # bins = np.sort(np.unique(bins)) 250 | # NOTE: comment this line for now, since this will make the each category 251 | # has different relative meaning across datasets 252 | non_zero_digits = _digitize(non_zero_row, bins, side='one') 253 | assert non_zero_digits.min() >= 1 254 | assert non_zero_digits.max() <= n_bins - 1 255 | binned_row = np.zeros_like(row, dtype=np.int64) 256 | binned_row[non_zero_ids] = non_zero_digits 257 | binned_rows.append(binned_row) 258 | bin_edges.append(np.concatenate([[0], bins])) 259 | result_key = self.result_binned_key 260 | bin_edges_key = n_bins 261 | logger.info(f"Saving binned data with {n_bins} bins, result_key: {result_key}") 262 | adata.layers[result_key] = np.stack(binned_rows) 263 | adata.obsm[bin_edges_key] = np.stack(bin_edges) 264 | 265 | def check_logged(self, adata: AnnData, obs_key: Optional[str] = None) -> bool: 266 | """ 267 | Check if the data is already log1p transformed. 268 | 269 | Args: 270 | 271 | adata (:class:`AnnData`): 272 | The :class:`AnnData` object to preprocess. 273 | obs_key (:class:`str`, optional): 274 | The key of :class:`AnnData.obs` to use for batch information. This arg 275 | is used in the highly variable gene selection step. 276 | """ 277 | data = _get_obs_rep(adata, layer=obs_key) 278 | max_, min_ = data.max(), data.min() 279 | if max_ > 30: 280 | return False 281 | if min_ < 0: 282 | return False 283 | 284 | non_zero_min = data[data > 0].min() 285 | if non_zero_min >= 1: 286 | return False 287 | 288 | return True 289 | 290 | 291 | def _digitize(x: np.ndarray, bins: np.ndarray, side="both") -> np.ndarray: 292 | """ 293 | Digitize the data into bins. This method spreads data uniformly when bins 294 | have same values. 295 | 296 | Args: 297 | 298 | x (:class:`np.ndarray`): 299 | The data to digitize. 300 | bins (:class:`np.ndarray`): 301 | The bins to use for digitization, in increasing order. 302 | side (:class:`str`, optional): 303 | The side to use for digitization. If "one", the left side is used. If 304 | "both", the left and right side are used. Default to "one". 305 | 306 | Returns: 307 | 308 | :class:`np.ndarray`: 309 | The digitized data. 310 | """ 311 | assert x.ndim == 1 and bins.ndim == 1 312 | 313 | left_digits = np.digitize(x, bins) 314 | if side == "one": 315 | return left_digits 316 | 317 | right_difits = np.digitize(x, bins, right=True) 318 | 319 | rands = np.random.rand(len(x)) # uniform random numbers 320 | 321 | digits = rands * (right_difits - left_digits) + left_digits 322 | digits = np.ceil(digits).astype(np.int64) 323 | return digits 324 | 325 | 326 | def binning( 327 | row: Union[np.ndarray, torch.Tensor], n_bins: int 328 | ) -> Union[np.ndarray, torch.Tensor]: 329 | """Binning the row into n_bins.""" 330 | dtype = row.dtype 331 | return_np = False if isinstance(row, torch.Tensor) else True 332 | row = row.cpu().numpy() if isinstance(row, torch.Tensor) else row 333 | # TODO: use torch.quantile and torch.bucketize 334 | 335 | if row.max() == 0: 336 | logger.warning( 337 | "The input data contains row of zeros. Please make sure this is expected." 338 | ) 339 | return ( 340 | np.zeros_like(row, dtype=dtype) 341 | if return_np 342 | else torch.zeros_like(row, dtype=dtype) 343 | ) 344 | 345 | if row.min() <= 0: 346 | non_zero_ids = row.nonzero() 347 | non_zero_row = row[non_zero_ids] 348 | bins = np.quantile(non_zero_row, np.linspace(0, 1, n_bins - 1)) 349 | non_zero_digits = _digitize(non_zero_row, bins) 350 | binned_row = np.zeros_like(row, dtype=np.int64) 351 | binned_row[non_zero_ids] = non_zero_digits 352 | else: 353 | bins = np.quantile(row, np.linspace(0, 1, n_bins - 1)) 354 | binned_row = _digitize(row, bins) 355 | return torch.from_numpy(binned_row) if not return_np else binned_row.astype(dtype) 356 | -------------------------------------------------------------------------------- /model/transformer.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import math 3 | from typing import Dict, Mapping, Optional, Tuple, Any, Union 4 | 5 | import torch 6 | import numpy as np 7 | from torch import nn, Tensor 8 | import torch.distributed as dist 9 | import torch.nn.functional as F 10 | from torch.nn import TransformerEncoder, TransformerEncoderLayer 11 | from torch.distributions import Bernoulli 12 | from tqdm import trange 13 | from model.flashDiff import MultiheadFlashDiff 14 | 15 | try: 16 | from model.flashMHA import FlashMHA 17 | except ImportError: 18 | import warnings 19 | 20 | warnings.warn("flash_attn is not installed") 21 | flash_attn_available = False 22 | 23 | 24 | 25 | class TransformerModel(nn.Module): 26 | def __init__( 27 | self, 28 | d_model: int, 29 | nhead: int, 30 | d_hid: int, 31 | nlayers: int, 32 | dropout: float = 0.5, 33 | use_fast_transformer: bool = True, 34 | fast_transformer_backend: str = "flash", 35 | ): 36 | super().__init__() 37 | 38 | 39 | if use_fast_transformer: 40 | if fast_transformer_backend == "linear": 41 | self.transformer_encoder = FastTransformerEncoderWrapper( 42 | d_model, nhead, d_hid, nlayers, dropout 43 | ) 44 | elif fast_transformer_backend == "flash": 45 | encoder_layers = FlashTransformerEncoderLayer( 46 | d_model=d_model, 47 | nhead=nhead, 48 | dim_feedforward=d_hid, 49 | dropout=dropout, 50 | batch_first=True 51 | ) 52 | self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers, enable_nested_tensor= False) 53 | elif fast_transformer_backend == "diff": 54 | encoder_layers = DiffTransformerEncoderLayer( 55 | d_model=d_model, 56 | nhead=nhead, 57 | dim_feedforward=d_hid, 58 | dropout=dropout, 59 | batch_first=True 60 | ) 61 | self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers, enable_nested_tensor= False) 62 | else: 63 | encoder_layers = TransformerEncoderLayer( 64 | d_model, nhead, d_hid, dropout, batch_first=True 65 | ) 66 | self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) 67 | 68 | def forward( 69 | self, 70 | embs: Tensor, 71 | src_key_padding_mask: Optional[Tensor] = None, 72 | need_weights: Optional[bool] = False, 73 | ) -> Tensor: 74 | if need_weights: 75 | output, layer_weights = self.transformer_encoder(embs, src_key_padding_mask=src_key_padding_mask, need_weights=need_weights) 76 | return output, layer_weights 77 | else: 78 | output = self.transformer_encoder(embs, src_key_padding_mask=src_key_padding_mask) 79 | return output # (batch, seq_len, embsize) 80 | 81 | 82 | 83 | class FastTransformerEncoderWrapper(nn.Module): 84 | def __init__( 85 | self, 86 | d_model: int, 87 | nhead: int, 88 | d_hid: int, 89 | nlayers: int, 90 | dropout: float = 0.5, 91 | ): 92 | super().__init__() 93 | self.fast_transformer_encoder = self.build_fast_transformer_encoder( 94 | d_model, nhead, d_hid, nlayers, dropout 95 | ) 96 | 97 | @staticmethod 98 | def build_fast_transformer_encoder( 99 | d_model: int, nhead: int, d_hid: int, nlayers: int, dropout: float 100 | ) -> nn.Module: 101 | from fast_transformers.builders import TransformerEncoderBuilder 102 | 103 | if d_model % nhead != 0: 104 | raise ValueError( 105 | f"d_model must be divisible by nhead, " 106 | f"got d_model={d_model} and nhead={nhead}" 107 | ) 108 | builder = TransformerEncoderBuilder.from_kwargs( 109 | n_layers=nlayers, 110 | n_heads=nhead, 111 | query_dimensions=d_model // nhead, 112 | value_dimensions=d_model // nhead, 113 | feed_forward_dimensions=d_hid, 114 | attention_type="linear", 115 | attention_dropout=dropout, 116 | dropout=dropout, 117 | activation="gelu", 118 | ) 119 | assert builder.attention_type == "linear" 120 | return builder.get() 121 | 122 | @staticmethod 123 | def build_length_mask( 124 | src: Tensor, 125 | src_key_padding_mask: torch.BoolTensor, 126 | ) -> "LengthMask": 127 | from fast_transformers.masking import LengthMask 128 | 129 | seq_len = src.shape[1] 130 | num_paddings = src_key_padding_mask.sum(dim=1) 131 | actual_seq_len = seq_len - num_paddings # (N,) 132 | length_mask = LengthMask(actual_seq_len, max_len=seq_len, device=src.device) 133 | 134 | if src_key_padding_mask[length_mask.bool_matrix].sum() != 0: 135 | raise ValueError( 136 | "Found padding tokens in the middle of the sequence. " 137 | "src_key_padding_mask and length_mask are not compatible." 138 | ) 139 | return length_mask 140 | 141 | def forward( 142 | self, 143 | src: Tensor, 144 | src_key_padding_mask: torch.BoolTensor, 145 | ) -> Tensor: 146 | """ 147 | Args: 148 | src: Tensor, shape [N, seq_len, embsize] 149 | src_key_padding_mask: Tensor, shape [N, seq_len] 150 | 151 | Returns: 152 | output Tensor of shape [N, seq_len, embsize] 153 | """ 154 | if src_key_padding_mask.shape != src.shape[:2]: 155 | raise ValueError( 156 | f"src_key_padding_mask shape {src_key_padding_mask.shape} " 157 | f"does not match first two dims of src shape {src.shape[:2]}" 158 | ) 159 | 160 | if src_key_padding_mask.dtype != torch.bool: 161 | raise ValueError( 162 | f"src_key_padding_mask needs to be of type torch.bool, " 163 | f"got {src_key_padding_mask.dtype}" 164 | ) 165 | 166 | length_mask = self.build_length_mask(src, src_key_padding_mask) 167 | output = self.fast_transformer_encoder(src, length_mask=length_mask) 168 | return output 169 | 170 | 171 | class FlashTransformerEncoderLayer(nn.Module): 172 | r"""TransformerEncoderLayer is made up of self-attn and feedforward network. 173 | The class is modified from torch.nn.TransformerEncoderLayer to support the 174 | FlashAttention. 175 | 176 | Args: 177 | d_model: the number of expected features in the input (required). 178 | nhead: the number of heads in the multiheadattention models (required). 179 | dim_feedforward: the dimension of the feedforward network model (default=2048). 180 | dropout: the dropout value (default=0.1). 181 | activation: the activation function of intermediate layer, relu or gelu (default=relu). 182 | layer_norm_eps: the eps value in layer normalization components (default=1e-5). 183 | batch_first: If ``True``, then the input and output tensors are provided 184 | as (batch, seq, feature). Default: ``False``. 185 | 186 | Examples:: 187 | >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) 188 | >>> src = torch.rand(10, 32, 512) 189 | >>> out = encoder_layer(src) 190 | 191 | Alternatively, when ``batch_first`` is ``True``: 192 | >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True) 193 | >>> src = torch.rand(32, 10, 512) 194 | >>> out = encoder_layer(src) 195 | """ 196 | __constants__ = ["batch_first"] 197 | 198 | def __init__( 199 | self, 200 | d_model, 201 | nhead, 202 | dim_feedforward=2048, 203 | dropout=0.1, 204 | activation="relu", 205 | layer_norm_eps=1e-5, 206 | batch_first=True, 207 | device=None, 208 | dtype=None, 209 | norm_scheme="post", # "pre" or "post" 210 | ) -> None: 211 | factory_kwargs = {"device": device, "dtype": dtype} 212 | super().__init__() 213 | self.self_attn = FlashMHA( 214 | embed_dim=d_model, 215 | num_heads=nhead, 216 | batch_first=batch_first, 217 | attention_dropout=dropout, 218 | **factory_kwargs, 219 | ) 220 | # self.self_attn = FlashMHA( 221 | # embed_dim=d_model, 222 | # num_heads=nhead, 223 | # dropout=dropout, 224 | # use_flash_attn= True, 225 | # **factory_kwargs, 226 | # ) 227 | # Version compatibility workaround 228 | if not hasattr(self.self_attn, "batch_first"): 229 | self.self_attn.batch_first = batch_first 230 | # Implementation of Feedforward model 231 | self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs) 232 | self.dropout = nn.Dropout(dropout) 233 | self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs) 234 | 235 | self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) 236 | self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) 237 | self.dropout1 = nn.Dropout(dropout) 238 | self.dropout2 = nn.Dropout(dropout) 239 | 240 | self.activation = self._get_activation_fn(activation) 241 | self.norm_scheme = norm_scheme 242 | if self.norm_scheme not in ["pre", "post"]: 243 | raise ValueError(f"norm_scheme should be pre or post, not {norm_scheme}") 244 | 245 | @staticmethod 246 | def _get_activation_fn(activation): 247 | if activation == "relu": 248 | return F.relu 249 | elif activation == "gelu": 250 | return F.gelu 251 | 252 | raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) 253 | 254 | def __setstate__(self, state): 255 | if "activation" not in state: 256 | state["activation"] = F.relu 257 | super().__setstate__(state) 258 | 259 | def forward( 260 | self, 261 | src: Tensor, 262 | src_mask: Optional[Tensor] = None, 263 | src_key_padding_mask: Optional[Tensor] = None, 264 | need_weights: Optional[bool] = False, 265 | **kwargs, 266 | ) -> Tensor: 267 | r"""Pass the input through the encoder layer. 268 | 269 | Args: 270 | src: the sequence to the encoder layer (required). 271 | src_mask: the mask for the src sequence (optional). 272 | src_key_padding_mask: the mask for the src keys per batch (optional). 273 | 274 | Shape: 275 | see the docs in Transformer class. 276 | """ 277 | if src_mask is not None: 278 | raise ValueError("FlashTransformerEncoderLayer does not support src_mask") 279 | 280 | if src_key_padding_mask is not None: 281 | if not src_key_padding_mask.any().item(): 282 | # no padding tokens in src 283 | src_key_padding_mask_ = None 284 | else: 285 | if src_key_padding_mask.dtype != torch.bool: 286 | src_key_padding_mask = src_key_padding_mask.bool() 287 | # NOTE: the FlashMHA uses mask 0 for padding tokens, which is the opposite 288 | src_key_padding_mask_ = ~src_key_padding_mask 289 | else: 290 | src_key_padding_mask_ = None 291 | 292 | if self.norm_scheme == "pre": 293 | src = self.norm1(src) 294 | src2, atten_weight = self.self_attn(src, key_padding_mask=src_key_padding_mask_, need_weights = need_weights) 295 | 296 | src = src + self.dropout1(src2) 297 | src = self.norm2(src) 298 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 299 | src = src + self.dropout2(src2) 300 | else: 301 | src2, atten_weight = self.self_attn(src, key_padding_mask=src_key_padding_mask_, need_weights = need_weights) 302 | src = src + self.dropout1(src2) 303 | src = self.norm1(src) 304 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 305 | src = src + self.dropout2(src2) 306 | src = self.norm2(src) 307 | 308 | if need_weights: 309 | return src, atten_weight 310 | else: 311 | return src 312 | 313 | 314 | class DiffTransformerEncoderLayer(nn.Module): 315 | __constants__ = ["batch_first"] 316 | 317 | def __init__( 318 | self, 319 | d_model, 320 | nhead, 321 | dim_feedforward=2048, 322 | dropout=0.1, 323 | lambda_init = 0.8, 324 | activation="relu", 325 | layer_norm_eps=1e-5, 326 | batch_first=True, 327 | device=None, 328 | dtype=None, 329 | norm_scheme="post", # "pre" or "post" 330 | ) -> None: 331 | factory_kwargs = {"device": device, "dtype": dtype} 332 | super().__init__() 333 | self.self_attn = MultiheadFlashDiff( 334 | embed_dim=d_model, 335 | num_heads=nhead, 336 | attention_dropout=dropout, 337 | lambda_init = lambda_init, 338 | ) 339 | # self.self_attn = FlashMHA( 340 | # embed_dim=d_model, 341 | # num_heads=nhead, 342 | # dropout=dropout, 343 | # use_flash_attn= True, 344 | # **factory_kwargs, 345 | # ) 346 | # Version compatibility workaround 347 | if not hasattr(self.self_attn, "batch_first"): 348 | self.self_attn.batch_first = batch_first 349 | # Implementation of Feedforward model 350 | self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs) 351 | self.dropout = nn.Dropout(dropout) 352 | self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs) 353 | 354 | self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) 355 | self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) 356 | self.dropout1 = nn.Dropout(dropout) 357 | self.dropout2 = nn.Dropout(dropout) 358 | 359 | self.activation = self._get_activation_fn(activation) 360 | self.norm_scheme = norm_scheme 361 | if self.norm_scheme not in ["pre", "post"]: 362 | raise ValueError(f"norm_scheme should be pre or post, not {norm_scheme}") 363 | 364 | @staticmethod 365 | def _get_activation_fn(activation): 366 | if activation == "relu": 367 | return F.relu 368 | elif activation == "gelu": 369 | return F.gelu 370 | 371 | raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) 372 | 373 | def __setstate__(self, state): 374 | if "activation" not in state: 375 | state["activation"] = F.relu 376 | super().__setstate__(state) 377 | 378 | def forward( 379 | self, 380 | src: Tensor, 381 | src_mask: Optional[Tensor] = None, 382 | src_key_padding_mask: Optional[Tensor] = None, 383 | need_weights: Optional[bool] = False, 384 | **kwargs, 385 | ) -> Tensor: 386 | r"""Pass the input through the encoder layer. 387 | 388 | Args: 389 | src: the sequence to the encoder layer (required). 390 | src_mask: the mask for the src sequence (optional). 391 | src_key_padding_mask: the mask for the src keys per batch (optional). 392 | 393 | Shape: 394 | see the docs in Transformer class. 395 | """ 396 | if src_mask is not None: 397 | raise ValueError("FlashTransformerEncoderLayer does not support src_mask") 398 | 399 | if src_key_padding_mask is not None: 400 | if not src_key_padding_mask.any().item(): 401 | # no padding tokens in src 402 | src_key_padding_mask_ = None 403 | else: 404 | if src_key_padding_mask.dtype != torch.bool: 405 | src_key_padding_mask = src_key_padding_mask.bool() 406 | # NOTE: the FlashMHA uses mask 0 for padding tokens, which is the opposite 407 | src_key_padding_mask_ = ~src_key_padding_mask 408 | else: 409 | src_key_padding_mask_ = None 410 | 411 | if self.norm_scheme == "pre": 412 | src = self.norm1(src) 413 | src2 = self.self_attn(src) 414 | 415 | src = src + self.dropout1(src2) 416 | src = self.norm2(src) 417 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 418 | src = src + self.dropout2(src2) 419 | else: 420 | src2 = self.self_attn(src) 421 | src = src + self.dropout1(src2) 422 | src = self.norm1(src) 423 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 424 | src = src + self.dropout2(src2) 425 | src = self.norm2(src) 426 | 427 | if need_weights: 428 | return src, atten_weight 429 | else: 430 | return src -------------------------------------------------------------------------------- /data/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from torch.utils.data import Dataset, DataLoader 4 | from torchtext.vocab import Vocab 5 | from scanpy import AnnData 6 | from typing import Dict, List, Tuple, Union 7 | import numpy as np 8 | from memory_profiler import profile 9 | from tokenizer.gene_tokenizer import GeneVocab 10 | from tqdm import tqdm 11 | import psutil 12 | import scanpy as sc 13 | import json 14 | 15 | 16 | class ATACDataset(Dataset): 17 | def __init__(self, 18 | atac_file, 19 | atac_key: str, 20 | atac_vocab: GeneVocab, 21 | cell_vocab: GeneVocab, 22 | batch_vocab: GeneVocab, 23 | chr_vocab: GeneVocab, 24 | atac_max_len: int, 25 | pad_token: str, 26 | cls_token: str, 27 | reg_token: str = None, 28 | logger = None, 29 | ): 30 | 31 | self.atac_key = atac_key if atac_key != "X" else None 32 | self.atac_max_len = atac_max_len 33 | self.logger = logger 34 | self.pad_token = pad_token 35 | self.cls_token = cls_token 36 | 37 | # read raw data 38 | self.rna_raw = {} 39 | self.atac_raw = {} 40 | 41 | 42 | self.log(f"load atac data from {atac_file}") 43 | atac_raw_data = sc.read_h5ad(atac_file, backed='r') 44 | self.atac_raw['values'] = atac_raw_data.layers[self.atac_key] if self.atac_key is not None else atac_raw_data.X 45 | try: 46 | gene_names = atac_raw_data.var['features'].tolist() 47 | except: 48 | gene_names = atac_raw_data.var_names.tolist() 49 | # get atac chr, by spliting the gene name using '-', and get the first element 50 | gene_chr = [gene.split('-')[0] for gene in gene_names] 51 | 52 | self.atac_raw['chr_ids'] = np.array(chr_vocab(gene_chr)) 53 | self.atac_raw['gene_ids'] = np.array(atac_vocab(gene_names)) 54 | self.atac_raw['pad_id'] = atac_vocab[pad_token] 55 | self.log(f"atac_raw: {self.atac_raw['values'].shape}") 56 | 57 | self.atac_cls_id= atac_vocab[cls_token] 58 | if reg_token is not None: 59 | self.atac_reg_id = atac_vocab[reg_token] 60 | self.chr_reg_id = chr_vocab[reg_token] 61 | else: 62 | self.atac_reg_id = None 63 | self.chr_reg_id = None 64 | self.chr_cls_id = chr_vocab[cls_token] 65 | self.chr_pad_id = chr_vocab[pad_token] 66 | 67 | self.cell_ids = cell_vocab(atac_raw_data.obs['annot'].tolist()) 68 | self.batch_ids = batch_vocab(atac_raw_data.obs['batch'].tolist()) 69 | return 70 | 71 | 72 | def __len__(self): 73 | return len(self.cell_ids) 74 | 75 | def log(self, msg): 76 | if int(os.environ["LOCAL_RANK"]) == 0: 77 | self.logger.info(msg) 78 | 79 | 80 | def prepare_row_atac(self, row): 81 | non_zero_idx = row.indices 82 | peaks = self.atac_raw['gene_ids'][non_zero_idx] 83 | chrs = self.atac_raw['chr_ids'][non_zero_idx] 84 | 85 | if self.atac_reg_id is not None: 86 | peaks = np.insert(peaks, 0, self.atac_reg_id) 87 | peaks = np.insert(peaks, 0, self.atac_cls_id) 88 | peaks = torch.from_numpy(peaks).long() 89 | 90 | if self.chr_reg_id is not None: 91 | chrs = np.insert(chrs, 0, self.chr_reg_id) 92 | chrs = np.insert(chrs, 0, self.chr_cls_id) 93 | chrs = torch.from_numpy(chrs).long() 94 | 95 | num_special_tokens = 1 if self.atac_reg_id is None else 2 96 | 97 | if len(peaks) > self.atac_max_len: 98 | idx = np.random.choice(len(peaks) - num_special_tokens, self.atac_max_len - num_special_tokens, replace=False) 99 | idx = idx + num_special_tokens 100 | for i in range(num_special_tokens): 101 | idx = np.insert(idx, i, i) 102 | peaks = peaks[idx] 103 | chrs = chrs[idx] 104 | elif len(peaks) <= self.atac_max_len: 105 | peaks = torch.cat( 106 | [ 107 | peaks, 108 | torch.full( 109 | (self.atac_max_len - len(peaks),), self.atac_raw['pad_id'], dtype=peaks.dtype 110 | ), 111 | ] 112 | ) 113 | chrs = torch.cat( 114 | [ 115 | chrs, 116 | torch.full( 117 | (self.atac_max_len - len(chrs),), self.chr_pad_id, dtype=chrs.dtype 118 | ), 119 | ] 120 | ) 121 | return peaks, chrs 122 | 123 | def __getitem__(self, idx): 124 | atac_row = self.atac_raw['values'][idx] 125 | atac_ids, atac_chrs = self.prepare_row_atac(atac_row) 126 | 127 | cell_id = torch.tensor(self.cell_ids[idx]).long() 128 | batch_id = torch.tensor(self.batch_ids[idx]).long() 129 | 130 | return { 131 | "atac_ids": atac_ids, 132 | "atac_chrs": atac_chrs, 133 | "cell_ids": cell_id, 134 | "batch_ids": batch_id, 135 | } 136 | 137 | 138 | class PairedSCDataset(Dataset): 139 | def __init__(self, 140 | rna_file, 141 | atac_file, 142 | rna_key: str, 143 | atac_key: str, 144 | rna_vocab: GeneVocab, 145 | atac_vocab: GeneVocab, 146 | cell_vocab: GeneVocab, 147 | batch_vocab: GeneVocab, 148 | chr_vocab: GeneVocab, 149 | gene2chr_file: str, 150 | rna_max_len: int, 151 | atac_max_len: int, 152 | pad_token: str, 153 | rna_pad_value: int, 154 | cls_token: str, 155 | reg_token: str = None, 156 | logger = None, 157 | get_full_genes: bool = False, 158 | ): 159 | self.rna_key = rna_key if rna_key != "X" else None 160 | self.atac_key = atac_key if atac_key != "X" else None 161 | self.rna_max_len = rna_max_len 162 | self.atac_max_len = atac_max_len 163 | self.logger = logger 164 | self.pad_token = pad_token 165 | self.rna_pad_value = rna_pad_value 166 | self.cls_token = cls_token 167 | 168 | # read raw data 169 | self.rna_raw = {} 170 | self.atac_raw = {} 171 | 172 | self.get_full_genes = get_full_genes 173 | 174 | 175 | 176 | self.log(f"load rna data from {rna_file}") 177 | rna_raw_data = sc.read_h5ad(rna_file, backed='r') 178 | self.rna_raw['values'] = rna_raw_data.layers[self.rna_key] if self.rna_key is not None else rna_raw_data.X 179 | try: 180 | gene_names = rna_raw_data.var['features'].tolist() 181 | except: 182 | gene_names = rna_raw_data.var_names.tolist() 183 | 184 | # read gene2chr file as a dict 185 | with open(gene2chr_file, 'r') as file: 186 | gene2chr = json.load(file) 187 | gene_chr = [gene2chr[gene] for gene in gene_names] 188 | 189 | self.rna_raw['chr_ids'] = np.array(chr_vocab(gene_chr)) 190 | self.rna_raw['gene_ids'] = np.array(rna_vocab(gene_names)) 191 | self.rna_raw['pad_id'] = rna_vocab[pad_token] 192 | self.log(f"rna_raw: {self.rna_raw['values'].shape}") 193 | 194 | 195 | 196 | self.log(f"load atac data from {atac_file}") 197 | atac_raw_data = sc.read_h5ad(atac_file, backed='r') 198 | self.atac_raw['values'] = atac_raw_data.layers[self.atac_key] if self.atac_key is not None else atac_raw_data.X 199 | try: 200 | gene_names = atac_raw_data.var['features'].tolist() 201 | except: 202 | gene_names = atac_raw_data.var_names.tolist() 203 | # get atac chr, by spliting the gene name using '-', and get the first element 204 | gene_chr = [gene.split('-')[0] for gene in gene_names] 205 | 206 | self.atac_raw['chr_ids'] = np.array(chr_vocab(gene_chr)) 207 | self.atac_raw['gene_ids'] = np.array(atac_vocab(gene_names)) 208 | self.atac_raw['pad_id'] = atac_vocab[pad_token] 209 | self.log(f"atac_raw: {self.atac_raw['values'].shape}") 210 | 211 | self.atac_cls_id= atac_vocab[cls_token] 212 | if reg_token is not None: 213 | self.atac_reg_id = atac_vocab[reg_token] 214 | self.chr_reg_id = chr_vocab[reg_token] 215 | else: 216 | self.atac_reg_id = None 217 | self.chr_reg_id = None 218 | self.chr_cls_id = chr_vocab[cls_token] 219 | self.chr_pad_id = chr_vocab[pad_token] 220 | 221 | rna_cell_idx = rna_raw_data.obs_names.tolist() 222 | atac_cell_idx = atac_raw_data.obs_names.tolist() 223 | 224 | assert (rna_cell_idx == atac_cell_idx) 225 | self.cell_ids = cell_vocab(atac_raw_data.obs['annot'].tolist()) 226 | self.batch_ids = batch_vocab(atac_raw_data.obs['batch'].tolist()) 227 | 228 | return 229 | 230 | 231 | def __len__(self): 232 | return len(self.cell_ids) 233 | 234 | def log(self, msg): 235 | if int(os.environ["LOCAL_RANK"]) == 0: 236 | self.logger.info(msg) 237 | 238 | def prepare_row_rna_full(self, row): 239 | if not isinstance(row, np.ndarray): 240 | row = row.toarray().flatten() 241 | genes = self.rna_raw['gene_ids'] 242 | values = row 243 | chrs = self.rna_raw['chr_ids'] 244 | 245 | genes = torch.from_numpy(genes).long() 246 | values = torch.from_numpy(values).long() 247 | chrs = torch.from_numpy(chrs).long() 248 | 249 | return genes, values, chrs 250 | 251 | 252 | def prepare_row_rna(self, row, non_zero_prob=0.5): 253 | # if row is not a numpy array, convert it to numpy array 254 | if not isinstance(row, np.ndarray): 255 | row = row.toarray().flatten() 256 | non_zero_idx = np.nonzero(row)[0] 257 | zero_idx = np.nonzero(row == 0)[0] 258 | 259 | total_number = min(self.rna_max_len, int(len(non_zero_idx) / non_zero_prob)) 260 | non_zero_number = int(total_number * non_zero_prob) 261 | zero_number = total_number - non_zero_number 262 | 263 | non_zero_idx = np.random.choice(non_zero_idx, non_zero_number, replace=False) 264 | zero_idx = np.random.choice(zero_idx, zero_number, replace=False) 265 | 266 | non_zero_genes = self.rna_raw['gene_ids'][non_zero_idx] 267 | zero_genes = self.rna_raw['gene_ids'][zero_idx] 268 | 269 | chr_non_zero = self.rna_raw['chr_ids'][non_zero_idx] 270 | chr_zero = self.rna_raw['chr_ids'][zero_idx] 271 | 272 | genes = np.concatenate([non_zero_genes, zero_genes]) 273 | values = np.concatenate([row[non_zero_idx], row[zero_idx]]) 274 | chrs = np.concatenate([chr_non_zero, chr_zero]) 275 | 276 | # shuffle the genes 277 | idx = np.random.permutation(len(genes)) 278 | genes = genes[idx] 279 | values = values[idx] 280 | chrs = chrs[idx] 281 | 282 | genes = torch.from_numpy(genes).long() 283 | values = torch.from_numpy(values).long() 284 | chrs = torch.from_numpy(chrs).long() 285 | 286 | if len(genes) < self.rna_max_len: 287 | genes = torch.cat( 288 | [ 289 | genes, 290 | torch.full( 291 | (self.rna_max_len - len(genes),), self.rna_raw['pad_id'], dtype=genes.dtype 292 | ), 293 | ] 294 | ) 295 | values = torch.cat( 296 | [ 297 | values, 298 | torch.full((self.rna_max_len - len(values),), self.rna_pad_value, dtype=values.dtype), 299 | ] 300 | ) 301 | chrs = torch.cat( 302 | [ 303 | chrs, 304 | torch.full((self.rna_max_len - len(chrs),), self.chr_pad_id, dtype=chrs.dtype), 305 | ] 306 | ) 307 | return genes, values, chrs 308 | 309 | def prepare_row_rna_binary(self, row): 310 | non_zero_idx = np.nonzero(row)[0] 311 | zero_idx = np.nonzero(row == 0)[0] 312 | 313 | if len(non_zero_idx) > self.rna_max_len // 2: 314 | non_zero_number = self.rna_max_len // 2 315 | zero_number = self.rna_max_len - non_zero_number 316 | else: 317 | non_zero_number = len(non_zero_idx) 318 | zero_number = non_zero_number 319 | 320 | non_zero_idx = np.random.choice(non_zero_idx, non_zero_number, replace=False) 321 | zero_idx = np.random.choice(zero_idx, zero_number, replace=False) 322 | 323 | non_zero_genes = self.rna_raw['gene_ids'][non_zero_idx] 324 | zero_genes = self.rna_raw['gene_ids'][zero_idx] 325 | 326 | chr_non_zero = self.rna_raw['chr_ids'][non_zero_idx] 327 | chr_zero = self.rna_raw['chr_ids'][zero_idx] 328 | 329 | genes = np.concatenate([non_zero_genes, zero_genes]) 330 | values = np.concatenate([np.ones(non_zero_number), np.zeros(zero_number)]) 331 | chrs = np.concatenate([chr_non_zero, chr_zero]) 332 | 333 | # shuffle the genes 334 | idx = np.random.permutation(len(genes)) 335 | genes = genes[idx] 336 | values = values[idx] 337 | chrs = chrs[idx] 338 | 339 | genes = torch.from_numpy(genes).long() 340 | values = torch.from_numpy(values).long() 341 | chrs = torch.from_numpy(chrs).long() 342 | 343 | if len(genes) < self.rna_max_len: 344 | genes = torch.cat( 345 | [ 346 | genes, 347 | torch.full( 348 | (self.rna_max_len - len(genes),), self.rna_raw['pad_id'], dtype=genes.dtype 349 | ), 350 | ] 351 | ) 352 | values = torch.cat( 353 | [ 354 | values, 355 | torch.full((self.rna_max_len - len(values),), self.rna_pad_value, dtype=values.dtype), 356 | ] 357 | ) 358 | chrs = torch.cat( 359 | [ 360 | chrs, 361 | torch.full((self.rna_max_len - len(chrs),), self.chr_pad_id, dtype=chrs.dtype), 362 | ] 363 | ) 364 | return genes, values, chrs 365 | 366 | def prepare_row_atac(self, row): 367 | non_zero_idx = row.indices 368 | peaks = self.atac_raw['gene_ids'][non_zero_idx] 369 | chrs = self.atac_raw['chr_ids'][non_zero_idx] 370 | 371 | if self.atac_reg_id is not None: 372 | peaks = np.insert(peaks, 0, self.atac_reg_id) 373 | peaks = np.insert(peaks, 0, self.atac_cls_id) 374 | peaks = torch.from_numpy(peaks).long() 375 | 376 | if self.chr_reg_id is not None: 377 | chrs = np.insert(chrs, 0, self.chr_reg_id) 378 | chrs = np.insert(chrs, 0, self.chr_cls_id) 379 | chrs = torch.from_numpy(chrs).long() 380 | 381 | num_special_tokens = 1 if self.atac_reg_id is None else 2 382 | 383 | if len(peaks) > self.atac_max_len: 384 | idx = np.random.choice(len(peaks) - num_special_tokens, self.atac_max_len - num_special_tokens, replace=False) 385 | idx = idx + num_special_tokens 386 | for i in range(num_special_tokens): 387 | idx = np.insert(idx, i, i) 388 | peaks = peaks[idx] 389 | chrs = chrs[idx] 390 | elif len(peaks) <= self.atac_max_len: 391 | peaks = torch.cat( 392 | [ 393 | peaks, 394 | torch.full( 395 | (self.atac_max_len - len(peaks),), self.atac_raw['pad_id'], dtype=peaks.dtype 396 | ), 397 | ] 398 | ) 399 | chrs = torch.cat( 400 | [ 401 | chrs, 402 | torch.full( 403 | (self.atac_max_len - len(chrs),), self.chr_pad_id, dtype=chrs.dtype 404 | ), 405 | ] 406 | ) 407 | return peaks, chrs 408 | 409 | def __getitem__(self, idx): 410 | 411 | rna_row = self.rna_raw['values'][idx] 412 | if self.get_full_genes == True: 413 | rna_ids, rna_values, rna_chrs = self.prepare_row_rna_full(rna_row) 414 | else: 415 | rna_ids, rna_values, rna_chrs = self.prepare_row_rna(rna_row) 416 | atac_row = self.atac_raw['values'][idx] 417 | atac_ids, atac_chrs = self.prepare_row_atac(atac_row) 418 | 419 | cell_id = torch.tensor(self.cell_ids[idx]).long() 420 | batch_id = torch.tensor(self.batch_ids[idx]).long() 421 | 422 | return { 423 | "rna_ids": rna_ids, 424 | "rna_values": rna_values, 425 | "rna_chrs": rna_chrs, 426 | "atac_ids": atac_ids, 427 | "atac_chrs": atac_chrs, 428 | "cell_ids": cell_id, 429 | "batch_ids": batch_id, 430 | } 431 | -------------------------------------------------------------------------------- /pretrain_ddp.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import gc 5 | import argparse 6 | import json 7 | import random 8 | import math 9 | import random 10 | from functools import reduce 11 | import numpy as np 12 | import pandas as pd 13 | from scipy import sparse 14 | from sklearn.model_selection import train_test_split 15 | import torch 16 | from torch import nn 17 | from torch.optim import Adam 18 | from torch.nn import functional as F 19 | from tensorboardX import SummaryWriter 20 | from torch.utils.data import DataLoader, Dataset 21 | from torch.utils.data.distributed import DistributedSampler 22 | from torch.nn.parallel import DistributedDataParallel as DDP 23 | import torch.distributed as dist 24 | from model import EpiFoundation 25 | from loss.loss import MaskedMSELoss 26 | from data.dataloader import * 27 | from tokenizer import GeneVocab 28 | import scanpy as sc 29 | import anndata as ad 30 | from utils import * 31 | from memory_profiler import profile 32 | 33 | import yaml 34 | 35 | torch.autograd.set_detect_anomaly(True) 36 | 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument("--config", type=str, default='./configs/pretrain/atac_cross_debug.yml', help='Config file.') 39 | args = parser.parse_args() 40 | 41 | def main(): 42 | # read and parse config file 43 | local_rank = int(os.environ["LOCAL_RANK"]) 44 | with open(args.config, 'r') as f: 45 | config = yaml.load(f, Loader=yaml.FullLoader) 46 | 47 | 48 | train_config = config['train'] 49 | valid_config = config['valid'] 50 | data_config = config['data'] 51 | vocab_config = config['vocab'] 52 | task_name = config['task_name'] 53 | task_floder = './experiment/{}'.format(task_name) 54 | ckpt_dir = os.path.join(task_floder, 'ckpts') 55 | 56 | 57 | 58 | random_seed = train_config['seed'] 59 | EPOCHS = train_config['epochs'] 60 | BATCH_SIZE = train_config['batch_size'] 61 | GRADIENT_ACCUMULATION = train_config['gradient_accumulation_steps'] 62 | LEARNING_RATE = float(train_config['lr']) 63 | 64 | model_name = train_config['model']['encoder'] 65 | 66 | save_ckpt_freq = train_config['save_ckpt_freq'] if 'save_ckpt_freq' in train_config else 5 67 | resume = train_config['resume'] if 'resume' in train_config else False 68 | 69 | # special tokens 70 | pad = vocab_config['special_tokens']['pad'] 71 | mask = vocab_config['special_tokens']['mask'] 72 | cls = vocab_config['special_tokens']['cls'] 73 | # reg = vocab_config['special_tokens']['reg'] 74 | 75 | # distibuted setting 76 | dist.init_process_group(backend='nccl') 77 | torch.cuda.set_device(local_rank) 78 | device = torch.device("cuda", local_rank) 79 | world_size = torch.distributed.get_world_size() 80 | seed_all(random_seed + torch.distributed.get_rank()) 81 | is_master = (local_rank == 0) 82 | 83 | # init loggers 84 | logger = set_log(log_dir= os.path.join(task_floder, 'logs')) 85 | tb_logger = SummaryWriter(os.path.join(task_floder, 'tb_logs')) 86 | if is_master: 87 | logger.info(dict2str(config)) 88 | 89 | 90 | rna_vocab = GeneVocab.from_file(vocab_config['rna_path']) 91 | atac_vocab = GeneVocab.from_file(vocab_config['atac_path']) 92 | cell_vocab = GeneVocab.from_file(vocab_config['cell_type_path']) 93 | batch_vocab = GeneVocab.from_file(vocab_config['batch_path']) 94 | chr_vocab = GeneVocab.from_file(vocab_config['chr_path']) 95 | 96 | if is_master: 97 | logger.info(f'Rna vocab size: {len(rna_vocab)}') 98 | logger.info(f'Atac vocab size: {len(atac_vocab)}') 99 | 100 | if is_master: 101 | logger.info('loading training data') 102 | 103 | train_set = PairedSCDataset( 104 | rna_file = data_config['train']['rna_path'], 105 | atac_file= data_config['train']['atac_path'], 106 | rna_key = data_config['train']['rna_key'], 107 | atac_key = data_config['train']['atac_key'], 108 | rna_vocab = rna_vocab, 109 | atac_vocab = atac_vocab, 110 | cell_vocab = cell_vocab, 111 | batch_vocab= batch_vocab, 112 | chr_vocab = chr_vocab, 113 | gene2chr_file= vocab_config['gene2chr_path'], 114 | rna_max_len = train_config['model']['rna_max_len'], 115 | atac_max_len = train_config['model']['atac_max_len'], 116 | pad_token = pad['token'], 117 | rna_pad_value = pad['value'], 118 | cls_token = cls['token'], 119 | # reg_token= reg['token'], 120 | logger = logger, 121 | ) 122 | 123 | gc.collect() 124 | train_sampler = DistributedSampler(train_set) 125 | train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, sampler=train_sampler, prefetch_factor=4, num_workers=4) 126 | 127 | 128 | if is_master: 129 | logger.info('loading validation data') 130 | val_set = PairedSCDataset( 131 | rna_file = data_config['test']['rna_path'], 132 | atac_file= data_config['test']['atac_path'], 133 | rna_key = data_config['test']['rna_key'], 134 | atac_key = data_config['test']['atac_key'], 135 | rna_vocab = rna_vocab, 136 | atac_vocab = atac_vocab, 137 | cell_vocab = cell_vocab, 138 | batch_vocab= batch_vocab, 139 | chr_vocab = chr_vocab, 140 | gene2chr_file= vocab_config['gene2chr_path'], 141 | rna_max_len = train_config['model']['rna_max_len'], 142 | atac_max_len = train_config['model']['atac_max_len'], 143 | pad_token = pad['token'], 144 | rna_pad_value = pad['value'], 145 | cls_token = cls['token'], 146 | # reg_token= reg['token'], 147 | logger = logger, 148 | ) 149 | gc.collect() 150 | 151 | val_sampler = SequentialDistributedSampler(val_set, batch_size=BATCH_SIZE, world_size=world_size) 152 | val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, sampler=val_sampler, prefetch_factor=4, num_workers=4) 153 | 154 | if is_master: 155 | logger.info('Creating model') 156 | 157 | model = EpiFoundation( 158 | num_class_cell = len(cell_vocab), 159 | num_rnas = len(rna_vocab), 160 | num_atacs = len(atac_vocab), 161 | num_values= data_config['bin_num'], 162 | num_chrs= len(chr_vocab), 163 | embed_dim = train_config['model']['embedding_dim'], 164 | depth = train_config['model']['num_layers'], 165 | heads = train_config['model']['head_num'], 166 | head_dim = train_config['model']['head_dim'], 167 | encoder = model_name, 168 | dropout = train_config['model']['dropout'], 169 | pad_token_idx_rna = rna_vocab[pad['token']], 170 | pad_token_idx_atac = atac_vocab[pad['token']], 171 | cell_emb_style = train_config['model']['cell_emb_style'], 172 | mvc_arch_style = train_config['model']['mvc_arch_style'], 173 | use_batch_labels = train_config['model']['use_batch_labels'], 174 | batch_label_num= len(batch_vocab), 175 | use_chr_labels= train_config['model']['use_chr_labels'], 176 | ).to(device) 177 | 178 | # optimizer 179 | optimizer = Adam(model.parameters(), lr=LEARNING_RATE) 180 | 181 | # learning rate scheduler 182 | scheduler = CosineAnnealingWarmupRestarts( 183 | optimizer, 184 | first_cycle_steps=15, 185 | cycle_mult=2, 186 | max_lr=LEARNING_RATE, 187 | min_lr=1e-6, 188 | warmup_steps=5, 189 | gamma=0.9 190 | ) 191 | 192 | start_epoch = 1 193 | model = DDP(model, device_ids=[local_rank], output_device=local_rank) 194 | 195 | # scaler = torch.amp.GradScaler(enabled=train_config['amp'].amp) 196 | scaler = torch.cuda.amp.GradScaler(enabled=train_config['amp']) 197 | 198 | # masked_mse_loss = MaskedMSELoss().to(local_rank) 199 | cross_entropy_loss = nn.CrossEntropyLoss(reduction='mean').to(local_rank) 200 | if data_config['bin_num'] > 1: 201 | mvc_loss_fn = nn.CrossEntropyLoss(reduction='mean', ignore_index = pad['value']).to(device) 202 | else: 203 | mvc_loss_fn = nn.MSELoss(reduction='sum').to(device) 204 | mvc_weight = train_config['task_weight']['mvc'] 205 | cell_type_weight = train_config['task_weight']['cell_type'] 206 | 207 | softmax = nn.Softmax(dim=-1) 208 | 209 | steps = 0 210 | if train_config['model']['pretrained'] is not None: 211 | if is_master: 212 | logger.info('Loading pretrained model from: {}'.format(train_config['model']['pretrained'])) 213 | checkpoint = torch.load(train_config['model']['pretrained'], map_location=device) 214 | model.module.load_state_dict(checkpoint['model']) 215 | optimizer.load_state_dict(checkpoint['optimizer']) 216 | scheduler.load_state_dict(checkpoint['scheduler']) 217 | scaler.load_state_dict(checkpoint['scaler']) 218 | if resume: 219 | start_epoch = checkpoint['epoch'] + 1 220 | steps = checkpoint['steps'] 221 | del checkpoint 222 | gc.collect() 223 | 224 | dist.barrier() 225 | if is_master: 226 | logger.info('Start finetuning from epoch: {}, steps: {}'.format(start_epoch, steps)) 227 | for i in range(start_epoch, start_epoch + EPOCHS): 228 | train_loader.sampler.set_epoch(i) 229 | 230 | if is_master: 231 | logger.info('Training with {} samples, steps: {}'.format(len(train_loader.dataset), len(train_loader))) 232 | model.train() 233 | dist.barrier() 234 | running_loss = {'mvc': 0.0, 'cell': 0.0, 'total': 0.0} 235 | cum_acc_cell = 0.0 236 | cum_acc_value = 0.0 237 | for index, batch in enumerate(train_loader): 238 | index += 1 239 | steps += 1 240 | rna_values = batch['rna_values'].to(device) 241 | rna_ids = batch['rna_ids'].to(device) 242 | atac_ids = batch['atac_ids'].to(device) 243 | cell_ids = batch['cell_ids'].to(device) 244 | batch_ids = batch['batch_ids'].to(device) 245 | rna_chrs = batch['rna_chrs'].to(device) 246 | atac_chrs = batch['atac_chrs'].to(device) 247 | padding_positions = atac_ids.eq(atac_vocab[pad['token']]) 248 | if index % GRADIENT_ACCUMULATION != 0 and index != len(train_loader): 249 | with model.no_sync(): 250 | with torch.cuda.amp.autocast(enabled=train_config['amp'], dtype= torch.bfloat16): 251 | # finetue using all expression values, do not mask 252 | output = model(atac = atac_ids, rna = rna_ids, src_key_padding_mask = padding_positions, batch_id = batch_ids, rna_chrs = rna_chrs, atac_chrs = atac_chrs) 253 | 254 | mvc_loss = mvc_loss_fn(output['mvc_pred'].transpose(1, 2), rna_values) * mvc_weight 255 | cell_loss = cross_entropy_loss(output['cell_pred'], cell_ids) * cell_type_weight 256 | loss = mvc_loss + cell_loss 257 | 258 | running_loss['mvc'] += mvc_loss.item() 259 | running_loss['cell'] += cell_loss.item() 260 | running_loss['total'] += loss.item() 261 | 262 | loss = loss / GRADIENT_ACCUMULATION 263 | scaler.scale(loss).backward() 264 | else: 265 | with torch.cuda.amp.autocast(enabled=train_config['amp'], dtype= torch.bfloat16): 266 | output = model(atac = atac_ids, rna = rna_ids, src_key_padding_mask = padding_positions, batch_id = batch_ids, rna_chrs = rna_chrs, atac_chrs = atac_chrs) 267 | 268 | mvc_loss = mvc_loss_fn(output['mvc_pred'].transpose(1, 2), rna_values) * mvc_weight 269 | cell_loss = cross_entropy_loss(output['cell_pred'], cell_ids) * cell_type_weight 270 | loss = mvc_loss + cell_loss 271 | 272 | running_loss['mvc'] += mvc_loss.item() 273 | running_loss['cell'] += cell_loss.item() 274 | running_loss['total'] += loss.item() 275 | if is_master: 276 | tb_logger.add_scalar('train/mvc_loss', mvc_loss.item(), steps) 277 | tb_logger.add_scalar('train/cell_loss', cell_loss.item(), steps) 278 | tb_logger.add_scalar('train/total_loss', loss.item(), steps) 279 | logger.info(f'Epoch: {i} | Step: {index} | MVC Loss: {mvc_loss:.4f} | Cell Type Loss: {cell_loss:.4f} | Total Loss: {loss:.4f}') 280 | loss = loss / GRADIENT_ACCUMULATION 281 | scaler.scale(loss).backward() 282 | scaler.unscale_(optimizer) 283 | torch.nn.utils.clip_grad_norm_(model.parameters(), int(1e2)) 284 | scaler.step(optimizer) 285 | scaler.update() 286 | optimizer.zero_grad() 287 | # cell type accuracy 288 | type_pred = softmax(output['cell_pred']) 289 | type_pred = type_pred.argmax(dim=-1) 290 | cum_acc_cell += (type_pred.eq(cell_ids)).sum().item() / len(cell_ids) 291 | 292 | value_pred = softmax(output['mvc_pred']).argmax(dim=-1) 293 | # expression value accuracy 294 | non_pad_idx = rna_values.ne(pad['value']) 295 | non_pad_pred = value_pred[non_pad_idx] 296 | non_pad_label = rna_values[non_pad_idx] 297 | cum_acc_value += (non_pad_pred.eq(non_pad_label).sum().item()) / non_pad_label.size(0) 298 | 299 | cum_acc_cell = 100 * cum_acc_cell / index 300 | cum_acc_cell = get_reduced(cum_acc_cell, local_rank, 0, world_size) 301 | 302 | cum_acc_value = 100 * cum_acc_value / index 303 | cum_acc_value = get_reduced(cum_acc_value, local_rank, 0, world_size) 304 | for key in running_loss: 305 | running_loss[key] = running_loss[key] / index 306 | running_loss[key] = get_reduced(running_loss[key], local_rank, 0, world_size) 307 | if is_master: 308 | logger.info(f'Epoch: {i} | MVC Loss: {running_loss["mvc"]:.4f} | Cell Type Loss: {running_loss["cell"]:.4f} | Total Loss: {running_loss["total"]:.4f} | Cell Type Accuracy: {cum_acc_cell:.2f} | Expression Value Accuracy: {cum_acc_value:.2f}') 309 | dist.barrier() 310 | scheduler.step() 311 | # del train_set, train_sampler, train_loader 312 | 313 | if i % valid_config['freq'] == 0: 314 | if is_master: 315 | logger.info('#### Validation ####') 316 | model.eval() 317 | dist.barrier() 318 | running_loss = {'mvc': 0.0, 'cell': 0.0, 'total': 0.0} 319 | 320 | cum_acc_cell = 0.0 321 | cum_acc_value = 0.0 322 | 323 | with torch.no_grad(): 324 | for index, batch in enumerate(val_loader): 325 | index += 1 326 | 327 | rna_values = batch['rna_values'].to(device) 328 | rna_ids = batch['rna_ids'].to(device) 329 | atac_ids = batch['atac_ids'].to(device) 330 | cell_ids = batch['cell_ids'].to(device) 331 | batch_ids = batch['batch_ids'].to(device) 332 | rna_chrs = batch['rna_chrs'].to(device) 333 | atac_chrs = batch['atac_chrs'].to(device) 334 | 335 | padding_positions = atac_ids.eq(atac_vocab[pad['token']]) 336 | with torch.cuda.amp.autocast(enabled=train_config['amp'], dtype= torch.bfloat16): 337 | output = model(atac = atac_ids, rna = rna_ids, src_key_padding_mask = padding_positions, batch_id = batch_ids, rna_chrs = rna_chrs, atac_chrs = atac_chrs) 338 | 339 | mvc_loss = mvc_loss_fn(output['mvc_pred'].transpose(1, 2), rna_values) * mvc_weight 340 | cell_loss = cross_entropy_loss(output['cell_pred'], cell_ids) * cell_type_weight 341 | loss = mvc_loss + cell_loss 342 | 343 | running_loss['mvc'] += mvc_loss.item() 344 | running_loss['cell'] += cell_loss.item() 345 | running_loss['total'] += loss.item() 346 | 347 | type_pred = softmax(output['cell_pred']) 348 | type_pred = type_pred.argmax(dim=-1) 349 | cum_acc_cell += (type_pred.eq(cell_ids)).sum().item() / len(cell_ids) 350 | 351 | value_pred = softmax(output['mvc_pred']).argmax(dim=-1) 352 | # expression value accuracy 353 | non_pad_idx = rna_values.ne(pad['value']) 354 | non_pad_pred = value_pred[non_pad_idx] 355 | non_pad_label = rna_values[non_pad_idx] 356 | cum_acc_value += (non_pad_pred.eq(non_pad_label).sum().item()) / non_pad_label.size(0) 357 | # break 358 | for key in running_loss: 359 | running_loss[key] = running_loss[key] / index 360 | running_loss[key] = get_reduced(running_loss[key], local_rank, 0, world_size) 361 | cum_acc_cell = 100 * cum_acc_cell / index 362 | cum_acc_cell = get_reduced(cum_acc_cell, local_rank, 0, world_size) 363 | 364 | cum_acc_value = 100 * cum_acc_value / index 365 | cum_acc_value = get_reduced(cum_acc_value, local_rank, 0, world_size) 366 | 367 | # del val_set, val_sampler, val_loader 368 | if is_master: 369 | logger.info(f'MVC Loss: {running_loss["mvc"]:.4f} | Cell Type Loss: {running_loss["cell"]:.4f} | Total Loss: {running_loss["total"]:.4f} | Cell Type Accuracy: {cum_acc_cell:.2f} | Expression Value Accuracy: {cum_acc_value:.2f}') 370 | 371 | if is_master and i % save_ckpt_freq == 0: 372 | save_ckpt(i, steps, model, optimizer, scheduler, scaler, running_loss["total"], task_name, ckpt_dir) 373 | 374 | 375 | if __name__ == '__main__': 376 | main() 377 | -------------------------------------------------------------------------------- /model/EpiFoundation.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, Tensor 6 | from torch.cuda.amp import autocast 7 | from einops import rearrange, repeat 8 | 9 | from functools import partial 10 | from contextlib import contextmanager 11 | 12 | from local_attention import LocalAttention 13 | from typing import Dict, Mapping, Optional, Tuple, Any, Union 14 | from model.performer import Performer, cast_tuple 15 | from model.transformer import TransformerModel 16 | 17 | 18 | class ClsDecoder(nn.Module): 19 | """ 20 | Decoder for classification task. 21 | """ 22 | def __init__( 23 | self, 24 | d_model: int, 25 | n_cls: int, 26 | nlayers: int = 3, 27 | activation: callable = nn.ReLU, 28 | ): 29 | super().__init__() 30 | # module list 31 | self._decoder = nn.ModuleList() 32 | for i in range(nlayers - 1): 33 | self._decoder.append(nn.Linear(d_model, d_model)) 34 | self._decoder.append(activation()) 35 | self._decoder.append(nn.LayerNorm(d_model)) 36 | self.out_layer = nn.Linear(d_model, n_cls) 37 | 38 | def forward(self, x: Tensor) -> Tensor: 39 | """ 40 | Args: 41 | x: Tensor, shape [batch_size, embsize] 42 | """ 43 | for layer in self._decoder: 44 | x = layer(x) 45 | return self.out_layer(x) 46 | 47 | class GeneEncoder(nn.Module): 48 | def __init__( 49 | self, 50 | num_embeddings: int, 51 | embedding_dim: int, 52 | padding_idx: Optional[int] = None, 53 | ): 54 | super().__init__() 55 | self.embedding = nn.Embedding( 56 | num_embeddings, embedding_dim, padding_idx=padding_idx 57 | ) 58 | self.enc_norm = nn.LayerNorm(embedding_dim) 59 | 60 | def forward(self, x: Tensor) -> Tensor: 61 | x = self.embedding(x) # (batch, seq_len, embsize) 62 | x = self.enc_norm(x) 63 | return x 64 | 65 | # positional embeddings 66 | class AbsolutePositionalEmbedding(nn.Module): 67 | def __init__(self, max_seq_len, embed_dim): 68 | super().__init__() 69 | self.emb = nn.Embedding(max_seq_len, embed_dim) 70 | 71 | def forward(self, x): 72 | # x is of shape (batch, seq_len) 73 | # torch.arange is for generating a range of values from 0 to max_seq_len 74 | t = torch.arange(x.shape[1], device=x.device) 75 | return self.emb(t) 76 | 77 | 78 | class Always(nn.Module): 79 | def __init__(self, val): 80 | super().__init__() 81 | self.val = val 82 | 83 | def forward(self, *args, **kwargs): 84 | return self.val 85 | 86 | 87 | class ExprDecoder(nn.Module): 88 | def __init__( 89 | self, 90 | d_model: int, 91 | explicit_zero_prob: bool = False, 92 | use_batch_labels: bool = False, 93 | # catagory_num: Optional[int] = None, 94 | ): 95 | super().__init__() 96 | d_in = d_model * 2 if use_batch_labels else d_model 97 | self.fc = nn.Sequential( 98 | nn.Linear(d_in, d_model), 99 | nn.LeakyReLU(), 100 | nn.Linear(d_model, d_model), 101 | nn.LeakyReLU(), 102 | nn.Linear(d_model, 1), 103 | ) 104 | self.explicit_zero_prob = explicit_zero_prob 105 | if explicit_zero_prob: 106 | self.zero_logit = nn.Sequential( 107 | nn.Linear(d_in, d_model), 108 | nn.LeakyReLU(), 109 | nn.Linear(d_model, d_model), 110 | nn.LeakyReLU(), 111 | nn.Linear(d_model, 1), 112 | ) 113 | 114 | def forward(self, x: Tensor) -> Dict[str, Tensor]: 115 | """x is the output of the transformer, (batch, seq_len, d_model)""" 116 | pred_value = self.fc(x).squeeze(-1) # (batch, seq_len) 117 | 118 | # if not self.explicit_zero_prob: 119 | # return dict(pred=pred_value) 120 | # zero_logits = self.zero_logit(x).squeeze(-1) # (batch, seq_len) 121 | # zero_probs = torch.sigmoid(zero_logits) 122 | return pred_value 123 | # TODO: note that the return currently is only for training. Since decoder 124 | # is not used in the test setting for the integration task, the eval/inference 125 | # logic is not implemented yet. However, remember to implement it when 126 | # the decoder is used in any test setting. The inference logic will need 127 | # to sample from the bernoulli distribution with the zero_probs. 128 | 129 | 130 | class PretrainDecoder(nn.Module): 131 | """ 132 | Decoder for the masked value prediction for cell embeddings. 133 | """ 134 | 135 | def __init__( 136 | self, 137 | d_model: int, 138 | arch_style: str = "inner product", 139 | query_activation: nn.Module = nn.Sigmoid, 140 | hidden_activation: nn.Module = nn.PReLU, 141 | explicit_zero_prob: bool = False, 142 | use_batch_labels: bool = False, 143 | catagory_num: Optional[int] = 2, 144 | ) -> None: 145 | """ 146 | Args: 147 | d_model (:obj:`int`): dimension of the gene embedding. 148 | arch_style (:obj:`str`): architecture style of the decoder, choice from 149 | 1. "inner product" or 2. "concat query" or 3. "sum query". 150 | query_activation (:obj:`nn.Module`): activation function for the query 151 | vectors. 152 | hidden_activation (:obj:`nn.Module`): activation function for the hidden 153 | layers. 154 | """ 155 | super().__init__() 156 | d_in = d_model * 2 if use_batch_labels else d_model 157 | if arch_style in ["inner product", "inner product, detach"]: 158 | self.gene2query = nn.Linear(d_model, d_model) 159 | self.query_activation = query_activation() 160 | self.W = nn.Linear(d_model, d_in, bias=False) 161 | if explicit_zero_prob: # by default, gene-wise prob rate 162 | self.W_zero_logit = nn.Linear(d_model, d_in) 163 | elif arch_style == "concat query": 164 | self.gene2query = nn.Linear(d_model, 128) 165 | self.query_activation = query_activation() 166 | self.fc1 = nn.Linear(d_in + 128, 128) 167 | self.hidden_activation = hidden_activation() 168 | # self.fc2 = nn.Linear(64, 1) 169 | # for rna value prediction 170 | self.fc2 = nn.Linear(128, catagory_num) 171 | elif arch_style == "sum query": 172 | self.gene2query = nn.Linear(d_model, d_model) 173 | self.query_activation = query_activation() 174 | self.fc1 = nn.Linear(d_in, 128) 175 | self.hidden_activation = hidden_activation() 176 | self.fc2 = nn.Linear(128, catagory_num) 177 | else: 178 | raise ValueError(f"Unknown arch_style: {arch_style}") 179 | 180 | self.arch_style = arch_style 181 | self.do_detach = arch_style.endswith("detach") 182 | self.explicit_zero_prob = explicit_zero_prob 183 | 184 | def forward( 185 | self, cell_emb: Tensor, gene_embs: Tensor 186 | ) -> Union[Tensor, Dict[str, Tensor]]: 187 | """ 188 | Args: 189 | cell_emb: Tensor, shape (batch, embsize=d_model) 190 | gene_embs: Tensor, shape (batch, seq_len, embsize=d_model) 191 | """ 192 | gene_embs = gene_embs.detach() if self.do_detach else gene_embs 193 | if self.arch_style in ["inner product", "inner product, detach"]: 194 | query_vecs = self.query_activation(self.gene2query(gene_embs)) 195 | cell_emb = cell_emb.unsqueeze(2) # (batch, embsize, 1) 196 | # the pred gene expr values, # (batch, seq_len) 197 | pred_value = torch.bmm(self.W(query_vecs), cell_emb).squeeze(2) 198 | if not self.explicit_zero_prob: 199 | return pred_value 200 | # zero logits need to based on the cell_emb, because of input exprs 201 | zero_logits = torch.bmm(self.W_zero_logit(query_vecs), cell_emb).squeeze(2) 202 | zero_probs = torch.sigmoid(zero_logits) 203 | return pred_value 204 | elif self.arch_style == "concat query": 205 | query_vecs = self.query_activation(self.gene2query(gene_embs)) 206 | # expand cell_emb to (batch, seq_len, embsize) 207 | cell_emb = cell_emb.unsqueeze(1).expand(-1, gene_embs.shape[1], -1) 208 | 209 | h = self.hidden_activation( 210 | self.fc1(torch.cat([cell_emb, query_vecs], dim=2)) 211 | ) 212 | if self.explicit_zero_prob: 213 | raise NotImplementedError 214 | return self.fc2(h).squeeze(2) # (batch, seq_len) 215 | elif self.arch_style == "sum query": 216 | query_vecs = self.query_activation(self.gene2query(gene_embs)) 217 | cell_emb = cell_emb.unsqueeze(1) 218 | 219 | h = self.hidden_activation(self.fc1(cell_emb + query_vecs)) 220 | if self.explicit_zero_prob: 221 | raise NotImplementedError 222 | return self.fc2(h).squeeze(2) # (batch, seq_len) 223 | 224 | class CategoryValueEncoder(nn.Module): 225 | def __init__( 226 | self, 227 | num_embeddings: int, 228 | embedding_dim: int, 229 | padding_idx: Optional[int] = None, 230 | ): 231 | super().__init__() 232 | self.embedding = nn.Embedding( 233 | num_embeddings, embedding_dim, padding_idx=padding_idx 234 | ) 235 | self.enc_norm = nn.LayerNorm(embedding_dim) 236 | 237 | def forward(self, x: Tensor) -> Tensor: 238 | x = x.long() 239 | x = self.embedding(x) # (batch, seq_len, embsize) 240 | x = self.enc_norm(x) 241 | return x 242 | 243 | # sinusoidal positional embeddings 244 | 245 | class Gene2VecPositionalEmbedding(nn.Module): 246 | def __init__(self, max_seq_len, embed_dim): 247 | super().__init__() 248 | gene2vec_weight = np.load('../data/gene2vec_16906.npy') 249 | gene2vec_weight = np.concatenate((gene2vec_weight, np.zeros((1, gene2vec_weight.shape[1]))), axis=0) 250 | gene2vec_weight = torch.from_numpy(gene2vec_weight) 251 | self.emb = nn.Embedding.from_pretrained(gene2vec_weight) 252 | 253 | def forward(self, x): 254 | t = torch.arange(x.shape[1], device=x.device) 255 | return self.emb(t) 256 | 257 | class BatchLabelEncoder(nn.Module): 258 | def __init__( 259 | self, 260 | num_embeddings: int, 261 | embedding_dim: int, 262 | padding_idx: Optional[int] = None, 263 | ): 264 | super().__init__() 265 | self.embedding = nn.Embedding( 266 | num_embeddings, embedding_dim, padding_idx=padding_idx 267 | ) 268 | self.enc_norm = nn.LayerNorm(embedding_dim) 269 | 270 | def forward(self, x: Tensor) -> Tensor: 271 | x = self.embedding(x) # (batch, embsize) 272 | x = self.enc_norm(x) 273 | return x 274 | 275 | class EpiFoundation(nn.Module): 276 | def __init__( 277 | self, 278 | num_class_cell, # num of cell categories 279 | num_rnas, # num of genes (or atac peaks) 280 | num_atacs, # num of genes (or atac peaks) 281 | num_values, # num of values 282 | num_chrs, # num of chromosomes 283 | embed_dim, # embed_dim of tokens 284 | depth, # layers 285 | heads, # num of heads 286 | head_dim = 64, # embed_dim of heads 287 | encoder:str = 'transformer', # encoder type, performer or transformer 288 | dropout = 0.2, 289 | pad_token_idx_atac = 0, # padding token index , shoule be vocab[pad_token], set to 0 for debugging 290 | pad_token_idx_rna = 0, # padding token index , shoule be vocab[pad_token], set to 0 for debugging 291 | cell_emb_style = "cls", # cell embedding style 292 | mvc_arch_style = "inner product", # mvc decoder architecture style 293 | use_batch_labels = False, # whether to use batch labels 294 | batch_label_num = 13, # num of batch labels 295 | use_chr_labels = False, # whether to use chr labels 296 | transformer_backend = 'flash', # backend of transformer, pytorch or einsum 297 | stage = 'pretrain', # stage of the model, pretrain or finetune 298 | ): 299 | super().__init__() 300 | 301 | self.stage = stage 302 | # self.express_emb = nn.Embedding(num_tokens, embed_dim) 303 | self.encoder_type = encoder 304 | self.cell_emb_style = cell_emb_style 305 | self.embed_dim = embed_dim 306 | 307 | # determine positional embedding 308 | self.rna_emb = GeneEncoder(num_rnas, embed_dim, padding_idx=pad_token_idx_rna) 309 | self.atac_emb = GeneEncoder(num_atacs, embed_dim, padding_idx=pad_token_idx_atac) 310 | 311 | if use_batch_labels: 312 | self.batch_emb = BatchLabelEncoder(batch_label_num, embed_dim) 313 | else: 314 | self.batch_emb = None 315 | 316 | if use_chr_labels: 317 | self.chr_emb = GeneEncoder(num_chrs, embed_dim) 318 | else: 319 | self.chr_emb = None 320 | 321 | self.dropout_rna = nn.Dropout(dropout) 322 | self.dropout_atac = nn.Dropout(dropout) 323 | 324 | if encoder == 'performer': 325 | self.encoder = Performer(embed_dim, depth, heads, head_dim) 326 | elif encoder == 'transformer': 327 | self.encoder = TransformerModel(d_model=embed_dim, nhead=heads, nlayers=depth, d_hid= head_dim, dropout=dropout, fast_transformer_backend=transformer_backend) 328 | # self.encoder = Performer(embed_dim, depth, heads, head_dim, local_attn_heads, local_window_size, causal, ff_mult, nb_features, feature_redraw_interval, reversible, ff_chunks, generalized_attention, kernel_fn, use_scalenorm, use_rezero, ff_glu, ff_dropout, attn_dropout, cross_attend, no_projection, auto_check_redraw, qkv_bias) 329 | self.norm = nn.LayerNorm(embed_dim) 330 | # self.to_out = nn.Linear(embed_dim, 1) 331 | self.cls_decoder = ClsDecoder(embed_dim, num_class_cell) 332 | self.mvc_decoder = PretrainDecoder(embed_dim, arch_style = mvc_arch_style, use_batch_labels=use_batch_labels, catagory_num = num_values) 333 | self.bn_atac = nn.BatchNorm1d(embed_dim, eps=6.1e-5) 334 | self.bn_rna = nn.BatchNorm1d(embed_dim, eps=6.1e-5) 335 | if stage == 'value_finetune': 336 | for param in self.cls_decoder.parameters(): 337 | param.requires_grad = False 338 | for param in self.mvc_decoder.parameters(): 339 | param.requires_grad = False 340 | self.value_decoder = PretrainDecoder(embed_dim, arch_style = mvc_arch_style, use_batch_labels=use_batch_labels, catagory_num = 1) 341 | 342 | def _get_cell_emb_from_layer( 343 | self, layer_output: Tensor, weights: Tensor = None 344 | ) -> Tensor: 345 | """ 346 | Args: 347 | layer_output(:obj:`Tensor`): shape (batch, seq_len, embsize) 348 | weights(:obj:`Tensor`): shape (batch, seq_len), optional and only used 349 | when :attr:`self.cell_emb_style` is "w-pool". 350 | 351 | Returns: 352 | :obj:`Tensor`: shape (batch, embsize) 353 | """ 354 | if self.cell_emb_style == "cls": 355 | cell_emb = layer_output[:, 0, :] # (batch, embsize) 356 | elif self.cell_emb_style == "avg-pool": 357 | cell_emb = torch.mean(layer_output, dim=1) 358 | elif self.cell_emb_style == "w-pool": 359 | if weights is None: 360 | raise ValueError("weights is required when cell_emb_style is w-pool") 361 | if weights.dim() != 2: 362 | raise ValueError("weights should be 2D") 363 | cell_emb = torch.sum(layer_output * weights.unsqueeze(2), dim=1) 364 | cell_emb = F.normalize(cell_emb, p=2, dim=1) # (batch, embsize) 365 | return cell_emb 366 | 367 | def forward(self, atac, rna, src_key_padding_mask: Optional[Tensor] = None, **kwargs): 368 | atac_emb = self.atac_emb(atac) 369 | if self.chr_emb is not None: 370 | chr_emb = self.chr_emb(kwargs['atac_chrs']) 371 | atac_emb = atac_emb + chr_emb 372 | atac_emb = self.dropout_atac(atac_emb) 373 | atac_emb = self.bn_atac(atac_emb.permute(0, 2, 1)).permute(0, 2, 1) 374 | 375 | x = self.encoder(atac_emb, src_key_padding_mask = src_key_padding_mask) 376 | transformer_output = self.norm(x) # (batch, seq_len, embsize) 377 | 378 | rna_emb = self.rna_emb(rna) 379 | if self.chr_emb is not None: 380 | chr_emb = self.chr_emb(kwargs['rna_chrs']) 381 | rna_emb = rna_emb + chr_emb 382 | rna_emb = self.dropout_rna(rna_emb) 383 | rna_emb = self.bn_rna(rna_emb.permute(0, 2, 1)).permute(0, 2, 1) 384 | 385 | output = {} 386 | cell_emb = self._get_cell_emb_from_layer(transformer_output) 387 | 388 | if self.batch_emb is not None: 389 | batch_emb = self.batch_emb(kwargs['batch_id']) 390 | cell_emb_w_batch = torch.cat((cell_emb, batch_emb), dim = 1) 391 | output["mvc_pred"] = self.mvc_decoder(cell_emb_w_batch, rna_emb) 392 | if self.stage == 'value_finetune': 393 | output["value_pred"] = self.value_decoder(cell_emb_w_batch, rna_emb) 394 | else: 395 | output["mvc_pred"] = self.mvc_decoder(cell_emb, rna_emb) 396 | if self.stage == 'value_finetune': 397 | output["value_pred"] = self.value_decoder(cell_emb, rna_emb) 398 | 399 | output["cell_emb"] = cell_emb 400 | output["cell_pred"] = self.cls_decoder(cell_emb) # (batch, n_cls) 401 | 402 | return output -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import print_function 4 | import json 5 | import os 6 | import struct 7 | import sys 8 | import platform 9 | import re 10 | import time 11 | import traceback 12 | import requests 13 | import socket 14 | import random 15 | import math 16 | import numpy as np 17 | import torch 18 | import logging 19 | import datetime 20 | from torch.optim.lr_scheduler import _LRScheduler 21 | from torch import nn 22 | import torch.nn.functional as F 23 | from torch.nn.modules.loss import _WeightedLoss 24 | from typing import Dict, Optional 25 | 26 | import torch.distributed as dist 27 | from torch.nn.parallel import DistributedDataParallel as DDP 28 | # DDP imports 29 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 30 | # FSDP imports 31 | from torch.distributed.fsdp.fully_sharded_data_parallel import ( 32 | CPUOffload, MixedPrecision, BackwardPrefetch, ShardingStrategy 33 | ) 34 | from torch.distributed.fsdp.wrap import ( 35 | size_based_auto_wrap_policy, 36 | transformer_auto_wrap_policy, 37 | enable_wrap, 38 | wrap 39 | ) 40 | from torch.distributed.fsdp import ( 41 | FullStateDictConfig, 42 | StateDictType 43 | ) 44 | 45 | 46 | def seed_all(seed_value, cuda_deterministic=False): 47 | """ 48 | set all random seeds 49 | """ 50 | random.seed(seed_value) 51 | os.environ['PYTHONHASHSEED'] = str(seed_value) 52 | # np.random.seed(seed_value) 53 | torch.manual_seed(seed_value) 54 | if torch.cuda.is_available(): 55 | torch.cuda.manual_seed(seed_value) 56 | torch.cuda.manual_seed_all(seed_value) 57 | # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html 58 | if cuda_deterministic: # slower, more reproducible 59 | torch.backends.cudnn.deterministic = True 60 | torch.backends.cudnn.benchmark = False 61 | else: # faster, less reproducible 62 | torch.backends.cudnn.deterministic = False 63 | torch.backends.cudnn.benchmark = True 64 | 65 | 66 | def set_log(log_dir, rank = -1): 67 | """ 68 | save log 69 | """ 70 | time_now = datetime.datetime.now() 71 | log_file = os.path.join(log_dir, f'{time_now.year}_{time_now.month}_{time_now.day}_{time_now.hour}_{time_now.minute}.log') 72 | if not os.path.exists(log_dir): 73 | os.makedirs(log_dir) 74 | else: 75 | pass 76 | 77 | logging.basicConfig(level=logging.INFO if rank in [-1, 0] else logging.WARN, 78 | format='[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d] %(message)s', 79 | datefmt='[%X]', 80 | handlers=[logging.FileHandler(log_file), logging.StreamHandler()] 81 | ) 82 | logger = logging.getLogger() 83 | return logger 84 | 85 | def dict2str(opt, indent_level=1): 86 | """dict to string for printing options. 87 | 88 | Args: 89 | opt (dict): Option dict. 90 | indent_level (int): Indent level. Default: 1. 91 | 92 | Return: 93 | (str): Option string for printing. 94 | """ 95 | msg = '\n' 96 | for k, v in opt.items(): 97 | if isinstance(v, dict): 98 | msg += ' ' * (indent_level * 2) + k + ':[' 99 | msg += dict2str(v, indent_level + 1) 100 | msg += ' ' * (indent_level * 2) + ']\n' 101 | else: 102 | msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n' 103 | return msg 104 | 105 | 106 | def save_ckpt_fsdp(epoch, steps, model, optimizer, scheduler, scaler, loss, task_name, ckpt_dir, rank): 107 | # Ensure all processes are synchronized 108 | dist.barrier() 109 | # Switch to FULL_STATE_DICT context to gather full state_dict 110 | with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT): 111 | full_state_dict = model.state_dict() 112 | # Only save on the master process 113 | if rank == 0: 114 | if not os.path.exists(ckpt_dir): 115 | os.makedirs(ckpt_dir) 116 | ckpt = { 117 | 'epoch': epoch, 118 | 'steps': steps, 119 | 'model': full_state_dict, 120 | 'optimizer': optimizer.state_dict(), 121 | 'scheduler': scheduler.state_dict(), 122 | 'scaler': scaler.state_dict(), 123 | 'loss': loss, 124 | } 125 | ckpt_path = os.path.join(ckpt_dir, f'checkpoint_{epoch}.pth') 126 | torch.save(ckpt, ckpt_path) 127 | 128 | def save_ckpt(epoch, steps, model, optimizer, scheduler, scaler, losses, model_name, ckpt_folder): 129 | """ 130 | save checkpoint 131 | """ 132 | if not os.path.exists(ckpt_folder): 133 | os.makedirs(ckpt_folder) 134 | torch.save( 135 | { 136 | 'epoch': epoch, 137 | 'steps': steps, 138 | 'model': model.module.state_dict(), 139 | 'optimizer': optimizer.state_dict(), 140 | 'scheduler': scheduler.state_dict(), 141 | 'scaler': scaler.state_dict(), 142 | 'losses': losses, 143 | }, 144 | f'{ckpt_folder}/Epoch_{epoch}_Step_{steps}_{model_name}.pth', 145 | ) 146 | 147 | 148 | 149 | 150 | def get_reduced(tensor, current_device, dest_device, world_size): 151 | """ 152 | garther tensor from different GPUs to the main GPU 153 | """ 154 | tensor = torch.tensor(tensor, device ='cuda') 155 | torch.distributed.reduce(tensor, dst=dest_device, op = dist.ReduceOp.SUM) 156 | tensor_mean = tensor.item() / world_size 157 | return tensor_mean 158 | 159 | def get_ndtensor_reduced(tensor, current_device, dest_device, world_size): 160 | """ 161 | garther tensor from different GPUs to the main GPU 162 | """ 163 | tensor = tensor.clone().detach() if torch.is_tensor(tensor) else torch.tensor(tensor) 164 | tensor = tensor.to(current_device) 165 | torch.distributed.reduce(tensor, dst=dest_device) 166 | tensor_mean = torch.zeros(tensor.shape) 167 | if len(tensor.shape) == 2: 168 | for i in range(tensor.shape[0]): 169 | for j in range(tensor.shape[1]): 170 | tensor_mean[i,j] = tensor[i,j].item() / world_size 171 | elif len(tensor.shape) == 1: 172 | for i in range(tensor.shape[0]): 173 | tensor_mean[i] = tensor[i].item() / world_size 174 | return tensor_mean 175 | 176 | def numel(m: torch.nn.Module, only_trainable: bool = False): 177 | """ 178 | returns the total number of parameters used by `m` (only counting 179 | shared parameters once); if `only_trainable` is True, then only 180 | includes parameters with `requires_grad = True` 181 | """ 182 | parameters = m.parameters() 183 | if only_trainable: 184 | parameters = list(p for p in parameters if p.requires_grad) 185 | unique = dict((p.data_ptr(), p) for p in parameters).values() 186 | return sum(p.numel() for p in unique) 187 | 188 | 189 | def label_smooth(y, K, epsilon=0.1): 190 | """ 191 | Label smoothing for multiclass labels 192 | One hot encode labels `y` over `K` classes. `y` should be of the form [1, 6, 3, etc.] 193 | """ 194 | m = len(y) 195 | out = np.ones((m, K)) * epsilon / K 196 | for index in range(m): 197 | out[index][y[index] - 1] += 1 - epsilon 198 | return torch.tensor(out) 199 | 200 | 201 | class SequentialDistributedSampler(torch.utils.data.sampler.Sampler): 202 | """ 203 | Distributed Sampler that subsamples indicies sequentially, 204 | making it easier to collate all results at the end. 205 | Even though we only use this sampler for eval and predict (no training), 206 | which means that the model params won't have to be synced (i.e. will not hang 207 | for synchronization even if varied number of forward passes), we still add extra 208 | samples to the sampler to make it evenly divisible (like in `DistributedSampler`) 209 | to make it easy to `gather` or `reduce` resulting tensors at the end of the loop. 210 | """ 211 | 212 | def __init__(self, dataset, batch_size, world_size, rank=None, num_replicas=None): 213 | if num_replicas is None: 214 | if not torch.distributed.is_available(): 215 | raise RuntimeError("Requires distributed package to be available") 216 | num_replicas = world_size 217 | if rank is None: 218 | if not torch.distributed.is_available(): 219 | raise RuntimeError("Requires distributed package to be available") 220 | rank = torch.distributed.get_rank() 221 | self.dataset = dataset 222 | self.num_replicas = num_replicas 223 | self.rank = rank 224 | self.batch_size = batch_size 225 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.batch_size / self.num_replicas)) * self.batch_size 226 | self.total_size = self.num_samples * self.num_replicas 227 | 228 | def __iter__(self): 229 | indices = list(range(len(self.dataset))) 230 | # add extra samples to make it evenly divisible 231 | indices += [indices[-1]] * (self.total_size - len(indices)) 232 | # subsample 233 | indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples] 234 | return iter(indices) 235 | 236 | def __len__(self): 237 | return self.num_samples 238 | 239 | 240 | def distributed_concat(tensor, num_total_examples, world_size): 241 | """ 242 | merge the inference results of different processes 243 | """ 244 | output_tensors = [tensor.clone() for _ in range(world_size)] 245 | torch.distributed.all_gather(output_tensors, tensor) 246 | concat = torch.cat(output_tensors, dim=0) 247 | # truncate the dummy elements added by SequentialDistributedSampler 248 | return concat[:num_total_examples] 249 | 250 | 251 | class CosineAnnealingWarmupRestarts(_LRScheduler): 252 | """ 253 | optimizer (Optimizer): Wrapped optimizer. 254 | first_cycle_steps (int): First cycle step size. 255 | cycle_mult(float): Cycle steps magnification. Default: -1. 256 | max_lr(float): First cycle's max learning rate. Default: 0.1. 257 | min_lr(float): Min learning rate. Default: 0.001. 258 | warmup_steps(int): Linear warmup step size. Default: 0. 259 | gamma(float): Decrease rate of max learning rate by cycle. Default: 1. 260 | last_epoch (int): The index of last epoch. Default: -1. 261 | """ 262 | 263 | def __init__(self, 264 | optimizer : torch.optim.Optimizer, 265 | first_cycle_steps : int, 266 | cycle_mult : float = 1., 267 | max_lr : float = 0.1, 268 | min_lr : float = 0.001, 269 | warmup_steps : int = 0, 270 | gamma : float = 1., 271 | last_epoch : int = -1 272 | ): 273 | assert warmup_steps < first_cycle_steps 274 | 275 | self.first_cycle_steps = first_cycle_steps # first cycle step size 276 | self.cycle_mult = cycle_mult # cycle steps magnification 277 | self.base_max_lr = max_lr # first max learning rate 278 | self.max_lr = max_lr # max learning rate in the current cycle 279 | self.min_lr = min_lr # min learning rate 280 | self.warmup_steps = warmup_steps # warmup step size 281 | self.gamma = gamma # decrease rate of max learning rate by cycle 282 | 283 | self.cur_cycle_steps = first_cycle_steps # first cycle step size 284 | self.cycle = 0 # cycle count 285 | self.step_in_cycle = last_epoch # step size of the current cycle 286 | 287 | super(CosineAnnealingWarmupRestarts, self).__init__(optimizer, last_epoch) 288 | 289 | # set learning rate min_lr 290 | self.init_lr() 291 | 292 | def init_lr(self): 293 | self.base_lrs = [] 294 | for param_group in self.optimizer.param_groups: 295 | param_group['lr'] = self.min_lr 296 | self.base_lrs.append(self.min_lr) 297 | 298 | def get_lr(self): 299 | if self.step_in_cycle == -1: 300 | return self.base_lrs 301 | elif self.step_in_cycle < self.warmup_steps: 302 | return [(self.max_lr - base_lr)*self.step_in_cycle / self.warmup_steps + base_lr for base_lr in self.base_lrs] 303 | else: 304 | return [base_lr + (self.max_lr - base_lr) \ 305 | * (1 + math.cos(math.pi * (self.step_in_cycle-self.warmup_steps) \ 306 | / (self.cur_cycle_steps - self.warmup_steps))) / 2 307 | for base_lr in self.base_lrs] 308 | 309 | def step(self, epoch=None): 310 | if epoch is None: 311 | epoch = self.last_epoch + 1 312 | self.step_in_cycle = self.step_in_cycle + 1 313 | if self.step_in_cycle >= self.cur_cycle_steps: 314 | self.cycle += 1 315 | self.step_in_cycle = self.step_in_cycle - self.cur_cycle_steps 316 | self.cur_cycle_steps = int((self.cur_cycle_steps - self.warmup_steps) * self.cycle_mult) + self.warmup_steps 317 | else: 318 | if epoch >= self.first_cycle_steps: 319 | if self.cycle_mult == 1.: 320 | self.step_in_cycle = epoch % self.first_cycle_steps 321 | self.cycle = epoch // self.first_cycle_steps 322 | else: 323 | n = int(math.log((epoch / self.first_cycle_steps * (self.cycle_mult - 1) + 1), self.cycle_mult)) 324 | self.cycle = n 325 | self.step_in_cycle = epoch - int(self.first_cycle_steps * (self.cycle_mult ** n - 1) / (self.cycle_mult - 1)) 326 | self.cur_cycle_steps = self.first_cycle_steps * self.cycle_mult ** (n) 327 | else: 328 | self.cur_cycle_steps = self.first_cycle_steps 329 | self.step_in_cycle = epoch 330 | 331 | self.max_lr = self.base_max_lr * (self.gamma**self.cycle) 332 | self.last_epoch = math.floor(epoch) 333 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 334 | param_group['lr'] = lr 335 | 336 | 337 | class DistanceLoss(_WeightedLoss): 338 | """ 339 | CrossEntropyLoss with Distance Weighted 340 | """ 341 | def __init__(self, weight=None, reduction='mean', ignore_index = None): 342 | super().__init__(weight=weight, reduction=reduction) 343 | self.weight = weight 344 | self.reduction = reduction 345 | self.ignore_index = ignore_index 346 | def forward(self, inputs, targets): 347 | if len(inputs.shape) > 2: 348 | inputs = inputs.reshape(-1, inputs.size(-1)) 349 | if len(targets.shape) > 1: 350 | targets = targets.reshape(-1) 351 | if self.ignore_index is not None: 352 | keep_index = (targets != self.ignore_index).nonzero(as_tuple=True)[0] 353 | targets = torch.index_select(targets, 0, keep_index) #targets[targets != self.ignore_index] 354 | inputs = torch.index_select(inputs, 0, keep_index) 355 | lsm = F.log_softmax(inputs, -1) 356 | targets = torch.empty(size=(targets.size(0), inputs.size(-1)), device=targets.device).fill_(0).scatter_(1, targets.data.unsqueeze(1), 1) 357 | if self.weight is not None: 358 | lsm = lsm * self.weight.unsqueeze(0) 359 | loss = -(targets * lsm).sum(-1) 360 | inputs = nn.Softmax(dim=-1)(inputs)[..., 1:-1].argmax(dim=-1) + 1 361 | # print('inputs', inputs.device, inputs.shape) 362 | targets = nn.Softmax(dim=-1)(targets)[..., 1:-1].argmax(dim=-1) + 1 363 | # print('targets', targets.device, targets.shape) 364 | distance = abs(inputs - targets) + 1e-2 365 | # print('loss.shape', loss.shape) 366 | # print('distance.shape', distance.shape) 367 | loss = loss * distance 368 | if self.reduction == 'sum': 369 | loss = loss.sum() 370 | elif self.reduction == 'mean': 371 | loss = loss.mean() 372 | return loss 373 | 374 | 375 | class LabelSmoothCrossEntropyLoss(_WeightedLoss): 376 | """ 377 | CrossEntropyLoss with Label Somoothing 378 | """ 379 | def __init__(self, weight=None, reduction='mean', smoothing=0.0): 380 | super().__init__(weight=weight, reduction=reduction) 381 | self.smoothing = smoothing 382 | self.weight = weight 383 | self.reduction = reduction 384 | 385 | @staticmethod 386 | def _smooth_one_hot(targets: torch.Tensor, n_classes: int, smoothing=0.0): 387 | assert 0 <= smoothing < 1 388 | with torch.no_grad(): 389 | targets = torch.empty(size=(targets.size(0), n_classes), 390 | device=targets.device) \ 391 | .fill_(smoothing / (n_classes - 1)) \ 392 | .scatter_(1, targets.data.unsqueeze(1), 1. - smoothing) 393 | return targets 394 | 395 | def forward(self, inputs, targets): 396 | targets = LabelSmoothCrossEntropyLoss._smooth_one_hot(targets, inputs.size(-1), 397 | self.smoothing) 398 | lsm = F.log_softmax(inputs, -1) 399 | 400 | if self.weight is not None: 401 | lsm = lsm * self.weight.unsqueeze(0) 402 | 403 | loss = -(targets * lsm).sum(-1) 404 | 405 | if self.reduction == 'sum': 406 | loss = loss.sum() 407 | elif self.reduction == 'mean': 408 | loss = loss.mean() 409 | 410 | return loss 411 | 412 | 413 | def eval_scib_metrics( 414 | adata, 415 | logger, 416 | batch_key: str = "str_batch", 417 | label_key: str = "celltype", 418 | notes: Optional[str] = None 419 | ): 420 | import scib 421 | 422 | logger.info("Calculating metrics...") 423 | results = scib.metrics.metrics( 424 | adata, 425 | adata_int=adata, 426 | batch_key=batch_key, 427 | label_key=label_key, 428 | embed="embedding", 429 | isolated_labels_asw_=False, 430 | silhouette_=True, 431 | hvg_score_=False, 432 | graph_conn_=True, 433 | pcr_=True, 434 | isolated_labels_f1_=False, 435 | trajectory_=False, 436 | nmi_=True, # use the clustering, bias to the best matching 437 | ari_=True, # use the clustering, bias to the best matching 438 | cell_cycle_=False, 439 | kBET_=False, # kBET return nan sometimes, need to examine 440 | ilisi_=False, 441 | clisi_=False, 442 | ) 443 | 444 | if notes is not None: 445 | logger.info(f"{notes}") 446 | 447 | logger.info(f"{results}") 448 | 449 | result_dict = results[0].to_dict() 450 | logger.info( 451 | "Biological Conservation Metrics: \n" 452 | f"ASW (cell-type): {result_dict['ASW_label']:.4f}, graph cLISI: {result_dict['cLISI']:.4f}, " 453 | f"isolated label silhouette: {result_dict['isolated_label_silhouette']:.4f}, \n" 454 | "Batch Effect Removal Metrics: \n" 455 | f"PCR_batch: {result_dict['PCR_batch']:.4f}, ASW (batch): {result_dict['ASW_label/batch']:.4f}, " 456 | f"graph connectivity: {result_dict['graph_conn']:.4f}, graph iLISI: {result_dict['iLISI']:.4f}" 457 | ) 458 | 459 | result_dict["avg_bio"] = np.mean( 460 | [ 461 | result_dict["NMI_cluster/label"], 462 | result_dict["ARI_cluster/label"], 463 | result_dict["ASW_label"], 464 | ] 465 | ) 466 | 467 | # remove nan value in result_dict 468 | result_dict = {k: v for k, v in result_dict.items() if not np.isnan(v)} 469 | 470 | return result_dict -------------------------------------------------------------------------------- /pretrain_fsdp.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import gc 5 | import argparse 6 | import json 7 | import random 8 | import math 9 | import random 10 | from functools import reduce 11 | import numpy as np 12 | import pandas as pd 13 | from scipy import sparse 14 | from sklearn.model_selection import train_test_split 15 | import torch 16 | from torch import nn 17 | from torch.optim import Adam 18 | from torch.nn import functional as F 19 | from tensorboardX import SummaryWriter 20 | from torch.utils.data import DataLoader, Dataset 21 | from torch.utils.data.distributed import DistributedSampler 22 | import torch.distributed as dist 23 | 24 | # DDP imports 25 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 26 | # FSDP imports 27 | from torch.distributed.fsdp.fully_sharded_data_parallel import ( 28 | CPUOffload, MixedPrecision, BackwardPrefetch, ShardingStrategy 29 | ) 30 | from torch.distributed.fsdp.wrap import ( 31 | size_based_auto_wrap_policy, 32 | transformer_auto_wrap_policy, 33 | enable_wrap, 34 | wrap 35 | ) 36 | from torch.distributed.fsdp import ( 37 | FullStateDictConfig, 38 | StateDictType, 39 | ) 40 | 41 | 42 | from model import EpiFoundation 43 | from loss.loss import MaskedMSELoss 44 | from data.dataloader import * 45 | from tokenizer import GeneVocab 46 | import scanpy as sc 47 | import anndata as ad 48 | from utils import * 49 | from memory_profiler import profile 50 | 51 | import yaml 52 | 53 | torch.autograd.set_detect_anomaly(True) 54 | 55 | parser = argparse.ArgumentParser() 56 | parser.add_argument("--config", type=str, default='./configs/pretrain/atac_cross_debug.yml', help='Config file.') 57 | args = parser.parse_args() 58 | 59 | def main(): 60 | dist.init_process_group(backend='nccl') 61 | # read and parse config file 62 | local_rank = int(os.environ["LOCAL_RANK"]) 63 | rank = int(os.environ["RANK"]) 64 | 65 | with open(args.config, 'r') as f: 66 | config = yaml.load(f, Loader=yaml.FullLoader) 67 | 68 | 69 | train_config = config['train'] 70 | valid_config = config['valid'] 71 | data_config = config['data'] 72 | vocab_config = config['vocab'] 73 | task_name = config['task_name'] 74 | task_floder = './experiment/{}'.format(task_name) 75 | ckpt_dir = os.path.join(task_floder, 'ckpts') 76 | 77 | 78 | 79 | random_seed = train_config['seed'] 80 | EPOCHS = train_config['epochs'] 81 | BATCH_SIZE = train_config['batch_size'] 82 | GRADIENT_ACCUMULATION = train_config['gradient_accumulation_steps'] 83 | LEARNING_RATE = float(train_config['lr']) 84 | 85 | model_name = train_config['model']['encoder'] 86 | 87 | save_ckpt_freq = train_config['save_ckpt_freq'] if 'save_ckpt_freq' in train_config else 5 88 | resume = train_config['resume'] if 'resume' in train_config else False 89 | 90 | # special tokens 91 | pad = vocab_config['special_tokens']['pad'] 92 | mask = vocab_config['special_tokens']['mask'] 93 | cls = vocab_config['special_tokens']['cls'] 94 | 95 | # distibuted setting, use local_rank as device id 96 | torch.cuda.set_device(local_rank) 97 | device = torch.device("cuda", local_rank) 98 | 99 | world_size = dist.get_world_size() 100 | seed_all(random_seed + rank) 101 | # use rank to determine if it is master node, instead of local_rank 102 | is_master = (rank == 0) 103 | 104 | # init loggers 105 | logger = set_log(log_dir= os.path.join(task_floder, 'logs')) 106 | tb_logger = SummaryWriter(os.path.join(task_floder, 'tb_logs')) 107 | if is_master: 108 | logger.info(dict2str(config)) 109 | 110 | 111 | rna_vocab = GeneVocab.from_file(vocab_config['rna_path']) 112 | atac_vocab = GeneVocab.from_file(vocab_config['atac_path']) 113 | cell_vocab = GeneVocab.from_file(vocab_config['cell_type_path']) 114 | batch_vocab = GeneVocab.from_file(vocab_config['batch_path']) 115 | chr_vocab = GeneVocab.from_file(vocab_config['chr_path']) 116 | 117 | if is_master: 118 | logger.info(f'Rna vocab size: {len(rna_vocab)}') 119 | logger.info(f'Atac vocab size: {len(atac_vocab)}') 120 | 121 | if is_master: 122 | logger.info('loading training data') 123 | 124 | train_set = PairedSCDataset( 125 | rna_file = data_config['train']['rna_path'], 126 | atac_file= data_config['train']['atac_path'], 127 | rna_key = data_config['train']['rna_key'], 128 | atac_key = data_config['train']['atac_key'], 129 | rna_vocab = rna_vocab, 130 | atac_vocab = atac_vocab, 131 | cell_vocab = cell_vocab, 132 | batch_vocab= batch_vocab, 133 | chr_vocab = chr_vocab, 134 | gene2chr_file= vocab_config['gene2chr_path'], 135 | rna_max_len = train_config['model']['rna_max_len'], 136 | atac_max_len = train_config['model']['atac_max_len'], 137 | pad_token = pad['token'], 138 | rna_pad_value = pad['value'], 139 | cls_token = cls['token'], 140 | logger = logger, 141 | ) 142 | 143 | gc.collect() 144 | train_sampler = DistributedSampler(train_set, num_replicas=world_size, rank=rank, shuffle=True) 145 | train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, sampler=train_sampler, prefetch_factor=4, num_workers=4) 146 | 147 | 148 | if is_master: 149 | logger.info('loading validation data') 150 | val_set = PairedSCDataset( 151 | rna_file = data_config['test']['rna_path'], 152 | atac_file= data_config['test']['atac_path'], 153 | rna_key = data_config['test']['rna_key'], 154 | atac_key = data_config['test']['atac_key'], 155 | rna_vocab = rna_vocab, 156 | atac_vocab = atac_vocab, 157 | cell_vocab = cell_vocab, 158 | batch_vocab= batch_vocab, 159 | chr_vocab = chr_vocab, 160 | gene2chr_file= vocab_config['gene2chr_path'], 161 | rna_max_len = train_config['model']['rna_max_len'], 162 | atac_max_len = train_config['model']['atac_max_len'], 163 | pad_token = pad['token'], 164 | rna_pad_value = pad['value'], 165 | cls_token = cls['token'], 166 | logger = logger, 167 | ) 168 | gc.collect() 169 | 170 | val_sampler = SequentialDistributedSampler(val_set, batch_size=BATCH_SIZE, world_size=world_size, rank=rank, num_replicas=world_size) 171 | val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, sampler=val_sampler, prefetch_factor=4, num_workers=4) 172 | 173 | if is_master: 174 | logger.info('Creating model') 175 | # init model on CPU 176 | model = EpiFoundation( 177 | num_class_cell = len(cell_vocab), 178 | num_rnas = len(rna_vocab), 179 | num_atacs = len(atac_vocab), 180 | num_values= data_config['bin_num'], 181 | num_chrs= len(chr_vocab), 182 | embed_dim = train_config['model']['embedding_dim'], 183 | depth = train_config['model']['num_layers'], 184 | heads = train_config['model']['head_num'], 185 | head_dim = train_config['model']['head_dim'], 186 | encoder = model_name, 187 | dropout = train_config['model']['dropout'], 188 | pad_token_idx_rna = rna_vocab[pad['token']], 189 | pad_token_idx_atac = atac_vocab[pad['token']], 190 | cell_emb_style = train_config['model']['cell_emb_style'], 191 | mvc_arch_style = train_config['model']['mvc_arch_style'], 192 | use_batch_labels = train_config['model']['use_batch_labels'], 193 | batch_label_num= len(batch_vocab), 194 | use_chr_labels= train_config['model']['use_chr_labels'], 195 | transformer_backend='flash', 196 | ) 197 | 198 | mixed_precision_policy = MixedPrecision( 199 | param_dtype=torch.bfloat16, 200 | reduce_dtype=torch.bfloat16, 201 | buffer_dtype=torch.bfloat16, 202 | ) 203 | 204 | sharding_strategy = ShardingStrategy.FULL_SHARD 205 | 206 | model = FSDP( 207 | module=model, 208 | mixed_precision=mixed_precision_policy, 209 | sharding_strategy=sharding_strategy, 210 | device_id=torch.cuda.current_device(), 211 | ) 212 | 213 | # optimizer 214 | optimizer = Adam(model.parameters(), lr=LEARNING_RATE) 215 | 216 | # learning rate scheduler 217 | scheduler = CosineAnnealingWarmupRestarts( 218 | optimizer, 219 | first_cycle_steps=15, 220 | cycle_mult=2, 221 | max_lr=LEARNING_RATE, 222 | min_lr=1e-6, 223 | warmup_steps=5, 224 | gamma=0.9 225 | ) 226 | 227 | start_epoch = 1 228 | 229 | # model = DDP(model, device_ids=[local_rank], output_device=local_rank) 230 | # scaler = torch.amp.GradScaler(enabled=train_config['amp'].amp) 231 | scaler = torch.cuda.amp.GradScaler(enabled=train_config['amp']) 232 | 233 | # masked_mse_loss = MaskedMSELoss().to(local_rank) 234 | cross_entropy_loss = nn.CrossEntropyLoss(reduction='mean').to(device) 235 | 236 | if data_config['bin_num'] > 1: 237 | mvc_loss_fn = nn.CrossEntropyLoss(reduction='mean', ignore_index = pad['value']).to(device) 238 | else: 239 | mvc_loss_fn = nn.MSELoss(reduction='sum').to(device) 240 | 241 | mvc_weight = train_config['task_weight']['mvc'] 242 | cell_type_weight = train_config['task_weight']['cell_type'] 243 | 244 | softmax = nn.Softmax(dim=-1) 245 | 246 | steps = 0 247 | if train_config['model']['pretrained'] is not None: 248 | if is_master: 249 | logger.info('Loading pretrained model from: {}'.format(train_config['model']['pretrained'])) 250 | checkpoint = torch.load(train_config['model']['pretrained'], map_location='cpu') 251 | 252 | with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT): 253 | model.load_state_dict(checkpoint['model']) 254 | 255 | optimizer.load_state_dict(checkpoint['optimizer']) 256 | scheduler.load_state_dict(checkpoint['scheduler']) 257 | scaler.load_state_dict(checkpoint['scaler']) 258 | if resume: 259 | start_epoch = checkpoint['epoch'] + 1 260 | steps = checkpoint['steps'] 261 | del checkpoint 262 | gc.collect() 263 | 264 | dist.barrier() 265 | if is_master: 266 | logger.info('Start finetuning from epoch: {}, steps: {}'.format(start_epoch, steps)) 267 | for i in range(start_epoch, start_epoch + EPOCHS): 268 | train_loader.sampler.set_epoch(i) 269 | 270 | if is_master: 271 | logger.info('Training with {} samples, steps: {}'.format(len(train_loader.dataset), len(train_loader))) 272 | model.train() 273 | dist.barrier() 274 | running_loss = {'mvc': 0.0, 'cell': 0.0, 'total': 0.0} 275 | cum_acc_cell = 0.0 276 | cum_acc_value = 0.0 277 | for index, batch in enumerate(train_loader): 278 | index += 1 279 | steps += 1 280 | rna_values = batch['rna_values'].to(device) 281 | rna_ids = batch['rna_ids'].to(device) 282 | atac_ids = batch['atac_ids'].to(device) 283 | cell_ids = batch['cell_ids'].to(device) 284 | batch_ids = batch['batch_ids'].to(device) 285 | rna_chrs = batch['rna_chrs'].to(device) 286 | atac_chrs = batch['atac_chrs'].to(device) 287 | 288 | padding_positions = atac_ids.eq(atac_vocab[pad['token']]) 289 | if index % GRADIENT_ACCUMULATION != 0 and index != len(train_loader): 290 | with model.no_sync(): 291 | # with torch.cuda.amp.autocast(enabled=train_config['amp'], dtype= torch.bfloat16): 292 | # # finetue using all expression values, do not mask 293 | # output = model(atac = atac_ids, rna = rna_ids, src_key_padding_mask = padding_positions, batch_id = batch_ids, rna_chrs = rna_chrs, atac_chrs = atac_chrs) 294 | 295 | # mvc_loss = mvc_loss_fn(output['mvc_pred'].transpose(1, 2), rna_values) * mvc_weight 296 | # cell_loss = cross_entropy_loss(output['cell_pred'], cell_ids) * cell_type_weight 297 | # loss = mvc_loss + cell_loss 298 | 299 | # running_loss['mvc'] += mvc_loss.item() 300 | # running_loss['cell'] += cell_loss.item() 301 | # running_loss['total'] += loss.item() 302 | 303 | # loss = loss / GRADIENT_ACCUMULATION 304 | 305 | output = model(atac = atac_ids, rna = rna_ids, src_key_padding_mask = padding_positions, batch_id = batch_ids, rna_chrs = rna_chrs, atac_chrs = atac_chrs) 306 | 307 | mvc_loss = mvc_loss_fn(output['mvc_pred'].transpose(1, 2), rna_values) * mvc_weight 308 | cell_loss = cross_entropy_loss(output['cell_pred'], cell_ids) * cell_type_weight 309 | loss = mvc_loss + cell_loss 310 | 311 | running_loss['mvc'] += mvc_loss.item() 312 | running_loss['cell'] += cell_loss.item() 313 | running_loss['total'] += loss.item() 314 | 315 | loss = loss / GRADIENT_ACCUMULATION 316 | 317 | ###### 318 | 319 | scaler.scale(loss).backward() 320 | else: 321 | # with torch.cuda.amp.autocast(enabled=train_config['amp'], dtype= torch.bfloat16): 322 | # output = model(atac = atac_ids, rna = rna_ids, src_key_padding_mask = padding_positions, batch_id = batch_ids, rna_chrs = rna_chrs, atac_chrs = atac_chrs) 323 | 324 | # mvc_loss = mvc_loss_fn(output['mvc_pred'].transpose(1, 2), rna_values) * mvc_weight 325 | # cell_loss = cross_entropy_loss(output['cell_pred'], cell_ids) * cell_type_weight 326 | # loss = mvc_loss + cell_loss 327 | 328 | # running_loss['mvc'] += mvc_loss.item() 329 | # running_loss['cell'] += cell_loss.item() 330 | # running_loss['total'] += loss.item() 331 | # if is_master: 332 | # tb_logger.add_scalar('train/mvc_loss', mvc_loss.item(), steps) 333 | # tb_logger.add_scalar('train/cell_loss', cell_loss.item(), steps) 334 | # tb_logger.add_scalar('train/total_loss', loss.item(), steps) 335 | # logger.info(f'Epoch: {i} | Step: {index} | MVC Loss: {mvc_loss:.4f} | Cell Type Loss: {cell_loss:.4f} | Total Loss: {loss:.4f}') 336 | # loss = loss / GRADIENT_ACCUMULATION 337 | 338 | output = model(atac = atac_ids, rna = rna_ids, src_key_padding_mask = padding_positions, batch_id = batch_ids, rna_chrs = rna_chrs, atac_chrs = atac_chrs) 339 | 340 | mvc_loss = mvc_loss_fn(output['mvc_pred'].transpose(1, 2), rna_values) * mvc_weight 341 | cell_loss = cross_entropy_loss(output['cell_pred'], cell_ids) * cell_type_weight 342 | loss = mvc_loss + cell_loss 343 | 344 | running_loss['mvc'] += mvc_loss.item() 345 | running_loss['cell'] += cell_loss.item() 346 | running_loss['total'] += loss.item() 347 | if is_master: 348 | tb_logger.add_scalar('train/mvc_loss', mvc_loss.item(), steps) 349 | tb_logger.add_scalar('train/cell_loss', cell_loss.item(), steps) 350 | tb_logger.add_scalar('train/total_loss', loss.item(), steps) 351 | logger.info(f'Epoch: {i} | Step: {index} | MVC Loss: {mvc_loss:.4f} | Cell Type Loss: {cell_loss:.4f} | Total Loss: {loss:.4f}') 352 | loss = loss / GRADIENT_ACCUMULATION 353 | ####### 354 | 355 | scaler.scale(loss).backward() 356 | scaler.unscale_(optimizer) 357 | torch.nn.utils.clip_grad_norm_(model.parameters(), int(1e2)) 358 | scaler.step(optimizer) 359 | scaler.update() 360 | optimizer.zero_grad() 361 | # cell type accuracy 362 | type_pred = softmax(output['cell_pred']) 363 | type_pred = type_pred.argmax(dim=-1) 364 | cum_acc_cell += (type_pred.eq(cell_ids)).sum().item() / len(cell_ids) 365 | 366 | value_pred = softmax(output['mvc_pred']).argmax(dim=-1) 367 | # expression value accuracy 368 | non_pad_idx = rna_values.ne(pad['value']) 369 | non_pad_pred = value_pred[non_pad_idx] 370 | non_pad_label = rna_values[non_pad_idx] 371 | cum_acc_value += (non_pad_pred.eq(non_pad_label).sum().item()) / non_pad_label.size(0) 372 | 373 | cum_acc_cell = 100 * cum_acc_cell / index 374 | cum_acc_cell = get_reduced(cum_acc_cell, rank, 0, world_size) 375 | 376 | cum_acc_value = 100 * cum_acc_value / index 377 | cum_acc_value = get_reduced(cum_acc_value, rank, 0, world_size) 378 | for key in running_loss: 379 | running_loss[key] = running_loss[key] / index 380 | running_loss[key] = get_reduced(running_loss[key], rank, 0, world_size) 381 | if is_master: 382 | logger.info(f'Epoch: {i} | MVC Loss: {running_loss["mvc"]:.4f} | Cell Type Loss: {running_loss["cell"]:.4f} | Total Loss: {running_loss["total"]:.4f} | Cell Type Accuracy: {cum_acc_cell:.2f} | Expression Value Accuracy: {cum_acc_value:.2f}') 383 | dist.barrier() 384 | scheduler.step() 385 | # del train_set, train_sampler, train_loader 386 | 387 | if i % valid_config['freq'] == 0: 388 | if is_master: 389 | logger.info('#### Validation ####') 390 | model.eval() 391 | dist.barrier() 392 | running_loss = {'mvc': 0.0, 'cell': 0.0, 'total': 0.0} 393 | 394 | cum_acc_cell = 0.0 395 | cum_acc_value = 0.0 396 | 397 | with torch.no_grad(): 398 | for index, batch in enumerate(val_loader): 399 | index += 1 400 | 401 | rna_values = batch['rna_values'].to(device) 402 | rna_ids = batch['rna_ids'].to(device) 403 | atac_ids = batch['atac_ids'].to(device) 404 | cell_ids = batch['cell_ids'].to(device) 405 | batch_ids = batch['batch_ids'].to(device) 406 | rna_chrs = batch['rna_chrs'].to(device) 407 | atac_chrs = batch['atac_chrs'].to(device) 408 | 409 | padding_positions = atac_ids.eq(atac_vocab[pad['token']]) 410 | # with torch.cuda.amp.autocast(enabled=train_config['amp'], dtype= torch.bfloat16): 411 | # output = model(atac = atac_ids, rna = rna_ids, src_key_padding_mask = padding_positions, batch_id = batch_ids, rna_chrs = rna_chrs, atac_chrs = atac_chrs) 412 | 413 | # mvc_loss = mvc_loss_fn(output['mvc_pred'].transpose(1, 2), rna_values) * mvc_weight 414 | # cell_loss = cross_entropy_loss(output['cell_pred'], cell_ids) * cell_type_weight 415 | # loss = mvc_loss + cell_loss 416 | 417 | output = model(atac = atac_ids, rna = rna_ids, src_key_padding_mask = padding_positions, batch_id = batch_ids, rna_chrs = rna_chrs, atac_chrs = atac_chrs) 418 | 419 | mvc_loss = mvc_loss_fn(output['mvc_pred'].transpose(1, 2), rna_values) * mvc_weight 420 | cell_loss = cross_entropy_loss(output['cell_pred'], cell_ids) * cell_type_weight 421 | loss = mvc_loss + cell_loss 422 | 423 | running_loss['mvc'] += mvc_loss.item() 424 | running_loss['cell'] += cell_loss.item() 425 | running_loss['total'] += loss.item() 426 | 427 | type_pred = softmax(output['cell_pred']) 428 | type_pred = type_pred.argmax(dim=-1) 429 | cum_acc_cell += (type_pred.eq(cell_ids)).sum().item() / len(cell_ids) 430 | 431 | value_pred = softmax(output['mvc_pred']).argmax(dim=-1) 432 | # expression value accuracy 433 | non_pad_idx = rna_values.ne(pad['value']) 434 | non_pad_pred = value_pred[non_pad_idx] 435 | non_pad_label = rna_values[non_pad_idx] 436 | cum_acc_value += (non_pad_pred.eq(non_pad_label).sum().item()) / non_pad_label.size(0) 437 | for key in running_loss: 438 | running_loss[key] = running_loss[key] / index 439 | running_loss[key] = get_reduced(running_loss[key], rank, 0, world_size) 440 | cum_acc_cell = 100 * cum_acc_cell / index 441 | cum_acc_cell = get_reduced(cum_acc_cell, rank, 0, world_size) 442 | 443 | cum_acc_value = 100 * cum_acc_value / index 444 | cum_acc_value = get_reduced(cum_acc_value, rank, 0, world_size) 445 | 446 | # del val_set, val_sampler, val_loader 447 | if is_master: 448 | logger.info(f'MVC Loss: {running_loss["mvc"]:.4f} | Cell Type Loss: {running_loss["cell"]:.4f} | Total Loss: {running_loss["total"]:.4f} | Cell Type Accuracy: {cum_acc_cell:.2f} | Expression Value Accuracy: {cum_acc_value:.2f}') 449 | 450 | if i % save_ckpt_freq == 0 or i == EPOCHS: 451 | save_ckpt_fsdp(i, steps, model, optimizer, scheduler, scaler, running_loss["total"], task_name, ckpt_dir, rank=rank) 452 | if is_master: 453 | logger.info('Model saved at epoch: {}'.format(i)) 454 | 455 | 456 | if __name__ == '__main__': 457 | main() 458 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import gc 5 | import argparse 6 | import json 7 | import random 8 | import math 9 | import random 10 | from functools import reduce 11 | import numpy as np 12 | import pandas as pd 13 | from scipy import sparse 14 | from sklearn.model_selection import train_test_split 15 | import torch 16 | from torch import nn 17 | from torch.optim import Adam 18 | from torch.nn import functional as F 19 | from tensorboardX import SummaryWriter 20 | from torch.utils.data import DataLoader, Dataset 21 | from torch.utils.data.distributed import DistributedSampler 22 | from torch.nn.parallel import DistributedDataParallel as DDP 23 | import torch.distributed as dist 24 | from model import EpiFoundation 25 | from loss.loss import MaskedMSELoss 26 | from data.dataloader import * 27 | from tokenizer import GeneVocab 28 | import scanpy as sc 29 | import anndata as ad 30 | from utils import * 31 | from memory_profiler import profile 32 | 33 | import yaml 34 | 35 | torch.autograd.set_detect_anomaly(True) 36 | 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument("--config", type=str, default='./configs/eval/baseline.yml', help='Config file.') 39 | parser.add_argument("--backend", type=str, default='flash', help='Fast Transformer backend.') 40 | args = parser.parse_args() 41 | 42 | # @profile(precision=4, stream=open("memory_profiler.log", "w+")) 43 | def main(): 44 | # read and parse config file 45 | local_rank = int(os.environ["LOCAL_RANK"]) 46 | with open(args.config, 'r') as f: 47 | config = yaml.load(f, Loader=yaml.FullLoader) 48 | 49 | 50 | train_config = config['train'] 51 | valid_config = config['valid'] 52 | data_config = config['data'] 53 | vocab_config = config['vocab'] 54 | task_name = config['task_name'] 55 | task_floder = './result/{}'.format(task_name) 56 | ckpt_dir = os.path.join(task_floder, 'ckpts') 57 | 58 | 59 | 60 | random_seed = train_config['seed'] 61 | EPOCHS = train_config['epochs'] 62 | BATCH_SIZE = train_config['batch_size'] 63 | GRADIENT_ACCUMULATION = train_config['gradient_accumulation_steps'] 64 | LEARNING_RATE = float(train_config['lr']) 65 | 66 | model_name = train_config['model']['encoder'] 67 | 68 | save_ckpt_freq = train_config['save_ckpt_freq'] if 'save_ckpt_freq' in train_config else 5 69 | resume = train_config['resume'] if 'resume' in train_config else False 70 | 71 | # special tokens 72 | pad = vocab_config['special_tokens']['pad'] 73 | mask = vocab_config['special_tokens']['mask'] 74 | cls = vocab_config['special_tokens']['cls'] 75 | 76 | # distibuted setting 77 | dist.init_process_group(backend='nccl') 78 | torch.cuda.set_device(local_rank) 79 | device = torch.device("cuda", local_rank) 80 | world_size = torch.distributed.get_world_size() 81 | seed_all(random_seed + torch.distributed.get_rank()) 82 | is_master = (local_rank == 0) 83 | 84 | # init loggers 85 | logger = set_log(log_dir= os.path.join(task_floder, 'logs')) 86 | tb_logger = SummaryWriter(os.path.join(task_floder, 'tb_logs')) 87 | if is_master: 88 | logger.info(dict2str(config)) 89 | 90 | 91 | rna_vocab = GeneVocab.from_file(vocab_config['rna_path']) 92 | atac_vocab = GeneVocab.from_file(vocab_config['atac_path']) 93 | cell_vocab = GeneVocab.from_file(vocab_config['cell_type_path']) 94 | batch_vocab = GeneVocab.from_file(vocab_config['batch_path']) 95 | chr_vocab = GeneVocab.from_file(vocab_config['chr_path']) 96 | if is_master: 97 | logger.info(f'Rna vocab size: {len(rna_vocab)}') 98 | logger.info(f'Atac vocab size: {len(atac_vocab)}') 99 | 100 | if is_master: 101 | logger.info('loading training data') 102 | 103 | 104 | 105 | if is_master: 106 | logger.info('loading validation data') 107 | val_set = PairedSCDataset( 108 | rna_file = data_config['test']['rna_path'], 109 | atac_file= data_config['test']['atac_path'], 110 | rna_key = data_config['test']['rna_key'], 111 | atac_key = data_config['test']['atac_key'], 112 | rna_vocab = rna_vocab, 113 | atac_vocab = atac_vocab, 114 | cell_vocab = cell_vocab, 115 | batch_vocab= batch_vocab, 116 | chr_vocab = chr_vocab, 117 | gene2chr_file= vocab_config['gene2chr_path'], 118 | rna_max_len = train_config['model']['rna_max_len'], 119 | atac_max_len = train_config['model']['atac_max_len'], 120 | pad_token = pad['token'], 121 | rna_pad_value = pad['value'], 122 | cls_token = cls['token'], 123 | logger = logger, 124 | ) 125 | gc.collect() 126 | 127 | val_sampler = SequentialDistributedSampler(val_set, batch_size=BATCH_SIZE, world_size=world_size) 128 | val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, sampler=val_sampler, prefetch_factor=4, num_workers=4) 129 | # create non distributed loader for evaluation 130 | if is_master: 131 | # for evaluation, batch size should be 1 132 | val_non_dist_loader = DataLoader(val_set, batch_size= 1, shuffle=False, num_workers=4) 133 | 134 | if is_master: 135 | logger.info('Creating model') 136 | 137 | model = EpiFoundation( 138 | num_class_cell = len(cell_vocab), 139 | num_rnas = len(rna_vocab), 140 | num_atacs = len(atac_vocab), 141 | num_values= data_config['bin_num'], 142 | num_chrs= len(chr_vocab), 143 | embed_dim = train_config['model']['embedding_dim'], 144 | depth = train_config['model']['num_layers'], 145 | heads = train_config['model']['head_num'], 146 | head_dim = train_config['model']['head_dim'], 147 | encoder = model_name, 148 | dropout = train_config['model']['dropout'], 149 | pad_token_idx_rna = rna_vocab[pad['token']], 150 | pad_token_idx_atac = atac_vocab[pad['token']], 151 | cell_emb_style = train_config['model']['cell_emb_style'], 152 | mvc_arch_style = train_config['model']['mvc_arch_style'], 153 | use_batch_labels = train_config['model']['use_batch_labels'], 154 | batch_label_num= len(batch_vocab), 155 | use_chr_labels= train_config['model']['use_chr_labels'], 156 | transformer_backend = args.backend, 157 | ).to(device) 158 | 159 | # optimizer 160 | optimizer = Adam(model.parameters(), lr=LEARNING_RATE) 161 | 162 | # learning rate scheduler 163 | scheduler = CosineAnnealingWarmupRestarts( 164 | optimizer, 165 | first_cycle_steps=15, 166 | cycle_mult=2, 167 | max_lr=LEARNING_RATE, 168 | min_lr=1e-6, 169 | warmup_steps=5, 170 | gamma=0.9 171 | ) 172 | 173 | if is_master and train_config['metric'] == True: 174 | non_dist_model = model 175 | model = DDP(model, device_ids=[local_rank], output_device=local_rank) 176 | 177 | # scaler = torch.amp.GradScaler(enabled=train_config['amp'].amp) 178 | scaler = torch.cuda.amp.GradScaler(enabled=train_config['amp']) 179 | 180 | # masked_mse_loss = MaskedMSELoss().to(local_rank) 181 | cross_entropy_loss = nn.CrossEntropyLoss(reduction='mean').to(local_rank) 182 | atac_cross_entropy_loss = nn.CrossEntropyLoss(reduction='mean', ignore_index = pad['value']).to(local_rank) 183 | 184 | softmax = nn.Softmax(dim=-1) 185 | 186 | steps = 0 187 | if train_config['model']['pretrained'] is not None: 188 | if is_master: 189 | logger.info('Loading pretrained model from: {}'.format(train_config['model']['pretrained'])) 190 | checkpoint = torch.load(train_config['model']['pretrained'], map_location=device) 191 | 192 | # # do not load batch_emb and cls_decoder parameters (when finetuning on different dataset) 193 | pretrained_dict = {k: v for k, v in checkpoint['model'].items() if 'batch_emb' not in k and 'cls_decoder' not in k and 'mvc_decoder' not in k} 194 | # pretrained_dict = {k: v for k, v in checkpoint['model'].items() if 'batch_emb' not in k } 195 | model_dict = model.module.state_dict() 196 | model_dict.update(pretrained_dict) 197 | model.module.load_state_dict(model_dict) 198 | if is_master and train_config['metric'] == True: 199 | non_dist_model.load_state_dict(checkpoint['model']) 200 | # optimizer.load_state_dict(checkpoint['optimizer']) 201 | scheduler.load_state_dict(checkpoint['scheduler']) 202 | scaler.load_state_dict(checkpoint['scaler']) 203 | del checkpoint 204 | gc.collect() 205 | 206 | dist.barrier() 207 | 208 | if train_config['metric'] == True: 209 | if is_master: 210 | non_dist_model.eval() 211 | model.eval() 212 | logger.info('Start evaluation with scib metrices') 213 | test_adata = sc.read_h5ad(data_config['test']['atac_path']) 214 | 215 | batch_labels = [] 216 | embeddings = [] 217 | cell_labels = [] 218 | cell_pred = [] 219 | 220 | cell_acc = 0.0 221 | tbar = tqdm(val_non_dist_loader, desc='Eval') 222 | for index, batch in enumerate(val_non_dist_loader): 223 | tbar.update(1) 224 | index += 1 225 | # if index > 10: 226 | # break 227 | 228 | rna_values = batch['rna_values'].to(device) 229 | rna_ids = batch['rna_ids'].to(device) 230 | atac_ids = batch['atac_ids'].to(device) 231 | cell_ids = batch['cell_ids'].to(device) 232 | batch_ids = batch['batch_ids'].to(device) 233 | rna_chrs = batch['rna_chrs'].to(device) 234 | atac_chrs = batch['atac_chrs'].to(device) 235 | 236 | padding_positions = atac_ids.eq(atac_vocab[pad['token']]) 237 | with torch.cuda.amp.autocast(enabled=train_config['amp'], dtype= torch.bfloat16): 238 | output = non_dist_model(atac = atac_ids, rna = rna_ids, src_key_padding_mask = padding_positions, batch_id = batch_ids, rna_chrs = rna_chrs, atac_chrs = atac_chrs) 239 | pred = softmax(output['cell_pred']).argmax(dim=-1) 240 | cell_pred.append(pred.cpu().numpy()) 241 | 242 | cell_labels.append(cell_ids.cpu().numpy()) 243 | 244 | # print("GT and Pred: ", cell_ids.cpu().numpy(), pred.cpu().numpy()) 245 | 246 | embeddings.append(output['cell_emb'].detach().cpu().numpy()) 247 | batch_labels.append(batch_ids.cpu().numpy()) 248 | tbar.close() 249 | # concatenate the results and transform to numpy array 250 | cell_labels = np.concatenate(cell_labels) 251 | embeddings = np.concatenate(embeddings) 252 | batch_labels = np.concatenate(batch_labels) 253 | 254 | 255 | embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True) 256 | 257 | data_cell_ids = np.array(cell_vocab(test_adata.obs['annot'].tolist())) 258 | data_batch_ids = test_adata.obs['batch'].tolist() 259 | data_batch_ids = np.array(batch_vocab(data_batch_ids)) 260 | 261 | cell_pred = np.concatenate(cell_pred) 262 | cell_acc = (cell_pred == cell_labels).sum() / len(cell_labels) 263 | logger.info(f'Cell type accuracy: {cell_acc:.4f}') 264 | 265 | 266 | # data_cell_ids must be same as cell_labels, ensure the order 267 | assert np.all(data_cell_ids == cell_labels) 268 | assert np.all(data_batch_ids == batch_labels) 269 | # now cell_labels have shape (len(data_cell_ids), 1), convert to (len(data_cell_ids),) 270 | 271 | 272 | test_adata.obsm['embedding'] = embeddings 273 | test_adata.obs['celltype'] = test_adata.obs['annot'].astype('category') 274 | test_adata.obs['str_batch'] = test_adata.obs['batch'].astype(str) 275 | eval_scib_metrics(test_adata, logger, batch_key='str_batch', label_key='celltype') 276 | 277 | del non_dist_model, val_non_dist_loader, test_adata 278 | gc.collect() 279 | 280 | dist.barrier() 281 | # if eval cell type acc, first finetune the model with cell type 282 | if train_config['cell_type_epochs'] > 0: 283 | if is_master: 284 | logger.info("Init Training set, train model with cell type") 285 | train_set = PairedSCDataset( 286 | rna_file = data_config['train']['rna_path'], 287 | atac_file= data_config['train']['atac_path'], 288 | rna_key = data_config['train']['rna_key'], 289 | atac_key = data_config['train']['atac_key'], 290 | rna_vocab = rna_vocab, 291 | atac_vocab = atac_vocab, 292 | cell_vocab = cell_vocab, 293 | batch_vocab= batch_vocab, 294 | chr_vocab = chr_vocab, 295 | gene2chr_file= vocab_config['gene2chr_path'], 296 | rna_max_len = train_config['model']['rna_max_len'], 297 | atac_max_len = train_config['model']['atac_max_len'], 298 | pad_token = pad['token'], 299 | rna_pad_value = pad['value'], 300 | cls_token = cls['token'], 301 | logger = logger, 302 | ) 303 | 304 | gc.collect() 305 | train_sampler = DistributedSampler(train_set) 306 | train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, sampler=train_sampler, prefetch_factor=4, num_workers=4) 307 | 308 | for i in range(1, train_config['cell_type_epochs'] + 1): 309 | train_loader.sampler.set_epoch(i) 310 | if is_master: 311 | logger.info('Training with {} samples, steps: {}'.format(len(train_loader.dataset), len(train_loader))) 312 | model.train() 313 | dist.barrier() 314 | running_loss = {'mvc': 0.0, 'cell': 0.0, 'total': 0.0} 315 | cum_acc_cell = 0.0 316 | cum_acc_value = 0.0 317 | for index, batch in enumerate(train_loader): 318 | index += 1 319 | steps += 1 320 | 321 | rna_values = batch['rna_values'].to(device) 322 | rna_ids = batch['rna_ids'].to(device) 323 | atac_ids = batch['atac_ids'].to(device) 324 | cell_ids = batch['cell_ids'].to(device) 325 | batch_ids = batch['batch_ids'].to(device) 326 | rna_chrs = batch['rna_chrs'].to(device) 327 | atac_chrs = batch['atac_chrs'].to(device) 328 | 329 | padding_positions = atac_ids.eq(atac_vocab[pad['token']]) 330 | 331 | if index % GRADIENT_ACCUMULATION != 0 and index != len(train_loader): 332 | with model.no_sync(): 333 | with torch.cuda.amp.autocast(enabled=train_config['amp'], dtype= torch.bfloat16): 334 | # finetue using all expression values, do not mask 335 | output = model(atac = atac_ids, rna = rna_ids, src_key_padding_mask = padding_positions, batch_id = batch_ids, rna_chrs = rna_chrs, atac_chrs = atac_chrs) 336 | 337 | mvc_loss = atac_cross_entropy_loss(output['mvc_pred'].transpose(1, 2), rna_values) 338 | cell_loss = cross_entropy_loss(output['cell_pred'], cell_ids) 339 | # only train the cell loss 340 | loss = cell_loss + mvc_loss * 0.0 341 | 342 | running_loss['mvc'] += mvc_loss.item() 343 | running_loss['cell'] += cell_loss.item() 344 | running_loss['total'] += loss.item() 345 | 346 | loss = loss / GRADIENT_ACCUMULATION 347 | scaler.scale(loss).backward() 348 | else: 349 | with torch.cuda.amp.autocast(enabled=train_config['amp'], dtype= torch.bfloat16): 350 | output = model(atac = atac_ids, rna = rna_ids, src_key_padding_mask = padding_positions, batch_id = batch_ids, rna_chrs = rna_chrs, atac_chrs = atac_chrs) 351 | 352 | mvc_loss = atac_cross_entropy_loss(output['mvc_pred'].transpose(1, 2), rna_values) 353 | 354 | cell_loss = cross_entropy_loss(output['cell_pred'], cell_ids) 355 | loss = cell_loss + mvc_loss * 0.0 356 | 357 | running_loss['mvc'] += mvc_loss.item() 358 | running_loss['cell'] += cell_loss.item() 359 | running_loss['total'] += loss.item() 360 | if is_master: 361 | tb_logger.add_scalar('train/mvc_loss', mvc_loss.item(), steps) 362 | tb_logger.add_scalar('train/cell_loss', cell_loss.item(), steps) 363 | tb_logger.add_scalar('train/total_loss', loss.item(), steps) 364 | logger.info(f'Epoch: {i} | Step: {index} | MVC Loss: {mvc_loss:.4f} | Cell Type Loss: {cell_loss:.4f} | Total Loss: {loss:.4f}') 365 | loss = loss / GRADIENT_ACCUMULATION 366 | scaler.scale(loss).backward() 367 | scaler.unscale_(optimizer) 368 | torch.nn.utils.clip_grad_norm_(model.parameters(), int(1e2)) 369 | scaler.step(optimizer) 370 | scaler.update() 371 | optimizer.zero_grad() 372 | # cell type accuracy 373 | type_pred = softmax(output['cell_pred']) 374 | type_pred = type_pred.argmax(dim=-1) 375 | cum_acc_cell += (type_pred.eq(cell_ids)).sum().item() / len(cell_ids) 376 | 377 | value_pred = softmax(output['mvc_pred']).argmax(dim=-1) 378 | # expression value accuracy 379 | non_pad_idx = rna_values.ne(pad['value']) 380 | non_pad_pred = value_pred[non_pad_idx] 381 | non_pad_label = rna_values[non_pad_idx] 382 | cum_acc_value += (non_pad_pred.eq(non_pad_label).sum().item()) / non_pad_label.size(0) 383 | 384 | cum_acc_cell = 100 * cum_acc_cell / index 385 | cum_acc_cell = get_reduced(cum_acc_cell, local_rank, 0, world_size) 386 | 387 | cum_acc_value = 100 * cum_acc_value / index 388 | cum_acc_value = get_reduced(cum_acc_value, local_rank, 0, world_size) 389 | for key in running_loss: 390 | running_loss[key] = running_loss[key] / index 391 | running_loss[key] = get_reduced(running_loss[key], local_rank, 0, world_size) 392 | if is_master: 393 | logger.info(f'Epoch: {i} | MVC Loss: {running_loss["mvc"]:.4f} | Cell Type Loss: {running_loss["cell"]:.4f} | Total Loss: {running_loss["total"]:.4f} | Cell Type Accuracy: {cum_acc_cell:.2f} | Expression Value Accuracy: {cum_acc_value:.2f}') 394 | dist.barrier() 395 | scheduler.step() 396 | 397 | # del train_set, train_sampler, train_loader 398 | if is_master and i % save_ckpt_freq == 0: 399 | logger.info("Saving the finetuned model") 400 | save_ckpt(i, steps, model, optimizer, scheduler, scaler, running_loss["total"], task_name, ckpt_dir) 401 | 402 | if i % valid_config['freq'] == 0: 403 | model.eval() 404 | dist.barrier() 405 | running_loss = {'mvc': 0.0, 'cell': 0.0, 'total': 0.0} 406 | 407 | cum_acc_cell = 0.0 408 | cum_acc_value = 0.0 409 | if is_master: 410 | logger.info('Start validation') 411 | with torch.no_grad(): 412 | for index, batch in enumerate(val_loader): 413 | index += 1 414 | 415 | rna_values = batch['rna_values'].to(device) 416 | rna_ids = batch['rna_ids'].to(device) 417 | atac_ids = batch['atac_ids'].to(device) 418 | cell_ids = batch['cell_ids'].to(device) 419 | batch_ids = batch['batch_ids'].to(device) 420 | rna_chrs = batch['rna_chrs'].to(device) 421 | atac_chrs = batch['atac_chrs'].to(device) 422 | 423 | padding_positions = atac_ids.eq(atac_vocab[pad['token']]) 424 | with torch.cuda.amp.autocast(enabled=train_config['amp'], dtype= torch.bfloat16): 425 | output = model(atac = atac_ids, rna = rna_ids, src_key_padding_mask = padding_positions, batch_id = batch_ids, rna_chrs = rna_chrs, atac_chrs = atac_chrs) 426 | 427 | mvc_loss = atac_cross_entropy_loss(output['mvc_pred'].transpose(1, 2), rna_values) 428 | cell_loss = cross_entropy_loss(output['cell_pred'], cell_ids) 429 | loss = cell_loss + mvc_loss * 0.0 430 | 431 | running_loss['mvc'] += mvc_loss.item() 432 | running_loss['cell'] += cell_loss.item() 433 | running_loss['total'] += loss.item() 434 | 435 | type_pred = softmax(output['cell_pred']) 436 | type_pred = type_pred.argmax(dim=-1) 437 | cum_acc_cell += (type_pred.eq(cell_ids)).sum().item() / len(cell_ids) 438 | 439 | value_pred = softmax(output['mvc_pred']).argmax(dim=-1) 440 | # expression value accuracy 441 | non_pad_idx = rna_values.ne(pad['value']) 442 | non_pad_pred = value_pred[non_pad_idx] 443 | non_pad_label = rna_values[non_pad_idx] 444 | cum_acc_value += (non_pad_pred.eq(non_pad_label).sum().item()) / non_pad_label.size(0) 445 | # break 446 | for key in running_loss: 447 | running_loss[key] = running_loss[key] / index 448 | running_loss[key] = get_reduced(running_loss[key], local_rank, 0, world_size) 449 | cum_acc_cell = 100 * cum_acc_cell / index 450 | cum_acc_cell = get_reduced(cum_acc_cell, local_rank, 0, world_size) 451 | 452 | cum_acc_value = 100 * cum_acc_value / index 453 | cum_acc_value = get_reduced(cum_acc_value, local_rank, 0, world_size) 454 | # del val_set, val_sampler, val_loader 455 | if is_master: 456 | logger.info(f'MVC Loss: {running_loss["mvc"]:.4f} | Cell Type Loss: {running_loss["cell"]:.4f} | Total Loss: {running_loss["total"]:.4f} | Cell Type Accuracy: {cum_acc_cell:.2f} | Expression Value Accuracy: {cum_acc_value:.2f}') 457 | 458 | if __name__ == '__main__': 459 | main() 460 | --------------------------------------------------------------------------------