├── README.md
├── cell_type_annotation_for_hyp3d.py
├── cell_type_annotation_for_merfish.py
├── cell_type_annotation_for_nanostring.py
├── cell_type_annotation_for_slideseq.py
├── cell_type_annotation_model.pyc
├── dnn_model
├── checkpoint_Hyp-3D_b.t7
├── checkpoint_MERFISH_s.t7
├── checkpoint_NanoString.t7
├── checkpoint_Slide-seq_DM1.z01
├── checkpoint_Slide-seq_DM1.z02
├── checkpoint_Slide-seq_DM1.z03
├── checkpoint_Slide-seq_DM1.z04
├── checkpoint_Slide-seq_DM1.z05
├── checkpoint_Slide-seq_DM1.z06
└── checkpoint_Slide-seq_DM1.zip
├── spatialID_example_m1s1.png
└── spatialID_overview.png
/README.md:
--------------------------------------------------------------------------------
1 | # spatial-ID
2 |
3 | [](https://www.python.org/)
4 |
5 | ### Spatial-ID: a cell typing method for spatially resolved transcriptomics via transfer learning and spatial embedding
6 | Spatially resolved transcriptomics (SRT) provides the opportunity to investigate the gene expression profiles and the spatial context of cells in naive state. Cell type annotation is a crucial task in the spatial transcriptome analysis of cell and tissue biology. In this study, we propose Spatial-ID, a supervision-based cell typing method, for high-throughput cell-level SRT datasets that integrates transfer learning and spatial embedding. Spatial-ID effectively incorporates the existing knowledge of reference scRNA-seq datasets and the spatial information of SRT datasets.
7 |
8 |
9 |
10 | # Dependences
11 |
12 | [](https://github.com/numpy/numpy)
13 | [](https://github.com/pandas-dev/pandas)
14 | [](https://github.com/theislab/scanpy)
15 | [](https://github.com/pytorch/pytorch)
16 | [](https://github.com/pyg-team/pytorch_geometric/)
17 |
18 | # Datasets
19 |
20 | - MERFISH: 280,186 cells * 254 genes, 12 samples. [https://doi.brainimagelibrary.org/doi/10.35077/g.21](https://doi.brainimagelibrary.org/doi/10.35077/g.21)
21 | - MERFISH-3D: 213,192 cells * 155 genes, 3 samples. [https://datadryad.org/stash/dataset/doi:10.5061/dryad.8t8s248](https://datadryad.org/stash/dataset/doi:10.5061/dryad.8t8s248)
22 | - Slide-seq: 207,335 cells * 27181 genes, 6 samples. [https://www.dropbox.com/s/ygzpj0d0oh67br0/Testis_Slideseq_Data.zip?dl=0](https://www.dropbox.com/s/ygzpj0d0oh67br0/Testis_Slideseq_Data.zip?dl=0)
23 | - NanoString: 83,621 cells * 980 genes, 20 samples. [https://nanostring.com/resources/smi-ffpe-dataset-lung9-rep1-data/](https://nanostring.com/resources/smi-ffpe-dataset-lung9-rep1-data/)
24 |
25 | # Usage
26 |
27 | - Run cell\_type\_annotation\_for\_merfish.py to annotate cells in MERFISH dataset.
28 | - Run cell\_type\_annotation\_for\_hyp3d.py to annotate cells in MERFISH-3D dataset.
29 | - Run cell\_type\_annotation\_for\_slideseq.py to annotate cells in Slide-seq dataset.
30 | - Run cell\_type\_annotation\_for\_nanostring.py to annotate cells in NanoString dataset.
31 |
32 | p.s. You may need to unzip dnn\_model/checkpoint\_Slide-seq\_DM1.t7 first before running cell\_type\_annotation\_for\_slideseq.py.
33 |
34 | [!!!] Note: An AttributeError saying that 'GELU' object has no attribute 'approximate' may occurs if your pytorch version is higher than 1.10.0. You can simply downgrade pytorch to 1.8.1 or modify the source code of pytorch temporarily.
35 |
36 | # Example
37 |
38 | 1. Put downloaded MERFISH data (e.g. mouse1_sample1.h5ad) in "dataset/MERFISH/" (as in Line 30 of cell\_type\_annotation\_for\_merfish.py).
39 | 2. Run cell\_type\_annotation\_for\_merfish.py to annotate cells of mouse1_sample1 data.
40 | 3. 4 files can be found in "result/MERFISH/" (as in Line 31 of cell\_type\_annotation\_for\_merfish.py):
41 | - spatialID-mouse1\_sample1.t7: Checkpoint of the self-supervised model in Stage 2.
42 | - spatialID-mouse1\_sample1.h5ad: Updated H5AD file with annotation result stored.
43 | - spatialID-mouse1\_sample1.csv: Annotation results with column "cell" representing cell IDs and "celltype_pred" representing annotated cell types.
44 | - spatialID-mouse1\_sample1.pdf: Visualization of annotation results as shown below.
45 |
46 |
47 |
48 | # Disclaimer
49 |
50 | This tool is for research purpose and not approved for clinical use.
51 |
52 | This is not an official Tencent product.
53 |
54 | # Coypright
55 |
56 | This tool is developed in Tencent AI Lab.
57 |
58 | The copyright holder for this project is Tencent AI Lab.
59 |
60 | All rights reserved.
61 |
--------------------------------------------------------------------------------
/cell_type_annotation_for_hyp3d.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | import os
5 | import time
6 | import random
7 | import argparse
8 | import anndata
9 | import numpy as np
10 | import pandas as pd
11 | import scanpy as sc
12 | import matplotlib.pyplot as plt
13 | import torch
14 | import torch_geometric
15 |
16 | from cell_type_annotation_model import DNNModel, SpatialModelTrainer
17 |
18 |
19 | random.seed(0)
20 | np.random.seed(0)
21 | torch.manual_seed(0)
22 | torch.cuda.manual_seed(0)
23 | torch.cuda.manual_seed_all(0)
24 | torch.backends.cudnn.benchmark = False
25 | torch.backends.cudnn.deterministic = True
26 |
27 |
28 | config = {
29 | 'data': {
30 | 'data_dir': 'dataset/Hyp_3D/',
31 | 'save_dir': 'result/Hyp_3D/',
32 | 'dataset': 'Hyp_3D',
33 | },
34 | 'preprocess': {
35 | 'filter_mt': True,
36 | 'cell_min_counts': 300,
37 | 'gene_min_cells': 10,
38 | 'cell_max_counts_percent': 98.0,
39 | 'drop_rate': 0,
40 | },
41 | 'transfer': {
42 | 'dnn_model': 'dnn_model/checkpoint_Hyp-3D_b.t7',
43 | 'gpu': '0',
44 | 'batch_size': 4096,
45 | },
46 | 'train': {
47 | 'pca_dim': 200, # for Stereoseq only
48 | 'k_graph': 30,
49 | 'edge_weight': True,
50 | 'kd_T': 1,
51 | 'feat_dim': 64,
52 | 'w_dae': 1.0,
53 | 'w_gae': 1.0,
54 | 'w_cls': 10.0,
55 | 'epochs': 200,
56 | }
57 | }
58 |
59 |
60 | def spatial_classification_tool(config, data_name):
61 | ''' Spatial classification workflow.
62 |
63 | # Arguments
64 | config (Config): Configuration parameters.
65 | data_name (str): Data name.
66 | '''
67 | ######################################
68 | # Part 1: Load data #
69 | ######################################
70 |
71 | # Set path and load data.
72 | print('\n==> Loading data...')
73 | dataset = config['data']['dataset']
74 | data_dir, save_dir = config['data']['data_dir'], config['data']['save_dir']
75 | print(f' Data name: {data_name} ({dataset})')
76 | print(f' Data path: {data_dir}')
77 | print(f' Save path: {save_dir}')
78 | adata = sc.read_h5ad(os.path.join(data_dir, f'{data_name}.h5ad'))
79 |
80 | # Initalize save path.
81 | model_name = f'spatialID-{data_name}'
82 | save_dir = os.path.join(save_dir, model_name)
83 | if not os.path.exists(save_dir):
84 | os.makedirs(save_dir)
85 |
86 |
87 | ######################################
88 | # Part 2: Preprocess #
89 | ######################################
90 |
91 | print('\n==> Preprocessing...')
92 | strings = [f'{k}={v}' for k, v in config['preprocess'].items()]
93 | print(' Parameters(%s)' % (', '.join(strings)))
94 |
95 | # Preprocess data.
96 | if dataset == 'Stereoseq':
97 | params = config['preprocess']
98 | if params['filter_mt']:
99 | adata.var['mt'] = adata.var_names.str.startswith(('Mt-', 'mt-'))
100 | sc.pp.calculate_qc_metrics(adata, qc_vars=['mt'], inplace=True)
101 | adata = adata[adata.obs['pct_counts_mt'] < 10].copy()
102 | if params['cell_min_counts'] > 0:
103 | sc.pp.filter_cells(adata, min_counts=params['cell_min_counts'])
104 | if params['gene_min_cells'] > 0:
105 | sc.pp.filter_genes(adata, min_cells=params['gene_min_cells'])
106 | if params['cell_max_counts_percent'] < 100:
107 | max_counts = np.percentile(adata.X.sum(1), params['cell_max_counts_percent'])
108 | sc.pp.filter_cells(adata, max_counts=max_counts)
109 | if type(adata.X) != np.ndarray:
110 | adata_X_sparse_backup = adata.X.copy()
111 | adata.X = adata.X.toarray()
112 | print(' %s: %d cells × %d genes.' % (data_name, adata.shape[0], adata.shape[1]))
113 |
114 | # Please be aware:
115 | # DNN model takes the origin gene expression matrix through its own normalization as input.
116 | # Other normalization (e.g. scanpy) can be added after DNN model inference is completed.
117 |
118 | # Add noise manually.
119 | if dataset != 'Stereoseq':
120 | drop_factor = (np.random.random(adata.shape) > config['preprocess']['drop_rate']) * 1.
121 | adata.X = adata.X * drop_factor
122 |
123 |
124 | ######################################
125 | # Part 3: Transfer from sc-dataset #
126 | ######################################
127 |
128 | print('\n==> Transfering from sc-dataset...')
129 | strings = [f'{k}={v}' for k, v in config['transfer'].items()]
130 | print(' Parameters(%s)' % (', '.join(strings)))
131 |
132 | # Set device.
133 | os.environ['CUDA_VISIBLE_DEVICES'] = config['transfer']['gpu']
134 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
135 |
136 | # Load DNN model trained by sc-dataset.
137 | checkpoint = torch.load(config['transfer']['dnn_model'])
138 | dnn_model = checkpoint['model'].to(device)
139 | dnn_model.eval()
140 |
141 | # Initialize DNN input.
142 | marker_genes = checkpoint['marker_genes']
143 | gene_indices = adata.var_names.get_indexer(marker_genes)
144 | adata_X = np.pad(adata.X, ((0,0),(0,1)))[:, gene_indices]
145 | norm_factor = np.linalg.norm(adata_X, axis=1, keepdims=True)
146 | norm_factor[norm_factor == 0] = 1
147 | dnn_inputs = torch.Tensor(adata_X / norm_factor).split(config['transfer']['batch_size'])
148 |
149 | # Inference with DNN model.
150 | dnn_predictions = []
151 | with torch.no_grad():
152 | for batch_idx, inputs in enumerate(dnn_inputs):
153 | inputs = inputs.to(device)
154 | outputs = dnn_model(inputs)
155 | dnn_predictions.append(outputs.detach().cpu().numpy())
156 | label_names = checkpoint['label_names']
157 | adata.obsm['pseudo_label'] = np.concatenate(dnn_predictions)
158 | adata.obs['pseudo_class'] = pd.Categorical([label_names[i] for i in adata.obsm['pseudo_label'].argmax(1)])
159 | adata.uns['pseudo_classes'] = label_names
160 |
161 | # Compute accuracy (only for Slide-seq).
162 | if dataset == 'Hyp_3D':
163 | indices = np.where(~adata.obs['Cell_class'].isin(['Ambiguous']))[0]
164 | adjusted_pr = adata.obs['pseudo_class'][indices].to_numpy()
165 | adjusted_gt = adata.obs['Cell_class'][indices].replace(
166 | ['Endothelial 1', 'Endothelial 2', 'Endothelial 3',
167 | 'OD Immature 1', 'OD Immature 2',
168 | 'OD Mature 1', 'OD Mature 2', 'OD Mature 3', 'OD Mature 4',
169 | 'Astrocyte', 'Pericytes'],
170 | ['Endothelial', 'Endothelial', 'Endothelial',
171 | 'Immature oligodendrocyte', 'Immature oligodendrocyte',
172 | 'Mature oligodendrocyte', 'Mature oligodendrocyte', 'Mature oligodendrocyte', 'Mature oligodendrocyte',
173 | 'Astrocytes', 'Mural']).to_numpy()
174 | acc = (adjusted_pr == adjusted_gt).sum() / len(indices) * 100.0
175 | print(' %s Acc (transfer only): %.2f%%' % (data_name, acc))
176 |
177 |
178 | ######################################
179 | # Part 4: Train GDAE model #
180 | ######################################
181 |
182 | print('\n==> Model training...')
183 | strings = [f'{k}={v}' for k, v in config['train'].items()]
184 | print(' Parameters(%s)' % (', '.join(strings)))
185 |
186 | # Normalize gene expression.
187 | sc.pp.normalize_total(adata, target_sum=1e4)
188 | sc.pp.log1p(adata)
189 | adata_X = (adata.X - adata.X.mean(0)) / (adata.X.std(0) + 1e-10)
190 |
191 | # Construct spatial graph.
192 | gene_mat = torch.Tensor(adata_X)
193 | if dataset == 'Stereoseq': # PCA
194 | u, s, v = torch.pca_lowrank(gene_mat, config['train']['pca_dim'])
195 | gene_mat = torch.matmul(gene_mat, v)
196 | if dataset == 'Hyp_3D':
197 | z_axis = adata.obs['Bregma'].to_numpy() * 1000
198 | for z in set(z_axis):
199 | adata.obsm['spatial'][z_axis == z] -= adata.obsm['spatial'][z_axis == z].min(0)
200 | adata.obsm['spatial'] = np.hstack([adata.obsm['spatial'], z_axis.reshape(-1, 1)])
201 | cell_coo = torch.Tensor(adata.obsm['spatial'])
202 | data = torch_geometric.data.Data(x=gene_mat, pos=cell_coo)
203 | data = torch_geometric.transforms.KNNGraph(k=config['train']['k_graph'], loop=True)(data)
204 | data.y = torch.Tensor(adata.obsm['pseudo_label'])
205 |
206 | # Make distances as edge weights.
207 | if config['train']['edge_weight']:
208 | data = torch_geometric.transforms.Distance()(data)
209 | data.edge_weight = 1 - data.edge_attr[:,0]
210 | else:
211 | data.edge_weight = torch.ones(data.edge_index.size(1))
212 |
213 | # Train self-supervision model.
214 | input_dim = data.num_features
215 | num_classes = len(adata.uns['pseudo_classes'])
216 | trainer = SpatialModelTrainer(input_dim, num_classes, device, config['train'])
217 | trainer.train(data, config['train'])
218 | trainer.save_checkpoint(os.path.join(save_dir, f'{model_name}.t7'))
219 |
220 | # Inference.
221 | print('\n==> Inferencing...')
222 | predictions = trainer.valid(data)
223 | celltype_pred = pd.Categorical([adata.uns['pseudo_classes'][i] for i in predictions.argmax(1)])
224 | if dataset == 'Hyp_3D':
225 | indices = np.where(~adata.obs['Cell_class'].isin(['Ambiguous']))[0]
226 | adjusted_pr = celltype_pred[indices].to_numpy()
227 | adjusted_gt = adata.obs['Cell_class'][indices].replace(
228 | ['Endothelial 1', 'Endothelial 2', 'Endothelial 3',
229 | 'OD Immature 1', 'OD Immature 2',
230 | 'OD Mature 1', 'OD Mature 2', 'OD Mature 3', 'OD Mature 4',
231 | 'Astrocyte', 'Pericytes'],
232 | ['Endothelial', 'Endothelial', 'Endothelial',
233 | 'Immature oligodendrocyte', 'Immature oligodendrocyte',
234 | 'Mature oligodendrocyte', 'Mature oligodendrocyte', 'Mature oligodendrocyte', 'Mature oligodendrocyte',
235 | 'Astrocytes', 'Mural']).to_numpy()
236 | acc = (adjusted_pr == adjusted_gt).sum() / len(indices) * 100.0
237 | print(' %s Acc (transfer+GDAE): %.2f%%' % (data_name, acc))
238 |
239 | # Save results.
240 | result = pd.DataFrame({'cell': adata.obs_names.tolist(), 'celltype_pred': celltype_pred})
241 | result.to_csv(os.path.join(save_dir, f'{model_name}.csv'), index=False)
242 | adata.obsm['celltype_prob'] = predictions
243 | adata.obs['celltype_pred'] = pd.Categorical(celltype_pred)
244 | if 'adata_X_sparse_backup' in locals():
245 | adata.X = adata_X_sparse_backup
246 | adata.write(os.path.join(save_dir, f'{model_name}.h5ad'))
247 |
248 | # Save visualization.
249 | spot_size = (30 if dataset == 'Stereoseq' else 20)
250 | if dataset == 'Stereoseq':
251 | pseudo_top100 = adata.obs['pseudo_class'].to_numpy()
252 | other_classes = list(pd.value_counts(adata.obs['pseudo_class'])[100:].index)
253 | pseudo_top100[adata.obs['pseudo_class'].isin(other_classes)] = '_Others'
254 | adata.obs['pseudo_class'] = pd.Categorical(pseudo_top100)
255 | # sc.pl.spatial(adata, img_key=None, color=['pseudo_class'], spot_size=spot_size, show=False)
256 | # plt.savefig(os.path.join(save_dir, f'pseudo-{data_name}.pdf'), bbox_inches='tight', dpi=150)
257 | sc.pl.spatial(adata, img_key=None, color=['celltype_pred'], spot_size=spot_size, show=False)
258 | plt.savefig(os.path.join(save_dir, f'{model_name}.pdf'), bbox_inches='tight', dpi=150)
259 | print(' Predictions is saved in', os.path.join(save_dir, f'{model_name}.csv/pdf'))
260 |
261 |
262 | if __name__ == '__main__':
263 | data_list = ['sample1', 'sample2', 'sample3']
264 | parser = argparse.ArgumentParser()
265 | parser.add_argument('--data_name', choices=data_list)
266 | args = parser.parse_args()
267 | spatial_classification_tool(config, args.data_name)
268 |
--------------------------------------------------------------------------------
/cell_type_annotation_for_merfish.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | import os
5 | import time
6 | import random
7 | import argparse
8 | import anndata
9 | import numpy as np
10 | import pandas as pd
11 | import scanpy as sc
12 | import matplotlib.pyplot as plt
13 | import torch
14 | import torch_geometric
15 |
16 | from cell_type_annotation_model import DNNModel, SpatialModelTrainer
17 |
18 |
19 | random.seed(0)
20 | np.random.seed(0)
21 | torch.manual_seed(0)
22 | torch.cuda.manual_seed(0)
23 | torch.cuda.manual_seed_all(0)
24 | torch.backends.cudnn.benchmark = False
25 | torch.backends.cudnn.deterministic = True
26 |
27 |
28 | config = {
29 | 'data': {
30 | 'data_dir': 'dataset/MERFISH/',
31 | 'save_dir': 'result/MERFISH/',
32 | 'dataset': 'MERFISH',
33 | },
34 | 'preprocess': {
35 | 'filter_mt': True,
36 | 'cell_min_counts': 300,
37 | 'gene_min_cells': 10,
38 | 'cell_max_counts_percent': 98.0,
39 | 'drop_rate': 0,
40 | },
41 | 'transfer': {
42 | 'dnn_model': 'dnn_model/checkpoint_MERFISH_s.t7',
43 | 'gpu': '0',
44 | 'batch_size': 4096,
45 | },
46 | 'train': {
47 | 'pca_dim': 200, # for Stereoseq only
48 | 'k_graph': 30,
49 | 'edge_weight': True,
50 | 'kd_T': 1,
51 | 'feat_dim': 64,
52 | 'w_dae': 1.0,
53 | 'w_gae': 1.0,
54 | 'w_cls': 10.0,
55 | 'epochs': 200,
56 | }
57 | }
58 |
59 |
60 | def spatial_classification_tool(config, data_name):
61 | ''' Spatial classification workflow.
62 |
63 | # Arguments
64 | config (Config): Configuration parameters.
65 | data_name (str): Data name.
66 | '''
67 | ######################################
68 | # Part 1: Load data #
69 | ######################################
70 |
71 | # Set path and load data.
72 | print('\n==> Loading data...')
73 | dataset = config['data']['dataset']
74 | data_dir, save_dir = config['data']['data_dir'], config['data']['save_dir']
75 | print(f' Data name: {data_name} ({dataset})')
76 | print(f' Data path: {data_dir}')
77 | print(f' Save path: {save_dir}')
78 | adata = sc.read_h5ad(os.path.join(data_dir, f'{data_name}.h5ad'))
79 |
80 | # Initalize save path.
81 | model_name = f'spatialID-{data_name}'
82 | save_dir = os.path.join(save_dir, model_name)
83 | if not os.path.exists(save_dir):
84 | os.makedirs(save_dir)
85 |
86 |
87 | ######################################
88 | # Part 2: Preprocess #
89 | ######################################
90 |
91 | print('\n==> Preprocessing...')
92 | strings = [f'{k}={v}' for k, v in config['preprocess'].items()]
93 | print(' Parameters(%s)' % (', '.join(strings)))
94 |
95 | # Preprocess data.
96 | if dataset == 'Stereoseq':
97 | params = config['preprocess']
98 | if params['filter_mt']:
99 | adata.var['mt'] = adata.var_names.str.startswith(('Mt-', 'mt-'))
100 | sc.pp.calculate_qc_metrics(adata, qc_vars=['mt'], inplace=True)
101 | adata = adata[adata.obs['pct_counts_mt'] < 10].copy()
102 | if params['cell_min_counts'] > 0:
103 | sc.pp.filter_cells(adata, min_counts=params['cell_min_counts'])
104 | if params['gene_min_cells'] > 0:
105 | sc.pp.filter_genes(adata, min_cells=params['gene_min_cells'])
106 | if params['cell_max_counts_percent'] < 100:
107 | max_counts = np.percentile(adata.X.sum(1), params['cell_max_counts_percent'])
108 | sc.pp.filter_cells(adata, max_counts=max_counts)
109 | if type(adata.X) != np.ndarray:
110 | adata_X_sparse_backup = adata.X.copy()
111 | adata.X = adata.X.toarray()
112 | print(' %s: %d cells × %d genes.' % (data_name, adata.shape[0], adata.shape[1]))
113 |
114 | # Please be aware:
115 | # DNN model takes the origin gene expression matrix through its own normalization as input.
116 | # Other normalization (e.g. scanpy) can be added after DNN model inference is completed.
117 |
118 | # Add noise manually.
119 | if dataset != 'Stereoseq':
120 | drop_factor = (np.random.random(adata.shape) > config['preprocess']['drop_rate']) * 1.
121 | adata.X = adata.X * drop_factor
122 |
123 |
124 | ######################################
125 | # Part 3: Transfer from sc-dataset #
126 | ######################################
127 |
128 | print('\n==> Transfering from sc-dataset...')
129 | strings = [f'{k}={v}' for k, v in config['transfer'].items()]
130 | print(' Parameters(%s)' % (', '.join(strings)))
131 |
132 | # Set device.
133 | os.environ['CUDA_VISIBLE_DEVICES'] = config['transfer']['gpu']
134 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
135 |
136 | # Load DNN model trained by sc-dataset.
137 | checkpoint = torch.load(config['transfer']['dnn_model'])
138 | dnn_model = checkpoint['model'].to(device)
139 | dnn_model.eval()
140 |
141 | # Initialize DNN input.
142 | marker_genes = checkpoint['marker_genes']
143 | gene_indices = adata.var_names.get_indexer(marker_genes)
144 | adata_X = np.pad(adata.X, ((0,0),(0,1)))[:, gene_indices].copy()
145 | norm_factor = np.linalg.norm(adata_X, axis=1, keepdims=True)
146 | norm_factor[norm_factor == 0] = 1
147 | dnn_inputs = torch.Tensor(adata_X / norm_factor).split(config['transfer']['batch_size'])
148 |
149 | # Inference with DNN model.
150 | dnn_predictions = []
151 | with torch.no_grad():
152 | for batch_idx, inputs in enumerate(dnn_inputs):
153 | inputs = inputs.to(device)
154 | outputs = dnn_model(inputs)
155 | dnn_predictions.append(outputs.detach().cpu().numpy())
156 | label_names = checkpoint['label_names']
157 | adata.obsm['pseudo_label'] = np.concatenate(dnn_predictions)
158 | adata.obs['pseudo_class'] = pd.Categorical([label_names[i] for i in adata.obsm['pseudo_label'].argmax(1)])
159 | adata.uns['pseudo_classes'] = label_names
160 |
161 | # Compute accuracy (only for MERFISH).
162 | if dataset == 'MERFISH':
163 | indices = np.where(~adata.obs['subclass'].isin(['L4/5 IT', 'L6 IT Car3', 'PVM', 'other']))[0]
164 | adjusted_pr = adata.obs['pseudo_class'][indices].to_numpy()
165 | adjusted_gt = adata.obs['subclass'][indices].replace(['Micro'], ['Macrophage']).to_numpy()
166 | acc = (adjusted_pr == adjusted_gt).sum() / len(indices) * 100.0
167 | print(' %s Acc (transfer only): %.2f%%' % (data_name, acc))
168 |
169 |
170 | ######################################
171 | # Part 4: Train GDAE model #
172 | ######################################
173 |
174 | print('\n==> Model training...')
175 | strings = [f'{k}={v}' for k, v in config['train'].items()]
176 | print(' Parameters(%s)' % (', '.join(strings)))
177 |
178 | # Normalize gene expression.
179 | sc.pp.normalize_total(adata, target_sum=1e4)
180 | sc.pp.log1p(adata)
181 | adata_X = (adata.X - adata.X.mean(0)) / (adata.X.std(0) + 1e-10)
182 |
183 | # Construct spatial graph.
184 | gene_mat = torch.Tensor(adata_X)
185 | if dataset == 'Stereoseq': # PCA
186 | u, s, v = torch.pca_lowrank(gene_mat, config['train']['pca_dim'])
187 | gene_mat = torch.matmul(gene_mat, v)
188 | cell_coo = torch.Tensor(adata.obsm['spatial'])
189 | data = torch_geometric.data.Data(x=gene_mat, pos=cell_coo)
190 | data = torch_geometric.transforms.KNNGraph(k=config['train']['k_graph'], loop=True)(data)
191 | data.y = torch.Tensor(adata.obsm['pseudo_label'])
192 |
193 | # Make distances as edge weights.
194 | if config['train']['edge_weight']:
195 | data = torch_geometric.transforms.Distance()(data)
196 | data.edge_weight = 1 - data.edge_attr[:,0]
197 | else:
198 | data.edge_weight = torch.ones(data.edge_index.size(1))
199 |
200 | # Train self-supervision model.
201 | input_dim = data.num_features
202 | num_classes = len(adata.uns['pseudo_classes'])
203 | trainer = SpatialModelTrainer(input_dim, num_classes, device, config['train'])
204 | trainer.train(data, config['train'])
205 | trainer.save_checkpoint(os.path.join(save_dir, f'{model_name}.t7'))
206 |
207 | # Inference.
208 | print('\n==> Inferencing...')
209 | predictions = trainer.valid(data)
210 | celltype_pred = pd.Categorical([adata.uns['pseudo_classes'][i] for i in predictions.argmax(1)])
211 | if dataset == 'MERFISH':
212 | indices = np.where(~adata.obs['subclass'].isin(['L4/5 IT', 'L6 IT Car3', 'PVM', 'other']))[0]
213 | adjusted_pr = celltype_pred[indices].to_numpy()
214 | adjusted_gt = adata.obs['subclass'][indices].replace(['Micro'], ['Macrophage']).to_numpy()
215 | acc = (adjusted_pr == adjusted_gt).sum() / len(indices) * 100.0
216 | print(' %s Acc (transfer+GDAE): %.2f%%' % (data_name, acc))
217 |
218 | # Save results.
219 | result = pd.DataFrame({'cell': adata.obs_names.tolist(), 'celltype_pred': celltype_pred})
220 | result.to_csv(os.path.join(save_dir, f'{model_name}.csv'), index=False)
221 | adata.obsm['celltype_prob'] = predictions
222 | adata.obs['celltype_pred'] = pd.Categorical(celltype_pred)
223 | if 'adata_X_sparse_backup' in locals():
224 | adata.X = adata_X_sparse_backup
225 | adata.write(os.path.join(save_dir, f'{model_name}.h5ad'))
226 |
227 | # Save visualization.
228 | spot_size = (30 if dataset == 'Stereoseq' else 20)
229 | if dataset == 'Stereoseq':
230 | pseudo_top100 = adata.obs['pseudo_class'].to_numpy()
231 | other_classes = list(pd.value_counts(adata.obs['pseudo_class'])[100:].index)
232 | pseudo_top100[adata.obs['pseudo_class'].isin(other_classes)] = '_Others'
233 | adata.obs['pseudo_class'] = pd.Categorical(pseudo_top100)
234 | # sc.pl.spatial(adata, img_key=None, color=['pseudo_class'], spot_size=spot_size, show=False)
235 | # plt.savefig(os.path.join(save_dir, f'pseudo-{data_name}.pdf'), bbox_inches='tight', dpi=150)
236 | sc.pl.spatial(adata, img_key=None, color=['celltype_pred'], spot_size=spot_size, show=False)
237 | plt.savefig(os.path.join(save_dir, f'{model_name}.pdf'), bbox_inches='tight', dpi=150)
238 | print(' Predictions is saved in', os.path.join(save_dir, f'{model_name}.csv/pdf'))
239 |
240 |
241 | if __name__ == '__main__':
242 | data_list = ['mouse1_sample1', 'mouse1_sample2', 'mouse1_sample3', 'mouse1_sample4', 'mouse1_sample5', 'mouse1_sample6',
243 | 'mouse2_sample1', 'mouse2_sample2', 'mouse2_sample3', 'mouse2_sample4', 'mouse2_sample5', 'mouse2_sample6']
244 | parser = argparse.ArgumentParser()
245 | parser.add_argument('--data_name', choices=data_list)
246 | args = parser.parse_args()
247 | spatial_classification_tool(config, args.data_name)
248 |
--------------------------------------------------------------------------------
/cell_type_annotation_for_nanostring.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | import os
5 | import time
6 | import random
7 | import argparse
8 | import anndata
9 | import numpy as np
10 | import pandas as pd
11 | import scanpy as sc
12 | import matplotlib.pyplot as plt
13 | import torch
14 | import torch_geometric
15 |
16 | from cell_type_annotation_model import DNNModel, SpatialModelTrainer
17 |
18 |
19 | random.seed(0)
20 | np.random.seed(0)
21 | torch.manual_seed(0)
22 | torch.cuda.manual_seed(0)
23 | torch.cuda.manual_seed_all(0)
24 | torch.backends.cudnn.benchmark = False
25 | torch.backends.cudnn.deterministic = True
26 |
27 |
28 | config = {
29 | 'data': {
30 | 'data_dir': 'dataset/NanoString/',
31 | 'save_dir': 'result/NanoString/',
32 | 'dataset': 'NanoString',
33 | },
34 | 'preprocess': {
35 | 'filter_mt': True,
36 | 'cell_min_counts': 300,
37 | 'gene_min_cells': 10,
38 | 'cell_max_counts_percent': 98.0,
39 | 'drop_rate': 0,
40 | },
41 | 'transfer': {
42 | 'dnn_model': 'dnn_model/checkpoint_NanoString.t7',
43 | 'gpu': '0',
44 | 'batch_size': 4096,
45 | },
46 | 'train': {
47 | 'pca_dim': 200, # for Stereoseq only
48 | 'k_graph': 30,
49 | 'edge_weight': True,
50 | 'kd_T': 1,
51 | 'feat_dim': 64,
52 | 'w_dae': 1.0,
53 | 'w_gae': 1.0,
54 | 'w_cls': 10.0,
55 | 'epochs': 200,
56 | }
57 | }
58 |
59 |
60 | def spatial_classification_tool(config, data_name):
61 | ''' Spatial classification workflow.
62 |
63 | # Arguments
64 | config (Config): Configuration parameters.
65 | data_name (str): Data name.
66 | '''
67 | ######################################
68 | # Part 1: Load data #
69 | ######################################
70 |
71 | # Set path and load data.
72 | print('\n==> Loading data...')
73 | dataset = config['data']['dataset']
74 | data_dir, save_dir = config['data']['data_dir'], config['data']['save_dir']
75 | print(f' Data name: {data_name} ({dataset})')
76 | print(f' Data path: {data_dir}')
77 | print(f' Save path: {save_dir}')
78 | adata = sc.read_h5ad(os.path.join(data_dir, f'{data_name}.h5ad'))
79 |
80 | # Initalize save path.
81 | model_name = f'spatialID-{data_name}'
82 | save_dir = os.path.join(save_dir, model_name)
83 | if not os.path.exists(save_dir):
84 | os.makedirs(save_dir)
85 |
86 |
87 | ######################################
88 | # Part 2: Preprocess #
89 | ######################################
90 |
91 | print('\n==> Preprocessing...')
92 | strings = [f'{k}={v}' for k, v in config['preprocess'].items()]
93 | print(' Parameters(%s)' % (', '.join(strings)))
94 |
95 | # Preprocess data.
96 | if dataset == 'Stereoseq':
97 | params = config['preprocess']
98 | if params['filter_mt']:
99 | adata.var['mt'] = adata.var_names.str.startswith(('Mt-', 'mt-'))
100 | sc.pp.calculate_qc_metrics(adata, qc_vars=['mt'], inplace=True)
101 | adata = adata[adata.obs['pct_counts_mt'] < 10].copy()
102 | if params['cell_min_counts'] > 0:
103 | sc.pp.filter_cells(adata, min_counts=params['cell_min_counts'])
104 | if params['gene_min_cells'] > 0:
105 | sc.pp.filter_genes(adata, min_cells=params['gene_min_cells'])
106 | if params['cell_max_counts_percent'] < 100:
107 | max_counts = np.percentile(adata.X.sum(1), params['cell_max_counts_percent'])
108 | sc.pp.filter_cells(adata, max_counts=max_counts)
109 | if type(adata.X) != np.ndarray:
110 | adata_X_sparse_backup = adata.X.copy()
111 | adata.X = adata.X.toarray()
112 | print(' %s: %d cells × %d genes.' % (data_name, adata.shape[0], adata.shape[1]))
113 |
114 | # Please be aware:
115 | # DNN model takes the origin gene expression matrix through its own normalization as input.
116 | # Other normalization (e.g. scanpy) can be added after DNN model inference is completed.
117 |
118 | # Add noise manually.
119 | if dataset != 'Stereoseq':
120 | drop_factor = (np.random.random(adata.shape) > config['preprocess']['drop_rate']) * 1.
121 | adata.X = adata.X * drop_factor
122 |
123 |
124 | ######################################
125 | # Part 3: Transfer from sc-dataset #
126 | ######################################
127 |
128 | print('\n==> Transfering from sc-dataset...')
129 | strings = [f'{k}={v}' for k, v in config['transfer'].items()]
130 | print(' Parameters(%s)' % (', '.join(strings)))
131 |
132 | # Set device.
133 | os.environ['CUDA_VISIBLE_DEVICES'] = config['transfer']['gpu']
134 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
135 |
136 | # Load DNN model trained by sc-dataset.
137 | checkpoint = torch.load(config['transfer']['dnn_model'])
138 | dnn_model = checkpoint['model'].to(device)
139 | dnn_model.eval()
140 |
141 | # Initialize DNN input.
142 | marker_genes = checkpoint['marker_genes']
143 | gene_indices = adata.var_names.get_indexer(marker_genes)
144 | adata_X = np.pad(adata.X, ((0,0),(0,1)))[:, gene_indices]
145 | norm_factor = np.linalg.norm(adata_X, axis=1, keepdims=True)
146 | norm_factor[norm_factor == 0] = 1
147 | dnn_inputs = torch.Tensor(adata_X / norm_factor).split(config['transfer']['batch_size'])
148 |
149 | # Inference with DNN model.
150 | dnn_predictions = []
151 | with torch.no_grad():
152 | for batch_idx, inputs in enumerate(dnn_inputs):
153 | inputs = inputs.to(device)
154 | outputs = dnn_model(inputs)
155 | dnn_predictions.append(outputs.detach().cpu().numpy())
156 | label_names = checkpoint['label_names']
157 | adata.obsm['pseudo_label'] = np.concatenate(dnn_predictions)
158 | adata.obs['pseudo_class'] = pd.Categorical([label_names[i] for i in adata.obsm['pseudo_label'].argmax(1)])
159 | adata.uns['pseudo_classes'] = label_names
160 |
161 | # Compute accuracy (only for NanoString).
162 | if dataset == 'NanoString':
163 | indices = np.where(adata.obs['celltype_refined'].isin(label_names))[0]
164 | adjusted_pr = adata.obs['pseudo_class'][indices].to_numpy()
165 | adjusted_gt = adata.obs['celltype_refined'][indices].to_numpy()
166 | acc = (adjusted_pr == adjusted_gt).sum() / len(indices) * 100.0
167 | print(' %s Acc (transfer only): %.2f%%' % (data_name, acc))
168 |
169 |
170 | ######################################
171 | # Part 4: Train GDAE model #
172 | ######################################
173 |
174 | print('\n==> Model training...')
175 | strings = [f'{k}={v}' for k, v in config['train'].items()]
176 | print(' Parameters(%s)' % (', '.join(strings)))
177 |
178 | # Normalize gene expression.
179 | sc.pp.normalize_total(adata, target_sum=1e4)
180 | sc.pp.log1p(adata)
181 | adata_X = (adata.X - adata.X.mean(0)) / (adata.X.std(0) + 1e-10)
182 |
183 | # Construct spatial graph.
184 | gene_mat = torch.Tensor(adata_X)
185 | if dataset == 'Stereoseq': # PCA
186 | u, s, v = torch.pca_lowrank(gene_mat, config['train']['pca_dim'])
187 | gene_mat = torch.matmul(gene_mat, v)
188 | cell_coo = torch.Tensor(adata.obsm['spatial'])
189 | data = torch_geometric.data.Data(x=gene_mat, pos=cell_coo)
190 | data = torch_geometric.transforms.KNNGraph(k=config['train']['k_graph'], loop=True)(data)
191 | data.y = torch.Tensor(adata.obsm['pseudo_label'])
192 |
193 | # Make distances as edge weights.
194 | if config['train']['edge_weight']:
195 | data = torch_geometric.transforms.Distance()(data)
196 | data.edge_weight = 1 - data.edge_attr[:,0]
197 | else:
198 | data.edge_weight = torch.ones(data.edge_index.size(1))
199 |
200 | # Train self-supervision model.
201 | input_dim = data.num_features
202 | num_classes = len(adata.uns['pseudo_classes'])
203 | trainer = SpatialModelTrainer(input_dim, num_classes, device, config['train'])
204 | trainer.train(data, config['train'])
205 | trainer.save_checkpoint(os.path.join(save_dir, f'{model_name}.t7'))
206 |
207 | # Inference.
208 | print('\n==> Inferencing...')
209 | predictions = trainer.valid(data)
210 | celltype_pred = pd.Categorical([adata.uns['pseudo_classes'][i] for i in predictions.argmax(1)])
211 | if dataset == 'NanoString':
212 | indices = np.where(adata.obs['celltype_refined'].isin(label_names))[0]
213 | adjusted_pr = celltype_pred[indices].to_numpy()
214 | adjusted_gt = adata.obs['celltype_refined'][indices].to_numpy()
215 | acc = (adjusted_pr == adjusted_gt).sum() / len(indices) * 100.0
216 | print(' %s Acc (transfer+GDAE): %.2f%%' % (data_name, acc))
217 |
218 | # Save results.
219 | result = pd.DataFrame({'cell': adata.obs_names.tolist(), 'celltype_pred': celltype_pred})
220 | result.to_csv(os.path.join(save_dir, f'{model_name}.csv'), index=False)
221 | adata.obsm['celltype_prob'] = predictions
222 | adata.obs['celltype_pred'] = pd.Categorical(celltype_pred)
223 | if 'adata_X_sparse_backup' in locals():
224 | adata.X = adata_X_sparse_backup
225 | adata.write(os.path.join(save_dir, f'{model_name}.h5ad'))
226 |
227 | # Save visualization.
228 | spot_size = (30 if dataset == 'Stereoseq' else 20)
229 | if dataset == 'Stereoseq':
230 | pseudo_top100 = adata.obs['pseudo_class'].to_numpy()
231 | other_classes = list(pd.value_counts(adata.obs['pseudo_class'])[100:].index)
232 | pseudo_top100[adata.obs['pseudo_class'].isin(other_classes)] = '_Others'
233 | adata.obs['pseudo_class'] = pd.Categorical(pseudo_top100)
234 | # sc.pl.spatial(adata, img_key=None, color=['pseudo_class'], spot_size=spot_size, show=False)
235 | # plt.savefig(os.path.join(save_dir, f'pseudo-{data_name}.pdf'), bbox_inches='tight', dpi=150)
236 | sc.pl.spatial(adata, img_key=None, color=['celltype_pred'], spot_size=spot_size, show=False)
237 | plt.savefig(os.path.join(save_dir, f'{model_name}.pdf'), bbox_inches='tight', dpi=150)
238 | print(' Predictions is saved in', os.path.join(save_dir, f'{model_name}.csv/pdf'))
239 |
240 |
241 | if __name__ == '__main__':
242 | data_list = [f'nanostring_fov{i+1}_sampledata' for i in range(20)]
243 | parser = argparse.ArgumentParser()
244 | parser.add_argument('--data_name', choices=data_list)
245 | args = parser.parse_args()
246 | spatial_classification_tool(config, args.data_name)
247 |
--------------------------------------------------------------------------------
/cell_type_annotation_for_slideseq.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | import os
5 | import time
6 | import random
7 | import argparse
8 | import anndata
9 | import numpy as np
10 | import pandas as pd
11 | import scanpy as sc
12 | import matplotlib.pyplot as plt
13 | import torch
14 | import torch_geometric
15 |
16 | from cell_type_annotation_model import DNNModel, SpatialModelTrainer
17 |
18 |
19 | random.seed(0)
20 | np.random.seed(0)
21 | torch.manual_seed(0)
22 | torch.cuda.manual_seed(0)
23 | torch.cuda.manual_seed_all(0)
24 | torch.backends.cudnn.benchmark = False
25 | torch.backends.cudnn.deterministic = True
26 |
27 |
28 | config = {
29 | 'data': {
30 | 'data_dir': 'dataset/Slide_seq/',
31 | 'save_dir': 'result/Slide_seq/',
32 | 'dataset': 'Slide_seq',
33 | 'radius_ratio': 1.0,
34 | },
35 | 'preprocess': {
36 | 'filter_mt': True,
37 | 'cell_min_counts': 300,
38 | 'gene_min_cells': 10,
39 | 'cell_max_counts_percent': 98.0,
40 | 'drop_rate': 0,
41 | },
42 | 'transfer': {
43 | 'dnn_model': 'dnn_model/checkpoint_Slide-seq_DM1.t7',
44 | 'gpu': '0',
45 | 'batch_size': 4096,
46 | },
47 | 'train': {
48 | 'pca_dim': 200, # for Stereoseq only
49 | 'k_graph': 30,
50 | 'edge_weight': True,
51 | 'kd_T': 1,
52 | 'feat_dim': 64,
53 | 'w_dae': 1.0,
54 | 'w_gae': 10.0,
55 | 'w_cls': 10.0,
56 | 'epochs': 200,
57 | }
58 | }
59 |
60 |
61 | def spatial_classification_tool(config, data_name):
62 | ''' Spatial classification workflow.
63 |
64 | # Arguments
65 | config (Config): Configuration parameters.
66 | data_name (str): Data name.
67 | '''
68 | ######################################
69 | # Part 1: Load data #
70 | ######################################
71 |
72 | # Set path and load data.
73 | print('\n==> Loading data...')
74 | dataset = config['data']['dataset']
75 | data_dir, save_dir = config['data']['data_dir'], config['data']['save_dir']
76 | print(f' Data name: {data_name} ({dataset})')
77 | print(f' Data path: {data_dir}')
78 | print(f' Save path: {save_dir}')
79 | adata = sc.read_h5ad(os.path.join(data_dir, f'{data_name}.h5ad'))
80 |
81 | # Sample center area for time test.
82 | if config['data']['radius_ratio'] < 1:
83 | adata.obsm['spatial'] = adata.obsm['spatial'] - adata.obsm['spatial'].mean(0)
84 | radius = adata.obsm['spatial'].max(0).mean() * config['data']['radius_ratio']
85 | adata = adata[(adata.obsm['spatial']**2).sum(1) < radius**2].copy()
86 |
87 | # Initalize save path.
88 | model_name = f'spatialID-{data_name}'
89 | save_dir = os.path.join(save_dir, model_name)
90 | if not os.path.exists(save_dir):
91 | os.makedirs(save_dir)
92 |
93 |
94 | ######################################
95 | # Part 2: Preprocess #
96 | ######################################
97 |
98 | print('\n==> Preprocessing...')
99 | strings = [f'{k}={v}' for k, v in config['preprocess'].items()]
100 | print(' Parameters(%s)' % (', '.join(strings)))
101 |
102 | # Preprocess data.
103 | if dataset == 'Stereoseq':
104 | params = config['preprocess']
105 | if params['filter_mt']:
106 | adata.var['mt'] = adata.var_names.str.startswith(('Mt-', 'mt-'))
107 | sc.pp.calculate_qc_metrics(adata, qc_vars=['mt'], inplace=True)
108 | adata = adata[adata.obs['pct_counts_mt'] < 10].copy()
109 | if params['cell_min_counts'] > 0:
110 | sc.pp.filter_cells(adata, min_counts=params['cell_min_counts'])
111 | if params['gene_min_cells'] > 0:
112 | sc.pp.filter_genes(adata, min_cells=params['gene_min_cells'])
113 | if params['cell_max_counts_percent'] < 100:
114 | max_counts = np.percentile(adata.X.sum(1), params['cell_max_counts_percent'])
115 | sc.pp.filter_cells(adata, max_counts=max_counts)
116 | if type(adata.X) != np.ndarray:
117 | adata_X_sparse_backup = adata.X.copy()
118 | adata.X = adata.X.toarray()
119 | print(' %s: %d cells × %d genes.' % (data_name, adata.shape[0], adata.shape[1]))
120 |
121 | # Please be aware:
122 | # DNN model takes the origin gene expression matrix through its own normalization as input.
123 | # Other normalization (e.g. scanpy) can be added after DNN model inference is completed.
124 |
125 | # Add noise manually.
126 | if dataset != 'Stereoseq':
127 | drop_factor = (np.random.random(adata.shape) > config['preprocess']['drop_rate']) * 1.
128 | adata.X = adata.X * drop_factor
129 |
130 |
131 | ######################################
132 | # Part 3: Transfer from sc-dataset #
133 | ######################################
134 |
135 | print('\n==> Transfering from sc-dataset...')
136 | strings = [f'{k}={v}' for k, v in config['transfer'].items()]
137 | print(' Parameters(%s)' % (', '.join(strings)))
138 | time1 = time.time()
139 |
140 | # Set device.
141 | os.environ['CUDA_VISIBLE_DEVICES'] = config['transfer']['gpu']
142 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
143 |
144 | # Load DNN model trained by sc-dataset.
145 | checkpoint = torch.load(config['transfer']['dnn_model'])
146 | dnn_model = checkpoint['model'].to(device)
147 | dnn_model.eval()
148 |
149 | # Initialize DNN input.
150 | marker_genes = checkpoint['marker_genes']
151 | gene_indices = adata.var_names.get_indexer(marker_genes)
152 | adata_X = np.pad(adata.X, ((0,0),(0,1)))[:, gene_indices]
153 | if dataset == 'Slide_seq':
154 | adata_X = np.log1p(adata_X)
155 | norm_factor = np.linalg.norm(adata_X, axis=1, keepdims=True)
156 | norm_factor[norm_factor == 0] = 1
157 | dnn_inputs = torch.Tensor(adata_X / norm_factor).split(config['transfer']['batch_size'])
158 |
159 | # Inference with DNN model.
160 | dnn_predictions = []
161 | with torch.no_grad():
162 | for batch_idx, inputs in enumerate(dnn_inputs):
163 | inputs = inputs.to(device)
164 | outputs = dnn_model(inputs)
165 | dnn_predictions.append(outputs.detach().cpu().numpy())
166 | label_names = checkpoint['label_names']
167 | adata.obsm['pseudo_label'] = np.concatenate(dnn_predictions)
168 | adata.obs['pseudo_class'] = pd.Categorical([label_names[i] for i in adata.obsm['pseudo_label'].argmax(1)])
169 | adata.uns['pseudo_classes'] = label_names
170 |
171 | # Compute accuracy (only for Slide-seq).
172 | if dataset == 'Slide_seq':
173 | adjusted_pr = adata.obs['pseudo_class'].to_numpy()
174 | adjusted_gt = adata.obs['ct_name'].replace(
175 | ['ES', 'Differentiating SPG', 'RS', 'SPC'], ['Elongating', 'SPG', 'STids', 'Scytes']).to_numpy()
176 | acc = (adjusted_pr == adjusted_gt).sum() / len(adjusted_gt) * 100.0
177 | print(' %s Acc (transfer only): %.2f%%' % (data_name, acc))
178 |
179 |
180 | ######################################
181 | # Part 4: Train GDAE model #
182 | ######################################
183 |
184 | print('\n==> Model training...')
185 | strings = [f'{k}={v}' for k, v in config['train'].items()]
186 | print(' Parameters(%s)' % (', '.join(strings)))
187 | time2 = time.time()
188 |
189 | # Normalize gene expression.
190 | sc.pp.normalize_total(adata, target_sum=1e4)
191 | sc.pp.log1p(adata)
192 | adata_X = adata.X - adata.X.mean(0)
193 | adata_X = adata_X / (adata.X.std(0) + 1e-10)
194 |
195 | # Construct spatial graph.
196 | gene_mat = torch.Tensor(adata_X)
197 | if dataset in ('Stereoseq', 'Slide_seq'): # PCA
198 | u, s, v = torch.pca_lowrank(gene_mat, config['train']['pca_dim'])
199 | gene_mat = torch.matmul(gene_mat, v)
200 | cell_coo = torch.Tensor(adata.obsm['spatial'])
201 | data = torch_geometric.data.Data(x=gene_mat, pos=cell_coo)
202 | data = torch_geometric.transforms.KNNGraph(k=config['train']['k_graph'], loop=True)(data)
203 | data.y = torch.Tensor(adata.obsm['pseudo_label'])
204 |
205 | # Make distances as edge weights.
206 | if config['train']['edge_weight']:
207 | data = torch_geometric.transforms.Distance()(data)
208 | data.edge_weight = 1 - data.edge_attr[:,0]
209 | else:
210 | data.edge_weight = torch.ones(data.edge_index.size(1))
211 |
212 | # Train self-supervision model.
213 | input_dim = data.num_features
214 | num_classes = len(adata.uns['pseudo_classes'])
215 | trainer = SpatialModelTrainer(input_dim, num_classes, device, config['train'])
216 | trainer.train(data, config['train'])
217 | trainer.save_checkpoint(os.path.join(save_dir, f'{model_name}.t7'))
218 | time3 = time.time()
219 |
220 | # Inference.
221 | print('\n==> Inferencing...')
222 | predictions = trainer.valid(data)
223 | celltype_pred = pd.Categorical([adata.uns['pseudo_classes'][i] for i in predictions.argmax(1)])
224 | if dataset == 'Slide_seq':
225 | adjusted_pr = celltype_pred.to_numpy()
226 | adjusted_gt = adata.obs['ct_name'].replace(
227 | ['ES', 'Differentiating SPG', 'RS', 'SPC'], ['Elongating', 'SPG', 'STids', 'Scytes']).to_numpy()
228 | acc = (adjusted_pr == adjusted_gt).sum() / len(adjusted_gt) * 100.0
229 | print(' %s Acc (transfer+GDAE): %.2f%%' % (data_name, acc))
230 |
231 | # Compute time cost.
232 | time4 = time.time()
233 | print(' S1 infer time: %.2fs' % (time2 - time1))
234 | print(' S2 train time: %.2fs' % (time3 - time2))
235 | print(' S2 infer time: %.2fs' % (time4 - time3))
236 |
237 | # Save results.
238 | result = pd.DataFrame({'cell': adata.obs_names.tolist(), 'celltype_pred': celltype_pred})
239 | result.to_csv(os.path.join(save_dir, f'{model_name}.csv'), index=False)
240 | adata.obsm['celltype_prob'] = predictions
241 | adata.obs['celltype_pred'] = pd.Categorical(celltype_pred)
242 | if 'adata_X_sparse_backup' in locals():
243 | adata.X = adata_X_sparse_backup
244 | adata.write(os.path.join(save_dir, f'{model_name}.h5ad'))
245 |
246 | # Save visualization.
247 | spot_size = (30 if dataset == 'Stereoseq' else 20)
248 | if dataset == 'Stereoseq':
249 | pseudo_top100 = adata.obs['pseudo_class'].to_numpy()
250 | other_classes = list(pd.value_counts(adata.obs['pseudo_class'])[100:].index)
251 | pseudo_top100[adata.obs['pseudo_class'].isin(other_classes)] = '_Others'
252 | adata.obs['pseudo_class'] = pd.Categorical(pseudo_top100)
253 | # sc.pl.spatial(adata, img_key=None, color=['pseudo_class'], spot_size=spot_size, show=False)
254 | # plt.savefig(os.path.join(save_dir, f'pseudo-{data_name}.pdf'), bbox_inches='tight', dpi=150)
255 | sc.pl.spatial(adata, img_key=None, color=['celltype_pred'], spot_size=spot_size, show=False)
256 | plt.savefig(os.path.join(save_dir, f'{model_name}.pdf'), bbox_inches='tight', dpi=150)
257 | print(' Predictions is saved in', os.path.join(save_dir, f'{model_name}.csv/pdf'))
258 |
259 |
260 | if __name__ == '__main__':
261 | data_list = ['DM1', 'DM2', 'DM3', 'WT1', 'WT2', 'WT3']
262 | parser = argparse.ArgumentParser()
263 | parser.add_argument('--data_name', choices=data_list)
264 | args = parser.parse_args()
265 | spatial_classification_tool(config, args.data_name)
266 |
--------------------------------------------------------------------------------
/cell_type_annotation_model.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentAILabHealthcare/spatialID/397417f56b2b0fc8bf388b5c504d88c18de34eac/cell_type_annotation_model.pyc
--------------------------------------------------------------------------------
/dnn_model/checkpoint_Hyp-3D_b.t7:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentAILabHealthcare/spatialID/397417f56b2b0fc8bf388b5c504d88c18de34eac/dnn_model/checkpoint_Hyp-3D_b.t7
--------------------------------------------------------------------------------
/dnn_model/checkpoint_MERFISH_s.t7:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentAILabHealthcare/spatialID/397417f56b2b0fc8bf388b5c504d88c18de34eac/dnn_model/checkpoint_MERFISH_s.t7
--------------------------------------------------------------------------------
/dnn_model/checkpoint_NanoString.t7:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentAILabHealthcare/spatialID/397417f56b2b0fc8bf388b5c504d88c18de34eac/dnn_model/checkpoint_NanoString.t7
--------------------------------------------------------------------------------
/dnn_model/checkpoint_Slide-seq_DM1.z01:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentAILabHealthcare/spatialID/397417f56b2b0fc8bf388b5c504d88c18de34eac/dnn_model/checkpoint_Slide-seq_DM1.z01
--------------------------------------------------------------------------------
/dnn_model/checkpoint_Slide-seq_DM1.z02:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentAILabHealthcare/spatialID/397417f56b2b0fc8bf388b5c504d88c18de34eac/dnn_model/checkpoint_Slide-seq_DM1.z02
--------------------------------------------------------------------------------
/dnn_model/checkpoint_Slide-seq_DM1.z03:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentAILabHealthcare/spatialID/397417f56b2b0fc8bf388b5c504d88c18de34eac/dnn_model/checkpoint_Slide-seq_DM1.z03
--------------------------------------------------------------------------------
/dnn_model/checkpoint_Slide-seq_DM1.z04:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentAILabHealthcare/spatialID/397417f56b2b0fc8bf388b5c504d88c18de34eac/dnn_model/checkpoint_Slide-seq_DM1.z04
--------------------------------------------------------------------------------
/dnn_model/checkpoint_Slide-seq_DM1.z05:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentAILabHealthcare/spatialID/397417f56b2b0fc8bf388b5c504d88c18de34eac/dnn_model/checkpoint_Slide-seq_DM1.z05
--------------------------------------------------------------------------------
/dnn_model/checkpoint_Slide-seq_DM1.z06:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentAILabHealthcare/spatialID/397417f56b2b0fc8bf388b5c504d88c18de34eac/dnn_model/checkpoint_Slide-seq_DM1.z06
--------------------------------------------------------------------------------
/dnn_model/checkpoint_Slide-seq_DM1.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentAILabHealthcare/spatialID/397417f56b2b0fc8bf388b5c504d88c18de34eac/dnn_model/checkpoint_Slide-seq_DM1.zip
--------------------------------------------------------------------------------
/spatialID_example_m1s1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentAILabHealthcare/spatialID/397417f56b2b0fc8bf388b5c504d88c18de34eac/spatialID_example_m1s1.png
--------------------------------------------------------------------------------
/spatialID_overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentAILabHealthcare/spatialID/397417f56b2b0fc8bf388b5c504d88c18de34eac/spatialID_overview.png
--------------------------------------------------------------------------------