├── .gitignore ├── .readthedocs.yaml ├── LICENSE ├── MANIFEST.in ├── README.md ├── article ├── A1_synthetic_data │ ├── array_size_benchmark.ipynb │ ├── bm_functions.py │ ├── chrysalis_example.ipynb │ ├── contamination_benchmark.ipynb │ ├── data_generator │ │ ├── generate_synthetic_datasets.py │ │ ├── generate_truncated_samples.py │ │ ├── tissue_generator.py │ │ └── tools.py │ ├── main_synthetic_benchmark.ipynb │ └── method_scripts │ │ ├── array_size_benchmark │ │ ├── chrysalis.py │ │ ├── graphst.py │ │ ├── mefisto.py │ │ ├── nsf.py │ │ └── stagate.py │ │ ├── contamination_benchmark │ │ ├── chrysalis.py │ │ ├── graphst.py │ │ ├── mefisto.py │ │ ├── nsf.py │ │ └── stagate.py │ │ └── main_synthetic_benchmark │ │ ├── chrysalis.py │ │ ├── graphst.py │ │ ├── mefisto.py │ │ ├── nsf.py │ │ └── stagate.py ├── A2_human_lymph_node │ ├── SVG_detection_methods │ │ ├── 1_bsp_spatialde_sepal.ipynb │ │ ├── 2_spark.R │ │ └── 3_method_comparison.ipynb │ ├── benchmarking │ │ ├── graphst.py │ │ ├── mefisto.py │ │ ├── nsf.py │ │ ├── spatialpca.R │ │ └── stagate.py │ ├── chrysalis_analysis_and_validation.ipynb │ └── morans_i.ipynb ├── A3_human_breast_cancer │ ├── benchmarking.ipynb │ ├── benchmarking │ │ ├── graphst.py │ │ ├── mefisto.py │ │ ├── nsf.py │ │ ├── spatialpca.R │ │ └── stagate.py │ ├── chrysalis_analysis_and_validation.ipynb │ └── morphology_integration │ │ ├── 1_extract_image_tiles.ipynb │ │ ├── 2_autoencoder_training.py │ │ └── 3_integrate_morphology.ipynb ├── A4_mouse_brain │ ├── ff │ │ ├── mouse_brain_ff.ipynb │ │ └── mouse_brain_integration.ipynb │ └── ffpe │ │ ├── benchmark │ │ ├── graphst.py │ │ ├── mefisto.py │ │ ├── nsf.py │ │ └── stagate.py │ │ ├── map_annotations.py │ │ ├── mouse_brain_ffpe.ipynb │ │ └── mouse_brain_ffpe_benchmark.ipynb ├── A5_visium_hd │ └── visium_hd_analysis.ipynb ├── A6_slide_seqv2 │ └── slide_seqv2_analysis.ipynb ├── A7_stereo_seq │ └── stereo_seq_analysis.ipynb └── readme.md ├── chrysalis ├── __init__.py ├── core.py ├── functions.py ├── plots.py ├── test │ └── test.py └── utils.py ├── docs ├── Makefile ├── _static │ ├── css │ │ └── custom.css │ ├── images │ │ ├── after.png │ │ └── before.png │ └── js │ │ └── custom.js ├── api.md ├── conf.py ├── docs_logo.svg ├── docs_logo_prev.svg ├── favicon.svg ├── favicon_prev.svg ├── index.md ├── make.bat ├── overview │ ├── getting-started.md │ └── installation.md ├── requirements.txt └── tutorials │ ├── advanced_integration_tutorial.ipynb │ ├── lymph_node_tutorial.ipynb │ └── mouse_brain_integration_tutorial.ipynb ├── gallery └── readme.md ├── misc ├── banner.png ├── demo.png ├── demo.svg ├── deprecated_functions.py ├── human_lymph_node.jpg ├── logo.svg ├── panel_1.png └── panel_2.png ├── plots ├── V1_Human_Lymph_Node.png ├── V1_Human_Lymph_Node.svg ├── V1_Mouse_Brain_Sagittal_Anterior.png ├── V1_Mouse_Brain_Sagittal_Anterior.svg ├── V1_Mouse_Brain_Sagittal_Posterior.png ├── V1_Mouse_Brain_Sagittal_Posterior.svg ├── V1_Mouse_Kidney.png ├── V1_Mouse_Kidney.svg └── gallery │ ├── Parent_Visium_Human_BreastCancer.png │ ├── Parent_Visium_Human_BreastCancer.svg │ ├── Parent_Visium_Human_Cerebellum.png │ ├── Parent_Visium_Human_Cerebellum.svg │ ├── Parent_Visium_Human_ColorectalCancer.png │ ├── Parent_Visium_Human_ColorectalCancer.svg │ ├── Parent_Visium_Human_Glioblastoma.png │ ├── Parent_Visium_Human_Glioblastoma.svg │ ├── Parent_Visium_Human_OvarianCancer.png │ ├── Parent_Visium_Human_OvarianCancer.svg │ ├── V1_Adult_Mouse_Brain.png │ ├── V1_Adult_Mouse_Brain.svg │ ├── V1_Adult_Mouse_Brain_Coronal_Section_1.png │ ├── V1_Adult_Mouse_Brain_Coronal_Section_1.svg │ ├── V1_Adult_Mouse_Brain_Coronal_Section_2.png │ ├── V1_Adult_Mouse_Brain_Coronal_Section_2.svg │ ├── V1_Breast_Cancer_Block_A_Section_1.png │ ├── V1_Breast_Cancer_Block_A_Section_1.svg │ ├── V1_Breast_Cancer_Block_A_Section_2.png │ ├── V1_Breast_Cancer_Block_A_Section_2.svg │ ├── V1_Human_Heart.png │ ├── V1_Human_Heart.svg │ ├── V1_Human_Lymph_Node.png │ ├── V1_Human_Lymph_Node.svg │ ├── V1_Mouse_Brain_Sagittal_Anterior.png │ ├── V1_Mouse_Brain_Sagittal_Anterior.svg │ ├── V1_Mouse_Brain_Sagittal_Anterior_Section_2.png │ ├── V1_Mouse_Brain_Sagittal_Anterior_Section_2.svg │ ├── V1_Mouse_Brain_Sagittal_Posterior.png │ ├── V1_Mouse_Brain_Sagittal_Posterior.svg │ ├── V1_Mouse_Brain_Sagittal_Posterior_Section_2.png │ ├── V1_Mouse_Brain_Sagittal_Posterior_Section_2.svg │ ├── V1_Mouse_Kidney.png │ └── V1_Mouse_Kidney.svg ├── pyproject.toml └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # dev and build folders 2 | data/ 3 | temp/ 4 | .pytest_cache/ 5 | /dist 6 | /dist_archive 7 | /dev 8 | 9 | # IDE 10 | /.idea 11 | .ipynb_checkpoints/ 12 | 13 | # temp 14 | functions.py 15 | 16 | # misc 17 | .paquo.toml 18 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | sphinx: 4 | configuration: docs/conf.py 5 | builder: html 6 | 7 | build: 8 | os: "ubuntu-22.04" 9 | tools: 10 | python: "3.8" 11 | apt_packages: 12 | - cmake 13 | - g++ 14 | 15 | python: 16 | install: 17 | - requirements: docs/requirements.txt 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2023, Demeter Túrós 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include requirements.txt -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

