├── LICENSE ├── README.md ├── Workflow.PNG ├── data ├── download.sh └── her_hvg_cut_1000.npy ├── dataset.py ├── model └── download_checkpoint.txt ├── predict.py ├── transformer.py ├── tutorial.ipynb ├── utils.py └── vis_model.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 maxpmx 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Leveraging information in spatial transcriptomics to predict super-resolution gene expression from histology images in tumors 2 | ### Minxing Pang, Kenong Su*, Mingyao Li* 3 | HisToGene is a deep learning method that predicts super-resolution gene expression from histology images in tumors. Trained in a spatial transcriptomics dataset, HisToGene models the spatial dependency in gene expression and histological features among spots through a modified Vision Transformer model. [[bioRxiv]](https://doi.org/10.1101/2021.11.28.470212) 4 | 5 | # Usage 6 | ```python 7 | import torch 8 | from vis_model import HisToGene 9 | 10 | model = HisToGene( 11 | n_genes=1000, 12 | patch_size=112, 13 | n_layers=4, 14 | dim=1024, 15 | learning_rate=1e-5, 16 | dropout=0.1, 17 | n_pos=64 18 | ) 19 | 20 | # flatten_patches: [N, 3*W*H] 21 | # coordinates: [N, 2] 22 | 23 | pred_expression = model(flatten_patches, coordinates) # [N, n_genes] 24 | 25 | ``` 26 | 27 | ## System environment 28 | Required package: 29 | - PyTorch >= 1.8 30 | - pytorch-lightning >= 1.4 31 | - scanpy >= 1.8 32 | 33 | ## Parameters 34 | - `n_genes`: int. 35 | Amount of genes. 36 | - `patch_size`: int. 37 | Width/diameter of the spots. 38 | - `n_layers`: int, default `4`. 39 | Number of Transformer blocks. 40 | - `dim`: int. 41 | Dimension of the embeddings. 42 | - `learning_rate`: float between `[0, 1]`, default `1e-5`. 43 | Learning rate. 44 | - `dropout`: float between `[0, 1]`, default `0.1`. 45 | Dropout rate in the Transformer. 46 | - `n_pos`: int, default `64`. 47 | Maximum number of the coordinates. 48 | 49 | # HisToGene pipeline 50 | See [tutorial.ipynb](tutorial.ipynb) 51 | 52 | # References 53 | https://github.com/almaan/her2st 54 | -------------------------------------------------------------------------------- /Workflow.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpmx/HisToGene/44ff75b9b68949796f27f1b4121d36ea4cba8bf3/Workflow.PNG -------------------------------------------------------------------------------- /data/download.sh: -------------------------------------------------------------------------------- 1 | git clone https://github.com/almaan/her2st.git 2 | -------------------------------------------------------------------------------- /data/her_hvg_cut_1000.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpmx/HisToGene/44ff75b9b68949796f27f1b4121d36ea4cba8bf3/data/her_hvg_cut_1000.npy -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | from utils import read_tiff 6 | import numpy as np 7 | import torchvision 8 | import torchvision.transforms as transforms 9 | import scanpy as sc 10 | from utils import get_data 11 | import os 12 | import glob 13 | from PIL import Image 14 | import pandas as pd 15 | import scprep as scp 16 | from PIL import ImageFile 17 | import seaborn as sns 18 | import matplotlib.pyplot as plt 19 | ImageFile.LOAD_TRUNCATED_IMAGES = True 20 | Image.MAX_IMAGE_PIXELS = None 21 | 22 | BCELL = ['CD19', 'CD79A', 'CD79B', 'MS4A1'] 23 | TUMOR = ['FASN'] 24 | CD4T = ['CD4'] 25 | CD8T = ['CD8A', 'CD8B'] 26 | DC = ['CLIC2', 'CLEC10A', 'CD1B', 'CD1A', 'CD1E'] 27 | MDC = ['LAMP3'] 28 | CMM = ['BRAF', 'KRAS'] 29 | IG = {'B_cell':BCELL, 'Tumor':TUMOR, 'CD4+T_cell':CD4T, 'CD8+T_cell':CD8T, 'Dendritic_cells':DC, 30 | 'Mature_dendritic_cells':MDC, 'Cutaneous_Malignant_Melanoma':CMM} 31 | MARKERS = [] 32 | for i in IG.values(): 33 | MARKERS+=i 34 | LYM = {'B_cell':BCELL, 'CD4+T_cell':CD4T, 'CD8+T_cell':CD8T} 35 | 36 | class STDataset(torch.utils.data.Dataset): 37 | """Some Information about STDataset""" 38 | def __init__(self, adata, img_path, diameter=177.5, train=True): 39 | super(STDataset, self).__init__() 40 | 41 | self.exp = adata.X.toarray() 42 | self.im = read_tiff(img_path) 43 | self.r = np.ceil(diameter/2).astype(int) 44 | self.train = train 45 | # self.d_spot = self.d_spot if self.d_spot%2==0 else self.d_spot+1 46 | self.transforms = transforms.Compose([ 47 | transforms.ColorJitter(0.5,0.5,0.5), 48 | transforms.RandomHorizontalFlip(), 49 | transforms.RandomRotation(degrees=180), 50 | transforms.ToTensor() 51 | ]) 52 | self.centers = adata.obsm['spatial'] 53 | self.pos = adata.obsm['position_norm'] 54 | def __getitem__(self, index): 55 | exp = self.exp[index] 56 | center = self.centers[index] 57 | x, y = center 58 | patch = self.im.crop((x-self.r, y-self.r, x+self.r, y+self.r)) 59 | exp = torch.Tensor(exp) 60 | mask = exp!=0 61 | mask = mask.float() 62 | if self.train: 63 | patch = self.transforms(patch) 64 | pos = torch.Tensor(self.pos[index]) 65 | return patch, pos, exp, mask 66 | 67 | def __len__(self): 68 | return len(self.centers) 69 | 70 | 71 | 72 | class HER2ST(torch.utils.data.Dataset): 73 | """Some Information about HER2ST""" 74 | def __init__(self,train=True,gene_list=None,ds=None,fold=0): 75 | super(HER2ST, self).__init__() 76 | self.cnt_dir = 'data/her2st/data/ST-cnts' 77 | self.img_dir = 'data/her2st/data/ST-imgs' 78 | self.pos_dir = 'data/her2st/data/ST-spotfiles' 79 | self.lbl_dir = 'data/her2st/data/ST-pat/lbl' 80 | self.r = 224//2 81 | # gene_list = list(np.load('data/her_hvg.npy',allow_pickle=True)) 82 | gene_list = list(np.load('data/her_hvg_cut_1000.npy',allow_pickle=True)) 83 | self.gene_list = gene_list 84 | names = os.listdir(self.cnt_dir) 85 | names.sort() 86 | names = [i[:2] for i in names] 87 | self.train = train 88 | 89 | # samples = ['A1','B1','C1','D1','E1','F1','G2','H1'] 90 | samples = names[1:33] 91 | te_names = [samples[fold]] 92 | tr_names = list(set(samples)-set(te_names)) 93 | if train: 94 | # names = names[1:33] 95 | # names = names[1:33] if self.cls==False else ['A1','B1','C1','D1','E1','F1','G2'] 96 | names = tr_names 97 | else: 98 | # names = [names[33]] 99 | # names = ['A1'] 100 | # names = [ds] if ds else ['H1'] 101 | names = te_names 102 | print('Loading imgs...') 103 | self.img_dict = {i:self.get_img(i) for i in names} 104 | print('Loading metadata...') 105 | self.meta_dict = {i:self.get_meta(i) for i in names} 106 | 107 | # self.gene_set = self.get_overlap(self.meta_dict,gene_list) 108 | # print(len(self.gene_set)) 109 | # np.save('data/her_hvg',self.gene_set) 110 | # quit() 111 | self.gene_set = list(gene_list) 112 | self.exp_dict = {i:scp.transform.log(scp.normalize.library_size_normalize(m[self.gene_set].values)) for i,m in self.meta_dict.items()} 113 | self.center_dict = {i:np.floor(m[['pixel_x','pixel_y']].values).astype(int) for i,m in self.meta_dict.items()} 114 | self.loc_dict = {i:m[['x','y']].values for i,m in self.meta_dict.items()} 115 | 116 | 117 | self.lengths = [len(i) for i in self.meta_dict.values()] 118 | self.cumlen = np.cumsum(self.lengths) 119 | self.id2name = dict(enumerate(names)) 120 | 121 | self.transforms = transforms.Compose([ 122 | transforms.ColorJitter(0.5,0.5,0.5), 123 | transforms.RandomHorizontalFlip(), 124 | transforms.RandomRotation(degrees=180), 125 | transforms.ToTensor() 126 | ]) 127 | def __getitem__(self, index): 128 | i = 0 129 | while index>=self.cumlen[i]: 130 | i += 1 131 | idx = index 132 | if i > 0: 133 | idx = index - self.cumlen[i-1] 134 | 135 | exp = self.exp_dict[self.id2name[i]][idx] 136 | center = self.center_dict[self.id2name[i]][idx] 137 | loc = self.loc_dict[self.id2name[i]][idx] 138 | 139 | # if self.cls or self.train==False: 140 | 141 | exp = torch.Tensor(exp) 142 | loc = torch.Tensor(loc) 143 | 144 | x, y = center 145 | patch = self.img_dict[self.id2name[i]].crop((x-self.r, y-self.r, x+self.r, y+self.r)) 146 | if self.train: 147 | patch = self.transforms(patch) 148 | else: 149 | patch = transforms.ToTensor()(patch) 150 | 151 | if self.train: 152 | return patch, loc, exp 153 | else: 154 | return patch, loc, exp, torch.Tensor(center) 155 | 156 | def __len__(self): 157 | return self.cumlen[-1] 158 | 159 | def get_img(self,name): 160 | pre = self.img_dir+'/'+name[0]+'/'+name 161 | fig_name = os.listdir(pre)[0] 162 | path = pre+'/'+fig_name 163 | im = Image.open(path) 164 | return im 165 | 166 | def get_cnt(self,name): 167 | path = self.cnt_dir+'/'+name+'.tsv' 168 | df = pd.read_csv(path,sep='\t',index_col=0) 169 | return df 170 | 171 | def get_pos(self,name): 172 | path = self.pos_dir+'/'+name+'_selection.tsv' 173 | # path = self.pos_dir+'/'+name+'_labeled_coordinates.tsv' 174 | df = pd.read_csv(path,sep='\t') 175 | 176 | x = df['x'].values 177 | y = df['y'].values 178 | x = np.around(x).astype(int) 179 | y = np.around(y).astype(int) 180 | id = [] 181 | for i in range(len(x)): 182 | id.append(str(x[i])+'x'+str(y[i])) 183 | df['id'] = id 184 | 185 | return df 186 | 187 | def get_lbl(self,name): 188 | # path = self.pos_dir+'/'+name+'_selection.tsv' 189 | path = self.lbl_dir+'/'+name+'_labeled_coordinates.tsv' 190 | df = pd.read_csv(path,sep='\t') 191 | 192 | x = df['x'].values 193 | y = df['y'].values 194 | x = np.around(x).astype(int) 195 | y = np.around(y).astype(int) 196 | id = [] 197 | for i in range(len(x)): 198 | id.append(str(x[i])+'x'+str(y[i])) 199 | df['id'] = id 200 | df.drop('pixel_x', inplace=True, axis=1) 201 | df.drop('pixel_y', inplace=True, axis=1) 202 | df.drop('x', inplace=True, axis=1) 203 | df.drop('y', inplace=True, axis=1) 204 | 205 | return df 206 | 207 | def get_meta(self,name,gene_list=None): 208 | cnt = self.get_cnt(name) 209 | pos = self.get_pos(name) 210 | meta = cnt.join((pos.set_index('id'))) 211 | self.max_x = 0 212 | self.max_y = 0 213 | loc = meta[['x','y']].values 214 | self.max_x = max(self.max_x, loc[:,0].max()) 215 | self.max_y = max(self.max_y, loc[:,1].max()) 216 | return meta 217 | 218 | def get_overlap(self,meta_dict,gene_list): 219 | gene_set = set(gene_list) 220 | for i in meta_dict.values(): 221 | gene_set = gene_set&set(i.columns) 222 | return list(gene_set) 223 | 224 | class ViT_HER2ST(torch.utils.data.Dataset): 225 | """Some Information about HER2ST""" 226 | def __init__(self,train=True,gene_list=None,ds=None,sr=False,fold=0): 227 | super(ViT_HER2ST, self).__init__() 228 | 229 | self.cnt_dir = 'data/her2st/data/ST-cnts' 230 | self.img_dir = 'data/her2st/data/ST-imgs' 231 | self.pos_dir = 'data/her2st/data/ST-spotfiles' 232 | self.lbl_dir = 'data/her2st/data/ST-pat/lbl' 233 | self.r = 224//4 234 | 235 | # gene_list = list(np.load('data/her_hvg.npy',allow_pickle=True)) 236 | gene_list = list(np.load('data/her_hvg_cut_1000.npy',allow_pickle=True)) 237 | self.gene_list = gene_list 238 | names = os.listdir(self.cnt_dir) 239 | names.sort() 240 | names = [i[:2] for i in names] 241 | self.train = train 242 | self.sr = sr 243 | 244 | # samples = ['A1','B1','C1','D1','E1','F1','G2','H1'] 245 | samples = names[1:33] 246 | 247 | te_names = [samples[fold]] 248 | tr_names = list(set(samples)-set(te_names)) 249 | 250 | if train: 251 | # names = names[1:33] 252 | # names = names[1:33] if self.cls==False else ['A1','B1','C1','D1','E1','F1','G2'] 253 | names = tr_names 254 | else: 255 | # names = [names[33]] 256 | # names = ['A1'] 257 | # names = [ds] if ds else ['H1'] 258 | names = te_names 259 | 260 | print('Loading imgs...') 261 | self.img_dict = {i:torch.Tensor(np.array(self.get_img(i))) for i in names} 262 | print('Loading metadata...') 263 | self.meta_dict = {i:self.get_meta(i) for i in names} 264 | 265 | 266 | self.gene_set = list(gene_list) 267 | self.exp_dict = {i:scp.transform.log(scp.normalize.library_size_normalize(m[self.gene_set].values)) for i,m in self.meta_dict.items()} 268 | self.center_dict = {i:np.floor(m[['pixel_x','pixel_y']].values).astype(int) for i,m in self.meta_dict.items()} 269 | self.loc_dict = {i:m[['x','y']].values for i,m in self.meta_dict.items()} 270 | 271 | 272 | self.lengths = [len(i) for i in self.meta_dict.values()] 273 | self.cumlen = np.cumsum(self.lengths) 274 | self.id2name = dict(enumerate(names)) 275 | 276 | self.transforms = transforms.Compose([ 277 | transforms.ColorJitter(0.5,0.5,0.5), 278 | transforms.RandomHorizontalFlip(), 279 | transforms.RandomRotation(degrees=180), 280 | transforms.ToTensor() 281 | ]) 282 | 283 | def filter_helper(self): 284 | a = np.zeros(len(self.gene_list)) 285 | n = 0 286 | for i,exp in self.exp_dict.items(): 287 | n += exp.shape[0] 288 | exp[exp>0] = 1 289 | for j in range((len(self.gene_list))): 290 | a[j] += np.sum(exp[:,j]) 291 | 292 | 293 | def __getitem__(self, index): 294 | i = index 295 | im = self.img_dict[self.id2name[i]] 296 | im = im.permute(1,0,2) 297 | # im = torch.Tensor(np.array(self.im)) 298 | exps = self.exp_dict[self.id2name[i]] 299 | centers = self.center_dict[self.id2name[i]] 300 | loc = self.loc_dict[self.id2name[i]] 301 | positions = torch.LongTensor(loc) 302 | patch_dim = 3 * self.r * self.r * 4 303 | 304 | if self.sr: 305 | centers = torch.LongTensor(centers) 306 | max_x = centers[:,0].max().item() 307 | max_y = centers[:,1].max().item() 308 | min_x = centers[:,0].min().item() 309 | min_y = centers[:,1].min().item() 310 | r_x = (max_x - min_x)//30 311 | r_y = (max_y - min_y)//30 312 | 313 | centers = torch.LongTensor([min_x,min_y]).view(1,-1) 314 | positions = torch.LongTensor([0,0]).view(1,-1) 315 | x = min_x 316 | y = min_y 317 | 318 | while y < max_y: 319 | x = min_x 320 | while x < max_x: 321 | centers = torch.cat((centers,torch.LongTensor([x,y]).view(1,-1)),dim=0) 322 | positions = torch.cat((positions,torch.LongTensor([x//r_x,y//r_y]).view(1,-1)),dim=0) 323 | x += 56 324 | y += 56 325 | 326 | centers = centers[1:,:] 327 | positions = positions[1:,:] 328 | 329 | n_patches = len(centers) 330 | patches = torch.zeros((n_patches,patch_dim)) 331 | for i in range(n_patches): 332 | center = centers[i] 333 | x, y = center 334 | patch = im[(x-self.r):(x+self.r),(y-self.r):(y+self.r),:] 335 | patches[i] = patch.flatten() 336 | 337 | 338 | return patches, positions, centers 339 | 340 | else: 341 | n_patches = len(centers) 342 | 343 | patches = torch.zeros((n_patches,patch_dim)) 344 | exps = torch.Tensor(exps) 345 | 346 | 347 | for i in range(n_patches): 348 | center = centers[i] 349 | x, y = center 350 | patch = im[(x-self.r):(x+self.r),(y-self.r):(y+self.r),:] 351 | patches[i] = patch.flatten() 352 | 353 | if self.train: 354 | return patches, positions, exps 355 | else: 356 | return patches, positions, exps, torch.Tensor(centers) 357 | 358 | def __len__(self): 359 | return len(self.exp_dict) 360 | 361 | def get_img(self,name): 362 | pre = self.img_dir+'/'+name[0]+'/'+name 363 | fig_name = os.listdir(pre)[0] 364 | path = pre+'/'+fig_name 365 | im = Image.open(path) 366 | return im 367 | 368 | def get_cnt(self,name): 369 | path = self.cnt_dir+'/'+name+'.tsv' 370 | df = pd.read_csv(path,sep='\t',index_col=0) 371 | 372 | return df 373 | 374 | def get_pos(self,name): 375 | path = self.pos_dir+'/'+name+'_selection.tsv' 376 | # path = self.pos_dir+'/'+name+'_labeled_coordinates.tsv' 377 | df = pd.read_csv(path,sep='\t') 378 | 379 | x = df['x'].values 380 | y = df['y'].values 381 | x = np.around(x).astype(int) 382 | y = np.around(y).astype(int) 383 | id = [] 384 | for i in range(len(x)): 385 | id.append(str(x[i])+'x'+str(y[i])) 386 | df['id'] = id 387 | 388 | return df 389 | 390 | def get_lbl(self,name): 391 | # path = self.pos_dir+'/'+name+'_selection.tsv' 392 | path = self.lbl_dir+'/'+name+'_labeled_coordinates.tsv' 393 | df = pd.read_csv(path,sep='\t') 394 | 395 | x = df['x'].values 396 | y = df['y'].values 397 | x = np.around(x).astype(int) 398 | y = np.around(y).astype(int) 399 | id = [] 400 | for i in range(len(x)): 401 | id.append(str(x[i])+'x'+str(y[i])) 402 | df['id'] = id 403 | df.drop('pixel_x', inplace=True, axis=1) 404 | df.drop('pixel_y', inplace=True, axis=1) 405 | df.drop('x', inplace=True, axis=1) 406 | df.drop('y', inplace=True, axis=1) 407 | 408 | return df 409 | 410 | def get_meta(self,name,gene_list=None): 411 | cnt = self.get_cnt(name) 412 | pos = self.get_pos(name) 413 | meta = cnt.join((pos.set_index('id'))) 414 | 415 | return meta 416 | 417 | def get_overlap(self,meta_dict,gene_list): 418 | gene_set = set(gene_list) 419 | for i in meta_dict.values(): 420 | gene_set = gene_set&set(i.columns) 421 | return list(gene_set) 422 | 423 | 424 | class ViT_SKIN(torch.utils.data.Dataset): 425 | """Some Information about ViT_SKIN""" 426 | def __init__(self,train=True,gene_list=None,ds=None,sr=False,aug=False,norm=False,fold=0): 427 | super(ViT_SKIN, self).__init__() 428 | 429 | self.dir = '/ibex/scratch/pangm0a/spatial/data/GSE144240_RAW/' 430 | self.r = 224//4 431 | 432 | patients = ['P2', 'P5', 'P9', 'P10'] 433 | reps = ['rep1', 'rep2', 'rep3'] 434 | names = [] 435 | for i in patients: 436 | for j in reps: 437 | names.append(i+'_ST_'+j) 438 | test_names = ['P2_ST_rep2'] 439 | 440 | # gene_list = list(np.load('data/skin_hvg.npy',allow_pickle=True)) 441 | gene_list = list(np.load('data/skin_hvg_cut_1000.npy',allow_pickle=True)) 442 | # gene_list = list(np.load('figures/mse_2000-vit_skin_a.npy',allow_pickle=True)) 443 | 444 | self.gene_list = gene_list 445 | 446 | self.train = train 447 | self.sr = sr 448 | self.aug = aug 449 | self.transforms = transforms.Compose([ 450 | transforms.ColorJitter(0.5,0.5,0.5), 451 | transforms.ToTensor() 452 | ]) 453 | self.norm = norm 454 | 455 | samples = names 456 | te_names = [samples[fold]] 457 | tr_names = list(set(samples)-set(te_names)) 458 | 459 | if train: 460 | # names = names 461 | # names = names[3:] 462 | # names = test_names 463 | names = tr_names 464 | else: 465 | # names = [names[33]] 466 | # names = ['A1'] 467 | # names = test_names 468 | names = te_names 469 | 470 | print('Loading imgs...') 471 | if self.aug: 472 | self.img_dict = {i: self.get_img(i) for i in names} 473 | else: 474 | self.img_dict = {i:torch.Tensor(np.array(self.get_img(i))) for i in names} 475 | print('Loading metadata...') 476 | self.meta_dict = {i:self.get_meta(i) for i in names} 477 | 478 | self.gene_set = list(gene_list) 479 | if self.norm: 480 | self.exp_dict = {i:sc.pp.scale(scp.transform.log(scp.normalize.library_size_normalize(m[self.gene_set].values))) for i,m in self.meta_dict.items()} 481 | else: 482 | self.exp_dict = {i:scp.transform.log(scp.normalize.library_size_normalize(m[self.gene_set].values)) for i,m in self.meta_dict.items()} 483 | self.center_dict = {i:np.floor(m[['pixel_x','pixel_y']].values).astype(int) for i,m in self.meta_dict.items()} 484 | self.loc_dict = {i:m[['x','y']].values for i,m in self.meta_dict.items()} 485 | 486 | self.lengths = [len(i) for i in self.meta_dict.values()] 487 | self.cumlen = np.cumsum(self.lengths) 488 | self.id2name = dict(enumerate(names)) 489 | 490 | 491 | def filter_helper(self): 492 | a = np.zeros(len(self.gene_list)) 493 | n = 0 494 | for i,exp in self.exp_dict.items(): 495 | n += exp.shape[0] 496 | exp[exp>0] = 1 497 | for j in range((len(self.gene_list))): 498 | a[j] += np.sum(exp[:,j]) 499 | 500 | 501 | def __getitem__(self, index): 502 | i = index 503 | im = self.img_dict[self.id2name[i]] 504 | if self.aug: 505 | im = self.transforms(im) 506 | # im = im.permute(1,2,0) 507 | im = im.permute(2,1,0) 508 | else: 509 | im = im.permute(1,0,2) 510 | # im = im 511 | 512 | exps = self.exp_dict[self.id2name[i]] 513 | centers = self.center_dict[self.id2name[i]] 514 | loc = self.loc_dict[self.id2name[i]] 515 | positions = torch.LongTensor(loc) 516 | patch_dim = 3 * self.r * self.r * 4 517 | 518 | if self.sr: 519 | centers = torch.LongTensor(centers) 520 | max_x = centers[:,0].max().item() 521 | max_y = centers[:,1].max().item() 522 | min_x = centers[:,0].min().item() 523 | min_y = centers[:,1].min().item() 524 | r_x = (max_x - min_x)//30 525 | r_y = (max_y - min_y)//30 526 | 527 | centers = torch.LongTensor([min_x,min_y]).view(1,-1) 528 | positions = torch.LongTensor([0,0]).view(1,-1) 529 | x = min_x 530 | y = min_y 531 | 532 | while y < max_y: 533 | x = min_x 534 | while x < max_x: 535 | centers = torch.cat((centers,torch.LongTensor([x,y]).view(1,-1)),dim=0) 536 | positions = torch.cat((positions,torch.LongTensor([x//r_x,y//r_y]).view(1,-1)),dim=0) 537 | x += 56 538 | y += 56 539 | 540 | centers = centers[1:,:] 541 | positions = positions[1:,:] 542 | 543 | n_patches = len(centers) 544 | patches = torch.zeros((n_patches,patch_dim)) 545 | for i in range(n_patches): 546 | center = centers[i] 547 | x, y = center 548 | patch = im[(x-self.r):(x+self.r),(y-self.r):(y+self.r),:] 549 | patches[i] = patch.flatten() 550 | 551 | 552 | return patches, positions, centers 553 | 554 | else: 555 | n_patches = len(centers) 556 | 557 | patches = torch.zeros((n_patches,patch_dim)) 558 | exps = torch.Tensor(exps) 559 | 560 | for i in range(n_patches): 561 | center = centers[i] 562 | x, y = center 563 | patch = im[(x-self.r):(x+self.r),(y-self.r):(y+self.r),:] 564 | patches[i] = patch.flatten() 565 | 566 | 567 | if self.train: 568 | return patches, positions, exps 569 | else: 570 | return patches, positions, exps, torch.Tensor(centers) 571 | 572 | def __len__(self): 573 | return len(self.exp_dict) 574 | 575 | def get_img(self,name): 576 | path = glob.glob(self.dir+'*'+name+'.jpg')[0] 577 | im = Image.open(path) 578 | return im 579 | 580 | def get_cnt(self,name): 581 | path = glob.glob(self.dir+'*'+name+'_stdata.tsv')[0] 582 | df = pd.read_csv(path,sep='\t',index_col=0) 583 | return df 584 | 585 | def get_pos(self,name): 586 | path = glob.glob(self.dir+'*spot*'+name+'.tsv')[0] 587 | df = pd.read_csv(path,sep='\t') 588 | 589 | x = df['x'].values 590 | y = df['y'].values 591 | x = np.around(x).astype(int) 592 | y = np.around(y).astype(int) 593 | id = [] 594 | for i in range(len(x)): 595 | id.append(str(x[i])+'x'+str(y[i])) 596 | df['id'] = id 597 | 598 | return df 599 | 600 | def get_meta(self,name,gene_list=None): 601 | cnt = self.get_cnt(name) 602 | pos = self.get_pos(name) 603 | meta = cnt.join(pos.set_index('id'),how='inner') 604 | 605 | return meta 606 | 607 | def get_overlap(self,meta_dict,gene_list): 608 | gene_set = set(gene_list) 609 | for i in meta_dict.values(): 610 | gene_set = gene_set&set(i.columns) 611 | return list(gene_set) 612 | 613 | 614 | class SKIN(torch.utils.data.Dataset): 615 | """Some Information about ViT_SKIN""" 616 | def __init__(self,train=True,gene_list=None,ds=None,sr=False,fold=0): 617 | super(SKIN, self).__init__() 618 | 619 | self.dir = '/ibex/scratch/pangm0a/spatial/data/GSE144240_RAW/' 620 | self.r = 224//2 621 | 622 | patients = ['P2', 'P5', 'P9', 'P10'] 623 | reps = ['rep1', 'rep2', 'rep3'] 624 | names = [] 625 | for i in patients: 626 | for j in reps: 627 | names.append(i+'_ST_'+j) 628 | test_names = ['P2_ST_rep2'] 629 | 630 | gene_list = list(np.load('data/skin_hvg_cut_1000.npy',allow_pickle=True)) 631 | self.gene_list = gene_list 632 | 633 | self.train = train 634 | self.sr = sr 635 | 636 | samples = names 637 | te_names = [samples[fold]] 638 | tr_names = list(set(samples)-set(te_names)) 639 | 640 | if train: 641 | # names = names 642 | # names = names[3:] 643 | # names = test_names 644 | names = tr_names 645 | else: 646 | # names = [names[33]] 647 | # names = ['A1'] 648 | # names = test_names 649 | names = te_names 650 | 651 | print('Loading imgs...') 652 | self.img_dict = {i:self.get_img(i) for i in names} 653 | print('Loading metadata...') 654 | self.meta_dict = {i:self.get_meta(i) for i in names} 655 | 656 | self.gene_set = list(gene_list) 657 | self.exp_dict = {i:scp.transform.log(scp.normalize.library_size_normalize(m[self.gene_set].values)) for i,m in self.meta_dict.items()} 658 | self.center_dict = {i:np.floor(m[['pixel_x','pixel_y']].values).astype(int) for i,m in self.meta_dict.items()} 659 | self.loc_dict = {i:m[['x','y']].values for i,m in self.meta_dict.items()} 660 | 661 | self.lengths = [len(i) for i in self.meta_dict.values()] 662 | self.cumlen = np.cumsum(self.lengths) 663 | self.id2name = dict(enumerate(names)) 664 | 665 | self.transforms = transforms.Compose([ 666 | transforms.ColorJitter(0.5,0.5,0.5), 667 | transforms.RandomHorizontalFlip(), 668 | transforms.RandomRotation(degrees=180), 669 | transforms.ToTensor() 670 | ]) 671 | 672 | def __getitem__(self, index): 673 | i = 0 674 | while index>=self.cumlen[i]: 675 | i += 1 676 | idx = index 677 | if i > 0: 678 | idx = index - self.cumlen[i-1] 679 | 680 | exp = self.exp_dict[self.id2name[i]][idx] 681 | center = self.center_dict[self.id2name[i]][idx] 682 | loc = self.loc_dict[self.id2name[i]][idx] 683 | 684 | exp = torch.Tensor(exp) 685 | loc = torch.Tensor(loc) 686 | 687 | x, y = center 688 | patch = self.img_dict[self.id2name[i]].crop((x-self.r, y-self.r, x+self.r, y+self.r)) 689 | if self.train: 690 | patch = self.transforms(patch) 691 | else: 692 | patch = transforms.ToTensor()(patch) 693 | 694 | if self.train: 695 | return patch, loc, exp 696 | else: 697 | return patch, loc, exp, torch.Tensor(center) 698 | 699 | def __len__(self): 700 | return self.cumlen[-1] 701 | 702 | def get_img(self,name): 703 | path = glob.glob(self.dir+'*'+name+'.jpg')[0] 704 | im = Image.open(path) 705 | return im 706 | 707 | def get_cnt(self,name): 708 | path = glob.glob(self.dir+'*'+name+'_stdata.tsv')[0] 709 | df = pd.read_csv(path,sep='\t',index_col=0) 710 | return df 711 | 712 | def get_pos(self,name): 713 | path = glob.glob(self.dir+'*spot*'+name+'.tsv')[0] 714 | df = pd.read_csv(path,sep='\t') 715 | 716 | x = df['x'].values 717 | y = df['y'].values 718 | x = np.around(x).astype(int) 719 | y = np.around(y).astype(int) 720 | id = [] 721 | for i in range(len(x)): 722 | id.append(str(x[i])+'x'+str(y[i])) 723 | df['id'] = id 724 | 725 | return df 726 | 727 | def get_meta(self,name,gene_list=None): 728 | cnt = self.get_cnt(name) 729 | pos = self.get_pos(name) 730 | meta = cnt.join(pos.set_index('id'),how='inner') 731 | 732 | return meta 733 | 734 | def get_overlap(self,meta_dict,gene_list): 735 | gene_set = set(gene_list) 736 | for i in meta_dict.values(): 737 | gene_set = gene_set&set(i.columns) 738 | return list(gene_set) 739 | 740 | 741 | if __name__ == '__main__': 742 | # dataset = VitDataset(diameter=112,sr=True) 743 | dataset = ViT_HER2ST(train=True,mt=False) 744 | # dataset = ViT_SKIN(train=True,mt=False,sr=False,aug=False) 745 | 746 | print(len(dataset)) 747 | print(dataset[0][0].shape) 748 | print(dataset[0][1].shape) 749 | print(dataset[0][2].shape) 750 | # print(dataset[0][3].shape) 751 | # print(dataset.max_x) 752 | # print(dataset.max_y) 753 | # print(len(dataset.gene_set)) 754 | # np.save('data/her_g_list.npy',dataset.gene_set) -------------------------------------------------------------------------------- /model/download_checkpoint.txt: -------------------------------------------------------------------------------- 1 | https://drive.google.com/file/d/15Js-oGI5fScZwN0Xd_5FytL4UJP3i1eN/view?usp=sharing 2 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from utils import * 4 | from vis_model import HisToGene 5 | import warnings 6 | from dataset import ViT_HER2ST, ViT_SKIN 7 | from tqdm import tqdm 8 | warnings.filterwarnings('ignore') 9 | 10 | 11 | MODEL_PATH = '' 12 | 13 | # device = 'cpu' 14 | def model_predict(model, test_loader, adata=None, attention=True, device = torch.device('cpu')): 15 | model.eval() 16 | model = model.to(device) 17 | preds = None 18 | with torch.no_grad(): 19 | for patch, position, exp, center in tqdm(test_loader): 20 | 21 | patch, position = patch.to(device), position.to(device) 22 | 23 | pred = model(patch, position) 24 | 25 | 26 | if preds is None: 27 | preds = pred.squeeze() 28 | ct = center 29 | gt = exp 30 | else: 31 | preds = torch.cat((preds,pred),dim=0) 32 | ct = torch.cat((ct,center),dim=0) 33 | gt = torch.cat((gt,exp),dim=0) 34 | 35 | preds = preds.cpu().squeeze().numpy() 36 | ct = ct.cpu().squeeze().numpy() 37 | gt = gt.cpu().squeeze().numpy() 38 | adata = ann.AnnData(preds) 39 | adata.obsm['spatial'] = ct 40 | 41 | adata_gt = ann.AnnData(gt) 42 | adata_gt.obsm['spatial'] = ct 43 | 44 | return adata, adata_gt 45 | 46 | def sr_predict(model, test_loader, attention=True,device = torch.device('cpu')): 47 | model.eval() 48 | model = model.to(device) 49 | preds = None 50 | with torch.no_grad(): 51 | for patch, position, center in tqdm(test_loader): 52 | 53 | patch, position = patch.to(device), position.to(device) 54 | pred = model(patch, position) 55 | 56 | if preds is None: 57 | preds = pred.squeeze() 58 | ct = center 59 | else: 60 | preds = torch.cat((preds,pred),dim=0) 61 | ct = torch.cat((ct,center),dim=0) 62 | preds = preds.cpu().squeeze().numpy() 63 | ct = ct.cpu().squeeze().numpy() 64 | adata = ann.AnnData(preds) 65 | adata.obsm['spatial'] = ct 66 | 67 | 68 | return adata 69 | 70 | def main(): 71 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 72 | # for fold in [5,11,17,26]: 73 | for fold in range(12): 74 | # fold=30 75 | # tag = '-vit_skin_aug' 76 | # tag = '-cnn_her2st_785_32_cv' 77 | tag = '-vit_her2st_785_32_cv' 78 | # tag = '-cnn_skin_134_cv' 79 | # tag = '-vit_skin_134_cv' 80 | ds = 'HER2' 81 | # ds = 'Skin' 82 | 83 | print('Loading model ...') 84 | # model = STModel.load_from_checkpoint('model/last_train_'+tag+'.ckpt') 85 | # model = VitModel.load_from_checkpoint('model/last_train_'+tag+'.ckpt') 86 | # model = STModel.load_from_checkpoint("model/last_train_"+tag+'_'+str(fold)+".ckpt") 87 | model = SpatialTransformer.load_from_checkpoint("model/last_train_"+tag+'_'+str(fold)+".ckpt") 88 | model = model.to(device) 89 | # model = torch.nn.DataParallel(model) 90 | print('Loading data ...') 91 | 92 | # g = list(np.load('data/her_hvg_cut_1000.npy',allow_pickle=True)) 93 | g = list(np.load('data/skin_hvg_cut_1000.npy',allow_pickle=True)) 94 | 95 | # dataset = SKIN(train=False,ds=ds,fold=fold) 96 | dataset = ViT_HER2ST(train=False,mt=False,sr=True,fold=fold) 97 | # dataset = ViT_SKIN(train=False,mt=False,sr=False,fold=fold) 98 | # dataset = VitDataset(diameter=112,sr=True) 99 | 100 | test_loader = DataLoader(dataset, batch_size=16, num_workers=4) 101 | print('Making prediction ...') 102 | 103 | adata_pred, adata = model_predict(model, test_loader, attention=False) 104 | # adata_pred = sr_predict(model,test_loader,attention=True) 105 | 106 | adata_pred.var_names = g 107 | print('Saving files ...') 108 | adata_pred = comp_tsne_km(adata_pred,4) 109 | # adata_pred = comp_umap(adata_pred) 110 | print(fold) 111 | print(adata_pred) 112 | 113 | adata_pred.write('processed/test_pred_'+ds+'_'+str(fold)+tag+'.h5ad') 114 | # adata_pred.write('processed/test_pred_sr_'+ds+'_'+str(fold)+tag+'.h5ad') 115 | 116 | # quit() 117 | 118 | if __name__ == '__main__': 119 | main() 120 | 121 | -------------------------------------------------------------------------------- /transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | import torch.nn.functional as F 4 | 5 | from einops import rearrange, repeat 6 | from einops.layers.torch import Rearrange 7 | # from visualizer import get_local 8 | # helpers 9 | 10 | def pair(t): 11 | return t if isinstance(t, tuple) else (t, t) 12 | 13 | class PreNorm(nn.Module): 14 | def __init__(self, dim, fn): 15 | super().__init__() 16 | self.norm = nn.LayerNorm(dim) 17 | self.fn = fn 18 | def forward(self, x, **kwargs): 19 | return self.fn(self.norm(x), **kwargs) 20 | 21 | class FeedForward(nn.Module): 22 | def __init__(self, dim, hidden_dim, dropout = 0.): 23 | super().__init__() 24 | self.net = nn.Sequential( 25 | nn.Linear(dim, hidden_dim), 26 | nn.GELU(), 27 | nn.Dropout(dropout), 28 | nn.Linear(hidden_dim, dim), 29 | nn.Dropout(dropout) 30 | ) 31 | def forward(self, x): 32 | return self.net(x) 33 | 34 | class Attention(nn.Module): 35 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 36 | super().__init__() 37 | inner_dim = dim_head * heads 38 | project_out = not (heads == 1 and dim_head == dim) 39 | 40 | self.heads = heads 41 | self.scale = dim_head ** -0.5 42 | 43 | self.attend = nn.Softmax(dim = -1) 44 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 45 | 46 | self.to_out = nn.Sequential( 47 | nn.Linear(inner_dim, dim), 48 | nn.Dropout(dropout) 49 | ) if project_out else nn.Identity() 50 | 51 | # @get_local('attn') 52 | def forward(self, x): 53 | b, n, _, h = *x.shape, self.heads 54 | qkv = self.to_qkv(x).chunk(3, dim = -1) 55 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) 56 | 57 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 58 | 59 | attn = self.attend(dots) 60 | # print(attn.shape) 61 | # quit() 62 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 63 | out = rearrange(out, 'b h n d -> b n (h d)') 64 | return self.to_out(out) 65 | 66 | class Transformer(nn.Module): 67 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): 68 | super().__init__() 69 | self.layers = nn.ModuleList([]) 70 | for _ in range(depth): 71 | self.layers.append(nn.ModuleList([ 72 | PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), 73 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) 74 | ])) 75 | def forward(self, x): 76 | for attn, ff in self.layers: 77 | x = attn(x) + x 78 | x = ff(x) + x 79 | return x 80 | 81 | class ViT(nn.Module): 82 | def __init__(self, *, dim, depth, heads, mlp_dim, dim_head = 64, dropout = 0., emb_dropout = 0.): 83 | super().__init__() 84 | self.dropout = nn.Dropout(emb_dropout) 85 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) 86 | self.to_latent = nn.Identity() 87 | 88 | def forward(self, x): 89 | x = self.dropout(x) 90 | x = self.transformer(x) 91 | x = self.to_latent(x) 92 | return x -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from os import name 2 | from PIL import Image 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import seaborn as sns 6 | import scanpy as sc, anndata as ad 7 | from sklearn import preprocessing 8 | from sklearn.cluster import KMeans 9 | Image.MAX_IMAGE_PIXELS = 933120000 10 | import pickle 11 | import pandas as pd 12 | import anndata as ann 13 | import os 14 | import glob 15 | import scprep as scp 16 | # from dataset import MARKERS 17 | BCELL = ['CD19', 'CD79A', 'CD79B', 'MS4A1'] 18 | TUMOR = ['FASN'] 19 | CD4T = ['CD4'] 20 | CD8T = ['CD8A', 'CD8B'] 21 | DC = ['CLIC2', 'CLEC10A', 'CD1B', 'CD1A', 'CD1E'] 22 | MDC = ['LAMP3'] 23 | CMM = ['BRAF', 'KRAS'] 24 | IG = {'B_cell':BCELL, 'Tumor':TUMOR, 'CD4+T_cell':CD4T, 'CD8+T_cell':CD8T, 'Dendritic_cells':DC, 25 | 'Mature_dendritic_cells':MDC, 'Cutaneous_Malignant_Melanoma':CMM} 26 | MARKERS = [] 27 | for i in IG.values(): 28 | MARKERS+=i 29 | LYM = {'B_cell':BCELL, 'CD4+T_cell':CD4T, 'CD8+T_cell':CD8T} 30 | 31 | def read_tiff(path): 32 | Image.MAX_IMAGE_PIXELS = 933120000 33 | im = Image.open(path) 34 | imarray = np.array(im) 35 | # I = plt.imread(path) 36 | return im 37 | 38 | def preprocess(adata, n_keep=1000, include=LYM, g=True): 39 | adata.var_names_make_unique() 40 | sc.pp.normalize_total(adata) 41 | sc.pp.log1p(adata) 42 | if g: 43 | # with open("data/gene_list.txt", "rb") as fp: 44 | # b = pickle.load(fp) 45 | b = list(np.load('data/skin_a.npy',allow_pickle=True)) 46 | adata = adata[:,b] 47 | elif include: 48 | # b = adata.copy() 49 | # sc.pp.highly_variable_genes(b, n_top_genes=n_keep,subset=True) 50 | # hvgs = b.var_names 51 | # n_union = len(hvgs&include) 52 | # n_include = len(include) 53 | # hvgs = list(set(hvgs)-set(include))[n_include-n_union:] 54 | # g = include 55 | # adata = adata[:,g] 56 | exp = np.zeros((adata.X.shape[0],len(include))) 57 | for n,(i,v) in enumerate(include.items()): 58 | tmp = adata[:,v].X 59 | tmp = np.mean(tmp,1).flatten() 60 | exp[:,n] = tmp 61 | adata = adata[:,:len(include)] 62 | adata.X = exp 63 | adata.var_names = list(include.keys()) 64 | 65 | else: 66 | sc.pp.highly_variable_genes(adata, n_top_genes=n_keep,subset=True) 67 | c = adata.obsm['spatial'] 68 | scaler = preprocessing.StandardScaler().fit(c) 69 | c = scaler.transform(c) 70 | adata.obsm['position_norm'] = c 71 | # with open("data/gene_list.txt", "wb") as fp: 72 | # pickle.dump(g, fp) 73 | return adata 74 | 75 | def comp_umap(adata): 76 | sc.pp.pca(adata) 77 | sc.pp.neighbors(adata) 78 | sc.tl.umap(adata) 79 | sc.tl.leiden(adata, key_added="clusters") 80 | return adata 81 | 82 | def comp_tsne_km(adata,k=10): 83 | sc.pp.pca(adata) 84 | sc.tl.tsne(adata) 85 | kmeans = KMeans(n_clusters=k, init="k-means++", random_state=0).fit(adata.obsm['X_pca']) 86 | adata.obs['kmeans'] = kmeans.labels_.astype(str) 87 | return adata 88 | 89 | def co_embed(a,b,k=10): 90 | a.obs['tag'] = 'Truth' 91 | b.obs['tag'] = 'Pred' 92 | adata = ad.concat([a,b]) 93 | sc.pp.pca(adata) 94 | sc.tl.tsne(adata) 95 | kmeans = KMeans(n_clusters=k, init="k-means++", random_state=0).fit(adata.obsm['X_pca']) 96 | adata.obs['kmeans'] = kmeans.labels_.astype(str) 97 | return adata 98 | 99 | def build_adata(name='H1'): 100 | cnt_dir = 'data/her2st/data/ST-cnts' 101 | img_dir = 'data/her2st/data/ST-imgs' 102 | pos_dir = 'data/her2st/data/ST-spotfiles' 103 | 104 | pre = img_dir+'/'+name[0]+'/'+name 105 | fig_name = os.listdir(pre)[0] 106 | path = pre+'/'+fig_name 107 | im = Image.open(path) 108 | 109 | path = cnt_dir+'/'+name+'.tsv' 110 | cnt = pd.read_csv(path,sep='\t',index_col=0) 111 | 112 | path = pos_dir+'/'+name+'_selection.tsv' 113 | df = pd.read_csv(path,sep='\t') 114 | 115 | x = df['x'].values 116 | y = df['y'].values 117 | id = [] 118 | for i in range(len(x)): 119 | id.append(str(x[i])+'x'+str(y[i])) 120 | df['id'] = id 121 | 122 | meta = cnt.join((df.set_index('id'))) 123 | 124 | gene_list = list(np.load('data/her_g_list.npy')) 125 | adata = ann.AnnData(scp.transform.log(scp.normalize.library_size_normalize(meta[gene_list].values))) 126 | adata.var_names = gene_list 127 | adata.obsm['spatial'] = np.floor(meta[['pixel_x','pixel_y']].values).astype(int) 128 | 129 | return adata, im 130 | 131 | 132 | def get_data(dataset='bc1', n_keep=1000, include=LYM, g=True): 133 | if dataset == 'bc1': 134 | adata = sc.datasets.visium_sge(sample_id='V1_Breast_Cancer_Block_A_Section_1', include_hires_tiff=True) 135 | adata = preprocess(adata, n_keep, include, g) 136 | img_path = adata.uns["spatial"]['V1_Breast_Cancer_Block_A_Section_1']["metadata"]["source_image_path"] 137 | elif dataset == 'bc2': 138 | adata = sc.datasets.visium_sge(sample_id='V1_Breast_Cancer_Block_A_Section_2', include_hires_tiff=True) 139 | adata = preprocess(adata, n_keep, include, g) 140 | img_path = adata.uns["spatial"]['V1_Breast_Cancer_Block_A_Section_2']["metadata"]["source_image_path"] 141 | else: 142 | adata = sc.datasets.visium_sge(sample_id=dataset, include_hires_tiff=True) 143 | adata = preprocess(adata, n_keep, include, g) 144 | img_path = adata.uns["spatial"][dataset]["metadata"]["source_image_path"] 145 | 146 | return adata, img_path 147 | 148 | if __name__ == '__main__': 149 | 150 | adata, img_path = get_data() 151 | print(adata.X.toarray()) 152 | -------------------------------------------------------------------------------- /vis_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torchvision 7 | import pytorch_lightning as pl 8 | from torchmetrics.functional import accuracy 9 | from transformer import ViT 10 | from torch.optim.lr_scheduler import ReduceLROnPlateau 11 | 12 | class FeatureExtractor(nn.Module): 13 | """Some Information about FeatureExtractor""" 14 | def __init__(self, backbone='resnet101'): 15 | super(FeatureExtractor, self).__init__() 16 | backbone = torchvision.models.resnet101(pretrained=True) 17 | layers = list(backbone.children())[:-1] 18 | self.backbone = nn.Sequential(*layers) 19 | # self.backbone = backbone 20 | def forward(self, x): 21 | x = self.backbone(x) 22 | return x 23 | 24 | class ImageClassifier(pl.LightningModule): 25 | def __init__(self, num_classes=4, backbone='resnet50', learning_rate=1e-3): 26 | super().__init__() 27 | self.save_hyperparameters() 28 | backbone = torchvision.models.resnet50(pretrained=True) 29 | num_filters = backbone.fc.in_features 30 | layers = list(backbone.children())[:-1] 31 | self.feature_extractor = nn.Sequential(*layers) 32 | num_target_classes = num_classes 33 | self.classifier = nn.Linear(num_filters, num_target_classes) 34 | # self.valid_acc = torchmetrics.Accuracy() 35 | self.learning_rate = learning_rate 36 | 37 | def forward(self, x): 38 | # use forward for inference/predictions 39 | embedding = self.feature_extractor(x) 40 | return embedding 41 | 42 | def training_step(self, batch, batch_idx): 43 | x, y = batch 44 | h = self.feature_extractor(x).flatten(1) 45 | h = self.classifier(h) 46 | logits = F.log_softmax(h, dim=1) 47 | loss = F.nll_loss(logits, y) 48 | self.log('train_loss', loss) 49 | return loss 50 | 51 | def validation_step(self, batch, batch_idx): 52 | x, y = batch 53 | h = self.feature_extractor(x).flatten(1) 54 | h = self.classifier(h) 55 | logits = F.log_softmax(h, dim=1) 56 | loss = F.nll_loss(logits, y) 57 | preds = torch.argmax(logits, dim=1) 58 | acc = accuracy(preds, y) 59 | 60 | self.log('valid_loss', loss) 61 | self.log('valid_acc', acc) 62 | 63 | def test_step(self, batch, batch_idx): 64 | x, y = batch 65 | h = self.feature_extractor(x).flatten(1) 66 | h = self.classifier(h) 67 | logits = F.log_softmax(h, dim=1) 68 | loss = F.nll_loss(logits, y) 69 | preds = torch.argmax(logits, dim=1) 70 | acc = accuracy(preds, y) 71 | 72 | self.log('test_loss', loss) 73 | self.log('test_acc', acc) 74 | 75 | def configure_optimizers(self): 76 | # self.hparams available because we called self.save_hyperparameters() 77 | return torch.optim.Adam(self.parameters(), lr=self.learning_rate) 78 | 79 | @staticmethod 80 | def add_model_specific_args(parent_parser): 81 | parser = ArgumentParser(parents=[parent_parser], add_help=False) 82 | parser.add_argument('--lr', type=float, default=0.0001) 83 | return parser 84 | 85 | 86 | class STModel(pl.LightningModule): 87 | def __init__(self, feature_model=None, n_genes=1000, hidden_dim=2048, learning_rate=1e-5, use_mask=False, use_pos=False, cls=False): 88 | super().__init__() 89 | self.save_hyperparameters() 90 | # self.feature_model = None 91 | if feature_model: 92 | # self.feature_model = ImageClassifier.load_from_checkpoint(feature_model) 93 | # self.feature_model.freeze() 94 | self.feature_extractor = ImageClassifier.load_from_checkpoint(feature_model) 95 | else: 96 | self.feature_extractor = FeatureExtractor() 97 | # self.pos_embed = nn.Linear(2, hidden_dim) 98 | self.pred_head = nn.Linear(hidden_dim, n_genes) 99 | 100 | self.learning_rate = learning_rate 101 | self.n_genes = n_genes 102 | 103 | def forward(self, patch, center): 104 | feature = self.feature_extractor(patch).flatten(1) 105 | h = feature 106 | pred = self.pred_head(F.relu(h)) 107 | return pred 108 | 109 | def training_step(self, batch, batch_idx): 110 | patch, center, exp = batch 111 | pred = self(patch, center) 112 | loss = F.mse_loss(pred, exp) 113 | self.log('train_loss', loss) 114 | return loss 115 | 116 | def validation_step(self, batch, batch_idx): 117 | patch, center, exp = batch 118 | pred = self(patch, center) 119 | loss = F.mse_loss(pred, exp) 120 | self.log('valid_loss', loss) 121 | 122 | def test_step(self, batch, batch_idx): 123 | patch, center, exp, mask, label = batch 124 | if self.use_mask: 125 | pred, mask_pred = self(patch, center) 126 | else: 127 | pred = self(patch, center) 128 | 129 | loss = F.mse_loss(pred, exp) 130 | self.log('test_loss', loss) 131 | 132 | def configure_optimizers(self): 133 | # self.hparams available because we called self.save_hyperparameters() 134 | optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) 135 | return optimizer 136 | 137 | @staticmethod 138 | def add_model_specific_args(parent_parser): 139 | parser = ArgumentParser(parents=[parent_parser], add_help=False) 140 | parser.add_argument('--learning_rate', type=float, default=0.0001) 141 | return parser 142 | 143 | 144 | class HisToGene(pl.LightningModule): 145 | def __init__(self, patch_size=112, n_layers=4, n_genes=1000, dim=1024, learning_rate=1e-4, dropout=0.1, n_pos=64): 146 | super().__init__() 147 | # self.save_hyperparameters() 148 | self.learning_rate = learning_rate 149 | patch_dim = 3*patch_size*patch_size 150 | self.patch_embedding = nn.Linear(patch_dim, dim) 151 | self.x_embed = nn.Embedding(n_pos,dim) 152 | self.y_embed = nn.Embedding(n_pos,dim) 153 | self.vit = ViT(dim=dim, depth=n_layers, heads=16, mlp_dim=2*dim, dropout = dropout, emb_dropout = dropout) 154 | 155 | self.gene_head = nn.Sequential( 156 | nn.LayerNorm(dim), 157 | nn.Linear(dim, n_genes) 158 | ) 159 | 160 | def forward(self, patches, centers): 161 | patches = self.patch_embedding(patches) 162 | centers_x = self.x_embed(centers[:,:,0]) 163 | centers_y = self.y_embed(centers[:,:,1]) 164 | x = patches + centers_x + centers_y 165 | h = self.vit(x) 166 | x = self.gene_head(h) 167 | return x 168 | 169 | def training_step(self, batch, batch_idx): 170 | patch, center, exp = batch 171 | pred = self(patch, center) 172 | loss = F.mse_loss(pred.view_as(exp), exp) 173 | self.log('train_loss', loss) 174 | return loss 175 | 176 | def validation_step(self, batch, batch_idx): 177 | patch, center, exp = batch 178 | pred = self(patch, center) 179 | loss = F.mse_loss(pred.view_as(exp), exp) 180 | self.log('valid_loss', loss) 181 | return loss 182 | 183 | def test_step(self, batch, batch_idx): 184 | patch, center, exp = batch 185 | pred = self(patch, center) 186 | loss = F.mse_loss(pred.view_as(exp), exp) 187 | self.log('test_loss', loss) 188 | 189 | def configure_optimizers(self): 190 | # self.hparams available because we called self.save_hyperparameters() 191 | return torch.optim.Adam(self.parameters(), lr=self.learning_rate) 192 | 193 | @staticmethod 194 | def add_model_specific_args(parent_parser): 195 | parser = ArgumentParser(parents=[parent_parser], add_help=False) 196 | parser.add_argument('--learning_rate', type=float, default=0.0001) 197 | return parser 198 | 199 | def count_parameters(model): 200 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 201 | 202 | if __name__ == '__main__': 203 | a = torch.rand(1,4000,3*112*112) 204 | p = torch.ones(1,4000,2).long() 205 | model = HisToGene() 206 | print(count_parameters(model)) 207 | x = model(a,p) 208 | print(x.shape) 209 | --------------------------------------------------------------------------------