├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── conftest.py
├── contrastive_vi
├── __init__.py
├── data
│ ├── __init__.py
│ ├── dataloaders
│ │ ├── __init__.py
│ │ ├── contrastive_dataloader.py
│ │ └── data_splitting.py
│ ├── datasets
│ │ ├── __init__.py
│ │ ├── haber_2017.py
│ │ ├── mcfarland_2020.py
│ │ ├── norman_2019.py
│ │ ├── papalexi_2021.py
│ │ └── zheng_2017.py
│ └── utils.py
├── model
│ ├── __init__.py
│ ├── base
│ │ ├── __init__.py
│ │ └── training_mixin.py
│ ├── contrastive_vi.py
│ └── total_contrastive_vi.py
└── module
│ ├── __init__.py
│ ├── contrastive_vi.py
│ ├── total_contrastive_vi.py
│ └── utils.py
├── environment.yml
├── pyproject.toml
├── setup.cfg
├── sketch.png
└── tests
├── __init__.py
├── data
└── dataloaders
│ ├── test_contrastive_dataloader.py
│ └── test_data_splitting.py
├── model
├── __init__.py
└── test_contrastive_vi.py
├── module
├── __init__.py
└── test_contrastive_vi.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # IDE
2 | .idea
3 |
4 | # Jupyter notebook
5 | .ipynb_checkpoints
6 |
7 | # Distribution
8 | *.egg-info
9 |
10 | # Python
11 | *__pycache__
12 |
13 | # OS
14 | *.DS_Store
15 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/pre-commit-hooks
3 | rev: v4.0.1
4 | hooks:
5 | - id: trailing-whitespace
6 | - id: end-of-file-fixer
7 | - id: check-docstring-first
8 | - repo: https://github.com/python/black
9 | rev: 21.9b0
10 | hooks:
11 | - id: black
12 | additional_dependencies: [toml, click==8.0.4]
13 | args:
14 | - --line-length=88
15 | - repo: https://github.com/PyCQA/pydocstyle
16 | rev: 6.1.1
17 | hooks:
18 | - id: pydocstyle
19 | additional_dependencies: [toml]
20 | args:
21 | - --ignore=D102,D107,D202,D203,D212,D205,D400,D401,D410,D411,D413,D415
22 | exclude: "(tests/.*|conftest.py)"
23 | - repo: https://github.com/pycqa/flake8
24 | rev: 4.0.1
25 | hooks:
26 | - id: flake8
27 | args:
28 | - --max-line-length=88
29 | - --ignore=E203,W503
30 | - repo: https://github.com/pycqa/isort
31 | rev: 5.8.0
32 | hooks:
33 | - id: isort
34 | name: isort (python)
35 | additional_dependencies: [toml]
36 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2021 Ethan Weinberger, Chris Lin, AIMS Lab
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | * Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | * Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | * Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # contrastiveVI
2 |
3 |
4 |
5 |
6 |
7 | contrastiveVI is a generative model designed to isolate factors of variation specific to
8 | a group of "target" cells (e.g. from specimens with a given disease) from those shared
9 | with a group of "background" cells (e.g. from healthy specimens). contrastiveVI is
10 | implemented in [scvi-tools](https://scvi-tools.org/).
11 |
12 | ## User guide
13 |
14 | **Note: This implementation of contrastiveVI is no longer maintained. Please see the main [scvi-tools](https://scvi-tools.org/) repository for the most up to date version of contrastiveVI, along with an updated tutorial [here](https://docs.scvi-tools.org/en/stable/tutorials/notebooks/scrna/contrastiveVI_tutorial.html).**
15 |
16 | ### Installation
17 |
18 | To install the latest version of contrastiveVI via pip
19 |
20 | ```
21 | pip install contrastive-vi
22 | ```
23 |
24 | Installation should take no more than 5 minutes.
25 |
26 | ### What you can do with contrastiveVI
27 |
28 | * If you have a dataset with cells in a background condition (e.g. from healthy
29 | controls) and a target condition (e.g. from diseased patients), you can train
30 | contrastiveVI to isolate latent factors of variation specific to the target cells
31 | from those shared with a background into separate latent spaces.
32 | * Run clustering algorithms on the target-specific latent space to discover sub-groups
33 | of target cells
34 | * Perform differential expression testing for discovered sub-groups of target cells
35 | using a procedure similar to that of [scVI
36 | ](https://www.nature.com/articles/s41592-018-0229-2).
37 |
38 | ### Colab Notebook Examples
39 |
40 | * [Applying contrastiveVI to see the effects of stem cell transplants for leukemia patients
41 | ](https://colab.research.google.com/drive/1yOTCVNWY6BydS1bppOYCWHvrxuvhMxZV?usp=sharing)
42 | * [Applying contrastiveVI to separate mouse intestinal epithelial cells
43 | infected with different pathogens by pathogen type
44 | ](https://colab.research.google.com/drive/1z0AcKQg7juArXGCx1XKj6skojWKRlDMC?usp=sharing)
45 | * [Applying contrastiveVI to better understand the results of a MIX-Seq
46 | small-molecule drug perturbation experiment
47 | ](https://colab.research.google.com/drive/1cMaJpMe3g0awCiwsw13oG7RvGnmXNCac?usp=sharing)
48 | * [Applying contrastiveVI to better understand heterogeneity in cellular responses to Alzheimer's disease
49 | ](https://colab.research.google.com/drive/1_R1YWQQUJzgQ6kz1XqglL5xZn8b8h1TX?usp=sharing)
50 |
51 |
52 | ## Development guide
53 |
54 | ### Set up the environment
55 | 1. Git clone this repository.
56 | 2. `cd contrastive-vi`.
57 | 3. Create and activate the specified conda environment by running
58 | ```
59 | conda env create -f environment.yml
60 | conda activate contrastive-vi-env
61 | ```
62 | 4. Install the `constrative_vi` package and necessary dependencies for
63 | development by running `pip install -e ".[dev]"`.
64 | 5. Git pre-commit hooks (https://pre-commit.com/) are used to automatically
65 | check and fix formatting errors before a Git commit happens. Run
66 | `pre-commit install` to install all the hooks.
67 | 6. Test that the pre-commit hooks work by running `pre-commit run --all-files`.
68 |
69 | ### Testing
70 | It's a good practice to include unit tests during development.
71 | Run `pytest tests` to verify existing tests.
72 |
73 |
74 | ## References
75 |
76 | If you find contrastiveVI useful for your work, please consider citing our preprent:
77 |
78 | ```
79 | @article{contrastiveVI,
80 | title={Isolating salient variations of interest in single-cell transcriptomic data with contrastiveVI},
81 | author={Weinberger, Ethan and Lin, Chris and Lee, Su-In},
82 | journal={bioRxiv},
83 | year={2021},
84 | publisher={Cold Spring Harbor Laboratory}
85 | }
86 | ```
87 |
--------------------------------------------------------------------------------
/conftest.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import scvi
3 | from scvi.model._utils import _init_library_size
4 |
5 | from contrastive_vi.data.dataloaders.contrastive_dataloader import ContrastiveDataLoader
6 | from contrastive_vi.model.contrastive_vi import ContrastiveVIModel
7 | from tests.utils import get_next_batch
8 |
9 |
10 | @pytest.fixture
11 | def mock_adata():
12 | adata = scvi.data.synthetic_iid(n_batches=2) # Same number of cells in each batch.
13 | # Make number of cells unequal across batches to test edge cases.
14 | adata = adata[:-3, :]
15 | adata.layers["raw_counts"] = adata.X.copy()
16 | ContrastiveVIModel.setup_anndata(
17 | adata=adata,
18 | batch_key="batch",
19 | labels_key="labels",
20 | layer="raw_counts",
21 | )
22 | return adata
23 |
24 |
25 | @pytest.fixture
26 | def mock_adata_manager(mock_adata):
27 | return ContrastiveVIModel._setup_adata_manager_store[mock_adata.uns["_scvi_uuid"]]
28 |
29 |
30 | @pytest.fixture
31 | def mock_library_log_means_and_vars(mock_adata_manager):
32 | return _init_library_size(mock_adata_manager, n_batch=2)
33 |
34 |
35 | @pytest.fixture
36 | def mock_library_log_means(mock_library_log_means_and_vars):
37 | return mock_library_log_means_and_vars[0]
38 |
39 |
40 | @pytest.fixture
41 | def mock_library_log_vars(mock_library_log_means_and_vars):
42 | return mock_library_log_means_and_vars[1]
43 |
44 |
45 | @pytest.fixture
46 | def mock_n_input(mock_adata):
47 | return mock_adata.X.shape[1]
48 |
49 |
50 | @pytest.fixture
51 | def mock_n_batch(mock_adata):
52 | return len(mock_adata.obs["batch"].unique())
53 |
54 |
55 | @pytest.fixture
56 | def mock_adata_background_indices(mock_adata):
57 | return (
58 | mock_adata.obs.index[(mock_adata.obs["batch"] == "batch_0")]
59 | .astype(int)
60 | .tolist()
61 | )
62 |
63 |
64 | @pytest.fixture
65 | def mock_adata_background_label(mock_adata):
66 | return 0
67 |
68 |
69 | @pytest.fixture
70 | def mock_adata_target_indices(mock_adata):
71 | return (
72 | mock_adata.obs.index[(mock_adata.obs["batch"] == "batch_1")]
73 | .astype(int)
74 | .tolist()
75 | )
76 |
77 |
78 | @pytest.fixture
79 | def mock_adata_target_label(mock_adata):
80 | return 1
81 |
82 |
83 | @pytest.fixture
84 | def mock_contrastive_dataloader(
85 | mock_adata_manager, mock_adata_background_indices, mock_adata_target_indices
86 | ):
87 | return ContrastiveDataLoader(
88 | mock_adata_manager,
89 | mock_adata_background_indices,
90 | mock_adata_target_indices,
91 | batch_size=32,
92 | shuffle=False,
93 | )
94 |
95 |
96 | @pytest.fixture
97 | def mock_contrastive_batch(mock_contrastive_dataloader):
98 | return get_next_batch(mock_contrastive_dataloader)
99 |
--------------------------------------------------------------------------------
/contrastive_vi/__init__.py:
--------------------------------------------------------------------------------
1 | """contrastive_vi setup file. copied from the scvi-tools-skeleton repo."""
2 |
3 | import logging
4 |
5 | from rich.console import Console
6 | from rich.logging import RichHandler
7 |
8 | # https://github.com/python-poetry/poetry/pull/2366#issuecomment-652418094
9 | # https://github.com/python-poetry/poetry/issues/144#issuecomment-623927302
10 | try:
11 | import importlib.metadata as importlib_metadata
12 | except ModuleNotFoundError:
13 | import importlib_metadata
14 |
15 | package_name = "contrastive_vi"
16 | __version__ = importlib_metadata.version(package_name)
17 |
18 | logger = logging.getLogger(__name__)
19 | # set the logging level
20 | logger.setLevel(logging.INFO)
21 |
22 | # nice logging outputs
23 | console = Console(force_terminal=True)
24 | console.is_jupyter = False
25 | ch = RichHandler(show_path=False, console=console, show_time=False)
26 | formatter = logging.Formatter("contrastive_vi: %(message)s")
27 | ch.setFormatter(formatter)
28 | logger.addHandler(ch)
29 |
30 | # this prevents double outputs
31 | logger.propagate = False
32 |
--------------------------------------------------------------------------------
/contrastive_vi/data/__init__.py:
--------------------------------------------------------------------------------
1 | """Data preprocessing modules."""
2 |
--------------------------------------------------------------------------------
/contrastive_vi/data/dataloaders/__init__.py:
--------------------------------------------------------------------------------
1 | """Data loader modules for mini-batching."""
2 |
--------------------------------------------------------------------------------
/contrastive_vi/data/dataloaders/contrastive_dataloader.py:
--------------------------------------------------------------------------------
1 | """Data loader for contrastive learning."""
2 | from itertools import cycle
3 | from typing import List, Optional, Union
4 |
5 | from scvi.data import AnnDataManager
6 | from scvi.dataloaders._concat_dataloader import ConcatDataLoader
7 |
8 |
9 | class _ContrastiveIterator:
10 | """
11 | Iterator for background and target dataloader pairs as found in the contrastive
12 | analysis setting.
13 |
14 | Each iteration of this iterator returns a dictionary with two elements:
15 | "background", containing one batch of data from the background dataloader, and
16 | "target", containing one batch of data from the target dataloader.
17 | """
18 |
19 | def __init__(self, background, target):
20 | self.background = iter(background)
21 | self.target = iter(target)
22 |
23 | def __iter__(self):
24 | return self
25 |
26 | def __next__(self):
27 | bg_samples = next(self.background)
28 | tg_samples = next(self.target)
29 | return {"background": bg_samples, "target": tg_samples}
30 |
31 |
32 | class ContrastiveDataLoader(ConcatDataLoader):
33 | """
34 | Data loader to load background and target data for contrastive learning.
35 |
36 | Each iteration of the data loader returns a dictionary containing background and
37 | target data points, indexed by "background" and "target", respectively.
38 | Args:
39 | ----
40 | adata_manager: AnnDataManager object that has been created via `setup_anndata`.
41 | background_indices: Indices for background samples in `adata`.
42 | target_indices: Indices for target samples in `adata`.
43 | shuffle: Whether the data should be shuffled.
44 | batch_size: Mini-batch size to load for background and target data.
45 | data_and_attributes: Dictionary with keys representing keys in data
46 | registry (`adata.uns["_scvi"]`) and value equal to desired numpy
47 | loading type (later made into torch tensor). If `None`, defaults to all
48 | registered data.
49 | drop_last: If int, drops the last batch if its length is less than
50 | `drop_last`. If `drop_last == True`, drops last non-full batch.
51 | If `drop_last == False`, iterate over all batches.
52 | **data_loader_kwargs: Keyword arguments for `torch.utils.data.DataLoader`.
53 | """
54 |
55 | def __init__(
56 | self,
57 | adata_manager: AnnDataManager,
58 | background_indices: List[int],
59 | target_indices: List[int],
60 | shuffle: bool = False,
61 | batch_size: int = 128,
62 | data_and_attributes: Optional[dict] = None,
63 | drop_last: Union[bool, int] = False,
64 | **data_loader_kwargs,
65 | ) -> None:
66 | super().__init__(
67 | adata_manager=adata_manager,
68 | indices_list=[background_indices, target_indices],
69 | shuffle=shuffle,
70 | batch_size=batch_size,
71 | data_and_attributes=data_and_attributes,
72 | drop_last=drop_last,
73 | **data_loader_kwargs,
74 | )
75 | self.background_indices = background_indices
76 | self.target_indices = target_indices
77 |
78 | def __iter__(self):
79 | """
80 |
81 | Iter method for conctrastive data loader.
82 |
83 | Will iter over the dataloader with the most data while cycling through
84 | the data in the other dataloaders.
85 | """
86 |
87 | iter_list = [
88 | cycle(dl) if dl != self.largest_dl else dl for dl in self.dataloaders
89 | ]
90 |
91 | return _ContrastiveIterator(background=iter_list[0], target=iter_list[1])
92 |
--------------------------------------------------------------------------------
/contrastive_vi/data/dataloaders/data_splitting.py:
--------------------------------------------------------------------------------
1 | """Utilities for splitting a dataset into training, validation, and test set."""
2 |
3 | from typing import List, Optional
4 |
5 | import numpy as np
6 | import pytorch_lightning as pl
7 | from anndata import AnnData
8 | from scvi import settings
9 | from scvi.dataloaders._data_splitting import validate_data_split
10 | from scvi.model._utils import parse_use_gpu_arg
11 |
12 | from contrastive_vi.data.dataloaders.contrastive_dataloader import ContrastiveDataLoader
13 |
14 |
15 | class ContrastiveDataSplitter(pl.LightningDataModule):
16 | """
17 | Create ContrastiveDataLoader for training, validation, and test set.
18 |
19 | Args:
20 | ----
21 | adata: AnnData object that has been registered via `setup_anndata`.
22 | background_indices: Indices for background samples in `adata`.
23 | target_indices: Indices for target samples in `adata`.
24 | train_size: Proportion of data to include in the training set.
25 | validation_size: Proportion of data to include in the validation set. The
26 | remaining proportion after `train_size` and `validation_size` is used for
27 | the test set.
28 | use_gpu: Use default GPU if available (if None or True); or index of GPU to
29 | use (if int); or name of GPU (if str, e.g., `'cuda:0'`); or use CPU
30 | (if False).
31 | **kwargs: Keyword args for data loader (`ContrastiveDataLoader`).
32 | """
33 |
34 | def __init__(
35 | self,
36 | adata: AnnData,
37 | background_indices: List[int],
38 | target_indices: List[int],
39 | train_size: float = 0.9,
40 | validation_size: Optional[float] = None,
41 | use_gpu: bool = False,
42 | **kwargs,
43 | ) -> None:
44 | super().__init__()
45 | self.adata = adata
46 | self.background_indices = background_indices
47 | self.target_indices = target_indices
48 | self.train_size = train_size
49 | self.validation_size = validation_size
50 | self.use_gpu = use_gpu
51 | self.data_loader_kwargs = kwargs
52 |
53 | self.n_background = len(background_indices)
54 | self.n_target = len(target_indices)
55 | self.n_background_train, self.n_background_val = validate_data_split(
56 | len(self.background_indices), self.train_size, self.validation_size
57 | )
58 | self.n_target_train, self.n_target_val = validate_data_split(
59 | len(self.target_indices), self.train_size, self.validation_size
60 | )
61 |
62 | def setup(self, stage: Optional[str] = None):
63 | random_state = np.random.RandomState(seed=settings.seed)
64 |
65 | target_permutation = random_state.permutation(self.target_indices)
66 | n_target_train = self.n_target_train
67 | n_target_val = self.n_target_val
68 | self.target_val_idx = target_permutation[:n_target_val]
69 | self.target_train_idx = target_permutation[
70 | n_target_val : (n_target_val + n_target_train)
71 | ]
72 | self.target_test_idx = target_permutation[(n_target_val + n_target_train) :]
73 |
74 | background_permutation = random_state.permutation(self.background_indices)
75 | n_background_train = self.n_background_train
76 | n_background_val = self.n_background_val
77 | self.background_val_idx = background_permutation[:n_background_val]
78 | self.background_train_idx = background_permutation[
79 | n_background_val : (n_background_val + n_background_train)
80 | ]
81 | self.background_test_idx = background_permutation[
82 | (n_background_val + n_background_train) :
83 | ]
84 |
85 | self.train_idx = np.concatenate(
86 | (self.background_train_idx, self.target_train_idx)
87 | )
88 | self.val_idx = np.concatenate((self.background_val_idx, self.target_val_idx))
89 | self.test_idx = np.concatenate((self.background_test_idx, self.target_test_idx))
90 |
91 | gpus, self.device = parse_use_gpu_arg(self.use_gpu, return_device=True)
92 | self.pin_memory = (
93 | True if (settings.dl_pin_memory_gpu_training and gpus != 0) else False
94 | )
95 |
96 | def _get_contrastive_dataloader(
97 | self, background_indices: List[int], target_indices: List[int]
98 | ) -> ContrastiveDataLoader:
99 | return ContrastiveDataLoader(
100 | self.adata,
101 | background_indices,
102 | target_indices,
103 | shuffle=True,
104 | drop_last=3,
105 | pin_memory=self.pin_memory,
106 | **self.data_loader_kwargs,
107 | )
108 |
109 | def train_dataloader(self) -> ContrastiveDataLoader:
110 | return self._get_contrastive_dataloader(
111 | self.background_train_idx, self.target_train_idx
112 | )
113 |
114 | def val_dataloader(self) -> ContrastiveDataLoader:
115 | if len(self.background_val_idx) > 0 and len(self.target_val_idx) > 0:
116 | return self._get_contrastive_dataloader(
117 | self.background_val_idx, self.target_val_idx
118 | )
119 | else:
120 | pass
121 |
122 | def test_dataloader(self) -> ContrastiveDataLoader:
123 | if len(self.background_test_idx) > 0 and len(self.target_test_idx) > 0:
124 | return self._get_contrastive_dataloader(
125 | self.background_test_idx, self.target_test_idx
126 | )
127 | else:
128 | pass
129 |
--------------------------------------------------------------------------------
/contrastive_vi/data/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | """Modules for downloading, reading, and preprocessing specific datasets."""
2 |
--------------------------------------------------------------------------------
/contrastive_vi/data/datasets/haber_2017.py:
--------------------------------------------------------------------------------
1 | """
2 | Download, read, and preprocess Haber et al. (2017) expression data.
3 |
4 | Single-cell expression data from Haber et al. A single-cell survey of the small
5 | intestinal epithelium. Nature (2017).
6 | """
7 | import gzip
8 | import os
9 |
10 | import pandas as pd
11 | import scanpy as sc
12 | from anndata import AnnData
13 |
14 | from contrastive_vi.data.utils import download_binary_file
15 |
16 |
17 | def download_haber_2017(output_path: str) -> None:
18 | """
19 | Download Haber et al. 2017 data from the hosting URLs.
20 |
21 | Args:
22 | ----
23 | output_path: Output path to store the downloaded and unzipped
24 | directories.
25 |
26 | Returns
27 | -------
28 | None. File directories are downloaded to output_path.
29 | """
30 |
31 | url = (
32 | "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE92nnn/GSE92332/suppl/GSE92332"
33 | "_SalmHelm_UMIcounts.txt.gz"
34 | )
35 |
36 | output_filename = os.path.join(output_path, url.split("/")[-1])
37 |
38 | download_binary_file(url, output_filename)
39 |
40 |
41 | def read_haber_2017(file_directory: str) -> pd.DataFrame:
42 | """
43 | Read the expression data for Download Haber et al. 2017 the given directory.
44 |
45 | Args:
46 | ----
47 | file_directory: Directory containing Haber et al. 2017 data.
48 |
49 | Returns
50 | -------
51 | A data frame containing single-cell gene expression count, with cell
52 | identification barcodes as column names and gene IDs as indices.
53 | """
54 |
55 | with gzip.open(
56 | os.path.join(file_directory, "GSE92332_SalmHelm_UMIcounts.txt.gz"), "rb"
57 | ) as f:
58 | df = pd.read_csv(f, sep="\t")
59 |
60 | return df
61 |
62 |
63 | def preprocess_haber_2017(download_path: str, n_top_genes: int) -> AnnData:
64 | """
65 | Preprocess expression data from Haber et al. 2017.
66 |
67 | Args:
68 | ----
69 | download_path: Path containing the downloaded Haber et al. 2017 data file.
70 | n_top_genes: Number of most variable genes to retain.
71 |
72 | Returns
73 | -------
74 | An AnnData object containing single-cell expression data. The layer
75 | "count" contains the count data for the most variable genes. The X
76 | variable contains the total-count-normalized and log-transformed data
77 | for the most variable genes (a copy with all the genes is stored in
78 | .raw).
79 | """
80 |
81 | df = read_haber_2017(download_path)
82 | df = df.transpose()
83 |
84 | cell_groups = []
85 | barcodes = []
86 | conditions = []
87 | cell_types = []
88 |
89 | for cell in df.index:
90 | cell_group, barcode, condition, cell_type = cell.split("_")
91 | cell_groups.append(cell_group)
92 | barcodes.append(barcode)
93 | conditions.append(condition)
94 | cell_types.append(cell_type)
95 |
96 | metadata_df = pd.DataFrame(
97 | {
98 | "cell_group": cell_groups,
99 | "barcode": barcodes,
100 | "condition": conditions,
101 | "cell_type": cell_types,
102 | }
103 | )
104 |
105 | adata = AnnData(X=df.values, obs=metadata_df)
106 | adata = adata[adata.obs["condition"] != "Hpoly.Day3"]
107 | adata.layers["count"] = adata.X.copy()
108 | sc.pp.normalize_total(adata)
109 | sc.pp.log1p(adata)
110 | adata.raw = adata
111 | sc.pp.highly_variable_genes(
112 | adata, flavor="seurat_v3", n_top_genes=n_top_genes, layer="count", subset=True
113 | )
114 | adata = adata[adata.layers["count"].sum(1) != 0] # Remove cells with all zeros.
115 | return adata
116 |
--------------------------------------------------------------------------------
/contrastive_vi/data/datasets/mcfarland_2020.py:
--------------------------------------------------------------------------------
1 | """
2 | Download, read, and preprocess Mcfarland et al. (2020) expression data.
3 |
4 | Single-cell expression data from Mcfarland et al. Multiplexed single-cell
5 | transcriptional response profiling to define cancer vulnerabilities and therapeutic
6 | mechanism of action. Nature Communications (2020).
7 | """
8 | import os
9 | import shutil
10 | from typing import Tuple
11 |
12 | import anndata
13 | import numpy as np
14 | import pandas as pd
15 | import scanpy as sc
16 | from anndata import AnnData
17 | from scipy.io import mmread
18 |
19 | from contrastive_vi.data.utils import download_binary_file
20 |
21 |
22 | def download_mcfarland_2020(output_path: str) -> None:
23 | """
24 | Download Mcfarland et al. 2020 data from the hosting URLs.
25 |
26 | Args:
27 | ----
28 | output_path: Output path to store the downloaded and unzipped
29 | directories.
30 |
31 | Returns
32 | -------
33 | None. File directories are downloaded and unzipped in output_path.
34 | """
35 | idasanutlin_url = "https://figshare.com/ndownloader/files/18716351"
36 | idasanutlin_output_filename = os.path.join(output_path, "idasanutlin.zip")
37 |
38 | download_binary_file(idasanutlin_url, idasanutlin_output_filename)
39 | idasanutlin_output_dir = idasanutlin_output_filename.replace(".zip", "")
40 | shutil.unpack_archive(idasanutlin_output_filename, idasanutlin_output_dir)
41 |
42 | dmso_url = "https://figshare.com/ndownloader/files/18716354"
43 | dmso_output_filename = os.path.join(output_path, "dmso.zip")
44 |
45 | download_binary_file(dmso_url, dmso_output_filename)
46 | dmso_output_dir = dmso_output_filename.replace(".zip", "")
47 | shutil.unpack_archive(dmso_output_filename, dmso_output_dir)
48 |
49 |
50 | def _read_mixseq_df(directory: str) -> pd.DataFrame:
51 | data = mmread(os.path.join(directory, "matrix.mtx"))
52 | barcodes = pd.read_table(os.path.join(directory, "barcodes.tsv"), header=None)
53 | classifications = pd.read_csv(os.path.join(directory, "classifications.csv"))
54 | classifications["cell_line"] = np.array(
55 | [x.split("_")[0] for x in classifications.singlet_ID.values]
56 | )
57 | gene_names = pd.read_table(os.path.join(directory, "genes.tsv"), header=None)
58 |
59 | df = pd.DataFrame(
60 | data.toarray(),
61 | columns=barcodes.iloc[:, 0].values,
62 | index=gene_names.iloc[:, 0].values,
63 | )
64 | return df
65 |
66 |
67 | def _get_tp53_mutation_status(directory: str) -> np.array:
68 | # Taken from https://cancerdatascience.org/blog/posts/mix-seq/
69 | TP53_WT = [
70 | "LNCAPCLONEFGC_PROSTATE",
71 | "DKMG_CENTRAL_NERVOUS_SYSTEM",
72 | "NCIH226_LUNG",
73 | "RCC10RGB_KIDNEY",
74 | "SNU1079_BILIARY_TRACT",
75 | "CCFSTTG1_CENTRAL_NERVOUS_SYSTEM",
76 | "COV434_OVARY",
77 | ]
78 |
79 | classifications = pd.read_csv(os.path.join(directory, "classifications.csv"))
80 | TP53_mutation_status = [
81 | "Wild Type" if x in TP53_WT else "Mutation"
82 | for x in classifications.singlet_ID.values
83 | ]
84 | return np.array(TP53_mutation_status)
85 |
86 |
87 | def read_mcfarland_2020(file_directory: str) -> Tuple[pd.DataFrame, pd.DataFrame]:
88 | """
89 | Read the expression data for Mcfarland et al. 2020 in the given directory.
90 |
91 | Args:
92 | ----
93 | file_directory: Directory containing Mcfarland et al. 2020 data.
94 |
95 | Returns
96 | -------
97 | Two data frames of raw count expression data. The first contains
98 | single-cell gene expression count data from cancer cell lines exposed to
99 | idasanutlin with cell identification barcodes as column names and gene IDs as
100 | indices. The second contains count data with the same format from samples
101 | exposed to a control solution (DMSO).
102 | """
103 | idasanutlin_dir = os.path.join(
104 | file_directory, "idasanutlin", "Idasanutlin_24hr_expt1"
105 | )
106 | idasanutlin_df = _read_mixseq_df(idasanutlin_dir)
107 |
108 | dmso_dir = os.path.join(file_directory, "dmso", "DMSO_24hr_expt1")
109 | dmso_df = _read_mixseq_df(dmso_dir)
110 |
111 | return idasanutlin_df, dmso_df
112 |
113 |
114 | def preprocess_mcfarland_2020(download_path: str, n_top_genes: int) -> AnnData:
115 | """
116 | Preprocess expression data from Mcfarland et al., 2020.
117 |
118 | Args:
119 | ----
120 | download_path: Path containing the downloaded Mcfarland et al. 2020 data files.
121 | n_top_genes: Number of most variable genes to retain.
122 |
123 | Returns
124 | -------
125 | An AnnData object containing single-cell expression data. The layer
126 | "count" contains the count data for the most variable genes. The X
127 | variable contains the total-count-normalized and log-transformed data
128 | for the most variable genes (a copy with all the genes is stored in
129 | .raw).
130 | """
131 |
132 | idasanutlin_df, dmso_df = read_mcfarland_2020(download_path)
133 | idasanutlin_df, dmso_df = idasanutlin_df.transpose(), dmso_df.transpose()
134 |
135 | idasanutlin_adata = AnnData(idasanutlin_df)
136 | idasanutlin_dir = os.path.join(
137 | download_path, "idasanutlin", "Idasanutlin_24hr_expt1"
138 | )
139 | idasanutlin_adata.obs["TP53_mutation_status"] = _get_tp53_mutation_status(
140 | idasanutlin_dir
141 | )
142 | idasanutlin_adata.obs["condition"] = np.repeat(
143 | "Idasanutlin", idasanutlin_adata.shape[0]
144 | )
145 |
146 | dmso_adata = AnnData(dmso_df)
147 | dmso_dir = os.path.join(download_path, "dmso", "DMSO_24hr_expt1")
148 | dmso_adata.obs["TP53_mutation_status"] = _get_tp53_mutation_status(dmso_dir)
149 | dmso_adata.obs["condition"] = np.repeat("DMSO", dmso_adata.shape[0])
150 |
151 | full_adata = anndata.concat([idasanutlin_adata, dmso_adata])
152 | full_adata.layers["count"] = full_adata.X.copy()
153 | sc.pp.normalize_total(full_adata)
154 | sc.pp.log1p(full_adata)
155 | full_adata.raw = full_adata
156 | sc.pp.highly_variable_genes(
157 | full_adata,
158 | flavor="seurat_v3",
159 | n_top_genes=n_top_genes,
160 | layer="count",
161 | subset=True,
162 | )
163 | full_adata = full_adata[
164 | full_adata.layers["count"].sum(1) != 0
165 | ] # Remove cells with all zeros.
166 | return full_adata
167 |
--------------------------------------------------------------------------------
/contrastive_vi/data/datasets/norman_2019.py:
--------------------------------------------------------------------------------
1 | """
2 | Download, read, and preprocess Norman et al. (2019) expression data.
3 |
4 | Single-cell expression data from Norman et al. Exploring genetic interaction
5 | manifolds constructed from rich single-cell phenotypes. Science (2019).
6 | """
7 |
8 | import gzip
9 | import os
10 | import re
11 |
12 | import pandas as pd
13 | import scanpy as sc
14 | from anndata import AnnData
15 | from scipy.io import mmread
16 | from scipy.sparse import coo_matrix
17 |
18 | from contrastive_vi.data.utils import download_binary_file
19 |
20 | # Gene program lists obtained by cross-referencing the heatmap here
21 | # https://github.com/thomasmaxwellnorman/Perturbseq_GI/blob/master/GI_optimal_umap.ipynb
22 | # with Figure 2b in Norman 2019
23 | G1_CYCLE = [
24 | "CDKN1C+CDKN1B",
25 | "CDKN1B+ctrl",
26 | "CDKN1B+CDKN1A",
27 | "CDKN1C+ctrl",
28 | "ctrl+CDKN1A",
29 | "CDKN1C+CDKN1A",
30 | "CDKN1A+ctrl",
31 | ]
32 |
33 | ERYTHROID = [
34 | "BPGM+SAMD1",
35 | "ATL1+ctrl",
36 | "UBASH3B+ZBTB25",
37 | "PTPN12+PTPN9",
38 | "PTPN12+UBASH3A",
39 | "CBL+CNN1",
40 | "UBASH3B+CNN1",
41 | "CBL+UBASH3B",
42 | "UBASH3B+PTPN9",
43 | "PTPN1+ctrl",
44 | "CBL+PTPN9",
45 | "CNN1+UBASH3A",
46 | "CBL+PTPN12",
47 | "PTPN12+ZBTB25",
48 | "UBASH3B+PTPN12",
49 | "SAMD1+PTPN12",
50 | "SAMD1+UBASH3B",
51 | "UBASH3B+UBASH3A",
52 | ]
53 |
54 | PIONEER_FACTORS = [
55 | "ZBTB10+SNAI1",
56 | "FOXL2+MEIS1",
57 | "POU3F2+CBFA2T3",
58 | "DUSP9+SNAI1",
59 | "FOXA3+FOXA1",
60 | "FOXA3+ctrl",
61 | "LYL1+IER5L",
62 | "FOXA1+FOXF1",
63 | "FOXF1+HOXB9",
64 | "FOXA1+HOXB9",
65 | "FOXA3+HOXB9",
66 | "FOXA3+FOXA1",
67 | "FOXA3+FOXL2",
68 | "POU3F2+FOXL2",
69 | "FOXF1+FOXL2",
70 | "FOXA1+FOXL2",
71 | "HOXA13+ctrl",
72 | "ctrl+HOXC13",
73 | "HOXC13+ctrl",
74 | "MIDN+ctrl",
75 | "TP73+ctrl",
76 | ]
77 |
78 | GRANULOCYTE_APOPTOSIS = [
79 | "SPI1+ctrl",
80 | "ctrl+SPI1",
81 | "ctrl+CEBPB",
82 | "CEBPB+ctrl",
83 | "JUN+CEBPA",
84 | "CEBPB+CEBPA",
85 | "FOSB+CEBPE",
86 | "ZC3HAV1+CEBPA",
87 | "KLF1+CEBPA",
88 | "ctrl+CEBPA",
89 | "CEBPA+ctrl",
90 | "CEBPE+CEBPA",
91 | "CEBPE+SPI1",
92 | "CEBPE+ctrl",
93 | "ctrl+CEBPE",
94 | "CEBPE+RUNX1T1",
95 | "CEBPE+CEBPB",
96 | "FOSB+CEBPB",
97 | "ETS2+CEBPE",
98 | ]
99 |
100 | MEGAKARYOCYTE = [
101 | "ctrl+ETS2",
102 | "MAPK1+ctrl",
103 | "ctrl+MAPK1",
104 | "ETS2+MAPK1",
105 | "CEBPB+MAPK1",
106 | "MAPK1+TGFBR2",
107 | ]
108 |
109 | PRO_GROWTH = [
110 | "CEBPE+KLF1",
111 | "KLF1+MAP2K6",
112 | "AHR+KLF1",
113 | "ctrl+KLF1",
114 | "KLF1+ctrl",
115 | "KLF1+BAK1",
116 | "KLF1+TGFBR2",
117 | ]
118 |
119 |
120 | def download_norman_2019(output_path: str) -> None:
121 | """
122 | Download Norman et al. 2019 data and metadata files from the hosting URLs.
123 |
124 | Args:
125 | ----
126 | output_path: Output path to store the downloaded and unzipped
127 | directories.
128 |
129 | Returns
130 | -------
131 | None. File directories are downloaded to output_path.
132 | """
133 |
134 | file_urls = (
135 | "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE133nnn/GSE133344/suppl"
136 | "/GSE133344_filtered_matrix.mtx.gz",
137 | "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE133nnn/GSE133344/suppl"
138 | "/GSE133344_filtered_genes.tsv.gz",
139 | "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE133nnn/GSE133344/suppl"
140 | "/GSE133344_filtered_barcodes.tsv.gz",
141 | "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE133nnn/GSE133344/suppl"
142 | "/GSE133344_filtered_cell_identities.csv.gz",
143 | )
144 |
145 | for url in file_urls:
146 | output_filename = os.path.join(output_path, url.split("/")[-1])
147 | download_binary_file(url, output_filename)
148 |
149 |
150 | def read_norman_2019(file_directory: str) -> coo_matrix:
151 | """
152 | Read the expression data for Norman et al. 2019 in the given directory.
153 |
154 | Args:
155 | ----
156 | file_directory: Directory containing Norman et al. 2019 data.
157 |
158 | Returns
159 | -------
160 | A sparse matrix containing single-cell gene expression count, with rows
161 | representing genes and columns representing cells.
162 | """
163 |
164 | with gzip.open(
165 | os.path.join(file_directory, "GSE133344_filtered_matrix.mtx.gz"), "rb"
166 | ) as f:
167 | matrix = mmread(f)
168 |
169 | return matrix
170 |
171 |
172 | def preprocess_norman_2019(download_path: str, n_top_genes: int) -> AnnData:
173 | """
174 | Preprocess expression data from Norman et al. 2019.
175 |
176 | Args:
177 | ----
178 | download_path: Path containing the downloaded Norman et al. 2019 data file.
179 | n_top_genes: Number of most variable genes to retain.
180 |
181 | Returns
182 | -------
183 | An AnnData object containing single-cell expression data. The layer
184 | "count" contains the count data for the most variable genes. The .X
185 | variable contains the normalized and log-transformed data for the most variable
186 | genes. A copy of data with all genes is stored in .raw.
187 | """
188 | matrix = read_norman_2019(download_path)
189 |
190 | # List of cell barcodes. The barcodes in this list are stored in the same order
191 | # as cells are in the count matrix.
192 | cell_barcodes = pd.read_csv(
193 | os.path.join(download_path, "GSE133344_filtered_barcodes.tsv.gz"),
194 | sep="\t",
195 | header=None,
196 | names=["cell_barcode"],
197 | )
198 |
199 | # IDs/names of the gene features.
200 | gene_list = pd.read_csv(
201 | os.path.join(download_path, "GSE133344_filtered_genes.tsv.gz"),
202 | sep="\t",
203 | header=None,
204 | names=["gene_id", "gene_name"],
205 | )
206 |
207 | # Dataframe where each row corresponds to a cell, and each column corresponds
208 | # to a gene feature.
209 | matrix = pd.DataFrame(
210 | matrix.transpose().todense(),
211 | columns=gene_list["gene_id"],
212 | index=cell_barcodes["cell_barcode"],
213 | dtype="int32",
214 | )
215 |
216 | # Dataframe mapping cell barcodes to metadata about that cell (e.g. which CRISPR
217 | # guides were applied to that cell). Unfortunately, this list has a different
218 | # ordering from the count matrix, so we have to be careful combining the metadata
219 | # and count data.
220 | cell_identities = pd.read_csv(
221 | os.path.join(download_path, "GSE133344_filtered_cell_identities.csv.gz")
222 | ).set_index("cell_barcode")
223 |
224 | # This merge call reorders our metadata dataframe to match the ordering in the
225 | # count matrix. Some cells in `cell_barcodes` do not have metadata associated with
226 | # them, and their metadata values will be filled in as NaN.
227 | aligned_metadata = pd.merge(
228 | cell_barcodes,
229 | cell_identities,
230 | left_on="cell_barcode",
231 | right_index=True,
232 | how="left",
233 | ).set_index("cell_barcode")
234 |
235 | adata = AnnData(matrix)
236 | adata.obs = aligned_metadata
237 |
238 | # Filter out any cells that don't have metadata values.
239 | rows_without_nans = [
240 | index for index, row in adata.obs.iterrows() if not row.isnull().any()
241 | ]
242 | adata = adata[rows_without_nans, :]
243 |
244 | # Remove these as suggested by the authors. See lines referring to
245 | # NegCtrl1_NegCtrl0 in GI_generate_populations.ipynb in the Norman 2019 paper's
246 | # Github repo https://github.com/thomasmaxwellnorman/Perturbseq_GI/
247 | adata = adata[adata.obs["guide_identity"] != "NegCtrl1_NegCtrl0__NegCtrl1_NegCtrl0"]
248 |
249 | # We create a new metadata column with cleaner representations of CRISPR guide
250 | # identities. The original format is _____
251 | adata.obs["guide_merged"] = adata.obs["guide_identity"]
252 |
253 | control_regex = re.compile(r"NegCtrl(.*)_NegCtrl(.*)+NegCtrl(.*)_NegCtrl(.*)")
254 | for i in adata.obs["guide_merged"].unique():
255 | if control_regex.match(i):
256 | # For any cells that only had control guides, we don't care about the
257 | # specific IDs of the guides. Here we relabel them just as "ctrl".
258 | adata.obs["guide_merged"].replace(i, "ctrl", inplace=True)
259 | else:
260 | # Otherwise, we reformat the guide label to be +. If Guide1
261 | # or Guide2 was a control, we replace it with "ctrl".
262 | split = i.split("__")[0]
263 | split = split.split("_")
264 | for j, string in enumerate(split):
265 | if "NegCtrl" in split[j]:
266 | split[j] = "ctrl"
267 | adata.obs["guide_merged"].replace(i, f"{split[0]}+{split[1]}", inplace=True)
268 |
269 | guides_to_programs = {}
270 | guides_to_programs.update(dict.fromkeys(G1_CYCLE, "G1 cell cycle arrest"))
271 | guides_to_programs.update(dict.fromkeys(ERYTHROID, "Erythroid"))
272 | guides_to_programs.update(dict.fromkeys(PIONEER_FACTORS, "Pioneer factors"))
273 | guides_to_programs.update(
274 | dict.fromkeys(GRANULOCYTE_APOPTOSIS, "Granulocyte/apoptosis")
275 | )
276 | guides_to_programs.update(dict.fromkeys(PRO_GROWTH, "Pro-growth"))
277 | guides_to_programs.update(dict.fromkeys(MEGAKARYOCYTE, "Megakaryocyte"))
278 | guides_to_programs.update(dict.fromkeys(["ctrl"], "Ctrl"))
279 |
280 | # We only keep cells whose guides were either controls or are labeled with a
281 | # specific gene program
282 | adata = adata[adata.obs["guide_merged"].isin(guides_to_programs.keys())]
283 | adata.obs["gene_program"] = [
284 | guides_to_programs[x] for x in adata.obs["guide_merged"]
285 | ]
286 |
287 | adata.obs["good_coverage"] = adata.obs["good_coverage"].astype(bool)
288 |
289 | adata.layers["count"] = adata.X.copy()
290 | sc.pp.normalize_total(adata)
291 | sc.pp.log1p(adata)
292 | adata.raw = adata
293 | sc.pp.highly_variable_genes(
294 | adata, flavor="seurat_v3", n_top_genes=n_top_genes, layer="count", subset=True
295 | )
296 | adata = adata[adata.layers["count"].sum(1) != 0] # Remove cells with all zeros.
297 | return adata
298 |
--------------------------------------------------------------------------------
/contrastive_vi/data/datasets/papalexi_2021.py:
--------------------------------------------------------------------------------
1 | """
2 | Download, read, and preprocess Papalexi et al. (2021) expression data.
3 |
4 | Single-cell expression data from Papalexi et al. Characterizing the molecular regulation
5 | of inhibitory immune checkpoints with multimodal single-cell screens. (Nature Genetics
6 | 2021)
7 | """
8 | import os
9 | import shutil
10 |
11 | import constants
12 | import pandas as pd
13 | import scanpy as sc
14 | from anndata import AnnData
15 |
16 | from contrastive_vi.data.utils import download_binary_file
17 |
18 |
19 | def download_papalexi_2021(output_path: str) -> None:
20 | """
21 | Download Papalexi et al. 2021 data from the hosting URLs.
22 |
23 | Args:
24 | ----
25 | output_path: Output path to store the downloaded and unzipped
26 | directories.
27 |
28 | Returns
29 | -------
30 | None. File directories are downloaded to output_path.
31 | """
32 |
33 | counts_data_url = (
34 | "https://www.ncbi.nlm.nih.gov/geo/download/?acc=GSE153056&format=file"
35 | )
36 | data_output_filename = os.path.join(output_path, "GSE153056_RAW.tar")
37 | download_binary_file(counts_data_url, data_output_filename)
38 | shutil.unpack_archive(data_output_filename, output_path)
39 |
40 | metadata_url = (
41 | "https://www.ncbi.nlm.nih.gov/geo/download/?acc=GSE153056&"
42 | "format=file&file=GSE153056_ECCITE_metadata.tsv.gz"
43 | )
44 | metadata_filename = os.path.join(output_path, metadata_url.split("=")[-1])
45 | download_binary_file(metadata_url, metadata_filename)
46 |
47 |
48 | def read_papalexi_2021(file_directory: str) -> pd.DataFrame:
49 | """
50 | Read the expression data for Papalexi et al. 2021 in the given directory.
51 |
52 | Args:
53 | ----
54 | file_directory: Directory containing Papalexi et al. 2021 data.
55 |
56 | Returns
57 | -------
58 | A pandas dataframe, with each column representing a cell
59 | and each row representing a gene feature.
60 | """
61 |
62 | matrix = pd.read_csv(
63 | os.path.join(file_directory, "GSM4633614_ECCITE_cDNA_counts.tsv.gz"),
64 | sep="\t",
65 | index_col=0,
66 | )
67 | return matrix
68 |
69 |
70 | def preprocess_papalexi_2021(download_path: str, n_top_genes: int) -> AnnData:
71 | """
72 | Preprocess expression data from Papalexi et al. 2021.
73 |
74 | Args:
75 | ----
76 | download_path: Path containing the downloaded Papalexi et al. 2021 data files.
77 | n_top_genes: Number of most variable genes to retain.
78 |
79 | Returns
80 | -------
81 | An AnnData object containing single-cell expression data. The layer
82 | "count" contains the count data for the most variable genes. The .X
83 | variable contains the normalized and log-transformed data for the most variable
84 | genes. A copy of data with all genes is stored in .raw.
85 | """
86 |
87 | df = read_papalexi_2021(download_path)
88 |
89 | # Switch dataframe from gene rows and cell columns to cell rows and gene columns
90 | df = df.transpose()
91 |
92 | metadata = pd.read_csv(
93 | os.path.join(download_path, "GSE153056_ECCITE_metadata.tsv.gz"),
94 | sep="\t",
95 | index_col=0,
96 | )
97 |
98 | # Note: By initializing the anndata object from a dataframe, variable names
99 | # are automatically stored in adata.var
100 | adata = AnnData(df)
101 | adata.obs = metadata
102 |
103 | # Protein measurements also collected as part of CITE-Seq
104 | protein_counts_df = pd.read_csv(
105 | os.path.join(download_path, "GSM4633615_ECCITE_ADT_counts.tsv.gz"),
106 | sep="\t",
107 | index_col=0,
108 | )
109 |
110 | # Switch dataframe from protein rows and cell columns to cell rows and protein
111 | # columns
112 | protein_counts_df = protein_counts_df.transpose()
113 |
114 | # Storing protein counts in an obsm field as expected by totalVI
115 | # (see https://docs.scvi-tools.org/en/stable/tutorials/notebooks/totalVI.html
116 | # for an example). Since `protein_counts_df` is annotated with protein names,
117 | # our obsm field will retain them as well.
118 | adata.obsm[constants.PROTEIN_EXPRESSION_KEY] = protein_counts_df
119 |
120 | adata.layers["count"] = adata.X.copy()
121 | sc.pp.normalize_total(adata)
122 | sc.pp.log1p(adata)
123 | adata.raw = adata
124 | sc.pp.highly_variable_genes(
125 | adata, flavor="seurat_v3", n_top_genes=n_top_genes, layer="count", subset=True
126 | )
127 | adata = adata[adata.layers["count"].sum(1) != 0] # Remove cells with all zeros.
128 | return adata
129 |
--------------------------------------------------------------------------------
/contrastive_vi/data/datasets/zheng_2017.py:
--------------------------------------------------------------------------------
1 | """
2 | Download, read, and preprocess Zheng et al. (2017) expression data.
3 |
4 | Single-cell expression data from Zheng et al. Massively parallel digital
5 | transcriptional profiling of single cells. Nature Communications (2017).
6 | """
7 | import os
8 | import shutil
9 |
10 | import numpy as np
11 | import pandas as pd
12 | import scanpy as sc
13 | from anndata import AnnData
14 | from scipy.io import mmread
15 |
16 | from contrastive_vi.data.utils import download_binary_file
17 |
18 |
19 | def download_zheng_2017(output_path: str) -> None:
20 | """
21 | Download Zheng et al. 2017 data from the hosting URLs.
22 |
23 | Args:
24 | ----
25 | output_path: Output path to store the downloaded and unzipped
26 | directories.
27 |
28 | Returns
29 | -------
30 | None. File directories are downloaded and unzipped in output_path.
31 | """
32 | host = "https://cf.10xgenomics.com/samples/cell-exp/1.1.0/"
33 | host_directories = [
34 | (
35 | "aml027_post_transplant/"
36 | "aml027_post_transplant_filtered_gene_bc_matrices.tar.gz"
37 | ),
38 | (
39 | "aml027_pre_transplant/"
40 | "aml027_pre_transplant_filtered_gene_bc_matrices.tar.gz"
41 | ),
42 | (
43 | "aml035_post_transplant/"
44 | "aml035_post_transplant_filtered_gene_bc_matrices.tar.gz"
45 | ),
46 | (
47 | "aml035_pre_transplant/"
48 | "aml035_pre_transplant_filtered_gene_bc_matrices.tar.gz"
49 | ),
50 | (
51 | "frozen_bmmc_healthy_donor1/"
52 | "frozen_bmmc_healthy_donor1_filtered_gene_bc_matrices.tar.gz"
53 | ),
54 | (
55 | "frozen_bmmc_healthy_donor2/"
56 | "frozen_bmmc_healthy_donor2_filtered_gene_bc_matrices.tar.gz"
57 | ),
58 | ]
59 | urls = [host + host_directory for host_directory in host_directories]
60 | output_filenames = [os.path.join(output_path, url.split("/")[-1]) for url in urls]
61 | for url, output_filename in zip(urls, output_filenames):
62 | download_binary_file(url, output_filename)
63 | output_dir = output_filename.replace(".tar.gz", "")
64 | shutil.unpack_archive(output_filename, output_dir)
65 |
66 |
67 | def read_zheng_2017(file_directory: str) -> pd.DataFrame:
68 | """
69 | Read the expression data for in a downloaded file directory.
70 |
71 | Args:
72 | ----
73 | file_directory: A downloaded and unzipped file directory.
74 |
75 | Returns
76 | -------
77 | A data frame containing single-cell gene expression count, with cell
78 | identification barcodes as column names and gene IDs as indices.
79 | """
80 | data = mmread(
81 | os.path.join(file_directory, "filtered_matrices_mex/hg19/matrix.mtx")
82 | ).toarray()
83 | genes = pd.read_table(
84 | os.path.join(file_directory, "filtered_matrices_mex/hg19/genes.tsv"),
85 | header=None,
86 | )
87 | barcodes = pd.read_table(
88 | os.path.join(file_directory, "filtered_matrices_mex/hg19/barcodes.tsv"),
89 | header=None,
90 | )
91 | return pd.DataFrame(
92 | data, index=genes.iloc[:, 0].values, columns=barcodes.iloc[:, 0].values
93 | )
94 |
95 |
96 | def preprocess_zheng_2017(download_path: str, n_top_genes: int) -> AnnData:
97 | """
98 | Preprocess expression data from Zheng et al. 2017.
99 |
100 | Args:
101 | ----
102 | download_path: Path containing the downloaded and unzipped file
103 | directories.
104 | n_top_genes: Number of most variable genes to retain.
105 |
106 | Returns
107 | -------
108 | An AnnData object containing single-cell expression data. The layer
109 | "count" contains the count data for the most variable genes. The X
110 | variable contains the total-count-normalized and log-transformed data
111 | for the most variable genes (a copy with all the genes is stored in
112 | .raw).
113 | """
114 | file_directory_dict = {
115 | "aml027_pre_transplant": ("aml027_pre_transplant_filtered_gene_bc_matrices"),
116 | "aml027_post_transplant": ("aml027_post_transplant_filtered_gene_bc_matrices"),
117 | "aml035_pre_transplant": ("aml035_pre_transplant_filtered_gene_bc_matrices"),
118 | "aml035_post_transplant": ("aml035_post_transplant_filtered_gene_bc_matrices"),
119 | "donor1_healthy": ("frozen_bmmc_healthy_donor1_filtered_gene_bc_matrices"),
120 | "donor2_healthy": ("frozen_bmmc_healthy_donor2_filtered_gene_bc_matrices"),
121 | }
122 | df_dict = {
123 | sample_id: read_zheng_2017(os.path.join(download_path, file_directory))
124 | for sample_id, file_directory in file_directory_dict.items()
125 | }
126 | gene_set_list = []
127 | for sample_id, df in df_dict.items():
128 | df = df.iloc[:, np.sum(df.values, axis=0) != 0]
129 | df = df.iloc[np.sum(df.values, axis=1) != 0, :]
130 | df = df.transpose()
131 | gene_set_list.append(set(df.columns))
132 | patient_id, condition = sample_id.split("_", 1)
133 | df["patient_id"] = patient_id
134 | df["condition"] = condition
135 | df_dict[sample_id] = df
136 | shared_genes = list(set.intersection(*gene_set_list))
137 | data_list = []
138 | meta_data_list = []
139 | for df in df_dict.values():
140 | data_list.append(df[shared_genes])
141 | meta_data_list.append(df[["patient_id", "condition"]])
142 | data = pd.concat(data_list)
143 | meta_data = pd.concat(meta_data_list)
144 | adata = AnnData(X=data.reset_index(drop=True), obs=meta_data.reset_index(drop=True))
145 | adata.layers["count"] = adata.X.copy()
146 | sc.pp.normalize_total(adata)
147 | sc.pp.log1p(adata)
148 | adata.raw = adata
149 | sc.pp.highly_variable_genes(
150 | adata,
151 | flavor="seurat_v3",
152 | n_top_genes=n_top_genes,
153 | layer="count",
154 | subset=True,
155 | )
156 | adata = adata[adata.layers["count"].sum(1) != 0] # Remove cells with all zeros.
157 | return adata
158 |
--------------------------------------------------------------------------------
/contrastive_vi/data/utils.py:
--------------------------------------------------------------------------------
1 | """Data preprocessing utilities."""
2 | import os
3 |
4 | import requests
5 | from anndata import AnnData
6 |
7 |
8 | def download_binary_file(file_url: str, output_path: str) -> None:
9 | """
10 | Download binary data file from a URL.
11 |
12 | Args:
13 | ----
14 | file_url: URL where the file is hosted.
15 | output_path: Output path for the downloaded file.
16 |
17 | Returns
18 | -------
19 | None.
20 | """
21 | request = requests.get(file_url)
22 | with open(output_path, "wb") as f:
23 | f.write(request.content)
24 | print(f"Downloaded data from {file_url} at {output_path}")
25 |
26 |
27 | def save_preprocessed_adata(adata: AnnData, output_path: str) -> None:
28 | """
29 | Save given AnnData object with preprocessed data to disk using our dataset file
30 | naming convention.
31 |
32 | Args:
33 | ----
34 | adata: AnnData object containing expression count data as well as metadata.
35 | output_path: Path to save resulting file.
36 |
37 | Returns
38 | -------
39 | None. Provided AnnData object is saved to disk in a subdirectory called
40 | "preprocessed" in output_path.
41 | """
42 | preprocessed_directory = os.path.join(output_path, "preprocessed")
43 | os.makedirs(preprocessed_directory, exist_ok=True)
44 | n_genes = adata.shape[1]
45 | filename = os.path.join(
46 | preprocessed_directory,
47 | f"adata_top_{n_genes}_genes.h5ad",
48 | )
49 | adata.write_h5ad(filename=filename)
50 |
--------------------------------------------------------------------------------
/contrastive_vi/model/__init__.py:
--------------------------------------------------------------------------------
1 | """scvi-tools Model classes for contrastive-VI."""
2 | from .contrastive_vi import ContrastiveVIModel as ContrastiveVI
3 | from .total_contrastive_vi import TotalContrastiveVIModel as TotalContrastiveVI
4 |
5 | __all__ = ["ContrastiveVI", "TotalContrastiveVI"]
6 |
--------------------------------------------------------------------------------
/contrastive_vi/model/base/__init__.py:
--------------------------------------------------------------------------------
1 | """Reusable Model classes for inheritance."""
2 |
--------------------------------------------------------------------------------
/contrastive_vi/model/base/training_mixin.py:
--------------------------------------------------------------------------------
1 | """
2 | Mixin classes for pre-coded features.
3 | For more details on Mixin classes, see
4 | https://docs.scvi-tools.org/en/0.9.0/user_guide/notebooks/model_user_guide.html#Mixing-in-pre-coded-features
5 | """
6 |
7 | from typing import List, Optional, Union
8 |
9 | import numpy as np
10 | from scvi.train import TrainingPlan, TrainRunner
11 |
12 | from contrastive_vi.data.dataloaders.data_splitting import ContrastiveDataSplitter
13 |
14 |
15 | class ContrastiveTrainingMixin:
16 | """General methods for contrastive learning."""
17 |
18 | def train(
19 | self,
20 | background_indices: List[int],
21 | target_indices: List[int],
22 | max_epochs: Optional[int] = None,
23 | use_gpu: Optional[Union[str, int, bool]] = None,
24 | train_size: float = 0.9,
25 | validation_size: Optional[float] = None,
26 | batch_size: int = 128,
27 | early_stopping: bool = False,
28 | plan_kwargs: Optional[dict] = None,
29 | **trainer_kwargs,
30 | ) -> None:
31 | """
32 | Train a contrastive model.
33 |
34 | Args:
35 | ----
36 | background_indices: Indices for background samples in `adata`.
37 | target_indices: Indices for target samples in `adata`.
38 | max_epochs: Number of passes through the dataset. If `None`, default to
39 | `np.min([round((20000 / n_cells) * 400), 400])`.
40 | use_gpu: Use default GPU if available (if `None` or `True`), or index of
41 | GPU to use (if `int`), or name of GPU (if `str`, e.g., `"cuda:0"`),
42 | or use CPU (if `False`).
43 | train_size: Size of training set in the range [0.0, 1.0].
44 | validation_size: Size of the validation set. If `None`, default to
45 | `1 - train_size`. If `train_size + validation_size < 1`, the remaining
46 | cells belong to the test set.
47 | batch_size: Mini-batch size to use during training.
48 | early_stopping: Perform early stopping. Additional arguments can be passed
49 | in `**kwargs`. See :class:`~scvi.train.Trainer` for further options.
50 | plan_kwargs: Keyword args for :class:`~scvi.train.TrainingPlan`. Keyword
51 | arguments passed to `train()` will overwrite values present
52 | in `plan_kwargs`, when appropriate.
53 | **trainer_kwargs: Other keyword args for :class:`~scvi.train.Trainer`.
54 |
55 | Returns
56 | -------
57 | None. The model is trained.
58 | """
59 | if max_epochs is None:
60 | n_cells = self.adata.n_obs
61 | max_epochs = np.min([round((20000 / n_cells) * 400), 400])
62 |
63 | plan_kwargs = plan_kwargs if isinstance(plan_kwargs, dict) else dict()
64 |
65 | data_splitter = ContrastiveDataSplitter(
66 | self.adata_manager,
67 | background_indices,
68 | target_indices,
69 | train_size=train_size,
70 | validation_size=validation_size,
71 | batch_size=batch_size,
72 | use_gpu=use_gpu,
73 | )
74 | training_plan = TrainingPlan(self.module, **plan_kwargs)
75 |
76 | es = "early_stopping"
77 | trainer_kwargs[es] = (
78 | early_stopping if es not in trainer_kwargs.keys() else trainer_kwargs[es]
79 | )
80 | runner = TrainRunner(
81 | self,
82 | training_plan=training_plan,
83 | data_splitter=data_splitter,
84 | max_epochs=max_epochs,
85 | use_gpu=use_gpu,
86 | **trainer_kwargs,
87 | )
88 | return runner()
89 |
--------------------------------------------------------------------------------
/contrastive_vi/model/contrastive_vi.py:
--------------------------------------------------------------------------------
1 | """Model class for contrastive-VI for single cell expression data."""
2 |
3 | import logging
4 | import warnings
5 | from functools import partial
6 | from typing import Dict, Iterable, List, Optional, Sequence, Union
7 |
8 | import numpy as np
9 | import pandas as pd
10 | import torch
11 | from anndata import AnnData
12 | from scvi import REGISTRY_KEYS
13 | from scvi.data import AnnDataManager
14 | from scvi.data.fields import (
15 | CategoricalJointObsField,
16 | CategoricalObsField,
17 | LayerField,
18 | NumericalJointObsField,
19 | NumericalObsField,
20 | )
21 | from scvi.dataloaders import AnnDataLoader
22 | from scvi.model._utils import (
23 | _get_batch_code_from_category,
24 | _init_library_size,
25 | scrna_raw_counts_properties,
26 | )
27 | from scvi.model.base import BaseModelClass
28 | from scvi.model.base._utils import _de_core
29 | from scvi.utils import setup_anndata_dsp
30 |
31 | from contrastive_vi.model.base.training_mixin import ContrastiveTrainingMixin
32 | from contrastive_vi.module.contrastive_vi import ContrastiveVIModule
33 |
34 | logger = logging.getLogger(__name__)
35 | Number = Union[int, float]
36 |
37 |
38 | class ContrastiveVIModel(ContrastiveTrainingMixin, BaseModelClass):
39 | """
40 | Model class for contrastive-VI.
41 | Args:
42 | ----
43 | adata: AnnData object that has been registered via
44 | `ContrastiveVIModel.setup_anndata`.
45 | n_batch: Number of batches. If 0, no batch effect correction is performed.
46 | n_hidden: Number of nodes per hidden layer.
47 | n_latent: Dimensionality of the latent space.
48 | n_layers: Number of hidden layers used for encoder and decoder NNs.
49 | dropout_rate: Dropout rate for neural networks.
50 | use_observed_lib_size: Use observed library size for RNA as scaling factor in
51 | mean of conditional distribution.
52 | disentangle: Whether to disentangle the salient and background latent variables.
53 | use_mmd: Whether to use the maximum mean discrepancy loss to force background
54 | latent variables to have the same distribution for background and target
55 | data.
56 | mmd_weight: Weight used for the MMD loss.
57 | gammas: Gamma parameters for the MMD loss.
58 | """
59 |
60 | def __init__(
61 | self,
62 | adata: AnnData,
63 | n_batch: int = 0,
64 | n_hidden: int = 128,
65 | n_background_latent: int = 10,
66 | n_salient_latent: int = 10,
67 | n_layers: int = 1,
68 | dropout_rate: float = 0.1,
69 | use_observed_lib_size: bool = True,
70 | wasserstein_penalty: float = 0,
71 | ) -> None:
72 | super(ContrastiveVIModel, self).__init__(adata)
73 |
74 | n_cats_per_cov = (
75 | self.adata_manager.get_state_registry(
76 | REGISTRY_KEYS.CAT_COVS_KEY
77 | ).n_cats_per_key
78 | if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry
79 | else None
80 | )
81 | n_batch = self.summary_stats.n_batch
82 | use_size_factor_key = (
83 | REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry
84 | )
85 | library_log_means, library_log_vars = None, None
86 | if not use_size_factor_key:
87 | library_log_means, library_log_vars = _init_library_size(
88 | self.adata_manager, n_batch
89 | )
90 |
91 | self.module = ContrastiveVIModule(
92 | n_input=self.summary_stats["n_vars"],
93 | n_batch=n_batch,
94 | n_hidden=n_hidden,
95 | n_background_latent=n_background_latent,
96 | n_salient_latent=n_salient_latent,
97 | n_layers=n_layers,
98 | dropout_rate=dropout_rate,
99 | use_observed_lib_size=use_observed_lib_size,
100 | library_log_means=library_log_means,
101 | library_log_vars=library_log_vars,
102 | wasserstein_penalty=wasserstein_penalty,
103 | )
104 | self._model_summary_string = "Contrastive-VI."
105 | # Necessary line to get params to be used for saving and loading.
106 | self.init_params_ = self._get_init_params(locals())
107 | logger.info("The model has been initialized")
108 |
109 | @classmethod
110 | @setup_anndata_dsp.dedent
111 | def setup_anndata(
112 | cls,
113 | adata: AnnData,
114 | layer: Optional[str] = None,
115 | batch_key: Optional[str] = None,
116 | labels_key: Optional[str] = None,
117 | size_factor_key: Optional[str] = None,
118 | categorical_covariate_keys: Optional[List[str]] = None,
119 | continuous_covariate_keys: Optional[List[str]] = None,
120 | **kwargs,
121 | ) -> Optional[AnnData]:
122 | """
123 | Set up AnnData instance for contrastive-VI model.
124 |
125 | Args:
126 | ----
127 | adata: AnnData object containing raw counts. Rows represent cells, columns
128 | represent features.
129 | layer: If not None, uses this as the key in adata.layers for raw count data.
130 | batch_key: Key in `adata.obs` for batch information. Categories will
131 | automatically be converted into integer categories and saved to
132 | `adata.obs["_scvi_batch"]`. If None, assign the same batch to all the
133 | data.
134 | labels_key: Key in `adata.obs` for label information. Categories will
135 | automatically be converted into integer categories and saved to
136 | `adata.obs["_scvi_labels"]`. If None, assign the same label to all the
137 | data.
138 | size_factor_key: Key in `adata.obs` for size factor information. Instead of
139 | using library size as a size factor, the provided size factor column
140 | will be used as offset in the mean of the likelihood. Assumed to be on
141 | linear scale.
142 | categorical_covariate_keys: Keys in `adata.obs` corresponding to categorical
143 | data. Used in some models.
144 | continuous_covariate_keys: Keys in `adata.obs` corresponding to continuous
145 | data. Used in some models.
146 |
147 | Returns
148 | -------
149 | If `copy` is True, return the modified `adata` set up for contrastive-VI
150 | model, otherwise `adata` is modified in place.
151 | """
152 | setup_method_args = cls._get_setup_method_args(**locals())
153 | anndata_fields = [
154 | LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True),
155 | CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key),
156 | CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key),
157 | NumericalObsField(
158 | REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key, required=False
159 | ),
160 | CategoricalJointObsField(
161 | REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys
162 | ),
163 | NumericalJointObsField(
164 | REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys
165 | ),
166 | ]
167 | adata_manager = AnnDataManager(
168 | fields=anndata_fields, setup_method_args=setup_method_args
169 | )
170 | adata_manager.register_fields(adata, **kwargs)
171 | cls.register_manager(adata_manager)
172 |
173 | @torch.no_grad()
174 | def get_latent_representation(
175 | self,
176 | adata: Optional[AnnData] = None,
177 | indices: Optional[Sequence[int]] = None,
178 | give_mean: bool = True,
179 | batch_size: Optional[int] = None,
180 | representation_kind: str = "salient",
181 | ) -> np.ndarray:
182 | """
183 | Return the background or salient latent representation for each cell.
184 |
185 | Args:
186 | ----
187 | adata: AnnData object with equivalent structure to initial AnnData. If `None`,
188 | defaults to the AnnData object used to initialize the model.
189 | indices: Indices of cells in adata to use. If `None`, all cells are used.
190 | give_mean: Give mean of distribution or sample from it.
191 | batch_size: Mini-batch size for data loading into model. Defaults to
192 | `scvi.settings.batch_size`.
193 | representation_kind: Either "background" or "salient" for the corresponding
194 | representation kind.
195 |
196 | Returns
197 | -------
198 | A numpy array with shape `(n_cells, n_latent)`.
199 | """
200 | available_representation_kinds = ["background", "salient"]
201 | assert representation_kind in available_representation_kinds, (
202 | f"representation_kind = {representation_kind} is not one of"
203 | f" {available_representation_kinds}"
204 | )
205 |
206 | adata = self._validate_anndata(adata)
207 | data_loader = self._make_data_loader(
208 | adata=adata,
209 | indices=indices,
210 | batch_size=batch_size,
211 | shuffle=False,
212 | data_loader_class=AnnDataLoader,
213 | )
214 | latent = []
215 | for tensors in data_loader:
216 | x = tensors[REGISTRY_KEYS.X_KEY]
217 | batch_index = tensors[REGISTRY_KEYS.BATCH_KEY]
218 | outputs = self.module._generic_inference(
219 | x=x, batch_index=batch_index, n_samples=1
220 | )
221 |
222 | if representation_kind == "background":
223 | latent_m = outputs["qz_m"]
224 | latent_sample = outputs["z"]
225 | else:
226 | latent_m = outputs["qs_m"]
227 | latent_sample = outputs["s"]
228 |
229 | if give_mean:
230 | latent_sample = latent_m
231 |
232 | latent += [latent_sample.detach().cpu()]
233 | return torch.cat(latent).numpy()
234 |
235 | def get_normalized_expression_fold_change(
236 | self,
237 | adata: Optional[AnnData] = None,
238 | indices: Optional[Sequence[int]] = None,
239 | transform_batch: Optional[Sequence[Union[Number, str]]] = None,
240 | gene_list: Optional[Sequence[str]] = None,
241 | library_size: Union[float, str] = 1.0,
242 | n_samples: int = 1,
243 | batch_size: Optional[int] = None,
244 | ) -> np.ndarray:
245 | """
246 | Return the normalized (decoded) gene expression.
247 |
248 | Args:
249 | ----
250 | adata: AnnData object with equivalent structure to initial AnnData. If `None`,
251 | defaults to the AnnData object used to initialize the model.
252 | indices: Indices of cells in adata to use. If `None`, all cells are used.
253 | transform_batch: Batch to condition on. If transform_batch is:
254 | - None, then real observed batch is used.
255 | - int, then batch transform_batch is used.
256 | gene_list: Return frequencies of expression for a subset of genes. This can
257 | save memory when working with large datasets and few genes are of interest.
258 | library_size: Scale the expression frequencies to a common library size. This
259 | allows gene expression levels to be interpreted on a common scale of
260 | relevant magnitude. If set to `"latent"`, use the latent library size.
261 | n_samples: Number of posterior samples to use for estimation.
262 | batch_size: Mini-batch size for data loading into model. Defaults to
263 | `scvi.settings.batch_size`.
264 |
265 | Returns
266 | -------
267 | If `n_samples` > 1, then the shape is `(samples, cells, genes)`. Otherwise,
268 | shape is `(cells, genes)`. Each element is fold change of salient normalized
269 | expression divided by background normalized expression.
270 | """
271 | exprs = self.get_normalized_expression(
272 | adata=adata,
273 | indices=indices,
274 | transform_batch=transform_batch,
275 | gene_list=gene_list,
276 | library_size=library_size,
277 | n_samples=n_samples,
278 | batch_size=batch_size,
279 | return_mean=False,
280 | return_numpy=True,
281 | )
282 | salient_exprs = exprs["salient"]
283 | background_exprs = exprs["background"]
284 | fold_change = salient_exprs / background_exprs
285 | return fold_change
286 |
287 | @torch.no_grad()
288 | def get_normalized_expression(
289 | self,
290 | adata: Optional[AnnData] = None,
291 | indices: Optional[Sequence[int]] = None,
292 | transform_batch: Optional[Sequence[Union[Number, str]]] = None,
293 | gene_list: Optional[Sequence[str]] = None,
294 | library_size: Union[float, str] = 1.0,
295 | n_samples: int = 1,
296 | n_samples_overall: Optional[int] = None,
297 | batch_size: Optional[int] = None,
298 | return_mean: bool = True,
299 | return_numpy: Optional[bool] = None,
300 | ) -> Dict[str, Union[np.ndarray, pd.DataFrame]]:
301 | """
302 | Return the normalized (decoded) gene expression.
303 |
304 | Args:
305 | ----
306 | adata: AnnData object with equivalent structure to initial AnnData. If `None`,
307 | defaults to the AnnData object used to initialize the model.
308 | indices: Indices of cells in adata to use. If `None`, all cells are used.
309 | transform_batch: Batch to condition on. If transform_batch is:
310 | - None, then real observed batch is used.
311 | - int, then batch transform_batch is used.
312 | gene_list: Return frequencies of expression for a subset of genes. This can
313 | save memory when working with large datasets and few genes are of interest.
314 | library_size: Scale the expression frequencies to a common library size. This
315 | allows gene expression levels to be interpreted on a common scale of
316 | relevant magnitude. If set to `"latent"`, use the latent library size.
317 | n_samples: Number of posterior samples to use for estimation.
318 | n_samples_overall: The number of random samples in `adata` to use.
319 | batch_size: Mini-batch size for data loading into model. Defaults to
320 | `scvi.settings.batch_size`.
321 | return_mean: Whether to return the mean of the samples.
322 | return_numpy: Return a `numpy.ndarray` instead of a `pandas.DataFrame`.
323 | DataFrame includes gene names as columns. If either `n_samples=1` or
324 | `return_mean=True`, defaults to `False`. Otherwise, it defaults to `True`.
325 |
326 | Returns
327 | -------
328 | A dictionary with keys "background" and "salient", with value as follows.
329 | If `n_samples` > 1 and `return_mean` is `False`, then the shape is
330 | `(samples, cells, genes)`. Otherwise, shape is `(cells, genes)`. In this
331 | case, return type is `pandas.DataFrame` unless `return_numpy` is `True`.
332 | """
333 | adata = self._validate_anndata(adata)
334 | if indices is None:
335 | indices = np.arange(adata.n_obs)
336 | if n_samples_overall is not None:
337 | indices = np.random.choice(indices, n_samples_overall)
338 | data_loader = self._make_data_loader(
339 | adata=adata,
340 | indices=indices,
341 | batch_size=batch_size,
342 | shuffle=False,
343 | data_loader_class=AnnDataLoader,
344 | )
345 |
346 | transform_batch = _get_batch_code_from_category(
347 | self.get_anndata_manager(adata, required=True), transform_batch
348 | )
349 |
350 | if gene_list is None:
351 | gene_mask = slice(None)
352 | else:
353 | all_genes = adata.var_names
354 | gene_mask = [True if gene in gene_list else False for gene in all_genes]
355 |
356 | if n_samples > 1 and return_mean is False:
357 | if return_numpy is False:
358 | warnings.warn(
359 | "return_numpy must be True if n_samples > 1 and"
360 | " return_mean is False, returning np.ndarray"
361 | )
362 | return_numpy = True
363 | if library_size == "latent":
364 | generative_output_key = "px_rate"
365 | scaling = 1
366 | else:
367 | generative_output_key = "px_scale"
368 | scaling = library_size
369 |
370 | background_exprs = []
371 | salient_exprs = []
372 | for tensors in data_loader:
373 | x = tensors[REGISTRY_KEYS.X_KEY]
374 | batch_index = tensors[REGISTRY_KEYS.BATCH_KEY]
375 | background_per_batch_exprs = []
376 | salient_per_batch_exprs = []
377 | for batch in transform_batch:
378 | if batch is not None:
379 | batch_index = torch.ones_like(batch_index) * batch
380 | inference_outputs = self.module._generic_inference(
381 | x=x, batch_index=batch_index, n_samples=n_samples
382 | )
383 | z = inference_outputs["z"]
384 | s = inference_outputs["s"]
385 | library = inference_outputs["library"]
386 | background_generative_outputs = self.module._generic_generative(
387 | z=z, s=torch.zeros_like(s), library=library, batch_index=batch_index
388 | )
389 | salient_generative_outputs = self.module._generic_generative(
390 | z=z, s=s, library=library, batch_index=batch_index
391 | )
392 | background_outputs = self._preprocess_normalized_expression(
393 | background_generative_outputs,
394 | generative_output_key,
395 | gene_mask,
396 | scaling,
397 | )
398 | background_per_batch_exprs.append(background_outputs)
399 | salient_outputs = self._preprocess_normalized_expression(
400 | salient_generative_outputs,
401 | generative_output_key,
402 | gene_mask,
403 | scaling,
404 | )
405 | salient_per_batch_exprs.append(salient_outputs)
406 |
407 | background_per_batch_exprs = np.stack(
408 | background_per_batch_exprs
409 | ) # Shape is (len(transform_batch) x batch_size x n_var).
410 | salient_per_batch_exprs = np.stack(salient_per_batch_exprs)
411 | background_exprs += [background_per_batch_exprs.mean(0)]
412 | salient_exprs += [salient_per_batch_exprs.mean(0)]
413 |
414 | if n_samples > 1:
415 | # The -2 axis correspond to cells.
416 | background_exprs = np.concatenate(background_exprs, axis=-2)
417 | salient_exprs = np.concatenate(salient_exprs, axis=-2)
418 | else:
419 | background_exprs = np.concatenate(background_exprs, axis=0)
420 | salient_exprs = np.concatenate(salient_exprs, axis=0)
421 | if n_samples > 1 and return_mean:
422 | background_exprs = background_exprs.mean(0)
423 | salient_exprs = salient_exprs.mean(0)
424 |
425 | if return_numpy is None or return_numpy is False:
426 | genes = adata.var_names[gene_mask]
427 | samples = adata.obs_names[indices]
428 | background_exprs = pd.DataFrame(
429 | background_exprs, columns=genes, index=samples
430 | )
431 | salient_exprs = pd.DataFrame(salient_exprs, columns=genes, index=samples)
432 | return {"background": background_exprs, "salient": salient_exprs}
433 |
434 | @torch.no_grad()
435 | def get_salient_normalized_expression(
436 | self,
437 | adata: Optional[AnnData] = None,
438 | indices: Optional[Sequence[int]] = None,
439 | transform_batch: Optional[Sequence[Union[Number, str]]] = None,
440 | gene_list: Optional[Sequence[str]] = None,
441 | library_size: Union[float, str] = 1.0,
442 | n_samples: int = 1,
443 | n_samples_overall: Optional[int] = None,
444 | batch_size: Optional[int] = None,
445 | return_mean: bool = True,
446 | return_numpy: Optional[bool] = None,
447 | ) -> Union[np.ndarray, pd.DataFrame]:
448 | """
449 | Return the normalized (decoded) gene expression.
450 |
451 | Gene expressions are decoded from both the background and salient latent space.
452 |
453 | Args:
454 | ----
455 | adata: AnnData object with equivalent structure to initial AnnData. If `None`,
456 | defaults to the AnnData object used to initialize the model.
457 | indices: Indices of cells in adata to use. If `None`, all cells are used.
458 | transform_batch: Batch to condition on. If transform_batch is:
459 | - None, then real observed batch is used.
460 | - int, then batch transform_batch is used.
461 | gene_list: Return frequencies of expression for a subset of genes. This can
462 | save memory when working with large datasets and few genes are of interest.
463 | library_size: Scale the expression frequencies to a common library size. This
464 | allows gene expression levels to be interpreted on a common scale of
465 | relevant magnitude. If set to `"latent"`, use the latent library size.
466 | n_samples: Number of posterior samples to use for estimation.
467 | n_samples_overall: The number of random samples in `adata` to use.
468 | batch_size: Mini-batch size for data loading into model. Defaults to
469 | `scvi.settings.batch_size`.
470 | return_mean: Whether to return the mean of the samples.
471 | return_numpy: Return a `numpy.ndarray` instead of a `pandas.DataFrame`.
472 | DataFrame includes gene names as columns. If either `n_samples=1` or
473 | `return_mean=True`, defaults to `False`. Otherwise, it defaults to `True`.
474 |
475 | Returns
476 | -------
477 | If `n_samples` > 1 and `return_mean` is `False`, then the shape is
478 | `(samples, cells, genes)`. Otherwise, shape is `(cells, genes)`. In this
479 | case, return type is `pandas.DataFrame` unless `return_numpy` is `True`.
480 | """
481 | exprs = self.get_normalized_expression(
482 | adata=adata,
483 | indices=indices,
484 | transform_batch=transform_batch,
485 | gene_list=gene_list,
486 | library_size=library_size,
487 | n_samples=n_samples,
488 | n_samples_overall=n_samples_overall,
489 | batch_size=batch_size,
490 | return_mean=return_mean,
491 | return_numpy=return_numpy,
492 | )
493 | return exprs["salient"]
494 |
495 | @torch.no_grad()
496 | def get_specific_normalized_expression(
497 | self,
498 | adata: Optional[AnnData] = None,
499 | indices: Optional[Sequence[int]] = None,
500 | transform_batch: Optional[Sequence[Union[Number, str]]] = None,
501 | gene_list: Optional[Sequence[str]] = None,
502 | library_size: Union[float, str] = 1,
503 | n_samples: int = 1,
504 | n_samples_overall: Optional[int] = None,
505 | batch_size: Optional[int] = None,
506 | return_mean: bool = True,
507 | return_numpy: Optional[bool] = None,
508 | expression_type: Optional[str] = None,
509 | indices_to_return_salient: Optional[Sequence[int]] = None,
510 | ):
511 | """
512 | Return the normalized (decoded) gene expression.
513 |
514 | Gene expressions are decoded from either the background or salient latent space.
515 | One of `expression_type` or `indices_to_return_salient` should have an input
516 | argument.
517 |
518 | Args:
519 | ----
520 | adata: AnnData object with equivalent structure to initial AnnData. If `None`,
521 | defaults to the AnnData object used to initialize the model.
522 | indices: Indices of cells in adata to use. If `None`, all cells are used.
523 | transform_batch: Batch to condition on. If transform_batch is:
524 | - None, then real observed batch is used.
525 | - int, then batch transform_batch is used.
526 | gene_list: Return frequencies of expression for a subset of genes. This can
527 | save memory when working with large datasets and few genes are of interest.
528 | library_size: Scale the expression frequencies to a common library size. This
529 | allows gene expression levels to be interpreted on a common scale of
530 | relevant magnitude. If set to `"latent"`, use the latent library size.
531 | n_samples: Number of posterior samples to use for estimation.
532 | n_samples_overall: The number of random samples in `adata` to use.
533 | batch_size: Mini-batch size for data loading into model. Defaults to
534 | `scvi.settings.batch_size`.
535 | return_mean: Whether to return the mean of the samples.
536 | return_numpy: Return a `numpy.ndarray` instead of a `pandas.DataFrame`.
537 | DataFrame includes gene names as columns. If either `n_samples=1` or
538 | `return_mean=True`, defaults to `False`. Otherwise, it defaults to `True`.
539 | expression_type: One of {"salient", "background"} to specify the type of
540 | normalized expression to return.
541 | indices_to_return_salient: If `indices` is a subset of
542 | `indices_to_return_salient`, normalized expressions derived from background
543 | and salient latent embeddings are returned. If `indices` is not `None` and
544 | is not a subset of `indices_to_return_salient`, normalized expressions
545 | derived only from background latent embeddings are returned.
546 |
547 | Returns
548 | -------
549 | If `n_samples` > 1 and `return_mean` is `False`, then the shape is
550 | `(samples, cells, genes)`. Otherwise, shape is `(cells, genes)`. In this
551 | case, return type is `pandas.DataFrame` unless `return_numpy` is `True`.
552 | """
553 | is_expression_type_none = expression_type is None
554 | is_indices_to_return_salient_none = indices_to_return_salient is None
555 | if is_expression_type_none and is_indices_to_return_salient_none:
556 | raise ValueError(
557 | "Both expression_type and indices_to_return_salient are None! "
558 | "Exactly one of them needs to be supplied with an input argument."
559 | )
560 | elif (not is_expression_type_none) and (not is_indices_to_return_salient_none):
561 | raise ValueError(
562 | "Both expression_type and indices_to_return_salient have an input "
563 | "argument! Exactly one of them needs to be supplied with an input "
564 | "argument."
565 | )
566 | else:
567 | exprs = self.get_normalized_expression(
568 | adata=adata,
569 | indices=indices,
570 | transform_batch=transform_batch,
571 | gene_list=gene_list,
572 | library_size=library_size,
573 | n_samples=n_samples,
574 | n_samples_overall=n_samples_overall,
575 | batch_size=batch_size,
576 | return_mean=return_mean,
577 | return_numpy=return_numpy,
578 | )
579 | if not is_expression_type_none:
580 | return exprs[expression_type]
581 | else:
582 | if indices is None:
583 | indices = np.arange(adata.n_obs)
584 | if set(indices).issubset(set(indices_to_return_salient)):
585 | return exprs["salient"]
586 | else:
587 | return exprs["background"]
588 |
589 | def differential_expression(
590 | self,
591 | adata: Optional[AnnData] = None,
592 | groupby: Optional[str] = None,
593 | group1: Optional[Iterable[str]] = None,
594 | group2: Optional[str] = None,
595 | idx1: Optional[Union[Sequence[int], Sequence[bool], str]] = None,
596 | idx2: Optional[Union[Sequence[int], Sequence[bool], str]] = None,
597 | mode: str = "change",
598 | delta: float = 0.25,
599 | batch_size: Optional[int] = None,
600 | all_stats: bool = True,
601 | batch_correction: bool = False,
602 | batchid1: Optional[Iterable[str]] = None,
603 | batchid2: Optional[Iterable[str]] = None,
604 | fdr_target: float = 0.05,
605 | silent: bool = False,
606 | target_idx: Optional[Sequence[int]] = None,
607 | n_samples: int = 1,
608 | **kwargs,
609 | ) -> pd.DataFrame:
610 | r"""
611 | Perform differential expression analysis.
612 |
613 | Args:
614 | ----
615 | adata: AnnData object with equivalent structure to initial AnnData. If `None`,
616 | defaults to the AnnData object used to initialize the model.
617 | groupby: The key of the observations grouping to consider.
618 | group1: Subset of groups, e.g. ["g1", "g2", "g3"], to which comparison shall be
619 | restricted, or all groups in `groupby` (default).
620 | group2: If `None`, compare each group in `group1` to the union of the rest of
621 | the groups in `groupby`. If a group identifier, compare with respect to this
622 | group.
623 | idx1: `idx1` and `idx2` can be used as an alternative to the AnnData keys.
624 | Custom identifier for `group1` that can be of three sorts:
625 | (1) a boolean mask, (2) indices, or (3) a string. If it is a string, then
626 | it will query indices that verifies conditions on adata.obs, as described
627 | in `pandas.DataFrame.query()`. If `idx1` is not `None`, this option
628 | overrides `group1` and `group2`.
629 | idx2: Custom identifier for `group2` that has the same properties as `idx1`.
630 | By default, includes all cells not specified in `idx1`.
631 | mode: Method for differential expression. See
632 | https://docs.scvi-tools.org/en/0.14.1/user_guide/background/differential_expression.html
633 | for more details.
634 | delta: Specific case of region inducing differential expression. In this case,
635 | we suppose that R\[-delta, delta] does not induce differential expression
636 | (change model default case).
637 | batch_size: Mini-batch size for data loading into model. Defaults to
638 | scvi.settings.batch_size.
639 | all_stats: Concatenate count statistics (e.g., mean expression group 1) to DE
640 | results.
641 | batch_correction: Whether to correct for batch effects in DE inference.
642 | batchid1: Subset of categories from `batch_key` registered in `setup_anndata`,
643 | e.g. ["batch1", "batch2", "batch3"], for `group1`. Only used if
644 | `batch_correction` is `True`, and by default all categories are used.
645 | batchid2: Same as `batchid1` for `group2`. `batchid2` must either have null
646 | intersection with `batchid1`, or be exactly equal to `batchid1`. When the
647 | two sets are exactly equal, cells are compared by decoding on the same
648 | batch. When sets have null intersection, cells from `group1` and `group2`
649 | are decoded on each group in `group1` and `group2`, respectively.
650 | fdr_target: Tag features as DE based on posterior expected false discovery rate.
651 | silent: If `True`, disables the progress bar. Default: `False`.
652 | target_idx: If not `None`, a boolean or integer identifier should be used for
653 | cells in the contrastive target group. Normalized expression values derived
654 | from both salient and background latent embeddings are used when
655 | {group1, group2} is a subset of the target group, otherwise background
656 | normalized expression values are used.
657 |
658 | **kwargs: Keyword args for
659 | `scvi.model.base.DifferentialComputation.get_bayes_factors`.
660 |
661 | Returns
662 | -------
663 | Differential expression DataFrame.
664 | """
665 | adata = self._validate_anndata(adata)
666 | col_names = adata.var_names
667 |
668 | if target_idx is not None:
669 | target_idx = np.array(target_idx)
670 | if target_idx.dtype is np.dtype("bool"):
671 | assert (
672 | len(target_idx) == adata.n_obs
673 | ), "target_idx mask must be the same length as adata!"
674 | target_idx = np.arange(adata.n_obs)[target_idx]
675 | model_fn = partial(
676 | self.get_specific_normalized_expression,
677 | return_numpy=True,
678 | n_samples=n_samples,
679 | batch_size=batch_size,
680 | expression_type=None,
681 | indices_to_return_salient=target_idx,
682 | )
683 | else:
684 | model_fn = partial(
685 | self.get_specific_normalized_expression,
686 | return_numpy=True,
687 | n_samples=n_samples,
688 | batch_size=batch_size,
689 | expression_type="salient",
690 | indices_to_return_salient=None,
691 | )
692 |
693 | result = _de_core(
694 | self.get_anndata_manager(adata, required=True),
695 | model_fn,
696 | groupby=groupby,
697 | group1=group1,
698 | group2=group2,
699 | idx1=idx1,
700 | idx2=idx2,
701 | all_stats=all_stats,
702 | all_stats_fn=scrna_raw_counts_properties,
703 | col_names=col_names,
704 | mode=mode,
705 | batchid1=batchid1,
706 | batchid2=batchid2,
707 | delta=delta,
708 | batch_correction=batch_correction,
709 | fdr=fdr_target,
710 | silent=silent,
711 | **kwargs,
712 | )
713 | return result
714 |
715 | @staticmethod
716 | @torch.no_grad()
717 | def _preprocess_normalized_expression(
718 | generative_outputs: Dict[str, torch.Tensor],
719 | generative_output_key: str,
720 | gene_mask: Union[list, slice],
721 | scaling: float,
722 | ) -> np.ndarray:
723 | output = generative_outputs[generative_output_key]
724 | output = output[..., gene_mask]
725 | output *= scaling
726 | output = output.cpu().numpy()
727 | return output
728 |
729 | @torch.no_grad()
730 | def get_latent_library_size(
731 | self,
732 | adata: Optional[AnnData] = None,
733 | indices: Optional[Sequence[int]] = None,
734 | give_mean: bool = True,
735 | batch_size: Optional[int] = None,
736 | ) -> np.ndarray:
737 | r"""
738 | Returns the latent library size for each cell.
739 | This is denoted as :math:`\ell_n` in the scVI paper.
740 | Parameters
741 | ----------
742 | adata
743 | AnnData object with equivalent structure to initial AnnData. If `None`,
744 | defaults to the AnnData object used to initialize the model.
745 | indices
746 | Indices of cells in adata to use. If `None`, all cells are used.
747 | give_mean
748 | Return the mean or a sample from the posterior distribution.
749 | batch_size
750 | Minibatch size for data loading into model. Defaults to
751 | `scvi.settings.batch_size`.
752 | """
753 | self._check_if_trained(warn=False)
754 |
755 | adata = self._validate_anndata(adata)
756 | scdl = self._make_data_loader(
757 | adata=adata, indices=indices, batch_size=batch_size
758 | )
759 | libraries = []
760 | for tensors in scdl:
761 | x = tensors[REGISTRY_KEYS.X_KEY]
762 | batch_index = tensors[REGISTRY_KEYS.BATCH_KEY]
763 | outputs = self.module._generic_inference(x=x, batch_index=batch_index)
764 |
765 | library = outputs["library"]
766 | if not give_mean:
767 | library = torch.exp(library)
768 | else:
769 | ql = (outputs["ql_m"], outputs["ql_v"])
770 | if ql is None:
771 |
772 | raise RuntimeError(
773 | "The module for this model does not compute the posterior"
774 | "distribution for the library size. Set `give_mean` to False"
775 | "to use the observed library size instead."
776 | )
777 | library = torch.distributions.LogNormal(ql[0], ql[1]).mean
778 | libraries += [library.cpu()]
779 | return torch.cat(libraries).numpy()
780 |
--------------------------------------------------------------------------------
/contrastive_vi/model/total_contrastive_vi.py:
--------------------------------------------------------------------------------
1 | """Model class for contrastive-VI for single cell expression data."""
2 |
3 | import logging
4 | import warnings
5 | from collections.abc import Iterable as IterableClass
6 | from functools import partial
7 | from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
8 |
9 | import numpy as np
10 | import pandas as pd
11 | import torch
12 | from anndata import AnnData
13 | from scvi import REGISTRY_KEYS
14 | from scvi._compat import Literal
15 | from scvi.data import AnnDataManager
16 | from scvi.data.fields import (
17 | CategoricalJointObsField,
18 | CategoricalObsField,
19 | LayerField,
20 | NumericalJointObsField,
21 | NumericalObsField,
22 | ProteinObsmField,
23 | )
24 | from scvi.dataloaders import AnnDataLoader
25 | from scvi.model._utils import (
26 | _get_batch_code_from_category,
27 | _init_library_size,
28 | cite_seq_raw_counts_properties,
29 | )
30 | from scvi.model.base import BaseModelClass
31 | from scvi.model.base._utils import _de_core
32 | from scvi.utils._docstrings import setup_anndata_dsp
33 |
34 | from contrastive_vi.model.base.training_mixin import ContrastiveTrainingMixin
35 | from contrastive_vi.module.total_contrastive_vi import TotalContrastiveVIModule
36 |
37 | logger = logging.getLogger(__name__)
38 | Number = Union[int, float]
39 |
40 |
41 | class TotalContrastiveVIModel(ContrastiveTrainingMixin, BaseModelClass):
42 | """
43 | Model class for total-contrastiveVI.
44 | Args:
45 | ----
46 | adata: AnnData object that has been registered via
47 | `TotalContrastiveVIModel.setup_anndata`.
48 | n_batch: Number of batches. If 0, no batch effect correction is performed.
49 | n_hidden: Number of nodes per hidden layer.
50 | n_background_latent: Dimensionality of the background latent space.
51 | n_salient_latent: Dimensionality of the salient latent space.
52 | n_layers: Number of hidden layers used for encoder and decoder NNs.
53 | dropout_rate: Dropout rate for neural networks.
54 | protein_batch_mask: Dictionary where each key is a batch code, and value is for
55 | each protein, whether it was observed or not.
56 | use_observed_lib_size: Use observed library size for RNA as scaling factor in
57 | mean of conditional distribution.
58 | empirical_protein_background_prior: Set the initialization of protein
59 | background prior empirically. This option fits a GMM for each of
60 | 100 cells per batch and averages the distributions. Note that even with
61 | this option set to `True`, this only initializes a parameter that is
62 | learned during inference. If `False`, randomly initializes. The default
63 | (`None`), sets this to `True` if greater than 10 proteins are used.
64 | """
65 |
66 | def __init__(
67 | self,
68 | adata: AnnData,
69 | n_hidden: int = 128,
70 | n_background_latent: int = 10,
71 | n_salient_latent: int = 10,
72 | gene_dispersion: Literal[
73 | "gene", "gene-batch", "gene-label", "gene-cell"
74 | ] = "gene",
75 | protein_dispersion: Literal[
76 | "protein", "protein-batch", "protein-label"
77 | ] = "protein",
78 | gene_likelihood: Literal["zinb", "nb"] = "nb",
79 | latent_distribution: Literal["normal", "ln"] = "normal",
80 | empirical_protein_background_prior: Optional[bool] = None,
81 | override_missing_proteins: bool = False,
82 | wasserstein_penalty: float = 0,
83 | **model_kwargs,
84 | ) -> None:
85 | super(TotalContrastiveVIModel, self).__init__(adata)
86 |
87 | self.protein_state_registry = self.adata_manager.get_state_registry(
88 | REGISTRY_KEYS.PROTEIN_EXP_KEY
89 | )
90 | if (
91 | ProteinObsmField.PROTEIN_BATCH_MASK in self.protein_state_registry
92 | and not override_missing_proteins
93 | ):
94 | batch_mask = self.protein_state_registry.protein_batch_mask
95 | msg = (
96 | "Some proteins have all 0 counts in some batches. "
97 | + "These proteins will be treated as missing measurements; however, "
98 | + "this can occur due to experimental design/biology. "
99 | + "Reinitialize the model with `override_missing_proteins=True`,"
100 | + "to override this behavior."
101 | )
102 | warnings.warn(msg, UserWarning)
103 | self._use_adversarial_classifier = True
104 | else:
105 | batch_mask = None
106 | self._use_adversarial_classifier = False
107 |
108 | emp_prior = (
109 | empirical_protein_background_prior
110 | if empirical_protein_background_prior is not None
111 | else (self.summary_stats.n_proteins > 10)
112 | )
113 | if emp_prior:
114 | prior_mean, prior_scale = self._get_totalvi_protein_priors(adata)
115 | else:
116 | prior_mean, prior_scale = None, None
117 |
118 | n_cats_per_cov = (
119 | self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY)[
120 | CategoricalJointObsField.N_CATS_PER_KEY
121 | ]
122 | if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry
123 | else None
124 | )
125 |
126 | n_batch = self.summary_stats.n_batch
127 | use_size_factor_key = (
128 | REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry
129 | )
130 | library_log_means, library_log_vars = None, None
131 | if not use_size_factor_key:
132 | library_log_means, library_log_vars = _init_library_size(
133 | self.adata_manager, n_batch
134 | )
135 |
136 | self.module = TotalContrastiveVIModule(
137 | n_input_genes=self.summary_stats["n_vars"],
138 | n_input_proteins=self.summary_stats["n_proteins"],
139 | n_batch=n_batch,
140 | n_background_latent=n_background_latent,
141 | n_salient_latent=n_salient_latent,
142 | n_continuous_cov=self.summary_stats.get("n_extra_continuous_covs", 0),
143 | n_cats_per_cov=n_cats_per_cov,
144 | gene_dispersion=gene_dispersion,
145 | protein_dispersion=protein_dispersion,
146 | gene_likelihood=gene_likelihood,
147 | latent_distribution=latent_distribution,
148 | protein_batch_mask=batch_mask,
149 | protein_background_prior_mean=prior_mean,
150 | protein_background_prior_scale=prior_scale,
151 | use_size_factor_key=use_size_factor_key,
152 | library_log_means=library_log_means,
153 | library_log_vars=library_log_vars,
154 | wasserstein_penalty=wasserstein_penalty,
155 | **model_kwargs,
156 | )
157 | self._model_summary_string = "totalContrastiveVI."
158 | # Necessary line to get params to be used for saving and loading.
159 | self.init_params_ = self._get_init_params(locals())
160 | logger.info("The model has been initialized")
161 |
162 | @classmethod
163 | @setup_anndata_dsp.dedent
164 | def setup_anndata(
165 | cls,
166 | adata: AnnData,
167 | protein_expression_obsm_key: str,
168 | protein_names_uns_key: Optional[str] = None,
169 | batch_key: Optional[str] = None,
170 | layer: Optional[str] = None,
171 | size_factor_key: Optional[str] = None,
172 | categorical_covariate_keys: Optional[List[str]] = None,
173 | continuous_covariate_keys: Optional[List[str]] = None,
174 | **kwargs,
175 | ) -> Optional[AnnData]:
176 | """
177 | %(summary)s.
178 | Parameters
179 | ----------
180 | %(param_adata)s
181 | protein_expression_obsm_key
182 | key in `adata.obsm` for protein expression data.
183 | protein_names_uns_key
184 | key in `adata.uns` for protein names. If None, will use the column names of
185 | `adata.obsm[protein_expression_obsm_key]` if it is a DataFrame, else will
186 | assign sequential names to proteins.
187 | %(param_batch_key)s
188 | %(param_layer)s
189 | %(param_size_factor_key)s
190 | %(param_cat_cov_keys)s
191 | %(param_cont_cov_keys)s
192 | %(param_copy)s
193 | Returns
194 | -------
195 | %(returns)s
196 | """
197 | setup_method_args = cls._get_setup_method_args(**locals())
198 | batch_field = CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key)
199 | anndata_fields = [
200 | LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True),
201 | CategoricalObsField(
202 | REGISTRY_KEYS.LABELS_KEY, None
203 | ), # Default labels field for compatibility with TOTALVAE
204 | batch_field,
205 | NumericalObsField(
206 | REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key, required=False
207 | ),
208 | CategoricalJointObsField(
209 | REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys
210 | ),
211 | NumericalJointObsField(
212 | REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys
213 | ),
214 | ProteinObsmField(
215 | REGISTRY_KEYS.PROTEIN_EXP_KEY,
216 | protein_expression_obsm_key,
217 | use_batch_mask=True,
218 | batch_key=batch_field.attr_key,
219 | colnames_uns_key=protein_names_uns_key,
220 | is_count_data=True,
221 | ),
222 | ]
223 | adata_manager = AnnDataManager(
224 | fields=anndata_fields, setup_method_args=setup_method_args
225 | )
226 | adata_manager.register_fields(adata, **kwargs)
227 | cls.register_manager(adata_manager)
228 |
229 | @torch.no_grad()
230 | def get_latent_representation(
231 | self,
232 | adata: Optional[AnnData] = None,
233 | indices: Optional[Sequence[int]] = None,
234 | give_mean: bool = True,
235 | batch_size: Optional[int] = None,
236 | representation_kind: str = "salient",
237 | ) -> np.ndarray:
238 | """
239 | Return the background or salient latent representation for each cell.
240 |
241 | Args:
242 | ----
243 | adata: AnnData object with equivalent structure to initial AnnData. If `None`,
244 | defaults to the AnnData object used to initialize the model.
245 | indices: Indices of cells in adata to use. If `None`, all cells are used.
246 | give_mean: Give mean of distribution or sample from it.
247 | batch_size: Mini-batch size for data loading into model. Defaults to
248 | `scvi.settings.batch_size`.
249 | representation_kind: Either "background" or "salient" for the corresponding
250 | representation kind.
251 |
252 | Returns
253 | -------
254 | A numpy array with shape `(n_cells, n_latent)`.
255 | """
256 | available_representation_kinds = ["background", "salient"]
257 | assert representation_kind in available_representation_kinds, (
258 | f"representation_kind = {representation_kind} is not one of"
259 | f" {available_representation_kinds}"
260 | )
261 |
262 | adata = self._validate_anndata(adata)
263 | data_loader = self._make_data_loader(
264 | adata=adata,
265 | indices=indices,
266 | batch_size=batch_size,
267 | shuffle=False,
268 | data_loader_class=AnnDataLoader,
269 | )
270 | latent = []
271 | for tensors in data_loader:
272 | x = tensors[REGISTRY_KEYS.X_KEY]
273 | y = tensors[REGISTRY_KEYS.PROTEIN_EXP_KEY]
274 | batch_index = tensors[REGISTRY_KEYS.BATCH_KEY]
275 | outputs = self.module._generic_inference(
276 | x=x, y=y, batch_index=batch_index, n_samples=1
277 | )
278 |
279 | if representation_kind == "background":
280 | latent_m = outputs["qz_m"]
281 | latent_sample = outputs["z"]
282 | else:
283 | latent_m = outputs["qs_m"]
284 | latent_sample = outputs["s"]
285 |
286 | if give_mean:
287 | latent_sample = latent_m
288 |
289 | latent += [latent_sample.detach().cpu()]
290 | return torch.cat(latent).numpy()
291 |
292 | @torch.no_grad()
293 | def get_normalized_expression(
294 | self,
295 | adata=None,
296 | indices=None,
297 | n_samples_overall: Optional[int] = None,
298 | transform_batch: Optional[Sequence[Union[Number, str]]] = None,
299 | gene_list: Optional[Sequence[str]] = None,
300 | protein_list: Optional[Sequence[str]] = None,
301 | library_size: Optional[Union[float, Literal["latent"]]] = 1,
302 | n_samples: int = 1,
303 | sample_protein_mixing: bool = False,
304 | scale_protein: bool = False,
305 | include_protein_background: bool = False,
306 | batch_size: Optional[int] = None,
307 | return_mean: bool = True,
308 | return_numpy: Optional[bool] = None,
309 | ) -> Dict[
310 | str, Tuple[Union[np.ndarray, pd.DataFrame], Union[np.ndarray, pd.DataFrame]]
311 | ]:
312 | r"""
313 | Returns the normalized gene expression and protein expression.
314 | This is denoted as :math:`\rho_n` in the totalVI paper for genes, and TODO
315 | for proteins, :math:`(1-\pi_{nt})\alpha_{nt}\beta_{nt}`.
316 | Parameters
317 | ----------
318 | adata
319 | AnnData object with equivalent structure to initial AnnData. If `None`,
320 | defaults to the AnnData object used to initialize the model.
321 | indices
322 | Indices of cells in adata to use. If `None`, all cells are used.
323 | n_samples_overall
324 | Number of samples to use in total
325 | transform_batch
326 | Batch to condition on.
327 | If transform_batch is:
328 | - None, then real observed batch is used
329 | - int, then batch transform_batch is used
330 | - List[int], then average over batches in list
331 | gene_list
332 | Return frequencies of expression for a subset of genes.
333 | This can save memory when working with large datasets and few genes are
334 | of interest.
335 | protein_list
336 | Return protein expression for a subset of genes.
337 | This can save memory when working with large datasets and few genes are
338 | of interest.
339 | library_size
340 | Scale the expression frequencies to a common library size.
341 | This allows gene expression levels to be interpreted on a common scale of
342 | relevant magnitude.
343 | n_samples
344 | Get sample scale from multiple samples.
345 | sample_protein_mixing
346 | Sample mixing bernoulli, setting background to zero
347 | scale_protein
348 | Make protein expression sum to 1
349 | include_protein_background
350 | Include background component for protein expression
351 | batch_size
352 | Minibatch size for data loading into model. Defaults to
353 | `scvi.settings.batch_size`.
354 | return_mean
355 | Whether to return the mean of the samples.
356 | return_numpy
357 | Return a `np.ndarray` instead of a `pd.DataFrame`. Includes gene
358 | names as columns. If either n_samples=1 or return_mean=True, defaults to
359 | False. Otherwise, it defaults to True.
360 | Returns
361 | -------
362 | - **gene_normalized_expression** - normalized expression for RNA
363 | - **protein_normalized_expression** - normalized expression for proteins
364 | If ``n_samples`` > 1 and ``return_mean`` is False, then the shape is
365 | ``(samples, cells, genes)``. Otherwise, shape is ``(cells, genes)``.
366 | Return type is ``pd.DataFrame`` unless ``return_numpy`` is True.
367 | """
368 | adata = self._validate_anndata(adata)
369 | adata_manager = self.get_anndata_manager(adata)
370 | if indices is None:
371 | indices = np.arange(adata.n_obs)
372 | if n_samples_overall is not None:
373 | indices = np.random.choice(indices, n_samples_overall)
374 | post = self._make_data_loader(
375 | adata=adata, indices=indices, batch_size=batch_size
376 | )
377 |
378 | if gene_list is None:
379 | gene_mask = slice(None)
380 | else:
381 | all_genes = adata.var_names
382 | gene_mask = [True if gene in gene_list else False for gene in all_genes]
383 | if protein_list is None:
384 | protein_mask = slice(None)
385 | else:
386 | all_proteins = self.scvi_setup_dict_["protein_names"]
387 | protein_mask = [True if p in protein_list else False for p in all_proteins]
388 | if indices is None:
389 | indices = np.arange(adata.n_obs)
390 |
391 | if n_samples > 1 and return_mean is False:
392 | if return_numpy is False:
393 | warnings.warn(
394 | "return_numpy must be True if n_samples > 1 and return_mean is "
395 | "False, returning np.ndarray"
396 | )
397 | return_numpy = True
398 |
399 | if not isinstance(transform_batch, IterableClass):
400 | transform_batch = [transform_batch]
401 |
402 | transform_batch = _get_batch_code_from_category(adata_manager, transform_batch)
403 |
404 | results = {}
405 | for expression_type in ["salient", "background"]:
406 | scale_list_gene = []
407 | scale_list_pro = []
408 |
409 | for tensors in post:
410 | x = tensors[REGISTRY_KEYS.X_KEY]
411 | y = tensors[REGISTRY_KEYS.PROTEIN_EXP_KEY]
412 | batch_original = tensors[REGISTRY_KEYS.BATCH_KEY]
413 | px_scale = torch.zeros_like(x)
414 | py_scale = torch.zeros_like(y)
415 | if n_samples > 1:
416 | px_scale = torch.stack(n_samples * [px_scale])
417 | py_scale = torch.stack(n_samples * [py_scale])
418 | for b in transform_batch:
419 | inference_outputs = self.module._generic_inference(
420 | x=x, y=y, batch_index=batch_original, n_samples=n_samples
421 | )
422 |
423 | if expression_type == "salient":
424 | s = inference_outputs["s"]
425 | elif expression_type == "background":
426 | s = torch.zeros_like(inference_outputs["s"])
427 | else:
428 | raise NotImplementedError("Invalid expression type provided")
429 |
430 | generative_outputs = self.module._generic_generative(
431 | z=inference_outputs["z"],
432 | s=s,
433 | library_gene=inference_outputs["library_gene"],
434 | batch_index=b,
435 | )
436 |
437 | if library_size == "latent":
438 | px_scale += generative_outputs["px_"]["rate"].cpu()
439 | else:
440 | px_scale += generative_outputs["px_"]["scale"].cpu()
441 | px_scale = px_scale[..., gene_mask]
442 |
443 | py_ = generative_outputs["py_"]
444 | # probability of background
445 | protein_mixing = 1 / (1 + torch.exp(-py_["mixing"].cpu()))
446 | if sample_protein_mixing is True:
447 | protein_mixing = torch.distributions.Bernoulli(
448 | protein_mixing
449 | ).sample()
450 | protein_val = py_["rate_fore"].cpu() * (1 - protein_mixing)
451 | if include_protein_background is True:
452 | protein_val += py_["rate_back"].cpu() * protein_mixing
453 |
454 | if scale_protein is True:
455 | protein_val = torch.nn.functional.normalize(
456 | protein_val, p=1, dim=-1
457 | )
458 | protein_val = protein_val[..., protein_mask]
459 | py_scale += protein_val
460 | px_scale /= len(transform_batch)
461 | py_scale /= len(transform_batch)
462 | scale_list_gene.append(px_scale)
463 | scale_list_pro.append(py_scale)
464 |
465 | if n_samples > 1:
466 | # concatenate along batch dimension
467 | # -> result shape = (samples, cells, features)
468 | scale_list_gene = torch.cat(scale_list_gene, dim=1)
469 | scale_list_pro = torch.cat(scale_list_pro, dim=1)
470 | # (cells, features, samples)
471 | scale_list_gene = scale_list_gene.permute(1, 2, 0)
472 | scale_list_pro = scale_list_pro.permute(1, 2, 0)
473 | else:
474 | scale_list_gene = torch.cat(scale_list_gene, dim=0)
475 | scale_list_pro = torch.cat(scale_list_pro, dim=0)
476 |
477 | if return_mean is True and n_samples > 1:
478 | scale_list_gene = torch.mean(scale_list_gene, dim=-1)
479 | scale_list_pro = torch.mean(scale_list_pro, dim=-1)
480 |
481 | scale_list_gene = scale_list_gene.cpu().numpy()
482 | scale_list_pro = scale_list_pro.cpu().numpy()
483 | if return_numpy is None or return_numpy is False:
484 | gene_df = pd.DataFrame(
485 | scale_list_gene,
486 | columns=adata.var_names[gene_mask],
487 | index=adata.obs_names[indices],
488 | )
489 | protein_names = self.protein_state_registry.column_names
490 | pro_df = pd.DataFrame(
491 | scale_list_pro,
492 | columns=protein_names[protein_mask],
493 | index=adata.obs_names[indices],
494 | )
495 |
496 | results[expression_type] = (gene_df, pro_df)
497 | else:
498 | results[expression_type] = (scale_list_gene, scale_list_pro)
499 | return results
500 |
501 | @torch.no_grad()
502 | def get_specific_normalized_expression(
503 | self,
504 | adata=None,
505 | indices=None,
506 | n_samples_overall=None,
507 | transform_batch: Optional[Sequence[Union[Number, str]]] = None,
508 | scale_protein=False,
509 | batch_size: Optional[int] = None,
510 | n_samples=1,
511 | sample_protein_mixing=False,
512 | include_protein_background=False,
513 | return_mean=True,
514 | return_numpy=True,
515 | expression_type: Optional[str] = None,
516 | indices_to_return_salient: Optional[Sequence[int]] = None,
517 | ):
518 | """
519 | Return normalized (decoded) gene and protein expression.
520 |
521 | Gene + protein expressions are decoded from either the background or salient
522 | latent space. One of `expression_type` or `indices_to_return_salient` should
523 | have an input argument.
524 |
525 | Args:
526 | ----
527 | adata:
528 | AnnData object with equivalent structure to initial AnnData. If `None`,
529 | defaults to the AnnData object used to initialize the model.
530 | indices: Indices of cells in adata to use. If `None`, all cells are used.
531 | n_samples_overall: The number of random samples in `adata` to use.
532 | transform_batch:
533 | Batch to condition on.
534 | If transform_batch is:
535 | - None, then real observed batch is used
536 | - int, then batch transform_batch is used
537 | - List[int], then average over batches in list
538 | scale_protein: Make protein expression sum to 1
539 | batch_size:
540 | Minibatch size for data loading into model. Defaults to
541 | `scvi.settings.batch_size`.
542 | sample_protein_mixing: Sample mixing bernoulli, setting background to zero
543 | include_protein_background: Include background component for protein expression
544 | return_mean: Whether to return the mean of the samples.
545 | return_numpy:
546 | Return a `np.ndarray` instead of a `pd.DataFrame`. Includes gene
547 | names as columns. If either n_samples=1 or return_mean=True, defaults to
548 | False. Otherwise, it defaults to True.
549 | expression_type: One of {"salient", "background"} to specify the type of
550 | normalized expression to return.
551 | indices_to_return_salient: If `indices` is a subset of
552 | `indices_to_return_salient`, normalized expressions derived from background
553 | and salient latent embeddings are returned. If `indices` is not `None` and
554 | is not a subset of `indices_to_return_salient`, normalized expressions
555 | derived only from background latent embeddings are returned.
556 |
557 | Returns
558 | -------
559 | If `n_samples` > 1 and `return_mean` is `False`, then the shape is
560 | `((samples, cells, genes), (samples, cells, proteins))`. Otherwise, shape
561 | is `((cells, genes), (cells, proteins))`. In this case, return type is
562 | Tuple[`pandas.DataFrame`] unless `return_numpy` is `True`.
563 | """
564 | is_expression_type_none = expression_type is None
565 | is_indices_to_return_salient_none = indices_to_return_salient is None
566 | if is_expression_type_none and is_indices_to_return_salient_none:
567 | raise ValueError(
568 | "Both expression_type and indices_to_return_salient are None! "
569 | "Exactly one of them needs to be supplied with an input argument."
570 | )
571 | elif (not is_expression_type_none) and (not is_indices_to_return_salient_none):
572 | raise ValueError(
573 | "Both expression_type and indices_to_return_salient have an input "
574 | "argument! Exactly one of them needs to be supplied with an input "
575 | "argument."
576 | )
577 | else:
578 | exprs = self.get_normalized_expression(
579 | adata=adata,
580 | indices=indices,
581 | n_samples_overall=n_samples_overall,
582 | transform_batch=transform_batch,
583 | return_numpy=return_numpy,
584 | return_mean=return_mean,
585 | n_samples=n_samples,
586 | batch_size=batch_size,
587 | scale_protein=scale_protein,
588 | sample_protein_mixing=sample_protein_mixing,
589 | include_protein_background=include_protein_background,
590 | )
591 | if not is_expression_type_none:
592 | return exprs[expression_type]
593 | else:
594 | if indices is None:
595 | indices = np.arange(adata.n_obs)
596 | if set(indices).issubset(set(indices_to_return_salient)):
597 | return exprs["salient"]
598 | else:
599 | return exprs["background"]
600 |
601 | def _expression_for_de(
602 | self,
603 | adata=None,
604 | indices=None,
605 | n_samples_overall=None,
606 | transform_batch: Optional[Sequence[Union[Number, str]]] = None,
607 | scale_protein=False,
608 | batch_size: Optional[int] = None,
609 | n_samples=1,
610 | sample_protein_mixing=False,
611 | include_protein_background=False,
612 | protein_prior_count=0.5,
613 | expression_type: Optional[str] = None,
614 | indices_to_return_salient: Optional[Sequence[int]] = None,
615 | ):
616 | rna, protein = self.get_specific_normalized_expression(
617 | adata=adata,
618 | indices=indices,
619 | n_samples_overall=n_samples_overall,
620 | transform_batch=transform_batch,
621 | return_numpy=True,
622 | n_samples=n_samples,
623 | batch_size=batch_size,
624 | scale_protein=scale_protein,
625 | sample_protein_mixing=sample_protein_mixing,
626 | include_protein_background=include_protein_background,
627 | expression_type=expression_type,
628 | indices_to_return_salient=indices_to_return_salient,
629 | )
630 | protein += protein_prior_count
631 |
632 | joint = np.concatenate([rna, protein], axis=1)
633 | return joint
634 |
635 | def differential_expression(
636 | self,
637 | adata: Optional[AnnData] = None,
638 | groupby: Optional[str] = None,
639 | group1: Optional[Iterable[str]] = None,
640 | group2: Optional[str] = None,
641 | idx1: Optional[Union[Sequence[int], Sequence[bool], str]] = None,
642 | idx2: Optional[Union[Sequence[int], Sequence[bool], str]] = None,
643 | mode: Literal["vanilla", "change"] = "change",
644 | delta: float = 0.25,
645 | batch_size: Optional[int] = None,
646 | all_stats: bool = True,
647 | batch_correction: bool = False,
648 | batchid1: Optional[Iterable[str]] = None,
649 | batchid2: Optional[Iterable[str]] = None,
650 | fdr_target: float = 0.05,
651 | silent: bool = False,
652 | protein_prior_count: float = 0.1,
653 | scale_protein: bool = False,
654 | sample_protein_mixing: bool = False,
655 | include_protein_background: bool = False,
656 | target_idx: Optional[Sequence[int]] = None,
657 | **kwargs,
658 | ) -> pd.DataFrame:
659 | r"""
660 | A unified method for differential expression analysis.
661 | Implements `"vanilla"` DE [Lopez18]_ and `"change"` mode DE [Boyeau19]_.
662 |
663 | Args:
664 | ----
665 | protein_prior_count:
666 | Prior count added to protein expression before LFC computation
667 | scale_protein:
668 | Force protein values to sum to one in every single cell
669 | (post-hoc normalization).
670 | sample_protein_mixing:
671 | Sample the protein mixture component, i.e., use the parameter to sample a
672 | Bernoulli that determines if expression is from foreground/background.
673 | include_protein_background:
674 | Include the protein background component as part of the protein expression
675 | target_idx: If not `None`, a boolean or integer identifier should be used for
676 | cells in the contrastive target group. Normalized expression values derived
677 | from both salient and background latent embeddings are used when
678 | {group1, group2} is a subset of the target group, otherwise background
679 | normalized expression values are used.
680 | **kwargs:
681 | Keyword args for
682 | :meth:`scvi.model.base.DifferentialComputation.get_bayes_factors`
683 |
684 | Returns
685 | -------
686 | Differential expression DataFrame.
687 | """
688 | adata = self._validate_anndata(adata)
689 | col_names = np.concatenate(
690 | [
691 | np.asarray(adata.var_names),
692 | self.protein_state_registry.column_names,
693 | ]
694 | )
695 |
696 | if target_idx is not None:
697 | target_idx = np.array(target_idx)
698 | if target_idx.dtype is np.dtype("bool"):
699 | assert (
700 | len(target_idx) == adata.n_obs
701 | ), "target_idx mask must be the same length as adata!"
702 | target_idx = np.arange(adata.n_obs)[target_idx]
703 | model_fn = partial(
704 | self._expression_for_de,
705 | scale_protein=scale_protein,
706 | sample_protein_mixing=sample_protein_mixing,
707 | include_protein_background=include_protein_background,
708 | protein_prior_count=protein_prior_count,
709 | batch_size=batch_size,
710 | expression_type=None,
711 | indices_to_return_salient=target_idx,
712 | n_samples=100,
713 | )
714 | else:
715 | model_fn = partial(
716 | self._expression_for_de,
717 | scale_protein=scale_protein,
718 | sample_protein_mixing=sample_protein_mixing,
719 | include_protein_background=include_protein_background,
720 | protein_prior_count=protein_prior_count,
721 | batch_size=batch_size,
722 | expression_type="salient",
723 | n_samples=100,
724 | )
725 |
726 | result = _de_core(
727 | self.get_anndata_manager(adata, required=True),
728 | model_fn,
729 | groupby,
730 | group1,
731 | group2,
732 | idx1,
733 | idx2,
734 | all_stats,
735 | cite_seq_raw_counts_properties,
736 | col_names,
737 | mode,
738 | batchid1,
739 | batchid2,
740 | delta,
741 | batch_correction,
742 | fdr_target,
743 | silent,
744 | **kwargs,
745 | )
746 |
747 | return result
748 |
749 | @torch.no_grad()
750 | def get_latent_library_size(
751 | self,
752 | adata: Optional[AnnData] = None,
753 | indices: Optional[Sequence[int]] = None,
754 | give_mean: bool = True,
755 | batch_size: Optional[int] = None,
756 | ) -> np.ndarray:
757 | r"""
758 | Returns the latent RNA library size for each cell.
759 | This is denoted as :math:`\ell_n` in the totalVI paper.
760 | Parameters
761 | ----------
762 | adata
763 | AnnData object with equivalent structure to initial AnnData. If `None`,
764 | defaults to the AnnData object used to initialize the model.
765 | indices
766 | Indices of cells in adata to use. If `None`, all cells are used.
767 | give_mean
768 | Return the mean or a sample from the posterior distribution.
769 | batch_size
770 | Minibatch size for data loading into model. Defaults to
771 | `scvi.settings.batch_size`.
772 | """
773 | self._check_if_trained(warn=False)
774 |
775 | adata = self._validate_anndata(adata)
776 | scdl = self._make_data_loader(
777 | adata=adata, indices=indices, batch_size=batch_size
778 | )
779 | libraries = []
780 | for tensors in scdl:
781 | x = tensors[REGISTRY_KEYS.X_KEY]
782 | y = tensors[REGISTRY_KEYS.PROTEIN_EXP_KEY]
783 | batch_index = tensors[REGISTRY_KEYS.BATCH_KEY]
784 | outputs = self.module._generic_inference(x=x, y=y, batch_index=batch_index)
785 |
786 | library = outputs["library_gene"]
787 | if not give_mean:
788 | library = torch.exp(library)
789 | else:
790 | ql = (outputs["ql_m"], outputs["ql_v"])
791 | if ql is None:
792 | raise RuntimeError(
793 | "The module for this model does not compute the posterior"
794 | "distribution for the library size. Set `give_mean` to False to"
795 | "use the observed library size instead."
796 | )
797 | library = torch.distributions.LogNormal(ql[0], ql[1]).mean
798 | libraries += [library.cpu()]
799 | return torch.cat(libraries).numpy()
800 |
--------------------------------------------------------------------------------
/contrastive_vi/module/__init__.py:
--------------------------------------------------------------------------------
1 | """PyTorch modules for models."""
2 |
--------------------------------------------------------------------------------
/contrastive_vi/module/contrastive_vi.py:
--------------------------------------------------------------------------------
1 | """PyTorch module for Contrastive VI for single cell expression data."""
2 |
3 | from typing import Dict, Optional, Tuple
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from scvi import REGISTRY_KEYS
10 | from scvi.distributions import ZeroInflatedNegativeBinomial
11 | from scvi.module.base import BaseModuleClass, LossRecorder, auto_move_data
12 | from scvi.nn import DecoderSCVI, Encoder, one_hot
13 | from torch.distributions import Normal
14 | from torch.distributions import kl_divergence as kl
15 |
16 | from contrastive_vi.module.utils import gram_matrix
17 |
18 | torch.backends.cudnn.benchmark = True
19 |
20 |
21 | class ContrastiveVIModule(BaseModuleClass):
22 | """
23 | PyTorch module for Contrastive VI (Variational Inference).
24 |
25 | Args:
26 | ----
27 | n_input: Number of input genes.
28 | n_batch: Number of batches. If 0, no batch effect correction is performed.
29 | n_hidden: Number of nodes per hidden layer.
30 | n_background_latent: Dimensionality of the background latent space.
31 | n_salient_latent: Dimensionality of the salient latent space.
32 | n_layers: Number of hidden layers used for encoder and decoder NNs.
33 | dropout_rate: Dropout rate for neural networks.
34 | use_observed_lib_size: Use observed library size for RNA as scaling factor in
35 | mean of conditional distribution.
36 | library_log_means: 1 x n_batch array of means of the log library sizes.
37 | Parameterize prior on library size if not using observed library size.
38 | library_log_vars: 1 x n_batch array of variances of the log library sizes.
39 | Parameterize prior on library size if not using observed library size.
40 | wasserstein_penalty: Weight of the Wasserstein distance loss that further
41 | discourages shared variations from leaking into the salient latent space.
42 | """
43 |
44 | def __init__(
45 | self,
46 | n_input: int,
47 | n_batch: int = 0,
48 | n_hidden: int = 128,
49 | n_background_latent: int = 10,
50 | n_salient_latent: int = 10,
51 | n_layers: int = 1,
52 | dropout_rate: float = 0.1,
53 | use_observed_lib_size: bool = True,
54 | library_log_means: Optional[np.ndarray] = None,
55 | library_log_vars: Optional[np.ndarray] = None,
56 | wasserstein_penalty: float = 0
57 | ) -> None:
58 | super().__init__()
59 | self.n_input = n_input
60 | self.n_batch = n_batch
61 | self.n_hidden = n_hidden
62 | self.n_background_latent = n_background_latent
63 | self.n_salient_latent = n_salient_latent
64 | self.n_layers = n_layers
65 | self.dropout_rate = dropout_rate
66 | self.latent_distribution = "normal"
67 | self.dispersion = "gene"
68 | self.px_r = torch.nn.Parameter(torch.randn(n_input))
69 | self.use_observed_lib_size = use_observed_lib_size
70 | self.wasserstein_penalty = wasserstein_penalty
71 |
72 | if not self.use_observed_lib_size:
73 | if library_log_means is None or library_log_vars is None:
74 | raise ValueError(
75 | "If not using observed_lib_size, "
76 | "must provide library_log_means and library_log_vars."
77 | )
78 | self.register_buffer(
79 | "library_log_means", torch.from_numpy(library_log_means).float()
80 | )
81 | self.register_buffer(
82 | "library_log_vars", torch.from_numpy(library_log_vars).float()
83 | )
84 |
85 | cat_list = [n_batch]
86 | # Background encoder.
87 | self.z_encoder = Encoder(
88 | n_input,
89 | n_background_latent,
90 | n_cat_list=cat_list,
91 | n_layers=n_layers,
92 | n_hidden=n_hidden,
93 | dropout_rate=dropout_rate,
94 | distribution=self.latent_distribution,
95 | inject_covariates=True,
96 | use_batch_norm=True,
97 | use_layer_norm=False,
98 | var_activation=None,
99 | )
100 | # Salient encoder.
101 | self.s_encoder = Encoder(
102 | n_input,
103 | n_salient_latent,
104 | n_cat_list=cat_list,
105 | n_layers=n_layers,
106 | n_hidden=n_hidden,
107 | dropout_rate=dropout_rate,
108 | distribution=self.latent_distribution,
109 | inject_covariates=True,
110 | use_batch_norm=True,
111 | use_layer_norm=False,
112 | var_activation=None,
113 | )
114 | # Library size encoder.
115 | self.l_encoder = Encoder(
116 | n_input,
117 | 1,
118 | n_layers=1,
119 | n_cat_list=cat_list,
120 | n_hidden=n_hidden,
121 | dropout_rate=dropout_rate,
122 | inject_covariates=True,
123 | use_batch_norm=True,
124 | use_layer_norm=False,
125 | var_activation=None,
126 | )
127 | # Decoder from latent variable to distribution parameters in data space.
128 | n_total_latent = n_background_latent + n_salient_latent
129 | self.decoder = DecoderSCVI(
130 | n_total_latent,
131 | n_input,
132 | n_cat_list=cat_list,
133 | n_layers=n_layers,
134 | n_hidden=n_hidden,
135 | inject_covariates=True,
136 | use_batch_norm=True,
137 | use_layer_norm=False,
138 | )
139 |
140 | @auto_move_data
141 | def _compute_local_library_params(
142 | self, batch_index: torch.Tensor
143 | ) -> Tuple[torch.Tensor, torch.Tensor]:
144 | n_batch = self.library_log_means.shape[1]
145 | local_library_log_means = F.linear(
146 | one_hot(batch_index, n_batch), self.library_log_means
147 | )
148 | local_library_log_vars = F.linear(
149 | one_hot(batch_index, n_batch), self.library_log_vars
150 | )
151 | return local_library_log_means, local_library_log_vars
152 |
153 | @staticmethod
154 | def _get_min_batch_size(concat_tensors: Dict[str, Dict[str, torch.Tensor]]) -> int:
155 | return min(
156 | concat_tensors["background"][REGISTRY_KEYS.X_KEY].shape[0],
157 | concat_tensors["target"][REGISTRY_KEYS.X_KEY].shape[0],
158 | )
159 |
160 | @staticmethod
161 | def _reduce_tensors_to_min_batch_size(
162 | tensors: Dict[str, torch.Tensor], min_batch_size: int
163 | ) -> None:
164 | for name, tensor in tensors.items():
165 | tensors[name] = tensor[:min_batch_size, :]
166 |
167 | @staticmethod
168 | def _get_inference_input_from_concat_tensors(
169 | concat_tensors: Dict[str, Dict[str, torch.Tensor]], index: str
170 | ) -> Dict[str, torch.Tensor]:
171 | tensors = concat_tensors[index]
172 | x = tensors[REGISTRY_KEYS.X_KEY]
173 | batch_index = tensors[REGISTRY_KEYS.BATCH_KEY]
174 | input_dict = dict(x=x, batch_index=batch_index)
175 | return input_dict
176 |
177 | def _get_inference_input(
178 | self, concat_tensors: Dict[str, Dict[str, torch.Tensor]]
179 | ) -> Dict[str, Dict[str, torch.Tensor]]:
180 | background = self._get_inference_input_from_concat_tensors(
181 | concat_tensors, "background"
182 | )
183 | target = self._get_inference_input_from_concat_tensors(concat_tensors, "target")
184 | # Ensure batch sizes are the same.
185 | min_batch_size = self._get_min_batch_size(concat_tensors)
186 | self._reduce_tensors_to_min_batch_size(background, min_batch_size)
187 | self._reduce_tensors_to_min_batch_size(target, min_batch_size)
188 | return dict(background=background, target=target)
189 |
190 | @staticmethod
191 | def _get_generative_input_from_concat_tensors(
192 | concat_tensors: Dict[str, Dict[str, torch.Tensor]], index: str
193 | ) -> Dict[str, torch.Tensor]:
194 | tensors = concat_tensors[index]
195 | batch_index = tensors[REGISTRY_KEYS.BATCH_KEY]
196 | input_dict = dict(batch_index=batch_index)
197 | return input_dict
198 |
199 | @staticmethod
200 | def _get_generative_input_from_inference_outputs(
201 | inference_outputs: Dict[str, Dict[str, torch.Tensor]], data_source: str
202 | ) -> Dict[str, torch.Tensor]:
203 | z = inference_outputs[data_source]["z"]
204 | s = inference_outputs[data_source]["s"]
205 | library = inference_outputs[data_source]["library"]
206 | return dict(z=z, s=s, library=library)
207 |
208 | def _get_generative_input(
209 | self,
210 | concat_tensors: Dict[str, Dict[str, torch.Tensor]],
211 | inference_outputs: Dict[str, Dict[str, torch.Tensor]],
212 | ) -> Dict[str, Dict[str, torch.Tensor]]:
213 | background_tensor_input = self._get_generative_input_from_concat_tensors(
214 | concat_tensors, "background"
215 | )
216 | target_tensor_input = self._get_generative_input_from_concat_tensors(
217 | concat_tensors, "target"
218 | )
219 | # Ensure batch sizes are the same.
220 | min_batch_size = self._get_min_batch_size(concat_tensors)
221 | self._reduce_tensors_to_min_batch_size(background_tensor_input, min_batch_size)
222 | self._reduce_tensors_to_min_batch_size(target_tensor_input, min_batch_size)
223 |
224 | background_inference_outputs = (
225 | self._get_generative_input_from_inference_outputs(
226 | inference_outputs, "background"
227 | )
228 | )
229 | target_inference_outputs = self._get_generative_input_from_inference_outputs(
230 | inference_outputs, "target"
231 | )
232 | background = {**background_tensor_input, **background_inference_outputs}
233 | target = {**target_tensor_input, **target_inference_outputs}
234 | return dict(background=background, target=target)
235 |
236 | @staticmethod
237 | def _reshape_tensor_for_samples(tensor: torch.Tensor, n_samples: int):
238 | return tensor.unsqueeze(0).expand((n_samples, tensor.size(0), tensor.size(1)))
239 |
240 | @auto_move_data
241 | def _generic_inference(
242 | self,
243 | x: torch.Tensor,
244 | batch_index: torch.Tensor,
245 | n_samples: int = 1,
246 | ) -> Dict[str, torch.Tensor]:
247 | x_ = x
248 | if self.use_observed_lib_size:
249 | library = torch.log(x.sum(1)).unsqueeze(1)
250 | x_ = torch.log(1 + x_)
251 |
252 | qz_m, qz_v, z = self.z_encoder(x_, batch_index)
253 | qs_m, qs_v, s = self.s_encoder(x_, batch_index)
254 |
255 | ql_m, ql_v = None, None
256 | if not self.use_observed_lib_size:
257 | ql_m, ql_v, library_encoded = self.l_encoder(x_, batch_index)
258 | library = library_encoded
259 |
260 | if n_samples > 1:
261 | qz_m = self._reshape_tensor_for_samples(qz_m, n_samples)
262 | qz_v = self._reshape_tensor_for_samples(qz_v, n_samples)
263 | z = self._reshape_tensor_for_samples(z, n_samples)
264 | qs_m = self._reshape_tensor_for_samples(qs_m, n_samples)
265 | qs_v = self._reshape_tensor_for_samples(qs_v, n_samples)
266 | s = self._reshape_tensor_for_samples(s, n_samples)
267 |
268 | if self.use_observed_lib_size:
269 | library = self._reshape_tensor_for_samples(library, n_samples)
270 | else:
271 | ql_m = self._reshape_tensor_for_samples(ql_m, n_samples)
272 | ql_v = self._reshape_tensor_for_samples(ql_v, n_samples)
273 | library = Normal(ql_m, ql_v.sqrt()).sample()
274 |
275 | outputs = dict(
276 | z=z,
277 | qz_m=qz_m,
278 | qz_v=qz_v,
279 | s=s,
280 | qs_m=qs_m,
281 | qs_v=qs_v,
282 | library=library,
283 | ql_m=ql_m,
284 | ql_v=ql_v,
285 | )
286 | return outputs
287 |
288 | @auto_move_data
289 | def inference(
290 | self,
291 | background: Dict[str, torch.Tensor],
292 | target: Dict[str, torch.Tensor],
293 | n_samples: int = 1,
294 | ) -> Dict[str, Dict[str, torch.Tensor]]:
295 | background_batch_size = background["x"].shape[0]
296 | target_batch_size = target["x"].shape[0]
297 | inference_input = {}
298 | for key in background.keys():
299 | inference_input[key] = torch.cat([background[key], target[key]], dim=0)
300 | outputs = self._generic_inference(**inference_input, n_samples=n_samples)
301 | batch_size_dim = 0 if n_samples == 1 else 1
302 | background_outputs, target_outputs = {}, {}
303 | for key in outputs.keys():
304 | if outputs[key] is not None:
305 | background_tensor, target_tensor = torch.split(
306 | outputs[key],
307 | [background_batch_size, target_batch_size],
308 | dim=batch_size_dim,
309 | )
310 | else:
311 | background_tensor, target_tensor = None, None
312 | background_outputs[key] = background_tensor
313 | target_outputs[key] = target_tensor
314 | background_outputs["s"] = torch.zeros_like(background_outputs["s"])
315 | return dict(background=background_outputs, target=target_outputs)
316 |
317 | @auto_move_data
318 | def _generic_generative(
319 | self,
320 | z: torch.Tensor,
321 | s: torch.Tensor,
322 | library: torch.Tensor,
323 | batch_index: torch.Tensor,
324 | ) -> Dict[str, torch.Tensor]:
325 | latent = torch.cat([z, s], dim=-1)
326 | px_scale, px_r, px_rate, px_dropout = self.decoder(
327 | self.dispersion,
328 | latent,
329 | library,
330 | batch_index,
331 | )
332 | px_r = torch.exp(self.px_r)
333 | return dict(
334 | px_scale=px_scale, px_r=px_r, px_rate=px_rate, px_dropout=px_dropout
335 | )
336 |
337 | @auto_move_data
338 | def generative(
339 | self,
340 | background: Dict[str, torch.Tensor],
341 | target: Dict[str, torch.Tensor],
342 | ) -> Dict[str, Dict[str, torch.Tensor]]:
343 | latent_z_shape = background["z"].shape
344 | batch_size_dim = 0 if len(latent_z_shape) == 2 else 1
345 | background_batch_size = background["z"].shape[batch_size_dim]
346 | target_batch_size = target["z"].shape[batch_size_dim]
347 | generative_input = {}
348 | for key in ["z", "s", "library"]:
349 | generative_input[key] = torch.cat(
350 | [background[key], target[key]], dim=batch_size_dim
351 | )
352 | generative_input["batch_index"] = torch.cat(
353 | [background["batch_index"], target["batch_index"]], dim=0
354 | )
355 | outputs = self._generic_generative(**generative_input)
356 | background_outputs, target_outputs = {}, {}
357 | for key in ["px_scale", "px_rate", "px_dropout"]:
358 | if outputs[key] is not None:
359 | background_tensor, target_tensor = torch.split(
360 | outputs[key],
361 | [background_batch_size, target_batch_size],
362 | dim=batch_size_dim,
363 | )
364 | else:
365 | background_tensor, target_tensor = None, None
366 | background_outputs[key] = background_tensor
367 | target_outputs[key] = target_tensor
368 | background_outputs["px_r"] = outputs["px_r"]
369 | target_outputs["px_r"] = outputs["px_r"]
370 | return dict(background=background_outputs, target=target_outputs)
371 |
372 | @staticmethod
373 | def reconstruction_loss(
374 | x: torch.Tensor,
375 | px_rate: torch.Tensor,
376 | px_r: torch.Tensor,
377 | px_dropout: torch.Tensor,
378 | ) -> torch.Tensor:
379 | """
380 | Compute likelihood loss for zero-inflated negative binomial distribution.
381 |
382 | Args:
383 | ----
384 | x: Input data.
385 | px_rate: Mean of distribution.
386 | px_r: Inverse dispersion.
387 | px_dropout: Logits scale of zero inflation probability.
388 |
389 | Returns
390 | -------
391 | Negative log likelihood (reconstruction loss) for each data point. If number
392 | of latent samples == 1, the tensor has shape `(batch_size, )`. If number
393 | of latent samples > 1, the tensor has shape `(n_samples, batch_size)`.
394 | """
395 | recon_loss = (
396 | -ZeroInflatedNegativeBinomial(mu=px_rate, theta=px_r, zi_logits=px_dropout)
397 | .log_prob(x)
398 | .sum(dim=-1)
399 | )
400 | return recon_loss
401 |
402 | @staticmethod
403 | def latent_kl_divergence(
404 | variational_mean: torch.Tensor,
405 | variational_var: torch.Tensor,
406 | prior_mean: torch.Tensor,
407 | prior_var: torch.Tensor,
408 | ) -> torch.Tensor:
409 | """
410 | Compute KL divergence between a variational posterior and prior Gaussian.
411 | Args:
412 | ----
413 | variational_mean: Mean of the variational posterior Gaussian.
414 | variational_var: Variance of the variational posterior Gaussian.
415 | prior_mean: Mean of the prior Gaussian.
416 | prior_var: Variance of the prior Gaussian.
417 |
418 | Returns
419 | -------
420 | KL divergence for each data point. If number of latent samples == 1,
421 | the tensor has shape `(batch_size, )`. If number of latent
422 | samples > 1, the tensor has shape `(n_samples, batch_size)`.
423 | """
424 | return kl(
425 | Normal(variational_mean, variational_var.sqrt()),
426 | Normal(prior_mean, prior_var.sqrt()),
427 | ).sum(dim=-1)
428 |
429 | def library_kl_divergence(
430 | self,
431 | batch_index: torch.Tensor,
432 | variational_library_mean: torch.Tensor,
433 | variational_library_var: torch.Tensor,
434 | library: torch.Tensor,
435 | ) -> torch.Tensor:
436 | """
437 | Compute KL divergence between library size variational posterior and prior.
438 |
439 | Both the variational posterior and prior are Log-Normal.
440 | Args:
441 | ----
442 | batch_index: Batch indices for batch-specific library size mean and
443 | variance.
444 | variational_library_mean: Mean of variational Log-Normal.
445 | variational_library_var: Variance of variational Log-Normal.
446 | library: Sampled library size.
447 |
448 | Returns
449 | -------
450 | KL divergence for each data point. If number of latent samples == 1,
451 | the tensor has shape `(batch_size, )`. If number of latent
452 | samples > 1, the tensor has shape `(n_samples, batch_size)`.
453 | """
454 | if not self.use_observed_lib_size:
455 | (
456 | local_library_log_means,
457 | local_library_log_vars,
458 | ) = self._compute_local_library_params(batch_index)
459 |
460 | kl_library = kl(
461 | Normal(variational_library_mean, variational_library_var.sqrt()),
462 | Normal(local_library_log_means, local_library_log_vars.sqrt()),
463 | )
464 | else:
465 | kl_library = torch.zeros_like(library)
466 | return kl_library.sum(dim=-1)
467 |
468 | def _generic_loss(
469 | self,
470 | tensors: Dict[str, torch.Tensor],
471 | inference_outputs: Dict[str, torch.Tensor],
472 | generative_outputs: Dict[str, torch.Tensor],
473 | ) -> Dict[str, torch.Tensor]:
474 | x = tensors[REGISTRY_KEYS.X_KEY]
475 | batch_index = tensors[REGISTRY_KEYS.BATCH_KEY]
476 |
477 | qz_m = inference_outputs["qz_m"]
478 | qz_v = inference_outputs["qz_v"]
479 | qs_m = inference_outputs["qs_m"]
480 | qs_v = inference_outputs["qs_v"]
481 | library = inference_outputs["library"]
482 | ql_m = inference_outputs["ql_m"]
483 | ql_v = inference_outputs["ql_v"]
484 | px_rate = generative_outputs["px_rate"]
485 | px_r = generative_outputs["px_r"]
486 | px_dropout = generative_outputs["px_dropout"]
487 |
488 | prior_z_m = torch.zeros_like(qz_m)
489 | prior_z_v = torch.ones_like(qz_v)
490 | prior_s_m = torch.zeros_like(qs_m)
491 | prior_s_v = torch.ones_like(qs_v)
492 |
493 | recon_loss = self.reconstruction_loss(x, px_rate, px_r, px_dropout)
494 | kl_z = self.latent_kl_divergence(qz_m, qz_v, prior_z_m, prior_z_v)
495 | kl_s = self.latent_kl_divergence(qs_m, qs_v, prior_s_m, prior_s_v)
496 | kl_library = self.library_kl_divergence(batch_index, ql_m, ql_v, library)
497 | return dict(
498 | recon_loss=recon_loss,
499 | kl_z=kl_z,
500 | kl_s=kl_s,
501 | kl_library=kl_library,
502 | )
503 |
504 | def loss(
505 | self,
506 | concat_tensors: Dict[str, Dict[str, torch.Tensor]],
507 | inference_outputs: Dict[str, Dict[str, torch.Tensor]],
508 | generative_outputs: Dict[str, Dict[str, torch.Tensor]],
509 | kl_weight: float = 1.0,
510 | ) -> LossRecorder:
511 | """
512 | Compute loss terms for contrastive-VI.
513 | Args:
514 | ----
515 | concat_tensors: Tuple of data mini-batch. The first element contains
516 | background data mini-batch. The second element contains target data
517 | mini-batch.
518 | inference_outputs: Dictionary of inference step outputs. The keys
519 | are "background" and "target" for the corresponding outputs.
520 | generative_outputs: Dictionary of generative step outputs. The keys
521 | are "background" and "target" for the corresponding outputs.
522 | kl_weight: Importance weight for KL divergence of background and salient
523 | latent variables, relative to KL divergence of library size.
524 |
525 | Returns
526 | -------
527 | An scvi.module.base.LossRecorder instance that records the following:
528 | loss: One-dimensional tensor for overall loss used for optimization.
529 | reconstruction_loss: Reconstruction loss with shape
530 | `(n_samples, batch_size)` if number of latent samples > 1, or
531 | `(batch_size, )` if number of latent samples == 1.
532 | kl_local: KL divergence term with shape
533 | `(n_samples, batch_size)` if number of latent samples > 1, or
534 | `(batch_size, )` if number of latent samples == 1.
535 | kl_global: One-dimensional tensor for global KL divergence term.
536 | """
537 | background_tensors = concat_tensors["background"]
538 | target_tensors = concat_tensors["target"]
539 | # Ensure batch sizes are the same.
540 | min_batch_size = self._get_min_batch_size(concat_tensors)
541 | self._reduce_tensors_to_min_batch_size(background_tensors, min_batch_size)
542 | self._reduce_tensors_to_min_batch_size(target_tensors, min_batch_size)
543 |
544 | background_losses = self._generic_loss(
545 | background_tensors,
546 | inference_outputs["background"],
547 | generative_outputs["background"],
548 | )
549 | target_losses = self._generic_loss(
550 | target_tensors,
551 | inference_outputs["target"],
552 | generative_outputs["target"],
553 | )
554 | reconst_loss = background_losses["recon_loss"] + target_losses["recon_loss"]
555 | kl_divergence_z = background_losses["kl_z"] + target_losses["kl_z"]
556 | kl_divergence_s = target_losses["kl_s"]
557 | kl_divergence_l = background_losses["kl_library"] + target_losses["kl_library"]
558 |
559 | wasserstein_loss = (
560 | torch.norm(inference_outputs["background"]["qs_m"], dim=-1)**2
561 | + torch.sum(inference_outputs["background"]["qs_v"], dim=-1)
562 | )
563 |
564 | kl_local_for_warmup = kl_divergence_z + kl_divergence_s
565 | kl_local_no_warmup = kl_divergence_l
566 |
567 | weighted_kl_local = kl_weight * (self.wasserstein_penalty*wasserstein_loss
568 | + kl_local_for_warmup) + kl_local_no_warmup
569 |
570 | loss = torch.mean(reconst_loss + weighted_kl_local)
571 |
572 | kl_local = dict(
573 | kl_divergence_l=kl_divergence_l,
574 | kl_divergence_z=kl_divergence_z,
575 | kl_divergence_s=kl_divergence_s
576 | )
577 | kl_global = torch.tensor(0.0)
578 |
579 | # LossRecorder internally sums the `reconst_loss`, `kl_local`, and `kl_global`
580 | # terms before logging, so we do the same for our `wasserstein_loss` term.
581 | return LossRecorder(
582 | loss,
583 | reconst_loss,
584 | kl_local,
585 | kl_global,
586 | wasserstein_loss=torch.sum(wasserstein_loss)
587 | )
588 |
589 | @torch.no_grad()
590 | def sample(self):
591 | raise NotImplementedError
592 |
593 | @torch.no_grad()
594 | @auto_move_data
595 | def marginal_ll(self):
596 | raise NotImplementedError
597 |
--------------------------------------------------------------------------------
/contrastive_vi/module/total_contrastive_vi.py:
--------------------------------------------------------------------------------
1 | """PyTorch module for Contrastive VI for single cell expression data."""
2 |
3 | from typing import Dict, Optional, Tuple, Union, Literal, Iterable
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn.functional as F
8 | from scvi import REGISTRY_KEYS
9 | from scvi.distributions import NegativeBinomial, NegativeBinomialMixture
10 | from scvi.module.base import BaseModuleClass, LossRecorder, auto_move_data
11 | from scvi.nn import DecoderTOTALVI, EncoderTOTALVI, one_hot
12 | from torch.distributions import Normal
13 | from torch.distributions import kl_divergence as kl
14 |
15 | torch.backends.cudnn.benchmark = True
16 |
17 |
18 | class TotalContrastiveVIModule(BaseModuleClass):
19 | """
20 | PyTorch module for total-contrastiveVI (contrastive analysis for CITE-seq).
21 |
22 | Args:
23 | ----
24 | n_input_genes: Number of input genes.
25 | n_input_proteins: Number of input proteins.
26 | n_batch: Number of batches. If 0, no batch effect correction is performed.
27 | n_hidden: Number of nodes per hidden layer.
28 | n_background_latent: Dimensionality of the background latent space.
29 | n_salient_latent: Dimensionality of the salient latent space.
30 | n_layers: Number of hidden layers used for encoder and decoder NNs.
31 | dropout_rate: Dropout rate for neural networks.
32 | protein_batch_mask: Dictionary where each key is a batch code, and value is for
33 | each protein, whether it was observed or not.
34 | use_observed_lib_size: Use observed library size for RNA as scaling factor in
35 | mean of conditional distribution.
36 | library_log_means: 1 x n_batch array of means of the log library sizes.
37 | Parameterize prior on library size if not using observed library size.
38 | library_log_vars: 1 x n_batch array of variances of the log library sizes.
39 | Parameterize prior on library size if not using observed library size.
40 | wasserstein_penalty: Weight of the Wasserstein distance loss that further
41 | discourages shared variations from leaking into the salient latent space.
42 | """
43 |
44 | def __init__(
45 | self,
46 | n_input_genes: int,
47 | n_input_proteins: int,
48 | n_batch: int = 0,
49 | n_labels: int = 0,
50 | n_hidden: int = 128,
51 | n_background_latent: int = 10,
52 | n_salient_latent: int = 10,
53 | n_layers_encoder: int = 2,
54 | n_layers_decoder: int = 1,
55 | n_continuous_cov: int = 0,
56 | n_cats_per_cov: Optional[Iterable[int]] = None,
57 | dropout_rate_decoder: float = 0.2,
58 | dropout_rate_encoder: float = 0.2,
59 | gene_dispersion: str = "gene",
60 | protein_dispersion: str = "protein",
61 | log_variational: bool = True,
62 | gene_likelihood: str = "nb",
63 | latent_distribution: str = "normal",
64 | protein_batch_mask: Dict[Union[str, int], np.ndarray] = None,
65 | encode_covariates: bool = True,
66 | protein_background_prior_mean: Optional[np.ndarray] = None,
67 | protein_background_prior_scale: Optional[np.ndarray] = None,
68 | use_size_factor_key: bool = False,
69 | use_observed_lib_size: bool = True,
70 | library_log_means: Optional[np.ndarray] = None,
71 | library_log_vars: Optional[np.ndarray] = None,
72 | use_batch_norm: Literal["encoder", "decoder", "none", "both"] = "both",
73 | use_layer_norm: Literal["encoder", "decoder", "none", "both"] = "none",
74 | wasserstein_penalty: float = 0,
75 | ) -> None:
76 | super().__init__()
77 | self.gene_dispersion = gene_dispersion
78 | self.n_background_latent = n_background_latent
79 | self.n_salient_latent = n_salient_latent
80 | self.log_variational = log_variational
81 | self.gene_likelihood = gene_likelihood
82 | self.n_batch = n_batch
83 | self.n_labels = n_labels
84 | self.n_input_genes = n_input_genes
85 | self.n_input_proteins = n_input_proteins
86 | self.protein_dispersion = protein_dispersion
87 | self.latent_distribution = latent_distribution
88 | self.protein_batch_mask = protein_batch_mask
89 | self.encode_covariates = encode_covariates
90 | self.use_size_factor_key = use_size_factor_key
91 | self.use_observed_lib_size = use_size_factor_key or use_observed_lib_size
92 | self.wasserstein_penalty = wasserstein_penalty
93 |
94 | if not self.use_observed_lib_size:
95 | if library_log_means is None or library_log_means is None:
96 | raise ValueError(
97 | "If not using observed_lib_size, "
98 | "must provide library_log_means and library_log_vars."
99 | )
100 |
101 | self.register_buffer(
102 | "library_log_means", torch.from_numpy(library_log_means).float()
103 | )
104 | self.register_buffer(
105 | "library_log_vars", torch.from_numpy(library_log_vars).float()
106 | )
107 |
108 | # parameters for prior on rate_back (background protein mean)
109 | if protein_background_prior_mean is None:
110 | if n_batch > 0:
111 | self.background_pro_alpha = torch.nn.Parameter(
112 | torch.randn(n_input_proteins, n_batch)
113 | )
114 | self.background_pro_log_beta = torch.nn.Parameter(
115 | torch.clamp(torch.randn(n_input_proteins, n_batch), -10, 1)
116 | )
117 | else:
118 | self.background_pro_alpha = torch.nn.Parameter(
119 | torch.randn(n_input_proteins)
120 | )
121 | self.background_pro_log_beta = torch.nn.Parameter(
122 | torch.clamp(torch.randn(n_input_proteins), -10, 1)
123 | )
124 | else:
125 | if protein_background_prior_mean.shape[1] == 1 and n_batch != 1:
126 | init_mean = protein_background_prior_mean.ravel()
127 | init_scale = protein_background_prior_scale.ravel()
128 | else:
129 | init_mean = protein_background_prior_mean
130 | init_scale = protein_background_prior_scale
131 | self.background_pro_alpha = torch.nn.Parameter(
132 | torch.from_numpy(init_mean.astype(np.float32))
133 | )
134 | self.background_pro_log_beta = torch.nn.Parameter(
135 | torch.log(torch.from_numpy(init_scale.astype(np.float32)))
136 | )
137 |
138 | if self.gene_dispersion == "gene":
139 | self.px_r = torch.nn.Parameter(torch.randn(n_input_genes))
140 | elif self.gene_dispersion == "gene-batch":
141 | self.px_r = torch.nn.Parameter(torch.randn(n_input_genes, n_batch))
142 | elif self.gene_dispersion == "gene-label":
143 | self.px_r = torch.nn.Parameter(torch.randn(n_input_genes, n_labels))
144 | else: # gene-cell
145 | pass
146 |
147 | if self.protein_dispersion == "protein":
148 | self.py_r = torch.nn.Parameter(2 * torch.rand(self.n_input_proteins))
149 | elif self.protein_dispersion == "protein-batch":
150 | self.py_r = torch.nn.Parameter(
151 | 2 * torch.rand(self.n_input_proteins, n_batch)
152 | )
153 | elif self.protein_dispersion == "protein-label":
154 | self.py_r = torch.nn.Parameter(
155 | 2 * torch.rand(self.n_input_proteins, n_labels)
156 | )
157 | else: # protein-cell
158 | pass
159 |
160 | use_batch_norm_encoder = use_batch_norm == "encoder" or use_batch_norm == "both"
161 | use_batch_norm_decoder = use_batch_norm == "decoder" or use_batch_norm == "both"
162 | use_layer_norm_encoder = use_layer_norm == "encoder" or use_layer_norm == "both"
163 | use_layer_norm_decoder = use_layer_norm == "decoder" or use_layer_norm == "both"
164 |
165 | # z encoder goes from the n_input-dimensional data to an n_latent-d
166 | # latent space representation
167 | n_input = n_input_genes + self.n_input_proteins
168 | n_input_encoder = n_input + n_continuous_cov * encode_covariates
169 | cat_list = [n_batch] + list([] if n_cats_per_cov is None else n_cats_per_cov)
170 | encoder_cat_list = cat_list if encode_covariates else None
171 | self.z_encoder = EncoderTOTALVI(
172 | n_input_encoder,
173 | n_background_latent,
174 | n_layers=n_layers_encoder,
175 | n_cat_list=encoder_cat_list,
176 | n_hidden=n_hidden,
177 | dropout_rate=dropout_rate_encoder,
178 | distribution=latent_distribution,
179 | use_batch_norm=use_batch_norm_encoder,
180 | use_layer_norm=use_layer_norm_encoder,
181 | )
182 |
183 | self.s_encoder = EncoderTOTALVI(
184 | n_input_encoder,
185 | n_salient_latent,
186 | n_layers=n_layers_encoder,
187 | n_cat_list=encoder_cat_list,
188 | n_hidden=n_hidden,
189 | dropout_rate=dropout_rate_encoder,
190 | distribution=latent_distribution,
191 | use_batch_norm=use_batch_norm_encoder,
192 | use_layer_norm=use_layer_norm_encoder,
193 | )
194 | n_total_latent = n_background_latent + n_salient_latent
195 | self.decoder = DecoderTOTALVI(
196 | n_total_latent + n_continuous_cov,
197 | n_input_genes,
198 | self.n_input_proteins,
199 | n_layers=n_layers_decoder,
200 | n_cat_list=cat_list,
201 | n_hidden=n_hidden,
202 | dropout_rate=dropout_rate_decoder,
203 | use_batch_norm=use_batch_norm_decoder,
204 | use_layer_norm=use_layer_norm_decoder,
205 | scale_activation="softplus" if use_size_factor_key else "softmax",
206 | )
207 |
208 | @auto_move_data
209 | def _compute_local_library_params(
210 | self, batch_index: torch.Tensor
211 | ) -> Tuple[torch.Tensor, torch.Tensor]:
212 | n_batch = self.library_log_means.shape[1]
213 | local_library_log_means = F.linear(
214 | one_hot(batch_index, n_batch), self.library_log_means
215 | )
216 | local_library_log_vars = F.linear(
217 | one_hot(batch_index, n_batch), self.library_log_vars
218 | )
219 | return local_library_log_means, local_library_log_vars
220 |
221 | @staticmethod
222 | def _get_min_batch_size(concat_tensors: Dict[str, Dict[str, torch.Tensor]]) -> int:
223 | return min(
224 | concat_tensors["background"][REGISTRY_KEYS.X_KEY].shape[0],
225 | concat_tensors["target"][REGISTRY_KEYS.X_KEY].shape[0],
226 | )
227 |
228 | @staticmethod
229 | def _reduce_tensors_to_min_batch_size(
230 | tensors: Dict[str, torch.Tensor], min_batch_size: int
231 | ) -> None:
232 | for name, tensor in tensors.items():
233 | tensors[name] = tensor[:min_batch_size, :]
234 |
235 | @staticmethod
236 | def _get_inference_input_from_concat_tensors(
237 | concat_tensors: Dict[str, Dict[str, torch.Tensor]], index: str
238 | ) -> Dict[str, torch.Tensor]:
239 | tensors = concat_tensors[index]
240 | x = tensors[REGISTRY_KEYS.X_KEY]
241 | y = tensors[REGISTRY_KEYS.PROTEIN_EXP_KEY]
242 | batch_index = tensors[REGISTRY_KEYS.BATCH_KEY]
243 | input_dict = dict(x=x, y=y, batch_index=batch_index)
244 | return input_dict
245 |
246 | def _get_inference_input(
247 | self, concat_tensors: Dict[str, Dict[str, torch.Tensor]]
248 | ) -> Dict[str, Dict[str, torch.Tensor]]:
249 | background = self._get_inference_input_from_concat_tensors(
250 | concat_tensors, "background"
251 | )
252 | target = self._get_inference_input_from_concat_tensors(concat_tensors, "target")
253 | # Ensure batch sizes are the same.
254 | min_batch_size = self._get_min_batch_size(concat_tensors)
255 | self._reduce_tensors_to_min_batch_size(background, min_batch_size)
256 | self._reduce_tensors_to_min_batch_size(target, min_batch_size)
257 | return dict(background=background, target=target)
258 |
259 | @staticmethod
260 | def _get_generative_input_from_concat_tensors(
261 | concat_tensors: Dict[str, Dict[str, torch.Tensor]], index: str
262 | ) -> Dict[str, torch.Tensor]:
263 | tensors = concat_tensors[index]
264 | batch_index = tensors[REGISTRY_KEYS.BATCH_KEY]
265 | input_dict = dict(batch_index=batch_index)
266 | return input_dict
267 |
268 | @staticmethod
269 | def _get_generative_input_from_inference_outputs(
270 | inference_outputs: Dict[str, Dict[str, torch.Tensor]], data_source: str
271 | ) -> Dict[str, torch.Tensor]:
272 | z = inference_outputs[data_source]["z"]
273 | s = inference_outputs[data_source]["s"]
274 | library_gene = inference_outputs[data_source]["library_gene"]
275 | return dict(z=z, s=s, library_gene=library_gene)
276 |
277 | def _get_generative_input(
278 | self,
279 | concat_tensors: Dict[str, Dict[str, torch.Tensor]],
280 | inference_outputs: Dict[str, Dict[str, torch.Tensor]],
281 | ) -> Dict[str, Dict[str, torch.Tensor]]:
282 | background_tensor_input = self._get_generative_input_from_concat_tensors(
283 | concat_tensors, "background"
284 | )
285 | target_tensor_input = self._get_generative_input_from_concat_tensors(
286 | concat_tensors, "target"
287 | )
288 | # Ensure batch sizes are the same.
289 | min_batch_size = self._get_min_batch_size(concat_tensors)
290 | self._reduce_tensors_to_min_batch_size(background_tensor_input, min_batch_size)
291 | self._reduce_tensors_to_min_batch_size(target_tensor_input, min_batch_size)
292 |
293 | background_inference_outputs = (
294 | self._get_generative_input_from_inference_outputs(
295 | inference_outputs, "background"
296 | )
297 | )
298 | target_inference_outputs = self._get_generative_input_from_inference_outputs(
299 | inference_outputs, "target"
300 | )
301 | background = {**background_tensor_input, **background_inference_outputs}
302 | target = {**target_tensor_input, **target_inference_outputs}
303 | return dict(background=background, target=target)
304 |
305 | @staticmethod
306 | def _reshape_tensor_for_samples(tensor: torch.Tensor, n_samples: int):
307 | return tensor.unsqueeze(0).expand((n_samples, tensor.size(0), tensor.size(1)))
308 |
309 | @auto_move_data
310 | def _generic_inference(
311 | self,
312 | x: torch.Tensor,
313 | y: torch.Tensor,
314 | batch_index: torch.Tensor,
315 | n_samples: int = 1,
316 | ) -> Dict[str, torch.Tensor]:
317 | x_ = x
318 | y_ = y
319 | if self.use_observed_lib_size:
320 | library_gene = x.sum(1).unsqueeze(1)
321 | x_ = torch.log(1 + x_)
322 | y_ = torch.log(1 + y_)
323 | encoder_input = torch.cat((x_, y_), dim=-1)
324 |
325 | (
326 | qz_m,
327 | qz_v,
328 | ql_m,
329 | ql_v,
330 | background_latent,
331 | untran_background_latent,
332 | ) = self.z_encoder(encoder_input, batch_index)
333 | z = background_latent["z"]
334 | untran_z = untran_background_latent["z"]
335 | untran_l = untran_background_latent["l"] # Library encoder used and updated.
336 | (qs_m, qs_v, _, _, salient_latent, untran_salient_latent) = self.s_encoder(
337 | encoder_input, batch_index
338 | )
339 | s = salient_latent["z"]
340 | untran_s = untran_salient_latent["z"]
341 | # Library encoder not used and not updated.
342 |
343 | if not self.use_observed_lib_size:
344 | library_gene = background_latent["l"]
345 |
346 | if n_samples > 1:
347 | qz_m = self._reshape_tensor_for_samples(qz_m, n_samples)
348 | qz_v = self._reshape_tensor_for_samples(qz_v, n_samples)
349 | untran_z = Normal(qz_m, qz_v.sqrt()).sample()
350 | z = self.z_encoder.z_transformation(untran_z)
351 |
352 | qs_m = self._reshape_tensor_for_samples(qs_m, n_samples)
353 | qs_v = self._reshape_tensor_for_samples(qs_v, n_samples)
354 | untran_s = Normal(qs_m, qs_v.sqrt()).sample()
355 | s = self.s_encoder.z_transformation(untran_s)
356 |
357 | ql_m = self._reshape_tensor_for_samples(ql_m, n_samples)
358 | ql_v = self._reshape_tensor_for_samples(ql_v, n_samples)
359 | untran_l = Normal(ql_m, ql_v.sqrt()).sample()
360 |
361 | if self.use_observed_lib_size:
362 | library_gene = self._reshape_tensor_for_samples(library_gene, n_samples)
363 | else:
364 | library_gene = self.z_encoder.l_transformation(untran_l)
365 |
366 | if self.n_batch > 0:
367 | py_back_alpha_prior = F.linear(
368 | one_hot(batch_index, self.n_batch), self.background_pro_alpha
369 | )
370 | py_back_beta_prior = F.linear(
371 | one_hot(batch_index, self.n_batch),
372 | torch.exp(self.background_pro_log_beta),
373 | )
374 | else:
375 | py_back_alpha_prior = self.background_pro_alpha
376 | py_back_beta_prior = torch.exp(self.background_pro_log_beta)
377 |
378 | back_mean_prior = Normal(py_back_alpha_prior, py_back_beta_prior)
379 |
380 | outputs = dict(
381 | untran_z=untran_z,
382 | z=z,
383 | qz_m=qz_m,
384 | qz_v=qz_v,
385 | untran_s=untran_s,
386 | s=s,
387 | qs_m=qs_m,
388 | qs_v=qs_v,
389 | library_gene=library_gene,
390 | ql_m=ql_m,
391 | ql_v=ql_v,
392 | untran_l=untran_l,
393 | back_mean_prior=back_mean_prior
394 | )
395 | return outputs
396 |
397 | @auto_move_data
398 | def inference(
399 | self,
400 | background: Dict[str, torch.Tensor],
401 | target: Dict[str, torch.Tensor],
402 | n_samples: int = 1,
403 | ) -> Dict[str, Dict[str, torch.Tensor]]:
404 | background_outputs = self._generic_inference(**background, n_samples=n_samples)
405 | target_outputs = self._generic_inference(**target, n_samples=n_samples)
406 | background_outputs["s"] = torch.zeros_like(background_outputs["s"])
407 | return dict(background=background_outputs, target=target_outputs)
408 |
409 | @auto_move_data
410 | def _generic_generative(
411 | self,
412 | z: torch.Tensor,
413 | s: torch.Tensor,
414 | library_gene: torch.Tensor,
415 | batch_index: torch.Tensor,
416 | ) -> Dict[str, torch.Tensor]:
417 | latent = torch.cat([z, s], dim=-1)
418 | px_, py_, log_pro_back_mean = self.decoder(
419 | latent,
420 | library_gene,
421 | batch_index,
422 | )
423 | px_r = torch.exp(self.px_r)
424 | py_r = torch.exp(self.py_r)
425 | px_["r"] = px_r
426 | py_["r"] = py_r
427 | return dict(px_=px_, py_=py_, log_pro_back_mean=log_pro_back_mean)
428 |
429 | @auto_move_data
430 | def generative(
431 | self,
432 | background: Dict[str, torch.Tensor],
433 | target: Dict[str, torch.Tensor],
434 | ) -> Dict[str, Dict[str, torch.Tensor]]:
435 | latent_z_shape = background["z"].shape
436 | batch_size_dim = 0 if len(latent_z_shape) == 2 else 1
437 | background_batch_size = background["z"].shape[batch_size_dim]
438 | target_batch_size = target["z"].shape[batch_size_dim]
439 | generative_input = {}
440 | for key in ["z", "s", "library_gene"]:
441 | generative_input[key] = torch.cat(
442 | [background[key], target[key]], dim=batch_size_dim
443 | )
444 | generative_input["batch_index"] = torch.cat(
445 | [background["batch_index"], target["batch_index"]], dim=0
446 | )
447 | outputs = self._generic_generative(**generative_input)
448 |
449 | # Split outputs into corresponding background and target set.
450 | background_outputs = {"px_": {}, "py_": {}}
451 | target_outputs = {"px_": {}, "py_": {}}
452 | for modality in ["px_", "py_"]:
453 | for key in outputs[modality].keys():
454 | if key == "r":
455 | background_tensor = outputs[modality][key]
456 | target_tensor = outputs[modality][key]
457 | else:
458 | if outputs[modality][key] is not None:
459 | background_tensor, target_tensor = torch.split(
460 | outputs[modality][key],
461 | [background_batch_size, target_batch_size],
462 | dim=batch_size_dim,
463 | )
464 | else:
465 | background_tensor, target_tensor = None, None
466 | background_outputs[modality][key] = background_tensor
467 | target_outputs[modality][key] = target_tensor
468 |
469 | if outputs["log_pro_back_mean"] is not None:
470 | background_tensor, target_tensor = torch.split(
471 | outputs["log_pro_back_mean"],
472 | [background_batch_size, target_batch_size],
473 | dim=batch_size_dim,
474 | )
475 | else:
476 | background_tensor, target_tensor = None, None
477 | background_outputs["log_pro_back_mean"] = background_tensor
478 | target_outputs["log_pro_back_mean"] = target_tensor
479 |
480 | return dict(background=background_outputs, target=target_outputs)
481 |
482 | @staticmethod
483 | def get_reconstruction_loss(
484 | x: torch.Tensor,
485 | y: torch.Tensor,
486 | px_dict: Dict[str, torch.Tensor],
487 | py_dict: Dict[str, torch.Tensor],
488 | pro_batch_mask_minibatch: Optional[torch.Tensor] = None,
489 | ) -> Tuple[torch.Tensor, torch.Tensor]:
490 | """Compute reconstruction loss."""
491 | px_ = px_dict
492 | py_ = py_dict
493 |
494 | reconst_loss_gene = (
495 | -NegativeBinomial(mu=px_["rate"], theta=px_["r"]).log_prob(x).sum(dim=-1)
496 | )
497 |
498 | py_conditional = NegativeBinomialMixture(
499 | mu1=py_["rate_back"],
500 | mu2=py_["rate_fore"],
501 | theta1=py_["r"],
502 | mixture_logits=py_["mixing"],
503 | )
504 | reconst_loss_protein_full = -py_conditional.log_prob(y)
505 | if pro_batch_mask_minibatch is not None:
506 | temp_pro_loss_full = torch.zeros_like(reconst_loss_protein_full)
507 | temp_pro_loss_full.masked_scatter_(
508 | pro_batch_mask_minibatch.bool(), reconst_loss_protein_full
509 | )
510 |
511 | reconst_loss_protein = temp_pro_loss_full.sum(dim=-1)
512 | else:
513 | reconst_loss_protein = reconst_loss_protein_full.sum(dim=-1)
514 |
515 | return reconst_loss_gene, reconst_loss_protein
516 |
517 | @staticmethod
518 | def latent_kl_divergence(
519 | variational_mean: torch.Tensor,
520 | variational_var: torch.Tensor,
521 | prior_mean: torch.Tensor,
522 | prior_var: torch.Tensor,
523 | ) -> torch.Tensor:
524 | """
525 | Compute KL divergence between a variational posterior and prior Gaussian.
526 | Args:
527 | ----
528 | variational_mean: Mean of the variational posterior Gaussian.
529 | variational_var: Variance of the variational posterior Gaussian.
530 | prior_mean: Mean of the prior Gaussian.
531 | prior_var: Variance of the prior Gaussian.
532 |
533 | Returns
534 | -------
535 | KL divergence for each data point. If number of latent samples == 1,
536 | the tensor has shape `(batch_size, )`. If number of latent
537 | samples > 1, the tensor has shape `(n_samples, batch_size)`.
538 | """
539 | return kl(
540 | Normal(variational_mean, variational_var.sqrt()),
541 | Normal(prior_mean, prior_var.sqrt()),
542 | ).sum(dim=-1)
543 |
544 | def library_gene_kl_divergence(
545 | self,
546 | batch_index: torch.Tensor,
547 | variational_library_mean: torch.Tensor,
548 | variational_library_var: torch.Tensor,
549 | library: torch.Tensor,
550 | ) -> torch.Tensor:
551 | """
552 | Compute KL divergence between library size variational posterior and prior.
553 |
554 | Both the variational posterior and prior are Log-Normal.
555 | Args:
556 | ----
557 | batch_index: Batch indices for batch-specific library size mean and
558 | variance.
559 | variational_library_mean: Mean of variational Log-Normal.
560 | variational_library_var: Variance of variational Log-Normal.
561 | library: Sampled library size.
562 |
563 | Returns
564 | -------
565 | KL divergence for each data point. If number of latent samples == 1,
566 | the tensor has shape `(batch_size, )`. If number of latent
567 | samples > 1, the tensor has shape `(n_samples, batch_size)`.
568 | """
569 | if not self.use_observed_lib_size:
570 | (
571 | local_library_log_means,
572 | local_library_log_vars,
573 | ) = self._compute_local_library_params(batch_index)
574 |
575 | kl_library = kl(
576 | Normal(variational_library_mean, variational_library_var.sqrt()),
577 | Normal(local_library_log_means, local_library_log_vars.sqrt()),
578 | )
579 | else:
580 | kl_library = torch.zeros_like(library)
581 | return kl_library.sum(dim=-1)
582 |
583 | def _generic_loss(
584 | self,
585 | tensors: Dict[str, torch.Tensor],
586 | inference_outputs: Dict[str, torch.Tensor],
587 | generative_outputs: Dict[str, torch.Tensor],
588 | ) -> Dict[str, torch.Tensor]:
589 | x = tensors[REGISTRY_KEYS.X_KEY]
590 | y = tensors[REGISTRY_KEYS.PROTEIN_EXP_KEY]
591 | batch_index = tensors[REGISTRY_KEYS.BATCH_KEY]
592 |
593 | qz_m = inference_outputs["qz_m"]
594 | qz_v = inference_outputs["qz_v"]
595 | qs_m = inference_outputs["qs_m"]
596 | qs_v = inference_outputs["qs_v"]
597 | ql_m = inference_outputs["ql_m"]
598 | ql_v = inference_outputs["ql_v"]
599 | library_gene = inference_outputs["library_gene"]
600 | px_ = generative_outputs["px_"]
601 | py_ = generative_outputs["py_"]
602 |
603 | prior_z_m = torch.zeros_like(qz_m)
604 | prior_z_v = torch.ones_like(qz_v)
605 | prior_s_m = torch.zeros_like(qs_m)
606 | prior_s_v = torch.ones_like(qs_v)
607 |
608 | if self.protein_batch_mask is not None:
609 | pro_batch_mask_minibatch = torch.zeros_like(y)
610 | for b in torch.unique(batch_index):
611 | b_indices = (batch_index == b).reshape(-1)
612 | pro_batch_mask_minibatch[b_indices] = torch.tensor(
613 | self.protein_batch_mask[b.item()].astype(np.float32),
614 | device=y.device,
615 | )
616 | else:
617 | pro_batch_mask_minibatch = None
618 | reconst_loss_gene, reconst_loss_protein = self.get_reconstruction_loss(
619 | x, y, px_, py_, pro_batch_mask_minibatch
620 | )
621 |
622 | kl_z = self.latent_kl_divergence(qz_m, qz_v, prior_z_m, prior_z_v)
623 | kl_s = self.latent_kl_divergence(qs_m, qs_v, prior_s_m, prior_s_v)
624 | kl_library_gene = self.library_gene_kl_divergence(
625 | batch_index, ql_m, ql_v, library_gene
626 | )
627 |
628 | kl_div_back_pro_full = kl(
629 | Normal(py_["back_alpha"], py_["back_beta"]), inference_outputs["back_mean_prior"]
630 | )
631 | if pro_batch_mask_minibatch is not None:
632 | kl_div_back_pro = (pro_batch_mask_minibatch * kl_div_back_pro_full).sum(
633 | dim=-1
634 | )
635 | else:
636 | kl_div_back_pro = kl_div_back_pro_full.sum(dim=-1)
637 |
638 | return dict(
639 | reconst_loss_gene=reconst_loss_gene,
640 | reconst_loss_protein=reconst_loss_protein,
641 | kl_z=kl_z,
642 | kl_s=kl_s,
643 | kl_library_gene=kl_library_gene,
644 | kl_div_back_pro=kl_div_back_pro,
645 | )
646 |
647 | def loss(
648 | self,
649 | concat_tensors: Dict[str, Dict[str, torch.Tensor]],
650 | inference_outputs: Dict[str, Dict[str, torch.Tensor]],
651 | generative_outputs: Dict[str, Dict[str, torch.Tensor]],
652 | kl_weight: float = 1.0,
653 | ) -> LossRecorder:
654 | """
655 | Compute loss terms for contrastive-VI.
656 | Args:
657 | ----
658 | concat_tensors: Tuple of data mini-batch. The first element contains
659 | background data mini-batch. The second element contains target data
660 | mini-batch.
661 | inference_outputs: Dictionary of inference step outputs. The keys
662 | are "background" and "target" for the corresponding outputs.
663 | generative_outputs: Dictionary of generative step outputs. The keys
664 | are "background" and "target" for the corresponding outputs.
665 |
666 | Returns
667 | -------
668 | An scvi.module.base.LossRecorder instance that records the losses.
669 | """
670 | background_tensors = concat_tensors["background"]
671 | target_tensors = concat_tensors["target"]
672 | # Ensure batch sizes are the same.
673 | min_batch_size = self._get_min_batch_size(concat_tensors)
674 | self._reduce_tensors_to_min_batch_size(background_tensors, min_batch_size)
675 | self._reduce_tensors_to_min_batch_size(target_tensors, min_batch_size)
676 |
677 | background_losses = self._generic_loss(
678 | background_tensors,
679 | inference_outputs["background"],
680 | generative_outputs["background"],
681 | )
682 | target_losses = self._generic_loss(
683 | target_tensors,
684 | inference_outputs["target"],
685 | generative_outputs["target"],
686 | )
687 |
688 | reconst_loss_gene = (
689 | background_losses["reconst_loss_gene"] + target_losses["reconst_loss_gene"]
690 | )
691 | reconst_loss_protein = (
692 | background_losses["reconst_loss_protein"]
693 | + target_losses["reconst_loss_protein"]
694 | )
695 |
696 | wasserstein_loss = (
697 | torch.norm(inference_outputs["background"]["qs_m"], dim=-1)**2
698 | + torch.sum(inference_outputs["background"]["qs_v"], dim=-1)
699 | )
700 |
701 | kl_div_z = background_losses["kl_z"] + target_losses["kl_z"]
702 | kl_div_s = target_losses["kl_s"]
703 | kl_div_l_gene = (
704 | background_losses["kl_library_gene"] + target_losses["kl_library_gene"]
705 | )
706 | kl_div_back_pro = (
707 | background_losses["kl_div_back_pro"] + target_losses["kl_div_back_pro"]
708 | )
709 |
710 | loss = torch.mean(
711 | reconst_loss_gene
712 | + reconst_loss_protein
713 | + kl_weight * kl_div_z
714 | + kl_weight * kl_div_s
715 | + kl_weight * self.wasserstein_penalty*wasserstein_loss
716 | + kl_div_l_gene
717 | + kl_weight * kl_div_back_pro
718 | )
719 |
720 | reconst_losses = dict(
721 | reconst_loss_gene=reconst_loss_gene,
722 | reconst_loss_protein=reconst_loss_protein,
723 | )
724 | kl_local = dict(
725 | kl_div_z=kl_div_z,
726 | kl_div_s=kl_div_s,
727 | kl_div_l_gene=kl_div_l_gene,
728 | kl_div_back_pro=kl_div_back_pro,
729 | )
730 | kl_global = torch.tensor(0.0)
731 | return LossRecorder(
732 | loss,
733 | reconst_losses,
734 | kl_local,
735 | kl_global=kl_global,
736 | wasserstein_loss=torch.sum(wasserstein_loss)
737 | )
738 |
739 | @torch.no_grad()
740 | def sample_mean(
741 | self,
742 | tensors: Dict[str, torch.Tensor],
743 | data_source: str,
744 | n_samples: int = 1,
745 | ) -> torch.Tensor:
746 | """Sample posterior mean."""
747 | raise NotImplementedError
748 |
749 | @torch.no_grad()
750 | def sample(self):
751 | raise NotImplementedError
752 |
753 | @torch.no_grad()
754 | @auto_move_data
755 | def marginal_ll(self):
756 | raise NotImplementedError
757 |
--------------------------------------------------------------------------------
/contrastive_vi/module/utils.py:
--------------------------------------------------------------------------------
1 | """Utilities for contrastiveVI modules."""
2 | import torch
3 |
4 |
5 | def gram_matrix(x: torch.Tensor, y: torch.Tensor, gammas: torch.Tensor) -> torch.Tensor:
6 | """
7 | Calculate the maximum mean discrepancy gram matrix with multiple gamma values.
8 |
9 | Args:
10 | ----
11 | x: Tensor with shape (B, P, M) or (P, M).
12 | y: Tensor with shape (B, R, M) or (R, M).
13 | gammas: 1-D tensor with the gamma values.
14 |
15 | Returns
16 | -------
17 | A tensor with shape (B, P, R) or (P, R) for the distance between pairs of data
18 | points in `x` and `y`.
19 | """
20 | gammas = gammas.unsqueeze(1)
21 | pairwise_distances = torch.cdist(x, y, p=2.0)
22 |
23 | pairwise_distances_sq = torch.square(pairwise_distances)
24 | tmp = torch.matmul(gammas, torch.reshape(pairwise_distances_sq, (1, -1)))
25 | tmp = torch.reshape(torch.sum(torch.exp(-tmp), 0), pairwise_distances_sq.shape)
26 | return tmp
27 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: contrastive-vi-env
2 | channels:
3 | - defaults
4 | - conda-forge
5 | dependencies:
6 | - python==3.9.*
7 | - pip
8 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools", "wheel"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [tool.black]
6 | line-length = 88
7 |
8 | [tool.isort]
9 | profile = "black"
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | name = contrastive-vi
3 | version = 0.2.0
4 |
5 | [options]
6 | package_dir =
7 | packages =
8 | find:
9 | python_requires =
10 | >=3.7
11 | install_requires =
12 | scanpy>=1.8.1
13 | protobuf<=3.20.1
14 | scvi-tools>=0.15.0, <=0.16.2
15 |
16 | [options.packages.find]
17 | where =
18 |
19 | [options.extras_require]
20 | dev =
21 | pre-commit==2.15.0
22 | toml==0.10.2
23 | pytest==6.2.5
24 |
--------------------------------------------------------------------------------
/sketch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/suinleelab/contrastiveVI/2835d2925f7ef60cdbebd2422435046cb9165f24/sketch.png
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/suinleelab/contrastiveVI/2835d2925f7ef60cdbebd2422435046cb9165f24/tests/__init__.py
--------------------------------------------------------------------------------
/tests/data/dataloaders/test_contrastive_dataloader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from scvi import REGISTRY_KEYS
3 |
4 | from contrastive_vi.data.dataloaders.contrastive_dataloader import ContrastiveDataLoader
5 | from tests.utils import get_next_batch
6 |
7 |
8 | class TestContrastiveDataLoader:
9 | def test_one_batch(
10 | self,
11 | mock_adata,
12 | mock_adata_manager,
13 | mock_adata_background_indices,
14 | mock_adata_background_label,
15 | mock_adata_target_indices,
16 | mock_adata_target_label,
17 | ):
18 | batch_size = 32
19 | dataloader = ContrastiveDataLoader(
20 | mock_adata_manager,
21 | mock_adata_background_indices,
22 | mock_adata_target_indices,
23 | batch_size=batch_size,
24 | shuffle=False,
25 | )
26 | batch = get_next_batch(dataloader)
27 | assert type(batch) == dict
28 | assert len(batch.keys()) == 2
29 | assert "background" in batch.keys()
30 | assert "target" in batch.keys()
31 |
32 | expected_background_data = torch.Tensor(
33 | mock_adata.layers["raw_counts"][mock_adata_background_indices, :][
34 | :batch_size, :
35 | ]
36 | )
37 | expected_target_data = torch.Tensor(
38 | mock_adata.layers["raw_counts"][mock_adata_target_indices, :][
39 | :batch_size, :
40 | ]
41 | )
42 |
43 | assert torch.equal(
44 | batch["background"][REGISTRY_KEYS.X_KEY], expected_background_data
45 | )
46 | assert torch.equal(batch["target"][REGISTRY_KEYS.X_KEY], expected_target_data)
47 |
48 | assert (
49 | batch["background"][REGISTRY_KEYS.BATCH_KEY] == mock_adata_background_label
50 | ).sum() == batch_size
51 | assert (
52 | batch["target"][REGISTRY_KEYS.BATCH_KEY] == mock_adata_target_label
53 | ).sum() == batch_size
54 |
--------------------------------------------------------------------------------
/tests/data/dataloaders/test_data_splitting.py:
--------------------------------------------------------------------------------
1 | from contrastive_vi.data.dataloaders.data_splitting import ContrastiveDataSplitter
2 |
3 |
4 | class TestContrastiveDataSplitter:
5 | def test_num_batches(
6 | self,
7 | mock_adata_manager,
8 | mock_adata_background_indices,
9 | mock_adata_target_indices,
10 | ) -> None:
11 | train_size = 0.8
12 | validation_size = 0.1
13 | test_size = 0.1
14 | batch_size = 20
15 | n_max = max(
16 | len(mock_adata_background_indices),
17 | len(mock_adata_target_indices),
18 | )
19 | expected_train_num_batches = n_max * train_size / batch_size
20 | expected_val_num_batches = n_max * validation_size / batch_size
21 | expected_test_num_batches = n_max * test_size / batch_size
22 |
23 | data_splitter = ContrastiveDataSplitter(
24 | mock_adata_manager,
25 | mock_adata_background_indices,
26 | mock_adata_target_indices,
27 | train_size=train_size,
28 | validation_size=validation_size,
29 | batch_size=batch_size,
30 | )
31 | data_splitter.setup()
32 | train_dataloader = data_splitter.train_dataloader()
33 | val_dataloader = data_splitter.val_dataloader()
34 | test_dataloader = data_splitter.test_dataloader()
35 |
36 | assert len(train_dataloader) == expected_train_num_batches
37 | assert len(val_dataloader) == expected_val_num_batches
38 | assert len(test_dataloader) == expected_test_num_batches
39 |
--------------------------------------------------------------------------------
/tests/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/suinleelab/contrastiveVI/2835d2925f7ef60cdbebd2422435046cb9165f24/tests/model/__init__.py
--------------------------------------------------------------------------------
/tests/model/test_contrastive_vi.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import pytest
4 | import torch
5 |
6 | from contrastive_vi.model.contrastive_vi import ContrastiveVIModel
7 | from tests.utils import copy_module_state_dict
8 |
9 |
10 | @pytest.fixture(
11 | params=[True, False], ids=["with_observed_lib_size", "without_observed_lib_size"]
12 | )
13 | def mock_contrastive_vi_model(
14 | mock_adata,
15 | mock_adata_background_indices,
16 | mock_adata_target_indices,
17 | mock_library_log_means,
18 | mock_library_log_vars,
19 | request,
20 | ):
21 | if request.param:
22 | return ContrastiveVIModel(
23 | mock_adata,
24 | n_hidden=16,
25 | n_background_latent=4,
26 | n_salient_latent=4,
27 | n_layers=2,
28 | use_observed_lib_size=True,
29 | )
30 | else:
31 | return ContrastiveVIModel(
32 | mock_adata,
33 | n_hidden=16,
34 | n_background_latent=4,
35 | n_salient_latent=4,
36 | n_layers=2,
37 | use_observed_lib_size=False,
38 | )
39 |
40 |
41 | class TestContrastiveVIModel:
42 | def test_train(
43 | self,
44 | mock_contrastive_vi_model,
45 | mock_adata_background_indices,
46 | mock_adata_target_indices,
47 | ):
48 | init_state_dict = copy_module_state_dict(mock_contrastive_vi_model.module)
49 | mock_contrastive_vi_model.train(
50 | background_indices=mock_adata_background_indices,
51 | target_indices=mock_adata_target_indices,
52 | max_epochs=10,
53 | batch_size=20, # Unequal final batches to test edge case.
54 | use_gpu=False,
55 | )
56 | trained_state_dict = copy_module_state_dict(mock_contrastive_vi_model.module)
57 | for param_key in mock_contrastive_vi_model.module.state_dict().keys():
58 | is_library_param = (
59 | param_key == "library_log_means" or param_key == "library_log_vars"
60 | )
61 | is_px_r_decoder_param = "px_r_decoder" in param_key
62 | is_l_encoder_param = "l_encoder" in param_key
63 |
64 | if (
65 | is_library_param
66 | or is_px_r_decoder_param
67 | or (
68 | is_l_encoder_param
69 | and mock_contrastive_vi_model.module.use_observed_lib_size
70 | )
71 | ):
72 | # There are three cases where parameters are not updated.
73 | # 1. Library means and vars are derived from input data and should
74 | # not be updated.
75 | # 2. In ContrastiveVIModel, dispersion is assumed to be gene-dependent
76 | # but not cell-dependent, so parameters in the dispersion (px_r)
77 | # decoder are not used and should not be updated.
78 | # 3. When observed library size is used, the library encoder is not
79 | # used and its parameters not updated.
80 | assert torch.equal(
81 | init_state_dict[param_key], trained_state_dict[param_key]
82 | )
83 | else:
84 | # Other parameters should be updated after training.
85 | assert not torch.equal(
86 | init_state_dict[param_key], trained_state_dict[param_key]
87 | )
88 |
89 | @pytest.mark.parametrize("representation_kind", ["background", "salient"])
90 | def test_get_latent_representation(
91 | self, mock_contrastive_vi_model, representation_kind
92 | ):
93 | n_cells = mock_contrastive_vi_model.adata.n_obs
94 | if representation_kind == "background":
95 | n_latent = mock_contrastive_vi_model.module.n_background_latent
96 | else:
97 | n_latent = mock_contrastive_vi_model.module.n_salient_latent
98 | representation = mock_contrastive_vi_model.get_latent_representation(
99 | representation_kind=representation_kind
100 | )
101 | assert representation.shape == (n_cells, n_latent)
102 |
103 | @pytest.mark.parametrize("representation_kind", ["background", "salient"])
104 | def test_get_normalized_expression(
105 | self, mock_contrastive_vi_model, representation_kind
106 | ):
107 | n_samples = 50
108 | n_cells = mock_contrastive_vi_model.adata.n_obs
109 | n_genes = mock_contrastive_vi_model.adata.n_vars
110 | one_sample_exprs = mock_contrastive_vi_model.get_normalized_expression(
111 | n_samples=1, return_numpy=True
112 | )
113 | one_sample_exprs = one_sample_exprs[representation_kind]
114 | assert type(one_sample_exprs) == np.ndarray
115 | assert one_sample_exprs.shape == (n_cells, n_genes)
116 |
117 | many_sample_exprs = mock_contrastive_vi_model.get_normalized_expression(
118 | n_samples=n_samples,
119 | return_mean=False,
120 | )
121 | many_sample_exprs = many_sample_exprs[representation_kind]
122 | assert type(many_sample_exprs) == np.ndarray
123 | assert many_sample_exprs.shape == (n_samples, n_cells, n_genes)
124 |
125 | exprs_df = mock_contrastive_vi_model.get_normalized_expression(
126 | n_samples=1,
127 | return_numpy=False,
128 | )
129 | exprs_df = exprs_df[representation_kind]
130 | assert type(exprs_df) == pd.DataFrame
131 | assert exprs_df.shape == (n_cells, n_genes)
132 |
133 | def test_get_salient_normalized_expression(self, mock_contrastive_vi_model):
134 | n_samples = 50
135 | n_cells = mock_contrastive_vi_model.adata.n_obs
136 | n_genes = mock_contrastive_vi_model.adata.n_vars
137 |
138 | one_sample_expr = mock_contrastive_vi_model.get_salient_normalized_expression(
139 | n_samples=1, return_numpy=True
140 | )
141 | assert type(one_sample_expr) == np.ndarray
142 | assert one_sample_expr.shape == (n_cells, n_genes)
143 |
144 | many_sample_expr = mock_contrastive_vi_model.get_salient_normalized_expression(
145 | n_samples=n_samples,
146 | return_mean=False,
147 | )
148 | assert type(many_sample_expr) == np.ndarray
149 | assert many_sample_expr.shape == (n_samples, n_cells, n_genes)
150 |
151 | expr_df = mock_contrastive_vi_model.get_salient_normalized_expression(
152 | n_samples=1,
153 | return_numpy=False,
154 | )
155 | assert type(expr_df) == pd.DataFrame
156 | assert expr_df.shape == (n_cells, n_genes)
157 |
158 | def test_get_normalized_expression_fold_change(self, mock_contrastive_vi_model):
159 | n_samples = 50
160 | n_cells = mock_contrastive_vi_model.adata.n_obs
161 | n_genes = mock_contrastive_vi_model.adata.n_vars
162 | one_sample_fc = mock_contrastive_vi_model.get_normalized_expression_fold_change(
163 | n_samples=1
164 | )
165 | assert one_sample_fc.shape == (n_cells, n_genes)
166 | many_sample_fc = (
167 | mock_contrastive_vi_model.get_normalized_expression_fold_change(
168 | n_samples=50
169 | )
170 | )
171 | assert many_sample_fc.shape == (n_samples, n_cells, n_genes)
172 |
173 | def test_differential_expression(self, mock_contrastive_vi_model):
174 | de_df = mock_contrastive_vi_model.differential_expression(
175 | groupby="labels",
176 | group1=["label_0"],
177 | )
178 | n_vars = mock_contrastive_vi_model.adata.n_vars
179 | assert type(de_df) == pd.DataFrame
180 | assert de_df.shape[0] == n_vars
181 |
--------------------------------------------------------------------------------
/tests/module/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/suinleelab/contrastiveVI/2835d2925f7ef60cdbebd2422435046cb9165f24/tests/module/__init__.py
--------------------------------------------------------------------------------
/tests/module/test_contrastive_vi.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | from scvi.module.base import LossRecorder
4 |
5 | from contrastive_vi.module.contrastive_vi import ContrastiveVIModule
6 |
7 | required_data_sources = ["background", "target"]
8 | required_inference_input_keys = ["x", "batch_index"]
9 | required_inference_output_keys = [
10 | "z",
11 | "qz_m",
12 | "qz_v",
13 | "s",
14 | "qs_m",
15 | "qs_v",
16 | "library",
17 | "ql_m",
18 | "ql_v",
19 | ]
20 | required_generative_input_keys_from_concat_tensors = ["batch_index"]
21 | required_generative_input_keys_from_inference_outputs = ["z", "s", "library"]
22 | required_generative_output_keys = [
23 | "px_scale",
24 | "px_r",
25 | "px_rate",
26 | "px_dropout",
27 | ]
28 |
29 |
30 | @pytest.fixture(
31 | params=[True, False], ids=["with_observed_lib_size", "without_observed_lib_size"]
32 | )
33 | def mock_contrastive_vi_module(
34 | mock_n_input, mock_n_batch, mock_library_log_means, mock_library_log_vars, request
35 | ):
36 | if request.param:
37 | return ContrastiveVIModule(
38 | n_input=mock_n_input,
39 | n_batch=mock_n_batch,
40 | n_hidden=10,
41 | n_background_latent=4,
42 | n_salient_latent=4,
43 | n_layers=2,
44 | use_observed_lib_size=True,
45 | library_log_means=None,
46 | library_log_vars=None,
47 | )
48 | else:
49 | return ContrastiveVIModule(
50 | n_input=mock_n_input,
51 | n_batch=mock_n_batch,
52 | n_hidden=10,
53 | n_background_latent=4,
54 | n_salient_latent=4,
55 | n_layers=2,
56 | use_observed_lib_size=False,
57 | library_log_means=mock_library_log_means,
58 | library_log_vars=mock_library_log_vars,
59 | )
60 |
61 |
62 | @pytest.fixture(params=[1, 2], ids=["one_latent_sample", "two_latent_samples"])
63 | def mock_contrastive_vi_data(
64 | mock_contrastive_batch,
65 | mock_contrastive_vi_module,
66 | request,
67 | ):
68 | concat_tensors = mock_contrastive_batch
69 | inference_input = mock_contrastive_vi_module._get_inference_input(concat_tensors)
70 | inference_outputs = mock_contrastive_vi_module.inference(
71 | **inference_input, n_samples=request.param
72 | )
73 | generative_input = mock_contrastive_vi_module._get_generative_input(
74 | concat_tensors, inference_outputs
75 | )
76 | generative_outputs = mock_contrastive_vi_module.generative(**generative_input)
77 | return dict(
78 | concat_tensors=concat_tensors,
79 | inference_input=inference_input,
80 | inference_outputs=inference_outputs,
81 | generative_input=generative_input,
82 | generative_outputs=generative_outputs,
83 | )
84 |
85 |
86 | class TestContrastiveVIModuleInference:
87 | def test_get_inference_input_from_concat_tensors(
88 | self,
89 | mock_contrastive_vi_module,
90 | mock_contrastive_batch,
91 | mock_n_input,
92 | ):
93 | inference_input = (
94 | mock_contrastive_vi_module._get_inference_input_from_concat_tensors(
95 | mock_contrastive_batch, "background"
96 | )
97 | )
98 | for key in required_inference_input_keys:
99 | assert key in inference_input.keys()
100 | x = inference_input["x"]
101 | batch_index = inference_input["batch_index"]
102 | batch_size = x.shape[0]
103 | assert x.shape == (batch_size, mock_n_input)
104 | assert batch_index.shape == (batch_size, 1)
105 |
106 | def test_get_inference_input(
107 | self,
108 | mock_contrastive_vi_module,
109 | mock_contrastive_batch,
110 | mock_adata_background_label,
111 | mock_adata_target_label,
112 | ):
113 | inference_input = mock_contrastive_vi_module._get_inference_input(
114 | mock_contrastive_batch
115 | )
116 | for data_source in required_data_sources:
117 | assert data_source in inference_input.keys()
118 |
119 | background_input = inference_input["background"]
120 | background_input_keys = background_input.keys()
121 | target_input = inference_input["target"]
122 | target_input_keys = target_input.keys()
123 |
124 | for key in required_inference_input_keys:
125 | assert key in background_input_keys
126 | assert key in target_input_keys
127 |
128 | # Check background vs. target labels are consistent.
129 | assert (
130 | background_input["batch_index"] != mock_adata_background_label
131 | ).sum() == 0
132 | assert (target_input["batch_index"] != mock_adata_target_label).sum() == 0
133 |
134 | @pytest.mark.parametrize("n_samples", [1, 2])
135 | def test_generic_inference(
136 | self,
137 | mock_contrastive_vi_module,
138 | mock_contrastive_batch,
139 | n_samples,
140 | ):
141 | inference_input = (
142 | mock_contrastive_vi_module._get_inference_input_from_concat_tensors(
143 | mock_contrastive_batch, "background"
144 | )
145 | )
146 | batch_size = inference_input["x"].shape[0]
147 | n_background_latent = mock_contrastive_vi_module.n_background_latent
148 | n_salient_latent = mock_contrastive_vi_module.n_salient_latent
149 |
150 | inference_outputs = mock_contrastive_vi_module._generic_inference(
151 | **inference_input, n_samples=n_samples
152 | )
153 | for key in required_inference_output_keys:
154 | assert key in inference_outputs.keys()
155 |
156 | if n_samples > 1:
157 | expected_background_latent_shape = (
158 | n_samples,
159 | batch_size,
160 | n_background_latent,
161 | )
162 | expected_salient_latent_shape = (n_samples, batch_size, n_salient_latent)
163 | expected_library_shape = (n_samples, batch_size, 1)
164 | else:
165 | expected_background_latent_shape = (batch_size, n_background_latent)
166 | expected_salient_latent_shape = (batch_size, n_salient_latent)
167 | expected_library_shape = (batch_size, 1)
168 |
169 | assert inference_outputs["z"].shape == expected_background_latent_shape
170 | assert inference_outputs["qz_m"].shape == expected_background_latent_shape
171 | assert inference_outputs["qz_v"].shape == expected_background_latent_shape
172 | assert inference_outputs["s"].shape == expected_salient_latent_shape
173 | assert inference_outputs["qs_m"].shape == expected_salient_latent_shape
174 | assert inference_outputs["qs_v"].shape == expected_salient_latent_shape
175 | assert inference_outputs["library"].shape == expected_library_shape
176 | assert (
177 | inference_outputs["ql_m"] is None
178 | or inference_outputs["ql_m"].shape == expected_library_shape
179 | )
180 | assert (
181 | inference_outputs["ql_v"] is None
182 | or inference_outputs["ql_m"].shape == expected_library_shape
183 | )
184 |
185 | def test_inference(
186 | self,
187 | mock_contrastive_vi_module,
188 | mock_contrastive_batch,
189 | ):
190 | inference_input = mock_contrastive_vi_module._get_inference_input(
191 | mock_contrastive_batch
192 | )
193 | inference_outputs = mock_contrastive_vi_module.inference(**inference_input)
194 | for data_source in required_data_sources:
195 | assert data_source in inference_outputs.keys()
196 | background_s = inference_outputs["background"]["s"]
197 |
198 | # Background salient variables should be all zeros.
199 | assert torch.equal(background_s, torch.zeros_like(background_s))
200 |
201 |
202 | class TestContrastiveVIModuleGenerative:
203 | def test_get_generative_input_from_concat_tensors(
204 | self,
205 | mock_contrastive_vi_module,
206 | mock_contrastive_batch,
207 | mock_n_input,
208 | ):
209 | generative_input = (
210 | mock_contrastive_vi_module._get_generative_input_from_concat_tensors(
211 | mock_contrastive_batch, "background"
212 | )
213 | )
214 | for key in required_generative_input_keys_from_concat_tensors:
215 | assert key in generative_input.keys()
216 | batch_index = generative_input["batch_index"]
217 | assert batch_index.shape[1] == 1
218 |
219 | def test_get_generative_input_from_inference_outputs(
220 | self,
221 | mock_contrastive_vi_module,
222 | mock_contrastive_batch,
223 | ):
224 | inference_outputs = mock_contrastive_vi_module.inference(
225 | **mock_contrastive_vi_module._get_inference_input(mock_contrastive_batch)
226 | )
227 | generative_input = (
228 | mock_contrastive_vi_module._get_generative_input_from_inference_outputs(
229 | inference_outputs, required_data_sources[0]
230 | )
231 | )
232 | for key in required_generative_input_keys_from_inference_outputs:
233 | assert key in generative_input
234 |
235 | z = generative_input["z"]
236 | s = generative_input["s"]
237 | library = generative_input["library"]
238 | batch_size = z.shape[0]
239 |
240 | assert z.shape == (batch_size, mock_contrastive_vi_module.n_background_latent)
241 | assert s.shape == (batch_size, mock_contrastive_vi_module.n_salient_latent)
242 | assert library.shape == (batch_size, 1)
243 |
244 | def test_get_generative_input(
245 | self,
246 | mock_contrastive_vi_module,
247 | mock_contrastive_batch,
248 | mock_adata_background_label,
249 | mock_adata_target_label,
250 | ):
251 | inference_outputs = mock_contrastive_vi_module.inference(
252 | **mock_contrastive_vi_module._get_inference_input(mock_contrastive_batch)
253 | )
254 | generative_input = mock_contrastive_vi_module._get_generative_input(
255 | mock_contrastive_batch, inference_outputs
256 | )
257 | for data_source in required_data_sources:
258 | assert data_source in generative_input.keys()
259 | background_generative_input = generative_input["background"]
260 | background_generative_input_keys = background_generative_input.keys()
261 | target_generative_input = generative_input["target"]
262 | target_generative_input_keys = target_generative_input.keys()
263 | for key in (
264 | required_generative_input_keys_from_concat_tensors
265 | + required_generative_input_keys_from_inference_outputs
266 | ):
267 | assert key in background_generative_input_keys
268 | assert key in target_generative_input_keys
269 |
270 | # Check background vs. target labels are consistent.
271 | assert (
272 | background_generative_input["batch_index"] != mock_adata_background_label
273 | ).sum() == 0
274 | assert (
275 | target_generative_input["batch_index"] != mock_adata_target_label
276 | ).sum() == 0
277 |
278 | @pytest.mark.parametrize("n_samples", [1, 2])
279 | def test_generic_generative(
280 | self,
281 | mock_contrastive_vi_module,
282 | mock_contrastive_batch,
283 | n_samples,
284 | ):
285 | inference_outputs = mock_contrastive_vi_module.inference(
286 | **mock_contrastive_vi_module._get_inference_input(mock_contrastive_batch),
287 | n_samples=n_samples,
288 | )
289 | generative_input = mock_contrastive_vi_module._get_generative_input(
290 | mock_contrastive_batch, inference_outputs
291 | )["background"]
292 | generative_outputs = mock_contrastive_vi_module._generic_generative(
293 | **generative_input
294 | )
295 | for key in required_generative_output_keys:
296 | assert key in generative_outputs.keys()
297 | px_scale = generative_outputs["px_scale"]
298 | px_r = generative_outputs["px_r"]
299 | px_rate = generative_outputs["px_rate"]
300 | px_dropout = generative_outputs["px_dropout"]
301 | batch_size = px_scale.shape[-2]
302 | n_input = mock_contrastive_vi_module.n_input
303 |
304 | if n_samples > 1:
305 | expected_shape = (n_samples, batch_size, n_input)
306 | else:
307 | expected_shape = (batch_size, n_input)
308 |
309 | assert px_scale.shape == expected_shape
310 | assert px_r.shape == (n_input,) # One dispersion parameter per gene.
311 | assert px_rate.shape == expected_shape
312 | assert px_dropout.shape == expected_shape
313 |
314 | def test_generative(
315 | self,
316 | mock_contrastive_vi_module,
317 | mock_contrastive_batch,
318 | ):
319 | inference_outputs = mock_contrastive_vi_module.inference(
320 | **mock_contrastive_vi_module._get_inference_input(mock_contrastive_batch),
321 | )
322 | generative_input = mock_contrastive_vi_module._get_generative_input(
323 | mock_contrastive_batch, inference_outputs
324 | )
325 | generative_outputs = mock_contrastive_vi_module.generative(**generative_input)
326 | for data_source in required_data_sources:
327 | assert data_source in generative_outputs.keys()
328 |
329 |
330 | class TestContrastiveVIModuleLoss:
331 | def test_reconstruction_loss(
332 | self, mock_contrastive_vi_module, mock_contrastive_vi_data
333 | ):
334 | inference_input = mock_contrastive_vi_data["inference_input"]["background"]
335 | generative_outputs = mock_contrastive_vi_data["generative_outputs"][
336 | "background"
337 | ]
338 | x = inference_input["x"]
339 | px_rate = generative_outputs["px_rate"]
340 | px_r = generative_outputs["px_r"]
341 | px_dropout = generative_outputs["px_dropout"]
342 | recon_loss = mock_contrastive_vi_module.reconstruction_loss(
343 | x, px_rate, px_r, px_dropout
344 | )
345 | if len(px_rate.shape) == 3:
346 | expected_shape = px_rate.shape[:2]
347 | else:
348 | expected_shape = px_rate.shape[:1]
349 | assert recon_loss.shape == expected_shape
350 |
351 | def test_latent_kl_divergence(
352 | self, mock_contrastive_vi_module, mock_contrastive_vi_data
353 | ):
354 | inference_outputs = mock_contrastive_vi_data["inference_outputs"]["background"]
355 | qz_m = inference_outputs["qz_m"]
356 | qz_v = inference_outputs["qz_v"]
357 | kl_z = mock_contrastive_vi_module.latent_kl_divergence(
358 | variational_mean=qz_m,
359 | variational_var=qz_v,
360 | prior_mean=torch.zeros_like(qz_m),
361 | prior_var=torch.ones_like(qz_v),
362 | )
363 | assert kl_z.shape == qz_m.shape[:-1]
364 |
365 | def test_library_kl_divergence(
366 | self, mock_contrastive_vi_module, mock_contrastive_vi_data
367 | ):
368 | inference_input = mock_contrastive_vi_data["inference_input"]["background"]
369 | inference_outputs = mock_contrastive_vi_data["inference_outputs"]["background"]
370 | batch_index = inference_input["batch_index"]
371 | ql_m = inference_outputs["ql_m"]
372 | ql_v = inference_outputs["ql_v"]
373 | library = inference_outputs["library"]
374 | kl_library = mock_contrastive_vi_module.library_kl_divergence(
375 | batch_index, ql_m, ql_v, library
376 | )
377 | expected_shape = library.shape[:-1]
378 | assert kl_library.shape == expected_shape
379 | if mock_contrastive_vi_module.use_observed_lib_size:
380 | assert torch.equal(kl_library, torch.zeros(expected_shape))
381 |
382 | def test_loss(self, mock_contrastive_vi_module, mock_contrastive_vi_data):
383 | expected_shape = mock_contrastive_vi_data["inference_outputs"]["background"][
384 | "qz_m"
385 | ].shape[:-1]
386 | losses = mock_contrastive_vi_module.loss(
387 | mock_contrastive_vi_data["concat_tensors"],
388 | mock_contrastive_vi_data["inference_outputs"],
389 | mock_contrastive_vi_data["generative_outputs"],
390 | )
391 | loss = losses.loss
392 | recon_loss = losses.reconstruction_loss
393 | kl_local = losses.kl_local
394 | kl_global = losses.kl_global
395 |
396 | assert loss.shape == tuple()
397 | assert recon_loss.shape == expected_shape
398 | assert kl_local.shape == expected_shape
399 | assert kl_global.shape == tuple()
400 |
401 | @pytest.mark.parametrize("compute_loss", [True, False])
402 | def test_forward(
403 | self,
404 | mock_contrastive_vi_module,
405 | mock_contrastive_vi_data,
406 | compute_loss,
407 | ):
408 | concat_tensors = mock_contrastive_vi_data["concat_tensors"]
409 | if compute_loss:
410 | inference_outputs, generative_outputs, losses = mock_contrastive_vi_module(
411 | concat_tensors, compute_loss=compute_loss
412 | )
413 | assert isinstance(losses, LossRecorder)
414 | else:
415 | inference_outputs, generative_outputs = mock_contrastive_vi_module(
416 | concat_tensors, compute_loss=compute_loss
417 | )
418 | for data_source in required_data_sources:
419 | assert data_source in inference_outputs.keys()
420 | assert data_source in generative_outputs.keys()
421 | for key in required_inference_output_keys:
422 | assert key in inference_outputs[data_source].keys()
423 | for key in required_generative_output_keys:
424 | assert key in generative_outputs[data_source].keys()
425 |
--------------------------------------------------------------------------------
/tests/utils.py:
--------------------------------------------------------------------------------
1 | """Helper utilities for testing."""
2 |
3 | from typing import Dict
4 |
5 | import torch
6 |
7 |
8 | def get_next_batch(dataloader):
9 | return next(tensors for tensors in dataloader)
10 |
11 |
12 | def copy_module_state_dict(module) -> Dict[str, torch.Tensor]:
13 | copy = {}
14 | for name, param in module.state_dict().items():
15 | copy[name] = param.detach().cpu().clone()
16 | return copy
17 |
--------------------------------------------------------------------------------