├── .github ├── CONTRIBUTING.md └── workflows │ ├── black.yaml │ ├── python-conda-build.yaml │ └── semantic-release.yaml ├── .gitignore ├── .readthedocs.yaml ├── README.md ├── docs ├── Contacts.rst ├── Installation.rst ├── Introduction.rst ├── License.rst ├── Makefile ├── README.rst ├── Reference.rst ├── Release_notes.md ├── _static │ ├── ambient_signal_hypothesis.png │ ├── bgd.png │ ├── overview_scAR.png │ ├── scAR_favicon.png │ ├── scAR_logo_black.png │ ├── scAR_logo_transparent.png │ ├── scAR_logo_white.png │ └── scar-styles.css ├── _templates │ └── layout.html ├── conf.py ├── index.rst ├── make.bat ├── requirements.txt ├── tutorials │ ├── README.rst │ ├── index.rst │ ├── scAR_tutorial_ambient_profile.ipynb │ ├── scAR_tutorial_batch_denoising_scRNAseq.ipynb │ ├── scAR_tutorial_denoising_CITEseq.ipynb │ ├── scAR_tutorial_denoising_scATACseq.ipynb │ ├── scAR_tutorial_denoising_scRNAseq.ipynb │ ├── scAR_tutorial_identity_barcode.ipynb │ ├── scAR_tutorial_sgRNA_assignment.ipynb │ └── synthetic_assignment.ipynb └── usages │ ├── index.rst │ ├── processing.rst │ ├── synthetic_dataset.rst │ └── training.rst ├── pyproject.toml ├── scar-cpu.yml ├── scar-gpu.yml └── scar ├── __init__.py ├── main ├── __init__.py ├── __main__.py ├── _activation_functions.py ├── _data_generater.py ├── _loss_functions.py ├── _scar.py ├── _setup.py ├── _utils.py └── _vae.py └── test ├── ambient_profile.pickle ├── citeseq_ambient_profile.pickle ├── citeseq_native_counts.pickle ├── citeseq_raw_counts.pickle ├── output_assignment.pickle ├── raw_counts.pickle ├── test_activation_functions.py └── test_scar.py /.github/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | ## How to contribute to scAR 2 | 3 | #### **Where to find the documentation?** 4 | 5 | * If you want to try **scAR**, [first check the documentation](https://scar-tutorials.readthedocs.io/en/latest/). Be aware of the versions. 6 | 7 | #### **Did you find a bug?** 8 | 9 | * **Ensure the bug was not already reported** by searching on GitHub under [Issues](https://github.com/Novartis/scAR/issues). 10 | 11 | * If you're unable to find an open issue addressing the problem, [open a new one](https://github.com/Novartis/scAR/issues/new). Be sure to include a **title and clear description**, as much relevant information as possible, and a **code sample** or an **executable test case** demonstrating the expected behavior that is not occurring. 12 | 13 | #### **Did you write a patch that fixes a bug?** 14 | 15 | * Open a new GitHub pull request with the patch. **Don't mark it [semantic versioning label](https://github.com/Novartis/scAR/labels), this may make a new release** 16 | 17 | * Ensure the PR description clearly describes the problem and solution. 18 | 19 | #### **Do you have questions about scAR documentation?** 20 | 21 | * Discuss any non-code question about scAR in the [readthedocs discussion](https://scar-tutorials.readthedocs.io/en/latest/Contacts.html#comments). 22 | 23 | #### **Do you want to contribute to the scAR?** 24 | 25 | * Please contact me @caibin.sheng@novartis.com 26 | 27 | Thanks, Caibin 28 | -------------------------------------------------------------------------------- /.github/workflows/black.yaml: -------------------------------------------------------------------------------- 1 | name: Black lint 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | lint: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v2 10 | - uses: psf/black@stable 11 | with: 12 | options: "--diff --verbose --line-length 127" 13 | src: "./scar" 14 | version: "22.3.0" 15 | -------------------------------------------------------------------------------- /.github/workflows/python-conda-build.yaml: -------------------------------------------------------------------------------- 1 | name: Python Package using Conda 2 | 3 | on: [ pull_request ] 4 | 5 | jobs: 6 | build-linux: 7 | runs-on: ubuntu-latest 8 | strategy: 9 | max-parallel: 5 10 | steps: 11 | - uses: actions/checkout@v2 12 | - name: Set up Python 3.12.3 13 | uses: actions/setup-python@v2 14 | with: 15 | python-version: 3.12.3 16 | - name: Add conda to system path 17 | run: | 18 | # $CONDA is an environment variable pointing to the root of the miniconda directory 19 | echo $CONDA/bin >> $GITHUB_PATH 20 | - name: Install dependencies 21 | run: | 22 | conda env create -f scar-cpu.yml 23 | - name: Run binary 24 | run: | 25 | export PATH="$GITHUB_PATH:$PATH" 26 | source activate scar 27 | scar --help 28 | # - name: Run pylint 29 | # run: | 30 | # export PATH="$GITHUB_PATH:$PATH" 31 | # source activate scar 32 | # conda install -c anaconda pylint astroid 33 | # pylint scar --fail-under 8.5 --disable=R,C --generated-members=torch.* 34 | - name: Run unit tests 35 | run: | 36 | export PATH="$GITHUB_PATH:$PATH" 37 | source activate scar 38 | conda install pytest 39 | pytest scar 40 | -------------------------------------------------------------------------------- /.github/workflows/semantic-release.yaml: -------------------------------------------------------------------------------- 1 | name: Semantic Release 2 | 3 | on: 4 | pull_request: 5 | types: 6 | - labeled 7 | branches: 8 | - 'main' 9 | 10 | jobs: 11 | release: 12 | if: ${{ github.event.label.name == 'semantic versioning' }} 13 | runs-on: ubuntu-latest 14 | concurrency: release 15 | permissions: 16 | id-token: write 17 | contents: write 18 | steps: 19 | - uses: actions/checkout@v3 20 | with: 21 | ref: main 22 | fetch-depth: 0 23 | 24 | - name: Python Semantic Release 25 | uses: relekang/python-semantic-release@master 26 | with: 27 | github_token: ${{ secrets.GITHUB_TOKEN }} -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **.ipynb_checkpoints 2 | **pycache__ 3 | **.DS_Store 4 | **build 5 | **.egg-info 6 | **.vscode 7 | **dist 8 | **.h5 9 | **.h5ad -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | os: "ubuntu-20.04" 5 | tools: 6 | python: "3.11" 7 | 8 | sphinx: 9 | configuration: docs/conf.py 10 | 11 | python: 12 | install: 13 | - requirements: docs/requirements.txt 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |

3 | 4 |

