├── ChEA └── ChEA_2016.txt ├── ENCODE.zip ├── README.md ├── TFvelo ├── __init__.py ├── core │ ├── __init__.py │ ├── _anndata.py │ ├── _arithmetic.py │ ├── _base.py │ ├── _linear_models.py │ ├── _metrics.py │ ├── _models.py │ ├── _parallelize.py │ └── tests │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── test_anndata.cpython-37.pyc │ │ ├── test_arithmetic.cpython-37.pyc │ │ ├── test_base.cpython-37.pyc │ │ ├── test_linear_models.cpython-37.pyc │ │ ├── test_metrics.cpython-37.pyc │ │ └── test_models.cpython-37.pyc │ │ ├── test_anndata.py │ │ ├── test_arithmetic.py │ │ ├── test_base.py │ │ ├── test_linear_models.py │ │ ├── test_metrics.py │ │ └── test_models.py ├── datasets.py ├── logging.py ├── pl.py ├── plotting │ ├── __init__.py │ ├── docs.py │ ├── gridspec.py │ ├── heatmap.py │ ├── paga.py │ ├── palettes.py │ ├── proportions.py │ ├── pseudotime.py │ ├── scatter.py │ ├── simulation.py │ ├── utils.py │ ├── velocity.py │ ├── velocity_embedding.py │ ├── velocity_embedding_grid.py │ ├── velocity_embedding_stream.py │ └── velocity_graph.py ├── pp.py ├── preprocessing │ ├── __init__.py │ ├── moments.py │ ├── neighbors.py │ └── utils.py ├── read_load.py ├── settings.py ├── tl.py ├── tools │ ├── __init__.py │ ├── _velocity_graph.py │ ├── dynamical_model.py │ ├── dynamical_model_utils.py │ ├── paga.py │ ├── rank_velocity_genes.py │ ├── terminal_states.py │ ├── transition_matrix.py │ ├── utils.py │ ├── velocity_confidence.py │ ├── velocity_embedding.py │ ├── velocity_graph.py │ └── velocity_pseudotime.py └── utils.py ├── TFvelo_analysis_demo.py ├── TFvelo_demo.ipynb ├── TFvelo_run_demo.py ├── baselines ├── MultiVelo_run.ipynb ├── Multivelo_analysis.ipynb ├── baseline_TI_demo.py ├── baseline_cellDancer_Demo.py ├── baseline_dynamo_demo.py ├── baseline_scvelo_demo.py ├── baseline_unitvelo_demo.py ├── compare_baselines_metrics_demo.py └── compare_baselines_phase_demo.py ├── data ├── 10x_mouse_brain │ └── adata_rna.h5ad └── TF_names_v_1.01.txt ├── figures └── demo.png └── simulation ├── TFvelo_synthetic_demo.py └── simulate_phase_delay.py /ENCODE.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaoyeye/TFvelo/ec6cdb940af94f02fe32c8fcdb98494cdb4beb96/ENCODE.zip -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TFvelo 2 | 3 | 4 | This is the code of TFvelo: gene regulation inspired RNA velocity estimation. Paper is avaliable at: [Li, Jiachen, et al. "TFvelo: gene regulation inspired RNA velocity estimation." bioRxiv (2023): 2023-07](https://doi.org/10.1101/2023.07.12.548785). 5 | 6 | Due to the wide usage of scVelo (Bergen, Volker, et al. "Generalizing RNA velocity to transient cell states through dynamical modeling." Nature biotechnology 38.12 (2020): 1408-1414, [link](https://github.com/theislab/scvelo)) and its clean, well-organized codes, we develop TFvelo based on the framework of scvelo. 7 | 8 | In TFvelo, the gene regulatory relationship is taken into consideration for modeling the time derivative of RNA abundance, which allows a more accurate phase portrait fitting for each gene. 9 | 10 | ![Image text](https://github.com/xiaoyeye/TFvelo/blob/main/figures/demo.png) 11 | 12 | TFvelo_run_demo.py provides the demo for runing TFvelo, and TFvelo_analysis_demo.py is for results visualization. The package TFvelo can be directly downloaded for usage. 13 | 14 | Please decompress the ENCODE TF-target database file firstly after downloading this package. On linux, you can run: 15 | ``` 16 | unzip ENCODE.zip 17 | ``` 18 | 19 | ## Environment: 20 | ``` 21 | conda create -n TFvelo_env python=3.8.12 22 | conda activate TFvelo_env 23 | pip install pandas==1.2.3 24 | pip install anndata==0.8.0 25 | pip install scanpy==1.8.2 26 | pip install numpy==1.21.6 27 | pip install scipy==1.10.1 28 | pip install numba==0.57.0 29 | pip install matplotlib==3.3.4 30 | pip install scvelo==0.2.4 31 | pip install typing_extensions 32 | ``` 33 | 34 | ## Reproduce: 35 | Running the program with default parameters can reproduce the results in manuscript. 36 | 37 | To reproduce TFvelo on pancreas: 38 | ``` 39 | python TFvelo_run_demo.py --dataset_name pancreas 40 | ``` 41 | This will automatically download, preprocess and run TF model on pancrease dataset. The result will be stored in 'TFvelo_pancreas/rc.h5ad'. 42 | 43 | 44 | After that, the visualization of results can be obtained by 45 | ``` 46 | python TFvelo_analysis_demo.py --dataset_name pancreas 47 | ``` 48 | This will show the pseudotime and streamplot on UMAP, and also the phase portrait fitting of best fitted genes. 49 | The result will be stored in 'TFvelo_pancreas_demo/TFvelo.h5ad', and figures will be saved in folder 'figures'. 50 | 51 | 52 | ## Usage: 53 | To apply TFvelo to any other scRNA-seq datasets: 54 | 55 | you can define a personalized name for the dataset, and simply add the following codes into the preprocess() function in TFvelo_run_demo.py: 56 | ``` 57 | if args.dataset_name == your_dataset_name: 58 | adata = ad.read_h5ad(your_h5ad_file_path) 59 | ``` 60 | Then run the code with: 61 | ``` 62 | python TFvelo_run_demo.py --dataset_name your_dataset_name 63 | python TFvelo_analysis_demo.py --dataset_name your_dataset_name 64 | ``` 65 | As a result, all generated h5ad files will be puted in the folder named: "TFvelo_"+your_data_name+"_demo". And figures will be saved in the folder "figures". 66 | 67 | ## Hyperparameters: 68 | ``` 69 | --n_jobs: The number of CPUs to use 70 | --init_weight_method: The method to initialize the weights. Correlation is adopted by default. 71 | --WX_method: The method to optimize weight. lsq_linear is adopted by default. 72 | --n_neighbors: The number of neighbors. 73 | --WX_thres: The max absolute value for weights. 74 | --TF_databases: The way to select candidate TFs. Use ENCODE and ChEA by default. 75 | --max_n_TF: Max number of TFs used for modeling each gene. 76 | --max_iter: Max number of iterations in the generalized EM algorithm. 77 | --n_time_points: The number of time points in the time assinment (E step of the generalized EM algorithm). 78 | --save_name: The name of folder which all generated files will be put in. 79 | ``` 80 | 81 | ## Baselines and metrics: 82 | The code for baselines and metrics are provided in the folder "baselines". You may need to create an environment and install the required packages for each baseline method. 83 | -------------------------------------------------------------------------------- /TFvelo/__init__.py: -------------------------------------------------------------------------------- 1 | """scvelo - RNA velocity generalized through dynamical modeling""" 2 | from anndata import AnnData 3 | from scanpy import read, read_loom 4 | from . import datasets, logging, pl, pp, settings, tl, utils 5 | from .core import get_df 6 | from .plotting.gridspec import GridSpec 7 | from .preprocessing.neighbors import Neighbors 8 | from .read_load import DataFrame, load, read_csv 9 | from .settings import set_figure_params 10 | from .tools.utils import round 11 | from .tools.velocity_graph import VelocityGraph 12 | 13 | 14 | 15 | __all__ = [ 16 | "AnnData", 17 | "DataFrame", 18 | "datasets", 19 | "get_df", 20 | "GridSpec", 21 | "load", 22 | "logging", 23 | "Neighbors", 24 | "pl", 25 | "pp", 26 | "read", 27 | "read_csv", 28 | "read_loom", 29 | "round", 30 | "set_figure_params", 31 | "settings", 32 | "tl", 33 | "utils", 34 | "VelocityGraph", 35 | ] 36 | -------------------------------------------------------------------------------- /TFvelo/core/__init__.py: -------------------------------------------------------------------------------- 1 | from ._anndata import ( 2 | clean_obs_names, 3 | cleanup, 4 | get_df, 5 | get_initial_size, 6 | get_modality, 7 | get_size, 8 | make_dense, 9 | make_sparse, 10 | merge, 11 | set_initial_size, 12 | set_modality, 13 | show_proportions, 14 | ) 15 | from ._arithmetic import clipped_log, invert, prod_sum, sum 16 | from ._linear_models import LinearRegression 17 | from ._metrics import l2_norm 18 | from ._models import SplicingDynamics 19 | from ._parallelize import get_n_jobs, parallelize 20 | 21 | __all__ = [ 22 | "clean_obs_names", 23 | "cleanup", 24 | "clipped_log", 25 | "get_df", 26 | "get_initial_size", 27 | "get_modality", 28 | "get_n_jobs", 29 | "get_size", 30 | "invert", 31 | "l2_norm", 32 | "LinearRegression", 33 | "make_dense", 34 | "make_sparse", 35 | "merge", 36 | "parallelize", 37 | "prod_sum", 38 | "set_initial_size", 39 | "set_modality", 40 | "show_proportions", 41 | "SplicingDynamics", 42 | "sum", 43 | ] 44 | -------------------------------------------------------------------------------- /TFvelo/core/_arithmetic.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Optional, Union 3 | 4 | import numpy as np 5 | from numpy import ndarray 6 | from scipy.sparse import issparse, spmatrix 7 | 8 | 9 | def clipped_log(x: ndarray, lb: float = 0, ub: float = 1, eps: float = 1e-6) -> ndarray: 10 | """Logarithmize between [lb + epsilon, ub - epsilon]. 11 | 12 | Arguments 13 | --------- 14 | x 15 | Array to invert. 16 | lb 17 | Lower bound of interval to which array entries are clipped. 18 | ub 19 | Upper bound of interval to which array entries are clipped. 20 | eps 21 | Offset of boundaries of clipping interval. 22 | 23 | Returns 24 | ------- 25 | ndarray 26 | Logarithm of clipped array. 27 | """ 28 | 29 | return np.log(np.clip(x, lb + eps, ub - eps)) 30 | 31 | 32 | def invert(x: ndarray) -> ndarray: 33 | """Invert array and set infinity to NaN. 34 | 35 | Arguments 36 | --------- 37 | x 38 | Array to invert. 39 | 40 | Returns 41 | ------- 42 | ndarray 43 | Inverted array. 44 | """ 45 | 46 | with warnings.catch_warnings(): 47 | warnings.simplefilter("ignore") 48 | x_inv = 1 / x * (x != 0) 49 | return x_inv 50 | 51 | 52 | def prod_sum( 53 | a1: Union[ndarray, spmatrix], a2: Union[ndarray, spmatrix], axis: Optional[int] 54 | ) -> ndarray: 55 | """Take sum of product of two arrays along given axis. 56 | 57 | Arguments 58 | --------- 59 | a1 60 | First array. 61 | a2 62 | Second array. 63 | axis 64 | Axis along which to sum elements. If `None`, all elements will be summed. 65 | Defaults to `None`. 66 | 67 | Returns 68 | ------- 69 | ndarray 70 | Sum of product of arrays along given axis. 71 | """ 72 | 73 | if issparse(a1): 74 | return a1.multiply(a2).sum(axis=axis).A1 75 | elif axis == 0: 76 | return np.einsum("ij, ij -> j", a1, a2) if a1.ndim > 1 else (a1 * a2).sum() 77 | elif axis == 1: 78 | return np.einsum("ij, ij -> i", a1, a2) if a1.ndim > 1 else (a1 * a2).sum() 79 | 80 | 81 | def sum(a: Union[ndarray, spmatrix], axis: Optional[int] = None) -> ndarray: 82 | """Sum array elements over a given axis. 83 | 84 | Arguments 85 | --------- 86 | a 87 | Elements to sum. 88 | axis 89 | Axis along which to sum elements. If `None`, all elements will be summed. 90 | Defaults to `None`. 91 | 92 | Returns 93 | ------- 94 | ndarray 95 | Sum of array along given axis. 96 | """ 97 | 98 | if a.ndim == 1: 99 | axis = 0 100 | 101 | with warnings.catch_warnings(): 102 | warnings.simplefilter("ignore") 103 | return a.sum(axis=axis).A1 if issparse(a) else a.sum(axis=axis) 104 | -------------------------------------------------------------------------------- /TFvelo/core/_base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Dict, Tuple, Union 3 | 4 | from numpy import ndarray 5 | 6 | 7 | class DynamicsBase(ABC): 8 | @abstractmethod 9 | def get_solution( 10 | self, t: ndarray, stacked: True, with_keys: bool = False 11 | ) -> Union[Dict, Tuple[ndarray], ndarray]: 12 | """Calculate solution of dynamics. 13 | 14 | Arguments 15 | --------- 16 | t 17 | Time steps at which to evaluate solution. 18 | stacked 19 | Whether to stack states or return them individually. Defaults to `True`. 20 | with_keys 21 | Whether to return solution labelled by variables in form of a dictionary. 22 | Defaults to `False`. 23 | 24 | Returns 25 | ------- 26 | Union[Dict, Tuple[ndarray], ndarray] 27 | Solution of system. If `with_keys=True`, the solution is returned in form of 28 | a dictionary with variables as keys. Otherwise, the solution is given as 29 | a `numpy.ndarray` of form `(n_steps, n_vars)`. 30 | """ 31 | 32 | return 33 | 34 | -------------------------------------------------------------------------------- /TFvelo/core/_linear_models.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple, Union 2 | 3 | import numpy as np 4 | from numpy import ndarray 5 | from scipy.sparse import csr_matrix, issparse 6 | 7 | from ._arithmetic import prod_sum, sum 8 | 9 | 10 | class LinearRegression: 11 | """Extreme quantile and constraint least square linear regression. 12 | 13 | Arguments 14 | --------- 15 | percentile 16 | Percentile of data on which linear regression line is fit. If `None`, all data 17 | is used, if a single value is given, it is interpreted as the upper quantile. 18 | Defaults to `None`. 19 | fit_intercept 20 | Whether to calculate the intercept for model. Defaults to `False`. 21 | positive_intercept 22 | Whether the intercept it constraint to positive values. Only plays a role when 23 | `fit_intercept=True`. Defaults to `True`. 24 | constrain_ratio 25 | Ratio to which coefficients are clipped. If `None`, the coefficients are not 26 | constraint. Defaults to `None`. 27 | 28 | Attributes 29 | ---------- 30 | coef_ 31 | Estimated coefficients of the linear regression line. 32 | 33 | intercept_ 34 | Fitted intercept of linear model. Set to `0.0` if `fit_intercept=False`. 35 | 36 | """ 37 | 38 | def __init__( 39 | self, 40 | percentile: Optional[Union[Tuple, int, float]] = None, 41 | fit_intercept: bool = False, 42 | positive_intercept: bool = True, 43 | constrain_ratio: Optional[Union[Tuple, float]] = None, 44 | ): 45 | if not fit_intercept and isinstance(percentile, (list, tuple)): 46 | self.percentile = percentile[1] 47 | else: 48 | self.percentile = percentile 49 | self.fit_intercept = fit_intercept 50 | self.positive_intercept = positive_intercept 51 | 52 | if constrain_ratio is None: 53 | self.constrain_ratio = [-np.inf, np.inf] 54 | elif len(constrain_ratio) == 1: 55 | self.constrain_ratio = [-np.inf, constrain_ratio] 56 | else: 57 | self.constrain_ratio = constrain_ratio 58 | 59 | def _trim_data(self, data: List) -> List: 60 | """Trim data to extreme values. 61 | 62 | Arguments 63 | --------- 64 | data 65 | Data to be trimmed to extreme quantiles. 66 | 67 | Returns 68 | ------- 69 | List 70 | Number of non-trivial entries per column and trimmed data. 71 | """ 72 | 73 | if not isinstance(data, List): 74 | data = [data] 75 | 76 | data = np.array( 77 | [data_mat.A if issparse(data_mat) else data_mat for data_mat in data] 78 | ) 79 | 80 | # TODO: Add explanatory comment 81 | normalized_data = np.sum( 82 | data / data.max(axis=1, keepdims=True).clip(1e-3, None), axis=0 83 | ) 84 | 85 | bound = np.percentile(normalized_data, self.percentile, axis=0) 86 | 87 | if bound.ndim == 1: 88 | trimmer = csr_matrix(normalized_data >= bound).astype(bool) 89 | else: 90 | trimmer = csr_matrix( 91 | (normalized_data <= bound[0]) | (normalized_data >= bound[1]) 92 | ).astype(bool) 93 | 94 | return [trimmer.getnnz(axis=0)] + [ 95 | trimmer.multiply(data_mat).tocsr() for data_mat in data 96 | ] 97 | 98 | def fit(self, x: ndarray, y: ndarray): 99 | """Fit linear model per column. 100 | 101 | Arguments 102 | --------- 103 | x 104 | Training data of shape `(n_obs, n_vars)`. 105 | y 106 | Target values of shape `(n_obs, n_vars)`. 107 | 108 | Returns 109 | ------- 110 | self 111 | Returns an instance of self. 112 | """ 113 | 114 | n_obs = x.shape[0] 115 | 116 | if self.percentile is not None: 117 | n_obs, x, y = self._trim_data(data=[x, y]) 118 | 119 | _xx = prod_sum(x, x, axis=0) 120 | _xy = prod_sum(x, y, axis=0) 121 | 122 | if self.fit_intercept: 123 | _x = sum(x, axis=0) / n_obs 124 | _y = sum(y, axis=0) / n_obs 125 | self.coef_ = (_xy / n_obs - _x * _y) / (_xx / n_obs - _x ** 2) 126 | self.intercept_ = _y - self.coef_ * _x 127 | 128 | if self.positive_intercept: 129 | idx = self.intercept_ < 0 130 | if self.coef_.ndim > 0: 131 | self.coef_[idx] = _xy[idx] / _xx[idx] 132 | else: 133 | self.coef_ = _xy / _xx 134 | self.intercept_ = np.clip(self.intercept_, 0, None) 135 | else: 136 | self.coef_ = _xy / _xx 137 | self.intercept_ = np.zeros(x.shape[1]) if x.ndim > 1 else 0 138 | 139 | if not np.isscalar(self.coef_): 140 | self.coef_[np.isnan(self.coef_)] = 0 141 | self.intercept_[np.isnan(self.intercept_)] = 0 142 | else: 143 | if np.isnan(self.coef_): 144 | self.coef_ = 0 145 | if np.isnan(self.intercept_): 146 | self.intercept_ = 0 147 | 148 | self.coef_ = np.clip(self.coef_, *self.constrain_ratio) 149 | 150 | return self 151 | -------------------------------------------------------------------------------- /TFvelo/core/_metrics.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import numpy as np 4 | from numpy import ndarray 5 | from scipy.sparse import issparse, spmatrix 6 | 7 | 8 | # TODO: Add case `axis == None` 9 | def l2_norm(x: Union[ndarray, spmatrix], axis: int = 1) -> Union[float, ndarray]: 10 | """Calculate l2 norm along a given axis. 11 | 12 | Arguments 13 | --------- 14 | x 15 | Array to calculate l2 norm of. 16 | axis 17 | Axis along which to calculate l2 norm. 18 | 19 | Returns 20 | ------- 21 | Union[float, ndarray] 22 | L2 norm along a given axis. 23 | """ 24 | 25 | if issparse(x): 26 | return np.sqrt(x.multiply(x).sum(axis=axis).A1) 27 | elif x.ndim == 1: 28 | return np.sqrt(np.einsum("i, i -> ", x, x)) 29 | elif axis == 0: 30 | return np.sqrt(np.einsum("ij, ij -> j", x, x)) 31 | elif axis == 1: 32 | return np.sqrt(np.einsum("ij, ij -> i", x, x)) 33 | -------------------------------------------------------------------------------- /TFvelo/core/_models.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Tuple, Union 2 | 3 | import numpy as np 4 | from numpy import ndarray 5 | 6 | from ._arithmetic import invert 7 | from ._base import DynamicsBase 8 | 9 | 10 | # TODO: Improve parameter names: alpha -> transcription_rate; beta -> splicing_rate; 11 | # gamma -> degradation_rate 12 | # TODO: Handle cases beta = 0, gamma == 0, beta == gamma 13 | class SplicingDynamics(DynamicsBase): 14 | """Splicing dynamics. 15 | 16 | Arguments 17 | --------- 18 | alpha 19 | Transcription rate. 20 | beta 21 | Translation rate. 22 | gamma 23 | Splicing degradation rate. 24 | initial_state 25 | Initial state of system. Defaults to `[0, 0]`. 26 | 27 | Attributes 28 | ---------- 29 | alpha 30 | Transcription rate. 31 | beta 32 | Translation rate. 33 | gamma 34 | Splicing degradation rate. 35 | initial_state 36 | Initial state of system. Defaults to `[0, 0]`. 37 | u0 38 | Initial abundance of unspliced RNA. 39 | s0 40 | Initial abundance of spliced RNA. 41 | 42 | """ 43 | 44 | def __init__( 45 | self, 46 | alpha: float, 47 | beta: float, 48 | omega: float, 49 | theta: float, 50 | gamma: float, 51 | array_flag = False 52 | #initial_state: Union[List, ndarray] = [0, 0], 53 | ): 54 | self.alpha = alpha 55 | self.beta = beta 56 | self.gamma = gamma 57 | self.omega = omega 58 | self.theta = theta 59 | self.array_flag = array_flag 60 | 61 | if self.array_flag: 62 | self.alpha=self.alpha.reshape(-1,1) 63 | self.beta=self.beta.reshape(-1,1) 64 | self.gamma=self.gamma.reshape(-1,1) 65 | self.omega=self.omega.reshape(-1,1) 66 | self.theta=self.theta.reshape(-1,1) 67 | 68 | @property 69 | def initial_state(self): 70 | return self._initial_state 71 | 72 | @initial_state.setter 73 | def initial_state(self, val): 74 | if isinstance(val, list) or (isinstance(val, ndarray) and (val.ndim == 1)): 75 | self.u0 = val[0] 76 | self.s0 = val[1] 77 | else: 78 | self.u0 = val[:, 0] 79 | self.s0 = val[:, 1] 80 | self._initial_state = val 81 | 82 | def get_solution( 83 | self, t: ndarray, stacked: bool = True, with_keys: bool = False 84 | ) -> Union[Dict, ndarray]: 85 | """Calculate solution of dynamics. 86 | 87 | Arguments 88 | --------- 89 | t 90 | Time steps at which to evaluate solution. 91 | stacked 92 | Whether to stack states or return them individually. Defaults to `True`. 93 | with_keys 94 | Whether to return solution labelled by variables in form of a dictionary. 95 | Defaults to `False`. 96 | 97 | Returns 98 | ------- 99 | Union[Dict, ndarray] 100 | Solution of system. If `with_keys=True`, the solution is returned in form of 101 | a dictionary with variables as keys. Otherwise, the solution is given as 102 | a `numpy.ndarray` of form `(n_steps, 2)`. 103 | """ 104 | if self.array_flag: 105 | t = t.reshape(1,-1) 106 | 107 | phi = np.arctan(self.omega/self.gamma) 108 | tmp1 = self.omega * t + self.theta 109 | tmp2 = np.sqrt(self.omega*self.omega + self.gamma*self.gamma) 110 | 111 | y = self.alpha * np.sin(tmp1) + self.beta 112 | WX = self.alpha * tmp2 * np.sin(tmp1+phi) + self.beta*self.gamma 113 | 114 | if with_keys: 115 | return {"WX": WX, "y": y} 116 | elif not stacked: 117 | return WX, y 118 | else: 119 | if isinstance(t, np.ndarray) and t.ndim == 2: 120 | return np.stack([WX, y], axis=2) 121 | else: 122 | return np.column_stack([WX, y]) 123 | 124 | -------------------------------------------------------------------------------- /TFvelo/core/_parallelize.py: -------------------------------------------------------------------------------- 1 | import os 2 | from multiprocessing import Manager 3 | from threading import Thread 4 | from typing import Any, Callable, Optional, Sequence, Union 5 | 6 | from joblib import delayed, Parallel 7 | 8 | import numpy as np 9 | from scipy.sparse import issparse, spmatrix 10 | 11 | from .. import logging as logg 12 | 13 | _msg_shown = False 14 | 15 | 16 | def get_n_jobs(n_jobs): 17 | if n_jobs is None or (n_jobs < 0 and os.cpu_count() + 1 + n_jobs <= 0): 18 | return 1 19 | elif n_jobs > os.cpu_count(): 20 | return os.cpu_count() 21 | elif n_jobs < 0: 22 | return os.cpu_count() + 1 + n_jobs 23 | else: 24 | return n_jobs 25 | 26 | 27 | def parallelize( 28 | callback: Callable[[Any], Any], 29 | collection: Union[spmatrix, Sequence[Any]], 30 | n_jobs: Optional[int] = None, 31 | n_split: Optional[int] = None, 32 | unit: str = "", 33 | as_array: bool = True, 34 | use_ixs: bool = False, 35 | backend: str = "loky", 36 | extractor: Optional[Callable[[Any], Any]] = None, 37 | show_progress_bar: bool = True, 38 | ) -> Union[np.ndarray, Any]: 39 | """ 40 | Parallelize function call over a collection of elements. 41 | 42 | Parameters 43 | ---------- 44 | callback 45 | Function to parallelize. 46 | collection 47 | Sequence of items which to chunkify. 48 | n_jobs 49 | Number of parallel jobs. 50 | n_split 51 | Split :paramref:`collection` into :paramref:`n_split` chunks. 52 | If `None`, split into :paramref:`n_jobs` chunks. 53 | unit 54 | Unit of the progress bar. 55 | as_array 56 | Whether to convert the results not :class:`numpy.ndarray`. 57 | use_ixs 58 | Whether to pass indices to the callback. 59 | backend 60 | Which backend to use for multiprocessing. See :class:`joblib.Parallel` for valid 61 | options. 62 | extractor 63 | Function to apply to the result after all jobs have finished. 64 | show_progress_bar 65 | Whether to show a progress bar. 66 | 67 | Returns 68 | ------- 69 | :class:`numpy.ndarray` 70 | Result depending on :paramref:`extractor` and :paramref:`as_array`. 71 | """ 72 | 73 | if show_progress_bar: 74 | try: 75 | try: 76 | from tqdm.notebook import tqdm 77 | except ImportError: 78 | from tqdm import tqdm_notebook as tqdm 79 | import ipywidgets # noqa 80 | except ImportError: 81 | global _msg_shown 82 | tqdm = None 83 | 84 | if not _msg_shown: 85 | logg.warn( 86 | "Unable to create progress bar. " 87 | "Consider installing `tqdm` as `pip install tqdm` " 88 | "and `ipywidgets` as `pip install ipywidgets`,\n" 89 | "or disable the progress bar using `show_progress_bar=False`." 90 | ) 91 | _msg_shown = True 92 | else: 93 | tqdm = None 94 | 95 | def update(pbar, queue, n_total): 96 | n_finished = 0 97 | while n_finished < n_total: 98 | try: 99 | res = queue.get() 100 | except EOFError as e: 101 | if not n_finished != n_total: 102 | raise RuntimeError( 103 | f"Finished only `{n_finished} out of `{n_total}` tasks.`" 104 | ) from e 105 | break 106 | assert res in (None, (1, None), 1) # (None, 1) means only 1 job 107 | if res == (1, None): 108 | n_finished += 1 109 | if pbar is not None: 110 | pbar.update() 111 | elif res is None: 112 | n_finished += 1 113 | elif pbar is not None: 114 | pbar.update() 115 | 116 | if pbar is not None: 117 | pbar.close() 118 | 119 | def wrapper(*args, **kwargs): 120 | if pass_queue and show_progress_bar: 121 | pbar = None if tqdm is None else tqdm(total=col_len, unit=unit) 122 | queue = Manager().Queue() 123 | thread = Thread(target=update, args=(pbar, queue, len(collections))) 124 | thread.start() 125 | else: 126 | pbar, queue, thread = None, None, None 127 | 128 | res = Parallel(n_jobs=n_jobs, backend=backend)( 129 | delayed(callback)( 130 | *((i, cs) if use_ixs else (cs,)), 131 | *args, 132 | **kwargs, 133 | queue=queue, 134 | ) 135 | for i, cs in enumerate(collections) 136 | ) 137 | 138 | res = np.array(res) if as_array else res 139 | if thread is not None: 140 | thread.join() 141 | 142 | return res if extractor is None else extractor(res) 143 | 144 | col_len = collection.shape[0] if issparse(collection) else len(collection) 145 | 146 | if n_split is None: 147 | n_split = get_n_jobs(n_jobs=n_jobs) 148 | 149 | if issparse(collection): 150 | if n_split == collection.shape[0]: 151 | collections = [collection[[ix], :] for ix in range(collection.shape[0])] 152 | else: 153 | step = collection.shape[0] // n_split 154 | 155 | ixs = [ 156 | np.arange(i * step, min((i + 1) * step, collection.shape[0])) 157 | for i in range(n_split) 158 | ] 159 | ixs[-1] = np.append( 160 | ixs[-1], np.arange(ixs[-1][-1] + 1, collection.shape[0]) 161 | ) 162 | 163 | collections = [collection[ix, :] for ix in filter(len, ixs)] 164 | else: 165 | collections = list(filter(len, np.array_split(collection, n_split))) 166 | 167 | pass_queue = not hasattr(callback, "py_func") # we'd be inside a numba function 168 | 169 | return wrapper 170 | -------------------------------------------------------------------------------- /TFvelo/core/tests/__init__.py: -------------------------------------------------------------------------------- 1 | from .test_base import get_adata, TestBase 2 | 3 | __all__ = ["get_adata", "TestBase"] 4 | -------------------------------------------------------------------------------- /TFvelo/core/tests/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaoyeye/TFvelo/ec6cdb940af94f02fe32c8fcdb98494cdb4beb96/TFvelo/core/tests/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /TFvelo/core/tests/__pycache__/test_anndata.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaoyeye/TFvelo/ec6cdb940af94f02fe32c8fcdb98494cdb4beb96/TFvelo/core/tests/__pycache__/test_anndata.cpython-37.pyc -------------------------------------------------------------------------------- /TFvelo/core/tests/__pycache__/test_arithmetic.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaoyeye/TFvelo/ec6cdb940af94f02fe32c8fcdb98494cdb4beb96/TFvelo/core/tests/__pycache__/test_arithmetic.cpython-37.pyc -------------------------------------------------------------------------------- /TFvelo/core/tests/__pycache__/test_base.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaoyeye/TFvelo/ec6cdb940af94f02fe32c8fcdb98494cdb4beb96/TFvelo/core/tests/__pycache__/test_base.cpython-37.pyc -------------------------------------------------------------------------------- /TFvelo/core/tests/__pycache__/test_linear_models.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaoyeye/TFvelo/ec6cdb940af94f02fe32c8fcdb98494cdb4beb96/TFvelo/core/tests/__pycache__/test_linear_models.cpython-37.pyc -------------------------------------------------------------------------------- /TFvelo/core/tests/__pycache__/test_metrics.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaoyeye/TFvelo/ec6cdb940af94f02fe32c8fcdb98494cdb4beb96/TFvelo/core/tests/__pycache__/test_metrics.cpython-37.pyc -------------------------------------------------------------------------------- /TFvelo/core/tests/__pycache__/test_models.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaoyeye/TFvelo/ec6cdb940af94f02fe32c8fcdb98494cdb4beb96/TFvelo/core/tests/__pycache__/test_models.cpython-37.pyc -------------------------------------------------------------------------------- /TFvelo/core/tests/test_anndata.py: -------------------------------------------------------------------------------- 1 | import hypothesis.strategies as st 2 | from hypothesis import given 3 | 4 | import numpy as np 5 | from numpy.testing import assert_array_equal 6 | from scipy.sparse import issparse 7 | 8 | from anndata import AnnData 9 | 10 | from scvelo.core import get_modality, make_dense, make_sparse, set_modality 11 | from .test_base import get_adata, TestBase 12 | 13 | 14 | class TestGetModality(TestBase): 15 | @given(adata=get_adata()) 16 | def test_get_modality(self, adata: AnnData): 17 | modality_to_get = self._subset_modalities(adata, 1)[0] 18 | modality_retrieved = get_modality(adata=adata, modality=modality_to_get) 19 | 20 | if modality_to_get == "X": 21 | assert_array_equal(adata.X, modality_retrieved) 22 | elif modality_to_get in adata.layers: 23 | assert_array_equal(adata.layers[modality_to_get], modality_retrieved) 24 | else: 25 | assert_array_equal(adata.obsm[modality_to_get], modality_retrieved) 26 | 27 | 28 | class TestMakeDense(TestBase): 29 | @given( 30 | adata=get_adata(sparse_entries=True), 31 | inplace=st.booleans(), 32 | n_modalities=st.integers(min_value=0), 33 | ) 34 | def test_make_dense(self, adata: AnnData, inplace: bool, n_modalities: int): 35 | modalities_to_densify = self._subset_modalities(adata, n_modalities) 36 | 37 | returned_adata = make_dense( 38 | adata=adata, modalities=modalities_to_densify, inplace=inplace 39 | ) 40 | 41 | if inplace: 42 | assert returned_adata is None 43 | assert np.all( 44 | [ 45 | not issparse(get_modality(adata=adata, modality=modality)) 46 | for modality in modalities_to_densify 47 | ] 48 | ) 49 | else: 50 | assert isinstance(returned_adata, AnnData) 51 | assert np.all( 52 | [ 53 | not issparse(get_modality(adata=returned_adata, modality=modality)) 54 | for modality in modalities_to_densify 55 | ] 56 | ) 57 | assert np.all( 58 | [ 59 | issparse(get_modality(adata=adata, modality=modality)) 60 | for modality in modalities_to_densify 61 | ] 62 | ) 63 | 64 | 65 | class TestMakeSparse(TestBase): 66 | @given( 67 | adata=get_adata(), 68 | inplace=st.booleans(), 69 | n_modalities=st.integers(min_value=0), 70 | ) 71 | def test_make_sparse(self, adata: AnnData, inplace: bool, n_modalities: int): 72 | modalities_to_make_sparse = self._subset_modalities(adata, n_modalities) 73 | 74 | returned_adata = make_sparse( 75 | adata=adata, modalities=modalities_to_make_sparse, inplace=inplace 76 | ) 77 | 78 | if inplace: 79 | assert returned_adata is None 80 | assert np.all( 81 | [ 82 | issparse(get_modality(adata=adata, modality=modality)) 83 | for modality in modalities_to_make_sparse 84 | if modality != "X" 85 | ] 86 | ) 87 | else: 88 | assert isinstance(returned_adata, AnnData) 89 | assert np.all( 90 | [ 91 | issparse(get_modality(adata=returned_adata, modality=modality)) 92 | for modality in modalities_to_make_sparse 93 | if modality != "X" 94 | ] 95 | ) 96 | assert np.all( 97 | [ 98 | not issparse(get_modality(adata=adata, modality=modality)) 99 | for modality in modalities_to_make_sparse 100 | if modality != "X" 101 | ] 102 | ) 103 | 104 | 105 | class TestSetModality(TestBase): 106 | @given(adata=get_adata(), inplace=st.booleans()) 107 | def test_set_modality(self, adata: AnnData, inplace: bool): 108 | modality_to_set = self._subset_modalities(adata, 1)[0] 109 | 110 | if (modality_to_set == "X") or (modality_to_set in adata.layers): 111 | new_value = np.random.randn(adata.n_obs, adata.n_vars) 112 | else: 113 | new_value = np.random.randn( 114 | adata.n_obs, np.random.randint(low=1, high=10000) 115 | ) 116 | 117 | returned_adata = set_modality( 118 | adata=adata, new_value=new_value, modality=modality_to_set, inplace=inplace 119 | ) 120 | 121 | if inplace: 122 | assert returned_adata is None 123 | if modality_to_set == "X": 124 | assert_array_equal(adata.X, new_value) 125 | elif modality_to_set in adata.layers: 126 | assert_array_equal(adata.layers[modality_to_set], new_value) 127 | else: 128 | assert_array_equal(adata.obsm[modality_to_set], new_value) 129 | else: 130 | assert isinstance(returned_adata, AnnData) 131 | if modality_to_set == "X": 132 | assert_array_equal(returned_adata.X, new_value) 133 | elif modality_to_set in adata.layers: 134 | assert_array_equal(returned_adata.layers[modality_to_set], new_value) 135 | else: 136 | assert_array_equal(returned_adata.obsm[modality_to_set], new_value) 137 | -------------------------------------------------------------------------------- /TFvelo/core/tests/test_arithmetic.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from hypothesis import given 4 | from hypothesis import strategies as st 5 | from hypothesis.extra.numpy import arrays 6 | 7 | import numpy as np 8 | from numpy import ndarray 9 | from numpy.testing import assert_almost_equal, assert_array_equal 10 | 11 | from scvelo.core import clipped_log, invert, prod_sum, sum 12 | 13 | 14 | class TestClippedLog: 15 | @given( 16 | a=arrays( 17 | float, 18 | shape=st.integers(min_value=1, max_value=100), 19 | elements=st.floats( 20 | min_value=-1e3, max_value=1e3, allow_infinity=False, allow_nan=False 21 | ), 22 | ), 23 | bounds=st.lists( 24 | st.floats( 25 | min_value=0, max_value=100, allow_infinity=False, allow_nan=False 26 | ), 27 | min_size=2, 28 | max_size=2, 29 | unique=True, 30 | ), 31 | eps=st.floats( 32 | min_value=1e-6, max_value=1, allow_infinity=False, allow_nan=False 33 | ), 34 | ) 35 | def test_flat_arrays(self, a: ndarray, bounds: List[float], eps: float): 36 | lb = min(bounds) 37 | ub = max(bounds) + 2 * eps 38 | 39 | a_logged = clipped_log(a, lb=lb, ub=ub, eps=eps) 40 | 41 | assert a_logged.shape == a.shape 42 | if (a <= lb).any(): 43 | assert_almost_equal(np.abs(a_logged - np.log(lb + eps)).min(), 0) 44 | else: 45 | assert (a_logged >= np.log(lb + eps)).all() 46 | if (a >= ub).any(): 47 | assert_almost_equal(np.abs(a_logged - np.log(ub - eps)).min(), 0) 48 | else: 49 | assert (a_logged <= np.log(ub - eps)).all() 50 | 51 | @given( 52 | a=arrays( 53 | float, 54 | shape=st.tuples( 55 | st.integers(min_value=1, max_value=100), 56 | st.integers(min_value=1, max_value=100), 57 | ), 58 | elements=st.floats( 59 | min_value=-1e3, max_value=1e3, allow_infinity=False, allow_nan=False 60 | ), 61 | ), 62 | bounds=st.lists( 63 | st.floats( 64 | min_value=0, max_value=100, allow_infinity=False, allow_nan=False 65 | ), 66 | min_size=2, 67 | max_size=2, 68 | unique=True, 69 | ), 70 | eps=st.floats( 71 | min_value=1e-6, max_value=1, allow_infinity=False, allow_nan=False 72 | ), 73 | ) 74 | def test_2d_arrays(self, a: ndarray, bounds: List[float], eps: float): 75 | lb = min(bounds) 76 | ub = max(bounds) + 2 * eps 77 | 78 | a_logged = clipped_log(a, lb=lb, ub=ub, eps=eps) 79 | 80 | assert a_logged.shape == a.shape 81 | if (a <= lb).any(): 82 | assert_almost_equal(np.abs(a_logged - np.log(lb + eps)).min(), 0) 83 | else: 84 | assert (a_logged >= np.log(lb + eps)).all() 85 | if (a >= ub).any(): 86 | assert_almost_equal(np.abs(a_logged - np.log(ub - eps)).min(), 0) 87 | else: 88 | assert (a_logged <= np.log(ub - eps)).all() 89 | 90 | 91 | class TestInvert: 92 | @given( 93 | a=arrays( 94 | float, 95 | shape=st.integers(min_value=1, max_value=100), 96 | elements=st.floats(max_value=1e3, allow_infinity=False, allow_nan=False), 97 | ) 98 | ) 99 | def test_flat_arrays(self, a: ndarray): 100 | a_inv = invert(a) 101 | 102 | if a[a != 0].size == 0: 103 | assert a_inv[a != 0].size == 0 104 | else: 105 | assert_array_equal(a_inv[a != 0], 1 / a[a != 0]) 106 | 107 | if 0 in a: 108 | assert np.isnan(a_inv[a == 0]).all() 109 | else: 110 | assert set(a_inv[a == 0]) == set() 111 | 112 | @given( 113 | a=arrays( 114 | float, 115 | shape=st.tuples( 116 | st.integers(min_value=1, max_value=100), 117 | st.integers(min_value=1, max_value=100), 118 | ), 119 | elements=st.floats(max_value=1e3, allow_infinity=False, allow_nan=False), 120 | ) 121 | ) 122 | def test_2d_arrays(self, a: ndarray): 123 | a_inv = invert(a) 124 | 125 | if a[a != 0].size == 0: 126 | assert a_inv[a != 0].size == 0 127 | else: 128 | assert_array_equal(a_inv[a != 0], 1 / a[a != 0]) 129 | 130 | if 0 in a: 131 | assert np.isnan(a_inv[a == 0]).all() 132 | else: 133 | assert set(a_inv[a == 0]) == set() 134 | 135 | 136 | # TODO: Extend test to generate sparse inputs as well 137 | # TODO: Make test to generate two different arrays a1, a2 138 | # TODO: Check why tests fail with assert_almost_equal 139 | class TestProdSum: 140 | @given( 141 | a=arrays( 142 | float, 143 | shape=st.integers(min_value=1, max_value=100), 144 | elements=st.floats(max_value=1e3, allow_infinity=False, allow_nan=False), 145 | ), 146 | axis=st.integers(min_value=0, max_value=1), 147 | ) 148 | def test_flat_array(self, a: ndarray, axis: int): 149 | assert np.allclose((a * a).sum(axis=0), prod_sum(a, a, axis=axis)) 150 | 151 | @given( 152 | a=arrays( 153 | float, 154 | shape=st.tuples( 155 | st.integers(min_value=1, max_value=100), 156 | st.integers(min_value=1, max_value=100), 157 | ), 158 | elements=st.floats(max_value=1e3, allow_infinity=False, allow_nan=False), 159 | ), 160 | axis=st.integers(min_value=0, max_value=1), 161 | ) 162 | def test_2d_array(self, a: ndarray, axis: int): 163 | assert np.allclose((a * a).sum(axis=axis), prod_sum(a, a, axis=axis)) 164 | 165 | 166 | # TODO: Extend test to generate sparse inputs as well 167 | class TestSum: 168 | @given( 169 | a=arrays( 170 | float, 171 | shape=st.integers(min_value=1, max_value=100), 172 | elements=st.floats(max_value=1e3, allow_infinity=False, allow_nan=False), 173 | ), 174 | ) 175 | def test_flat_arrays(self, a: ndarray): 176 | a_summed = sum(a=a, axis=0) 177 | 178 | assert_array_equal(a_summed, a.sum(axis=0)) 179 | 180 | @given( 181 | a=arrays( 182 | float, 183 | shape=st.tuples( 184 | st.integers(min_value=1, max_value=100), 185 | st.integers(min_value=1, max_value=100), 186 | ), 187 | elements=st.floats(max_value=1e3, allow_infinity=False, allow_nan=False), 188 | ), 189 | axis=st.integers(min_value=0, max_value=1), 190 | ) 191 | def test_2d_arrays(self, a: ndarray, axis: int): 192 | a_summed = sum(a=a, axis=axis) 193 | 194 | if a.ndim == 1: 195 | axis = 0 196 | 197 | assert_array_equal(a_summed, a.sum(axis=axis)) 198 | -------------------------------------------------------------------------------- /TFvelo/core/tests/test_base.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import List, Optional, Union 3 | 4 | import hypothesis.strategies as st 5 | from hypothesis import given 6 | from hypothesis.extra.numpy import arrays 7 | 8 | import numpy as np 9 | from scipy.sparse import csr_matrix, issparse 10 | 11 | from anndata import AnnData 12 | 13 | 14 | # TODO: Add possibility to generate adata object with floats as counts 15 | @st.composite 16 | def get_adata( 17 | draw, 18 | n_obs: Optional[int] = None, 19 | n_vars: Optional[int] = None, 20 | min_obs: Optional[int] = 1, 21 | max_obs: Optional[int] = 100, 22 | min_vars: Optional[int] = 1, 23 | max_vars: Optional[int] = 100, 24 | layer_keys: Optional[Union[List, str]] = None, 25 | min_layers: Optional[int] = 2, 26 | max_layers: int = 2, 27 | obsm_keys: Optional[Union[List, str]] = None, 28 | min_obsm: Optional[int] = 2, 29 | max_obsm: Optional[int] = 2, 30 | sparse_entries: bool = False, 31 | ) -> AnnData: 32 | """Generate an AnnData object. 33 | 34 | The largest possible value of a numerical entry is `1e5`. 35 | 36 | Arguments 37 | --------- 38 | n_obs: 39 | Number of observations. If set to `None`, a random integer between `1` and 40 | `max_obs` will be drawn. Defaults to `None`. 41 | n_vars: 42 | Number of variables. If set to `None`, a random integer between `1` and 43 | `max_vars` will be drawn. Defaults to `None`. 44 | min_obs: 45 | Minimum number of observations. If set to `None`, there is no lower limit. 46 | Defaults to `1`. 47 | max_obs: 48 | Maximum number of observations. If set to `None`, there is no upper limit. 49 | Defaults to `100`. 50 | min_vars: 51 | Minimum number of variables. If set to `None`, there is no lower limit. 52 | Defaults to `1`. 53 | max_vars: 54 | Maximum number of variables. If set to `None`, there is no upper limit. 55 | Defaults to `100`. 56 | layer_keys: 57 | Names of layers. If set to `None`, layers will be named at random. Defaults 58 | to `None`. 59 | min_layers: 60 | Minimum number of layers. Is set to the number of provided layer names if 61 | `layer_keys` is not `None`. Defaults to `2`. 62 | max_layers: Maximum number of layers. Is set to the number of provided layer 63 | names if `layer_keys` is not `None`. Defaults to `2`. 64 | obsm_keys: 65 | Names of multi-dimensional observations annotation. If set to `None`, names 66 | will be generated at random. Defaults to `None`. 67 | min_obsm: 68 | Minimum number of multi-dimensional observations annotation. Is set to the 69 | number of keys if `obsm_keys` is not `None`. Defaults to `2`. 70 | max_obsm: 71 | Maximum number of multi-dimensional observations annotation. Is set to the 72 | number of keys if `obsm_keys` is not `None`. Defaults to `2`. 73 | sparse_entries: 74 | Whether or not to make AnnData entries sparse. 75 | 76 | Returns 77 | ------- 78 | AnnData 79 | Generated :class:`~anndata.AnnData` object. 80 | """ 81 | 82 | if n_obs is None: 83 | n_obs = draw(st.integers(min_value=min_obs, max_value=max_obs)) 84 | if n_vars is None: 85 | n_vars = draw(st.integers(min_value=min_vars, max_value=max_vars)) 86 | 87 | if isinstance(layer_keys, str): 88 | layer_keys = [layer_keys] 89 | if isinstance(obsm_keys, str): 90 | obsm_keys = [obsm_keys] 91 | 92 | if layer_keys is not None: 93 | min_layers = len(layer_keys) 94 | max_layers = len(layer_keys) 95 | if obsm_keys is not None: 96 | min_obsm = len(obsm_keys) 97 | max_obsm = len(obsm_keys) 98 | 99 | X = draw( 100 | arrays( 101 | dtype=int, 102 | elements=st.integers(min_value=0, max_value=1e2), 103 | shape=(n_obs, n_vars), 104 | ) 105 | ) 106 | 107 | layers = draw( 108 | st.dictionaries( 109 | st.text(min_size=1) if layer_keys is None else st.sampled_from(layer_keys), 110 | arrays( 111 | dtype=int, 112 | elements=st.integers(min_value=0, max_value=1e2), 113 | shape=(n_obs, n_vars), 114 | ), 115 | min_size=min_layers, 116 | max_size=max_layers, 117 | ) 118 | ) 119 | 120 | obsm = draw( 121 | st.dictionaries( 122 | st.text(min_size=1) if obsm_keys is None else st.sampled_from(obsm_keys), 123 | arrays( 124 | dtype=int, 125 | elements=st.integers(min_value=0, max_value=1e2), 126 | shape=st.tuples( 127 | st.integers(min_value=n_obs, max_value=n_obs), 128 | st.integers(min_value=min_vars, max_value=max_vars), 129 | ), 130 | ), 131 | min_size=min_obsm, 132 | max_size=max_obsm, 133 | ) 134 | ) 135 | 136 | # Make keys for layers and obsm unique 137 | for key in set(layers.keys()).intersection(obsm.keys()): 138 | layers[f"{key}_"] = layers.pop(key) 139 | 140 | if sparse_entries: 141 | layers = {key: csr_matrix(val) for key, val in layers.items()} 142 | obsm = {key: csr_matrix(val) for key, val in obsm.items()} 143 | return AnnData(X=csr_matrix(X), layers=layers, obsm=obsm) 144 | else: 145 | return AnnData(X=X, layers=layers, obsm=obsm) 146 | 147 | 148 | class TestAdataGeneration: 149 | @given(adata=get_adata()) 150 | def test_default_adata_generation(self, adata: AnnData): 151 | assert type(adata) is AnnData 152 | 153 | @given(adata=get_adata(sparse_entries=True)) 154 | def test_sparse_adata_generation(self, adata: AnnData): 155 | assert type(adata) is AnnData 156 | assert issparse(adata.X) 157 | assert np.all([issparse(adata.layers[layer]) for layer in adata.layers]) 158 | assert np.all([issparse(adata.obsm[name]) for name in adata.obsm]) 159 | 160 | @given( 161 | adata=get_adata( 162 | n_obs=2, n_vars=2, layer_keys=["unspliced", "spliced"], obsm_keys="X_umap" 163 | ) 164 | ) 165 | def test_custom_adata_generation(self, adata: AnnData): 166 | assert adata.X.shape == (2, 2) 167 | assert len(adata.layers) == 2 168 | assert len(adata.obsm) == 1 169 | assert set(adata.layers.keys()) == {"unspliced", "spliced"} 170 | assert set(adata.obsm.keys()) == {"X_umap"} 171 | 172 | 173 | class TestBase: 174 | def _subset_modalities( 175 | self, 176 | adata: AnnData, 177 | n_modalities: int, 178 | from_layers: bool = True, 179 | from_obsm: bool = True, 180 | ): 181 | """Subset modalities of an AnnData object.""" 182 | 183 | modalities = ["X"] 184 | if from_layers: 185 | modalities += list(adata.layers.keys()) 186 | if from_obsm: 187 | modalities += list(adata.obsm.keys()) 188 | return random.sample(modalities, min(len(modalities), n_modalities)) 189 | 190 | def _convert_to_float(self, adata: AnnData): 191 | """Convert AnnData entries in `layer` and `obsm` into floats.""" 192 | 193 | for layer in adata.layers: 194 | adata.layers[layer] = adata.layers[layer].astype(float) 195 | for obs in adata.obsm: 196 | adata.obsm[obs] = adata.obsm[obs].astype(float) 197 | -------------------------------------------------------------------------------- /TFvelo/core/tests/test_linear_models.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from hypothesis import given 3 | from hypothesis import strategies as st 4 | from hypothesis.extra.numpy import arrays 5 | 6 | import numpy as np 7 | from numpy import ndarray 8 | from numpy.testing import assert_almost_equal, assert_array_equal 9 | 10 | from scvelo.core import LinearRegression 11 | 12 | 13 | class TestLinearRegression: 14 | @given( 15 | x=arrays( 16 | float, 17 | shape=st.integers(min_value=1, max_value=100), 18 | elements=st.floats( 19 | min_value=-1e3, max_value=1e3, allow_infinity=False, allow_nan=False 20 | ), 21 | ), 22 | coef=st.floats( 23 | min_value=-1000, max_value=1000, allow_infinity=False, allow_nan=False 24 | ), 25 | ) 26 | def test_perfect_fit(self, x: ndarray, coef: float): 27 | lr = LinearRegression() 28 | lr.fit(x, x * coef) 29 | 30 | assert lr.intercept_ == 0 31 | if set(x) != {0}: # fit is only unique if x is non-trivial 32 | assert_almost_equal(lr.coef_, coef) 33 | 34 | @given( 35 | x=arrays( 36 | float, 37 | shape=st.tuples( 38 | st.integers(min_value=1, max_value=100), 39 | st.integers(min_value=1, max_value=100), 40 | ), 41 | elements=st.floats( 42 | min_value=-1e3, max_value=1e3, allow_infinity=False, allow_nan=False 43 | ), 44 | ), 45 | coef=arrays( 46 | float, 47 | shape=100, 48 | elements=st.floats( 49 | min_value=-1000, max_value=1000, allow_infinity=False, allow_nan=False 50 | ), 51 | ), 52 | ) 53 | # TODO: Extend test to use `percentile`. Zero columns (after trimming) make the 54 | # previous implementation of the unit test fail 55 | # TODO: Check why test fails if number of columns is increased to e.g. 1000 (500) 56 | def test_perfect_fit_2d(self, x: ndarray, coef: ndarray): 57 | coef = coef[: x.shape[1]] 58 | lr = LinearRegression() 59 | lr.fit(x, x * coef) 60 | 61 | assert lr.coef_.shape == (x.shape[1],) 62 | assert lr.intercept_.shape == (x.shape[1],) 63 | assert_array_equal(lr.intercept_, np.zeros(x.shape[1])) 64 | if set(x.flatten()) != {0}: # fit is only unique if x is non-trivial 65 | assert_almost_equal(lr.coef_, coef) 66 | 67 | # TODO: Use hypothesis 68 | # TODO: Integrate into `test_perfect_fit_2d` 69 | @pytest.mark.parametrize( 70 | "x, coef, intercept", 71 | [ 72 | (np.array([[0], [1], [2], [3]]), 0, 1), 73 | (np.array([[0], [1], [2], [3]]), 2, 1), 74 | (np.array([[0], [1], [2], [3]]), 2, -1), 75 | ], 76 | ) 77 | def test_perfect_fit_with_intercept( 78 | self, x: ndarray, coef: float, intercept: float 79 | ): 80 | lr = LinearRegression(fit_intercept=True, positive_intercept=False) 81 | lr.fit(x, x * coef + intercept) 82 | 83 | assert lr.coef_.shape == (x.shape[1],) 84 | assert lr.intercept_.shape == (x.shape[1],) 85 | assert_array_equal(lr.intercept_, intercept) 86 | assert_array_equal(lr.coef_, coef) 87 | -------------------------------------------------------------------------------- /TFvelo/core/tests/test_metrics.py: -------------------------------------------------------------------------------- 1 | from hypothesis import given 2 | from hypothesis import strategies as st 3 | from hypothesis.extra.numpy import arrays 4 | 5 | import numpy as np 6 | from numpy import ndarray 7 | 8 | from scvelo.core import l2_norm 9 | 10 | 11 | # TODO: Extend test to generate sparse inputs as well 12 | @given( 13 | a=arrays( 14 | float, 15 | shape=st.integers(min_value=1, max_value=100), 16 | elements=st.floats(max_value=1e3, allow_infinity=False, allow_nan=False), 17 | ), 18 | axis=st.integers(min_value=0, max_value=1), 19 | ) 20 | def test_l2_norm(a: ndarray, axis: int): 21 | if a.ndim == 1: 22 | np.allclose(np.linalg.norm(a), l2_norm(a, axis=axis)) 23 | else: 24 | np.allclose(np.linalg.norm(a, axis=axis), l2_norm(a, axis=axis)) 25 | -------------------------------------------------------------------------------- /TFvelo/core/tests/test_models.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import pytest 4 | from hypothesis import given 5 | from hypothesis import strategies as st 6 | from hypothesis.extra.numpy import arrays 7 | 8 | import numpy as np 9 | from numpy import ndarray 10 | from scipy.integrate import odeint 11 | 12 | from scvelo.core import SplicingDynamics 13 | 14 | 15 | class TestSplicingDynamics: 16 | @given( 17 | alpha=st.floats(min_value=0, allow_infinity=False), 18 | beta=st.floats(min_value=0, max_value=1, exclude_min=True), 19 | gamma=st.floats(min_value=0, max_value=1, exclude_min=True), 20 | initial_state=st.lists( 21 | st.floats(min_value=0, allow_infinity=False), min_size=2, max_size=2 22 | ), 23 | t=arrays( 24 | float, 25 | shape=st.integers(min_value=1, max_value=100), 26 | elements=st.floats( 27 | min_value=0, max_value=1e3, allow_infinity=False, allow_nan=False 28 | ), 29 | ), 30 | with_keys=st.booleans(), 31 | ) 32 | def test_output_form( 33 | self, 34 | alpha: float, 35 | beta: float, 36 | gamma: float, 37 | initial_state: List[float], 38 | t: ndarray, 39 | with_keys: bool, 40 | ): 41 | if beta == gamma: 42 | gamma = gamma + 1e-6 43 | 44 | splicing_dynamics = SplicingDynamics( 45 | alpha=alpha, beta=beta, gamma=gamma, initial_state=initial_state 46 | ) 47 | solution = splicing_dynamics.get_solution(t=t, with_keys=with_keys) 48 | 49 | if not with_keys: 50 | assert type(solution) == ndarray 51 | assert solution.shape == (len(t), 2) 52 | else: 53 | assert len(solution) == 2 54 | assert type(solution) == dict 55 | assert list(solution.keys()) == ["u", "s"] 56 | assert all([len(var) == len(t) for var in solution.values()]) 57 | 58 | # TODO: Check how / if hypothesis can be used instead. 59 | @pytest.mark.parametrize( 60 | "alpha, beta, gamma, initial_state", 61 | [ 62 | (5, 0.5, 0.4, [0, 1]), 63 | ], 64 | ) 65 | def test_solution(self, alpha, beta, gamma, initial_state): 66 | def model(y, t, alpha, beta, gamma): 67 | dydt = np.zeros(2) 68 | dydt[0] = alpha - beta * y[0] 69 | dydt[1] = beta * y[0] - gamma * y[1] 70 | 71 | return dydt 72 | 73 | t = np.linspace(0, 20, 10000) 74 | splicing_dynamics = SplicingDynamics( 75 | alpha=alpha, beta=beta, gamma=gamma, initial_state=initial_state 76 | ) 77 | exact_solution = splicing_dynamics.get_solution(t=t) 78 | 79 | numerical_solution = odeint( 80 | model, 81 | np.array(initial_state), 82 | t, 83 | args=( 84 | alpha, 85 | beta, 86 | gamma, 87 | ), 88 | ) 89 | 90 | assert np.allclose(numerical_solution, exact_solution) 91 | -------------------------------------------------------------------------------- /TFvelo/logging.py: -------------------------------------------------------------------------------- 1 | """Logging and Profiling 2 | """ 3 | from datetime import datetime 4 | from platform import python_version 5 | from sys import stdout 6 | from time import time as get_time 7 | 8 | from packaging.version import parse 9 | 10 | from anndata.logging import get_memory_usage 11 | 12 | from . import settings 13 | 14 | _VERBOSITY_LEVELS_FROM_STRINGS = {"error": 0, "warn": 1, "info": 2, "hint": 3} 15 | 16 | 17 | def info(*args, **kwargs): 18 | return msg(*args, v="info", **kwargs) 19 | 20 | 21 | def error(*args, **kwargs): 22 | args = ("Error:",) + args 23 | return msg(*args, v="error", **kwargs) 24 | 25 | 26 | def warn(*args, **kwargs): 27 | args = ("WARNING:",) + args 28 | return msg(*args, v="warn", **kwargs) 29 | 30 | 31 | def hint(*args, **kwargs): 32 | return msg(*args, v="hint", **kwargs) 33 | 34 | 35 | def _settings_verbosity_greater_or_equal_than(v): 36 | if isinstance(settings.verbosity, str): 37 | settings_v = _VERBOSITY_LEVELS_FROM_STRINGS[settings.verbosity] 38 | else: 39 | settings_v = settings.verbosity 40 | return settings_v >= v 41 | 42 | 43 | def msg( 44 | *msg, 45 | v=None, 46 | time=False, 47 | memory=False, 48 | reset=False, 49 | end="\n", 50 | no_indent=False, 51 | t=None, 52 | m=None, 53 | r=None, 54 | ): 55 | """Write message to logging output. 56 | Log output defaults to standard output but can be set to a file 57 | by setting `sc.settings.log_file = 'mylogfile.txt'`. 58 | v : {'error', 'warn', 'info', 'hint'} or int, (default: 4) 59 | 0/'error', 1/'warn', 2/'info', 3/'hint', 4, 5, 6... 60 | time, t : bool, optional (default: False) 61 | Print timing information; restart the clock. 62 | memory, m : bool, optional (default: Faulse) 63 | Print memory information. 64 | reset, r : bool, optional (default: False) 65 | Reset timing and memory measurement. Is automatically reset 66 | when passing one of ``time`` or ``memory``. 67 | end : str (default: '\n') 68 | Same meaning as in builtin ``print()`` function. 69 | no_indent : bool (default: False) 70 | Do not indent for ``v >= 4``. 71 | """ 72 | # variable shortcuts 73 | if t is not None: 74 | time = t 75 | if m is not None: 76 | memory = m 77 | if r is not None: 78 | reset = r 79 | if v is None: 80 | v = 4 81 | if isinstance(v, str): 82 | v = _VERBOSITY_LEVELS_FROM_STRINGS[v] 83 | if v == 3: # insert "--> " before hints 84 | msg = ("-->",) + msg 85 | if v >= 4 and not no_indent: 86 | msg = (" ",) + msg 87 | if _settings_verbosity_greater_or_equal_than(v): 88 | if not time and not memory and len(msg) > 0: 89 | _write_log(*msg, end=end) 90 | if reset: 91 | try: 92 | settings._previous_memory_usage, _ = get_memory_usage() 93 | except Exception: 94 | pass 95 | settings._previous_time = get_time() 96 | if time: 97 | elapsed = get_passed_time() 98 | msg = msg + (f"({_sec_to_str(elapsed)})",) 99 | _write_log(*msg, end=end) 100 | if memory: 101 | _write_log(get_memory_usage(), end=end) 102 | 103 | 104 | m = msg 105 | 106 | 107 | def _write_log(*msg, end="\n"): 108 | """Write message to log output, ignoring the verbosity level. 109 | This is the most basic function. 110 | Parameters 111 | ---------- 112 | *msg : 113 | One or more arguments to be formatted as string. Same behavior as print 114 | function. 115 | """ 116 | from .settings import logfile 117 | 118 | if logfile == "": 119 | print(*msg, end=end) 120 | else: 121 | out = "" 122 | for s in msg: 123 | out += f"{s} " 124 | with open(logfile, "a") as f: 125 | f.write(out + end) 126 | 127 | 128 | def _sec_to_str(t, show_microseconds=False): 129 | """Format time in seconds. 130 | Parameters 131 | ---------- 132 | t : int 133 | Time in seconds. 134 | """ 135 | from functools import reduce 136 | 137 | t_str = "%d:%02d:%02d.%02d" % reduce( 138 | lambda ll, b: divmod(ll[0], b) + ll[1:], [(t * 100,), 100, 60, 60] 139 | ) 140 | return t_str if show_microseconds else t_str[:-3] 141 | 142 | 143 | def get_passed_time(): 144 | now = get_time() 145 | elapsed = now - settings._previous_time 146 | settings._previous_time = now 147 | return elapsed 148 | 149 | 150 | def print_passed_time(): 151 | return _sec_to_str(get_passed_time()) 152 | 153 | 154 | def timeout(func, args=(), timeout_duration=2, default=None, **kwargs): 155 | """This will spwan a thread and run the given function using the args, kwargs and 156 | return the given default value if the timeout_duration is exceeded 157 | """ 158 | import threading 159 | 160 | class InterruptableThread(threading.Thread): 161 | def __init__(self): 162 | threading.Thread.__init__(self) 163 | self.result = default 164 | 165 | def run(self): 166 | try: 167 | self.result = func(*args, **kwargs) 168 | except Exception: 169 | pass 170 | 171 | it = InterruptableThread() 172 | it.start() 173 | it.join(timeout_duration) 174 | return it.result 175 | 176 | 177 | def get_latest_pypi_version(): 178 | from subprocess import CalledProcessError, check_output 179 | 180 | try: # needs to work offline as well 181 | result = check_output(["pip", "search", "scvelo"]) 182 | return f"{result.split()[-1]}"[2:-1] 183 | except CalledProcessError: 184 | return "0.0.0" 185 | 186 | 187 | def check_if_latest_version(): 188 | from . import __version__ 189 | 190 | latest_version = timeout( 191 | get_latest_pypi_version, timeout_duration=2, default="0.0.0" 192 | ) 193 | if parse(__version__.rsplit(".dev")[0]) < parse(latest_version.rsplit(".dev")[0]): 194 | warn( 195 | "There is a newer scvelo version available on PyPI:\n", 196 | "Your version: \t\t", 197 | __version__, 198 | "\nLatest version: \t", 199 | latest_version, 200 | ) 201 | 202 | 203 | def print_version(): 204 | from . import __version__ 205 | 206 | _write_log( 207 | f"Running scvelo {__version__} " 208 | f"(python {python_version()}) on {get_date_string()}.", 209 | ) 210 | check_if_latest_version() 211 | 212 | 213 | def print_versions(): 214 | for mod in [ 215 | "scvelo", 216 | "scanpy", 217 | "anndata", 218 | "loompy", 219 | "numpy", 220 | "scipy", 221 | "matplotlib", 222 | "sklearn", 223 | "pandas", 224 | ]: 225 | mod_name = mod[0] if isinstance(mod, tuple) else mod 226 | mod_install = mod[1] if isinstance(mod, tuple) else mod 227 | try: 228 | mod_version = __import__(mod_name).__version__ 229 | _write_log(f"{mod_install}=={mod_version}", end=" ") 230 | except (ImportError, AttributeError): 231 | pass 232 | _write_log("") 233 | check_if_latest_version() 234 | 235 | 236 | def get_date_string(): 237 | return datetime.now().strftime("%Y-%m-%d %H:%M") 238 | 239 | 240 | def switch_verbosity(mode="on", module=None): 241 | if module is None: 242 | from . import settings 243 | elif module == "scanpy": 244 | from scanpy import settings 245 | else: 246 | exec(f"from {module} import settings") 247 | 248 | if mode == "on" and hasattr(settings, "tmp_verbosity"): 249 | settings.verbosity = settings.tmp_verbosity 250 | del settings.tmp_verbosity 251 | 252 | elif mode == "off": 253 | settings.tmp_verbosity = settings.verbosity 254 | settings.verbosity = 0 255 | 256 | elif not isinstance(mode, str): 257 | settings.tmp_verbosity = settings.verbosity 258 | settings.verbosity = mode 259 | 260 | 261 | class ProgressReporter: 262 | def __init__(self, total, interval=3): 263 | self.count = 0 264 | self.total = total 265 | self.timestamp = get_time() 266 | self.interval = interval 267 | 268 | def update(self): 269 | self.count += 1 270 | if settings.verbosity > 1 and ( 271 | get_time() - self.timestamp > self.interval or self.count == self.total 272 | ): 273 | self.timestamp = get_time() 274 | percent = int(self.count * 100 / self.total) 275 | stdout.write(f"\r... {percent}%") 276 | stdout.flush() 277 | 278 | def finish(self): 279 | if settings.verbosity > 1: 280 | stdout.write("\r") 281 | stdout.flush() 282 | 283 | 284 | def profiler(command, filename="profile.stats", n_stats=10): 285 | """Profiler for a python program 286 | 287 | Runs cProfile and outputs ordered statistics that describe 288 | how often and for how long various parts of the program are executed. 289 | 290 | Stats can be visualized with `!snakeviz profile.stats`. 291 | 292 | Parameters 293 | ---------- 294 | command: str 295 | Command string to be executed. 296 | filename: str 297 | Name under which to store the stats. 298 | n_stats: int or None 299 | Number of top stats to show. 300 | """ 301 | import cProfile 302 | import pstats 303 | 304 | cProfile.run(command, filename) 305 | stats = pstats.Stats(filename).strip_dirs().sort_stats("time") 306 | return stats.print_stats(n_stats or {}) 307 | -------------------------------------------------------------------------------- /TFvelo/pl.py: -------------------------------------------------------------------------------- 1 | from .plotting import * # noqa 2 | -------------------------------------------------------------------------------- /TFvelo/plotting/__init__.py: -------------------------------------------------------------------------------- 1 | from scanpy.plotting import paga_compare, rank_genes_groups 2 | 3 | from .gridspec import gridspec 4 | from .heatmap import heatmap 5 | from .paga import paga 6 | from .proportions import proportions 7 | from .scatter import diffmap, draw_graph, pca, phate, scatter, tsne, umap 8 | from .simulation import simulation 9 | from .utils import hist, plot 10 | from .velocity import velocity 11 | from .velocity_embedding import velocity_embedding 12 | from .velocity_embedding_grid import velocity_embedding_grid 13 | from .velocity_embedding_stream import velocity_embedding_stream 14 | from .velocity_graph import velocity_graph 15 | 16 | __all__ = [ 17 | "diffmap", 18 | "draw_graph", 19 | "gridspec", 20 | "heatmap", 21 | "hist", 22 | "paga", 23 | "paga_compare", 24 | "pca", 25 | "phate", 26 | "plot", 27 | "proportions", 28 | "rank_genes_groups", 29 | "scatter", 30 | "simulation", 31 | "tsne", 32 | "umap", 33 | "velocity", 34 | "velocity_embedding", 35 | "velocity_embedding_grid", 36 | "velocity_embedding_stream", 37 | "velocity_graph", 38 | ] 39 | -------------------------------------------------------------------------------- /TFvelo/plotting/docs.py: -------------------------------------------------------------------------------- 1 | """Shared docstrings for plotting function parameters. 2 | """ 3 | from textwrap import dedent 4 | 5 | 6 | def doc_params(**kwds): 7 | """\ 8 | Docstrings should start with "\" in the first line for proper formatting. 9 | """ 10 | 11 | def dec(obj): 12 | obj.__doc__ = dedent(obj.__doc__).format(**kwds) 13 | return obj 14 | 15 | return dec 16 | 17 | 18 | doc_scatter = """\ 19 | basis: `str` or list of `str` (default: `None`) 20 | Key for embedding. If not specified, use 'umap', 'tsne' or 'pca' (ordered by 21 | preference). 22 | vkey: `str` or list of `str` (default: `None`) 23 | Key for velocity / steady-state ratio to be visualized. 24 | color: `str`, list of `str` or `None` (default: `None`) 25 | Key for annotations of observations/cells or variables/genes 26 | use_raw : `bool` (default: `None`) 27 | Use `raw` attribute of `adata` if present. 28 | layer: `str`, list of `str` or `None` (default: `None`) 29 | Specify the layer for `color`. 30 | color_map: `str` (default: `matplotlib.rcParams['image.cmap']`) 31 | String denoting matplotlib color map. 32 | colorbar: `bool` (default: `False`) 33 | Whether to show colorbar. 34 | palette: list of `str` (default: `None`) 35 | Colors to use for plotting groups (categorical annotation). 36 | size: `float` (default: 5) 37 | Point size. 38 | alpha: `float` (default: 1) 39 | Set blending - 0 transparent to 1 opaque. 40 | linewidth: `float` (default: 1) 41 | Scaling factor for the width of occurring lines. 42 | linecolor: `str` ir list of `str` (default: 'k') 43 | Color of lines from velocity fits, linear fits and polynomial fits 44 | perc: tuple, e.g. [2,98] (default: `None`) 45 | Specify percentile for continuous coloring. 46 | groups: `str` or list of `str` (default: `all groups`) 47 | Restrict to a few categories in categorical observation annotation. 48 | Multiple categories can be passed as list with ['cluster_1', 'cluster_3'], 49 | or as string with 'cluster_1, cluster_3'. 50 | sort_order: `bool` (default: `True`) 51 | For continuous annotations used as color parameter, 52 | plot data points with higher values on top of others. 53 | components: `str` or list of `str` (default: '1,2') 54 | For instance, ['1,2', '2,3']. 55 | projection: {'2d', '3d'} (default: '2d') 56 | Projection of plot. 57 | legend_loc: str (default: 'none') 58 | Location of legend, either 'on data', 'right margin' or valid keywords 59 | for matplotlib.legend. 60 | legend_fontsize: `int` (default: `None`) 61 | Legend font size. 62 | legend_fontweight: {'normal', 'bold', ...} (default: `None`) 63 | Legend font weight. A numeric value in range 0-1000 or a string. 64 | Defaults to 'bold' if `legend_loc = 'on data'`, otherwise to 'normal'. 65 | Available are `['light', 'normal', 'medium', 'semibold', 'bold', 'heavy', 'black']`. 66 | legend_fontoutline: float (default: `None`) 67 | Line width of the legend font outline in pt. Draws a white outline using 68 | the path effect :class:`~matplotlib.patheffects.withStroke`. 69 | legend_align_text: bool or str (default: `None`) 70 | Aligns the positions of the legend texts. Set the axis along which the best 71 | alignment should be determined. This can be 'y' or True (vertically), 72 | 'x' (horizontally), or 'xy' (best alignment in both directions). 73 | right_margin: `float` or list of `float` (default: `None`) 74 | Adjust the width of the space right of each plotting panel. 75 | left_margin: `float` or list of `float` (default: `None`) 76 | Adjust the width of the space left of each plotting panel. 77 | xlabel: `str` (default: `None`) 78 | Label of x-axis. 79 | ylabel: `str` (default: `None`) 80 | Label of y-axis. 81 | title: `str` (default: `None`) 82 | Provide title for panels either as, e.g. `["title1", "title2", ...]`. 83 | fontsize: `float` (default: `None`) 84 | Label font size. 85 | figsize: tuple (default: `(7,5)`) 86 | Figure size. 87 | xlim: tuple, e.g. [0,1] or `None` (default: `None`) 88 | Restrict x-limits of the axis. 89 | ylim: tuple, e.g. [0,1] or `None` (default: `None`) 90 | Restrict y-limits of the axis. 91 | add_density: `bool` or `str` or `None` (default: `None`) 92 | Whether to show density of values along x and y axes. 93 | Color of the density plot can also be passed as `str`. 94 | add_assignments: `bool` or `str` or `None` (default: `None`) 95 | Whether to add assignments to the model curve. 96 | Color of the assignments can also be passed as `str`. 97 | add_linfit: `bool` or `str` or `None` (default: `None`) 98 | Whether to add linear regression fit to the data points. 99 | Color of the line can also be passed as `str`. 100 | Fitting with or without an intercept by passing `'intercept'` or `'no_intercept'`. 101 | A colored regression line with intercept is obtained with `'intercept, blue'`. 102 | add_polyfit: `bool` or `str` or `int` or `None` (default: `None`) 103 | Whether to add polynomial fit to the data points. Color of the polyfit plot can also 104 | be passed as `str`. The degree of the polynomial fit can be passed as `int` 105 | (default is 2 for quadratic fit). 106 | Fitting with or without an intercept by passing `'intercept'` or `'no_intercept'`. 107 | A colored regression line with intercept is obtained with `'intercept, blue'`. 108 | add_rug: `str` or `None` (default: `None`) 109 | If categorical observation annotation (e.g. 'clusters') is given, a rugplot is 110 | attached to the x-axis showing the data membership to each of the categories. 111 | add_text: `str` (default: `None`) 112 | Text to be added to the plot, passed as `str`. 113 | add_text_pos: `tuple`, e.g. [0.05, 0.95] (defaut: `[0.05, 0.95]`) 114 | Text position. Default is `[0.05, 0.95]`, positioning the text at top right corner. 115 | add_margin: `float` (default: `None`) 116 | A value between [-1, 1] to add (positive) and reduce (negative) figure margins. 117 | add_outline: `bool` or `str` (default: `False`) 118 | Whether to show an outline around scatter plot dots. 119 | Alternatively a string of cluster names can be passed, e.g. 'cluster_1, clusters_3'. 120 | outline_width: tuple type `scalar` or `None` (default: `(0.3, 0.05)`) 121 | Width of the inner and outer outline 122 | outline_color: tuple of type `str` or `None` (default: `('black', 'white')`) 123 | Inner and outer matplotlib color of the outline 124 | n_convolve: `int` or `None` (default: `None`) 125 | If `int` is given, data is smoothed by convolution 126 | along the x-axis with kernel size `n_convolve`. 127 | smooth: `bool` or `int` (default: `None`) 128 | Whether to convolve/average the color values over the nearest neighbors. 129 | If `int`, it specifies number of neighbors. 130 | normalize_data: `bool` (default: `None`) 131 | Whether to rescale values for x, y to [0,1]. 132 | rescale_color: `tuple` (default: `None`) 133 | Boundaries for color rescaling, e.g. [0, 1], setting min/max values of the colorbar. 134 | color_gradients: `str` or `np.ndarray` (default: `None`) 135 | Key for `.obsm` or array with color gradients by categories. 136 | dpi: `int` (default: 80) 137 | Figure dpi. 138 | frameon: `bool` (default: `True`) 139 | Draw a frame around the scatter plot. 140 | ncols: `int` (default: `None`) 141 | Number of panels per row. 142 | nrows: `int` (default: `None`) 143 | Number of panels per column. 144 | wspace : `float` (default: None) 145 | Adjust the width of the space between multiple panels. 146 | hspace : `float` (default: None) 147 | Adjust the height of the space between multiple panels. 148 | show: `bool`, optional (default: `None`) 149 | Show the plot, do not return axis. 150 | save: `bool` or `str`, optional (default: `None`) 151 | If `True` or a `str`, save the figure. A string is appended to the default filename. 152 | Infer the filetype if ending on {'.pdf', '.png', '.svg'}. 153 | ax: `matplotlib.Axes`, optional (default: `None`) 154 | A matplotlib axes object. Only works if plotting a single component.\ 155 | """ 156 | -------------------------------------------------------------------------------- /TFvelo/plotting/gridspec.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import matplotlib.pyplot as pl 4 | 5 | # todo: auto-complete and docs wrapper 6 | from .scatter import scatter 7 | from .utils import get_figure_params, hist 8 | from .velocity_embedding import velocity_embedding 9 | from .velocity_embedding_grid import velocity_embedding_grid 10 | from .velocity_embedding_stream import velocity_embedding_stream 11 | from .velocity_graph import velocity_graph 12 | 13 | 14 | def _wraps_plot(wrapper, func): 15 | args = {"self", "kwargs"} 16 | annots_orig = {k: v for k, v in wrapper.__annotations__.items() if k not in args} 17 | annots = {k: v for k, v in func.__annotations__.items()} 18 | wrapper.__annotations__ = {**annots, **annots_orig} 19 | wrapper.__wrapped__ = func 20 | return wrapper 21 | 22 | 23 | _wraps_plot_scatter = partial(_wraps_plot, func=scatter) 24 | _wraps_plot_hist = partial(_wraps_plot, func=hist) 25 | _wraps_plot_velocity_graph = partial(_wraps_plot, func=velocity_graph) 26 | _wraps_plot_velocity_embedding = partial(_wraps_plot, func=velocity_embedding) 27 | _wraps_plot_velocity_embedding_grid = partial(_wraps_plot, func=velocity_embedding_grid) 28 | _wraps_plot_velocity_embedding_stream = partial( 29 | _wraps_plot, func=velocity_embedding_stream 30 | ) 31 | 32 | 33 | def gridspec(ncols=4, nrows=1, figsize=None, dpi=None): 34 | figsize, dpi = get_figure_params(figsize, dpi, ncols) 35 | gs = pl.GridSpec( 36 | nrows, ncols, pl.figure(None, (figsize[0] * ncols, figsize[1] * nrows), dpi=dpi) 37 | ) 38 | return gs 39 | 40 | 41 | class GridSpec: 42 | def __init__(self, ncols=4, nrows=1, figsize=None, dpi=None, **scatter_kwargs): 43 | """Specifies the geometry of the grid that a subplots can be placed in 44 | 45 | Example 46 | 47 | .. code:: python 48 | 49 | with scv.GridSpec() as pl: 50 | pl.scatter(adata, basis='pca') 51 | pl.scatter(adata, basis='umap') 52 | pl.hist(adata.obs.initial_size) 53 | 54 | Parameters 55 | ---------- 56 | ncols: `int` (default: 4) 57 | Number of panels per row. 58 | nrows: `int` (default: 1) 59 | Number of panels per column. 60 | figsize: tuple (default: `None`) 61 | Figure size. 62 | dpi: `int` (default: `None`) 63 | Figure dpi. 64 | scatter_kwargs: 65 | Arguments to be passed to all scatter plots, e.g. `frameon=False`. 66 | """ 67 | self.ncols, self.nrows, self.figsize, self.dpi = ncols, nrows, figsize, dpi 68 | self.scatter_kwargs = scatter_kwargs 69 | self.scatter_kwargs.update({"show": False}) 70 | self.get_new_grid() 71 | self.new_row = None 72 | 73 | def __enter__(self): 74 | return self 75 | 76 | def __exit__(self, exc_type, exc_val, exc_tb): 77 | if self.new_row and self.count < self.max_count: 78 | ax = pl.subplot(self.gs[self.max_count - 1]) 79 | ax.axis("off") 80 | pl.show() 81 | 82 | def get_new_grid(self): 83 | self.gs = gridspec(self.ncols, self.nrows, self.figsize, self.dpi) 84 | geo = self.gs[0].get_geometry() 85 | self.max_count, self.count, self.new_row = geo[0] * geo[1], 0, True 86 | 87 | def get_ax(self): 88 | if self.count >= self.max_count: 89 | self.get_new_grid() 90 | self.count += 1 91 | return pl.subplot(self.gs[self.count - 1]) 92 | 93 | def get_kwargs(self, kwargs=None): 94 | _kwargs = self.scatter_kwargs.copy() 95 | if kwargs is not None: 96 | _kwargs.update(kwargs) 97 | _kwargs.update({"ax": self.get_ax(), "show": False}) 98 | return _kwargs 99 | 100 | @_wraps_plot_scatter 101 | def scatter(self, adata, **kwargs): 102 | return scatter(adata, **self.get_kwargs(kwargs)) 103 | 104 | @_wraps_plot_velocity_embedding 105 | def velocity_embedding(self, adata, **kwargs): 106 | return velocity_embedding(adata, **self.get_kwargs(kwargs)) 107 | 108 | @_wraps_plot_velocity_embedding_grid 109 | def velocity_embedding_grid(self, adata, **kwargs): 110 | return velocity_embedding_grid(adata, **self.get_kwargs(kwargs)) 111 | 112 | @_wraps_plot_velocity_embedding_stream 113 | def velocity_embedding_stream(self, adata, **kwargs): 114 | return velocity_embedding_stream(adata, **self.get_kwargs(kwargs)) 115 | 116 | @_wraps_plot_velocity_graph 117 | def velocity_graph(self, adata, **kwargs): 118 | return velocity_graph(adata, **self.get_kwargs(kwargs)) 119 | 120 | @_wraps_plot_hist 121 | def hist(self, array, **kwargs): 122 | return hist(array, **self.get_kwargs(kwargs)) 123 | -------------------------------------------------------------------------------- /TFvelo/plotting/heatmap.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from scipy.sparse import issparse 4 | 5 | from .. import logging as logg 6 | from .utils import ( 7 | interpret_colorkey, 8 | is_categorical, 9 | savefig_or_show, 10 | set_colors_for_categorical_obs, 11 | strings_to_categoricals, 12 | to_list, 13 | ) 14 | 15 | 16 | def heatmap( 17 | adata, 18 | var_names, 19 | sortby="latent_time", 20 | layer="M_total", 21 | color_map="viridis", 22 | col_color=None, 23 | palette="viridis", 24 | n_convolve=30, 25 | standard_scale=0, 26 | sort=True, 27 | filter_start=0, 28 | filter_end=1, 29 | colorbar=None, 30 | col_cluster=False, 31 | row_cluster=False, 32 | context=None, 33 | font_scale=None, 34 | figsize=(8, 4), 35 | show=None, 36 | save=None, 37 | **kwargs, 38 | ): 39 | """\ 40 | Plot time series for genes as heatmap. 41 | 42 | Arguments 43 | --------- 44 | adata: :class:`~anndata.AnnData` 45 | Annotated data matrix. 46 | var_names: `str`, list of `str` 47 | Names of variables to use for the plot. 48 | sortby: `str` (default: `'latent_time'`) 49 | Observation key to extract time data from. 50 | layer: `str` (default: `'Ms'`) 51 | Layer key to extract count data from. 52 | color_map: `str` (default: `'viridis'`) 53 | String denoting matplotlib color map. 54 | col_color: `str` or list of `str` (default: `None`) 55 | String denoting matplotlib color map to use along the columns. 56 | palette: list of `str` (default: `'viridis'`) 57 | Colors to use for plotting groups (categorical annotation). 58 | n_convolve: `int` or `None` (default: `30`) 59 | If `int` is given, data is smoothed by convolution 60 | along the x-axis with kernel size n_convolve. 61 | standard_scale : `int` or `None` (default: `0`) 62 | Either 0 (rows) or 1 (columns). Whether or not to standardize that dimension 63 | (each row or column), subtract minimum and divide each by its maximum. 64 | sort: `bool` (default: `True`) 65 | Wether to sort the expression values given by xkey. 66 | colorbar: `bool` or `None` (default: `None`) 67 | Whether to show colorbar. 68 | {row,col}_cluster : `bool` or `None` 69 | If True, cluster the {rows, columns}. 70 | context : `None`, or one of {paper, notebook, talk, poster} 71 | A dictionary of parameters or the name of a preconfigured set. 72 | font_scale : float, optional 73 | Scaling factor to scale the size of the font elements. 74 | figsize: tuple (default: `(8,4)`) 75 | Figure size. 76 | show: `bool`, optional (default: `None`) 77 | Show the plot, do not return axis. 78 | save: `bool` or `str`, optional (default: `None`) 79 | If `True` or a `str`, save the figure. A string is appended to the default 80 | filename. Infer the filetype if ending on {'.pdf', '.png', '.svg'}. 81 | kwargs: 82 | Arguments passed to seaborns clustermap, 83 | e.g., set `yticklabels=True` to display all gene names in all rows. 84 | 85 | Returns 86 | ------- 87 | If `show==False` a `matplotlib.Axis` 88 | """ 89 | 90 | import seaborn as sns 91 | 92 | var_names = [name for name in var_names if name in adata.var_names] 93 | 94 | tkey, xkey = kwargs.pop("tkey", sortby), kwargs.pop("xkey", layer) 95 | time = adata.obs[tkey].values 96 | time = time[np.isfinite(time)] 97 | 98 | X = ( 99 | adata[:, var_names].layers[xkey] 100 | if xkey in adata.layers.keys() 101 | else adata[:, var_names].X 102 | ) 103 | if issparse(X): 104 | X = X.A 105 | df = pd.DataFrame(X[np.argsort(time)], columns=var_names) 106 | 107 | if n_convolve is not None: 108 | weights = np.ones(n_convolve) / n_convolve 109 | for gene in var_names: 110 | try: 111 | df[gene] = np.convolve(df[gene].values, weights, mode="same") 112 | except Exception: 113 | pass # e.g. all-zero counts or nans cannot be convolved 114 | 115 | if sort: 116 | max_sort = np.argsort(np.argmax(df.values, axis=0)) 117 | df = pd.DataFrame(df.values[:, max_sort], columns=df.columns[max_sort]) 118 | n_genes = len(var_names) 119 | df = pd.DataFrame(df.values[:, int(filter_start * n_genes): int(filter_end * n_genes)], columns=df.columns[int(filter_start * n_genes): int(filter_end * n_genes)]) 120 | strings_to_categoricals(adata) 121 | 122 | if col_color is not None: 123 | col_colors = to_list(col_color) 124 | col_color = [] 125 | for _, col in enumerate(col_colors): 126 | if not is_categorical(adata, col): 127 | obs_col = adata.obs[col] 128 | cat_col = np.round(obs_col / np.max(obs_col), 2) * np.max(obs_col) 129 | adata.obs[f"{col}_categorical"] = pd.Categorical(cat_col) 130 | col += "_categorical" 131 | set_colors_for_categorical_obs(adata, col, palette) 132 | col_color.append(interpret_colorkey(adata, col)[np.argsort(time)]) 133 | 134 | if "dendrogram_ratio" not in kwargs: 135 | kwargs["dendrogram_ratio"] = ( 136 | 0.1 if row_cluster else 0, 137 | 0.2 if col_cluster else 0, 138 | ) 139 | if "cbar_pos" not in kwargs or not colorbar: 140 | kwargs["cbar_pos"] = None 141 | 142 | kwargs.update( 143 | dict( 144 | col_colors=col_color, 145 | col_cluster=col_cluster, 146 | row_cluster=row_cluster, 147 | cmap=color_map, 148 | xticklabels=False, 149 | standard_scale=standard_scale, 150 | figsize=figsize, 151 | ) 152 | ) 153 | 154 | args = {} 155 | if font_scale is not None: 156 | args = {"font_scale": font_scale} 157 | context = context or "notebook" 158 | 159 | with sns.plotting_context(context=context, **args): 160 | try: 161 | cm = sns.clustermap(df.T, **kwargs) 162 | except Exception: 163 | logg.warn("Please upgrade seaborn with `pip install -U seaborn`.") 164 | kwargs.pop("dendrogram_ratio") 165 | kwargs.pop("cbar_pos") 166 | cm = sns.clustermap(df.T, **kwargs) 167 | 168 | savefig_or_show("heatmap", save=save, show=show) 169 | #if show is False: 170 | # return cm 171 | return df 172 | -------------------------------------------------------------------------------- /TFvelo/plotting/palettes.py: -------------------------------------------------------------------------------- 1 | from typing import Mapping, Sequence 2 | 3 | from matplotlib import cm, colors 4 | 5 | """Color palettes in addition to matplotlib's palettes.""" 6 | 7 | 8 | # Colorblindness adjusted vega_10 9 | # See https://github.com/theislab/scanpy/issues/387 10 | vega_10 = list(map(colors.to_hex, cm.tab10.colors)) 11 | vega_10_scanpy = vega_10.copy() 12 | vega_10_scanpy[2] = "#279e68" # green 13 | vega_10_scanpy[4] = "#aa40fc" # purple 14 | vega_10_scanpy[8] = "#b5bd61" # kakhi 15 | 16 | # default matplotlib 2.0 palette 17 | # see 'category20' on https://github.com/vega/vega/wiki/Scales#scale-range-literals 18 | vega_20 = list(map(colors.to_hex, cm.tab20.colors)) 19 | 20 | # reorderd, some removed, some added 21 | vega_20_scanpy = [ 22 | *vega_20[0:14:2], 23 | *vega_20[16::2], # dark without grey 24 | *vega_20[1:15:2], 25 | *vega_20[17::2], # light without grey 26 | "#ad494a", 27 | "#8c6d31", # manual additions 28 | ] 29 | vega_20_scanpy[2] = vega_10_scanpy[2] 30 | vega_20_scanpy[4] = vega_10_scanpy[4] 31 | vega_20_scanpy[7] = vega_10_scanpy[8] # kakhi shifted by missing grey 32 | 33 | default_20 = vega_20_scanpy 34 | 35 | # fmt: off 36 | # orig reference http://epub.wu.ac.at/1692/1/document.pdf 37 | zeileis_26 = [ 38 | "#023fa5", "#7d87b9", "#bec1d4", "#d6bcc0", "#bb7784", "#8e063b", "#4a6fe3", 39 | "#8595e1", "#b5bbe3", "#e6afb9", "#e07b91", "#d33f6a", "#11c638", "#8dd593", 40 | "#c6dec7", "#ead3c6", "#f0b98d", "#ef9708", "#0fcfc0", "#9cded6", "#d5eae7", 41 | "#f3e1eb", "#f6c4e1", "#f79cd4", "#7f7f7f", "#c7c7c7", "#1CE6FF", "#336600", 42 | ] 43 | 44 | default_26 = zeileis_26 45 | 46 | # from godsnotwheregodsnot.blogspot.de/2012/09/color-distribution-methodology.html 47 | godsnot_64 = [ 48 | # "#000000", # remove the black, as often, we have black colored annotation 49 | "#FFFF00", "#1CE6FF", "#FF34FF", "#FF4A46", "#008941", "#006FA6", "#A30059", 50 | "#FFDBE5", "#7A4900", "#0000A6", "#63FFAC", "#B79762", "#004D43", "#8FB0FF", 51 | "#997D87", "#5A0007", "#809693", "#FEFFE6", "#1B4400", "#4FC601", "#3B5DFF", 52 | "#4A3B53", "#FF2F80", "#61615A", "#BA0900", "#6B7900", "#00C2A0", "#FFAA92", 53 | "#FF90C9", "#B903AA", "#D16100", "#DDEFFF", "#000035", "#7B4F4B", "#A1C299", 54 | "#300018", "#0AA6D8", "#013349", "#00846F", "#372101", "#FFB500", "#C2FFED", 55 | "#A079BF", "#CC0744", "#C0B9B2", "#C2FF99", "#001E09", "#00489C", "#6F0062", 56 | "#0CBD66", "#EEC3FF", "#456D75", "#B77B68", "#7A87A1", "#788D66", "#885578", 57 | "#FAD09F", "#FF8A9A", "#D157A0", "#BEC459", "#456648", "#0086ED", "#886F4C", 58 | "#34362D", "#B4A8BD", "#00A6AA", "#452C2C", "#636375", "#A3C8C9", "#FF913F", 59 | "#938A81", "#575329", "#00FECF", "#B05B6F", "#8CD0FF", "#3B9700", "#04F757", 60 | "#C8A1A1", "#1E6E00", "#7900D7", "#A77500", "#6367A9", "#A05837", "#6B002C", 61 | "#772600", "#D790FF", "#9B9700", "#549E79", "#FFF69F", "#201625", "#72418F", 62 | "#BC23FF", "#99ADC0", "#3A2465", "#922329", "#5B4534", "#FDE8DC", "#404E55", 63 | "#0089A3", "#CB7E98", "#A4E804", "#324E72", "#6A3A4C" 64 | ] 65 | 66 | default_64 = godsnot_64 67 | 68 | 69 | # colors in addition to matplotlib's colors 70 | additional_colors = { 71 | 'gold2': '#eec900', 'firebrick3': '#cd2626', 'khaki2': '#eee685', 72 | 'slategray3': '#9fb6cd', 'palegreen3': '#7ccd7c', 'tomato2': '#ee5c42', 73 | 'grey80': '#cccccc', 'grey90': '#e5e5e5', 'wheat4': '#8b7e66', 'grey65': '#a6a6a6', 74 | 'grey10': '#1a1a1a', 'grey20': '#333333', 'grey50': '#7f7f7f', 'grey30': '#4d4d4d', 75 | 'grey40': '#666666', 'antiquewhite2': '#eedfcc', 'grey77': '#c4c4c4', 76 | 'snow4': '#8b8989', 'chartreuse3': '#66cd00', 'yellow4': '#8b8b00', 77 | 'darkolivegreen2': '#bcee68', 'olivedrab3': '#9acd32', 'azure3': '#c1cdcd', 78 | 'violetred': '#d02090', 'mediumpurple3': '#8968cd', 'purple4': '#551a8b', 79 | 'seagreen4': '#2e8b57', 'lightblue3': '#9ac0cd', 'orchid3': '#b452cd', 80 | 'indianred 3': '#cd5555', 'grey60': '#999999', 'mediumorchid1': '#e066ff', 81 | 'plum3': '#cd96cd', 'palevioletred3': '#cd6889' 82 | } 83 | # fmt: on 84 | 85 | 86 | def _plot_color_cylce(clists: Mapping[str, Sequence[str]]): 87 | import numpy as np 88 | 89 | import matplotlib.pyplot as plt 90 | from matplotlib.colors import BoundaryNorm, ListedColormap 91 | 92 | fig, axes = plt.subplots(nrows=len(clists)) # type: plt.Figure, plt.Axes 93 | fig.subplots_adjust(top=0.95, bottom=0.01, left=0.3, right=0.99) 94 | axes[0].set_title("Color Maps/Cycles", fontsize=14) 95 | 96 | for ax, (name, clist) in zip(axes, clists.items()): 97 | n = len(clist) 98 | ax.imshow( 99 | np.arange(n)[None, :].repeat(2, 0), 100 | aspect="auto", 101 | cmap=ListedColormap(clist), 102 | norm=BoundaryNorm(np.arange(n + 1) - 0.5, n), 103 | ) 104 | pos = list(ax.get_position().bounds) 105 | x_text = pos[0] - 0.01 106 | y_text = pos[1] + pos[3] / 2.0 107 | fig.text(x_text, y_text, name, va="center", ha="right", fontsize=10) 108 | 109 | # Turn off all ticks & spines 110 | for ax in axes: 111 | ax.set_axis_off() 112 | fig.show() 113 | 114 | 115 | if __name__ == "__main__": 116 | _plot_color_cylce( 117 | {name: colors for name, colors in globals().items() if isinstance(colors, list)} 118 | ) 119 | -------------------------------------------------------------------------------- /TFvelo/plotting/proportions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import matplotlib.pyplot as pl 4 | 5 | from ..core import sum 6 | from .utils import savefig_or_show 7 | 8 | 9 | def proportions( 10 | adata, 11 | groupby="clusters", 12 | layers=None, 13 | highlight="unspliced", 14 | add_labels_pie=True, 15 | add_labels_bar=True, 16 | fontsize=8, 17 | figsize=(10, 2), 18 | dpi=100, 19 | use_raw=True, 20 | show=True, 21 | save=None, 22 | ): 23 | """Plot pie chart of spliced/unspliced proprtions. 24 | 25 | Arguments 26 | --------- 27 | adata: :class:`~anndata.AnnData` 28 | Annotated data matrix. 29 | groupby: `str` (default: 'clusters') 30 | Key of observations grouping to consider. 31 | layers: list of `str`(default: `['spliced#, 'unspliced', 'ambiguous']`) 32 | Specify the layers of count matrices for computing proportions. 33 | highlight: `str` (default: 'unspliced') 34 | Which proportions to highlight in pie chart. 35 | add_labels_pie: `bool` (default: True) 36 | Whether to add percentage labels in pie chart. 37 | add_labels_bar: `bool` (default: True) 38 | Whether to add percentage labels in bar chart. 39 | fontsize: `float` (default: 8) 40 | Label font size. 41 | figsize: tuple (default: `(10,2)`) 42 | Figure size. 43 | dpi: `int` (default: 80) 44 | Figure dpi. 45 | use_raw : `bool` (default: `True`) 46 | Use initial cell sizes before normalization and filtering. 47 | show: `bool` (default: True) 48 | Show the plot, do not return axis. 49 | save: `bool` or `str`, optional (default: `None`) 50 | If `True` or a `str`, save the figure. A string is appended to the default 51 | filename. Infer the filetype if ending on {'.pdf', '.png', '.svg'}. 52 | 53 | Returns 54 | ------- 55 | Plots the proportions of abundances as pie chart. 56 | """ 57 | 58 | # get counts per cell for each layer 59 | if layers is None: 60 | layers = ["spliced", "unspliced", "ambigious"] 61 | layers_keys = [key for key in layers if key in adata.layers.keys()] 62 | counts_layers = [sum(adata.layers[key], axis=1) for key in layers_keys] 63 | 64 | if use_raw: 65 | ikey, obs = "initial_size_", adata.obs 66 | counts_layers = [ 67 | obs[ikey + layer_key] if ikey + layer_key in obs.keys() else c 68 | for layer_key, c in zip(layers_keys, counts_layers) 69 | ] 70 | counts_total = np.sum(counts_layers, 0) 71 | counts_total += counts_total == 0 72 | counts_layers = np.array([counts / counts_total for counts in counts_layers]) 73 | 74 | gspec = pl.GridSpec(1, 2, pl.figure(None, figsize, dpi=dpi)) 75 | colors = pl.get_cmap("tab20b")(np.linspace(0.10, 0.65, len(layers_keys))) 76 | 77 | # pie chart of total abundances 78 | ax = pl.subplot(gspec[0]) 79 | if highlight is None: 80 | highlight = "none" 81 | explode = [ 82 | 0.1 if (layer_key == highlight or layer_key in highlight) else 0 83 | for layer_key in layers_keys 84 | ] 85 | 86 | autopct = "%1.0f%%" if add_labels_pie else None 87 | pie = ax.pie( 88 | np.mean(counts_layers, axis=1), 89 | colors=colors, 90 | explode=explode, 91 | autopct=autopct, 92 | shadow=True, 93 | startangle=45, 94 | ) 95 | if autopct is not None: 96 | for pct, color in zip(pie[-1], colors): 97 | r, g, b, _ = color 98 | pct.set_color("white" if r * g * b < 0.5 else "darkgrey") 99 | pct.set_fontweight("bold") 100 | pct.set_fontsize(fontsize) 101 | ax.legend( 102 | layers_keys, 103 | ncol=len(layers_keys), 104 | bbox_to_anchor=(0, 1), 105 | loc="lower left", 106 | fontsize=fontsize, 107 | ) 108 | 109 | # bar chart of abundances per category 110 | if groupby is not None and groupby in adata.obs.keys(): 111 | counts_groups = dict() 112 | for cluster in adata.obs[groupby].cat.categories: 113 | counts_groups[cluster] = np.mean( 114 | counts_layers[:, adata.obs[groupby] == cluster], axis=1 115 | ) 116 | 117 | labels = list(counts_groups.keys()) 118 | data = np.array(list(counts_groups.values())) 119 | data_cum = data.cumsum(axis=1) 120 | 121 | ax2 = pl.subplot(gspec[1]) 122 | for i, (colname, color) in enumerate(zip(layers_keys, colors)): 123 | starts, widths = data_cum[:, i] - data[:, i], data[:, i] 124 | xpos = starts + widths / 2 125 | curr_xpos = xpos[0] 126 | for i, (x, w) in enumerate(zip(xpos, widths)): 127 | if not (x - w / 2 + 0.05 < curr_xpos < x + w / 2 - 0.05): 128 | curr_xpos = x 129 | xpos[i] = curr_xpos 130 | 131 | ax2.barh( 132 | labels, widths, left=starts, height=0.9, label=colname, color=color 133 | ) 134 | 135 | if add_labels_bar: 136 | r, g, b, _ = color 137 | text_color = "white" if r * g * b < 0.5 else "darkgrey" 138 | for y, (x, c) in enumerate(zip(xpos, widths)): 139 | ax2.text( 140 | x, 141 | y, 142 | f"{(c * 100):.0f}%", 143 | ha="center", 144 | va="center", 145 | color=text_color, 146 | fontsize=fontsize, 147 | fontweight="bold", 148 | ) 149 | 150 | ax2.legend( 151 | ncol=len(layers_keys), 152 | bbox_to_anchor=(0, 1), 153 | loc="lower left", 154 | fontsize=fontsize, 155 | ) 156 | ax2.invert_yaxis() 157 | ax2.set_xlim(0, np.nansum(data, axis=1).max()) 158 | ax2.margins(0) 159 | 160 | ax2.set_xlabel("proportions", fontweight="bold", fontsize=fontsize * 1.2) 161 | ax2.set_ylabel(groupby, fontweight="bold", fontsize=fontsize * 1.2) 162 | ax2.tick_params(axis="both", which="major", labelsize=fontsize) 163 | ax = [ax, ax2] 164 | 165 | savefig_or_show("proportions", dpi=dpi, save=save, show=show) 166 | if show is False: 167 | return ax 168 | -------------------------------------------------------------------------------- /TFvelo/plotting/pseudotime.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import matplotlib.pyplot as pl 4 | from matplotlib.ticker import MaxNLocator 5 | 6 | 7 | def principal_curve(adata): 8 | X_curve = adata.uns["principal_curve"]["projections"] 9 | ixsort = adata.uns["principal_curve"]["ixsort"] 10 | pl.plot(X_curve[ixsort, 0], X_curve[ixsort, 1], c="k", lw=3, zorder=2000000) 11 | 12 | 13 | def pseudotime(adata, gene_list, ckey="velocity", reverse=False): 14 | ixsort = adata.uns["principal_curve"]["ixsort"] 15 | arclength = adata.uns["principal_curve"]["arclength"] 16 | if reverse: 17 | arclength /= np.max(arclength) 18 | else: 19 | arclength = (np.max(arclength) - arclength) / np.max(arclength) 20 | cell_subset = adata.uns["principal_curve"]["cell_subset"] 21 | 22 | adata_subset = adata[cell_subset].copy() 23 | 24 | gs = pl.GridSpec(1, len(gene_list)) 25 | for n, gene in enumerate(gene_list): 26 | i = adata_subset.var_names.get_loc(gene) 27 | ax = pl.subplot(gs[n]) 28 | 29 | lb, ub = np.percentile(adata_subset.obsm[ckey][:, i], [0.5, 99.5]) 30 | c = np.clip(adata_subset.obsm[ckey][ixsort, i], lb, ub) 31 | # pl.scatter(arclength[ixsort], adata2.obsm['Mu'][ixsort, i], label="unspliced") 32 | pl.scatter( 33 | arclength[ixsort], 34 | adata_subset.obsm["Ms"][ixsort, i] 35 | * adata_subset.uns["velocity_pars"]["gamma"][i], 36 | c=c, 37 | cmap="coolwarm", 38 | alpha=1, 39 | s=1, 40 | label="spliced", 41 | ) 42 | 43 | c /= abs(c).max() 44 | c *= ( 45 | adata_subset.obsm["Ms"][ixsort, i] 46 | * adata_subset.uns["velocity_pars"]["gamma"][i] 47 | ).max() 48 | 49 | z = np.ma.polyfit(arclength[ixsort], c, 8) 50 | fun = np.poly1d(z) 51 | pl.plot(arclength[ixsort], fun(arclength[ixsort]), label=ckey) 52 | 53 | # Hide the right and top spines 54 | ax.spines["right"].set_visible(False) 55 | ax.spines["top"].set_visible(False) 56 | ax.xaxis.set_major_locator(MaxNLocator(nbins=3)) 57 | ax.yaxis.set_major_locator(MaxNLocator(nbins=3)) 58 | 59 | pl.ylabel(gene) 60 | pl.title(f"Colored by {ckey}") 61 | -------------------------------------------------------------------------------- /TFvelo/plotting/simulation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import matplotlib.pyplot as pl 4 | from matplotlib import rcParams 5 | 6 | from ..core import SplicingDynamics 7 | from ..tools.dynamical_model_utils import get_vars 8 | from .utils import make_dense 9 | 10 | 11 | def get_dynamics(adata, key="fit", extrapolate=False, sorted=False, t=None): 12 | alpha, beta, gamma, weight, scaling, t_ = get_vars(adata, key=key) 13 | if extrapolate: 14 | u0_ = unspliced(t_, 0, alpha, beta) 15 | tmax = t_ + omega_inv(u0_ * 1e-4, u0=u0_, alpha=0, beta=beta) 16 | t = np.concatenate( 17 | [np.linspace(0, t_, num=500), t_ + np.linspace(0, tmax, num=500)] 18 | ) 19 | elif t is None or t is True: 20 | t = adata.obs[f"{key}_t"].values if key == "true" else adata.layers[f"{key}_t"] 21 | 22 | omega, alpha, u0, s0 = vectorize(np.sort(t) if sorted else t, t_, alpha, beta, gamma, weight) 23 | ut, st = SplicingDynamics( 24 | alpha=alpha, beta=beta, gamma=gamma, weight=weight, initial_state=[u0, s0] 25 | ).get_solution(omega) 26 | return alpha, ut, st 27 | 28 | 29 | def compute_dynamics( 30 | adata, basis, key="true" 31 | ): 32 | idx = adata.var_names.get_loc(basis) if isinstance(basis, str) else basis 33 | key = "fit" if f"{key}_gamma" not in adata.var_keys() else key 34 | alpha, beta, omega, theta, gamma, delta = get_vars(adata[:, basis], key=key) 35 | 36 | #omega, alpha, u0, s0 = vectorize(np.sort(t) if sort else t, t_, alpha, beta, gamma, weight) 37 | num = np.clip(int(adata.X.shape[0] / 5), 1000, 2000) 38 | tpoints = np.linspace(0, 1, num=num) 39 | WX_t, y_t = SplicingDynamics(alpha=alpha, beta=beta, omega=omega, theta=theta, 40 | gamma=gamma).get_solution(tpoints, stacked=False) 41 | return alpha, WX_t, y_t, tpoints 42 | 43 | 44 | def show_full_dynamics( 45 | adata, 46 | basis, 47 | key="true", 48 | use_raw=False, 49 | linewidth=1, 50 | linecolor=None, 51 | show_assignments=None, 52 | ax=None, 53 | show_full_dynamic=False, 54 | ): 55 | if ax is None: 56 | ax = pl.gca() 57 | color = linecolor if linecolor else "grey" if key == "true" else "purple" 58 | linewidth = 0.5 * linewidth if key == "true" else linewidth 59 | label = "learned dynamics" if key == "fit" else "true dynamics" 60 | line = None 61 | 62 | if key != "true": 63 | _, WX_t, y_t, t_points = compute_dynamics( 64 | adata, basis, key 65 | ) 66 | if not show_full_dynamic: 67 | data = adata[:,basis].layers['fit_t_raw'].reshape(-1) 68 | if not 0 in data: 69 | data = np.insert(data, 0, 0) 70 | if not 1 in data: 71 | data = np.insert(data, len(data), 1) 72 | sorted_data = np.sort(data) 73 | intervals = np.diff(sorted_data) 74 | blank_start_id = np.argmax(intervals) 75 | if (blank_start_id==0) or (blank_start_id==len(intervals)-1): 76 | nonblank_idx = (t_points>sorted_data[1]) & (t_pointssorted_data[blank_start_id+1]) 80 | mid_idx = int((nonblank_idx).sum()/2) - (t_points>sorted_data[blank_start_id+1]).sum() 81 | WX_t = WX_t[nonblank_idx] 82 | y_t = y_t[nonblank_idx] 83 | mid_dt_idx = min(mid_idx+20, len(WX_t)-1) 84 | ax.quiver(y_t[mid_idx], WX_t[mid_idx], y_t[mid_dt_idx]-y_t[mid_idx], WX_t[mid_dt_idx]-WX_t[mid_idx], 85 | scale=0.2, width=0.05, color='purple', angles='xy', scale_units='xy') 86 | 87 | if not isinstance(show_assignments, str) or show_assignments != "only": 88 | ax.scatter(y_t, WX_t, color=color, s=1) 89 | if show_assignments is not None and show_assignments is not False: 90 | WX_key, y_key = ( 91 | ("WX", "y") 92 | ) 93 | WX, y = ( 94 | make_dense(adata[:, basis].layers[WX_key]).flatten(), 95 | make_dense(adata[:, basis].layers[y_key]).flatten(), 96 | ) 97 | ax.plot( 98 | np.array([y, y_t]), 99 | np.array([WX, WX_t]), 100 | color="grey", 101 | linewidth=0.1 * linewidth, 102 | ) 103 | 104 | if not isinstance(show_assignments, str) or show_assignments != "only": 105 | _, WX_t, y_t, _ = compute_dynamics( 106 | adata, basis, key 107 | ) 108 | (line,) = ax.plot(y_t, WX_t, color=color, linewidth=linewidth, label=label) 109 | 110 | idx = adata.var_names.get_loc(basis) 111 | gamma = adata.var[f"{key}_gamma"][idx] 112 | #xnew = np.linspace(np.min(y_t), np.max(y_t)) 113 | #ynew = np.linspace(np.min(WX_t), np.max(WX_t))#gamma / weight * (xnew - np.min(xnew)) + np.min(ut) 114 | #ax.plot(xnew, ynew, color=color, linestyle="--", linewidth=linewidth) 115 | return line, label 116 | 117 | 118 | def simulation( 119 | adata, 120 | var_names="all", 121 | legend_loc="upper right", 122 | legend_fontsize=20, 123 | linewidth=None, 124 | dpi=None, 125 | xkey="true_t", 126 | ykey=None, 127 | colors=None, 128 | **kwargs, 129 | ): 130 | from ..tools.utils import make_dense 131 | from .scatter import scatter 132 | 133 | if ykey is None: 134 | ykey = ["unspliced", "spliced", "alpha"] 135 | if colors is None: 136 | colors = ["darkblue", "darkgreen", "grey"] 137 | var_names = ( 138 | adata.var_names 139 | if isinstance(var_names, str) and var_names == "all" 140 | else [name for name in var_names if name in adata.var_names] 141 | ) 142 | 143 | figsize = rcParams["figure.figsize"] 144 | ncols = len(var_names) 145 | for i, gs in enumerate( 146 | pl.GridSpec( 147 | 1, ncols, pl.figure(None, (figsize[0] * ncols, figsize[1]), dpi=dpi) 148 | ) 149 | ): 150 | idx = adata.var_names.get_loc(var_names[i]) 151 | alpha, ut, st = compute_dynamics(adata, idx) 152 | t = ( 153 | adata.obs[xkey] 154 | if xkey in adata.obs.keys() 155 | else make_dense(adata.layers["fit_t"][:, idx]) 156 | ) 157 | idx_sorted = np.argsort(t) 158 | t = t[idx_sorted] 159 | 160 | ax = pl.subplot(gs) 161 | _kwargs = {"alpha": 0.3, "title": "", "xlabel": "time", "ylabel": "counts"} 162 | _kwargs.update(kwargs) 163 | linewidth = 1 if linewidth is None else linewidth 164 | 165 | ykey = [ykey] if isinstance(ykey, str) else ykey 166 | for j, key in enumerate(ykey): 167 | if key in adata.layers: 168 | y = make_dense(adata.layers[key][:, idx])[idx_sorted] 169 | ax = scatter(x=t, y=y, color=colors[j], ax=ax, show=False, **_kwargs) 170 | 171 | if key == "unspliced": 172 | ax.plot(t, ut, label="unspliced", color=colors[j], linewidth=linewidth) 173 | elif key == "spliced": 174 | ax.plot(t, st, label="spliced", color=colors[j], linewidth=linewidth) 175 | elif key == "alpha": 176 | largs = dict(linewidth=linewidth, linestyle="--") 177 | ax.plot(t, alpha, label="alpha", color=colors[j], **largs) 178 | 179 | pl.xlim(0) 180 | pl.ylim(0) 181 | if legend_loc != "none" and i == ncols - 1: 182 | pl.legend(loc=legend_loc, fontsize=legend_fontsize) 183 | -------------------------------------------------------------------------------- /TFvelo/plotting/velocity.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from scipy.sparse import issparse 4 | 5 | import matplotlib.pyplot as pl 6 | from matplotlib import rcParams 7 | 8 | from ..preprocessing.moments import second_order_moments 9 | from ..tools.rank_velocity_genes import rank_velocity_genes 10 | from .scatter import scatter 11 | from .utils import ( 12 | default_basis, 13 | default_size, 14 | get_basis, 15 | get_figure_params, 16 | savefig_or_show, 17 | ) 18 | 19 | 20 | def velocity( 21 | adata, 22 | var_names=None, 23 | basis=None, 24 | vkey="velocity", 25 | mode=None, 26 | fits=None, 27 | layers="all", 28 | color=None, 29 | color_map=None, 30 | colorbar=True, 31 | perc=[2, 98], 32 | alpha=0.5, 33 | size=None, 34 | groupby=None, 35 | groups=None, 36 | legend_loc="none", 37 | legend_fontsize=8, 38 | use_raw=False, 39 | fontsize=None, 40 | figsize=None, 41 | dpi=None, 42 | show=None, 43 | save=None, 44 | ax=None, 45 | ncols=None, 46 | **kwargs, 47 | ): 48 | """Phase and velocity plot for set of genes. 49 | 50 | The phase plot shows spliced against unspliced expressions with steady-state fit. 51 | Further the embedding is shown colored by velocity and expression. 52 | 53 | Arguments 54 | --------- 55 | adata: :class:`~anndata.AnnData` 56 | Annotated data matrix. 57 | var_names: `str` or list of `str` (default: `None`) 58 | Which variables to show. 59 | basis: `str` (default: `'umap'`) 60 | Key for embedding coordinates. 61 | mode: `'stochastic'` or `None` (default: `None`) 62 | Whether to show show covariability phase portrait. 63 | fits: `str` or list of `str` (default: `['velocity', 'dynamics']`) 64 | Which steady-state estimates to show. 65 | layers: `str` or list of `str` (default: `'all'`) 66 | Which layers to show. 67 | color: `str`, list of `str` or `None` (default: `None`) 68 | Key for annotations of observations/cells or variables/genes 69 | color_map: `str` or tuple (default: `['RdYlGn', 'gnuplot_r']`) 70 | String denoting matplotlib color map. If tuple is given, first and latter 71 | color map correspond to velocity and expression, respectively. 72 | perc: tuple, e.g. [2,98] (default: `[2,98]`) 73 | Specify percentile for continuous coloring. 74 | groups: `str`, `list` (default: `None`) 75 | Subset of groups, e.g. [‘g1’, ‘g2’], to which the plot shall be restricted. 76 | groupby: `str`, `list` or `np.ndarray` (default: `None`) 77 | Key of observations grouping to consider. 78 | legend_loc: str (default: 'none') 79 | Location of legend, either 'on data', 'right margin' 80 | or valid keywords for matplotlib.legend. 81 | size: `float` (default: 5) 82 | Point size. 83 | alpha: `float` (default: 1) 84 | Set blending - 0 transparent to 1 opaque. 85 | fontsize: `float` (default: `None`) 86 | Label font size. 87 | figsize: tuple (default: `(7,5)`) 88 | Figure size. 89 | dpi: `int` (default: 80) 90 | Figure dpi. 91 | show: `bool`, optional (default: `None`) 92 | Show the plot, do not return axis. 93 | save: `bool` or `str`, optional (default: `None`) 94 | If `True` or a `str`, save the figure. A string is appended to the default 95 | filename. Infer the filetype if ending on {'.pdf', '.png', '.svg'}. 96 | ax: `matplotlib.Axes`, optional (default: `None`) 97 | A matplotlib axes object. Only works if plotting a single component. 98 | ncols: `int` or `None` (default: `None`) 99 | Number of columns to arange multiplots into. 100 | 101 | """ 102 | basis = default_basis(adata) if basis is None else get_basis(adata, basis) 103 | color, color_map = kwargs.pop("c", color), kwargs.pop("cmap", color_map) 104 | if fits is None: 105 | fits = ["velocity", "dynamics"] 106 | if color_map is None: 107 | color_map = ["RdYlGn", "gnuplot_r"] 108 | 109 | if isinstance(groupby, str) and groupby in adata.obs.keys(): 110 | if ( 111 | "rank_velocity_genes" not in adata.uns.keys() 112 | or adata.uns["rank_velocity_genes"]["params"]["groupby"] != groupby 113 | ): 114 | rank_velocity_genes(adata, vkey=vkey, n_genes=10, groupby=groupby) 115 | names = np.array(adata.uns["rank_velocity_genes"]["names"].tolist()) 116 | if groups is None: 117 | var_names = names[:, 0] 118 | else: 119 | groups = [groups] if isinstance(groups, str) else groups 120 | categories = adata.obs[groupby].cat.categories 121 | idx = np.array([any([g in group for g in groups]) for group in categories]) 122 | var_names = np.hstack(names[idx, : int(10 / idx.sum())]) 123 | elif var_names is not None: 124 | if isinstance(var_names, str): 125 | var_names = [var_names] 126 | else: 127 | var_names = [var for var in var_names if var in adata.var_names] 128 | else: 129 | raise ValueError("No var_names or groups specified.") 130 | var_names = pd.unique(var_names) 131 | 132 | if use_raw or "M_total" not in adata.layers.keys(): 133 | skey = 'total' 134 | #skey, ukey = "spliced", "unspliced" 135 | else: 136 | skey = 'M_total' 137 | #skey, ukey = "Ms", "Mu" 138 | layers = [vkey, skey] if layers == "all" else layers 139 | layers = [layer for layer in layers if layer in adata.layers.keys() or layer == "X"] 140 | 141 | fits = list(adata.layers.keys()) if fits == "all" else fits 142 | fits = [fit for fit in fits if f"{fit}_gamma" in adata.var.keys()] + ["dynamics"] 143 | stochastic_fits = [fit for fit in fits if f"variance_{fit}" in adata.layers.keys()] 144 | 145 | nplts = 1 + len(layers) + (mode == "stochastic") * 2 146 | ncols = 1 if ncols is None else ncols 147 | nrows = int(np.ceil(len(var_names) / ncols)) 148 | ncols = int(ncols * nplts) 149 | figsize = rcParams["figure.figsize"] if figsize is None else figsize 150 | figsize, dpi = get_figure_params(figsize, dpi, ncols / 2) 151 | if ax is None: 152 | gs_figsize = (figsize[0] * ncols / 2, figsize[1] * nrows / 2) 153 | ax = pl.figure(figsize=gs_figsize, dpi=dpi) 154 | gs = pl.GridSpec(nrows, ncols, wspace=0.5, hspace=0.8) 155 | 156 | # half size, since fontsize is halved in width and height 157 | size = default_size(adata) / 2 if size is None else size 158 | fontsize = rcParams["font.size"] * 0.8 if fontsize is None else fontsize 159 | 160 | scatter_kwargs = dict(colorbar=colorbar, perc=perc, size=size, use_raw=use_raw) 161 | scatter_kwargs.update(dict(fontsize=fontsize, legend_fontsize=legend_fontsize)) 162 | 163 | for v, var in enumerate(var_names): 164 | _adata = adata[:, var] 165 | s = _adata.layers[skey] 166 | #s, u = _adata.layers[skey], _adata.layers[ukey] 167 | if issparse(s): 168 | #s, u = s.A, u.A 169 | s = s.A 170 | 171 | # spliced/unspliced phase portrait with steady-state estimate 172 | ax = pl.subplot(gs[v * nplts]) 173 | cmap = color_map 174 | if isinstance(color_map, (list, tuple)): 175 | cmap = color_map[-1] if color in ["X", skey] else color_map[0] 176 | if "xlabel" not in kwargs: 177 | kwargs["xlabel"] = "Target" 178 | if "ylabel" not in kwargs: 179 | kwargs["ylabel"] = "TFs" 180 | legend_loc_lines = "none" if v < len(var_names) - 1 else legend_loc 181 | #phase portrait plot 182 | scatter( 183 | adata, 184 | basis=var, 185 | color=color, 186 | color_map=cmap, 187 | frameon=True, 188 | title=var, 189 | alpha=alpha, 190 | vkey=fits, 191 | show=False, 192 | ax=ax, 193 | save=False, 194 | legend_loc_lines=legend_loc_lines, 195 | **scatter_kwargs, 196 | **kwargs, 197 | ) 198 | 199 | # velocity and expression plots 200 | for layer_id, layer in enumerate(layers): 201 | ax = pl.subplot(gs[v * nplts + layer_id + 1]) 202 | title = "expression" if layer in ["X", "y", skey] else layer 203 | # _kwargs = {} if title == 'expression' else kwargs 204 | cmap = color_map 205 | if isinstance(color_map, (list, tuple)): 206 | cmap = color_map[-1] if layer in ["X", skey] else color_map[0] 207 | scatter( 208 | adata, 209 | basis=basis, 210 | color=var, 211 | layer=layer, 212 | title=title, 213 | color_map=cmap, 214 | alpha=alpha, 215 | frameon=False, 216 | show=False, 217 | ax=ax, 218 | save=False, 219 | **scatter_kwargs, 220 | **kwargs, 221 | ) 222 | 223 | if mode == "stochastic": 224 | ss, us = second_order_moments(_adata) 225 | s, u, ss, us = s.flatten(), u.flatten(), ss.flatten(), us.flatten() 226 | fit = stochastic_fits[0] 227 | 228 | ax = pl.subplot(gs[v * nplts + len(layers) + 1]) 229 | beta, offset = 1, 0 230 | if f"{fit}_beta" in adata.var.keys(): 231 | beta = _adata.var[f"{fit}_beta"] 232 | if f"{fit}_offset" in adata.var.keys(): 233 | offset = _adata.var[f"{fit}_offset"] 234 | x = np.array(2 * (ss - s ** 2) - s) 235 | y = np.array(2 * (us - u * s) + u + 2 * s * offset / beta) 236 | kwargs["xlabel"] = r"2 $\Sigma_s - \langle s \rangle$" 237 | kwargs["ylabel"] = r"2 $\Sigma_{us} + \langle u \rangle$" 238 | scatter( 239 | adata, 240 | x=x, 241 | y=y, 242 | color=color, 243 | title=var, 244 | frameon=True, 245 | ax=ax, 246 | save=False, 247 | show=False, 248 | **scatter_kwargs, 249 | **kwargs, 250 | ) 251 | 252 | xnew = np.linspace(np.min(x), np.max(x) * 1.02) 253 | for fit in stochastic_fits: 254 | gamma, beta, offset2 = 1, 1, 0 255 | if f"{fit}_gamma" in adata.var.keys(): 256 | gamma = _adata.var[f"{fit}_gamma"].values 257 | if f"{fit}_beta" in adata.var.keys(): 258 | beta = _adata.var[f"{fit}_beta"].values 259 | if f"{fit}_offset2" in adata.var.keys(): 260 | offset2 = _adata.var[f"{fit}_offset2"].values 261 | ynew = gamma / beta * xnew + offset2 / beta 262 | pl.plot(xnew, ynew, c="k", linestyle="--") 263 | 264 | savefig_or_show(dpi=dpi, save=save, show=show) 265 | if show is False: 266 | return ax 267 | -------------------------------------------------------------------------------- /TFvelo/plotting/velocity_embedding.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import matplotlib.pyplot as pl 4 | from matplotlib import rcParams 5 | from matplotlib.colors import is_color_like 6 | 7 | from ..tools.utils import groups_to_bool 8 | from ..tools.velocity_embedding import ( 9 | velocity_embedding as compute_velocity_embedding, 10 | ) 11 | from .docs import doc_params, doc_scatter 12 | from .scatter import scatter 13 | from .utils import ( 14 | default_arrow, 15 | default_basis, 16 | default_color, 17 | default_color_map, 18 | default_size, 19 | get_ax, 20 | get_components, 21 | get_figure_params, 22 | interpret_colorkey, 23 | make_unique_list, 24 | make_unique_valid_list, 25 | savefig_or_show, 26 | velocity_embedding_changed, 27 | ) 28 | 29 | 30 | @doc_params(scatter=doc_scatter) 31 | def velocity_embedding( 32 | adata, 33 | basis=None, 34 | vkey="velocity", 35 | density=None, 36 | arrow_size=None, 37 | arrow_length=None, 38 | scale=None, 39 | X=None, 40 | V=None, 41 | recompute=None, 42 | color=None, 43 | use_raw=None, 44 | layer=None, 45 | color_map=None, 46 | colorbar=True, 47 | palette=None, 48 | size=None, 49 | alpha=0.2, 50 | perc=None, 51 | sort_order=True, 52 | groups=None, 53 | components=None, 54 | projection="2d", 55 | legend_loc="none", 56 | legend_fontsize=None, 57 | legend_fontweight=None, 58 | xlabel=None, 59 | ylabel=None, 60 | title=None, 61 | fontsize=None, 62 | figsize=None, 63 | dpi=None, 64 | frameon=None, 65 | show=None, 66 | save=None, 67 | ax=None, 68 | ncols=None, 69 | **kwargs, 70 | ): 71 | """\ 72 | Scatter plot of velocities on the embedding. 73 | 74 | Arguments 75 | --------- 76 | adata: :class:`~anndata.AnnData` 77 | Annotated data matrix. 78 | density: `float` (default: 1) 79 | Amount of velocities to show - 0 none to 1 all 80 | arrow_size: `float` or triple `headlength, headwidth, headaxislength` (default: 1) 81 | Size of arrows. 82 | arrow_length: `float` (default: 1) 83 | Length of arrows. 84 | scale: `float` (default: 1) 85 | Length of velocities in the embedding. 86 | {scatter} 87 | 88 | Returns 89 | ------- 90 | `matplotlib.Axis` if `show==False` 91 | """ 92 | 93 | if vkey == "all": 94 | lkeys = list(adata.layers.keys()) 95 | vkey = [key for key in lkeys if "velocity" in key and "_u" not in key] 96 | color, color_map = kwargs.pop("c", color), kwargs.pop("cmap", color_map) 97 | layers, vkeys = make_unique_list(layer), make_unique_list(vkey) 98 | colors = make_unique_list(color, allow_array=True) 99 | bases = make_unique_valid_list(adata, basis) 100 | bases = [default_basis(adata, **kwargs) if b is None else b for b in bases] 101 | 102 | if V is None: 103 | for key in vkeys: 104 | for bas in bases: 105 | if recompute or velocity_embedding_changed(adata, basis=bas, vkey=key): 106 | compute_velocity_embedding(adata, basis=bas, vkey=key) 107 | 108 | scatter_kwargs = { 109 | "perc": perc, 110 | "use_raw": use_raw, 111 | "sort_order": sort_order, 112 | "alpha": alpha, 113 | "components": components, 114 | "projection": projection, 115 | "legend_loc": legend_loc, 116 | "groups": groups, 117 | "legend_fontsize": legend_fontsize, 118 | "legend_fontweight": legend_fontweight, 119 | "palette": palette, 120 | "color_map": color_map, 121 | "frameon": frameon, 122 | "xlabel": xlabel, 123 | "ylabel": ylabel, 124 | "colorbar": colorbar, 125 | "dpi": dpi, 126 | "fontsize": fontsize, 127 | "show": False, 128 | "save": False, 129 | } 130 | 131 | multikey = ( 132 | colors 133 | if len(colors) > 1 134 | else layers 135 | if len(layers) > 1 136 | else vkeys 137 | if len(vkeys) > 1 138 | else bases 139 | if len(bases) > 1 140 | else None 141 | ) 142 | if multikey is not None: 143 | if title is None: 144 | title = list(multikey) 145 | elif isinstance(title, (list, tuple)): 146 | title *= int(np.ceil(len(multikey) / len(title))) 147 | ncols = len(multikey) if ncols is None else min(len(multikey), ncols) 148 | nrows = int(np.ceil(len(multikey) / ncols)) 149 | figsize = rcParams["figure.figsize"] if figsize is None else figsize 150 | figsize, dpi = get_figure_params(figsize, dpi, ncols) 151 | gs_figsize = (figsize[0] * ncols, figsize[1] * nrows) 152 | ax = [] 153 | for i, gs in enumerate( 154 | pl.GridSpec(nrows, ncols, pl.figure(None, gs_figsize, dpi=dpi)) 155 | ): 156 | if i < len(multikey): 157 | ax.append( 158 | velocity_embedding( 159 | adata, 160 | density=density, 161 | scale=scale, 162 | size=size, 163 | ax=pl.subplot(gs), 164 | arrow_size=arrow_size, 165 | arrow_length=arrow_length, 166 | basis=bases[i] if len(bases) > 1 else basis, 167 | color=colors[i] if len(colors) > 1 else color, 168 | layer=layers[i] if len(layers) > 1 else layer, 169 | vkey=vkeys[i] if len(vkeys) > 1 else vkey, 170 | title=title[i] if isinstance(title, (list, tuple)) else title, 171 | **scatter_kwargs, 172 | **kwargs, 173 | ) 174 | ) 175 | savefig_or_show(dpi=dpi, save=save, show=show) 176 | if show is False: 177 | return ax 178 | 179 | else: 180 | ax, show = get_ax(ax, show, figsize, dpi, projection) 181 | 182 | color, layer, vkey, basis = colors[0], layers[0], vkeys[0], bases[0] 183 | color = default_color(adata) if color is None else color 184 | color_map = default_color_map(adata, color) if color_map is None else color_map 185 | size = default_size(adata) / 2 if size is None else size 186 | if use_raw is None and "Ms" not in adata.layers.keys(): 187 | use_raw = True 188 | _adata = ( 189 | adata[groups_to_bool(adata, groups, groupby=color)] 190 | if groups is not None and color in adata.obs.keys() 191 | else adata 192 | ) 193 | 194 | quiver_kwargs = { 195 | "scale": scale, 196 | "cmap": color_map, 197 | "angles": "xy", 198 | "scale_units": "xy", 199 | "edgecolors": "k", 200 | "linewidth": 0.1, 201 | "width": None, 202 | } 203 | if basis in adata.var_names: 204 | if use_raw: 205 | x = adata[:, basis].layers["spliced"] 206 | y = adata[:, basis].layers["unspliced"] 207 | else: 208 | x = adata[:, basis].layers["Ms"] 209 | y = adata[:, basis].layers["Mu"] 210 | dx = adata[:, basis].layers[vkey] 211 | dy = np.zeros(adata.n_obs) 212 | if f"{vkey}_u" in adata.layers.keys(): 213 | dy = adata[:, basis].layers[f"{vkey}_u"] 214 | X = np.stack([np.ravel(x), np.ravel(y)]).T 215 | V = np.stack([np.ravel(dx), np.ravel(dy)]).T 216 | else: 217 | x = None if X is None else X[:, 0] 218 | y = None if X is None else X[:, 1] 219 | comps = get_components(components, basis, projection) 220 | X = _adata.obsm[f"X_{basis}"][:, comps] if X is None else X[:, :2] 221 | V = _adata.obsm[f"{vkey}_{basis}"][:, comps] if V is None else V[:, :2] 222 | 223 | hl, hw, hal = default_arrow(arrow_size) 224 | if arrow_length is not None: 225 | scale = 1 / arrow_length 226 | if scale is None: 227 | scale = 1 228 | quiver_kwargs.update({"scale": scale, "width": 0.0005, "headlength": hl}) 229 | quiver_kwargs.update({"headwidth": hw, "headaxislength": hal}) 230 | 231 | for arg in list(kwargs): 232 | if arg in quiver_kwargs: 233 | quiver_kwargs.update({arg: kwargs[arg]}) 234 | else: 235 | scatter_kwargs.update({arg: kwargs[arg]}) 236 | 237 | if ( 238 | basis in adata.var_names 239 | and isinstance(color, str) 240 | and color in adata.layers.keys() 241 | ): 242 | c = interpret_colorkey(_adata, basis, color, perc) 243 | else: 244 | c = interpret_colorkey(_adata, color, layer, perc) 245 | 246 | if density is not None and 0 < density < 1: 247 | s = int(density * _adata.n_obs) 248 | ix_choice = np.random.choice(_adata.n_obs, size=s, replace=False) 249 | c = c[ix_choice] if len(c) == _adata.n_obs else c 250 | X = X[ix_choice] 251 | V = V[ix_choice] 252 | 253 | if projection == "3d" and X.shape[1] > 2 and V.shape[1] > 2: 254 | V, size = V / scale / 5, size / 10 255 | x0, x1, x2 = X[:, 0], X[:, 1], X[:, 2] 256 | v0, v1, v2 = V[:, 0], V[:, 1], V[:, 2] 257 | quiver3d_kwargs = {"zorder": 3, "linewidth": 0.5, "arrow_length_ratio": 0.3} 258 | c = list(c) + [element for element in list(c) for _ in range(2)] 259 | if is_color_like(c[0]): 260 | ax.quiver(x0, x1, x2, v0, v1, v2, color=c, **quiver3d_kwargs) 261 | else: 262 | ax.quiver(x0, x1, x2, v0, v1, v2, c, **quiver3d_kwargs) 263 | else: 264 | quiver_kwargs.update({"zorder": 3}) 265 | if is_color_like(c[0]): 266 | ax.quiver(X[:, 0], X[:, 1], V[:, 0], V[:, 1], color=c, **quiver_kwargs) 267 | else: 268 | ax.quiver(X[:, 0], X[:, 1], V[:, 0], V[:, 1], c, **quiver_kwargs) 269 | 270 | scatter_kwargs.update({"basis": basis, "x": x, "y": y, "color": color}) 271 | scatter_kwargs.update({"vkey": vkey, "layer": layer}) 272 | ax = scatter(adata, size=size, title=title, ax=ax, zorder=0, **scatter_kwargs) 273 | 274 | savefig_or_show(dpi=dpi, save=save, show=show) 275 | if show is False: 276 | return ax 277 | -------------------------------------------------------------------------------- /TFvelo/pp.py: -------------------------------------------------------------------------------- 1 | from .preprocessing import * # noqa 2 | -------------------------------------------------------------------------------- /TFvelo/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | from .moments import moments 2 | from .neighbors import neighbors, pca, remove_duplicate_cells 3 | 4 | from .utils import ( 5 | get_TFs, 6 | cleanup, 7 | filter_and_normalize, 8 | filter_genes, 9 | filter_genes_dispersion, 10 | log1p, 11 | normalize_per_cell, 12 | recipe_velocity, 13 | show_proportions, 14 | ) 15 | 16 | 17 | __all__ = [ 18 | "get_TFs", 19 | "cleanup", 20 | "filter_and_normalize", 21 | "filter_genes", 22 | "filter_genes_dispersion", 23 | "log1p", 24 | "moments", 25 | "neighbors", 26 | "normalize_per_cell", 27 | "pca", 28 | "recipe_velocity", 29 | "remove_duplicate_cells", 30 | "show_proportions", 31 | ] 32 | 33 | -------------------------------------------------------------------------------- /TFvelo/preprocessing/moments.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.sparse import csr_matrix, issparse 3 | 4 | from .. import logging as logg 5 | from .. import settings 6 | from .neighbors import get_connectivities, get_n_neighs, neighbors, verify_neighbors 7 | from .utils import normalize_per_cell, not_yet_normalized 8 | 9 | 10 | def moments( 11 | data, 12 | n_neighbors=30, 13 | n_pcs=None, 14 | mode="connectivities", 15 | method="umap", 16 | use_rep=None, 17 | use_highly_variable=True, 18 | copy=False, 19 | ): 20 | """Computes moments for velocity estimation. 21 | 22 | First-/second-order moments are computed for each cell across its nearest neighbors, 23 | where the neighbor graph is obtained from euclidean distances in PCA space. 24 | 25 | Arguments 26 | --------- 27 | data: :class:`~anndata.AnnData` 28 | Annotated data matrix. 29 | n_neighbors: `int` (default: 30) 30 | Number of neighbors to use. 31 | n_pcs: `int` (default: None) 32 | Number of principal components to use. 33 | If not specified, the full space is used of a pre-computed PCA, 34 | or 30 components are used when PCA is computed internally. 35 | mode: `'connectivities'` or `'distances'` (default: `'connectivities'`) 36 | Distance metric to use for moment computation. 37 | method : {{'umap', 'hnsw', 'sklearn', `None`}} (default: `'umap'`) 38 | Method to compute neighbors, only differs in runtime. 39 | Connectivities are computed with adaptive kernel width as proposed in 40 | Haghverdi et al. 2016 (https://doi.org/10.1038/nmeth.3971). 41 | use_rep : `None`, `'X'` or any key for `.obsm` (default: None) 42 | Use the indicated representation. If `None`, the representation is chosen 43 | automatically: for .n_vars < 50, .X is used, otherwise ‘X_pca’ is used. 44 | use_highly_variable: `bool` (default: True) 45 | Whether to use highly variable genes only, stored in .var['highly_variable']. 46 | copy: `bool` (default: `False`) 47 | Return a copy instead of writing to adata. 48 | 49 | Returns 50 | ------- 51 | M_total: `.layers` 52 | dense matrix with first order moments of total counts. 53 | Ms: `.layers` 54 | dense matrix with first order moments of spliced counts. 55 | Mu: `.layers` 56 | dense matrix with first order moments of unspliced counts. 57 | """ 58 | 59 | adata = data.copy() if copy else data 60 | 61 | layers = [layer for layer in {"total"} if layer in adata.layers] 62 | if any([not_yet_normalized(adata.layers[layer]) for layer in layers]): 63 | normalize_per_cell(adata) 64 | 65 | if n_neighbors is not None and n_neighbors > get_n_neighs(adata): 66 | neighbors( 67 | adata, 68 | n_neighbors=n_neighbors, 69 | use_rep=use_rep, 70 | use_highly_variable=use_highly_variable, 71 | n_pcs=n_pcs, 72 | method=method, 73 | ) 74 | verify_neighbors(adata) 75 | 76 | if "total" not in adata.layers.keys(): # or "unspliced" not in adata.layers.keys(): 77 | logg.warn("Skipping moments, because total counts were not found.") 78 | else: 79 | logg.info(f"computing moments based on {mode}", r=True) 80 | connectivities = get_connectivities( 81 | adata, mode, n_neighbors=n_neighbors, recurse_neighbors=False 82 | ) 83 | 84 | adata.layers["M_total"] = ( 85 | csr_matrix.dot(connectivities, csr_matrix(adata.layers["total"])) 86 | .astype(np.float32) 87 | .A 88 | ) 89 | # if renormalize: normalize_per_cell(adata, layers={'Ms', 'Mu'}, enforce=True) 90 | 91 | logg.info( 92 | " finished", time=True, end=" " if settings.verbosity > 2 else "\n" 93 | ) 94 | logg.hint( 95 | "added \n" 96 | " 'M_total', moments of total abundances (adata.layers)" 97 | ) 98 | return adata if copy else None 99 | 100 | 101 | def second_order_moments(adata, adjusted=False): 102 | """Computes second order moments for stochastic velocity estimation. 103 | 104 | Arguments 105 | --------- 106 | adata: `AnnData` 107 | Annotated data matrix. 108 | 109 | Returns 110 | ------- 111 | Mss: Second order moments for spliced abundances 112 | Mus: Second order moments for spliced with unspliced abundances 113 | """ 114 | 115 | if "neighbors" not in adata.uns: 116 | raise ValueError( 117 | "You need to run `pp.neighbors` first to compute a neighborhood graph." 118 | ) 119 | 120 | connectivities = get_connectivities(adata) 121 | s, u = csr_matrix(adata.layers["spliced"]), csr_matrix(adata.layers["unspliced"]) 122 | if s.shape[0] == 1: 123 | s, u = s.T, u.T 124 | Mss = csr_matrix.dot(connectivities, s.multiply(s)).astype(np.float32).A 125 | Mus = csr_matrix.dot(connectivities, s.multiply(u)).astype(np.float32).A 126 | if adjusted: 127 | Mss = 2 * Mss - adata.layers["Ms"].reshape(Mss.shape) 128 | Mus = 2 * Mus - adata.layers["Mu"].reshape(Mus.shape) 129 | return Mss, Mus 130 | 131 | 132 | def second_order_moments_u(adata): 133 | """Computes second order moments for stochastic velocity estimation. 134 | 135 | Arguments 136 | --------- 137 | adata: `AnnData` 138 | Annotated data matrix. 139 | 140 | Returns 141 | ------- 142 | Muu: Second order moments for unspliced abundances 143 | """ 144 | 145 | if "neighbors" not in adata.uns: 146 | raise ValueError( 147 | "You need to run `pp.neighbors` first to compute a neighborhood graph." 148 | ) 149 | 150 | connectivities = get_connectivities(adata) 151 | u = csr_matrix(adata.layers["unspliced"]) 152 | Muu = csr_matrix.dot(connectivities, u.multiply(u)).astype(np.float32).A 153 | 154 | return Muu 155 | 156 | 157 | def magic_impute(adata, knn=5, t=2, verbose=0, **kwargs): 158 | logg.info( 159 | "To be used carefully. Magic has not yet been tested for this application." 160 | ) 161 | import magic 162 | 163 | magic_operator = magic.MAGIC(verbose=verbose, knn=knn, t=t, **kwargs) 164 | adata.layers["Ms"] = magic_operator.fit_transform(adata.layers["spliced"]) 165 | adata.layers["Mu"] = magic_operator.transform(adata.layers["unspliced"]) 166 | 167 | 168 | def get_moments( 169 | adata, layer=None, second_order=None, centered=True, mode="connectivities" 170 | ): 171 | """Computes moments for a specified layer. 172 | 173 | First and second order moments. 174 | If centered, that corresponds to means and variances across nearest neighbors. 175 | 176 | Arguments 177 | --------- 178 | adata: `AnnData` 179 | Annotated data matrix. 180 | layer: `str` (default: `None`) 181 | Key of layer with abundances to consider for moment computation. 182 | second_order: `bool` (default: `None`) 183 | Whether to compute second order moments from abundances. 184 | centered: `bool` (default: `True`) 185 | Whether to compute centered (=variance) or uncentered second order moments. 186 | mode: `'connectivities'` or `'distances'` (default: `'connectivities'`) 187 | Distance metric to use for moment computation. 188 | 189 | Returns 190 | ------- 191 | Mx: first or second order moments 192 | """ 193 | 194 | if "neighbors" not in adata.uns: 195 | raise ValueError( 196 | "You need to run `pp.neighbors` first to compute a neighborhood graph." 197 | ) 198 | connectivities = get_connectivities(adata, mode=mode) 199 | X = ( 200 | adata.X 201 | if layer is None 202 | else adata.layers[layer] 203 | if isinstance(layer, str) 204 | else layer 205 | ) 206 | X = ( 207 | csr_matrix(X) 208 | if isinstance(layer, str) and layer in {"spliced", "unspliced"} 209 | else np.array(X) 210 | if not issparse(X) 211 | else X 212 | ) 213 | if not issparse(X): 214 | X = X[:, ~np.isnan(X.sum(0))] 215 | if second_order: 216 | X2 = X.multiply(X) if issparse(X) else X ** 2 217 | Mx = ( 218 | csr_matrix.dot(connectivities, X2) 219 | if second_order 220 | else csr_matrix.dot(connectivities, X) 221 | ) 222 | if centered: 223 | mu = csr_matrix.dot(connectivities, X) 224 | mu2 = mu.multiply(mu) if issparse(mu) else mu ** 2 225 | Mx = Mx - mu2 226 | else: 227 | Mx = csr_matrix.dot(connectivities, X) 228 | if issparse(X): 229 | Mx = Mx.astype(np.float32).A 230 | return Mx 231 | -------------------------------------------------------------------------------- /TFvelo/read_load.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | from pathlib import Path 4 | from urllib.request import urlretrieve 5 | 6 | import numpy as np 7 | import pandas as pd 8 | 9 | from .core import clean_obs_names as _clean_obs_names 10 | from .core import get_df as _get_df 11 | from .core import merge as _merge 12 | from .core._anndata import obs_df as _obs_df 13 | from .core._anndata import var_df as _var_df 14 | 15 | 16 | def load(filename, backup_url=None, header="infer", index_col="infer", **kwargs): 17 | """Load a csv, txt, tsv or npy file.""" 18 | numpy_ext = {"npy", "npz"} 19 | pandas_ext = {"csv", "txt", "tsv"} 20 | 21 | if not os.path.exists(filename) and backup_url is None: 22 | raise FileNotFoundError(f"Did not find file {filename}.") 23 | 24 | elif not os.path.exists(filename): 25 | d = os.path.dirname(filename) 26 | if not os.path.exists(d): 27 | os.makedirs(d) 28 | urlretrieve(backup_url, filename) 29 | 30 | ext = Path(filename).suffixes[-1][1:] 31 | 32 | if ext in numpy_ext: 33 | return np.load(filename, **kwargs) 34 | 35 | elif ext in pandas_ext: 36 | df = pd.read_csv( 37 | filename, 38 | header=header, 39 | index_col=None if index_col == "infer" else index_col, 40 | **kwargs, 41 | ) 42 | if index_col == "infer" and len(df.columns) > 1: 43 | is_int_index = all(np.arange(0, len(df)) == df.iloc[:, 0]) 44 | is_str_index = isinstance(df.iloc[0, 0], str) and all( 45 | [not isinstance(d, str) for d in df.iloc[0, 1:]] 46 | ) 47 | if is_int_index or is_str_index: 48 | df.set_index(df.columns[0], inplace=True) 49 | return df 50 | 51 | else: 52 | raise ValueError( 53 | f"'{filename}' does not end on a valid extension.\n" 54 | "Please, provide one of the available extensions.\n" 55 | f"{numpy_ext | pandas_ext}\n" 56 | ) 57 | 58 | 59 | read_csv = load 60 | 61 | 62 | def clean_obs_names(data, base="[AGTCBDHKMNRSVWY]", ID_length=12, copy=False): 63 | warnings.warn( 64 | "`scvelo.read_load.clean_obs_names` is deprecated since scVelo v0.2.4 and will " 65 | "be removed in a future version. Please use `scvelo.core.clean_obs_names` " 66 | "instead.", 67 | DeprecationWarning, 68 | stacklevel=2, 69 | ) 70 | 71 | return _clean_obs_names(data=data, base=base, ID_length=ID_length, copy=copy) 72 | 73 | 74 | def merge(adata, ldata, copy=True): 75 | warnings.warn( 76 | "`scvelo.read_load.merge` is deprecated since scVelo v0.2.4 and will be " 77 | "removed in a future version. Please use `scvelo.core.merge` instead.", 78 | DeprecationWarning, 79 | stacklevel=2, 80 | ) 81 | 82 | return _merge(adata=adata, ldata=ldata, copy=True) 83 | 84 | 85 | def obs_df(adata, keys, layer=None): 86 | warnings.warn( 87 | "`scvelo.read_load.obs_df` is deprecated since scVelo v0.2.4 and will be " 88 | "removed in a future version. Please use `scvelo.core._anndata.obs_df` " 89 | "instead.", 90 | DeprecationWarning, 91 | stacklevel=2, 92 | ) 93 | 94 | return _obs_df(adata=adata, keys=keys, layer=layer) 95 | 96 | 97 | def var_df(adata, keys, layer=None): 98 | warnings.warn( 99 | "`scvelo.read_load.var_df` is deprecated since scVelo v0.2.4 and will be " 100 | "removed in a future version. Please use `scvelo.core._anndata.var_df` " 101 | "instead.", 102 | DeprecationWarning, 103 | stacklevel=2, 104 | ) 105 | 106 | return _var_df(adata=adata, keys=keys, layer=layer) 107 | 108 | 109 | def get_df( 110 | data, 111 | keys=None, 112 | layer=None, 113 | index=None, 114 | columns=None, 115 | sort_values=None, 116 | dropna="all", 117 | precision=None, 118 | ): 119 | warnings.warn( 120 | "`scvelo.read_load.get_df` is deprecated since scVelo v0.2.4 and will be " 121 | "removed in a future version. Please use `scvelo.core.get_df` instead.", 122 | DeprecationWarning, 123 | stacklevel=2, 124 | ) 125 | 126 | return _get_df( 127 | data=data, 128 | keys=keys, 129 | layer=layer, 130 | index=index, 131 | columns=columns, 132 | sort_values=sort_values, 133 | dropna=dropna, 134 | precision=precision, 135 | ) 136 | 137 | 138 | DataFrame = get_df 139 | 140 | 141 | def load_biomart(): 142 | # human genes from https://biomart.genenames.org/martform 143 | # mouse genes from http://www.ensembl.org/biomart/martview 144 | # antibodies from https://www.biolegend.com/en-us/totalseq 145 | nb_url = "https://github.com/theislab/scvelo_notebooks/raw/master/" 146 | 147 | filename = "data/biomart/mart_export_human.txt" 148 | df = load(filename, sep="\t", backup_url=f"{nb_url}{filename}") 149 | df.columns = ["ensembl", "gene name"] 150 | df.index = df.pop("ensembl") 151 | 152 | filename = "data/biomart/mart_export_mouse.txt" 153 | df2 = load(filename, sep="\t", backup_url=f"{nb_url}{filename}") 154 | df2.columns = ["ensembl", "gene name"] 155 | df2.index = df2.pop("ensembl") 156 | 157 | df = pd.concat([df, df2]) 158 | df = df.drop_duplicates() 159 | return df 160 | 161 | 162 | def convert_to_gene_names(ensembl_names=None): 163 | """Retrieve gene names from ensembl IDs.""" 164 | df = load_biomart() 165 | if ensembl_names is not None: 166 | if isinstance(ensembl_names, str): 167 | ensembl_names = [ensembl_names] 168 | valid_names = [name for name in ensembl_names if name in df.index] 169 | if len(valid_names) > 0: 170 | df = df.loc[valid_names] 171 | 172 | gene_names = np.array(ensembl_names) 173 | idx = pd.DataFrame(ensembl_names).isin(df.index).values.flatten() 174 | gene_names[idx] = df["gene name"].values 175 | 176 | df = pd.DataFrame([ensembl_names, gene_names]).T 177 | df.columns = ["ensembl", "gene name"] 178 | df.index = df.pop("ensembl") 179 | return df 180 | 181 | 182 | def convert_to_ensembl(gene_names=None): 183 | """Retrieve ensembl IDs from a list of gene names.""" 184 | df = load_biomart() 185 | if gene_names is not None: 186 | if isinstance(gene_names, str): 187 | gene_names = [gene_names] 188 | valid_names = [name for name in gene_names if name in df["gene name"].tolist()] 189 | if len(valid_names) > 0: 190 | index = [i in valid_names for i in df["gene name"].tolist()] 191 | df = df[index] 192 | 193 | df["ensembl"] = df.index 194 | df = df.set_index("gene name") 195 | return df 196 | 197 | 198 | def gene_info(name, fields="name,symbol,refseq,generif,ensembl"): 199 | """Retrieve gene information from biothings client.""" 200 | try: 201 | from biothings_client import get_client 202 | except ImportError: 203 | raise ImportError( 204 | "Please install Biothings first via `pip install biothings_client`." 205 | ) 206 | 207 | class MyGeneInfo(get_client("gene", instance=False)): 208 | def __init__(self): 209 | super(MyGeneInfo, self).__init__() 210 | 211 | if not name.startswith("ENS"): 212 | df = convert_to_gene_names() 213 | df.reset_index(inplace=True) 214 | df.set_index("gene name", inplace=True) 215 | if name in df.index: 216 | name = df.loc[name][0] 217 | 218 | info = MyGeneInfo().getgene(name, fields) 219 | return info 220 | -------------------------------------------------------------------------------- /TFvelo/tl.py: -------------------------------------------------------------------------------- 1 | from .tools import * # noqa 2 | -------------------------------------------------------------------------------- /TFvelo/tools/__init__.py: -------------------------------------------------------------------------------- 1 | from scanpy.tools import diffmap, dpt, louvain, tsne, umap 2 | 3 | from .dynamical_model import ( 4 | DynamicsRecovery, 5 | latent_time, 6 | rank_dynamical_genes, 7 | recover_dynamics, 8 | recover_latent_time, 9 | ) 10 | from .paga import paga 11 | from .rank_velocity_genes import rank_velocity_genes, velocity_clusters 12 | from .terminal_states import eigs, terminal_states 13 | from .transition_matrix import transition_matrix 14 | from .velocity_confidence import velocity_confidence, velocity_confidence_transition 15 | from .velocity_embedding import velocity_embedding 16 | from .velocity_graph import velocity_graph 17 | from .velocity_pseudotime import velocity_map, velocity_pseudotime 18 | 19 | __all__ = [ 20 | "diffmap", 21 | "dpt", 22 | "DynamicsRecovery", 23 | "eigs", 24 | "latent_time", 25 | "louvain", 26 | "paga", 27 | "rank_dynamical_genes", 28 | "rank_velocity_genes", 29 | "recover_dynamics", 30 | "recover_latent_time", 31 | "terminal_states", 32 | "transition_matrix", 33 | "tsne", 34 | "umap", 35 | "velocity_clusters", 36 | "velocity_confidence", 37 | "velocity_confidence_transition", 38 | "velocity_embedding", 39 | "velocity_graph", 40 | "velocity_map", 41 | "velocity_pseudotime", 42 | ] 43 | -------------------------------------------------------------------------------- /TFvelo/tools/transition_matrix.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import numpy as np 4 | import pandas as pd 5 | from scipy.sparse import csr_matrix, SparseEfficiencyWarning 6 | from scipy.spatial.distance import pdist, squareform 7 | 8 | from ..preprocessing.neighbors import get_connectivities, get_neighs 9 | from .utils import normalize 10 | 11 | warnings.simplefilter("ignore", SparseEfficiencyWarning) 12 | 13 | 14 | def transition_matrix( 15 | adata, 16 | vkey='velocity',#"velo_hat", 17 | basis=None, 18 | backward=False, 19 | self_transitions=True, 20 | scale=10, 21 | perc=None, 22 | threshold=None, 23 | use_negative_cosines=False, 24 | weight_diffusion=0, 25 | scale_diffusion=1, 26 | weight_indirect_neighbors=None, 27 | n_neighbors=None, 28 | vgraph=None, 29 | basis_constraint=None, 30 | ): 31 | """Computes cell-to-cell transition probabilities 32 | 33 | .. math:: 34 | \\tilde \\pi_{ij} = \\frac1{z_i} \\exp( \\pi_{ij} / \\sigma), 35 | 36 | from the velocity graph :math:`\\pi_{ij}`, with row-normalization :math:`z_i` and 37 | kernel width :math:`\\sigma` (scale parameter :math:`\\lambda = \\sigma^{-1}`). 38 | 39 | Alternatively, use :func:`cellrank.tl.transition_matrix` to account for uncertainty 40 | in the velocity estimates. 41 | 42 | Arguments 43 | --------- 44 | adata: :class:`~anndata.AnnData` 45 | Annotated data matrix. 46 | vkey: `str` (default: `'velocity'`) 47 | Name of velocity estimates to be used. 48 | basis: `str` or `None` (default: `None`) 49 | Restrict transition to embedding if specified 50 | backward: `bool` (default: `False`) 51 | Whether to use the transition matrix to 52 | push forward (`False`) or to pull backward (`True`) 53 | self_transitions: `bool` (default: `True`) 54 | Allow transitions from one node to itself. 55 | scale: `float` (default: 10) 56 | Scale parameter of gaussian kernel. 57 | perc: `float` between `0` and `100` or `None` (default: `None`) 58 | Determines threshold of transitions to include. 59 | use_negative_cosines: `bool` (default: `False`) 60 | If True, negatively similar transitions are taken into account. 61 | weight_diffusion: `float` (default: 0) 62 | Relative weight to be given to diffusion kernel (Brownian motion) 63 | scale_diffusion: `float` (default: 1) 64 | Scale of diffusion kernel. 65 | weight_indirect_neighbors: `float` between `0` and `1` or `None` (default: `None`) 66 | Weight to be assigned to indirect neighbors (i.e. neighbors of higher degrees). 67 | n_neighbors:`int` (default: None) 68 | Number of nearest neighbors to consider around each cell. 69 | vgraph: csr matrix or `None` (default: `None`) 70 | Velocity graph representation to use instead of adata.uns[f'{vkey}_graph']. 71 | 72 | Returns 73 | ------- 74 | Returns sparse matrix with transition probabilities. 75 | """ 76 | 77 | if f"{vkey}_graph" not in adata.uns: 78 | raise ValueError( 79 | "You need to run `tl.velocity_graph` first to compute cosine correlations." 80 | ) 81 | 82 | graph_neg = None 83 | if vgraph is not None: 84 | graph = vgraph.copy() 85 | else: 86 | if hasattr(adata, "obsp") and f"{vkey}_graph" in adata.obsp.keys(): 87 | graph = csr_matrix(adata.obsp[f"{vkey}_graph"]).copy() 88 | if f"{vkey}_graph_neg" in adata.obsp.keys(): 89 | graph_neg = adata.obsp[f"{vkey}_graph_neg"] 90 | else: 91 | graph = csr_matrix(adata.uns[f"{vkey}_graph"]).copy() 92 | if f"{vkey}_graph_neg" in adata.uns.keys(): 93 | graph_neg = adata.uns[f"{vkey}_graph_neg"] 94 | 95 | if basis_constraint is not None and f"X_{basis_constraint}" in adata.obsm.keys(): 96 | from sklearn.neighbors import NearestNeighbors 97 | 98 | neighs = NearestNeighbors(n_neighbors=100) 99 | neighs.fit(adata.obsm[f"X_{basis_constraint}"]) 100 | basis_graph = neighs.kneighbors_graph(mode="connectivity") > 0 101 | graph = graph.multiply(basis_graph) 102 | 103 | if self_transitions: 104 | confidence = graph.max(1).A.flatten() 105 | ub = np.percentile(confidence, 98) 106 | self_prob = np.clip(ub - confidence, 0, 1) 107 | graph.setdiag(self_prob) 108 | 109 | T = np.expm1(graph * scale) # equivalent to np.exp(graph.A * scale) - 1 110 | if graph_neg is not None: 111 | graph_neg = adata.uns[f"{vkey}_graph_neg"] 112 | if use_negative_cosines: 113 | T -= np.expm1(-graph_neg * scale) 114 | else: 115 | T += np.expm1(graph_neg * scale) 116 | T.data += 1 117 | 118 | # weight direct and indirect (recursed) neighbors 119 | if weight_indirect_neighbors is not None and weight_indirect_neighbors < 1: 120 | direct_neighbors = get_neighs(adata, "distances") > 0 121 | direct_neighbors.setdiag(1) 122 | w = weight_indirect_neighbors 123 | T = w * T + (1 - w) * direct_neighbors.multiply(T) 124 | 125 | if n_neighbors is not None: 126 | T = T.multiply( 127 | get_connectivities( 128 | adata, mode="distances", n_neighbors=n_neighbors, recurse_neighbors=True 129 | ) 130 | ) 131 | 132 | if perc is not None or threshold is not None: 133 | if threshold is None: 134 | threshold = np.percentile(T.data, perc) 135 | T.data[T.data < threshold] = 0 136 | T.eliminate_zeros() 137 | 138 | if backward: 139 | T = T.T 140 | T = normalize(T) 141 | 142 | if f"X_{basis}" in adata.obsm.keys(): 143 | dists_emb = (T > 0).multiply(squareform(pdist(adata.obsm[f"X_{basis}"]))) 144 | scale_diffusion *= dists_emb.data.mean() 145 | 146 | diffusion_kernel = dists_emb.copy() 147 | diffusion_kernel.data = np.exp( 148 | -0.5 * dists_emb.data ** 2 / scale_diffusion ** 2 149 | ) 150 | T = T.multiply(diffusion_kernel) # combine velocity kernel & diffusion kernel 151 | 152 | if 0 < weight_diffusion < 1: # add diffusion kernel (Brownian motion - like) 153 | diffusion_kernel.data = np.exp( 154 | -0.5 * dists_emb.data ** 2 / (scale_diffusion / 2) ** 2 155 | ) 156 | T = (1 - weight_diffusion) * T + weight_diffusion * diffusion_kernel 157 | 158 | T = normalize(T) 159 | 160 | return T 161 | 162 | 163 | def get_cell_transitions( 164 | adata, 165 | starting_cell=0, 166 | basis=None, 167 | n_steps=100, 168 | n_neighbors=30, 169 | backward=False, 170 | random_state=None, 171 | **kwargs, 172 | ): 173 | """Simulate cell transitions 174 | 175 | Arguments 176 | --------- 177 | adata: :class:`~anndata.AnnData` 178 | Annotated data matrix. 179 | starting_cell: `int` (default: `0`) 180 | Index (`int`) or name (`obs_names) of starting cell. 181 | n_steps: `int` (default: `100`) 182 | Number of transitions/steps to be simulated. 183 | backward: `bool` (default: `False`) 184 | Whether to use the transition matrix to 185 | push forward (`False`) or to pull backward (`True`) 186 | random_state: `int` or `None` (default: `None`) 187 | Set to `int` for reproducibility, otherwise `None` for a random seed. 188 | **kwargs: 189 | To be passed to tl.transition_matrix. 190 | 191 | Returns 192 | ------- 193 | Returns embedding coordinates (if basis is specified), 194 | otherwise return indices of simulated cell transitions. 195 | """ 196 | 197 | np.random.seed(random_state) 198 | if isinstance(starting_cell, str) and starting_cell in adata.obs_names: 199 | starting_cell = adata.obs_names.get_loc(starting_cell) 200 | X = [starting_cell] 201 | T = transition_matrix( 202 | adata, 203 | backward=backward, 204 | basis_constraint=basis, 205 | self_transitions=False, 206 | **kwargs, 207 | ) 208 | for _ in range(n_steps): 209 | t = T[X[-1]] 210 | indices, p = t.indices, t.data 211 | if n_neighbors is not None and n_neighbors < len(p): 212 | idx = np.argsort(t.data)[::-1][:n_neighbors] 213 | indices, p = indices[idx], p[idx] 214 | if len(p) == 0: 215 | indices, p = [X[-1]], [1] 216 | p /= np.sum(p) 217 | ix = np.random.choice(indices, p=p) 218 | X.append(ix) 219 | X = pd.unique(X) 220 | if basis is not None and f"X_{basis}" in adata.obsm.keys(): 221 | X = adata.obsm[f"X_{basis}"][X].T 222 | if backward: 223 | X = np.flip(X, axis=-1) 224 | return X 225 | -------------------------------------------------------------------------------- /TFvelo/tools/velocity_confidence.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .. import logging as logg 4 | from ..core import l2_norm, prod_sum 5 | from ..preprocessing.neighbors import get_neighs 6 | from .transition_matrix import transition_matrix 7 | from .utils import get_indices, random_subsample 8 | 9 | 10 | def velocity_confidence(data, vkey="velocity", copy=False): 11 | """Computes confidences of velocities. 12 | 13 | .. code:: python 14 | 15 | scv.tl.velocity_confidence(adata) 16 | scv.pl.scatter(adata, color='velocity_confidence', perc=[2,98]) 17 | 18 | .. image:: https://user-images.githubusercontent.com/31883718/69626334-b6df5200-1048-11ea-9171-495845c5bc7a.png 19 | :width: 600px 20 | 21 | 22 | Arguments 23 | --------- 24 | data: :class:`~anndata.AnnData` 25 | Annotated data matrix. 26 | vkey: `str` (default: `'velocity'`) 27 | Name of velocity estimates to be used. 28 | copy: `bool` (default: `False`) 29 | Return a copy instead of writing to adata. 30 | 31 | Returns 32 | ------- 33 | velocity_length: `.obs` 34 | Length of the velocity vectors for each individual cell 35 | velocity_confidence: `.obs` 36 | Confidence for each cell 37 | """ # noqa E501 38 | 39 | adata = data.copy() if copy else data 40 | if vkey not in adata.layers.keys(): 41 | raise ValueError("You need to run `tl.velocity` first.") 42 | 43 | V = np.array(adata.layers[vkey]) 44 | 45 | tmp_filter = np.invert(np.isnan(np.sum(V, axis=0))) 46 | if f"{vkey}_genes" in adata.var.keys(): 47 | tmp_filter &= np.array(adata.var[f"{vkey}_genes"], dtype=bool) 48 | if "spearmans_score" in adata.var.keys(): 49 | tmp_filter &= adata.var["spearmans_score"].values > 0.1 50 | 51 | V = V[:, tmp_filter] 52 | 53 | V -= V.mean(1)[:, None] 54 | V_norm = l2_norm(V, axis=1) 55 | R = np.zeros(adata.n_obs) 56 | 57 | indices = get_indices(dist=get_neighs(adata, "distances"))[0] 58 | for i in range(adata.n_obs): 59 | Vi_neighs = V[indices[i]] 60 | Vi_neighs -= Vi_neighs.mean(1)[:, None] 61 | R[i] = np.mean( 62 | np.einsum("ij, j", Vi_neighs, V[i]) 63 | / (l2_norm(Vi_neighs, axis=1) * V_norm[i])[None, :] 64 | ) 65 | 66 | adata.obs[f"{vkey}_length"] = V_norm.round(2) 67 | adata.obs[f"{vkey}_confidence"] = np.clip(R, 0, None) 68 | 69 | logg.hint(f"added '{vkey}_length' (adata.obs)") 70 | logg.hint(f"added '{vkey}_confidence' (adata.obs)") 71 | 72 | if f"{vkey}_confidence_transition" not in adata.obs.keys(): 73 | velocity_confidence_transition(adata, vkey) 74 | 75 | return adata if copy else None 76 | 77 | 78 | def velocity_confidence_transition(data, vkey="velocity", scale=10, copy=False): 79 | """Computes confidences of velocity transitions. 80 | 81 | Arguments 82 | --------- 83 | data: :class:`~anndata.AnnData` 84 | Annotated data matrix. 85 | vkey: `str` (default: `'velocity'`) 86 | Name of velocity estimates to be used. 87 | scale: `float` (default: 10) 88 | Scale parameter of gaussian kernel. 89 | copy: `bool` (default: `False`) 90 | Return a copy instead of writing to adata. 91 | 92 | Returns 93 | ------- 94 | velocity_confidence_transition: `.obs` 95 | Confidence of transition for each cell 96 | """ 97 | 98 | adata = data.copy() if copy else data 99 | if vkey not in adata.layers.keys(): 100 | raise ValueError("You need to run `tl.velocity` first.") 101 | 102 | X = np.array(adata.layers["M_total"]) 103 | V = np.array(adata.layers[vkey]) 104 | 105 | tmp_filter = np.invert(np.isnan(np.sum(V, axis=0))) 106 | if f"{vkey}_genes" in adata.var.keys(): 107 | tmp_filter &= np.array(adata.var[f"{vkey}_genes"], dtype=bool) 108 | if "spearmans_score" in adata.var.keys(): 109 | tmp_filter &= adata.var["spearmans_score"].values > 0.1 110 | 111 | V = V[:, tmp_filter] 112 | X = X[:, tmp_filter] 113 | 114 | T = transition_matrix(adata, vkey=vkey, scale=scale) 115 | dX = T.dot(X) - X 116 | dX -= dX.mean(1)[:, None] 117 | V -= V.mean(1)[:, None] 118 | 119 | norms = l2_norm(dX, axis=1) * l2_norm(V, axis=1) 120 | norms += norms == 0 121 | 122 | adata.obs[f"{vkey}_confidence_transition"] = prod_sum(dX, V, axis=1) / norms 123 | 124 | logg.hint(f"added '{vkey}_confidence_transition' (adata.obs)") 125 | 126 | return adata if copy else None 127 | 128 | 129 | def score_robustness( 130 | data, adata_subset=None, fraction=0.5, vkey="velocity", copy=False 131 | ): 132 | adata = data.copy() if copy else data 133 | 134 | if adata_subset is None: 135 | from scvelo.preprocessing.moments import moments 136 | from scvelo.preprocessing.neighbors import neighbors 137 | from .velocity import velocity 138 | 139 | logg.switch_verbosity("off") 140 | adata_subset = adata.copy() 141 | subset = random_subsample(adata_subset, fraction=fraction, return_subset=True) 142 | neighbors(adata_subset) 143 | moments(adata_subset) 144 | velocity(adata_subset, vkey=vkey) 145 | logg.switch_verbosity("on") 146 | else: 147 | subset = adata.obs_names.isin(adata_subset.obs_names) 148 | 149 | V = adata[subset].layers[vkey] 150 | V_subset = adata_subset.layers[vkey] 151 | 152 | score = np.nan * (subset is False) 153 | score[subset] = prod_sum(V, V_subset, axis=1) / ( 154 | l2_norm(V, axis=1) * l2_norm(V_subset, axis=1) 155 | ) 156 | adata.obs[f"{vkey}_score_robustness"] = score 157 | 158 | return adata_subset if copy else None 159 | -------------------------------------------------------------------------------- /TFvelo/tools/velocity_embedding.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import numpy as np 4 | from scipy.sparse import issparse 5 | 6 | from .. import logging as logg 7 | from .. import settings 8 | from ..core import l2_norm 9 | from .transition_matrix import transition_matrix 10 | 11 | 12 | def quiver_autoscale(X_emb, V_emb): 13 | import matplotlib.pyplot as pl 14 | 15 | scale_factor = np.abs(X_emb).max() # just so that it handles very large values 16 | fig, ax = pl.subplots() 17 | Q = ax.quiver( 18 | X_emb[:, 0] / scale_factor, 19 | X_emb[:, 1] / scale_factor, 20 | V_emb[:, 0], 21 | V_emb[:, 1], 22 | angles="xy", 23 | scale_units="xy", 24 | scale=None, 25 | ) 26 | Q._init() 27 | fig.clf() 28 | pl.close(fig) 29 | return Q.scale / scale_factor 30 | 31 | 32 | def velocity_embedding( 33 | data, 34 | basis=None, 35 | vkey="velocity", 36 | scale=10, 37 | self_transitions=True, 38 | use_negative_cosines=True, 39 | direct_pca_projection=None, 40 | retain_scale=False, 41 | autoscale=True, 42 | all_comps=True, 43 | T=None, 44 | copy=False, 45 | ): 46 | """Projects the single cell velocities into any embedding. 47 | 48 | Given normalized difference of the embedding positions 49 | :math: 50 | `\\tilde \\delta_{ij} = \\frac{x_j-x_i}{\\left\\lVert x_j-x_i \\right\\rVert}`. 51 | the projections are obtained as expected displacements with respect to the 52 | transition matrix :math:`\\tilde \\pi_{ij}` as 53 | 54 | .. math:: 55 | \\tilde \\nu_i = E_{\\tilde \\pi_{i\\cdot}} [\\tilde \\delta_{i \\cdot}] 56 | = \\sum_{j \\neq i} \\left( \\tilde \\pi_{ij} - \\frac1n \\right) \\tilde \\ 57 | delta_{ij}. 58 | 59 | 60 | Arguments 61 | --------- 62 | data: :class:`~anndata.AnnData` 63 | Annotated data matrix. 64 | basis: `str` (default: `'tsne'`) 65 | Which embedding to use. 66 | vkey: `str` (default: `'velocity'`) 67 | Name of velocity estimates to be used. 68 | scale: `int` (default: 10) 69 | Scale parameter of gaussian kernel for transition matrix. 70 | self_transitions: `bool` (default: `True`) 71 | Whether to allow self transitions, based on the confidences of transitioning to 72 | neighboring cells. 73 | use_negative_cosines: `bool` (default: `True`) 74 | Whether to project cell-to-cell transitions with negative cosines into 75 | negative/opposite direction. 76 | direct_pca_projection: `bool` (default: `None`) 77 | Whether to directly project the velocities into PCA space, 78 | thus skipping the velocity graph. 79 | retain_scale: `bool` (default: `False`) 80 | Whether to retain scale from high dimensional space in embedding. 81 | autoscale: `bool` (default: `True`) 82 | Whether to scale the embedded velocities by a scalar multiplier, 83 | which simply ensures that the arrows in the embedding are properly scaled. 84 | all_comps: `bool` (default: `True`) 85 | Whether to compute the velocities on all embedding components. 86 | T: `csr_matrix` (default: `None`) 87 | Allows the user to directly pass a transition matrix. 88 | copy: `bool` (default: `False`) 89 | Return a copy instead of writing to `adata`. 90 | 91 | Returns 92 | ------- 93 | velocity_umap: `.obsm` 94 | coordinates of velocity projection on embedding (e.g., basis='umap') 95 | """ 96 | 97 | adata = data.copy() if copy else data 98 | 99 | if basis is None: 100 | keys = [ 101 | key for key in ["pca", "tsne", "umap"] if f"X_{key}" in adata.obsm.keys() 102 | ] 103 | if len(keys) > 0: 104 | basis = "pca" if direct_pca_projection else keys[-1] 105 | else: 106 | raise ValueError("No basis specified") 107 | 108 | if f"X_{basis}" not in adata.obsm_keys(): 109 | raise ValueError("You need to compute the embedding first.") 110 | 111 | if direct_pca_projection and "pca" in basis: 112 | logg.warn( 113 | "Directly projecting velocities into PCA space is for exploratory analysis " 114 | "on principal components.\n" 115 | " It does not reflect the actual velocity field from high " 116 | "dimensional gene expression space.\n" 117 | " To visualize velocities, consider applying " 118 | "`direct_pca_projection=False`.\n" 119 | ) 120 | 121 | logg.info("computing velocity embedding", r=True) 122 | 123 | V = np.array(adata.layers[vkey]) 124 | vgenes = np.ones(adata.n_vars, dtype=bool) 125 | if f"{vkey}_genes" in adata.var.keys(): 126 | vgenes &= np.array(adata.var[f"{vkey}_genes"], dtype=bool) 127 | vgenes &= ~np.isnan(V.sum(0)) 128 | V = V[:, vgenes] 129 | 130 | if direct_pca_projection and "pca" in basis: 131 | PCs = adata.varm["PCs"] if all_comps else adata.varm["PCs"][:, :2] 132 | PCs = PCs[vgenes] 133 | 134 | X_emb = adata.obsm[f"X_{basis}"] 135 | V_emb = (V - V.mean(0)).dot(PCs) 136 | 137 | else: 138 | X_emb = ( 139 | adata.obsm[f"X_{basis}"] if all_comps else adata.obsm[f"X_{basis}"][:, :2] 140 | ) 141 | V_emb = np.zeros(X_emb.shape) 142 | 143 | T = ( 144 | transition_matrix( 145 | adata, 146 | vkey=vkey, 147 | scale=scale, 148 | self_transitions=self_transitions, 149 | use_negative_cosines=use_negative_cosines, 150 | ) 151 | if T is None 152 | else T 153 | ) 154 | T.setdiag(0) 155 | T.eliminate_zeros() 156 | 157 | densify = adata.n_obs < 1e4 158 | TA = T.A if densify else None 159 | 160 | with warnings.catch_warnings(): 161 | warnings.simplefilter("ignore") 162 | for i in range(adata.n_obs): 163 | indices = T[i].indices 164 | dX = X_emb[indices] - X_emb[i, None] # shape (n_neighbors, 2) 165 | if not retain_scale: 166 | dX /= l2_norm(dX)[:, None] 167 | dX[np.isnan(dX)] = 0 # zero diff in a steady-state 168 | probs = TA[i, indices] if densify else T[i].data 169 | V_emb[i] = probs.dot(dX) - probs.mean() * dX.sum(0) 170 | 171 | if retain_scale: 172 | X = ( 173 | adata.layers["Ms"] 174 | if "Ms" in adata.layers.keys() 175 | else adata.layers["spliced"] 176 | ) 177 | delta = T.dot(X[:, vgenes]) - X[:, vgenes] 178 | if issparse(delta): 179 | delta = delta.A 180 | cos_proj = (V * delta).sum(1) / l2_norm(delta) 181 | V_emb *= np.clip(cos_proj[:, None] * 10, 0, 1) 182 | 183 | if autoscale: 184 | V_emb /= 3 * quiver_autoscale(X_emb, V_emb) 185 | 186 | if f"{vkey}_params" in adata.uns.keys(): 187 | adata.uns[f"{vkey}_params"]["embeddings"] = ( 188 | [] 189 | if "embeddings" not in adata.uns[f"{vkey}_params"] 190 | else list(adata.uns[f"{vkey}_params"]["embeddings"]) 191 | ) 192 | adata.uns[f"{vkey}_params"]["embeddings"].extend([basis]) 193 | 194 | vkey += f"_{basis}" 195 | adata.obsm[vkey] = V_emb 196 | 197 | logg.info(" finished", time=True, end=" " if settings.verbosity > 2 else "\n") 198 | logg.hint("added\n" f" '{vkey}', embedded velocity vectors (adata.obsm)") 199 | 200 | return adata if copy else None 201 | -------------------------------------------------------------------------------- /TFvelo/tools/velocity_pseudotime.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.sparse import issparse, linalg, spdiags 3 | 4 | from scanpy.tools._dpt import DPT 5 | 6 | from .. import logging as logg 7 | from ..preprocessing.moments import get_connectivities 8 | from .terminal_states import terminal_states 9 | from .utils import groups_to_bool, scale, strings_to_categoricals 10 | 11 | 12 | def principal_curve(data, basis="pca", n_comps=4, clusters_list=None, copy=False): 13 | """Computes the principal curve 14 | Arguments 15 | --------- 16 | data: :class:`~anndata.AnnData` 17 | Annotated data matrix. 18 | basis: `str` (default: `'pca'`) 19 | Basis to use for computing the principal curve. 20 | n_comps: `int` (default: 4) 21 | Number of pricipal components to be used. 22 | copy: `bool`, (default: `False`) 23 | Return a copy instead of writing to adata. 24 | 25 | Returns 26 | ------- 27 | principal_curve: `.uns` 28 | dictionary containing `projections`, `ixsort` and `arclength` 29 | """ 30 | 31 | adata = data.copy() if copy else data 32 | import rpy2.robjects as robjects 33 | from rpy2.robjects.packages import importr 34 | 35 | if clusters_list is not None: 36 | cell_subset = np.array( 37 | [label in clusters_list for label in adata.obs["clusters"]] 38 | ) 39 | X_emb = adata[cell_subset].obsm[f"X_{basis}"][:, :n_comps] 40 | else: 41 | cell_subset = None 42 | X_emb = adata.obsm[f"X_{basis}"][:, :n_comps] 43 | 44 | n_obs, n_dim = X_emb.shape 45 | 46 | # convert array to R matrix 47 | xvec = robjects.FloatVector(X_emb.T.reshape((X_emb.size))) 48 | X_R = robjects.r.matrix(xvec, nrow=n_obs, ncol=n_dim) 49 | 50 | fit = importr("princurve").principal_curve(X_R) 51 | 52 | adata.uns["principal_curve"] = dict() 53 | adata.uns["principal_curve"]["ixsort"] = ixsort = np.array(fit[1]) - 1 54 | adata.uns["principal_curve"]["projections"] = np.array(fit[0])[ixsort] 55 | adata.uns["principal_curve"]["arclength"] = np.array(fit[2]) 56 | adata.uns["principal_curve"]["cell_subset"] = cell_subset 57 | 58 | return adata if copy else None 59 | 60 | 61 | def velocity_map(adata=None, T=None, n_dcs=10, return_model=False): 62 | vpt = VPT(adata, n_dcs=n_dcs) 63 | if T is None: 64 | T = adata.uns["velocity_graph"] - adata.uns["velocity_graph_neg"] 65 | vpt._connectivities = T + T.T 66 | vpt.compute_transitions() 67 | vpt.compute_eigen(n_dcs) 68 | adata.obsm["X_vmap"] = vpt.eigen_basis 69 | return vpt if return_model else None 70 | 71 | 72 | class VPT(DPT): 73 | def set_iroot(self, root=None): 74 | if ( 75 | isinstance(root, str) 76 | and root in self._adata.obs.keys() 77 | and self._adata.obs[root].max() != 0 78 | ): 79 | self.iroot = get_connectivities(self._adata).dot(self._adata.obs[root]) 80 | self.iroot = scale(self.iroot).argmax() 81 | elif isinstance(root, str) and root in self._adata.obs_names: 82 | self.iroot = self._adata.obs_names.get_loc(root) 83 | elif isinstance(root, (int, np.integer)) and root < self._adata.n_obs: 84 | self.iroot = root 85 | else: 86 | self.iroot = None 87 | 88 | def compute_transitions(self, density_normalize=True): 89 | T = self._connectivities 90 | if density_normalize: 91 | q = np.asarray(T.sum(axis=0)) 92 | q += q == 0 93 | Q = ( 94 | spdiags(1.0 / q, 0, T.shape[0], T.shape[0]) 95 | if issparse(T) 96 | else np.diag(1.0 / q) 97 | ) 98 | K = Q.dot(T).dot(Q) 99 | else: 100 | K = T 101 | z = np.sqrt(np.asarray(K.sum(axis=0))) 102 | Z = ( 103 | spdiags(1.0 / z, 0, K.shape[0], K.shape[0]) 104 | if issparse(K) 105 | else np.diag(1.0 / z) 106 | ) 107 | self._transitions_sym = Z.dot(K).dot(Z) 108 | 109 | def compute_eigen(self, n_comps=10, sym=None, sort="decrease"): 110 | if self._transitions_sym is None: 111 | raise ValueError("Run `.compute_transitions` first.") 112 | n_comps = min(self._transitions_sym.shape[0] - 1, n_comps) 113 | evals, evecs = linalg.eigsh(self._transitions_sym, k=n_comps, which="LM") 114 | self._eigen_values = evals[::-1] 115 | self._eigen_basis = evecs[:, ::-1] 116 | 117 | def compute_pseudotime(self, inverse=False): 118 | if self.iroot is not None: 119 | self._set_pseudotime() 120 | self.pseudotime = 1 - self.pseudotime if inverse else self.pseudotime 121 | self.pseudotime[~np.isfinite(self.pseudotime)] = np.nan 122 | else: 123 | self.pseudotime = np.empty(self._adata.n_obs) 124 | self.pseudotime[:] = np.nan 125 | 126 | 127 | def velocity_pseudotime( 128 | adata, 129 | vkey="velocity", 130 | modality='M_total', 131 | groupby=None, 132 | groups=None, 133 | root_key=None, 134 | end_key=None, 135 | n_dcs=10, 136 | use_velocity_graph=True, 137 | save_diffmap=None, 138 | return_model=None, 139 | **kwargs, 140 | ): 141 | """Computes a pseudotime based on the velocity graph. 142 | 143 | Velocity pseudotime is a random-walk based distance measures on the velocity graph. 144 | After computing a distribution over root cells obtained from the velocity-inferred 145 | transition matrix, it measures the average number of steps it takes to reach a cell 146 | after start walking from one of the root cells. Contrarily to diffusion pseudotime, 147 | it implicitly infers the root cells and is based on the directed velocity graph 148 | instead of the similarity-based diffusion kernel. 149 | 150 | .. code:: python 151 | 152 | scv.tl.velocity_pseudotime(adata) 153 | scv.pl.scatter(adata, color='velocity_pseudotime', color_map='gnuplot') 154 | 155 | .. image:: https://user-images.githubusercontent.com/31883718/69545487-33fbc000-0f92-11ea-969b-194dc68400b0.png 156 | :width: 600px 157 | 158 | Arguments 159 | --------- 160 | adata: :class:`~anndata.AnnData` 161 | Annotated data matrix 162 | vkey: `str` (default: `'velocity'`) 163 | Name of velocity estimates to be used. 164 | groupby: `str`, `list` or `np.ndarray` (default: `None`) 165 | Key of observations grouping to consider. 166 | groups: `str`, `list` or `np.ndarray` (default: `None`) 167 | Groups selected to find terminal states on. Must be an element of 168 | adata.obs[groupby]. Only to be set, if each group is assumed to have a distinct 169 | lineage with an independent root and end point. 170 | root_key: `int` (default: `None`) 171 | Index of root cell to be used. 172 | Computed from velocity-inferred transition matrix if not specified. 173 | end_key: `int` (default: `None`) 174 | Index of end point to be used. 175 | Computed from velocity-inferred transition matrix if not specified. 176 | n_dcs: `int` (default: 10) 177 | The number of diffusion components to use. 178 | use_velocity_graph: `bool` (default: `True`) 179 | Whether to use the velocity graph. 180 | If False, it uses the similarity-based diffusion kernel. 181 | save_diffmap: `bool` (default: `None`) 182 | Whether to store diffmap coordinates. 183 | return_model: `bool` (default: `None`) 184 | Whether to return the vpt object for further inspection. 185 | **kwargs: 186 | Further arguments to pass to VPT (e.g. min_group_size, allow_kendall_tau_shift). 187 | 188 | Returns 189 | ------- 190 | velocity_pseudotime: `.obs` 191 | Velocity pseudotime obtained from velocity graph. 192 | """ # noqa E501 193 | 194 | strings_to_categoricals(adata) 195 | if root_key is None and "root_cells" in adata.obs.keys(): 196 | root0 = adata.obs["root_cells"][0] 197 | if not np.isnan(root0) and not isinstance(root0, str): 198 | root_key = "root_cells" 199 | if end_key is None and "end_points" in adata.obs.keys(): 200 | end0 = adata.obs["end_points"][0] 201 | if not np.isnan(end0) and not isinstance(end0, str): 202 | end_key = "end_points" 203 | 204 | groupby = ( 205 | "cell_fate" if groupby is None and "cell_fate" in adata.obs.keys() else groupby 206 | ) 207 | if groupby is not None: 208 | logg.warn( 209 | "Only set groupby, when you have evident distinct clusters/lineages," 210 | " each with an own root and end point." 211 | ) 212 | categories = ( 213 | adata.obs[groupby].cat.categories 214 | if groupby is not None and groups is None 215 | else [None] 216 | ) 217 | for cat in categories: 218 | groups = cat if cat is not None else groups 219 | if ( 220 | root_key is None 221 | or root_key in adata.obs.keys() 222 | and np.max(adata.obs[root_key]) == np.min(adata.obs[root_key]) 223 | ): 224 | terminal_states(adata, vkey=vkey, groupby=groupby, groups=groups, modality=modality) 225 | root_key, end_key = "root_cells", "end_points" 226 | cell_subset = groups_to_bool(adata, groups=groups, groupby=groupby) 227 | data = adata.copy() if cell_subset is None else adata[cell_subset].copy() 228 | if "allow_kendall_tau_shift" not in kwargs: 229 | kwargs["allow_kendall_tau_shift"] = True 230 | vpt = VPT(data, n_dcs=n_dcs, **kwargs) 231 | 232 | if use_velocity_graph: 233 | T = data.uns[f"{vkey}_graph"] - data.uns[f"{vkey}_graph_neg"] 234 | vpt._connectivities = T + T.T 235 | 236 | vpt.compute_transitions() 237 | vpt.compute_eigen(n_comps=n_dcs) 238 | 239 | vpt.set_iroot(root_key) 240 | vpt.compute_pseudotime() 241 | dpt_root = vpt.pseudotime 242 | 243 | if end_key is not None: 244 | vpt.set_iroot(end_key) 245 | vpt.compute_pseudotime(inverse=True) 246 | dpt_end = vpt.pseudotime 247 | 248 | # merge dpt_root and inverse dpt_end together 249 | vpt.pseudotime = np.nan_to_num(dpt_root) + np.nan_to_num(dpt_end) 250 | vpt.pseudotime[np.isfinite(dpt_root) & np.isfinite(dpt_end)] /= 2 251 | vpt.pseudotime = scale(vpt.pseudotime) 252 | vpt.pseudotime[np.isnan(dpt_root) & np.isnan(dpt_end)] = np.nan 253 | 254 | if "n_branchings" in kwargs and kwargs["n_branchings"] > 0: 255 | vpt.branchings_segments() 256 | else: 257 | vpt.indices = vpt.pseudotime.argsort() 258 | 259 | if f"{vkey}_pseudotime" not in adata.obs.keys(): 260 | pseudotime = np.empty(adata.n_obs) 261 | pseudotime[:] = np.nan 262 | else: 263 | pseudotime = adata.obs[f"{vkey}_pseudotime"].values 264 | pseudotime[cell_subset] = vpt.pseudotime 265 | adata.obs[f"{vkey}_pseudotime"] = np.array(pseudotime, dtype=np.float64) 266 | 267 | if save_diffmap: 268 | diffmap = np.empty(shape=(adata.n_obs, n_dcs)) 269 | diffmap[:] = np.nan 270 | diffmap[cell_subset] = vpt.eigen_basis 271 | adata.obsm[f"X_diffmap_{groups}"] = diffmap 272 | 273 | return vpt if return_model else None 274 | -------------------------------------------------------------------------------- /TFvelo/utils.py: -------------------------------------------------------------------------------- 1 | from .core import ( 2 | clean_obs_names, 3 | cleanup, 4 | get_initial_size, 5 | merge, 6 | set_initial_size, 7 | show_proportions, 8 | ) 9 | from .plotting.simulation import compute_dynamics 10 | from .plotting.utils import ( 11 | clip, 12 | interpret_colorkey, 13 | is_categorical, 14 | rgb_custom_colormap, 15 | ) 16 | from .plotting.velocity_embedding_grid import compute_velocity_on_grid 17 | from .preprocessing.moments import get_moments 18 | from .preprocessing.neighbors import get_connectivities 19 | from .read_load import ( 20 | convert_to_ensembl, 21 | convert_to_gene_names, 22 | gene_info, 23 | load_biomart, 24 | ) 25 | from .tools.rank_velocity_genes import get_mean_var 26 | from .tools.transition_matrix import get_cell_transitions 27 | from .tools.transition_matrix import transition_matrix as get_transition_matrix 28 | from .tools.utils import * # noqa 29 | from .tools.velocity_graph import vals_to_csr 30 | 31 | __all__ = [ 32 | "cleanup", 33 | "clean_obs_names", 34 | "clip", 35 | "compute_dynamics", 36 | "compute_velocity_on_grid", 37 | "convert_to_ensembl", 38 | "convert_to_gene_names", 39 | "gene_info", 40 | "get_cell_transitions", 41 | "get_connectivities", 42 | "get_initial_size", 43 | "get_mean_var", 44 | "get_moments", 45 | "get_transition_matrix", 46 | "interpret_colorkey", 47 | "is_categorical", 48 | "load_biomart", 49 | "merge", 50 | "rgb_custom_colormap", 51 | "set_initial_size", 52 | "show_proportions", 53 | "vals_to_csr", 54 | ] 55 | -------------------------------------------------------------------------------- /TFvelo_analysis_demo.py: -------------------------------------------------------------------------------- 1 | import TFvelo as TFv 2 | import anndata as ad 3 | import scanpy as sc 4 | import numpy as np 5 | import scipy 6 | import matplotlib 7 | matplotlib.use('AGG') 8 | 9 | 10 | np.set_printoptions(suppress=True) 11 | 12 | 13 | def check_data_type(adata): 14 | for key in list(adata.var): 15 | if adata.var[key][0] in ['True', 'False']: 16 | print('Checking', key) 17 | adata.var[key] = adata.var[key].map({'True': True, 'False': False}) 18 | return 19 | 20 | 21 | def data_type_tostr(adata, key=None): 22 | if key is None: 23 | for key in list(adata.var): 24 | if adata.var[key][0] in [True, False]: 25 | print('Transfering', key) 26 | adata.var[key] = adata.var[key].map({True: 'True', False:'False'}) 27 | elif key in adata.var.keys(): 28 | if adata.var[key][0] in [True, False]: 29 | print('Transfering', key) 30 | adata.var[key] = adata.var[key].map({True: 'True', False:'False'}) 31 | return 32 | 33 | 34 | 35 | def get_pseudotime(adata): 36 | TFv.tl.velocity_graph(adata, basis=None, vkey='velocity', xkey='M_total') 37 | TFv.tl.velocity_pseudotime(adata, vkey='velocity', modality='M_total') 38 | TFv.pl.scatter(adata, basis=args.basis, color='velocity_pseudotime', cmap='gnuplot', fontsize=20, save='pseudotime') 39 | return adata 40 | 41 | 42 | def get_sort_positions(arr): 43 | positions = np.argsort(np.argsort(arr)) 44 | positions_normed = positions/(len(arr)-1) 45 | return positions_normed 46 | 47 | 48 | def get_metric_pseudotime(adata, t_key='latent_t'): 49 | n_cells, n_genes = adata.shape 50 | adata.var['spearmanr_pseudotime'] = 0.0 51 | for i in range(n_genes): 52 | correlation, _ = scipy.stats.spearmanr(adata.layers[t_key][:,i], adata.obs['velocity_pseudotime']) 53 | adata.var['spearmanr_pseudotime'][i] = correlation 54 | return adata 55 | 56 | 57 | def show_adata(args, adata, save_name, show_all=0): 58 | if show_all: 59 | for i in range(int((len(adata.var_names)-1)/20)+1): 60 | genes2show = adata.var_names[i*20: (i+1)*20] 61 | TFv.pl.velocity(adata, genes2show, ncols=4, add_outline=True, layers='na', dpi=300, fontsize=15, save='WX_y_'+save_name+'_'+str(i)) #layers='all' 62 | if len(adata.obs['clusters'].cat.categories) > 10: 63 | legend_loc = 'right margin' 64 | else: 65 | legend_loc = 'on data' 66 | cutoff_perc = 20 67 | TFv.pl.velocity_embedding_stream(adata, vkey='velocity', use_derivative=False, density=2, basis=args.basis, \ 68 | cutoff_perc=cutoff_perc, smooth=0.5, fontsize=20, recompute=True, \ 69 | legend_loc=legend_loc, save='embedding_stream_'+save_name) # 70 | 71 | return 72 | 73 | 74 | 75 | def get_sort_t(adata): 76 | t = adata.layers['fit_t_raw'].copy() 77 | normed_t = adata.layers['fit_t_raw'].copy() 78 | n_bins = 20 79 | n_cells, n_genes = adata.shape 80 | sort_t = np.zeros([n_cells, n_genes]) 81 | non_blank_gene = np.zeros(n_genes, dtype=int) 82 | hist_all, bins_all = np.zeros([n_genes, n_bins]), np.zeros([n_genes, n_bins+1]) 83 | for i in range(n_genes): 84 | gene_name = adata.var_names[i] 85 | tmp = t[:,i].copy() 86 | if np.isnan(tmp).sum(): 87 | non_blank_gene[i] = 1 88 | continue 89 | hist, bins = np.histogram(tmp, bins=n_bins) 90 | hist_all[i], bins_all[i] = hist, bins 91 | if not (0 in list(hist)): 92 | if (tmp.min() < 0.1) and (tmp.max() > 0.8): 93 | blank_start_bin_id = np.argmin(hist) 94 | blank_end_bin_id = blank_start_bin_id 95 | non_blank_gene[i] = 1 96 | blank_start_bin = bins[blank_start_bin_id] 97 | blank_end_bin = bins[blank_end_bin_id] 98 | tmp = (tmp < blank_start_bin)*1 + tmp 99 | else: 100 | blank_end_bin = tmp.min() 101 | else: 102 | blank_start_bin_id = list(hist).index(0) 103 | for j in range(blank_start_bin_id+1, len(hist)): 104 | if hist[j] > 0: 105 | blank_end_bin_id = j 106 | break 107 | blank_start_bin = bins[blank_start_bin_id] 108 | blank_end_bin = bins[blank_end_bin_id] 109 | tmp = (tmp < blank_start_bin)*1 + tmp 110 | 111 | t[:,i] = tmp 112 | tmp = tmp - blank_end_bin 113 | tmp = tmp/tmp.max() 114 | normed_t[:,i] = tmp 115 | sort_t[:,i] = get_sort_positions(tmp) 116 | 117 | adata.layers['latent_t'] = sort_t.copy() 118 | adata.var['non_blank_gene'] = non_blank_gene.copy() 119 | return adata 120 | 121 | 122 | 123 | 124 | def main(args): 125 | adata = ad.read_h5ad(args.data_path+"rc.h5ad") 126 | adata.var_names_make_unique() 127 | check_data_type(adata) 128 | print(adata) 129 | 130 | losses = adata.varm['loss'].copy() 131 | losses[np.isnan(losses)] = 1e6 132 | adata.var['min_loss'] = losses.min(1) 133 | 134 | n_cells = adata.shape[0] 135 | expanded_scaling_y = np.expand_dims(np.array(adata.var['fit_scaling_y']),0).repeat(n_cells,axis=0) 136 | adata.layers['velocity'] = adata.layers['velo_hat'] / expanded_scaling_y 137 | 138 | if 'X_pca' not in adata.obsm.keys(): 139 | print('PCA ing') 140 | sc.tl.pca(adata, n_comps=50, svd_solver='arpack') 141 | if (args.basis=='umap') and ('X_umap' not in adata.obsm.keys()): 142 | print('Umap ing') 143 | if args.dataset_name == 'hesc1': 144 | sc.tl.pca(adata, n_comps=50, svd_solver='arpack') 145 | sc.pp.neighbors(adata, use_rep="X_pca", n_neighbors=30, n_pcs=5) 146 | sc.tl.umap(adata) 147 | else: 148 | sc.tl.umap(adata) 149 | sc.pl.umap(adata, color='clusters', save=True) 150 | 151 | adata = get_pseudotime(adata) 152 | 153 | adata_copy = adata.copy() 154 | adata_copy = get_sort_t(adata_copy) 155 | 156 | adata_copy_1 = adata_copy.copy() 157 | data_type_tostr(adata_copy_1) 158 | print(adata_copy_1) 159 | adata_copy_1.write(args.data_path + 'rc.h5ad') 160 | 161 | thres_loss = np.percentile(adata_copy.var['min_loss'], args.loss_percent_thres) 162 | adata_copy = adata_copy[:, adata_copy.var['min_loss'] < thres_loss] 163 | 164 | thres_n_cells = adata_copy.X.shape[0] * 0.1 165 | adata_copy = adata_copy[:, adata_copy.var['n_cells'] > thres_n_cells] 166 | 167 | adata_copy = adata_copy[:, adata_copy.var['non_blank_gene']==0] 168 | 169 | 170 | adata_copy = get_metric_pseudotime(adata_copy) 171 | adata_copy = adata_copy[:, adata_copy.var['spearmanr_pseudotime'] > args.spearmanr_thres] 172 | 173 | TFv.tl.velocity_graph(adata_copy, basis=None, vkey='velocity', xkey='M_total') 174 | adata_copy.uns['clusters_colors'] = adata.uns['clusters_colors'] 175 | show_adata(args, adata_copy, save_name='velo', show_all=1) 176 | 177 | 178 | data_type_tostr(adata_copy) 179 | print(adata_copy) 180 | adata_copy.write(args.data_path + 'TFvelo.h5ad') 181 | 182 | return 183 | 184 | 185 | 186 | 187 | if __name__ == '__main__': 188 | import argparse 189 | parser = argparse.ArgumentParser() 190 | parser.add_argument( '--dataset_name', type=str, default="pancreas", help='pancreas, gastrulation_erythroid, 10x_mouse_brain, hesc1') 191 | parser.add_argument( '--layer', type=str, default="M_total", help='M_total, total') 192 | parser.add_argument( '--basis', type=str, default="umap", help='umap, tsne, pca') 193 | parser.add_argument( '--loss_percent_thres', type=int, default=50, help='max loss of each gene') 194 | parser.add_argument( '--spearmanr_thres', type=float, default=0.8, help='min spearmanr') 195 | parser.add_argument( '--save_name', type=str, default='_demo', help='save_name') 196 | args = parser.parse_args() 197 | 198 | args.data_path = 'TFvelo_'+ args.dataset_name + args.save_name+ '/' 199 | print('------------------------------------------------------------') 200 | 201 | print(args) 202 | main(args) 203 | -------------------------------------------------------------------------------- /TFvelo_run_demo.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import TFvelo as TFv 3 | import anndata as ad 4 | import numpy as np 5 | import scanpy as sc 6 | import scvelo as scv 7 | import matplotlib 8 | matplotlib.use('AGG') 9 | 10 | import os, sys 11 | 12 | def check_data_type(adata): 13 | for key in list(adata.var): 14 | if adata.var[key][0] in ['True', 'False']: 15 | adata.var[key] = adata.var[key].map({'True': True, 'False': False}) 16 | return 17 | 18 | def data_type_tostr(adata, key): 19 | if key in adata.var.keys(): 20 | if adata.var[key][0] in [True, False]: 21 | adata.var[key] = adata.var[key].map({True: 'True', False:'False'}) 22 | return 23 | 24 | 25 | def preprocess(args): 26 | print('----------------------------------preprocess',args.dataset_name,'---------------------------------------------') 27 | if args.dataset_name == 'pancreas': 28 | adata = scv.datasets.pancreas() 29 | elif args.dataset_name == 'gastrulation_erythroid': 30 | adata = scv.datasets.gastrulation_erythroid() 31 | adata.uns['clusters_colors'] = adata.uns['celltype_colors'].copy() 32 | adata.obs['clusters'] = adata.obs['celltype'].copy() 33 | elif args.dataset_name == 'hesc1': 34 | expression = pd.read_table("data/hesc1/rpkm.txt", header=0, index_col=0, sep="\t").T 35 | adata = ad.AnnData(expression) 36 | adata.obs_names = expression.index 37 | adata.var_names = expression.columns 38 | adata.obs['time_gt'] = 'Nan' 39 | for ii, cell in enumerate(adata.obs_names): 40 | adata.obs['time_gt'][ii] = cell.split('.')[0] 41 | adata.obs['time_gt'] = adata.obs['time_gt'].astype('category') 42 | adata.obs['clusters'] = adata.obs['time_gt'].copy() 43 | elif args.dataset_name == '10x_mouse_brain': 44 | adata = ad.read_h5ad("data/10x_mouse_brain/adata_rna.h5ad") 45 | adata.obs['clusters'] = adata.obs['celltype'].copy() 46 | 47 | if not os.path.exists(args.result_path): 48 | os.makedirs(args.result_path) 49 | adata.var_names_make_unique() 50 | adata.obs_names_make_unique() 51 | 52 | adata.uns['genes_all'] = np.array(adata.var_names) 53 | 54 | if "spliced" in adata.layers: 55 | adata.layers["total"] = adata.layers["spliced"].todense() + adata.layers["unspliced"].todense() 56 | elif "new" in adata.layers: 57 | adata.layers["total"] = np.array(adata.layers["total"].todense()) 58 | else: 59 | adata.layers["total"] = adata.X 60 | adata.layers["total_raw"] = adata.layers["total"].copy() 61 | n_cells, n_genes = adata.X.shape 62 | sc.pp.filter_genes(adata, min_cells=int(n_cells/50)) 63 | sc.pp.filter_cells(adata, min_genes=int(n_genes/50)) 64 | TFv.pp.filter_and_normalize(adata, min_shared_counts=20, n_top_genes=2000, log=True) #include the following steps 65 | adata.X = adata.layers["total"].copy() 66 | 67 | if not args.dataset_name in ['10x_mouse_brain']: 68 | adata.uns['clusters_colors'] = np.array(['red', 'orange', 'yellow', 'green','skyblue', 'blue','purple', 'pink', '#8fbc8f', '#f4a460', '#fdbf6f', '#ff7f00', '#b2df8a', '#1f78b4', 69 | '#6a3d9a', '#cab2d6'], dtype=object) 70 | 71 | gene_names = [] 72 | for tmp in adata.var_names: 73 | gene_names.append(tmp.upper()) 74 | adata.var_names = gene_names 75 | adata.var_names_make_unique() 76 | adata.obs_names_make_unique() 77 | 78 | TFv.pp.moments(adata, n_pcs=30, n_neighbors=args.n_neighbors) 79 | 80 | TFv.pp.get_TFs(adata, databases=args.TF_databases) 81 | print(adata) 82 | adata.uns['genes_pp'] = np.array(adata.var_names) 83 | adata.write(args.result_path + 'pp.h5ad') 84 | 85 | 86 | 87 | def main(args): 88 | print('--------------------------------') 89 | adata = ad.read_h5ad(args.result_path + 'pp.h5ad') 90 | 91 | n_jobs_max = np.max([int(os.cpu_count()/2), 1]) 92 | if args.n_jobs >= 1: 93 | n_jobs = np.min([args.n_jobs, n_jobs_max]) 94 | else: 95 | n_jobs = n_jobs_max 96 | print('n_jobs:', n_jobs) 97 | flag = TFv.tl.recover_dynamics(adata, n_jobs=n_jobs, max_iter=args.max_iter, var_names=args.var_names, 98 | WX_method = args.WX_method, WX_thres=args.WX_thres, max_n_TF=args.max_n_TF, n_top_genes=args.n_top_genes, 99 | fit_scaling=True, use_raw=args.use_raw, init_weight_method=args.init_weight_method, 100 | n_time_points=args.n_time_points) 101 | if flag==False: 102 | return adata, False 103 | if 'highly_variable_genes' in adata.var.keys(): 104 | data_type_tostr(adata, key='highly_variable_genes') 105 | adata.write(args.result_path + 'rc.h5ad') 106 | return 107 | 108 | 109 | if __name__ == '__main__': 110 | import argparse 111 | parser = argparse.ArgumentParser() 112 | parser.add_argument( '--dataset_name', type=str, default="pancreas", help='pancreas, gastrulation_erythroid, 10x_mouse_brain, hesc1') 113 | parser.add_argument( '--n_jobs', type=int, default=28, help='number of cpus to use') 114 | parser.add_argument( '--var_names', type=str, default="all", help='all, highly_variable_genes') 115 | parser.add_argument( '--init_weight_method', type=str, default= "correlation", help='use correlation to initialize the weights') 116 | parser.add_argument( '--WX_method', type=str, default= "lsq_linear", help='LS, LASSO, Ridge, constant, LS_constrant, lsq_linear') 117 | parser.add_argument( '--n_neighbors', type=int, default=30, help='number of neighbors') 118 | parser.add_argument( '--WX_thres', type=int, default=20, help='the threshold for weights') 119 | parser.add_argument( '--n_top_genes', type=int, default=2000, help='n_top_genes') 120 | parser.add_argument( '--TF_databases', nargs='+', default='ENCODE ChEA', help='knockTF ChEA ENCODE') 121 | parser.add_argument( '--max_n_TF', type=int, default=99, help='max number of TFs') 122 | parser.add_argument( '--max_iter', type=int, default=20, help='max number of iteration in EM') 123 | parser.add_argument( '--n_time_points', type=int, default=1000, help='use_raw') 124 | parser.add_argument( '--save_name', type=str, default='_demo', help='save_name') 125 | parser.add_argument( '--use_raw', type=int, default=0, help='use_raw') 126 | parser.add_argument( '--basis', type=str, default='umap', help='umap') 127 | 128 | args = parser.parse_args() 129 | args.result_path = 'TFvelo_'+ args.dataset_name + args.save_name+ '/' 130 | print('********************************************************************************************************') 131 | print('********************************************************************************************************') 132 | print(args) 133 | preprocess(args) 134 | main(args) 135 | 136 | -------------------------------------------------------------------------------- /baselines/baseline_TI_demo.py: -------------------------------------------------------------------------------- 1 | import scvelo as scv 2 | import anndata as ad 3 | import numpy as np 4 | import scanpy as sc 5 | 6 | import matplotlib 7 | matplotlib.use('AGG') 8 | 9 | import os 10 | 11 | 12 | def run_paga(args, adata, iroot_tyre): 13 | # paga trajectory inference 14 | sc.tl.paga(adata, groups='clusters') 15 | #sc.pl.paga(adata, color=['clusters'], save='') 16 | 17 | # dpt_pseudotime inference 18 | adata.uns['iroot'] = np.flatnonzero(adata.obs['clusters'] == iroot_tyre)[0] 19 | sc.tl.dpt(adata) 20 | #sc.pl.umap(adata, color=['dpt_pseudotime'], legend_loc='on data', save='_'+args.dataset_name+'_dpt_pseudotime.png') 21 | scv.pl.scatter(adata, color='dpt_pseudotime', color_map='gnuplot', size=20, save=args.dataset_name+'_dpt_pseudotime.png') 22 | return adata 23 | 24 | def run_palantir(args, adata, iroot_tyre): 25 | sc.external.tl.palantir(adata, n_components=5, knn=30) 26 | 27 | iroot = np.flatnonzero(adata.obs['clusters'] == iroot_tyre)[0] 28 | start_cell = adata.obs_names[iroot] 29 | 30 | pr_res = sc.external.tl.palantir_results( 31 | adata, 32 | early_cell=start_cell, 33 | ms_data='X_palantir_multiscale', 34 | num_waypoints=500, 35 | ) 36 | adata.obs['pr_pseudotime'] = pr_res.pseudotime 37 | adata.obs['pr_entropy'] = pr_res.entropy 38 | #adata.obs['pr_branch_probs'] = pr_res.branch_probs 39 | #adata.uns['pr_waypoints'] = pr_res.waypoints 40 | 41 | #sc.pl.umap(adata, color=['pr_pseudotime'], legend_loc='on data', save='_'+args.dataset_name+'_pr_pseudotime.png') 42 | scv.pl.scatter(adata, color='pr_pseudotime', color_map='gnuplot', size=20, save=args.dataset_name+'_pr_pseudotime.png') 43 | return adata 44 | 45 | 46 | def main(args): 47 | adata = ad.read_h5ad(args.result_path + 'pp.h5ad') 48 | #adata.X = adata.layers['M_total'] 49 | 50 | if args.dataset_name == 'pancreas': 51 | iroot_tyre = 'Ductal' 52 | elif args.dataset_name == 'gastrulation_erythroid': 53 | iroot_tyre = 'Blood progenitors 1' 54 | elif args.dataset_name == 'hesc1': 55 | iroot_tyre = 'E3' 56 | elif args.dataset_name == '10x_mouse_brain': 57 | iroot_tyre = 'RG, Astro, OPC' 58 | 59 | if 'X_pca' not in adata.obsm.keys(): 60 | print('PCA ing') 61 | sc.tl.pca(adata, n_comps=50, svd_solver='arpack') 62 | # sc.pl.pca(adata, color=['clusters'], show=False, save='_clusters.png') 63 | if ('X_umap' not in adata.obsm.keys()): 64 | print('Umap ing') 65 | if args.dataset_name == 'hesc1': 66 | sc.pp.neighbors(adata, use_rep="X_pca", n_neighbors=30, n_pcs=5) 67 | sc.tl.umap(adata) 68 | 69 | adata = run_paga(args, adata, iroot_tyre) 70 | adata = run_palantir(args, adata, iroot_tyre) 71 | 72 | adata.write(args.result_path + 'TI.h5ad') 73 | 74 | return 75 | 76 | 77 | if __name__ == '__main__': 78 | import argparse 79 | parser = argparse.ArgumentParser() 80 | parser.add_argument( '--dataset_name', type=str, default="pancreas", help='pancreas, gastrulation_erythroid') 81 | parser.add_argument( '--n_jobs', type=int, default=16, help='n_jobs') 82 | parser.add_argument( '--save_name', type=str, default='', help='save_name') 83 | args = parser.parse_args() 84 | 85 | for args.dataset_name in ['pancreas', 'gastrulation_erythroid', '10x_mouse_brain', 'hesc1']: 86 | args.result_path = 'TFvelo_'+ args.dataset_name + args.save_name+ '/' 87 | if not os.path.exists(args.result_path): 88 | os.makedirs(args.result_path) 89 | print('********************************************************************************************************') 90 | print('********************************************************************************************************') 91 | print(args) 92 | main(args) 93 | 94 | -------------------------------------------------------------------------------- /baselines/baseline_cellDancer_Demo.py: -------------------------------------------------------------------------------- 1 | # import packages 2 | import os 3 | import sys 4 | import glob 5 | import pandas as pd 6 | import math 7 | import matplotlib.pyplot as plt 8 | import celldancer as cd 9 | import celldancer.cdplt as cdplt 10 | from celldancer.cdplt import colormap, build_colormap 11 | import celldancer.utilities as cdutil 12 | import scvelo as scv 13 | import anndata as ad 14 | import scanpy as sc 15 | import numpy as np 16 | 17 | args_dataset = 'pancreas' #gastrulation_erythroid, pancreas 18 | args_pp = 0 19 | args_train = 0 20 | args_postp = 0 21 | n_jobs = 8 22 | 23 | 24 | color_map = None 25 | if args_dataset == 'pancreas': 26 | adata = scv.datasets.pancreas() 27 | color_map = colormap.colormap_pancreas 28 | elif args_dataset == 'gastrulation_erythroid': 29 | adata = scv.datasets.gastrulation_erythroid() 30 | color_map = colormap.colormap_erythroid 31 | adata.obs['clusters'] = adata.obs['celltype'].copy() 32 | elif args_dataset == 'bonemarrow': 33 | adata = scv.datasets.bonemarrow() 34 | elif args_dataset == 'dentategyrus': 35 | adata = scv.datasets.dentategyrus() 36 | elif args_dataset == 'larry': 37 | adata = ad.read_h5ad("data/larry/larry.h5ad") 38 | adata.obs['clusters'] = adata.obs['state_info'] 39 | adata.obsm['X_umap'] = np.stack([np.array(adata.obs['SPRING-x']), np.array(adata.obs['SPRING-y'])]).T 40 | elif args_dataset == 'pons': 41 | adata = ad.read_h5ad("data/pons/oligo_lite.h5ad") 42 | adata.obs['clusters'] = adata.obs['celltype'] 43 | 44 | print(adata) 45 | 46 | if color_map is None: 47 | cluster_list = list(adata.obs['clusters'].cat.categories) 48 | color_map = build_colormap(cluster_list) 49 | 50 | try: 51 | adata_TFv = ad.read_h5ad("../TFvelo_master/TFvelo_"+args_dataset+"/TFvelo.h5ad") 52 | color_map = {} 53 | for i, c in enumerate(adata_TFv.obs['clusters'].cat.categories): 54 | color_map[c] = adata_TFv.uns['clusters_colors'][i] 55 | except: 56 | print('No TFvelo colors') 57 | 58 | save_folder = 'cellDancer_' + args_dataset +'/' 59 | if not os.path.exists(save_folder): 60 | os.makedirs(save_folder) 61 | csv_path = save_folder + 'cell_type_u_s.csv' 62 | 63 | 64 | if args_pp: 65 | scv.pp.filter_and_normalize(adata, min_shared_counts=20, n_top_genes=2000) 66 | scv.pp.moments(adata, n_pcs=30, n_neighbors=30) # cell amount will influence the setti 67 | if 'X_umap' not in adata.obsm: 68 | if 'X_tsne' in adata.obsm: 69 | print('Copying tsne to umap') 70 | adata.obsm['X_umap'] = adata.obsm['X_tsne'] 71 | else: 72 | sc.pp.neighbors(adata, n_neighbors=30) 73 | sc.tl.umap(adata) 74 | 75 | print(adata) 76 | adata.write(save_folder + 'pp.h5ad') 77 | 78 | 79 | cell_type_u_s = cdutil.adata_to_df_with_embed(adata, 80 | us_para=['Mu','Ms'], 81 | cell_type_para='clusters', 82 | embed_para='X_umap', 83 | save_path=csv_path) 84 | #gene_list=['Hba-x','Smim1'] 85 | else: 86 | adata = ad.read_h5ad(save_folder + 'pp.h5ad') 87 | cell_type_u_s = pd.read_csv(csv_path) 88 | 89 | print(cell_type_u_s) 90 | 91 | 92 | if args_train: # to train model on each gene 93 | loss_df, cellDancer_df = cd.velocity(cell_type_u_s, permutation_ratio=0.5, n_jobs=n_jobs) 94 | else: 95 | loss_df = pd.read_csv(save_folder+'loss.csv') 96 | cellDancer_df = pd.read_csv(save_folder+'cellDancer_estimation.csv') 97 | 98 | 99 | if args_postp: 100 | # Compute cell velocity 101 | cellDancer_df=cd.compute_cell_velocity(cellDancer_df=cellDancer_df, projection_neighbor_size=100) 102 | 103 | # Plot cell velocity 104 | fig, ax = plt.subplots(figsize=(10,10)) 105 | im = cdplt.scatter_cell(ax, cellDancer_df, colors=color_map, alpha=0.5, s=20, velocity=True, legend='on', min_mass=5, arrow_grid=(20,20)) 106 | ax.axis('off') 107 | plt.savefig(save_folder+'arrowplot.png') 108 | plt.close() 109 | 110 | 111 | # set parameters 112 | dt = 0.001 113 | t_total = {dt: 10000} 114 | n_repeats = 10 115 | # estimate pseudotime 116 | cellDancer_df = cd.pseudo_time(cellDancer_df=cellDancer_df, 117 | grid=(30, 30), 118 | dt=dt, 119 | t_total=t_total[dt], 120 | n_repeats=n_repeats, 121 | speed_up=(60,60), 122 | n_paths = 5, 123 | psrng_seeds_diffusion=[i for i in range(n_repeats)], 124 | n_jobs=n_jobs) 125 | # plot pseudotime 126 | fig, ax = plt.subplots(figsize=(10,10)) 127 | im=cdplt.scatter_cell(ax,cellDancer_df, colors='pseudotime', alpha=0.5, velocity=False) 128 | ax.axis('off') 129 | plt.savefig(save_folder+'pseudotime.png') 130 | plt.close() 131 | 132 | cellDancer_df.to_csv(os.path.join(save_folder, ('cellDancer_estimation_final.csv')),index=False) 133 | 134 | else: 135 | cellDancer_df = pd.read_csv(save_folder+'cellDancer_estimation_final.csv') 136 | 137 | 138 | -------------------------------------------------------------------------------- /baselines/baseline_dynamo_demo.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import scvelo as scv 3 | import anndata as ad 4 | import numpy as np 5 | import scanpy as sc 6 | import dynamo as dyn 7 | 8 | import warnings 9 | warnings.filterwarnings('ignore') 10 | 11 | import matplotlib 12 | matplotlib.use('AGG') 13 | 14 | import os, sys 15 | 16 | def data_type_tostr(adata, key=None): 17 | print(adata) 18 | if key is None: 19 | for key in list(adata.var): 20 | if adata.var[key][0] in [True, False]: 21 | print('Transfering', key, 'because True/False') 22 | adata.var[key] = adata.var[key].map({True: 'True', False:'False'}) 23 | if adata.var[key][0] is None: 24 | print('Transfering', key, 'because None') 25 | for j in range(len(adata.var[key])): 26 | if adata.var[key][j] is None: 27 | adata.var[key][j] = 'None' 28 | else: 29 | adata.var[key][j] = str(adata.var[key][j]) 30 | if adata.var[key][0] is np.nan: 31 | print('Transfering', key, 'because NaN') 32 | for j in range(len(adata.var[key])): 33 | if adata.var[key][j] is None: 34 | adata.var[key][j] = 'NaN' 35 | else: 36 | adata.var[key][j] = str(adata.var[key][j]) 37 | elif key in adata.var.keys(): 38 | if adata.var[key][0] in [True, False]: 39 | print('Transfering', key) 40 | adata.var[key] = adata.var[key].map({True: 'True', False:'False'}) 41 | if 'cell_phase_genes' in adata.uns: 42 | del adata.uns['cell_phase_genes'] 43 | return 44 | 45 | def dynamo_workflow_scNTseq(adata, **kwargs): 46 | preprocessor = dyn.pp.Preprocessor(cell_cycle_score_enable=True) 47 | preprocessor.preprocess_adata(adata, recipe='monocle', **kwargs) 48 | 49 | dyn.tl.dynamics(adata) 50 | 51 | dyn.tl.reduceDimension(adata) 52 | 53 | dyn.tl.cell_velocities(adata, calc_rnd_vel=True, transition_genes=adata.var_names) 54 | 55 | dyn.vf.VectorField(adata, basis='umap') 56 | return 57 | 58 | def main(args): 59 | if args.dataset_name == 'pancreas': 60 | adata = scv.datasets.pancreas() 61 | elif args.dataset_name == 'gastrulation_erythroid': 62 | adata = scv.datasets.gastrulation_erythroid() 63 | adata.obs['clusters'] = adata.obs['celltype'].copy() 64 | 65 | print(adata) 66 | 67 | dyn.pp.recipe_monocle(adata) 68 | dyn.tl.dynamics(adata, cores=3) 69 | 70 | dyn.tl.reduceDimension(adata) 71 | dyn.tl.cell_velocities(adata) 72 | 73 | dyn.tl.cell_wise_confidence(adata) 74 | dyn.vf.VectorField(adata) 75 | 76 | print(adata) 77 | 78 | data_type_tostr(adata) 79 | adata.write(args.result_path + 'dynamo.h5ad') 80 | return 81 | 82 | def analysis(args): 83 | adata_TFv = ad.read_h5ad(args.result_path+'TFvelo.h5ad') 84 | n_colors = len(adata_TFv.obs['clusters'].cat.categories) 85 | adata_TFv.uns['clusters_colors'] = adata_TFv.uns['clusters_colors'][:n_colors] 86 | 87 | method = 'dynamo' 88 | adata = ad.read_h5ad(args.result_path+method+'.h5ad') 89 | adata.uns['clusters_colors'] = adata_TFv.uns['clusters_colors'][adata_TFv.obs['clusters'].cat.categories.argsort()] 90 | 91 | dyn.vf.VectorField(adata, basis='umap', M=100) 92 | dyn.ext.ddhodge(adata, basis='umap') 93 | print(adata) 94 | 95 | save_kwargs = {"path": 'figures/', "prefix": 'dyn_'+args.dataset_name+'_embedding_stream', "dpi": 300, "ext": 'png'} 96 | dyn.pl.streamline_plot(adata, color=['clusters'], color_key=adata.uns['clusters_colors'], save_show_or_return='save', save_kwargs=save_kwargs) 97 | save_kwargs = {"path": 'figures/', "prefix": 'dyn_'+args.dataset_name+'_embedding_grid', "dpi": 300, "ext": 'png'} 98 | dyn.pl.grid_vectors(adata, color=['clusters'], color_key=adata.uns['clusters_colors'], save_show_or_return='save', save_kwargs=save_kwargs) 99 | #save_kwargs = {"path": 'figures/', "prefix": 'dyn_'+args.dataset_name+'_pseudotime', "dpi": 300, "ext": 'png'} 100 | #dyn.pl.streamline_plot(adata, color=['umap_ddhodge_potential'], save_show_or_return='save', save_kwargs=save_kwargs) 101 | scv.pl.scatter(adata, basis='umap', color='umap_ddhodge_potential', cmap='gnuplot', fontsize=20, save='dynamo_pseudotime.png') 102 | adata.write(args.result_path + 'dynamo.h5ad') 103 | return 104 | 105 | if __name__ == '__main__': 106 | import argparse 107 | parser = argparse.ArgumentParser() 108 | parser.add_argument( '--dataset_name', type=str, default="pancreas", help='pancreas, gastrulation_erythroid, pons, scNT_seq') 109 | parser.add_argument( '--save_name', type=str, default='_demo', help='save_name') 110 | 111 | args = parser.parse_args() 112 | args.result_path = 'TFvelo_'+ args.dataset_name + args.save_name+ '/' 113 | if not os.path.exists(args.result_path): 114 | os.makedirs(args.result_path) 115 | print('********************************************************************************************************') 116 | print('********************************************************************************************************') 117 | print(args) 118 | 119 | #main(args) 120 | analysis(args) 121 | 122 | -------------------------------------------------------------------------------- /baselines/baseline_scvelo_demo.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import scvelo as scv 3 | import anndata as ad 4 | import numpy as np 5 | import scanpy as sc 6 | 7 | import matplotlib 8 | matplotlib.use('AGG') 9 | import umap 10 | 11 | import os, sys 12 | 13 | 14 | def main(args): 15 | scv.set_figure_params() 16 | if args.dataset_name == 'pancreas': 17 | adata = scv.datasets.pancreas() 18 | elif args.dataset_name == 'gastrulation_erythroid': 19 | adata = scv.datasets.gastrulation_erythroid() 20 | adata.uns['clusters_colors'] = adata.uns['celltype_colors'].copy() 21 | adata.obs['clusters'] = adata.obs['celltype'].copy() 22 | 23 | scv.pp.filter_and_normalize(adata, min_shared_counts=20, n_top_genes=2000) 24 | scv.pp.moments(adata, n_pcs=30, n_neighbors=30) 25 | scv.tl.recover_dynamics(adata, n_jobs=args.n_jobs) 26 | scv.tl.velocity(adata, mode='dynamical') 27 | scv.tl.velocity_graph(adata) 28 | adata.write(args.result_path + 'scvelo.h5ad') 29 | return 30 | 31 | 32 | def analysis(args): 33 | adata_TFv = ad.read_h5ad(args.result_path+'TFvelo.h5ad') 34 | n_colors = len(adata_TFv.obs['clusters'].cat.categories) 35 | adata_TFv.uns['clusters_colors'] = adata_TFv.uns['clusters_colors'][:n_colors] 36 | 37 | method = 'scvelo' 38 | adata = ad.read_h5ad(args.result_path+method+'.h5ad') 39 | adata.uns['clusters_colors'] = adata_TFv.uns['clusters_colors'].copy() 40 | if args.dataset_name == 'gastrulation_erythroid': 41 | adata.obs['clusters'] = adata.obs['celltype'].copy() 42 | 43 | 44 | scv.tl.latent_time(adata) 45 | print(adata) 46 | scv.pl.scatter(adata, color='velocity_pseudotime', cmap='gnuplot', fontsize=20, save=args.dataset_name+'_'+method+'_pseudotime.png') 47 | scv.pl.velocity_embedding_stream(adata, color='clusters', dpi=300, title='', save= args.dataset_name+'_'+method+'_embedding_stream.png') 48 | scv.pl.velocity_embedding_grid(adata, color='clusters', arrow_size=10, dpi=300, title='', save= args.dataset_name+'_'+method+'_embedding_grid.png') 49 | 50 | return 51 | 52 | if __name__ == '__main__': 53 | import argparse 54 | parser = argparse.ArgumentParser() 55 | parser.add_argument( '--dataset_name', type=str, default="pancreas", help='pancreas, gastrulation_erythroid') 56 | parser.add_argument( '--n_jobs', type=int, default=28, help='n_jobs') 57 | parser.add_argument( '--save_name', type=str, default='_demo', help='save_name') 58 | 59 | args = parser.parse_args() 60 | args.result_path = 'TFvelo_'+ args.dataset_name + args.save_name+ '/' 61 | if not os.path.exists(args.result_path): 62 | os.makedirs(args.result_path) 63 | print('********************************************************************************************************') 64 | print('********************************************************************************************************') 65 | print(args) 66 | 67 | main(args) 68 | analysis(args) 69 | 70 | -------------------------------------------------------------------------------- /baselines/baseline_unitvelo_demo.py: -------------------------------------------------------------------------------- 1 | import unitvelo as utv 2 | import scvelo as scv 3 | import scanpy as sc 4 | import anndata as ad 5 | 6 | velo = utv.config.Configuration() 7 | velo.R2_ADJUST = True 8 | velo.IROOT = None 9 | velo.FIT_OPTION = '1' 10 | velo.GPU = 1 11 | 12 | 13 | 14 | def main(args): 15 | if args.dataset_name == 'pancreas': 16 | path_to_adata = 'data/Pancreas/endocrinogenesis_day15.h5ad' 17 | label = 'clusters' 18 | elif args.dataset_name == 'gastrulation_erythroid': 19 | path_to_adata = 'data/Gastrulation/erythroid_lineage.h5ad' 20 | label = 'celltype' 21 | 22 | adata = utv.run_model(path_to_adata, label, config_file=velo) 23 | 24 | scv.pp.neighbors(adata) 25 | 26 | adata.write(args.result_path+'unitvelo.h5ad') 27 | 28 | #subvar = adata.var.loc[adata.var['velocity_genes'] == True] 29 | #sub = adata[:, subvar.index] 30 | 31 | return 32 | 33 | def analysis(args): 34 | adata_TFv = ad.read_h5ad(args.result_path+'TFvelo.h5ad') 35 | n_colors = len(adata_TFv.obs['clusters'].cat.categories) 36 | adata_TFv.uns['clusters_colors'] = adata_TFv.uns['clusters_colors'][:n_colors] 37 | 38 | method = 'unitvelo' 39 | adata = ad.read_h5ad(args.result_path+method+'.h5ad') 40 | adata.uns['clusters_colors'] = adata_TFv.uns['clusters_colors'].copy() 41 | adata.uns['label'] = 'clusters' 42 | 43 | if args.dataset_name == 'pancreas': 44 | label = 'clusters' 45 | elif args.dataset_name == 'gastrulation_erythroid': 46 | adata.obs['clusters'] = adata.obs['celltype'].copy() 47 | elif args.dataset_name == 'pons': 48 | adata.obs['clusters'] = adata.obs['celltype'].copy() 49 | 50 | scv.pl.scatter(adata, color='velocity_pseudotime', cmap='gnuplot', fontsize=20, save=args.dataset_name+'_'+method+'_pseudotime.png') 51 | scv.pl.velocity_embedding_stream(adata, color='clusters', dpi=300, title='', save= args.dataset_name+'_'+method+'_embedding_stream.png') 52 | scv.pl.velocity_embedding_grid(adata, color='clusters', dpi=300, arrow_size=10, title='', save= args.dataset_name+'_'+method+'_embedding_grid.png') 53 | adata.write(args.result_path+'unitvelo.h5ad') 54 | 55 | return 56 | 57 | 58 | if __name__ == '__main__': 59 | import argparse 60 | parser = argparse.ArgumentParser() 61 | parser.add_argument( '--dataset_name', type=str, default="gastrulation_erythroid", help='pancreas, gastrulation_erythroid, pons') 62 | parser.add_argument( '--save_name', type=str, default='_demo', help='save_name') 63 | args = parser.parse_args() 64 | 65 | args.result_path = 'TFvelo_'+ args.dataset_name + args.save_name+ '/' 66 | print('------------------------------------------------------------') 67 | print(args) 68 | #main(args) 69 | analysis(args) 70 | 71 | -------------------------------------------------------------------------------- /baselines/compare_baselines_metrics_demo.py: -------------------------------------------------------------------------------- 1 | import unitvelo as utv 2 | import scvelo as scv 3 | import scanpy as sc 4 | import anndata as ad 5 | import numpy as np 6 | import pandas as pd 7 | import TFvelo as TFv 8 | import os 9 | 10 | velo = utv.config.Configuration() 11 | velo.R2_ADJUST = True 12 | velo.IROOT = None 13 | velo.FIT_OPTION = '1' 14 | velo.GPU = 1 15 | 16 | 17 | 18 | def summary_scores(all_scores): 19 | sep_scores = {k:np.mean(s) for k, s in all_scores.items() if s} 20 | overal_agg = np.mean([s for k, s in sep_scores.items() if s]) 21 | return sep_scores, overal_agg 22 | 23 | def utv_eva(args, method='unitvelo'): 24 | print('-------------------', method, '--------------------') 25 | adata = ad.read_h5ad(args.result_path + method + '.h5ad') 26 | if method == 'scvelo': 27 | adata = adata[:, adata.var['velocity_genes'] == True] 28 | elif method == 'unitvelo': 29 | adata = adata[:, adata.var['velocity_genes'] == True] 30 | elif method == 'dynamo': 31 | subvar = adata.var.loc[adata.var['gamma'] != 'None'] 32 | adata = adata[:, subvar.index] 33 | adata.obsm['velocity_umap'] = np.nan_to_num(adata.obsm['velocity_umap']) 34 | adata.layers['velocity'] = adata.layers['velocity_S'] 35 | scv.pp.neighbors(adata) 36 | 37 | if args.dataset_name == 'pancreas': 38 | cluster_edges = [('Ductal', 'Ngn3 low EP'), ('Ngn3 low EP', 'Ngn3 high EP'), ('Ngn3 high EP', 'Pre-endocrine'), 39 | ('Pre-endocrine', 'Beta'), ('Pre-endocrine', 'Alpha'), ('Pre-endocrine', 'Delta'), ('Pre-endocrine', 'Epsilon')] 40 | 41 | elif args.dataset_name == 'gastrulation_erythroid': 42 | cluster_edges = [('Blood progenitors 1', 'Blood progenitors 2'), ('Blood progenitors 2', 'Erythroid1'), 43 | ('Erythroid1', 'Erythroid2'), ('Erythroid2', 'Erythroid3')] 44 | 45 | elif args.dataset_name == 'pons': 46 | cluster_edges = [('OPCs', 'COPs'), ('COPs', 'NFOLs'), ('NFOLs', 'MFOLs')] 47 | 48 | metrics = utv.evaluate(adata, cluster_edges, k_cluster='clusters', k_velocity='velocity') 49 | Cross_Boundary_Direction_Correctness = list(summary_scores(metrics['Cross-Boundary Direction Correctness (A->B)'])[0].values()) 50 | In_Cluster_Coherence = list(summary_scores(metrics['In-cluster Coherence'])[0].values()) 51 | return Cross_Boundary_Direction_Correctness, In_Cluster_Coherence 52 | 53 | 54 | def draw_plot(df, save_name, fig_type='boxplot'): 55 | import seaborn as sns 56 | import matplotlib.pyplot as plt 57 | my_palette = ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", 58 | "#8c564b", "#e377c2", "#7f7f7f", "#bcbd22", "#17becf"] 59 | color_mapping = {method: color for method, color in zip(df.keys(), my_palette)} 60 | 61 | plt.figure() 62 | sns.set_palette(my_palette[:df.shape[-1]]) 63 | df_mean = df.mean() 64 | df_mean_sorted = df_mean.sort_values(ascending=False) 65 | colors_sorted = [color_mapping[method] for method in df_mean_sorted.index] 66 | if fig_type=='boxplot': 67 | #ax = sns.boxplot(data=df) 68 | #plt.ylim(np.percentile(df['unitvelo'], 10), 1) 69 | ax = sns.boxplot(data=df, order=df_mean_sorted.index.tolist(), palette=colors_sorted) 70 | elif fig_type=='violinplot': 71 | ax = sns.violinplot(data=df) 72 | elif fig_type=='barplot': 73 | ax = sns.barplot(x=df_mean_sorted.index, y=df_mean_sorted.values, palette=colors_sorted) 74 | for index, value in enumerate(list(df_mean_sorted.values)): 75 | ax.text(index, value, f"{value:.2f}", ha='center', va='bottom', fontsize=12) 76 | else: 77 | return 78 | if len(df.keys())>4: 79 | plt.xticks(rotation=30) 80 | plt.title(save_name.replace('_', ' '), fontsize=16) 81 | plt.xlabel('Methods', fontsize=16) 82 | plt.ylabel('Values', fontsize=16) 83 | #ax = plt.gca() 84 | ax.tick_params(axis='x', labelsize=16) 85 | ax.tick_params(axis='y', labelsize=16) 86 | plt.savefig('figures/'+save_name+'_'+fig_type+'.png', dpi=300, bbox_inches='tight', pad_inches=0.1) 87 | plt.close() 88 | return 89 | 90 | def run_utv_metric(args, methods): 91 | Cross_Boundary_Direction_Correctness_all, In_Cluster_Coherence_all = {}, {} 92 | for method in methods: 93 | Cross_Boundary_Direction_Correctness, In_Cluster_Coherence = utv_eva(args, method) 94 | Cross_Boundary_Direction_Correctness_all[method] = Cross_Boundary_Direction_Correctness 95 | In_Cluster_Coherence_all[method] = In_Cluster_Coherence 96 | 97 | Cross_Boundary_Direction_Correctness_df = pd.DataFrame(Cross_Boundary_Direction_Correctness_all) 98 | Cross_Boundary_Direction_Correctness_df.to_csv('figures/Cross_Boundary_Direction_Correctness_'+args.dataset_name+'.txt', sep='\t', index=False) 99 | In_Cluster_Coherence_all_df = pd.DataFrame(In_Cluster_Coherence_all) 100 | In_Cluster_Coherence_all_df.to_csv('figures/In_Cluster_Coherence_all_'+args.dataset_name+'.txt', sep='\t', index=False) 101 | return 102 | 103 | def draw_utv_metric(args): 104 | Cross_Boundary_Direction_Correctness_df = pd.read_csv('figures/Cross_Boundary_Direction_Correctness_'+args.dataset_name+'.txt', sep='\t') 105 | In_Cluster_Coherence_all_df = pd.read_csv('figures/In_Cluster_Coherence_all_'+args.dataset_name+'.txt', sep='\t') 106 | 107 | draw_plot(Cross_Boundary_Direction_Correctness_df, save_name='Cross_Boundary_Direction_Correctness', 108 | fig_type='boxplot') 109 | draw_plot(Cross_Boundary_Direction_Correctness_df, save_name='Cross_Boundary_Direction_Correctness', 110 | fig_type='barplot') 111 | draw_plot(In_Cluster_Coherence_all_df, save_name='In_Cluster_Coherence', 112 | fig_type='boxplot') 113 | draw_plot(In_Cluster_Coherence_all_df, save_name='In_Cluster_Coherence', 114 | fig_type='barplot') 115 | return 116 | 117 | def get_velocity_consistency(args, methods): 118 | confidence_all = {} 119 | for method in methods: 120 | print('-------------------', method, '--------------------') 121 | adata = ad.read_h5ad(args.result_path + method + '.h5ad') 122 | if method == 'dynamo': 123 | subvar = adata.var.loc[adata.var['gamma'] != 'None'] 124 | adata = adata[:, subvar.index] 125 | adata.obsm['velocity_umap'] = np.nan_to_num(adata.obsm['velocity_umap']) 126 | adata.layers['velocity'] = adata.layers['velocity_S'].todense() 127 | adata.layers['Ms'] = adata.layers['M_s'].todense() 128 | scv.pp.neighbors(adata) 129 | scv.tl.velocity_graph(adata) 130 | 131 | if method == 'TFvelo': 132 | TFv.tl.velocity_confidence(adata) 133 | else: 134 | scv.tl.velocity_confidence(adata) 135 | adata.obs['velocity_consistency'] = adata.obs['velocity_confidence'] 136 | del adata.obs['velocity_confidence'] 137 | scv.pl.scatter(adata, c='velocity_consistency', cmap='coolwarm', fontsize=20, save='velocity_consistency_'+method+'.png') 138 | confidence_all[method] = adata.obs['velocity_consistency'] 139 | print(adata.obs['velocity_consistency'].mean()) 140 | confidence_all_df = pd.DataFrame(confidence_all) 141 | confidence_all_df.to_csv('figures/velocity_consistency_'+args.dataset_name+'.txt', sep='\t', index=False) 142 | return 143 | 144 | def draw_confidence(args): 145 | confidence_all_df = pd.read_csv('figures/velocity_consistency_'+args.dataset_name+'.txt', sep='\t') 146 | draw_plot(confidence_all_df, save_name='velocity_consistency', 147 | fig_type='boxplot') 148 | draw_plot(confidence_all_df, save_name='velocity_consistency', 149 | fig_type='barplot') 150 | return 151 | 152 | 153 | if __name__ == '__main__': 154 | import argparse 155 | parser = argparse.ArgumentParser() 156 | parser.add_argument( '--dataset_name', type=str, default="pancreas", help='pancreas, gastrulation_erythroid, pons') 157 | parser.add_argument( '--save_name', type=str, default='_demo', help='save_name') 158 | args = parser.parse_args() 159 | 160 | args.result_path = 'TFvelo_'+ args.dataset_name + args.save_name+ '/' 161 | print('------------------------------------------------------------') 162 | print(args) 163 | methods = ['TFvelo', 'scvelo', 'unitvelo', 'dynamo', 'cellDancer'] 164 | 165 | if not os.path.exists('figures'): 166 | os.makedirs('figures') 167 | 168 | run_utv_metric(args, methods) 169 | draw_utv_metric(args) 170 | 171 | get_velocity_consistency(args, methods) 172 | draw_confidence(args) 173 | 174 | 175 | -------------------------------------------------------------------------------- /data/10x_mouse_brain/adata_rna.h5ad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaoyeye/TFvelo/ec6cdb940af94f02fe32c8fcdb98494cdb4beb96/data/10x_mouse_brain/adata_rna.h5ad -------------------------------------------------------------------------------- /figures/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaoyeye/TFvelo/ec6cdb940af94f02fe32c8fcdb98494cdb4beb96/figures/demo.png -------------------------------------------------------------------------------- /simulation/simulate_phase_delay.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | start = -0.2 5 | end = 1.2 6 | x = np.linspace(start, end, 500) 7 | t = (x-start)/(end-start) 8 | y1 = np.sin(x*np.pi) 9 | y1_n = y1 + np.random.normal(0, 0.2, len(x)) 10 | y2 = np.sin(x*np.pi-1) 11 | y2_n = y2 + np.random.normal(0, 0.2, len(x)) 12 | 13 | fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6)) 14 | # First subplot 15 | ax1.plot(t, y1, label='TF', color='blue', linewidth=5) 16 | ax1.plot(t, y2, label='Target', color='red', linewidth=5) 17 | ax1.scatter(t, y1_n, c=t, alpha=0.7) 18 | ax1.scatter(t, y2_n, c=t, alpha=0.7) 19 | ax1.set_xlabel('t', fontsize=25) 20 | ax1.set_ylabel('Abundance', fontsize=25) 21 | ax1.legend(fontsize=25) 22 | ax1.set_xticks([]) 23 | ax1.set_yticks([]) 24 | #ax1.tick_params(axis='both', labelsize=20) 25 | # Second subplot 26 | ax2.plot(y2, y1, color='purple', linewidth=5) 27 | scatter = ax2.scatter(y2_n, y1_n, c=t, alpha=0.7) 28 | ax2.set_xlabel('Target', fontsize=25) 29 | ax2.set_ylabel('TF', fontsize=25) 30 | ax2.set_xticks([]) 31 | ax2.set_yticks([]) 32 | #ax2.tick_params(axis='both', labelsize=20) 33 | cbar = plt.colorbar(scatter, ax=ax2) 34 | cbar.set_label('t', fontsize=25) 35 | cbar.ax.tick_params(labelsize=20) 36 | # Save the figure 37 | plt.savefig('figures/TF_target.png', dpi=300, bbox_inches='tight', pad_inches=0.1) 38 | plt.close() 39 | 40 | 41 | 42 | y2 = np.sin(x*np.pi-0.1) 43 | y2_n = y2 + np.random.normal(0, 0.2, len(x)) 44 | fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6)) 45 | # First subplot 46 | ax1.plot(t, y1, label='Unspliced', color='blue', linewidth=5) 47 | ax1.plot(t, y2, label='Spliced', color='red', linewidth=5) 48 | ax1.scatter(t, y1_n, c=t, alpha=0.7) 49 | ax1.scatter(t, y2_n, c=t, alpha=0.7) 50 | ax1.set_xlabel('t', fontsize=25) 51 | ax1.set_ylabel('Abundance', fontsize=25) 52 | ax1.legend(fontsize=25) 53 | ax1.set_xticks([]) 54 | ax1.set_yticks([]) 55 | #ax1.tick_params(axis='both', labelsize=20) 56 | # Second subplot 57 | ax2.plot(y2, y1, color='purple', linewidth=5) 58 | scatter = ax2.scatter(y2_n, y1_n, c=t, alpha=0.7) 59 | ax2.set_xlabel('Spliced', fontsize=25) 60 | ax2.set_ylabel('Unspliced', fontsize=25) 61 | ax2.set_xticks([]) 62 | ax2.set_yticks([]) 63 | #ax2.tick_params(axis='both', labelsize=20) 64 | cbar = plt.colorbar(scatter, ax=ax2) 65 | cbar.set_label('t', fontsize=25) 66 | cbar.ax.tick_params(labelsize=20) 67 | plt.tight_layout() 68 | # Save the figure 69 | plt.savefig('figures/U_S.png', dpi=300, bbox_inches='tight', pad_inches=0.1) 70 | plt.close() 71 | --------------------------------------------------------------------------------