├── CAST ├── CAST_Mark.py ├── CAST_Projection.py ├── CAST_Stack.py ├── __init__.py ├── main.py ├── models │ ├── __init__.py │ ├── aug.py │ └── model_GCNII.py ├── utils.py └── visualize.py ├── LICENSE ├── README.md ├── README_CAST_diagram.png ├── demo ├── demo1_CAST_Mark │ └── demo1_CAST_mark.ipynb ├── demo2_CAST_Stack_Align_S4_to_S1 │ └── demo2_CAST_Stack_Align_S4_to_S1.ipynb └── demo3_CAST_Projection │ └── demo3_CAST_project.ipynb └── setup.py /CAST/CAST_Mark.py: -------------------------------------------------------------------------------- 1 | import torch, dgl 2 | import numpy as np 3 | import networkx as nx 4 | import matplotlib.pyplot as plt 5 | from .models.aug import random_aug 6 | from .utils import coords2adjacentmat 7 | from timeit import default_timer as timer 8 | from collections import OrderedDict 9 | from tqdm import trange 10 | 11 | def train_seq(graphs, args, dump_epoch_list, out_prefix, model): 12 | """The CAST MARK training function 13 | 14 | Args: 15 | graphs (List[Tuple(str, dgl.Graph, torch.Tensor)]): List of 3-member tuples, each tuple represents one tissue sample, containing sample name, a DGL graph object, and a feature matrix in the torch.Tensor format 16 | args (model_GCNII.Args): the Args object contains training parameters 17 | dump_epoch_list (List): A list of epoch id you hope training snapshots to be dumped, for debug use, empty by default 18 | out_prefix (str): file name prefix for the snapshot files 19 | model (model_GCNII.CCA_SSG): the GNN model 20 | 21 | Returns: 22 | Tuple(Dict, List, CCA_SSG): returns a 3-member tuple, a dictionary containing the graph embeddings for each sample, a list of every loss value, and the trained model object 23 | """ 24 | model = model.to(args.device) 25 | 26 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr1, weight_decay=args.wd1) 27 | 28 | loss_log = [] 29 | time_now = timer() 30 | 31 | t = trange(args.epochs, desc='', leave=True) 32 | for epoch in t: 33 | 34 | with torch.no_grad(): 35 | if epoch in dump_epoch_list: 36 | model.eval() 37 | dump_embedding = OrderedDict() 38 | for name, graph, feat in graphs: 39 | # graph = graph.to(args.device) 40 | # feat = feat.to(args.device) 41 | dump_embedding[name] = model.get_embedding(graph, feat) 42 | torch.save(dump_embedding, f'{out_prefix}_embed_dict_epoch{epoch}.pt') 43 | torch.save(loss_log, f'{out_prefix}_loss_log_epoch{epoch}.pt') 44 | print(f"Successfully dumped epoch {epoch}") 45 | 46 | losses = dict() 47 | model.train() 48 | optimizer.zero_grad() 49 | # print(f'Epoch: {epoch}') 50 | 51 | for name_, graph_, feat_ in graphs: 52 | with torch.no_grad(): 53 | N = graph_.number_of_nodes() 54 | graph1, feat1 = random_aug(graph_, feat_, args.dfr, args.der) 55 | graph2, feat2 = random_aug(graph_, feat_, args.dfr, args.der) 56 | 57 | graph1 = graph1.add_self_loop() 58 | graph2 = graph2.add_self_loop() 59 | 60 | z1, z2 = model(graph1, feat1, graph2, feat2) 61 | 62 | c = torch.mm(z1.T, z2) 63 | c1 = torch.mm(z1.T, z1) 64 | c2 = torch.mm(z2.T, z2) 65 | 66 | c = c / N 67 | c1 = c1 / N 68 | c2 = c2 / N 69 | 70 | loss_inv = - torch.diagonal(c).sum() 71 | iden = torch.eye(c.size(0), device=args.device) 72 | loss_dec1 = (iden - c1).pow(2).sum() 73 | loss_dec2 = (iden - c2).pow(2).sum() 74 | loss = loss_inv + args.lambd * (loss_dec1 + loss_dec2) 75 | loss.backward() 76 | optimizer.step() 77 | 78 | # del graph1, feat1, graph2, feat2 79 | loss_log.append(loss.item()) 80 | time_step = timer() - time_now 81 | time_now += time_step 82 | # print(f'Loss: {loss.item()} step time={time_step:.3f}s') 83 | t.set_description(f'Loss: {loss.item():.3f} step time={time_step:.3f}s') 84 | t.refresh() 85 | 86 | model.eval() 87 | with torch.no_grad(): 88 | dump_embedding = OrderedDict() 89 | for name, graph, feat in graphs: 90 | dump_embedding[name] = model.get_embedding(graph, feat) 91 | return dump_embedding, loss_log, model 92 | 93 | # graph construction tools 94 | def delaunay_dgl(sample_name, df, output_path,if_plot=True,strategy_t = 'convex'): 95 | coords = np.column_stack((np.array(df)[:,0],np.array(df)[:,1])) 96 | delaunay_graph = coords2adjacentmat(coords,output_mode = 'raw',strategy_t = strategy_t) 97 | if if_plot: 98 | positions = dict(zip(delaunay_graph.nodes, coords[delaunay_graph.nodes,:])) 99 | _, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 10)) 100 | nx.draw( 101 | delaunay_graph, 102 | positions, 103 | ax=ax, 104 | node_size=1, 105 | node_color="#000000", 106 | edge_color="#5A98AF", 107 | alpha=0.6, 108 | ) 109 | plt.axis('equal') 110 | plt.savefig(f'{output_path}/delaunay_{sample_name}.png') 111 | return dgl.from_networkx(delaunay_graph) -------------------------------------------------------------------------------- /CAST/CAST_Projection.py: -------------------------------------------------------------------------------- 1 | import torch,random 2 | from tqdm import tqdm 3 | import numpy as np 4 | import pandas as pd 5 | import matplotlib.pyplot as plt 6 | import seaborn as sns 7 | from scipy.sparse import csr_matrix 8 | from sklearn.metrics import pairwise_distances,pairwise_distances_chunked,confusion_matrix 9 | import scanpy as sc 10 | from scipy.sparse import csr_matrix as csr 11 | from .utils import coords2adjacentmat 12 | 13 | def space_project( 14 | sdata_inte, 15 | idx_source, 16 | idx_target, 17 | raw_layer, 18 | source_sample, 19 | target_sample, 20 | coords_source, 21 | coords_target, 22 | output_path, 23 | source_sample_ctype_col, 24 | target_cell_pc_feature = None, 25 | source_cell_pc_feature = None, 26 | k2 = 1, 27 | ifplot = True, 28 | umap_feature = 'X_umap', 29 | ave_dist_fold = 2, 30 | batch_t = '', 31 | alignment_shift_adjustment = 50, 32 | color_dict = None, 33 | adjust_shift = False, 34 | metric_t = 'cosine', 35 | working_memory_t = 1000 36 | ): 37 | sdata_ref = sdata_inte[idx_target,:].copy() 38 | source_feat = sdata_inte[idx_source,:].layers[raw_layer].toarray() 39 | 40 | project_ind = np.zeros([np.sum(idx_target),k2]).astype(int) 41 | project_weight = np.zeros_like(project_ind).astype(float) 42 | cdists = np.zeros_like(project_ind).astype(float) 43 | physical_dist = np.zeros_like(project_ind).astype(float) 44 | all_avg_feat = np.zeros([np.sum(idx_target),source_feat.shape[1]]).astype(float) 45 | 46 | if source_sample_ctype_col is not None: 47 | for ctype_t in np.unique(sdata_inte[idx_target].obs[source_sample_ctype_col]): 48 | print(f'Start to project {ctype_t} cells:') 49 | idx_ctype_t = np.isin(sdata_inte[idx_target].obs[source_sample_ctype_col],ctype_t) 50 | ave_dist_t,_,_,_ = average_dist(coords_target[idx_ctype_t,:].copy(),working_memory_t=working_memory_t) 51 | dist_thres = ave_dist_fold * ave_dist_t + alignment_shift_adjustment 52 | if adjust_shift: 53 | coords_shift = group_shift(target_cell_pc_feature[idx_ctype_t,:], source_cell_pc_feature, coords_target[idx_ctype_t,:], coords_source, working_memory_t = working_memory_t, metric_t = metric_t) 54 | coords_source_t = coords_source + coords_shift 55 | print(coords_shift) 56 | else: 57 | coords_source_t = coords_source.copy() 58 | project_ind[idx_ctype_t,:],project_weight[idx_ctype_t,:],cdists[idx_ctype_t,:],physical_dist[idx_ctype_t,:],all_avg_feat[idx_ctype_t,:] = physical_dist_priority_project( 59 | feat_target = target_cell_pc_feature[idx_ctype_t,:], 60 | feat_source = source_cell_pc_feature, 61 | coords_target = coords_target[idx_ctype_t,:], 62 | coords_source = coords_source_t, 63 | source_feat = source_feat, 64 | k2 = 1, 65 | pdist_thres = dist_thres, 66 | metric_t = metric_t, 67 | working_memory_t = working_memory_t) 68 | else: 69 | ave_dist_t,_,_,_ = average_dist(coords_target.copy(),working_memory_t=working_memory_t,strategy_t='delaunay') 70 | dist_thres = ave_dist_fold * ave_dist_t + alignment_shift_adjustment 71 | project_ind,project_weight,cdists,physical_dist,all_avg_feat = physical_dist_priority_project( 72 | feat_target = target_cell_pc_feature, 73 | feat_source = source_cell_pc_feature, 74 | coords_target = coords_target, 75 | coords_source = coords_source, 76 | source_feat = source_feat, 77 | k2 = 1, 78 | pdist_thres = dist_thres, 79 | working_memory_t = working_memory_t) 80 | 81 | umap_target = sdata_inte[idx_target,:].obsm[umap_feature] 82 | umap_source = sdata_inte[idx_source,:].obsm[umap_feature] 83 | 84 | sdata_ref.layers[f'{source_sample}_raw'] = csr(all_avg_feat) 85 | sdata_ref.layers[f'{target_sample}_norm1e4'] = csr(sc.pp.normalize_total(sdata_ref,target_sum=1e4,layer = f'{raw_layer}',inplace=False)['X']) 86 | sdata_ref.layers[f'{source_sample}_norm1e4'] = csr(sc.pp.normalize_total(sdata_ref,target_sum=1e4,layer = f'{source_sample}_raw',inplace=False)['X']) 87 | y_true_t = np.array(sdata_inte[idx_target].obs[source_sample_ctype_col].values) if source_sample_ctype_col is not None else None 88 | y_source = np.array(sdata_inte[idx_source].obs[source_sample_ctype_col].values) if source_sample_ctype_col is not None else None 89 | y_pred_t = y_source[project_ind[:,0]] if source_sample_ctype_col is not None else None 90 | torch.save([physical_dist,project_ind,coords_target,coords_source,y_true_t,y_pred_t,y_source,output_path,source_sample_ctype_col,umap_target,umap_source,source_sample,target_sample,cdists,k2],f'{output_path}/mid_result{batch_t}.pt') 91 | if ifplot == True: 92 | evaluation_project( 93 | physical_dist = physical_dist, 94 | project_ind = project_ind, 95 | coords_target = coords_target, 96 | coords_source = coords_source, 97 | y_true_t = y_true_t, 98 | y_pred_t = y_pred_t, 99 | y_source = y_source, 100 | output_path = output_path, 101 | source_sample_ctype_col = source_sample_ctype_col, 102 | umap_target = umap_target, 103 | umap_source = umap_source, 104 | source_sample = source_sample, 105 | target_sample = target_sample, 106 | cdists = cdists, 107 | batch_t = batch_t, 108 | color_dict = color_dict) 109 | return sdata_ref,[project_ind,project_weight,cdists,physical_dist] 110 | 111 | def average_dist(coords,quantile_t = 0.99,working_memory_t = 1000,strategy_t = 'convex'): 112 | coords_t = pd.DataFrame(coords) 113 | coords_t.drop_duplicates(inplace = True) 114 | coords = np.array(coords_t) 115 | if coords.shape[0] > 5: 116 | delaunay_graph_t = coords2adjacentmat(coords,output_mode='raw',strategy_t = strategy_t) 117 | edges = np.array(delaunay_graph_t.edges()) 118 | def reduce_func(chunk_t, start): 119 | return chunk_t 120 | dists = pairwise_distances_chunked(coords, coords, metric='euclidean', n_jobs=-1,working_memory = working_memory_t,reduce_func = reduce_func) 121 | edge_dist = [] 122 | start_t = 0 123 | for dist_mat_t in dists: 124 | end_t = start_t + dist_mat_t.shape[0] 125 | idx_chunk = (start_t <= edges[:,0]) & (edges[:,0] < end_t) 126 | edge_t = edges[idx_chunk,:] 127 | edge_dist_t = np.array([dist_mat_t[node - start_t,val] for [node,val] in edge_t]) 128 | edge_dist.extend(edge_dist_t) 129 | start_t = end_t 130 | filter_thres = np.quantile(edge_dist,quantile_t) 131 | for i,j in edges[edge_dist > filter_thres,:]: 132 | delaunay_graph_t.remove_edge(i,j) 133 | result_t = np.mean(np.array(edge_dist)[edge_dist <= filter_thres]) 134 | return result_t,filter_thres,edge_dist,delaunay_graph_t 135 | else: 136 | dists = pairwise_distances(coords, coords, metric='euclidean', n_jobs=-1) 137 | result_t = np.mean(dists.flatten()) 138 | return result_t,'','','' 139 | 140 | def group_shift(feat_target, feat_source, coords_target_t, coords_source_t, working_memory_t = 1000, pencentile_t = 0.8, metric_t = 'cosine'): 141 | from sklearn.metrics import pairwise_distances_chunked 142 | print(f'Using {metric_t} distance to calculate group shift:') 143 | feat_similarity_ctype = np.vstack(list(pairwise_distances_chunked(feat_target, feat_source, metric=metric_t, n_jobs=-1, working_memory=working_memory_t))) 144 | num_anchor = int(feat_similarity_ctype.shape[0] * pencentile_t) 145 | anchor_rank = np.argpartition(feat_similarity_ctype, num_anchor - 1, axis=-1)[:,:num_anchor] 146 | anchors = [] 147 | for i in range(num_anchor): 148 | anchors.extend(anchor_rank[:,i].tolist()) 149 | anchors = list(set(anchors)) 150 | if len(anchors) >= num_anchor: 151 | break 152 | coords_shift = np.median(coords_target_t,axis=0) - np.median(coords_source_t[np.array(anchors),:],axis=0) 153 | return coords_shift 154 | 155 | def physical_dist_priority_project(feat_target, feat_source, coords_target, coords_source, source_feat = None, k2 = 1, k_extend = 20, pdist_thres = 200, working_memory_t = 1000, metric_t = 'cosine'): 156 | def reduce_func_cdist_priority(chunk_cdist, start): 157 | chunk_pdist = pairwise_distances(coords_target[start:(chunk_cdist.shape[0] + start),:],coords_source, metric='euclidean', n_jobs=-1) 158 | idx_pdist_t = chunk_pdist < pdist_thres 159 | idx_pdist_sum = idx_pdist_t.sum(1) 160 | idx_lessk2 = (idx_pdist_sum>= k2) 161 | cosine_knn_ind = np.zeros([chunk_cdist.shape[0],k2]).astype(int) 162 | cosine_knn_weight = np.zeros_like(cosine_knn_ind).astype(float) 163 | cosine_knn_cdist = np.zeros_like(cosine_knn_ind).astype(float) 164 | cosine_knn_physical_dist = np.zeros_like(cosine_knn_ind).astype(float) 165 | 166 | idx_narrow = np.where(idx_lessk2)[0] 167 | idx_narrow_reverse = np.where(np.logical_not(idx_lessk2))[0] 168 | 169 | for i in idx_narrow: 170 | idx_pdist_t_i = idx_pdist_t[i,:] 171 | idx_i = np.where(idx_pdist_t[i,:])[0] 172 | knn_ind_t = idx_i[np.argpartition(chunk_cdist[i,idx_pdist_t_i], k2 - 1, axis=-1)[:k2]] 173 | _,weight_cell,cdist_cosine = cosine_IDW(chunk_cdist[i,knn_ind_t],k2 = k2,need_filter=False) 174 | cosine_knn_ind[[i],:] = knn_ind_t 175 | cosine_knn_weight[[i],:] = weight_cell 176 | cosine_knn_cdist[[i],:] = cdist_cosine 177 | cosine_knn_physical_dist[[i],:] = chunk_pdist[i,knn_ind_t] 178 | if len(idx_narrow_reverse) > 0: 179 | for i in idx_narrow_reverse: 180 | idx_pdist_extend = np.argpartition(chunk_pdist[i,:], k_extend - 1, axis=-1)[:k_extend] 181 | knn_ind_t = idx_pdist_extend[np.argpartition(chunk_cdist[i,idx_pdist_extend], k2 - 1, axis=-1)[:k2]] 182 | _,weight_cell,cdist_cosine = cosine_IDW(chunk_cdist[i,knn_ind_t],k2 = k2,need_filter=False) 183 | cosine_knn_ind[[i],:] = knn_ind_t 184 | cosine_knn_weight[[i],:] = weight_cell 185 | cosine_knn_cdist[[i],:] = cdist_cosine 186 | cosine_knn_physical_dist[[i],:] = chunk_pdist[i,knn_ind_t] 187 | return cosine_knn_ind,cosine_knn_weight,cosine_knn_cdist,cosine_knn_physical_dist 188 | 189 | print(f'Using {metric_t} distance to calculate cell low dimensional distance:') 190 | dists = pairwise_distances_chunked(feat_target, feat_source, metric=metric_t, n_jobs=-1,working_memory = working_memory_t,reduce_func=reduce_func_cdist_priority) 191 | cosine_knn_inds = [] 192 | cosine_k2nn_weights = [] 193 | cosine_k2nn_cdists = [] 194 | cosine_k2nn_physical_dists = [] 195 | for output in tqdm(dists): 196 | cosine_knn_inds.append(output[0]) 197 | cosine_k2nn_weights.append(output[1]) 198 | cosine_k2nn_cdists.append(output[2]) 199 | cosine_k2nn_physical_dists.append(output[3]) 200 | 201 | all_cosine_knn_inds = np.concatenate(cosine_knn_inds) 202 | all_cosine_k2nn_weights = np.concatenate(cosine_k2nn_weights) 203 | all_cosine_k2nn_cdists = np.concatenate(cosine_k2nn_cdists) 204 | all_cosine_k2nn_physical_dists = np.concatenate(cosine_k2nn_physical_dists) 205 | 206 | if source_feat is not None: 207 | mask_idw = sparse_mask(all_cosine_k2nn_weights,all_cosine_knn_inds, source_feat.shape[0]) 208 | all_avg_feat = mask_idw.dot(source_feat) 209 | return all_cosine_knn_inds,all_cosine_k2nn_weights,all_cosine_k2nn_cdists,all_cosine_k2nn_physical_dists,all_avg_feat 210 | else: 211 | return all_cosine_knn_inds,all_cosine_k2nn_weights,all_cosine_k2nn_cdists,all_cosine_k2nn_physical_dists 212 | 213 | 214 | def sparse_mask(idw_t, ind : np.ndarray, n_cols : int, dtype=np.float64): # ind is indices with shape (num data points, indices), in the form of output of numpy.argpartition function 215 | # build csr matrix from scratch 216 | rows = np.repeat(np.arange(ind.shape[0]), ind.shape[1]) # gives like [1,1,1,2,2,2,3,3,3] 217 | cols = ind.flatten() # the col indices that should be 1 218 | data = idw_t.flatten() # Set to `1` each (row,column) pair 219 | return csr_matrix((data, (rows, cols)), shape=(ind.shape[0], n_cols), dtype=dtype) 220 | 221 | def cosine_IDW(cosine_dist_t,k2=5,eps = 1e-6,need_filter = True,ifavg = False): 222 | if need_filter: 223 | idx_cosdist_t = np.argpartition(cosine_dist_t, k2 - 1, axis=-1)[:k2] 224 | cdist_cosine_t = cosine_dist_t[idx_cosdist_t] 225 | else: 226 | idx_cosdist_t = 0 227 | cdist_cosine_t = cosine_dist_t 228 | if ifavg: 229 | weight_cell_t = np.array([1/k2] * k2) 230 | else: 231 | weight_cell_t = IDW(cdist_cosine_t,eps) 232 | return idx_cosdist_t, weight_cell_t, cdist_cosine_t 233 | 234 | def IDW(df_value,eps = 1e-6): 235 | weights = 1.0 /(df_value + eps).T 236 | weights /= weights.sum(axis=0) 237 | return weights.T 238 | 239 | def evaluation_project( 240 | physical_dist, 241 | project_ind, 242 | coords_target, 243 | coords_source, 244 | y_true_t, 245 | y_pred_t, 246 | y_source, 247 | output_path, 248 | source_sample_ctype_col, 249 | umap_target = None, 250 | umap_source = None, 251 | source_sample = None, 252 | target_sample = None, 253 | cdists = None, 254 | batch_t = '', 255 | exclude_group = 'Other', 256 | color_dict = None, 257 | umap_examples = False): 258 | print(f'Generate evaluation plots:') 259 | plt.rcParams.update({'pdf.fonttype':42, 'font.size' : 15}) 260 | plt.rcParams['axes.grid'] = False 261 | ### histogram ### 262 | cdist_hist(physical_dist.flatten(),range_t = [0,2000]) 263 | plt.savefig(f'{output_path}/physical_dist_hist{batch_t}.pdf') 264 | cdist_hist(cdists.flatten(),range_t = [0,2]) 265 | plt.savefig(f'{output_path}/cdist_hist{batch_t}.pdf') 266 | 267 | ### confusion matrix ### 268 | if source_sample_ctype_col is not None: 269 | if exclude_group is not None: 270 | idx_t = y_true_t != exclude_group 271 | y_true_t_use = y_true_t[idx_t] 272 | y_pred_t_use = y_pred_t[idx_t] 273 | else: 274 | y_true_t_use = y_true_t 275 | y_pred_t_use = y_pred_t 276 | confusion_mat_plot(y_true_t_use,y_pred_t_use) 277 | plt.savefig(f'{output_path}/confusion_mat_raw_with_label_{source_sample_ctype_col}{batch_t}.pdf') 278 | confusion_mat_plot(y_true_t_use,y_pred_t_use,withlabel = False) 279 | plt.savefig(f'{output_path}/confusion_mat_raw_without_label_{source_sample_ctype_col}{batch_t}.pdf') 280 | 281 | ### link plot 3d ### 282 | if color_dict is not None and source_sample_ctype_col is not None: 283 | color_target = [color_dict[x] for x in y_true_t] 284 | color_source = [color_dict[x] for x in y_source] 285 | else: 286 | color_target="#9295CA" 287 | color_source='#E66665' 288 | link_plot_3d(project_ind, coords_target, coords_source, k = 1,figsize_t = [10,10], 289 | sample_n=200, link_color_mask = None, 290 | color_target = color_target, color_source = color_source, 291 | color_true = "#222222") 292 | plt.savefig(f'{output_path}/link_plot{batch_t}.pdf', dpi=300) 293 | 294 | ### Umap ### 295 | if umap_examples: 296 | cdist_check(cdists.copy(),project_ind.copy(),umap_target,umap_source,labels_t=[target_sample,source_sample],random_seed_t=0,figsize_t=[40,32]) 297 | plt.savefig(f'{output_path}/umap_examples{batch_t}.pdf',dpi = 300) 298 | 299 | #################### Visualization #################### 300 | 301 | def cdist_hist(data_t,range_t = None,step = None): 302 | plt.figure(figsize=[5,5]) 303 | plt.hist(data_t, bins='auto',alpha = 0.5,color = '#1073BC') 304 | plt.yticks(fontsize=20) 305 | plt.xticks(fontsize=20) 306 | if type(range_t) != type(None): 307 | if type(step) != type(None): 308 | plt.xticks(np.arange(range_t[0], range_t[1] + 0.001, step),fontsize=20) 309 | else: 310 | plt.xticks(fontsize=20) 311 | plt.xlim(range_t[0], range_t[1]) 312 | else: 313 | plt.xticks(fontsize=20) 314 | plt.tight_layout() 315 | 316 | def confusion_mat_plot(y_true_t, y_pred_t, filter_thres = None, withlabel = True, fig_x = 60, fig_y = 20): 317 | plt.rcParams.update({'axes.labelsize' : 30,'pdf.fonttype':42,'axes.titlesize' : 30,'font.size': 15,'legend.markerscale' : 3}) 318 | plt.rcParams['axes.grid'] = False 319 | TPrate = np.round(np.sum(y_pred_t == y_true_t) / len(y_true_t),2) 320 | uniq_t = np.unique(y_true_t,return_counts=True) 321 | if type(filter_thres) == type(None): 322 | labels_t = uniq_t[0] 323 | else: 324 | labels_t = uniq_t[0][uniq_t[1] >= filter_thres] 325 | plt.figure(figsize=[fig_x,fig_y]) 326 | for idx_t, i in enumerate(['count','true','pred']): 327 | if i == 'count': 328 | normalize_t = None 329 | title_t = 'Counts (TP%%: %.2f)' % TPrate 330 | elif i == 'true': 331 | normalize_t = 'true' 332 | title_t = 'Sensitivity' 333 | elif i == 'pred': 334 | normalize_t = 'pred' 335 | title_t = 'Precision' 336 | plt.subplot(1,3,idx_t + 1) 337 | confusion_mat = confusion_matrix(y_true_t,y_pred_t,labels = labels_t, normalize = normalize_t) 338 | if i == 'count': 339 | vmax_t = np.max(confusion_mat) 340 | else: 341 | vmax_t = 1 342 | confusion_mat = pd.DataFrame(confusion_mat,columns=labels_t,index=labels_t) 343 | if withlabel: 344 | annot = np.diag(np.diag(confusion_mat.values.copy(),0),0) 345 | annot = np.round(annot,2) 346 | annot = annot.astype('str') 347 | annot[annot=='0.0']='' 348 | annot[annot=='0']='' 349 | sns.heatmap(confusion_mat,cmap = 'RdBu',center = 0,annot=annot,fmt='',square = True,vmax = vmax_t) 350 | else: 351 | sns.heatmap(confusion_mat,cmap = 'RdBu',center = 0,square = True,vmax = vmax_t) 352 | plt.title(title_t) 353 | plt.xlabel('Predicted label') 354 | plt.ylabel('True label') 355 | plt.tight_layout() 356 | 357 | def cdist_check(cdist_t,cdist_idx,umap_coords0,umap_coords1, labels_t = ['query','ref'],random_seed_t = 2,figsize_t = [40,32],output_path_t = None): 358 | plt.rcParams.update({'xtick.labelsize' : 20,'ytick.labelsize':20, 'axes.labelsize' : 30, 'axes.titlesize' : 40,'axes.grid': False}) 359 | random.seed(random_seed_t) 360 | sampled_points = np.sort(random.sample(list(range(0,cdist_idx.shape[0])),20)) 361 | fig, axs = plt.subplots(nrows=4, ncols=5, figsize=figsize_t) 362 | axs = axs.flatten() 363 | for i in range(len(sampled_points)): 364 | idx_check = sampled_points[i] 365 | axs[i].scatter(umap_coords0[:,0],umap_coords0[:,1],s = 0.5,c = '#1f77b4',rasterized=True) 366 | axs[i].scatter(umap_coords1[:,0],umap_coords1[:,1],s = 0.5,c = '#E47E8B',rasterized=True) 367 | axs[i].scatter(umap_coords0[idx_check,0],umap_coords0[idx_check,1],s = 220,linewidth = 4,c = '#1f77b4',edgecolors = '#000000',label = labels_t[0],rasterized=False) 368 | axs[i].scatter(umap_coords1[cdist_idx[idx_check,0],0],umap_coords1[cdist_idx[idx_check,0],1],s = 220,linewidth=4,c = '#E47E8B',edgecolors = '#000000',label = labels_t[1], rasterized=False) 369 | axs[i].legend(scatterpoints=1,markerscale=2, fontsize=30) 370 | axs[i].set_xticks([]) 371 | axs[i].set_yticks([]) 372 | axs[i].set_title('cdist = ' + str(format(cdist_t[idx_check,0],'.2f'))) 373 | if output_path_t is not None: 374 | plt.savefig(f'{output_path_t}/umap_examples.pdf',dpi = 300) 375 | plt.close('all') 376 | 377 | def link_plot_3d(assign_mat, coords_target, coords_source, k, figsize_t = [15,20], sample_n=1000, link_color_mask=None, color_target="#9295CA", color_source='#E66665', color_true = "#999999", color_false = "#999999", remove_background = True): 378 | from mpl_toolkits.mplot3d.art3d import Line3DCollection 379 | assert k == 1 380 | ax = plt.figure(figsize=figsize_t).add_subplot(projection='3d') 381 | xylim = max(coords_source.max(), coords_target.max()) 382 | ax.set_xlim(0, xylim) 383 | ax.set_ylim(0, xylim) 384 | ax.set_zlim(-0.1, 1.1) 385 | ax.set_box_aspect([1,1,0.6]) 386 | ax.view_init(elev=25) 387 | 388 | coordsidx_transfer_source_link = assign_mat[:, 0] 389 | 390 | coords_transfer_source_link = coords_source[coordsidx_transfer_source_link,:] 391 | t1 = np.row_stack((coords_transfer_source_link[:,0],coords_transfer_source_link[:,1])) # source 392 | t2 = np.row_stack((coords_target[:,0],coords_target[:,1])) # target 393 | 394 | downsample_indices = np.random.choice(range(coords_target.shape[0]), sample_n) 395 | 396 | if link_color_mask is not None: 397 | final_true_indices = np.intersect1d(downsample_indices, np.where(link_color_mask)[0]) 398 | final_false_indices = np.intersect1d(downsample_indices, np.where(~link_color_mask)[0]) 399 | segs = [[(*t2[:, i], 0), (*t1[:, i], 1)] for i in final_false_indices] 400 | line_collection = Line3DCollection(segs, colors=color_false, lw=0.5, linestyles='dashed') 401 | line_collection.set_rasterized(True) 402 | ax.add_collection(line_collection) 403 | else: 404 | final_true_indices = downsample_indices 405 | 406 | segs = [[(*t2[:, i], 0), (*t1[:, i], 1)] for i in final_true_indices] 407 | line_collection = Line3DCollection(segs, colors=color_true, lw=0.5, linestyles='dashed') 408 | line_collection.set_rasterized(True) 409 | ax.add_collection(line_collection) 410 | 411 | ### target - z = 0 412 | ax.scatter(xs = coords_target[:,0],ys = coords_target[:,1], zs=0, s = 2, c =color_target, alpha = 0.8, ec='none', rasterized=True, depthshade=False) 413 | ### source - z = 1 414 | ax.scatter(xs = coords_source[:,0],ys = coords_source[:,1], zs=1, s = 2, c =color_source, alpha = 0.8, ec='none', rasterized=True, depthshade=False) 415 | if remove_background: 416 | # Remove axis 417 | ax.axis('off') 418 | # Remove background 419 | ax.xaxis.pane.fill = False 420 | ax.yaxis.pane.fill = False 421 | ax.zaxis.pane.fill = False 422 | -------------------------------------------------------------------------------- /CAST/CAST_Stack.py: -------------------------------------------------------------------------------- 1 | import torch,copy,os,random 2 | from tqdm import trange 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | from sklearn.cluster import KMeans 6 | import seaborn as sns 7 | from dataclasses import dataclass, field 8 | from .visualize import add_scale_bar 9 | 10 | #################### Registration #################### 11 | # Parameters class 12 | 13 | @dataclass 14 | class reg_params: 15 | dataname : str 16 | ### affine 17 | theta_r1 : float = 0 18 | theta_r2 : float = 0 19 | d_list : list[float] = field(default_factory=list) 20 | translation_params : list[float] = None 21 | mirror_t : list[float] = None 22 | alpha_basis : list[float] = field(default_factory=list) 23 | iterations : int = 500 24 | dist_penalty1 : float = 0 25 | attention_params: list[float] = field(default_factory=list) 26 | 27 | ### BS 28 | mesh_trans_list : list[float] = field(default_factory=list) 29 | attention_region : list[float] = field(default_factory=list) 30 | attention_params_bs : list[float] = field(default_factory=list) 31 | mesh_weight : list[float] = field(default_factory=list) 32 | iterations_bs : list[float] = field(default_factory=list) 33 | alpha_basis_bs : list[float] = field(default_factory=list) 34 | meshsize : list[float] = field(default_factory=list) 35 | img_size_bs : list[float] = field(default_factory=list) # max_xy 36 | dist_penalty2 : list[float] = field(default_factory=list) 37 | PaddingRate_bs : float = 0 38 | 39 | ### common 40 | bleeding : float = 500 41 | diff_step : float = 5 42 | min_qr2 : float = 0 43 | mean_q : float = 0 44 | mean_r : float = 0 45 | gpu: int = 0 46 | device : str = field(init=False) 47 | ifrigid : bool = False 48 | 49 | def __post_init__(self): 50 | if self.gpu != -1 and torch.cuda.is_available(): 51 | self.device = 'cuda:{}'.format(self.gpu) 52 | else: 53 | self.device = 'cpu' 54 | 55 | def get_range(sp_coords): 56 | yrng = max(sp_coords, key=lambda x:x[1])[1] - min(sp_coords, key=lambda x:x[1])[1] 57 | xrng = max(sp_coords, key=lambda x:x[0])[0] - min(sp_coords, key=lambda x:x[0])[0] 58 | return xrng, yrng 59 | 60 | def prelocate(coords_q,coords_r,cov_anchor_it,bleeding,output_path,d_list=[1,2,3],prefix = 'test',ifplot = True,index_list = None,translation_params = None,mirror_t = None): 61 | idx_q = np.ones(coords_q.shape[0],dtype=bool) if index_list is None else index_list[0] 62 | idx_r = np.ones(coords_r.shape[0],dtype=bool) if index_list is None else index_list[1] 63 | mirror_t = [1,-1] if mirror_t is None else mirror_t 64 | theta_t = [] 65 | J_t = [] 66 | if translation_params is None: 67 | translation_x = [0] 68 | translation_y = [0] 69 | else: 70 | xrng, yrng = get_range(coords_r.detach().cpu()) 71 | dx_ratio_max, dy_ratio_max, xy_steps = translation_params 72 | dx_max = dx_ratio_max * xrng 73 | dy_max = dy_ratio_max * yrng 74 | translation_x = np.linspace(-dx_max, dx_max, num=int(xy_steps)) # dx 75 | translation_y = np.linspace(-dy_max, dy_max, num=int(xy_steps)) # dy 76 | for mirror in mirror_t: 77 | for dx in translation_x: 78 | for dy in translation_y: 79 | for d in d_list: 80 | for phi in [0,90,180,270]: 81 | a = d 82 | d = d * mirror 83 | theta = torch.Tensor([a,d,phi,dx,dy]).reshape(5,1).to(coords_q.device) 84 | coords_query_it = affine_trans_t(theta,coords_q) 85 | try: 86 | J_t.append(J_cal(coords_query_it[idx_q],coords_r[idx_r],cov_anchor_it,bleeding).sum().item()) 87 | except: 88 | continue 89 | theta_t.append(theta) 90 | if ifplot: 91 | prelocate_loss_plot(J_t,output_path,prefix) 92 | return(theta_t[np.argmin(J_t)]) 93 | 94 | def Affine_GD(coords_query_it_raw,coords_ref_it,cov_anchor_it,output_path,bleeding=500, dist_penalty = 0,diff_step = 50,alpha_basis = np.reshape(np.array([0,0,1/5,2,2]),[5,1]),iterations = 50,prefix='test',attention_params = [None,3,1,0],scale_t = 1,coords_log = False,index_list = None, mid_visual = False,early_stop_thres = 1, ifrigid = False): 95 | idx_q = np.ones(coords_query_it_raw.shape[0],dtype=bool) if index_list is None else index_list[0] 96 | idx_r = np.ones(coords_ref_it.shape[0],dtype=bool) if index_list is None else index_list[1] 97 | dev = coords_query_it_raw.device 98 | theta = torch.Tensor([1,1,0,0,0]).reshape(5,1).to(dev) # initial theta, [a,d,phi,t1,t2] 99 | coords_query_it = coords_query_it_raw.clone() 100 | plot_mid(coords_query_it.cpu() * scale_t,coords_ref_it.cpu() * scale_t,output_path,prefix + '_init',scale_bar_t=None) if mid_visual else None 101 | similarity_score = [J_cal(coords_query_it[idx_q],coords_ref_it[idx_r],cov_anchor_it,bleeding,dist_penalty,attention_params).sum().cpu().item()] 102 | it_J = [] 103 | it_theta = [] 104 | coords_q_log = [] 105 | delta_similarity_score = [np.inf] * 5 106 | t = trange(iterations, desc='', leave=True) 107 | for it in t: 108 | alpha = alpha_init(alpha_basis,it,dev) 109 | ## de_sscore 110 | dJ_dxy_mat = dJ_dt_cal(coords_query_it[idx_q], 111 | coords_ref_it[idx_r], 112 | diff_step, 113 | dev, 114 | cov_anchor_it, 115 | bleeding, 116 | dist_penalty, 117 | attention_params) 118 | 119 | dJ_dtheta = dJ_dtheta_cal(coords_query_it[idx_q,0], 120 | coords_query_it[idx_q,1], 121 | dJ_dxy_mat,theta,dev,ifrigid = ifrigid) 122 | theta = theta_renew(theta,dJ_dtheta,alpha,ifrigid = ifrigid) 123 | 124 | coords_query_it = affine_trans_t(theta,coords_query_it_raw) 125 | it_J.append(dJ_dtheta) 126 | it_theta.append(theta) 127 | if coords_log: 128 | coords_q_log.append(coords_query_it.detach().cpu().numpy()) 129 | 130 | sscore_t = J_cal(coords_query_it[idx_q],coords_ref_it[idx_r],cov_anchor_it,bleeding,dist_penalty,attention_params).sum().cpu().item() 131 | # print(f'Loss: {sscore_t}') 132 | t.set_description(f'Loss: {sscore_t:.3f}') 133 | t.refresh() 134 | similarity_score.append(sscore_t) 135 | if mid_visual: 136 | if (it % 20 == 0) | (it == 0): 137 | plot_mid(coords_query_it.cpu() * scale_t,coords_ref_it.cpu() * scale_t,output_path,prefix + str(int(it/10 + 0.5)),scale_bar_t=None) 138 | if early_stop_thres is not None and it > 200: 139 | delta_similarity_score.append(similarity_score[-2] - similarity_score[-1]) 140 | if np.all(np.array(delta_similarity_score[-5:]) < early_stop_thres): 141 | print(f'Early stop at {it}th iteration.') 142 | break 143 | return([similarity_score,it_J,it_theta,coords_q_log]) 144 | 145 | def BSpline_GD(coords_q,coords_r,cov_anchor_it,iterations,output_path,bleeding, dist_penalty = 0, alpha_basis = 1000,diff_step = 50,mesh_size = 5,prefix = 'test',mesh_weight = None,attention_params = [None,3,1,0],scale_t = 1,coords_log = False, index_list = None, mid_visual = False,max_xy = None,renew_mesh_trans = True,restriction_t = 0.5): 146 | idx_q = np.ones(coords_q.shape[0],dtype=bool) if index_list is None else index_list[0] 147 | idx_r = np.ones(coords_r.shape[0],dtype=bool) if index_list is None else index_list[1] 148 | dev = coords_q.device 149 | plot_mid(coords_q.cpu() * scale_t,coords_r.cpu()* scale_t,output_path,prefix + '_FFD_initial_' + str(iterations),scale_bar_t=None) if mid_visual else None 150 | 151 | max_xy = coords_q.max(0)[0].cpu() if max_xy is None else max_xy 152 | mesh,mesh_weight,kls,dxy_ffd_all,delta = BSpline_GD_preparation(max_xy,mesh_size,dev,mesh_weight) 153 | coords_query_it = coords_q.clone() 154 | 155 | similarity_score = [J_cal(coords_query_it[idx_q],coords_r[idx_r],cov_anchor_it,bleeding,dist_penalty,attention_params).sum().cpu().item()] 156 | mesh_trans_list = [] 157 | coords_q_log = [] 158 | mesh_trans = mesh.clone() 159 | max_movement = (max_xy / (mesh_size - 1.) * restriction_t).to(mesh.device).unsqueeze(-1).unsqueeze(-1) 160 | t = trange(iterations, desc='', leave=True) 161 | for it in t: 162 | dJ_dxy_mat = dJ_dt_cal(coords_query_it[idx_q], 163 | coords_r[idx_r], 164 | diff_step, 165 | dev, 166 | cov_anchor_it, 167 | bleeding, 168 | dist_penalty, 169 | attention_params) 170 | if renew_mesh_trans or it == 0: 171 | uv_raw, ij_raw = BSpline_GD_uv_ij_calculate(coords_query_it,delta,dev) 172 | uv = uv_raw[:,idx_q] # 2 * N[idx] 173 | ij = ij_raw[:,idx_q] # 2 * N[idx] 174 | 175 | result_B_t = B_matrix(uv,kls) ## 16 * N[idx] 176 | dxy_ffd = get_dxy_ffd(ij,result_B_t,mesh,dJ_dxy_mat,mesh_weight,alpha_basis) 177 | 178 | if renew_mesh_trans: 179 | mesh_trans = mesh + dxy_ffd 180 | else: 181 | mesh_trans = mesh + torch.clamp(mesh_trans + dxy_ffd - mesh, min=-max_movement, max=max_movement) 182 | mesh_trans_list.append(mesh_trans) 183 | coords_query_it = BSpline_renew_coords(uv_raw,kls,ij_raw,mesh_trans) 184 | if coords_log: 185 | coords_q_log.append(coords_query_it.detach().cpu().numpy()) 186 | sscore_t = J_cal(coords_query_it[idx_q],coords_r[idx_r],cov_anchor_it,bleeding,dist_penalty,attention_params).sum().cpu().item() 187 | # print(f'Loss: {sscore_t}') 188 | t.set_description(f'Loss: {sscore_t:.3f}') 189 | t.refresh() 190 | 191 | similarity_score.append(sscore_t) 192 | if mid_visual: 193 | if (it % 20 == 0) | (it == 0): 194 | plot_mid(coords_query_it.cpu() * scale_t,coords_r.cpu() * scale_t,output_path,prefix + '_FFD_it_' + str(it),scale_bar_t=None) 195 | mesh_plot(mesh.cpu(),coords_q_t=coords_query_it.cpu(),mesh_trans_t=mesh_trans.cpu()) 196 | plt.savefig(f'{output_path}/{prefix}_mesh_plot_it_{it}.pdf') 197 | plt.clf() 198 | ### visualization 199 | plt.figure(figsize=[20,10]) 200 | plt.subplot(1,2,1) 201 | plt.scatter(np.array(coords_q.cpu()[:,0].tolist()) * scale_t, 202 | np.array(coords_q.cpu()[:,1].tolist()) * scale_t, s=2,edgecolors='none', alpha = 0.5,rasterized=True, 203 | c='blue',label = 'Before') 204 | plt.scatter(np.array(coords_query_it.cpu()[:,0].tolist()) * scale_t, 205 | np.array(coords_query_it.cpu()[:,1].tolist()) * scale_t, s=2,edgecolors='none', alpha = 0.7,rasterized=True, 206 | c='#ef233c',label = 'After') 207 | plt.xticks(fontsize=20) 208 | plt.yticks(fontsize=20) 209 | plt.legend(fontsize=15) 210 | plt.axis('equal') 211 | plt.subplot(1,2,2) 212 | titles = 'loss = ' + format(similarity_score[-1],'.1f') 213 | plt.scatter(list(range(0,len(similarity_score))),similarity_score,s = 5) 214 | plt.title(titles,fontsize=20) 215 | plt.savefig(os.path.join(output_path,prefix + '_after_Bspline_' + str(iterations) + '.pdf')) 216 | return([coords_query_it,mesh_trans_list,dxy_ffd_all,similarity_score,coords_q_log]) 217 | 218 | def J_cal(coords_q,coords_r,cov_mat,bleeding = 10, dist_penalty = 0,attention_params = [None,3,1,0]): 219 | attention_region,double_penalty,penalty_inc_all,penalty_inc_both = attention_params 220 | bleeding_x = coords_q[:, 0].min() - bleeding, coords_q[:, 0].max() + bleeding 221 | bleeding_y = coords_q[:, 1].min() - bleeding, coords_q[:, 1].max() + bleeding 222 | 223 | sub_ind = ((coords_r[:, 0] > bleeding_x[0]) & (coords_r[:, 0] < bleeding_x[1]) & 224 | (coords_r[:, 1] > bleeding_y[0]) & (coords_r[:, 1] < bleeding_y[1])) 225 | 226 | cov_mat_t = cov_mat[:,sub_ind] 227 | dist = torch.cdist(coords_q,coords_r[sub_ind,:]) 228 | min_dist_values, close_idx = torch.min(dist, dim=1) 229 | 230 | tmp1 = torch.stack((torch.arange(coords_q.shape[0], device=coords_q.device), close_idx)).T 231 | s_score_mat = cov_mat_t[tmp1[:, 0], tmp1[:, 1]] 232 | 233 | if(dist_penalty != 0): 234 | penalty_tres = torch.sqrt((coords_r[:,0].max() - coords_r[:,0].min()) * (coords_r[:,1].max() - coords_r[:,1].min()) / coords_r.shape[0]) 235 | dist_d = min_dist_values / penalty_tres 236 | if(type(attention_region) is np.ndarray): 237 | attention_region = torch.tensor(attention_region, device=coords_q.device) 238 | dist_d[attention_region] = min_dist_values[attention_region] / (penalty_tres/double_penalty) 239 | dist_d[dist_d < 1] = 1 240 | dist_d[dist_d > 1] *= dist_penalty 241 | dist_d[attention_region] *= penalty_inc_all 242 | dist_d[(dist_d > 1) & attention_region] *= (penalty_inc_both/dist_penalty + 1) 243 | else: 244 | dist_d[dist_d < 1] = 1 245 | dist_d[dist_d > 1] *= dist_penalty 246 | return s_score_mat * dist_d 247 | return s_score_mat 248 | 249 | def alpha_init(alpha_basis,it,dev): 250 | return 5/torch.pow(torch.Tensor([it/40 + 1]).to(dev),0.6) * alpha_basis 251 | 252 | def dJ_dt_cal(coords_q,coords_r,diff_step,dev,cov_anchor_it,bleeding,dist_penalty,attention_params): 253 | dJ_dy = (J_cal(coords_q + torch.tensor([0,diff_step], device=dev), 254 | coords_r, 255 | cov_anchor_it, 256 | bleeding, 257 | dist_penalty, 258 | attention_params) - 259 | J_cal(coords_q + torch.tensor([0,-diff_step], device=dev), 260 | coords_r, 261 | cov_anchor_it, 262 | bleeding, 263 | dist_penalty, 264 | attention_params)) / (2 * diff_step) 265 | dJ_dx = (J_cal(coords_q + torch.tensor([diff_step,0], device=dev), 266 | coords_r, 267 | cov_anchor_it, 268 | bleeding, 269 | dist_penalty, 270 | attention_params) - 271 | J_cal(coords_q + torch.tensor([-diff_step,0], device=dev), 272 | coords_r, 273 | cov_anchor_it, 274 | bleeding, 275 | dist_penalty, 276 | attention_params)) / (2 * diff_step) 277 | dJ_dxy_mat = torch.vstack((dJ_dx,dJ_dy)) # [dJ_{i}/dx_{i},dJ_{i}/dy_{i}] (2 * N) 278 | return dJ_dxy_mat 279 | 280 | def dJ_dtheta_cal(xi,yi,dJ_dxy_mat,theta,dev,ifrigid = False): 281 | ''' 282 | #dxy_da: 283 | #{x * cos(rad_phi), x * sin(rad_phi)} 284 | #dxy_dd: 285 | #{-y * sin(rad_phi), y * cos(rad_phi)} 286 | #dxy_dphi: 287 | #{-d * y * cos(rad_phi) - a * x * sin(rad_phi), a * x * cos(rad_phi) - d * y * sin(rad_phi)} 288 | #dxy_dt1: 289 | #{1, 0} 290 | #dxy_dt2: 291 | #{0, 1} 292 | 293 | # when we set d = a (rigid): 294 | #dxy_da 295 | #{x * cos(rad_phi) - y * sin(rad_phi), y * cos(rad_phi) + x * sin(rad_phi)} 296 | #dxy_dd - set as the same value as dxy_da 297 | #{x * cos(rad_phi) - y * sin(rad_phi), y * cos(rad_phi) + x * sin(rad_phi)} 298 | #dxy_dphi 299 | #{-a * y * cos(rad_phi) - a * x * sin(rad_phi), a * x * cos(rad_phi) - a * y * sin(rad_phi)} 300 | ''' 301 | N = xi.shape[0] 302 | rad_phi = theta[2,0].deg2rad() 303 | cos_rad_phi = rad_phi.cos() 304 | sin_rad_phi = rad_phi.sin() 305 | ones = torch.ones(N, device=dev) 306 | zeros = torch.zeros(N, device=dev) 307 | if ifrigid: 308 | #### let d = a, only allow scaling, rotation and translation (Similarity transformation) 309 | #### If we want to use pure rigid transformation, just set `alpha_basis` as `[0,0,x,x,x]`, then the theta[0] will be always 1. 310 | dxy_dtheta = torch.stack([ 311 | torch.stack([ 312 | xi * cos_rad_phi - yi * sin_rad_phi, #dxy_da (rigid) 313 | xi * cos_rad_phi - yi * sin_rad_phi, #dxy_dd - won't use (rigid) 314 | -theta[0] * cos_rad_phi * yi - theta[0] * xi * sin_rad_phi, #dxy_dphi 315 | ones, #dxy_dt1 316 | zeros]), #dxy_dt2 317 | torch.stack([ 318 | yi * cos_rad_phi + xi * sin_rad_phi, #dxy_da (rigid) 319 | yi * cos_rad_phi + xi * sin_rad_phi, #dxy_dd - won't use (rigid) 320 | theta[0] * xi * cos_rad_phi - theta[0] * yi * sin_rad_phi, #dxy_dphi 321 | zeros, #dxy_dt1 322 | ones])]) #dxy_dt2 323 | else: 324 | dxy_dtheta = torch.stack([ 325 | torch.stack([ 326 | xi * cos_rad_phi, #dxy_da 327 | -yi * sin_rad_phi, #dxy_dd 328 | -theta[1] * cos_rad_phi * yi - theta[0] * xi * sin_rad_phi, #dxy_dphi 329 | ones, #dxy_dt1 330 | zeros]), #dxy_dt2 331 | torch.stack([ 332 | xi * sin_rad_phi, #dxy_da 333 | yi * cos_rad_phi, #dxy_dd 334 | theta[0] * xi * cos_rad_phi - theta[1] * yi * sin_rad_phi, #dxy_dphi 335 | zeros, #dxy_dt1 336 | ones])]) #dxy_dt2 337 | 338 | dJ_dtheta = torch.bmm(dxy_dtheta.permute(2, 1, 0), ### [N,5,2] 339 | dJ_dxy_mat.transpose(0, 1).unsqueeze(-1) ### [N,2,1] 340 | ).squeeze(2) # [dJ_{i}/dtheta_{k}] (N * 5) 341 | dJ_dtheta = dJ_dtheta.sum(0) 342 | 343 | return dJ_dtheta 344 | 345 | def theta_renew(theta,dJ_dtheta,alpha,ifrigid = False): 346 | alpha_dJ = alpha * dJ_dtheta.reshape(5,1) 347 | alpha_dJ[0:3] = alpha_dJ[0:3] / 1000 # avoid dtheta_{abcd} change a lot of x and y 348 | if ifrigid & (theta[0] == -theta[1]): 349 | # only when the rigid transformation is allowed, we should check the value of d and a if they are mirrored. 350 | # if d and a are mirrored (setting in the prelocate `d = d * mirror``), we should set alpha_dJ[1] as the `-alpha_dJ[1]`. 351 | alpha_dJ[1] = -alpha_dJ[1] 352 | theta_new = theta - alpha_dJ 353 | return theta_new 354 | 355 | def affine_trans_t(theta,coords_t): 356 | rad_phi = theta[2,0].deg2rad() 357 | cos_rad_phi = rad_phi.cos() 358 | sin_rad_phi = rad_phi.sin() 359 | A = torch.Tensor([[theta[0,0] * cos_rad_phi, -theta[1,0] * sin_rad_phi],[theta[0,0] * sin_rad_phi, theta[1,0] * cos_rad_phi]]).to(theta.device) 360 | t_vec = theta[3:5,:] 361 | coords_t1 = torch.mm(A,coords_t.T) + t_vec 362 | coords_t1 = coords_t1.T 363 | return coords_t1 364 | 365 | def torch_Bspline(uv, kl): 366 | return ( 367 | torch.where(kl == 0, (1 - uv) ** 3 / 6, 368 | torch.where(kl == 1, uv ** 3 / 2 - uv ** 2 + 2 / 3, 369 | torch.where(kl == 2, (-3 * uv ** 3 + 3 * uv ** 2 + 3 * uv + 1) / 6, 370 | torch.where(kl == 3, uv ** 3 / 6, torch.zeros_like(uv))))) 371 | ) 372 | 373 | def BSpline_GD_preparation(max_xy,mesh_size,dev,mesh_weight): 374 | delta = max_xy / (mesh_size - 1.) 375 | mesh = np.ones((2, mesh_size + 3, mesh_size + 3)) ## 2 * (mesh_size + 3) * (mesh_size + 3) 376 | for i in range(mesh_size + 3): 377 | for j in range(mesh_size + 3): 378 | mesh[:, i, j] = [(i - 1) * delta[0], (j - 1) * delta[1]] ## 0 - -delta, 1 - 0, 2 - delta, ..., 6 - delta * 5, 7 - delta * 6 (last row) 379 | mesh = torch.tensor(mesh).to(dev) 380 | mesh_weight = torch.tensor(mesh_weight).to(dev) if type(mesh_weight) is np.ndarray else 1 381 | kls = torch.stack(torch.meshgrid(torch.arange(4), torch.arange(4))).flatten(1).to(dev) ## 2 * 16 382 | dxy_ffd_all = torch.zeros(mesh.shape, device=dev) ## 2 * (mesh_size + 3) * (mesh_size + 3) 383 | return mesh,mesh_weight,kls,dxy_ffd_all,delta 384 | 385 | def BSpline_GD_uv_ij_calculate(coords_query_it,delta,dev): 386 | pos_reg = coords_query_it.T / delta.reshape(2,1).to(dev) # 2 * N 387 | pos_floor = pos_reg.floor().long() # 2 * N 388 | uv_raw = pos_reg - pos_floor # 2 * N 389 | ij_raw = pos_floor - 1 # 2 * N 390 | return uv_raw, ij_raw 391 | 392 | def B_matrix(uv_t, kls_t): 393 | result_B_list = [] 394 | for kl in kls_t.T: 395 | B = torch_Bspline(uv_t, kl.view(2, 1)) # 2 * N[idx] 396 | result_B_list.append(B.prod(0, keepdim=True)) # 1 * N[idx] ; .prod() - product of all elements in the tensor along a given dimension (0 - reduce along rows, 1 - reduce along columns) 397 | return torch.cat(result_B_list,0) # 16 * N[idx] 398 | 399 | def get_dxy_ffd(ij,result_B_t,mesh,dJ_dxy_mat,mesh_weight,alpha_basis): 400 | dxy_ffd_t = torch.zeros(mesh.shape, device=result_B_t.device) 401 | ij_0 = ij[0] + 1 402 | ij_1 = ij[1] + 1 403 | for k in range(dxy_ffd_t.shape[1]): 404 | for l in range(dxy_ffd_t.shape[2]): 405 | mask = (ij_0 <= k) & (k <= ij_0 + 3) & (ij_1 <= l) & (l <= ij_1 + 3) 406 | if mask.any(): # check if there is any True in the mask 407 | idx_kl = mask.nonzero().flatten() 408 | ij_t = torch.tensor([k, l], device=ij.device) - (ij[:, idx_kl].T + 1) 409 | keys = ij_t[:, 0] * 4 + ij_t[:, 1] 410 | t33 = result_B_t[keys, idx_kl] 411 | dxy_ffd_t[:,k,l] -= torch.matmul(dJ_dxy_mat[:,idx_kl],t33.unsqueeze(1).float()).squeeze(1) 412 | dxy_ffd_t *= mesh_weight 413 | dxy_ffd_t = dxy_ffd_t * alpha_basis 414 | return dxy_ffd_t 415 | 416 | def BSpline_renew_coords(uv_t,kls_t,ij_t,mesh_trans): 417 | result_tt = torch.zeros_like(uv_t, dtype=torch.float32) 418 | for kl in kls_t.T: 419 | B = torch_Bspline(uv_t, kl.view(2, 1)) 420 | pivots = (ij_t + 1 + kl.view(2, 1)).clamp(0, mesh_trans.size(-1) - 1) 421 | mesh_t = mesh_trans[:, pivots[0], pivots[1]] 422 | result_tt += B.prod(0, keepdim=True) * mesh_t 423 | return result_tt.T 424 | 425 | def reg_total_t(coords_q,coords_r,params_dist): 426 | dev = params_dist.device 427 | mean_q = coords_q.mean(0) 428 | mean_r = coords_r.mean(0) 429 | coords_q_t = torch.tensor(np.array(coords_q) - mean_q).float().to(dev) ## Initial location 430 | coords_q_r1 = affine_trans_t(params_dist.theta_r1,coords_q_t) ## Prelocation 1st Affine 431 | coords_q_r2 = affine_trans_t(params_dist.theta_r2,coords_q_r1) ## Affine transformation 2st Affine 432 | if params_dist.mesh_trans_list != [] and params_dist.mesh_trans_list != [[]]: 433 | coords_q_r3 = coords_q_r2.clone() 434 | for round_t in range(len(params_dist.mesh_trans_list)): 435 | coords_q_r3 = coords_q_r3.clone() - params_dist.min_qr2[round_t] 436 | coords_q_r3 = FFD_Bspline_apply_t(coords_q_r3.clone(),params_dist,round_t) 437 | coords_q_r3 = coords_q_r3.clone() + params_dist.min_qr2[round_t] 438 | coords_q_f = coords_q_r3.clone() 439 | else: 440 | coords_q_f = coords_q_r2 441 | coords_q_reconstruct = coords_q_f + torch.tensor(mean_r).to(dev) 442 | coords_q_reconstruct = coords_q_reconstruct.float() 443 | return coords_q_f,coords_q_reconstruct 444 | 445 | def FFD_Bspline_apply_t(coords_q,params_dist,round_t = 0): 446 | mesh_trans_list = params_dist.mesh_trans_list[round_t] 447 | dev = coords_q.device 448 | img_size = params_dist.img_size_bs[round_t] 449 | mesh_size = mesh_trans_list[0].shape[2] - 3 450 | delta = img_size / (mesh_size - 1.) 451 | coords_query_it = copy.deepcopy(coords_q) 452 | 453 | for it in trange(len(mesh_trans_list), desc='', leave=True): 454 | mesh_trans = mesh_trans_list[it] 455 | pos_reg = coords_query_it.T / delta.reshape(2,1).to(dev) 456 | pos_floor = pos_reg.floor().long() 457 | uv = pos_reg - pos_floor 458 | ij = pos_floor - 1 459 | kls = torch.stack(torch.meshgrid(torch.arange(4), torch.arange(4))).flatten(1).to(dev) 460 | result_tt = torch.zeros_like(uv).float() 461 | for kl in kls.T: 462 | B = torch_Bspline(uv, kl.view(2, 1)) 463 | pivots = (ij + 1 + kl.view(2, 1)).clamp(0, mesh_trans.size(-1) - 1) 464 | mesh_t = mesh_trans[:, pivots[0], pivots[1]] 465 | result_tt += B.prod(0, keepdim=True) * mesh_t 466 | coords_query_it = result_tt.T 467 | return coords_query_it 468 | 469 | def rescale_coords(coords_raw,graph_list,rescale = False): 470 | rescale_factor = 1 471 | if rescale: 472 | coords_raw = coords_raw.copy() 473 | for sample_t in graph_list: 474 | rescale_factor_t = 22340 / np.abs(coords_raw[sample_t]).max() 475 | coords_raw[sample_t] = coords_raw[sample_t].copy() * rescale_factor_t 476 | if sample_t == graph_list[1]: 477 | rescale_factor = rescale_factor_t 478 | return coords_raw,rescale_factor 479 | 480 | #################### Visualization #################### 481 | 482 | def mesh_plot(mesh_t,coords_q_t,mesh_trans_t = None): 483 | mesh_no_last_row = mesh_t[:, :, :].numpy() 484 | plt.figure(figsize=[10,10]) 485 | plt.plot(mesh_no_last_row[0], mesh_no_last_row[1], 'blue') 486 | plt.plot(mesh_no_last_row.T[..., 0], mesh_no_last_row.T[..., 1], 'blue') 487 | if(type(mesh_trans_t) is not type(None)): 488 | mesh_trans_no_last_row = mesh_trans_t[:, :, :].numpy() 489 | plt.plot(mesh_trans_no_last_row[0], mesh_trans_no_last_row[1], 'orange') 490 | plt.plot(mesh_trans_no_last_row.T[..., 0], mesh_trans_no_last_row.T[..., 1], 'orange') 491 | plt.scatter(coords_q_t.T[0,:],coords_q_t.T[1,:],c='blue',s = 0.5,alpha=0.5, rasterized=True) 492 | 493 | def plot_mid(coords_q,coords_r,output_path='',filename = None,title_t = ['ref','query'],s_t = 8,scale_bar_t = None): 494 | plt.rcParams.update({'font.size' : 30,'axes.titlesize' : 30,'pdf.fonttype':42,'legend.markerscale' : 5}) 495 | plt.figure(figsize=[10,12]) 496 | plt.scatter(np.array(coords_r)[:,0].tolist(), 497 | np.array(coords_r)[:,1].tolist(), s=s_t,edgecolors='none', alpha = 0.5,rasterized=True, 498 | c='#9295CA',label = title_t[0]) 499 | plt.scatter(np.array(coords_q)[:,0].tolist(), 500 | np.array(coords_q)[:,1].tolist(), s=s_t,edgecolors='none', alpha = 0.5,rasterized=True, 501 | c='#E66665',label = title_t[1]) 502 | plt.legend(fontsize=15) 503 | plt.axis('equal') 504 | if (type(scale_bar_t) != type(None)): 505 | add_scale_bar(scale_bar_t[0],scale_bar_t[1]) 506 | if (filename != None): 507 | plt.savefig(os.path.join(output_path,filename + '.pdf'),dpi = 100) 508 | 509 | def corr_heat(coords_q,coords_r,corr,output_path,title_t = ['Corr in ref','Anchor in query'],filename=None,scale_bar_t = None): 510 | plt.rcParams.update({'font.size' : 20,'axes.titlesize' : 20,'pdf.fonttype':42}) 511 | random.seed(2) 512 | sampled_points = np.sort(random.sample(list(range(0,coords_q.shape[0])),20)) 513 | plt.figure(figsize=((40,25))) 514 | for t in range(0,len(sampled_points)): 515 | 516 | plt_ind = t * 2 517 | ins_cell_idx = sampled_points[t] 518 | col_value = corr[ins_cell_idx,:] 519 | col_value_bg = [0] * coords_q.shape[0] 520 | col_value_bg[ins_cell_idx] = 1 521 | size_value_bg = [5] * coords_q.shape[0] 522 | size_value_bg[ins_cell_idx] = 30 523 | plt.subplot(5,8,plt_ind + 1) 524 | plt.scatter(np.array(coords_r[:,0]), np.array(coords_r[:,1]), s=5,edgecolors='none', 525 | c=col_value,cmap = 'vlag',vmin = -1,vmax= 1,rasterized=True) 526 | 527 | plt.title(title_t[0]) 528 | plt.axis('equal') 529 | if (type(scale_bar_t) != type(None)): 530 | add_scale_bar(scale_bar_t[0],scale_bar_t[1]) 531 | plt.subplot(5,8,plt_ind + 2) 532 | plt.scatter(np.array(coords_q[:,0]), np.array(coords_q[:,1]), s=size_value_bg,edgecolors='none', 533 | c=col_value_bg,cmap = 'vlag',vmin = -1,vmax= 1,rasterized=True) 534 | plt.scatter(np.array(coords_q[ins_cell_idx,0]), np.array(coords_q[ins_cell_idx,1]), s=size_value_bg[ins_cell_idx],edgecolors='none', 535 | c=col_value_bg[ins_cell_idx],cmap = 'vlag',vmin = -1,vmax= 1,rasterized=True) 536 | plt.title(title_t[1]) 537 | plt.axis('equal') 538 | if (type(scale_bar_t) != type(None)): 539 | add_scale_bar(scale_bar_t[0],scale_bar_t[1]) 540 | plt.tight_layout() 541 | plt.colorbar() 542 | if (filename != None): 543 | plt.savefig(os.path.join(output_path,filename + '.pdf'),dpi=100,transparent=True) 544 | 545 | def prelocate_loss_plot(J_t,output_path,prefix = 'test'): 546 | plt.rcParams.update({'font.size' : 15}) 547 | plt.figure(figsize=[5,5]) 548 | plt.scatter(x=list(range(0,len(J_t))),y=J_t) 549 | plt.savefig(f'{output_path}/{prefix}_prelocate_loss.pdf') 550 | 551 | def register_result(coords_q,coords_r,cov_anchor_t,bleeding,embed_stack,output_path,k=8,prefix='test',scale_t = 1,index_list = None): 552 | idx_q = np.ones(coords_q.shape[0],dtype=bool) if index_list is None else index_list[0] 553 | idx_r = np.ones(coords_r.shape[0],dtype=bool) if index_list is None else index_list[1] 554 | coords_q = coords_q * scale_t 555 | coords_r = coords_r * scale_t 556 | kmeans = KMeans(n_clusters=k,random_state=0).fit(embed_stack) 557 | cell_label = kmeans.labels_ 558 | cluster_pl = sns.color_palette('tab20',len(np.unique(cell_label))) 559 | ### panel 1 ### 560 | plot_mid(coords_q[idx_q],coords_r[idx_r],output_path,f'{prefix}_Results_1', scale_bar_t = None) 561 | ### panel 2 ### 562 | plt.figure(figsize=[10,12]) 563 | plt.rcParams.update({'font.size' : 10,'axes.titlesize' : 20,'pdf.fonttype':42}) 564 | col=coords_q[idx_q,0] 565 | row=coords_q[idx_q,1] 566 | cell_type_t = cell_label[0:coords_q[idx_q].shape[0]] 567 | for i in set(cell_type_t): 568 | plt.scatter(np.array(col)[cell_type_t == i], 569 | np.array(row)[cell_type_t == i], s=12,edgecolors='none',alpha = 0.5,rasterized=True, 570 | c=np.array(cluster_pl)[cell_type_t[cell_type_t == i]],label = str(i)) 571 | col=coords_r[idx_r,0] 572 | row=coords_r[idx_r,1] 573 | cell_type_t = cell_label[coords_q[idx_q].shape[0]:] 574 | for i in set(cell_type_t): 575 | plt.scatter(np.array(col)[cell_type_t == i], 576 | np.array(row)[cell_type_t == i], s=12,edgecolors='none',alpha = 0.5,rasterized=True, 577 | c=np.array(cluster_pl)[cell_type_t[cell_type_t == i]],label = str(i)) 578 | plt.axis('equal') 579 | plt.xticks(fontsize=20) 580 | plt.yticks(fontsize=20) 581 | plt.title('K means (k = ' + str(k) + ')',fontsize=30) 582 | add_scale_bar(200,'200 µm') 583 | plt.savefig(f'{output_path}/{prefix}_Results_2.pdf',dpi = 300) 584 | ### panel 3 ### 585 | plt.figure(figsize=[20,12]) 586 | plt.subplot(1,2,1) 587 | t_score = J_cal(torch.from_numpy(coords_q[idx_q]),torch.from_numpy(coords_r[idx_r]),cov_anchor_t,bleeding) 588 | plt.scatter(coords_q[idx_q,0],coords_q[idx_q,1],c=1 - t_score,cmap = 'vlag',vmin = -1,vmax = 1,s = 15,edgecolors='none',alpha=0.5,rasterized=True) 589 | add_scale_bar(200,'200 µm') 590 | plt.subplot(1,2,2) 591 | plt.scatter(coords_q[0,0],coords_q[0,1],c=[0],cmap = 'vlag',vmin = -1,vmax = 1,s = 15,alpha=0.5) 592 | plt.axis('off') 593 | plt.colorbar() 594 | plt.savefig(f'{output_path}/{prefix}_Results_3.pdf',dpi = 300) 595 | 596 | def affine_reg_params(it_theta,similarity_score,iterations,output_path,prefix='test'): 597 | plt.rcParams.update({'font.size' : 15,'axes.titlesize' : 15,'pdf.fonttype':42}) 598 | similarity_score_t = copy.deepcopy(similarity_score) 599 | titles = ['a','d','φ','t1','t2','loss = ' + format(similarity_score[-1],'.1f')] 600 | plt.figure(figsize=[15,8]) 601 | for i in range(0,6): 602 | plt.subplot(2,4,i+1) 603 | if i == 5: 604 | plt.scatter(list(range(0,len(similarity_score_t))),similarity_score_t,s = 5) 605 | else: 606 | # plt.scatter(x = range(0,iterations),y=np.array(it_theta)[:,i,0],s = 5) 607 | plt.scatter(x = range(1,len(similarity_score_t)),y=np.array(it_theta)[:,i,0],s = 5) 608 | plt.title(titles[i],fontsize=20) 609 | plt.savefig(os.path.join(output_path,prefix + '_params_Affine_GD_' + str(iterations) + 'its.pdf')) 610 | 611 | def CAST_STACK_rough(coords_raw_list, ifsquare=True, if_max_xy=True, percentile = None): 612 | ''' 613 | coords_raw_list: list of numpy arrays, each array is the coordinates of a layer 614 | ifsquare: if True, the coordinates will be scaled to a square 615 | if_max_xy: if True, the coordinates will be scaled to the max value of the `max_range_x` and `max_range_y`, respectively (if ifsquare is False), or the max value of [max_range_x,max_range_y] (if ifsquare is True) 616 | percentile: if not None, the min and max will be calculated based on the percentile of the coordinates for each slice. 617 | ''' 618 | # Convert list of arrays to a single numpy array for easier processing 619 | all_coords = np.concatenate(coords_raw_list) 620 | # Finding the global min and max for both x and y 621 | if percentile is None: 622 | min_x, min_y = np.min(all_coords, axis=0) 623 | max_x, max_y = np.max(all_coords, axis=0) 624 | else: 625 | min_x_list, min_y_list, max_x_list, max_y_list = [], [], [], [] 626 | for coords_t in coords_raw_list: 627 | min_x_list.append(np.percentile(coords_t[:,0],percentile)) 628 | min_y_list.append(np.percentile(coords_t[:,1],percentile)) 629 | max_x_list.append(np.percentile(coords_t[:,0],100-percentile)) 630 | max_y_list.append(np.percentile(coords_t[:,1],100-percentile)) 631 | min_x, min_y = np.min(min_x_list), np.min(min_y_list) 632 | max_x, max_y = np.max(max_x_list), np.max(max_y_list) 633 | max_xy = np.array([max_x - min_x, max_y - min_y]) 634 | scaled_coords_list = [] 635 | for coords_t in coords_raw_list: 636 | coords_t2 = (coords_t - coords_t.min(axis=0)) / np.ptp(coords_t, axis=0) 637 | if if_max_xy: 638 | max_xy_scale = max_xy 639 | else: 640 | max_xy_scale = max_xy / np.max(max_xy) 641 | scaled_coords = coords_t2 * np.max(max_xy_scale) if ifsquare else coords_t2 * max_xy_scale 642 | scaled_coords_list.append(scaled_coords) 643 | return scaled_coords_list 644 | 645 | #################### Calculation #################### 646 | def coords_minus_mean(coord_t): 647 | return np.array(coord_t) - np.mean(np.array(coord_t),axis = 0) 648 | 649 | def coords_minus_min(coord_t): 650 | return np.array(coord_t) - np.min(np.array(coord_t),axis = 0) 651 | 652 | def max_minus_value(corr): 653 | return np.max(corr) - corr 654 | 655 | def coords_minus_min_t(coord_t): 656 | return coord_t - coord_t.min(0)[0] 657 | 658 | def max_minus_value_t(corr): 659 | return corr.max() - corr 660 | 661 | def corr_dist(query_np, ref_np, nan_as = 'min'): 662 | from sklearn.metrics import pairwise_distances_chunked 663 | def chunked_callback(dist_matrix,start): 664 | return 1 - dist_matrix 665 | chunks = pairwise_distances_chunked(query_np, ref_np, metric='correlation', n_jobs=-1, working_memory=1024, reduce_func=chunked_callback) 666 | corr_q_r = np.vstack(list(chunks)) 667 | if nan_as == 'min': 668 | corr_q_r[np.isnan(corr_q_r)] = np.nanmin(corr_q_r) 669 | return corr_q_r 670 | 671 | def region_detect(embed_dict_t,coords0,k = 20): 672 | plot_row = int(np.floor((k+1)/4) + 1) 673 | kmeans = KMeans(n_clusters=k,random_state=0).fit(embed_dict_t) 674 | cell_label = kmeans.labels_ 675 | cluster_pl = sns.color_palette('tab20',len(np.unique(cell_label))) 676 | plt.figure(figsize=((20,5 * plot_row))) 677 | plt.subplot(plot_row,4,1) 678 | cell_label_idx = 0 679 | col=coords0[:,0].tolist() 680 | row=coords0[:,1].tolist() 681 | cell_type_t = cell_label[cell_label_idx:(cell_label_idx + coords0.shape[0])] 682 | cell_label_idx += coords0.shape[0] 683 | for i in set(cell_type_t): 684 | plt.scatter(np.array(col)[cell_type_t == i], 685 | np.array(row)[cell_type_t == i], s=5,edgecolors='none', 686 | c=np.array(cluster_pl)[cell_type_t[cell_type_t == i]],label = str(i)) 687 | plt.title(' (KMeans, k = ' + str(k) + ')',fontsize=20) 688 | plt.xticks(fontsize=20) 689 | plt.yticks(fontsize=20) 690 | plt.axis('equal') 691 | for j,i in enumerate(set(cell_type_t)): 692 | plt.subplot(plot_row,4,j+2) 693 | plt.scatter(np.array(col),np.array(row),s=3,c = '#DDDDDD') 694 | plt.scatter(np.array(col)[cell_type_t == i], 695 | np.array(row)[cell_type_t == i], s=5,edgecolors='none', 696 | c=np.array(cluster_pl)[cell_type_t[cell_type_t == i]],label = str(i)) 697 | plt.title(str(i),fontsize=20) 698 | plt.axis('equal') 699 | return cell_label 700 | -------------------------------------------------------------------------------- /CAST/__init__.py: -------------------------------------------------------------------------------- 1 | from .main import CAST_STACK, CAST_MARK, CAST_PROJECT 2 | from .CAST_Stack import reg_params, region_detect, corr_dist, CAST_STACK_rough 3 | from .visualize import kmeans_plot_multiple, plot_mid, dsplot, plot_mid_v2 4 | from .utils import * 5 | -------------------------------------------------------------------------------- /CAST/main.py: -------------------------------------------------------------------------------- 1 | from .CAST_Mark import * 2 | from .CAST_Stack import * 3 | from .CAST_Projection import * 4 | from .utils import * 5 | from .visualize import * 6 | from .models.model_GCNII import Args, CCA_SSG 7 | 8 | def CAST_MARK(coords_raw_t,exp_dict_t,output_path_t,task_name_t = None,gpu_t = None,args = None,epoch_t = None, if_plot = True, graph_strategy = 'convex'): 9 | ### setting 10 | gpu_t = 0 if torch.cuda.is_available() and gpu_t is None else -1 11 | device = 'cuda:0' if gpu_t == 0 else 'cpu' 12 | samples = list(exp_dict_t.keys()) 13 | task_name_t = task_name_t if task_name_t is not None else 'task1' 14 | inputs = [] 15 | 16 | ### construct delaunay graphs and input data 17 | print(f'Constructing delaunay graphs for {len(samples)} samples...') 18 | for sample_t in samples: 19 | graph_dgl_t = delaunay_dgl(sample_t,coords_raw_t[sample_t],output_path_t,if_plot=if_plot,strategy_t = graph_strategy).to(device) 20 | feat_torch_t = torch.tensor(exp_dict_t[sample_t], dtype=torch.float32, device=device) 21 | inputs.append((sample_t, graph_dgl_t, feat_torch_t)) 22 | 23 | ### parameters setting 24 | if args is None: 25 | args = Args( 26 | dataname=task_name_t, # name of the dataset, used to save the log file 27 | gpu = gpu_t, # gpu id, set to zero for single-GPU nodes 28 | epochs=400, # number of epochs for training 29 | lr1= 1e-3, # learning rate 30 | wd1= 0, # weight decay 31 | lambd= 1e-3, # lambda in the loss function, refer to online methods 32 | n_layers=9, # number of GCNII layers, more layers mean a deeper model, larger reception field, at a cost of VRAM usage and computation time 33 | der=0.5, # edge dropout rate in CCA-SSG 34 | dfr=0.3, # feature dropout rate in CCA-SSG 35 | use_encoder=True, # perform a single-layer dimension reduction before the GNNs, helps save VRAM and computation time if the gene panel is large 36 | encoder_dim=512, # encoder dimension, ignore if `use_encoder` set to `False` 37 | ) 38 | args.epochs = epoch_t if epoch_t is not None else args.epochs 39 | 40 | ### Initialize the model 41 | in_dim = inputs[0][-1].size(-1) 42 | model = CCA_SSG(in_dim=in_dim, encoder_dim=args.encoder_dim, n_layers=args.n_layers, use_encoder=args.use_encoder).to(args.device) 43 | 44 | ### Training 45 | print(f'Training on {args.device}...') 46 | embed_dict, loss_log, model = train_seq(graphs=inputs, args=args, dump_epoch_list=[], out_prefix=f'{output_path_t}/{task_name_t}_seq_train', model=model) 47 | 48 | ### Saving the results 49 | torch.save(embed_dict, f'{output_path_t}/demo_embed_dict.pt') 50 | torch.save(loss_log, f'{output_path_t}/demo_loss_log.pt') 51 | torch.save(model, f'{output_path_t}/demo_model_trained.pt') 52 | print(f'Finished.') 53 | print(f'The embedding, log, model files were saved to {output_path_t}') 54 | return embed_dict 55 | 56 | def CAST_STACK(coords_raw,embed_dict,output_path,graph_list,params_dist= None,tmp1_f1_idx = None, mid_visual = False, sub_node_idxs = None, rescale = False, corr_q_r = None, if_embed_sub = False, early_stop_thres = None, renew_mesh_trans = True): 57 | ### setting parameters 58 | query_sample = graph_list[0] 59 | ref_sample = graph_list[1] 60 | prefix_t = f'{query_sample}_align_to_{ref_sample}' 61 | result_log = dict() 62 | coords_raw, result_log['ref_rescale_factor'] = rescale_coords(coords_raw,graph_list,rescale = rescale) 63 | 64 | if sub_node_idxs is None: 65 | sub_node_idxs = { 66 | query_sample: np.ones(coords_raw[query_sample].shape[0],dtype=bool), 67 | ref_sample: np.ones(coords_raw[ref_sample].shape[0],dtype=bool) 68 | } 69 | 70 | if params_dist is None: 71 | params_dist = reg_params(dataname = query_sample, 72 | gpu = 0, 73 | #### Affine parameters 74 | iterations=500, 75 | dist_penalty1=0, 76 | bleeding=500, 77 | d_list = [3,2,1,1/2,1/3], 78 | attention_params = [None,3,1,0], 79 | #### FFD parameters 80 | dist_penalty2 = [0], 81 | alpha_basis_bs = [500], 82 | meshsize = [8], 83 | iterations_bs = [400], 84 | attention_params_bs = [[tmp1_f1_idx,3,1,0]], 85 | mesh_weight = [None]) 86 | if params_dist.alpha_basis == []: 87 | params_dist.alpha_basis = torch.Tensor([1/3000,1/3000,1/100,5,5]).reshape(5,1).to(params_dist.device) 88 | round_t = 0 89 | plt.rcParams.update({'pdf.fonttype':42}) 90 | plt.rcParams['axes.grid'] = False 91 | 92 | ### Generate correlation matrix of the graph embedding 93 | if corr_q_r is None: 94 | if if_embed_sub: 95 | corr_q_r = corr_dist(embed_dict[query_sample].cpu()[sub_node_idxs[query_sample]], embed_dict[ref_sample].cpu()[sub_node_idxs[ref_sample]]) 96 | else: 97 | corr_q_r = corr_dist(embed_dict[query_sample].cpu(), embed_dict[ref_sample].cpu()) 98 | else: 99 | corr_q_r = corr_q_r 100 | 101 | # Plot initial coordinates 102 | kmeans_plot_multiple(embed_dict,graph_list,coords_raw,prefix_t,output_path,k=15,dot_size = 10) if mid_visual else None 103 | corr_heat(coords_raw[query_sample][sub_node_idxs[query_sample]],coords_raw[ref_sample][sub_node_idxs[ref_sample]],corr_q_r,output_path,filename=prefix_t+'_corr') if mid_visual else None 104 | plot_mid(coords_raw[query_sample],coords_raw[ref_sample],output_path,f'{prefix_t}_raw') 105 | 106 | ### Initialize the coordinates and tensor 107 | corr_q_r = torch.Tensor(corr_q_r).to(params_dist.device) 108 | params_dist.mean_q = coords_raw[query_sample].mean(0) 109 | params_dist.mean_r = coords_raw[ref_sample].mean(0) 110 | coords_query = torch.Tensor(coords_minus_mean(coords_raw[query_sample])).to(params_dist.device) 111 | coords_ref = torch.Tensor(coords_minus_mean(coords_raw[ref_sample])).to(params_dist.device) 112 | 113 | ### Pre-location 114 | theta_r1_t = prelocate(coords_query,coords_ref,max_minus_value_t(corr_q_r),params_dist.bleeding,output_path,d_list=params_dist.d_list,prefix = prefix_t,index_list=[sub_node_idxs[k_t] for k_t in graph_list],translation_params = params_dist.translation_params,mirror_t=params_dist.mirror_t) 115 | params_dist.theta_r1 = theta_r1_t 116 | coords_query_r1 = affine_trans_t(params_dist.theta_r1,coords_query) 117 | plot_mid(coords_query_r1.cpu(),coords_ref.cpu(),output_path,prefix_t + '_prelocation') if mid_visual else None ### consistent scale with ref coords 118 | 119 | ### Affine 120 | output_list = Affine_GD(coords_query_r1, 121 | coords_ref, 122 | max_minus_value_t(corr_q_r), 123 | output_path, 124 | params_dist.bleeding, 125 | params_dist.dist_penalty1, 126 | alpha_basis = params_dist.alpha_basis, 127 | iterations = params_dist.iterations, 128 | prefix=prefix_t, 129 | attention_params = params_dist.attention_params, 130 | coords_log = True, 131 | index_list=[sub_node_idxs[k_t] for k_t in graph_list], 132 | mid_visual = mid_visual, 133 | early_stop_thres = early_stop_thres, 134 | ifrigid=params_dist.ifrigid) 135 | 136 | similarity_score,it_J,it_theta,coords_log = output_list 137 | params_dist.theta_r2 = it_theta[-1] 138 | result_log['affine_J'] = similarity_score 139 | result_log['affine_it_theta'] = it_theta 140 | result_log['affine_coords_log'] = coords_log 141 | result_log['coords_ref'] = coords_ref 142 | 143 | # Affine results 144 | affine_reg_params([i.cpu().numpy() for i in it_theta],similarity_score,params_dist.iterations,output_path,prefix=prefix_t)# if mid_visual else None 145 | if if_embed_sub: 146 | embed_stack_t = np.row_stack((embed_dict[query_sample].cpu().detach().numpy()[sub_node_idxs[query_sample]],embed_dict[ref_sample].cpu().detach().numpy()[sub_node_idxs[ref_sample]])) 147 | else: 148 | embed_stack_t = np.row_stack((embed_dict[query_sample].cpu().detach().numpy(),embed_dict[ref_sample].cpu().detach().numpy())) 149 | coords_query_r2 = affine_trans_t(params_dist.theta_r2,coords_query_r1) 150 | register_result(coords_query_r2.cpu().detach().numpy(), 151 | coords_ref.cpu().detach().numpy(), 152 | max_minus_value_t(corr_q_r).cpu(), 153 | params_dist.bleeding, 154 | embed_stack_t, 155 | output_path, 156 | k=20, 157 | prefix=prefix_t, 158 | scale_t=1, 159 | index_list=[sub_node_idxs[k_t] for k_t in graph_list])# if mid_visual else None 160 | 161 | if params_dist.iterations_bs[round_t] != 0: 162 | ### B-Spline free-form deformation 163 | padding_rate = params_dist.PaddingRate_bs # by default, 0 164 | coords_query_r2_min = coords_query_r2.min(0)[0] # The x and y min of the query coords 165 | coords_query_r2_tmp = coords_minus_min_t(coords_query_r2) # min of the x and y is 0 166 | max_xy_tmp = coords_query_r2_tmp.max(0)[0] # max_xy withouth padding 167 | adj_min_qr2 = coords_query_r2_min - max_xy_tmp * padding_rate # adjust the min_qr2 168 | setattr(params_dist,'img_size_bs',[(max_xy_tmp * (1+padding_rate * 2)).cpu()]) # max_xy 169 | params_dist.min_qr2 = [adj_min_qr2] 170 | t1 = BSpline_GD(coords_query_r2 - params_dist.min_qr2[round_t], 171 | coords_ref - params_dist.min_qr2[round_t], 172 | max_minus_value_t(corr_q_r), 173 | params_dist.iterations_bs[round_t], 174 | output_path, 175 | params_dist.bleeding, 176 | params_dist.dist_penalty2[round_t], 177 | params_dist.alpha_basis_bs[round_t], 178 | params_dist.diff_step, 179 | params_dist.meshsize[round_t], 180 | prefix_t + '_' + str(round_t), 181 | params_dist.mesh_weight[round_t], 182 | params_dist.attention_params_bs[round_t], 183 | coords_log = True, 184 | index_list=[sub_node_idxs[k_t] for k_t in graph_list], 185 | mid_visual = mid_visual, 186 | max_xy = params_dist.img_size_bs[round_t], 187 | renew_mesh_trans = renew_mesh_trans) 188 | 189 | # B-Spline FFD results 190 | register_result(t1[0].cpu().numpy(),(coords_ref - params_dist.min_qr2[round_t]).cpu().numpy(),max_minus_value_t(corr_q_r).cpu(),params_dist.bleeding,embed_stack_t,output_path,k=20,prefix=prefix_t+ '_' + str(round_t) +'_BSpine_' + str(params_dist.iterations_bs[round_t]),index_list=[sub_node_idxs[k_t] for k_t in graph_list])# if mid_visual else None 191 | # register_result(t1[0].cpu().numpy(),(coords_ref - coords_query_r2.min(0)[0]).cpu().numpy(),max_minus_value_t(corr_q_r).cpu(),params_dist.bleeding,embed_stack_t,output_path,k=20,prefix=prefix_t+ '_' + str(round_t) +'_BSpine_' + str(params_dist.iterations_bs[round_t]),index_list=[sub_node_idxs[k_t] for k_t in graph_list])# if mid_visual else None 192 | result_log['BS_coords_log1'] = t1[4] 193 | result_log['BS_J1'] = t1[3] 194 | if renew_mesh_trans: 195 | setattr(params_dist,'mesh_trans_list',[t1[1]]) 196 | else: 197 | setattr(params_dist,'mesh_trans_list',[[t1[1][-1]]]) 198 | 199 | ### Save results 200 | torch.save(params_dist,os.path.join(output_path,f'{prefix_t}_params.data')) 201 | torch.save(result_log,os.path.join(output_path,f'{prefix_t}_result_log.data')) 202 | coords_final = dict() 203 | _, coords_q_final = reg_total_t(coords_raw[query_sample],coords_raw[ref_sample],params_dist) 204 | coords_final[query_sample] = coords_q_final.cpu() / result_log['ref_rescale_factor'] ### rescale back to the original scale 205 | coords_final[ref_sample] = coords_raw[ref_sample] / result_log['ref_rescale_factor'] ### rescale back to the original scale 206 | plot_mid(coords_final[query_sample],coords_final[ref_sample],output_path,f'{prefix_t}_align') 207 | torch.save(coords_final,os.path.join(output_path,f'{prefix_t}_coords_final.data')) 208 | return coords_final 209 | 210 | def CAST_PROJECT( 211 | sdata_inte, # the integrated dataset 212 | source_sample, # the source sample name 213 | target_sample, # the target sample name 214 | coords_source, # the coordinates of the source sample 215 | coords_target, # the coordinates of the target sample 216 | scaled_layer = 'log2_norm1e4_scaled', # the scaled layer name in `adata.layers`, which is used to be integrated 217 | raw_layer = 'raw', # the raw layer name in `adata.layers`, which is used to be projected into target sample 218 | batch_key = 'protocol', # the column name of the samples in `obs` 219 | use_highly_variable_t = True, # if use highly variable genes 220 | ifplot = True, # if plot the result 221 | n_components = 50, # the `n_components` parameter in `sc.pp.pca` 222 | umap_n_neighbors = 50, # the `n_neighbors` parameter in `sc.pp.neighbors` 223 | umap_n_pcs = 30, # the `n_pcs` parameter in `sc.pp.neighbors` 224 | min_dist = 0.01, # the `min_dist` parameter in `sc.tl.umap` 225 | spread_t = 5, # the `spread` parameter in `sc.tl.umap` 226 | k2 = 1, # select k2 cells to do the projection for each cell 227 | source_sample_ctype_col = 'level_2', # the column name of the cell type in `obs` 228 | output_path = '', # the output path 229 | umap_feature = 'X_umap', # the feature used for umap 230 | pc_feature = 'X_pca_harmony', # the feature used for the projection 231 | integration_strategy = 'Harmony', # 'Harmony' or None (use existing integrated features) 232 | ave_dist_fold = 3, # the `ave_dist_fold` is used to set the distance threshold (average_distance * `ave_dist_fold`) 233 | save_result = True, # if save the results 234 | ifcombat = True, # if use combat when using the Harmony integration 235 | alignment_shift_adjustment = 50, # to adjust the small alignment shift for the distance threshold) 236 | color_dict = None, # the color dict for the cell type 237 | adjust_shift = False, # if adjust the alignment shift by group 238 | metric_t = 'cosine', 239 | working_memory_t = 1000 # the working memory for the pairwise distance calculation 240 | ): 241 | 242 | #### integration 243 | if integration_strategy == 'Harmony': 244 | sdata_inte = Harmony_integration( 245 | sdata_inte = sdata_inte, 246 | scaled_layer = scaled_layer, 247 | use_highly_variable_t = use_highly_variable_t, 248 | batch_key = batch_key, 249 | umap_n_neighbors = umap_n_neighbors, 250 | umap_n_pcs = umap_n_pcs, 251 | min_dist = min_dist, 252 | spread_t = spread_t, 253 | source_sample_ctype_col = source_sample_ctype_col, 254 | output_path = output_path, 255 | n_components = n_components, 256 | ifplot = True, 257 | ifcombat = ifcombat) 258 | elif integration_strategy is None: 259 | print(f'Using the pre-integrated data {pc_feature} and the UMAP {umap_feature}') 260 | 261 | #### Projection 262 | idx_source = sdata_inte.obs[batch_key] == source_sample 263 | idx_target = sdata_inte.obs[batch_key] == target_sample 264 | source_cell_pc_feature = sdata_inte[idx_source, :].obsm[pc_feature] 265 | target_cell_pc_feature = sdata_inte[idx_target, :].obsm[pc_feature] 266 | sdata_ref,output_list = space_project( 267 | sdata_inte = sdata_inte, 268 | idx_source = idx_source, 269 | idx_target = idx_target, 270 | raw_layer = raw_layer, 271 | source_sample = source_sample, 272 | target_sample = target_sample, 273 | coords_source = coords_source, 274 | coords_target = coords_target, 275 | output_path = output_path, 276 | source_sample_ctype_col = source_sample_ctype_col, 277 | target_cell_pc_feature = target_cell_pc_feature, 278 | source_cell_pc_feature = source_cell_pc_feature, 279 | k2 = k2, 280 | ifplot = ifplot, 281 | umap_feature = umap_feature, 282 | ave_dist_fold = ave_dist_fold, 283 | alignment_shift_adjustment = alignment_shift_adjustment, 284 | color_dict = color_dict, 285 | metric_t = metric_t, 286 | adjust_shift = adjust_shift, 287 | working_memory_t = working_memory_t 288 | ) 289 | 290 | ### Save the results 291 | if save_result == True: 292 | sdata_ref.write_h5ad(f'{output_path}/sdata_ref.h5ad') 293 | torch.save(output_list,f'{output_path}/projection_data.pt') 294 | return sdata_ref,output_list -------------------------------------------------------------------------------- /CAST/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanglab-broad/CAST/3ebcf4fbe1e43dbbb9ae92ecd07e93562e156160/CAST/models/__init__.py -------------------------------------------------------------------------------- /CAST/models/aug.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import numpy as np 3 | import dgl 4 | # __func is original, func is GPU optimized 5 | def __random_aug(graph, x, feat_drop_rate, edge_mask_rate): 6 | n_node = graph.number_of_nodes() 7 | 8 | edge_mask = mask_edge(graph, edge_mask_rate) 9 | feat = drop_feature(x, feat_drop_rate) 10 | 11 | ng = dgl.graph([]) 12 | ng.add_nodes(n_node) 13 | src = graph.edges()[0] 14 | dst = graph.edges()[1] 15 | 16 | nsrc = src[edge_mask] 17 | ndst = dst[edge_mask] 18 | ng.add_edges(nsrc, ndst) 19 | 20 | return ng, feat 21 | 22 | def __drop_feature(x, drop_prob): 23 | drop_mask = th.empty( 24 | (x.size(1),), 25 | dtype=th.float32, 26 | device=x.device).uniform_(0, 1) < drop_prob 27 | x = x.clone() 28 | x[:, drop_mask] = 0 29 | 30 | return x 31 | 32 | def __mask_edge(graph, mask_prob): 33 | E = graph.number_of_edges() 34 | 35 | mask_rates = th.FloatTensor(np.ones(E) * mask_prob) 36 | masks = th.bernoulli(1 - mask_rates) 37 | mask_idx = masks.nonzero().squeeze(1) 38 | return mask_idx 39 | 40 | def random_aug(graph, x, feat_drop_rate, edge_mask_rate): 41 | n_node = graph.number_of_nodes() 42 | 43 | edge_mask = mask_edge(graph, edge_mask_rate) 44 | feat = x.clone() 45 | feat = drop_feature(feat, feat_drop_rate) 46 | 47 | ng = dgl.graph([], device=graph.device) 48 | ng.add_nodes(n_node) 49 | src = graph.edges()[0] 50 | dst = graph.edges()[1] 51 | 52 | nsrc = src[edge_mask] 53 | ndst = dst[edge_mask] 54 | ng.add_edges(nsrc, ndst) 55 | 56 | return ng, feat 57 | 58 | def drop_feature(x, drop_prob): 59 | drop_mask = th.empty( 60 | (x.size(1),), 61 | dtype=th.float32, 62 | device=x.device).uniform_(0, 1) < drop_prob 63 | # x = x.clone() 64 | x[:, drop_mask] = 0 65 | 66 | return x 67 | 68 | def mask_edge(graph, mask_prob): 69 | E = graph.number_of_edges() 70 | mask_rates = th.ones(E, device=graph.device) * mask_prob 71 | masks = th.bernoulli(1 - mask_rates) 72 | mask_idx = masks.nonzero().squeeze(1) 73 | return mask_idx -------------------------------------------------------------------------------- /CAST/models/model_GCNII.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import torch 3 | import torch.nn as nn 4 | # import torch.nn.functional as F 5 | from dataclasses import dataclass, field 6 | from dgl.nn import GCN2Conv, GraphConv 7 | 8 | @dataclass 9 | class Args: 10 | dataname : str 11 | gpu : int = 0 12 | epochs : int = 1000 13 | lr1 : float = 1e-3 14 | wd1 : float = 0.0 15 | lambd : float = 1e-3 16 | n_layers : int = 9 17 | der : float = 0.2 18 | dfr : float = 0.2 19 | device : str = field(init=False) 20 | encoder_dim : int = 256 21 | use_encoder : bool = False 22 | 23 | def __post_init__(self): 24 | if self.gpu != -1 and torch.cuda.is_available(): 25 | self.device = 'cuda:{}'.format(self.gpu) 26 | else: 27 | self.device = 'cpu' 28 | 29 | 30 | 31 | # fix the div zero standard deviation bug, Shuchen Luo (20220217) 32 | def standardize(x, eps = 1e-12): 33 | return (x - x.mean(0)) / x.std(0).clamp(eps) 34 | 35 | class Encoder(nn.Module): 36 | def __init__(self, in_dim : int, encoder_dim : int): 37 | super().__init__() 38 | self.layer = nn.Linear(in_dim, encoder_dim, bias=True) 39 | self.relu = nn.ReLU() 40 | def forward(self, x): 41 | return self.relu(self.layer(x)) 42 | 43 | 44 | # GCN2Conv(in_feats, layer, alpha=0.1, lambda_=1, project_initial_features=True, allow_zero_in_degree=False, bias=True, activation=None) 45 | class GCNII(nn.Module): 46 | def __init__(self, in_dim : int, encoder_dim: int, n_layers : int, alpha=None, lambda_=None, use_encoder=False): 47 | super().__init__() 48 | 49 | self.n_layers = n_layers 50 | self.use_encoder = use_encoder 51 | if alpha is None: 52 | self.alpha = [0.1] * self.n_layers 53 | else: 54 | self.alpha = alpha 55 | if lambda_ is None: 56 | self.lambda_ = [1.] * self.n_layers 57 | else: 58 | self.lambda_ = lambda_ 59 | if self.use_encoder: 60 | self.encoder = Encoder(in_dim, encoder_dim) 61 | self.hid_dim = encoder_dim 62 | else: self.hid_dim = in_dim 63 | self.relu = nn.ReLU() 64 | self.convs = nn.ModuleList() 65 | 66 | for i in range(n_layers): 67 | self.convs.append(GCN2Conv(self.hid_dim, i + 1, alpha=self.alpha[i], lambda_=self.lambda_[i], activation=None)) 68 | 69 | def forward(self, graph, x): 70 | if self.use_encoder: 71 | x = self.encoder(x) 72 | # print('GCNII forward: after encoder', torch.any(torch.isnan(x))) 73 | feat0 = x 74 | for i in range(self.n_layers): 75 | x = self.relu(self.convs[i](graph, x, feat0)) 76 | # print('GCNII layer', i + 1, 'is_nan', torch.any(torch.isnan(x))) 77 | return x 78 | 79 | 80 | 81 | class GCN(nn.Module): 82 | def __init__(self, in_dim : int, encoder_dim: int, n_layers : int, use_encoder=False): 83 | super().__init__() 84 | 85 | self.n_layers = n_layers 86 | self.use_encoder = use_encoder 87 | 88 | if self.use_encoder: 89 | self.encoder = Encoder(in_dim, encoder_dim) 90 | self.hid_dim = encoder_dim 91 | else: self.hid_dim = in_dim 92 | self.relu = nn.ReLU() 93 | self.convs = nn.ModuleList() 94 | 95 | for i in range(n_layers): 96 | self.convs.append(GraphConv(self.hid_dim, self.hid_dim, activation=None)) 97 | 98 | def forward(self, graph, x): 99 | if self.use_encoder: 100 | x = self.encoder(x) 101 | # print('GCN forward: after encoder', torch.any(torch.isnan(x))) 102 | for i in range(self.n_layers): 103 | x = self.relu(self.convs[i](graph, x)) 104 | # print('GCN layer', i + 1, 'is_nan', torch.any(torch.isnan(x))) 105 | return x 106 | 107 | 108 | 109 | class CCA_SSG(nn.Module): 110 | def __init__(self, in_dim, encoder_dim, n_layers, backbone='GCNII', alpha=None, lambda_=None, use_encoder=False): 111 | super().__init__() 112 | if backbone == 'GCNII': 113 | self.backbone = GCNII(in_dim, encoder_dim, n_layers, alpha, lambda_, use_encoder) 114 | elif backbone == 'GCN': 115 | self.backbone = GCN(in_dim, encoder_dim, n_layers, use_encoder) 116 | 117 | def get_embedding(self, graph, feat): 118 | out = self.backbone(graph, feat) 119 | return out.detach() 120 | 121 | def forward(self, graph1, feat1, graph2, feat2): 122 | h1 = self.backbone(graph1, feat1) 123 | h2 = self.backbone(graph2, feat2) 124 | # print('CCASSG forward: h1 is', torch.any(torch.isnan(h1))) 125 | # print('CCASSG forward: h2 is', torch.any(torch.isnan(h2))) 126 | z1 = standardize(h1) 127 | z2 = standardize(h2) 128 | # print('h1.std', h1.std(0)) 129 | # print('h1-h1.mean(0)', h1 - h1.mean(0)) 130 | # print('CCASSG forward: z1 is', torch.any(torch.isnan(z1))) 131 | # print('CCASSG forward: z2 is', torch.any(torch.isnan(z2))) 132 | 133 | return z1, z2 134 | 135 | -------------------------------------------------------------------------------- /CAST/utils.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import numpy as np 3 | import pandas as pd 4 | import scipy, random 5 | import scanpy as sc 6 | from tqdm import tqdm 7 | import matplotlib.pyplot as plt 8 | from sklearn.metrics import pairwise_distances_chunked, pairwise_distances 9 | from .visualize import link_plot 10 | 11 | def coords2adjacentmat(coords,output_mode = 'adjacent',strategy_t = 'convex'): 12 | if strategy_t == 'convex': ### slow but may generate more reasonable delaunay graph 13 | from libpysal.cg import voronoi_frames 14 | from libpysal import weights 15 | cells, _ = voronoi_frames(coords, clip="convex hull") 16 | delaunay_graph = weights.Rook.from_dataframe(cells).to_networkx() 17 | elif strategy_t == 'delaunay': ### fast but may generate long distance edges 18 | from scipy.spatial import Delaunay 19 | from collections import defaultdict 20 | tri = Delaunay(coords) 21 | delaunay_graph = nx.Graph() 22 | coords_dict = defaultdict(list) 23 | for i, coord in enumerate(coords): 24 | coords_dict[tuple(coord)].append(i) 25 | for simplex in tri.simplices: 26 | for i in range(3): 27 | for node1 in coords_dict[tuple(coords[simplex[i]])]: 28 | for node2 in coords_dict[tuple(coords[simplex[(i+1)%3]])]: 29 | if not delaunay_graph.has_edge(node1, node2): 30 | delaunay_graph.add_edge(node1, node2) 31 | if output_mode == 'adjacent': 32 | return nx.to_scipy_sparse_array(delaunay_graph).todense() 33 | elif output_mode == 'raw': 34 | return delaunay_graph 35 | elif output_mode == 'adjacent_sparse': 36 | return nx.to_scipy_sparse_array(delaunay_graph) 37 | 38 | def hv_cutoff(max_col,threshold=2000): 39 | for thres_t in range(0,int(np.max(max_col))): 40 | if np.sum(max_col > thres_t) < threshold: 41 | return thres_t -1 42 | 43 | def detect_highly_variable_genes(sdata,batch_key = 'batch',n_top_genes = 4000,count_layer = 'count'): 44 | samples = np.unique(sdata.obs[batch_key]) 45 | thres_list = [] 46 | max_count_list = [] 47 | bool_list = [] 48 | for list_t, sample_t in enumerate(samples): 49 | idx_t = sdata.obs[batch_key] == sample_t 50 | if count_layer == '.X': 51 | max_t = sdata[idx_t,:].X.max(0).toarray() if scipy.sparse.issparse(sdata.X) else sdata[idx_t,:].X.max(0) 52 | else: 53 | max_t = sdata[idx_t,:].layers[count_layer].max(0).toarray() if scipy.sparse.issparse(sdata.layers[count_layer]) else sdata[idx_t,:].layers[count_layer].max(0) 54 | max_count_list.append(max_t) 55 | thres_list.append(hv_cutoff(max_count_list[list_t],threshold=n_top_genes)) 56 | bool_list.append(max_count_list[list_t] > thres_list[list_t]) 57 | stack = np.stack(bool_list) 58 | return np.all(stack, axis=0)[0] 59 | 60 | def extract_coords_exp(sdata, batch_key = 'batch', cols = 'spatial', count_layer = 'count', data_format = 'norm1e4',ifcombat = False, if_inte = False): 61 | coords_raw = {} 62 | exps = {} 63 | samples = np.unique(sdata.obs[batch_key]) 64 | if count_layer == '.X': 65 | sdata.layers['raw'] = sdata.X.copy() 66 | sdata = preprocess_fast(sdata, mode = 'customized') 67 | if if_inte: 68 | scaled_layer = 'log2_norm1e4_scaled' 69 | pc_feature = 'X_pca_harmony' 70 | sdata = Harmony_integration(sdata, 71 | scaled_layer = scaled_layer, 72 | use_highly_variable_t = True, 73 | batch_key = batch_key, 74 | umap_n_neighbors = 50, 75 | umap_n_pcs = 30, 76 | min_dist = 0.01, 77 | spread_t = 5, 78 | source_sample_ctype_col = None, 79 | output_path = None, 80 | n_components = 50, 81 | ifplot = False, 82 | ifcombat = ifcombat) 83 | for sample_t in samples: 84 | idx_t = sdata.obs[batch_key] == sample_t 85 | coords_raw[sample_t] = sdata.obsm['spatial'][idx_t] if type(cols) is not list else np.array(sdata.obs[cols][idx_t]) 86 | exps[sample_t] = sdata[idx_t].obsm[pc_feature].copy() 87 | else: 88 | sdata.X = sdata.layers[data_format].copy() 89 | if ifcombat == True: 90 | sc.pp.combat(sdata, key=batch_key) 91 | for sample_t in samples: 92 | idx_t = sdata.obs[batch_key] == sample_t 93 | coords_raw[sample_t] = sdata.obsm['spatial'][idx_t] if type(cols) is not list else np.array(sdata.obs[cols][idx_t]) 94 | exps[sample_t] = sdata[idx_t].X.copy() 95 | if scipy.sparse.issparse(exps[sample_t]): 96 | exps[sample_t] = exps[sample_t].toarray() 97 | return coords_raw,exps 98 | 99 | def Harmony_integration( 100 | sdata_inte, 101 | scaled_layer, 102 | use_highly_variable_t, 103 | batch_key, 104 | umap_n_neighbors, 105 | umap_n_pcs, 106 | min_dist, 107 | spread_t, 108 | source_sample_ctype_col, 109 | output_path, 110 | n_components = 50, 111 | ifplot = True, 112 | ifcombat = False): 113 | #### integration based on the Harmony 114 | sdata_inte.X = sdata_inte.layers[scaled_layer].copy() 115 | if ifcombat == True: 116 | sc.pp.combat(sdata_inte, key=batch_key) 117 | print(f'Running PCA based on the layer {scaled_layer}:') 118 | sc.tl.pca(sdata_inte, use_highly_variable=use_highly_variable_t, svd_solver = 'full', n_comps= n_components) 119 | print(f'Running Harmony integration:') 120 | sc.external.pp.harmony_integrate(sdata_inte, batch_key) 121 | print(f'Compute a neighborhood graph based on the {umap_n_neighbors} `n_neighbors`, {umap_n_pcs} `n_pcs`:') 122 | sc.pp.neighbors(sdata_inte, n_neighbors=umap_n_neighbors, n_pcs=umap_n_pcs, use_rep='X_pca_harmony') 123 | print(f'Generate the UMAP based on the {min_dist} `min_dist`, {spread_t} `spread`:') 124 | sc.tl.umap(sdata_inte,min_dist=min_dist, spread = spread_t) 125 | sdata_inte.obsm['har_X_umap'] = sdata_inte.obsm['X_umap'].copy() 126 | if ifplot == True: 127 | plt.rcParams.update({'pdf.fonttype':42}) 128 | sc.settings.figdir = output_path 129 | sc.set_figure_params(figsize=(10, 10),facecolor='white',vector_friendly=True, dpi_save=300,fontsize = 25) 130 | sc.pl.umap(sdata_inte,color=[batch_key],size=10,save=f'_har_{umap_n_pcs}pcs_batch.pdf') 131 | sc.pl.umap(sdata_inte,color=[source_sample_ctype_col],size=10,save=f'_har_{umap_n_pcs}pcs_ctype.pdf') if source_sample_ctype_col is not None else None 132 | return sdata_inte 133 | 134 | def random_sample(coords_t, nodenum, seed_t = 2): 135 | random.seed(seed_t) 136 | sub_node_idx = np.sort(random.sample(range(coords_t.shape[0]),nodenum)) 137 | return sub_node_idx 138 | 139 | def sub_node_sum(coords_t,exp_t,nodenum=1000,vis = True,seed_t = 2): 140 | from scipy.sparse import csr_matrix as csr 141 | # random.seed(seed_t) 142 | if nodenum > coords_t.shape[0]: 143 | print('The number of nodes is larger than the total number of nodes. Return the original data.') 144 | sub_node_idx = np.arange(coords_t.shape[0]) 145 | if scipy.sparse.issparse(exp_t): 146 | return exp_t,sub_node_idx 147 | else: 148 | return csr(exp_t),sub_node_idx 149 | # sub_node_idx = np.sort(random.sample(range(coords_t.shape[0]),nodenum)) 150 | sub_node_idx = random_sample(coords_t, nodenum, seed_t = seed_t) 151 | coords_t_sub = coords_t[sub_node_idx,:].copy() 152 | close_idx = nearest_neighbors_idx(coords_t_sub,coords_t) 153 | A = np.zeros([coords_t_sub.shape[0],coords_t.shape[0]]) 154 | for ind,i in enumerate(close_idx.tolist()): 155 | A[i,ind] = 1 156 | csr_A = csr(A) 157 | if scipy.sparse.issparse(exp_t): 158 | exp_t_sub = csr_A.dot(exp_t) 159 | else: 160 | exp_t_sub = csr_A.dot(csr(exp_t)) 161 | if(vis == True): 162 | link_plot(close_idx,coords_t,coords_t_sub,k = 1) 163 | return exp_t_sub,sub_node_idx 164 | 165 | def nearest_neighbors_idx(coord1,coord2,mode_t = 'knn'): ### coord1 is the reference, coord2 is the target 166 | if mode_t == 'knn': 167 | from sklearn.neighbors import KNeighborsClassifier 168 | knn_classifier = KNeighborsClassifier(n_neighbors=1,metric='euclidean') 169 | knn_classifier.fit(coord1, np.zeros(coord1.shape[0])) # Use dummy labels, since we only care about distances 170 | # Find nearest neighbors 171 | _, close_idx = knn_classifier.kneighbors(coord2) 172 | return close_idx 173 | else: 174 | result = [] 175 | dists = pairwise_distances_chunked(coord2,coord1,working_memory = 100, metric='euclidean', n_jobs=-1) 176 | for chunk in tqdm(dists): # for each chunk (minibatch) 177 | knn_ind = np.argpartition(chunk, 0, axis=-1)[:, 0] # introsort to get indices of top k neighbors according to the distance matrix [n_query, k] 178 | result.append(knn_ind) 179 | close_idx = np.concatenate(result) 180 | return np.expand_dims(close_idx, axis=1) 181 | 182 | def non_zero_center_scale(sdata_t_X): 183 | std_nocenter = np.sqrt(np.square(sdata_t_X).sum(0)/(sdata_t_X.shape[0]-1)) 184 | return(sdata_t_X/std_nocenter) 185 | 186 | def sub_data_extract(sample_list,coords_raw, exps, nodenum_t = 20000, if_non_zero_center_scale = True): 187 | coords_sub = dict() 188 | exp_sub = dict() 189 | sub_node_idxs = dict() 190 | for sample_t in sample_list: 191 | exp_t,sub_node_idxs[sample_t] = sub_node_sum(coords_raw[sample_t],exps[sample_t],nodenum=nodenum_t,vis = False) 192 | exp_sub[sample_t] = non_zero_center_scale(exp_t.toarray()) if if_non_zero_center_scale else exp_t.toarray() 193 | coords_sub[sample_t] = coords_raw[sample_t][sub_node_idxs[sample_t],:] 194 | return coords_sub,exp_sub,sub_node_idxs 195 | 196 | def preprocess_fast(sdata1, mode = 'customized',target_sum=1e4,base = 2,zero_center = True,regressout = False): 197 | print('Preprocessing...') 198 | from scipy.sparse import csr_matrix as csr 199 | if 'raw' in sdata1.layers: 200 | if type(sdata1.layers['raw']) != scipy.sparse._csr.csr_matrix: 201 | sdata1.layers['raw'] = csr(sdata1.layers['raw'].copy()) 202 | sdata1.X = sdata1.layers['raw'].copy() 203 | else: 204 | if type(sdata1.X) != scipy.sparse._csr.csr_matrix: 205 | sdata1.X = csr(sdata1.X.copy()) 206 | sdata1.layers['raw'] = sdata1.X.copy() 207 | if mode == 'default': 208 | sc.pp.normalize_total(sdata1) 209 | sdata1.layers['norm'] = csr(sdata1.X.copy()) 210 | sc.pp.log1p(sdata1) 211 | sdata1.layers['log1p_norm'] = csr(sdata1.X.copy()) 212 | sc.pp.scale(sdata1,zero_center = zero_center) 213 | if scipy.sparse.issparse(sdata1.X): #### automatically change to non csr matrix (zero_center == True, the .X would be sparce) 214 | sdata1.X = sdata1.X.toarray().copy() 215 | sdata1.layers['log1p_norm_scaled'] = sdata1.X.copy() 216 | if regressout: 217 | sdata1.obs['total_counts'] = sdata1.layers['raw'].toarray().sum(axis=1) 218 | sc.pp.regress_out(sdata1, ['total_counts']) 219 | sdata1.layers['log1p_norm_scaled'] = sdata1.X.copy() 220 | return sdata1 #### sdata1.X is sdata1.layers['log1p_norm_scaled'] 221 | elif mode == 'customized': 222 | if target_sum == 1e4: 223 | target_sum_str = '1e4' 224 | else: 225 | target_sum_str = str(target_sum) 226 | sc.pp.normalize_total(sdata1,target_sum=target_sum) 227 | sdata1.layers[f'norm{target_sum_str}'] = csr(sdata1.X.copy()) 228 | sc.pp.log1p(sdata1,base = base) 229 | sdata1.layers[f'log{str(base)}_norm{target_sum_str}'] = csr(sdata1.X.copy()) 230 | sc.pp.scale(sdata1,zero_center = zero_center) 231 | if scipy.sparse.issparse(sdata1.X): #### automatically change to non csr matrix (zero_center == True, the .X would be sparce) 232 | sdata1.X = sdata1.X.toarray().copy() 233 | sdata1.layers[f'log{str(base)}_norm{target_sum_str}_scaled'] = sdata1.X.copy() 234 | if regressout: 235 | sdata1.obs['total_counts'] = sdata1.layers['raw'].toarray().sum(axis=1) 236 | sc.pp.regress_out(sdata1, ['total_counts']) 237 | sdata1.layers[f'log{str(base)}_norm{target_sum_str}_scaled'] = sdata1.X.copy() 238 | return sdata1 #### sdata1.X is sdata1.layers[f'log{str(base)}_norm{target_sum_str}_scaled'] 239 | else: 240 | print('Please set the `mode` as one of the {"default", "customized"}.') 241 | 242 | def cell_select(coords_t, s=0.5, c=None, output_path_t=None): 243 | ''' 244 | Select cells by drawing a polygon on the plot. 245 | Click the "Finish Polygon" button to finish drawing the polygon. 246 | Click the "Clear Polygon" button to clear the polygon. 247 | ''' 248 | import matplotlib.pyplot as plt 249 | from matplotlib.patches import Polygon 250 | from shapely.geometry import Point, Polygon as ShapelyPolygon 251 | import ipywidgets as widgets 252 | import numpy as np 253 | 254 | indices = np.arange(coords_t.shape[0]).reshape(-1, 1) 255 | coords = np.hstack((coords_t, indices)) 256 | 257 | global poly_coords, polygon_patch, selected_cell_ids 258 | poly_coords = [] 259 | polygon_patch = None 260 | selected_cell_ids = [] 261 | 262 | def process_selected_ids(selected_ids): 263 | # Inner function to process the selected cell IDs 264 | print("Selected Cell IDs:", selected_ids) 265 | # Additional processing can be done here 266 | 267 | def on_click(event): 268 | global poly_coords, polygon_patch 269 | if event.inaxes is not ax or event.button != 1: 270 | return 271 | poly_coords.append((event.xdata, event.ydata)) 272 | if polygon_patch: 273 | polygon_patch.remove() 274 | polygon_patch = Polygon(poly_coords, closed=False, color='blue', alpha=0.3) 275 | ax.add_patch(polygon_patch) 276 | fig.canvas.draw() 277 | 278 | def finish_polygon(b): 279 | global poly_coords, polygon_patch, selected_cell_ids 280 | if polygon_patch: 281 | polygon_patch.set_closed(True) 282 | fig.canvas.draw() 283 | shapely_poly = ShapelyPolygon(poly_coords) 284 | selected_cell_ids = [int(id) for x, y, id in coords if shapely_poly.contains(Point(x, y))] 285 | process_selected_ids(selected_cell_ids) 286 | if output_path_t is not None: 287 | fig.savefig(output_path_t) 288 | poly_coords.clear() 289 | 290 | def clear_polygon(b): 291 | global poly_coords, polygon_patch 292 | if polygon_patch: 293 | polygon_patch.remove() 294 | polygon_patch = None 295 | fig.canvas.draw() 296 | poly_coords.clear() 297 | 298 | fig, ax = plt.subplots(figsize=[10, 10]) 299 | x, y, ids = zip(*coords) 300 | ax.scatter(x, y, s=s, c=c) 301 | ax.set_aspect('equal', adjustable='box') 302 | 303 | finish_button = widgets.Button(description="Finish Polygon") 304 | finish_button.on_click(finish_polygon) 305 | 306 | clear_button = widgets.Button(description="Clear Polygon") 307 | clear_button.on_click(clear_polygon) 308 | 309 | display(widgets.HBox([finish_button, clear_button])) 310 | 311 | fig.canvas.mpl_connect('button_press_event', on_click) 312 | 313 | #### Delta analysis 314 | 315 | 316 | def get_neighborhood_rad(coords_centroids, coords_candidate, radius_px, dist=None): 317 | if dist is None: 318 | dist = pairwise_distances(coords_centroids, coords_candidate, metric='euclidean', n_jobs=-1) 319 | rad_mask = dist < radius_px 320 | return rad_mask 321 | 322 | def delta_cell_cal(coords_tgt,coords_ref,ctype_tgt,ctype_ref,radius_px): 323 | ''' 324 | coords_tgt: coordinates of niche centroids (target cells). 325 | coords_ref: coordinates of reference cells. 326 | ctype_tgt: cell type of niche centroids. 327 | ctype_ref: cell type of reference cells. 328 | radius_px: radius of neighborhood. 329 | 330 | Output: 331 | return: delta_cell_tgt, delta_cell_ref, delta_cell. 332 | 333 | e.g. 334 | coords_tgt = coords_final['injured'] 335 | coords_ref = coords_final['normal'] 336 | ctype_tgt = sdata.obs['Annotation'][right_idx] 337 | ctype_ref = sdata.obs['Annotation'][left_idx] 338 | radius_px = 1000 339 | df_delta_cell_tgt,df_delta_cell_ref,df_delta_cell = delta_cell(coords_tgt,coords_ref,ctype_tgt,ctype_ref,radius_px) 340 | ''' 341 | ##### 1. generate nbhd_mask_tgt and nbhd_mask_ref. 342 | # nbhd_mask_tgt: coords_tgt vs coords_tgt itself. 343 | # nbhd_mask_ref: coords_tgt vs coords_ref. 344 | nbhd_mask_tgt = get_neighborhood_rad(coords_tgt, coords_tgt, radius_px) 345 | nbhd_mask_ref = get_neighborhood_rad(coords_tgt, coords_ref, radius_px) 346 | 347 | ##### 2. generate ctype_one_hot_array 348 | # ctype_one_hot_array: one-hot encoding of cell types. 349 | # To make the order of columns consistent, we stack the two labels together. 350 | ctype_all = np.hstack([ctype_tgt,ctype_ref]) 351 | idx_ctype_tgt = np.arange(len(ctype_tgt)) 352 | idx_ctype_ref = np.arange(len(ctype_tgt),len(ctype_tgt)+len(ctype_ref)) 353 | ctype_one_hot = pd.get_dummies(ctype_all) 354 | ctype_one_hot_cols = ctype_one_hot.columns 355 | ctype_one_hot_tgt = ctype_one_hot.values[idx_ctype_tgt] 356 | ctype_one_hot_ref = ctype_one_hot.values[idx_ctype_ref] 357 | 358 | ##### 3. generate delta_cell_tgt, delta_cell_ref and delta_cell 359 | # delta_cell_tgt: raw cell type counts of target cells given niche centroids. 360 | # delta_cell_ref: raw cell type counts of reference cells given niche centroids. 361 | # delta_cell: delta_cell_tgt - delta_cell_ref. 362 | d_cell_tgt = nbhd_mask_tgt.astype(int).dot(ctype_one_hot_tgt.astype(int)) 363 | d_cell_ref = nbhd_mask_ref.astype(int).dot(ctype_one_hot_ref.astype(int)) 364 | d_cell = d_cell_tgt - d_cell_ref 365 | 366 | return pd.DataFrame(d_cell_tgt,columns = ctype_one_hot_cols), pd.DataFrame(d_cell_ref,columns = ctype_one_hot_cols), pd.DataFrame(d_cell,columns = ctype_one_hot_cols) 367 | 368 | def delta_exp_cal(coords_tgt,coords_ref,exp_tgt,exp_ref,radius_px,valid_tgt_idx=None,valid_ref_idx=None): 369 | ''' 370 | coords_tgt: coordinates of niche centroids (target cells). 371 | coords_ref: coordinates of reference cells. 372 | exp_tgt: gene expression of target cells. 373 | exp_ref: gene expression of reference cells. 374 | radius_px: radius of neighborhood. 375 | 376 | Output: 377 | return: delta_exp_tgt, delta_exp_ref, delta_exp. 378 | 379 | e.g. 380 | ''' 381 | ##### 0. generate valid_tgt_idx and valid_ref_idx. For the cell type specific analysis, we only consider the cells in the given cell type. 382 | valid_tgt_idx = np.arange(len(coords_tgt)) if valid_tgt_idx is None else valid_tgt_idx 383 | valid_ref_idx = np.arange(len(coords_ref)) if valid_ref_idx is None else valid_ref_idx 384 | 385 | ##### 1. generate nbhd_mask_tgt and nbhd_mask_ref. 386 | # nbhd_mask_tgt: coords_tgt vs coords_tgt itself. 387 | # nbhd_mask_ref: coords_tgt vs coords_ref. 388 | nbhd_mask_tgt = get_neighborhood_rad(coords_tgt, coords_tgt[valid_tgt_idx], radius_px) 389 | nbhd_mask_ref = get_neighborhood_rad(coords_tgt, coords_ref[valid_ref_idx], radius_px) 390 | 391 | ##### 2. generate delta_cell_tgt, delta_cell_ref and delta_cell 392 | # delta_exp_tgt: Average gene expression of target cells given niche centroids. 393 | # delta_exp_ref: Average gene expression of reference cells given niche centroids. 394 | # delta_exp: delta_exp_tgt - delta_exp_ref. 395 | d_exp_tgt = nbhd_mask_tgt.dot(exp_tgt[valid_tgt_idx]).astype(float) / nbhd_mask_tgt.sum(axis=1)[:,None] 396 | d_exp_ref = nbhd_mask_ref.dot(exp_ref[valid_ref_idx]).astype(float) / nbhd_mask_ref.sum(axis=1)[:,None] 397 | 398 | ### if the nbhd_mask_tgt.sum(axis=1)[:,None] is 0, then the d_exp_tgt, d_exp_ref will be nan. We set it to 0. 399 | d_exp_tgt[np.isnan(d_exp_tgt)] = 0 400 | d_exp_ref[np.isnan(d_exp_ref)] = 0 401 | d_exp = d_exp_tgt - d_exp_ref 402 | 403 | return d_exp_tgt, d_exp_ref, d_exp 404 | 405 | def delta_exp_sigplot(p_values,avg_differences,abs_10logp_cutoff = None, abs_avg_diff_cutoff = None, sig = True): 406 | y_t = np.array(-np.log10(p_values)) 407 | x_t = np.array(avg_differences) 408 | abs_10logp_cutoff = np.quantile(np.abs(y_t),0.95) if abs_10logp_cutoff is None else abs_10logp_cutoff 409 | abs_avg_diff_cutoff = np.quantile(np.abs(x_t),0.95) if abs_avg_diff_cutoff is None else abs_avg_diff_cutoff 410 | idx_sig = (np.abs(y_t) > abs_10logp_cutoff) & (np.abs(x_t) > abs_avg_diff_cutoff) if sig else np.zeros(len(y_t),dtype = bool) 411 | idx_sig_up = (y_t > abs_10logp_cutoff) & (x_t > abs_avg_diff_cutoff) if sig else np.zeros(len(y_t),dtype = bool) 412 | idx_sig_down = (y_t > abs_10logp_cutoff) & (x_t < -abs_avg_diff_cutoff) if sig else np.zeros(len(y_t),dtype = bool) 413 | plt.figure(figsize = (10,10)) 414 | plt.scatter(x_t,y_t, s = 2, c = 'black', rasterized = True) 415 | plt.scatter(x_t[idx_sig],y_t[idx_sig], s = 5, c = 'red', rasterized = True) 416 | plt.xlabel('Average difference') 417 | plt.ylabel('-log10(p)') 418 | return idx_sig, idx_sig_up, idx_sig_down 419 | 420 | def delta_exp_statistics(delta_exp_tgt, delta_exp_ref): 421 | from scipy.stats import ranksums 422 | from tqdm import tqdm 423 | p_values = [] 424 | avg_differences = [] 425 | for i in tqdm(range(delta_exp_tgt.shape[1])): 426 | # Calculate the rank-sum p-value 427 | p_value = ranksums(delta_exp_tgt[:, i], delta_exp_ref[:, i]).pvalue 428 | p_values.append(p_value) 429 | # Calculate the average of the differences 430 | avg_difference = np.mean(delta_exp_tgt[:, i] - delta_exp_ref[:, i]) 431 | avg_differences.append(avg_difference) 432 | return p_values, avg_differences -------------------------------------------------------------------------------- /CAST/visualize.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import seaborn as sns 3 | import numpy as np 4 | from sklearn.cluster import KMeans, MiniBatchKMeans 5 | import os 6 | 7 | def kmeans_plot_multiple(embed_dict_t,graph_list,coords,taskname_t,output_path_t,k=20,dot_size = 10,scale_bar_t = None,minibatch = True,plot_strategy = 'sep',axis_off = False): 8 | num_plot = len(graph_list) 9 | plot_row = int(np.floor(num_plot/2) + 1) 10 | embed_stack = embed_dict_t[graph_list[0]].cpu().detach().numpy() 11 | for i in range(1,num_plot): 12 | embed_stack = np.row_stack((embed_stack,embed_dict_t[graph_list[i]].cpu().detach().numpy())) 13 | print(f'Perform KMeans clustering on {embed_stack.shape[0]} cells...') 14 | kmeans = KMeans(n_clusters=k,random_state=0).fit(embed_stack) if minibatch == False else MiniBatchKMeans(n_clusters=k,random_state=0).fit(embed_stack) 15 | cell_label = kmeans.labels_ 16 | cluster_pl = sns.color_palette('tab20',len(np.unique(cell_label))) 17 | print(f'Plotting the KMeans clustering results...') 18 | cell_label_idx = 0 19 | if plot_strategy == 'sep': 20 | plt.figure(figsize=((20,10 * plot_row))) 21 | for j in range(num_plot): 22 | plt.subplot(plot_row,2,j+1) 23 | coords0 = coords[graph_list[j]] 24 | col=coords0[:,0].tolist() 25 | row=coords0[:,1].tolist() 26 | cell_type_t = cell_label[cell_label_idx:(cell_label_idx + coords0.shape[0])] 27 | cell_label_idx += coords0.shape[0] 28 | for i in set(cell_type_t): 29 | plt.scatter(np.array(col)[cell_type_t == i], 30 | np.array(row)[cell_type_t == i], s=dot_size,edgecolors='none', 31 | c=np.array(cluster_pl)[cell_type_t[cell_type_t == i]],label = str(i), rasterized=True) 32 | plt.title(graph_list[j] + ' (KMeans, k = ' + str(k) + ')',fontsize=20) 33 | plt.xticks(fontsize=20) 34 | plt.yticks(fontsize=20) 35 | plt.axis('equal') 36 | if axis_off: 37 | plt.xticks([]) 38 | plt.yticks([]) 39 | if (type(scale_bar_t) != type(None)): 40 | add_scale_bar(scale_bar_t[0],scale_bar_t[1]) 41 | else: 42 | plt.figure(figsize=[10,12]) 43 | plt.rcParams.update({'font.size' : 10,'axes.titlesize' : 20,'pdf.fonttype':42}) 44 | for j in range(num_plot): 45 | coords0 = coords[graph_list[j]] 46 | col=coords0[:,0].tolist() 47 | row=coords0[:,1].tolist() 48 | cell_type_t = cell_label[cell_label_idx:(cell_label_idx + coords0.shape[0])] 49 | cell_label_idx += coords0.shape[0] 50 | for i in set(cell_type_t): 51 | plt.scatter(np.array(col)[cell_type_t == i], 52 | np.array(row)[cell_type_t == i], s=dot_size,edgecolors='none',alpha = 0.5, 53 | c=np.array(cluster_pl)[cell_type_t[cell_type_t == i]],label = str(i), rasterized=True) 54 | plt.xticks(fontsize=20) 55 | plt.yticks(fontsize=20) 56 | plt.axis('equal') 57 | if axis_off: 58 | plt.xticks([]) 59 | plt.yticks([]) 60 | plt.title('K means (k = ' + str(k) + ')',fontsize=30) 61 | if (type(scale_bar_t) != type(None)): 62 | add_scale_bar(scale_bar_t[0],scale_bar_t[1]) 63 | plt.savefig(f'{output_path_t}/{taskname_t}_trained_k{str(k)}.pdf',dpi = 100) 64 | return cell_label 65 | 66 | def add_scale_bar(length_t,label_t): 67 | import matplotlib.font_manager as fm 68 | from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar 69 | fontprops = fm.FontProperties(size=20, family='Arial') 70 | bar = AnchoredSizeBar(plt.gca().transData, length_t, label_t, 4, pad=0.1, 71 | sep=5, borderpad=0.5, frameon=False, 72 | size_vertical=0.1, color='black',fontproperties = fontprops) 73 | plt.gca().add_artist(bar) 74 | 75 | def plot_mid_v2(coords_q,coords_r = None,output_path='',filename = None,title_t = ['ref','query'],s_t = 8,scale_bar_t = None): 76 | plt.rcParams.update({'font.size' : 30,'axes.titlesize' : 30,'pdf.fonttype':42,'legend.markerscale' : 5}) 77 | plt.figure(figsize=[10,12]) 78 | if coords_r is not None: 79 | plt.scatter(np.array(coords_r)[:,0].tolist(), 80 | np.array(coords_r)[:,1].tolist(), s=s_t,edgecolors='none', alpha = 0.5,rasterized=True, 81 | c='#9295CA',label = title_t[0]) 82 | plt.scatter(np.array(coords_q)[:,0].tolist(), 83 | np.array(coords_q)[:,1].tolist(), s=s_t,edgecolors='none', alpha = 0.5,rasterized=True, 84 | c='#E66665',label = title_t[1]) 85 | plt.legend(fontsize=15) 86 | plt.axis('equal') 87 | if (type(scale_bar_t) != type(None)): 88 | add_scale_bar(scale_bar_t[0],scale_bar_t[1]) 89 | if (filename != None): 90 | plt.savefig(os.path.join(output_path,filename + '.pdf'),dpi = 300) 91 | 92 | def plot_mid(coords_q,coords_r,output_path='',filename = None,title_t = ['ref','query'],s_t = 8,scale_bar_t = None,axis_off = False): 93 | plt.rcParams.update({'font.size' : 30,'axes.titlesize' : 30,'pdf.fonttype':42,'legend.markerscale' : 5}) 94 | plt.figure(figsize=[10,12]) 95 | plt.scatter(np.array(coords_r)[:,0].tolist(), 96 | np.array(coords_r)[:,1].tolist(), s=s_t,edgecolors='none', alpha = 0.5,rasterized=True, 97 | c='#9295CA',label = title_t[0]) 98 | plt.scatter(np.array(coords_q)[:,0].tolist(), 99 | np.array(coords_q)[:,1].tolist(), s=s_t,edgecolors='none', alpha = 0.5,rasterized=True, 100 | c='#E66665',label = title_t[1]) 101 | plt.legend(fontsize=15) 102 | plt.axis('equal') 103 | if axis_off: 104 | plt.xticks([]) 105 | plt.yticks([]) 106 | if (type(scale_bar_t) != type(None)): 107 | add_scale_bar(scale_bar_t[0],scale_bar_t[1]) 108 | if (filename != None): 109 | plt.savefig(os.path.join(output_path,filename + '.pdf'),dpi = 300) 110 | 111 | def link_plot(all_cosine_knn_inds_t,coords_q,coords_r,k,figsize_t = [15,20],scale_bar_t = None): 112 | assign_mat = all_cosine_knn_inds_t 113 | plt.figure(figsize=figsize_t) 114 | coords_transfer_r = coords_r[np.unique(assign_mat),:] 115 | coords_transfer_q = coords_q 116 | plt.scatter(x = coords_transfer_q[:,0],y = coords_transfer_q[:,1],s = 2,rasterized=True) 117 | i = list(range(coords_transfer_q.shape[0])) 118 | j = i.copy() 119 | for i_t in range(k): 120 | idx_transfer_r_link = assign_mat[:,i_t] 121 | coords_transfer_r_link = coords_r[idx_transfer_r_link,:] 122 | t1 = np.row_stack((coords_transfer_r_link[:,0],coords_transfer_r_link[:,1])) 123 | t2 = np.row_stack((coords_transfer_q[:,0],coords_transfer_q[:,1])) 124 | plt.plot([t1[0,i],t2[0,j]],[t1[1,i],t2[1,j]],'g',lw = 0.3,rasterized=True) 125 | plt.scatter(x = coords_transfer_r[:,0],y = coords_transfer_r[:,1],s = 4,c = 'red',rasterized=True) 126 | 127 | plt.axis('equal') 128 | if (type(scale_bar_t) != type(None)): 129 | add_scale_bar(scale_bar_t[0],scale_bar_t[1]) 130 | used_dots_num = np.unique(assign_mat).shape[0] 131 | all_dots_num = np.sum(coords_r.shape[0]) 132 | return [all_dots_num,used_dots_num,format(used_dots_num/all_dots_num,'.2f')] 133 | 134 | def dsplot(coords0,coords_plaque_t,s_cell=10,s_plaque=40,col_cell='#999999',col_plaque='red',cmap_t = 'vlag',alpha = 1,vmax_t = None, title=None, scale_bar_200 = None, output_path_t = None, coords0_mask = None): 135 | if coords0_mask is not None: 136 | coords_other = coords0[~coords0_mask].copy() 137 | coords0 = coords0[coords0_mask].copy() 138 | else: 139 | coords_other = None 140 | if type(vmax_t) == type(None): 141 | vmax_t = np.abs(col_cell).max() 142 | plt.figure(figsize=(13,13)) 143 | if title is not None: 144 | plt.title(title, fontsize=30) 145 | plt.xticks(fontsize=20) 146 | plt.yticks(fontsize=20) 147 | plt.axis('equal') 148 | if type(col_cell) != str: 149 | col_cell_i = np.array(col_cell)[coords0_mask] if coords0_mask is not None else col_cell 150 | else: 151 | col_cell_i = col_cell 152 | if coords_other is not None: 153 | plt.scatter(coords_other[:,0], coords_other[:,1], s=s_cell, edgecolors='none',alpha = 0.2, 154 | c='#aaaaaa',cmap = cmap_t,rasterized=True) 155 | col=coords0[:,0].tolist() 156 | row=coords0[:,1].tolist() 157 | plt.scatter(np.array(col), np.array(row), s=s_cell, edgecolors='none',alpha = alpha, vmax=vmax_t,vmin = -vmax_t, 158 | c=col_cell_i,cmap = cmap_t,rasterized=True) 159 | plt.colorbar(ticks=[-vmax_t,0, vmax_t]) 160 | if type(coords_plaque_t) != type(None): 161 | coords1 = coords_plaque_t 162 | if type(s_plaque) != int: 163 | s_plaque_i = np.array(s_plaque) 164 | else: 165 | s_plaque_i = s_plaque 166 | if type(col_plaque) != str: 167 | col_plaque_i = np.array(col_plaque) 168 | else: 169 | col_plaque_i = col_plaque 170 | col=coords1[:,0].tolist() 171 | row=coords1[:,1].tolist() 172 | plt.scatter(np.array(col),np.array(row), s=s_plaque_i, edgecolors='none', c=col_plaque_i,rasterized=True) 173 | if scale_bar_200 is not None: 174 | add_scale_bar(scale_bar_200,'200 µm') 175 | if output_path_t is not None: 176 | plt.savefig(output_path_t,dpi = 300) 177 | plt.close('all') 178 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Zefang Tang, Shuchen Luo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | **CAST** is a Python library for physically aligning different spatial transcriptome regardless of technologies, magnification, individual variation, and experimental batch effects. 3 | CAST is composed of three modules: CAST Mark, CAST Stack, and CAST Projection. 4 | ![CAST Diagram](./README_CAST_diagram.png) 5 | 6 | # Installation Guide 7 | 8 | For Windows system, we suggest to use the `Anaconda powershell` to run following code. 9 | 10 | ## Requirements 11 | 12 | - **[Optional]** It is recommended to create a conda environment for CAST. 13 | 14 | For example, create a new environment named `cast_demo` and activate it. 15 | ``` 16 | conda create -y -n cast_demo python=3.9 17 | conda activate cast_demo 18 | ``` 19 | - **[Optional]** CAST requires `pytorch` and `dgl` 20 | 21 | [Install Pytorch](https://pytorch.org/get-started/locally/) 22 | 23 | [Install DGL](https://www.dgl.ai/pages/start.html) 24 | 25 | Users could use `nvcc --version` to check the CUDA version for installation. 26 | 27 | Here we provide with an example of the CUDA `11.3` installation code. 28 | ``` 29 | #### If CPU only #### 30 | conda install pytorch==1.11.0 cpuonly -c pytorch 31 | conda install -c dglteam dgl 32 | 33 | #### If GPU available #### 34 | conda install -y -c pytorch pytorch==1.11.0 cudatoolkit=11.3 35 | conda install -y -c dglteam dgl-cuda11.3==0.9.1 36 | ``` 37 | 38 | ## Installation 39 | If `git` is available: 40 | ``` 41 | pip install git+https://github.com/wanglab-broad/CAST.git 42 | ``` 43 | If `git` is unavailable: 44 | 45 | 1. Download the package and unpack it 46 | 47 | 2. run the code: 48 | ``` 49 | cd $package 50 | pip install -e . 51 | ``` 52 | 53 | # Demo 54 | We provide with several demos to demonstrate the functions in CAST package. 55 | Due to file size restrictions at GitHub, we cannot upload sample datasets associated with the demos. Please kindly find the data at our zenodo archive for the paper: https://zenodo.org/doi/10.5281/zenodo.12215314. This Zenodo repo also contains code and data to reproduce results in the paper. 56 | 57 | Users can use following code to open the `Jupyter notebook` (We recommend to use `Chrome` to open the jupyter notebook). 58 | ``` 59 | cd $demo_path 60 | jupyter notebook 61 | 62 | #### If remote kernel #### 63 | jupyter notebook --ip=0.0.0.0 --port=8800 64 | 65 | #### If dead kernel #### 66 | jupyter notebook --NotebookApp.max_buffer_size=21474836480 67 | ``` 68 | ## Demo 1 CAST_Mark 69 | In this demo, CAST_Mark can captures the common spatial features across multiple samples 70 | 71 | Users should first replace `$demo_path` with the CAST demo Path in the first cell. 72 | 73 | ## Demo 2 CAST_Stack_Align_S4_to_S1 74 | In this demo, Stack_Align can align two samples together. 75 | 76 | Users should first replace `$demo_path` with the CAST demo Path in the first cell. 77 | 78 | ## Demo 3 CAST_project 79 | In this demo, CAST_projectiong will project one sample to another one. 80 | 81 | Users should first replace `$demo_path` with the CAST demo Path in the first cell. 82 | 83 | ## More demos 84 | We also updated more demos in the [tutorial pages](https://cast-tutorial.readthedocs.io/en/latest/). 85 | 86 | -------------------------------------------------------------------------------- /README_CAST_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanglab-broad/CAST/3ebcf4fbe1e43dbbb9ae92ecd07e93562e156160/README_CAST_diagram.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="CAST", 5 | version="0.4", 6 | packages=find_packages(), 7 | install_requires=[ 8 | 'torch', 9 | 'matplotlib', 10 | 'seaborn', 11 | 'scikit-learn', 12 | 'h5py', 13 | 'statsmodels', 14 | 'tqdm', 15 | 'geopandas', 16 | 'Rtree', 17 | 'scanpy', 18 | 'libpysal', 19 | 'ipython', 20 | 'jupyterlab', 21 | 'jupyter', 22 | 'numpy', 23 | 'pandas' 24 | ] 25 | ) 26 | --------------------------------------------------------------------------------