├── .gitignore ├── .readthedocs.yaml ├── LICENSE ├── MANIFEST.in ├── README.md ├── docs ├── Makefile ├── make.bat ├── requirements.txt └── source │ ├── _static │ └── img │ │ └── scTour_head_image.png │ ├── _templates │ └── autosummary │ │ └── class.rst │ ├── api_predict.rst │ ├── api_reverse_time.rst │ ├── api_train.rst │ ├── api_vf.rst │ ├── conf.py │ ├── index.rst │ └── notebook │ ├── scTour_inference_PostInference_adjustment.ipynb │ ├── scTour_inference_basic.ipynb │ └── scTour_prediction.ipynb ├── sctour ├── __init__.py ├── _utils.py ├── data.py ├── logger.py ├── model.py ├── module.py ├── predict.py ├── train.py └── vector_field.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | /**/__pycache__/ 2 | sctour.egg-info 3 | build_doc.sh 4 | build/ 5 | build_wheel.sh 6 | /dist/ 7 | docs/source/generated/ 8 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | os: ubuntu-22.04 5 | tools: 6 | python: "3.11" 7 | 8 | sphinx: 9 | configuration: docs/source/conf.py 10 | 11 | python: 12 | install: 13 | - requirements: docs/requirements.txt 14 | - method: pip 15 | path: . 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Qian Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include LICENSE 3 | recursive-exclude * __pycache__ 4 | recursive-exclude * *.pyc 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # scTour 3 | 4 | 5 | 6 | scTour is an innovative and comprehensive method for dissecting cellular dynamics by analysing datasets derived from single-cell genomics. 7 | 8 | It provides a unifying framework to depict the full picture of developmental processes from multiple angles including the developmental pseudotime, vector field and latent space. 9 | 10 | It further generalises these functionalities to a multi-task architecture for within-dataset inference and cross-dataset prediction of cellular dynamics in a batch-insensitive manner. 11 | 12 | ## Key features 13 | 14 | - cell pseudotime estimation with no need for specifying starting cells. 15 | - transcriptomic vector field inference with no discrimination between spliced and unspliced mRNAs. 16 | - latent space mapping by combining intrinsic transcriptomic structure with extrinsic pseudotime ordering. 17 | - model-based prediction of pseudotime, vector field, and latent space for query cells/datasets/time intervals. 18 | - insensitive to batch effects; robust to cell subsampling; scalable to large datasets. 19 | 20 | ## Installation 21 | 22 | [![PyPI](https://img.shields.io/pypi/v/sctour.svg?color=brightgreen&style=flat)](https://pypi.org/project/sctour) 23 | 24 | ```console 25 | pip install sctour 26 | ``` 27 | 28 | [![Conda](https://img.shields.io/conda/vn/conda-forge/sctour.svg?color=brightgreen&style=flat)](https://anaconda.org/conda-forge/sctour) 29 | 30 | ```console 31 | conda install -c conda-forge sctour 32 | ``` 33 | 34 | ## Documentation 35 | 36 | [![Documentation Status](https://readthedocs.org/projects/sctour/badge/?version=latest)](https://sctour.readthedocs.io/en/latest/?badge=latest) 37 | 38 | Full documentation can be found [here](https://sctour.readthedocs.io/en/latest/). 39 | 40 | ## Reference 41 | 42 | [Qian Li, scTour: a deep learning architecture for robust inference and accurate prediction of cellular dynamics. Genome Biology, 2023](https://genomebiology.biomedcentral.com/articles/10.1186/s13059-023-02988-9) 43 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.9.1 2 | torchdiffeq==0.2.2 3 | numpy>=1.19.2 4 | scanpy>=1.7.1 5 | anndata>=0.7.5 6 | scipy>=1.5.2 7 | tqdm>=4.32.2 8 | scikit-learn>=0.24.1 9 | myst_parser==0.15.2 10 | sphinx>=3.5.2 11 | nbsphinx==0.8.8 12 | ipython_genutils==0.2.0 13 | jinja2<3.1 14 | -------------------------------------------------------------------------------- /docs/source/_static/img/scTour_head_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiQian-XC/sctour/86e21fe356012776d55a07f7c2fa7013dc19def8/docs/source/_static/img/scTour_head_image.png -------------------------------------------------------------------------------- /docs/source/_templates/autosummary/class.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. add toctree option to make autodoc generate the pages 6 | 7 | .. autoclass:: {{ objname }} 8 | 9 | {% block attributes %} 10 | {% if attributes %} 11 | .. rubric:: Attributes 12 | 13 | .. autosummary:: 14 | :toctree: . 15 | {% for item in attributes %} 16 | {% if has_attr(fullname, item) %} 17 | ~{{ fullname }}.{{ item }} 18 | {% endif %} 19 | {%- endfor %} 20 | {% endif %} 21 | {% endblock %} 22 | 23 | {% block methods %} 24 | {% if methods %} 25 | .. rubric:: Methods 26 | 27 | .. autosummary:: 28 | :toctree: . 29 | {% for item in methods %} 30 | {%- if item != '__init__' %} 31 | ~{{ fullname }}.{{ item }} 32 | {%- endif -%} 33 | {%- endfor %} 34 | {% endif %} 35 | {% endblock %} 36 | -------------------------------------------------------------------------------- /docs/source/api_predict.rst: -------------------------------------------------------------------------------- 1 | Prediction 2 | ================================== 3 | 4 | .. autosummary:: 5 | :toctree: generated/ 6 | :nosignatures: 7 | 8 | sctour.predict.load_model 9 | sctour.predict.predict_time 10 | sctour.predict.predict_latentsp 11 | sctour.predict.predict_vector_field 12 | sctour.predict.predict_ltsp_from_time 13 | -------------------------------------------------------------------------------- /docs/source/api_reverse_time.rst: -------------------------------------------------------------------------------- 1 | Post-inference adjustment 2 | ================================== 3 | 4 | .. autofunction:: sctour.train.reverse_time 5 | -------------------------------------------------------------------------------- /docs/source/api_train.rst: -------------------------------------------------------------------------------- 1 | Model training 2 | ================================== 3 | 4 | .. autosummary:: 5 | :toctree: generated/ 6 | :nosignatures: 7 | 8 | sctour.train.Trainer 9 | -------------------------------------------------------------------------------- /docs/source/api_vf.rst: -------------------------------------------------------------------------------- 1 | Vector field visualization 2 | ================================== 3 | 4 | .. autofunction:: sctour.vf.plot_vector_field 5 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.insert(0, os.path.abspath('../../')) 4 | 5 | 6 | # -- Project information ----------------------------------------------------- 7 | from datetime import datetime 8 | import sctour 9 | 10 | project = 'sctour' 11 | author = 'Qian Li' 12 | copyright = f'{datetime.now():%Y}, {author}' 13 | release = sctour.__version__ 14 | 15 | 16 | # -- General configuration --------------------------------------------------- 17 | extensions = [ 18 | 'sphinx.ext.autodoc', 19 | 'sphinx.ext.mathjax', 20 | 'sphinx.ext.napoleon', 21 | 'sphinx.ext.intersphinx', 22 | 'sphinx.ext.viewcode', 23 | 'myst_parser', 24 | 'nbsphinx', 25 | 'sphinx.ext.autosummary', 26 | ] 27 | autosummary_generate = True 28 | autodoc_member_order = 'bysource' 29 | napoleon_include_init_with_doc = False 30 | napoleon_numpy_docstring = True 31 | napoleon_use_rtype = True 32 | napoleon_use_param = True 33 | 34 | # settings 35 | master_doc = 'index' 36 | templates_path = ["_templates"] 37 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 38 | pygments_style = 'sphinx' 39 | 40 | intersphinx_mapping = dict( 41 | python=('https://docs.python.org/3/', None), 42 | numpy=('https://numpy.org/doc/stable/', None), 43 | pandas=('https://pandas.pydata.org/pandas-docs/stable/', None), 44 | anndata=('https://anndata.readthedocs.io/en/stable/', None), 45 | scanpy=('https://scanpy.readthedocs.io/en/stable/', None), 46 | scipy=('https://docs.scipy.org/doc/scipy/reference/', None), 47 | torch=('https://pytorch.org/docs/master/', None), 48 | matplotlib=('https://matplotlib.org/stable/', None), 49 | ) 50 | 51 | 52 | # -- Options for HTML output ------------------------------------------------- 53 | html_theme = 'sphinx_rtd_theme' 54 | html_static_path = ['_static'] 55 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | |PyPI| |Docs| 2 | 3 | scTour: a deep learning architecture for robust inference and accurate prediction of cellular dynamics 4 | ====================================================================================================== 5 | 6 | .. image:: https://raw.githubusercontent.com/LiQian-XC/sctour/main/docs/source/_static/img/scTour_head_image.png 7 | :width: 400px 8 | :align: left 9 | 10 | scTour is an innovative and comprehensive method for dissecting cellular dynamics by analysing datasets derived from single-cell genomics. 11 | 12 | It provides a unifying framework to depict the full picture of developmental processes from multiple angles including the developmental pseudotime, vector field and latent space. 13 | 14 | It further generalises these functionalities to a multi-task architecture for within-dataset inference and cross-dataset prediction of cellular dynamics in a batch-insensitive manner. 15 | 16 | Key features 17 | ------------ 18 | 19 | - cell pseudotime estimation with no need for specifying starting cells. 20 | - transcriptomic vector field inference with no discrimination between spliced and unspliced mRNAs. 21 | - latent space mapping by combining intrinsic transcriptomic structure with extrinsic pseudotime ordering. 22 | - model-based prediction of pseudotime, vector field, and latent space for query cells/datasets/time intervals. 23 | - insensitive to batch effects; robust to cell subsampling; scalable to large datasets. 24 | 25 | Installation 26 | ------------ 27 | 28 | scTour requires Python ≥ 3.7:: 29 | 30 | pip install sctour 31 | 32 | Reference 33 | --------- 34 | 35 | Qian Li, scTour: a deep learning architecture for robust inference and accurate prediction of cellular dynamics, 2023, `Genome Biology `_. 36 | 37 | .. toctree:: 38 | :maxdepth: 2 39 | :caption: Tutorials 40 | :hidden: 41 | 42 | notebook/scTour_inference_basic 43 | notebook/scTour_inference_PostInference_adjustment 44 | notebook/scTour_prediction 45 | 46 | .. toctree:: 47 | :maxdepth: 2 48 | :caption: API 49 | :hidden: 50 | 51 | api_train 52 | api_predict 53 | api_vf 54 | api_reverse_time 55 | 56 | .. |PyPI| image:: https://img.shields.io/pypi/v/sctour.svg 57 | :target: https://pypi.org/project/sctour 58 | 59 | .. |Docs| image:: https://readthedocs.org/projects/sctour/badge/?version=latest 60 | :target: https://sctour.readthedocs.io 61 | -------------------------------------------------------------------------------- /sctour/__init__.py: -------------------------------------------------------------------------------- 1 | from . import train 2 | from . import predict 3 | from . import vector_field as vf 4 | __version__ = '1.0.0' 5 | -------------------------------------------------------------------------------- /sctour/_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy.sparse import issparse 5 | 6 | 7 | ##calculate KL divergence 8 | def normal_kl(mu1, lv1, mu2, lv2): 9 | """ 10 | Calculate KL divergence 11 | This function is from torchdiffeq: https://github.com/rtqichen/torchdiffeq/blob/master/examples/latent_ode.py 12 | """ 13 | v1 = torch.exp(lv1) 14 | v2 = torch.exp(lv2) 15 | lstd1 = lv1/2. 16 | lstd2 = lv2/2. 17 | 18 | kl = lstd2 - lstd1 + (v1 + (mu1-mu2)**2.)/(2.*v2) - 0.5 19 | return kl 20 | 21 | 22 | ## get step size 23 | def get_step_size(step_size, t1, t2, t_size): 24 | if step_size is None: 25 | options = {} 26 | else: 27 | step_size = (t2 - t1)/t_size/step_size 28 | options = dict(step_size = step_size) 29 | return options 30 | 31 | 32 | ##calculate log zinb probability 33 | def log_zinb(x, mu, theta, pi, eps=1e-8): 34 | """ 35 | Calculate log probability under zero-inflated negative binomial distribution 36 | This function is from scvi-tools: https://github.com/YosefLab/scvi-tools/blob/6dae6482efa2d235182bf4ad10dbd9483b7d57cd/scvi/distributions/_negative_binomial.py 37 | """ 38 | softplus_pi = F.softplus(-pi) 39 | log_theta_eps = torch.log(theta + eps) 40 | log_theta_mu_eps = torch.log(theta + mu + eps) 41 | pi_theta_log = -pi + theta * (log_theta_eps - log_theta_mu_eps) 42 | 43 | case_zero = F.softplus(pi_theta_log) - softplus_pi 44 | mul_case_zero = torch.mul((x < eps).type(torch.float32), case_zero) 45 | 46 | case_non_zero = ( 47 | -softplus_pi 48 | + pi_theta_log 49 | + x * (torch.log(mu + eps) - log_theta_mu_eps) 50 | + torch.lgamma(x + theta) 51 | - torch.lgamma(theta) 52 | - torch.lgamma(x + 1) 53 | ) 54 | mul_case_non_zero = torch.mul((x > eps).type(torch.float32), case_non_zero) 55 | 56 | res = mul_case_zero + mul_case_non_zero 57 | return res 58 | 59 | 60 | ##calculate log nb probability 61 | def log_nb(x, mu, theta, eps=1e-8): 62 | """ 63 | Calculate log probability under negative binomial distribution 64 | This function is from scvi-tools: https://github.com/YosefLab/scvi-tools/blob/6dae6482efa2d235182bf4ad10dbd9483b7d57cd/scvi/distributions/_negative_binomial.py 65 | """ 66 | log_theta_mu_eps = torch.log(theta + mu + eps) 67 | 68 | res = ( 69 | theta * (torch.log(theta + eps) - log_theta_mu_eps) 70 | + x * (torch.log(mu + eps) - log_theta_mu_eps) 71 | + torch.lgamma(x + theta) 72 | - torch.lgamma(theta) 73 | - torch.lgamma(x + 1) 74 | ) 75 | 76 | return res 77 | 78 | 79 | ##L2 norm 80 | def l2_norm(x, axis=-1): 81 | if issparse(x): 82 | return np.sqrt(x.multiply(x).sum(axis=axis).A1) 83 | else: 84 | return np.sqrt(np.sum(x * x, axis = axis)) 85 | -------------------------------------------------------------------------------- /sctour/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from scipy import sparse 4 | import numpy as np 5 | from anndata import AnnData 6 | 7 | 8 | def split_data( 9 | adata: AnnData, 10 | percent: float, 11 | val_frac: float = 0.1, 12 | ): 13 | """ 14 | Split the dataset for training and validation 15 | 16 | Parameters 17 | ---------- 18 | adata 19 | The `AnnData` object for the whole dataset 20 | percent 21 | The percentage to be used for training the model 22 | val_frac 23 | The percentage to be used for validation 24 | 25 | Returns 26 | ---------- 27 | `AnnData` object for training and validation 28 | """ 29 | 30 | n_cells = adata.n_obs 31 | n_train = int(np.ceil(n_cells * percent)) 32 | n_val = min(int(np.floor(n_train * val_frac)), n_cells - n_train) 33 | 34 | indices = np.random.permutation(n_cells) 35 | train_idx = np.random.choice(indices, n_train, replace = False) 36 | indices2 = np.setdiff1d(indices, train_idx) 37 | val_idx = np.random.choice(indices2, n_val, replace = False) 38 | # train_idx = indices[:n_train] 39 | # val_idx = indices[n_train:(n_train + n_val)] 40 | 41 | train_data = adata[train_idx, :] 42 | val_data = adata[val_idx, :] 43 | 44 | return train_data, val_data 45 | 46 | 47 | def split_index( 48 | n_cells: int, 49 | percent: float, 50 | val_frac: float = 0.1, 51 | ): 52 | """ 53 | Split the indices for training and validation 54 | 55 | Parameters 56 | ---------- 57 | n_cells 58 | The total number of cells 59 | percent 60 | The percentage to be used for training the model 61 | val_frac 62 | The percentage to be used for validation 63 | 64 | Returns 65 | ---------- 66 | 2-tuple of indices 67 | """ 68 | 69 | n_train = int(np.ceil(n_cells * percent)) 70 | n_val = min(int(np.floor(n_train * val_frac)), n_cells - n_train) 71 | indices = np.random.permutation(n_cells) 72 | train_idx = np.random.choice(indices, n_train, replace = False) 73 | indices2 = np.setdiff1d(indices, train_idx) 74 | val_idx = np.random.choice(indices2, n_val, replace = False) 75 | # train_idx = indices[:n_train] 76 | # val_idx = indices[n_train:(n_train + n_val)] 77 | return train_idx, val_idx 78 | 79 | 80 | class MakeDataset(Dataset): 81 | """ 82 | A class to generate Dataset 83 | 84 | Parameters 85 | ---------- 86 | adata 87 | An `AnnData` object 88 | """ 89 | 90 | def __init__( 91 | self, 92 | adata: AnnData, 93 | loss_mode: str, 94 | ): 95 | X = adata.X 96 | if loss_mode in ['nb', 'zinb']: 97 | X = np.log1p(X) 98 | if sparse.issparse(X): 99 | X = X.A 100 | self.data = torch.tensor(X) 101 | self.library_size = self.data.sum(-1) 102 | 103 | def __len__(self): 104 | return self.data.size(0) 105 | 106 | def __getitem__(self, idx): 107 | return self.data[idx, :], self.library_size[idx] 108 | 109 | 110 | class BatchSampler(): 111 | """ 112 | A class to generate mini-batches through two rounds of randomization 113 | 114 | Parameters 115 | ---------- 116 | n 117 | Total number of cells 118 | batch_size 119 | Size of mini-batches 120 | drop_last 121 | Whether or not drop the last batch when its size is smaller than the batch_size 122 | """ 123 | def __init__( 124 | self, 125 | n: int, 126 | batch_size: int, 127 | drop_last: bool = False, 128 | ): 129 | self.batch_size = batch_size 130 | self.n = n 131 | self.drop_last = drop_last 132 | 133 | def __iter__(self): 134 | seq_n = torch.randperm(self.n) 135 | lb = self.n // self.batch_size 136 | idxs = np.arange(self.n) 137 | for i in range(lb): 138 | idx = np.random.choice(idxs, self.batch_size, replace=False) 139 | yield seq_n[idx].tolist() 140 | idxs = np.setdiff1d(idxs, idx) 141 | if (not self.drop_last) and (len(idxs) > 0): 142 | yield seq_n[idxs].tolist() 143 | 144 | def __len__(self): 145 | if self.drop_last: 146 | return self.n // self.batch_size 147 | else: 148 | return int(np.ceil(self.n / self.batch_size)) 149 | -------------------------------------------------------------------------------- /sctour/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | logging.basicConfig(level=logging.INFO, format="%(message)s") 4 | logger = logging.getLogger(__name__) 5 | 6 | info = logger.info 7 | warn = logger.warning 8 | error = logger.error 9 | debug = logger.debug 10 | -------------------------------------------------------------------------------- /sctour/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.distributions import Normal, kl_divergence 5 | from torchdiffeq import odeint 6 | from typing import Optional 7 | from typing_extensions import Literal 8 | 9 | from .module import LatentODEfunc, Encoder, Decoder 10 | from ._utils import get_step_size, normal_kl, log_zinb, log_nb 11 | 12 | 13 | class TNODE(nn.Module): 14 | """ 15 | Class to automatically infer cellular dynamics using VAE and neural ODE. 16 | 17 | Parameters 18 | ---------- 19 | device 20 | The torch device. 21 | n_int 22 | The dimensionality of the input. 23 | n_latent 24 | The dimensionality of the latent space. 25 | (Default: 5) 26 | n_ode_hidden 27 | The dimensionality of the hidden layer for the latent ODE function. 28 | (Default: 25) 29 | n_vae_hidden 30 | The dimensionality of the hidden layer for the VAE. 31 | (Default: 128) 32 | batch_norm 33 | Whether to include `BatchNorm` layer. 34 | (Default: `False`) 35 | ode_method 36 | Solver for integration. 37 | (Default: `'euler'`) 38 | step_size 39 | The step size during integration. 40 | alpha_recon_lec 41 | Scaling factor for reconstruction loss from encoder-derived latent space. 42 | (Default: 0.5) 43 | alpha_recon_lode 44 | Scaling factor for reconstruction loss from ODE-solver latent space. 45 | (Default: 0.5) 46 | alpha_kl 47 | Scaling factor for KL divergence. 48 | (Default: 1.0) 49 | loss_mode 50 | The mode for calculating the reconstruction error. 51 | (Default: `'nb'`) 52 | """ 53 | 54 | def __init__( 55 | self, 56 | device, 57 | n_int: int, 58 | n_latent: int = 5, 59 | n_ode_hidden: int = 25, 60 | n_vae_hidden: int = 128, 61 | batch_norm: bool = False, 62 | ode_method: str = 'euler', 63 | step_size: Optional[int] = None, 64 | alpha_recon_lec: float = 0.5, 65 | alpha_recon_lode: float = 0.5, 66 | alpha_kl: float = 1., 67 | loss_mode: Literal['mse', 'nb', 'zinb'] = 'nb', 68 | ): 69 | super().__init__() 70 | self.n_int = n_int 71 | self.n_latent = n_latent 72 | self.n_ode_hidden = n_ode_hidden 73 | self.n_vae_hidden = n_vae_hidden 74 | self.batch_norm = batch_norm 75 | self.ode_method = ode_method 76 | self.step_size = step_size 77 | self.alpha_recon_lec = alpha_recon_lec 78 | self.alpha_recon_lode = alpha_recon_lode 79 | self.alpha_kl = alpha_kl 80 | self.loss_mode = loss_mode 81 | self.device = device 82 | 83 | self.lode_func = LatentODEfunc(n_latent, n_ode_hidden) 84 | self.encoder = Encoder(n_int, n_latent, n_vae_hidden, batch_norm).to(self.device) 85 | self.decoder = Decoder(n_int, n_latent, n_vae_hidden, batch_norm, loss_mode).to(self.device) 86 | 87 | def forward(self, x: torch.Tensor, y: torch.Tensor) -> tuple: 88 | """ 89 | Given the transcriptomes of cells, this function derives the time and latent space of the cells, as well as reconstructs the transcriptomes. 90 | 91 | Parameters 92 | ---------- 93 | x 94 | The input data. 95 | y 96 | The library size. 97 | 98 | Returns 99 | ---------- 100 | 5-tuple of :class:`torch.Tensor` 101 | Tensors for loss, including: 102 | 1) total loss, 103 | 2) reconstruction loss from encoder-derived latent space, 104 | 3) reconstruction loss from ODE-solver latent space, 105 | 4) KL divergence, 106 | 5) divergence between encoder-derived latent space and ODE-solver latent space 107 | """ 108 | 109 | ## get the time and latent space through Encoder 110 | T, qz_mean, qz_logvar = self.encoder(x) 111 | T = T.ravel() ## odeint requires 1-D Tensor for time 112 | epsilon = torch.randn(qz_mean.size()).to(T.device) 113 | z = epsilon * torch.exp(.5 * qz_logvar) + qz_mean 114 | 115 | index = torch.argsort(T) 116 | T = T[index] 117 | x = x[index] 118 | z = z[index] 119 | y = y[index] 120 | # qz_mean = qz_mean[index] 121 | # qz_logvar = qz_logvar[index] 122 | index2 = (T[:-1] != T[1:]) 123 | index2 = torch.cat((index2, torch.tensor([True]).to(index2.device))) ## index2 is used to get unique time points as odeint requires strictly increasing/decreasing time points 124 | T = T[index2] 125 | x = x[index2] 126 | z = z[index2] 127 | y = y[index2] 128 | # qz_mean = qz_mean[index2] 129 | # qz_logvar = qz_logvar[index2] 130 | 131 | ## infer the latent space through ODE solver based on z0, t, and LatentODEfunc 132 | z0 = z[0] 133 | options = get_step_size(self.step_size, T[0], T[-1], len(T)) 134 | pred_z = odeint(self.lode_func, z0.to('cpu'), T.to('cpu'), method = self.ode_method, options = options).view(-1, self.n_latent) 135 | pred_z = pred_z.to(z.device) 136 | 137 | ## reconstruct the input through Decoder and compute reconstruction loss 138 | if self.loss_mode == 'mse': 139 | pred_x1 = self.decoder(z) ## decode through latent space returned by Encoder 140 | pred_x2 = self.decoder(pred_z) ## decode through latent space returned by ODE solver 141 | recon_loss_ec = F.mse_loss(x, pred_x1, reduction='none').sum(-1).mean() 142 | recon_loss_ode = F.mse_loss(x, pred_x2, reduction='none').sum(-1).mean() 143 | if self.loss_mode == 'nb': 144 | pred_x1 = self.decoder(z) ## decode through latent space returned by Encoder 145 | pred_x2 = self.decoder(pred_z) ## decode through latent space returned by ODE solver 146 | y = y.unsqueeze(1).expand(pred_x1.size(0), pred_x1.size(1)) 147 | pred_x1 = pred_x1 * y 148 | pred_x2 = pred_x2 * y 149 | disp = torch.exp(self.decoder.disp) 150 | recon_loss_ec = -log_nb(x, pred_x1, disp).sum(-1).mean() 151 | recon_loss_ode = -log_nb(x, pred_x2, disp).sum(-1).mean() 152 | if self.loss_mode == 'zinb': 153 | pred_x1, dp1 = self.decoder(z) 154 | pred_x2, dp2 = self.decoder(pred_z) 155 | y = y.unsqueeze(1).expand(pred_x1.size(0), pred_x1.size(1)) 156 | pred_x1 = pred_x1 * y 157 | pred_x2 = pred_x2 * y 158 | disp = torch.exp(self.decoder.disp) 159 | recon_loss_ec = -log_zinb(x, pred_x1, disp, dp1).sum(-1).mean() 160 | recon_loss_ode = -log_zinb(x, pred_x2, disp, dp2).sum(-1).mean() 161 | 162 | ## compute KL divergence and z divergence 163 | z_div = F.mse_loss(z, pred_z, reduction='none').sum(-1).mean() 164 | pz_mean = torch.zeros_like(qz_mean) 165 | pz_logvar = torch.zeros_like(qz_mean) 166 | kl_div = normal_kl(qz_mean, qz_logvar, pz_mean, pz_logvar).sum(-1).mean() 167 | 168 | loss = self.alpha_recon_lec * recon_loss_ec + self.alpha_recon_lode * recon_loss_ode + z_div + self.alpha_kl * kl_div 169 | 170 | return loss, recon_loss_ec, recon_loss_ode, kl_div, z_div 171 | -------------------------------------------------------------------------------- /sctour/module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing_extensions import Literal 4 | 5 | 6 | class LatentODEfunc(nn.Module): 7 | """ 8 | A class modelling the latent state derivatives with respect to time. 9 | 10 | Parameters 11 | ---------- 12 | n_latent 13 | The dimensionality of the latent space. 14 | (Default: 5) 15 | n_hidden 16 | The dimensionality of the hidden layer. 17 | (Default: 25) 18 | """ 19 | 20 | def __init__( 21 | self, 22 | n_latent: int = 5, 23 | n_hidden: int = 25, 24 | ): 25 | super().__init__() 26 | self.elu = nn.ELU() 27 | self.fc1 = nn.Linear(n_latent, n_hidden) 28 | self.fc2 = nn.Linear(n_hidden, n_latent) 29 | 30 | def forward(self, t: torch.Tensor, x: torch.Tensor): 31 | """ 32 | Compute the gradient at a given time t and a given state x. 33 | 34 | Parameters 35 | ---------- 36 | t 37 | A given time point. 38 | x 39 | A given latent state. 40 | 41 | Returns 42 | ---------- 43 | :class:`torch.Tensor` 44 | A tensor 45 | """ 46 | out = self.fc1(x) 47 | out = self.elu(out) 48 | out = self.fc2(out) 49 | return out 50 | 51 | 52 | class Encoder(nn.Module): 53 | """ 54 | Encoder class generating the time and latent space. 55 | 56 | Parameters 57 | ---------- 58 | n_int 59 | The dimensionality of the input. 60 | n_latent 61 | The dimensionality of the latent space. 62 | (Default: 5) 63 | n_hidden 64 | The dimensionality of the hidden layer. 65 | (Default: 128) 66 | batch_norm 67 | Whether to include `BatchNorm` layer or not. 68 | (Default: `False`) 69 | """ 70 | 71 | def __init__( 72 | self, 73 | n_int: int, 74 | n_latent: int = 5, 75 | n_hidden: int = 128, 76 | batch_norm: bool = False, 77 | ): 78 | super().__init__() 79 | self.n_latent = n_latent 80 | self.fc = nn.Sequential() 81 | self.fc.add_module('L1', nn.Linear(n_int, n_hidden)) 82 | if batch_norm: 83 | self.fc.add_module('N1', nn.BatchNorm1d(n_hidden)) 84 | self.fc.add_module('A1', nn.ReLU()) 85 | self.fc2 = nn.Linear(n_hidden, n_latent*2) 86 | self.fc3 = nn.Linear(n_hidden, 1) 87 | 88 | def forward(self, x:torch.Tensor) -> tuple: 89 | x = self.fc(x) 90 | out = self.fc2(x) 91 | qz_mean, qz_logvar = out[:, :self.n_latent], out[:, self.n_latent:] 92 | t = self.fc3(x).sigmoid() 93 | return t, qz_mean, qz_logvar 94 | 95 | 96 | class Decoder(nn.Module): 97 | """ 98 | Decoder class to reconstruct the original input based on its latent space. 99 | 100 | Parameters 101 | ---------- 102 | n_latent 103 | The dimensionality of the latent space. 104 | (Default: 5) 105 | n_int 106 | The dimensionality of the original input. 107 | n_hidden 108 | The dimensionality of the hidden layer. 109 | (Default: 128) 110 | batch_norm 111 | Whether to include `BatchNorm` layer or not. 112 | (Default: `False`) 113 | loss_mode 114 | The mode for reconstructing the original data. 115 | (Default: `'nb'`) 116 | """ 117 | 118 | def __init__( 119 | self, 120 | n_int: int, 121 | n_latent: int = 5, 122 | n_hidden: int = 128, 123 | batch_norm: bool = False, 124 | loss_mode: Literal['mse', 'nb', 'zinb'] = 'nb', 125 | ): 126 | super().__init__() 127 | self.loss_mode = loss_mode 128 | if loss_mode in ['nb', 'zinb']: 129 | self.disp = nn.Parameter(torch.randn(n_int)) 130 | 131 | self.fc = nn.Sequential() 132 | self.fc.add_module('L1', nn.Linear(n_latent, n_hidden)) 133 | if batch_norm: 134 | self.fc.add_module('N1', nn.BatchNorm1d(n_hidden)) 135 | self.fc.add_module('A1', nn.ReLU()) 136 | 137 | if loss_mode == 'mse': 138 | self.fc2 = nn.Linear(n_hidden, n_int) 139 | if loss_mode in ['nb', 'zinb']: 140 | self.fc2 = nn.Sequential(nn.Linear(n_hidden, n_int), nn.Softmax(dim = -1)) 141 | if loss_mode == 'zinb': 142 | self.fc3 = nn.Linear(n_hidden, n_int) 143 | 144 | def forward(self, z: torch.Tensor): 145 | out = self.fc(z) 146 | recon_x = self.fc2(out) 147 | if self.loss_mode == 'zinb': 148 | disp = self.fc3(out) 149 | return recon_x, disp 150 | else: 151 | return recon_x 152 | -------------------------------------------------------------------------------- /sctour/predict.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchdiffeq import odeint 3 | from typing import Optional, Union 4 | from typing_extensions import Literal 5 | import numpy as np 6 | from anndata import AnnData 7 | from scipy import sparse 8 | from scipy.sparse import spmatrix 9 | import os 10 | 11 | from ._utils import get_step_size 12 | from .train import Trainer 13 | 14 | 15 | def _check_data( 16 | adata1: AnnData, 17 | adata2: AnnData, 18 | loss_mode: str, 19 | ) -> np.ndarray: 20 | """ 21 | Check the query data. 22 | 23 | Parameters 24 | ---------- 25 | adata1 26 | An :class:`~anndata.AnnData` object for the query dataset. 27 | adata2 28 | An :class:`~anndata.AnnData` object for the training dataset. 29 | loss_mode 30 | The `loss_mode` used for model training. 31 | 32 | Returns 33 | ---------- 34 | :class:`~numpy.ndarray` 35 | The expression matrix for the query dataset. 36 | """ 37 | 38 | if len(adata1.var_names.intersection(adata2.var_names)) != adata2.n_vars: 39 | raise ValueError( 40 | "The query AnnData must contain all the genes that are used for training in the training dataset." 41 | ) 42 | 43 | X = adata1[:, adata2.var_names].X 44 | if loss_mode == 'mse': 45 | if (X.min() < 0) or (X.max() > np.log1p(1e6)): 46 | raise ValueError( 47 | "Invalid expression matrix in `.X`. Model trained from `mse` mode expects log1p(normalized expression) in `.X` of the query AnnData." 48 | ) 49 | else: 50 | data = X.data if sparse.issparse(X) else X 51 | if (data.min() < 0) or np.any(~np.equal(np.mod(data, 1), 0)): 52 | raise ValueError( 53 | f"Invalid expression matrix in `.X`. Model trained from `{loss_mode}` mode expects raw UMI counts in `.X` of the query AnnData." 54 | ) 55 | else: 56 | X = np.log1p(X) 57 | 58 | return X 59 | 60 | 61 | def load_model(model: str): 62 | """ 63 | Load the trained scTour model for prediction. 64 | 65 | Parameters 66 | ---------- 67 | model 68 | Filename for the scTour model trained and saved. 69 | 70 | Returns 71 | ---------- 72 | :class:`sctour.train.Trainer` 73 | The trained scTour model. 74 | """ 75 | 76 | if not os.path.isfile(model): 77 | raise FileNotFoundError( 78 | f'No such file: `{model}`.' 79 | ) 80 | 81 | checkpoint = torch.load(model, map_location=torch.device('cpu')) 82 | # checkpoint = torch.load(model) 83 | model_kwargs = checkpoint['model_kwargs'] 84 | del model_kwargs['device'] 85 | del model_kwargs['n_int'] 86 | tnode = Trainer( 87 | adata = checkpoint['adata'], 88 | percent = checkpoint['percent'], 89 | nepoch = checkpoint['nepoch'], 90 | batch_size = checkpoint['batch_size'], 91 | drop_last = checkpoint['drop_last'], 92 | lr = checkpoint['lr'], 93 | wt_decay = checkpoint['wt_decay'], 94 | eps = checkpoint['eps'], 95 | random_state = checkpoint['random_state'], 96 | val_frac = checkpoint['val_frac'], 97 | use_gpu = checkpoint['use_gpu'], 98 | **model_kwargs, 99 | ) 100 | tnode.model.load_state_dict(checkpoint['model_state_dict']) 101 | tnode.time_reverse = checkpoint['time_reverse'] 102 | return tnode 103 | 104 | 105 | def predict_time( 106 | model: Trainer, 107 | adata: AnnData, 108 | reverse: bool = False, 109 | ) -> np.ndarray: 110 | """ 111 | Predict the pseudotime for query cells. 112 | 113 | Parameters 114 | ---------- 115 | model 116 | A :class:`sctour.train.Trainer` for trained scTour model. 117 | adata 118 | An :class:`~anndata.AnnData` object for the query dataset. 119 | reverse 120 | Whether to reverse the predicted pseudotime. When the pseudotime returned by `get_time()` function for the training data was in reverse order and you used the post-inference adjustment (`reverse_time()` function), please set this parameter to `True`. 121 | (Default: `False`) 122 | 123 | Returns 124 | ---------- 125 | :class:`~numpy.ndarray` 126 | The pseudotime predicted for the query cells. 127 | """ 128 | 129 | if model.time_reverse is None: 130 | raise RuntimeError( 131 | 'It seems you did not run `get_time()` function after model training. Please run `get_time()` first after training for the training data before you run `predict_time()` for the query data.' 132 | ) 133 | 134 | X = _check_data(adata, model.adata, model.loss_mode) 135 | ts = model._get_time(model = model.model, X = X) 136 | if model.time_reverse: 137 | ts = 1 - ts 138 | if reverse: 139 | ts = 1 - ts 140 | return ts.cpu().numpy() 141 | 142 | 143 | def predict_vector_field( 144 | model: Trainer, 145 | T: np.ndarray, 146 | Z: np.ndarray, 147 | ) -> np.ndarray: 148 | """ 149 | Predict the vector field for query cells. 150 | 151 | Parameters 152 | ---------- 153 | model 154 | A :class:`sctour.train.Trainer` for trained scTour model. 155 | T 156 | The predicted pseudotime for query cells. 157 | Z 158 | The predicted latent representations for query cells. 159 | 160 | Returns 161 | ---------- 162 | :class:`~numpy.ndarray` 163 | The vector field predicted for query cells. 164 | """ 165 | 166 | vf = model._get_vector_field( 167 | model = model.model, 168 | T = T, 169 | Z = Z, 170 | time_reverse = model.time_reverse, 171 | ) 172 | return vf 173 | 174 | 175 | def predict_latentsp( 176 | model: Trainer, 177 | adata: AnnData, 178 | mode: Literal['coarse', 'fine'] = 'fine', 179 | alpha_z: float = .5, 180 | alpha_predz: float = .5, 181 | step_size: Optional[int] = None, 182 | step_wise: bool = False, 183 | batch_size: Optional[int] = None, 184 | ) -> tuple: 185 | """ 186 | Predict the latent representations for query cells given their transcriptomes. 187 | 188 | Parameters 189 | ---------- 190 | model 191 | A :class:`sctour.train.Trainer` for trained scTour model. 192 | adata 193 | An :class:`~anndata.AnnData` object for the query dataset. 194 | mode 195 | The mode for deriving the latent space for the query dataset. 196 | Two modes are included: 197 | ``'fine'``: derive the latent space by taking the training data into consideration; 198 | ``'coarse'``: derive the latent space directly from the query data without involving the training data. 199 | alpha_z 200 | Scaling factor for encoder-derived latent space. 201 | (Default: 0.5) 202 | alpha_predz 203 | Scaling factor for ODE-solver-derived latent space. 204 | (Default: 0.5) 205 | step_size 206 | The step size during integration. 207 | step_wise 208 | Whether to perform step-wise integration by iteratively considering only two time points each time. 209 | (Default: `False`) 210 | batch_size 211 | Batch size when deriving the latent space. The default is no mini-batching. 212 | 213 | Returns 214 | ---------- 215 | tuple 216 | 3-tuple of weighted combined latent space, encoder-derived latent space, and ODE-solver-derived latent space. 217 | """ 218 | 219 | X = _check_data(adata, model.adata, model.loss_mode) 220 | if mode == 'coarse': 221 | mix_zs, zs, pred_zs = model._get_latentsp( 222 | model = model.model, 223 | X = X, 224 | alpha_z = alpha_z, 225 | alpha_predz = alpha_predz, 226 | step_size = step_size, 227 | step_wise = step_wise, 228 | batch_size = batch_size, 229 | ) 230 | if mode == 'fine': 231 | X2 = model.adata.X 232 | if model.loss_mode in ['nb', 'zinb']: 233 | X2 = np.log1p(X2) 234 | if sparse.issparse(X2): 235 | X2 = X2.A 236 | if sparse.issparse(X): 237 | X = X.A 238 | mix_zs, zs, pred_zs = model._get_latentsp( 239 | model = model.model, 240 | X = np.vstack((X, X2)), 241 | alpha_z = alpha_z, 242 | alpha_predz = alpha_predz, 243 | step_size = step_size, 244 | step_wise = step_wise, 245 | batch_size = batch_size, 246 | ) 247 | mix_zs = mix_zs[:len(X)] 248 | zs = zs[:len(X)] 249 | pred_zs = pred_zs[:len(X)] 250 | 251 | return mix_zs, zs, pred_zs 252 | 253 | 254 | @torch.no_grad() 255 | def predict_ltsp_from_time( 256 | model: Trainer, 257 | T: np.ndarray, 258 | reverse: bool = False, 259 | step_wise: bool = True, 260 | step_size: Optional[int] = None, 261 | alpha_z: float = 0.5, 262 | alpha_predz: float = 0.5, 263 | k: int = 20, 264 | ) -> np.ndarray: 265 | """ 266 | Predict the transcriptomic latent space for query (unobserved) time intervals. 267 | 268 | Parameters 269 | ---------- 270 | model 271 | A :class:`sctour.train.Trainer` for trained scTour model. 272 | T 273 | A 1D numpy array containing the query time points (with values between 0 and 1). The latent space corresponding to these time points will be predicted. 274 | reverse 275 | When the pseudotime returned by `get_time()` function for the training data was in reverse order and you used the post-inference adjustment (`reverse_time()` function), please set this parameter to `True`. 276 | (Default: `False`) 277 | step_wise 278 | Whether to perform step-wise integration by iteratively considering only two time points when inferring the reference latent space from the training data. 279 | (Default: `True`) 280 | step_size 281 | The step size during integration. 282 | alpha_z 283 | Scaling factor for encoder-derived latent space. 284 | (Default: 0.5) 285 | alpha_predz 286 | Scaling factor for ODE-solver-derived latent space. 287 | (Default: 0.5) 288 | k 289 | The k nearest neighbors in the time space considered when predicting the latent representation for each query time point. 290 | (Default: 20) 291 | 292 | Returns 293 | ---------- 294 | :class:`~numpy.ndarray` 295 | Predicted latent space corresponding to the query time interval. 296 | """ 297 | 298 | mdl = model.model 299 | 300 | if not isinstance(T, np.ndarray): 301 | raise TypeError( 302 | "The input time interval must be a numpy array." 303 | ) 304 | if len(T.shape) > 1: 305 | raise TypeError( 306 | "The input time interval must be a 1D numpy array." 307 | ) 308 | if np.any(T < 0) or np.any(T > 1): 309 | raise ValueError( 310 | "The input time points must be in [0, 1]." 311 | ) 312 | 313 | ridx = np.random.permutation(len(T)) 314 | rT = torch.tensor(T[ridx]) 315 | ## get the reference time and latent space from the training data 316 | X = model.adata.X 317 | if model.loss_mode in ['nb', 'zinb']: 318 | X = np.log1p(X) 319 | mix_zs, zs, pred_zs = model._get_latentsp(model = mdl, 320 | X = X, 321 | alpha_z = alpha_z, 322 | alpha_predz = alpha_predz, 323 | step_wise = step_wise, 324 | step_size = step_size, 325 | ) 326 | ts = model._get_time(model = mdl, X = X) 327 | if model.time_reverse: 328 | ts = 1 - ts 329 | if reverse: 330 | ts = 1 - ts 331 | 332 | ts = ts.cpu() 333 | zs = torch.tensor(mix_zs) 334 | 335 | pred_T_zs = torch.empty((len(rT), mdl.n_latent)) 336 | for i, t in enumerate(rT): 337 | diff = torch.abs(t - ts) 338 | idxs = torch.argsort(diff) 339 | # n = (diff == 0).sum() 340 | # idxs = idxs[n:(k + n)] 341 | if (diff == 0).any(): 342 | pred_T_zs[i] = zs[idxs[0]].clone() 343 | else: 344 | idxs = idxs[:k] 345 | k_zs = torch.empty((k, mdl.n_latent)) 346 | for j, idx in enumerate(idxs): 347 | z0 = zs[idx].clone() 348 | t0 = ts[idx].clone() 349 | pred_t = torch.stack((t0, t)) 350 | if pred_t[0] < pred_t[1]: 351 | options = get_step_size(step_size, pred_t[0], pred_t[-1], len(pred_t)) 352 | else: 353 | options = get_step_size(step_size, pred_t[-1], pred_t[0], len(pred_t)) 354 | k_zs[j] = odeint( 355 | mdl.lode_func, 356 | z0, 357 | pred_t, 358 | method = mdl.ode_method, 359 | options = options 360 | )[1] 361 | k_zs = torch.mean(k_zs, dim = 0) 362 | pred_T_zs[i] = k_zs 363 | ts = torch.cat((ts, t.unsqueeze(0))) 364 | zs = torch.cat((zs, k_zs.unsqueeze(0))) 365 | 366 | pred_T_zs = pred_T_zs[np.argsort(ridx)] 367 | return pred_T_zs.numpy() 368 | -------------------------------------------------------------------------------- /sctour/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from torchdiffeq import odeint 4 | from typing import Optional, Union 5 | from typing_extensions import Literal 6 | import numpy as np 7 | from anndata import AnnData 8 | from scipy import sparse 9 | from scipy.sparse import spmatrix 10 | from tqdm import tqdm 11 | import os 12 | from collections import defaultdict 13 | 14 | from .model import TNODE 15 | from ._utils import get_step_size 16 | from .data import split_data, MakeDataset, BatchSampler 17 | from . import logger 18 | 19 | 20 | ##reverse time 21 | def reverse_time( 22 | T: np.ndarray, 23 | ) -> np.ndarray: 24 | """ 25 | Post-inference adjustment to reverse the pseudotime. 26 | 27 | Parameters 28 | ---------- 29 | T 30 | The pseudotime inferred for each cell. 31 | 32 | Returns 33 | ---------- 34 | :class:`~numpy.ndarray` 35 | The reversed pseudotime. 36 | """ 37 | 38 | return 1 - T 39 | 40 | 41 | class Trainer: 42 | """ 43 | Class for implementing the scTour training process. 44 | 45 | Parameters 46 | ---------- 47 | adata 48 | An :class:`~anndata.AnnData` object for the training data. 49 | percent 50 | The percentage of cells used for model training. Default to 0.2 when the cell number > 10,000 and to 0.9 otherwise. 51 | n_latent 52 | The dimensionality of the latent space. 53 | (Default: 5) 54 | n_ode_hidden 55 | The dimensionality of the hidden layer for the latent ODE function. 56 | (Default: 25) 57 | n_vae_hidden 58 | The dimensionality of the hidden layer for the VAE. 59 | (Default: 128) 60 | batch_norm 61 | Whether to include a `BatchNorm` layer. 62 | (Default: `False`) 63 | ode_method 64 | The solver for ODE. List of ODE solvers can be found in `torchdiffeq`. 65 | (Default: `'euler'`) 66 | step_size 67 | The step size during integration. 68 | alpha_recon_lec 69 | The scaling factor for the reconstruction error from encoder-derived latent space. 70 | (Default: 0.5) 71 | alpha_recon_lode 72 | The scaling factor for the reconstruction error from ODE-solver-derived latent space. 73 | (Default: 0.5) 74 | alpha_kl 75 | The scaling factor for the KL divergence in the loss function. 76 | (Default: 1.0) 77 | loss_mode 78 | The mode for calculating the reconstruction error. 79 | (Default: `'nb'`) 80 | Three modes are included: 81 | ``'mse'``: mean squared error; 82 | ``'nb'``: negative binomial conditioned likelihood; 83 | ``'zinb'``: zero-inflated negative binomial conditioned likelihood. 84 | nepoch 85 | Number of epochs. 86 | batch_size 87 | The batch size during training. 88 | (Default: 1024) 89 | drop_last 90 | Whether or not drop the last batch when its size is smaller than `batch_size`. 91 | (Default: `False`) 92 | lr 93 | The learning rate. 94 | (Default: 1e-3) 95 | wt_decay 96 | The weight decay (L2 penalty) for Adam optimizer. 97 | (Default: 1e-6) 98 | eps 99 | The `eps` parameter for Adam optimizer. 100 | (Default: 0.01) 101 | random_state 102 | The seed for generating random numbers. 103 | (Default: 0) 104 | val_frac 105 | The percentage of data used for validation. 106 | (Default: 0.1) 107 | use_gpu 108 | Whether to use GPU when available. 109 | (Default: `True`) 110 | """ 111 | 112 | def __init__( 113 | self, 114 | adata: AnnData, 115 | percent: Optional[float] = None, 116 | n_latent: int = 5, 117 | n_ode_hidden: int = 25, 118 | n_vae_hidden: int = 128, 119 | batch_norm: bool = False, 120 | ode_method: str = 'euler', 121 | step_size: Optional[int] = None, 122 | alpha_recon_lec: float = 0.5, 123 | alpha_recon_lode: float = 0.5, 124 | alpha_kl: float = 1., 125 | loss_mode: Literal['mse', 'nb', 'zinb'] = 'nb', 126 | nepoch: Optional[int] = None, 127 | batch_size: int = 1024, 128 | drop_last: bool = False, 129 | lr: float = 1e-3, 130 | wt_decay: float = 1e-6, 131 | eps: float = 0.01, 132 | random_state: int = 0, 133 | val_frac: float = 0.1, 134 | use_gpu: bool = True, 135 | ): 136 | self.loss_mode = loss_mode 137 | if self.loss_mode not in ['mse', 'nb', 'zinb']: 138 | raise ValueError( 139 | f"`loss_mode` must be one of ['mse', 'nb', 'zinb'], but input was '{self.loss_mode}'." 140 | ) 141 | 142 | if (alpha_recon_lec < 0) or (alpha_recon_lec > 1): 143 | raise ValueError( 144 | '`alpha_recon_lec` must be between 0 and 1.' 145 | ) 146 | if (alpha_recon_lode < 0) or (alpha_recon_lode > 1): 147 | raise ValueError( 148 | '`alpha_recon_lode` must be between 0 and 1.' 149 | ) 150 | if alpha_recon_lec + alpha_recon_lode != 1: 151 | raise ValueError( 152 | 'The sum of `alpha_recon_lec` and `alpha_recon_lode` must be 1.' 153 | ) 154 | 155 | self.adata = adata 156 | if 'n_genes_by_counts' not in self.adata.obs: 157 | raise KeyError( 158 | "`n_genes_by_counts` not found in `.obs` of the AnnData. Please run `scanpy.pp.calculate_qc_metrics` first to calculate the number of genes detected in each cell." 159 | ) 160 | if loss_mode == 'mse': 161 | if (self.adata.X.min() < 0) or (self.adata.X.max() > np.log1p(1e6)): 162 | raise ValueError( 163 | "Invalid expression matrix in `.X`. `mse` mode expects log1p(normalized expression) in `.X` of the AnnData." 164 | ) 165 | else: 166 | X = self.adata.X.data if sparse.issparse(self.adata.X) else self.adata.X 167 | if (X.min() < 0) or np.any(~np.equal(np.mod(X, 1), 0)): 168 | raise ValueError( 169 | f"Invalid expression matrix in `.X`. `{self.loss_mode}` mode expects raw UMI counts in `.X` of the AnnData." 170 | ) 171 | 172 | self.n_cells = adata.n_obs 173 | self.batch_size = batch_size 174 | self.drop_last = drop_last 175 | self.percent = percent 176 | if self.percent is None: 177 | if self.n_cells > 10000: 178 | self.percent = .2 179 | else: 180 | self.percent = .9 181 | else: 182 | if (self.percent < 0) or (self.percent > 1): 183 | raise ValueError( 184 | "`percent` must be between 0 and 1." 185 | ) 186 | self.val_frac = val_frac 187 | if (self.val_frac < 0) or (self.val_frac > 1): 188 | raise ValueError( 189 | '`val_frac` must be between 0 and 1.' 190 | ) 191 | 192 | if nepoch is None: 193 | ncells = round(self.n_cells * self.percent) 194 | self.nepoch = np.min([round((10000 / ncells) * 400), 400]) 195 | else: 196 | self.nepoch = nepoch 197 | 198 | self.lr = lr 199 | self.wt_decay = wt_decay 200 | self.eps = eps 201 | self.time_reverse = None 202 | 203 | self.random_state = random_state 204 | np.random.seed(random_state) 205 | # random.seed(random_state) 206 | torch.manual_seed(random_state) 207 | # torch.backends.cudnn.benchmark = False 208 | # torch.use_deterministic_algorithms(True) 209 | 210 | self.use_gpu = use_gpu 211 | gpu = torch.cuda.is_available() and use_gpu 212 | if gpu: 213 | torch.cuda.manual_seed(random_state) 214 | self.device = torch.device('cuda') 215 | logger.info('Running using GPU.') 216 | else: 217 | self.device = torch.device('cpu') 218 | logger.info('Running using CPU.') 219 | 220 | self.n_int = adata.n_vars 221 | self.model_kwargs = dict( 222 | device = self.device, 223 | n_int = self.n_int, 224 | n_latent = n_latent, 225 | n_ode_hidden = n_ode_hidden, 226 | n_vae_hidden = n_vae_hidden, 227 | batch_norm = batch_norm, 228 | ode_method = ode_method, 229 | step_size = step_size, 230 | alpha_recon_lec = alpha_recon_lec, 231 | alpha_recon_lode = alpha_recon_lode, 232 | alpha_kl = alpha_kl, 233 | loss_mode = loss_mode, 234 | ) 235 | self.model = TNODE(**self.model_kwargs) 236 | self.log = defaultdict(list) 237 | 238 | 239 | def _get_data_loaders(self) -> None: 240 | """ 241 | Generate Data Loaders for training and validation datasets. 242 | """ 243 | 244 | train_data, val_data = split_data(self.adata, self.percent, self.val_frac) 245 | self.train_dataset = MakeDataset(train_data, self.loss_mode) 246 | self.val_dataset = MakeDataset(val_data, self.loss_mode) 247 | 248 | # sampler = BatchSampler(train_data.n_obs, self.batch_size, self.drop_last) 249 | # self.train_dl = DataLoader(self.train_dataset, batch_sampler = sampler) 250 | self.train_dl = DataLoader(self.train_dataset, batch_size = self.batch_size, shuffle = True) 251 | self.val_dl = DataLoader(self.val_dataset, batch_size = self.batch_size) 252 | 253 | 254 | def train(self): 255 | """ 256 | Model training. 257 | """ 258 | self._get_data_loaders() 259 | 260 | params = filter(lambda p: p.requires_grad, self.model.parameters()) 261 | self.optimizer = torch.optim.Adam(params, lr = self.lr, weight_decay = self.wt_decay, eps = self.eps) 262 | 263 | with tqdm(total=self.nepoch, unit='epoch') as t: 264 | for tepoch in range(t.total): 265 | train_loss = self._on_epoch_train(self.train_dl) 266 | val_loss = self._on_epoch_val(self.val_dl) 267 | self.log['train_loss'].append(train_loss) 268 | self.log['validation_loss'].append(val_loss) 269 | t.set_description(f"Epoch {tepoch + 1}") 270 | t.set_postfix({'train_loss': train_loss, 'val_loss': val_loss}, refresh=False) 271 | t.update() 272 | 273 | 274 | def _on_epoch_train(self, DL) -> float: 275 | """ 276 | Go through the model and update the model parameters. 277 | 278 | Parameters 279 | ---------- 280 | DL 281 | DataLoader for training dataset. 282 | 283 | Returns 284 | ---------- 285 | float 286 | Training loss for the current epoch. 287 | """ 288 | 289 | self.model.train() 290 | total_loss = .0 291 | ss = 0 292 | for X, Y in DL: 293 | self.optimizer.zero_grad() 294 | X = X.to(self.device) 295 | Y = Y.to(self.device) 296 | loss, recon_loss_ec, recon_loss_ode, kl_div, z_div = self.model(X, Y) 297 | loss.backward() 298 | self.optimizer.step() 299 | 300 | total_loss += loss.item() * X.size(0) 301 | ss += X.size(0) 302 | 303 | train_loss = total_loss/ss 304 | return train_loss 305 | 306 | 307 | @torch.no_grad() 308 | def _on_epoch_val(self, DL) -> float: 309 | """ 310 | Validate using validation dataset. 311 | 312 | Parameters 313 | ---------- 314 | DL 315 | DataLoader for validation dataset. 316 | 317 | Returns 318 | ---------- 319 | float 320 | Validation loss for the current epoch. 321 | """ 322 | 323 | self.model.eval() 324 | total_loss = .0 325 | ss = 0 326 | for X, Y in DL: 327 | X = X.to(self.device) 328 | Y = Y.to(self.device) 329 | loss, recon_loss_ec, recon_loss_ode, kl_div, z_div = self.model(X, Y) 330 | total_loss += loss.item() * X.size(0) 331 | ss += X.size(0) 332 | 333 | val_loss = total_loss/ss 334 | return val_loss 335 | 336 | 337 | def get_time( 338 | self, 339 | ) -> np.ndarray: 340 | """ 341 | Infer the developmental pseudotime. 342 | 343 | Returns 344 | ---------- 345 | :class:`~numpy.ndarray` 346 | The pseudotime inferred for each cell. 347 | """ 348 | 349 | X = self.adata.X 350 | if self.loss_mode in ['nb', 'zinb']: 351 | X = np.log1p(X) 352 | ts = self._get_time(self.model, X) 353 | 354 | ## The model might return pseudotime in reverse order. Check this based on number of genes expressed in each cell. 355 | if self.time_reverse is None: 356 | n_genes = torch.tensor(self.adata.obs['n_genes_by_counts'].values).float().log1p().to(self.device) 357 | m_ts = ts.mean() 358 | m_ngenes = n_genes.mean() 359 | beta_direction = (ts * n_genes).sum() - len(ts) * m_ts * m_ngenes 360 | if beta_direction > 0: 361 | self.time_reverse = True 362 | else: 363 | self.time_reverse = False 364 | if self.time_reverse: 365 | ts = 1 - ts 366 | 367 | return ts.cpu().numpy() 368 | 369 | 370 | def get_vector_field( 371 | self, 372 | T: np.ndarray, 373 | Z: np.ndarray, 374 | ) -> np.ndarray: 375 | """ 376 | Infer the vector field. 377 | 378 | Parameters 379 | ---------- 380 | T 381 | The pseudotime estimated for each cell. 382 | Z 383 | The latent representation for each cell. 384 | 385 | Returns 386 | ---------- 387 | :class:`~numpy.ndarray` 388 | The estimated vector field. 389 | """ 390 | 391 | vf = self._get_vector_field( 392 | self.model, 393 | T, 394 | Z, 395 | self.time_reverse, 396 | ) 397 | return vf 398 | 399 | 400 | def get_latentsp( 401 | self, 402 | alpha_z: float = .5, 403 | alpha_predz: float = .5, 404 | step_size: Optional[int] = None, 405 | step_wise: bool = False, 406 | batch_size: Optional[int] = None, 407 | ) -> tuple: 408 | """ 409 | Infer the latent space. 410 | 411 | Parameters 412 | ---------- 413 | alpha_z 414 | Scaling factor for encoder-derived latent space. 415 | (Default: 0.5) 416 | alpha_predz 417 | Scaling factor for ODE-solver-derived latent space. 418 | (Default: 0.5) 419 | step_size 420 | Step size during integration. 421 | step_wise 422 | Whether to perform step-wise integration by iteratively considering only two time points each time. 423 | (Default: `False`) 424 | batch_size 425 | Batch size when deriving the latent space. The default is no mini-batching. 426 | 427 | Returns 428 | ---------- 429 | tuple 430 | 3-tuple of weighted combined latent space, encoder-derived latent space, and ODE-solver-derived latent space. 431 | """ 432 | 433 | X = self.adata.X 434 | if self.model.loss_mode in ['nb', 'zinb']: 435 | X = np.log1p(X) 436 | mix_zs, zs, pred_zs = self._get_latentsp(self.model, 437 | X, 438 | alpha_z, 439 | alpha_predz, 440 | step_size, 441 | step_wise, 442 | batch_size, 443 | ) 444 | return mix_zs, zs, pred_zs 445 | 446 | 447 | def save_model( 448 | self, 449 | save_dir: str, 450 | save_prefix: str, 451 | ) -> None: 452 | """ 453 | Save the trained scTour model. 454 | 455 | Parameters 456 | ---------- 457 | save_dir 458 | The directory where the model will be saved. 459 | save_prefix 460 | The prefix for model name. The model will be saved in 'save_dir/save_prefix.pth'. 461 | """ 462 | 463 | save_path = os.path.abspath(os.path.join(save_dir, f'{save_prefix}.pth')) 464 | # save_path = os.path.abspath(os.path.join(save_dir, f'{save_prefix}.tar')) 465 | torch.save( 466 | { 467 | 'model_state_dict': self.model.state_dict(), 468 | 'optimizer_state_dict': self.optimizer.state_dict(), 469 | 'model_kwargs': self.model_kwargs, 470 | 'time_reverse': self.time_reverse, 471 | 'adata': self.adata, 472 | 'percent': self.percent, 473 | 'nepoch': self.nepoch, 474 | 'batch_size': self.batch_size, 475 | 'random_state': self.random_state, 476 | 'drop_last': self.drop_last, 477 | 'lr': self.lr, 478 | 'wt_decay': self.wt_decay, 479 | 'eps': self.eps, 480 | 'val_frac': self.val_frac, 481 | 'use_gpu': self.use_gpu, 482 | }, 483 | save_path 484 | ) 485 | 486 | 487 | @staticmethod 488 | @torch.no_grad() 489 | def _get_time( 490 | model: TNODE, 491 | X: Union[np.ndarray, spmatrix], 492 | ) -> torch.tensor: 493 | """ 494 | Derive the developmental pseudotime for cells. 495 | 496 | Parameters 497 | ---------- 498 | model 499 | The trained scTour model. 500 | X 501 | The data matrix. 502 | 503 | Returns 504 | ---------- 505 | :class:`torch.Tensor` 506 | The pseudotime estimated for each cell. 507 | """ 508 | 509 | model.eval() 510 | if sparse.issparse(X): 511 | X = X.A 512 | X = torch.tensor(X).to(model.device) 513 | ts, _, _ = model.encoder(X) 514 | ts = ts.ravel() 515 | return ts 516 | 517 | 518 | @staticmethod 519 | @torch.no_grad() 520 | def _get_vector_field( 521 | model: TNODE, 522 | T: np.ndarray, 523 | Z: np.ndarray, 524 | time_reverse: bool, 525 | ) -> np.ndarray: 526 | """ 527 | Derive the vector field for cells. 528 | 529 | Parameters 530 | ---------- 531 | model 532 | The trained scTour model. 533 | T 534 | The pseudotime for each cell. 535 | Z 536 | The latent representation for each cell. 537 | time_reverse 538 | Whether to reverse the vector field. 539 | 540 | Returns 541 | ---------- 542 | :class:`~numpy.ndarray` 543 | The estimated vector field. 544 | """ 545 | 546 | model.eval() 547 | if not (isinstance(T, np.ndarray) and isinstance(Z, np.ndarray)): 548 | raise TypeError( 549 | 'The inputs must be numpy arrays.' 550 | ) 551 | Z = torch.tensor(Z) 552 | T = torch.tensor(T) 553 | if time_reverse is None: 554 | raise RuntimeError( 555 | 'It seems you did not run `get_time()` function first after model training.' 556 | ) 557 | direction = 1 558 | if time_reverse: 559 | direction = -1 560 | return direction * model.lode_func(T, Z).numpy() 561 | 562 | 563 | @staticmethod 564 | @torch.no_grad() 565 | def _get_latentsp( 566 | model: TNODE, 567 | X: Union[np.ndarray, spmatrix], 568 | alpha_z: float = .5, 569 | alpha_predz: float = .5, 570 | step_size: Optional[int] = None, 571 | step_wise: bool = False, 572 | batch_size: Optional[int] = None, 573 | ): 574 | """ 575 | Derive the latent representations of cells. 576 | 577 | Parameters 578 | ---------- 579 | model 580 | The trained scTour model. 581 | X 582 | The data matrix. 583 | alpha_z 584 | Scaling factor for encoder-derived latent space. 585 | (Default: 0.5) 586 | alpha_predz 587 | Scaling factor for ODE-solver-derived latent space. 588 | (Default: 0.5) 589 | step_size 590 | Step size during integration. 591 | step_wise 592 | Whether to perform step-wise integration by iteratively considering only two time points each time. 593 | (Default: `False`) 594 | batch_size 595 | Batch size when deriving the latent space. The default is no mini-batching. 596 | 597 | Returns 598 | ---------- 599 | tuple 600 | 3-tuple of weighted combined latent space, encoder-derived latent space, and ODE-solver-derived latent space. 601 | """ 602 | 603 | model.eval() 604 | 605 | if (alpha_z < 0) or (alpha_z > 1): 606 | raise ValueError( 607 | '`alpha_z` must be between 0 and 1.' 608 | ) 609 | if (alpha_predz < 0) or (alpha_predz > 1): 610 | raise ValueError( 611 | '`alpha_predz` must be between 0 and 1.' 612 | ) 613 | if alpha_z + alpha_predz != 1: 614 | raise ValueError( 615 | 'The sum of `alpha_z` and `alpha_predz` must be 1.' 616 | ) 617 | 618 | if sparse.issparse(X): 619 | X = X.A 620 | X = torch.tensor(X).to(model.device) 621 | T, qz_mean, qz_logvar = model.encoder(X) 622 | T = T.ravel().cpu() 623 | epsilon = torch.randn(qz_mean.size()) 624 | zs = epsilon * torch.exp(.5 * qz_logvar.cpu()) + qz_mean.cpu() 625 | 626 | sort_T, sort_idx, sort_ridx = np.unique(T, return_index=True, return_inverse=True) 627 | sort_T = torch.tensor(sort_T) 628 | sort_zs = zs[sort_idx] 629 | 630 | pred_zs = [] 631 | if batch_size is None: 632 | batch_size = len(sort_T) 633 | times = int(np.ceil(len(sort_T) / batch_size)) 634 | for i in range(times): 635 | idx1 = i * batch_size 636 | idx2 = np.min([(i + 1)*batch_size, len(sort_T)]) 637 | t = sort_T[idx1:idx2] 638 | z = sort_zs[idx1:idx2] 639 | z0 = z[0] 640 | 641 | if not step_wise: 642 | options = get_step_size(step_size, t[0], t[-1], len(t)) 643 | pred_z = odeint( 644 | model.lode_func, 645 | z0, 646 | t, 647 | method = model.ode_method, 648 | options = options 649 | ).view(-1, model.n_latent) 650 | else: 651 | pred_z = torch.empty((len(t), z.size(1))) 652 | pred_z[0] = z0 653 | for j in range(len(t) - 1): 654 | t2 = t[j:(j + 2)] 655 | options = get_step_size(step_size, t2[0], t2[-1], len(t2)) 656 | pred_z[j + 1] = odeint( 657 | model.lode_func, 658 | z[j], 659 | t2, 660 | method = model.ode_method, 661 | options = options 662 | )[1] 663 | 664 | pred_zs += [pred_z] 665 | 666 | pred_zs = torch.cat(pred_zs) 667 | pred_zs = pred_zs[sort_ridx] 668 | mix_zs = alpha_z * zs + alpha_predz * pred_zs 669 | 670 | return mix_zs.numpy(), zs.numpy(), pred_zs.numpy() 671 | -------------------------------------------------------------------------------- /sctour/vector_field.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scanpy as sc 3 | from scipy.sparse import coo_matrix, csr_matrix 4 | from scipy.stats import norm 5 | from sklearn.neighbors import NearestNeighbors 6 | import matplotlib.pyplot as plt 7 | from matplotlib.axes import Axes 8 | from anndata import AnnData 9 | from typing import Optional, Union 10 | 11 | from ._utils import l2_norm 12 | from . import logger 13 | 14 | 15 | def cosine_similarity( 16 | adata: AnnData, 17 | zs_key: str, 18 | reverse: bool = False, 19 | use_rep_neigh: Optional[str] = None, 20 | vf_key: str = 'X_VF', 21 | run_neigh: bool = True, 22 | n_neigh: int = 20, 23 | t_key: Optional[str] = None, 24 | var_stabilize_transform: bool = False, 25 | ) -> csr_matrix: 26 | """ 27 | Calculate the cosine similarity between the vector field and the cell-neighbor latent state difference for each cell. 28 | The calculation borrows the ideas from scvelo: https://github.com/theislab/scvelo/blob/master/scvelo/tools/velocity_graph.py. 29 | 30 | Parameters 31 | ---------- 32 | adata 33 | An :class:`~anndata.AnnData` object. 34 | reverse 35 | Whether to reverse the direction of the vector field. When the pseudotime returned by `get_time()` function was in reverse order and you used the post-inference adjustment (`reverse_time()` function), please set this parameter to `True`. 36 | (Default: `False`) 37 | zs_key 38 | The key in `.obsm` for storing the latent space. 39 | vf_key 40 | The key in `.obsm` for storing the vector field. 41 | (Default: `'X_VF'`) 42 | run_neigh 43 | Whether to run neighbor detection. 44 | (Default: `True`) 45 | use_rep_neigh 46 | The representation in `.obsm` which will be used for neighbor detection. 47 | n_neigh 48 | The number of neighbors considered for each cell. 49 | (Default: 20) 50 | t_key: 51 | The key in `.obs` for estimated pseudotime which will be considered when detecting neighbors. 52 | var_stabilize_transform 53 | Whether to perform variance-stabilizing transformation for vector field and cell-neighbor latent state difference. 54 | (Default: `False`) 55 | 56 | Returns 57 | ---------- 58 | :class:`~scipy.sparse.csr_matrix` 59 | A sparse matrix with cosine similarities. 60 | """ 61 | 62 | Z = np.array(adata.obsm[zs_key]) 63 | V = np.array(adata.obsm[vf_key]) 64 | if reverse: 65 | V = -V 66 | if var_stabilize_transform: 67 | V = np.sqrt(np.abs(V)) * np.sign(V) 68 | 69 | ncells = adata.n_obs 70 | 71 | if run_neigh or ('neighbors' not in adata.uns): 72 | if use_rep_neigh is None: 73 | use_rep_neigh = zs_key 74 | logger.warn(f"Warning: the parameter `use_rep_neigh` in function `plot_vector_field` is not provided. Use `{zs_key}` in `.obsm` of the AnnData instead.") 75 | else: 76 | if use_rep_neigh not in adata.obsm: 77 | raise KeyError( 78 | f"`{use_rep_neigh}` not found in `.obsm` of the AnnData. Please provide valid `use_rep_neigh` for neighbor detection." 79 | ) 80 | sc.pp.neighbors(adata, use_rep = use_rep_neigh, n_neighbors = n_neigh) 81 | n_neigh = adata.uns['neighbors']['params']['n_neighbors'] - 1 82 | # indices_matrix = adata.obsp['distances'].indices.reshape(-1, n_neigh) 83 | 84 | if t_key is not None: 85 | if t_key not in adata.obs: 86 | raise KeyError( 87 | f"`{t_key}` not found in `.obs` of the AnnData. Please provide valid `t_key` for estimated pseudotime." 88 | ) 89 | ts = adata.obs[t_key].values 90 | indices_matrix2 = np.zeros((ncells, n_neigh), dtype = int) 91 | for i in range(ncells): 92 | idx = np.abs(ts - ts[i]).argsort()[:(n_neigh + 1)] 93 | idx = np.setdiff1d(idx, i) if i in idx else idx[:-1] 94 | indices_matrix2[i] = idx 95 | # indices_matrix = np.hstack([indices_matrix, indices_matrix2]) 96 | 97 | vals, rows, cols = [], [], [] 98 | for i in range(ncells): 99 | # idx = np.unique(indices_matrix[i]) 100 | # idx2 = indices_matrix[idx].flatten() 101 | # idx2 = np.setdiff1d(idx2, i) 102 | idx = adata.obsp['distances'][i].indices 103 | idx2 = adata.obsp['distances'][idx].indices 104 | idx2 = np.setdiff1d(idx2, i) 105 | idx = np.unique(np.concatenate([idx, idx2])) if t_key is None else np.unique(np.concatenate([idx, idx2, indices_matrix2[i]])) 106 | dZ = Z[idx] - Z[i, None] 107 | if var_stabilize_transform: 108 | dZ = np.sqrt(np.abs(dZ)) * np.sign(dZ) 109 | cos_sim = np.einsum("ij, j", dZ, V[i]) / (l2_norm(dZ, axis = 1) * l2_norm(V[i])) 110 | cos_sim[np.isnan(cos_sim)] = 0 111 | vals.extend(cos_sim) 112 | rows.extend(np.repeat(i, len(idx))) 113 | cols.extend(idx) 114 | 115 | res = coo_matrix((vals, (rows, cols)), shape = (ncells, ncells)) 116 | res.data = np.clip(res.data, -1, 1) 117 | return res.tocsr() 118 | 119 | 120 | def quiver_autoscale( 121 | E: np.ndarray, 122 | V: np.ndarray, 123 | ): 124 | """ 125 | Get the autoscaling in quiver. 126 | This function is from scvelo: https://github.com/theislab/scvelo/blob/master/scvelo/tools/velocity_embedding.py. 127 | 128 | Parameters 129 | ---------- 130 | E 131 | The embedding. 132 | V 133 | The weighted unitary displacement. 134 | 135 | Returns 136 | ---------- 137 | The autoscaling factor. 138 | """ 139 | 140 | fig, ax = plt.subplots() 141 | scale_factor = np.abs(E).max() 142 | 143 | Q = ax.quiver( 144 | E[:, 0] / scale_factor, 145 | E[:, 1] / scale_factor, 146 | V[:, 0], 147 | V[:, 1], 148 | angles = 'xy', 149 | scale = None, 150 | scale_units = 'xy', 151 | ) 152 | Q._init() 153 | fig.clf() 154 | plt.close(fig) 155 | return Q.scale / scale_factor 156 | 157 | 158 | def vector_field_embedding( 159 | adata: AnnData, 160 | T_key: str, 161 | E_key: str, 162 | scale: int = 10, 163 | self_transition: bool = False, 164 | ): 165 | """ 166 | Calculate the weighted unitary displacement vectors under a certain embedding. 167 | This function borrows the ideas from scvelo: https://github.com/theislab/scvelo/blob/master/scvelo/tools/velocity_embedding.py. 168 | 169 | Parameters 170 | ---------- 171 | adata 172 | An :class:`~anndata.AnnData` object. 173 | T_key 174 | The key in `.obsp` for cosine similarity. 175 | E_key 176 | The key in `.obsm` for embedding. 177 | scale 178 | Scale factor for cosine similarity. 179 | (Default: 10) 180 | self_transition 181 | Whether to take self-transition into consideration. 182 | (Default: `False`) 183 | 184 | Returns 185 | ---------- 186 | The weighted unitary displacement vectors. 187 | """ 188 | 189 | T = adata.obsp[T_key].copy() 190 | 191 | if self_transition: 192 | max_t = T.max(1).A.flatten() 193 | ub = np.percentile(max_t, 98) 194 | self_t = np.clip(ub - max_t, 0, 1) 195 | T.setdiag(self_t) 196 | 197 | T = T.sign().multiply(np.expm1(abs(T * scale))) 198 | T = T.multiply(csr_matrix(1.0 / abs(T).sum(1))) 199 | if self_transition: 200 | T.setdiag(0) 201 | T.eliminate_zeros() 202 | 203 | E = np.array(adata.obsm[E_key]) 204 | V = np.zeros(E.shape) 205 | 206 | for i in range(adata.n_obs): 207 | idx = T[i].indices 208 | dE = E[idx] - E[i, None] 209 | dE /= l2_norm(dE)[:, None] 210 | dE[np.isnan(dE)] = 0 211 | prob = T[i].data 212 | V[i] = prob.dot(dE) - prob.mean() * dE.sum(0) 213 | 214 | V /= 3 * quiver_autoscale(E, V) 215 | return V 216 | 217 | 218 | def vector_field_embedding_grid( 219 | E: np.ndarray, 220 | V: np.ndarray, 221 | smooth: float = 0.5, 222 | stream: bool = False, 223 | density: float = 1.0, 224 | ) -> tuple: 225 | """ 226 | Estimate the unitary displacement vectors within a grid. 227 | This function borrows the ideas from scvelo: https://github.com/theislab/scvelo/blob/master/scvelo/plotting/velocity_embedding_grid.py. 228 | 229 | Parameters 230 | ---------- 231 | E 232 | The embedding. 233 | V 234 | The unitary displacement vectors under the embedding. 235 | smooth 236 | The factor for scale in Gaussian pdf. 237 | (Default: 0.5) 238 | stream 239 | Whether to adjust for streamplot. 240 | (Default: `False`) 241 | density 242 | grid density 243 | (Default: 1.0) 244 | 245 | Returns 246 | ---------- 247 | tuple 248 | The embedding and unitary displacement vectors at grid level. 249 | """ 250 | 251 | grs = [] 252 | for i in range(E.shape[1]): 253 | m, M = np.min(E[:, i]), np.max(E[:, i]) 254 | diff = M - m 255 | m = m - 0.01 * diff 256 | M = M + 0.01 * diff 257 | gr = np.linspace(m, M, int(50 * density)) 258 | grs.append(gr) 259 | 260 | meshes = np.meshgrid(*grs) 261 | E_grid = np.vstack([i.flat for i in meshes]).T 262 | 263 | n_neigh = int(E.shape[0] / 50) 264 | nn = NearestNeighbors(n_neighbors = n_neigh, n_jobs = -1) 265 | nn.fit(E) 266 | dists, neighs = nn.kneighbors(E_grid) 267 | 268 | scale = np.mean([g[1] - g[0] for g in grs]) * smooth 269 | weight = norm.pdf(x = dists, scale = scale) 270 | weight_sum = weight.sum(1) 271 | 272 | V_grid = (V[neighs] * weight[:, :, None]).sum(1) 273 | V_grid /= np.maximum(1, weight_sum)[:, None] 274 | 275 | if stream: 276 | E_grid = np.stack(grs) 277 | ns = E_grid.shape[1] 278 | V_grid = V_grid.T.reshape(2, ns, ns) 279 | 280 | mass = np.sqrt((V_grid * V_grid).sum(0)) 281 | min_mass = 1e-5 282 | min_mass = np.clip(min_mass, None, np.percentile(mass, 99) * 0.01) 283 | cutoff1 = (mass < min_mass) 284 | 285 | length = np.sum(np.mean(np.abs(V[neighs]), axis = 1), axis = 1).reshape(ns, ns) 286 | cutoff2 = (length < np.percentile(length, 5)) 287 | 288 | cutoff = (cutoff1 | cutoff2) 289 | V_grid[0][cutoff] = np.nan 290 | else: 291 | min_weight = np.percentile(weight_sum, 99) * 0.01 292 | E_grid, V_grid = E_grid[weight_sum > min_weight], V_grid[weight_sum > min_weight] 293 | V_grid /= 3 * quiver_autoscale(E_grid, V_grid) 294 | 295 | return E_grid, V_grid 296 | 297 | 298 | def plot_vector_field( 299 | adata: AnnData, 300 | zs_key: str, 301 | reverse: bool = False, 302 | vf_key: str = 'X_VF', 303 | run_neigh: bool = True, 304 | use_rep_neigh: Optional[str] = None, 305 | t_key: Optional[str] = None, 306 | n_neigh: int = 20, 307 | var_stabilize_transform: bool = False, 308 | E_key: str = 'X_umap', 309 | scale: int = 10, 310 | self_transition: bool = False, 311 | smooth: float = 0.5, 312 | density: float = 1., 313 | grid: bool = False, 314 | stream: bool = True, 315 | stream_density: int = 2, 316 | stream_color: str = 'k', 317 | stream_linewidth: int = 1, 318 | stream_arrowsize: int = 1, 319 | grid_density: float = 1., 320 | grid_arrowcolor: str = 'grey', 321 | grid_arrowlength: int = 1, 322 | grid_arrowsize: int = 1, 323 | show: bool = True, 324 | save: Optional[Union[str, bool]] = None, 325 | # color: Optional[str] = None, 326 | # ax: Optional[Axes] = None, 327 | **kwargs, 328 | ): 329 | """ 330 | Visualize the vector field. 331 | The visualization of vector field under an embedding borrows the ideas from scvelo: https://github.com/theislab/scvelo. 332 | 333 | Parameters 334 | ---------- 335 | adata 336 | An :class:`~anndata.AnnData` object. 337 | zs_key 338 | The key in `.obsm` for storing the latent space. 339 | reverse 340 | Whether to reverse the direction of the vector field. When the pseudotime returned by `get_time()` function was in reverse order and you used the post-inference adjustment (`reverse_time()` function), please set this parameter to `True`. 341 | (Default: `False`) 342 | vf_key 343 | The key in `.obsm` for storing the vector field. 344 | run_neigh 345 | Whether to run neighbor detection. 346 | (Default: `True`) 347 | use_rep_neigh 348 | The representation in `.obsm` which will be used for neighbor detection. 349 | t_key: 350 | The key in `.obs` for estimated pseudotime which will be considered when detecting neighbors. 351 | n_neigh 352 | The number of neighbors considered for each cell. 353 | (Default: 20) 354 | var_stabilize_transform 355 | Whether to perform variance-stabilizing transformation for vector field and cell-neighbor latent state difference. 356 | (Default: `False`) 357 | E_key 358 | The key in `.obsm` for embedding. 359 | (Default: `'X_umap'`) 360 | scale 361 | Scale factor for cosine similarity. 362 | (Default: 10) 363 | self_transition 364 | Whether to take self-transition into consideration. 365 | (Default: `False`) 366 | smooth 367 | The factor for scale in Gaussian pdf. 368 | (Default: 0.5) 369 | density 370 | Percentage of cells to show when displaying the vector field at per-cell level. 371 | (Default: 1.0) 372 | grid 373 | Whether to display vector field as arrows at grid level. 374 | (Default: `False`) 375 | stream 376 | Whether to display vector field as streamplot. 377 | (Default: `True`) 378 | stream_density 379 | The density parameter in streamplot for controlling the closeness of the streamlines. 380 | (Default: 2) 381 | stream_color 382 | The streamline color for streamplot. 383 | (Default: 'k') 384 | stream_linewidth 385 | The line width for streamplot. 386 | (Default: 1) 387 | stream_arrowsize 388 | The arrow size for streamplot. 389 | (Default: 1) 390 | grid_density 391 | The density for showing vector field as arrows at grid level. 392 | (Default: 1.0) 393 | grid_arrowcolor 394 | The arrow color when showing vector field as arrows at grid level. 395 | (Default: `'grey'`) 396 | grid_arrowlength 397 | The arrow length when showing vector field as arrows at grid level. 398 | (Default: 1) 399 | grid_arrowsize 400 | The arrow size when showing vector field as arrows at grid level. 401 | (Default: 1) 402 | show 403 | Whether to show the plot. 404 | (Default: `True`) 405 | save 406 | Whether to save the figure. If `True` or a `str`, the figure will be saved as 'sctour_vector_field.png' (if `True` provided) or a given filename (if a `str` provided). 407 | kwargs 408 | Parameters passed to :func:`scanpy.pl.embedding`. 409 | 410 | Returns 411 | ---------- 412 | None 413 | """ 414 | 415 | if zs_key not in adata.obsm: 416 | raise KeyError( 417 | f"`{zs_key}` not found in `.obsm` of the AnnData. Please provide valid `zs_key` for latent space." 418 | ) 419 | if vf_key not in adata.obsm: 420 | raise KeyError( 421 | f"`{vf_key}` not found in `.obsm` of the AnnData. Please provide valid `vf_key` for vector field." 422 | ) 423 | if E_key not in adata.obsm: 424 | raise KeyError( 425 | f"`{E_key}` not found in `.obsm` of the AnnData. Please provide valid `E_key` for embedding." 426 | ) 427 | if (grid_density < 0) or (grid_density > 1): 428 | raise ValueError( 429 | "`grid_density` must be between 0 and 1." 430 | ) 431 | if (density < 0) or (density > 1): 432 | raise ValueError( 433 | "`density` must be between 0 and 1." 434 | ) 435 | 436 | ##calculate cosine similarity 437 | adata.obsp['cosine_similarity'] = cosine_similarity( 438 | adata, 439 | reverse = reverse, 440 | zs_key = zs_key, 441 | vf_key = vf_key, 442 | run_neigh = run_neigh, 443 | use_rep_neigh = use_rep_neigh, 444 | t_key = t_key, 445 | n_neigh = n_neigh, 446 | var_stabilize_transform = var_stabilize_transform, 447 | ) 448 | ##get weighted unitary displacement vectors under a certain embedding 449 | adata.obsm['X_DV'] = vector_field_embedding( 450 | adata, 451 | T_key = 'cosine_similarity', 452 | E_key = E_key, 453 | scale = scale, 454 | self_transition = self_transition, 455 | ) 456 | 457 | E = np.array(adata.obsm[E_key]) 458 | V = adata.obsm['X_DV'] 459 | 460 | if grid: 461 | stream = False 462 | 463 | if grid or stream: 464 | E, V = vector_field_embedding_grid( 465 | E = E, 466 | V = V, 467 | smooth = smooth, 468 | stream = stream, 469 | density = grid_density, 470 | ) 471 | 472 | ax = sc.pl.embedding(adata, basis = E_key, show=False, **kwargs) 473 | if stream: 474 | lengths = np.sqrt((V * V).sum(0)) 475 | stream_linewidth *= 2 * lengths / lengths[~np.isnan(lengths)].max() 476 | stream_kwargs = dict( 477 | linewidth = stream_linewidth, 478 | density = stream_density, 479 | zorder = 3, 480 | color = stream_color, 481 | arrowsize = stream_arrowsize, 482 | arrowstyle = '-|>', 483 | maxlength = 4, 484 | integration_direction = 'both', 485 | ) 486 | ax.streamplot(E[0], E[1], V[0], V[1], **stream_kwargs) 487 | else: 488 | if not grid: 489 | if density < 1: 490 | idx = np.random.choice(len(E), int(len(E) * density), replace = False) 491 | E = E[idx] 492 | V = V[idx] 493 | scale = 1 / grid_arrowlength 494 | hl, hw, hal = 6 * grid_arrowsize, 5 * grid_arrowsize, 4 * grid_arrowsize 495 | quiver_kwargs = dict( 496 | angles = 'xy', 497 | scale_units = 'xy', 498 | edgecolors = 'k', 499 | scale = scale, 500 | width = 0.001, 501 | headlength = hl, 502 | headwidth = hw, 503 | headaxislength = hal, 504 | color = grid_arrowcolor, 505 | linewidth = 0.2, 506 | zorder = 3, 507 | ) 508 | ax.quiver(E[:, 0], E[:, 1], V[:, 0], V[:, 1], **quiver_kwargs) 509 | 510 | if save: 511 | if isinstance(save, str): 512 | plt.savefig(save) 513 | else: 514 | plt.savefig('sctour_vector_field.png') 515 | if show: 516 | plt.show() 517 | if save: 518 | plt.close() 519 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | def get_readme(): 4 | with open("README.md", "rt", encoding="utf-8") as fh: 5 | return fh.read() 6 | 7 | setuptools.setup( 8 | name='sctour', 9 | version='1.0.0', 10 | author='Qian Li', 11 | author_email='liqian.picb@gmail.com', 12 | description='a deep learning architecture for robust inference and accurate prediction of cellular dynamics', 13 | long_description=get_readme(), 14 | long_description_content_type="text/markdown", 15 | url='https://github.com/LiQian-XC/sctour', 16 | packages=setuptools.find_packages(), 17 | install_requires=[ 18 | 'torch>=1.9.1', 19 | 'torchdiffeq>=0.2.2', 20 | 'numpy>=1.19.2', 21 | 'scanpy>=1.7.1', 22 | 'anndata>=0.7.5', 23 | 'scipy>=1.5.2', 24 | 'tqdm>=4.32.2', 25 | 'scikit-learn>=0.24.1', 26 | 'leidenalg>=0.8.4', 27 | ], 28 | classifiers=[ 29 | "Programming Language :: Python :: 3", 30 | "Operating System :: OS Independent", 31 | "Intended Audience :: Science/Research", 32 | "Topic :: Scientific/Engineering :: Bio-Informatics", 33 | "Development Status :: 4 - Beta", 34 | ], 35 | python_requires='>=3.7', 36 | ) 37 | --------------------------------------------------------------------------------