├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── conftest.py ├── contrastive_vi ├── __init__.py ├── data │ ├── __init__.py │ ├── dataloaders │ │ ├── __init__.py │ │ ├── contrastive_dataloader.py │ │ └── data_splitting.py │ ├── datasets │ │ ├── __init__.py │ │ ├── haber_2017.py │ │ ├── mcfarland_2020.py │ │ ├── norman_2019.py │ │ ├── papalexi_2021.py │ │ └── zheng_2017.py │ └── utils.py ├── model │ ├── __init__.py │ ├── base │ │ ├── __init__.py │ │ └── training_mixin.py │ ├── contrastive_vi.py │ └── total_contrastive_vi.py └── module │ ├── __init__.py │ ├── contrastive_vi.py │ ├── total_contrastive_vi.py │ └── utils.py ├── environment.yml ├── pyproject.toml ├── setup.cfg ├── sketch.png └── tests ├── __init__.py ├── data └── dataloaders │ ├── test_contrastive_dataloader.py │ └── test_data_splitting.py ├── model ├── __init__.py └── test_contrastive_vi.py ├── module ├── __init__.py └── test_contrastive_vi.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # IDE 2 | .idea 3 | 4 | # Jupyter notebook 5 | .ipynb_checkpoints 6 | 7 | # Distribution 8 | *.egg-info 9 | 10 | # Python 11 | *__pycache__ 12 | 13 | # OS 14 | *.DS_Store 15 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.0.1 4 | hooks: 5 | - id: trailing-whitespace 6 | - id: end-of-file-fixer 7 | - id: check-docstring-first 8 | - repo: https://github.com/python/black 9 | rev: 21.9b0 10 | hooks: 11 | - id: black 12 | additional_dependencies: [toml, click==8.0.4] 13 | args: 14 | - --line-length=88 15 | - repo: https://github.com/PyCQA/pydocstyle 16 | rev: 6.1.1 17 | hooks: 18 | - id: pydocstyle 19 | additional_dependencies: [toml] 20 | args: 21 | - --ignore=D102,D107,D202,D203,D212,D205,D400,D401,D410,D411,D413,D415 22 | exclude: "(tests/.*|conftest.py)" 23 | - repo: https://github.com/pycqa/flake8 24 | rev: 4.0.1 25 | hooks: 26 | - id: flake8 27 | args: 28 | - --max-line-length=88 29 | - --ignore=E203,W503 30 | - repo: https://github.com/pycqa/isort 31 | rev: 5.8.0 32 | hooks: 33 | - id: isort 34 | name: isort (python) 35 | additional_dependencies: [toml] 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2021 Ethan Weinberger, Chris Lin, AIMS Lab 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # contrastiveVI 2 | 3 |

4 | 5 |

