├── README.md ├── cell_type_annotation_for_hyp3d.py ├── cell_type_annotation_for_merfish.py ├── cell_type_annotation_for_nanostring.py ├── cell_type_annotation_for_slideseq.py ├── cell_type_annotation_model.pyc ├── dnn_model ├── checkpoint_Hyp-3D_b.t7 ├── checkpoint_MERFISH_s.t7 ├── checkpoint_NanoString.t7 ├── checkpoint_Slide-seq_DM1.z01 ├── checkpoint_Slide-seq_DM1.z02 ├── checkpoint_Slide-seq_DM1.z03 ├── checkpoint_Slide-seq_DM1.z04 ├── checkpoint_Slide-seq_DM1.z05 ├── checkpoint_Slide-seq_DM1.z06 └── checkpoint_Slide-seq_DM1.zip ├── spatialID_example_m1s1.png └── spatialID_overview.png /README.md: -------------------------------------------------------------------------------- 1 | # spatial-ID 2 | 3 | [![python >3.8.8](https://img.shields.io/badge/python-3.8.8-brightgreen)](https://www.python.org/) 4 | 5 | ### Spatial-ID: a cell typing method for spatially resolved transcriptomics via transfer learning and spatial embedding 6 | Spatially resolved transcriptomics (SRT) provides the opportunity to investigate the gene expression profiles and the spatial context of cells in naive state. Cell type annotation is a crucial task in the spatial transcriptome analysis of cell and tissue biology. In this study, we propose Spatial-ID, a supervision-based cell typing method, for high-throughput cell-level SRT datasets that integrates transfer learning and spatial embedding. Spatial-ID effectively incorporates the existing knowledge of reference scRNA-seq datasets and the spatial information of SRT datasets. 7 | 8 | 9 | 10 | # Dependences 11 | 12 | [![numpy-1.21.3](https://img.shields.io/badge/numpy-1.21.3-red)](https://github.com/numpy/numpy) 13 | [![pandas-1.2.4](https://img.shields.io/badge/pandas-1.2.4-lightgrey)](https://github.com/pandas-dev/pandas) 14 | [![scanpy-1.8.1](https://img.shields.io/badge/scanpy-1.8.1-blue)](https://github.com/theislab/scanpy) 15 | [![torch-1.8.1](https://img.shields.io/badge/torch-1.8.1-orange)](https://github.com/pytorch/pytorch) 16 | [![torch__geometric-1.7.2](https://img.shields.io/badge/torch__geometric-1.7.2-green)](https://github.com/pyg-team/pytorch_geometric/) 17 | 18 | # Datasets 19 | 20 | - MERFISH: 280,186 cells * 254 genes, 12 samples. [https://doi.brainimagelibrary.org/doi/10.35077/g.21](https://doi.brainimagelibrary.org/doi/10.35077/g.21) 21 | - MERFISH-3D: 213,192 cells * 155 genes, 3 samples. [https://datadryad.org/stash/dataset/doi:10.5061/dryad.8t8s248](https://datadryad.org/stash/dataset/doi:10.5061/dryad.8t8s248) 22 | - Slide-seq: 207,335 cells * 27181 genes, 6 samples. [https://www.dropbox.com/s/ygzpj0d0oh67br0/Testis_Slideseq_Data.zip?dl=0](https://www.dropbox.com/s/ygzpj0d0oh67br0/Testis_Slideseq_Data.zip?dl=0) 23 | - NanoString: 83,621 cells * 980 genes, 20 samples. [https://nanostring.com/resources/smi-ffpe-dataset-lung9-rep1-data/](https://nanostring.com/resources/smi-ffpe-dataset-lung9-rep1-data/) 24 | 25 | # Usage 26 | 27 | - Run cell\_type\_annotation\_for\_merfish.py to annotate cells in MERFISH dataset. 28 | - Run cell\_type\_annotation\_for\_hyp3d.py to annotate cells in MERFISH-3D dataset. 29 | - Run cell\_type\_annotation\_for\_slideseq.py to annotate cells in Slide-seq dataset. 30 | - Run cell\_type\_annotation\_for\_nanostring.py to annotate cells in NanoString dataset. 31 | 32 | p.s. You may need to unzip dnn\_model/checkpoint\_Slide-seq\_DM1.t7 first before running cell\_type\_annotation\_for\_slideseq.py. 33 | 34 | [!!!] Note: An AttributeError saying that 'GELU' object has no attribute 'approximate' may occurs if your pytorch version is higher than 1.10.0. You can simply downgrade pytorch to 1.8.1 or modify the source code of pytorch temporarily. 35 | 36 | # Example 37 | 38 | 1. Put downloaded MERFISH data (e.g. mouse1_sample1.h5ad) in "dataset/MERFISH/" (as in Line 30 of cell\_type\_annotation\_for\_merfish.py). 39 | 2. Run cell\_type\_annotation\_for\_merfish.py to annotate cells of mouse1_sample1 data. 40 | 3. 4 files can be found in "result/MERFISH/" (as in Line 31 of cell\_type\_annotation\_for\_merfish.py): 41 | - spatialID-mouse1\_sample1.t7: Checkpoint of the self-supervised model in Stage 2. 42 | - spatialID-mouse1\_sample1.h5ad: Updated H5AD file with annotation result stored. 43 | - spatialID-mouse1\_sample1.csv: Annotation results with column "cell" representing cell IDs and "celltype_pred" representing annotated cell types. 44 | - spatialID-mouse1\_sample1.pdf: Visualization of annotation results as shown below. 45 | 46 | 47 | 48 | # Disclaimer 49 | 50 | This tool is for research purpose and not approved for clinical use. 51 | 52 | This is not an official Tencent product. 53 | 54 | # Coypright 55 | 56 | This tool is developed in Tencent AI Lab. 57 | 58 | The copyright holder for this project is Tencent AI Lab. 59 | 60 | All rights reserved. 61 | -------------------------------------------------------------------------------- /cell_type_annotation_for_hyp3d.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import time 6 | import random 7 | import argparse 8 | import anndata 9 | import numpy as np 10 | import pandas as pd 11 | import scanpy as sc 12 | import matplotlib.pyplot as plt 13 | import torch 14 | import torch_geometric 15 | 16 | from cell_type_annotation_model import DNNModel, SpatialModelTrainer 17 | 18 | 19 | random.seed(0) 20 | np.random.seed(0) 21 | torch.manual_seed(0) 22 | torch.cuda.manual_seed(0) 23 | torch.cuda.manual_seed_all(0) 24 | torch.backends.cudnn.benchmark = False 25 | torch.backends.cudnn.deterministic = True 26 | 27 | 28 | config = { 29 | 'data': { 30 | 'data_dir': 'dataset/Hyp_3D/', 31 | 'save_dir': 'result/Hyp_3D/', 32 | 'dataset': 'Hyp_3D', 33 | }, 34 | 'preprocess': { 35 | 'filter_mt': True, 36 | 'cell_min_counts': 300, 37 | 'gene_min_cells': 10, 38 | 'cell_max_counts_percent': 98.0, 39 | 'drop_rate': 0, 40 | }, 41 | 'transfer': { 42 | 'dnn_model': 'dnn_model/checkpoint_Hyp-3D_b.t7', 43 | 'gpu': '0', 44 | 'batch_size': 4096, 45 | }, 46 | 'train': { 47 | 'pca_dim': 200, # for Stereoseq only 48 | 'k_graph': 30, 49 | 'edge_weight': True, 50 | 'kd_T': 1, 51 | 'feat_dim': 64, 52 | 'w_dae': 1.0, 53 | 'w_gae': 1.0, 54 | 'w_cls': 10.0, 55 | 'epochs': 200, 56 | } 57 | } 58 | 59 | 60 | def spatial_classification_tool(config, data_name): 61 | ''' Spatial classification workflow. 62 | 63 | # Arguments 64 | config (Config): Configuration parameters. 65 | data_name (str): Data name. 66 | ''' 67 | ###################################### 68 | # Part 1: Load data # 69 | ###################################### 70 | 71 | # Set path and load data. 72 | print('\n==> Loading data...') 73 | dataset = config['data']['dataset'] 74 | data_dir, save_dir = config['data']['data_dir'], config['data']['save_dir'] 75 | print(f' Data name: {data_name} ({dataset})') 76 | print(f' Data path: {data_dir}') 77 | print(f' Save path: {save_dir}') 78 | adata = sc.read_h5ad(os.path.join(data_dir, f'{data_name}.h5ad')) 79 | 80 | # Initalize save path. 81 | model_name = f'spatialID-{data_name}' 82 | save_dir = os.path.join(save_dir, model_name) 83 | if not os.path.exists(save_dir): 84 | os.makedirs(save_dir) 85 | 86 | 87 | ###################################### 88 | # Part 2: Preprocess # 89 | ###################################### 90 | 91 | print('\n==> Preprocessing...') 92 | strings = [f'{k}={v}' for k, v in config['preprocess'].items()] 93 | print(' Parameters(%s)' % (', '.join(strings))) 94 | 95 | # Preprocess data. 96 | if dataset == 'Stereoseq': 97 | params = config['preprocess'] 98 | if params['filter_mt']: 99 | adata.var['mt'] = adata.var_names.str.startswith(('Mt-', 'mt-')) 100 | sc.pp.calculate_qc_metrics(adata, qc_vars=['mt'], inplace=True) 101 | adata = adata[adata.obs['pct_counts_mt'] < 10].copy() 102 | if params['cell_min_counts'] > 0: 103 | sc.pp.filter_cells(adata, min_counts=params['cell_min_counts']) 104 | if params['gene_min_cells'] > 0: 105 | sc.pp.filter_genes(adata, min_cells=params['gene_min_cells']) 106 | if params['cell_max_counts_percent'] < 100: 107 | max_counts = np.percentile(adata.X.sum(1), params['cell_max_counts_percent']) 108 | sc.pp.filter_cells(adata, max_counts=max_counts) 109 | if type(adata.X) != np.ndarray: 110 | adata_X_sparse_backup = adata.X.copy() 111 | adata.X = adata.X.toarray() 112 | print(' %s: %d cells × %d genes.' % (data_name, adata.shape[0], adata.shape[1])) 113 | 114 | # Please be aware: 115 | # DNN model takes the origin gene expression matrix through its own normalization as input. 116 | # Other normalization (e.g. scanpy) can be added after DNN model inference is completed. 117 | 118 | # Add noise manually. 119 | if dataset != 'Stereoseq': 120 | drop_factor = (np.random.random(adata.shape) > config['preprocess']['drop_rate']) * 1. 121 | adata.X = adata.X * drop_factor 122 | 123 | 124 | ###################################### 125 | # Part 3: Transfer from sc-dataset # 126 | ###################################### 127 | 128 | print('\n==> Transfering from sc-dataset...') 129 | strings = [f'{k}={v}' for k, v in config['transfer'].items()] 130 | print(' Parameters(%s)' % (', '.join(strings))) 131 | 132 | # Set device. 133 | os.environ['CUDA_VISIBLE_DEVICES'] = config['transfer']['gpu'] 134 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 135 | 136 | # Load DNN model trained by sc-dataset. 137 | checkpoint = torch.load(config['transfer']['dnn_model']) 138 | dnn_model = checkpoint['model'].to(device) 139 | dnn_model.eval() 140 | 141 | # Initialize DNN input. 142 | marker_genes = checkpoint['marker_genes'] 143 | gene_indices = adata.var_names.get_indexer(marker_genes) 144 | adata_X = np.pad(adata.X, ((0,0),(0,1)))[:, gene_indices] 145 | norm_factor = np.linalg.norm(adata_X, axis=1, keepdims=True) 146 | norm_factor[norm_factor == 0] = 1 147 | dnn_inputs = torch.Tensor(adata_X / norm_factor).split(config['transfer']['batch_size']) 148 | 149 | # Inference with DNN model. 150 | dnn_predictions = [] 151 | with torch.no_grad(): 152 | for batch_idx, inputs in enumerate(dnn_inputs): 153 | inputs = inputs.to(device) 154 | outputs = dnn_model(inputs) 155 | dnn_predictions.append(outputs.detach().cpu().numpy()) 156 | label_names = checkpoint['label_names'] 157 | adata.obsm['pseudo_label'] = np.concatenate(dnn_predictions) 158 | adata.obs['pseudo_class'] = pd.Categorical([label_names[i] for i in adata.obsm['pseudo_label'].argmax(1)]) 159 | adata.uns['pseudo_classes'] = label_names 160 | 161 | # Compute accuracy (only for Slide-seq). 162 | if dataset == 'Hyp_3D': 163 | indices = np.where(~adata.obs['Cell_class'].isin(['Ambiguous']))[0] 164 | adjusted_pr = adata.obs['pseudo_class'][indices].to_numpy() 165 | adjusted_gt = adata.obs['Cell_class'][indices].replace( 166 | ['Endothelial 1', 'Endothelial 2', 'Endothelial 3', 167 | 'OD Immature 1', 'OD Immature 2', 168 | 'OD Mature 1', 'OD Mature 2', 'OD Mature 3', 'OD Mature 4', 169 | 'Astrocyte', 'Pericytes'], 170 | ['Endothelial', 'Endothelial', 'Endothelial', 171 | 'Immature oligodendrocyte', 'Immature oligodendrocyte', 172 | 'Mature oligodendrocyte', 'Mature oligodendrocyte', 'Mature oligodendrocyte', 'Mature oligodendrocyte', 173 | 'Astrocytes', 'Mural']).to_numpy() 174 | acc = (adjusted_pr == adjusted_gt).sum() / len(indices) * 100.0 175 | print(' %s Acc (transfer only): %.2f%%' % (data_name, acc)) 176 | 177 | 178 | ###################################### 179 | # Part 4: Train GDAE model # 180 | ###################################### 181 | 182 | print('\n==> Model training...') 183 | strings = [f'{k}={v}' for k, v in config['train'].items()] 184 | print(' Parameters(%s)' % (', '.join(strings))) 185 | 186 | # Normalize gene expression. 187 | sc.pp.normalize_total(adata, target_sum=1e4) 188 | sc.pp.log1p(adata) 189 | adata_X = (adata.X - adata.X.mean(0)) / (adata.X.std(0) + 1e-10) 190 | 191 | # Construct spatial graph. 192 | gene_mat = torch.Tensor(adata_X) 193 | if dataset == 'Stereoseq': # PCA 194 | u, s, v = torch.pca_lowrank(gene_mat, config['train']['pca_dim']) 195 | gene_mat = torch.matmul(gene_mat, v) 196 | if dataset == 'Hyp_3D': 197 | z_axis = adata.obs['Bregma'].to_numpy() * 1000 198 | for z in set(z_axis): 199 | adata.obsm['spatial'][z_axis == z] -= adata.obsm['spatial'][z_axis == z].min(0) 200 | adata.obsm['spatial'] = np.hstack([adata.obsm['spatial'], z_axis.reshape(-1, 1)]) 201 | cell_coo = torch.Tensor(adata.obsm['spatial']) 202 | data = torch_geometric.data.Data(x=gene_mat, pos=cell_coo) 203 | data = torch_geometric.transforms.KNNGraph(k=config['train']['k_graph'], loop=True)(data) 204 | data.y = torch.Tensor(adata.obsm['pseudo_label']) 205 | 206 | # Make distances as edge weights. 207 | if config['train']['edge_weight']: 208 | data = torch_geometric.transforms.Distance()(data) 209 | data.edge_weight = 1 - data.edge_attr[:,0] 210 | else: 211 | data.edge_weight = torch.ones(data.edge_index.size(1)) 212 | 213 | # Train self-supervision model. 214 | input_dim = data.num_features 215 | num_classes = len(adata.uns['pseudo_classes']) 216 | trainer = SpatialModelTrainer(input_dim, num_classes, device, config['train']) 217 | trainer.train(data, config['train']) 218 | trainer.save_checkpoint(os.path.join(save_dir, f'{model_name}.t7')) 219 | 220 | # Inference. 221 | print('\n==> Inferencing...') 222 | predictions = trainer.valid(data) 223 | celltype_pred = pd.Categorical([adata.uns['pseudo_classes'][i] for i in predictions.argmax(1)]) 224 | if dataset == 'Hyp_3D': 225 | indices = np.where(~adata.obs['Cell_class'].isin(['Ambiguous']))[0] 226 | adjusted_pr = celltype_pred[indices].to_numpy() 227 | adjusted_gt = adata.obs['Cell_class'][indices].replace( 228 | ['Endothelial 1', 'Endothelial 2', 'Endothelial 3', 229 | 'OD Immature 1', 'OD Immature 2', 230 | 'OD Mature 1', 'OD Mature 2', 'OD Mature 3', 'OD Mature 4', 231 | 'Astrocyte', 'Pericytes'], 232 | ['Endothelial', 'Endothelial', 'Endothelial', 233 | 'Immature oligodendrocyte', 'Immature oligodendrocyte', 234 | 'Mature oligodendrocyte', 'Mature oligodendrocyte', 'Mature oligodendrocyte', 'Mature oligodendrocyte', 235 | 'Astrocytes', 'Mural']).to_numpy() 236 | acc = (adjusted_pr == adjusted_gt).sum() / len(indices) * 100.0 237 | print(' %s Acc (transfer+GDAE): %.2f%%' % (data_name, acc)) 238 | 239 | # Save results. 240 | result = pd.DataFrame({'cell': adata.obs_names.tolist(), 'celltype_pred': celltype_pred}) 241 | result.to_csv(os.path.join(save_dir, f'{model_name}.csv'), index=False) 242 | adata.obsm['celltype_prob'] = predictions 243 | adata.obs['celltype_pred'] = pd.Categorical(celltype_pred) 244 | if 'adata_X_sparse_backup' in locals(): 245 | adata.X = adata_X_sparse_backup 246 | adata.write(os.path.join(save_dir, f'{model_name}.h5ad')) 247 | 248 | # Save visualization. 249 | spot_size = (30 if dataset == 'Stereoseq' else 20) 250 | if dataset == 'Stereoseq': 251 | pseudo_top100 = adata.obs['pseudo_class'].to_numpy() 252 | other_classes = list(pd.value_counts(adata.obs['pseudo_class'])[100:].index) 253 | pseudo_top100[adata.obs['pseudo_class'].isin(other_classes)] = '_Others' 254 | adata.obs['pseudo_class'] = pd.Categorical(pseudo_top100) 255 | # sc.pl.spatial(adata, img_key=None, color=['pseudo_class'], spot_size=spot_size, show=False) 256 | # plt.savefig(os.path.join(save_dir, f'pseudo-{data_name}.pdf'), bbox_inches='tight', dpi=150) 257 | sc.pl.spatial(adata, img_key=None, color=['celltype_pred'], spot_size=spot_size, show=False) 258 | plt.savefig(os.path.join(save_dir, f'{model_name}.pdf'), bbox_inches='tight', dpi=150) 259 | print(' Predictions is saved in', os.path.join(save_dir, f'{model_name}.csv/pdf')) 260 | 261 | 262 | if __name__ == '__main__': 263 | data_list = ['sample1', 'sample2', 'sample3'] 264 | parser = argparse.ArgumentParser() 265 | parser.add_argument('--data_name', choices=data_list) 266 | args = parser.parse_args() 267 | spatial_classification_tool(config, args.data_name) 268 | -------------------------------------------------------------------------------- /cell_type_annotation_for_merfish.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import time 6 | import random 7 | import argparse 8 | import anndata 9 | import numpy as np 10 | import pandas as pd 11 | import scanpy as sc 12 | import matplotlib.pyplot as plt 13 | import torch 14 | import torch_geometric 15 | 16 | from cell_type_annotation_model import DNNModel, SpatialModelTrainer 17 | 18 | 19 | random.seed(0) 20 | np.random.seed(0) 21 | torch.manual_seed(0) 22 | torch.cuda.manual_seed(0) 23 | torch.cuda.manual_seed_all(0) 24 | torch.backends.cudnn.benchmark = False 25 | torch.backends.cudnn.deterministic = True 26 | 27 | 28 | config = { 29 | 'data': { 30 | 'data_dir': 'dataset/MERFISH/', 31 | 'save_dir': 'result/MERFISH/', 32 | 'dataset': 'MERFISH', 33 | }, 34 | 'preprocess': { 35 | 'filter_mt': True, 36 | 'cell_min_counts': 300, 37 | 'gene_min_cells': 10, 38 | 'cell_max_counts_percent': 98.0, 39 | 'drop_rate': 0, 40 | }, 41 | 'transfer': { 42 | 'dnn_model': 'dnn_model/checkpoint_MERFISH_s.t7', 43 | 'gpu': '0', 44 | 'batch_size': 4096, 45 | }, 46 | 'train': { 47 | 'pca_dim': 200, # for Stereoseq only 48 | 'k_graph': 30, 49 | 'edge_weight': True, 50 | 'kd_T': 1, 51 | 'feat_dim': 64, 52 | 'w_dae': 1.0, 53 | 'w_gae': 1.0, 54 | 'w_cls': 10.0, 55 | 'epochs': 200, 56 | } 57 | } 58 | 59 | 60 | def spatial_classification_tool(config, data_name): 61 | ''' Spatial classification workflow. 62 | 63 | # Arguments 64 | config (Config): Configuration parameters. 65 | data_name (str): Data name. 66 | ''' 67 | ###################################### 68 | # Part 1: Load data # 69 | ###################################### 70 | 71 | # Set path and load data. 72 | print('\n==> Loading data...') 73 | dataset = config['data']['dataset'] 74 | data_dir, save_dir = config['data']['data_dir'], config['data']['save_dir'] 75 | print(f' Data name: {data_name} ({dataset})') 76 | print(f' Data path: {data_dir}') 77 | print(f' Save path: {save_dir}') 78 | adata = sc.read_h5ad(os.path.join(data_dir, f'{data_name}.h5ad')) 79 | 80 | # Initalize save path. 81 | model_name = f'spatialID-{data_name}' 82 | save_dir = os.path.join(save_dir, model_name) 83 | if not os.path.exists(save_dir): 84 | os.makedirs(save_dir) 85 | 86 | 87 | ###################################### 88 | # Part 2: Preprocess # 89 | ###################################### 90 | 91 | print('\n==> Preprocessing...') 92 | strings = [f'{k}={v}' for k, v in config['preprocess'].items()] 93 | print(' Parameters(%s)' % (', '.join(strings))) 94 | 95 | # Preprocess data. 96 | if dataset == 'Stereoseq': 97 | params = config['preprocess'] 98 | if params['filter_mt']: 99 | adata.var['mt'] = adata.var_names.str.startswith(('Mt-', 'mt-')) 100 | sc.pp.calculate_qc_metrics(adata, qc_vars=['mt'], inplace=True) 101 | adata = adata[adata.obs['pct_counts_mt'] < 10].copy() 102 | if params['cell_min_counts'] > 0: 103 | sc.pp.filter_cells(adata, min_counts=params['cell_min_counts']) 104 | if params['gene_min_cells'] > 0: 105 | sc.pp.filter_genes(adata, min_cells=params['gene_min_cells']) 106 | if params['cell_max_counts_percent'] < 100: 107 | max_counts = np.percentile(adata.X.sum(1), params['cell_max_counts_percent']) 108 | sc.pp.filter_cells(adata, max_counts=max_counts) 109 | if type(adata.X) != np.ndarray: 110 | adata_X_sparse_backup = adata.X.copy() 111 | adata.X = adata.X.toarray() 112 | print(' %s: %d cells × %d genes.' % (data_name, adata.shape[0], adata.shape[1])) 113 | 114 | # Please be aware: 115 | # DNN model takes the origin gene expression matrix through its own normalization as input. 116 | # Other normalization (e.g. scanpy) can be added after DNN model inference is completed. 117 | 118 | # Add noise manually. 119 | if dataset != 'Stereoseq': 120 | drop_factor = (np.random.random(adata.shape) > config['preprocess']['drop_rate']) * 1. 121 | adata.X = adata.X * drop_factor 122 | 123 | 124 | ###################################### 125 | # Part 3: Transfer from sc-dataset # 126 | ###################################### 127 | 128 | print('\n==> Transfering from sc-dataset...') 129 | strings = [f'{k}={v}' for k, v in config['transfer'].items()] 130 | print(' Parameters(%s)' % (', '.join(strings))) 131 | 132 | # Set device. 133 | os.environ['CUDA_VISIBLE_DEVICES'] = config['transfer']['gpu'] 134 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 135 | 136 | # Load DNN model trained by sc-dataset. 137 | checkpoint = torch.load(config['transfer']['dnn_model']) 138 | dnn_model = checkpoint['model'].to(device) 139 | dnn_model.eval() 140 | 141 | # Initialize DNN input. 142 | marker_genes = checkpoint['marker_genes'] 143 | gene_indices = adata.var_names.get_indexer(marker_genes) 144 | adata_X = np.pad(adata.X, ((0,0),(0,1)))[:, gene_indices].copy() 145 | norm_factor = np.linalg.norm(adata_X, axis=1, keepdims=True) 146 | norm_factor[norm_factor == 0] = 1 147 | dnn_inputs = torch.Tensor(adata_X / norm_factor).split(config['transfer']['batch_size']) 148 | 149 | # Inference with DNN model. 150 | dnn_predictions = [] 151 | with torch.no_grad(): 152 | for batch_idx, inputs in enumerate(dnn_inputs): 153 | inputs = inputs.to(device) 154 | outputs = dnn_model(inputs) 155 | dnn_predictions.append(outputs.detach().cpu().numpy()) 156 | label_names = checkpoint['label_names'] 157 | adata.obsm['pseudo_label'] = np.concatenate(dnn_predictions) 158 | adata.obs['pseudo_class'] = pd.Categorical([label_names[i] for i in adata.obsm['pseudo_label'].argmax(1)]) 159 | adata.uns['pseudo_classes'] = label_names 160 | 161 | # Compute accuracy (only for MERFISH). 162 | if dataset == 'MERFISH': 163 | indices = np.where(~adata.obs['subclass'].isin(['L4/5 IT', 'L6 IT Car3', 'PVM', 'other']))[0] 164 | adjusted_pr = adata.obs['pseudo_class'][indices].to_numpy() 165 | adjusted_gt = adata.obs['subclass'][indices].replace(['Micro'], ['Macrophage']).to_numpy() 166 | acc = (adjusted_pr == adjusted_gt).sum() / len(indices) * 100.0 167 | print(' %s Acc (transfer only): %.2f%%' % (data_name, acc)) 168 | 169 | 170 | ###################################### 171 | # Part 4: Train GDAE model # 172 | ###################################### 173 | 174 | print('\n==> Model training...') 175 | strings = [f'{k}={v}' for k, v in config['train'].items()] 176 | print(' Parameters(%s)' % (', '.join(strings))) 177 | 178 | # Normalize gene expression. 179 | sc.pp.normalize_total(adata, target_sum=1e4) 180 | sc.pp.log1p(adata) 181 | adata_X = (adata.X - adata.X.mean(0)) / (adata.X.std(0) + 1e-10) 182 | 183 | # Construct spatial graph. 184 | gene_mat = torch.Tensor(adata_X) 185 | if dataset == 'Stereoseq': # PCA 186 | u, s, v = torch.pca_lowrank(gene_mat, config['train']['pca_dim']) 187 | gene_mat = torch.matmul(gene_mat, v) 188 | cell_coo = torch.Tensor(adata.obsm['spatial']) 189 | data = torch_geometric.data.Data(x=gene_mat, pos=cell_coo) 190 | data = torch_geometric.transforms.KNNGraph(k=config['train']['k_graph'], loop=True)(data) 191 | data.y = torch.Tensor(adata.obsm['pseudo_label']) 192 | 193 | # Make distances as edge weights. 194 | if config['train']['edge_weight']: 195 | data = torch_geometric.transforms.Distance()(data) 196 | data.edge_weight = 1 - data.edge_attr[:,0] 197 | else: 198 | data.edge_weight = torch.ones(data.edge_index.size(1)) 199 | 200 | # Train self-supervision model. 201 | input_dim = data.num_features 202 | num_classes = len(adata.uns['pseudo_classes']) 203 | trainer = SpatialModelTrainer(input_dim, num_classes, device, config['train']) 204 | trainer.train(data, config['train']) 205 | trainer.save_checkpoint(os.path.join(save_dir, f'{model_name}.t7')) 206 | 207 | # Inference. 208 | print('\n==> Inferencing...') 209 | predictions = trainer.valid(data) 210 | celltype_pred = pd.Categorical([adata.uns['pseudo_classes'][i] for i in predictions.argmax(1)]) 211 | if dataset == 'MERFISH': 212 | indices = np.where(~adata.obs['subclass'].isin(['L4/5 IT', 'L6 IT Car3', 'PVM', 'other']))[0] 213 | adjusted_pr = celltype_pred[indices].to_numpy() 214 | adjusted_gt = adata.obs['subclass'][indices].replace(['Micro'], ['Macrophage']).to_numpy() 215 | acc = (adjusted_pr == adjusted_gt).sum() / len(indices) * 100.0 216 | print(' %s Acc (transfer+GDAE): %.2f%%' % (data_name, acc)) 217 | 218 | # Save results. 219 | result = pd.DataFrame({'cell': adata.obs_names.tolist(), 'celltype_pred': celltype_pred}) 220 | result.to_csv(os.path.join(save_dir, f'{model_name}.csv'), index=False) 221 | adata.obsm['celltype_prob'] = predictions 222 | adata.obs['celltype_pred'] = pd.Categorical(celltype_pred) 223 | if 'adata_X_sparse_backup' in locals(): 224 | adata.X = adata_X_sparse_backup 225 | adata.write(os.path.join(save_dir, f'{model_name}.h5ad')) 226 | 227 | # Save visualization. 228 | spot_size = (30 if dataset == 'Stereoseq' else 20) 229 | if dataset == 'Stereoseq': 230 | pseudo_top100 = adata.obs['pseudo_class'].to_numpy() 231 | other_classes = list(pd.value_counts(adata.obs['pseudo_class'])[100:].index) 232 | pseudo_top100[adata.obs['pseudo_class'].isin(other_classes)] = '_Others' 233 | adata.obs['pseudo_class'] = pd.Categorical(pseudo_top100) 234 | # sc.pl.spatial(adata, img_key=None, color=['pseudo_class'], spot_size=spot_size, show=False) 235 | # plt.savefig(os.path.join(save_dir, f'pseudo-{data_name}.pdf'), bbox_inches='tight', dpi=150) 236 | sc.pl.spatial(adata, img_key=None, color=['celltype_pred'], spot_size=spot_size, show=False) 237 | plt.savefig(os.path.join(save_dir, f'{model_name}.pdf'), bbox_inches='tight', dpi=150) 238 | print(' Predictions is saved in', os.path.join(save_dir, f'{model_name}.csv/pdf')) 239 | 240 | 241 | if __name__ == '__main__': 242 | data_list = ['mouse1_sample1', 'mouse1_sample2', 'mouse1_sample3', 'mouse1_sample4', 'mouse1_sample5', 'mouse1_sample6', 243 | 'mouse2_sample1', 'mouse2_sample2', 'mouse2_sample3', 'mouse2_sample4', 'mouse2_sample5', 'mouse2_sample6'] 244 | parser = argparse.ArgumentParser() 245 | parser.add_argument('--data_name', choices=data_list) 246 | args = parser.parse_args() 247 | spatial_classification_tool(config, args.data_name) 248 | -------------------------------------------------------------------------------- /cell_type_annotation_for_nanostring.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import time 6 | import random 7 | import argparse 8 | import anndata 9 | import numpy as np 10 | import pandas as pd 11 | import scanpy as sc 12 | import matplotlib.pyplot as plt 13 | import torch 14 | import torch_geometric 15 | 16 | from cell_type_annotation_model import DNNModel, SpatialModelTrainer 17 | 18 | 19 | random.seed(0) 20 | np.random.seed(0) 21 | torch.manual_seed(0) 22 | torch.cuda.manual_seed(0) 23 | torch.cuda.manual_seed_all(0) 24 | torch.backends.cudnn.benchmark = False 25 | torch.backends.cudnn.deterministic = True 26 | 27 | 28 | config = { 29 | 'data': { 30 | 'data_dir': 'dataset/NanoString/', 31 | 'save_dir': 'result/NanoString/', 32 | 'dataset': 'NanoString', 33 | }, 34 | 'preprocess': { 35 | 'filter_mt': True, 36 | 'cell_min_counts': 300, 37 | 'gene_min_cells': 10, 38 | 'cell_max_counts_percent': 98.0, 39 | 'drop_rate': 0, 40 | }, 41 | 'transfer': { 42 | 'dnn_model': 'dnn_model/checkpoint_NanoString.t7', 43 | 'gpu': '0', 44 | 'batch_size': 4096, 45 | }, 46 | 'train': { 47 | 'pca_dim': 200, # for Stereoseq only 48 | 'k_graph': 30, 49 | 'edge_weight': True, 50 | 'kd_T': 1, 51 | 'feat_dim': 64, 52 | 'w_dae': 1.0, 53 | 'w_gae': 1.0, 54 | 'w_cls': 10.0, 55 | 'epochs': 200, 56 | } 57 | } 58 | 59 | 60 | def spatial_classification_tool(config, data_name): 61 | ''' Spatial classification workflow. 62 | 63 | # Arguments 64 | config (Config): Configuration parameters. 65 | data_name (str): Data name. 66 | ''' 67 | ###################################### 68 | # Part 1: Load data # 69 | ###################################### 70 | 71 | # Set path and load data. 72 | print('\n==> Loading data...') 73 | dataset = config['data']['dataset'] 74 | data_dir, save_dir = config['data']['data_dir'], config['data']['save_dir'] 75 | print(f' Data name: {data_name} ({dataset})') 76 | print(f' Data path: {data_dir}') 77 | print(f' Save path: {save_dir}') 78 | adata = sc.read_h5ad(os.path.join(data_dir, f'{data_name}.h5ad')) 79 | 80 | # Initalize save path. 81 | model_name = f'spatialID-{data_name}' 82 | save_dir = os.path.join(save_dir, model_name) 83 | if not os.path.exists(save_dir): 84 | os.makedirs(save_dir) 85 | 86 | 87 | ###################################### 88 | # Part 2: Preprocess # 89 | ###################################### 90 | 91 | print('\n==> Preprocessing...') 92 | strings = [f'{k}={v}' for k, v in config['preprocess'].items()] 93 | print(' Parameters(%s)' % (', '.join(strings))) 94 | 95 | # Preprocess data. 96 | if dataset == 'Stereoseq': 97 | params = config['preprocess'] 98 | if params['filter_mt']: 99 | adata.var['mt'] = adata.var_names.str.startswith(('Mt-', 'mt-')) 100 | sc.pp.calculate_qc_metrics(adata, qc_vars=['mt'], inplace=True) 101 | adata = adata[adata.obs['pct_counts_mt'] < 10].copy() 102 | if params['cell_min_counts'] > 0: 103 | sc.pp.filter_cells(adata, min_counts=params['cell_min_counts']) 104 | if params['gene_min_cells'] > 0: 105 | sc.pp.filter_genes(adata, min_cells=params['gene_min_cells']) 106 | if params['cell_max_counts_percent'] < 100: 107 | max_counts = np.percentile(adata.X.sum(1), params['cell_max_counts_percent']) 108 | sc.pp.filter_cells(adata, max_counts=max_counts) 109 | if type(adata.X) != np.ndarray: 110 | adata_X_sparse_backup = adata.X.copy() 111 | adata.X = adata.X.toarray() 112 | print(' %s: %d cells × %d genes.' % (data_name, adata.shape[0], adata.shape[1])) 113 | 114 | # Please be aware: 115 | # DNN model takes the origin gene expression matrix through its own normalization as input. 116 | # Other normalization (e.g. scanpy) can be added after DNN model inference is completed. 117 | 118 | # Add noise manually. 119 | if dataset != 'Stereoseq': 120 | drop_factor = (np.random.random(adata.shape) > config['preprocess']['drop_rate']) * 1. 121 | adata.X = adata.X * drop_factor 122 | 123 | 124 | ###################################### 125 | # Part 3: Transfer from sc-dataset # 126 | ###################################### 127 | 128 | print('\n==> Transfering from sc-dataset...') 129 | strings = [f'{k}={v}' for k, v in config['transfer'].items()] 130 | print(' Parameters(%s)' % (', '.join(strings))) 131 | 132 | # Set device. 133 | os.environ['CUDA_VISIBLE_DEVICES'] = config['transfer']['gpu'] 134 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 135 | 136 | # Load DNN model trained by sc-dataset. 137 | checkpoint = torch.load(config['transfer']['dnn_model']) 138 | dnn_model = checkpoint['model'].to(device) 139 | dnn_model.eval() 140 | 141 | # Initialize DNN input. 142 | marker_genes = checkpoint['marker_genes'] 143 | gene_indices = adata.var_names.get_indexer(marker_genes) 144 | adata_X = np.pad(adata.X, ((0,0),(0,1)))[:, gene_indices] 145 | norm_factor = np.linalg.norm(adata_X, axis=1, keepdims=True) 146 | norm_factor[norm_factor == 0] = 1 147 | dnn_inputs = torch.Tensor(adata_X / norm_factor).split(config['transfer']['batch_size']) 148 | 149 | # Inference with DNN model. 150 | dnn_predictions = [] 151 | with torch.no_grad(): 152 | for batch_idx, inputs in enumerate(dnn_inputs): 153 | inputs = inputs.to(device) 154 | outputs = dnn_model(inputs) 155 | dnn_predictions.append(outputs.detach().cpu().numpy()) 156 | label_names = checkpoint['label_names'] 157 | adata.obsm['pseudo_label'] = np.concatenate(dnn_predictions) 158 | adata.obs['pseudo_class'] = pd.Categorical([label_names[i] for i in adata.obsm['pseudo_label'].argmax(1)]) 159 | adata.uns['pseudo_classes'] = label_names 160 | 161 | # Compute accuracy (only for NanoString). 162 | if dataset == 'NanoString': 163 | indices = np.where(adata.obs['celltype_refined'].isin(label_names))[0] 164 | adjusted_pr = adata.obs['pseudo_class'][indices].to_numpy() 165 | adjusted_gt = adata.obs['celltype_refined'][indices].to_numpy() 166 | acc = (adjusted_pr == adjusted_gt).sum() / len(indices) * 100.0 167 | print(' %s Acc (transfer only): %.2f%%' % (data_name, acc)) 168 | 169 | 170 | ###################################### 171 | # Part 4: Train GDAE model # 172 | ###################################### 173 | 174 | print('\n==> Model training...') 175 | strings = [f'{k}={v}' for k, v in config['train'].items()] 176 | print(' Parameters(%s)' % (', '.join(strings))) 177 | 178 | # Normalize gene expression. 179 | sc.pp.normalize_total(adata, target_sum=1e4) 180 | sc.pp.log1p(adata) 181 | adata_X = (adata.X - adata.X.mean(0)) / (adata.X.std(0) + 1e-10) 182 | 183 | # Construct spatial graph. 184 | gene_mat = torch.Tensor(adata_X) 185 | if dataset == 'Stereoseq': # PCA 186 | u, s, v = torch.pca_lowrank(gene_mat, config['train']['pca_dim']) 187 | gene_mat = torch.matmul(gene_mat, v) 188 | cell_coo = torch.Tensor(adata.obsm['spatial']) 189 | data = torch_geometric.data.Data(x=gene_mat, pos=cell_coo) 190 | data = torch_geometric.transforms.KNNGraph(k=config['train']['k_graph'], loop=True)(data) 191 | data.y = torch.Tensor(adata.obsm['pseudo_label']) 192 | 193 | # Make distances as edge weights. 194 | if config['train']['edge_weight']: 195 | data = torch_geometric.transforms.Distance()(data) 196 | data.edge_weight = 1 - data.edge_attr[:,0] 197 | else: 198 | data.edge_weight = torch.ones(data.edge_index.size(1)) 199 | 200 | # Train self-supervision model. 201 | input_dim = data.num_features 202 | num_classes = len(adata.uns['pseudo_classes']) 203 | trainer = SpatialModelTrainer(input_dim, num_classes, device, config['train']) 204 | trainer.train(data, config['train']) 205 | trainer.save_checkpoint(os.path.join(save_dir, f'{model_name}.t7')) 206 | 207 | # Inference. 208 | print('\n==> Inferencing...') 209 | predictions = trainer.valid(data) 210 | celltype_pred = pd.Categorical([adata.uns['pseudo_classes'][i] for i in predictions.argmax(1)]) 211 | if dataset == 'NanoString': 212 | indices = np.where(adata.obs['celltype_refined'].isin(label_names))[0] 213 | adjusted_pr = celltype_pred[indices].to_numpy() 214 | adjusted_gt = adata.obs['celltype_refined'][indices].to_numpy() 215 | acc = (adjusted_pr == adjusted_gt).sum() / len(indices) * 100.0 216 | print(' %s Acc (transfer+GDAE): %.2f%%' % (data_name, acc)) 217 | 218 | # Save results. 219 | result = pd.DataFrame({'cell': adata.obs_names.tolist(), 'celltype_pred': celltype_pred}) 220 | result.to_csv(os.path.join(save_dir, f'{model_name}.csv'), index=False) 221 | adata.obsm['celltype_prob'] = predictions 222 | adata.obs['celltype_pred'] = pd.Categorical(celltype_pred) 223 | if 'adata_X_sparse_backup' in locals(): 224 | adata.X = adata_X_sparse_backup 225 | adata.write(os.path.join(save_dir, f'{model_name}.h5ad')) 226 | 227 | # Save visualization. 228 | spot_size = (30 if dataset == 'Stereoseq' else 20) 229 | if dataset == 'Stereoseq': 230 | pseudo_top100 = adata.obs['pseudo_class'].to_numpy() 231 | other_classes = list(pd.value_counts(adata.obs['pseudo_class'])[100:].index) 232 | pseudo_top100[adata.obs['pseudo_class'].isin(other_classes)] = '_Others' 233 | adata.obs['pseudo_class'] = pd.Categorical(pseudo_top100) 234 | # sc.pl.spatial(adata, img_key=None, color=['pseudo_class'], spot_size=spot_size, show=False) 235 | # plt.savefig(os.path.join(save_dir, f'pseudo-{data_name}.pdf'), bbox_inches='tight', dpi=150) 236 | sc.pl.spatial(adata, img_key=None, color=['celltype_pred'], spot_size=spot_size, show=False) 237 | plt.savefig(os.path.join(save_dir, f'{model_name}.pdf'), bbox_inches='tight', dpi=150) 238 | print(' Predictions is saved in', os.path.join(save_dir, f'{model_name}.csv/pdf')) 239 | 240 | 241 | if __name__ == '__main__': 242 | data_list = [f'nanostring_fov{i+1}_sampledata' for i in range(20)] 243 | parser = argparse.ArgumentParser() 244 | parser.add_argument('--data_name', choices=data_list) 245 | args = parser.parse_args() 246 | spatial_classification_tool(config, args.data_name) 247 | -------------------------------------------------------------------------------- /cell_type_annotation_for_slideseq.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import time 6 | import random 7 | import argparse 8 | import anndata 9 | import numpy as np 10 | import pandas as pd 11 | import scanpy as sc 12 | import matplotlib.pyplot as plt 13 | import torch 14 | import torch_geometric 15 | 16 | from cell_type_annotation_model import DNNModel, SpatialModelTrainer 17 | 18 | 19 | random.seed(0) 20 | np.random.seed(0) 21 | torch.manual_seed(0) 22 | torch.cuda.manual_seed(0) 23 | torch.cuda.manual_seed_all(0) 24 | torch.backends.cudnn.benchmark = False 25 | torch.backends.cudnn.deterministic = True 26 | 27 | 28 | config = { 29 | 'data': { 30 | 'data_dir': 'dataset/Slide_seq/', 31 | 'save_dir': 'result/Slide_seq/', 32 | 'dataset': 'Slide_seq', 33 | 'radius_ratio': 1.0, 34 | }, 35 | 'preprocess': { 36 | 'filter_mt': True, 37 | 'cell_min_counts': 300, 38 | 'gene_min_cells': 10, 39 | 'cell_max_counts_percent': 98.0, 40 | 'drop_rate': 0, 41 | }, 42 | 'transfer': { 43 | 'dnn_model': 'dnn_model/checkpoint_Slide-seq_DM1.t7', 44 | 'gpu': '0', 45 | 'batch_size': 4096, 46 | }, 47 | 'train': { 48 | 'pca_dim': 200, # for Stereoseq only 49 | 'k_graph': 30, 50 | 'edge_weight': True, 51 | 'kd_T': 1, 52 | 'feat_dim': 64, 53 | 'w_dae': 1.0, 54 | 'w_gae': 10.0, 55 | 'w_cls': 10.0, 56 | 'epochs': 200, 57 | } 58 | } 59 | 60 | 61 | def spatial_classification_tool(config, data_name): 62 | ''' Spatial classification workflow. 63 | 64 | # Arguments 65 | config (Config): Configuration parameters. 66 | data_name (str): Data name. 67 | ''' 68 | ###################################### 69 | # Part 1: Load data # 70 | ###################################### 71 | 72 | # Set path and load data. 73 | print('\n==> Loading data...') 74 | dataset = config['data']['dataset'] 75 | data_dir, save_dir = config['data']['data_dir'], config['data']['save_dir'] 76 | print(f' Data name: {data_name} ({dataset})') 77 | print(f' Data path: {data_dir}') 78 | print(f' Save path: {save_dir}') 79 | adata = sc.read_h5ad(os.path.join(data_dir, f'{data_name}.h5ad')) 80 | 81 | # Sample center area for time test. 82 | if config['data']['radius_ratio'] < 1: 83 | adata.obsm['spatial'] = adata.obsm['spatial'] - adata.obsm['spatial'].mean(0) 84 | radius = adata.obsm['spatial'].max(0).mean() * config['data']['radius_ratio'] 85 | adata = adata[(adata.obsm['spatial']**2).sum(1) < radius**2].copy() 86 | 87 | # Initalize save path. 88 | model_name = f'spatialID-{data_name}' 89 | save_dir = os.path.join(save_dir, model_name) 90 | if not os.path.exists(save_dir): 91 | os.makedirs(save_dir) 92 | 93 | 94 | ###################################### 95 | # Part 2: Preprocess # 96 | ###################################### 97 | 98 | print('\n==> Preprocessing...') 99 | strings = [f'{k}={v}' for k, v in config['preprocess'].items()] 100 | print(' Parameters(%s)' % (', '.join(strings))) 101 | 102 | # Preprocess data. 103 | if dataset == 'Stereoseq': 104 | params = config['preprocess'] 105 | if params['filter_mt']: 106 | adata.var['mt'] = adata.var_names.str.startswith(('Mt-', 'mt-')) 107 | sc.pp.calculate_qc_metrics(adata, qc_vars=['mt'], inplace=True) 108 | adata = adata[adata.obs['pct_counts_mt'] < 10].copy() 109 | if params['cell_min_counts'] > 0: 110 | sc.pp.filter_cells(adata, min_counts=params['cell_min_counts']) 111 | if params['gene_min_cells'] > 0: 112 | sc.pp.filter_genes(adata, min_cells=params['gene_min_cells']) 113 | if params['cell_max_counts_percent'] < 100: 114 | max_counts = np.percentile(adata.X.sum(1), params['cell_max_counts_percent']) 115 | sc.pp.filter_cells(adata, max_counts=max_counts) 116 | if type(adata.X) != np.ndarray: 117 | adata_X_sparse_backup = adata.X.copy() 118 | adata.X = adata.X.toarray() 119 | print(' %s: %d cells × %d genes.' % (data_name, adata.shape[0], adata.shape[1])) 120 | 121 | # Please be aware: 122 | # DNN model takes the origin gene expression matrix through its own normalization as input. 123 | # Other normalization (e.g. scanpy) can be added after DNN model inference is completed. 124 | 125 | # Add noise manually. 126 | if dataset != 'Stereoseq': 127 | drop_factor = (np.random.random(adata.shape) > config['preprocess']['drop_rate']) * 1. 128 | adata.X = adata.X * drop_factor 129 | 130 | 131 | ###################################### 132 | # Part 3: Transfer from sc-dataset # 133 | ###################################### 134 | 135 | print('\n==> Transfering from sc-dataset...') 136 | strings = [f'{k}={v}' for k, v in config['transfer'].items()] 137 | print(' Parameters(%s)' % (', '.join(strings))) 138 | time1 = time.time() 139 | 140 | # Set device. 141 | os.environ['CUDA_VISIBLE_DEVICES'] = config['transfer']['gpu'] 142 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 143 | 144 | # Load DNN model trained by sc-dataset. 145 | checkpoint = torch.load(config['transfer']['dnn_model']) 146 | dnn_model = checkpoint['model'].to(device) 147 | dnn_model.eval() 148 | 149 | # Initialize DNN input. 150 | marker_genes = checkpoint['marker_genes'] 151 | gene_indices = adata.var_names.get_indexer(marker_genes) 152 | adata_X = np.pad(adata.X, ((0,0),(0,1)))[:, gene_indices] 153 | if dataset == 'Slide_seq': 154 | adata_X = np.log1p(adata_X) 155 | norm_factor = np.linalg.norm(adata_X, axis=1, keepdims=True) 156 | norm_factor[norm_factor == 0] = 1 157 | dnn_inputs = torch.Tensor(adata_X / norm_factor).split(config['transfer']['batch_size']) 158 | 159 | # Inference with DNN model. 160 | dnn_predictions = [] 161 | with torch.no_grad(): 162 | for batch_idx, inputs in enumerate(dnn_inputs): 163 | inputs = inputs.to(device) 164 | outputs = dnn_model(inputs) 165 | dnn_predictions.append(outputs.detach().cpu().numpy()) 166 | label_names = checkpoint['label_names'] 167 | adata.obsm['pseudo_label'] = np.concatenate(dnn_predictions) 168 | adata.obs['pseudo_class'] = pd.Categorical([label_names[i] for i in adata.obsm['pseudo_label'].argmax(1)]) 169 | adata.uns['pseudo_classes'] = label_names 170 | 171 | # Compute accuracy (only for Slide-seq). 172 | if dataset == 'Slide_seq': 173 | adjusted_pr = adata.obs['pseudo_class'].to_numpy() 174 | adjusted_gt = adata.obs['ct_name'].replace( 175 | ['ES', 'Differentiating SPG', 'RS', 'SPC'], ['Elongating', 'SPG', 'STids', 'Scytes']).to_numpy() 176 | acc = (adjusted_pr == adjusted_gt).sum() / len(adjusted_gt) * 100.0 177 | print(' %s Acc (transfer only): %.2f%%' % (data_name, acc)) 178 | 179 | 180 | ###################################### 181 | # Part 4: Train GDAE model # 182 | ###################################### 183 | 184 | print('\n==> Model training...') 185 | strings = [f'{k}={v}' for k, v in config['train'].items()] 186 | print(' Parameters(%s)' % (', '.join(strings))) 187 | time2 = time.time() 188 | 189 | # Normalize gene expression. 190 | sc.pp.normalize_total(adata, target_sum=1e4) 191 | sc.pp.log1p(adata) 192 | adata_X = adata.X - adata.X.mean(0) 193 | adata_X = adata_X / (adata.X.std(0) + 1e-10) 194 | 195 | # Construct spatial graph. 196 | gene_mat = torch.Tensor(adata_X) 197 | if dataset in ('Stereoseq', 'Slide_seq'): # PCA 198 | u, s, v = torch.pca_lowrank(gene_mat, config['train']['pca_dim']) 199 | gene_mat = torch.matmul(gene_mat, v) 200 | cell_coo = torch.Tensor(adata.obsm['spatial']) 201 | data = torch_geometric.data.Data(x=gene_mat, pos=cell_coo) 202 | data = torch_geometric.transforms.KNNGraph(k=config['train']['k_graph'], loop=True)(data) 203 | data.y = torch.Tensor(adata.obsm['pseudo_label']) 204 | 205 | # Make distances as edge weights. 206 | if config['train']['edge_weight']: 207 | data = torch_geometric.transforms.Distance()(data) 208 | data.edge_weight = 1 - data.edge_attr[:,0] 209 | else: 210 | data.edge_weight = torch.ones(data.edge_index.size(1)) 211 | 212 | # Train self-supervision model. 213 | input_dim = data.num_features 214 | num_classes = len(adata.uns['pseudo_classes']) 215 | trainer = SpatialModelTrainer(input_dim, num_classes, device, config['train']) 216 | trainer.train(data, config['train']) 217 | trainer.save_checkpoint(os.path.join(save_dir, f'{model_name}.t7')) 218 | time3 = time.time() 219 | 220 | # Inference. 221 | print('\n==> Inferencing...') 222 | predictions = trainer.valid(data) 223 | celltype_pred = pd.Categorical([adata.uns['pseudo_classes'][i] for i in predictions.argmax(1)]) 224 | if dataset == 'Slide_seq': 225 | adjusted_pr = celltype_pred.to_numpy() 226 | adjusted_gt = adata.obs['ct_name'].replace( 227 | ['ES', 'Differentiating SPG', 'RS', 'SPC'], ['Elongating', 'SPG', 'STids', 'Scytes']).to_numpy() 228 | acc = (adjusted_pr == adjusted_gt).sum() / len(adjusted_gt) * 100.0 229 | print(' %s Acc (transfer+GDAE): %.2f%%' % (data_name, acc)) 230 | 231 | # Compute time cost. 232 | time4 = time.time() 233 | print(' S1 infer time: %.2fs' % (time2 - time1)) 234 | print(' S2 train time: %.2fs' % (time3 - time2)) 235 | print(' S2 infer time: %.2fs' % (time4 - time3)) 236 | 237 | # Save results. 238 | result = pd.DataFrame({'cell': adata.obs_names.tolist(), 'celltype_pred': celltype_pred}) 239 | result.to_csv(os.path.join(save_dir, f'{model_name}.csv'), index=False) 240 | adata.obsm['celltype_prob'] = predictions 241 | adata.obs['celltype_pred'] = pd.Categorical(celltype_pred) 242 | if 'adata_X_sparse_backup' in locals(): 243 | adata.X = adata_X_sparse_backup 244 | adata.write(os.path.join(save_dir, f'{model_name}.h5ad')) 245 | 246 | # Save visualization. 247 | spot_size = (30 if dataset == 'Stereoseq' else 20) 248 | if dataset == 'Stereoseq': 249 | pseudo_top100 = adata.obs['pseudo_class'].to_numpy() 250 | other_classes = list(pd.value_counts(adata.obs['pseudo_class'])[100:].index) 251 | pseudo_top100[adata.obs['pseudo_class'].isin(other_classes)] = '_Others' 252 | adata.obs['pseudo_class'] = pd.Categorical(pseudo_top100) 253 | # sc.pl.spatial(adata, img_key=None, color=['pseudo_class'], spot_size=spot_size, show=False) 254 | # plt.savefig(os.path.join(save_dir, f'pseudo-{data_name}.pdf'), bbox_inches='tight', dpi=150) 255 | sc.pl.spatial(adata, img_key=None, color=['celltype_pred'], spot_size=spot_size, show=False) 256 | plt.savefig(os.path.join(save_dir, f'{model_name}.pdf'), bbox_inches='tight', dpi=150) 257 | print(' Predictions is saved in', os.path.join(save_dir, f'{model_name}.csv/pdf')) 258 | 259 | 260 | if __name__ == '__main__': 261 | data_list = ['DM1', 'DM2', 'DM3', 'WT1', 'WT2', 'WT3'] 262 | parser = argparse.ArgumentParser() 263 | parser.add_argument('--data_name', choices=data_list) 264 | args = parser.parse_args() 265 | spatial_classification_tool(config, args.data_name) 266 | -------------------------------------------------------------------------------- /cell_type_annotation_model.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentAILabHealthcare/spatialID/397417f56b2b0fc8bf388b5c504d88c18de34eac/cell_type_annotation_model.pyc -------------------------------------------------------------------------------- /dnn_model/checkpoint_Hyp-3D_b.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentAILabHealthcare/spatialID/397417f56b2b0fc8bf388b5c504d88c18de34eac/dnn_model/checkpoint_Hyp-3D_b.t7 -------------------------------------------------------------------------------- /dnn_model/checkpoint_MERFISH_s.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentAILabHealthcare/spatialID/397417f56b2b0fc8bf388b5c504d88c18de34eac/dnn_model/checkpoint_MERFISH_s.t7 -------------------------------------------------------------------------------- /dnn_model/checkpoint_NanoString.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentAILabHealthcare/spatialID/397417f56b2b0fc8bf388b5c504d88c18de34eac/dnn_model/checkpoint_NanoString.t7 -------------------------------------------------------------------------------- /dnn_model/checkpoint_Slide-seq_DM1.z01: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentAILabHealthcare/spatialID/397417f56b2b0fc8bf388b5c504d88c18de34eac/dnn_model/checkpoint_Slide-seq_DM1.z01 -------------------------------------------------------------------------------- /dnn_model/checkpoint_Slide-seq_DM1.z02: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentAILabHealthcare/spatialID/397417f56b2b0fc8bf388b5c504d88c18de34eac/dnn_model/checkpoint_Slide-seq_DM1.z02 -------------------------------------------------------------------------------- /dnn_model/checkpoint_Slide-seq_DM1.z03: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentAILabHealthcare/spatialID/397417f56b2b0fc8bf388b5c504d88c18de34eac/dnn_model/checkpoint_Slide-seq_DM1.z03 -------------------------------------------------------------------------------- /dnn_model/checkpoint_Slide-seq_DM1.z04: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentAILabHealthcare/spatialID/397417f56b2b0fc8bf388b5c504d88c18de34eac/dnn_model/checkpoint_Slide-seq_DM1.z04 -------------------------------------------------------------------------------- /dnn_model/checkpoint_Slide-seq_DM1.z05: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentAILabHealthcare/spatialID/397417f56b2b0fc8bf388b5c504d88c18de34eac/dnn_model/checkpoint_Slide-seq_DM1.z05 -------------------------------------------------------------------------------- /dnn_model/checkpoint_Slide-seq_DM1.z06: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentAILabHealthcare/spatialID/397417f56b2b0fc8bf388b5c504d88c18de34eac/dnn_model/checkpoint_Slide-seq_DM1.z06 -------------------------------------------------------------------------------- /dnn_model/checkpoint_Slide-seq_DM1.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentAILabHealthcare/spatialID/397417f56b2b0fc8bf388b5c504d88c18de34eac/dnn_model/checkpoint_Slide-seq_DM1.zip -------------------------------------------------------------------------------- /spatialID_example_m1s1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentAILabHealthcare/spatialID/397417f56b2b0fc8bf388b5c504d88c18de34eac/spatialID_example_m1s1.png -------------------------------------------------------------------------------- /spatialID_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentAILabHealthcare/spatialID/397417f56b2b0fc8bf388b5c504d88c18de34eac/spatialID_overview.png --------------------------------------------------------------------------------