├── .github
├── CONTRIBUTING.md
└── workflows
│ ├── black.yaml
│ ├── python-conda-build.yaml
│ └── semantic-release.yaml
├── .gitignore
├── .readthedocs.yaml
├── README.md
├── docs
├── Contacts.rst
├── Installation.rst
├── Introduction.rst
├── License.rst
├── Makefile
├── README.rst
├── Reference.rst
├── Release_notes.md
├── _static
│ ├── ambient_signal_hypothesis.png
│ ├── bgd.png
│ ├── overview_scAR.png
│ ├── scAR_favicon.png
│ ├── scAR_logo_black.png
│ ├── scAR_logo_transparent.png
│ ├── scAR_logo_white.png
│ └── scar-styles.css
├── _templates
│ └── layout.html
├── conf.py
├── index.rst
├── make.bat
├── requirements.txt
├── tutorials
│ ├── README.rst
│ ├── index.rst
│ ├── scAR_tutorial_ambient_profile.ipynb
│ ├── scAR_tutorial_batch_denoising_scRNAseq.ipynb
│ ├── scAR_tutorial_denoising_CITEseq.ipynb
│ ├── scAR_tutorial_denoising_scATACseq.ipynb
│ ├── scAR_tutorial_denoising_scRNAseq.ipynb
│ ├── scAR_tutorial_identity_barcode.ipynb
│ ├── scAR_tutorial_sgRNA_assignment.ipynb
│ └── synthetic_assignment.ipynb
└── usages
│ ├── index.rst
│ ├── processing.rst
│ ├── synthetic_dataset.rst
│ └── training.rst
├── pyproject.toml
├── scar-cpu.yml
├── scar-gpu.yml
└── scar
├── __init__.py
├── main
├── __init__.py
├── __main__.py
├── _activation_functions.py
├── _data_generater.py
├── _loss_functions.py
├── _scar.py
├── _setup.py
├── _utils.py
└── _vae.py
└── test
├── ambient_profile.pickle
├── citeseq_ambient_profile.pickle
├── citeseq_native_counts.pickle
├── citeseq_raw_counts.pickle
├── output_assignment.pickle
├── raw_counts.pickle
├── test_activation_functions.py
└── test_scar.py
/.github/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | ## How to contribute to scAR
2 |
3 | #### **Where to find the documentation?**
4 |
5 | * If you want to try **scAR**, [first check the documentation](https://scar-tutorials.readthedocs.io/en/latest/). Be aware of the versions.
6 |
7 | #### **Did you find a bug?**
8 |
9 | * **Ensure the bug was not already reported** by searching on GitHub under [Issues](https://github.com/Novartis/scAR/issues).
10 |
11 | * If you're unable to find an open issue addressing the problem, [open a new one](https://github.com/Novartis/scAR/issues/new). Be sure to include a **title and clear description**, as much relevant information as possible, and a **code sample** or an **executable test case** demonstrating the expected behavior that is not occurring.
12 |
13 | #### **Did you write a patch that fixes a bug?**
14 |
15 | * Open a new GitHub pull request with the patch. **Don't mark it [semantic versioning label](https://github.com/Novartis/scAR/labels), this may make a new release**
16 |
17 | * Ensure the PR description clearly describes the problem and solution.
18 |
19 | #### **Do you have questions about scAR documentation?**
20 |
21 | * Discuss any non-code question about scAR in the [readthedocs discussion](https://scar-tutorials.readthedocs.io/en/latest/Contacts.html#comments).
22 |
23 | #### **Do you want to contribute to the scAR?**
24 |
25 | * Please contact me @caibin.sheng@novartis.com
26 |
27 | Thanks, Caibin
28 |
--------------------------------------------------------------------------------
/.github/workflows/black.yaml:
--------------------------------------------------------------------------------
1 | name: Black lint
2 |
3 | on: [push, pull_request]
4 |
5 | jobs:
6 | lint:
7 | runs-on: ubuntu-latest
8 | steps:
9 | - uses: actions/checkout@v2
10 | - uses: psf/black@stable
11 | with:
12 | options: "--diff --verbose --line-length 127"
13 | src: "./scar"
14 | version: "22.3.0"
15 |
--------------------------------------------------------------------------------
/.github/workflows/python-conda-build.yaml:
--------------------------------------------------------------------------------
1 | name: Python Package using Conda
2 |
3 | on: [ pull_request ]
4 |
5 | jobs:
6 | build-linux:
7 | runs-on: ubuntu-latest
8 | strategy:
9 | max-parallel: 5
10 | steps:
11 | - uses: actions/checkout@v2
12 | - name: Set up Python 3.12.3
13 | uses: actions/setup-python@v2
14 | with:
15 | python-version: 3.12.3
16 | - name: Add conda to system path
17 | run: |
18 | # $CONDA is an environment variable pointing to the root of the miniconda directory
19 | echo $CONDA/bin >> $GITHUB_PATH
20 | - name: Install dependencies
21 | run: |
22 | conda env create -f scar-cpu.yml
23 | - name: Run binary
24 | run: |
25 | export PATH="$GITHUB_PATH:$PATH"
26 | source activate scar
27 | scar --help
28 | # - name: Run pylint
29 | # run: |
30 | # export PATH="$GITHUB_PATH:$PATH"
31 | # source activate scar
32 | # conda install -c anaconda pylint astroid
33 | # pylint scar --fail-under 8.5 --disable=R,C --generated-members=torch.*
34 | - name: Run unit tests
35 | run: |
36 | export PATH="$GITHUB_PATH:$PATH"
37 | source activate scar
38 | conda install pytest
39 | pytest scar
40 |
--------------------------------------------------------------------------------
/.github/workflows/semantic-release.yaml:
--------------------------------------------------------------------------------
1 | name: Semantic Release
2 |
3 | on:
4 | pull_request:
5 | types:
6 | - labeled
7 | branches:
8 | - 'main'
9 |
10 | jobs:
11 | release:
12 | if: ${{ github.event.label.name == 'semantic versioning' }}
13 | runs-on: ubuntu-latest
14 | concurrency: release
15 | permissions:
16 | id-token: write
17 | contents: write
18 | steps:
19 | - uses: actions/checkout@v3
20 | with:
21 | ref: main
22 | fetch-depth: 0
23 |
24 | - name: Python Semantic Release
25 | uses: relekang/python-semantic-release@master
26 | with:
27 | github_token: ${{ secrets.GITHUB_TOKEN }}
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | **.ipynb_checkpoints
2 | **pycache__
3 | **.DS_Store
4 | **build
5 | **.egg-info
6 | **.vscode
7 | **dist
8 | **.h5
9 | **.h5ad
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | version: 2
2 |
3 | build:
4 | os: "ubuntu-20.04"
5 | tools:
6 | python: "3.11"
7 |
8 | sphinx:
9 | configuration: docs/conf.py
10 |
11 | python:
12 | install:
13 | - requirements: docs/requirements.txt
14 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 | [](https://anaconda.org/bioconda/scar)
9 | [](http://bioconda.github.io/recipes/scar/README.html)
10 | [](https://github.com/psf/black)
11 | [](https://scar-tutorials.readthedocs.io/en/latest/?badge=latest)
12 | [](https://github.com/semantic-release/semantic-release)
13 | [](https://github.com/Novartis/scAR/actions/workflows/python-conda-build.yaml)
14 | [](https://github.com/Novartis/scAR)
15 | [](https://anaconda.org/bioconda/scar/files)
16 |
17 | **scAR** (single-cell Ambient Remover) is a tool designed for denoising ambient signals in droplet-based single-cell omics data. It can be employed for a wide range of applications, such as, **sgRNA assignment** in scCRISPRseq, **identity barcode assignment** in cell indexing, **protein denoising** in CITE-seq, **mRNA denoising** in scRNAseq, and **ATAC signal denoising** in scATACseq, among others.
18 |
19 | # Table of Contents
20 |
21 | - [Installation](#Installation)
22 | - [Dependencies](#Dependencies)
23 | - [Resources](#Resources)
24 | - [License](#License)
25 | - [Reference](#Reference)
26 |
27 | ## [Installation](https://scar-tutorials.readthedocs.io/en/latest/Installation.html)
28 | ## Dependencies
29 |
30 | [](https://pytorch.org/)
31 | [](https://www.python.org/)
32 | [](https://pytorch.org/vision/stable/index.html)
33 | [](https://github.com/tqdm/tqdm)
34 | [](https://scikit-learn.org/)
35 |
36 | ## Resources
37 |
38 | - Installation, Usages and Tutorials can be found in the [documentation](https://scar-tutorials.readthedocs.io/en/latest/).
39 | - If you'd like to contribute, please read [contributing guidelines](https://github.com/Novartis/scAR/blob/main/.github/CONTRIBUTING.md).
40 | - Please use the [issues](https://github.com/Novartis/scAR/issues) to submit bug reports.
41 |
42 | ## License
43 |
44 | This project is licensed under the terms of [License](docs/License.rst).
45 | Copyright 2022 Novartis International AG.
46 |
47 | ## Reference
48 |
49 | If you use scAR in your research, please consider citing our [manuscript](https://doi.org/10.1101/2022.01.14.476312),
50 |
51 | ```
52 | @article {Sheng2022.01.14.476312,
53 | author = {Sheng, Caibin and Lopes, Rui and Li, Gang and Schuierer, Sven and Waldt, Annick and Cuttat, Rachel and Dimitrieva, Slavica and Kauffmann, Audrey and Durand, Eric and Galli, Giorgio G and Roma, Guglielmo and de Weck, Antoine},
54 | title = {Probabilistic modeling of ambient noise in single-cell omics data},
55 | elocation-id = {2022.01.14.476312},
56 | year = {2022},
57 | doi = {10.1101/2022.01.14.476312},
58 | publisher = {Cold Spring Harbor Laboratory},
59 | URL = {https://www.biorxiv.org/content/early/2022/01/14/2022.01.14.476312},
60 | eprint = {https://www.biorxiv.org/content/early/2022/01/14/2022.01.14.476312.full.pdf},
61 | journal = {bioRxiv}
62 | }
63 | ```
64 |
--------------------------------------------------------------------------------
/docs/Contacts.rst:
--------------------------------------------------------------------------------
1 | Contacts
2 | ===============
3 |
4 | Authors
5 | ------------------------------------------------
6 | The following authors contributed to design, implementation, integration, and maintenace of scAR.
7 |
8 | `@Caibin `_
9 | `@AlexMTYZ `_
10 | `@fgypas `_
11 | `@mr-nvs `_
12 | `@Tobias-Ternent `_
13 | `@adeweck `_
14 |
15 | Contributing
16 | ------------------------------------------------
17 | All kinds of contribution are welcome! Please file `issues `_.
18 |
19 | Comments
20 | ------------------------------------------------
21 | Post your comments or questions here:
22 |
23 | .. disqus::
24 | :disqus_identifier: scar-discussion
--------------------------------------------------------------------------------
/docs/Installation.rst:
--------------------------------------------------------------------------------
1 | Installation
2 | ================
3 |
4 | To use ``scar``, first install it using conda install or Git+pip.
5 |
6 | Conda install
7 | -------------------------------
8 |
9 | 1, Install `conda `_
10 |
11 | 2, Create conda environment::
12 |
13 | conda create -n scar
14 |
15 | 3, Activate conda environment::
16 |
17 | conda activate scar
18 |
19 | 4, Install `PyTorch `_
20 |
21 | 5, Install scar::
22 |
23 | conda install bioconda::scar
24 |
25 | 6, Activate the scar conda environment::
26 |
27 | conda activate scar
28 |
29 | Git + pip
30 | -------------------------------------------
31 |
32 | 1, Clone scar repository::
33 |
34 | git clone https://github.com/Novartis/scar.git
35 |
36 | 2, Enter the cloned directory::
37 |
38 | cd scar
39 |
40 | 3, Create a conda environment
41 |
42 | .. tabs::
43 |
44 | .. tab:: GPU version
45 |
46 | .. code-block::
47 | :caption: Please use ``scar-gpu`` if you have an nvidia graphics card and the corresponding driver installed
48 |
49 | conda env create -f scar-gpu.yml
50 |
51 | .. tab:: CPU version
52 |
53 | .. code-block::
54 | :caption: Please use ``scar-cpu`` if you don't have a graphics card availalble
55 |
56 | conda env create -f scar-cpu.yml
57 |
58 | 4, Activate the scar conda environment::
59 |
60 | conda activate scar
61 |
62 |
63 |
64 |
65 |
--------------------------------------------------------------------------------
/docs/Introduction.rst:
--------------------------------------------------------------------------------
1 | Introduction
2 | ===============
3 |
4 | What is ambient signal?
5 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
6 |
7 | .. image:: _static/ambient_signal_hypothesis.png
8 | :width: 500
9 | :align: center
10 |
11 |
12 | When preparing a single-cell solution, cell lysis releases RNA or protein counts that become encapsulated by droplets. These exogenous molecules are mixed with native ones and barcoded by the same 10x beads, leading to overestimated count data. The presence of this ambient signal can compromise downstream analysis and even introduce significant bias in certain cases, such as scCRISPR-seq and cell multiplexing.
13 |
14 | The design of scAR
15 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
16 |
17 | .. image:: _static/overview_scAR.png
18 | :width: 600
19 | :align: center
20 |
21 |
22 | scAR employs a latent variable model that represents both biological and technical components in the observed count data. The model is developed under the ambient signal hypothesis, where the probability of each ambient transcript's occurrence can be estimated empirically from cell-free droplets. The model has two hidden variables, namely the contamination level per cell and the probability of native transcript occurrence. With these three parameters, scAR can reconstruct noisy observations. We train neural networks, specifically the variational autoencoder, to learn the hidden variables by minimizing the differences between the reconstructions and the original noisy observations. Once the model converges, contamination levels and native expression are inferred, and downstream analysis can be performed using these values. For more information, please refer to our manuscript [Sheng2022]_.
23 |
24 | What types of data that scAR can process?
25 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
26 | We validated the effectiveness of scAR across a range of droplet-based single-cell omics technologies, including scRNAseq, scCRISPR-seq, CITEseq, and scATACseq. scAR was able to remove ambient mRNA, assign sgRNAs, assign cell tags, clean noisy protein counts (ADT), and clean peak counts, resulting in significant improvements in data quality in all tested datasets. Notably, scAR was able to recover a substantial proportion (33% to 50%) of cells in scCRISPR-seq and cell multiplexing experiments. Given that ambient contamination is a common issue in droplet-based single-cell omics, particularly in complex experiments or samples, scAR represents a viable solution for addressing this challenge.
27 |
28 | What are the alternative apporaches?
29 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
30 | Several methods exist for modeling noise in single-cell omics data. In general, these methods can be classified into two categories: those that deal with background noise and those that model stochastic noise. Some examples of these methods are provided below.
31 |
32 | +-------------------------------------------+-------------------------------------------+
33 | | Background noise | Stachastic noise |
34 | +========+===============+==================+========+===============+==================+
35 | | CellBender [Fleming2023]_ | scVI [Lopez2018]_ |
36 | +-------------------------------------------+-------------------------------------------+
37 | | SoupX [Young2020]_ | DCA [Eraslan2019]_ |
38 | +-------------------------------------------+-------------------------------------------+
39 | | DecontX [Yang2020]_ | |
40 | +-------------------------------------------+-------------------------------------------+
41 | | totalVI (protein counts) [Gayoso2021]_ | |
42 | +-------------------------------------------+-------------------------------------------+
43 | | DSB (protein counts) [Mulè2022]_ | |
44 | +-------------------------------------------+-------------------------------------------+
--------------------------------------------------------------------------------
/docs/License.rst:
--------------------------------------------------------------------------------
1 | License
2 | ===============
3 |
4 | MIT License
5 |
6 | Copyright (c) 2022 Novartis
7 |
8 | Permission is hereby granted, free of charge, to any person obtaining a copy
9 | of this software and associated documentation files (the "Software"), to deal
10 | in the Software without restriction, including without limitation the rights
11 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | copies of the Software, and to permit persons to whom the Software is
13 | furnished to do so, subject to the following conditions:
14 |
15 | The above copyright notice and this permission notice shall be included in all
16 | copies or substantial portions of the Software.
17 |
18 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24 | SOFTWARE.
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/README.rst:
--------------------------------------------------------------------------------
1 | Documentation
2 | =============
3 |
4 | This folder contains all documentations of scar. See details in the `documentation `__.
--------------------------------------------------------------------------------
/docs/Reference.rst:
--------------------------------------------------------------------------------
1 | Reference
2 | ===============
3 |
4 | .. [Dixit2016] Dixit *et al.* (2016),
5 | `Perturb-Seq: Dissecting Molecular Circuits with Scalable Single-Cell RNA Profiling of Pooled Genetic Screens `__,
6 | Cell.
7 |
8 | .. [Eraslan2019] Eraslan *et al.* (2019),
9 | `Single-cell RNA-seq denoising using a deep count autoencoder `__,
10 | Nature Communications.
11 |
12 | .. [Fleming2023] Fleming *et al.* (2023),
13 | `Unsupervised removal of systematic background noise from droplet-based single-cell experiments using CellBender `__,
14 | Nature Methods.
15 |
16 | .. [Gayoso2021] Gayoso *et al.* (2021),
17 | `Joint probabilistic modeling of single-cell multi-omic data with totalVI `__,
18 | Nature Methods.
19 |
20 | .. [Lopez2018] Lopez *et al.* (2018),
21 | `Deep generative modeling for single-cell transcriptomics `__,
22 | Nature Methods.
23 |
24 | .. [Lun2019] Lun *et al.* (2019),
25 | `EmptyDrops: Distinguishing cells from empty droplets in droplet-based single-cell RNA sequencing data `__,
26 | Genome Biology.
27 |
28 | .. [Ly2020] Ly *et al.* (2020),
29 | `The Bayesian Methodology of Sir Harold Jeffreys as a Practical Alternative to the P Value Hypothesis Test `__,
30 | Computational Brain & Behavior.
31 |
32 | .. [Mulè2022] Mulè *et al.* (2022),
33 | `Normalizing and denoising protein expression data from droplet-based single cell profiling `__,
34 | Nature Communications.
35 |
36 | .. [Sheng2022] Sheng *et al.* (2022),
37 | `Probabilistic machine learning ensures accurate ambient denoising in droplet-based single-cell omics `__,
38 | bioRxiv.
39 |
40 | .. [Yang2020] Yang *et al.* (2020),
41 | `Decontamination of ambient RNA in single-cell RNA-seq with DecontX `__,
42 | Genome Biology.
43 |
44 | .. [Young2020] Young *et al.* (2020),
45 | `SoupX removes ambient RNA contamination from droplet-based single-cell RNA sequencing data `__,
46 | GigaScience.
--------------------------------------------------------------------------------
/docs/_static/ambient_signal_hypothesis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Novartis/scar/df7adef31162b471358b2b442445a449f939859d/docs/_static/ambient_signal_hypothesis.png
--------------------------------------------------------------------------------
/docs/_static/bgd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Novartis/scar/df7adef31162b471358b2b442445a449f939859d/docs/_static/bgd.png
--------------------------------------------------------------------------------
/docs/_static/overview_scAR.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Novartis/scar/df7adef31162b471358b2b442445a449f939859d/docs/_static/overview_scAR.png
--------------------------------------------------------------------------------
/docs/_static/scAR_favicon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Novartis/scar/df7adef31162b471358b2b442445a449f939859d/docs/_static/scAR_favicon.png
--------------------------------------------------------------------------------
/docs/_static/scAR_logo_black.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Novartis/scar/df7adef31162b471358b2b442445a449f939859d/docs/_static/scAR_logo_black.png
--------------------------------------------------------------------------------
/docs/_static/scAR_logo_transparent.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Novartis/scar/df7adef31162b471358b2b442445a449f939859d/docs/_static/scAR_logo_transparent.png
--------------------------------------------------------------------------------
/docs/_static/scAR_logo_white.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Novartis/scar/df7adef31162b471358b2b442445a449f939859d/docs/_static/scAR_logo_white.png
--------------------------------------------------------------------------------
/docs/_static/scar-styles.css:
--------------------------------------------------------------------------------
1 | body {
2 | margin: 0;
3 | font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, "Noto Sans", "Liberation Sans", sans-serif, "Apple Color Emoji", "Segoe UI Emoji", "Segoe UI Symbol", "Noto Color Emoji";
4 | }
--------------------------------------------------------------------------------
/docs/_templates/layout.html:
--------------------------------------------------------------------------------
1 | {# Import the theme's layout. #}
2 | {% extends "!layout.html" %}
3 |
4 | {# Custom CSS overrides #}
5 | {% set bootswatch_css_custom = ['_static/scar-styles.css'] %}
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # This file only contains a selection of the most common options. For a full
4 | # list see the documentation:
5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
6 |
7 | # -- Path setup --------------------------------------------------------------
8 |
9 | # If extensions (or modules to document with autodoc) are in another directory,
10 | # add these directories to sys.path here. If the directory is relative to the
11 | # documentation root, use os.path.abspath to make it absolute, like shown here.
12 |
13 | import os
14 | import sys
15 |
16 | sys.path.insert(0, os.path.abspath("."))
17 | sys.path.insert(0, os.path.abspath(".."))
18 | from scar import __version__
19 |
20 | # -- Project information -----------------------------------------------------
21 | project = "scAR"
22 | copyright = "Novartis Institute for BioMedical Research, 2022"
23 | author = "Caibin Sheng"
24 | release = __version__
25 |
26 |
27 | # -- General configuration ---------------------------------------------------
28 |
29 | # Add any Sphinx extension module names here, as strings. They can be
30 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
31 | # ones.
32 | extensions = [
33 | "myst_parser",
34 | "sphinx.ext.autosectionlabel",
35 | "nbsphinx",
36 | "sphinx_gallery.gen_gallery",
37 | "sphinx_disqus.disqus",
38 | "sphinxarg.ext",
39 | "sphinx.ext.autodoc",
40 | "sphinx.ext.autosummary",
41 | "sphinx.ext.napoleon",
42 | "autodocsumm",
43 | "matplotlib.sphinxext.plot_directive",
44 | "sphinx_design",
45 | "sphinx_tabs.tabs",
46 | ]
47 |
48 | nbsphinx_execute = "never"
49 | nbsphinx_allow_errors = True
50 | autosummary_generate = True
51 | # Add type of source files
52 |
53 | nb_custom_formats = {
54 | ".md": ["jupytext.reads", {"fmt": "mystnb"}],
55 | }
56 |
57 | # Add any paths that contain templates here, relative to this directory.
58 |
59 | # List of patterns, relative to source directory, that match files and
60 | # directories to ignore when looking for source files.
61 | # This pattern also affects html_static_path and html_extra_path.
62 | exclude_patterns = ["Thumbs.db", ".DS_Store"]
63 |
64 | # Add comments
65 | disqus_shortname = "scar-discussion"
66 |
67 | # -- Options for HTML output -------------------------------------------------
68 |
69 | # The theme to use for HTML and HTML Help pages. See the documentation for
70 | # a list of builtin themes.
71 | #
72 | html_theme = "pydata_sphinx_theme"
73 | templates_path = ["_templates"]
74 | html_static_path = ["_static"]
75 | html_css_files = ["scar-styles.css"]
76 | html_show_sourcelink = False
77 |
78 | # Add any paths that contain custom static files (such as style sheets) here,
79 | # relative to this directory. They are copied after the builtin static files,
80 | # so a file named "default.css" will overwrite the builtin "default.css".
81 | html_logo = "_static/scAR_logo_transparent.png"
82 | html_theme_options = {
83 | "logo": {
84 | "image_light": "scAR_logo_white.png",
85 | "image_dark": "scAR_logo_black.png",
86 | },
87 | "pygment_light_style": "tango",
88 | "pygment_dark_style": "native",
89 | "icon_links": [
90 | {
91 | "name": "GitHub",
92 | "url": "https://github.com/Novartis/scar",
93 | "icon": "fab fa-github-square",
94 | "type": "fontawesome",
95 | }
96 | ],
97 | "use_edit_page_button": False,
98 | "favicons": [
99 | {
100 | "rel": "icon",
101 | "sizes": "32x32",
102 | "href": "_static/scAR_favicon.png",
103 | }
104 | ],
105 | }
106 | html_context = {
107 | "github_user": "Novartis",
108 | "github_repo": "scar",
109 | "github_version": "develop",
110 | "doc_path": "docs",
111 | }
112 |
113 | # html_sidebars = {
114 | # "**": ["search-field.html", "sidebar-nav-bs.html", "sidebar-ethical-ads.html"]
115 | # }
116 |
117 | autodoc_mock_imports = ["django"]
118 | autodoc_default_options = {
119 | "autosummary": True,
120 | }
121 | numpydoc_show_class_members = False
122 |
123 | # Options for plot examples
124 | plot_include_source = True
125 | plot_formats = [("png", 120)]
126 | plot_html_show_formats = False
127 | plot_html_show_source_link = False
128 |
129 | sphinx_gallery_conf = {
130 | "examples_dirs": "tutorials", # path to your example scripts
131 | "gallery_dirs": "_build/tutorial_gallery", # path to where to save gallery generated output
132 | "filename_pattern": "/scAR_tutorial_",
133 | # "ignore_pattern": r"__init__\.py",
134 | }
135 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | .. scAR documentation master file, created by
2 | sphinx-quickstart on Fri Apr 22 15:48:44 2022.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | scAR documentation
7 | ================================
8 |
9 | **Version**: |release|
10 |
11 | **Useful links**:
12 | `Binary Installers `__ |
13 | `Source Repository `__ |
14 | `Issues `__ |
15 | `Contacts `__
16 |
17 |
18 | :mod:`scAR` (single-cell Ambient Remover) is an explainable machine learning model designed to remove ambient signals from droplet-based single cell omics data. It is suitable for various applications such as **sgRNA assignment** in scCRISPR-seq, **identity barcode assignment** in cell multiplexing, **protein denoising** in CITE-seq, **mRNA denoising** in scRNAseq, **peak count denoising** in scATACseq and more.
19 |
20 | It is developed by Oncology Data Science, Novartis Institute for BioMedical Research.
21 |
22 |
23 | .. grid:: 1 2 2 2
24 | :gutter: 2
25 |
26 | .. grid-item-card:: Getting started
27 | :link: Introduction
28 | :link-type: doc
29 | :img-background: _static/bgd.png
30 | :class-card: sd-text-black
31 |
32 | New to *scAR*? Check out the getting started guide. It contains an introduction to *scAR's* main concepts.
33 |
34 | +++
35 | .. button-ref:: Introduction
36 | :ref-type: doc
37 | :color: primary
38 | :shadow:
39 | :align: center
40 |
41 | What is scAR?
42 |
43 | .. grid-item-card:: Installation
44 | :link: Installation
45 | :link-type: doc
46 | :img-background: _static/bgd.png
47 |
48 | Want to install *scAR*? Check out the installation guide. It contains steps to install *scAR*.
49 |
50 | +++
51 | .. button-ref:: Installation
52 | :ref-type: doc
53 | :color: primary
54 | :shadow:
55 | :align: center
56 |
57 | How to install scAR?
58 |
59 | .. grid-item-card:: API reference
60 | :link: usages/index
61 | :link-type: doc
62 | :img-background: _static/bgd.png
63 |
64 | The API reference contains detailed descriptions of scAR API.
65 |
66 | +++
67 | .. button-ref:: usages/index
68 | :ref-type: doc
69 | :color: primary
70 | :shadow:
71 | :align: center
72 |
73 | To the API
74 |
75 | .. grid-item-card:: Tutorials
76 | :link: tutorials/index
77 | :link-type: doc
78 | :img-background: _static/bgd.png
79 |
80 | The tutorials walk you through the applications of scAR.
81 |
82 | +++
83 | .. button-ref:: tutorials/index
84 | :ref-type: doc
85 | :color: primary
86 | :shadow:
87 | :align: center
88 |
89 | To Tutorials
90 |
91 | |
92 |
93 | .. toctree::
94 | :hidden:
95 |
96 | Introduction
97 | Installation
98 | usages/index
99 | tutorials/index
100 | Release_notes
101 | Reference
102 | License
103 | Contacts
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 |
13 | %SPHINXBUILD% >NUL 2>NUL
14 | if errorlevel 9009 (
15 | echo.
16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
17 | echo.installed, then set the SPHINXBUILD environment variable to point
18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
19 | echo.may add the Sphinx directory to PATH.
20 | echo.
21 | echo.If you don't have Sphinx installed, grab it from
22 | echo.https://www.sphinx-doc.org/
23 | exit /b 1
24 | )
25 |
26 | if "%1" == "" goto help
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | jinja2<3.1.0
2 | myst_parser
3 | nbsphinx
4 | sphinx-gallery<0.17.0
5 | sphinx-argparse
6 | sphinx-disqus
7 | autodocsumm
8 | ipykernel
9 | protobuf
10 | scanpy
11 | sphinx-design
12 | pydata-sphinx-theme
13 | sphinx_tabs
14 | git+https://github.com/Novartis/scAR.git
15 |
--------------------------------------------------------------------------------
/docs/tutorials/README.rst:
--------------------------------------------------------------------------------
1 | This folder contains notebooks used in scAR tutorials
--------------------------------------------------------------------------------
/docs/tutorials/index.rst:
--------------------------------------------------------------------------------
1 | Tutorials
2 | ==============
3 |
4 | There are two ways to run ``scar``. For Python users, we recommend the Python API; for R users, we recommend the command line tool.
5 |
6 | Run scar with Python API
7 | ------------------------
8 | .. nbgallery::
9 |
10 | scAR_tutorial_ambient_profile
11 | scAR_tutorial_sgRNA_assignment
12 | scAR_tutorial_identity_barcode
13 | scAR_tutorial_denoising_CITEseq
14 | scAR_tutorial_denoising_scRNAseq
15 | scAR_tutorial_denoising_scATACseq
16 | scAR_tutorial_batch_denoising_scRNAseq
17 |
18 | Run scar with the command line tool
19 | ---------------------------------
20 |
21 | The command line tool supports two formats of input.
22 |
23 | Use ``.h5`` files as the input
24 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
25 |
26 | We can use the output of cellranger count *filtered_feature_bc_matrix.h5* as the input for ``scar``::
27 |
28 | scar filtered_feature_bc_matrix.h5 -ft feature_type -o output
29 |
30 | ``filtered_feature_bc_matrix.h5``, a filtered .h5 file produced by cellranger count.
31 |
32 | ``feature_type``, a string, either 'mRNA' or 'sgRNA' or 'ADT' or 'tag' or 'CMO' or 'ATAC'.
33 |
34 | .. note::
35 | The ambient profile is calculated by averaging the cell pool under this mode. If you want to use a more accurate ambient profile, please consider calculating it and using ``.pickle`` files as the input, as detailed below.
36 |
37 | The output folder contains an h5ad file::
38 |
39 | output
40 | └── filtered_feature_bc_matrix_denoised_feature_type.h5ad
41 |
42 | The h5ad file can be read by `scanpy.read `__ as an `anndata `__ object:
43 |
44 | - anndata.X, denosed counts.
45 | - anndata.obs['``noise_ratio``'], estimated noise ratio per cell.
46 | - anndata.layers['``native_frequencies``'], estimated native frequencies.
47 | - anndata.layers['``BayesFactor``'], bayesian factor of ambient contamination.
48 | - anndata.obs['``sgRNAs``' or '``tags``'], optional, feature assignment, e.g., sgRNA, tag, CMO, and etc..
49 |
50 |
51 | Use ``.pickle`` files as the input
52 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
53 | We can also run ``scar`` by::
54 |
55 | scar raw_count_matrix.pickle -ft feature_type -o output
56 |
57 | ``raw_count_matrix.pickle``, a file of raw count matrix (MxN) with cells in rows and features in columns.
58 |
59 | +--------+--------+--------+-----+--------+
60 | | cells | gene_0 | gene_1 | ... | gene_y |
61 | +========+========+========+=====+========+
62 | | cell_0 | 12 | 3 | ... | 82 |
63 | +--------+--------+--------+-----+--------+
64 | | cell_1 | 13 | 0 | ... | 78 |
65 | +--------+--------+--------+-----+--------+
66 | | cell_2 | 35 | 30 | ... | 170 |
67 | +--------+--------+--------+-----+--------+
68 | | ... | ... | ... | ... | ... |
69 | +--------+--------+--------+-----+--------+
70 | | cell_x | 16 | 5 | ... | 112 |
71 | +--------+--------+--------+-----+--------+
72 |
73 |
74 | ``feature_type``, a string, either 'mRNA' or 'sgRNA' or 'ADT' or 'tag' or 'CMO' or 'ATAC'.
75 |
76 | .. note::
77 | An extra argument ``ambient_profile`` is recommended to achieve deeper noise reduction.
78 |
79 |
80 | ``ambient_profile`` represents the probability of occurrence of each ambient transcript and can be empirically estimated by averging cell-free droplets.
81 |
82 | +--------+-----------------+
83 | | genes | ambient profile |
84 | +========+=================+
85 | | gene_0 | .0003 |
86 | +--------+-----------------+
87 | | gene_1 | .00004 |
88 | +--------+-----------------+
89 | | gene_2 | .00003 |
90 | +--------+-----------------+
91 | | ... | ... |
92 | +--------+-----------------+
93 | | gene_y | .0012 |
94 | +--------+-----------------+
95 |
96 | .. warning::
97 | ``ambient_profile`` should sum to one. The gene order should be consistent with ``raw_count_matrix``.
98 |
99 | For other optional arguments and parameters, run::
100 |
101 | scar --help
102 |
103 | The output folder contains four (or five) files::
104 |
105 | output
106 | ├── denoised_counts.pickle
107 | ├── expected_noise_ratio.pickle
108 | ├── BayesFactor.pickle
109 | ├── expected_native_freq.pickle
110 | └── assignment.pickle
111 |
112 | In the folder structure above:
113 |
114 | - ``expected_noise_ratio.pickle``, estimated noise ratio.
115 | - ``denoised_counts.pickle``, denoised count matrix.
116 | - ``BayesFactor.pickle``, bayesian factor of ambient contamination.
117 | - ``expected_native_freq.pickle``, estimated native frequencies.
118 | - ``assignment.pickle``, optional, feature assignment, e.g., sgRNA, tag, and etc..
119 |
--------------------------------------------------------------------------------
/docs/usages/index.rst:
--------------------------------------------------------------------------------
1 | API
2 | ===============
3 |
4 | Python API
5 | ----------------------
6 |
7 | Processing
8 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
9 | Calculate ambient profile
10 |
11 | .. currentmodule:: scar.main._setup
12 |
13 | .. autosummary::
14 | :nosignatures:
15 |
16 | setup_anndata
17 |
18 |
19 | Training
20 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
21 | The core module of scar
22 |
23 | .. currentmodule:: scar.main._scar
24 |
25 | .. autosummary::
26 | :nosignatures:
27 |
28 | model
29 |
30 | Synthetic_dataset
31 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
32 | Generate synthetic datasets (scRNAseq, CITE-seq, scCRISPRseq) with ambient contamination
33 |
34 | .. currentmodule:: scar.main._data_generater
35 |
36 | .. autosummary::
37 | :nosignatures:
38 |
39 | scrnaseq
40 | citeseq
41 | cropseq
42 |
43 | Plotting
44 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
45 | Plotting functions (under development).
46 |
47 | Reporting
48 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
49 | Generate denoising reports (under development).
50 |
51 | Command Line Interface
52 | -----------------------------------
53 |
54 | .. argparse::
55 | :module: scar.main.__main__
56 | :func: scar_parser
57 | :prog: scar
58 |
--------------------------------------------------------------------------------
/docs/usages/processing.rst:
--------------------------------------------------------------------------------
1 | Setup anndata
2 | ==============
3 | Calculate ambient profile for relevant feature types
4 |
5 | .. automodule:: scar.main._setup
6 |
7 | .. autofunction:: setup_anndata
--------------------------------------------------------------------------------
/docs/usages/synthetic_dataset.rst:
--------------------------------------------------------------------------------
1 | Generate synthetic single-cell datasets
2 | ==========================================
3 | .. automodule:: scar.main._data_generater
4 | :members:
5 | :member-order: bysource
6 |
--------------------------------------------------------------------------------
/docs/usages/training.rst:
--------------------------------------------------------------------------------
1 | Denoising model
2 | ===============================
3 |
4 | .. automodule:: scar.main._scar
5 |
6 | .. autoclass:: model
7 | :members:
8 | :member-order: bysource
9 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=68.1.2"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "scar"
7 | version = "0.7.0"
8 | requires-python = ">= 3.10"
9 | dependencies = [
10 | "torch >= 1.10.0",
11 | "torchvision >= 0.9.0",
12 | "tqdm >= 4.62.3",
13 | "seaborn >= 0.11.2",
14 | "scikit-learn >= 1.0.1",
15 | "pyro-ppl >= 1.8.0",
16 | "scanpy >= 1.9.2"
17 | ]
18 | authors = [
19 | {name = "Caibin Sheng", email = "caibin.sheng.res@gmail.com"}
20 | ]
21 | description = "scAR (single-cell Ambient Remover) is a package for denoising the ambient signals in droplet-based single cell omics"
22 | readme = "README.md"
23 | license = {text = "MIT License"}
24 | keywords = ["single cell omics", "variational autoencoder", "machine learning", "generative model", "cite-seq", "scCRISPRseq", "scRNAseq"]
25 |
26 | [project.urls]
27 | Homepage = "https://github.com/Novartis/scAR"
28 | Documentation = "https://scar-tutorials.readthedocs.io/en/main/"
29 | Repository = "https://github.com/Novartis/scar.git"
30 | Issues = "https://github.com/Novartis/scAR/issues"
31 | Changelog = "https://github.com/me/spam/blob/master/CHANGELOG.md"
32 |
33 | [tool.semantic_release]
34 | version_toml = ["pyproject.toml:project.version"]
35 | major_on_zero = false
36 | branch = "develop"
37 | upload_to_release = false
38 | hvcs = "github"
39 | upload_to_repository = false
40 | upload_to_pypi = false
41 | patch_without_tag = false
42 |
43 | [tool.semantic_release.changelog]
44 | changelog_file="docs/Release_notes.md"
45 |
46 | [project.gui-scripts]
47 | scar = "scar.main.__main__:main"
48 |
--------------------------------------------------------------------------------
/scar-cpu.yml:
--------------------------------------------------------------------------------
1 | name: scar
2 | channels:
3 | - nvidia
4 | - pytorch
5 | - conda-forge
6 | dependencies:
7 | - conda-forge::python>=3.10
8 | - nvidia::cudatoolkit>=11.1
9 | - pytorch::pytorch>=1.10.0
10 | - pytorch::torchvision>=0.9.0
11 | - pytorch-mutex=*=cpu
12 | - conda-forge::tqdm>=4.62.3
13 | - conda-forge::seaborn>=0.11.2
14 | - conda-forge::scikit-learn>=1.0.1
15 | - conda-forge::scanpy
16 | - conda-forge::pyro-ppl>=1.8.0
17 | - conda-forge::pip
18 | - pip:
19 | - .
20 |
--------------------------------------------------------------------------------
/scar-gpu.yml:
--------------------------------------------------------------------------------
1 | name: scar
2 | channels:
3 | - nvidia
4 | - pytorch
5 | - conda-forge
6 | dependencies:
7 | - conda-forge::python>=3.10
8 | - nvidia::cudatoolkit>=11.1
9 | - pytorch::pytorch>=1.10.0
10 | - pytorch::torchvision>=0.9.0
11 | - pytorch-mutex=*=cuda
12 | - conda-forge::tqdm>=4.62.3
13 | - conda-forge::seaborn>=0.11.2
14 | - conda-forge::scikit-learn>=1.0.1
15 | - conda-forge::scanpy
16 | - conda-forge::pyro-ppl>=1.8.0
17 | - conda-forge::pip
18 | - pip:
19 | - .
20 |
--------------------------------------------------------------------------------
/scar/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from importlib.metadata import version
3 | __version__ = version("scar")
4 |
5 | from .main._scar import model
6 | from .main._setup import setup_anndata
7 | from .main import _data_generater as data_generator
8 |
--------------------------------------------------------------------------------
/scar/main/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
--------------------------------------------------------------------------------
/scar/main/__main__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """command line of scar"""
3 |
4 | import argparse
5 |
6 | import os
7 | import pandas as pd, scanpy as sc
8 | from ._scar import model
9 | from ..__init__ import __version__
10 | from ._utils import get_logger
11 |
12 | def main():
13 | """main function for command line interface"""
14 | args = Config()
15 | count_matrix_path = args.count_matrix[0]
16 | ambient_profile_path = args.ambient_profile
17 | feature_type = args.feature_type
18 | output_dir = (
19 | os.getcwd() if not args.output else args.output
20 | ) # if None, output to current directory
21 | count_model = args.count_model
22 | nn_layer1 = args.hidden_layer1
23 | nn_layer2 = args.hidden_layer2
24 | latent_dim = args.latent_dim
25 | epochs = args.epochs
26 | device = args.device
27 | sparsity = args.sparsity
28 | batchkey = args.batchkey
29 | cachecapacity = args.cachecapacity
30 | gnf = bool(args.get_native_frequencies)
31 | save_model = args.save_model
32 | batch_size = args.batchsize
33 | batch_size_infer = args.batchsize_infer
34 | adjust = args.adjust
35 | cutoff = args.cutoff
36 | moi = args.moi
37 | round_to_int = args.round2int
38 | clip_to_obs = args.clip_to_obs
39 | verbose = args.verbose
40 |
41 | main_logger = get_logger("scar", verbose=verbose)
42 |
43 | _, file_extension = os.path.splitext(count_matrix_path)
44 |
45 | if file_extension == ".pickle":
46 | count_matrix = pd.read_pickle(count_matrix_path)
47 |
48 | # Denoising transcritomic data
49 | elif file_extension == ".h5":
50 | adata = sc.read_10x_h5(count_matrix_path, gex_only=False)
51 |
52 | main_logger.info(
53 | "unprocessed data contains: {0} cells and {1} genes".format(
54 | adata.shape[0], adata.shape[1]
55 | )
56 | )
57 | adata = adata[:, adata.X.sum(axis=0) > 0] # filter out features of zero counts
58 | main_logger.info(
59 | "filter out features of zero counts, remaining data contains: {0} cells and {1} genes".format(
60 | adata.shape[0], adata.shape[1]
61 | )
62 | )
63 |
64 | if feature_type.lower() == "all":
65 | features = adata.var["feature_types"].unique()
66 | count_matrix = adata.copy()
67 |
68 | # Denoising mRNAs
69 | elif feature_type.lower() in ["mrna", "mrnas"]:
70 | features = "Gene Expression"
71 | adata_fb = adata[:, adata.var["feature_types"] == features]
72 | count_matrix = adata_fb.copy()
73 |
74 | # Denoising sgRNAs
75 | elif feature_type.lower() in ["sgrna", "sgrnas"]:
76 | features = "CRISPR Guide Capture"
77 | adata_fb = adata[:, adata.var["feature_types"] == features]
78 | count_matrix = adata_fb.copy()
79 |
80 | # Denoising CMO tags
81 | elif feature_type.lower() in ["tag", "tags"]:
82 | features = "Multiplexing Capture"
83 | adata_fb = adata[:, adata.var["feature_types"] == features]
84 | count_matrix = adata_fb.copy()
85 |
86 | # Denoising ADTs
87 | elif feature_type.lower() in ["adt", "adts"]:
88 | features = "Antibody Capture"
89 | adata_fb = adata[:, adata.var["feature_types"] == features]
90 | count_matrix = adata_fb.copy()
91 |
92 | # Denoising ATAC peaks
93 | elif feature_type.lower() in ["atac"]:
94 | features = "Peaks"
95 | adata_fb = adata[:, adata.var["feature_types"] == features]
96 | count_matrix = adata_fb.copy()
97 |
98 | main_logger.info(f"modalities to denoise: {features}")
99 |
100 | else:
101 | raise Exception(file_extension + " files are not supported.")
102 |
103 | if ambient_profile_path:
104 | _, ambient_profile_file_extension = os.path.splitext(ambient_profile_path)
105 | if ambient_profile_file_extension == ".pickle":
106 | ambient_profile = pd.read_pickle(ambient_profile_path)
107 |
108 | # Currently, use the default approach to calculate the ambient profile in the case of h5
109 | elif ambient_profile_file_extension == ".h5":
110 | ambient_profile = None
111 |
112 | else:
113 | raise Exception(
114 | ambient_profile_file_extension + " files are not supported."
115 | )
116 | else:
117 | ambient_profile = None
118 |
119 | main_logger.info(f"feature_type: {feature_type}")
120 | main_logger.info(f"count_model: {count_model}")
121 | main_logger.info(f"output_dir: {output_dir}")
122 | main_logger.info(f"count_matrix_path: {count_matrix_path}")
123 | main_logger.info(f"ambient_profile_path: {ambient_profile_path}")
124 | main_logger.info(f"expected data sparsity: {sparsity:.2f}")
125 |
126 | if not os.path.isdir(output_dir):
127 | os.makedirs(output_dir)
128 |
129 | # Run model
130 | scar_model = model(
131 | raw_count=count_matrix,
132 | ambient_profile=ambient_profile,
133 | nn_layer1=nn_layer1,
134 | nn_layer2=nn_layer2,
135 | latent_dim=latent_dim,
136 | feature_type=feature_type,
137 | count_model=count_model,
138 | batch_key=batchkey,
139 | cache_capacity=cachecapacity,
140 | sparsity=sparsity,
141 | device=device,
142 | )
143 |
144 | scar_model.train(
145 | batch_size=batch_size,
146 | epochs=epochs,
147 | save_model=save_model,
148 | )
149 |
150 | scar_model.inference(
151 | adjust=adjust,
152 | get_native_frequencies=gnf,
153 | round_to_int=round_to_int,
154 | batch_size=batch_size_infer,
155 | clip_to_obs=clip_to_obs,
156 | )
157 |
158 | if feature_type.lower() in ["sgrna", "sgrnas", "tag", "tags", "cmo", "cmos"]:
159 | scar_model.assignment(cutoff=cutoff, moi=moi)
160 |
161 | main_logger.info("Saving results...")
162 |
163 | # save results
164 | if file_extension == ".pickle":
165 | output_path01, output_path02, output_path03, output_path04 = (
166 | os.path.join(output_dir, "denoised_counts.pickle"),
167 | os.path.join(output_dir, "BayesFactor.pickle"),
168 | os.path.join(output_dir, "native_frequency.pickle"),
169 | os.path.join(output_dir, "noise_ratio.pickle"),
170 | )
171 |
172 | pd.DataFrame(
173 | scar_model.native_counts.toarray(),
174 | index=count_matrix.index,
175 | columns=count_matrix.columns,
176 | ).to_pickle(output_path01)
177 | main_logger.info(f"denoised counts saved in: {output_path01}")
178 |
179 | pd.DataFrame(
180 | scar_model.noise_ratio.toarray(),
181 | index=count_matrix.index,
182 | columns=["noise_ratio"]
183 | ).to_pickle(output_path04)
184 | main_logger.info(f"expected noise ratio saved in: {output_path04}")
185 |
186 | if scar_model.native_frequencies is not None:
187 | pd.DataFrame(
188 | scar_model.native_frequencies.toarray(),
189 | index=count_matrix.index,
190 | columns=count_matrix.columns,
191 | ).to_pickle(output_path03)
192 | main_logger.info(f"expected native frequencies saved in: {output_path03}")
193 |
194 | if feature_type.lower() in ["sgrna", "sgrnas", "tag", "tags", "cmo", "cmos"]:
195 | pd.DataFrame(
196 | scar_model.bayesfactor.toarray(),
197 | index=count_matrix.index,
198 | columns=count_matrix.columns,
199 | ).to_pickle(output_path02)
200 | main_logger.info(f"BayesFactor matrix saved in: {output_path02}")
201 |
202 | output_path05 = os.path.join(output_dir, "assignment.pickle")
203 | scar_model.feature_assignment.to_pickle(output_path05)
204 | main_logger.info(f"assignment saved in: {output_path05}")
205 |
206 | elif file_extension == ".h5":
207 | output_path_h5ad = os.path.join(
208 | output_dir, f"filtered_feature_bc_matrix_denoised_{feature_type}.h5ad"
209 | )
210 |
211 | denoised_adata = adata.copy()
212 | denoised_adata.X = scar_model.native_counts
213 | denoised_adata.obs["noise_ratio"] = pd.DataFrame(
214 | scar_model.noise_ratio.toarray(),
215 | index=count_matrix.obs_names,
216 | columns=["noise_ratio"],
217 | )
218 | if scar_model.native_frequencies is not None:
219 | denoised_adata.layers["native_frequencies"] = scar_model.native_frequencies.toarray()
220 |
221 | if feature_type.lower() in ["sgrna", "sgrnas", "tag", "tags", "cmo", "cmos"]:
222 | denoised_adata.obs = denoised_adata.obs.join(scar_model.feature_assignment)
223 | denoised_adata.layers["BayesFactor"] = scar_model.bayesfactor.toarray()
224 |
225 | denoised_adata.write(output_path_h5ad)
226 | main_logger.info(f"the denoised h5ad file saved in: {output_path_h5ad}")
227 |
228 |
229 | class Config:
230 | """
231 | The configuration options. Options can be specified as command-line arguments.
232 | """
233 |
234 | def __init__(self) -> None:
235 | """Initialize configuration values."""
236 | self.parser = scar_parser()
237 | self.namespace = vars(self.parser.parse_args())
238 |
239 | def __getattr__(self, option):
240 | return self.namespace[option]
241 |
242 |
243 | def scar_parser():
244 | """Argument parser"""
245 |
246 | parser = argparse.ArgumentParser(
247 | description="scAR (single-cell Ambient Remover) is a deep learning model for removal of the ambient signals in droplet-based single cell omics",
248 | formatter_class=argparse.RawTextHelpFormatter,
249 | )
250 | parser.add_argument(
251 | "--version",
252 | action="version",
253 | version=f"%(prog)s version: {__version__}",
254 | )
255 | parser.add_argument(
256 | "count_matrix",
257 | type=str,
258 | nargs="+",
259 | help="The file of raw count matrix, 2D array (cells x genes) or the path of a filtered_feature_bc_matrix.h5",
260 | )
261 | parser.add_argument(
262 | "-ap",
263 | "--ambient_profile",
264 | type=str,
265 | default=None,
266 | help="The file of empty profile obtained from empty droplets, 1D array",
267 | )
268 | parser.add_argument(
269 | "-ft",
270 | "--feature_type",
271 | type=str,
272 | default="mRNA",
273 | help="The feature types, e.g. mRNA, sgRNA, ADT, tag, CMO and ATAC",
274 | )
275 | parser.add_argument(
276 | "-o", "--output", type=str, default=None, help="Output directory"
277 | )
278 | parser.add_argument(
279 | "-m", "--count_model", type=str, default="binomial", help="Count model"
280 | )
281 | parser.add_argument(
282 | "-sp",
283 | "--sparsity",
284 | type=float,
285 | default=0.9,
286 | help="The sparsity of expected native signals",
287 | )
288 | parser.add_argument(
289 | "-bk",
290 | "--batchkey",
291 | type=str,
292 | default=None,
293 | help="The batch key for batch correction",
294 | )
295 | parser.add_argument(
296 | "-cache",
297 | "--cachecapacity",
298 | type=int,
299 | default=20000,
300 | help="The capacity of cache for batch correction",
301 | )
302 | parser.add_argument(
303 | "-gnf",
304 | "--get_native_frequencies",
305 | type=int,
306 | default=0,
307 | help="Whether to get native frequencies, 0 or 1, by default 0, not to get native frequencies",
308 | )
309 | parser.add_argument(
310 | "-hl1",
311 | "--hidden_layer1",
312 | type=int,
313 | default=150,
314 | help="Number of neurons in the first layer",
315 | )
316 | parser.add_argument(
317 | "-hl2",
318 | "--hidden_layer2",
319 | type=int,
320 | default=100,
321 | help="Number of neurons in the second layer",
322 | )
323 | parser.add_argument(
324 | "-ls",
325 | "--latent_dim",
326 | type=int,
327 | default=15,
328 | help="Dimension of latent space",
329 | )
330 | parser.add_argument(
331 | "-epo", "--epochs", type=int, default=800, help="Training epochs"
332 | )
333 | parser.add_argument(
334 | "-d",
335 | "--device",
336 | type=str,
337 | default="auto",
338 | help="Device used for training, either 'auto', 'cpu', or 'cuda'",
339 | )
340 | parser.add_argument(
341 | "-s",
342 | "--save_model",
343 | type=int,
344 | default=False,
345 | help="Save the trained model",
346 | )
347 | parser.add_argument(
348 | "-batchsize",
349 | "--batchsize",
350 | type=int,
351 | default=64,
352 | help="Batch size for training, set a small value upon out of memory error",
353 | )
354 | parser.add_argument(
355 | "-batchsize_infer",
356 | "--batchsize_infer",
357 | type=int,
358 | default=4096,
359 | help="Batch size for inference, set a small value upon out of memory error",
360 | )
361 | parser.add_argument(
362 | "-adjust",
363 | "--adjust",
364 | type=str,
365 | default="micro",
366 | help="""Only used for calculating Bayesfactors to improve performance,
367 |
368 | | 'micro' -- adjust the estimated native counts per cell. Default.
369 | | 'global' -- adjust the estimated native counts globally.
370 | | False -- no adjustment, use the model-returned native counts.""",
371 | )
372 | parser.add_argument(
373 | "-cutoff",
374 | "--cutoff",
375 | type=float,
376 | default=3,
377 | help="Cutoff for Bayesfactors. See [Ly2020]_",
378 | )
379 | parser.add_argument(
380 | "-round2int",
381 | "--round2int",
382 | type=str,
383 | default="stochastic_rounding",
384 | help="Round the counts",
385 | )
386 |
387 | parser.add_argument(
388 | "-clip_to_obs",
389 | "--clip_to_obs",
390 | type=bool,
391 | default=False,
392 | help="clip the predicted native counts by observed counts, \
393 | use it with caution, as it may lead to overestimation of overall noise.",
394 | )
395 | parser.add_argument(
396 | "-moi",
397 | "--moi",
398 | type=float,
399 | default=None,
400 | help="Multiplicity of Infection. If assigned, it will allow optimized thresholding, \
401 | which tests a series of cutoffs to find the best one based on distributions of infections under given moi. \
402 | See [Dixit2016]_ for details. Under development.",
403 | )
404 | parser.add_argument(
405 | "-verbose",
406 | "--verbose",
407 | type=bool,
408 | default=True,
409 | help="Whether to print the logging messages",
410 | )
411 | return parser
412 |
413 |
414 | if __name__ == "__main__":
415 | main()
416 |
--------------------------------------------------------------------------------
/scar/main/_activation_functions.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Customized activation functions
4 | """
5 |
6 | import torch
7 |
8 |
9 | class mytanh(torch.nn.Module):
10 | def __init__(self):
11 | super().__init__() # init the base class
12 |
13 | def forward(self, input_x):
14 | var_tanh = torch.tanh(input_x)
15 | output = (1 + var_tanh) / 2
16 | return output
17 |
18 |
19 | class hnormalization(torch.nn.Module):
20 | def __init__(self):
21 | super().__init__()
22 |
23 | def forward(self, input_x):
24 | return input_x / (input_x.sum(dim=1).view(-1, 1) + 1e-5)
25 |
26 |
27 | class mysoftplus(torch.nn.Module):
28 | def __init__(self, sparsity=0.9):
29 | super().__init__() # init the base class
30 | self.sparsity = sparsity
31 |
32 | def forward(self, input_x):
33 | return self._mysoftplus(input_x)
34 |
35 | def _mysoftplus(self, input_x):
36 | """customized softplus activation, output range: [0, inf)"""
37 | var_sp = torch.nn.functional.softplus(input_x)
38 | threshold = torch.nn.functional.softplus(
39 | torch.tensor(-(1 - self.sparsity) * 10.0, device=input_x.device)
40 | )
41 | var_sp = var_sp - threshold
42 | zero = torch.zeros_like(threshold)
43 | var_out = torch.where(var_sp <= zero, zero, var_sp)
44 | return var_out
45 |
--------------------------------------------------------------------------------
/scar/main/_data_generater.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """Module to generate synthetic datasets with ambient contamination"""
3 |
4 | import numpy as np
5 | from numpy import random
6 | import matplotlib.pyplot as plt
7 | import seaborn as sns
8 |
9 | ###################################################################################
10 | # Synthetic scrnaseq datasets
11 |
12 |
13 | class scrnaseq:
14 | """Generate synthetic single-cell RNAseq data with ambient contamination
15 |
16 | Parameters
17 | ----------
18 | n_cells : int
19 | number of cells
20 | n_celltypes : int
21 | number of cell types
22 | n_features : int
23 | number of features (mRNA)
24 | n_total_molecules : int, optional
25 | total molecules per cell, by default 8000
26 | capture_rate : float, optional
27 | the probability of being captured by beads, by default 0.7
28 |
29 | Examples
30 | --------
31 | .. plot::
32 | :context: close-figs
33 |
34 | import numpy as np
35 | from scar import data_generator
36 |
37 | n_features = 1000 # 1000 genes, bad visualization with too big number
38 | n_cells = 6000 # cells
39 | n_total_molecules = 20000 # total mRNAs
40 | n_celltypes = 8 # cell types
41 |
42 | np.random.seed(8)
43 | scRNAseq = data_generator.scrnaseq(n_cells, n_celltypes, n_features, n_total_molecules=n_total_molecules)
44 | scRNAseq.generate(dirichlet_concentration_hyper=1)
45 | scRNAseq.heatmap(vmax=5)
46 |
47 | """
48 |
49 | def __init__(
50 | self, n_cells, n_celltypes, n_features, n_total_molecules=8000, capture_rate=0.7
51 | ):
52 | """initilization"""
53 | self.n_cells = n_cells
54 | """int, number of cells"""
55 | self.n_celltypes = n_celltypes
56 | """int, number of cell types"""
57 | self.n_features = n_features
58 | """int, number of features (mRNA, sgRNA, ADT, tag, CMO, and etc.)"""
59 | self.n_total_molecules = n_total_molecules
60 | """int, number of total molecules per cell"""
61 | self.capture_rate = capture_rate
62 | """float, the probability of being captured by beads"""
63 | self.obs_count = (None,)
64 | """vector, observed counts"""
65 | self.ambient_profile = (None,)
66 | """vector, the probability of occurrence of each ambient transcript"""
67 | self.cell_identity = (None,)
68 | """matrix, the onehot expression of the identity of cell types"""
69 | self.noise_ratio = (None,)
70 | """vector, contamination level per cell"""
71 | self.celltype = (None,)
72 | """vector, the identity of cell types"""
73 | self.ambient_signals = (None,)
74 | """matrix, the real ambient signals"""
75 | self.native_signals = (None,)
76 | """matrix, the real native signals"""
77 | self.native_profile = (None,)
78 | """matrix, the frequencies of the real native signals"""
79 | self.total_counts = (None,)
80 | """vector, the total observed counts per cell"""
81 | self.empty_droplets = None
82 | """matrix, synthetic cell-free droplets"""
83 |
84 | def generate(self, dirichlet_concentration_hyper=0.05):
85 | """Generate a synthetic scRNAseq dataset.
86 |
87 | Parameters
88 | ----------
89 | dirichlet_concentration_hyper : None or real, optional
90 | the concentration hyperparameters of dirichlet distribution. \
91 | Determining the sparsity of native signals. \
92 | If None, 1 / n_features, by default 0.005.
93 |
94 | Returns
95 | -------
96 | After running, several attributes are added
97 | """
98 |
99 | if dirichlet_concentration_hyper:
100 | alpha = np.ones(self.n_features) * dirichlet_concentration_hyper
101 | else:
102 | alpha = np.ones(self.n_features) / self.n_features
103 |
104 | # simulate native expression frequencies for each cell
105 | cell_comp_prior = random.dirichlet(np.ones(self.n_celltypes))
106 | celltype = random.choice(
107 | a=self.n_celltypes, size=self.n_cells, p=cell_comp_prior
108 | )
109 | cell_identity = np.identity(self.n_celltypes)[celltype]
110 | theta_celltype = random.dirichlet(alpha, size=self.n_celltypes)
111 |
112 | beta_in_each_cell = cell_identity.dot(theta_celltype)
113 |
114 | # simulate total molecules for a droplet in ambient pool
115 | n_total_mol = random.randint(
116 | low=self.n_total_molecules / 5, high=self.n_total_molecules / 2, size=1
117 | )
118 |
119 | # simulate ambient signals
120 | beta0 = random.dirichlet(np.ones(self.n_features))
121 | tot_count0 = random.negative_binomial(
122 | n_total_mol, self.capture_rate, size=self.n_cells
123 | )
124 | ambient_signals = np.vstack(
125 | [random.multinomial(n=tot_c, pvals=beta0) for tot_c in tot_count0]
126 | )
127 |
128 | # add empty droplets
129 | tot_count0_empty = random.negative_binomial(
130 | n_total_mol, self.capture_rate, size=self.n_cells
131 | )
132 | ambient_signals_empty = np.vstack(
133 | [random.multinomial(n=tot_c, pvals=beta0) for tot_c in tot_count0_empty]
134 | )
135 |
136 | # simulate native signals
137 | tot_trails = random.randint(
138 | low=self.n_total_molecules / 2,
139 | high=self.n_total_molecules,
140 | size=self.n_celltypes,
141 | )
142 | tot_count1 = [
143 | random.negative_binomial(tot, self.capture_rate)
144 | for tot in cell_identity.dot(tot_trails)
145 | ]
146 |
147 | native_signals = np.vstack(
148 | [
149 | random.multinomial(n=tot_c, pvals=theta1)
150 | for tot_c, theta1 in zip(tot_count1, beta_in_each_cell)
151 | ]
152 | )
153 | obs = ambient_signals + native_signals
154 |
155 | noise_ratio = tot_count0 / (tot_count0 + tot_count1)
156 |
157 | self.obs_count = obs
158 | self.ambient_profile = beta0
159 | self.cell_identity = cell_identity
160 | self.noise_ratio = noise_ratio
161 | self.celltype = celltype
162 | self.ambient_signals = ambient_signals
163 | self.native_signals = native_signals
164 | self.native_profile = beta_in_each_cell
165 | self.total_counts = obs.sum(axis=1)
166 | self.empty_droplets = ambient_signals_empty.astype(int)
167 |
168 | def heatmap(
169 | self, feature_type="mRNA", return_obj=False, figsize=(12, 4), vmin=0, vmax=10
170 | ):
171 | """Heatmap of synthetic data.
172 |
173 | Parameters
174 | ----------
175 | feature_type : str, optional
176 | the feature types, by default "mRNA"
177 | return_obj : bool, optional
178 | whether to output figure object, by default False
179 | figsize : tuple, optional
180 | figure size, by default (15, 5)
181 | vmin : int, optional
182 | colorbar minimum, by default 0
183 | vmax : int, optional
184 | colorbar maximum, by default 10
185 |
186 | Returns
187 | -------
188 | fig object
189 | if return_obj, return a fig object
190 | """
191 | sort_cell_idx = []
192 | for f in self.ambient_profile.argsort():
193 | sort_cell_idx += list(np.where(self.celltype == f)[0])
194 |
195 | native_signals = self.native_signals[sort_cell_idx][
196 | :, self.ambient_profile.argsort()
197 | ]
198 | ambient_signals = self.ambient_signals[sort_cell_idx][
199 | :, self.ambient_profile.argsort()
200 | ]
201 | obs = self.obs_count[sort_cell_idx][:, self.ambient_profile.argsort()]
202 |
203 | fig, axs = plt.subplots(ncols=3, figsize=figsize)
204 | sns.heatmap(
205 | np.log2(obs + 1),
206 | yticklabels=False,
207 | vmin=vmin,
208 | vmax=vmax,
209 | cmap="coolwarm",
210 | center=1,
211 | ax=axs[0],
212 | rasterized=True,
213 | cbar_kws={"label": "log2(counts + 1)"},
214 | )
215 | axs[0].set_title("noisy observation")
216 |
217 | sns.heatmap(
218 | np.log2(ambient_signals + 1),
219 | yticklabels=False,
220 | vmin=vmin,
221 | vmax=vmax,
222 | cmap="coolwarm",
223 | center=1,
224 | ax=axs[1],
225 | rasterized=True,
226 | cbar_kws={"label": "log2(counts + 1)"},
227 | )
228 | axs[1].set_title("ambient signals")
229 |
230 | sns.heatmap(
231 | np.log2(native_signals + 1),
232 | yticklabels=False,
233 | vmin=vmin,
234 | vmax=vmax,
235 | cmap="coolwarm",
236 | center=1,
237 | ax=axs[2],
238 | rasterized=True,
239 | cbar_kws={"label": "log2(counts + 1)"},
240 | )
241 | axs[2].set_title("native signals")
242 |
243 | fig.supxlabel(feature_type)
244 | fig.supylabel("cells")
245 | plt.tight_layout()
246 |
247 | if return_obj:
248 | return fig
249 |
250 |
251 | ######################################################################################
252 | # Synthetic citeseq datasets
253 | class citeseq(scrnaseq):
254 | """Generate synthetic ADT count data for CITE-seq with ambient contamination
255 |
256 | Parameters
257 | ----------
258 | n_cells : int
259 | number of cells
260 | n_celltypes : int
261 | number of cell types
262 | n_features : int
263 | number of distinct antibodies (ADTs)
264 | n_total_molecules : int, optional
265 | number of total molecules, by default 8000
266 | capture_rate : float, optional
267 | the probabilities of being captured by beads, by default 0.7
268 |
269 | Examples
270 | --------
271 | .. plot::
272 | :context: close-figs
273 |
274 | import numpy as np
275 | from scar import data_generator
276 |
277 | n_features = 50 # 50 ADTs
278 | n_cells = 6000 # 6000 cells
279 | n_celltypes = 6 # cell types
280 |
281 | # generate a synthetic ADT count dataset
282 | np.random.seed(8)
283 | citeseq = data_generator.citeseq(n_cells, n_celltypes, n_features)
284 | citeseq.generate()
285 | citeseq.heatmap()
286 |
287 | """
288 |
289 | def __init__(
290 | self, n_cells, n_celltypes, n_features, n_total_molecules=8000, capture_rate=0.7
291 | ):
292 | super().__init__(
293 | n_cells, n_celltypes, n_features, n_total_molecules, capture_rate
294 | )
295 |
296 | def generate(self, dirichlet_concentration_hyper=None):
297 | """Generate a synthetic ADT dataset.
298 |
299 | Parameters
300 | ----------
301 | dirichlet_concentration_hyper : None or real, optional
302 | the concentration hyperparameters of dirichlet distribution. \
303 | If None, 1 / n_features, by default None
304 |
305 | Returns
306 | -------
307 | After running, several attributes are added
308 | """
309 |
310 | if dirichlet_concentration_hyper:
311 | alpha = np.ones(self.n_features) * dirichlet_concentration_hyper
312 | else:
313 | alpha = np.ones(self.n_features) / self.n_features
314 |
315 | # simulate native expression frequencies for each cell
316 | cell_comp_prior = random.dirichlet(np.ones(self.n_celltypes))
317 | celltype = random.choice(
318 | a=self.n_celltypes, size=self.n_cells, p=cell_comp_prior
319 | )
320 | cell_identity = np.identity(self.n_celltypes)[celltype]
321 | theta_celltype = random.dirichlet(alpha, size=self.n_celltypes)
322 | beta_in_each_cell = cell_identity.dot(theta_celltype)
323 |
324 | # simulate total molecules for a droplet in ambient pool
325 | n_total_mol = random.randint(
326 | low=self.n_total_molecules / 5, high=self.n_total_molecules / 2, size=1
327 | )
328 |
329 | # simulate ambient signals
330 | beta0 = random.dirichlet(np.ones(self.n_features))
331 | tot_count0 = random.negative_binomial(
332 | n_total_mol, self.capture_rate, size=self.n_cells
333 | )
334 | ambient_signals = np.vstack(
335 | [random.multinomial(n=tot_c, pvals=beta0) for tot_c in tot_count0]
336 | )
337 |
338 | # add empty droplets
339 | tot_count0_empty = random.negative_binomial(
340 | n_total_mol, self.capture_rate, size=self.n_cells
341 | )
342 | ambient_signals_empty = np.vstack(
343 | [random.multinomial(n=tot_c, pvals=beta0) for tot_c in tot_count0_empty]
344 | )
345 |
346 | # simulate native signals
347 | tot_trails = random.randint(
348 | low=self.n_total_molecules / 2,
349 | high=self.n_total_molecules,
350 | size=self.n_celltypes,
351 | )
352 | tot_count1 = [
353 | random.negative_binomial(tot, self.capture_rate)
354 | for tot in cell_identity.dot(tot_trails)
355 | ]
356 |
357 | native_signals = np.vstack(
358 | [
359 | random.multinomial(n=tot_c, pvals=theta1)
360 | for tot_c, theta1 in zip(tot_count1, beta_in_each_cell)
361 | ]
362 | )
363 | obs = ambient_signals + native_signals
364 |
365 | noise_ratio = tot_count0 / (tot_count0 + tot_count1)
366 |
367 | self.obs_count = obs
368 | self.ambient_profile = beta0
369 | self.cell_identity = cell_identity
370 | self.noise_ratio = noise_ratio
371 | self.celltype = celltype
372 | self.ambient_signals = ambient_signals
373 | self.native_signals = native_signals
374 | self.native_profile = beta_in_each_cell
375 | self.total_counts = obs.sum(axis=1)
376 | self.empty_droplets = ambient_signals_empty.astype(int)
377 |
378 | def heatmap(
379 | self, feature_type="ADT", return_obj=False, figsize=(12, 4), vmin=0, vmax=10
380 | ):
381 | """Heatmap of synthetic data.
382 |
383 | Parameters
384 | ----------
385 | feature_type : str, optional
386 | the feature types, by default "ADT"
387 | return_obj : bool, optional
388 | whether to output figure object, by default False
389 | figsize : tuple, optional
390 | figure size, by default (15, 5)
391 | vmin : int, optional
392 | colorbar minimum, by default 0
393 | vmax : int, optional
394 | colorbar maximum, by default 10
395 |
396 | Returns
397 | -------
398 | fig object
399 | if return_obj, return a fig object
400 | """
401 | sort_cell_idx = []
402 | for f in self.ambient_profile.argsort():
403 | sort_cell_idx += list(np.where(self.celltype == f)[0])
404 |
405 | native_signals = self.native_signals[sort_cell_idx][
406 | :, self.ambient_profile.argsort()
407 | ]
408 | ambient_signals = self.ambient_signals[sort_cell_idx][
409 | :, self.ambient_profile.argsort()
410 | ]
411 | obs = self.obs_count[sort_cell_idx][:, self.ambient_profile.argsort()]
412 |
413 | fig, axs = plt.subplots(ncols=3, figsize=figsize)
414 | sns.heatmap(
415 | np.log2(obs + 1),
416 | yticklabels=False,
417 | vmin=vmin,
418 | vmax=vmax,
419 | cmap="coolwarm",
420 | center=1,
421 | ax=axs[0],
422 | rasterized=True,
423 | cbar_kws={"label": "log2(counts + 1)"},
424 | )
425 | axs[0].set_title("noisy observation")
426 |
427 | sns.heatmap(
428 | np.log2(ambient_signals + 1),
429 | yticklabels=False,
430 | vmin=vmin,
431 | vmax=vmax,
432 | cmap="coolwarm",
433 | center=1,
434 | ax=axs[1],
435 | rasterized=True,
436 | cbar_kws={"label": "log2(counts + 1)"},
437 | )
438 | axs[1].set_title("ambient signals")
439 |
440 | sns.heatmap(
441 | np.log2(native_signals + 1),
442 | yticklabels=False,
443 | vmin=vmin,
444 | vmax=vmax,
445 | cmap="coolwarm",
446 | center=1,
447 | ax=axs[2],
448 | rasterized=True,
449 | cbar_kws={"label": "log2(counts + 1)"},
450 | )
451 | axs[2].set_title("native signals")
452 |
453 | fig.supxlabel(feature_type)
454 | fig.supylabel("cells")
455 | plt.tight_layout()
456 |
457 | if return_obj:
458 | return fig
459 |
460 |
461 | ##########################################################################################
462 | # Synthetic cropseq datasets
463 |
464 |
465 | class cropseq(scrnaseq):
466 | """Generate synthetic sgRNA count data for scCRISPRseq with ambient contamination
467 |
468 | Parameters
469 | ----------
470 | n_cells : int
471 | number of cells
472 | n_celltypes : int
473 | number of cell types
474 | n_features : int
475 | number of dinstinct sgRNAs
476 | library_pattern : str, optional
477 | the pattern of sgRNA libraries, three possibilities:
478 |
479 | | "uniform" - each sgRNA has equal frequency in the libraries
480 | | "pyramid" - a few sgRNAs have significantly higher frequencies in the libraries
481 | | "reverse_pyramid" - a few sgRNAs have significantly lower frequencies in the libraries
482 | | By default "pyramid".
483 | noise_ratio : float, optional
484 | global contamination level, by default 0.005
485 | average_counts_per_cell : int, optional
486 | average total sgRNA counts per cell, by default 2000
487 | doublet_rate : int, optional
488 | doublet rate, by default 0
489 | missing_rate : int, optional
490 | the fraction of droplets which have zero sgRNAs integrated, by default 0
491 |
492 | Examples
493 | --------
494 |
495 | .. plot::
496 | :context: close-figs
497 |
498 | import numpy as np
499 | from scar import data_generator
500 |
501 | n_features = 100 # 100 sgRNAs in the libraries
502 | n_cells = 6000 # 6000 cells
503 | n_celltypes = 1 # single cell line
504 |
505 | # generate a synthetic sgRNA count dataset
506 | np.random.seed(8)
507 | cropseq = data_generator.cropseq(n_cells, n_celltypes, n_features)
508 | cropseq.generate(noise_ratio=0.98)
509 | cropseq.heatmap(vmax=6)
510 | """
511 |
512 | def __init__(
513 | self,
514 | n_cells,
515 | n_celltypes,
516 | n_features,
517 | ):
518 | super().__init__(n_cells, n_celltypes, n_features)
519 |
520 | self.sgrna_freq = None
521 | """vector, sgRNA frequencies in the libraries
522 | """
523 |
524 | # generate a pool of sgrnas
525 | def _set_sgrna_frequency(self):
526 | """set the pattern of sgrna library"""
527 | if self.library_pattern == "uniform":
528 | self.sgrna_freq = 1.0 / self.n_features
529 | elif self.library_pattern == "pyramid":
530 | uniform_spaced_values = np.random.permutation(self.n_features + 1) / (
531 | self.n_features + 1
532 | )
533 | uniform_spaced_values = uniform_spaced_values[uniform_spaced_values != 0]
534 | log_values = np.log(uniform_spaced_values)
535 | self.sgrna_freq = log_values / np.sum(log_values)
536 | elif self.library_pattern == "reverse_pyramid":
537 | uniform_spaced_values = (
538 | np.random.permutation(self.n_features) / self.n_features
539 | )
540 | log_values = (uniform_spaced_values + 0.001) ** (1 / 10)
541 | self.sgrna_freq = log_values / np.sum(log_values)
542 |
543 | def _set_native_signals(self):
544 | """generatation of native signals"""
545 |
546 | self._set_sgrna_frequency()
547 |
548 | # cells without any sgrnas
549 | n_cells_miss = int(self.n_cells * self.missing_rate)
550 |
551 | # Doublets
552 | n_doublets = int(self.n_cells * self.doublet_rate)
553 |
554 | # total number of single sgrnas which are integrated into cells
555 | # (cells with double sgrnas will be counted twice)
556 | n_cells_integrated = self.n_cells - n_cells_miss + n_doublets
557 |
558 | # create cells with sgrnas based on sgrna frequencies
559 | self.celltype = random.choice(
560 | a=range(self.n_features), size=n_cells_integrated, p=self.sgrna_freq
561 | )
562 | self.cell_identity = np.eye(self.n_features)[self.celltype] # cell_identity
563 |
564 | def _add_ambient(self):
565 | """add ambient signals"""
566 | self._set_native_signals()
567 | sgrna_mixed_freq = (
568 | 1 - self.noise_ratio
569 | ) * self.cell_identity + self.noise_ratio * self.sgrna_freq
570 | sgrna_mixed_freq = sgrna_mixed_freq / sgrna_mixed_freq.sum(
571 | axis=1, keepdims=True
572 | )
573 | return sgrna_mixed_freq
574 |
575 | # function to generate counts per cell
576 | def generate(
577 | self,
578 | dirichlet_concentration_hyper=None,
579 | library_pattern="pyramid",
580 | noise_ratio=0.96,
581 | average_counts_per_cell=2000,
582 | doublet_rate=0,
583 | missing_rate=0,
584 | ):
585 | """Generate a synthetic sgRNA count dataset.
586 |
587 | Parameters
588 | ----------
589 | library_pattern : str, optional
590 | library pattern, by default "pyramid"
591 | noise_ratio : float, optional
592 | global contamination level, by default 0.005
593 | average_counts_per_cell : int, optional
594 | average total sgRNA counts per cell, by default 2000
595 | doublet_rate : int, optional
596 | doublet rate, by default 0
597 | missing_rate : int, optional
598 | the fraction of droplets which have zero sgRNAs integrated, by default 0
599 |
600 | Returns
601 | -------
602 | After running, several attributes are added
603 | """
604 |
605 | assert library_pattern.lower() in ["uniform", "pyramid", "reverse_pyramid"]
606 | self.library_pattern = library_pattern.lower()
607 | """str, library pattern
608 | """
609 | self.doublet_rate = doublet_rate
610 | """float, doublet rate
611 | """
612 | self.missing_rate = missing_rate
613 | """float, the fraction of droplets which have zero sgRNAs integrated.
614 | """
615 | self.noise_ratio = noise_ratio
616 | """float, global contamination level
617 | """
618 | self.average_counts_per_cell = average_counts_per_cell
619 | """int, the mean of total sgRNA counts per cell
620 | """
621 |
622 | # generate total counts per cell
623 | total_counts = random.negative_binomial(
624 | self.average_counts_per_cell, 0.7, size=self.n_cells
625 | )
626 |
627 | # the mixed sgrna expression profile
628 | sgrna_mixed_freq = self._add_ambient()
629 |
630 | # final count matrix:
631 | obs = np.vstack(
632 | [
633 | random.multinomial(n=tot_c, pvals=p)
634 | for tot_c, p in zip(total_counts, sgrna_mixed_freq)
635 | ]
636 | )
637 |
638 | self.obs_count = obs
639 | self.total_counts = total_counts
640 | self.ambient_profile = self.sgrna_freq
641 | self.native_signals = (
642 | self.total_counts.reshape(-1, 1)
643 | * (1 - self.noise_ratio)
644 | * self.cell_identity
645 | )
646 | self.ambient_signals = np.clip(obs - self.native_signals, 0, None)
647 |
648 | def heatmap(
649 | self, feature_type="sgRNAs", return_obj=False, figsize=(12, 4), vmin=0, vmax=7
650 | ):
651 | """Heatmap of synthetic data.
652 |
653 | Parameters
654 | ----------
655 | feature_type : str, optional
656 | the feature types, by default "sgRNAs"
657 | return_obj : bool, optional
658 | whether to output figure object, by default False
659 | figsize : tuple, optional
660 | figure size, by default (15, 5)
661 | vmin : int, optional
662 | colorbar minimum, by default 0
663 | vmax : int, optional
664 | colorbar maximum, by default 10
665 |
666 | Returns
667 | -------
668 | fig object
669 | if return_obj, return a fig object
670 | """
671 | sort_cell_idx = []
672 | for f in self.ambient_profile.argsort():
673 | sort_cell_idx += list(np.where(self.celltype == f)[0])
674 |
675 | native_signals = self.native_signals[sort_cell_idx][
676 | :, self.ambient_profile.argsort()
677 | ]
678 | ambient_signals = self.ambient_signals[sort_cell_idx][
679 | :, self.ambient_profile.argsort()
680 | ]
681 | obs = self.obs_count[sort_cell_idx][:, self.ambient_profile.argsort()]
682 |
683 | fig, axs = plt.subplots(ncols=3, figsize=figsize)
684 | sns.heatmap(
685 | np.log2(obs + 1),
686 | yticklabels=False,
687 | vmin=vmin,
688 | vmax=vmax,
689 | cmap="coolwarm",
690 | center=1,
691 | ax=axs[0],
692 | rasterized=True,
693 | cbar_kws={"label": "log2(counts + 1)"},
694 | )
695 | axs[0].set_title("noisy observation")
696 |
697 | sns.heatmap(
698 | np.log2(ambient_signals + 1),
699 | yticklabels=False,
700 | vmin=vmin,
701 | vmax=vmax,
702 | cmap="coolwarm",
703 | center=1,
704 | ax=axs[1],
705 | rasterized=True,
706 | cbar_kws={"label": "log2(counts + 1)"},
707 | )
708 | axs[1].set_title("ambient signals")
709 |
710 | sns.heatmap(
711 | np.log2(native_signals + 1),
712 | yticklabels=False,
713 | vmin=vmin,
714 | vmax=vmax,
715 | cmap="coolwarm",
716 | center=1,
717 | ax=axs[2],
718 | rasterized=True,
719 | cbar_kws={"label": "log2(counts + 1)"},
720 | )
721 | axs[2].set_title("native signals")
722 |
723 | fig.supxlabel(feature_type)
724 | fig.supylabel("cells")
725 | plt.tight_layout()
726 |
727 | if return_obj:
728 | return fig
729 |
--------------------------------------------------------------------------------
/scar/main/_loss_functions.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """Loss functions"""
3 |
4 | import torch
5 | from torch.distributions import Normal, kl_divergence, Binomial, Poisson
6 | from pyro.distributions.zero_inflated import ZeroInflatedPoisson
7 |
8 | def kld(means, var):
9 | """KL divergence"""
10 | mean = torch.zeros_like(means)
11 | scale = torch.ones_like(var)
12 | return kl_divergence(Normal(means, torch.sqrt(var)), Normal(mean, scale)).sum(dim=1)
13 |
14 |
15 | def get_reconstruction_loss(
16 | input_matrix, dec_nr, dec_prob, amb_prob, dec_dp, count_model
17 | ):
18 | """reconstruction loss"""
19 | tot_count = input_matrix.sum(dim=1).view(-1, 1)
20 | prob_tot = dec_prob * (1 - dec_nr) + amb_prob * dec_nr
21 |
22 | if count_model.lower() == "zeroinflatedpoisson":
23 | recon_loss = -ZeroInflatedPoisson(
24 | rate=tot_count * prob_tot / (1 - dec_dp), gate=dec_dp, validate_args=False
25 | ).log_prob(input_matrix)
26 | recon_loss = torch.nan_to_num(recon_loss, nan=1e-7, posinf=1e15, neginf=-1e15)
27 | recon_loss = recon_loss.sum(axis=1).mean()
28 |
29 | elif count_model.lower() == "binomial":
30 | recon_loss = -Binomial(tot_count, probs=prob_tot, validate_args=False).log_prob(
31 | input_matrix
32 | )
33 | recon_loss = torch.nan_to_num(recon_loss, nan=1e-7, posinf=1e15, neginf=-1e15)
34 | recon_loss = recon_loss.sum(axis=1).mean()
35 |
36 | elif count_model.lower() == "poisson":
37 | recon_loss = -Poisson(rate=tot_count * prob_tot, validate_args=False).log_prob(
38 | input_matrix
39 | ) # add 1 to avoid a situation where all counts are zeros
40 | recon_loss = torch.nan_to_num(recon_loss, nan=1e-7, posinf=1e15, neginf=-1e15)
41 | recon_loss = recon_loss.sum(axis=1).mean()
42 |
43 | return recon_loss
44 |
45 |
46 | def loss_fn(
47 | input_matrix,
48 | dec_nr,
49 | dec_prob,
50 | means,
51 | var,
52 | amb_prob,
53 | reconstruction_weight,
54 | kld_weight=1e-5,
55 | dec_dp=None,
56 | count_model="binomial",
57 | ):
58 | """loss function"""
59 |
60 | recon_loss = get_reconstruction_loss(
61 | input_matrix, dec_nr, dec_prob, amb_prob, dec_dp=dec_dp, count_model=count_model
62 | )
63 | kld_loss = kld(means, var).sum()
64 | total_loss = recon_loss * reconstruction_weight + kld_loss * kld_weight
65 |
66 | return recon_loss, kld_loss, total_loss
67 |
--------------------------------------------------------------------------------
/scar/main/_scar.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """The main module of scar"""
4 |
5 | import sys, time, contextlib, torch
6 | from typing import Optional, Union
7 | from scipy import sparse
8 | import numpy as np, pandas as pd, anndata as ad
9 |
10 | from torch.utils.data import Dataset, random_split, DataLoader
11 | from tqdm import tqdm
12 | from tqdm.contrib import DummyTqdmFile
13 |
14 | from ._vae import VAE
15 | from ._loss_functions import loss_fn
16 | from ._utils import get_logger
17 |
18 |
19 | @contextlib.contextmanager
20 | def std_out_err_redirect_tqdm():
21 | """
22 | Writing progressbar into stdout rather than stderr,
23 | from https://github.com/tqdm/tqdm/blob/master/examples/redirect_print.py
24 | """
25 | orig_out_err = sys.stdout, sys.stderr
26 | try:
27 | sys.stdout, sys.stderr = map(DummyTqdmFile, orig_out_err)
28 | yield orig_out_err[0]
29 | # Relay exceptions
30 | except Exception as exc:
31 | raise exc
32 | # Always restore sys.stdout/err if necessary
33 | finally:
34 | sys.stdout, sys.stderr = orig_out_err
35 |
36 |
37 | # scar class
38 | class model:
39 | """The scar model
40 |
41 | Parameters
42 | ----------
43 | raw_count : Union[str, np.ndarray, pd.DataFrame, ad.AnnData]
44 | Raw count matrix or Anndata object.
45 |
46 | .. note::
47 | scar takes the raw UMI counts as input. No size normalization or log transformation.
48 |
49 | ambient_profile : Optional[Union[str, np.ndarray, pd.DataFrame]], optional
50 | the probability of occurrence of each ambient transcript.\
51 | If None, averaging cells to estimate the ambient profile, by default None
52 | nn_layer1 : int, optional
53 | number of neurons of the 1st layer, by default 150
54 | nn_layer2 : int, optional
55 | number of neurons of the 2nd layer, by default 100
56 | latent_dim : int, optional
57 | number of neurons of the bottleneck layer, by default 15
58 | dropout_prob : float, optional
59 | dropout probability of neurons, by default 0
60 | feature_type : str, optional
61 | the feature to be denoised. One of the following:
62 |
63 | | 'mRNA' -- transcriptome data, including scRNAseq and snRNAseq
64 | | 'ADT' -- protein counts in CITE-seq
65 | | 'sgRNA' -- sgRNA counts for scCRISPRseq
66 | | 'tag' -- identity barcodes or any data types of high sparsity. \
67 | E.g., in cell indexing experiments, we would expect a single true signal \
68 | (1) and many negative signals (0) for each cell
69 | | 'CMO' -- Cell Multiplexing Oligo counts for cell hashing
70 | | 'ATAC' -- peak counts for scATACseq
71 | .. versionadded:: 0.5.2
72 | | By default "mRNA"
73 | count_model : str, optional
74 | the model to generate the UMI count. One of the following:
75 |
76 | | 'binomial' -- binomial model,
77 | | 'poisson' -- poisson model,
78 | | 'zeroinflatedpoisson' -- zeroinflatedpoisson model, by default "binomial"
79 | sparsity : float, optional
80 | range: [0, 1]. The sparsity of expected native signals. \
81 | It varies between datasets, e.g. if one prefilters genes -- \
82 | use only highly variable genes -- \
83 | the sparsity should be low; on the other hand, it should be set high \
84 | in the case of unflitered genes. \
85 | Forced to be one in the mode of "sgRNA(s)" and "tag(s)". \
86 | Thank Will Macnair for the valuable feedback.
87 |
88 | .. versionadded:: 0.4.0
89 | cache_capacity : int, optional
90 | the capacity of caching data on GPU. Set a smaller value upon GPU memory issue. By default 20000 cells are cached.
91 |
92 | .. versionadded:: 0.7.0
93 | batch_key : str, optional
94 | batch key in AnnData.obs, by default None. \
95 | If assigned, batch ambient removel will be performed and \
96 | the ambient profile will be estimated for each batch.
97 |
98 | .. versionadded:: 0.7.0
99 |
100 | device : str, optional
101 | either "auto, "cpu" or "cuda" or "mps", by default "auto"
102 | verbose : bool, optional
103 | whether to print the details, by default True
104 |
105 | Raises
106 | ------
107 | TypeError
108 | if raw_count is not str or np.ndarray or pd.DataFrame
109 | TypeError
110 | if ambient_profile is not str or np.ndarray or pd.DataFrame or None
111 |
112 | Examples
113 | --------
114 | >>> # Real data
115 | >>> import scanpy as sc
116 | >>> from scar import model
117 | >>> adata = sc.read("...") # load an anndata object
118 | >>> scarObj = model(adata, ambient_profile) # initialize scar model
119 | >>> scarObj.train() # start training
120 | >>> scarObj.inference() # inference
121 | >>> adata.layers["X_scar_denoised"] = scarObj.native_counts # results are saved in scarObj
122 | >>> adata.obsm["X_scar_assignment"] = scarObj.feature_assignment #'sgRNA' or 'tag' feature type
123 |
124 | Examples
125 | -------------------------
126 | .. plot::
127 | :context: close-figs
128 |
129 | # Synthetic data
130 | import numpy as np
131 | import seaborn as sns
132 | import matplotlib.pyplot as plt
133 | from scar import data_generator, model
134 |
135 | # Generate a synthetic ADT count dataset
136 | np.random.seed(8)
137 | n_features = 50 # 50 ADTs
138 | n_cells = 6000 # 6000 cells
139 | n_celltypes = 6 # cell types
140 | citeseq = data_generator.citeseq(n_cells, n_celltypes, n_features)
141 | citeseq.generate()
142 |
143 | # Train scAR
144 | citeseq_denoised = model(citeseq.obs_count, citeseq.ambient_profile, feature_type="ADT", sparsity=0.6) # initialize scar model
145 | citeseq_denoised.train(epochs=100, verbose=False) # start training
146 | citeseq_denoised.inference() # inference
147 |
148 | # Visualization
149 | sorted_noisy_counts = citeseq.obs_count[citeseq.celltype.argsort()][
150 | :, citeseq.ambient_profile.argsort()
151 | ] # noisy observation
152 | sorted_native_counts = citeseq.native_signals[citeseq.celltype.argsort()][
153 | :, citeseq.ambient_profile.argsort()
154 | ] # native counts
155 | sorted_denoised_counts = citeseq_denoised.native_counts.toarray()[citeseq.celltype.argsort()][
156 | :, citeseq.ambient_profile.argsort()
157 | ] # denoised counts
158 |
159 | fig, axs = plt.subplots(ncols=3, figsize=(12,4))
160 | sns.heatmap(
161 | np.log2(sorted_noisy_counts + 1),
162 | yticklabels=False,
163 | vmin=0,
164 | vmax=10,
165 | cmap="coolwarm",
166 | center=1,
167 | ax=axs[0],
168 | cbar_kws={"label": "log2(counts + 1)"},
169 | )
170 | axs[0].set_title("noisy observation")
171 |
172 | sns.heatmap(
173 | np.log2(sorted_native_counts + 1),
174 | yticklabels=False,
175 | vmin=0,
176 | vmax=10,
177 | cmap="coolwarm",
178 | center=1,
179 | ax=axs[1],
180 | cbar_kws={"label": "log2(counts + 1)"},
181 | )
182 | axs[1].set_title("native counts (ground truth)")
183 |
184 | sns.heatmap(
185 | np.log2(sorted_denoised_counts + 1),
186 | yticklabels=False,
187 | vmin=0,
188 | vmax=10,
189 | cmap="coolwarm",
190 | center=1,
191 | ax=axs[2],
192 | cbar_kws={"label": "log2(counts + 1)"},
193 | )
194 | axs[2].set_title("denoised counts (prediction)")
195 |
196 | fig.supxlabel("ADTs")
197 | fig.supylabel("cells")
198 | plt.tight_layout()
199 |
200 | """
201 |
202 | def __init__(
203 | self,
204 | raw_count: Union[str, np.ndarray, pd.DataFrame, ad.AnnData],
205 | ambient_profile: Optional[Union[str, np.ndarray, pd.DataFrame]] = None,
206 | nn_layer1: int = 150,
207 | nn_layer2: int = 100,
208 | latent_dim: int = 15,
209 | dropout_prob: float = 0,
210 | feature_type: str = "mRNA",
211 | count_model: str = "binomial",
212 | sparsity: float = 0.9,
213 | batch_key: str = None,
214 | device: str = "auto",
215 | cache_capacity: int = 20000,
216 | verbose: bool = True,
217 | ):
218 | """initialize object"""
219 |
220 | self.logger = get_logger("model", verbose=verbose)
221 | """logging.Logger, the logger for this class.
222 | """
223 |
224 | if device == "auto":
225 | if torch.cuda.is_available():
226 | self.device = torch.device("cuda")
227 | self.logger.info(f"{self.device} is detected and will be used.")
228 | elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
229 | self.device = torch.device("mps")
230 | self.logger.info(f"{self.device} is detected and will be used.")
231 | else:
232 | self.device = torch.device("cpu")
233 | self.logger.info(f"No GPU detected. {self.device} will be used.")
234 | else:
235 | self.device = device
236 | self.logger.info(f"{device} will be used.")
237 |
238 | """str, either "auto, "cpu" or "cuda".
239 | """
240 | self.nn_layer1 = nn_layer1
241 | """int, number of neurons of the 1st layer.
242 | """
243 | self.nn_layer2 = nn_layer2
244 | """int, number of neurons of the 2nd layer.
245 | """
246 | self.latent_dim = latent_dim
247 | """int, number of neurons of the bottleneck layer.
248 | """
249 | self.dropout_prob = dropout_prob
250 | """float, dropout probability of neurons.
251 | """
252 | self.feature_type = feature_type
253 | """str, the feature to be denoised. One of the following:
254 |
255 | | 'mRNA' -- transcriptome
256 | | 'ADT' -- protein counts in CITE-seq
257 | | 'sgRNA' -- sgRNA counts for scCRISPRseq
258 | | 'tag' -- identity barcodes or any data types of super high sparsity. \
259 | E.g., in cell indexing experiments, we would expect a single true signal \
260 | (1) and many negative signals (0) for each cell.
261 | | 'CMO' -- Cell Multiplexing Oligo counts for cell hashing
262 | | 'ATAC' -- peak counts for scATACseq
263 | | By default "mRNA"
264 | """
265 | self.count_model = count_model
266 | """str, the model to generate the UMI count. One of the following:
267 |
268 | | 'binomial' -- binomial model,
269 | | 'poisson' -- poisson model,
270 | | 'zeroinflatedpoisson' -- zeroinflatedpoisson model.
271 | """
272 | self.sparsity = sparsity
273 | """float, the sparsity of expected native signals. (0, 1]. \
274 | Forced to be one in the mode of "sgRNA(s)" and "tag(s)".
275 | """
276 | self.cache_capacity = cache_capacity
277 | """int, the capacity of caching data on GPU. Set a smaller value upon GPU memory issue. By default 20000 cells are cached on GPU/MPS.
278 |
279 | .. versionadded:: 0.7.0
280 | """
281 |
282 | if isinstance(raw_count, ad.AnnData):
283 | if batch_key is not None:
284 | if batch_key not in raw_count.obs.columns:
285 | raise ValueError(f"{batch_key} not found in AnnData.obs.")
286 |
287 | self.logger.info(
288 | f"Found {raw_count.obs[batch_key].nunique()} batches defined by {batch_key} in AnnData.obs. Estimating ambient profile per batch..."
289 | )
290 | batch_id_per_cell = pd.Categorical(raw_count.obs[batch_key]).codes
291 | ambient_profile = np.empty((len(np.unique(batch_id_per_cell)),raw_count.shape[1]))
292 | for batch_id in np.unique(batch_id_per_cell):
293 | subset = raw_count[batch_id_per_cell==batch_id]
294 | ambient_profile[batch_id, :] = subset.X.sum(axis=0) / subset.X.sum()
295 |
296 | # add a mapper to locate the batch id
297 | self.batch_id = batch_id_per_cell
298 | self.n_batch = len(np.unique(batch_id_per_cell))
299 | else:
300 | # get ambient profile from AnnData.uns
301 | if "ambient_profile_all" in raw_count.uns:
302 | self.logger.info(
303 | "Found ambient profile in AnnData.uns['ambient_profile_all']"
304 | )
305 | ambient_profile = raw_count.uns["ambient_profile_all"]
306 | else:
307 | self.logger.info(
308 | "Ambient profile not found in AnnData.uns['ambient_profile'], estimating it by averaging pooled cells..."
309 | )
310 |
311 | elif isinstance(raw_count, str):
312 | # read pickle file into dataframe
313 | raw_count = pd.read_pickle(raw_count)
314 |
315 | elif isinstance(raw_count, np.ndarray):
316 | # convert np.array to pd.DataFrame
317 | raw_count = pd.DataFrame(
318 | raw_count,
319 | index=range(raw_count.shape[0]),
320 | columns=range(raw_count.shape[1]),
321 | )
322 |
323 | elif isinstance(raw_count, pd.DataFrame):
324 | pass
325 | else:
326 | raise TypeError(
327 | f"Expecting str or np.array or pd.DataFrame or AnnData object, but get a {type(raw_count)}"
328 | )
329 |
330 | self.raw_count = raw_count
331 | """raw_count : np.ndarray, raw count matrix.
332 | """
333 | self.n_features = raw_count.shape[1]
334 | """int, number of features.
335 | """
336 | self.cell_id = raw_count.index.to_list() if isinstance(raw_count, pd.DataFrame) else raw_count.obs_names.to_list()
337 | """list, cell id.
338 | """
339 | self.feature_names = raw_count.columns.to_list() if isinstance(raw_count, pd.DataFrame) else raw_count.var_names.to_list()
340 | """list, feature names.
341 | """
342 |
343 | if isinstance(ambient_profile, str):
344 | ambient_profile = pd.read_pickle(ambient_profile)
345 | ambient_profile = ambient_profile.fillna(0).values # missing vals -> zeros
346 | elif isinstance(ambient_profile, pd.DataFrame):
347 | ambient_profile = ambient_profile.fillna(0).values # missing vals -> zeros
348 | elif isinstance(ambient_profile, np.ndarray):
349 | ambient_profile = np.nan_to_num(ambient_profile) # missing vals -> zeros
350 | elif not ambient_profile:
351 | self.logger.info(" Evaluate ambient profile from cells")
352 | if isinstance(raw_count, pd.DataFrame):
353 | ambient_profile = raw_count.sum() / raw_count.sum().sum()
354 | ambient_profile = ambient_profile.fillna(0).values
355 | elif isinstance(raw_count, ad.AnnData):
356 | ambient_profile = np.array(raw_count.X.sum(axis=0)/raw_count.X.sum())
357 | ambient_profile = np.nan_to_num(ambient_profile).flatten()
358 | else:
359 | raise TypeError(
360 | f"Expecting str / np.array / None / pd.DataFrame, but get a {type(ambient_profile)}"
361 | )
362 |
363 | if ambient_profile.squeeze().ndim == 1:
364 | ambient_profile = (
365 | ambient_profile.squeeze()
366 | .reshape(1, -1)
367 | )
368 | # add a mapper to locate the artificial batch id
369 | self.batch_id = np.zeros(raw_count.shape[0], dtype=int)#.reshape(-1, 1)
370 | self.n_batch = 1
371 |
372 | self.ambient_profile = ambient_profile
373 | """ambient_profile : np.ndarray, the probability of occurrence of each ambient transcript.
374 | """
375 |
376 | self.runtime = None
377 | """int, runtime in seconds.
378 | """
379 | self.loss_values = None
380 | """list, loss values during training.
381 | """
382 | self.trained_model = None
383 | """nn.Module object, added after training.
384 | """
385 | self.native_counts = None
386 | """np.ndarray, denoised counts, added after inference
387 | """
388 | self.bayesfactor = None
389 | """np.ndarray, bayesian factor of whether native signals are present, added after inference
390 | """
391 | self.native_frequencies = None
392 | """np.ndarray, probability of native transcripts (normalized denoised counts), added after inference
393 | """
394 | self.noise_ratio = None
395 | """np.ndarray, noise ratio per cell, added after inference
396 | """
397 | self.feature_assignment = None
398 | """pd.DataFrame, assignment of sgRNA or tag or other feature barcodes, added after inference or assignment
399 | """
400 |
401 | def train(
402 | self,
403 | batch_size: int = 64,
404 | train_size: float = 0.998,
405 | shuffle: bool = True,
406 | kld_weight: float = 1e-5,
407 | lr: float = 1e-3,
408 | lr_step_size: int = 5,
409 | lr_gamma: float = 0.97,
410 | epochs: int = 400,
411 | reconstruction_weight: float = 1,
412 | dropout_prob: float = 0,
413 | save_model: bool = False,
414 | verbose: bool = True,
415 | ):
416 | """train training scar model
417 |
418 | Parameters
419 | ----------
420 | batch_size : int, optional
421 | batch size, by default 64
422 | train_size : float, optional
423 | the size of training samples, by default 0.998
424 | shuffle : bool, optional
425 | whether to shuffle the data, by default True
426 | kld_weight : float, optional
427 | weight of KL loss, by default 1e-5
428 | lr : float, optional
429 | initial learning rate, by default 1e-3
430 | lr_step_size : int, optional
431 | `period of learning rate decay, \
432 | `_\
433 | by default 5
434 | lr_gamma : float, optional
435 | multiplicative factor of learning rate decay, by default 0.97
436 | epochs : int, optional
437 | training iterations, by default 800
438 | reconstruction_weight : float, optional
439 | weight on reconstruction error, by default 1
440 | dropout_prob : float, optional
441 | dropout probability of neurons, by default 0
442 | save_model : bool, optional
443 | whether to save trained models(under development), by default False
444 | verbose : bool, optional
445 | whether to print the details, by default True
446 | Returns
447 | -------
448 | After training, a trained_model attribute will be added.
449 |
450 | """
451 | # Generators
452 | total_dataset = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, device=self.device, cache_capacity=self.cache_capacity)
453 | training_set, validation_set = random_split(total_dataset, [train_size, 1 - train_size])
454 | training_generator = DataLoader(
455 | training_set, batch_size=batch_size, shuffle=shuffle,
456 | drop_last=True
457 | )
458 | self.dataset = total_dataset
459 |
460 | loss_values = []
461 |
462 | # Define model
463 | vae_nets = VAE(
464 | n_features=self.n_features,
465 | nn_layer1=self.nn_layer1,
466 | nn_layer2=self.nn_layer2,
467 | latent_dim=self.latent_dim,
468 | dropout_prob=dropout_prob,
469 | feature_type=self.feature_type,
470 | count_model=self.count_model,
471 | sparsity=self.sparsity,
472 | n_batch=self.n_batch,
473 | verbose=verbose,
474 | ).to(self.device)
475 | # Define optimizer
476 | optim = torch.optim.Adam(vae_nets.parameters(), lr=lr)
477 | scheduler = torch.optim.lr_scheduler.StepLR(
478 | optim, step_size=lr_step_size, gamma=lr_gamma
479 | )
480 |
481 | self.logger.info(f"kld_weight: {kld_weight:.2e}")
482 | self.logger.info(f"learning rate: {lr:.2e}")
483 | self.logger.info(f"lr_step_size: {lr_step_size:d}")
484 | self.logger.info(f"lr_gamma: {lr_gamma:.2f}")
485 |
486 | # Run training
487 | training_start_time = time.time()
488 | with std_out_err_redirect_tqdm() as orig_stdout:
489 | # Initialize progress bar
490 | progress_bar = tqdm(
491 | total=epochs,
492 | file=orig_stdout,
493 | dynamic_ncols=True,
494 | desc="Training",
495 | )
496 | progress_bar.clear()
497 | for _ in range(epochs):
498 | train_tot_loss = 0
499 | train_kld_loss = 0
500 | train_recon_loss = 0
501 |
502 | vae_nets.train()
503 | for x_batch, ambient_freq, batch_id_onehot in training_generator:
504 | optim.zero_grad()
505 | dec_nr, dec_prob, means, var, dec_dp = vae_nets(x_batch, batch_id_onehot)
506 | recon_loss_minibatch, kld_loss_minibatch, loss_minibatch = loss_fn(
507 | x_batch,
508 | dec_nr,
509 | dec_prob,
510 | means,
511 | var,
512 | ambient_freq,
513 | reconstruction_weight=reconstruction_weight,
514 | kld_weight=kld_weight,
515 | dec_dp=dec_dp,
516 | count_model=self.count_model,
517 | )
518 | loss_minibatch.backward()
519 | optim.step()
520 |
521 | train_tot_loss += loss_minibatch.detach().item()
522 | train_recon_loss += recon_loss_minibatch.detach().item()
523 | train_kld_loss += kld_loss_minibatch.detach().item()
524 |
525 | scheduler.step()
526 |
527 | avg_train_tot_loss = train_tot_loss / len(training_generator)
528 | loss_values.append(avg_train_tot_loss)
529 |
530 | progress_bar.set_postfix({"Loss": "{:.4e}".format(avg_train_tot_loss)})
531 | progress_bar.update()
532 |
533 | progress_bar.close()
534 |
535 | if save_model:
536 | torch.save(vae_nets, save_model)
537 |
538 | self.loss_values = loss_values
539 | self.trained_model = vae_nets
540 | self.runtime = time.time() - training_start_time
541 |
542 | # Inference
543 | @torch.no_grad()
544 | def inference(
545 | self,
546 | batch_size=4096,
547 | count_model_inf="poisson",
548 | adjust="micro",
549 | cutoff=3,
550 | round_to_int="stochastic_rounding",
551 | clip_to_obs=False,
552 | get_native_frequencies=False,
553 | moi=None,
554 | ):
555 | """inference infering the expected native signals, noise ratios, Bayesfactors and expected native frequencies
556 |
557 | Parameters
558 | ----------
559 | batch_size : int, optional
560 | batch size, set a small value upon GPU memory issue, by default 4096
561 | count_model_inf : str, optional
562 | inference model for evaluation of ambient presence, by default "poisson"
563 | adjust : str, optional
564 | Only used for calculating Bayesfactors to improve performance. \
565 | One of the following:
566 |
567 | | 'micro' -- adjust the estimated native counts per cell. \
568 | This can overcome the issue of over- or under-estimation of noise.
569 | | 'global' -- adjust the estimated native counts globally. \
570 | This can overcome the issue of over- or under-estimation of noise.
571 | | False -- no adjustment, use the model-returned native counts.
572 | | Defaults to "micro"
573 | cutoff : int, optional
574 | cutoff for Bayesfactors, by default 3
575 | round_to_int : str, optional
576 | whether to round the counts, by default "stochastic_rounding"
577 |
578 | .. versionadded:: 0.4.1
579 |
580 | clip_to_obs : bool, optional
581 | whether to clip the predicted native counts to the observation in order to ensure \
582 | that denoised counts are not greater than the observation, by default False. \
583 | Use it with caution, as it may lead to over-estimation of overall noise.
584 |
585 | .. versionadded:: 0.5.0
586 |
587 | get_native_frequencies : bool, optional
588 | whether to get native frequencies, by default False
589 |
590 | .. versionadded:: 0.7.0
591 |
592 | moi : int, optional (under development)
593 | multiplicity of infection. If assigned, it will allow optimized thresholding, \
594 | which tests a series of cutoffs to find the best one \
595 | based on distributions of infections under given moi.\
596 | See Perturb-seq [Dixit2016]_ for details, by default None
597 | Returns
598 | -------
599 | After inferring, several attributes will be added, inc. native_counts, bayesfactor,\
600 | native_frequencies, and noise_ratio. \
601 | A feature_assignment will be added in 'sgRNA' or 'tag' or 'CMO' feature type.
602 | """
603 | n_features = self.n_features
604 | sample_size = self.raw_count.shape[0]
605 |
606 | dt = np.int64 if round_to_int=="stochastic_rounding" else np.float32
607 | native_counts = sparse.lil_matrix((sample_size, n_features), dtype=dt)
608 | noise_ratio = sparse.lil_matrix((sample_size, 1), dtype=np.float32)
609 |
610 | native_frequencies = sparse.lil_matrix((sample_size, n_features), dtype=np.float32) if get_native_frequencies else None
611 |
612 | if self.feature_type.lower() in [
613 | "sgrna",
614 | "sgrnas",
615 | "tag",
616 | "tags",
617 | "cmo",
618 | "cmos",
619 | "atac",
620 | ]:
621 | bayesfactor = sparse.lil_matrix((sample_size, n_features), dtype=np.float32)
622 | else:
623 | bayesfactor = None
624 |
625 | if not batch_size:
626 | batch_size = sample_size
627 | i = 0
628 | generator_full_data = DataLoader(
629 | self.dataset, batch_size=batch_size, shuffle=False
630 | )
631 |
632 | for x_batch_tot, ambient_freq_tot, x_batch_id_onehot_tot in generator_full_data:
633 | minibatch_size = x_batch_tot.shape[
634 | 0
635 | ] # if not the last batch, equals to batch size
636 |
637 | (
638 | native_counts_batch,
639 | bayesfactor_batch,
640 | native_frequencies_batch,
641 | noise_ratio_batch,
642 | ) = self.trained_model.inference(
643 | x_batch_tot,
644 | x_batch_id_onehot_tot,
645 | ambient_freq_tot[0, :],
646 | count_model_inf=count_model_inf,
647 | adjust=adjust,
648 | round_to_int=round_to_int,
649 | clip_to_obs=clip_to_obs,
650 | )
651 | native_counts[
652 | i * batch_size : i * batch_size + minibatch_size, :
653 | ] = native_counts_batch
654 | noise_ratio[
655 | i * batch_size : i * batch_size + minibatch_size, :
656 | ] = noise_ratio_batch
657 | if native_frequencies is not None:
658 | native_frequencies[
659 | i * batch_size : i * batch_size + minibatch_size, :
660 | ] = native_frequencies_batch
661 | if bayesfactor is not None:
662 | bayesfactor[
663 | i * batch_size : i * batch_size + minibatch_size, :
664 | ] = bayesfactor_batch
665 |
666 | i += 1
667 |
668 | self.native_counts = native_counts.tocsr()
669 | self.noise_ratio = noise_ratio.tocsr()
670 | self.bayesfactor = bayesfactor.tocsr() if bayesfactor is not None else None
671 | self.native_frequencies = native_frequencies.tocsr() if native_frequencies is not None else None
672 |
673 | if self.feature_type.lower() in [
674 | "sgrna",
675 | "sgrnas",
676 | "tag",
677 | "tags",
678 | "cmo",
679 | "cmos",
680 | "atac",
681 | ]:
682 | self.assignment(cutoff=cutoff, moi=moi)
683 | else:
684 | self.feature_assignment = None
685 |
686 | def assignment(self, cutoff=3, moi=None):
687 | """assignment assignment of feature barcodes. Re-run it can test different cutoffs for your experiments.
688 |
689 | Parameters
690 | ----------
691 | cutoff : int, optional
692 | cutoff for Bayesfactors, by default 3
693 | moi : float, optional
694 | multiplicity of infection. (under development)\
695 | If assigned, it will allow optimized thresholding,\
696 | which tests a series of cutoffs to find the best one \
697 | based on distributions of infections under given moi.\
698 | See Perturb-seq [Dixit2016]_, by default None
699 | Returns
700 | -------
701 | After running, a attribute 'feature_assignment' will be added,\
702 | in 'sgRNA' or 'tag' or 'CMO' feature type.
703 | Raises
704 | ------
705 | NotImplementedError
706 | if moi is not None
707 | """
708 |
709 | feature_assignment = pd.DataFrame(
710 | index=self.cell_id, columns=[self.feature_type, f"n_{self.feature_type}"]
711 | )
712 | bayesfactor_df = pd.DataFrame(
713 | self.bayesfactor.toarray(), index=self.cell_id, columns=self.feature_names
714 | )
715 | bayesfactor_df[bayesfactor_df < cutoff] = 0 # Apply the cutoff for Bayesfactors
716 |
717 | for cell, row in bayesfactor_df.iterrows():
718 | bayesfactor_max = row[row == row.max()]
719 | if row.max() == 0:
720 | feature_assignment.loc[cell, f"n_{self.feature_type}"] = 0
721 | feature_assignment.loc[cell, self.feature_type] = ""
722 | elif len(bayesfactor_max) == 1:
723 | feature_assignment.loc[cell, f"n_{self.feature_type}"] = 1
724 | feature_assignment.loc[cell, self.feature_type] = bayesfactor_max.index[
725 | 0
726 | ]
727 | else:
728 | feature_assignment.loc[cell, f"n_{self.feature_type}"] = len(
729 | bayesfactor_max
730 | )
731 | feature_assignment.loc[cell, self.feature_type] = (", ").join(
732 | bayesfactor_max.index.astype(str)
733 | )
734 |
735 | self.feature_assignment = feature_assignment
736 |
737 | if moi:
738 | raise NotImplementedError
739 |
740 | class UMIDataset(Dataset):
741 | """Characterizes dataset for PyTorch"""
742 |
743 | def __init__(self, raw_count, ambient_profile, batch_id, device, cache_capacity=20000):
744 | """Initialization"""
745 |
746 | self.raw_count = torch.from_numpy(raw_count.fillna(0).values).int() if isinstance(raw_count, pd.DataFrame) else raw_count
747 | self.ambient_profile = torch.from_numpy(ambient_profile).float().to(device)
748 | self.batch_id = torch.from_numpy(batch_id).to(torch.int64).to(device)
749 | self.batch_onehot = torch.from_numpy(np.eye(len(np.unique(batch_id)))).to(torch.int64).to(device)
750 | self.device = device
751 | self.cache_capacity = cache_capacity
752 |
753 | # Cache data
754 | self.cache = {}
755 |
756 | def __len__(self):
757 | """Denotes the total number of samples"""
758 | return self.raw_count.shape[0]
759 |
760 | def __getitem__(self, index):
761 | """Generates one sample of data"""
762 |
763 | if index in self.cache:
764 | return self.cache[index]
765 | else:
766 | # Select samples
767 | sc_count = self.raw_count[index].to(self.device) if isinstance(self.raw_count, torch.Tensor) else torch.from_numpy(self.raw_count[index].X.toarray().flatten()).int().to(self.device)
768 | sc_ambient = self.ambient_profile[self.batch_id[index], :]
769 | sc_batch_id_onehot = self.batch_onehot[self.batch_id[index], :]
770 |
771 | # Cache samples
772 | if len(self.cache) <= self.cache_capacity:
773 | self.cache[index] = (sc_count, sc_ambient, sc_batch_id_onehot)
774 |
775 | return sc_count, sc_ambient, sc_batch_id_onehot
776 |
--------------------------------------------------------------------------------
/scar/main/_setup.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 | import numpy as np
3 | import pandas as pd
4 | import matplotlib.pyplot as plt
5 | import seaborn as sns
6 | from anndata import AnnData
7 | import torch
8 | from torch.distributions.multinomial import Multinomial
9 |
10 | from ._utils import get_logger
11 |
12 |
13 | def setup_anndata(
14 | adata: AnnData,
15 | raw_adata: AnnData,
16 | feature_type: Union[str, list] = None,
17 | prob: float = 0.995,
18 | min_raw_counts: int = 2,
19 | iterations: int = 3,
20 | n_batch: int = None,
21 | sample: int = None,
22 | kneeplot: bool = True,
23 | verbose: bool = True,
24 | figsize: tuple = (6, 6),
25 | ):
26 | """Calculate ambient profile for relevant features
27 |
28 | Identify the cell-free droplets through a multinomial distribution. See EmptyDrops [Lun2019]_ for details.
29 |
30 |
31 | Parameters
32 | ----------
33 | adata : AnnData
34 | A filtered adata object, loaded from filtered_feature_bc_matrix using `scanpy.read` , gene filtering is recommended to save memory
35 | raw_adata : AnnData
36 | An raw adata object, loaded from raw_feature_bc_matrix using `scanpy.read`
37 | feature_type : Union[str, list], optional
38 | Feature type, e.g. 'Gene Expression', 'Antibody Capture', 'CRISPR Guide Capture' or 'Multiplexing Capture', all feature types are calculated if None, by default None
39 | prob : float, optional
40 | The probability of each gene, considered as containing ambient RNA if greater than prob (joint prob euqals to the product of all genes for a droplet), by default 0.995
41 | min_raw_counts : int, optional
42 | Total counts filter for raw_adata, filtering out low counts to save memory, by default 2
43 | iterations : int, optional
44 | Total iterations, by default 3
45 | n_batch : int, optional
46 | Total number of batches, set it to a bigger number when out of memory issue occurs, by default None
47 | sample : int, optional
48 | Randomly sample droplets to test, if greater than total droplets, use all droplets. Use all droplets by default (None)
49 | kneeplot : bool, optional
50 | Kneeplot to show subpopulations of droplets, by default True
51 | verbose : bool, optional
52 | Whether to display message
53 | figsize : tuple, optimal
54 | Figure size, by default (6, 6)
55 |
56 | Returns
57 | -------
58 | The relevant ambient profile is added in `adata.uns`
59 |
60 | Examples
61 | ---------
62 | .. plot::
63 | :context: close-figs
64 |
65 | import scanpy as sc
66 | from scar import setup_anndata
67 | # read filtered data
68 | adata = sc.read_10x_h5(filename='500_hgmm_3p_LT_Chromium_Controller_filtered_feature_bc_matrix.h5ad',
69 | backup_url='https://cf.10xgenomics.com/samples/cell-exp/6.1.0/500_hgmm_3p_LT_Chromium_Controller/500_hgmm_3p_LT_Chromium_Controller_filtered_feature_bc_matrix.h5');
70 | adata.var_names_make_unique();
71 | # read raw data
72 | adata_raw = sc.read_10x_h5(filename='500_hgmm_3p_LT_Chromium_Controller_raw_feature_bc_matrix.h5ad',
73 | backup_url='https://cf.10xgenomics.com/samples/cell-exp/6.1.0/500_hgmm_3p_LT_Chromium_Controller/500_hgmm_3p_LT_Chromium_Controller_raw_feature_bc_matrix.h5');
74 | adata_raw.var_names_make_unique();
75 | # gene and cell filter
76 | sc.pp.filter_genes(adata, min_counts=200);
77 | sc.pp.filter_genes(adata, max_counts=6000);
78 | sc.pp.filter_cells(adata, min_genes=200);
79 | # setup anndata
80 | setup_anndata(
81 | adata,
82 | adata_raw,
83 | feature_type = "Gene Expression",
84 | prob = 0.975,
85 | min_raw_counts = 2,
86 | kneeplot = True,
87 | )
88 | """
89 |
90 | setup_logger = get_logger("setup_anndata", verbose=verbose)
91 |
92 | if feature_type is None:
93 | feature_type = adata.var["feature_types"].unique()
94 | elif isinstance(feature_type, str):
95 | feature_type = [feature_type]
96 |
97 | # take subset genes to save memory
98 | # raw_adata._inplace_subset_var(raw_adata.var_names.isin(adata.var_names))
99 | # raw_adata._inplace_subset_obs(raw_adata.X.sum(axis=1) >= min_raw_counts)
100 | raw_adata = raw_adata[:, raw_adata.var_names.isin(adata.var_names)]
101 | raw_adata = raw_adata[raw_adata.X.sum(axis=1) >= min_raw_counts]
102 |
103 | raw_adata.obs["total_counts"] = raw_adata.X.sum(axis=1)
104 |
105 | if sample is not None:
106 | sample = int(sample)
107 | setup_logger.info(
108 | f"Randomly sample {sample:d} droplets from {raw_adata.shape[0]:d} droplets."
109 | )
110 | else:
111 | sample = raw_adata.shape[0]
112 | setup_logger.info(f"Use all {sample:d} droplets.")
113 |
114 | # check n_batch
115 | if n_batch is None:
116 | n_batch = int(np.ceil(sample / 5000))
117 | else:
118 | n_batch = int(n_batch)
119 |
120 | # check if per batch contains too many droplets
121 | if sample / n_batch > 5000:
122 | setup_logger.info(
123 | "The number of droplets per batch is too large, this may cause memory issue, please increase the number of batches."
124 | )
125 |
126 | idx = np.random.choice(
127 | raw_adata.shape[0], size=min(raw_adata.shape[0], sample), replace=False
128 | )
129 | raw_adata = raw_adata[idx]
130 |
131 | # initial estimation of ambient profile, will be update
132 | ambient_prof = raw_adata.X.sum(axis=0) / raw_adata.X.sum()
133 |
134 | setup_logger.info(f"Estimating ambient profile for {feature_type}...")
135 |
136 | i = 0
137 | while i < iterations:
138 | # calculate joint probability (log) of being cell-free droplets for each droplet
139 | log_prob = []
140 | batch_idx = np.floor(
141 | np.array(range(raw_adata.shape[0])) / raw_adata.shape[0] * n_batch
142 | )
143 | for b in range(n_batch):
144 | try:
145 | count_batch = raw_adata[batch_idx == b].X.astype(int).toarray()
146 | except MemoryError:
147 | raise MemoryError("use more batches by setting a higher n_batch")
148 | log_prob_batch = Multinomial(
149 | probs=torch.tensor(ambient_prof), validate_args=False
150 | ).log_prob(torch.Tensor(count_batch))
151 | log_prob.append(log_prob_batch)
152 |
153 | log_prob = np.concatenate(log_prob, axis=0)
154 | raw_adata.obs["log_prob"] = log_prob
155 | raw_adata.obs["droplets"] = "other droplets"
156 |
157 | # cell-containing droplets
158 | raw_adata.obs.loc[
159 | raw_adata.obs_names.isin(adata.obs_names), "droplets"
160 | ] = "cells"
161 |
162 | # identify cell-free droplets
163 | raw_adata.obs["droplets"] = raw_adata.obs["droplets"].mask(
164 | raw_adata.obs["log_prob"] >= np.log(prob) * raw_adata.shape[1],
165 | "cell-free droplets",
166 | )
167 | emptydrops = raw_adata[raw_adata.obs["droplets"] == "cell-free droplets"]
168 |
169 | if emptydrops.shape[0] < 50:
170 | raise Exception("Too few emptydroplets! Lower the prob parameter")
171 |
172 | ambient_prof = emptydrops.X.sum(axis=0) / emptydrops.X.sum()
173 |
174 | i += 1
175 |
176 | setup_logger.info(f"Iteration: {i:d}")
177 |
178 | # update ambient profile for each feature type
179 | for ft in feature_type:
180 | tmp = emptydrops[:, emptydrops.var["feature_types"] == ft]
181 | adata.uns[f"ambient_profile_{ft}"] = pd.DataFrame(
182 | tmp.X.sum(axis=0).reshape(-1, 1) / tmp.X.sum(),
183 | index=tmp.var_names,
184 | columns=[f"ambient_profile_{ft}"],
185 | )
186 |
187 | setup_logger.info(f"Estimated ambient profile for {ft} saved in adata.uns")
188 |
189 | # update ambient profile for all feature types
190 | adata.uns[f"ambient_profile_all"] = pd.DataFrame(
191 | emptydrops.X.sum(axis=0).reshape(-1, 1) / emptydrops.X.sum(),
192 | index=emptydrops.var_names,
193 | columns=[f"ambient_profile_all"],
194 | )
195 |
196 | setup_logger.info("Estimated ambient profile for all features saved in adata.uns")
197 |
198 | if kneeplot:
199 | _, axs = plt.subplots(2, figsize=figsize)
200 |
201 | all_droplets = raw_adata.obs.copy()
202 | all_droplets = (
203 | all_droplets.sort_values(by="total_counts", ascending=False)
204 | .reset_index()
205 | .rename_axis("rank_by_counts")
206 | .reset_index()
207 | )
208 | all_droplets = all_droplets.loc[all_droplets["total_counts"] >= min_raw_counts]
209 | all_droplets = all_droplets.set_index("index").rename_axis("cells")
210 | all_droplets = (
211 | all_droplets.sort_values(by="log_prob", ascending=True)
212 | .reset_index()
213 | .rename_axis("rank_by_log_prob")
214 | .reset_index()
215 | .set_index("cells")
216 | )
217 |
218 | ax = sns.lineplot(
219 | data=all_droplets,
220 | x="rank_by_counts",
221 | y="total_counts",
222 | hue="droplets",
223 | hue_order=["cells", "other droplets", "cell-free droplets"],
224 | palette=sns.color_palette()[-3:],
225 | markers=False,
226 | lw=2,
227 | ci=None,
228 | ax=axs[0],
229 | )
230 |
231 | ax.set_xscale("log")
232 | ax.set_yscale("log")
233 | ax.set_xlabel("")
234 | ax.set_title("cell-free droplets have lower counts")
235 |
236 | all_droplets["prob"] = np.exp(all_droplets["log_prob"])
237 | ax = sns.lineplot(
238 | data=all_droplets,
239 | x="rank_by_log_prob",
240 | y="prob",
241 | hue="droplets",
242 | hue_order=["cells", "other droplets", "cell-free droplets"],
243 | palette=sns.color_palette()[-3:],
244 | markers=False,
245 | lw=2,
246 | ci=None,
247 | ax=axs[1],
248 | )
249 | ax.set_xscale("log")
250 | ax.set_xlabel("sorted droplets")
251 | ax.set_title("cell-free droplets have relatively higher probs")
252 |
253 | plt.tight_layout()
254 |
--------------------------------------------------------------------------------
/scar/main/_utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | # Added: Create logger and assign handler
4 | def get_logger(name, verbose=True):
5 | logger = logging.getLogger(name)
6 | logger.handlers.clear()
7 | logger.setLevel(logging.INFO if verbose else logging.WARNING)
8 | formatter = logging.Formatter(
9 | fmt="%(asctime)s|%(levelname)s|%(name)s|%(message)s",
10 | datefmt="%Y-%m-%d %H:%M:%S",
11 | )
12 | handler = logging.StreamHandler()
13 | handler.setFormatter(formatter)
14 | logger.addHandler(handler)
15 | return logger
16 |
--------------------------------------------------------------------------------
/scar/main/_vae.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | The variational autoencoder
4 | """
5 |
6 | import numpy as np
7 | from scipy import stats
8 |
9 | import torch
10 | from torch import nn
11 |
12 | from ._activation_functions import mytanh, hnormalization, mysoftplus
13 | from ._utils import get_logger
14 |
15 | #########################################################################
16 | ## Variational autoencoder
17 | #########################################################################
18 |
19 |
20 | class VAE(nn.Module):
21 | """A class of variational autoencoder
22 |
23 | Parameters
24 | ----------
25 | n_features : int
26 | number of features (e.g. mRNA, sgRNA, ADT, tag, CMO, ...)
27 | nn_layer1 : int, optional
28 | number of neurons in the 1st layer, by default 150
29 | nn_layer2 : int, optional
30 | number of neurons in the 2nd layer, by default 100
31 | latent_dim : int, optional
32 | number of neurons in the bottleneck layer, by default 15
33 | dropout_prob : int, optional
34 | dropout probability, by default 0
35 | feature_type : str, optional
36 | the feature to be denoised, either of 'mRNA', 'sgRNA', 'ADT', 'tag', 'CMO', 'ATAC', by default "mRNA"
37 | count_model : str, optional
38 | the model to generate the UMI count, either of "binomial", "poisson", "zeroinflatedpoisson", by default "binomial"
39 | sparsity : float, optional
40 | the sparsity of expected data, by default 0.9
41 | verbose : bool, optional
42 | whether to display information, by default True
43 | """
44 |
45 | def __init__(
46 | self,
47 | n_features,
48 | nn_layer1=150,
49 | nn_layer2=100,
50 | latent_dim=15,
51 | dropout_prob=0,
52 | feature_type="mRNA",
53 | count_model="binomial",
54 | n_batch=1,
55 | sparsity=0.9,
56 | verbose=True,
57 | ):
58 | super().__init__()
59 | assert feature_type.lower() in [
60 | "mrna",
61 | "mrnas",
62 | "sgrna",
63 | "sgrnas",
64 | "adt",
65 | "adts",
66 | "tag",
67 | "tags",
68 | "cmo",
69 | "cmos",
70 | "atac",
71 | ]
72 | assert count_model.lower() in ["binomial", "poisson", "zeroinflatedpoisson"]
73 | # force the sparsity to be one in the mode of "sgRNAs" and "tags" and "CMOs"
74 | if feature_type.lower() in [
75 | "sgrna",
76 | "sgrnas",
77 | "tag",
78 | "tags",
79 | "cmo",
80 | "cmos",
81 | ]:
82 | sparsity = 1
83 |
84 | self.encoder = Encoder(
85 | n_features, n_batch, nn_layer1, nn_layer2, latent_dim, dropout_prob
86 | )
87 | self.decoder = Decoder(
88 | n_features,
89 | n_batch,
90 | nn_layer1,
91 | nn_layer2,
92 | latent_dim,
93 | dropout_prob,
94 | count_model,
95 | sparsity,
96 | )
97 |
98 | vae_logger = get_logger("VAE", verbose=verbose)
99 |
100 | vae_logger.info("Running VAE using the following param set:")
101 | vae_logger.info(f"...denoised count type: {feature_type}")
102 | vae_logger.info(f"...count model: {count_model}")
103 | vae_logger.info(f"...num_input_feature: {n_features:d}")
104 | vae_logger.info(f"...NN_layer1: {nn_layer1:d}")
105 | vae_logger.info(f"...NN_layer2: {nn_layer2:d}")
106 | vae_logger.info(f"...latent_space: {latent_dim:d}")
107 | vae_logger.info(f"...dropout_prob: {dropout_prob:.2f}")
108 | vae_logger.info(f"...expected data sparsity: {sparsity:.2f}")
109 |
110 | def forward(self, input_matrix, batch_id_onehot=None):
111 | """forward function"""
112 | sampling, means, var = self.encoder(input_matrix, batch_id_onehot)
113 | dec_nr, dec_prob, dec_dp = self.decoder(sampling, batch_id_onehot)
114 | return dec_nr, dec_prob, means, var, dec_dp
115 |
116 | @torch.no_grad()
117 | def inference(
118 | self,
119 | input_matrix,
120 | batch_id_onehot,
121 | amb_prob,
122 | count_model_inf="poisson",
123 | adjust="micro",
124 | round_to_int="stochastic_rounding",
125 | clip_to_obs=False,
126 | ):
127 | """
128 | Inference of presence of native signals
129 | """
130 | assert count_model_inf.lower() in ["poisson", "binomial", "zeroinflatedpoisson"]
131 | assert adjust in [False, "global", "micro"]
132 |
133 | # Estimate native signals
134 | dec_nr, dec_prob, _, _, _ = self.forward(input_matrix, batch_id_onehot)
135 |
136 | # Copy tensor to CPU
137 | input_matrix_np = input_matrix.cpu().numpy()
138 | noise_ratio = dec_nr.cpu().numpy().reshape(-1, 1)
139 | nat_prob = dec_prob.cpu().numpy()
140 | amb_prob = amb_prob.cpu().numpy().reshape(1, -1)
141 |
142 | total_count_per_cell = input_matrix_np.sum(axis=1).reshape(-1, 1)
143 | expected_native_counts = total_count_per_cell * (1 - noise_ratio) * nat_prob
144 | expected_amb_counts = total_count_per_cell * noise_ratio * amb_prob
145 | tot_amb = expected_amb_counts.sum(axis=1).reshape(-1, 1)
146 |
147 | if not round_to_int:
148 | pass
149 | elif round_to_int.lower() == "stochastic_rounding":
150 | expected_native_counts = (
151 | np.floor(expected_native_counts)
152 | + (
153 | np.random.rand(*expected_native_counts.shape)
154 | < expected_native_counts - np.floor(expected_native_counts)
155 | ).astype(int)
156 | ).astype(int)
157 |
158 | expected_amb_counts = (
159 | np.floor(expected_amb_counts)
160 | + (
161 | np.random.rand(*expected_amb_counts.shape)
162 | < expected_amb_counts - np.floor(expected_amb_counts)
163 | ).astype(int)
164 | ).astype(int)
165 |
166 | if clip_to_obs:
167 | expected_native_counts = np.clip(
168 | expected_native_counts,
169 | a_min=np.zeros_like(input_matrix_np),
170 | a_max=input_matrix_np,
171 | )
172 |
173 | if not adjust:
174 | adjust = 0
175 | elif adjust == "global":
176 | adjust = (total_count_per_cell.sum() - tot_amb.sum()) / len(
177 | input_matrix_np.flatten()
178 | )
179 | elif adjust == "micro":
180 | adjust = (total_count_per_cell - tot_amb) / input_matrix_np.shape[1]
181 | adjust = np.repeat(adjust, input_matrix_np.shape[1], axis=1)
182 |
183 | ### Calculate the Bayesian factors
184 | # The probability that observed UMI counts do not purely
185 | # come from expected distribution of ambient signals.
186 | # H1: x is drawn from distribution (binomial or poission or
187 | # zeroinflatedpoisson)with prob > amb_prob
188 | # H2: x is drawn from distribution (binomial or poission or
189 | # zeroinflatedpoisson) with prob = amb_prob
190 |
191 | if count_model_inf.lower() == "binomial":
192 | probs_h1 = stats.binom.logcdf(input_matrix_np, tot_amb + adjust, amb_prob)
193 | probs_h2 = stats.binom.logpmf(input_matrix_np, tot_amb + adjust, amb_prob)
194 |
195 | elif count_model_inf.lower() == "poisson":
196 | probs_h1 = stats.poisson.logcdf(
197 | input_matrix_np, expected_amb_counts + adjust
198 | )
199 | probs_h2 = stats.poisson.logpmf(
200 | input_matrix_np, expected_amb_counts + adjust
201 | )
202 |
203 | elif count_model_inf.lower() == "zeroinflatedpoisson":
204 | raise NotImplementedError
205 |
206 | bayesian_factor = np.clip(probs_h1 - probs_h2 + 1e-22, -709.78, 709.78)
207 | bayesian_factor = np.exp(bayesian_factor)
208 |
209 | return (
210 | expected_native_counts,
211 | bayesian_factor,
212 | dec_prob.cpu().numpy(),
213 | dec_nr.cpu().numpy(),
214 | )
215 |
216 |
217 | #########################################################################
218 | ## Encoder
219 | #########################################################################
220 | class Encoder(nn.Module):
221 | """
222 | Encoder that takes the original expressions of feature barcodes and produces the encoding.
223 |
224 | Consists of 2 FC layers.
225 | """
226 |
227 | def __init__(self, n_features, n_batch, nn_layer1, nn_layer2, latent_dim, dropout_prob):
228 | """initialization"""
229 | super().__init__()
230 | self.activation = nn.SELU()
231 | # if n_batch > 1:
232 | # n_features += n_batch
233 | self.fc1 = nn.Linear(n_features + n_batch, nn_layer1)
234 | self.bn1 = nn.BatchNorm1d(nn_layer1, momentum=0.01, eps=0.001)
235 | self.dp1 = nn.Dropout(p=dropout_prob)
236 | self.fc2 = nn.Linear(nn_layer1, nn_layer2)
237 | self.bn2 = nn.BatchNorm1d(nn_layer2, momentum=0.01, eps=0.001)
238 | self.dp2 = nn.Dropout(p=dropout_prob)
239 |
240 | self.linear_means = nn.Linear(nn_layer2, latent_dim)
241 | self.linear_log_vars = nn.Linear(nn_layer2, latent_dim)
242 | self.z_transformation = nn.Softmax(dim=-1)
243 |
244 | def reparametrize(self, means, log_vars):
245 | """reparameterization"""
246 | var = log_vars.exp() + 1e-4
247 | return torch.distributions.Normal(means, var.sqrt()).rsample(), var
248 |
249 | def forward(self, input_matrix, batch_id_onehot):
250 | """forward function"""
251 | input_matrix = (input_matrix + 1).log2() # log transformation of count data
252 | input_matrix = torch.cat([input_matrix, batch_id_onehot], 1)
253 | enc = self.fc1(input_matrix)
254 | enc = self.bn1(enc)
255 | enc = self.activation(enc)
256 | enc = self.dp1(enc)
257 | enc = self.fc2(enc)
258 | enc = self.bn2(enc)
259 | enc = self.activation(enc)
260 | enc = torch.clamp(enc, min=None, max=1e7)
261 | enc = self.dp2(enc)
262 |
263 | means = self.linear_means(enc)
264 | log_vars = self.linear_log_vars(enc)
265 | sampling, var = self.reparametrize(means, log_vars)
266 | latent_transform = self.z_transformation(sampling)
267 |
268 | return latent_transform, means, var
269 |
270 |
271 | #########################################################################
272 | ## Decoder
273 | #########################################################################
274 | class Decoder(nn.Module):
275 | """
276 | A decoder model that takes the encodings and a batch (source) matrix and produces decodings.
277 |
278 | Made up of 2 FC layers.
279 | """
280 |
281 | def __init__(
282 | self,
283 | n_features,
284 | n_batch,
285 | nn_layer1,
286 | nn_layer2,
287 | latent_dim,
288 | dropout_prob,
289 | count_model,
290 | sparsity,
291 | ):
292 | """initialization"""
293 | super().__init__()
294 | self.activation = nn.SELU()
295 | self.normalization_native_freq = hnormalization()
296 | self.noise_activation = mytanh()
297 | self.activation_native_freq = mysoftplus(sparsity)
298 | self.fc4 = nn.Linear(latent_dim + n_batch, nn_layer2)
299 | self.bn4 = nn.BatchNorm1d(nn_layer2, momentum=0.01, eps=0.001)
300 | self.dp4 = nn.Dropout(p=dropout_prob)
301 | self.fc5 = nn.Linear(nn_layer2, nn_layer1)
302 | self.bn5 = nn.BatchNorm1d(nn_layer1, momentum=0.01, eps=0.001)
303 | self.dp5 = nn.Dropout(p=dropout_prob)
304 |
305 | self.noise_fc = nn.Linear(nn_layer1, 1)
306 | self.out_fc = nn.Linear(nn_layer1, n_features)
307 | self.count_model = count_model
308 | if count_model.lower() == "zeroinflatedpoisson":
309 | self.dropoutprob = nn.Linear(nn_layer1, 1)
310 | self.dropout_activation = mytanh()
311 |
312 | def forward(self, sampling, batch_id_onehot):
313 | """forward function"""
314 | # decoder
315 | cond_sampling = torch.cat([sampling, batch_id_onehot], 1)
316 | dec = self.fc4(cond_sampling)
317 | dec = self.bn4(dec)
318 | dec = self.activation(dec)
319 | dec = self.fc5(dec)
320 | dec = self.bn5(dec)
321 | dec = self.activation(dec)
322 | dec = torch.clamp(dec, min=None, max=1e7)
323 |
324 | # final layers to produce prob parameters
325 | dec_prob = self.out_fc(dec)
326 | dec_prob = self.activation_native_freq(dec_prob)
327 | dec_prob = self.normalization_native_freq(dec_prob)
328 |
329 | # final layers to produce noise_ratio parameters
330 | dec_nr = self.noise_fc(dec)
331 | dec_nr = self.noise_activation(dec_nr)
332 |
333 | # final layers to learn the dropout probability
334 | if self.count_model.lower() == "zeroinflatedpoisson":
335 | dec_dp = self.dropoutprob(dec)
336 | dec_dp = self.dropout_activation(dec_dp)
337 | dec_dp = torch.nan_to_num(dec_dp, nan=1e-7)
338 | else:
339 | dec_dp = None
340 |
341 | return dec_nr, dec_prob, dec_dp
342 |
--------------------------------------------------------------------------------
/scar/test/ambient_profile.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Novartis/scar/df7adef31162b471358b2b442445a449f939859d/scar/test/ambient_profile.pickle
--------------------------------------------------------------------------------
/scar/test/citeseq_ambient_profile.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Novartis/scar/df7adef31162b471358b2b442445a449f939859d/scar/test/citeseq_ambient_profile.pickle
--------------------------------------------------------------------------------
/scar/test/citeseq_native_counts.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Novartis/scar/df7adef31162b471358b2b442445a449f939859d/scar/test/citeseq_native_counts.pickle
--------------------------------------------------------------------------------
/scar/test/citeseq_raw_counts.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Novartis/scar/df7adef31162b471358b2b442445a449f939859d/scar/test/citeseq_raw_counts.pickle
--------------------------------------------------------------------------------
/scar/test/output_assignment.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Novartis/scar/df7adef31162b471358b2b442445a449f939859d/scar/test/output_assignment.pickle
--------------------------------------------------------------------------------
/scar/test/raw_counts.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Novartis/scar/df7adef31162b471358b2b442445a449f939859d/scar/test/raw_counts.pickle
--------------------------------------------------------------------------------
/scar/test/test_activation_functions.py:
--------------------------------------------------------------------------------
1 | """ This tests the activation functions. """
2 | import decimal
3 | import unittest
4 | import numpy
5 | import torch
6 | from decimal import Decimal
7 | from scar.main._activation_functions import mytanh, hnormalization, mysoftplus
8 |
9 |
10 | class ActivationFunctionsTest(unittest.TestCase):
11 | """
12 | Test activation_functions.py functions.
13 | """
14 |
15 | def test_mytanh(self):
16 | """
17 | Test mytanh().
18 | """
19 | self.assertEqual(
20 | Decimal(mytanh()(torch.tensor(1.0, dtype=torch.float32)).item()).quantize(
21 | decimal.Decimal(".01"), rounding=decimal.ROUND_DOWN
22 | ),
23 | Decimal(0.88).quantize(decimal.Decimal(".01"), rounding=decimal.ROUND_DOWN),
24 | )
25 |
26 | def test_hnormalization(self):
27 | """
28 | Test hnormalization().
29 | """
30 | self.assertTrue(
31 | torch.allclose(
32 | hnormalization()(torch.tensor(numpy.full((20, 8), 1))).double(),
33 | torch.tensor(numpy.full((20, 8), 0.1250)),
34 | )
35 | )
36 |
37 | def test_mysoftplus(self):
38 | """
39 | Test mysoftplus().
40 | """
41 | self.assertTrue(
42 | torch.allclose(
43 | mysoftplus()(
44 | torch.tensor(numpy.full((20, 8), 0.01), dtype=torch.float32)
45 | ).double(),
46 | torch.tensor(numpy.full((20, 8), 0.3849)),
47 | )
48 | )
49 |
--------------------------------------------------------------------------------
/scar/test/test_scar.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | from sklearn.metrics.pairwise import euclidean_distances
4 |
5 | from scar import model, data_generator
6 | import unittest
7 |
8 |
9 | class ScarIntegration(unittest.TestCase):
10 | """
11 | Functional testing
12 | """
13 |
14 | def test_scar(self):
15 | raw_count = pd.read_pickle("scar/test/raw_counts.pickle")
16 | ambient_profile = pd.read_pickle("scar/test/ambient_profile.pickle")
17 | expected_output = pd.read_pickle("scar/test/output_assignment.pickle")
18 |
19 | scarObj = model(
20 | raw_count=raw_count.values,
21 | ambient_profile=ambient_profile,
22 | feature_type="sgRNAs",
23 | )
24 |
25 | scarObj.train(epochs=40, batch_size=32)
26 |
27 | scarObj.inference()
28 |
29 | self.assertTrue(scarObj.feature_assignment.equals(expected_output))
30 |
31 | def test_scar_data_generator(self):
32 | """
33 | Functional testing of data_generator module
34 | """
35 | np.random.seed(8)
36 | citeseq = data_generator.citeseq(6000, 6, 50)
37 | citeseq.generate()
38 |
39 | citeseq_raw_counts = pd.read_pickle("scar/test/citeseq_raw_counts.pickle")
40 |
41 | self.assertTrue(np.array_equal(citeseq.obs_count, citeseq_raw_counts.values, equal_nan=True))
42 |
43 | def test_scar_citeseq(self):
44 | """
45 | Functional testing of scAR
46 | """
47 | citeseq_raw_counts = pd.read_pickle("scar/test/citeseq_raw_counts.pickle")
48 | citeseq_ambient_profile = pd.read_pickle(
49 | "scar/test/citeseq_ambient_profile.pickle"
50 | )
51 | citeseq_native_signals = pd.read_pickle(
52 | "scar/test/citeseq_native_counts.pickle"
53 | )
54 |
55 | citeseq_scar = model(
56 | raw_count=citeseq_raw_counts.values,
57 | ambient_profile=citeseq_ambient_profile.values,
58 | feature_type="ADTs",
59 | )
60 |
61 | citeseq_scar.train(epochs=200, batch_size=32, verbose=False)
62 | citeseq_scar.inference()
63 |
64 | dist = euclidean_distances(
65 | citeseq_native_signals.values, citeseq_scar.native_counts
66 | )
67 | mean_dist = (np.eye(dist.shape[0]) * dist).sum() / dist.shape[0]
68 |
69 | self.assertLess(mean_dist, 50, "Funtional test of scAR fails")
70 |
--------------------------------------------------------------------------------