├── HIST2ST.py ├── HIST2ST_train.py ├── NB_module.py ├── README.md ├── Workflow.png ├── data ├── download.sh ├── her_hvg_cut_1000.npy └── skin_hvg_cut_1000.npy ├── dataset.py ├── gcn.py ├── graph_construction.py ├── predict.py ├── run_trained_models.ipynb ├── transformer.py ├── tutorial.ipynb └── utils.py /HIST2ST.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import pytorch_lightning as pl 4 | import torchvision.transforms as tf 5 | from gcn import * 6 | from NB_module import * 7 | from transformer import * 8 | from scipy.stats import pearsonr 9 | from torch.utils.data import DataLoader 10 | from copy import deepcopy as dcp 11 | from collections import defaultdict as dfd 12 | from sklearn.metrics import adjusted_rand_score as ari_score 13 | from sklearn.metrics.cluster import normalized_mutual_info_score as nmi_score 14 | class convmixer_block(nn.Module): 15 | def __init__(self,dim,kernel_size): 16 | super().__init__() 17 | self.dw=nn.Sequential( 18 | nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same"), 19 | nn.BatchNorm2d(dim), 20 | nn.GELU(), 21 | nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same"), 22 | nn.BatchNorm2d(dim), 23 | nn.GELU(), 24 | ) 25 | self.pw=nn.Sequential( 26 | nn.Conv2d(dim, dim, kernel_size=1), 27 | nn.GELU(), 28 | nn.BatchNorm2d(dim), 29 | ) 30 | def forward(self,x): 31 | x=self.dw(x)+x 32 | x=self.pw(x) 33 | return x 34 | class mixer_transformer(nn.Module): 35 | def __init__(self,channel=32, kernel_size=5, dim=1024, 36 | depth1=2, depth2=8, depth3=4, 37 | heads=8, dim_head=64, mlp_dim=1024, dropout = 0., 38 | policy='mean',gcn=True 39 | ): 40 | super().__init__() 41 | self.layer1=nn.Sequential( 42 | *[convmixer_block(channel,kernel_size) for i in range(depth1)], 43 | ) 44 | self.layer2=nn.Sequential(*[attn_block(dim,heads,dim_head,mlp_dim,dropout) for i in range(depth2)]) 45 | self.layer3=nn.ModuleList([gs_block(dim,dim,policy,gcn) for i in range(depth3)]) 46 | self.jknet=nn.Sequential( 47 | nn.LSTM(dim,dim,2), 48 | SelectItem(0), 49 | ) 50 | self.down=nn.Sequential( 51 | nn.Conv2d(channel,channel//8,1,1), 52 | nn.Flatten(), 53 | ) 54 | def forward(self,x,ct,adj): 55 | x=self.down(self.layer1(x)) 56 | g=x.unsqueeze(0) 57 | g=self.layer2(g+ct).squeeze(0) 58 | jk=[] 59 | for layer in self.layer3: 60 | g=layer(g,adj) 61 | jk.append(g.unsqueeze(0)) 62 | g=torch.cat(jk,0) 63 | g=self.jknet(g).mean(0) 64 | return g 65 | class ViT(nn.Module): 66 | def __init__(self, channel=32,kernel_size=5,dim=1024, 67 | depth1=2, depth2=8, depth3=4, 68 | heads=8, mlp_dim=1024, dim_head = 64, dropout = 0., 69 | policy='mean',gcn=True 70 | ): 71 | super().__init__() 72 | self.dropout = nn.Dropout(dropout) 73 | self.transformer = mixer_transformer( 74 | channel, kernel_size, dim, 75 | depth1, depth2, depth3, 76 | heads, dim_head, mlp_dim, dropout, 77 | policy,gcn, 78 | ) 79 | 80 | def forward(self,x,ct,adj): 81 | x = self.dropout(x) 82 | x = self.transformer(x,ct,adj) 83 | return x 84 | 85 | class Hist2ST(pl.LightningModule): 86 | def __init__(self, learning_rate=1e-5, fig_size=112, label=None, 87 | dropout=0.2, n_pos=64, kernel_size=5, patch_size=7, n_genes=785, 88 | depth1=2, depth2=8, depth3=4, heads=16, channel=32, 89 | zinb=0, nb=False, bake=0, lamb=0, policy='mean', 90 | ): 91 | super().__init__() 92 | # self.save_hyperparameters() 93 | dim=(fig_size//patch_size)**2*channel//8 94 | self.learning_rate = learning_rate 95 | 96 | self.nb=nb 97 | self.zinb=zinb 98 | 99 | self.bake=bake 100 | self.lamb=lamb 101 | 102 | self.label=label 103 | self.patch_embedding = nn.Conv2d(3,channel,patch_size,patch_size) 104 | self.x_embed = nn.Embedding(n_pos,dim) 105 | self.y_embed = nn.Embedding(n_pos,dim) 106 | self.vit = ViT( 107 | channel=channel, kernel_size=kernel_size, heads=heads, 108 | dim=dim, depth1=depth1,depth2=depth2, depth3=depth3, 109 | mlp_dim=dim, dropout = dropout, policy=policy, gcn=True, 110 | ) 111 | self.channel=channel 112 | self.patch_size=patch_size 113 | self.n_genes=n_genes 114 | if self.zinb>0: 115 | if self.nb: 116 | self.hr=nn.Linear(dim, n_genes) 117 | self.hp=nn.Linear(dim, n_genes) 118 | else: 119 | self.mean = nn.Sequential(nn.Linear(dim, n_genes), MeanAct()) 120 | self.disp = nn.Sequential(nn.Linear(dim, n_genes), DispAct()) 121 | self.pi = nn.Sequential(nn.Linear(dim, n_genes), nn.Sigmoid()) 122 | if self.bake>0: 123 | self.coef=nn.Sequential( 124 | nn.Linear(dim,dim), 125 | nn.ReLU(), 126 | nn.Linear(dim,1), 127 | ) 128 | self.gene_head = nn.Sequential( 129 | nn.LayerNorm(dim), 130 | nn.Linear(dim, n_genes), 131 | ) 132 | self.tf=tf.Compose([ 133 | tf.RandomGrayscale(0.1), 134 | tf.RandomRotation(90), 135 | tf.RandomHorizontalFlip(0.2), 136 | ]) 137 | def forward(self, patches, centers, adj, aug=False): 138 | B,N,C,H,W=patches.shape 139 | patches=patches.reshape(B*N,C,H,W) 140 | patches = self.patch_embedding(patches) 141 | centers_x = self.x_embed(centers[:,:,0]) 142 | centers_y = self.y_embed(centers[:,:,1]) 143 | ct=centers_x + centers_y 144 | h = self.vit(patches,ct,adj) 145 | x = self.gene_head(h) 146 | extra=None 147 | if self.zinb>0: 148 | if self.nb: 149 | r=self.hr(h) 150 | p=self.hp(h) 151 | extra=(r,p) 152 | else: 153 | m = self.mean(h) 154 | d = self.disp(h) 155 | p = self.pi(h) 156 | extra=(m,d,p) 157 | if aug: 158 | h=self.coef(h) 159 | return x,extra,h 160 | def aug(self,patch,center,adj): 161 | bake_x=[] 162 | for i in range(self.bake): 163 | new_patch=self.tf(patch.squeeze(0)).unsqueeze(0) 164 | x,_,h=self(new_patch,center,adj,True) 165 | bake_x.append((x.unsqueeze(0),h.unsqueeze(0))) 166 | return bake_x 167 | def distillation(self,bake_x): 168 | new_x,coef=zip(*bake_x) 169 | coef=torch.cat(coef,0) 170 | new_x=torch.cat(new_x,0) 171 | coef=F.softmax(coef,dim=0) 172 | new_x=(new_x*coef).sum(0) 173 | return new_x 174 | def training_step(self, batch, batch_idx): 175 | patch, center, exp, adj, oris, sfs, *_ = batch 176 | adj=adj.squeeze(0) 177 | exp=exp.squeeze(0) 178 | pred,extra,h = self(patch, center, adj) 179 | 180 | mse_loss = F.mse_loss(pred, exp) 181 | self.log('mse_loss', mse_loss,on_epoch=True, prog_bar=True, logger=True) 182 | bake_loss=0 183 | if self.bake>0: 184 | bake_x=self.aug(patch,center,adj) 185 | new_pred=self.distillation(bake_x) 186 | bake_loss+=F.mse_loss(new_pred,pred) 187 | self.log('bake_loss', bake_loss,on_epoch=True, prog_bar=True, logger=True) 188 | zinb_loss=0 189 | if self.zinb>0: 190 | if self.nb: 191 | r,p=extra 192 | zinb_loss = NB_loss(oris.squeeze(0),r,p) 193 | else: 194 | m,d,p=extra 195 | zinb_loss = ZINB_loss(oris.squeeze(0),m,d,p,sfs.squeeze(0)) 196 | self.log('zinb_loss', zinb_loss,on_epoch=True, prog_bar=True, logger=True) 197 | 198 | loss=mse_loss+self.zinb*zinb_loss+self.lamb*bake_loss 199 | return loss 200 | 201 | def validation_step(self, batch, batch_idx): 202 | patch, center, exp, adj, oris, sfs, *_ = batch 203 | def cluster(pred,cls): 204 | sc.pp.pca(pred) 205 | sc.tl.tsne(pred) 206 | kmeans = KMeans(n_clusters=cls, init="k-means++", random_state=0).fit(pred.obsm['X_pca']) 207 | pred.obs['kmeans'] = kmeans.labels_.astype(str) 208 | p=pred.obs['kmeans'].to_numpy() 209 | return p 210 | 211 | pred,extra,h = self(patch, center, adj.squeeze(0)) 212 | if self.label is not None: 213 | adata=ann.AnnData(pred.squeeze().cpu().numpy()) 214 | idx=self.label!='undetermined' 215 | cls=len(set(self.label)) 216 | x=adata[idx] 217 | l=self.label[idx] 218 | predlbl=cluster(x,cls-1) 219 | self.log('nmi',nmi_score(predlbl,l)) 220 | self.log('ari',ari_score(predlbl,l)) 221 | 222 | loss = F.mse_loss(pred.squeeze(0), exp.squeeze(0)) 223 | self.log('valid_loss', loss,on_epoch=True, prog_bar=True, logger=True) 224 | 225 | pred=pred.squeeze(0).cpu().numpy().T 226 | exp=exp.squeeze(0).cpu().numpy().T 227 | r=[] 228 | for g in range(self.n_genes): 229 | r.append(pearsonr(pred[g],exp[g])[0]) 230 | R=torch.Tensor(r).mean() 231 | self.log('R', R, on_epoch=True, prog_bar=True, logger=True) 232 | return loss 233 | 234 | def configure_optimizers(self): 235 | # self.hparams available because we called self.save_hyperparameters() 236 | optim=torch.optim.Adam(self.parameters(), lr=self.learning_rate) 237 | StepLR = torch.optim.lr_scheduler.StepLR(optim, step_size=50, gamma=0.9) 238 | optim_dict = {'optimizer': optim, 'lr_scheduler': StepLR} 239 | return optim_dict 240 | -------------------------------------------------------------------------------- /HIST2ST_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import random 4 | import argparse 5 | import pickle as pk 6 | import pytorch_lightning as pl 7 | from utils import * 8 | from HIST2ST import * 9 | from predict import * 10 | from torch.utils.data import DataLoader 11 | from pytorch_lightning.loggers import TensorBoardLogger 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--gpu', type=int, default=2, help='the id of gpu.') 15 | parser.add_argument('--fold', type=int, default=5, help='dataset fold.') 16 | parser.add_argument('--seed', type=int, default=12000, help='random seed.') 17 | parser.add_argument('--epochs', type=int, default=350, help='number of epochs.') 18 | parser.add_argument('--name', type=str, default='hist2ST', help='prefix name.') 19 | parser.add_argument('--data', type=str, default='her2st', help='dataset name:{"her2st","cscc"}.') 20 | parser.add_argument('--logger', type=str, default='../logs/my_logs', help='logger path.') 21 | parser.add_argument('--lr', type=float, default=1e-5, help='learning rate.') 22 | parser.add_argument('--dropout', type=float, default=0.2, help='dropout.') 23 | 24 | parser.add_argument('--bake', type=int, default=5, help='the number of augmented images.') 25 | parser.add_argument('--lamb', type=float, default=0.5, help='the loss coef of self-distillation.') 26 | 27 | 28 | parser.add_argument('--nb', type=str, default='F', help='zinb or nb loss.') 29 | parser.add_argument('--zinb', type=float, default=0.25, help='the loss coef of zinb.') 30 | 31 | parser.add_argument('--prune', type=str, default='Grid', help='how to prune the edge:{"Grid","NA"}') 32 | parser.add_argument('--policy', type=str, default='mean', help='the aggregation way in the GNN .') 33 | parser.add_argument('--neighbor', type=int, default=4, help='the number of neighbors in the GNN.') 34 | 35 | parser.add_argument('--tag', type=str, default='5-7-2-8-4-16-32', 36 | help='hyper params: kernel-patch-depth1-depth2-depth3-heads-channel,' 37 | 'depth1-depth2-depth3 are the depth of Convmixer, Multi-head layer in Transformer, and GNN, respectively' 38 | 'patch is the value of kernel_size and stride in the path embedding layer of Convmixer' 39 | 'kernel is the kernel_size in the depthwise of Convmixer module' 40 | 'heads are the number of attention heads in the Multi-head layer' 41 | 'channel is the value of the input and output channel of depthwise and pointwise. ') 42 | 43 | args = parser.parse_args() 44 | random.seed(args.seed) 45 | np.random.seed(args.seed) 46 | torch.manual_seed(args.seed) 47 | torch.cuda.manual_seed(args.seed) 48 | torch.cuda.manual_seed_all(args.seed) 49 | torch.backends.cudnn.benchmark = False 50 | torch.backends.cudnn.deterministic = True 51 | kernel,patch,depth1,depth2,depth3,heads,channel=map(lambda x:int(x),args.tag.split('-')) 52 | 53 | trainset = pk_load(args.fold,'train',False,args.data,neighs=args.neighbor, prune=args.prune) 54 | train_loader = DataLoader(trainset, batch_size=1, num_workers=0, shuffle=True) 55 | testset = pk_load(args.fold,'test',False,args.data,neighs=args.neighbor, prune=args.prune) 56 | test_loader = DataLoader(testset, batch_size=1, num_workers=0, shuffle=False) 57 | label=None 58 | if args.fold in [5,11,17,23,26,30] and args.data=='her2st': 59 | label=testset.label[testset.names[0]] 60 | 61 | genes=785 62 | if args.data=='cscc': 63 | args.name+='_cscc' 64 | genes=171 65 | 66 | log_name='' 67 | if args.zinb>0: 68 | if args.nb=='T': 69 | args.name+='_nb' 70 | else: 71 | args.name+='_zinb' 72 | log_name+=f'-{args.zinb}' 73 | if args.bake>0: 74 | args.name+='_bake' 75 | log_name+=f'-{args.bake}-{args.lamb}' 76 | log_name=f'{args.fold}-{args.name}-{args.tag}'+log_name+f'-{args.policy}-{args.neighbor}' 77 | logger = TensorBoardLogger( 78 | args.logger, 79 | name=log_name 80 | ) 81 | print(log_name) 82 | 83 | model = Hist2ST( 84 | depth1=depth1, depth2=depth2, depth3=depth3, 85 | n_genes=genes, learning_rate=args.lr, label=label, 86 | kernel_size=kernel, patch_size=patch, 87 | heads=heads, channel=channel, dropout=args.dropout, 88 | zinb=args.zinb, nb=args.nb=='T', 89 | bake=args.bake, lamb=args.lamb, 90 | policy=args.policy, 91 | ) 92 | trainer = pl.Trainer( 93 | gpus=[args.gpu], max_epochs=args.epochs, 94 | logger=logger,check_val_every_n_epoch=2, 95 | ) 96 | 97 | trainer.fit(model, train_loader, test_loader) 98 | torch.save(model.state_dict(),f"./model/{args.fold}-Hist2ST{'_cscc' if args.data=='cscc' else ''}.ckpt") 99 | # model.load_state_dict(torch.load(f"./model/{args.fold}-Hist2ST{'_cscc' if args.data=='cscc' else ''}.ckpt"),) 100 | pred, gt = test(model, test_loader,'cuda') 101 | R=get_R(pred,gt)[0] 102 | print('Pearson Correlation:',np.nanmean(R)) 103 | clus,ARI=cluster(pred,label) 104 | print('ARI:',ARI) 105 | -------------------------------------------------------------------------------- /NB_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class MeanAct(nn.Module): 7 | def __init__(self): 8 | super(MeanAct, self).__init__() 9 | def forward(self, x): 10 | return torch.clamp(torch.exp(x), min=1e-5, max=1e6) 11 | 12 | class DispAct(nn.Module): 13 | def __init__(self): 14 | super(DispAct, self).__init__() 15 | def forward(self, x): 16 | return torch.clamp(F.softplus(x), min=1e-4, max=1e4) 17 | 18 | def NB_loss(x, h_r, h_p): 19 | 20 | ll = torch.lgamma(torch.exp(h_r) + x) - torch.lgamma(torch.exp(h_r)) 21 | ll += h_p * x - torch.log(torch.exp(h_p) + 1) * (x + torch.exp(h_r)) 22 | 23 | loss = -torch.mean(torch.sum(ll, axis=-1)) 24 | return loss 25 | 26 | def ZINB_loss(x, mean, disp, pi, scale_factor=1.0, ridge_lambda=0.0): 27 | eps = 1e-10 28 | if isinstance(scale_factor,float): 29 | scale_factor=np.full((len(mean),),scale_factor) 30 | scale_factor = scale_factor[:, None] 31 | mean = mean * scale_factor 32 | 33 | t1 = torch.lgamma(disp+eps) + torch.lgamma(x+1.0) - torch.lgamma(x+disp+eps) 34 | t2 = (disp+x) * torch.log(1.0 + (mean/(disp+eps))) + (x * (torch.log(disp+eps) - torch.log(mean+eps))) 35 | nb_final = t1 + t2 36 | 37 | nb_case = nb_final - torch.log(1.0-pi+eps) 38 | zero_nb = torch.pow(disp/(disp+mean+eps), disp) 39 | zero_case = -torch.log(pi + ((1.0-pi)*zero_nb)+eps) 40 | result = torch.where(torch.le(x, 1e-8), zero_case, nb_case) 41 | 42 | if ridge_lambda > 0: 43 | ridge = ridge_lambda*torch.square(pi) 44 | result += ridge 45 | result = torch.mean(result) 46 | return result -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Spatial Transcriptomics Prediction from Histology jointly through Transformer and Graph Neural Networks 2 | ### Yuansong Zeng, Zhuoyi Wei, Weijiang Yu, Rui Yin, Bingling Li, Zhonghui Tang, Yutong Lu, Yuedong Yang* 3 | 4 | 5 | Here, we have developed Hist2ST, a deep learning-based model using histology images to predict RNA-seq expression. 6 | At each sequenced spot, the corre-sponding histology image is cropped into an image patch, from which 2D vision 7 | features are learned through convolutional operations. Meanwhile, the spatial relations with the whole image and 8 | neighbored patches are captured through Transformer and graph neural network modules, respectively. These learned 9 | features are then used to predict the gene expression by following the zero-inflated negative binomial (ZINB) distribution. 10 | To alleviate the impact by the small spatial transcriptomics data, a self-distillation mechanism is employed for efficient 11 | learning of the model. Hist2ST was tested on the HER2-positive breast cancer and the cutaneous squamous cell carcinoma datasets, 12 | and shown to outperform existing methods in terms of both gene expression prediction and following spatial region identification. 13 | 14 | 15 | 16 | ![(Variational) gcn](Workflow.png) 17 | 18 | 19 | 20 | # Usage 21 | ```python 22 | import torch 23 | from HIST2ST import Hist2ST 24 | 25 | model = Hist2ST( 26 | depth1=2, depth2=8, depth3=4, 27 | n_genes=785, learning_rate=1e-5, 28 | kernel_size=5, patch_size=7, fig_size=112, 29 | heads=16, channel=32, dropout=0.2, 30 | zinb=0.25, nb=False, 31 | bake=5, lamb=0.5, 32 | policy='mean', 33 | ) 34 | 35 | # patches: [N, 3, W, H] 36 | # coordinates: [N, 2] 37 | # adjacency: [N, N] 38 | pred_expression = model(patches, coordinates,adjacency) # [N, n_genes] 39 | 40 | ``` 41 | 42 | Note: the detailed parameters instructions please see [HIST2ST_train](https://github.com/biomed-AI/Hist2ST/blob/main/HIST2ST_train.py) 43 | 44 | 45 | ## System environment 46 | Required package: 47 | - PyTorch >= 1.10 48 | - pytorch-lightning >= 1.4 49 | - scanpy >= 1.8 50 | - python >=3.7 51 | - tensorboard 52 | 53 | 54 | # Hist2ST pipeline 55 | 56 | See [tutorial.ipynb](tutorial.ipynb) 57 | 58 | 59 | NOTE: Run the following command if you want to run the script tutorial.ipynb 60 | 61 | 1. Please run the script `download.sh` in the folder [data](https://github.com/biomed-AI/Hist2ST/tree/main/data) 62 | 63 | or 64 | 65 | Run the command line `git clone https://github.com/almaan/her2st.git` in the dir [data](https://github.com/biomed-AI/Hist2ST/tree/main/data) 66 | 67 | 2. Run `gunzip *.gz` in the dir `Hist2ST/data/her2st/data/ST-cnts/` to unzip the gz files 68 | 69 | 70 | # Datasets 71 | 72 | - human HER2-positive breast tumor ST data https://github.com/almaan/her2st/. 73 | - human cutaneous squamous cell carcinoma 10x Visium data (GSE144240). 74 | - you can also download all datasets from [here](https://www.synapse.org/#!Synapse:syn29738084/files/) 75 | 76 | 77 | # Trained models 78 | All Trained models of our method on HER2+ and cSCC datasets can be found at [synapse](https://www.synapse.org/#!Synapse:syn29738084/files/) 79 | 80 | 81 | # Citation 82 | 83 | Please cite our paper: 84 | 85 | ``` 86 | 87 | @article{zengys, 88 | title={Spatial Transcriptomics Prediction from Histology jointly through Transformer and Graph Neural Networks}, 89 | author={ Yuansong Zeng, Zhuoyi Wei, Weijiang Yu, Rui Yin, Bingling Li, Zhonghui Tang, Yutong Lu, Yuedong Yang}, 90 | journal={biorxiv}, 91 | year={2021} 92 | publisher={Cold Spring Harbor Laboratory} 93 | } 94 | 95 | ``` 96 | -------------------------------------------------------------------------------- /Workflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biomed-AI/Hist2ST/7480e5d1f7712282e6c96d8430684e7830517769/Workflow.png -------------------------------------------------------------------------------- /data/download.sh: -------------------------------------------------------------------------------- 1 | git clone https://github.com/almaan/her2st.git -------------------------------------------------------------------------------- /data/her_hvg_cut_1000.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biomed-AI/Hist2ST/7480e5d1f7712282e6c96d8430684e7830517769/data/her_hvg_cut_1000.npy -------------------------------------------------------------------------------- /data/skin_hvg_cut_1000.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biomed-AI/Hist2ST/7480e5d1f7712282e6c96d8430684e7830517769/data/skin_hvg_cut_1000.npy -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import torch 4 | import torchvision 5 | import numpy as np 6 | import scanpy as sc 7 | import pandas as pd 8 | import scprep as scp 9 | import anndata as ad 10 | import seaborn as sns 11 | import torch.nn as nn 12 | import torch.optim as optim 13 | import matplotlib.pyplot as plt 14 | import torch.nn.functional as F 15 | import torchvision.transforms as transforms 16 | from PIL import ImageFile, Image 17 | from utils import read_tiff, get_data 18 | from graph_construction import calcADJ 19 | from collections import defaultdict as dfd 20 | ImageFile.LOAD_TRUNCATED_IMAGES = True 21 | Image.MAX_IMAGE_PIXELS = None 22 | 23 | class ViT_HER2ST(torch.utils.data.Dataset): 24 | """Some Information about HER2ST""" 25 | def __init__(self,train=True,fold=0,r=4,flatten=True,ori=False,adj=False,prune='Grid',neighs=4): 26 | super(ViT_HER2ST, self).__init__() 27 | 28 | self.cnt_dir = 'data/her2st/data/ST-cnts' 29 | self.img_dir = 'data/her2st/data/ST-imgs' 30 | self.pos_dir = 'data/her2st/data/ST-spotfiles' 31 | self.lbl_dir = 'data/her2st/data/ST-pat/lbl' 32 | self.r = 224//r 33 | 34 | # gene_list = list(np.load('data/her_hvg.npy',allow_pickle=True)) 35 | gene_list = list(np.load('data/her_hvg_cut_1000.npy',allow_pickle=True)) 36 | self.gene_list = gene_list 37 | names = os.listdir(self.cnt_dir) 38 | names.sort() 39 | names = [i[:2] for i in names] 40 | self.train = train 41 | self.ori = ori 42 | self.adj = adj 43 | # samples = ['A1','B1','C1','D1','E1','F1','G2','H1'] 44 | samples = names[1:33] 45 | 46 | te_names = [samples[fold]] 47 | print(te_names) 48 | tr_names = list(set(samples)-set(te_names)) 49 | 50 | if train: 51 | self.names = tr_names 52 | else: 53 | self.names = te_names 54 | 55 | print('Loading imgs...') 56 | self.img_dict = {i:torch.Tensor(np.array(self.get_img(i))) for i in self.names} 57 | print('Loading metadata...') 58 | self.meta_dict = {i:self.get_meta(i) for i in self.names} 59 | self.label={i:None for i in self.names} 60 | self.lbl2id={ 61 | 'invasive cancer':0, 'breast glands':1, 'immune infiltrate':2, 62 | 'cancer in situ':3, 'connective tissue':4, 'adipose tissue':5, 'undetermined':-1 63 | } 64 | if not train and self.names[0] in ['A1','B1','C1','D1','E1','F1','G2','H1','J1']: 65 | self.lbl_dict={i:self.get_lbl(i) for i in self.names} 66 | # self.label={i:m['label'].values for i,m in self.lbl_dict.items()} 67 | idx=self.meta_dict[self.names[0]].index 68 | lbl=self.lbl_dict[self.names[0]] 69 | lbl=lbl.loc[idx,:]['label'].values 70 | # lbl=torch.Tensor(list(map(lambda i:self.lbl2id[i],lbl))) 71 | self.label[self.names[0]]=lbl 72 | elif train: 73 | for i in self.names: 74 | idx=self.meta_dict[i].index 75 | if i in ['A1','B1','C1','D1','E1','F1','G2','H1','J1']: 76 | lbl=self.get_lbl(i) 77 | lbl=lbl.loc[idx,:]['label'].values 78 | lbl=torch.Tensor(list(map(lambda i:self.lbl2id[i],lbl))) 79 | self.label[i]=lbl 80 | else: 81 | self.label[i]=torch.full((len(idx),),-1) 82 | self.gene_set = list(gene_list) 83 | self.exp_dict = { 84 | i:scp.transform.log(scp.normalize.library_size_normalize(m[self.gene_set].values)) 85 | for i,m in self.meta_dict.items() 86 | } 87 | if self.ori: 88 | self.ori_dict = {i:m[self.gene_set].values for i,m in self.meta_dict.items()} 89 | self.counts_dict={} 90 | for i,m in self.ori_dict.items(): 91 | n_counts=m.sum(1) 92 | sf = n_counts / np.median(n_counts) 93 | self.counts_dict[i]=sf 94 | self.center_dict = { 95 | i:np.floor(m[['pixel_x','pixel_y']].values).astype(int) 96 | for i,m in self.meta_dict.items() 97 | } 98 | self.loc_dict = {i:m[['x','y']].values for i,m in self.meta_dict.items()} 99 | self.adj_dict = { 100 | i:calcADJ(m,neighs,pruneTag=prune) 101 | for i,m in self.loc_dict.items() 102 | } 103 | self.patch_dict=dfd(lambda :None) 104 | self.lengths = [len(i) for i in self.meta_dict.values()] 105 | self.cumlen = np.cumsum(self.lengths) 106 | self.id2name = dict(enumerate(self.names)) 107 | self.flatten=flatten 108 | def __getitem__(self, index): 109 | ID=self.id2name[index] 110 | im = self.img_dict[ID] 111 | im = im.permute(1,0,2) 112 | # im = torch.Tensor(np.array(self.im)) 113 | exps = self.exp_dict[ID] 114 | if self.ori: 115 | oris = self.ori_dict[ID] 116 | sfs = self.counts_dict[ID] 117 | centers = self.center_dict[ID] 118 | loc = self.loc_dict[ID] 119 | adj = self.adj_dict[ID] 120 | patches = self.patch_dict[ID] 121 | positions = torch.LongTensor(loc) 122 | patch_dim = 3 * self.r * self.r * 4 123 | label=self.label[ID] 124 | exps = torch.Tensor(exps) 125 | if patches is None: 126 | n_patches = len(centers) 127 | if self.flatten: 128 | patches = torch.zeros((n_patches,patch_dim)) 129 | else: 130 | patches = torch.zeros((n_patches,3,2*self.r,2*self.r)) 131 | for i in range(n_patches): 132 | center = centers[i] 133 | x, y = center 134 | patch = im[(x-self.r):(x+self.r),(y-self.r):(y+self.r),:] 135 | if self.flatten: 136 | patches[i] = patch.flatten() 137 | else: 138 | patches[i]=patch.permute(2,0,1) 139 | self.patch_dict[ID]=patches 140 | data=[patches, positions, exps] 141 | if self.adj: 142 | data.append(adj) 143 | if self.ori: 144 | data+=[torch.Tensor(oris),torch.Tensor(sfs)] 145 | data.append(torch.Tensor(centers)) 146 | return data 147 | 148 | def __len__(self): 149 | return len(self.exp_dict) 150 | 151 | def get_img(self,name): 152 | pre = self.img_dir+'/'+name[0]+'/'+name 153 | fig_name = os.listdir(pre)[0] 154 | path = pre+'/'+fig_name 155 | im = Image.open(path) 156 | return im 157 | 158 | def get_cnt(self,name): 159 | path = self.cnt_dir+'/'+name+'.tsv' 160 | df = pd.read_csv(path,sep='\t',index_col=0) 161 | 162 | return df 163 | 164 | def get_pos(self,name): 165 | path = self.pos_dir+'/'+name+'_selection.tsv' 166 | # path = self.pos_dir+'/'+name+'_labeled_coordinates.tsv' 167 | df = pd.read_csv(path,sep='\t') 168 | 169 | x = df['x'].values 170 | y = df['y'].values 171 | x = np.around(x).astype(int) 172 | y = np.around(y).astype(int) 173 | id = [] 174 | for i in range(len(x)): 175 | id.append(str(x[i])+'x'+str(y[i])) 176 | df['id'] = id 177 | 178 | return df 179 | 180 | def get_meta(self,name,gene_list=None): 181 | cnt = self.get_cnt(name) 182 | pos = self.get_pos(name) 183 | meta = cnt.join((pos.set_index('id'))) 184 | 185 | return meta 186 | 187 | def get_lbl(self,name): 188 | # path = self.pos_dir+'/'+name+'_selection.tsv' 189 | path = self.lbl_dir+'/'+name+'_labeled_coordinates.tsv' 190 | df = pd.read_csv(path,sep='\t') 191 | 192 | x = df['x'].values 193 | y = df['y'].values 194 | x = np.around(x).astype(int) 195 | y = np.around(y).astype(int) 196 | id = [] 197 | for i in range(len(x)): 198 | id.append(str(x[i])+'x'+str(y[i])) 199 | df['id'] = id 200 | df.drop('pixel_x', inplace=True, axis=1) 201 | df.drop('pixel_y', inplace=True, axis=1) 202 | df.drop('x', inplace=True, axis=1) 203 | df.drop('y', inplace=True, axis=1) 204 | df.set_index('id',inplace=True) 205 | return df 206 | 207 | class ViT_SKIN(torch.utils.data.Dataset): 208 | """Some Information about ViT_SKIN""" 209 | def __init__(self,train=True,r=4,norm=False,fold=0,flatten=True,ori=False,adj=False,prune='NA',neighs=4): 210 | super(ViT_SKIN, self).__init__() 211 | 212 | self.dir = './data/GSE144240_RAW/' 213 | self.r = 224//r 214 | 215 | patients = ['P2', 'P5', 'P9', 'P10'] 216 | reps = ['rep1', 'rep2', 'rep3'] 217 | names = [] 218 | for i in patients: 219 | for j in reps: 220 | names.append(i+'_ST_'+j) 221 | gene_list = list(np.load('data/skin_hvg_cut_1000.npy',allow_pickle=True)) 222 | 223 | self.ori = ori 224 | self.adj = adj 225 | self.norm = norm 226 | self.train = train 227 | self.flatten = flatten 228 | self.gene_list = gene_list 229 | samples = names 230 | te_names = [samples[fold]] 231 | tr_names = list(set(samples)-set(te_names)) 232 | 233 | if train: 234 | self.names = tr_names 235 | else: 236 | self.names = te_names 237 | 238 | print(te_names) 239 | print('Loading imgs...') 240 | self.img_dict = {i:torch.Tensor(np.array(self.get_img(i))) for i in self.names} 241 | print('Loading metadata...') 242 | self.meta_dict = {i:self.get_meta(i) for i in self.names} 243 | 244 | self.gene_set = list(gene_list) 245 | if self.norm: 246 | self.exp_dict = { 247 | i:sc.pp.scale(scp.transform.log(scp.normalize.library_size_normalize(m[self.gene_set].values))) 248 | for i,m in self.meta_dict.items() 249 | } 250 | else: 251 | self.exp_dict = { 252 | i:scp.transform.log(scp.normalize.library_size_normalize(m[self.gene_set].values)) 253 | for i,m in self.meta_dict.items() 254 | } 255 | if self.ori: 256 | self.ori_dict = {i:m[self.gene_set].values for i,m in self.meta_dict.items()} 257 | self.counts_dict={} 258 | for i,m in self.ori_dict.items(): 259 | n_counts=m.sum(1) 260 | sf = n_counts / np.median(n_counts) 261 | self.counts_dict[i]=sf 262 | self.center_dict = { 263 | i:np.floor(m[['pixel_x','pixel_y']].values).astype(int) 264 | for i,m in self.meta_dict.items() 265 | } 266 | self.loc_dict = {i:m[['x','y']].values for i,m in self.meta_dict.items()} 267 | self.adj_dict = { 268 | i:calcADJ(m,neighs,pruneTag=prune) 269 | for i,m in self.loc_dict.items() 270 | } 271 | self.patch_dict=dfd(lambda :None) 272 | self.lengths = [len(i) for i in self.meta_dict.values()] 273 | self.cumlen = np.cumsum(self.lengths) 274 | self.id2name = dict(enumerate(self.names)) 275 | 276 | 277 | def filter_helper(self): 278 | a = np.zeros(len(self.gene_list)) 279 | n = 0 280 | for i,exp in self.exp_dict.items(): 281 | n += exp.shape[0] 282 | exp[exp>0] = 1 283 | for j in range((len(self.gene_list))): 284 | a[j] += np.sum(exp[:,j]) 285 | 286 | 287 | def __getitem__(self, index): 288 | ID=self.id2name[index] 289 | im = self.img_dict[ID].permute(1,0,2) 290 | 291 | exps = self.exp_dict[ID] 292 | if self.ori: 293 | oris = self.ori_dict[ID] 294 | sfs = self.counts_dict[ID] 295 | adj=self.adj_dict[ID] 296 | centers = self.center_dict[ID] 297 | loc = self.loc_dict[ID] 298 | patches = self.patch_dict[ID] 299 | positions = torch.LongTensor(loc) 300 | patch_dim = 3 * self.r * self.r * 4 301 | exps = torch.Tensor(exps) 302 | if patches is None: 303 | n_patches = len(centers) 304 | if self.flatten: 305 | patches = torch.zeros((n_patches,patch_dim)) 306 | else: 307 | patches = torch.zeros((n_patches,3,2*self.r,2*self.r)) 308 | 309 | for i in range(n_patches): 310 | center = centers[i] 311 | x, y = center 312 | patch = im[(x-self.r):(x+self.r),(y-self.r):(y+self.r),:] 313 | if self.flatten: 314 | patches[i] = patch.flatten() 315 | else: 316 | patches[i]=patch.permute(2,0,1) 317 | self.patch_dict[ID]=patches 318 | data=[patches, positions, exps] 319 | if self.adj: 320 | data.append(adj) 321 | if self.ori: 322 | data+=[torch.Tensor(oris),torch.Tensor(sfs)] 323 | data.append(torch.Tensor(centers)) 324 | return data 325 | 326 | def __len__(self): 327 | return len(self.exp_dict) 328 | 329 | def get_img(self,name): 330 | path = glob.glob(self.dir+'*'+name+'.jpg')[0] 331 | im = Image.open(path) 332 | return im 333 | 334 | def get_cnt(self,name): 335 | path = glob.glob(self.dir+'*'+name+'_stdata.tsv')[0] 336 | df = pd.read_csv(path,sep='\t',index_col=0) 337 | return df 338 | 339 | def get_pos(self,name): 340 | path = glob.glob(self.dir+'*spot*'+name+'.tsv')[0] 341 | df = pd.read_csv(path,sep='\t') 342 | 343 | x = df['x'].values 344 | y = df['y'].values 345 | x = np.around(x).astype(int) 346 | y = np.around(y).astype(int) 347 | id = [] 348 | for i in range(len(x)): 349 | id.append(str(x[i])+'x'+str(y[i])) 350 | df['id'] = id 351 | 352 | return df 353 | 354 | def get_meta(self,name,gene_list=None): 355 | cnt = self.get_cnt(name) 356 | pos = self.get_pos(name) 357 | meta = cnt.join(pos.set_index('id'),how='inner') 358 | 359 | return meta 360 | 361 | def get_overlap(self,meta_dict,gene_list): 362 | gene_set = set(gene_list) 363 | for i in meta_dict.values(): 364 | gene_set = gene_set&set(i.columns) 365 | return list(gene_set) -------------------------------------------------------------------------------- /gcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import init 5 | from torch.autograd import Variable 6 | 7 | import time 8 | import random 9 | import numpy as np 10 | from collections import defaultdict 11 | 12 | class gs_block(nn.Module): 13 | def __init__( 14 | self, feature_dim, embed_dim, 15 | policy='mean', gcn=False, num_sample=10 16 | ): 17 | super().__init__() 18 | self.gcn = gcn 19 | self.policy=policy 20 | self.embed_dim = embed_dim 21 | self.feat_dim = feature_dim 22 | self.num_sample = num_sample 23 | self.weight = nn.Parameter(torch.FloatTensor( 24 | embed_dim, 25 | self.feat_dim if self.gcn else 2*self.feat_dim 26 | )) 27 | init.xavier_uniform_(self.weight) 28 | 29 | def forward(self, x, Adj): 30 | neigh_feats = self.aggregate(x, Adj) 31 | if not self.gcn: 32 | combined = torch.cat([x, neigh_feats], dim=1) 33 | else: 34 | combined = neigh_feats 35 | combined = F.relu(self.weight.mm(combined.T)).T 36 | combined = F.normalize(combined,2,1) 37 | return combined 38 | def aggregate(self,x, Adj): 39 | adj=Variable(Adj).to(Adj.device) 40 | if not self.gcn: 41 | n=len(adj) 42 | adj = adj-torch.eye(n).to(adj.device) 43 | if self.policy=='mean': 44 | num_neigh = adj.sum(1, keepdim=True) 45 | mask = adj.div(num_neigh) 46 | to_feats = mask.mm(x) 47 | elif self.policy=='max': 48 | indexs = [i.nonzero() for i in adj==1] 49 | to_feats = [] 50 | for feat in [x[i.squeeze()] for i in indexs]: 51 | if len(feat.size()) == 1: 52 | to_feats.append(feat.view(1, -1)) 53 | else: 54 | to_feats.append(torch.max(feat,0)[0].view(1, -1)) 55 | to_feats = torch.cat(to_feats, 0) 56 | return to_feats -------------------------------------------------------------------------------- /graph_construction.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.spatial import distance_matrix, minkowski_distance, distance 4 | def calcADJ(coord, k=8, distanceType='euclidean', pruneTag='NA'): 5 | r""" 6 | Calculate spatial Matrix directly use X/Y coordinates 7 | """ 8 | spatialMatrix=coord#.cpu().numpy() 9 | nodes=spatialMatrix.shape[0] 10 | Adj=torch.zeros((nodes,nodes)) 11 | for i in np.arange(spatialMatrix.shape[0]): 12 | tmp=spatialMatrix[i,:].reshape(1,-1) 13 | distMat = distance.cdist(tmp,spatialMatrix, distanceType) 14 | if k == 0: 15 | k = spatialMatrix.shape[0]-1 16 | res = distMat.argsort()[:k+1] 17 | tmpdist = distMat[0,res[0][1:k+1]] 18 | boundary = np.mean(tmpdist)+np.std(tmpdist) #optional 19 | for j in np.arange(1,k+1): 20 | # No prune 21 | if pruneTag == 'NA': 22 | Adj[i][res[0][j]]=1.0 23 | elif pruneTag == 'STD': 24 | if distMat[0,res[0][j]]<=boundary: 25 | Adj[i][res[0][j]]=1.0 26 | # Prune: only use nearest neighbor as exact grid: 6 in cityblock, 8 in euclidean 27 | elif pruneTag == 'Grid': 28 | if distMat[0,res[0][j]]<=2.0: 29 | Adj[i][res[0][j]]=1.0 30 | return Adj -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import scanpy as sc 4 | import anndata as ad 5 | from tqdm import tqdm 6 | from dataset import ViT_HER2ST, ViT_SKIN 7 | from scipy.stats import pearsonr,spearmanr 8 | from sklearn.cluster import KMeans 9 | from sklearn.metrics import adjusted_rand_score as ari_score 10 | from sklearn.metrics.cluster import normalized_mutual_info_score as nmi_score 11 | def pk_load(fold,mode='train',flatten=False,dataset='her2st',r=4,ori=True,adj=True,prune='Grid',neighs=4): 12 | assert dataset in ['her2st','cscc'] 13 | if dataset=='her2st': 14 | dataset = ViT_HER2ST( 15 | train=(mode=='train'),fold=fold,flatten=flatten, 16 | ori=ori,neighs=neighs,adj=adj,prune=prune,r=r 17 | ) 18 | elif dataset=='cscc': 19 | dataset = ViT_SKIN( 20 | train=(mode=='train'),fold=fold,flatten=flatten, 21 | ori=ori,neighs=neighs,adj=adj,prune=prune,r=r 22 | ) 23 | return dataset 24 | def test(model,test,device='cuda'): 25 | model=model.to(device) 26 | model.eval() 27 | preds=None 28 | ct=None 29 | gt=None 30 | loss=0 31 | with torch.no_grad(): 32 | for patch, position, exp, adj, *_, center in tqdm(test): 33 | patch, position, adj = patch.to(device), position.to(device), adj.to(device).squeeze(0) 34 | pred = model(patch, position, adj)[0] 35 | preds = pred.squeeze().cpu().numpy() 36 | ct = center.squeeze().cpu().numpy() 37 | gt = exp.squeeze().cpu().numpy() 38 | adata = ad.AnnData(preds) 39 | adata.obsm['spatial'] = ct 40 | adata_gt = ad.AnnData(gt) 41 | adata_gt.obsm['spatial'] = ct 42 | return adata,adata_gt 43 | def cluster(adata,label): 44 | idx=label!='undetermined' 45 | tmp=adata[idx] 46 | l=label[idx] 47 | sc.pp.pca(tmp) 48 | sc.tl.tsne(tmp) 49 | kmeans = KMeans(n_clusters=len(set(l)), init="k-means++", random_state=0).fit(tmp.obsm['X_pca']) 50 | p=kmeans.labels_.astype(str) 51 | lbl=np.full(len(adata),str(len(set(l)))) 52 | lbl[idx]=p 53 | adata.obs['kmeans']=lbl 54 | return p,round(ari_score(p,l),3) 55 | def get_R(data1,data2,dim=1,func=pearsonr): 56 | adata1=data1.X 57 | adata2=data2.X 58 | r1,p1=[],[] 59 | for g in range(data1.shape[dim]): 60 | if dim==1: 61 | r,pv=func(adata1[:,g],adata2[:,g]) 62 | elif dim==0: 63 | r,pv=func(adata1[g,:],adata2[g,:]) 64 | r1.append(r) 65 | p1.append(pv) 66 | r1=np.array(r1) 67 | p1=np.array(p1) 68 | return r1,p1 69 | -------------------------------------------------------------------------------- /run_trained_models.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "053309c2-cb6d-4623-83b7-7b0d947998e6", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "name": "stdout", 11 | "output_type": "stream", 12 | "text": [ 13 | "[easydl] tensorflow not available!\n" 14 | ] 15 | } 16 | ], 17 | "source": [ 18 | "import torch\n", 19 | "import numpy as np\n", 20 | "import pytorch_lightning as pl\n", 21 | "import torchvision.transforms as tf\n", 22 | "from tqdm import tqdm\n", 23 | "from predict import *\n", 24 | "from HIST2ST import *\n", 25 | "from dataset import ViT_HER2ST, ViT_SKIN\n", 26 | "from scipy.stats import pearsonr,spearmanr\n", 27 | "from torch.utils.data import DataLoader\n", 28 | "from pytorch_lightning.loggers import TensorBoardLogger\n", 29 | "from copy import deepcopy as dcp\n", 30 | "from collections import defaultdict as dfd\n", 31 | "from sklearn.metrics import adjusted_rand_score as ari_score\n", 32 | "from sklearn.metrics.cluster import normalized_mutual_info_score as nmi_score" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "id": "b132bfb4-bd29-44fa-bbca-3790b0848fd4", 38 | "metadata": {}, 39 | "source": [ 40 | "# Data Loading" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 2, 46 | "id": "6281d8f2-e5f7-4c26-88bf-5256197b1418", 47 | "metadata": { 48 | "tags": [] 49 | }, 50 | "outputs": [], 51 | "source": [ 52 | "name=[*[f'A{i}' for i in range(2,7)],*[f'B{i}' for i in range(1,7)],\n", 53 | " *[f'C{i}' for i in range(1,7)],*[f'D{i}' for i in range(1,7)],\n", 54 | " *[f'E{i}' for i in range(1,4)],*[f'F{i}' for i in range(1,4)],*[f'G{i}' for i in range(1,4)]]\n", 55 | "patients = ['P2', 'P5', 'P9', 'P10']\n", 56 | "reps = ['rep1', 'rep2', 'rep3']\n", 57 | "skinname = []\n", 58 | "for i in patients:\n", 59 | " for j in reps:\n", 60 | " skinname.append(i+'_ST_'+j)\n", 61 | "device='cuda'\n", 62 | "tag='5-7-2-8-4-16-32'\n", 63 | "k,p,d1,d2,d3,h,c=map(lambda x:int(x),tag.split('-'))\n", 64 | "dropout=0.2\n", 65 | "random.seed(12000)\n", 66 | "np.random.seed(12000)\n", 67 | "torch.manual_seed(12000)\n", 68 | "torch.cuda.manual_seed(12000)\n", 69 | "torch.cuda.manual_seed_all(12000) \n", 70 | "torch.backends.cudnn.benchmark = False\n", 71 | "torch.backends.cudnn.deterministic = True" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "source": [ 77 | "\n", 78 | "# Hist2ST Prediction\n", 79 | "\n", 80 | "### To run the trained model, please select the trained model and replace the value of the variable fold with the number in the name of the selected trained model." 81 | ], 82 | "metadata": { 83 | "collapsed": false 84 | } 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 6, 89 | "id": "0f19d5d8-6510-493b-9287-ba7215155dc1", 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "fold=5\n", 94 | "data='her2st'\n", 95 | "prune='Grid' if data=='her2st' else 'NA'\n", 96 | "genes=171 if data=='cscc' else 785" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 7, 102 | "id": "76cfcfb9-3169-46da-9288-2e2be7942d84", 103 | "metadata": {}, 104 | "outputs": [ 105 | { 106 | "name": "stdout", 107 | "output_type": "stream", 108 | "text": [ 109 | "['B1']\n", 110 | "Loading imgs...\n", 111 | "Loading metadata...\n" 112 | ] 113 | }, 114 | { 115 | "name": "stderr", 116 | "output_type": "stream", 117 | "text": [ 118 | "100%|██████████████████████████████████| 1/1 [00:02<00:00, 2.85s/it]\n" 119 | ] 120 | } 121 | ], 122 | "source": [ 123 | "testset = pk_load(fold,'test',dataset=data,flatten=False,adj=True,ori=True,prune=prune)\n", 124 | "test_loader = DataLoader(testset, batch_size=1, num_workers=0, shuffle=False)\n", 125 | "label=testset.label[testset.names[0]]\n", 126 | "genes=785\n", 127 | "model=Hist2ST(\n", 128 | " depth1=d1, depth2=d2,depth3=d3,n_genes=genes, \n", 129 | " kernel_size=k, patch_size=p,\n", 130 | " heads=h, channel=c, dropout=0.2,\n", 131 | " zinb=0.25, nb=False,\n", 132 | " bake=5, lamb=0.5, \n", 133 | ")\n", 134 | "model.load_state_dict(torch.load(f'./model/{fold}-Hist2ST.ckpt'))\n", 135 | "pred, gt = test(model, test_loader,'cuda')" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 7, 141 | "id": "4266cce3-6bc7-4f5b-81d5-d03d3ae089c4", 142 | "metadata": { 143 | "pycharm": { 144 | "name": "#%%\n" 145 | } 146 | }, 147 | "outputs": [ 148 | { 149 | "name": "stdout", 150 | "output_type": "stream", 151 | "text": [ 152 | "Pearson Correlation: 0.2887870599966082\n", 153 | "ARI: 0.431\n" 154 | ] 155 | } 156 | ], 157 | "source": [ 158 | "R=get_R(pred,gt)[0]\n", 159 | "print('Pearson Correlation:',np.nanmean(R))\n", 160 | "\n", 161 | "\n", 162 | "# clus,ARI=cluster(pred,label)\n", 163 | "# print('ARI:',ARI)\n" 164 | ] 165 | } 166 | ], 167 | "metadata": { 168 | "kernelspec": { 169 | "display_name": "Python [conda env:task]", 170 | "language": "python", 171 | "name": "conda-env-task-py" 172 | }, 173 | "language_info": { 174 | "codemirror_mode": { 175 | "name": "ipython", 176 | "version": 3 177 | }, 178 | "file_extension": ".py", 179 | "mimetype": "text/x-python", 180 | "name": "python", 181 | "nbconvert_exporter": "python", 182 | "pygments_lexer": "ipython3", 183 | "version": "3.7.0" 184 | }, 185 | "pycharm": { 186 | "stem_cell": { 187 | "cell_type": "raw", 188 | "source": [], 189 | "metadata": { 190 | "collapsed": false 191 | } 192 | } 193 | } 194 | }, 195 | "nbformat": 4, 196 | "nbformat_minor": 5 197 | } -------------------------------------------------------------------------------- /transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import pandas as pd 4 | import scanpy as sc 5 | import torch.nn.functional as F 6 | from easydl import * 7 | from anndata import AnnData 8 | from torch import nn, einsum 9 | from scipy.stats import pearsonr 10 | from torch.autograd import Function 11 | from torch.autograd.variable import * 12 | from einops import rearrange, repeat 13 | from einops.layers.torch import Rearrange 14 | class SelectItem(nn.Module): 15 | def __init__(self, item_index): 16 | super(SelectItem, self).__init__() 17 | self.item_index = item_index 18 | 19 | def forward(self, inputs): 20 | return inputs[self.item_index] 21 | class PreNorm(nn.Module): 22 | def __init__(self, dim, fn): 23 | super().__init__() 24 | self.norm = nn.LayerNorm(dim) 25 | self.fn = fn 26 | def forward(self, x, **kwargs): 27 | return self.fn(self.norm(x), **kwargs) 28 | 29 | class FeedForward(nn.Module): 30 | def __init__(self, dim, hidden_dim, dropout = 0.): 31 | super().__init__() 32 | self.net = nn.Sequential( 33 | nn.Linear(dim, hidden_dim), 34 | nn.GELU(), 35 | nn.Dropout(dropout), 36 | nn.Linear(hidden_dim, dim), 37 | nn.Dropout(dropout) 38 | ) 39 | def forward(self, x): 40 | return self.net(x) 41 | 42 | class Attention(nn.Module): 43 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 44 | super().__init__() 45 | inner_dim = dim_head * heads 46 | project_out = not (heads == 1 and dim_head == dim) 47 | 48 | self.heads = heads 49 | self.scale = dim_head ** -0.5 50 | 51 | self.attend = nn.Softmax(dim = -1) 52 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 53 | 54 | self.to_out = nn.Sequential( 55 | nn.Linear(inner_dim, dim), 56 | nn.Dropout(dropout) 57 | ) if project_out else nn.Identity() 58 | 59 | # @get_local('attn') 60 | def forward(self, x): 61 | b, n, _, h = *x.shape, self.heads 62 | qkv = self.to_qkv(x).chunk(3, dim = -1) 63 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) 64 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 65 | attn = self.attend(dots) 66 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 67 | out = rearrange(out, 'b h n d -> b n (h d)') 68 | return self.to_out(out) 69 | 70 | class attn_block(nn.Module): 71 | def __init__(self, dim, heads, dim_head, mlp_dim, dropout = 0.): 72 | super().__init__() 73 | self.attn=PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)) 74 | self.ff=PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) 75 | def forward(self, x): 76 | x = self.attn(x) + x 77 | x = self.ff(x) + x 78 | return x 79 | -------------------------------------------------------------------------------- /tutorial.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "053309c2-cb6d-4623-83b7-7b0d947998e6", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "name": "stdout", 11 | "output_type": "stream", 12 | "text": [ 13 | "[easydl] tensorflow not available!\n" 14 | ] 15 | } 16 | ], 17 | "source": [ 18 | "import torch\n", 19 | "import numpy as np\n", 20 | "import pytorch_lightning as pl\n", 21 | "import torchvision.transforms as tf\n", 22 | "from tqdm import tqdm\n", 23 | "from predict import *\n", 24 | "from HIST2ST import *\n", 25 | "from dataset import ViT_HER2ST, ViT_SKIN\n", 26 | "from scipy.stats import pearsonr,spearmanr\n", 27 | "from torch.utils.data import DataLoader\n", 28 | "from pytorch_lightning.loggers import TensorBoardLogger\n", 29 | "from copy import deepcopy as dcp\n", 30 | "from collections import defaultdict as dfd\n", 31 | "from sklearn.metrics import adjusted_rand_score as ari_score\n", 32 | "from sklearn.metrics.cluster import normalized_mutual_info_score as nmi_score" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "id": "b132bfb4-bd29-44fa-bbca-3790b0848fd4", 38 | "metadata": {}, 39 | "source": [ 40 | "# Data Loading" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 2, 46 | "id": "6281d8f2-e5f7-4c26-88bf-5256197b1418", 47 | "metadata": { 48 | "tags": [] 49 | }, 50 | "outputs": [], 51 | "source": [ 52 | "name=[*[f'A{i}' for i in range(2,7)],*[f'B{i}' for i in range(1,7)],\n", 53 | " *[f'C{i}' for i in range(1,7)],*[f'D{i}' for i in range(1,7)],\n", 54 | " *[f'E{i}' for i in range(1,4)],*[f'F{i}' for i in range(1,4)],*[f'G{i}' for i in range(1,4)]]\n", 55 | "patients = ['P2', 'P5', 'P9', 'P10']\n", 56 | "reps = ['rep1', 'rep2', 'rep3']\n", 57 | "skinname = []\n", 58 | "for i in patients:\n", 59 | " for j in reps:\n", 60 | " skinname.append(i+'_ST_'+j)\n", 61 | "device='cuda'\n", 62 | "tag='5-7-2-8-4-16-32'\n", 63 | "k,p,d1,d2,d3,h,c=map(lambda x:int(x),tag.split('-'))\n", 64 | "dropout=0.2\n", 65 | "random.seed(12000)\n", 66 | "np.random.seed(12000)\n", 67 | "torch.manual_seed(12000)\n", 68 | "torch.cuda.manual_seed(12000)\n", 69 | "torch.cuda.manual_seed_all(12000) \n", 70 | "torch.backends.cudnn.benchmark = False\n", 71 | "torch.backends.cudnn.deterministic = True" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 3, 77 | "id": "0030c0ad-e445-42a5-97b1-135a222f5e82", 78 | "metadata": {}, 79 | "outputs": [ 80 | { 81 | "name": "stdout", 82 | "output_type": "stream", 83 | "text": [ 84 | "['B1']\n", 85 | "Loading imgs...\n", 86 | "Loading metadata...\n" 87 | ] 88 | } 89 | ], 90 | "source": [ 91 | "fold=5\n", 92 | "data='her2st'\n", 93 | "prune='Grid' if data=='her2st' else 'NA'\n", 94 | "genes=171 if data=='cscc' else 785\n", 95 | "trainset = pk_load(fold,'train',dataset=data,flatten=False,adj=True,ori=True,prune=prune)\n", 96 | "train_loader = DataLoader(trainset, batch_size=1, num_workers=0, shuffle=True)" 97 | ] 98 | }, 99 | { 100 | "cell_type": "markdown", 101 | "id": "2febd49c-f377-4211-9133-39da1ba5909e", 102 | "metadata": {}, 103 | "source": [ 104 | "# Hist2ST training" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 4, 110 | "id": "82bc6cee-8b9a-4c43-80b5-1c68feb3c550", 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "model=Hist2ST(\n", 115 | " depth1=d1, depth2=d2,depth3=d3,n_genes=genes,\n", 116 | " kernel_size=k, patch_size=p,\n", 117 | " heads=h, channel=c, dropout=0.2,\n", 118 | " zinb=0.25, nb=False,\n", 119 | " bake=5, lamb=0.5, \n", 120 | ")" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 6, 126 | "id": "3309e1d3-528b-46cd-bcba-2f2cc4020e53", 127 | "metadata": { 128 | "tags": [] 129 | }, 130 | "outputs": [ 131 | { 132 | "name": "stderr", 133 | "output_type": "stream", 134 | "text": [ 135 | "GPU available: True, used: True\n", 136 | "TPU available: False, using: 0 TPU cores\n", 137 | "IPU available: False, using: 0 IPUs\n", 138 | "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]\n", 139 | "\n", 140 | " | Name | Type | Params\n", 141 | "-----------------------------------------------\n", 142 | "0 | patch_embedding | Conv2d | 4.7 K \n", 143 | "1 | x_embed | Embedding | 65.5 K\n", 144 | "2 | y_embed | Embedding | 65.5 K\n", 145 | "3 | vit | ViT | 71.4 M\n", 146 | "4 | mean | Sequential | 804 K \n", 147 | "5 | disp | Sequential | 804 K \n", 148 | "6 | pi | Sequential | 804 K \n", 149 | "7 | coef | Sequential | 1.1 M \n", 150 | "8 | gene_head | Sequential | 806 K \n", 151 | "-----------------------------------------------\n", 152 | "75.8 M Trainable params\n", 153 | "0 Non-trainable params\n", 154 | "75.8 M Total params\n", 155 | "303.159 Total estimated model params size (MB)\n" 156 | ] 157 | }, 158 | { 159 | "data": { 160 | "application/vnd.jupyter.widget-view+json": { 161 | "model_id": "3a07cb1dd99e45c88839b3c19b0285de", 162 | "version_major": 2, 163 | "version_minor": 0 164 | }, 165 | "text/plain": [ 166 | "Training: 0it [00:00, ?it/s]" 167 | ] 168 | }, 169 | "metadata": {}, 170 | "output_type": "display_data" 171 | } 172 | ], 173 | "source": [ 174 | "logger=None\n", 175 | "trainer = pl.Trainer(\n", 176 | " gpus=[0], max_epochs=350,\n", 177 | " logger=logger,\n", 178 | ")\n", 179 | "trainer.fit(model, train_loader)\n", 180 | "\n", 181 | "import os\n", 182 | "if not os.path.isdir(\"./model/\"):\n", 183 | " os.mkdir(\"./model/\")\n", 184 | "\n", 185 | "torch.save(model.state_dict(),f\"./model/{fold}-Hist2ST.ckpt\")" 186 | ] 187 | }, 188 | { 189 | "cell_type": "markdown", 190 | "id": "45afd1c5-2495-4240-b3b6-92b415507582", 191 | "metadata": {}, 192 | "source": [ 193 | "# Hist2ST Prediction" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": 6, 199 | "id": "0f19d5d8-6510-493b-9287-ba7215155dc1", 200 | "metadata": {}, 201 | "outputs": [], 202 | "source": [ 203 | "fold=5\n", 204 | "data='her2st'\n", 205 | "prune='Grid' if data=='her2st' else 'NA'\n", 206 | "genes=171 if data=='cscc' else 785" 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": 7, 212 | "id": "76cfcfb9-3169-46da-9288-2e2be7942d84", 213 | "metadata": {}, 214 | "outputs": [ 215 | { 216 | "name": "stdout", 217 | "output_type": "stream", 218 | "text": [ 219 | "['B1']\n", 220 | "Loading imgs...\n", 221 | "Loading metadata...\n" 222 | ] 223 | }, 224 | { 225 | "name": "stderr", 226 | "output_type": "stream", 227 | "text": [ 228 | "100%|██████████████████████████████████| 1/1 [00:02<00:00, 2.85s/it]\n" 229 | ] 230 | } 231 | ], 232 | "source": [ 233 | "testset = pk_load(fold,'test',dataset=data,flatten=False,adj=True,ori=True,prune=prune)\n", 234 | "test_loader = DataLoader(testset, batch_size=1, num_workers=0, shuffle=False)\n", 235 | "label=testset.label[testset.names[0]]\n", 236 | "genes=785\n", 237 | "model=Hist2ST(\n", 238 | " depth1=d1, depth2=d2,depth3=d3,n_genes=genes, \n", 239 | " kernel_size=k, patch_size=p,\n", 240 | " heads=h, channel=c, dropout=0.2,\n", 241 | " zinb=0.25, nb=False,\n", 242 | " bake=5, lamb=0.5, \n", 243 | ")\n", 244 | "model.load_state_dict(torch.load(f'./model/{fold}-Hist2ST.ckpt'))\n", 245 | "pred, gt = test(model, test_loader,'cuda')" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": 7, 251 | "id": "4266cce3-6bc7-4f5b-81d5-d03d3ae089c4", 252 | "metadata": {}, 253 | "outputs": [ 254 | { 255 | "name": "stdout", 256 | "output_type": "stream", 257 | "text": [ 258 | "Pearson Correlation: 0.2887870599966082\n", 259 | "ARI: 0.431\n" 260 | ] 261 | } 262 | ], 263 | "source": [ 264 | "R=get_R(pred,gt)[0]\n", 265 | "print('Pearson Correlation:',np.nanmean(R))\n", 266 | "clus,ARI=cluster(pred,label)\n", 267 | "print('ARI:',ARI)" 268 | ] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "execution_count": 10, 273 | "id": "6f613201-90c3-4cbf-be06-008b63155c29", 274 | "metadata": {}, 275 | "outputs": [ 276 | { 277 | "name": "stderr", 278 | "output_type": "stream", 279 | "text": [ 280 | "... storing 'kmeans' as categorical\n" 281 | ] 282 | }, 283 | { 284 | "data": { 285 | "image/png": "\n", 286 | "text/plain": [ 287 | "
" 288 | ] 289 | }, 290 | "metadata": {}, 291 | "output_type": "display_data" 292 | } 293 | ], 294 | "source": [ 295 | "sc.pl.spatial(pred, img=None, color='kmeans', spot_size=112)" 296 | ] 297 | } 298 | ], 299 | "metadata": { 300 | "kernelspec": { 301 | "display_name": "Python [conda env:task]", 302 | "language": "python", 303 | "name": "conda-env-task-py" 304 | }, 305 | "language_info": { 306 | "codemirror_mode": { 307 | "name": "ipython", 308 | "version": 3 309 | }, 310 | "file_extension": ".py", 311 | "mimetype": "text/x-python", 312 | "name": "python", 313 | "nbconvert_exporter": "python", 314 | "pygments_lexer": "ipython3", 315 | "version": "3.7.0" 316 | }, 317 | "pycharm": { 318 | "stem_cell": { 319 | "cell_type": "raw", 320 | "source": [], 321 | "metadata": { 322 | "collapsed": false 323 | } 324 | } 325 | } 326 | }, 327 | "nbformat": 4, 328 | "nbformat_minor": 5 329 | } -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import torch 4 | import pickle 5 | import numpy as np 6 | import pandas as pd 7 | import scprep as scp 8 | import anndata as ann 9 | import seaborn as sns 10 | import matplotlib.pyplot as plt 11 | import scanpy as sc, anndata as ad 12 | from os import name 13 | from PIL import Image 14 | from sklearn import preprocessing 15 | from sklearn.cluster import KMeans 16 | Image.MAX_IMAGE_PIXELS = 933120000 17 | # from dataset import MARKERS 18 | BCELL = ['CD19', 'CD79A', 'CD79B', 'MS4A1'] 19 | TUMOR = ['FASN'] 20 | CD4T = ['CD4'] 21 | CD8T = ['CD8A', 'CD8B'] 22 | DC = ['CLIC2', 'CLEC10A', 'CD1B', 'CD1A', 'CD1E'] 23 | MDC = ['LAMP3'] 24 | CMM = ['BRAF', 'KRAS'] 25 | IG = {'B_cell':BCELL, 'Tumor':TUMOR, 'CD4+T_cell':CD4T, 'CD8+T_cell':CD8T, 'Dendritic_cells':DC, 26 | 'Mature_dendritic_cells':MDC, 'Cutaneous_Malignant_Melanoma':CMM} 27 | MARKERS = [] 28 | for i in IG.values(): 29 | MARKERS+=i 30 | LYM = {'B_cell':BCELL, 'CD4+T_cell':CD4T, 'CD8+T_cell':CD8T} 31 | 32 | def read_tiff(path): 33 | Image.MAX_IMAGE_PIXELS = 933120000 34 | im = Image.open(path) 35 | imarray = np.array(im) 36 | # I = plt.imread(path) 37 | return im 38 | 39 | def preprocess(adata, n_keep=1000, include=LYM, g=True): 40 | adata.var_names_make_unique() 41 | sc.pp.normalize_total(adata) 42 | sc.pp.log1p(adata) 43 | if g: 44 | # with open("data/gene_list.txt", "rb") as fp: 45 | # b = pickle.load(fp) 46 | b = list(np.load('data/skin_a.npy',allow_pickle=True)) 47 | adata = adata[:,b] 48 | elif include: 49 | # b = adata.copy() 50 | # sc.pp.highly_variable_genes(b, n_top_genes=n_keep,subset=True) 51 | # hvgs = b.var_names 52 | # n_union = len(hvgs&include) 53 | # n_include = len(include) 54 | # hvgs = list(set(hvgs)-set(include))[n_include-n_union:] 55 | # g = include 56 | # adata = adata[:,g] 57 | exp = np.zeros((adata.X.shape[0],len(include))) 58 | for n,(i,v) in enumerate(include.items()): 59 | tmp = adata[:,v].X 60 | tmp = np.mean(tmp,1).flatten() 61 | exp[:,n] = tmp 62 | adata = adata[:,:len(include)] 63 | adata.X = exp 64 | adata.var_names = list(include.keys()) 65 | 66 | else: 67 | sc.pp.highly_variable_genes(adata, n_top_genes=n_keep,subset=True) 68 | c = adata.obsm['spatial'] 69 | scaler = preprocessing.StandardScaler().fit(c) 70 | c = scaler.transform(c) 71 | adata.obsm['position_norm'] = c 72 | # with open("data/gene_list.txt", "wb") as fp: 73 | # pickle.dump(g, fp) 74 | return adata 75 | 76 | def comp_umap(adata): 77 | sc.pp.pca(adata) 78 | sc.pp.neighbors(adata) 79 | sc.tl.umap(adata) 80 | sc.tl.leiden(adata, key_added="clusters") 81 | return adata 82 | 83 | def comp_tsne_km(adata,k=10): 84 | sc.pp.pca(adata) 85 | sc.tl.tsne(adata) 86 | kmeans = KMeans(n_clusters=k, init="k-means++", random_state=0).fit(adata.obsm['X_pca']) 87 | adata.obs['kmeans'] = kmeans.labels_.astype(str) 88 | return adata 89 | 90 | def co_embed(a,b,k=10): 91 | a.obs['tag'] = 'Truth' 92 | b.obs['tag'] = 'Pred' 93 | adata = ad.concat([a,b]) 94 | sc.pp.pca(adata) 95 | sc.tl.tsne(adata) 96 | kmeans = KMeans(n_clusters=k, init="k-means++", random_state=0).fit(adata.obsm['X_pca']) 97 | adata.obs['kmeans'] = kmeans.labels_.astype(str) 98 | return adata 99 | 100 | def build_adata(name='H1'): 101 | cnt_dir = 'data/her2st/data/ST-cnts' 102 | img_dir = 'data/her2st/data/ST-imgs' 103 | pos_dir = 'data/her2st/data/ST-spotfiles' 104 | 105 | pre = img_dir+'/'+name[0]+'/'+name 106 | fig_name = os.listdir(pre)[0] 107 | path = pre+'/'+fig_name 108 | im = Image.open(path) 109 | 110 | path = cnt_dir+'/'+name+'.tsv' 111 | cnt = pd.read_csv(path,sep='\t',index_col=0) 112 | 113 | path = pos_dir+'/'+name+'_selection.tsv' 114 | df = pd.read_csv(path,sep='\t') 115 | 116 | x = df['x'].values 117 | y = df['y'].values 118 | id = [] 119 | for i in range(len(x)): 120 | id.append(str(x[i])+'x'+str(y[i])) 121 | df['id'] = id 122 | 123 | meta = cnt.join((df.set_index('id'))) 124 | 125 | gene_list = list(np.load('data/her_g_list.npy')) 126 | adata = ann.AnnData(scp.transform.log(scp.normalize.library_size_normalize(meta[gene_list].values))) 127 | adata.var_names = gene_list 128 | adata.obsm['spatial'] = np.floor(meta[['pixel_x','pixel_y']].values).astype(int) 129 | 130 | return adata, im 131 | 132 | 133 | def get_data(dataset='bc1', n_keep=1000, include=LYM, g=True): 134 | if dataset == 'bc1': 135 | adata = sc.datasets.visium_sge(sample_id='V1_Breast_Cancer_Block_A_Section_1', include_hires_tiff=True) 136 | adata = preprocess(adata, n_keep, include, g) 137 | img_path = adata.uns["spatial"]['V1_Breast_Cancer_Block_A_Section_1']["metadata"]["source_image_path"] 138 | elif dataset == 'bc2': 139 | adata = sc.datasets.visium_sge(sample_id='V1_Breast_Cancer_Block_A_Section_2', include_hires_tiff=True) 140 | adata = preprocess(adata, n_keep, include, g) 141 | img_path = adata.uns["spatial"]['V1_Breast_Cancer_Block_A_Section_2']["metadata"]["source_image_path"] 142 | else: 143 | adata = sc.datasets.visium_sge(sample_id=dataset, include_hires_tiff=True) 144 | adata = preprocess(adata, n_keep, include, g) 145 | img_path = adata.uns["spatial"][dataset]["metadata"]["source_image_path"] 146 | 147 | return adata, img_path 148 | def find_resolution(adata_, n_clusters, random=666): 149 | obtained_clusters = -1 150 | iteration = 0 151 | resolutions = [0., 1000.] 152 | while obtained_clusters != n_clusters and iteration < 50: 153 | current_res = sum(resolutions) / 2 154 | adata = sc.tl.louvain(adata_, resolution=current_res, random_state=random, copy=True) 155 | labels = adata.obs['louvain'] 156 | obtained_clusters = len(np.unique(labels)) 157 | 158 | if obtained_clusters < n_clusters: 159 | resolutions[0] = current_res 160 | else: 161 | resolutions[1] = current_res 162 | 163 | iteration = iteration + 1 164 | 165 | return current_res 166 | def get_center_labels(features, resolution=0.1): 167 | n_cells = features.shape[0] 168 | adata0 = ad.AnnData(features) 169 | sc.pp.neighbors(adata0, n_neighbors=15, use_rep="X") 170 | adata0 = sc.tl.louvain(adata0, resolution=resolution, random_state=0, copy=True) 171 | y_pred = adata0.obs['louvain'] 172 | y_pred = np.asarray(y_pred, dtype=int) 173 | 174 | features = pd.DataFrame(adata0.X, index=np.arange(0, adata0.shape[0])) 175 | Group = pd.Series(y_pred, index=np.arange(0, adata0.shape[0]), name="Group") 176 | Mergefeature = pd.concat([features, Group], axis=1) 177 | 178 | init_centroid = np.asarray(Mergefeature.groupby("Group").mean()) 179 | n_clusters = init_centroid.shape[0] 180 | 181 | return init_centroid, y_pred 182 | def lvcluster(adata1,label): 183 | adata=adata1.copy() 184 | n_clusters = len(set(label)) 185 | sc.pp.neighbors(adata, n_neighbors=45, use_rep="X") 186 | resolution = find_resolution(adata, n_clusters) 187 | init_centers, cluster_labels_cpu = get_center_labels(adata.X, resolution=resolution) 188 | return cluster_labels_cpu 189 | def normalize( 190 | adata, filter_min_counts=False, size_factors=True, 191 | normalize_input=True, logtrans_input=True, hvg=True 192 | ): 193 | 194 | if filter_min_counts: 195 | sc.pp.filter_genes(adata, min_counts=1) 196 | sc.pp.filter_cells(adata, min_counts=1) 197 | # sc.pp.filter_genes(adata, min_genes=2000) 198 | # sc.pp.filter_cells(adata, min_cells=3) 199 | 200 | if size_factors or normalize_input or logtrans_input: 201 | adata.raw = adata.copy() 202 | else: 203 | adata.raw = adata 204 | 205 | if size_factors: 206 | sc.pp.normalize_per_cell(adata) 207 | adata.obs['size_factors'] = adata.obs.n_counts / np.median(adata.obs.n_counts) 208 | else: 209 | adata.obs['size_factors'] = 1.0 210 | 211 | if logtrans_input: 212 | sc.pp.log1p(adata) 213 | 214 | if normalize_input: 215 | sc.pp.scale(adata) 216 | 217 | if hvg: 218 | sc.pp.highly_variable_genes(adata, n_top_genes=2000) # min_mean=0.0125, max_mean=3, min_disp=0.5 219 | adata = adata[:, adata.var.highly_variable] 220 | 221 | return adata 222 | if __name__ == '__main__': 223 | 224 | adata, img_path = get_data() 225 | print(adata.X.toarray()) 226 | --------------------------------------------------------------------------------