4 | 5 | [![PyPI](https://img.shields.io/pypi/v/chrysalis-st?logo=PyPI)](https://pypi.org/project/chrysalis-st) 6 | [![Downloads](https://static.pepy.tech/badge/chrysalis-st)](https://pepy.tech/project/chrysalis-st) 7 | 8 | **chrysalis** is a powerful and lightweight method designed to identify and visualise tissue compartments in spatial 9 | transcriptomics datasets, all without the need for external references. 10 | **chrysalis** achieves this by identifying spatially variable genes (SVGs) through spatial autocorrelation. 11 | It then employs dimensionality reduction and archetypal analysis to locate extremal points in the low-dimensional 12 | feature space, which represent pure tissue compartments. 13 | Each observation (i.e. capture spot) in the gene expression matrix is subsequently represented as a proportion of these 14 | distinct compartments. 15 | **chrysalis** features a unique approach based on maximum intensity projection, allowing the simultaneous visualization 16 | of diverse tissue compartments. 17 | Moreover, it seamlessly integrates into `scanpy` based pipelines. 18 | 19 | If you like **chrysalis**, consider citing our [publication](https://github.com/rockdeme/chrysalis/#reference). 20 | 21 |

22 | 23 |

24 | 25 | **chrysalis** can define distinct tissue compartments and cellular niches with specific gene expression signatures, 26 | highlighted with specific colors. For instance, in the `V1_Human_Lymph_Node` dataset, **chrysalis** identifies 27 | various regions, such as germinal centers (yellow), B cell follicles (dark orange), and T cell compartments 28 | (lime). You can find more examples in the [gallery](https://github.com/rockdeme/chrysalis/tree/master/gallery#readme) section. 29 | 30 |

31 | 32 |

33 | 34 | ## Package 35 | **chrysalis** can be used with any pre-existing `anndata` snapshot of 10X Visium, Slide-seqV2 and Stereo-seq datasets 36 | generated with `scanpy`, and on new samples without the need of preprocessing. It is designed to be as lightweight as 37 | possible, however currently it relies on `libpysal` for its fast implementation of Moran's I. 38 | 39 | **chrysalis** requires the following packages: 40 | - numpy 41 | - pandas 42 | - matplotlib 43 | - scanpy 44 | - pysal 45 | - archetypes 46 | - scikit_learn 47 | - scipy 48 | - tqdm 49 | - seaborn 50 | 51 | To install **chrysalis**: 52 | ```terminal 53 | pip install chrysalis-st 54 | ``` 55 | 56 | ## Documentation, Tutorials and API details 57 | 58 | User documentation is available on: https://chrysalis.readthedocs.io/ 59 | 60 | Basic tutorials covering the main functionality of **chrysalis** are available on the documentation site. 61 | - first step-by-step tutorial: https://chrysalis.readthedocs.io/en/latest/tutorials/lymph_node_tutorial.html 62 | 63 | 64 | ## Usage 65 | 66 | ```python 67 | import chrysalis as ch 68 | import scanpy as sc 69 | import matplotlib.pyplot as plt 70 | 71 | adata = sc.datasets.visium_sge(sample_id='V1_Human_Lymph_Node') 72 | 73 | sc.pp.calculate_qc_metrics(adata, inplace=True) 74 | sc.pp.filter_cells(adata, min_counts=6000) 75 | sc.pp.filter_genes(adata, min_cells=10) 76 | 77 | ch.detect_svgs(adata) 78 | 79 | sc.pp.normalize_total(adata, inplace=True) 80 | sc.pp.log1p(adata) 81 | 82 | ch.pca(adata) 83 | 84 | ch.aa(adata, n_pcs=20, n_archetypes=8) 85 | 86 | ch.plot(adata) 87 | plt.show() 88 | ``` 89 | 90 | ## Reference 91 | 92 | **Chrysalis: decoding tissue compartments in spatial transcriptomics with archetypal analysis** 93 | *Demeter Túrós, Jelica Vasiljevic, Kerstin Hahn, Sven Rottenberg, Alberto Valdeolivas* Communications Biology 7.1 (2024): 1520. 94 | https://www.nature.com/articles/s42003-024-07165-7 95 | -------------------------------------------------------------------------------- /article/A1_synthetic_data/bm_functions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import scanpy as sc 4 | from scipy.stats import pearsonr 5 | 6 | 7 | def get_correlation_df(xdf, ydf): 8 | corr_matrix = np.empty((len(xdf.columns), len(ydf.columns))) 9 | for i, col1 in enumerate(xdf.columns): 10 | for j, col2 in enumerate(ydf.columns): 11 | corr, _ = pearsonr(xdf[col1], ydf[col2]) 12 | corr_matrix[i, j] = corr 13 | 14 | corr_df = pd.DataFrame(data=corr_matrix, 15 | index=xdf.columns, 16 | columns=ydf.columns).T 17 | return corr_df 18 | 19 | 20 | def collect_correlation_results(list_csv_path, enforce_symmetry=False): 21 | corr_dict = {} 22 | for idx, fp in enumerate(list_csv_path): 23 | corr_df = pd.read_csv(fp, index_col=0) 24 | corr_df = corr_df[[c for c in corr_df.columns if 'uniform' not in c]] 25 | if enforce_symmetry: 26 | col_num = corr_df.shape[1] 27 | corr_df = corr_df.iloc[:col_num, :] 28 | # get the max correlations and add it to the results_df 29 | max_corr = corr_df.max(axis=0) 30 | corr_dict[idx] = list(max_corr.values) 31 | df = pd.DataFrame(dict([(k, pd.Series(v)) for k, v in corr_dict.items()])) 32 | return df 33 | 34 | 35 | def collect_metadata(list_adata_path): 36 | meta_dict = {} 37 | for idx, fp in enumerate(list_adata_path): 38 | print(fp) 39 | adata = sc.read_h5ad(fp) 40 | adata.var["mt"] = adata.var_names.str.startswith("MT-") 41 | sc.pp.calculate_qc_metrics(adata, inplace=True, qc_vars=['mt']) 42 | 43 | # add count metrics 44 | total_counts = np.sum(adata.obs['total_counts']) 45 | med_counts = np.median(adata.obs['total_counts']) 46 | n_genes = np.median(adata.obs['n_genes_by_counts']) 47 | 48 | # add sample hash - can cause errors if we change the location of the hash in the sample name 49 | adata.uns['parameters']['hash'] = fp.split('/')[-1].split('_')[0] 50 | meta_dict[idx] = adata.uns['parameters'] 51 | 52 | meta_dict[idx]['total_counts'] = total_counts 53 | meta_dict[idx]['median_counts'] = med_counts 54 | meta_dict[idx]['median_n_genes'] = n_genes 55 | 56 | meta_df = pd.DataFrame(meta_dict).T 57 | # meta_df = meta_df.drop(columns=['annot_col']) 58 | meta_df = meta_df.convert_dtypes() 59 | return meta_df 60 | -------------------------------------------------------------------------------- /article/A1_synthetic_data/data_generator/generate_synthetic_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import hashlib 4 | import numpy as np 5 | import pandas as pd 6 | import scanpy as sc 7 | from tqdm import tqdm 8 | from itertools import product 9 | from tissue_generator import generate_synthetic_data 10 | 11 | 12 | #%% 13 | # tabula sapiens immune 14 | adata = sc.read_h5ad(f'data/tabula_sapiens_immune_subsampled_26k.h5ad') 15 | adata.X = adata.raw.X 16 | 17 | # do some QC 18 | adata.var["mt"] = adata.var['feature_name'].str.startswith("MT-") 19 | sc.pp.calculate_qc_metrics(adata, qc_vars=["mt"], percent_top=None, log1p=False, inplace=True) 20 | adata = adata[adata.obs.n_genes_by_counts < 2500, :] 21 | adata = adata[adata.obs.pct_counts_mt < 20, :] 22 | 23 | # remove cell types with less than 100 cells 24 | cell_num = pd.value_counts(adata.obs['cell_type']) 25 | low_cell_abundance = cell_num[cell_num < 100] 26 | adata = adata[~adata.obs['cell_type'].isin(low_cell_abundance.index)].copy() 27 | 28 | # remove sparse genes 29 | sc.pp.filter_genes(adata, min_cells=10) 30 | 31 | cell_types = np.unique(adata.obs['cell_type']) 32 | print(f'Number of cell types: {len(cell_types)}') 33 | print(cell_types) 34 | 35 | #%% 36 | param_dict = { 37 | 'annot_col': ['cell_type'], # annotation column containing the cell type categories 38 | # 'seed': [*range(1)], # generate different samples with seed 39 | 'seed': [42], 40 | 41 | # tissue zone generation 42 | 'n_tissue_zones': [6, 10, 14], # number of tissue zones 43 | 'uniform': [0.0, 0.5, 0.2, 0.8], # fraction of uniform zones from the total 44 | 'gp_mean': [8, 4, 10], # granularity 4-8-10 looks good 45 | 'gp_eta_true': [1.5], # controls gradient steepness we can leave this at 1.5. 1.0-2.5 can be 46 | # fine if we want to tune this 47 | 'gp_eta_true_uniform': [0.5], # same as above just with lower values 48 | 49 | # cell type abundance assignment 50 | 'cell_type_p': [0.02, 0.04], # tune number of cell types per non-uniform tissue zone 51 | 'mu_low_density': [3], # tune average abundance - 3 is around 6 cells/spot/cell type, 5 like 10-15. 52 | # Plot the histo with 53 | 'mu_high_density': [5], # the distribution for more info 54 | 'p_high_density': [0.5], # how many cell types are high density vs low density 55 | 56 | # confounder parameters 57 | 'mu_detection': [5], # detection rate shape to multiply the counts with 58 | 'mu_contamination': [0.03], # contamination shape for adding random counts to each gene per location 59 | 'mu_depth': [1], # sequencing depth shape to multiply the counts with per location 60 | 'depth_loc_mean_var_ratio': [25], # sequencing depth shape 61 | 'mu_depth_exp': [1], # sample wise sequencing depth shape - single value drawn to multiply everything with 62 | 'depth_mean_var_ratio_exp': [5], # sample wise sequencing depth scale 63 | } 64 | 65 | #%% 66 | filepath = 'data/tabula_sapiens_immune' 67 | 68 | # get the total number of combinations quickly 69 | param_combinations = product(*param_dict.values()) 70 | num_combinations = len([x for x in param_combinations]) 71 | 72 | # reset the iterator 73 | param_combinations = product(*param_dict.values()) 74 | print(f'Number of samples generated: {num_combinations}') 75 | 76 | #%% 77 | for parameters in tqdm(param_combinations, total=num_combinations): 78 | settings = {k:v for k, v in zip(param_dict.keys(), parameters)} 79 | print(settings) 80 | 81 | # check if folder already exists 82 | sample_hash = hashlib.sha256(json.dumps(settings).encode('utf-8')).hexdigest()[:12] 83 | folder = filepath + '/' + sample_hash + '/' 84 | 85 | if not os.path.isdir(folder): 86 | try: 87 | generate_synthetic_data(adata, *parameters, parameter_dict=settings, save=True, root_folder=folder) 88 | except AssertionError as e: 89 | print(f'AssertionError: {e}') 90 | else: 91 | pass 92 | 93 | # %% 94 | # varying contamination and sequencing depth - CONTAMINATION AND DEPTH OVERRIDES 95 | 96 | param_dict = { 97 | 'annot_col': ['cell_type'], # annotation column containing the cell type categories 98 | 'seed': [37, 42, 69], 99 | 100 | # spatial pattern type 101 | 'pattern_type': ['gp'], 102 | 103 | # tissue zone generation 104 | 'n_tissue_zones': [6], # number of tissue zones 105 | 'uniform': [0.0], # fraction of uniform zones from the total 106 | 'gp_mean': [8], # granularity 4-8-10 looks good 107 | 'gp_eta_true': [1.5], 108 | # controls gradient steepness we can leave this at 1.5. 1.0-2.5 can be fine if we want to tune this 109 | 'gp_eta_true_uniform': [0.5], # same as above just with lower values 110 | # r_d 111 | 'r_d_mean': [5], 112 | 'r_d_mean_var_ratio': [0.3], 113 | 114 | # cell type abundance assignment 115 | 'cell_type_p': [0.02], # tune number of cell types per non-uniform tissue zone 116 | 'mu_low_density': [3], 117 | # tune average abundance - 3 is around 6 cells/spot/cell type, 5 like 10-15. Plot the histo with 118 | 'mu_high_density': [5], # the distribution for more info 119 | 'p_high_density': [0.5], # how many cell types are high density vs low density 120 | 121 | # confounder parameters 122 | 'mu_detection': [5], # detection rate shape to multiply the counts with 123 | 'mu_contamination': [0.03], # contamination shape for adding random counts to each gene per location 124 | 'mu_depth': [1], # sequencing depth shape to multiply the counts with per location 125 | 'depth_loc_mean_var_ratio': [25], # sequencing depth shape 126 | 'mu_depth_exp': [1], # sample wise sequencing depth shape - single value drawn to multiply everything with 127 | 'depth_mean_var_ratio_exp': [5], # sample wise sequencing depth scale 128 | 129 | # direct overrides for contamination and depth to make them more consistent 130 | 'contamination_override': [0.03, 0.09, 0.27, 0.81], # 0.03 * (3)^n 131 | 'depth_override': [1, 0.2, 0.04, 0.008, 0.0016, 0.00032], # 1.0 * (0.2)^n 132 | } 133 | 134 | #%% 135 | filepath = 'data/tabula_sapiens_immune_contamination' 136 | 137 | # get the total number of combinations quickly 138 | param_combinations = product(*param_dict.values()) 139 | num_combinations = len([x for x in param_combinations]) 140 | 141 | # reset the iterator 142 | param_combinations = product(*param_dict.values()) 143 | print(f'Number of samples generated: {num_combinations}') 144 | 145 | #%% 146 | for parameters in tqdm(param_combinations, total=num_combinations): 147 | # for parameters in param_combinations: 148 | settings = {k:v for k, v in zip(param_dict.keys(), parameters)} 149 | print(settings) 150 | 151 | # check if folder already exists 152 | sample_hash = hashlib.sha256(json.dumps(settings).encode('utf-8')).hexdigest()[:12] 153 | folder = filepath + '/' + sample_hash + '/' 154 | 155 | if not os.path.isdir(folder): 156 | try: 157 | generate_synthetic_data(adata, *parameters, parameter_dict=settings, save=True, root_folder=folder, 158 | confounder_version='v2') 159 | except AssertionError as e: 160 | print(f'AssertionError: {e}') 161 | except Exception as e: 162 | print(f'An unexpected error occurred: {e}') 163 | else: 164 | pass 165 | -------------------------------------------------------------------------------- /article/A1_synthetic_data/data_generator/generate_truncated_samples.py: -------------------------------------------------------------------------------- 1 | import os 2 | import scanpy as sc 3 | from tqdm import tqdm 4 | import matplotlib.pyplot as plt 5 | from glob import glob 6 | 7 | 8 | filepath = 'data/tabula_sapiens_immune_contamination' 9 | savepath = 'data/tabula_sapiens_immune_size' 10 | adatas = glob(filepath + '/*/*.h5ad') 11 | 12 | for idx, adp in tqdm(enumerate(adatas), total=len(adatas)): 13 | print(adp) 14 | sample_folder = '/'.join(adp.split('/')[:-1]) + '/' 15 | adata = sc.read_h5ad(adp) 16 | 17 | sample_id = sample_folder.split('/')[-2] 18 | 19 | if (adata.uns['parameters']['mu_contamination'] == 0.03) & (adata.uns['parameters']['mu_depth_exp'] == 1): 20 | 21 | # 50 x 50 22 | os.makedirs(f'{savepath}/{sample_id}-5050/') 23 | adata.write_h5ad(f'{savepath}/{sample_id}-5050/{sample_id}-5050.h5ad') 24 | tissue_zones = adata.obsm['tissue_zones'] 25 | n1, n2 = (50, 50) 26 | fig, ax = plt.subplots(1, 1, figsize=(5, 5)) 27 | img_data = tissue_zones['tissue_zone_0'].values.reshape(n1, n2).T 28 | im = ax.imshow(img_data, cmap='mako_r', vmin=0, vmax=1) 29 | plt.tight_layout() 30 | plt.show() 31 | 32 | # 25 x 50 33 | adata = adata[adata.obsm['spatial'][:, 0] < 50] 34 | os.makedirs(f'{savepath}/{sample_id}-2550/') 35 | adata.write_h5ad(f'{savepath}/{sample_id}-2550/{sample_id}-2550.h5ad') 36 | tissue_zones = adata.obsm['tissue_zones'] 37 | n1, n2 = (25, 50) 38 | fig, ax = plt.subplots(1, 1, figsize=(5, 5)) 39 | img_data = tissue_zones['tissue_zone_0'].values.reshape(n1, n2).T 40 | im = ax.imshow(img_data, cmap='mako_r', vmin=0, vmax=1) 41 | plt.tight_layout() 42 | plt.show() 43 | 44 | # 25 x 25 45 | adata = adata[adata.obsm['spatial'][:, 1] < 50] 46 | os.makedirs(f'{savepath}/{sample_id}-2525/') 47 | adata.write_h5ad(f'{savepath}/{sample_id}-2525/{sample_id}-2525.h5ad') 48 | tissue_zones = adata.obsm['tissue_zones'] 49 | n1, n2 = (25, 25) 50 | fig, ax = plt.subplots(1, 1, figsize=(5, 5)) 51 | img_data = tissue_zones['tissue_zone_0'].values.reshape(n1, n2).T 52 | im = ax.imshow(img_data, cmap='mako_r', vmin=0, vmax=1) 53 | plt.tight_layout() 54 | plt.show() 55 | -------------------------------------------------------------------------------- /article/A1_synthetic_data/data_generator/tissue_generator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import json 4 | import hashlib 5 | import scanpy as sc 6 | import seaborn as sns 7 | import matplotlib.pyplot as plt 8 | from functions import plot_spatial, return_cell_types 9 | from tools import (generate_tissue_zones, assign_cell_types, assign_cell_type_abundance, 10 | construct_cell_abundance_matrix, generate_synthetic_counts, construct_adata, 11 | add_confounders, add_confounders_v2) 12 | 13 | 14 | def generate_synthetic_data(adata, annot_col, seed, pattern_type, n_tissue_zones, uniform, gp_mean, gp_eta_true, 15 | gp_eta_true_uniform, r_d_mean, r_d_mean_var_ratio, cell_type_p, mu_low_density, 16 | mu_high_density, p_high_density, mu_detection, mu_contamination, mu_depth, 17 | depth_loc_mean_var_ratio, mu_depth_exp, depth_mean_var_ratio_exp, contamination_override, 18 | depth_override, parameter_dict, save=False, root_folder=None, confounder_version='v1'): 19 | 20 | # some checks before running the generation 21 | cell_types, n_cell_types = return_cell_types(adata, annot_col) 22 | n_uniform = int(n_tissue_zones * uniform) 23 | n_sparse_ct = n_cell_types - n_uniform 24 | n_sparse_tz = n_tissue_zones - n_uniform 25 | # assert n_sparse_ct >= n_sparse_tz, (f"Sparse tissue zones = {n_sparse_tz} ({n_tissue_zones}-{n_uniform}) which is" 26 | # f"higher than the number of cell types ({n_sparse_ct}) ({n_cell_types}-" 27 | # f"{n_uniform}). Either decrease the " 28 | # f"overall number of tissue zones or the fraction of uniform zones.") 29 | assert n_uniform < n_cell_types 30 | 31 | if type(parameter_dict) != dict: 32 | raise Warning("parameter_dict is not a dictionary!") 33 | sample_hash = hashlib.sha256(json.dumps(parameter_dict).encode('utf-8')).hexdigest()[:12] 34 | if save: 35 | assert root_folder is not None 36 | os.makedirs(root_folder, exist_ok=True) 37 | assert os.path.isdir(root_folder) == True 38 | 39 | # generate tissue zones with gaussian processes or reaction-diffusion 40 | abundance_df, locs = generate_tissue_zones(grid_size=(50, 50), 41 | n_tissue_zones=n_tissue_zones, 42 | unfiform_fraction=uniform, 43 | pattern_type=pattern_type, gp_mean=gp_mean, gp_eta_true=gp_eta_true, 44 | gp_eta_true_uniform=gp_eta_true_uniform, 45 | r_d_mean=r_d_mean, seed=seed, r_d_mean_var_ratio=r_d_mean_var_ratio) 46 | 47 | # look at the tissue zones 48 | 49 | nrows = math.ceil((n_tissue_zones + 1) / 5) 50 | 51 | plt.figure(figsize=(3 * 5, 3 * nrows)) 52 | plot_spatial(abundance_df.values, n=(50, 50), nrows=nrows, names=abundance_df.columns, vmax=None) 53 | plt.tight_layout() 54 | if save: 55 | plt.savefig(root_folder + 'tissue_zones.png') 56 | plt.show() 57 | 58 | # assign cell types to tissue zones 59 | cell_types_df = assign_cell_types(adata, n_tissue_zones, annot_col=annot_col, unfiform_fraction=uniform, 60 | cell_type_p=cell_type_p, seed=seed) 61 | # add average abundances 62 | # mu controls cell density per spot 63 | cell_types_df = assign_cell_type_abundance(cell_types_df, p_high_density=p_high_density, 64 | mu_low_density=mu_low_density, mu_high_density=mu_high_density, 65 | seed=seed) 66 | 67 | sns.heatmap(cell_types_df, square=True) 68 | plt.tight_layout() 69 | if save: 70 | plt.savefig(root_folder + 'heatmap.png') 71 | plt.show() 72 | 73 | # multiply the abundances with the average and take the integer values (difference between the two is the 74 | # capture eff.) 75 | cell_count_df, capture_eff_df = construct_cell_abundance_matrix(abundance_df, cell_types_df, seed=seed) 76 | 77 | n_cell_types = len(cell_types_df.T.columns) 78 | nrows = math.ceil((n_cell_types + 1) / 5) 79 | 80 | plt.figure(figsize=(3 * 5, 3 * nrows)) 81 | plot_spatial(cell_count_df.values, n=(50, 50), nrows=nrows, names=cell_types_df.T.columns, vmax=None) 82 | plt.tight_layout() 83 | if save: 84 | plt.savefig(root_folder + 'cell_types.png') 85 | plt.show() 86 | 87 | synthetic_counts = generate_synthetic_counts(adata, cell_count_df, capture_eff_df, annot_col=annot_col, seed=seed) 88 | 89 | sth_adata = construct_adata(synthetic_counts, adata, abundance_df, locs, cell_count_df, capture_eff_df) 90 | 91 | if confounder_version == 'v1': 92 | sth_adata = add_confounders(sth_adata, 93 | mu_detection=mu_detection, mu_contamination=mu_contamination, mu_depth=mu_depth, 94 | depth_loc_mean_var_ratio=depth_loc_mean_var_ratio, mu_depth_exp=mu_depth_exp, 95 | depth_mean_var_ratio_exp=depth_mean_var_ratio_exp, 96 | contamination_override=contamination_override, depth_override=depth_override, 97 | seed=seed) 98 | elif confounder_version == 'v2': 99 | sth_adata = add_confounders_v2(sth_adata, 100 | mu_detection=mu_detection, mu_contamination=mu_contamination, mu_depth=mu_depth, 101 | depth_loc_mean_var_ratio=depth_loc_mean_var_ratio, mu_depth_exp=mu_depth_exp, 102 | depth_mean_var_ratio_exp=depth_mean_var_ratio_exp, 103 | contamination_override=contamination_override, depth_override=depth_override, 104 | seed=seed) 105 | else: 106 | raise KeyError("Invalid confounder_version. It should be either 'v1' or 'v2'.") 107 | 108 | sth_adata.uns['parameters'] = parameter_dict 109 | 110 | sth_adata.var['ensembl_id'] = sth_adata.var_names 111 | sth_adata.var['gene_symbol'] = adata.var['feature_name'] 112 | sth_adata.var_names = sth_adata.var['gene_symbol'].astype(str) 113 | 114 | sth_adata.var["mt"] = adata.var_names.str.startswith("MT-") 115 | sc.pp.calculate_qc_metrics(sth_adata, qc_vars=['mt'], inplace=True) 116 | 117 | if save: 118 | sth_adata.write_h5ad(root_folder + f'{sample_hash}_sth_adata.h5ad') 119 | 120 | # sc.pl.highest_expr_genes(sth_adata, n_top=20) 121 | # sc.pl.spatial(sth_adata, color=["log1p_total_counts", 'MALAT1'], spot_size=2.7) 122 | -------------------------------------------------------------------------------- /article/A1_synthetic_data/method_scripts/array_size_benchmark/chrysalis.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import scanpy as sc 4 | from tqdm import tqdm 5 | import seaborn as sns 6 | from glob import glob 7 | import chrysalis as ch 8 | import matplotlib.pyplot as plt 9 | from article.A1_synthetic_data.bm_functions import get_correlation_df 10 | 11 | 12 | filepath = 'data/tabula_sapiens_immune_size' 13 | adatas = glob(filepath + '/*/*.h5ad') 14 | 15 | results_df = pd.DataFrame() 16 | 17 | for idx, adp in tqdm(enumerate(adatas), total=len(adatas)): 18 | print(adp) 19 | sample_folder = '/'.join(adp.split('/')[:-1]) + '/' 20 | adata = sc.read_h5ad(adp) 21 | 22 | # number of non-uniform compartments 23 | uniform = int(adata.uns['parameters']['n_tissue_zones'] * adata.uns['parameters']['uniform']) 24 | tissue_zones = adata.uns['parameters']['n_tissue_zones'] - uniform 25 | tissue_zones = int(tissue_zones) 26 | 27 | # chrysalis pipeline 28 | ch.detect_svgs(adata, neighbors=8, top_svg=1000, min_morans=0.01) 29 | ch.plot_svgs(adata) 30 | plt.savefig(sample_folder + 'ch_svgs.png') 31 | plt.close() 32 | 33 | sc.pp.normalize_total(adata, inplace=True) 34 | sc.pp.log1p(adata) 35 | 36 | ch.pca(adata, n_pcs=40) 37 | 38 | ch.plot_explained_variance(adata) 39 | plt.savefig(sample_folder + 'ch_expl_variance.png') 40 | plt.close() 41 | 42 | ch.aa(adata, n_pcs=20, n_archetypes=tissue_zones) 43 | 44 | ch.plot(adata, dim=tissue_zones, marker='s') 45 | plt.savefig(sample_folder + 'ch_plot.png') 46 | plt.close() 47 | 48 | col_num = int(np.sqrt(tissue_zones)) 49 | ch.plot_compartments(adata, marker='s', ncols=col_num) 50 | plt.savefig(sample_folder + 'ch_comps.png') 51 | plt.close() 52 | 53 | # correlation with tissue zones 54 | compartments = adata.obsm['chr_aa'] 55 | compartment_df = pd.DataFrame(data=compartments, index=adata.obs.index) 56 | tissue_zone_df = adata.obsm['tissue_zones'] 57 | # tissue_zone_df = tissue_zone_df[[c for c in tissue_zone_df.columns if 'uniform' not in c]] 58 | 59 | corr_df = get_correlation_df(tissue_zone_df, compartment_df) 60 | corr_df.to_csv(sample_folder + 'pearson.csv') 61 | 62 | sns.heatmap(corr_df, square=True, center=0) 63 | plt.tight_layout() 64 | plt.savefig(sample_folder + 'corr_heatmap.png') 65 | plt.close() 66 | -------------------------------------------------------------------------------- /article/A1_synthetic_data/method_scripts/array_size_benchmark/graphst.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pandas as pd 3 | import scanpy as sc 4 | from tqdm import tqdm 5 | from glob import glob 6 | import seaborn as sns 7 | from GraphST import GraphST 8 | import matplotlib.pyplot as plt 9 | from sklearn.decomposition import PCA 10 | from article.A1_synthetic_data.bm_functions import get_correlation_df 11 | 12 | 13 | filepath = '/storage/homefs/pt22a065/chr_data/tabula_sapiens_immune_size' 14 | adatas = glob(filepath + '/*/*.h5ad') 15 | 16 | results_df = pd.DataFrame() 17 | 18 | for idx, adp in tqdm(enumerate(adatas), total=len(adatas)): 19 | print(adp) 20 | sample_folder = '/'.join(adp.split('/')[:-1]) + '/' 21 | device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu') 22 | adata = sc.read_h5ad(adp) 23 | adata.var_names_make_unique() 24 | 25 | # define model 26 | model = GraphST.GraphST(adata, device=torch.device('cpu')) 27 | 28 | # train model 29 | adata = model.train() 30 | 31 | graphst_df = pd.DataFrame(adata.obsm['emb']) 32 | graphst_df.index = adata.obs.index 33 | 34 | pca = PCA(n_components=20, svd_solver='arpack', random_state=42) 35 | graphst_pcs = pca.fit_transform(graphst_df) 36 | graphst_pcs_df = pd.DataFrame(data=graphst_pcs, index=graphst_df.index) 37 | 38 | graphst_pcs_df.to_csv(sample_folder + 'graphst_comps.csv') 39 | 40 | 41 | tissue_zone_df = adata.obsm['tissue_zones'] 42 | # tissue_zone_df = tissue_zone_df[[c for c in tissue_zone_df.columns if 'uniform' not in c]] 43 | 44 | corr_df = get_correlation_df(tissue_zone_df, graphst_pcs_df) 45 | corr_df.to_csv(sample_folder + 'graphst_pearson.csv') 46 | 47 | fig, ax = plt.subplots(1, 1, figsize=(6.4, 4.8)) 48 | sns.heatmap(corr_df, square=True, center=0, ax=ax) 49 | plt.tight_layout() 50 | plt.savefig(sample_folder + 'graphst_corr_heatmap.png') 51 | plt.close() 52 | -------------------------------------------------------------------------------- /article/A1_synthetic_data/method_scripts/array_size_benchmark/mefisto.py: -------------------------------------------------------------------------------- 1 | import mofax 2 | import pandas as pd 3 | import scanpy as sc 4 | from tqdm import tqdm 5 | from glob import glob 6 | import seaborn as sns 7 | import matplotlib.pyplot as plt 8 | from mofapy2.run.entry_point import entry_point 9 | from article.A1_synthetic_data.bm_functions import get_correlation_df 10 | 11 | 12 | filepath = '/storage/homefs/pt22a065/chr_data/tabula_sapiens_immune_size' 13 | adatas = glob(filepath + '/*/*.h5ad') 14 | 15 | results_df = pd.DataFrame() 16 | 17 | for idx, adp in tqdm(enumerate(adatas), total=len(adatas)): 18 | print(adp) 19 | sample_folder = '/'.join(adp.split('/')[:-1]) + '/' 20 | adata = sc.read_h5ad(adp) 21 | 22 | # number of non-uniform compartments 23 | uniform = int(adata.uns['parameters']['n_tissue_zones'] * adata.uns['parameters']['uniform']) 24 | tissue_zones = adata.uns['parameters']['n_tissue_zones'] - uniform 25 | tissue_zones = int(tissue_zones) 26 | 27 | sc.pp.filter_genes(adata, min_cells=len(adata) * 0.05) 28 | 29 | sc.pp.normalize_total(adata, inplace=True) 30 | sc.pp.log1p(adata) 31 | sc.pp.highly_variable_genes(adata, flavor="seurat", n_top_genes=2000) 32 | 33 | adata.obs = pd.concat([adata.obs, 34 | pd.DataFrame(adata.obsm["spatial"], columns=["imagerow", "imagecol"], 35 | index=adata.obs_names), 36 | ], axis=1) 37 | 38 | ent = entry_point() 39 | ent.set_data_options(use_float32=True) 40 | ent.set_data_from_anndata(adata, features_subset="highly_variable") 41 | 42 | ent.set_model_options(factors=tissue_zones) 43 | ent.set_train_options(save_interrupted=True) 44 | ent.set_train_options(seed=2021) 45 | 46 | # We use 1000 inducing points to learn spatial covariance patterns 47 | n_inducing = 500 # 500 for size tests 48 | 49 | ent.set_covariates([adata.obsm["spatial"]], covariates_names=["imagerow", "imagecol"]) 50 | ent.set_smooth_options(sparseGP=True, frac_inducing=n_inducing / adata.n_obs, 51 | start_opt=10, opt_freq=10) 52 | 53 | ent.build() 54 | ent.run() 55 | ent.save(sample_folder + "mefisto_temp.hdf5") 56 | m = mofax.mofa_model(sample_folder + "mefisto_temp.hdf5") 57 | factor_df = m.get_factors(df=True) 58 | 59 | # factor_df = ent.model.getFactors(df=True) 60 | 61 | factor_df.to_csv(sample_folder + 'mefisto_comps.csv') 62 | 63 | # correlation with tissue zones 64 | tissue_zone_df = adata.obsm['tissue_zones'] 65 | # tissue_zone_df = tissue_zone_df[[c for c in tissue_zone_df.columns if 'uniform' not in c]] 66 | 67 | corr_df = get_correlation_df(tissue_zone_df, factor_df) 68 | corr_df.to_csv(sample_folder + 'mefisto_pearson.csv') 69 | 70 | fig, ax = plt.subplots(1, 1, figsize=(6.4, 4.8)) 71 | sns.heatmap(corr_df, square=True, center=0, ax=ax) 72 | plt.tight_layout() 73 | plt.savefig(sample_folder + 'mefisto_corr_heatmap.png') 74 | plt.close() 75 | -------------------------------------------------------------------------------- /article/A1_synthetic_data/method_scripts/array_size_benchmark/nsf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import random 4 | import numpy as np 5 | import pandas as pd 6 | import scanpy as sc 7 | from os import path 8 | from tqdm import tqdm 9 | from glob import glob 10 | import seaborn as sns 11 | from nsf.models import sf 12 | import matplotlib.pyplot as plt 13 | from numpy.linalg import LinAlgError 14 | from tensorflow_probability import math as tm 15 | from article.A1_synthetic_data.bm_functions import get_correlation_df 16 | from nsf.utils import preprocess, training, visualize, postprocess, misc 17 | 18 | 19 | filepath = '/storage/homefs/pt22a065/chr_data/tabula_sapiens_immune_size' 20 | adatas = glob(filepath + '/*/*.h5ad') 21 | tfk = tm.psd_kernels 22 | 23 | results_df = pd.DataFrame() 24 | 25 | for idx, adp in tqdm(enumerate(adatas), total=len(adatas)): 26 | try: 27 | print(adp) 28 | sample_folder = '/'.join(adp.split('/')[:-1]) + '/' 29 | 30 | # Check if all necessary output files already exist in the sample_folder 31 | if (os.path.exists(sample_folder + 'nsf_comps.csv') and 32 | os.path.exists(sample_folder + 'nsf_pearson.csv') and 33 | os.path.exists(sample_folder + 'nsf_corr_heatmap.png')): 34 | print(f"Skipping {sample_folder} as output files already exist.") 35 | continue 36 | 37 | adata = sc.read_h5ad(adp) 38 | 39 | # number of non-uniform compartments 40 | uniform = int(adata.uns['parameters']['n_tissue_zones'] * adata.uns['parameters']['uniform']) 41 | tissue_zones = adata.uns['parameters']['n_tissue_zones'] - uniform 42 | tissue_zones = int(tissue_zones) 43 | 44 | sc.pp.calculate_qc_metrics(adata, inplace=True) 45 | # sc.pp.filter_genes(adata, min_cells=len(adata) * 0.05) 46 | sc.pp.filter_genes(adata, min_cells=1, inplace=True) 47 | 48 | try: 49 | # first attempt without any filtering 50 | sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=3000) 51 | except: 52 | try: 53 | # second attempt with min_counts=10 54 | sc.pp.filter_genes(adata, min_counts=10, inplace=True) 55 | # sc.pp.filter_cells(adata, min_counts=100, inplace=True) 56 | sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=3000) 57 | except: 58 | # final attempt with min_counts=100 59 | sc.pp.filter_genes(adata, min_counts=100, inplace=True) 60 | # sc.pp.filter_cells(adata, min_counts=100, inplace=True) 61 | sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=3000) 62 | adata.layers = {"counts": adata.X.copy()} # store raw counts before normalization changes ad.X 63 | sc.pp.normalize_total(adata, inplace=True, layers=None, key_added="sizefactor") 64 | sc.pp.log1p(adata) 65 | 66 | adata.var['deviance_poisson'] = preprocess.deviancePoisson(adata.layers["counts"]) 67 | o = np.argsort(-adata.var['deviance_poisson']) 68 | idx = list(range(adata.shape[0])) 69 | random.shuffle(idx) 70 | adata = adata[idx, o] 71 | adata = adata[:, :2000] 72 | 73 | Dtr, Dval = preprocess.anndata_to_train_val(adata, layer="counts", 74 | sz="constant") # size factor set to constant from scanpy for 75 | # the samples with y nan error 76 | Dtr_n, Dval_n = preprocess.anndata_to_train_val(adata) # normalized data 77 | fmeans, Dtr_c, Dval_c = preprocess.center_data(Dtr_n, Dval_n) # centered features 78 | Xtr = Dtr["X"] # note this should be identical to Dtr_n["X"] 79 | Ntr = Xtr.shape[0] 80 | Dtf = preprocess.prepare_datasets_tf(Dtr, Dval=Dval, shuffle=False) 81 | Dtf_n = preprocess.prepare_datasets_tf(Dtr_n, Dval=Dval_n, shuffle=False) 82 | Dtf_c = preprocess.prepare_datasets_tf(Dtr_c, Dval=Dval_c, shuffle=False) 83 | visualize.heatmap(Xtr, Dtr["Y"][:, 0], marker="D", s=15) 84 | plt.close() 85 | 86 | # Visualize raw data 87 | plt.imshow(np.log1p(Dtr["Y"])[:50, :100], cmap="Blues") 88 | plt.close() 89 | 90 | # Visualize inducing points 91 | Z = misc.kmeans_inducing_pts(Xtr, 500) 92 | fig, ax = plt.subplots(figsize=(12, 10)) 93 | ax.scatter(Xtr[:, 0], Xtr[:, 1], marker="D", s=50, ) 94 | ax.scatter(Z[:, 0], Z[:, 1], c="red", s=30) 95 | plt.close() 96 | 97 | # initialize inducing points and tuning parameters 98 | Z = misc.kmeans_inducing_pts(Xtr, 500) # 2363 99 | M = Z.shape[0] 100 | ker = tfk.MaternThreeHalves 101 | S = 3 # samples for elbo approximation 102 | # NSF: Spatial only with non-negative factors 103 | L = tissue_zones # number of latent factors, ideally divisible by 2 104 | J = adata.shape[1] # 2000 105 | 106 | mpth = path.join("/storage/homefs/pt22a065/chr_benchmarks/nsf/models/V5/") 107 | 108 | fit = sf.SpatialFactorization(J, L, Z, psd_kernel=ker, nonneg=True, lik="poi") 109 | fit.elbo_avg(Xtr, Dtr["Y"], sz=Dtr["sz"]) 110 | fit.init_loadings(Dtr["Y"], X=Xtr, sz=Dtr["sz"]) 111 | fit.elbo_avg(Xtr, Dtr["Y"], sz=Dtr["sz"]) 112 | pp = fit.generate_pickle_path("scanpy", base=mpth) 113 | tro = training.ModelTrainer(fit, pickle_path=pp) 114 | tro.train_model(*Dtf, ckpt_freq=10000) 115 | 116 | ttl = "NSF: spatial, non-negative factors, Poisson likelihood" 117 | visualize.plot_loss(tro.loss, title=ttl) # ,ss=range(2000,4000)) 118 | plt.savefig(sample_folder + 'nsf_loss.png') 119 | plt.close() 120 | 121 | hmkw = {"figsize": (4, 4), "s": 0.3, "marker": "D", "subplot_space": 0, 122 | "spinecolor": "white"} 123 | insf = postprocess.interpret_nsf(fit, Xtr, S=10, lda_mode=False) 124 | tgnames = [str(i) for i in range(1, L + 1)] 125 | 126 | # fig, axes = visualize.multiheatmap(Xtr, np.sqrt(insf["factors"]), (4, 3), **hmkw) 127 | # visualize.set_titles(fig, tgnames, x=0.05, y=.85, fontsize="medium", c="white", 128 | # ha="left", va="top") 129 | # plt.savefig(sample_folder + 'nsf_comps.png') 130 | # plt.close() 131 | 132 | data = {'factors': insf, 'positions': Xtr} 133 | 134 | 135 | def transform_coords(X): 136 | # code from nsf github 137 | X[:, 1] = -X[:, 1] 138 | xmin = X.min(axis=0) 139 | X -= xmin 140 | x_gmean = np.exp(np.mean(np.log(X.max(axis=0)))) 141 | X *= 4 / x_gmean 142 | return X - X.mean(axis=0) 143 | 144 | 145 | X = adata.obsm["spatial"].copy().astype('float32') 146 | tcoords = transform_coords(X) 147 | 148 | pair_idx = [] 149 | for xy in data['positions']: 150 | distances = [math.dist([xy[0], xy[1]], [idx[0], idx[1]]) for idx in tcoords] 151 | pair_idx.append(np.argmin(distances)) 152 | 153 | nsf_df = pd.DataFrame(data=np.zeros([len(adata), data['factors']['factors'].shape[1]])) 154 | for idx, i in enumerate(pair_idx): 155 | nsf_df.iloc[i, :] = data['factors']['factors'][idx, :] 156 | nsf_df.index = adata.obs.index 157 | 158 | nsf_df.to_csv(sample_folder + 'nsf_comps.csv') 159 | nsf_df = nsf_df[~(nsf_df.eq(0).all(axis=1))] 160 | 161 | tissue_zone_df = adata.obsm['tissue_zones'].loc[list(nsf_df.index)] 162 | 163 | corr_df = get_correlation_df(tissue_zone_df, nsf_df) 164 | corr_df.to_csv(sample_folder + 'nsf_pearson.csv') 165 | 166 | fig, ax = plt.subplots(1, 1, figsize=(6.4, 4.8)) 167 | sns.heatmap(corr_df, square=True, center=0, ax=ax) 168 | plt.tight_layout() 169 | plt.savefig(sample_folder + 'nsf_corr_heatmap.png') 170 | plt.close() 171 | 172 | except (LinAlgError, ValueError) as e: 173 | print(f"LinAlgError encountered for matrix. Continuing...{e}") 174 | break 175 | -------------------------------------------------------------------------------- /article/A1_synthetic_data/method_scripts/array_size_benchmark/stagate.py: -------------------------------------------------------------------------------- 1 | import STAGATE 2 | import pandas as pd 3 | import scanpy as sc 4 | from tqdm import tqdm 5 | from glob import glob 6 | import seaborn as sns 7 | import sklearn.neighbors 8 | import matplotlib.pyplot as plt 9 | from article.A1_synthetic_data.bm_functions import get_correlation_df 10 | 11 | 12 | filepath = 'data/tabula_sapiens_immune_size' 13 | adatas = glob(filepath + '/*/*.h5ad') 14 | 15 | results_df = pd.DataFrame() 16 | 17 | for idx, adp in tqdm(enumerate(adatas), total=len(adatas)): 18 | print(adp) 19 | sample_folder = '/'.join(adp.split('/')[:-1]) + '/' 20 | adata = sc.read_h5ad(adp) 21 | 22 | sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=3000) 23 | sc.pp.normalize_total(adata, target_sum=1e4) 24 | sc.pp.log1p(adata) 25 | 26 | # include 8 neighbours with the cutoff similarly to 6 for visium 27 | STAGATE.Cal_Spatial_Net(adata, rad_cutoff=3.3) 28 | STAGATE.Stats_Spatial_Net(adata) 29 | 30 | coor = pd.DataFrame(adata.obsm['spatial']) 31 | coor.index = adata.obs.index 32 | coor.columns = ['imagerow', 'imagecol'] 33 | 34 | nbrs = sklearn.neighbors.NearestNeighbors(radius=3.3).fit(coor) 35 | distances, indices = nbrs.radius_neighbors(coor, return_distance=True) 36 | KNN_list = [] 37 | for it in range(indices.shape[0]): 38 | KNN_list.append(pd.DataFrame(zip([it]*indices[it].shape[0], indices[it], distances[it]))) 39 | KNN_df = pd.concat(KNN_list) 40 | KNN_df.columns = ['Cell1', 'Cell2', 'Distance'] 41 | 42 | adata = STAGATE.train_STAGATE(adata, alpha=0) 43 | 44 | for i in range(30): 45 | adata.obs[f'stagate_{i}'] = adata.obsm['STAGATE'][:, i] 46 | 47 | # doesn't work without the proper anndata structure - missing some spatial info 48 | # with mpl.rc_context({'figure.figsize': [4.5, 5]}): 49 | # sc.pl.spatial(adata, color=[f'S{i}' for i in range(30)], size=2, ncols=4, show=False) 50 | # plt.savefig(sample_folder + 'stagate_comps.png') 51 | # plt.close() 52 | 53 | # correlation with tissue zones 54 | stagate_df = adata.obs[[f'stagate_{i}' for i in range(30)]] 55 | stagate_df.to_csv(sample_folder + 'stagate_comps.csv') 56 | 57 | tissue_zone_df = adata.obsm['tissue_zones'] 58 | # tissue_zone_df = tissue_zone_df[[c for c in tissue_zone_df.columns if 'uniform' not in c]] 59 | 60 | corr_df = get_correlation_df(tissue_zone_df, stagate_df) 61 | corr_df.to_csv(sample_folder + 'stagate_pearson.csv') 62 | 63 | fig, ax = plt.subplots(1, 1, figsize=(6.4, 4.8)) 64 | sns.heatmap(corr_df, square=True, center=0, ax=ax) 65 | plt.tight_layout() 66 | plt.savefig(sample_folder + 'stagate_corr_heatmap.png') 67 | plt.close() 68 | -------------------------------------------------------------------------------- /article/A1_synthetic_data/method_scripts/contamination_benchmark/chrysalis.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | import scanpy as sc 5 | from tqdm import tqdm 6 | import seaborn as sns 7 | from glob import glob 8 | import chrysalis as ch 9 | import matplotlib.pyplot as plt 10 | from article.A1_synthetic_data.bm_functions import get_correlation_df 11 | 12 | 13 | filepath = 'data/tabula_sapiens_immune_contamination' 14 | adatas = glob(filepath + '/*/*.h5ad') 15 | 16 | results_df = pd.DataFrame() 17 | 18 | for idx, adp in tqdm(enumerate(adatas), total=len(adatas)): 19 | print(adp) 20 | sample_folder = '/'.join(adp.split('/')[:-1]) + '/' 21 | 22 | # Check if all necessary output files already exist in the sample_folder 23 | if (os.path.exists(sample_folder + 'ch_svgs.png') and 24 | os.path.exists(sample_folder + 'ch_expl_variance.png') and 25 | os.path.exists(sample_folder + 'ch_plot.png') and 26 | os.path.exists(sample_folder + 'ch_comps.png') and 27 | os.path.exists(sample_folder + 'pearson.csv') and 28 | os.path.exists(sample_folder + 'corr_heatmap.png')): 29 | print(f"Skipping {sample_folder} as output files already exist.") 30 | continue 31 | 32 | adata = sc.read_h5ad(adp) 33 | 34 | # number of non-uniform compartments 35 | uniform = int(adata.uns['parameters']['n_tissue_zones'] * adata.uns['parameters']['uniform']) 36 | tissue_zones = adata.uns['parameters']['n_tissue_zones'] - uniform 37 | tissue_zones = int(tissue_zones) 38 | 39 | # chrysalis pipeline 40 | ch.detect_svgs(adata, neighbors=8, top_svg=1000, min_morans=-1) 41 | ch.plot_svgs(adata) 42 | plt.savefig(sample_folder + 'ch_svgs.png') 43 | plt.close() 44 | 45 | sc.pp.normalize_total(adata, inplace=True) 46 | sc.pp.log1p(adata) 47 | svg_num = len(adata.var[adata.var['spatially_variable'] == True]) 48 | if svg_num < 40: 49 | ch.pca(adata, n_pcs=svg_num - 1) 50 | else: 51 | ch.pca(adata, n_pcs=40) 52 | 53 | ch.plot_explained_variance(adata) 54 | plt.savefig(sample_folder + 'ch_expl_variance.png') 55 | plt.close() 56 | if svg_num < 20: 57 | ch.aa(adata, n_pcs=svg_num, n_archetypes=tissue_zones) 58 | else: 59 | ch.aa(adata, n_pcs=20, n_archetypes=tissue_zones) 60 | 61 | ch.plot(adata, dim=tissue_zones, marker='s') 62 | plt.savefig(sample_folder + 'ch_plot.png') 63 | plt.close() 64 | 65 | col_num = int(np.sqrt(tissue_zones)) 66 | ch.plot_compartments(adata, marker='s', ncols=col_num) 67 | plt.savefig(sample_folder + 'ch_comps.png') 68 | plt.close() 69 | 70 | # correlation with tissue zones 71 | compartments = adata.obsm['chr_aa'] 72 | compartment_df = pd.DataFrame(data=compartments, index=adata.obs.index) 73 | tissue_zone_df = adata.obsm['tissue_zones'] 74 | # tissue_zone_df = tissue_zone_df[[c for c in tissue_zone_df.columns if 'uniform' not in c]] 75 | 76 | corr_df = get_correlation_df(tissue_zone_df, compartment_df) 77 | corr_df.to_csv(sample_folder + 'pearson.csv') 78 | 79 | sns.heatmap(corr_df, square=True, center=0) 80 | plt.tight_layout() 81 | plt.savefig(sample_folder + 'corr_heatmap.png') 82 | plt.close() 83 | -------------------------------------------------------------------------------- /article/A1_synthetic_data/method_scripts/contamination_benchmark/graphst.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pandas as pd 4 | import scanpy as sc 5 | from tqdm import tqdm 6 | from glob import glob 7 | import seaborn as sns 8 | from GraphST import GraphST 9 | import matplotlib.pyplot as plt 10 | from sklearn.decomposition import PCA 11 | from article.A1_synthetic_data.bm_functions import get_correlation_df 12 | 13 | 14 | filepath = '/storage/homefs/pt22a065/chr_data/tabula_sapiens_immune_contamination' 15 | adatas = glob(filepath + '/*/*.h5ad') 16 | 17 | results_df = pd.DataFrame() 18 | 19 | for idx, adp in tqdm(enumerate(adatas), total=len(adatas)): 20 | print(adp) 21 | sample_folder = '/'.join(adp.split('/')[:-1]) + '/' 22 | device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu') 23 | 24 | # Check if all necessary output files already exist in the sample_folder 25 | if (os.path.exists(sample_folder + 'graphst_comps.csv') and 26 | os.path.exists(sample_folder + 'graphst_pearson.csv') and 27 | os.path.exists(sample_folder + 'graphst_corr_heatmap.png')): 28 | print(f"Skipping {sample_folder} as output files already exist.") 29 | continue 30 | 31 | adata = sc.read_h5ad(adp) 32 | adata.var_names_make_unique() 33 | 34 | try: 35 | # first attempt without any filtering 36 | sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=3000) 37 | except: 38 | try: 39 | # second attempt with min_counts=10 40 | sc.pp.filter_genes(adata, min_counts=10, inplace=True) 41 | # sc.pp.filter_cells(adata, min_counts=100, inplace=True) 42 | sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=3000) 43 | except: 44 | # final attempt with min_counts=100 45 | sc.pp.filter_genes(adata, min_counts=100, inplace=True) 46 | # sc.pp.filter_cells(adata, min_counts=100, inplace=True) 47 | sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=3000) 48 | 49 | sc.pp.normalize_total(adata, target_sum=1e4) 50 | sc.pp.log1p(adata) 51 | sc.pp.scale(adata, zero_center=False, max_value=10) 52 | 53 | # define model 54 | model = GraphST.GraphST(adata, device=torch.device('cpu')) 55 | 56 | # train model 57 | adata = model.train() 58 | 59 | graphst_df = pd.DataFrame(adata.obsm['emb']) 60 | graphst_df.index = adata.obs.index 61 | 62 | pca = PCA(n_components=20, svd_solver='arpack', random_state=42) 63 | graphst_pcs = pca.fit_transform(graphst_df) 64 | graphst_pcs_df = pd.DataFrame(data=graphst_pcs, index=graphst_df.index) 65 | 66 | graphst_pcs_df.to_csv(sample_folder + 'graphst_comps.csv') 67 | 68 | tissue_zone_df = adata.obsm['tissue_zones'] 69 | # tissue_zone_df = tissue_zone_df[[c for c in tissue_zone_df.columns if 'uniform' not in c]] 70 | 71 | corr_df = get_correlation_df(tissue_zone_df, graphst_pcs_df) 72 | corr_df.to_csv(sample_folder + 'graphst_pearson.csv') 73 | 74 | fig, ax = plt.subplots(1, 1, figsize=(6.4, 4.8)) 75 | sns.heatmap(corr_df, square=True, center=0, ax=ax) 76 | plt.tight_layout() 77 | plt.savefig(sample_folder + 'graphst_corr_heatmap.png') 78 | plt.close() 79 | -------------------------------------------------------------------------------- /article/A1_synthetic_data/method_scripts/contamination_benchmark/mefisto.py: -------------------------------------------------------------------------------- 1 | import os 2 | import mofax 3 | import pandas as pd 4 | import scanpy as sc 5 | from tqdm import tqdm 6 | from glob import glob 7 | import seaborn as sns 8 | import matplotlib.pyplot as plt 9 | from mofapy2.run.entry_point import entry_point 10 | from article.A1_synthetic_data.bm_functions import get_correlation_df 11 | 12 | 13 | filepath = '/storage/homefs/pt22a065/chr_data/tabula_sapiens_immune_contamination' 14 | adatas = glob(filepath + '/*/*.h5ad') 15 | 16 | results_df = pd.DataFrame() 17 | 18 | for idx, adp in tqdm(enumerate(adatas), total=len(adatas)): 19 | print(adp) 20 | sample_folder = '/'.join(adp.split('/')[:-1]) + '/' 21 | 22 | # Check if all necessary output files already exist in the sample_folder 23 | if (os.path.exists(sample_folder + 'mefisto_comps.csv') and 24 | os.path.exists(sample_folder + 'mefisto_pearson.csv') and 25 | os.path.exists(sample_folder + 'mefisto_corr_heatmap.png')): 26 | print(f"Skipping {sample_folder} as output files already exist.") 27 | continue 28 | 29 | adata = sc.read_h5ad(adp) 30 | 31 | # number of non-uniform compartments 32 | uniform = int(adata.uns['parameters']['n_tissue_zones'] * adata.uns['parameters']['uniform']) 33 | tissue_zones = adata.uns['parameters']['n_tissue_zones'] - uniform 34 | tissue_zones = int(tissue_zones) 35 | 36 | sc.pp.filter_genes(adata, min_cells=len(adata) * 0.05) 37 | 38 | sc.pp.normalize_total(adata, inplace=True) 39 | sc.pp.log1p(adata) 40 | sc.pp.highly_variable_genes(adata, flavor="seurat", n_top_genes=2000) 41 | 42 | adata.obs = pd.concat([adata.obs, 43 | pd.DataFrame(adata.obsm["spatial"], columns=["imagerow", "imagecol"], 44 | index=adata.obs_names), 45 | ], axis=1) 46 | 47 | ent = entry_point() 48 | ent.set_data_options(use_float32=True) 49 | ent.set_data_from_anndata(adata, features_subset="highly_variable") 50 | 51 | ent.set_model_options(factors=tissue_zones) 52 | ent.set_train_options(save_interrupted=True) 53 | ent.set_train_options(seed=2021) 54 | 55 | # We use 1000 inducing points to learn spatial covariance patterns 56 | n_inducing = 1000 # 500 for size tests 57 | 58 | ent.set_covariates([adata.obsm["spatial"]], covariates_names=["imagerow", "imagecol"]) 59 | ent.set_smooth_options(sparseGP=True, frac_inducing=n_inducing / adata.n_obs, 60 | start_opt=10, opt_freq=10) 61 | 62 | ent.build() 63 | ent.run() 64 | ent.save(sample_folder + "mefisto_temp.hdf5") 65 | m = mofax.mofa_model(sample_folder + "mefisto_temp.hdf5") 66 | factor_df = m.get_factors(df=True) 67 | 68 | # factor_df = ent.model.getFactors(df=True) 69 | 70 | factor_df.to_csv(sample_folder + 'mefisto_comps.csv') 71 | 72 | # correlation with tissue zones 73 | tissue_zone_df = adata.obsm['tissue_zones'] 74 | # tissue_zone_df = tissue_zone_df[[c for c in tissue_zone_df.columns if 'uniform' not in c]] 75 | 76 | corr_df = get_correlation_df(tissue_zone_df, factor_df) 77 | corr_df.to_csv(sample_folder + 'mefisto_pearson.csv') 78 | 79 | fig, ax = plt.subplots(1, 1, figsize=(6.4, 4.8)) 80 | sns.heatmap(corr_df, square=True, center=0, ax=ax) 81 | plt.tight_layout() 82 | plt.savefig(sample_folder + 'mefisto_corr_heatmap.png') 83 | plt.close() 84 | -------------------------------------------------------------------------------- /article/A1_synthetic_data/method_scripts/contamination_benchmark/nsf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import random 4 | import numpy as np 5 | import pandas as pd 6 | import scanpy as sc 7 | from os import path 8 | from tqdm import tqdm 9 | from glob import glob 10 | import seaborn as sns 11 | from nsf.models import sf 12 | import matplotlib.pyplot as plt 13 | from numpy.linalg import LinAlgError 14 | from tensorflow_probability import math as tm 15 | from article.A1_synthetic_data.bm_functions import get_correlation_df 16 | from nsf.utils import preprocess, training, visualize, postprocess, misc 17 | 18 | 19 | filepath = '/storage/homefs/pt22a065/chr_data/tabula_sapiens_immune_contamination' 20 | adatas = glob(filepath + '/*/*.h5ad') 21 | tfk = tm.psd_kernels 22 | 23 | results_df = pd.DataFrame() 24 | 25 | for idx, adp in tqdm(enumerate(adatas), total=len(adatas)): 26 | try: 27 | print(adp) 28 | sample_folder = '/'.join(adp.split('/')[:-1]) + '/' 29 | 30 | # Check if all necessary output files already exist in the sample_folder 31 | if (os.path.exists(sample_folder + 'nsf_comps.csv') and 32 | os.path.exists(sample_folder + 'nsf_pearson.csv') and 33 | os.path.exists(sample_folder + 'nsf_corr_heatmap.png')): 34 | print(f"Skipping {sample_folder} as output files already exist.") 35 | continue 36 | 37 | adata = sc.read_h5ad(adp) 38 | 39 | # number of non-uniform compartments 40 | uniform = int(adata.uns['parameters']['n_tissue_zones'] * adata.uns['parameters']['uniform']) 41 | tissue_zones = adata.uns['parameters']['n_tissue_zones'] - uniform 42 | tissue_zones = int(tissue_zones) 43 | 44 | sc.pp.calculate_qc_metrics(adata, inplace=True) 45 | # sc.pp.filter_genes(adata, min_cells=len(adata) * 0.05) 46 | sc.pp.filter_genes(adata, min_counts=100, inplace=True) 47 | 48 | sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=3000) 49 | adata.layers = {"counts": adata.X.copy()} # store raw counts before normalization changes ad.X 50 | sc.pp.normalize_total(adata, inplace=True, layers=None, key_added="sizefactor") 51 | sc.pp.log1p(adata) 52 | 53 | adata.var['deviance_poisson'] = preprocess.deviancePoisson(adata.layers["counts"]) 54 | o = np.argsort(-adata.var['deviance_poisson']) 55 | idx = list(range(adata.shape[0])) 56 | random.shuffle(idx) 57 | adata = adata[idx, o] 58 | adata = adata[:, :2000] 59 | 60 | Dtr, Dval = preprocess.anndata_to_train_val(adata, layer="counts", 61 | sz="constant") # size factor set to constant from scanpy for 62 | # the samples with y nan error 63 | Dtr_n, Dval_n = preprocess.anndata_to_train_val(adata) # normalized data 64 | fmeans, Dtr_c, Dval_c = preprocess.center_data(Dtr_n, Dval_n) # centered features 65 | Xtr = Dtr["X"] # note this should be identical to Dtr_n["X"] 66 | Ntr = Xtr.shape[0] 67 | Dtf = preprocess.prepare_datasets_tf(Dtr, Dval=Dval, shuffle=False) 68 | Dtf_n = preprocess.prepare_datasets_tf(Dtr_n, Dval=Dval_n, shuffle=False) 69 | Dtf_c = preprocess.prepare_datasets_tf(Dtr_c, Dval=Dval_c, shuffle=False) 70 | visualize.heatmap(Xtr, Dtr["Y"][:, 0], marker="D", s=15) 71 | plt.close() 72 | 73 | # Visualize raw data 74 | plt.imshow(np.log1p(Dtr["Y"])[:50, :100], cmap="Blues") 75 | plt.close() 76 | 77 | # Visualize inducing points 78 | Z = misc.kmeans_inducing_pts(Xtr, 500) 79 | fig, ax = plt.subplots(figsize=(12, 10)) 80 | ax.scatter(Xtr[:, 0], Xtr[:, 1], marker="D", s=50, ) 81 | ax.scatter(Z[:, 0], Z[:, 1], c="red", s=30) 82 | plt.close() 83 | 84 | # initialize inducing points and tuning parameters 85 | Z = misc.kmeans_inducing_pts(Xtr, 1000) # 2363 86 | M = Z.shape[0] 87 | ker = tfk.MaternThreeHalves 88 | S = 3 # samples for elbo approximation 89 | # NSF: Spatial only with non-negative factors 90 | L = tissue_zones # number of latent factors, ideally divisible by 2 91 | J = adata.shape[1] # 2000 92 | 93 | mpth = path.join("/storage/homefs/pt22a065/chr_benchmarks/nsf/models/V5/") 94 | 95 | fit = sf.SpatialFactorization(J, L, Z, psd_kernel=ker, nonneg=True, lik="poi") 96 | fit.elbo_avg(Xtr, Dtr["Y"], sz=Dtr["sz"]) 97 | fit.init_loadings(Dtr["Y"], X=Xtr, sz=Dtr["sz"]) 98 | fit.elbo_avg(Xtr, Dtr["Y"], sz=Dtr["sz"]) 99 | pp = fit.generate_pickle_path("scanpy", base=mpth) 100 | tro = training.ModelTrainer(fit, pickle_path=pp) 101 | tro.train_model(*Dtf, ckpt_freq=10000) 102 | 103 | ttl = "NSF: spatial, non-negative factors, Poisson likelihood" 104 | visualize.plot_loss(tro.loss, title=ttl) # ,ss=range(2000,4000)) 105 | plt.savefig(sample_folder + 'nsf_loss.png') 106 | plt.close() 107 | 108 | hmkw = {"figsize": (4, 4), "s": 0.3, "marker": "D", "subplot_space": 0, 109 | "spinecolor": "white"} 110 | insf = postprocess.interpret_nsf(fit, Xtr, S=10, lda_mode=False) 111 | tgnames = [str(i) for i in range(1, L + 1)] 112 | 113 | # fig, axes = visualize.multiheatmap(Xtr, np.sqrt(insf["factors"]), (4, 3), **hmkw) 114 | # visualize.set_titles(fig, tgnames, x=0.05, y=.85, fontsize="medium", c="white", 115 | # ha="left", va="top") 116 | # plt.savefig(sample_folder + 'nsf_comps.png') 117 | # plt.close() 118 | 119 | data = {'factors': insf, 'positions': Xtr} 120 | 121 | 122 | def transform_coords(X): 123 | # code from nsf github 124 | X[:, 1] = -X[:, 1] 125 | xmin = X.min(axis=0) 126 | X -= xmin 127 | x_gmean = np.exp(np.mean(np.log(X.max(axis=0)))) 128 | X *= 4 / x_gmean 129 | return X - X.mean(axis=0) 130 | 131 | 132 | X = adata.obsm["spatial"].copy().astype('float32') 133 | tcoords = transform_coords(X) 134 | 135 | pair_idx = [] 136 | for xy in data['positions']: 137 | distances = [math.dist([xy[0], xy[1]], [idx[0], idx[1]]) for idx in tcoords] 138 | pair_idx.append(np.argmin(distances)) 139 | 140 | nsf_df = pd.DataFrame(data=np.zeros([len(adata), data['factors']['factors'].shape[1]])) 141 | for idx, i in enumerate(pair_idx): 142 | nsf_df.iloc[i, :] = data['factors']['factors'][idx, :] 143 | nsf_df.index = adata.obs.index 144 | 145 | nsf_df.to_csv(sample_folder + 'nsf_comps.csv') 146 | nsf_df = nsf_df[~(nsf_df.eq(0).all(axis=1))] 147 | 148 | tissue_zone_df = adata.obsm['tissue_zones'].loc[list(nsf_df.index)] 149 | 150 | corr_df = get_correlation_df(tissue_zone_df, nsf_df) 151 | corr_df.to_csv(sample_folder + 'nsf_pearson.csv') 152 | 153 | fig, ax = plt.subplots(1, 1, figsize=(6.4, 4.8)) 154 | sns.heatmap(corr_df, square=True, center=0, ax=ax) 155 | plt.tight_layout() 156 | plt.savefig(sample_folder + 'nsf_corr_heatmap.png') 157 | plt.close() 158 | 159 | except (LinAlgError, ValueError) as e: 160 | print(f"LinAlgError encountered for matrix. Continuing...{e}") 161 | continue 162 | -------------------------------------------------------------------------------- /article/A1_synthetic_data/method_scripts/contamination_benchmark/stagate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import STAGATE 3 | import pandas as pd 4 | import scanpy as sc 5 | from tqdm import tqdm 6 | from glob import glob 7 | import seaborn as sns 8 | import sklearn.neighbors 9 | import matplotlib.pyplot as plt 10 | from article.A1_synthetic_data.bm_functions import get_correlation_df 11 | 12 | 13 | filepath = 'data/tabula_sapiens_immune_contamination' 14 | adatas = glob(filepath + '/*/*.h5ad') 15 | 16 | results_df = pd.DataFrame() 17 | 18 | for idx, adp in tqdm(enumerate(adatas), total=len(adatas)): 19 | print(adp) 20 | sample_folder = '/'.join(adp.split('/')[:-1]) + '/' 21 | 22 | # Check if all necessary output files already exist in the sample_folder 23 | if (os.path.exists(sample_folder + 'stagate_comps.csv') and 24 | os.path.exists(sample_folder + 'stagate_pearson.csv') and 25 | os.path.exists(sample_folder + 'stagate_corr_heatmap.png')): 26 | print(f"Skipping {sample_folder} as output files already exist.") 27 | continue 28 | 29 | adata = sc.read_h5ad(adp) 30 | 31 | try: 32 | # first attempt without any filtering 33 | sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=3000) 34 | except: 35 | try: 36 | # second attempt with min_counts=10 37 | sc.pp.filter_genes(adata, min_counts=10, inplace=True) 38 | # sc.pp.filter_cells(adata, min_counts=100, inplace=True) 39 | sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=3000) 40 | except: 41 | # final attempt with min_counts=100 42 | sc.pp.filter_genes(adata, min_counts=100, inplace=True) 43 | # sc.pp.filter_cells(adata, min_counts=100, inplace=True) 44 | sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=3000) 45 | 46 | sc.pp.normalize_total(adata, target_sum=1e4) 47 | sc.pp.log1p(adata) 48 | 49 | # include 8 neighbours with the cutoff similarly to 6 for visium 50 | STAGATE.Cal_Spatial_Net(adata, rad_cutoff=3.3) 51 | STAGATE.Stats_Spatial_Net(adata) 52 | 53 | coor = pd.DataFrame(adata.obsm['spatial']) 54 | coor.index = adata.obs.index 55 | coor.columns = ['imagerow', 'imagecol'] 56 | 57 | nbrs = sklearn.neighbors.NearestNeighbors(radius=3.3).fit(coor) 58 | distances, indices = nbrs.radius_neighbors(coor, return_distance=True) 59 | KNN_list = [] 60 | for it in range(indices.shape[0]): 61 | KNN_list.append(pd.DataFrame(zip([it]*indices[it].shape[0], indices[it], distances[it]))) 62 | KNN_df = pd.concat(KNN_list) 63 | KNN_df.columns = ['Cell1', 'Cell2', 'Distance'] 64 | 65 | adata = STAGATE.train_STAGATE(adata, alpha=0) 66 | 67 | for i in range(30): 68 | adata.obs[f'stagate_{i}'] = adata.obsm['STAGATE'][:, i] 69 | 70 | # doesn't work without the proper anndata structure - missing some spatial info 71 | # with mpl.rc_context({'figure.figsize': [4.5, 5]}): 72 | # sc.pl.spatial(adata, color=[f'S{i}' for i in range(30)], size=2, ncols=4, show=False) 73 | # plt.savefig(sample_folder + 'stagate_comps.png') 74 | # plt.close() 75 | 76 | # correlation with tissue zones 77 | stagate_df = adata.obs[[f'stagate_{i}' for i in range(30)]] 78 | stagate_df.to_csv(sample_folder + 'stagate_comps.csv') 79 | 80 | tissue_zone_df = adata.obsm['tissue_zones'] 81 | # tissue_zone_df = tissue_zone_df[[c for c in tissue_zone_df.columns if 'uniform' not in c]] 82 | 83 | corr_df = get_correlation_df(tissue_zone_df, stagate_df) 84 | corr_df.to_csv(sample_folder + 'stagate_pearson.csv') 85 | 86 | fig, ax = plt.subplots(1, 1, figsize=(6.4, 4.8)) 87 | sns.heatmap(corr_df, square=True, center=0, ax=ax) 88 | plt.tight_layout() 89 | plt.savefig(sample_folder + 'stagate_corr_heatmap.png') 90 | plt.close() 91 | -------------------------------------------------------------------------------- /article/A1_synthetic_data/method_scripts/main_synthetic_benchmark/chrysalis.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import scanpy as sc 4 | from tqdm import tqdm 5 | import seaborn as sns 6 | from glob import glob 7 | import chrysalis as ch 8 | import matplotlib.pyplot as plt 9 | from article.A1_synthetic_data.bm_functions import get_correlation_df 10 | 11 | 12 | filepath = 'data/tabula_sapiens_immune' 13 | adatas = glob(filepath + '/*/*.h5ad') 14 | 15 | results_df = pd.DataFrame() 16 | 17 | for idx, adp in tqdm(enumerate(adatas), total=len(adatas)): 18 | print(adp) 19 | sample_folder = '/'.join(adp.split('/')[:-1]) + '/' 20 | adata = sc.read_h5ad(adp) 21 | 22 | # number of non-uniform compartments 23 | uniform = int(adata.uns['parameters']['n_tissue_zones'] * adata.uns['parameters']['uniform']) 24 | tissue_zones = adata.uns['parameters']['n_tissue_zones'] - uniform 25 | tissue_zones = int(tissue_zones) 26 | 27 | # chrysalis pipeline 28 | ch.detect_svgs(adata, neighbors=8, top_svg=1000, min_morans=0.01) 29 | ch.plot_svgs(adata) 30 | plt.savefig(sample_folder + 'ch_svgs.png') 31 | plt.close() 32 | 33 | sc.pp.normalize_total(adata, inplace=True) 34 | sc.pp.log1p(adata) 35 | 36 | ch.pca(adata, n_pcs=40) 37 | 38 | ch.plot_explained_variance(adata) 39 | plt.savefig(sample_folder + 'ch_expl_variance.png') 40 | plt.close() 41 | 42 | ch.aa(adata, n_pcs=20, n_archetypes=tissue_zones) 43 | 44 | ch.plot(adata, dim=tissue_zones, marker='s') 45 | plt.savefig(sample_folder + 'ch_plot.png') 46 | plt.close() 47 | 48 | col_num = int(np.sqrt(tissue_zones)) 49 | ch.plot_compartments(adata, marker='s', ncols=col_num) 50 | plt.savefig(sample_folder + 'ch_comps.png') 51 | plt.close() 52 | 53 | # correlation with tissue zones 54 | compartments = adata.obsm['chr_aa'] 55 | compartment_df = pd.DataFrame(data=compartments, index=adata.obs.index) 56 | tissue_zone_df = adata.obsm['tissue_zones'] 57 | # tissue_zone_df = tissue_zone_df[[c for c in tissue_zone_df.columns if 'uniform' not in c]] 58 | 59 | corr_df = get_correlation_df(tissue_zone_df, compartment_df) 60 | corr_df.to_csv(sample_folder + 'pearson.csv') 61 | 62 | sns.heatmap(corr_df, square=True, center=0) 63 | plt.tight_layout() 64 | plt.savefig(sample_folder + 'corr_heatmap.png') 65 | plt.close() 66 | -------------------------------------------------------------------------------- /article/A1_synthetic_data/method_scripts/main_synthetic_benchmark/graphst.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pandas as pd 3 | import scanpy as sc 4 | from tqdm import tqdm 5 | from glob import glob 6 | import seaborn as sns 7 | from GraphST import GraphST 8 | import matplotlib.pyplot as plt 9 | from sklearn.decomposition import PCA 10 | from article.A1_synthetic_data.bm_functions import get_correlation_df 11 | 12 | 13 | filepath = 'data/tabula_sapiens_immune' 14 | adatas = glob(filepath + '/*/*.h5ad') 15 | 16 | results_df = pd.DataFrame() 17 | 18 | for idx, adp in tqdm(enumerate(adatas), total=len(adatas)): 19 | print(adp) 20 | sample_folder = '/'.join(adp.split('/')[:-1]) + '/' 21 | device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu') 22 | adata = sc.read_h5ad(adp) 23 | adata.var_names_make_unique() 24 | 25 | # define model 26 | model = GraphST.GraphST(adata, device=torch.device('cpu')) 27 | 28 | # train model 29 | adata = model.train() 30 | 31 | graphst_df = pd.DataFrame(adata.obsm['emb']) 32 | graphst_df.index = adata.obs.index 33 | 34 | pca = PCA(n_components=20, svd_solver='arpack', random_state=42) 35 | graphst_pcs = pca.fit_transform(graphst_df) 36 | graphst_pcs_df = pd.DataFrame(data=graphst_pcs, index=graphst_df.index) 37 | 38 | graphst_pcs_df.to_csv(sample_folder + 'graphst_comps.csv') 39 | 40 | 41 | tissue_zone_df = adata.obsm['tissue_zones'] 42 | # tissue_zone_df = tissue_zone_df[[c for c in tissue_zone_df.columns if 'uniform' not in c]] 43 | 44 | corr_df = get_correlation_df(tissue_zone_df, graphst_pcs_df) 45 | corr_df.to_csv(sample_folder + 'graphst_pearson.csv') 46 | 47 | fig, ax = plt.subplots(1, 1, figsize=(6.4, 4.8)) 48 | sns.heatmap(corr_df, square=True, center=0, ax=ax) 49 | plt.tight_layout() 50 | plt.savefig(sample_folder + 'graphst_corr_heatmap.png') 51 | plt.close() 52 | -------------------------------------------------------------------------------- /article/A1_synthetic_data/method_scripts/main_synthetic_benchmark/mefisto.py: -------------------------------------------------------------------------------- 1 | import mofax 2 | import pandas as pd 3 | import scanpy as sc 4 | from tqdm import tqdm 5 | from glob import glob 6 | import seaborn as sns 7 | import matplotlib.pyplot as plt 8 | from mofapy2.run.entry_point import entry_point 9 | from article.A1_synthetic_data.bm_functions import get_correlation_df 10 | 11 | 12 | filepath = '/storage/homefs/pt22a065/chr_data/tabula_sapiens_immune' 13 | adatas = glob(filepath + '/*/*.h5ad') 14 | 15 | results_df = pd.DataFrame() 16 | 17 | for idx, adp in tqdm(enumerate(adatas), total=len(adatas)): 18 | print(adp) 19 | sample_folder = '/'.join(adp.split('/')[:-1]) + '/' 20 | adata = sc.read_h5ad(adp) 21 | 22 | # number of non-uniform compartments 23 | uniform = int(adata.uns['parameters']['n_tissue_zones'] * adata.uns['parameters']['uniform']) 24 | tissue_zones = adata.uns['parameters']['n_tissue_zones'] - uniform 25 | tissue_zones = int(tissue_zones) 26 | 27 | sc.pp.filter_genes(adata, min_cells=len(adata) * 0.05) 28 | 29 | sc.pp.normalize_total(adata, inplace=True) 30 | sc.pp.log1p(adata) 31 | sc.pp.highly_variable_genes(adata, flavor="seurat", n_top_genes=2000) 32 | 33 | adata.obs = pd.concat([adata.obs, 34 | pd.DataFrame(adata.obsm["spatial"], columns=["imagerow", "imagecol"], 35 | index=adata.obs_names), 36 | ], axis=1) 37 | 38 | ent = entry_point() 39 | ent.set_data_options(use_float32=True) 40 | ent.set_data_from_anndata(adata, features_subset="highly_variable") 41 | 42 | ent.set_model_options(factors=tissue_zones) 43 | ent.set_train_options(save_interrupted=True) 44 | ent.set_train_options(seed=2021) 45 | 46 | # We use 1000 inducing points to learn spatial covariance patterns 47 | n_inducing = 500 # 500 for size tests 48 | 49 | ent.set_covariates([adata.obsm["spatial"]], covariates_names=["imagerow", "imagecol"]) 50 | ent.set_smooth_options(sparseGP=True, frac_inducing=n_inducing / adata.n_obs, 51 | start_opt=10, opt_freq=10) 52 | 53 | ent.build() 54 | ent.run() 55 | ent.save(sample_folder + "mefisto_temp.hdf5") 56 | m = mofax.mofa_model(sample_folder + "mefisto_temp.hdf5") 57 | factor_df = m.get_factors(df=True) 58 | 59 | # factor_df = ent.model.getFactors(df=True) 60 | 61 | factor_df.to_csv(sample_folder + 'mefisto_comps.csv') 62 | 63 | # correlation with tissue zones 64 | tissue_zone_df = adata.obsm['tissue_zones'] 65 | # tissue_zone_df = tissue_zone_df[[c for c in tissue_zone_df.columns if 'uniform' not in c]] 66 | 67 | corr_df = get_correlation_df(tissue_zone_df, factor_df) 68 | corr_df.to_csv(sample_folder + 'mefisto_pearson.csv') 69 | 70 | fig, ax = plt.subplots(1, 1, figsize=(6.4, 4.8)) 71 | sns.heatmap(corr_df, square=True, center=0, ax=ax) 72 | plt.tight_layout() 73 | plt.savefig(sample_folder + 'mefisto_corr_heatmap.png') 74 | plt.close() 75 | -------------------------------------------------------------------------------- /article/A1_synthetic_data/method_scripts/main_synthetic_benchmark/nsf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import random 4 | import numpy as np 5 | import pandas as pd 6 | import scanpy as sc 7 | from os import path 8 | from tqdm import tqdm 9 | from glob import glob 10 | import seaborn as sns 11 | from nsf.models import sf 12 | import matplotlib.pyplot as plt 13 | from numpy.linalg import LinAlgError 14 | from tensorflow_probability import math as tm 15 | from article.A1_synthetic_data.bm_functions import get_correlation_df 16 | from nsf.utils import preprocess, training, visualize, postprocess, misc 17 | 18 | 19 | filepath = 'data/tabula_sapiens_immune' 20 | adatas = glob(filepath + '/*/*.h5ad') 21 | tfk = tm.psd_kernels 22 | 23 | results_df = pd.DataFrame() 24 | 25 | for idx, adp in tqdm(enumerate(adatas), total=len(adatas)): 26 | try: 27 | print(adp) 28 | sample_folder = '/'.join(adp.split('/')[:-1]) + '/' 29 | 30 | # Check if all necessary output files already exist in the sample_folder 31 | if (os.path.exists(sample_folder + 'nsf_comps.csv') and 32 | os.path.exists(sample_folder + 'nsf_pearson.csv') and 33 | os.path.exists(sample_folder + 'nsf_corr_heatmap.png')): 34 | print(f"Skipping {sample_folder} as output files already exist.") 35 | continue 36 | 37 | adata = sc.read_h5ad(adp) 38 | 39 | # number of non-uniform compartments 40 | uniform = int(adata.uns['parameters']['n_tissue_zones'] * adata.uns['parameters']['uniform']) 41 | tissue_zones = adata.uns['parameters']['n_tissue_zones'] - uniform 42 | tissue_zones = int(tissue_zones) 43 | 44 | sc.pp.calculate_qc_metrics(adata, inplace=True) 45 | # sc.pp.filter_genes(adata, min_cells=len(adata) * 0.05) 46 | sc.pp.filter_genes(adata, min_counts=100, inplace=True) 47 | 48 | sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=3000) 49 | adata.layers = {"counts": adata.X.copy()} # store raw counts before normalization changes ad.X 50 | sc.pp.normalize_total(adata, inplace=True, layers=None, key_added="sizefactor") 51 | sc.pp.log1p(adata) 52 | 53 | adata.var['deviance_poisson'] = preprocess.deviancePoisson(adata.layers["counts"]) 54 | o = np.argsort(-adata.var['deviance_poisson']) 55 | idx = list(range(adata.shape[0])) 56 | random.shuffle(idx) 57 | adata = adata[idx, o] 58 | adata = adata[:, :2000] 59 | 60 | Dtr, Dval = preprocess.anndata_to_train_val(adata, layer="counts", 61 | sz="constant") # size factor set to constant from scanpy for 62 | # the samples with y nan error 63 | Dtr_n, Dval_n = preprocess.anndata_to_train_val(adata) # normalized data 64 | fmeans, Dtr_c, Dval_c = preprocess.center_data(Dtr_n, Dval_n) # centered features 65 | Xtr = Dtr["X"] # note this should be identical to Dtr_n["X"] 66 | Ntr = Xtr.shape[0] 67 | Dtf = preprocess.prepare_datasets_tf(Dtr, Dval=Dval, shuffle=False) 68 | Dtf_n = preprocess.prepare_datasets_tf(Dtr_n, Dval=Dval_n, shuffle=False) 69 | Dtf_c = preprocess.prepare_datasets_tf(Dtr_c, Dval=Dval_c, shuffle=False) 70 | visualize.heatmap(Xtr, Dtr["Y"][:, 0], marker="D", s=15) 71 | plt.close() 72 | 73 | # Visualize raw data 74 | plt.imshow(np.log1p(Dtr["Y"])[:50, :100], cmap="Blues") 75 | plt.close() 76 | 77 | # Visualize inducing points 78 | Z = misc.kmeans_inducing_pts(Xtr, 500) 79 | fig, ax = plt.subplots(figsize=(12, 10)) 80 | ax.scatter(Xtr[:, 0], Xtr[:, 1], marker="D", s=50, ) 81 | ax.scatter(Z[:, 0], Z[:, 1], c="red", s=30) 82 | plt.close() 83 | 84 | # initialize inducing points and tuning parameters 85 | Z = misc.kmeans_inducing_pts(Xtr, 2363) # 2363 86 | M = Z.shape[0] 87 | ker = tfk.MaternThreeHalves 88 | S = 3 # samples for elbo approximation 89 | # NSF: Spatial only with non-negative factors 90 | L = tissue_zones # number of latent factors, ideally divisible by 2 91 | J = adata.shape[1] # 2000 92 | 93 | mpth = path.join("/storage/homefs/pt22a065/chr_benchmarks/nsf/models/V5/") 94 | 95 | fit = sf.SpatialFactorization(J, L, Z, psd_kernel=ker, nonneg=True, lik="poi") 96 | fit.elbo_avg(Xtr, Dtr["Y"], sz=Dtr["sz"]) 97 | fit.init_loadings(Dtr["Y"], X=Xtr, sz=Dtr["sz"]) 98 | fit.elbo_avg(Xtr, Dtr["Y"], sz=Dtr["sz"]) 99 | pp = fit.generate_pickle_path("scanpy", base=mpth) 100 | tro = training.ModelTrainer(fit, pickle_path=pp) 101 | tro.train_model(*Dtf, ckpt_freq=10000) 102 | 103 | ttl = "NSF: spatial, non-negative factors, Poisson likelihood" 104 | visualize.plot_loss(tro.loss, title=ttl) # ,ss=range(2000,4000)) 105 | plt.savefig(sample_folder + 'nsf_loss.png') 106 | plt.close() 107 | 108 | hmkw = {"figsize": (4, 4), "s": 0.3, "marker": "D", "subplot_space": 0, 109 | "spinecolor": "white"} 110 | insf = postprocess.interpret_nsf(fit, Xtr, S=10, lda_mode=False) 111 | tgnames = [str(i) for i in range(1, L + 1)] 112 | 113 | # fig, axes = visualize.multiheatmap(Xtr, np.sqrt(insf["factors"]), (4, 3), **hmkw) 114 | # visualize.set_titles(fig, tgnames, x=0.05, y=.85, fontsize="medium", c="white", 115 | # ha="left", va="top") 116 | # plt.savefig(sample_folder + 'nsf_comps.png') 117 | # plt.close() 118 | 119 | data = {'factors': insf, 'positions': Xtr} 120 | 121 | 122 | def transform_coords(X): 123 | # code from nsf github 124 | X[:, 1] = -X[:, 1] 125 | xmin = X.min(axis=0) 126 | X -= xmin 127 | x_gmean = np.exp(np.mean(np.log(X.max(axis=0)))) 128 | X *= 4 / x_gmean 129 | return X - X.mean(axis=0) 130 | 131 | 132 | X = adata.obsm["spatial"].copy().astype('float32') 133 | tcoords = transform_coords(X) 134 | 135 | pair_idx = [] 136 | for xy in data['positions']: 137 | distances = [math.dist([xy[0], xy[1]], [idx[0], idx[1]]) for idx in tcoords] 138 | pair_idx.append(np.argmin(distances)) 139 | 140 | nsf_df = pd.DataFrame(data=np.zeros([len(adata), data['factors']['factors'].shape[1]])) 141 | for idx, i in enumerate(pair_idx): 142 | nsf_df.iloc[i, :] = data['factors']['factors'][idx, :] 143 | nsf_df.index = adata.obs.index 144 | 145 | nsf_df.to_csv(sample_folder + 'nsf_comps.csv') 146 | nsf_df = nsf_df[~(nsf_df.eq(0).all(axis=1))] 147 | 148 | tissue_zone_df = adata.obsm['tissue_zones'].loc[list(nsf_df.index)] 149 | 150 | corr_df = get_correlation_df(tissue_zone_df, nsf_df) 151 | corr_df.to_csv(sample_folder + 'nsf_pearson.csv') 152 | 153 | fig, ax = plt.subplots(1, 1, figsize=(6.4, 4.8)) 154 | sns.heatmap(corr_df, square=True, center=0, ax=ax) 155 | plt.tight_layout() 156 | plt.savefig(sample_folder + 'nsf_corr_heatmap.png') 157 | plt.close() 158 | 159 | except (LinAlgError, ValueError) as e: 160 | print(f"LinAlgError encountered for matrix. Continuing...{e}") 161 | continue 162 | -------------------------------------------------------------------------------- /article/A1_synthetic_data/method_scripts/main_synthetic_benchmark/stagate.py: -------------------------------------------------------------------------------- 1 | import STAGATE 2 | import pandas as pd 3 | import scanpy as sc 4 | from tqdm import tqdm 5 | from glob import glob 6 | import seaborn as sns 7 | import sklearn.neighbors 8 | import matplotlib.pyplot as plt 9 | from article.A1_synthetic_data.bm_functions import get_correlation_df 10 | 11 | 12 | filepath = 'data/tabula_sapiens_immune' 13 | adatas = glob(filepath + '/*/*.h5ad') 14 | 15 | results_df = pd.DataFrame() 16 | 17 | for idx, adp in tqdm(enumerate(adatas), total=len(adatas)): 18 | print(adp) 19 | sample_folder = '/'.join(adp.split('/')[:-1]) + '/' 20 | adata = sc.read_h5ad(adp) 21 | 22 | sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=3000) 23 | sc.pp.normalize_total(adata, target_sum=1e4) 24 | sc.pp.log1p(adata) 25 | 26 | # include 8 neighbours with the cutoff similarly to 6 for visium 27 | STAGATE.Cal_Spatial_Net(adata, rad_cutoff=3.3) 28 | STAGATE.Stats_Spatial_Net(adata) 29 | 30 | coor = pd.DataFrame(adata.obsm['spatial']) 31 | coor.index = adata.obs.index 32 | coor.columns = ['imagerow', 'imagecol'] 33 | 34 | nbrs = sklearn.neighbors.NearestNeighbors(radius=3.3).fit(coor) 35 | distances, indices = nbrs.radius_neighbors(coor, return_distance=True) 36 | KNN_list = [] 37 | for it in range(indices.shape[0]): 38 | KNN_list.append(pd.DataFrame(zip([it]*indices[it].shape[0], indices[it], distances[it]))) 39 | KNN_df = pd.concat(KNN_list) 40 | KNN_df.columns = ['Cell1', 'Cell2', 'Distance'] 41 | 42 | adata = STAGATE.train_STAGATE(adata, alpha=0) 43 | 44 | for i in range(30): 45 | adata.obs[f'stagate_{i}'] = adata.obsm['STAGATE'][:, i] 46 | 47 | # doesn't work without the proper anndata structure - missing some spatial info 48 | # with mpl.rc_context({'figure.figsize': [4.5, 5]}): 49 | # sc.pl.spatial(adata, color=[f'S{i}' for i in range(30)], size=2, ncols=4, show=False) 50 | # plt.savefig(sample_folder + 'stagate_comps.png') 51 | # plt.close() 52 | 53 | # correlation with tissue zones 54 | stagate_df = adata.obs[[f'stagate_{i}' for i in range(30)]] 55 | stagate_df.to_csv(sample_folder + 'stagate_comps.csv') 56 | 57 | tissue_zone_df = adata.obsm['tissue_zones'] 58 | # tissue_zone_df = tissue_zone_df[[c for c in tissue_zone_df.columns if 'uniform' not in c]] 59 | 60 | corr_df = get_correlation_df(tissue_zone_df, stagate_df) 61 | corr_df.to_csv(sample_folder + 'stagate_pearson.csv') 62 | 63 | fig, ax = plt.subplots(1, 1, figsize=(6.4, 4.8)) 64 | sns.heatmap(corr_df, square=True, center=0, ax=ax) 65 | plt.tight_layout() 66 | plt.savefig(sample_folder + 'stagate_corr_heatmap.png') 67 | plt.close() 68 | -------------------------------------------------------------------------------- /article/A2_human_lymph_node/SVG_detection_methods/1_bsp_spatialde_sepal.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "outputs": [], 7 | "source": [ 8 | "from SpatialDE import test\n", 9 | "import pickle\n", 10 | "import pandas as pd\n", 11 | "import scanpy as sc\n", 12 | "import squidpy as sq\n", 13 | "\n", 14 | "data_path = '/mnt/c/Users/demeter_turos/PycharmProjects/chrysalis/data/cell2loc_human_lymph_node/'" 15 | ], 16 | "metadata": { 17 | "collapsed": false 18 | } 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": null, 23 | "outputs": [], 24 | "source": [ 25 | "# BSP\n", 26 | "\n", 27 | "adata = sc.datasets.visium_sge(sample_id='V1_Human_Lymph_Node')\n", 28 | "\n", 29 | "adata.var_names_make_unique()\n", 30 | "\n", 31 | "sc.pp.calculate_qc_metrics(adata, inplace=True)\n", 32 | "sc.pp.filter_cells(adata, min_counts=6000)\n", 33 | "sc.pp.filter_genes(adata, min_cells=10)\n", 34 | "\n", 35 | "data = adata.to_df().astype(int)\n", 36 | "locs = adata.obsm['spatial']\n", 37 | "locs_df = pd.DataFrame(locs, columns=['x', 'y'])\n", 38 | "\n", 39 | "data.to_csv(data_path + 'lymph_node/counts.csv')\n", 40 | "locs_df.to_csv(data_path + 'lymph_node/locs.csv', index=False)\n", 41 | "\n", 42 | "# BSP was run via CLI using counts.csv and locs.csv\n", 43 | "# python BSP.py --datasetName lymph_node --spaLocFilename locs.csv --expFilename counts.csv" 44 | ], 45 | "metadata": { 46 | "collapsed": false 47 | } 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "id": "577377be", 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "# spatialDE\n", 57 | "\n", 58 | "adata = sc.datasets.visium_sge(sample_id='V1_Human_Lymph_Node')\n", 59 | "\n", 60 | "adata.var_names_make_unique()\n", 61 | "\n", 62 | "sc.pp.calculate_qc_metrics(adata, inplace=True)\n", 63 | "sc.pp.filter_cells(adata, min_counts=6000)\n", 64 | "sc.pp.filter_genes(adata, min_cells=10)\n", 65 | "\n", 66 | "# sc.pp.normalize_total(adata, inplace=True)\n", 67 | "# sc.pp.log1p(adata)\n", 68 | "\n", 69 | "results_t = test(adata)\n", 70 | "\n", 71 | "with open(data_path + 'spatialde.pickle', 'wb') as handle:\n", 72 | " pickle.dump(results_t, handle, protocol=pickle.HIGHEST_PROTOCOL)\n", 73 | "\n", 74 | "with open(data_path + 'spatialde.pickle', 'rb') as f:\n", 75 | " test = pickle.load(f)\n", 76 | "\n", 77 | "test[0].to_csv(data_path + 'spatialde.csv')" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": null, 83 | "id": "0ee072f2", 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "# Sepal\n", 88 | "\n", 89 | "adata = sc.datasets.visium_sge(sample_id='V1_Human_Lymph_Node')\n", 90 | "\n", 91 | "adata.var_names_make_unique()\n", 92 | "\n", 93 | "sc.pp.calculate_qc_metrics(adata, inplace=True)\n", 94 | "sc.pp.filter_cells(adata, min_counts=6000)\n", 95 | "sc.pp.filter_genes(adata, min_cells=10)\n", 96 | "\n", 97 | "sq.gr.spatial_neighbors(adata)\n", 98 | "genes = list(adata.var_names)\n", 99 | "sq.gr.sepal(adata, max_neighs=6, genes=genes, n_jobs=1)\n", 100 | "adata.uns[\"sepal_score\"].head(10)\n", 101 | "sepal_df = adata.uns[\"sepal_score\"]\n", 102 | "\n", 103 | "sepal_df.to_csv(data_path + 'sepal.csv')" 104 | ] 105 | } 106 | ], 107 | "metadata": { 108 | "kernelspec": { 109 | "display_name": "Python 3 (ipykernel)", 110 | "language": "python", 111 | "name": "python3" 112 | }, 113 | "language_info": { 114 | "codemirror_mode": { 115 | "name": "ipython", 116 | "version": 3 117 | }, 118 | "file_extension": ".py", 119 | "mimetype": "text/x-python", 120 | "name": "python", 121 | "nbconvert_exporter": "python", 122 | "pygments_lexer": "ipython3", 123 | "version": "3.8.15" 124 | } 125 | }, 126 | "nbformat": 4, 127 | "nbformat_minor": 5 128 | } 129 | -------------------------------------------------------------------------------- /article/A2_human_lymph_node/SVG_detection_methods/2_spark.R: -------------------------------------------------------------------------------- 1 | library(Seurat) 2 | library(SPARK) 3 | 4 | # Read sparse matrix from h5 file into Seurat object 5 | adata <- Load10X_Spatial("C:/Users/demeter_turos/PycharmProjects/chrysalis/data/V1_Human_Lymph_Node/") 6 | 7 | adata <- PercentageFeatureSet(adata, "^mt-", col.name = "percent_mito") 8 | 9 | adata <- subset(adata, subset = nCount_Spatial > 6000) 10 | gene_counts <- rowSums(GetAssayData(adata, slot = "counts") > 0) 11 | keep_genes <- names(gene_counts[gene_counts >= 10]) 12 | adata <- subset(adata, features = keep_genes) 13 | 14 | expression_data <- as.matrix(adata@assays$Spatial@data) 15 | dim(expression_data) 16 | locs <- GetTissueCoordinates(adata) 17 | 18 | sparkX <- sparkx(expression_data,locs,numCores=1,option="mixture") 19 | 20 | head(sparkX$res_mtest) 21 | write.csv(sparkX$res_mtest, "C:/Users/demeter_turos/PycharmProjects/chrysalis/dev/benchmarks/ 22 | fig_1_lymph_node_cell2loc/spark.csv") 23 | -------------------------------------------------------------------------------- /article/A2_human_lymph_node/benchmarking/graphst.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import pandas as pd 4 | import scanpy as sc 5 | from GraphST import GraphST 6 | 7 | 8 | start_time = time.time() 9 | 10 | data_path = '/mnt/c/Users/demeter_turos/PycharmProjects/chrysalis/data/cell2loc_human_lymph_node/' 11 | device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu') 12 | 13 | adata = sc.read_h5ad(data_path + 'chr.h5ad') 14 | adata.var_names_make_unique() 15 | 16 | # define model 17 | model = GraphST.GraphST(adata, device=torch.device('cpu')) 18 | 19 | # train model 20 | adata = model.train() 21 | 22 | pd.DataFrame(adata.obsm['emb']).to_csv(data_path + 'graphst_lymph_node.csv') 23 | 24 | end_time = time.time() 25 | elapsed_time = end_time - start_time 26 | print(elapsed_time) 27 | -------------------------------------------------------------------------------- /article/A2_human_lymph_node/benchmarking/mefisto.py: -------------------------------------------------------------------------------- 1 | import time 2 | import mofax 3 | import pandas as pd 4 | import scanpy as sc 5 | from mofapy2.run.entry_point import entry_point 6 | 7 | 8 | start_time = time.time() 9 | 10 | datadir = "/mnt/c/Users/demeter_turos/PycharmProjects/chrysalis/data/cell2loc_human_lymph_node/mefisto/" 11 | 12 | data_path = '/mnt/c/Users/demeter_turos/PycharmProjects/chrysalis/data/cell2loc_human_lymph_node/' 13 | adata = sc.read_h5ad(data_path + 'chr.h5ad') 14 | 15 | sc.pp.normalize_total(adata, inplace=True) 16 | sc.pp.log1p(adata) 17 | sc.pp.highly_variable_genes(adata, flavor="seurat", n_top_genes=2000) 18 | 19 | adata.obs = pd.concat([adata.obs, 20 | pd.DataFrame(adata.obsm["spatial"], columns=["imagerow", "imagecol"], index=adata.obs_names), 21 | ], axis=1) 22 | 23 | ent = entry_point() 24 | ent.set_data_options(use_float32=True) 25 | ent.set_data_from_anndata(adata, features_subset="highly_variable") 26 | 27 | ent.set_model_options(factors=8) 28 | ent.set_train_options() 29 | ent.set_train_options(seed=2021) 30 | 31 | # We use 1000 inducing points to learn spatial covariance patterns 32 | n_inducing = 1000 33 | 34 | ent.set_covariates([adata.obsm["spatial"]], covariates_names=["imagerow", "imagecol"]) 35 | ent.set_smooth_options(sparseGP=True, frac_inducing=n_inducing/adata.n_obs, 36 | start_opt=10, opt_freq=10) 37 | 38 | ent.build() 39 | ent.run() 40 | ent.save(datadir + "ST_model.hdf5") 41 | 42 | m = mofax.mofa_model(datadir + "ST_model.hdf5") 43 | factor_df = m.get_factors(df=True) 44 | factor_df.to_csv(datadir + 'factors.csv') 45 | 46 | end_time = time.time() 47 | elapsed_time = end_time - start_time 48 | print(elapsed_time) 49 | -------------------------------------------------------------------------------- /article/A2_human_lymph_node/benchmarking/nsf.py: -------------------------------------------------------------------------------- 1 | import time 2 | import pickle 3 | import random 4 | import numpy as np 5 | import scanpy as sc 6 | from os import path 7 | import matplotlib.pyplot as plt 8 | from tensorflow_probability import math as tm 9 | from nsf.models import cf, sf, sfh 10 | from nsf.utils import preprocess, training, misc, visualize, postprocess 11 | 12 | 13 | tfk = tm.psd_kernels 14 | start_time = time.time() 15 | 16 | ad = sc.datasets.visium_sge(sample_id='V1_Human_Lymph_Node') 17 | 18 | ad.var_names_make_unique() 19 | sc.pp.calculate_qc_metrics(ad, inplace=True) 20 | sc.pp.filter_cells(ad, min_counts=6000) 21 | sc.pp.filter_genes(ad, min_cells=10) 22 | ad.layers = {"counts":ad.X.copy()} #store raw counts before normalization changes ad.X 23 | sc.pp.normalize_total(ad, inplace=True, layers=None, key_added="sizefactor") 24 | sc.pp.log1p(ad) 25 | 26 | ad.var['deviance_poisson'] = preprocess.deviancePoisson(ad.layers["counts"]) 27 | o = np.argsort(-ad.var['deviance_poisson']) 28 | idx = list(range(ad.shape[0])) 29 | random.shuffle(idx) 30 | ad = ad[idx,o] 31 | ad = ad[:,:2000] 32 | 33 | Dtr,Dval = preprocess.anndata_to_train_val(ad,layer="counts",sz="scanpy") 34 | Dtr_n,Dval_n = preprocess.anndata_to_train_val(ad) #normalized data 35 | fmeans,Dtr_c,Dval_c = preprocess.center_data(Dtr_n,Dval_n) #centered features 36 | Xtr = Dtr["X"] #note this should be identical to Dtr_n["X"] 37 | Ntr = Xtr.shape[0] 38 | Dtf = preprocess.prepare_datasets_tf(Dtr,Dval=Dval,shuffle=False) 39 | Dtf_n = preprocess.prepare_datasets_tf(Dtr_n,Dval=Dval_n,shuffle=False) 40 | Dtf_c = preprocess.prepare_datasets_tf(Dtr_c,Dval=Dval_c,shuffle=False) 41 | visualize.heatmap(Xtr,Dtr["Y"][:,0],marker="D",s=15) 42 | plt.show() 43 | 44 | # Visualize raw data 45 | plt.imshow(np.log1p(Dtr["Y"])[:50,:100],cmap="Blues") 46 | plt.show() 47 | 48 | # Visualize inducing points 49 | Z = misc.kmeans_inducing_pts(Xtr,500) 50 | fig,ax=plt.subplots(figsize=(12,10)) 51 | ax.scatter(Xtr[:,0],Xtr[:,1],marker="D",s=50,) 52 | ax.scatter(Z[:,0],Z[:,1],c="red",s=30) 53 | plt.show() 54 | 55 | # initialize inducing points and tuning parameters 56 | Z = misc.kmeans_inducing_pts(Xtr, 2363) 57 | M = Z.shape[0] 58 | ker = tfk.MaternThreeHalves 59 | S = 3 #samples for elbo approximation 60 | # NSF: Spatial only with non-negative factors 61 | L = 8 #number of latent factors, ideally divisible by 2 62 | J = 2000 63 | 64 | mpth = path.join("/mnt/c/Users/demeter_turos/PycharmProjects/deep_learning/nsf/models/V5") 65 | 66 | fit = sf.SpatialFactorization(J,L,Z,psd_kernel=ker,nonneg=True,lik="poi") 67 | fit.elbo_avg(Xtr,Dtr["Y"],sz=Dtr["sz"]) 68 | fit.init_loadings(Dtr["Y"],X=Xtr,sz=Dtr["sz"]) 69 | fit.elbo_avg(Xtr,Dtr["Y"],sz=Dtr["sz"]) 70 | pp = fit.generate_pickle_path("scanpy",base=mpth) 71 | tro = training.ModelTrainer(fit,pickle_path=pp) 72 | tro.train_model(*Dtf) 73 | 74 | ttl = "NSF: spatial, non-negative factors, Poisson likelihood" 75 | visualize.plot_loss(tro.loss,title=ttl)#,ss=range(2000,4000)) 76 | plt.show() 77 | 78 | hmkw = {"figsize":(4,4), "s":0.3, "marker":"D", "subplot_space":0, 79 | "spinecolor":"white"} 80 | insf = postprocess.interpret_nsf(fit,Xtr,S=10,lda_mode=False) 81 | tgnames = [str(i) for i in range(1,L+1)] 82 | fig,axes=visualize.multiheatmap(Xtr, np.sqrt(insf["factors"]), (4,3), **hmkw) 83 | visualize.set_titles(fig, tgnames, x=0.05, y=.85, fontsize="medium", c="white", 84 | ha="left", va="top") 85 | plt.show() 86 | 87 | end_time = time.time() 88 | elapsed_time = end_time - start_time 89 | print(elapsed_time) 90 | 91 | file = open(mpth + '/human_lymph_node_nsf.pkl', 'wb') 92 | pickle.dump({'factors': insf, 'positions': Xtr}, file) 93 | file.close() 94 | 95 | file = open(mpth + '/human_lymph_node_nsf.pkl', 'rb') 96 | data = pickle.load(file) 97 | file.close() -------------------------------------------------------------------------------- /article/A2_human_lymph_node/benchmarking/spatialpca.R: -------------------------------------------------------------------------------- 1 | library(SpatialPCA) 2 | library(ggplot2) 3 | library(Matrix) 4 | library(Seurat) 5 | 6 | 7 | # HUMAN LYMPH NODE 8 | # Read sparse matrix from h5 file into Seurat object 9 | 10 | start_time <- Sys.time() 11 | 12 | adata <- Load10X_Spatial("C:/Users/demeter_turos/PycharmProjects/chrysalis/data/V1_Human_Lymph_Node/",) 13 | 14 | adata <- PercentageFeatureSet(adata, "^mt-", col.name = "percent_mito") 15 | 16 | adata <- subset(adata, subset = nCount_Spatial > 6000) 17 | gene_counts <- rowSums(GetAssayData(adata, slot = "counts") > 0) 18 | keep_genes <- names(gene_counts[gene_counts >= 10]) 19 | adata <- subset(adata, features = keep_genes) 20 | 21 | xy_coords <- adata@images$slice1@coordinates 22 | xy_coords <- xy_coords[c('imagerow', 'imagecol')] 23 | colnames(xy_coords) <- c('x_coord', 'y_coord') 24 | 25 | count_sub <- adata@assays$Spatial@data 26 | print(dim(count_sub)) # The count matrix 27 | xy_coords <- as.matrix(xy_coords) 28 | rownames(xy_coords) <- colnames(count_sub) # the rownames of location should match with the colnames of count matrix 29 | LIBD <- CreateSpatialPCAObject(counts=count_sub, location=xy_coords, project="SpatialPCA", gene.type="spatial", 30 | sparkversion="spark", numCores_spark=5, gene.number=3000, customGenelist=NULL, 31 | min.loctions=20, min.features=20) 32 | 33 | LIBD <- SpatialPCA_buildKernel(LIBD, kerneltype="gaussian", bandwidthtype="SJ", bandwidth.set.by.user=NULL) 34 | LIBD <- SpatialPCA_EstimateLoading(LIBD,fast=FALSE,SpatialPCnum=20) 35 | LIBD <- SpatialPCA_SpatialPCs(LIBD, fast=FALSE) 36 | 37 | 38 | saveRDS(LIBD, file = "C:/Users/demeter_turos/PycharmProjects/chrysalis/data/cell2loc_human_lymph_node/spatialpca/libd.rds") 39 | 40 | LIBD <- readRDS(file = "C:/Users/demeter_turos/PycharmProjects/chrysalis/data/cell2loc_human_lymph_node/spatialpca/libd.rds") 41 | 42 | write.csv(LIBD@SpatialPCs, "C:/Users/demeter_turos/PycharmProjects/chrysalis/data/cell2loc_human_lymph_node/spatialpca/spatial_pcs.csv") 43 | 44 | clusterlabel <- walktrap_clustering(clusternum=8, latent_dat=LIBD@SpatialPCs, knearest=70) 45 | # here for all 12 samples in LIBD, we set the same k nearest number in walktrap_clustering to be 70. 46 | # for other Visium or ST data, the user can also set k nearest number as 47 | # round(sqrt(dim(SpatialPCAobject@SpatialPCs)[2])) by default. 48 | clusterlabel_refine <- refine_cluster_10x(clusterlabels=clusterlabel, location=LIBD@location, shape="hexagon") 49 | 50 | end_time <- Sys.time() 51 | 52 | elapsed_time <- end_time - start_time 53 | print(elapsed_time) 54 | 55 | cbp<-c('#db5f57', '#dbc257', '#91db57', '#57db80', '#57d3db', '#5770db', '#a157db', '#db57b2') 56 | plot_cluster(location=xy_coords,clusterlabel=clusterlabel_refine, pointsize=1.5, 57 | title_in=paste0("SpatialPCA"), color_in=cbp) 58 | -------------------------------------------------------------------------------- /article/A2_human_lymph_node/benchmarking/stagate.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import scanpy as sc 3 | import matplotlib as mpl 4 | import STAGATE 5 | import time 6 | 7 | 8 | start_time = time.time() 9 | 10 | adata = sc.datasets.visium_sge(sample_id='V1_Human_Lymph_Node') 11 | sc.pp.calculate_qc_metrics(adata, inplace=True) 12 | 13 | sc.pp.filter_cells(adata, min_counts=6000) 14 | sc.pp.filter_genes(adata, min_cells=10) 15 | 16 | sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=3000) 17 | sc.pp.normalize_total(adata, target_sum=1e4) 18 | sc.pp.log1p(adata) 19 | 20 | STAGATE.Cal_Spatial_Net(adata, rad_cutoff=150) 21 | STAGATE.Stats_Spatial_Net(adata) 22 | 23 | coor = pd.DataFrame(adata.obsm['spatial']) 24 | coor.index = adata.obs.index 25 | coor.columns = ['imagerow', 'imagecol'] 26 | import sklearn.neighbors 27 | 28 | nbrs = sklearn.neighbors.NearestNeighbors(radius=150).fit(coor) 29 | distances, indices = nbrs.radius_neighbors(coor, return_distance=True) 30 | KNN_list = [] 31 | for it in range(indices.shape[0]): 32 | KNN_list.append(pd.DataFrame(zip([it]*indices[it].shape[0], indices[it], distances[it]))) 33 | KNN_df = pd.concat(KNN_list) 34 | KNN_df.columns = ['Cell1', 'Cell2', 'Distance'] 35 | 36 | adata = STAGATE.train_STAGATE(adata, alpha=0) 37 | 38 | for i in range(30): 39 | adata.obs[f'S{i}'] = adata.obsm['STAGATE'][:, i] 40 | 41 | with mpl.rc_context({'figure.figsize': [4.5, 5]}): 42 | sc.pl.spatial(adata, color=[f'S{i}' for i in range(30)], size=2, ncols=4) 43 | 44 | stagate_df = adata.obs[[f'S{i}' for i in range(30)]] 45 | stagate_df.to_csv('stagate_lymph_node.csv') 46 | 47 | end_time = time.time() 48 | elapsed_time = end_time - start_time 49 | print(elapsed_time) -------------------------------------------------------------------------------- /article/A3_human_breast_cancer/benchmarking/graphst.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pandas as pd 3 | import scanpy as sc 4 | from GraphST import GraphST 5 | 6 | 7 | data_path = '/mnt/c/Users/demeter_turos/PycharmProjects/chrysalis/data/xenium_human_breast_cancer/' 8 | 9 | device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu') 10 | 11 | adata = sc.read_h5ad(data_path + 'visium_sample.h5ad') 12 | adata.var_names_make_unique() 13 | 14 | # define model 15 | model = GraphST.GraphST(adata, device=torch.device('cpu')) 16 | # train model 17 | adata = model.train() 18 | 19 | pd.DataFrame(adata.obsm['emb']).to_csv(data_path + 'graphst_breast_cancer.csv') 20 | -------------------------------------------------------------------------------- /article/A3_human_breast_cancer/benchmarking/mefisto.py: -------------------------------------------------------------------------------- 1 | import time 2 | import mofax 3 | import pandas as pd 4 | import scanpy as sc 5 | from mofapy2.run.entry_point import entry_point 6 | 7 | 8 | datadir = "/mnt/c/Users/demeter_turos/PycharmProjects/chrysalis/data/xenium_human_breast_cancer/mefisto/" 9 | 10 | data_path = '/mnt/c/Users/demeter_turos/PycharmProjects/chrysalis/data/xenium_human_breast_cancer/' 11 | adata = sc.read_h5ad(data_path + 'visium_sample.h5ad') 12 | 13 | sc.pp.normalize_total(adata, inplace=True) 14 | sc.pp.log1p(adata) 15 | sc.pp.highly_variable_genes(adata, flavor="seurat", n_top_genes=2000) 16 | 17 | adata.obs = pd.concat([adata.obs, 18 | pd.DataFrame(adata.obsm["spatial"], columns=["imagerow", "imagecol"], index=adata.obs_names), 19 | ], axis=1) 20 | 21 | 22 | ent = entry_point() 23 | ent.set_data_options(use_float32=True) 24 | ent.set_data_from_anndata(adata, features_subset="highly_variable") 25 | 26 | ent.set_model_options(factors=8) 27 | ent.set_train_options() 28 | ent.set_train_options(seed=2021) 29 | 30 | # We use 1000 inducing points to learn spatial covariance patterns 31 | n_inducing = 1000 32 | 33 | ent.set_covariates([adata.obsm["spatial"]], covariates_names=["imagerow", "imagecol"]) 34 | ent.set_smooth_options(sparseGP=True, frac_inducing=n_inducing/adata.n_obs, 35 | start_opt=10, opt_freq=10) 36 | 37 | 38 | ent.build() 39 | ent.run() 40 | ent.save(datadir + "ST_model.hdf5") 41 | 42 | m = mofax.mofa_model(datadir + "ST_model.hdf5") 43 | factor_df = m.get_factors(df=True) 44 | factor_df.to_csv(datadir + 'factors.csv') -------------------------------------------------------------------------------- /article/A3_human_breast_cancer/benchmarking/nsf.py: -------------------------------------------------------------------------------- 1 | import time 2 | import pickle 3 | import random 4 | import numpy as np 5 | import scanpy as sc 6 | from os import path 7 | import matplotlib.pyplot as plt 8 | from tensorflow_probability import math as tm 9 | from nsf.models import cf, sf, sfh 10 | from nsf.utils import preprocess, training, misc, visualize, postprocess 11 | 12 | 13 | tfk = tm.psd_kernels 14 | 15 | data_path = '/mnt/c/Users/demeter_turos/PycharmProjects/chrysalis/data/xenium_human_breast_cancer/' 16 | 17 | ad = sc.read_h5ad(data_path + 'visium_sample.h5ad') 18 | ad.var_names_make_unique() 19 | 20 | sc.pp.calculate_qc_metrics(ad, inplace=True) 21 | sc.pp.filter_cells(ad, min_counts=1000) 22 | sc.pp.filter_genes(ad, min_cells=10) 23 | 24 | ad.layers = {"counts":ad.X.copy()} #store raw counts before normalization changes ad.X 25 | sc.pp.normalize_total(ad, inplace=True, layers=None, key_added="sizefactor") 26 | sc.pp.log1p(ad) 27 | 28 | # normalization, feature selection and train/test split 29 | ad.var['deviance_poisson'] = preprocess.deviancePoisson(ad.layers["counts"]) 30 | o = np.argsort(-ad.var['deviance_poisson']) 31 | idx = list(range(ad.shape[0])) 32 | random.shuffle(idx) 33 | ad = ad[idx,o] 34 | ad = ad[:,:2000] 35 | 36 | Dtr,Dval = preprocess.anndata_to_train_val(ad,layer="counts",sz="scanpy") 37 | Dtr_n,Dval_n = preprocess.anndata_to_train_val(ad) #normalized data 38 | fmeans,Dtr_c,Dval_c = preprocess.center_data(Dtr_n,Dval_n) #centered features 39 | Xtr = Dtr["X"] #note this should be identical to Dtr_n["X"] 40 | Ntr = Xtr.shape[0] 41 | Dtf = preprocess.prepare_datasets_tf(Dtr,Dval=Dval,shuffle=False) 42 | Dtf_n = preprocess.prepare_datasets_tf(Dtr_n,Dval=Dval_n,shuffle=False) 43 | Dtf_c = preprocess.prepare_datasets_tf(Dtr_c,Dval=Dval_c,shuffle=False) 44 | visualize.heatmap(Xtr,Dtr["Y"][:,0],marker="D",s=15) 45 | plt.show() 46 | 47 | # Visualize raw data 48 | plt.imshow(np.log1p(Dtr["Y"])[:50,:100],cmap="Blues") 49 | plt.show() 50 | 51 | # Visualize inducing points 52 | Z = misc.kmeans_inducing_pts(Xtr,500) 53 | fig,ax=plt.subplots(figsize=(12,10)) 54 | ax.scatter(Xtr[:,0],Xtr[:,1],marker="D",s=50,) 55 | ax.scatter(Z[:,0],Z[:,1],c="red",s=30) 56 | plt.show() 57 | 58 | # initialize inducing points and tuning parameters 59 | Z = misc.kmeans_inducing_pts(Xtr, 2363) 60 | M = Z.shape[0] 61 | ker = tfk.MaternThreeHalves 62 | S = 3 #samples for elbo approximation 63 | # NSF: Spatial only with non-negative factors 64 | L = 8 #number of latent factors, ideally divisible by 2 65 | J = 2000 66 | 67 | mpth = path.join("/mnt/c/Users/demeter_turos/PycharmProjects/deep_learning/nsf/models/V6") 68 | 69 | fit = sf.SpatialFactorization(J,L,Z,psd_kernel=ker,nonneg=True,lik="poi") 70 | fit.elbo_avg(Xtr,Dtr["Y"],sz=Dtr["sz"]) 71 | fit.init_loadings(Dtr["Y"],X=Xtr,sz=Dtr["sz"]) 72 | fit.elbo_avg(Xtr,Dtr["Y"],sz=Dtr["sz"]) 73 | pp = fit.generate_pickle_path("scanpy",base=mpth) 74 | tro = training.ModelTrainer(fit,pickle_path=pp) 75 | tro.train_model(*Dtf) 76 | ttl = "NSF: spatial, non-negative factors, Poisson likelihood" 77 | visualize.plot_loss(tro.loss,title=ttl)#,ss=range(2000,4000)) 78 | plt.show() 79 | 80 | # Postprocessing 81 | hmkw = {"figsize":(4,4), "s":0.3, "marker":"D", "subplot_space":0, 82 | "spinecolor":"white"} 83 | insf = postprocess.interpret_nsf(fit,Xtr,S=10,lda_mode=False) 84 | tgnames = [str(i) for i in range(1,L+1)] 85 | fig,axes=visualize.multiheatmap(Xtr, np.sqrt(insf["factors"]), (4,3), **hmkw) 86 | visualize.set_titles(fig, tgnames, x=0.05, y=.85, fontsize="medium", c="white", 87 | ha="left", va="top") 88 | plt.show() 89 | 90 | file = open(mpth + '/human_breast_cancer_nsf.pkl', 'wb') 91 | pickle.dump({'factors': insf, 'positions': Xtr}, file) 92 | file.close() 93 | 94 | file = open(mpth + '/human_breast_cancer_nsf.pkl', 'rb') 95 | data = pickle.load(file) 96 | file.close() 97 | -------------------------------------------------------------------------------- /article/A3_human_breast_cancer/benchmarking/spatialpca.R: -------------------------------------------------------------------------------- 1 | library(SpatialPCA) 2 | library(ggplot2) 3 | library(Matrix) 4 | library(Seurat) 5 | 6 | 7 | # HUMAN BREAST CANCER 8 | # Read sparse matrix from h5 file into Seurat object 9 | adata <- Load10X_Spatial("C:/Users/demeter_turos/PycharmProjects/chrysalis/data/xenium_human_breast_cancer/visium",) 10 | csv_data <- read.csv('C:/Users/demeter_turos/PycharmProjects/chrysalis/data/xenium_human_breast_cancer/spots_filtered.csv') 11 | first_column <- csv_data[, 1] 12 | adata <- adata[, first_column] 13 | 14 | adata <- PercentageFeatureSet(adata, "^mt-", col.name = "percent_mito") 15 | 16 | gene_counts <- rowSums(GetAssayData(adata, slot = "counts") > 0) 17 | keep_genes <- names(gene_counts[gene_counts >= 10]) 18 | adata <- subset(adata, features = keep_genes) 19 | 20 | xy_coords <- adata@images$slice1@coordinates 21 | xy_coords <- xy_coords[c('imagerow', 'imagecol')] 22 | colnames(xy_coords) <- c('x_coord', 'y_coord') 23 | 24 | count_sub <- adata@assays$Spatial@data 25 | print(dim(count_sub)) # The count matrix 26 | xy_coords <- as.matrix(xy_coords) 27 | rownames(xy_coords) <- colnames(count_sub) # the rownames of location should match with the colnames of count matrix 28 | LIBD <- CreateSpatialPCAObject(counts=count_sub, location=xy_coords, project="SpatialPCA", gene.type="spatial", 29 | sparkversion="spark", numCores_spark=5, gene.number=3000, customGenelist=NULL, 30 | min.loctions=20, min.features=20) 31 | 32 | LIBD <- SpatialPCA_buildKernel(LIBD, kerneltype="gaussian", bandwidthtype="SJ", bandwidth.set.by.user=NULL) 33 | LIBD <- SpatialPCA_EstimateLoading(LIBD,fast=FALSE,SpatialPCnum=20) 34 | LIBD <- SpatialPCA_SpatialPCs(LIBD, fast=FALSE) 35 | 36 | saveRDS(LIBD, file = "C:/Users/demeter_turos/PycharmProjects/chrysalis/data/xenium_human_breast_cancer/spatialpca/libd.rds") 37 | 38 | LIBD <- readRDS(file = "C:/Users/demeter_turos/PycharmProjects/chrysalis/data/xenium_human_breast_cancer/spatialpca/libd.rds") 39 | 40 | write.csv(LIBD@SpatialPCs, "C:/Users/demeter_turos/PycharmProjects/chrysalis/data/xenium_human_breast_cancer/spatialpca/spatial_pcs.csv") 41 | 42 | clusterlabel <- walktrap_clustering(clusternum=8, latent_dat=LIBD@SpatialPCs, knearest=70) 43 | # here for all 12 samples in LIBD, we set the same k nearest number in walktrap_clustering to be 70. 44 | # for other Visium or ST data, the user can also set k nearest number as round(sqrt(dim(SpatialPCAobject@SpatialPCs)[2])) by default. 45 | clusterlabel_refine <- refine_cluster_10x(clusterlabels=clusterlabel, location=LIBD@location, shape="hexagon") 46 | 47 | cbp<-c('#db5f57', '#dbc257', '#91db57', '#57db80', '#57d3db', '#5770db', '#a157db', '#db57b2') 48 | plot_cluster(location=xy_coords,clusterlabel=clusterlabel_refine, pointsize=1.5, 49 | title_in=paste0("SpatialPCA"), color_in=cbp) -------------------------------------------------------------------------------- /article/A3_human_breast_cancer/benchmarking/stagate.py: -------------------------------------------------------------------------------- 1 | import STAGATE 2 | import pandas as pd 3 | import scanpy as sc 4 | import matplotlib as mpl 5 | import sklearn.neighbors 6 | 7 | 8 | data_path = '/mnt/c/Users/demeter_turos/PycharmProjects/chrysalis/data/xenium_human_breast_cancer/' 9 | 10 | adata = sc.read_h5ad(data_path + 'visium_sample.h5ad') 11 | 12 | sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=3000) 13 | sc.pp.normalize_total(adata, target_sum=1e4) 14 | sc.pp.log1p(adata) 15 | 16 | STAGATE.Cal_Spatial_Net(adata, rad_cutoff=300) 17 | STAGATE.Stats_Spatial_Net(adata) 18 | 19 | coor = pd.DataFrame(adata.obsm['spatial']) 20 | coor.index = adata.obs.index 21 | coor.columns = ['imagerow', 'imagecol'] 22 | 23 | 24 | nbrs = sklearn.neighbors.NearestNeighbors(radius=1000).fit(coor) 25 | distances, indices = nbrs.radius_neighbors(coor, return_distance=True) 26 | KNN_list = [] 27 | for it in range(indices.shape[0]): 28 | KNN_list.append(pd.DataFrame(zip([it]*indices[it].shape[0], indices[it], distances[it]))) 29 | 30 | KNN_df = pd.concat(KNN_list) 31 | KNN_df.columns = ['Cell1', 'Cell2', 'Distance'] 32 | 33 | sc.pl.spatial(adata, color='CD74') 34 | 35 | adata = STAGATE.train_STAGATE(adata, alpha=0) 36 | 37 | for i in range(30): 38 | adata.obs[f'S{i}'] = adata.obsm['STAGATE'][:, i] 39 | 40 | 41 | with mpl.rc_context({'figure.figsize': [4.5, 5]}): 42 | sc.pl.spatial(adata, color=[f'S{i}' for i in range(30)], size=2, ncols=4) 43 | 44 | stagate_df = adata.obs[[f'S{i}' for i in range(30)]] 45 | stagate_df.to_csv('stagate_breast_cancer.csv') 46 | -------------------------------------------------------------------------------- /article/A3_human_breast_cancer/morphology_integration/2_autoencoder_training.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | import torch.utils.data as data 7 | import pytorch_lightning as L 8 | import matplotlib.pyplot as plt 9 | from itertools import islice 10 | from torchinfo import summary 11 | from torchvision import transforms 12 | from pytorch_lightning.callbacks import Callback, LearningRateMonitor, ModelCheckpoint 13 | 14 | 15 | # setting the seed 16 | L.seed_everything(42) 17 | 18 | # ensure that all operations are deterministic on GPU for reproducibility 19 | torch.backends.cudnn.deterministic = True 20 | torch.backends.cudnn.benchmark = False 21 | 22 | device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") 23 | print("Device:", device) 24 | 25 | data_path = "/mnt/c/Users/demeter_turos/PycharmProjects/chrysalis/data/xenium_human_breast_cancer/tiles_299.npy" 26 | 27 | class NpDataset(data.Dataset): 28 | def __init__(self, data_file=None, transform=None): 29 | self.data = np.load(data_file) 30 | if self.data.shape[-1] == 3: 31 | self.data = self.data.transpose(0, 3, 1, 2) 32 | self.data = self.data / 255.0 33 | self.transform = transform 34 | 35 | def __len__(self): 36 | return len(self.data) 37 | 38 | def __getitem__(self, idx): 39 | x = torch.from_numpy(self.data[idx]).float() 40 | if self.transform: 41 | x = self.transform(x) 42 | return x, x 43 | 44 | transform = transforms.Compose([transforms.ToPILImage(), 45 | transforms.Resize((256, 256)), 46 | transforms.ToTensor(), 47 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 48 | ]) 49 | ds = NpDataset(data_file=data_path, transform=transform) 50 | 51 | train_set, val_set = torch.utils.data.random_split(ds, [3510, 396]) 52 | 53 | train_loader = data.DataLoader(ds, batch_size=1, shuffle=True, drop_last=False, pin_memory=True, num_workers=1) 54 | 55 | 56 | def show_image(data_loader): 57 | dataiter = iter(data_loader) 58 | images = dataiter.next() 59 | img = images[0][0].clone() # clone to avoid changing original tensor 60 | img = img * 0.5 + 0.5 61 | img_arr = img.numpy().transpose(1, 2, 0) 62 | plt.imshow(img_arr) 63 | plt.show() 64 | 65 | 66 | show_image(train_loader) 67 | 68 | class Encoder(nn.Module): 69 | def __init__(self, num_input_channels: int, base_channel_size: int, latent_dim: int, act_fn: object = nn.GELU): 70 | super().__init__() 71 | c_hid = base_channel_size 72 | self.net = nn.Sequential( 73 | # 256x256 => 128x128 74 | nn.Conv2d(num_input_channels, c_hid, kernel_size=3, padding=1, stride=2), 75 | act_fn(), 76 | # 128x128 => 64x64 77 | nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1, stride=2), 78 | act_fn(), 79 | # 64x64 => 32x32 80 | nn.Conv2d(c_hid, 2 * c_hid, kernel_size=3, padding=1, stride=2), 81 | act_fn(), 82 | nn.Flatten(), # image grid to single feature vector 83 | nn.Linear(32 * 32 * 2 * c_hid, latent_dim), 84 | ) 85 | 86 | def forward(self, x): 87 | return self.net(x) 88 | 89 | 90 | class Decoder(nn.Module): 91 | def __init__(self, num_input_channels: int, base_channel_size: int, latent_dim: int, act_fn: object = nn.GELU): 92 | super().__init__() 93 | c_hid = base_channel_size 94 | self.linear = nn.Sequential(nn.Linear(latent_dim, 2 * c_hid * 32 * 32), act_fn()) 95 | self.net = nn.Sequential( 96 | # 32x32=>64x64 97 | nn.ConvTranspose2d(2 * c_hid, c_hid, kernel_size=3, output_padding=1, padding=1, stride=2), 98 | act_fn(), 99 | # 64x64 => 128x128 100 | nn.ConvTranspose2d(c_hid, num_input_channels, kernel_size=3, output_padding=1, padding=1, stride=2), 101 | act_fn(), 102 | # 128x128 => 256x256 103 | nn.ConvTranspose2d(num_input_channels, num_input_channels, kernel_size=3, output_padding=1, padding=1, 104 | stride=2), 105 | nn.Tanh(), 106 | ) 107 | 108 | def forward(self, x): 109 | x = self.linear(x) 110 | x = x.reshape(x.shape[0], 2 * 32, 32, 32) 111 | x = self.net(x) 112 | return x 113 | 114 | 115 | class Autoencoder(L.LightningModule): 116 | def __init__( 117 | self, 118 | base_channel_size: int, 119 | latent_dim: int, 120 | encoder_class: object = Encoder, 121 | decoder_class: object = Decoder, 122 | num_input_channels: int = 3, 123 | width: int = 32, 124 | height: int = 32, 125 | ): 126 | super().__init__() 127 | # saving hyperparameters of autoencoder 128 | self.save_hyperparameters() 129 | # creating encoder and decoder 130 | self.encoder = encoder_class(num_input_channels, base_channel_size, latent_dim) 131 | self.decoder = decoder_class(num_input_channels, base_channel_size, latent_dim) 132 | # example input array needed for visualizing the graph of the network 133 | self.example_input_array = torch.zeros(2, num_input_channels, width, height) 134 | 135 | def forward(self, x): 136 | """the forward function takes in an image and returns the reconstructed image.""" 137 | z = self.encoder(x) 138 | x_hat = self.decoder(z) 139 | return x_hat 140 | 141 | def _get_reconstruction_loss(self, batch): 142 | """given a batch of images, this function returns the reconstruction loss (MSE in our case)""" 143 | x, _ = batch # We do not need the labels 144 | x_hat = self.forward(x) 145 | loss = F.mse_loss(x, x_hat, reduction="none") 146 | loss = loss.sum(dim=[1, 2, 3]).mean(dim=[0]) 147 | return loss 148 | 149 | def configure_optimizers(self): 150 | optimizer = optim.Adam(self.parameters(), lr=1e-3) 151 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.2, patience=20, min_lr=5e-5) 152 | return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "train_loss"} 153 | 154 | def training_step(self, batch, batch_idx): 155 | loss = self._get_reconstruction_loss(batch) 156 | self.log("train_loss", loss) 157 | return loss 158 | 159 | def validation_step(self, batch, batch_idx): 160 | loss = self._get_reconstruction_loss(batch) 161 | self.log("val_loss", loss) 162 | 163 | def test_step(self, batch, batch_idx): 164 | loss = self._get_reconstruction_loss(batch) 165 | self.log("test_loss", loss) 166 | 167 | 168 | model = Autoencoder(base_channel_size=32, latent_dim=512, width=256, height=256) 169 | 170 | summary(model, (1, 3, 256, 256)) 171 | 172 | latent_dim = 512 173 | 174 | trainer = L.Trainer(default_root_dir="/mnt/c/Users/demeter_turos/PycharmProjects/deep_learning/" 175 | "autoencoder_he_image/models/", 176 | accelerator="auto", 177 | devices=1, 178 | max_epochs=100, 179 | callbacks=[ModelCheckpoint(save_weights_only=True), 180 | # GenerateCallback(get_train_images(8), every_n_epochs=10), 181 | LearningRateMonitor("epoch")], 182 | limit_val_batches=0, 183 | num_sanity_val_steps=0) 184 | 185 | trainer.fit(model, train_loader) 186 | 187 | model = Autoencoder.load_from_checkpoint("/mnt/c/Users/demeter_turos/PycharmProjects/deep_learning/" 188 | "autoencoder_he_image/models/lightning_logs/version_8/checkpoints/" 189 | "epoch=99-step=390600.ckpt") 190 | 191 | data_iter = iter(train_loader) # Create an iterator for the dataloader 192 | 193 | 194 | def show_images(model, data_iter): 195 | model.to(device) 196 | model.eval() 197 | num_elements = 9 198 | elements = [x[0] for x in islice(data_iter, num_elements)] # Select only the first tensor of each pair 199 | 200 | # Now stack the list of tensors to a single tensor 201 | elements = torch.cat(elements) 202 | 203 | images = data_iter.next() 204 | elements = elements.to(device) 205 | # images = images[None, :] 206 | outputs = model(elements) 207 | 208 | outputs = outputs.cpu().detach().numpy() 209 | 210 | 211 | fig, axs = plt.subplots(3, 6, figsize=(12, 6)) 212 | axs = axs.flatten() 213 | for a in axs: 214 | a.axis('off') 215 | 216 | for idx, i in enumerate([3, 4, 5, 9, 10, 11, 15, 16, 17]): 217 | out_img = outputs[idx] 218 | out_img = out_img * 0.5 + 0.5 219 | out_img_arr = out_img.transpose(1, 2, 0) 220 | axs[i].imshow(out_img_arr) 221 | 222 | for idx, i in enumerate([0, 1, 2, 6, 7 ,8, 12, 13, 14]): 223 | img = elements[idx].clone() 224 | img = img * 0.5 + 0.5 225 | img_arr = img.cpu().numpy().transpose(1, 2, 0) 226 | axs[i].imshow(img_arr) 227 | 228 | 229 | show_images(model, data_iter) 230 | plt.show() 231 | 232 | model.eval() 233 | 234 | all_encoded_features = [] 235 | 236 | eval_loader = data.DataLoader(ds, batch_size=1, shuffle=False, drop_last=False, pin_memory=True, num_workers=1) 237 | 238 | with torch.no_grad(): 239 | for batch in eval_loader: 240 | x, _ = batch 241 | x = x.to(device) 242 | encoded_features = model.encoder(x) 243 | encoded_features = encoded_features.cpu().detach().numpy() 244 | all_encoded_features.append(encoded_features) 245 | 246 | all_encoded_features = np.concatenate(all_encoded_features) 247 | 248 | np.save("/mnt/c/Users/demeter_turos/PycharmProjects/chrysalis/data/xenium_human_breast_cancer/aenc_features_512_v2.npy", 249 | all_encoded_features) 250 | -------------------------------------------------------------------------------- /article/A4_mouse_brain/ffpe/benchmark/graphst.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pandas as pd 3 | import scanpy as sc 4 | from tqdm import tqdm 5 | from glob import glob 6 | from GraphST import GraphST 7 | import seaborn as sns 8 | import matplotlib.pyplot as plt 9 | from benchmarks.bm_functions import get_correlation_df, collect_correlation_results, collect_metadata 10 | from sklearn.decomposition import PCA 11 | 12 | print(adp) 13 | sample_folder = '/'.join(adp.split('/')[:-1]) + '/' 14 | device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu') 15 | adata = sc.read_h5ad(adp) 16 | adata.var_names_make_unique() 17 | 18 | # define model 19 | model = GraphST.GraphST(adata, device=torch.device('cpu')) 20 | 21 | # train model 22 | adata = model.train() 23 | 24 | graphst_df = pd.DataFrame(adata.obsm['emb']) 25 | graphst_df.index = adata.obs.index 26 | 27 | pca = PCA(n_components=20, svd_solver='arpack', random_state=42) 28 | graphst_pcs = pca.fit_transform(graphst_df) 29 | graphst_pcs_df = pd.DataFrame(data=graphst_pcs, index=graphst_df.index) 30 | 31 | graphst_pcs_df.to_csv(sample_folder + 'graphst_comps.csv') 32 | 33 | tissue_zone_df = adata.obsm['tissue_zones'] 34 | # tissue_zone_df = tissue_zone_df[[c for c in tissue_zone_df.columns if 'uniform' not in c]] 35 | 36 | corr_df = get_correlation_df(tissue_zone_df, graphst_pcs_df) 37 | corr_df.to_csv(sample_folder + 'graphst_pearson.csv') 38 | 39 | fig, ax = plt.subplots(1, 1, figsize=(6.4, 4.8)) 40 | sns.heatmap(corr_df, square=True, center=0, ax=ax) 41 | plt.tight_layout() 42 | plt.savefig(sample_folder + 'graphst_corr_heatmap.png') 43 | plt.close() 44 | 45 | -------------------------------------------------------------------------------- /article/A4_mouse_brain/ffpe/benchmark/mefisto.py: -------------------------------------------------------------------------------- 1 | import mofax 2 | import pandas as pd 3 | import scanpy as sc 4 | from mofapy2.run.entry_point import entry_point 5 | from tqdm import tqdm 6 | from glob import glob 7 | 8 | 9 | filepath = '/storage/homefs/pt22a065/chr_data/mouse_brain_anterior' 10 | adatas = glob(filepath + '/*/*.h5ad') 11 | 12 | results_df = pd.DataFrame() 13 | 14 | for idx, adp in tqdm(enumerate(adatas), total=len(adatas)): 15 | print(adp) 16 | sample_folder = '/'.join(adp.split('/')[:-1]) + '/' 17 | adata = sc.read_h5ad(adp) 18 | 19 | sc.pp.filter_genes(adata, min_cells=len(adata) * 0.05) 20 | 21 | sc.pp.normalize_total(adata, inplace=True) 22 | sc.pp.log1p(adata) 23 | sc.pp.highly_variable_genes(adata, flavor="seurat", n_top_genes=2000) 24 | 25 | adata.obs = pd.concat([adata.obs, 26 | pd.DataFrame(adata.obsm["spatial"], columns=["imagerow", "imagecol"], 27 | index=adata.obs_names), 28 | ], axis=1) 29 | 30 | ent = entry_point() 31 | ent.set_data_options(use_float32=True) 32 | ent.set_data_from_anndata(adata, features_subset="highly_variable") 33 | 34 | ent.set_model_options(factors=28) 35 | ent.set_train_options(save_interrupted=True) 36 | ent.set_train_options(seed=2021) 37 | 38 | # We use 1000 inducing points to learn spatial covariance patterns 39 | n_inducing = 1000 # 500 for size tests 40 | 41 | ent.set_covariates([adata.obsm["spatial"]], covariates_names=["imagerow", "imagecol"]) 42 | ent.set_smooth_options(sparseGP=True, frac_inducing=n_inducing / adata.n_obs, 43 | start_opt=10, opt_freq=10) 44 | 45 | ent.build() 46 | ent.run() 47 | ent.save(sample_folder + "mefisto_temp.hdf5") 48 | m = mofax.mofa_model(sample_folder + "mefisto_temp.hdf5") 49 | factor_df = m.get_factors(df=True) 50 | 51 | # factor_df = ent.model.getFactors(df=True) 52 | 53 | factor_df.to_csv(sample_folder + 'mefisto_comps.csv') 54 | -------------------------------------------------------------------------------- /article/A4_mouse_brain/ffpe/benchmark/nsf.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import pandas as pd 4 | import scanpy as sc 5 | from os import path 6 | import os 7 | from tqdm import tqdm 8 | from glob import glob 9 | import seaborn as sns 10 | import matplotlib.pyplot as plt 11 | from tensorflow_probability import math as tm 12 | from nsf.models import sf 13 | from nsf.utils import preprocess, training, visualize, postprocess, misc 14 | import pickle 15 | import math 16 | 17 | from bm_functions import get_correlation_df, collect_correlation_results, collect_metadata 18 | 19 | filepath = '/storage/homefs/pt22a065/chr_data/mouse_brain_anterior' 20 | adatas = glob(filepath + '/*/*.h5ad') 21 | tfk = tm.psd_kernels 22 | 23 | results_df = pd.DataFrame() 24 | 25 | for idx, adp in tqdm(enumerate(adatas), total=len(adatas)): 26 | print(adp) 27 | sample_folder = '/'.join(adp.split('/')[:-1]) + '/' 28 | 29 | # Check if all necessary output files already exist in the sample_folder 30 | if (os.path.exists(sample_folder + 'nsf_comps.csv') and 31 | os.path.exists(sample_folder + 'nsf_pearson.csv') and 32 | os.path.exists(sample_folder + 'nsf_corr_heatmap.png')): 33 | print(f"Skipping {sample_folder} as output files already exist.") 34 | continue 35 | 36 | adata = sc.read_h5ad(adp) 37 | 38 | sc.pp.calculate_qc_metrics(adata, inplace=True) 39 | sc.pp.filter_genes(adata, min_cells=len(adata) * 0.05) 40 | 41 | sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=3000) 42 | adata.layers = {"counts": adata.X.copy()} # store raw counts before normalization changes ad.X 43 | sc.pp.normalize_total(adata, inplace=True, layers=None, key_added="sizefactor") 44 | sc.pp.log1p(adata) 45 | 46 | adata.var['deviance_poisson'] = preprocess.deviancePoisson(adata.layers["counts"]) 47 | o = np.argsort(-adata.var['deviance_poisson']) 48 | idx = list(range(adata.shape[0])) 49 | random.shuffle(idx) 50 | adata = adata[idx, o] 51 | adata = adata[:, :2000] 52 | 53 | Dtr, Dval = preprocess.anndata_to_train_val(adata, layer="counts", sz="scanpy") 54 | Dtr_n, Dval_n = preprocess.anndata_to_train_val(adata) # normalized data 55 | fmeans, Dtr_c, Dval_c = preprocess.center_data(Dtr_n, Dval_n) # centered features 56 | Xtr = Dtr["X"] # note this should be identical to Dtr_n["X"] 57 | Ntr = Xtr.shape[0] 58 | Dtf = preprocess.prepare_datasets_tf(Dtr, Dval=Dval, shuffle=False) 59 | Dtf_n = preprocess.prepare_datasets_tf(Dtr_n, Dval=Dval_n, shuffle=False) 60 | Dtf_c = preprocess.prepare_datasets_tf(Dtr_c, Dval=Dval_c, shuffle=False) 61 | visualize.heatmap(Xtr, Dtr["Y"][:, 0], marker="D", s=15) 62 | plt.close() 63 | 64 | # Visualize raw data 65 | plt.imshow(np.log1p(Dtr["Y"])[:50, :100], cmap="Blues") 66 | plt.close() 67 | 68 | # Visualize inducing points 69 | Z = misc.kmeans_inducing_pts(Xtr, 500) 70 | fig, ax = plt.subplots(figsize=(12, 10)) 71 | ax.scatter(Xtr[:, 0], Xtr[:, 1], marker="D", s=50, ) 72 | ax.scatter(Z[:, 0], Z[:, 1], c="red", s=30) 73 | plt.close() 74 | 75 | # initialize inducing points and tuning parameters 76 | Z = misc.kmeans_inducing_pts(Xtr, 2363) # 2363 77 | M = Z.shape[0] 78 | ker = tfk.MaternThreeHalves 79 | S = 3 # samples for elbo approximation 80 | # NSF: Spatial only with non-negative factors 81 | L = 28 # number of latent factors, ideally divisible by 2 82 | J = 2000 83 | 84 | mpth = path.join("/storage/homefs/pt22a065/chr_benchmarks/nsf/models/V5/") 85 | 86 | fit = sf.SpatialFactorization(J, L, Z, psd_kernel=ker, nonneg=True, lik="poi") 87 | fit.elbo_avg(Xtr, Dtr["Y"], sz=Dtr["sz"]) 88 | fit.init_loadings(Dtr["Y"], X=Xtr, sz=Dtr["sz"]) 89 | fit.elbo_avg(Xtr, Dtr["Y"], sz=Dtr["sz"]) 90 | pp = fit.generate_pickle_path("scanpy", base=mpth) 91 | tro = training.ModelTrainer(fit, pickle_path=pp) 92 | tro.train_model(*Dtf, ckpt_freq=10000) 93 | 94 | ttl = "NSF: spatial, non-negative factors, Poisson likelihood" 95 | visualize.plot_loss(tro.loss, title=ttl) # ,ss=range(2000,4000)) 96 | plt.savefig(sample_folder + 'nsf_loss.png') 97 | plt.close() 98 | 99 | hmkw = {"figsize": (4, 4), "s": 0.3, "marker": "D", "subplot_space": 0, 100 | "spinecolor": "white"} 101 | insf = postprocess.interpret_nsf(fit, Xtr, S=10, lda_mode=False) 102 | tgnames = [str(i) for i in range(1, L + 1)] 103 | 104 | # fig, axes = visualize.multiheatmap(Xtr, np.sqrt(insf["factors"]), (4, 3), **hmkw) 105 | # visualize.set_titles(fig, tgnames, x=0.05, y=.85, fontsize="medium", c="white", 106 | # ha="left", va="top") 107 | # plt.savefig(sample_folder + 'nsf_comps.png') 108 | # plt.close() 109 | 110 | data = {'factors': insf, 'positions': Xtr} 111 | 112 | 113 | def transform_coords(X): 114 | # code from nsf github 115 | X[:, 1] = -X[:, 1] 116 | xmin = X.min(axis=0) 117 | X -= xmin 118 | x_gmean = np.exp(np.mean(np.log(X.max(axis=0)))) 119 | X *= 4 / x_gmean 120 | return X - X.mean(axis=0) 121 | 122 | 123 | X = adata.obsm["spatial"].copy().astype('float32') 124 | tcoords = transform_coords(X) 125 | 126 | pair_idx = [] 127 | for xy in data['positions']: 128 | distances = [math.dist([xy[0], xy[1]], [idx[0], idx[1]]) for idx in tcoords] 129 | pair_idx.append(np.argmin(distances)) 130 | 131 | nsf_df = pd.DataFrame(data=np.zeros([len(adata), data['factors']['factors'].shape[1]])) 132 | for idx, i in enumerate(pair_idx): 133 | nsf_df.iloc[i, :] = data['factors']['factors'][idx, :] 134 | nsf_df.index = adata.obs.index 135 | 136 | nsf_df.to_csv(sample_folder + 'nsf_comps.csv') -------------------------------------------------------------------------------- /article/A4_mouse_brain/ffpe/benchmark/stagate.py: -------------------------------------------------------------------------------- 1 | import STAGATE 2 | import pandas as pd 3 | import scanpy as sc 4 | from tqdm import tqdm 5 | from glob import glob 6 | import seaborn as sns 7 | import matplotlib as mpl 8 | import sklearn.neighbors 9 | import matplotlib.pyplot as plt 10 | from benchmarks.bm_functions import get_correlation_df, collect_correlation_results, collect_metadata 11 | import numpy as np 12 | 13 | 14 | filepath = 'data/mouse_brain_anterior/ffpe_cranial/' 15 | adata = sc.read_h5ad(filepath + 'chr.h5ad') 16 | 17 | sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=3000) 18 | sc.pp.normalize_total(adata, target_sum=1e4) 19 | sc.pp.log1p(adata) 20 | 21 | # include 8 neighbours with the cutoff similarly to 6 for visium 22 | STAGATE.Cal_Spatial_Net(adata, rad_cutoff=150) 23 | STAGATE.Stats_Spatial_Net(adata) 24 | 25 | coor = pd.DataFrame(adata.obsm['spatial']) 26 | coor.index = adata.obs.index 27 | coor.columns = ['imagerow', 'imagecol'] 28 | 29 | nbrs = sklearn.neighbors.NearestNeighbors(radius=150).fit(coor) 30 | distances, indices = nbrs.radius_neighbors(coor, return_distance=True) 31 | KNN_list = [] 32 | for it in range(indices.shape[0]): 33 | KNN_list.append(pd.DataFrame(zip([it] * indices[it].shape[0], indices[it], distances[it]))) 34 | KNN_df = pd.concat(KNN_list) 35 | KNN_df.columns = ['Cell1', 'Cell2', 'Distance'] 36 | 37 | adata = STAGATE.train_STAGATE(adata, alpha=0) 38 | 39 | for i in range(30): 40 | adata.obs[f'stagate_{i}'] = adata.obsm['STAGATE'][:, i] 41 | 42 | with mpl.rc_context({'figure.figsize': [4.5, 5]}): 43 | sc.pl.spatial(adata, color=[f'stagate_{i}' for i in range(30)], size=2, ncols=4, show=False) 44 | plt.savefig(filepath + 'stagate_comps.png') 45 | plt.close() 46 | 47 | # correlation with tissue zones 48 | stagate_df = adata.obs[[f'stagate_{i}' for i in range(30)]] 49 | stagate_df.to_csv(filepath + 'stagate_comps.csv') 50 | -------------------------------------------------------------------------------- /article/A4_mouse_brain/ffpe/map_annotations.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import pandas as pd 3 | import scanpy as sc 4 | from tqdm import tqdm 5 | import matplotlib.pyplot as plt 6 | from shapely.geometry import Point 7 | from shapely.ops import unary_union 8 | from paquo.projects import QuPathProject 9 | from shapely.errors import ShapelyDeprecationWarning 10 | 11 | 12 | def get_annotation_polygons(qupath_project, show=True): 13 | # Filter out ShapelyDeprecationWarning 14 | warnings.filterwarnings("ignore", category=ShapelyDeprecationWarning) 15 | # load the qupath project 16 | slides = QuPathProject(qupath_project, mode='r+') 17 | img_dict = {} 18 | for img in slides.images: 19 | 20 | # get annotations for slide image 21 | annotations = img.hierarchy.annotations 22 | 23 | # collect polys 24 | polys = {} 25 | error_count = 0 26 | erroneous_annot = {} 27 | for annotation in annotations: 28 | try: 29 | id = annotation.path_class.id 30 | if id in polys.keys(): 31 | if annotation.roi.type != 'LineString': 32 | polys[id].append(annotation.roi) 33 | else: 34 | if annotation.roi.type != 'LineString': 35 | polys[id] = [annotation.roi] 36 | except Exception: 37 | erroneous_annot[error_count] = annotation 38 | error_count += 1 39 | print(f"Reading slide {img.image_name}") 40 | print(f"Erroneous poly found {error_count} times from {len(annotations)} polygons.") 41 | 42 | if show: 43 | # merge polys with the same annotation 44 | polym = {} 45 | for key in polys.keys(): 46 | polym[key] = unary_union(polys[key]) 47 | # look at them 48 | for key in polym.keys(): 49 | if polym[key].type != 'Polygon': 50 | for geom in polym[key].geoms: 51 | plt.plot(*geom.exterior.xy) 52 | else: 53 | plt.plot(*polym[key].exterior.xy) 54 | plt.show() 55 | 56 | img_dict[img.image_name] = polys 57 | return img_dict 58 | 59 | 60 | def map_annotations(adata, polygon_dict, default_annot='Tumor'): 61 | df_dict = {i: default_annot for i in list(adata.obs.index)} 62 | tissue_type = pd.DataFrame(df_dict.values(), index=df_dict.keys()) 63 | spot_df = pd.DataFrame(adata.obsm['spatial']) 64 | 65 | spot_annots = {} 66 | for key in tqdm(polygon_dict.keys(), desc='Mapping annotations...'): 67 | x, y = spot_df.iloc[:, 0], spot_df.iloc[:, 1] 68 | points = [Point(x, y) for x, y in zip(x, y)] 69 | 70 | contains_sum = [False for x in range(len(spot_df))] 71 | for iv in polygon_dict[key]: 72 | contains = [iv.contains(p) for p in points] 73 | contains_sum = [cs or c for cs, c in zip(contains_sum, contains)] 74 | # plt.scatter(x, y, s=1) 75 | # plt.plot(*iv.exterior.xy) 76 | # plt.show() 77 | 78 | spot_annots[key] = contains_sum 79 | replace = adata.obs.index[contains_sum] 80 | 81 | tissue_type[0][replace] = key 82 | return tissue_type 83 | 84 | 85 | 86 | data_path = 'data/Visium_FFPE_Mouse_Brain/' 87 | adata = sc.read_h5ad(data_path + 'chr_28.h5ad') 88 | 89 | qupath_project = 'data/Visium_FFPE_Mouse_Brain/smooth_brain/' 90 | 91 | polygon_dict = get_annotation_polygons(qupath_project, show=True) 92 | 93 | img_key = list(polygon_dict.keys())[0] 94 | polygon_dict = polygon_dict[img_key] 95 | annots = map_annotations(adata, polygon_dict, default_annot='Rest') 96 | annots[0] = ['Corpus callosum' if x == 'Corpus_callosum/capsula_externa' else x for x in annots[0]] 97 | adata.obs['annotation'] = annots 98 | 99 | sc.pl.spatial(adata, color='annotation') 100 | plt.show() 101 | 102 | adata.write_h5ad(data_path + 'chr_28_annotated.h5ad') 103 | -------------------------------------------------------------------------------- /article/readme.md: -------------------------------------------------------------------------------- 1 | # Chrysalis Article Readme 2 | 3 | 4 | This folder contains the notebooks and scripts used for our research article, showcasing the data analysis we conducted. 5 | To recreate the analysis, you'll need to take some extra steps, such as getting the raw data, adjusting directory paths, 6 | and downloading supplementary files from Zenodo (https://doi.org/10.5281/zenodo.8247780). 7 | ``` 8 | . 9 | ├── A1_synthetic_data 10 | │   ├── array_size_benchmark.ipynb 11 | │   ├── bm_functions.py 12 | │   ├── chrysalis_example.ipynb 13 | │   ├── contamination_benchmark.ipynb 14 | │   ├── data_generator 15 | │   │   ├── functions.py 16 | │   │   ├── generate_synthetic_datasets.py 17 | │   │   ├── generate_truncated_samples.py 18 | │   │   ├── tissue_generator.py 19 | │   │   └── tools.py 20 | │   ├── main_synthetic_benchmark.ipynb 21 | │   └── method_scripts 22 | │   ├── array_size_benchmark 23 | │   │   ├── chrysalis.py 24 | │   │   ├── graphst.py 25 | │   │   ├── mefisto.py 26 | │   │   ├── nsf.py 27 | │   │   └── stagate.py 28 | │   ├── contamination_benchmark 29 | │   │   ├── chrysalis.py 30 | │   │   ├── graphst.py 31 | │   │   ├── mefisto.py 32 | │   │   ├── nsf.py 33 | │   │   └── stagate.py 34 | │   └── main_synthetic_benchmark 35 | │   ├── chrysalis.py 36 | │   ├── graphst.py 37 | │   ├── mefisto.py 38 | │   ├── nsf.py 39 | │   └── stagate.py 40 | ├── A2_human_lymph_node 41 | │   ├── SVG_detection_methods 42 | │   │   ├── 1_bsp_spatialde_sepal.ipynb 43 | │   │   ├── 2_spark.R 44 | │   │   └── 3_method_comparison.ipynb 45 | │   ├── benchmarking 46 | │   │   ├── graphst.py 47 | │   │   ├── mefisto.py 48 | │   │   ├── nsf.py 49 | │   │   ├── spatialpca.R 50 | │   │   └── stagate.py 51 | │   ├── chrysalis_analysis_and_validation.ipynb 52 | │   └── morans_i.ipynb 53 | ├── A3_human_breast_cancer 54 | │   ├── benchmarking 55 | │   │   ├── graphst.py 56 | │   │   ├── mefisto.py 57 | │   │   ├── nsf.py 58 | │   │   ├── spatialpca.R 59 | │   │   └── stagate.py 60 | │   ├── benchmarking.ipynb 61 | │   ├── chrysalis_analysis_and_validation.ipynb 62 | │   └── morphology_integration 63 | │   ├── 1_extract_image_tiles.ipynb 64 | │   ├── 2_autoencoder_training.py 65 | │   └── 3_integrate_morphology.ipynb 66 | ├── A4_mouse_brain 67 | │   ├── ff 68 | │   │   ├── mouse_brain_ff.ipynb 69 | │   │   └── mouse_brain_integration.ipynb 70 | │   └── ffpe 71 | │   ├── benchmark 72 | │   │   ├── graphst.py 73 | │   │   ├── mefisto.py 74 | │   │   ├── nsf.py 75 | │   │   └── stagate.py 76 | │   ├── map_annotations.py 77 | │   ├── mouse_brain_ffpe.ipynb 78 | │   └── mouse_brain_ffpe_benchmark.ipynb 79 | ├── A5_visium_hd 80 | │   └── visium_hd_analysis.ipynb 81 | ├── A6_slide_seqv2 82 | │   └── slide_seqv2_analysis.ipynb 83 | ├── A7_stereo_seq 84 | │   └── stereo_seq_analysis.ipynb 85 | └── readme.md 86 | ``` 87 | 88 | ## Chrysalis: decoding tissue compartments in spatial transcriptomics with archetypal analysis 89 | 90 | **Authors**: Demeter Túrós, Jelica Vasiljevic, Kerstin Hahn, Sven Rottenberg, and Alberto Valdeolivas 91 | 92 | **Abstract**: Dissecting tissue compartments in spatial transcriptomics (ST) remains challenging due to 93 | limited spatial resolution and dependence on single-cell reference data. We present Chrysalis, a novel 94 | computational method that rapidly uncovers tissue compartments through spatially variable gene (SVG) 95 | detection and archetypal analysis without requiring external reference data. Additionally, it offers a 96 | unique visualisation approach for swift tissue characterization and provides access to gene expression 97 | signatures, enabling the identification of spatially and functionally distinct cellular niches. Chrysalis 98 | was evaluated through various benchmarks and validated against deconvolution, independently obtained cell 99 | type abundance data, and histopathological annotations, demonstrating superior performance compared to 100 | other algorithms on both in silico and real-world test examples. Furthermore, we underscored its versatility 101 | across different technologies, such as Visium, Visium HD, Slide-seq, and Stereo-seq. 102 | -------------------------------------------------------------------------------- /chrysalis/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from numba.core.errors import NumbaDeprecationWarning, NumbaPendingDeprecationWarning 3 | 4 | # Filter out specific warning categories 5 | warnings.filterwarnings("ignore", category=NumbaDeprecationWarning) 6 | warnings.filterwarnings("ignore", category=UserWarning) 7 | warnings.filterwarnings("ignore", category=FutureWarning) 8 | 9 | # core functions 10 | from .core import detect_svgs 11 | from .core import pca 12 | from .core import aa 13 | 14 | # plotting functions 15 | from .plots import plot 16 | from .plots import plot_compartment 17 | from .plots import plot_compartments 18 | from .plots import plot_explained_variance 19 | from .plots import plot_svgs 20 | from .plots import plot_rss 21 | from .plots import plot_heatmap 22 | from .plots import plot_weights 23 | from .plots import plot_svg_matrix 24 | from .plots import plot_samples 25 | 26 | # utility functions 27 | from .utils import get_compartment_df 28 | from .utils import integrate_adatas 29 | from .utils import harmony_integration 30 | -------------------------------------------------------------------------------- /chrysalis/core.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import scanpy as sc 4 | from tqdm import tqdm 5 | import archetypes as arch 6 | from anndata import AnnData 7 | from pysal.lib import weights 8 | from pysal.explore import esda 9 | from sklearn.decomposition import PCA 10 | 11 | 12 | def detect_svgs(adata: AnnData, min_spots: float=0.05, top_svg: int=1000, min_morans: float=0.20, neighbors: int=6, 13 | geary: bool=False): 14 | """ 15 | Calculate spatial autocorrelation (Moran's I) to define spatially variable genes. 16 | 17 | By default we only calculate autocorrelation for genes expressed in at least 5% of capture spots 18 | defined with `min_spots`. 19 | 20 | :param adata: 21 | The AnnData data matrix of shape `n_obs` × `n_vars`. Rows correspond to cells and columns to genes. 22 | Spatial data needs to be stored in `.obsm['spatial']` as X and Y coordinate columns. 23 | :param min_spots: Run calculation only for genes expressed in equal or higher fraction of the total capture spots. 24 | :param top_svg: Cutoff for top ranked spatially variable genes. 25 | :param min_morans: 26 | Cutoff using Moran's I. Specifying this parameter does not disable the cutoff set in `top_svg`. The 27 | 'min_morans' cutoff only activates when the cutoff value retains less genes than specified in `top_svg`. 28 | :param neighbors: Number of nearest neighbours used for calculating Moran's I. 29 | :param geary: 30 | Calculate Geary's C in addition to Moran's I. Selected SVGs are not affect by this, stored 31 | in `.var["Geary's C"]`. 32 | :return: 33 | Updates `.var` with the following fields: 34 | 35 | - **.var["Moran's I"]** – Moran's I value for all genes. 36 | - **.var["spatially_variable"]** – Boolean labels of the examined genes based on the defined cutoffs. 37 | 38 | Example usage: 39 | 40 | >>> import chrysalis as ch 41 | >>> import scanpy as sc 42 | >>> adata = sc.datasets.visium_sge(sample_id='V1_Human_Lymph_Node') 43 | >>> sc.pp.calculate_qc_metrics(adata, inplace=True) 44 | >>> sc.pp.filter_cells(adata, min_counts=6000) 45 | >>> sc.pp.filter_genes(adata, min_cells=10) 46 | >>> ch.detect_svgs(adata) 47 | 48 | """ 49 | 50 | assert 0 < min_spots < 1 51 | 52 | sc.settings.verbosity = 0 53 | ad = sc.pp.filter_genes(adata, min_cells=int(len(adata) * min_spots), copy=True) 54 | ad.var_names_make_unique() # moran dies so need some check later 55 | if "log1p" not in adata.uns_keys(): 56 | sc.pp.normalize_total(ad, inplace=True) 57 | sc.pp.log1p(ad) 58 | 59 | gene_matrix = ad.to_df() 60 | 61 | points = adata.obsm['spatial'].copy() 62 | points[:, 1] = points[:, 1] * -1 63 | 64 | w = weights.KNN.from_array(points, k=neighbors) 65 | w.transform = 'R' 66 | 67 | moran_dict = {} 68 | if geary: 69 | geary_dict = {} 70 | 71 | for c in tqdm(ad.var_names, desc='Calculating SVGs'): 72 | moran = esda.moran.Moran(gene_matrix[c], w, permutations=0) 73 | moran_dict[c] = moran.I 74 | if geary: 75 | geary = esda.geary.Geary(gene_matrix[c], w, permutations=0) 76 | geary_dict[c] = geary.C 77 | 78 | moran_df = pd.DataFrame(data=moran_dict.values(), index=moran_dict.keys(), columns=["Moran's I"]) 79 | moran_df = moran_df.sort_values(ascending=False, by="Moran's I") 80 | adata.var["Moran's I"] = moran_df["Moran's I"] 81 | 82 | if geary: 83 | geary_df = pd.DataFrame(data=geary_dict.values(), index=geary_dict.keys(), columns=["Geary's C"]) 84 | geary_df = geary_df.sort_values(ascending=False, by="Geary's C") 85 | adata.var["Geary's C"] = geary_df["Geary's C"] 86 | 87 | # select thresholds 88 | if len(moran_df[:top_svg]) < len(moran_df[moran_df["Moran's I"] > min_morans]): 89 | adata.var['spatially_variable'] = [True if x in moran_df[:top_svg].index else False for x in adata.var_names] 90 | else: 91 | moran_df = moran_df[moran_df["Moran's I"] > min_morans] 92 | adata.var['spatially_variable'] = [True if x in moran_df.index else False for x in adata.var_names] 93 | 94 | 95 | def pca(adata: AnnData, n_pcs: int=50): 96 | """ 97 | Perform PCA (Principal Component Analysis) to calculate PCA coordinates, loadings, and variance decomposition. 98 | 99 | Spatially variable genes need to be defined in `.var['spatially_variable']` using `chrysalis.detect_svgs`. 100 | 101 | :param adata: The AnnData data matrix of shape `n_obs` × `n_vars`. Rows correspond to cells and columns to genes. 102 | :param n_pcs: Number of principal components to be calculated. 103 | :return: 104 | Adds PCs to `.obsm['chr_X_pca']` and updates `.uns` with the following fields: 105 | 106 | - **.uns['chr_pca']['variance_ratio']** – Explained variance ratio. 107 | - **.uns['chr_pca']['loadings']** – Spatially variable gene loadings. 108 | - **.uns['chr_pca']['features']** – Spatially variable gene names. 109 | 110 | Example usage: 111 | 112 | >>> import chrysalis as ch 113 | >>> import scanpy as sc 114 | >>> adata = sc.datasets.visium_sge(sample_id='V1_Human_Lymph_Node') 115 | >>> sc.pp.calculate_qc_metrics(adata, inplace=True) 116 | >>> sc.pp.filter_cells(adata, min_counts=6000) 117 | >>> sc.pp.filter_genes(adata, min_cells=10) 118 | >>> ch.detect_svgs(adata) 119 | >>> sc.pp.normalize_total(adata, inplace=True) 120 | >>> sc.pp.log1p(adata) 121 | >>> ch.pca(adata) 122 | 123 | """ 124 | 125 | # todo: this only works with CSL matrix, need something to check if the matrix is dense 126 | pcs = np.asarray(adata[:, adata.var['spatially_variable'] == True].X.todense()) 127 | pca = PCA(n_components=n_pcs, svd_solver='arpack', random_state=42) 128 | adata.obsm['chr_X_pca'] = pca.fit_transform(pcs) 129 | 130 | if 'chr_pca' not in adata.uns.keys(): 131 | adata.uns['chr_pca'] = {'variance_ratio': pca.explained_variance_ratio_, 132 | 'loadings': pca.components_, 133 | 'features': list(adata[:, adata.var['spatially_variable'] == True].var_names)} 134 | else: 135 | adata.uns['chr_pca']['variance_ratio'] = pca.explained_variance_ratio_ 136 | adata.uns['chr_pca']['loadings'] = pca.components_ 137 | adata.uns['chr_pca']['features'] = list(adata[:, adata.var['spatially_variable'] == True].var_names) 138 | 139 | 140 | def aa(adata: AnnData, n_archetypes: int=8, pca_key: str=None, n_pcs: int=None, max_iter: int=200): 141 | """ 142 | Run archetypal analysis on the low-dimensional embedding. 143 | 144 | Calculates archetypes, alphas, loadings, and RSS (Residual Sum of Squares). Requires input calculated with 145 | `chrysalis.pca`. 146 | 147 | :param adata: The AnnData data matrix of shape `n_obs` × `n_vars`. Rows correspond to cells and columns to genes. 148 | :param n_archetypes: Number of archetypes (tissue compartments) to be identified. 149 | :param pca_key: Define alternative PCA input key from `.obm`, otherwise `chr_X_pca` is used. 150 | :param n_pcs: Number of PCs (Principal Components) to be used. 151 | :param max_iter: Maximum number of iterations. 152 | :return: 153 | Updates `.uns` with the following fields: 154 | 155 | - **.uns['chr_aa']['archetypes']** – Archetypes. 156 | - **.uns['chr_aa']['alphas']** – Alphas. 157 | - **.uns['chr_aa']['loadings']** – Gene loadings. 158 | - **.uns['chr_aa']['RSS']** – RSS reconstruvtion error. 159 | 160 | Example usage: 161 | 162 | >>> import chrysalis as ch 163 | >>> import scanpy as sc 164 | >>> adata = sc.datasets.visium_sge(sample_id='V1_Human_Lymph_Node') 165 | >>> sc.pp.calculate_qc_metrics(adata, inplace=True) 166 | >>> sc.pp.filter_cells(adata, min_counts=6000) 167 | >>> sc.pp.filter_genes(adata, min_cells=10) 168 | >>> ch.detect_svgs(adata) 169 | >>> sc.pp.normalize_total(adata, inplace=True) 170 | >>> sc.pp.log1p(adata) 171 | >>> ch.pca(adata) 172 | >>> ch.aa(adata, n_pcs=20, n_archetypes=8) 173 | 174 | """ 175 | 176 | if not isinstance(n_archetypes, int): 177 | raise TypeError 178 | if n_archetypes < 2: 179 | raise ValueError(f"n_archetypes cannot be less than 2.") 180 | 181 | if n_pcs is None: 182 | pcs = n_archetypes-1 183 | else: 184 | pcs = n_pcs 185 | 186 | model = arch.AA(n_archetypes=n_archetypes, n_init=3, max_iter=max_iter, tol=0.001, random_state=42) 187 | 188 | if pca_key is None: 189 | model.fit(adata.obsm['chr_X_pca'][:, :pcs]) 190 | else: 191 | model.fit(adata.obsm[pca_key][:, :pcs]) 192 | 193 | adata.obsm['chr_aa'] = model.alphas_ 194 | 195 | # get the mean of the original feature matrix and add it to the multiplied archetypes with the PCA loading matrix 196 | # aa_loadings = np.mean(pcs, axis=0) + np.dot(model.archetypes_.T, pca.components_[:n_archetypes, :]) 197 | aa_loadings = np.dot(model.archetypes_, adata.uns['chr_pca']['loadings'][:pcs, :]) 198 | 199 | if 'chr_aa' not in adata.uns.keys(): 200 | adata.uns['chr_aa'] = {'archetypes': model.archetypes_, 201 | 'alphas': model.alphas_, 202 | 'loadings': aa_loadings, 203 | 'RSS': model.rss_} 204 | else: 205 | adata.uns['chr_aa']['archetypes'] = model.archetypes_ 206 | adata.uns['chr_aa']['alphas'] = model.alphas_ 207 | adata.uns['chr_aa']['loadings'] = aa_loadings 208 | adata.uns['chr_aa']['RSS'] = model.rss_ 209 | -------------------------------------------------------------------------------- /chrysalis/functions.py: -------------------------------------------------------------------------------- 1 | import colorsys 2 | import numpy as np 3 | import pandas as pd 4 | import scanpy as sc 5 | from tqdm import tqdm 6 | import archetypes as arch 7 | from pysal.lib import weights 8 | from pysal.explore import esda 9 | import matplotlib.pyplot as plt 10 | import matplotlib.colors as mcolors 11 | from sklearn.decomposition import PCA 12 | from scipy.spatial.distance import cdist 13 | 14 | 15 | def get_moransI(w_orig, y): 16 | # REF: https://github.com/yatshunlee/spatial_autocorrelation/blob/main/spatial_autocorrelation/moransI.py modified 17 | # wth some ChatGPT magic to remove the for loops 18 | 19 | if not isinstance(y, np.ndarray): 20 | raise TypeError("Passed array (feature) should be in numpy array (ndim = 1)") 21 | if y.shape[0] != w_orig.shape[0]: 22 | raise ValueError("Feature array is not the same shape of weight") 23 | if w_orig.shape[0] != w_orig.shape[1]: 24 | raise ValueError("Weight array should be in square shape") 25 | 26 | w = w_orig.copy() 27 | y_hat = np.mean(y) 28 | D = y - y_hat 29 | D_sq = (y - y_hat) ** 2 30 | N = y.shape[0] 31 | sum_W = np.sum(w) 32 | w *= D.reshape(-1, 1) * D.reshape(1, -1) * (w != 0) 33 | moransI = (np.sum(w) / sum(D_sq)) * (N / sum_W) 34 | 35 | return round(moransI, 8) 36 | 37 | 38 | def black_to_color(color): 39 | # define the colors in the colormap 40 | colors = ["black", color] 41 | # create a colormap object using the defined colors 42 | cmap = mcolors.LinearSegmentedColormap.from_list("", colors) 43 | return cmap 44 | 45 | 46 | def hls_to_hex(h, l, s): 47 | # convert the HSV values to RGB values 48 | r, g, b = colorsys.hls_to_rgb(h, l, s) 49 | # convert the RGB values to a hex color code 50 | hex_code = "#{:02X}{:02X}{:02X}".format(int(r * 255), int(g * 255), 76) 51 | return hex_code 52 | 53 | 54 | def generate_random_colors(num_colors, hue_range=(0, 1), saturation=0.5, lightness=0.5, min_distance=0.2): 55 | colors = [] 56 | hue_list = [] 57 | 58 | while len(colors) < num_colors: 59 | # Generate a random hue value within the specified range 60 | hue = np.random.uniform(hue_range[0], hue_range[1]) 61 | 62 | # Check if the hue is far enough away from the previous hue 63 | if len(hue_list) == 0 or all(abs(hue - h) > min_distance for h in hue_list): 64 | hue_list.append(hue) 65 | saturation = saturation 66 | lightness = lightness 67 | rgb = colorsys.hls_to_rgb(hue, lightness, saturation) 68 | hex_code = '#{:02x}{:02x}{:02x}'.format(int(rgb[0] * 255), int(rgb[1] * 255), int(rgb[2] * 255)) 69 | colors.append(hex_code) 70 | 71 | return colors 72 | 73 | def get_rgb_from_colormap(cmap, vmin, vmax, value): 74 | # normalize the value within the range [0, 1] 75 | norm = plt.Normalize(vmin=vmin, vmax=vmax) 76 | value_normalized = norm(value) 77 | 78 | # get the RGBA value from the colormap 79 | rgba = plt.get_cmap(cmap)(value_normalized) 80 | # convert the RGBA value to RGB 81 | # color = tuple(np.array(rgba[:3]) * 255) 82 | color = np.array(rgba[:, :3]) 83 | 84 | return color 85 | 86 | 87 | def blend_colors(colors_1, colors_2, weight=0.5): 88 | # ensure weight is between 0 and 1 89 | weight = max(0, min(1, weight)) 90 | 91 | # blend the colors using linear interpolation 92 | blended_colors = [] 93 | for i in range(len(colors_1)): 94 | r = (1 - weight) * colors_1[i][0] + weight * colors_2[i][0] 95 | g = (1 - weight) * colors_1[i][1] + weight * colors_2[i][1] 96 | b = (1 - weight) * colors_1[i][2] + weight * colors_2[i][2] 97 | blended_colors.append((r, g, b)) 98 | return blended_colors 99 | 100 | 101 | def mip_colors(colors_1, colors_2): 102 | # blend the colors using linear interpolation 103 | mip_color = [] 104 | for i in range(len(colors_1)): 105 | r = max(colors_1[i][0], colors_2[i][0]) 106 | g = max(colors_1[i][1], colors_2[i][1]) 107 | b = max(colors_1[i][2], colors_2[i][2]) 108 | mip_color.append((r, g, b)) 109 | return mip_color 110 | 111 | 112 | def chrysalis_calculate(adata, min_spots=1000, top_svg=1000, n_archetypes=8): 113 | """ 114 | Calculates spatially variable genes and embeddings for visualization. 115 | 116 | :param adata: 10X Visium anndata matrix created with scanpy. 117 | :param min_spots: Discard genes expressed in less capture spots than this threshold. Speeds up spatially variable 118 | gene computation but can be set lower if sample area is small. 119 | :param top_svg: Number of spatially variable genes to be considered for PCA. 120 | :param n_archetypes: Number of inferred archetypes, best leave it at 8, no significant gain by trying to visualize 121 | more. 122 | :return: Directly annotates the data matrix: adata.obsm['chr_X_pca'] and adata.obsm['chr_aa']. 123 | """ 124 | sc.settings.verbosity = 0 125 | ad = sc.pp.filter_genes(adata, min_cells=min_spots, copy=True) 126 | ad.var_names_make_unique() # moran dies so need some check later 127 | if "log1p" not in adata.uns_keys(): 128 | sc.pp.normalize_total(ad, inplace=True) 129 | sc.pp.log1p(ad) 130 | 131 | gene_matrix = ad.to_df() 132 | 133 | points = adata.obsm['spatial'].copy() 134 | points[:, 1] = points[:, 1] * -1 135 | 136 | w = weights.KNN.from_array(points, k=6) 137 | w.transform = 'R' 138 | moran_dict = {} 139 | 140 | for c in tqdm(ad.var_names): 141 | moran = esda.moran.Moran(gene_matrix[c], w, permutations=0) 142 | moran_dict[c] = moran.I 143 | 144 | moran_df = pd.DataFrame(data=moran_dict.values(), index=moran_dict.keys(), columns=["Moran's I"]) 145 | moran_df = moran_df.sort_values(ascending=False, by="Moran's I") 146 | adata.var['spatially_variable'] = [True if x in moran_df[:top_svg].index else False for x in adata.var_names] 147 | ad.var['spatially_variable'] = [True if x in moran_df[:top_svg].index else False for x in ad.var_names] 148 | adata.var["Moran's I"] = moran_df["Moran's I"] 149 | 150 | pcs = np.asarray(ad[:, ad.var['spatially_variable'] == True].X.todense()) 151 | pca = PCA(n_components=50, svd_solver='arpack', random_state=0) 152 | adata.obsm['chr_X_pca'] = pca.fit_transform(pcs) 153 | if 'chr_pca' not in adata.uns.keys(): 154 | adata.uns['chr_pca'] = {'variance_ratio': pca.explained_variance_ratio_} 155 | else: 156 | adata.uns['chr_pca']['variance_ratio'] = pca.explained_variance_ratio_ 157 | 158 | model = arch.AA(n_archetypes=n_archetypes, n_init=3, max_iter=200, tol=0.001, random_state=42) 159 | model.fit(adata.obsm['chr_X_pca'][:, :n_archetypes-1]) 160 | adata.obsm[f'chr_aa'] = model.alphas_ 161 | 162 | 163 | def chrysalis_plot(adata, dim=8, hexcodes=None, seed=None, mode='aa'): 164 | """ 165 | Visualizes embeddings calculated with chrysalis.calculate. 166 | :param adata: 10X Visium anndata matrix created with scanpy. 167 | :param dim: Number of components to visualize. 168 | :param hexcodes: List of hexadecimal colors to replace the default colormap. 169 | :param seed: Random seed, used for mixing colors. 170 | :param mode: Components to visualize: 'aa' - archetype analysis, 'pca' - PCA 171 | :return: 172 | """ 173 | 174 | # define PC colors 175 | if hexcodes is None: 176 | hexcodes = ['#db5f57', '#dbc257', '#91db57', '#57db80', '#57d3db', '#5770db', '#a157db', '#db57b2'] 177 | if seed is None: 178 | np.random.seed(len(adata)) 179 | else: 180 | np.random.seed(seed) 181 | np.random.shuffle(hexcodes) 182 | else: 183 | assert len(hexcodes) >= dim 184 | # define colormaps 185 | cmaps = [] 186 | 187 | if mode == 'aa': 188 | for d in range(dim): 189 | pc_cmap = black_to_color(hexcodes[d]) 190 | pc_rgb = get_rgb_from_colormap(pc_cmap, 191 | vmin=min(adata.obsm['chr_aa'][:, d]), 192 | vmax=max(adata.obsm['chr_aa'][:, d]), 193 | value=adata.obsm['chr_aa'][:, d]) 194 | cmaps.append(pc_rgb) 195 | 196 | elif mode == 'pca': 197 | for d in range(dim): 198 | pc_cmap = black_to_color(hexcodes[d]) 199 | pc_rgb = get_rgb_from_colormap(pc_cmap, 200 | vmin=min(adata.obsm['chr_X_pca'][:, d]), 201 | vmax=max(adata.obsm['chr_X_pca'][:, d]), 202 | value=adata.obsm['chr_X_pca'][:, d]) 203 | cmaps.append(pc_rgb) 204 | else: 205 | raise Exception 206 | 207 | # mip colormaps 208 | cblend = mip_colors(cmaps[0], cmaps[1],) 209 | if len(cmaps) > 2: 210 | i = 2 211 | for cmap in cmaps[2:]: 212 | cblend = mip_colors(cblend, cmap,) 213 | i += 1 214 | 215 | # plot 216 | fig, ax = plt.subplots(1, 1, figsize=(5, 5)) 217 | ax.axis('off') 218 | row = adata.obsm['spatial'][:, 0] 219 | col = adata.obsm['spatial'][:, 1] * -1 220 | ax.set_xlim((np.min(row) * 0.9, np.max(row) * 1.1)) 221 | ax.set_ylim((np.min(col) * 1.1, np.max(col) * 0.9)) 222 | ax.set_aspect('equal') 223 | 224 | distances = cdist(np.column_stack((row, col)), np.column_stack((row, col))) 225 | np.fill_diagonal(distances, np.inf) 226 | min_distance = np.min(distances) 227 | 228 | # get the physical length of the x and y axes 229 | ax_len = np.diff(np.array(ax.get_position())[:, 0]) * fig.get_size_inches()[0] 230 | size_const = ax_len / np.diff(ax.get_xlim())[0] * min_distance * 72 231 | size = size_const ** 2 * 0.95 232 | plt.scatter(row, col, s=size, marker="h", c=cblend) 233 | -------------------------------------------------------------------------------- /chrysalis/test/test.py: -------------------------------------------------------------------------------- 1 | import scanpy as sc 2 | import chrysalis as ch 3 | import matplotlib.pyplot as plt 4 | import scanorama 5 | import os 6 | from glob import glob 7 | 8 | 9 | def preprocess_sample(): 10 | # preprocessing samples 11 | print('Preprocessing...') 12 | 13 | if not os.path.isdir('temp/'): 14 | os.makedirs('temp/', exist_ok=True) 15 | 16 | try: 17 | adata = sc.read_h5ad('temp/V1_Human_Lymph_Node_ss.h5ad') 18 | except FileNotFoundError: 19 | adata = sc.datasets.visium_sge(sample_id='V1_Human_Lymph_Node') 20 | sc.pp.calculate_qc_metrics(adata, inplace=True) 21 | sc.pp.filter_cells(adata, min_counts=6000) 22 | sc.pp.filter_genes(adata, min_cells=10) 23 | 24 | ch.detect_svgs(adata) 25 | 26 | moran_df = adata.var[adata.var["Moran's I"] > 0.08] 27 | adata.var['spatially_variable'] = [True if x in moran_df.index else False for x in adata.var_names] 28 | 29 | adata.write_h5ad(f'temp/V1_Human_Lymph_Node_ss.h5ad') 30 | 31 | return adata 32 | 33 | 34 | def preprocess_multisample(): 35 | # preprocessing samples 36 | print('Preprocessing...') 37 | samples = ['V1_Mouse_Brain_Sagittal_Anterior_Section_2', 'V1_Mouse_Brain_Sagittal_Posterior_Section_2'] 38 | adatas = [] 39 | 40 | if not os.path.isdir('temp/'): 41 | os.makedirs('temp/', exist_ok=True) 42 | 43 | files = glob('temp/*_ms.h5ad') 44 | adatas = [sc.read_h5ad(x) for x in files] 45 | 46 | if len(adatas) == 0: 47 | for sample in samples: 48 | ad = sc.datasets.visium_sge(sample_id=sample) 49 | ad.var_names_make_unique() 50 | sc.pp.calculate_qc_metrics(ad, inplace=True) 51 | sc.pp.filter_cells(ad, min_counts=1000) 52 | sc.pp.filter_genes(ad, min_cells=10) 53 | sc.pp.normalize_total(ad, inplace=True) 54 | sc.pp.log1p(ad) 55 | 56 | ch.detect_svgs(ad, min_morans=0.05, min_spots=0.05) 57 | ad.write_h5ad(f'temp/{sample}_ms.h5ad') 58 | 59 | adatas.append(ad) 60 | 61 | return adatas 62 | 63 | 64 | def save_plot(plot_save, name=None): 65 | if plot_save: 66 | # plt.show() 67 | if isinstance(name, str): 68 | if not os.path.isdir('temp/plots/'): 69 | os.makedirs('temp/plots/', exist_ok=True) 70 | plt.savefig(f'temp/plots/{name}.png') 71 | else: 72 | raise ValueError('No plot name specified.') 73 | else: 74 | plt.clf() 75 | 76 | 77 | def test_single_sample(save=True): 78 | 79 | adata = preprocess_sample() 80 | 81 | # normalization 82 | sc.pp.normalize_total(adata, inplace=True) 83 | sc.pp.log1p(adata) 84 | 85 | ch.pca(adata) 86 | 87 | ch.plot_svgs(adata) 88 | print(os.getcwd()) 89 | save_plot(save, name='singleplot_svg') 90 | 91 | ch.plot_explained_variance(adata) 92 | save_plot(save, name='singleplot_evr') 93 | 94 | ch.aa(adata, n_pcs=20, n_archetypes=8) 95 | 96 | ch.plot(adata) 97 | save_plot(save, name='singleplot_plot') 98 | 99 | ch.plot_compartments(adata) 100 | save_plot(save, name='singleplot_comps') 101 | 102 | ch.plot_heatmap(adata) 103 | save_plot(save, name='singleplot_heatmap') 104 | 105 | ch.plot_weights(adata) 106 | save_plot(save, name='singleplot_weights') 107 | 108 | 109 | def test_multi_sample_harmony(save=True): 110 | 111 | adatas = preprocess_multisample() 112 | 113 | # concatenate samples 114 | adata = ch.integrate_adatas(adatas, sample_col='sample') 115 | # replace ENSEMBL IDs with the gene symbols and make them unique 116 | adata.var_names = adata.var['gene_symbols'] 117 | adata.var_names_make_unique() 118 | # harmony 119 | ch.pca(adata, n_pcs=50) 120 | ch.harmony_integration(adata, 'sample', random_state=42, block_size=0.05) 121 | 122 | ch.aa(adata, n_pcs=20, n_archetypes=10) 123 | 124 | ch.plot_samples(adata, 1, 2, dim=10, suptitle='test') 125 | save_plot(save, name='multiplot_mip_harmony') 126 | 127 | ch.plot_samples(adata, 1, 2, dim=10, suptitle='test', selected_comp=0) 128 | save_plot(save, name='multiplot_single_harmony') 129 | 130 | 131 | def test_multi_sample_scanorama(save=True): 132 | 133 | adatas = preprocess_multisample() 134 | 135 | # scanorama 136 | adatas_cor = scanorama.correct_scanpy(adatas, return_dimred=True) 137 | # concatenate samples 138 | adata = ch.integrate_adatas(adatas_cor, sample_col='sample') 139 | # replace ENSEMBL IDs with the gene symbols and make them unique 140 | adata.var_names = adata.var['gene_symbols'] 141 | adata.var_names_make_unique() 142 | 143 | ch.pca(adata, n_pcs=50) 144 | ch.aa(adata, n_pcs=20, n_archetypes=10) 145 | 146 | ch.plot_samples(adata, 1, 2, dim=10, suptitle='test') 147 | save_plot(save, name='multiplot_mip_scanorama') 148 | 149 | ch.plot_samples(adata, 1, 2, dim=10, suptitle='test', selected_comp=0) 150 | save_plot(save, name='multiplot_single_scanorama') 151 | 152 | def test_multi_sample_plots(save=True): 153 | adatas = preprocess_multisample() 154 | 155 | # concatenate samples 156 | adata = ch.integrate_adatas(adatas, sample_col='sample') 157 | # replace ENSEMBL IDs with the gene symbols and make them unique 158 | adata.var_names = adata.var['gene_symbols'] 159 | adata.var_names_make_unique() 160 | 161 | ch.pca(adata, n_pcs=50) 162 | ch.aa(adata, n_pcs=20, n_archetypes=10) 163 | 164 | ch.plot_svg_matrix(adatas, figsize=(8, 7), obs_name='sample', cluster=True) 165 | save_plot(save, name='multiplot_svg_matrix') 166 | 167 | ch.plot_samples(adata, 1, 2, dim=10, suptitle='test') 168 | save_plot(save, name='multiplot_mip') 169 | 170 | ch.plot_samples(adata, 1, 2, dim=10, suptitle='test', selected_comp=0) 171 | save_plot(save, name='multiplot_single') 172 | 173 | 174 | if __name__ == '__main__': 175 | save=True 176 | 177 | script_directory = os.path.dirname(os.path.abspath(__file__)) 178 | os.chdir(script_directory) 179 | 180 | print(f'Temporary files and plots are saved to the following directory: {script_directory}') 181 | 182 | print('Running test_single_sample...') 183 | test_single_sample(save=True) 184 | print('Test completed!') 185 | print('Running test_multi_sample_harmony...') 186 | test_multi_sample_harmony(save=True) 187 | print('Test completed!') 188 | print('Running test_multi_sample_scanorama...') 189 | test_multi_sample_scanorama(save=True) 190 | print('Test completed!') 191 | print('Running test_multi_sample_plots...') 192 | test_multi_sample_plots(save=True) 193 | print('Test completed!') 194 | print('------------------------------') 195 | print('All tests have been completed!') 196 | -------------------------------------------------------------------------------- /chrysalis/utils.py: -------------------------------------------------------------------------------- 1 | import anndata 2 | import colorsys 3 | import numpy as np 4 | import pandas as pd 5 | from tqdm import tqdm 6 | import archetypes as arch 7 | from scipy.stats import entropy 8 | import matplotlib.pyplot as plt 9 | import matplotlib.colors as mcolors 10 | from .core import detect_svgs 11 | from typing import List 12 | from anndata import AnnData 13 | 14 | 15 | def black_to_color(color): 16 | # define the colors in the colormap 17 | colors = ["black", color] 18 | # create a colormap object using the defined colors 19 | cmap = mcolors.LinearSegmentedColormap.from_list("", colors) 20 | return cmap 21 | 22 | 23 | def color_to_color(first, last): 24 | # define the colors in the colormap 25 | colors = [first, last] 26 | # create a colormap object using the defined colors 27 | cmap = mcolors.LinearSegmentedColormap.from_list("", colors) 28 | return cmap 29 | 30 | 31 | def hls_to_hex(h, l, s): 32 | # convert the HLS values to RGB values 33 | r, g, b = colorsys.hls_to_rgb(h, l, s) 34 | # convert the RGB values to a hex color code 35 | hex_code = "#{:02X}{:02X}{:02X}".format(int(r * 255), int(g * 255), 76) 36 | return hex_code 37 | 38 | 39 | def generate_random_colors(num_colors, hue_range=(0, 1), saturation=0.5, lightness=0.5, min_distance=0.05, seed=None): 40 | colors = [] 41 | hue_list = [] 42 | if seed: 43 | np.random.seed(seed) 44 | else: 45 | np.random.seed(42) 46 | while len(colors) < num_colors: 47 | # generate a random hue value within the specified range 48 | hue = np.random.uniform(hue_range[0], hue_range[1]) 49 | 50 | # check if the hue is far enough away from the previous hue 51 | if len(hue_list) == 0 or all(abs(hue - h) > min_distance for h in hue_list): 52 | hue_list.append(hue) 53 | saturation = saturation 54 | lightness = lightness 55 | rgb = colorsys.hls_to_rgb(hue, lightness, saturation) 56 | hex_code = '#{:02x}{:02x}{:02x}'.format(int(rgb[0] * 255), int(rgb[1] * 255), int(rgb[2] * 255)) 57 | colors.append(hex_code) 58 | 59 | return colors 60 | 61 | 62 | def get_rgb_from_colormap(cmap, vmin, vmax, value): 63 | # normalize the value within the range [0, 1] 64 | norm = plt.Normalize(vmin=vmin, vmax=vmax) 65 | value_normalized = norm(value) 66 | 67 | # get the RGBA value from the colormap 68 | rgba = plt.get_cmap(cmap)(value_normalized) 69 | # convert the RGBA value to RGB 70 | # color = tuple(np.array(rgba[:3]) * 255) 71 | color = np.array(rgba[:, :3]) 72 | 73 | return color 74 | 75 | 76 | def blend_colors(colors_1, colors_2, weight=0.5): 77 | # ensure weight is between 0 and 1 78 | weight = max(0, min(1, weight)) 79 | 80 | # blend the colors using linear interpolation 81 | blended_colors = [] 82 | for i in range(len(colors_1)): 83 | r = (1 - weight) * colors_1[i][0] + weight * colors_2[i][0] 84 | g = (1 - weight) * colors_1[i][1] + weight * colors_2[i][1] 85 | b = (1 - weight) * colors_1[i][2] + weight * colors_2[i][2] 86 | blended_colors.append((r, g, b)) 87 | return blended_colors 88 | 89 | 90 | def mip_colors(colors_1, colors_2): 91 | # blend the colors using linear interpolation 92 | mip_color = [] 93 | for i in range(len(colors_1)): 94 | r = max(colors_1[i][0], colors_2[i][0]) 95 | g = max(colors_1[i][1], colors_2[i][1]) 96 | b = max(colors_1[i][2], colors_2[i][2]) 97 | mip_color.append((r, g, b)) 98 | return mip_color 99 | 100 | 101 | def get_colors(adata, dim=8, seed=42): 102 | if dim > 8: 103 | hexcodes = generate_random_colors(num_colors=dim, min_distance=1 / dim * 0.5, seed=seed, 104 | saturation=0.65, lightness=0.60) 105 | else: 106 | hexcodes = ['#db5f57', '#dbc257', '#91db57', '#57db80', '#57d3db', '#5770db', '#a157db', '#db57b2'] 107 | if seed is None: 108 | np.random.seed(len(adata)) 109 | else: 110 | np.random.seed(seed) 111 | np.random.shuffle(hexcodes) 112 | return hexcodes 113 | 114 | 115 | def estimate_compartments(adata, n_pcs=20, range_archetypes=(3, 50), max_iter=10): 116 | 117 | if 'chr_X_pca' not in adata.obsm.keys(): 118 | raise ValueError(".obsm['chr_X_pca'] cannot be found, run chrysalis_pca first.") 119 | 120 | entropy_arr = np.zeros((len(range(range_archetypes[0], range_archetypes[1])), len(adata))) 121 | rss_dict = {} 122 | i = 0 123 | for a in tqdm(range(range_archetypes[0], range_archetypes[1]), desc='Fitting models'): 124 | model = arch.AA(n_archetypes=a, n_init=3, max_iter=max_iter, tol=0.001, random_state=42) 125 | model.fit(adata.obsm['chr_X_pca'][:, :n_pcs]) 126 | rss_dict[a] = model.rss_ 127 | 128 | entropy_arr[i, :] = entropy(model.alphas_.T) 129 | i += 1 130 | 131 | adata.obsm['entropy'] = entropy_arr.T 132 | 133 | if 'chr_aa' not in adata.uns.keys(): 134 | adata.uns['chr_aa'] = {'RSSs': rss_dict} 135 | else: 136 | adata.uns['chr_aa']['RSSs'] = rss_dict 137 | 138 | 139 | def get_compartment_df(adata: AnnData, weights: bool=True): 140 | """ 141 | Get spatially variable gene weights/expression values as a pandas DataFrame. 142 | 143 | :param adata: The AnnData data matrix of shape `n_obs` × `n_vars`. Rows correspond to cells and columns to genes.: 144 | :param weights: If False, return expression values instead of weights. 145 | :return: Pandas DataFrame. 146 | 147 | """ 148 | 149 | # SVG expression for each compartment 150 | exp_array = np.asarray(adata[:, adata.var['spatially_variable'] == True].X.todense()) 151 | exp_array = np.mean(exp_array, axis=0) 152 | exp_aa = adata.uns['chr_aa']['loadings'] 153 | if not weights: 154 | exp_aa += exp_array 155 | 156 | df = pd.DataFrame(data=exp_aa, columns=adata.uns['chr_pca']['features'], 157 | index=[f'compartment_{x}' for x in (range(len(exp_aa)))]).T 158 | return df 159 | 160 | 161 | def integrate_adatas(adatas: List[AnnData], sample_names: List[str]=None, calculate_svgs: bool=False, 162 | sample_col: str='sample', **kwargs): 163 | """ 164 | Integrate multiple samples stored in AnnData objects. 165 | 166 | If ENSEMBL IDs are present in the`.var['gene_ids']` column, that will be used instead of gene symbols. 167 | `.var['spatially_variable']` will be outer joined. 168 | 169 | :param adatas: List of AnnData objects. 170 | :param sample_names: List of sample names. If not defined, a list of integers [0, 1, ...] will be used instead. 171 | :param calculate_svgs: If True, the function also runs `chrysalis.detect_svgs` for every sample. 172 | :param sample_col: `.obs` column name to store the sample labels. 173 | :param kwargs: Keyword arguments for `chrysalis.detect_svgs`. 174 | :return: 175 | Integrated AnnData object. Sample IDs are stored in `.obs[sample_col]. `.var['spatially_variable']` contains 176 | the union of`.var['spatially_variable']` from the input AnnData objects. Sample-wise SVG data is stored in 177 | `.varm['spatially_variable']` and Moran's I is stored in `.varm["Moran's I"]`. 178 | 179 | """ 180 | 181 | if sample_names is None: 182 | sample_names = np.arange(len(adatas)) 183 | assert len(adatas) == len(sample_names) 184 | 185 | adatas_dict = {} 186 | gene_symbol_dict = {} 187 | for ad, name in zip(adatas, sample_names): 188 | 189 | # replace .uns['spatial'] with the specified sample name 190 | if 'spatial' in ad.uns.keys(): 191 | assert len(ad.uns['spatial'].keys()) == 1 192 | curr_key = list(ad.uns['spatial'].keys())[0] 193 | ad.uns['spatial'][name] = ad.uns['spatial'][curr_key] 194 | if name != curr_key: 195 | del ad.uns['spatial'][curr_key] 196 | 197 | # check if column is already used 198 | if sample_col not in ad.obs.columns: 199 | ad.obs[sample_col] = name 200 | else: 201 | raise Exception('sample_id_col is already present in adata.obs, specify another column.') 202 | 203 | if 'gene_symbols' not in ad.var.columns: 204 | ad.var['gene_symbols'] = ad.var_names 205 | 206 | if 'gene_ids' in ad.var.columns: 207 | ad.var_names = ad.var['gene_ids'] 208 | 209 | # check if SVGs are already present 210 | if 'spatially_variable' not in ad.var.columns: 211 | if calculate_svgs: 212 | detect_svgs(ad, **kwargs) 213 | else: 214 | raise Exception('spatially_variable column is not found in adata.var. Run `chrysalis.detect_svgs` ' 215 | 'first or set the calculate_svgs argument to True.') 216 | 217 | ad.var[f'spatially_variable_{name}'] = ad.var['spatially_variable'] 218 | ad.var[f"Moran's I_{name}"] = ad.var["Moran's I"] 219 | 220 | adatas_dict[name] = ad 221 | 222 | # concat samples 223 | adata = anndata.concat(adatas_dict, index_unique='-', uns_merge='unique', merge='first') 224 | adata.obs[sample_col] = adata.obs[sample_col].astype('category') 225 | # get SVGs for all samples 226 | svg_columns = [c for c in adata.var.columns if 'spatially_variable' in c] 227 | svg_list = [list(adata.var[c][adata.var[c] == True].index) for c in svg_columns] 228 | 229 | # union of SVGs 230 | spatially_variable = list(set().union(*svg_list)) 231 | adata.var['spatially_variable'] = [True if x in spatially_variable else False for x in adata.var_names] 232 | 233 | # save sample-wise spatially_variable and Morans's I columns 234 | sv_cols = [x for x in adata.var.columns if 'spatially_variable_' in x] 235 | adata.varm['spatially_variable'] = adata.var[sv_cols].copy() 236 | if 'gene_ids' in adata.var.columns: 237 | adata.varm['spatially_variable']['gene_ids'] = adata.var['gene_ids'] 238 | if 'gene_symbols' in adata.var.columns: 239 | adata.varm['spatially_variable']['gene_symbols'] = adata.var['gene_symbols'] 240 | adata.var = adata.var.drop(columns=sv_cols) 241 | 242 | mi_cols = [x for x in adata.var.columns if "Moran's I_" in x] 243 | adata.varm["Moran's I"] = adata.var[mi_cols].copy() 244 | if 'gene_ids' in adata.var.columns: 245 | adata.varm["Moran's I"]['gene_ids'] = adata.var['gene_ids'] 246 | if 'gene_symbols' in adata.var.columns: 247 | adata.varm["Moran's I"]['gene_symbols'] = adata.var['gene_symbols'] 248 | adata.var = adata.var.drop(columns=mi_cols) 249 | 250 | return adata 251 | 252 | 253 | def harmony_integration(adata, covariates, input_matrix='chr_X_pca', corrected_matrix=None, 254 | random_state=42, **harmony_kw): 255 | """ 256 | Integrate data using `harmonypy`, the Python implementation of the R package Harmony. 257 | 258 | Harmony integration is done on the PCA matrix, therefore `chrysalis.pca` must be run before this function. 259 | 260 | :param adata: The AnnData data matrix of shape `n_obs` × `n_vars`. Rows correspond to cells and columns to genes. 261 | :param covariates: String or list of strings containing the covariate columns to integrate over. 262 | :param input_matrix: Input PCA matrix, by default 'chr_X_pca' is used in `.obsm`. 263 | :param corrected_matrix: If `corrected_matrix` is defined, a new `.obsm` matrix will be created for the integrated 264 | results instead of overwriting the `input_matrix`. 265 | :param harmony_kw: `harmonypy.run_harmony()` keyword arguments. 266 | :return: 267 | Replaces `.obsm[input_matrix]` with the corrected one, or saves the new matrix as a new .`.obsm` matrix 268 | specified with `corrected_matrix`. 269 | 270 | """ 271 | 272 | try: 273 | import harmonypy as hm 274 | except ImportError: 275 | raise ImportError("Please install harmonypy: `pip install harmonypy`.") 276 | 277 | data_matrix = adata.obsm[input_matrix] 278 | metadata = adata.obs 279 | 280 | ho = hm.run_harmony(data_matrix, metadata, covariates, random_state, **harmony_kw) 281 | 282 | adjusted_matrix = np.transpose(ho.Z_corr) 283 | 284 | if corrected_matrix is None: 285 | adata.obsm[input_matrix] = adjusted_matrix 286 | else: 287 | adata.obsm[corrected_matrix] = adjusted_matrix 288 | 289 | 290 | def get_color_vector(adata: AnnData, dim: int=8, hexcodes: List[str]=None, seed: int=None, 291 | selected_comp='all'): 292 | # define compartment colors 293 | # default colormap with 8 colors 294 | hexcodes = get_hexcodes(hexcodes, dim, seed, len(adata)) 295 | 296 | if selected_comp == 'all': 297 | # define colormaps 298 | cmaps = [] 299 | for d in range(dim): 300 | pc_cmap = black_to_color(hexcodes[d]) 301 | pc_rgb = get_rgb_from_colormap(pc_cmap, 302 | vmin=min(adata.obsm['chr_aa'][:, d]), 303 | vmax=max(adata.obsm['chr_aa'][:, d]), 304 | value=adata.obsm['chr_aa'][:, d]) 305 | cmaps.append(pc_rgb) 306 | 307 | # mip colormaps 308 | cblend = mip_colors(cmaps[0], cmaps[1],) 309 | if len(cmaps) > 2: 310 | i = 2 311 | for cmap in cmaps[2:]: 312 | cblend = mip_colors(cblend, cmap,) 313 | i += 1 314 | # specific compartment 315 | else: 316 | color_first = '#2e2e2e' 317 | pc_cmap = color_to_color(color_first, hexcodes[selected_comp]) 318 | pc_rgb = get_rgb_from_colormap(pc_cmap, 319 | vmin=min(adata.obsm['chr_aa'][:, selected_comp]), 320 | vmax=max(adata.obsm['chr_aa'][:, selected_comp]), 321 | value=adata.obsm['chr_aa'][:, selected_comp]) 322 | cblend = pc_rgb 323 | return cblend 324 | 325 | 326 | def get_hexcodes(hexcodes, dim, seed, adata_len): 327 | if hexcodes is None: 328 | if dim > 8: 329 | hexcodes = generate_random_colors(num_colors=dim, min_distance=1 / dim * 0.5, seed=seed, 330 | saturation=0.65, lightness=0.60) 331 | else: 332 | hexcodes = ['#db5f57', '#dbc257', '#91db57', '#57db80', '#57d3db', '#5770db', '#a157db', '#db57b2'] 333 | if seed is None: 334 | np.random.seed(adata_len) 335 | else: 336 | np.random.seed(seed) 337 | np.random.shuffle(hexcodes) 338 | else: 339 | assert len(hexcodes) >= dim 340 | return hexcodes 341 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/_static/css/custom.css: -------------------------------------------------------------------------------- 1 | /* 2 | `width:auto` was rendering 0px wide for .svg files 3 | https://stackoverflow.com/questions/59215996/how-to-add-a-logo-to-my-readthedocs-logo-rendering-at-0px-wide 4 | */ 5 | .wy-side-nav-search .wy-dropdown > a img.logo, .wy-side-nav-search > a img.logo { 6 | width: 200px; 7 | } 8 | 9 | /* ReadTheDocs theme colors */ 10 | 11 | .wy-nav-top { background-color: #f07e44 } 12 | 13 | .wy-side-nav-search { background-color: transparent} 14 | .wy-side-nav-search input[type="text"] { border-width: 0 } 15 | 16 | .wy-nav-content { 17 | max-width: 1200px; 18 | background: #ffffff; 19 | } 20 | 21 | .wy-body-for-nav { 22 | background: #ffffff; 23 | } 24 | 25 | /* toctree menu caption colors */ 26 | .wy-menu-vertical header, .wy-menu-vertical p.caption { 27 | color: #8655ea; 28 | } 29 | 30 | /* toctree menu background colors */ 31 | .wy-nav-side { 32 | background: #f8f8f8; 33 | } 34 | 35 | /* toctree menu item colors */ 36 | .wy-menu-vertical a { 37 | color: #231F20; 38 | } 39 | 40 | .wy-menu-vertical a:hover { 41 | background-color: #d0beff; 42 | } 43 | 44 | .wy-menu-vertical a:hover button.toctree-expand { 45 | color: #d9d9d9 46 | } 47 | 48 | .rst-content code.literal, .rst-content tt.literal { 49 | color: #8655ea; 50 | } 51 | 52 | html.writer-html4 .rst-content dl:not(.docutils)>dt,html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.citation):not(.glossary):not(.simple)>dt { 53 | display: table; 54 | margin: 6px 0; 55 | font-size: 90%; 56 | line-height: normal; 57 | background: rgba(240, 231, 250, 0.73); 58 | color: #8655ea; 59 | border-top: 3px solid #6726ec; 60 | padding: 6px; 61 | position: relative 62 | } 63 | 64 | .wy-nav-content-wrap { 65 | background: rgb(255 255 255 / 5%); 66 | } 67 | 68 | /* before-after slider */ 69 | 70 | html { 71 | box-sizing: border-box; 72 | } 73 | *, *:before, *:after { 74 | box-sizing: inherit; 75 | } 76 | body { 77 | margin: 0; 78 | height: 100vh; 79 | display: flex; 80 | justify-content: center; 81 | align-items: center; 82 | } 83 | .befaft_container { 84 | position: relative; 85 | width: 500px; 86 | height: 436px; 87 | border: 2px solid white; 88 | } 89 | .befaft_container .img { 90 | position: absolute; 91 | top: 0; 92 | left: 0; 93 | width: 100%; 94 | height: 100%; 95 | background-size: 500px 100%; 96 | } 97 | .befaft_container .background-img { 98 | background-image: url('../images/before.png'); 99 | } 100 | .befaft_container .foreground-img { 101 | background-image: url('../images/after.png'); 102 | width: 50%; 103 | } 104 | .befaft_container .slider { 105 | position: absolute; 106 | -webkit-appearance: none; 107 | appearance: none; 108 | width: 100%; 109 | height: 100%; 110 | background: rgba(255, 255, 255, 0); 111 | outline: none; 112 | margin: 0; 113 | transition: all 0.2s; 114 | display: flex; 115 | justify-content: center; 116 | align-items: center; 117 | } 118 | .befaft_container .slider:hover { 119 | background: rgba(242, 242, 242, .1); 120 | } 121 | .befaft_container .slider::-webkit-slider-thumb { 122 | -webkit-appearance: none; 123 | appearance: none; 124 | width: 6px; 125 | height: 436px; 126 | background: #ffffff; 127 | cursor: pointer; 128 | } 129 | .befaft_container .slider::-moz-range-thumb { 130 | width: 6px; 131 | height: 436px; 132 | background: white; 133 | cursor: pointer; 134 | } 135 | .befaft_container .slider-button { 136 | pointer-events: none; 137 | position: absolute; 138 | width: 30px; 139 | height: 30px; 140 | border-radius: 50%; 141 | background-color: #ede8ff; 142 | left: calc(50% - 17.5px); 143 | top: calc(50% - 18px); 144 | display: flex; 145 | justify-content: center; 146 | align-items: center; 147 | } 148 | .befaft_container .slider-button:after { 149 | content: ''; 150 | padding: 3px; 151 | display: inline-block; 152 | border: solid #5d5d5d; 153 | border-width: 0 2px 2px 0; 154 | transform: rotate(-45deg); 155 | } 156 | .befaft_container .slider-button:before { 157 | content: ''; 158 | padding: 3px; 159 | display: inline-block; 160 | border: solid #5d5d5d; 161 | border-width: 0 2px 2px 0; 162 | transform: rotate(135deg); 163 | } 164 | -------------------------------------------------------------------------------- /docs/_static/images/after.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rockdeme/chrysalis/6c18ff8c3e2b1b64fdf8fe7217a96f98ed6749be/docs/_static/images/after.png -------------------------------------------------------------------------------- /docs/_static/images/before.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rockdeme/chrysalis/6c18ff8c3e2b1b64fdf8fe7217a96f98ed6749be/docs/_static/images/before.png -------------------------------------------------------------------------------- /docs/_static/js/custom.js: -------------------------------------------------------------------------------- 1 | $(document).ready(function() { 2 | $("#slider").on("input change", (e)=>{ 3 | const sliderPos = e.target.value; 4 | // Update the width of the foreground image 5 | $('.foreground-img').css('width', `${sliderPos}%`) 6 | // Update the position of the slider button 7 | $('.slider-button').css('left', `calc(${sliderPos}% - 18px)`) 8 | }); 9 | }); -------------------------------------------------------------------------------- /docs/api.md: -------------------------------------------------------------------------------- 1 | ```{eval-rst} 2 | .. automodule:: chrysalis 3 | ``` 4 | 5 | # API overview 6 | 7 | Import chrysalis as: 8 | 9 | ``` 10 | import chrysalis as ch 11 | ``` 12 | 13 | ## Core functions 14 | 15 | Identifying spatially variable genes, dimensionality reduction, archetypal analysis. 16 | 17 | Main functions required to identify tissue compartments. 18 | 19 | ```{eval-rst} 20 | .. autosummary:: 21 | :toctree: generated/ 22 | 23 | detect_svgs 24 | pca 25 | aa 26 | ``` 27 | 28 | ## Plotting 29 | 30 | Visualization module. 31 | 32 | ### Tissue compartments 33 | 34 | Visualizations to examine the identified compartments in the tissue space. 35 | 36 | ```{eval-rst} 37 | .. autosummary:: 38 | :toctree: generated/ 39 | 40 | plot 41 | plot_samples 42 | plot_compartment 43 | plot_compartments 44 | ``` 45 | 46 | ### Quality control 47 | 48 | Plot quality control metrics to determine the correct number of spatially variable genes or PCs (Principal Components). 49 | 50 | ```{eval-rst} 51 | .. autosummary:: 52 | :toctree: generated/ 53 | 54 | plot_explained_variance 55 | plot_svgs 56 | ``` 57 | 58 | ### Compartment-associated genes 59 | 60 | Generate a visualization of the top-contributing genes for each tissue compartment. 61 | 62 | ```{eval-rst} 63 | .. autosummary:: 64 | :toctree: generated/ 65 | 66 | plot_heatmap 67 | plot_weights 68 | ``` 69 | 70 | ## Utility functions 71 | 72 | Sample interation, spatially variable gene contributions. 73 | 74 | ```{eval-rst} 75 | .. autosummary:: 76 | :toctree: generated/ 77 | 78 | integrate_adatas 79 | harmony_integration 80 | get_compartment_df 81 | ``` 82 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.insert(0, os.path.abspath("..")) 5 | sys.path.insert(0, os.path.abspath("../..")) 6 | 7 | 8 | # Configuration file for the Sphinx documentation builder. 9 | # 10 | # For the full list of built-in configuration values, see the documentation: 11 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 12 | 13 | # -- Project information ----------------------------------------------------- 14 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 15 | 16 | project = 'chrysalis' 17 | copyright = '2024, Demeter Túrós' 18 | author = 'Demeter Túrós' 19 | release = '2023' 20 | 21 | # -- General configuration --------------------------------------------------- 22 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 23 | 24 | extensions = ['myst_parser', 25 | 'sphinx.ext.autodoc', 26 | 'sphinx.ext.autosummary', 27 | 'sphinx.ext.napoleon', 28 | 'nbsphinx', 29 | ] 30 | 31 | templates_path = ['_templates'] 32 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 33 | 34 | # -- Options for HTML output ------------------------------------------------- 35 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 36 | 37 | html_theme = 'sphinx_rtd_theme' 38 | html_static_path = ['_static'] 39 | 40 | html_logo = "docs_logo.svg" 41 | html_favicon = "favicon.svg" 42 | 43 | html_theme_options = {"logo_only": True} 44 | 45 | html_css_files = ['css/custom.css'] 46 | 47 | html_js_files = ['js/custom.js'] 48 | 49 | html_title = 'chrysalis 0.2.0' 50 | 51 | autodoc_exclude_members = { 52 | 'chrysalis.core': ['detect_svgs', 'pca', 'aa'], 53 | } 54 | 55 | def setup(app): 56 | app.add_css_file("css/custom.css") -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Welcome to chrysalis! 2 | [![PyPI](https://img.shields.io/pypi/v/chrysalis-st?logo=PyPI)](https://pypi.org/project/chrysalis-st) 3 | [![Downloads](https://static.pepy.tech/badge/chrysalis-st)](https://pepy.tech/project/chrysalis-st) 4 | [![Stars](https://img.shields.io/github/stars/rockdeme/chrysalis?logo=GitHub&color=yellow)](https://github.com/rockdeme/chrysalis/stargazers) 5 | ```{include} ../README.md 6 | :start-line: 7 7 | :end-line: 18 8 | ``` 9 | 10 | * Discuss **chrysalis** on [GitHub]. 11 | * Get started by reading the {doc}`basic tutorial `. 12 | * You can also browse the {doc}`API `. 13 | * Consider citing our [bioRxiv preprint]. 14 | 15 | ## Visual demonstration 16 | ### human lung cancer (FFPE) 17 | 18 | [Squamous Cell Carcinoma](https://www.10xgenomics.com/resources/datasets/human-lung-cancer-ffpe-2-standard) sample by 10X Genomics. 19 | 20 | Move the slider to reveal tissue compartments calculated by **chrysalis** or the associated tissue morphology. 21 | 22 | 23 |
24 |
25 |
26 | 27 |
28 |
29 | 30 | ```{toctree} 31 | :hidden: true 32 | :maxdepth: 1 33 | :caption: chrysalis 34 | 35 | overview/getting-started 36 | overview/installation 37 | ``` 38 | 39 | ```{toctree} 40 | :hidden: true 41 | :maxdepth: 2 42 | :caption: Tutorials 43 | 44 | tutorials/lymph_node_tutorial 45 | tutorials/mouse_brain_integration_tutorial 46 | tutorials/advanced_integration_tutorial 47 | ``` 48 | 49 | ```{toctree} 50 | :hidden: true 51 | :maxdepth: 2 52 | :caption: API 53 | 54 | api 55 | ``` 56 | 57 | [GitHub]: https://github.com/rockdeme/chrysalis 58 | [bioRxiv preprint]: https://doi.org/10.1101/2023.08.17.553606 59 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/overview/getting-started.md: -------------------------------------------------------------------------------- 1 | # Getting started 2 | 3 | Quickstart guide to install and run **chrysalis**. 4 | 5 | ## Install 6 | 7 | ```shell 8 | pip install chrysalis-st 9 | ``` 10 | 11 | ## Run 12 | 13 | ```python 14 | import chrysalis as ch 15 | import scanpy as sc 16 | import matplotlib.pyplot as plt 17 | 18 | 19 | adata = sc.datasets.visium_sge(sample_id='V1_Human_Lymph_Node') 20 | 21 | sc.pp.calculate_qc_metrics(adata, inplace=True) 22 | sc.pp.filter_cells(adata, min_counts=6000) 23 | sc.pp.filter_genes(adata, min_cells=10) 24 | 25 | ch.detect_svgs(adata, min_morans=0.08, min_spots=0.05) 26 | 27 | sc.pp.normalize_total(adata, inplace=True) 28 | sc.pp.log1p(adata) 29 | 30 | ch.pca(adata) 31 | 32 | ch.aa(adata, n_pcs=20, n_archetypes=8) 33 | 34 | ch.plot(adata) 35 | plt.show() 36 | ``` 37 | -------------------------------------------------------------------------------- /docs/overview/installation.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | Create a new conda environment if required: 4 | 5 | ```shell 6 | conda create -n chrysalis-env python=3.8 7 | ``` 8 | 9 | ```shell 10 | conda activate chrysalis-env 11 | ``` 12 | 13 | You can install **chrysalis** from PyPI using pip: 14 | 15 | ```shell 16 | pip install chrysalis-st 17 | ``` 18 | 19 | This will install **chrysalis** and all dependencies including `scanpy`. 20 | 21 | ## Troubleshooting 22 | 23 | If `rvlib` fails to install, you can try installing it with conda: 24 | ```shell 25 | conda install -c conda-forge rvlib 26 | ``` 27 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | scipy 2 | numpy 3 | pandas 4 | matplotlib==3.7.2 5 | sphinx==5.3.0 6 | sphinx_rtd_theme==1.2.2 7 | myst_parser==0.18.1 8 | nbsphinx==0.9.2 9 | chrysalis-st -------------------------------------------------------------------------------- /gallery/readme.md: -------------------------------------------------------------------------------- 1 | # Gallery 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /misc/banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rockdeme/chrysalis/6c18ff8c3e2b1b64fdf8fe7217a96f98ed6749be/misc/banner.png -------------------------------------------------------------------------------- /misc/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rockdeme/chrysalis/6c18ff8c3e2b1b64fdf8fe7217a96f98ed6749be/misc/demo.png -------------------------------------------------------------------------------- /misc/deprecated_functions.py: -------------------------------------------------------------------------------- 1 | import colorsys 2 | import numpy as np 3 | import pandas as pd 4 | import scanpy as sc 5 | from tqdm import tqdm 6 | import geopandas as gpd 7 | import archetypes as arch 8 | from pysal.lib import weights 9 | from pysal.explore import esda 10 | import matplotlib.pyplot as plt 11 | from shapely.geometry import Point 12 | import matplotlib.colors as mcolors 13 | from sklearn.decomposition import PCA 14 | from scipy.spatial.distance import cdist 15 | 16 | 17 | def get_moransI(w_orig, y): 18 | # w = spatial weight (topological or actual distance) 19 | # y = actual value 20 | # y_hat = mean value 21 | # REF: https://github.com/yatshunlee/spatial_autocorrelation/blob/main/spatial_autocorrelation/moransI.py modified 22 | # wth some ChatGPT magic to remove the for loops 23 | 24 | if not isinstance(y, np.ndarray): 25 | raise TypeError("Passed array (feature) should be in numpy array (ndim = 1)") 26 | if y.shape[0] != w_orig.shape[0]: 27 | raise ValueError("Feature array is not the same shape of weight") 28 | if w_orig.shape[0] != w_orig.shape[1]: 29 | raise ValueError("Weight array should be in square shape") 30 | 31 | w = w_orig.copy() 32 | y_hat = np.mean(y) 33 | D = y - y_hat 34 | D_sq = (y - y_hat) ** 2 35 | N = y.shape[0] 36 | sum_W = np.sum(w) 37 | w *= D.reshape(-1, 1) * D.reshape(1, -1) * (w != 0) 38 | moransI = (np.sum(w) / sum(D_sq)) * (N / sum_W) 39 | return round(moransI, 8) 40 | 41 | 42 | def black_to_color(color): 43 | # define the colors in the colormap 44 | colors = ["black", color] 45 | 46 | # create a colormap object using the defined colors 47 | cmap = mcolors.LinearSegmentedColormap.from_list("", colors) 48 | 49 | return cmap 50 | 51 | 52 | def hls_to_hex(h, l, s): 53 | # convert the HSV values to RGB values 54 | r, g, b = colorsys.hls_to_rgb(h, l, s) 55 | # convert the RGB values to a hex color code 56 | hex_code = "#{:02X}{:02X}{:02X}".format(int(r * 255), int(g * 255), 76) 57 | return hex_code 58 | 59 | 60 | def generate_random_colors(num_colors, hue_range=(0, 1), saturation=0.5, lightness=0.5, min_distance=0.2): 61 | colors = [] 62 | hue_list = [] 63 | 64 | while len(colors) < num_colors: 65 | # Generate a random hue value within the specified range 66 | hue = np.random.uniform(hue_range[0], hue_range[1]) 67 | 68 | # Check if the hue is far enough away from the previous hue 69 | if len(hue_list) == 0 or all(abs(hue - h) > min_distance for h in hue_list): 70 | hue_list.append(hue) 71 | saturation = saturation 72 | lightness = lightness 73 | rgb = colorsys.hls_to_rgb(hue, lightness, saturation) 74 | hex_code = '#{:02x}{:02x}{:02x}'.format(int(rgb[0] * 255), int(rgb[1] * 255), int(rgb[2] * 255)) 75 | colors.append(hex_code) 76 | 77 | return colors 78 | 79 | def get_rgb_from_colormap(cmap, vmin, vmax, value): 80 | # normalize the value within the range [0, 1] 81 | norm = plt.Normalize(vmin=vmin, vmax=vmax) 82 | value_normalized = norm(value) 83 | 84 | # get the RGBA value from the colormap 85 | rgba = plt.get_cmap(cmap)(value_normalized) 86 | 87 | # convert the RGBA value to RGB 88 | # color = tuple(np.array(rgba[:3]) * 255) 89 | color = np.array(rgba[:, :3]) 90 | 91 | return color 92 | 93 | 94 | def blend_colors(colors_1, colors_2, weight=0.5): 95 | # ensure weight is between 0 and 1 96 | weight = max(0, min(1, weight)) 97 | 98 | # blend the colors using linear interpolation 99 | blended_colors = [] 100 | for i in range(len(colors_1)): 101 | r = (1 - weight) * colors_1[i][0] + weight * colors_2[i][0] 102 | g = (1 - weight) * colors_1[i][1] + weight * colors_2[i][1] 103 | b = (1 - weight) * colors_1[i][2] + weight * colors_2[i][2] 104 | blended_colors.append((r, g, b)) 105 | return blended_colors 106 | 107 | 108 | def mip_colors(colors_1, colors_2): 109 | # blend the colors using linear interpolation 110 | mip_color = [] 111 | for i in range(len(colors_1)): 112 | r = max(colors_1[i][0], colors_2[i][0]) 113 | g = max(colors_1[i][1], colors_2[i][1]) 114 | b = max(colors_1[i][2], colors_2[i][2]) 115 | mip_color.append((r, g, b)) 116 | return mip_color 117 | 118 | def chrysalis_plot_old(adata, pcs=8, hexcodes=None, seed=None, vis='mip_colors'): 119 | 120 | def norm_weight(a, b): 121 | # for weighting PCs if we want to use blend_colors 122 | return (b - a) / b 123 | 124 | # define PC colors 125 | if hexcodes is None: 126 | hexcodes = ['#db5f57', '#dbc257', '#91db57', '#57db80', '#57d3db', '#5770db', '#a157db', '#db57b2'] 127 | if seed is None: 128 | np.random.seed(len(adata)) 129 | else: 130 | np.random.seed(seed) 131 | np.random.shuffle(hexcodes) 132 | else: 133 | assert len(hexcodes) >= pcs 134 | 135 | 136 | # define colormaps 137 | cmaps = [] 138 | for pc in range(pcs): 139 | pc_cmap = black_to_color(hexcodes[pc]) 140 | pc_rgb = get_rgb_from_colormap(pc_cmap, 141 | vmin=min(adata.obs[f'pca_{pc}']), 142 | vmax=max(adata.obs[f'pca_{pc}']), 143 | value=adata.obs[f'pca_{pc}']) 144 | cmaps.append(pc_rgb) 145 | 146 | # blend colormaps 147 | if vis == 'mip_colors': 148 | cblend = mip_colors(cmaps[0], cmaps[1],) 149 | if len(cmaps) > 2: 150 | i = 2 151 | for cmap in cmaps[2:]: 152 | cblend = mip_colors(cblend, cmap,) 153 | i += 1 154 | elif vis == 'blend_colors': 155 | var_r = np.cumsum(adata.uns['pca']['variance_ratio'][:pcs]) # get variance ratios to normalize 156 | cblend = blend_colors(cmaps[0], cmaps[1], weight=norm_weight(var_r[0], var_r[1])) 157 | if len(cmaps) > 2: 158 | i = 2 159 | for cmap in cmaps[2:]: 160 | cblend = blend_colors(cblend, cmap, weight=norm_weight(var_r[i - 1], var_r[i])) 161 | i += 1 162 | else: 163 | raise Exception('vis should be either mip_colors or blend colors') 164 | 165 | # plot 166 | fig, ax = plt.subplots(1, 1, figsize=(6, 6)) 167 | ax.axis('off') 168 | # ax[idx].set_xlim((0, 8500)) 169 | # ax[idx].set_ylim((-8500, 0)) 170 | row = adata.obsm['spatial'][:, 0] 171 | col = adata.obsm['spatial'][:, 1] * -1 172 | plt.scatter(row, col, s=25, marker="h", c=cblend) 173 | ax.set_aspect('equal') 174 | 175 | 176 | def chrysalis_calculate_old(adata): 177 | sc.pp.filter_genes(adata, min_cells=1000) 178 | adata.var_names_make_unique() # moran dies so need some check later 179 | sc.pp.normalize_total(adata, inplace=True) 180 | sc.pp.log1p(adata) 181 | 182 | gene_matrix = adata.to_df() 183 | gene_list = list(gene_matrix.columns) 184 | gdf = gpd.GeoDataFrame(gene_matrix) 185 | gdf['spots'] = [Point(x, y) for x, y in zip(adata.obsm['spatial'][:, 0], adata.obsm['spatial'][:, 1] * -1)] 186 | gdf.geometry = gdf['spots'] 187 | 188 | w = weights.KNN.from_dataframe(gdf, k=6) 189 | w.transform = 'R' 190 | moran_dict = {} 191 | # moran.by_col(gdf,gene_list, w=w, permutations=0) this doesn't seem to be faster 192 | for c in tqdm(gene_list): 193 | moran = esda.moran.Moran(gdf[c], w, permutations=0) 194 | moran_dict[c] = moran.I 195 | 196 | moran_df = pd.DataFrame(data=moran_dict.values(), index=moran_dict.keys(), columns=["Moran's I"]) 197 | moran_df = moran_df.sort_values(ascending=False, by="Moran's I") 198 | adata.var['highly_variable'] = [True if x in moran_df[:1000].index else False for x in adata.var_names] 199 | adata.var["Moran's I"] = moran_df["Moran's I"] 200 | 201 | sc.pp.pca(adata) 202 | 203 | for i in range(20): 204 | adata.obs[f'pca_{i}'] = adata.obsm['X_pca'][:, i] 205 | 206 | # archetype analysis 207 | model = arch.AA(n_archetypes=8, n_init=3, max_iter=200, tol=0.001, random_state=42) 208 | model.fit(adata.obsm['X_pca'][:, :7]) 209 | 210 | for i in range(model.alphas_.shape[1]): 211 | adata.obs[f'aa_{i}'] = model.alphas_[:, i] 212 | 213 | 214 | def plot_loadings(adata): 215 | hexcodes = ['#db5f57', '#dbc257', '#91db57', '#57db80', '#57d3db', '#5770db', '#a157db', '#db57b2'] 216 | 217 | np.random.seed(len(adata)) 218 | np.random.shuffle(hexcodes) 219 | 220 | loadings = pd.DataFrame(adata.varm['PCs'][:, :20], index=adata.var_names) 221 | sl = loadings[[0]].sort_values(ascending=False, by=0)[:10] 222 | 223 | fig, ax = plt.subplots(2, 4, figsize=(3 * 4, 4 * 2)) 224 | ax = ax.flatten() 225 | for i in range(8): 226 | sl = loadings[[i]].sort_values(ascending=False, by=i)[:10] 227 | ax[i].axis('on') 228 | ax[i].grid(axis='x') 229 | ax[i].set_axisbelow(True) 230 | ax[i].barh(list(sl.index)[::-1], list(sl[i].values)[::-1], color=hexcodes[i]) 231 | ax[i].set_xlabel('Loading') 232 | ax[i].set_title(f'PC {i}') 233 | plt.tight_layout() 234 | plt.show() 235 | 236 | 237 | def chrysalis_plot_aa(adata, pcs=8, hexcodes=None, seed=None, vis='mip_colors'): 238 | 239 | def norm_weight(a, b): 240 | # for weighting PCs if we want to use blend_colors 241 | return (b - a) / b 242 | 243 | # define PC colors 244 | if hexcodes is None: 245 | hexcodes = ['#db5f57', '#dbc257', '#91db57', '#57db80', '#57d3db', '#5770db', '#a157db', '#db57b2'] 246 | 247 | if pcs > 8: 248 | if seed is None: 249 | np.random.seed(len(adata)) 250 | else: 251 | np.random.seed(seed) 252 | hexcodes = generate_random_colors(pcs, hue_range=(0.0, 1.0), min_distance=0.05) 253 | 254 | if seed is None: 255 | np.random.seed(len(adata)) 256 | else: 257 | np.random.seed(seed) 258 | np.random.shuffle(hexcodes) 259 | else: 260 | assert len(hexcodes) >= pcs 261 | 262 | 263 | # define colormaps 264 | cmaps = [] 265 | for pc in range(pcs): 266 | pc_cmap = black_to_color(hexcodes[pc]) 267 | pc_rgb = get_rgb_from_colormap(pc_cmap, 268 | vmin=min(adata.obs[f'aa_{pc}']), 269 | vmax=max(adata.obs[f'aa_{pc}']), 270 | value=adata.obs[f'aa_{pc}']) 271 | cmaps.append(pc_rgb) 272 | 273 | # blend colormaps 274 | if vis == 'mip_colors': 275 | cblend = mip_colors(cmaps[0], cmaps[1],) 276 | if len(cmaps) > 2: 277 | i = 2 278 | for cmap in cmaps[2:]: 279 | cblend = mip_colors(cblend, cmap,) 280 | i += 1 281 | elif vis == 'blend_colors': 282 | var_r = np.cumsum(adata.uns['pca']['variance_ratio'][:pcs]) # get variance ratios to normalize 283 | cblend = blend_colors(cmaps[0], cmaps[1], weight=norm_weight(var_r[0], var_r[1])) 284 | if len(cmaps) > 2: 285 | i = 2 286 | for cmap in cmaps[2:]: 287 | cblend = blend_colors(cblend, cmap, weight=norm_weight(var_r[i - 1], var_r[i])) 288 | i += 1 289 | else: 290 | raise Exception("Vis should be either 'mip_colors' or 'blend colors'") 291 | 292 | # plot 293 | fig, ax = plt.subplots(1, 1, figsize=(6, 6)) 294 | ax.axis('off') 295 | row = adata.obsm['spatial'][:, 0] 296 | col = adata.obsm['spatial'][:, 1] * -1 297 | ax.set_xlim((np.min(row) * 0.9, np.max(row) * 1.1)) 298 | ax.set_ylim((np.min(col) * 1.1, np.max(col) * 0.9)) 299 | ax.set_aspect('equal') 300 | 301 | # get the physical length of the x and y axes 302 | x_length = np.diff(ax.get_xlim())[0] * fig.dpi * fig.get_size_inches()[0] 303 | y_length = np.diff(ax.get_ylim())[0] * fig.dpi * fig.get_size_inches()[1] 304 | 305 | size = np.sqrt(x_length * y_length) * 0.000005 306 | 307 | plt.scatter(row, col, s=size, marker="h", c=cblend) -------------------------------------------------------------------------------- /misc/human_lymph_node.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rockdeme/chrysalis/6c18ff8c3e2b1b64fdf8fe7217a96f98ed6749be/misc/human_lymph_node.jpg -------------------------------------------------------------------------------- /misc/panel_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rockdeme/chrysalis/6c18ff8c3e2b1b64fdf8fe7217a96f98ed6749be/misc/panel_1.png -------------------------------------------------------------------------------- /misc/panel_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rockdeme/chrysalis/6c18ff8c3e2b1b64fdf8fe7217a96f98ed6749be/misc/panel_2.png -------------------------------------------------------------------------------- /plots/V1_Human_Lymph_Node.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rockdeme/chrysalis/6c18ff8c3e2b1b64fdf8fe7217a96f98ed6749be/plots/V1_Human_Lymph_Node.png -------------------------------------------------------------------------------- /plots/V1_Mouse_Brain_Sagittal_Anterior.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rockdeme/chrysalis/6c18ff8c3e2b1b64fdf8fe7217a96f98ed6749be/plots/V1_Mouse_Brain_Sagittal_Anterior.png -------------------------------------------------------------------------------- /plots/V1_Mouse_Brain_Sagittal_Posterior.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rockdeme/chrysalis/6c18ff8c3e2b1b64fdf8fe7217a96f98ed6749be/plots/V1_Mouse_Brain_Sagittal_Posterior.png -------------------------------------------------------------------------------- /plots/V1_Mouse_Kidney.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rockdeme/chrysalis/6c18ff8c3e2b1b64fdf8fe7217a96f98ed6749be/plots/V1_Mouse_Kidney.png -------------------------------------------------------------------------------- /plots/gallery/Parent_Visium_Human_BreastCancer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rockdeme/chrysalis/6c18ff8c3e2b1b64fdf8fe7217a96f98ed6749be/plots/gallery/Parent_Visium_Human_BreastCancer.png -------------------------------------------------------------------------------- /plots/gallery/Parent_Visium_Human_Cerebellum.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rockdeme/chrysalis/6c18ff8c3e2b1b64fdf8fe7217a96f98ed6749be/plots/gallery/Parent_Visium_Human_Cerebellum.png -------------------------------------------------------------------------------- /plots/gallery/Parent_Visium_Human_ColorectalCancer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rockdeme/chrysalis/6c18ff8c3e2b1b64fdf8fe7217a96f98ed6749be/plots/gallery/Parent_Visium_Human_ColorectalCancer.png -------------------------------------------------------------------------------- /plots/gallery/Parent_Visium_Human_Glioblastoma.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rockdeme/chrysalis/6c18ff8c3e2b1b64fdf8fe7217a96f98ed6749be/plots/gallery/Parent_Visium_Human_Glioblastoma.png -------------------------------------------------------------------------------- /plots/gallery/Parent_Visium_Human_OvarianCancer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rockdeme/chrysalis/6c18ff8c3e2b1b64fdf8fe7217a96f98ed6749be/plots/gallery/Parent_Visium_Human_OvarianCancer.png -------------------------------------------------------------------------------- /plots/gallery/V1_Adult_Mouse_Brain.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rockdeme/chrysalis/6c18ff8c3e2b1b64fdf8fe7217a96f98ed6749be/plots/gallery/V1_Adult_Mouse_Brain.png -------------------------------------------------------------------------------- /plots/gallery/V1_Adult_Mouse_Brain_Coronal_Section_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rockdeme/chrysalis/6c18ff8c3e2b1b64fdf8fe7217a96f98ed6749be/plots/gallery/V1_Adult_Mouse_Brain_Coronal_Section_1.png -------------------------------------------------------------------------------- /plots/gallery/V1_Adult_Mouse_Brain_Coronal_Section_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rockdeme/chrysalis/6c18ff8c3e2b1b64fdf8fe7217a96f98ed6749be/plots/gallery/V1_Adult_Mouse_Brain_Coronal_Section_2.png -------------------------------------------------------------------------------- /plots/gallery/V1_Breast_Cancer_Block_A_Section_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rockdeme/chrysalis/6c18ff8c3e2b1b64fdf8fe7217a96f98ed6749be/plots/gallery/V1_Breast_Cancer_Block_A_Section_1.png -------------------------------------------------------------------------------- /plots/gallery/V1_Breast_Cancer_Block_A_Section_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rockdeme/chrysalis/6c18ff8c3e2b1b64fdf8fe7217a96f98ed6749be/plots/gallery/V1_Breast_Cancer_Block_A_Section_2.png -------------------------------------------------------------------------------- /plots/gallery/V1_Human_Heart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rockdeme/chrysalis/6c18ff8c3e2b1b64fdf8fe7217a96f98ed6749be/plots/gallery/V1_Human_Heart.png -------------------------------------------------------------------------------- /plots/gallery/V1_Human_Lymph_Node.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rockdeme/chrysalis/6c18ff8c3e2b1b64fdf8fe7217a96f98ed6749be/plots/gallery/V1_Human_Lymph_Node.png -------------------------------------------------------------------------------- /plots/gallery/V1_Mouse_Brain_Sagittal_Anterior.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rockdeme/chrysalis/6c18ff8c3e2b1b64fdf8fe7217a96f98ed6749be/plots/gallery/V1_Mouse_Brain_Sagittal_Anterior.png -------------------------------------------------------------------------------- /plots/gallery/V1_Mouse_Brain_Sagittal_Anterior_Section_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rockdeme/chrysalis/6c18ff8c3e2b1b64fdf8fe7217a96f98ed6749be/plots/gallery/V1_Mouse_Brain_Sagittal_Anterior_Section_2.png -------------------------------------------------------------------------------- /plots/gallery/V1_Mouse_Brain_Sagittal_Posterior.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rockdeme/chrysalis/6c18ff8c3e2b1b64fdf8fe7217a96f98ed6749be/plots/gallery/V1_Mouse_Brain_Sagittal_Posterior.png -------------------------------------------------------------------------------- /plots/gallery/V1_Mouse_Brain_Sagittal_Posterior_Section_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rockdeme/chrysalis/6c18ff8c3e2b1b64fdf8fe7217a96f98ed6749be/plots/gallery/V1_Mouse_Brain_Sagittal_Posterior_Section_2.png -------------------------------------------------------------------------------- /plots/gallery/V1_Mouse_Kidney.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rockdeme/chrysalis/6c18ff8c3e2b1b64fdf8fe7217a96f98ed6749be/plots/gallery/V1_Mouse_Kidney.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | # for testing use pip install --index-url=https://test.pypi.org/simple/ 6 | # --extra-index-url=https://pypi.org/simple/ chrysalis-st==x.x.x 7 | 8 | [project] 9 | name = "chrysalis-st" 10 | version = "0.2.0" 11 | authors = [ 12 | {name="Demeter Túrós"}, 13 | ] 14 | description = "Powerful and lightweight package to identify tissue compartments in spatial transcriptomics datasets." 15 | readme = "README.md" 16 | requires-python = ">=3.7" 17 | classifiers = [ 18 | "Programming Language :: Python :: 3", 19 | "License :: OSI Approved :: MIT License", 20 | "Operating System :: OS Independent", 21 | ] 22 | 23 | dependencies = [ 24 | "archetypes==0.4.2", # last version before PyTorch integration, chrysalis is not relying on that yet 25 | "matplotlib", 26 | "numpy", 27 | "pandas", 28 | "pysal", 29 | "scanpy", 30 | "scikit-learn", 31 | "scipy", 32 | "tqdm", 33 | "seaborn", 34 | ] 35 | 36 | [project.urls] 37 | "Homepage" = "https://github.com/rockdeme/chrysalis" 38 | "Documentation" = "https://chrysalis.readthedocs.io/" 39 | 40 | [tool.hatch.build.targets.wheel] 41 | packages = ["chrysalis"] 42 | 43 | [tool.hatch.build.targets.sdist] 44 | exclude = [ 45 | "/.github", 46 | "/docs", 47 | "/misc", 48 | "/plots", 49 | "/article", 50 | "/gallery", 51 | ".readthedocs.yaml", 52 | ] 53 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | archetypes 2 | matplotlib 3 | numpy 4 | pandas 5 | pysal 6 | scanpy 7 | scikit_learn 8 | scipy 9 | tqdm 10 | seaborn 11 | anndata 12 | --------------------------------------------------------------------------------