├── .gitignore ├── LICENSE ├── README.md ├── data ├── Zheng68K.h5ad.gz └── panglao_10000.h5ad.gz ├── scBERT ├── collate_final_results_finetune.ipynb ├── dist_finetune.py ├── dist_finetune_fewshot.py ├── dist_finetune_nopretraining.py ├── dist_pretrain.py ├── nog2v_explore.ipynb ├── performer_pytorch │ ├── .ipynb_checkpoints │ │ └── performer_pytorch-checkpoint.py │ ├── __init__.py │ ├── performer_pytorch.py │ └── reversible.py ├── preprocess.py ├── scbert_baselines_LR-MacParland.ipynb ├── scbert_baselines_LR.ipynb ├── scbert_environment.yml └── utils.py └── scGPT ├── create_figures_and_tables.ipynb ├── scGPT_baselines_LR.py ├── scGPT_run_all_celltypeannot_fewshot.py ├── scGPT_run_all_celltypeannot_nopretrain_freeze.py ├── scGPT_run_all_celltypeannot_nopretrain_nofreeze.py └── scgpt_environment.yml /.gitignore: -------------------------------------------------------------------------------- 1 | scBERT/performer_pytorch/__pycache__/* 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2024, Rebecca Boiarsky 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # sc-foundation-eval 2 | Code for evaluating single cell foundation models scBERT and scGPT. This code was used for the analysis presented in A Deep Dive into Single-Cell RNA Sequencing Foundation Models, bioRxiv https://doi.org/10.1101/2023.10.19.563100. 3 | 4 | The repo is organized by model. Below are descriptions of the scripts and analysis code included for each: 5 | 6 | ## scBERT 7 | * [performer_pytorch/](scBERT/performer_pytorch) contains the code for the scBERT model 8 | * preprocess.py is a script provided by the scBERT authors, used to preprocess a dataset for fine-tuning 9 | * dist_pretrain.py: used to pre-train scBERT from scratch 10 | * dist_finetune.py: used to run fine-tuning (cell type annotation) for scBERT (Table 1). For our "no gene2vec" ablation (Table 2), do not pass the argument `--pos_embed_g2v` when calling this script. 11 | * An example command line call to run fine-tuning: `python dist_finetune.py --model_name finetune_seed2021 --data_path --model_path --world_size=1 --seed=2021 --epochs=10 --grad_acc=1 --batch_size=32 --pos_embed_g2v` 12 | * dist_finetune_nopretraining.py: run our "no pre-training" ablation on scBERT (Table 2) 13 | * Similar command line call as above, but you do not need to supply a model_path, since this script does not load a pre-trained model (if you do supply one, it will be ignored and the ablation will still run properly) 14 | * dist_finetune_fewshot.py: run scBERT fine-tuning on 10, 25, 50, 75, and 100\% of the training data 15 | * scbert_baselines_LR.ipynb shows example code for running the logistic regression baseline for annotating cell types in the Zheng68K PBMC dataset, including the few-shot setting 16 | * nog2v_explore.ipynb: an exploration of pre-training performance for our "no gene2vec" ablation, including the results shown in Table 3 17 | * collate_final_results_finetune.ipynb: collate results of fine-tuning scBERT (full and few-shot settings), logistic regression (full and few-shot settings), and ablation studies to create Tables 1 & 2 and Figure 2 18 | 19 | ## scGPT 20 | * scGPT_baselines_LR.py: runs the logistic regression baseline for annotating cell types in the myeloid, multiple sclerosis, and pancreas datasets, including the few-shot settings 21 | * scGPT_run_all_celltypeannot_fewshot.py: runs scGPT fine-tuning for annotating cell types in the myeloid, multiple sclerosis, and pancreas datasets, including the few-shot settings. Based on the [annotation tutorial](tutorials/Tutorial_Annotation.ipynb) provided in scGPT's GitHub repo. 22 | * scGPT_run_all_celltypeannot_nopretrain{_freeze}.py: run our "no pre-training" ablation on scGPT, with or without freezing pre-decoder weights (Supp. Figure 6, Supp. Table 5) 23 | * create_figures_and_tables.ipynb: take the output of the previous scripts to create Figure 3, Supp. Figure 6, and Supp. Table 5 24 | 25 | ## Data Availability 26 | 27 | ### scBERT datasets 28 | * The Zheng68K PBMC data used for finetuning scBERT can be downloaded from our [data/](data) directory. It has been processed using the scBERT/preprocess.py script. 29 | * preprocess.py requires panglao_1000.h5ad, a subsampled version of the panglao dataset on which scBERT was pre-trained, also available in [data/](data). 30 | * The full panglao dataset used for pretraining is too large to host on GitHub, but can be downloaded as per the [instructions](https://github.com/TencentAILabHealthcare/scBERT#data) from the scBERT authors. 31 | 32 | ### scGPT datasets 33 | As provided by the scGPT authors: 34 | - Multiple Sclerosis (M.S.) dataset: [link](https://drive.google.com/drive/folders/1Qd42YNabzyr2pWt9xoY4cVMTAxsNBt4v?usp=sharing) 35 | 36 | - Myeloid (Mye.) dataset: [link](https://drive.google.com/drive/folders/1VbpApQufZq8efFGakW3y8QDDpY9MBoDS?usp=drive_link) 37 | 38 | - hPancreas dataset: [link](https://drive.google.com/drive/folders/1s9XjcSiPC-FYV3VeHrEa7SeZetrthQVV?usp=drive_link) 39 | -------------------------------------------------------------------------------- /data/Zheng68K.h5ad.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clinicalml/sc-foundation-eval/1332b3476e5e07c7143f494b178a51ad1c20baf0/data/Zheng68K.h5ad.gz -------------------------------------------------------------------------------- /data/panglao_10000.h5ad.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clinicalml/sc-foundation-eval/1332b3476e5e07c7143f494b178a51ad1c20baf0/data/panglao_10000.h5ad.gz -------------------------------------------------------------------------------- /scBERT/dist_finetune.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import gc 4 | import argparse 5 | import json 6 | import random 7 | import math 8 | import random 9 | from functools import reduce 10 | import numpy as np 11 | import pandas as pd 12 | from scipy import sparse 13 | from sklearn.model_selection import train_test_split, ShuffleSplit, StratifiedShuffleSplit, StratifiedKFold 14 | from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, precision_recall_fscore_support, classification_report 15 | import torch 16 | from torch import nn 17 | from torch.optim import Adam, SGD, AdamW 18 | from torch.nn import functional as F 19 | from torch.optim.lr_scheduler import StepLR, CosineAnnealingWarmRestarts, CyclicLR 20 | from torch.utils.data import DataLoader, Dataset 21 | from torch.utils.data.distributed import DistributedSampler 22 | from torch.nn.parallel import DistributedDataParallel as DDP 23 | import torch.distributed as dist 24 | from tqdm import tqdm 25 | 26 | from performer_pytorch import PerformerLM 27 | import scanpy as sc 28 | import anndata as ad 29 | from utils import * 30 | from datetime import datetime 31 | from time import time 32 | import torch.multiprocessing as mp 33 | from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present 34 | from torch.utils.tensorboard import SummaryWriter 35 | 36 | 37 | def main(): 38 | parser = argparse.ArgumentParser() 39 | parser.add_argument("--master_addr", type=str, default="127.0.0.1", help='Master addr for dist finetune.') 40 | parser.add_argument("--master_port", type=str, default="8500", help='Master port for dist finetune.') 41 | parser.add_argument("--world_size", type=int, default=1, help='Number of GPUs.') 42 | parser.add_argument("--bin_num", type=int, default=5, help='Number of bins.') 43 | parser.add_argument("--gene_num", type=int, default=None, help='Number of genes.') # 16906, if not supplied, will take the number of genes in the supplied training data 44 | parser.add_argument("--epochs", type=int, default=10, help='Number of epochs.') 45 | parser.add_argument("--seed", type=int, default=2021, help='Random seed.') 46 | parser.add_argument("--batch_size", type=int, default=32, help='Number of batch size.') 47 | parser.add_argument("--learning_rate", type=float, default=1e-4, help='Learning rate.') 48 | parser.add_argument("--grad_acc", type=int, default=1, help='Number of gradient accumulation.') 49 | parser.add_argument("--valid_every", type=int, default=1, help='Number of training epochs between twice validation.') 50 | parser.add_argument("--pos_embed_g2v", action='store_true', help='Using Gene2vec encoding or not (default no unless this arg is supplied).') 51 | parser.add_argument("--g2v_file", type=str, default='/data/rna_rep_learning/scBERT/gene2vec_16906.npy', help='File containing Gene2vec embeddings') 52 | parser.add_argument("--sin_emb_wavelength", type=float, default = None, help='Wavelength of sinusoidal expression encodings. Defaults to bin_num.') 53 | parser.add_argument("--data_path", type=str, default='/data/rna_rep_learning/scBERT/Zheng68K.h5ad', help='Path of data for finetune.') 54 | parser.add_argument("--model_path", type=str, default='ckpts/panglao_full_with_g2v/2022-May-11-17:38:47/panglao_full_with_g2v_epoch_17.pth', help='Path of pretrained checkpoint to load.') 55 | parser.add_argument("--ft_ckpt", action="store_true", help="Add this flag if continuing to train an already finetuned model.") 56 | parser.add_argument("--ckpt_dir", type=str, default='./ckpts/', help='Directory for saving checkpoints.') 57 | parser.add_argument("--use_continuous", action="store_true", help='If this arg is provided, embed continuous expression values and predict continuous expression values during masking, instead of bucketed.') 58 | parser.add_argument("--model_name", type=str, default='finetune', help='Finetuned model name.') 59 | parser.add_argument("--debug", action="store_true", help="Debug setting: saves to new dir.") 60 | parser.add_argument("--small_geneset", action='store_true', help='Train a smaller model. Currently implemented as including genes present in at least 5% of cells.') 61 | args = parser.parse_args() 62 | 63 | model_name = args.model_name 64 | timestamp = datetime.now().strftime("%Y-%b-%d-%H:%M:%S") 65 | ckpt_dir = os.path.join(args.ckpt_dir, model_name, timestamp) 66 | 67 | # Create checkpoint dir if doesn't exist 68 | # NOTE: Done before distributing to avoid process collision 69 | if not (os.path.exists(ckpt_dir)): 70 | os.makedirs(ckpt_dir) 71 | 72 | print("Checkpoint dir: ", ckpt_dir) 73 | 74 | mp.spawn( 75 | distributed_finetune, 76 | args=(args, ckpt_dir, model_name), 77 | nprocs=args.world_size, 78 | join=True, 79 | ) 80 | 81 | 82 | def distributed_finetune(rank, args, ckpt_dir, model_name): 83 | 84 | SEED = args.seed 85 | EPOCHS = args.epochs 86 | BATCH_SIZE = args.batch_size 87 | GRADIENT_ACCUMULATION = args.grad_acc 88 | LEARNING_RATE = args.learning_rate 89 | VALIDATE_EVERY = args.valid_every 90 | CLASS = args.bin_num + 2 91 | POS_EMBED_USING = args.pos_embed_g2v 92 | PATIENCE = 10 93 | UNASSIGN_THRES = 0.0 94 | USE_CONTINUOUS = args.use_continuous 95 | if args.sin_emb_wavelength: 96 | SIN_EMB_WAVELENGTH = args.sin_emb_wavelength 97 | else: 98 | SIN_EMB_WAVELENGTH = args.bin_num 99 | 100 | is_master = rank == 0 101 | master_addr = args.master_addr 102 | master_port = args.master_port 103 | world_size = args.world_size 104 | 105 | # Control sources of randomness 106 | torch.manual_seed(SEED) 107 | random.seed(SEED) 108 | np.random.seed(SEED) 109 | 110 | ### CLASSES FROM ORIGINAL CODE ### 111 | 112 | class SCDataset(Dataset): 113 | def __init__(self, data, label, use_continuous): 114 | super().__init__() 115 | self.data = data 116 | self.label = label 117 | self.use_continuous = use_continuous 118 | 119 | def __getitem__(self, index): 120 | #rand_start = random.randint(0, self.data.shape[0]-1) 121 | full_seq = self.data[index].toarray()[0] 122 | full_seq[full_seq > (CLASS - 2)] = CLASS - 2 123 | full_seq = torch.from_numpy(full_seq).long() #long() converts to int64 124 | if(not self.use_continuous): 125 | full_seq = full_seq.long() #long() converts to int64 126 | full_seq = torch.cat((full_seq, torch.tensor([0]))).to(device) #this is the CLS token ? 127 | seq_label = self.label[index] 128 | return full_seq, seq_label 129 | 130 | def __len__(self): 131 | return self.data.shape[0] 132 | 133 | class Identity(torch.nn.Module): 134 | def __init__(self, dropout = 0., h_dim = 100, out_dim = 10): 135 | super(Identity, self).__init__() 136 | self.conv1 = nn.Conv2d(1, 1, (1, 200)) 137 | self.act = nn.ReLU() 138 | self.fc1 = nn.Linear(in_features=SEQ_LEN, out_features=512, bias=True) 139 | self.act1 = nn.ReLU() 140 | self.dropout1 = nn.Dropout(dropout) 141 | self.fc2 = nn.Linear(in_features=512, out_features=h_dim, bias=True) 142 | self.act2 = nn.ReLU() 143 | self.dropout2 = nn.Dropout(dropout) 144 | self.fc3 = nn.Linear(in_features=h_dim, out_features=out_dim, bias=True) 145 | 146 | def forward(self, x): 147 | x = x[:,None,:,:] 148 | x = self.conv1(x) 149 | x = self.act(x) 150 | x = x.view(x.shape[0],-1) 151 | x = self.fc1(x) 152 | x = self.act1(x) 153 | x = self.dropout1(x) 154 | x = self.fc2(x) 155 | x = self.act2(x) 156 | x = self.dropout2(x) 157 | x = self.fc3(x) 158 | return x 159 | 160 | def preprocess_data_smallgeneset(data_path, ref_data_path = '/data/rna_rep_learning/scBERT/panglao_human.h5ad'): 161 | panglao = sc.read_h5ad(ref_data_path) 162 | sc.pp.filter_genes(panglao, min_cells=0.05*len(panglao)) 163 | data = sc.read_h5ad(data_path) 164 | counts = sparse.lil_matrix((data.X.shape[0],panglao.X.shape[1]),dtype=np.float32) 165 | ref = panglao.var_names.tolist() 166 | obj = data.var_names.tolist() 167 | 168 | for i in range(len(ref)): 169 | if ref[i] in obj: 170 | loc = obj.index(ref[i]) 171 | counts[:,i] = data.X[:,loc] 172 | 173 | counts = counts.tocsr() 174 | new = ad.AnnData(X=counts) 175 | new.var_names = ref 176 | new.obs_names = data.obs_names 177 | new.obs = data.obs 178 | new.uns = panglao.uns 179 | 180 | #sc.pp.filter_cells(new, min_genes=200) 181 | #sc.pp.normalize_total(new, target_sum=1e4) 182 | #sc.pp.log1p(new, base=2) 183 | return(new) 184 | 185 | setup_process(rank, master_addr, master_port, world_size) 186 | device = torch.device("cuda:{}".format(rank)) 187 | 188 | print("Set up distributed processes...") 189 | 190 | data = sc.read_h5ad(args.data_path) 191 | 192 | if args.small_geneset: 193 | data = preprocess_data_smallgeneset(args.data_path) 194 | print("Filtered data to include {} genes present in at least 5% of cells".format(data.shape[1])) 195 | else: 196 | data = sc.read_h5ad(args.data_path) 197 | if args.debug: 198 | debug_seq_len = 5000 199 | data = data[:1000,:debug_seq_len] 200 | GRADIENT_ACCUMULATION = 1 201 | 202 | 203 | label_dict, label = np.unique(np.array(data.obs['celltype']), return_inverse=True) # Convert strings categorical to integrate categorical, and label_dict[label] can be restored 204 | class_num = np.unique(label, return_counts=True)[1].tolist() 205 | #class_weight = torch.tensor([(1 - (x / sum(class_num))) ** 2 for x in class_num]) #doesn't get used anywhere 206 | class_weight = torch.tensor([1/x for x in class_num]) #use this simpler weighting 207 | label = torch.from_numpy(label) 208 | data = data.X 209 | if args.gene_num is not None: 210 | SEQ_LEN = args.gene_num + 1 211 | else: 212 | SEQ_LEN = data.shape[1] + 1 # num_genes + 1 213 | 214 | sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=2022) #update to hardcode to reduce sources of randomness/isolate performance differences between MODELS 215 | for index_train, index_val in sss.split(data, label): 216 | data_train, label_train = data[index_train], label[index_train] 217 | data_val, label_val = data[index_val], label[index_val] 218 | train_dataset = SCDataset(data_train, label_train, USE_CONTINUOUS) 219 | val_dataset = SCDataset(data_val, label_val, USE_CONTINUOUS) 220 | print("size of training data: {}".format(len(train_dataset))) 221 | train_sampler = DistributedSampler(train_dataset, shuffle=True, seed=SEED) 222 | val_sampler = DistributedSampler(val_dataset, shuffle=True, seed=SEED) 223 | 224 | train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=train_sampler) 225 | val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, sampler=val_sampler) 226 | 227 | print("Loaded data...") 228 | 229 | model = PerformerLM( 230 | num_tokens = CLASS, 231 | dim = 200, 232 | depth = 6, 233 | max_seq_len = SEQ_LEN, 234 | heads = 10, 235 | local_attn_heads = 0, 236 | g2v_position_emb = POS_EMBED_USING, 237 | g2v_file = args.g2v_file, 238 | pred_continuous = USE_CONTINUOUS, 239 | sin_emb_wavelength = SIN_EMB_WAVELENGTH, 240 | ) 241 | model = model.to(device) 242 | 243 | # Load checkpoint onto correct rank 244 | checkpoint = torch.load(args.model_path, map_location=device) 245 | consume_prefix_in_state_dict_if_present(checkpoint['model_state_dict'], "module.") 246 | if args.ft_ckpt: 247 | print("Loaded finetuned ckpt...") 248 | model.to_out = Identity(dropout=0., h_dim=128, out_dim=label_dict.shape[0]) 249 | model.load_state_dict(checkpoint['model_state_dict']) 250 | model = model.to(device) 251 | cur_epoch = checkpoint['epoch'] 252 | # Load optimizer 253 | #optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 254 | # Load scheduler 255 | #scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 256 | 257 | #calculate val accuracy of saved model, so knows whether next epoch is worth saving 258 | model.eval() 259 | dist.barrier() 260 | predictions = [] 261 | truths = [] 262 | with torch.no_grad(): 263 | for index, (data_v, labels_v) in enumerate(val_loader): 264 | index += 1 265 | data_v, labels_v = data_v.to(device), labels_v.to(device) 266 | logits = model(data_v) 267 | softmax = nn.Softmax(dim=-1) 268 | final_prob = softmax(logits) 269 | final = final_prob.argmax(dim=-1) 270 | final[np.amax(np.array(final_prob.cpu()), axis=-1) < UNASSIGN_THRES] = -1 271 | predictions.append(final) 272 | truths.append(labels_v) 273 | del data_v, labels_v, logits, final_prob, final 274 | # gather 275 | predictions = distributed_concat(torch.cat(predictions, dim=0), len(val_sampler.dataset), world_size) 276 | truths = distributed_concat(torch.cat(truths, dim=0), len(val_sampler.dataset), world_size) 277 | no_drop = predictions != -1 278 | predictions = np.array((predictions[no_drop]).cpu()) 279 | truths = np.array((truths[no_drop]).cpu()) 280 | max_acc = accuracy_score(truths, predictions) 281 | 282 | else: 283 | print("Loaded pretrained model...") 284 | model.load_state_dict(checkpoint['model_state_dict']) 285 | model.to_out = Identity(dropout=0., h_dim=128, out_dim=label_dict.shape[0]).to(device) 286 | cur_epoch = 0 287 | 288 | for name, param in model.named_parameters(): 289 | param.requires_grad = False 290 | for name, param in model.norm.named_parameters(): 291 | param.requires_grad = True 292 | for name, param in model.performer.net.layers[-1].named_parameters(): #make last layers of performer trainable during fine tuning 293 | param.requires_grad = True 294 | for name, param in model.to_out.named_parameters(): 295 | param.requires_grad = True 296 | 297 | try: 298 | model = DDP(model, device_ids=[device], output_device=device) 299 | 300 | # optimizer 301 | optimizer = Adam(model.parameters(), lr=LEARNING_RATE) 302 | scheduler = CosineAnnealingWarmupRestarts( 303 | optimizer, 304 | first_cycle_steps=15, 305 | cycle_mult=2, 306 | max_lr=LEARNING_RATE, 307 | min_lr=1e-6, 308 | warmup_steps=5, 309 | gamma=0.9 310 | ) 311 | 312 | #implement class weights in loss to handle class imbalance 313 | loss_fn = nn.CrossEntropyLoss(weight=class_weight).to(device) 314 | 315 | dist.barrier() 316 | trigger_times = 0 317 | max_acc = 0.0 318 | writer = SummaryWriter(os.path.join(ckpt_dir, 'tensorboard')) 319 | for i in range(cur_epoch+1, EPOCHS+1): 320 | print("{} iterations in train dataloader per epoch".format(len(train_loader))) 321 | train_loader.sampler.set_epoch(i) 322 | model.train() 323 | dist.barrier() 324 | running_loss = 0.0 325 | cum_acc = 0.0 326 | for index, (data, labels) in tqdm(enumerate(train_loader)): 327 | index += 1 328 | data, labels = data.to(device), labels.to(device) 329 | if index % GRADIENT_ACCUMULATION != 0: 330 | with model.no_sync(): 331 | logits = model(data) 332 | loss = loss_fn(logits, labels) 333 | loss.backward() 334 | if index % GRADIENT_ACCUMULATION == 0: 335 | logits = model(data) 336 | loss = loss_fn(logits, labels) 337 | loss.backward() 338 | torch.nn.utils.clip_grad_norm_(model.parameters(), int(1e6)) 339 | optimizer.step() 340 | optimizer.zero_grad() 341 | running_loss += loss.item() 342 | softmax = nn.Softmax(dim=-1) 343 | final = softmax(logits) 344 | final = final.argmax(dim=-1) 345 | pred_num = labels.size(0) 346 | correct_num = torch.eq(final, labels).sum(dim=-1) 347 | cum_acc += torch.true_divide(correct_num, pred_num).mean().item() 348 | epoch_loss = running_loss / index 349 | epoch_acc = 100 * cum_acc / index 350 | epoch_loss = get_reduced(epoch_loss, device, 0, world_size) 351 | epoch_acc = get_reduced(epoch_acc, device, 0, world_size) 352 | if is_master: 353 | print(f' == Epoch: {i} | Training Loss: {epoch_loss:.6f} | Accuracy: {epoch_acc:6.4f}% ==') 354 | dist.barrier() 355 | scheduler.step() 356 | 357 | if i % VALIDATE_EVERY == 0: 358 | model.eval() 359 | dist.barrier() 360 | running_loss = 0.0 361 | predictions = [] 362 | truths = [] 363 | with torch.no_grad(): 364 | for index, (data_v, labels_v) in enumerate(val_loader): 365 | index += 1 366 | data_v, labels_v = data_v.to(device), labels_v.to(device) 367 | logits = model(data_v) 368 | loss = loss_fn(logits, labels_v) 369 | running_loss += loss.item() 370 | softmax = nn.Softmax(dim=-1) 371 | final_prob = softmax(logits) 372 | final = final_prob.argmax(dim=-1) 373 | final[np.amax(np.array(final_prob.cpu()), axis=-1) < UNASSIGN_THRES] = -1 374 | predictions.append(final) 375 | truths.append(labels_v) 376 | del data_v, labels_v, logits, final_prob, final 377 | # gather 378 | dist.barrier() 379 | predictions = distributed_concat(torch.cat(predictions, dim=0), len(val_sampler.dataset), world_size) 380 | truths = distributed_concat(torch.cat(truths, dim=0), len(val_sampler.dataset), world_size) 381 | no_drop = predictions != -1 382 | predictions = np.array((predictions[no_drop]).cpu()) 383 | truths = np.array((truths[no_drop]).cpu()) 384 | cur_acc = accuracy_score(truths, predictions) 385 | f1 = f1_score(truths, predictions, average='macro') 386 | val_loss = running_loss / index 387 | val_loss = get_reduced(val_loss, device, 0, world_size) 388 | dist.barrier() #hopefully this helps the last epoch get written to tensorboard when training on multiple gpus 389 | if is_master: 390 | print(f' == Epoch: {i} | Validation Loss: {val_loss:.6f} | F1 Score: {f1:.6f} ==') 391 | print(confusion_matrix(truths, predictions)) 392 | print(classification_report(truths, predictions, labels=np.arange(len(label_dict)), target_names=label_dict.tolist(), digits=4)) 393 | 394 | writer.add_scalar('Loss/train', epoch_loss, i) 395 | writer.add_scalar('Accuracy/train', epoch_acc, i) 396 | writer.add_scalar('Loss/val', val_loss, i) 397 | writer.add_scalar('Accuracy/val', cur_acc, i) 398 | writer.add_scalar('F1/val', f1, i) 399 | if cur_acc > max_acc: 400 | max_acc = cur_acc 401 | trigger_times = 0 402 | save_best_ckpt(i, model, optimizer, scheduler, val_loss, model_name, ckpt_dir) 403 | else: 404 | trigger_times += 1 405 | if trigger_times > PATIENCE: 406 | break 407 | del predictions, truths 408 | except Exception as e: 409 | print(e) 410 | pass #so that cleanup() occurs with or without error 411 | cleanup() 412 | 413 | 414 | def setup_process(rank, master_addr, master_port, world_size, backend="nccl"): 415 | print(f"Setting up process: rank={rank} world_size={world_size} backend={backend}.") 416 | print(f"master_addr={master_addr} master_port={master_port}") 417 | os.environ["MASTER_ADDR"] = master_addr 418 | os.environ["MASTER_PORT"] = master_port 419 | dist.init_process_group(backend=backend, rank=rank, world_size=world_size) 420 | 421 | 422 | def cleanup(): 423 | dist.destroy_process_group() 424 | 425 | 426 | if __name__=="__main__": 427 | main() 428 | -------------------------------------------------------------------------------- /scBERT/dist_finetune_fewshot.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import gc 4 | import argparse 5 | import json 6 | import random 7 | import math 8 | import random 9 | from functools import reduce 10 | import numpy as np 11 | import pandas as pd 12 | from scipy import sparse 13 | from sklearn.model_selection import train_test_split, ShuffleSplit, StratifiedShuffleSplit, StratifiedKFold 14 | from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, precision_recall_fscore_support, classification_report 15 | import torch 16 | from torch import nn 17 | from torch.optim import Adam, SGD, AdamW 18 | from torch.nn import functional as F 19 | from torch.optim.lr_scheduler import StepLR, CosineAnnealingWarmRestarts, CyclicLR 20 | from torch.utils.data import DataLoader, Dataset 21 | from torch.utils.data.distributed import DistributedSampler 22 | from torch.nn.parallel import DistributedDataParallel as DDP 23 | import torch.distributed as dist 24 | from tqdm import tqdm 25 | 26 | from performer_pytorch import PerformerLM 27 | import scanpy as sc 28 | import anndata as ad 29 | from utils import * 30 | from datetime import datetime 31 | from time import time 32 | import torch.multiprocessing as mp 33 | from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present 34 | from torch.utils.tensorboard import SummaryWriter 35 | 36 | 37 | def main(): 38 | parser = argparse.ArgumentParser() 39 | parser.add_argument("--master_addr", type=str, default="127.0.0.1", help='Master addr for dist finetune.') 40 | parser.add_argument("--master_port", type=str, default="8500", help='Master port for dist finetune.') 41 | parser.add_argument("--world_size", type=int, default=1, help='Number of GPUs.') 42 | parser.add_argument("--bin_num", type=int, default=5, help='Number of bins.') 43 | parser.add_argument("--gene_num", type=int, default=None, help='Number of genes.') # 16906, if not supplied, will take the number of genes in the supplied training data 44 | parser.add_argument("--epochs", type=int, default=10, help='Number of epochs.') 45 | parser.add_argument("--seed", type=int, default=2021, help='Random seed.') 46 | parser.add_argument("--batch_size", type=int, default=32, help='Number of batch size.') 47 | parser.add_argument("--learning_rate", type=float, default=1e-4, help='Learning rate.') 48 | parser.add_argument("--grad_acc", type=int, default=1, help='Number of gradient accumulation.') 49 | parser.add_argument("--valid_every", type=int, default=1, help='Number of training epochs between twice validation.') 50 | parser.add_argument("--pos_embed_g2v", action='store_true', help='Using Gene2vec encoding or not (default no unless this arg is supplied).') 51 | parser.add_argument("--g2v_file", type=str, default='/data/rna_rep_learning/scBERT/gene2vec_16906.npy', help='File containing Gene2vec embeddings') 52 | parser.add_argument("--data_path", type=str, default='/data/rna_rep_learning/scBERT/Zheng68K.h5ad', help='Path of data for finetune.') 53 | parser.add_argument("--model_path", type=str, default='ckpts/panglao_full_with_g2v/2022-May-11-17:38:47/panglao_full_with_g2v_epoch_17.pth', help='Path of pretrained checkpoint to load.') 54 | parser.add_argument("--ft_ckpt", action="store_true", help="Add this flag if continuing to train an already finetuned model.") 55 | parser.add_argument("--ckpt_dir", type=str, default='./ckpts/', help='Directory for saving checkpoints.') 56 | parser.add_argument("--nreps", type=int, default=3, help='Number of replicates for each data split experiment.') 57 | #parser.add_argument("--sampling_fracs", type=list, default=[1.0, 0.75, 0.5, 0.25, 0.1], help='List of fractions of training data to sample for sample efficiency experiments.') #passing a list doesn't actually work 58 | parser.add_argument("--debug", action="store_true", help="Debug setting: saves to new dir.") 59 | 60 | args = parser.parse_args() 61 | 62 | timestamp = datetime.now().strftime("%Y-%b-%d-%H:%M:%S") 63 | 64 | mp.spawn( 65 | distributed_finetune, 66 | args=(args, timestamp), 67 | nprocs=args.world_size, 68 | join=True, 69 | ) 70 | 71 | 72 | def distributed_finetune(rank, args, timestamp): 73 | 74 | SEED = args.seed 75 | EPOCHS = args.epochs 76 | BATCH_SIZE = args.batch_size 77 | GRADIENT_ACCUMULATION = args.grad_acc 78 | LEARNING_RATE = args.learning_rate 79 | VALIDATE_EVERY = args.valid_every 80 | CLASS = args.bin_num + 2 81 | POS_EMBED_USING = args.pos_embed_g2v 82 | PATIENCE = 10 83 | UNASSIGN_THRES = 0.0 84 | NREPS = args.nreps 85 | SAMPLING_FRACS = [1.0, 0.75, 0.5, 0.25, 0.1] #arg doesn't work currently 86 | 87 | is_master = rank == 0 88 | master_addr = args.master_addr 89 | master_port = args.master_port 90 | world_size = args.world_size 91 | 92 | ### CLASSES FROM ORIGINAL CODE ### 93 | 94 | class SCDataset(Dataset): 95 | def __init__(self, data, label): 96 | super().__init__() 97 | self.data = data 98 | self.label = label 99 | 100 | def __getitem__(self, index): 101 | #rand_start = random.randint(0, self.data.shape[0]-1) 102 | full_seq = self.data[index].toarray()[0] 103 | full_seq[full_seq > (CLASS - 2)] = CLASS - 2 104 | full_seq = torch.from_numpy(full_seq).long() #long() converts to int64 105 | full_seq = torch.cat((full_seq, torch.tensor([0]))).to(device) #this is the CLS token ? 106 | seq_label = self.label[index] 107 | return full_seq, seq_label 108 | 109 | def __len__(self): 110 | return self.data.shape[0] 111 | 112 | class Identity(torch.nn.Module): 113 | def __init__(self, dropout = 0., h_dim = 100, out_dim = 10): 114 | super(Identity, self).__init__() 115 | self.conv1 = nn.Conv2d(1, 1, (1, 200)) 116 | self.act = nn.ReLU() 117 | self.fc1 = nn.Linear(in_features=SEQ_LEN, out_features=512, bias=True) 118 | self.act1 = nn.ReLU() 119 | self.dropout1 = nn.Dropout(dropout) 120 | self.fc2 = nn.Linear(in_features=512, out_features=h_dim, bias=True) 121 | self.act2 = nn.ReLU() 122 | self.dropout2 = nn.Dropout(dropout) 123 | self.fc3 = nn.Linear(in_features=h_dim, out_features=out_dim, bias=True) 124 | 125 | def forward(self, x): 126 | x = x[:,None,:,:] 127 | x = self.conv1(x) 128 | x = self.act(x) 129 | x = x.view(x.shape[0],-1) 130 | x = self.fc1(x) 131 | x = self.act1(x) 132 | x = self.dropout1(x) 133 | x = self.fc2(x) 134 | x = self.act2(x) 135 | x = self.dropout2(x) 136 | x = self.fc3(x) 137 | return x 138 | 139 | cur_time = time() 140 | setup_process(rank, master_addr, master_port, world_size) 141 | device = torch.device("cuda:{}".format(rank)) 142 | 143 | print("Set up distributed processes...") 144 | 145 | data = sc.read_h5ad(args.data_path) 146 | if args.debug: 147 | debug_seq_len = 5000 148 | data = data[:1000,:debug_seq_len] 149 | GRADIENT_ACCUMULATION = 1 150 | label_dict, label = np.unique(np.array(data.obs['celltype']), return_inverse=True) # Convert strings categorical to integrate categorical, and label_dict[label] can be restored 151 | class_num = np.unique(label, return_counts=True)[1].tolist() 152 | #class_weight = torch.tensor([(1 - (x / sum(class_num))) ** 2 for x in class_num]) #doesn't get used anywhere 153 | class_weight = torch.tensor([1/x for x in class_num]) #use this simpler weighting 154 | label = torch.from_numpy(label) 155 | data = data.X 156 | if args.gene_num is not None: 157 | SEQ_LEN = args.gene_num + 1 158 | else: 159 | SEQ_LEN = data.shape[1] + 1 # num_genes + 1 160 | 161 | 162 | def instantiate_new_model(): 163 | #create new model 164 | model = PerformerLM( 165 | num_tokens = CLASS, 166 | dim = 200, 167 | depth = 6, 168 | max_seq_len = SEQ_LEN, 169 | heads = 10, 170 | local_attn_heads = 0, 171 | g2v_position_emb = POS_EMBED_USING, 172 | g2v_file = args.g2v_file 173 | ) 174 | model = model.to(device) 175 | 176 | # Load checkpoint onto correct rank 177 | checkpoint = torch.load(args.model_path, map_location=device) 178 | consume_prefix_in_state_dict_if_present(checkpoint['model_state_dict'], "module.") 179 | if args.ft_ckpt: 180 | print("Loaded finetuned ckpt...") 181 | model.to_out = Identity(dropout=0., h_dim=128, out_dim=label_dict.shape[0]) 182 | model.load_state_dict(checkpoint['model_state_dict']) 183 | model = model.to(device) 184 | cur_epoch = checkpoint['epoch'] 185 | # Load optimizer 186 | #optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 187 | # Load scheduler 188 | #scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 189 | else: 190 | print("Loaded pretrained model...") 191 | model.load_state_dict(checkpoint['model_state_dict']) 192 | model.to_out = Identity(dropout=0., h_dim=128, out_dim=label_dict.shape[0]).to(device) 193 | cur_epoch = 0 194 | 195 | for name, param in model.named_parameters(): 196 | param.requires_grad = False 197 | for name, param in model.norm.named_parameters(): 198 | param.requires_grad = True 199 | for name, param in model.performer.net.layers[-1].named_parameters(): #make last layers of performer trainable during fine tuning 200 | param.requires_grad = True 201 | for name, param in model.to_out.named_parameters(): 202 | param.requires_grad = True 203 | 204 | # optimizer 205 | optimizer = Adam(model.parameters(), lr=LEARNING_RATE) 206 | scheduler = CosineAnnealingWarmupRestarts( 207 | optimizer, 208 | first_cycle_steps=15, 209 | cycle_mult=2, 210 | max_lr=LEARNING_RATE, 211 | min_lr=1e-6, 212 | warmup_steps=5, 213 | gamma=0.9 214 | ) 215 | 216 | return(model, optimizer, scheduler, cur_epoch) 217 | 218 | try: 219 | for k in np.arange(NREPS): 220 | # Control sources of randomness - for each run k, different seed is used 221 | # this effects parameter initialization & the subsampling of the [fixed] training set 222 | torch.manual_seed(SEED*k) 223 | random.seed(SEED*k) 224 | np.random.seed(SEED*k) 225 | 226 | accs = [] 227 | f1s = [] 228 | for frac in SAMPLING_FRACS: 229 | #create new model 230 | model, optimizer, scheduler, cur_epoch = instantiate_new_model() 231 | model = DDP(model, device_ids=[device], output_device=device) 232 | 233 | #ckpt dir setup, only need one process to create directory so use dist.barrier() 234 | model_name = "finetune_sampleeff_{}_{}".format(frac, k) 235 | ckpt_dir = os.path.join("ckpts/", "finetune-sampleeff-"+timestamp, model_name) 236 | if is_master: 237 | print("Checkpoint dir: ", ckpt_dir) 238 | if not (os.path.exists(ckpt_dir)): 239 | os.makedirs(ckpt_dir) 240 | dist.barrier() 241 | 242 | 243 | #implement class weights in loss to handle class imbalance 244 | loss_fn = nn.CrossEntropyLoss(weight=class_weight).to(device) 245 | 246 | dist.barrier() 247 | trigger_times = 0 248 | max_acc = 0.0 249 | writer = SummaryWriter(os.path.join(ckpt_dir, 'tensorboard')) 250 | 251 | # attempt to seed dataloader - this is required for true reproducibility 252 | def seed_worker(worker_id): 253 | worker_seed = torch.initial_seed() % 2**32 254 | numpy.random.seed(worker_seed) 255 | random.seed(worker_seed) 256 | 257 | g = torch.Generator() 258 | g.manual_seed(0) 259 | 260 | #downsample training set 261 | sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=2022) #same val set across all runs 262 | for index_train, index_val in sss.split(data, label): 263 | index_train_small = np.random.choice(index_train, round(index_train.shape[0]*frac), replace=False) # different random subset will be chosen with each k 264 | data_train, label_train = data[index_train_small], label[index_train_small] 265 | train_dataset = SCDataset(data_train, label_train) 266 | data_val, label_val = data[index_val], label[index_val] 267 | val_dataset = SCDataset(data_val, label_val) 268 | train_sampler = DistributedSampler(train_dataset, shuffle=True) 269 | train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=train_sampler, worker_init_fn=seed_worker, generator=g) 270 | val_sampler = DistributedSampler(val_dataset, shuffle=True) 271 | val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, sampler=val_sampler, worker_init_fn=seed_worker, generator=g) 272 | 273 | print("Loaded data...") 274 | 275 | for i in range(cur_epoch+1, EPOCHS+1): 276 | print("{} iterations in train dataloader per epoch".format(len(train_loader))) 277 | train_loader.sampler.set_epoch(i) 278 | model.train() 279 | dist.barrier() 280 | running_loss = 0.0 281 | cum_acc = 0.0 282 | for index, (data_t, labels_t) in tqdm(enumerate(train_loader)): 283 | index += 1 284 | data_t, labels_t = data_t.to(device), labels_t.to(device) 285 | if index % GRADIENT_ACCUMULATION != 0: 286 | with model.no_sync(): 287 | logits = model(data_t) 288 | loss = loss_fn(logits, labels_t) 289 | loss.backward() 290 | if index % GRADIENT_ACCUMULATION == 0: 291 | logits = model(data_t) 292 | loss = loss_fn(logits, labels_t) 293 | loss.backward() 294 | torch.nn.utils.clip_grad_norm_(model.parameters(), int(1e6)) 295 | optimizer.step() 296 | optimizer.zero_grad() 297 | running_loss += loss.item() 298 | softmax = nn.Softmax(dim=-1) 299 | final = softmax(logits) 300 | final = final.argmax(dim=-1) 301 | pred_num = labels_t.size(0) 302 | correct_num = torch.eq(final, labels_t).sum(dim=-1) 303 | cum_acc += torch.true_divide(correct_num, pred_num).mean().item() 304 | epoch_loss = running_loss / index 305 | epoch_acc = 100 * cum_acc / index 306 | epoch_loss = get_reduced(epoch_loss, device, 0, world_size) 307 | epoch_acc = get_reduced(epoch_acc, device, 0, world_size) 308 | if is_master: 309 | print(f' == Epoch: {i} | Training Loss: {epoch_loss:.6f} | Accuracy: {epoch_acc:6.4f}% ==') 310 | dist.barrier() 311 | scheduler.step() 312 | 313 | if i % VALIDATE_EVERY == 0: 314 | model.eval() 315 | dist.barrier() 316 | running_loss = 0.0 317 | predictions = [] 318 | truths = [] 319 | with torch.no_grad(): 320 | for index, (data_v, labels_v) in enumerate(val_loader): 321 | index += 1 322 | data_v, labels_v = data_v.to(device), labels_v.to(device) 323 | logits = model(data_v) 324 | loss = loss_fn(logits, labels_v) 325 | running_loss += loss.item() 326 | softmax = nn.Softmax(dim=-1) 327 | final_prob = softmax(logits) 328 | final = final_prob.argmax(dim=-1) 329 | final[np.amax(np.array(final_prob.cpu()), axis=-1) < UNASSIGN_THRES] = -1 330 | predictions.append(final) 331 | truths.append(labels_v) 332 | del data_v, labels_v, logits, final_prob, final 333 | # gather 334 | predictions = distributed_concat(torch.cat(predictions, dim=0), len(val_sampler.dataset), world_size) 335 | truths = distributed_concat(torch.cat(truths, dim=0), len(val_sampler.dataset), world_size) 336 | no_drop = predictions != -1 337 | predictions = np.array((predictions[no_drop]).cpu()) 338 | truths = np.array((truths[no_drop]).cpu()) 339 | cur_acc = accuracy_score(truths, predictions) 340 | f1 = f1_score(truths, predictions, average='macro') 341 | val_loss = running_loss / index 342 | val_loss = get_reduced(val_loss, device, 0, world_size) 343 | if is_master: 344 | print(f' == Epoch: {i} | Validation Loss: {val_loss:.6f} | F1 Score: {f1:.6f} | Accuracy: {cur_acc:.3f} ==') 345 | print(confusion_matrix(truths, predictions)) 346 | print(classification_report(truths, predictions, labels=np.arange(len(label_dict)), target_names=label_dict.tolist(), digits=4)) 347 | 348 | duration = time() - cur_time 349 | cur_time = time() 350 | 351 | writer.add_scalar('Loss/train', epoch_loss, i) 352 | writer.add_scalar('Accuracy/train', epoch_acc, i) 353 | writer.add_scalar('Loss/val', val_loss, i) 354 | writer.add_scalar('Accuracy/val', cur_acc, i) 355 | writer.add_scalar('F1/val', f1, i) 356 | if cur_acc > max_acc: 357 | max_acc = cur_acc 358 | trigger_times = 0 359 | save_best_ckpt(i, model, optimizer, scheduler, val_loss, model_name, ckpt_dir) 360 | else: 361 | trigger_times += 1 362 | if trigger_times > PATIENCE: 363 | break 364 | del predictions, truths 365 | accs.append(cur_acc) 366 | f1s.append(f1) 367 | if is_master: 368 | print("fraction of training set:") 369 | print(SAMPLING_FRACS) 370 | print("effective fraction of full dataset:") 371 | print([np.round(s*0.8,2) for s in SAMPLING_FRACS]) #size of training set as fraction of overall dataset size 372 | print(accs) 373 | print(f1s) 374 | with open('logs/finetune_sampleeff_{}_{}.txt'.format(k, timestamp), 'a') as fd: 375 | fd.write(','.join([str(a) for a in SAMPLING_FRACS])+'\n') 376 | fd.write(','.join([str(np.round(s*0.8,2)) for s in SAMPLING_FRACS])+'\n') 377 | fd.write(','.join(map(lambda x: str(x), accs))+'\n') 378 | fd.write(','.join(map(lambda x: str(x), f1s))+'\n') 379 | except Exception as e: 380 | print(e) 381 | pass #so that cleanup() occurs with or without error 382 | cleanup() 383 | 384 | 385 | def setup_process(rank, master_addr, master_port, world_size, backend="nccl"): 386 | print(f"Setting up process: rank={rank} world_size={world_size} backend={backend}.") 387 | print(f"master_addr={master_addr} master_port={master_port}") 388 | os.environ["MASTER_ADDR"] = master_addr 389 | os.environ["MASTER_PORT"] = master_port 390 | dist.init_process_group(backend=backend, rank=rank, world_size=world_size) 391 | 392 | 393 | def cleanup(): 394 | dist.destroy_process_group() 395 | 396 | 397 | if __name__=="__main__": 398 | main() 399 | -------------------------------------------------------------------------------- /scBERT/dist_finetune_nopretraining.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import gc 4 | import argparse 5 | import json 6 | import random 7 | import math 8 | import random 9 | from functools import reduce 10 | import numpy as np 11 | import pandas as pd 12 | from scipy import sparse 13 | from sklearn.model_selection import train_test_split, ShuffleSplit, StratifiedShuffleSplit, StratifiedKFold 14 | from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, precision_recall_fscore_support, classification_report 15 | import torch 16 | from torch import nn 17 | from torch.optim import Adam, SGD, AdamW 18 | from torch.nn import functional as F 19 | from torch.optim.lr_scheduler import StepLR, CosineAnnealingWarmRestarts, CyclicLR 20 | from torch.utils.data import DataLoader, Dataset 21 | from torch.utils.data.distributed import DistributedSampler 22 | from torch.nn.parallel import DistributedDataParallel as DDP 23 | import torch.distributed as dist 24 | from tqdm import tqdm 25 | 26 | from performer_pytorch import PerformerLM 27 | import scanpy as sc 28 | import anndata as ad 29 | from utils import * 30 | from datetime import datetime 31 | from time import time 32 | import torch.multiprocessing as mp 33 | from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present 34 | from torch.utils.tensorboard import SummaryWriter 35 | 36 | 37 | def main(): 38 | parser = argparse.ArgumentParser() 39 | parser.add_argument("--master_addr", type=str, default="127.0.0.1", help='Master addr for dist finetune.') 40 | parser.add_argument("--master_port", type=str, default="8500", help='Master port for dist finetune.') 41 | parser.add_argument("--world_size", type=int, default=1, help='Number of GPUs.') 42 | parser.add_argument("--bin_num", type=int, default=5, help='Number of bins.') 43 | parser.add_argument("--gene_num", type=int, default=None, help='Number of genes.') # 16906, if not supplied, will take the number of genes in the supplied training data 44 | parser.add_argument("--epochs", type=int, default=10, help='Number of epochs.') 45 | parser.add_argument("--seed", type=int, default=2021, help='Random seed.') 46 | parser.add_argument("--batch_size", type=int, default=32, help='Number of batch size.') 47 | parser.add_argument("--learning_rate", type=float, default=1e-4, help='Learning rate.') 48 | parser.add_argument("--grad_acc", type=int, default=1, help='Number of gradient accumulation.') 49 | parser.add_argument("--valid_every", type=int, default=1, help='Number of training epochs between twice validation.') 50 | parser.add_argument("--pos_embed_g2v", action='store_true', help='Using Gene2vec encoding or not (default no unless this arg is supplied).') 51 | parser.add_argument("--g2v_file", type=str, default='/data/rna_rep_learning/scBERT/gene2vec_16906.npy', help='File containing Gene2vec embeddings') 52 | parser.add_argument("--sin_emb_wavelength", type=float, default = None, help='Wavelength of sinusoidal expression encodings. Defaults to bin_num.') 53 | parser.add_argument("--data_path", type=str, default='/data/rna_rep_learning/scBERT/Zheng68K.h5ad', help='Path of data for finetune.') 54 | parser.add_argument("--model_path", type=str, default='ckpts/panglao_full_with_g2v/2022-May-11-17:38:47/panglao_full_with_g2v_epoch_17.pth', help='Path of pretrained checkpoint to load.') 55 | parser.add_argument("--ft_ckpt", action="store_true", help="Add this flag if continuing to train an already finetuned model.") 56 | parser.add_argument("--ckpt_dir", type=str, default='./ckpts/', help='Directory for saving checkpoints.') 57 | parser.add_argument("--use_continuous", action="store_true", help='If this arg is provided, embed continuous expression values and predict continuous expression values during masking, instead of bucketed.') 58 | parser.add_argument("--model_name", type=str, default='finetune', help='Finetuned model name.') 59 | parser.add_argument("--debug", action="store_true", help="Debug setting: saves to new dir.") 60 | parser.add_argument("--small_geneset", action='store_true', help='Train a smaller model. Currently implemented as including genes present in at least 5% of cells.') 61 | args = parser.parse_args() 62 | 63 | model_name = args.model_name 64 | timestamp = datetime.now().strftime("%Y-%b-%d-%H:%M:%S") 65 | ckpt_dir = os.path.join(args.ckpt_dir, model_name, timestamp) 66 | 67 | # Create checkpoint dir if doesn't exist 68 | # NOTE: Done before distributing to avoid process collision 69 | if not (os.path.exists(ckpt_dir)): 70 | os.makedirs(ckpt_dir) 71 | 72 | print("Checkpoint dir: ", ckpt_dir) 73 | 74 | mp.spawn( 75 | distributed_finetune, 76 | args=(args, ckpt_dir, model_name), 77 | nprocs=args.world_size, 78 | join=True, 79 | ) 80 | 81 | 82 | def distributed_finetune(rank, args, ckpt_dir, model_name): 83 | 84 | SEED = args.seed 85 | EPOCHS = args.epochs 86 | BATCH_SIZE = args.batch_size 87 | GRADIENT_ACCUMULATION = args.grad_acc 88 | LEARNING_RATE = args.learning_rate 89 | VALIDATE_EVERY = args.valid_every 90 | CLASS = args.bin_num + 2 91 | POS_EMBED_USING = args.pos_embed_g2v 92 | PATIENCE = 10 93 | UNASSIGN_THRES = 0.0 94 | USE_CONTINUOUS = args.use_continuous 95 | if args.sin_emb_wavelength: 96 | SIN_EMB_WAVELENGTH = args.sin_emb_wavelength 97 | else: 98 | SIN_EMB_WAVELENGTH = args.bin_num 99 | 100 | is_master = rank == 0 101 | master_addr = args.master_addr 102 | master_port = args.master_port 103 | world_size = args.world_size 104 | 105 | # Control sources of randomness 106 | torch.manual_seed(SEED) 107 | random.seed(SEED) 108 | np.random.seed(SEED) 109 | 110 | ### CLASSES FROM ORIGINAL CODE ### 111 | 112 | class SCDataset(Dataset): 113 | def __init__(self, data, label, use_continuous): 114 | super().__init__() 115 | self.data = data 116 | self.label = label 117 | self.use_continuous = use_continuous 118 | 119 | def __getitem__(self, index): 120 | #rand_start = random.randint(0, self.data.shape[0]-1) 121 | full_seq = self.data[index].toarray()[0] 122 | full_seq[full_seq > (CLASS - 2)] = CLASS - 2 123 | full_seq = torch.from_numpy(full_seq).long() #long() converts to int64 124 | if(not self.use_continuous): 125 | full_seq = full_seq.long() #long() converts to int64 126 | full_seq = torch.cat((full_seq, torch.tensor([0]))).to(device) #this is the CLS token ? 127 | seq_label = self.label[index] 128 | return full_seq, seq_label 129 | 130 | def __len__(self): 131 | return self.data.shape[0] 132 | 133 | class Identity(torch.nn.Module): 134 | def __init__(self, dropout = 0., h_dim = 100, out_dim = 10): 135 | super(Identity, self).__init__() 136 | self.conv1 = nn.Conv2d(1, 1, (1, 200)) 137 | self.act = nn.ReLU() 138 | self.fc1 = nn.Linear(in_features=SEQ_LEN, out_features=512, bias=True) 139 | self.act1 = nn.ReLU() 140 | self.dropout1 = nn.Dropout(dropout) 141 | self.fc2 = nn.Linear(in_features=512, out_features=h_dim, bias=True) 142 | self.act2 = nn.ReLU() 143 | self.dropout2 = nn.Dropout(dropout) 144 | self.fc3 = nn.Linear(in_features=h_dim, out_features=out_dim, bias=True) 145 | 146 | def forward(self, x): 147 | x = x[:,None,:,:] 148 | x = self.conv1(x) 149 | x = self.act(x) 150 | x = x.view(x.shape[0],-1) 151 | x = self.fc1(x) 152 | x = self.act1(x) 153 | x = self.dropout1(x) 154 | x = self.fc2(x) 155 | x = self.act2(x) 156 | x = self.dropout2(x) 157 | x = self.fc3(x) 158 | return x 159 | 160 | def preprocess_data_smallgeneset(data_path, ref_data_path = '/data/rna_rep_learning/scBERT/panglao_human.h5ad'): 161 | panglao = sc.read_h5ad(ref_data_path) 162 | sc.pp.filter_genes(panglao, min_cells=0.05*len(panglao)) 163 | data = sc.read_h5ad(data_path) 164 | counts = sparse.lil_matrix((data.X.shape[0],panglao.X.shape[1]),dtype=np.float32) 165 | ref = panglao.var_names.tolist() 166 | obj = data.var_names.tolist() 167 | 168 | for i in range(len(ref)): 169 | if ref[i] in obj: 170 | loc = obj.index(ref[i]) 171 | counts[:,i] = data.X[:,loc] 172 | 173 | counts = counts.tocsr() 174 | new = ad.AnnData(X=counts) 175 | new.var_names = ref 176 | new.obs_names = data.obs_names 177 | new.obs = data.obs 178 | new.uns = panglao.uns 179 | 180 | #sc.pp.filter_cells(new, min_genes=200) 181 | #sc.pp.normalize_total(new, target_sum=1e4) 182 | #sc.pp.log1p(new, base=2) 183 | return(new) 184 | 185 | setup_process(rank, master_addr, master_port, world_size) 186 | device = torch.device("cuda:{}".format(rank)) 187 | 188 | print("Set up distributed processes...") 189 | 190 | data = sc.read_h5ad(args.data_path) 191 | 192 | if args.small_geneset: 193 | data = preprocess_data_smallgeneset(args.data_path) 194 | print("Filtered data to include {} genes present in at least 5% of cells".format(data.shape[1])) 195 | else: 196 | data = sc.read_h5ad(args.data_path) 197 | if args.debug: 198 | debug_seq_len = 5000 199 | data = data[:1000,:debug_seq_len] 200 | GRADIENT_ACCUMULATION = 1 201 | 202 | 203 | label_dict, label = np.unique(np.array(data.obs['celltype']), return_inverse=True) # Convert strings categorical to integrate categorical, and label_dict[label] can be restored 204 | class_num = np.unique(label, return_counts=True)[1].tolist() 205 | #class_weight = torch.tensor([(1 - (x / sum(class_num))) ** 2 for x in class_num]) #doesn't get used anywhere 206 | class_weight = torch.tensor([1/x for x in class_num]) #use this simpler weighting 207 | label = torch.from_numpy(label) 208 | data = data.X 209 | if args.gene_num is not None: 210 | SEQ_LEN = args.gene_num + 1 211 | else: 212 | SEQ_LEN = data.shape[1] + 1 # num_genes + 1 213 | 214 | sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=2022) #update to hardcode to reduce sources of randomness/isolate performance differences between MODELS 215 | for index_train, index_val in sss.split(data, label): 216 | data_train, label_train = data[index_train], label[index_train] 217 | data_val, label_val = data[index_val], label[index_val] 218 | train_dataset = SCDataset(data_train, label_train, USE_CONTINUOUS) 219 | val_dataset = SCDataset(data_val, label_val, USE_CONTINUOUS) 220 | print("size of training data: {}".format(len(train_dataset))) 221 | train_sampler = DistributedSampler(train_dataset, shuffle=True, seed=SEED) 222 | val_sampler = DistributedSampler(val_dataset, shuffle=True, seed=SEED) 223 | 224 | train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=train_sampler) 225 | val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, sampler=val_sampler) 226 | 227 | print("Loaded data...") 228 | 229 | model = PerformerLM( 230 | num_tokens = CLASS, 231 | dim = 200, 232 | depth = 6, 233 | max_seq_len = SEQ_LEN, 234 | heads = 10, 235 | local_attn_heads = 0, 236 | g2v_position_emb = POS_EMBED_USING, 237 | g2v_file = args.g2v_file, 238 | pred_continuous = USE_CONTINUOUS, 239 | sin_emb_wavelength = SIN_EMB_WAVELENGTH, 240 | ) 241 | model = model.to(device) 242 | 243 | ### BLOCK BELOW IS COMMENTED OUT FOR "NO PRETRAINING" EXPERIMENT ### 244 | """ 245 | # Load checkpoint onto correct rank 246 | checkpoint = torch.load(args.model_path, map_location=device) 247 | consume_prefix_in_state_dict_if_present(checkpoint['model_state_dict'], "module.") 248 | if args.ft_ckpt: 249 | print("Loaded finetuned ckpt...") 250 | model.to_out = Identity(dropout=0., h_dim=128, out_dim=label_dict.shape[0]) 251 | model.load_state_dict(checkpoint['model_state_dict']) 252 | model = model.to(device) 253 | cur_epoch = checkpoint['epoch'] 254 | # Load optimizer 255 | #optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 256 | # Load scheduler 257 | #scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 258 | else: 259 | print("Loaded pretrained model...") 260 | model.load_state_dict(checkpoint['model_state_dict']) 261 | """ 262 | model.to_out = Identity(dropout=0., h_dim=128, out_dim=label_dict.shape[0]).to(device) 263 | cur_epoch = 0 264 | 265 | for name, param in model.named_parameters(): 266 | param.requires_grad = False 267 | for name, param in model.norm.named_parameters(): 268 | param.requires_grad = True 269 | for name, param in model.performer.net.layers[-1].named_parameters(): #make last layers of performer trainable during fine tuning 270 | param.requires_grad = True 271 | for name, param in model.to_out.named_parameters(): 272 | param.requires_grad = True 273 | 274 | try: 275 | model = DDP(model, device_ids=[device], output_device=device) 276 | 277 | # optimizer 278 | optimizer = Adam(model.parameters(), lr=LEARNING_RATE) 279 | scheduler = CosineAnnealingWarmupRestarts( 280 | optimizer, 281 | first_cycle_steps=15, 282 | cycle_mult=2, 283 | max_lr=LEARNING_RATE, 284 | min_lr=1e-6, 285 | warmup_steps=5, 286 | gamma=0.9 287 | ) 288 | 289 | #implement class weights in loss to handle class imbalance 290 | loss_fn = nn.CrossEntropyLoss(weight=class_weight).to(device) 291 | 292 | dist.barrier() 293 | trigger_times = 0 294 | max_acc = 0.0 295 | writer = SummaryWriter(os.path.join(ckpt_dir, 'tensorboard')) 296 | for i in range(cur_epoch+1, EPOCHS+1): 297 | print("{} iterations in train dataloader per epoch".format(len(train_loader))) 298 | train_loader.sampler.set_epoch(i) 299 | model.train() 300 | dist.barrier() 301 | running_loss = 0.0 302 | cum_acc = 0.0 303 | for index, (data, labels) in tqdm(enumerate(train_loader)): 304 | index += 1 305 | data, labels = data.to(device), labels.to(device) 306 | if index % GRADIENT_ACCUMULATION != 0: 307 | with model.no_sync(): 308 | logits = model(data) 309 | loss = loss_fn(logits, labels) 310 | loss.backward() 311 | if index % GRADIENT_ACCUMULATION == 0: 312 | logits = model(data) 313 | loss = loss_fn(logits, labels) 314 | loss.backward() 315 | torch.nn.utils.clip_grad_norm_(model.parameters(), int(1e6)) 316 | optimizer.step() 317 | optimizer.zero_grad() 318 | running_loss += loss.item() 319 | softmax = nn.Softmax(dim=-1) 320 | final = softmax(logits) 321 | final = final.argmax(dim=-1) 322 | pred_num = labels.size(0) 323 | correct_num = torch.eq(final, labels).sum(dim=-1) 324 | cum_acc += torch.true_divide(correct_num, pred_num).mean().item() 325 | epoch_loss = running_loss / index 326 | epoch_acc = 100 * cum_acc / index 327 | epoch_loss = get_reduced(epoch_loss, device, 0, world_size) 328 | epoch_acc = get_reduced(epoch_acc, device, 0, world_size) 329 | if is_master: 330 | print(f' == Epoch: {i} | Training Loss: {epoch_loss:.6f} | Accuracy: {epoch_acc:6.4f}% ==') 331 | dist.barrier() 332 | scheduler.step() 333 | 334 | if i % VALIDATE_EVERY == 0: 335 | model.eval() 336 | dist.barrier() 337 | running_loss = 0.0 338 | predictions = [] 339 | truths = [] 340 | with torch.no_grad(): 341 | for index, (data_v, labels_v) in enumerate(val_loader): 342 | index += 1 343 | data_v, labels_v = data_v.to(device), labels_v.to(device) 344 | logits = model(data_v) 345 | loss = loss_fn(logits, labels_v) 346 | running_loss += loss.item() 347 | softmax = nn.Softmax(dim=-1) 348 | final_prob = softmax(logits) 349 | final = final_prob.argmax(dim=-1) 350 | final[np.amax(np.array(final_prob.cpu()), axis=-1) < UNASSIGN_THRES] = -1 351 | predictions.append(final) 352 | truths.append(labels_v) 353 | del data_v, labels_v, logits, final_prob, final 354 | # gather 355 | dist.barrier() 356 | predictions = distributed_concat(torch.cat(predictions, dim=0), len(val_sampler.dataset), world_size) 357 | truths = distributed_concat(torch.cat(truths, dim=0), len(val_sampler.dataset), world_size) 358 | no_drop = predictions != -1 359 | predictions = np.array((predictions[no_drop]).cpu()) 360 | truths = np.array((truths[no_drop]).cpu()) 361 | cur_acc = accuracy_score(truths, predictions) 362 | f1 = f1_score(truths, predictions, average='macro') 363 | val_loss = running_loss / index 364 | val_loss = get_reduced(val_loss, device, 0, world_size) 365 | dist.barrier() 366 | if is_master: 367 | print(f' == Epoch: {i} | Validation Loss: {val_loss:.6f} | F1 Score: {f1:.6f} ==') 368 | print(confusion_matrix(truths, predictions)) 369 | print(classification_report(truths, predictions, labels=np.arange(len(label_dict)), target_names=label_dict.tolist(), digits=4)) 370 | 371 | writer.add_scalar('Loss/train', epoch_loss, i) 372 | writer.add_scalar('Accuracy/train', epoch_acc, i) 373 | writer.add_scalar('Loss/val', val_loss, i) 374 | writer.add_scalar('Accuracy/val', cur_acc, i) 375 | writer.add_scalar('F1/val', f1, i) 376 | if cur_acc > max_acc: 377 | max_acc = cur_acc 378 | trigger_times = 0 379 | save_best_ckpt(i, model, optimizer, scheduler, val_loss, model_name, ckpt_dir) 380 | else: 381 | trigger_times += 1 382 | if trigger_times > PATIENCE: 383 | break 384 | del predictions, truths 385 | except Exception as e: 386 | print(e) 387 | pass #so that cleanup() occurs with or without error 388 | cleanup() 389 | 390 | 391 | def setup_process(rank, master_addr, master_port, world_size, backend="nccl"): 392 | print(f"Setting up process: rank={rank} world_size={world_size} backend={backend}.") 393 | print(f"master_addr={master_addr} master_port={master_port}") 394 | os.environ["MASTER_ADDR"] = master_addr 395 | os.environ["MASTER_PORT"] = master_port 396 | dist.init_process_group(backend=backend, rank=rank, world_size=world_size) 397 | 398 | 399 | def cleanup(): 400 | dist.destroy_process_group() 401 | 402 | 403 | if __name__=="__main__": 404 | main() 405 | -------------------------------------------------------------------------------- /scBERT/dist_pretrain.py: -------------------------------------------------------------------------------- 1 | from cgi import print_directory 2 | import os 3 | import gc 4 | import argparse 5 | import json 6 | import random 7 | import math 8 | import random 9 | from functools import reduce 10 | import numpy as np 11 | import pandas as pd 12 | from scipy import sparse 13 | from sklearn.model_selection import train_test_split 14 | import torch 15 | from torch import nn 16 | from torch.optim import Adam 17 | from torch.nn import functional as F 18 | from torch.utils.data import DataLoader, Dataset 19 | from torch.utils.data.distributed import DistributedSampler 20 | from torch.utils.tensorboard import SummaryWriter 21 | from torch.nn.parallel import DistributedDataParallel as DDP 22 | from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present 23 | import torch.distributed as dist 24 | import torch.multiprocessing as mp 25 | from performer_pytorch import PerformerLM 26 | import scanpy as sc 27 | import anndata as ad 28 | from utils import * 29 | from tqdm import tqdm 30 | from datetime import datetime 31 | from time import time 32 | from collections import OrderedDict 33 | 34 | 35 | def main(): 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument("--master_addr", type=str, default="127.0.0.1", help='Master addr for dist training.') 38 | parser.add_argument("--master_port", type=str, default="8500", help='Master port for dist training.') 39 | parser.add_argument("--world_size", type=int, default=2, help='Number of GPUs.') 40 | parser.add_argument("--bin_num", type=int, default=5, help='Number of bins.') 41 | parser.add_argument("--gene_num", type=int, default=None, help='Number of genes.') # 16906, if not supplied, will take the number of genes in the supplied training data 42 | parser.add_argument("--epochs", type=int, default=100, help='Number of epochs.') 43 | parser.add_argument("--seed", type=int, default=2021, help='Random seed.') 44 | parser.add_argument("--batch_size", type=int, default=3, help='Batch size.') 45 | parser.add_argument("--learning_rate", type=float, default=1e-4, help='Learning rate.') 46 | parser.add_argument("--grad_acc", type=int, default=60, help='Number of gradient accumulation.') 47 | parser.add_argument("--valid_every", type=int, default=1, help='Number of training epochs between twice validation.') 48 | parser.add_argument("--mask_prob", type=float, default=0.15, help='Probability of masking.') 49 | parser.add_argument("--replace_prob", type=float, default=0.9, help='Probability of replacing with [MASK] token for masking.') 50 | parser.add_argument("--pos_embed_g2v", action='store_true', help='Using Gene2vec encoding or not (default no unless this arg is supplied).') 51 | parser.add_argument("--sin_emb_wavelength", type=float, default = None, help='Wavelength of sinusoidal expression encodings. Defaults to bin_num.') 52 | parser.add_argument("--small_geneset", action='store_true', help='Train a smaller model. Currently implemented as including genes present in at least 5% of cells.') 53 | parser.add_argument("--g2v_file", type=str, default='/data/rna_rep_learning/scBERT/gene2vec_16906.npy', help='File containing Gene2vec embeddings') 54 | parser.add_argument("--data_path", type=str, default='/data/rna_rep_learning/scBERT/panglao_human.h5ad', help='Path of data for pretraining.') 55 | parser.add_argument("--ckpt_dir", type=str, default='./ckpts/', help='Directory of checkpoint to save.') 56 | parser.add_argument("--model_name", type=str, default='panglao_pretrain', help='Model name used for saving model.') 57 | parser.add_argument("--pretrained_ckpt", type=str, default=None, help='Pretrained checkpoint path.') 58 | parser.add_argument("--pred_continuous", action="store_true", help='If this arg is provided, embed continuous expression values and predict continuous expression values during masking, instead of bucketed.') 59 | parser.add_argument("--debug", action="store_true", help="Debug setting: saves to new dir.") 60 | args = parser.parse_args() 61 | 62 | model_name = args.model_name 63 | 64 | # Control sources of randomness 65 | torch.manual_seed(args.seed) 66 | random.seed(args.seed) 67 | np.random.seed(args.seed) 68 | 69 | # If continuing training from checkpoint 70 | if args.pretrained_ckpt and not args.debug: 71 | ckpt_dir = os.path.dirname(args.pretrained_ckpt) 72 | else: 73 | timestamp = datetime.now().strftime("%Y-%b-%d-%H:%M:%S") 74 | ckpt_dir = os.path.join(args.ckpt_dir, model_name, timestamp) 75 | 76 | # Create checkpoint dir if doesn't exist 77 | # NOTE: Done before distributing to avoid process collision 78 | if not (os.path.exists(ckpt_dir)): 79 | os.makedirs(ckpt_dir) 80 | 81 | print("Checkpoint dir: ", ckpt_dir) 82 | 83 | mp.spawn( 84 | distributed_pretrain, 85 | args=(args, ckpt_dir, model_name), 86 | nprocs=args.world_size, 87 | join=True, 88 | ) 89 | 90 | 91 | def distributed_pretrain(rank, args, ckpt_dir, model_name): 92 | 93 | SEED = args.seed 94 | EPOCHS = args.epochs 95 | BATCH_SIZE = args.batch_size 96 | GRADIENT_ACCUMULATION = args.grad_acc 97 | LEARNING_RATE = args.learning_rate 98 | VALIDATE_EVERY = args.valid_every 99 | CLASS = args.bin_num + 2 100 | POS_EMBED_USING = args.pos_embed_g2v 101 | if args.sin_emb_wavelength: 102 | SIN_EMB_WAVELENGTH = args.sin_emb_wavelength 103 | else: 104 | SIN_EMB_WAVELENGTH = args.bin_num 105 | MASK_PROB = args.mask_prob 106 | REPLACE_PROB = args.replace_prob 107 | PRED_CONTINUOUS = args.pred_continuous 108 | RANDOM_TOKEN_PROB = 0. 109 | MASK_TOKEN_ID = CLASS - 1 110 | PAD_TOKEN_ID = CLASS - 1 111 | MASK_IGNORE_TOKEN_IDS = [0] 112 | 113 | is_master = rank == 0 114 | master_addr = args.master_addr 115 | master_port = args.master_port 116 | world_size = args.world_size 117 | 118 | ### HELPER FUNCTIONS AND DATASET CLASS FROM ORIGINAL CODE ### 119 | 120 | # get the random prob matrix and True means smaller than prob threshold 121 | def prob_mask_like(t, prob): 122 | return torch.zeros_like(t).float().uniform_(0, 1) < prob 123 | 124 | # get the mask matrix which cannot be masked 125 | def mask_with_tokens(t, token_ids): 126 | init_no_mask = torch.full_like(t, False, dtype=torch.bool) 127 | mask = reduce(lambda acc, el: acc | (t == el), token_ids, init_no_mask) 128 | return mask 129 | 130 | def get_mask_subset_with_prob(mask, prob): 131 | batch, seq_len, device = *mask.shape, mask.device 132 | max_masked = math.ceil(prob * seq_len) # num of mask of a single sequence in average 133 | num_tokens = mask.sum(dim=-1, keepdim=True) # num of pure tokens of each sequence except special tokens 134 | mask_excess = torch.cat((torch.zeros(0), torch.arange(mask.size(-1)).repeat(mask.size(0)))).reshape(mask.size(0),mask.size(-1)).to(device) 135 | mask_excess = (mask_excess >= (num_tokens * prob).ceil()) # only 15% of pure tokens can be masked 136 | mask_excess = mask_excess[:, :max_masked] # get difference between 15% of pure tokens and 15% of all tokens 137 | rand = torch.rand((batch, seq_len), device=device).masked_fill(~mask, -1e9) # rand (0-1) as prob, special token use -1e9 138 | _, sampled_indices = rand.topk(max_masked, dim=-1) # get index of topk prob to mask 139 | sampled_indices = (sampled_indices + 1).masked_fill_(mask_excess, 0) # delete difference of mask not pure 140 | new_mask = torch.zeros((batch, seq_len + 1), device=device) # get (batch, seq_len) shape zero matrix 141 | new_mask.scatter_(-1, sampled_indices, 1) # set masks in zero matrix as 1 142 | return new_mask[:, 1:].bool() # the final mask, True is mask 143 | 144 | def data_mask( 145 | data, 146 | mask_prob = MASK_PROB, 147 | replace_prob = REPLACE_PROB, 148 | num_tokens = None, 149 | random_token_prob = RANDOM_TOKEN_PROB, 150 | mask_token_id = MASK_TOKEN_ID, 151 | pad_token_id = PAD_TOKEN_ID, 152 | mask_ignore_token_ids = MASK_IGNORE_TOKEN_IDS 153 | ): 154 | mask_ignore_token_ids = set([*mask_ignore_token_ids, pad_token_id]) 155 | # do not mask [pad] tokens, or any other tokens in the tokens designated to be excluded ([cls], [sep]) 156 | # also do not include these special tokens in the tokens chosen at random 157 | no_mask = mask_with_tokens(data, mask_ignore_token_ids) # ignore_token as True, will not be masked later 158 | mask = get_mask_subset_with_prob(~no_mask, mask_prob) # get the True/False mask matrix 159 | # get mask indices 160 | ## mask_indices = torch.nonzero(mask, as_tuple=True) # get the index of mask(nonzero value of mask matrix) 161 | # mask input with mask tokens with probability of `replace_prob` (keep tokens the same with probability 1 - replace_prob) 162 | masked_input = data.clone().detach() 163 | # if random token probability > 0 for mlm 164 | if random_token_prob > 0: 165 | assert num_tokens is not None, 'num_tokens keyword must be supplied when instantiating MLM if using random token replacement' 166 | random_token_prob = prob_mask_like(data, random_token_prob) # get the mask matrix of random token replace 167 | random_tokens = torch.randint(0, num_tokens, data.shape, device=data.device) # generate random token matrix with the same shape as input 168 | random_no_mask = mask_with_tokens(random_tokens, mask_ignore_token_ids) # not masked matrix for the random token matrix 169 | random_token_prob &= ~random_no_mask # get the pure mask matrix of random token replace 170 | random_indices = torch.nonzero(random_token_prob, as_tuple=True) # index of random token replace 171 | masked_data[random_indices] = random_tokens[random_indices] # replace some tokens by random token 172 | # [mask] input 173 | replace_prob = prob_mask_like(data, replace_prob) # get the mask matrix of token being masked 174 | masked_input = masked_input.masked_fill(mask * replace_prob, mask_token_id) # get the data has been masked by mask_token 175 | # mask out any tokens to padding tokens that were not originally going to be masked 176 | labels = data.masked_fill(~mask, pad_token_id) # the label of masked tokens; will have "pad_token_id" everywhere that was not masked (eg. of pad_token_id having overloaded uses...) 177 | return masked_input, labels 178 | 179 | def MSEloss(preds, target, reduction = 'mean', ignore_index = MASK_TOKEN_ID): 180 | """ 181 | Created our own function to allow for an "ignore_index" argument 182 | """ 183 | if not (target.size() == preds.size()): 184 | print( 185 | "Using a target size ({}) that is different to the input size ({}). " 186 | #"This will likely lead to incorrect results due to broadcasting. " 187 | "Please ensure they have the same size.".format(target.size(), preds.size()) 188 | ) 189 | if reduction != "mean": 190 | print("WARNING: mean MSEloss is automatically calculated, even though you specified a different reduction") 191 | #expanded_preds, expanded_target = torch.broadcast_tensors(preds, target) 192 | diff = (preds-target)*(target!=ignore_index) #dont count loss from values that were not masked 193 | return torch.mean(diff**2) 194 | 195 | class SCDataset(Dataset): 196 | def __init__(self, data, use_continuous=False): 197 | super().__init__() 198 | self.data = data 199 | self.use_continuous = use_continuous 200 | 201 | def __getitem__(self, index): 202 | rand_start = random.randint(0, self.data.shape[0]-1) 203 | full_seq = self.data[rand_start].toarray()[0] 204 | full_seq[full_seq > (CLASS - 2)] = CLASS - 2 205 | full_seq = torch.from_numpy(full_seq) 206 | if(not self.use_continuous): 207 | full_seq = full_seq.long() #long() converts to int64 208 | full_seq = torch.cat((full_seq, torch.tensor([0]))).to(device) 209 | return full_seq 210 | 211 | def __len__(self): 212 | return self.data.shape[0] 213 | 214 | cur_time = time() 215 | setup_process(rank, master_addr, master_port, world_size) 216 | device = torch.device("cuda:{}".format(rank)) 217 | 218 | print("Set up distributed processes...") 219 | 220 | data = sc.read_h5ad(args.data_path) 221 | if args.debug: 222 | debug_seq_len = 5000 223 | data = data[:50,:debug_seq_len] 224 | GRADIENT_ACCUMULATION = 1 225 | elif args.small_geneset: 226 | sc.pp.filter_genes(data, min_cells=0.05*len(data)) 227 | print("Filtered data to include {} genes present in at least 5% of cells".format(data.shape[1])) 228 | data = data.X 229 | if args.gene_num is not None: 230 | SEQ_LEN = args.gene_num + 1 231 | else: 232 | SEQ_LEN = data.shape[1] + 1 # num_genes + 1 233 | 234 | data_train, data_val = train_test_split(data, test_size=0.05, random_state=SEED) 235 | 236 | train_dataset = SCDataset(data_train, PRED_CONTINUOUS) 237 | val_dataset = SCDataset(data_val, PRED_CONTINUOUS) 238 | 239 | train_sampler = DistributedSampler(train_dataset) 240 | val_sampler = SequentialDistributedSampler(val_dataset, batch_size=BATCH_SIZE, world_size=world_size) 241 | 242 | train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=train_sampler) 243 | val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, sampler=val_sampler) 244 | 245 | print("Loaded data...") 246 | 247 | # model 248 | model = PerformerLM( 249 | num_tokens = CLASS, 250 | dim = 200, 251 | depth = 6, 252 | max_seq_len = SEQ_LEN, 253 | heads = 10, 254 | local_attn_heads = 0, 255 | g2v_position_emb = POS_EMBED_USING, 256 | g2v_file = args.g2v_file, 257 | pred_continuous = PRED_CONTINUOUS, 258 | sin_emb_wavelength = SIN_EMB_WAVELENGTH, 259 | ) 260 | # optimizer 261 | optimizer = Adam(model.parameters(), lr=LEARNING_RATE) 262 | # learning rate scheduler 263 | scheduler = CosineAnnealingWarmupRestarts( 264 | optimizer, 265 | first_cycle_steps=15, 266 | cycle_mult=2, 267 | max_lr=LEARNING_RATE, 268 | min_lr=1e-6, 269 | warmup_steps=5, 270 | gamma=0.9 271 | ) 272 | model.to(device) 273 | 274 | # If continuing training from checkpoint 275 | if args.pretrained_ckpt: 276 | # Load checkpoint onto correct rank 277 | checkpoint = torch.load(args.pretrained_ckpt, map_location=device) 278 | consume_prefix_in_state_dict_if_present(checkpoint['model_state_dict'], "module.") 279 | model.load_state_dict(checkpoint['model_state_dict']) 280 | # Load optimizer 281 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 282 | # Load scheduler 283 | scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 284 | cur_epoch = checkpoint['epoch'] 285 | else: 286 | cur_epoch = 0 287 | model = DDP(model, device_ids=[device], output_device=device) 288 | 289 | print("Loaded model...") 290 | if PRED_CONTINUOUS: 291 | loss_fn = MSEloss 292 | else: 293 | loss_fn = nn.CrossEntropyLoss(ignore_index = PAD_TOKEN_ID, reduction='mean').to(device) 294 | softmax = nn.Softmax(dim=-1) 295 | 296 | dist.barrier() 297 | writer = SummaryWriter(os.path.join(ckpt_dir, 'tensorboard')) 298 | for i in range(cur_epoch + 1, EPOCHS + 1): 299 | train_loader.sampler.set_epoch(i) 300 | model.train() 301 | dist.barrier() 302 | running_loss = 0.0 303 | cum_acc = 0.0 304 | cum_impute_error = 0.0 305 | for index, data in tqdm(enumerate(train_loader)): 306 | index += 1 307 | data = data.to(device) 308 | data, labels = data_mask(data) 309 | if index % GRADIENT_ACCUMULATION != 0: 310 | with model.no_sync(): 311 | logits = model(data) #should be size batch_size x seq_len x num_bins (if PRED_CONTINUOUS: batch_size x seq_len x 1) 312 | loss = loss_fn(logits.transpose(1, 2).squeeze(dim=1), labels) / GRADIENT_ACCUMULATION #squeeze needed for MSEloss, shouldn't affect x-ent loss 313 | loss.backward() 314 | if index % GRADIENT_ACCUMULATION == 0: 315 | logits = model(data) 316 | loss = loss_fn(logits.transpose(1, 2).squeeze(dim=1), labels) / GRADIENT_ACCUMULATION 317 | loss.backward() 318 | torch.nn.utils.clip_grad_norm_(model.parameters(), int(1e2)) 319 | optimizer.step() 320 | optimizer.zero_grad() 321 | running_loss += loss.item() 322 | if PRED_CONTINUOUS: 323 | final = logits.squeeze() 324 | impute_error = ((labels != PAD_TOKEN_ID) * torch.abs(final - labels)).sum(dim=-1) 325 | cum_impute_error += impute_error.median().item() #keep track of median imputation error per batch; report avg across batches in epoch 326 | else: #calculating 0-1 accuracy only applies with categorical preds 327 | final = softmax(logits)[..., 1:-1] 328 | final = final.argmax(dim=-1) + 1 329 | pred_num = (labels != PAD_TOKEN_ID).sum(dim=-1) 330 | correct_num = ((labels != PAD_TOKEN_ID) * (final == labels)).sum(dim=-1) 331 | cum_acc += torch.true_divide(correct_num, pred_num).mean().item() 332 | if PRED_CONTINUOUS: 333 | epoch_impute_error = cum_impute_error / index 334 | epoch_impute_error = get_reduced(epoch_impute_error, device, 0, world_size) 335 | epoch_acc =-1 336 | else: 337 | epoch_acc = 100 * cum_acc / index 338 | epoch_acc = get_reduced(epoch_acc, device, 0, world_size) 339 | epoch_abs_error = -1 340 | epoch_loss = running_loss / index 341 | epoch_loss = get_reduced(epoch_loss, device, 0, world_size) 342 | if is_master: 343 | if PRED_CONTINUOUS: 344 | print(f' == Epoch: {i} | Training Loss: {epoch_loss:.6f} | Accuracy: {epoch_acc:6.4f}% | Median Imputation Error : {epoch_impute_error:.4f} ==') 345 | else: 346 | print(f' == Epoch: {i} | Training Loss: {epoch_loss:.6f} | Accuracy: {epoch_acc:6.4f}% ==') 347 | dist.barrier() 348 | scheduler.step() 349 | 350 | if i % VALIDATE_EVERY == 0: 351 | model.eval() 352 | dist.barrier() 353 | running_loss = 0.0 354 | running_error = 0.0 355 | predictions = [] 356 | truths = [] 357 | with torch.no_grad(): 358 | for index, data in enumerate(val_loader): 359 | index += 1 360 | data = data.to(device) 361 | data, labels = data_mask(data) 362 | logits = model(data) 363 | loss = loss_fn(logits.transpose(1, 2).squeeze(dim=1), labels) 364 | running_loss += loss.item() 365 | softmax = nn.Softmax(dim=-1) 366 | if PRED_CONTINUOUS: 367 | final = logits.squeeze(dim=2) #(include 'dim' in case at the end of loop, batchsize=1) 368 | else: 369 | final = softmax(logits)[..., 1:-1] 370 | final = final.argmax(dim=-1) + 1 371 | predictions.append(final) 372 | truths.append(labels) 373 | del data, labels, logits, final 374 | # gather 375 | predictions = distributed_concat(torch.cat(predictions, dim=0), len(val_sampler.dataset), world_size) 376 | truths = distributed_concat(torch.cat(truths, dim=0), len(val_sampler.dataset), world_size) 377 | val_num = (truths != PAD_TOKEN_ID).sum(dim=-1) 378 | 379 | # Epoch loss 380 | val_loss = running_loss / index 381 | val_loss = get_reduced(val_loss, device, 0, world_size) 382 | 383 | # accuracy (categorical output) or absolute error (continuous output) 384 | if PRED_CONTINUOUS: 385 | val_impute_error = ((truths != PAD_TOKEN_ID) * torch.abs(predictions - truths)).sum(dim=-1).median().item() 386 | val_acc = -1 387 | else: 388 | correct_num = ((truths != PAD_TOKEN_ID) * (predictions == truths)).sum(dim=-1) 389 | val_acc = 100 * (correct_num / val_num).mean().item() 390 | val_impute_error = -1 391 | 392 | if is_master: 393 | if PRED_CONTINUOUS: 394 | print(f' == Epoch: {i} | Validation Loss: {val_loss:.6f} | Accuracy: {val_acc:6.4f}% | Median Imputation Error: {val_impute_error:.4f} ==') 395 | else: 396 | print(f' == Epoch: {i} | Validation Loss: {val_loss:.6f} | Accuracy: {val_acc:6.4f}% ==') 397 | 398 | duration = time() - cur_time 399 | cur_time = time() 400 | 401 | writer.add_scalar('Epoch duration', duration, i) 402 | writer.add_scalar('Loss/val', val_loss, i) 403 | writer.add_scalar('Loss/val', val_loss, i) 404 | if PRED_CONTINUOUS: 405 | writer.add_scalar('Median imputation error/train', epoch_impute_error, i) 406 | writer.add_scalar('Median imputation error/val', val_impute_error, i) 407 | else: 408 | writer.add_scalar('Accuracy/train', epoch_acc, i) 409 | writer.add_scalar('Accuracy/val', val_acc, i) 410 | 411 | del predictions, truths 412 | if is_master: 413 | save_ckpt(i, model, optimizer, scheduler, epoch_loss, model_name, ckpt_dir) 414 | 415 | cleanup() 416 | 417 | 418 | def setup_process(rank, master_addr, master_port, world_size, backend="nccl"): 419 | print(f"Setting up process: rank={rank} world_size={world_size} backend={backend}.") 420 | print(f"master_addr={master_addr} master_port={master_port}") 421 | os.environ["MASTER_ADDR"] = master_addr 422 | os.environ["MASTER_PORT"] = master_port 423 | dist.init_process_group(backend=backend, rank=rank, world_size=world_size) 424 | 425 | 426 | def cleanup(): 427 | dist.destroy_process_group() 428 | 429 | 430 | if __name__=="__main__": 431 | main() 432 | -------------------------------------------------------------------------------- /scBERT/performer_pytorch/.ipynb_checkpoints/performer_pytorch-checkpoint.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn 6 | from torch.cuda.amp import autocast 7 | from einops import rearrange, repeat 8 | 9 | from functools import partial 10 | from contextlib import contextmanager 11 | 12 | from local_attention import LocalAttention 13 | from performer_pytorch.reversible import ReversibleSequence, SequentialSequence 14 | 15 | try: 16 | from apex import amp 17 | APEX_AVAILABLE = True 18 | except: 19 | APEX_AVAILABLE = False 20 | 21 | # helpers 22 | 23 | def exists(val): 24 | return val is not None 25 | 26 | def empty(tensor): 27 | return tensor.numel() == 0 28 | 29 | def default(val, d): 30 | return val if exists(val) else d 31 | 32 | @contextmanager 33 | def null_context(): 34 | yield 35 | 36 | def cast_tuple(val): 37 | return (val,) if not isinstance(val, tuple) else val 38 | 39 | # def get_module_device(module): 40 | # return next(module.parameters).device 41 | 42 | def get_module_device(module): 43 | try: 44 | return next(module.parameters()).device 45 | except StopIteration: 46 | # For nn.DataParallel compatibility in PyTorch 1.5 47 | def find_tensor_attributes(module): 48 | tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] 49 | return tuples 50 | gen = module._named_members(get_members_fn=find_tensor_attributes) 51 | first_tuple = next(gen) 52 | return first_tuple[1].device 53 | 54 | def find_modules(nn_module, type): 55 | return [module for module in nn_module.modules() if isinstance(module, type)] 56 | 57 | class Always(nn.Module): 58 | def __init__(self, val): 59 | super().__init__() 60 | self.val = val 61 | def forward(self, *args, **kwargs): 62 | return self.val 63 | 64 | # kernel functions 65 | 66 | # transcribed from jax to pytorch from 67 | # https://github.com/google-research/google-research/blob/master/performer/fast_attention/jax/fast_attention.py 68 | 69 | def softmax_kernel(data, *, projection_matrix, is_query, normalize_data=True, eps=1e-4, device = None): 70 | b, h, *_ = data.shape 71 | 72 | data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1. 73 | 74 | ratio = (projection_matrix.shape[0] ** -0.5) 75 | 76 | projection = repeat(projection_matrix, 'j d -> b h j d', b = b, h = h) 77 | projection = projection.type_as(data) 78 | 79 | data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data), projection.clone()) 80 | 81 | diag_data = data ** 2 82 | diag_data = torch.sum(diag_data, dim=-1) 83 | diag_data = (diag_data / 2.0) * (data_normalizer ** 2) 84 | diag_data = diag_data.unsqueeze(dim=-1) 85 | 86 | if is_query: 87 | data_dash = ratio * ( 88 | torch.exp(data_dash - diag_data - 89 | torch.max(data_dash, dim=-1, keepdim=True).values) + eps) 90 | else: 91 | data_dash = ratio * ( 92 | torch.exp(data_dash - diag_data - torch.max(data_dash)) + eps) 93 | 94 | return data_dash.type_as(data) 95 | 96 | def generalized_kernel(data, *, projection_matrix, kernel_fn = nn.ReLU(), kernel_epsilon = 0.001, normalize_data = True, device = None): 97 | b, h, *_ = data.shape 98 | 99 | data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1. 100 | 101 | if projection_matrix is None: 102 | return kernel_fn(data_normalizer * data) + kernel_epsilon 103 | 104 | projection = repeat(projection_matrix, 'j d -> b h j d', b = b, h = h) 105 | projection = projection.type_as(data) 106 | 107 | data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data), projection) 108 | 109 | data_prime = kernel_fn(data_dash) + kernel_epsilon 110 | return data_prime.type_as(data) 111 | 112 | def orthogonal_matrix_chunk(cols, device = None): 113 | unstructured_block = torch.randn((cols, cols), device = device) 114 | q, r = torch.linalg.qr(unstructured_block.cpu(), mode = 'reduced') 115 | q, r = map(lambda t: t.to(device), (q, r)) 116 | return q.t() 117 | 118 | def gaussian_orthogonal_random_matrix(nb_rows, nb_columns, scaling = 0, device = None): 119 | nb_full_blocks = int(nb_rows / nb_columns) 120 | 121 | block_list = [] 122 | 123 | for _ in range(nb_full_blocks): 124 | q = orthogonal_matrix_chunk(nb_columns, device = device) 125 | block_list.append(q) 126 | 127 | remaining_rows = nb_rows - nb_full_blocks * nb_columns 128 | if remaining_rows > 0: 129 | q = orthogonal_matrix_chunk(nb_columns, device = device) 130 | block_list.append(q[:remaining_rows]) 131 | 132 | final_matrix = torch.cat(block_list) 133 | 134 | if scaling == 0: 135 | multiplier = torch.randn((nb_rows, nb_columns), device = device).norm(dim = 1) 136 | elif scaling == 1: 137 | multiplier = math.sqrt((float(nb_columns))) * torch.ones((nb_rows,), device = device) 138 | else: 139 | raise ValueError(f'Invalid scaling {scaling}') 140 | 141 | return torch.diag(multiplier) @ final_matrix 142 | 143 | # linear attention classes with softmax kernel 144 | 145 | # non-causal linear attention 146 | def linear_attention(q, k, v): 147 | k_cumsum = k.sum(dim = -2) 148 | D_inv = 1. / torch.einsum('...nd,...d->...n', q, k_cumsum.type_as(q)) 149 | context = torch.einsum('...nd,...ne->...de', k, v) 150 | out = torch.einsum('...de,...nd,...n->...ne', context, q, D_inv) 151 | return out 152 | 153 | # efficient causal linear attention, created by EPFL 154 | # TODO: rewrite EPFL's CUDA kernel to do mixed precision and remove half to float conversion and back 155 | def causal_linear_attention(q, k, v, eps = 1e-6): 156 | from fast_transformers.causal_product import CausalDotProduct 157 | autocast_enabled = torch.is_autocast_enabled() 158 | is_half = isinstance(q, torch.cuda.HalfTensor) 159 | assert not is_half or APEX_AVAILABLE, 'half tensors can only be used if nvidia apex is available' 160 | cuda_context = null_context if not autocast_enabled else partial(autocast, enabled = False) 161 | 162 | causal_dot_product_fn = amp.float_function(CausalDotProduct.apply) if is_half else CausalDotProduct.apply 163 | 164 | k_cumsum = k.cumsum(dim=-2) + eps 165 | D_inv = 1. / torch.einsum('...nd,...nd->...n', q, k_cumsum.type_as(q)) 166 | 167 | with cuda_context(): 168 | if autocast_enabled: 169 | q, k, v = map(lambda t: t.float(), (q, k, v)) 170 | 171 | out = causal_dot_product_fn(q, k, v) 172 | 173 | out = torch.einsum('...nd,...n->...nd', out, D_inv) 174 | return out 175 | 176 | # inefficient causal linear attention, without cuda code, for reader's reference 177 | # not being used 178 | def causal_linear_attention_noncuda(q, k, v, chunk_size = 128): 179 | last_k_cumsum = 0 180 | last_context_cumsum = 0 181 | outs = [] 182 | 183 | for q, k, v in zip(*map(lambda t: t.chunk(chunk_size, dim = -2), (q, k, v))): 184 | k_cumsum = last_k_cumsum + k.cumsum(dim=-2) 185 | 186 | D_inv = 1. / torch.einsum('...nd,...nd->...n', q, k_cumsum.type_as(q)) 187 | context = torch.einsum('...nd,...ne->...nde', k, v) 188 | context_cumsum = last_context_cumsum + context.cumsum(dim=-3) 189 | out = torch.einsum('...nde,...nd,...n->...ne', context_cumsum, q, D_inv) 190 | 191 | last_k_cumsum = k_cumsum[:, :, -1:] 192 | last_context_cumsum = context_cumsum[:, :, -1:] 193 | outs.append(out) 194 | 195 | return torch.cat(outs, dim = -2) 196 | 197 | def norm_tensor(tensor, dim=-1): 198 | return tensor / tensor.sum(dim=dim).unsqueeze(dim) 199 | 200 | class FastAttention(nn.Module): 201 | def __init__(self, dim_heads, nb_features = None, ortho_scaling = 0, causal = False, generalized_attention = False, kernel_fn = nn.ReLU(), no_projection = False): 202 | super().__init__() 203 | nb_features = default(nb_features, int(dim_heads * math.log(dim_heads))) 204 | 205 | self.dim_heads = dim_heads 206 | self.nb_features = nb_features 207 | self.ortho_scaling = ortho_scaling 208 | 209 | self.create_projection = partial(gaussian_orthogonal_random_matrix, nb_rows = self.nb_features, nb_columns = dim_heads, scaling = ortho_scaling) 210 | projection_matrix = self.create_projection() 211 | self.register_buffer('projection_matrix', projection_matrix) 212 | 213 | self.generalized_attention = generalized_attention 214 | self.kernel_fn = kernel_fn 215 | 216 | # if this is turned on, no projection will be used 217 | # queries and keys will be softmax-ed as in the original efficient attention paper 218 | self.no_projection = no_projection 219 | 220 | self.causal = causal 221 | if causal: 222 | try: 223 | import fast_transformers.causal_product.causal_product_cuda 224 | self.causal_linear_fn = partial(causal_linear_attention) 225 | except ImportError: 226 | print('unable to import cuda code for auto-regressive Performer. will default to the memory inefficient non-cuda version') 227 | self.causal_linear_fn = causal_linear_attention_noncuda 228 | 229 | @torch.no_grad() 230 | def redraw_projection_matrix(self, device): 231 | projections = self.create_projection(device = device) 232 | self.projection_matrix.copy_(projections) 233 | del projections 234 | 235 | def forward(self, q, k, v, output_attentions = False): 236 | device = q.device 237 | # inds = [8060, 8064, 6243, 8575, 10342, 10913, 9366, 993, 7796, 5210, 5212, 5504, 6851, 6559, 5508, 13107, 13820] 238 | if self.no_projection: 239 | q = q.softmax(dim = -1) 240 | k = torch.exp(k) if self.causal else k.softmax(dim = -2) 241 | 242 | elif self.generalized_attention: 243 | create_kernel = partial(generalized_kernel, kernel_fn = self.kernel_fn, projection_matrix = self.projection_matrix, device = device) 244 | q, k = map(create_kernel, (q, k)) 245 | 246 | else: 247 | create_kernel = partial(softmax_kernel, projection_matrix = self.projection_matrix, device = device) 248 | q = create_kernel(q, is_query = True) 249 | k = create_kernel(k, is_query = False) 250 | 251 | attn_fn = linear_attention if not self.causal else self.causal_linear_fn 252 | out = attn_fn(q, k, v) 253 | if output_attentions: 254 | v_diag = torch.eye(v.shape[-2]).to(device) 255 | v_diag = v_diag.unsqueeze(0).unsqueeze(0).repeat(v.shape[0],v.shape[1],1,1) 256 | # attn_weights = torch.zeros(1, 1, len(inds), len(inds)).to(device).to(torch.float16) 257 | # attn_weights = torch.zeros(1, q.shape[1], len(inds), len(inds)).to(device).to(torch.float16) 258 | attn_weights = torch.zeros(1, 1, q.shape[2], q.shape[2]).to(device).to(torch.float16) 259 | for head_dim in range(q.shape[1]): 260 | # attn_weights[0, head_dim] = torch.abs(attn_fn(q[:,head_dim].to(torch.float16), k[:,head_dim].to(torch.float16), v_diag[:,head_dim].to(torch.float16)))[0, inds][:, inds] 261 | attn_weights += torch.abs(attn_fn(q[:,head_dim].to(torch.float16), k[:,head_dim].to(torch.float16), v_diag[:,head_dim].to(torch.float16))) 262 | # attn_weights += norm_tensor(torch.abs(attn_fn(q[:,head_dim].to(torch.float16), k[:,head_dim].to(torch.float16), v_diag[:,head_dim].to(torch.float16))), dim=-1) 263 | attn_weights /= q.shape[1] 264 | return out, attn_weights 265 | else: 266 | return out 267 | 268 | # classes 269 | 270 | class ReZero(nn.Module): 271 | def __init__(self, fn): 272 | super().__init__() 273 | self.g = nn.Parameter(torch.tensor(1e-3)) 274 | self.fn = fn 275 | 276 | def forward(self, x, **kwargs): 277 | return self.fn(x, **kwargs) * self.g 278 | 279 | class PreScaleNorm(nn.Module): 280 | def __init__(self, dim, fn, eps=1e-5): 281 | super().__init__() 282 | self.fn = fn 283 | self.g = nn.Parameter(torch.ones(1)) 284 | self.eps = eps 285 | 286 | def forward(self, x, **kwargs): 287 | n = torch.norm(x, dim=-1, keepdim=True).clamp(min=self.eps) 288 | x = x / n * self.g 289 | return self.fn(x, **kwargs) 290 | 291 | class PreLayerNorm(nn.Module): 292 | def __init__(self, dim, fn): 293 | super().__init__() 294 | self.norm = nn.LayerNorm(dim) 295 | self.fn = fn 296 | def forward(self, x, **kwargs): 297 | return self.fn(self.norm(x), **kwargs) 298 | 299 | class Chunk(nn.Module): 300 | def __init__(self, chunks, fn, along_dim = -1): 301 | super().__init__() 302 | self.dim = along_dim 303 | self.chunks = chunks 304 | self.fn = fn 305 | 306 | def forward(self, x, **kwargs): 307 | if self.chunks == 1: 308 | return self.fn(x, **kwargs) 309 | chunks = x.chunk(self.chunks, dim = self.dim) 310 | return torch.cat([self.fn(c, **kwargs) for c in chunks], dim = self.dim) 311 | 312 | class FeedForward(nn.Module): 313 | def __init__(self, dim, mult = 4, dropout = 0., activation = None, glu = False): 314 | super().__init__() 315 | activation = default(activation, nn.GELU) 316 | 317 | self.glu = glu 318 | self.w1 = nn.Linear(dim, dim * mult * (2 if glu else 1)) 319 | self.act = activation() 320 | self.dropout = nn.Dropout(dropout) 321 | self.w2 = nn.Linear(dim * mult, dim) 322 | 323 | def forward(self, x, **kwargs): 324 | if not self.glu: 325 | x = self.w1(x) 326 | x = self.act(x) 327 | else: 328 | x, v = self.w1(x).chunk(2, dim=-1) 329 | x = self.act(x) * v 330 | 331 | x = self.dropout(x) 332 | x = self.w2(x) 333 | return x 334 | 335 | class SelfAttention(nn.Module): 336 | def __init__( 337 | self, 338 | dim, 339 | causal = False, 340 | heads = 8, 341 | dim_head = 64, 342 | local_heads = 0, 343 | local_window_size = 256, 344 | nb_features = None, 345 | feature_redraw_interval = 1000, 346 | generalized_attention = False, 347 | kernel_fn = nn.ReLU(), 348 | dropout = 0., 349 | no_projection = False, 350 | qkv_bias = False 351 | ): 352 | super().__init__() 353 | assert dim % heads == 0, 'dimension must be divisible by number of heads' 354 | dim_head = default(dim_head, dim // heads) 355 | inner_dim = dim_head * heads 356 | self.fast_attention = FastAttention(dim_head, nb_features, causal = causal, generalized_attention = generalized_attention, kernel_fn = kernel_fn, no_projection = no_projection) 357 | 358 | self.heads = heads 359 | self.global_heads = heads - local_heads 360 | self.local_attn = LocalAttention(window_size = local_window_size, causal = causal, autopad = True, dropout = dropout, look_forward = int(not causal), rel_pos_emb_config = (dim_head, local_heads)) if local_heads > 0 else None 361 | 362 | self.to_q = nn.Linear(dim, inner_dim, bias = qkv_bias) 363 | self.to_k = nn.Linear(dim, inner_dim, bias = qkv_bias) 364 | self.to_v = nn.Linear(dim, inner_dim, bias = qkv_bias) 365 | self.to_out = nn.Linear(inner_dim, dim) 366 | self.dropout = nn.Dropout(dropout) 367 | 368 | def forward(self, x, pos_emb = None, context = None, mask = None, context_mask = None, output_attentions = False, **kwargs): 369 | b, n, _, h, gh = *x.shape, self.heads, self.global_heads 370 | 371 | cross_attend = exists(context) 372 | 373 | context = default(context, x) 374 | context_mask = default(context_mask, mask) if not cross_attend else context_mask 375 | 376 | q, k, v = self.to_q(x), self.to_k(context), self.to_v(context) 377 | 378 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) 379 | (q, lq), (k, lk), (v, lv) = map(lambda t: (t[:, :gh], t[:, gh:]), (q, k, v)) 380 | 381 | attn_outs = [] 382 | 383 | if not empty(q): 384 | if exists(context_mask): 385 | global_mask = context_mask[:, None, :, None] 386 | v.masked_fill_(~global_mask, 0.) 387 | 388 | if exists(pos_emb) and not cross_attend: 389 | q, k, = apply_rotary_pos_emb(q, k, pos_emb) 390 | 391 | if output_attentions: 392 | out, attn_weights = self.fast_attention(q, k, v, output_attentions) 393 | else: 394 | out = self.fast_attention(q, k, v) 395 | attn_outs.append(out) 396 | 397 | if not empty(lq): 398 | assert not cross_attend, 'local attention is not compatible with cross attention' 399 | out = self.local_attn(lq, lk, lv, input_mask = mask) 400 | attn_outs.append(out) 401 | 402 | out = torch.cat(attn_outs, dim = 1) # combine attn_out and cross_attn_out, here we have only attn_out, that means this line does nothing 403 | out = rearrange(out, 'b h n d -> b n (h d)') 404 | out = self.to_out(out) 405 | if output_attentions: 406 | return self.dropout(out), attn_weights 407 | else: 408 | return self.dropout(out) 409 | 410 | # positional embeddings 411 | 412 | class AbsolutePositionalEmbedding(nn.Module): 413 | def __init__(self, dim, max_seq_len): 414 | super().__init__() 415 | self.emb = nn.Embedding(max_seq_len, dim) 416 | 417 | def forward(self, x): 418 | t = torch.arange(x.shape[1], device=x.device) 419 | return self.emb(t) 420 | 421 | # rotary positional embedding helpers 422 | 423 | def rotate_every_two(x): 424 | x = rearrange(x, '... (d j) -> ... d j', j = 2) 425 | x1, x2 = x.unbind(dim = -1) 426 | x = torch.stack((-x2, x1), dim = -1) 427 | return rearrange(x, '... d j -> ... (d j)') 428 | 429 | def apply_rotary_pos_emb(q, k, sinu_pos): 430 | sinu_pos = rearrange(sinu_pos, '() n (j d) -> n j d', j = 2) 431 | sin, cos = sinu_pos.unbind(dim = -2) 432 | sin, cos = map(lambda t: repeat(t, 'b n -> b (n j)', j = 2), (sin, cos)) 433 | q, k = map(lambda t: (t * cos) + (rotate_every_two(t) * sin), (q, k)) 434 | return q, k 435 | 436 | # sinusoidal positional embeddings 437 | class SinExpressionEmbedding(nn.Module): 438 | def __init__(self, d_model: int, dropout: float = 0.1, wavelength: float = 12.0): 439 | super().__init__() 440 | self.d_model = d_model 441 | self.dropout = nn.Dropout(p=dropout) 442 | 443 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(wavelength) / d_model)) 444 | # Add dimensions for easy multiplication with input tensors 445 | div_term = div_term[None, None, :] 446 | 447 | self.register_buffer('div_term', div_term) 448 | 449 | def forward(self, x): 450 | """ 451 | Args: 452 | x: Tensor, shape [batch_size, seq_len] 453 | """ 454 | pe = torch.zeros(x.size(0),x.size(1),self.d_model,device=x.device) 455 | pe[:,:,0::2] = torch.sin(x[:,:,None] * self.div_term) 456 | pe[:,:,1::2] = torch.cos(x[:,:,None] * self.div_term) 457 | return self.dropout(pe) 458 | 459 | class Gene2VecPositionalEmbedding(nn.Module): 460 | def __init__(self, dim, max_seq_len, g2v_file): 461 | super().__init__() 462 | gene2vec_weight = np.load(g2v_file) 463 | gene2vec_weight = np.concatenate((gene2vec_weight, np.zeros((1, gene2vec_weight.shape[1]))), axis=0) 464 | gene2vec_weight = torch.from_numpy(gene2vec_weight) 465 | self.emb = nn.Embedding.from_pretrained(gene2vec_weight) 466 | 467 | def forward(self, x): 468 | t = torch.arange(x.shape[1], device=x.device) 469 | return self.emb(t) 470 | 471 | # performer 472 | 473 | class Performer(nn.Module): 474 | def __init__( 475 | self, 476 | dim, # dimension 477 | depth, # layers 478 | heads, # heads 479 | dim_head, # dim of head 480 | local_attn_heads = 0, # num of local attention heads, (heads - local_attn_heads) is num of global performers 481 | local_window_size = 256, # window size of local attention 482 | causal = False, # autoregressive or not 483 | ff_mult = 4, # dim of intermediate features after attention / dim of input features 484 | nb_features = None, # number of random features, if not set, will default to (d * log(d)), where d is the dimension of each head ?? what is random feature ?? 485 | feature_redraw_interval = 1000, # how frequently to redraw the projection matrix, the more frequent, the slower the training 486 | reversible = False, # reversible layers, from Reformer (save memory) 487 | ff_chunks = 1, # chunk feedforward layer, from Reformer 488 | generalized_attention = False, # defaults to softmax approximation, but can be set to True for generalized attention ?? what is generalized attention ?? 489 | kernel_fn = nn.ReLU(), # the kernel function to be used, if generalized attention is turned on, defaults to Relu 490 | use_scalenorm = False, # use scale norm, from 'Transformers without Tears' paper, a substitute for LayerNorm, priority: scalenorm.rezero.layernorm 491 | use_rezero = False, # use Rezero or not, from 'Rezero is all you need' paper, a substitute for LayerNorm, priority: scalenorm.rezero.layernorm 492 | ff_glu = False, # use GLU (Gated Linear Units) variant for feedforward 493 | ff_dropout = 0., # feedforward dropout 494 | attn_dropout = 0., # post-attention dropout 495 | cross_attend = False, # ?? 496 | no_projection = False, # ?? 497 | auto_check_redraw = True, # ?? 498 | qkv_bias = True, # ?? 499 | ): 500 | super().__init__() 501 | layers = nn.ModuleList([]) 502 | local_attn_heads = cast_tuple(local_attn_heads) 503 | local_attn_heads = local_attn_heads * depth if len(local_attn_heads) == 1 else local_attn_heads 504 | assert len(local_attn_heads) == depth, 'tuple specifying number of local attention heads per depth must be equal to the total depth' 505 | assert all(map(lambda n: n >= 0 and n <= heads, local_attn_heads)), 'local attention head value must be less than the total number of heads' 506 | 507 | if use_scalenorm: 508 | wrapper_fn = partial(PreScaleNorm, dim) 509 | elif use_rezero: 510 | wrapper_fn = ReZero 511 | else: 512 | wrapper_fn = partial(PreLayerNorm, dim) 513 | 514 | for _, local_heads in zip(range(depth), local_attn_heads): 515 | layers.append(nn.ModuleList([ 516 | wrapper_fn(SelfAttention(dim, causal = causal, heads = heads, dim_head = dim_head, local_heads = local_heads, local_window_size = local_window_size, nb_features = nb_features, generalized_attention = generalized_attention, kernel_fn = kernel_fn, dropout = attn_dropout, no_projection = no_projection, qkv_bias = qkv_bias)), 517 | wrapper_fn(Chunk(ff_chunks, FeedForward(dim, mult = ff_mult, dropout = ff_dropout, glu = ff_glu), along_dim = 1)) 518 | ])) 519 | # if no need cross_attend(decoder), begin next cycle 520 | if not cross_attend: 521 | continue 522 | layers.append(nn.ModuleList([ 523 | wrapper_fn(SelfAttention(dim, heads = heads, dim_head = dim_head, nb_features = nb_features, generalized_attention = generalized_attention, kernel_fn = kernel_fn, dropout = attn_dropout, no_projection = no_projection)), 524 | wrapper_fn(Chunk(ff_chunks, FeedForward(dim, mult = ff_mult, dropout = ff_dropout, glu = ff_glu), along_dim = 1)) 525 | ])) 526 | 527 | execute_type = ReversibleSequence if reversible else SequentialSequence 528 | 529 | route_attn = ((True, False),) * depth * (2 if cross_attend else 1) # ((True, False), (True, False), (True, False), (True, False), (True, False), (True, False)) 530 | route_context = ((False, False), (True, False)) * depth 531 | attn_route_map = {'mask': route_attn, 'pos_emb': route_attn} 532 | context_route_map = {'context': route_context, 'context_mask': route_context} if cross_attend else {} 533 | self.net = execute_type(layers, args_route = {**attn_route_map, **context_route_map}) 534 | 535 | # keeping track of when to redraw projections for all attention layers 536 | self.auto_check_redraw = auto_check_redraw 537 | self.feature_redraw_interval = feature_redraw_interval 538 | self.register_buffer('calls_since_last_redraw', torch.tensor(0)) 539 | 540 | def fix_projection_matrices_(self): 541 | self.feature_redraw_interval = None 542 | 543 | def check_redraw_projections(self): 544 | if not self.training: 545 | return 546 | 547 | if exists(self.feature_redraw_interval) and self.calls_since_last_redraw >= self.feature_redraw_interval: 548 | device = get_module_device(self) 549 | 550 | fast_attentions = find_modules(self, FastAttention) 551 | for fast_attention in fast_attentions: 552 | fast_attention.redraw_projection_matrix(device) 553 | 554 | self.calls_since_last_redraw.zero_() 555 | return 556 | 557 | self.calls_since_last_redraw += 1 558 | 559 | def forward(self, x, output_attentions = False, **kwargs): 560 | if self.auto_check_redraw: 561 | self.check_redraw_projections() 562 | return self.net(x, output_attentions = output_attentions, **kwargs) 563 | 564 | class PerformerLM(nn.Module): 565 | def __init__( 566 | self, 567 | *, 568 | num_tokens, # num of tokens 569 | max_seq_len, # max length of sequence 570 | dim, # dim of tokens 571 | depth, # layers 572 | heads, # num of heads 573 | dim_head = 64, # dim of heads 574 | local_attn_heads = 0, 575 | local_window_size = 256, 576 | causal = False, 577 | ff_mult = 4, 578 | nb_features = None, 579 | feature_redraw_interval = 1000, 580 | reversible = False, 581 | ff_chunks = 1, 582 | ff_glu = False, 583 | emb_dropout = 0., 584 | ff_dropout = 0., 585 | attn_dropout = 0., 586 | generalized_attention = False, 587 | kernel_fn = nn.ReLU(), 588 | use_scalenorm = False, 589 | use_rezero = False, 590 | cross_attend = False, 591 | no_projection = False, 592 | tie_embed = False, # False: output is num of tokens, True: output is dim of tokens //multiply final embeddings with token weights for logits, like gpt decoder// 593 | g2v_position_emb = True, # priority: gene2vec, no embedding 594 | g2v_file = None, 595 | sin_emb_wavelength = 12.0, 596 | auto_check_redraw = True, 597 | qkv_bias = False, 598 | pred_continuous = False # False: predict the categorical bucketed expression, True: predict the continuous expression value 599 | ): 600 | super().__init__() 601 | local_attn_heads = cast_tuple(local_attn_heads) 602 | 603 | self.max_seq_len = max_seq_len 604 | if(pred_continuous): 605 | self.token_emb = SinExpressionEmbedding(dim, wavelength = sin_emb_wavelength) 606 | else: 607 | self.token_emb = nn.Embedding(num_tokens, dim) 608 | 609 | if g2v_position_emb: 610 | self.pos_emb = Gene2VecPositionalEmbedding(dim, max_seq_len, g2v_file) 611 | self.layer_pos_emb = Always(None) 612 | else: 613 | self.pos_emb = torch.zeros_like 614 | self.layer_pos_emb = Always(None) 615 | 616 | self.dropout = nn.Dropout(emb_dropout) 617 | 618 | self.performer = Performer(dim, depth, heads, dim_head, local_attn_heads, local_window_size, causal, ff_mult, nb_features, feature_redraw_interval, reversible, ff_chunks, generalized_attention, kernel_fn, use_scalenorm, use_rezero, ff_glu, ff_dropout, attn_dropout, cross_attend, no_projection, auto_check_redraw, qkv_bias) 619 | self.norm = nn.LayerNorm(dim) 620 | if pred_continuous: 621 | self.to_out = nn.Linear(dim, 1) 622 | else: 623 | self.to_out = nn.Linear(dim, num_tokens) if not tie_embed else None 624 | 625 | def check_redraw_projections(self): 626 | self.performer.check_redraw_projections() 627 | 628 | def fix_projection_matrices_(self): 629 | self.performer.fix_projection_matrices_() 630 | 631 | def forward(self, x, return_encodings = False, output_attentions = False, **kwargs): 632 | b, n, device = *x.shape, x.device 633 | assert n <= self.max_seq_len, f'sequence length {n} must be less than the max sequence length {self.max_seq_len}' 634 | 635 | # token and positional embedding 636 | x = self.token_emb(x) # expression level 637 | if output_attentions: 638 | x.requires_grad_() # used for attn_map output 639 | x += self.pos_emb(x) # add gene embedding (gene2vec or zeroes) 640 | x = self.dropout(x) 641 | 642 | # performer layers 643 | layer_pos_emb = self.layer_pos_emb(x) #this returns None (Rebecca) 644 | 645 | if output_attentions: 646 | x, attn_weights = self.performer(x, pos_emb = layer_pos_emb, output_attentions = output_attentions, **kwargs) 647 | # norm and to logits 648 | x = self.norm(x) 649 | if return_encodings: 650 | return self.to_out(x), x, attn_weights 651 | 652 | if exists(self.to_out): 653 | return self.to_out(x), attn_weights 654 | 655 | return (x @ self.token_emb.weight.t()), attn_weights 656 | else: 657 | x = self.performer(x, pos_emb = layer_pos_emb, output_attentions = output_attentions, **kwargs) 658 | # norm and to logits 659 | x = self.norm(x) 660 | if return_encodings: 661 | return self.to_out(x), x 662 | 663 | if exists(self.to_out): 664 | x = self.to_out(x) 665 | return x #batch size x seq len x hidden (logits) 666 | 667 | return x @ self.token_emb.weight.t() #batch size x seq len x vocab size 668 | -------------------------------------------------------------------------------- /scBERT/performer_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from performer_pytorch.performer_pytorch import PerformerLM, Performer, FastAttention, SelfAttention 2 | -------------------------------------------------------------------------------- /scBERT/performer_pytorch/performer_pytorch.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn 6 | from torch.cuda.amp import autocast 7 | from einops import rearrange, repeat 8 | 9 | from functools import partial 10 | from contextlib import contextmanager 11 | 12 | from local_attention import LocalAttention 13 | from performer_pytorch.reversible import ReversibleSequence, SequentialSequence 14 | 15 | try: 16 | from apex import amp 17 | APEX_AVAILABLE = True 18 | except: 19 | APEX_AVAILABLE = False 20 | 21 | # helpers 22 | 23 | def exists(val): 24 | return val is not None 25 | 26 | def empty(tensor): 27 | return tensor.numel() == 0 28 | 29 | def default(val, d): 30 | return val if exists(val) else d 31 | 32 | @contextmanager 33 | def null_context(): 34 | yield 35 | 36 | def cast_tuple(val): 37 | return (val,) if not isinstance(val, tuple) else val 38 | 39 | # def get_module_device(module): 40 | # return next(module.parameters).device 41 | 42 | def get_module_device(module): 43 | try: 44 | return next(module.parameters()).device 45 | except StopIteration: 46 | # For nn.DataParallel compatibility in PyTorch 1.5 47 | def find_tensor_attributes(module): 48 | tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] 49 | return tuples 50 | gen = module._named_members(get_members_fn=find_tensor_attributes) 51 | first_tuple = next(gen) 52 | return first_tuple[1].device 53 | 54 | def find_modules(nn_module, type): 55 | return [module for module in nn_module.modules() if isinstance(module, type)] 56 | 57 | class Always(nn.Module): 58 | def __init__(self, val): 59 | super().__init__() 60 | self.val = val 61 | def forward(self, *args, **kwargs): 62 | return self.val 63 | 64 | # kernel functions 65 | 66 | # transcribed from jax to pytorch from 67 | # https://github.com/google-research/google-research/blob/master/performer/fast_attention/jax/fast_attention.py 68 | 69 | def softmax_kernel(data, *, projection_matrix, is_query, normalize_data=True, eps=1e-4, device = None): 70 | b, h, *_ = data.shape 71 | 72 | data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1. 73 | 74 | ratio = (projection_matrix.shape[0] ** -0.5) 75 | 76 | projection = repeat(projection_matrix, 'j d -> b h j d', b = b, h = h) 77 | projection = projection.type_as(data) 78 | 79 | data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data), projection.clone()) 80 | 81 | diag_data = data ** 2 82 | diag_data = torch.sum(diag_data, dim=-1) 83 | diag_data = (diag_data / 2.0) * (data_normalizer ** 2) 84 | diag_data = diag_data.unsqueeze(dim=-1) 85 | 86 | if is_query: 87 | data_dash = ratio * ( 88 | torch.exp(data_dash - diag_data - 89 | torch.max(data_dash, dim=-1, keepdim=True).values) + eps) 90 | else: 91 | data_dash = ratio * ( 92 | torch.exp(data_dash - diag_data - torch.max(data_dash)) + eps) 93 | 94 | return data_dash.type_as(data) 95 | 96 | def generalized_kernel(data, *, projection_matrix, kernel_fn = nn.ReLU(), kernel_epsilon = 0.001, normalize_data = True, device = None): 97 | b, h, *_ = data.shape 98 | 99 | data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1. 100 | 101 | if projection_matrix is None: 102 | return kernel_fn(data_normalizer * data) + kernel_epsilon 103 | 104 | projection = repeat(projection_matrix, 'j d -> b h j d', b = b, h = h) 105 | projection = projection.type_as(data) 106 | 107 | data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data), projection) 108 | 109 | data_prime = kernel_fn(data_dash) + kernel_epsilon 110 | return data_prime.type_as(data) 111 | 112 | def orthogonal_matrix_chunk(cols, device = None): 113 | unstructured_block = torch.randn((cols, cols), device = device) 114 | q, r = torch.linalg.qr(unstructured_block.cpu(), mode = 'reduced') 115 | q, r = map(lambda t: t.to(device), (q, r)) 116 | return q.t() 117 | 118 | def gaussian_orthogonal_random_matrix(nb_rows, nb_columns, scaling = 0, device = None): 119 | nb_full_blocks = int(nb_rows / nb_columns) 120 | 121 | block_list = [] 122 | 123 | for _ in range(nb_full_blocks): 124 | q = orthogonal_matrix_chunk(nb_columns, device = device) 125 | block_list.append(q) 126 | 127 | remaining_rows = nb_rows - nb_full_blocks * nb_columns 128 | if remaining_rows > 0: 129 | q = orthogonal_matrix_chunk(nb_columns, device = device) 130 | block_list.append(q[:remaining_rows]) 131 | 132 | final_matrix = torch.cat(block_list) 133 | 134 | if scaling == 0: 135 | multiplier = torch.randn((nb_rows, nb_columns), device = device).norm(dim = 1) 136 | elif scaling == 1: 137 | multiplier = math.sqrt((float(nb_columns))) * torch.ones((nb_rows,), device = device) 138 | else: 139 | raise ValueError(f'Invalid scaling {scaling}') 140 | 141 | return torch.diag(multiplier) @ final_matrix 142 | 143 | # linear attention classes with softmax kernel 144 | 145 | # non-causal linear attention 146 | def linear_attention(q, k, v): 147 | k_cumsum = k.sum(dim = -2) 148 | D_inv = 1. / torch.einsum('...nd,...d->...n', q, k_cumsum.type_as(q)) 149 | context = torch.einsum('...nd,...ne->...de', k, v) 150 | out = torch.einsum('...de,...nd,...n->...ne', context, q, D_inv) 151 | return out 152 | 153 | # efficient causal linear attention, created by EPFL 154 | # TODO: rewrite EPFL's CUDA kernel to do mixed precision and remove half to float conversion and back 155 | def causal_linear_attention(q, k, v, eps = 1e-6): 156 | from fast_transformers.causal_product import CausalDotProduct 157 | autocast_enabled = torch.is_autocast_enabled() 158 | is_half = isinstance(q, torch.cuda.HalfTensor) 159 | assert not is_half or APEX_AVAILABLE, 'half tensors can only be used if nvidia apex is available' 160 | cuda_context = null_context if not autocast_enabled else partial(autocast, enabled = False) 161 | 162 | causal_dot_product_fn = amp.float_function(CausalDotProduct.apply) if is_half else CausalDotProduct.apply 163 | 164 | k_cumsum = k.cumsum(dim=-2) + eps 165 | D_inv = 1. / torch.einsum('...nd,...nd->...n', q, k_cumsum.type_as(q)) 166 | 167 | with cuda_context(): 168 | if autocast_enabled: 169 | q, k, v = map(lambda t: t.float(), (q, k, v)) 170 | 171 | out = causal_dot_product_fn(q, k, v) 172 | 173 | out = torch.einsum('...nd,...n->...nd', out, D_inv) 174 | return out 175 | 176 | # inefficient causal linear attention, without cuda code, for reader's reference 177 | # not being used 178 | def causal_linear_attention_noncuda(q, k, v, chunk_size = 128): 179 | last_k_cumsum = 0 180 | last_context_cumsum = 0 181 | outs = [] 182 | 183 | for q, k, v in zip(*map(lambda t: t.chunk(chunk_size, dim = -2), (q, k, v))): 184 | k_cumsum = last_k_cumsum + k.cumsum(dim=-2) 185 | 186 | D_inv = 1. / torch.einsum('...nd,...nd->...n', q, k_cumsum.type_as(q)) 187 | context = torch.einsum('...nd,...ne->...nde', k, v) 188 | context_cumsum = last_context_cumsum + context.cumsum(dim=-3) 189 | out = torch.einsum('...nde,...nd,...n->...ne', context_cumsum, q, D_inv) 190 | 191 | last_k_cumsum = k_cumsum[:, :, -1:] 192 | last_context_cumsum = context_cumsum[:, :, -1:] 193 | outs.append(out) 194 | 195 | return torch.cat(outs, dim = -2) 196 | 197 | def norm_tensor(tensor, dim=-1): 198 | return tensor / tensor.sum(dim=dim).unsqueeze(dim) 199 | 200 | class FastAttention(nn.Module): 201 | def __init__(self, dim_heads, nb_features = None, ortho_scaling = 0, causal = False, generalized_attention = False, kernel_fn = nn.ReLU(), no_projection = False): 202 | super().__init__() 203 | nb_features = default(nb_features, int(dim_heads * math.log(dim_heads))) 204 | 205 | self.dim_heads = dim_heads 206 | self.nb_features = nb_features 207 | self.ortho_scaling = ortho_scaling 208 | 209 | self.create_projection = partial(gaussian_orthogonal_random_matrix, nb_rows = self.nb_features, nb_columns = dim_heads, scaling = ortho_scaling) 210 | projection_matrix = self.create_projection() 211 | self.register_buffer('projection_matrix', projection_matrix) 212 | 213 | self.generalized_attention = generalized_attention 214 | self.kernel_fn = kernel_fn 215 | 216 | # if this is turned on, no projection will be used 217 | # queries and keys will be softmax-ed as in the original efficient attention paper 218 | self.no_projection = no_projection 219 | 220 | self.causal = causal 221 | if causal: 222 | try: 223 | import fast_transformers.causal_product.causal_product_cuda 224 | self.causal_linear_fn = partial(causal_linear_attention) 225 | except ImportError: 226 | print('unable to import cuda code for auto-regressive Performer. will default to the memory inefficient non-cuda version') 227 | self.causal_linear_fn = causal_linear_attention_noncuda 228 | 229 | @torch.no_grad() 230 | def redraw_projection_matrix(self, device): 231 | projections = self.create_projection(device = device) 232 | self.projection_matrix.copy_(projections) 233 | del projections 234 | 235 | def forward(self, q, k, v, output_attentions = False): 236 | device = q.device 237 | # inds = [8060, 8064, 6243, 8575, 10342, 10913, 9366, 993, 7796, 5210, 5212, 5504, 6851, 6559, 5508, 13107, 13820] 238 | if self.no_projection: 239 | q = q.softmax(dim = -1) 240 | k = torch.exp(k) if self.causal else k.softmax(dim = -2) 241 | 242 | elif self.generalized_attention: 243 | create_kernel = partial(generalized_kernel, kernel_fn = self.kernel_fn, projection_matrix = self.projection_matrix, device = device) 244 | q, k = map(create_kernel, (q, k)) 245 | 246 | else: 247 | create_kernel = partial(softmax_kernel, projection_matrix = self.projection_matrix, device = device) 248 | q = create_kernel(q, is_query = True) 249 | k = create_kernel(k, is_query = False) 250 | 251 | attn_fn = linear_attention if not self.causal else self.causal_linear_fn 252 | out = attn_fn(q, k, v) 253 | if output_attentions: 254 | v_diag = torch.eye(v.shape[-2]).to(device) 255 | v_diag = v_diag.unsqueeze(0).unsqueeze(0).repeat(v.shape[0],v.shape[1],1,1) 256 | # attn_weights = torch.zeros(1, 1, len(inds), len(inds)).to(device).to(torch.float16) 257 | # attn_weights = torch.zeros(1, q.shape[1], len(inds), len(inds)).to(device).to(torch.float16) 258 | attn_weights = torch.zeros(1, 1, q.shape[2], q.shape[2]).to(device).to(torch.float16) 259 | for head_dim in range(q.shape[1]): 260 | # attn_weights[0, head_dim] = torch.abs(attn_fn(q[:,head_dim].to(torch.float16), k[:,head_dim].to(torch.float16), v_diag[:,head_dim].to(torch.float16)))[0, inds][:, inds] 261 | attn_weights += torch.abs(attn_fn(q[:,head_dim].to(torch.float16), k[:,head_dim].to(torch.float16), v_diag[:,head_dim].to(torch.float16))) 262 | # attn_weights += norm_tensor(torch.abs(attn_fn(q[:,head_dim].to(torch.float16), k[:,head_dim].to(torch.float16), v_diag[:,head_dim].to(torch.float16))), dim=-1) 263 | attn_weights /= q.shape[1] 264 | return out, attn_weights 265 | else: 266 | return out 267 | 268 | # classes 269 | 270 | class ReZero(nn.Module): 271 | def __init__(self, fn): 272 | super().__init__() 273 | self.g = nn.Parameter(torch.tensor(1e-3)) 274 | self.fn = fn 275 | 276 | def forward(self, x, **kwargs): 277 | return self.fn(x, **kwargs) * self.g 278 | 279 | class PreScaleNorm(nn.Module): 280 | def __init__(self, dim, fn, eps=1e-5): 281 | super().__init__() 282 | self.fn = fn 283 | self.g = nn.Parameter(torch.ones(1)) 284 | self.eps = eps 285 | 286 | def forward(self, x, **kwargs): 287 | n = torch.norm(x, dim=-1, keepdim=True).clamp(min=self.eps) 288 | x = x / n * self.g 289 | return self.fn(x, **kwargs) 290 | 291 | class PreLayerNorm(nn.Module): 292 | def __init__(self, dim, fn): 293 | super().__init__() 294 | self.norm = nn.LayerNorm(dim) 295 | self.fn = fn 296 | def forward(self, x, **kwargs): 297 | return self.fn(self.norm(x), **kwargs) 298 | 299 | class Chunk(nn.Module): 300 | def __init__(self, chunks, fn, along_dim = -1): 301 | super().__init__() 302 | self.dim = along_dim 303 | self.chunks = chunks 304 | self.fn = fn 305 | 306 | def forward(self, x, **kwargs): 307 | if self.chunks == 1: 308 | return self.fn(x, **kwargs) 309 | chunks = x.chunk(self.chunks, dim = self.dim) 310 | return torch.cat([self.fn(c, **kwargs) for c in chunks], dim = self.dim) 311 | 312 | class FeedForward(nn.Module): 313 | def __init__(self, dim, mult = 4, dropout = 0., activation = None, glu = False): 314 | super().__init__() 315 | activation = default(activation, nn.GELU) 316 | 317 | self.glu = glu 318 | self.w1 = nn.Linear(dim, dim * mult * (2 if glu else 1)) 319 | self.act = activation() 320 | self.dropout = nn.Dropout(dropout) 321 | self.w2 = nn.Linear(dim * mult, dim) 322 | 323 | def forward(self, x, **kwargs): 324 | if not self.glu: 325 | x = self.w1(x) 326 | x = self.act(x) 327 | else: 328 | x, v = self.w1(x).chunk(2, dim=-1) 329 | x = self.act(x) * v 330 | 331 | x = self.dropout(x) 332 | x = self.w2(x) 333 | return x 334 | 335 | class SelfAttention(nn.Module): 336 | def __init__( 337 | self, 338 | dim, 339 | causal = False, 340 | heads = 8, 341 | dim_head = 64, 342 | local_heads = 0, 343 | local_window_size = 256, 344 | nb_features = None, 345 | feature_redraw_interval = 1000, 346 | generalized_attention = False, 347 | kernel_fn = nn.ReLU(), 348 | dropout = 0., 349 | no_projection = False, 350 | qkv_bias = False 351 | ): 352 | super().__init__() 353 | assert dim % heads == 0, 'dimension must be divisible by number of heads' 354 | dim_head = default(dim_head, dim // heads) 355 | inner_dim = dim_head * heads 356 | self.fast_attention = FastAttention(dim_head, nb_features, causal = causal, generalized_attention = generalized_attention, kernel_fn = kernel_fn, no_projection = no_projection) 357 | 358 | self.heads = heads 359 | self.global_heads = heads - local_heads 360 | self.local_attn = LocalAttention(window_size = local_window_size, causal = causal, autopad = True, dropout = dropout, look_forward = int(not causal), rel_pos_emb_config = (dim_head, local_heads)) if local_heads > 0 else None 361 | 362 | self.to_q = nn.Linear(dim, inner_dim, bias = qkv_bias) 363 | self.to_k = nn.Linear(dim, inner_dim, bias = qkv_bias) 364 | self.to_v = nn.Linear(dim, inner_dim, bias = qkv_bias) 365 | self.to_out = nn.Linear(inner_dim, dim) 366 | self.dropout = nn.Dropout(dropout) 367 | 368 | def forward(self, x, pos_emb = None, context = None, mask = None, context_mask = None, output_attentions = False, **kwargs): 369 | b, n, _, h, gh = *x.shape, self.heads, self.global_heads 370 | 371 | cross_attend = exists(context) 372 | 373 | context = default(context, x) 374 | context_mask = default(context_mask, mask) if not cross_attend else context_mask 375 | 376 | q, k, v = self.to_q(x), self.to_k(context), self.to_v(context) 377 | 378 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) 379 | (q, lq), (k, lk), (v, lv) = map(lambda t: (t[:, :gh], t[:, gh:]), (q, k, v)) 380 | 381 | attn_outs = [] 382 | 383 | if not empty(q): 384 | if exists(context_mask): 385 | global_mask = context_mask[:, None, :, None] 386 | v.masked_fill_(~global_mask, 0.) 387 | 388 | if exists(pos_emb) and not cross_attend: 389 | q, k, = apply_rotary_pos_emb(q, k, pos_emb) 390 | 391 | if output_attentions: 392 | out, attn_weights = self.fast_attention(q, k, v, output_attentions) 393 | else: 394 | out = self.fast_attention(q, k, v) 395 | attn_outs.append(out) 396 | 397 | if not empty(lq): 398 | assert not cross_attend, 'local attention is not compatible with cross attention' 399 | out = self.local_attn(lq, lk, lv, input_mask = mask) 400 | attn_outs.append(out) 401 | 402 | out = torch.cat(attn_outs, dim = 1) # combine attn_out and cross_attn_out, here we have only attn_out, that means this line does nothing 403 | out = rearrange(out, 'b h n d -> b n (h d)') 404 | out = self.to_out(out) 405 | if output_attentions: 406 | return self.dropout(out), attn_weights 407 | else: 408 | return self.dropout(out) 409 | 410 | # positional embeddings 411 | 412 | class AbsolutePositionalEmbedding(nn.Module): 413 | def __init__(self, dim, max_seq_len): 414 | super().__init__() 415 | self.emb = nn.Embedding(max_seq_len, dim) 416 | 417 | def forward(self, x): 418 | t = torch.arange(x.shape[1], device=x.device) 419 | return self.emb(t) 420 | 421 | # rotary positional embedding helpers 422 | 423 | def rotate_every_two(x): 424 | x = rearrange(x, '... (d j) -> ... d j', j = 2) 425 | x1, x2 = x.unbind(dim = -1) 426 | x = torch.stack((-x2, x1), dim = -1) 427 | return rearrange(x, '... d j -> ... (d j)') 428 | 429 | def apply_rotary_pos_emb(q, k, sinu_pos): 430 | sinu_pos = rearrange(sinu_pos, '() n (j d) -> n j d', j = 2) 431 | sin, cos = sinu_pos.unbind(dim = -2) 432 | sin, cos = map(lambda t: repeat(t, 'b n -> b (n j)', j = 2), (sin, cos)) 433 | q, k = map(lambda t: (t * cos) + (rotate_every_two(t) * sin), (q, k)) 434 | return q, k 435 | 436 | # sinusoidal positional embeddings 437 | class SinExpressionEmbedding(nn.Module): 438 | def __init__(self, d_model: int, dropout: float = 0.1, wavelength: float = 12.0): 439 | super().__init__() 440 | self.d_model = d_model 441 | self.dropout = nn.Dropout(p=dropout) 442 | 443 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(wavelength) / d_model)) 444 | # Add dimensions for easy multiplication with input tensors 445 | div_term = div_term[None, None, :] 446 | 447 | self.register_buffer('div_term', div_term) 448 | 449 | def forward(self, x): 450 | """ 451 | Args: 452 | x: Tensor, shape [batch_size, seq_len] 453 | """ 454 | pe = torch.zeros(x.size(0),x.size(1),self.d_model,device=x.device) 455 | pe[:,:,0::2] = torch.sin(x[:,:,None] * self.div_term) 456 | pe[:,:,1::2] = torch.cos(x[:,:,None] * self.div_term) 457 | return self.dropout(pe) 458 | 459 | class Gene2VecPositionalEmbedding(nn.Module): 460 | def __init__(self, dim, max_seq_len, g2v_file): 461 | super().__init__() 462 | gene2vec_weight = np.load(g2v_file) 463 | gene2vec_weight = np.concatenate((gene2vec_weight, np.zeros((1, gene2vec_weight.shape[1]))), axis=0) 464 | gene2vec_weight = torch.from_numpy(gene2vec_weight) 465 | self.emb = nn.Embedding.from_pretrained(gene2vec_weight) 466 | 467 | def forward(self, x): 468 | t = torch.arange(x.shape[1], device=x.device) 469 | return self.emb(t) 470 | 471 | # performer 472 | 473 | class Performer(nn.Module): 474 | def __init__( 475 | self, 476 | dim, # dimension 477 | depth, # layers 478 | heads, # heads 479 | dim_head, # dim of head 480 | local_attn_heads = 0, # num of local attention heads, (heads - local_attn_heads) is num of global performers 481 | local_window_size = 256, # window size of local attention 482 | causal = False, # autoregressive or not 483 | ff_mult = 4, # dim of intermediate features after attention / dim of input features 484 | nb_features = None, # number of random features, if not set, will default to (d * log(d)), where d is the dimension of each head ?? what is random feature ?? 485 | feature_redraw_interval = 1000, # how frequently to redraw the projection matrix, the more frequent, the slower the training 486 | reversible = False, # reversible layers, from Reformer (save memory) 487 | ff_chunks = 1, # chunk feedforward layer, from Reformer 488 | generalized_attention = False, # defaults to softmax approximation, but can be set to True for generalized attention ?? what is generalized attention ?? 489 | kernel_fn = nn.ReLU(), # the kernel function to be used, if generalized attention is turned on, defaults to Relu 490 | use_scalenorm = False, # use scale norm, from 'Transformers without Tears' paper, a substitute for LayerNorm, priority: scalenorm.rezero.layernorm 491 | use_rezero = False, # use Rezero or not, from 'Rezero is all you need' paper, a substitute for LayerNorm, priority: scalenorm.rezero.layernorm 492 | ff_glu = False, # use GLU (Gated Linear Units) variant for feedforward 493 | ff_dropout = 0., # feedforward dropout 494 | attn_dropout = 0., # post-attention dropout 495 | cross_attend = False, # ?? 496 | no_projection = False, # ?? 497 | auto_check_redraw = True, # ?? 498 | qkv_bias = True, # ?? 499 | ): 500 | super().__init__() 501 | layers = nn.ModuleList([]) 502 | local_attn_heads = cast_tuple(local_attn_heads) 503 | local_attn_heads = local_attn_heads * depth if len(local_attn_heads) == 1 else local_attn_heads 504 | assert len(local_attn_heads) == depth, 'tuple specifying number of local attention heads per depth must be equal to the total depth' 505 | assert all(map(lambda n: n >= 0 and n <= heads, local_attn_heads)), 'local attention head value must be less than the total number of heads' 506 | 507 | if use_scalenorm: 508 | wrapper_fn = partial(PreScaleNorm, dim) 509 | elif use_rezero: 510 | wrapper_fn = ReZero 511 | else: 512 | wrapper_fn = partial(PreLayerNorm, dim) 513 | 514 | for _, local_heads in zip(range(depth), local_attn_heads): 515 | layers.append(nn.ModuleList([ 516 | wrapper_fn(SelfAttention(dim, causal = causal, heads = heads, dim_head = dim_head, local_heads = local_heads, local_window_size = local_window_size, nb_features = nb_features, generalized_attention = generalized_attention, kernel_fn = kernel_fn, dropout = attn_dropout, no_projection = no_projection, qkv_bias = qkv_bias)), 517 | wrapper_fn(Chunk(ff_chunks, FeedForward(dim, mult = ff_mult, dropout = ff_dropout, glu = ff_glu), along_dim = 1)) 518 | ])) 519 | # if no need cross_attend(decoder), begin next cycle 520 | if not cross_attend: 521 | continue 522 | layers.append(nn.ModuleList([ 523 | wrapper_fn(SelfAttention(dim, heads = heads, dim_head = dim_head, nb_features = nb_features, generalized_attention = generalized_attention, kernel_fn = kernel_fn, dropout = attn_dropout, no_projection = no_projection)), 524 | wrapper_fn(Chunk(ff_chunks, FeedForward(dim, mult = ff_mult, dropout = ff_dropout, glu = ff_glu), along_dim = 1)) 525 | ])) 526 | 527 | execute_type = ReversibleSequence if reversible else SequentialSequence 528 | 529 | route_attn = ((True, False),) * depth * (2 if cross_attend else 1) # ((True, False), (True, False), (True, False), (True, False), (True, False), (True, False)) 530 | route_context = ((False, False), (True, False)) * depth 531 | attn_route_map = {'mask': route_attn, 'pos_emb': route_attn} 532 | context_route_map = {'context': route_context, 'context_mask': route_context} if cross_attend else {} 533 | self.net = execute_type(layers, args_route = {**attn_route_map, **context_route_map}) 534 | 535 | # keeping track of when to redraw projections for all attention layers 536 | self.auto_check_redraw = auto_check_redraw 537 | self.feature_redraw_interval = feature_redraw_interval 538 | self.register_buffer('calls_since_last_redraw', torch.tensor(0)) 539 | 540 | def fix_projection_matrices_(self): 541 | self.feature_redraw_interval = None 542 | 543 | def check_redraw_projections(self): 544 | if not self.training: 545 | return 546 | 547 | if exists(self.feature_redraw_interval) and self.calls_since_last_redraw >= self.feature_redraw_interval: 548 | device = get_module_device(self) 549 | 550 | fast_attentions = find_modules(self, FastAttention) 551 | for fast_attention in fast_attentions: 552 | fast_attention.redraw_projection_matrix(device) 553 | 554 | self.calls_since_last_redraw.zero_() 555 | return 556 | 557 | self.calls_since_last_redraw += 1 558 | 559 | def forward(self, x, output_attentions = False, **kwargs): 560 | if self.auto_check_redraw: 561 | self.check_redraw_projections() 562 | return self.net(x, output_attentions = output_attentions, **kwargs) 563 | 564 | class PerformerLM(nn.Module): 565 | def __init__( 566 | self, 567 | *, 568 | num_tokens, # num of tokens 569 | max_seq_len, # max length of sequence 570 | dim, # dim of tokens 571 | depth, # layers 572 | heads, # num of heads 573 | dim_head = 64, # dim of heads 574 | local_attn_heads = 0, 575 | local_window_size = 256, 576 | causal = False, 577 | ff_mult = 4, 578 | nb_features = None, 579 | feature_redraw_interval = 1000, 580 | reversible = False, 581 | ff_chunks = 1, 582 | ff_glu = False, 583 | emb_dropout = 0., 584 | ff_dropout = 0., 585 | attn_dropout = 0., 586 | generalized_attention = False, 587 | kernel_fn = nn.ReLU(), 588 | use_scalenorm = False, 589 | use_rezero = False, 590 | cross_attend = False, 591 | no_projection = False, 592 | tie_embed = False, # False: output is num of tokens, True: output is dim of tokens //multiply final embeddings with token weights for logits, like gpt decoder// 593 | g2v_position_emb = True, # priority: gene2vec, no embedding 594 | g2v_file = None, 595 | sin_emb_wavelength = 12.0, 596 | auto_check_redraw = True, 597 | qkv_bias = False, 598 | pred_continuous = False # False: predict the categorical bucketed expression, True: predict the continuous expression value 599 | ): 600 | super().__init__() 601 | local_attn_heads = cast_tuple(local_attn_heads) 602 | 603 | self.max_seq_len = max_seq_len 604 | if(pred_continuous): 605 | self.token_emb = SinExpressionEmbedding(dim, wavelength = sin_emb_wavelength) 606 | else: 607 | self.token_emb = nn.Embedding(num_tokens, dim) 608 | 609 | if g2v_position_emb: 610 | self.pos_emb = Gene2VecPositionalEmbedding(dim, max_seq_len, g2v_file) 611 | self.layer_pos_emb = Always(None) 612 | else: 613 | self.pos_emb = torch.zeros_like 614 | self.layer_pos_emb = Always(None) 615 | 616 | self.dropout = nn.Dropout(emb_dropout) 617 | 618 | self.performer = Performer(dim, depth, heads, dim_head, local_attn_heads, local_window_size, causal, ff_mult, nb_features, feature_redraw_interval, reversible, ff_chunks, generalized_attention, kernel_fn, use_scalenorm, use_rezero, ff_glu, ff_dropout, attn_dropout, cross_attend, no_projection, auto_check_redraw, qkv_bias) 619 | self.norm = nn.LayerNorm(dim) 620 | if pred_continuous: 621 | self.to_out = nn.Linear(dim, 1) 622 | else: 623 | self.to_out = nn.Linear(dim, num_tokens) if not tie_embed else None 624 | 625 | def check_redraw_projections(self): 626 | self.performer.check_redraw_projections() 627 | 628 | def fix_projection_matrices_(self): 629 | self.performer.fix_projection_matrices_() 630 | 631 | def forward(self, x, return_encodings = False, output_attentions = False, **kwargs): 632 | b, n, device = *x.shape, x.device 633 | assert n <= self.max_seq_len, f'sequence length {n} must be less than the max sequence length {self.max_seq_len}' 634 | # token and positional embedding 635 | x = self.token_emb(x) # expression level 636 | if output_attentions: 637 | x.requires_grad_() # used for attn_map output 638 | x += self.pos_emb(x) # add gene embedding (gene2vec or zeroes) 639 | x = self.dropout(x) 640 | 641 | # performer layers 642 | layer_pos_emb = self.layer_pos_emb(x) #this returns None (Rebecca) 643 | if output_attentions: 644 | x, attn_weights = self.performer(x, pos_emb = layer_pos_emb, output_attentions = output_attentions, **kwargs) 645 | # norm and to logits 646 | x = self.norm(x) 647 | if return_encodings: 648 | return self.to_out(x), x, attn_weights 649 | 650 | if exists(self.to_out): 651 | return self.to_out(x), attn_weights 652 | 653 | return (x @ self.token_emb.weight.t()), attn_weights 654 | else: 655 | x = self.performer(x, pos_emb = layer_pos_emb, output_attentions = output_attentions, **kwargs) 656 | # norm and to logits 657 | x = self.norm(x) 658 | if return_encodings: 659 | return self.to_out(x), x 660 | 661 | if exists(self.to_out): 662 | x = self.to_out(x) 663 | return x #batch size x seq len x hidden (logits) 664 | 665 | return x @ self.token_emb.weight.t() #batch size x seq len x vocab size 666 | -------------------------------------------------------------------------------- /scBERT/performer_pytorch/reversible.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from operator import itemgetter 4 | from torch.autograd.function import Function 5 | from torch.utils.checkpoint import get_device_states, set_device_states 6 | 7 | # for routing arguments into the functions of the reversible layer 8 | def route_args(router, args, depth): 9 | routed_args = [(dict(), dict()) for _ in range(depth)] 10 | matched_keys = [key for key in args.keys() if key in router] 11 | 12 | for key in matched_keys: 13 | val = args[key] 14 | for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])): 15 | new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes) 16 | routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args}) 17 | return routed_args 18 | 19 | # following example for saving and setting rng here https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html 20 | class Deterministic(nn.Module): 21 | def __init__(self, net): 22 | super().__init__() 23 | self.net = net 24 | self.cpu_state = None 25 | self.cuda_in_fwd = None 26 | self.gpu_devices = None 27 | self.gpu_states = None 28 | 29 | def record_rng(self, *args): 30 | self.cpu_state = torch.get_rng_state() 31 | if torch.cuda._initialized: 32 | self.cuda_in_fwd = True 33 | self.gpu_devices, self.gpu_states = get_device_states(*args) 34 | 35 | def forward(self, *args, record_rng = False, set_rng = False, **kwargs): 36 | if record_rng: 37 | self.record_rng(*args) 38 | 39 | if not set_rng: 40 | return self.net(*args, **kwargs) 41 | 42 | rng_devices = [] 43 | if self.cuda_in_fwd: 44 | rng_devices = self.gpu_devices 45 | 46 | with torch.random.fork_rng(devices=rng_devices, enabled=True): 47 | torch.set_rng_state(self.cpu_state) 48 | if self.cuda_in_fwd: 49 | set_device_states(self.gpu_devices, self.gpu_states) 50 | return self.net(*args, **kwargs) 51 | 52 | # heavily inspired by https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py 53 | # once multi-GPU is confirmed working, refactor and send PR back to source 54 | class ReversibleBlock(nn.Module): 55 | def __init__(self, f, g): 56 | super().__init__() 57 | self.f = Deterministic(f) 58 | self.g = Deterministic(g) 59 | 60 | def forward(self, x, f_args = {}, g_args = {}): 61 | x1, x2 = torch.chunk(x, 2, dim=2) 62 | y1, y2 = None, None 63 | 64 | with torch.no_grad(): 65 | y1 = x1 + self.f(x2, record_rng=self.training, **f_args) 66 | y2 = x2 + self.g(y1, record_rng=self.training, **g_args) 67 | 68 | return torch.cat([y1, y2], dim=2) 69 | 70 | def backward_pass(self, y, dy, f_args = {}, g_args = {}): 71 | y1, y2 = torch.chunk(y, 2, dim=2) 72 | del y 73 | 74 | dy1, dy2 = torch.chunk(dy, 2, dim=2) 75 | del dy 76 | 77 | with torch.enable_grad(): 78 | y1.requires_grad = True 79 | gy1 = self.g(y1, set_rng=True, **g_args) 80 | torch.autograd.backward(gy1, dy2) 81 | 82 | with torch.no_grad(): 83 | x2 = y2 - gy1 84 | del y2, gy1 85 | 86 | dx1 = dy1 + y1.grad 87 | del dy1 88 | y1.grad = None 89 | 90 | with torch.enable_grad(): 91 | x2.requires_grad = True 92 | fx2 = self.f(x2, set_rng=True, **f_args) 93 | torch.autograd.backward(fx2, dx1, retain_graph=True) 94 | 95 | with torch.no_grad(): 96 | x1 = y1 - fx2 97 | del y1, fx2 98 | 99 | dx2 = dy2 + x2.grad 100 | del dy2 101 | x2.grad = None 102 | 103 | x = torch.cat([x1, x2.detach()], dim=2) 104 | dx = torch.cat([dx1, dx2], dim=2) 105 | 106 | return x, dx 107 | 108 | class _ReversibleFunction(Function): 109 | @staticmethod 110 | def forward(ctx, x, blocks, args): 111 | ctx.args = args 112 | for block, kwarg in zip(blocks, args): 113 | x = block(x, **kwarg) 114 | ctx.y = x.detach() 115 | ctx.blocks = blocks 116 | return x 117 | 118 | @staticmethod 119 | def backward(ctx, dy): 120 | y = ctx.y 121 | args = ctx.args 122 | for block, kwargs in zip(ctx.blocks[::-1], args[::-1]): 123 | y, dy = block.backward_pass(y, dy, **kwargs) 124 | return dy, None, None 125 | 126 | class SequentialSequence(nn.Module): 127 | def __init__(self, layers, args_route = {}): 128 | super().__init__() 129 | assert all(len(route) == len(layers) for route in args_route.values()), 'each argument route map must have the same depth as the number of sequential layers' 130 | self.layers = layers 131 | self.args_route = args_route 132 | 133 | def forward(self, x, output_attentions = False, **kwargs): 134 | args = route_args(self.args_route, kwargs, len(self.layers)) 135 | layers_and_args = list(zip(self.layers, args)) 136 | 137 | if output_attentions: 138 | attn_weights = [] 139 | for (f, g), (f_args, g_args) in layers_and_args: 140 | if output_attentions: 141 | x = x + f(x, output_attentions = output_attentions, **f_args)[0] 142 | attn_weights.append(f(x, output_attentions = output_attentions, **f_args)[1].unsqueeze(0)) 143 | else: 144 | x = x + f(x, **f_args) 145 | x = x + g(x, **g_args) 146 | if output_attentions: 147 | attn_weights = torch.transpose(torch.cat(attn_weights, dim=0), 0, 1) # the final dim is (batch, layer, head, len, len) 148 | attn_weights = torch.mean(attn_weights, dim=1) # the dim is (batch, head, len, len) 149 | return x, attn_weights 150 | else: 151 | return x 152 | 153 | class ReversibleSequence(nn.Module): 154 | def __init__(self, blocks, args_route = {}): 155 | super().__init__() 156 | self.args_route = args_route 157 | self.blocks = nn.ModuleList([ReversibleBlock(f=f, g=g) for f, g in blocks]) 158 | 159 | def forward(self, x, **kwargs): 160 | x = torch.cat([x, x], dim=-1) 161 | 162 | blocks = self.blocks 163 | args = route_args(self.args_route, kwargs, len(blocks)) 164 | args = list(map(lambda x: {'f_args': x[0], 'g_args': x[1]}, args)) 165 | 166 | out = _ReversibleFunction.apply(x, blocks, args) 167 | return torch.stack(out.chunk(2, dim=-1)).sum(dim=0) 168 | -------------------------------------------------------------------------------- /scBERT/preprocess.py: -------------------------------------------------------------------------------- 1 | import scanpy as sc, numpy as np, pandas as pd, anndata as ad 2 | from scipy import sparse 3 | 4 | panglao = sc.read_h5ad('./data/panglao_10000.h5ad') 5 | data = sc.read_h5ad('./data/raw_data.h5ad') 6 | counts = sparse.lil_matrix((data.X.shape[0],panglao.X.shape[1]),dtype=np.float32) 7 | ref = panglao.var_names.tolist() 8 | obj = data.var_names.tolist() 9 | 10 | for i in range(len(ref)): 11 | if ref[i] in obj: 12 | loc = obj.index(ref[i]) 13 | counts[:,i] = data.X[:,loc] 14 | 15 | counts = counts.tocsr() 16 | new = ad.AnnData(X=counts) 17 | new.var_names = ref 18 | new.obs_names = data.obs_names 19 | new.obs = data.obs 20 | new.uns = panglao.uns 21 | 22 | sc.pp.filter_cells(new, min_genes=200) 23 | sc.pp.normalize_total(new, target_sum=1e4) 24 | sc.pp.log1p(new, base=2) 25 | new.write('./data/preprocessed_data.h5ad') 26 | 27 | 28 | -------------------------------------------------------------------------------- /scBERT/scbert_environment.yml: -------------------------------------------------------------------------------- 1 | name: rna-pretrain 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - _openmp_mutex=4.5=1_gnu 9 | - blas=1.0=mkl 10 | - brotlipy=0.7.0=py39h27cfd23_1003 11 | - bzip2=1.0.8=h7b6447c_0 12 | - ca-certificates=2022.3.29=h06a4308_1 13 | - certifi=2021.10.8=py39h06a4308_2 14 | - cffi=1.15.0=py39hd667e15_1 15 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 16 | - cryptography=36.0.0=py39h9ce1e76_0 17 | - cudatoolkit=11.3.1=h2bc3f7f_2 18 | - ffmpeg=4.3=hf484d3e_0 19 | - freetype=2.11.0=h70c0345_0 20 | - giflib=5.2.1=h7b6447c_0 21 | - gmp=6.2.1=h2531618_2 22 | - gnutls=3.6.15=he1e5248_0 23 | - idna=3.3=pyhd3eb1b0_0 24 | - intel-openmp=2021.4.0=h06a4308_3561 25 | - jpeg=9d=h7f8727e_0 26 | - lame=3.100=h7b6447c_0 27 | - lcms2=2.12=h3be6417_0 28 | - ld_impl_linux-64=2.35.1=h7274673_9 29 | - libffi=3.3=he6710b0_2 30 | - libgcc-ng=9.3.0=h5101ec6_17 31 | - libgomp=9.3.0=h5101ec6_17 32 | - libiconv=1.16=h7f8727e_2 33 | - libidn2=2.3.2=h7f8727e_0 34 | - libpng=1.6.37=hbc83047_0 35 | - libstdcxx-ng=9.3.0=hd4cf53a_17 36 | - libtasn1=4.16.0=h27cfd23_0 37 | - libtiff=4.2.0=h85742a9_0 38 | - libunistring=0.9.10=h27cfd23_0 39 | - libuv=1.40.0=h7b6447c_0 40 | - libwebp=1.2.2=h55f646e_0 41 | - libwebp-base=1.2.2=h7f8727e_0 42 | - lz4-c=1.9.3=h295c915_1 43 | - mkl=2021.4.0=h06a4308_640 44 | - mkl-service=2.4.0=py39h7f8727e_0 45 | - mkl_fft=1.3.1=py39hd3c417c_0 46 | - mkl_random=1.2.2=py39h51133e4_0 47 | - ncurses=6.3=h7f8727e_2 48 | - nettle=3.7.3=hbbd107a_1 49 | - numpy=1.21.5=py39he7a7128_1 50 | - numpy-base=1.21.5=py39hf524024_1 51 | - openh264=2.1.1=h4ff587b_0 52 | - openssl=1.1.1n=h7f8727e_0 53 | - pillow=9.0.1=py39h22f2fdc_0 54 | - pip=21.2.4=py39h06a4308_0 55 | - pycparser=2.21=pyhd3eb1b0_0 56 | - pyopenssl=22.0.0=pyhd3eb1b0_0 57 | - pysocks=1.7.1=py39h06a4308_0 58 | - python=3.9.12=h12debd9_0 59 | - pytorch=1.11.0=py3.9_cuda11.3_cudnn8.2.0_0 60 | - pytorch-mutex=1.0=cuda 61 | - readline=8.1.2=h7f8727e_1 62 | - requests=2.27.1=pyhd3eb1b0_0 63 | - setuptools=61.2.0=py39h06a4308_0 64 | - six=1.16.0=pyhd3eb1b0_1 65 | - sqlite=3.38.2=hc218d9a_0 66 | - tk=8.6.11=h1ccaba5_0 67 | - torchaudio=0.11.0=py39_cu113 68 | - torchvision=0.12.0=py39_cu113 69 | - typing_extensions=4.1.1=pyh06a4308_0 70 | - tzdata=2022a=hda174b7_0 71 | - urllib3=1.26.9=py39h06a4308_0 72 | - wheel=0.37.1=pyhd3eb1b0_0 73 | - xz=5.2.5=h7b6447c_0 74 | - zlib=1.2.12=h7f8727e_2 75 | - zstd=1.4.9=haebb681_0 76 | - pip: 77 | - absl-py==1.0.0 78 | - aiohttp==3.8.3 79 | - aiosignal==1.2.0 80 | - altair==4.2.0 81 | - anndata==0.8.0 82 | - asttokens==2.0.5 83 | - async-timeout==4.0.2 84 | - attrs==22.1.0 85 | - axial-positional-embedding==0.2.1 86 | - backcall==0.2.0 87 | - cachetools==5.0.0 88 | - chex==0.1.5 89 | - click==8.1.2 90 | - colorama==0.4.5 91 | - commonmark==0.9.1 92 | - cycler==0.11.0 93 | - debugpy==1.6.2 94 | - decorator==5.1.1 95 | - deepspeed==0.7.3 96 | - dm-tree==0.1.7 97 | - docrep==0.3.2 98 | - einops==0.4.1 99 | - entrypoints==0.4 100 | - et-xmlfile==1.1.0 101 | - etils==0.8.0 102 | - executing==0.8.3 103 | - filelock==3.6.0 104 | - flax==0.6.0 105 | - fonttools==4.33.3 106 | - frozenlist==1.3.1 107 | - fsspec==2022.8.2 108 | - future==0.18.2 109 | - google-auth==2.6.6 110 | - google-auth-oauthlib==0.4.6 111 | - grpcio==1.44.0 112 | - h5py==3.6.0 113 | - hjson==3.1.0 114 | - huggingface-hub==0.5.1 115 | - hyperopt==0.1.2 116 | - importlib-metadata==4.11.3 117 | - importlib-resources==5.9.0 118 | - ipdb==0.13.9 119 | - ipykernel==6.15.1 120 | - ipython==8.4.0 121 | - ipywidgets==8.0.2 122 | - jax==0.3.21 123 | - jaxlib==0.3.20 124 | - jedi==0.18.1 125 | - jinja2==3.1.2 126 | - joblib==1.1.0 127 | - jsonschema==4.16.0 128 | - jupyter-client==7.3.4 129 | - jupyter-core==4.11.1 130 | - jupyterlab-widgets==3.0.3 131 | - kiwisolver==1.4.2 132 | - llvmlite==0.38.0 133 | - local-attention==1.4.3 134 | - markdown==3.3.6 135 | - markupsafe==2.1.1 136 | - matplotlib==3.5.1 137 | - matplotlib-inline==0.1.3 138 | - msgpack==1.0.4 139 | - mudata==0.2.0 140 | - multidict==6.0.2 141 | - multipledispatch==0.6.0 142 | - natsort==8.1.0 143 | - nest-asyncio==1.5.5 144 | - networkx==2.8 145 | - ninja==1.10.2.3 146 | - numba==0.55.1 147 | - numpyro==0.10.1 148 | - oauthlib==3.2.0 149 | - openpyxl==3.0.10 150 | - opt-einsum==3.3.0 151 | - optax==0.1.3 152 | - packaging==21.3 153 | - pandas==1.4.2 154 | - parso==0.8.3 155 | - patsy==0.5.2 156 | - performer-pytorch==1.1.4 157 | - pexpect==4.8.0 158 | - pickleshare==0.7.5 159 | - prompt-toolkit==3.0.30 160 | - protobuf==3.19.6 161 | - psutil==5.9.1 162 | - ptyprocess==0.7.0 163 | - pure-eval==0.2.2 164 | - py-cpuinfo==8.0.0 165 | - pyasn1==0.4.8 166 | - pyasn1-modules==0.2.8 167 | - pydantic==1.10.2 168 | - pydeprecate==0.3.2 169 | - pygments==2.12.0 170 | - pymongo==4.2.0 171 | - pynndescent==0.5.6 172 | - pyparsing==3.0.8 173 | - pyro-api==0.1.2 174 | - pyro-ppl==1.8.2 175 | - pyrsistent==0.18.1 176 | - python-dateutil==2.8.2 177 | - pytorch-lightning==1.7.7 178 | - pytz==2022.1 179 | - pyyaml==6.0 180 | - pyzmq==23.2.0 181 | - regex==2022.4.24 182 | - requests-oauthlib==1.3.1 183 | - rich==11.2.0 184 | - rsa==4.8 185 | - sacremoses==0.0.49 186 | - scanpy==1.9.1 187 | - scikit-learn==1.0.2 188 | - scikit-misc==0.1.4 189 | - scipy==1.8.0 190 | - scvi-tools==0.17.4 191 | - seaborn==0.11.2 192 | - session-info==1.0.0 193 | - sklearn==0.0 194 | - stack-data==0.3.0 195 | - statsmodels==0.13.2 196 | - stdlib-list==0.8.0 197 | - tensorboard==2.10.1 198 | - tensorboard-data-server==0.6.1 199 | - tensorboard-plugin-wit==1.8.1 200 | - threadpoolctl==3.1.0 201 | - tokenizers==0.12.1 202 | - toml==0.10.2 203 | - toolz==0.12.0 204 | - torchmetrics==0.9.3 205 | - tornado==6.2 206 | - tqdm==4.64.0 207 | - traitlets==5.3.0 208 | - transformers==4.18.0 209 | - umap-learn==0.5.3 210 | - wcwidth==0.2.5 211 | - werkzeug==2.1.1 212 | - widgetsnbextension==4.0.3 213 | - xlrd==2.0.1 214 | - yarl==1.8.1 215 | - zipp==3.8.0 216 | prefix: /opt/conda/rpeyser/envs/rna-pretrain 217 | -------------------------------------------------------------------------------- /scBERT/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import print_function 4 | import json 5 | import os 6 | import struct 7 | import sys 8 | import platform 9 | import re 10 | import time 11 | import traceback 12 | import requests 13 | import socket 14 | import random 15 | import math 16 | import numpy as np 17 | import torch 18 | import logging 19 | import datetime 20 | from torch.optim.lr_scheduler import _LRScheduler 21 | from torch import nn 22 | import torch.nn.functional as F 23 | from torch.nn.modules.loss import _WeightedLoss 24 | 25 | 26 | 27 | def seed_all(seed_value, cuda_deterministic=False): 28 | """ 29 | 设置所有的随机种子 30 | """ 31 | random.seed(seed_value) 32 | os.environ['PYTHONHASHSEED'] = str(seed_value) 33 | np.random.seed(seed_value) 34 | torch.manual_seed(seed_value) 35 | if torch.cuda.is_available(): 36 | torch.cuda.manual_seed(seed_value) 37 | torch.cuda.manual_seed_all(seed_value) 38 | # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html 39 | if cuda_deterministic: # slower, more reproducible 40 | torch.backends.cudnn.deterministic = True 41 | torch.backends.cudnn.benchmark = False 42 | else: # faster, less reproducible 43 | torch.backends.cudnn.deterministic = False 44 | torch.backends.cudnn.benchmark = True 45 | 46 | 47 | def set_log(logfileName, rank=-1): 48 | """ 49 | master节点保存所有log,其他节点只保存warning及error 50 | """ 51 | log_file_folder = os.path.dirname(logfileName) 52 | time_now = datetime.datetime.now() 53 | logfileName = f'{logfileName}_{time_now.year}_{time_now.month}_{time_now.day}_{time_now.hour}_{time_now.minute}.log' 54 | if not os.path.exists(log_file_folder): 55 | os.makedirs(log_file_folder) 56 | else: 57 | pass 58 | 59 | logging.basicConfig(level=logging.INFO if rank in [-1, 0] else logging.WARN, 60 | format='[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d] %(message)s', 61 | datefmt='[%X]', 62 | handlers=[logging.FileHandler(logfileName), logging.StreamHandler()] 63 | ) 64 | logger = logging.getLogger() 65 | return logger 66 | 67 | 68 | def save_ckpt(epoch, model, optimizer, scheduler, losses, model_name, ckpt_folder): 69 | """ 70 | 保存模型checkpoint 71 | """ 72 | if not os.path.exists(ckpt_folder): 73 | os.makedirs(ckpt_folder) 74 | torch.save( 75 | { 76 | 'epoch': epoch, 77 | 'model_state_dict': model.state_dict(), 78 | 'optimizer_state_dict': optimizer.state_dict(), 79 | 'scheduler_state_dict': scheduler.state_dict(), 80 | 'losses': losses, 81 | }, 82 | f'{ckpt_folder}/{model_name}_epoch_{epoch}.pth' 83 | ) 84 | 85 | def save_simple_ckpt(model, model_name, ckpt_folder): 86 | """ 87 | 保存模型checkpoint 88 | """ 89 | if not os.path.exists(ckpt_folder): 90 | os.makedirs(ckpt_folder) 91 | torch.save( 92 | { 93 | 'model_state_dict': model.module.state_dict() 94 | }, 95 | f'{ckpt_folder}{model_name}.pth' 96 | ) 97 | 98 | def save_best_ckpt(epoch, model, optimizer, scheduler, losses, model_name, ckpt_folder): 99 | """ 100 | 保存模型checkpoint 101 | """ 102 | if not os.path.exists(ckpt_folder): 103 | os.makedirs(ckpt_folder) 104 | torch.save( 105 | { 106 | 'epoch': epoch, 107 | 'model_state_dict': model.module.state_dict(), 108 | 'optimizer_state_dict': optimizer.state_dict(), 109 | 'scheduler_state_dict': scheduler.state_dict(), 110 | 'losses': losses, 111 | }, 112 | f'{ckpt_folder}/{model_name}_best.pth' 113 | ) 114 | 115 | def get_reduced(tensor, current_device, dest_device, world_size): 116 | """ 117 | 将不同GPU上的变量或tensor集中在主GPU上,并得到均值 118 | """ 119 | tensor = tensor.clone().detach() if torch.is_tensor(tensor) else torch.tensor(tensor) 120 | tensor = tensor.to(current_device) 121 | torch.distributed.reduce(tensor, dst=dest_device) 122 | tensor_mean = tensor.item() / world_size 123 | return tensor_mean 124 | 125 | def get_ndtensor_reduced(tensor, current_device, dest_device, world_size): 126 | """ 127 | 将不同GPU上的变量或tensor集中在主GPU上,并得到均值, 需要是2维张量 128 | """ 129 | tensor = tensor.clone().detach() if torch.is_tensor(tensor) else torch.tensor(tensor) 130 | tensor = tensor.to(current_device) 131 | torch.distributed.reduce(tensor, dst=dest_device) 132 | tensor_mean = torch.zeros(tensor.shape) 133 | if len(tensor.shape) == 2: 134 | for i in range(tensor.shape[0]): 135 | for j in range(tensor.shape[1]): 136 | tensor_mean[i,j] = tensor[i,j].item() / world_size 137 | elif len(tensor.shape) == 1: 138 | for i in range(tensor.shape[0]): 139 | tensor_mean[i] = tensor[i].item() / world_size 140 | return tensor_mean 141 | 142 | def numel(m: torch.nn.Module, only_trainable: bool = False): 143 | """ 144 | returns the total number of parameters used by `m` (only counting 145 | shared parameters once); if `only_trainable` is True, then only 146 | includes parameters with `requires_grad = True` 147 | """ 148 | parameters = m.parameters() 149 | if only_trainable: 150 | parameters = list(p for p in parameters if p.requires_grad) 151 | unique = dict((p.data_ptr(), p) for p in parameters).values() 152 | return sum(p.numel() for p in unique) 153 | 154 | 155 | def label_smooth(y, K, epsilon=0.1): 156 | """ 157 | Label smoothing for multiclass labels 158 | One hot encode labels `y` over `K` classes. `y` should be of the form [1, 6, 3, etc.] 159 | """ 160 | m = len(y) 161 | out = np.ones((m, K)) * epsilon / K 162 | for index in range(m): 163 | out[index][y[index] - 1] += 1 - epsilon 164 | return torch.tensor(out) 165 | 166 | 167 | class SequentialDistributedSampler(torch.utils.data.sampler.Sampler): 168 | """ 169 | Distributed Sampler that subsamples indicies sequentially, 170 | making it easier to collate all results at the end. 171 | Even though we only use this sampler for eval and predict (no training), 172 | which means that the model params won't have to be synced (i.e. will not hang 173 | for synchronization even if varied number of forward passes), we still add extra 174 | samples to the sampler to make it evenly divisible (like in `DistributedSampler`) 175 | to make it easy to `gather` or `reduce` resulting tensors at the end of the loop. 176 | """ 177 | 178 | def __init__(self, dataset, batch_size, world_size, rank=None, num_replicas=None): 179 | if num_replicas is None: 180 | if not torch.distributed.is_available(): 181 | raise RuntimeError("Requires distributed package to be available") 182 | num_replicas = world_size 183 | if rank is None: 184 | if not torch.distributed.is_available(): 185 | raise RuntimeError("Requires distributed package to be available") 186 | rank = torch.distributed.get_rank() 187 | self.dataset = dataset 188 | self.num_replicas = num_replicas 189 | self.rank = rank 190 | self.batch_size = batch_size 191 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.batch_size / self.num_replicas)) * self.batch_size 192 | self.total_size = self.num_samples * self.num_replicas 193 | 194 | def __iter__(self): 195 | indices = list(range(len(self.dataset))) 196 | # add extra samples to make it evenly divisible 197 | indices += [indices[-1]] * (self.total_size - len(indices)) 198 | # subsample 199 | indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples] 200 | return iter(indices) 201 | 202 | def __len__(self): 203 | return self.num_samples 204 | 205 | 206 | def distributed_concat(tensor, num_total_examples, world_size): 207 | """ 208 | 合并不同进程的inference结果 209 | """ 210 | output_tensors = [tensor.clone() for _ in range(world_size)] 211 | torch.distributed.all_gather(output_tensors, tensor) 212 | concat = torch.cat(output_tensors, dim=0) 213 | # truncate the dummy elements added by SequentialDistributedSampler 214 | return concat[:num_total_examples] 215 | 216 | 217 | class CosineAnnealingWarmupRestarts(_LRScheduler): 218 | """ 219 | optimizer (Optimizer): Wrapped optimizer. 220 | first_cycle_steps (int): First cycle step size. 221 | cycle_mult(float): Cycle steps magnification. Default: -1. 222 | max_lr(float): First cycle's max learning rate. Default: 0.1. 223 | min_lr(float): Min learning rate. Default: 0.001. 224 | warmup_steps(int): Linear warmup step size. Default: 0. 225 | gamma(float): Decrease rate of max learning rate by cycle. Default: 1. 226 | last_epoch (int): The index of last epoch. Default: -1. 227 | """ 228 | 229 | def __init__(self, 230 | optimizer : torch.optim.Optimizer, 231 | first_cycle_steps : int, 232 | cycle_mult : float = 1., 233 | max_lr : float = 0.1, 234 | min_lr : float = 0.001, 235 | warmup_steps : int = 0, 236 | gamma : float = 1., 237 | last_epoch : int = -1 238 | ): 239 | assert warmup_steps < first_cycle_steps 240 | 241 | self.first_cycle_steps = first_cycle_steps # first cycle step size 242 | self.cycle_mult = cycle_mult # cycle steps magnification 243 | self.base_max_lr = max_lr # first max learning rate 244 | self.max_lr = max_lr # max learning rate in the current cycle 245 | self.min_lr = min_lr # min learning rate 246 | self.warmup_steps = warmup_steps # warmup step size 247 | self.gamma = gamma # decrease rate of max learning rate by cycle 248 | 249 | self.cur_cycle_steps = first_cycle_steps # first cycle step size 250 | self.cycle = 0 # cycle count 251 | self.step_in_cycle = last_epoch # step size of the current cycle 252 | 253 | super(CosineAnnealingWarmupRestarts, self).__init__(optimizer, last_epoch) 254 | 255 | # set learning rate min_lr 256 | self.init_lr() 257 | 258 | def init_lr(self): 259 | self.base_lrs = [] 260 | for param_group in self.optimizer.param_groups: 261 | param_group['lr'] = self.min_lr 262 | self.base_lrs.append(self.min_lr) 263 | 264 | def get_lr(self): 265 | if self.step_in_cycle == -1: 266 | return self.base_lrs 267 | elif self.step_in_cycle < self.warmup_steps: 268 | return [(self.max_lr - base_lr)*self.step_in_cycle / self.warmup_steps + base_lr for base_lr in self.base_lrs] 269 | else: 270 | return [base_lr + (self.max_lr - base_lr) \ 271 | * (1 + math.cos(math.pi * (self.step_in_cycle-self.warmup_steps) \ 272 | / (self.cur_cycle_steps - self.warmup_steps))) / 2 273 | for base_lr in self.base_lrs] 274 | 275 | def step(self, epoch=None): 276 | if epoch is None: 277 | epoch = self.last_epoch + 1 278 | self.step_in_cycle = self.step_in_cycle + 1 279 | if self.step_in_cycle >= self.cur_cycle_steps: 280 | self.cycle += 1 281 | self.step_in_cycle = self.step_in_cycle - self.cur_cycle_steps 282 | self.cur_cycle_steps = int((self.cur_cycle_steps - self.warmup_steps) * self.cycle_mult) + self.warmup_steps 283 | else: 284 | if epoch >= self.first_cycle_steps: 285 | if self.cycle_mult == 1.: 286 | self.step_in_cycle = epoch % self.first_cycle_steps 287 | self.cycle = epoch // self.first_cycle_steps 288 | else: 289 | n = int(math.log((epoch / self.first_cycle_steps * (self.cycle_mult - 1) + 1), self.cycle_mult)) 290 | self.cycle = n 291 | self.step_in_cycle = epoch - int(self.first_cycle_steps * (self.cycle_mult ** n - 1) / (self.cycle_mult - 1)) 292 | self.cur_cycle_steps = self.first_cycle_steps * self.cycle_mult ** (n) 293 | else: 294 | self.cur_cycle_steps = self.first_cycle_steps 295 | self.step_in_cycle = epoch 296 | 297 | self.max_lr = self.base_max_lr * (self.gamma**self.cycle) 298 | self.last_epoch = math.floor(epoch) 299 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 300 | param_group['lr'] = lr 301 | 302 | 303 | class DistanceLoss(_WeightedLoss): 304 | """ 305 | CrossEntropyLoss with Distance Weighted 306 | """ 307 | def __init__(self, weight=None, reduction='mean', ignore_index = None): 308 | super().__init__(weight=weight, reduction=reduction) 309 | self.weight = weight 310 | self.reduction = reduction 311 | self.ignore_index = ignore_index 312 | def forward(self, inputs, targets): 313 | if len(inputs.shape) > 2: 314 | inputs = inputs.reshape(-1, inputs.size(-1)) 315 | if len(targets.shape) > 1: 316 | targets = targets.reshape(-1) 317 | if self.ignore_index is not None: 318 | keep_index = (targets != self.ignore_index).nonzero(as_tuple=True)[0] 319 | targets = torch.index_select(targets, 0, keep_index) #targets[targets != self.ignore_index] 320 | inputs = torch.index_select(inputs, 0, keep_index) 321 | lsm = F.log_softmax(inputs, -1) 322 | targets = torch.empty(size=(targets.size(0), inputs.size(-1)), device=targets.device).fill_(0).scatter_(1, targets.data.unsqueeze(1), 1) 323 | if self.weight is not None: 324 | lsm = lsm * self.weight.unsqueeze(0) 325 | loss = -(targets * lsm).sum(-1) 326 | inputs = nn.Softmax(dim=-1)(inputs)[..., 1:-1].argmax(dim=-1) + 1 327 | # print('inputs', inputs.device, inputs.shape) 328 | targets = nn.Softmax(dim=-1)(targets)[..., 1:-1].argmax(dim=-1) + 1 329 | # print('targets', targets.device, targets.shape) 330 | distance = abs(inputs - targets) + 1e-2 331 | # print('loss.shape', loss.shape) 332 | # print('distance.shape', distance.shape) 333 | loss = loss * distance 334 | if self.reduction == 'sum': 335 | loss = loss.sum() 336 | elif self.reduction == 'mean': 337 | loss = loss.mean() 338 | return loss 339 | 340 | 341 | class LabelSmoothCrossEntropyLoss(_WeightedLoss): 342 | """ 343 | CrossEntropyLoss with Label Somoothing 344 | """ 345 | def __init__(self, weight=None, reduction='mean', smoothing=0.0): 346 | super().__init__(weight=weight, reduction=reduction) 347 | self.smoothing = smoothing 348 | self.weight = weight 349 | self.reduction = reduction 350 | 351 | @staticmethod 352 | def _smooth_one_hot(targets: torch.Tensor, n_classes: int, smoothing=0.0): 353 | assert 0 <= smoothing < 1 354 | with torch.no_grad(): 355 | targets = torch.empty(size=(targets.size(0), n_classes), 356 | device=targets.device) \ 357 | .fill_(smoothing / (n_classes - 1)) \ 358 | .scatter_(1, targets.data.unsqueeze(1), 1. - smoothing) 359 | return targets 360 | 361 | def forward(self, inputs, targets): 362 | targets = LabelSmoothCrossEntropyLoss._smooth_one_hot(targets, inputs.size(-1), 363 | self.smoothing) 364 | lsm = F.log_softmax(inputs, -1) 365 | 366 | if self.weight is not None: 367 | lsm = lsm * self.weight.unsqueeze(0) 368 | 369 | loss = -(targets * lsm).sum(-1) 370 | 371 | if self.reduction == 'sum': 372 | loss = loss.sum() 373 | elif self.reduction == 'mean': 374 | loss = loss.mean() 375 | 376 | return loss 377 | -------------------------------------------------------------------------------- /scGPT/scGPT_baselines_LR.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gc 3 | import argparse 4 | import json 5 | import random 6 | import math 7 | import random 8 | from functools import reduce 9 | import numpy as np 10 | import pandas as pd 11 | from scipy import sparse 12 | from sklearn.model_selection import train_test_split, cross_validate 13 | from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, recall_score, precision_score 14 | from tqdm import tqdm 15 | 16 | import scanpy as sc 17 | import anndata as ad 18 | from datetime import datetime 19 | from time import time 20 | import matplotlib.pyplot as plt 21 | from sklearn.linear_model import LogisticRegression 22 | import seaborn as sns 23 | from pathlib import Path 24 | 25 | # Define the file path where you want to store the results 26 | res_file_path = 'LR_sampleeff_results.txt' 27 | with open(res_file_path, 'a') as file: 28 | file.write("dataset_name\tfraction\tseed\tc\taccuracy\tprecision\trecall\tmacro_f1\n") 29 | 30 | now = datetime.now() 31 | 32 | for DATASET_NAME in ['ms', 'pancreas', 'myeloid']: 33 | print(DATASET_NAME) 34 | ## Step 2: Load and pre-process data 35 | 36 | if DATASET_NAME == "ms": 37 | data_dir = Path("/localdata/rna_rep_learning/scGPT/ms") #RB 38 | adata = sc.read(data_dir / "c_data.h5ad") 39 | adata_test = sc.read(data_dir / "filtered_ms_adata.h5ad") 40 | adata.obs["celltype"] = adata.obs["Factor Value[inferred cell type - authors labels]"].astype("category") 41 | adata_test.obs["celltype"] = adata_test.obs["Factor Value[inferred cell type - authors labels]"].astype("category") 42 | adata.obs["batch_id"] = adata.obs["str_batch"] = "0" 43 | adata_test.obs["batch_id"] = adata_test.obs["str_batch"] = "1" 44 | adata.var.set_index(adata.var["gene_name"], inplace=True) 45 | adata_test.var.set_index(adata.var["gene_name"], inplace=True) 46 | data_is_raw = False 47 | filter_gene_by_counts = False 48 | adata_test_raw = adata_test.copy() 49 | adata = adata.concatenate(adata_test, batch_key="str_batch") 50 | 51 | if DATASET_NAME == "pancreas": 52 | data_dir = Path("/localdata/rna_rep_learning/scGPT/pancreas") 53 | adata = sc.read(data_dir / "demo_train.h5ad") 54 | adata_test = sc.read(data_dir / "demo_test.h5ad") 55 | adata.obs["celltype"] = adata.obs["Celltype"].astype("category") 56 | adata_test.obs["celltype"] = adata_test.obs["Celltype"].astype("category") 57 | adata.obs["batch_id"] = adata.obs["str_batch"] = "0" 58 | adata_test.obs["batch_id"] = adata_test.obs["str_batch"] = "1" 59 | data_is_raw = False 60 | filter_gene_by_counts = False 61 | adata_test_raw = adata_test.copy() 62 | adata = adata.concatenate(adata_test, batch_key="str_batch") 63 | 64 | if DATASET_NAME == "myeloid": 65 | data_dir = Path("/localdata/rna_rep_learning/scGPT/myeloid/") 66 | adata = sc.read(data_dir / "reference_adata.h5ad") 67 | adata_test = sc.read(data_dir / "query_adata.h5ad") 68 | adata.obs["celltype"] = adata.obs["cell_type"].astype("category") 69 | adata_test.obs["celltype"] = adata_test.obs["cell_type"].astype("category") 70 | adata.obs["batch_id"] = adata.obs["str_batch"] = "0" 71 | adata_test.obs["batch_id"] = adata_test.obs["str_batch"] = "1" 72 | adata_test_raw = adata_test.copy() 73 | data_is_raw = False 74 | filter_gene_by_counts = False 75 | adata = adata.concatenate(adata_test, batch_key="str_batch") 76 | 77 | # make the batch category column 78 | batch_id_labels = adata.obs["str_batch"].astype("category").cat.codes.values 79 | adata.obs["batch_id"] = batch_id_labels 80 | celltype_id_labels = adata.obs["celltype"].astype("category").cat.codes.values 81 | celltypes = adata.obs["celltype"].unique() 82 | num_types = len(np.unique(celltype_id_labels)) 83 | id2type = dict(enumerate(adata.obs["celltype"].astype("category").cat.categories)) 84 | adata.obs["celltype_id"] = celltype_id_labels 85 | adata.var["gene_name"] = adata.var.index.tolist() 86 | 87 | adata_test = adata[adata.obs["str_batch"] == "1"] 88 | adata = adata[adata.obs["str_batch"] == "0"] 89 | 90 | all_counts_full = ( 91 | adata.X.A 92 | if sparse.issparse(adata.X) 93 | else adata.X 94 | ) 95 | all_counts_test = ( 96 | adata_test.X.A 97 | if sparse.issparse(adata_test.X) 98 | else adata_test.X 99 | ) 100 | genes = adata.var["gene_name"].tolist() 101 | 102 | celltypes_labels_full = adata.obs["celltype_id"].tolist() # make sure count from 0 103 | celltypes_labels_full = np.array(celltypes_labels_full) 104 | celltypes_labels_test = adata_test.obs["celltype_id"].tolist() # make sure count from 0 105 | celltypes_labels_test = np.array(celltypes_labels_test) 106 | 107 | #batch_ids = adata.obs["batch_id"].tolist() 108 | #num_batch_types = len(set(batch_ids)) 109 | #batch_ids = np.array(batch_ids) 110 | 111 | for FRAC in [0.1, 0.25, 0.5, 0.75, 1]: 112 | for SEED in [1,2,3,4,5,6,7,8,9,10]: 113 | print("fraction: ", FRAC) 114 | print("SEED: ", SEED) 115 | save_dir = Path(f"./save/logisticregression/lr-{DATASET_NAME}-frac{FRAC}-seed{SEED}-{now}/") 116 | save_dir.mkdir(parents=True, exist_ok=True) 117 | print(f"save to {save_dir}") 118 | 119 | ## optionally subset train data for few shot experiments - RB 120 | if FRAC != 1: 121 | print("subsetting to {}% training data".format(FRAC*100)) 122 | all_counts, _, celltypes_labels, _, = train_test_split(all_counts_full, celltypes_labels_full, train_size=FRAC, random_state=SEED, shuffle=True, stratify=celltypes_labels_full) 123 | else: 124 | all_counts = all_counts_full.copy() 125 | celltypes_labels = celltypes_labels_full.copy() 126 | 127 | ## choose c using k fold cross val 128 | #if SEED == 1: # can do this just once per fraction (share c across seeds) 129 | print("running cross validation to choose c...") 130 | cv_results = {} 131 | for c in [1e-3, 1e-2, 1e-1, 1, 10, 100, 1000, 10000]: 132 | lr = LogisticRegression(random_state=0, penalty="l1", C=c, solver="liblinear") 133 | res = cross_validate(lr, all_counts, celltypes_labels, scoring=['accuracy']) 134 | cv_results[c] = np.mean(res['test_accuracy']) 135 | 136 | #choose best c 137 | best_ind = np.argmax(list(cv_results.values())) 138 | c = list(cv_results.keys())[best_ind] 139 | print(f'for {FRAC*100}% of {DATASET_NAME} data with seed {SEED}, best c is {c}') 140 | 141 | ## run LR 142 | lr = LogisticRegression(penalty="l1", C=c, solver="liblinear", random_state=SEED) 143 | lr.fit(all_counts, celltypes_labels) 144 | 145 | test_acc = lr.score(all_counts_test, celltypes_labels_test) 146 | print("test set accuracy: " + str(np.around(test_acc, 4))) 147 | 148 | test_recall = recall_score(celltypes_labels_test, lr.predict(all_counts_test), average="macro") #Based on github, looks like they actually used macro (they dont claim to in paper, but is consistent with their results) 149 | print("test set recall: " + str(np.around(test_recall, 4))) 150 | 151 | test_precision = precision_score(celltypes_labels_test, lr.predict(all_counts_test), average="macro") #Based on github, looks like they actually used macro (they dont claim to in paper, but is consistent with their results) 152 | print("test set precision: " + str(np.around(test_precision, 4))) 153 | 154 | test_macro_f1 = f1_score(celltypes_labels_test, lr.predict(all_counts_test), average="macro") 155 | print("test set macro F1: " + str(np.around(test_macro_f1, 4))) 156 | 157 | ## plot confusion matrix 158 | from sklearn.metrics import confusion_matrix 159 | celltypes = list(celltypes) 160 | predictions = lr.predict(all_counts_test) 161 | for i in set([id2type[p] for p in predictions]): 162 | if i not in celltypes: 163 | celltypes.remove(i) 164 | print("removing cell type {}".format(i)) 165 | cm = confusion_matrix(celltypes_labels_test, predictions) 166 | cm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis] 167 | cm = pd.DataFrame(cm, index=celltypes[:cm.shape[0]], columns=celltypes[:cm.shape[1]]) 168 | plt.figure(figsize=(10, 10)) 169 | sns.heatmap(cm, annot=True, fmt=".1f", cmap="Blues") 170 | plt.savefig(save_dir / "confusion_matrix.png", dpi=300, bbox_inches="tight") 171 | 172 | ## write results to file 173 | with open(res_file_path, 'a') as file: 174 | file.write(f"{DATASET_NAME}\t{FRAC}\t{SEED}\t{c}\t{test_acc}\t{test_precision}\t{test_recall}\t{test_macro_f1}\n") 175 | 176 | print("\n") 177 | -------------------------------------------------------------------------------- /scGPT/scgpt_environment.yml: -------------------------------------------------------------------------------- 1 | name: scgpt 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=conda_forge 7 | - _openmp_mutex=4.5=2_gnu 8 | - _r-mutex=1.0.1=anacondar_1 9 | - binutils_impl_linux-64=2.40=hf600244_0 10 | - bwidget=1.9.14=ha770c72_1 11 | - bzip2=1.0.8=h7f98852_4 12 | - c-ares=1.19.1=hd590300_0 13 | - ca-certificates=2023.7.22=hbcca054_0 14 | - cairo=1.16.0=h0c91306_1017 15 | - cuda-version=11.7=h67201e3_2 16 | - cudatoolkit=11.7.1=h4bc3d14_12 17 | - cudatoolkit-dev=11.7.0=h1de0b5d_6 18 | - cudnn=8.8.0.121=h838ba91_3 19 | - curl=8.3.0=hca28451_0 20 | - expat=2.5.0=hcb278e6_1 21 | - font-ttf-dejavu-sans-mono=2.37=hab24e00_0 22 | - font-ttf-inconsolata=3.000=h77eed37_0 23 | - font-ttf-source-code-pro=2.038=h77eed37_0 24 | - font-ttf-ubuntu=0.83=hab24e00_0 25 | - fontconfig=2.14.2=h14ed4e7_0 26 | - fonts-conda-ecosystem=1=0 27 | - fonts-conda-forge=1=0 28 | - freetype=2.12.1=h267a509_2 29 | - fribidi=1.0.10=h36c2ea0_0 30 | - gcc=11.4.0=h7baecda_2 31 | - gcc_impl_linux-64=11.4.0=h7aa1c59_2 32 | - gettext=0.21.1=h27087fc_0 33 | - gfortran_impl_linux-64=11.4.0=h86428dc_2 34 | - gmp=6.2.1=h58526e2_0 35 | - graphite2=1.3.13=h58526e2_1001 36 | - gxx=11.4.0=h7baecda_2 37 | - gxx_impl_linux-64=11.4.0=h7aa1c59_2 38 | - harfbuzz=8.2.1=h3d44ed6_0 39 | - icu=73.2=h59595ed_0 40 | - kernel-headers_linux-64=2.6.32=he073ed8_16 41 | - keyutils=1.6.1=h166bdaf_0 42 | - krb5=1.21.2=h659d440_0 43 | - ld_impl_linux-64=2.40=h41732ed_0 44 | - lerc=4.0.0=h27087fc_0 45 | - libblas=3.9.0=18_linux64_openblas 46 | - libcurl=8.3.0=hca28451_0 47 | - libdeflate=1.19=hd590300_0 48 | - libedit=3.1.20191231=he28a2e2_2 49 | - libev=4.33=h516909a_1 50 | - libexpat=2.5.0=hcb278e6_1 51 | - libffi=3.4.2=h7f98852_5 52 | - libgcc-devel_linux-64=11.4.0=h922705a_2 53 | - libgcc-ng=13.2.0=h807b86a_2 54 | - libgfortran-ng=13.2.0=h69a702a_2 55 | - libgfortran5=13.2.0=ha4646dd_2 56 | - libgit2=1.7.1=hca3a8ce_0 57 | - libglib=2.78.0=hebfc3b9_0 58 | - libgomp=13.2.0=h807b86a_2 59 | - libiconv=1.17=h166bdaf_0 60 | - libjpeg-turbo=2.1.5.1=hd590300_1 61 | - liblapack=3.9.0=18_linux64_openblas 62 | - libnghttp2=1.52.0=h61bc06f_0 63 | - libnsl=2.0.0=h7f98852_0 64 | - libopenblas=0.3.24=pthreads_h413a1c8_0 65 | - libpng=1.6.39=h753d276_0 66 | - libsanitizer=11.4.0=h4dcbe23_2 67 | - libsqlite=3.43.0=h2797004_0 68 | - libssh2=1.11.0=h0841786_0 69 | - libstdcxx-devel_linux-64=11.4.0=h922705a_2 70 | - libstdcxx-ng=13.2.0=h7e041cc_2 71 | - libtiff=4.6.0=h29866fb_1 72 | - libuuid=2.38.1=h0b41bf4_0 73 | - libwebp-base=1.3.2=hd590300_0 74 | - libxcb=1.15=h0b41bf4_0 75 | - libxml2=2.11.5=h232c23b_1 76 | - libzlib=1.2.13=hd590300_5 77 | - make=4.3=hd18ef5c_1 78 | - ncurses=6.4=hcb278e6_0 79 | - openssl=3.1.3=hd590300_0 80 | - pandoc=3.1.3=h32600fe_0 81 | - pango=1.50.14=ha41ecd1_2 82 | - pcre2=10.40=hc3806b6_0 83 | - pip=23.2.1=pyhd8ed1ab_0 84 | - pixman=0.40.0=h36c2ea0_0 85 | - pthread-stubs=0.4=h36c2ea0_1001 86 | - python=3.10.11=he550d4f_0_cpython 87 | - r-askpass=1.2.0=r43h57805ef_0 88 | - r-assertthat=0.2.1=r43hc72bb7e_4 89 | - r-base=4.3.1=h639d9d3_5 90 | - r-base64enc=0.1_3=r43h57805ef_1006 91 | - r-brew=1.0_8=r43hc72bb7e_2 92 | - r-brio=1.1.3=r43h57805ef_2 93 | - r-bslib=0.5.1=r43hc72bb7e_0 94 | - r-cachem=1.0.8=r43h57805ef_1 95 | - r-callr=3.7.3=r43hc72bb7e_1 96 | - r-cli=3.6.1=r43ha503ecb_1 97 | - r-clipr=0.8.0=r43hc72bb7e_2 98 | - r-commonmark=1.9.0=r43h57805ef_1 99 | - r-cpp11=0.4.6=r43hc72bb7e_0 100 | - r-crayon=1.5.2=r43hc72bb7e_2 101 | - r-credentials=2.0.1=r43hc72bb7e_0 102 | - r-curl=5.0.2=r43hf9611b0_0 103 | - r-desc=1.4.2=r43hc72bb7e_2 104 | - r-devtools=2.4.5=r43hc72bb7e_2 105 | - r-diffobj=0.3.5=r43h57805ef_2 106 | - r-digest=0.6.33=r43ha503ecb_0 107 | - r-downlit=0.4.3=r43hc72bb7e_0 108 | - r-ellipsis=0.3.2=r43h57805ef_2 109 | - r-evaluate=0.21=r43hc72bb7e_1 110 | - r-fansi=1.0.4=r43h57805ef_1 111 | - r-fastmap=1.1.1=r43ha503ecb_1 112 | - r-fontawesome=0.5.2=r43hc72bb7e_0 113 | - r-fs=1.6.3=r43ha503ecb_0 114 | - r-gert=1.9.3=r43hc25a090_1 115 | - r-gh=1.4.0=r43hc72bb7e_1 116 | - r-gitcreds=0.1.2=r43hc72bb7e_2 117 | - r-glue=1.6.2=r43h57805ef_2 118 | - r-highr=0.10=r43hc72bb7e_1 119 | - r-htmltools=0.5.6=r43ha503ecb_0 120 | - r-htmlwidgets=1.6.2=r43hc72bb7e_1 121 | - r-httpuv=1.6.11=r43ha503ecb_1 122 | - r-httr=1.4.7=r43hc72bb7e_0 123 | - r-httr2=0.2.3=r43hc72bb7e_1 124 | - r-ini=0.3.1=r43hc72bb7e_1005 125 | - r-jquerylib=0.1.4=r43hc72bb7e_2 126 | - r-jsonlite=1.8.7=r43h57805ef_0 127 | - r-knitr=1.44=r43hc72bb7e_0 128 | - r-later=1.3.1=r43ha503ecb_1 129 | - r-lifecycle=1.0.3=r43hc72bb7e_2 130 | - r-magrittr=2.0.3=r43h57805ef_2 131 | - r-memoise=2.0.1=r43hc72bb7e_2 132 | - r-mime=0.12=r43h57805ef_2 133 | - r-miniui=0.1.1.1=r43hc72bb7e_1004 134 | - r-openssl=2.1.1=r43hb353fa6_0 135 | - r-pillar=1.9.0=r43hc72bb7e_1 136 | - r-pkgbuild=1.4.2=r43hc72bb7e_0 137 | - r-pkgconfig=2.0.3=r43hc72bb7e_3 138 | - r-pkgdown=2.0.7=r43hc72bb7e_1 139 | - r-pkgload=1.3.3=r43hc72bb7e_0 140 | - r-praise=1.0.0=r43hc72bb7e_1007 141 | - r-prettyunits=1.2.0=r43hc72bb7e_0 142 | - r-processx=3.8.2=r43h57805ef_0 143 | - r-profvis=0.3.8=r43h57805ef_3 144 | - r-promises=1.2.1=r43ha503ecb_0 145 | - r-ps=1.7.5=r43h57805ef_1 146 | - r-purrr=1.0.2=r43h57805ef_0 147 | - r-r6=2.5.1=r43hc72bb7e_2 148 | - r-ragg=1.2.5=r43hd759249_3 149 | - r-rappdirs=0.3.3=r43h57805ef_2 150 | - r-rcmdcheck=1.4.0=r43h785f33e_2 151 | - r-rcpp=1.0.11=r43h7df8631_0 152 | - r-rematch2=2.1.2=r43hc72bb7e_3 153 | - r-remotes=2.4.2.1=r43hc72bb7e_0 154 | - r-rlang=1.1.1=r43ha503ecb_1 155 | - r-rmarkdown=2.25=r43hc72bb7e_0 156 | - r-roxygen2=7.2.3=r43ha503ecb_1 157 | - r-rprojroot=2.0.3=r43hc72bb7e_0 158 | - r-rstudioapi=0.15.0=r43hc72bb7e_0 159 | - r-rversions=2.1.2=r43hc72bb7e_2 160 | - r-sass=0.4.7=r43ha503ecb_0 161 | - r-sessioninfo=1.2.2=r43hc72bb7e_2 162 | - r-shiny=1.7.5=r43h785f33e_0 163 | - r-sourcetools=0.1.7_1=r43ha503ecb_1 164 | - r-stringi=1.7.12=r43h9facbd6_3 165 | - r-stringr=1.5.0=r43h785f33e_1 166 | - r-sys=3.4.2=r43h57805ef_1 167 | - r-systemfonts=1.0.4=r43haf97adc_2 168 | - r-testthat=3.1.10=r43ha503ecb_0 169 | - r-textshaping=0.3.6=r43hd87b9d6_7 170 | - r-tibble=3.2.1=r43h57805ef_2 171 | - r-tinytex=0.46=r43hc72bb7e_0 172 | - r-urlchecker=1.0.1=r43hc72bb7e_2 173 | - r-usethis=2.2.2=r43hc72bb7e_0 174 | - r-utf8=1.2.3=r43h57805ef_1 175 | - r-vctrs=0.6.3=r43ha503ecb_0 176 | - r-waldo=0.5.1=r43hc72bb7e_1 177 | - r-whisker=0.4.1=r43hc72bb7e_1 178 | - r-withr=2.5.0=r43hc72bb7e_2 179 | - r-xfun=0.40=r43ha503ecb_0 180 | - r-xml2=1.3.5=r43h1ad5fc0_0 181 | - r-xopen=1.0.0=r43hc72bb7e_1005 182 | - r-xtable=1.8_4=r43hc72bb7e_5 183 | - r-yaml=2.3.7=r43h57805ef_1 184 | - r-zip=2.3.0=r43h57805ef_1 185 | - readline=8.2=h8228510_1 186 | - sed=4.8=he412f7d_0 187 | - setuptools=68.2.2=pyhd8ed1ab_0 188 | - sysroot_linux-64=2.12=he073ed8_16 189 | - tk=8.6.13=h2797004_0 190 | - tktable=2.10=h0c5db8f_5 191 | - wheel=0.41.2=pyhd8ed1ab_0 192 | - xorg-kbproto=1.0.7=h7f98852_1002 193 | - xorg-libice=1.1.1=hd590300_0 194 | - xorg-libsm=1.2.4=h7391055_0 195 | - xorg-libx11=1.8.6=h8ee46fc_0 196 | - xorg-libxau=1.0.11=hd590300_0 197 | - xorg-libxdmcp=1.1.3=h7f98852_0 198 | - xorg-libxext=1.3.4=h0b41bf4_2 199 | - xorg-libxrender=0.9.11=hd590300_0 200 | - xorg-libxt=1.3.0=hd590300_1 201 | - xorg-renderproto=0.11.1=h7f98852_1002 202 | - xorg-xextproto=7.3.0=h0b41bf4_1003 203 | - xorg-xproto=7.0.31=h7f98852_1007 204 | - xz=5.2.6=h166bdaf_0 205 | - zlib=1.2.13=hd590300_5 206 | - zstd=1.5.5=hfc55251_0 207 | - pip: 208 | - absl-py==2.0.0 209 | - aiohttp==3.8.5 210 | - aiosignal==1.3.1 211 | - anndata==0.9.2 212 | - annotated-types==0.5.0 213 | - anyio==3.7.1 214 | - appdirs==1.4.4 215 | - argon2-cffi==23.1.0 216 | - argon2-cffi-bindings==21.2.0 217 | - arrow==1.2.3 218 | - asttokens==2.4.0 219 | - async-lru==2.0.4 220 | - async-timeout==4.0.3 221 | - attrs==23.1.0 222 | - babel==2.13.0 223 | - backcall==0.2.0 224 | - backoff==2.2.1 225 | - beautifulsoup4==4.12.2 226 | - bleach==6.1.0 227 | - blessed==1.20.0 228 | - certifi==2023.7.22 229 | - cffi==1.16.0 230 | - charset-normalizer==3.2.0 231 | - chex==0.1.7 232 | - click==8.1.7 233 | - cmake==3.27.5 234 | - comm==0.1.4 235 | - contextlib2==21.6.0 236 | - contourpy==1.1.1 237 | - croniter==1.4.1 238 | - cycler==0.11.0 239 | - datasets==2.14.5 240 | - dateutils==0.6.12 241 | - debugpy==1.8.0 242 | - decorator==5.1.1 243 | - deepdiff==6.5.0 244 | - defusedxml==0.7.1 245 | - deprecated==1.2.14 246 | - dill==0.3.5.1 247 | - dm-tree==0.1.8 248 | - docker-pycreds==0.4.0 249 | - docrep==0.3.2 250 | - einops==0.6.1 251 | - etils==1.5.0 252 | - exceptiongroup==1.1.3 253 | - executing==1.2.0 254 | - fastapi==0.103.1 255 | - fastjsonschema==2.18.1 256 | - filelock==3.12.4 257 | - flash-attn==1.0.4 258 | - flax==0.7.4 259 | - fonttools==4.42.1 260 | - fqdn==1.5.1 261 | - frozenlist==1.4.0 262 | - fsspec==2023.6.0 263 | - gitdb==4.0.10 264 | - gitpython==3.1.37 265 | - h11==0.14.0 266 | - h5py==3.9.0 267 | - huggingface-hub==0.17.3 268 | - idna==3.4 269 | - igraph==0.10.8 270 | - importlib-resources==6.1.0 271 | - inquirer==3.1.3 272 | - ipykernel==6.25.2 273 | - ipython==8.15.0 274 | - isoduration==20.11.0 275 | - itsdangerous==2.1.2 276 | - jax==0.4.16 277 | - jaxlib==0.4.16 278 | - jedi==0.19.0 279 | - jinja2==3.1.2 280 | - joblib==1.3.2 281 | - json5==0.9.14 282 | - jsonpointer==2.4 283 | - jsonschema==4.19.1 284 | - jsonschema-specifications==2023.7.1 285 | - jupyter-client==8.3.1 286 | - jupyter-core==5.3.1 287 | - jupyter-events==0.8.0 288 | - jupyter-lsp==2.2.0 289 | - jupyter-server==2.8.0 290 | - jupyter-server-terminals==0.4.4 291 | - jupyterlab==4.0.7 292 | - jupyterlab-pygments==0.2.2 293 | - jupyterlab-server==2.25.0 294 | - kiwisolver==1.4.5 295 | - leidenalg==0.10.1 296 | - lightning==2.0.9 297 | - lightning-cloud==0.5.38 298 | - lightning-utilities==0.9.0 299 | - lit==17.0.1 300 | - llvmlite==0.41.0 301 | - markdown-it-py==3.0.0 302 | - markupsafe==2.1.3 303 | - matplotlib==3.7.3 304 | - matplotlib-inline==0.1.6 305 | - mdurl==0.1.2 306 | - mistune==3.0.2 307 | - ml-collections==0.1.1 308 | - ml-dtypes==0.3.1 309 | - mpmath==1.3.0 310 | - msgpack==1.0.6 311 | - mudata==0.2.3 312 | - multidict==6.0.4 313 | - multipledispatch==1.0.0 314 | - multiprocess==0.70.13 315 | - natsort==8.4.0 316 | - nbclient==0.8.0 317 | - nbconvert==7.9.2 318 | - nbformat==5.9.2 319 | - nest-asyncio==1.5.8 320 | - networkx==3.1 321 | - ninja==1.11.1 322 | - notebook-shim==0.2.3 323 | - numba==0.58.0 324 | - numpy==1.24.4 325 | - numpyro==0.13.2 326 | - nvidia-cublas-cu11==11.10.3.66 327 | - nvidia-cuda-cupti-cu11==11.7.101 328 | - nvidia-cuda-nvrtc-cu11==11.7.99 329 | - nvidia-cuda-runtime-cu11==11.7.99 330 | - nvidia-cudnn-cu11==8.5.0.96 331 | - nvidia-cufft-cu11==10.9.0.58 332 | - nvidia-curand-cu11==10.2.10.91 333 | - nvidia-cusolver-cu11==11.4.0.1 334 | - nvidia-cusparse-cu11==11.7.4.91 335 | - nvidia-nccl-cu11==2.14.3 336 | - nvidia-nvtx-cu11==11.7.91 337 | - opt-einsum==3.3.0 338 | - optax==0.1.7 339 | - orbax-checkpoint==0.4.0 340 | - ordered-set==4.1.0 341 | - overrides==7.4.0 342 | - packaging==23.1 343 | - pandas==1.5.3 344 | - pandocfilters==1.5.0 345 | - parso==0.8.3 346 | - pathtools==0.1.2 347 | - patsy==0.5.3 348 | - pexpect==4.8.0 349 | - pickleshare==0.7.5 350 | - pillow==10.0.1 351 | - platformdirs==3.10.0 352 | - prometheus-client==0.17.1 353 | - prompt-toolkit==3.0.39 354 | - protobuf==4.24.3 355 | - psutil==5.9.5 356 | - ptyprocess==0.7.0 357 | - pure-eval==0.2.2 358 | - pyarrow==13.0.0 359 | - pycparser==2.21 360 | - pydantic==2.1.1 361 | - pydantic-core==2.4.0 362 | - pydot==1.4.2 363 | - pygments==2.16.1 364 | - pyjwt==2.8.0 365 | - pynndescent==0.5.10 366 | - pyparsing==3.1.1 367 | - pyro-api==0.1.2 368 | - pyro-ppl==1.8.6 369 | - python-dateutil==2.8.2 370 | - python-editor==1.0.4 371 | - python-json-logger==2.0.7 372 | - python-multipart==0.0.6 373 | - pytorch-lightning==2.0.9 374 | - pytz==2023.3.post1 375 | - pyyaml==6.0.1 376 | - pyzmq==25.1.1 377 | - readchar==4.0.5 378 | - referencing==0.30.2 379 | - regex==2023.8.8 380 | - requests==2.31.0 381 | - responses==0.18.0 382 | - rfc3339-validator==0.1.4 383 | - rfc3986-validator==0.1.1 384 | - rich==13.5.3 385 | - rpds-py==0.10.6 386 | - safetensors==0.3.3 387 | - scanpy==1.9.5 388 | - scgpt==0.1.7 389 | - scib==1.1.4 390 | - scikit-learn==1.3.1 391 | - scikit-misc==0.3.0 392 | - scipy==1.11.2 393 | - scvi-tools==1.0.3 394 | - seaborn==0.12.2 395 | - send2trash==1.8.2 396 | - sentry-sdk==1.31.0 397 | - session-info==1.0.0 398 | - setproctitle==1.3.2 399 | - six==1.16.0 400 | - smmap==5.0.1 401 | - sniffio==1.3.0 402 | - soupsieve==2.5 403 | - sparse==0.14.0 404 | - stack-data==0.6.2 405 | - starlette==0.27.0 406 | - starsessions==1.3.0 407 | - statsmodels==0.14.0 408 | - stdlib-list==0.9.0 409 | - sympy==1.12 410 | - tbb==2021.10.0 411 | - tensorstore==0.1.44 412 | - terminado==0.17.1 413 | - texttable==1.6.7 414 | - threadpoolctl==3.2.0 415 | - tinycss2==1.2.1 416 | - tokenizers==0.13.3 417 | - tomli==2.0.1 418 | - toolz==0.12.0 419 | - torch==2.0.1 420 | - torchaudio==2.0.2 421 | - torchdata==0.6.1 422 | - torchmetrics==1.2.0 423 | - torchtext==0.15.2 424 | - torchvision==0.15.2 425 | - tornado==6.3.3 426 | - tqdm==4.66.1 427 | - traitlets==5.10.1 428 | - transformers==4.33.2 429 | - triton==2.0.0 430 | - typing-extensions==4.8.0 431 | - tzdata==2023.3 432 | - umap-learn==0.5.4 433 | - uri-template==1.3.0 434 | - urllib3==2.0.5 435 | - uvicorn==0.23.2 436 | - wandb==0.15.11 437 | - wcwidth==0.2.6 438 | - webcolors==1.13 439 | - webencodings==0.5.1 440 | - websocket-client==1.6.3 441 | - websockets==11.0.3 442 | - wrapt==1.15.0 443 | - xarray==2023.9.0 444 | - xxhash==3.3.0 445 | - yarl==1.9.2 446 | - zipp==3.17.0 447 | prefix: /opt/conda/rpeyser/envs/scgpt_2 448 | --------------------------------------------------------------------------------