├── docs ├── source │ ├── Installation.rst │ ├── About STREAM2.rst │ ├── _static │ │ └── img │ │ │ ├── logo.png │ │ │ └── Fig1_V2.1.jpg │ ├── index.rst │ ├── API.rst │ └── conf.py ├── environment.yml ├── Makefile └── make.bat ├── tests ├── data │ └── rnaseq_paul15.h5ad └── test_stream2_rnaseq.py ├── .readthedocs.yml ├── stream2 ├── plotting │ ├── __init__.py │ ├── _utils.py │ └── _palettes.py ├── __init__.py ├── preprocessing │ ├── __init__.py │ ├── check_env.sh │ ├── _general.py │ ├── _variable_genes.py │ ├── _utils.py │ ├── _pca.py │ └── _qc.py ├── tools │ ├── __init__.py │ ├── _dimension_reduction.py │ ├── _pseudotime.py │ ├── _markers.py │ ├── _graph_utils.py │ └── _elpigraph.py ├── readwrite.py ├── _settings.py └── _utils.py ├── requirements.txt ├── setup.py ├── .github └── workflows │ └── CI.yml ├── LICENSE ├── README.md └── .gitignore /docs/source/Installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ -------------------------------------------------------------------------------- /docs/source/About STREAM2.rst: -------------------------------------------------------------------------------- 1 | About STREAM2 2 | ============= -------------------------------------------------------------------------------- /tests/data/rnaseq_paul15.h5ad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pinellolab/STREAM2/HEAD/tests/data/rnaseq_paul15.h5ad -------------------------------------------------------------------------------- /docs/source/_static/img/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pinellolab/STREAM2/HEAD/docs/source/_static/img/logo.png -------------------------------------------------------------------------------- /docs/source/_static/img/Fig1_V2.1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pinellolab/STREAM2/HEAD/docs/source/_static/img/Fig1_V2.1.jpg -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | conda: 4 | environment: docs/environment.yml 5 | 6 | build: 7 | os: ubuntu-22.04 8 | tools: 9 | python: "mambaforge-4.10" 10 | 11 | sphinx: 12 | builder: html 13 | configuration: docs/source/conf.py 14 | fail_on_warning: false -------------------------------------------------------------------------------- /stream2/plotting/__init__.py: -------------------------------------------------------------------------------- 1 | """Plotting.""" 2 | 3 | from ._plot import ( 4 | pca_variance_ratio, 5 | pcs_features, 6 | variable_genes, 7 | violin, 8 | hist, 9 | dimension_reduction, 10 | graph, 11 | feature_path, 12 | stream_sc, 13 | stream 14 | ) 15 | -------------------------------------------------------------------------------- /docs/environment.yml: -------------------------------------------------------------------------------- 1 | name: readthedocs 2 | channels: 3 | - conda-forge 4 | - bioconda 5 | - defaults 6 | dependencies: 7 | - pip 8 | - numpy<1.24.0 #avoid errors caused by 1.24 9 | - pandoc>=2.14 10 | - pip: 11 | - sphinx>=3.0 12 | - sphinx-rtd-theme>=0.5 13 | - nbsphinx>=0.8 14 | - git+https://github.com/pinellolab/STREAM2 15 | -------------------------------------------------------------------------------- /stream2/__init__.py: -------------------------------------------------------------------------------- 1 | """STREAM2.""" 2 | 3 | from ._settings import settings 4 | from . import preprocessing as pp 5 | from . import tools as tl 6 | from . import plotting as pl 7 | from .readwrite import * 8 | 9 | __version__ = "0.1a" 10 | 11 | import sys 12 | sys.modules.update( 13 | {f'{__name__}.{m}': globals()[m] for m in ['tl', 'pp', 'pl']}) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.17.0 2 | numba>=0.52.0 3 | networkx>=2.5 4 | pandas>=1.0,!=1.1 ##required by Anndata 5 | anndata>=0.7.4 6 | scanpy>=1.6.0 7 | shapely>=2.0.1 8 | statsmodels>=0.12.1 9 | # h5py<3.0.0 ##avoid byte strings 10 | scikit-learn>=1.2 11 | scipy>=1.4 12 | kneed>=0.7 13 | seaborn>=0.11 14 | matplotlib>=3.3 15 | plotly>=4.14.0 16 | scikit-misc>=0.1.3 17 | adjusttext>=0.7.3 18 | umap-learn>=0.4.6 19 | elpigraph-python>=0.3.1 20 | statsmodels>=0.13.2 21 | python-slugify>=5.0.0 -------------------------------------------------------------------------------- /stream2/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | """Preprocessing.""" 2 | 3 | from ._general import ( 4 | log_transform, 5 | normalize, 6 | ) 7 | from ._qc import ( 8 | cal_qc, 9 | cal_qc_rna, 10 | cal_qc_atac, 11 | filter_samples, 12 | filter_cells_rna, 13 | filter_cells_atac, 14 | filter_features, 15 | filter_genes, 16 | filter_peaks, 17 | ) 18 | from ._pca import ( 19 | pca, 20 | select_pcs, 21 | select_pcs_features, 22 | ) 23 | from ._variable_genes import select_variable_genes 24 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. stream2 documentation master file, created by 2 | sphinx-quickstart on Thu Feb 4 21:51:31 2021. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to STREAM2's documentation! 7 | =================================== 8 | 9 | 10 | Contents 11 | ======== 12 | 13 | .. toctree:: 14 | :maxdepth: 2 15 | :caption: Contents: 16 | 17 | About STREAM2 18 | Installation 19 | API 20 | 21 | 22 | .. toctree:: 23 | :maxdepth: 1 24 | :caption: Tutorials 25 | 26 | complex_structure 27 | supervision_ordinal 28 | supervision_categorical 29 | multiomics 30 | stream_plots -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /stream2/preprocessing/check_env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ENVS=$(conda env list | awk '{print $1}' ) 3 | if [[ $ENVS = *"$1"* ]]; then 4 | #source /data/pinello/SHARED_SOFTWARE/anaconda_latest/etc/profile.d/conda.sh 5 | #source ~/anaconda3/etc/profile.d/conda.sh 6 | conda activate $1 7 | echo "change conda env to $1" 8 | echo 'Start ChromVAR computation' 9 | parent_path=$( cd "$(dirname "${BASH_SOURCE[0]}")" ; pwd -P ) 10 | R CMD BATCH --no-save --no-restore "--args input='$2' species='$3' genome='$4' feature='$5' n_jobs=$6" $parent_path/stream2_chromVar.R $2/stream2_chromVar.out & 11 | date 12 | wait 13 | else 14 | echo "Error: Please provide a valid virtual environment. or create a new environment by run 'conda create -n stream2_chromVar R bioconductor-chromvar' " 15 | exit 16 | fi; -------------------------------------------------------------------------------- /stream2/tools/__init__.py: -------------------------------------------------------------------------------- 1 | """The core functionality.""" 2 | 3 | from ._dimension_reduction import dimension_reduction 4 | from ._elpigraph import learn_graph, seed_graph 5 | from ._pseudotime import infer_pseudotime 6 | from ._markers import ( 7 | detect_transition_markers, 8 | spearman_columns, 9 | spearman_pairwise, 10 | xicorr_columns, 11 | xicorr_pairwise, 12 | ) 13 | from ._graph_utils import ( 14 | add_path, 15 | del_path, 16 | find_paths, 17 | refit_graph, 18 | extend_leaves, 19 | prune_graph, 20 | get_weights, 21 | get_component, 22 | find_disconnected_components, 23 | ordinal_knn, 24 | smooth_ordinal_labels, 25 | early_groups, 26 | interpolate, 27 | use_graph_with_n_nodes, 28 | project_graph 29 | ) 30 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | if sys.version_info < (3, 7): 4 | sys.exit('stream2 requires Python >= 3.7') 5 | 6 | from setuptools import setup, find_packages 7 | from pathlib import Path 8 | setup( 9 | name='stream2', 10 | version='0.1a', 11 | author='Huidong Chen', 12 | athor_email='hd7chen AT gmail DOT com', 13 | license='BSD', 14 | description="STREAM2: Fast and scalable trajectory analysis" 15 | "of single-cell omics data", 16 | long_description=Path('README.md').read_text('utf-8'), 17 | long_description_content_type="text/markdown", 18 | url='https://github.com/pinellolab/STREAM2', 19 | packages=find_packages(), 20 | classifiers=[ 21 | "Programming Language :: Python :: 3", 22 | "License :: OSI Approved :: BSD License", 23 | "Operating System :: OS Independent", 24 | ], 25 | python_requires='>=3.6', 26 | install_requires=[ 27 | p.strip() 28 | for p in Path('requirements.txt').read_text('utf-8').splitlines() 29 | ], 30 | ) 31 | -------------------------------------------------------------------------------- /stream2/readwrite.py: -------------------------------------------------------------------------------- 1 | """reading and writing.""" 2 | import anndata as ad 3 | import os 4 | import pandas as pd 5 | 6 | from anndata import ( 7 | AnnData, 8 | read_h5ad, 9 | read_csv, 10 | read_excel, 11 | read_hdf, 12 | read_loom, 13 | read_mtx, 14 | read_text, 15 | read_umi_tools, 16 | read_zarr, 17 | ) 18 | 19 | 20 | def read_10X_output(file_path, assay="RNA", **kwargs): 21 | if file_path is None: 22 | file_path = "" 23 | _fp = lambda f: os.path.join(file_path, f) 24 | 25 | adata = ad.read_mtx(_fp("matrix.mtx"), **kwargs).T 26 | adata.X = adata.X 27 | adata.obs_names = pd.read_csv(_fp("barcodes.tsv"), header=None)[0] 28 | if assay == "ATAC": 29 | features = pd.read_csv(_fp("peaks.bed"), header=None, sep="\t") 30 | features.columns = ["seqnames", "start", "end"] 31 | features.index = ( 32 | features["seqnames"].astype(str) 33 | + "_" 34 | + features["start"].astype(str) 35 | + "_" 36 | + features["end"].astype(str) 37 | ) 38 | else: 39 | features = pd.read_csv(_fp("genes.tsv"), header=None, sep="\t") 40 | features.index = features.index.astype("str") 41 | adata.var = features 42 | 43 | return adata 44 | -------------------------------------------------------------------------------- /docs/source/API.rst: -------------------------------------------------------------------------------- 1 | .. automodule:: stream2 2 | 3 | API 4 | === 5 | 6 | Import stream2 as:: 7 | 8 | import stream2 as st2 9 | 10 | Configuration for STREAM2 11 | ~~~~~~~~~~~~~~~~~~~~~~~~~ 12 | .. autosummary:: 13 | :toctree: _autosummary 14 | 15 | settings.set_figure_params 16 | settings.set_workdir 17 | 18 | 19 | Reading 20 | ~~~~~~~ 21 | .. autosummary:: 22 | :toctree: _autosummary 23 | 24 | read_csv 25 | read_h5ad 26 | read_10X_output 27 | read_mtx 28 | 29 | 30 | See more at `anndata `_ 31 | 32 | 33 | Preprocessing 34 | ~~~~~~~~~~~~~ 35 | .. autosummary:: 36 | :toctree: _autosummary 37 | 38 | pp.log_transform 39 | pp.normalize 40 | pp.cal_qc_rna 41 | pp.filter_genes 42 | pp.pca 43 | pp.select_variable_genes 44 | 45 | 46 | Tools 47 | ~~~~~ 48 | .. autosummary:: 49 | :toctree: _autosummary 50 | 51 | tl.dimension_reduction 52 | tl.seed_graph 53 | tl.learn_graph 54 | tl.infer_pseudotime 55 | tl.add_path 56 | tl.del_path 57 | tl.get_weights 58 | tl.extend_leaves 59 | tl.refit_graph 60 | tl.project_graph 61 | 62 | 63 | Plotting 64 | ~~~~~~~~ 65 | .. autosummary:: 66 | :toctree: _autosummary 67 | 68 | pl.pca_variance_ratio 69 | pl.variable_genes 70 | pl.violin 71 | pl.graph 72 | pl.dimension_reduction 73 | pl.stream_sc 74 | pl.stream -------------------------------------------------------------------------------- /.github/workflows/CI.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | build-linux: 7 | runs-on: ubuntu-20.04 8 | strategy: 9 | max-parallel: 5 10 | matrix: 11 | python-version: [3.6, 3.7, 3.8, 3.9] 12 | 13 | steps: 14 | - uses: actions/checkout@v2 15 | - name: Set up Python ${{ matrix.python-version }} 16 | uses: actions/setup-python@v2 17 | with: 18 | python-version: ${{ matrix.python-version }} 19 | - name: Add conda to system path 20 | run: | 21 | # $CONDA is an environment variable pointing to the root of the miniconda directory 22 | echo $CONDA/bin >> $GITHUB_PATH 23 | - name: Install dependencies 24 | run: | 25 | conda config --add channels bioconda 26 | conda config --add channels conda-forge 27 | # conda env update --file environment.yml --name base 28 | python -m pip install --upgrade pip 29 | pip install -r requirements.txt 30 | pip install -e . 31 | - name: Lint with flake8 32 | run: | 33 | conda install flake8 34 | # stop the build if there are Python syntax errors or undefined names 35 | flake8 . --count --select=E9,E501,F63,F7,F82 --show-source --statistics 36 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 37 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 38 | - name: Test with pytest 39 | run: | 40 | conda install pytest 41 | pytest 42 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2021, Pinello Lab 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![CI](https://github.com/pinellolab/stream2/actions/workflows/CI.yml/badge.svg)](https://github.com/pinellolab/stream2/actions/workflows/CI.yml) 2 | 3 | ![simba](./docs/source/_static/img/logo.png?raw=true) 4 | 5 | # STREAM2 6 | STREAM2 (**S**ingle-cell **T**rajectories **R**econstruction, **E**xploration **A**nd **M**apping) is an interactive pipeline capable of disentangling and visualizing complex trajectories from for single-cell omics data. 7 | 8 | 9 | Installation 10 | ------------ 11 | ```sh 12 | $ pip install git+https://github.com/pinellolab/STREAM2 13 | ``` 14 | 15 | Tutorials 16 | --------- 17 | Preliminary tutorials for the usage of STREAM2 can be found at **STREAM2_tutorials** repository https://github.com/pinellolab/STREAM2_tutorials. 18 | 19 | 20 | Description 21 | ----------- 22 | The four key innovations of STREAM2 are: 23 | 1) STREAM2 can learn more biologically meaningful trajectories in a semi-supervised way by leveraging external information (e.g. time points, FACS labels, predefined relations of clusters, etc.); 24 | 2) STREAM2 is able to learn not only linear or tree-like structures but also more complex graphs with loops or disconnected components; 25 | 3) STREAM2 supports trajectory inference for various single-cell assays such as gene expression, chromatin accessibility, protein expression level, and DNA methylation; 26 | 4) STREAM2 introduces a flexible path-based marker detection procedure. In addition, we provide a scalable and fast python package along with a comprehensive documentation website to facilitate STREAM2 analysis. 27 | 28 | ![simba](./docs/source/_static/img/Fig1_V2.1.jpg?raw=true) 29 | -------------------------------------------------------------------------------- /stream2/preprocessing/_general.py: -------------------------------------------------------------------------------- 1 | """General preprocessing functions.""" 2 | 3 | import numpy as np 4 | from sklearn.utils import sparsefuncs 5 | from ._utils import cal_tf_idf 6 | 7 | 8 | def log_transform(adata): 9 | """Return the natural logarithm of one plus the input array, element-wise. 10 | 11 | Parameters 12 | ---------- 13 | adata: AnnData 14 | Annotated data matrix. 15 | 16 | Returns 17 | ------- 18 | updates `adata` with the following fields. 19 | X: `numpy.ndarray` (`adata.X`) 20 | Store #observations × #var_genes logarithmized data matrix. 21 | """ 22 | 23 | adata.X = np.log1p(adata.X) 24 | return None 25 | 26 | 27 | def normalize(adata, method="lib_size", scale_factor=1e4, save_raw=True): 28 | """Normalize count matrix. 29 | 30 | Parameters 31 | ---------- 32 | adata: AnnData 33 | Annotated data matrix. 34 | method: `str`, optional (default: 'lib_size') 35 | Choose from {{'lib_size','tf_idf'}}. 36 | Method used for dimension reduction.\n 37 | 'lib_size': Total-count normalize (library-size correct).\n 38 | 'tf_idf': TF-IDF (term frequency–inverse document frequency) 39 | transformation. 40 | Returns 41 | ------- 42 | updates `adata` with the following fields. 43 | X: `numpy.ndarray` (`adata.X`) 44 | Store #observations × #var_genes normalized data matrix. 45 | """ 46 | 47 | if method not in ["lib_size", "tf_idf"]: 48 | raise ValueError("unrecognized method '%s'" % method) 49 | if save_raw: 50 | adata.layers["raw"] = adata.X.copy() 51 | if method == "lib_size": 52 | sparsefuncs.inplace_row_scale(adata.X, 1 / adata.X.sum(axis=1).A) 53 | adata.X = adata.X * scale_factor 54 | if method == "tf_idf": 55 | adata.X = cal_tf_idf(adata.X) 56 | -------------------------------------------------------------------------------- /stream2/preprocessing/_variable_genes.py: -------------------------------------------------------------------------------- 1 | """Preprocess.""" 2 | 3 | import numpy as np 4 | from scipy.sparse import ( 5 | csr_matrix, 6 | ) 7 | from sklearn.utils import sparsefuncs 8 | from skmisc.loess import loess 9 | 10 | 11 | def select_variable_genes( 12 | adata, 13 | layer="raw", 14 | span=0.3, 15 | n_top_genes=2000, 16 | ): 17 | """Select highly variable genes. 18 | 19 | This function implenments the method 'vst' in Seurat v3. 20 | Inspired by Scanpy. 21 | 22 | Parameters 23 | ---------- 24 | 25 | 26 | Returns 27 | ------- 28 | """ 29 | 30 | if layer is None: 31 | X = adata.X 32 | else: 33 | X = adata.layers[layer].astype(np.float64).copy() 34 | mean, variance = sparsefuncs.mean_variance_axis(X, axis=0) 35 | variance_expected = np.zeros(adata.shape[1], dtype=np.float64) 36 | not_const = variance > 0 37 | 38 | model = loess( 39 | np.log10(mean[not_const]), 40 | np.log10(variance[not_const]), 41 | span=span, 42 | degree=2, 43 | ) 44 | model.fit() 45 | variance_expected[not_const] = 10**model.outputs.fitted_values 46 | N = adata.shape[0] 47 | clip_max = np.sqrt(N) 48 | clip_val = np.sqrt(variance_expected) * clip_max + mean 49 | 50 | X = csr_matrix(X) 51 | mask = X.data > clip_val[X.indices] 52 | X.data[mask] = clip_val[X.indices[mask]] 53 | 54 | squared_X_sum = np.array(X.power(2).sum(axis=0)) 55 | X_sum = np.array(X.sum(axis=0)) 56 | 57 | norm_gene_var = (1 / ((N - 1) * variance_expected)) * ( 58 | (N * np.square(mean)) + squared_X_sum - 2 * X_sum * mean 59 | ) 60 | norm_gene_var = norm_gene_var.flatten() 61 | 62 | adata.var["variances_norm"] = norm_gene_var 63 | adata.var["variances"] = variance 64 | adata.var["means"] = mean 65 | ids_top = norm_gene_var.argsort()[-n_top_genes:][::-1] 66 | adata.var["highly_variable"] = np.isin(range(adata.shape[1]), ids_top) 67 | print(f"{n_top_genes} variable genes are selected.") 68 | -------------------------------------------------------------------------------- /stream2/preprocessing/_utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions and classes.""" 2 | 3 | import numpy as np 4 | from kneed import KneeLocator 5 | from scipy.sparse import csr_matrix, diags 6 | 7 | 8 | def locate_elbow( 9 | x, 10 | y, 11 | S=10, 12 | min_elbow=0, 13 | curve="convex", 14 | direction="decreasing", 15 | online=False, 16 | **kwargs 17 | ): 18 | """Detect knee points. 19 | 20 | Parameters 21 | ---------- 22 | x : `array-like` 23 | x values 24 | y : `array-like` 25 | y values 26 | S : `float`, optional (default: 10) 27 | Sensitivity 28 | min_elbow: `int`, optional (default: 0) 29 | The minimum elbow location 30 | curve: `str`, optional (default: 'convex') 31 | Choose from {'convex','concave'} 32 | If 'concave', algorithm will detect knees, 33 | If 'convex', algorithm will detect elbows. 34 | direction: `str`, optional (default: 'decreasing') 35 | Choose from {'decreasing','increasing'} 36 | online: `bool`, optional (default: False) 37 | kneed will correct old knee points if True, 38 | kneed will return first knee if False. 39 | **kwargs: `dict`, optional 40 | Extra arguments to KneeLocator. 41 | 42 | Returns 43 | ------- 44 | elbow: `int` 45 | elbow point 46 | """ 47 | kneedle = KneeLocator( 48 | x[int(min_elbow):], 49 | y[int(min_elbow):], 50 | S=S, 51 | curve=curve, 52 | direction=direction, 53 | online=online, 54 | **kwargs, 55 | ) 56 | if kneedle.elbow is None: 57 | elbow = len(y) 58 | else: 59 | elbow = int(kneedle.elbow) 60 | return elbow 61 | 62 | 63 | def cal_tf_idf(mat): 64 | """Transform a count matrix to a tf-idf representation.""" 65 | mat = csr_matrix(mat) 66 | tf = csr_matrix(mat / (mat.sum(axis=0))) 67 | idf = np.array(np.log(1 + mat.shape[1] / mat.sum(axis=1))).flatten() 68 | tf_idf = csr_matrix(np.dot(diags(idf), tf)) 69 | return tf_idf 70 | -------------------------------------------------------------------------------- /tests/test_stream2_rnaseq.py: -------------------------------------------------------------------------------- 1 | import stream2 as st 2 | import pytest 3 | 4 | 5 | @pytest.fixture 6 | def adata(): 7 | return st.read_h5ad( 8 | "tests/data/rnaseq_paul15.h5ad") 9 | 10 | 11 | def test_stream2_rnaseq_paul15(adata, tmp_path): 12 | st.settings.set_workdir(tmp_path / "result_rnaseq_paul15") 13 | st.settings.set_figure_params(dpi=80, 14 | style='white', 15 | fig_size=[5, 5], 16 | rc={'image.cmap': 'viridis'}) 17 | st.pp.filter_genes(adata, min_n_cells=3) 18 | st.pp.cal_qc_rna(adata) 19 | st.pl.violin(adata, 20 | list_obs=['n_counts', 'n_genes', 'pct_mt']) 21 | st.pp.normalize(adata, method='lib_size') 22 | st.pp.log_transform(adata) 23 | st.pp.select_variable_genes(adata, n_top_genes=500) 24 | st.pl.variable_genes(adata, show_texts=False) 25 | st.pp.pca(adata, feature='highly_variable', n_components=40) 26 | st.pl.pca_variance_ratio(adata, log=True) 27 | 28 | st.tl.dimension_reduction(adata, obsm='X_pca', n_dim=40, n_jobs=1) 29 | st.pl.dimension_reduction(adata, color=['paul15_clusters', 'n_genes'], 30 | dict_drawing_order={ 31 | 'paul15_clusters': 'random', 32 | 'n_genes': 'sorted'}, 33 | fig_legend_ncol=2, 34 | fig_size=(5.5, 5)) 35 | 36 | st.tl.seed_graph(adata) 37 | st.tl.learn_graph(adata) 38 | st.pl.graph(adata, 39 | color=['paul15_clusters', 'n_genes'], 40 | show_text=True, 41 | show_node=True) 42 | st.tl.infer_pseudotime(adata, 43 | source=0, 44 | target=4) 45 | st.pl.graph(adata, 46 | color=['epg_pseudotime'], 47 | show_text=False, 48 | show_node=False) 49 | st.tl.infer_pseudotime(adata, source=0) 50 | st.pl.graph(adata, 51 | color=['epg_pseudotime'], 52 | show_text=False, 53 | show_node=False) 54 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.gitignore.io/api/osx,python,windows 2 | 3 | ### OSX ### 4 | *.DS_Store 5 | .AppleDouble 6 | .LSOverride 7 | 8 | # Icon must end with two \r 9 | Icon 10 | 11 | # Thumbnails 12 | ._* 13 | 14 | # Files that might appear in the root of a volume 15 | .DocumentRevisions-V100 16 | .fseventsd 17 | .Spotlight-V100 18 | .TemporaryItems 19 | .Trashes 20 | .VolumeIcon.icns 21 | .com.apple.timemachine.donotpresent 22 | 23 | # Directories potentially created on remote AFP share 24 | .AppleDB 25 | .AppleDesktop 26 | Network Trash Folder 27 | Temporary Items 28 | .apdisk 29 | 30 | ### Python ### 31 | # Byte-compiled / optimized / DLL files 32 | __pycache__/ 33 | *.py[cod] 34 | *$py.class 35 | 36 | # C extensions 37 | *.so 38 | 39 | # Distribution / packaging 40 | .Python 41 | build/ 42 | develop-eggs/ 43 | dist/ 44 | downloads/ 45 | eggs/ 46 | .eggs/ 47 | lib/ 48 | lib64/ 49 | parts/ 50 | sdist/ 51 | var/ 52 | wheels/ 53 | *.egg-info/ 54 | .installed.cfg 55 | *.egg 56 | 57 | # PyInstaller 58 | # Usually these files are written by a python script from a template 59 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 60 | *.manifest 61 | *.spec 62 | 63 | # Installer logs 64 | pip-log.txt 65 | pip-delete-this-directory.txt 66 | 67 | # Unit test / coverage reports 68 | htmlcov/ 69 | .tox/ 70 | .coverage 71 | .coverage.* 72 | .cache 73 | .pytest_cache/ 74 | nosetests.xml 75 | coverage.xml 76 | *.cover 77 | .hypothesis/ 78 | 79 | # Translations 80 | *.mo 81 | *.pot 82 | 83 | # Flask stuff: 84 | instance/ 85 | .webassets-cache 86 | 87 | # Scrapy stuff: 88 | .scrapy 89 | 90 | # Sphinx documentation 91 | docs/_build/ 92 | 93 | # PyBuilder 94 | target/ 95 | 96 | # Jupyter Notebook 97 | .ipynb_checkpoints 98 | 99 | # pyenv 100 | .python-version 101 | 102 | # celery beat schedule file 103 | celerybeat-schedule.* 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | 130 | ### Windows ### 131 | # Windows thumbnail cache files 132 | Thumbs.db 133 | ehthumbs.db 134 | ehthumbs_vista.db 135 | 136 | # Folder config file 137 | Desktop.ini 138 | 139 | # Recycle Bin used on file shares 140 | $RECYCLE.BIN/ 141 | 142 | # Windows Installer files 143 | *.cab 144 | *.msi 145 | *.msm 146 | *.msp 147 | 148 | # Windows shortcuts 149 | *.lnk 150 | 151 | 152 | # End of https://www.gitignore.io/api/osx,python,windows 153 | .vscode/settings.json 154 | -------------------------------------------------------------------------------- /stream2/tools/_dimension_reduction.py: -------------------------------------------------------------------------------- 1 | """UMAP (Uniform Manifold Approximation and Projection)""" 2 | 3 | import umap as umap_learn 4 | from sklearn.manifold import ( 5 | LocallyLinearEmbedding, 6 | Isomap, 7 | TSNE, 8 | SpectralEmbedding, 9 | ) 10 | 11 | 12 | def dimension_reduction( 13 | adata, 14 | n_neighbors=15, 15 | n_components=2, 16 | random_state=2020, 17 | layer=None, 18 | obsm=None, 19 | n_dim=None, 20 | method="umap", 21 | eigen_solver="auto", 22 | **kwargs, 23 | ): 24 | """perform dimension reduction 25 | 26 | Parameters 27 | ---------- 28 | adata: AnnData 29 | Annotated data matrix. 30 | method: `str`, optional (default: 'umap') 31 | Choose from {{'umap','se','mlle','tsne','isomap'}} 32 | Method used for dimension reduction. 33 | 'umap': Uniform Manifold Approximation and Projection 34 | 'se': Spectral embedding algorithm 35 | 'mlle': Modified locally linear embedding algorithm 36 | 'tsne': T-distributed Stochastic Neighbor Embedding 37 | 'isomap': Isomap Embedding 38 | 39 | 40 | Returns 41 | ------- 42 | updates `adata` with the following fields: 43 | `.obsm['X_umap']` : `numpy.ndarray` 44 | UMAP coordinates of samples. 45 | """ 46 | 47 | if sum(list(map(lambda x: x is not None, [layer, obsm]))) == 2: 48 | raise ValueError("Only one of `layer` and `obsm` can be used") 49 | elif obsm is not None: 50 | X = adata.obsm[obsm] 51 | elif layer is not None: 52 | X = adata.layers[layer] 53 | else: 54 | X = adata.X 55 | if n_dim is not None: 56 | X = X[:, :n_dim] 57 | 58 | if method == "umap": 59 | reducer = umap_learn.UMAP( 60 | n_neighbors=n_neighbors, 61 | n_components=n_components, 62 | random_state=random_state, 63 | **kwargs, 64 | ) 65 | elif method == "se": 66 | reducer = SpectralEmbedding( 67 | n_neighbors=n_neighbors, 68 | n_components=n_components, 69 | random_state=random_state, 70 | **kwargs, 71 | ) 72 | elif method == "mlle": 73 | reducer = LocallyLinearEmbedding( 74 | n_neighbors=n_neighbors, 75 | n_components=n_components, 76 | eigen_solver=eigen_solver, 77 | random_state=random_state, 78 | **kwargs, 79 | ) 80 | elif method == "tsne": 81 | reducer = TSNE( 82 | n_components=n_components, 83 | random_state=random_state, 84 | **kwargs, 85 | ) 86 | elif method == "isomap": 87 | reducer = Isomap( 88 | n_neighbors=n_neighbors, 89 | n_components=n_components, 90 | eigen_solver=eigen_solver, 91 | **kwargs, 92 | ) 93 | 94 | reducer.fit(X) 95 | adata.obsm["X_dr"] = reducer.embedding_ 96 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | # import os 14 | # import sys 15 | # sys.path.insert(0, os.path.abspath('.')) 16 | 17 | 18 | # -- Project information ----------------------------------------------------- 19 | 20 | project = "stream2" 21 | copyright = "2021, huidong chen" 22 | author = "huidong chen" 23 | 24 | # The full version, including alpha/beta/rc tags 25 | release = "v0.1.0" 26 | 27 | # -- Retrieve notebooks ------------------------------- 28 | 29 | from urllib.request import urlretrieve # noqa: E402 30 | 31 | notebooks_url = "https://github.com/pinellolab/STREAM2_tutorials/raw/main/tutorial_notebooks/" # noqa 32 | notebooks_v1_0 = [ 33 | "complex_structure.ipynb", 34 | "supervision_ordinal.ipynb", 35 | "supervision_categorical.ipynb", 36 | "multiomics.ipynb", 37 | "stream_plots.ipynb", 38 | ] 39 | 40 | for nb in notebooks_v1_0: 41 | try: 42 | urlretrieve(notebooks_url + nb, nb) 43 | except Exception: 44 | pass 45 | 46 | # -- General configuration --------------------------------------------------- 47 | 48 | # Add any Sphinx extension module names here, as strings. They can be 49 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 50 | # ones. 51 | extensions = [ 52 | "sphinx.ext.autodoc", 53 | "sphinx.ext.autosummary", 54 | 'sphinx.ext.napoleon', 55 | "sphinx.ext.intersphinx", 56 | "sphinx.ext.mathjax", 57 | "sphinx.ext.viewcode", 58 | "nbsphinx", 59 | ] 60 | autosummary_generate = True 61 | 62 | # Add any paths that contain templates here, relative to this directory. 63 | templates_path = ["_templates"] 64 | 65 | # List of patterns, relative to source directory, that match files and 66 | # directories to ignore when looking for source files. 67 | # This pattern also affects html_static_path and html_extra_path. 68 | exclude_patterns = ['_build'] 69 | 70 | 71 | # -- Options for HTML output ------------------------------------------------- 72 | 73 | # The theme to use for HTML and HTML Help pages. See the documentation for 74 | # a list of builtin themes. 75 | # 76 | html_theme = "sphinx_rtd_theme" 77 | 78 | github_repo = 'stream2' 79 | github_nb_repo = 'stream2_tutorials' 80 | 81 | # Add any paths that contain custom static files (such as style sheets) here, 82 | # relative to this directory. They are copied after the builtin static files, 83 | # so a file named "default.css" will overwrite the builtin "default.css". 84 | html_static_path = ["_static"] 85 | -------------------------------------------------------------------------------- /stream2/_settings.py: -------------------------------------------------------------------------------- 1 | """Configuration for STREAM2.""" 2 | 3 | import os 4 | import seaborn as sns 5 | import matplotlib as mpl 6 | 7 | 8 | class Stream2Config: 9 | """configuration class for STREAM2.""" 10 | 11 | def __init__(self, workdir="./result_stream2", save_fig=False, n_jobs=1): 12 | self.workdir = workdir 13 | self.save_fig = save_fig 14 | self.n_jobs = n_jobs 15 | 16 | def set_figure_params( 17 | self, 18 | context="notebook", 19 | style="white", 20 | palette="deep", 21 | font="sans-serif", 22 | font_scale=1.1, 23 | color_codes=True, 24 | save_fig=False, 25 | dpi=80, 26 | dpi_save=150, 27 | fig_size=[5.4, 4.8], 28 | rc=None, 29 | ): 30 | """Set global parameters for figures. Modified from sns.set() 31 | 32 | Parameters 33 | ---------- 34 | context : string or dict 35 | Plotting context parameters, see seaborn :func:`plotting_context` 36 | style: `string`,optional (default: 'white') 37 | Axes style parameters, see seaborn :func:`axes_style` 38 | palette : string or sequence 39 | Color palette, see seaborn :func:`color_palette` 40 | font_scale: `float`, optional (default: 1.3) 41 | Separate scaling factor to independently 42 | scale the size of the font elements. 43 | color_codes : `bool`, optional (default: True) 44 | If ``True`` and ``palette`` is a seaborn palette, 45 | remap the shorthand color codes (e.g. "b", "g", "r", etc.) 46 | to the colors from this palette. 47 | dpi: `int`,optional (default: 80) 48 | Resolution of rendered figures. 49 | dpi_save: `int`,optional (default: 150) 50 | Resolution of saved figures. 51 | rc: `dict`,optional (default: None) 52 | rc settings properties. 53 | Parameter mappings to override the values in the preset style. 54 | Please see https://matplotlib.org/tutorials/introductory/customizing.html#a-sample-matplotlibrc-file # noqa 55 | 56 | Returns 57 | ------- 58 | """ 59 | # mpl.rcParams.update(mpl.rcParamsDefault) 60 | sns.set( 61 | context=context, 62 | style=style, 63 | palette=palette, 64 | font=font, 65 | font_scale=font_scale, 66 | color_codes=color_codes, 67 | rc={ 68 | "figure.dpi": dpi, 69 | "savefig.dpi": dpi_save, 70 | "figure.figsize": fig_size, 71 | "image.cmap": "viridis", 72 | "lines.markersize": 6, 73 | "legend.columnspacing": 0.1, 74 | "legend.borderaxespad": 0.1, 75 | "legend.handletextpad": 0.1, 76 | "pdf.fonttype": 42, 77 | }, 78 | ) 79 | if rc is not None: 80 | assert isinstance(rc, dict), "rc must be dict" 81 | for key, value in rc.items(): 82 | if key in mpl.rcParams.keys(): 83 | mpl.rcParams[key] = value 84 | else: 85 | raise Exception("unrecognized property '%s'" % key) 86 | 87 | def set_workdir(self, workdir=None): 88 | """Set working directory. 89 | 90 | Parameters 91 | ---------- 92 | workdir: `str`, optional (default: None) 93 | Working directory. 94 | 95 | Returns 96 | ------- 97 | """ 98 | if workdir is None: 99 | workdir = self.workdir 100 | print("Using default working directory.") 101 | if not os.path.exists(workdir): 102 | os.makedirs(workdir) 103 | self.workdir = workdir 104 | print("Saving results in: %s" % workdir) 105 | 106 | 107 | settings = Stream2Config() 108 | -------------------------------------------------------------------------------- /stream2/plotting/_utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions and classes.""" 2 | 3 | import numpy as np 4 | import pandas as pd 5 | from pandas.api.types import ( 6 | is_numeric_dtype, 7 | is_string_dtype, 8 | is_categorical_dtype, 9 | ) 10 | import matplotlib as mpl 11 | 12 | from ._palettes import default_20, default_28, default_102 13 | 14 | 15 | def get_colors(arr, vmin=None, vmax=None, clip=False): 16 | """Generate a list of colors for a given array.""" 17 | 18 | if not isinstance(arr, (pd.Series, np.ndarray)): 19 | raise TypeError("`arr` must be pd.Series or np.ndarray") 20 | colors = [] 21 | if is_numeric_dtype(arr): 22 | image_cmap = mpl.rcParams["image.cmap"] 23 | cm = mpl.cm.get_cmap(image_cmap, 512) 24 | if vmin is None: 25 | vmin = min(arr) 26 | if vmax is None: 27 | vmax = max(arr) 28 | norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax, clip=clip) 29 | colors = [mpl.colors.to_hex(cm(norm(x))) for x in arr] 30 | elif is_string_dtype(arr) or is_categorical_dtype(arr): 31 | categories = np.unique(arr) 32 | length = len(categories) 33 | # check if default matplotlib palette has enough colors 34 | # mpl.style.use('default') 35 | if len(mpl.rcParams["axes.prop_cycle"].by_key()["color"]) >= length: 36 | cc = mpl.rcParams["axes.prop_cycle"]() 37 | palette = [ 38 | mpl.colors.rgb2hex(next(cc)["color"]) for _ in range(length) 39 | ] 40 | else: 41 | if length <= 20: 42 | palette = default_20 43 | elif length <= 28: 44 | palette = default_28 45 | elif length <= len(default_102): # 103 colors 46 | palette = default_102 47 | else: 48 | rgb_rainbow = mpl.cm.rainbow(np.linspace(0, 1, length)) 49 | palette = [ 50 | mpl.colors.rgb2hex(rgb_rainbow[i, :-1]) 51 | for i in range(length) 52 | ] 53 | colors = pd.Series([""] * len(arr)) 54 | for i, x in enumerate(categories): 55 | ids = np.where(arr == x)[0] 56 | colors[ids] = palette[i] 57 | colors = list(colors) 58 | else: 59 | raise TypeError("unsupported data type for `arr`") 60 | return colors 61 | 62 | 63 | def generate_palette(arr): 64 | """Generate a color palette for a given array.""" 65 | 66 | if not isinstance(arr, (pd.Series, np.ndarray)): 67 | raise TypeError("`arr` must be pd.Series or np.ndarray") 68 | colors = [] 69 | if is_string_dtype(arr) or is_categorical_dtype(arr): 70 | categories = np.unique(arr) 71 | length = len(categories) 72 | # check if default matplotlib palette has enough colors 73 | # mpl.style.use('default') 74 | if len(mpl.rcParams["axes.prop_cycle"].by_key()["color"]) >= length: 75 | cc = mpl.rcParams["axes.prop_cycle"]() 76 | palette = [ 77 | mpl.colors.rgb2hex(next(cc)["color"]) for _ in range(length) 78 | ] 79 | else: 80 | if length <= 20: 81 | palette = default_20 82 | elif length <= 28: 83 | palette = default_28 84 | elif length <= len(default_102): # 103 colors 85 | palette = default_102 86 | else: 87 | rgb_rainbow = mpl.cm.rainbow(np.linspace(0, 1, length)) 88 | palette = [ 89 | mpl.colors.rgb2hex(rgb_rainbow[i, :-1]) 90 | for i in range(length) 91 | ] 92 | colors = pd.Series([""] * len(arr)) 93 | for i, x in enumerate(categories): 94 | ids = np.where(arr == x)[0] 95 | colors[ids] = palette[i] 96 | colors = list(colors) 97 | else: 98 | raise TypeError("unsupported data type for `arr`") 99 | dict_palette = dict(zip(arr, colors)) 100 | return dict_palette 101 | -------------------------------------------------------------------------------- /stream2/tools/_pseudotime.py: -------------------------------------------------------------------------------- 1 | """Pseudotime inference.""" 2 | 3 | import numpy as np 4 | import networkx as nx 5 | 6 | 7 | def infer_pseudotime( 8 | adata, 9 | source, 10 | target=None, 11 | nodes_to_include=None, 12 | key="epg", 13 | copy=False, 14 | ): 15 | """Infer pseudotime 16 | Parameters 17 | ---------- 18 | adata: AnnData 19 | Annotated data matrix. 20 | copy: `bool`, optional (default: False) 21 | If ``True``, return a copy instead of writing to adata. 22 | Returns 23 | ------- 24 | """ 25 | 26 | epg_edge = adata.uns[key]["edge"] 27 | epg_edge_len = adata.uns[key]["edge_len"] 28 | G = nx.Graph() 29 | edges_weighted = list(zip(epg_edge[:, 0], epg_edge[:, 1], epg_edge_len)) 30 | G.add_weighted_edges_from(edges_weighted, weight="len") 31 | if target is not None: 32 | if nodes_to_include is None: 33 | # nodes on the shortest path 34 | nodes_sp = nx.shortest_path( 35 | G, source=source, target=target, weight="len" 36 | ) 37 | else: 38 | assert isinstance( 39 | nodes_to_include, list 40 | ), "`nodes_to_include` must be list" 41 | # lists of simple paths, in order from shortest to longest 42 | list_paths = list( 43 | nx.shortest_simple_paths( 44 | G, source=source, target=target, weight="len" 45 | ) 46 | ) 47 | flag_exist = False 48 | for p in list_paths: 49 | if set(nodes_to_include).issubset(p): 50 | nodes_sp = p 51 | flag_exist = True 52 | break 53 | if not flag_exist: 54 | return f"no path that passes {nodes_to_include} exists" 55 | else: 56 | nodes_sp = [source] + [v for u, v in nx.bfs_edges(G, source)] 57 | G_sp = G.subgraph(nodes_sp).copy() 58 | index_nodes = { 59 | x: nodes_sp.index(x) if x in nodes_sp else G.number_of_nodes() 60 | for x in G.nodes 61 | } 62 | 63 | if target is None: 64 | dict_dist_to_source = nx.shortest_path_length( 65 | G_sp, source=source, weight="len" 66 | ) 67 | else: 68 | dict_dist_to_source = dict( 69 | zip( 70 | nodes_sp, 71 | np.cumsum( 72 | np.array( 73 | [0.0] 74 | + [ 75 | G.get_edge_data(nodes_sp[i], nodes_sp[i + 1])[ 76 | "len" 77 | ] 78 | for i in range(len(nodes_sp) - 1) 79 | ] 80 | ) 81 | ), 82 | ) 83 | ) 84 | 85 | cells = np.isin(adata.obs[f"{key}_node_id"], nodes_sp) 86 | id_edges_cell = adata.obs.loc[cells, f"{key}_edge_id"].tolist() 87 | edges_cell = adata.uns[key]["edge"][id_edges_cell, :] 88 | len_edges_cell = adata.uns[key]["edge_len"][id_edges_cell] 89 | 90 | # proportion on the edge 91 | prop_edge = np.clip( 92 | adata.obs.loc[cells, f"{key}_edge_loc"], a_min=0, a_max=1 93 | ).values 94 | 95 | dist_to_source = [] 96 | for i in np.arange(edges_cell.shape[0]): 97 | if index_nodes[edges_cell[i, 0]] > index_nodes[edges_cell[i, 1]]: 98 | dist_to_source.append(dict_dist_to_source[edges_cell[i, 1]]) 99 | prop_edge[i] = 1 - prop_edge[i] 100 | else: 101 | dist_to_source.append(dict_dist_to_source[edges_cell[i, 0]]) 102 | dist_to_source = np.array(dist_to_source) 103 | dist_on_edge = len_edges_cell * prop_edge 104 | dist = dist_to_source + dist_on_edge 105 | 106 | if copy: 107 | return dist 108 | else: 109 | adata.obs[f"{key}_pseudotime"] = np.nan 110 | adata.obs.loc[cells, f"{key}_pseudotime"] = dist 111 | adata.uns[f"{key}_pseudotime_params"] = { 112 | "source": source, 113 | "target": target, 114 | "nodes_to_include": nodes_to_include, 115 | } 116 | -------------------------------------------------------------------------------- /stream2/plotting/_palettes.py: -------------------------------------------------------------------------------- 1 | """Color palettes in addition to matplotlib's palettes. 2 | 3 | This is modifed from 4 | scanpy palettes https://github.com/theislab/scanpy/blob/master/scanpy/plotting/palettes.py # noqa 5 | """ 6 | 7 | from matplotlib import cm, colors 8 | 9 | # Colorblindness adjusted vega_10 10 | # See https://github.com/theislab/scanpy/issues/387 11 | vega_10 = list(map(colors.to_hex, cm.tab10.colors)) 12 | vega_10_scanpy = vega_10.copy() 13 | vega_10_scanpy[2] = "#279e68" # green 14 | vega_10_scanpy[4] = "#aa40fc" # purple 15 | vega_10_scanpy[8] = "#b5bd61" # kakhi 16 | 17 | # default matplotlib 2.0 palette 18 | # see 'category20' on https://github.com/vega/vega/wiki/Scales#scale-range-literals # noqa 19 | vega_20 = list(map(colors.to_hex, cm.tab20.colors)) 20 | 21 | # reorderd, some removed, some added 22 | vega_20_scanpy = [ 23 | *vega_20[0:14:2], 24 | *vega_20[16::2], # dark without grey 25 | *vega_20[1:15:2], 26 | *vega_20[17::2], # light without grey 27 | "#ad494a", 28 | "#8c6d31", # manual additions 29 | ] 30 | vega_20_scanpy[2] = vega_10_scanpy[2] 31 | vega_20_scanpy[4] = vega_10_scanpy[4] 32 | vega_20_scanpy[7] = vega_10_scanpy[8] # kakhi shifted by missing grey 33 | # TODO: also replace pale colors if necessary 34 | 35 | default_20 = vega_20_scanpy 36 | 37 | # https://graphicdesign.stackexchange.com/questions/3682/where-can-i-find-a-large-palette-set-of-contrasting-colors-for-coloring-many-d 38 | # update 1 39 | # orig reference http://epub.wu.ac.at/1692/1/document.pdf 40 | zeileis_28 = [ 41 | "#023fa5", 42 | "#7d87b9", 43 | "#bec1d4", 44 | "#d6bcc0", 45 | "#bb7784", 46 | "#8e063b", 47 | "#4a6fe3", 48 | "#8595e1", 49 | "#b5bbe3", 50 | "#e6afb9", 51 | "#e07b91", 52 | "#d33f6a", 53 | "#11c638", 54 | "#8dd593", 55 | "#c6dec7", 56 | "#ead3c6", 57 | "#f0b98d", 58 | "#ef9708", 59 | "#0fcfc0", 60 | "#9cded6", 61 | "#d5eae7", 62 | "#f3e1eb", 63 | "#f6c4e1", 64 | "#f79cd4", 65 | "#7f7f7f", 66 | "#c7c7c7", 67 | "#1CE6FF", 68 | "#336600", # these last ones were added, 69 | ] 70 | 71 | default_28 = zeileis_28 72 | 73 | # from http://godsnotwheregodsnot.blogspot.de/2012/09/color-distribution-methodology.html # noqa 74 | godsnot_102 = [ 75 | # "#000000", 76 | # remove the black, as often, we have black colored annotation 77 | "#FFFF00", 78 | "#1CE6FF", 79 | "#FF34FF", 80 | "#FF4A46", 81 | "#008941", 82 | "#006FA6", 83 | "#A30059", 84 | "#FFDBE5", 85 | "#7A4900", 86 | "#0000A6", 87 | "#63FFAC", 88 | "#B79762", 89 | "#004D43", 90 | "#8FB0FF", 91 | "#997D87", 92 | "#5A0007", 93 | "#809693", 94 | "#6A3A4C", 95 | "#1B4400", 96 | "#4FC601", 97 | "#3B5DFF", 98 | "#4A3B53", 99 | "#FF2F80", 100 | "#61615A", 101 | "#BA0900", 102 | "#6B7900", 103 | "#00C2A0", 104 | "#FFAA92", 105 | "#FF90C9", 106 | "#B903AA", 107 | "#D16100", 108 | "#DDEFFF", 109 | "#000035", 110 | "#7B4F4B", 111 | "#A1C299", 112 | "#300018", 113 | "#0AA6D8", 114 | "#013349", 115 | "#00846F", 116 | "#372101", 117 | "#FFB500", 118 | "#C2FFED", 119 | "#A079BF", 120 | "#CC0744", 121 | "#C0B9B2", 122 | "#C2FF99", 123 | "#001E09", 124 | "#00489C", 125 | "#6F0062", 126 | "#0CBD66", 127 | "#EEC3FF", 128 | "#456D75", 129 | "#B77B68", 130 | "#7A87A1", 131 | "#788D66", 132 | "#885578", 133 | "#FAD09F", 134 | "#FF8A9A", 135 | "#D157A0", 136 | "#BEC459", 137 | "#456648", 138 | "#0086ED", 139 | "#886F4C", 140 | "#34362D", 141 | "#B4A8BD", 142 | "#00A6AA", 143 | "#452C2C", 144 | "#636375", 145 | "#A3C8C9", 146 | "#FF913F", 147 | "#938A81", 148 | "#575329", 149 | "#00FECF", 150 | "#B05B6F", 151 | "#8CD0FF", 152 | "#3B9700", 153 | "#04F757", 154 | "#C8A1A1", 155 | "#1E6E00", 156 | "#7900D7", 157 | "#A77500", 158 | "#6367A9", 159 | "#A05837", 160 | "#6B002C", 161 | "#772600", 162 | "#D790FF", 163 | "#9B9700", 164 | "#549E79", 165 | "#FFF69F", 166 | "#201625", 167 | "#72418F", 168 | "#BC23FF", 169 | "#99ADC0", 170 | "#3A2465", 171 | "#922329", 172 | "#5B4534", 173 | "#FDE8DC", 174 | "#404E55", 175 | "#0089A3", 176 | "#CB7E98", 177 | "#A4E804", 178 | "#324E72", 179 | ] 180 | 181 | default_102 = godsnot_102 182 | -------------------------------------------------------------------------------- /stream2/_utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions and classes.""" 2 | 3 | from kneed import KneeLocator 4 | from copy import deepcopy 5 | import networkx as nx 6 | import numpy as np 7 | import pandas as pd 8 | import scipy 9 | 10 | 11 | def locate_elbow( 12 | x, 13 | y, 14 | S=10, 15 | min_elbow=0, 16 | curve="convex", 17 | direction="decreasing", 18 | online=False, 19 | **kwargs, 20 | ): 21 | """Detect knee points 22 | Parameters 23 | ---------- 24 | x : `array-like` 25 | x values 26 | y : `array-like` 27 | y values 28 | S : `float`, optional (default: 10) 29 | Sensitivity 30 | min_elbow: `int`, optional (default: 0) 31 | The minimum elbow location 32 | curve: `str`, optional (default: 'convex') 33 | Choose from {'convex','concave'} 34 | If 'concave', algorithm will detect knees, 35 | If 'convex', algorithm will detect elbows. 36 | direction: `str`, optional (default: 'decreasing') 37 | Choose from {'decreasing','increasing'} 38 | online: `bool`, optional (default: False) 39 | kneed will correct old knee points if True, 40 | kneed will return first knee if False. 41 | **kwargs: `dict`, optional 42 | Extra arguments to KneeLocator. 43 | Returns 44 | ------- 45 | elbow: `int` 46 | elbow point 47 | """ 48 | kneedle = KneeLocator( 49 | x[int(min_elbow) :], 50 | y[int(min_elbow) :], 51 | S=S, 52 | curve=curve, 53 | direction=direction, 54 | online=online, 55 | **kwargs, 56 | ) 57 | if kneedle.elbow is None: 58 | elbow = len(y) 59 | else: 60 | elbow = int(kneedle.elbow) 61 | return elbow 62 | 63 | 64 | def get_path( 65 | adata, source=None, target=None, nodes_to_include=None, key="epg" 66 | ): 67 | #### Extract cells by provided nodes 68 | 69 | epg_edge = adata.uns[key]["edge"] 70 | epg_edge_len = adata.uns[key]["edge_len"] 71 | G = nx.Graph() 72 | edges_weighted = list(zip(epg_edge[:, 0], epg_edge[:, 1], epg_edge_len)) 73 | G.add_weighted_edges_from(edges_weighted, weight="len") 74 | 75 | if source is None: 76 | source = adata.uns[f"{key}_pseudotime_params"]["source"] 77 | if target is None: 78 | target = adata.uns[f"{key}_pseudotime_params"]["target"] 79 | if nodes_to_include is None: 80 | nodes_to_include = adata.uns[f"{key}_pseudotime_params"][ 81 | "nodes_to_include" 82 | ] 83 | 84 | if target is not None: 85 | if nodes_to_include is None: 86 | # nodes on the shortest path 87 | nodes_sp = nx.shortest_path( 88 | G, source=source, target=target, weight="len" 89 | ) 90 | else: 91 | assert isinstance( 92 | nodes_to_include, list 93 | ), "`nodes_to_include` must be list" 94 | # lists of simple paths, in order from shortest to longest 95 | list_paths = list( 96 | nx.shortest_simple_paths( 97 | G, source=source, target=target, weight="len" 98 | ) 99 | ) 100 | flag_exist = False 101 | for p in list_paths: 102 | if set(nodes_to_include).issubset(p): 103 | nodes_sp = p 104 | flag_exist = True 105 | break 106 | if not flag_exist: 107 | print(f"no path that passes {nodes_to_include} exists") 108 | else: 109 | nodes_sp = [source] + [v for u, v in nx.bfs_edges(G, source)] 110 | 111 | cells = adata.obs_names[np.isin(adata.obs[f"{key}_node_id"], nodes_sp)] 112 | path_alias = "Path_%s-%s-%s" % (source, nodes_to_include, target) 113 | print( 114 | len(cells), 115 | "Cells are selected for Path_Source_Nodes-to-include_Target : ", 116 | path_alias, 117 | ) 118 | return cells, path_alias 119 | 120 | 121 | def get_expdata( 122 | adata, source=None, target=None, nodes_to_include=None, key="epg" 123 | ): 124 | cells, path_alias = get_path(adata, source, target, nodes_to_include, key) 125 | 126 | if scipy.sparse.issparse(adata.X): 127 | mat = adata.X.todense() 128 | else: 129 | mat = adata.X 130 | 131 | df_sc = pd.DataFrame( 132 | index=adata.obs_names.tolist(), 133 | data=mat, 134 | columns=adata.var.index.tolist(), 135 | ) 136 | df_cells = deepcopy(df_sc.loc[cells]) 137 | df_cells[f"{key}_pseudotime"] = adata.obs[f"{key}_pseudotime"][cells] 138 | df_cells_sort = df_cells.sort_values( 139 | by=[f"{key}_pseudotime"], ascending=True 140 | ) 141 | 142 | return df_cells_sort, path_alias 143 | 144 | 145 | def stream2elpi(adata, key="epg"): 146 | PG = { 147 | "NodePositions": adata.uns[key]["node_pos"].astype(float), 148 | "Edges": [ 149 | adata.uns[key]["edge"], 150 | np.repeat( 151 | adata.uns[key]["params"]["epg_lambda"], 152 | len(adata.uns[key]["node_pos"]), 153 | ), 154 | ], 155 | "Lambda": adata.uns[key]["params"]["epg_lambda"], 156 | "Mu": adata.uns[key]["params"]["epg_mu"], 157 | "projection": {"edge_len": adata.uns[key]["edge_len"]}, 158 | } 159 | return PG 160 | -------------------------------------------------------------------------------- /stream2/preprocessing/_pca.py: -------------------------------------------------------------------------------- 1 | """Principal component analysis.""" 2 | 3 | import numpy as np 4 | from sklearn.decomposition import TruncatedSVD 5 | from ._utils import ( 6 | locate_elbow, 7 | ) 8 | 9 | 10 | def pca( 11 | adata, 12 | n_components=50, 13 | algorithm="randomized", 14 | n_iter=5, 15 | random_state=2021, 16 | tol=0.0, 17 | feature=None, 18 | **kwargs, 19 | ): 20 | """perform Principal Component Analysis (PCA) 21 | 22 | Parameters 23 | ---------- 24 | adata: AnnData 25 | Annotated data matrix. 26 | n_components: `int`, optional (default: 50) 27 | Desired dimensionality of output data 28 | algorithm: `str`, optional (default: 'randomized') 29 | SVD solver to use. Choose from {'arpack', 'randomized'}. 30 | n_iter: `int`, optional (default: '5') 31 | Number of iterations for randomized SVD solver. 32 | Not used by ARPACK. 33 | tol: `float`, optional (default: 0) 34 | Tolerance for ARPACK. 0 means machine precision. 35 | Ignored by randomized SVD solver. 36 | feature: `str`, optional (default: None) 37 | Feature used to perform PCA. 38 | The data type of `.var[feature]` needs to be `bool` 39 | If None, adata.X will be used. 40 | kwargs: 41 | Other keyword arguments are passed down to `TruncatedSVD()` 42 | 43 | Returns 44 | ------- 45 | updates `adata` with the following fields: 46 | `.obsm['X_pca']` : `array` 47 | PCA transformed X. 48 | `.uns['pca']['PCs']` : `array` 49 | Principal components in feature space, 50 | representing the directions of maximum variance in the data. 51 | `.uns['pca']['variance']` : `array` 52 | The variance of the training samples transformed by a 53 | projection to each component. 54 | `.uns['pca']['variance_ratio']` : `array` 55 | Percentage of variance explained by each of the selected components. 56 | """ 57 | if feature is None: 58 | X = adata.X.copy() 59 | else: 60 | mask = adata.var[feature] 61 | X = adata[:, mask].X.copy() 62 | svd = TruncatedSVD( 63 | n_components=n_components, 64 | algorithm=algorithm, 65 | n_iter=n_iter, 66 | random_state=random_state, 67 | tol=tol, 68 | **kwargs, 69 | ) 70 | svd.fit(X) 71 | adata.obsm["X_pca"] = svd.transform(X) 72 | adata.uns["pca"] = dict() 73 | adata.uns["pca"]["n_pcs"] = n_components 74 | adata.uns["pca"]["PCs"] = svd.components_.T 75 | adata.uns["pca"]["variance"] = svd.explained_variance_ 76 | adata.uns["pca"]["variance_ratio"] = svd.explained_variance_ratio_ 77 | 78 | 79 | def select_pcs( 80 | adata, 81 | n_pcs=None, 82 | S=1, 83 | curve="convex", 84 | direction="decreasing", 85 | online=False, 86 | min_elbow=None, 87 | **kwargs, 88 | ): 89 | """select top PCs based on variance_ratio.""" 90 | if n_pcs is None: 91 | n_components = adata.obsm["X_pca"].shape[1] 92 | if min_elbow is None: 93 | min_elbow = n_components / 10 94 | n_pcs = locate_elbow( 95 | range(n_components), 96 | adata.uns["pca"]["variance_ratio"], 97 | S=S, 98 | curve=curve, 99 | min_elbow=min_elbow, 100 | direction=direction, 101 | online=online, 102 | **kwargs, 103 | ) 104 | adata.uns["pca"]["n_pcs"] = n_pcs 105 | else: 106 | adata.uns["pca"]["n_pcs"] = n_pcs 107 | 108 | 109 | def select_pcs_features( 110 | adata, 111 | S=1, 112 | curve="convex", 113 | direction="decreasing", 114 | online=False, 115 | min_elbow=None, 116 | **kwargs, 117 | ): 118 | """select features that contribute to the top PCs. 119 | 120 | S : `float`, optional (default: 10) 121 | Sensitivity 122 | min_elbow: `int`, optional (default: 0) 123 | The minimum elbow location 124 | curve: `str`, optional (default: 'convex') 125 | Choose from {'convex','concave'} 126 | If 'concave', algorithm will detect knees, 127 | If 'convex', algorithm will detect elbows. 128 | direction: `str`, optional (default: 'decreasing') 129 | Choose from {'decreasing','increasing'} 130 | online: `bool`, optional (default: False) 131 | kneed will correct old knee points if True, 132 | kneed will return first knee if False. 133 | **kwargs: `dict`, optional 134 | Extra arguments to KneeLocator. 135 | """ 136 | n_pcs = adata.uns["pca"]["n_pcs"] 137 | n_features = adata.uns["pca"]["PCs"].shape[0] 138 | if min_elbow is None: 139 | min_elbow = n_features / 6 140 | adata.uns["pca"]["features"] = dict() 141 | ids_features = list() 142 | for i in range(n_pcs): 143 | elbow = locate_elbow( 144 | range(n_features), 145 | np.sort( 146 | np.abs( 147 | adata.uns["pca"]["PCs"][:, i], 148 | ) 149 | )[::-1], 150 | S=S, 151 | min_elbow=min_elbow, 152 | curve=curve, 153 | direction=direction, 154 | online=online, 155 | **kwargs, 156 | ) 157 | ids_features_i = list( 158 | np.argsort(np.abs(adata.uns["pca"]["PCs"][:, i],))[ 159 | ::-1 160 | ][:elbow] 161 | ) 162 | adata.uns["pca"]["features"][f"pc_{i}"] = ids_features_i 163 | ids_features = ids_features + ids_features_i 164 | print(f"#features selected from PC {i}: {len(ids_features_i)}") 165 | adata.var["top_pcs"] = False 166 | adata.var.loc[adata.var_names[np.unique(ids_features)], "top_pcs"] = True 167 | print(f'#features in total: {adata.var["top_pcs"].sum()}') 168 | -------------------------------------------------------------------------------- /stream2/tools/_markers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numba as nb 3 | import pandas as pd 4 | import math 5 | import scipy 6 | import os 7 | from copy import deepcopy 8 | from statsmodels.sandbox.stats.multicomp import multipletests 9 | 10 | from .. import _utils 11 | from .._settings import settings 12 | 13 | 14 | def spearman_columns(A, B): 15 | """Spearman correlation over columns 16 | A,B: np.arrays with same shape 17 | 18 | Returns 19 | ------- 20 | correlations: np.array 21 | correlations[i] = spearman_corrcoef(A[:,i],B[:,i]) 22 | """ 23 | assert A.shape == B.shape 24 | return pearson_corr(rankdata(A.T), rankdata(B.T)) 25 | 26 | 27 | def spearman_pairwise(A, B): 28 | """Spearman correlation matrix 29 | A,B: np.arrays with same shape 30 | 31 | Returns 32 | ------- 33 | correlations: np.array 34 | correlations[i,j] = spearman_corrcoef(A[:,i],B[:,j]) 35 | """ 36 | n, m = A.shape[1], B.shape[1] 37 | i, j = np.ones((n, m)).nonzero() 38 | return pearson_corr(rankdata(A.T[i]), rankdata(B.T[j])).reshape(n, m) 39 | 40 | 41 | @nb.njit(parallel=True) 42 | def xicorr_columns(A, B): 43 | """XI correlation over columns 44 | A,B: 2d np.arrays with same shape 45 | 46 | Returns 47 | ------- 48 | correlations: 49 | correlations[i] = xi_corrcoef(A[:,i],B[:,i]) 50 | """ 51 | assert A.shape == B.shape 52 | n, m = A.shape 53 | corrs = np.zeros(m) 54 | pvals = np.zeros(m) 55 | for i in nb.prange(m): 56 | corrs[i], pvals[i] = xicorr(A[:, i], B[:, i], n) 57 | return corrs, pvals 58 | 59 | 60 | @nb.njit(parallel=True) 61 | def xicorr_pairwise(A, B): 62 | """XI correlation over columns 63 | A,B: 2d np.arrays with same shape 64 | 65 | Returns 66 | ------- 67 | correlations: 68 | correlations[i] = xi_corrcoef(A[:,i],B[:,i]) 69 | """ 70 | assert A.shape == B.shape 71 | ns = len(A) 72 | 73 | n, m = A.shape[1], B.shape[1] 74 | 75 | corrs = np.ones((n, m)) 76 | pvals = np.ones((n, m)) 77 | for i in nb.prange(n): 78 | for j in nb.prange(m): 79 | corrs[i, j], pvals[i, j] = xicorr(A[:, i], B[:, j], ns) 80 | return corrs, pvals 81 | 82 | 83 | @nb.njit 84 | def _nb_unique1d(ar): 85 | """Numba speedup.""" 86 | ar = ar.flatten() 87 | perm = ar.argsort(kind="mergesort") 88 | aux = ar[perm] 89 | 90 | mask = np.empty(aux.shape, dtype=np.bool_) 91 | mask[:1] = True 92 | if aux.shape[0] > 0 and aux.dtype.kind in "cfmM" and np.isnan(aux[-1]): 93 | if ( 94 | aux.dtype.kind == "c" 95 | ): # for complex all NaNs are considered equivalent 96 | aux_firstnan = np.searchsorted(np.isnan(aux), True, side="left") 97 | else: 98 | aux_firstnan = np.searchsorted(aux, aux[-1], side="left") 99 | mask[1:aux_firstnan] = aux[1:aux_firstnan] != aux[: aux_firstnan - 1] 100 | mask[aux_firstnan] = True 101 | mask[aux_firstnan + 1 :] = False 102 | else: 103 | mask[1:] = aux[1:] != aux[:-1] 104 | 105 | imask = np.cumsum(mask) - 1 106 | inv_idx = np.empty(mask.shape, dtype=np.intp) 107 | inv_idx[perm] = imask 108 | idx = np.append(np.nonzero(mask)[0], mask.size) 109 | 110 | # idx #inverse #counts 111 | return aux[mask], perm[mask], inv_idx, np.diff(idx) 112 | 113 | 114 | @nb.njit 115 | def normal_cdf(x): 116 | """Cumulative distribution function for the standard normal distribution""" 117 | return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 118 | 119 | 120 | @nb.njit 121 | def average_ties(X): 122 | """Same as scipy.stats.rankdata method="average".""" 123 | xi = np.argsort(X) 124 | xi_rank = np.argsort(xi) 125 | unique, _, inverse, c_ = _nb_unique1d(X) 126 | unique_rank_sum = np.zeros_like(unique) 127 | for i0, inv in enumerate(inverse): 128 | unique_rank_sum[inv] += xi_rank[i0] 129 | unique_count = np.zeros_like(unique) 130 | for i0, inv in enumerate(inverse): 131 | unique_count[inv] += 1 132 | unique_rank_mean = unique_rank_sum / unique_count 133 | rank_mean = unique_rank_mean[inverse] 134 | return rank_mean + 1 135 | 136 | 137 | def average_stat( 138 | mdata, 139 | transition_markers, 140 | ): 141 | """Average correlation coefficients 142 | doi.org/10.3758/BF03334037 143 | """ 144 | counts = pd.Series( 145 | { 146 | assay: sum(~np.isnan(adata.obs["epg_pseudotime"])) 147 | for assay, adata in mdata.mod.items() 148 | } 149 | ) 150 | 151 | joint_index = set.intersection( 152 | *[set(df.index) for df in transition_markers.values()] 153 | ) 154 | joint_dfs = { 155 | assay: df.loc[joint_index] for assay, df in transition_markers.items() 156 | } 157 | 158 | joint_stats = pd.concat( 159 | [df["stat"] for assay, df in joint_dfs.items()], axis=1 160 | ) 161 | joint_stats.columns = joint_dfs.keys() 162 | 163 | avg_stat = np.tanh( 164 | np.sum( 165 | np.arctanh(joint_stats) 166 | * (counts - 3) 167 | / (sum(counts) - 3 * len(counts)), 168 | axis=1, 169 | ) 170 | ) 171 | return avg_stat 172 | 173 | 174 | @nb.njit 175 | def xicorr(x, y, n): 176 | """Translated from R https://github.com/cran/XICOR/""" 177 | # ---corr 178 | PI = average_ties(x) 179 | fr = average_ties(y) / n 180 | gr = average_ties(-y) / n 181 | fr = fr[np.argsort(PI, kind="mergesort")] 182 | 183 | CU = np.mean(gr * (1 - gr)) 184 | A1 = np.abs(np.diff(fr)).sum() / (2 * n) 185 | xi = 1 - A1 / CU 186 | 187 | # ---pval 188 | qfr = np.sort(fr) 189 | ind = np.arange(n) + 1 190 | ind2 = np.array([2 * n - 2 * ind[i - 1] + 1 for i in ind]) 191 | 192 | ai = np.mean(ind2 * qfr * qfr) / n 193 | ci = np.mean(ind2 * qfr) / n 194 | cq = np.cumsum(qfr) 195 | 196 | m = (cq + (n - ind) * qfr) / n 197 | b = np.mean(m ** 2) 198 | v = (ai - 2 * b + np.square(ci)) / np.square(CU) 199 | 200 | # sd = np.sqrt(v/n) 201 | pval = 1 - normal_cdf(np.sqrt(n) * xi / np.sqrt(v)) 202 | return xi, pval 203 | 204 | 205 | @nb.njit(parallel=True) 206 | def xicorr_ps(x, Y): 207 | """Fast xi correlation coefficient 208 | x: 0d np.array 209 | Y: 2d np.array 210 | """ 211 | n = len(Y) 212 | corrs = np.zeros(Y.shape[1]) 213 | pvals = np.zeros(Y.shape[1]) 214 | for i in nb.prange(Y.shape[1]): 215 | corrs[i], pvals[i] = xicorr(x,Y[:, i], n) 216 | return corrs, pvals 217 | 218 | 219 | def spearman_ps(x, Y): 220 | """Fast spearman correlation coefficient 221 | X: 2d np.array 222 | y: 0d np.array 223 | """ 224 | return pearson_corr(rankdata(x[None]), rankdata(Y.T)) 225 | 226 | 227 | def pearson_corr(arr1, arr2): 228 | """Pearson correlation along the last dimension of two multidimensional 229 | arrays.""" 230 | mean1 = np.mean(arr1, axis=-1, keepdims=1) 231 | mean2 = np.mean(arr2, axis=-1, keepdims=1) 232 | dev1, dev2 = arr1 - mean1, arr2 - mean2 233 | sqdev1, sqdev2 = np.square(dev1), np.square(dev2) 234 | numer = np.sum(dev1 * dev2, axis=-1) # Covariance 235 | var1, var2 = np.sum(sqdev1, axis=-1), np.sum(sqdev2, axis=-1) # Variances 236 | denom = np.sqrt(var1 * var2) 237 | 238 | # Divide numerator by denominator, but use NaN where the denominator is 0 239 | return np.divide( 240 | numer, denom, out=np.full_like(numer, np.nan), where=(denom != 0) 241 | ) 242 | 243 | 244 | @nb.njit(parallel=True, fastmath=True) 245 | def rankdata(X): 246 | """reimplementing scipy.stats.rankdata faster.""" 247 | tmp = np.zeros_like(X) 248 | for i in nb.prange(X.shape[0]): 249 | tmp[i] = _rankdata_inner(X[i]) 250 | return tmp 251 | 252 | 253 | @nb.njit 254 | def _rankdata_inner(x): 255 | """inner loop for rankdata.""" 256 | sorter = np.argsort(x) 257 | 258 | inv = np.empty(sorter.size, dtype=np.intp) 259 | inv[sorter] = np.arange(sorter.size, dtype=np.intp) 260 | 261 | x = x[sorter] 262 | obs = np.concatenate((np.array([True]), x[1:] != x[:-1])) 263 | dense = obs.cumsum()[inv] 264 | 265 | # cumulative counts of each unique value 266 | count = np.append(np.nonzero(obs)[0], len(obs)) 267 | # average method 268 | return 0.5 * (count[dense] + count[dense - 1] + 1) 269 | 270 | 271 | def p_val(r, n): 272 | t = r * np.sqrt((n - 2) / (1 - r ** 2)) 273 | return scipy.stats.t.sf(np.abs(t), n - 1) * 2 274 | 275 | 276 | def scale_marker_expr(df_marker_detection, percentile_expr): 277 | # optimal version for STREAM1 278 | ind_neg = df_marker_detection.min() < 0 279 | ind_pos = df_marker_detection.min() >= 0 280 | df_neg = df_marker_detection.loc[:, ind_neg] 281 | df_pos = df_marker_detection.loc[:, ind_pos] 282 | 283 | if ind_neg.sum() > 0: 284 | print("Matrix contains negative values...") 285 | # genes with negative values 286 | minValues = df_neg.apply( 287 | lambda x: np.percentile(x[x < 0], 100 - percentile_expr), axis=0 288 | ) 289 | maxValues = df_neg.apply( 290 | lambda x: np.percentile(x[x > 0], percentile_expr), axis=0 291 | ) 292 | for i in range(df_neg.shape[1]): 293 | df_gene = df_neg.iloc[:, i].copy(deep=True) 294 | df_gene[df_gene < minValues[i]] = minValues[i] 295 | df_gene[df_gene > maxValues[i]] = maxValues[i] 296 | df_neg.iloc[:, i] = df_gene - minValues[i] 297 | df_neg = df_neg.copy(deep=True) 298 | maxValues = df_neg.max(axis=0) 299 | df_neg_scaled = df_neg / np.array(maxValues)[:, None].T 300 | else: 301 | df_neg_scaled = pd.DataFrame(index=df_neg.index) 302 | 303 | if ind_pos.sum() > 0: 304 | maxValues = df_pos.apply( 305 | lambda x: np.percentile(x[x > 0], percentile_expr), axis=0 306 | ) 307 | df_pos_scaled = df_pos / np.array(maxValues)[:, None].T 308 | df_pos_scaled[df_pos_scaled > 1] = 1 309 | else: 310 | df_pos_scaled = pd.DataFrame(index=df_pos.index) 311 | 312 | df_marker_detection_scaled = pd.concat( 313 | [df_neg_scaled, df_pos_scaled], axis=1 314 | ) 315 | 316 | return df_marker_detection_scaled 317 | 318 | 319 | def detect_transition_markers( 320 | adata, 321 | percentile_expr=95, 322 | min_num_cells=5, 323 | fc_cutoff=1, 324 | method="spearman", 325 | key="epg", 326 | ): 327 | 328 | file_path = os.path.join(settings.workdir, "transition_markers") 329 | if not os.path.exists(file_path): 330 | os.makedirs(file_path) 331 | 332 | # Extract cells by parameters in previous infer_pseudotime() step 333 | path_source = adata.uns[f"{key}_pseudotime_params"]["source"] 334 | path_target = adata.uns[f"{key}_pseudotime_params"]["target"] 335 | nodes_to_include_path = adata.uns[f"{key}_pseudotime_params"][ 336 | "nodes_to_include" 337 | ] 338 | 339 | if path_target is None: 340 | print( 341 | "Please re-run infer_pseudotime() and specify value for " 342 | "parameter target" 343 | ) 344 | exit() 345 | 346 | cells, path_alias = _utils.get_path( 347 | adata, path_source, path_target, nodes_to_include_path, key 348 | ) 349 | 350 | # Scale matrix with expressed markers 351 | input_markers = adata.var_names.tolist() 352 | if scipy.sparse.issparse(adata.X): 353 | mat = adata[:, input_markers].X.todense() 354 | else: 355 | mat = adata[:, input_markers].X 356 | 357 | df_sc = pd.DataFrame( 358 | index=adata.obs_names.tolist(), 359 | data=mat, 360 | columns=input_markers, 361 | ) 362 | 363 | print( 364 | "Filtering out markers that are expressed in less than " 365 | + str(min_num_cells) 366 | + " cells ..." 367 | ) 368 | input_markers_expressed = np.array(input_markers)[ 369 | np.where((df_sc[input_markers]>0).sum(axis=0) > min_num_cells)[0] 370 | ].tolist() 371 | df_marker_detection = df_sc[input_markers_expressed].copy() 372 | 373 | df_scaled_marker_expr = scale_marker_expr( 374 | df_marker_detection, percentile_expr 375 | ) 376 | adata.uns["scaled_marker_expr"] = df_scaled_marker_expr 377 | 378 | print(str(len(input_markers_expressed)) + " markers are being scanned ...") 379 | 380 | df_cells = deepcopy(df_scaled_marker_expr.loc[cells]) 381 | pseudotime_cells = adata.obs[f"{key}_pseudotime"][cells] 382 | df_cells_sort = df_cells.iloc[np.argsort(pseudotime_cells)] 383 | pseudotime_cells_sort = pseudotime_cells[np.argsort(pseudotime_cells)] 384 | 385 | dict_tg_edges = dict() 386 | 387 | id_initial = range(0, int(df_cells_sort.shape[0] * 0.2)) 388 | id_final = range( 389 | int(df_cells_sort.shape[0] * 0.8), int(df_cells_sort.shape[0] * 1) 390 | ) 391 | values_initial, values_final = ( 392 | df_cells_sort.iloc[id_initial, :], 393 | df_cells_sort.iloc[id_final, :], 394 | ) 395 | diff_initial_final = np.abs( 396 | values_final.mean(axis=0) - values_initial.mean(axis=0) 397 | ) 398 | 399 | # original expression 400 | df_cells_ori = deepcopy(df_marker_detection.loc[cells]) 401 | df_cells_sort_ori = df_cells_ori.iloc[np.argsort(pseudotime_cells)] 402 | values_initial_ori, values_final_ori = ( 403 | df_cells_sort_ori.iloc[id_initial, :], 404 | df_cells_sort_ori.iloc[id_final, :], 405 | ) 406 | 407 | ix_pos = diff_initial_final > 0 408 | logfc = pd.Series( 409 | np.zeros(len(diff_initial_final)), index=diff_initial_final.index 410 | ) 411 | logfc[ix_pos] = np.log2( 412 | ( 413 | np.maximum(values_final.mean(axis=0), values_initial.mean(axis=0)) 414 | + 0.01 415 | ) 416 | / ( 417 | np.minimum(values_final.mean(axis=0), values_initial.mean(axis=0)) 418 | + 0.01 419 | ) 420 | ) 421 | 422 | ix_cutoff = np.array(logfc > fc_cutoff) 423 | 424 | if sum(ix_cutoff) == 0: 425 | print( 426 | "No Transition markers are detected in branch with nodes " 427 | + str(path_source) 428 | + " to " 429 | + str(path_target) 430 | ) 431 | 432 | else: 433 | df_stat_pval_qval = pd.DataFrame( 434 | np.full((sum(ix_cutoff), 8), np.nan), 435 | columns=[ 436 | "stat", 437 | "logfc", 438 | "pval", 439 | "qval", 440 | "initial_mean", 441 | "final_mean", 442 | "initial_mean_ori", 443 | "final_mean_ori", 444 | ], 445 | index=df_cells_sort.columns[ix_cutoff], 446 | ) 447 | 448 | if method == "spearman": 449 | df_stat_pval_qval["stat"] = spearman_ps( 450 | np.array(pseudotime_cells_sort), 451 | np.array(df_cells_sort.iloc[:, ix_cutoff]), 452 | ) 453 | df_stat_pval_qval["pval"] = p_val( 454 | df_stat_pval_qval["stat"], len(pseudotime_cells_sort) 455 | ) 456 | elif method == "xi": 457 | # /!\ dont use df_cells_sort 458 | # and pseudotime_cells_sort, breaks xicorr 459 | res = xicorr_ps( 460 | np.array(pseudotime_cells), 461 | np.array(df_cells.iloc[:, ix_cutoff]) 462 | ) 463 | df_stat_pval_qval["stat"] = res[0] 464 | df_stat_pval_qval["pval"] = res[1] 465 | else: 466 | raise ValueError("method must be one of 'spearman', 'xi'") 467 | 468 | df_stat_pval_qval["logfc"] = logfc 469 | p_values = df_stat_pval_qval["pval"] 470 | q_values = multipletests(p_values, method="fdr_bh")[1] 471 | df_stat_pval_qval["qval"] = q_values 472 | df_stat_pval_qval["initial_mean"] = values_initial.mean(axis=0) 473 | df_stat_pval_qval["final_mean"] = values_final.mean(axis=0) 474 | df_stat_pval_qval["initial_mean_ori"] = values_initial_ori.mean(axis=0) 475 | df_stat_pval_qval["final_mean_ori"] = values_final_ori.mean(axis=0) 476 | 477 | dict_tg_edges[path_alias] = df_stat_pval_qval.sort_values(["qval"]) 478 | 479 | dict_tg_edges[path_alias].to_csv( 480 | os.path.join( 481 | file_path, 482 | "transition_markers_path_" 483 | + str(path_source) 484 | + "-" 485 | + str(path_target) 486 | + ".tsv", 487 | ), 488 | sep="\t", 489 | index=True, 490 | ) 491 | 492 | if "transition_markers" in adata.uns.keys(): 493 | adata.uns["transition_markers"].update(dict_tg_edges) 494 | else: 495 | adata.uns["transition_markers"] = dict_tg_edges 496 | -------------------------------------------------------------------------------- /stream2/preprocessing/_qc.py: -------------------------------------------------------------------------------- 1 | """Quality Control.""" 2 | 3 | import numpy as np 4 | from scipy.sparse import ( 5 | issparse, 6 | csr_matrix, 7 | ) 8 | import re 9 | 10 | 11 | def cal_qc(adata, expr_cutoff=1): 12 | """Calculate quality control metrics. 13 | 14 | Parameters 15 | ---------- 16 | adata: AnnData 17 | Annotated data matrix. 18 | expr_cutoff: `float`, optional (default: 1) 19 | Expression cutoff. 20 | If greater than expr_cutoff,the feature is considered 'expressed' 21 | assay: `str`, optional (default: 'rna') 22 | Choose from {'rna','atac'},case insensitive 23 | Returns 24 | ------- 25 | updates `adata` with the following fields. 26 | n_counts: `pandas.Series` (`adata.var['n_counts']`,dtype `int`) 27 | The number of read count each gene has. 28 | n_cells: `pandas.Series` (`adata.var['n_cells']`,dtype `int`) 29 | The number of cells in which each gene is expressed. 30 | pct_cells: `pandas.Series` (`adata.var['pct_cells']`,dtype `float`) 31 | The percentage of cells in which each gene is expressed. 32 | n_counts: `pandas.Series` (`adata.obs['n_counts']`,dtype `int`) 33 | The number of read count each cell has. 34 | n_genes: `pandas.Series` (`adata.obs['n_genes']`,dtype `int`) 35 | The number of genes expressed in each cell. 36 | pct_genes: `pandas.Series` (`adata.obs['pct_genes']`,dtype `float`) 37 | The percentage of genes expressed in each cell. 38 | n_peaks: `pandas.Series` (`adata.obs['n_peaks']`,dtype `int`) 39 | The number of peaks expressed in each cell. 40 | pct_peaks: `pandas.Series` (`adata.obs['pct_peaks']`,dtype `int`) 41 | The percentage of peaks expressed in each cell. 42 | pct_mt: `pandas.Series` (`adata.obs['pct_mt']`,dtype `float`) 43 | the percentage of counts in mitochondrial genes 44 | """ 45 | 46 | if not issparse(adata.X): 47 | adata.X = csr_matrix(adata.X) 48 | 49 | n_counts = adata.X.sum(axis=0).A1 50 | adata.var["n_counts"] = n_counts 51 | n_samples = (adata.X >= expr_cutoff).sum(axis=0).A1 52 | adata.var["n_samples"] = n_samples 53 | adata.var["pct_samples"] = n_samples / adata.shape[0] 54 | 55 | n_counts = adata.X.sum(axis=1).A1 56 | adata.obs["n_counts"] = n_counts 57 | n_features = (adata.X >= expr_cutoff).sum(axis=1).A1 58 | adata.obs["n_features"] = n_features 59 | adata.obs["pct_features"] = n_features / adata.shape[1] 60 | 61 | 62 | def cal_qc_rna(adata, expr_cutoff=1): 63 | """Calculate quality control metrics. 64 | 65 | Parameters 66 | ---------- 67 | adata: AnnData 68 | Annotated data matrix. 69 | expr_cutoff: `float`, optional (default: 1) 70 | Expression cutoff. 71 | If greater than expr_cutoff,the feature is considered 'expressed' 72 | assay: `str`, optional (default: 'rna') 73 | Choose from {'rna','atac'},case insensitive 74 | Returns 75 | ------- 76 | updates `adata` with the following fields. 77 | n_counts: `pandas.Series` (`adata.var['n_counts']`,dtype `int`) 78 | The number of read count each gene has. 79 | n_cells: `pandas.Series` (`adata.var['n_cells']`,dtype `int`) 80 | The number of cells in which each gene is expressed. 81 | pct_cells: `pandas.Series` (`adata.var['pct_cells']`,dtype `float`) 82 | The percentage of cells in which each gene is expressed. 83 | n_counts: `pandas.Series` (`adata.obs['n_counts']`,dtype `int`) 84 | The number of read count each cell has. 85 | n_genes: `pandas.Series` (`adata.obs['n_genes']`,dtype `int`) 86 | The number of genes expressed in each cell. 87 | pct_genes: `pandas.Series` (`adata.obs['pct_genes']`,dtype `float`) 88 | The percentage of genes expressed in each cell. 89 | n_peaks: `pandas.Series` (`adata.obs['n_peaks']`,dtype `int`) 90 | The number of peaks expressed in each cell. 91 | pct_peaks: `pandas.Series` (`adata.obs['pct_peaks']`,dtype `int`) 92 | The percentage of peaks expressed in each cell. 93 | pct_mt: `pandas.Series` (`adata.obs['pct_mt']`,dtype `float`) 94 | the percentage of counts in mitochondrial genes 95 | """ 96 | 97 | if not issparse(adata.X): 98 | adata.X = csr_matrix(adata.X) 99 | 100 | n_counts = adata.X.sum(axis=0).A1 101 | adata.var["n_counts"] = n_counts 102 | n_cells = (adata.X >= expr_cutoff).sum(axis=0).A1 103 | adata.var["n_cells"] = n_cells 104 | adata.var["pct_cells"] = n_cells / adata.shape[0] 105 | 106 | n_counts = adata.X.sum(axis=1).A1 107 | adata.obs["n_counts"] = n_counts 108 | n_features = (adata.X >= expr_cutoff).sum(axis=1).A1 109 | adata.obs["n_genes"] = n_features 110 | adata.obs["pct_genes"] = n_features / adata.shape[1] 111 | r = re.compile("^MT-", flags=re.IGNORECASE) 112 | mt_genes = list(filter(r.match, adata.var_names)) 113 | if len(mt_genes) > 0: 114 | n_counts_mt = adata[:, mt_genes].X.sum(axis=1).A1 115 | adata.obs["pct_mt"] = n_counts_mt / n_counts 116 | else: 117 | adata.obs["pct_mt"] = 0 118 | 119 | 120 | def cal_qc_atac(adata, expr_cutoff=1): 121 | """Calculate quality control metrics. 122 | 123 | Parameters 124 | ---------- 125 | adata: AnnData 126 | Annotated data matrix. 127 | expr_cutoff: `float`, optional (default: 1) 128 | Expression cutoff. 129 | If greater than expr_cutoff,the feature is considered 'expressed' 130 | assay: `str`, optional (default: 'rna') 131 | Choose from {'rna','atac'},case insensitive 132 | Returns 133 | ------- 134 | updates `adata` with the following fields. 135 | n_counts: `pandas.Series` (`adata.var['n_counts']`,dtype `int`) 136 | The number of read count each gene has. 137 | n_cells: `pandas.Series` (`adata.var['n_cells']`,dtype `int`) 138 | The number of cells in which each gene is expressed. 139 | pct_cells: `pandas.Series` (`adata.var['pct_cells']`,dtype `float`) 140 | The percentage of cells in which each gene is expressed. 141 | n_counts: `pandas.Series` (`adata.obs['n_counts']`,dtype `int`) 142 | The number of read count each cell has. 143 | n_genes: `pandas.Series` (`adata.obs['n_genes']`,dtype `int`) 144 | The number of genes expressed in each cell. 145 | pct_genes: `pandas.Series` (`adata.obs['pct_genes']`,dtype `float`) 146 | The percentage of genes expressed in each cell. 147 | n_peaks: `pandas.Series` (`adata.obs['n_peaks']`,dtype `int`) 148 | The number of peaks expressed in each cell. 149 | pct_peaks: `pandas.Series` (`adata.obs['pct_peaks']`,dtype `int`) 150 | The percentage of peaks expressed in each cell. 151 | pct_mt: `pandas.Series` (`adata.obs['pct_mt']`,dtype `float`) 152 | the percentage of counts in mitochondrial genes 153 | """ 154 | 155 | if not issparse(adata.X): 156 | adata.X = csr_matrix(adata.X) 157 | 158 | n_counts = adata.X.sum(axis=0).A1 159 | adata.var["n_counts"] = n_counts 160 | n_cells = (adata.X >= expr_cutoff).sum(axis=0).A1 161 | adata.var["n_cells"] = n_cells 162 | adata.var["pct_cells"] = n_cells / adata.shape[0] 163 | 164 | n_counts = adata.X.sum(axis=1).A1 165 | adata.obs["n_counts"] = n_counts 166 | n_features = (adata.X >= expr_cutoff).sum(axis=1).A1 167 | adata.obs["n_peaks"] = n_features 168 | adata.obs["pct_peaks"] = n_features / adata.shape[1] 169 | 170 | 171 | def filter_samples( 172 | adata, 173 | min_n_features=1, 174 | max_n_features=None, 175 | min_pct_features=None, 176 | max_pct_features=None, 177 | min_n_counts=None, 178 | max_n_counts=None, 179 | expr_cutoff=1, 180 | ): 181 | """Filter out samples based on different metrics. 182 | 183 | Parameters 184 | ---------- 185 | adata: AnnData 186 | Annotated data matrix. 187 | min_n_features: `int`, optional (default: None) 188 | Minimum number of features expressed 189 | min_pct_features: `float`, optional (default: None) 190 | Minimum percentage of features expressed 191 | min_n_counts: `int`, optional (default: None) 192 | Minimum number of read count for one cell 193 | expr_cutoff: `float`, optional (default: 1) 194 | Expression cutoff. 195 | If greater than expr_cutoff,the gene is considered 'expressed' 196 | assay: `str`, optional (default: 'rna') 197 | Choose from {{'rna','atac'}},case insensitive 198 | Returns 199 | ------- 200 | updates `adata` with a subset of cells that pass the filtering. 201 | updates `adata` with the following fields if cal_qc() was not performed. 202 | n_counts: `pandas.Series` (`adata.obs['n_counts']`,dtype `int`) 203 | The number of read count each cell has. 204 | n_genes: `pandas.Series` (`adata.obs['n_genes']`,dtype `int`) 205 | The number of genes expressed in each cell. 206 | pct_genes: `pandas.Series` (`adata.obs['pct_genes']`,dtype `float`) 207 | The percentage of genes expressed in each cell. 208 | n_peaks: `pandas.Series` (`adata.obs['n_peaks']`,dtype `int`) 209 | The number of peaks expressed in each cell. 210 | pct_peaks: `pandas.Series` (`adata.obs['pct_peaks']`,dtype `int`) 211 | The percentage of peaks expressed in each cell. 212 | """ 213 | 214 | if not issparse(adata.X): 215 | adata.X = csr_matrix(adata.X) 216 | if "n_counts" in adata.obs_keys(): 217 | n_counts = adata.obs["n_counts"] 218 | else: 219 | n_counts = np.sum(adata.X, axis=1).A 220 | adata.obs["n_counts"] = n_counts 221 | if "n_features" in adata.obs_keys(): 222 | n_features = adata.obs["n_features"] 223 | else: 224 | n_features = np.sum(adata.X >= expr_cutoff, axis=1).A1 225 | adata.obs["n_features"] = n_features 226 | if "pct_features" in adata.obs_keys(): 227 | pct_features = adata.obs["pct_features"] 228 | else: 229 | pct_features = n_features / adata.shape[1] 230 | adata.obs["pct_features"] = pct_features 231 | 232 | print("before filtering: ") 233 | print(f"{adata.shape[0]} samples, {adata.shape[1]} feature") 234 | if ( 235 | sum( 236 | list( 237 | map( 238 | lambda x: x is None, 239 | [ 240 | min_n_features, 241 | min_pct_features, 242 | min_n_counts, 243 | max_n_features, 244 | max_pct_features, 245 | max_n_counts, 246 | ], 247 | ) 248 | ) 249 | ) 250 | == 6 251 | ): 252 | print("No filtering") 253 | else: 254 | cell_subset = np.ones(len(adata.obs_names), dtype=bool) 255 | if min_n_features is not None: 256 | print("filter samples based on min_n_features") 257 | cell_subset = (n_features >= min_n_features) & cell_subset 258 | if max_n_features is not None: 259 | print("filter samples based on max_n_features") 260 | cell_subset = (n_features <= max_n_features) & cell_subset 261 | if min_pct_features is not None: 262 | print("filter samples based on min_pct_features") 263 | cell_subset = (pct_features >= min_pct_features) & cell_subset 264 | if max_pct_features is not None: 265 | print("filter samples based on max_pct_features") 266 | cell_subset = (pct_features <= max_pct_features) & cell_subset 267 | if min_n_counts is not None: 268 | print("filter samples based on min_n_counts") 269 | cell_subset = (n_counts >= min_n_counts) & cell_subset 270 | if max_n_counts is not None: 271 | print("filter samples based on max_n_counts") 272 | cell_subset = (n_counts <= max_n_counts) & cell_subset 273 | adata._inplace_subset_obs(cell_subset) 274 | print("after filtering out low-quality samples: ") 275 | print(f"{adata.shape[0]} samples, {adata.shape[1]} feature") 276 | return None 277 | 278 | 279 | def filter_cells_rna( 280 | adata, 281 | min_n_genes=None, 282 | max_n_genes=None, 283 | min_pct_genes=None, 284 | max_pct_genes=None, 285 | min_n_counts=None, 286 | max_n_counts=None, 287 | expr_cutoff=1, 288 | ): 289 | """Filter out cells for RNA-seq based on different metrics. 290 | 291 | Parameters 292 | ---------- 293 | adata: AnnData 294 | Annotated data matrix. 295 | min_n_genes: `int`, optional (default: None) 296 | Minimum number of genes expressed 297 | min_pct_genes: `float`, optional (default: None) 298 | Minimum percentage of genes expressed 299 | min_n_counts: `int`, optional (default: None) 300 | Minimum number of read count for one cell 301 | expr_cutoff: `float`, optional (default: 1) 302 | Expression cutoff. 303 | If greater than expr_cutoff,the gene is considered 'expressed' 304 | assay: `str`, optional (default: 'rna') 305 | Choose from {{'rna','atac'}},case insensitive 306 | Returns 307 | ------- 308 | updates `adata` with a subset of cells that pass the filtering. 309 | updates `adata` with the following fields if cal_qc() was not performed. 310 | n_counts: `pandas.Series` (`adata.obs['n_counts']`,dtype `int`) 311 | The number of read count each cell has. 312 | n_genes: `pandas.Series` (`adata.obs['n_genes']`,dtype `int`) 313 | The number of genes expressed in each cell. 314 | pct_genes: `pandas.Series` (`adata.obs['pct_genes']`,dtype `float`) 315 | The percentage of genes expressed in each cell. 316 | n_peaks: `pandas.Series` (`adata.obs['n_peaks']`,dtype `int`) 317 | The number of peaks expressed in each cell. 318 | pct_peaks: `pandas.Series` (`adata.obs['pct_peaks']`,dtype `int`) 319 | The percentage of peaks expressed in each cell. 320 | """ 321 | 322 | if not issparse(adata.X): 323 | adata.X = csr_matrix(adata.X) 324 | if "n_counts" in adata.obs_keys(): 325 | n_counts = adata.obs["n_counts"] 326 | else: 327 | n_counts = np.sum(adata.X, axis=1).A1 328 | adata.obs["n_counts"] = n_counts 329 | 330 | if "n_genes" in adata.obs_keys(): 331 | n_genes = adata.obs["n_genes"] 332 | else: 333 | n_genes = np.sum(adata.X >= expr_cutoff, axis=1).A1 334 | adata.obs["n_genes"] = n_genes 335 | if "pct_genes" in adata.obs_keys(): 336 | pct_genes = adata.obs["pct_genes"] 337 | else: 338 | pct_genes = n_genes / adata.shape[1] 339 | adata.obs["pct_genes"] = pct_genes 340 | 341 | print("before filtering: ") 342 | print(f"{adata.shape[0]} cells, {adata.shape[1]} genes") 343 | if ( 344 | sum( 345 | list( 346 | map( 347 | lambda x: x is None, 348 | [ 349 | min_n_genes, 350 | min_pct_genes, 351 | min_n_counts, 352 | max_n_genes, 353 | max_pct_genes, 354 | max_n_counts, 355 | ], 356 | ) 357 | ) 358 | ) 359 | == 6 360 | ): 361 | print("No filtering") 362 | else: 363 | cell_subset = np.ones(len(adata.obs_names), dtype=bool) 364 | if min_n_genes is not None: 365 | print("filter cells based on min_n_genes") 366 | cell_subset = (n_genes >= min_n_genes) & cell_subset 367 | if max_n_genes is not None: 368 | print("filter cells based on max_n_genes") 369 | cell_subset = (n_genes <= max_n_genes) & cell_subset 370 | if min_pct_genes is not None: 371 | print("filter cells based on min_pct_genes") 372 | cell_subset = (pct_genes >= min_pct_genes) & cell_subset 373 | if max_pct_genes is not None: 374 | print("filter cells based on max_pct_genes") 375 | cell_subset = (pct_genes <= max_pct_genes) & cell_subset 376 | if min_n_counts is not None: 377 | print("filter cells based on min_n_counts") 378 | cell_subset = (n_counts >= min_n_counts) & cell_subset 379 | if max_n_counts is not None: 380 | print("filter cells based on max_n_counts") 381 | cell_subset = (n_counts <= max_n_counts) & cell_subset 382 | adata._inplace_subset_obs(cell_subset) 383 | print("after filtering out low-quality cells: ") 384 | print(f"{adata.shape[0]} cells, {adata.shape[1]} genes") 385 | return None 386 | 387 | 388 | def filter_cells_atac( 389 | adata, 390 | min_n_peaks=None, 391 | max_n_peaks=None, 392 | min_pct_peaks=None, 393 | max_pct_peaks=None, 394 | min_n_counts=None, 395 | max_n_counts=None, 396 | expr_cutoff=1, 397 | ): 398 | """Filter out cells for ATAC-seq based on different metrics. 399 | 400 | Parameters 401 | ---------- 402 | adata: AnnData 403 | Annotated data matrix. 404 | min_n_peaks: `int`, optional (default: None) 405 | Minimum number of peaks expressed 406 | min_pct_peaks: `float`, optional (default: None) 407 | Minimum percentage of peaks expressed 408 | min_n_counts: `int`, optional (default: None) 409 | Minimum number of read count for one cell 410 | expr_cutoff: `float`, optional (default: 1) 411 | Expression cutoff. 412 | If greater than expr_cutoff,the gene is considered 'expressed' 413 | assay: `str`, optional (default: 'rna') 414 | Choose from {{'rna','atac'}},case insensitive 415 | Returns 416 | ------- 417 | updates `adata` with a subset of cells that pass the filtering. 418 | updates `adata` with the following fields if cal_qc() was not performed. 419 | n_counts: `pandas.Series` (`adata.obs['n_counts']`,dtype `int`) 420 | The number of read count each cell has. 421 | n_genes: `pandas.Series` (`adata.obs['n_genes']`,dtype `int`) 422 | The number of genes expressed in each cell. 423 | pct_genes: `pandas.Series` (`adata.obs['pct_genes']`,dtype `float`) 424 | The percentage of genes expressed in each cell. 425 | n_peaks: `pandas.Series` (`adata.obs['n_peaks']`,dtype `int`) 426 | The number of peaks expressed in each cell. 427 | pct_peaks: `pandas.Series` (`adata.obs['pct_peaks']`,dtype `int`) 428 | The percentage of peaks expressed in each cell. 429 | """ 430 | 431 | if not issparse(adata.X): 432 | adata.X = csr_matrix(adata.X) 433 | if "n_counts" in adata.obs_keys(): 434 | n_counts = adata.obs["n_counts"] 435 | else: 436 | n_counts = np.sum(adata.X, axis=1).A1 437 | adata.obs["n_counts"] = n_counts 438 | 439 | if "n_peaks" in adata.obs_keys(): 440 | n_peaks = adata.obs["n_peaks"] 441 | else: 442 | n_peaks = np.sum(adata.X >= expr_cutoff, axis=1).A1 443 | adata.obs["n_peaks"] = n_peaks 444 | if "pct_peaks" in adata.obs_keys(): 445 | pct_peaks = adata.obs["pct_peaks"] 446 | else: 447 | pct_peaks = n_peaks / adata.shape[1] 448 | adata.obs["pct_peaks"] = pct_peaks 449 | 450 | print("before filtering: ") 451 | print(f"{adata.shape[0]} cells, {adata.shape[1]} peaks") 452 | if ( 453 | sum( 454 | list( 455 | map( 456 | lambda x: x is None, 457 | [ 458 | min_n_peaks, 459 | min_pct_peaks, 460 | min_n_counts, 461 | max_n_peaks, 462 | max_pct_peaks, 463 | max_n_counts, 464 | ], 465 | ) 466 | ) 467 | ) 468 | == 6 469 | ): 470 | print("No filtering") 471 | else: 472 | cell_subset = np.ones(len(adata.obs_names), dtype=bool) 473 | if min_n_peaks is not None: 474 | print("filter cells based on min_n_peaks") 475 | cell_subset = (n_peaks >= min_n_peaks) & cell_subset 476 | if max_n_peaks is not None: 477 | print("filter cells based on max_n_peaks") 478 | cell_subset = (n_peaks <= max_n_peaks) & cell_subset 479 | if min_pct_peaks is not None: 480 | print("filter cells based on min_pct_peaks") 481 | cell_subset = (pct_peaks >= min_pct_peaks) & cell_subset 482 | if max_pct_peaks is not None: 483 | print("filter cells based on max_pct_peaks") 484 | cell_subset = (pct_peaks <= max_pct_peaks) & cell_subset 485 | if min_n_counts is not None: 486 | print("filter cells based on min_n_counts") 487 | cell_subset = (n_counts >= min_n_counts) & cell_subset 488 | if max_n_counts is not None: 489 | print("filter cells based on max_n_counts") 490 | cell_subset = (n_counts <= max_n_counts) & cell_subset 491 | adata._inplace_subset_obs(cell_subset) 492 | print("after filtering out low-quality cells: ") 493 | print(f"{adata.shape[0]} cells, {adata.shape[1]} peaks") 494 | return None 495 | 496 | 497 | def filter_genes( 498 | adata, 499 | min_n_cells=3, 500 | max_n_cells=None, 501 | min_pct_cells=None, 502 | max_pct_cells=None, 503 | min_n_counts=None, 504 | max_n_counts=None, 505 | expr_cutoff=1, 506 | ): 507 | """Filter out features based on different metrics. 508 | 509 | Parameters 510 | ---------- 511 | adata: AnnData 512 | Annotated data matrix. 513 | min_n_cells: `int`, optional (default: 5) 514 | Minimum number of cells expressing one feature 515 | min_pct_cells: `float`, optional (default: None) 516 | Minimum percentage of cells expressing one feature 517 | min_n_counts: `int`, optional (default: None) 518 | Minimum number of read count for one feature 519 | expr_cutoff: `float`, optional (default: 1) 520 | Expression cutoff. 521 | If greater than expr_cutoff,the feature is considered 'expressed' 522 | assay: `str`, optional (default: 'rna') 523 | Choose from {{'rna','atac'}},case insensitive 524 | 525 | Returns 526 | ------- 527 | updates `adata` with a subset of features that pass the filtering. 528 | updates `adata` with the following fields if cal_qc() was not performed. 529 | n_counts: `pandas.Series` (`adata.var['n_counts']`,dtype `int`) 530 | The number of read count each gene has. 531 | n_cells: `pandas.Series` (`adata.var['n_cells']`,dtype `int`) 532 | The number of cells in which each gene is expressed. 533 | pct_cells: `pandas.Series` (`adata.var['pct_cells']`,dtype `float`) 534 | The percentage of cells in which each gene is expressed. 535 | """ 536 | 537 | feature = "genes" 538 | if not issparse(adata.X): 539 | adata.X = csr_matrix(adata.X) 540 | 541 | if "n_counts" in adata.var_keys(): 542 | n_counts = adata.var["n_counts"] 543 | else: 544 | n_counts = np.sum(adata.X, axis=0).A1 545 | adata.var["n_counts"] = n_counts 546 | if "n_cells" in adata.var_keys(): 547 | n_cells = adata.var["n_cells"] 548 | else: 549 | n_cells = np.sum(adata.X >= expr_cutoff, axis=0).A1 550 | adata.var["n_cells"] = n_cells 551 | if "pct_cells" in adata.var_keys(): 552 | pct_cells = adata.var["pct_cells"] 553 | else: 554 | pct_cells = n_cells / adata.shape[0] 555 | adata.var["pct_cells"] = pct_cells 556 | 557 | print("Before filtering: ") 558 | print( 559 | str(adata.shape[0]) + " cells, " + str(adata.shape[1]) + " " + feature 560 | ) 561 | if ( 562 | sum( 563 | list( 564 | map( 565 | lambda x: x is None, 566 | [ 567 | min_n_cells, 568 | min_pct_cells, 569 | min_n_counts, 570 | max_n_cells, 571 | max_pct_cells, 572 | max_n_counts, 573 | ], 574 | ) 575 | ) 576 | ) 577 | == 6 578 | ): 579 | print("No filtering") 580 | else: 581 | feature_subset = np.ones(len(adata.var_names), dtype=bool) 582 | if min_n_cells is not None: 583 | print("Filter " + feature + " based on min_n_cells") 584 | feature_subset = (n_cells >= min_n_cells) & feature_subset 585 | if max_n_cells is not None: 586 | print("Filter " + feature + " based on max_n_cells") 587 | feature_subset = (n_cells <= max_n_cells) & feature_subset 588 | if min_pct_cells is not None: 589 | print("Filter " + feature + " based on min_pct_cells") 590 | feature_subset = (pct_cells >= min_pct_cells) & feature_subset 591 | if max_pct_cells is not None: 592 | print("Filter " + feature + " based on max_pct_cells") 593 | feature_subset = (pct_cells <= max_pct_cells) & feature_subset 594 | if min_n_counts is not None: 595 | print("Filter " + feature + " based on min_n_counts") 596 | feature_subset = (n_counts >= min_n_counts) & feature_subset 597 | if max_n_counts is not None: 598 | print("Filter " + feature + " based on max_n_counts") 599 | feature_subset = (n_counts <= max_n_counts) & feature_subset 600 | adata._inplace_subset_var(feature_subset) 601 | print("After filtering out low-expressed " + feature + ": ") 602 | print( 603 | str(adata.shape[0]) 604 | + " cells, " 605 | + str(adata.shape[1]) 606 | + " " 607 | + feature 608 | ) 609 | return None 610 | 611 | 612 | def filter_peaks( 613 | adata, 614 | min_n_cells=5, 615 | max_n_cells=None, 616 | min_pct_cells=None, 617 | max_pct_cells=None, 618 | min_n_counts=None, 619 | max_n_counts=None, 620 | expr_cutoff=1, 621 | ): 622 | """Filter out features based on different metrics. 623 | 624 | Parameters 625 | ---------- 626 | adata: AnnData 627 | Annotated data matrix. 628 | min_n_cells: `int`, optional (default: 5) 629 | Minimum number of cells expressing one feature 630 | min_pct_cells: `float`, optional (default: None) 631 | Minimum percentage of cells expressing one feature 632 | min_n_counts: `int`, optional (default: None) 633 | Minimum number of read count for one feature 634 | expr_cutoff: `float`, optional (default: 1) 635 | Expression cutoff. 636 | If greater than expr_cutoff,the feature is considered 'expressed' 637 | assay: `str`, optional (default: 'rna') 638 | Choose from {{'rna','atac'}},case insensitive 639 | 640 | Returns 641 | ------- 642 | updates `adata` with a subset of features that pass the filtering. 643 | updates `adata` with the following fields if cal_qc() was not performed. 644 | n_counts: `pandas.Series` (`adata.var['n_counts']`,dtype `int`) 645 | The number of read count each gene has. 646 | n_cells: `pandas.Series` (`adata.var['n_cells']`,dtype `int`) 647 | The number of cells in which each gene is expressed. 648 | pct_cells: `pandas.Series` (`adata.var['pct_cells']`,dtype `float`) 649 | The percentage of cells in which each gene is expressed. 650 | """ 651 | 652 | feature = "peaks" 653 | if not issparse(adata.X): 654 | adata.X = csr_matrix(adata.X) 655 | 656 | if "n_counts" in adata.var_keys(): 657 | n_counts = adata.var["n_counts"] 658 | else: 659 | n_counts = np.sum(adata.X, axis=0).A1 660 | adata.var["n_counts"] = n_counts 661 | if "n_cells" in adata.var_keys(): 662 | n_cells = adata.var["n_cells"] 663 | else: 664 | n_cells = np.sum(adata.X >= expr_cutoff, axis=0).A1 665 | adata.var["n_cells"] = n_cells 666 | if "pct_cells" in adata.var_keys(): 667 | pct_cells = adata.var["pct_cells"] 668 | else: 669 | pct_cells = n_cells / adata.shape[0] 670 | adata.var["pct_cells"] = pct_cells 671 | 672 | print("Before filtering: ") 673 | print( 674 | str(adata.shape[0]) + " cells, " + str(adata.shape[1]) + " " + feature 675 | ) 676 | if ( 677 | sum( 678 | list( 679 | map( 680 | lambda x: x is None, 681 | [ 682 | min_n_cells, 683 | min_pct_cells, 684 | min_n_counts, 685 | max_n_cells, 686 | max_pct_cells, 687 | max_n_counts, 688 | ], 689 | ) 690 | ) 691 | ) 692 | == 6 693 | ): 694 | print("No filtering") 695 | else: 696 | feature_subset = np.ones(len(adata.var_names), dtype=bool) 697 | if min_n_cells is not None: 698 | print("Filter " + feature + " based on min_n_cells") 699 | feature_subset = (n_cells >= min_n_cells) & feature_subset 700 | if max_n_cells is not None: 701 | print("Filter " + feature + " based on max_n_cells") 702 | feature_subset = (n_cells <= max_n_cells) & feature_subset 703 | if min_pct_cells is not None: 704 | print("Filter " + feature + " based on min_pct_cells") 705 | feature_subset = (pct_cells >= min_pct_cells) & feature_subset 706 | if max_pct_cells is not None: 707 | print("Filter " + feature + " based on max_pct_cells") 708 | feature_subset = (pct_cells <= max_pct_cells) & feature_subset 709 | if min_n_counts is not None: 710 | print("Filter " + feature + " based on min_n_counts") 711 | feature_subset = (n_counts >= min_n_counts) & feature_subset 712 | if max_n_counts is not None: 713 | print("Filter " + feature + " based on max_n_counts") 714 | feature_subset = (n_counts <= max_n_counts) & feature_subset 715 | adata._inplace_subset_var(feature_subset) 716 | print("After filtering out low-expressed " + feature + ": ") 717 | print( 718 | str(adata.shape[0]) 719 | + " cells, " 720 | + str(adata.shape[1]) 721 | + " " 722 | + feature 723 | ) 724 | return None 725 | 726 | 727 | def filter_features( 728 | adata, 729 | min_n_samples=5, 730 | max_n_samples=None, 731 | min_pct_samples=None, 732 | max_pct_samples=None, 733 | min_n_counts=None, 734 | max_n_counts=None, 735 | expr_cutoff=1, 736 | ): 737 | """Filter out features based on different metrics. 738 | 739 | Parameters 740 | ---------- 741 | adata: AnnData 742 | Annotated data matrix. 743 | min_n_cells: `int`, optional (default: 5) 744 | Minimum number of cells expressing one feature 745 | min_pct_cells: `float`, optional (default: None) 746 | Minimum percentage of cells expressing one feature 747 | min_n_counts: `int`, optional (default: None) 748 | Minimum number of read count for one feature 749 | expr_cutoff: `float`, optional (default: 1) 750 | Expression cutoff. 751 | If greater than expr_cutoff,the feature is considered 'expressed' 752 | assay: `str`, optional (default: 'rna') 753 | Choose from {{'rna','atac'}},case insensitive 754 | 755 | Returns 756 | ------- 757 | updates `adata` with a subset of features that pass the filtering. 758 | updates `adata` with the following fields if cal_qc() was not performed. 759 | n_counts: `pandas.Series` (`adata.var['n_counts']`,dtype `int`) 760 | The number of read count each gene has. 761 | n_cells: `pandas.Series` (`adata.var['n_cells']`,dtype `int`) 762 | The number of cells in which each gene is expressed. 763 | pct_cells: `pandas.Series` (`adata.var['pct_cells']`,dtype `float`) 764 | The percentage of cells in which each gene is expressed. 765 | """ 766 | 767 | if not issparse(adata.X): 768 | adata.X = csr_matrix(adata.X) 769 | if "n_counts" in adata.var_keys(): 770 | n_counts = adata.var["n_counts"] 771 | else: 772 | n_counts = np.sum(adata.X, axis=0).A1 773 | adata.var["n_counts"] = n_counts 774 | if "n_samples" in adata.var_keys(): 775 | n_samples = adata.var["n_samples"] 776 | else: 777 | n_samples = np.sum(adata.X >= expr_cutoff, axis=0).A1 778 | adata.var["n_samples"] = n_samples 779 | if "pct_samples" in adata.var_keys(): 780 | pct_samples = adata.var["pct_samples"] 781 | else: 782 | pct_samples = n_samples / adata.shape[0] 783 | adata.var["pct_samples"] = pct_samples 784 | 785 | print("Before filtering: ") 786 | print(f"{adata.shape[0]} samples, {adata.shape[1]} features") 787 | 788 | if ( 789 | sum( 790 | list( 791 | map( 792 | lambda x: x is None, 793 | [ 794 | min_n_samples, 795 | min_pct_samples, 796 | min_n_counts, 797 | max_n_samples, 798 | max_pct_samples, 799 | max_n_counts, 800 | ], 801 | ) 802 | ) 803 | ) 804 | == 6 805 | ): 806 | print("No filtering") 807 | else: 808 | feature_subset = np.ones(len(adata.var_names), dtype=bool) 809 | if min_n_samples is not None: 810 | print("Filter features based on min_n_samples") 811 | feature_subset = (n_samples >= min_n_samples) & feature_subset 812 | if max_n_samples is not None: 813 | print("Filter features based on max_n_samples") 814 | feature_subset = (n_samples <= max_n_samples) & feature_subset 815 | if min_pct_samples is not None: 816 | print("Filter features based on min_pct_samples") 817 | feature_subset = (pct_samples >= min_pct_samples) & feature_subset 818 | if max_pct_samples is not None: 819 | print("Filter features based on max_pct_samples") 820 | feature_subset = (pct_samples <= max_pct_samples) & feature_subset 821 | if min_n_counts is not None: 822 | print("Filter features based on min_n_counts") 823 | feature_subset = (n_counts >= min_n_counts) & feature_subset 824 | if max_n_counts is not None: 825 | print("Filter features based on max_n_counts") 826 | feature_subset = (n_counts <= max_n_counts) & feature_subset 827 | adata._inplace_subset_var(feature_subset) 828 | print("After filtering out low-expressed features: ") 829 | print(f"{adata.shape[0]} samples, {adata.shape[1]} features") 830 | return None 831 | -------------------------------------------------------------------------------- /stream2/tools/_graph_utils.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import elpigraph 3 | import numpy as np 4 | import scanpy as sc 5 | import scipy 6 | import copy 7 | import statsmodels.api 8 | from sklearn.neighbors import KNeighborsRegressor 9 | 10 | from ._elpigraph import ( 11 | learn_graph, 12 | _store_graph_attributes, 13 | _get_graph_data, 14 | _subset_adata, 15 | ) 16 | from .._utils import stream2elpi 17 | 18 | 19 | def project_graph(adata, to_basis="X_umap", key="epg"): 20 | 21 | obsm = adata.uns[key]["params"]["obsm"] 22 | layer = adata.uns[key]["params"]["layer"] 23 | if obsm is not None: 24 | X = adata.obsm[obsm].copy() 25 | from_basis = obsm 26 | elif layer is not None: 27 | X = adata.layers[layer].copy() 28 | from_basis = layer 29 | else: 30 | X = adata.X 31 | from_basis = 'X' 32 | 33 | suffix = f"_from_{from_basis}_to_{to_basis}" 34 | adata.uns[key + suffix] = copy.deepcopy(adata.uns[key]) 35 | adata.uns[key + suffix]["params"]["obsm"] = "X_umap" 36 | 37 | # proj 38 | adata.uns[key + suffix]["node_pos"] = elpigraph.utils.proj2embedding( 39 | X, 40 | adata.obsm[to_basis], 41 | adata.uns[key]["node_pos"], 42 | ) 43 | empty_nodes = np.where( 44 | np.isnan(adata.uns[key + suffix]["node_pos"][:, 0]) 45 | )[0] 46 | for node in empty_nodes: 47 | neigh_nodes = adata.uns[key + suffix][ 48 | "node_pos" 49 | ][ 50 | np.unique( 51 | adata.uns[key + suffix]["edge"][ 52 | (adata.uns[key + suffix]["edge"] == node).any(axis=1) 53 | ] 54 | ) 55 | ] 56 | adata.uns[key + suffix]["node_pos"][node] = np.nanmean( 57 | neigh_nodes, axis=0 58 | ) 59 | 60 | 61 | def find_paths( 62 | adata, 63 | min_path_len=None, 64 | n_nodes=None, 65 | max_inner_fraction=0.1, 66 | min_node_n_points=None, 67 | max_n_points=None, 68 | min_compactness=0.5, 69 | radius=None, 70 | allow_same_branch=True, 71 | fit_loops=True, 72 | plot=False, 73 | verbose=False, 74 | inplace=False, 75 | use_weights=False, 76 | use_partition=False, 77 | epg_lambda=None, 78 | epg_mu=None, 79 | epg_cycle_lambda=None, 80 | epg_cycle_mu=None, 81 | ignore_equivalent=False, 82 | key="epg", 83 | ): 84 | """This function tries to add extra paths to the graph by computing a 85 | series of principal curves connecting two nodes and retaining plausible 86 | ones using heuristic parameters. 87 | 88 | min_path_len: int, default=None 89 | Minimum distance along the graph (in number of nodes) 90 | that separates the two nodes to connect with a principal curve 91 | n_nodes: int, default=None 92 | Number of nodes in the candidate principal curves 93 | max_inner_fraction: float in [0,1], default=0.1 94 | Maximum fraction of points inside vs outside the loop 95 | (controls how empty the loop formed with the added path should be) 96 | min_node_n_points: int, default=1 97 | Minimum number of points associated to nodes of the principal curve 98 | (prevents creating paths through empty space) 99 | max_n_points: int, default=5% of the number of points 100 | Maximum number of points inside the loop 101 | min_compactness: float in [0,1], default=0.5 102 | Minimum 'roundness' of the loop (1=more round) 103 | (if very narrow loops are not desired) 104 | radius: float, default=None 105 | Max distance in space that separates 106 | the two nodes to connect with a principal curve 107 | allow_same_branch: bool, default=True 108 | Whether to allow new paths to connect two nodes from the same branch 109 | fit_loops: bool, default=True 110 | Whether to refit the graph to data after adding the new paths 111 | plot: bool, default=False 112 | Whether to plot selected candidate paths 113 | verbose: bool, default=False 114 | copy: bool, default=False 115 | use_weights: bool, default=False 116 | Whether to use point weights 117 | use_partition: bool or list, default=False 118 | """ 119 | if use_partition: 120 | if verbose: 121 | print("Searching potential loops for each partition...") 122 | if type(use_partition) is bool: 123 | partitions = adata.obs["partition"].unique() 124 | elif type(use_partition) is list: 125 | partitions = use_partition 126 | else: 127 | raise ValueError( 128 | "use_partition should be a bool" + "or a list of partitions" 129 | ) 130 | 131 | merged_nodep = [] 132 | merged_edges = [] 133 | num_edges = 0 134 | for part in adata.obs["partition"].unique(): 135 | 136 | p_adata = _subset_adata(adata, part) 137 | if part in partitions: 138 | _find_paths( 139 | p_adata, 140 | min_path_len=min_path_len, 141 | n_nodes=n_nodes, 142 | max_inner_fraction=max_inner_fraction, 143 | min_node_n_points=min_node_n_points, 144 | max_n_points=max_n_points, 145 | min_compactness=min_compactness, 146 | radius=radius, 147 | allow_same_branch=allow_same_branch, 148 | fit_loops=fit_loops, 149 | epg_lambda=epg_lambda, 150 | epg_mu=epg_mu, 151 | epg_cycle_lambda=epg_cycle_lambda, 152 | epg_cycle_mu=epg_cycle_mu, 153 | use_weights=use_weights, 154 | ignore_equivalent=ignore_equivalent, 155 | plot=plot, 156 | verbose=verbose, 157 | inplace=inplace, 158 | key=key, 159 | ) 160 | 161 | merged_nodep.append(p_adata.uns[key]["node_pos"]) 162 | merged_edges.append(p_adata.uns[key]["edge"] + num_edges) 163 | num_edges += len(p_adata.uns[key]["node_pos"]) 164 | 165 | adata.uns[key] = {} 166 | adata.uns[key]["node_pos"] = np.concatenate(merged_nodep) 167 | adata.uns[key]["edge"] = np.concatenate((merged_edges)) 168 | adata.uns[key]["node_partition"] = np.repeat( 169 | adata.obs["partition"].unique(), 170 | [len(nodep) for nodep in merged_nodep], 171 | ).astype(str) 172 | adata.uns[key]["edge_partition"] = np.repeat( 173 | adata.obs["partition"].unique(), 174 | [len(edges) for edges in merged_edges], 175 | ).astype(str) 176 | adata.uns[key]["params"] = p_adata.uns[key]["params"] 177 | 178 | X = _get_graph_data(adata, key=key) 179 | _store_graph_attributes(adata, X, key=key) 180 | 181 | else: 182 | if verbose: 183 | print("Searching potential loops...") 184 | 185 | _find_paths( 186 | adata, 187 | min_path_len=min_path_len, 188 | n_nodes=n_nodes, 189 | max_inner_fraction=max_inner_fraction, 190 | min_node_n_points=min_node_n_points, 191 | max_n_points=max_n_points, 192 | min_compactness=min_compactness, 193 | radius=radius, 194 | allow_same_branch=allow_same_branch, 195 | fit_loops=fit_loops, 196 | epg_lambda=epg_lambda, 197 | epg_mu=epg_mu, 198 | epg_cycle_lambda=epg_cycle_lambda, 199 | epg_cycle_mu=epg_cycle_mu, 200 | use_weights=use_weights, 201 | ignore_equivalent=ignore_equivalent, 202 | plot=plot, 203 | verbose=verbose, 204 | inplace=inplace, 205 | key=key, 206 | ) 207 | 208 | 209 | def _find_paths( 210 | adata, 211 | min_path_len=None, 212 | n_nodes=None, 213 | max_inner_fraction=0.1, 214 | min_node_n_points=None, 215 | max_n_points=None, 216 | min_compactness=0.5, 217 | radius=None, 218 | allow_same_branch=True, 219 | fit_loops=True, 220 | epg_lambda=None, 221 | epg_mu=None, 222 | epg_cycle_lambda=None, 223 | epg_cycle_mu=None, 224 | use_weights=False, 225 | ignore_equivalent=False, 226 | plot=False, 227 | verbose=True, 228 | inplace=False, 229 | key="epg", 230 | ): 231 | 232 | # --- Init parameters, variables 233 | if use_weights: 234 | if "pointweights" not in adata.obs: 235 | raise ValueError( 236 | "adata.obs['pointweights'] missing. Run st2.tl.get_weights" 237 | ) 238 | weights = np.array(adata.obs["pointweights"]).reshape((-1, 1)) 239 | else: 240 | weights = None 241 | 242 | X = _get_graph_data(adata, key) 243 | PG = stream2elpi(adata, key) 244 | PG = elpigraph.findPaths( 245 | X, 246 | PG, 247 | Mu=epg_mu, 248 | Lambda=epg_lambda, 249 | cycle_Lambda=epg_cycle_lambda, 250 | cycle_Mu=epg_cycle_mu, 251 | min_path_len=min_path_len, 252 | nnodes=n_nodes, 253 | max_inner_fraction=max_inner_fraction, 254 | min_node_n_points=min_node_n_points, 255 | max_n_points=max_n_points, 256 | min_compactness=min_compactness, 257 | radius=radius, 258 | allow_same_branch=allow_same_branch, 259 | fit_loops=fit_loops, 260 | weights=weights, 261 | ignore_equivalent=ignore_equivalent, 262 | plot=plot, 263 | verbose=verbose, 264 | ) 265 | 266 | if PG is None: 267 | return 268 | if inplace: 269 | adata.uns[key]["node_pos"] = PG["addLoopsdict"]["merged_nodep"] 270 | adata.uns[key]["edge"] = PG["addLoopsdict"]["merged_edges"] 271 | # update edge_len, conn, data projection 272 | _store_graph_attributes(adata, X, key) 273 | 274 | 275 | def add_path( 276 | adata, 277 | source, 278 | target, 279 | n_nodes=None, 280 | use_weights=False, 281 | refit_graph=False, 282 | epg_mu=None, 283 | epg_lambda=None, 284 | epg_cycle_mu=None, 285 | epg_cycle_lambda=None, 286 | key="epg", 287 | ): 288 | 289 | # --- Init parameters, variables 290 | if epg_mu is None: 291 | epg_mu = adata.uns[key]["params"]["epg_mu"] 292 | if epg_lambda is None: 293 | epg_lambda = adata.uns[key]["params"]["epg_lambda"] 294 | if epg_cycle_mu is None: 295 | epg_cycle_mu = epg_mu 296 | if epg_cycle_lambda is None: 297 | epg_cycle_lambda = epg_lambda 298 | if use_weights: 299 | weights = np.array(adata.obs["pointweights"])[:, None] 300 | else: 301 | weights = None 302 | 303 | X = _get_graph_data(adata, key) 304 | PG = stream2elpi(adata, key) 305 | PG = elpigraph.addPath( 306 | X, 307 | PG=PG, 308 | source=source, 309 | target=target, 310 | n_nodes=n_nodes, 311 | weights=weights, 312 | refit_graph=refit_graph, 313 | Mu=epg_mu, 314 | Lambda=epg_lambda, 315 | cycle_Mu=epg_cycle_mu, 316 | cycle_Lambda=epg_cycle_lambda, 317 | ) 318 | 319 | adata.uns["epg"]["node_pos"] = PG["NodePositions"] 320 | adata.uns["epg"]["edge"] = PG["Edges"][0] 321 | 322 | # update edge_len, conn, data projection 323 | _store_graph_attributes(adata, X, key) 324 | 325 | 326 | def del_path( 327 | adata, 328 | source, 329 | target, 330 | nodes_to_include=None, 331 | use_weights=False, 332 | refit_graph=False, 333 | epg_mu=None, 334 | epg_lambda=None, 335 | epg_cycle_mu=None, 336 | epg_cycle_lambda=None, 337 | key="epg", 338 | ): 339 | 340 | # --- Init parameters, variables 341 | if epg_mu is None: 342 | epg_mu = adata.uns[key]["params"]["epg_mu"] 343 | if epg_lambda is None: 344 | epg_lambda = adata.uns[key]["params"]["epg_lambda"] 345 | if epg_cycle_mu is None: 346 | epg_cycle_mu = epg_mu 347 | if epg_cycle_lambda is None: 348 | epg_cycle_lambda = epg_lambda 349 | if use_weights: 350 | weights = np.array(adata.obs["pointweights"])[:, None] 351 | else: 352 | weights = None 353 | 354 | X = _get_graph_data(adata, key) 355 | PG = stream2elpi(adata, key) 356 | PG = elpigraph.delPath( 357 | X, 358 | PG=PG, 359 | source=source, 360 | target=target, 361 | nodes_to_include=nodes_to_include, 362 | weights=weights, 363 | refit_graph=refit_graph, 364 | Mu=epg_mu, 365 | Lambda=epg_lambda, 366 | cycle_Mu=epg_cycle_mu, 367 | cycle_Lambda=epg_cycle_lambda, 368 | ) 369 | 370 | adata.uns["epg"]["node_pos"] = PG["NodePositions"] 371 | adata.uns["epg"]["edge"] = PG["Edges"][0] 372 | 373 | # update edge_len, conn, data projection 374 | _store_graph_attributes(adata, X, key) 375 | 376 | 377 | def prune_graph( 378 | adata, 379 | mode="PointNumber", 380 | collapse_par=5, 381 | trimming_radius=np.inf, 382 | refit_graph=False, 383 | copy=False, 384 | ): 385 | pg = { 386 | "NodePositions": adata.uns["epg"]["node_pos"].copy(), 387 | "Edges": [adata.uns["epg"]["edge"].copy()], 388 | } 389 | pg2 = elpigraph.CollapseBranches( 390 | adata.obsm["X_dr"], 391 | pg, 392 | ControlPar=collapse_par, 393 | Mode=mode, 394 | TrimmingRadius=trimming_radius, 395 | ) 396 | if not copy: 397 | if refit_graph: 398 | params = adata.uns["epg"]["params"] 399 | learn_graph( 400 | adata, 401 | method=params["method"], 402 | obsm=params["obsm"], 403 | layer=params["layer"], 404 | n_nodes=params["n_nodes"], 405 | epg_lambda=params["epg_lambda"], 406 | epg_mu=params["epg_mu"], 407 | epg_alpha=params["epg_alpha"], 408 | use_seed=False, 409 | InitNodePositions=pg2["Nodes"], 410 | InitEdges=pg2["Edges"], 411 | ) 412 | else: 413 | adata.uns["epg"]["node_pos"] = pg2["Nodes"] 414 | adata.uns["epg"]["edge"] = pg2["Edges"] 415 | else: 416 | return pg2 417 | 418 | 419 | def find_disconnected_components( 420 | adata, groups="leiden", neighbors_key=None, verbose=True 421 | ): 422 | """Find if data contains disconnected components. 423 | 424 | Inputs 425 | ------ 426 | adata : anndata.AnnData class instance 427 | 428 | Returns 429 | ------- 430 | adata.obs['partition']: component assignment of points 431 | """ 432 | 433 | if groups not in adata.obs: 434 | raise ValueError(f"{groups} not found in adata.obs") 435 | 436 | sc.tl.paga(adata, groups=groups, neighbors_key=neighbors_key) 437 | # edges = np.argwhere(adata.uns["paga"]["connectivities"]) 438 | # edges_tree = np.argwhere(adata.uns["paga"]["connectivities_tree"]) 439 | g = nx.convert_matrix.from_scipy_sparse_array( 440 | adata.uns["paga"]["connectivities_tree"] 441 | ) 442 | comps = [list(c) for c in nx.algorithms.components.connected_components(g)] 443 | clus_idx = [ 444 | np.where(adata.obs[adata.uns["paga"]["groups"]].astype(int) == i)[0] 445 | for i in g.nodes 446 | ] 447 | 448 | partition = np.zeros(len(adata), dtype=object) 449 | for i, comp in enumerate(comps): 450 | comp_idx = np.concatenate([clus_idx[i] for i in comp]) 451 | partition[comp_idx] = str(i) 452 | adata.obs["partition"] = partition 453 | print(f"Found", len(adata.obs["partition"].unique()), "components") 454 | 455 | 456 | def get_weights( 457 | adata, 458 | obsm="X_dr", 459 | layer=None, 460 | bandwidth=1, 461 | griddelta=100, 462 | exponent=1, 463 | method="sklearn", 464 | **kwargs, 465 | ): 466 | if sum(list(map(lambda x: x is not None, [layer, obsm]))) == 2: 467 | raise ValueError("Only one of `layer` and `obsm` can be used") 468 | elif obsm is not None: 469 | if obsm in adata.obsm: 470 | mat = adata.obsm[obsm] 471 | else: 472 | raise ValueError(f"could not find {obsm} in `adata.obsm`") 473 | elif layer is not None: 474 | if layer in adata.layers: 475 | mat = adata.layers[layer] 476 | else: 477 | raise ValueError(f"could not find {layer} in `adata.layers`") 478 | else: 479 | mat = adata.X 480 | 481 | adata.obs["pointweights"] = elpigraph.utils.getWeights( 482 | mat, bandwidth, griddelta, exponent, method, **kwargs 483 | ) 484 | 485 | 486 | def get_component(adata, component): 487 | sadata = _subset_adata(adata, component) 488 | for key in ["seed_epg", "epg"]: 489 | if key in sadata.uns: 490 | X = _get_graph_data(sadata, "epg") 491 | _store_graph_attributes(sadata, X, key) 492 | return sadata 493 | 494 | 495 | def ordinal_knn( 496 | adata, 497 | ordinal_label, 498 | obsm="X_pca", 499 | layer=None, 500 | n_neighbors=15, 501 | n_natural=1, 502 | metric="cosine", 503 | method="guide", 504 | return_sparse=False, 505 | stages=None, 506 | ): 507 | """Supervised (ordinal) nearest-neighbor search. 508 | 509 | Parameters 510 | ---------- 511 | n_neighbors: int 512 | Number of neighbors 513 | n_natural: int 514 | Number of natural neighbors (between 0 and n_neighbors-1) 515 | to force the graph to retain. Tunes the strength of supervision 516 | metric: str 517 | One of sklearn's distance metrics 518 | method : str (default='force') 519 | if 'force', for each point at stage[i] get n_neighbors, forcing: 520 | - n_neighbors/3 to be from stage[i-1] 521 | - n_neighbors/3 to be from stage[i] 522 | - n_neighbors/3 to be from stage[i+1] 523 | For stage[0] and stage[-1], 2*n_neighbors/3 are taken from stage[i] 524 | 525 | if 'guide', for each point at stage[i] get n_neighbors 526 | from points in {stage[i-1], stage[i], stage[i+1]}, 527 | without constraints on proportions 528 | return_sparse: bool 529 | Whether to return the graph in sparse form 530 | or as longform indices and distances 531 | stages: list 532 | Ordered list of ordinal label stages (low to high). 533 | If None, taken as np.unique(ordinal_label) 534 | 535 | Returns 536 | ------- 537 | Supervised nearest-neighbors as a graph in sparse form 538 | or as longform indices and distances 539 | 540 | """ 541 | 542 | if sum(list(map(lambda x: x is not None, [layer, obsm]))) == 2: 543 | raise ValueError("Only one of `layer` and `obsm` can be used") 544 | elif obsm is not None: 545 | if obsm in adata.obsm: 546 | mat = adata.obsm[obsm] 547 | else: 548 | raise ValueError(f"could not find {obsm} in `adata.obsm`") 549 | elif layer is not None: 550 | if layer in adata.layers: 551 | mat = adata.layers[layer] 552 | else: 553 | raise ValueError(f"could not find {layer} in `adata.layers`") 554 | else: 555 | mat = adata.X 556 | 557 | out = elpigraph.utils.supervised_knn( 558 | mat, 559 | stages_labels=adata.obs[ordinal_label], 560 | stages=stages, 561 | method=method, 562 | n_neighbors=n_neighbors, 563 | n_natural=n_natural, 564 | m=metric, 565 | return_sparse=return_sparse, 566 | ) 567 | 568 | if return_sparse: 569 | return out 570 | else: 571 | knn_dists, knn_idx = out 572 | return knn_dists, knn_idx 573 | 574 | 575 | def smooth_ordinal_labels( 576 | adata, 577 | root, 578 | ordinal_label, 579 | obsm="X_pca", 580 | layer=None, 581 | n_neighbors=15, 582 | n_natural=1, 583 | metric="euclidean", 584 | method="guide", 585 | stages=None, 586 | ): 587 | """Smooth ordinal labels into a continuous vector 588 | 589 | Parameters 590 | ---------- 591 | root: int 592 | Index of chosen root data points 593 | n_neighbors: int 594 | Number of neighbors 595 | n_natural: int 596 | Number of natural neighbors (between 0 and n_neighbors-1) 597 | to force the graph to retain. Tunes the strength of supervision 598 | metric: str 599 | One of sklearn's distance metrics 600 | method : str (default='force') 601 | if 'force', for each point at stage[i] get n_neighbors, forcing: 602 | - n_neighbors/3 to be from stage[i-1] 603 | - n_neighbors/3 to be from stage[i] 604 | - n_neighbors/3 to be from stage[i+1] 605 | For stage[0] and stage[-1], 2*n_neighbors/3 are taken from stage[i] 606 | 607 | if 'guide', for each point at stage[i] get n_neighbors 608 | from points in {stage[i-1], stage[i], stage[i+1]}, 609 | without constraints on proportions 610 | return_sparse: bool 611 | Whether to return the graph in sparse form 612 | or as longform indices and distances 613 | stages: list 614 | Ordered list of ordinal label stages (low to high). 615 | If None, taken as np.unique(ordinal_label) 616 | 617 | Returns 618 | ------- 619 | adata.obs['ps']: smoothed ordinal labels 620 | 621 | """ 622 | if sum(list(map(lambda x: x is not None, [layer, obsm]))) == 2: 623 | raise ValueError("Only one of `layer` and `obsm` can be used") 624 | elif obsm is not None: 625 | if obsm in adata.obsm: 626 | mat = adata.obsm[obsm] 627 | else: 628 | raise ValueError(f"could not find {obsm} in `adata.obsm`") 629 | elif layer is not None: 630 | if layer in adata.layers: 631 | mat = adata.layers[layer] 632 | else: 633 | raise ValueError(f"could not find {layer} in `adata.layers`") 634 | else: 635 | mat = adata.X 636 | 637 | g = elpigraph.utils.supervised_knn( 638 | mat, 639 | stages_labels=adata.obs[ordinal_label], 640 | stages=stages, 641 | n_natural=n_natural, 642 | n_neighbors=n_neighbors, 643 | m=metric, 644 | method=method, 645 | return_sparse=True, 646 | ) 647 | 648 | adata.obs["ps"] = elpigraph.utils.geodesic_pseudotime( 649 | mat, n_neighbors, root=root, g=g 650 | ) 651 | 652 | 653 | def refit_graph( 654 | adata, 655 | use_weights=False, 656 | shift_nodes_pos={}, 657 | epg_mu=None, 658 | epg_lambda=None, 659 | cycle_epg_mu=None, 660 | cycle_epg_lambda=None, 661 | ): 662 | """Refit graph to data 663 | 664 | Parameters 665 | ---------- 666 | use_weights: bool 667 | Whether to weight points with adata.obs['pointweights'] 668 | shift_nodes_pos: dict 669 | Optional dict to hold some nodes fixed at specified positions 670 | e.g., {2:[.5,.2]} will hold node 2 at coordinates [.5,.2] 671 | epg_mu: float 672 | ElPiGraph Mu parameter 673 | epg_lambda: float 674 | ElPiGraph Lambda parameter 675 | cycle_epg_mu: float 676 | ElPiGraph Mu parameter, specific for nodes that are part of cycles 677 | cycle_epg_lambda: float 678 | ElPiGraph Lambda parameter, specific for nodes that are part of cycles 679 | """ 680 | # --- Init parameters, variables 681 | if epg_mu is None: 682 | epg_mu = adata.uns["epg"]["params"]["epg_mu"] 683 | if epg_lambda is None: 684 | epg_lambda = adata.uns["epg"]["params"]["epg_lambda"] 685 | if cycle_epg_mu is None: 686 | cycle_epg_mu = epg_mu 687 | if cycle_epg_lambda is None: 688 | cycle_epg_lambda = epg_lambda 689 | if use_weights: 690 | weights = np.array(adata.obs["pointweights"])[:, None] 691 | else: 692 | weights = None 693 | 694 | X = _get_graph_data(adata, "epg") 695 | PG = stream2elpi(adata, "epg") 696 | elpigraph._graph_editing.refitGraph( 697 | X, 698 | PG=PG, 699 | shift_nodes_pos=shift_nodes_pos, 700 | PointWeights=weights, 701 | Mu=epg_mu, 702 | Lambda=epg_lambda, 703 | cycle_Mu=cycle_epg_mu, 704 | cycle_Lambda=cycle_epg_lambda, 705 | ) 706 | 707 | adata.uns["epg"]["node_pos"] = PG["NodePositions"] 708 | 709 | # update edge_len, conn, data projection 710 | _store_graph_attributes(adata, X, "epg") 711 | 712 | 713 | def extend_leaves( 714 | adata, 715 | Mode="QuantDists", 716 | ControlPar=0.5, 717 | DoSA=True, 718 | DoSA_maxiter=200, 719 | LeafIDs=None, 720 | TrimmingRadius=float("inf"), 721 | key="epg", 722 | ): 723 | """Extend leaves with additional nodes 724 | 725 | Parameters 726 | ----------- 727 | Mode: str, the mode used to extend the graph. 728 | "QuantDists","QuantCentroid", "WeigthedCentroid" 729 | LeafIDs: int vector, 730 | The id of nodes to extend. If None, all the vertices will be extended. 731 | TrimmingRadius: positive numeric 732 | The trimming radius used to control distance 733 | DoSA: bool 734 | Should optimization (via simulated annealing) 735 | be performed when Mode = "QuantDists"? 736 | ControlPar: positive numeric 737 | The parameter used to control the contribution of 738 | the different data points 739 | 740 | The value of ControlPar has a different interpretation 741 | depending on the valus of Mode. 742 | In each case, for only the extreme points, 743 | i.e., the points associated with the leaf node that 744 | do not have a projection on any edge are considered. 745 | 746 | If Mode = "QuantCentroid", for each leaf node, 747 | the extreme points are ordered by their distance from the node 748 | and the centroid of the points farther away 749 | than ControlPar is returned. 750 | 751 | If Mode = "WeightedCentroid", for each leaf node, 752 | a weight is computed for each points 753 | by raising the distance to the ControlPar power. 754 | Hence, larger values of ControlPar result in a larger influence 755 | of points farther from the node 756 | """ 757 | 758 | X = _get_graph_data(adata, key) 759 | 760 | PG = elpigraph.ExtendLeaves( 761 | X.astype(float), 762 | PG=stream2elpi(adata, key), 763 | Mode=Mode, 764 | ControlPar=ControlPar, 765 | DoSA=DoSA, 766 | DoSA_maxiter=DoSA_maxiter, 767 | LeafIDs=LeafIDs, 768 | TrimmingRadius=TrimmingRadius, 769 | ) 770 | 771 | adata.uns[key]["node_pos"] = PG["NodePositions"] 772 | adata.uns[key]["edge"] = PG["Edges"][0] 773 | _store_graph_attributes(adata, X, key) 774 | 775 | def grow_leaves( 776 | adata, 777 | n_nodes=20, 778 | use_weights=False, 779 | epg_mu=None, 780 | epg_lambda=None, 781 | epg_cycle_mu=None, 782 | epg_cycle_lambda=None, 783 | key="epg", 784 | ): 785 | """Grow leaves using elpigraph optimization 786 | 787 | Parameters 788 | ---------- 789 | use_weights: bool 790 | Whether to weight points with adata.obs['pointweights'] 791 | shift_nodes_pos: dict 792 | Optional dict to hold some nodes fixed at specified positions 793 | e.g., {2:[.5,.2]} will hold node 2 at coordinates [.5,.2] 794 | epg_mu: float 795 | ElPiGraph Mu parameter 796 | epg_lambda: float 797 | ElPiGraph Lambda parameter 798 | cycle_epg_mu: float 799 | ElPiGraph Mu parameter, specific for nodes that are part of cycles 800 | cycle_epg_lambda: float 801 | ElPiGraph Lambda parameter, specific for nodes that are part of cycles 802 | """ 803 | # --- Init parameters, variables 804 | if epg_mu is None: 805 | epg_mu = adata.uns[key]["params"]["epg_mu"] 806 | if epg_lambda is None: 807 | epg_lambda = adata.uns[key]["params"]["epg_lambda"] 808 | if epg_cycle_mu is None: 809 | epg_cycle_mu = epg_mu 810 | if epg_cycle_lambda is None: 811 | epg_cycle_lambda = epg_lambda 812 | if use_weights: 813 | weights = np.array(adata.obs["pointweights"])[:, None] 814 | else: 815 | weights = None 816 | 817 | X = _get_graph_data(adata, key) 818 | PG = elpigraph.GrowLeaves( 819 | X, 820 | NumNodes=n_nodes+len(adata.uns["epg"]["node_pos"]), 821 | InitNodePositions=adata.uns["epg"]["node_pos"], 822 | InitEdges=adata.uns["epg"]["edge"], 823 | PointWeights=weights, 824 | Mu=epg_mu, 825 | Lambda=epg_lambda, 826 | verbose=1, 827 | Do_PCA=False, 828 | CenterData=False 829 | )[0] 830 | 831 | adata.uns["epg"]["node_pos"] = PG["NodePositions"] 832 | adata.uns["epg"]["edge"] = PG["Edges"][0] 833 | 834 | # update edge_len, conn, data projection 835 | _store_graph_attributes(adata, X, key) 836 | 837 | def nodes_info(adata,key='epg'): 838 | '''Return dict of graph nodes classified into leaf, branching, branch 839 | ''' 840 | g = elpigraph.src.graphs.ConstructGraph(stream2elpi(adata, key=key)) 841 | leaf = np.where(np.array(g.degree()) == 1)[0] 842 | branching = np.where(np.array(g.degree()) > 2)[0] 843 | branch = np.where(np.array(g.degree()) == 2)[0] 844 | return {'leaf':leaf, 'branching':branching, 'branch':branch} 845 | 846 | def use_graph_with_n_nodes(adata, n_nodes): 847 | """Use the graph at n_nodes. 848 | This requires having run st2.tl.learn_graph with store_evolution=True 849 | """ 850 | 851 | adata.uns["epg"]["node_pos"] = adata.uns["epg"]["graph_evolution"][ 852 | "all_node_pos" 853 | ][n_nodes] 854 | adata.uns["epg"]["edge"] = elpigraph.src.core.DecodeElasticMatrix2( 855 | adata.uns["epg"]["graph_evolution"]["all_edge"][n_nodes] 856 | )[0] 857 | adata.uns["epg"]["conn"] = scipy.sparse.csr_matrix( 858 | adata.uns["epg"]["graph_evolution"]["all_edge"][n_nodes] 859 | ) 860 | X = _get_graph_data(adata, "epg") 861 | _store_graph_attributes(adata, X, "epg") 862 | 863 | 864 | def early_groups( 865 | adata, 866 | branch_nodes, 867 | source, 868 | target, 869 | nodes_to_include=None, 870 | flavor="ot_unbalanced", 871 | n_windows=20, 872 | n_neighbors=20, 873 | ot_reg_e=0.01, 874 | ot_reg_m=0.001, 875 | key="epg", 876 | ): 877 | """ 878 | Split data between source and target (with target a branching node) 879 | into n_windows slices along pseudotime. 880 | Then try to guess which branch the data prior 881 | to the branching most resembles. 882 | branch_nodes are adjacent to target and represent the separate branches. 883 | Labels are propagated back in pseudotime for each of the n_windows slices 884 | (e.g., from branch_nodes to slice[n_windows-1], 885 | then from slice[n_windows-1] to slice[n_windows-2],etc) 886 | 887 | Parameters 888 | ---------- 889 | branch_nodes: list[int] 890 | List of node labels adjacent to target branch node 891 | source: int 892 | Root node label 893 | target: int 894 | Branching node label 895 | nodes_to_include: list[int] 896 | Nodes to include in the path between source and target 897 | flavor: str 898 | How to propagate labels from branch_nodes 899 | to the previous pseudotime slice 900 | "ot" for optimal transport 901 | "ot_unbalanced" for unbalanced OT 902 | "ot_equal" for OT with weight of each branch_nodes equalized 903 | "knn" for simple nearest-neighbor search 904 | n_windows: int 905 | How many slices along pseudotime to make 906 | with data between source and target 907 | n_neighbors: int 908 | Number of nearest neighbors for flavor= 909 | ot_reg_e: float 910 | Unbalanced optimal transport entropic regularization parameter 911 | ot_reg_m: float 912 | Unbalanced optimal transport unbalanced parameter 913 | key: str 914 | Graph key 915 | """ 916 | X = _get_graph_data(adata, key) 917 | PG = stream2elpi(adata, key) 918 | elpigraph.utils.early_groups( 919 | X, 920 | PG, 921 | branch_nodes=branch_nodes, 922 | source=source, 923 | target=target, 924 | nodes_to_include=nodes_to_include, 925 | flavor=flavor, 926 | n_windows=n_windows, 927 | n_neighbors=n_neighbors, 928 | ot_reg_e=ot_reg_e, 929 | ot_reg_m=ot_reg_m, 930 | ) 931 | 932 | s = "-".join(str(x) for x in branch_nodes) 933 | adata.obs[f"early_groups_{source}->{s}"] = PG[ 934 | f"early_groups_{source}->{s}" 935 | ] 936 | adata.obs[f"early_groups_{source}->{s}_clusters"] = PG[ 937 | f"early_groups_{source}->{s}_clusters" 938 | ] 939 | 940 | 941 | def interpolate( 942 | adata, 943 | t_len=200, 944 | method="knn", 945 | frac=0.1, 946 | n_neighbors="auto", 947 | weights="uniform", 948 | key="epg", 949 | ): 950 | """Resample adata.X by interpolation along pseudotime with t_len values 951 | 952 | Parameters 953 | ---------- 954 | t_len: int 955 | Number of pseudotime values to resample 956 | method: str 957 | 'knn' for sklearn.neighbors.KNeighborsRegressor 958 | 'lowess' for statsmodels.api.nonparametric.lowess (can be slow) 959 | frac: float 0-1 960 | lowess frac parameter 961 | n_neighbors: int 962 | KNeighborsRegressor n_neighbors parameter 963 | weights: str, 'uniform' or 'distance' 964 | KNeighborsRegressor weights parameter. 965 | 966 | Returns 967 | ------- 968 | t_new: np.array 969 | Resampled pseudotime values 970 | interp: np.array 971 | Resampled adata.X 972 | """ 973 | 974 | X = adata.X 975 | pseudotime = adata.obs[f"{key}_pseudotime"] 976 | 977 | idx_path = ~np.isnan(pseudotime) 978 | X_path = X[idx_path] 979 | 980 | t_path = np.array(pseudotime[idx_path]).reshape(-1, 1) 981 | t_new = np.linspace(pseudotime.min(), pseudotime.max(), t_len).reshape( 982 | -1, 1 983 | ) 984 | 985 | if method == "knn": 986 | if n_neighbors == "auto": 987 | n_neighbors = int(len(X_path) * 0.05) 988 | reg = KNeighborsRegressor(n_neighbors=n_neighbors, weights=weights) 989 | interp = reg.fit(X=t_path, y=X_path).predict(t_new) 990 | 991 | elif method == "lowess": # very slow 992 | interp = np.zeros((t_len, X_path.shape[1])) 993 | for i in range(X_path.shape[1]): 994 | interp[:, i] = statsmodels.api.nonparametric.lowess( 995 | X_path[:, i], 996 | t_path.flat, 997 | it=1, 998 | frac=frac, 999 | xvals=t_new.flat, 1000 | return_sorted=False, 1001 | ) 1002 | else: 1003 | raise ValueError("method must be one of 'knn','lowess'") 1004 | return t_new, interp 1005 | -------------------------------------------------------------------------------- /stream2/tools/_elpigraph.py: -------------------------------------------------------------------------------- 1 | """Functions to calculate principal graph.""" 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import scipy 6 | import elpigraph 7 | import networkx as nx 8 | from copy import deepcopy 9 | from sklearn.cluster import SpectralClustering, AffinityPropagation, KMeans 10 | from sklearn.metrics.pairwise import pairwise_distances, euclidean_distances 11 | 12 | from .._settings import settings 13 | 14 | 15 | def learn_graph( 16 | adata, 17 | method="principal_tree", 18 | obsm="X_dr", 19 | layer=None, 20 | n_nodes=50, 21 | epg_lambda=0.01, 22 | epg_mu=0.1, 23 | epg_alpha=0.02, 24 | epg_trimming_radius=float("inf"), 25 | use_seed=True, 26 | use_partition=False, 27 | use_weights=False, 28 | ordinal_label=None, 29 | ordinal_supervision_strength=1, 30 | ordinal_root_point=None, 31 | n_jobs=None, 32 | GPU=False, 33 | max_candidates={"AddNode2Node": 20, "BisectEdge": 20, "ShrinkEdge": 50}, 34 | store_evolution=False, 35 | **kwargs, 36 | ): 37 | """Learn principal graph. 38 | 39 | Parameters 40 | ---------- 41 | adata: `AnnData` 42 | Anndata object. 43 | method: `str`, (default: 'principal_curve'); 44 | Method used to calculate the graph. 45 | obsm: `str`, optional (default: 'X_dr') 46 | The multi-dimensional annotation of observations 47 | used to learn the graph 48 | layer: `str`, optional (default: None) 49 | The layer used to learn the graph 50 | use_seed: bool 51 | Whether to use the seed graph in adata.uns['seed_epg'] 52 | generated by st.seed_graph. 53 | If True, ignores obsm and layer parameters 54 | use_partition: bool 55 | Whether to learn a disconnected graph 56 | for each category in adata.uns['partition'] 57 | use_weights: bool 58 | Whether to weight points with adata.obs['pointweights'] 59 | GPU: 60 | Whether to perform computations using GPU (requires cupy library) 61 | max_candidates: 62 | Max number of candidates to generate with each graph grammar 63 | when exploring graph topology at each iteration. 64 | Setting numbers lower can increase speed, 65 | especially for higher number of nodes. 66 | store_evolution: 67 | Store the evolution of the graph for each number of nodes 68 | **kwargs: 69 | Additional arguments to each method 70 | 71 | Returns 72 | ------- 73 | updates `adata.uns['epg']` with the following fields. 74 | conn: `sparse matrix` (`.uns['epg']['conn']`) 75 | A connectivity sparse matrix. 76 | node_pos: `array` (`.uns['epg']['node_pos']`) 77 | Node positions. 78 | edge: `array` (`.uns['epg']['edge']`) 79 | Node edges. 80 | """ 81 | if use_partition: 82 | print("Learning elastic principal graph for each partition...") 83 | if type(use_partition) is bool: 84 | partitions = adata.obs["partition"].unique() 85 | elif type(use_partition) is list: 86 | partitions = use_partition 87 | else: 88 | raise ValueError( 89 | "use_partition should be a bool or a list of partitions" 90 | ) 91 | 92 | if ordinal_label is not None: 93 | raise ValueError( 94 | "use_partition can't be used together with ordinal_label" 95 | ) 96 | if store_evolution: 97 | raise ValueError( 98 | "can't use store_evolution=True when use_partition=True" 99 | ) 100 | 101 | merged_nodep = [] 102 | merged_edges = [] 103 | num_edges = 0 104 | for part in adata.obs["partition"].unique(): 105 | 106 | if part not in partitions: 107 | p_adata = _subset_adata(adata, part) 108 | else: 109 | if use_seed: 110 | p_adata = _subset_adata(adata, part) 111 | if len(p_adata.uns["seed_epg"]["node_pos"]) < n_nodes: 112 | nnodes = n_nodes 113 | else: 114 | nnodes = len(p_adata.uns["seed_epg"]["node_pos"]) + 1 115 | else: 116 | p_adata = adata[adata.obs["partition"] == part].copy() 117 | 118 | _learn_graph( 119 | p_adata, 120 | method=method, 121 | obsm=obsm, 122 | layer=layer, 123 | n_nodes=nnodes, 124 | epg_lambda=epg_lambda, 125 | epg_mu=epg_mu, 126 | epg_alpha=epg_alpha, 127 | epg_trimming_radius=epg_trimming_radius, 128 | use_seed=use_seed, 129 | use_weights=use_weights, 130 | n_jobs=n_jobs, 131 | ordinal_label=ordinal_label, 132 | ordinal_supervision_strength=ordinal_supervision_strength, 133 | ordinal_root_point=[ordinal_root_point], 134 | GPU=GPU, 135 | max_candidates=max_candidates, 136 | store_evolution=store_evolution, 137 | **kwargs, 138 | ) 139 | 140 | merged_nodep.append(p_adata.uns["epg"]["node_pos"]) 141 | merged_edges.append(p_adata.uns["epg"]["edge"] + num_edges) 142 | num_edges += len(p_adata.uns["epg"]["node_pos"]) 143 | 144 | adata.uns["epg"] = {} 145 | adata.uns["epg"]["node_pos"] = np.concatenate(merged_nodep) 146 | adata.uns["epg"]["edge"] = np.concatenate((merged_edges)) 147 | adata.uns["epg"]["node_partition"] = np.repeat( 148 | adata.obs["partition"].unique(), 149 | [len(nodep) for nodep in merged_nodep], 150 | ).astype(str) 151 | adata.uns["epg"]["edge_partition"] = np.repeat( 152 | adata.obs["partition"].unique(), 153 | [len(edges) for edges in merged_edges], 154 | ).astype(str) 155 | adata.uns["epg"]["params"] = p_adata.uns["epg"]["params"] 156 | 157 | X = _get_graph_data(adata, key="epg") 158 | _store_graph_attributes(adata, X, key="epg") 159 | 160 | else: 161 | _learn_graph( 162 | adata, 163 | method=method, 164 | obsm=obsm, 165 | layer=layer, 166 | n_nodes=n_nodes, 167 | epg_lambda=epg_lambda, 168 | epg_mu=epg_mu, 169 | epg_alpha=epg_alpha, 170 | epg_trimming_radius=epg_trimming_radius, 171 | use_seed=use_seed, 172 | use_weights=use_weights, 173 | ordinal_label=ordinal_label, 174 | ordinal_supervision_strength=ordinal_supervision_strength, 175 | ordinal_root_point=[ordinal_root_point], 176 | n_jobs=n_jobs, 177 | GPU=GPU, 178 | max_candidates=max_candidates, 179 | store_evolution=store_evolution, 180 | **kwargs, 181 | ) 182 | 183 | 184 | def _learn_graph( 185 | adata, 186 | method="principal_tree", 187 | obsm="X_dr", 188 | layer=None, 189 | n_nodes=50, 190 | epg_lambda=0.01, 191 | epg_mu=0.1, 192 | epg_alpha=0.02, 193 | epg_trimming_radius=float("inf"), 194 | use_seed=True, 195 | use_weights=False, 196 | ordinal_label=None, 197 | ordinal_supervision_strength=1, 198 | ordinal_root_point=None, 199 | n_jobs=None, 200 | GPU=False, 201 | max_candidates={"AddNode2Node": 20, "BisectEdge": 20, "ShrinkEdge": 50}, 202 | store_evolution=False, 203 | **kwargs, 204 | ): 205 | """Learn principal graph. 206 | 207 | Parameters 208 | ---------- 209 | adata: `AnnData` 210 | Anndata object. 211 | method: `str`, (default: 'principal_curve'); 212 | Method used to calculate the graph. 213 | obsm: `str`, optional (default: 'X_dr') 214 | The multi-dimensional annotation of observations 215 | used to learn the graph 216 | layer: `str`, optional (default: None) 217 | The layer used to learn the graph 218 | use_seed: bool 219 | Whether to use the seed graph in adata.uns['seed_epg'] 220 | generated by st.seed_graph. 221 | If True, ignores obsm and layer parameters 222 | **kwargs: 223 | Additional arguments to each method 224 | 225 | Returns 226 | ------- 227 | updates `adata.uns['epg']` with the following fields. 228 | conn: `sparse matrix` (`.uns['epg']['conn']`) 229 | A connectivity sparse matrix. 230 | node_pos: `array` (`.uns['epg']['node_pos']`) 231 | Node positions. 232 | """ 233 | 234 | assert method in [ 235 | "principal_curve", 236 | "principal_tree", 237 | "principal_circle", 238 | ], ( 239 | "`method` must be one of " 240 | "['principal_curve','principal_tree','principal_circle']" 241 | ) 242 | 243 | if use_seed and (method == "principal_tree"): 244 | if "seed_epg" not in adata.uns: 245 | raise ValueError( 246 | "could not find 'seed_epg' in `adata.uns. Please run" 247 | " st.tl.seed_graph" 248 | ) 249 | if n_nodes <= len(adata.uns["seed_epg"]["node_pos"]): 250 | raise ValueError( 251 | f"The seed graph already has at least {n_nodes} nodes. Please" 252 | " run st.tl.learn_graph with higher n_nodes" 253 | ) 254 | kwargs["InitNodePositions"] = adata.uns["seed_epg"]["node_pos"] 255 | kwargs["InitEdges"] = adata.uns["seed_epg"]["edge"] 256 | if adata.uns["seed_epg"]["params"]["obsm"] is not None: 257 | mat = adata.obsm[adata.uns["seed_epg"]["params"]["obsm"]] 258 | elif adata.uns["seed_epg"]["params"]["layer"] is not None: 259 | mat = adata.obsm[adata.uns["seed_epg"]["params"]["layer"]] 260 | else: 261 | print("Learning the graph on adata.X") 262 | mat = adata.X 263 | else: 264 | if ( 265 | use_seed 266 | and (method != "principal_tree") 267 | and ("seed_epg" in adata.uns) 268 | ): 269 | print(f"WARNING: seed graph is ignored when using method {method}") 270 | 271 | kwargs["InitNodePositions"] = None 272 | kwargs["InitEdges"] = None 273 | if sum(list(map(lambda x: x is not None, [layer, obsm]))) == 2: 274 | raise ValueError("Only one of `layer` and `obsm` can be used") 275 | elif obsm is not None: 276 | if obsm in adata.obsm: 277 | mat = adata.obsm[obsm] 278 | else: 279 | raise ValueError(f"could not find {obsm} in `adata.obsm`") 280 | elif layer is not None: 281 | if layer in adata.layers: 282 | mat = adata.layers[layer] 283 | else: 284 | raise ValueError(f"could not find {layer} in `adata.layers`") 285 | else: 286 | mat = adata.X 287 | 288 | if n_jobs is None: 289 | n_jobs = settings.n_jobs 290 | 291 | if use_weights: 292 | if "pointweights" not in adata.obs: 293 | raise ValueError( 294 | "adata.obs['pointweights'] not found. Please run" 295 | " st2.tl.get_weights" 296 | ) 297 | weights = np.array(adata.obs["pointweights"]).reshape((-1, 1)) 298 | else: 299 | weights = None 300 | 301 | if ordinal_label is not None: 302 | if type(ordinal_label) == str: 303 | kwargs["pseudotime"] = adata.obs[ordinal_label].to_numpy() 304 | else: 305 | raise ValueError(f"ordinal_label key not found in adata.obs.") 306 | kwargs["pseudotimeLambda"] = ordinal_supervision_strength 307 | kwargs["FixNodesAtPoints"] = [ordinal_root_point] 308 | 309 | if method == "principal_curve": 310 | dict_epg = elpigraph.computeElasticPrincipalCurve( 311 | X=mat, 312 | NumNodes=n_nodes, 313 | TrimmingRadius=epg_trimming_radius, 314 | n_cores=n_jobs, 315 | Do_PCA=False, 316 | CenterData=False, 317 | Lambda=epg_lambda, 318 | Mu=epg_mu, 319 | alpha=epg_alpha, 320 | PointWeights=weights, 321 | GPU=GPU, 322 | MaxNumberOfGraphCandidatesDict=max_candidates, 323 | StoreGraphEvolution=store_evolution, 324 | **kwargs, 325 | )[0] 326 | if method == "principal_tree": 327 | dict_epg = elpigraph.computeElasticPrincipalTree( 328 | X=mat, 329 | NumNodes=n_nodes, 330 | TrimmingRadius=epg_trimming_radius, 331 | n_cores=n_jobs, 332 | Do_PCA=False, 333 | CenterData=False, 334 | Lambda=epg_lambda, 335 | Mu=epg_mu, 336 | alpha=epg_alpha, 337 | PointWeights=weights, 338 | GPU=GPU, 339 | MaxNumberOfGraphCandidatesDict=max_candidates, 340 | StoreGraphEvolution=store_evolution, 341 | **kwargs, 342 | )[0] 343 | if method == "principal_circle": 344 | dict_epg = elpigraph.computeElasticPrincipalCircle( 345 | X=mat, 346 | NumNodes=n_nodes, 347 | TrimmingRadius=epg_trimming_radius, 348 | n_cores=n_jobs, 349 | Do_PCA=False, 350 | CenterData=False, 351 | Lambda=epg_lambda, 352 | Mu=epg_mu, 353 | alpha=epg_alpha, 354 | PointWeights=weights, 355 | InitNodes=3, 356 | GPU=GPU, 357 | MaxNumberOfGraphCandidatesDict=max_candidates, 358 | StoreGraphEvolution=store_evolution, 359 | **kwargs, 360 | )[0] 361 | 362 | adata.uns["epg"] = dict() 363 | 364 | adata.uns["epg"]["node"] = np.arange(n_nodes) 365 | adata.uns["epg"]["node_pos"] = dict_epg["NodePositions"] 366 | adata.uns["epg"]["edge"] = dict_epg["Edges"][0] 367 | adata.uns["epg"]["params"] = { 368 | "method": method, 369 | "obsm": obsm, 370 | "layer": layer, 371 | "n_nodes": n_nodes, 372 | "epg_lambda": epg_lambda, 373 | "epg_mu": epg_mu, 374 | "epg_alpha": epg_alpha, 375 | "use_seed": use_seed, 376 | } 377 | if store_evolution: 378 | adata.uns["epg"]["graph_evolution"] = { 379 | "all_node_pos": dict_epg["AllNodePositions"], 380 | "all_edge": dict_epg["AllElasticMatrices"], 381 | } 382 | _store_graph_attributes(adata, mat, key="epg") 383 | 384 | 385 | def seed_graph( 386 | adata, 387 | obsm="X_dr", 388 | layer=None, 389 | clustering="kmeans", 390 | damping=0.75, 391 | pref_perc=50, 392 | n_clusters=10, 393 | max_n_clusters=200, 394 | n_neighbors=50, 395 | nb_pct=None, 396 | paths_favored=[], 397 | paths_disfavored=[], 398 | label=None, 399 | label_strength=2, 400 | force=False, 401 | use_weights=False, 402 | use_partition=False, 403 | ): 404 | """Seeding the initial elastic principal graph. 405 | 406 | Parameters 407 | ---------- 408 | adata: AnnData 409 | Annotated data matrix. 410 | obsm: `str`, optional (default: 'X_dr') 411 | The multi-dimensional annotation of observations 412 | used to learn the graph 413 | layer: `str`, optional (default: None) 414 | The layer used to learn the graph 415 | init_nodes_pos: `array`, shape = [n_nodes,n_dimension], 416 | optional (default: `None`) 417 | initial node positions 418 | init_edges: `array`, shape = [n_edges,2], optional (default: `None`) 419 | initial edges, all the initial nodes should be included 420 | in the tree structure 421 | clustering: `str`, optional (default: 'kmeans') 422 | Choose from {{'ap','kmeans','sc'}} 423 | clustering method used to infer the initial nodes. 424 | 'ap' affinity propagation 425 | 'kmeans' K-Means clustering 426 | 'sc' spectral clustering 427 | damping: `float`, optional (default: 0.75) 428 | Damping factor (between 0.5 and 1) for affinity propagation. 429 | pref_perc: `int`, optional (default: 50) 430 | Preference percentile (between 0 and 100). 431 | The percentile of the input similarities for affinity propagation. 432 | n_clusters: `int`, optional (default: 10) 433 | Number of clusters (only valid once 'clustering' 434 | is specified as 'sc' or 'kmeans'). 435 | max_n_clusters: `int`, optional (default: 200) 436 | The allowed maximum number of clusters for 'ap'. 437 | n_neighbors: `int`, optional (default: 50) 438 | The number of neighbor cells used for spectral clustering. 439 | nb_pct: `float`, optional (default: None) 440 | The percentage of neighbor cells 441 | (when specified, it will overwrite n_neighbors). 442 | paths_favored: list of lists, optional (default: []) 443 | Favored paths between categorical labels used 444 | for supervised MST initialization 445 | paths_disfavored: list of lists, optional (default: []) 446 | Disfavored paths between categorical labels 447 | used for supervised MST initialization 448 | label: `str`, optional (default: None) 449 | Categorical labels for supervised MST initialization 450 | label_strength: float in [1,oo) 451 | Strength of supervised MST initialization 452 | force: bool 453 | (experimental feature) 454 | Force supervised MST initialization to follow 455 | specified paths rather than using soft constraint 456 | use_weights: bool 457 | Whether to weight points with adata.obs['pointweights'] 458 | use_partition: bool 459 | Whether to learn a disconnected graph 460 | for each category in adata.uns['partition'] 461 | 462 | Returns 463 | ------- 464 | adata.obs['clustering']: `pandas.core.series.Series` 465 | (`adata.obs['clustering']`,dtype `str`) 466 | Array of dim (number of samples) that stores 467 | the clustering labels ('0', '1', …) for each cell. 468 | adata.uns['seed_epg'] : dict 469 | Elastic principal graph structure. 470 | 471 | """ 472 | 473 | if use_partition: 474 | print("Seeding initial graph for each partition...") 475 | if type(use_partition) is bool: 476 | partitions = adata.obs["partition"].unique() 477 | elif type(use_partition) is list: 478 | partitions = use_partition 479 | else: 480 | raise ValueError( 481 | "use_partition should be a bool or a list of partitions" 482 | ) 483 | 484 | merged_nodep = [] 485 | merged_edges = [] 486 | num_edges = 0 487 | 488 | for part in adata.obs["partition"].unique(): 489 | if type(use_partition) is list: 490 | p_adata = _subset_adata(adata, part) 491 | else: 492 | p_adata = adata[adata.obs["partition"] == part].copy() 493 | 494 | if part in partitions: 495 | _seed_graph( 496 | p_adata, 497 | obsm=obsm, 498 | layer=layer, 499 | clustering=clustering, 500 | damping=damping, 501 | pref_perc=pref_perc, 502 | n_clusters=n_clusters, 503 | max_n_clusters=max_n_clusters, 504 | n_neighbors=n_neighbors, 505 | nb_pct=nb_pct, 506 | paths_favored=paths_favored, 507 | paths_disfavored=paths_disfavored, 508 | label=label, 509 | label_strength=label_strength, 510 | force=force, 511 | use_weights=use_weights, 512 | verbose=False, 513 | ) 514 | 515 | merged_nodep.append(p_adata.uns["seed_epg"]["node_pos"]) 516 | merged_edges.append(p_adata.uns["seed_epg"]["edge"] + num_edges) 517 | num_edges += len(p_adata.uns["seed_epg"]["node_pos"]) 518 | 519 | adata.uns["seed_epg"] = {} 520 | adata.uns["seed_epg"]["node_pos"] = np.concatenate(merged_nodep) 521 | adata.uns["seed_epg"]["edge"] = np.concatenate((merged_edges)) 522 | adata.uns["seed_epg"]["node_partition"] = np.repeat( 523 | adata.obs["partition"].unique(), 524 | [len(nodep) for nodep in merged_nodep], 525 | ).astype(str) 526 | adata.uns["seed_epg"]["edge_partition"] = np.repeat( 527 | adata.obs["partition"].unique(), 528 | [len(edges) for edges in merged_edges], 529 | ).astype(str) 530 | adata.uns["seed_epg"]["params"] = p_adata.uns["seed_epg"]["params"] 531 | 532 | X = _get_graph_data(adata, key="seed_epg") 533 | _store_graph_attributes(adata, X, key="seed_epg") 534 | 535 | else: 536 | _seed_graph( 537 | adata, 538 | obsm=obsm, 539 | layer=layer, 540 | clustering=clustering, 541 | damping=damping, 542 | pref_perc=pref_perc, 543 | n_clusters=n_clusters, 544 | max_n_clusters=max_n_clusters, 545 | n_neighbors=n_neighbors, 546 | nb_pct=nb_pct, 547 | paths_favored=paths_favored, 548 | paths_disfavored=paths_disfavored, 549 | label=label, 550 | label_strength=label_strength, 551 | force=force, 552 | use_weights=use_weights, 553 | ) 554 | 555 | 556 | def _seed_graph( 557 | adata, 558 | obsm="X_dr", 559 | layer=None, 560 | clustering="kmeans", 561 | damping=0.75, 562 | pref_perc=50, 563 | n_clusters=10, 564 | max_n_clusters=200, 565 | n_neighbors=50, 566 | nb_pct=None, 567 | paths_favored=[], 568 | paths_disfavored=[], 569 | label=None, 570 | label_strength=2, 571 | force=False, 572 | use_weights=False, 573 | verbose=True, 574 | ): 575 | """Internal method to seed_graph""" 576 | 577 | if verbose: 578 | print("Seeding initial graph...") 579 | 580 | if sum(list(map(lambda x: x is not None, [layer, obsm]))) == 2: 581 | raise ValueError("Only one of `layer` and `obsm` can be used") 582 | elif obsm is not None: 583 | if obsm in adata.obsm: 584 | mat = adata.obsm[obsm] 585 | adata.uns["seed"] = obsm 586 | else: 587 | raise ValueError(f"could not find {obsm} in `adata.obsm`") 588 | elif layer is not None: 589 | if layer in adata.layers: 590 | mat = adata.layers[layer] 591 | adata.uns["seed"] = obsm 592 | else: 593 | raise ValueError(f"could not find {layer} in `adata.layers`") 594 | else: 595 | mat = adata.X 596 | 597 | if nb_pct is not None: 598 | n_neighbors = int(np.around(mat.shape[0] * nb_pct)) 599 | 600 | if label_strength<1: 601 | raise ValueError("label_strength should be >=1") 602 | 603 | if verbose: 604 | print("Clustering...") 605 | if clustering == "ap": 606 | if verbose: 607 | print("Affinity propagation ...") 608 | ap = AffinityPropagation( 609 | damping=damping, 610 | random_state=42, 611 | preference=np.percentile( 612 | -euclidean_distances(mat, squared=True), pref_perc 613 | ), 614 | ).fit(mat) 615 | # ap = AffinityPropagation(damping=damping).fit(mat) 616 | if ap.cluster_centers_.shape[0] > max_n_clusters: 617 | if verbose: 618 | print( 619 | "The number of clusters is " 620 | + str(ap.cluster_centers_.shape[0]) 621 | ) 622 | if verbose: 623 | print( 624 | "Too many clusters are generated, please lower pref_perc" 625 | " or increase damping and retry it" 626 | ) 627 | return 628 | cluster_labels = ap.labels_ 629 | init_nodes_pos = ap.cluster_centers_ 630 | elif clustering == "sc": 631 | if verbose: 632 | print("Spectral clustering ...") 633 | sc = SpectralClustering( 634 | n_clusters=n_clusters, 635 | affinity="nearest_neighbors", 636 | n_neighbors=n_neighbors, 637 | eigen_solver="arpack", 638 | random_state=42, 639 | ).fit(mat) 640 | cluster_labels = sc.labels_ 641 | init_nodes_pos = np.empty((0, mat.shape[1])) # cluster centers 642 | for x in np.unique(cluster_labels): 643 | id_cells = np.array(range(mat.shape[0]))[cluster_labels == x] 644 | init_nodes_pos = np.vstack( 645 | (init_nodes_pos, np.median(mat[id_cells, :], axis=0)) 646 | ) 647 | elif clustering == "kmeans": 648 | if verbose: 649 | print("K-Means clustering ...") 650 | if use_weights: 651 | if "pointweights" not in adata.obs: 652 | raise ValueError( 653 | "adata.obs['pointweights'] not found. Please run" 654 | " st2.tl.get_weights" 655 | ) 656 | weights = np.array(adata.obs["pointweights"]).flatten() 657 | else: 658 | weights = None 659 | kmeans = KMeans( 660 | n_clusters=n_clusters, init="k-means++", 661 | n_init=10, max_iter=300, tol=0.0001, 662 | algorithm='lloyd',random_state=42 663 | ).fit(mat, sample_weight=weights) 664 | cluster_labels = kmeans.labels_ 665 | init_nodes_pos = kmeans.cluster_centers_ 666 | else: 667 | if verbose: 668 | print("'" + clustering + "'" + " is not supported") 669 | adata.obs[clustering] = ["cluster " + str(x) for x in cluster_labels] 670 | 671 | # Minimum Spanning Tree ### 672 | if verbose: 673 | print("Calculating minimum spanning tree...") 674 | 675 | # ---if supervised adjacency matrix option 676 | if ( 677 | ((len(paths_favored) > 0) or (len(paths_disfavored) > 0)) 678 | and label is None 679 | ) or ( 680 | ((len(paths_favored) == 0) and (len(paths_disfavored) == 0)) 681 | and label is not None 682 | ): 683 | raise ValueError( 684 | "Both a label key (label: str) and cluster paths (paths: list of" 685 | " list) need to be provided for path-supervised initialization" 686 | ) 687 | elif ( 688 | (len(paths_favored) > 0) or (len(paths_disfavored) > 0) 689 | ) and label is not None: 690 | ( 691 | init_nodes_pos, 692 | clus_adjmat, 693 | adjmat_strength, 694 | num_modes, 695 | num_labels, 696 | labels_ignored, 697 | ) = _categorical_adjmat( 698 | mat, 699 | init_nodes_pos, 700 | paths_favored, 701 | paths_disfavored, 702 | adata.obs[label], 703 | label_strength, 704 | ) 705 | D = pairwise_distances(init_nodes_pos) 706 | G = nx.from_numpy_array(D * clus_adjmat) 707 | 708 | # ---else unsupervised 709 | else: 710 | D = pairwise_distances(init_nodes_pos) 711 | G = nx.from_numpy_array(D) 712 | 713 | # ---get edges from mst 714 | mst = nx.minimum_spanning_tree(G, ignore_nan=True) 715 | init_edges = np.array(mst.edges()) 716 | if force and label is not None: 717 | init_edges = _force_missing_connections( 718 | D, num_labels, num_modes, init_edges, paths_favored, clus_adjmat 719 | ) 720 | 721 | # Store results ### 722 | adata.uns["seed_epg"] = dict() 723 | adata.uns["seed_epg"]["node_pos"] = init_nodes_pos 724 | adata.uns["seed_epg"]["edge"] = init_edges 725 | adata.uns["seed_epg"]["params"] = dict( 726 | obsm=obsm, 727 | layer=layer, 728 | clustering=clustering, 729 | damping=damping, 730 | pref_perc=pref_perc, 731 | n_clusters=n_clusters, 732 | max_n_clusters=max_n_clusters, 733 | n_neighbors=n_neighbors, 734 | nb_pct=nb_pct, 735 | ) 736 | _store_graph_attributes(adata, mat, key="seed_epg") 737 | 738 | 739 | def _store_graph_attributes(adata, mat, key): 740 | """Compute graph attributes and store them in adata.uns[key]""" 741 | 742 | G = nx.Graph() 743 | G.add_edges_from(adata.uns[key]["edge"].tolist(), weight=1) 744 | mat_conn = nx.to_scipy_sparse_array( 745 | G, 746 | nodelist=np.arange(len(adata.uns[key]["node_pos"])), 747 | weight="weight", 748 | ) 749 | 750 | # partition points 751 | node_id, node_dist = elpigraph.src.core.PartitionData( 752 | X=mat, 753 | NodePositions=adata.uns[key]["node_pos"], 754 | MaxBlockSize=len(adata.uns[key]["node_pos"]) ** 4, 755 | SquaredX=np.sum(mat ** 2, axis=1, keepdims=1), 756 | ) 757 | # project points onto edges 758 | dict_proj = elpigraph.src.reporting.project_point_onto_graph( 759 | X=mat, 760 | NodePositions=adata.uns[key]["node_pos"], 761 | Edges=adata.uns[key]["edge"], 762 | Partition=node_id, 763 | ) 764 | edge_dist = np.linalg.norm( 765 | mat - dict_proj['X_projected'], 766 | axis=1) 767 | 768 | adata.obs[f"{key}_node_id"] = node_id.flatten() 769 | adata.obs[f"{key}_node_dist"] = node_dist 770 | adata.obs[f"{key}_edge_id"] = dict_proj["EdgeID"].astype(int) 771 | adata.obs[f"{key}_edge_loc"] = dict_proj["ProjectionValues"] 772 | adata.obs[f"{key}_edge_dist"] = edge_dist 773 | 774 | # adata.obsm[f"X_{key}_proj"] = dict_proj["X_projected"] 775 | 776 | adata.uns[key]["conn"] = mat_conn 777 | adata.uns[key]["edge_len"] = dict_proj["EdgeLen"] 778 | 779 | 780 | def _get_branch_id(adata, key="epg"): 781 | """add adata.obs['branch_id']""" 782 | # get branches 783 | net = elpigraph.src.graphs.ConstructGraph( 784 | {"Edges": [adata.uns[key]["edge"]]} 785 | ) 786 | branches = elpigraph.src.graphs.GetSubGraph(net, "branches") 787 | _dict_branches = { 788 | (b[0], b[-1]): b for i, b in enumerate(branches) 789 | } # temporary branch node lists (not in order) 790 | 791 | ordered_edges, ordered_nodes = elpigraph.src.supervised.bf_search( 792 | _dict_branches, root_node=np.where(np.array(net.degree()) == 1)[0][0] 793 | ) 794 | # create ordered dict 795 | dict_branches = {} 796 | for i, e in enumerate(ordered_edges): # for each branch 797 | # store branch in order (both the key and the list) 798 | if e not in _dict_branches: 799 | dict_branches[e] = _dict_branches[e[::-1]][::-1] 800 | else: 801 | dict_branches[e] = _dict_branches[e] 802 | 803 | # disable warning 804 | pd.options.mode.chained_assignment = None 805 | 806 | point_edges = adata.uns[key]["edge"][adata.obs[f"{key}_edge_id"]] 807 | adata.obs[f"{key}_branch_id"] = "" 808 | for i, e in enumerate(point_edges): 809 | for k, v in dict_branches.items(): 810 | if all(np.isin(e, v)): 811 | adata.obs[f"{key}_branch_id"][i] = k 812 | 813 | # reactivate warning 814 | pd.options.mode.chained_assignment = "warn" 815 | 816 | 817 | # Categorical MST initialization utils ## 818 | 819 | 820 | def _force_missing_connections( 821 | D, num_labels, num_modes, init_edges, paths_favored, clus_adjmat 822 | ): 823 | found_missing = True 824 | while found_missing: 825 | 826 | found_missing = False 827 | edges_labels = np.array(list(num_labels.keys()))[num_modes][ 828 | init_edges 829 | ].tolist() 830 | for path in paths_favored: 831 | for i in range(len(path) - 1): 832 | if [path[i], path[i + 1]] not in edges_labels and [ 833 | path[i + 1], 834 | path[i], 835 | ] not in edges_labels: 836 | print(path[i], path[i + 1]) 837 | print(num_labels[path[i]], num_labels[path[i + 1]]) 838 | 839 | missing_is = np.where(num_modes == num_labels[path[i]])[0] 840 | missing_js = np.where( 841 | num_modes == num_labels[path[i + 1]] 842 | )[0] 843 | x = D[missing_is[:, None], missing_js] 844 | _i, _j = np.where(x == x.min()) 845 | mi, mj = missing_is[_i], missing_js[_j] 846 | D[mi, mj] = D[mj, mi] = -1.0 847 | found_missing = True 848 | 849 | # ---get edges from mst 850 | G = nx.from_numpy_array(D * clus_adjmat) 851 | mst = nx.minimum_spanning_tree(G, ignore_nan=True) 852 | init_edges = np.array(mst.edges()) 853 | return init_edges 854 | 855 | 856 | def _get_partition_modes(mat, init_nodes_pos, labels): 857 | """Return most frequent label assigned to each node.""" 858 | labels = np.array(labels) 859 | part = elpigraph.src.core.PartitionData( 860 | mat, init_nodes_pos, 10 ** 6, np.sum(mat ** 2, axis=1, keepdims=1) 861 | )[0].flatten() 862 | modes = np.empty(len(init_nodes_pos), dtype=labels.dtype) 863 | 864 | for i in range(len(init_nodes_pos)): 865 | modes[i] = scipy.stats.mode(labels[part == i]).mode[0] 866 | return modes 867 | 868 | 869 | def _get_labels_adjmat(labels_u, labels_ignored, paths_favored, 870 | paths_disfavored, label_strength): 871 | """Create adjmat given labels and paths. 872 | 873 | labels_ignored are connected to all other labels 874 | """ 875 | num_labels = { 876 | s: i for i, s in enumerate(np.append(labels_u, labels_ignored)) 877 | } 878 | len_labels = len(labels_u) + len(labels_ignored) 879 | adjmat = np.ones((len_labels, len_labels)) 880 | 881 | # allow connections given from paths 882 | for p in paths_favored: 883 | for i in range(len(p) - 1): 884 | adjmat[num_labels[p[i]], num_labels[p[i + 1]]] = adjmat[ 885 | num_labels[p[i + 1]], num_labels[p[i]] 886 | ] = 1/label_strength 887 | 888 | # remove forbidden connections given from paths_disfavored 889 | for p in paths_disfavored: 890 | for i in range(len(p) - 1): 891 | adjmat[num_labels[p[i]], num_labels[p[i + 1]]] = adjmat[ 892 | num_labels[p[i + 1]], num_labels[p[i]] 893 | ] = label_strength 894 | 895 | return adjmat, num_labels 896 | 897 | def _get_clus_adjmat(adjmat_strength, num_modes, n_clusters): 898 | """Create clus_adjmat given labels adjmat 899 | and kmeans label assignment.""" 900 | 901 | adjmat_clus = np.ones((n_clusters, n_clusters)) 902 | 903 | for ei in range(len(adjmat_strength)): 904 | for ej in range(len(adjmat_strength)): 905 | clus_ei = np.where(num_modes == ei)[0] 906 | clus_ej = np.where(num_modes == ej)[0] 907 | adjmat_clus[ 908 | clus_ei[:, None], 909 | np.repeat(clus_ej[None], len(clus_ei), axis=0), 910 | ] = adjmat_strength[ei, ej] 911 | return adjmat_clus 912 | 913 | 914 | def _categorical_adjmat( 915 | mat, init_nodes_pos, paths_favored, 916 | paths_disfavored, labels, label_strength 917 | ): 918 | """Main function, create categorical adjmat given 919 | node positions, cluster paths, point labels.""" 920 | 921 | labels_u = np.unique([c for p in paths_favored for c in p]) 922 | labels_ignored = np.setdiff1d(labels, labels_u) 923 | # label adjacency matrix 924 | adjmat_strength, num_labels = _get_labels_adjmat( 925 | labels_u, labels_ignored, paths_favored, 926 | paths_disfavored, label_strength 927 | ) 928 | # assign label to nodes 929 | modes = _get_partition_modes(mat, init_nodes_pos, labels) 930 | num_modes = np.array([num_labels[m] for m in modes]) 931 | 932 | # add centroids if necessary to prevent bug 933 | # (if some label has no kmean assigned to it) 934 | labels_miss = np.setdiff1d(labels_u, modes) 935 | if len(labels_miss) > 0: 936 | print( 937 | f"Found label(s) {labels_miss} with no representative node. Adding" 938 | " label centroid(s) as node(s)" 939 | ) 940 | centroids = np.vstack( 941 | [mat[labels == s].mean(axis=0) for s in labels_miss] 942 | ) 943 | init_nodes_pos = np.vstack((init_nodes_pos, centroids)) 944 | modes = np.hstack((modes, labels_miss)) 945 | num_modes = np.array([num_labels[m] for m in modes]) 946 | 947 | # nodes adjacency matrix 948 | clus_adjmat = _get_clus_adjmat( 949 | adjmat_strength, 950 | num_modes, 951 | n_clusters=len(init_nodes_pos), 952 | ) 953 | return ( 954 | init_nodes_pos, 955 | clus_adjmat, 956 | adjmat_strength, 957 | num_modes, 958 | num_labels, 959 | labels_ignored, 960 | ) 961 | 962 | 963 | 964 | 965 | 966 | 967 | def _get_graph_data(adata, key): 968 | """get data matrix used to learn the graph.""" 969 | obsm = adata.uns[key]["params"]["obsm"] 970 | layer = adata.uns[key]["params"]["layer"] 971 | 972 | if obsm is not None: 973 | if obsm in adata.obsm: 974 | mat = adata.obsm[obsm] 975 | else: 976 | raise ValueError(f"could not find {obsm} in `adata.obsm`") 977 | elif layer is not None: 978 | if layer in adata.layers: 979 | mat = adata.layers[layer] 980 | else: 981 | raise ValueError(f"could not find {layer} in `adata.layers`") 982 | else: 983 | mat = adata.X 984 | return mat 985 | 986 | 987 | def _subset_adata(adata, part): 988 | p_adata = adata[adata.obs["partition"] == part].copy() 989 | for key in ["seed_epg", "epg"]: 990 | if key in p_adata.uns: 991 | p_adata.uns[key] = deepcopy(adata.uns[key]) 992 | p_adata.uns[key]["node_pos"] = p_adata.uns[key]["node_pos"][ 993 | p_adata.uns[key]["node_partition"] == part 994 | ] 995 | p_adata.uns[key]["edge"] = p_adata.uns[key]["edge"][ 996 | p_adata.uns[key]["edge_partition"] == part 997 | ] 998 | p_adata.uns[key]["edge"] -= p_adata.uns[key]["edge"].min() 999 | return p_adata 1000 | --------------------------------------------------------------------------------