├── .gitignore ├── .readthedocs.yml ├── CITATION.cff ├── LICENSE ├── MANIFEST.in ├── README.md ├── docs ├── Makefile ├── about.rst ├── conf.py ├── getting_started.rst ├── index.rst ├── install.rst ├── references.rst ├── release.rst └── requirements.txt ├── figures └── HumanBoneMarrow.png ├── notebooks ├── AuxiliaryFunctions.ipynb ├── Figure1_Description.ipynb ├── Figure2_ErythroidMouse.ipynb ├── Figure3_BoneMarrow.ipynb ├── Figure4_IntestinalOrganoid.ipynb ├── README.md ├── SuppFig1_Simulation.ipynb ├── SuppFig5_DentateGyrus.ipynb ├── SuppFig6_Oligodendrocyte.ipynb ├── SuppFig6_Pancreas.ipynb ├── fig1_toydata.csv └── simulated │ ├── cellinfo.csv │ ├── dimension.csv │ ├── spliced.mtx │ └── unspliced.mtx ├── pyproject.toml ├── setup.py └── unitvelo ├── __init__.py ├── config.py ├── eval_utils.py ├── gene_influence.py ├── individual_gene.py ├── main.py ├── model.py ├── optimize_utils.py ├── pl.py ├── utils.py └── velocity.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store* 2 | .vscode/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # Temp data 10 | data 11 | notebooks/data 12 | *.h5ad 13 | tempdata/ 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | env/ 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | .hypothesis/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # pyenv 83 | .python-version 84 | 85 | # celery beat schedule file 86 | celerybeat-schedule 87 | 88 | # SageMath parsed files 89 | *.sage.py 90 | 91 | # dotenv 92 | .env 93 | 94 | # virtualenv 95 | .venv 96 | venv/ 97 | ENV/ 98 | 99 | # Spyder project settings 100 | .spyderproject 101 | .spyproject 102 | 103 | # Rope project settings 104 | .ropeproject 105 | 106 | # mkdocs documentation 107 | /site 108 | 109 | # mypy 110 | .mypy_cache/ 111 | 112 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | sphinx: 2 | configuration: docs/conf.py 3 | build: 4 | image: latest 5 | python: 6 | version: 3.7 7 | install: 8 | - requirements: docs/requirements.txt 9 | - requirements: requirements.txt 10 | - method: pip 11 | path: . 12 | extra_requirements: 13 | - docs 14 | - method: setuptools 15 | path: package 16 | system_packages: true 17 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this package, please cite it as below." 3 | authors: 4 | - family-names: "GAO" 5 | given-names: "MINGZE" 6 | orcid: "https://orcid.org/0000-0002-8795-3861" 7 | title: "UniTVelo" 8 | version: 0.2.3 9 | doi: 10.5281/zenodo.7112387 10 | date-released: 2022-09-25 11 | url: "https://github.com/StatBiomed/UniTVelo" 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2022, michaelgmz 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # UniTVelo for RNA Velocity Analysis 2 | 3 | Temporally unified RNA velocity for single cell trajectory inference (UniTVelo) is implementated on Python 3 and TensorFlow 2. The model estimates velocity of each gene and updates cell time based on phase portraits concurrently. 4 | ![human bone marrow velocity stream](figures/HumanBoneMarrow.png) 5 | 6 | The major features of UniTVelo are, 7 | 8 | * Using spliced RNA oriented design to model RNA velocity and transcription rates 9 | * Introducing a unified latent time (`Unified-time mode`) across whole transcriptome to incorporate stably and monotonically changed genes 10 | * Retaining gene-spcific time matrics (`Independent mode`) for complex datasets 11 | 12 | UniTVelo has proved its robustness in 10 different datasets. Details can be found via our manuscript in bioRxiv which is currently under review ([UniTVelo](https://www.biorxiv.org/content/10.1101/2022.04.27.489808v1)). 13 | 14 | ## Installation 15 | 16 | ### GPU Acceleration 17 | 18 | UniTVelo is designed based on TensorFlow's automatic differentiation architecture. Please make sure [TensorFlow 2](https://www.tensorflow.org/install) and relative [CUDA](https://developer.nvidia.com/cuda-downloads) dependencies are correctly installed. 19 | 20 | Use the following scripts to confirm TensorFlow is using the GPU. 21 | 22 | ```python3 23 | import tensorflow as tf 24 | print ("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU'))) 25 | ``` 26 | 27 | If GPU is not available, UniTVelo will automatically switch to CPU for model fitting or it can be spcified in `config.py` (see `Getting Started` below). 28 | 29 | ### Main Module 30 | 31 | (Optional) Create a separate conda environment for version control and to avoid potential conflicts. 32 | 33 | ```python3 34 | conda create -n unitvelo python=3.7 35 | conda activate unitvelo 36 | ``` 37 | 38 | UniTVelo package can be conveniently installed via PyPI or directly from GitHub repository. 39 | 40 | ```python3 41 | pip install unitvelo 42 | ``` 43 | 44 | or 45 | 46 | ```python3 47 | pip install git+https://github.com/StatBiomed/UniTVelo 48 | ``` 49 | 50 | ## Getting Started 51 | 52 | ### Analyzed Notebooks 53 | https://drive.google.com/drive/folders/1A-Gcu0zhjVv4N8UZHttM_RULSztUzaUU?usp=sharing 54 | 55 | ### Public Datasets 56 | 57 | Examples of UniTVelo and steps for reproducible results are provided in Jupyter Notebook under `notebooks` folder. Specifically, please refer to records analyzing [Mouse Erythroid](notebooks/Figure2_ErythroidMouse.ipynb) and [Human Bone Marrow](notebooks/Figure3_BoneMarrow.ipynb) datasets. 58 | 59 | UniTVelo has proved its performance through 10 different datasets and 4 of them have been incorporated within scVelo package, see [datasets](notebooks/README.md). Others can be obtained via [link](https://connecthkuhk-my.sharepoint.com/:f:/g/personal/gmz1229_connect_hku_hk/EkC47RWWUrtOqcWzJ0neDGEBKLZTHWZW7PPe3vhUo9sn6g?e=QyoLFJ). 60 | 61 | ### RNA Velocity on New Dataset 62 | 63 | UniTVelo provides an integrated function for velocity analysis by default whilst specific configurations might need to be adjusted accordingly. 64 | 65 | 1. Import package 66 | 67 | ```python3 68 | import unitvelo as utv 69 | ``` 70 | 71 | 2. Sub-class and override base configuration file (here lists a few frequently used), please refer `config.py` for detailed arguments. 72 | 73 | ```python3 74 | velo = utv.config.Configuration() 75 | velo.R2_ADJUST = True 76 | velo.IROOT = None 77 | velo.FIT_OPTION = '1' 78 | velo.GPU = 0 79 | ``` 80 | 81 | * Arguments: 82 | * -- `velo.R2_ADJUST` (bool), linear regression R-squared on extreme quantile (default) or full data (adjusted) 83 | * -- `velo.IROOT` (str), specify root cell cluster would enable diffusion map based time initialization, default None 84 | * -- `velo.FIT_OPTION` (str), '1' Unified-time mode (default), '2' Independent mode 85 | * -- `velo.GPU` (int), specify the GPU card used for fitting, -1 will switch to CPU mode, default 0. 86 | 87 | 3. Run model (label refers to column name in adata.obs specifying celltypes) 88 | 89 | ```python3 90 | adata = utv.run_model(path_to_adata, label, config_file=velo) 91 | scv.pl.velocity_embedding_stream(adata, color=label, dpi=100, title='') 92 | ``` 93 | 94 | 4. Evaluation metrics (Optional) 95 | 96 | ```python3 97 | # Cross Boundary Direction Correctness 98 | # Ground truth should be given via `cluster_edges` 99 | metrics = {} 100 | metrics = utv.evaluate(adata, cluster_edges, label, 'velocity') 101 | 102 | # Latent time estimation 103 | scv.pl.scatter(adata, color='latent_time', color_map='gnuplot', size=20) 104 | 105 | # Phase portraits for individual genes (experimental) 106 | utv.pl.plot_range(gene_name, adata, velo, show_ax=True, time_metric='latent_time') 107 | ``` 108 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | PAPER = 8 | BUILDDIR = _build 9 | 10 | # User-friendly check for sphinx-build 11 | ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) 12 | $(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) 13 | endif 14 | 15 | # Internal variables. 16 | PAPEROPT_a4 = -D latex_paper_size=a4 17 | PAPEROPT_letter = -D latex_paper_size=letter 18 | ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 19 | # the i18n builder cannot share the environment and doctrees with the others 20 | I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 21 | 22 | .PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext 23 | 24 | help: 25 | @echo "Please use \`make ' where is one of" 26 | @echo " html to make standalone HTML files" 27 | @echo " dirhtml to make HTML files named index.html in directories" 28 | @echo " singlehtml to make a single large HTML file" 29 | @echo " pickle to make pickle files" 30 | @echo " json to make JSON files" 31 | @echo " htmlhelp to make HTML files and a HTML help project" 32 | @echo " qthelp to make HTML files and a qthelp project" 33 | @echo " devhelp to make HTML files and a Devhelp project" 34 | @echo " epub to make an epub" 35 | @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" 36 | @echo " latexpdf to make LaTeX files and run them through pdflatex" 37 | @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" 38 | @echo " text to make text files" 39 | @echo " man to make manual pages" 40 | @echo " texinfo to make Texinfo files" 41 | @echo " info to make Texinfo files and run them through makeinfo" 42 | @echo " gettext to make PO message catalogs" 43 | @echo " changes to make an overview of all changed/added/deprecated items" 44 | @echo " xml to make Docutils-native XML files" 45 | @echo " pseudoxml to make pseudoxml-XML files for display purposes" 46 | @echo " linkcheck to check all external links for integrity" 47 | @echo " doctest to run all doctests embedded in the documentation (if enabled)" 48 | 49 | clean: 50 | rm -rf $(BUILDDIR)/* 51 | 52 | html: 53 | $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html 54 | @echo 55 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." 56 | 57 | dirhtml: 58 | $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml 59 | @echo 60 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." 61 | 62 | singlehtml: 63 | $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml 64 | @echo 65 | @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." 66 | 67 | pickle: 68 | $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle 69 | @echo 70 | @echo "Build finished; now you can process the pickle files." 71 | 72 | json: 73 | $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json 74 | @echo 75 | @echo "Build finished; now you can process the JSON files." 76 | 77 | htmlhelp: 78 | $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp 79 | @echo 80 | @echo "Build finished; now you can run HTML Help Workshop with the" \ 81 | ".hhp project file in $(BUILDDIR)/htmlhelp." 82 | 83 | qthelp: 84 | $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp 85 | @echo 86 | @echo "Build finished; now you can run "qcollectiongenerator" with the" \ 87 | ".qhcp project file in $(BUILDDIR)/qthelp, like this:" 88 | @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/mytest.qhcp" 89 | @echo "To view the help file:" 90 | @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/mytest.qhc" 91 | 92 | devhelp: 93 | $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp 94 | @echo 95 | @echo "Build finished." 96 | @echo "To view the help file:" 97 | @echo "# mkdir -p $$HOME/.local/share/devhelp/mytest" 98 | @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/mytest" 99 | @echo "# devhelp" 100 | 101 | epub: 102 | $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub 103 | @echo 104 | @echo "Build finished. The epub file is in $(BUILDDIR)/epub." 105 | 106 | latex: 107 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 108 | @echo 109 | @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." 110 | @echo "Run \`make' in that directory to run these through (pdf)latex" \ 111 | "(use \`make latexpdf' here to do that automatically)." 112 | 113 | latexpdf: 114 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 115 | @echo "Running LaTeX files through pdflatex..." 116 | $(MAKE) -C $(BUILDDIR)/latex all-pdf 117 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 118 | 119 | latexpdfja: 120 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 121 | @echo "Running LaTeX files through platex and dvipdfmx..." 122 | $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja 123 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 124 | 125 | text: 126 | $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text 127 | @echo 128 | @echo "Build finished. The text files are in $(BUILDDIR)/text." 129 | 130 | man: 131 | $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man 132 | @echo 133 | @echo "Build finished. The manual pages are in $(BUILDDIR)/man." 134 | 135 | texinfo: 136 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 137 | @echo 138 | @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." 139 | @echo "Run \`make' in that directory to run these through makeinfo" \ 140 | "(use \`make info' here to do that automatically)." 141 | 142 | info: 143 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 144 | @echo "Running Texinfo files through makeinfo..." 145 | make -C $(BUILDDIR)/texinfo info 146 | @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." 147 | 148 | gettext: 149 | $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale 150 | @echo 151 | @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." 152 | 153 | changes: 154 | $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes 155 | @echo 156 | @echo "The overview file is in $(BUILDDIR)/changes." 157 | 158 | linkcheck: 159 | $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck 160 | @echo 161 | @echo "Link check complete; look for any errors in the above output " \ 162 | "or in $(BUILDDIR)/linkcheck/output.txt." 163 | 164 | doctest: 165 | $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest 166 | @echo "Testing of doctests in the sources finished, look at the " \ 167 | "results in $(BUILDDIR)/doctest/output.txt." 168 | 169 | xml: 170 | $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml 171 | @echo 172 | @echo "Build finished. The XML files are in $(BUILDDIR)/xml." 173 | 174 | pseudoxml: 175 | $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml 176 | @echo 177 | @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." 178 | -------------------------------------------------------------------------------- /docs/about.rst: -------------------------------------------------------------------------------- 1 | UniTVelo 2 | ======== 3 | 4 | Temporally unified RNA velocity for single cell trajectory inference (UniTVelo) is implementated on Python 3 and TensorFlow 2. 5 | The model estimates velocity of each gene and updates cell time based on phase portraits concurrently. 6 | 7 | .. image:: https://github.com/StatBiomed/UniTVelo/blob/main/figures/HumanBoneMarrow.png?raw=true 8 | :width: 300px 9 | :align: center 10 | 11 | The major features of UniTVelo are, 12 | 13 | * Using spliced RNA oriented design to model RNA velocity and transcription rates 14 | * Introducing a unified latent time (`Unified-time mode`) across whole transcriptome to incorporate stably and monotonically changed genes 15 | * Retaining gene-spcific time matrics (`Independent mode`) for complex datasets 16 | 17 | UniTVelo has proved its robustness in 10 different datasets. Details can be found via our manuscript in bioRxiv which is currently under review (UniTVelo_). 18 | 19 | .. _UniTVelo: https://www.biorxiv.org/content/10.1101/2022.04.27.489808v1 -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # mytest documentation build configuration file, created by 4 | # sphinx-quickstart on Sat Dec 19 19:34:52 2015. 5 | # 6 | # This file is execfile()d with the current directory set to its 7 | # containing dir. 8 | # 9 | # Note that not all possible configuration values are present in this 10 | # autogenerated file. 11 | # 12 | # All configuration values have a default; values that are commented out 13 | # serve to show the default. 14 | 15 | import sys 16 | import os 17 | from datetime import datetime 18 | 19 | # If extensions (or modules to document with autodoc) are in another directory, 20 | # add these directories to sys.path here. If the directory is relative to the 21 | # documentation root, use os.path.abspath to make it absolute, like shown here. 22 | #sys.path.insert(0, os.path.abspath('.')) 23 | 24 | from pathlib import Path 25 | HERE = Path(__file__).parent 26 | sys.path.insert(0, f"{HERE.parent.parent}") 27 | sys.path.insert(0, os.path.abspath("_ext")) 28 | 29 | # -- Retrieve notebooks ------------------------------------------------ 30 | 31 | from urllib.request import urlretrieve 32 | 33 | notebooks_url = "https://github.com/StatBiomed/UniTVelo/raw/main/notebooks/" 34 | notebooks = [ 35 | "Figure2_ErythroidMouse.ipynb", 36 | "Figure3_BoneMarrow.ipynb", 37 | "Figure4_IntestinalOrganoid.ipynb", 38 | "AuxiliaryFunctions.ipynb" 39 | ] 40 | for nb in notebooks: 41 | try: 42 | urlretrieve(notebooks_url + nb, nb) 43 | except: 44 | raise ValueError(f'{nb} cannot be retrieved.') 45 | 46 | 47 | # -- General configuration ------------------------------------------------ 48 | 49 | # If your documentation needs a minimal Sphinx version, state it here. 50 | #needs_sphinx = '1.0' 51 | 52 | # Add any Sphinx extension module names here, as strings. They can be 53 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 54 | # ones. 55 | 56 | needs_sphinx = "1.7" 57 | 58 | extensions = [ 59 | "sphinx.ext.autodoc", 60 | "sphinx.ext.doctest", 61 | "sphinx.ext.coverage", 62 | "sphinx.ext.mathjax", 63 | "sphinx.ext.autosummary", 64 | "sphinx.ext.napoleon", 65 | "sphinx.ext.intersphinx", 66 | "sphinx.ext.githubpages", 67 | "sphinx_autodoc_typehints", 68 | "nbsphinx", 69 | # "edit_on_github", 70 | ] 71 | 72 | # Add any paths that contain templates here, relative to this directory. 73 | templates_path = ['_templates'] 74 | 75 | # The suffix of source filenames. 76 | source_suffix = ['.rst', ".ipynb"] 77 | 78 | # The encoding of source files. 79 | #source_encoding = 'utf-8-sig' 80 | 81 | # The master toctree document. 82 | master_doc = 'index' 83 | 84 | # General information about the project. 85 | project = "UniTVelo" 86 | author = "Mingze Gao" 87 | title = "Temporally unified RNA velocity" 88 | copyright = f"{datetime.now():%Y}, {author}" 89 | 90 | # Disable pdf and epub generation 91 | enable_pdf_build = False 92 | enable_epub_build = False 93 | 94 | # The version info for the project you're documenting, acts as replacement for 95 | # |version| and |release|, also used in various other places throughout the 96 | # built documents. 97 | # 98 | # The short X.Y version. 99 | import unitvelo 100 | version = unitvelo.__version__ #'0.1.3' # 101 | # The full version, including alpha/beta/rc tags. 102 | release = version 103 | 104 | # The language for content autogenerated by Sphinx. Refer to documentation 105 | # for a list of supported languages. 106 | #language = None 107 | 108 | # There are two options for replacing |today|: either, you set today to some 109 | # non-false value, then it is used: 110 | #today = '' 111 | # Else, today_fmt is used as the format for a strftime call. 112 | #today_fmt = '%B %d, %Y' 113 | 114 | # List of patterns, relative to source directory, that match files and 115 | # directories to ignore when looking for source files. 116 | exclude_patterns = ['_build'] 117 | 118 | # The reST default role (used for this markup: `text`) to use for all 119 | # documents. 120 | #default_role = None 121 | 122 | # If true, '()' will be appended to :func: etc. cross-reference text. 123 | #add_function_parentheses = True 124 | 125 | # If true, the current module name will be prepended to all description 126 | # unit titles (such as .. function::). 127 | #add_module_names = True 128 | 129 | # If true, sectionauthor and moduleauthor directives will be shown in the 130 | # output. They are ignored by default. 131 | #show_authors = False 132 | 133 | # The name of the Pygments (syntax highlighting) style to use. 134 | pygments_style = 'sphinx' 135 | 136 | # A list of ignored prefixes for module index sorting. 137 | #modindex_common_prefix = [] 138 | 139 | # If true, keep warnings as "system message" paragraphs in the built documents. 140 | #keep_warnings = False 141 | 142 | 143 | # -- Options for HTML output ---------------------------------------------- 144 | 145 | # The theme to use for HTML and HTML Help pages. See the documentation for 146 | # a list of builtin themes. 147 | # html_theme = 'default' 148 | # html_theme = "bizstyle" 149 | # html_theme = "nature" 150 | html_theme = 'sphinx_rtd_theme' 151 | github_repo = 'unitvelo' 152 | github_nb_repo = 'unitvelo' 153 | html_theme_options = dict(navigation_depth=1, titles_only=True) 154 | 155 | 156 | # import sphinx_bootstrap_theme 157 | # html_theme = 'bootstrap' 158 | # html_theme_path = sphinx_bootstrap_theme.get_html_theme_path() 159 | 160 | # Theme options are theme-specific and customize the look and feel of a theme 161 | # further. For a list of options available for each theme, see the 162 | # documentation. 163 | #html_theme_options = {} 164 | 165 | 166 | # Add any paths that contain custom themes here, relative to this directory. 167 | #html_theme_path = [] 168 | 169 | # The name for this set of Sphinx documents. If None, it defaults to 170 | # " v documentation". 171 | #html_title = None 172 | 173 | # A shorter title for the navigation bar. Default is the same as html_title. 174 | #html_short_title = None 175 | 176 | # The name of an image file (relative to this directory) to place at the top 177 | # of the sidebar. 178 | #html_logo = None 179 | 180 | # The name of an image file (within the static path) to use as favicon of the 181 | # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 182 | # pixels large. 183 | #html_favicon = None 184 | 185 | # Add any paths that contain custom static files (such as style sheets) here, 186 | # relative to this directory. They are copied after the builtin static files, 187 | # so a file named "default.css" will overwrite the builtin "default.css". 188 | html_static_path = ['_static'] 189 | 190 | # Add any extra paths that contain custom files (such as robots.txt or 191 | # .htaccess) here, relative to this directory. These files are copied 192 | # directly to the root of the documentation. 193 | #html_extra_path = [] 194 | 195 | # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, 196 | # using the given strftime format. 197 | #html_last_updated_fmt = '%b %d, %Y' 198 | 199 | # If true, SmartyPants will be used to convert quotes and dashes to 200 | # typographically correct entities. 201 | #html_use_smartypants = True 202 | 203 | # Custom sidebar templates, maps document names to template names. 204 | # html_sidebars = { 205 | # '**': ['install.html','usage.html', 'tutorial.html', 'api.html', 'faq.html', 'release.html']} 206 | 207 | # Additional templates that should be rendered to pages, maps page names to 208 | # template names. 209 | #html_additional_pages = {} 210 | 211 | # If false, no module index is generated. 212 | #html_domain_indices = True 213 | 214 | # If false, no index is generated. 215 | #html_use_index = True 216 | 217 | # If true, the index is split into individual pages for each letter. 218 | #html_split_index = False 219 | 220 | # If true, links to the reST sources are added to the pages. 221 | #html_show_sourcelink = True 222 | 223 | # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. 224 | #html_show_sphinx = True 225 | 226 | # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. 227 | #html_show_copyright = True 228 | 229 | # If true, an OpenSearch description file will be output, and all pages will 230 | # contain a tag referring to it. The value of this option must be the 231 | # base URL from which the finished HTML is served. 232 | #html_use_opensearch = '' 233 | 234 | # This is the file name suffix for HTML files (e.g. ".xhtml"). 235 | #html_file_suffix = None 236 | 237 | # Output file base name for HTML help builder. 238 | htmlhelp_basename = 'UniTVelo' 239 | 240 | 241 | # -- Options for LaTeX output --------------------------------------------- 242 | 243 | latex_elements = { 244 | # The paper size ('letterpaper' or 'a4paper'). 245 | #'papersize': 'letterpaper', 246 | 247 | # The font size ('10pt', '11pt' or '12pt'). 248 | #'pointsize': '10pt', 249 | 250 | # Additional stuff for the LaTeX preamble. 251 | #'preamble': '', 252 | } 253 | 254 | # Grouping the document tree into LaTeX files. List of tuples 255 | # (source start file, target name, title, 256 | # author, documentclass [howto, manual, or own class]). 257 | latex_documents = [ 258 | ('index', 'UniTVelo.tex', u'UniTVelo Documentation', 259 | u'Mingze Gao', 'manual'), 260 | ] 261 | 262 | # The name of an image file (relative to this directory) to place at the top of 263 | # the title page. 264 | #latex_logo = None 265 | 266 | # For "manual" documents, if this is true, then toplevel headings are parts, 267 | # not chapters. 268 | #latex_use_parts = False 269 | 270 | # If true, show page references after internal links. 271 | #latex_show_pagerefs = False 272 | 273 | # If true, show URL addresses after external links. 274 | #latex_show_urls = False 275 | 276 | # Documents to append as an appendix to all manuals. 277 | #latex_appendices = [] 278 | 279 | # If false, no module index is generated. 280 | #latex_domain_indices = True 281 | 282 | 283 | # -- Options for manual page output --------------------------------------- 284 | 285 | # One entry per manual page. List of tuples 286 | # (source start file, name, description, authors, manual section). 287 | man_pages = [ 288 | ('index', 'UniTVelo', u'UniTVelo Documentation', 289 | [u'Mingze Gao'], 1) 290 | ] 291 | 292 | # If true, show URL addresses after external links. 293 | #man_show_urls = False 294 | 295 | 296 | # -- Options for Texinfo output ------------------------------------------- 297 | 298 | # Grouping the document tree into Texinfo files. List of tuples 299 | # (source start file, target name, title, author, 300 | # dir menu entry, description, category) 301 | texinfo_documents = [ 302 | ('index', 'UniTVelo', u'UniTVelo Documentation', 303 | u'Mingze Gao', 'UniTVelo', 'One line description of project.', 304 | 'Miscellaneous'), 305 | ] 306 | 307 | # Documents to append as an appendix to all manuals. 308 | #texinfo_appendices = [] 309 | 310 | # If false, no module index is generated. 311 | #texinfo_domain_indices = True 312 | 313 | # How to display URL addresses: 'footnote', 'no', or 'inline'. 314 | #texinfo_show_urls = 'footnote' 315 | 316 | # If true, do not generate a @detailmenu in the "Top" node's menu. 317 | #texinfo_no_detailmenu = False 318 | 319 | -------------------------------------------------------------------------------- /docs/getting_started.rst: -------------------------------------------------------------------------------- 1 | Getting Started 2 | =============== 3 | 4 | Public Datasets 5 | --------------- 6 | 7 | Examples of UniTVelo and steps for reproducible results are provided in Jupyter Notebook under notebooks_ folder. 8 | For start, please refer to records analyzing `Mouse Erythroid`_ and `Human BoneMarrow`_ datasets. 9 | 10 | RNA Velocity on New Dataset 11 | --------------------------- 12 | 13 | UniTVelo provides an integrated function for velocity analysis by default whilst specific configurations might need to be adjusted accordingly. 14 | 15 | # 1. Import package:: 16 | 17 | import unitvelo as utv 18 | 19 | # 2. Sub-class and override base configuration file (here lists a few frequently used), please refer `config.py` for detailed arguments:: 20 | 21 | velo = utv.config.Configuration() 22 | velo.R2_ADJUST = True 23 | velo.IROOT = None 24 | velo.FIT_OPTION = '1' 25 | velo.GPU = 0 26 | 27 | Arguments: 28 | 29 | - `velo.R2_ADJUST` (bool), linear regression R-squared on extreme quantile (default) or full data (adjusted). 30 | - `velo.IROOT` (str), specify root cell cluster would enable diffusion map based time initialization, default None. 31 | - `velo.FIT_OPTION` (str), '1' Unified-time mode (default), '2' Independent mode 32 | - `velo.GPU` (int), specify the GPU card used for fitting, -1 will switch to CPU mode, default 0 33 | 34 | # 3. Run model (label refers to column name in adata.obs specifying celltypes):: 35 | 36 | adata = utv.run_model(path_to_adata, label, config_file=velo) 37 | scv.pl.velocity_embedding_stream(adata, color=label, dpi=100, title='') 38 | 39 | # 4. Evaluation metrics (Optional):: 40 | 41 | # Cross Boundary Direction Correctness 42 | # Ground truth should be given via `cluster_edges` 43 | metrics = {} 44 | metrics = utv.evaluate(adata, cluster_edges, label, 'velocity') 45 | 46 | # Latent time estimation 47 | scv.pl.scatter(adata, color='latent_time', color_map='gnuplot', size=20) 48 | 49 | # Phase portraits for individual genes (experimental) 50 | utv.pl.plot_range(gene_name, adata, velo, show_ax=True, time_metric='latent_time') 51 | 52 | .. _notebooks: https://github.com/StatBiomed/UniTVelo/tree/main/notebooks 53 | .. _`Mouse Erythroid`: ./Figure2_ErythroidMouse 54 | .. _`Human BoneMarrow`: ./Figure3_BoneMarrow -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | |PyPI| |Docs| |Downloads| 2 | 3 | .. |PyPI| image:: https://img.shields.io/pypi/v/unitvelo.svg 4 | :target: https://pypi.org/project/unitvelo 5 | .. |Docs| image:: https://readthedocs.org/projects/unitvelo/badge/?version=latest 6 | :target: https://unitvelo.readthedocs.io 7 | .. |Downloads| image:: https://static.pepy.tech/personalized-badge/unitvelo?period=total&units=international_system&left_color=black&right_color=orange&left_text=Downloads 8 | :target: https://github.com/StatBiomed/UniTVelo 9 | 10 | UniTVelo for temporally unified RNA velocity analysis 11 | ===================================================== 12 | 13 | Temporally unified RNA velocity for single cell trajectory inference (UniTVelo) is implementated on Python 3 and TensorFlow 2. 14 | The model estimates velocity of each gene and updates cell time based on phase portraits concurrently. 15 | 16 | .. image:: https://github.com/StatBiomed/UniTVelo/blob/main/figures/HumanBoneMarrow.png?raw=true 17 | :width: 300px 18 | :align: center 19 | 20 | The major features of UniTVelo are, 21 | 22 | * Using spliced RNA oriented design to model RNA velocity and transcription rates 23 | * Introducing a unified latent time (`Unified-time mode`) across whole transcriptome to incorporate stably and monotonically changed genes 24 | * Retaining gene-spcific time matrics (`Independent mode`) for complex datasets 25 | 26 | UniTVelo has proved its robustness in 10 different datasets. Details can be found via our manuscript in bioRxiv which is currently under review (UniTVelo_). 27 | 28 | .. _UniTVelo: https://www.biorxiv.org/content/10.1101/2022.04.27.489808v1 29 | 30 | .. toctree:: 31 | :caption: Home 32 | :maxdepth: 1 33 | :hidden: 34 | 35 | about 36 | install 37 | release 38 | references 39 | 40 | .. toctree:: 41 | :caption: Tutorials 42 | :maxdepth: 1 43 | :hidden: 44 | 45 | getting_started 46 | Figure2_ErythroidMouse 47 | Figure3_BoneMarrow 48 | Figure4_IntestinalOrganoid 49 | AuxiliaryFunctions -------------------------------------------------------------------------------- /docs/install.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | GPU Acceleration 5 | ---------------- 6 | 7 | UniTVelo is designed based on TensorFlow's automatic differentiation architecture. 8 | Please make sure TensorFlow_ and relative CUDA_ dependencies are correctly installed. 9 | 10 | Please use the following scripts to confirm TensorFlow is using the GPU:: 11 | 12 | import tensorflow as tf 13 | print ("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU'))) 14 | 15 | If GPU is not available, UniTVelo will automatically switch to CPU for model fitting or it can be spcified in `config.py` (see `Getting Started`_). 16 | 17 | Main Module 18 | ----------- 19 | 20 | UniTVelo requires Python 3.7 or later. 21 | We recommend to use Anaconda_ environment for version control and to avoid potential conflicts:: 22 | 23 | conda create -n unitvelo python=3.7 24 | conda activate unitvelo 25 | 26 | UniTVelo package can be conveniently installed via PyPI (for stable version) :: 27 | 28 | pip install unitvelo 29 | 30 | or directly from GitHub repository (for development version):: 31 | 32 | pip install git+https://github.com/StatBiomed/UniTVelo 33 | 34 | Dependencies 35 | ------------ 36 | 37 | Most required dependencies are automatically installed, e.g. 38 | 39 | - `scvelo `_ for a few pre- and post-processing analysis 40 | - `statsmodels `_ for regression analysis 41 | - `jupyter `_ for running RNA velocity within notebooks 42 | 43 | If you run into any issues or errors are raised during the installation process, feel free to contact us at GitHub_. 44 | 45 | .. _Tensorflow: https://www.tensorflow.org/install 46 | .. _CUDA: https://developer.nvidia.com/cuda-downloads 47 | .. _Anaconda: https://www.anaconda.com/ 48 | .. _GitHub: https://github.com/StatBiomed/UniTVelo 49 | .. _`Getting Started`: getting_started -------------------------------------------------------------------------------- /docs/references.rst: -------------------------------------------------------------------------------- 1 | References 2 | ========== 3 | 4 | .. [Gao22] Gao *et al.* (2022), 5 | *UniTVelo: Temporally Unified RNA velocity reinforces single-cell trajectory inference*, 6 | `bioRxiv `__. 7 | 8 | .. [Bergen20] Bergen *et al.* (2020), 9 | *Generalizing RNA velocity to transient cell states through dynamical modeling*, 10 | `Nature Biotech `__. 11 | 12 | .. [Manno18] La Manno *et al.* (2018), 13 | *RNA velocity of single cells*, 14 | `Nature `__. -------------------------------------------------------------------------------- /docs/release.rst: -------------------------------------------------------------------------------- 1 | Release History 2 | =============== 3 | 4 | Version 0.2.5 5 | ------------- 6 | - Fix issues # 19_ and # 20_ in GitHub repo 7 | - Fix bugs in plotting function, plot_range 8 | - Structurize model outputs 9 | 10 | Version 0.2.4 11 | ------------- 12 | - Support the input of one or multiple genes trends to initialize cell time, see config.py, parameter self.IROOT 13 | - Re-formulate the structure of configuration file 14 | - Fixed bugs on progress bar 15 | 16 | Version 0.2.3 17 | ------------- 18 | - Change threshold for self.AGENES_R2 19 | 20 | Version 0.2.2 21 | ------------- 22 | - Add benchmarking function to scVelo 23 | - Provide prediction script which uses down-sampled data to predict RNA velocity and cell time on entire dataset 24 | - Add notebooks for auxiliary functions 25 | 26 | Version 0.2.1 27 | ------------- 28 | - Support input of both raw path and adata objects 29 | - Fix bugs on logging file 30 | 31 | Version 0.2.0 32 | ------------- 33 | - Beta version of UniTVelo released 34 | - Provide option of using gene counts for model initialization 35 | - Provide reference script for choosing unified-time mode or independent mode 36 | - Provide sampling script when dataset is oversized and GPU memory is bottleneck 37 | - Re-organize configuration file 38 | - Number of velocity genes can be amplified (an adjustable hyper-parameter) during optimization which allows post-analysis on more genes 39 | 40 | Version 0.1.6 41 | ------------- 42 | - Fix bug for early stopping scenarios 43 | 44 | Version 0.1.5 45 | ------------- 46 | - Fix bug in saving files 47 | - Add adjustable parameters for penalty in configuration file 48 | 49 | Version 0.1.4 50 | ------------- 51 | - Add penalty on loss function 52 | - Support informative gene selection in unified-time mode 53 | 54 | Version 0.1.3 55 | ------------- 56 | - Support CPU mode for model fitting 57 | - Fix bugs in documentations 58 | 59 | Version 0.1.2 60 | ------------- 61 | - Fix bugs in setup.py 62 | - Add tutorials and documentations 63 | 64 | Version 0.1.0 65 | ------------- 66 | - Alpha version of UniTVelo released 67 | 68 | .. _19: https://github.com/StatBiomed/UniTVelo/issues/19 69 | .. _20: https://github.com/StatBiomed/UniTVelo/issues/20 70 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | # Just until rtd.org understands pyproject.toml 2 | unitvelo 3 | setuptools 4 | setuptools_scm 5 | typing_extensions 6 | importlib_metadata 7 | sphinx_rtd_theme>=0.3 8 | sphinx_autodoc_typehints 9 | 10 | # converting notebooks to html 11 | ipykernel 12 | sphinx 13 | nbsphinx<=0.8.6 -------------------------------------------------------------------------------- /figures/HumanBoneMarrow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StatBiomed/UniTVelo/c7e9419186a24fbdca8c2a571af3b19f105d497a/figures/HumanBoneMarrow.png -------------------------------------------------------------------------------- /notebooks/README.md: -------------------------------------------------------------------------------- 1 | # Datasets used & validated by UniTVelo 2 | 3 | ## Available Datasets 4 | 5 | - **Mouse erythroid differentiation**: `scv.datasets.gastrulation_erythroid()` 6 | 7 | - **Human bone marrow hematopoieses**: `scv.datasets.bonemarrow()` 8 | 9 | - **Dentate gyrus neurogenesis development**: `scv.datasets.dentategyrus()` 10 | 11 | - **Pancreatic endocrinogenesis**: `scv.datasets.pancreas()` 12 | 13 | ## Available by requesting 14 | 15 | - **Mouse developing retina**: [Dr. Peter Kharchenko](https://www.nature.com/articles/s41592-021-01171-x) 16 | 17 | - **Intestinal organoid differentiation**: [Dr. Alexander Van Oudenaarden](https://www.science.org/doi/10.1126/science.aax3072) 18 | 19 | - **Neuron genesis with KCI stimulation**: [Dr. Qi Qiu](https://www.nature.com/articles/s41592-020-0935-4) 20 | 21 | - **Hindbrain (pons) of adolescent mice**: [Dr. Gioele La Manno](https://www.nature.com/articles/s41586-018-0414-6) 22 | 23 | - **Human erythroid differentiation**: [Dr. Melania Barlie](https://genomebiology.biomedcentral.com/articles/10.1186/s13059-021-02414-y) 24 | -------------------------------------------------------------------------------- /notebooks/fig1_toydata.csv: -------------------------------------------------------------------------------- 1 | ,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19 2 | 0,0.9999999873980414,0.9999999873980414,0.9999999873980414,0.739579850633163,0.9473157599859405,0.09069689782336354,0.18439479594235308,0.17572523953276686,0.5148382729385048,0.6425475077412557,0.9999999873980414,0.2337445785815362,0.6178726164216641,0.03067689191084355,0.6098699489666615,0.13071023509837687,0.13504501330316998,0.7555851855431683,0.08102700798190199,0.9486495378951076 3 | 1,0.9999999873980414,0.6028676149435341,0.21640546576236375,0.05768589457147755,0.9999999873980414,0.8902967543690465,0.0,0.11637212257483043,0.029676558478968218,0.050683560548350215,0.610536837921245,0.9999999873980414,0.9999999873980414,0.7189062930410728,0.9999999873980414,0.9999999873980414,0.9999999873980414,0.9999999873980414,0.9999999873980414,0.20540179801173508 4 | 2,0.9999999873980414,0.9999999873980414,0.9999999873980414,0.817605858319439,0.9143047567340545,0.03634544802480377,0.10536845482420176,0.6328776178997941,0.6055351707618684,0.7435811843606643,0.9999999873980414,0.17039012789609842,0.1840613514650613,0.8339446377067361,0.7145715148362797,0.2827609167434275,0.682560845016269,0.6538846199691761,0.8196065251831897,0.0 5 | 3,0.9999999873980414,0.30043347403989173,0.6702233993564732,0.012337445659795776,0.9999999873980414,0.9999999873980414,0.10436812139232643,0.7509169628610834,0.05035011607105844,0.11203734437003732,0.2950983624032233,0.9999999873980414,0.9999999873980414,0.7939313004317228,0.9999999873980414,0.9986662094888743,0.9999999873980414,0.9999999873980414,0.9999999873980414,0.36578859158908017 6 | 4,0.9999999873980414,0.8692897522996645,0.8249416368198581,0.2560853585600853,0.06768922889023088,0.7045681805175263,0.28809602838009596,0.8262754147290252,0.22674224455840886,0.30643547463114373,0.8616205293219537,0.7532510742021259,0.5788596125785261,0.9253084244846832,0.13737912464421242,0.9999999873980414,0.9999999873980414,0.9999999873980414,0.6175391719443724,0.0 7 | 5,0.9999999873980414,0.5885295024199877,0.7302434052689932,0.02334111341042444,0.9999999873980414,0.0,0.0,0.13071023509837687,0.8672890854359139,0.05101700502564199,0.6508836196735501,0.9999999873980414,0.9999999873980414,0.2487495800596662,0.9999999873980414,0.9999999873980414,0.09636545393732376,0.9999999873980414,0.9999999873980414,0.7195731819956563 8 | 6,0.9999999873980414,0.9999999873980414,0.9673224286234472,0.7712570759758819,0.9366455367126036,0.08269423036836088,0.10436812139232643,0.31377125313156284,0.6415471743093804,0.19673224160214886,0.9999999873980414,0.9126375343475956,0.15071690373588353,0.8406135272525717,0.13237745748483576,0.6912304014258552,0.21340446546673775,0.37479159247595817,0.8452817499346565,0.09203067573253065 9 | 7,0.9999999873980414,0.5971990588295739,0.6112037268758286,0.07769256320898421,0.9999999873980414,0.0,0.9689896510099061,0.09936645423294976,0.018006001773755997,0.8719573081179988,0.5791930570558179,0.9999999873980414,0.9999999873980414,0.2007335753296502,0.9999999873980414,0.9999999873980414,0.15605201537255198,0.9999999873980414,0.9999999873980414,0.6798932891979348 10 | 8,0.03534511459292844,0.6332110623770859,0.17172390580526553,0.029676558478968218,0.9999999873980414,0.0,0.13971323598525487,0.08936311991419643,0.017005668341880664,0.08369456380023621,0.6562187313102186,0.9999999873980414,0.0,0.7939313004317228,0.9999999873980414,0.9999999873980414,0.9999999873980414,0.9999999873980414,0.9669889841461554,0.3221073650638573 11 | 9,0.9999999873980414,0.5198399400978815,0.7819272992492188,0.8676225299132057,0.9999999873980414,0.9189729794161394,0.9999999873980414,0.15005001478129998,0.02500833579688333,0.029676558478968218,0.37679225933970883,0.9999999873980414,0.9999999873980414,0.3207735871546902,0.9999999873980414,0.9999999873980414,0.9999999873980414,0.9999999873980414,0.9999999873980414,0.7752584097033832 12 | 10,0.9999999873980414,0.8859619761642534,0.9999999873980414,0.7659219643392134,0.9999999873980414,0.03101033638813533,0.12204067868879065,0.13271090196212754,0.09203067573253065,0.21773924367153086,0.9999999873980414,0.13437812434858643,0.7055685139494017,0.5848616131697781,0.9093030895746779,0.5795265015331097,0.9999999873980414,0.9613204280321952,0.8012670789321419,0.8696231967769563 13 | 11,0.0,0.3894631494767964,0.6248749504447915,0.7985995231138077,0.9999999873980414,0.0,0.9816605411469936,0.1437145697127562,0.06668889545835555,0.044681559957098216,0.679559844720643,0.11103701093816198,0.9999999873980414,0.7852617440221366,0.9999999873980414,0.9999999873980414,0.874624863936333,0.9999999873980414,0.9999999873980414,0.5521840543951839 14 | 12,0.7679226312029641,0.10003334318753332,0.9893297641247045,0.7312437387008686,0.9299766471667681,0.05768589457147755,0.11403801123378798,0.2154051323304884,0.6002000591251999,0.7055685139494017,0.9999999873980414,0.8439479720254894,0.1287095682346262,0.9543180940090679,0.4241413751151413,0.7412470730196219,0.9999999873980414,0.7579192968842108,0.045015004434389994,0.9119706453930121 15 | 13,0.040346781752305105,0.5915305027156137,0.7739246317942161,0.0763587852998171,0.9999999873980414,0.0,0.08669556409586221,0.18739579623797908,0.05001667159376666,0.06268756173085421,0.6598866205604281,0.9999999873980414,0.9999999873980414,0.742580850928789,0.9999999873980414,0.9999999873980414,0.9999999873980414,0.9999999873980414,0.9999999873980414,0.6275425062631257 16 | 14,0.9999999873980414,0.5981993922614492,0.700566846790025,0.007002334023127332,0.9999999873980414,0.0043347782047931105,0.9999999873980414,0.17939312878297642,0.034344781161053106,0.046348782343557104,0.5045014941424597,0.9999999873980414,0.9999999873980414,0.7235745157231577,0.9999999873980414,0.9999999873980414,0.9999999873980414,0.9999999873980414,0.9999999873980414,0.7112370700633619 17 | 15,0.9999999873980414,0.5671890558733139,0.2394131346954964,0.730576849746285,0.9999999873980414,0.8156051914556883,0.9999999873980414,0.9363120922353119,0.0,0.05868622800335288,0.3474491453380324,0.9999999873980414,0.9999999873980414,0.2617539146740455,0.9999999873980414,0.9999999873980414,0.9999999873980414,0.9999999873980414,0.9999999873980414,0.12037345630233176 18 | 16,0.670556843833765,0.9999999873980414,0.9999999873980414,0.8009336344548501,0.9999999873980414,0.06768922889023088,0.028342780569801107,0.7149049593135715,0.6732243996520992,0.7562520744977519,0.9999999873980414,0.05301767188939266,0.6992330688808579,0.9283094247803092,0.25241746930987574,0.6072023931483272,0.8102700798190199,0.760586852702545,0.04534844891168177,0.8939646436192561 19 | 17,0.9999999873980414,0.9999999873980414,0.9303100916440599,0.736578850337537,0.8799599755730014,0.8179393027967308,0.1213737897342071,0.5718572785553988,0.20006668637506664,0.7132377369271126,0.9999999873980414,0.9999999873980414,0.14404801419004798,0.883627864823211,0.17005668341880664,0.703567847085651,0.16805601655505598,0.5891963913745712,0.8796265310957097,0.9263087579165585 20 | 18,0.0,0.9999999873980414,0.7692564091121312,0.27009002660633996,0.8399466382979881,0.26142047019675374,0.18072690669214353,0.8816271979594603,0.1587195711908862,0.2484161355823744,0.9999999873980414,0.7475825180881657,0.15505168194067664,0.9523174271453172,0.07769256320898421,0.9999999873980414,0.9999999873980414,0.9903300975565799,0.6352117292408366,0.7409136285423301 21 | 19,0.0,0.9999999873980414,0.7819272992492188,0.6138712826941628,0.0,0.6178726164216641,0.16138712700922042,0.9579859832592774,0.0043347782047931105,0.14904968134942465,0.829609859501943,0.10703567721066065,0.0,0.9656552062369883,0.042014004138763994,0.9999999873980414,0.07369122948148288,0.9999999873980414,0.23241080067236908,0.13871290255337954 22 | 20,0.9999999873980414,0.6282093952177092,0.8032677457958926,0.8022674123640172,0.9999999873980414,0.024008002365007997,0.9999999873980414,0.8219406365242321,0.0020006668637506664,0.9546515384863596,0.3371123665419873,0.9999999873980414,0.9999999873980414,0.7325775166100357,0.9999999873980414,0.9999999873980414,0.9999999873980414,0.9999999873980414,0.9999999873980414,0.17305768371443264 23 | 21,0.9999999873980414,0.9999999873980414,0.9999999873980414,0.7159052927454468,0.9999999873980414,0.03634544802480377,0.5841947242151946,0.18739579623797908,0.7072357363358606,0.7602534082252532,0.9999999873980414,0.5975325033068657,0.26008669228758663,0.901633866596967,0.6968989575398155,0.8222740810015239,0.712570847972529,0.14938312582671642,0.7589196303160861,0.041680559661472216 24 | 22,0.07669222977710888,0.9999999873980414,0.9816605411469936,0.7115705145406537,0.9999999873980414,0.03267755877459422,0.11737245600670576,0.12537512346170843,0.5878626134654041,0.2527509137871675,0.9999999873980414,0.9086362006200943,0.7142380703589879,0.9296432026894763,0.3067689191084355,0.2960986958350986,0.9999999873980414,0.1887295741471462,0.8119373022054788,0.9296432026894763 25 | 23,0.9999999873980414,0.8309436374111101,0.9999999873980414,0.6738912886066828,0.9189729794161394,0.8222740810015239,0.17405801714630798,0.11637212257483043,0.3701233697938733,0.6648882877198048,0.9999999873980414,0.7009002912673168,0.5358452750078868,0.9453150931221899,0.5841947242151946,0.14004668046254665,0.13304434643941931,0.17772590639651753,0.8432810830709059,0.06035345038981177 26 | 24,0.9999999873980414,0.9999999873980414,0.9999999873980414,0.13437812434858643,0.9173057570296805,0.04001333727501333,0.09269756468711421,0.22640880008111708,0.6802267336752266,0.7169056261773221,0.9999999873980414,0.760586852702545,0.1363787912123371,0.12370790107524954,0.7092364031996112,0.7019006246991921,0.6955651796306483,0.24774924662779085,0.7799266323854681,0.07002334023127332 27 | 25,0.9999999873980414,0.9186395349388476,0.9999999873980414,0.7979326341592241,0.9999999873980414,0.08436145275481977,0.08636211961857043,0.715571848268155,0.6512170641508419,0.6972324020171072,0.9999999873980414,0.7279092939279508,0.1497165703040082,0.9423140928265639,0.20173390876152553,0.25108369140070863,0.7975991896819323,0.3141046976088546,0.0,0.09636545393732376 28 | 26,0.6912304014258552,0.9999999873980414,0.9999999873980414,0.08969656439148821,0.9999999873980414,0.055018338753143325,0.00033344447729177773,0.1123707888473291,0.640546840877505,0.21307102098944597,0.9999999873980414,0.9999999873980414,0.6708902883110568,0.9999999873980414,0.27142380451550707,0.3444481450424064,0.8299433039792348,0.7622540750890039,0.046348782343557104,0.06668889545835555 29 | 27,0.19506501921568997,0.9999999873980414,0.9999999873980414,0.07369122948148288,0.9686562065326143,0.04068022622959688,0.015671890432713553,0.7499166294292081,0.21007002069381997,0.1677225720777642,0.9999999873980414,0.8999666442105081,0.19439813026110642,0.9793264298059512,0.6965655130625237,0.3341113662463613,0.9129709788248874,0.7599199637479614,0.017339112819172442,0.06968989575398155 30 | 28,0.9999999873980414,0.5488496096222661,0.7675891867256723,0.04568189338897355,0.9999999873980414,0.9039679779380094,0.0,0.049016338161891326,0.838612860388821,0.9253084244846832,0.6295431731268764,0.9999999873980414,0.9999999873980414,0.21940646605798975,0.9999999873980414,0.9999999873980414,0.9999999873980414,0.9999999873980414,0.0016672223864588886,0.7822607437265106 31 | 29,0.5911970582383219,0.9003000886877999,0.9999999873980414,0.6898966235166881,0.9999999873980414,0.8029343013186008,0.015338445955421776,0.7569189634523354,0.6588862871285528,0.18539512937422842,0.9999999873980414,0.9999999873980414,0.6632210653333459,0.9999999873980414,0.7205735154275317,0.5651883890095633,0.9999999873980414,0.7792597434308846,0.016672223864588886,0.9303100916440599 32 | 30,0.014004668046254665,0.6462153969914652,0.7279092939279508,0.0613537838216871,0.9999999873980414,0.9999999873980414,0.11037012198357843,0.11503834466566332,0.05601867218501866,0.11937312287045643,0.29443147344863974,0.9699899844417814,0.9999999873980414,0.7709236314985901,0.9999999873980414,0.9999999873980414,0.07435811843606643,0.9999999873980414,0.9999999873980414,0.3487829232471995 33 | 31,0.9999999873980414,0.5148382729385048,0.19506501921568997,0.05735245009418577,0.9999999873980414,0.0,0.0,0.04068022622959688,0.01900633520563133,0.8299433039792348,0.5905301692837384,0.9999999873980414,0.9999999873980414,0.2830943612207193,0.9999999873980414,0.9999999873980414,0.9999999873980414,0.9999999873980414,0.0,0.15971990462276153 34 | 32,0.0,0.1603867935773451,0.961653872509487,0.7175725151319057,0.9276425358257256,0.9999999873980414,0.055018338753143325,0.5555184991681017,0.21907302158069797,0.6842280674027279,0.9999999873980414,0.11770590048399754,0.754584852111293,0.9999999873980414,0.7872624108858872,0.5558519436453935,0.9999999873980414,0.5721907230326906,0.11603867809753865,0.8376125269569457 35 | 33,0.9999999873980414,0.13971323598525487,0.9713237623509485,0.7079026252904441,0.0,0.1420473473262973,0.20406802010256797,0.7129042924498208,0.5261753851664253,0.14571523657650687,0.9999999873980414,0.23007668933132663,0.6658886211516801,0.9063020892790519,0.6422140632639639,0.1587195711908862,0.9999999873980414,0.7942647449090146,0.06302100620814599,0.904634866892593 36 | 34,0.9999999873980414,0.9999999873980414,0.9709903178736567,0.6932310682896059,0.8962987549602985,0.7852617440221366,0.2527509137871675,0.10836945511982776,0.36145381338428706,0.4014671506593004,0.9999999873980414,0.7672557422483806,0.3337779217690695,0.9999999873980414,0.3784594817261677,0.12070690077962354,0.9999999873980414,0.1573857932817191,0.6722240662202239,0.0 37 | 35,0.9999999873980414,0.6438812856504228,0.7569189634523354,0.00033344447729177773,0.9999999873980414,0.0,0.02334111341042444,0.07269089604960755,0.04001333727501333,0.04368122652522288,0.6532177310145926,0.9999999873980414,0.9999999873980414,0.7559186300204601,0.9999999873980414,0.9999999873980414,0.9999999873980414,0.9999999873980414,0.9999999873980414,0.21140379860298708 38 | 36,0.044681559957098216,0.3544514793611597,0.7822607437265106,0.040346781752305105,0.9999999873980414,0.0,0.05168389398022555,0.14338112523546442,0.02800933609250933,0.9683227620553225,0.7079026252904441,0.9999999873980414,0.9999999873980414,0.22707568903570063,0.9999999873980414,0.9999999873980414,0.9999999873980414,0.9999999873980414,0.9999999873980414,0.6682227324927226 39 | 37,0.9999999873980414,0.9999999873980414,0.9629876504186541,0.7622540750890039,0.0,0.0913637867779471,0.13071023509837687,0.18072690669214353,0.5398466087353881,0.6412137298320886,0.9999999873980414,0.7729242983623408,0.697565846494399,0.895631866005715,0.34978325667907484,0.7835945216356777,0.9999999873980414,0.7582527413615026,0.8982994218240492,0.09103034230065532 40 | 38,0.9999999873980414,0.9199733128480148,0.9699899844417814,0.17472490610089153,0.9999999873980414,0.7779259655217174,0.26008669228758663,0.09469823155086488,0.6972324020171072,0.6228742835810408,0.9999999873980414,0.748582851520041,0.6808936226298101,0.13404467987129465,0.7612537416571286,0.8532844173896592,0.715571848268155,0.1630543493956793,0.09236412020982243,0.02367455788771622 41 | 39,0.9999999873980414,0.9999999873980414,0.9999999873980414,0.12537512346170843,0.9999999873980414,0.09336445364169776,0.13071023509837687,0.6898966235166881,0.6712237327883486,0.7555851855431683,0.9999999873980414,0.7072357363358606,0.18039346221485175,0.9999999873980414,0.7062354029039852,0.3311103659507353,0.9999999873980414,0.6568856202648021,0.16105368253192864,0.08402800827752799 42 | 40,0.19839946398860775,0.9003000886877999,0.925641868961975,0.7849282995448448,0.9313104250759352,0.8659553075267468,0.9999999873980414,0.2097365762165282,0.6728909551748075,0.7385795172012877,0.9999999873980414,0.9999999873980414,0.2457485797640402,0.8892964209371712,0.7042347360402346,0.6475491749006324,0.9999999873980414,0.8106035242963117,0.03201066982001066,0.8949649770511314 43 | 41,0.0,0.5871957245108206,0.6812270671071019,0.05035011607105844,0.9999999873980414,0.0,0.11303767780191265,0.1227075676433742,0.04801600473001599,0.07369122948148288,0.5971990588295739,0.9999999873980414,0.9999999873980414,0.8042680792277679,0.9999999873980414,0.9986662094888743,0.9999999873980414,0.9999999873980414,0.9999999873980414,0.3517839235428255 44 | 42,0.9999999873980414,0.9999999873980414,0.9779926518967841,0.7625875195662957,0.9999999873980414,0.09103034230065532,0.11603867809753865,0.30176725194905885,0.22640880008111708,0.19806601951131597,0.9999999873980414,0.898632866301341,0.1213737897342071,0.10370123243774287,0.8452817499346565,0.664554843242513,0.1287095682346262,0.34844947876990773,0.8592864179809112,0.052350782934809104 45 | 43,0.16138712700922042,0.9999999873980414,0.9469823155086488,0.7962654117727652,0.9999999873980414,0.04434811547980644,0.025341780274175107,0.6605535095150117,0.20273424219340086,0.8306101929338183,0.8556185287307017,0.9373124256671872,0.19806601951131597,0.8582860845490359,0.664554843242513,0.6358786181954201,0.9999999873980414,0.760586852702545,0.9999999873980414,0.03334444772917777 46 | 44,0.17739246191922575,0.9233077576209325,0.9999999873980414,0.7809269658173434,0.9999999873980414,0.04734911577543244,0.9999999873980414,0.7159052927454468,0.21773924367153086,0.1183727894385811,0.9999999873980414,0.8599533069354948,0.6768922889023088,0.9809936521924101,0.6835611784481443,0.5518506099178921,0.9999999873980414,0.739579850633163,0.041680559661472216,0.9263087579165585 47 | 45,0.9999999873980414,0.9999999873980414,0.9349783143261448,0.7535845186794177,0.9153050901659299,0.8549516397761181,0.09869956527836621,0.619539838808123,0.6578859536966775,0.17705901744193397,0.9999999873980414,0.9269756468711421,0.7672557422483806,0.8832944203459192,0.14271423628088087,0.7049016249948181,0.15971990462276153,0.5868622800335288,0.8702900857315399,0.07235745157231577 48 | 46,0.9999999873980414,0.889629865414463,0.2487495800596662,0.3221073650638573,0.0,0.2594198033330031,0.14871623687213287,0.03201066982001066,0.9213070907571819,0.907635867188219,0.8212737475696485,0.9999999873980414,0.9999999873980414,0.21907302158069797,0.9999999873980414,0.9699899844417814,0.9999999873980414,0.9999999873980414,0.1527175705996342,0.16105368253192864 49 | 47,0.9999999873980414,0.9999999873980414,0.8209403030923568,0.2830943612207193,0.8612870848446619,0.3937979276815895,0.24174724603653885,0.0,0.1483827923948411,0.20006668637506664,0.9999999873980414,0.15771923775901087,0.0823607858910691,0.9999999873980414,0.9633210948959459,0.9999999873980414,0.9999999873980414,0.9999999873980414,0.5765255012374837,0.08102700798190199 50 | 48,0.9999999873980414,0.9999999873980414,0.8222740810015239,0.21373790994402952,0.8432810830709059,0.6805601781525183,0.2274091335129924,0.09103034230065532,0.19806601951131597,0.3411137002694886,0.9999999873980414,0.24008002365007997,0.6118706158304121,0.9373124256671872,0.18172724012401886,0.9999999873980414,0.9999999873980414,0.021007002069381997,0.6018672815116588,0.9086362006200943 51 | 49,0.0,0.9999999873980414,0.9603200946003199,0.07535845186794177,0.9999999873980414,0.8186061917513143,0.05168389398022555,0.28676225047092885,0.682560845016269,0.718572848563781,0.9999999873980414,0.8432810830709059,0.7135711814044043,0.9266422023938503,0.10503501034690998,0.3204401426773984,0.18539512937422842,0.3097699194040615,0.1093697885517031,0.8482827502302825 52 | 50,0.9999999873980414,0.9999999873980414,0.9999999873980414,0.7659219643392134,0.9999999873980414,0.04801600473001599,0.613537838216871,0.18772924071527086,0.7122374034952372,0.8122707466827706,0.9999999873980414,0.634544840286253,0.22374124426278286,0.8592864179809112,0.661553842946887,0.847615861275699,0.6878959566529375,0.19539846369298175,0.7529176297248341,0.907635867188219 53 | 51,0.018672890728339553,0.9999999873980414,0.9999999873980414,0.826608859206317,0.9999999873980414,0.7919306335679721,0.07802600768627599,0.4911637150507886,0.6458819525141735,0.7239079602004495,0.9999999873980414,0.12937645718920976,0.6835611784481443,0.9489829823723994,0.2050683535344433,0.35278425697470084,0.9999999873980414,0.29709902926697396,0.018339446251047775,0.07502500739064999 54 | 52,0.0,0.9999999873980414,0.9999999873980414,0.778592854476301,0.9339779808942694,0.02034011311479844,0.10970323302899487,0.28676225047092885,0.6348782847635448,0.7342447389964946,0.9999999873980414,0.9999999873980414,0.12470823450712487,0.9999999873980414,0.2050683535344433,0.6325441734225024,0.1467155700083822,0.565521833486855,0.8886295319825877,0.04834944920730777 55 | 53,0.03801267041126266,0.595531836443115,0.7082360697677359,0.018339446251047775,0.9999999873980414,0.9999999873980414,0.11503834466566332,0.039013003843137994,0.0,0.12204067868879065,0.30343447433551773,0.9329776474623941,0.0,0.16238746044109575,0.9999999873980414,0.9999999873980414,0.9999999873980414,0.9999999873980414,0.9999999873980414,0.36312103577074595 56 | 54,0.9999999873980414,0.853617861866951,0.7505835183837917,0.2750916937657166,0.049349782639183104,0.5455151648493484,0.2680893597425893,0.02200733550125733,0.09169723125523888,0.8156051914556883,0.9999999873980414,0.11137045541545376,0.12370790107524954,0.9999999873980414,0.041680559661472216,0.9823274301015772,0.9999999873980414,0.9999999873980414,0.2967655847896822,0.1513837926904671 57 | 55,0.9999999873980414,0.8929643101873808,0.9999999873980414,0.6055351707618684,0.8812937534821685,0.8286095260700677,0.19139712996548042,0.16105368253192864,0.4744914911861997,0.6718906217429321,0.9999999873980414,0.7825941882038023,0.5651883890095633,0.9999999873980414,0.3984661503636744,0.19539846369298175,0.9999999873980414,0.08269423036836088,0.7775925210444257,0.9753250960784499 58 | 56,0.9999999873980414,0.8936311991419643,0.9123040898703039,0.10603534377878532,0.9333110919396859,0.01167055670521222,0.09236412020982243,0.4551517115032766,0.6288762841722928,0.7055685139494017,0.9999999873980414,0.913637867779471,0.7555851855431683,0.8699566412542481,0.15605201537255198,0.29543180688051507,0.6838946229254361,0.5251750517345499,0.11503834466566332,0.859619862458203 59 | 57,0.9999999873980414,0.9999999873980414,0.8342780821840279,0.5048349386197515,0.0,0.2624208036286291,0.1660553496913053,0.0,0.0006668889545835555,0.07602534082252532,0.8649549740948714,0.9293097582121845,0.9999999873980414,0.9999999873980414,0.03401133668376133,0.9999999873980414,0.9999999873980414,0.9999999873980414,0.2337445785815362,0.8569523066398688 60 | 58,0.9999999873980414,0.9999999873980414,0.7819272992492188,0.27009002660633996,0.04801600473001599,0.6022007259889506,0.27442480481113307,0.07002334023127332,0.20040013085235842,0.6518839531054255,0.9999999873980414,0.13437812434858643,0.1467155700083822,0.9999999873980414,0.1123707888473291,0.9999999873980414,0.7592530747933779,0.00733577850041911,0.5531843878270593,0.1153717891429551 61 | 59,0.9999999873980414,0.12637545689358376,0.8496165281394497,0.24008002365007997,0.8392797493434045,0.7412470730196219,0.24508169080945663,0.04968322711647488,0.2514171358780004,0.3637879247253295,0.9999999873980414,0.21007002069381997,0.5951983919658232,0.9596532056457363,0.1767255729646422,0.9999999873980414,0.9999999873980414,0.006002000591251999,0.6422140632639639,0.9303100916440599 62 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=42" 4 | ] 5 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import io 3 | import os 4 | from setuptools import setup, find_packages 5 | 6 | # version of cudatoolkit and tensorflow 7 | REQUIRED = [ 8 | 'numpy>=1.20', 9 | 'scikit-learn>=0.22', 10 | 'pandas', 11 | 'scipy>=1.4.1', 12 | 'seaborn', 13 | 'matplotlib>=3.3.0', 14 | 'tqdm', 15 | 'scanpy>=1.5', 16 | 'statsmodels', 17 | 'anndata>=0.7.5', 18 | 'scvelo>=0.2.2', 19 | 'IPython', 20 | 'ipykernel', 21 | 'IProgress', 22 | 'ipywidgets', 23 | 'jupyter', 24 | 'tensorflow>=2.4.1' 25 | ] 26 | 27 | here = os.path.abspath(os.path.dirname(__file__)) 28 | with io.open(os.path.join(here, 'README.md'), encoding='utf-8') as f: 29 | long_description = '\n' + f.read() 30 | 31 | setup( 32 | name='unitvelo', 33 | version='0.2.5', 34 | # use_scm_version=True, 35 | # setup_requires=['setuptools_scm'], 36 | 37 | description='Temporally unified RNA velocity inference', 38 | long_description=long_description, 39 | long_description_content_type='text/markdown', 40 | author='Mingze Gao', 41 | author_email='gmz1229@connect.hku.hk', 42 | python_requires='>=3.7.0', 43 | url='https://github.com/StatBiomed/UniTVelo', 44 | packages=find_packages(), 45 | 46 | entry_points={ 47 | 'console_scripts': ['unitvelo = unitvelo.main:run_model'], 48 | }, 49 | 50 | install_requires=REQUIRED, 51 | extras_require={'parallel': ['multiprocessing>=3.8']}, 52 | include_package_data=True, 53 | license='BSD', 54 | keywords=[ 55 | 'RNA velocity', 56 | 'Unified time', 57 | 'Transcriptomics', 58 | 'Kinetic', 59 | 'Trajectory inference' 60 | ], 61 | 62 | classifiers=[ 63 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers 64 | 'License :: OSI Approved :: BSD License', 65 | 'Development Status :: 4 - Beta', 66 | 'Intended Audience :: Science/Research', 67 | 'Natural Language :: English', 68 | 'Programming Language :: Python', 69 | 'Programming Language :: Python :: 3', 70 | 'Programming Language :: Python :: 3.7', 71 | 'Programming Language :: Python :: 3.8', 72 | 'Programming Language :: Python :: 3.9', 73 | 'Topic :: Scientific/Engineering :: Bio-Informatics', 74 | 'Topic :: Scientific/Engineering :: Visualization' 75 | ] 76 | ) -------------------------------------------------------------------------------- /unitvelo/__init__.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import os 3 | from time import gmtime, strftime 4 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 5 | 6 | try: 7 | from setuptools_scm import get_version 8 | __version__ = get_version(root="..", relative_to=__file__) 9 | del get_version 10 | 11 | except (LookupError, ImportError): 12 | try: 13 | from importlib_metadata import version 14 | except: 15 | from importlib.metadata import version 16 | __version__ = version(__name__) 17 | del version 18 | 19 | print (f'(Running UniTVelo {__version__})') 20 | print (strftime("%Y-%m-%d %H:%M:%S", gmtime())) 21 | 22 | from .main import run_model 23 | from .config import Configuration 24 | from .eval_utils import evaluate 25 | from .gene_influence import influence 26 | from .utils import choose_mode, subset_adata -------------------------------------------------------------------------------- /unitvelo/config.py: -------------------------------------------------------------------------------- 1 | #%% 2 | #! Base Configuration Class 3 | #! Don't use this class directly. 4 | #! Instead, sub-class it and override the configurations you need to change. 5 | 6 | 7 | class Preprocessing(object): 8 | def __init__(self): 9 | self.MIN_SHARED_COUNTS = 20 10 | 11 | # (int) # of highly variable genes selected for pre-processing, default 2000 12 | # consider decreasing to 1500 when # of cells > 10k 13 | self.N_TOP_GENES = 2000 14 | 15 | self.N_PCS = 30 16 | self.N_NEIGHBORS = 30 17 | 18 | # (bool) use raw un/spliced counts or first order moments 19 | self.USE_RAW = False 20 | 21 | # (bool) rescaled Mu/Ms as input based on variance, default True 22 | self.RESCALE_DATA = True 23 | 24 | class Regularization(object): 25 | def __init__(self): 26 | # (bool) regularization on loss function to push peak time away from 0.5 27 | # mainly used in unified time mode for linear phase portraits 28 | self.REG_LOSS = True 29 | # (float) gloablly adjust the magnitude of the penalty, recommend < 0.1 30 | self.REG_TIMES = 0.075 31 | # (float) scaling parameter of the regularizer 32 | self.REG_SCALE = 1 33 | 34 | # (list of tuples) [(gene1, trend1), (gene2, trend2), (gene3, trend3), ...], 35 | # a list of genes with trend can be one of {increase, decrease}, default None 36 | self.GENE_PRIOR = None 37 | self.GENE_PRIOR_SCALE = 5000 38 | 39 | class Optimizer(object): 40 | def __init__(self): 41 | # (float) learning rate of the main optimizer 42 | self.LEARNING_RATE = 1e-2 43 | # (int) maximum iteration rate of main optimizer 44 | self.MAX_ITER = 12000 45 | 46 | class FittingOption(object): 47 | def __init__(self): 48 | # Fitting options under Gaussian model 49 | # '1' = Unified-time mode 50 | # '2' = Independent mode 51 | self.FIT_OPTION = '1' 52 | 53 | # (str, experimental) methods to aggregate time metrix, default 'SVD' 54 | # Max SVD Raw 55 | self.DENSITY = 'SVD' 56 | # (str) whether to reorder cell based on relative positions for time assignment 57 | # Soft_Reorder (default) Hard (for Independent mode) 58 | self.REORDER_CELL = 'Soft_Reorder' 59 | # (bool) aggregate gene-specific time to cell time during fitting 60 | # controlled by self.FIT_OPTION 61 | self.AGGREGATE_T = True 62 | 63 | # (bool, experimental), whether clip negative predictions to 0, default False 64 | self.ASSIGN_POS_U = False 65 | 66 | # (bool, experimental) cell time restricted to (0, 1) if False, default False 67 | self.RESCALE_TIME = False 68 | 69 | class VelocityGenes(object): 70 | def __init__(self): 71 | # (bool) linear regression $R^2$ on extreme quantile (default) or full data (adjusted) 72 | # valid when self.VGENES = 'basic' 73 | self.R2_ADJUST = True 74 | 75 | # (str) selection creteria for velocity genes used in RNA velocity construction, default basic 76 | # 1. raws, all highly variable genes specified by self.N_TOP_GENES will be used 77 | # 2. offset, linear regression $R^2$ and coefficient with offset, will override self.R2_ADJUST 78 | # 3. basic, linear regression $R^2$ and coefficient without offset 79 | # 4. single gene name, fit this designated gene alone, for model validation purpose only 80 | # 5. [list of gene names], manually provide a list of genes as velocity genes in string, might improve performance, see scNT 81 | self.VGENES = 'basic' 82 | 83 | # (float) threshold of R2 at later stage of the optimization proces 84 | # to capture the dynamics of more genes beside initially selected velocity genes 85 | # Note: self.AGENES_R2 = 1 will switch to origianl mode with no amplification 86 | self.AGENES_R2 = 1 87 | self.AGENES_THRES = 0.61 88 | 89 | # (bool, experimental) exclude cell that have 0 expression in either un/spliced when contributing to loss function 90 | self.FILTER_CELLS = False 91 | 92 | class CellInitialization(object): 93 | def __init__(self): 94 | # (str) criteria for cell latent time initialization, default None 95 | # 1. None, initialized based on the exact order of input expression matrix 96 | # 2. gcount, str, initialized based on gene counts (https://www.science.org/doi/abs/10.1126/science.aax0249) 97 | # 3. cluster name, str, use diffusion map based time as initialization 98 | # 4. [(gene1, trend1), (gene2, trend2), (gene3, trend3), ...], list of tuples, 99 | # a list of genes with trend can be one of {increase, decrease} 100 | self.IROOT = None 101 | 102 | # (int) number of random initializations of time points, default 1 103 | # in rare cases, velocity field generated might be reversed, possibly because stably and monotonically changed genes 104 | # change this parameter to 2 might do the trick 105 | self.NUM_REP = 1 106 | # when self.NUM_REP = 2, the following parameter will determine how the second time will be initialized 107 | # re_pre, reverse the inferred cell time of first run 108 | # re_init, reverse the initialization time of first run 109 | self.NUM_REP_TIME = 're_pre' 110 | 111 | class Configuration(): 112 | def __init__(self): 113 | Preprocessing.__init__(self) 114 | Regularization.__init__(self) 115 | Optimizer.__init__(self) 116 | FittingOption.__init__(self) 117 | VelocityGenes.__init__(self) 118 | CellInitialization.__init__(self) 119 | 120 | # (int) speficy the GPU card for acceleration, default 0 121 | # -1 will switch to CPU mode 122 | self.GPU = 0 123 | 124 | # Gaussian Mixture 125 | self.BASE_FUNCTION = 'Gaussian' 126 | 127 | # Deterministic Curve Linear 128 | self.GENERAL = 'Curve' 129 | 130 | # (str) embedding format of adata, e.g. pca, tsne, umap, 131 | # if None (default), algorithm will choose one automatically 132 | self.BASIS = None 133 | 134 | # (int, experimental) window size for sliding smoothing of distribution with highest probability 135 | # useful when self.DENSITY == 'Max' 136 | # self.WIN_SIZE = 50 -------------------------------------------------------------------------------- /unitvelo/eval_utils.py: -------------------------------------------------------------------------------- 1 | #%% 2 | """ 3 | Evaluation utility functions. 4 | This module contains util functions for computing evaluation scores. 5 | """ 6 | 7 | import numpy as np 8 | from sklearn.metrics.pairwise import cosine_similarity 9 | 10 | def summary_scores(all_scores): 11 | """Summarize group scores. 12 | 13 | Args: 14 | all_scores (dict{str,list}): 15 | {group name: score list of individual cells}. 16 | 17 | Returns: 18 | dict{str,float}: 19 | Group-wise aggregation scores. 20 | float: 21 | score aggregated on all samples 22 | 23 | """ 24 | sep_scores = {k:np.mean(s) for k, s in all_scores.items() if s} 25 | overal_agg = np.mean([s for k, s in sep_scores.items() if s]) 26 | return sep_scores, overal_agg 27 | 28 | def keep_type(adata, nodes, target, k_cluster): 29 | """Select cells of targeted type 30 | 31 | Args: 32 | adata (Anndata): 33 | Anndata object. 34 | nodes (list): 35 | Indexes for cells 36 | target (str): 37 | Cluster name. 38 | k_cluster (str): 39 | Cluster key in adata.obs dataframe 40 | 41 | Returns: 42 | list: 43 | Selected cells. 44 | 45 | """ 46 | return nodes[adata.obs[k_cluster][nodes].values == target] 47 | 48 | def cross_boundary_correctness( 49 | adata, 50 | k_cluster, 51 | k_velocity, 52 | cluster_edges, 53 | return_raw=False, 54 | x_emb="X_umap" 55 | ): 56 | """Cross-Boundary Direction Correctness Score (A->B) 57 | 58 | Args: 59 | adata (Anndata): 60 | Anndata object. 61 | k_cluster (str): 62 | key to the cluster column in adata.obs DataFrame. 63 | k_velocity (str): 64 | key to the velocity matrix in adata.obsm. 65 | cluster_edges (list of tuples("A", "B")): 66 | pairs of clusters has transition direction A->B 67 | return_raw (bool): 68 | return aggregated or raw scores. 69 | x_emb (str): 70 | key to x embedding for visualization. 71 | 72 | Returns: 73 | dict: 74 | all_scores indexed by cluster_edges or mean scores indexed by cluster_edges 75 | float: 76 | averaged score over all cells. 77 | 78 | """ 79 | scores = {} 80 | all_scores = {} 81 | 82 | if x_emb == "X_umap": 83 | v_emb = adata.obsm['{}_umap'.format(k_velocity)] 84 | else: 85 | v_emb = adata.obsm[[key for key in adata.obsm if key.startswith(k_velocity)][0]] 86 | 87 | x_emb = adata.obsm[x_emb] 88 | 89 | for u, v in cluster_edges: 90 | sel = adata.obs[k_cluster] == u 91 | nbs = adata.uns['neighbors']['indices'][sel] # [n * 30] 92 | 93 | boundary_nodes = map(lambda nodes:keep_type(adata, nodes, v, k_cluster), nbs) 94 | x_points = x_emb[sel] 95 | x_velocities = v_emb[sel] 96 | 97 | type_score = [] 98 | for x_pos, x_vel, nodes in zip(x_points, x_velocities, boundary_nodes): 99 | if len(nodes) == 0: continue 100 | 101 | position_dif = x_emb[nodes] - x_pos 102 | dir_scores = cosine_similarity(position_dif, x_vel.reshape(1,-1)).flatten() 103 | type_score.append(np.mean(dir_scores)) 104 | 105 | scores[(u, v)] = np.mean(type_score) 106 | all_scores[(u, v)] = type_score 107 | 108 | if return_raw: 109 | return all_scores 110 | 111 | return scores, np.mean([sc for sc in scores.values()]) 112 | 113 | def inner_cluster_coh(adata, k_cluster, k_velocity, return_raw=False): 114 | """In-cluster Coherence Score. 115 | 116 | Args: 117 | adata (Anndata): 118 | Anndata object. 119 | k_cluster (str): 120 | key to the cluster column in adata.obs DataFrame. 121 | k_velocity (str): 122 | key to the velocity matrix in adata.obsm. 123 | return_raw (bool): 124 | return aggregated or raw scores. 125 | 126 | Returns: 127 | dict: 128 | all_scores indexed by cluster_edges mean scores indexed by cluster_edges 129 | float: 130 | averaged score over all cells. 131 | 132 | """ 133 | clusters = np.unique(adata.obs[k_cluster]) 134 | scores = {} 135 | all_scores = {} 136 | 137 | for cat in clusters: 138 | sel = adata.obs[k_cluster] == cat 139 | nbs = adata.uns['neighbors']['indices'][sel] 140 | same_cat_nodes = map(lambda nodes:keep_type(adata, nodes, cat, k_cluster), nbs) 141 | 142 | velocities = adata.layers[k_velocity] 143 | cat_vels = velocities[sel] 144 | cat_score = [cosine_similarity(cat_vels[[ith]], velocities[nodes]).mean() 145 | for ith, nodes in enumerate(same_cat_nodes) 146 | if len(nodes) > 0] 147 | all_scores[cat] = cat_score 148 | scores[cat] = np.mean(cat_score) 149 | 150 | if return_raw: 151 | return all_scores 152 | 153 | return scores, np.mean([sc for sc in scores.values()]) 154 | 155 | def evaluate( 156 | adata, 157 | cluster_edges, 158 | k_cluster, 159 | k_velocity="velocity", 160 | x_emb="X_umap", 161 | verbose=True 162 | ): 163 | """Evaluate velocity estimation results using 5 metrics. 164 | 165 | Args: 166 | adata (Anndata): 167 | Anndata object. 168 | cluster_edges (list of tuples("A", "B")): 169 | pairs of clusters has transition direction A->B 170 | k_cluster (str): 171 | key to the cluster column in adata.obs DataFrame. 172 | k_velocity (str): 173 | key to the velocity matrix in adata.obsm. 174 | x_emb (str): 175 | key to x embedding for visualization. 176 | 177 | Returns: 178 | dict: 179 | aggregated metric scores. 180 | 181 | """ 182 | 183 | from .eval_utils import cross_boundary_correctness 184 | from .eval_utils import inner_cluster_coh 185 | crs_bdr_crc = cross_boundary_correctness(adata, k_cluster, k_velocity, cluster_edges, True, x_emb) 186 | ic_coh = inner_cluster_coh(adata, k_cluster, k_velocity, True) 187 | 188 | if verbose: 189 | print("# Cross-Boundary Direction Correctness (A->B)\n{}\nTotal Mean: {}".format(*summary_scores(crs_bdr_crc))) 190 | print("# In-cluster Coherence\n{}\nTotal Mean: {}".format(*summary_scores(ic_coh))) 191 | 192 | return { 193 | "Cross-Boundary Direction Correctness (A->B)": crs_bdr_crc, 194 | "In-cluster Coherence": ic_coh, 195 | } 196 | -------------------------------------------------------------------------------- /unitvelo/gene_influence.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import numpy as np 3 | import pandas as pd 4 | from scipy.sparse import coo_matrix, csr_matrix 5 | 6 | def vals_to_csr(vals, rows, cols, shape): 7 | graph = coo_matrix((vals, (rows, cols)), shape=shape) 8 | graph_neg = graph.copy() 9 | 10 | graph.data = np.clip(graph.data, 0, 1) 11 | graph_neg.data = np.clip(graph_neg.data, -1, 0) 12 | 13 | graph.eliminate_zeros() 14 | graph_neg.eliminate_zeros() 15 | 16 | return graph.tocsr(), graph_neg.tocsr() 17 | 18 | def get_iterative_indices(indices, index, n_recurse_neighbors=2): 19 | def iterate_indices(indices, index, n_recurse_neighbors): 20 | if n_recurse_neighbors > 1: 21 | index = iterate_indices(indices, index, n_recurse_neighbors - 1) 22 | ix = np.append(index, indices[index]) # direct and indirect neighbors 23 | if np.isnan(ix).any(): 24 | ix = ix[~np.isnan(ix)] 25 | return ix.astype(int) 26 | 27 | indices = np.unique(iterate_indices(indices, index, n_recurse_neighbors)) 28 | return indices 29 | 30 | def cosine_correlation(dX, Vi): 31 | dx = dX - dX.mean(-1)[:, None] 32 | Vi_norm = vector_norm(Vi) 33 | 34 | if Vi_norm == 0: 35 | result = np.zeros(dx.shape[0]) 36 | else: 37 | result = np.einsum("ij, j", dx, Vi) / (norm(dx) * Vi_norm)[None, :] 38 | 39 | return result 40 | 41 | def vector_norm(x): 42 | """computes the L2-norm along axis 1, equivalent to np.linalg.norm(A, axis=1) 43 | """ 44 | return np.sqrt(np.einsum("i, i -> ", x, x)) 45 | 46 | def norm(A): 47 | """computes the L2-norm along axis 1, equivalent to np.linalg.norm(A, axis=1) 48 | """ 49 | return np.sqrt(np.einsum("ij, ij -> i", A, A) if A.ndim > 1 else np.sum(A * A)) 50 | 51 | class Influence(): 52 | def __init__(self) -> None: 53 | pass 54 | 55 | def get_n_jobs(self, n_jobs): 56 | import os 57 | 58 | if n_jobs is None or (n_jobs < 0 and os.cpu_count() + 1 + n_jobs <= 0): 59 | return 1, os.cpu_count() 60 | elif n_jobs > os.cpu_count(): 61 | return os.cpu_count(), os.cpu_count() 62 | elif n_jobs < 0: 63 | return os.cpu_count() + 1 + n_jobs, os.cpu_count() 64 | else: 65 | return n_jobs, os.cpu_count() 66 | 67 | def get_indices(self, D): 68 | D.data += 1e-6 69 | A = D > 0 70 | n_counts = A.sum(1).A1 # summation over axis 1, equivalent to np.sum(A, 1) 71 | 72 | D.eliminate_zeros() 73 | D.data -= 1e-6 74 | indices = D.indices.reshape((-1, n_counts.min())) 75 | 76 | return indices, D 77 | 78 | def compute_cosines(self): 79 | vals, rows, cols = [], [], [] 80 | 81 | for i in range(self.X.shape[0]): 82 | if self.V[i].max() != 0 or self.V[i].min() != 0: 83 | neighs_idx = get_iterative_indices(self.indices, i) 84 | 85 | dX = self.X[neighs_idx] - self.X[i, None] # 60% of runtime 86 | dX = np.sqrt(np.abs(dX)) * np.sign(dX) 87 | val = cosine_correlation(dX, self.V[i]) # 40% of runtime 88 | 89 | vals.extend(val) 90 | rows.extend(np.ones(len(neighs_idx)) * i) 91 | cols.extend(neighs_idx) 92 | 93 | vals = np.hstack(vals) 94 | vals[np.isnan(vals)] = 0 95 | 96 | graph, graph_neg = \ 97 | vals_to_csr(vals, rows, cols, shape=(self.X.shape[0], self.X.shape[0])) 98 | return graph, graph_neg 99 | 100 | def transition_matrix(self, gene_list): 101 | #! vgraph = VelocityGraph(adata, sqrt_transform=True, gene_subset=gene_list) 102 | used_genes = self.var_names.isin(gene_list) 103 | self.X = np.array(self.Ms[:, used_genes], dtype=np.float32) 104 | self.V = np.array(self.velocity[:, used_genes], dtype=np.float32) 105 | 106 | self.V = np.sqrt(np.abs(self.V)) * np.sign(self.V) 107 | self.V -= np.nanmean(self.V, axis=1)[:, None] 108 | 109 | self.indices = self.get_indices(self.distance)[0] 110 | 111 | #! vgraph.compute_cosines() 112 | graph, graph_neg = self.compute_cosines() 113 | 114 | #! transition matrix in sparse format 115 | graph = csr_matrix(graph) 116 | 117 | confidence = graph.max(1).A.flatten() 118 | ub = np.percentile(confidence, 98) 119 | self_prob = np.clip(ub - confidence, 0, 1) 120 | graph.setdiag(self_prob) 121 | 122 | T = np.expm1(graph * 10) # np.exp(graph.A * 10) - 1 123 | T -= np.expm1(-graph_neg * 10) 124 | 125 | T = T.multiply(csr_matrix(1.0 / np.abs(T).sum(1))) # original `normalize` function 126 | T.setdiag(0) 127 | T.eliminate_zeros() 128 | return T.A 129 | 130 | def get_importance_simplify(self, args): 131 | gene_list, gene = args 132 | T = self.transition_matrix(gene_list) 133 | 134 | #! self.get_importance by clusters or entire dataset 135 | cosine = [] 136 | cosine.append(np.abs(self.Tref - T).sum()) 137 | 138 | for type in self.ctypes: 139 | index = np.squeeze(np.argwhere(self.label_list == type)) 140 | cosine.append(np.abs(self.Tref[index, :] - T[index, :]).sum()) 141 | 142 | return [cosine, gene] 143 | 144 | def verify_neighbors(self, adata): 145 | valid = "neighbors" in adata.uns.keys() and "params" in adata.uns["neighbors"] 146 | 147 | if valid: 148 | n_neighs = (adata.obsp["distances"] > 0).sum(1) 149 | valid = n_neighs.min() * 2 > n_neighs.max() 150 | 151 | if not valid: 152 | raise ValueError("You need to run scv.pp.neighbors first.") 153 | 154 | def recover_importance( 155 | self, 156 | adata, 157 | n_jobs=None 158 | ): 159 | """Rank genes importance (influence) on estimated velocity fields 160 | by filtering out one gene by another and compare the cosine similarity 161 | between original velocity embedding (equivalent to transition matrix) 162 | with the new velocity embedding minus one genes 163 | 164 | Args: 165 | adata (Anndata): 166 | Anndata object. 167 | n_jobs (int): 168 | number of CPU cores to use. 169 | 170 | Returns: 171 | dataframe: 172 | calculated scores for each gene (unranked version) 173 | need to run `self.rank_importance(df)` 174 | 175 | """ 176 | 177 | import time 178 | import multiprocessing 179 | self.verify_neighbors(adata) 180 | 181 | basis = adata.uns['basis'] 182 | 183 | if f"X_{basis}" not in adata.obsm_keys(): 184 | raise ValueError("You need to compute the embedding first.") 185 | 186 | if f'velocity_graph' not in adata.uns.keys(): 187 | raise ValueError('Need to run `scv.tl.velocity_graph(adata, gene_subset=None)` first') 188 | 189 | vgenes = adata.var.loc[adata.var['velocity_genes'] == True].index 190 | n_jobs, total_jobs = self.get_n_jobs(n_jobs=n_jobs) 191 | 192 | self.adata = adata 193 | self.var_names = adata.var_names 194 | self.Ms = adata.layers["Ms"] 195 | self.velocity = adata.layers["velocity"] 196 | self.distance = adata.obsp["distances"] 197 | self.label_list = adata.obs[adata.uns['label']].values 198 | self.ctypes = list(set(adata.obs[adata.uns['label']].values)) 199 | self.Tref = self.transition_matrix(list(vgenes)) 200 | 201 | ctime = time.time() 202 | print (f"(using {n_jobs}/{total_jobs} cores)") 203 | with multiprocessing.Pool(n_jobs) as pool: 204 | self.res = pool.map_async( 205 | self.get_importance_simplify, 206 | self.bufferize(vgenes)).get() 207 | print ('Time elapsed ', int(time.time() - ctime), ' seconds.') 208 | 209 | columns = ['Overall'] 210 | columns.extend([str(type) for type in self.ctypes]) 211 | df_aggre = pd.DataFrame(index=vgenes, data=0, dtype=np.float32, columns=columns) 212 | 213 | for _res in self.res: 214 | df_aggre.at[_res[1], :] = _res[0] 215 | 216 | return (df_aggre - df_aggre.min(axis=0)) / \ 217 | (df_aggre.max(axis=0) - df_aggre.min(axis=0)) 218 | 219 | def bufferize(self, velocity_genes): 220 | buffer = [] 221 | 222 | for gene in velocity_genes: 223 | used_genes = list(velocity_genes).copy() 224 | used_genes.remove(gene) 225 | 226 | buffer.append([used_genes, gene]) 227 | 228 | return buffer 229 | 230 | def rank_importance(self, df): 231 | ranking = pd.DataFrame(index=list(range(len(df))), data=np.nan, 232 | dtype=np.float32, columns=df.columns) 233 | 234 | for col in ranking.columns: 235 | ranking.at[:, col] = df.sort_values(by=[col], ascending=False).index 236 | 237 | return ranking 238 | 239 | def influence(adata, n_jobs): 240 | import os 241 | from .gene_influence import Influence 242 | val = Influence() 243 | 244 | gene_score = val.recover_importance(adata, n_jobs) 245 | ranking = val.rank_importance(gene_score) 246 | 247 | gene_score = gene_score.sort_values(by=['Overall'], ascending=False) 248 | adata.uns['gene_rank'] = ranking 249 | adata.uns['gene_score'] = gene_score 250 | 251 | adata.write(os.path.join(adata.uns['temp'], 'temp.h5ad'), compression='gzip') 252 | 253 | return ranking, gene_score -------------------------------------------------------------------------------- /unitvelo/individual_gene.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import os 3 | import pandas as pd 4 | import numpy as np 5 | import seaborn as sns 6 | import matplotlib.pyplot as plt 7 | from IPython import display 8 | 9 | def atoi(text): 10 | return int(text) if text.isdigit() else text 11 | 12 | def natural_keys(text): 13 | import re 14 | return [atoi(c) for c in re.split(r'(\d+)', text)] 15 | 16 | class Validation(): 17 | def __init__(self, adata, time_metric='latent_time') -> None: 18 | self.adata = adata 19 | self.time_metric = time_metric 20 | 21 | if 'latent_time' not in adata.obs.columns: 22 | import scvelo as scv 23 | scv.tl.latent_time(adata, min_likelihood=None) 24 | 25 | if len(set(adata.obs[adata.uns['label']])) > 20: 26 | self.palette = 'viridis' 27 | else: 28 | self.palette = 'tab20' 29 | 30 | def init_data(self, adata): 31 | DIR = adata.uns['temp'] 32 | f = os.listdir(DIR) 33 | f.sort(key=natural_keys) 34 | 35 | self.mu = pd.read_csv(f'{DIR}/Mu.csv', index_col=0).add_suffix('_u') 36 | self.ms = pd.read_csv(f'{DIR}/Ms.csv', index_col=0).add_suffix('_s') 37 | 38 | self.fs = pd.DataFrame(index=self.ms.index) 39 | self.fu = pd.DataFrame(index=self.ms.index) 40 | self.var = pd.DataFrame(index=[col[:-2] for col in self.ms.columns]) 41 | 42 | for i in f: 43 | if i.startswith('fits'): 44 | self.fs = self.fs.join(pd.read_csv(f'{DIR}/{i}', index_col=0).add_suffix(f'_fits')) 45 | 46 | if i.startswith('fitu'): 47 | self.fu = self.fu.join(pd.read_csv(f'{DIR}/{i}', index_col=0).add_suffix(f'_fitu')) 48 | 49 | if i.startswith('fitvar'): 50 | df = pd.read_csv(f'{DIR}/{i}', index_col=0) 51 | df.index = df.index.map(str) 52 | self.var = self.var.join(df) 53 | 54 | self.lt = adata.obs[self.time_metric].values 55 | self.label = adata.obs[adata.uns['label']].values 56 | 57 | def concat_data(self): 58 | self.msmu = self.ms.join(self.mu) 59 | self.msmu['labels'] = self.label 60 | 61 | self.msfs = self.ms.join(self.fs) 62 | self.msfs['labels'] = self.label 63 | 64 | self.mufu = self.mu.join(self.fu) 65 | self.mufu['labels'] = self.label 66 | 67 | self.fsfu = self.fs.join(self.fu) 68 | self.fsfu['labels'] = self.label 69 | 70 | def inspect_genes(self, gene_name, adata): 71 | fig, axes = plt.subplots(2, 2, figsize=(14, 12)) 72 | 73 | self.plot_mf(adata, self.msfs[gene_name + '_s'], self.msfs[gene_name + f'_fits'], axes[0][0]) 74 | axes[0][0].plot( 75 | [np.min(self.fs[gene_name + f'_fits']), np.max(self.fs[gene_name + f'_fits'])], 76 | [np.min(self.fs[gene_name + f'_fits']), np.max(self.fs[gene_name + f'_fits'])], 77 | ls='--', c='red') 78 | axes[0][0].set_title('x: ms -- y: fits') 79 | 80 | self.plot_mf(adata, self.mufu[gene_name + '_u'], self.mufu[gene_name + f'_fitu'], axes[0][1]) 81 | axes[0][1].plot( 82 | [np.min(self.fu[gene_name + f'_fitu']), np.max(self.fu[gene_name + f'_fitu'])], 83 | [np.min(self.fu[gene_name + f'_fitu']), np.max(self.fu[gene_name + f'_fitu'])], 84 | ls='--', c='red') 85 | axes[0][1].set_title('x: mu -- y: fitu') 86 | 87 | self.plot_mf(adata, self.msmu[gene_name + '_s'], self.msmu[gene_name + '_u'], axes[1][0]) 88 | axes[1][0].set_title('x: ms -- y: mu') 89 | 90 | self.plot_mf(adata, self.fsfu[gene_name + f'_fits'], self.fsfu[gene_name + f'_fitu'], axes[1][1]) 91 | axes[1][1].set_title('x: fits -- y: fitu') 92 | 93 | plt.show() 94 | 95 | def spliced_time(self, gene_name): 96 | fig, axes = plt.subplots(1, 2, figsize=(12, 4)) 97 | self.sns_plot(self.ms, gene_name + '_s', 'time', 'ms', axes[0]) 98 | self.sns_plot(self.fs, gene_name + f'_fits', 'time', 'fits', axes[1]) 99 | 100 | def unspliced_time(self, gene_name): 101 | fig, axes = plt.subplots(1, 2, figsize=(12, 4)) 102 | self.sns_plot(self.mu, gene_name + '_u', 'time', 'mu', axes[0]) 103 | self.sns_plot(self.fu, gene_name + f'_fitu', 'time', 'fitu', axes[1]) 104 | plt.show() 105 | 106 | def sns_plot(self, data, expression, x, y, loc): 107 | df = pd.DataFrame(data=[self.lt, data[expression].values], index=[x, y]).T 108 | df['labels'] = self.label 109 | sns.scatterplot(x=x, y=y, data=df, hue='labels', sizes=1, palette=self.palette, ax=loc) 110 | 111 | def putative_trans_time(self): 112 | pass 113 | 114 | def vars_trends(self, gene_name, adata): 115 | par_names = adata.uns['par_names'] 116 | para = pd.DataFrame(index=par_names, columns=['Values']) 117 | para['Values'] = self.var.loc[gene_name].values 118 | self.para = para.iloc[:, -1].T.astype(np.float32) 119 | 120 | def plot_mf(self, adata, s, u, ax=None, hue=None): 121 | data = [np.squeeze(s), np.squeeze(u)] 122 | df = pd.DataFrame(data=data, index=['spliced', 'unspliced']).T 123 | 124 | if adata.shape[0] == np.squeeze(s).shape[0]: 125 | df['labels'] = adata.obs[adata.uns['label']].values 126 | hue = 'labels' 127 | 128 | sns.scatterplot(x='spliced', y='unspliced', data=df, sizes=1, 129 | palette=self.palette, ax=ax, hue=hue) 130 | 131 | def func(self, validate=None, t_cell=None): 132 | s, u = validate.get_s_u(self.para.values, t_cell) 133 | return s, u 134 | 135 | def plot_range(self, gene_name, adata, ctype=None): 136 | #! solving the scaling of unspliced problem afterwards 137 | from .optimize_utils import exp_args 138 | from .optimize_utils import Model_Utils 139 | validate = Model_Utils(config=adata.uns['config']) 140 | self.vars_trends(gene_name, adata) 141 | 142 | columns = exp_args(adata) 143 | for col in columns: 144 | self.para[col] = np.log(self.para[col]) 145 | 146 | boundary = ( 147 | np.reshape(self.para['t'] - 3 * (1 / np.sqrt(2 * np.exp(self.para['a']))), (1, 1)), 148 | np.reshape(self.para['t'] + 3 * (1 / np.sqrt(2 * np.exp(self.para['a']))), (1, 1)) 149 | ) 150 | 151 | adata = adata[:, gene_name] 152 | if ctype != None: 153 | adata = adata[adata.obs[adata.obs[adata.uns['label']] == ctype].index, :] 154 | 155 | spre, upre = self.func(validate, validate.init_time(boundary, (3000, 1))) 156 | sone, uone = self.func(validate, validate.init_time((0, 1), (3000, 1))) 157 | sfit, ufit = self.func(validate, adata.obs['latent_time'].values) 158 | 159 | display.clear_output(wait=True) 160 | fig, axes = plt.subplots(1, 2, figsize=(12, 4)) 161 | self.plot_mf(adata, adata.layers['Ms'], adata.layers['Mu'] / adata.var['scaling'].values, axes[0]) 162 | self.plot_mf(adata, spre, upre, axes[0]) 163 | self.plot_mf(adata, sone, uone, axes[0]) 164 | self.plot_mf(adata, sfit, ufit, axes[1]) 165 | 166 | if 'li_coef' in adata.var.columns: 167 | x = np.linspace(np.min(adata.layers['Ms']), np.max(adata.layers['Ms']), 1000) 168 | y = x * adata.var['li_coef'].values 169 | axes[0].plot(x, y, ls='--', c='red') 170 | 171 | plt.show() 172 | 173 | def plot_scv_fit(self, gene_name, adata): 174 | DIR = adata.uns['temp'] 175 | self.mu = pd.read_csv(f'{DIR}/Mu.csv', index_col=0).add_suffix('_u') 176 | self.ms = pd.read_csv(f'{DIR}/Ms.csv', index_col=0).add_suffix('_s') 177 | 178 | self.fu = pd.read_csv(f'{DIR}/scvu.csv', index_col=0).add_suffix(f'_fitu') 179 | self.fs = pd.read_csv(f'{DIR}/scvs.csv', index_col=0).add_suffix(f'_fits') 180 | 181 | self.lt = adata.obs['latent_time'].values 182 | self.label = adata.obs[adata.uns['label']].values 183 | 184 | self.concat_data() 185 | self.inspect_genes(gene_name, adata) 186 | self.spliced_time(gene_name) 187 | self.unspliced_time(gene_name) 188 | 189 | def exam_genes(adata, gene_name=None, time_metric='latent_time'): 190 | display.clear_output(wait=True) 191 | from .individual_gene import Validation 192 | examine = Validation(adata, time_metric=time_metric) 193 | examine.init_data(adata) 194 | examine.concat_data() 195 | examine.inspect_genes(gene_name, adata) 196 | examine.spliced_time(gene_name) 197 | examine.unspliced_time(gene_name) 198 | examine.vars_trends(gene_name, adata) 199 | 200 | def exam_scv(data_path, gene_name, basis, label): 201 | try: 202 | import scvelo as scv 203 | except ModuleNotFoundError: 204 | print ('Install scVelo via `pip install scvelo`') 205 | 206 | adata = scv.read(data_path) 207 | adata.uns['datapath'] = data_path 208 | adata.uns['label'] = label 209 | adata.uns['basis'] = basis 210 | 211 | scv.pp.filter_and_normalize(adata, min_shared_counts=20, n_top_genes=2000) 212 | scv.pp.moments(adata, n_pcs=30, n_neighbors=30) 213 | scv.tl.recover_dynamics(adata, n_jobs=20) 214 | scv.tl.velocity(adata, mode='dynamical') 215 | scv.tl.velocity_graph(adata) 216 | scv.tl.latent_time(adata) 217 | 218 | if basis != None: 219 | scv.pl.velocity_embedding_stream( 220 | adata, basis=basis, color=label, 221 | legend_loc='far right', dpi=200, 222 | title='scVelo dynamical model' 223 | ) 224 | 225 | scv.pl.scatter(adata, color='latent_time', color_map='gnuplot', size=50) 226 | 227 | if gene_name != None: 228 | examine = Validation(adata, time_metric='latent_time') 229 | examine.plot_scv_fit(gene_name, adata) 230 | 231 | return adata -------------------------------------------------------------------------------- /unitvelo/main.py: -------------------------------------------------------------------------------- 1 | from .velocity import Velocity 2 | import scvelo as scv 3 | import os 4 | 5 | def run_model( 6 | adata, 7 | label, 8 | config_file=None, 9 | normalize=True, 10 | ): 11 | """Preparation and pre-processing function of RNA velocity calculation. 12 | 13 | Args: 14 | adata (str): 15 | takes relative of absolute path of Anndata object as input or directly adata object as well. 16 | label (str): 17 | column name in adata.var indicating cell clusters. 18 | config_file (object): 19 | model configuration object, default: None. 20 | 21 | Returns: 22 | adata: 23 | modified Anndata object. 24 | 25 | """ 26 | 27 | from .utils import init_config_summary, init_adata_and_logs 28 | config, _ = init_config_summary(config=config_file) 29 | adata, data_path = init_adata_and_logs(adata, config, normalize=normalize) 30 | 31 | scv.settings.presenter_view = True 32 | scv.settings.verbosity = 0 33 | scv.settings.file_format_figs = 'png' 34 | 35 | replicates, pre = None, 1e15 36 | adata_second = adata.copy() 37 | 38 | for rep in range(config.NUM_REP): 39 | if rep >= 1: 40 | adata = adata_second.copy() 41 | adata.obs['latent_time_gm'] = pre_time_gm 42 | 43 | adata.uns['datapath'] = data_path 44 | adata.uns['label'] = label 45 | adata.uns['base_function'] = 'Gaussian' 46 | 47 | if config.BASIS is None: 48 | basis_keys = ["pca", "tsne", "umap"] 49 | basis = [key for key in basis_keys if f"X_{key}" in adata.obsm.keys()][-1] 50 | elif f"X_{config.BASIS}" in adata.obsm.keys(): 51 | basis = config.BASIS 52 | else: 53 | raise ValueError('Invalid embedding parameter config.BASIS') 54 | adata.uns['basis'] = basis 55 | 56 | if 'scNT' in data_path: 57 | import pandas as pd 58 | gene_ids = pd.read_csv('../data/scNT/brie_neuron_splicing_time.tsv', delimiter='\t', index_col='GeneID') 59 | config.VGENES = list(gene_ids.loc[gene_ids['time_FDR'] < 0.01].index) 60 | 61 | model = Velocity(adata, config=config) 62 | model.get_velo_genes() 63 | 64 | adata = model.fit_velo_genes(basis=basis, rep=rep) 65 | pre_time_gm = adata.obs['latent_time'].values 66 | 67 | if config.GENERAL != 'Linear': 68 | replicates = adata if adata.uns['loss'] < pre else replicates 69 | pre = adata.uns['loss'] if adata.uns['loss'] < pre else pre 70 | 71 | #? change adata to replicates? 72 | replicates.write(os.path.join(adata.uns['temp'], f'temp_{config.FIT_OPTION}.h5ad')) 73 | 74 | if 'examine_genes' in adata.uns.keys(): 75 | from .individual_gene import exam_genes 76 | exam_genes(adata, adata.uns['examine_genes']) 77 | 78 | return replicates if config.GENERAL != 'Linear' else adata -------------------------------------------------------------------------------- /unitvelo/model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from scvelo.tools.utils import make_unique_list 4 | from tqdm import tqdm 5 | import scvelo as scv 6 | from .optimize_utils import Model_Utils, exp_args 7 | from .utils import save_vars, new_adata_col, min_max 8 | from .pl import plot_loss 9 | 10 | exp = tf.math.exp 11 | log = tf.math.log 12 | sum = tf.math.reduce_sum 13 | mean = tf.math.reduce_mean 14 | sqrt = tf.math.sqrt 15 | abs = tf.math.abs 16 | square = tf.math.square 17 | pow = tf.math.pow 18 | std = tf.math.reduce_std 19 | var = tf.math.reduce_variance 20 | 21 | class Recover_Paras(Model_Utils): 22 | def __init__( 23 | self, 24 | adata, 25 | Ms, 26 | Mu, 27 | var_names, 28 | idx=None, 29 | rep=1, 30 | config=None 31 | ): 32 | super().__init__( 33 | adata=adata, 34 | var_names=var_names, 35 | Ms=Ms, 36 | Mu=Mu, 37 | config = config 38 | ) 39 | 40 | self.idx = idx 41 | self.rep = rep 42 | self.scaling = adata.var['scaling'].values 43 | self.flag = True 44 | 45 | self.init_pars() 46 | self.init_vars() 47 | self.init_weights() 48 | self.adata.uns['par_names'] = self.default_pars_names 49 | self.t_cell = self.compute_cell_time(args=None) 50 | self.pi = tf.constant(np.pi, dtype=tf.float32) 51 | 52 | def compute_cell_time(self, args=None, iter=None, show=False): 53 | if args != None: 54 | boundary = (args[4] - 3 * (1 / sqrt(2 * exp(args[3]))), 55 | args[4] + 3 * (1 / sqrt(2 * exp(args[3])))) 56 | t_range = boundary if self.config.RESCALE_TIME else (0, 1) 57 | x = self.init_time(t_range, (3000, self.adata.n_vars)) 58 | 59 | s_predict, u_predict = self.get_s_u(args, x) 60 | s_predict = tf.expand_dims(s_predict, axis=0) # 1 3000 d 61 | u_predict = tf.expand_dims(u_predict, axis=0) 62 | Mu = tf.expand_dims(self.Mu, axis=1) # n 1 d 63 | Ms = tf.expand_dims(self.Ms, axis=1) 64 | 65 | t_cell = self.match_time(Ms, Mu, s_predict, u_predict, x.numpy(), iter) 66 | 67 | if self.config.AGGREGATE_T: 68 | t_cell = tf.reshape(t_cell, (-1, 1)) 69 | t_cell = tf.broadcast_to(t_cell, self.adata.shape) 70 | 71 | # plot_phase_portrait(self.adata, args, Ms, Mu, s_predict, u_predict) 72 | 73 | else: 74 | boundary = (self.t - 3 * (1 / sqrt(2 * exp(self.log_a))), 75 | self.t + 3 * (1 / sqrt(2 * exp(self.log_a)))) 76 | 77 | t_cell = self.init_time((0, 1), self.adata.shape) 78 | 79 | if self.rep == 1: 80 | if self.config.NUM_REP_TIME == 're_init': 81 | t_cell = 1 - t_cell 82 | if self.config.NUM_REP_TIME == 're_pre': 83 | t_cell = 1 - self.adata.obs['latent_time_gm'].values 84 | t_cell = tf.broadcast_to(t_cell.reshape(-1, 1), self.adata.shape) 85 | 86 | if self.rep > 1: 87 | tf.random.set_seed(np.ceil((self.rep - 1) / 2)) 88 | shuffle = tf.random.shuffle(t_cell) 89 | t_cell = shuffle if self.rep % 2 == 1 else 1 - shuffle 90 | 91 | if self.config.IROOT == 'gcount': 92 | print ('---> Use Gene Counts as initial.') 93 | self.adata.obs['gcount'] = np.sum(self.adata.X.todense() > 0, axis=1) 94 | g_time = 1 - min_max(self.adata.obs.groupby(self.adata.uns['label'])['gcount'].mean()) 95 | 96 | for id in list(g_time.index): 97 | self.adata.obs.loc[self.adata.obs[self.adata.uns['label']] == id, 'gcount'] = g_time[id] 98 | 99 | scv.pl.scatter(self.adata, color='gcount', cmap='gnuplot', dpi=100) 100 | t_cell = tf.cast( 101 | tf.broadcast_to( 102 | self.adata.obs['gcount'].values.reshape(-1, 1), 103 | self.adata.shape), 104 | tf.float32) 105 | 106 | elif type(self.config.IROOT) == list: 107 | t_cell, perc = [], [] 108 | for prior in self.config.IROOT: 109 | expr = np.array(self.adata[:, prior[0]].layers['Ms']) 110 | 111 | perc.append(np.max(expr) * 0.75) # modify 0.75 for parameter tuning 112 | t_cell.append(min_max(expr) if prior[1] == 'increase' else 1 - min_max(expr)) 113 | 114 | perc_total = np.sum(perc) 115 | perc = [perc[i] / perc_total for i in range(len(perc))] 116 | print (f'assigned weights of IROOT {list(np.around(np.array(perc), 2))}') 117 | t_cell = [perc[i] * t_cell[i] for i in range(len(perc))] 118 | t_cell = tf.cast(tf.broadcast_to(np.sum(t_cell, axis=0).reshape(-1, 1), self.adata.shape), tf.float32) 119 | 120 | elif self.config.IROOT in self.adata.obs[self.adata.uns['label']].values: 121 | print ('Use diffusion pseudotime as initial.') 122 | import scanpy as sc 123 | sc.tl.diffmap(self.adata) 124 | self.adata.uns['iroot'] = \ 125 | np.flatnonzero( 126 | self.adata.obs[self.adata.uns['label']] == self.config.IROOT 127 | )[0] 128 | sc.tl.dpt(self.adata) 129 | 130 | if show: 131 | scv.pl.scatter(self.adata, color='dpt_pseudotime', cmap='gnuplot', dpi=100) 132 | 133 | t_cell = tf.cast( 134 | tf.broadcast_to( 135 | self.adata.obs['dpt_pseudotime'].values.reshape(-1, 1), 136 | self.adata.shape), 137 | tf.float32) 138 | 139 | else: 140 | pass 141 | 142 | return t_cell 143 | 144 | def get_loglikelihood(self, distx=None, varx=None): 145 | n = np.clip(len(distx) - len(self.u) * 0.01, 2, None) 146 | loglik = -1 / 2 / n * np.sum(distx) / varx 147 | loglik -= 1 / 2 * np.log(2 * np.pi * varx) 148 | 149 | def amplify_gene(self, t_cell, iter): 150 | from sklearn import linear_model 151 | from sklearn.metrics import r2_score 152 | print (f'\nExaming genes which are not initially considered as velocity genes') 153 | 154 | r2 = np.repeat(-1., self.Ms.shape[1]) 155 | reg = linear_model.LinearRegression() 156 | 157 | for col in tqdm(range(self.Ms.shape[1])): 158 | if not self.idx[col]: 159 | y = np.reshape(self.Ms[:, col], (-1, 1)) 160 | x = np.reshape(t_cell, (-1, 1)) 161 | 162 | reg.fit(x, y) 163 | y_pred = reg.predict(x) 164 | r2[col] = r2_score(y, y_pred) 165 | 166 | self.agenes = r2 >= self.config.AGENES_R2 167 | self.adata.var['amplify_r2'] = r2 168 | self.adata.var['amplify_genes'] = self.agenes 169 | self.flag = False 170 | 171 | _ = self.get_log(sum(self.s_r2, axis=0) + sum(self.u_r2, axis=0), True, iter=iter) 172 | self.used_agenes = np.array(self.adata.var['amplify_genes'].values) 173 | self.total_genes = self.idx | self.used_agenes 174 | 175 | print (f'# of amplified genes {self.agenes.sum()}, # of used {self.used_agenes.sum()}') 176 | print (f'# of (velocity + used) genes {self.total_genes.sum()}') 177 | 178 | self.infi_genes = ~np.logical_xor(~self.agenes, self.used_agenes) 179 | self.adata.var['amplify_infi'] = self.infi_genes 180 | print (f'# of infinite (or nan) genes {self.infi_genes.sum()}') 181 | 182 | def compute_loss(self, args, t_cell, Ms, Mu, iter, progress_bar): 183 | self.s_func, self.u_func = self.get_s_u(args, t_cell) 184 | udiff, sdiff = Mu - self.u_func, Ms - self.s_func 185 | 186 | if (self.config.AGENES_R2 < 1) & (iter > self.agenes_thres): 187 | self.u_r2 = square(udiff) 188 | self.s_r2 = square(sdiff) 189 | 190 | if self.flag: 191 | self.amplify_gene(t_cell.numpy()[:, 0], iter=iter) 192 | 193 | if iter > int(0.9 * self.config.MAX_ITER) & self.config.REG_LOSS: 194 | self.s_r2 = self.s_r2 + \ 195 | std(Ms, axis=0) * self.config.REG_TIMES * \ 196 | exp(-square(args[4] - 0.5) / self.config.REG_SCALE) 197 | 198 | #compute variance, equivalent to np.var(np.sign(sdiff) * np.sqrt(distx)) 199 | self.vars = mean(self.s_r2, axis=0) \ 200 | - square(mean(tf.math.sign(sdiff) * sqrt(self.s_r2), axis=0)) 201 | self.varu = mean(self.u_r2 * square(self.scaling), axis=0) \ 202 | - square(mean(tf.math.sign(udiff) * sqrt(self.u_r2) * self.scaling, axis=0)) 203 | 204 | #! edge case of mRNAs levels to be the same across all cells 205 | self.vars += tf.cast(self.vars == 0, tf.float32) 206 | self.varu += tf.cast(self.varu == 0, tf.float32) 207 | 208 | self.u_log_likeli = \ 209 | - (Mu.shape[0] / 2) * log(2 * self.pi * self.varu) \ 210 | - sum(self.u_r2 * square(self.scaling), axis=0) / (2 * self.varu) 211 | self.s_log_likeli = \ 212 | - (Ms.shape[0] / 2) * log(2 * self.pi * self.vars) \ 213 | - sum(self.s_r2, axis=0) / (2 * self.vars) 214 | 215 | error_1 = np.sum(sum(self.u_r2, axis=0).numpy()[self.total_genes]) / np.sum(self.total_genes) 216 | error_2 = np.sum(sum(self.s_r2, axis=0).numpy()[self.total_genes]) / np.sum(self.total_genes) 217 | self.se.append(error_1 + error_2) 218 | progress_bar.set_description(f'Loss (Total): {self.se[-1]:.3f}, (Spliced): {error_2:.3f}, (Unspliced): {error_1:.3f}') 219 | 220 | return self.get_loss(iter, 221 | sum(self.s_r2, axis=0), 222 | sum(self.u_r2, axis=0)) 223 | 224 | else: 225 | self.u_r2 = square(udiff) 226 | self.s_r2 = square(sdiff) 227 | 228 | if (self.config.FIT_OPTION == '1') & \ 229 | (iter > int(0.9 * self.config.MAX_ITER)) & self.config.REG_LOSS: 230 | self.s_r2 = self.s_r2 + \ 231 | std(Ms, axis=0) * self.config.REG_TIMES * \ 232 | exp(-square(args[4] - 0.5) / self.config.REG_SCALE) 233 | 234 | #! convert for self.varu to account for scaling in pre-processing 235 | self.vars = mean(self.s_r2, axis=0) \ 236 | - square(mean(tf.math.sign(sdiff) * sqrt(self.s_r2), axis=0)) 237 | self.varu = mean(self.u_r2 * square(self.scaling), axis=0) \ 238 | - square(mean(tf.math.sign(udiff) * sqrt(self.u_r2) * self.scaling, axis=0)) 239 | 240 | self.u_log_likeli = \ 241 | - (Mu.shape[0] / 2) * log(2 * self.pi * self.varu) \ 242 | - sum(self.u_r2 * square(self.scaling), axis=0) / (2 * self.varu) 243 | self.s_log_likeli = \ 244 | - (Ms.shape[0] / 2) * log(2 * self.pi * self.vars) \ 245 | - sum(self.s_r2, axis=0) / (2 * self.vars) 246 | 247 | error_1 = np.sum(sum(self.u_r2, axis=0).numpy()[self.idx]) / np.sum(self.idx) 248 | error_2 = np.sum(sum(self.s_r2, axis=0).numpy()[self.idx]) / np.sum(self.idx) 249 | self.se.append(error_1 + error_2) 250 | progress_bar.set_description(f'Loss (Total): {self.se[-1]:.3f}, (Spliced): {error_2:.3f}, (Unspliced): {error_1:.3f}') 251 | 252 | self.vgene_loss = self.se[-1] 253 | return self.get_loss(iter, 254 | sum(self.s_r2, axis=0), 255 | sum(self.u_r2, axis=0)) 256 | 257 | def fit_likelihood(self): 258 | Ms, Mu, t_cell = self.Ms, self.Mu, self.t_cell 259 | log_gamma, log_beta, offset = self.log_gamma, self.log_beta, self.offset 260 | intercept = self.intercept 261 | log_a, t, log_h = self.log_a, self.t, self.log_h 262 | 263 | from packaging import version 264 | if version.parse(tf.__version__) >= version.parse('2.11.0'): 265 | optimizer = tf.keras.optimizers.legacy.Adam(learning_rate=self.init_lr, amsgrad=True) 266 | else: 267 | optimizer = tf.keras.optimizers.Adam(learning_rate=self.init_lr, amsgrad=True) 268 | 269 | pre = tf.repeat(1e6, Ms.shape[1]) # (2000, ) 270 | self.se, self.m_args, self.m_ur2, self.m_sr2 = [], None, None, None 271 | self.m_ullf, self.m_sllf = None, None 272 | 273 | progress_bar = tqdm(range(self.config.MAX_ITER)) 274 | for iter in progress_bar: 275 | with tf.GradientTape() as tape: 276 | args = [ 277 | log_gamma, 278 | log_beta, 279 | offset, 280 | log_a, 281 | t, 282 | log_h, 283 | intercept 284 | ] 285 | obj = self.compute_loss(args, t_cell, Ms, Mu, iter, progress_bar) 286 | 287 | stop_cond = self.get_stop_cond(iter, pre, obj) 288 | 289 | if iter > self.agenes_thres + 1: 290 | self.m_args = self.get_optimal_res(args, self.m_args) 291 | self.m_ur2 = self.get_optimal_res(self.u_r2, self.m_ur2) 292 | self.m_sr2 = self.get_optimal_res(self.s_r2, self.m_sr2) 293 | self.m_ullf = self.get_optimal_res(self.u_log_likeli, self.m_ullf) 294 | self.m_sllf = self.get_optimal_res(self.s_log_likeli, self.m_sllf) 295 | 296 | if (iter > self.agenes_thres) & \ 297 | (iter == self.config.MAX_ITER - 1 or \ 298 | tf.math.reduce_all(stop_cond) == True or \ 299 | (min(self.se[self.agenes_thres + 1:]) * 1.1 < self.se[-1] 300 | if (iter > self.agenes_thres + 1) else False)): 301 | 302 | if (iter > int(0.9 * self.config.MAX_ITER)) & self.config.REG_LOSS & \ 303 | (min(self.se[self.agenes_thres:]) * 1.1 >= self.se[-1]): 304 | self.m_args = args 305 | self.m_ur2 = self.u_r2 306 | self.m_sr2 = self.s_r2 307 | self.m_ullf = self.u_log_likeli 308 | self.m_sllf = self.s_log_likeli 309 | 310 | t_cell = self.compute_cell_time(args=self.m_args, iter=iter) 311 | _ = self.get_fit_s(self.m_args, t_cell) 312 | s_derivative = self.get_s_deri(self.m_args, t_cell) 313 | # s_derivative = exp(args[1]) * Mu - exp(args[0]) * Ms 314 | 315 | self.post_utils(iter, self.m_args) 316 | break 317 | 318 | args_to_optimize = self.get_opt_args(iter, args) 319 | gradients = tape.gradient(target=obj, sources=args_to_optimize) 320 | 321 | # convert gradients of variables with unused genes to 0 322 | # keep other gradients by multiplying 1 323 | if (self.config.AGENES_R2 < 1) & (iter > self.agenes_thres): 324 | convert = tf.cast(self.total_genes, tf.float32) 325 | processed_grads = [g * convert for g in gradients] 326 | else: 327 | convert = tf.cast(self.idx, tf.float32) 328 | processed_grads = [g * convert for g in gradients] 329 | 330 | optimizer.apply_gradients(zip(processed_grads, args_to_optimize)) 331 | pre = obj 332 | 333 | if iter > 0 and int(iter % 800) == 0: 334 | t_cell = self.compute_cell_time(args=args, iter=iter) 335 | 336 | self.adata.layers['fit_t'] = t_cell.numpy() if self.config.AGGREGATE_T else t_cell 337 | self.adata.var['velocity_genes'] = self.total_genes if not self.flag else self.idx 338 | self.adata.layers['fit_t'][:, ~self.adata.var['velocity_genes'].values] = np.nan 339 | 340 | return self.get_interim_t(t_cell, self.adata.var['velocity_genes'].values), s_derivative.numpy(), self.adata 341 | 342 | def get_optimal_res(self, current, opt): 343 | return current if min(self.se[self.agenes_thres + 1:]) == self.se[-1] else opt 344 | 345 | def post_utils(self, iter, args): 346 | # Reshape un/spliced variance to (ngenes, ) and save 347 | self.adata.var['fit_vars'] = np.squeeze(self.vars) 348 | self.adata.var['fit_varu'] = np.squeeze(self.varu) 349 | 350 | # Save predicted parameters of RBF kernel to adata 351 | self.save_pars([item.numpy() for item in args]) 352 | self.adata.var['fit_beta'] /= self.scaling 353 | self.adata.var['fit_intercept'] *= self.scaling 354 | 355 | # Plotting function for examining model loss 356 | plot_loss(iter, self.se, self.agenes_thres) 357 | 358 | # Save observations, predictinos and variables locally 359 | save_vars(self.adata, args, 360 | self.s_func.numpy(), self.u_func.numpy(), 361 | self.K, self.scaling) 362 | 363 | #! Model loss, log likelihood and BIC based on unspliced counts 364 | gene_loss = sum(self.m_ur2, axis=0) / self.nobs \ 365 | if self.config.FILTER_CELLS \ 366 | else sum(self.m_ur2, axis=0) / self.Ms.shape[0] 367 | 368 | list_name = ['fit_loss', 'fit_llf'] 369 | list_data = [gene_loss.numpy(), self.m_ullf.numpy()] 370 | new_adata_col(self.adata, list_name, list_data) 371 | 372 | # Mimimum loss during optimization, might not be the actual minimum 373 | r2_spliced = 1 - sum(self.m_sr2, axis=0) / var(self.Ms, axis=0) \ 374 | / (self.adata.shape[0] - 1) 375 | r2_unspliced = 1 - sum(self.m_ur2, axis=0) / var(self.Mu, axis=0) \ 376 | / (self.adata.shape[0] - 1) 377 | new_adata_col(self.adata, ['fit_sr2', 'fit_ur2'], [r2_spliced.numpy(), r2_unspliced.numpy()]) 378 | 379 | tloss = min(self.se[self.agenes_thres + 1:]) 380 | self.adata.uns['loss'] = self.vgene_loss 381 | print (f'Total loss {tloss:.3f}, vgene loss {self.vgene_loss:.3f}') 382 | 383 | def save_pars(self, paras): 384 | columns = exp_args(self.adata, 1) 385 | for i, name in enumerate(self.default_pars_names): 386 | self.adata.var[f"fit_{name}"] = np.transpose(np.squeeze(paras[i])) 387 | 388 | if name in columns: 389 | self.adata.var[f"fit_{name}"] = np.exp(self.adata.var[f"fit_{name}"]) 390 | 391 | def lagrange( 392 | adata, 393 | idx=None, 394 | Ms=None, 395 | Mu=None, 396 | var_names="velocity_genes", 397 | rep=1, 398 | config=None 399 | ): 400 | if len(set(adata.var_names)) != len(adata.var_names): 401 | adata.var_names_make_unique() 402 | 403 | var_names = adata.var_names[idx] 404 | var_names = make_unique_list(var_names, allow_array=True) 405 | 406 | model = Recover_Paras( 407 | adata, 408 | Ms, 409 | Mu, 410 | var_names, 411 | idx=idx, 412 | rep=rep, 413 | config=config 414 | ) 415 | 416 | latent_time_gm, s_derivative, adata = model.fit_likelihood() 417 | 418 | if 'latent_time' in adata.obs.columns: 419 | del adata.obs['latent_time'] 420 | adata.obs['latent_time_gm'] = min_max(latent_time_gm[:, 0]) 421 | 422 | return s_derivative, adata 423 | -------------------------------------------------------------------------------- /unitvelo/optimize_utils.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import tensorflow as tf 3 | import numpy as np 4 | import logging 5 | np.random.seed(42) 6 | 7 | exp = tf.math.exp 8 | pow = tf.math.pow 9 | square = tf.math.square 10 | sum = tf.math.reduce_sum 11 | abs = tf.math.abs 12 | mean = tf.math.reduce_mean 13 | log = tf.math.log 14 | sqrt = tf.math.sqrt 15 | 16 | def inv(obj): 17 | return tf.math.reciprocal(obj + 1e-6) 18 | 19 | def col_minmax(matrix, gene_id=None): 20 | if gene_id != None: 21 | if (np.max(matrix, axis=0) == np.min(matrix, axis=0)): 22 | print (gene_id) 23 | return matrix 24 | 25 | return (matrix - np.min(matrix, axis=0)) \ 26 | / (np.max(matrix, axis=0) - np.min(matrix, axis=0)) 27 | 28 | def exp_args(adata, K=1): 29 | if adata.uns['base_function'] == 'Gaussian': 30 | columns = ['a', 'h', 'gamma', 'beta'] 31 | else: 32 | columns = ['gamma', 'beta'] 33 | columns.extend([f'a{k}' for k in range(K)]) 34 | 35 | return columns 36 | 37 | #%% 38 | class Model_Utils(): 39 | def __init__( 40 | self, 41 | adata=None, 42 | var_names=None, 43 | Ms=None, 44 | Mu=None, 45 | config=None 46 | ): 47 | self.adata = adata 48 | self.var_names = var_names 49 | self.Ms, self.Mu = Ms, Mu 50 | self.K = 1 51 | self.win_size = 50 52 | self.config = config 53 | self.gene_log = [] 54 | self.agenes_thres = int(config.AGENES_THRES * config.MAX_ITER) 55 | 56 | def init_vars(self): 57 | ngenes = self.Ms.shape[1] 58 | ones = tf.ones((1, ngenes), dtype=tf.float32) 59 | 60 | self.log_beta = tf.Variable(ones * 0, name='log_beta') 61 | self.intercept = tf.Variable(ones * 0, name='intercept') 62 | 63 | self.t = tf.Variable(ones * 0.5, name='t') #! mean of Gaussian 64 | self.log_a = tf.Variable(ones * 0, name='log_a') #! 1 / scaling of Gaussian 65 | self.offset = tf.Variable(ones * 0, name='offset') 66 | 67 | self.log_h = tf.Variable(ones * log(tf.math.reduce_max(self.Ms, axis=0)), name='log_h') 68 | 69 | init_gamma = self.adata.var['velocity_gamma'].values 70 | self.log_gamma = tf.Variable( 71 | log(tf.reshape(init_gamma, (1, self.adata.n_vars))), 72 | name='log_gamma') 73 | 74 | for id in np.where(init_gamma <= 0)[0]: 75 | logging.info(f'name: {self.adata.var.index[id]}, gamma: {init_gamma[id]}') 76 | 77 | self.log_gamma = tf.Variable( 78 | tf.where(tf.math.is_finite(self.log_gamma), self.log_gamma, 0), 79 | name='log_gamma') 80 | 81 | if self.config.VGENES == 'offset': 82 | init_inter = self.adata.var['velocity_inter'].values 83 | self.intercept = \ 84 | tf.Variable( 85 | tf.reshape(init_inter, (1, self.adata.n_vars)), 86 | name='intercept') 87 | 88 | if type(self.config.GENE_PRIOR) == list: 89 | vgenes_temp = [] 90 | prior_t = np.ones((1, ngenes), dtype=np.float32) * 0.5 91 | for prior in self.config.GENE_PRIOR: 92 | prior_t[0][prior[2]] = 1.1 if prior[1] == 'increase' else -0.2 93 | vgenes_temp.append(prior + (prior_t[0][prior[2]], )) 94 | self.t = tf.Variable(prior_t, name='t') 95 | print (f'Modified GENE_PRIOR {vgenes_temp} with init_tau') 96 | 97 | def init_pars(self): 98 | self.default_pars_names = ['gamma', 'beta'] 99 | self.default_pars_names += ['offset', 'a', 't', 'h', 'intercept'] 100 | 101 | def init_weights(self, weighted=False): 102 | nonzero_s, nonzero_u = self.Ms > 0, self.Mu > 0 103 | weights = np.array(nonzero_s & nonzero_u, dtype=bool) 104 | 105 | if weighted: 106 | ub_s = np.percentile(self.s[weights], self.perc) 107 | ub_u = np.percentile(self.u[weights], self.perc) 108 | if ub_s > 0: 109 | weights &= np.ravel(self.s <= ub_s) 110 | if ub_u > 0: 111 | weights &= np.ravel(self.u <= ub_u) 112 | 113 | self.weights = tf.cast(weights, dtype=tf.float32) 114 | self.nobs = np.sum(weights, axis=0) 115 | 116 | def init_lr(self): 117 | if self.config.LEARNING_RATE == None: 118 | lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( 119 | initial_learning_rate=1e-2, 120 | decay_steps=2000, 121 | decay_rate=0.9, 122 | staircase=True 123 | ) 124 | 125 | else: 126 | lr_schedule = self.config.LEARNING_RATE 127 | 128 | return lr_schedule 129 | 130 | def get_fit_s(self, args, t_cell): 131 | self.fit_s = exp(args[5]) * \ 132 | exp(-exp(args[3]) * square(t_cell - args[4])) + \ 133 | args[2] 134 | 135 | return self.fit_s 136 | 137 | def get_s_deri(self, args, t_cell): 138 | self.s_deri = (self.fit_s - args[2]) * \ 139 | (-exp(args[3]) * 2 * (t_cell - args[4])) 140 | 141 | return self.s_deri 142 | 143 | def get_fit_u(self, args): 144 | return (self.s_deri + exp(args[0]) * self.fit_s) / exp(args[1]) + args[6] 145 | 146 | def get_s_u(self, args, t_cell): 147 | s = self.get_fit_s(args, t_cell) 148 | s_deri = self.get_s_deri(args, t_cell) 149 | u = self.get_fit_u(args) 150 | 151 | if self.config.ASSIGN_POS_U: 152 | s = tf.clip_by_value(s, 0, 1000) 153 | u = tf.clip_by_value(u, 0, 1000) 154 | 155 | return s, u 156 | 157 | def max_density(self, dis): 158 | if self.config.DENSITY == 'Max': 159 | sort_list = np.sort(dis) 160 | diff = np.insert(np.diff(sort_list), 0, 0, axis=1) 161 | kernel = np.ones(self.win_size) 162 | 163 | loc = [] 164 | for row in range(dis.shape[0]): 165 | idx = np.argmin(np.convolve(diff[row, :], kernel, 'valid')) 166 | window = sort_list[row, idx:idx + self.win_size] 167 | loc.append(np.mean(window)) 168 | 169 | return np.array(loc, dtype=np.float32) 170 | 171 | if self.config.DENSITY == 'SVD': 172 | s, u, v = tf.linalg.svd(dis) 173 | s = s[0:50] 174 | u = u[:, :tf.size(s)] 175 | v = v[:tf.size(s), :tf.size(s)] 176 | dis_approx = tf.matmul(u, tf.matmul(tf.linalg.diag(s), v, adjoint_b=True)) 177 | return tf.cast(mean(dis_approx, axis=1), tf.float32) 178 | 179 | if self.config.DENSITY == 'Raw': 180 | # weight = self.gene_prior_perc(dis) 181 | return tf.cast(mean(dis, axis=1), tf.float32) 182 | 183 | def gene_prior_perc(self, dis): 184 | perc, perc_idx = [], [] 185 | aggregrate_weight = np.ones((dis.shape[1], ), dtype=np.float32) 186 | 187 | for prior in self.config.GENE_PRIOR: 188 | expr = np.array(self.adata[:, prior[0]].layers['Ms']) 189 | perc.append(np.max(expr) * 0.75) # modify 0.75 for parameter tuning 190 | 191 | perc_total = np.sum(perc) 192 | perc = [perc[i] / perc_total for i in range(len(perc))] 193 | 194 | for sequence, prior in enumerate(self.config.GENE_PRIOR): 195 | temp_id = np.sum(self.boolean[:prior[2]]) 196 | aggregrate_weight[temp_id] = perc[sequence] * self.config.GENE_PRIOR_SCALE 197 | perc_idx.append(temp_id) 198 | 199 | # print (f'assigned weights of GENE_PRIOR {list(np.around(np.array(perc), 2)), perc_idx}') 200 | weight_total = np.sum(aggregrate_weight) 201 | aggregrate_weight = [aggregrate_weight[i] / weight_total for i in range(len(aggregrate_weight))] 202 | return np.reshape(aggregrate_weight, (1, dis.shape[1])) 203 | 204 | def match_time(self, Ms, Mu, s_predict, u_predict, x, iter): 205 | val = x[1, :] - x[0, :] 206 | cell_time = np.zeros((Ms.shape[0], Ms.shape[2])) 207 | 208 | if (self.config.AGENES_R2 < 1) & (iter > self.agenes_thres): 209 | self.index_list, self.boolean = np.squeeze(tf.where(self.total_genes)), self.total_genes 210 | else: 211 | self.index_list, self.boolean = np.squeeze(tf.where(self.idx)), self.idx 212 | 213 | for index in range(Ms.shape[2]): 214 | if index in self.index_list: 215 | sobs = Ms[:, :, index:index + 1] # n * 1 * 1 216 | uobs = Mu[:, :, index:index + 1] 217 | spre = s_predict[:, :, index:index + 1] # 1 * 3000 * 1 218 | upre = u_predict[:, :, index:index + 1] 219 | 220 | u_r2 = square(uobs - upre) # n * 3000 * 1 221 | s_r2 = square(sobs - spre) 222 | euclidean = sqrt(u_r2 + s_r2) 223 | assign_loc = tf.math.argmin(euclidean, axis=1).numpy() # n * 1 224 | 225 | if self.config.REORDER_CELL == 'Soft_Reorder': 226 | cell_time[:, index:index + 1] = \ 227 | col_minmax(self.reorder(assign_loc), self.adata.var.index[index]) 228 | if self.config.REORDER_CELL == 'Soft': 229 | cell_time[:, index:index + 1] = \ 230 | col_minmax(assign_loc, self.adata.var.index[index]) 231 | if self.config.REORDER_CELL == 'Hard': 232 | cell_time[:, index:index + 1] = \ 233 | x[0, index:index + 1] + val[index:index + 1] * assign_loc 234 | 235 | if self.config.AGGREGATE_T: 236 | return self.max_density(cell_time[:, self.index_list]) #! sampling? 237 | else: 238 | return cell_time 239 | 240 | def reorder(self, loc): 241 | new_loc = np.zeros(loc.shape) 242 | 243 | for gid in range(loc.shape[1]): 244 | ref = sorted([(val, idx) for idx, val in enumerate(loc[:, gid])], 245 | key=lambda x:x[0]) 246 | 247 | pre, count, rep = ref[0][0], 0, 0 248 | for item in ref: 249 | if item[0] > pre: 250 | count += rep 251 | rep = 1 252 | else: 253 | rep += 1 254 | 255 | new_loc[item[1], gid] = count 256 | pre = item[0] 257 | 258 | return new_loc 259 | 260 | def init_time(self, boundary, shape=None): 261 | x = tf.linspace(boundary[0], boundary[1], shape[0]) 262 | 263 | try: 264 | if type(boundary[0]) == int or boundary[0].shape[1] == 1: 265 | x = tf.reshape(x, (-1, 1)) 266 | x = tf.broadcast_to(x, shape) 267 | else: 268 | x = tf.squeeze(x) 269 | except: 270 | x = tf.reshape(x, (-1, 1)) 271 | x = tf.broadcast_to(x, shape) 272 | 273 | return tf.cast(x, dtype=tf.float32) 274 | 275 | def get_opt_args(self, iter, args): 276 | remain = iter % 400 277 | 278 | if self.config.FIT_OPTION == '1': 279 | if iter < self.config.MAX_ITER / 2: 280 | args_to_optimize = [args[2], args[3], args[4], args[5]] \ 281 | if remain < 200 else [args[0], args[1], args[6]] 282 | 283 | else: 284 | args_to_optimize = [args[0], args[1], 285 | args[2], args[3], 286 | args[4], args[5], args[6]] 287 | 288 | if self.config.FIT_OPTION == '2': 289 | if iter < self.config.MAX_ITER / 2: 290 | args_to_optimize = [args[3], args[5]] \ 291 | if remain < 200 else [args[0], args[1]] 292 | 293 | else: 294 | args_to_optimize = [args[0], args[1], args[3], args[5]] 295 | 296 | return args_to_optimize 297 | 298 | def get_log(self, loss, amplify=False, iter=None): 299 | self.finite = tf.math.is_finite(loss) # location of weird genes out of 2000 300 | glog = self.adata.var.iloc[np.squeeze(tf.where(~self.finite))].index.values 301 | 302 | for gene in glog: 303 | if gene not in self.gene_log: 304 | logging.info(f'{gene}, iter {iter}') 305 | self.gene_log.append(gene) 306 | 307 | if amplify: 308 | if self.adata.var.at[gene, 'amplify_genes'] == True: 309 | logging.info(f'{gene}, iter {iter}, amplify') 310 | self.adata.var.at[gene, 'amplify_genes'] = False 311 | 312 | loss = tf.where(self.finite, loss, 0) 313 | return loss 314 | 315 | def get_loss(self, iter, s_r2, u_r2): 316 | if iter < self.config.MAX_ITER / 2: 317 | remain = iter % 400 318 | loss = s_r2 if remain < 200 else u_r2 319 | else: 320 | loss = s_r2 + u_r2 321 | 322 | loss = self.get_log(loss, amplify=False, iter=iter) 323 | return loss 324 | 325 | def get_stop_cond(self, iter, pre, obj): 326 | stop_s = tf.zeros(self.Ms.shape[1], tf.bool) 327 | stop_u = tf.zeros(self.Ms.shape[1], tf.bool) 328 | remain = iter % 400 329 | 330 | if remain > 1 and remain < 200: 331 | stop_s = abs(pre - obj) <= abs(pre) * 1e-4 332 | if remain > 201 and remain < 400: 333 | stop_u = abs(pre - obj) <= abs(pre) * 1e-4 334 | 335 | return tf.math.logical_and(stop_s, stop_u) 336 | 337 | def get_interim_t(self, t_cell, idx): 338 | if self.config.AGGREGATE_T: 339 | return t_cell 340 | 341 | else: 342 | #? modify this later for independent mode 343 | t_interim = np.zeros(t_cell.shape) 344 | 345 | for i in range(t_cell.shape[1]): 346 | if idx[i]: 347 | temp = np.reshape(t_cell[:, i], (-1, 1)) 348 | t_interim[:, i] = \ 349 | np.squeeze(col_minmax(temp, self.adata.var.index[i])) 350 | 351 | t_interim = self.max_density(t_interim) 352 | t_interim = tf.reshape(t_interim, (-1, 1)) 353 | t_interim = tf.broadcast_to(t_interim, self.adata.shape) 354 | return t_interim -------------------------------------------------------------------------------- /unitvelo/pl.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import matplotlib.pyplot as plt 3 | import scvelo as scv 4 | from IPython import display 5 | import seaborn as sns 6 | import numpy as np 7 | import os 8 | from .individual_gene import Validation 9 | 10 | def plot_zero_gene_example(gene_name, adata, nonzero_para, zero_para): 11 | import math 12 | nrows = math.ceil(len(gene_name[:8]) / 2) 13 | 14 | fig, axes = plt.subplots(nrows, 2, figsize=(12, 5 * nrows)) 15 | for i in range(nrows): 16 | sns.scatterplot(np.squeeze(adata[:, gene_name[i * 2]].layers['Ms']), 17 | np.squeeze(adata[:, gene_name[i * 2]].layers['Mu']), 18 | sizes=1, ax=axes[i][0], 19 | hue=adata.obs[adata.uns['label']], palette='tab20') 20 | axes[i][0].set_title(f'{gene_name[i * 2]}') 21 | 22 | nonzero_coef = np.round(nonzero_para.loc[gene_name[i * 2]]['coef'], 2) 23 | zero_coef = np.round(zero_para.loc[gene_name[i * 2]]['coef'], 2) 24 | nonzero_r2 = np.round(nonzero_para.loc[gene_name[i * 2]]['r2'], 2) 25 | zero_r2 = np.round(zero_para.loc[gene_name[i * 2]]['r2'], 2) 26 | 27 | axes[i][0].set_xlabel(f'({nonzero_coef}, {zero_coef})') 28 | axes[i][0].set_ylabel(f'({nonzero_r2}, {zero_r2})') 29 | 30 | sns.scatterplot(np.squeeze(adata[:, gene_name[i * 2 + 1]].layers['Ms']), 31 | np.squeeze(adata[:, gene_name[i * 2 + 1]].layers['Mu']), 32 | sizes=1, ax=axes[i][1], 33 | hue=adata.obs[adata.uns['label']], palette='tab20') 34 | axes[i][1].set_title(f'{gene_name[i * 2 + 1]}') 35 | 36 | nonzero_coef = np.round(nonzero_para.loc[gene_name[i * 2 + 1]]['coef'], 2) 37 | zero_coef = np.round(zero_para.loc[gene_name[i * 2 + 1]]['coef'], 2) 38 | nonzero_r2 = np.round(nonzero_para.loc[gene_name[i * 2 + 1]]['r2'], 2) 39 | zero_r2 = np.round(zero_para.loc[gene_name[i * 2 + 1]]['r2'], 2) 40 | 41 | axes[i][1].set_xlabel(f'({nonzero_coef}, {zero_coef})') 42 | axes[i][1].set_ylabel(f'({nonzero_r2}, {zero_r2})') 43 | 44 | plt.show() 45 | fig.savefig(os.path.join(adata.uns['temp'], 'Gene_Filter_ByZero.png'), dpi=300) 46 | from .pl import plot_zero_gene_distribution 47 | plot_zero_gene_distribution(nonzero_para, zero_para) 48 | 49 | def plot_zero_gene_distribution(nonzero_para, zero_para, adata=None): 50 | fig, axes = plt.subplots(2, 2, figsize=(12, 10)) 51 | sns.distplot(nonzero_para['coef'].values, ax=axes[0][0], kde=False) 52 | axes[0][0].set_title('Nonzero genes coefficient') 53 | 54 | sns.distplot(zero_para['coef'].values, ax=axes[0][1], kde=False) 55 | axes[0][1].set_title('Zero genes coefficient') 56 | 57 | sns.distplot(nonzero_para['r2'].values, ax=axes[1][0], kde=False) 58 | axes[1][0].set_title('Nonzero genes r2') 59 | 60 | sns.distplot(zero_para['r2'].values, ax=axes[1][1], kde=False) 61 | axes[1][1].set_title('Zero genes r2') 62 | plt.show() 63 | fig.savefig(os.path.join(adata.uns['temp'], 'Gene_Filter_ByZero_Distribution.png'), dpi=300) 64 | 65 | def rbf(x, height, sigma, tau, offset_rbf): 66 | return height * np.exp(-sigma * (x - tau) * (x - tau)) + offset_rbf 67 | 68 | def rbf_deri(x, height, sigma, tau, offset_rbf): 69 | return (rbf(x, height, sigma, tau, offset_rbf) - offset_rbf) * (-sigma * 2 * (x - tau)) 70 | 71 | def rbf_u(x, height, sigma, tau, offset_rbf, beta, gamma, intercept): 72 | return (rbf_deri(x, height, sigma, tau, offset_rbf) + gamma * rbf(x, height, sigma, tau, offset_rbf)) / beta + intercept 73 | 74 | def plot_range( 75 | gene_name, 76 | adata, 77 | config_file=None, 78 | save_fig=False, 79 | show_ax=False, 80 | show_legend=True, 81 | show_details=False, 82 | time_metric='latent_time', 83 | palette='tab20', 84 | size=20, 85 | ncols=1 86 | ): 87 | """ 88 | Plotting function of phase portraits of individual genes. 89 | 90 | Args: 91 | gene_name (str): name of that gene to be illusrated, would extend to list of genes in next release 92 | adata (AnnData) 93 | config_file (.Config class): configuration file used for velocity estimation 94 | save_fig (bool): if True, save fig, default False 95 | 96 | show_ax (bool) 97 | show_legend (bool) 98 | show_details (bool): if True, plot detailed regression results together with estimated temporal change 99 | time_metric (str): inferred cell time, default 'latent_time' 100 | 101 | show_temporal (bool, experimental): whether plot temporal changes 102 | show_positive (bool, experimental): related to self.ASSIGN_POS_U 103 | t_left (float, experimental): starting time of phase portraits 104 | t_right (float, experimental): ending time of phase portraits 105 | """ 106 | 107 | if config_file == None: 108 | raise ValueError('Please set attribute `config_file`') 109 | 110 | if time_metric == 'latent_time': 111 | if 'latent_time' not in adata.obs.columns: 112 | scv.tl.latent_time(adata, min_likelihood=None) 113 | 114 | if show_details: 115 | from .individual_gene import exam_genes 116 | exam_genes(adata, gene_name, time_metric=time_metric) 117 | 118 | else: 119 | gene_name = gene_name if type(gene_name) == list else [gene_name] 120 | figs = [] 121 | 122 | for gn in gene_name: 123 | fig, axes = plt.subplots( 124 | nrows=1, 125 | ncols=3, 126 | figsize=(18, 4) 127 | ) 128 | gdata = adata[:, gn] 129 | 130 | boundary = (gdata.var.fit_t.values - 3 * (1 / np.sqrt(2 * np.exp(gdata.var.fit_a.values))), 131 | gdata.var.fit_t.values + 3 * (1 / np.sqrt(2 * np.exp(gdata.var.fit_a.values)))) 132 | 133 | t_one = np.linspace(0, 1, 1000) 134 | t_boundary = np.linspace(boundary[0], boundary[1], 2000) 135 | 136 | spre = np.squeeze(rbf(t_boundary, gdata.var.fit_h.values, gdata.var.fit_a.values, gdata.var.fit_t.values, gdata.var.fit_offset.values)) 137 | sone = np.squeeze(rbf(t_one, gdata.var.fit_h.values, gdata.var.fit_a.values, gdata.var.fit_t.values, gdata.var.fit_offset.values)) 138 | 139 | upre = np.squeeze(rbf_u(t_boundary, gdata.var.fit_h.values, gdata.var.fit_a.values, gdata.var.fit_t.values, gdata.var.fit_offset.values, gdata.var.fit_beta.values, gdata.var.fit_gamma.values, gdata.var.fit_intercept.values)) 140 | uone = np.squeeze(rbf_u(t_one, gdata.var.fit_h.values, gdata.var.fit_a.values, gdata.var.fit_t.values, gdata.var.fit_offset.values, gdata.var.fit_beta.values, gdata.var.fit_gamma.values, gdata.var.fit_intercept.values)) 141 | 142 | g1 = sns.scatterplot(x=np.squeeze(gdata.layers['Ms']), 143 | y=np.squeeze(gdata.layers['Mu']), 144 | s=size, hue=adata.obs[adata.uns['label']], 145 | palette=palette, ax=axes[0]) 146 | axes[0].plot(spre, upre, color='lightgrey', linewidth=2, label='Predicted Curve') 147 | axes[0].plot(sone, uone, color='black', linewidth=2, label='Predicted Curve Time 0-1') 148 | axes[0].set_xlabel('Spliced Reads') 149 | axes[0].set_ylabel('Unspliced Reads') 150 | 151 | axes[0].set_xlim([-0.005 if gdata.layers['Ms'].min() < 1 152 | else gdata.layers['Ms'].min() * 0.95, 153 | gdata.layers['Ms'].max() * 1.05]) 154 | axes[0].set_ylim([-0.005 if gdata.layers['Mu'].min() < 1 155 | else gdata.layers['Mu'].min() * 0.95, 156 | gdata.layers['Mu'].max() * 1.05]) 157 | 158 | g2 = sns.scatterplot(x=np.squeeze(adata.obs[time_metric].values), 159 | y=np.squeeze(gdata.layers['Ms']), 160 | s=size, hue=adata.obs[adata.uns['label']], 161 | palette=palette, ax=axes[1]) 162 | sns.lineplot(x=t_one, y=sone, color='black', linewidth=2, ax=axes[1]) 163 | 164 | axes[1].set_xlabel('Inferred Cell Time') 165 | axes[1].set_ylabel('Spliced') 166 | 167 | g3 = sns.scatterplot(x=np.squeeze(adata.obs[time_metric].values), 168 | y=np.squeeze(gdata.layers['Mu']), 169 | s=size, hue=adata.obs[adata.uns['label']], 170 | palette=palette, ax=axes[2]) 171 | sns.lineplot(x=t_one, y=uone, color='black', linewidth=2, ax=axes[2]) 172 | 173 | axes[2].set_xlabel('Inferred Cell Time') 174 | axes[2].set_ylabel('Unspliced') 175 | 176 | # if not show_ax: 177 | # axes.axis("off") 178 | 179 | if not show_legend: 180 | g1.get_legend().remove() 181 | g2.get_legend().remove() 182 | g3.get_legend().remove() 183 | 184 | axes[1].set_title(gn, fontsize=12) 185 | plt.show() 186 | 187 | if save_fig: 188 | plt.savefig(os.path.join(adata.uns['temp'], f'GM_{gn}.png'), dpi=300, bbox_inches='tight') 189 | 190 | def plot_phase_portrait(adata, args, sobs, uobs, spre, upre): 191 | if 'examine_genes' in adata.uns.keys(): 192 | display.clear_output(wait=True) 193 | examine = Validation(adata) 194 | fig, axes = plt.subplots(1, 1, figsize=(6, 4)) 195 | 196 | examine.plot_mf(adata, sobs, uobs, axes[0]) 197 | examine.plot_mf(adata, spre, upre, axes[0]) 198 | plt.show() 199 | 200 | else: 201 | pass 202 | 203 | def plot_cell_time(adata): 204 | if 'examine_genes' in adata.uns.keys(): 205 | raise ValueError( 206 | f'self.VGENES in configuration file should not be a specified gene name.\n' 207 | f'Please re-run the model use alternative setting.' 208 | ) 209 | 210 | else: 211 | scv.pl.scatter( 212 | adata, color='latent_time', cmap='gnuplot', 213 | size=25, title='Assigned cell time', dpi=300 214 | ) 215 | 216 | def plot_loss(iter, loss, thres=None): 217 | fig, axes = plt.subplots(1, 2, figsize=(12, 4)) 218 | x = range(iter + 1) 219 | 220 | subiter, subloss = x[800:thres - 1], loss[800:thres - 1] 221 | axes[0].plot(subiter, subloss) 222 | axes[0].set_title('Iter # from 800 to cutoff') 223 | axes[0].set_ylabel('Euclidean Loss') 224 | 225 | # subiter, subloss = x[int(iter / 2):], loss[int(iter / 2):] 226 | # axes[1].plot(subiter, subloss) 227 | # axes[1].set_title('Iter # from 1/2 of maximum') 228 | 229 | subiter, subloss = x[int(thres * 1.01):], loss[int(thres * 1.01):] 230 | axes[1].plot(subiter, subloss) 231 | axes[1].set_title('Iter # from cutoff to terminated state') 232 | 233 | plt.show() 234 | plt.close() 235 | 236 | def plot_compare_loss(adata): 237 | var = adata.var.loc[adata.var['velocity_genes'] == True] 238 | 239 | fig, axes = plt.subplots(1, 2, figsize=(14, 4)) 240 | x, y = np.log(var['li_loss']), np.log(var['fit_loss']) 241 | sns.scatterplot(x, y, ax=axes[0]) 242 | axes[0].plot([np.min(x), np.max(x)], [np.min(x), np.max(x)], linestyle='--', c='r') 243 | axes[0].set_title('Log Loss') 244 | 245 | x, y = var['li_loss'], var['fit_loss'] 246 | sns.scatterplot(x, y, ax=axes[1]) 247 | axes[1].plot([np.min(x), np.max(x)], [np.min(x), np.max(x)], linestyle='--', c='r') 248 | axes[1].set_title('Normal Loss') 249 | 250 | plt.show() 251 | plt.close() 252 | 253 | print (f'---> # of genes which linear regression loss is smaller: ', end='') 254 | print (var.loc[var['fit_loss'] >= var['li_loss']].shape[0]) 255 | 256 | def plot_compare_bic(adata): 257 | var = adata.var.loc[adata.var['velocity_genes'] == True] 258 | 259 | fig, axes = plt.subplots(1, 1, figsize=(6, 4)) 260 | x, y = var['li_bic'], var['fit_bic'] 261 | sns.scatterplot(x, y, ax=axes) 262 | axes.plot([np.min(x), np.max(x)], [np.min(x), np.max(x)], linestyle='--', c='r') 263 | axes.set_title('Normal BIC') 264 | 265 | plt.show() 266 | plt.close() 267 | 268 | print (f'---> # of genes which linear regression BIC is smaller: ', end='') 269 | print (var.loc[var['fit_bic'] >= var['li_bic']].shape[0]) 270 | print (var.loc[var['fit_bic'] >= var['li_bic']].shape[0] / var.shape[0]) 271 | 272 | def plot_compare_llf(adata): 273 | var = adata.var.loc[adata.var['velocity_genes'] == True] 274 | 275 | fig, axes = plt.subplots(1, 1, figsize=(6, 4)) 276 | ratio = -2 * (var['li_llf'] - var['fit_llf']) 277 | sns.distplot(ratio, ax=axes, bins=200, kde=True) 278 | axes.set_title('Likelihood Ratio') 279 | 280 | plt.show() 281 | plt.close() 282 | 283 | def plot_reverse_tran_scatter(adata): 284 | sns.scatterplot(x='rbf_r2', y='qua_r2', 285 | data=adata.var.loc[adata.var['velocity_genes'] == True]) 286 | plt.axline((0, 0), (0.5, 0.5), color='r') 287 | 288 | plt.title(f'$R^2$ comparison of RBF and Quadratic model') 289 | plt.show() 290 | plt.close() -------------------------------------------------------------------------------- /unitvelo/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import os 4 | np.random.seed(42) 5 | import logging 6 | import scvelo as scv 7 | import tensorflow as tf 8 | 9 | def get_cgene_list(): 10 | s_genes_list = \ 11 | ['Mcm5', 'Pcna', 'Tyms', 'Fen1', 'Mcm2', 'Mcm4', 'Rrm1', 'Ung', 'Gins2', 12 | 'Mcm6', 'Cdca7', 'Dtl', 'Prim1', 'Uhrf1', 'Mlf1ip', 'Hells', 'Rfc2', 13 | 'Rpa2', 'Nasp', 'Rad51ap1', 'Gmnn', 'Wdr76', 'Slbp', 'Ccne2', 'Ubr7', 14 | 'Pold3', 'Msh2', 'Atad2', 'Rad51', 'Rrm2', 'Cdc45', 'Cdc6', 'Exo1', 'Tipin', 15 | 'Dscc1', 'Blm', 'Casp8ap2', 'Usp1', 'Clspn', 'Pola1', 'Chaf1b', 'Brip1', 'E2f8'] 16 | 17 | g2m_genes_list = \ 18 | ['Hmgb2', 'Cdk1', 'Nusap1', 'Ube2c', 'Birc5', 'Tpx2', 'Top2a', 'Ndc80', 19 | 'Cks2', 'Nuf2', 'Cks1b', 'Mki67', 'Tmpo', 'Cenpf', 'Tacc3', 'Fam64a', 20 | 'Smc4', 'Ccnb2', 'Ckap2l', 'Ckap2', 'Aurkb', 'Bub1', 'Kif11', 'Anp32e', 21 | 'Tubb4b', 'Gtse1', 'Kif20b', 'Hjurp', 'Cdca3', 'Hn1', 'Cdc20', 'Ttk', 22 | 'Cdc25c', 'Kif2c', 'Rangap1', 'Ncapd2', 'Dlgap5', 'Cdca2', 'Cdca8', 23 | 'Ect2', 'Kif23', 'Hmmr', 'Aurka', 'Psrc1', 'Anln', 'Lbr', 'Ckap5', 24 | 'Cenpe', 'Ctcf', 'Nek2', 'G2e3', 'Gas2l3', 'Cbx5', 'Cenpa'] 25 | 26 | return s_genes_list, g2m_genes_list 27 | 28 | def new_adata_col(adata, var_names, values): 29 | for i, name in enumerate(var_names): 30 | # adata.var[name] = np.zeros(adata.n_vars) * np.nan 31 | adata.var[name] = values[i] 32 | 33 | def get_cycle_gene(adata): 34 | from .utils import get_cgene_list 35 | s_genes_list, g2m_genes_list = get_cgene_list() 36 | 37 | # var = adata.var.loc[adata.var['velocity_genes'] == True] 38 | phase_s, phase_g2 = [], [] 39 | 40 | for gene in adata.var.index: 41 | if gene in s_genes_list: 42 | phase_s.append(gene) 43 | if gene in g2m_genes_list: 44 | phase_g2.append(gene) 45 | 46 | return phase_s, phase_g2 47 | 48 | def col_corrcoef(raw, fit): 49 | from .utils import min_max 50 | 51 | res = [] 52 | for col in range(raw.shape[1]): 53 | corr = np.corrcoef(min_max(raw[:, col]), min_max(fit[:, col]))[0][1] 54 | res.append(corr) 55 | 56 | return np.array(res) 57 | 58 | def col_spearman(raw, fit): 59 | from scipy.stats import spearmanr 60 | from .utils import min_max 61 | 62 | res = [] 63 | for col in range(raw.shape[1]): 64 | results, _ = spearmanr(min_max(raw[:, col]), min_max(fit[:, col])) 65 | res.append(results) 66 | 67 | return np.array(res) 68 | 69 | def col_spearman_unorm(raw, fit): 70 | from scipy.stats import spearmanr 71 | 72 | res = [] 73 | for col in range(raw.shape[1]): 74 | results, _ = spearmanr(raw[:, col], fit[:, col]) 75 | res.append(results) 76 | 77 | return np.array(res) 78 | 79 | def gene_level_spearman(adata, iroot=None, n_top_genes=2000, file_name=None): 80 | """ 81 | Spearman correlation between the expression profiles fitness 82 | of the top-down (scVelo Dynamical mode) or 83 | bottom-up strategies (UniTVelo Independent mode) 84 | Could serve as a systematic comparison (e.g. the entire transcriptome) 85 | between two methods 86 | """ 87 | 88 | import scvelo as scv 89 | import scanpy as sc 90 | import numpy as np 91 | import pandas as pd 92 | from .utils import col_spearman 93 | from scipy.stats import spearmanr 94 | 95 | print ('---> Running scVelo dynamical mode') 96 | scvdata = scv.read(adata.uns['datapath']) 97 | scv.pp.filter_and_normalize(scvdata, min_shared_counts=20, n_top_genes=n_top_genes) 98 | scv.pp.moments(scvdata, n_pcs=30, n_neighbors=30) 99 | scv.tl.recover_dynamics(scvdata, n_jobs=20) 100 | scv.tl.velocity(scvdata, mode='dynamical') 101 | scv.tl.latent_time(scvdata) 102 | 103 | print ('---> Caculating spearman correlation with reference') 104 | if 'organoids' in adata.uns['datapath']: 105 | scvdata.var.index = scvdata.var['gene'].values 106 | 107 | index = adata.var.loc[adata.var['velocity_genes'] == True].index.intersection(scvdata.var.loc[scvdata.var['velocity_genes'] == True].index) 108 | tadata = adata[:, index].layers['fit_t'] 109 | tscv = scvdata[:, index].layers['fit_t'] 110 | 111 | if iroot != None: 112 | print ('---> Generating reference with diffusion pseudotime') 113 | sc.tl.diffmap(adata) 114 | adata.uns['iroot'] = np.flatnonzero(adata.obs[adata.uns['label']] == iroot)[0] 115 | sc.tl.dpt(adata) 116 | 117 | gt = adata.obs['dpt_pseudotime'].values 118 | gt = np.broadcast_to(np.reshape(gt, (-1, 1)), tadata.shape) 119 | resadata = col_spearman(gt, tadata) 120 | resscv = col_spearman(gt, tscv) 121 | 122 | results, _ = spearmanr(min_max(scvdata.obs['latent_time']), min_max(gt)) 123 | print (np.reshape(results, (-1, 1))) 124 | 125 | if file_name != None: 126 | print ('---> Generating reference pseudotime with Slingshot') 127 | slingshot = pd.read_csv(f'./slingshot/{file_name}') 128 | slingshot = slingshot['x'].values 129 | 130 | slingshot = np.broadcast_to(np.reshape(slingshot, (-1, 1)), tadata.shape) 131 | resadata = col_spearman(slingshot, tadata) 132 | resscv = col_spearman(slingshot, tscv) 133 | 134 | results, _ = spearmanr(min_max(scvdata.obs['latent_time']), min_max(slingshot[:, 0])) 135 | print (np.reshape(results, (-1, 1))) 136 | 137 | corr = pd.DataFrame(index=index, columns=['UniTVelo', 'scVelo']) 138 | corr['UniTVelo'] = np.reshape(resadata, (-1, 1)) 139 | corr['scVelo'] = np.reshape(resscv, (-1, 1)) 140 | 141 | return corr['UniTVelo'].values, corr['scVelo'].values, corr 142 | 143 | def gene_level_comparison(corr): 144 | import seaborn as sns 145 | import matplotlib.pyplot as plt 146 | 147 | sns.boxplot(data=corr) 148 | plt.title('Comparison of gene-specific time', fontsize=12) 149 | plt.ylabel('Spearman Correlation', fontsize=12) 150 | plt.show() 151 | plt.close() 152 | 153 | sns.scatterplot(x='UniTVelo', y='scVelo', data=corr) 154 | plt.axline((0, 0), (0.5, 0.5), color='r') 155 | plt.show() 156 | plt.close() 157 | 158 | def col_mse(raw, fit): 159 | from sklearn.metrics import mean_squared_error 160 | 161 | res = [] 162 | for col in range(raw.shape[1]): 163 | results = mean_squared_error(raw[:, col], fit[:, col]) 164 | res.append(results) 165 | return np.array(res) 166 | 167 | def col_minmax(matrix): 168 | return (matrix - np.min(matrix, axis=0)) \ 169 | / (np.max(matrix, axis=0) - np.min(matrix, axis=0)) 170 | 171 | def inv_prob(obs, fit): 172 | temp = np.abs(obs - fit) 173 | temp = np.log(np.sum(temp, axis=0)) 174 | temp = np.exp(-temp) 175 | return temp / np.sum(temp) 176 | 177 | def remove_dir(data_path, adata): 178 | import shutil 179 | dir = os.path.split(data_path)[0] 180 | filename = os.path.splitext(os.path.basename(data_path))[0] 181 | 182 | NEW_DIR = os.path.join(dir, filename) 183 | adata.uns['temp'] = NEW_DIR 184 | 185 | if os.path.exists(NEW_DIR): 186 | shutil.rmtree(NEW_DIR) 187 | os.mkdir(NEW_DIR) 188 | 189 | def save_vars( 190 | adata, 191 | args, 192 | fits, 193 | fitu, 194 | K=1, 195 | scaling=1 196 | ): 197 | from .optimize_utils import exp_args 198 | 199 | s = pd.DataFrame(data=fits, index=adata.obs.index, columns=adata.var.index) 200 | u = pd.DataFrame(data=fitu, index=adata.obs.index, columns=adata.var.index) 201 | ms = pd.DataFrame(data=adata.layers['Ms'], index=adata.obs.index, columns=adata.var.index) 202 | mu = pd.DataFrame(data=adata.layers['Mu'], index=adata.obs.index, columns=adata.var.index) 203 | s['label'] = adata.obs[adata.uns['label']].values 204 | 205 | if adata.var.index[0].startswith('ENSMUSG'): 206 | adata.var.index = adata.var['gene'] 207 | adata.var.index.name = 'index' 208 | 209 | var = pd.DataFrame(data=np.zeros((adata.shape[1], )), index=adata.var.index) 210 | del var[0] 211 | 212 | pars = [] 213 | for i in range(len(args)): 214 | if args[i].shape[0] > 1: 215 | for k in range(K): 216 | # par = np.zeros(adata.n_vars) * np.nan 217 | par = args[i][k, :].numpy() 218 | pars.append(par) 219 | else: 220 | # par = np.zeros(adata.n_vars) * np.nan 221 | par = args[i].numpy() 222 | pars.append(par) 223 | 224 | for i, name in enumerate(adata.uns['par_names']): 225 | var[name] = np.transpose(pars[i]) 226 | 227 | columns = exp_args(adata, K=K) 228 | for col in columns: 229 | var[col] = np.exp(var[col]) 230 | 231 | var['beta'] /= scaling 232 | var['intercept'] *= scaling 233 | 234 | NEW_DIR = adata.uns['temp'] 235 | 236 | s.to_csv(f'{NEW_DIR}/fits.csv') 237 | u.to_csv(f'{NEW_DIR}/fitu.csv') 238 | var.to_csv(f'{NEW_DIR}/fitvar.csv') 239 | ms.to_csv(f'{NEW_DIR}/Ms.csv') 240 | mu.to_csv(f'{NEW_DIR}/Mu.csv') 241 | 242 | def min_max(data): 243 | return (data - np.min(data)) / (np.max(data) - np.min(data)) 244 | 245 | def make_dense(X): 246 | from scipy.sparse import issparse 247 | 248 | XA = X.A if issparse(X) and X.ndim == 2 else X.A1 if issparse(X) else X 249 | if XA.ndim == 2: 250 | XA = XA[0] if XA.shape[0] == 1 else XA[:, 0] if XA.shape[1] == 1 else XA 251 | return np.array(XA) 252 | 253 | def get_weight(x, y=None, perc=95): 254 | from scipy.sparse import issparse 255 | 256 | xy_norm = np.array(x.A if issparse(x) else x) 257 | if y is not None: 258 | if issparse(y): 259 | y = y.A 260 | xy_norm = xy_norm / np.clip(np.max(xy_norm, axis=0), 1e-3, None) 261 | xy_norm += y / np.clip(np.max(y, axis=0), 1e-3, None) 262 | 263 | if isinstance(perc, int): 264 | weights = xy_norm >= np.percentile(xy_norm, perc, axis=0) 265 | else: 266 | lb, ub = np.percentile(xy_norm, perc, axis=0) 267 | weights = (xy_norm <= lb) | (xy_norm >= ub) 268 | 269 | return weights 270 | 271 | def R2(residual, total): 272 | r2 = np.ones(residual.shape[1]) - \ 273 | np.sum(residual * residual, axis=0) / \ 274 | np.sum(total * total, axis=0) 275 | r2[np.isnan(r2)] = 0 276 | return r2 277 | 278 | def OLS(x, y): 279 | mean_x, mean_y = np.mean(x, axis=0), np.mean(y, axis=0) 280 | numerator = np.sum(x * y - mean_y * x, axis=0) 281 | denominator = np.sum(x ** 2 - mean_x * x, axis=0) 282 | 283 | coef_ = numerator / denominator 284 | inter_ = mean_y - coef_ * mean_x 285 | return coef_, inter_ 286 | 287 | def get_model_para(adata): 288 | var = adata.var.loc[adata.var['velocity_genes'] == True] 289 | var = var[[ 290 | 'scaling', 'fit_vars', 291 | 'fit_varu', 'fit_gamma', 'fit_beta', 'fit_offset', 292 | 'fit_a', 'fit_t', 'fit_h', 293 | 'fit_intercept', 'fit_loss', 'fit_bic', 294 | 'fit_llf', 'fit_sr2', 'fit_ur2' 295 | ]] 296 | 297 | if 're_transit' in adata.var.columns: 298 | var[0].extend(['qua_r2', 'rbf_r2', 're_transit']) 299 | 300 | return var 301 | 302 | def reverse_transient(adata, time_metric='latent_time'): 303 | from scipy.optimize import curve_fit 304 | import numpy as np 305 | from tqdm.notebook import tqdm 306 | 307 | adata.var['re_transit'] = False 308 | adata.var['qua_r2'] = -1. 309 | adata.var['rbf_r2'] = -1. 310 | sigma_max = np.max(adata.var.loc[adata.var['velocity_genes'] == True]['fit_a']) 311 | celltime = adata.obs[time_metric].values 312 | 313 | def quadratic(x, a, b, c): 314 | return a * (x ** 2) + b * x + c 315 | 316 | def rbf(x, h, sigma, tau): 317 | return h * np.exp(-sigma * (x - tau) * (x - tau)) 318 | 319 | for index, row in tqdm(adata.var.iterrows()): 320 | if row['velocity_genes']: 321 | spliced = np.squeeze(np.array(adata[:, index].layers['Ms'])) 322 | popt, _ = curve_fit(quadratic, celltime, spliced) 323 | 324 | fitted = quadratic(celltime, popt[0], popt[1], popt[2]) 325 | ss_res = np.sum((spliced - fitted) ** 2) 326 | ss_tot = np.sum((spliced - np.mean(spliced)) ** 2) 327 | r2 = 1 - (ss_res / ss_tot) 328 | adata.var.loc[index, 'qua_r2'] = r2 329 | 330 | try: 331 | popt_rbf, _ = curve_fit(rbf, celltime, spliced, 332 | maxfev=10000, 333 | bounds=([1e-2, 1e-3, -np.inf], 334 | [np.max(spliced), sigma_max, +np.inf]) 335 | ) 336 | 337 | fitted_rbf = rbf(celltime, popt_rbf[0], popt_rbf[1], popt_rbf[2]) 338 | ss_res_rbf = np.sum((spliced - fitted_rbf) ** 2) 339 | ss_tot_rbf = np.sum((spliced - np.mean(spliced)) ** 2) 340 | r2_rbf= 1 - (ss_res_rbf / ss_tot_rbf) 341 | adata.var.loc[index, 'rbf_r2'] = r2_rbf 342 | 343 | except: 344 | r2_rbf = 0 345 | adata.var.loc[index, 'rbf_r2'] = r2_rbf 346 | 347 | if r2 - r2_rbf > 0.075: 348 | adata.var.loc[index, 're_transit'] = True 349 | 350 | from .pl import plot_reverse_tran_scatter 351 | plot_reverse_tran_scatter(adata) 352 | 353 | re_tran_num = adata.var.loc[adata.var['re_transit'] == True].shape[0] 354 | re_tran_perc = re_tran_num / adata.var.loc[adata.var['velocity_genes'] == True].shape[0] 355 | logging.info(f'# of genes which are identified as reverse transient {re_tran_num}') 356 | logging.info(f'percentage of genes which are identified as reverse transient {re_tran_perc}') 357 | 358 | return adata 359 | 360 | def choose_mode(adata, label=None): 361 | print ('This function works as a reference only.') 362 | print ('For less certain scenario, we also suggest users to try both.') 363 | print ('---> Checking cell cycle scores...') 364 | 365 | from .utils import get_cgene_list 366 | import scvelo as scv 367 | scv.pp.filter_and_normalize(adata, min_shared_counts=20, n_top_genes=2000) 368 | 369 | s, g2m = get_cgene_list() 370 | num_s = len(adata.var.index.intersection(s)) 371 | num_g2m = len(adata.var.index.intersection(g2m)) 372 | 373 | print (f'---> Number of S genes {num_s}/{len(s)}') 374 | print (f'---> Number of G2M genes {num_g2m}/{len(g2m)}') 375 | 376 | if (num_s / len(s) > 0.5) or (num_g2m / len(g2m) > 0.5): 377 | print ('Independent mode is recommended, consider setting config.FIT_OPTION = 2') 378 | 379 | else: 380 | print ('# of cycle genes failed to pass thresholds') 381 | print ('---> Checking sparse cell types...') 382 | scv.pp.moments(adata, n_pcs=30, n_neighbors=30) 383 | adata.obs['cid'] = list(range(adata.shape[0])) 384 | 385 | try: 386 | neighbors = adata.uns['neighbors']['indices'] 387 | except: 388 | scv.pp.neighbors(adata, n_pcs=30, n_neighbors=30) 389 | neighbors = adata.uns['neighbors']['indices'] 390 | 391 | ctype_perc = [] 392 | ctype = list(set(adata.obs[label].values)) 393 | 394 | for type in ctype: 395 | temp = adata[adata.obs.loc[adata.obs[label] == type].index, :] 396 | temp_id = temp.obs['cid'].values 397 | temp_nei = neighbors[temp_id, 1:].flatten() 398 | 399 | temp_nei = [True if nei in temp_id else False for nei in temp_nei] 400 | ctype_perc.append(np.sum(temp_nei) / len(temp_nei)) 401 | 402 | if np.sum(np.array(ctype_perc) > 0.95) >= 3: 403 | print ('More than two sparse cell types have been detected') 404 | print ('Independent mode is recommended, consider setting config.FIT_OPTION = 2') 405 | else: 406 | print ('Unified-time mode is recommended, consider setting config.FIT_OPTION = 1') 407 | 408 | def subset_adata(adata, label=None, proportion=0.5, min_cells=50): 409 | adata.obs['cid'] = list(range(adata.shape[0])) 410 | ctype = list(set(adata.obs[label].values)) 411 | 412 | subset = [] 413 | 414 | for type in ctype: 415 | temp = adata[adata.obs.loc[adata.obs[label] == type].index, :] 416 | temp_id = temp.obs['cid'].values 417 | 418 | if len(temp_id) <= min_cells: 419 | subset.extend(list(temp_id)) 420 | elif int(temp.shape[0] * proportion) <= min_cells: 421 | subset.extend(list(np.random.choice(temp_id, size=min_cells, replace=False))) 422 | else: 423 | subset.extend(list(np.random.choice(temp_id, size=int(temp.shape[0] * proportion), replace=False))) 424 | 425 | return adata[np.array(subset), :] 426 | 427 | def subset_prediction(adata_subset, adata, config=None): 428 | from .optimize_utils import Model_Utils 429 | 430 | model = Model_Utils(adata, config=config) 431 | x = model.init_time((0, 1), (3000, adata.n_vars)) 432 | 433 | adata.var = adata_subset.var.copy() 434 | adata.uns['basis'] = adata_subset.uns['basis'] 435 | adata.uns['label'] = adata_subset.uns['label'] 436 | adata.uns['par_names'] = adata_subset.uns['par_names'] 437 | adata.uns['base_function'] = adata_subset.uns['base_function'] 438 | 439 | model.total_genes = adata.var['velocity_genes'].values 440 | model.idx = adata.var['velocity_genes'].values 441 | scaling = adata.var['scaling'].values 442 | 443 | args_shape = (1, adata.n_vars) 444 | args = [ 445 | np.broadcast_to(np.log(np.array(adata.var['fit_gamma'].values)), args_shape), 446 | np.broadcast_to(np.log(np.array(adata.var['fit_beta'].values * scaling)), args_shape), 447 | np.broadcast_to(np.array(adata.var['fit_offset'].values), args_shape), 448 | np.broadcast_to(np.log(np.array(adata.var['fit_a'].values)), args_shape), 449 | np.broadcast_to(np.array(adata.var['fit_t'].values), args_shape), 450 | np.broadcast_to(np.log(np.array(adata.var['fit_h'].values)), args_shape), 451 | np.broadcast_to(np.array(adata.var['fit_intercept'].values / scaling), args_shape) 452 | ] 453 | 454 | s_predict, s_deri_predict, u_predict = \ 455 | model.get_fit_s(args, x), model.get_s_deri(args, x), model.get_fit_u(x) 456 | s_predict = tf.expand_dims(s_predict, axis=0) # 1 3000 d 457 | u_predict = tf.expand_dims(u_predict, axis=0) 458 | Mu = tf.expand_dims(adata.layers['Mu'] / scaling, axis=1) # n 1 d 459 | Ms = tf.expand_dims(adata.layers['Ms'], axis=1) 460 | 461 | t_cell = model.match_time(Ms, Mu, s_predict, u_predict, x.numpy(), config.MAX_ITER) 462 | t_cell = np.reshape(t_cell, (-1, 1)) 463 | t_cell = np.broadcast_to(t_cell, adata.shape) 464 | 465 | adata.layers['fit_t'] = t_cell.copy() 466 | adata.layers['fit_t'][:, ~adata.var['velocity_genes'].values] = np.nan 467 | 468 | model.fit_s = model.get_fit_s(args, t_cell).numpy() 469 | model.s_deri = model.get_s_deri(args, t_cell).numpy() 470 | model.fit_u = model.get_fit_u(args).numpy() 471 | adata.layers['velocity'] = model.s_deri 472 | 473 | adata.obs['latent_time_gm'] = min_max(np.nanmean(adata.layers['fit_t'], axis=1)) 474 | scv.tl.velocity_graph(adata, sqrt_transform=True) 475 | scv.tl.velocity_embedding(adata, basis=adata.uns['basis']) 476 | scv.tl.latent_time(adata, min_likelihood=None) 477 | 478 | if config.FIT_OPTION == '1': 479 | adata.obs['latent_time'] = adata.obs['latent_time_gm'] 480 | del adata.obs['latent_time_gm'] 481 | 482 | if os.path.exists(os.path.join(adata_subset.uns['temp'], 'prediction')): 483 | pass 484 | else: os.mkdir(os.path.join(adata_subset.uns['temp'], 'prediction')) 485 | adata.uns['temp'] = os.path.join(adata_subset.uns['temp'], 'prediction') 486 | 487 | import shutil 488 | shutil.copyfile(os.path.join(adata_subset.uns['temp'], 'fitvar.csv'), 489 | os.path.join(adata.uns['temp'], 'fitvar.csv')) 490 | 491 | import pandas as pd 492 | s = pd.DataFrame(data=model.fit_s, index=adata.obs.index, columns=adata.var.index) 493 | u = pd.DataFrame(data=model.fit_u, index=adata.obs.index, columns=adata.var.index) 494 | ms = pd.DataFrame(data=adata.layers['Ms'], index=adata.obs.index, columns=adata.var.index) 495 | mu = pd.DataFrame(data=adata.layers['Mu'], index=adata.obs.index, columns=adata.var.index) 496 | s['label'] = adata.obs[adata.uns['label']].values 497 | 498 | NEW_DIR = adata.uns['temp'] 499 | s.to_csv(f'{NEW_DIR}/fits.csv') 500 | u.to_csv(f'{NEW_DIR}/fitu.csv') 501 | ms.to_csv(f'{NEW_DIR}/Ms.csv') 502 | mu.to_csv(f'{NEW_DIR}/Mu.csv') 503 | adata.write(os.path.join(NEW_DIR, f'predict_adata.h5ad')) 504 | 505 | return adata 506 | 507 | def prior_trend_valid(adata, gene_list=None, name='IROOT'): 508 | import sys 509 | vgenes_temp = [] 510 | vgenes = adata.var.loc[adata.var['velocity_genes'] == True].index 511 | 512 | file_path = os.path.join(adata.uns['temp'], 'vgenes.txt') 513 | with open(file_path, 'w') as fp: 514 | for item in vgenes: 515 | fp.write("%s\n" % item) 516 | 517 | for prior in gene_list: 518 | if prior[0] not in vgenes: 519 | print (f'{prior[0]} of {name} has not been identified as a velocity gene') 520 | else: 521 | vgenes_temp.append(prior + (list(adata.var.index).index(prior[0]), )) 522 | 523 | if len(vgenes_temp) == 0: 524 | print (f'---> No genes has been identified as velocity genes') 525 | print (f'Consider selecting one from {file_path}') 526 | sys.exit() 527 | else: 528 | print (f'Modified {name} {vgenes_temp} with index') 529 | return vgenes_temp 530 | 531 | def init_config_summary(config=None): 532 | from .config import Configuration 533 | if config == None: 534 | print (f'Model configuration file not specified. Default settings with unified-time mode will be used.') 535 | config = Configuration() 536 | 537 | if config.FIT_OPTION == '1': 538 | config.DENSITY = 'SVD' if config.GENE_PRIOR == None else 'Raw' 539 | config.REORDER_CELL = 'Soft_Reorder' 540 | config.AGGREGATE_T = True 541 | 542 | elif config.FIT_OPTION == '2': 543 | config.DENSITY = 'Raw' 544 | config.REORDER_CELL = 'Hard' 545 | config.AGGREGATE_T = False 546 | 547 | else: 548 | raise ValueError('config.FIT_OPTION is invalid') 549 | 550 | print ('------> Manully Specified Parameters <------') 551 | config_ref = Configuration() 552 | dict_input, dict_ref = vars(config), vars(config_ref) 553 | 554 | para_used = [] 555 | for parameter in dict_ref: 556 | if dict_input[parameter] != dict_ref[parameter]: 557 | print (parameter, dict_input[parameter], sep=f':\t') 558 | para_used.append(parameter) 559 | 560 | print ('------> Model Configuration Settings <------') 561 | default_para = ['N_TOP_GENES', 562 | 'LEARNING_RATE', 563 | 'FIT_OPTION', 564 | 'DENSITY', 565 | 'REORDER_CELL', 566 | 'AGGREGATE_T', 567 | 'R2_ADJUST', 568 | 'GENE_PRIOR', 569 | 'VGENES', 570 | 'IROOT'] 571 | 572 | for parameter in default_para: 573 | if parameter not in para_used: 574 | print (parameter, dict_ref[parameter], sep=f':\t') 575 | 576 | print ('--------------------------------------------') 577 | print ('') 578 | return config, para_used 579 | 580 | def init_adata_and_logs(adata, config, normalize=True): 581 | if type(adata) == str: 582 | data_path = adata 583 | adata = scv.read(data_path) 584 | 585 | else: 586 | cwd = os.getcwd() 587 | if os.path.exists(os.path.join(cwd, 'res')): 588 | pass 589 | else: os.mkdir(os.path.join(cwd, 'res')) 590 | 591 | print (f'Current working dir is {cwd}.') 592 | print (f'Results will be stored in res folder') 593 | data_path = os.path.join(cwd, 'res', 'temp.h5ad') 594 | 595 | from .utils import remove_dir 596 | remove_dir(data_path, adata) 597 | logging.basicConfig(filename=os.path.join(adata.uns['temp'], 'logging.txt'), 598 | filemode='a', 599 | format='%(asctime)s, %(levelname)s, %(message)s', 600 | datefmt='%H:%M:%S', 601 | level=logging.INFO) 602 | 603 | if normalize: 604 | scv.pp.filter_and_normalize(adata, 605 | min_shared_counts=config.MIN_SHARED_COUNTS, 606 | n_top_genes=config.N_TOP_GENES) 607 | print (f"Extracted {adata.var[adata.var['highly_variable'] == True].shape[0]} highly variable genes.") 608 | 609 | print (f'Computing moments for {len(adata.var)} genes with n_neighbors: {config.N_NEIGHBORS} and n_pcs: {config.N_PCS}') 610 | scv.pp.moments(adata, 611 | n_pcs=config.N_PCS, 612 | n_neighbors=config.N_NEIGHBORS) 613 | else: 614 | scv.pp.neighbors(adata) 615 | 616 | print ('') 617 | return adata, data_path -------------------------------------------------------------------------------- /unitvelo/velocity.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import pandas as pd 4 | from .model import lagrange 5 | from .utils import make_dense, get_weight, R2 6 | import scvelo as scv 7 | import tensorflow as tf 8 | 9 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 10 | np.random.seed(42) 11 | 12 | class Velocity: 13 | def __init__( 14 | self, 15 | adata=None, 16 | min_ratio=0.01, 17 | min_r2=0.01, 18 | fit_offset=False, 19 | perc=[5, 95], 20 | vkey='velocity', 21 | config=None 22 | ): 23 | self.adata = adata 24 | self.vkey = vkey 25 | 26 | self.Ms = adata.layers["spliced"] if config.USE_RAW else adata.layers["Ms"].copy() 27 | self.Mu = adata.layers["unspliced"] if config.USE_RAW else adata.layers["Mu"].copy() 28 | self.Ms, self.Mu = make_dense(self.Ms), make_dense(self.Mu) 29 | 30 | self.min_r2 = min_r2 31 | self.min_ratio = min_ratio 32 | 33 | n_obs, n_vars = self.Ms.shape 34 | self.gamma = np.zeros(n_vars, dtype=np.float32) 35 | self.r2 = np.zeros(n_vars, dtype=np.float32) 36 | self.velocity_genes = np.ones(n_vars, dtype=bool) 37 | self.residual_scale = np.zeros([n_obs, n_vars], dtype=np.float32) 38 | 39 | self.perc = perc 40 | self.fit_offset = fit_offset 41 | self.config = config 42 | 43 | def get_velo_genes(self): 44 | variable = self.adata.var 45 | if variable.index[0].startswith('ENSMUSG'): 46 | variable.index = variable['gene'] 47 | variable.index.name = 'index' 48 | 49 | weights = get_weight(self.Ms, self.Mu, perc=95) 50 | Ms, Mu = weights * self.Ms, weights * self.Mu 51 | 52 | self.gamma_quantile = np.sum(Mu * Ms, axis=0) / np.sum(Ms * Ms, axis=0) 53 | self.scaling = np.std(self.Mu, axis=0) / np.std(self.Ms, axis=0) 54 | self.adata.layers['Mu_scale'] = self.Mu / self.scaling 55 | 56 | if self.config.R2_ADJUST: 57 | Ms, Mu = self.Ms, self.Mu 58 | 59 | self.gene_index = variable.index 60 | self.gamma_ref = np.sum(Mu * Ms, axis=0) / np.sum(Ms * Ms, axis=0) 61 | self.residual_scale = self.Mu - self.gamma_ref * self.Ms 62 | self.r2 = R2(self.residual_scale, total=self.Mu - np.mean(self.Mu, axis=0)) 63 | 64 | self.velocity_genes = np.ones(Ms.shape[1], dtype=bool) 65 | 66 | if type(self.config.VGENES) == str and \ 67 | self.config.VGENES in self.adata.var.index: 68 | self.adata.uns['examine_genes'] = self.config.VGENES 69 | self.velocity_genes = np.zeros(Ms.shape[1], dtype=np.bool) 70 | self.velocity_genes[ 71 | np.argwhere(self.adata.var.index == self.config.VGENES)] = True 72 | 73 | elif type(self.config.VGENES) == list: 74 | temp = [] 75 | for gene in variable.index: 76 | if gene in self.config.VGENES: 77 | temp.append(True) 78 | else: 79 | temp.append(False) 80 | 81 | self.velocity_genes = np.array(temp) 82 | self.adata.var['velocity_gamma'] = self.gamma_ref 83 | 84 | elif self.config.VGENES == 'raws': 85 | self.velocity_genes = np.ones(Ms.shape[1], dtype=np.bool) 86 | 87 | elif self.config.VGENES == 'offset': 88 | self.fit_linear(None, self.Ms, self.Mu, 'vgene_offset', coarse=False) 89 | 90 | if self.config.FILTER_CELLS: 91 | self.fit_linear(None, self.Ms, self.Mu, 'vgene_offset', coarse=True) 92 | 93 | elif self.config.VGENES == 'basic': 94 | self.velocity_genes = ( 95 | (self.r2 > self.min_r2) 96 | & (self.r2 < 0.95) 97 | & (self.gamma_quantile > self.min_ratio) 98 | & (self.gamma_ref > self.min_ratio) 99 | & (np.max(self.Ms > 0, axis=0) > 0) 100 | & (np.max(self.Mu > 0, axis=0) > 0) 101 | ) 102 | print (f'# of velocity genes {self.velocity_genes.sum()} (Criterion: positive regression coefficient between un/spliced counts)') 103 | 104 | if self.config.R2_ADJUST: 105 | lb, ub = np.nanpercentile(self.scaling, [10, 90]) 106 | self.velocity_genes = ( 107 | self.velocity_genes 108 | & (self.scaling > np.min([lb, 0.03])) 109 | & (self.scaling < np.max([ub, 3])) 110 | ) 111 | print (f'# of velocity genes {self.velocity_genes.sum()} (Criterion: std of un/spliced reads should be moderate, w/o extreme values)') 112 | 113 | self.adata.var['velocity_gamma'] = self.gamma_ref 114 | self.adata.var['velocity_r2'] = self.r2 115 | 116 | else: 117 | raise ValueError('Plase specify the correct self.VGENES in configuration file') 118 | 119 | if True: 120 | self.init_weights() 121 | self.velocity_genes = self.velocity_genes & (self.nobs > 0.05 * Ms.shape[1]) 122 | 123 | self.adata.var['scaling'] = self.scaling 124 | self.adata.var['velocity_genes'] = self.velocity_genes 125 | self.adata.uns[f"{self.vkey}_params"] = {"mode": self.config.GENERAL, "perc": self.perc} 126 | 127 | if np.sum(self.velocity_genes) < 2: 128 | print ('---> Low signal in splicing dynamics.') 129 | 130 | from .utils import prior_trend_valid 131 | if type(self.config.IROOT) == list: 132 | self.config.IROOT = prior_trend_valid(self.adata, self.config.IROOT, 'IROOT') 133 | 134 | if type(self.config.GENE_PRIOR) == list: 135 | self.config.GENE_PRIOR = prior_trend_valid(self.adata, self.config.GENE_PRIOR, 'GENE_PRIOR') 136 | 137 | def init_weights(self): 138 | nonzero_s, nonzero_u = self.Ms > 0, self.Mu > 0 139 | weights = np.array(nonzero_s & nonzero_u, dtype=bool) 140 | self.nobs = np.sum(weights, axis=0) 141 | 142 | def fit_deterministic(self, idx, Ms, Mu, Ms_scale, Mu_scale): 143 | df_gamma = pd.DataFrame(index=self.gene_index, data=0, 144 | dtype=np.float32, columns=['coef', 'inter']) 145 | 146 | if self.fit_offset: 147 | pass 148 | 149 | else: 150 | weights_new = get_weight(Ms, Mu, perc=self.perc) 151 | x, y = weights_new * Ms_scale, weights_new * Mu_scale 152 | df_gamma['coef'][idx] = np.sum(y * x, axis=0) / np.sum(x * x, axis=0) 153 | 154 | residual = self.Mu \ 155 | - np.broadcast_to(df_gamma['coef'].values, self.Mu.shape) * self.Ms \ 156 | - np.broadcast_to(df_gamma['inter'].values, self.Mu.shape) 157 | 158 | self.adata.var['velocity_gamma'] = df_gamma['coef'].values 159 | self.adata.var['intercept'] = df_gamma['inter'].values 160 | return residual 161 | 162 | def fit_linear(self, idx, Ms, Mu, method='vgene_offset', coarse=False): 163 | ''' 164 | [bic] for linear BIC comparison with algorithm BIC 165 | [kinetic] for selection of cyclic kinetic gene and monotonic expression gene 166 | [vgene_offset] for determination of velocity gene alternatively using offset 167 | ''' 168 | 169 | if method == 'vgene_offset': 170 | from sklearn import linear_model 171 | from sklearn.metrics import r2_score 172 | 173 | index = self.adata.var.index 174 | linear = pd.DataFrame(index=index, data=0, dtype=np.float32, 175 | columns=['coef', 'inter', 'r2']) 176 | 177 | reg = linear_model.LinearRegression() 178 | for col in range(Ms.shape[1]): 179 | if not coarse: 180 | sobs = np.reshape(Ms[:, col], (-1, 1)) 181 | uobs = np.reshape(Mu[:, col], (-1, 1)) 182 | 183 | else: 184 | if self.config.FILTER_CELLS: 185 | nonzero_s = Ms[:, col] > 0 186 | nonzero_u = Mu[:, col] > 0 187 | valid = np.array(nonzero_s & nonzero_u, dtype=bool) 188 | sobs = np.reshape(Ms[:, col][valid], (-1, 1)) 189 | uobs = np.reshape(Mu[:, col][valid], (-1, 1)) 190 | 191 | else: 192 | sobs = np.reshape(Ms[:, col], (-1, 1)) 193 | uobs = np.reshape(Mu[:, col], (-1, 1)) 194 | 195 | reg.fit(sobs, uobs) 196 | u_pred = reg.predict(sobs) 197 | 198 | linear.loc[index[col], 'coef'] = float(reg.coef_) 199 | linear.loc[index[col], 'inter'] = float(reg.intercept_) 200 | linear.loc[index[col], 'r2'] = r2_score(uobs, u_pred) 201 | 202 | self.adata.var['velocity_inter'] = np.array(linear['inter'].values) 203 | self.adata.var['velocity_gamma'] = np.array(linear['coef'].values) 204 | self.adata.var['velocity_r2'] = np.array(linear['r2'].values) 205 | self.gamma_ref = np.array(linear['coef'].values) 206 | self.r2 = np.array(linear['r2'].values) 207 | 208 | self.velocity_genes = ( 209 | self.velocity_genes 210 | & (self.r2 > self.min_r2) 211 | & (self.r2 < 0.95) 212 | & (np.array(linear['coef'].values) > self.min_ratio) 213 | & (np.max(self.Ms > 0, axis=0) > 0) 214 | & (np.max(self.Mu > 0, axis=0) > 0) 215 | ) 216 | print (f'# of velocity genes {self.velocity_genes.sum()} (Criterion: positive regression coefficient between un/spliced counts)') 217 | 218 | lb, ub = np.nanpercentile(self.scaling, [10, 90]) 219 | self.velocity_genes = ( 220 | self.velocity_genes 221 | & (self.scaling > np.min([lb, 0.03])) 222 | & (self.scaling < np.max([ub, 3])) 223 | ) 224 | print (f'# of velocity genes {self.velocity_genes.sum()} (Criterion: std of un/spliced reads should be moderate, w/o extreme values)') 225 | 226 | def fit_curve(self, adata, idx, Ms_scale, Mu_scale, rep=1): 227 | physical_devices = tf.config.list_physical_devices('GPU') 228 | 229 | if len(physical_devices) == 0 or self.config.GPU == -1: 230 | tf.config.set_visible_devices([], 'GPU') 231 | device = '/cpu:0' 232 | print ('No GPU device has been detected. Switch to CPU mode.') 233 | 234 | else: 235 | assert self.config.GPU < len(physical_devices), 'Please specify the correct GPU card.' 236 | tf.config.set_visible_devices(physical_devices[self.config.GPU], 'GPU') 237 | 238 | os.environ["CUDA_VISIBLE_DEVICES"] = f'{self.config.GPU}' 239 | for gpu in physical_devices: 240 | tf.config.experimental.set_memory_growth(gpu, True) 241 | 242 | device = f'/gpu:{self.config.GPU}' 243 | print (f'Using GPU card: {self.config.GPU}') 244 | 245 | with tf.device(device): 246 | residual, adata = lagrange( 247 | adata, idx=idx, 248 | Ms=Ms_scale, Mu=Mu_scale, 249 | rep=rep, config=self.config 250 | ) 251 | 252 | return residual, adata 253 | 254 | def fit_velo_genes(self, basis='umap', rep=1): 255 | idx = self.velocity_genes 256 | print (f'# of velocity genes {idx.sum()} (Criterion: genes have reads in more than 5% of total cells)') 257 | 258 | if self.config.RESCALE_DATA: 259 | Ms_scale, Mu_scale = \ 260 | self.Ms, self.Mu / (np.std(self.Mu, axis=0) / np.std(self.Ms, axis=0)) 261 | else: 262 | Ms_scale, Mu_scale = self.Ms, self.Mu 263 | 264 | assert self.config.GENERAL in ['Deterministic', 'Curve', 'Linear'], \ 265 | 'Please specify the correct self.GENERAL in configuration file.' 266 | 267 | if self.config.GENERAL == 'Curve': 268 | residual, adata = self.fit_curve(self.adata, idx, Ms_scale, Mu_scale, rep=rep) 269 | 270 | if self.config.GENERAL == 'Deterministic': 271 | residual = self.fit_deterministic(idx, self.Ms, self.Mu, Ms_scale, Mu_scale) 272 | 273 | if self.config.GENERAL == 'Linear': 274 | self.fit_linear(idx, Ms_scale, Mu_scale, method='kinetic') 275 | print (np.sum(self.adata[:, idx].var['kinetic_gene'].values) / \ 276 | np.sum(self.adata[:, idx].var['velocity_genes'].values)) 277 | return self.adata 278 | 279 | adata.layers[self.vkey] = residual 280 | 281 | if 'examine_genes' not in adata.uns.keys() and basis != None: 282 | scv.tl.velocity_graph(adata, sqrt_transform=True) 283 | scv.tl.velocity_embedding(adata, basis=basis) 284 | scv.tl.latent_time(adata, min_likelihood=None) 285 | 286 | if self.config.FIT_OPTION == '1': 287 | adata.obs['latent_time'] = adata.obs['latent_time_gm'] 288 | del adata.obs['latent_time_gm'] 289 | 290 | return adata --------------------------------------------------------------------------------