5 | 6 | 7 | 8 | [![scAR](https://anaconda.org/bioconda/scar/badges/version.svg)](https://anaconda.org/bioconda/scar) 9 | [![install with bioconda](https://img.shields.io/badge/install%20with-bioconda-brightgreen.svg?style=flat)](http://bioconda.github.io/recipes/scar/README.html) 10 | [![code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) 11 | [![Documentation Status](https://readthedocs.org/projects/scar-tutorials/badge/?version=latest)](https://scar-tutorials.readthedocs.io/en/latest/?badge=latest) 12 | [![semantic-release: angular](https://img.shields.io/badge/semantic--release-angular-e10079?logo=semantic-release)](https://github.com/semantic-release/semantic-release) 13 | [![test](https://github.com/Novartis/scAR/actions/workflows/python-conda-build.yaml/badge.svg)](https://github.com/Novartis/scAR/actions/workflows/python-conda-build.yaml) 14 | [![Stars](https://img.shields.io/github/stars/Novartis/scar?logo=GitHub&color=red)](https://github.com/Novartis/scAR) 15 | [![Downloads](https://anaconda.org/bioconda/scar/badges/downloads.svg)](https://anaconda.org/bioconda/scar/files) 16 | 17 | **scAR** (single-cell Ambient Remover) is a tool designed for denoising ambient signals in droplet-based single-cell omics data. It can be employed for a wide range of applications, such as, **sgRNA assignment** in scCRISPRseq, **identity barcode assignment** in cell indexing, **protein denoising** in CITE-seq, **mRNA denoising** in scRNAseq, and **ATAC signal denoising** in scATACseq, among others. 18 | 19 | # Table of Contents 20 | 21 | - [Installation](#Installation) 22 | - [Dependencies](#Dependencies) 23 | - [Resources](#Resources) 24 | - [License](#License) 25 | - [Reference](#Reference) 26 | 27 | ## [Installation](https://scar-tutorials.readthedocs.io/en/latest/Installation.html) 28 | ## Dependencies 29 | 30 | [![PyTorch 1.8](https://img.shields.io/badge/PyTorch-1.8.0-greeen.svg)](https://pytorch.org/) 31 | [![Python 3.8.6](https://img.shields.io/badge/python-3.8.6-blue.svg)](https://www.python.org/) 32 | [![torchvision 0.9.0](https://img.shields.io/badge/torchvision-0.9.0-red.svg)](https://pytorch.org/vision/stable/index.html) 33 | [![tqdm 4.62.3](https://img.shields.io/badge/tqdm-4.62.3-orange.svg)](https://github.com/tqdm/tqdm) 34 | [![scikit-learn 1.0.1](https://img.shields.io/badge/scikit_learn-1.0.1-green.svg)](https://scikit-learn.org/) 35 | 36 | ## Resources 37 | 38 | - Installation, Usages and Tutorials can be found in the [documentation](https://scar-tutorials.readthedocs.io/en/latest/). 39 | - If you'd like to contribute, please read [contributing guidelines](https://github.com/Novartis/scAR/blob/main/.github/CONTRIBUTING.md). 40 | - Please use the [issues](https://github.com/Novartis/scAR/issues) to submit bug reports. 41 | 42 | ## License 43 | 44 | This project is licensed under the terms of [License](docs/License.rst). 45 | Copyright 2022 Novartis International AG. 46 | 47 | ## Reference 48 | 49 | If you use scAR in your research, please consider citing our [manuscript](https://doi.org/10.1101/2022.01.14.476312), 50 | 51 | ``` 52 | @article {Sheng2022.01.14.476312, 53 | author = {Sheng, Caibin and Lopes, Rui and Li, Gang and Schuierer, Sven and Waldt, Annick and Cuttat, Rachel and Dimitrieva, Slavica and Kauffmann, Audrey and Durand, Eric and Galli, Giorgio G and Roma, Guglielmo and de Weck, Antoine}, 54 | title = {Probabilistic modeling of ambient noise in single-cell omics data}, 55 | elocation-id = {2022.01.14.476312}, 56 | year = {2022}, 57 | doi = {10.1101/2022.01.14.476312}, 58 | publisher = {Cold Spring Harbor Laboratory}, 59 | URL = {https://www.biorxiv.org/content/early/2022/01/14/2022.01.14.476312}, 60 | eprint = {https://www.biorxiv.org/content/early/2022/01/14/2022.01.14.476312.full.pdf}, 61 | journal = {bioRxiv} 62 | } 63 | ``` 64 | -------------------------------------------------------------------------------- /docs/Contacts.rst: -------------------------------------------------------------------------------- 1 | Contacts 2 | =============== 3 | 4 | Authors 5 | ------------------------------------------------ 6 | The following authors contributed to design, implementation, integration, and maintenace of scAR. 7 | 8 | `@Caibin `_ 9 | `@AlexMTYZ `_ 10 | `@fgypas `_ 11 | `@mr-nvs `_ 12 | `@Tobias-Ternent `_ 13 | `@adeweck `_ 14 | 15 | Contributing 16 | ------------------------------------------------ 17 | All kinds of contribution are welcome! Please file `issues `_. 18 | 19 | Comments 20 | ------------------------------------------------ 21 | Post your comments or questions here: 22 | 23 | .. disqus:: 24 | :disqus_identifier: scar-discussion -------------------------------------------------------------------------------- /docs/Installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ================ 3 | 4 | To use ``scar``, first install it using conda install or Git+pip. 5 | 6 | Conda install 7 | ------------------------------- 8 | 9 | 1, Install `conda `_ 10 | 11 | 2, Create conda environment:: 12 | 13 | conda create -n scar 14 | 15 | 3, Activate conda environment:: 16 | 17 | conda activate scar 18 | 19 | 4, Install `PyTorch `_ 20 | 21 | 5, Install scar:: 22 | 23 | conda install bioconda::scar 24 | 25 | 6, Activate the scar conda environment:: 26 | 27 | conda activate scar 28 | 29 | Git + pip 30 | ------------------------------------------- 31 | 32 | 1, Clone scar repository:: 33 | 34 | git clone https://github.com/Novartis/scar.git 35 | 36 | 2, Enter the cloned directory:: 37 | 38 | cd scar 39 | 40 | 3, Create a conda environment 41 | 42 | .. tabs:: 43 | 44 | .. tab:: GPU version 45 | 46 | .. code-block:: 47 | :caption: Please use ``scar-gpu`` if you have an nvidia graphics card and the corresponding driver installed 48 | 49 | conda env create -f scar-gpu.yml 50 | 51 | .. tab:: CPU version 52 | 53 | .. code-block:: 54 | :caption: Please use ``scar-cpu`` if you don't have a graphics card availalble 55 | 56 | conda env create -f scar-cpu.yml 57 | 58 | 4, Activate the scar conda environment:: 59 | 60 | conda activate scar 61 | 62 | 63 | 64 | 65 | -------------------------------------------------------------------------------- /docs/Introduction.rst: -------------------------------------------------------------------------------- 1 | Introduction 2 | =============== 3 | 4 | What is ambient signal? 5 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 6 | 7 | .. image:: _static/ambient_signal_hypothesis.png 8 | :width: 500 9 | :align: center 10 | 11 | 12 | When preparing a single-cell solution, cell lysis releases RNA or protein counts that become encapsulated by droplets. These exogenous molecules are mixed with native ones and barcoded by the same 10x beads, leading to overestimated count data. The presence of this ambient signal can compromise downstream analysis and even introduce significant bias in certain cases, such as scCRISPR-seq and cell multiplexing. 13 | 14 | The design of scAR 15 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 16 | 17 | .. image:: _static/overview_scAR.png 18 | :width: 600 19 | :align: center 20 | 21 | 22 | scAR employs a latent variable model that represents both biological and technical components in the observed count data. The model is developed under the ambient signal hypothesis, where the probability of each ambient transcript's occurrence can be estimated empirically from cell-free droplets. The model has two hidden variables, namely the contamination level per cell and the probability of native transcript occurrence. With these three parameters, scAR can reconstruct noisy observations. We train neural networks, specifically the variational autoencoder, to learn the hidden variables by minimizing the differences between the reconstructions and the original noisy observations. Once the model converges, contamination levels and native expression are inferred, and downstream analysis can be performed using these values. For more information, please refer to our manuscript [Sheng2022]_. 23 | 24 | What types of data that scAR can process? 25 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 26 | We validated the effectiveness of scAR across a range of droplet-based single-cell omics technologies, including scRNAseq, scCRISPR-seq, CITEseq, and scATACseq. scAR was able to remove ambient mRNA, assign sgRNAs, assign cell tags, clean noisy protein counts (ADT), and clean peak counts, resulting in significant improvements in data quality in all tested datasets. Notably, scAR was able to recover a substantial proportion (33% to 50%) of cells in scCRISPR-seq and cell multiplexing experiments. Given that ambient contamination is a common issue in droplet-based single-cell omics, particularly in complex experiments or samples, scAR represents a viable solution for addressing this challenge. 27 | 28 | What are the alternative apporaches? 29 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 30 | Several methods exist for modeling noise in single-cell omics data. In general, these methods can be classified into two categories: those that deal with background noise and those that model stochastic noise. Some examples of these methods are provided below. 31 | 32 | +-------------------------------------------+-------------------------------------------+ 33 | | Background noise | Stachastic noise | 34 | +========+===============+==================+========+===============+==================+ 35 | | CellBender [Fleming2023]_ | scVI [Lopez2018]_ | 36 | +-------------------------------------------+-------------------------------------------+ 37 | | SoupX [Young2020]_ | DCA [Eraslan2019]_ | 38 | +-------------------------------------------+-------------------------------------------+ 39 | | DecontX [Yang2020]_ | | 40 | +-------------------------------------------+-------------------------------------------+ 41 | | totalVI (protein counts) [Gayoso2021]_ | | 42 | +-------------------------------------------+-------------------------------------------+ 43 | | DSB (protein counts) [Mulè2022]_ | | 44 | +-------------------------------------------+-------------------------------------------+ -------------------------------------------------------------------------------- /docs/License.rst: -------------------------------------------------------------------------------- 1 | License 2 | =============== 3 | 4 | MIT License 5 | 6 | Copyright (c) 2022 Novartis 7 | 8 | Permission is hereby granted, free of charge, to any person obtaining a copy 9 | of this software and associated documentation files (the "Software"), to deal 10 | in the Software without restriction, including without limitation the rights 11 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | copies of the Software, and to permit persons to whom the Software is 13 | furnished to do so, subject to the following conditions: 14 | 15 | The above copyright notice and this permission notice shall be included in all 16 | copies or substantial portions of the Software. 17 | 18 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | SOFTWARE. -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/README.rst: -------------------------------------------------------------------------------- 1 | Documentation 2 | ============= 3 | 4 | This folder contains all documentations of scar. See details in the `documentation `__. -------------------------------------------------------------------------------- /docs/Reference.rst: -------------------------------------------------------------------------------- 1 | Reference 2 | =============== 3 | 4 | .. [Dixit2016] Dixit *et al.* (2016), 5 | `Perturb-Seq: Dissecting Molecular Circuits with Scalable Single-Cell RNA Profiling of Pooled Genetic Screens `__, 6 | Cell. 7 | 8 | .. [Eraslan2019] Eraslan *et al.* (2019), 9 | `Single-cell RNA-seq denoising using a deep count autoencoder `__, 10 | Nature Communications. 11 | 12 | .. [Fleming2023] Fleming *et al.* (2023), 13 | `Unsupervised removal of systematic background noise from droplet-based single-cell experiments using CellBender `__, 14 | Nature Methods. 15 | 16 | .. [Gayoso2021] Gayoso *et al.* (2021), 17 | `Joint probabilistic modeling of single-cell multi-omic data with totalVI `__, 18 | Nature Methods. 19 | 20 | .. [Lopez2018] Lopez *et al.* (2018), 21 | `Deep generative modeling for single-cell transcriptomics `__, 22 | Nature Methods. 23 | 24 | .. [Lun2019] Lun *et al.* (2019), 25 | `EmptyDrops: Distinguishing cells from empty droplets in droplet-based single-cell RNA sequencing data `__, 26 | Genome Biology. 27 | 28 | .. [Ly2020] Ly *et al.* (2020), 29 | `The Bayesian Methodology of Sir Harold Jeffreys as a Practical Alternative to the P Value Hypothesis Test `__, 30 | Computational Brain & Behavior. 31 | 32 | .. [Mulè2022] Mulè *et al.* (2022), 33 | `Normalizing and denoising protein expression data from droplet-based single cell profiling `__, 34 | Nature Communications. 35 | 36 | .. [Sheng2022] Sheng *et al.* (2022), 37 | `Probabilistic machine learning ensures accurate ambient denoising in droplet-based single-cell omics `__, 38 | bioRxiv. 39 | 40 | .. [Yang2020] Yang *et al.* (2020), 41 | `Decontamination of ambient RNA in single-cell RNA-seq with DecontX `__, 42 | Genome Biology. 43 | 44 | .. [Young2020] Young *et al.* (2020), 45 | `SoupX removes ambient RNA contamination from droplet-based single-cell RNA sequencing data `__, 46 | GigaScience. -------------------------------------------------------------------------------- /docs/_static/ambient_signal_hypothesis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Novartis/scar/df7adef31162b471358b2b442445a449f939859d/docs/_static/ambient_signal_hypothesis.png -------------------------------------------------------------------------------- /docs/_static/bgd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Novartis/scar/df7adef31162b471358b2b442445a449f939859d/docs/_static/bgd.png -------------------------------------------------------------------------------- /docs/_static/overview_scAR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Novartis/scar/df7adef31162b471358b2b442445a449f939859d/docs/_static/overview_scAR.png -------------------------------------------------------------------------------- /docs/_static/scAR_favicon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Novartis/scar/df7adef31162b471358b2b442445a449f939859d/docs/_static/scAR_favicon.png -------------------------------------------------------------------------------- /docs/_static/scAR_logo_black.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Novartis/scar/df7adef31162b471358b2b442445a449f939859d/docs/_static/scAR_logo_black.png -------------------------------------------------------------------------------- /docs/_static/scAR_logo_transparent.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Novartis/scar/df7adef31162b471358b2b442445a449f939859d/docs/_static/scAR_logo_transparent.png -------------------------------------------------------------------------------- /docs/_static/scAR_logo_white.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Novartis/scar/df7adef31162b471358b2b442445a449f939859d/docs/_static/scAR_logo_white.png -------------------------------------------------------------------------------- /docs/_static/scar-styles.css: -------------------------------------------------------------------------------- 1 | body { 2 | margin: 0; 3 | font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, "Noto Sans", "Liberation Sans", sans-serif, "Apple Color Emoji", "Segoe UI Emoji", "Segoe UI Symbol", "Noto Color Emoji"; 4 | } -------------------------------------------------------------------------------- /docs/_templates/layout.html: -------------------------------------------------------------------------------- 1 | {# Import the theme's layout. #} 2 | {% extends "!layout.html" %} 3 | 4 | {# Custom CSS overrides #} 5 | {% set bootswatch_css_custom = ['_static/scar-styles.css'] %} -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | 13 | import os 14 | import sys 15 | 16 | sys.path.insert(0, os.path.abspath(".")) 17 | sys.path.insert(0, os.path.abspath("..")) 18 | from scar import __version__ 19 | 20 | # -- Project information ----------------------------------------------------- 21 | project = "scAR" 22 | copyright = "Novartis Institute for BioMedical Research, 2022" 23 | author = "Caibin Sheng" 24 | release = __version__ 25 | 26 | 27 | # -- General configuration --------------------------------------------------- 28 | 29 | # Add any Sphinx extension module names here, as strings. They can be 30 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 31 | # ones. 32 | extensions = [ 33 | "myst_parser", 34 | "sphinx.ext.autosectionlabel", 35 | "nbsphinx", 36 | "sphinx_gallery.gen_gallery", 37 | "sphinx_disqus.disqus", 38 | "sphinxarg.ext", 39 | "sphinx.ext.autodoc", 40 | "sphinx.ext.autosummary", 41 | "sphinx.ext.napoleon", 42 | "autodocsumm", 43 | "matplotlib.sphinxext.plot_directive", 44 | "sphinx_design", 45 | "sphinx_tabs.tabs", 46 | ] 47 | 48 | nbsphinx_execute = "never" 49 | nbsphinx_allow_errors = True 50 | autosummary_generate = True 51 | # Add type of source files 52 | 53 | nb_custom_formats = { 54 | ".md": ["jupytext.reads", {"fmt": "mystnb"}], 55 | } 56 | 57 | # Add any paths that contain templates here, relative to this directory. 58 | 59 | # List of patterns, relative to source directory, that match files and 60 | # directories to ignore when looking for source files. 61 | # This pattern also affects html_static_path and html_extra_path. 62 | exclude_patterns = ["Thumbs.db", ".DS_Store"] 63 | 64 | # Add comments 65 | disqus_shortname = "scar-discussion" 66 | 67 | # -- Options for HTML output ------------------------------------------------- 68 | 69 | # The theme to use for HTML and HTML Help pages. See the documentation for 70 | # a list of builtin themes. 71 | # 72 | html_theme = "pydata_sphinx_theme" 73 | templates_path = ["_templates"] 74 | html_static_path = ["_static"] 75 | html_css_files = ["scar-styles.css"] 76 | html_show_sourcelink = False 77 | 78 | # Add any paths that contain custom static files (such as style sheets) here, 79 | # relative to this directory. They are copied after the builtin static files, 80 | # so a file named "default.css" will overwrite the builtin "default.css". 81 | html_logo = "_static/scAR_logo_transparent.png" 82 | html_theme_options = { 83 | "logo": { 84 | "image_light": "scAR_logo_white.png", 85 | "image_dark": "scAR_logo_black.png", 86 | }, 87 | "pygment_light_style": "tango", 88 | "pygment_dark_style": "native", 89 | "icon_links": [ 90 | { 91 | "name": "GitHub", 92 | "url": "https://github.com/Novartis/scar", 93 | "icon": "fab fa-github-square", 94 | "type": "fontawesome", 95 | } 96 | ], 97 | "use_edit_page_button": False, 98 | "favicons": [ 99 | { 100 | "rel": "icon", 101 | "sizes": "32x32", 102 | "href": "_static/scAR_favicon.png", 103 | } 104 | ], 105 | } 106 | html_context = { 107 | "github_user": "Novartis", 108 | "github_repo": "scar", 109 | "github_version": "develop", 110 | "doc_path": "docs", 111 | } 112 | 113 | # html_sidebars = { 114 | # "**": ["search-field.html", "sidebar-nav-bs.html", "sidebar-ethical-ads.html"] 115 | # } 116 | 117 | autodoc_mock_imports = ["django"] 118 | autodoc_default_options = { 119 | "autosummary": True, 120 | } 121 | numpydoc_show_class_members = False 122 | 123 | # Options for plot examples 124 | plot_include_source = True 125 | plot_formats = [("png", 120)] 126 | plot_html_show_formats = False 127 | plot_html_show_source_link = False 128 | 129 | sphinx_gallery_conf = { 130 | "examples_dirs": "tutorials", # path to your example scripts 131 | "gallery_dirs": "_build/tutorial_gallery", # path to where to save gallery generated output 132 | "filename_pattern": "/scAR_tutorial_", 133 | # "ignore_pattern": r"__init__\.py", 134 | } 135 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. scAR documentation master file, created by 2 | sphinx-quickstart on Fri Apr 22 15:48:44 2022. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | scAR documentation 7 | ================================ 8 | 9 | **Version**: |release| 10 | 11 | **Useful links**: 12 | `Binary Installers `__ | 13 | `Source Repository `__ | 14 | `Issues `__ | 15 | `Contacts `__ 16 | 17 | 18 | :mod:`scAR` (single-cell Ambient Remover) is an explainable machine learning model designed to remove ambient signals from droplet-based single cell omics data. It is suitable for various applications such as **sgRNA assignment** in scCRISPR-seq, **identity barcode assignment** in cell multiplexing, **protein denoising** in CITE-seq, **mRNA denoising** in scRNAseq, **peak count denoising** in scATACseq and more. 19 | 20 | It is developed by Oncology Data Science, Novartis Institute for BioMedical Research. 21 | 22 | 23 | .. grid:: 1 2 2 2 24 | :gutter: 2 25 | 26 | .. grid-item-card:: Getting started 27 | :link: Introduction 28 | :link-type: doc 29 | :img-background: _static/bgd.png 30 | :class-card: sd-text-black 31 | 32 | New to *scAR*? Check out the getting started guide. It contains an introduction to *scAR's* main concepts. 33 | 34 | +++ 35 | .. button-ref:: Introduction 36 | :ref-type: doc 37 | :color: primary 38 | :shadow: 39 | :align: center 40 | 41 | What is scAR? 42 | 43 | .. grid-item-card:: Installation 44 | :link: Installation 45 | :link-type: doc 46 | :img-background: _static/bgd.png 47 | 48 | Want to install *scAR*? Check out the installation guide. It contains steps to install *scAR*. 49 | 50 | +++ 51 | .. button-ref:: Installation 52 | :ref-type: doc 53 | :color: primary 54 | :shadow: 55 | :align: center 56 | 57 | How to install scAR? 58 | 59 | .. grid-item-card:: API reference 60 | :link: usages/index 61 | :link-type: doc 62 | :img-background: _static/bgd.png 63 | 64 | The API reference contains detailed descriptions of scAR API. 65 | 66 | +++ 67 | .. button-ref:: usages/index 68 | :ref-type: doc 69 | :color: primary 70 | :shadow: 71 | :align: center 72 | 73 | To the API 74 | 75 | .. grid-item-card:: Tutorials 76 | :link: tutorials/index 77 | :link-type: doc 78 | :img-background: _static/bgd.png 79 | 80 | The tutorials walk you through the applications of scAR. 81 | 82 | +++ 83 | .. button-ref:: tutorials/index 84 | :ref-type: doc 85 | :color: primary 86 | :shadow: 87 | :align: center 88 | 89 | To Tutorials 90 | 91 | | 92 | 93 | .. toctree:: 94 | :hidden: 95 | 96 | Introduction 97 | Installation 98 | usages/index 99 | tutorials/index 100 | Release_notes 101 | Reference 102 | License 103 | Contacts -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | jinja2<3.1.0 2 | myst_parser 3 | nbsphinx 4 | sphinx-gallery<0.17.0 5 | sphinx-argparse 6 | sphinx-disqus 7 | autodocsumm 8 | ipykernel 9 | protobuf 10 | scanpy 11 | sphinx-design 12 | pydata-sphinx-theme 13 | sphinx_tabs 14 | git+https://github.com/Novartis/scAR.git 15 | -------------------------------------------------------------------------------- /docs/tutorials/README.rst: -------------------------------------------------------------------------------- 1 | This folder contains notebooks used in scAR tutorials -------------------------------------------------------------------------------- /docs/tutorials/index.rst: -------------------------------------------------------------------------------- 1 | Tutorials 2 | ============== 3 | 4 | There are two ways to run ``scar``. For Python users, we recommend the Python API; for R users, we recommend the command line tool. 5 | 6 | Run scar with Python API 7 | ------------------------ 8 | .. nbgallery:: 9 | 10 | scAR_tutorial_ambient_profile 11 | scAR_tutorial_sgRNA_assignment 12 | scAR_tutorial_identity_barcode 13 | scAR_tutorial_denoising_CITEseq 14 | scAR_tutorial_denoising_scRNAseq 15 | scAR_tutorial_denoising_scATACseq 16 | scAR_tutorial_batch_denoising_scRNAseq 17 | 18 | Run scar with the command line tool 19 | --------------------------------- 20 | 21 | The command line tool supports two formats of input. 22 | 23 | Use ``.h5`` files as the input 24 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 25 | 26 | We can use the output of cellranger count *filtered_feature_bc_matrix.h5* as the input for ``scar``:: 27 | 28 | scar filtered_feature_bc_matrix.h5 -ft feature_type -o output 29 | 30 | ``filtered_feature_bc_matrix.h5``, a filtered .h5 file produced by cellranger count. 31 | 32 | ``feature_type``, a string, either 'mRNA' or 'sgRNA' or 'ADT' or 'tag' or 'CMO' or 'ATAC'. 33 | 34 | .. note:: 35 | The ambient profile is calculated by averaging the cell pool under this mode. If you want to use a more accurate ambient profile, please consider calculating it and using ``.pickle`` files as the input, as detailed below. 36 | 37 | The output folder contains an h5ad file:: 38 | 39 | output 40 | └── filtered_feature_bc_matrix_denoised_feature_type.h5ad 41 | 42 | The h5ad file can be read by `scanpy.read `__ as an `anndata `__ object: 43 | 44 | - anndata.X, denosed counts. 45 | - anndata.obs['``noise_ratio``'], estimated noise ratio per cell. 46 | - anndata.layers['``native_frequencies``'], estimated native frequencies. 47 | - anndata.layers['``BayesFactor``'], bayesian factor of ambient contamination. 48 | - anndata.obs['``sgRNAs``' or '``tags``'], optional, feature assignment, e.g., sgRNA, tag, CMO, and etc.. 49 | 50 | 51 | Use ``.pickle`` files as the input 52 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 53 | We can also run ``scar`` by:: 54 | 55 | scar raw_count_matrix.pickle -ft feature_type -o output 56 | 57 | ``raw_count_matrix.pickle``, a file of raw count matrix (MxN) with cells in rows and features in columns. 58 | 59 | +--------+--------+--------+-----+--------+ 60 | | cells | gene_0 | gene_1 | ... | gene_y | 61 | +========+========+========+=====+========+ 62 | | cell_0 | 12 | 3 | ... | 82 | 63 | +--------+--------+--------+-----+--------+ 64 | | cell_1 | 13 | 0 | ... | 78 | 65 | +--------+--------+--------+-----+--------+ 66 | | cell_2 | 35 | 30 | ... | 170 | 67 | +--------+--------+--------+-----+--------+ 68 | | ... | ... | ... | ... | ... | 69 | +--------+--------+--------+-----+--------+ 70 | | cell_x | 16 | 5 | ... | 112 | 71 | +--------+--------+--------+-----+--------+ 72 | 73 | 74 | ``feature_type``, a string, either 'mRNA' or 'sgRNA' or 'ADT' or 'tag' or 'CMO' or 'ATAC'. 75 | 76 | .. note:: 77 | An extra argument ``ambient_profile`` is recommended to achieve deeper noise reduction. 78 | 79 | 80 | ``ambient_profile`` represents the probability of occurrence of each ambient transcript and can be empirically estimated by averging cell-free droplets. 81 | 82 | +--------+-----------------+ 83 | | genes | ambient profile | 84 | +========+=================+ 85 | | gene_0 | .0003 | 86 | +--------+-----------------+ 87 | | gene_1 | .00004 | 88 | +--------+-----------------+ 89 | | gene_2 | .00003 | 90 | +--------+-----------------+ 91 | | ... | ... | 92 | +--------+-----------------+ 93 | | gene_y | .0012 | 94 | +--------+-----------------+ 95 | 96 | .. warning:: 97 | ``ambient_profile`` should sum to one. The gene order should be consistent with ``raw_count_matrix``. 98 | 99 | For other optional arguments and parameters, run:: 100 | 101 | scar --help 102 | 103 | The output folder contains four (or five) files:: 104 | 105 | output 106 | ├── denoised_counts.pickle 107 | ├── expected_noise_ratio.pickle 108 | ├── BayesFactor.pickle 109 | ├── expected_native_freq.pickle 110 | └── assignment.pickle 111 | 112 | In the folder structure above: 113 | 114 | - ``expected_noise_ratio.pickle``, estimated noise ratio. 115 | - ``denoised_counts.pickle``, denoised count matrix. 116 | - ``BayesFactor.pickle``, bayesian factor of ambient contamination. 117 | - ``expected_native_freq.pickle``, estimated native frequencies. 118 | - ``assignment.pickle``, optional, feature assignment, e.g., sgRNA, tag, and etc.. 119 | -------------------------------------------------------------------------------- /docs/usages/index.rst: -------------------------------------------------------------------------------- 1 | API 2 | =============== 3 | 4 | Python API 5 | ---------------------- 6 | 7 | Processing 8 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 9 | Calculate ambient profile 10 | 11 | .. currentmodule:: scar.main._setup 12 | 13 | .. autosummary:: 14 | :nosignatures: 15 | 16 | setup_anndata 17 | 18 | 19 | Training 20 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 21 | The core module of scar 22 | 23 | .. currentmodule:: scar.main._scar 24 | 25 | .. autosummary:: 26 | :nosignatures: 27 | 28 | model 29 | 30 | Synthetic_dataset 31 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 32 | Generate synthetic datasets (scRNAseq, CITE-seq, scCRISPRseq) with ambient contamination 33 | 34 | .. currentmodule:: scar.main._data_generater 35 | 36 | .. autosummary:: 37 | :nosignatures: 38 | 39 | scrnaseq 40 | citeseq 41 | cropseq 42 | 43 | Plotting 44 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 45 | Plotting functions (under development). 46 | 47 | Reporting 48 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 49 | Generate denoising reports (under development). 50 | 51 | Command Line Interface 52 | ----------------------------------- 53 | 54 | .. argparse:: 55 | :module: scar.main.__main__ 56 | :func: scar_parser 57 | :prog: scar 58 | -------------------------------------------------------------------------------- /docs/usages/processing.rst: -------------------------------------------------------------------------------- 1 | Setup anndata 2 | ============== 3 | Calculate ambient profile for relevant feature types 4 | 5 | .. automodule:: scar.main._setup 6 | 7 | .. autofunction:: setup_anndata -------------------------------------------------------------------------------- /docs/usages/synthetic_dataset.rst: -------------------------------------------------------------------------------- 1 | Generate synthetic single-cell datasets 2 | ========================================== 3 | .. automodule:: scar.main._data_generater 4 | :members: 5 | :member-order: bysource 6 | -------------------------------------------------------------------------------- /docs/usages/training.rst: -------------------------------------------------------------------------------- 1 | Denoising model 2 | =============================== 3 | 4 | .. automodule:: scar.main._scar 5 | 6 | .. autoclass:: model 7 | :members: 8 | :member-order: bysource 9 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=68.1.2"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "scar" 7 | version = "0.7.0" 8 | requires-python = ">= 3.10" 9 | dependencies = [ 10 | "torch >= 1.10.0", 11 | "torchvision >= 0.9.0", 12 | "tqdm >= 4.62.3", 13 | "seaborn >= 0.11.2", 14 | "scikit-learn >= 1.0.1", 15 | "pyro-ppl >= 1.8.0", 16 | "scanpy >= 1.9.2" 17 | ] 18 | authors = [ 19 | {name = "Caibin Sheng", email = "caibin.sheng.res@gmail.com"} 20 | ] 21 | description = "scAR (single-cell Ambient Remover) is a package for denoising the ambient signals in droplet-based single cell omics" 22 | readme = "README.md" 23 | license = {text = "MIT License"} 24 | keywords = ["single cell omics", "variational autoencoder", "machine learning", "generative model", "cite-seq", "scCRISPRseq", "scRNAseq"] 25 | 26 | [project.urls] 27 | Homepage = "https://github.com/Novartis/scAR" 28 | Documentation = "https://scar-tutorials.readthedocs.io/en/main/" 29 | Repository = "https://github.com/Novartis/scar.git" 30 | Issues = "https://github.com/Novartis/scAR/issues" 31 | Changelog = "https://github.com/me/spam/blob/master/CHANGELOG.md" 32 | 33 | [tool.semantic_release] 34 | version_toml = ["pyproject.toml:project.version"] 35 | major_on_zero = false 36 | branch = "develop" 37 | upload_to_release = false 38 | hvcs = "github" 39 | upload_to_repository = false 40 | upload_to_pypi = false 41 | patch_without_tag = false 42 | 43 | [tool.semantic_release.changelog] 44 | changelog_file="docs/Release_notes.md" 45 | 46 | [project.gui-scripts] 47 | scar = "scar.main.__main__:main" 48 | -------------------------------------------------------------------------------- /scar-cpu.yml: -------------------------------------------------------------------------------- 1 | name: scar 2 | channels: 3 | - nvidia 4 | - pytorch 5 | - conda-forge 6 | dependencies: 7 | - conda-forge::python>=3.10 8 | - nvidia::cudatoolkit>=11.1 9 | - pytorch::pytorch>=1.10.0 10 | - pytorch::torchvision>=0.9.0 11 | - pytorch-mutex=*=cpu 12 | - conda-forge::tqdm>=4.62.3 13 | - conda-forge::seaborn>=0.11.2 14 | - conda-forge::scikit-learn>=1.0.1 15 | - conda-forge::scanpy 16 | - conda-forge::pyro-ppl>=1.8.0 17 | - conda-forge::pip 18 | - pip: 19 | - . 20 | -------------------------------------------------------------------------------- /scar-gpu.yml: -------------------------------------------------------------------------------- 1 | name: scar 2 | channels: 3 | - nvidia 4 | - pytorch 5 | - conda-forge 6 | dependencies: 7 | - conda-forge::python>=3.10 8 | - nvidia::cudatoolkit>=11.1 9 | - pytorch::pytorch>=1.10.0 10 | - pytorch::torchvision>=0.9.0 11 | - pytorch-mutex=*=cuda 12 | - conda-forge::tqdm>=4.62.3 13 | - conda-forge::seaborn>=0.11.2 14 | - conda-forge::scikit-learn>=1.0.1 15 | - conda-forge::scanpy 16 | - conda-forge::pyro-ppl>=1.8.0 17 | - conda-forge::pip 18 | - pip: 19 | - . 20 | -------------------------------------------------------------------------------- /scar/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from importlib.metadata import version 3 | __version__ = version("scar") 4 | 5 | from .main._scar import model 6 | from .main._setup import setup_anndata 7 | from .main import _data_generater as data_generator 8 | -------------------------------------------------------------------------------- /scar/main/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /scar/main/__main__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """command line of scar""" 3 | 4 | import argparse 5 | 6 | import os 7 | import pandas as pd, scanpy as sc 8 | from ._scar import model 9 | from ..__init__ import __version__ 10 | from ._utils import get_logger 11 | 12 | def main(): 13 | """main function for command line interface""" 14 | args = Config() 15 | count_matrix_path = args.count_matrix[0] 16 | ambient_profile_path = args.ambient_profile 17 | feature_type = args.feature_type 18 | output_dir = ( 19 | os.getcwd() if not args.output else args.output 20 | ) # if None, output to current directory 21 | count_model = args.count_model 22 | nn_layer1 = args.hidden_layer1 23 | nn_layer2 = args.hidden_layer2 24 | latent_dim = args.latent_dim 25 | epochs = args.epochs 26 | device = args.device 27 | sparsity = args.sparsity 28 | batchkey = args.batchkey 29 | cachecapacity = args.cachecapacity 30 | gnf = bool(args.get_native_frequencies) 31 | save_model = args.save_model 32 | batch_size = args.batchsize 33 | batch_size_infer = args.batchsize_infer 34 | adjust = args.adjust 35 | cutoff = args.cutoff 36 | moi = args.moi 37 | round_to_int = args.round2int 38 | clip_to_obs = args.clip_to_obs 39 | verbose = args.verbose 40 | 41 | main_logger = get_logger("scar", verbose=verbose) 42 | 43 | _, file_extension = os.path.splitext(count_matrix_path) 44 | 45 | if file_extension == ".pickle": 46 | count_matrix = pd.read_pickle(count_matrix_path) 47 | 48 | # Denoising transcritomic data 49 | elif file_extension == ".h5": 50 | adata = sc.read_10x_h5(count_matrix_path, gex_only=False) 51 | 52 | main_logger.info( 53 | "unprocessed data contains: {0} cells and {1} genes".format( 54 | adata.shape[0], adata.shape[1] 55 | ) 56 | ) 57 | adata = adata[:, adata.X.sum(axis=0) > 0] # filter out features of zero counts 58 | main_logger.info( 59 | "filter out features of zero counts, remaining data contains: {0} cells and {1} genes".format( 60 | adata.shape[0], adata.shape[1] 61 | ) 62 | ) 63 | 64 | if feature_type.lower() == "all": 65 | features = adata.var["feature_types"].unique() 66 | count_matrix = adata.copy() 67 | 68 | # Denoising mRNAs 69 | elif feature_type.lower() in ["mrna", "mrnas"]: 70 | features = "Gene Expression" 71 | adata_fb = adata[:, adata.var["feature_types"] == features] 72 | count_matrix = adata_fb.copy() 73 | 74 | # Denoising sgRNAs 75 | elif feature_type.lower() in ["sgrna", "sgrnas"]: 76 | features = "CRISPR Guide Capture" 77 | adata_fb = adata[:, adata.var["feature_types"] == features] 78 | count_matrix = adata_fb.copy() 79 | 80 | # Denoising CMO tags 81 | elif feature_type.lower() in ["tag", "tags"]: 82 | features = "Multiplexing Capture" 83 | adata_fb = adata[:, adata.var["feature_types"] == features] 84 | count_matrix = adata_fb.copy() 85 | 86 | # Denoising ADTs 87 | elif feature_type.lower() in ["adt", "adts"]: 88 | features = "Antibody Capture" 89 | adata_fb = adata[:, adata.var["feature_types"] == features] 90 | count_matrix = adata_fb.copy() 91 | 92 | # Denoising ATAC peaks 93 | elif feature_type.lower() in ["atac"]: 94 | features = "Peaks" 95 | adata_fb = adata[:, adata.var["feature_types"] == features] 96 | count_matrix = adata_fb.copy() 97 | 98 | main_logger.info(f"modalities to denoise: {features}") 99 | 100 | else: 101 | raise Exception(file_extension + " files are not supported.") 102 | 103 | if ambient_profile_path: 104 | _, ambient_profile_file_extension = os.path.splitext(ambient_profile_path) 105 | if ambient_profile_file_extension == ".pickle": 106 | ambient_profile = pd.read_pickle(ambient_profile_path) 107 | 108 | # Currently, use the default approach to calculate the ambient profile in the case of h5 109 | elif ambient_profile_file_extension == ".h5": 110 | ambient_profile = None 111 | 112 | else: 113 | raise Exception( 114 | ambient_profile_file_extension + " files are not supported." 115 | ) 116 | else: 117 | ambient_profile = None 118 | 119 | main_logger.info(f"feature_type: {feature_type}") 120 | main_logger.info(f"count_model: {count_model}") 121 | main_logger.info(f"output_dir: {output_dir}") 122 | main_logger.info(f"count_matrix_path: {count_matrix_path}") 123 | main_logger.info(f"ambient_profile_path: {ambient_profile_path}") 124 | main_logger.info(f"expected data sparsity: {sparsity:.2f}") 125 | 126 | if not os.path.isdir(output_dir): 127 | os.makedirs(output_dir) 128 | 129 | # Run model 130 | scar_model = model( 131 | raw_count=count_matrix, 132 | ambient_profile=ambient_profile, 133 | nn_layer1=nn_layer1, 134 | nn_layer2=nn_layer2, 135 | latent_dim=latent_dim, 136 | feature_type=feature_type, 137 | count_model=count_model, 138 | batch_key=batchkey, 139 | cache_capacity=cachecapacity, 140 | sparsity=sparsity, 141 | device=device, 142 | ) 143 | 144 | scar_model.train( 145 | batch_size=batch_size, 146 | epochs=epochs, 147 | save_model=save_model, 148 | ) 149 | 150 | scar_model.inference( 151 | adjust=adjust, 152 | get_native_frequencies=gnf, 153 | round_to_int=round_to_int, 154 | batch_size=batch_size_infer, 155 | clip_to_obs=clip_to_obs, 156 | ) 157 | 158 | if feature_type.lower() in ["sgrna", "sgrnas", "tag", "tags", "cmo", "cmos"]: 159 | scar_model.assignment(cutoff=cutoff, moi=moi) 160 | 161 | main_logger.info("Saving results...") 162 | 163 | # save results 164 | if file_extension == ".pickle": 165 | output_path01, output_path02, output_path03, output_path04 = ( 166 | os.path.join(output_dir, "denoised_counts.pickle"), 167 | os.path.join(output_dir, "BayesFactor.pickle"), 168 | os.path.join(output_dir, "native_frequency.pickle"), 169 | os.path.join(output_dir, "noise_ratio.pickle"), 170 | ) 171 | 172 | pd.DataFrame( 173 | scar_model.native_counts.toarray(), 174 | index=count_matrix.index, 175 | columns=count_matrix.columns, 176 | ).to_pickle(output_path01) 177 | main_logger.info(f"denoised counts saved in: {output_path01}") 178 | 179 | pd.DataFrame( 180 | scar_model.noise_ratio.toarray(), 181 | index=count_matrix.index, 182 | columns=["noise_ratio"] 183 | ).to_pickle(output_path04) 184 | main_logger.info(f"expected noise ratio saved in: {output_path04}") 185 | 186 | if scar_model.native_frequencies is not None: 187 | pd.DataFrame( 188 | scar_model.native_frequencies.toarray(), 189 | index=count_matrix.index, 190 | columns=count_matrix.columns, 191 | ).to_pickle(output_path03) 192 | main_logger.info(f"expected native frequencies saved in: {output_path03}") 193 | 194 | if feature_type.lower() in ["sgrna", "sgrnas", "tag", "tags", "cmo", "cmos"]: 195 | pd.DataFrame( 196 | scar_model.bayesfactor.toarray(), 197 | index=count_matrix.index, 198 | columns=count_matrix.columns, 199 | ).to_pickle(output_path02) 200 | main_logger.info(f"BayesFactor matrix saved in: {output_path02}") 201 | 202 | output_path05 = os.path.join(output_dir, "assignment.pickle") 203 | scar_model.feature_assignment.to_pickle(output_path05) 204 | main_logger.info(f"assignment saved in: {output_path05}") 205 | 206 | elif file_extension == ".h5": 207 | output_path_h5ad = os.path.join( 208 | output_dir, f"filtered_feature_bc_matrix_denoised_{feature_type}.h5ad" 209 | ) 210 | 211 | denoised_adata = adata.copy() 212 | denoised_adata.X = scar_model.native_counts 213 | denoised_adata.obs["noise_ratio"] = pd.DataFrame( 214 | scar_model.noise_ratio.toarray(), 215 | index=count_matrix.obs_names, 216 | columns=["noise_ratio"], 217 | ) 218 | if scar_model.native_frequencies is not None: 219 | denoised_adata.layers["native_frequencies"] = scar_model.native_frequencies.toarray() 220 | 221 | if feature_type.lower() in ["sgrna", "sgrnas", "tag", "tags", "cmo", "cmos"]: 222 | denoised_adata.obs = denoised_adata.obs.join(scar_model.feature_assignment) 223 | denoised_adata.layers["BayesFactor"] = scar_model.bayesfactor.toarray() 224 | 225 | denoised_adata.write(output_path_h5ad) 226 | main_logger.info(f"the denoised h5ad file saved in: {output_path_h5ad}") 227 | 228 | 229 | class Config: 230 | """ 231 | The configuration options. Options can be specified as command-line arguments. 232 | """ 233 | 234 | def __init__(self) -> None: 235 | """Initialize configuration values.""" 236 | self.parser = scar_parser() 237 | self.namespace = vars(self.parser.parse_args()) 238 | 239 | def __getattr__(self, option): 240 | return self.namespace[option] 241 | 242 | 243 | def scar_parser(): 244 | """Argument parser""" 245 | 246 | parser = argparse.ArgumentParser( 247 | description="scAR (single-cell Ambient Remover) is a deep learning model for removal of the ambient signals in droplet-based single cell omics", 248 | formatter_class=argparse.RawTextHelpFormatter, 249 | ) 250 | parser.add_argument( 251 | "--version", 252 | action="version", 253 | version=f"%(prog)s version: {__version__}", 254 | ) 255 | parser.add_argument( 256 | "count_matrix", 257 | type=str, 258 | nargs="+", 259 | help="The file of raw count matrix, 2D array (cells x genes) or the path of a filtered_feature_bc_matrix.h5", 260 | ) 261 | parser.add_argument( 262 | "-ap", 263 | "--ambient_profile", 264 | type=str, 265 | default=None, 266 | help="The file of empty profile obtained from empty droplets, 1D array", 267 | ) 268 | parser.add_argument( 269 | "-ft", 270 | "--feature_type", 271 | type=str, 272 | default="mRNA", 273 | help="The feature types, e.g. mRNA, sgRNA, ADT, tag, CMO and ATAC", 274 | ) 275 | parser.add_argument( 276 | "-o", "--output", type=str, default=None, help="Output directory" 277 | ) 278 | parser.add_argument( 279 | "-m", "--count_model", type=str, default="binomial", help="Count model" 280 | ) 281 | parser.add_argument( 282 | "-sp", 283 | "--sparsity", 284 | type=float, 285 | default=0.9, 286 | help="The sparsity of expected native signals", 287 | ) 288 | parser.add_argument( 289 | "-bk", 290 | "--batchkey", 291 | type=str, 292 | default=None, 293 | help="The batch key for batch correction", 294 | ) 295 | parser.add_argument( 296 | "-cache", 297 | "--cachecapacity", 298 | type=int, 299 | default=20000, 300 | help="The capacity of cache for batch correction", 301 | ) 302 | parser.add_argument( 303 | "-gnf", 304 | "--get_native_frequencies", 305 | type=int, 306 | default=0, 307 | help="Whether to get native frequencies, 0 or 1, by default 0, not to get native frequencies", 308 | ) 309 | parser.add_argument( 310 | "-hl1", 311 | "--hidden_layer1", 312 | type=int, 313 | default=150, 314 | help="Number of neurons in the first layer", 315 | ) 316 | parser.add_argument( 317 | "-hl2", 318 | "--hidden_layer2", 319 | type=int, 320 | default=100, 321 | help="Number of neurons in the second layer", 322 | ) 323 | parser.add_argument( 324 | "-ls", 325 | "--latent_dim", 326 | type=int, 327 | default=15, 328 | help="Dimension of latent space", 329 | ) 330 | parser.add_argument( 331 | "-epo", "--epochs", type=int, default=800, help="Training epochs" 332 | ) 333 | parser.add_argument( 334 | "-d", 335 | "--device", 336 | type=str, 337 | default="auto", 338 | help="Device used for training, either 'auto', 'cpu', or 'cuda'", 339 | ) 340 | parser.add_argument( 341 | "-s", 342 | "--save_model", 343 | type=int, 344 | default=False, 345 | help="Save the trained model", 346 | ) 347 | parser.add_argument( 348 | "-batchsize", 349 | "--batchsize", 350 | type=int, 351 | default=64, 352 | help="Batch size for training, set a small value upon out of memory error", 353 | ) 354 | parser.add_argument( 355 | "-batchsize_infer", 356 | "--batchsize_infer", 357 | type=int, 358 | default=4096, 359 | help="Batch size for inference, set a small value upon out of memory error", 360 | ) 361 | parser.add_argument( 362 | "-adjust", 363 | "--adjust", 364 | type=str, 365 | default="micro", 366 | help="""Only used for calculating Bayesfactors to improve performance, 367 | 368 | | 'micro' -- adjust the estimated native counts per cell. Default. 369 | | 'global' -- adjust the estimated native counts globally. 370 | | False -- no adjustment, use the model-returned native counts.""", 371 | ) 372 | parser.add_argument( 373 | "-cutoff", 374 | "--cutoff", 375 | type=float, 376 | default=3, 377 | help="Cutoff for Bayesfactors. See [Ly2020]_", 378 | ) 379 | parser.add_argument( 380 | "-round2int", 381 | "--round2int", 382 | type=str, 383 | default="stochastic_rounding", 384 | help="Round the counts", 385 | ) 386 | 387 | parser.add_argument( 388 | "-clip_to_obs", 389 | "--clip_to_obs", 390 | type=bool, 391 | default=False, 392 | help="clip the predicted native counts by observed counts, \ 393 | use it with caution, as it may lead to overestimation of overall noise.", 394 | ) 395 | parser.add_argument( 396 | "-moi", 397 | "--moi", 398 | type=float, 399 | default=None, 400 | help="Multiplicity of Infection. If assigned, it will allow optimized thresholding, \ 401 | which tests a series of cutoffs to find the best one based on distributions of infections under given moi. \ 402 | See [Dixit2016]_ for details. Under development.", 403 | ) 404 | parser.add_argument( 405 | "-verbose", 406 | "--verbose", 407 | type=bool, 408 | default=True, 409 | help="Whether to print the logging messages", 410 | ) 411 | return parser 412 | 413 | 414 | if __name__ == "__main__": 415 | main() 416 | -------------------------------------------------------------------------------- /scar/main/_activation_functions.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Customized activation functions 4 | """ 5 | 6 | import torch 7 | 8 | 9 | class mytanh(torch.nn.Module): 10 | def __init__(self): 11 | super().__init__() # init the base class 12 | 13 | def forward(self, input_x): 14 | var_tanh = torch.tanh(input_x) 15 | output = (1 + var_tanh) / 2 16 | return output 17 | 18 | 19 | class hnormalization(torch.nn.Module): 20 | def __init__(self): 21 | super().__init__() 22 | 23 | def forward(self, input_x): 24 | return input_x / (input_x.sum(dim=1).view(-1, 1) + 1e-5) 25 | 26 | 27 | class mysoftplus(torch.nn.Module): 28 | def __init__(self, sparsity=0.9): 29 | super().__init__() # init the base class 30 | self.sparsity = sparsity 31 | 32 | def forward(self, input_x): 33 | return self._mysoftplus(input_x) 34 | 35 | def _mysoftplus(self, input_x): 36 | """customized softplus activation, output range: [0, inf)""" 37 | var_sp = torch.nn.functional.softplus(input_x) 38 | threshold = torch.nn.functional.softplus( 39 | torch.tensor(-(1 - self.sparsity) * 10.0, device=input_x.device) 40 | ) 41 | var_sp = var_sp - threshold 42 | zero = torch.zeros_like(threshold) 43 | var_out = torch.where(var_sp <= zero, zero, var_sp) 44 | return var_out 45 | -------------------------------------------------------------------------------- /scar/main/_data_generater.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Module to generate synthetic datasets with ambient contamination""" 3 | 4 | import numpy as np 5 | from numpy import random 6 | import matplotlib.pyplot as plt 7 | import seaborn as sns 8 | 9 | ################################################################################### 10 | # Synthetic scrnaseq datasets 11 | 12 | 13 | class scrnaseq: 14 | """Generate synthetic single-cell RNAseq data with ambient contamination 15 | 16 | Parameters 17 | ---------- 18 | n_cells : int 19 | number of cells 20 | n_celltypes : int 21 | number of cell types 22 | n_features : int 23 | number of features (mRNA) 24 | n_total_molecules : int, optional 25 | total molecules per cell, by default 8000 26 | capture_rate : float, optional 27 | the probability of being captured by beads, by default 0.7 28 | 29 | Examples 30 | -------- 31 | .. plot:: 32 | :context: close-figs 33 | 34 | import numpy as np 35 | from scar import data_generator 36 | 37 | n_features = 1000 # 1000 genes, bad visualization with too big number 38 | n_cells = 6000 # cells 39 | n_total_molecules = 20000 # total mRNAs 40 | n_celltypes = 8 # cell types 41 | 42 | np.random.seed(8) 43 | scRNAseq = data_generator.scrnaseq(n_cells, n_celltypes, n_features, n_total_molecules=n_total_molecules) 44 | scRNAseq.generate(dirichlet_concentration_hyper=1) 45 | scRNAseq.heatmap(vmax=5) 46 | 47 | """ 48 | 49 | def __init__( 50 | self, n_cells, n_celltypes, n_features, n_total_molecules=8000, capture_rate=0.7 51 | ): 52 | """initilization""" 53 | self.n_cells = n_cells 54 | """int, number of cells""" 55 | self.n_celltypes = n_celltypes 56 | """int, number of cell types""" 57 | self.n_features = n_features 58 | """int, number of features (mRNA, sgRNA, ADT, tag, CMO, and etc.)""" 59 | self.n_total_molecules = n_total_molecules 60 | """int, number of total molecules per cell""" 61 | self.capture_rate = capture_rate 62 | """float, the probability of being captured by beads""" 63 | self.obs_count = (None,) 64 | """vector, observed counts""" 65 | self.ambient_profile = (None,) 66 | """vector, the probability of occurrence of each ambient transcript""" 67 | self.cell_identity = (None,) 68 | """matrix, the onehot expression of the identity of cell types""" 69 | self.noise_ratio = (None,) 70 | """vector, contamination level per cell""" 71 | self.celltype = (None,) 72 | """vector, the identity of cell types""" 73 | self.ambient_signals = (None,) 74 | """matrix, the real ambient signals""" 75 | self.native_signals = (None,) 76 | """matrix, the real native signals""" 77 | self.native_profile = (None,) 78 | """matrix, the frequencies of the real native signals""" 79 | self.total_counts = (None,) 80 | """vector, the total observed counts per cell""" 81 | self.empty_droplets = None 82 | """matrix, synthetic cell-free droplets""" 83 | 84 | def generate(self, dirichlet_concentration_hyper=0.05): 85 | """Generate a synthetic scRNAseq dataset. 86 | 87 | Parameters 88 | ---------- 89 | dirichlet_concentration_hyper : None or real, optional 90 | the concentration hyperparameters of dirichlet distribution. \ 91 | Determining the sparsity of native signals. \ 92 | If None, 1 / n_features, by default 0.005. 93 | 94 | Returns 95 | ------- 96 | After running, several attributes are added 97 | """ 98 | 99 | if dirichlet_concentration_hyper: 100 | alpha = np.ones(self.n_features) * dirichlet_concentration_hyper 101 | else: 102 | alpha = np.ones(self.n_features) / self.n_features 103 | 104 | # simulate native expression frequencies for each cell 105 | cell_comp_prior = random.dirichlet(np.ones(self.n_celltypes)) 106 | celltype = random.choice( 107 | a=self.n_celltypes, size=self.n_cells, p=cell_comp_prior 108 | ) 109 | cell_identity = np.identity(self.n_celltypes)[celltype] 110 | theta_celltype = random.dirichlet(alpha, size=self.n_celltypes) 111 | 112 | beta_in_each_cell = cell_identity.dot(theta_celltype) 113 | 114 | # simulate total molecules for a droplet in ambient pool 115 | n_total_mol = random.randint( 116 | low=self.n_total_molecules / 5, high=self.n_total_molecules / 2, size=1 117 | ) 118 | 119 | # simulate ambient signals 120 | beta0 = random.dirichlet(np.ones(self.n_features)) 121 | tot_count0 = random.negative_binomial( 122 | n_total_mol, self.capture_rate, size=self.n_cells 123 | ) 124 | ambient_signals = np.vstack( 125 | [random.multinomial(n=tot_c, pvals=beta0) for tot_c in tot_count0] 126 | ) 127 | 128 | # add empty droplets 129 | tot_count0_empty = random.negative_binomial( 130 | n_total_mol, self.capture_rate, size=self.n_cells 131 | ) 132 | ambient_signals_empty = np.vstack( 133 | [random.multinomial(n=tot_c, pvals=beta0) for tot_c in tot_count0_empty] 134 | ) 135 | 136 | # simulate native signals 137 | tot_trails = random.randint( 138 | low=self.n_total_molecules / 2, 139 | high=self.n_total_molecules, 140 | size=self.n_celltypes, 141 | ) 142 | tot_count1 = [ 143 | random.negative_binomial(tot, self.capture_rate) 144 | for tot in cell_identity.dot(tot_trails) 145 | ] 146 | 147 | native_signals = np.vstack( 148 | [ 149 | random.multinomial(n=tot_c, pvals=theta1) 150 | for tot_c, theta1 in zip(tot_count1, beta_in_each_cell) 151 | ] 152 | ) 153 | obs = ambient_signals + native_signals 154 | 155 | noise_ratio = tot_count0 / (tot_count0 + tot_count1) 156 | 157 | self.obs_count = obs 158 | self.ambient_profile = beta0 159 | self.cell_identity = cell_identity 160 | self.noise_ratio = noise_ratio 161 | self.celltype = celltype 162 | self.ambient_signals = ambient_signals 163 | self.native_signals = native_signals 164 | self.native_profile = beta_in_each_cell 165 | self.total_counts = obs.sum(axis=1) 166 | self.empty_droplets = ambient_signals_empty.astype(int) 167 | 168 | def heatmap( 169 | self, feature_type="mRNA", return_obj=False, figsize=(12, 4), vmin=0, vmax=10 170 | ): 171 | """Heatmap of synthetic data. 172 | 173 | Parameters 174 | ---------- 175 | feature_type : str, optional 176 | the feature types, by default "mRNA" 177 | return_obj : bool, optional 178 | whether to output figure object, by default False 179 | figsize : tuple, optional 180 | figure size, by default (15, 5) 181 | vmin : int, optional 182 | colorbar minimum, by default 0 183 | vmax : int, optional 184 | colorbar maximum, by default 10 185 | 186 | Returns 187 | ------- 188 | fig object 189 | if return_obj, return a fig object 190 | """ 191 | sort_cell_idx = [] 192 | for f in self.ambient_profile.argsort(): 193 | sort_cell_idx += list(np.where(self.celltype == f)[0]) 194 | 195 | native_signals = self.native_signals[sort_cell_idx][ 196 | :, self.ambient_profile.argsort() 197 | ] 198 | ambient_signals = self.ambient_signals[sort_cell_idx][ 199 | :, self.ambient_profile.argsort() 200 | ] 201 | obs = self.obs_count[sort_cell_idx][:, self.ambient_profile.argsort()] 202 | 203 | fig, axs = plt.subplots(ncols=3, figsize=figsize) 204 | sns.heatmap( 205 | np.log2(obs + 1), 206 | yticklabels=False, 207 | vmin=vmin, 208 | vmax=vmax, 209 | cmap="coolwarm", 210 | center=1, 211 | ax=axs[0], 212 | rasterized=True, 213 | cbar_kws={"label": "log2(counts + 1)"}, 214 | ) 215 | axs[0].set_title("noisy observation") 216 | 217 | sns.heatmap( 218 | np.log2(ambient_signals + 1), 219 | yticklabels=False, 220 | vmin=vmin, 221 | vmax=vmax, 222 | cmap="coolwarm", 223 | center=1, 224 | ax=axs[1], 225 | rasterized=True, 226 | cbar_kws={"label": "log2(counts + 1)"}, 227 | ) 228 | axs[1].set_title("ambient signals") 229 | 230 | sns.heatmap( 231 | np.log2(native_signals + 1), 232 | yticklabels=False, 233 | vmin=vmin, 234 | vmax=vmax, 235 | cmap="coolwarm", 236 | center=1, 237 | ax=axs[2], 238 | rasterized=True, 239 | cbar_kws={"label": "log2(counts + 1)"}, 240 | ) 241 | axs[2].set_title("native signals") 242 | 243 | fig.supxlabel(feature_type) 244 | fig.supylabel("cells") 245 | plt.tight_layout() 246 | 247 | if return_obj: 248 | return fig 249 | 250 | 251 | ###################################################################################### 252 | # Synthetic citeseq datasets 253 | class citeseq(scrnaseq): 254 | """Generate synthetic ADT count data for CITE-seq with ambient contamination 255 | 256 | Parameters 257 | ---------- 258 | n_cells : int 259 | number of cells 260 | n_celltypes : int 261 | number of cell types 262 | n_features : int 263 | number of distinct antibodies (ADTs) 264 | n_total_molecules : int, optional 265 | number of total molecules, by default 8000 266 | capture_rate : float, optional 267 | the probabilities of being captured by beads, by default 0.7 268 | 269 | Examples 270 | -------- 271 | .. plot:: 272 | :context: close-figs 273 | 274 | import numpy as np 275 | from scar import data_generator 276 | 277 | n_features = 50 # 50 ADTs 278 | n_cells = 6000 # 6000 cells 279 | n_celltypes = 6 # cell types 280 | 281 | # generate a synthetic ADT count dataset 282 | np.random.seed(8) 283 | citeseq = data_generator.citeseq(n_cells, n_celltypes, n_features) 284 | citeseq.generate() 285 | citeseq.heatmap() 286 | 287 | """ 288 | 289 | def __init__( 290 | self, n_cells, n_celltypes, n_features, n_total_molecules=8000, capture_rate=0.7 291 | ): 292 | super().__init__( 293 | n_cells, n_celltypes, n_features, n_total_molecules, capture_rate 294 | ) 295 | 296 | def generate(self, dirichlet_concentration_hyper=None): 297 | """Generate a synthetic ADT dataset. 298 | 299 | Parameters 300 | ---------- 301 | dirichlet_concentration_hyper : None or real, optional 302 | the concentration hyperparameters of dirichlet distribution. \ 303 | If None, 1 / n_features, by default None 304 | 305 | Returns 306 | ------- 307 | After running, several attributes are added 308 | """ 309 | 310 | if dirichlet_concentration_hyper: 311 | alpha = np.ones(self.n_features) * dirichlet_concentration_hyper 312 | else: 313 | alpha = np.ones(self.n_features) / self.n_features 314 | 315 | # simulate native expression frequencies for each cell 316 | cell_comp_prior = random.dirichlet(np.ones(self.n_celltypes)) 317 | celltype = random.choice( 318 | a=self.n_celltypes, size=self.n_cells, p=cell_comp_prior 319 | ) 320 | cell_identity = np.identity(self.n_celltypes)[celltype] 321 | theta_celltype = random.dirichlet(alpha, size=self.n_celltypes) 322 | beta_in_each_cell = cell_identity.dot(theta_celltype) 323 | 324 | # simulate total molecules for a droplet in ambient pool 325 | n_total_mol = random.randint( 326 | low=self.n_total_molecules / 5, high=self.n_total_molecules / 2, size=1 327 | ) 328 | 329 | # simulate ambient signals 330 | beta0 = random.dirichlet(np.ones(self.n_features)) 331 | tot_count0 = random.negative_binomial( 332 | n_total_mol, self.capture_rate, size=self.n_cells 333 | ) 334 | ambient_signals = np.vstack( 335 | [random.multinomial(n=tot_c, pvals=beta0) for tot_c in tot_count0] 336 | ) 337 | 338 | # add empty droplets 339 | tot_count0_empty = random.negative_binomial( 340 | n_total_mol, self.capture_rate, size=self.n_cells 341 | ) 342 | ambient_signals_empty = np.vstack( 343 | [random.multinomial(n=tot_c, pvals=beta0) for tot_c in tot_count0_empty] 344 | ) 345 | 346 | # simulate native signals 347 | tot_trails = random.randint( 348 | low=self.n_total_molecules / 2, 349 | high=self.n_total_molecules, 350 | size=self.n_celltypes, 351 | ) 352 | tot_count1 = [ 353 | random.negative_binomial(tot, self.capture_rate) 354 | for tot in cell_identity.dot(tot_trails) 355 | ] 356 | 357 | native_signals = np.vstack( 358 | [ 359 | random.multinomial(n=tot_c, pvals=theta1) 360 | for tot_c, theta1 in zip(tot_count1, beta_in_each_cell) 361 | ] 362 | ) 363 | obs = ambient_signals + native_signals 364 | 365 | noise_ratio = tot_count0 / (tot_count0 + tot_count1) 366 | 367 | self.obs_count = obs 368 | self.ambient_profile = beta0 369 | self.cell_identity = cell_identity 370 | self.noise_ratio = noise_ratio 371 | self.celltype = celltype 372 | self.ambient_signals = ambient_signals 373 | self.native_signals = native_signals 374 | self.native_profile = beta_in_each_cell 375 | self.total_counts = obs.sum(axis=1) 376 | self.empty_droplets = ambient_signals_empty.astype(int) 377 | 378 | def heatmap( 379 | self, feature_type="ADT", return_obj=False, figsize=(12, 4), vmin=0, vmax=10 380 | ): 381 | """Heatmap of synthetic data. 382 | 383 | Parameters 384 | ---------- 385 | feature_type : str, optional 386 | the feature types, by default "ADT" 387 | return_obj : bool, optional 388 | whether to output figure object, by default False 389 | figsize : tuple, optional 390 | figure size, by default (15, 5) 391 | vmin : int, optional 392 | colorbar minimum, by default 0 393 | vmax : int, optional 394 | colorbar maximum, by default 10 395 | 396 | Returns 397 | ------- 398 | fig object 399 | if return_obj, return a fig object 400 | """ 401 | sort_cell_idx = [] 402 | for f in self.ambient_profile.argsort(): 403 | sort_cell_idx += list(np.where(self.celltype == f)[0]) 404 | 405 | native_signals = self.native_signals[sort_cell_idx][ 406 | :, self.ambient_profile.argsort() 407 | ] 408 | ambient_signals = self.ambient_signals[sort_cell_idx][ 409 | :, self.ambient_profile.argsort() 410 | ] 411 | obs = self.obs_count[sort_cell_idx][:, self.ambient_profile.argsort()] 412 | 413 | fig, axs = plt.subplots(ncols=3, figsize=figsize) 414 | sns.heatmap( 415 | np.log2(obs + 1), 416 | yticklabels=False, 417 | vmin=vmin, 418 | vmax=vmax, 419 | cmap="coolwarm", 420 | center=1, 421 | ax=axs[0], 422 | rasterized=True, 423 | cbar_kws={"label": "log2(counts + 1)"}, 424 | ) 425 | axs[0].set_title("noisy observation") 426 | 427 | sns.heatmap( 428 | np.log2(ambient_signals + 1), 429 | yticklabels=False, 430 | vmin=vmin, 431 | vmax=vmax, 432 | cmap="coolwarm", 433 | center=1, 434 | ax=axs[1], 435 | rasterized=True, 436 | cbar_kws={"label": "log2(counts + 1)"}, 437 | ) 438 | axs[1].set_title("ambient signals") 439 | 440 | sns.heatmap( 441 | np.log2(native_signals + 1), 442 | yticklabels=False, 443 | vmin=vmin, 444 | vmax=vmax, 445 | cmap="coolwarm", 446 | center=1, 447 | ax=axs[2], 448 | rasterized=True, 449 | cbar_kws={"label": "log2(counts + 1)"}, 450 | ) 451 | axs[2].set_title("native signals") 452 | 453 | fig.supxlabel(feature_type) 454 | fig.supylabel("cells") 455 | plt.tight_layout() 456 | 457 | if return_obj: 458 | return fig 459 | 460 | 461 | ########################################################################################## 462 | # Synthetic cropseq datasets 463 | 464 | 465 | class cropseq(scrnaseq): 466 | """Generate synthetic sgRNA count data for scCRISPRseq with ambient contamination 467 | 468 | Parameters 469 | ---------- 470 | n_cells : int 471 | number of cells 472 | n_celltypes : int 473 | number of cell types 474 | n_features : int 475 | number of dinstinct sgRNAs 476 | library_pattern : str, optional 477 | the pattern of sgRNA libraries, three possibilities: 478 | 479 | | "uniform" - each sgRNA has equal frequency in the libraries 480 | | "pyramid" - a few sgRNAs have significantly higher frequencies in the libraries 481 | | "reverse_pyramid" - a few sgRNAs have significantly lower frequencies in the libraries 482 | | By default "pyramid". 483 | noise_ratio : float, optional 484 | global contamination level, by default 0.005 485 | average_counts_per_cell : int, optional 486 | average total sgRNA counts per cell, by default 2000 487 | doublet_rate : int, optional 488 | doublet rate, by default 0 489 | missing_rate : int, optional 490 | the fraction of droplets which have zero sgRNAs integrated, by default 0 491 | 492 | Examples 493 | -------- 494 | 495 | .. plot:: 496 | :context: close-figs 497 | 498 | import numpy as np 499 | from scar import data_generator 500 | 501 | n_features = 100 # 100 sgRNAs in the libraries 502 | n_cells = 6000 # 6000 cells 503 | n_celltypes = 1 # single cell line 504 | 505 | # generate a synthetic sgRNA count dataset 506 | np.random.seed(8) 507 | cropseq = data_generator.cropseq(n_cells, n_celltypes, n_features) 508 | cropseq.generate(noise_ratio=0.98) 509 | cropseq.heatmap(vmax=6) 510 | """ 511 | 512 | def __init__( 513 | self, 514 | n_cells, 515 | n_celltypes, 516 | n_features, 517 | ): 518 | super().__init__(n_cells, n_celltypes, n_features) 519 | 520 | self.sgrna_freq = None 521 | """vector, sgRNA frequencies in the libraries 522 | """ 523 | 524 | # generate a pool of sgrnas 525 | def _set_sgrna_frequency(self): 526 | """set the pattern of sgrna library""" 527 | if self.library_pattern == "uniform": 528 | self.sgrna_freq = 1.0 / self.n_features 529 | elif self.library_pattern == "pyramid": 530 | uniform_spaced_values = np.random.permutation(self.n_features + 1) / ( 531 | self.n_features + 1 532 | ) 533 | uniform_spaced_values = uniform_spaced_values[uniform_spaced_values != 0] 534 | log_values = np.log(uniform_spaced_values) 535 | self.sgrna_freq = log_values / np.sum(log_values) 536 | elif self.library_pattern == "reverse_pyramid": 537 | uniform_spaced_values = ( 538 | np.random.permutation(self.n_features) / self.n_features 539 | ) 540 | log_values = (uniform_spaced_values + 0.001) ** (1 / 10) 541 | self.sgrna_freq = log_values / np.sum(log_values) 542 | 543 | def _set_native_signals(self): 544 | """generatation of native signals""" 545 | 546 | self._set_sgrna_frequency() 547 | 548 | # cells without any sgrnas 549 | n_cells_miss = int(self.n_cells * self.missing_rate) 550 | 551 | # Doublets 552 | n_doublets = int(self.n_cells * self.doublet_rate) 553 | 554 | # total number of single sgrnas which are integrated into cells 555 | # (cells with double sgrnas will be counted twice) 556 | n_cells_integrated = self.n_cells - n_cells_miss + n_doublets 557 | 558 | # create cells with sgrnas based on sgrna frequencies 559 | self.celltype = random.choice( 560 | a=range(self.n_features), size=n_cells_integrated, p=self.sgrna_freq 561 | ) 562 | self.cell_identity = np.eye(self.n_features)[self.celltype] # cell_identity 563 | 564 | def _add_ambient(self): 565 | """add ambient signals""" 566 | self._set_native_signals() 567 | sgrna_mixed_freq = ( 568 | 1 - self.noise_ratio 569 | ) * self.cell_identity + self.noise_ratio * self.sgrna_freq 570 | sgrna_mixed_freq = sgrna_mixed_freq / sgrna_mixed_freq.sum( 571 | axis=1, keepdims=True 572 | ) 573 | return sgrna_mixed_freq 574 | 575 | # function to generate counts per cell 576 | def generate( 577 | self, 578 | dirichlet_concentration_hyper=None, 579 | library_pattern="pyramid", 580 | noise_ratio=0.96, 581 | average_counts_per_cell=2000, 582 | doublet_rate=0, 583 | missing_rate=0, 584 | ): 585 | """Generate a synthetic sgRNA count dataset. 586 | 587 | Parameters 588 | ---------- 589 | library_pattern : str, optional 590 | library pattern, by default "pyramid" 591 | noise_ratio : float, optional 592 | global contamination level, by default 0.005 593 | average_counts_per_cell : int, optional 594 | average total sgRNA counts per cell, by default 2000 595 | doublet_rate : int, optional 596 | doublet rate, by default 0 597 | missing_rate : int, optional 598 | the fraction of droplets which have zero sgRNAs integrated, by default 0 599 | 600 | Returns 601 | ------- 602 | After running, several attributes are added 603 | """ 604 | 605 | assert library_pattern.lower() in ["uniform", "pyramid", "reverse_pyramid"] 606 | self.library_pattern = library_pattern.lower() 607 | """str, library pattern 608 | """ 609 | self.doublet_rate = doublet_rate 610 | """float, doublet rate 611 | """ 612 | self.missing_rate = missing_rate 613 | """float, the fraction of droplets which have zero sgRNAs integrated. 614 | """ 615 | self.noise_ratio = noise_ratio 616 | """float, global contamination level 617 | """ 618 | self.average_counts_per_cell = average_counts_per_cell 619 | """int, the mean of total sgRNA counts per cell 620 | """ 621 | 622 | # generate total counts per cell 623 | total_counts = random.negative_binomial( 624 | self.average_counts_per_cell, 0.7, size=self.n_cells 625 | ) 626 | 627 | # the mixed sgrna expression profile 628 | sgrna_mixed_freq = self._add_ambient() 629 | 630 | # final count matrix: 631 | obs = np.vstack( 632 | [ 633 | random.multinomial(n=tot_c, pvals=p) 634 | for tot_c, p in zip(total_counts, sgrna_mixed_freq) 635 | ] 636 | ) 637 | 638 | self.obs_count = obs 639 | self.total_counts = total_counts 640 | self.ambient_profile = self.sgrna_freq 641 | self.native_signals = ( 642 | self.total_counts.reshape(-1, 1) 643 | * (1 - self.noise_ratio) 644 | * self.cell_identity 645 | ) 646 | self.ambient_signals = np.clip(obs - self.native_signals, 0, None) 647 | 648 | def heatmap( 649 | self, feature_type="sgRNAs", return_obj=False, figsize=(12, 4), vmin=0, vmax=7 650 | ): 651 | """Heatmap of synthetic data. 652 | 653 | Parameters 654 | ---------- 655 | feature_type : str, optional 656 | the feature types, by default "sgRNAs" 657 | return_obj : bool, optional 658 | whether to output figure object, by default False 659 | figsize : tuple, optional 660 | figure size, by default (15, 5) 661 | vmin : int, optional 662 | colorbar minimum, by default 0 663 | vmax : int, optional 664 | colorbar maximum, by default 10 665 | 666 | Returns 667 | ------- 668 | fig object 669 | if return_obj, return a fig object 670 | """ 671 | sort_cell_idx = [] 672 | for f in self.ambient_profile.argsort(): 673 | sort_cell_idx += list(np.where(self.celltype == f)[0]) 674 | 675 | native_signals = self.native_signals[sort_cell_idx][ 676 | :, self.ambient_profile.argsort() 677 | ] 678 | ambient_signals = self.ambient_signals[sort_cell_idx][ 679 | :, self.ambient_profile.argsort() 680 | ] 681 | obs = self.obs_count[sort_cell_idx][:, self.ambient_profile.argsort()] 682 | 683 | fig, axs = plt.subplots(ncols=3, figsize=figsize) 684 | sns.heatmap( 685 | np.log2(obs + 1), 686 | yticklabels=False, 687 | vmin=vmin, 688 | vmax=vmax, 689 | cmap="coolwarm", 690 | center=1, 691 | ax=axs[0], 692 | rasterized=True, 693 | cbar_kws={"label": "log2(counts + 1)"}, 694 | ) 695 | axs[0].set_title("noisy observation") 696 | 697 | sns.heatmap( 698 | np.log2(ambient_signals + 1), 699 | yticklabels=False, 700 | vmin=vmin, 701 | vmax=vmax, 702 | cmap="coolwarm", 703 | center=1, 704 | ax=axs[1], 705 | rasterized=True, 706 | cbar_kws={"label": "log2(counts + 1)"}, 707 | ) 708 | axs[1].set_title("ambient signals") 709 | 710 | sns.heatmap( 711 | np.log2(native_signals + 1), 712 | yticklabels=False, 713 | vmin=vmin, 714 | vmax=vmax, 715 | cmap="coolwarm", 716 | center=1, 717 | ax=axs[2], 718 | rasterized=True, 719 | cbar_kws={"label": "log2(counts + 1)"}, 720 | ) 721 | axs[2].set_title("native signals") 722 | 723 | fig.supxlabel(feature_type) 724 | fig.supylabel("cells") 725 | plt.tight_layout() 726 | 727 | if return_obj: 728 | return fig 729 | -------------------------------------------------------------------------------- /scar/main/_loss_functions.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Loss functions""" 3 | 4 | import torch 5 | from torch.distributions import Normal, kl_divergence, Binomial, Poisson 6 | from pyro.distributions.zero_inflated import ZeroInflatedPoisson 7 | 8 | def kld(means, var): 9 | """KL divergence""" 10 | mean = torch.zeros_like(means) 11 | scale = torch.ones_like(var) 12 | return kl_divergence(Normal(means, torch.sqrt(var)), Normal(mean, scale)).sum(dim=1) 13 | 14 | 15 | def get_reconstruction_loss( 16 | input_matrix, dec_nr, dec_prob, amb_prob, dec_dp, count_model 17 | ): 18 | """reconstruction loss""" 19 | tot_count = input_matrix.sum(dim=1).view(-1, 1) 20 | prob_tot = dec_prob * (1 - dec_nr) + amb_prob * dec_nr 21 | 22 | if count_model.lower() == "zeroinflatedpoisson": 23 | recon_loss = -ZeroInflatedPoisson( 24 | rate=tot_count * prob_tot / (1 - dec_dp), gate=dec_dp, validate_args=False 25 | ).log_prob(input_matrix) 26 | recon_loss = torch.nan_to_num(recon_loss, nan=1e-7, posinf=1e15, neginf=-1e15) 27 | recon_loss = recon_loss.sum(axis=1).mean() 28 | 29 | elif count_model.lower() == "binomial": 30 | recon_loss = -Binomial(tot_count, probs=prob_tot, validate_args=False).log_prob( 31 | input_matrix 32 | ) 33 | recon_loss = torch.nan_to_num(recon_loss, nan=1e-7, posinf=1e15, neginf=-1e15) 34 | recon_loss = recon_loss.sum(axis=1).mean() 35 | 36 | elif count_model.lower() == "poisson": 37 | recon_loss = -Poisson(rate=tot_count * prob_tot, validate_args=False).log_prob( 38 | input_matrix 39 | ) # add 1 to avoid a situation where all counts are zeros 40 | recon_loss = torch.nan_to_num(recon_loss, nan=1e-7, posinf=1e15, neginf=-1e15) 41 | recon_loss = recon_loss.sum(axis=1).mean() 42 | 43 | return recon_loss 44 | 45 | 46 | def loss_fn( 47 | input_matrix, 48 | dec_nr, 49 | dec_prob, 50 | means, 51 | var, 52 | amb_prob, 53 | reconstruction_weight, 54 | kld_weight=1e-5, 55 | dec_dp=None, 56 | count_model="binomial", 57 | ): 58 | """loss function""" 59 | 60 | recon_loss = get_reconstruction_loss( 61 | input_matrix, dec_nr, dec_prob, amb_prob, dec_dp=dec_dp, count_model=count_model 62 | ) 63 | kld_loss = kld(means, var).sum() 64 | total_loss = recon_loss * reconstruction_weight + kld_loss * kld_weight 65 | 66 | return recon_loss, kld_loss, total_loss 67 | -------------------------------------------------------------------------------- /scar/main/_scar.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """The main module of scar""" 4 | 5 | import sys, time, contextlib, torch 6 | from typing import Optional, Union 7 | from scipy import sparse 8 | import numpy as np, pandas as pd, anndata as ad 9 | 10 | from torch.utils.data import Dataset, random_split, DataLoader 11 | from tqdm import tqdm 12 | from tqdm.contrib import DummyTqdmFile 13 | 14 | from ._vae import VAE 15 | from ._loss_functions import loss_fn 16 | from ._utils import get_logger 17 | 18 | 19 | @contextlib.contextmanager 20 | def std_out_err_redirect_tqdm(): 21 | """ 22 | Writing progressbar into stdout rather than stderr, 23 | from https://github.com/tqdm/tqdm/blob/master/examples/redirect_print.py 24 | """ 25 | orig_out_err = sys.stdout, sys.stderr 26 | try: 27 | sys.stdout, sys.stderr = map(DummyTqdmFile, orig_out_err) 28 | yield orig_out_err[0] 29 | # Relay exceptions 30 | except Exception as exc: 31 | raise exc 32 | # Always restore sys.stdout/err if necessary 33 | finally: 34 | sys.stdout, sys.stderr = orig_out_err 35 | 36 | 37 | # scar class 38 | class model: 39 | """The scar model 40 | 41 | Parameters 42 | ---------- 43 | raw_count : Union[str, np.ndarray, pd.DataFrame, ad.AnnData] 44 | Raw count matrix or Anndata object. 45 | 46 | .. note:: 47 | scar takes the raw UMI counts as input. No size normalization or log transformation. 48 | 49 | ambient_profile : Optional[Union[str, np.ndarray, pd.DataFrame]], optional 50 | the probability of occurrence of each ambient transcript.\ 51 | If None, averaging cells to estimate the ambient profile, by default None 52 | nn_layer1 : int, optional 53 | number of neurons of the 1st layer, by default 150 54 | nn_layer2 : int, optional 55 | number of neurons of the 2nd layer, by default 100 56 | latent_dim : int, optional 57 | number of neurons of the bottleneck layer, by default 15 58 | dropout_prob : float, optional 59 | dropout probability of neurons, by default 0 60 | feature_type : str, optional 61 | the feature to be denoised. One of the following: 62 | 63 | | 'mRNA' -- transcriptome data, including scRNAseq and snRNAseq 64 | | 'ADT' -- protein counts in CITE-seq 65 | | 'sgRNA' -- sgRNA counts for scCRISPRseq 66 | | 'tag' -- identity barcodes or any data types of high sparsity. \ 67 | E.g., in cell indexing experiments, we would expect a single true signal \ 68 | (1) and many negative signals (0) for each cell 69 | | 'CMO' -- Cell Multiplexing Oligo counts for cell hashing 70 | | 'ATAC' -- peak counts for scATACseq 71 | .. versionadded:: 0.5.2 72 | | By default "mRNA" 73 | count_model : str, optional 74 | the model to generate the UMI count. One of the following: 75 | 76 | | 'binomial' -- binomial model, 77 | | 'poisson' -- poisson model, 78 | | 'zeroinflatedpoisson' -- zeroinflatedpoisson model, by default "binomial" 79 | sparsity : float, optional 80 | range: [0, 1]. The sparsity of expected native signals. \ 81 | It varies between datasets, e.g. if one prefilters genes -- \ 82 | use only highly variable genes -- \ 83 | the sparsity should be low; on the other hand, it should be set high \ 84 | in the case of unflitered genes. \ 85 | Forced to be one in the mode of "sgRNA(s)" and "tag(s)". \ 86 | Thank Will Macnair for the valuable feedback. 87 | 88 | .. versionadded:: 0.4.0 89 | cache_capacity : int, optional 90 | the capacity of caching data on GPU. Set a smaller value upon GPU memory issue. By default 20000 cells are cached. 91 | 92 | .. versionadded:: 0.7.0 93 | batch_key : str, optional 94 | batch key in AnnData.obs, by default None. \ 95 | If assigned, batch ambient removel will be performed and \ 96 | the ambient profile will be estimated for each batch. 97 | 98 | .. versionadded:: 0.7.0 99 | 100 | device : str, optional 101 | either "auto, "cpu" or "cuda" or "mps", by default "auto" 102 | verbose : bool, optional 103 | whether to print the details, by default True 104 | 105 | Raises 106 | ------ 107 | TypeError 108 | if raw_count is not str or np.ndarray or pd.DataFrame 109 | TypeError 110 | if ambient_profile is not str or np.ndarray or pd.DataFrame or None 111 | 112 | Examples 113 | -------- 114 | >>> # Real data 115 | >>> import scanpy as sc 116 | >>> from scar import model 117 | >>> adata = sc.read("...") # load an anndata object 118 | >>> scarObj = model(adata, ambient_profile) # initialize scar model 119 | >>> scarObj.train() # start training 120 | >>> scarObj.inference() # inference 121 | >>> adata.layers["X_scar_denoised"] = scarObj.native_counts # results are saved in scarObj 122 | >>> adata.obsm["X_scar_assignment"] = scarObj.feature_assignment #'sgRNA' or 'tag' feature type 123 | 124 | Examples 125 | ------------------------- 126 | .. plot:: 127 | :context: close-figs 128 | 129 | # Synthetic data 130 | import numpy as np 131 | import seaborn as sns 132 | import matplotlib.pyplot as plt 133 | from scar import data_generator, model 134 | 135 | # Generate a synthetic ADT count dataset 136 | np.random.seed(8) 137 | n_features = 50 # 50 ADTs 138 | n_cells = 6000 # 6000 cells 139 | n_celltypes = 6 # cell types 140 | citeseq = data_generator.citeseq(n_cells, n_celltypes, n_features) 141 | citeseq.generate() 142 | 143 | # Train scAR 144 | citeseq_denoised = model(citeseq.obs_count, citeseq.ambient_profile, feature_type="ADT", sparsity=0.6) # initialize scar model 145 | citeseq_denoised.train(epochs=100, verbose=False) # start training 146 | citeseq_denoised.inference() # inference 147 | 148 | # Visualization 149 | sorted_noisy_counts = citeseq.obs_count[citeseq.celltype.argsort()][ 150 | :, citeseq.ambient_profile.argsort() 151 | ] # noisy observation 152 | sorted_native_counts = citeseq.native_signals[citeseq.celltype.argsort()][ 153 | :, citeseq.ambient_profile.argsort() 154 | ] # native counts 155 | sorted_denoised_counts = citeseq_denoised.native_counts.toarray()[citeseq.celltype.argsort()][ 156 | :, citeseq.ambient_profile.argsort() 157 | ] # denoised counts 158 | 159 | fig, axs = plt.subplots(ncols=3, figsize=(12,4)) 160 | sns.heatmap( 161 | np.log2(sorted_noisy_counts + 1), 162 | yticklabels=False, 163 | vmin=0, 164 | vmax=10, 165 | cmap="coolwarm", 166 | center=1, 167 | ax=axs[0], 168 | cbar_kws={"label": "log2(counts + 1)"}, 169 | ) 170 | axs[0].set_title("noisy observation") 171 | 172 | sns.heatmap( 173 | np.log2(sorted_native_counts + 1), 174 | yticklabels=False, 175 | vmin=0, 176 | vmax=10, 177 | cmap="coolwarm", 178 | center=1, 179 | ax=axs[1], 180 | cbar_kws={"label": "log2(counts + 1)"}, 181 | ) 182 | axs[1].set_title("native counts (ground truth)") 183 | 184 | sns.heatmap( 185 | np.log2(sorted_denoised_counts + 1), 186 | yticklabels=False, 187 | vmin=0, 188 | vmax=10, 189 | cmap="coolwarm", 190 | center=1, 191 | ax=axs[2], 192 | cbar_kws={"label": "log2(counts + 1)"}, 193 | ) 194 | axs[2].set_title("denoised counts (prediction)") 195 | 196 | fig.supxlabel("ADTs") 197 | fig.supylabel("cells") 198 | plt.tight_layout() 199 | 200 | """ 201 | 202 | def __init__( 203 | self, 204 | raw_count: Union[str, np.ndarray, pd.DataFrame, ad.AnnData], 205 | ambient_profile: Optional[Union[str, np.ndarray, pd.DataFrame]] = None, 206 | nn_layer1: int = 150, 207 | nn_layer2: int = 100, 208 | latent_dim: int = 15, 209 | dropout_prob: float = 0, 210 | feature_type: str = "mRNA", 211 | count_model: str = "binomial", 212 | sparsity: float = 0.9, 213 | batch_key: str = None, 214 | device: str = "auto", 215 | cache_capacity: int = 20000, 216 | verbose: bool = True, 217 | ): 218 | """initialize object""" 219 | 220 | self.logger = get_logger("model", verbose=verbose) 221 | """logging.Logger, the logger for this class. 222 | """ 223 | 224 | if device == "auto": 225 | if torch.cuda.is_available(): 226 | self.device = torch.device("cuda") 227 | self.logger.info(f"{self.device} is detected and will be used.") 228 | elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): 229 | self.device = torch.device("mps") 230 | self.logger.info(f"{self.device} is detected and will be used.") 231 | else: 232 | self.device = torch.device("cpu") 233 | self.logger.info(f"No GPU detected. {self.device} will be used.") 234 | else: 235 | self.device = device 236 | self.logger.info(f"{device} will be used.") 237 | 238 | """str, either "auto, "cpu" or "cuda". 239 | """ 240 | self.nn_layer1 = nn_layer1 241 | """int, number of neurons of the 1st layer. 242 | """ 243 | self.nn_layer2 = nn_layer2 244 | """int, number of neurons of the 2nd layer. 245 | """ 246 | self.latent_dim = latent_dim 247 | """int, number of neurons of the bottleneck layer. 248 | """ 249 | self.dropout_prob = dropout_prob 250 | """float, dropout probability of neurons. 251 | """ 252 | self.feature_type = feature_type 253 | """str, the feature to be denoised. One of the following: 254 | 255 | | 'mRNA' -- transcriptome 256 | | 'ADT' -- protein counts in CITE-seq 257 | | 'sgRNA' -- sgRNA counts for scCRISPRseq 258 | | 'tag' -- identity barcodes or any data types of super high sparsity. \ 259 | E.g., in cell indexing experiments, we would expect a single true signal \ 260 | (1) and many negative signals (0) for each cell. 261 | | 'CMO' -- Cell Multiplexing Oligo counts for cell hashing 262 | | 'ATAC' -- peak counts for scATACseq 263 | | By default "mRNA" 264 | """ 265 | self.count_model = count_model 266 | """str, the model to generate the UMI count. One of the following: 267 | 268 | | 'binomial' -- binomial model, 269 | | 'poisson' -- poisson model, 270 | | 'zeroinflatedpoisson' -- zeroinflatedpoisson model. 271 | """ 272 | self.sparsity = sparsity 273 | """float, the sparsity of expected native signals. (0, 1]. \ 274 | Forced to be one in the mode of "sgRNA(s)" and "tag(s)". 275 | """ 276 | self.cache_capacity = cache_capacity 277 | """int, the capacity of caching data on GPU. Set a smaller value upon GPU memory issue. By default 20000 cells are cached on GPU/MPS. 278 | 279 | .. versionadded:: 0.7.0 280 | """ 281 | 282 | if isinstance(raw_count, ad.AnnData): 283 | if batch_key is not None: 284 | if batch_key not in raw_count.obs.columns: 285 | raise ValueError(f"{batch_key} not found in AnnData.obs.") 286 | 287 | self.logger.info( 288 | f"Found {raw_count.obs[batch_key].nunique()} batches defined by {batch_key} in AnnData.obs. Estimating ambient profile per batch..." 289 | ) 290 | batch_id_per_cell = pd.Categorical(raw_count.obs[batch_key]).codes 291 | ambient_profile = np.empty((len(np.unique(batch_id_per_cell)),raw_count.shape[1])) 292 | for batch_id in np.unique(batch_id_per_cell): 293 | subset = raw_count[batch_id_per_cell==batch_id] 294 | ambient_profile[batch_id, :] = subset.X.sum(axis=0) / subset.X.sum() 295 | 296 | # add a mapper to locate the batch id 297 | self.batch_id = batch_id_per_cell 298 | self.n_batch = len(np.unique(batch_id_per_cell)) 299 | else: 300 | # get ambient profile from AnnData.uns 301 | if "ambient_profile_all" in raw_count.uns: 302 | self.logger.info( 303 | "Found ambient profile in AnnData.uns['ambient_profile_all']" 304 | ) 305 | ambient_profile = raw_count.uns["ambient_profile_all"] 306 | else: 307 | self.logger.info( 308 | "Ambient profile not found in AnnData.uns['ambient_profile'], estimating it by averaging pooled cells..." 309 | ) 310 | 311 | elif isinstance(raw_count, str): 312 | # read pickle file into dataframe 313 | raw_count = pd.read_pickle(raw_count) 314 | 315 | elif isinstance(raw_count, np.ndarray): 316 | # convert np.array to pd.DataFrame 317 | raw_count = pd.DataFrame( 318 | raw_count, 319 | index=range(raw_count.shape[0]), 320 | columns=range(raw_count.shape[1]), 321 | ) 322 | 323 | elif isinstance(raw_count, pd.DataFrame): 324 | pass 325 | else: 326 | raise TypeError( 327 | f"Expecting str or np.array or pd.DataFrame or AnnData object, but get a {type(raw_count)}" 328 | ) 329 | 330 | self.raw_count = raw_count 331 | """raw_count : np.ndarray, raw count matrix. 332 | """ 333 | self.n_features = raw_count.shape[1] 334 | """int, number of features. 335 | """ 336 | self.cell_id = raw_count.index.to_list() if isinstance(raw_count, pd.DataFrame) else raw_count.obs_names.to_list() 337 | """list, cell id. 338 | """ 339 | self.feature_names = raw_count.columns.to_list() if isinstance(raw_count, pd.DataFrame) else raw_count.var_names.to_list() 340 | """list, feature names. 341 | """ 342 | 343 | if isinstance(ambient_profile, str): 344 | ambient_profile = pd.read_pickle(ambient_profile) 345 | ambient_profile = ambient_profile.fillna(0).values # missing vals -> zeros 346 | elif isinstance(ambient_profile, pd.DataFrame): 347 | ambient_profile = ambient_profile.fillna(0).values # missing vals -> zeros 348 | elif isinstance(ambient_profile, np.ndarray): 349 | ambient_profile = np.nan_to_num(ambient_profile) # missing vals -> zeros 350 | elif not ambient_profile: 351 | self.logger.info(" Evaluate ambient profile from cells") 352 | if isinstance(raw_count, pd.DataFrame): 353 | ambient_profile = raw_count.sum() / raw_count.sum().sum() 354 | ambient_profile = ambient_profile.fillna(0).values 355 | elif isinstance(raw_count, ad.AnnData): 356 | ambient_profile = np.array(raw_count.X.sum(axis=0)/raw_count.X.sum()) 357 | ambient_profile = np.nan_to_num(ambient_profile).flatten() 358 | else: 359 | raise TypeError( 360 | f"Expecting str / np.array / None / pd.DataFrame, but get a {type(ambient_profile)}" 361 | ) 362 | 363 | if ambient_profile.squeeze().ndim == 1: 364 | ambient_profile = ( 365 | ambient_profile.squeeze() 366 | .reshape(1, -1) 367 | ) 368 | # add a mapper to locate the artificial batch id 369 | self.batch_id = np.zeros(raw_count.shape[0], dtype=int)#.reshape(-1, 1) 370 | self.n_batch = 1 371 | 372 | self.ambient_profile = ambient_profile 373 | """ambient_profile : np.ndarray, the probability of occurrence of each ambient transcript. 374 | """ 375 | 376 | self.runtime = None 377 | """int, runtime in seconds. 378 | """ 379 | self.loss_values = None 380 | """list, loss values during training. 381 | """ 382 | self.trained_model = None 383 | """nn.Module object, added after training. 384 | """ 385 | self.native_counts = None 386 | """np.ndarray, denoised counts, added after inference 387 | """ 388 | self.bayesfactor = None 389 | """np.ndarray, bayesian factor of whether native signals are present, added after inference 390 | """ 391 | self.native_frequencies = None 392 | """np.ndarray, probability of native transcripts (normalized denoised counts), added after inference 393 | """ 394 | self.noise_ratio = None 395 | """np.ndarray, noise ratio per cell, added after inference 396 | """ 397 | self.feature_assignment = None 398 | """pd.DataFrame, assignment of sgRNA or tag or other feature barcodes, added after inference or assignment 399 | """ 400 | 401 | def train( 402 | self, 403 | batch_size: int = 64, 404 | train_size: float = 0.998, 405 | shuffle: bool = True, 406 | kld_weight: float = 1e-5, 407 | lr: float = 1e-3, 408 | lr_step_size: int = 5, 409 | lr_gamma: float = 0.97, 410 | epochs: int = 400, 411 | reconstruction_weight: float = 1, 412 | dropout_prob: float = 0, 413 | save_model: bool = False, 414 | verbose: bool = True, 415 | ): 416 | """train training scar model 417 | 418 | Parameters 419 | ---------- 420 | batch_size : int, optional 421 | batch size, by default 64 422 | train_size : float, optional 423 | the size of training samples, by default 0.998 424 | shuffle : bool, optional 425 | whether to shuffle the data, by default True 426 | kld_weight : float, optional 427 | weight of KL loss, by default 1e-5 428 | lr : float, optional 429 | initial learning rate, by default 1e-3 430 | lr_step_size : int, optional 431 | `period of learning rate decay, \ 432 | `_\ 433 | by default 5 434 | lr_gamma : float, optional 435 | multiplicative factor of learning rate decay, by default 0.97 436 | epochs : int, optional 437 | training iterations, by default 800 438 | reconstruction_weight : float, optional 439 | weight on reconstruction error, by default 1 440 | dropout_prob : float, optional 441 | dropout probability of neurons, by default 0 442 | save_model : bool, optional 443 | whether to save trained models(under development), by default False 444 | verbose : bool, optional 445 | whether to print the details, by default True 446 | Returns 447 | ------- 448 | After training, a trained_model attribute will be added. 449 | 450 | """ 451 | # Generators 452 | total_dataset = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, device=self.device, cache_capacity=self.cache_capacity) 453 | training_set, validation_set = random_split(total_dataset, [train_size, 1 - train_size]) 454 | training_generator = DataLoader( 455 | training_set, batch_size=batch_size, shuffle=shuffle, 456 | drop_last=True 457 | ) 458 | self.dataset = total_dataset 459 | 460 | loss_values = [] 461 | 462 | # Define model 463 | vae_nets = VAE( 464 | n_features=self.n_features, 465 | nn_layer1=self.nn_layer1, 466 | nn_layer2=self.nn_layer2, 467 | latent_dim=self.latent_dim, 468 | dropout_prob=dropout_prob, 469 | feature_type=self.feature_type, 470 | count_model=self.count_model, 471 | sparsity=self.sparsity, 472 | n_batch=self.n_batch, 473 | verbose=verbose, 474 | ).to(self.device) 475 | # Define optimizer 476 | optim = torch.optim.Adam(vae_nets.parameters(), lr=lr) 477 | scheduler = torch.optim.lr_scheduler.StepLR( 478 | optim, step_size=lr_step_size, gamma=lr_gamma 479 | ) 480 | 481 | self.logger.info(f"kld_weight: {kld_weight:.2e}") 482 | self.logger.info(f"learning rate: {lr:.2e}") 483 | self.logger.info(f"lr_step_size: {lr_step_size:d}") 484 | self.logger.info(f"lr_gamma: {lr_gamma:.2f}") 485 | 486 | # Run training 487 | training_start_time = time.time() 488 | with std_out_err_redirect_tqdm() as orig_stdout: 489 | # Initialize progress bar 490 | progress_bar = tqdm( 491 | total=epochs, 492 | file=orig_stdout, 493 | dynamic_ncols=True, 494 | desc="Training", 495 | ) 496 | progress_bar.clear() 497 | for _ in range(epochs): 498 | train_tot_loss = 0 499 | train_kld_loss = 0 500 | train_recon_loss = 0 501 | 502 | vae_nets.train() 503 | for x_batch, ambient_freq, batch_id_onehot in training_generator: 504 | optim.zero_grad() 505 | dec_nr, dec_prob, means, var, dec_dp = vae_nets(x_batch, batch_id_onehot) 506 | recon_loss_minibatch, kld_loss_minibatch, loss_minibatch = loss_fn( 507 | x_batch, 508 | dec_nr, 509 | dec_prob, 510 | means, 511 | var, 512 | ambient_freq, 513 | reconstruction_weight=reconstruction_weight, 514 | kld_weight=kld_weight, 515 | dec_dp=dec_dp, 516 | count_model=self.count_model, 517 | ) 518 | loss_minibatch.backward() 519 | optim.step() 520 | 521 | train_tot_loss += loss_minibatch.detach().item() 522 | train_recon_loss += recon_loss_minibatch.detach().item() 523 | train_kld_loss += kld_loss_minibatch.detach().item() 524 | 525 | scheduler.step() 526 | 527 | avg_train_tot_loss = train_tot_loss / len(training_generator) 528 | loss_values.append(avg_train_tot_loss) 529 | 530 | progress_bar.set_postfix({"Loss": "{:.4e}".format(avg_train_tot_loss)}) 531 | progress_bar.update() 532 | 533 | progress_bar.close() 534 | 535 | if save_model: 536 | torch.save(vae_nets, save_model) 537 | 538 | self.loss_values = loss_values 539 | self.trained_model = vae_nets 540 | self.runtime = time.time() - training_start_time 541 | 542 | # Inference 543 | @torch.no_grad() 544 | def inference( 545 | self, 546 | batch_size=4096, 547 | count_model_inf="poisson", 548 | adjust="micro", 549 | cutoff=3, 550 | round_to_int="stochastic_rounding", 551 | clip_to_obs=False, 552 | get_native_frequencies=False, 553 | moi=None, 554 | ): 555 | """inference infering the expected native signals, noise ratios, Bayesfactors and expected native frequencies 556 | 557 | Parameters 558 | ---------- 559 | batch_size : int, optional 560 | batch size, set a small value upon GPU memory issue, by default 4096 561 | count_model_inf : str, optional 562 | inference model for evaluation of ambient presence, by default "poisson" 563 | adjust : str, optional 564 | Only used for calculating Bayesfactors to improve performance. \ 565 | One of the following: 566 | 567 | | 'micro' -- adjust the estimated native counts per cell. \ 568 | This can overcome the issue of over- or under-estimation of noise. 569 | | 'global' -- adjust the estimated native counts globally. \ 570 | This can overcome the issue of over- or under-estimation of noise. 571 | | False -- no adjustment, use the model-returned native counts. 572 | | Defaults to "micro" 573 | cutoff : int, optional 574 | cutoff for Bayesfactors, by default 3 575 | round_to_int : str, optional 576 | whether to round the counts, by default "stochastic_rounding" 577 | 578 | .. versionadded:: 0.4.1 579 | 580 | clip_to_obs : bool, optional 581 | whether to clip the predicted native counts to the observation in order to ensure \ 582 | that denoised counts are not greater than the observation, by default False. \ 583 | Use it with caution, as it may lead to over-estimation of overall noise. 584 | 585 | .. versionadded:: 0.5.0 586 | 587 | get_native_frequencies : bool, optional 588 | whether to get native frequencies, by default False 589 | 590 | .. versionadded:: 0.7.0 591 | 592 | moi : int, optional (under development) 593 | multiplicity of infection. If assigned, it will allow optimized thresholding, \ 594 | which tests a series of cutoffs to find the best one \ 595 | based on distributions of infections under given moi.\ 596 | See Perturb-seq [Dixit2016]_ for details, by default None 597 | Returns 598 | ------- 599 | After inferring, several attributes will be added, inc. native_counts, bayesfactor,\ 600 | native_frequencies, and noise_ratio. \ 601 | A feature_assignment will be added in 'sgRNA' or 'tag' or 'CMO' feature type. 602 | """ 603 | n_features = self.n_features 604 | sample_size = self.raw_count.shape[0] 605 | 606 | dt = np.int64 if round_to_int=="stochastic_rounding" else np.float32 607 | native_counts = sparse.lil_matrix((sample_size, n_features), dtype=dt) 608 | noise_ratio = sparse.lil_matrix((sample_size, 1), dtype=np.float32) 609 | 610 | native_frequencies = sparse.lil_matrix((sample_size, n_features), dtype=np.float32) if get_native_frequencies else None 611 | 612 | if self.feature_type.lower() in [ 613 | "sgrna", 614 | "sgrnas", 615 | "tag", 616 | "tags", 617 | "cmo", 618 | "cmos", 619 | "atac", 620 | ]: 621 | bayesfactor = sparse.lil_matrix((sample_size, n_features), dtype=np.float32) 622 | else: 623 | bayesfactor = None 624 | 625 | if not batch_size: 626 | batch_size = sample_size 627 | i = 0 628 | generator_full_data = DataLoader( 629 | self.dataset, batch_size=batch_size, shuffle=False 630 | ) 631 | 632 | for x_batch_tot, ambient_freq_tot, x_batch_id_onehot_tot in generator_full_data: 633 | minibatch_size = x_batch_tot.shape[ 634 | 0 635 | ] # if not the last batch, equals to batch size 636 | 637 | ( 638 | native_counts_batch, 639 | bayesfactor_batch, 640 | native_frequencies_batch, 641 | noise_ratio_batch, 642 | ) = self.trained_model.inference( 643 | x_batch_tot, 644 | x_batch_id_onehot_tot, 645 | ambient_freq_tot[0, :], 646 | count_model_inf=count_model_inf, 647 | adjust=adjust, 648 | round_to_int=round_to_int, 649 | clip_to_obs=clip_to_obs, 650 | ) 651 | native_counts[ 652 | i * batch_size : i * batch_size + minibatch_size, : 653 | ] = native_counts_batch 654 | noise_ratio[ 655 | i * batch_size : i * batch_size + minibatch_size, : 656 | ] = noise_ratio_batch 657 | if native_frequencies is not None: 658 | native_frequencies[ 659 | i * batch_size : i * batch_size + minibatch_size, : 660 | ] = native_frequencies_batch 661 | if bayesfactor is not None: 662 | bayesfactor[ 663 | i * batch_size : i * batch_size + minibatch_size, : 664 | ] = bayesfactor_batch 665 | 666 | i += 1 667 | 668 | self.native_counts = native_counts.tocsr() 669 | self.noise_ratio = noise_ratio.tocsr() 670 | self.bayesfactor = bayesfactor.tocsr() if bayesfactor is not None else None 671 | self.native_frequencies = native_frequencies.tocsr() if native_frequencies is not None else None 672 | 673 | if self.feature_type.lower() in [ 674 | "sgrna", 675 | "sgrnas", 676 | "tag", 677 | "tags", 678 | "cmo", 679 | "cmos", 680 | "atac", 681 | ]: 682 | self.assignment(cutoff=cutoff, moi=moi) 683 | else: 684 | self.feature_assignment = None 685 | 686 | def assignment(self, cutoff=3, moi=None): 687 | """assignment assignment of feature barcodes. Re-run it can test different cutoffs for your experiments. 688 | 689 | Parameters 690 | ---------- 691 | cutoff : int, optional 692 | cutoff for Bayesfactors, by default 3 693 | moi : float, optional 694 | multiplicity of infection. (under development)\ 695 | If assigned, it will allow optimized thresholding,\ 696 | which tests a series of cutoffs to find the best one \ 697 | based on distributions of infections under given moi.\ 698 | See Perturb-seq [Dixit2016]_, by default None 699 | Returns 700 | ------- 701 | After running, a attribute 'feature_assignment' will be added,\ 702 | in 'sgRNA' or 'tag' or 'CMO' feature type. 703 | Raises 704 | ------ 705 | NotImplementedError 706 | if moi is not None 707 | """ 708 | 709 | feature_assignment = pd.DataFrame( 710 | index=self.cell_id, columns=[self.feature_type, f"n_{self.feature_type}"] 711 | ) 712 | bayesfactor_df = pd.DataFrame( 713 | self.bayesfactor.toarray(), index=self.cell_id, columns=self.feature_names 714 | ) 715 | bayesfactor_df[bayesfactor_df < cutoff] = 0 # Apply the cutoff for Bayesfactors 716 | 717 | for cell, row in bayesfactor_df.iterrows(): 718 | bayesfactor_max = row[row == row.max()] 719 | if row.max() == 0: 720 | feature_assignment.loc[cell, f"n_{self.feature_type}"] = 0 721 | feature_assignment.loc[cell, self.feature_type] = "" 722 | elif len(bayesfactor_max) == 1: 723 | feature_assignment.loc[cell, f"n_{self.feature_type}"] = 1 724 | feature_assignment.loc[cell, self.feature_type] = bayesfactor_max.index[ 725 | 0 726 | ] 727 | else: 728 | feature_assignment.loc[cell, f"n_{self.feature_type}"] = len( 729 | bayesfactor_max 730 | ) 731 | feature_assignment.loc[cell, self.feature_type] = (", ").join( 732 | bayesfactor_max.index.astype(str) 733 | ) 734 | 735 | self.feature_assignment = feature_assignment 736 | 737 | if moi: 738 | raise NotImplementedError 739 | 740 | class UMIDataset(Dataset): 741 | """Characterizes dataset for PyTorch""" 742 | 743 | def __init__(self, raw_count, ambient_profile, batch_id, device, cache_capacity=20000): 744 | """Initialization""" 745 | 746 | self.raw_count = torch.from_numpy(raw_count.fillna(0).values).int() if isinstance(raw_count, pd.DataFrame) else raw_count 747 | self.ambient_profile = torch.from_numpy(ambient_profile).float().to(device) 748 | self.batch_id = torch.from_numpy(batch_id).to(torch.int64).to(device) 749 | self.batch_onehot = torch.from_numpy(np.eye(len(np.unique(batch_id)))).to(torch.int64).to(device) 750 | self.device = device 751 | self.cache_capacity = cache_capacity 752 | 753 | # Cache data 754 | self.cache = {} 755 | 756 | def __len__(self): 757 | """Denotes the total number of samples""" 758 | return self.raw_count.shape[0] 759 | 760 | def __getitem__(self, index): 761 | """Generates one sample of data""" 762 | 763 | if index in self.cache: 764 | return self.cache[index] 765 | else: 766 | # Select samples 767 | sc_count = self.raw_count[index].to(self.device) if isinstance(self.raw_count, torch.Tensor) else torch.from_numpy(self.raw_count[index].X.toarray().flatten()).int().to(self.device) 768 | sc_ambient = self.ambient_profile[self.batch_id[index], :] 769 | sc_batch_id_onehot = self.batch_onehot[self.batch_id[index], :] 770 | 771 | # Cache samples 772 | if len(self.cache) <= self.cache_capacity: 773 | self.cache[index] = (sc_count, sc_ambient, sc_batch_id_onehot) 774 | 775 | return sc_count, sc_ambient, sc_batch_id_onehot 776 | -------------------------------------------------------------------------------- /scar/main/_setup.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | import numpy as np 3 | import pandas as pd 4 | import matplotlib.pyplot as plt 5 | import seaborn as sns 6 | from anndata import AnnData 7 | import torch 8 | from torch.distributions.multinomial import Multinomial 9 | 10 | from ._utils import get_logger 11 | 12 | 13 | def setup_anndata( 14 | adata: AnnData, 15 | raw_adata: AnnData, 16 | feature_type: Union[str, list] = None, 17 | prob: float = 0.995, 18 | min_raw_counts: int = 2, 19 | iterations: int = 3, 20 | n_batch: int = None, 21 | sample: int = None, 22 | kneeplot: bool = True, 23 | verbose: bool = True, 24 | figsize: tuple = (6, 6), 25 | ): 26 | """Calculate ambient profile for relevant features 27 | 28 | Identify the cell-free droplets through a multinomial distribution. See EmptyDrops [Lun2019]_ for details. 29 | 30 | 31 | Parameters 32 | ---------- 33 | adata : AnnData 34 | A filtered adata object, loaded from filtered_feature_bc_matrix using `scanpy.read` , gene filtering is recommended to save memory 35 | raw_adata : AnnData 36 | An raw adata object, loaded from raw_feature_bc_matrix using `scanpy.read` 37 | feature_type : Union[str, list], optional 38 | Feature type, e.g. 'Gene Expression', 'Antibody Capture', 'CRISPR Guide Capture' or 'Multiplexing Capture', all feature types are calculated if None, by default None 39 | prob : float, optional 40 | The probability of each gene, considered as containing ambient RNA if greater than prob (joint prob euqals to the product of all genes for a droplet), by default 0.995 41 | min_raw_counts : int, optional 42 | Total counts filter for raw_adata, filtering out low counts to save memory, by default 2 43 | iterations : int, optional 44 | Total iterations, by default 3 45 | n_batch : int, optional 46 | Total number of batches, set it to a bigger number when out of memory issue occurs, by default None 47 | sample : int, optional 48 | Randomly sample droplets to test, if greater than total droplets, use all droplets. Use all droplets by default (None) 49 | kneeplot : bool, optional 50 | Kneeplot to show subpopulations of droplets, by default True 51 | verbose : bool, optional 52 | Whether to display message 53 | figsize : tuple, optimal 54 | Figure size, by default (6, 6) 55 | 56 | Returns 57 | ------- 58 | The relevant ambient profile is added in `adata.uns` 59 | 60 | Examples 61 | --------- 62 | .. plot:: 63 | :context: close-figs 64 | 65 | import scanpy as sc 66 | from scar import setup_anndata 67 | # read filtered data 68 | adata = sc.read_10x_h5(filename='500_hgmm_3p_LT_Chromium_Controller_filtered_feature_bc_matrix.h5ad', 69 | backup_url='https://cf.10xgenomics.com/samples/cell-exp/6.1.0/500_hgmm_3p_LT_Chromium_Controller/500_hgmm_3p_LT_Chromium_Controller_filtered_feature_bc_matrix.h5'); 70 | adata.var_names_make_unique(); 71 | # read raw data 72 | adata_raw = sc.read_10x_h5(filename='500_hgmm_3p_LT_Chromium_Controller_raw_feature_bc_matrix.h5ad', 73 | backup_url='https://cf.10xgenomics.com/samples/cell-exp/6.1.0/500_hgmm_3p_LT_Chromium_Controller/500_hgmm_3p_LT_Chromium_Controller_raw_feature_bc_matrix.h5'); 74 | adata_raw.var_names_make_unique(); 75 | # gene and cell filter 76 | sc.pp.filter_genes(adata, min_counts=200); 77 | sc.pp.filter_genes(adata, max_counts=6000); 78 | sc.pp.filter_cells(adata, min_genes=200); 79 | # setup anndata 80 | setup_anndata( 81 | adata, 82 | adata_raw, 83 | feature_type = "Gene Expression", 84 | prob = 0.975, 85 | min_raw_counts = 2, 86 | kneeplot = True, 87 | ) 88 | """ 89 | 90 | setup_logger = get_logger("setup_anndata", verbose=verbose) 91 | 92 | if feature_type is None: 93 | feature_type = adata.var["feature_types"].unique() 94 | elif isinstance(feature_type, str): 95 | feature_type = [feature_type] 96 | 97 | # take subset genes to save memory 98 | # raw_adata._inplace_subset_var(raw_adata.var_names.isin(adata.var_names)) 99 | # raw_adata._inplace_subset_obs(raw_adata.X.sum(axis=1) >= min_raw_counts) 100 | raw_adata = raw_adata[:, raw_adata.var_names.isin(adata.var_names)] 101 | raw_adata = raw_adata[raw_adata.X.sum(axis=1) >= min_raw_counts] 102 | 103 | raw_adata.obs["total_counts"] = raw_adata.X.sum(axis=1) 104 | 105 | if sample is not None: 106 | sample = int(sample) 107 | setup_logger.info( 108 | f"Randomly sample {sample:d} droplets from {raw_adata.shape[0]:d} droplets." 109 | ) 110 | else: 111 | sample = raw_adata.shape[0] 112 | setup_logger.info(f"Use all {sample:d} droplets.") 113 | 114 | # check n_batch 115 | if n_batch is None: 116 | n_batch = int(np.ceil(sample / 5000)) 117 | else: 118 | n_batch = int(n_batch) 119 | 120 | # check if per batch contains too many droplets 121 | if sample / n_batch > 5000: 122 | setup_logger.info( 123 | "The number of droplets per batch is too large, this may cause memory issue, please increase the number of batches." 124 | ) 125 | 126 | idx = np.random.choice( 127 | raw_adata.shape[0], size=min(raw_adata.shape[0], sample), replace=False 128 | ) 129 | raw_adata = raw_adata[idx] 130 | 131 | # initial estimation of ambient profile, will be update 132 | ambient_prof = raw_adata.X.sum(axis=0) / raw_adata.X.sum() 133 | 134 | setup_logger.info(f"Estimating ambient profile for {feature_type}...") 135 | 136 | i = 0 137 | while i < iterations: 138 | # calculate joint probability (log) of being cell-free droplets for each droplet 139 | log_prob = [] 140 | batch_idx = np.floor( 141 | np.array(range(raw_adata.shape[0])) / raw_adata.shape[0] * n_batch 142 | ) 143 | for b in range(n_batch): 144 | try: 145 | count_batch = raw_adata[batch_idx == b].X.astype(int).toarray() 146 | except MemoryError: 147 | raise MemoryError("use more batches by setting a higher n_batch") 148 | log_prob_batch = Multinomial( 149 | probs=torch.tensor(ambient_prof), validate_args=False 150 | ).log_prob(torch.Tensor(count_batch)) 151 | log_prob.append(log_prob_batch) 152 | 153 | log_prob = np.concatenate(log_prob, axis=0) 154 | raw_adata.obs["log_prob"] = log_prob 155 | raw_adata.obs["droplets"] = "other droplets" 156 | 157 | # cell-containing droplets 158 | raw_adata.obs.loc[ 159 | raw_adata.obs_names.isin(adata.obs_names), "droplets" 160 | ] = "cells" 161 | 162 | # identify cell-free droplets 163 | raw_adata.obs["droplets"] = raw_adata.obs["droplets"].mask( 164 | raw_adata.obs["log_prob"] >= np.log(prob) * raw_adata.shape[1], 165 | "cell-free droplets", 166 | ) 167 | emptydrops = raw_adata[raw_adata.obs["droplets"] == "cell-free droplets"] 168 | 169 | if emptydrops.shape[0] < 50: 170 | raise Exception("Too few emptydroplets! Lower the prob parameter") 171 | 172 | ambient_prof = emptydrops.X.sum(axis=0) / emptydrops.X.sum() 173 | 174 | i += 1 175 | 176 | setup_logger.info(f"Iteration: {i:d}") 177 | 178 | # update ambient profile for each feature type 179 | for ft in feature_type: 180 | tmp = emptydrops[:, emptydrops.var["feature_types"] == ft] 181 | adata.uns[f"ambient_profile_{ft}"] = pd.DataFrame( 182 | tmp.X.sum(axis=0).reshape(-1, 1) / tmp.X.sum(), 183 | index=tmp.var_names, 184 | columns=[f"ambient_profile_{ft}"], 185 | ) 186 | 187 | setup_logger.info(f"Estimated ambient profile for {ft} saved in adata.uns") 188 | 189 | # update ambient profile for all feature types 190 | adata.uns[f"ambient_profile_all"] = pd.DataFrame( 191 | emptydrops.X.sum(axis=0).reshape(-1, 1) / emptydrops.X.sum(), 192 | index=emptydrops.var_names, 193 | columns=[f"ambient_profile_all"], 194 | ) 195 | 196 | setup_logger.info("Estimated ambient profile for all features saved in adata.uns") 197 | 198 | if kneeplot: 199 | _, axs = plt.subplots(2, figsize=figsize) 200 | 201 | all_droplets = raw_adata.obs.copy() 202 | all_droplets = ( 203 | all_droplets.sort_values(by="total_counts", ascending=False) 204 | .reset_index() 205 | .rename_axis("rank_by_counts") 206 | .reset_index() 207 | ) 208 | all_droplets = all_droplets.loc[all_droplets["total_counts"] >= min_raw_counts] 209 | all_droplets = all_droplets.set_index("index").rename_axis("cells") 210 | all_droplets = ( 211 | all_droplets.sort_values(by="log_prob", ascending=True) 212 | .reset_index() 213 | .rename_axis("rank_by_log_prob") 214 | .reset_index() 215 | .set_index("cells") 216 | ) 217 | 218 | ax = sns.lineplot( 219 | data=all_droplets, 220 | x="rank_by_counts", 221 | y="total_counts", 222 | hue="droplets", 223 | hue_order=["cells", "other droplets", "cell-free droplets"], 224 | palette=sns.color_palette()[-3:], 225 | markers=False, 226 | lw=2, 227 | ci=None, 228 | ax=axs[0], 229 | ) 230 | 231 | ax.set_xscale("log") 232 | ax.set_yscale("log") 233 | ax.set_xlabel("") 234 | ax.set_title("cell-free droplets have lower counts") 235 | 236 | all_droplets["prob"] = np.exp(all_droplets["log_prob"]) 237 | ax = sns.lineplot( 238 | data=all_droplets, 239 | x="rank_by_log_prob", 240 | y="prob", 241 | hue="droplets", 242 | hue_order=["cells", "other droplets", "cell-free droplets"], 243 | palette=sns.color_palette()[-3:], 244 | markers=False, 245 | lw=2, 246 | ci=None, 247 | ax=axs[1], 248 | ) 249 | ax.set_xscale("log") 250 | ax.set_xlabel("sorted droplets") 251 | ax.set_title("cell-free droplets have relatively higher probs") 252 | 253 | plt.tight_layout() 254 | -------------------------------------------------------------------------------- /scar/main/_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | # Added: Create logger and assign handler 4 | def get_logger(name, verbose=True): 5 | logger = logging.getLogger(name) 6 | logger.handlers.clear() 7 | logger.setLevel(logging.INFO if verbose else logging.WARNING) 8 | formatter = logging.Formatter( 9 | fmt="%(asctime)s|%(levelname)s|%(name)s|%(message)s", 10 | datefmt="%Y-%m-%d %H:%M:%S", 11 | ) 12 | handler = logging.StreamHandler() 13 | handler.setFormatter(formatter) 14 | logger.addHandler(handler) 15 | return logger 16 | -------------------------------------------------------------------------------- /scar/main/_vae.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | The variational autoencoder 4 | """ 5 | 6 | import numpy as np 7 | from scipy import stats 8 | 9 | import torch 10 | from torch import nn 11 | 12 | from ._activation_functions import mytanh, hnormalization, mysoftplus 13 | from ._utils import get_logger 14 | 15 | ######################################################################### 16 | ## Variational autoencoder 17 | ######################################################################### 18 | 19 | 20 | class VAE(nn.Module): 21 | """A class of variational autoencoder 22 | 23 | Parameters 24 | ---------- 25 | n_features : int 26 | number of features (e.g. mRNA, sgRNA, ADT, tag, CMO, ...) 27 | nn_layer1 : int, optional 28 | number of neurons in the 1st layer, by default 150 29 | nn_layer2 : int, optional 30 | number of neurons in the 2nd layer, by default 100 31 | latent_dim : int, optional 32 | number of neurons in the bottleneck layer, by default 15 33 | dropout_prob : int, optional 34 | dropout probability, by default 0 35 | feature_type : str, optional 36 | the feature to be denoised, either of 'mRNA', 'sgRNA', 'ADT', 'tag', 'CMO', 'ATAC', by default "mRNA" 37 | count_model : str, optional 38 | the model to generate the UMI count, either of "binomial", "poisson", "zeroinflatedpoisson", by default "binomial" 39 | sparsity : float, optional 40 | the sparsity of expected data, by default 0.9 41 | verbose : bool, optional 42 | whether to display information, by default True 43 | """ 44 | 45 | def __init__( 46 | self, 47 | n_features, 48 | nn_layer1=150, 49 | nn_layer2=100, 50 | latent_dim=15, 51 | dropout_prob=0, 52 | feature_type="mRNA", 53 | count_model="binomial", 54 | n_batch=1, 55 | sparsity=0.9, 56 | verbose=True, 57 | ): 58 | super().__init__() 59 | assert feature_type.lower() in [ 60 | "mrna", 61 | "mrnas", 62 | "sgrna", 63 | "sgrnas", 64 | "adt", 65 | "adts", 66 | "tag", 67 | "tags", 68 | "cmo", 69 | "cmos", 70 | "atac", 71 | ] 72 | assert count_model.lower() in ["binomial", "poisson", "zeroinflatedpoisson"] 73 | # force the sparsity to be one in the mode of "sgRNAs" and "tags" and "CMOs" 74 | if feature_type.lower() in [ 75 | "sgrna", 76 | "sgrnas", 77 | "tag", 78 | "tags", 79 | "cmo", 80 | "cmos", 81 | ]: 82 | sparsity = 1 83 | 84 | self.encoder = Encoder( 85 | n_features, n_batch, nn_layer1, nn_layer2, latent_dim, dropout_prob 86 | ) 87 | self.decoder = Decoder( 88 | n_features, 89 | n_batch, 90 | nn_layer1, 91 | nn_layer2, 92 | latent_dim, 93 | dropout_prob, 94 | count_model, 95 | sparsity, 96 | ) 97 | 98 | vae_logger = get_logger("VAE", verbose=verbose) 99 | 100 | vae_logger.info("Running VAE using the following param set:") 101 | vae_logger.info(f"...denoised count type: {feature_type}") 102 | vae_logger.info(f"...count model: {count_model}") 103 | vae_logger.info(f"...num_input_feature: {n_features:d}") 104 | vae_logger.info(f"...NN_layer1: {nn_layer1:d}") 105 | vae_logger.info(f"...NN_layer2: {nn_layer2:d}") 106 | vae_logger.info(f"...latent_space: {latent_dim:d}") 107 | vae_logger.info(f"...dropout_prob: {dropout_prob:.2f}") 108 | vae_logger.info(f"...expected data sparsity: {sparsity:.2f}") 109 | 110 | def forward(self, input_matrix, batch_id_onehot=None): 111 | """forward function""" 112 | sampling, means, var = self.encoder(input_matrix, batch_id_onehot) 113 | dec_nr, dec_prob, dec_dp = self.decoder(sampling, batch_id_onehot) 114 | return dec_nr, dec_prob, means, var, dec_dp 115 | 116 | @torch.no_grad() 117 | def inference( 118 | self, 119 | input_matrix, 120 | batch_id_onehot, 121 | amb_prob, 122 | count_model_inf="poisson", 123 | adjust="micro", 124 | round_to_int="stochastic_rounding", 125 | clip_to_obs=False, 126 | ): 127 | """ 128 | Inference of presence of native signals 129 | """ 130 | assert count_model_inf.lower() in ["poisson", "binomial", "zeroinflatedpoisson"] 131 | assert adjust in [False, "global", "micro"] 132 | 133 | # Estimate native signals 134 | dec_nr, dec_prob, _, _, _ = self.forward(input_matrix, batch_id_onehot) 135 | 136 | # Copy tensor to CPU 137 | input_matrix_np = input_matrix.cpu().numpy() 138 | noise_ratio = dec_nr.cpu().numpy().reshape(-1, 1) 139 | nat_prob = dec_prob.cpu().numpy() 140 | amb_prob = amb_prob.cpu().numpy().reshape(1, -1) 141 | 142 | total_count_per_cell = input_matrix_np.sum(axis=1).reshape(-1, 1) 143 | expected_native_counts = total_count_per_cell * (1 - noise_ratio) * nat_prob 144 | expected_amb_counts = total_count_per_cell * noise_ratio * amb_prob 145 | tot_amb = expected_amb_counts.sum(axis=1).reshape(-1, 1) 146 | 147 | if not round_to_int: 148 | pass 149 | elif round_to_int.lower() == "stochastic_rounding": 150 | expected_native_counts = ( 151 | np.floor(expected_native_counts) 152 | + ( 153 | np.random.rand(*expected_native_counts.shape) 154 | < expected_native_counts - np.floor(expected_native_counts) 155 | ).astype(int) 156 | ).astype(int) 157 | 158 | expected_amb_counts = ( 159 | np.floor(expected_amb_counts) 160 | + ( 161 | np.random.rand(*expected_amb_counts.shape) 162 | < expected_amb_counts - np.floor(expected_amb_counts) 163 | ).astype(int) 164 | ).astype(int) 165 | 166 | if clip_to_obs: 167 | expected_native_counts = np.clip( 168 | expected_native_counts, 169 | a_min=np.zeros_like(input_matrix_np), 170 | a_max=input_matrix_np, 171 | ) 172 | 173 | if not adjust: 174 | adjust = 0 175 | elif adjust == "global": 176 | adjust = (total_count_per_cell.sum() - tot_amb.sum()) / len( 177 | input_matrix_np.flatten() 178 | ) 179 | elif adjust == "micro": 180 | adjust = (total_count_per_cell - tot_amb) / input_matrix_np.shape[1] 181 | adjust = np.repeat(adjust, input_matrix_np.shape[1], axis=1) 182 | 183 | ### Calculate the Bayesian factors 184 | # The probability that observed UMI counts do not purely 185 | # come from expected distribution of ambient signals. 186 | # H1: x is drawn from distribution (binomial or poission or 187 | # zeroinflatedpoisson)with prob > amb_prob 188 | # H2: x is drawn from distribution (binomial or poission or 189 | # zeroinflatedpoisson) with prob = amb_prob 190 | 191 | if count_model_inf.lower() == "binomial": 192 | probs_h1 = stats.binom.logcdf(input_matrix_np, tot_amb + adjust, amb_prob) 193 | probs_h2 = stats.binom.logpmf(input_matrix_np, tot_amb + adjust, amb_prob) 194 | 195 | elif count_model_inf.lower() == "poisson": 196 | probs_h1 = stats.poisson.logcdf( 197 | input_matrix_np, expected_amb_counts + adjust 198 | ) 199 | probs_h2 = stats.poisson.logpmf( 200 | input_matrix_np, expected_amb_counts + adjust 201 | ) 202 | 203 | elif count_model_inf.lower() == "zeroinflatedpoisson": 204 | raise NotImplementedError 205 | 206 | bayesian_factor = np.clip(probs_h1 - probs_h2 + 1e-22, -709.78, 709.78) 207 | bayesian_factor = np.exp(bayesian_factor) 208 | 209 | return ( 210 | expected_native_counts, 211 | bayesian_factor, 212 | dec_prob.cpu().numpy(), 213 | dec_nr.cpu().numpy(), 214 | ) 215 | 216 | 217 | ######################################################################### 218 | ## Encoder 219 | ######################################################################### 220 | class Encoder(nn.Module): 221 | """ 222 | Encoder that takes the original expressions of feature barcodes and produces the encoding. 223 | 224 | Consists of 2 FC layers. 225 | """ 226 | 227 | def __init__(self, n_features, n_batch, nn_layer1, nn_layer2, latent_dim, dropout_prob): 228 | """initialization""" 229 | super().__init__() 230 | self.activation = nn.SELU() 231 | # if n_batch > 1: 232 | # n_features += n_batch 233 | self.fc1 = nn.Linear(n_features + n_batch, nn_layer1) 234 | self.bn1 = nn.BatchNorm1d(nn_layer1, momentum=0.01, eps=0.001) 235 | self.dp1 = nn.Dropout(p=dropout_prob) 236 | self.fc2 = nn.Linear(nn_layer1, nn_layer2) 237 | self.bn2 = nn.BatchNorm1d(nn_layer2, momentum=0.01, eps=0.001) 238 | self.dp2 = nn.Dropout(p=dropout_prob) 239 | 240 | self.linear_means = nn.Linear(nn_layer2, latent_dim) 241 | self.linear_log_vars = nn.Linear(nn_layer2, latent_dim) 242 | self.z_transformation = nn.Softmax(dim=-1) 243 | 244 | def reparametrize(self, means, log_vars): 245 | """reparameterization""" 246 | var = log_vars.exp() + 1e-4 247 | return torch.distributions.Normal(means, var.sqrt()).rsample(), var 248 | 249 | def forward(self, input_matrix, batch_id_onehot): 250 | """forward function""" 251 | input_matrix = (input_matrix + 1).log2() # log transformation of count data 252 | input_matrix = torch.cat([input_matrix, batch_id_onehot], 1) 253 | enc = self.fc1(input_matrix) 254 | enc = self.bn1(enc) 255 | enc = self.activation(enc) 256 | enc = self.dp1(enc) 257 | enc = self.fc2(enc) 258 | enc = self.bn2(enc) 259 | enc = self.activation(enc) 260 | enc = torch.clamp(enc, min=None, max=1e7) 261 | enc = self.dp2(enc) 262 | 263 | means = self.linear_means(enc) 264 | log_vars = self.linear_log_vars(enc) 265 | sampling, var = self.reparametrize(means, log_vars) 266 | latent_transform = self.z_transformation(sampling) 267 | 268 | return latent_transform, means, var 269 | 270 | 271 | ######################################################################### 272 | ## Decoder 273 | ######################################################################### 274 | class Decoder(nn.Module): 275 | """ 276 | A decoder model that takes the encodings and a batch (source) matrix and produces decodings. 277 | 278 | Made up of 2 FC layers. 279 | """ 280 | 281 | def __init__( 282 | self, 283 | n_features, 284 | n_batch, 285 | nn_layer1, 286 | nn_layer2, 287 | latent_dim, 288 | dropout_prob, 289 | count_model, 290 | sparsity, 291 | ): 292 | """initialization""" 293 | super().__init__() 294 | self.activation = nn.SELU() 295 | self.normalization_native_freq = hnormalization() 296 | self.noise_activation = mytanh() 297 | self.activation_native_freq = mysoftplus(sparsity) 298 | self.fc4 = nn.Linear(latent_dim + n_batch, nn_layer2) 299 | self.bn4 = nn.BatchNorm1d(nn_layer2, momentum=0.01, eps=0.001) 300 | self.dp4 = nn.Dropout(p=dropout_prob) 301 | self.fc5 = nn.Linear(nn_layer2, nn_layer1) 302 | self.bn5 = nn.BatchNorm1d(nn_layer1, momentum=0.01, eps=0.001) 303 | self.dp5 = nn.Dropout(p=dropout_prob) 304 | 305 | self.noise_fc = nn.Linear(nn_layer1, 1) 306 | self.out_fc = nn.Linear(nn_layer1, n_features) 307 | self.count_model = count_model 308 | if count_model.lower() == "zeroinflatedpoisson": 309 | self.dropoutprob = nn.Linear(nn_layer1, 1) 310 | self.dropout_activation = mytanh() 311 | 312 | def forward(self, sampling, batch_id_onehot): 313 | """forward function""" 314 | # decoder 315 | cond_sampling = torch.cat([sampling, batch_id_onehot], 1) 316 | dec = self.fc4(cond_sampling) 317 | dec = self.bn4(dec) 318 | dec = self.activation(dec) 319 | dec = self.fc5(dec) 320 | dec = self.bn5(dec) 321 | dec = self.activation(dec) 322 | dec = torch.clamp(dec, min=None, max=1e7) 323 | 324 | # final layers to produce prob parameters 325 | dec_prob = self.out_fc(dec) 326 | dec_prob = self.activation_native_freq(dec_prob) 327 | dec_prob = self.normalization_native_freq(dec_prob) 328 | 329 | # final layers to produce noise_ratio parameters 330 | dec_nr = self.noise_fc(dec) 331 | dec_nr = self.noise_activation(dec_nr) 332 | 333 | # final layers to learn the dropout probability 334 | if self.count_model.lower() == "zeroinflatedpoisson": 335 | dec_dp = self.dropoutprob(dec) 336 | dec_dp = self.dropout_activation(dec_dp) 337 | dec_dp = torch.nan_to_num(dec_dp, nan=1e-7) 338 | else: 339 | dec_dp = None 340 | 341 | return dec_nr, dec_prob, dec_dp 342 | -------------------------------------------------------------------------------- /scar/test/ambient_profile.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Novartis/scar/df7adef31162b471358b2b442445a449f939859d/scar/test/ambient_profile.pickle -------------------------------------------------------------------------------- /scar/test/citeseq_ambient_profile.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Novartis/scar/df7adef31162b471358b2b442445a449f939859d/scar/test/citeseq_ambient_profile.pickle -------------------------------------------------------------------------------- /scar/test/citeseq_native_counts.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Novartis/scar/df7adef31162b471358b2b442445a449f939859d/scar/test/citeseq_native_counts.pickle -------------------------------------------------------------------------------- /scar/test/citeseq_raw_counts.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Novartis/scar/df7adef31162b471358b2b442445a449f939859d/scar/test/citeseq_raw_counts.pickle -------------------------------------------------------------------------------- /scar/test/output_assignment.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Novartis/scar/df7adef31162b471358b2b442445a449f939859d/scar/test/output_assignment.pickle -------------------------------------------------------------------------------- /scar/test/raw_counts.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Novartis/scar/df7adef31162b471358b2b442445a449f939859d/scar/test/raw_counts.pickle -------------------------------------------------------------------------------- /scar/test/test_activation_functions.py: -------------------------------------------------------------------------------- 1 | """ This tests the activation functions. """ 2 | import decimal 3 | import unittest 4 | import numpy 5 | import torch 6 | from decimal import Decimal 7 | from scar.main._activation_functions import mytanh, hnormalization, mysoftplus 8 | 9 | 10 | class ActivationFunctionsTest(unittest.TestCase): 11 | """ 12 | Test activation_functions.py functions. 13 | """ 14 | 15 | def test_mytanh(self): 16 | """ 17 | Test mytanh(). 18 | """ 19 | self.assertEqual( 20 | Decimal(mytanh()(torch.tensor(1.0, dtype=torch.float32)).item()).quantize( 21 | decimal.Decimal(".01"), rounding=decimal.ROUND_DOWN 22 | ), 23 | Decimal(0.88).quantize(decimal.Decimal(".01"), rounding=decimal.ROUND_DOWN), 24 | ) 25 | 26 | def test_hnormalization(self): 27 | """ 28 | Test hnormalization(). 29 | """ 30 | self.assertTrue( 31 | torch.allclose( 32 | hnormalization()(torch.tensor(numpy.full((20, 8), 1))).double(), 33 | torch.tensor(numpy.full((20, 8), 0.1250)), 34 | ) 35 | ) 36 | 37 | def test_mysoftplus(self): 38 | """ 39 | Test mysoftplus(). 40 | """ 41 | self.assertTrue( 42 | torch.allclose( 43 | mysoftplus()( 44 | torch.tensor(numpy.full((20, 8), 0.01), dtype=torch.float32) 45 | ).double(), 46 | torch.tensor(numpy.full((20, 8), 0.3849)), 47 | ) 48 | ) 49 | -------------------------------------------------------------------------------- /scar/test/test_scar.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | from sklearn.metrics.pairwise import euclidean_distances 4 | 5 | from scar import model, data_generator 6 | import unittest 7 | 8 | 9 | class ScarIntegration(unittest.TestCase): 10 | """ 11 | Functional testing 12 | """ 13 | 14 | def test_scar(self): 15 | raw_count = pd.read_pickle("scar/test/raw_counts.pickle") 16 | ambient_profile = pd.read_pickle("scar/test/ambient_profile.pickle") 17 | expected_output = pd.read_pickle("scar/test/output_assignment.pickle") 18 | 19 | scarObj = model( 20 | raw_count=raw_count.values, 21 | ambient_profile=ambient_profile, 22 | feature_type="sgRNAs", 23 | ) 24 | 25 | scarObj.train(epochs=40, batch_size=32) 26 | 27 | scarObj.inference() 28 | 29 | self.assertTrue(scarObj.feature_assignment.equals(expected_output)) 30 | 31 | def test_scar_data_generator(self): 32 | """ 33 | Functional testing of data_generator module 34 | """ 35 | np.random.seed(8) 36 | citeseq = data_generator.citeseq(6000, 6, 50) 37 | citeseq.generate() 38 | 39 | citeseq_raw_counts = pd.read_pickle("scar/test/citeseq_raw_counts.pickle") 40 | 41 | self.assertTrue(np.array_equal(citeseq.obs_count, citeseq_raw_counts.values, equal_nan=True)) 42 | 43 | def test_scar_citeseq(self): 44 | """ 45 | Functional testing of scAR 46 | """ 47 | citeseq_raw_counts = pd.read_pickle("scar/test/citeseq_raw_counts.pickle") 48 | citeseq_ambient_profile = pd.read_pickle( 49 | "scar/test/citeseq_ambient_profile.pickle" 50 | ) 51 | citeseq_native_signals = pd.read_pickle( 52 | "scar/test/citeseq_native_counts.pickle" 53 | ) 54 | 55 | citeseq_scar = model( 56 | raw_count=citeseq_raw_counts.values, 57 | ambient_profile=citeseq_ambient_profile.values, 58 | feature_type="ADTs", 59 | ) 60 | 61 | citeseq_scar.train(epochs=200, batch_size=32, verbose=False) 62 | citeseq_scar.inference() 63 | 64 | dist = euclidean_distances( 65 | citeseq_native_signals.values, citeseq_scar.native_counts 66 | ) 67 | mean_dist = (np.eye(dist.shape[0]) * dist).sum() / dist.shape[0] 68 | 69 | self.assertLess(mean_dist, 50, "Funtional test of scAR fails") 70 | --------------------------------------------------------------------------------