├── .DS_Store ├── .gitignore ├── LICENSE ├── PYPI.rst ├── README.rst ├── setup.py ├── stereosite ├── __init__.py ├── cn │ ├── __init__.py │ ├── cellneighbor.py │ └── deconvolution.py ├── datasets │ ├── LR_database │ │ ├── CellChatDB.human.csv │ │ ├── CellChatDB.mouse.csv │ │ ├── cellphoneDB_interactions.csv │ │ └── interactions.csv │ ├── STRINGdb │ │ ├── Hsa │ │ │ ├── 9606.protein.aliases.v11.5.txt.gz │ │ │ └── 9606.protein.info.v11.5.txt.gz │ │ ├── Mmu │ │ │ ├── 10090.protein.aliases.v11.5.txt.gz │ │ │ └── 10090.protein.info.v11.5.txt.gz │ │ └── score_info.txt.gz │ ├── TF_database │ │ ├── dorothea_TF-target_regulons_hs.csv │ │ └── dorothea_TF-target_regulons_mm.csv │ └── biomart │ │ ├── human_to_mouse_biomart_export.csv │ │ └── mouse_to_human_biomart_export.csv ├── degene.py ├── plot │ ├── __init__.py │ ├── cellneighbor.py │ ├── intensity.py │ ├── mask.py │ ├── net.py │ ├── sankey.py │ ├── scii.py │ ├── scii_circos.py │ ├── scii_net.py │ └── scii_tensor.py ├── ppi │ ├── __init__.py │ ├── ppi_analysis.py │ ├── query.py │ └── run_mcl.sh ├── read │ ├── __init__.py │ └── gem.py ├── scii.py ├── scii_tensor.py └── tf_infer.py └── tutorial ├── .gitattributes ├── PPI ├── hub_net_cluster0.pdf ├── out.query_net_edgelist.txt.I40 └── query_net_edgelist.txt ├── StereoSiTE.ipynb ├── TF ├── TF_net.pdf └── tf_infer.pdf ├── data ├── .DS_Store ├── CellChatDB.mouse.csv ├── Nfkpb_pathway_mmu04064.txt ├── SS200000681TL_A1.tissue_10000.gem.gz ├── degene_test.h5ad ├── inf_aver_noCAF.csv └── test_adata.h5ad ├── scii_tensor ├── cells_circos.pdf ├── cells_lr_circos.pdf ├── core_3d_heatmap.png ├── factor_cc_heatmap.pdf ├── factor_lr_heatmap.pdf ├── igrap_network.pdf └── sankey_3d.png ├── scii_tensor_plot.ipynb └── stereosite_run.ipynb /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STOmics/StereoSiTE/58d1daf7bec2db10cc4d443154eeee5618377fee/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .pypi 3 | 4 | #data 5 | tutorial/data/ 6 | tutorial/test.* 7 | 8 | # Distribution / packaging 9 | .Python 10 | build/ 11 | develop-eggs/ 12 | dist/ 13 | downloads/ 14 | eggs/ 15 | .eggs/ 16 | lib/ 17 | lib64/ 18 | parts/ 19 | sdist/ 20 | var/ 21 | wheels/ 22 | pip-wheel-metadata/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST -------------------------------------------------------------------------------- /PYPI.rst: -------------------------------------------------------------------------------- 1 | StereoSiTE - Spatial Transcriptome Analysis in Python 2 | ====================================================== 3 | 4 | **StereoSiTE** is a package for the analysis and visualization of spatial transcriptome data. 5 | It builds on top of `anndata`_, `scanpy`_ and `squidpy`_, from which it inherits modularity and scalability. 6 | It provides analysis tools to dissect cellular neighborhood based on cell composition and quantitatively define cell-cell communication in spatial. 7 | 8 | StereoSiTE's key applications 9 | ------------------------------ 10 | 11 | - Cellular Neighborhood (CN) clustering based on cell composition of each bin 12 | - Spatial Cell Interaction Intensity (SCII) analysis 13 | 14 | Citation 15 | --------- 16 | 17 | If you use `stereosite`_ in your work, please cite the publication as follows: 18 | 19 | **StereoSiTE: A framework to spatially and quantitatively profile the cellular neighborhood organized iTME** 20 | 21 | Xing Liu, Chi Qu, Chuandong Liu, Na Zhu, Huaqiang Huang, Fei Teng, Caili Huang, Bingying Luo, Xuanzhu Liu, Yisong Xu, Min Xie, Feng Xi, Mei Li, Liang Wu, Yuxiang Li, Ao Chen, Xun Xu, Sha Liao, Jiajun Zhang 22 | 23 | bioRxiv 2022.12.31.522366; doi: https://doi.org/10.1101/2022.12.31.522366 24 | 25 | .. _scanpy: https://scanpy.readthedocs.io/en/stable/ 26 | .. _anndata: https://anndata.readthedocs.io/en/stable/ 27 | .. _squidpy: https://squidpy.readthedocs.io/en/stable/ 28 | .. _stereosite: https://github.com/STOmics/stereosite -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | |PyPI| |Downloads| |stars| |Discourse| |Zulip| 2 | 3 | StereoSiTE - Spatial Transcriptome Analysis in Python 4 | ====================================================== 5 | 6 | **StereoSiTE** is a package for the analysis and visualization of spatial transcriptome data. 7 | It builds on top of `anndata`_, `scanpy`_ and `squidpy`_, from which it inherits modularity and scalability. 8 | It provides analysis tools to dissect cellular neighborhood based on cell composition and quantitatively define cell-cell communication in spatial. 9 | 10 | StereoSiTE's key applications 11 | ------------------------------ 12 | 13 | - Cellular Neighborhood (CN) clustering based on cellular composition of each bin 14 | - Spatial Cell Interaction Intensity (SCII) analysis 15 | - Tissue domain clustering based on intercellular interactions within each bin (SCIITensor) 16 | 17 | Demo data 18 | ---------- 19 | `demo`_ (Fetch code: nFks) 20 | 21 | Citation 22 | --------- 23 | 24 | If you use `stereosite`_ in your work, please cite the publication as follows: 25 | 26 | **StereoSiTE: A framework to spatially and quantitatively profile the cellular neighborhood organized iTME** 27 | 28 | Xing Liu, Chi Qu, Chuandong Liu, Na Zhu, Huaqiang Huang, Fei Teng, Caili Huang, Bingying Luo, Xuanzhu Liu, Min Xie, Feng Xi, Mei Li, Liang Wu, Yuxiang Li, 29 | Ao Chen, Xun Xu, Sha Liao, Jiajun Zhang, StereoSiTE: a framework to spatially and quantitatively profile the cellular neighborhood organized iTME, 30 | GigaScience, Volume 13, 2024, giae078, https://doi.org/10.1093/gigascience/giae078 31 | 32 | **SCIITensor: A tensor decomposition based algorithm to construct actionable TME modules with spatially resolved intercellular communications** 33 | 34 | Huaqiang Huang, Chuandong Liu, Xin Liu, Jingyi Tian, Feng Xi, Mei Li, Guibo Li, Ao Chen, Xun Xu, Sha Liao, Jiajun Zhang, Xing Liu 35 | bioRxiv 2024.05.21.595103; doi: https://doi.org/10.1101/2024.05.21.595103 36 | 37 | 38 | Installation 39 | ------------- 40 | 41 | Install StereoSiTE via PyPi by running: 42 | 43 | >>> conda create -y -n stereosite python=3.9.12 44 | >>> conda activate stereosite 45 | >>> pip install stereosite 46 | 47 | or via raw code by running: 48 | 49 | >>> git clone https://github.com/STOmics/StereoSiTE.git 50 | >>> cd StereoSiTE 51 | >>> python setup.py install 52 | 53 | Run examples 54 | ------------- 55 | 56 | **Transfer spatial gene expression matrix (gem|gef) to anndata** 57 | :: 58 | 59 | from stereosite.read.gem import Gem_Reader 60 | gem_file = "to/path/sn.gem.gz" 61 | gem_reader = Gem_Reader(gem_file) 62 | # without cell mask, bin_size must be specified. Here bin_size=200 63 | adata_bin200 = gem_reader.gem2anndata(200) 64 | adata_bin200_file = "to/path/sn_bin200.h5ad" 65 | adata_bin200.write(adata_bin200_file) 66 | # or with cell mask, gem will be transfered to anndata in single-cell resolution. 67 | mask_file = "to/path/sn_mask.tif" 68 | adata = gem_reader.gem_with_cellmask_2anndata(mask_file) 69 | 70 | **Cellular Neighborhood (CN)** 71 | :: 72 | 73 | from stereosite.cn.deconvolution import Cell2location 74 | # Get the cellular composition of each square bin by deconvolution 75 | ref_file = "to/path/scCell_reference.csv" 76 | adata_file = "to/path/sn_bin200.h5ad" 77 | out_dir = "to/out/deconvolution/bin200" 78 | cell2loc = Cell2location(ref_file, adata_file, out_dir = out_dir, bin_size = 200, gpu = 0) 79 | cell2loc.run_deconvolution() 80 | # Analyze the Cellular Neighborhood based on deconvolution result 81 | from stereosite.cn.cellneighbor import cn_deconvolve 82 | import anndata 83 | adata_anno_file = "to/out/deconvolution/bin200/cell2location_map/sp.h5ad" 84 | adata = anndata.read(adata_anno_file) 85 | # use_rep specify matrix used to calculate cell composition of every bin 86 | cn_deconvolve(adata, use_rep='q05_cell_abundance_w_sf') 87 | # Or use the annotated cell bin data to dissect the CNs 88 | from stereosite.cn.cellneighbor import cn_cellbin 89 | import anndata 90 | adata = anndata.read(adata_anno_file) 91 | cn_cellbin(adata, 400, n_neighbors = 20, resolution = 0.4, min_dist = 0.1) 92 | # CN result visualization 93 | from stereosite.plot.cellneighbor import umap, heatmap, spatial 94 | spatial(adata, spot_size=20) 95 | umap(adata) 96 | heatmap(adata) 97 | 98 | **Spatial Cell Interaction Intensity (SCII)** 99 | :: 100 | 101 | from stereosite.scii import intensities_count 102 | # The annotated cellbin or square bin at single-cell resolution data is required. 103 | # Choose LR database based on the sample type. CellChatDB provide mouse and human database separatly. 104 | # CellphoneDB provide only human database, but also can be used to analyse mouse data by transfer mouse gene into homologous 105 | # human gene, which will be automaticaly done by the software. 106 | interactiondb_file = "./datasets/LR_database/CellChatDB.mouse.csv" 107 | scii_dict = intensities_count(adata, interactiondb_file, 108 | distance_threshold = 50, 109 | anno = 'cell_type') 110 | # Or we can specify different distance_threshold for individual LR types 111 | scii_dict = intensities_count(adata, interactiondb_file, 112 | distance_threshold = {'Secreted Signaling': 200, 'ECM-Receptor': 200, 'Cell-Cell Contact': 30}, 113 | anno = 'cell_type') 114 | # Or specify the distance_coefficient parameter to consider distance when caculating interaction intensity. 115 | # distance_coefficient=0 means distance would not influence the interaction intensity. 116 | scii_dict = intensities_count(adata, interactiondb_file, 117 | distance_threshold = {'Secreted Signaling': 200, 'ECM-Receptor': 200, 'Cell-Cell Contact': 30}, 118 | distance_coefficient = {'Secreted Signaling': 1, 'ECM-Receptor': 0.1, 'Cell-Cell Contact': 0}, 119 | anno = 'cell_type') 120 | # The interaction result can be writen into a pickle file, and can be re-loaded when you want to re-analyze it. 121 | import pickle 122 | os.makedirs("./out/scii", exist_ok=True) 123 | interaction_file = "./out/scii/interactions.pkl" 124 | with open(interaction_file, 'wb') as writer: 125 | pickle.dump(scii_dict, writer) 126 | with open(interaction_file, 'rb') as reader: 127 | scii_dict = pickle.load(reader) 128 | 129 | # SCII result visualization 130 | #filter scii result based on intensities value or pathway 131 | cell_pairs = [('cell1', 'cell2'), ('cell1', 'cell3'), ...] # cell pairs that will be remained 132 | gene_list = ['gene1', 'gene2', 'gene3', ...] # list of genes that will be filtered out 133 | filter_interaction = interaction_select(scii_dict, cell_pairs=cell_pairs, filter_genes=gene_list, intensities_range=(6000, 8000)) 134 | pathway_names = ['EGF'] # Interactions of these pathways will be remained. 135 | pathway_interaction = interaction_pathway_select(scii_dict, pathway_name=pathway_names, interactiondb_file=interactiondb_file) 136 | # Users can combine the intensities filter with the pathway selection 137 | pathway_interaction = interaction_pathway_select(filter_interaction, pathway_name=pathway_names, interactiondb_file=interactiondb_file) 138 | 139 | # Visualize the interaction result by bubble plot 140 | from stereosite.plot.scii import ligrec 141 | import numpy as np 142 | ligrec(scii_dict, 143 | intensities_range=(50, np.inf), 144 | pvalue_threshold=0.05, 145 | alpha=1e-4, 146 | swap_axes=False, 147 | source_groups=["Non-immune cells", "M2-like", 'DC', 'Teff'], 148 | target_groups = ["M1-like", "M2-like", "Monocytes", "Teff", "CD8+ Tcells"], 149 | title=" ", 150 | ) 151 | # Or visualize the selected interactions 152 | ligrec(pathway_interaction, 153 | pvalue_threshold=0.05, 154 | alpha=1e-4, 155 | swap_axes=False, 156 | ) 157 | # Show spatial distribution of interaction intensity between specific cell pair meidated by specific LR pair. 158 | from stereosite.plot.intensity import intensity_insitu 159 | cells = ['Non-immune cells', 'M2-like'] 160 | genes = ['Ptprc', 'Mrc1'] 161 | intensity_insitu(adata, cells, genes, radius = 50, distance_coefficient=0.01, spot_size=5) 162 | 163 | # Visualize the interaction result by circle plot and graph 164 | from stereosite.plot import scii_circos, scii_net 165 | anno='cell_type' 166 | cell_colors = dict(zip(adata.obs[anno].cat.categories, adata.uns[f'{anno}_colors'])) # Define the color of sectors representing cells 167 | filter_matrix = filter_interaction['intensities'].fillna(0) 168 | scii_circos.cells_lr_circos(filter_matrix, cell_colors=cell_colors) 169 | pathway_matrix = pathway_interaction['intensities'].fillna(0) 170 | scii_circos.cells_lr_circos(pathway_matrix, cell_colors=cell_colors) 171 | scii_circos.cells_circos(filter_matrix) 172 | scii_circos.cells_circos(pathway_matrix) 173 | 174 | #Draw the network diagram based on the Graph generated previously. 175 | g1 = scii_net.lr_link_graph_generate(filter_matrix, cell_colors=cell_colors, reducer=6) 176 | scii_net.cell_lr_grap_plot(g1, figsize=10, vertex_label_size=6) 177 | g2 = scii_net.cell_graph_generate(filter_matrix, reducer=30, cell_colors=cell_colors) 178 | scii_net.cell_graph_plot(g2, vertex_label_size=8, figsize=5, edge_width=[0.5, 3]) 179 | g3 = scii_net.lr_link_graph_generate(pathway_matrix, cell_colors=cell_colors, reducer=6) 180 | scii_net.cell_lr_grap_plot(g3, figsize=8, edge_width=[0.5, 2]) 181 | g4 = scii_net.cell_graph_generate(pathway_matrix, reducer=15, cell_colors=cell_colors) 182 | scii_net.cell_graph_plot(g4, vertex_label_size=8, figsize=5, edge_width=[0.5, 3]) 183 | 184 | 185 | **SCIITensor -- single sample analysis** 186 | :: 187 | 188 | from stereosite import scii_tensor 189 | import anndata 190 | import pandas as pd 191 | import seaborn as sns 192 | import matplotlib as mpl 193 | import matplotlib.pyplot as plt 194 | import pickle 195 | import numpy as np 196 | import scanpy as sc 197 | # Generate interactiontensor object and evaluate the optimal combination of ranks 198 | adata = anndata.read(adata_anno_file) 199 | interactionDB = "./datasets/LR_database/CellChatDB.mouse.csv" 200 | sct = scii_tensor.InteractionTensor(adata, interactionDB=interactionDB) 201 | radius = {'Secreted Signaling': 100, 'ECM-Receptor': 100, 'Cell-Cell Contact': 30} 202 | scii_tensor.build_SCII(sct, radius=radius, window_size=200, anno_col='cell2loc_anno') 203 | scii_tensor.process_SCII(sct, zero_remove=True, log_data=True) 204 | reconstruction_errors = scii_tensor.evaluate_ranks(sct, use_gpu=True, device='cuda:1') 205 | # Visualize the reconstruction errors using line plot 206 | from stereosite.plot.scii_tensor import reconstruction_error_line 207 | reconstruction_error_line(reconstruction_errors, figsize=(4, 4)) 208 | # Decompose the interaction tensor with optimal combination of ranks 209 | scii_tensor.SCII_Tensor(sct, rank=[15, 15, 15], device='cuda:0') 210 | with open("out/scii_tensor_res.pkl", "wb") as f: 211 | pickle.dump(sct, f) 212 | # spatial distribution of each TME module 213 | import scanpy as sc 214 | sc.pl.spatial(sct.adata, color='TME_module', img_key=None, spot_size=20) 215 | 216 | ## visualization of core matrix 217 | from stereosite.plot import sankey, scii_circos, scii_net 218 | #normalize the core matrix 219 | norm_core = scii_tensor.core_normalization(sct.core, feature_range=(0, 100)) 220 | #process core matrix to generate dataFrame that will be used to draw sankey plot. 221 | left_df, right_df = sankey.core_process(norm_core) 222 | sankey.sankey_3d(left_df, right_df, link_alpha=0.5, interval=0.005) 223 | 224 | from stereosite.plot.scii_tensor import tme_core_heatmap, core_heatmap 225 | core_heatmap(norm_core) # 3D heatmap plot showing the core matrix 226 | tme_core_heatmap(sct.core, tme_number=1, figsize=(4, 4)) # 2D heatmap plot showing the result of one TME 227 | ## cell-cell factor heatmap 228 | from stereosite.scii_tensor import top_pair 229 | import seaborn as sns 230 | import matplotlib.pyplot as plt 231 | top_cc_pair = top_pair(sct, pair='cc', top_n=20) 232 | fig = sns.clustermap(top_cc_pair.T, cmap="Purples", standard_scale=0, metric='euclidean', method='ward', 233 | row_cluster=False, dendrogram_ratio=0.05, cbar_pos=(1.02, 0.6, 0.01, 0.3), 234 | figsize=(4, 6), 235 | ) 236 | ## ligand-receptor factor heatmap 237 | top_lr_pair = top_pair(sct, pair='lr', top_n=20) 238 | fig = sns.clustermap(top_lr_pair.T, cmap="Purples", standard_scale=0, metric='euclidean', method='ward', 239 | row_cluster=False, dendrogram_ratio=0.05, cbar_pos=(1.02, 0.6, 0.01, 0.3), 240 | figsize=(4, 6), 241 | ) 242 | 243 | ## visualize selected interactions using heatmap 244 | from stereosite.plot.scii_tensor import interaction_heatmap 245 | interactions = scii_tensor.interaction_select(sct, 246 | tme_module=1, 247 | cellpair_module=1, 248 | lrpair_module=11, n_lr=15, n_cc=15) 249 | interaction_heatmap(interactions, figsize=(5, 3), vmax=50) 250 | ## visualize selected interactions using circle plot 251 | from stereosite.plot.scii_circos import cells_lr_circos, cells_circos, lr_circos 252 | cells = adata.obs['cell2loc_anno'].unique() 253 | cell_colors = dict(zip(adata.obs[anno].cat.categories, adata.uns[f'{anno}_colors'])) # Define the color of sectors representing cells 254 | scii_circos.cells_lr_circos(interaction_matrix, cells=cells, cell_colors=cell_colors, scii_tensor=True) 255 | #Draw the circos which only contains cell types and the links between them. 256 | scii_circos.cells_circos(interaction_matrix, cells, cell_colors=cell_colors, label_orientation='vertical', scii_tensor=True) 257 | #Draw circos which only contains ligand-receptor genes 258 | scii_circos.lr_circos(interaction_matrix, cells=cells, scii_tensor=True) 259 | 260 | #Draw the network diagram based on the Graph generated previously. 261 | from stereosite.plot.scii_net import lr_link_graph_generate 262 | g1 = lr_link_graph_generate(interaction_matrix, cells = cells, separator="_", cell_colors=cell_colors, scii_tensor=True) 263 | scii_net.cell_lr_grap_plot(g1, figsize=10, 264 | 265 | **SCIITensor -- multiple sample analysis** 266 | :: 267 | 268 | adata_1 = anndata.read(adata_anno_file_1) 269 | ## decompose another sample data 270 | ## evaluate the optimal combination of ranks 271 | interactionDB = "./datasets/LR_database/CellChatDB.mouse.csv" 272 | sct_1 = scii_tensor.InteractionTensor(adata_1, interactionDB=interactionDB) 273 | radius = {'Secreted Signaling': 100, 'ECM-Receptor': 100, 'Cell-Cell Contact': 30} 274 | scii_tensor.build_SCII(sct_1, radius=radius, window_size=200, anno_col='cell2loc_anno') 275 | scii_tensor.process_SCII(sct_1, zero_remove=True, log_data=True) 276 | reconstruction_errors = scii_tensor.evaluate_ranks(sct_1, use_gpu=True, device='cuda:1') 277 | ## visualize the reconstruction errors using line plot 278 | from stereosite.plot.scii_tensor import reconstruction_error_line 279 | reconstruction_error_line(reconstruction_errors, figsize=(4, 4)) 280 | scii_tensor.SCII_Tensor(sct_1, rank=(20, 20, 13), device='cuda:0') 281 | with open("out/scii_tensor_res_1.pkl", "wb") as f: 282 | pickle.dump(sct_1, f) 283 | ## merge decomposed matrices 284 | sct_merge = scii_tensor.merge_data([sct, sct_1], patient_id=['p1' ,'p2']) 285 | ## visualize the reconstruction errors 286 | from stereosite.plot.scii_tensor import reconstruction_error_line 287 | reconstruction_error_line(reconstruction_errors, figsize=(4, 4)) 288 | scii_tensor.SCII_Tensor_multiple(sct_merge, rank=[15,15,10], device='cuda:1') 289 | ## spatial distribution of meta-module 290 | sc.pl.spatial(sct_merge.adata[0], color=['TME_module', 'TME_meta_module'], img_key=None, spot_size=20) 291 | sc.pl.spatial(sct_merge.adata[1], color=['TME_module', 'TME_meta_module'], img_key=None, spot_size=20) 292 | #normalize the core matrix 293 | norm_core = scii_tensor.core_normalization(sct.core, feature_range=(0, 100)) 294 | #process core matrix to generate dataFrame that will be used to draw sankey plot. 295 | left_df, right_df = sankey.core_process(norm_core) 296 | sankey.sankey_3d(left_df, right_df, link_alpha=0.5, interval=0.005) 297 | from stereosite.plot.scii_tensor import tme_core_heatmap, core_heatmap 298 | core_heatmap(norm_core) # 3D heatmap plot showing the core matrix 299 | tme_core_heatmap(sct.core, tme_number=1, figsize=(4, 4)) # 2D heatmap plot showing the result of one TME 300 | 301 | ## visualize selected interactions using heatmap 302 | from stereosite.plot.scii_tensor import interaction_heatmap 303 | interactions = scii_tensor.interaction_select_multiple(sct_merge, 304 | tme_module=0, sample='p2', 305 | cellpair_module=0, 306 | lrpair_module=1, n_lr=15, n_cc=15) 307 | interaction_heatmap(interactions, figsize=(5, 3), vmax=10) 308 | ##visualize selected interactions using circle plot 309 | from stereosite.plot.scii_circos import cells_lr_circos, cells_circos, lr_circos 310 | cells = adata.obs['cell2loc_anno'].unique() 311 | cell_colors = dict(zip(adata.obs[anno].cat.categories, adata.uns[f'{anno}_colors'])) # Define the color of sectors representing cells 312 | scii_circos.cells_lr_circos(interaction_matrix, cells=cells, cell_colors=cell_colors, scii_tensor=True) 313 | #Draw the circos which only contains cell types and the links between them. 314 | scii_circos.cells_circos(interaction_matrix, cells, cell_colors=cell_colors, label_orientation='vertical', scii_tensor=True) 315 | #Draw circos which only contains ligand-receptor genes 316 | scii_circos.lr_circos(interaction_matrix, cells=cells, scii_tensor=True) 317 | 318 | .. |stars| image:: https://img.shields.io/github/stars/STOmics/StereoSiTE?logo=GitHub&color=yellow 319 | :target: https://github.com/STOmics/StereoSiTE/stargazers 320 | 321 | .. |PyPI| image:: https://img.shields.io/pypi/v/stereosite.svg 322 | :target: https://pypi.org/project/stereosite/ 323 | :alt: PyPI 324 | 325 | .. |Downloads| image:: https://static.pepy.tech/badge/stereosite 326 | :target: https://pepy.tech/project/stereosite 327 | :alt: Downloads 328 | 329 | .. |Discourse| image:: https://img.shields.io/discourse/posts?color=yellow&logo=discourse&server=https%3A%2F%2Fdiscourse.scverse.org 330 | :target: https://discourse.scverse.org/ 331 | :alt: Discourse 332 | 333 | .. |Zulip| image:: https://img.shields.io/badge/zulip-join_chat-%2367b08f.svg 334 | :target: https://scverse.zulipchat.com 335 | :alt: Zulip 336 | 337 | .. _scanpy: https://scanpy.readthedocs.io/en/stable/ 338 | .. _anndata: https://anndata.readthedocs.io/en/stable/ 339 | .. _squidpy: https://squidpy.readthedocs.io/en/stable/ 340 | .. _stereosite: https://github.com/STOmics/stereosite 341 | .. _demo: https://bgipan.genomics.cn/#/link/ilOA8JTgy7jKrNX4ZrOc 342 | 343 | 344 | 345 | 346 | 347 | 348 | 349 | 350 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open("PYPI.rst", "r") as f: 4 | long_description = f.read() 5 | 6 | setup( 7 | name='stereosite', 8 | version='2.2.3', 9 | author='LiuXing', 10 | author_email='liuxing2@genomics.cn', 11 | description=('Analysis spatial transcriptomics data'), 12 | long_description=long_description, 13 | license='GPL-3 License', 14 | keywords='spatial cell interaction intensity', 15 | url="https://github.com/STOmics/StereoSiTE", 16 | 17 | packages=find_packages(), #['stereosite'], #需要打包的目录列表 18 | 19 | include_package_data=True, 20 | platforms='any', 21 | #需要安装的依赖包 22 | install_requires = [ 23 | 'anndata>=0.8.0', 24 | 'scanpy>=1.9.1', 25 | 'squidpy>=1.1.2', 26 | 'decoupler>=1.4.0', 27 | 'pydeseq2>=0.3.6', 28 | 'networkx>=3.1', 29 | 'tensorly>=0.8.1', 30 | 'scikit-learn>=1.2.1', 31 | 'torch>=1.11.0', 32 | 'igraph>=0.10.4', 33 | 'pycirclize>=1.1.0', 34 | 'cell2location==0.1.3' 35 | ], 36 | ) 37 | -------------------------------------------------------------------------------- /stereosite/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STOmics/StereoSiTE/58d1daf7bec2db10cc4d443154eeee5618377fee/stereosite/__init__.py -------------------------------------------------------------------------------- /stereosite/cn/__init__.py: -------------------------------------------------------------------------------- 1 | """The deconvolution and annotation module""" -------------------------------------------------------------------------------- /stereosite/cn/cellneighbor.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # _*_ coding: utf-8 _*_ 3 | 4 | #@Author: LiuXing liuxing2@genomics.cn 5 | #@Date: 2023-07-12 15:24:52 6 | #@Last Modified by: LiuXing 7 | #@Last Modified time: 2023-07-12 15:24:52 8 | 9 | 10 | 11 | import os, sys 12 | from optparse import OptionParser 13 | import numpy as np 14 | import pandas as pd 15 | import scanpy as sc 16 | import anndata 17 | import seaborn as sns 18 | import matplotlib.pyplot as plt 19 | from scipy import sparse as sp 20 | import logging as logg 21 | 22 | def cn_deconvolve(adata: anndata, 23 | use_rep: str = 'q05_cell_abundance_w_sf', 24 | n_neighbors: int = 20, 25 | resolution: float = 0.4, 26 | min_dist: float = 0.2, 27 | random_stat: int = 100, 28 | key_added: str = None): 29 | """ 30 | Cluster bins with similar cellular composition, which was calculated from cell type deconvolution result. 31 | 32 | Parameters 33 | ---------- 34 | adata 35 | anndata 36 | use_rep 37 | Key value of anndata.obsm, can be used to obtain deconvolution matrix of each bin 38 | n_neighbors 39 | Number of neighbors that will be connected with the center point when constructing k-nearest neighbor graph 40 | resolution 41 | A parameter value controlling the coarseness of the clustering. Higher values lead to more clusters. Used for leiden function. 42 | min_dist 43 | The effective minimum distance between embedded points. Smaller values will result in a more clustered/clumped embedding where 44 | nearby points on the manifold are drawn closer together, while larger values will result on a more even dispersal of points. 45 | The value should be set relative to the spread value, which determines the scale at which embedded points will be spread out. 46 | The default of in the umap-learn package is 0.1. 47 | key_added 48 | Column name of adata.obs, which store the result of leiden clustering. 49 | """ 50 | if key_added == None: 51 | key_added = 'cell_neighbor' 52 | sc.pp.neighbors(adata, use_rep=use_rep, n_neighbors=n_neighbors, key_added='CN') 53 | sc.tl.leiden(adata, resolution=resolution, key_added=key_added, neighbors_key='CN') 54 | sc.tl.umap(adata, min_dist=min_dist, spread=1, random_state=random_stat, neighbors_key='CN') 55 | 56 | # analysis each Cellular Neighborhood's celltype composition according to bin level cell2location annotation matrix 57 | celltypes= [str[23:] for str in list(adata.obsm[use_rep])] 58 | # sum celltypes in each Cellular Neighborhood 59 | cn_pct = pd.pivot_table(adata.obs,columns=key_added,values=celltypes,aggfunc=np.sum).T 60 | # calculate percentage 61 | cn_pct = cn_pct.apply(lambda x: x/x.sum(), axis=1) 62 | adata.uns['CN']['cell_composition'] = cn_pct 63 | 64 | def cn_cellbin(adata: anndata, 65 | bin_size: int = 200, 66 | anno: str = 'cell2loc_anno', 67 | n_neighbors: int = 20, 68 | resolution: float = 0.4, 69 | min_dist: float = 0.1, 70 | key_added: str = None, 71 | random_stat:int = 100, 72 | ): 73 | """ 74 | Bin the annotated ST data in single-resolution, and cluster bins with similar cellular composition. 75 | 76 | Parameters 77 | ---------- 78 | adata 79 | anndata 80 | bin_size 81 | Determines the size of bin. Take 200 as an example, calculate the cellular composition in every sparately arranged 200x200 square bin. 82 | anno 83 | Column name of anndata.obs, which contains the cell type information. 84 | n_neighbors 85 | Number of neighbors that will be connected with the center point when constructing k-nearest neighbor graph 86 | resolution 87 | A parameter value controlling the coarseness of the clustering. Higher values lead to more clusters. Used for leiden function. 88 | min_dist 89 | The effective minimum distance between embedded points. Smaller values will result in a more clustered/clumped embedding where 90 | nearby points on the manifold are drawn closer together, while larger values will result on a more even dispersal of points. 91 | The value should be set relative to the spread value, which determines the scale at which embedded points will be spread out. 92 | The default of in the umap-learn package is 0.1. 93 | key_added 94 | Column name of adata.obs, which store the result of leiden clustering. 95 | """ 96 | if key_added == None: 97 | key_added = 'cell_neighbor' 98 | bin_cor =[str(x[0]) + "-" + str(x[1]) for x in ((adata.obsm['spatial']//bin_size)*bin_size+(bin_size/2)).astype(int)] 99 | adata.obs['bin_cor'] = bin_cor 100 | groups = adata.obs[anno].groupby(adata.obs['bin_cor']) 101 | cellbin_count = pd.DataFrame(index=list(set(bin_cor)), columns = adata.obs[anno].unique()) 102 | for group in groups: 103 | for cell in cellbin_count.columns: 104 | cell_count = group[1][group[1]==cell].shape[0] 105 | cellbin_count.loc[group[0], cell] = cell_count 106 | 107 | from sklearn.neighbors import NearestNeighbors 108 | #construct nearest neighbor graph 109 | nbrs = NearestNeighbors(n_neighbors=n_neighbors, algorithm='ball_tree').fit(cellbin_count) 110 | distances, indices = nbrs.kneighbors(cellbin_count) 111 | X = sp.coo_matrix(([], ([], [])), shape=(indices.shape[0], 1)) 112 | adjacency = _get_sparse_matrix_from_indices_distances_umap(indices, distances, distances.shape[0], n_neighbors) 113 | g = _get_igraph_from_adjacency(adjacency, directed=False) 114 | 115 | #cluster bins by leiden 116 | import leidenalg 117 | partition_type = leidenalg.RBConfigurationVertexPartition 118 | partition_kwargs = dict() 119 | partition_kwargs['weights'] = np.array(g.es['weight']).astype(np.float64) 120 | partition_kwargs['n_iterations'] = -1 121 | partition_kwargs['seed'] = 1 122 | partition_kwargs['resolution_parameter'] = resolution 123 | part = leidenalg.find_partition(g, partition_type, **partition_kwargs) 124 | groups = np.array(part.membership) 125 | cn_cluster = dict(zip(cellbin_count.index.values, groups)) 126 | adata.obs[key_added] = adata.obs['bin_cor'].map(cn_cluster).astype('category') 127 | adata.uns['CN'] = dict() 128 | adata.uns['CN']['params'] = {'n_neighbors': n_neighbors, 129 | 'random_state': random_stat, 130 | } 131 | adata.uns['CN']['leiden_cluster'] = groups 132 | values = cellbin_count.columns 133 | cellbin_count[key_added] = groups 134 | cn_pct = pd.pivot_table(cellbin_count, columns = key_added, values = values, aggfunc=np.sum).T 135 | cn_pct = cn_pct.apply(lambda x: x/x.sum()*100, axis=1) 136 | adata.uns['CN']['cell_composition'] = cn_pct 137 | 138 | #umap 139 | import umap 140 | umap_model = umap.UMAP(n_neighbors=n_neighbors, min_dist=min_dist, n_components=2, random_state=random_stat) 141 | umap_result = umap_model.fit_transform(cellbin_count.values) 142 | 143 | adata.uns['CN']['umap'] = umap_result 144 | 145 | 146 | def _get_sparse_matrix_from_indices_distances_umap( 147 | knn_indices, knn_dists, n_obs, n_neighbors 148 | ): 149 | rows = np.zeros((n_obs * n_neighbors), dtype=np.int64) 150 | cols = np.zeros((n_obs * n_neighbors), dtype=np.int64) 151 | vals = np.zeros((n_obs * n_neighbors), dtype=np.float64) 152 | 153 | for i in range(knn_indices.shape[0]): 154 | for j in range(n_neighbors): 155 | if knn_indices[i, j] == -1: 156 | continue # We didn't get the full knn for i 157 | if knn_indices[i, j] == i: 158 | val = 0.0 159 | else: 160 | val = knn_dists[i, j] 161 | 162 | rows[i * n_neighbors + j] = i 163 | cols[i * n_neighbors + j] = knn_indices[i, j] 164 | vals[i * n_neighbors + j] = val 165 | 166 | result = sp.coo_matrix((vals, (rows, cols)), shape=(n_obs, n_obs)) 167 | result.eliminate_zeros() 168 | return result.tocsr() 169 | 170 | def _get_igraph_from_adjacency(adjacency, directed=None): 171 | """Get igraph graph from adjacency matrix.""" 172 | import igraph as ig 173 | 174 | sources, targets = adjacency.nonzero() 175 | weights = adjacency[sources, targets] 176 | if isinstance(weights, np.matrix): 177 | weights = weights.A1 178 | g = ig.Graph(directed=directed) 179 | g.add_vertices(adjacency.shape[0]) # this adds adjacency.shape[0] vertices 180 | g.add_edges(list(zip(sources, targets))) 181 | try: 182 | g.es['weight'] = weights 183 | except KeyError: 184 | pass 185 | if g.vcount() != adjacency.shape[0]: 186 | logg.warning( 187 | f'The constructed graph has only {g.vcount()} nodes. ' 188 | 'Your adjacency matrix contained redundant nodes.' 189 | ) 190 | return g 191 | 192 | -------------------------------------------------------------------------------- /stereosite/cn/deconvolution.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # _*_ coding: utf-8 _*_ 3 | 4 | # @Author: LiuXing liuxing2@genomics.cn 5 | # @Date: 2022-06-15 17:27:47 6 | # @Last Modified by: LiuXing 7 | # @Last Modified time: 2022-06-15 17:27:47 8 | 9 | from email.policy import default 10 | import sys, os 11 | 12 | import scanpy as sc 13 | import anndata 14 | import pandas as pd 15 | import numpy as np 16 | import matplotlib.pyplot as plt 17 | import matplotlib as mpl 18 | 19 | import cell2location 20 | import scvi 21 | from scvi import REGISTRY_KEYS 22 | import matplotlib 23 | from matplotlib import rcParams 24 | rcParams['pdf.fonttype'] = 42 #enables correct plotting of text 25 | import seaborn as sns 26 | from optparse import OptionParser 27 | 28 | class Cell2location(): 29 | """ 30 | Run cell type deconvolution or annotation by cell2location, including reference and ST(Spatial Transcriptomics) data procession. 31 | 32 | Parameters 33 | ---------- 34 | ref_file 35 | reference file path, which can be csv or anndata format. 36 | adata_vis_file 37 | ST file path, which will be deconvoluted or annotated 38 | out_dir 39 | output directory path that will be used to save result. Three subdirectory will be created: 40 | reference_signatures: save processed single-cell reference in csv format if the ref_file was given in anndata format 41 | cell2location_map: save deconvolution model and result in anndata format 42 | figures: save figures generated in the deconvolution or annotation process. 43 | bin_size 44 | the bin size of ST data in h5ad file, default = 50. If the given data was in single-cell resolution, please specify bin_size=1. 45 | the N_cells_per_location will be caculated based on this parameter. N_cell_per_location = (bin_size*500/10000)^2 if bins_size > 1 else 1 46 | gpu 47 | Load model on default GPU if available (if None or True), or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False). 48 | """ 49 | def __init__(self, ref_file: str, 50 | adata_vis_file: str, 51 | out_dir: str = os.getcwd(), 52 | bin_size: int = 50, 53 | gpu = True): 54 | self.ref_file = ref_file 55 | self.adata_vis_file = adata_vis_file 56 | self.results_folder = out_dir 57 | self.N_cells_per_location = int((bin_size*500/10000)**2) if bin_size > 1 else 1 #calculate the cell number per bin based on bin size 58 | if str(gpu).upper() == 'TRUE': 59 | self.gpu = True 60 | elif str(gpu).upper() == 'FALSE': 61 | self.gpu = False 62 | else: 63 | self.gpu = f"cuda:{gpu}" 64 | self.ref_run_name = f'{self.results_folder}/reference_signatures' 65 | self.run_name = f'{self.results_folder}/cell2location_map' 66 | self.figures = f'{self.results_folder}/figures' 67 | os.makedirs(self.ref_run_name, exist_ok=True) 68 | os.makedirs(self.run_name, exist_ok=True) 69 | os.makedirs(self.figures, exist_ok=True) 70 | sc.settings.figdir = self.figures 71 | 72 | def run_deconvolution(self): 73 | """ 74 | run both reference processing and deconvolution with default parameters 75 | """ 76 | if self.ref_file.endswith(".h5ad"): 77 | inf_aver = self.process_ref() 78 | elif self.ref_file.endswith(".csv"): 79 | inf_aver = pd.read_csv(self.ref_file, index_col=0) 80 | adata_vis = self.process_vis(inf_aver) 81 | return adata_vis 82 | 83 | def process_ref(self, 84 | batch_key: str = 'sample', 85 | labels_key: str = 'cell_type', 86 | max_epochs: int = 1500) -> pd.DataFrame: 87 | """ 88 | process single-cell sequence reference data 89 | 90 | Parameters 91 | ---------- 92 | batch_key 93 | specify the key for obtaining batch infomation. For example, if the reference data collected from different sample, reaction or exprement batch, 94 | data source should be markered with batch key, and the program will revise the batch effect. 95 | default = sample 96 | labels_key 97 | key name for getting cell type annotation information, default = cell_type. 98 | max_epochs 99 | maximal epochs for model training, default=1500. 100 | 101 | Returns 102 | ---------- 103 | infered average gene expression vector of every cell type 104 | """ 105 | sc.settings.figdir = self.figures 106 | inf_aver_file = f"{self.ref_run_name}/inf_aver.csv" 107 | if (os.path.exists(inf_aver_file)): 108 | inf_aver = pd.read_csv(inf_aver_file, index_col=0) 109 | return inf_aver 110 | elif (os.path.exists(f"{self.ref_run_name}/model.pt")): 111 | adata_ref_file = f"{self.ref_run_name}/sc.h5ad" 112 | adata_ref = sc.read_h5ad(adata_ref_file) 113 | mod = cell2location.models.RegressionModel.load(self.ref_run_name, adata_ref, use_gpu=self.gpu) 114 | else: 115 | adata_ref = sc.read_h5ad(self.ref_file) 116 | adata_ref.obs_names_make_unique() 117 | sc.pp.filter_cells(adata_ref, min_genes = 0) 118 | sc.pp.filter_genes(adata_ref, min_counts = 10) 119 | adata_ref.var['mt'] = adata_ref.var_names.str.startswith(('mt-', 'MT-')) 120 | sc.pp.calculate_qc_metrics(adata_ref, qc_vars=['mt'], percent_top=None, log1p=False, inplace=True) 121 | sc.pl.violin(adata_ref, ['n_genes', 'total_counts', 'pct_counts_mt'], jitter=0.4, multi_panel=True, save='ref_violin.png') 122 | sc.pl.scatter(adata_ref, x='total_counts', y='n_genes', save = 'ref_scatter.png') 123 | 124 | from cell2location.utils.filtering import filter_genes 125 | selected = filter_genes(adata_ref, cell_count_cutoff=5, cell_percentage_cutoff2=0.03, nonz_mean_cutoff=1.12) 126 | plt.savefig(f"{self.figures}/ref_filter_genes.png") 127 | plt.clf() 128 | #filter the object 129 | adata_ref = adata_ref[:, selected].copy() 130 | #prepare anndata for the regression model 131 | cell2location.models.RegressionModel.setup_anndata(adata=adata_ref, 132 | # reaction / sample / batch 133 | batch_key = batch_key, 134 | #batch_key = 'Sample', 135 | # cell type, covariate used fir contructing signatures 136 | labels_key = labels_key, 137 | # multiplicative technical effects (platform, 3' vs 5', donor effect) 138 | #categorical_covariate_keys = ['Experiment'] 139 | #categorical_covariate_keys = ['Method'] 140 | ) 141 | # create the regression model 142 | from cell2location.models import RegressionModel 143 | mod = RegressionModel(adata_ref) 144 | #vies anndata_setup as a sanity check 145 | mod.view_anndata_setup() 146 | mod.train(max_epochs=max_epochs, use_gpu=self.gpu) 147 | 148 | # In this section, we export the estimated cell abundance (summary of the posterior distribution). 149 | adata_ref = mod.export_posterior( 150 | adata_ref, sample_kwargs={'num_samples':1000, 'batch_size': 2500, 'use_gpu': self.gpu} 151 | ) 152 | 153 | #Save model 154 | mod.save(f"{self.ref_run_name}", overwrite=True) 155 | 156 | #Save anndata object with results 157 | adata_file = f"{self.ref_run_name}/sc.h5ad" 158 | adata_ref.write(adata_file) 159 | mod.plot_history(20) 160 | plt.savefig(f"{self.figures}/ref_train_history.png") 161 | plt.clf() 162 | inf_aver = mod.samples[f"post_sample_means"]["per_cluster_mu_fg"].T 163 | if "detection_y_c" in list(mod.samples[f"post_sample_means"].keys()): 164 | inf_aver = inf_aver * mod.samples[f"post_sample_means"]["detection_y_c"].mean() 165 | aver = mod._compute_cluster_averages(key=REGISTRY_KEYS.LABELS_KEY) 166 | aver = aver[mod.factor_names_] 167 | plt.hist2d(np.log10(aver.values.flatten() + 1), 168 | np.log10(inf_aver.flatten() + 1), 169 | bins = 50, 170 | norm = matplotlib.colors.LogNorm(), 171 | ) 172 | plt.xlabel("Mean expression for every gene in every cluster") 173 | plt.ylabel("Estimated expression for every gene in every cluster") 174 | plt.savefig(f"{self.figures}/ref_train_QC.png") 175 | plt.clf() 176 | #export estimated expression in each cluster 177 | if 'means_per_cluster_mu_fg' in adata_ref.varm.keys(): 178 | inf_aver = adata_ref.varm['means_per_cluster_mu_fg'][[f'means_per_cluster_mu_fg_{i}' 179 | for i in adata_ref.uns['mod']['factor_names']]].copy() 180 | else: 181 | inf_aver = adata_ref.var[[f'means_per_cluster_mu_fg_{i}' 182 | for i in adata_ref.uns['mod']['factor_names']]].copy() 183 | inf_aver.columns = adata_ref.uns['mod']['factor_names'] 184 | inf_aver.to_csv(inf_aver_file) 185 | return inf_aver 186 | 187 | def process_vis(self, 188 | inf_aver: pd.DataFrame, 189 | max_epochs: int = 5000, 190 | batch_size: int = 90000, 191 | anno: str = 'cell2loc_anno', 192 | spot_size: int = 70) -> anndata: 193 | """ 194 | deconvolute ST data based on the processed SC reference 195 | 196 | Parameters 197 | ---------- 198 | inf_aver 199 | infered average gene expression vector of every cell type 200 | max_epochs 201 | maximal epochs for model training, default=5000. 202 | batch_size 203 | batch size that determines the amount of data loaded into the gpu memory, default=90000. 204 | If the running report outOfMemory error, reduce the batch_size can help to resolve. 205 | anno 206 | label key that will be used to store annotation result, default=cell2loc_anno. 207 | spot_size 208 | specify spot size when draw bins or cells in space. 209 | Returns 210 | ----------- 211 | anndata with deconvolution and annotation result. 212 | """ 213 | sc.settings.figdir = self.figures 214 | # find shared genes and subset both anndata and reference signatures 215 | adata_vis = anndata.read(self.adata_vis_file) 216 | if adata_vis.raw != None: 217 | adata_vis = adata_vis.raw.to_adata() 218 | adata_vis.raw = adata_vis 219 | #find mitochondria-encoded (MT) genes 220 | adata_vis.var['MT_gene'] = adata_vis.var_names.str.startswith(('mt-', 'MT-')) 221 | 222 | # remove MT genes for spatial mapping (keeping their counts in the object) 223 | adata_vis.obsm['MT'] = adata_vis[:, adata_vis.var['MT_gene'].values].X.toarray() 224 | adata_vis = adata_vis[:, ~adata_vis.var['MT_gene'].values] 225 | intersect = np.intersect1d(adata_vis.var_names, inf_aver.index) 226 | adata_vis = adata_vis[:, intersect].copy() 227 | inf_aver = inf_aver.loc[intersect, :].copy() 228 | 229 | #prepare anndata for cell2location model 230 | cell2location.models.Cell2location.setup_anndata(adata=adata_vis) 231 | 232 | # create and train the model 233 | mod = cell2location.models.Cell2location( 234 | adata_vis, cell_state_df=inf_aver, 235 | # the expected average cell abundance: tissue-dependent 236 | # hyper-prior which can be estimated from paired histology: 237 | N_cells_per_location=self.N_cells_per_location, 238 | # hyperparameter controlling normalisation of 239 | # within-experiment variation in RNA detection: 240 | detection_alpha=20 241 | ) 242 | mod.view_anndata_setup() 243 | 244 | if mod.adata.n_obs < batch_size: 245 | batch_size = None 246 | mod.train(max_epochs=max_epochs, 247 | # train using full data (batch_size=None) 248 | batch_size=batch_size, 249 | # use all data points in training because 250 | # we need to estimate cell abundance at all locations 251 | train_size=1, 252 | use_gpu=self.gpu) 253 | 254 | # In this section, we export the estimated cell abundance (summary of the posterior distribution). 255 | adata_vis = mod.export_posterior( 256 | adata_vis, sample_kwargs={'num_samples': 1000, 'batch_size': 10000, 'use_gpu': self.gpu} 257 | ) 258 | 259 | # Save model 260 | mod.save(f"{self.run_name}", overwrite=True) 261 | mod.plot_history(20) 262 | plt.legend(labels=['full data training']) 263 | plt.savefig(f"{self.figures}/vis_train_history.png") 264 | plt.clf() 265 | 266 | use_n_obs = 1000 267 | ind_x = np.random.choice(mod.adata_manager.adata.n_obs, np.min((use_n_obs, mod.adata.n_obs)), replace=False) 268 | mod.expected_nb_param = mod.module.model.compute_expected( 269 | mod.samples[f"post_sample_means"], mod.adata_manager, ind_x=ind_x 270 | ) 271 | x_data = mod.adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY)[ind_x, :] 272 | x_data = np.asarray(x_data.toarray()) 273 | mod.plot_posterior_mu_vs_data(mod.expected_nb_param["mu"], x_data) 274 | plt.savefig(f"{self.figures}/vis_QC.png") 275 | plt.clf() 276 | 277 | # Save anndata object with results 278 | adata_file = f"{self.run_name}/sp.h5ad" 279 | #adata_vis.write(adata_file) 280 | # add 5% quantile, representing confident cell abundance, 'at least this amount is present', 281 | # to adata.obs with nice names for plotting 282 | adata_vis.obs[adata_vis.uns['mod']['factor_names']] = adata_vis.obsm['q05_cell_abundance_w_sf'] 283 | cellList=list(set(inf_aver.columns) & set(adata_vis.obs.columns)) 284 | with mpl.rc_context({'axes.facecolor': 'black', 'figure.figsize': [5, 5]}): 285 | sc.pl.spatial(adata_vis, img_key="hires", color=cellList, spot_size=spot_size, vmin=0, vmax='p99.2', 286 | cmap='magma', save="cell_abundance.png") 287 | 288 | adata_vis.obs[anno] = adata_vis.obs[cellList].idxmax(axis=1) 289 | 290 | # Save anndata object with results 291 | self._plotCells(anno, adata_vis, spot_size=spot_size) 292 | #self._plotMarkerGenes(anno, adata_vis) 293 | adata_vis.write(adata_file) 294 | return adata_vis 295 | 296 | def _plotCells(self, 297 | anno: str, 298 | adata_vis: anndata, 299 | spot_size: int = 70): 300 | """ 301 | draw the annotation result 302 | """ 303 | import math 304 | from tqdm import tqdm 305 | rcParams['figure.figsize'] = 5, 5 306 | sc.pl.spatial(adata_vis, img_key="hires", color=anno, spot_size=spot_size, save=f"{anno}.png") 307 | 308 | cellsCount = adata_vis.obs[anno].value_counts() 309 | cellsRate = cellsCount/cellsCount.sum()*100 310 | cellsdf = pd.concat([cellsCount, cellsRate], axis=1) 311 | cellsdf.columns = ['count', 'rate'] 312 | colors_dict = dict(zip(adata_vis.obs[anno].cat.categories, adata_vis.uns[f'{anno}_colors'])) 313 | cellsdf['color'] = cellsdf.index.map(colors_dict) 314 | cellsdf.to_csv(f"{self.run_name}/{anno}_cell_count.tsv", sep="\t") 315 | plt.figure(figsize=(8, 8)) 316 | cellsdf['count'].plot(kind = 'barh', color=cellsdf['color']) 317 | i = 0 318 | sum = cellsdf['count'].sum() 319 | for _, v in cellsdf['rate'].items(): 320 | plt.text(sum*(v+2)/100, i, '%.2f' % v, ha='center', va='bottom', fontsize=11) 321 | i+=1 322 | plt.savefig(f"{self.figures}/{anno}_cell_count.png", bbox_inches='tight') 323 | plt.clf() 324 | plotcol = 4 325 | cells = adata_vis.obs[anno].unique() 326 | plotrow = math.ceil(len(cells)/plotcol) 327 | figSize = (16, plotrow*4) 328 | fig = plt.figure(figsize=figSize,dpi=100) 329 | for j in tqdm(range(len(cells))): 330 | cell = cells[j] 331 | i = j+1 332 | row = int(i/plotcol) 333 | col = i - row*plotcol 334 | ax = plt.subplot(plotrow, plotcol, i) 335 | sc.pl.spatial(adata_vis, img_key="hires", color=anno, groups = [cell], spot_size=spot_size, show=False, ax = ax, title ="{0} ({1})".format(cell, cellsCount.loc[cell]), legend_loc=None) 336 | ax.set_xlabel("") 337 | ax.set_ylabel("") 338 | plt.savefig(f"{self.figures}/{anno}_split.png", bbox_inches='tight') 339 | plt.clf() 340 | 341 | def main(): 342 | """ 343 | This program can be used to process scRNAseq reference and annotate ST(spatial transcriptomics) data 344 | %prog [options] 345 | """ 346 | parser = OptionParser(main.__doc__) 347 | parser.add_option("-r", "--reference", action = "store", type = "str", dest = "reference", help = "reference file path, can be h5ad or csv format file.") 348 | parser.add_option("-i", "--vis", action = "store", type = "str", dest = "vis", help = "input stereo-seq data in h5ad format.") 349 | parser.add_option("-o", "--outDir", action = "store", type = "str", dest = "outDir", help = "output directory path.") 350 | parser.add_option("-g", "--gpu", action = "store", type = "int", default = 1, dest = "gpu", help = "give cuda gpu name that will be used to run this program. default=1") 351 | parser.add_option("--bin_size", action="store", type = int, default = 50, dest = "bin_size", help = "bin size of the given ST data. default=50") 352 | 353 | opts, args = parser.parse_args() 354 | 355 | if (opts.reference == None or opts.vis == None or opts.outDir == None): 356 | sys.exit(not parser.print_help()) 357 | 358 | cell2location = Cell2location(opts.reference, opts.vis, opts.outDir, bin_size=opts.bin_size, gpu = opts.gpu) 359 | cell2location.run_deconvolution() 360 | 361 | if __name__ == "__main__": 362 | main() 363 | -------------------------------------------------------------------------------- /stereosite/datasets/STRINGdb/Hsa/9606.protein.aliases.v11.5.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STOmics/StereoSiTE/58d1daf7bec2db10cc4d443154eeee5618377fee/stereosite/datasets/STRINGdb/Hsa/9606.protein.aliases.v11.5.txt.gz -------------------------------------------------------------------------------- /stereosite/datasets/STRINGdb/Hsa/9606.protein.info.v11.5.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STOmics/StereoSiTE/58d1daf7bec2db10cc4d443154eeee5618377fee/stereosite/datasets/STRINGdb/Hsa/9606.protein.info.v11.5.txt.gz -------------------------------------------------------------------------------- /stereosite/datasets/STRINGdb/Mmu/10090.protein.aliases.v11.5.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STOmics/StereoSiTE/58d1daf7bec2db10cc4d443154eeee5618377fee/stereosite/datasets/STRINGdb/Mmu/10090.protein.aliases.v11.5.txt.gz -------------------------------------------------------------------------------- /stereosite/datasets/STRINGdb/Mmu/10090.protein.info.v11.5.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STOmics/StereoSiTE/58d1daf7bec2db10cc4d443154eeee5618377fee/stereosite/datasets/STRINGdb/Mmu/10090.protein.info.v11.5.txt.gz -------------------------------------------------------------------------------- /stereosite/datasets/STRINGdb/score_info.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STOmics/StereoSiTE/58d1daf7bec2db10cc4d443154eeee5618377fee/stereosite/datasets/STRINGdb/score_info.txt.gz -------------------------------------------------------------------------------- /stereosite/degene.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- encoding: utf-8 -*- 3 | 4 | """ 5 | @File : de_gene.py 6 | @Description : de_gene.py 7 | @Author : Liuchuandong liuchuandong@genomics.cn 8 | @Date : 2023/08/04 17:08:24 9 | """ 10 | import decoupler as dc 11 | import scanpy as sc 12 | import anndata 13 | import sys, os 14 | import numpy as np 15 | import pandas as pd 16 | from pydeseq2.dds import DeseqDataSet 17 | from pydeseq2.ds import DeseqStats 18 | 19 | # Only needed for visualization: 20 | import matplotlib.pyplot as plt 21 | import seaborn as sns 22 | 23 | def deseq2(adata,sample_col='sample',groups_col='cell_neighbor', 24 | contrast=['5','Others'],batch=None,save=None): 25 | ''' 26 | Perform differential gene expression analysis in pseudo-bulk manner by using 'decoupler' and 'pydeseq2' modules. 27 | Here just give a one line method in StereoSite workflow. 28 | 29 | Parameters 30 | ---------- 31 | adata 32 | Anndata objects, make sure .X is raw counts. 33 | sample_col 34 | Column of obs contains samples names. 35 | groups_col 36 | Column of obs contains groups names. 37 | contrast 38 | list contains group names you want to compare,['5','Others'] means CN5 vs Others_CN, while ['5','1'] means CN5 vs CN1 39 | batch 40 | Column of obs contains batch names 41 | save 42 | Path to where to save the volcano plot. Infer the filetype if ending on {`.pdf`, `.png`, `.svg`}. 43 | ---------- 44 | 45 | Returns 46 | dataframe of deseq2 result 47 | 48 | ''' 49 | # get pseudobulk 50 | pdata = dc.get_pseudobulk( 51 | adata, 52 | sample_col=sample_col, 53 | groups_col=groups_col, 54 | mode='sum', 55 | min_cells=0, 56 | min_counts=0) 57 | # Filter genes 58 | genes = dc.filter_by_expr(pdata, group=groups_col, min_count=0, min_total_count=10) 59 | pdata = pdata[:, genes].copy() 60 | if 'Others' in contrast: 61 | pdata.obs['deGroup'] = [contrast[0] if i==contrast[0] else 'Others' for i in pdata.obs[groups_col]] 62 | else: 63 | pdata.obs['deGroup'] = pdata.obs[groups_col] 64 | # Build DESeq2 object 65 | if batch is None: 66 | design_factors = ['deGroup'] 67 | else: 68 | design_factors=['deGroup',batch] 69 | dds = DeseqDataSet( 70 | adata=pdata, 71 | design_factors=design_factors, 72 | ref_level=['deGroup', contrast[1]], 73 | refit_cooks=True, 74 | n_cpus=8, 75 | ) 76 | # Compute LFCs 77 | dds.deseq2() 78 | # Extract contrast between CN5 vs Others 79 | stat_res = DeseqStats(dds, contrast=["deGroup", contrast[0], contrast[1]], n_cpus=8) 80 | # Compute Wald test 81 | stat_res.summary() 82 | # Shrink LFCs 83 | stat_res.lfc_shrink(coeff='deGroup_'+contrast[0]+'_vs_'+contrast[1]) 84 | # Extract results 85 | results_df = stat_res.results_df 86 | dc.plot_volcano_df(results_df, x='log2FoldChange', y='padj', top=20,save=save) 87 | return(results_df) 88 | -------------------------------------------------------------------------------- /stereosite/plot/__init__.py: -------------------------------------------------------------------------------- 1 | """ The ploting and mask coloring module """ -------------------------------------------------------------------------------- /stereosite/plot/cellneighbor.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # _*_ coding: utf-8 _*_ 3 | 4 | #@Author: LiuXing liuxing2@genomics.cn 5 | #@Date: 2023-07-13 14:30:56 6 | #@Last Modified by: LiuXing 7 | #@Last Modified time: 2023-07-13 14:30:56 8 | 9 | import os, sys 10 | import numpy as np 11 | import pandas as pd 12 | import scanpy as sc 13 | import anndata 14 | import seaborn as sns 15 | import matplotlib.pyplot as plt 16 | 17 | def umap(adata: anndata, 18 | size: int = 10, 19 | color: str = 'cell_neighbor', 20 | legend_loc: str = 'on data', 21 | legend_fontsize: int = 20, 22 | figsize: tuple = (6, 6) 23 | ): 24 | if 'umap' not in adata.uns['CN']: 25 | with plt.rc_context({'axes.facecolor':'white','figure.figsize': figsize}): 26 | sc.pl.umap(adata, 27 | color=[color], 28 | size=size, 29 | ncols = 2, 30 | legend_loc=legend_loc, 31 | legend_fontsize=legend_fontsize) 32 | else: 33 | if f"{color}_colors" in adata.uns.keys(): 34 | palette = adata.uns[f'{color}_colors'] 35 | else: 36 | palette = 'tab20' 37 | umap_result = adata.uns['CN']['umap'] 38 | leiden_cluster = adata.uns['CN']['leiden_cluster'] 39 | fig, ax = plt.subplots(figsize=figsize) 40 | sns.scatterplot(x = umap_result[:, 0], y = umap_result[:, 1], hue = leiden_cluster, palette = palette, s = 10, ax=ax) 41 | ax.set_xlabel('UMAP1') 42 | ax.set_ylabel('UMAP2') 43 | ax.set_title('Cellular_Neighborhood') 44 | plt.show() 45 | 46 | def spatial(adata: anndata, 47 | spot_size: int = 100, 48 | figsize: list = [6, 6] 49 | ): 50 | with plt.rc_context({'axes.facecolor':'white','figure.figsize': figsize}): 51 | sc.pl.spatial(adata, 52 | color=['cell_neighbor'], 53 | size=1.3, 54 | img_key='hires', 55 | alpha=1, 56 | spot_size=spot_size) 57 | 58 | def heatmap(adata: anndata, 59 | cmap: str = 'RdYlBu_r', 60 | figsize: tuple = (5, 3.5), 61 | row_cluster: bool = False, 62 | col_cluster: bool = False, 63 | z_score = None, 64 | ): 65 | cn_pct = adata.uns['CN']['cell_composition'] 66 | sns.clustermap(cn_pct, 67 | cmap=cmap, 68 | figsize = figsize, 69 | row_cluster=row_cluster, 70 | col_cluster=col_cluster, 71 | z_score=z_score) -------------------------------------------------------------------------------- /stereosite/plot/intensity.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # _*_ coding: utf-8 _*_ 3 | 4 | #@Author: LiuXing liuxing2@genomics.cn 5 | #@Date: 2023-07-12 15:29:29 6 | #@Last Modified by: LiuXing 7 | #@Last Modified time: 2023-07-12 15:29:29 8 | 9 | import os, sys 10 | import anndata 11 | import scanpy as sc 12 | import squidpy as sq 13 | import math 14 | import matplotlib.pyplot as plt 15 | from scipy import sparse 16 | import numpy as np 17 | import pandas as pd 18 | from tqdm.notebook import tqdm 19 | 20 | def _intensity_show(LRadata: anndata, 21 | cells: list, 22 | genes: list, 23 | l: np.ndarray, 24 | r: np.ndarray, 25 | key: str = 'expression_spatial_distances', 26 | anno = "cell2loc_anno", 27 | alpha_g:float = 0.5, 28 | alpha_i:float = 0.4, 29 | spot_size:float = 2, 30 | figsize:tuple = (4, 5), 31 | save: str = None): 32 | """ 33 | draw the spatial cell interaction intensity in space. 34 | """ 35 | df = pd.DataFrame() 36 | df['x'] = LRadata.obsm["spatial"][:,0].astype(int) 37 | df['y'] = LRadata.obsm["spatial"][:,1].astype(int) 38 | df['intensity'] = LRadata.obsp[key].sum(axis=1).A1*(LRadata.obs[anno].isin(cells).any()) 39 | df[genes[0]] = l 40 | df[genes[1]] = r 41 | colors = ['tab:blue', 'tab:orange', 'tab:red'] 42 | fig, ax = plt.subplots(figsize=figsize, dpi=200) 43 | i=0 44 | scatters = [] 45 | labels = [] 46 | for gene in genes: 47 | df1 = df[df[gene]>0] 48 | scatter = ax.scatter(x = df1['x'], y = df1['y'], c = colors[i], edgecolors='none', alpha = alpha_g, s = spot_size) 49 | scatters.append(scatter) 50 | labels.append(gene) 51 | i+=1 52 | if (df.shape[0] > 0): 53 | scatter = ax.scatter(x = df['x'], y = df['y'], c = colors[2], s = df['intensity']*spot_size, edgecolors='none', alpha = alpha_i) 54 | scatters.append(scatter) 55 | labels.append('intensity') 56 | handles, labels1 = scatter.legend_elements(prop="sizes", alpha=0.6) 57 | legend1 = ax.legend(handles, labels1, bbox_to_anchor=[1.25, 0.7], fontsize=3, title = 'intensity') 58 | ax.add_artist(legend1) 59 | else: 60 | return 0 61 | ax.set_ylim(bottom = df['y'].max(), top = 0) 62 | ax.legend(scatters, labels, prop = {'size': 4}, bbox_to_anchor=[1.25, 0.8], fontsize=4) 63 | plt.gca().set_aspect(1) 64 | plt.axis('off') 65 | plt.title(f"{cells[0]} | {cells[1]}\n({genes[0]} | {genes[1]})") 66 | if save!=None: 67 | plt.savefig(save) 68 | 69 | def intensity_insitu(adata: anndata, 70 | cells: list, 71 | genes: list, 72 | anno: str= 'cell2loc_anno', 73 | radius: float = 0.0, 74 | distance_coefficient:float = 0.0, 75 | connectivities_key: str = "spatial_distances", 76 | complex_process_model: str = 'mean', 77 | alpha_g:float = 0.5, 78 | alpha_i:float = 0.4, 79 | spot_size:float = 2, 80 | figsize:tuple = (4, 5), 81 | save: str = None 82 | ) -> int: 83 | """ 84 | Calculate the spatial cell interaction intensity between specified cells and ligand receptor genes. 85 | 86 | Parameters 87 | ---------- 88 | adata 89 | anndata 90 | cells 91 | list contains sender and receiver cell type. [sender_cell, receiver_cell] 92 | genes 93 | list contains ligand and receptor genes. [ligand, receptor] 94 | anno 95 | annotation key, default=cell2loc_anno. 96 | radius 97 | radius threshold when constructing nearest neighbor graph. Only be used when the neighbor graph doesn't exist. 98 | distance_coefficient 99 | Consider the distance as one of the factor that influence the interaction intensity using the exponential decay formular: C=C0*e^(-k*d). 100 | The parameter defines the k value in the formular. Default=0, means distance would not influence the interaction intensity. 101 | connectivities_key 102 | obtain the constructed nearest neighbor graph in adata.obsp[connectivities_key]. If this doesn't exist, construct 103 | the neighbor graph with specified radius 104 | complex_process_model 105 | determine how to deal with the complexed ligand and receptor which contain multiple subunits. There are two options: mean, min. 106 | mean: calculate the mean expression of all subunits to represent the complex 107 | min: pick the minimal expression of all subunits to represent the complex 108 | 109 | Returns 110 | ---------- 111 | intensity value of the entire slide 112 | """ 113 | exp_key = "expression" 114 | if "_" in genes[0]: 115 | if complex_process_model == 'min': 116 | l = (adata.obs[anno]==cells[0])*(adata[:,genes[0].split("_")].X.min(axis=1).toarray()[:,0]) 117 | elif complex_process_model == 'mean': 118 | l = (adata.obs[anno]==cells[0])*(adata[:,genes[0].split("_")].X.mean(axis=1).A1) 119 | else: 120 | raise Exception(f"complex_process_model should be mean or min, but got {complex_process_model}.") 121 | else: 122 | l = (adata.obs[anno]==cells[0])*(adata[:,genes[0]].X.sum(axis=1).A1) 123 | if "_" in genes[1]: 124 | if complex_process_model == 'min': 125 | r = (adata.obs[anno]==cells[1])*(adata[:,genes[1].split("_")].X.min(axis=1).toarray()[:,0]) 126 | elif complex_process_model == 'mean': 127 | r = (adata.obs['cell2loc_anno']==cells[1])*(adata[:,genes[1].split("_")].X.mean(axis=1).A1) 128 | else: 129 | raise Exception(f"complex_process_model should be mean or min, but got {complex_process_model}.") 130 | else: 131 | r = (adata.obs[anno]==cells[1])*(adata[:,genes[1]].X.sum(axis=1).A1) 132 | 133 | if connectivities_key in adata.obsp.keys() and radius == 0: 134 | connect_matrix = adata.obsp[connectivities_key] 135 | elif radius > 10: 136 | key_added = f"{radius}um" 137 | if f"{key_added}_distances" in adata.obsp.keys(): 138 | connect_matrix = adata.obsp[f"{key_added}_distances"] 139 | else: 140 | sq.gr.spatial_neighbors(adata, radius=radius*2, coord_type="generic", key_added=key_added) 141 | connect_matrix = adata.obsp[f"{key_added}_distances"] 142 | connectivities_key = f"{key_added}_distances" 143 | else: 144 | raise Exception(f"The distances_key ({connectivities_key}) dosn't exist in adata.obsp, and radius has not be specified with a value >= 10") 145 | 146 | l = l.values 147 | r = r.values 148 | l_rows = np.where(l > 0)[0] 149 | r_cols = np.where(r > 0)[0] 150 | sub_connect_matrix=connect_matrix[l_rows,:][:,r_cols].todense() 151 | dst = np.where(sub_connect_matrix>0) 152 | distances = sub_connect_matrix[dst] 153 | exps = l[l_rows[dst[0]]]*[math.exp(-distance_coefficient*d) for d in distances.A1] + r[r_cols[dst[1]]] 154 | 155 | spatial_exp = sparse.csr_matrix((exps, (l_rows[dst[0]], r_cols[dst[1]])), shape=connect_matrix.shape, dtype=int) 156 | exp_connectivities_key = f"{exp_key}_{connectivities_key}" 157 | adata.obsp[exp_connectivities_key] = spatial_exp 158 | neighbors_key = exp_connectivities_key.replace("distances", "neighbors") 159 | params = adata.uns[connectivities_key.replace("distances", "neighbors")] 160 | params['weight'] = 'expression' 161 | adata.uns[neighbors_key] = {'connectivities_key': connectivities_key.replace("distances", "connectivities"), 162 | 'distances_key': exp_connectivities_key, 163 | 'params': params, 164 | } 165 | 166 | _intensity_show(adata, cells, genes, l, r, key=exp_connectivities_key, anno=anno, alpha_g=alpha_g, alpha_i=alpha_i, spot_size=spot_size, figsize=figsize, save=save) 167 | return spatial_exp.sum() 168 | 169 | def intensities_with_radius(adata, pairs = None): 170 | """ 171 | draw the line plot with radius as x and intensity per area as y 172 | 173 | Parameters 174 | --------------- 175 | adata 176 | anndata 177 | pairs 178 | pair list that will be drawed in the plot 179 | for example Teff|Microphage(APP|CD74) 180 | """ 181 | try: 182 | plot_df = adata.uns['intensities_with_radius'] 183 | except Exception: 184 | print("there must be intensities_with_radius in adata.uns, please run intensity.intensities_with_radius before draw this plot.") 185 | if pairs == None: 186 | columns = plot_df.columns.values.tolist() 187 | columns.remove('radius') 188 | else: 189 | columns = pairs 190 | 191 | import matplotlib.colors as mcolors 192 | colors = list(mcolors.TABLEAU_COLORS.keys()) 193 | i = 0 194 | for column in columns: 195 | x = plot_df['radius'] 196 | y = plot_df[column]/(3.14*(plot_df['radius']**2)) 197 | plt.plot(x, y, 198 | linestyle="-", linewidth=2, color=colors[i], 199 | marker='o', markersize=4, markeredgecolor='black', markerfacecolor=colors[i], 200 | label=column, 201 | ) 202 | i+=1 203 | plt.xlabel('radius (µm)') 204 | plt.ylabel('intensity / area') 205 | plt.legend() 206 | 207 | -------------------------------------------------------------------------------- /stereosite/plot/mask.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # _*_ coding: utf-8 _*_ 3 | 4 | #@Author: LiuXing liuxing2@genomics.cn 5 | #@Date: 2023-07-11 16:04:39 6 | #@Last Modified by: LiuXing 7 | #@Last Modified time: 2023-07-11 16:04:39 8 | 9 | import numpy as np 10 | import tifffile 11 | import cv2 12 | import os, sys 13 | import anndata 14 | 15 | def _color_change(value): 16 | digit = list(map(str, range(10))) + list("abcdef") 17 | a1 = digit.index(value[1]) * 16 + digit.index(value[2]) 18 | a2 = digit.index(value[3]) * 16 + digit.index(value[4]) 19 | a3 = digit.index(value[5]) * 16 + digit.index(value[6]) 20 | return (a3, a2, a1) 21 | 22 | def _get_color_list(my_color): 23 | if (my_color == None): 24 | my_color = ['#0343df', '#f97306', '#15b01a', '#e50000', '#7e1e9c', '#653700', '#ff81c0', '#929591', 25 | '#6e750e', '#00ffff', '#ff796c', '#06c2ac', '#75bbfd', '#01ff07', '#cb416b', '#bf77f6', '#ceb301', 26 | '#137e6d', '#516572', '#dbb40c', '#d0fefe', '#9ffeb0', '#fdaa48', '#ffcfdc', '#ffffc2', '#ac9362', 27 | '#7a9703', '#96ae8d', '#b66a50', '#411900'] 28 | rgb_mycolor = [] 29 | for color in my_color: 30 | rgb_mycolor.append(list(_color_change(color))) 31 | return rgb_mycolor 32 | 33 | def mask_coloring(adata: anndata, mask_file: str, 34 | anno: str='cell2loc_anno', 35 | save: str=None, save_legend: str=None): 36 | """ 37 | Paint each cell in the mask with color corresponding to the palette of cell type in anndata 38 | 39 | Parameters 40 | ----------- 41 | adata 42 | anndata file with cell type annotation 43 | mask_file 44 | cell mask file path 45 | anno 46 | annotation key in adata. default=cell2loc_anno 47 | save 48 | if save was specified with a file path, the colored mask will be writen to it. The file suffix should be .png|.jpg 49 | save_legend 50 | if save_legend was specified with a file path, the corresponding legend of colored mask will be writen to it. The file suffix should be .png|.jpg 51 | 52 | Returns 53 | ---------- 54 | Tuple 55 | colored cell mask: ndarray 56 | legend: ndarray 57 | """ 58 | maskImg = tifffile.imread(mask_file) 59 | if (maskImg.max() == 1): 60 | _, labels = cv2.connectedComponents(maskImg) 61 | else: 62 | labels = maskImg 63 | dst = np.nonzero(labels) 64 | paletteKey = f"{anno}_colors" 65 | my_color = list(adata.uns[paletteKey]) if paletteKey in adata.uns.keys() else None 66 | rgb_color = _get_color_list(my_color) 67 | clusterDict = dict(zip(adata.obs.index.astype(int), adata.obs[anno].cat.codes)) 68 | img_colors = [rgb_color[int(clusterDict[x])] if x in clusterDict.keys() else [255, 255, 255] for x in labels[dst]] 69 | new_img = np.zeros((maskImg.shape[0], maskImg.shape[1], 3), dtype = np.int8) 70 | new_img[dst] = img_colors 71 | 72 | #get legend 73 | cells = adata.obs[anno].cat.categories 74 | width = 10000 75 | high = 1000 76 | legend = np.zeros([high*len(cells), width,3]) 77 | #legend.fill(0) 78 | font = cv2.FONT_HERSHEY_SIMPLEX 79 | fontScale = 20 80 | # Line thickness of 40 px 81 | thickness = 50 82 | pointSize = 100 83 | for i in range(len(cells)): 84 | color = rgb_color[i] 85 | cell = cells[i] 86 | point = (int(width/8), int((i+0.4)*high)) 87 | cv2.circle(legend, point, pointSize, color, pointSize*2) 88 | org = (int(width/5), int((i+0.6)*high)) 89 | legend = cv2.putText(legend, cell, org, font, fontScale, (255, 255, 255), thickness, cv2.LINE_AA) 90 | 91 | if save != None: 92 | out_dir = os.path.dirname(save) 93 | if not os.path.exists(out_dir): 94 | os.makedirs(out_dir, exist_ok=True) 95 | cv2.imwrite(save, new_img) 96 | if save_legend != None: 97 | out_dir = os.path.dirname(save_legend) 98 | if not os.path.exists(out_dir): 99 | os.makedirs(out_dir) 100 | cv2.imwrite(save_legend, legend) 101 | return new_img, legend -------------------------------------------------------------------------------- /stereosite/plot/net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- encoding: utf-8 -*- 3 | 4 | """ 5 | @File : get_hub_net.py 6 | @Description : get_hub_net.py 7 | @Author : Liuchuandong liuchuandong@genomics.cn 8 | @Date : 2023/08/01 11:42:43 9 | """ 10 | 11 | 12 | 13 | import os, sys 14 | from optparse import OptionParser 15 | import numpy as np 16 | import pandas as pd 17 | import networkx as nx 18 | import matplotlib.pyplot as plt 19 | import matplotlib 20 | from matplotlib.colors import ListedColormap, LinearSegmentedColormap 21 | 22 | POS_SEED=99 23 | HUB_NODE_SIZE = 366 24 | BASIC_NODE_SIZE = 100 25 | PPI_outdir = 'PPI' 26 | os.makedirs(PPI_outdir, exist_ok=True) 27 | TF_outdir = 'TF' 28 | os.makedirs(TF_outdir, exist_ok=True) 29 | 30 | def _edge_colors(graph,cmap='binary',min_weight=0.4,max_weight=1,unit=0.001): 31 | ''' 32 | set edges color by STRING combined_score 33 | ''' 34 | norm=matplotlib.colors.Normalize(vmin=min_weight, vmax=max_weight) 35 | scale_map = matplotlib.cm.ScalarMappable(cmap=cmap, norm=norm) 36 | colors = [] 37 | for i in graph.edges(): 38 | weight = graph.edges[i]['combined_score'] 39 | colors.append( matplotlib.colors.rgb2hex(scale_map.to_rgba(weight*unit)) ) 40 | return({'color_list':colors,'scale_colormap':scale_map}) 41 | 42 | def _node_colors(graph,logFC_df,name,cmap='Reds',min_weight=None,max_weight=None,unit=0.1): 43 | ''' 44 | set node colors by logFC 45 | logFC_df 46 | Dataframe contains gene's logFoldChange column,and node names as row index 47 | name 48 | Column name in logFC_df with gene's logFoldChange 49 | ''' 50 | logFC_df = logFC_df.loc[list(graph.nodes()),:] 51 | for i in graph.nodes(): 52 | graph.nodes[i]['logFC']=logFC_df.loc[i,name] 53 | if min_weight==None: 54 | min_weight= round(min(logFC_df[name]),1) 55 | if max_weight==None: 56 | max_weight= round(max(logFC_df[name]),1) 57 | 58 | norm=matplotlib.colors.Normalize(vmin=min_weight, vmax=max_weight) 59 | scale_map = matplotlib.cm.ScalarMappable(cmap=cmap, norm=norm) 60 | colors = [] 61 | for i in graph.nodes(): 62 | fc = round(graph.nodes[i]['logFC'],1) 63 | colors.append( matplotlib.colors.rgb2hex(scale_map.to_rgba(fc)) ) 64 | return({'color_list':colors,'scale_colormap':scale_map}) 65 | 66 | 67 | def ppi_hub_net(hub_net,hub_gene,logFC_df,name,save=PPI_outdir+'hub_net_cluster1.pdf',figsize=(6,6.8),hspace = 0.4, 68 | edge_min_weight=0.8,edge_max_weight=1.0,edge_unit=0.001,edge_cmap='binary', 69 | node_min_weight=0,node_max_weight=None,node_unit=0.1,node_cmap='Reds'): 70 | ''' 71 | plot ppi network 72 | 73 | Parameters 74 | ---------- 75 | hub_net 76 | Networkx object, the net need to plot. 77 | hub_gene 78 | The hub genes of ppi net. 79 | logFC_df 80 | Differential analysis result contains gene's logFoldChange column. 81 | name 82 | Column name in logFC_df with gene's logFoldChange 83 | ---------- 84 | ''' 85 | fig, ax=plt.subplots(3,1,figsize=figsize,height_ratios=[28,1,1]) 86 | edge_colors = _edge_colors(hub_net,cmap=edge_cmap,min_weight=edge_min_weight,max_weight=edge_max_weight,unit=edge_unit) 87 | node_colors = _node_colors(hub_net,logFC_df=logFC_df,name=name,cmap=node_cmap,min_weight=node_min_weight,max_weight=node_max_weight,unit=node_unit) 88 | pos = nx.spring_layout(hub_net, seed=POS_SEED) # Seed layout for reproducibility 89 | nx.draw(hub_net,pos,edge_color=edge_colors['color_list'],node_color=node_colors['color_list'],ax=ax[0], 90 | with_labels=True, font_weight='bold',font_size=9,node_size=[HUB_NODE_SIZE if v in hub_gene else BASIC_NODE_SIZE for v in hub_net.nodes()]) 91 | fig.colorbar(edge_colors['scale_colormap'], cax=ax[1], orientation='horizontal', label='STRING combined_score') 92 | fig.colorbar(node_colors['scale_colormap'], cax=ax[2], orientation='horizontal', label='logFC') 93 | plt.subplots_adjust(hspace = hspace) 94 | plt.savefig(save) 95 | 96 | def _select_grn(tfs,logFC_df,grn,source='source',target='target',pathway_genes=None): 97 | ''' 98 | Select nodes from input tfs and input pathway 99 | 100 | Parameters 101 | ---------- 102 | tfs 103 | Transcription factor names in GRN. 104 | grn 105 | Dataframe of gene regulatory network (GRN). 106 | source 107 | Column names in grn dataframe with source nodes. 108 | target 109 | Column names in grn dataframe with target nodes. 110 | logFC_df 111 | Dataframe of deseq2 result. 112 | pathway_genes 113 | Genes list of pathway. 114 | ---------- 115 | 116 | Returns 117 | Dataframe of selected gene regulatory network (GRN) 118 | ''' 119 | if pathway_genes==None: 120 | select_grn=grn.loc[[i for i in grn.index if grn.loc[i,source] in tfs],:].reset_index(drop=True) 121 | select_grn=select_grn.loc[[i for i in select_grn.index if select_grn.loc[i,target] in logFC_df.index],:].reset_index(drop=True) 122 | else: 123 | select_grn=grn.loc[[i for i in grn.index if grn.loc[i,source] in tfs],:].reset_index(drop=True) 124 | pathway_genes = [i for i in pathway_genes if i in logFC_df.index] 125 | select_grn=select_grn.loc[[i for i in select_grn.index if select_grn.loc[i,target] in pathway_genes],:].reset_index(drop=True) 126 | return(select_grn) 127 | 128 | def tf_net(tfs,grn,logFC_df,name,source='source',target='target',pathway_genes=None,save=TF_outdir+'/TF_net.pdf', 129 | figsize=(6.8,6),wspace = 0.04,node_cmap='Reds',node_min_weight=None,node_max_weight=None,node_unit=0.1): 130 | ''' 131 | Plot gene regulatory network (GRN) of input transcription factors. Node color represent logFC, larger Nodes represent input TFs. 132 | 133 | Parameters 134 | ---------- 135 | tfs 136 | Transcription factor names in GRN. 137 | grn 138 | Gene regulatory network (GRN). 139 | logFC_df 140 | dataframe of deseq2 result. 141 | name 142 | Column name in logFC_df with gene's logFoldChange. 143 | source 144 | Column names in grn dataframe with source nodes. 145 | target 146 | Column names in grn dataframe with target nodes. 147 | pathway_genes 148 | Genes list of pathway. Optional. 149 | save 150 | Path to where to save the plot. Infer the filetype if ending on {`.pdf`, `.png`, `.svg`}. 151 | ---------- 152 | 153 | ''' 154 | select_grn = _select_grn(tfs=tfs,logFC_df=logFC_df,grn=grn,source=source,target=target,pathway_genes=pathway_genes) 155 | tf_net = nx.from_pandas_edgelist(df=select_grn,source=source,target=target,create_using=nx.DiGraph()) 156 | node_colors = _node_colors(graph=tf_net,logFC_df=logFC_df,name=name,cmap=node_cmap,min_weight=node_min_weight,max_weight=node_max_weight,unit=node_unit) 157 | fig, ax=plt.subplots(1,2,figsize=figsize,width_ratios=[28,1]) 158 | pos = nx.circular_layout(tf_net) 159 | nx.draw(tf_net,pos,ax=ax[0],connectionstyle="arc3,rad=0.1",with_labels=True, edge_color='grey',width=0.5,arrowsize=8, 160 | #edgecolors=['red' if results_df.loc[v,'padj']<0.05 else node_colors['color_list'][list(tf_net.nodes()).index(v)] for v in tf_net.nodes()], 161 | font_size=9,font_weight='normal',verticalalignment='baseline', 162 | node_color=node_colors['color_list'], 163 | node_size=[HUB_NODE_SIZE if v in tfs else BASIC_NODE_SIZE for v in tf_net.nodes()]) 164 | 165 | fig.colorbar(node_colors['scale_colormap'], cax=ax[1], orientation='vertical', label='logFC')#horizontal 166 | plt.subplots_adjust(wspace = wspace) 167 | plt.savefig(save) 168 | 169 | -------------------------------------------------------------------------------- /stereosite/plot/sankey.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import pandas as pd 3 | import numpy as np 4 | import matplotlib as mpl 5 | import matplotlib.pyplot as plt 6 | import seaborn as sns 7 | from sklearn import datasets 8 | from sklearn.preprocessing import minmax_scale 9 | from collections import defaultdict 10 | 11 | class PySankeyException(Exception): 12 | pass 13 | 14 | 15 | class NullsInFrame(PySankeyException): 16 | pass 17 | 18 | 19 | class LabelMismatch(PySankeyException): 20 | pass 21 | 22 | def _check_data_matches_labels(labels, data): 23 | if isinstance(data, list): 24 | data = set(data) 25 | if isinstance(data, pd.Series): 26 | data = set(data.unique().tolist()) 27 | if isinstance(labels, list): 28 | labels = set(labels) 29 | if labels != data: 30 | raise LabelMismatch('Labels and data do not match. {0}'.format(','.join(labels))) 31 | 32 | def core_process(core:np.array, model_names:list = ['CC_Module', 'LR_Module', 'TME']): 33 | # preprocess core matrix with three dimensionality: ligand-receptor model, cell-cell model, TME model 34 | 35 | left = [] 36 | mid_l = [] 37 | mid_r = [] 38 | right = [] 39 | left_weight = [] 40 | mid_l_weight = [] 41 | mid_r_weight = [] 42 | right_weight = [] 43 | lr_tme_link = core.sum(axis=0) 44 | cc_tme_link = core.sum(axis=1) 45 | lrs, ccs, tmes = core.shape 46 | for tme in range(tmes): 47 | for lr in range(lrs): 48 | v = cc_tme_link[lr, tme] 49 | if v <= 0: 50 | continue 51 | left.append(f"{model_names[0]} {lr}") 52 | mid_l.append(f"{model_names[2]} {tme}") 53 | left_weight.append(v) 54 | mid_l_weight.append(v) 55 | for cc in range(ccs): 56 | v = lr_tme_link[cc, tme] 57 | if v <= 0: 58 | continue 59 | right.append(f"{model_names[1]} {cc}") 60 | mid_r.append(f"{model_names[2]} {tme}") 61 | right_weight.append(v) 62 | mid_r_weight.append(v) 63 | #mid_l_weight = tmes*ccs*[cc_tme_link.sum()/(tmes*ccs)] 64 | #mid_r_weight = tmes*lrs*[lr_tme_link.sum()/(tmes*lrs)] 65 | 66 | #dataFrame = pd.DataFrame({'left': left, 'right': right, 'mid_l': mid_l, 'mid_r': mid_r, 67 | # 'left_weight': left_weight, 'right_weight': right_weight, 68 | # 'mid_l_weight': mid_l_weight, 'mid_r_weight': mid_r_weight}, index=range(len(left))) 69 | data_left = pd.DataFrame({'left': left, 'mid_l': mid_l, 70 | 'left_weight': left_weight, 71 | 'mid_l_weight': mid_l_weight, 72 | }, index=range(len(left))) 73 | data_right = pd.DataFrame({'right': right, 'mid_r': mid_r, 74 | 'right_weight': right_weight, 75 | 'mid_r_weight': mid_r_weight 76 | }, index=range(len(right))) 77 | return data_left, data_right 78 | 79 | def factor_process(factors:np.array, names:list, model_type:str='CC_Module'): 80 | left = [] 81 | right = [] 82 | leftWeight = [] 83 | rightWeight = [] 84 | for index, element in np.ndenumerate(factors): 85 | left.append(f'{model_type} {index[1]}') 86 | right.append(names[index[0]]) 87 | leftWeight.append(element) 88 | rightWeight.append(element) 89 | dataFrame = pd.DataFrame({'left': left, 'right': right, 'left_weight': leftWeight, 'right_weight': rightWeight}, index=range(len(left))) 90 | return dataFrame 91 | 92 | def lr_pathway_dict_generate(interactiondb_file:str, 93 | lr_label:str='interaction_name', 94 | pathway_label:str='pathway_name', 95 | ): 96 | interactiondb =pd.read_csv(interactiondb_file) 97 | lr_pathway_dict = dict(zip(interactiondb[lr_label].values, interactiondb[pathway_label].values)) 98 | return lr_pathway_dict 99 | 100 | def interaction_matrix_decomposition(interaction_matrix, interactiondb_file, 101 | components:int=5, 102 | W_filter:float=0.5, H_filter:float=0.5, 103 | cell_level:int=0): 104 | filter_matrix = interaction_matrix.copy() 105 | if cell_level==1: 106 | filter_matrix.columns = filter_matrix.columns.swaplevel(i=0, j=1) 107 | cellchatDB = pd.read_csv(interactiondb_file) 108 | lr_pathway_dict = lr_pathway_dict_generate(interactiondb_file) 109 | sender_df = pd.DataFrame(index = [lr_pathway_dict[f'{x[0]}_{x[1]}'] if f'{x[0]}_{x[1]}' in lr_pathway_dict.keys() \ 110 | else cellchatDB[(cellchatDB['ligand'] == x[0]) & (cellchatDB['receptor'] == x[1])]['pathway_name'].values[0] \ 111 | for x in filter_matrix.index]) 112 | cells = filter_matrix.columns.get_level_values(cell_level).unique() 113 | for cell in cells: 114 | sender_df[cell] = filter_matrix[cell].sum(axis=1).values 115 | sender_df = sender_df.groupby(sender_df.index).sum() 116 | from sklearn.decomposition import NMF 117 | model = NMF(n_components = components, init='random', random_state=0) 118 | W = model.fit_transform(sender_df.T.values) 119 | H = model.components_ 120 | 121 | normalized_W = minmax_scale(W, axis=1) 122 | normalized_W[normalized_W < W_filter] = 0 123 | 124 | df_W = pd.DataFrame(normalized_W, index = sender_df.T.index) 125 | left = [] 126 | right = [] 127 | left_weight = [] 128 | right_weight = [] 129 | for cell, row in df_W.iterrows(): 130 | for i, v in row.items(): 131 | if v == 0: 132 | continue 133 | left.append(cell) 134 | right.append(f"Pattern {i+1}") 135 | left_weight.append(v) 136 | right_weight.append(v) 137 | sankey_df_W = pd.DataFrame({'left': left, 'right': right, 'left_weight': left_weight, 'right_weight': right_weight}, index=range(len(left))) 138 | 139 | normalized_H = minmax_scale(H, axis=0) 140 | normalized_H[normalized_H < H_filter] = 0 141 | df_H = pd.DataFrame(normalized_H, columns = sender_df.T.columns) 142 | left = [] 143 | right = [] 144 | left_weight = [] 145 | right_weight = [] 146 | average_v = normalized_H.sum()/normalized_H.shape[0] 147 | for i, row in df_H.iterrows(): 148 | for pathway, v in row.items(): 149 | if v == 0: 150 | continue 151 | left.append(f"Pattern {i+1}") 152 | right.append(pathway) 153 | left_weight.append(v) 154 | right_weight.append(v) 155 | sankey_df_H = pd.DataFrame({'left': left, 'right': right, 'left_weight': left_weight, 'right_weight': right_weight}, index=range(len(left))) 156 | return sankey_df_W, sankey_df_H 157 | 158 | def sankey_3d(data_l:pd.DataFrame, data_r:pd.DataFrame, 159 | cmap='tab20', left_labels = None, right_labels = None, mid_labels = None, 160 | aspect=3, fontsize=5, save=None, close_plot=False, patch_alpha:float=0.99, link_alpha:float=0.4, 161 | interval:float=0.005, module_color=(0.4980392156862745, 0.4980392156862745, 0.4980392156862745), 162 | ): 163 | ''' 164 | Make Sankey Diagram showing flow: ligand-receptor model <--- TME ---> cell-cell model 165 | 166 | Inputs: 167 | data_l: pandas.dataFrame. 168 | Contains columns left, mid_l, left_weight, mid_l_weight 169 | data_r: pandas.dataFrame. 170 | Contains columns right, mid_r, right_weight, mid_r_weight 171 | cmap: str|dict. 172 | Define colors of each patch. User can set matplotlib's colormap (e.g. viridis, jet, tab10) or label_name -> color dict (e.g. dict(A="red", B="blue", C="green", ...)). 173 | left_labels: list[str]|array[str]. 174 | Order of the left labels in the diagram 175 | right_labels: list[str]|array[str]. 176 | Order of the right labels in the diagram 177 | mid_labels: list[str]|array[str]. 178 | Order of the middle labels in the diagram 179 | aspect: float. 180 | Vertical extent of the diagram in units of horizontal extent 181 | fontsize: float. 182 | Fontsize of patch label text 183 | save: str. 184 | If the figure file name was given, the sankey figure will be stored in it. 185 | interval: float. 186 | Distance between two adjacent patchs = interval * vertical length of all patchs. 187 | module_color: tuple[float, float, float]. 188 | Define the color of left and right patchs representing modules. default=(0.4980392156862745, 0.4980392156862745, 0.4980392156862745). 189 | 190 | Output: 191 | None 192 | ''' 193 | plt.figure(dpi=140) 194 | plt.rc('text', usetex=False) 195 | plt.rc('font', family='serif') 196 | 197 | if len(data_l[(data_l.left.isnull()) | (data_r.right.isnull()) | (data_l.mid_l.isnull()) | (data_r.mid_r.isnull())]): 198 | raise NullsInFrame('Sankey graph dose not support null values.') 199 | 200 | #Identify all labels that appear 'left' or 'right' 201 | allLabels = pd.Series(np.r_[data_r.mid_r.unique(), data_l.left.unique(), data_r.right.unique(), data_l.mid_l.unique()]).unique() 202 | 203 | #Identify labels 204 | if left_labels == None: 205 | left_labels = data_l.left.unique()[::-1] 206 | else: 207 | _check_data_matches_labels(left_labels, data_l['left']) 208 | if mid_labels == None: 209 | mid_labels = data_r.mid_r.unique()[::-1] 210 | else: 211 | _check_data_matches_labels(mid_labels, data_l['mid_l']) 212 | if right_labels == None: 213 | right_labels = data_r.right.unique()[::-1] 214 | else: 215 | _check_data_matches_labels(right_labels, data_r['right']) 216 | 217 | if isinstance(cmap, str): 218 | color_dict = {} 219 | colorPalette = sns.color_palette(cmap, len(mid_labels)) 220 | for i, label in enumerate(mid_labels): 221 | color_dict[label] = colorPalette[i] 222 | for label in left_labels: 223 | color_dict[label] = module_color 224 | for label in right_labels: 225 | color_dict[label] = module_color 226 | elif isinstance(cmap, dict): 227 | color_dict = cmap 228 | missing = [label for label in allLabels if label not in color_dict.keys()] 229 | if missing: 230 | msg = "The cmap parameter is missing values for the following labels: {}".format(', '.join(missing)) 231 | raise ValueError(msg) 232 | else: 233 | raise ValueError("cmap must be string representing the matplotlib's colormap or dict") 234 | 235 | #Determine widths of individual strips 236 | 237 | ns_l = defaultdict() 238 | ns_m_l = defaultdict() 239 | ns_m_r = defaultdict() 240 | ns_r = defaultdict() 241 | for midLabel in mid_labels: 242 | leftDict = {} 243 | midLDict = {} 244 | midRDict = {} 245 | rightDict = {} 246 | for leftLabel in left_labels: 247 | midLDict[leftLabel] = data_l[(data_l.mid_l == midLabel) & (data_l.left == leftLabel)].mid_l_weight.sum() 248 | leftDict[leftLabel] = data_l[(data_l.mid_l == midLabel) & (data_l.left == leftLabel)].left_weight.sum() 249 | ns_m_l[midLabel] = midLDict 250 | ns_l[midLabel] = leftDict 251 | for rightLabel in right_labels: 252 | midRDict[rightLabel] = data_r[(data_r.mid_r == midLabel) & (data_r.right == rightLabel)].mid_r_weight.sum() 253 | rightDict[rightLabel] = data_r[(data_r.mid_r == midLabel) & (data_r.right == rightLabel)].right_weight.sum() 254 | ns_m_r[midLabel] = midRDict 255 | ns_r[midLabel] = rightDict 256 | 257 | midLWidths = defaultdict() 258 | for i, midLabel in enumerate(mid_labels): 259 | myD = {} 260 | myD['mid'] = data_l[data_l.mid_l == midLabel].mid_l_weight.sum() 261 | if i == 0: 262 | myD['bottom'] = 0 263 | myD['top'] = myD['mid'] 264 | else: 265 | myD['bottom'] = midLWidths[mid_labels[i - 1]]['top'] + interval*data_l.mid_l_weight.sum() 266 | myD['top'] = myD['bottom'] + myD['mid'] 267 | topEdge = myD['top'] 268 | midLWidths[midLabel] = myD 269 | midRWidths = defaultdict() 270 | for i, midLabel in enumerate(mid_labels): 271 | myD = {} 272 | myD['mid'] = data_r[data_r.mid_r == midLabel].mid_r_weight.sum() 273 | if i == 0: 274 | myD['bottom'] = 0 275 | myD['top'] = myD['mid'] 276 | else: 277 | myD['bottom'] = midRWidths[mid_labels[i - 1]]['top'] + interval*data_r.mid_r_weight.sum() 278 | myD['top'] = myD['bottom'] + myD['mid'] 279 | topEdge = myD['top'] 280 | midRWidths[midLabel] = myD 281 | 282 | # Determine positions of left label patches and total widths 283 | leftWidths = defaultdict() 284 | for i, leftLabel in enumerate(left_labels): 285 | myD = {} 286 | myD['left'] = data_l[data_l.left == leftLabel].left_weight.sum() 287 | if i == 0: 288 | myD['bottom'] = 0 289 | myD['top'] = myD['left'] 290 | else: 291 | myD['bottom'] = leftWidths[left_labels[i - 1]]['top'] + interval*data_l.left_weight.sum() 292 | myD['top'] = myD['bottom'] + myD['left'] 293 | topEdge = myD['top'] 294 | leftWidths[leftLabel] = myD 295 | 296 | # Determine positions of right label patches and total widths 297 | rightWidths = defaultdict() 298 | for i, rightLabel in enumerate(right_labels): 299 | myD = {} 300 | myD['right'] = data_r[data_r.right == rightLabel].right_weight.sum() 301 | if i == 0: 302 | myD['bottom'] = 0 303 | myD['top'] = myD['right'] 304 | else: 305 | myD['bottom'] = rightWidths[right_labels[i-1]]['top'] + interval*data_r.right_weight.sum() 306 | myD['top'] = myD['bottom'] + myD['right'] 307 | topEdge = myD['top'] 308 | rightWidths[rightLabel] = myD 309 | 310 | #Total vertical extent of diagram 311 | 312 | xMax = topEdge/aspect 313 | width = 0.35*xMax 314 | pad = 0.9995 315 | 316 | #Draw vertical bars on left and right of each label's section & print label 317 | for midLabel in mid_labels: 318 | plt.fill_between( 319 | [-width/2, width/2], 320 | 2 * [midLWidths[midLabel]['bottom'] *pad], 321 | 2 * [(midLWidths[midLabel]['bottom'] + midLWidths[midLabel]['mid']) * pad], 322 | color = color_dict[midLabel], 323 | alpha=patch_alpha 324 | ) 325 | plt.text( 326 | 0, 327 | midLWidths[midLabel]['bottom'] + 0.5*midLWidths[midLabel]['mid'], 328 | midLabel, 329 | {'ha': 'center', 'va': 'center'}, 330 | fontsize = fontsize 331 | ) 332 | 333 | for leftLabel in left_labels: 334 | plt.fill_between( 335 | [-xMax-width, -xMax], 336 | 2 * [leftWidths[leftLabel]['bottom'] * pad], 337 | 2 * [(leftWidths[leftLabel]['bottom'] + leftWidths[leftLabel]['left']) * pad], 338 | color = color_dict[leftLabel], 339 | alpha=patch_alpha, 340 | edgecolor='k', 341 | linewidth=0.3, 342 | ) 343 | plt.text( 344 | -xMax - width/2, 345 | leftWidths[leftLabel]['bottom'] + 0.5*leftWidths[leftLabel]['left'], 346 | leftLabel, 347 | {'ha': 'center', 'va': 'center'}, 348 | fontsize=fontsize 349 | ) 350 | 351 | for rightLabel in right_labels: 352 | plt.fill_between( 353 | [xMax, xMax + width], 354 | 2 * [rightWidths[rightLabel]['bottom'] * pad], 355 | 2 * [(rightWidths[rightLabel]['bottom'] + rightWidths[rightLabel]['right']) * pad], 356 | color = color_dict[rightLabel], 357 | alpha=patch_alpha, 358 | edgecolor='k', 359 | linewidth=0.3, 360 | ) 361 | plt.text( 362 | xMax + width/2, 363 | rightWidths[rightLabel]['bottom'] + 0.5*rightWidths[rightLabel]['right'], 364 | rightLabel, 365 | {'ha': 'center', 'va': 'center'}, 366 | fontsize = fontsize 367 | ) 368 | 369 | # Plot strips 370 | for midLabel in mid_labels: 371 | for leftLabel in left_labels: 372 | labelColor = midLabel 373 | if len(data_l[(data_l.mid_l == midLabel)& (data_l.left == leftLabel)]) > 0: 374 | # Create array of y values for each strip, half at middle value, half at left 375 | ys_d = np.array(50 * [leftWidths[leftLabel]['bottom']] + 50*[midLWidths[midLabel]['bottom']]) 376 | ys_d = np.convolve(ys_d, 0.05*np.ones(20), mode='valid') 377 | ys_d = np.convolve(ys_d, 0.05*np.ones(20), mode='valid') 378 | ys_u = np.array(50*[leftWidths[leftLabel]['bottom'] + ns_l[midLabel][leftLabel]] + 50*[midLWidths[midLabel]['bottom'] + ns_m_l[midLabel][leftLabel]]) 379 | ys_u = np.convolve(ys_u, 0.05 * np.ones(20), mode='valid') 380 | ys_u = np.convolve(ys_u, 0.05 * np.ones(20), mode='valid') 381 | 382 | # Update bottom edges at each label so next strip starts at the right place 383 | midLWidths[midLabel]['bottom'] += ns_m_l[midLabel][leftLabel] 384 | leftWidths[leftLabel]['bottom'] += ns_l[midLabel][leftLabel] 385 | plt.fill_between( 386 | np.linspace(-xMax, -width/2, len(ys_d)), ys_d, ys_u, alpha=link_alpha, 387 | color = color_dict[labelColor] 388 | ) 389 | 390 | for rightLabel in right_labels: 391 | labelColor = midLabel 392 | if len(data_r[(data_r.mid_r == midLabel) & (data_r.right == rightLabel)]) > 0 : 393 | # Create array of y values for each strip, half at let value, 394 | # half at right 395 | ys_d = np.array(50 * [midRWidths[midLabel]['bottom']] + 50 * [rightWidths[rightLabel]['bottom']]) 396 | ys_d = np.convolve(ys_d, 0.05*np.ones(20), mode='valid') 397 | ys_d = np.convolve(ys_d, 0.05*np.ones(20), mode='valid') 398 | ys_u = np.array(50 * [midRWidths[midLabel]['bottom'] + ns_m_r[midLabel][rightLabel]] + 50 * [rightWidths[rightLabel]['bottom'] + ns_r[midLabel][rightLabel]]) 399 | ys_u = np.convolve(ys_u, 0.05 * np.ones(20), mode='valid') 400 | ys_u = np.convolve(ys_u, 0.05 * np.ones(20), mode='valid') 401 | 402 | # Update bottom edges at each label so next strip starts at the right place 403 | midRWidths[midLabel]['bottom'] += ns_m_r[midLabel][rightLabel] 404 | rightWidths[rightLabel]['bottom'] += ns_r[midLabel][rightLabel] 405 | plt.fill_between( 406 | np.linspace(width/2, xMax, len(ys_d)), ys_d, ys_u, alpha=link_alpha, 407 | color = color_dict[labelColor] 408 | ) 409 | plt.gca().axis('off') 410 | plt.gcf().set_size_inches(6, 6) 411 | if save != None: 412 | plt.savefig(save, bbox_inches='tight', dpi=300) 413 | if close_plot: 414 | plt.close() 415 | 416 | def sankey_2d(data:pd.DataFrame, cmap="tab20", left_labels:list=None, right_labels:list=None, 417 | aspect:float=4, fontsize=4, save=None, close_plot=False, patch_alpha:float=0.99, link_alpha:float=0.65, 418 | interval:float=0.0, strip_color='left', ax:mpl.axes.Axes=None, dpi:float=300 419 | ): 420 | ''' 421 | Make Sankey Diagram 422 | 423 | Inputs: 424 | data: pandas.DataFrame. 425 | contains columns left, right, mid_l, mid_r, left_weight, right_weight, mid_l_weight, mid_r_weight 426 | cmap: str|dict. 427 | Define colors of each patch. User can set matplotlib's colormap (e.g. viridis, jet, tab10) or label_name -> color dict (e.g. dict(A="red", B="blue", C="green", ...)). 428 | left_labels: list[str] | array[str]. 429 | order of the left labels in the diagram 430 | right_labels: list[str] | array[str]. 431 | order of the right labels in the diagram 432 | aspect: float. 433 | vertical extent of the diagram in units of horizontal extent 434 | 435 | Output: 436 | None 437 | ''' 438 | if ax==None: 439 | fig = plt.figure(dpi=dpi) 440 | ax = fig.add_subplot(1, 1, 1) 441 | 442 | if len(data[(data.left.isnull()) | (data.right.isnull())]): 443 | raise NullsInFrame('Sankey graph dose not support null values.') 444 | 445 | #Identify all labels that appear 'left' or 'right' 446 | allLabels = pd.Series(np.r_[data.left.unique(), data.right.unique()]).unique() 447 | 448 | #Identify left labels 449 | if left_labels == None: 450 | left_labels = data.left.unique()[::-1] 451 | else: 452 | _check_data_matches_labels(left_labels, data['left']) 453 | if right_labels == None: 454 | right_labels = data.right.unique()[::-1] 455 | else: 456 | _check_data_matches_labels(right_labels, data['right']) 457 | 458 | if isinstance(cmap, str): 459 | color_dict = {} 460 | colorPalette = sns.color_palette(cmap, len(allLabels)) 461 | for i, label in enumerate(allLabels): 462 | color_dict[label] = colorPalette[i] 463 | elif isinstance(cmap, dict): 464 | color_dict = cmap 465 | else: 466 | raise Exception("cmap must be string representing the matplotlib's colormap or dict") 467 | 468 | #Determine widths of individual strips 469 | from collections import defaultdict 470 | ns_l = defaultdict() 471 | ns_r = defaultdict() 472 | for leftLabel in left_labels: 473 | leftDict = {} 474 | rightDict = {} 475 | for rightLabel in right_labels: 476 | leftDict[rightLabel] = data[(data.left == leftLabel) & (data.right == rightLabel)].left_weight.sum() 477 | rightDict[rightLabel] = data[(data.left == leftLabel) & (data.right == rightLabel)].right_weight.sum() 478 | ns_l[leftLabel] = leftDict 479 | ns_r[leftLabel] = rightDict 480 | 481 | # Determine positions of left label patches and total widths 482 | leftWidths = defaultdict() 483 | for i, leftLabel in enumerate(left_labels): 484 | myD = {} 485 | myD['left'] = data[data.left == leftLabel].left_weight.sum() 486 | if i == 0: 487 | myD['bottom'] = 0 488 | myD['top'] = myD['left'] 489 | else: 490 | myD['bottom'] = leftWidths[left_labels[i - 1]]['top'] + interval*data.left_weight.sum() 491 | myD['top'] = myD['bottom'] + myD['left'] 492 | topEdge = myD['top'] 493 | leftWidths[leftLabel] = myD 494 | 495 | # Determine positions of right label patches and total widths 496 | rightWidths = defaultdict() 497 | for i, rightLabel in enumerate(right_labels): 498 | myD = {} 499 | myD['right'] = data[data.right == rightLabel].right_weight.sum() 500 | if i == 0: 501 | myD['bottom'] = 0 502 | myD['top'] = myD['right'] 503 | else: 504 | myD['bottom'] = rightWidths[right_labels[i-1]]['top'] + interval*data.right_weight.sum() 505 | myD['top'] = myD['bottom'] + myD['right'] 506 | topEdge = myD['top'] 507 | rightWidths[rightLabel] = myD 508 | 509 | #Total vertical extent of diagram 510 | l_width = -0.3 511 | r_width = 1.3 512 | xMax = topEdge/aspect 513 | 514 | #Draw vertical bars on left and right of each label's section & print label 515 | for leftLabel in left_labels: 516 | ax.fill_between( 517 | [l_width*xMax, 0], 518 | 2 * [leftWidths[leftLabel]['bottom']], 519 | 2 * [leftWidths[leftLabel]['bottom'] + leftWidths[leftLabel]['left']], 520 | color = color_dict[leftLabel], 521 | alpha=patch_alpha, 522 | edgecolor='k', 523 | linewidth=0.3, 524 | ) 525 | ax.text( 526 | l_width/2 * xMax, 527 | leftWidths[leftLabel]['bottom'] + 0.5*leftWidths[leftLabel]['left'], 528 | leftLabel, 529 | {'ha': 'center', 'va': 'center'}, 530 | fontsize=fontsize 531 | ) 532 | for rightLabel in right_labels: 533 | ax.fill_between( 534 | [xMax, r_width*xMax], 2 * [rightWidths[rightLabel]['bottom']], 535 | 2 * [rightWidths[rightLabel]['bottom'] + rightWidths[rightLabel]['right']], 536 | color = color_dict[rightLabel], 537 | alpha=patch_alpha, 538 | edgecolor='k', 539 | linewidth=0.3, 540 | ) 541 | ax.text( 542 | (r_width+l_width/2) * xMax, 543 | rightWidths[rightLabel]['bottom'] + 0.5*rightWidths[rightLabel]['right'], 544 | rightLabel, 545 | {'ha': 'center', 'va': 'center'}, 546 | fontsize = fontsize 547 | ) 548 | 549 | # Plot strips 550 | for leftLabel in left_labels: 551 | for rightLabel in right_labels: 552 | if strip_color == 'left': 553 | labelColor = leftLabel 554 | else: 555 | labelColor = rightLabel 556 | if len(data[(data.left == leftLabel) & (data.right == rightLabel)]) > 0 : 557 | # Create array of y values for each strip, half at let value, 558 | # half at right 559 | ys_d = np.array(50 * [leftWidths[leftLabel]['bottom']] + 50 * [rightWidths[rightLabel]['bottom']]) 560 | ys_d = np.convolve(ys_d, 0.05*np.ones(20), mode='valid') 561 | ys_d = np.convolve(ys_d, 0.05*np.ones(20), mode='valid') 562 | ys_u = np.array(50 * [leftWidths[leftLabel]['bottom'] + ns_l[leftLabel][rightLabel]] + 50 * [rightWidths[rightLabel]['bottom'] + ns_r[leftLabel][rightLabel]]) 563 | ys_u = np.convolve(ys_u, 0.05 * np.ones(20), mode='valid') 564 | ys_u = np.convolve(ys_u, 0.05 * np.ones(20), mode='valid') 565 | 566 | # Update bottom edges at each label so next strip starts at the right place 567 | leftWidths[leftLabel]['bottom'] += ns_l[leftLabel][rightLabel] 568 | rightWidths[rightLabel]['bottom'] += ns_r[leftLabel][rightLabel] 569 | ax.fill_between( 570 | np.linspace(0, xMax, len(ys_d)), ys_d, ys_u, alpha=link_alpha, 571 | color = color_dict[labelColor] 572 | ) 573 | ax.axis('off') 574 | if save != None: 575 | plt.savefig(save, bbox_inches='tight', dpi=dpi) 576 | if close_plot: 577 | plt.close() 578 | 579 | def sankey_pathway_decomposition(W:pd.DataFrame, H:pd.DataFrame, 580 | left_labels:list=None, right_labels:list=None, mid_labels:list=None, 581 | aspect:float=4, patch_alpha:float=0.99, link_alpha:float=0.65, 582 | interval:float=0.0, 583 | figsize:tuple=(12, 6), fontsize:float=9, 584 | cmap='tab20', 585 | dpi:int=300, 586 | save:str=None, 587 | ): 588 | ''' 589 | Input: 590 | W: pd.DataFrame. 591 | The DataFrame generated from W matrix of decomposition result with shape (K, R). K represents cells while R represents patterns. 592 | H: pd.DataFrame. 593 | The DataFrame generated from H matrix of decomposition result with shape (R, N). R represents 594 | cmap: str|dict. 595 | Define colors of each patch. User can set matplotlib's colormap (e.g. viridis, jet, tab10) or label_name -> color dict (e.g. dict(A="red", B="blue", C="green", ...)). 596 | Output: 597 | None 598 | ''' 599 | 600 | fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize, layout="constrained") 601 | sankey_2d(W, left_labels=left_labels, right_labels=mid_labels, 602 | aspect=aspect, patch_alpha=patch_alpha, link_alpha=link_alpha, 603 | interval=interval, 604 | cmap=cmap, fontsize=fontsize, strip_color='right', 605 | dpi=dpi, ax=ax1) 606 | sankey_2d(H, left_labels=mid_labels, right_labels=right_labels, 607 | aspect=aspect, patch_alpha=patch_alpha, link_alpha=link_alpha, 608 | interval=interval, 609 | cmap=cmap, fontsize=fontsize, strip_color='left', 610 | dpi=dpi, ax=ax2) 611 | if save != None: 612 | fig.savefig(save, bbox_inches='tight', dpi=dpi) -------------------------------------------------------------------------------- /stereosite/plot/scii.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import numpy as np 3 | import pandas as pd 4 | import scanpy as sc 5 | from typing import Any, Mapping, Sequence, TYPE_CHECKING 6 | from matplotlib.axes import Axes 7 | from matplotlib.colorbar import ColorbarBase 8 | import matplotlib.pyplot as plt 9 | from anndata import AnnData 10 | from scanpy import logging as logg, settings 11 | from scipy.sparse import issparse, spmatrix 12 | from scipy.cluster import hierarchy as sch 13 | from pathlib import Path 14 | from typing import Any, Callable, Hashable, Iterable, Sequence, Generator, TYPE_CHECKING 15 | from inspect import signature 16 | 17 | class CustomDotplot(sc.pl.DotPlot): 18 | 19 | BASE = 10 20 | 21 | DEFAULT_LARGEST_DOT = 50.0 22 | DEFAULT_NUM_COLORBAR_TICKS = 5 23 | DEFAULT_NUM_LEGEND_DOTS = 5 24 | 25 | def __init__(self, minn: float, delta: float, alpha: float, *args: Any, **kwargs: Any): 26 | super().__init__(*args, **kwargs) 27 | self._delta = delta 28 | self._minn = minn 29 | self._alpha = alpha 30 | 31 | def _plot_size_legend(self, size_legend_ax: Axes) -> None: 32 | y = self.BASE ** -((self.dot_max * self._delta) + self._minn) 33 | x = self.BASE ** -((self.dot_min * self._delta) + self._minn) 34 | size_range = -(np.logspace(x, y, self.DEFAULT_NUM_LEGEND_DOTS + 1, base=10).astype(np.float64)) 35 | size_range = (size_range - np.min(size_range)) / (np.max(size_range) - np.min(size_range)) 36 | # no point in showing dot of size 0 37 | size_range = size_range[1:] 38 | 39 | size = size_range**self.size_exponent 40 | mult = (self.largest_dot - self.smallest_dot) + self.smallest_dot 41 | size = size * mult 42 | 43 | # plot size bar 44 | size_legend_ax.scatter( 45 | np.arange(len(size)) + 0.5, 46 | np.repeat(1, len(size)), 47 | s=size, 48 | color="black", 49 | edgecolor="black", 50 | linewidth=self.dot_edge_lw, 51 | zorder=100, 52 | ) 53 | size_legend_ax.set_xticks(np.arange(len(size)) + 0.5) 54 | labels = [f"{(x * self._delta) + self._minn:.1f}" for x in size_range] 55 | size_legend_ax.set_xticklabels(labels, fontsize="small") 56 | 57 | # remove y ticks and labels 58 | size_legend_ax.tick_params(axis="y", left=False, labelleft=False, labelright=False) 59 | # remove surrounding lines 60 | for direction in ["right", "top", "left", "bottom"]: 61 | size_legend_ax.spines[direction].set_visible(False) 62 | 63 | ymax = size_legend_ax.get_ylim()[1] 64 | size_legend_ax.set_ylim(-1.05 - self.largest_dot * 0.003, 4) 65 | size_legend_ax.set_title(self.size_title, y=ymax + 0.25, size="small") 66 | 67 | xmin, xmax = size_legend_ax.get_xlim() 68 | size_legend_ax.set_xlim(xmin - 0.15, xmax + 0.5) 69 | 70 | if self._alpha is not None: 71 | ax = self.fig.add_subplot() 72 | ax.scatter( 73 | [0.35, 0.65], 74 | [0, 0], 75 | s=size[-1], 76 | color="black", 77 | edgecolor="black", 78 | linewidth=self.dot_edge_lw, 79 | zorder=100, 80 | ) 81 | ax.scatter( 82 | [0.65], [0], s=0.33 * mult, color="white", edgecolor="black", linewidth=self.dot_edge_lw, zorder=100 83 | ) 84 | ax.set_xlim([0, 1]) 85 | ax.set_xticks([0.35, 0.65]) 86 | ax.set_xticklabels(["false", "true"]) 87 | ax.set_yticks([]) 88 | ax.set_title(f"significant\n$p={self._alpha}$", y=ymax + 0.25, size="small") 89 | ax.set(frame_on=False) 90 | 91 | l, b, w, h = size_legend_ax.get_position().bounds 92 | ax.set_position([l + w, b, w, h]) 93 | 94 | def _plot_colorbar(self, color_legend_ax: Axes, normalize: bool) -> None: 95 | cmap = plt.get_cmap(self.cmap) 96 | 97 | ColorbarBase( 98 | color_legend_ax, 99 | orientation="horizontal", 100 | cmap=cmap, 101 | norm=normalize, 102 | ticks=np.linspace( 103 | np.nanmin(self.dot_color_df.values), 104 | np.nanmax(self.dot_color_df.values), 105 | self.DEFAULT_NUM_COLORBAR_TICKS, 106 | ), 107 | format="%.2f", 108 | ) 109 | 110 | color_legend_ax.set_title(self.color_legend_title, fontsize="small") 111 | color_legend_ax.xaxis.set_tick_params(labelsize="small") 112 | 113 | def _unique_order_preserving(iterable) -> tuple: 114 | """Remove items from an iterable while preserving the order.""" 115 | seen: set[Hashable] = set() 116 | seen_add = seen.add 117 | return [i for i in iterable if not (i in seen or seen_add(i))], seen 118 | 119 | def verbosity(level: int) -> Generator[None, None, None]: 120 | """ 121 | Temporarily set the verbosity level of :mod:`scanpy`. 122 | Parameters 123 | ---------- 124 | level 125 | The new verbosity level. 126 | Returns 127 | ------- 128 | Nothing. 129 | """ 130 | import scanpy as sc 131 | 132 | verbosity = sc.settings.verbosity 133 | sc.settings.verbosity = level 134 | try: 135 | yield 136 | finally: 137 | sc.settings.verbosity = verbosity 138 | 139 | def _filter_kwargs(func, kwargs) -> dict: 140 | style_args = {k for k in signature(func).parameters.keys()} # noqa: C416 141 | return {k: v for k, v in kwargs.items() if k in style_args} 142 | 143 | def save_fig(fig, path, make_dir= True, ext="png", **kwargs) -> None: 144 | """ 145 | Save a figure. 146 | Parameters 147 | ---------- 148 | fig 149 | Figure to save. 150 | path 151 | Path where to save the figure. If path is relative, save it under :attr:`scanpy.settings.figdir`. 152 | make_dir 153 | Whether to try making the directory if it does not exist. 154 | ext 155 | Extension to use if none is provided. 156 | kwargs 157 | Keyword arguments for :meth:`matplotlib.figure.Figure.savefig`. 158 | Returns 159 | ------- 160 | None 161 | Just saves the plot. 162 | """ 163 | if os.path.splitext(path)[1] == "": 164 | path = f"{path}.{ext}" 165 | 166 | path = Path(path) 167 | 168 | if not path.is_absolute(): 169 | path = Path(settings.figdir) / path 170 | 171 | if make_dir: 172 | try: 173 | path.parent.mkdir(parents=True, exist_ok=True) 174 | except OSError as e: 175 | logg.debug(f"Unable to create directory `{path.parent}`. Reason: `{e}`") 176 | 177 | logg.debug(f"Saving figure to `{path!r}`") 178 | 179 | kwargs.setdefault("bbox_inches", "tight") 180 | kwargs.setdefault("transparent", True) 181 | 182 | fig.savefig(path, **kwargs) 183 | 184 | def _dendrogram(data, method, **kwargs): 185 | link_kwargs = _filter_kwargs(sch.linkage, kwargs) 186 | dendro_kwargs = _filter_kwargs(sch.dendrogram, kwargs) 187 | 188 | # Row-cluster 189 | row_link = sch.linkage(data, method=method, **link_kwargs) 190 | row_dendro = sch.dendrogram(row_link, no_plot=True, **dendro_kwargs) 191 | row_order = row_dendro["leaves"] 192 | 193 | # Column-cluster 194 | col_link = sch.linkage(data.T, method=method, **link_kwargs) 195 | col_dendro = sch.dendrogram(col_link, no_plot=True, **dendro_kwargs) 196 | col_order = col_dendro["leaves"] 197 | 198 | return row_order, col_order, row_link, col_link 199 | 200 | 201 | from squidpy._constants._utils import ModeEnum 202 | class DendrogramAxis(ModeEnum): 203 | INTERACTING_MOLS = "interacting_molecules" 204 | INTERACTING_CLUSTERS = "interacting_clusters" 205 | BOTH = "both" 206 | 207 | _SEP = " | " 208 | 209 | def ligrec( 210 | adata, 211 | cluster_key = None, 212 | source_groups = None, 213 | target_groups = None, 214 | intensities_range = (-np.inf, np.inf), 215 | pvalue_threshold = 0.05, 216 | remove_empty_interactions = True, 217 | remove_nonsig_interactions = False, 218 | dendrogram = None, 219 | alpha = 0.001, 220 | swap_axes = False, 221 | title = None, 222 | figsize = None, 223 | dpi = None, 224 | save = None, 225 | **kwargs, 226 | ) -> None: 227 | """ 228 | Plot the result of a receptor-ligand permutation test. 229 | The result was computed by :func:`squidpy.gr.ligrec`. 230 | :math:`molecule_1` belongs to the source clusters displayed on the top (or on the right, if ``swap_axes = True``, 231 | whereas :math:`molecule_2` belongs to the target clusters. 232 | Parameters 233 | ---------- 234 | %(adata)s 235 | It can also be a :class:`dict`, as returned by :func:`squidpy.gr.ligrec`. 236 | %(cluster_key)s 237 | Only used when ``adata`` is of type :class:`AnnData`. 238 | source_groups 239 | Source interaction clusters. If `None`, select all clusters. 240 | target_groups 241 | Target interaction clusters. If `None`, select all clusters. 242 | intensity_range 243 | Only show interactions whose intensity are within this **closed** interval. 244 | pvalue_threshold 245 | Only show interactions with p-value <= ``pvalue_threshold``. 246 | remove_empty_interactions 247 | Remove rows and columns that only contain interactions with `NaN` values. 248 | remove_nonsig_interactions 249 | Remove rows and columns that only contain interactions that are larger than ``alpha``. 250 | dendrogram 251 | How to cluster based on the p-values. Valid options are: 252 | - `None` - do not perform clustering. 253 | - `'interacting_molecules'` - cluster the interacting molecules. 254 | - `'interacting_clusters'` - cluster the interacting clusters. 255 | - `'both'` - cluster both rows and columns. Note that in this case, the dendrogram is not shown. 256 | alpha 257 | Significance threshold. All elements with p-values <= ``alpha`` will be marked by tori instead of dots. 258 | swap_axes 259 | Whether to show the cluster combinations as rows and the interacting pairs as columns. 260 | title 261 | Title of the plot. 262 | %(plotting_save)s 263 | kwargs 264 | Keyword arguments for :meth:`scanpy.pl.DotPlot.style` or :meth:`scanpy.pl.DotPlot.legend`. 265 | Returns 266 | ------- 267 | %(plotting_returns)s 268 | """ 269 | 270 | def filter_values( 271 | pvals: pd.DataFrame, intensities: pd.DataFrame, *, mask: pd.DataFrame, kind: str 272 | ): 273 | mask_rows = mask.any(axis=1) 274 | pvals = pvals.loc[mask_rows] 275 | intensities = intensities.loc[mask_rows] 276 | 277 | if pvals.empty: 278 | raise ValueError(f"After removing rows with only {kind} interactions, none remain.") 279 | 280 | mask_cols = mask.any(axis=0) 281 | pvals = pvals.loc[:, mask_cols] 282 | intensities = intensities.loc[:, mask_cols] 283 | 284 | if pvals.empty: 285 | raise ValueError(f"After removing columns with only {kind} interactions, none remain.") 286 | 287 | return pvals, intensities 288 | 289 | def get_dendrogram(adata: AnnData, linkage: str = "complete") -> Mapping[str, Any]: 290 | z_var = sch.linkage( 291 | adata.X, 292 | metric="correlation", 293 | method=linkage, 294 | optimal_ordering=adata.n_obs <= 1500, # matplotlib will most likely give up first 295 | ) 296 | dendro_info = sch.dendrogram(z_var, labels=adata.obs_names.values, no_plot=True) 297 | # this is what the DotPlot requires 298 | return { 299 | "linkage": z_var, 300 | "groupby": ["groups"], 301 | "cor_method": "pearson", 302 | "use_rep": None, 303 | "linkage_method": linkage, 304 | "categories_ordered": dendro_info["ivl"], 305 | "categories_idx_ordered": dendro_info["leaves"], 306 | "dendrogram_info": dendro_info, 307 | } 308 | 309 | if dendrogram is not None: 310 | dendrogram = DendrogramAxis(dendrogram) # type: ignore[assignment] 311 | if TYPE_CHECKING: 312 | assert isinstance(dendrogram, DendrogramAxis) 313 | 314 | if isinstance(adata, AnnData): 315 | if cluster_key is None: 316 | raise ValueError("Please provide `cluster_key` when supplying an `AnnData` object.") 317 | 318 | cluster_key = adata.uns.ligrec(cluster_key) 319 | if cluster_key not in adata.uns_keys(): 320 | raise KeyError(f"Key `{cluster_key}` not found in `adata.uns`.") 321 | adata = adata.uns[cluster_key] 322 | 323 | if not isinstance(adata, dict): 324 | raise TypeError( 325 | f"Expected `adata` to be either of type `anndata.AnnData` or `dict`, found `{type(adata).__name__}`." 326 | ) 327 | 328 | if len(intensities_range) != 2: 329 | raise ValueError(f"Expected `intensities_range` to be a sequence of size `2`, found `{len(intensities_range)}`.") 330 | intensities_range = tuple(sorted(intensities_range)) # type: ignore[assignment] 331 | 332 | if alpha is not None and not (0 <= alpha <= 1): 333 | raise ValueError(f"Expected `alpha` to be in range `[0, 1]`, found `{alpha}`.") 334 | 335 | if source_groups is None: 336 | source_groups = adata["pvalues"].columns.get_level_values(0) 337 | elif isinstance(source_groups, str): 338 | source_groups = (source_groups,) 339 | 340 | if target_groups is None: 341 | target_groups = adata["pvalues"].columns.get_level_values(1) 342 | if isinstance(target_groups, str): 343 | target_groups = (target_groups,) 344 | if title is None: 345 | title = "Receptor-ligand test" 346 | 347 | source_groups, _ = _unique_order_preserving(source_groups) # type: ignore[assignment] 348 | target_groups, _ = _unique_order_preserving(target_groups) # type: ignore[assignment] 349 | 350 | pvals: pd.DataFrame = adata["pvalues"].loc[:, (source_groups, target_groups)] 351 | intensities: pd.DataFrame = adata["intensities"].loc[:, (source_groups, target_groups)] 352 | 353 | if pvals.empty: 354 | raise ValueError("No valid clusters have been selected.") 355 | 356 | intensities = intensities[(intensities >= intensities_range[0]) & (intensities <= intensities_range[1])] 357 | pvals = pvals[pvals <= pvalue_threshold] 358 | 359 | if remove_empty_interactions: 360 | pvals, intensities = filter_values(pvals, intensities, mask=~(pd.isnull(intensities) | pd.isnull(pvals)), kind="NaN") 361 | if remove_nonsig_interactions and alpha is not None: 362 | pvals, intensities = filter_values(pvals, intensities, mask=pvals <= alpha, kind="non-significant") 363 | 364 | start, label_ranges = 0, {} 365 | 366 | if dendrogram == DendrogramAxis.INTERACTING_CLUSTERS: 367 | # rows are now cluster combinations, not interacting pairs 368 | pvals = pvals.T 369 | intensities = intensities.T 370 | 371 | for cls, size in (pvals.groupby(level=0, axis=1)).size().to_dict().items(): 372 | label_ranges[cls] = (start, start + size - 1) 373 | start += size 374 | label_ranges = {k: label_ranges[k] for k in sorted(label_ranges.keys())} 375 | 376 | pvals = pvals[label_ranges.keys()] 377 | pvals = -np.log10(pvals + min(1e-3, alpha if alpha is not None else 1e-3)).fillna(0) 378 | 379 | pvals.columns = map(_SEP.join, pvals.columns.to_flat_index()) 380 | pvals.index = map(_SEP.join, pvals.index.to_flat_index()) 381 | 382 | intensities = intensities[label_ranges.keys()].fillna(0) 383 | intensities.columns = map(_SEP.join, intensities.columns.to_flat_index()) 384 | intensities.index = map(_SEP.join, intensities.index.to_flat_index()) 385 | intensities = np.log10(intensities + 1) 386 | 387 | var = pd.DataFrame(pvals.columns) 388 | var = var.set_index(var.columns[0]) 389 | 390 | adata = AnnData(pvals.values, obs={"groups": pd.Categorical(pvals.index)}, var=var, dtype=pvals.values.dtype) 391 | adata.obs_names = pvals.index 392 | minn = np.nanmin(adata.X) 393 | delta = np.nanmax(adata.X) - minn 394 | adata.X = (adata.X - minn) / delta 395 | 396 | try: 397 | if dendrogram == DendrogramAxis.BOTH: 398 | row_order, col_order, _, _ = _dendrogram( 399 | adata.X, method="complete", metric="correlation", optimal_ordering=adata.n_obs <= 1500 400 | ) 401 | adata = adata[row_order, :][:, col_order] 402 | pvals = pvals.iloc[row_order, :].iloc[:, col_order] 403 | intensities = intensities.iloc[row_order, :].iloc[:, col_order] 404 | elif dendrogram is not None: 405 | adata.uns["dendrogram"] = get_dendrogram(adata) 406 | except IndexError: 407 | # just in case pandas indexing fails 408 | raise 409 | except Exception as e: 410 | logg.warning(f"Unable to create a dendrogram. Reason: `{e}`") 411 | dendrogram = None 412 | 413 | kwargs["dot_edge_lw"] = 0 414 | kwargs.setdefault("cmap", "viridis") 415 | kwargs.setdefault("grid", True) 416 | kwargs.pop("color_on", None) # interferes with tori 417 | 418 | dp = ( 419 | CustomDotplot( 420 | delta=delta, 421 | minn=minn, 422 | alpha=alpha, 423 | adata=adata, 424 | var_names=adata.var_names, 425 | groupby="groups", 426 | dot_color_df=intensities, 427 | dot_size_df=pvals, 428 | title=title, 429 | var_group_labels=None if dendrogram == DendrogramAxis.BOTH else list(label_ranges.keys()), 430 | var_group_positions=None if dendrogram == DendrogramAxis.BOTH else list(label_ranges.values()), 431 | standard_scale=None, 432 | figsize=figsize, 433 | ) 434 | .style( 435 | **_filter_kwargs(sc.pl.DotPlot.style, kwargs), 436 | ) 437 | .legend( 438 | size_title=r"$-\log_{10} ~ P$", 439 | colorbar_title=r"$log_{10}(intensity + 1)$", 440 | **_filter_kwargs(sc.pl.DotPlot.legend, kwargs), 441 | ) 442 | ) 443 | if dendrogram in (DendrogramAxis.INTERACTING_MOLS, DendrogramAxis.INTERACTING_CLUSTERS): 444 | # ignore the warning about mismatching groups 445 | with verbosity(0): 446 | dp.add_dendrogram(size=1.6, dendrogram_key="dendrogram") 447 | if swap_axes: 448 | dp.swap_axes() 449 | 450 | dp.make_figure() 451 | 452 | if dendrogram != DendrogramAxis.BOTH: 453 | # remove the target part in: source | target 454 | labs = dp.ax_dict["mainplot_ax"].get_yticklabels() if swap_axes else dp.ax_dict["mainplot_ax"].get_xticklabels() 455 | for text in labs: 456 | text.set_text(text.get_text().split(_SEP)[1]) 457 | if swap_axes: 458 | dp.ax_dict["mainplot_ax"].set_yticklabels(labs) 459 | else: 460 | dp.ax_dict["mainplot_ax"].set_xticklabels(labs) 461 | 462 | if alpha is not None: 463 | yy, xx = np.where((pvals.values + alpha) >= -np.log10(alpha)) 464 | if len(xx) and len(yy): 465 | # for dendrogram='both', they are already re-ordered 466 | mapper = ( 467 | np.argsort(adata.uns["dendrogram"]["categories_idx_ordered"]) 468 | if "dendrogram" in adata.uns 469 | else np.arange(len(pvals)) 470 | ) 471 | logg.info(f"Found `{len(yy)}` significant interactions at level `{alpha}`") 472 | ss = 0.33 * (adata.X[yy, xx] * (dp.largest_dot - dp.smallest_dot) + dp.smallest_dot) 473 | 474 | # must be after ss = ..., cc = ... 475 | yy = np.array([mapper[y] for y in yy]) 476 | if swap_axes: 477 | xx, yy = yy, xx 478 | dp.ax_dict["mainplot_ax"].scatter(xx + 0.5, yy + 0.5, color="white", s=ss, lw=0) 479 | 480 | if dpi is not None: 481 | dp.fig.set_dpi(dpi) 482 | 483 | if save is not None: 484 | save_fig(dp.fig, save) -------------------------------------------------------------------------------- /stereosite/plot/scii_circos.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import pandas as pd 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import seaborn as sns 6 | from sklearn import datasets 7 | from collections import defaultdict 8 | from pycirclize import Circos 9 | from pycirclize.parser import Matrix 10 | from pycirclize.utils import calc_group_spaces, ColorCycler 11 | from matplotlib.patches import Patch 12 | from matplotlib.lines import Line2D 13 | from matplotlib.colors import ListedColormap 14 | 15 | 16 | def _sort(s:pd.Series, num:int=100): 17 | sorted_s = s.sort_values(ascending=False).index[:num] 18 | sorted_s.index = range(num) 19 | return sorted_s 20 | 21 | def scii_tensor_select(scii_tensor_list:list, 22 | factor_cc:pd.DataFrame, 23 | factor_lr:pd.DataFrame, 24 | factor_tme:pd.DataFrame, 25 | interest_TME:str, #|int 26 | interest_cc_module:str, #|int 27 | interest_LR_module:str, #|int 28 | lr_number:int=20, 29 | cc_number:int=10, 30 | 31 | ) -> pd.DataFrame: 32 | ''' 33 | Input: 34 | scii_tensor_file: Name of the file that contains scii_tensor result of all windows/bins. 35 | factor_cc: Factor matrix of cell cell pair module, the index indicates interacitng cell pair 36 | while the column indicates cell-cell module. 37 | factor_lr: Factor matrix of ligand receptor pair module, the index indicates ligand receptor pair that induce cell-cell interaction 38 | while the column indicates ligand-receptor module. 39 | factor_tme: Factor matrix of Tumor MicroEnvironment(TME) module, the index indicates window/bin 40 | while the column indicates TME module. 41 | interest_TME: Name of the interested TME module that will be calculated. 42 | interest_cc_module: Name of the interested cell-cell pair module. 43 | interest_lr_module: Name of the interested ligand-receptor pair module. 44 | lr_number: The number of ligand-receptor pairs on top that will remain. 45 | cc_number: The number of cell-cell pair on top that will remain. 46 | 47 | Return: 48 | pandas.DataFrame: Matrix contains interesting interaction in the TME region of interest, index represents cell-cell pairs 49 | while column represents ligand-receptor pairs. 50 | 51 | ''' 52 | top_lrpair = factor_lr.apply(lambda x: _sort(x)) 53 | tme_cluster = factor_tme.idxmax(axis=1) 54 | sub_scii_matrix = [k for k,v in zip(scii_tensor_list, pd.Series(tme_cluster).isin([interest_TME])) if v] 55 | sub_mt = np.dstack(sub_scii_matrix) 56 | mean_sub_mt = np.mean(sub_mt, axis=2) 57 | mean_df = pd.DataFrame(mean_sub_mt, index=scii_tensor_list[0].index, columns=scii_tensor_list[0].columns) 58 | mean_mt = mean_df[top_lrpair[interest_LR_module][0:lr_number]].loc[_sort(factor_cc[interest_cc_module], cc_number).tolist()] 59 | return mean_mt 60 | 61 | def _cell_pairs_generate(cells:list, separator:str="_"): 62 | cell_pairs = {} 63 | for c1 in cells: 64 | for c2 in cells: 65 | cell_pairs[f"{c1}{separator}{c2}"] = (c1, c2) 66 | return cell_pairs 67 | 68 | def scii_interaction_matrix_process(interaction_matrix:pd.DataFrame, 69 | cell_lr_separator:str="|", 70 | ) -> tuple: 71 | ''' 72 | Input 73 | interaction_matrix: DataFrame. Matrix contains interesting interaction, index represents cell cell pairs 74 | while column represents ligand receptor pairs 75 | cells: List of cell type names, which will be used to separate cells of the cell pair. 76 | separator: Separator used to combine ligand with receptor genes into LR pairs. 77 | cell_lr_separator: Separator used to combine cells with LR genes. 78 | Return 79 | (sectors:dict, links:list, genes:set) 80 | sectors: The dictionary containing all vertices and their weight. {sender-ligand: value, receiver-receptor: value, ...} 81 | links: List containing all links between sectors. [[sender-ligand, receiver-receptor, value], ...] 82 | genes: Set contains names of all genes. 83 | ''' 84 | 85 | #Normalize the value into 0~100 86 | interaction_matrix = interaction_matrix.fillna(0).T 87 | 88 | scii_min, scii_max = interaction_matrix.min().min(), interaction_matrix.max().max() 89 | norm_mt = ((interaction_matrix - scii_min)/(scii_max - scii_min)*100).apply(round).astype(int) 90 | 91 | #Generate sectors and cell groups 92 | sectors1 = defaultdict() 93 | sectors2 = defaultdict() 94 | genes = set() 95 | cells = set() 96 | links = [] 97 | for index, row in norm_mt.iterrows(): 98 | sender, receiver = index 99 | cells.add(sender) 100 | cells.add(receiver) 101 | for LR, value in row.items(): 102 | if value == 0: 103 | continue 104 | ligand, receptor = LR 105 | v1 = f"{sender}{cell_lr_separator}{ligand}" 106 | v2 = f"{receiver}{cell_lr_separator}{receptor}" 107 | if sender not in sectors1.keys(): 108 | sectors1[sender] = defaultdict(int) 109 | if receiver not in sectors2.keys(): 110 | sectors2[receiver] = defaultdict(int) 111 | sectors1[sender][ligand] += value 112 | sectors2[receiver][receptor] += value 113 | genes.add(ligand) 114 | genes.add(receptor) 115 | links.append([v1, v2, value]) 116 | 117 | sectors = defaultdict(int) 118 | for cell in cells: 119 | if cell in sectors1.keys(): 120 | for ligand, value in sectors1[cell].items(): 121 | sectors[f"{cell}{cell_lr_separator}{ligand}"] += value 122 | if cell in sectors2.keys(): 123 | for receptor, value in sectors2[cell].items(): 124 | sectors[f"{cell}{cell_lr_separator}{receptor}"] += value 125 | return sectors, links, genes, list(cells) 126 | 127 | def scii_tensor_interaction_matrix_process(interaction_matrix:pd.DataFrame, 128 | cells:list, 129 | separator:str="-", 130 | cell_lr_separator:str="|", 131 | ) -> tuple: 132 | ''' 133 | Input 134 | interaction_matrix: DataFrame. Matrix contains interesting interaction in the TME region of interest, index represents cell-cell pairs 135 | while column represents ligand-receptor pairs 136 | cells: list. 137 | List of cell type names, which will be used to separate cells of the cell pair. 138 | separator: str. 139 | Separator used to combine ligand with receptor genes into LR pairs. 140 | cell_lr_separator: str. 141 | Separator used to combine cells with LR genes. 142 | Return 143 | (sectors:dict, links:list, genes:set) 144 | sectors: dict. 145 | The dictionary containing all vertices and their weight. {sender-ligand: value, receiver-receptor: value, ...} 146 | links: list. 147 | List containing all links between sectors. [[sender-ligand, receiver-receptor, value], ...] 148 | genes: set. 149 | Set contains names of all genes. 150 | ''' 151 | #Normalize the value into 0~100 152 | scii_min, scii_max = interaction_matrix.min().min(), interaction_matrix.max().max() 153 | norm_mt = ((interaction_matrix - scii_min)/(scii_max - scii_min)*100).apply(round).astype(int) 154 | 155 | #Generate a directory that contains names of cell types 156 | cell_pairs = _cell_pairs_generate(cells) 157 | 158 | #Generate sectors and cell groups 159 | sectors1 = defaultdict() 160 | sectors2 = defaultdict() 161 | genes = set() 162 | links = [] 163 | for index, row in norm_mt.iterrows(): 164 | sender, receiver = cell_pairs[index] 165 | for LR, value in row.items(): 166 | if value == 0: 167 | continue 168 | ligand, receptor = LR.split(separator, 1) 169 | v1 = f"{sender}{cell_lr_separator}{ligand}" 170 | v2 = f"{receiver}{cell_lr_separator}{receptor}" 171 | if sender not in sectors1.keys(): 172 | sectors1[sender] = defaultdict(int) 173 | if receiver not in sectors2.keys(): 174 | sectors2[receiver] = defaultdict(int) 175 | sectors1[sender][ligand] += value 176 | sectors2[receiver][receptor] += value 177 | genes.add(ligand) 178 | genes.add(receptor) 179 | links.append([v1, v2, value]) 180 | 181 | sectors = defaultdict(int) 182 | for cell in cells: 183 | if cell in sectors1.keys(): 184 | for ligand, value in sectors1[cell].items(): 185 | sectors[f"{cell}{cell_lr_separator}{ligand}"] += value 186 | if cell in sectors2.keys(): 187 | for receptor, value in sectors2[cell].items(): 188 | sectors[f"{cell}{cell_lr_separator}{receptor}"] += value 189 | return sectors, links, genes 190 | 191 | def cells_lr_circos(interaction_matrix:pd.DataFrame, 192 | cells:list=None, 193 | cell_colors='Set3', 194 | gene_colors='tab20', 195 | link_color_palette:str='tab20', 196 | separater:str="-", 197 | cell_lr_separator:str="|", 198 | label_size:float=8, 199 | dpi:float=300, 200 | scii_tensor:bool=False, 201 | save:str=None, 202 | ): 203 | ''' 204 | Input: 205 | interaction_matrix: pandas.DataFrame. 206 | The dataframe that contains the spatial cell interaction intensity(SCII) values of each interaction. 207 | The index represents cell_cell pairs and the column represents ligand_receptor pairs. 208 | cells: list. 209 | The list contains the names of all cell types. This will be used to separate the sender and receiver cells that were 210 | combined to create index names for the interaction_matrix. 211 | cell_colors: str|dict. 212 | Define colors of difference cell type. User can set matplotlib's colormap (e.g. viridis, jet, tab10) or label_name -> color dict (e.g. dict(A="red", B="blue", C="green", ...)). 213 | gene_colors: str|dict. 214 | Define colors of difference gene. User can set matplotlib's colormap (e.g. viridis, jet, tab10) or label_name -> color dict (e.g. dict(A="red", B="blue", C="green", ...)). 215 | link_color_palette: str|list. 216 | Define colors of difference link. User can set matplotlib's colormap (e.g. viridis, jet, tab10) or list contain colors. 217 | separator: str. 218 | Separator used to combine ligand with receptor genes into LR pairs. 219 | cell_lr_separator: str. 220 | Separator used to combine cells with LR genes. 221 | scii_tensor: bool. 222 | If the interaction_matrix was generated by scii, set the scii_tensor=True. 223 | save: str. 224 | File name of the figure that will be saved. 225 | Return: 226 | None 227 | ''' 228 | 229 | if scii_tensor: 230 | sectors, links, genes = scii_tensor_interaction_matrix_process(interaction_matrix, cells=cells, separator=separater, cell_lr_separator=cell_lr_separator) 231 | else: 232 | sectors, links, genes, cells = scii_interaction_matrix_process(interaction_matrix, cell_lr_separator=cell_lr_separator) 233 | 234 | 235 | cell_groups = defaultdict(list) 236 | for key in sectors.keys(): 237 | cell, gene = key.split(cell_lr_separator, 1) 238 | cell_groups[cell].append(key) 239 | group_sizes = [len(value) for key, value in cell_groups.items()] 240 | 241 | #Generate the color palette for cells, genes and links 242 | #cmap1 = plt.get_cmap('tab20b') 243 | #cmap2 = plt.get_cmap('tab20c') 244 | 245 | if isinstance(cell_colors, str): 246 | cell_color_palette = plt.get_cmap(cell_colors, len(cells)).colors 247 | cell_colors = dict(zip(cells, cell_color_palette[0:len(cells)])) 248 | elif isinstance(cell_colors, dict): 249 | cell_colors = cell_colors 250 | else: 251 | raise Exception("cell_colors must be string representing the matplotlib's colormap or dict") 252 | 253 | if isinstance(gene_colors, str): 254 | gene_color_palette = plt.get_cmap(gene_colors, len(genes)).colors 255 | gene_colors = dict(zip(sorted(list(genes)), gene_color_palette[0:len(genes)])) 256 | elif isinstance(gene_colors, dict): 257 | gene_colors = gene_colors 258 | else: 259 | raise Exception("gene_colors must be string representing the matplotlib's colormap or dict") 260 | 261 | 262 | lr_links = sorted(list(set([(x[0].split(cell_lr_separator, 1)[1], x[1].split(cell_lr_separator, 1)[1]) for x in links]))) 263 | if isinstance(link_color_palette, str): 264 | link_color_palette = plt.get_cmap(link_color_palette, len(lr_links)).colors 265 | elif isinstance(link_color_palette, list): 266 | if len(link_color_palette) < len(lr_links): 267 | raise Exception("the length of link_color_palette less than then number of lr_links: {0} < {1}".format(len(link_color_palette), len(lr_links))) 268 | link_colors = dict(zip(lr_links, link_color_palette[0:len(lr_links)])) 269 | 270 | spaces = calc_group_spaces(group_sizes, space_bw_group=10, space_in_group=1) 271 | circos = Circos(sectors, space=spaces) 272 | 273 | # Plot sector track 274 | #ColorCycler.set_cmap("Set3") 275 | for sector in circos.sectors: 276 | track = sector.add_track(r_lim=(90, 95)) 277 | track.axis(fc=gene_colors[sector.name.split(cell_lr_separator, 1)[1]]) 278 | #track.text(sector.name.split("-", 1)[1], fontsize=5, r=92, orientation="vertical") 279 | 280 | #ColorCycler.set_cmap("tab10") 281 | for cell, group in cell_groups.items(): 282 | group_deg_lim = circos.get_group_sectors_deg_lim(group) 283 | circos.rect(r_lim=(100, 103), deg_lim=group_deg_lim, fc=cell_colors[cell], ec="black", lw=0.5) 284 | group_center_deg = sum(group_deg_lim)/2 285 | circos.text(cell, r=106, deg=group_center_deg, adjust_rotation=True, fontsize=label_size) 286 | 287 | #Plot links 288 | for sender, receiver, value in links: 289 | ligand, receptor = sender.split(cell_lr_separator, 1)[1], receiver.split(cell_lr_separator, 1)[1] 290 | circos.link((sender, sectors[sender]-value, sectors[sender]), (receiver, sectors[receiver]-value, sectors[receiver]), color=link_colors[(ligand, receptor)], direction=1) 291 | sectors[sender] -= value 292 | sectors[receiver] -= value 293 | 294 | fig = circos.plotfig() 295 | #Plot legend 296 | rect_handles = [] 297 | for link, color in link_colors.items(): 298 | rect_handles.append(Patch(color=color, label=f"{link[0]}-{link[1]}")) 299 | rect_legend = circos.ax.legend( 300 | handles = rect_handles, 301 | bbox_to_anchor=(1.1, 1.0), 302 | fontsize=6, 303 | title="Ligand-Receptor", 304 | ) 305 | circos.ax.add_artist(rect_legend) 306 | 307 | scatter_handles = [] 308 | for gene, color in gene_colors.items(): 309 | scatter_handles.append(Line2D([], [], color=color, marker="o", label=gene, ms=6, ls="None")) 310 | scatter_legend = circos.ax.legend( 311 | handles=scatter_handles, 312 | bbox_to_anchor=(1.5, 1.0), 313 | fontsize=6, 314 | title="Gene", 315 | handlelength=2, 316 | ) 317 | if save != None: 318 | fig.savefig(save, dpi=dpi) 319 | 320 | def cells_circos(interaction_matrix:pd.DataFrame, 321 | cells:list=None, 322 | cell_colors='tab20', 323 | label_orientation:str="horizontal", 324 | label_size:float=10, 325 | dpi:float=300, 326 | save:str=None, 327 | scii_tensor:bool=False, 328 | ): 329 | ''' 330 | Input: 331 | interaction_matrix: pandas.DataFrame. 332 | The dataframe that contains the spatial cell interaction intensity(SCII) values of each interaction. 333 | The index represents cell_cell pairs and the column represents ligand_receptor pairs. 334 | cells: list. 335 | The list contains the names of all cell types. This will be used to separate the sender and receiver cells that were 336 | combined to create index names for the interaction_matrix. 337 | cell_colors : str | dict[str, str], optional 338 | Colormap assigned to each outer track and link. User can set matplotlib's colormap (e.g. viridis, jet, tab10) or label_name -> color dict (e.g. dict(A="red", B="blue", C="green", ...)) 339 | save: str. 340 | File name of the figure that will be saved. 341 | scii_tensor: bool. 342 | If the interaction_matrix was generated by scii, set the scii_tensor=True. 343 | Return: 344 | None 345 | ''' 346 | #Normalize the value into 0~100 347 | scii_min, scii_max = interaction_matrix.min().min(), interaction_matrix.max().max() 348 | norm_interaction_matrix = ((interaction_matrix - scii_min)/(scii_max - scii_min)*100).apply(round).astype(int) 349 | 350 | #Generate the matrix that will be used to draw circos 351 | if scii_tensor: 352 | if not isinstance(cells, list): 353 | cells = [cell for cell in cells] 354 | if cells==None: 355 | raise Exception("When scii_tensor is True, the cells parameter must be given") 356 | cci_df = norm_interaction_matrix.sum(axis=1).to_frame() 357 | cell_pairs = _cell_pairs_generate(cells) 358 | cci_df[['sender', 'receiver']] = [cell_pairs[x] for x in cci_df.index] 359 | 360 | else: 361 | cci_df = norm_interaction_matrix.T.sum(axis=1).to_frame() 362 | cci_df[['sender', 'receiver']] = [list(x) for x in cci_df.index] 363 | 364 | cells = list(set(cci_df['sender'].unique()) | set(cci_df['receiver'].unique())) 365 | cells_dict = dict(zip(cells, range(len(cells)))) 366 | cci_df['sender_index'] = cci_df['sender'].map(cells_dict) 367 | cci_df['receiver_index'] = cci_df['receiver'].map(cells_dict) 368 | 369 | fromto_table_df = cci_df[['sender', 'receiver', 0]].rename(columns = {'sender': 'from', 'receiver': 'to', 0: 'value'}).reset_index(drop=True) 370 | matrix = Matrix.parse_fromto_table(fromto_table_df) 371 | 372 | #Draw circos 373 | circos = Circos.initialize_from_matrix( 374 | matrix, 375 | space=3, 376 | cmap=cell_colors, 377 | #ticks_interval=5, 378 | label_kws=dict(size=label_size, r=110, orientation=label_orientation), 379 | link_kws=dict(direction=1, ec='black', lw=0.5), 380 | ) 381 | fig = circos.plotfig() 382 | if save != None: 383 | fig.savefig(save, dpi=dpi) 384 | 385 | def lr_circos(interaction_matrix:pd.DataFrame, 386 | cells:list, 387 | cmap='Set3', 388 | separator:str="-", 389 | label_orientation:str="vertical", 390 | label_size:float=6, 391 | dpi:float=300, 392 | save:str=None, 393 | scii_tensor=False, 394 | ): 395 | ''' 396 | Input: 397 | interaction_matrix: The dataframe that contains the spatial cell interaction intensity(SCII) values of each interaction. 398 | The index represents cell_cell pairs and the column represents ligand_receptor pairs. 399 | cmap : str | dict[str, str], optional 400 | Colormap assigned to each outer track and link. User can set matplotlib's colormap (e.g. viridis, jet, tab10) or label_name -> color dict (e.g. dict(A="red", B="blue", C="green", ...)) 401 | separator: str. 402 | Separator used to combine ligand with receptor genes into LR pairs. 403 | save: File name of the figure that will be saved. 404 | ''' 405 | if scii_tensor: 406 | _, links, _ = scii_tensor_interaction_matrix_process(interaction_matrix, cells, separator=separator) 407 | else: 408 | _, links, _ = scii_interaction_matrix_process(interaction_matrix) 409 | links_df = pd.DataFrame(links, columns = ['from', 'to', 'value']).sort_values(by = ['from', 'to']) 410 | matrix = Matrix.parse_fromto_table(links_df) 411 | circos = Circos.initialize_from_matrix( 412 | matrix, 413 | space=2, 414 | cmap=cmap, 415 | #ticks_interval=5, 416 | label_kws=dict(size=label_size, r=110, orientation=label_orientation), 417 | link_kws=dict(direction=1, ec='black', lw=0.5, alpha=0.5), 418 | ) 419 | fig = circos.plotfig() 420 | if save != None: 421 | fig.savefig(save, dpi=dpi) 422 | 423 | 424 | 425 | -------------------------------------------------------------------------------- /stereosite/plot/scii_net.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import numpy as np 3 | import pandas as pd 4 | import matplotlib.pyplot as plt 5 | import igraph as ig 6 | from igraph import Graph 7 | from .scii_circos import scii_interaction_matrix_process, scii_tensor_interaction_matrix_process, _cell_pairs_generate 8 | 9 | def lr_link_graph_generate(interaction_matrix:pd.DataFrame, 10 | cells:list=None, 11 | separator:str="-", 12 | cell_lr_separator:str="|", 13 | reducer:int=3, 14 | cell_colors="Set3", 15 | lr_color_palette:str="tab20", 16 | scii_tensor:bool=False, 17 | ) -> Graph: 18 | ''' 19 | Input: 20 | interaction_matrix: DataFrame. 21 | Matrix contains interesting interaction in the TME region of interest, index represents cell-cell pairs 22 | while column represents ligand-receptor pairs 23 | cells: List. 24 | cell type names, which will be used to separate cells of the cell pair. 25 | separator: str. 26 | Separator used to combine ligand with receptor genes into LR pairs. 27 | cell_lr_separator: str. 28 | Separator used to combine cells with LR genes. 29 | reducer: int. 30 | The size of a vertex = weight of the vertex/reducer. 31 | cell_colors: str|dict. 32 | Define colors of difference cell type. User can set matplotlib's colormap (e.g. viridis, jet, tab10) or label_name -> color dict (e.g. dict(A="red", B="blue", C="green", ...)). 33 | lr_palette 34 | lr_color_palette: str|list. 35 | Define cmap of links between ligand with receptor. 36 | scii_tensor: bool. 37 | If the interaction_matrix was generated by scii, set the scii_tensor=True. 38 | Return: 39 | g: Graph containing all vertices(ligand or receptor) and their links information. 40 | 41 | ''' 42 | 43 | #transfer matrix into igraph 44 | if scii_tensor: 45 | sectors, links, genes = scii_tensor_interaction_matrix_process(interaction_matrix, cells, separator=separator, cell_lr_separator=cell_lr_separator) 46 | else: 47 | sectors, links, genes, cells = scii_interaction_matrix_process(interaction_matrix, cell_lr_separator=cell_lr_separator) 48 | 49 | #generate color palette for cells and links 50 | if isinstance(cell_colors, str): 51 | cell_colors = dict(zip(cells, [tuple(x[0:3]) for x in plt.get_cmap(cell_colors, len(cells)).colors])) 52 | elif isinstance(cell_colors, dict): 53 | cell_colors = cell_colors 54 | else: 55 | raise Exception("cell_colors must be string representing the matplotlib's colormap or dict") 56 | 57 | vertices = list(sectors.keys()) 58 | vertices_color = [cell_colors[x.split(cell_lr_separator)[0]] for x in vertices] 59 | vertices_dict = dict(zip(vertices, range(len(vertices)))) 60 | edges_index = [(vertices_dict[x[0]], vertices_dict[x[1]]) for x in links] 61 | lr_links = sorted(set([(x[0].split(cell_lr_separator)[1], x[1].split(cell_lr_separator)[1]) for x in links])) 62 | if isinstance(lr_color_palette, str): 63 | link_color_palette = [tuple(x[0:3]) for x in plt.get_cmap(lr_color_palette, len(lr_links)).colors] 64 | elif isinstance(lr_color_palette, list): 65 | if len(link_color_palette) < len(lr_links): 66 | raise Exception("the length of link_color_palette less than then number of lr_links: {0} < {1}".format(len(link_color_palette), len(lr_links))) 67 | link_color_palette = link_color_palette[0:len(lr_links)] 68 | link_colors = dict(zip(lr_links, link_color_palette)) 69 | edges_color = [link_colors[(x[0].split(cell_lr_separator)[1], x[1].split(cell_lr_separator)[1])] for x in links] 70 | 71 | g = Graph(n=len(vertices), edges=edges_index, directed=True) 72 | g.vs['name'] = vertices 73 | g.vs['label'] = [x.split(cell_lr_separator, 1)[1] for x in vertices] 74 | g.vs['color'] = vertices_color 75 | g.vs['weight'] = [x/reducer if x>reducer*5 else 5 for x in sectors.values()] 76 | g.es['weight'] = [x[2] for x in links] 77 | g.es['color'] = edges_color 78 | return g 79 | 80 | def cell_graph_generate(interaction_matrix:pd.DataFrame, 81 | cells:list=None, 82 | reducer:int=10, 83 | cell_colors="Set3", 84 | scii_tensor:bool=False, 85 | ) -> Graph: 86 | ''' 87 | Input: 88 | interaction_matrix: DataFrame. 89 | Matrix contains interesting interaction in the TME region of interest, index represents cell-cell pairs 90 | while column represents ligand-receptor pairs 91 | cells: List. 92 | cell type names, which will be used to separate cells of the cell pair. 93 | separator: str. 94 | Separator used to combine ligand with receptor genes into LR pairs. 95 | cell_lr_separator: str. 96 | Separator used to combine cells with LR genes. 97 | reducer: int. 98 | The size of a vertex = weight of the vertex/reducer. 99 | cell_colors: str|dict. 100 | Define colors of difference cell type. User can set matplotlib's colormap (e.g. viridis, jet, tab10) or label_name -> color dict (e.g. dict(A="red", B="blue", C="green", ...)). 101 | lr_palette 102 | scii_tensor: bool. 103 | If the interaction_matrix was generated by scii, set the scii_tensor=True. 104 | Return: 105 | g: Graph containing all vertices(ligand or receptor) and their links information. 106 | 107 | ''' 108 | 109 | #Normalize the value into 0~100 110 | scii_min, scii_max = interaction_matrix.min().min(), interaction_matrix.max().max() 111 | norm_interaction_matrix = ((interaction_matrix - scii_min)/(scii_max - scii_min)*100).apply(round).astype(int) 112 | 113 | if scii_tensor: 114 | if not isinstance(cells, list): 115 | cells = [cell for cell in cells] 116 | if cells==None: 117 | raise Exception("When scii_tensor is True, the cells parameter must be given") 118 | cci_df = norm_interaction_matrix.sum(axis=1).to_frame() 119 | cell_pairs = _cell_pairs_generate(cells) 120 | cci_df[['sender', 'receiver']] = [cell_pairs[x] for x in cci_df.index] 121 | 122 | else: 123 | cci_df = norm_interaction_matrix.T.sum(axis=1).to_frame() 124 | cci_df[['sender', 'receiver']] = [list(x) for x in cci_df.index] 125 | 126 | cells = list(set(cci_df['sender'].unique()) | set(cci_df['receiver'].unique())) 127 | cells_dict = dict(zip(cells, range(len(cells)))) 128 | cci_df['sender_index'] = cci_df['sender'].map(cells_dict) 129 | cci_df['receiver_index'] = cci_df['receiver'].map(cells_dict) 130 | cci_df = cci_df.rename(columns = {'sender': 'from', 'receiver': 'to', 0: 'value'}).reset_index(drop=True) 131 | 132 | #generate color palette for cells and links 133 | if isinstance(cell_colors, str): 134 | cell_colors = dict(zip(cells, [tuple(x[0:3]) for x in plt.get_cmap(cell_colors, len(cells)).colors])) 135 | elif isinstance(cell_colors, dict): 136 | cell_colors = cell_colors 137 | else: 138 | raise Exception("cell_colors must be string representing the matplotlib's colormap or dict") 139 | 140 | vertices = cells 141 | vertices_color = [cell_colors[x] for x in cells] 142 | vertices_weight = cci_df['value'].groupby(cci_df['from']).sum() 143 | edges_index = [tuple(x) for x in cci_df[['sender_index', 'receiver_index']].values] 144 | edges_color = [cell_colors[cell] for cell in cci_df['from'].values] 145 | g = Graph(n=len(vertices), edges = edges_index, directed=True) 146 | g.vs['name'] = vertices 147 | g.vs['label'] = vertices 148 | g.vs['color'] = vertices_color 149 | g.vs['weight'] = [vertices_weight[x]/reducer if vertices_weight[x]>reducer*5 else 5 for x in vertices] 150 | g.es['weight'] = cci_df['value'].values 151 | g.es['color'] = edges_color 152 | 153 | return g 154 | 155 | def cell_lr_graph_plot(g:Graph, 156 | separator:str='-', 157 | cell_lr_separator:str="|", 158 | layout_type:str='kk', 159 | save:str=None, 160 | vertex_label_angle=90, 161 | vertex_label_dist = 0, 162 | vertex_label_size = 8, 163 | edge_width=[0.5, 4], 164 | edge_curved=0.2, 165 | edge_arrow_size=10, 166 | edge_arrow_width=5, 167 | figsize:int=15, 168 | dpi:float=300, 169 | g_kwargs:dict={}, 170 | ): 171 | ''' 172 | Input: 173 | g: Graph. 174 | Graph object contians information of cell cell communication. 175 | separator: str. 176 | Separator used to combine ligand with receptor genes into LR pairs. 177 | cell_lr_separator: str. 178 | Separator used to combine cells with LR genes. 179 | layout_type: str. 180 | Layout style used to draw the graph. Same to the layout of igraph 181 | save: str. 182 | File name of the figure that will be saved. 183 | g_kwargs: dict. 184 | Dictionary containing parameters of igraph plot setting. 185 | Return: 186 | None 187 | ''' 188 | 189 | cell_colors = dict(zip([x.split(cell_lr_separator)[0] for x in g.vs['name']], g.vs['color'])) 190 | link_colors = dict(zip([f"{x.source_vertex['label']}{separator}{x.target_vertex['label']}" for x in g.es], g.es['color'])) 191 | 192 | fig, ax = plt.subplots() 193 | layout = g.layout(layout_type) 194 | layout.rotate(0) 195 | #draw graph 196 | ig.plot(g, 197 | layout = layout, 198 | vertex_size=g.vs['weight'], 199 | vertex_label_angle=vertex_label_angle, 200 | vertex_label_dist = vertex_label_dist, 201 | vertex_label_size = vertex_label_size, 202 | edge_width=edge_width, 203 | edge_curved=edge_curved, 204 | edge_arrow_size=edge_arrow_size, 205 | edge_arrow_width=edge_arrow_width, 206 | target=ax, 207 | **g_kwargs 208 | ) 209 | #generate legend 210 | cell_legend_handles = [] 211 | for cell, color in cell_colors.items(): 212 | handle = ax.scatter( 213 | [], [], 214 | s=100, 215 | facecolor=color, 216 | label=cell, 217 | ) 218 | cell_legend_handles.append(handle) 219 | l1 = ax.legend( 220 | handles=cell_legend_handles, 221 | title='Cell Type', 222 | bbox_to_anchor=(1.0, 1.0), 223 | bbox_transform=ax.transAxes, 224 | ) 225 | lr_legend_handles = [] 226 | for lr, color in link_colors.items(): 227 | handle = ax.scatter( 228 | [], [], 229 | s=100, 230 | facecolor=color, 231 | label=lr, 232 | marker = 's', 233 | ) 234 | lr_legend_handles.append(handle) 235 | ax.legend( 236 | handles=lr_legend_handles, 237 | title='Ligand-Receptor', 238 | bbox_to_anchor=(1.0, 0.6), 239 | bbox_transform=ax.transAxes, 240 | ) 241 | fig.gca().add_artist(l1) 242 | fig.set_size_inches(figsize, figsize) 243 | if save != None: 244 | fig.savefig(save, dpi=dpi) 245 | 246 | 247 | def cell_graph_plot(g:Graph, 248 | layout_type:str='kk', 249 | save:str=None, 250 | vertex_label_angle=90, 251 | vertex_label_dist = 0, 252 | vertex_label_size = 8, 253 | edge_width=[0.5, 4], 254 | edge_curved=0.2, 255 | edge_arrow_size=10, 256 | edge_arrow_width=5, 257 | figsize:float=15, 258 | dpi:float=300, 259 | g_kwargs:dict={}, 260 | ): 261 | ''' 262 | Input: 263 | g: Graph. 264 | Graph object contians information of cell cell communication. 265 | layout_type: Layout style used to draw the graph. Same to the layout of igraph 266 | layout_type: str. 267 | Layout style used to draw the graph. Same to the layout of igraph 268 | save: str. 269 | File name of the figure that will be saved. 270 | g_kwargs: dict. 271 | Dictionary containing parameters of igraph plot setting. 272 | ''' 273 | 274 | fig, ax = plt.subplots() 275 | layout = g.layout(layout_type) 276 | layout.rotate(0) 277 | #draw graph 278 | ig.plot(g, 279 | layout = layout, 280 | vertex_size=g.vs['weight'], 281 | vertex_label_angle=vertex_label_angle, 282 | vertex_label_dist = vertex_label_dist, 283 | vertex_label_size = vertex_label_size, 284 | edge_width=edge_width, 285 | edge_curved=edge_curved, 286 | edge_arrow_size=edge_arrow_size, 287 | edge_arrow_width=edge_arrow_width, 288 | target=ax, 289 | **g_kwargs 290 | ) 291 | fig.set_size_inches(figsize, figsize) 292 | if save != None: 293 | fig.savefig(save, dpi=dpi) -------------------------------------------------------------------------------- /stereosite/plot/scii_tensor.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import pandas as pd 5 | import seaborn as sns 6 | from mpl_toolkits.mplot3d import Axes3D 7 | import matplotlib as mpl 8 | from mpl_toolkits.axes_grid1.inset_locator import inset_axes 9 | from matplotlib import cm 10 | 11 | def core_heatmap(core, 12 | save:str=None, 13 | figsize:tuple=(10, 6), 14 | dpi=300, 15 | cmap:str='rocket', 16 | ): 17 | fig = plt.figure(111, figsize=figsize, facecolor='white') 18 | ax1 = fig.add_subplot(1, 2, 1, projection='3d') 19 | ax11 = fig.add_subplot(1, 2, 2) #color bar 20 | 21 | # Set the azimuth 22 | ax1.azim = -40 23 | ax1.dist = 9 24 | ax1.elev = 25 25 | 26 | # Draw grid 27 | x = np.arange(core.shape[0]) 28 | y = np.arange(core.shape[1]) 29 | X, Y = np.meshgrid(x, y) 30 | 31 | # Draw the 2D heatmap 32 | for module in range(core.shape[2]): 33 | paper_white = cm.ScalarMappable(cmap = cmap).to_rgba(core[:,:,module]) 34 | surf = ax1.plot_surface(X, np.zeros(shape=X.shape)+module, Y, rstride=1, cstride=1, facecolors=paper_white, linewidth=0, antialiased=True, alpha=1) 35 | 36 | # Draw main coordinates and set the ticks 37 | ax1.tick_params(axis='x', colors='k', labelsize=6) 38 | ax1.tick_params(axis='y', colors='k', labelsize=6) 39 | ax1.tick_params(axis='z', colors='k', labelsize=6) 40 | ax1.set_xlabel("LR Modules") 41 | ax1.set_ylabel("TME Modules") 42 | ax1.set_zlabel("CC Modules") 43 | ax1.set_xticks(range(core.shape[0])) 44 | ax1.set_yticks(range(core.shape[2])) 45 | ax1.set_zticks(range(core.shape[1])) 46 | 47 | # Set the color bar 48 | ax11.set_visible(False) 49 | axin11 = inset_axes(ax11, width="2%", height="75%", loc='center left', borderpad=0) 50 | axin11.tick_params(axis='y', labelsize=12) 51 | norm_cot = mpl.colors.Normalize(vmin=core.min(), vmax=core.max()) 52 | fig.colorbar(cm.ScalarMappable(norm=norm_cot,cmap=cmap), shrink=1.0, aspect=5, ax=ax11, cax=axin11) 53 | if save != None: 54 | plt.savefig(save, dpi=dpi) 55 | else: 56 | plt.show() 57 | plt.close() 58 | 59 | def tme_core_heatmap(core, 60 | tme_number:int, 61 | vmin:int=0, 62 | vmax:int=None, 63 | save:str=None, 64 | figsize:tuple=(5, 4), 65 | dpi:float=300, 66 | cmap:str='rocket', 67 | ): 68 | tme_df = pd.DataFrame(core[:, :, tme_number]) 69 | tme_df.columns = tme_df.columns.map(lambda x: f"LR_Module {x}") 70 | tme_df.index = tme_df.index.map(lambda x: f"CC_Module {x}") 71 | if vmax==None: 72 | vmax=tme_df.max().max() 73 | plt.figure(figsize=figsize) 74 | h = sns.heatmap(tme_df, #cmap='Purples', 75 | linewidths=0.005, linecolor='black', 76 | annot=False, 77 | cbar=False, 78 | vmin=vmin, 79 | vmax=vmax, 80 | cmap=cmap, 81 | ) 82 | cbar = h.figure.colorbar(h.collections[0]) 83 | ticks = [tme_df.min().min(), vmax] # tme7_df.max().max()] 84 | labels = ['Low', 'High'] 85 | cbar.set_ticks(ticks) 86 | cbar.set_ticklabels(labels) 87 | plt.title(f"TME Module {tme_number}") 88 | if save != None: 89 | plt.savefig(save, dpi=dpi) 90 | else: 91 | plt.show() 92 | plt.close() 93 | 94 | def interaction_heatmap(interaction_matrix:pd.DataFrame, 95 | linewidths:float=0.005, 96 | linecolor:str='black', 97 | vmax:int=None, 98 | save:str=None, 99 | figsize:tuple=(6, 1.5), 100 | dpi:float=300, 101 | cmap:str='rocket', 102 | ): 103 | if vmax==None: 104 | vmax=interaction_matrix.max().max() 105 | plt.figure(figsize=figsize) 106 | h = sns.heatmap(interaction_matrix, 107 | linewidths=linewidths, 108 | linecolor=linecolor, 109 | vmax=vmax, 110 | cbar=False, 111 | cmap=cmap, 112 | ) 113 | cbar = h.figure.colorbar(h.collections[0]) 114 | ticks = [0, vmax] 115 | labels = ['Low', 'High'] 116 | cbar.set_ticks(ticks) 117 | cbar.set_ticklabels(labels) 118 | if save != None: 119 | plt.savefig(save, dpi=dpi) 120 | else: 121 | plt.show() 122 | plt.close() 123 | 124 | def reconstruction_error_line(re_mat:np.ndarray, 125 | figsize:tuple=(4, 4), 126 | save:str=None, 127 | palette:str='tab20', 128 | dpi=300, 129 | ): 130 | 131 | num_TME_modules = re_mat.shape[0] 132 | colors = plt.get_cmap(palette, num_TME_modules-2).colors 133 | 134 | plt.figure(figsize=figsize) 135 | for j in range(3,num_TME_modules): 136 | plt.plot(np.arange(3, len(re_mat[j])),re_mat[j][3:],label = 'rank: TME={}'.format(j), color=colors[j-2]) 137 | plt.xlabel('ranks: CC, LR') 138 | plt.ylabel('reconstruction error') 139 | plt.legend(bbox_to_anchor=(1.04, 0.5), loc="center left") 140 | if save!=None: 141 | plt.savefig(save, bbox_inches='tight', dpi=dpi) 142 | else: 143 | plt.show() 144 | plt.close() -------------------------------------------------------------------------------- /stereosite/ppi/__init__.py: -------------------------------------------------------------------------------- 1 | """ The protein and protein interaction module """ -------------------------------------------------------------------------------- /stereosite/ppi/ppi_analysis.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- encoding: utf-8 -*- 3 | 4 | """ 5 | @File : ppi_analysis.py 6 | @Description : ppi_analysis.py 7 | @Author : Liuchuandong liuchuandong@genomics.cn 8 | @Date : 2023/08/01 14:35:03 9 | """ 10 | 11 | from gc import DEBUG_COLLECTABLE 12 | import os, sys 13 | from optparse import OptionParser 14 | import numpy as np 15 | import pandas as pd 16 | import networkx as nx 17 | import matplotlib 18 | import matplotlib.pyplot as plt 19 | from matplotlib.colors import ListedColormap, LinearSegmentedColormap 20 | import subprocess 21 | import logging 22 | 23 | PPI_outdir = 'PPI' 24 | os.makedirs(PPI_outdir, exist_ok=True) 25 | log = logging.getLogger(__name__) 26 | log.setLevel(logging.DEBUG) 27 | 28 | def mcl(query_net,inflation: float = 4.0): 29 | ''' 30 | Run Markov CLustering in query net 31 | 32 | Parameters 33 | ---------- 34 | query_net 35 | the ppi net queried from STRING database according to user input 36 | inflation 37 | the Markov CLuster algorithm's parparmeter, vary this parameter to obtain clusterings at different levels of granularity. A good set of starting values is 1.4, 2, 4, and 6. 38 | ---------- 39 | ''' 40 | logging.critical("Make sure the Markov CLuster algorithm (MCL,https://github.com/micans/mcl) be installed properly.") 41 | 42 | nx.to_pandas_edgelist(query_net).to_csv(PPI_outdir+"/query_net_edgelist.txt",sep='\t',header=0,index=0) 43 | mcl_result_filepath = PPI_outdir+'/out.query_net_edgelist.txt.I'+str( int(inflation*10) ) 44 | 45 | subprocess.run(['mcl', PPI_outdir+'/query_net_edgelist.txt','--abc','-I', str(inflation),'-o',mcl_result_filepath], capture_output=True) 46 | logging.info('mcl result be stored in: '+mcl_result_filepath ) 47 | mcl_result = pd.read_csv(mcl_result_filepath, sep='\t', header=None, index_col=False) 48 | 49 | logging.info('Get {} clusters by markov clustering'.format( len(mcl_result) )) 50 | logging.info('The lagrest cluster with {} nodes'.format( len(mcl_result.loc[0,].dropna()) )) 51 | 52 | return(mcl_result) 53 | 54 | 55 | 56 | def get_cluster_net(query_net,mcl_result,mcl_id=0): 57 | ''' 58 | Subset specific net from qurey net according to specific mcl_cluster_id, default is the largest one 59 | ''' 60 | cluster_net = nx.subgraph(query_net,[i for i in list(mcl_result.loc[mcl_id,]) if str(i)!='nan']) 61 | return(cluster_net) 62 | 63 | 64 | 65 | def get_MCC_hub_genes(cluster_net): 66 | ''' 67 | Find hub genes from specific MCL cluster 68 | ''' 69 | # 1. get genes by sorting MCC score 70 | node_MCC = nx.node_clique_number(cluster_net) 71 | rank_values = set(node_MCC.values()) 72 | rank_values = sorted(rank_values,reverse=True) 73 | max_values = rank_values[0:10] 74 | print('Top10 MCC scores are:',max_values) 75 | max_MCC_genes = [] 76 | # get genes which with top MCC score, by default output 10 genes, if more than 10 genes with max score, still output 77 | for i in max_values: 78 | for k,v in node_MCC.items(): 79 | if v ==i: max_MCC_genes.append(k) 80 | if len(max_MCC_genes)>10: 81 | break 82 | # 2. get genes by sorting nodes degree 83 | top_degree_genes=[] 84 | rank_degree = sorted(cluster_net.degree, key=lambda x: x[1], reverse=True) 85 | degree = list(set([i[1] for i in rank_degree])) 86 | degree.sort(reverse=True) 87 | if len(degree)<10: 88 | degree_cutoff=degree[len(degree)-1] 89 | else: 90 | degree_cutoff=degree[9] 91 | for i in rank_degree: 92 | if i[1]>=degree_cutoff: 93 | top_degree_genes.append(i[0]) 94 | print('Top10 degree are:',degree[0:10]) 95 | # 3. get overlapped genes according to MCC and degree 96 | hub_genes = [i for i in max_MCC_genes if i in top_degree_genes] 97 | hub_genes_df = pd.DataFrame({'hub_genes':hub_genes}) 98 | rank_degree = dict(rank_degree) 99 | hub_genes_df['degree'] = [rank_degree[i] for i in hub_genes] 100 | hub_genes_df['MCC_score'] = [node_MCC[i] for i in hub_genes] 101 | hub_genes_df = hub_genes_df.sort_values(by='degree',ascending=False,ignore_index=True) 102 | return(hub_genes_df) 103 | 104 | 105 | 106 | def get_hub_net(hub_genes,cluster_net,cutoff: float = 0.8): 107 | ''' 108 | Subset hub net of hub gene from specific MCL cluster 109 | 110 | Parpameters 111 | hub_genes 112 | The hub genes from specific MCL cluster by using get_hub_gene function. 113 | cluster_net 114 | The specific MCL cluster. 115 | cutoff 116 | the edge's STRING combined-score cutoff. 117 | ''' 118 | # 1. remove edges which score1: 48 | dup_id = set(dup['#string_protein_id']) 49 | for i in dup_id: 50 | tmp=dup[dup['#string_protein_id']==i] 51 | # keep term which have preferred_name 52 | if len(tmp['preferred_name'].dropna())==1: 53 | rm_index = set(tmp.index).difference( set(tmp['preferred_name'].dropna().index) ) 54 | rm_index = list(rm_index) 55 | # if no preferred_name, keep first in all duplicated 56 | else: 57 | rm_index=tmp.index[list( range(1,len(tmp)) )] 58 | print("remove terms with duplicated protein id: ",flush=True) 59 | print(query_proteins.loc[rm_index,:],flush=True) 60 | query_proteins=query_proteins.drop(index=rm_index,axis=0) 61 | return query_proteins 62 | 63 | 64 | 65 | ################################################### 66 | ## 2. get STRING network according query term ## 67 | ################################################### 68 | def get_PPInet(query_term, protein_info, protein_alias,full_net,score_cutoff): 69 | ''' 70 | Get PPI network of query term from STRING database. 71 | 72 | Parameters 73 | ---------- 74 | query_term 75 | List of input genes. 76 | protein_info 77 | Protein information of STRING database. 78 | protein_alias 79 | Protein alias of STRING database. 80 | full_net 81 | Whole ppi network from STRING database. 82 | score_cutoff 83 | STRING combined score of protein-protein interaction, represent confidence level. 84 | ---------- 85 | 86 | Returns 87 | PPI network of input. 88 | ''' 89 | query_term = _query_input(query_term, protein_info, protein_alias) 90 | query_proteins = _get_all_matched_terms(query_term) 91 | 92 | query_net = full_net[full_net['combined_score']>=score_cutoff*1000] 93 | query_net = query_net[query_net['protein1'].isin(query_proteins['#string_protein_id'])] 94 | query_net = query_net[query_net['protein2'].isin(query_proteins['#string_protein_id'])] 95 | query_net = nx.from_pandas_edgelist(query_net,'protein1','protein2',edge_attr = 'combined_score') 96 | # relabel net 97 | proteins_id_to_query_dict = query_proteins 98 | proteins_id_to_query_dict.index = proteins_id_to_query_dict['#string_protein_id'] 99 | proteins_id_to_query_dict = proteins_id_to_query_dict.loc[:,0] 100 | proteins_id_to_query_dict = proteins_id_to_query_dict.to_dict() 101 | query_net = nx.relabel_nodes(query_net, proteins_id_to_query_dict, copy=False) 102 | return query_net 103 | 104 | 105 | -------------------------------------------------------------------------------- /stereosite/ppi/run_mcl.sh: -------------------------------------------------------------------------------- 1 | #export PATH=path/to/mcl/bin:$PATH 2 | mcl $1 --abc -I $2 3 | -------------------------------------------------------------------------------- /stereosite/read/__init__.py: -------------------------------------------------------------------------------- 1 | """ The gene expression matrix read module """ -------------------------------------------------------------------------------- /stereosite/read/gem.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # _*_ coding: utf-8 _*_ 3 | 4 | #@Author: LiuXing liuxing2@genomics.cn 5 | #@Date: 2023-07-12 11:01:58 6 | #@Last Modified by: LiuXing 7 | #@Last Modified time: 2023-07-12 11:01:58 8 | 9 | 10 | 11 | import sys, os 12 | 13 | import scanpy as sc 14 | import anndata 15 | import pandas as pd 16 | import numpy as np 17 | import matplotlib.pyplot as plt 18 | import matplotlib as mpl 19 | import csv 20 | from scipy import sparse, stats 21 | import gzip 22 | import tifffile 23 | import cv2 24 | import tifffile 25 | import h5py 26 | 27 | class Gem_Reader(): 28 | """ 29 | class for transfer gene expression matrix to anndata 30 | 31 | Parameters 32 | ---------- 33 | gem_file 34 | gene expression matrix (gem or gef) file path. 35 | Gem file should include at least 4 columns: geneID, x, y, MIDCount. MIDCount indicates gene expression count, and can be specified by parameter count_key 36 | If cell segmentation has been done, gen file is consist of 5 columns: geneID, x, y, MIDCount, label. transcripts with same label belong to the same single cell. 37 | tissue_mask 38 | tissue mask file path, should be tiff format, used to extract gene expression data under the tissue covered region 39 | count_key 40 | column name of gene expression count in gem file 41 | cell_label_key 42 | column name of cell labels in gem file with single-cell segmentation. 43 | gene_name_key 44 | column name of gene name or gene ID in the gem file 45 | """ 46 | def __init__(self, 47 | gem_file: str, 48 | tissue_mask: str = None, 49 | count_key: str = 'MIDCount', 50 | cell_label_key: str = 'label', 51 | gene_name_key: str = 'geneID', 52 | ): 53 | self.gem_file = gem_file 54 | self.tissue_mask = tissue_mask 55 | self.count_key = count_key 56 | self.cell_label_key = cell_label_key 57 | self.gene_name_key = gene_name_key 58 | if self.gem_file.endswith(("gem.gz", "gem")): 59 | self._read_gem() 60 | elif self.gem_file.endswith("gef"): 61 | self._read_gef() 62 | else: 63 | print ("The file format of input gene expression file is incorrect. Only gem, gem.gz and gef file are supported.") 64 | sys.exit() 65 | 66 | def _read_gem(self) -> pd.DataFrame: 67 | """ 68 | Read gene expression matrix from the given gem file. And extract data under tissue covered region if tissue mask file was given. 69 | """ 70 | 71 | columntypes = {self.gene_name_key: 'category', 72 | "x": int, 73 | 'y': int, 74 | self.count_key: int, 75 | } 76 | 77 | gem = pd.read_csv(self.gem_file, sep="\t", quoting=csv.QUOTE_NONE, comment="#", dtype=columntypes) 78 | 79 | #gem.rename(columns={self.count_key : 'counts', 80 | # self.cell_label_key: 'label'}) 81 | 82 | #extract gene expression under tissue covered region 83 | if self.tissue_mask != None: 84 | tissue_mask = tifffile.imread(self.tissue_mask) 85 | maxY, maxX = tissue_mask.shape 86 | if maxX > gem['x'].max() or maxY > gem['y'].max(): 87 | print ("WARMING: mask is out of bounds") 88 | gem = gem.loc[(gem['x'] < maxX)&(gem['y'] < maxY)] 89 | gem = gem.loc[tissue_mask[gem['y'], gem['x']] > 0] 90 | self.gem = gem 91 | 92 | def _read_gef(self) -> pd.DataFrame: 93 | """ 94 | Read gene expression matrix from the given gef file. And extract data under tissue covered region if tissue mask file was given. 95 | """ 96 | gef = h5py.File(self.gem_file, 'r') 97 | total_len = gef['geneExp']['bin1']['expression'].len() 98 | geneNames = gef['geneExp']['bin1']['gene'][self.gene_name_key].astype('str') 99 | offsets = gef['geneExp']['bin1']['gene']['offset'] 100 | expand_geneNames = [] 101 | for i in range(len(offsets)): 102 | if i == len(offsets)-1: 103 | extends = total_len - offsets[i] 104 | else: 105 | extends = offsets[i+1] - offsets[i] 106 | expand_geneNames.extend([geneNames[i]]*extends) 107 | gem = pd.DataFrame(gef['geneExp']['bin1']['expression'][:]) 108 | gem[self.gene_name_key] = pd.Series(expand_geneNames, dtype='category') 109 | if self.tissue_mask !=None: 110 | tissue_mask = tifffile.imread(self.tissue_mask) 111 | maxY, maxX = tissue_mask.shape 112 | if maxX > gem['x'].max() or maxY > gem['y'].max(): 113 | print ("WARMING: mask is out of bounds") 114 | gem = gem.loc[(gem['x']0] 116 | gef.close() 117 | self.gem = gem 118 | 119 | def gem2anndata(self, bin_size=50) -> anndata: 120 | """ 121 | transfer gem to anndata 122 | 123 | Parameters 124 | ---------- 125 | bin_size 126 | Specify bin size with this parameter. For example, bin_sie = 50 means bining spots in the same 50X50 square. 127 | 128 | Returns 129 | ---------- 130 | anndata 131 | contains gene expression vector and spatial coordinate of each bin. 132 | """ 133 | 134 | half_bin_size = int(bin_size/2) 135 | 136 | self.gem['x'] = (self.gem['x']//bin_size)*bin_size + half_bin_size 137 | self.gem['y'] = (self.gem['y']//bin_size)*bin_size + half_bin_size 138 | self.gem['cell'] = self.gem['x'].astype(str) + "-" + self.gem['y'].astype(str) 139 | 140 | cells = self.gem['cell'].unique() 141 | genes = self.gem[self.gene_name_key].unique() 142 | 143 | cells_dict = dict(zip(cells, range(0, len(cells)))) 144 | genes_dict = dict(zip(genes, range(0, len(genes)))) 145 | rows = self.gem['cell'].map(cells_dict) 146 | cols = self.gem[self.gene_name_key].map(genes_dict) 147 | expMtx = sparse.csr_matrix((self.gem[self.count_key].values, (rows, cols)), shape=(cells.shape[0], genes.shape[0]), dtype=np.int32) 148 | 149 | obs = pd.DataFrame(index = cells) 150 | var = pd.DataFrame(index = genes) 151 | adata = anndata.AnnData(X = expMtx, obs = obs, var = var) 152 | positions = np.array(list(map(lambda x: [int(v) for v in x.strip().split("-")], adata.obs.index))) 153 | adata.obs['x'] = positions[:,0] 154 | adata.obs['y'] = positions[:,1] 155 | adata.obsm['spatial'] = adata.obs[['x', 'y']].values 156 | return adata 157 | 158 | def cellbin2anndata(self) -> anndata: 159 | """ 160 | transfer gem with single-cell segmentation to anndata 161 | 162 | Returns 163 | ----------- 164 | anndata 165 | contains gene expression vector and sptial coordinate of each cell 166 | """ 167 | 168 | gem = self.gem[self.gem[self.cell_label_key]!=0] 169 | 170 | cells = gem[self.cell_label_key].unique() 171 | genes = gem[self.gene_name_key].unique() 172 | cells_dict = dict(zip(cells, range(0, len(cells)))) 173 | genes_dict = dict(zip(genes, range(0, len(genes)))) 174 | rows = gem[self.cell_label_key].map(cells_dict) 175 | cols = gem[self.gene_name_key].map(genes_dict) 176 | 177 | expMtx = sparse.csr_matrix((gem[self.count_key].values, (rows, cols)), shape=(cells.shape[0], genes.shape[0]), dtype=np.int32) 178 | 179 | obs = pd.DataFrame(index = cells) 180 | var = pd.DataFrame(index = genes) 181 | adata = anndata.AnnData(X = expMtx, obs = obs, var = var) 182 | spatialgroup = gem[['x', 'y']].groupby(gem[self.cell_label_key]) 183 | spatialdf = spatialgroup.agg(lambda x: (x.max()+x.min())/2) 184 | spatialdf = spatialdf.reset_index() 185 | spatialdf[self.cell_label_key] = spatialdf[self.cell_label_key].astype('category').cat.reorder_categories(cells, ordered=True) 186 | spatialdf.sort_values(self.cell_label_key, inplace=True) 187 | adata.obsm['spatial'] = spatialdf[['x', 'y']].values 188 | return adata 189 | 190 | def gem_with_cellmask_2anndata(self, cell_mask: str) -> anndata: 191 | """ 192 | extract single-cell gem based on the cell mask and transfer it to anndata 193 | 194 | Parameters 195 | ----------- 196 | cell_mask 197 | cell mask file path, should be tiff format, used to extract gene expression of each single-cell 198 | 199 | Returns 200 | ----------- 201 | anndata 202 | contains gene expression vector and sptial coordinate of each cell 203 | """ 204 | 205 | mask = cv2.imread(cell_mask, -1) 206 | if (mask.max() == 1): 207 | _, labels = cv2.connectedComponents(mask) 208 | else: 209 | labels = mask 210 | tissuedf = pd.DataFrame() 211 | dst = np.nonzero(labels) 212 | 213 | tissuedf['x'] = dst[1] 214 | tissuedf['y'] = dst[0] 215 | tissuedf[self.cell_label_key] = labels[dst] 216 | 217 | res = pd.merge(self.gem, tissuedf, on=['x', 'y'], how='inner') 218 | self.gem = res 219 | adata = self.cellbin2anndata() 220 | return adata 221 | -------------------------------------------------------------------------------- /stereosite/scii.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # _*_ coding: utf-8 _*_ 3 | 4 | #@Author: LiuXing liuxing2@genomics.cn 5 | #@Date: 2023-07-12 15:29:08 6 | #@Last Modified by: LiuXing 7 | #@Last Modified time: 2023-07-12 15:29:08 8 | 9 | import os, sys 10 | import scanpy as sc 11 | import squidpy as sq 12 | import anndata 13 | import matplotlib.pyplot as plt 14 | import pandas as pd 15 | import numpy as np 16 | from scipy import sparse 17 | from itertools import product 18 | import math 19 | from tqdm.notebook import tqdm 20 | from multiprocessing import Pool 21 | from optparse import OptionParser 22 | import logging 23 | 24 | LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s" 25 | logging.basicConfig(level=logging.INFO, format=LOG_FORMAT) 26 | 27 | CODE_NUMBER=10000 28 | 29 | def _m2h_homologene(adata: anndata): 30 | """ 31 | homologous transition gene name from mouse to human 32 | 33 | Parameters 34 | ---------- 35 | adata 36 | anndata 37 | """ 38 | # Biomart tables 39 | biomart_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "datasets/biomart") 40 | #h2m_tab = pd.read_csv(os.path.join(biomart_path, "human_to_mouse_biomart_export.csv")).set_index('Gene name') 41 | m2h_tab = pd.read_csv(os.path.join(biomart_path, "mouse_to_human_biomart_export.csv")).set_index('Gene name') 42 | 43 | hdict = m2h_tab[~m2h_tab['Human gene name'].isnull()]['Human gene name'].to_dict() 44 | adata.var['original_gene_symbol'] = adata.var_names 45 | adata.var.index = [hdict[x] if x in hdict.keys() else x.upper() for x in adata.var_names] 46 | adata.var_names_make_unique() 47 | return adata 48 | 49 | def _mh_translate(adata:anndata): 50 | if any([x.startswith("Gm") for x in adata.var.index]): 51 | adata = _m2h_homologene(adata) 52 | #self.adata.var_names_make_unique() 53 | return adata 54 | 55 | def _hm_translate(adata, LRpairs): 56 | if "original_gene_symbol" in adata.var.keys(): 57 | #tanslate LRpairs from human genes to mouse genes 58 | hmdict = adata.var['original_gene_symbol'].to_dict() 59 | mouse_LRpairs = [] 60 | for LRpair in LRpairs: 61 | for gene in LRpair: 62 | if "_" in gene: 63 | mouse_LRpairs.append("_".join([hmdict[z] if z in hmdict.keys() else z for z in gene.split("_")])) 64 | else: 65 | mouse_LRpairs.append(hmdict[gene] if gene in hmdict.keys() else gene) 66 | return mouse_LRpairs 67 | else: 68 | return LRpairs 69 | 70 | def _generate_LRpairs(interactionDB:str, 71 | adata:anndata, 72 | LR_anno:str='annotation'): 73 | """ 74 | process L-R database. 75 | """ 76 | file_name = os.path.basename(interactionDB) 77 | SOURCE = 'source' 78 | TARGET = 'target' 79 | if not "MOUSE" in file_name.upper(): 80 | #translate mouse gene symbol to human gene symbol 81 | adata = _mh_translate(adata) 82 | logging.info(f"gene homology translation finished.") 83 | 84 | if "CELLCHATDB" in file_name.upper(): 85 | interactions = pd.read_csv(interactionDB) 86 | interactions[SOURCE] = interactions['ligand'] 87 | interactions[TARGET] = interactions['interaction_name_2'].str.extract('- (.+)', expand=False).map(lambda x: x.strip().strip("(").strip(")").replace("+", "_")) 88 | else: 89 | interactions = pd.read_csv(interactionDB) 90 | if SOURCE in interactions.columns: 91 | interactions.pop(SOURCE) 92 | if TARGET in interactions.columns: 93 | interactions.pop(TARGET) 94 | interactions.rename( 95 | columns={"genesymbol_intercell_source": SOURCE, "genesymbol_intercell_target": TARGET}, inplace=True 96 | ) 97 | interactions[SOURCE] = interactions[SOURCE].str.replace("^COMPLEX:", "", regex=True) 98 | interactions[TARGET] = interactions[TARGET].str.replace("^COMPLEX:", "", regex=True) 99 | 100 | LRpairs_dict = dict() 101 | if LR_anno in interactions.columns: 102 | LR_types = interactions[LR_anno].unique() 103 | for LR_type in LR_types: 104 | LRpairs_dict[LR_type] = interactions[interactions[LR_anno] == LR_type][[SOURCE, TARGET]].drop_duplicates().values 105 | else: 106 | LRpairs_dict['All'] = interactions[[SOURCE, TARGET]].drop_duplicates().values 107 | LRpairs = interactions[[SOURCE, TARGET]].drop_duplicates().values 108 | LRlist = [] 109 | filter_LRpairs = dict() 110 | for LR_type, LRpairs in LRpairs_dict.items(): 111 | filter_LRpairs[LR_type] = list() 112 | for LRpair in LRpairs: 113 | ligand, receptor = LRpair 114 | ligand_subs = ligand.split("_") 115 | receptor_subs = receptor.split("_") 116 | genes = ligand_subs + receptor_subs 117 | if all(g in adata.var_names for g in genes): 118 | filter_LRpairs[LR_type].append(LRpair) 119 | LRlist.extend(genes) 120 | adata = adata[:, adata.var_names.isin(LRlist)] 121 | return filter_LRpairs, adata 122 | 123 | def _generate_neighbors_graph(adata:anndata, 124 | radius = 400, #int or dict 125 | anno:str = "cell2loc_anno"): 126 | if isinstance(radius, int): 127 | sq.gr.spatial_neighbors(adata, radius=radius, coord_type="generic", key_added=str(radius)) 128 | elif isinstance(radius, dict): 129 | for LR_type, r in radius.items(): 130 | sq.gr.spatial_neighbors(adata, radius=r, coord_type='generic', key_added=str(r)) 131 | else: 132 | raise Exception("radius type must be int or dictionary, but get: {0}".format(type(radius))) 133 | #sq.gr.nhood_enrichment(adata, cluster_key= anno) 134 | #np.nan_to_num(adata.uns[f"{anno}_nhood_enrichment"]['zscore'], copy=False) 135 | return adata 136 | 137 | def preprocess_adata(adata:anndata, 138 | interactionDB:str, 139 | sample_number:int = 1000000, 140 | seed:int = 101, 141 | radius = 200, #int or dict 142 | scale:float = 0.5, 143 | LR_anno:str = "annotation", 144 | anno:str = "cell2loc_anno", 145 | use_raw:bool = True): 146 | 147 | #adata = anndata.read(adata_file) 148 | if adata.raw and use_raw: 149 | adata = adata.raw.to_adata() 150 | #sampling data with 1000000 cells 151 | sample_rate = 1 152 | if adata.shape[0] > sample_number: 153 | n_obs = adata.shape[0] 154 | sample_rate = float(sample_number)/float(n_obs) 155 | obs_index = adata.obs.index.values.copy() 156 | np.random.seed(seed) #set the random seed to be 101 157 | np.random.shuffle(obs_index) 158 | adata = adata[obs_index[0:sample_number],] 159 | #sc.pp.subsample(self.adata, n_obs=sample_number) 160 | logging.info("get subsample data from original data, the subsample data shape is {0}".format(adata.shape)) 161 | 162 | #generate LRpairs list from interactions 163 | LRpairs, adata = _generate_LRpairs(interactionDB, adata, LR_anno=LR_anno) 164 | LRpair_number = sum([len(x) for x in LRpairs.values()]) 165 | logging.info(f"generate LRpairs finished, and get {LRpair_number} LRpair") 166 | #change the radius unit from um to dnb and Check if the LR types given by parameter contians all LR types in the interaction database 167 | if isinstance(radius, int): 168 | radius = int(radius/scale) 169 | elif isinstance(radius, dict): 170 | for LRtype, r in radius.items(): 171 | radius[LRtype] = int(r/scale) 172 | if not set(LRpairs.keys()).issubset(set(radius.keys())): 173 | raise Exception("LR types given by parameter radius: {0} don't contain all LR types in the interaction database: {1}".format(LRpairs.keys(), radius.keys())) 174 | else: 175 | raise Exception("The type of radius threshold parameter must be int or dict, but we got: {0}".format(type(radius))) 176 | adata = _generate_neighbors_graph(adata, radius=radius, anno = anno) 177 | return adata, LRpairs, sample_rate 178 | 179 | def _get_LR_connect_matrix(adata:anndata, 180 | LRpair:list, 181 | connect_matrix:sparse.csr_matrix, 182 | complex_process:str = 'mean', 183 | distance_coefficient = 0, 184 | ) -> sparse.coo_matrix: 185 | ligand, receptor = LRpair 186 | if "_" in ligand: 187 | ligands = ligand.split("_") 188 | if complex_process.upper() == 'MEAN': 189 | exp_l = adata[:, ligands].X.mean(axis=1).A1 190 | elif complex_process.upper() == 'MIN': 191 | exp_l = adata[:, ligands].X.min(axis=1).toarray()[:,0] 192 | else: 193 | raise Exception("complex process model must be mean or min, but got: {0}".format(complex_process)) 194 | else: 195 | exp_l = adata[:, ligand].X.toarray()[:,0] 196 | if "_" in receptor: 197 | receptors = receptor.split("_") 198 | if complex_process.upper() == 'MEAN': 199 | exp_r = adata[:, receptors].X.mean(axis=1).A1 200 | elif complex_process.upper() == 'MIN': 201 | exp_r = adata[:, receptors].X.min(axis=1).toarray()[:,0] 202 | else: 203 | raise Exception("complex process model must be mean or min, but got: {0}".format(complex_process)) 204 | else: 205 | exp_r = adata[:, receptor].X.toarray()[:,0] 206 | l_rows = np.where(exp_l > 0)[0] 207 | r_cols = np.where(exp_r > 0)[0] 208 | sub_connect_matrix = connect_matrix[l_rows,:][:,r_cols].todense() 209 | dst = np.where(sub_connect_matrix > 0) 210 | distances = sub_connect_matrix[dst] 211 | connect_exp_lr = exp_l[l_rows[dst[0]]]*[math.exp(-distance_coefficient*d) for d in distances.A1] + exp_r[r_cols[dst[1]]] 212 | exp_connect_matrix = sparse.coo_matrix((connect_exp_lr, (l_rows[dst[0]], r_cols[dst[1]])), shape=connect_matrix.shape) 213 | return exp_connect_matrix 214 | 215 | def _get_LR_intensity(exp_connect_matrix:sparse.coo_matrix, 216 | cellTypeIndex:np.array, 217 | cellTypeNumber:int 218 | ) -> np.matrix: 219 | senders = cellTypeIndex[exp_connect_matrix.row] 220 | receivers = cellTypeIndex[exp_connect_matrix.col] 221 | interaction_matrix = sparse.csr_matrix((exp_connect_matrix.data, (senders, receivers)), shape=(cellTypeNumber, cellTypeNumber)) 222 | 223 | return interaction_matrix.todense() 224 | 225 | def _permutation_test(interaction_matrix:np.matrix, 226 | exp_connect_matrix:sparse.coo_matrix, 227 | cellTypeIndex:np.array, 228 | cellTypeNumber:int, 229 | n_perms = 1000, 230 | seed = 101 231 | ) -> np.array: 232 | pvalues = np.zeros((cellTypeNumber, cellTypeNumber), dtype=np.int32) 233 | for i in range(n_perms): 234 | cellTypeIndex_tmp = cellTypeIndex.copy() 235 | rs = np.random.RandomState(None if seed is None else i + seed) 236 | rs.shuffle(cellTypeIndex_tmp) 237 | interaction_matrix_tmp = _get_LR_intensity(exp_connect_matrix, cellTypeIndex_tmp, cellTypeNumber) 238 | pvalues += np.where(interaction_matrix_tmp > interaction_matrix, 1, 0) 239 | pvalues = pvalues/float(n_perms) 240 | pvalues[interaction_matrix == 0] = None 241 | return pvalues 242 | 243 | def _LRpair_process(adata:anndata, 244 | LRpair:np.array, 245 | connect_matrix:sparse.coo_matrix, 246 | cellTypeIndex:np.array, 247 | cellTypeNumber:int, 248 | seed:int, 249 | n_perms:int, 250 | complex_process:str = 'mean', 251 | distance_coefficient = 0, 252 | )->tuple : #np.array 253 | exp_connect_matrix = _get_LR_connect_matrix(adata, LRpair, connect_matrix, complex_process = complex_process, distance_coefficient=distance_coefficient) 254 | interaction_matrix = _get_LR_intensity(exp_connect_matrix, cellTypeIndex, cellTypeNumber) 255 | pvalues = _permutation_test(interaction_matrix, exp_connect_matrix, cellTypeIndex, cellTypeNumber, seed = seed, n_perms=n_perms) 256 | return interaction_matrix, pvalues 257 | 258 | def _result_combined(result_list:list, #list 259 | LRpairs:np.array, 260 | cell_types:np.array 261 | ) -> pd.DataFrame: 262 | 263 | columns = list(product(cell_types, repeat=2)) 264 | my_columns = pd.MultiIndex.from_tuples(columns, names=['cluster1', 'cluster2']) 265 | my_index = pd.MultiIndex.from_tuples(LRpairs, names=["source", "target"]) 266 | values = [np.ravel(x) for x in result_list] 267 | result_df = pd.DataFrame(np.row_stack(values), index = my_index, columns = my_columns) 268 | return result_df 269 | 270 | def intensities_count(adata:anndata, 271 | interactionDB:str, 272 | distance_threshold = 200, #int or dict 273 | distance_coefficient = 0, #float or dict 274 | scale:float = 0.5, 275 | LR_anno:str = "annotation", 276 | anno:str = "cell_type", 277 | seed:int = 101, 278 | n_perms:int = 1000, 279 | use_raw:bool = True, 280 | jobs:int = 1, 281 | complex_process_model:str = 'mean', 282 | ) -> dict: 283 | """ 284 | calculate intensities of interactions between all cell type pairs and ligand receptor pairs. 285 | 286 | Parameters 287 | ---------- 288 | adata 289 | anndata 290 | interactionDB 291 | file that stores ligand receptor pairs 292 | distance_threshold 293 | only cell pairs with the distance shorter than distance_threshold will be connected when construct nearest neighbor graph. 294 | default=200. The unit is µm 295 | If the ligand-receptor pairs have been clustered into different types, the distance_threshold can receive a dictionary with 296 | LR types and corresponding distance thresholds. e.g: {'Secreted Signaling': 200, 'ECM-Receptor': 200, 'Cell-Cell Contact': 30} 297 | distance_coefficient 298 | Consider the distance as one of the factor that influence the interaction intensity using the exponential decay formular: C=C0*e^(-k*d). 299 | The parameter defines the k value in the formular. Default=0, means distance would not influence the interaction intensity. 300 | If the ligand-receptor pairs have been clustered into different types, the distance_coefficient can receive a dictionary with 301 | LR types and corresponding coefficient. e.g: {'Secreted Signaling': 1, 'ECM-Receptor': 0.1, 'Cell-Cell Contact': 0} 302 | scale 303 | The distance between adjancent spots, the unit is µm. For Stereo-chip, scale=0.5. default=0.5 304 | LR_anno 305 | The name of the column that contains the LR types annotation information, default=annotation. 306 | anno 307 | cell type annotation key, default=cell_type 308 | seed 309 | specify seed when randomly shifting cell type labels to do permutation test, which will generate a null distribution 310 | to calculate p value of each interactions. default=101 311 | n_perms 312 | specify permutation number. default=10000 313 | use_raw 314 | bool value, which ditermine whether use the raw data in anndata. default=True 315 | jobs 316 | when jobs > 1, the program will call multi-process to analyze data. jobs=1 is fast enough for most task. default=1 317 | complex_process_model 318 | determine how to deal with the complexed ligand and receptor which contain multiple subunits. There are two options: mean, min. 319 | mean: calculate the mean expression of all subunits to represent the complex 320 | min: pick the minimal expression of all subunits to represent the complex 321 | 322 | Returns 323 | ---------- 324 | dictionary contains intensities and p values result 325 | {'intensities': intensities in DataFrame, 326 | 'pvalues': pvalues in DataFrame 327 | } 328 | """ 329 | 330 | adata, LRpairs, sample_rate = preprocess_adata(adata, interactionDB, radius=distance_threshold, scale=scale, seed = seed, LR_anno=LR_anno, anno=anno, use_raw=use_raw) 331 | #connect_matrix = adata.obsp['spatial_connectivities'] 332 | cell_types = adata.obs[anno].unique() 333 | cell_type_dict = dict(zip(cell_types, range(0, len(cell_types)))) 334 | cellTypeIndex = adata.obs[anno].map(cell_type_dict).astype(int).values 335 | cellTypeNumber = len(cell_types) 336 | 337 | logging.info("interaction intensity count begin.") 338 | results = [] 339 | intensities_list = [] 340 | pvalues_list = [] 341 | LRpairs_lists = [] 342 | if (jobs == 1): 343 | for LRtype, LRpair_list in LRpairs.items(): 344 | logging.info("compute the interaction intensity of LRpairs {0}".format(LRtype)) 345 | if isinstance(distance_threshold, (int, float)): 346 | connect_matrix = adata.obsp[f"{int(distance_threshold/scale)}_distances"] 347 | elif isinstance(distance_threshold, dict): 348 | connect_matrix = adata.obsp[f"{int(distance_threshold[LRtype])}_distances"] 349 | else: 350 | raise Exception("the type of distance_threshold must be int, float or dict, but get:{0}".format(type(distance_threshold))) 351 | if isinstance(distance_coefficient, (int, float)): 352 | k = distance_coefficient 353 | elif isinstance(distance_coefficient, dict): 354 | k = distance_coefficient[LRtype] 355 | else: 356 | raise Exception("the type of distance_threshold must be int, float or dict, but get:{0}".format(type(distance_coefficient))) 357 | for LRpair in tqdm(LRpair_list): 358 | LRpairs_lists.append(LRpair) 359 | results.append(_LRpair_process(adata, LRpair, connect_matrix, cellTypeIndex,cellTypeNumber, seed, n_perms, complex_process=complex_process_model, distance_coefficient=k)) 360 | else: 361 | pool = Pool(jobs) 362 | for LRtype, LRpair_list in LRpairs.items(): 363 | logging.info("compute the interaction intensity of LRpairs {0}".format(LRtype)) 364 | if isinstance(distance_threshold, int): 365 | connect_matrix = adata.obsp[f"{int(distance_threshold/scale)}_distances"] 366 | elif isinstance(distance_threshold, dict): 367 | connect_matrix = adata.obsp[f"{int(distance_threshold[LRtype])}_distances"] 368 | else: 369 | raise Exception("the type of distance_threshold must be int or dict, but get:{0}".format(type(distance_threshold))) 370 | if isinstance(distance_coefficient, (int, float)): 371 | k = distance_coefficient 372 | elif isinstance(distance_coefficient, dict): 373 | k = distance_coefficient[LRtype] 374 | else: 375 | raise Exception("the type of distance_threshold must be int, float or dict, but get:{0}".format(type(distance_coefficient))) 376 | LRpairs_lists.extend(LRpair_list) 377 | for LRpair in tqdm(LRpair_list): 378 | LRpairs_lists.append(LRpair) 379 | results.append(pool.apply_async(_LRpair_process, (adata, LRpair, connect_matrix, cellTypeIndex, cellTypeNumber, seed, n_perms, complex_process_model, k,))) 380 | pool.close() 381 | pool.join() 382 | logging.info("interaction intensity count finished.") 383 | LRpairs_lists = _hm_translate(adata, LRpairs_lists) 384 | logging.info("begin to combine results") 385 | for result in results: 386 | if (jobs == 1): 387 | intensities, pvalues = result 388 | else: 389 | intensities, pvalues = result.get() 390 | intensities_list.append(intensities) 391 | pvalues_list.append(pvalues) 392 | intensity_df = _result_combined(intensities_list, LRpairs_lists, cell_types) 393 | pvalues_df = _result_combined(pvalues_list, LRpairs_lists, cell_types) 394 | intensity_df = intensity_df/sample_rate 395 | logging.info(f"result combining finished.") 396 | 397 | plot_data = {'intensities': intensity_df, 398 | 'pvalues': pvalues_df} 399 | return plot_data 400 | 401 | def intensities_write(plot_data, out_dir): 402 | 403 | if not os.path.exists(out_dir): 404 | os.makedirs(out_dir, exist_ok=True) 405 | 406 | intensity_df = plot_data['intensities'] 407 | pvalues_df = plot_data['pvalues'] 408 | intensity_df_pickle = f"{out_dir}/intensities.pkl" 409 | pvalues_df_pickle = f"{out_dir}/pvalues.pkl" 410 | intensity_df.to_pickle(intensity_df_pickle) 411 | pvalues_df.to_pickle(pvalues_df_pickle) 412 | logging.info("finished to write result to pickle file.") 413 | 414 | intensity_df_csv = f"{out_dir}/intensities.csv" 415 | pvalues_df_csv = f"{out_dir}/pvalues.csv" 416 | intensity_df.columns = ["|".join(x) for x in intensity_df.columns] 417 | pvalues_df.columns = ["|".join(x) for x in pvalues_df.columns] 418 | intensity_df.to_csv(intensity_df_csv) 419 | pvalues_df.to_csv(pvalues_df_csv) 420 | logging.info("finished to write result to csv file.") 421 | 422 | def intensities_with_radius(adata: anndata, 423 | plot_pairs_file: str, 424 | radius_list: list = list(range(10, 400, 10)), 425 | anno: str = "cell_type", 426 | copy=False): 427 | """ 428 | calculate the intensities with different radius threshold 429 | 430 | Parameters 431 | ---------- 432 | adata 433 | anndata 434 | plot_pairs_file 435 | path of csv file which stores the interactions that will be processed. 436 | Contains 5 columns: sender, receiver, ligand, receptor. Columns was seperated by table 437 | radius_list 438 | list contains radius thresholds that will be set to calculate intensity 439 | anno 440 | cell type annotation key 441 | copy 442 | If true, result will be return in DataFrame format. If false, result will be stored in anndata.uns['intensities_with_radius']. 443 | """ 444 | radius_list = [x*2 for x in radius_list] 445 | cell_types = adata.obs[anno].unique() 446 | cell_type_dict = dict(zip(cell_types, range(0, len(cell_types)))) 447 | cellTypeIndex = adata.obs[anno].map(cell_type_dict).astype(int).values 448 | cellTypeNumber = len(cell_types) 449 | 450 | plot_pairs = pd.read_csv(plot_pairs_file, sep="\t") 451 | genes_list = plot_pairs[['ligand', 'receptor']].drop_duplicates().values.tolist() 452 | results = {} 453 | for radius in tqdm(radius_list): 454 | connectivities_key = f"{radius}_connectivities" 455 | if not connectivities_key in adata.obsp.keys(): 456 | connect_matrix = sq.gr.spatial_neighbors(adata, radius=radius, coord_type="generic", copy=True) 457 | else: 458 | connect_matrix = adata.obsp[connectivities_key] 459 | intensities_list = [] 460 | for genes in genes_list: 461 | exp_connect_matrix = _get_LR_connect_matrix(adata, genes, connect_matrix) 462 | interaction_matrix = _get_LR_intensity(exp_connect_matrix, cellTypeIndex, cellTypeNumber) 463 | intensities_list.append(interaction_matrix) 464 | intensity_df = _result_combined(intensities_list, genes_list, cell_types) 465 | results[radius] = intensity_df 466 | plot_df = pd.DataFrame() 467 | plot_df['radius'] = [x/2 for x in radius_list] 468 | for index, pair in plot_pairs.iterrows(): 469 | intensities = [] 470 | for radius in radius_list: 471 | intensities.append(results[radius].loc[tuple(pair[['ligand', 'receptor']].values)][tuple(pair[['sender', 'receiver']].values)]) 472 | column = f"{pair['sender']}|{pair['receiver']} ({pair['ligand']}|{pair['receptor']})" 473 | plot_df[column] = intensities 474 | if copy: 475 | return plot_df 476 | else: 477 | adata.uns['intensities_with_radius'] = plot_df 478 | 479 | def interaction_select(interaction:dict, 480 | cell_pairs:list=[], 481 | filter_genes:list=[], 482 | pvalue_threshold:float=0.05, 483 | intensities_range=(0, np.inf), 484 | ): 485 | ''' 486 | Input: 487 | interaction: dict{'pvalues': pd.DataFrame, 'intensities': pd.DataFrame}. Result of SCII. 488 | cell_pairs: [(celltype1, celltype2), (celltype1, celltype3), ...]. Interested cell pairs you want to visualize. 489 | filter_genes: [gene1, gene2, ...]. List of genes that you want to discard from the scii result. 490 | pvalue_threshold: Threshold of pvalue to filter the result. default=0.05 491 | intensities_range: Only retain the interaction result with intensity value between the intensities range. 492 | ''' 493 | filter_interaction = { 494 | 'intensities': interaction['intensities'].loc[:,cell_pairs], 495 | 'pvalues': interaction['pvalues'].loc[:, cell_pairs], 496 | } 497 | filter_interaction['intensities'] = filter_interaction['intensities'][[all([all([x not in filter_genes for x in y.split("_")]) for y in z]) for z in filter_interaction['intensities'].index]] 498 | filter_interaction['pvalues'] = filter_interaction['pvalues'][[all([all([x not in filter_genes for x in y.split("_")]) for y in z]) for z in filter_interaction['pvalues'].index]] 499 | filter_mask = (filter_interaction['pvalues'] < pvalue_threshold) & (filter_interaction['intensities']>intensities_range[0]) & (filter_interaction['intensities']