6 | 7 | contrastiveVI is a generative model designed to isolate factors of variation specific to 8 | a group of "target" cells (e.g. from specimens with a given disease) from those shared 9 | with a group of "background" cells (e.g. from healthy specimens). contrastiveVI is 10 | implemented in [scvi-tools](https://scvi-tools.org/). 11 | 12 | ## User guide 13 | 14 | **Note: This implementation of contrastiveVI is no longer maintained. Please see the main [scvi-tools](https://scvi-tools.org/) repository for the most up to date version of contrastiveVI, along with an updated tutorial [here](https://docs.scvi-tools.org/en/stable/tutorials/notebooks/scrna/contrastiveVI_tutorial.html).** 15 | 16 | ### Installation 17 | 18 | To install the latest version of contrastiveVI via pip 19 | 20 | ``` 21 | pip install contrastive-vi 22 | ``` 23 | 24 | Installation should take no more than 5 minutes. 25 | 26 | ### What you can do with contrastiveVI 27 | 28 | * If you have a dataset with cells in a background condition (e.g. from healthy 29 | controls) and a target condition (e.g. from diseased patients), you can train 30 | contrastiveVI to isolate latent factors of variation specific to the target cells 31 | from those shared with a background into separate latent spaces. 32 | * Run clustering algorithms on the target-specific latent space to discover sub-groups 33 | of target cells 34 | * Perform differential expression testing for discovered sub-groups of target cells 35 | using a procedure similar to that of [scVI 36 | ](https://www.nature.com/articles/s41592-018-0229-2). 37 | 38 | ### Colab Notebook Examples 39 | 40 | * [Applying contrastiveVI to see the effects of stem cell transplants for leukemia patients 41 | ](https://colab.research.google.com/drive/1yOTCVNWY6BydS1bppOYCWHvrxuvhMxZV?usp=sharing) 42 | * [Applying contrastiveVI to separate mouse intestinal epithelial cells 43 | infected with different pathogens by pathogen type 44 | ](https://colab.research.google.com/drive/1z0AcKQg7juArXGCx1XKj6skojWKRlDMC?usp=sharing) 45 | * [Applying contrastiveVI to better understand the results of a MIX-Seq 46 | small-molecule drug perturbation experiment 47 | ](https://colab.research.google.com/drive/1cMaJpMe3g0awCiwsw13oG7RvGnmXNCac?usp=sharing) 48 | * [Applying contrastiveVI to better understand heterogeneity in cellular responses to Alzheimer's disease 49 | ](https://colab.research.google.com/drive/1_R1YWQQUJzgQ6kz1XqglL5xZn8b8h1TX?usp=sharing) 50 | 51 | 52 | ## Development guide 53 | 54 | ### Set up the environment 55 | 1. Git clone this repository. 56 | 2. `cd contrastive-vi`. 57 | 3. Create and activate the specified conda environment by running 58 | ``` 59 | conda env create -f environment.yml 60 | conda activate contrastive-vi-env 61 | ``` 62 | 4. Install the `constrative_vi` package and necessary dependencies for 63 | development by running `pip install -e ".[dev]"`. 64 | 5. Git pre-commit hooks (https://pre-commit.com/) are used to automatically 65 | check and fix formatting errors before a Git commit happens. Run 66 | `pre-commit install` to install all the hooks. 67 | 6. Test that the pre-commit hooks work by running `pre-commit run --all-files`. 68 | 69 | ### Testing 70 | It's a good practice to include unit tests during development. 71 | Run `pytest tests` to verify existing tests. 72 | 73 | 74 | ## References 75 | 76 | If you find contrastiveVI useful for your work, please consider citing our preprent: 77 | 78 | ``` 79 | @article{contrastiveVI, 80 | title={Isolating salient variations of interest in single-cell transcriptomic data with contrastiveVI}, 81 | author={Weinberger, Ethan and Lin, Chris and Lee, Su-In}, 82 | journal={bioRxiv}, 83 | year={2021}, 84 | publisher={Cold Spring Harbor Laboratory} 85 | } 86 | ``` 87 | -------------------------------------------------------------------------------- /conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import scvi 3 | from scvi.model._utils import _init_library_size 4 | 5 | from contrastive_vi.data.dataloaders.contrastive_dataloader import ContrastiveDataLoader 6 | from contrastive_vi.model.contrastive_vi import ContrastiveVIModel 7 | from tests.utils import get_next_batch 8 | 9 | 10 | @pytest.fixture 11 | def mock_adata(): 12 | adata = scvi.data.synthetic_iid(n_batches=2) # Same number of cells in each batch. 13 | # Make number of cells unequal across batches to test edge cases. 14 | adata = adata[:-3, :] 15 | adata.layers["raw_counts"] = adata.X.copy() 16 | ContrastiveVIModel.setup_anndata( 17 | adata=adata, 18 | batch_key="batch", 19 | labels_key="labels", 20 | layer="raw_counts", 21 | ) 22 | return adata 23 | 24 | 25 | @pytest.fixture 26 | def mock_adata_manager(mock_adata): 27 | return ContrastiveVIModel._setup_adata_manager_store[mock_adata.uns["_scvi_uuid"]] 28 | 29 | 30 | @pytest.fixture 31 | def mock_library_log_means_and_vars(mock_adata_manager): 32 | return _init_library_size(mock_adata_manager, n_batch=2) 33 | 34 | 35 | @pytest.fixture 36 | def mock_library_log_means(mock_library_log_means_and_vars): 37 | return mock_library_log_means_and_vars[0] 38 | 39 | 40 | @pytest.fixture 41 | def mock_library_log_vars(mock_library_log_means_and_vars): 42 | return mock_library_log_means_and_vars[1] 43 | 44 | 45 | @pytest.fixture 46 | def mock_n_input(mock_adata): 47 | return mock_adata.X.shape[1] 48 | 49 | 50 | @pytest.fixture 51 | def mock_n_batch(mock_adata): 52 | return len(mock_adata.obs["batch"].unique()) 53 | 54 | 55 | @pytest.fixture 56 | def mock_adata_background_indices(mock_adata): 57 | return ( 58 | mock_adata.obs.index[(mock_adata.obs["batch"] == "batch_0")] 59 | .astype(int) 60 | .tolist() 61 | ) 62 | 63 | 64 | @pytest.fixture 65 | def mock_adata_background_label(mock_adata): 66 | return 0 67 | 68 | 69 | @pytest.fixture 70 | def mock_adata_target_indices(mock_adata): 71 | return ( 72 | mock_adata.obs.index[(mock_adata.obs["batch"] == "batch_1")] 73 | .astype(int) 74 | .tolist() 75 | ) 76 | 77 | 78 | @pytest.fixture 79 | def mock_adata_target_label(mock_adata): 80 | return 1 81 | 82 | 83 | @pytest.fixture 84 | def mock_contrastive_dataloader( 85 | mock_adata_manager, mock_adata_background_indices, mock_adata_target_indices 86 | ): 87 | return ContrastiveDataLoader( 88 | mock_adata_manager, 89 | mock_adata_background_indices, 90 | mock_adata_target_indices, 91 | batch_size=32, 92 | shuffle=False, 93 | ) 94 | 95 | 96 | @pytest.fixture 97 | def mock_contrastive_batch(mock_contrastive_dataloader): 98 | return get_next_batch(mock_contrastive_dataloader) 99 | -------------------------------------------------------------------------------- /contrastive_vi/__init__.py: -------------------------------------------------------------------------------- 1 | """contrastive_vi setup file. copied from the scvi-tools-skeleton repo.""" 2 | 3 | import logging 4 | 5 | from rich.console import Console 6 | from rich.logging import RichHandler 7 | 8 | # https://github.com/python-poetry/poetry/pull/2366#issuecomment-652418094 9 | # https://github.com/python-poetry/poetry/issues/144#issuecomment-623927302 10 | try: 11 | import importlib.metadata as importlib_metadata 12 | except ModuleNotFoundError: 13 | import importlib_metadata 14 | 15 | package_name = "contrastive_vi" 16 | __version__ = importlib_metadata.version(package_name) 17 | 18 | logger = logging.getLogger(__name__) 19 | # set the logging level 20 | logger.setLevel(logging.INFO) 21 | 22 | # nice logging outputs 23 | console = Console(force_terminal=True) 24 | console.is_jupyter = False 25 | ch = RichHandler(show_path=False, console=console, show_time=False) 26 | formatter = logging.Formatter("contrastive_vi: %(message)s") 27 | ch.setFormatter(formatter) 28 | logger.addHandler(ch) 29 | 30 | # this prevents double outputs 31 | logger.propagate = False 32 | -------------------------------------------------------------------------------- /contrastive_vi/data/__init__.py: -------------------------------------------------------------------------------- 1 | """Data preprocessing modules.""" 2 | -------------------------------------------------------------------------------- /contrastive_vi/data/dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | """Data loader modules for mini-batching.""" 2 | -------------------------------------------------------------------------------- /contrastive_vi/data/dataloaders/contrastive_dataloader.py: -------------------------------------------------------------------------------- 1 | """Data loader for contrastive learning.""" 2 | from itertools import cycle 3 | from typing import List, Optional, Union 4 | 5 | from scvi.data import AnnDataManager 6 | from scvi.dataloaders._concat_dataloader import ConcatDataLoader 7 | 8 | 9 | class _ContrastiveIterator: 10 | """ 11 | Iterator for background and target dataloader pairs as found in the contrastive 12 | analysis setting. 13 | 14 | Each iteration of this iterator returns a dictionary with two elements: 15 | "background", containing one batch of data from the background dataloader, and 16 | "target", containing one batch of data from the target dataloader. 17 | """ 18 | 19 | def __init__(self, background, target): 20 | self.background = iter(background) 21 | self.target = iter(target) 22 | 23 | def __iter__(self): 24 | return self 25 | 26 | def __next__(self): 27 | bg_samples = next(self.background) 28 | tg_samples = next(self.target) 29 | return {"background": bg_samples, "target": tg_samples} 30 | 31 | 32 | class ContrastiveDataLoader(ConcatDataLoader): 33 | """ 34 | Data loader to load background and target data for contrastive learning. 35 | 36 | Each iteration of the data loader returns a dictionary containing background and 37 | target data points, indexed by "background" and "target", respectively. 38 | Args: 39 | ---- 40 | adata_manager: AnnDataManager object that has been created via `setup_anndata`. 41 | background_indices: Indices for background samples in `adata`. 42 | target_indices: Indices for target samples in `adata`. 43 | shuffle: Whether the data should be shuffled. 44 | batch_size: Mini-batch size to load for background and target data. 45 | data_and_attributes: Dictionary with keys representing keys in data 46 | registry (`adata.uns["_scvi"]`) and value equal to desired numpy 47 | loading type (later made into torch tensor). If `None`, defaults to all 48 | registered data. 49 | drop_last: If int, drops the last batch if its length is less than 50 | `drop_last`. If `drop_last == True`, drops last non-full batch. 51 | If `drop_last == False`, iterate over all batches. 52 | **data_loader_kwargs: Keyword arguments for `torch.utils.data.DataLoader`. 53 | """ 54 | 55 | def __init__( 56 | self, 57 | adata_manager: AnnDataManager, 58 | background_indices: List[int], 59 | target_indices: List[int], 60 | shuffle: bool = False, 61 | batch_size: int = 128, 62 | data_and_attributes: Optional[dict] = None, 63 | drop_last: Union[bool, int] = False, 64 | **data_loader_kwargs, 65 | ) -> None: 66 | super().__init__( 67 | adata_manager=adata_manager, 68 | indices_list=[background_indices, target_indices], 69 | shuffle=shuffle, 70 | batch_size=batch_size, 71 | data_and_attributes=data_and_attributes, 72 | drop_last=drop_last, 73 | **data_loader_kwargs, 74 | ) 75 | self.background_indices = background_indices 76 | self.target_indices = target_indices 77 | 78 | def __iter__(self): 79 | """ 80 | 81 | Iter method for conctrastive data loader. 82 | 83 | Will iter over the dataloader with the most data while cycling through 84 | the data in the other dataloaders. 85 | """ 86 | 87 | iter_list = [ 88 | cycle(dl) if dl != self.largest_dl else dl for dl in self.dataloaders 89 | ] 90 | 91 | return _ContrastiveIterator(background=iter_list[0], target=iter_list[1]) 92 | -------------------------------------------------------------------------------- /contrastive_vi/data/dataloaders/data_splitting.py: -------------------------------------------------------------------------------- 1 | """Utilities for splitting a dataset into training, validation, and test set.""" 2 | 3 | from typing import List, Optional 4 | 5 | import numpy as np 6 | import pytorch_lightning as pl 7 | from anndata import AnnData 8 | from scvi import settings 9 | from scvi.dataloaders._data_splitting import validate_data_split 10 | from scvi.model._utils import parse_use_gpu_arg 11 | 12 | from contrastive_vi.data.dataloaders.contrastive_dataloader import ContrastiveDataLoader 13 | 14 | 15 | class ContrastiveDataSplitter(pl.LightningDataModule): 16 | """ 17 | Create ContrastiveDataLoader for training, validation, and test set. 18 | 19 | Args: 20 | ---- 21 | adata: AnnData object that has been registered via `setup_anndata`. 22 | background_indices: Indices for background samples in `adata`. 23 | target_indices: Indices for target samples in `adata`. 24 | train_size: Proportion of data to include in the training set. 25 | validation_size: Proportion of data to include in the validation set. The 26 | remaining proportion after `train_size` and `validation_size` is used for 27 | the test set. 28 | use_gpu: Use default GPU if available (if None or True); or index of GPU to 29 | use (if int); or name of GPU (if str, e.g., `'cuda:0'`); or use CPU 30 | (if False). 31 | **kwargs: Keyword args for data loader (`ContrastiveDataLoader`). 32 | """ 33 | 34 | def __init__( 35 | self, 36 | adata: AnnData, 37 | background_indices: List[int], 38 | target_indices: List[int], 39 | train_size: float = 0.9, 40 | validation_size: Optional[float] = None, 41 | use_gpu: bool = False, 42 | **kwargs, 43 | ) -> None: 44 | super().__init__() 45 | self.adata = adata 46 | self.background_indices = background_indices 47 | self.target_indices = target_indices 48 | self.train_size = train_size 49 | self.validation_size = validation_size 50 | self.use_gpu = use_gpu 51 | self.data_loader_kwargs = kwargs 52 | 53 | self.n_background = len(background_indices) 54 | self.n_target = len(target_indices) 55 | self.n_background_train, self.n_background_val = validate_data_split( 56 | len(self.background_indices), self.train_size, self.validation_size 57 | ) 58 | self.n_target_train, self.n_target_val = validate_data_split( 59 | len(self.target_indices), self.train_size, self.validation_size 60 | ) 61 | 62 | def setup(self, stage: Optional[str] = None): 63 | random_state = np.random.RandomState(seed=settings.seed) 64 | 65 | target_permutation = random_state.permutation(self.target_indices) 66 | n_target_train = self.n_target_train 67 | n_target_val = self.n_target_val 68 | self.target_val_idx = target_permutation[:n_target_val] 69 | self.target_train_idx = target_permutation[ 70 | n_target_val : (n_target_val + n_target_train) 71 | ] 72 | self.target_test_idx = target_permutation[(n_target_val + n_target_train) :] 73 | 74 | background_permutation = random_state.permutation(self.background_indices) 75 | n_background_train = self.n_background_train 76 | n_background_val = self.n_background_val 77 | self.background_val_idx = background_permutation[:n_background_val] 78 | self.background_train_idx = background_permutation[ 79 | n_background_val : (n_background_val + n_background_train) 80 | ] 81 | self.background_test_idx = background_permutation[ 82 | (n_background_val + n_background_train) : 83 | ] 84 | 85 | self.train_idx = np.concatenate( 86 | (self.background_train_idx, self.target_train_idx) 87 | ) 88 | self.val_idx = np.concatenate((self.background_val_idx, self.target_val_idx)) 89 | self.test_idx = np.concatenate((self.background_test_idx, self.target_test_idx)) 90 | 91 | gpus, self.device = parse_use_gpu_arg(self.use_gpu, return_device=True) 92 | self.pin_memory = ( 93 | True if (settings.dl_pin_memory_gpu_training and gpus != 0) else False 94 | ) 95 | 96 | def _get_contrastive_dataloader( 97 | self, background_indices: List[int], target_indices: List[int] 98 | ) -> ContrastiveDataLoader: 99 | return ContrastiveDataLoader( 100 | self.adata, 101 | background_indices, 102 | target_indices, 103 | shuffle=True, 104 | drop_last=3, 105 | pin_memory=self.pin_memory, 106 | **self.data_loader_kwargs, 107 | ) 108 | 109 | def train_dataloader(self) -> ContrastiveDataLoader: 110 | return self._get_contrastive_dataloader( 111 | self.background_train_idx, self.target_train_idx 112 | ) 113 | 114 | def val_dataloader(self) -> ContrastiveDataLoader: 115 | if len(self.background_val_idx) > 0 and len(self.target_val_idx) > 0: 116 | return self._get_contrastive_dataloader( 117 | self.background_val_idx, self.target_val_idx 118 | ) 119 | else: 120 | pass 121 | 122 | def test_dataloader(self) -> ContrastiveDataLoader: 123 | if len(self.background_test_idx) > 0 and len(self.target_test_idx) > 0: 124 | return self._get_contrastive_dataloader( 125 | self.background_test_idx, self.target_test_idx 126 | ) 127 | else: 128 | pass 129 | -------------------------------------------------------------------------------- /contrastive_vi/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | """Modules for downloading, reading, and preprocessing specific datasets.""" 2 | -------------------------------------------------------------------------------- /contrastive_vi/data/datasets/haber_2017.py: -------------------------------------------------------------------------------- 1 | """ 2 | Download, read, and preprocess Haber et al. (2017) expression data. 3 | 4 | Single-cell expression data from Haber et al. A single-cell survey of the small 5 | intestinal epithelium. Nature (2017). 6 | """ 7 | import gzip 8 | import os 9 | 10 | import pandas as pd 11 | import scanpy as sc 12 | from anndata import AnnData 13 | 14 | from contrastive_vi.data.utils import download_binary_file 15 | 16 | 17 | def download_haber_2017(output_path: str) -> None: 18 | """ 19 | Download Haber et al. 2017 data from the hosting URLs. 20 | 21 | Args: 22 | ---- 23 | output_path: Output path to store the downloaded and unzipped 24 | directories. 25 | 26 | Returns 27 | ------- 28 | None. File directories are downloaded to output_path. 29 | """ 30 | 31 | url = ( 32 | "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE92nnn/GSE92332/suppl/GSE92332" 33 | "_SalmHelm_UMIcounts.txt.gz" 34 | ) 35 | 36 | output_filename = os.path.join(output_path, url.split("/")[-1]) 37 | 38 | download_binary_file(url, output_filename) 39 | 40 | 41 | def read_haber_2017(file_directory: str) -> pd.DataFrame: 42 | """ 43 | Read the expression data for Download Haber et al. 2017 the given directory. 44 | 45 | Args: 46 | ---- 47 | file_directory: Directory containing Haber et al. 2017 data. 48 | 49 | Returns 50 | ------- 51 | A data frame containing single-cell gene expression count, with cell 52 | identification barcodes as column names and gene IDs as indices. 53 | """ 54 | 55 | with gzip.open( 56 | os.path.join(file_directory, "GSE92332_SalmHelm_UMIcounts.txt.gz"), "rb" 57 | ) as f: 58 | df = pd.read_csv(f, sep="\t") 59 | 60 | return df 61 | 62 | 63 | def preprocess_haber_2017(download_path: str, n_top_genes: int) -> AnnData: 64 | """ 65 | Preprocess expression data from Haber et al. 2017. 66 | 67 | Args: 68 | ---- 69 | download_path: Path containing the downloaded Haber et al. 2017 data file. 70 | n_top_genes: Number of most variable genes to retain. 71 | 72 | Returns 73 | ------- 74 | An AnnData object containing single-cell expression data. The layer 75 | "count" contains the count data for the most variable genes. The X 76 | variable contains the total-count-normalized and log-transformed data 77 | for the most variable genes (a copy with all the genes is stored in 78 | .raw). 79 | """ 80 | 81 | df = read_haber_2017(download_path) 82 | df = df.transpose() 83 | 84 | cell_groups = [] 85 | barcodes = [] 86 | conditions = [] 87 | cell_types = [] 88 | 89 | for cell in df.index: 90 | cell_group, barcode, condition, cell_type = cell.split("_") 91 | cell_groups.append(cell_group) 92 | barcodes.append(barcode) 93 | conditions.append(condition) 94 | cell_types.append(cell_type) 95 | 96 | metadata_df = pd.DataFrame( 97 | { 98 | "cell_group": cell_groups, 99 | "barcode": barcodes, 100 | "condition": conditions, 101 | "cell_type": cell_types, 102 | } 103 | ) 104 | 105 | adata = AnnData(X=df.values, obs=metadata_df) 106 | adata = adata[adata.obs["condition"] != "Hpoly.Day3"] 107 | adata.layers["count"] = adata.X.copy() 108 | sc.pp.normalize_total(adata) 109 | sc.pp.log1p(adata) 110 | adata.raw = adata 111 | sc.pp.highly_variable_genes( 112 | adata, flavor="seurat_v3", n_top_genes=n_top_genes, layer="count", subset=True 113 | ) 114 | adata = adata[adata.layers["count"].sum(1) != 0] # Remove cells with all zeros. 115 | return adata 116 | -------------------------------------------------------------------------------- /contrastive_vi/data/datasets/mcfarland_2020.py: -------------------------------------------------------------------------------- 1 | """ 2 | Download, read, and preprocess Mcfarland et al. (2020) expression data. 3 | 4 | Single-cell expression data from Mcfarland et al. Multiplexed single-cell 5 | transcriptional response profiling to define cancer vulnerabilities and therapeutic 6 | mechanism of action. Nature Communications (2020). 7 | """ 8 | import os 9 | import shutil 10 | from typing import Tuple 11 | 12 | import anndata 13 | import numpy as np 14 | import pandas as pd 15 | import scanpy as sc 16 | from anndata import AnnData 17 | from scipy.io import mmread 18 | 19 | from contrastive_vi.data.utils import download_binary_file 20 | 21 | 22 | def download_mcfarland_2020(output_path: str) -> None: 23 | """ 24 | Download Mcfarland et al. 2020 data from the hosting URLs. 25 | 26 | Args: 27 | ---- 28 | output_path: Output path to store the downloaded and unzipped 29 | directories. 30 | 31 | Returns 32 | ------- 33 | None. File directories are downloaded and unzipped in output_path. 34 | """ 35 | idasanutlin_url = "https://figshare.com/ndownloader/files/18716351" 36 | idasanutlin_output_filename = os.path.join(output_path, "idasanutlin.zip") 37 | 38 | download_binary_file(idasanutlin_url, idasanutlin_output_filename) 39 | idasanutlin_output_dir = idasanutlin_output_filename.replace(".zip", "") 40 | shutil.unpack_archive(idasanutlin_output_filename, idasanutlin_output_dir) 41 | 42 | dmso_url = "https://figshare.com/ndownloader/files/18716354" 43 | dmso_output_filename = os.path.join(output_path, "dmso.zip") 44 | 45 | download_binary_file(dmso_url, dmso_output_filename) 46 | dmso_output_dir = dmso_output_filename.replace(".zip", "") 47 | shutil.unpack_archive(dmso_output_filename, dmso_output_dir) 48 | 49 | 50 | def _read_mixseq_df(directory: str) -> pd.DataFrame: 51 | data = mmread(os.path.join(directory, "matrix.mtx")) 52 | barcodes = pd.read_table(os.path.join(directory, "barcodes.tsv"), header=None) 53 | classifications = pd.read_csv(os.path.join(directory, "classifications.csv")) 54 | classifications["cell_line"] = np.array( 55 | [x.split("_")[0] for x in classifications.singlet_ID.values] 56 | ) 57 | gene_names = pd.read_table(os.path.join(directory, "genes.tsv"), header=None) 58 | 59 | df = pd.DataFrame( 60 | data.toarray(), 61 | columns=barcodes.iloc[:, 0].values, 62 | index=gene_names.iloc[:, 0].values, 63 | ) 64 | return df 65 | 66 | 67 | def _get_tp53_mutation_status(directory: str) -> np.array: 68 | # Taken from https://cancerdatascience.org/blog/posts/mix-seq/ 69 | TP53_WT = [ 70 | "LNCAPCLONEFGC_PROSTATE", 71 | "DKMG_CENTRAL_NERVOUS_SYSTEM", 72 | "NCIH226_LUNG", 73 | "RCC10RGB_KIDNEY", 74 | "SNU1079_BILIARY_TRACT", 75 | "CCFSTTG1_CENTRAL_NERVOUS_SYSTEM", 76 | "COV434_OVARY", 77 | ] 78 | 79 | classifications = pd.read_csv(os.path.join(directory, "classifications.csv")) 80 | TP53_mutation_status = [ 81 | "Wild Type" if x in TP53_WT else "Mutation" 82 | for x in classifications.singlet_ID.values 83 | ] 84 | return np.array(TP53_mutation_status) 85 | 86 | 87 | def read_mcfarland_2020(file_directory: str) -> Tuple[pd.DataFrame, pd.DataFrame]: 88 | """ 89 | Read the expression data for Mcfarland et al. 2020 in the given directory. 90 | 91 | Args: 92 | ---- 93 | file_directory: Directory containing Mcfarland et al. 2020 data. 94 | 95 | Returns 96 | ------- 97 | Two data frames of raw count expression data. The first contains 98 | single-cell gene expression count data from cancer cell lines exposed to 99 | idasanutlin with cell identification barcodes as column names and gene IDs as 100 | indices. The second contains count data with the same format from samples 101 | exposed to a control solution (DMSO). 102 | """ 103 | idasanutlin_dir = os.path.join( 104 | file_directory, "idasanutlin", "Idasanutlin_24hr_expt1" 105 | ) 106 | idasanutlin_df = _read_mixseq_df(idasanutlin_dir) 107 | 108 | dmso_dir = os.path.join(file_directory, "dmso", "DMSO_24hr_expt1") 109 | dmso_df = _read_mixseq_df(dmso_dir) 110 | 111 | return idasanutlin_df, dmso_df 112 | 113 | 114 | def preprocess_mcfarland_2020(download_path: str, n_top_genes: int) -> AnnData: 115 | """ 116 | Preprocess expression data from Mcfarland et al., 2020. 117 | 118 | Args: 119 | ---- 120 | download_path: Path containing the downloaded Mcfarland et al. 2020 data files. 121 | n_top_genes: Number of most variable genes to retain. 122 | 123 | Returns 124 | ------- 125 | An AnnData object containing single-cell expression data. The layer 126 | "count" contains the count data for the most variable genes. The X 127 | variable contains the total-count-normalized and log-transformed data 128 | for the most variable genes (a copy with all the genes is stored in 129 | .raw). 130 | """ 131 | 132 | idasanutlin_df, dmso_df = read_mcfarland_2020(download_path) 133 | idasanutlin_df, dmso_df = idasanutlin_df.transpose(), dmso_df.transpose() 134 | 135 | idasanutlin_adata = AnnData(idasanutlin_df) 136 | idasanutlin_dir = os.path.join( 137 | download_path, "idasanutlin", "Idasanutlin_24hr_expt1" 138 | ) 139 | idasanutlin_adata.obs["TP53_mutation_status"] = _get_tp53_mutation_status( 140 | idasanutlin_dir 141 | ) 142 | idasanutlin_adata.obs["condition"] = np.repeat( 143 | "Idasanutlin", idasanutlin_adata.shape[0] 144 | ) 145 | 146 | dmso_adata = AnnData(dmso_df) 147 | dmso_dir = os.path.join(download_path, "dmso", "DMSO_24hr_expt1") 148 | dmso_adata.obs["TP53_mutation_status"] = _get_tp53_mutation_status(dmso_dir) 149 | dmso_adata.obs["condition"] = np.repeat("DMSO", dmso_adata.shape[0]) 150 | 151 | full_adata = anndata.concat([idasanutlin_adata, dmso_adata]) 152 | full_adata.layers["count"] = full_adata.X.copy() 153 | sc.pp.normalize_total(full_adata) 154 | sc.pp.log1p(full_adata) 155 | full_adata.raw = full_adata 156 | sc.pp.highly_variable_genes( 157 | full_adata, 158 | flavor="seurat_v3", 159 | n_top_genes=n_top_genes, 160 | layer="count", 161 | subset=True, 162 | ) 163 | full_adata = full_adata[ 164 | full_adata.layers["count"].sum(1) != 0 165 | ] # Remove cells with all zeros. 166 | return full_adata 167 | -------------------------------------------------------------------------------- /contrastive_vi/data/datasets/norman_2019.py: -------------------------------------------------------------------------------- 1 | """ 2 | Download, read, and preprocess Norman et al. (2019) expression data. 3 | 4 | Single-cell expression data from Norman et al. Exploring genetic interaction 5 | manifolds constructed from rich single-cell phenotypes. Science (2019). 6 | """ 7 | 8 | import gzip 9 | import os 10 | import re 11 | 12 | import pandas as pd 13 | import scanpy as sc 14 | from anndata import AnnData 15 | from scipy.io import mmread 16 | from scipy.sparse import coo_matrix 17 | 18 | from contrastive_vi.data.utils import download_binary_file 19 | 20 | # Gene program lists obtained by cross-referencing the heatmap here 21 | # https://github.com/thomasmaxwellnorman/Perturbseq_GI/blob/master/GI_optimal_umap.ipynb 22 | # with Figure 2b in Norman 2019 23 | G1_CYCLE = [ 24 | "CDKN1C+CDKN1B", 25 | "CDKN1B+ctrl", 26 | "CDKN1B+CDKN1A", 27 | "CDKN1C+ctrl", 28 | "ctrl+CDKN1A", 29 | "CDKN1C+CDKN1A", 30 | "CDKN1A+ctrl", 31 | ] 32 | 33 | ERYTHROID = [ 34 | "BPGM+SAMD1", 35 | "ATL1+ctrl", 36 | "UBASH3B+ZBTB25", 37 | "PTPN12+PTPN9", 38 | "PTPN12+UBASH3A", 39 | "CBL+CNN1", 40 | "UBASH3B+CNN1", 41 | "CBL+UBASH3B", 42 | "UBASH3B+PTPN9", 43 | "PTPN1+ctrl", 44 | "CBL+PTPN9", 45 | "CNN1+UBASH3A", 46 | "CBL+PTPN12", 47 | "PTPN12+ZBTB25", 48 | "UBASH3B+PTPN12", 49 | "SAMD1+PTPN12", 50 | "SAMD1+UBASH3B", 51 | "UBASH3B+UBASH3A", 52 | ] 53 | 54 | PIONEER_FACTORS = [ 55 | "ZBTB10+SNAI1", 56 | "FOXL2+MEIS1", 57 | "POU3F2+CBFA2T3", 58 | "DUSP9+SNAI1", 59 | "FOXA3+FOXA1", 60 | "FOXA3+ctrl", 61 | "LYL1+IER5L", 62 | "FOXA1+FOXF1", 63 | "FOXF1+HOXB9", 64 | "FOXA1+HOXB9", 65 | "FOXA3+HOXB9", 66 | "FOXA3+FOXA1", 67 | "FOXA3+FOXL2", 68 | "POU3F2+FOXL2", 69 | "FOXF1+FOXL2", 70 | "FOXA1+FOXL2", 71 | "HOXA13+ctrl", 72 | "ctrl+HOXC13", 73 | "HOXC13+ctrl", 74 | "MIDN+ctrl", 75 | "TP73+ctrl", 76 | ] 77 | 78 | GRANULOCYTE_APOPTOSIS = [ 79 | "SPI1+ctrl", 80 | "ctrl+SPI1", 81 | "ctrl+CEBPB", 82 | "CEBPB+ctrl", 83 | "JUN+CEBPA", 84 | "CEBPB+CEBPA", 85 | "FOSB+CEBPE", 86 | "ZC3HAV1+CEBPA", 87 | "KLF1+CEBPA", 88 | "ctrl+CEBPA", 89 | "CEBPA+ctrl", 90 | "CEBPE+CEBPA", 91 | "CEBPE+SPI1", 92 | "CEBPE+ctrl", 93 | "ctrl+CEBPE", 94 | "CEBPE+RUNX1T1", 95 | "CEBPE+CEBPB", 96 | "FOSB+CEBPB", 97 | "ETS2+CEBPE", 98 | ] 99 | 100 | MEGAKARYOCYTE = [ 101 | "ctrl+ETS2", 102 | "MAPK1+ctrl", 103 | "ctrl+MAPK1", 104 | "ETS2+MAPK1", 105 | "CEBPB+MAPK1", 106 | "MAPK1+TGFBR2", 107 | ] 108 | 109 | PRO_GROWTH = [ 110 | "CEBPE+KLF1", 111 | "KLF1+MAP2K6", 112 | "AHR+KLF1", 113 | "ctrl+KLF1", 114 | "KLF1+ctrl", 115 | "KLF1+BAK1", 116 | "KLF1+TGFBR2", 117 | ] 118 | 119 | 120 | def download_norman_2019(output_path: str) -> None: 121 | """ 122 | Download Norman et al. 2019 data and metadata files from the hosting URLs. 123 | 124 | Args: 125 | ---- 126 | output_path: Output path to store the downloaded and unzipped 127 | directories. 128 | 129 | Returns 130 | ------- 131 | None. File directories are downloaded to output_path. 132 | """ 133 | 134 | file_urls = ( 135 | "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE133nnn/GSE133344/suppl" 136 | "/GSE133344_filtered_matrix.mtx.gz", 137 | "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE133nnn/GSE133344/suppl" 138 | "/GSE133344_filtered_genes.tsv.gz", 139 | "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE133nnn/GSE133344/suppl" 140 | "/GSE133344_filtered_barcodes.tsv.gz", 141 | "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE133nnn/GSE133344/suppl" 142 | "/GSE133344_filtered_cell_identities.csv.gz", 143 | ) 144 | 145 | for url in file_urls: 146 | output_filename = os.path.join(output_path, url.split("/")[-1]) 147 | download_binary_file(url, output_filename) 148 | 149 | 150 | def read_norman_2019(file_directory: str) -> coo_matrix: 151 | """ 152 | Read the expression data for Norman et al. 2019 in the given directory. 153 | 154 | Args: 155 | ---- 156 | file_directory: Directory containing Norman et al. 2019 data. 157 | 158 | Returns 159 | ------- 160 | A sparse matrix containing single-cell gene expression count, with rows 161 | representing genes and columns representing cells. 162 | """ 163 | 164 | with gzip.open( 165 | os.path.join(file_directory, "GSE133344_filtered_matrix.mtx.gz"), "rb" 166 | ) as f: 167 | matrix = mmread(f) 168 | 169 | return matrix 170 | 171 | 172 | def preprocess_norman_2019(download_path: str, n_top_genes: int) -> AnnData: 173 | """ 174 | Preprocess expression data from Norman et al. 2019. 175 | 176 | Args: 177 | ---- 178 | download_path: Path containing the downloaded Norman et al. 2019 data file. 179 | n_top_genes: Number of most variable genes to retain. 180 | 181 | Returns 182 | ------- 183 | An AnnData object containing single-cell expression data. The layer 184 | "count" contains the count data for the most variable genes. The .X 185 | variable contains the normalized and log-transformed data for the most variable 186 | genes. A copy of data with all genes is stored in .raw. 187 | """ 188 | matrix = read_norman_2019(download_path) 189 | 190 | # List of cell barcodes. The barcodes in this list are stored in the same order 191 | # as cells are in the count matrix. 192 | cell_barcodes = pd.read_csv( 193 | os.path.join(download_path, "GSE133344_filtered_barcodes.tsv.gz"), 194 | sep="\t", 195 | header=None, 196 | names=["cell_barcode"], 197 | ) 198 | 199 | # IDs/names of the gene features. 200 | gene_list = pd.read_csv( 201 | os.path.join(download_path, "GSE133344_filtered_genes.tsv.gz"), 202 | sep="\t", 203 | header=None, 204 | names=["gene_id", "gene_name"], 205 | ) 206 | 207 | # Dataframe where each row corresponds to a cell, and each column corresponds 208 | # to a gene feature. 209 | matrix = pd.DataFrame( 210 | matrix.transpose().todense(), 211 | columns=gene_list["gene_id"], 212 | index=cell_barcodes["cell_barcode"], 213 | dtype="int32", 214 | ) 215 | 216 | # Dataframe mapping cell barcodes to metadata about that cell (e.g. which CRISPR 217 | # guides were applied to that cell). Unfortunately, this list has a different 218 | # ordering from the count matrix, so we have to be careful combining the metadata 219 | # and count data. 220 | cell_identities = pd.read_csv( 221 | os.path.join(download_path, "GSE133344_filtered_cell_identities.csv.gz") 222 | ).set_index("cell_barcode") 223 | 224 | # This merge call reorders our metadata dataframe to match the ordering in the 225 | # count matrix. Some cells in `cell_barcodes` do not have metadata associated with 226 | # them, and their metadata values will be filled in as NaN. 227 | aligned_metadata = pd.merge( 228 | cell_barcodes, 229 | cell_identities, 230 | left_on="cell_barcode", 231 | right_index=True, 232 | how="left", 233 | ).set_index("cell_barcode") 234 | 235 | adata = AnnData(matrix) 236 | adata.obs = aligned_metadata 237 | 238 | # Filter out any cells that don't have metadata values. 239 | rows_without_nans = [ 240 | index for index, row in adata.obs.iterrows() if not row.isnull().any() 241 | ] 242 | adata = adata[rows_without_nans, :] 243 | 244 | # Remove these as suggested by the authors. See lines referring to 245 | # NegCtrl1_NegCtrl0 in GI_generate_populations.ipynb in the Norman 2019 paper's 246 | # Github repo https://github.com/thomasmaxwellnorman/Perturbseq_GI/ 247 | adata = adata[adata.obs["guide_identity"] != "NegCtrl1_NegCtrl0__NegCtrl1_NegCtrl0"] 248 | 249 | # We create a new metadata column with cleaner representations of CRISPR guide 250 | # identities. The original format is _____ 251 | adata.obs["guide_merged"] = adata.obs["guide_identity"] 252 | 253 | control_regex = re.compile(r"NegCtrl(.*)_NegCtrl(.*)+NegCtrl(.*)_NegCtrl(.*)") 254 | for i in adata.obs["guide_merged"].unique(): 255 | if control_regex.match(i): 256 | # For any cells that only had control guides, we don't care about the 257 | # specific IDs of the guides. Here we relabel them just as "ctrl". 258 | adata.obs["guide_merged"].replace(i, "ctrl", inplace=True) 259 | else: 260 | # Otherwise, we reformat the guide label to be +. If Guide1 261 | # or Guide2 was a control, we replace it with "ctrl". 262 | split = i.split("__")[0] 263 | split = split.split("_") 264 | for j, string in enumerate(split): 265 | if "NegCtrl" in split[j]: 266 | split[j] = "ctrl" 267 | adata.obs["guide_merged"].replace(i, f"{split[0]}+{split[1]}", inplace=True) 268 | 269 | guides_to_programs = {} 270 | guides_to_programs.update(dict.fromkeys(G1_CYCLE, "G1 cell cycle arrest")) 271 | guides_to_programs.update(dict.fromkeys(ERYTHROID, "Erythroid")) 272 | guides_to_programs.update(dict.fromkeys(PIONEER_FACTORS, "Pioneer factors")) 273 | guides_to_programs.update( 274 | dict.fromkeys(GRANULOCYTE_APOPTOSIS, "Granulocyte/apoptosis") 275 | ) 276 | guides_to_programs.update(dict.fromkeys(PRO_GROWTH, "Pro-growth")) 277 | guides_to_programs.update(dict.fromkeys(MEGAKARYOCYTE, "Megakaryocyte")) 278 | guides_to_programs.update(dict.fromkeys(["ctrl"], "Ctrl")) 279 | 280 | # We only keep cells whose guides were either controls or are labeled with a 281 | # specific gene program 282 | adata = adata[adata.obs["guide_merged"].isin(guides_to_programs.keys())] 283 | adata.obs["gene_program"] = [ 284 | guides_to_programs[x] for x in adata.obs["guide_merged"] 285 | ] 286 | 287 | adata.obs["good_coverage"] = adata.obs["good_coverage"].astype(bool) 288 | 289 | adata.layers["count"] = adata.X.copy() 290 | sc.pp.normalize_total(adata) 291 | sc.pp.log1p(adata) 292 | adata.raw = adata 293 | sc.pp.highly_variable_genes( 294 | adata, flavor="seurat_v3", n_top_genes=n_top_genes, layer="count", subset=True 295 | ) 296 | adata = adata[adata.layers["count"].sum(1) != 0] # Remove cells with all zeros. 297 | return adata 298 | -------------------------------------------------------------------------------- /contrastive_vi/data/datasets/papalexi_2021.py: -------------------------------------------------------------------------------- 1 | """ 2 | Download, read, and preprocess Papalexi et al. (2021) expression data. 3 | 4 | Single-cell expression data from Papalexi et al. Characterizing the molecular regulation 5 | of inhibitory immune checkpoints with multimodal single-cell screens. (Nature Genetics 6 | 2021) 7 | """ 8 | import os 9 | import shutil 10 | 11 | import constants 12 | import pandas as pd 13 | import scanpy as sc 14 | from anndata import AnnData 15 | 16 | from contrastive_vi.data.utils import download_binary_file 17 | 18 | 19 | def download_papalexi_2021(output_path: str) -> None: 20 | """ 21 | Download Papalexi et al. 2021 data from the hosting URLs. 22 | 23 | Args: 24 | ---- 25 | output_path: Output path to store the downloaded and unzipped 26 | directories. 27 | 28 | Returns 29 | ------- 30 | None. File directories are downloaded to output_path. 31 | """ 32 | 33 | counts_data_url = ( 34 | "https://www.ncbi.nlm.nih.gov/geo/download/?acc=GSE153056&format=file" 35 | ) 36 | data_output_filename = os.path.join(output_path, "GSE153056_RAW.tar") 37 | download_binary_file(counts_data_url, data_output_filename) 38 | shutil.unpack_archive(data_output_filename, output_path) 39 | 40 | metadata_url = ( 41 | "https://www.ncbi.nlm.nih.gov/geo/download/?acc=GSE153056&" 42 | "format=file&file=GSE153056_ECCITE_metadata.tsv.gz" 43 | ) 44 | metadata_filename = os.path.join(output_path, metadata_url.split("=")[-1]) 45 | download_binary_file(metadata_url, metadata_filename) 46 | 47 | 48 | def read_papalexi_2021(file_directory: str) -> pd.DataFrame: 49 | """ 50 | Read the expression data for Papalexi et al. 2021 in the given directory. 51 | 52 | Args: 53 | ---- 54 | file_directory: Directory containing Papalexi et al. 2021 data. 55 | 56 | Returns 57 | ------- 58 | A pandas dataframe, with each column representing a cell 59 | and each row representing a gene feature. 60 | """ 61 | 62 | matrix = pd.read_csv( 63 | os.path.join(file_directory, "GSM4633614_ECCITE_cDNA_counts.tsv.gz"), 64 | sep="\t", 65 | index_col=0, 66 | ) 67 | return matrix 68 | 69 | 70 | def preprocess_papalexi_2021(download_path: str, n_top_genes: int) -> AnnData: 71 | """ 72 | Preprocess expression data from Papalexi et al. 2021. 73 | 74 | Args: 75 | ---- 76 | download_path: Path containing the downloaded Papalexi et al. 2021 data files. 77 | n_top_genes: Number of most variable genes to retain. 78 | 79 | Returns 80 | ------- 81 | An AnnData object containing single-cell expression data. The layer 82 | "count" contains the count data for the most variable genes. The .X 83 | variable contains the normalized and log-transformed data for the most variable 84 | genes. A copy of data with all genes is stored in .raw. 85 | """ 86 | 87 | df = read_papalexi_2021(download_path) 88 | 89 | # Switch dataframe from gene rows and cell columns to cell rows and gene columns 90 | df = df.transpose() 91 | 92 | metadata = pd.read_csv( 93 | os.path.join(download_path, "GSE153056_ECCITE_metadata.tsv.gz"), 94 | sep="\t", 95 | index_col=0, 96 | ) 97 | 98 | # Note: By initializing the anndata object from a dataframe, variable names 99 | # are automatically stored in adata.var 100 | adata = AnnData(df) 101 | adata.obs = metadata 102 | 103 | # Protein measurements also collected as part of CITE-Seq 104 | protein_counts_df = pd.read_csv( 105 | os.path.join(download_path, "GSM4633615_ECCITE_ADT_counts.tsv.gz"), 106 | sep="\t", 107 | index_col=0, 108 | ) 109 | 110 | # Switch dataframe from protein rows and cell columns to cell rows and protein 111 | # columns 112 | protein_counts_df = protein_counts_df.transpose() 113 | 114 | # Storing protein counts in an obsm field as expected by totalVI 115 | # (see https://docs.scvi-tools.org/en/stable/tutorials/notebooks/totalVI.html 116 | # for an example). Since `protein_counts_df` is annotated with protein names, 117 | # our obsm field will retain them as well. 118 | adata.obsm[constants.PROTEIN_EXPRESSION_KEY] = protein_counts_df 119 | 120 | adata.layers["count"] = adata.X.copy() 121 | sc.pp.normalize_total(adata) 122 | sc.pp.log1p(adata) 123 | adata.raw = adata 124 | sc.pp.highly_variable_genes( 125 | adata, flavor="seurat_v3", n_top_genes=n_top_genes, layer="count", subset=True 126 | ) 127 | adata = adata[adata.layers["count"].sum(1) != 0] # Remove cells with all zeros. 128 | return adata 129 | -------------------------------------------------------------------------------- /contrastive_vi/data/datasets/zheng_2017.py: -------------------------------------------------------------------------------- 1 | """ 2 | Download, read, and preprocess Zheng et al. (2017) expression data. 3 | 4 | Single-cell expression data from Zheng et al. Massively parallel digital 5 | transcriptional profiling of single cells. Nature Communications (2017). 6 | """ 7 | import os 8 | import shutil 9 | 10 | import numpy as np 11 | import pandas as pd 12 | import scanpy as sc 13 | from anndata import AnnData 14 | from scipy.io import mmread 15 | 16 | from contrastive_vi.data.utils import download_binary_file 17 | 18 | 19 | def download_zheng_2017(output_path: str) -> None: 20 | """ 21 | Download Zheng et al. 2017 data from the hosting URLs. 22 | 23 | Args: 24 | ---- 25 | output_path: Output path to store the downloaded and unzipped 26 | directories. 27 | 28 | Returns 29 | ------- 30 | None. File directories are downloaded and unzipped in output_path. 31 | """ 32 | host = "https://cf.10xgenomics.com/samples/cell-exp/1.1.0/" 33 | host_directories = [ 34 | ( 35 | "aml027_post_transplant/" 36 | "aml027_post_transplant_filtered_gene_bc_matrices.tar.gz" 37 | ), 38 | ( 39 | "aml027_pre_transplant/" 40 | "aml027_pre_transplant_filtered_gene_bc_matrices.tar.gz" 41 | ), 42 | ( 43 | "aml035_post_transplant/" 44 | "aml035_post_transplant_filtered_gene_bc_matrices.tar.gz" 45 | ), 46 | ( 47 | "aml035_pre_transplant/" 48 | "aml035_pre_transplant_filtered_gene_bc_matrices.tar.gz" 49 | ), 50 | ( 51 | "frozen_bmmc_healthy_donor1/" 52 | "frozen_bmmc_healthy_donor1_filtered_gene_bc_matrices.tar.gz" 53 | ), 54 | ( 55 | "frozen_bmmc_healthy_donor2/" 56 | "frozen_bmmc_healthy_donor2_filtered_gene_bc_matrices.tar.gz" 57 | ), 58 | ] 59 | urls = [host + host_directory for host_directory in host_directories] 60 | output_filenames = [os.path.join(output_path, url.split("/")[-1]) for url in urls] 61 | for url, output_filename in zip(urls, output_filenames): 62 | download_binary_file(url, output_filename) 63 | output_dir = output_filename.replace(".tar.gz", "") 64 | shutil.unpack_archive(output_filename, output_dir) 65 | 66 | 67 | def read_zheng_2017(file_directory: str) -> pd.DataFrame: 68 | """ 69 | Read the expression data for in a downloaded file directory. 70 | 71 | Args: 72 | ---- 73 | file_directory: A downloaded and unzipped file directory. 74 | 75 | Returns 76 | ------- 77 | A data frame containing single-cell gene expression count, with cell 78 | identification barcodes as column names and gene IDs as indices. 79 | """ 80 | data = mmread( 81 | os.path.join(file_directory, "filtered_matrices_mex/hg19/matrix.mtx") 82 | ).toarray() 83 | genes = pd.read_table( 84 | os.path.join(file_directory, "filtered_matrices_mex/hg19/genes.tsv"), 85 | header=None, 86 | ) 87 | barcodes = pd.read_table( 88 | os.path.join(file_directory, "filtered_matrices_mex/hg19/barcodes.tsv"), 89 | header=None, 90 | ) 91 | return pd.DataFrame( 92 | data, index=genes.iloc[:, 0].values, columns=barcodes.iloc[:, 0].values 93 | ) 94 | 95 | 96 | def preprocess_zheng_2017(download_path: str, n_top_genes: int) -> AnnData: 97 | """ 98 | Preprocess expression data from Zheng et al. 2017. 99 | 100 | Args: 101 | ---- 102 | download_path: Path containing the downloaded and unzipped file 103 | directories. 104 | n_top_genes: Number of most variable genes to retain. 105 | 106 | Returns 107 | ------- 108 | An AnnData object containing single-cell expression data. The layer 109 | "count" contains the count data for the most variable genes. The X 110 | variable contains the total-count-normalized and log-transformed data 111 | for the most variable genes (a copy with all the genes is stored in 112 | .raw). 113 | """ 114 | file_directory_dict = { 115 | "aml027_pre_transplant": ("aml027_pre_transplant_filtered_gene_bc_matrices"), 116 | "aml027_post_transplant": ("aml027_post_transplant_filtered_gene_bc_matrices"), 117 | "aml035_pre_transplant": ("aml035_pre_transplant_filtered_gene_bc_matrices"), 118 | "aml035_post_transplant": ("aml035_post_transplant_filtered_gene_bc_matrices"), 119 | "donor1_healthy": ("frozen_bmmc_healthy_donor1_filtered_gene_bc_matrices"), 120 | "donor2_healthy": ("frozen_bmmc_healthy_donor2_filtered_gene_bc_matrices"), 121 | } 122 | df_dict = { 123 | sample_id: read_zheng_2017(os.path.join(download_path, file_directory)) 124 | for sample_id, file_directory in file_directory_dict.items() 125 | } 126 | gene_set_list = [] 127 | for sample_id, df in df_dict.items(): 128 | df = df.iloc[:, np.sum(df.values, axis=0) != 0] 129 | df = df.iloc[np.sum(df.values, axis=1) != 0, :] 130 | df = df.transpose() 131 | gene_set_list.append(set(df.columns)) 132 | patient_id, condition = sample_id.split("_", 1) 133 | df["patient_id"] = patient_id 134 | df["condition"] = condition 135 | df_dict[sample_id] = df 136 | shared_genes = list(set.intersection(*gene_set_list)) 137 | data_list = [] 138 | meta_data_list = [] 139 | for df in df_dict.values(): 140 | data_list.append(df[shared_genes]) 141 | meta_data_list.append(df[["patient_id", "condition"]]) 142 | data = pd.concat(data_list) 143 | meta_data = pd.concat(meta_data_list) 144 | adata = AnnData(X=data.reset_index(drop=True), obs=meta_data.reset_index(drop=True)) 145 | adata.layers["count"] = adata.X.copy() 146 | sc.pp.normalize_total(adata) 147 | sc.pp.log1p(adata) 148 | adata.raw = adata 149 | sc.pp.highly_variable_genes( 150 | adata, 151 | flavor="seurat_v3", 152 | n_top_genes=n_top_genes, 153 | layer="count", 154 | subset=True, 155 | ) 156 | adata = adata[adata.layers["count"].sum(1) != 0] # Remove cells with all zeros. 157 | return adata 158 | -------------------------------------------------------------------------------- /contrastive_vi/data/utils.py: -------------------------------------------------------------------------------- 1 | """Data preprocessing utilities.""" 2 | import os 3 | 4 | import requests 5 | from anndata import AnnData 6 | 7 | 8 | def download_binary_file(file_url: str, output_path: str) -> None: 9 | """ 10 | Download binary data file from a URL. 11 | 12 | Args: 13 | ---- 14 | file_url: URL where the file is hosted. 15 | output_path: Output path for the downloaded file. 16 | 17 | Returns 18 | ------- 19 | None. 20 | """ 21 | request = requests.get(file_url) 22 | with open(output_path, "wb") as f: 23 | f.write(request.content) 24 | print(f"Downloaded data from {file_url} at {output_path}") 25 | 26 | 27 | def save_preprocessed_adata(adata: AnnData, output_path: str) -> None: 28 | """ 29 | Save given AnnData object with preprocessed data to disk using our dataset file 30 | naming convention. 31 | 32 | Args: 33 | ---- 34 | adata: AnnData object containing expression count data as well as metadata. 35 | output_path: Path to save resulting file. 36 | 37 | Returns 38 | ------- 39 | None. Provided AnnData object is saved to disk in a subdirectory called 40 | "preprocessed" in output_path. 41 | """ 42 | preprocessed_directory = os.path.join(output_path, "preprocessed") 43 | os.makedirs(preprocessed_directory, exist_ok=True) 44 | n_genes = adata.shape[1] 45 | filename = os.path.join( 46 | preprocessed_directory, 47 | f"adata_top_{n_genes}_genes.h5ad", 48 | ) 49 | adata.write_h5ad(filename=filename) 50 | -------------------------------------------------------------------------------- /contrastive_vi/model/__init__.py: -------------------------------------------------------------------------------- 1 | """scvi-tools Model classes for contrastive-VI.""" 2 | from .contrastive_vi import ContrastiveVIModel as ContrastiveVI 3 | from .total_contrastive_vi import TotalContrastiveVIModel as TotalContrastiveVI 4 | 5 | __all__ = ["ContrastiveVI", "TotalContrastiveVI"] 6 | -------------------------------------------------------------------------------- /contrastive_vi/model/base/__init__.py: -------------------------------------------------------------------------------- 1 | """Reusable Model classes for inheritance.""" 2 | -------------------------------------------------------------------------------- /contrastive_vi/model/base/training_mixin.py: -------------------------------------------------------------------------------- 1 | """ 2 | Mixin classes for pre-coded features. 3 | For more details on Mixin classes, see 4 | https://docs.scvi-tools.org/en/0.9.0/user_guide/notebooks/model_user_guide.html#Mixing-in-pre-coded-features 5 | """ 6 | 7 | from typing import List, Optional, Union 8 | 9 | import numpy as np 10 | from scvi.train import TrainingPlan, TrainRunner 11 | 12 | from contrastive_vi.data.dataloaders.data_splitting import ContrastiveDataSplitter 13 | 14 | 15 | class ContrastiveTrainingMixin: 16 | """General methods for contrastive learning.""" 17 | 18 | def train( 19 | self, 20 | background_indices: List[int], 21 | target_indices: List[int], 22 | max_epochs: Optional[int] = None, 23 | use_gpu: Optional[Union[str, int, bool]] = None, 24 | train_size: float = 0.9, 25 | validation_size: Optional[float] = None, 26 | batch_size: int = 128, 27 | early_stopping: bool = False, 28 | plan_kwargs: Optional[dict] = None, 29 | **trainer_kwargs, 30 | ) -> None: 31 | """ 32 | Train a contrastive model. 33 | 34 | Args: 35 | ---- 36 | background_indices: Indices for background samples in `adata`. 37 | target_indices: Indices for target samples in `adata`. 38 | max_epochs: Number of passes through the dataset. If `None`, default to 39 | `np.min([round((20000 / n_cells) * 400), 400])`. 40 | use_gpu: Use default GPU if available (if `None` or `True`), or index of 41 | GPU to use (if `int`), or name of GPU (if `str`, e.g., `"cuda:0"`), 42 | or use CPU (if `False`). 43 | train_size: Size of training set in the range [0.0, 1.0]. 44 | validation_size: Size of the validation set. If `None`, default to 45 | `1 - train_size`. If `train_size + validation_size < 1`, the remaining 46 | cells belong to the test set. 47 | batch_size: Mini-batch size to use during training. 48 | early_stopping: Perform early stopping. Additional arguments can be passed 49 | in `**kwargs`. See :class:`~scvi.train.Trainer` for further options. 50 | plan_kwargs: Keyword args for :class:`~scvi.train.TrainingPlan`. Keyword 51 | arguments passed to `train()` will overwrite values present 52 | in `plan_kwargs`, when appropriate. 53 | **trainer_kwargs: Other keyword args for :class:`~scvi.train.Trainer`. 54 | 55 | Returns 56 | ------- 57 | None. The model is trained. 58 | """ 59 | if max_epochs is None: 60 | n_cells = self.adata.n_obs 61 | max_epochs = np.min([round((20000 / n_cells) * 400), 400]) 62 | 63 | plan_kwargs = plan_kwargs if isinstance(plan_kwargs, dict) else dict() 64 | 65 | data_splitter = ContrastiveDataSplitter( 66 | self.adata_manager, 67 | background_indices, 68 | target_indices, 69 | train_size=train_size, 70 | validation_size=validation_size, 71 | batch_size=batch_size, 72 | use_gpu=use_gpu, 73 | ) 74 | training_plan = TrainingPlan(self.module, **plan_kwargs) 75 | 76 | es = "early_stopping" 77 | trainer_kwargs[es] = ( 78 | early_stopping if es not in trainer_kwargs.keys() else trainer_kwargs[es] 79 | ) 80 | runner = TrainRunner( 81 | self, 82 | training_plan=training_plan, 83 | data_splitter=data_splitter, 84 | max_epochs=max_epochs, 85 | use_gpu=use_gpu, 86 | **trainer_kwargs, 87 | ) 88 | return runner() 89 | -------------------------------------------------------------------------------- /contrastive_vi/model/contrastive_vi.py: -------------------------------------------------------------------------------- 1 | """Model class for contrastive-VI for single cell expression data.""" 2 | 3 | import logging 4 | import warnings 5 | from functools import partial 6 | from typing import Dict, Iterable, List, Optional, Sequence, Union 7 | 8 | import numpy as np 9 | import pandas as pd 10 | import torch 11 | from anndata import AnnData 12 | from scvi import REGISTRY_KEYS 13 | from scvi.data import AnnDataManager 14 | from scvi.data.fields import ( 15 | CategoricalJointObsField, 16 | CategoricalObsField, 17 | LayerField, 18 | NumericalJointObsField, 19 | NumericalObsField, 20 | ) 21 | from scvi.dataloaders import AnnDataLoader 22 | from scvi.model._utils import ( 23 | _get_batch_code_from_category, 24 | _init_library_size, 25 | scrna_raw_counts_properties, 26 | ) 27 | from scvi.model.base import BaseModelClass 28 | from scvi.model.base._utils import _de_core 29 | from scvi.utils import setup_anndata_dsp 30 | 31 | from contrastive_vi.model.base.training_mixin import ContrastiveTrainingMixin 32 | from contrastive_vi.module.contrastive_vi import ContrastiveVIModule 33 | 34 | logger = logging.getLogger(__name__) 35 | Number = Union[int, float] 36 | 37 | 38 | class ContrastiveVIModel(ContrastiveTrainingMixin, BaseModelClass): 39 | """ 40 | Model class for contrastive-VI. 41 | Args: 42 | ---- 43 | adata: AnnData object that has been registered via 44 | `ContrastiveVIModel.setup_anndata`. 45 | n_batch: Number of batches. If 0, no batch effect correction is performed. 46 | n_hidden: Number of nodes per hidden layer. 47 | n_latent: Dimensionality of the latent space. 48 | n_layers: Number of hidden layers used for encoder and decoder NNs. 49 | dropout_rate: Dropout rate for neural networks. 50 | use_observed_lib_size: Use observed library size for RNA as scaling factor in 51 | mean of conditional distribution. 52 | disentangle: Whether to disentangle the salient and background latent variables. 53 | use_mmd: Whether to use the maximum mean discrepancy loss to force background 54 | latent variables to have the same distribution for background and target 55 | data. 56 | mmd_weight: Weight used for the MMD loss. 57 | gammas: Gamma parameters for the MMD loss. 58 | """ 59 | 60 | def __init__( 61 | self, 62 | adata: AnnData, 63 | n_batch: int = 0, 64 | n_hidden: int = 128, 65 | n_background_latent: int = 10, 66 | n_salient_latent: int = 10, 67 | n_layers: int = 1, 68 | dropout_rate: float = 0.1, 69 | use_observed_lib_size: bool = True, 70 | wasserstein_penalty: float = 0, 71 | ) -> None: 72 | super(ContrastiveVIModel, self).__init__(adata) 73 | 74 | n_cats_per_cov = ( 75 | self.adata_manager.get_state_registry( 76 | REGISTRY_KEYS.CAT_COVS_KEY 77 | ).n_cats_per_key 78 | if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry 79 | else None 80 | ) 81 | n_batch = self.summary_stats.n_batch 82 | use_size_factor_key = ( 83 | REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry 84 | ) 85 | library_log_means, library_log_vars = None, None 86 | if not use_size_factor_key: 87 | library_log_means, library_log_vars = _init_library_size( 88 | self.adata_manager, n_batch 89 | ) 90 | 91 | self.module = ContrastiveVIModule( 92 | n_input=self.summary_stats["n_vars"], 93 | n_batch=n_batch, 94 | n_hidden=n_hidden, 95 | n_background_latent=n_background_latent, 96 | n_salient_latent=n_salient_latent, 97 | n_layers=n_layers, 98 | dropout_rate=dropout_rate, 99 | use_observed_lib_size=use_observed_lib_size, 100 | library_log_means=library_log_means, 101 | library_log_vars=library_log_vars, 102 | wasserstein_penalty=wasserstein_penalty, 103 | ) 104 | self._model_summary_string = "Contrastive-VI." 105 | # Necessary line to get params to be used for saving and loading. 106 | self.init_params_ = self._get_init_params(locals()) 107 | logger.info("The model has been initialized") 108 | 109 | @classmethod 110 | @setup_anndata_dsp.dedent 111 | def setup_anndata( 112 | cls, 113 | adata: AnnData, 114 | layer: Optional[str] = None, 115 | batch_key: Optional[str] = None, 116 | labels_key: Optional[str] = None, 117 | size_factor_key: Optional[str] = None, 118 | categorical_covariate_keys: Optional[List[str]] = None, 119 | continuous_covariate_keys: Optional[List[str]] = None, 120 | **kwargs, 121 | ) -> Optional[AnnData]: 122 | """ 123 | Set up AnnData instance for contrastive-VI model. 124 | 125 | Args: 126 | ---- 127 | adata: AnnData object containing raw counts. Rows represent cells, columns 128 | represent features. 129 | layer: If not None, uses this as the key in adata.layers for raw count data. 130 | batch_key: Key in `adata.obs` for batch information. Categories will 131 | automatically be converted into integer categories and saved to 132 | `adata.obs["_scvi_batch"]`. If None, assign the same batch to all the 133 | data. 134 | labels_key: Key in `adata.obs` for label information. Categories will 135 | automatically be converted into integer categories and saved to 136 | `adata.obs["_scvi_labels"]`. If None, assign the same label to all the 137 | data. 138 | size_factor_key: Key in `adata.obs` for size factor information. Instead of 139 | using library size as a size factor, the provided size factor column 140 | will be used as offset in the mean of the likelihood. Assumed to be on 141 | linear scale. 142 | categorical_covariate_keys: Keys in `adata.obs` corresponding to categorical 143 | data. Used in some models. 144 | continuous_covariate_keys: Keys in `adata.obs` corresponding to continuous 145 | data. Used in some models. 146 | 147 | Returns 148 | ------- 149 | If `copy` is True, return the modified `adata` set up for contrastive-VI 150 | model, otherwise `adata` is modified in place. 151 | """ 152 | setup_method_args = cls._get_setup_method_args(**locals()) 153 | anndata_fields = [ 154 | LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), 155 | CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), 156 | CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key), 157 | NumericalObsField( 158 | REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key, required=False 159 | ), 160 | CategoricalJointObsField( 161 | REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys 162 | ), 163 | NumericalJointObsField( 164 | REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys 165 | ), 166 | ] 167 | adata_manager = AnnDataManager( 168 | fields=anndata_fields, setup_method_args=setup_method_args 169 | ) 170 | adata_manager.register_fields(adata, **kwargs) 171 | cls.register_manager(adata_manager) 172 | 173 | @torch.no_grad() 174 | def get_latent_representation( 175 | self, 176 | adata: Optional[AnnData] = None, 177 | indices: Optional[Sequence[int]] = None, 178 | give_mean: bool = True, 179 | batch_size: Optional[int] = None, 180 | representation_kind: str = "salient", 181 | ) -> np.ndarray: 182 | """ 183 | Return the background or salient latent representation for each cell. 184 | 185 | Args: 186 | ---- 187 | adata: AnnData object with equivalent structure to initial AnnData. If `None`, 188 | defaults to the AnnData object used to initialize the model. 189 | indices: Indices of cells in adata to use. If `None`, all cells are used. 190 | give_mean: Give mean of distribution or sample from it. 191 | batch_size: Mini-batch size for data loading into model. Defaults to 192 | `scvi.settings.batch_size`. 193 | representation_kind: Either "background" or "salient" for the corresponding 194 | representation kind. 195 | 196 | Returns 197 | ------- 198 | A numpy array with shape `(n_cells, n_latent)`. 199 | """ 200 | available_representation_kinds = ["background", "salient"] 201 | assert representation_kind in available_representation_kinds, ( 202 | f"representation_kind = {representation_kind} is not one of" 203 | f" {available_representation_kinds}" 204 | ) 205 | 206 | adata = self._validate_anndata(adata) 207 | data_loader = self._make_data_loader( 208 | adata=adata, 209 | indices=indices, 210 | batch_size=batch_size, 211 | shuffle=False, 212 | data_loader_class=AnnDataLoader, 213 | ) 214 | latent = [] 215 | for tensors in data_loader: 216 | x = tensors[REGISTRY_KEYS.X_KEY] 217 | batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] 218 | outputs = self.module._generic_inference( 219 | x=x, batch_index=batch_index, n_samples=1 220 | ) 221 | 222 | if representation_kind == "background": 223 | latent_m = outputs["qz_m"] 224 | latent_sample = outputs["z"] 225 | else: 226 | latent_m = outputs["qs_m"] 227 | latent_sample = outputs["s"] 228 | 229 | if give_mean: 230 | latent_sample = latent_m 231 | 232 | latent += [latent_sample.detach().cpu()] 233 | return torch.cat(latent).numpy() 234 | 235 | def get_normalized_expression_fold_change( 236 | self, 237 | adata: Optional[AnnData] = None, 238 | indices: Optional[Sequence[int]] = None, 239 | transform_batch: Optional[Sequence[Union[Number, str]]] = None, 240 | gene_list: Optional[Sequence[str]] = None, 241 | library_size: Union[float, str] = 1.0, 242 | n_samples: int = 1, 243 | batch_size: Optional[int] = None, 244 | ) -> np.ndarray: 245 | """ 246 | Return the normalized (decoded) gene expression. 247 | 248 | Args: 249 | ---- 250 | adata: AnnData object with equivalent structure to initial AnnData. If `None`, 251 | defaults to the AnnData object used to initialize the model. 252 | indices: Indices of cells in adata to use. If `None`, all cells are used. 253 | transform_batch: Batch to condition on. If transform_batch is: 254 | - None, then real observed batch is used. 255 | - int, then batch transform_batch is used. 256 | gene_list: Return frequencies of expression for a subset of genes. This can 257 | save memory when working with large datasets and few genes are of interest. 258 | library_size: Scale the expression frequencies to a common library size. This 259 | allows gene expression levels to be interpreted on a common scale of 260 | relevant magnitude. If set to `"latent"`, use the latent library size. 261 | n_samples: Number of posterior samples to use for estimation. 262 | batch_size: Mini-batch size for data loading into model. Defaults to 263 | `scvi.settings.batch_size`. 264 | 265 | Returns 266 | ------- 267 | If `n_samples` > 1, then the shape is `(samples, cells, genes)`. Otherwise, 268 | shape is `(cells, genes)`. Each element is fold change of salient normalized 269 | expression divided by background normalized expression. 270 | """ 271 | exprs = self.get_normalized_expression( 272 | adata=adata, 273 | indices=indices, 274 | transform_batch=transform_batch, 275 | gene_list=gene_list, 276 | library_size=library_size, 277 | n_samples=n_samples, 278 | batch_size=batch_size, 279 | return_mean=False, 280 | return_numpy=True, 281 | ) 282 | salient_exprs = exprs["salient"] 283 | background_exprs = exprs["background"] 284 | fold_change = salient_exprs / background_exprs 285 | return fold_change 286 | 287 | @torch.no_grad() 288 | def get_normalized_expression( 289 | self, 290 | adata: Optional[AnnData] = None, 291 | indices: Optional[Sequence[int]] = None, 292 | transform_batch: Optional[Sequence[Union[Number, str]]] = None, 293 | gene_list: Optional[Sequence[str]] = None, 294 | library_size: Union[float, str] = 1.0, 295 | n_samples: int = 1, 296 | n_samples_overall: Optional[int] = None, 297 | batch_size: Optional[int] = None, 298 | return_mean: bool = True, 299 | return_numpy: Optional[bool] = None, 300 | ) -> Dict[str, Union[np.ndarray, pd.DataFrame]]: 301 | """ 302 | Return the normalized (decoded) gene expression. 303 | 304 | Args: 305 | ---- 306 | adata: AnnData object with equivalent structure to initial AnnData. If `None`, 307 | defaults to the AnnData object used to initialize the model. 308 | indices: Indices of cells in adata to use. If `None`, all cells are used. 309 | transform_batch: Batch to condition on. If transform_batch is: 310 | - None, then real observed batch is used. 311 | - int, then batch transform_batch is used. 312 | gene_list: Return frequencies of expression for a subset of genes. This can 313 | save memory when working with large datasets and few genes are of interest. 314 | library_size: Scale the expression frequencies to a common library size. This 315 | allows gene expression levels to be interpreted on a common scale of 316 | relevant magnitude. If set to `"latent"`, use the latent library size. 317 | n_samples: Number of posterior samples to use for estimation. 318 | n_samples_overall: The number of random samples in `adata` to use. 319 | batch_size: Mini-batch size for data loading into model. Defaults to 320 | `scvi.settings.batch_size`. 321 | return_mean: Whether to return the mean of the samples. 322 | return_numpy: Return a `numpy.ndarray` instead of a `pandas.DataFrame`. 323 | DataFrame includes gene names as columns. If either `n_samples=1` or 324 | `return_mean=True`, defaults to `False`. Otherwise, it defaults to `True`. 325 | 326 | Returns 327 | ------- 328 | A dictionary with keys "background" and "salient", with value as follows. 329 | If `n_samples` > 1 and `return_mean` is `False`, then the shape is 330 | `(samples, cells, genes)`. Otherwise, shape is `(cells, genes)`. In this 331 | case, return type is `pandas.DataFrame` unless `return_numpy` is `True`. 332 | """ 333 | adata = self._validate_anndata(adata) 334 | if indices is None: 335 | indices = np.arange(adata.n_obs) 336 | if n_samples_overall is not None: 337 | indices = np.random.choice(indices, n_samples_overall) 338 | data_loader = self._make_data_loader( 339 | adata=adata, 340 | indices=indices, 341 | batch_size=batch_size, 342 | shuffle=False, 343 | data_loader_class=AnnDataLoader, 344 | ) 345 | 346 | transform_batch = _get_batch_code_from_category( 347 | self.get_anndata_manager(adata, required=True), transform_batch 348 | ) 349 | 350 | if gene_list is None: 351 | gene_mask = slice(None) 352 | else: 353 | all_genes = adata.var_names 354 | gene_mask = [True if gene in gene_list else False for gene in all_genes] 355 | 356 | if n_samples > 1 and return_mean is False: 357 | if return_numpy is False: 358 | warnings.warn( 359 | "return_numpy must be True if n_samples > 1 and" 360 | " return_mean is False, returning np.ndarray" 361 | ) 362 | return_numpy = True 363 | if library_size == "latent": 364 | generative_output_key = "px_rate" 365 | scaling = 1 366 | else: 367 | generative_output_key = "px_scale" 368 | scaling = library_size 369 | 370 | background_exprs = [] 371 | salient_exprs = [] 372 | for tensors in data_loader: 373 | x = tensors[REGISTRY_KEYS.X_KEY] 374 | batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] 375 | background_per_batch_exprs = [] 376 | salient_per_batch_exprs = [] 377 | for batch in transform_batch: 378 | if batch is not None: 379 | batch_index = torch.ones_like(batch_index) * batch 380 | inference_outputs = self.module._generic_inference( 381 | x=x, batch_index=batch_index, n_samples=n_samples 382 | ) 383 | z = inference_outputs["z"] 384 | s = inference_outputs["s"] 385 | library = inference_outputs["library"] 386 | background_generative_outputs = self.module._generic_generative( 387 | z=z, s=torch.zeros_like(s), library=library, batch_index=batch_index 388 | ) 389 | salient_generative_outputs = self.module._generic_generative( 390 | z=z, s=s, library=library, batch_index=batch_index 391 | ) 392 | background_outputs = self._preprocess_normalized_expression( 393 | background_generative_outputs, 394 | generative_output_key, 395 | gene_mask, 396 | scaling, 397 | ) 398 | background_per_batch_exprs.append(background_outputs) 399 | salient_outputs = self._preprocess_normalized_expression( 400 | salient_generative_outputs, 401 | generative_output_key, 402 | gene_mask, 403 | scaling, 404 | ) 405 | salient_per_batch_exprs.append(salient_outputs) 406 | 407 | background_per_batch_exprs = np.stack( 408 | background_per_batch_exprs 409 | ) # Shape is (len(transform_batch) x batch_size x n_var). 410 | salient_per_batch_exprs = np.stack(salient_per_batch_exprs) 411 | background_exprs += [background_per_batch_exprs.mean(0)] 412 | salient_exprs += [salient_per_batch_exprs.mean(0)] 413 | 414 | if n_samples > 1: 415 | # The -2 axis correspond to cells. 416 | background_exprs = np.concatenate(background_exprs, axis=-2) 417 | salient_exprs = np.concatenate(salient_exprs, axis=-2) 418 | else: 419 | background_exprs = np.concatenate(background_exprs, axis=0) 420 | salient_exprs = np.concatenate(salient_exprs, axis=0) 421 | if n_samples > 1 and return_mean: 422 | background_exprs = background_exprs.mean(0) 423 | salient_exprs = salient_exprs.mean(0) 424 | 425 | if return_numpy is None or return_numpy is False: 426 | genes = adata.var_names[gene_mask] 427 | samples = adata.obs_names[indices] 428 | background_exprs = pd.DataFrame( 429 | background_exprs, columns=genes, index=samples 430 | ) 431 | salient_exprs = pd.DataFrame(salient_exprs, columns=genes, index=samples) 432 | return {"background": background_exprs, "salient": salient_exprs} 433 | 434 | @torch.no_grad() 435 | def get_salient_normalized_expression( 436 | self, 437 | adata: Optional[AnnData] = None, 438 | indices: Optional[Sequence[int]] = None, 439 | transform_batch: Optional[Sequence[Union[Number, str]]] = None, 440 | gene_list: Optional[Sequence[str]] = None, 441 | library_size: Union[float, str] = 1.0, 442 | n_samples: int = 1, 443 | n_samples_overall: Optional[int] = None, 444 | batch_size: Optional[int] = None, 445 | return_mean: bool = True, 446 | return_numpy: Optional[bool] = None, 447 | ) -> Union[np.ndarray, pd.DataFrame]: 448 | """ 449 | Return the normalized (decoded) gene expression. 450 | 451 | Gene expressions are decoded from both the background and salient latent space. 452 | 453 | Args: 454 | ---- 455 | adata: AnnData object with equivalent structure to initial AnnData. If `None`, 456 | defaults to the AnnData object used to initialize the model. 457 | indices: Indices of cells in adata to use. If `None`, all cells are used. 458 | transform_batch: Batch to condition on. If transform_batch is: 459 | - None, then real observed batch is used. 460 | - int, then batch transform_batch is used. 461 | gene_list: Return frequencies of expression for a subset of genes. This can 462 | save memory when working with large datasets and few genes are of interest. 463 | library_size: Scale the expression frequencies to a common library size. This 464 | allows gene expression levels to be interpreted on a common scale of 465 | relevant magnitude. If set to `"latent"`, use the latent library size. 466 | n_samples: Number of posterior samples to use for estimation. 467 | n_samples_overall: The number of random samples in `adata` to use. 468 | batch_size: Mini-batch size for data loading into model. Defaults to 469 | `scvi.settings.batch_size`. 470 | return_mean: Whether to return the mean of the samples. 471 | return_numpy: Return a `numpy.ndarray` instead of a `pandas.DataFrame`. 472 | DataFrame includes gene names as columns. If either `n_samples=1` or 473 | `return_mean=True`, defaults to `False`. Otherwise, it defaults to `True`. 474 | 475 | Returns 476 | ------- 477 | If `n_samples` > 1 and `return_mean` is `False`, then the shape is 478 | `(samples, cells, genes)`. Otherwise, shape is `(cells, genes)`. In this 479 | case, return type is `pandas.DataFrame` unless `return_numpy` is `True`. 480 | """ 481 | exprs = self.get_normalized_expression( 482 | adata=adata, 483 | indices=indices, 484 | transform_batch=transform_batch, 485 | gene_list=gene_list, 486 | library_size=library_size, 487 | n_samples=n_samples, 488 | n_samples_overall=n_samples_overall, 489 | batch_size=batch_size, 490 | return_mean=return_mean, 491 | return_numpy=return_numpy, 492 | ) 493 | return exprs["salient"] 494 | 495 | @torch.no_grad() 496 | def get_specific_normalized_expression( 497 | self, 498 | adata: Optional[AnnData] = None, 499 | indices: Optional[Sequence[int]] = None, 500 | transform_batch: Optional[Sequence[Union[Number, str]]] = None, 501 | gene_list: Optional[Sequence[str]] = None, 502 | library_size: Union[float, str] = 1, 503 | n_samples: int = 1, 504 | n_samples_overall: Optional[int] = None, 505 | batch_size: Optional[int] = None, 506 | return_mean: bool = True, 507 | return_numpy: Optional[bool] = None, 508 | expression_type: Optional[str] = None, 509 | indices_to_return_salient: Optional[Sequence[int]] = None, 510 | ): 511 | """ 512 | Return the normalized (decoded) gene expression. 513 | 514 | Gene expressions are decoded from either the background or salient latent space. 515 | One of `expression_type` or `indices_to_return_salient` should have an input 516 | argument. 517 | 518 | Args: 519 | ---- 520 | adata: AnnData object with equivalent structure to initial AnnData. If `None`, 521 | defaults to the AnnData object used to initialize the model. 522 | indices: Indices of cells in adata to use. If `None`, all cells are used. 523 | transform_batch: Batch to condition on. If transform_batch is: 524 | - None, then real observed batch is used. 525 | - int, then batch transform_batch is used. 526 | gene_list: Return frequencies of expression for a subset of genes. This can 527 | save memory when working with large datasets and few genes are of interest. 528 | library_size: Scale the expression frequencies to a common library size. This 529 | allows gene expression levels to be interpreted on a common scale of 530 | relevant magnitude. If set to `"latent"`, use the latent library size. 531 | n_samples: Number of posterior samples to use for estimation. 532 | n_samples_overall: The number of random samples in `adata` to use. 533 | batch_size: Mini-batch size for data loading into model. Defaults to 534 | `scvi.settings.batch_size`. 535 | return_mean: Whether to return the mean of the samples. 536 | return_numpy: Return a `numpy.ndarray` instead of a `pandas.DataFrame`. 537 | DataFrame includes gene names as columns. If either `n_samples=1` or 538 | `return_mean=True`, defaults to `False`. Otherwise, it defaults to `True`. 539 | expression_type: One of {"salient", "background"} to specify the type of 540 | normalized expression to return. 541 | indices_to_return_salient: If `indices` is a subset of 542 | `indices_to_return_salient`, normalized expressions derived from background 543 | and salient latent embeddings are returned. If `indices` is not `None` and 544 | is not a subset of `indices_to_return_salient`, normalized expressions 545 | derived only from background latent embeddings are returned. 546 | 547 | Returns 548 | ------- 549 | If `n_samples` > 1 and `return_mean` is `False`, then the shape is 550 | `(samples, cells, genes)`. Otherwise, shape is `(cells, genes)`. In this 551 | case, return type is `pandas.DataFrame` unless `return_numpy` is `True`. 552 | """ 553 | is_expression_type_none = expression_type is None 554 | is_indices_to_return_salient_none = indices_to_return_salient is None 555 | if is_expression_type_none and is_indices_to_return_salient_none: 556 | raise ValueError( 557 | "Both expression_type and indices_to_return_salient are None! " 558 | "Exactly one of them needs to be supplied with an input argument." 559 | ) 560 | elif (not is_expression_type_none) and (not is_indices_to_return_salient_none): 561 | raise ValueError( 562 | "Both expression_type and indices_to_return_salient have an input " 563 | "argument! Exactly one of them needs to be supplied with an input " 564 | "argument." 565 | ) 566 | else: 567 | exprs = self.get_normalized_expression( 568 | adata=adata, 569 | indices=indices, 570 | transform_batch=transform_batch, 571 | gene_list=gene_list, 572 | library_size=library_size, 573 | n_samples=n_samples, 574 | n_samples_overall=n_samples_overall, 575 | batch_size=batch_size, 576 | return_mean=return_mean, 577 | return_numpy=return_numpy, 578 | ) 579 | if not is_expression_type_none: 580 | return exprs[expression_type] 581 | else: 582 | if indices is None: 583 | indices = np.arange(adata.n_obs) 584 | if set(indices).issubset(set(indices_to_return_salient)): 585 | return exprs["salient"] 586 | else: 587 | return exprs["background"] 588 | 589 | def differential_expression( 590 | self, 591 | adata: Optional[AnnData] = None, 592 | groupby: Optional[str] = None, 593 | group1: Optional[Iterable[str]] = None, 594 | group2: Optional[str] = None, 595 | idx1: Optional[Union[Sequence[int], Sequence[bool], str]] = None, 596 | idx2: Optional[Union[Sequence[int], Sequence[bool], str]] = None, 597 | mode: str = "change", 598 | delta: float = 0.25, 599 | batch_size: Optional[int] = None, 600 | all_stats: bool = True, 601 | batch_correction: bool = False, 602 | batchid1: Optional[Iterable[str]] = None, 603 | batchid2: Optional[Iterable[str]] = None, 604 | fdr_target: float = 0.05, 605 | silent: bool = False, 606 | target_idx: Optional[Sequence[int]] = None, 607 | n_samples: int = 1, 608 | **kwargs, 609 | ) -> pd.DataFrame: 610 | r""" 611 | Perform differential expression analysis. 612 | 613 | Args: 614 | ---- 615 | adata: AnnData object with equivalent structure to initial AnnData. If `None`, 616 | defaults to the AnnData object used to initialize the model. 617 | groupby: The key of the observations grouping to consider. 618 | group1: Subset of groups, e.g. ["g1", "g2", "g3"], to which comparison shall be 619 | restricted, or all groups in `groupby` (default). 620 | group2: If `None`, compare each group in `group1` to the union of the rest of 621 | the groups in `groupby`. If a group identifier, compare with respect to this 622 | group. 623 | idx1: `idx1` and `idx2` can be used as an alternative to the AnnData keys. 624 | Custom identifier for `group1` that can be of three sorts: 625 | (1) a boolean mask, (2) indices, or (3) a string. If it is a string, then 626 | it will query indices that verifies conditions on adata.obs, as described 627 | in `pandas.DataFrame.query()`. If `idx1` is not `None`, this option 628 | overrides `group1` and `group2`. 629 | idx2: Custom identifier for `group2` that has the same properties as `idx1`. 630 | By default, includes all cells not specified in `idx1`. 631 | mode: Method for differential expression. See 632 | https://docs.scvi-tools.org/en/0.14.1/user_guide/background/differential_expression.html 633 | for more details. 634 | delta: Specific case of region inducing differential expression. In this case, 635 | we suppose that R\[-delta, delta] does not induce differential expression 636 | (change model default case). 637 | batch_size: Mini-batch size for data loading into model. Defaults to 638 | scvi.settings.batch_size. 639 | all_stats: Concatenate count statistics (e.g., mean expression group 1) to DE 640 | results. 641 | batch_correction: Whether to correct for batch effects in DE inference. 642 | batchid1: Subset of categories from `batch_key` registered in `setup_anndata`, 643 | e.g. ["batch1", "batch2", "batch3"], for `group1`. Only used if 644 | `batch_correction` is `True`, and by default all categories are used. 645 | batchid2: Same as `batchid1` for `group2`. `batchid2` must either have null 646 | intersection with `batchid1`, or be exactly equal to `batchid1`. When the 647 | two sets are exactly equal, cells are compared by decoding on the same 648 | batch. When sets have null intersection, cells from `group1` and `group2` 649 | are decoded on each group in `group1` and `group2`, respectively. 650 | fdr_target: Tag features as DE based on posterior expected false discovery rate. 651 | silent: If `True`, disables the progress bar. Default: `False`. 652 | target_idx: If not `None`, a boolean or integer identifier should be used for 653 | cells in the contrastive target group. Normalized expression values derived 654 | from both salient and background latent embeddings are used when 655 | {group1, group2} is a subset of the target group, otherwise background 656 | normalized expression values are used. 657 | 658 | **kwargs: Keyword args for 659 | `scvi.model.base.DifferentialComputation.get_bayes_factors`. 660 | 661 | Returns 662 | ------- 663 | Differential expression DataFrame. 664 | """ 665 | adata = self._validate_anndata(adata) 666 | col_names = adata.var_names 667 | 668 | if target_idx is not None: 669 | target_idx = np.array(target_idx) 670 | if target_idx.dtype is np.dtype("bool"): 671 | assert ( 672 | len(target_idx) == adata.n_obs 673 | ), "target_idx mask must be the same length as adata!" 674 | target_idx = np.arange(adata.n_obs)[target_idx] 675 | model_fn = partial( 676 | self.get_specific_normalized_expression, 677 | return_numpy=True, 678 | n_samples=n_samples, 679 | batch_size=batch_size, 680 | expression_type=None, 681 | indices_to_return_salient=target_idx, 682 | ) 683 | else: 684 | model_fn = partial( 685 | self.get_specific_normalized_expression, 686 | return_numpy=True, 687 | n_samples=n_samples, 688 | batch_size=batch_size, 689 | expression_type="salient", 690 | indices_to_return_salient=None, 691 | ) 692 | 693 | result = _de_core( 694 | self.get_anndata_manager(adata, required=True), 695 | model_fn, 696 | groupby=groupby, 697 | group1=group1, 698 | group2=group2, 699 | idx1=idx1, 700 | idx2=idx2, 701 | all_stats=all_stats, 702 | all_stats_fn=scrna_raw_counts_properties, 703 | col_names=col_names, 704 | mode=mode, 705 | batchid1=batchid1, 706 | batchid2=batchid2, 707 | delta=delta, 708 | batch_correction=batch_correction, 709 | fdr=fdr_target, 710 | silent=silent, 711 | **kwargs, 712 | ) 713 | return result 714 | 715 | @staticmethod 716 | @torch.no_grad() 717 | def _preprocess_normalized_expression( 718 | generative_outputs: Dict[str, torch.Tensor], 719 | generative_output_key: str, 720 | gene_mask: Union[list, slice], 721 | scaling: float, 722 | ) -> np.ndarray: 723 | output = generative_outputs[generative_output_key] 724 | output = output[..., gene_mask] 725 | output *= scaling 726 | output = output.cpu().numpy() 727 | return output 728 | 729 | @torch.no_grad() 730 | def get_latent_library_size( 731 | self, 732 | adata: Optional[AnnData] = None, 733 | indices: Optional[Sequence[int]] = None, 734 | give_mean: bool = True, 735 | batch_size: Optional[int] = None, 736 | ) -> np.ndarray: 737 | r""" 738 | Returns the latent library size for each cell. 739 | This is denoted as :math:`\ell_n` in the scVI paper. 740 | Parameters 741 | ---------- 742 | adata 743 | AnnData object with equivalent structure to initial AnnData. If `None`, 744 | defaults to the AnnData object used to initialize the model. 745 | indices 746 | Indices of cells in adata to use. If `None`, all cells are used. 747 | give_mean 748 | Return the mean or a sample from the posterior distribution. 749 | batch_size 750 | Minibatch size for data loading into model. Defaults to 751 | `scvi.settings.batch_size`. 752 | """ 753 | self._check_if_trained(warn=False) 754 | 755 | adata = self._validate_anndata(adata) 756 | scdl = self._make_data_loader( 757 | adata=adata, indices=indices, batch_size=batch_size 758 | ) 759 | libraries = [] 760 | for tensors in scdl: 761 | x = tensors[REGISTRY_KEYS.X_KEY] 762 | batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] 763 | outputs = self.module._generic_inference(x=x, batch_index=batch_index) 764 | 765 | library = outputs["library"] 766 | if not give_mean: 767 | library = torch.exp(library) 768 | else: 769 | ql = (outputs["ql_m"], outputs["ql_v"]) 770 | if ql is None: 771 | 772 | raise RuntimeError( 773 | "The module for this model does not compute the posterior" 774 | "distribution for the library size. Set `give_mean` to False" 775 | "to use the observed library size instead." 776 | ) 777 | library = torch.distributions.LogNormal(ql[0], ql[1]).mean 778 | libraries += [library.cpu()] 779 | return torch.cat(libraries).numpy() 780 | -------------------------------------------------------------------------------- /contrastive_vi/model/total_contrastive_vi.py: -------------------------------------------------------------------------------- 1 | """Model class for contrastive-VI for single cell expression data.""" 2 | 3 | import logging 4 | import warnings 5 | from collections.abc import Iterable as IterableClass 6 | from functools import partial 7 | from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union 8 | 9 | import numpy as np 10 | import pandas as pd 11 | import torch 12 | from anndata import AnnData 13 | from scvi import REGISTRY_KEYS 14 | from scvi._compat import Literal 15 | from scvi.data import AnnDataManager 16 | from scvi.data.fields import ( 17 | CategoricalJointObsField, 18 | CategoricalObsField, 19 | LayerField, 20 | NumericalJointObsField, 21 | NumericalObsField, 22 | ProteinObsmField, 23 | ) 24 | from scvi.dataloaders import AnnDataLoader 25 | from scvi.model._utils import ( 26 | _get_batch_code_from_category, 27 | _init_library_size, 28 | cite_seq_raw_counts_properties, 29 | ) 30 | from scvi.model.base import BaseModelClass 31 | from scvi.model.base._utils import _de_core 32 | from scvi.utils._docstrings import setup_anndata_dsp 33 | 34 | from contrastive_vi.model.base.training_mixin import ContrastiveTrainingMixin 35 | from contrastive_vi.module.total_contrastive_vi import TotalContrastiveVIModule 36 | 37 | logger = logging.getLogger(__name__) 38 | Number = Union[int, float] 39 | 40 | 41 | class TotalContrastiveVIModel(ContrastiveTrainingMixin, BaseModelClass): 42 | """ 43 | Model class for total-contrastiveVI. 44 | Args: 45 | ---- 46 | adata: AnnData object that has been registered via 47 | `TotalContrastiveVIModel.setup_anndata`. 48 | n_batch: Number of batches. If 0, no batch effect correction is performed. 49 | n_hidden: Number of nodes per hidden layer. 50 | n_background_latent: Dimensionality of the background latent space. 51 | n_salient_latent: Dimensionality of the salient latent space. 52 | n_layers: Number of hidden layers used for encoder and decoder NNs. 53 | dropout_rate: Dropout rate for neural networks. 54 | protein_batch_mask: Dictionary where each key is a batch code, and value is for 55 | each protein, whether it was observed or not. 56 | use_observed_lib_size: Use observed library size for RNA as scaling factor in 57 | mean of conditional distribution. 58 | empirical_protein_background_prior: Set the initialization of protein 59 | background prior empirically. This option fits a GMM for each of 60 | 100 cells per batch and averages the distributions. Note that even with 61 | this option set to `True`, this only initializes a parameter that is 62 | learned during inference. If `False`, randomly initializes. The default 63 | (`None`), sets this to `True` if greater than 10 proteins are used. 64 | """ 65 | 66 | def __init__( 67 | self, 68 | adata: AnnData, 69 | n_hidden: int = 128, 70 | n_background_latent: int = 10, 71 | n_salient_latent: int = 10, 72 | gene_dispersion: Literal[ 73 | "gene", "gene-batch", "gene-label", "gene-cell" 74 | ] = "gene", 75 | protein_dispersion: Literal[ 76 | "protein", "protein-batch", "protein-label" 77 | ] = "protein", 78 | gene_likelihood: Literal["zinb", "nb"] = "nb", 79 | latent_distribution: Literal["normal", "ln"] = "normal", 80 | empirical_protein_background_prior: Optional[bool] = None, 81 | override_missing_proteins: bool = False, 82 | wasserstein_penalty: float = 0, 83 | **model_kwargs, 84 | ) -> None: 85 | super(TotalContrastiveVIModel, self).__init__(adata) 86 | 87 | self.protein_state_registry = self.adata_manager.get_state_registry( 88 | REGISTRY_KEYS.PROTEIN_EXP_KEY 89 | ) 90 | if ( 91 | ProteinObsmField.PROTEIN_BATCH_MASK in self.protein_state_registry 92 | and not override_missing_proteins 93 | ): 94 | batch_mask = self.protein_state_registry.protein_batch_mask 95 | msg = ( 96 | "Some proteins have all 0 counts in some batches. " 97 | + "These proteins will be treated as missing measurements; however, " 98 | + "this can occur due to experimental design/biology. " 99 | + "Reinitialize the model with `override_missing_proteins=True`," 100 | + "to override this behavior." 101 | ) 102 | warnings.warn(msg, UserWarning) 103 | self._use_adversarial_classifier = True 104 | else: 105 | batch_mask = None 106 | self._use_adversarial_classifier = False 107 | 108 | emp_prior = ( 109 | empirical_protein_background_prior 110 | if empirical_protein_background_prior is not None 111 | else (self.summary_stats.n_proteins > 10) 112 | ) 113 | if emp_prior: 114 | prior_mean, prior_scale = self._get_totalvi_protein_priors(adata) 115 | else: 116 | prior_mean, prior_scale = None, None 117 | 118 | n_cats_per_cov = ( 119 | self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY)[ 120 | CategoricalJointObsField.N_CATS_PER_KEY 121 | ] 122 | if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry 123 | else None 124 | ) 125 | 126 | n_batch = self.summary_stats.n_batch 127 | use_size_factor_key = ( 128 | REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry 129 | ) 130 | library_log_means, library_log_vars = None, None 131 | if not use_size_factor_key: 132 | library_log_means, library_log_vars = _init_library_size( 133 | self.adata_manager, n_batch 134 | ) 135 | 136 | self.module = TotalContrastiveVIModule( 137 | n_input_genes=self.summary_stats["n_vars"], 138 | n_input_proteins=self.summary_stats["n_proteins"], 139 | n_batch=n_batch, 140 | n_background_latent=n_background_latent, 141 | n_salient_latent=n_salient_latent, 142 | n_continuous_cov=self.summary_stats.get("n_extra_continuous_covs", 0), 143 | n_cats_per_cov=n_cats_per_cov, 144 | gene_dispersion=gene_dispersion, 145 | protein_dispersion=protein_dispersion, 146 | gene_likelihood=gene_likelihood, 147 | latent_distribution=latent_distribution, 148 | protein_batch_mask=batch_mask, 149 | protein_background_prior_mean=prior_mean, 150 | protein_background_prior_scale=prior_scale, 151 | use_size_factor_key=use_size_factor_key, 152 | library_log_means=library_log_means, 153 | library_log_vars=library_log_vars, 154 | wasserstein_penalty=wasserstein_penalty, 155 | **model_kwargs, 156 | ) 157 | self._model_summary_string = "totalContrastiveVI." 158 | # Necessary line to get params to be used for saving and loading. 159 | self.init_params_ = self._get_init_params(locals()) 160 | logger.info("The model has been initialized") 161 | 162 | @classmethod 163 | @setup_anndata_dsp.dedent 164 | def setup_anndata( 165 | cls, 166 | adata: AnnData, 167 | protein_expression_obsm_key: str, 168 | protein_names_uns_key: Optional[str] = None, 169 | batch_key: Optional[str] = None, 170 | layer: Optional[str] = None, 171 | size_factor_key: Optional[str] = None, 172 | categorical_covariate_keys: Optional[List[str]] = None, 173 | continuous_covariate_keys: Optional[List[str]] = None, 174 | **kwargs, 175 | ) -> Optional[AnnData]: 176 | """ 177 | %(summary)s. 178 | Parameters 179 | ---------- 180 | %(param_adata)s 181 | protein_expression_obsm_key 182 | key in `adata.obsm` for protein expression data. 183 | protein_names_uns_key 184 | key in `adata.uns` for protein names. If None, will use the column names of 185 | `adata.obsm[protein_expression_obsm_key]` if it is a DataFrame, else will 186 | assign sequential names to proteins. 187 | %(param_batch_key)s 188 | %(param_layer)s 189 | %(param_size_factor_key)s 190 | %(param_cat_cov_keys)s 191 | %(param_cont_cov_keys)s 192 | %(param_copy)s 193 | Returns 194 | ------- 195 | %(returns)s 196 | """ 197 | setup_method_args = cls._get_setup_method_args(**locals()) 198 | batch_field = CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key) 199 | anndata_fields = [ 200 | LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), 201 | CategoricalObsField( 202 | REGISTRY_KEYS.LABELS_KEY, None 203 | ), # Default labels field for compatibility with TOTALVAE 204 | batch_field, 205 | NumericalObsField( 206 | REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key, required=False 207 | ), 208 | CategoricalJointObsField( 209 | REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys 210 | ), 211 | NumericalJointObsField( 212 | REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys 213 | ), 214 | ProteinObsmField( 215 | REGISTRY_KEYS.PROTEIN_EXP_KEY, 216 | protein_expression_obsm_key, 217 | use_batch_mask=True, 218 | batch_key=batch_field.attr_key, 219 | colnames_uns_key=protein_names_uns_key, 220 | is_count_data=True, 221 | ), 222 | ] 223 | adata_manager = AnnDataManager( 224 | fields=anndata_fields, setup_method_args=setup_method_args 225 | ) 226 | adata_manager.register_fields(adata, **kwargs) 227 | cls.register_manager(adata_manager) 228 | 229 | @torch.no_grad() 230 | def get_latent_representation( 231 | self, 232 | adata: Optional[AnnData] = None, 233 | indices: Optional[Sequence[int]] = None, 234 | give_mean: bool = True, 235 | batch_size: Optional[int] = None, 236 | representation_kind: str = "salient", 237 | ) -> np.ndarray: 238 | """ 239 | Return the background or salient latent representation for each cell. 240 | 241 | Args: 242 | ---- 243 | adata: AnnData object with equivalent structure to initial AnnData. If `None`, 244 | defaults to the AnnData object used to initialize the model. 245 | indices: Indices of cells in adata to use. If `None`, all cells are used. 246 | give_mean: Give mean of distribution or sample from it. 247 | batch_size: Mini-batch size for data loading into model. Defaults to 248 | `scvi.settings.batch_size`. 249 | representation_kind: Either "background" or "salient" for the corresponding 250 | representation kind. 251 | 252 | Returns 253 | ------- 254 | A numpy array with shape `(n_cells, n_latent)`. 255 | """ 256 | available_representation_kinds = ["background", "salient"] 257 | assert representation_kind in available_representation_kinds, ( 258 | f"representation_kind = {representation_kind} is not one of" 259 | f" {available_representation_kinds}" 260 | ) 261 | 262 | adata = self._validate_anndata(adata) 263 | data_loader = self._make_data_loader( 264 | adata=adata, 265 | indices=indices, 266 | batch_size=batch_size, 267 | shuffle=False, 268 | data_loader_class=AnnDataLoader, 269 | ) 270 | latent = [] 271 | for tensors in data_loader: 272 | x = tensors[REGISTRY_KEYS.X_KEY] 273 | y = tensors[REGISTRY_KEYS.PROTEIN_EXP_KEY] 274 | batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] 275 | outputs = self.module._generic_inference( 276 | x=x, y=y, batch_index=batch_index, n_samples=1 277 | ) 278 | 279 | if representation_kind == "background": 280 | latent_m = outputs["qz_m"] 281 | latent_sample = outputs["z"] 282 | else: 283 | latent_m = outputs["qs_m"] 284 | latent_sample = outputs["s"] 285 | 286 | if give_mean: 287 | latent_sample = latent_m 288 | 289 | latent += [latent_sample.detach().cpu()] 290 | return torch.cat(latent).numpy() 291 | 292 | @torch.no_grad() 293 | def get_normalized_expression( 294 | self, 295 | adata=None, 296 | indices=None, 297 | n_samples_overall: Optional[int] = None, 298 | transform_batch: Optional[Sequence[Union[Number, str]]] = None, 299 | gene_list: Optional[Sequence[str]] = None, 300 | protein_list: Optional[Sequence[str]] = None, 301 | library_size: Optional[Union[float, Literal["latent"]]] = 1, 302 | n_samples: int = 1, 303 | sample_protein_mixing: bool = False, 304 | scale_protein: bool = False, 305 | include_protein_background: bool = False, 306 | batch_size: Optional[int] = None, 307 | return_mean: bool = True, 308 | return_numpy: Optional[bool] = None, 309 | ) -> Dict[ 310 | str, Tuple[Union[np.ndarray, pd.DataFrame], Union[np.ndarray, pd.DataFrame]] 311 | ]: 312 | r""" 313 | Returns the normalized gene expression and protein expression. 314 | This is denoted as :math:`\rho_n` in the totalVI paper for genes, and TODO 315 | for proteins, :math:`(1-\pi_{nt})\alpha_{nt}\beta_{nt}`. 316 | Parameters 317 | ---------- 318 | adata 319 | AnnData object with equivalent structure to initial AnnData. If `None`, 320 | defaults to the AnnData object used to initialize the model. 321 | indices 322 | Indices of cells in adata to use. If `None`, all cells are used. 323 | n_samples_overall 324 | Number of samples to use in total 325 | transform_batch 326 | Batch to condition on. 327 | If transform_batch is: 328 | - None, then real observed batch is used 329 | - int, then batch transform_batch is used 330 | - List[int], then average over batches in list 331 | gene_list 332 | Return frequencies of expression for a subset of genes. 333 | This can save memory when working with large datasets and few genes are 334 | of interest. 335 | protein_list 336 | Return protein expression for a subset of genes. 337 | This can save memory when working with large datasets and few genes are 338 | of interest. 339 | library_size 340 | Scale the expression frequencies to a common library size. 341 | This allows gene expression levels to be interpreted on a common scale of 342 | relevant magnitude. 343 | n_samples 344 | Get sample scale from multiple samples. 345 | sample_protein_mixing 346 | Sample mixing bernoulli, setting background to zero 347 | scale_protein 348 | Make protein expression sum to 1 349 | include_protein_background 350 | Include background component for protein expression 351 | batch_size 352 | Minibatch size for data loading into model. Defaults to 353 | `scvi.settings.batch_size`. 354 | return_mean 355 | Whether to return the mean of the samples. 356 | return_numpy 357 | Return a `np.ndarray` instead of a `pd.DataFrame`. Includes gene 358 | names as columns. If either n_samples=1 or return_mean=True, defaults to 359 | False. Otherwise, it defaults to True. 360 | Returns 361 | ------- 362 | - **gene_normalized_expression** - normalized expression for RNA 363 | - **protein_normalized_expression** - normalized expression for proteins 364 | If ``n_samples`` > 1 and ``return_mean`` is False, then the shape is 365 | ``(samples, cells, genes)``. Otherwise, shape is ``(cells, genes)``. 366 | Return type is ``pd.DataFrame`` unless ``return_numpy`` is True. 367 | """ 368 | adata = self._validate_anndata(adata) 369 | adata_manager = self.get_anndata_manager(adata) 370 | if indices is None: 371 | indices = np.arange(adata.n_obs) 372 | if n_samples_overall is not None: 373 | indices = np.random.choice(indices, n_samples_overall) 374 | post = self._make_data_loader( 375 | adata=adata, indices=indices, batch_size=batch_size 376 | ) 377 | 378 | if gene_list is None: 379 | gene_mask = slice(None) 380 | else: 381 | all_genes = adata.var_names 382 | gene_mask = [True if gene in gene_list else False for gene in all_genes] 383 | if protein_list is None: 384 | protein_mask = slice(None) 385 | else: 386 | all_proteins = self.scvi_setup_dict_["protein_names"] 387 | protein_mask = [True if p in protein_list else False for p in all_proteins] 388 | if indices is None: 389 | indices = np.arange(adata.n_obs) 390 | 391 | if n_samples > 1 and return_mean is False: 392 | if return_numpy is False: 393 | warnings.warn( 394 | "return_numpy must be True if n_samples > 1 and return_mean is " 395 | "False, returning np.ndarray" 396 | ) 397 | return_numpy = True 398 | 399 | if not isinstance(transform_batch, IterableClass): 400 | transform_batch = [transform_batch] 401 | 402 | transform_batch = _get_batch_code_from_category(adata_manager, transform_batch) 403 | 404 | results = {} 405 | for expression_type in ["salient", "background"]: 406 | scale_list_gene = [] 407 | scale_list_pro = [] 408 | 409 | for tensors in post: 410 | x = tensors[REGISTRY_KEYS.X_KEY] 411 | y = tensors[REGISTRY_KEYS.PROTEIN_EXP_KEY] 412 | batch_original = tensors[REGISTRY_KEYS.BATCH_KEY] 413 | px_scale = torch.zeros_like(x) 414 | py_scale = torch.zeros_like(y) 415 | if n_samples > 1: 416 | px_scale = torch.stack(n_samples * [px_scale]) 417 | py_scale = torch.stack(n_samples * [py_scale]) 418 | for b in transform_batch: 419 | inference_outputs = self.module._generic_inference( 420 | x=x, y=y, batch_index=batch_original, n_samples=n_samples 421 | ) 422 | 423 | if expression_type == "salient": 424 | s = inference_outputs["s"] 425 | elif expression_type == "background": 426 | s = torch.zeros_like(inference_outputs["s"]) 427 | else: 428 | raise NotImplementedError("Invalid expression type provided") 429 | 430 | generative_outputs = self.module._generic_generative( 431 | z=inference_outputs["z"], 432 | s=s, 433 | library_gene=inference_outputs["library_gene"], 434 | batch_index=b, 435 | ) 436 | 437 | if library_size == "latent": 438 | px_scale += generative_outputs["px_"]["rate"].cpu() 439 | else: 440 | px_scale += generative_outputs["px_"]["scale"].cpu() 441 | px_scale = px_scale[..., gene_mask] 442 | 443 | py_ = generative_outputs["py_"] 444 | # probability of background 445 | protein_mixing = 1 / (1 + torch.exp(-py_["mixing"].cpu())) 446 | if sample_protein_mixing is True: 447 | protein_mixing = torch.distributions.Bernoulli( 448 | protein_mixing 449 | ).sample() 450 | protein_val = py_["rate_fore"].cpu() * (1 - protein_mixing) 451 | if include_protein_background is True: 452 | protein_val += py_["rate_back"].cpu() * protein_mixing 453 | 454 | if scale_protein is True: 455 | protein_val = torch.nn.functional.normalize( 456 | protein_val, p=1, dim=-1 457 | ) 458 | protein_val = protein_val[..., protein_mask] 459 | py_scale += protein_val 460 | px_scale /= len(transform_batch) 461 | py_scale /= len(transform_batch) 462 | scale_list_gene.append(px_scale) 463 | scale_list_pro.append(py_scale) 464 | 465 | if n_samples > 1: 466 | # concatenate along batch dimension 467 | # -> result shape = (samples, cells, features) 468 | scale_list_gene = torch.cat(scale_list_gene, dim=1) 469 | scale_list_pro = torch.cat(scale_list_pro, dim=1) 470 | # (cells, features, samples) 471 | scale_list_gene = scale_list_gene.permute(1, 2, 0) 472 | scale_list_pro = scale_list_pro.permute(1, 2, 0) 473 | else: 474 | scale_list_gene = torch.cat(scale_list_gene, dim=0) 475 | scale_list_pro = torch.cat(scale_list_pro, dim=0) 476 | 477 | if return_mean is True and n_samples > 1: 478 | scale_list_gene = torch.mean(scale_list_gene, dim=-1) 479 | scale_list_pro = torch.mean(scale_list_pro, dim=-1) 480 | 481 | scale_list_gene = scale_list_gene.cpu().numpy() 482 | scale_list_pro = scale_list_pro.cpu().numpy() 483 | if return_numpy is None or return_numpy is False: 484 | gene_df = pd.DataFrame( 485 | scale_list_gene, 486 | columns=adata.var_names[gene_mask], 487 | index=adata.obs_names[indices], 488 | ) 489 | protein_names = self.protein_state_registry.column_names 490 | pro_df = pd.DataFrame( 491 | scale_list_pro, 492 | columns=protein_names[protein_mask], 493 | index=adata.obs_names[indices], 494 | ) 495 | 496 | results[expression_type] = (gene_df, pro_df) 497 | else: 498 | results[expression_type] = (scale_list_gene, scale_list_pro) 499 | return results 500 | 501 | @torch.no_grad() 502 | def get_specific_normalized_expression( 503 | self, 504 | adata=None, 505 | indices=None, 506 | n_samples_overall=None, 507 | transform_batch: Optional[Sequence[Union[Number, str]]] = None, 508 | scale_protein=False, 509 | batch_size: Optional[int] = None, 510 | n_samples=1, 511 | sample_protein_mixing=False, 512 | include_protein_background=False, 513 | return_mean=True, 514 | return_numpy=True, 515 | expression_type: Optional[str] = None, 516 | indices_to_return_salient: Optional[Sequence[int]] = None, 517 | ): 518 | """ 519 | Return normalized (decoded) gene and protein expression. 520 | 521 | Gene + protein expressions are decoded from either the background or salient 522 | latent space. One of `expression_type` or `indices_to_return_salient` should 523 | have an input argument. 524 | 525 | Args: 526 | ---- 527 | adata: 528 | AnnData object with equivalent structure to initial AnnData. If `None`, 529 | defaults to the AnnData object used to initialize the model. 530 | indices: Indices of cells in adata to use. If `None`, all cells are used. 531 | n_samples_overall: The number of random samples in `adata` to use. 532 | transform_batch: 533 | Batch to condition on. 534 | If transform_batch is: 535 | - None, then real observed batch is used 536 | - int, then batch transform_batch is used 537 | - List[int], then average over batches in list 538 | scale_protein: Make protein expression sum to 1 539 | batch_size: 540 | Minibatch size for data loading into model. Defaults to 541 | `scvi.settings.batch_size`. 542 | sample_protein_mixing: Sample mixing bernoulli, setting background to zero 543 | include_protein_background: Include background component for protein expression 544 | return_mean: Whether to return the mean of the samples. 545 | return_numpy: 546 | Return a `np.ndarray` instead of a `pd.DataFrame`. Includes gene 547 | names as columns. If either n_samples=1 or return_mean=True, defaults to 548 | False. Otherwise, it defaults to True. 549 | expression_type: One of {"salient", "background"} to specify the type of 550 | normalized expression to return. 551 | indices_to_return_salient: If `indices` is a subset of 552 | `indices_to_return_salient`, normalized expressions derived from background 553 | and salient latent embeddings are returned. If `indices` is not `None` and 554 | is not a subset of `indices_to_return_salient`, normalized expressions 555 | derived only from background latent embeddings are returned. 556 | 557 | Returns 558 | ------- 559 | If `n_samples` > 1 and `return_mean` is `False`, then the shape is 560 | `((samples, cells, genes), (samples, cells, proteins))`. Otherwise, shape 561 | is `((cells, genes), (cells, proteins))`. In this case, return type is 562 | Tuple[`pandas.DataFrame`] unless `return_numpy` is `True`. 563 | """ 564 | is_expression_type_none = expression_type is None 565 | is_indices_to_return_salient_none = indices_to_return_salient is None 566 | if is_expression_type_none and is_indices_to_return_salient_none: 567 | raise ValueError( 568 | "Both expression_type and indices_to_return_salient are None! " 569 | "Exactly one of them needs to be supplied with an input argument." 570 | ) 571 | elif (not is_expression_type_none) and (not is_indices_to_return_salient_none): 572 | raise ValueError( 573 | "Both expression_type and indices_to_return_salient have an input " 574 | "argument! Exactly one of them needs to be supplied with an input " 575 | "argument." 576 | ) 577 | else: 578 | exprs = self.get_normalized_expression( 579 | adata=adata, 580 | indices=indices, 581 | n_samples_overall=n_samples_overall, 582 | transform_batch=transform_batch, 583 | return_numpy=return_numpy, 584 | return_mean=return_mean, 585 | n_samples=n_samples, 586 | batch_size=batch_size, 587 | scale_protein=scale_protein, 588 | sample_protein_mixing=sample_protein_mixing, 589 | include_protein_background=include_protein_background, 590 | ) 591 | if not is_expression_type_none: 592 | return exprs[expression_type] 593 | else: 594 | if indices is None: 595 | indices = np.arange(adata.n_obs) 596 | if set(indices).issubset(set(indices_to_return_salient)): 597 | return exprs["salient"] 598 | else: 599 | return exprs["background"] 600 | 601 | def _expression_for_de( 602 | self, 603 | adata=None, 604 | indices=None, 605 | n_samples_overall=None, 606 | transform_batch: Optional[Sequence[Union[Number, str]]] = None, 607 | scale_protein=False, 608 | batch_size: Optional[int] = None, 609 | n_samples=1, 610 | sample_protein_mixing=False, 611 | include_protein_background=False, 612 | protein_prior_count=0.5, 613 | expression_type: Optional[str] = None, 614 | indices_to_return_salient: Optional[Sequence[int]] = None, 615 | ): 616 | rna, protein = self.get_specific_normalized_expression( 617 | adata=adata, 618 | indices=indices, 619 | n_samples_overall=n_samples_overall, 620 | transform_batch=transform_batch, 621 | return_numpy=True, 622 | n_samples=n_samples, 623 | batch_size=batch_size, 624 | scale_protein=scale_protein, 625 | sample_protein_mixing=sample_protein_mixing, 626 | include_protein_background=include_protein_background, 627 | expression_type=expression_type, 628 | indices_to_return_salient=indices_to_return_salient, 629 | ) 630 | protein += protein_prior_count 631 | 632 | joint = np.concatenate([rna, protein], axis=1) 633 | return joint 634 | 635 | def differential_expression( 636 | self, 637 | adata: Optional[AnnData] = None, 638 | groupby: Optional[str] = None, 639 | group1: Optional[Iterable[str]] = None, 640 | group2: Optional[str] = None, 641 | idx1: Optional[Union[Sequence[int], Sequence[bool], str]] = None, 642 | idx2: Optional[Union[Sequence[int], Sequence[bool], str]] = None, 643 | mode: Literal["vanilla", "change"] = "change", 644 | delta: float = 0.25, 645 | batch_size: Optional[int] = None, 646 | all_stats: bool = True, 647 | batch_correction: bool = False, 648 | batchid1: Optional[Iterable[str]] = None, 649 | batchid2: Optional[Iterable[str]] = None, 650 | fdr_target: float = 0.05, 651 | silent: bool = False, 652 | protein_prior_count: float = 0.1, 653 | scale_protein: bool = False, 654 | sample_protein_mixing: bool = False, 655 | include_protein_background: bool = False, 656 | target_idx: Optional[Sequence[int]] = None, 657 | **kwargs, 658 | ) -> pd.DataFrame: 659 | r""" 660 | A unified method for differential expression analysis. 661 | Implements `"vanilla"` DE [Lopez18]_ and `"change"` mode DE [Boyeau19]_. 662 | 663 | Args: 664 | ---- 665 | protein_prior_count: 666 | Prior count added to protein expression before LFC computation 667 | scale_protein: 668 | Force protein values to sum to one in every single cell 669 | (post-hoc normalization). 670 | sample_protein_mixing: 671 | Sample the protein mixture component, i.e., use the parameter to sample a 672 | Bernoulli that determines if expression is from foreground/background. 673 | include_protein_background: 674 | Include the protein background component as part of the protein expression 675 | target_idx: If not `None`, a boolean or integer identifier should be used for 676 | cells in the contrastive target group. Normalized expression values derived 677 | from both salient and background latent embeddings are used when 678 | {group1, group2} is a subset of the target group, otherwise background 679 | normalized expression values are used. 680 | **kwargs: 681 | Keyword args for 682 | :meth:`scvi.model.base.DifferentialComputation.get_bayes_factors` 683 | 684 | Returns 685 | ------- 686 | Differential expression DataFrame. 687 | """ 688 | adata = self._validate_anndata(adata) 689 | col_names = np.concatenate( 690 | [ 691 | np.asarray(adata.var_names), 692 | self.protein_state_registry.column_names, 693 | ] 694 | ) 695 | 696 | if target_idx is not None: 697 | target_idx = np.array(target_idx) 698 | if target_idx.dtype is np.dtype("bool"): 699 | assert ( 700 | len(target_idx) == adata.n_obs 701 | ), "target_idx mask must be the same length as adata!" 702 | target_idx = np.arange(adata.n_obs)[target_idx] 703 | model_fn = partial( 704 | self._expression_for_de, 705 | scale_protein=scale_protein, 706 | sample_protein_mixing=sample_protein_mixing, 707 | include_protein_background=include_protein_background, 708 | protein_prior_count=protein_prior_count, 709 | batch_size=batch_size, 710 | expression_type=None, 711 | indices_to_return_salient=target_idx, 712 | n_samples=100, 713 | ) 714 | else: 715 | model_fn = partial( 716 | self._expression_for_de, 717 | scale_protein=scale_protein, 718 | sample_protein_mixing=sample_protein_mixing, 719 | include_protein_background=include_protein_background, 720 | protein_prior_count=protein_prior_count, 721 | batch_size=batch_size, 722 | expression_type="salient", 723 | n_samples=100, 724 | ) 725 | 726 | result = _de_core( 727 | self.get_anndata_manager(adata, required=True), 728 | model_fn, 729 | groupby, 730 | group1, 731 | group2, 732 | idx1, 733 | idx2, 734 | all_stats, 735 | cite_seq_raw_counts_properties, 736 | col_names, 737 | mode, 738 | batchid1, 739 | batchid2, 740 | delta, 741 | batch_correction, 742 | fdr_target, 743 | silent, 744 | **kwargs, 745 | ) 746 | 747 | return result 748 | 749 | @torch.no_grad() 750 | def get_latent_library_size( 751 | self, 752 | adata: Optional[AnnData] = None, 753 | indices: Optional[Sequence[int]] = None, 754 | give_mean: bool = True, 755 | batch_size: Optional[int] = None, 756 | ) -> np.ndarray: 757 | r""" 758 | Returns the latent RNA library size for each cell. 759 | This is denoted as :math:`\ell_n` in the totalVI paper. 760 | Parameters 761 | ---------- 762 | adata 763 | AnnData object with equivalent structure to initial AnnData. If `None`, 764 | defaults to the AnnData object used to initialize the model. 765 | indices 766 | Indices of cells in adata to use. If `None`, all cells are used. 767 | give_mean 768 | Return the mean or a sample from the posterior distribution. 769 | batch_size 770 | Minibatch size for data loading into model. Defaults to 771 | `scvi.settings.batch_size`. 772 | """ 773 | self._check_if_trained(warn=False) 774 | 775 | adata = self._validate_anndata(adata) 776 | scdl = self._make_data_loader( 777 | adata=adata, indices=indices, batch_size=batch_size 778 | ) 779 | libraries = [] 780 | for tensors in scdl: 781 | x = tensors[REGISTRY_KEYS.X_KEY] 782 | y = tensors[REGISTRY_KEYS.PROTEIN_EXP_KEY] 783 | batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] 784 | outputs = self.module._generic_inference(x=x, y=y, batch_index=batch_index) 785 | 786 | library = outputs["library_gene"] 787 | if not give_mean: 788 | library = torch.exp(library) 789 | else: 790 | ql = (outputs["ql_m"], outputs["ql_v"]) 791 | if ql is None: 792 | raise RuntimeError( 793 | "The module for this model does not compute the posterior" 794 | "distribution for the library size. Set `give_mean` to False to" 795 | "use the observed library size instead." 796 | ) 797 | library = torch.distributions.LogNormal(ql[0], ql[1]).mean 798 | libraries += [library.cpu()] 799 | return torch.cat(libraries).numpy() 800 | -------------------------------------------------------------------------------- /contrastive_vi/module/__init__.py: -------------------------------------------------------------------------------- 1 | """PyTorch modules for models.""" 2 | -------------------------------------------------------------------------------- /contrastive_vi/module/contrastive_vi.py: -------------------------------------------------------------------------------- 1 | """PyTorch module for Contrastive VI for single cell expression data.""" 2 | 3 | from typing import Dict, Optional, Tuple 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from scvi import REGISTRY_KEYS 10 | from scvi.distributions import ZeroInflatedNegativeBinomial 11 | from scvi.module.base import BaseModuleClass, LossRecorder, auto_move_data 12 | from scvi.nn import DecoderSCVI, Encoder, one_hot 13 | from torch.distributions import Normal 14 | from torch.distributions import kl_divergence as kl 15 | 16 | from contrastive_vi.module.utils import gram_matrix 17 | 18 | torch.backends.cudnn.benchmark = True 19 | 20 | 21 | class ContrastiveVIModule(BaseModuleClass): 22 | """ 23 | PyTorch module for Contrastive VI (Variational Inference). 24 | 25 | Args: 26 | ---- 27 | n_input: Number of input genes. 28 | n_batch: Number of batches. If 0, no batch effect correction is performed. 29 | n_hidden: Number of nodes per hidden layer. 30 | n_background_latent: Dimensionality of the background latent space. 31 | n_salient_latent: Dimensionality of the salient latent space. 32 | n_layers: Number of hidden layers used for encoder and decoder NNs. 33 | dropout_rate: Dropout rate for neural networks. 34 | use_observed_lib_size: Use observed library size for RNA as scaling factor in 35 | mean of conditional distribution. 36 | library_log_means: 1 x n_batch array of means of the log library sizes. 37 | Parameterize prior on library size if not using observed library size. 38 | library_log_vars: 1 x n_batch array of variances of the log library sizes. 39 | Parameterize prior on library size if not using observed library size. 40 | wasserstein_penalty: Weight of the Wasserstein distance loss that further 41 | discourages shared variations from leaking into the salient latent space. 42 | """ 43 | 44 | def __init__( 45 | self, 46 | n_input: int, 47 | n_batch: int = 0, 48 | n_hidden: int = 128, 49 | n_background_latent: int = 10, 50 | n_salient_latent: int = 10, 51 | n_layers: int = 1, 52 | dropout_rate: float = 0.1, 53 | use_observed_lib_size: bool = True, 54 | library_log_means: Optional[np.ndarray] = None, 55 | library_log_vars: Optional[np.ndarray] = None, 56 | wasserstein_penalty: float = 0 57 | ) -> None: 58 | super().__init__() 59 | self.n_input = n_input 60 | self.n_batch = n_batch 61 | self.n_hidden = n_hidden 62 | self.n_background_latent = n_background_latent 63 | self.n_salient_latent = n_salient_latent 64 | self.n_layers = n_layers 65 | self.dropout_rate = dropout_rate 66 | self.latent_distribution = "normal" 67 | self.dispersion = "gene" 68 | self.px_r = torch.nn.Parameter(torch.randn(n_input)) 69 | self.use_observed_lib_size = use_observed_lib_size 70 | self.wasserstein_penalty = wasserstein_penalty 71 | 72 | if not self.use_observed_lib_size: 73 | if library_log_means is None or library_log_vars is None: 74 | raise ValueError( 75 | "If not using observed_lib_size, " 76 | "must provide library_log_means and library_log_vars." 77 | ) 78 | self.register_buffer( 79 | "library_log_means", torch.from_numpy(library_log_means).float() 80 | ) 81 | self.register_buffer( 82 | "library_log_vars", torch.from_numpy(library_log_vars).float() 83 | ) 84 | 85 | cat_list = [n_batch] 86 | # Background encoder. 87 | self.z_encoder = Encoder( 88 | n_input, 89 | n_background_latent, 90 | n_cat_list=cat_list, 91 | n_layers=n_layers, 92 | n_hidden=n_hidden, 93 | dropout_rate=dropout_rate, 94 | distribution=self.latent_distribution, 95 | inject_covariates=True, 96 | use_batch_norm=True, 97 | use_layer_norm=False, 98 | var_activation=None, 99 | ) 100 | # Salient encoder. 101 | self.s_encoder = Encoder( 102 | n_input, 103 | n_salient_latent, 104 | n_cat_list=cat_list, 105 | n_layers=n_layers, 106 | n_hidden=n_hidden, 107 | dropout_rate=dropout_rate, 108 | distribution=self.latent_distribution, 109 | inject_covariates=True, 110 | use_batch_norm=True, 111 | use_layer_norm=False, 112 | var_activation=None, 113 | ) 114 | # Library size encoder. 115 | self.l_encoder = Encoder( 116 | n_input, 117 | 1, 118 | n_layers=1, 119 | n_cat_list=cat_list, 120 | n_hidden=n_hidden, 121 | dropout_rate=dropout_rate, 122 | inject_covariates=True, 123 | use_batch_norm=True, 124 | use_layer_norm=False, 125 | var_activation=None, 126 | ) 127 | # Decoder from latent variable to distribution parameters in data space. 128 | n_total_latent = n_background_latent + n_salient_latent 129 | self.decoder = DecoderSCVI( 130 | n_total_latent, 131 | n_input, 132 | n_cat_list=cat_list, 133 | n_layers=n_layers, 134 | n_hidden=n_hidden, 135 | inject_covariates=True, 136 | use_batch_norm=True, 137 | use_layer_norm=False, 138 | ) 139 | 140 | @auto_move_data 141 | def _compute_local_library_params( 142 | self, batch_index: torch.Tensor 143 | ) -> Tuple[torch.Tensor, torch.Tensor]: 144 | n_batch = self.library_log_means.shape[1] 145 | local_library_log_means = F.linear( 146 | one_hot(batch_index, n_batch), self.library_log_means 147 | ) 148 | local_library_log_vars = F.linear( 149 | one_hot(batch_index, n_batch), self.library_log_vars 150 | ) 151 | return local_library_log_means, local_library_log_vars 152 | 153 | @staticmethod 154 | def _get_min_batch_size(concat_tensors: Dict[str, Dict[str, torch.Tensor]]) -> int: 155 | return min( 156 | concat_tensors["background"][REGISTRY_KEYS.X_KEY].shape[0], 157 | concat_tensors["target"][REGISTRY_KEYS.X_KEY].shape[0], 158 | ) 159 | 160 | @staticmethod 161 | def _reduce_tensors_to_min_batch_size( 162 | tensors: Dict[str, torch.Tensor], min_batch_size: int 163 | ) -> None: 164 | for name, tensor in tensors.items(): 165 | tensors[name] = tensor[:min_batch_size, :] 166 | 167 | @staticmethod 168 | def _get_inference_input_from_concat_tensors( 169 | concat_tensors: Dict[str, Dict[str, torch.Tensor]], index: str 170 | ) -> Dict[str, torch.Tensor]: 171 | tensors = concat_tensors[index] 172 | x = tensors[REGISTRY_KEYS.X_KEY] 173 | batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] 174 | input_dict = dict(x=x, batch_index=batch_index) 175 | return input_dict 176 | 177 | def _get_inference_input( 178 | self, concat_tensors: Dict[str, Dict[str, torch.Tensor]] 179 | ) -> Dict[str, Dict[str, torch.Tensor]]: 180 | background = self._get_inference_input_from_concat_tensors( 181 | concat_tensors, "background" 182 | ) 183 | target = self._get_inference_input_from_concat_tensors(concat_tensors, "target") 184 | # Ensure batch sizes are the same. 185 | min_batch_size = self._get_min_batch_size(concat_tensors) 186 | self._reduce_tensors_to_min_batch_size(background, min_batch_size) 187 | self._reduce_tensors_to_min_batch_size(target, min_batch_size) 188 | return dict(background=background, target=target) 189 | 190 | @staticmethod 191 | def _get_generative_input_from_concat_tensors( 192 | concat_tensors: Dict[str, Dict[str, torch.Tensor]], index: str 193 | ) -> Dict[str, torch.Tensor]: 194 | tensors = concat_tensors[index] 195 | batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] 196 | input_dict = dict(batch_index=batch_index) 197 | return input_dict 198 | 199 | @staticmethod 200 | def _get_generative_input_from_inference_outputs( 201 | inference_outputs: Dict[str, Dict[str, torch.Tensor]], data_source: str 202 | ) -> Dict[str, torch.Tensor]: 203 | z = inference_outputs[data_source]["z"] 204 | s = inference_outputs[data_source]["s"] 205 | library = inference_outputs[data_source]["library"] 206 | return dict(z=z, s=s, library=library) 207 | 208 | def _get_generative_input( 209 | self, 210 | concat_tensors: Dict[str, Dict[str, torch.Tensor]], 211 | inference_outputs: Dict[str, Dict[str, torch.Tensor]], 212 | ) -> Dict[str, Dict[str, torch.Tensor]]: 213 | background_tensor_input = self._get_generative_input_from_concat_tensors( 214 | concat_tensors, "background" 215 | ) 216 | target_tensor_input = self._get_generative_input_from_concat_tensors( 217 | concat_tensors, "target" 218 | ) 219 | # Ensure batch sizes are the same. 220 | min_batch_size = self._get_min_batch_size(concat_tensors) 221 | self._reduce_tensors_to_min_batch_size(background_tensor_input, min_batch_size) 222 | self._reduce_tensors_to_min_batch_size(target_tensor_input, min_batch_size) 223 | 224 | background_inference_outputs = ( 225 | self._get_generative_input_from_inference_outputs( 226 | inference_outputs, "background" 227 | ) 228 | ) 229 | target_inference_outputs = self._get_generative_input_from_inference_outputs( 230 | inference_outputs, "target" 231 | ) 232 | background = {**background_tensor_input, **background_inference_outputs} 233 | target = {**target_tensor_input, **target_inference_outputs} 234 | return dict(background=background, target=target) 235 | 236 | @staticmethod 237 | def _reshape_tensor_for_samples(tensor: torch.Tensor, n_samples: int): 238 | return tensor.unsqueeze(0).expand((n_samples, tensor.size(0), tensor.size(1))) 239 | 240 | @auto_move_data 241 | def _generic_inference( 242 | self, 243 | x: torch.Tensor, 244 | batch_index: torch.Tensor, 245 | n_samples: int = 1, 246 | ) -> Dict[str, torch.Tensor]: 247 | x_ = x 248 | if self.use_observed_lib_size: 249 | library = torch.log(x.sum(1)).unsqueeze(1) 250 | x_ = torch.log(1 + x_) 251 | 252 | qz_m, qz_v, z = self.z_encoder(x_, batch_index) 253 | qs_m, qs_v, s = self.s_encoder(x_, batch_index) 254 | 255 | ql_m, ql_v = None, None 256 | if not self.use_observed_lib_size: 257 | ql_m, ql_v, library_encoded = self.l_encoder(x_, batch_index) 258 | library = library_encoded 259 | 260 | if n_samples > 1: 261 | qz_m = self._reshape_tensor_for_samples(qz_m, n_samples) 262 | qz_v = self._reshape_tensor_for_samples(qz_v, n_samples) 263 | z = self._reshape_tensor_for_samples(z, n_samples) 264 | qs_m = self._reshape_tensor_for_samples(qs_m, n_samples) 265 | qs_v = self._reshape_tensor_for_samples(qs_v, n_samples) 266 | s = self._reshape_tensor_for_samples(s, n_samples) 267 | 268 | if self.use_observed_lib_size: 269 | library = self._reshape_tensor_for_samples(library, n_samples) 270 | else: 271 | ql_m = self._reshape_tensor_for_samples(ql_m, n_samples) 272 | ql_v = self._reshape_tensor_for_samples(ql_v, n_samples) 273 | library = Normal(ql_m, ql_v.sqrt()).sample() 274 | 275 | outputs = dict( 276 | z=z, 277 | qz_m=qz_m, 278 | qz_v=qz_v, 279 | s=s, 280 | qs_m=qs_m, 281 | qs_v=qs_v, 282 | library=library, 283 | ql_m=ql_m, 284 | ql_v=ql_v, 285 | ) 286 | return outputs 287 | 288 | @auto_move_data 289 | def inference( 290 | self, 291 | background: Dict[str, torch.Tensor], 292 | target: Dict[str, torch.Tensor], 293 | n_samples: int = 1, 294 | ) -> Dict[str, Dict[str, torch.Tensor]]: 295 | background_batch_size = background["x"].shape[0] 296 | target_batch_size = target["x"].shape[0] 297 | inference_input = {} 298 | for key in background.keys(): 299 | inference_input[key] = torch.cat([background[key], target[key]], dim=0) 300 | outputs = self._generic_inference(**inference_input, n_samples=n_samples) 301 | batch_size_dim = 0 if n_samples == 1 else 1 302 | background_outputs, target_outputs = {}, {} 303 | for key in outputs.keys(): 304 | if outputs[key] is not None: 305 | background_tensor, target_tensor = torch.split( 306 | outputs[key], 307 | [background_batch_size, target_batch_size], 308 | dim=batch_size_dim, 309 | ) 310 | else: 311 | background_tensor, target_tensor = None, None 312 | background_outputs[key] = background_tensor 313 | target_outputs[key] = target_tensor 314 | background_outputs["s"] = torch.zeros_like(background_outputs["s"]) 315 | return dict(background=background_outputs, target=target_outputs) 316 | 317 | @auto_move_data 318 | def _generic_generative( 319 | self, 320 | z: torch.Tensor, 321 | s: torch.Tensor, 322 | library: torch.Tensor, 323 | batch_index: torch.Tensor, 324 | ) -> Dict[str, torch.Tensor]: 325 | latent = torch.cat([z, s], dim=-1) 326 | px_scale, px_r, px_rate, px_dropout = self.decoder( 327 | self.dispersion, 328 | latent, 329 | library, 330 | batch_index, 331 | ) 332 | px_r = torch.exp(self.px_r) 333 | return dict( 334 | px_scale=px_scale, px_r=px_r, px_rate=px_rate, px_dropout=px_dropout 335 | ) 336 | 337 | @auto_move_data 338 | def generative( 339 | self, 340 | background: Dict[str, torch.Tensor], 341 | target: Dict[str, torch.Tensor], 342 | ) -> Dict[str, Dict[str, torch.Tensor]]: 343 | latent_z_shape = background["z"].shape 344 | batch_size_dim = 0 if len(latent_z_shape) == 2 else 1 345 | background_batch_size = background["z"].shape[batch_size_dim] 346 | target_batch_size = target["z"].shape[batch_size_dim] 347 | generative_input = {} 348 | for key in ["z", "s", "library"]: 349 | generative_input[key] = torch.cat( 350 | [background[key], target[key]], dim=batch_size_dim 351 | ) 352 | generative_input["batch_index"] = torch.cat( 353 | [background["batch_index"], target["batch_index"]], dim=0 354 | ) 355 | outputs = self._generic_generative(**generative_input) 356 | background_outputs, target_outputs = {}, {} 357 | for key in ["px_scale", "px_rate", "px_dropout"]: 358 | if outputs[key] is not None: 359 | background_tensor, target_tensor = torch.split( 360 | outputs[key], 361 | [background_batch_size, target_batch_size], 362 | dim=batch_size_dim, 363 | ) 364 | else: 365 | background_tensor, target_tensor = None, None 366 | background_outputs[key] = background_tensor 367 | target_outputs[key] = target_tensor 368 | background_outputs["px_r"] = outputs["px_r"] 369 | target_outputs["px_r"] = outputs["px_r"] 370 | return dict(background=background_outputs, target=target_outputs) 371 | 372 | @staticmethod 373 | def reconstruction_loss( 374 | x: torch.Tensor, 375 | px_rate: torch.Tensor, 376 | px_r: torch.Tensor, 377 | px_dropout: torch.Tensor, 378 | ) -> torch.Tensor: 379 | """ 380 | Compute likelihood loss for zero-inflated negative binomial distribution. 381 | 382 | Args: 383 | ---- 384 | x: Input data. 385 | px_rate: Mean of distribution. 386 | px_r: Inverse dispersion. 387 | px_dropout: Logits scale of zero inflation probability. 388 | 389 | Returns 390 | ------- 391 | Negative log likelihood (reconstruction loss) for each data point. If number 392 | of latent samples == 1, the tensor has shape `(batch_size, )`. If number 393 | of latent samples > 1, the tensor has shape `(n_samples, batch_size)`. 394 | """ 395 | recon_loss = ( 396 | -ZeroInflatedNegativeBinomial(mu=px_rate, theta=px_r, zi_logits=px_dropout) 397 | .log_prob(x) 398 | .sum(dim=-1) 399 | ) 400 | return recon_loss 401 | 402 | @staticmethod 403 | def latent_kl_divergence( 404 | variational_mean: torch.Tensor, 405 | variational_var: torch.Tensor, 406 | prior_mean: torch.Tensor, 407 | prior_var: torch.Tensor, 408 | ) -> torch.Tensor: 409 | """ 410 | Compute KL divergence between a variational posterior and prior Gaussian. 411 | Args: 412 | ---- 413 | variational_mean: Mean of the variational posterior Gaussian. 414 | variational_var: Variance of the variational posterior Gaussian. 415 | prior_mean: Mean of the prior Gaussian. 416 | prior_var: Variance of the prior Gaussian. 417 | 418 | Returns 419 | ------- 420 | KL divergence for each data point. If number of latent samples == 1, 421 | the tensor has shape `(batch_size, )`. If number of latent 422 | samples > 1, the tensor has shape `(n_samples, batch_size)`. 423 | """ 424 | return kl( 425 | Normal(variational_mean, variational_var.sqrt()), 426 | Normal(prior_mean, prior_var.sqrt()), 427 | ).sum(dim=-1) 428 | 429 | def library_kl_divergence( 430 | self, 431 | batch_index: torch.Tensor, 432 | variational_library_mean: torch.Tensor, 433 | variational_library_var: torch.Tensor, 434 | library: torch.Tensor, 435 | ) -> torch.Tensor: 436 | """ 437 | Compute KL divergence between library size variational posterior and prior. 438 | 439 | Both the variational posterior and prior are Log-Normal. 440 | Args: 441 | ---- 442 | batch_index: Batch indices for batch-specific library size mean and 443 | variance. 444 | variational_library_mean: Mean of variational Log-Normal. 445 | variational_library_var: Variance of variational Log-Normal. 446 | library: Sampled library size. 447 | 448 | Returns 449 | ------- 450 | KL divergence for each data point. If number of latent samples == 1, 451 | the tensor has shape `(batch_size, )`. If number of latent 452 | samples > 1, the tensor has shape `(n_samples, batch_size)`. 453 | """ 454 | if not self.use_observed_lib_size: 455 | ( 456 | local_library_log_means, 457 | local_library_log_vars, 458 | ) = self._compute_local_library_params(batch_index) 459 | 460 | kl_library = kl( 461 | Normal(variational_library_mean, variational_library_var.sqrt()), 462 | Normal(local_library_log_means, local_library_log_vars.sqrt()), 463 | ) 464 | else: 465 | kl_library = torch.zeros_like(library) 466 | return kl_library.sum(dim=-1) 467 | 468 | def _generic_loss( 469 | self, 470 | tensors: Dict[str, torch.Tensor], 471 | inference_outputs: Dict[str, torch.Tensor], 472 | generative_outputs: Dict[str, torch.Tensor], 473 | ) -> Dict[str, torch.Tensor]: 474 | x = tensors[REGISTRY_KEYS.X_KEY] 475 | batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] 476 | 477 | qz_m = inference_outputs["qz_m"] 478 | qz_v = inference_outputs["qz_v"] 479 | qs_m = inference_outputs["qs_m"] 480 | qs_v = inference_outputs["qs_v"] 481 | library = inference_outputs["library"] 482 | ql_m = inference_outputs["ql_m"] 483 | ql_v = inference_outputs["ql_v"] 484 | px_rate = generative_outputs["px_rate"] 485 | px_r = generative_outputs["px_r"] 486 | px_dropout = generative_outputs["px_dropout"] 487 | 488 | prior_z_m = torch.zeros_like(qz_m) 489 | prior_z_v = torch.ones_like(qz_v) 490 | prior_s_m = torch.zeros_like(qs_m) 491 | prior_s_v = torch.ones_like(qs_v) 492 | 493 | recon_loss = self.reconstruction_loss(x, px_rate, px_r, px_dropout) 494 | kl_z = self.latent_kl_divergence(qz_m, qz_v, prior_z_m, prior_z_v) 495 | kl_s = self.latent_kl_divergence(qs_m, qs_v, prior_s_m, prior_s_v) 496 | kl_library = self.library_kl_divergence(batch_index, ql_m, ql_v, library) 497 | return dict( 498 | recon_loss=recon_loss, 499 | kl_z=kl_z, 500 | kl_s=kl_s, 501 | kl_library=kl_library, 502 | ) 503 | 504 | def loss( 505 | self, 506 | concat_tensors: Dict[str, Dict[str, torch.Tensor]], 507 | inference_outputs: Dict[str, Dict[str, torch.Tensor]], 508 | generative_outputs: Dict[str, Dict[str, torch.Tensor]], 509 | kl_weight: float = 1.0, 510 | ) -> LossRecorder: 511 | """ 512 | Compute loss terms for contrastive-VI. 513 | Args: 514 | ---- 515 | concat_tensors: Tuple of data mini-batch. The first element contains 516 | background data mini-batch. The second element contains target data 517 | mini-batch. 518 | inference_outputs: Dictionary of inference step outputs. The keys 519 | are "background" and "target" for the corresponding outputs. 520 | generative_outputs: Dictionary of generative step outputs. The keys 521 | are "background" and "target" for the corresponding outputs. 522 | kl_weight: Importance weight for KL divergence of background and salient 523 | latent variables, relative to KL divergence of library size. 524 | 525 | Returns 526 | ------- 527 | An scvi.module.base.LossRecorder instance that records the following: 528 | loss: One-dimensional tensor for overall loss used for optimization. 529 | reconstruction_loss: Reconstruction loss with shape 530 | `(n_samples, batch_size)` if number of latent samples > 1, or 531 | `(batch_size, )` if number of latent samples == 1. 532 | kl_local: KL divergence term with shape 533 | `(n_samples, batch_size)` if number of latent samples > 1, or 534 | `(batch_size, )` if number of latent samples == 1. 535 | kl_global: One-dimensional tensor for global KL divergence term. 536 | """ 537 | background_tensors = concat_tensors["background"] 538 | target_tensors = concat_tensors["target"] 539 | # Ensure batch sizes are the same. 540 | min_batch_size = self._get_min_batch_size(concat_tensors) 541 | self._reduce_tensors_to_min_batch_size(background_tensors, min_batch_size) 542 | self._reduce_tensors_to_min_batch_size(target_tensors, min_batch_size) 543 | 544 | background_losses = self._generic_loss( 545 | background_tensors, 546 | inference_outputs["background"], 547 | generative_outputs["background"], 548 | ) 549 | target_losses = self._generic_loss( 550 | target_tensors, 551 | inference_outputs["target"], 552 | generative_outputs["target"], 553 | ) 554 | reconst_loss = background_losses["recon_loss"] + target_losses["recon_loss"] 555 | kl_divergence_z = background_losses["kl_z"] + target_losses["kl_z"] 556 | kl_divergence_s = target_losses["kl_s"] 557 | kl_divergence_l = background_losses["kl_library"] + target_losses["kl_library"] 558 | 559 | wasserstein_loss = ( 560 | torch.norm(inference_outputs["background"]["qs_m"], dim=-1)**2 561 | + torch.sum(inference_outputs["background"]["qs_v"], dim=-1) 562 | ) 563 | 564 | kl_local_for_warmup = kl_divergence_z + kl_divergence_s 565 | kl_local_no_warmup = kl_divergence_l 566 | 567 | weighted_kl_local = kl_weight * (self.wasserstein_penalty*wasserstein_loss 568 | + kl_local_for_warmup) + kl_local_no_warmup 569 | 570 | loss = torch.mean(reconst_loss + weighted_kl_local) 571 | 572 | kl_local = dict( 573 | kl_divergence_l=kl_divergence_l, 574 | kl_divergence_z=kl_divergence_z, 575 | kl_divergence_s=kl_divergence_s 576 | ) 577 | kl_global = torch.tensor(0.0) 578 | 579 | # LossRecorder internally sums the `reconst_loss`, `kl_local`, and `kl_global` 580 | # terms before logging, so we do the same for our `wasserstein_loss` term. 581 | return LossRecorder( 582 | loss, 583 | reconst_loss, 584 | kl_local, 585 | kl_global, 586 | wasserstein_loss=torch.sum(wasserstein_loss) 587 | ) 588 | 589 | @torch.no_grad() 590 | def sample(self): 591 | raise NotImplementedError 592 | 593 | @torch.no_grad() 594 | @auto_move_data 595 | def marginal_ll(self): 596 | raise NotImplementedError 597 | -------------------------------------------------------------------------------- /contrastive_vi/module/total_contrastive_vi.py: -------------------------------------------------------------------------------- 1 | """PyTorch module for Contrastive VI for single cell expression data.""" 2 | 3 | from typing import Dict, Optional, Tuple, Union, Literal, Iterable 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | from scvi import REGISTRY_KEYS 9 | from scvi.distributions import NegativeBinomial, NegativeBinomialMixture 10 | from scvi.module.base import BaseModuleClass, LossRecorder, auto_move_data 11 | from scvi.nn import DecoderTOTALVI, EncoderTOTALVI, one_hot 12 | from torch.distributions import Normal 13 | from torch.distributions import kl_divergence as kl 14 | 15 | torch.backends.cudnn.benchmark = True 16 | 17 | 18 | class TotalContrastiveVIModule(BaseModuleClass): 19 | """ 20 | PyTorch module for total-contrastiveVI (contrastive analysis for CITE-seq). 21 | 22 | Args: 23 | ---- 24 | n_input_genes: Number of input genes. 25 | n_input_proteins: Number of input proteins. 26 | n_batch: Number of batches. If 0, no batch effect correction is performed. 27 | n_hidden: Number of nodes per hidden layer. 28 | n_background_latent: Dimensionality of the background latent space. 29 | n_salient_latent: Dimensionality of the salient latent space. 30 | n_layers: Number of hidden layers used for encoder and decoder NNs. 31 | dropout_rate: Dropout rate for neural networks. 32 | protein_batch_mask: Dictionary where each key is a batch code, and value is for 33 | each protein, whether it was observed or not. 34 | use_observed_lib_size: Use observed library size for RNA as scaling factor in 35 | mean of conditional distribution. 36 | library_log_means: 1 x n_batch array of means of the log library sizes. 37 | Parameterize prior on library size if not using observed library size. 38 | library_log_vars: 1 x n_batch array of variances of the log library sizes. 39 | Parameterize prior on library size if not using observed library size. 40 | wasserstein_penalty: Weight of the Wasserstein distance loss that further 41 | discourages shared variations from leaking into the salient latent space. 42 | """ 43 | 44 | def __init__( 45 | self, 46 | n_input_genes: int, 47 | n_input_proteins: int, 48 | n_batch: int = 0, 49 | n_labels: int = 0, 50 | n_hidden: int = 128, 51 | n_background_latent: int = 10, 52 | n_salient_latent: int = 10, 53 | n_layers_encoder: int = 2, 54 | n_layers_decoder: int = 1, 55 | n_continuous_cov: int = 0, 56 | n_cats_per_cov: Optional[Iterable[int]] = None, 57 | dropout_rate_decoder: float = 0.2, 58 | dropout_rate_encoder: float = 0.2, 59 | gene_dispersion: str = "gene", 60 | protein_dispersion: str = "protein", 61 | log_variational: bool = True, 62 | gene_likelihood: str = "nb", 63 | latent_distribution: str = "normal", 64 | protein_batch_mask: Dict[Union[str, int], np.ndarray] = None, 65 | encode_covariates: bool = True, 66 | protein_background_prior_mean: Optional[np.ndarray] = None, 67 | protein_background_prior_scale: Optional[np.ndarray] = None, 68 | use_size_factor_key: bool = False, 69 | use_observed_lib_size: bool = True, 70 | library_log_means: Optional[np.ndarray] = None, 71 | library_log_vars: Optional[np.ndarray] = None, 72 | use_batch_norm: Literal["encoder", "decoder", "none", "both"] = "both", 73 | use_layer_norm: Literal["encoder", "decoder", "none", "both"] = "none", 74 | wasserstein_penalty: float = 0, 75 | ) -> None: 76 | super().__init__() 77 | self.gene_dispersion = gene_dispersion 78 | self.n_background_latent = n_background_latent 79 | self.n_salient_latent = n_salient_latent 80 | self.log_variational = log_variational 81 | self.gene_likelihood = gene_likelihood 82 | self.n_batch = n_batch 83 | self.n_labels = n_labels 84 | self.n_input_genes = n_input_genes 85 | self.n_input_proteins = n_input_proteins 86 | self.protein_dispersion = protein_dispersion 87 | self.latent_distribution = latent_distribution 88 | self.protein_batch_mask = protein_batch_mask 89 | self.encode_covariates = encode_covariates 90 | self.use_size_factor_key = use_size_factor_key 91 | self.use_observed_lib_size = use_size_factor_key or use_observed_lib_size 92 | self.wasserstein_penalty = wasserstein_penalty 93 | 94 | if not self.use_observed_lib_size: 95 | if library_log_means is None or library_log_means is None: 96 | raise ValueError( 97 | "If not using observed_lib_size, " 98 | "must provide library_log_means and library_log_vars." 99 | ) 100 | 101 | self.register_buffer( 102 | "library_log_means", torch.from_numpy(library_log_means).float() 103 | ) 104 | self.register_buffer( 105 | "library_log_vars", torch.from_numpy(library_log_vars).float() 106 | ) 107 | 108 | # parameters for prior on rate_back (background protein mean) 109 | if protein_background_prior_mean is None: 110 | if n_batch > 0: 111 | self.background_pro_alpha = torch.nn.Parameter( 112 | torch.randn(n_input_proteins, n_batch) 113 | ) 114 | self.background_pro_log_beta = torch.nn.Parameter( 115 | torch.clamp(torch.randn(n_input_proteins, n_batch), -10, 1) 116 | ) 117 | else: 118 | self.background_pro_alpha = torch.nn.Parameter( 119 | torch.randn(n_input_proteins) 120 | ) 121 | self.background_pro_log_beta = torch.nn.Parameter( 122 | torch.clamp(torch.randn(n_input_proteins), -10, 1) 123 | ) 124 | else: 125 | if protein_background_prior_mean.shape[1] == 1 and n_batch != 1: 126 | init_mean = protein_background_prior_mean.ravel() 127 | init_scale = protein_background_prior_scale.ravel() 128 | else: 129 | init_mean = protein_background_prior_mean 130 | init_scale = protein_background_prior_scale 131 | self.background_pro_alpha = torch.nn.Parameter( 132 | torch.from_numpy(init_mean.astype(np.float32)) 133 | ) 134 | self.background_pro_log_beta = torch.nn.Parameter( 135 | torch.log(torch.from_numpy(init_scale.astype(np.float32))) 136 | ) 137 | 138 | if self.gene_dispersion == "gene": 139 | self.px_r = torch.nn.Parameter(torch.randn(n_input_genes)) 140 | elif self.gene_dispersion == "gene-batch": 141 | self.px_r = torch.nn.Parameter(torch.randn(n_input_genes, n_batch)) 142 | elif self.gene_dispersion == "gene-label": 143 | self.px_r = torch.nn.Parameter(torch.randn(n_input_genes, n_labels)) 144 | else: # gene-cell 145 | pass 146 | 147 | if self.protein_dispersion == "protein": 148 | self.py_r = torch.nn.Parameter(2 * torch.rand(self.n_input_proteins)) 149 | elif self.protein_dispersion == "protein-batch": 150 | self.py_r = torch.nn.Parameter( 151 | 2 * torch.rand(self.n_input_proteins, n_batch) 152 | ) 153 | elif self.protein_dispersion == "protein-label": 154 | self.py_r = torch.nn.Parameter( 155 | 2 * torch.rand(self.n_input_proteins, n_labels) 156 | ) 157 | else: # protein-cell 158 | pass 159 | 160 | use_batch_norm_encoder = use_batch_norm == "encoder" or use_batch_norm == "both" 161 | use_batch_norm_decoder = use_batch_norm == "decoder" or use_batch_norm == "both" 162 | use_layer_norm_encoder = use_layer_norm == "encoder" or use_layer_norm == "both" 163 | use_layer_norm_decoder = use_layer_norm == "decoder" or use_layer_norm == "both" 164 | 165 | # z encoder goes from the n_input-dimensional data to an n_latent-d 166 | # latent space representation 167 | n_input = n_input_genes + self.n_input_proteins 168 | n_input_encoder = n_input + n_continuous_cov * encode_covariates 169 | cat_list = [n_batch] + list([] if n_cats_per_cov is None else n_cats_per_cov) 170 | encoder_cat_list = cat_list if encode_covariates else None 171 | self.z_encoder = EncoderTOTALVI( 172 | n_input_encoder, 173 | n_background_latent, 174 | n_layers=n_layers_encoder, 175 | n_cat_list=encoder_cat_list, 176 | n_hidden=n_hidden, 177 | dropout_rate=dropout_rate_encoder, 178 | distribution=latent_distribution, 179 | use_batch_norm=use_batch_norm_encoder, 180 | use_layer_norm=use_layer_norm_encoder, 181 | ) 182 | 183 | self.s_encoder = EncoderTOTALVI( 184 | n_input_encoder, 185 | n_salient_latent, 186 | n_layers=n_layers_encoder, 187 | n_cat_list=encoder_cat_list, 188 | n_hidden=n_hidden, 189 | dropout_rate=dropout_rate_encoder, 190 | distribution=latent_distribution, 191 | use_batch_norm=use_batch_norm_encoder, 192 | use_layer_norm=use_layer_norm_encoder, 193 | ) 194 | n_total_latent = n_background_latent + n_salient_latent 195 | self.decoder = DecoderTOTALVI( 196 | n_total_latent + n_continuous_cov, 197 | n_input_genes, 198 | self.n_input_proteins, 199 | n_layers=n_layers_decoder, 200 | n_cat_list=cat_list, 201 | n_hidden=n_hidden, 202 | dropout_rate=dropout_rate_decoder, 203 | use_batch_norm=use_batch_norm_decoder, 204 | use_layer_norm=use_layer_norm_decoder, 205 | scale_activation="softplus" if use_size_factor_key else "softmax", 206 | ) 207 | 208 | @auto_move_data 209 | def _compute_local_library_params( 210 | self, batch_index: torch.Tensor 211 | ) -> Tuple[torch.Tensor, torch.Tensor]: 212 | n_batch = self.library_log_means.shape[1] 213 | local_library_log_means = F.linear( 214 | one_hot(batch_index, n_batch), self.library_log_means 215 | ) 216 | local_library_log_vars = F.linear( 217 | one_hot(batch_index, n_batch), self.library_log_vars 218 | ) 219 | return local_library_log_means, local_library_log_vars 220 | 221 | @staticmethod 222 | def _get_min_batch_size(concat_tensors: Dict[str, Dict[str, torch.Tensor]]) -> int: 223 | return min( 224 | concat_tensors["background"][REGISTRY_KEYS.X_KEY].shape[0], 225 | concat_tensors["target"][REGISTRY_KEYS.X_KEY].shape[0], 226 | ) 227 | 228 | @staticmethod 229 | def _reduce_tensors_to_min_batch_size( 230 | tensors: Dict[str, torch.Tensor], min_batch_size: int 231 | ) -> None: 232 | for name, tensor in tensors.items(): 233 | tensors[name] = tensor[:min_batch_size, :] 234 | 235 | @staticmethod 236 | def _get_inference_input_from_concat_tensors( 237 | concat_tensors: Dict[str, Dict[str, torch.Tensor]], index: str 238 | ) -> Dict[str, torch.Tensor]: 239 | tensors = concat_tensors[index] 240 | x = tensors[REGISTRY_KEYS.X_KEY] 241 | y = tensors[REGISTRY_KEYS.PROTEIN_EXP_KEY] 242 | batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] 243 | input_dict = dict(x=x, y=y, batch_index=batch_index) 244 | return input_dict 245 | 246 | def _get_inference_input( 247 | self, concat_tensors: Dict[str, Dict[str, torch.Tensor]] 248 | ) -> Dict[str, Dict[str, torch.Tensor]]: 249 | background = self._get_inference_input_from_concat_tensors( 250 | concat_tensors, "background" 251 | ) 252 | target = self._get_inference_input_from_concat_tensors(concat_tensors, "target") 253 | # Ensure batch sizes are the same. 254 | min_batch_size = self._get_min_batch_size(concat_tensors) 255 | self._reduce_tensors_to_min_batch_size(background, min_batch_size) 256 | self._reduce_tensors_to_min_batch_size(target, min_batch_size) 257 | return dict(background=background, target=target) 258 | 259 | @staticmethod 260 | def _get_generative_input_from_concat_tensors( 261 | concat_tensors: Dict[str, Dict[str, torch.Tensor]], index: str 262 | ) -> Dict[str, torch.Tensor]: 263 | tensors = concat_tensors[index] 264 | batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] 265 | input_dict = dict(batch_index=batch_index) 266 | return input_dict 267 | 268 | @staticmethod 269 | def _get_generative_input_from_inference_outputs( 270 | inference_outputs: Dict[str, Dict[str, torch.Tensor]], data_source: str 271 | ) -> Dict[str, torch.Tensor]: 272 | z = inference_outputs[data_source]["z"] 273 | s = inference_outputs[data_source]["s"] 274 | library_gene = inference_outputs[data_source]["library_gene"] 275 | return dict(z=z, s=s, library_gene=library_gene) 276 | 277 | def _get_generative_input( 278 | self, 279 | concat_tensors: Dict[str, Dict[str, torch.Tensor]], 280 | inference_outputs: Dict[str, Dict[str, torch.Tensor]], 281 | ) -> Dict[str, Dict[str, torch.Tensor]]: 282 | background_tensor_input = self._get_generative_input_from_concat_tensors( 283 | concat_tensors, "background" 284 | ) 285 | target_tensor_input = self._get_generative_input_from_concat_tensors( 286 | concat_tensors, "target" 287 | ) 288 | # Ensure batch sizes are the same. 289 | min_batch_size = self._get_min_batch_size(concat_tensors) 290 | self._reduce_tensors_to_min_batch_size(background_tensor_input, min_batch_size) 291 | self._reduce_tensors_to_min_batch_size(target_tensor_input, min_batch_size) 292 | 293 | background_inference_outputs = ( 294 | self._get_generative_input_from_inference_outputs( 295 | inference_outputs, "background" 296 | ) 297 | ) 298 | target_inference_outputs = self._get_generative_input_from_inference_outputs( 299 | inference_outputs, "target" 300 | ) 301 | background = {**background_tensor_input, **background_inference_outputs} 302 | target = {**target_tensor_input, **target_inference_outputs} 303 | return dict(background=background, target=target) 304 | 305 | @staticmethod 306 | def _reshape_tensor_for_samples(tensor: torch.Tensor, n_samples: int): 307 | return tensor.unsqueeze(0).expand((n_samples, tensor.size(0), tensor.size(1))) 308 | 309 | @auto_move_data 310 | def _generic_inference( 311 | self, 312 | x: torch.Tensor, 313 | y: torch.Tensor, 314 | batch_index: torch.Tensor, 315 | n_samples: int = 1, 316 | ) -> Dict[str, torch.Tensor]: 317 | x_ = x 318 | y_ = y 319 | if self.use_observed_lib_size: 320 | library_gene = x.sum(1).unsqueeze(1) 321 | x_ = torch.log(1 + x_) 322 | y_ = torch.log(1 + y_) 323 | encoder_input = torch.cat((x_, y_), dim=-1) 324 | 325 | ( 326 | qz_m, 327 | qz_v, 328 | ql_m, 329 | ql_v, 330 | background_latent, 331 | untran_background_latent, 332 | ) = self.z_encoder(encoder_input, batch_index) 333 | z = background_latent["z"] 334 | untran_z = untran_background_latent["z"] 335 | untran_l = untran_background_latent["l"] # Library encoder used and updated. 336 | (qs_m, qs_v, _, _, salient_latent, untran_salient_latent) = self.s_encoder( 337 | encoder_input, batch_index 338 | ) 339 | s = salient_latent["z"] 340 | untran_s = untran_salient_latent["z"] 341 | # Library encoder not used and not updated. 342 | 343 | if not self.use_observed_lib_size: 344 | library_gene = background_latent["l"] 345 | 346 | if n_samples > 1: 347 | qz_m = self._reshape_tensor_for_samples(qz_m, n_samples) 348 | qz_v = self._reshape_tensor_for_samples(qz_v, n_samples) 349 | untran_z = Normal(qz_m, qz_v.sqrt()).sample() 350 | z = self.z_encoder.z_transformation(untran_z) 351 | 352 | qs_m = self._reshape_tensor_for_samples(qs_m, n_samples) 353 | qs_v = self._reshape_tensor_for_samples(qs_v, n_samples) 354 | untran_s = Normal(qs_m, qs_v.sqrt()).sample() 355 | s = self.s_encoder.z_transformation(untran_s) 356 | 357 | ql_m = self._reshape_tensor_for_samples(ql_m, n_samples) 358 | ql_v = self._reshape_tensor_for_samples(ql_v, n_samples) 359 | untran_l = Normal(ql_m, ql_v.sqrt()).sample() 360 | 361 | if self.use_observed_lib_size: 362 | library_gene = self._reshape_tensor_for_samples(library_gene, n_samples) 363 | else: 364 | library_gene = self.z_encoder.l_transformation(untran_l) 365 | 366 | if self.n_batch > 0: 367 | py_back_alpha_prior = F.linear( 368 | one_hot(batch_index, self.n_batch), self.background_pro_alpha 369 | ) 370 | py_back_beta_prior = F.linear( 371 | one_hot(batch_index, self.n_batch), 372 | torch.exp(self.background_pro_log_beta), 373 | ) 374 | else: 375 | py_back_alpha_prior = self.background_pro_alpha 376 | py_back_beta_prior = torch.exp(self.background_pro_log_beta) 377 | 378 | back_mean_prior = Normal(py_back_alpha_prior, py_back_beta_prior) 379 | 380 | outputs = dict( 381 | untran_z=untran_z, 382 | z=z, 383 | qz_m=qz_m, 384 | qz_v=qz_v, 385 | untran_s=untran_s, 386 | s=s, 387 | qs_m=qs_m, 388 | qs_v=qs_v, 389 | library_gene=library_gene, 390 | ql_m=ql_m, 391 | ql_v=ql_v, 392 | untran_l=untran_l, 393 | back_mean_prior=back_mean_prior 394 | ) 395 | return outputs 396 | 397 | @auto_move_data 398 | def inference( 399 | self, 400 | background: Dict[str, torch.Tensor], 401 | target: Dict[str, torch.Tensor], 402 | n_samples: int = 1, 403 | ) -> Dict[str, Dict[str, torch.Tensor]]: 404 | background_outputs = self._generic_inference(**background, n_samples=n_samples) 405 | target_outputs = self._generic_inference(**target, n_samples=n_samples) 406 | background_outputs["s"] = torch.zeros_like(background_outputs["s"]) 407 | return dict(background=background_outputs, target=target_outputs) 408 | 409 | @auto_move_data 410 | def _generic_generative( 411 | self, 412 | z: torch.Tensor, 413 | s: torch.Tensor, 414 | library_gene: torch.Tensor, 415 | batch_index: torch.Tensor, 416 | ) -> Dict[str, torch.Tensor]: 417 | latent = torch.cat([z, s], dim=-1) 418 | px_, py_, log_pro_back_mean = self.decoder( 419 | latent, 420 | library_gene, 421 | batch_index, 422 | ) 423 | px_r = torch.exp(self.px_r) 424 | py_r = torch.exp(self.py_r) 425 | px_["r"] = px_r 426 | py_["r"] = py_r 427 | return dict(px_=px_, py_=py_, log_pro_back_mean=log_pro_back_mean) 428 | 429 | @auto_move_data 430 | def generative( 431 | self, 432 | background: Dict[str, torch.Tensor], 433 | target: Dict[str, torch.Tensor], 434 | ) -> Dict[str, Dict[str, torch.Tensor]]: 435 | latent_z_shape = background["z"].shape 436 | batch_size_dim = 0 if len(latent_z_shape) == 2 else 1 437 | background_batch_size = background["z"].shape[batch_size_dim] 438 | target_batch_size = target["z"].shape[batch_size_dim] 439 | generative_input = {} 440 | for key in ["z", "s", "library_gene"]: 441 | generative_input[key] = torch.cat( 442 | [background[key], target[key]], dim=batch_size_dim 443 | ) 444 | generative_input["batch_index"] = torch.cat( 445 | [background["batch_index"], target["batch_index"]], dim=0 446 | ) 447 | outputs = self._generic_generative(**generative_input) 448 | 449 | # Split outputs into corresponding background and target set. 450 | background_outputs = {"px_": {}, "py_": {}} 451 | target_outputs = {"px_": {}, "py_": {}} 452 | for modality in ["px_", "py_"]: 453 | for key in outputs[modality].keys(): 454 | if key == "r": 455 | background_tensor = outputs[modality][key] 456 | target_tensor = outputs[modality][key] 457 | else: 458 | if outputs[modality][key] is not None: 459 | background_tensor, target_tensor = torch.split( 460 | outputs[modality][key], 461 | [background_batch_size, target_batch_size], 462 | dim=batch_size_dim, 463 | ) 464 | else: 465 | background_tensor, target_tensor = None, None 466 | background_outputs[modality][key] = background_tensor 467 | target_outputs[modality][key] = target_tensor 468 | 469 | if outputs["log_pro_back_mean"] is not None: 470 | background_tensor, target_tensor = torch.split( 471 | outputs["log_pro_back_mean"], 472 | [background_batch_size, target_batch_size], 473 | dim=batch_size_dim, 474 | ) 475 | else: 476 | background_tensor, target_tensor = None, None 477 | background_outputs["log_pro_back_mean"] = background_tensor 478 | target_outputs["log_pro_back_mean"] = target_tensor 479 | 480 | return dict(background=background_outputs, target=target_outputs) 481 | 482 | @staticmethod 483 | def get_reconstruction_loss( 484 | x: torch.Tensor, 485 | y: torch.Tensor, 486 | px_dict: Dict[str, torch.Tensor], 487 | py_dict: Dict[str, torch.Tensor], 488 | pro_batch_mask_minibatch: Optional[torch.Tensor] = None, 489 | ) -> Tuple[torch.Tensor, torch.Tensor]: 490 | """Compute reconstruction loss.""" 491 | px_ = px_dict 492 | py_ = py_dict 493 | 494 | reconst_loss_gene = ( 495 | -NegativeBinomial(mu=px_["rate"], theta=px_["r"]).log_prob(x).sum(dim=-1) 496 | ) 497 | 498 | py_conditional = NegativeBinomialMixture( 499 | mu1=py_["rate_back"], 500 | mu2=py_["rate_fore"], 501 | theta1=py_["r"], 502 | mixture_logits=py_["mixing"], 503 | ) 504 | reconst_loss_protein_full = -py_conditional.log_prob(y) 505 | if pro_batch_mask_minibatch is not None: 506 | temp_pro_loss_full = torch.zeros_like(reconst_loss_protein_full) 507 | temp_pro_loss_full.masked_scatter_( 508 | pro_batch_mask_minibatch.bool(), reconst_loss_protein_full 509 | ) 510 | 511 | reconst_loss_protein = temp_pro_loss_full.sum(dim=-1) 512 | else: 513 | reconst_loss_protein = reconst_loss_protein_full.sum(dim=-1) 514 | 515 | return reconst_loss_gene, reconst_loss_protein 516 | 517 | @staticmethod 518 | def latent_kl_divergence( 519 | variational_mean: torch.Tensor, 520 | variational_var: torch.Tensor, 521 | prior_mean: torch.Tensor, 522 | prior_var: torch.Tensor, 523 | ) -> torch.Tensor: 524 | """ 525 | Compute KL divergence between a variational posterior and prior Gaussian. 526 | Args: 527 | ---- 528 | variational_mean: Mean of the variational posterior Gaussian. 529 | variational_var: Variance of the variational posterior Gaussian. 530 | prior_mean: Mean of the prior Gaussian. 531 | prior_var: Variance of the prior Gaussian. 532 | 533 | Returns 534 | ------- 535 | KL divergence for each data point. If number of latent samples == 1, 536 | the tensor has shape `(batch_size, )`. If number of latent 537 | samples > 1, the tensor has shape `(n_samples, batch_size)`. 538 | """ 539 | return kl( 540 | Normal(variational_mean, variational_var.sqrt()), 541 | Normal(prior_mean, prior_var.sqrt()), 542 | ).sum(dim=-1) 543 | 544 | def library_gene_kl_divergence( 545 | self, 546 | batch_index: torch.Tensor, 547 | variational_library_mean: torch.Tensor, 548 | variational_library_var: torch.Tensor, 549 | library: torch.Tensor, 550 | ) -> torch.Tensor: 551 | """ 552 | Compute KL divergence between library size variational posterior and prior. 553 | 554 | Both the variational posterior and prior are Log-Normal. 555 | Args: 556 | ---- 557 | batch_index: Batch indices for batch-specific library size mean and 558 | variance. 559 | variational_library_mean: Mean of variational Log-Normal. 560 | variational_library_var: Variance of variational Log-Normal. 561 | library: Sampled library size. 562 | 563 | Returns 564 | ------- 565 | KL divergence for each data point. If number of latent samples == 1, 566 | the tensor has shape `(batch_size, )`. If number of latent 567 | samples > 1, the tensor has shape `(n_samples, batch_size)`. 568 | """ 569 | if not self.use_observed_lib_size: 570 | ( 571 | local_library_log_means, 572 | local_library_log_vars, 573 | ) = self._compute_local_library_params(batch_index) 574 | 575 | kl_library = kl( 576 | Normal(variational_library_mean, variational_library_var.sqrt()), 577 | Normal(local_library_log_means, local_library_log_vars.sqrt()), 578 | ) 579 | else: 580 | kl_library = torch.zeros_like(library) 581 | return kl_library.sum(dim=-1) 582 | 583 | def _generic_loss( 584 | self, 585 | tensors: Dict[str, torch.Tensor], 586 | inference_outputs: Dict[str, torch.Tensor], 587 | generative_outputs: Dict[str, torch.Tensor], 588 | ) -> Dict[str, torch.Tensor]: 589 | x = tensors[REGISTRY_KEYS.X_KEY] 590 | y = tensors[REGISTRY_KEYS.PROTEIN_EXP_KEY] 591 | batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] 592 | 593 | qz_m = inference_outputs["qz_m"] 594 | qz_v = inference_outputs["qz_v"] 595 | qs_m = inference_outputs["qs_m"] 596 | qs_v = inference_outputs["qs_v"] 597 | ql_m = inference_outputs["ql_m"] 598 | ql_v = inference_outputs["ql_v"] 599 | library_gene = inference_outputs["library_gene"] 600 | px_ = generative_outputs["px_"] 601 | py_ = generative_outputs["py_"] 602 | 603 | prior_z_m = torch.zeros_like(qz_m) 604 | prior_z_v = torch.ones_like(qz_v) 605 | prior_s_m = torch.zeros_like(qs_m) 606 | prior_s_v = torch.ones_like(qs_v) 607 | 608 | if self.protein_batch_mask is not None: 609 | pro_batch_mask_minibatch = torch.zeros_like(y) 610 | for b in torch.unique(batch_index): 611 | b_indices = (batch_index == b).reshape(-1) 612 | pro_batch_mask_minibatch[b_indices] = torch.tensor( 613 | self.protein_batch_mask[b.item()].astype(np.float32), 614 | device=y.device, 615 | ) 616 | else: 617 | pro_batch_mask_minibatch = None 618 | reconst_loss_gene, reconst_loss_protein = self.get_reconstruction_loss( 619 | x, y, px_, py_, pro_batch_mask_minibatch 620 | ) 621 | 622 | kl_z = self.latent_kl_divergence(qz_m, qz_v, prior_z_m, prior_z_v) 623 | kl_s = self.latent_kl_divergence(qs_m, qs_v, prior_s_m, prior_s_v) 624 | kl_library_gene = self.library_gene_kl_divergence( 625 | batch_index, ql_m, ql_v, library_gene 626 | ) 627 | 628 | kl_div_back_pro_full = kl( 629 | Normal(py_["back_alpha"], py_["back_beta"]), inference_outputs["back_mean_prior"] 630 | ) 631 | if pro_batch_mask_minibatch is not None: 632 | kl_div_back_pro = (pro_batch_mask_minibatch * kl_div_back_pro_full).sum( 633 | dim=-1 634 | ) 635 | else: 636 | kl_div_back_pro = kl_div_back_pro_full.sum(dim=-1) 637 | 638 | return dict( 639 | reconst_loss_gene=reconst_loss_gene, 640 | reconst_loss_protein=reconst_loss_protein, 641 | kl_z=kl_z, 642 | kl_s=kl_s, 643 | kl_library_gene=kl_library_gene, 644 | kl_div_back_pro=kl_div_back_pro, 645 | ) 646 | 647 | def loss( 648 | self, 649 | concat_tensors: Dict[str, Dict[str, torch.Tensor]], 650 | inference_outputs: Dict[str, Dict[str, torch.Tensor]], 651 | generative_outputs: Dict[str, Dict[str, torch.Tensor]], 652 | kl_weight: float = 1.0, 653 | ) -> LossRecorder: 654 | """ 655 | Compute loss terms for contrastive-VI. 656 | Args: 657 | ---- 658 | concat_tensors: Tuple of data mini-batch. The first element contains 659 | background data mini-batch. The second element contains target data 660 | mini-batch. 661 | inference_outputs: Dictionary of inference step outputs. The keys 662 | are "background" and "target" for the corresponding outputs. 663 | generative_outputs: Dictionary of generative step outputs. The keys 664 | are "background" and "target" for the corresponding outputs. 665 | 666 | Returns 667 | ------- 668 | An scvi.module.base.LossRecorder instance that records the losses. 669 | """ 670 | background_tensors = concat_tensors["background"] 671 | target_tensors = concat_tensors["target"] 672 | # Ensure batch sizes are the same. 673 | min_batch_size = self._get_min_batch_size(concat_tensors) 674 | self._reduce_tensors_to_min_batch_size(background_tensors, min_batch_size) 675 | self._reduce_tensors_to_min_batch_size(target_tensors, min_batch_size) 676 | 677 | background_losses = self._generic_loss( 678 | background_tensors, 679 | inference_outputs["background"], 680 | generative_outputs["background"], 681 | ) 682 | target_losses = self._generic_loss( 683 | target_tensors, 684 | inference_outputs["target"], 685 | generative_outputs["target"], 686 | ) 687 | 688 | reconst_loss_gene = ( 689 | background_losses["reconst_loss_gene"] + target_losses["reconst_loss_gene"] 690 | ) 691 | reconst_loss_protein = ( 692 | background_losses["reconst_loss_protein"] 693 | + target_losses["reconst_loss_protein"] 694 | ) 695 | 696 | wasserstein_loss = ( 697 | torch.norm(inference_outputs["background"]["qs_m"], dim=-1)**2 698 | + torch.sum(inference_outputs["background"]["qs_v"], dim=-1) 699 | ) 700 | 701 | kl_div_z = background_losses["kl_z"] + target_losses["kl_z"] 702 | kl_div_s = target_losses["kl_s"] 703 | kl_div_l_gene = ( 704 | background_losses["kl_library_gene"] + target_losses["kl_library_gene"] 705 | ) 706 | kl_div_back_pro = ( 707 | background_losses["kl_div_back_pro"] + target_losses["kl_div_back_pro"] 708 | ) 709 | 710 | loss = torch.mean( 711 | reconst_loss_gene 712 | + reconst_loss_protein 713 | + kl_weight * kl_div_z 714 | + kl_weight * kl_div_s 715 | + kl_weight * self.wasserstein_penalty*wasserstein_loss 716 | + kl_div_l_gene 717 | + kl_weight * kl_div_back_pro 718 | ) 719 | 720 | reconst_losses = dict( 721 | reconst_loss_gene=reconst_loss_gene, 722 | reconst_loss_protein=reconst_loss_protein, 723 | ) 724 | kl_local = dict( 725 | kl_div_z=kl_div_z, 726 | kl_div_s=kl_div_s, 727 | kl_div_l_gene=kl_div_l_gene, 728 | kl_div_back_pro=kl_div_back_pro, 729 | ) 730 | kl_global = torch.tensor(0.0) 731 | return LossRecorder( 732 | loss, 733 | reconst_losses, 734 | kl_local, 735 | kl_global=kl_global, 736 | wasserstein_loss=torch.sum(wasserstein_loss) 737 | ) 738 | 739 | @torch.no_grad() 740 | def sample_mean( 741 | self, 742 | tensors: Dict[str, torch.Tensor], 743 | data_source: str, 744 | n_samples: int = 1, 745 | ) -> torch.Tensor: 746 | """Sample posterior mean.""" 747 | raise NotImplementedError 748 | 749 | @torch.no_grad() 750 | def sample(self): 751 | raise NotImplementedError 752 | 753 | @torch.no_grad() 754 | @auto_move_data 755 | def marginal_ll(self): 756 | raise NotImplementedError 757 | -------------------------------------------------------------------------------- /contrastive_vi/module/utils.py: -------------------------------------------------------------------------------- 1 | """Utilities for contrastiveVI modules.""" 2 | import torch 3 | 4 | 5 | def gram_matrix(x: torch.Tensor, y: torch.Tensor, gammas: torch.Tensor) -> torch.Tensor: 6 | """ 7 | Calculate the maximum mean discrepancy gram matrix with multiple gamma values. 8 | 9 | Args: 10 | ---- 11 | x: Tensor with shape (B, P, M) or (P, M). 12 | y: Tensor with shape (B, R, M) or (R, M). 13 | gammas: 1-D tensor with the gamma values. 14 | 15 | Returns 16 | ------- 17 | A tensor with shape (B, P, R) or (P, R) for the distance between pairs of data 18 | points in `x` and `y`. 19 | """ 20 | gammas = gammas.unsqueeze(1) 21 | pairwise_distances = torch.cdist(x, y, p=2.0) 22 | 23 | pairwise_distances_sq = torch.square(pairwise_distances) 24 | tmp = torch.matmul(gammas, torch.reshape(pairwise_distances_sq, (1, -1))) 25 | tmp = torch.reshape(torch.sum(torch.exp(-tmp), 0), pairwise_distances_sq.shape) 26 | return tmp 27 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: contrastive-vi-env 2 | channels: 3 | - defaults 4 | - conda-forge 5 | dependencies: 6 | - python==3.9.* 7 | - pip 8 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [tool.black] 6 | line-length = 88 7 | 8 | [tool.isort] 9 | profile = "black" -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = contrastive-vi 3 | version = 0.2.0 4 | 5 | [options] 6 | package_dir = 7 | packages = 8 | find: 9 | python_requires = 10 | >=3.7 11 | install_requires = 12 | scanpy>=1.8.1 13 | protobuf<=3.20.1 14 | scvi-tools>=0.15.0, <=0.16.2 15 | 16 | [options.packages.find] 17 | where = 18 | 19 | [options.extras_require] 20 | dev = 21 | pre-commit==2.15.0 22 | toml==0.10.2 23 | pytest==6.2.5 24 | -------------------------------------------------------------------------------- /sketch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/suinleelab/contrastiveVI/2835d2925f7ef60cdbebd2422435046cb9165f24/sketch.png -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/suinleelab/contrastiveVI/2835d2925f7ef60cdbebd2422435046cb9165f24/tests/__init__.py -------------------------------------------------------------------------------- /tests/data/dataloaders/test_contrastive_dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from scvi import REGISTRY_KEYS 3 | 4 | from contrastive_vi.data.dataloaders.contrastive_dataloader import ContrastiveDataLoader 5 | from tests.utils import get_next_batch 6 | 7 | 8 | class TestContrastiveDataLoader: 9 | def test_one_batch( 10 | self, 11 | mock_adata, 12 | mock_adata_manager, 13 | mock_adata_background_indices, 14 | mock_adata_background_label, 15 | mock_adata_target_indices, 16 | mock_adata_target_label, 17 | ): 18 | batch_size = 32 19 | dataloader = ContrastiveDataLoader( 20 | mock_adata_manager, 21 | mock_adata_background_indices, 22 | mock_adata_target_indices, 23 | batch_size=batch_size, 24 | shuffle=False, 25 | ) 26 | batch = get_next_batch(dataloader) 27 | assert type(batch) == dict 28 | assert len(batch.keys()) == 2 29 | assert "background" in batch.keys() 30 | assert "target" in batch.keys() 31 | 32 | expected_background_data = torch.Tensor( 33 | mock_adata.layers["raw_counts"][mock_adata_background_indices, :][ 34 | :batch_size, : 35 | ] 36 | ) 37 | expected_target_data = torch.Tensor( 38 | mock_adata.layers["raw_counts"][mock_adata_target_indices, :][ 39 | :batch_size, : 40 | ] 41 | ) 42 | 43 | assert torch.equal( 44 | batch["background"][REGISTRY_KEYS.X_KEY], expected_background_data 45 | ) 46 | assert torch.equal(batch["target"][REGISTRY_KEYS.X_KEY], expected_target_data) 47 | 48 | assert ( 49 | batch["background"][REGISTRY_KEYS.BATCH_KEY] == mock_adata_background_label 50 | ).sum() == batch_size 51 | assert ( 52 | batch["target"][REGISTRY_KEYS.BATCH_KEY] == mock_adata_target_label 53 | ).sum() == batch_size 54 | -------------------------------------------------------------------------------- /tests/data/dataloaders/test_data_splitting.py: -------------------------------------------------------------------------------- 1 | from contrastive_vi.data.dataloaders.data_splitting import ContrastiveDataSplitter 2 | 3 | 4 | class TestContrastiveDataSplitter: 5 | def test_num_batches( 6 | self, 7 | mock_adata_manager, 8 | mock_adata_background_indices, 9 | mock_adata_target_indices, 10 | ) -> None: 11 | train_size = 0.8 12 | validation_size = 0.1 13 | test_size = 0.1 14 | batch_size = 20 15 | n_max = max( 16 | len(mock_adata_background_indices), 17 | len(mock_adata_target_indices), 18 | ) 19 | expected_train_num_batches = n_max * train_size / batch_size 20 | expected_val_num_batches = n_max * validation_size / batch_size 21 | expected_test_num_batches = n_max * test_size / batch_size 22 | 23 | data_splitter = ContrastiveDataSplitter( 24 | mock_adata_manager, 25 | mock_adata_background_indices, 26 | mock_adata_target_indices, 27 | train_size=train_size, 28 | validation_size=validation_size, 29 | batch_size=batch_size, 30 | ) 31 | data_splitter.setup() 32 | train_dataloader = data_splitter.train_dataloader() 33 | val_dataloader = data_splitter.val_dataloader() 34 | test_dataloader = data_splitter.test_dataloader() 35 | 36 | assert len(train_dataloader) == expected_train_num_batches 37 | assert len(val_dataloader) == expected_val_num_batches 38 | assert len(test_dataloader) == expected_test_num_batches 39 | -------------------------------------------------------------------------------- /tests/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/suinleelab/contrastiveVI/2835d2925f7ef60cdbebd2422435046cb9165f24/tests/model/__init__.py -------------------------------------------------------------------------------- /tests/model/test_contrastive_vi.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import pytest 4 | import torch 5 | 6 | from contrastive_vi.model.contrastive_vi import ContrastiveVIModel 7 | from tests.utils import copy_module_state_dict 8 | 9 | 10 | @pytest.fixture( 11 | params=[True, False], ids=["with_observed_lib_size", "without_observed_lib_size"] 12 | ) 13 | def mock_contrastive_vi_model( 14 | mock_adata, 15 | mock_adata_background_indices, 16 | mock_adata_target_indices, 17 | mock_library_log_means, 18 | mock_library_log_vars, 19 | request, 20 | ): 21 | if request.param: 22 | return ContrastiveVIModel( 23 | mock_adata, 24 | n_hidden=16, 25 | n_background_latent=4, 26 | n_salient_latent=4, 27 | n_layers=2, 28 | use_observed_lib_size=True, 29 | ) 30 | else: 31 | return ContrastiveVIModel( 32 | mock_adata, 33 | n_hidden=16, 34 | n_background_latent=4, 35 | n_salient_latent=4, 36 | n_layers=2, 37 | use_observed_lib_size=False, 38 | ) 39 | 40 | 41 | class TestContrastiveVIModel: 42 | def test_train( 43 | self, 44 | mock_contrastive_vi_model, 45 | mock_adata_background_indices, 46 | mock_adata_target_indices, 47 | ): 48 | init_state_dict = copy_module_state_dict(mock_contrastive_vi_model.module) 49 | mock_contrastive_vi_model.train( 50 | background_indices=mock_adata_background_indices, 51 | target_indices=mock_adata_target_indices, 52 | max_epochs=10, 53 | batch_size=20, # Unequal final batches to test edge case. 54 | use_gpu=False, 55 | ) 56 | trained_state_dict = copy_module_state_dict(mock_contrastive_vi_model.module) 57 | for param_key in mock_contrastive_vi_model.module.state_dict().keys(): 58 | is_library_param = ( 59 | param_key == "library_log_means" or param_key == "library_log_vars" 60 | ) 61 | is_px_r_decoder_param = "px_r_decoder" in param_key 62 | is_l_encoder_param = "l_encoder" in param_key 63 | 64 | if ( 65 | is_library_param 66 | or is_px_r_decoder_param 67 | or ( 68 | is_l_encoder_param 69 | and mock_contrastive_vi_model.module.use_observed_lib_size 70 | ) 71 | ): 72 | # There are three cases where parameters are not updated. 73 | # 1. Library means and vars are derived from input data and should 74 | # not be updated. 75 | # 2. In ContrastiveVIModel, dispersion is assumed to be gene-dependent 76 | # but not cell-dependent, so parameters in the dispersion (px_r) 77 | # decoder are not used and should not be updated. 78 | # 3. When observed library size is used, the library encoder is not 79 | # used and its parameters not updated. 80 | assert torch.equal( 81 | init_state_dict[param_key], trained_state_dict[param_key] 82 | ) 83 | else: 84 | # Other parameters should be updated after training. 85 | assert not torch.equal( 86 | init_state_dict[param_key], trained_state_dict[param_key] 87 | ) 88 | 89 | @pytest.mark.parametrize("representation_kind", ["background", "salient"]) 90 | def test_get_latent_representation( 91 | self, mock_contrastive_vi_model, representation_kind 92 | ): 93 | n_cells = mock_contrastive_vi_model.adata.n_obs 94 | if representation_kind == "background": 95 | n_latent = mock_contrastive_vi_model.module.n_background_latent 96 | else: 97 | n_latent = mock_contrastive_vi_model.module.n_salient_latent 98 | representation = mock_contrastive_vi_model.get_latent_representation( 99 | representation_kind=representation_kind 100 | ) 101 | assert representation.shape == (n_cells, n_latent) 102 | 103 | @pytest.mark.parametrize("representation_kind", ["background", "salient"]) 104 | def test_get_normalized_expression( 105 | self, mock_contrastive_vi_model, representation_kind 106 | ): 107 | n_samples = 50 108 | n_cells = mock_contrastive_vi_model.adata.n_obs 109 | n_genes = mock_contrastive_vi_model.adata.n_vars 110 | one_sample_exprs = mock_contrastive_vi_model.get_normalized_expression( 111 | n_samples=1, return_numpy=True 112 | ) 113 | one_sample_exprs = one_sample_exprs[representation_kind] 114 | assert type(one_sample_exprs) == np.ndarray 115 | assert one_sample_exprs.shape == (n_cells, n_genes) 116 | 117 | many_sample_exprs = mock_contrastive_vi_model.get_normalized_expression( 118 | n_samples=n_samples, 119 | return_mean=False, 120 | ) 121 | many_sample_exprs = many_sample_exprs[representation_kind] 122 | assert type(many_sample_exprs) == np.ndarray 123 | assert many_sample_exprs.shape == (n_samples, n_cells, n_genes) 124 | 125 | exprs_df = mock_contrastive_vi_model.get_normalized_expression( 126 | n_samples=1, 127 | return_numpy=False, 128 | ) 129 | exprs_df = exprs_df[representation_kind] 130 | assert type(exprs_df) == pd.DataFrame 131 | assert exprs_df.shape == (n_cells, n_genes) 132 | 133 | def test_get_salient_normalized_expression(self, mock_contrastive_vi_model): 134 | n_samples = 50 135 | n_cells = mock_contrastive_vi_model.adata.n_obs 136 | n_genes = mock_contrastive_vi_model.adata.n_vars 137 | 138 | one_sample_expr = mock_contrastive_vi_model.get_salient_normalized_expression( 139 | n_samples=1, return_numpy=True 140 | ) 141 | assert type(one_sample_expr) == np.ndarray 142 | assert one_sample_expr.shape == (n_cells, n_genes) 143 | 144 | many_sample_expr = mock_contrastive_vi_model.get_salient_normalized_expression( 145 | n_samples=n_samples, 146 | return_mean=False, 147 | ) 148 | assert type(many_sample_expr) == np.ndarray 149 | assert many_sample_expr.shape == (n_samples, n_cells, n_genes) 150 | 151 | expr_df = mock_contrastive_vi_model.get_salient_normalized_expression( 152 | n_samples=1, 153 | return_numpy=False, 154 | ) 155 | assert type(expr_df) == pd.DataFrame 156 | assert expr_df.shape == (n_cells, n_genes) 157 | 158 | def test_get_normalized_expression_fold_change(self, mock_contrastive_vi_model): 159 | n_samples = 50 160 | n_cells = mock_contrastive_vi_model.adata.n_obs 161 | n_genes = mock_contrastive_vi_model.adata.n_vars 162 | one_sample_fc = mock_contrastive_vi_model.get_normalized_expression_fold_change( 163 | n_samples=1 164 | ) 165 | assert one_sample_fc.shape == (n_cells, n_genes) 166 | many_sample_fc = ( 167 | mock_contrastive_vi_model.get_normalized_expression_fold_change( 168 | n_samples=50 169 | ) 170 | ) 171 | assert many_sample_fc.shape == (n_samples, n_cells, n_genes) 172 | 173 | def test_differential_expression(self, mock_contrastive_vi_model): 174 | de_df = mock_contrastive_vi_model.differential_expression( 175 | groupby="labels", 176 | group1=["label_0"], 177 | ) 178 | n_vars = mock_contrastive_vi_model.adata.n_vars 179 | assert type(de_df) == pd.DataFrame 180 | assert de_df.shape[0] == n_vars 181 | -------------------------------------------------------------------------------- /tests/module/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/suinleelab/contrastiveVI/2835d2925f7ef60cdbebd2422435046cb9165f24/tests/module/__init__.py -------------------------------------------------------------------------------- /tests/module/test_contrastive_vi.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from scvi.module.base import LossRecorder 4 | 5 | from contrastive_vi.module.contrastive_vi import ContrastiveVIModule 6 | 7 | required_data_sources = ["background", "target"] 8 | required_inference_input_keys = ["x", "batch_index"] 9 | required_inference_output_keys = [ 10 | "z", 11 | "qz_m", 12 | "qz_v", 13 | "s", 14 | "qs_m", 15 | "qs_v", 16 | "library", 17 | "ql_m", 18 | "ql_v", 19 | ] 20 | required_generative_input_keys_from_concat_tensors = ["batch_index"] 21 | required_generative_input_keys_from_inference_outputs = ["z", "s", "library"] 22 | required_generative_output_keys = [ 23 | "px_scale", 24 | "px_r", 25 | "px_rate", 26 | "px_dropout", 27 | ] 28 | 29 | 30 | @pytest.fixture( 31 | params=[True, False], ids=["with_observed_lib_size", "without_observed_lib_size"] 32 | ) 33 | def mock_contrastive_vi_module( 34 | mock_n_input, mock_n_batch, mock_library_log_means, mock_library_log_vars, request 35 | ): 36 | if request.param: 37 | return ContrastiveVIModule( 38 | n_input=mock_n_input, 39 | n_batch=mock_n_batch, 40 | n_hidden=10, 41 | n_background_latent=4, 42 | n_salient_latent=4, 43 | n_layers=2, 44 | use_observed_lib_size=True, 45 | library_log_means=None, 46 | library_log_vars=None, 47 | ) 48 | else: 49 | return ContrastiveVIModule( 50 | n_input=mock_n_input, 51 | n_batch=mock_n_batch, 52 | n_hidden=10, 53 | n_background_latent=4, 54 | n_salient_latent=4, 55 | n_layers=2, 56 | use_observed_lib_size=False, 57 | library_log_means=mock_library_log_means, 58 | library_log_vars=mock_library_log_vars, 59 | ) 60 | 61 | 62 | @pytest.fixture(params=[1, 2], ids=["one_latent_sample", "two_latent_samples"]) 63 | def mock_contrastive_vi_data( 64 | mock_contrastive_batch, 65 | mock_contrastive_vi_module, 66 | request, 67 | ): 68 | concat_tensors = mock_contrastive_batch 69 | inference_input = mock_contrastive_vi_module._get_inference_input(concat_tensors) 70 | inference_outputs = mock_contrastive_vi_module.inference( 71 | **inference_input, n_samples=request.param 72 | ) 73 | generative_input = mock_contrastive_vi_module._get_generative_input( 74 | concat_tensors, inference_outputs 75 | ) 76 | generative_outputs = mock_contrastive_vi_module.generative(**generative_input) 77 | return dict( 78 | concat_tensors=concat_tensors, 79 | inference_input=inference_input, 80 | inference_outputs=inference_outputs, 81 | generative_input=generative_input, 82 | generative_outputs=generative_outputs, 83 | ) 84 | 85 | 86 | class TestContrastiveVIModuleInference: 87 | def test_get_inference_input_from_concat_tensors( 88 | self, 89 | mock_contrastive_vi_module, 90 | mock_contrastive_batch, 91 | mock_n_input, 92 | ): 93 | inference_input = ( 94 | mock_contrastive_vi_module._get_inference_input_from_concat_tensors( 95 | mock_contrastive_batch, "background" 96 | ) 97 | ) 98 | for key in required_inference_input_keys: 99 | assert key in inference_input.keys() 100 | x = inference_input["x"] 101 | batch_index = inference_input["batch_index"] 102 | batch_size = x.shape[0] 103 | assert x.shape == (batch_size, mock_n_input) 104 | assert batch_index.shape == (batch_size, 1) 105 | 106 | def test_get_inference_input( 107 | self, 108 | mock_contrastive_vi_module, 109 | mock_contrastive_batch, 110 | mock_adata_background_label, 111 | mock_adata_target_label, 112 | ): 113 | inference_input = mock_contrastive_vi_module._get_inference_input( 114 | mock_contrastive_batch 115 | ) 116 | for data_source in required_data_sources: 117 | assert data_source in inference_input.keys() 118 | 119 | background_input = inference_input["background"] 120 | background_input_keys = background_input.keys() 121 | target_input = inference_input["target"] 122 | target_input_keys = target_input.keys() 123 | 124 | for key in required_inference_input_keys: 125 | assert key in background_input_keys 126 | assert key in target_input_keys 127 | 128 | # Check background vs. target labels are consistent. 129 | assert ( 130 | background_input["batch_index"] != mock_adata_background_label 131 | ).sum() == 0 132 | assert (target_input["batch_index"] != mock_adata_target_label).sum() == 0 133 | 134 | @pytest.mark.parametrize("n_samples", [1, 2]) 135 | def test_generic_inference( 136 | self, 137 | mock_contrastive_vi_module, 138 | mock_contrastive_batch, 139 | n_samples, 140 | ): 141 | inference_input = ( 142 | mock_contrastive_vi_module._get_inference_input_from_concat_tensors( 143 | mock_contrastive_batch, "background" 144 | ) 145 | ) 146 | batch_size = inference_input["x"].shape[0] 147 | n_background_latent = mock_contrastive_vi_module.n_background_latent 148 | n_salient_latent = mock_contrastive_vi_module.n_salient_latent 149 | 150 | inference_outputs = mock_contrastive_vi_module._generic_inference( 151 | **inference_input, n_samples=n_samples 152 | ) 153 | for key in required_inference_output_keys: 154 | assert key in inference_outputs.keys() 155 | 156 | if n_samples > 1: 157 | expected_background_latent_shape = ( 158 | n_samples, 159 | batch_size, 160 | n_background_latent, 161 | ) 162 | expected_salient_latent_shape = (n_samples, batch_size, n_salient_latent) 163 | expected_library_shape = (n_samples, batch_size, 1) 164 | else: 165 | expected_background_latent_shape = (batch_size, n_background_latent) 166 | expected_salient_latent_shape = (batch_size, n_salient_latent) 167 | expected_library_shape = (batch_size, 1) 168 | 169 | assert inference_outputs["z"].shape == expected_background_latent_shape 170 | assert inference_outputs["qz_m"].shape == expected_background_latent_shape 171 | assert inference_outputs["qz_v"].shape == expected_background_latent_shape 172 | assert inference_outputs["s"].shape == expected_salient_latent_shape 173 | assert inference_outputs["qs_m"].shape == expected_salient_latent_shape 174 | assert inference_outputs["qs_v"].shape == expected_salient_latent_shape 175 | assert inference_outputs["library"].shape == expected_library_shape 176 | assert ( 177 | inference_outputs["ql_m"] is None 178 | or inference_outputs["ql_m"].shape == expected_library_shape 179 | ) 180 | assert ( 181 | inference_outputs["ql_v"] is None 182 | or inference_outputs["ql_m"].shape == expected_library_shape 183 | ) 184 | 185 | def test_inference( 186 | self, 187 | mock_contrastive_vi_module, 188 | mock_contrastive_batch, 189 | ): 190 | inference_input = mock_contrastive_vi_module._get_inference_input( 191 | mock_contrastive_batch 192 | ) 193 | inference_outputs = mock_contrastive_vi_module.inference(**inference_input) 194 | for data_source in required_data_sources: 195 | assert data_source in inference_outputs.keys() 196 | background_s = inference_outputs["background"]["s"] 197 | 198 | # Background salient variables should be all zeros. 199 | assert torch.equal(background_s, torch.zeros_like(background_s)) 200 | 201 | 202 | class TestContrastiveVIModuleGenerative: 203 | def test_get_generative_input_from_concat_tensors( 204 | self, 205 | mock_contrastive_vi_module, 206 | mock_contrastive_batch, 207 | mock_n_input, 208 | ): 209 | generative_input = ( 210 | mock_contrastive_vi_module._get_generative_input_from_concat_tensors( 211 | mock_contrastive_batch, "background" 212 | ) 213 | ) 214 | for key in required_generative_input_keys_from_concat_tensors: 215 | assert key in generative_input.keys() 216 | batch_index = generative_input["batch_index"] 217 | assert batch_index.shape[1] == 1 218 | 219 | def test_get_generative_input_from_inference_outputs( 220 | self, 221 | mock_contrastive_vi_module, 222 | mock_contrastive_batch, 223 | ): 224 | inference_outputs = mock_contrastive_vi_module.inference( 225 | **mock_contrastive_vi_module._get_inference_input(mock_contrastive_batch) 226 | ) 227 | generative_input = ( 228 | mock_contrastive_vi_module._get_generative_input_from_inference_outputs( 229 | inference_outputs, required_data_sources[0] 230 | ) 231 | ) 232 | for key in required_generative_input_keys_from_inference_outputs: 233 | assert key in generative_input 234 | 235 | z = generative_input["z"] 236 | s = generative_input["s"] 237 | library = generative_input["library"] 238 | batch_size = z.shape[0] 239 | 240 | assert z.shape == (batch_size, mock_contrastive_vi_module.n_background_latent) 241 | assert s.shape == (batch_size, mock_contrastive_vi_module.n_salient_latent) 242 | assert library.shape == (batch_size, 1) 243 | 244 | def test_get_generative_input( 245 | self, 246 | mock_contrastive_vi_module, 247 | mock_contrastive_batch, 248 | mock_adata_background_label, 249 | mock_adata_target_label, 250 | ): 251 | inference_outputs = mock_contrastive_vi_module.inference( 252 | **mock_contrastive_vi_module._get_inference_input(mock_contrastive_batch) 253 | ) 254 | generative_input = mock_contrastive_vi_module._get_generative_input( 255 | mock_contrastive_batch, inference_outputs 256 | ) 257 | for data_source in required_data_sources: 258 | assert data_source in generative_input.keys() 259 | background_generative_input = generative_input["background"] 260 | background_generative_input_keys = background_generative_input.keys() 261 | target_generative_input = generative_input["target"] 262 | target_generative_input_keys = target_generative_input.keys() 263 | for key in ( 264 | required_generative_input_keys_from_concat_tensors 265 | + required_generative_input_keys_from_inference_outputs 266 | ): 267 | assert key in background_generative_input_keys 268 | assert key in target_generative_input_keys 269 | 270 | # Check background vs. target labels are consistent. 271 | assert ( 272 | background_generative_input["batch_index"] != mock_adata_background_label 273 | ).sum() == 0 274 | assert ( 275 | target_generative_input["batch_index"] != mock_adata_target_label 276 | ).sum() == 0 277 | 278 | @pytest.mark.parametrize("n_samples", [1, 2]) 279 | def test_generic_generative( 280 | self, 281 | mock_contrastive_vi_module, 282 | mock_contrastive_batch, 283 | n_samples, 284 | ): 285 | inference_outputs = mock_contrastive_vi_module.inference( 286 | **mock_contrastive_vi_module._get_inference_input(mock_contrastive_batch), 287 | n_samples=n_samples, 288 | ) 289 | generative_input = mock_contrastive_vi_module._get_generative_input( 290 | mock_contrastive_batch, inference_outputs 291 | )["background"] 292 | generative_outputs = mock_contrastive_vi_module._generic_generative( 293 | **generative_input 294 | ) 295 | for key in required_generative_output_keys: 296 | assert key in generative_outputs.keys() 297 | px_scale = generative_outputs["px_scale"] 298 | px_r = generative_outputs["px_r"] 299 | px_rate = generative_outputs["px_rate"] 300 | px_dropout = generative_outputs["px_dropout"] 301 | batch_size = px_scale.shape[-2] 302 | n_input = mock_contrastive_vi_module.n_input 303 | 304 | if n_samples > 1: 305 | expected_shape = (n_samples, batch_size, n_input) 306 | else: 307 | expected_shape = (batch_size, n_input) 308 | 309 | assert px_scale.shape == expected_shape 310 | assert px_r.shape == (n_input,) # One dispersion parameter per gene. 311 | assert px_rate.shape == expected_shape 312 | assert px_dropout.shape == expected_shape 313 | 314 | def test_generative( 315 | self, 316 | mock_contrastive_vi_module, 317 | mock_contrastive_batch, 318 | ): 319 | inference_outputs = mock_contrastive_vi_module.inference( 320 | **mock_contrastive_vi_module._get_inference_input(mock_contrastive_batch), 321 | ) 322 | generative_input = mock_contrastive_vi_module._get_generative_input( 323 | mock_contrastive_batch, inference_outputs 324 | ) 325 | generative_outputs = mock_contrastive_vi_module.generative(**generative_input) 326 | for data_source in required_data_sources: 327 | assert data_source in generative_outputs.keys() 328 | 329 | 330 | class TestContrastiveVIModuleLoss: 331 | def test_reconstruction_loss( 332 | self, mock_contrastive_vi_module, mock_contrastive_vi_data 333 | ): 334 | inference_input = mock_contrastive_vi_data["inference_input"]["background"] 335 | generative_outputs = mock_contrastive_vi_data["generative_outputs"][ 336 | "background" 337 | ] 338 | x = inference_input["x"] 339 | px_rate = generative_outputs["px_rate"] 340 | px_r = generative_outputs["px_r"] 341 | px_dropout = generative_outputs["px_dropout"] 342 | recon_loss = mock_contrastive_vi_module.reconstruction_loss( 343 | x, px_rate, px_r, px_dropout 344 | ) 345 | if len(px_rate.shape) == 3: 346 | expected_shape = px_rate.shape[:2] 347 | else: 348 | expected_shape = px_rate.shape[:1] 349 | assert recon_loss.shape == expected_shape 350 | 351 | def test_latent_kl_divergence( 352 | self, mock_contrastive_vi_module, mock_contrastive_vi_data 353 | ): 354 | inference_outputs = mock_contrastive_vi_data["inference_outputs"]["background"] 355 | qz_m = inference_outputs["qz_m"] 356 | qz_v = inference_outputs["qz_v"] 357 | kl_z = mock_contrastive_vi_module.latent_kl_divergence( 358 | variational_mean=qz_m, 359 | variational_var=qz_v, 360 | prior_mean=torch.zeros_like(qz_m), 361 | prior_var=torch.ones_like(qz_v), 362 | ) 363 | assert kl_z.shape == qz_m.shape[:-1] 364 | 365 | def test_library_kl_divergence( 366 | self, mock_contrastive_vi_module, mock_contrastive_vi_data 367 | ): 368 | inference_input = mock_contrastive_vi_data["inference_input"]["background"] 369 | inference_outputs = mock_contrastive_vi_data["inference_outputs"]["background"] 370 | batch_index = inference_input["batch_index"] 371 | ql_m = inference_outputs["ql_m"] 372 | ql_v = inference_outputs["ql_v"] 373 | library = inference_outputs["library"] 374 | kl_library = mock_contrastive_vi_module.library_kl_divergence( 375 | batch_index, ql_m, ql_v, library 376 | ) 377 | expected_shape = library.shape[:-1] 378 | assert kl_library.shape == expected_shape 379 | if mock_contrastive_vi_module.use_observed_lib_size: 380 | assert torch.equal(kl_library, torch.zeros(expected_shape)) 381 | 382 | def test_loss(self, mock_contrastive_vi_module, mock_contrastive_vi_data): 383 | expected_shape = mock_contrastive_vi_data["inference_outputs"]["background"][ 384 | "qz_m" 385 | ].shape[:-1] 386 | losses = mock_contrastive_vi_module.loss( 387 | mock_contrastive_vi_data["concat_tensors"], 388 | mock_contrastive_vi_data["inference_outputs"], 389 | mock_contrastive_vi_data["generative_outputs"], 390 | ) 391 | loss = losses.loss 392 | recon_loss = losses.reconstruction_loss 393 | kl_local = losses.kl_local 394 | kl_global = losses.kl_global 395 | 396 | assert loss.shape == tuple() 397 | assert recon_loss.shape == expected_shape 398 | assert kl_local.shape == expected_shape 399 | assert kl_global.shape == tuple() 400 | 401 | @pytest.mark.parametrize("compute_loss", [True, False]) 402 | def test_forward( 403 | self, 404 | mock_contrastive_vi_module, 405 | mock_contrastive_vi_data, 406 | compute_loss, 407 | ): 408 | concat_tensors = mock_contrastive_vi_data["concat_tensors"] 409 | if compute_loss: 410 | inference_outputs, generative_outputs, losses = mock_contrastive_vi_module( 411 | concat_tensors, compute_loss=compute_loss 412 | ) 413 | assert isinstance(losses, LossRecorder) 414 | else: 415 | inference_outputs, generative_outputs = mock_contrastive_vi_module( 416 | concat_tensors, compute_loss=compute_loss 417 | ) 418 | for data_source in required_data_sources: 419 | assert data_source in inference_outputs.keys() 420 | assert data_source in generative_outputs.keys() 421 | for key in required_inference_output_keys: 422 | assert key in inference_outputs[data_source].keys() 423 | for key in required_generative_output_keys: 424 | assert key in generative_outputs[data_source].keys() 425 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | """Helper utilities for testing.""" 2 | 3 | from typing import Dict 4 | 5 | import torch 6 | 7 | 8 | def get_next_batch(dataloader): 9 | return next(tensors for tensors in dataloader) 10 | 11 | 12 | def copy_module_state_dict(module) -> Dict[str, torch.Tensor]: 13 | copy = {} 14 | for name, param in module.state_dict().items(): 15 | copy[name] = param.detach().cpu().clone() 16 | return copy 17 | --------------------------------------------------------------------------------