├── .gitignore
├── .readthedocs.yaml
├── LICENSE
├── MANIFEST.in
├── README.md
├── docs
├── Makefile
├── make.bat
├── requirements.txt
└── source
│ ├── _static
│ └── img
│ │ └── scTour_head_image.png
│ ├── _templates
│ └── autosummary
│ │ └── class.rst
│ ├── api_predict.rst
│ ├── api_reverse_time.rst
│ ├── api_train.rst
│ ├── api_vf.rst
│ ├── conf.py
│ ├── index.rst
│ └── notebook
│ ├── scTour_inference_PostInference_adjustment.ipynb
│ ├── scTour_inference_basic.ipynb
│ └── scTour_prediction.ipynb
├── sctour
├── __init__.py
├── _utils.py
├── data.py
├── logger.py
├── model.py
├── module.py
├── predict.py
├── train.py
└── vector_field.py
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | /**/__pycache__/
2 | sctour.egg-info
3 | build_doc.sh
4 | build/
5 | build_wheel.sh
6 | /dist/
7 | docs/source/generated/
8 |
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | version: 2
2 |
3 | build:
4 | os: ubuntu-22.04
5 | tools:
6 | python: "3.11"
7 |
8 | sphinx:
9 | configuration: docs/source/conf.py
10 |
11 | python:
12 | install:
13 | - requirements: docs/requirements.txt
14 | - method: pip
15 | path: .
16 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Qian Li
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include README.md
2 | include LICENSE
3 | recursive-exclude * __pycache__
4 | recursive-exclude * *.pyc
5 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # scTour
3 |
4 |
5 |
6 | scTour is an innovative and comprehensive method for dissecting cellular dynamics by analysing datasets derived from single-cell genomics.
7 |
8 | It provides a unifying framework to depict the full picture of developmental processes from multiple angles including the developmental pseudotime, vector field and latent space.
9 |
10 | It further generalises these functionalities to a multi-task architecture for within-dataset inference and cross-dataset prediction of cellular dynamics in a batch-insensitive manner.
11 |
12 | ## Key features
13 |
14 | - cell pseudotime estimation with no need for specifying starting cells.
15 | - transcriptomic vector field inference with no discrimination between spliced and unspliced mRNAs.
16 | - latent space mapping by combining intrinsic transcriptomic structure with extrinsic pseudotime ordering.
17 | - model-based prediction of pseudotime, vector field, and latent space for query cells/datasets/time intervals.
18 | - insensitive to batch effects; robust to cell subsampling; scalable to large datasets.
19 |
20 | ## Installation
21 |
22 | [](https://pypi.org/project/sctour)
23 |
24 | ```console
25 | pip install sctour
26 | ```
27 |
28 | [](https://anaconda.org/conda-forge/sctour)
29 |
30 | ```console
31 | conda install -c conda-forge sctour
32 | ```
33 |
34 | ## Documentation
35 |
36 | [](https://sctour.readthedocs.io/en/latest/?badge=latest)
37 |
38 | Full documentation can be found [here](https://sctour.readthedocs.io/en/latest/).
39 |
40 | ## Reference
41 |
42 | [Qian Li, scTour: a deep learning architecture for robust inference and accurate prediction of cellular dynamics. Genome Biology, 2023](https://genomebiology.biomedcentral.com/articles/10.1186/s13059-023-02988-9)
43 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = source
9 | BUILDDIR = build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=source
11 | set BUILDDIR=build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.http://sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=1.9.1
2 | torchdiffeq==0.2.2
3 | numpy>=1.19.2
4 | scanpy>=1.7.1
5 | anndata>=0.7.5
6 | scipy>=1.5.2
7 | tqdm>=4.32.2
8 | scikit-learn>=0.24.1
9 | myst_parser==0.15.2
10 | sphinx>=3.5.2
11 | nbsphinx==0.8.8
12 | ipython_genutils==0.2.0
13 | jinja2<3.1
14 |
--------------------------------------------------------------------------------
/docs/source/_static/img/scTour_head_image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiQian-XC/sctour/86e21fe356012776d55a07f7c2fa7013dc19def8/docs/source/_static/img/scTour_head_image.png
--------------------------------------------------------------------------------
/docs/source/_templates/autosummary/class.rst:
--------------------------------------------------------------------------------
1 | {{ fullname | escape | underline}}
2 |
3 | .. currentmodule:: {{ module }}
4 |
5 | .. add toctree option to make autodoc generate the pages
6 |
7 | .. autoclass:: {{ objname }}
8 |
9 | {% block attributes %}
10 | {% if attributes %}
11 | .. rubric:: Attributes
12 |
13 | .. autosummary::
14 | :toctree: .
15 | {% for item in attributes %}
16 | {% if has_attr(fullname, item) %}
17 | ~{{ fullname }}.{{ item }}
18 | {% endif %}
19 | {%- endfor %}
20 | {% endif %}
21 | {% endblock %}
22 |
23 | {% block methods %}
24 | {% if methods %}
25 | .. rubric:: Methods
26 |
27 | .. autosummary::
28 | :toctree: .
29 | {% for item in methods %}
30 | {%- if item != '__init__' %}
31 | ~{{ fullname }}.{{ item }}
32 | {%- endif -%}
33 | {%- endfor %}
34 | {% endif %}
35 | {% endblock %}
36 |
--------------------------------------------------------------------------------
/docs/source/api_predict.rst:
--------------------------------------------------------------------------------
1 | Prediction
2 | ==================================
3 |
4 | .. autosummary::
5 | :toctree: generated/
6 | :nosignatures:
7 |
8 | sctour.predict.load_model
9 | sctour.predict.predict_time
10 | sctour.predict.predict_latentsp
11 | sctour.predict.predict_vector_field
12 | sctour.predict.predict_ltsp_from_time
13 |
--------------------------------------------------------------------------------
/docs/source/api_reverse_time.rst:
--------------------------------------------------------------------------------
1 | Post-inference adjustment
2 | ==================================
3 |
4 | .. autofunction:: sctour.train.reverse_time
5 |
--------------------------------------------------------------------------------
/docs/source/api_train.rst:
--------------------------------------------------------------------------------
1 | Model training
2 | ==================================
3 |
4 | .. autosummary::
5 | :toctree: generated/
6 | :nosignatures:
7 |
8 | sctour.train.Trainer
9 |
--------------------------------------------------------------------------------
/docs/source/api_vf.rst:
--------------------------------------------------------------------------------
1 | Vector field visualization
2 | ==================================
3 |
4 | .. autofunction:: sctour.vf.plot_vector_field
5 |
--------------------------------------------------------------------------------
/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.insert(0, os.path.abspath('../../'))
4 |
5 |
6 | # -- Project information -----------------------------------------------------
7 | from datetime import datetime
8 | import sctour
9 |
10 | project = 'sctour'
11 | author = 'Qian Li'
12 | copyright = f'{datetime.now():%Y}, {author}'
13 | release = sctour.__version__
14 |
15 |
16 | # -- General configuration ---------------------------------------------------
17 | extensions = [
18 | 'sphinx.ext.autodoc',
19 | 'sphinx.ext.mathjax',
20 | 'sphinx.ext.napoleon',
21 | 'sphinx.ext.intersphinx',
22 | 'sphinx.ext.viewcode',
23 | 'myst_parser',
24 | 'nbsphinx',
25 | 'sphinx.ext.autosummary',
26 | ]
27 | autosummary_generate = True
28 | autodoc_member_order = 'bysource'
29 | napoleon_include_init_with_doc = False
30 | napoleon_numpy_docstring = True
31 | napoleon_use_rtype = True
32 | napoleon_use_param = True
33 |
34 | # settings
35 | master_doc = 'index'
36 | templates_path = ["_templates"]
37 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
38 | pygments_style = 'sphinx'
39 |
40 | intersphinx_mapping = dict(
41 | python=('https://docs.python.org/3/', None),
42 | numpy=('https://numpy.org/doc/stable/', None),
43 | pandas=('https://pandas.pydata.org/pandas-docs/stable/', None),
44 | anndata=('https://anndata.readthedocs.io/en/stable/', None),
45 | scanpy=('https://scanpy.readthedocs.io/en/stable/', None),
46 | scipy=('https://docs.scipy.org/doc/scipy/reference/', None),
47 | torch=('https://pytorch.org/docs/master/', None),
48 | matplotlib=('https://matplotlib.org/stable/', None),
49 | )
50 |
51 |
52 | # -- Options for HTML output -------------------------------------------------
53 | html_theme = 'sphinx_rtd_theme'
54 | html_static_path = ['_static']
55 |
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | |PyPI| |Docs|
2 |
3 | scTour: a deep learning architecture for robust inference and accurate prediction of cellular dynamics
4 | ======================================================================================================
5 |
6 | .. image:: https://raw.githubusercontent.com/LiQian-XC/sctour/main/docs/source/_static/img/scTour_head_image.png
7 | :width: 400px
8 | :align: left
9 |
10 | scTour is an innovative and comprehensive method for dissecting cellular dynamics by analysing datasets derived from single-cell genomics.
11 |
12 | It provides a unifying framework to depict the full picture of developmental processes from multiple angles including the developmental pseudotime, vector field and latent space.
13 |
14 | It further generalises these functionalities to a multi-task architecture for within-dataset inference and cross-dataset prediction of cellular dynamics in a batch-insensitive manner.
15 |
16 | Key features
17 | ------------
18 |
19 | - cell pseudotime estimation with no need for specifying starting cells.
20 | - transcriptomic vector field inference with no discrimination between spliced and unspliced mRNAs.
21 | - latent space mapping by combining intrinsic transcriptomic structure with extrinsic pseudotime ordering.
22 | - model-based prediction of pseudotime, vector field, and latent space for query cells/datasets/time intervals.
23 | - insensitive to batch effects; robust to cell subsampling; scalable to large datasets.
24 |
25 | Installation
26 | ------------
27 |
28 | scTour requires Python ≥ 3.7::
29 |
30 | pip install sctour
31 |
32 | Reference
33 | ---------
34 |
35 | Qian Li, scTour: a deep learning architecture for robust inference and accurate prediction of cellular dynamics, 2023, `Genome Biology `_.
36 |
37 | .. toctree::
38 | :maxdepth: 2
39 | :caption: Tutorials
40 | :hidden:
41 |
42 | notebook/scTour_inference_basic
43 | notebook/scTour_inference_PostInference_adjustment
44 | notebook/scTour_prediction
45 |
46 | .. toctree::
47 | :maxdepth: 2
48 | :caption: API
49 | :hidden:
50 |
51 | api_train
52 | api_predict
53 | api_vf
54 | api_reverse_time
55 |
56 | .. |PyPI| image:: https://img.shields.io/pypi/v/sctour.svg
57 | :target: https://pypi.org/project/sctour
58 |
59 | .. |Docs| image:: https://readthedocs.org/projects/sctour/badge/?version=latest
60 | :target: https://sctour.readthedocs.io
61 |
--------------------------------------------------------------------------------
/sctour/__init__.py:
--------------------------------------------------------------------------------
1 | from . import train
2 | from . import predict
3 | from . import vector_field as vf
4 | __version__ = '1.0.0'
5 |
--------------------------------------------------------------------------------
/sctour/_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 | from scipy.sparse import issparse
5 |
6 |
7 | ##calculate KL divergence
8 | def normal_kl(mu1, lv1, mu2, lv2):
9 | """
10 | Calculate KL divergence
11 | This function is from torchdiffeq: https://github.com/rtqichen/torchdiffeq/blob/master/examples/latent_ode.py
12 | """
13 | v1 = torch.exp(lv1)
14 | v2 = torch.exp(lv2)
15 | lstd1 = lv1/2.
16 | lstd2 = lv2/2.
17 |
18 | kl = lstd2 - lstd1 + (v1 + (mu1-mu2)**2.)/(2.*v2) - 0.5
19 | return kl
20 |
21 |
22 | ## get step size
23 | def get_step_size(step_size, t1, t2, t_size):
24 | if step_size is None:
25 | options = {}
26 | else:
27 | step_size = (t2 - t1)/t_size/step_size
28 | options = dict(step_size = step_size)
29 | return options
30 |
31 |
32 | ##calculate log zinb probability
33 | def log_zinb(x, mu, theta, pi, eps=1e-8):
34 | """
35 | Calculate log probability under zero-inflated negative binomial distribution
36 | This function is from scvi-tools: https://github.com/YosefLab/scvi-tools/blob/6dae6482efa2d235182bf4ad10dbd9483b7d57cd/scvi/distributions/_negative_binomial.py
37 | """
38 | softplus_pi = F.softplus(-pi)
39 | log_theta_eps = torch.log(theta + eps)
40 | log_theta_mu_eps = torch.log(theta + mu + eps)
41 | pi_theta_log = -pi + theta * (log_theta_eps - log_theta_mu_eps)
42 |
43 | case_zero = F.softplus(pi_theta_log) - softplus_pi
44 | mul_case_zero = torch.mul((x < eps).type(torch.float32), case_zero)
45 |
46 | case_non_zero = (
47 | -softplus_pi
48 | + pi_theta_log
49 | + x * (torch.log(mu + eps) - log_theta_mu_eps)
50 | + torch.lgamma(x + theta)
51 | - torch.lgamma(theta)
52 | - torch.lgamma(x + 1)
53 | )
54 | mul_case_non_zero = torch.mul((x > eps).type(torch.float32), case_non_zero)
55 |
56 | res = mul_case_zero + mul_case_non_zero
57 | return res
58 |
59 |
60 | ##calculate log nb probability
61 | def log_nb(x, mu, theta, eps=1e-8):
62 | """
63 | Calculate log probability under negative binomial distribution
64 | This function is from scvi-tools: https://github.com/YosefLab/scvi-tools/blob/6dae6482efa2d235182bf4ad10dbd9483b7d57cd/scvi/distributions/_negative_binomial.py
65 | """
66 | log_theta_mu_eps = torch.log(theta + mu + eps)
67 |
68 | res = (
69 | theta * (torch.log(theta + eps) - log_theta_mu_eps)
70 | + x * (torch.log(mu + eps) - log_theta_mu_eps)
71 | + torch.lgamma(x + theta)
72 | - torch.lgamma(theta)
73 | - torch.lgamma(x + 1)
74 | )
75 |
76 | return res
77 |
78 |
79 | ##L2 norm
80 | def l2_norm(x, axis=-1):
81 | if issparse(x):
82 | return np.sqrt(x.multiply(x).sum(axis=axis).A1)
83 | else:
84 | return np.sqrt(np.sum(x * x, axis = axis))
85 |
--------------------------------------------------------------------------------
/sctour/data.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | from scipy import sparse
4 | import numpy as np
5 | from anndata import AnnData
6 |
7 |
8 | def split_data(
9 | adata: AnnData,
10 | percent: float,
11 | val_frac: float = 0.1,
12 | ):
13 | """
14 | Split the dataset for training and validation
15 |
16 | Parameters
17 | ----------
18 | adata
19 | The `AnnData` object for the whole dataset
20 | percent
21 | The percentage to be used for training the model
22 | val_frac
23 | The percentage to be used for validation
24 |
25 | Returns
26 | ----------
27 | `AnnData` object for training and validation
28 | """
29 |
30 | n_cells = adata.n_obs
31 | n_train = int(np.ceil(n_cells * percent))
32 | n_val = min(int(np.floor(n_train * val_frac)), n_cells - n_train)
33 |
34 | indices = np.random.permutation(n_cells)
35 | train_idx = np.random.choice(indices, n_train, replace = False)
36 | indices2 = np.setdiff1d(indices, train_idx)
37 | val_idx = np.random.choice(indices2, n_val, replace = False)
38 | # train_idx = indices[:n_train]
39 | # val_idx = indices[n_train:(n_train + n_val)]
40 |
41 | train_data = adata[train_idx, :]
42 | val_data = adata[val_idx, :]
43 |
44 | return train_data, val_data
45 |
46 |
47 | def split_index(
48 | n_cells: int,
49 | percent: float,
50 | val_frac: float = 0.1,
51 | ):
52 | """
53 | Split the indices for training and validation
54 |
55 | Parameters
56 | ----------
57 | n_cells
58 | The total number of cells
59 | percent
60 | The percentage to be used for training the model
61 | val_frac
62 | The percentage to be used for validation
63 |
64 | Returns
65 | ----------
66 | 2-tuple of indices
67 | """
68 |
69 | n_train = int(np.ceil(n_cells * percent))
70 | n_val = min(int(np.floor(n_train * val_frac)), n_cells - n_train)
71 | indices = np.random.permutation(n_cells)
72 | train_idx = np.random.choice(indices, n_train, replace = False)
73 | indices2 = np.setdiff1d(indices, train_idx)
74 | val_idx = np.random.choice(indices2, n_val, replace = False)
75 | # train_idx = indices[:n_train]
76 | # val_idx = indices[n_train:(n_train + n_val)]
77 | return train_idx, val_idx
78 |
79 |
80 | class MakeDataset(Dataset):
81 | """
82 | A class to generate Dataset
83 |
84 | Parameters
85 | ----------
86 | adata
87 | An `AnnData` object
88 | """
89 |
90 | def __init__(
91 | self,
92 | adata: AnnData,
93 | loss_mode: str,
94 | ):
95 | X = adata.X
96 | if loss_mode in ['nb', 'zinb']:
97 | X = np.log1p(X)
98 | if sparse.issparse(X):
99 | X = X.A
100 | self.data = torch.tensor(X)
101 | self.library_size = self.data.sum(-1)
102 |
103 | def __len__(self):
104 | return self.data.size(0)
105 |
106 | def __getitem__(self, idx):
107 | return self.data[idx, :], self.library_size[idx]
108 |
109 |
110 | class BatchSampler():
111 | """
112 | A class to generate mini-batches through two rounds of randomization
113 |
114 | Parameters
115 | ----------
116 | n
117 | Total number of cells
118 | batch_size
119 | Size of mini-batches
120 | drop_last
121 | Whether or not drop the last batch when its size is smaller than the batch_size
122 | """
123 | def __init__(
124 | self,
125 | n: int,
126 | batch_size: int,
127 | drop_last: bool = False,
128 | ):
129 | self.batch_size = batch_size
130 | self.n = n
131 | self.drop_last = drop_last
132 |
133 | def __iter__(self):
134 | seq_n = torch.randperm(self.n)
135 | lb = self.n // self.batch_size
136 | idxs = np.arange(self.n)
137 | for i in range(lb):
138 | idx = np.random.choice(idxs, self.batch_size, replace=False)
139 | yield seq_n[idx].tolist()
140 | idxs = np.setdiff1d(idxs, idx)
141 | if (not self.drop_last) and (len(idxs) > 0):
142 | yield seq_n[idxs].tolist()
143 |
144 | def __len__(self):
145 | if self.drop_last:
146 | return self.n // self.batch_size
147 | else:
148 | return int(np.ceil(self.n / self.batch_size))
149 |
--------------------------------------------------------------------------------
/sctour/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | logging.basicConfig(level=logging.INFO, format="%(message)s")
4 | logger = logging.getLogger(__name__)
5 |
6 | info = logger.info
7 | warn = logger.warning
8 | error = logger.error
9 | debug = logger.debug
10 |
--------------------------------------------------------------------------------
/sctour/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.distributions import Normal, kl_divergence
5 | from torchdiffeq import odeint
6 | from typing import Optional
7 | from typing_extensions import Literal
8 |
9 | from .module import LatentODEfunc, Encoder, Decoder
10 | from ._utils import get_step_size, normal_kl, log_zinb, log_nb
11 |
12 |
13 | class TNODE(nn.Module):
14 | """
15 | Class to automatically infer cellular dynamics using VAE and neural ODE.
16 |
17 | Parameters
18 | ----------
19 | device
20 | The torch device.
21 | n_int
22 | The dimensionality of the input.
23 | n_latent
24 | The dimensionality of the latent space.
25 | (Default: 5)
26 | n_ode_hidden
27 | The dimensionality of the hidden layer for the latent ODE function.
28 | (Default: 25)
29 | n_vae_hidden
30 | The dimensionality of the hidden layer for the VAE.
31 | (Default: 128)
32 | batch_norm
33 | Whether to include `BatchNorm` layer.
34 | (Default: `False`)
35 | ode_method
36 | Solver for integration.
37 | (Default: `'euler'`)
38 | step_size
39 | The step size during integration.
40 | alpha_recon_lec
41 | Scaling factor for reconstruction loss from encoder-derived latent space.
42 | (Default: 0.5)
43 | alpha_recon_lode
44 | Scaling factor for reconstruction loss from ODE-solver latent space.
45 | (Default: 0.5)
46 | alpha_kl
47 | Scaling factor for KL divergence.
48 | (Default: 1.0)
49 | loss_mode
50 | The mode for calculating the reconstruction error.
51 | (Default: `'nb'`)
52 | """
53 |
54 | def __init__(
55 | self,
56 | device,
57 | n_int: int,
58 | n_latent: int = 5,
59 | n_ode_hidden: int = 25,
60 | n_vae_hidden: int = 128,
61 | batch_norm: bool = False,
62 | ode_method: str = 'euler',
63 | step_size: Optional[int] = None,
64 | alpha_recon_lec: float = 0.5,
65 | alpha_recon_lode: float = 0.5,
66 | alpha_kl: float = 1.,
67 | loss_mode: Literal['mse', 'nb', 'zinb'] = 'nb',
68 | ):
69 | super().__init__()
70 | self.n_int = n_int
71 | self.n_latent = n_latent
72 | self.n_ode_hidden = n_ode_hidden
73 | self.n_vae_hidden = n_vae_hidden
74 | self.batch_norm = batch_norm
75 | self.ode_method = ode_method
76 | self.step_size = step_size
77 | self.alpha_recon_lec = alpha_recon_lec
78 | self.alpha_recon_lode = alpha_recon_lode
79 | self.alpha_kl = alpha_kl
80 | self.loss_mode = loss_mode
81 | self.device = device
82 |
83 | self.lode_func = LatentODEfunc(n_latent, n_ode_hidden)
84 | self.encoder = Encoder(n_int, n_latent, n_vae_hidden, batch_norm).to(self.device)
85 | self.decoder = Decoder(n_int, n_latent, n_vae_hidden, batch_norm, loss_mode).to(self.device)
86 |
87 | def forward(self, x: torch.Tensor, y: torch.Tensor) -> tuple:
88 | """
89 | Given the transcriptomes of cells, this function derives the time and latent space of the cells, as well as reconstructs the transcriptomes.
90 |
91 | Parameters
92 | ----------
93 | x
94 | The input data.
95 | y
96 | The library size.
97 |
98 | Returns
99 | ----------
100 | 5-tuple of :class:`torch.Tensor`
101 | Tensors for loss, including:
102 | 1) total loss,
103 | 2) reconstruction loss from encoder-derived latent space,
104 | 3) reconstruction loss from ODE-solver latent space,
105 | 4) KL divergence,
106 | 5) divergence between encoder-derived latent space and ODE-solver latent space
107 | """
108 |
109 | ## get the time and latent space through Encoder
110 | T, qz_mean, qz_logvar = self.encoder(x)
111 | T = T.ravel() ## odeint requires 1-D Tensor for time
112 | epsilon = torch.randn(qz_mean.size()).to(T.device)
113 | z = epsilon * torch.exp(.5 * qz_logvar) + qz_mean
114 |
115 | index = torch.argsort(T)
116 | T = T[index]
117 | x = x[index]
118 | z = z[index]
119 | y = y[index]
120 | # qz_mean = qz_mean[index]
121 | # qz_logvar = qz_logvar[index]
122 | index2 = (T[:-1] != T[1:])
123 | index2 = torch.cat((index2, torch.tensor([True]).to(index2.device))) ## index2 is used to get unique time points as odeint requires strictly increasing/decreasing time points
124 | T = T[index2]
125 | x = x[index2]
126 | z = z[index2]
127 | y = y[index2]
128 | # qz_mean = qz_mean[index2]
129 | # qz_logvar = qz_logvar[index2]
130 |
131 | ## infer the latent space through ODE solver based on z0, t, and LatentODEfunc
132 | z0 = z[0]
133 | options = get_step_size(self.step_size, T[0], T[-1], len(T))
134 | pred_z = odeint(self.lode_func, z0.to('cpu'), T.to('cpu'), method = self.ode_method, options = options).view(-1, self.n_latent)
135 | pred_z = pred_z.to(z.device)
136 |
137 | ## reconstruct the input through Decoder and compute reconstruction loss
138 | if self.loss_mode == 'mse':
139 | pred_x1 = self.decoder(z) ## decode through latent space returned by Encoder
140 | pred_x2 = self.decoder(pred_z) ## decode through latent space returned by ODE solver
141 | recon_loss_ec = F.mse_loss(x, pred_x1, reduction='none').sum(-1).mean()
142 | recon_loss_ode = F.mse_loss(x, pred_x2, reduction='none').sum(-1).mean()
143 | if self.loss_mode == 'nb':
144 | pred_x1 = self.decoder(z) ## decode through latent space returned by Encoder
145 | pred_x2 = self.decoder(pred_z) ## decode through latent space returned by ODE solver
146 | y = y.unsqueeze(1).expand(pred_x1.size(0), pred_x1.size(1))
147 | pred_x1 = pred_x1 * y
148 | pred_x2 = pred_x2 * y
149 | disp = torch.exp(self.decoder.disp)
150 | recon_loss_ec = -log_nb(x, pred_x1, disp).sum(-1).mean()
151 | recon_loss_ode = -log_nb(x, pred_x2, disp).sum(-1).mean()
152 | if self.loss_mode == 'zinb':
153 | pred_x1, dp1 = self.decoder(z)
154 | pred_x2, dp2 = self.decoder(pred_z)
155 | y = y.unsqueeze(1).expand(pred_x1.size(0), pred_x1.size(1))
156 | pred_x1 = pred_x1 * y
157 | pred_x2 = pred_x2 * y
158 | disp = torch.exp(self.decoder.disp)
159 | recon_loss_ec = -log_zinb(x, pred_x1, disp, dp1).sum(-1).mean()
160 | recon_loss_ode = -log_zinb(x, pred_x2, disp, dp2).sum(-1).mean()
161 |
162 | ## compute KL divergence and z divergence
163 | z_div = F.mse_loss(z, pred_z, reduction='none').sum(-1).mean()
164 | pz_mean = torch.zeros_like(qz_mean)
165 | pz_logvar = torch.zeros_like(qz_mean)
166 | kl_div = normal_kl(qz_mean, qz_logvar, pz_mean, pz_logvar).sum(-1).mean()
167 |
168 | loss = self.alpha_recon_lec * recon_loss_ec + self.alpha_recon_lode * recon_loss_ode + z_div + self.alpha_kl * kl_div
169 |
170 | return loss, recon_loss_ec, recon_loss_ode, kl_div, z_div
171 |
--------------------------------------------------------------------------------
/sctour/module.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from typing_extensions import Literal
4 |
5 |
6 | class LatentODEfunc(nn.Module):
7 | """
8 | A class modelling the latent state derivatives with respect to time.
9 |
10 | Parameters
11 | ----------
12 | n_latent
13 | The dimensionality of the latent space.
14 | (Default: 5)
15 | n_hidden
16 | The dimensionality of the hidden layer.
17 | (Default: 25)
18 | """
19 |
20 | def __init__(
21 | self,
22 | n_latent: int = 5,
23 | n_hidden: int = 25,
24 | ):
25 | super().__init__()
26 | self.elu = nn.ELU()
27 | self.fc1 = nn.Linear(n_latent, n_hidden)
28 | self.fc2 = nn.Linear(n_hidden, n_latent)
29 |
30 | def forward(self, t: torch.Tensor, x: torch.Tensor):
31 | """
32 | Compute the gradient at a given time t and a given state x.
33 |
34 | Parameters
35 | ----------
36 | t
37 | A given time point.
38 | x
39 | A given latent state.
40 |
41 | Returns
42 | ----------
43 | :class:`torch.Tensor`
44 | A tensor
45 | """
46 | out = self.fc1(x)
47 | out = self.elu(out)
48 | out = self.fc2(out)
49 | return out
50 |
51 |
52 | class Encoder(nn.Module):
53 | """
54 | Encoder class generating the time and latent space.
55 |
56 | Parameters
57 | ----------
58 | n_int
59 | The dimensionality of the input.
60 | n_latent
61 | The dimensionality of the latent space.
62 | (Default: 5)
63 | n_hidden
64 | The dimensionality of the hidden layer.
65 | (Default: 128)
66 | batch_norm
67 | Whether to include `BatchNorm` layer or not.
68 | (Default: `False`)
69 | """
70 |
71 | def __init__(
72 | self,
73 | n_int: int,
74 | n_latent: int = 5,
75 | n_hidden: int = 128,
76 | batch_norm: bool = False,
77 | ):
78 | super().__init__()
79 | self.n_latent = n_latent
80 | self.fc = nn.Sequential()
81 | self.fc.add_module('L1', nn.Linear(n_int, n_hidden))
82 | if batch_norm:
83 | self.fc.add_module('N1', nn.BatchNorm1d(n_hidden))
84 | self.fc.add_module('A1', nn.ReLU())
85 | self.fc2 = nn.Linear(n_hidden, n_latent*2)
86 | self.fc3 = nn.Linear(n_hidden, 1)
87 |
88 | def forward(self, x:torch.Tensor) -> tuple:
89 | x = self.fc(x)
90 | out = self.fc2(x)
91 | qz_mean, qz_logvar = out[:, :self.n_latent], out[:, self.n_latent:]
92 | t = self.fc3(x).sigmoid()
93 | return t, qz_mean, qz_logvar
94 |
95 |
96 | class Decoder(nn.Module):
97 | """
98 | Decoder class to reconstruct the original input based on its latent space.
99 |
100 | Parameters
101 | ----------
102 | n_latent
103 | The dimensionality of the latent space.
104 | (Default: 5)
105 | n_int
106 | The dimensionality of the original input.
107 | n_hidden
108 | The dimensionality of the hidden layer.
109 | (Default: 128)
110 | batch_norm
111 | Whether to include `BatchNorm` layer or not.
112 | (Default: `False`)
113 | loss_mode
114 | The mode for reconstructing the original data.
115 | (Default: `'nb'`)
116 | """
117 |
118 | def __init__(
119 | self,
120 | n_int: int,
121 | n_latent: int = 5,
122 | n_hidden: int = 128,
123 | batch_norm: bool = False,
124 | loss_mode: Literal['mse', 'nb', 'zinb'] = 'nb',
125 | ):
126 | super().__init__()
127 | self.loss_mode = loss_mode
128 | if loss_mode in ['nb', 'zinb']:
129 | self.disp = nn.Parameter(torch.randn(n_int))
130 |
131 | self.fc = nn.Sequential()
132 | self.fc.add_module('L1', nn.Linear(n_latent, n_hidden))
133 | if batch_norm:
134 | self.fc.add_module('N1', nn.BatchNorm1d(n_hidden))
135 | self.fc.add_module('A1', nn.ReLU())
136 |
137 | if loss_mode == 'mse':
138 | self.fc2 = nn.Linear(n_hidden, n_int)
139 | if loss_mode in ['nb', 'zinb']:
140 | self.fc2 = nn.Sequential(nn.Linear(n_hidden, n_int), nn.Softmax(dim = -1))
141 | if loss_mode == 'zinb':
142 | self.fc3 = nn.Linear(n_hidden, n_int)
143 |
144 | def forward(self, z: torch.Tensor):
145 | out = self.fc(z)
146 | recon_x = self.fc2(out)
147 | if self.loss_mode == 'zinb':
148 | disp = self.fc3(out)
149 | return recon_x, disp
150 | else:
151 | return recon_x
152 |
--------------------------------------------------------------------------------
/sctour/predict.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchdiffeq import odeint
3 | from typing import Optional, Union
4 | from typing_extensions import Literal
5 | import numpy as np
6 | from anndata import AnnData
7 | from scipy import sparse
8 | from scipy.sparse import spmatrix
9 | import os
10 |
11 | from ._utils import get_step_size
12 | from .train import Trainer
13 |
14 |
15 | def _check_data(
16 | adata1: AnnData,
17 | adata2: AnnData,
18 | loss_mode: str,
19 | ) -> np.ndarray:
20 | """
21 | Check the query data.
22 |
23 | Parameters
24 | ----------
25 | adata1
26 | An :class:`~anndata.AnnData` object for the query dataset.
27 | adata2
28 | An :class:`~anndata.AnnData` object for the training dataset.
29 | loss_mode
30 | The `loss_mode` used for model training.
31 |
32 | Returns
33 | ----------
34 | :class:`~numpy.ndarray`
35 | The expression matrix for the query dataset.
36 | """
37 |
38 | if len(adata1.var_names.intersection(adata2.var_names)) != adata2.n_vars:
39 | raise ValueError(
40 | "The query AnnData must contain all the genes that are used for training in the training dataset."
41 | )
42 |
43 | X = adata1[:, adata2.var_names].X
44 | if loss_mode == 'mse':
45 | if (X.min() < 0) or (X.max() > np.log1p(1e6)):
46 | raise ValueError(
47 | "Invalid expression matrix in `.X`. Model trained from `mse` mode expects log1p(normalized expression) in `.X` of the query AnnData."
48 | )
49 | else:
50 | data = X.data if sparse.issparse(X) else X
51 | if (data.min() < 0) or np.any(~np.equal(np.mod(data, 1), 0)):
52 | raise ValueError(
53 | f"Invalid expression matrix in `.X`. Model trained from `{loss_mode}` mode expects raw UMI counts in `.X` of the query AnnData."
54 | )
55 | else:
56 | X = np.log1p(X)
57 |
58 | return X
59 |
60 |
61 | def load_model(model: str):
62 | """
63 | Load the trained scTour model for prediction.
64 |
65 | Parameters
66 | ----------
67 | model
68 | Filename for the scTour model trained and saved.
69 |
70 | Returns
71 | ----------
72 | :class:`sctour.train.Trainer`
73 | The trained scTour model.
74 | """
75 |
76 | if not os.path.isfile(model):
77 | raise FileNotFoundError(
78 | f'No such file: `{model}`.'
79 | )
80 |
81 | checkpoint = torch.load(model, map_location=torch.device('cpu'))
82 | # checkpoint = torch.load(model)
83 | model_kwargs = checkpoint['model_kwargs']
84 | del model_kwargs['device']
85 | del model_kwargs['n_int']
86 | tnode = Trainer(
87 | adata = checkpoint['adata'],
88 | percent = checkpoint['percent'],
89 | nepoch = checkpoint['nepoch'],
90 | batch_size = checkpoint['batch_size'],
91 | drop_last = checkpoint['drop_last'],
92 | lr = checkpoint['lr'],
93 | wt_decay = checkpoint['wt_decay'],
94 | eps = checkpoint['eps'],
95 | random_state = checkpoint['random_state'],
96 | val_frac = checkpoint['val_frac'],
97 | use_gpu = checkpoint['use_gpu'],
98 | **model_kwargs,
99 | )
100 | tnode.model.load_state_dict(checkpoint['model_state_dict'])
101 | tnode.time_reverse = checkpoint['time_reverse']
102 | return tnode
103 |
104 |
105 | def predict_time(
106 | model: Trainer,
107 | adata: AnnData,
108 | reverse: bool = False,
109 | ) -> np.ndarray:
110 | """
111 | Predict the pseudotime for query cells.
112 |
113 | Parameters
114 | ----------
115 | model
116 | A :class:`sctour.train.Trainer` for trained scTour model.
117 | adata
118 | An :class:`~anndata.AnnData` object for the query dataset.
119 | reverse
120 | Whether to reverse the predicted pseudotime. When the pseudotime returned by `get_time()` function for the training data was in reverse order and you used the post-inference adjustment (`reverse_time()` function), please set this parameter to `True`.
121 | (Default: `False`)
122 |
123 | Returns
124 | ----------
125 | :class:`~numpy.ndarray`
126 | The pseudotime predicted for the query cells.
127 | """
128 |
129 | if model.time_reverse is None:
130 | raise RuntimeError(
131 | 'It seems you did not run `get_time()` function after model training. Please run `get_time()` first after training for the training data before you run `predict_time()` for the query data.'
132 | )
133 |
134 | X = _check_data(adata, model.adata, model.loss_mode)
135 | ts = model._get_time(model = model.model, X = X)
136 | if model.time_reverse:
137 | ts = 1 - ts
138 | if reverse:
139 | ts = 1 - ts
140 | return ts.cpu().numpy()
141 |
142 |
143 | def predict_vector_field(
144 | model: Trainer,
145 | T: np.ndarray,
146 | Z: np.ndarray,
147 | ) -> np.ndarray:
148 | """
149 | Predict the vector field for query cells.
150 |
151 | Parameters
152 | ----------
153 | model
154 | A :class:`sctour.train.Trainer` for trained scTour model.
155 | T
156 | The predicted pseudotime for query cells.
157 | Z
158 | The predicted latent representations for query cells.
159 |
160 | Returns
161 | ----------
162 | :class:`~numpy.ndarray`
163 | The vector field predicted for query cells.
164 | """
165 |
166 | vf = model._get_vector_field(
167 | model = model.model,
168 | T = T,
169 | Z = Z,
170 | time_reverse = model.time_reverse,
171 | )
172 | return vf
173 |
174 |
175 | def predict_latentsp(
176 | model: Trainer,
177 | adata: AnnData,
178 | mode: Literal['coarse', 'fine'] = 'fine',
179 | alpha_z: float = .5,
180 | alpha_predz: float = .5,
181 | step_size: Optional[int] = None,
182 | step_wise: bool = False,
183 | batch_size: Optional[int] = None,
184 | ) -> tuple:
185 | """
186 | Predict the latent representations for query cells given their transcriptomes.
187 |
188 | Parameters
189 | ----------
190 | model
191 | A :class:`sctour.train.Trainer` for trained scTour model.
192 | adata
193 | An :class:`~anndata.AnnData` object for the query dataset.
194 | mode
195 | The mode for deriving the latent space for the query dataset.
196 | Two modes are included:
197 | ``'fine'``: derive the latent space by taking the training data into consideration;
198 | ``'coarse'``: derive the latent space directly from the query data without involving the training data.
199 | alpha_z
200 | Scaling factor for encoder-derived latent space.
201 | (Default: 0.5)
202 | alpha_predz
203 | Scaling factor for ODE-solver-derived latent space.
204 | (Default: 0.5)
205 | step_size
206 | The step size during integration.
207 | step_wise
208 | Whether to perform step-wise integration by iteratively considering only two time points each time.
209 | (Default: `False`)
210 | batch_size
211 | Batch size when deriving the latent space. The default is no mini-batching.
212 |
213 | Returns
214 | ----------
215 | tuple
216 | 3-tuple of weighted combined latent space, encoder-derived latent space, and ODE-solver-derived latent space.
217 | """
218 |
219 | X = _check_data(adata, model.adata, model.loss_mode)
220 | if mode == 'coarse':
221 | mix_zs, zs, pred_zs = model._get_latentsp(
222 | model = model.model,
223 | X = X,
224 | alpha_z = alpha_z,
225 | alpha_predz = alpha_predz,
226 | step_size = step_size,
227 | step_wise = step_wise,
228 | batch_size = batch_size,
229 | )
230 | if mode == 'fine':
231 | X2 = model.adata.X
232 | if model.loss_mode in ['nb', 'zinb']:
233 | X2 = np.log1p(X2)
234 | if sparse.issparse(X2):
235 | X2 = X2.A
236 | if sparse.issparse(X):
237 | X = X.A
238 | mix_zs, zs, pred_zs = model._get_latentsp(
239 | model = model.model,
240 | X = np.vstack((X, X2)),
241 | alpha_z = alpha_z,
242 | alpha_predz = alpha_predz,
243 | step_size = step_size,
244 | step_wise = step_wise,
245 | batch_size = batch_size,
246 | )
247 | mix_zs = mix_zs[:len(X)]
248 | zs = zs[:len(X)]
249 | pred_zs = pred_zs[:len(X)]
250 |
251 | return mix_zs, zs, pred_zs
252 |
253 |
254 | @torch.no_grad()
255 | def predict_ltsp_from_time(
256 | model: Trainer,
257 | T: np.ndarray,
258 | reverse: bool = False,
259 | step_wise: bool = True,
260 | step_size: Optional[int] = None,
261 | alpha_z: float = 0.5,
262 | alpha_predz: float = 0.5,
263 | k: int = 20,
264 | ) -> np.ndarray:
265 | """
266 | Predict the transcriptomic latent space for query (unobserved) time intervals.
267 |
268 | Parameters
269 | ----------
270 | model
271 | A :class:`sctour.train.Trainer` for trained scTour model.
272 | T
273 | A 1D numpy array containing the query time points (with values between 0 and 1). The latent space corresponding to these time points will be predicted.
274 | reverse
275 | When the pseudotime returned by `get_time()` function for the training data was in reverse order and you used the post-inference adjustment (`reverse_time()` function), please set this parameter to `True`.
276 | (Default: `False`)
277 | step_wise
278 | Whether to perform step-wise integration by iteratively considering only two time points when inferring the reference latent space from the training data.
279 | (Default: `True`)
280 | step_size
281 | The step size during integration.
282 | alpha_z
283 | Scaling factor for encoder-derived latent space.
284 | (Default: 0.5)
285 | alpha_predz
286 | Scaling factor for ODE-solver-derived latent space.
287 | (Default: 0.5)
288 | k
289 | The k nearest neighbors in the time space considered when predicting the latent representation for each query time point.
290 | (Default: 20)
291 |
292 | Returns
293 | ----------
294 | :class:`~numpy.ndarray`
295 | Predicted latent space corresponding to the query time interval.
296 | """
297 |
298 | mdl = model.model
299 |
300 | if not isinstance(T, np.ndarray):
301 | raise TypeError(
302 | "The input time interval must be a numpy array."
303 | )
304 | if len(T.shape) > 1:
305 | raise TypeError(
306 | "The input time interval must be a 1D numpy array."
307 | )
308 | if np.any(T < 0) or np.any(T > 1):
309 | raise ValueError(
310 | "The input time points must be in [0, 1]."
311 | )
312 |
313 | ridx = np.random.permutation(len(T))
314 | rT = torch.tensor(T[ridx])
315 | ## get the reference time and latent space from the training data
316 | X = model.adata.X
317 | if model.loss_mode in ['nb', 'zinb']:
318 | X = np.log1p(X)
319 | mix_zs, zs, pred_zs = model._get_latentsp(model = mdl,
320 | X = X,
321 | alpha_z = alpha_z,
322 | alpha_predz = alpha_predz,
323 | step_wise = step_wise,
324 | step_size = step_size,
325 | )
326 | ts = model._get_time(model = mdl, X = X)
327 | if model.time_reverse:
328 | ts = 1 - ts
329 | if reverse:
330 | ts = 1 - ts
331 |
332 | ts = ts.cpu()
333 | zs = torch.tensor(mix_zs)
334 |
335 | pred_T_zs = torch.empty((len(rT), mdl.n_latent))
336 | for i, t in enumerate(rT):
337 | diff = torch.abs(t - ts)
338 | idxs = torch.argsort(diff)
339 | # n = (diff == 0).sum()
340 | # idxs = idxs[n:(k + n)]
341 | if (diff == 0).any():
342 | pred_T_zs[i] = zs[idxs[0]].clone()
343 | else:
344 | idxs = idxs[:k]
345 | k_zs = torch.empty((k, mdl.n_latent))
346 | for j, idx in enumerate(idxs):
347 | z0 = zs[idx].clone()
348 | t0 = ts[idx].clone()
349 | pred_t = torch.stack((t0, t))
350 | if pred_t[0] < pred_t[1]:
351 | options = get_step_size(step_size, pred_t[0], pred_t[-1], len(pred_t))
352 | else:
353 | options = get_step_size(step_size, pred_t[-1], pred_t[0], len(pred_t))
354 | k_zs[j] = odeint(
355 | mdl.lode_func,
356 | z0,
357 | pred_t,
358 | method = mdl.ode_method,
359 | options = options
360 | )[1]
361 | k_zs = torch.mean(k_zs, dim = 0)
362 | pred_T_zs[i] = k_zs
363 | ts = torch.cat((ts, t.unsqueeze(0)))
364 | zs = torch.cat((zs, k_zs.unsqueeze(0)))
365 |
366 | pred_T_zs = pred_T_zs[np.argsort(ridx)]
367 | return pred_T_zs.numpy()
368 |
--------------------------------------------------------------------------------
/sctour/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import DataLoader
3 | from torchdiffeq import odeint
4 | from typing import Optional, Union
5 | from typing_extensions import Literal
6 | import numpy as np
7 | from anndata import AnnData
8 | from scipy import sparse
9 | from scipy.sparse import spmatrix
10 | from tqdm import tqdm
11 | import os
12 | from collections import defaultdict
13 |
14 | from .model import TNODE
15 | from ._utils import get_step_size
16 | from .data import split_data, MakeDataset, BatchSampler
17 | from . import logger
18 |
19 |
20 | ##reverse time
21 | def reverse_time(
22 | T: np.ndarray,
23 | ) -> np.ndarray:
24 | """
25 | Post-inference adjustment to reverse the pseudotime.
26 |
27 | Parameters
28 | ----------
29 | T
30 | The pseudotime inferred for each cell.
31 |
32 | Returns
33 | ----------
34 | :class:`~numpy.ndarray`
35 | The reversed pseudotime.
36 | """
37 |
38 | return 1 - T
39 |
40 |
41 | class Trainer:
42 | """
43 | Class for implementing the scTour training process.
44 |
45 | Parameters
46 | ----------
47 | adata
48 | An :class:`~anndata.AnnData` object for the training data.
49 | percent
50 | The percentage of cells used for model training. Default to 0.2 when the cell number > 10,000 and to 0.9 otherwise.
51 | n_latent
52 | The dimensionality of the latent space.
53 | (Default: 5)
54 | n_ode_hidden
55 | The dimensionality of the hidden layer for the latent ODE function.
56 | (Default: 25)
57 | n_vae_hidden
58 | The dimensionality of the hidden layer for the VAE.
59 | (Default: 128)
60 | batch_norm
61 | Whether to include a `BatchNorm` layer.
62 | (Default: `False`)
63 | ode_method
64 | The solver for ODE. List of ODE solvers can be found in `torchdiffeq`.
65 | (Default: `'euler'`)
66 | step_size
67 | The step size during integration.
68 | alpha_recon_lec
69 | The scaling factor for the reconstruction error from encoder-derived latent space.
70 | (Default: 0.5)
71 | alpha_recon_lode
72 | The scaling factor for the reconstruction error from ODE-solver-derived latent space.
73 | (Default: 0.5)
74 | alpha_kl
75 | The scaling factor for the KL divergence in the loss function.
76 | (Default: 1.0)
77 | loss_mode
78 | The mode for calculating the reconstruction error.
79 | (Default: `'nb'`)
80 | Three modes are included:
81 | ``'mse'``: mean squared error;
82 | ``'nb'``: negative binomial conditioned likelihood;
83 | ``'zinb'``: zero-inflated negative binomial conditioned likelihood.
84 | nepoch
85 | Number of epochs.
86 | batch_size
87 | The batch size during training.
88 | (Default: 1024)
89 | drop_last
90 | Whether or not drop the last batch when its size is smaller than `batch_size`.
91 | (Default: `False`)
92 | lr
93 | The learning rate.
94 | (Default: 1e-3)
95 | wt_decay
96 | The weight decay (L2 penalty) for Adam optimizer.
97 | (Default: 1e-6)
98 | eps
99 | The `eps` parameter for Adam optimizer.
100 | (Default: 0.01)
101 | random_state
102 | The seed for generating random numbers.
103 | (Default: 0)
104 | val_frac
105 | The percentage of data used for validation.
106 | (Default: 0.1)
107 | use_gpu
108 | Whether to use GPU when available.
109 | (Default: `True`)
110 | """
111 |
112 | def __init__(
113 | self,
114 | adata: AnnData,
115 | percent: Optional[float] = None,
116 | n_latent: int = 5,
117 | n_ode_hidden: int = 25,
118 | n_vae_hidden: int = 128,
119 | batch_norm: bool = False,
120 | ode_method: str = 'euler',
121 | step_size: Optional[int] = None,
122 | alpha_recon_lec: float = 0.5,
123 | alpha_recon_lode: float = 0.5,
124 | alpha_kl: float = 1.,
125 | loss_mode: Literal['mse', 'nb', 'zinb'] = 'nb',
126 | nepoch: Optional[int] = None,
127 | batch_size: int = 1024,
128 | drop_last: bool = False,
129 | lr: float = 1e-3,
130 | wt_decay: float = 1e-6,
131 | eps: float = 0.01,
132 | random_state: int = 0,
133 | val_frac: float = 0.1,
134 | use_gpu: bool = True,
135 | ):
136 | self.loss_mode = loss_mode
137 | if self.loss_mode not in ['mse', 'nb', 'zinb']:
138 | raise ValueError(
139 | f"`loss_mode` must be one of ['mse', 'nb', 'zinb'], but input was '{self.loss_mode}'."
140 | )
141 |
142 | if (alpha_recon_lec < 0) or (alpha_recon_lec > 1):
143 | raise ValueError(
144 | '`alpha_recon_lec` must be between 0 and 1.'
145 | )
146 | if (alpha_recon_lode < 0) or (alpha_recon_lode > 1):
147 | raise ValueError(
148 | '`alpha_recon_lode` must be between 0 and 1.'
149 | )
150 | if alpha_recon_lec + alpha_recon_lode != 1:
151 | raise ValueError(
152 | 'The sum of `alpha_recon_lec` and `alpha_recon_lode` must be 1.'
153 | )
154 |
155 | self.adata = adata
156 | if 'n_genes_by_counts' not in self.adata.obs:
157 | raise KeyError(
158 | "`n_genes_by_counts` not found in `.obs` of the AnnData. Please run `scanpy.pp.calculate_qc_metrics` first to calculate the number of genes detected in each cell."
159 | )
160 | if loss_mode == 'mse':
161 | if (self.adata.X.min() < 0) or (self.adata.X.max() > np.log1p(1e6)):
162 | raise ValueError(
163 | "Invalid expression matrix in `.X`. `mse` mode expects log1p(normalized expression) in `.X` of the AnnData."
164 | )
165 | else:
166 | X = self.adata.X.data if sparse.issparse(self.adata.X) else self.adata.X
167 | if (X.min() < 0) or np.any(~np.equal(np.mod(X, 1), 0)):
168 | raise ValueError(
169 | f"Invalid expression matrix in `.X`. `{self.loss_mode}` mode expects raw UMI counts in `.X` of the AnnData."
170 | )
171 |
172 | self.n_cells = adata.n_obs
173 | self.batch_size = batch_size
174 | self.drop_last = drop_last
175 | self.percent = percent
176 | if self.percent is None:
177 | if self.n_cells > 10000:
178 | self.percent = .2
179 | else:
180 | self.percent = .9
181 | else:
182 | if (self.percent < 0) or (self.percent > 1):
183 | raise ValueError(
184 | "`percent` must be between 0 and 1."
185 | )
186 | self.val_frac = val_frac
187 | if (self.val_frac < 0) or (self.val_frac > 1):
188 | raise ValueError(
189 | '`val_frac` must be between 0 and 1.'
190 | )
191 |
192 | if nepoch is None:
193 | ncells = round(self.n_cells * self.percent)
194 | self.nepoch = np.min([round((10000 / ncells) * 400), 400])
195 | else:
196 | self.nepoch = nepoch
197 |
198 | self.lr = lr
199 | self.wt_decay = wt_decay
200 | self.eps = eps
201 | self.time_reverse = None
202 |
203 | self.random_state = random_state
204 | np.random.seed(random_state)
205 | # random.seed(random_state)
206 | torch.manual_seed(random_state)
207 | # torch.backends.cudnn.benchmark = False
208 | # torch.use_deterministic_algorithms(True)
209 |
210 | self.use_gpu = use_gpu
211 | gpu = torch.cuda.is_available() and use_gpu
212 | if gpu:
213 | torch.cuda.manual_seed(random_state)
214 | self.device = torch.device('cuda')
215 | logger.info('Running using GPU.')
216 | else:
217 | self.device = torch.device('cpu')
218 | logger.info('Running using CPU.')
219 |
220 | self.n_int = adata.n_vars
221 | self.model_kwargs = dict(
222 | device = self.device,
223 | n_int = self.n_int,
224 | n_latent = n_latent,
225 | n_ode_hidden = n_ode_hidden,
226 | n_vae_hidden = n_vae_hidden,
227 | batch_norm = batch_norm,
228 | ode_method = ode_method,
229 | step_size = step_size,
230 | alpha_recon_lec = alpha_recon_lec,
231 | alpha_recon_lode = alpha_recon_lode,
232 | alpha_kl = alpha_kl,
233 | loss_mode = loss_mode,
234 | )
235 | self.model = TNODE(**self.model_kwargs)
236 | self.log = defaultdict(list)
237 |
238 |
239 | def _get_data_loaders(self) -> None:
240 | """
241 | Generate Data Loaders for training and validation datasets.
242 | """
243 |
244 | train_data, val_data = split_data(self.adata, self.percent, self.val_frac)
245 | self.train_dataset = MakeDataset(train_data, self.loss_mode)
246 | self.val_dataset = MakeDataset(val_data, self.loss_mode)
247 |
248 | # sampler = BatchSampler(train_data.n_obs, self.batch_size, self.drop_last)
249 | # self.train_dl = DataLoader(self.train_dataset, batch_sampler = sampler)
250 | self.train_dl = DataLoader(self.train_dataset, batch_size = self.batch_size, shuffle = True)
251 | self.val_dl = DataLoader(self.val_dataset, batch_size = self.batch_size)
252 |
253 |
254 | def train(self):
255 | """
256 | Model training.
257 | """
258 | self._get_data_loaders()
259 |
260 | params = filter(lambda p: p.requires_grad, self.model.parameters())
261 | self.optimizer = torch.optim.Adam(params, lr = self.lr, weight_decay = self.wt_decay, eps = self.eps)
262 |
263 | with tqdm(total=self.nepoch, unit='epoch') as t:
264 | for tepoch in range(t.total):
265 | train_loss = self._on_epoch_train(self.train_dl)
266 | val_loss = self._on_epoch_val(self.val_dl)
267 | self.log['train_loss'].append(train_loss)
268 | self.log['validation_loss'].append(val_loss)
269 | t.set_description(f"Epoch {tepoch + 1}")
270 | t.set_postfix({'train_loss': train_loss, 'val_loss': val_loss}, refresh=False)
271 | t.update()
272 |
273 |
274 | def _on_epoch_train(self, DL) -> float:
275 | """
276 | Go through the model and update the model parameters.
277 |
278 | Parameters
279 | ----------
280 | DL
281 | DataLoader for training dataset.
282 |
283 | Returns
284 | ----------
285 | float
286 | Training loss for the current epoch.
287 | """
288 |
289 | self.model.train()
290 | total_loss = .0
291 | ss = 0
292 | for X, Y in DL:
293 | self.optimizer.zero_grad()
294 | X = X.to(self.device)
295 | Y = Y.to(self.device)
296 | loss, recon_loss_ec, recon_loss_ode, kl_div, z_div = self.model(X, Y)
297 | loss.backward()
298 | self.optimizer.step()
299 |
300 | total_loss += loss.item() * X.size(0)
301 | ss += X.size(0)
302 |
303 | train_loss = total_loss/ss
304 | return train_loss
305 |
306 |
307 | @torch.no_grad()
308 | def _on_epoch_val(self, DL) -> float:
309 | """
310 | Validate using validation dataset.
311 |
312 | Parameters
313 | ----------
314 | DL
315 | DataLoader for validation dataset.
316 |
317 | Returns
318 | ----------
319 | float
320 | Validation loss for the current epoch.
321 | """
322 |
323 | self.model.eval()
324 | total_loss = .0
325 | ss = 0
326 | for X, Y in DL:
327 | X = X.to(self.device)
328 | Y = Y.to(self.device)
329 | loss, recon_loss_ec, recon_loss_ode, kl_div, z_div = self.model(X, Y)
330 | total_loss += loss.item() * X.size(0)
331 | ss += X.size(0)
332 |
333 | val_loss = total_loss/ss
334 | return val_loss
335 |
336 |
337 | def get_time(
338 | self,
339 | ) -> np.ndarray:
340 | """
341 | Infer the developmental pseudotime.
342 |
343 | Returns
344 | ----------
345 | :class:`~numpy.ndarray`
346 | The pseudotime inferred for each cell.
347 | """
348 |
349 | X = self.adata.X
350 | if self.loss_mode in ['nb', 'zinb']:
351 | X = np.log1p(X)
352 | ts = self._get_time(self.model, X)
353 |
354 | ## The model might return pseudotime in reverse order. Check this based on number of genes expressed in each cell.
355 | if self.time_reverse is None:
356 | n_genes = torch.tensor(self.adata.obs['n_genes_by_counts'].values).float().log1p().to(self.device)
357 | m_ts = ts.mean()
358 | m_ngenes = n_genes.mean()
359 | beta_direction = (ts * n_genes).sum() - len(ts) * m_ts * m_ngenes
360 | if beta_direction > 0:
361 | self.time_reverse = True
362 | else:
363 | self.time_reverse = False
364 | if self.time_reverse:
365 | ts = 1 - ts
366 |
367 | return ts.cpu().numpy()
368 |
369 |
370 | def get_vector_field(
371 | self,
372 | T: np.ndarray,
373 | Z: np.ndarray,
374 | ) -> np.ndarray:
375 | """
376 | Infer the vector field.
377 |
378 | Parameters
379 | ----------
380 | T
381 | The pseudotime estimated for each cell.
382 | Z
383 | The latent representation for each cell.
384 |
385 | Returns
386 | ----------
387 | :class:`~numpy.ndarray`
388 | The estimated vector field.
389 | """
390 |
391 | vf = self._get_vector_field(
392 | self.model,
393 | T,
394 | Z,
395 | self.time_reverse,
396 | )
397 | return vf
398 |
399 |
400 | def get_latentsp(
401 | self,
402 | alpha_z: float = .5,
403 | alpha_predz: float = .5,
404 | step_size: Optional[int] = None,
405 | step_wise: bool = False,
406 | batch_size: Optional[int] = None,
407 | ) -> tuple:
408 | """
409 | Infer the latent space.
410 |
411 | Parameters
412 | ----------
413 | alpha_z
414 | Scaling factor for encoder-derived latent space.
415 | (Default: 0.5)
416 | alpha_predz
417 | Scaling factor for ODE-solver-derived latent space.
418 | (Default: 0.5)
419 | step_size
420 | Step size during integration.
421 | step_wise
422 | Whether to perform step-wise integration by iteratively considering only two time points each time.
423 | (Default: `False`)
424 | batch_size
425 | Batch size when deriving the latent space. The default is no mini-batching.
426 |
427 | Returns
428 | ----------
429 | tuple
430 | 3-tuple of weighted combined latent space, encoder-derived latent space, and ODE-solver-derived latent space.
431 | """
432 |
433 | X = self.adata.X
434 | if self.model.loss_mode in ['nb', 'zinb']:
435 | X = np.log1p(X)
436 | mix_zs, zs, pred_zs = self._get_latentsp(self.model,
437 | X,
438 | alpha_z,
439 | alpha_predz,
440 | step_size,
441 | step_wise,
442 | batch_size,
443 | )
444 | return mix_zs, zs, pred_zs
445 |
446 |
447 | def save_model(
448 | self,
449 | save_dir: str,
450 | save_prefix: str,
451 | ) -> None:
452 | """
453 | Save the trained scTour model.
454 |
455 | Parameters
456 | ----------
457 | save_dir
458 | The directory where the model will be saved.
459 | save_prefix
460 | The prefix for model name. The model will be saved in 'save_dir/save_prefix.pth'.
461 | """
462 |
463 | save_path = os.path.abspath(os.path.join(save_dir, f'{save_prefix}.pth'))
464 | # save_path = os.path.abspath(os.path.join(save_dir, f'{save_prefix}.tar'))
465 | torch.save(
466 | {
467 | 'model_state_dict': self.model.state_dict(),
468 | 'optimizer_state_dict': self.optimizer.state_dict(),
469 | 'model_kwargs': self.model_kwargs,
470 | 'time_reverse': self.time_reverse,
471 | 'adata': self.adata,
472 | 'percent': self.percent,
473 | 'nepoch': self.nepoch,
474 | 'batch_size': self.batch_size,
475 | 'random_state': self.random_state,
476 | 'drop_last': self.drop_last,
477 | 'lr': self.lr,
478 | 'wt_decay': self.wt_decay,
479 | 'eps': self.eps,
480 | 'val_frac': self.val_frac,
481 | 'use_gpu': self.use_gpu,
482 | },
483 | save_path
484 | )
485 |
486 |
487 | @staticmethod
488 | @torch.no_grad()
489 | def _get_time(
490 | model: TNODE,
491 | X: Union[np.ndarray, spmatrix],
492 | ) -> torch.tensor:
493 | """
494 | Derive the developmental pseudotime for cells.
495 |
496 | Parameters
497 | ----------
498 | model
499 | The trained scTour model.
500 | X
501 | The data matrix.
502 |
503 | Returns
504 | ----------
505 | :class:`torch.Tensor`
506 | The pseudotime estimated for each cell.
507 | """
508 |
509 | model.eval()
510 | if sparse.issparse(X):
511 | X = X.A
512 | X = torch.tensor(X).to(model.device)
513 | ts, _, _ = model.encoder(X)
514 | ts = ts.ravel()
515 | return ts
516 |
517 |
518 | @staticmethod
519 | @torch.no_grad()
520 | def _get_vector_field(
521 | model: TNODE,
522 | T: np.ndarray,
523 | Z: np.ndarray,
524 | time_reverse: bool,
525 | ) -> np.ndarray:
526 | """
527 | Derive the vector field for cells.
528 |
529 | Parameters
530 | ----------
531 | model
532 | The trained scTour model.
533 | T
534 | The pseudotime for each cell.
535 | Z
536 | The latent representation for each cell.
537 | time_reverse
538 | Whether to reverse the vector field.
539 |
540 | Returns
541 | ----------
542 | :class:`~numpy.ndarray`
543 | The estimated vector field.
544 | """
545 |
546 | model.eval()
547 | if not (isinstance(T, np.ndarray) and isinstance(Z, np.ndarray)):
548 | raise TypeError(
549 | 'The inputs must be numpy arrays.'
550 | )
551 | Z = torch.tensor(Z)
552 | T = torch.tensor(T)
553 | if time_reverse is None:
554 | raise RuntimeError(
555 | 'It seems you did not run `get_time()` function first after model training.'
556 | )
557 | direction = 1
558 | if time_reverse:
559 | direction = -1
560 | return direction * model.lode_func(T, Z).numpy()
561 |
562 |
563 | @staticmethod
564 | @torch.no_grad()
565 | def _get_latentsp(
566 | model: TNODE,
567 | X: Union[np.ndarray, spmatrix],
568 | alpha_z: float = .5,
569 | alpha_predz: float = .5,
570 | step_size: Optional[int] = None,
571 | step_wise: bool = False,
572 | batch_size: Optional[int] = None,
573 | ):
574 | """
575 | Derive the latent representations of cells.
576 |
577 | Parameters
578 | ----------
579 | model
580 | The trained scTour model.
581 | X
582 | The data matrix.
583 | alpha_z
584 | Scaling factor for encoder-derived latent space.
585 | (Default: 0.5)
586 | alpha_predz
587 | Scaling factor for ODE-solver-derived latent space.
588 | (Default: 0.5)
589 | step_size
590 | Step size during integration.
591 | step_wise
592 | Whether to perform step-wise integration by iteratively considering only two time points each time.
593 | (Default: `False`)
594 | batch_size
595 | Batch size when deriving the latent space. The default is no mini-batching.
596 |
597 | Returns
598 | ----------
599 | tuple
600 | 3-tuple of weighted combined latent space, encoder-derived latent space, and ODE-solver-derived latent space.
601 | """
602 |
603 | model.eval()
604 |
605 | if (alpha_z < 0) or (alpha_z > 1):
606 | raise ValueError(
607 | '`alpha_z` must be between 0 and 1.'
608 | )
609 | if (alpha_predz < 0) or (alpha_predz > 1):
610 | raise ValueError(
611 | '`alpha_predz` must be between 0 and 1.'
612 | )
613 | if alpha_z + alpha_predz != 1:
614 | raise ValueError(
615 | 'The sum of `alpha_z` and `alpha_predz` must be 1.'
616 | )
617 |
618 | if sparse.issparse(X):
619 | X = X.A
620 | X = torch.tensor(X).to(model.device)
621 | T, qz_mean, qz_logvar = model.encoder(X)
622 | T = T.ravel().cpu()
623 | epsilon = torch.randn(qz_mean.size())
624 | zs = epsilon * torch.exp(.5 * qz_logvar.cpu()) + qz_mean.cpu()
625 |
626 | sort_T, sort_idx, sort_ridx = np.unique(T, return_index=True, return_inverse=True)
627 | sort_T = torch.tensor(sort_T)
628 | sort_zs = zs[sort_idx]
629 |
630 | pred_zs = []
631 | if batch_size is None:
632 | batch_size = len(sort_T)
633 | times = int(np.ceil(len(sort_T) / batch_size))
634 | for i in range(times):
635 | idx1 = i * batch_size
636 | idx2 = np.min([(i + 1)*batch_size, len(sort_T)])
637 | t = sort_T[idx1:idx2]
638 | z = sort_zs[idx1:idx2]
639 | z0 = z[0]
640 |
641 | if not step_wise:
642 | options = get_step_size(step_size, t[0], t[-1], len(t))
643 | pred_z = odeint(
644 | model.lode_func,
645 | z0,
646 | t,
647 | method = model.ode_method,
648 | options = options
649 | ).view(-1, model.n_latent)
650 | else:
651 | pred_z = torch.empty((len(t), z.size(1)))
652 | pred_z[0] = z0
653 | for j in range(len(t) - 1):
654 | t2 = t[j:(j + 2)]
655 | options = get_step_size(step_size, t2[0], t2[-1], len(t2))
656 | pred_z[j + 1] = odeint(
657 | model.lode_func,
658 | z[j],
659 | t2,
660 | method = model.ode_method,
661 | options = options
662 | )[1]
663 |
664 | pred_zs += [pred_z]
665 |
666 | pred_zs = torch.cat(pred_zs)
667 | pred_zs = pred_zs[sort_ridx]
668 | mix_zs = alpha_z * zs + alpha_predz * pred_zs
669 |
670 | return mix_zs.numpy(), zs.numpy(), pred_zs.numpy()
671 |
--------------------------------------------------------------------------------
/sctour/vector_field.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scanpy as sc
3 | from scipy.sparse import coo_matrix, csr_matrix
4 | from scipy.stats import norm
5 | from sklearn.neighbors import NearestNeighbors
6 | import matplotlib.pyplot as plt
7 | from matplotlib.axes import Axes
8 | from anndata import AnnData
9 | from typing import Optional, Union
10 |
11 | from ._utils import l2_norm
12 | from . import logger
13 |
14 |
15 | def cosine_similarity(
16 | adata: AnnData,
17 | zs_key: str,
18 | reverse: bool = False,
19 | use_rep_neigh: Optional[str] = None,
20 | vf_key: str = 'X_VF',
21 | run_neigh: bool = True,
22 | n_neigh: int = 20,
23 | t_key: Optional[str] = None,
24 | var_stabilize_transform: bool = False,
25 | ) -> csr_matrix:
26 | """
27 | Calculate the cosine similarity between the vector field and the cell-neighbor latent state difference for each cell.
28 | The calculation borrows the ideas from scvelo: https://github.com/theislab/scvelo/blob/master/scvelo/tools/velocity_graph.py.
29 |
30 | Parameters
31 | ----------
32 | adata
33 | An :class:`~anndata.AnnData` object.
34 | reverse
35 | Whether to reverse the direction of the vector field. When the pseudotime returned by `get_time()` function was in reverse order and you used the post-inference adjustment (`reverse_time()` function), please set this parameter to `True`.
36 | (Default: `False`)
37 | zs_key
38 | The key in `.obsm` for storing the latent space.
39 | vf_key
40 | The key in `.obsm` for storing the vector field.
41 | (Default: `'X_VF'`)
42 | run_neigh
43 | Whether to run neighbor detection.
44 | (Default: `True`)
45 | use_rep_neigh
46 | The representation in `.obsm` which will be used for neighbor detection.
47 | n_neigh
48 | The number of neighbors considered for each cell.
49 | (Default: 20)
50 | t_key:
51 | The key in `.obs` for estimated pseudotime which will be considered when detecting neighbors.
52 | var_stabilize_transform
53 | Whether to perform variance-stabilizing transformation for vector field and cell-neighbor latent state difference.
54 | (Default: `False`)
55 |
56 | Returns
57 | ----------
58 | :class:`~scipy.sparse.csr_matrix`
59 | A sparse matrix with cosine similarities.
60 | """
61 |
62 | Z = np.array(adata.obsm[zs_key])
63 | V = np.array(adata.obsm[vf_key])
64 | if reverse:
65 | V = -V
66 | if var_stabilize_transform:
67 | V = np.sqrt(np.abs(V)) * np.sign(V)
68 |
69 | ncells = adata.n_obs
70 |
71 | if run_neigh or ('neighbors' not in adata.uns):
72 | if use_rep_neigh is None:
73 | use_rep_neigh = zs_key
74 | logger.warn(f"Warning: the parameter `use_rep_neigh` in function `plot_vector_field` is not provided. Use `{zs_key}` in `.obsm` of the AnnData instead.")
75 | else:
76 | if use_rep_neigh not in adata.obsm:
77 | raise KeyError(
78 | f"`{use_rep_neigh}` not found in `.obsm` of the AnnData. Please provide valid `use_rep_neigh` for neighbor detection."
79 | )
80 | sc.pp.neighbors(adata, use_rep = use_rep_neigh, n_neighbors = n_neigh)
81 | n_neigh = adata.uns['neighbors']['params']['n_neighbors'] - 1
82 | # indices_matrix = adata.obsp['distances'].indices.reshape(-1, n_neigh)
83 |
84 | if t_key is not None:
85 | if t_key not in adata.obs:
86 | raise KeyError(
87 | f"`{t_key}` not found in `.obs` of the AnnData. Please provide valid `t_key` for estimated pseudotime."
88 | )
89 | ts = adata.obs[t_key].values
90 | indices_matrix2 = np.zeros((ncells, n_neigh), dtype = int)
91 | for i in range(ncells):
92 | idx = np.abs(ts - ts[i]).argsort()[:(n_neigh + 1)]
93 | idx = np.setdiff1d(idx, i) if i in idx else idx[:-1]
94 | indices_matrix2[i] = idx
95 | # indices_matrix = np.hstack([indices_matrix, indices_matrix2])
96 |
97 | vals, rows, cols = [], [], []
98 | for i in range(ncells):
99 | # idx = np.unique(indices_matrix[i])
100 | # idx2 = indices_matrix[idx].flatten()
101 | # idx2 = np.setdiff1d(idx2, i)
102 | idx = adata.obsp['distances'][i].indices
103 | idx2 = adata.obsp['distances'][idx].indices
104 | idx2 = np.setdiff1d(idx2, i)
105 | idx = np.unique(np.concatenate([idx, idx2])) if t_key is None else np.unique(np.concatenate([idx, idx2, indices_matrix2[i]]))
106 | dZ = Z[idx] - Z[i, None]
107 | if var_stabilize_transform:
108 | dZ = np.sqrt(np.abs(dZ)) * np.sign(dZ)
109 | cos_sim = np.einsum("ij, j", dZ, V[i]) / (l2_norm(dZ, axis = 1) * l2_norm(V[i]))
110 | cos_sim[np.isnan(cos_sim)] = 0
111 | vals.extend(cos_sim)
112 | rows.extend(np.repeat(i, len(idx)))
113 | cols.extend(idx)
114 |
115 | res = coo_matrix((vals, (rows, cols)), shape = (ncells, ncells))
116 | res.data = np.clip(res.data, -1, 1)
117 | return res.tocsr()
118 |
119 |
120 | def quiver_autoscale(
121 | E: np.ndarray,
122 | V: np.ndarray,
123 | ):
124 | """
125 | Get the autoscaling in quiver.
126 | This function is from scvelo: https://github.com/theislab/scvelo/blob/master/scvelo/tools/velocity_embedding.py.
127 |
128 | Parameters
129 | ----------
130 | E
131 | The embedding.
132 | V
133 | The weighted unitary displacement.
134 |
135 | Returns
136 | ----------
137 | The autoscaling factor.
138 | """
139 |
140 | fig, ax = plt.subplots()
141 | scale_factor = np.abs(E).max()
142 |
143 | Q = ax.quiver(
144 | E[:, 0] / scale_factor,
145 | E[:, 1] / scale_factor,
146 | V[:, 0],
147 | V[:, 1],
148 | angles = 'xy',
149 | scale = None,
150 | scale_units = 'xy',
151 | )
152 | Q._init()
153 | fig.clf()
154 | plt.close(fig)
155 | return Q.scale / scale_factor
156 |
157 |
158 | def vector_field_embedding(
159 | adata: AnnData,
160 | T_key: str,
161 | E_key: str,
162 | scale: int = 10,
163 | self_transition: bool = False,
164 | ):
165 | """
166 | Calculate the weighted unitary displacement vectors under a certain embedding.
167 | This function borrows the ideas from scvelo: https://github.com/theislab/scvelo/blob/master/scvelo/tools/velocity_embedding.py.
168 |
169 | Parameters
170 | ----------
171 | adata
172 | An :class:`~anndata.AnnData` object.
173 | T_key
174 | The key in `.obsp` for cosine similarity.
175 | E_key
176 | The key in `.obsm` for embedding.
177 | scale
178 | Scale factor for cosine similarity.
179 | (Default: 10)
180 | self_transition
181 | Whether to take self-transition into consideration.
182 | (Default: `False`)
183 |
184 | Returns
185 | ----------
186 | The weighted unitary displacement vectors.
187 | """
188 |
189 | T = adata.obsp[T_key].copy()
190 |
191 | if self_transition:
192 | max_t = T.max(1).A.flatten()
193 | ub = np.percentile(max_t, 98)
194 | self_t = np.clip(ub - max_t, 0, 1)
195 | T.setdiag(self_t)
196 |
197 | T = T.sign().multiply(np.expm1(abs(T * scale)))
198 | T = T.multiply(csr_matrix(1.0 / abs(T).sum(1)))
199 | if self_transition:
200 | T.setdiag(0)
201 | T.eliminate_zeros()
202 |
203 | E = np.array(adata.obsm[E_key])
204 | V = np.zeros(E.shape)
205 |
206 | for i in range(adata.n_obs):
207 | idx = T[i].indices
208 | dE = E[idx] - E[i, None]
209 | dE /= l2_norm(dE)[:, None]
210 | dE[np.isnan(dE)] = 0
211 | prob = T[i].data
212 | V[i] = prob.dot(dE) - prob.mean() * dE.sum(0)
213 |
214 | V /= 3 * quiver_autoscale(E, V)
215 | return V
216 |
217 |
218 | def vector_field_embedding_grid(
219 | E: np.ndarray,
220 | V: np.ndarray,
221 | smooth: float = 0.5,
222 | stream: bool = False,
223 | density: float = 1.0,
224 | ) -> tuple:
225 | """
226 | Estimate the unitary displacement vectors within a grid.
227 | This function borrows the ideas from scvelo: https://github.com/theislab/scvelo/blob/master/scvelo/plotting/velocity_embedding_grid.py.
228 |
229 | Parameters
230 | ----------
231 | E
232 | The embedding.
233 | V
234 | The unitary displacement vectors under the embedding.
235 | smooth
236 | The factor for scale in Gaussian pdf.
237 | (Default: 0.5)
238 | stream
239 | Whether to adjust for streamplot.
240 | (Default: `False`)
241 | density
242 | grid density
243 | (Default: 1.0)
244 |
245 | Returns
246 | ----------
247 | tuple
248 | The embedding and unitary displacement vectors at grid level.
249 | """
250 |
251 | grs = []
252 | for i in range(E.shape[1]):
253 | m, M = np.min(E[:, i]), np.max(E[:, i])
254 | diff = M - m
255 | m = m - 0.01 * diff
256 | M = M + 0.01 * diff
257 | gr = np.linspace(m, M, int(50 * density))
258 | grs.append(gr)
259 |
260 | meshes = np.meshgrid(*grs)
261 | E_grid = np.vstack([i.flat for i in meshes]).T
262 |
263 | n_neigh = int(E.shape[0] / 50)
264 | nn = NearestNeighbors(n_neighbors = n_neigh, n_jobs = -1)
265 | nn.fit(E)
266 | dists, neighs = nn.kneighbors(E_grid)
267 |
268 | scale = np.mean([g[1] - g[0] for g in grs]) * smooth
269 | weight = norm.pdf(x = dists, scale = scale)
270 | weight_sum = weight.sum(1)
271 |
272 | V_grid = (V[neighs] * weight[:, :, None]).sum(1)
273 | V_grid /= np.maximum(1, weight_sum)[:, None]
274 |
275 | if stream:
276 | E_grid = np.stack(grs)
277 | ns = E_grid.shape[1]
278 | V_grid = V_grid.T.reshape(2, ns, ns)
279 |
280 | mass = np.sqrt((V_grid * V_grid).sum(0))
281 | min_mass = 1e-5
282 | min_mass = np.clip(min_mass, None, np.percentile(mass, 99) * 0.01)
283 | cutoff1 = (mass < min_mass)
284 |
285 | length = np.sum(np.mean(np.abs(V[neighs]), axis = 1), axis = 1).reshape(ns, ns)
286 | cutoff2 = (length < np.percentile(length, 5))
287 |
288 | cutoff = (cutoff1 | cutoff2)
289 | V_grid[0][cutoff] = np.nan
290 | else:
291 | min_weight = np.percentile(weight_sum, 99) * 0.01
292 | E_grid, V_grid = E_grid[weight_sum > min_weight], V_grid[weight_sum > min_weight]
293 | V_grid /= 3 * quiver_autoscale(E_grid, V_grid)
294 |
295 | return E_grid, V_grid
296 |
297 |
298 | def plot_vector_field(
299 | adata: AnnData,
300 | zs_key: str,
301 | reverse: bool = False,
302 | vf_key: str = 'X_VF',
303 | run_neigh: bool = True,
304 | use_rep_neigh: Optional[str] = None,
305 | t_key: Optional[str] = None,
306 | n_neigh: int = 20,
307 | var_stabilize_transform: bool = False,
308 | E_key: str = 'X_umap',
309 | scale: int = 10,
310 | self_transition: bool = False,
311 | smooth: float = 0.5,
312 | density: float = 1.,
313 | grid: bool = False,
314 | stream: bool = True,
315 | stream_density: int = 2,
316 | stream_color: str = 'k',
317 | stream_linewidth: int = 1,
318 | stream_arrowsize: int = 1,
319 | grid_density: float = 1.,
320 | grid_arrowcolor: str = 'grey',
321 | grid_arrowlength: int = 1,
322 | grid_arrowsize: int = 1,
323 | show: bool = True,
324 | save: Optional[Union[str, bool]] = None,
325 | # color: Optional[str] = None,
326 | # ax: Optional[Axes] = None,
327 | **kwargs,
328 | ):
329 | """
330 | Visualize the vector field.
331 | The visualization of vector field under an embedding borrows the ideas from scvelo: https://github.com/theislab/scvelo.
332 |
333 | Parameters
334 | ----------
335 | adata
336 | An :class:`~anndata.AnnData` object.
337 | zs_key
338 | The key in `.obsm` for storing the latent space.
339 | reverse
340 | Whether to reverse the direction of the vector field. When the pseudotime returned by `get_time()` function was in reverse order and you used the post-inference adjustment (`reverse_time()` function), please set this parameter to `True`.
341 | (Default: `False`)
342 | vf_key
343 | The key in `.obsm` for storing the vector field.
344 | run_neigh
345 | Whether to run neighbor detection.
346 | (Default: `True`)
347 | use_rep_neigh
348 | The representation in `.obsm` which will be used for neighbor detection.
349 | t_key:
350 | The key in `.obs` for estimated pseudotime which will be considered when detecting neighbors.
351 | n_neigh
352 | The number of neighbors considered for each cell.
353 | (Default: 20)
354 | var_stabilize_transform
355 | Whether to perform variance-stabilizing transformation for vector field and cell-neighbor latent state difference.
356 | (Default: `False`)
357 | E_key
358 | The key in `.obsm` for embedding.
359 | (Default: `'X_umap'`)
360 | scale
361 | Scale factor for cosine similarity.
362 | (Default: 10)
363 | self_transition
364 | Whether to take self-transition into consideration.
365 | (Default: `False`)
366 | smooth
367 | The factor for scale in Gaussian pdf.
368 | (Default: 0.5)
369 | density
370 | Percentage of cells to show when displaying the vector field at per-cell level.
371 | (Default: 1.0)
372 | grid
373 | Whether to display vector field as arrows at grid level.
374 | (Default: `False`)
375 | stream
376 | Whether to display vector field as streamplot.
377 | (Default: `True`)
378 | stream_density
379 | The density parameter in streamplot for controlling the closeness of the streamlines.
380 | (Default: 2)
381 | stream_color
382 | The streamline color for streamplot.
383 | (Default: 'k')
384 | stream_linewidth
385 | The line width for streamplot.
386 | (Default: 1)
387 | stream_arrowsize
388 | The arrow size for streamplot.
389 | (Default: 1)
390 | grid_density
391 | The density for showing vector field as arrows at grid level.
392 | (Default: 1.0)
393 | grid_arrowcolor
394 | The arrow color when showing vector field as arrows at grid level.
395 | (Default: `'grey'`)
396 | grid_arrowlength
397 | The arrow length when showing vector field as arrows at grid level.
398 | (Default: 1)
399 | grid_arrowsize
400 | The arrow size when showing vector field as arrows at grid level.
401 | (Default: 1)
402 | show
403 | Whether to show the plot.
404 | (Default: `True`)
405 | save
406 | Whether to save the figure. If `True` or a `str`, the figure will be saved as 'sctour_vector_field.png' (if `True` provided) or a given filename (if a `str` provided).
407 | kwargs
408 | Parameters passed to :func:`scanpy.pl.embedding`.
409 |
410 | Returns
411 | ----------
412 | None
413 | """
414 |
415 | if zs_key not in adata.obsm:
416 | raise KeyError(
417 | f"`{zs_key}` not found in `.obsm` of the AnnData. Please provide valid `zs_key` for latent space."
418 | )
419 | if vf_key not in adata.obsm:
420 | raise KeyError(
421 | f"`{vf_key}` not found in `.obsm` of the AnnData. Please provide valid `vf_key` for vector field."
422 | )
423 | if E_key not in adata.obsm:
424 | raise KeyError(
425 | f"`{E_key}` not found in `.obsm` of the AnnData. Please provide valid `E_key` for embedding."
426 | )
427 | if (grid_density < 0) or (grid_density > 1):
428 | raise ValueError(
429 | "`grid_density` must be between 0 and 1."
430 | )
431 | if (density < 0) or (density > 1):
432 | raise ValueError(
433 | "`density` must be between 0 and 1."
434 | )
435 |
436 | ##calculate cosine similarity
437 | adata.obsp['cosine_similarity'] = cosine_similarity(
438 | adata,
439 | reverse = reverse,
440 | zs_key = zs_key,
441 | vf_key = vf_key,
442 | run_neigh = run_neigh,
443 | use_rep_neigh = use_rep_neigh,
444 | t_key = t_key,
445 | n_neigh = n_neigh,
446 | var_stabilize_transform = var_stabilize_transform,
447 | )
448 | ##get weighted unitary displacement vectors under a certain embedding
449 | adata.obsm['X_DV'] = vector_field_embedding(
450 | adata,
451 | T_key = 'cosine_similarity',
452 | E_key = E_key,
453 | scale = scale,
454 | self_transition = self_transition,
455 | )
456 |
457 | E = np.array(adata.obsm[E_key])
458 | V = adata.obsm['X_DV']
459 |
460 | if grid:
461 | stream = False
462 |
463 | if grid or stream:
464 | E, V = vector_field_embedding_grid(
465 | E = E,
466 | V = V,
467 | smooth = smooth,
468 | stream = stream,
469 | density = grid_density,
470 | )
471 |
472 | ax = sc.pl.embedding(adata, basis = E_key, show=False, **kwargs)
473 | if stream:
474 | lengths = np.sqrt((V * V).sum(0))
475 | stream_linewidth *= 2 * lengths / lengths[~np.isnan(lengths)].max()
476 | stream_kwargs = dict(
477 | linewidth = stream_linewidth,
478 | density = stream_density,
479 | zorder = 3,
480 | color = stream_color,
481 | arrowsize = stream_arrowsize,
482 | arrowstyle = '-|>',
483 | maxlength = 4,
484 | integration_direction = 'both',
485 | )
486 | ax.streamplot(E[0], E[1], V[0], V[1], **stream_kwargs)
487 | else:
488 | if not grid:
489 | if density < 1:
490 | idx = np.random.choice(len(E), int(len(E) * density), replace = False)
491 | E = E[idx]
492 | V = V[idx]
493 | scale = 1 / grid_arrowlength
494 | hl, hw, hal = 6 * grid_arrowsize, 5 * grid_arrowsize, 4 * grid_arrowsize
495 | quiver_kwargs = dict(
496 | angles = 'xy',
497 | scale_units = 'xy',
498 | edgecolors = 'k',
499 | scale = scale,
500 | width = 0.001,
501 | headlength = hl,
502 | headwidth = hw,
503 | headaxislength = hal,
504 | color = grid_arrowcolor,
505 | linewidth = 0.2,
506 | zorder = 3,
507 | )
508 | ax.quiver(E[:, 0], E[:, 1], V[:, 0], V[:, 1], **quiver_kwargs)
509 |
510 | if save:
511 | if isinstance(save, str):
512 | plt.savefig(save)
513 | else:
514 | plt.savefig('sctour_vector_field.png')
515 | if show:
516 | plt.show()
517 | if save:
518 | plt.close()
519 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 |
3 | def get_readme():
4 | with open("README.md", "rt", encoding="utf-8") as fh:
5 | return fh.read()
6 |
7 | setuptools.setup(
8 | name='sctour',
9 | version='1.0.0',
10 | author='Qian Li',
11 | author_email='liqian.picb@gmail.com',
12 | description='a deep learning architecture for robust inference and accurate prediction of cellular dynamics',
13 | long_description=get_readme(),
14 | long_description_content_type="text/markdown",
15 | url='https://github.com/LiQian-XC/sctour',
16 | packages=setuptools.find_packages(),
17 | install_requires=[
18 | 'torch>=1.9.1',
19 | 'torchdiffeq>=0.2.2',
20 | 'numpy>=1.19.2',
21 | 'scanpy>=1.7.1',
22 | 'anndata>=0.7.5',
23 | 'scipy>=1.5.2',
24 | 'tqdm>=4.32.2',
25 | 'scikit-learn>=0.24.1',
26 | 'leidenalg>=0.8.4',
27 | ],
28 | classifiers=[
29 | "Programming Language :: Python :: 3",
30 | "Operating System :: OS Independent",
31 | "Intended Audience :: Science/Research",
32 | "Topic :: Scientific/Engineering :: Bio-Informatics",
33 | "Development Status :: 4 - Beta",
34 | ],
35 | python_requires='>=3.7',
36 | )
37 |
--------------------------------------------------------------------------------