├── .editorconfig ├── .flake8 ├── .github └── workflows │ └── test.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── codecov.yml ├── data └── synthetic │ └── scripts │ ├── symSim.R │ └── symsim_env.yaml ├── docs ├── Makefile ├── _static │ └── css │ │ └── custom.css ├── _templates │ ├── autosummary │ │ └── class.rst │ └── class_no_inherited.rst ├── conf.py ├── extensions │ └── typed_returns.py ├── index.rst ├── installation.rst ├── make.bat ├── references.bib ├── references.rst └── release_notes │ ├── index.rst │ └── v0.1.0.rst ├── mrvi ├── __init__.py ├── _components.py ├── _constants.py ├── _model.py ├── _module.py └── _utils.py ├── pyproject.toml ├── readthedocs.yml ├── setup.py ├── tests ├── __init__.py └── test_mrvi.py └── workflow ├── Snakefile ├── config └── config.yaml ├── envs ├── process_data.yaml └── run_model.yaml ├── notebooks ├── semisynthetic.py ├── snrna.py └── synthetic.py └── scripts ├── compute_local_scores.py ├── process_data.py ├── run_model.py └── utils ├── __init__.py ├── _base_model.py ├── _baselines.py ├── _metrics.py ├── _milo.py └── _mrvi.py /.editorconfig: -------------------------------------------------------------------------------- 1 | # http://editorconfig.org 2 | 3 | root = true 4 | 5 | [*] 6 | indent_style = space 7 | indent_size = 4 8 | trim_trailing_whitespace = true 9 | insert_final_newline = true 10 | charset = utf-8 11 | end_of_line = lf 12 | 13 | [*.bat] 14 | indent_style = tab 15 | end_of_line = crlf 16 | 17 | [LICENSE] 18 | insert_final_newline = false 19 | 20 | [Makefile] 21 | indent_style = tab 22 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E203, E266, E501, W503, W605, N812, F821 3 | exclude = .git,docs,workflow/.snakemake,workflow/notebooks 4 | max-line-length = 119 5 | 6 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: mrvi 5 | 6 | on: 7 | push: 8 | branches: [main] 9 | pull_request: 10 | branches: [main] 11 | 12 | jobs: 13 | build: 14 | runs-on: ubuntu-latest 15 | strategy: 16 | matrix: 17 | python-version: ["3.9", "3.10", "3.11"] 18 | 19 | steps: 20 | - uses: actions/checkout@v3 21 | - name: Set up Python ${{ matrix.python-version }} 22 | uses: actions/setup-python@v4 23 | with: 24 | python-version: ${{ matrix.python-version }} 25 | - name: Cache pip 26 | uses: actions/cache@v2 27 | with: 28 | path: ~/.cache/pip 29 | key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }} 30 | restore-keys: | 31 | ${{ runner.os }}-pip- 32 | - name: Install dependencies 33 | run: | 34 | pip install pytest-cov 35 | pip install .[dev] 36 | - name: Lint with flake8 37 | run: | 38 | flake8 39 | - name: Format with black 40 | run: | 41 | black --check . 42 | - name: Test with pytest 43 | run: | 44 | pytest --cov-report=xml --cov=mrvi 45 | - name: After success 46 | run: | 47 | bash <(curl -s https://codecov.io/bash) 48 | pip list 49 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # DS_Store 2 | .DS_Store 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | docs/api/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | 135 | # vscode 136 | .vscode/settings.json 137 | 138 | # snakemake 139 | workflow/.snakemake/ 140 | workflow/results/ 141 | workflow/data/ 142 | workflow/figures/ 143 | 144 | # mrvi data 145 | data/*/*.h5ad -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: 24.1.1 4 | hooks: 5 | - id: black 6 | - repo: https://github.com/PyCQA/flake8 7 | rev: 7.0.0 8 | hooks: 9 | - id: flake8 10 | - repo: https://github.com/pycqa/isort 11 | rev: 5.13.2 12 | hooks: 13 | - id: isort 14 | name: isort (python) 15 | additional_dependencies: [toml] 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2021, Yosef Lab 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multi-resolution Variational Inference 2 | 3 | **DEPRECIATED: please refer to the current [version](https://github.com/YosefLab/mrvi) of the project for up-to-date code** 4 | 5 | Multi-resolution Variational Inference (MrVI) is a package for analysis of sample-level heterogeneity in multi-site, multi-sample single-cell omics data. Built with [scvi-tools](https://scvi-tools.org). 6 | 7 | --- 8 | 9 | To install, run: 10 | 11 | ``` 12 | pip install mrvi 13 | ``` 14 | 15 | `mrvi.MrVI` follows the same API used in scvi-tools. 16 | 17 | ```python 18 | import mrvi 19 | import anndata 20 | 21 | adata = anndata.read_h5ad("path/to/adata.h5ad") 22 | # Sample (e.g. donors, perturbations, etc.) should go in sample_key 23 | # Sites, plates, and other factors should go in categorical_nuisance_keys 24 | mrvi.MrVI.setup_anndata(adata, sample_key="donor", categorical_nuisance_keys=["site"]) 25 | mrvi_model = mrvi.MrVI(adata) 26 | mrvi_model.train() 27 | # Get z representation 28 | adata.obsm["X_mrvi_z"] = mrvi_model.get_latent_representation(give_z=True) 29 | # Get u representation 30 | adata.obsm["X_mrvi_u"] = mrvi_model.get_latent_representation(give_z=False) 31 | # Cells by n_sample by n_latent 32 | cell_sample_representations = mrvi_model.get_local_sample_representation() 33 | # Cells by n_sample by n_sample 34 | cell_sample_sample_distances = mrvi_model.get_local_sample_representation(return_distances=True) 35 | ``` 36 | 37 | ## Citation 38 | 39 | ``` 40 | @article {Boyeau2022.10.04.510898, 41 | author = {Boyeau, Pierre and Hong, Justin and Gayoso, Adam and Jordan, Michael and Azizi, Elham and Yosef, Nir}, 42 | title = {Deep generative modeling for quantifying sample-level heterogeneity in single-cell omics}, 43 | elocation-id = {2022.10.04.510898}, 44 | year = {2022}, 45 | doi = {10.1101/2022.10.04.510898}, 46 | publisher = {Cold Spring Harbor Laboratory}, 47 | abstract = {Contemporary single-cell omics technologies have enabled complex experimental designs incorporating hundreds of samples accompanied by detailed information on sample-level conditions. Current approaches for analyzing condition-level heterogeneity in these experiments often rely on a simplification of the data such as an aggregation at the cell-type or cell-state-neighborhood level. Here we present MrVI, a deep generative model that provides sample-sample comparisons at a single-cell resolution, permitting the discovery of subtle sample-specific effects across cell populations. Additionally, the output of MrVI can be used to quantify the association between sample-level metadata and cell state variation. We benchmarked MrVI against conventional meta-analysis procedures on two synthetic datasets and one real dataset with a well-controlled experimental structure. This work introduces a novel approach to understanding sample-level heterogeneity while leveraging the full resolution of single-cell sequencing data.Competing Interest StatementN.Y. is an advisor and/or has equity in Cellarity, Celsius Therapeutics, and Rheos Medicine.}, 48 | URL = {https://www.biorxiv.org/content/early/2022/10/06/2022.10.04.510898}, 49 | eprint = {https://www.biorxiv.org/content/early/2022/10/06/2022.10.04.510898.full.pdf}, 50 | journal = {bioRxiv} 51 | } 52 | ``` 53 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | # Run to check if valid 2 | # curl --data-binary @codecov.yml https://codecov.io/validate 3 | coverage: 4 | status: 5 | project: 6 | default: 7 | target: 80% 8 | threshold: 1% 9 | patch: off 10 | -------------------------------------------------------------------------------- /data/synthetic/scripts/symSim.R: -------------------------------------------------------------------------------- 1 | library(SymSim) # Uses the fork https://github.com/justjhong/SymSim to expose DiscreteEVF 2 | library(ape) 3 | library(anndata) 4 | 5 | genCorTree <- function(nTips, cor) { 6 | root <- stree(1, tip.label = "branching_point") 7 | tip <- stree(nTips, tip.label = sprintf("CT%i:1", 1:nTips)) 8 | tree <- bind.tree(root, tip, where = 1) 9 | tree$edge.length <- c(cor, rep(1 - cor, nTips)) 10 | return(tree) 11 | } 12 | 13 | genMetaEVFs <- function(n_cells_total, nCellTypes, nMetadata, metadataCor, n_nd_evf, nEVFsCellType, nEVFsPerMetadata, save_dir = NULL, randseed = 1) { 14 | set.seed(randseed) 15 | seed <- sample(c(1:1e5), size = nMetadata + 2) 16 | 17 | if (length(metadataCor == 1)) { 18 | metadataCor <- rep(metadataCor, nMetadata) 19 | } 20 | if (length(nEVFsPerMetadata == 1)) { 21 | nEVFsPerMetadata <- rep(nEVFsPerMetadata, nMetadata) 22 | } 23 | 24 | ct_tree <- genCorTree(nCellTypes, 0) 25 | base_evf_res <- DiscreteEVF(ct_tree, n_cells_total, n_cells_total / nCellTypes, 1, 0.4, n_nd_evf, nEVFsCellType, "all", 1, seed[[1]]) 26 | evf_mtxs <- base_evf_res[[1]] 27 | base_evf_mtx_ncols <- c() 28 | for (j in 1:3) { 29 | colnames(evf_mtxs[[j]]) <- paste(colnames(evf_mtxs[[j]]), "base", sep = "_") 30 | base_evf_mtx_ncols <- c(base_evf_mtx_ncols, ncol(evf_mtxs[[j]])) 31 | } 32 | 33 | ct_mapping <- ct_tree$tip.label 34 | names(ct_mapping) <- seq_len(length(ct_mapping)) 35 | meta <- data.frame("celltype" = ct_mapping[base_evf_res[[2]]$pop]) 36 | 37 | for (i in 1:nMetadata) { 38 | shuffled_row_idxs <- sample(1:n_cells_total) 39 | meta_tree <- genCorTree(2, metadataCor[[i]]) 40 | meta_evf_res <- DiscreteEVF(meta_tree, n_cells_total, n_cells_total / 2, 1, 1.0, 0, nEVFsPerMetadata[[i]], "all", 1, seed[[i + 2]]) 41 | meta_evf_mtxs <- meta_evf_res[[1]] 42 | for (j in 1:3) { 43 | colnames(meta_evf_mtxs[[j]]) <- paste(colnames(meta_evf_mtxs[[j]]), sprintf("meta_%d", i), sep = "_") 44 | evf_mtxs[[j]] <- cbind(evf_mtxs[[j]], meta_evf_mtxs[[j]][shuffled_row_idxs, ]) 45 | } 46 | 47 | meta_mapping <- meta_tree$tip.label 48 | names(meta_mapping) <- seq_len(length(meta_mapping)) 49 | meta[sprintf("meta_%d", i)] <- meta_mapping[meta_evf_res[[2]]$pop][shuffled_row_idxs] 50 | } 51 | 52 | 53 | # random meta_evfs for cell type 2 54 | nc_tree <- genCorTree(1, 0) 55 | nc_evf_res <- DiscreteEVF(nc_tree, n_cells_total / 2, n_cells_total / 2, 1, 1.0, 0, sum(nEVFsPerMetadata), "all", 1, seed[[2]]) 56 | nc_evf_mtxs <- nc_evf_res[[1]] 57 | ct2_idxs <- which(meta$celltype == "CT2:1") 58 | for (j in 1:3) { 59 | evf_mtxs[[j]][ct2_idxs, (base_evf_mtx_ncols[[j]] + 1):ncol(evf_mtxs[[j]])] = nc_evf_mtxs[[j]] 60 | } 61 | 62 | return(list(evf_mtxs, meta, nc_evf_mtxs)) 63 | } 64 | 65 | 66 | MetaSim <- function(nMetadata, metadataCor, nEVFsPerMetadata, nEVFsCellType, write = F, save_path = NULL, randseed = 1) { 67 | ncells <- 20000 68 | ngenes <- 2000 69 | data(gene_len_pool) 70 | gene_len <- sample(gene_len_pool, ngenes, replace = FALSE) 71 | 72 | evf_res <- genMetaEVFs(n_cells_total = ncells, nCellTypes = 2, nMetadata = nMetadata, metadataCor = metadataCor, n_nd_evf = 60, nEVFsCellType = nEVFsCellType, nEVFsPerMetadata = nEVFsPerMetadata, randseed = randseed) 73 | 74 | print("simulating true counts") 75 | true_counts_res <- SimulateTrueCountsFromEVF(evf_res, ngenes = ngenes, randseed = randseed) 76 | rm(evf_res) 77 | gc() 78 | 79 | print("simulating observed counts") 80 | observed_counts <- True2ObservedCounts(true_counts = true_counts_res[[1]], meta_cell = true_counts_res[[3]], protocol = "UMI", alpha_mean = 0.05, alpha_sd = 0.02, gene_len = gene_len, depth_mean = 5e4, depth_sd = 3e3) 81 | rm(true_counts_res) 82 | gc() 83 | 84 | print("simulating batch effects") 85 | observed_counts_2batches <- DivideBatches(observed_counts_res = observed_counts, nbatch = 2, batch_effect_size = 1) 86 | rm(observed_counts) 87 | gc() 88 | 89 | print("converting to anndata") 90 | meta_keys <- c("celltype", "batch") 91 | meta_keys <- c(meta_keys, paste("meta_", 1:nMetadata, sep = "")) 92 | ad_results <- AnnData(X = t(observed_counts_2batches$counts), obs = observed_counts_2batches$cell_meta[meta_keys]) 93 | 94 | if (write == T && !is.null(save_path)) { 95 | write_h5ad(ad_results, save_path) 96 | } 97 | 98 | return(list(ad_results, evf_res)) 99 | } 100 | 101 | meta_results <- MetaSim(3, metadataCor = c(0, 0.5, 0.9), nEVFsPerMetadata = 7, nEVFsCellType = 40, write = T, save_path = "3_meta_sim_20k.h5ad", randseed = 126) -------------------------------------------------------------------------------- /data/synthetic/scripts/symsim_env.yaml: -------------------------------------------------------------------------------- 1 | name: symsim 2 | channels: 3 | - r 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=conda_forge 8 | - _openmp_mutex=4.5=2_gnu 9 | - _r-mutex=1.0.0=anacondar_1 10 | - binutils_impl_linux-64=2.36.1=h193b22a_2 11 | - bwidget=1.9.14=ha770c72_1 12 | - bzip2=1.0.8=h7f98852_4 13 | - c-ares=1.18.1=h7f98852_0 14 | - ca-certificates=2022.9.14=ha878542_0 15 | - cairo=1.16.0=ha61ee94_1014 16 | - curl=7.83.1=h2283fc2_0 17 | - expat=2.4.8=h27087fc_0 18 | - font-ttf-dejavu-sans-mono=2.37=hab24e00_0 19 | - font-ttf-inconsolata=3.000=h77eed37_0 20 | - font-ttf-source-code-pro=2.038=h77eed37_0 21 | - font-ttf-ubuntu=0.83=hab24e00_0 22 | - fontconfig=2.14.0=hc2a2eb6_1 23 | - fonts-conda-ecosystem=1=0 24 | - fonts-conda-forge=1=0 25 | - freetype=2.12.1=hca18f0e_0 26 | - fribidi=1.0.10=h36c2ea0_0 27 | - gcc_impl_linux-64=12.1.0=hea43390_16 28 | - gettext=0.19.8.1=h73d1719_1008 29 | - gfortran_impl_linux-64=12.1.0=h1db8e46_16 30 | - graphite2=1.3.13=h58526e2_1001 31 | - gsl=2.7=he838d99_0 32 | - gxx_impl_linux-64=12.1.0=hea43390_16 33 | - harfbuzz=5.2.0=hf9f4e7c_0 34 | - icu=70.1=h27087fc_0 35 | - jpeg=9e=h166bdaf_2 36 | - kernel-headers_linux-64=2.6.32=he073ed8_15 37 | - keyutils=1.6.1=h166bdaf_0 38 | - krb5=1.19.3=h08a2579_0 39 | - ld_impl_linux-64=2.36.1=hea4e1c9_2 40 | - lerc=4.0.0=h27087fc_0 41 | - libblas=3.9.0=16_linux64_openblas 42 | - libcblas=3.9.0=16_linux64_openblas 43 | - libcurl=7.83.1=h2283fc2_0 44 | - libdeflate=1.14=h166bdaf_0 45 | - libedit=3.1.20191231=he28a2e2_2 46 | - libev=4.33=h516909a_1 47 | - libffi=3.4.2=h7f98852_5 48 | - libgcc-devel_linux-64=12.1.0=h1ec3361_16 49 | - libgcc-ng=12.1.0=h8d9b700_16 50 | - libgfortran-ng=12.1.0=h69a702a_16 51 | - libgfortran5=12.1.0=hdcd56e2_16 52 | - libglib=2.72.1=h2d90d5f_0 53 | - libgomp=12.1.0=h8d9b700_16 54 | - libiconv=1.16=h516909a_0 55 | - liblapack=3.9.0=16_linux64_openblas 56 | - libnghttp2=1.47.0=hff17c54_1 57 | - libopenblas=0.3.21=pthreads_h78a6416_3 58 | - libpng=1.6.38=h753d276_0 59 | - libsanitizer=12.1.0=ha89aaad_16 60 | - libssh2=1.10.0=hf14f497_3 61 | - libstdcxx-devel_linux-64=12.1.0=h1ec3361_16 62 | - libstdcxx-ng=12.1.0=ha89aaad_16 63 | - libtiff=4.4.0=h55922b4_4 64 | - libuuid=2.32.1=h7f98852_1000 65 | - libwebp-base=1.2.4=h166bdaf_0 66 | - libxcb=1.13=h7f98852_1004 67 | - libxml2=2.9.14=h22db469_4 68 | - libzlib=1.2.12=h166bdaf_3 69 | - make=4.3=hd18ef5c_1 70 | - ncurses=6.3=h27087fc_1 71 | - openssl=3.0.5=h166bdaf_2 72 | - pango=1.50.10=hc4f8a73_0 73 | - pcre=8.45=h9c3ff4c_0 74 | - pcre2=10.37=hc3806b6_1 75 | - pixman=0.40.0=h36c2ea0_0 76 | - pthread-stubs=0.4=h36c2ea0_1001 77 | - r-askpass=1.1=r42h76d94ec_0 78 | - r-base=4.2.1=ha8c3e7c_1 79 | - r-brew=1.0_7=r42h6115d3f_0 80 | - r-brio=1.1.3=r42h76d94ec_0 81 | - r-cachem=1.0.6=r42h76d94ec_0 82 | - r-callr=3.7.0=r42h6115d3f_0 83 | - r-cli=3.3.0=r42h884c59f_0 84 | - r-clipr=0.8.0=r42h6115d3f_0 85 | - r-commonmark=1.8.0=r42h76d94ec_0 86 | - r-cpp11=0.4.2=r42h6115d3f_0 87 | - r-crayon=1.5.1=r42h6115d3f_0 88 | - r-credentials=1.3.2=r42h142f84f_0 89 | - r-curl=4.3.2=r42h76d94ec_0 90 | - r-desc=1.4.1=r42h6115d3f_0 91 | - r-devtools=2.4.3=r42h6115d3f_0 92 | - r-diffobj=0.3.5=r42h76d94ec_0 93 | - r-digest=0.6.29=r42h884c59f_0 94 | - r-ellipsis=0.3.2=r42h76d94ec_0 95 | - r-evaluate=0.15=r42h6115d3f_0 96 | - r-fansi=1.0.3=r42h76d94ec_0 97 | - r-fastmap=1.1.0=r42h884c59f_0 98 | - r-fs=1.5.2=r42h884c59f_0 99 | - r-gert=1.6.0=r42h76d94ec_0 100 | - r-gh=1.3.0=r42h142f84f_0 101 | - r-gitcreds=0.1.1=r42h6115d3f_0 102 | - r-glue=1.6.2=r42h76d94ec_0 103 | - r-highr=0.9=r42h6115d3f_0 104 | - r-httr=1.4.3=r42h6115d3f_0 105 | - r-ini=0.3.1=r42h142f84f_0 106 | - r-jsonlite=1.8.0=r42h76d94ec_0 107 | - r-knitr=1.39=r42h6115d3f_0 108 | - r-lifecycle=1.0.1=r42h142f84f_0 109 | - r-magrittr=2.0.3=r42h76d94ec_0 110 | - r-memoise=2.0.1=r42h6115d3f_0 111 | - r-mime=0.12=r42h76d94ec_0 112 | - r-openssl=2.0.2=r42h76d94ec_0 113 | - r-pillar=1.7.0=r42h6115d3f_0 114 | - r-pkgbuild=1.3.1=r42h142f84f_0 115 | - r-pkgconfig=2.0.3=r42h6115d3f_0 116 | - r-pkgload=1.2.4=r42h142f84f_0 117 | - r-praise=1.0.0=r42h6115d3f_4 118 | - r-prettyunits=1.1.1=r42h142f84f_0 119 | - r-processx=3.5.3=r42h76d94ec_0 120 | - r-ps=1.7.0=r42h76d94ec_0 121 | - r-purrr=0.3.4=r42h76d94ec_0 122 | - r-r6=2.5.1=r42h6115d3f_0 123 | - r-rappdirs=0.3.3=r42h76d94ec_0 124 | - r-rcmdcheck=1.4.0=r42h142f84f_0 125 | - r-rematch2=2.1.2=r42h142f84f_0 126 | - r-remotes=2.4.2=r42h142f84f_0 127 | - r-rlang=1.0.2=r42h884c59f_0 128 | - r-roxygen2=7.2.0=r42h884c59f_0 129 | - r-rprojroot=2.0.3=r42h6115d3f_0 130 | - r-rstudioapi=0.13=r42h6115d3f_0 131 | - r-rversions=2.1.1=r42h6115d3f_0 132 | - r-sessioninfo=1.2.2=r42h142f84f_0 133 | - r-stringi=1.7.6=r42h884c59f_0 134 | - r-stringr=1.4.0=r42h6115d3f_0 135 | - r-sys=3.4=r42h76d94ec_0 136 | - r-testthat=3.1.4=r42h884c59f_0 137 | - r-tibble=3.1.7=r42h76d94ec_0 138 | - r-usethis=2.1.6=r42h142f84f_0 139 | - r-utf8=1.2.2=r42h76d94ec_0 140 | - r-vctrs=0.4.1=r42h884c59f_0 141 | - r-waldo=0.4.0=r42h6115d3f_0 142 | - r-whisker=0.4=r42h6115d3f_0 143 | - r-withr=2.5.0=r42h6115d3f_0 144 | - r-xfun=0.31=r42h76d94ec_0 145 | - r-xml2=1.3.3=r42h884c59f_0 146 | - r-xopen=1.0.0=r42h142f84f_0 147 | - r-yaml=2.3.5=r42h76d94ec_0 148 | - r-zip=2.2.0=r42h76d94ec_0 149 | - readline=8.1.2=h0f457ee_0 150 | - sed=4.8=he412f7d_0 151 | - sysroot_linux-64=2.12=he073ed8_15 152 | - tk=8.6.12=h27826a3_0 153 | - tktable=2.10=hb7b940f_3 154 | - xorg-kbproto=1.0.7=h7f98852_1002 155 | - xorg-libice=1.0.10=h7f98852_0 156 | - xorg-libsm=1.2.3=hd9c2040_1000 157 | - xorg-libx11=1.7.2=h7f98852_0 158 | - xorg-libxau=1.0.9=h7f98852_0 159 | - xorg-libxdmcp=1.1.3=h7f98852_0 160 | - xorg-libxext=1.3.4=h7f98852_1 161 | - xorg-libxrender=0.9.10=h7f98852_1003 162 | - xorg-libxt=1.2.1=h7f98852_2 163 | - xorg-renderproto=0.11.1=h7f98852_1002 164 | - xorg-xextproto=7.3.0=h7f98852_1002 165 | - xorg-xproto=7.0.31=h7f98852_1007 166 | - xz=5.2.6=h166bdaf_0 167 | - zlib=1.2.12=h166bdaf_3 168 | - zstd=1.5.2=h6239696_4 169 | prefix: /home/justin/miniconda3/envs/symsim 170 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = python -msphinx 7 | SPHINXPROJ = mrvi 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/_static/css/custom.css: -------------------------------------------------------------------------------- 1 | /* influenced by and borrowed from: https://github.com/cvxgrp/pymde/blob/main/docs_src/source/_static/css/custom.css */ 2 | 3 | @import url('https://fonts.googleapis.com/css2?family=Roboto:ital,wght@0,300;0,400;0,600;1,300;1,400;1,600&display=swap'); 4 | 5 | :root { 6 | --sidebarcolor: #003262; 7 | --sidebarfontcolor: #ffffff; 8 | --sidebarhover: #295e97; 9 | 10 | --bodyfontcolor: #333; 11 | --webfont: 'Roboto'; 12 | 13 | --contentwidth: 1000px; 14 | } 15 | 16 | /* Fonts and text */ 17 | h1, h2, h3, h4, h5, h6 { 18 | font-family: var(--webfont), 'Helvetica Neue', Helvetica, Arial, sans-serif; 19 | font-weight: 400; 20 | } 21 | 22 | h2, h3, h4, h5, h6 { 23 | padding-top: 0.25em; 24 | margin-bottom: 0.5em; 25 | } 26 | 27 | h1 { 28 | font-size: 225%; 29 | } 30 | 31 | body { 32 | font-family: var(--webfont), 'Helvetica Neue', Helvetica, Arial, sans-serif; 33 | color: var(--bodyfontcolor); 34 | } 35 | 36 | p { 37 | font-size: 1em; 38 | line-height: 150%; 39 | } 40 | 41 | 42 | /* Sidebar */ 43 | .wy-side-nav-search { 44 | background-color: var(--sidebarcolor); 45 | } 46 | 47 | .wy-nav-side { 48 | background: var(--sidebarcolor); 49 | } 50 | 51 | .wy-menu-vertical header, .wy-menu-vertical p.caption { 52 | color: var(--sidebarfontcolor); 53 | } 54 | 55 | .wy-menu-vertical a { 56 | color: var(--sidebarfontcolor); 57 | } 58 | 59 | .wy-side-nav-search > div.version { 60 | color: var(--sidebarfontcolor); 61 | } 62 | 63 | .wy-menu-vertical a:hover { 64 | background-color: var(--sidebarhover); 65 | } 66 | 67 | /* Main content */ 68 | .wy-nav-content { 69 | max-width: var(--contentwidth); 70 | } 71 | 72 | 73 | html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple) dl:not(.field-list) > dt{ 74 | margin-bottom: 6px; 75 | border-left: none; 76 | background: none; 77 | color: #555; 78 | } 79 | -------------------------------------------------------------------------------- /docs/_templates/autosummary/class.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. add toctree option to make autodoc generate the pages 6 | 7 | .. autoclass:: {{ objname }} 8 | 9 | {% block attributes %} 10 | {% if attributes %} 11 | .. rubric:: Attributes 12 | 13 | .. autosummary:: 14 | :toctree: . 15 | {% for item in attributes %} 16 | ~{{ fullname }}.{{ item }} 17 | {%- endfor %} 18 | {% endif %} 19 | {% endblock %} 20 | 21 | {% block methods %} 22 | {% if methods %} 23 | .. rubric:: Methods 24 | 25 | .. autosummary:: 26 | :toctree: . 27 | {% for item in methods %} 28 | {%- if item != '__init__' %} 29 | ~{{ fullname }}.{{ item }} 30 | {%- endif -%} 31 | {%- endfor %} 32 | {% endif %} 33 | {% endblock %} 34 | -------------------------------------------------------------------------------- /docs/_templates/class_no_inherited.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. add toctree option to make autodoc generate the pages 6 | 7 | .. autoclass:: {{ objname }} 8 | 9 | {% block methods %} 10 | {% if methods %} 11 | .. rubric:: Methods 12 | 13 | .. autosummary:: 14 | :toctree: . 15 | {% for item in methods %} 16 | {%- if item != '__init__' and item not in inherited_members%} 17 | ~{{ fullname }}.{{ item }} 18 | {%- endif -%} 19 | 20 | {%- endfor %} 21 | {% endif %} 22 | {% endblock %} 23 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | import sys 10 | from pathlib import Path 11 | 12 | HERE = Path(__file__).parent 13 | sys.path[:0] = [str(HERE.parent), str(HERE / "extensions")] 14 | 15 | import mrvi # noqa 16 | 17 | 18 | # -- General configuration --------------------------------------------- 19 | 20 | # If your documentation needs a minimal Sphinx version, state it here. 21 | # 22 | needs_sphinx = "3.0" # Nicer param docs 23 | 24 | # Add any Sphinx extension module names here, as strings. They can be 25 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones. 26 | extensions = [ 27 | "sphinx.ext.autodoc", 28 | "sphinx.ext.viewcode", 29 | "nbsphinx", 30 | "nbsphinx_link", 31 | "sphinx.ext.mathjax", 32 | "sphinx.ext.napoleon", 33 | "sphinx_autodoc_typehints", # needs to be after napoleon 34 | "sphinx.ext.intersphinx", 35 | "sphinx.ext.autosummary", 36 | "scanpydoc.elegant_typehints", 37 | "scanpydoc.definition_list_typed_field", 38 | "scanpydoc.autosummary_generate_imported", 39 | *[p.stem for p in (HERE / "extensions").glob("*.py")], 40 | ] 41 | 42 | # nbsphinx specific settings 43 | exclude_patterns = ["_build", "**.ipynb_checkpoints"] 44 | nbsphinx_execute = "never" 45 | 46 | # Add any paths that contain templates here, relative to this directory. 47 | templates_path = ["_templates"] 48 | 49 | # The suffix(es) of source filenames. 50 | # You can specify multiple suffix as a list of string: 51 | # 52 | # source_suffix = ['.rst', '.md'] 53 | source_suffix = ".rst" 54 | 55 | # Generate the API documentation when building 56 | autosummary_generate = True 57 | autodoc_member_order = "bysource" 58 | napoleon_google_docstring = False 59 | napoleon_numpy_docstring = True 60 | napoleon_include_init_with_doc = False 61 | napoleon_use_rtype = True 62 | napoleon_use_param = True 63 | napoleon_custom_sections = [("Params", "Parameters")] 64 | todo_include_todos = False 65 | numpydoc_show_class_members = False 66 | annotate_defaults = True 67 | # The master toctree document. 68 | master_doc = "index" 69 | 70 | 71 | intersphinx_mapping = dict( 72 | anndata=("https://anndata.readthedocs.io/en/stable/", None), 73 | ipython=("https://ipython.readthedocs.io/en/stable/", None), 74 | matplotlib=("https://matplotlib.org/", None), 75 | numpy=("https://docs.scipy.org/doc/numpy/", None), 76 | pandas=("https://pandas.pydata.org/pandas-docs/stable/", None), 77 | python=("https://docs.python.org/3", None), 78 | scipy=("https://docs.scipy.org/doc/scipy/reference/", None), 79 | sklearn=("https://scikit-learn.org/stable/", None), 80 | torch=("https://pytorch.org/docs/master/", None), 81 | scanpy=("https://scanpy.readthedocs.io/en/stable/", None), 82 | pytorch_lightning=("https://pytorch-lightning.readthedocs.io/en/stable/", None), 83 | ) 84 | 85 | 86 | # General information about the project. 87 | project = "mrvi" 88 | copyright = "2022, Yosef Lab, UC Berkeley" 89 | author = "Pierre Boyeau, Justin Hong, Adam Gayoso" 90 | 91 | # The version info for the project you're documenting, acts as replacement 92 | # for |version| and |release|, also used in various other places throughout 93 | # the built documents. 94 | # 95 | # The short X.Y version. 96 | version = mrvi.__version__ 97 | # The full version, including alpha/beta/rc tags. 98 | release = mrvi.__version__ 99 | 100 | # The language for content autogenerated by Sphinx. Refer to documentation 101 | # for a list of supported languages. 102 | # 103 | # This is also used if you do content translation via gettext catalogs. 104 | # Usually you set "language" from the command line for these cases. 105 | language = None 106 | 107 | # List of patterns, relative to source directory, that match files and 108 | # directories to ignore when looking for source files. 109 | # This patterns also effect to html_static_path and html_extra_path 110 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 111 | 112 | # The name of the Pygments (syntax highlighting) style to use. 113 | pygments_style = "tango" 114 | 115 | # If true, `todo` and `todoList` produce output, else they produce nothing. 116 | todo_include_todos = False 117 | 118 | # -- Options for HTML output ------------------------------------------------- 119 | 120 | # The theme to use for HTML and HTML Help pages. See the documentation for 121 | # a list of builtin themes. 122 | # 123 | html_theme = "sphinx_rtd_theme" 124 | 125 | html_show_sourcelink = False 126 | 127 | html_show_copyright = False 128 | 129 | display_version = True 130 | 131 | # Add any paths that contain custom static files (such as style sheets) here, 132 | # relative to this directory. They are copied after the builtin static files, 133 | # so a file named "default.css" will overwrite the builtin "default.css". 134 | html_static_path = ["_static"] 135 | 136 | html_css_files = [ 137 | "css/custom.css", 138 | ] 139 | 140 | html_favicon = "favicon.ico" 141 | -------------------------------------------------------------------------------- /docs/extensions/typed_returns.py: -------------------------------------------------------------------------------- 1 | # code from https://github.com/theislab/scanpy/blob/master/docs/extensions/typed_returns.py 2 | # with some minor adjustment 3 | import re 4 | 5 | from sphinx.application import Sphinx 6 | from sphinx.ext.napoleon import NumpyDocstring 7 | 8 | 9 | def process_return(lines): 10 | for line in lines: 11 | m = re.fullmatch(r"(?P\w+)\s+:\s+(?P[\w.]+)", line) 12 | if m: 13 | # Once this is in scanpydoc, we can use the fancy hover stuff 14 | yield f'-{m["param"]} (:class:`~{m["type"]}`)' 15 | else: 16 | yield line 17 | 18 | 19 | def scanpy_parse_returns_section(self, section): 20 | lines_raw = list(process_return(self._dedent(self._consume_to_next_section()))) 21 | lines = self._format_block(":returns: ", lines_raw) 22 | if lines and lines[-1]: 23 | lines.append("") 24 | return lines 25 | 26 | 27 | def setup(app: Sphinx): 28 | NumpyDocstring._parse_returns_section = scanpy_parse_returns_section 29 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | ======================== 2 | mrvi documentation 3 | ======================== 4 | 5 | Welcome! This is the documentation website for the `mrvi 6 | `_ package. 7 | 8 | .. toctree:: 9 | :maxdepth: 1 10 | :hidden: 11 | 12 | installation 13 | api/index 14 | release_notes/index 15 | references 16 | -------------------------------------------------------------------------------- /docs/installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | Prerequisites 5 | ~~~~~~~~~~~~~~ 6 | 7 | `mrvi` can be installed via PyPI. 8 | 9 | conda prerequisites 10 | ################### 11 | 12 | 1. Install Conda. We typically use the Miniconda_ Python distribution. Use Python version >=3.7. 13 | 14 | 2. Create a new conda environment:: 15 | 16 | conda create -n mrvi-env python=3.7 17 | 18 | 3. Activate your environment:: 19 | 20 | source activate mrvi-env 21 | 22 | pip prerequisites: 23 | ################## 24 | 25 | 1. Install Python_, we prefer the `pyenv `_ version management system, along with `pyenv-virtualenv `_. 26 | 27 | 2. Install PyTorch_. If you have an Nvidia GPU, be sure to install a version of PyTorch that supports it -- scvi-tools runs much faster with a discrete GPU. 28 | 29 | .. _Miniconda: https://conda.io/miniconda.html 30 | .. _Python: https://www.python.org/downloads/ 31 | .. _PyTorch: http://pytorch.org 32 | 33 | mrvi installation 34 | ~~~~~~~~~~~~~~~~~~~~~~~ 35 | 36 | Install mrvi in one of the following ways: 37 | 38 | Through **pip**:: 39 | 40 | pip install 41 | 42 | Through pip with packages to run notebooks. This installs scanpy, etc.:: 43 | 44 | pip install [tutorials] 45 | 46 | Nightly version - clone this repo and run:: 47 | 48 | pip install . 49 | 50 | For development - clone this repo and run:: 51 | 52 | pip install -e .[dev,docs] 53 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=python -msphinx 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | set SPHINXPROJ=mrvi 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The Sphinx module was not found. Make sure you have Sphinx installed, 20 | echo.then set the SPHINXBUILD environment variable to point to the full 21 | echo.path of the 'sphinx-build' executable. Alternatively you may add the 22 | echo.Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /docs/references.bib: -------------------------------------------------------------------------------- 1 | @article{Boyeau2022mrvi, 2 | abstract = {Contemporary single-cell omics technologies have enabled complex experimental designs incorporating hundreds of samples accompanied by detailed information on sample-level conditions. Current approaches for analyzing condition-level heterogeneity in these experiments often rely on a simplification of the data such as an aggregation at the cell-type or cell-state-neighborhood level. Here we present MrVI, a deep generative model that provides sample-sample comparisons at a single-cell resolution, permitting the discovery of subtle sample-specific effects across cell populations. Additionally, the output of MrVI can be used to quantify the association between sample-level metadata and cell state variation. We benchmarked MrVI against conventional meta-analysis procedures on two synthetic datasets and one real dataset with a well-controlled experimental structure. This work introduces a novel approach to understanding sample-level heterogeneity while leveraging the full resolution of single-cell sequencing data.Competing Interest StatementN.Y. is an advisor and/or has equity in Cellarity, Celsius Therapeutics, and Rheos Medicine.}, 3 | author = {Boyeau, Pierre and Hong, Justin and Gayoso, Adam and Jordan, Michael I. and Azizi, Elham and Yosef, Nir}, 4 | doi = {10.1101/2022.10.04.510898}, 5 | elocation-id = {2022.10.04.510898}, 6 | eprint = {https://www.biorxiv.org/content/early/2022/10/06/2022.10.04.510898.full.pdf}, 7 | journal = {bioRxiv}, 8 | publisher = {Cold Spring Harbor Laboratory}, 9 | title = {Deep generative modeling for quantifying sample-level heterogeneity in single-cell omics}, 10 | url = {https://www.biorxiv.org/content/early/2022/10/06/2022.10.04.510898}, 11 | year = {2022}, 12 | bdsk-url-1 = {https://www.biorxiv.org/content/early/2022/10/06/2022.10.04.510898}, 13 | bdsk-url-2 = {https://doi.org/10.1101/2022.10.04.510898}} 14 | -------------------------------------------------------------------------------- /docs/references.rst: -------------------------------------------------------------------------------- 1 | References 2 | ========== 3 | **Deep generative modeling for quantifying sample-level heterogeneity in single-cell omics** 4 | Pierre Boyeau*, Justin Hong*, Adam Gayoso, Michael I. Jordan, Elham Azizi, Nir Yosef 5 | bioRxiv 2022. `Link `_. 6 | 7 | -------------------------------------------------------------------------------- /docs/release_notes/index.rst: -------------------------------------------------------------------------------- 1 | Release notes 2 | ============= 3 | 4 | This is the list of changes to ``mrvi`` between each release. Full commit history 5 | is available in the `commit logs `_. 6 | 7 | Version 0.1 8 | ----------- 9 | .. toctree:: 10 | :maxdepth: 2 11 | 12 | v0.1.0 13 | -------------------------------------------------------------------------------- /docs/release_notes/v0.1.0.rst: -------------------------------------------------------------------------------- 1 | New in 0.1.0 (2020-10-04) 2 | ------------------------- 3 | Initial release of ``mrvi``. Complies with ``scvi-tools>=0.9.0a0``. 4 | -------------------------------------------------------------------------------- /mrvi/__init__.py: -------------------------------------------------------------------------------- 1 | """scvi-tools-skeleton.""" 2 | 3 | import logging 4 | 5 | from rich.console import Console 6 | from rich.logging import RichHandler 7 | 8 | from ._model import MrVI, MrVAE 9 | 10 | # https://github.com/python-poetry/poetry/pull/2366#issuecomment-652418094 11 | # https://github.com/python-poetry/poetry/issues/144#issuecomment-623927302 12 | try: 13 | import importlib.metadata as importlib_metadata 14 | except ModuleNotFoundError: 15 | import importlib_metadata 16 | 17 | package_name = "mrvi" 18 | __version__ = importlib_metadata.version(package_name) 19 | 20 | logger = logging.getLogger(__name__) 21 | # set the logging level 22 | logger.setLevel(logging.INFO) 23 | 24 | # nice logging outputs 25 | console = Console(force_terminal=True) 26 | if console.is_jupyter is True: 27 | console.is_jupyter = False 28 | ch = RichHandler(show_path=False, console=console, show_time=False) 29 | formatter = logging.Formatter("mypackage: %(message)s") 30 | ch.setFormatter(formatter) 31 | logger.addHandler(ch) 32 | 33 | # this prevents double outputs 34 | logger.propagate = False 35 | 36 | __all__ = ["MrVI", "MrVAE"] 37 | -------------------------------------------------------------------------------- /mrvi/_components.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from scvi.distributions import NegativeBinomial 4 | from scvi.nn import one_hot 5 | 6 | from ._utils import ResnetFC 7 | 8 | 9 | class ExpActivation(nn.Module): 10 | def __init__(self) -> None: 11 | super().__init__() 12 | 13 | def forward(self, x: torch.Tensor) -> torch.Tensor: 14 | return torch.exp(x) 15 | 16 | 17 | class DecoderZX(nn.Module): 18 | """Parameterizes the counts likelihood for the data given the latent variables.""" 19 | 20 | def __init__( 21 | self, 22 | n_in, 23 | n_out, 24 | n_nuisance, 25 | linear_decoder, 26 | n_hidden=128, 27 | activation="softmax", 28 | ): 29 | super().__init__() 30 | if activation == "softmax": 31 | activation_ = nn.Softmax(-1) 32 | elif activation == "softplus": 33 | activation_ = nn.Softplus() 34 | elif activation == "exp": 35 | activation_ = ExpActivation() 36 | elif activation == "sigmoid": 37 | activation_ = nn.Sigmoid() 38 | else: 39 | raise ValueError("activation must be one of 'softmax' or 'softplus'") 40 | self.linear_decoder = linear_decoder 41 | self.n_nuisance = n_nuisance 42 | self.n_latent = n_in - n_nuisance 43 | if linear_decoder: 44 | self.amat = nn.Linear(self.n_latent, n_out, bias=False) 45 | self.amat_site = nn.Parameter( 46 | torch.randn(self.n_nuisance, self.n_latent, n_out) 47 | ) 48 | self.offsets = nn.Parameter(torch.randn(self.n_nuisance, n_out)) 49 | self.dropout_ = nn.Dropout(0.1) 50 | self.activation_ = activation_ 51 | 52 | else: 53 | self.px_mean = ResnetFC( 54 | n_in=n_in, 55 | n_out=n_out, 56 | n_hidden=n_hidden, 57 | activation=activation_, 58 | ) 59 | self.px_r = nn.Parameter(torch.randn(n_out)) 60 | 61 | def forward(self, z, size_factor): 62 | if self.linear_decoder: 63 | nuisance_oh = z[..., -self.n_nuisance :] 64 | z0 = z[..., : -self.n_nuisance] 65 | x1 = self.amat(z0) 66 | 67 | nuisance_ids = torch.argmax(nuisance_oh, -1) 68 | As = self.amat_site[nuisance_ids] 69 | z0_detach = self.dropout_(z0.detach())[..., None] 70 | x2 = (As * z0_detach).sum(-2) 71 | offsets = self.offsets[nuisance_ids] 72 | mu = x1 + x2 + offsets 73 | mu = self.activation_(mu) 74 | else: 75 | mu = self.px_mean(z) 76 | mu = mu * size_factor 77 | return NegativeBinomial(mu=mu, theta=self.px_r.exp()) 78 | 79 | 80 | class LinearDecoderUZ(nn.Module): 81 | def __init__( 82 | self, 83 | n_latent, 84 | n_sample, 85 | n_out, 86 | scaler=False, 87 | scaler_n_hidden=32, 88 | ): 89 | super().__init__() 90 | self.n_latent = n_latent 91 | self.n_sample = n_sample 92 | self.n_out = n_out 93 | 94 | self.amat_sample = nn.Parameter(torch.randn(n_sample, self.n_latent, n_out)) 95 | self.offsets = nn.Parameter(torch.randn(n_sample, n_out)) 96 | 97 | self.scaler = None 98 | if scaler: 99 | self.scaler = nn.Sequential( 100 | nn.Linear(n_latent + n_sample, scaler_n_hidden), 101 | nn.LayerNorm(scaler_n_hidden), 102 | nn.ReLU(), 103 | nn.Linear(scaler_n_hidden, 1), 104 | nn.Sigmoid(), 105 | ) 106 | 107 | def forward(self, u, sample_id): 108 | sample_id_ = sample_id.long().squeeze() 109 | As = self.amat_sample[sample_id_] 110 | 111 | u_detach = u.detach()[..., None] 112 | z2 = (As * u_detach).sum(-2) 113 | offsets = self.offsets[sample_id_] 114 | delta = z2 + offsets 115 | if self.scaler is not None: 116 | sample_oh = one_hot(sample_id, self.n_sample) 117 | if u.ndim != sample_oh.ndim: 118 | sample_oh = sample_oh[None].expand(u.shape[0], *sample_oh.shape) 119 | inputs = torch.cat([u.detach(), sample_oh], -1) 120 | delta = delta * self.scaler(inputs) 121 | return u + delta 122 | 123 | 124 | class DecoderUZ(nn.Module): 125 | def __init__( 126 | self, 127 | n_latent, 128 | n_latent_sample, 129 | n_out, 130 | dropout_rate=0.0, 131 | n_layers=1, 132 | n_hidden=128, 133 | ): 134 | super().__init__() 135 | self.n_latent = n_latent 136 | self.n_latent_sample = n_latent_sample 137 | self.n_in = n_latent + n_latent_sample 138 | self.n_out = n_out 139 | 140 | arch_mod = self.construct_arch(self.n_in, n_hidden, n_layers, dropout_rate) + [ 141 | nn.Linear(n_hidden, self.n_out, bias=False) 142 | ] 143 | self.mod = nn.Sequential(*arch_mod) 144 | 145 | arch_scaler = self.construct_arch( 146 | self.n_latent, n_hidden, n_layers, dropout_rate 147 | ) + [nn.Linear(n_hidden, 1)] 148 | self.scaler = nn.Sequential(*arch_scaler) 149 | self.scaler.append(nn.Sigmoid()) 150 | 151 | @staticmethod 152 | def construct_arch(n_inputs, n_hidden, n_layers, dropout_rate): 153 | """Initializes MLP architecture""" 154 | 155 | block_inputs = [ 156 | nn.Linear(n_inputs, n_hidden), 157 | nn.BatchNorm1d(n_hidden), 158 | nn.Dropout(p=dropout_rate), 159 | nn.ReLU(), 160 | ] 161 | 162 | block_inner = n_layers * [ 163 | nn.Linear(n_hidden, n_hidden), 164 | nn.BatchNorm1d(n_hidden), 165 | nn.ReLU(), 166 | ] 167 | return block_inputs + block_inner 168 | 169 | def forward(self, u): 170 | u_ = u.clone() 171 | if u_.dim() == 3: 172 | n_samples, n_cells, n_features = u_.shape 173 | u0_ = u_[:, :, : self.n_latent].reshape(-1, self.n_latent) 174 | u_ = u_.reshape(-1, n_features) 175 | pred_ = self.mod(u_).reshape(n_samples, n_cells, -1) 176 | scaler_ = self.scaler(u0_).reshape(n_samples, n_cells, -1) 177 | else: 178 | pred_ = self.mod(u) 179 | scaler_ = self.scaler(u[:, : self.n_latent]) 180 | mean = u[..., : self.n_latent] + scaler_ * pred_ 181 | return mean 182 | -------------------------------------------------------------------------------- /mrvi/_constants.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple 2 | 3 | 4 | class _MRVI_REGISTRY_KEYS_NT(NamedTuple): 5 | SAMPLE_KEY: str = "sample" 6 | CATEGORICAL_NUISANCE_KEYS: str = "categorical_nuisance_keys" 7 | 8 | 9 | MRVI_REGISTRY_KEYS = _MRVI_REGISTRY_KEYS_NT() 10 | -------------------------------------------------------------------------------- /mrvi/_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | from copy import deepcopy 5 | from typing import List, Optional 6 | 7 | import numpy as np 8 | import torch 9 | from anndata import AnnData 10 | from scvi import REGISTRY_KEYS 11 | from scvi.data import AnnDataManager 12 | from scvi.data.fields import CategoricalJointObsField, CategoricalObsField, LayerField 13 | from scvi.model.base import BaseModelClass, UnsupervisedTrainingMixin, VAEMixin 14 | from sklearn.metrics import pairwise_distances 15 | from tqdm import tqdm 16 | 17 | from ._constants import MRVI_REGISTRY_KEYS 18 | from ._module import MrVAE 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | DEFAULT_TRAIN_KWARGS = dict( 23 | early_stopping=True, 24 | early_stopping_patience=15, 25 | check_val_every_n_epoch=1, 26 | batch_size=256, 27 | train_size=0.9, 28 | plan_kwargs=dict( 29 | lr=1e-2, 30 | n_epochs_kl_warmup=20, 31 | ), 32 | ) 33 | 34 | 35 | class MrVI(UnsupervisedTrainingMixin, VAEMixin, BaseModelClass): 36 | """ 37 | Multi-resolution Variational Inference (MrVI) :cite:`Boyeau2022mrvi`. 38 | 39 | Parameters 40 | ---------- 41 | adata 42 | AnnData object that has been registered via 43 | :meth:`~scvi.model.MrVI.setup_anndata`. 44 | n_latent 45 | Dimensionality of the latent space. 46 | n_latent_donor 47 | Dimensionality of the latent space for sample embeddings. 48 | linear_decoder_zx 49 | Whether to use a linear decoder for the decoder from z to x. 50 | linear_decoder_uz 51 | Whether to use a linear decoder for the decoder from u to z. 52 | linear_decoder_uz_scaler 53 | Whether to incorporate a learned scaler term in the decoder from u to z. 54 | linear_decoder_uz_scaler_n_hidden 55 | If `linear_decoder_uz_scaler` is True, the number of hidden 56 | units in the neural network used to produce the scaler term 57 | in decoder from u to z. 58 | px_kwargs 59 | Keyword args for :class:`~mrvi.components.DecoderZX`. 60 | pz_kwargs 61 | Keyword args for :class:`~mrvi.components.DecoderUZ`. 62 | """ 63 | 64 | def __init__( 65 | self, 66 | adata: AnnData, 67 | **model_kwargs, 68 | ): 69 | super().__init__(adata) 70 | n_cats_per_nuisance_keys = ( 71 | self.adata_manager.get_state_registry( 72 | MRVI_REGISTRY_KEYS.CATEGORICAL_NUISANCE_KEYS 73 | ).n_cats_per_key 74 | if MRVI_REGISTRY_KEYS.CATEGORICAL_NUISANCE_KEYS 75 | in self.adata_manager.data_registry 76 | else [] 77 | ) 78 | 79 | n_sample = self.summary_stats.n_sample 80 | n_obs_per_sample = ( 81 | adata.obs.groupby( 82 | self.adata_manager.get_state_registry(MRVI_REGISTRY_KEYS.SAMPLE_KEY)[ 83 | "original_key" 84 | ] 85 | ) 86 | .size() 87 | .loc[ 88 | self.adata_manager.get_state_registry(MRVI_REGISTRY_KEYS.SAMPLE_KEY)[ 89 | "categorical_mapping" 90 | ] 91 | ] 92 | .values 93 | ) 94 | n_obs_per_sample = torch.from_numpy(n_obs_per_sample).float() 95 | self.data_splitter = None 96 | self.module = MrVAE( 97 | n_input=self.summary_stats.n_vars, 98 | n_sample=n_sample, 99 | n_obs_per_sample=n_obs_per_sample, 100 | n_cats_per_nuisance_keys=n_cats_per_nuisance_keys, 101 | **model_kwargs, 102 | ) 103 | self.init_params_ = self._get_init_params(locals()) 104 | 105 | @classmethod 106 | def setup_anndata( 107 | cls, 108 | adata: AnnData, 109 | layer: Optional[str] = None, 110 | sample_key: Optional[str] = None, 111 | categorical_nuisance_keys: Optional[List[str]] = None, 112 | **kwargs, 113 | ): 114 | setup_method_args = cls._get_setup_method_args(**locals()) 115 | anndata_fields = [ 116 | LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), 117 | CategoricalObsField(MRVI_REGISTRY_KEYS.SAMPLE_KEY, sample_key), 118 | CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, None), 119 | CategoricalJointObsField( 120 | MRVI_REGISTRY_KEYS.CATEGORICAL_NUISANCE_KEYS, categorical_nuisance_keys 121 | ), 122 | ] 123 | adata_manager = AnnDataManager( 124 | fields=anndata_fields, setup_method_args=setup_method_args 125 | ) 126 | adata_manager.register_fields(adata, **kwargs) 127 | cls.register_manager(adata_manager) 128 | 129 | def train( 130 | self, 131 | max_epochs: int | None = None, 132 | accelerator: str = "auto", 133 | devices: int | list[int] | str = "auto", 134 | train_size: float = 0.9, 135 | validation_size: float | None = None, 136 | batch_size: int = 128, 137 | early_stopping: bool = False, 138 | plan_kwargs: dict | None = None, 139 | **trainer_kwargs, 140 | ): 141 | train_kwargs = dict( 142 | max_epochs=max_epochs, 143 | accelerator=accelerator, 144 | devices=devices, 145 | train_size=train_size, 146 | validation_size=validation_size, 147 | batch_size=batch_size, 148 | early_stopping=early_stopping, 149 | **trainer_kwargs, 150 | ) 151 | train_kwargs = dict(deepcopy(DEFAULT_TRAIN_KWARGS), **train_kwargs) 152 | plan_kwargs = plan_kwargs or {} 153 | train_kwargs["plan_kwargs"] = dict( 154 | deepcopy(DEFAULT_TRAIN_KWARGS["plan_kwargs"]), **plan_kwargs 155 | ) 156 | super().train(**train_kwargs) 157 | 158 | @torch.no_grad() 159 | def get_latent_representation( 160 | self, 161 | adata: AnnData | None = None, 162 | indices: list[int] | None = None, 163 | mc_samples: int = 5000, 164 | batch_size: int | None = None, 165 | give_z: bool = False, 166 | ) -> np.ndarray: 167 | self._check_if_trained(warn=False) 168 | adata = self._validate_anndata(adata) 169 | scdl = self._make_data_loader( 170 | adata=adata, indices=indices, batch_size=batch_size 171 | ) 172 | 173 | u = [] 174 | z = [] 175 | for tensors in tqdm(scdl): 176 | inference_inputs = self.module._get_inference_input(tensors) 177 | outputs = self.module.inference(mc_samples=mc_samples, **inference_inputs) 178 | u.append(outputs["u"].mean(0).cpu()) 179 | z.append(outputs["z"].mean(0).cpu()) 180 | 181 | u = torch.cat(u, 0).numpy() 182 | z = torch.cat(z, 0).numpy() 183 | return z if give_z else u 184 | 185 | @staticmethod 186 | def compute_distance_matrix_from_representations( 187 | representations: np.ndarray, metric: str = "euclidean" 188 | ) -> np.ndarray: 189 | """ 190 | Compute distance matrices from counterfactual sample representations. 191 | 192 | Parameters 193 | ---------- 194 | representations 195 | Counterfactual sample representations of shape 196 | (n_cells, n_sample, n_features). 197 | metric 198 | Metric to use for computing distance matrix. 199 | """ 200 | n_cells, n_donors, _ = representations.shape 201 | pairwise_dists = np.zeros((n_cells, n_donors, n_donors)) 202 | for i, cell_rep in enumerate(representations): 203 | d_ = pairwise_distances(cell_rep, metric=metric) 204 | pairwise_dists[i, :, :] = d_ 205 | return pairwise_dists 206 | 207 | @torch.no_grad() 208 | def get_local_sample_representation( 209 | self, 210 | adata: AnnData | None = None, 211 | batch_size: int = 256, 212 | mc_samples: int = 10, 213 | return_distances: bool = False, 214 | ): 215 | """ 216 | Computes the local sample representation of the cells in the adata object. 217 | 218 | For each cell, it returns a matrix of size (n_sample, n_features). 219 | 220 | Parameters 221 | ---------- 222 | adata 223 | AnnData object to use for computing the local sample representation. 224 | batch_size 225 | Batch size to use for computing the local sample representation. 226 | mc_samples 227 | Number of Monte Carlo samples to use for computing the local sample 228 | representation. 229 | return_distances 230 | If ``return_distances`` is ``True``, returns a distance matrix of 231 | size (n_sample, n_sample) for each cell. 232 | """ 233 | adata = self.adata if adata is None else adata 234 | self._check_if_trained(warn=False) 235 | adata = self._validate_anndata(adata) 236 | scdl = self._make_data_loader(adata=adata, indices=None, batch_size=batch_size) 237 | 238 | reps = [] 239 | for tensors in tqdm(scdl): 240 | xs = [] 241 | for sample in range(self.summary_stats.n_sample): 242 | cf_sample = sample * torch.ones_like( 243 | tensors[MRVI_REGISTRY_KEYS.SAMPLE_KEY] 244 | ) 245 | inference_inputs = self.module._get_inference_input(tensors) 246 | inference_outputs = self.module.inference( 247 | mc_samples=mc_samples, cf_sample=cf_sample, **inference_inputs 248 | ) 249 | new = inference_outputs["z"] 250 | 251 | xs.append(new[:, :, None]) 252 | 253 | xs = torch.cat(xs, 2).mean(0) 254 | reps.append(xs.cpu().numpy()) 255 | # n_cells, n_sample, n_latent 256 | reps = np.concatenate(reps, 0) 257 | 258 | if return_distances: 259 | return self.compute_distance_matrix_from_representations(reps) 260 | 261 | return reps 262 | -------------------------------------------------------------------------------- /mrvi/_module.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import torch 4 | import torch.distributions as db 5 | import torch.nn as nn 6 | from scvi import REGISTRY_KEYS 7 | from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data 8 | from scvi.nn import one_hot 9 | from torch.distributions import kl_divergence as kl 10 | 11 | from ._components import DecoderUZ, DecoderZX, LinearDecoderUZ 12 | from ._constants import MRVI_REGISTRY_KEYS 13 | from ._utils import ConditionalBatchNorm1d, NormalNN 14 | 15 | DEFAULT_PX_HIDDEN = 32 16 | DEFAULT_PZ_LAYERS = 1 17 | DEFAULT_PZ_HIDDEN = 32 18 | 19 | 20 | class MrVAE(BaseModuleClass): 21 | def __init__( 22 | self, 23 | n_input: int, 24 | n_sample: int, 25 | n_obs_per_sample: int, 26 | n_cats_per_nuisance_keys: list[int], 27 | n_latent: int = 10, 28 | n_latent_sample: int = 2, 29 | linear_decoder_zx: bool = True, 30 | linear_decoder_uz: bool = True, 31 | linear_decoder_uz_scaler: bool = False, 32 | linear_decoder_uz_scaler_n_hidden: int = 32, 33 | px_kwargs: dict | None = None, 34 | pz_kwargs: dict | None = None, 35 | ): 36 | super().__init__() 37 | px_kwargs = dict(n_hidden=DEFAULT_PX_HIDDEN) 38 | if px_kwargs is not None: 39 | px_kwargs.update(px_kwargs) 40 | pz_kwargs = dict(n_layers=DEFAULT_PZ_LAYERS, n_hidden=DEFAULT_PZ_HIDDEN) 41 | if pz_kwargs is not None: 42 | pz_kwargs.update(pz_kwargs) 43 | 44 | self.n_cats_per_nuisance_keys = n_cats_per_nuisance_keys 45 | self.n_sample = n_sample 46 | assert n_latent_sample != 0 47 | self.sample_embeddings = nn.Embedding(n_sample, n_latent_sample) 48 | 49 | n_nuisance = sum(self.n_cats_per_nuisance_keys) 50 | # Generative model 51 | self.px = DecoderZX( 52 | n_latent + n_nuisance, 53 | n_input, 54 | n_nuisance=n_nuisance, 55 | linear_decoder=linear_decoder_zx, 56 | **px_kwargs, 57 | ) 58 | self.qu = NormalNN(128 + n_latent_sample, n_latent, n_categories=1) 59 | self.ql = NormalNN(n_input, 1, n_categories=1) 60 | 61 | self.linear_decoder_uz = linear_decoder_uz 62 | if linear_decoder_uz: 63 | self.pz = LinearDecoderUZ( 64 | n_latent, 65 | self.n_sample, 66 | n_latent, 67 | scaler=linear_decoder_uz_scaler, 68 | scaler_n_hidden=linear_decoder_uz_scaler_n_hidden, 69 | ) 70 | else: 71 | self.pz = DecoderUZ( 72 | n_latent, 73 | n_latent_sample, 74 | n_latent, 75 | **pz_kwargs, 76 | ) 77 | self.n_obs_per_sample = nn.Parameter(n_obs_per_sample, requires_grad=False) 78 | 79 | self.x_featurizer = nn.Sequential(nn.Linear(n_input, 128), nn.ReLU()) 80 | self.bnn = ConditionalBatchNorm1d(128, n_sample) 81 | self.x_featurizer2 = nn.Sequential(nn.Linear(128, 128), nn.ReLU()) 82 | self.bnn2 = ConditionalBatchNorm1d(128, n_sample) 83 | 84 | def _get_inference_input( 85 | self, tensors: dict[str, torch.Tensor], **kwargs 86 | ) -> dict[str, torch.Tensor]: 87 | x = tensors[REGISTRY_KEYS.X_KEY] 88 | sample_index = tensors[MRVI_REGISTRY_KEYS.SAMPLE_KEY] 89 | categorical_nuisance_keys = tensors[ 90 | MRVI_REGISTRY_KEYS.CATEGORICAL_NUISANCE_KEYS 91 | ] 92 | return dict( 93 | x=x, 94 | sample_index=sample_index, 95 | categorical_nuisance_keys=categorical_nuisance_keys, 96 | ) 97 | 98 | @auto_move_data 99 | def inference( 100 | self, 101 | x: torch.Tensor, 102 | sample_index: torch.Tensor, 103 | categorical_nuisance_keys: torch.Tensor, 104 | mc_samples: int = 1, 105 | cf_sample: torch.Tensor | None = None, 106 | use_mean: bool = False, 107 | ) -> dict[str, torch.Tensor]: 108 | x_ = torch.log1p(x) 109 | 110 | sample_index_cf = sample_index if cf_sample is None else cf_sample 111 | zsample = self.sample_embeddings(sample_index_cf.long().squeeze(-1)) 112 | zsample_ = zsample 113 | if mc_samples >= 2: 114 | zsample_ = zsample[None].expand(mc_samples, *zsample.shape) 115 | 116 | nuisance_oh = [] 117 | for dim in range(categorical_nuisance_keys.shape[-1]): 118 | nuisance_oh.append( 119 | one_hot( 120 | categorical_nuisance_keys[:, [dim]], 121 | self.n_cats_per_nuisance_keys[dim], 122 | ) 123 | ) 124 | nuisance_oh = torch.cat(nuisance_oh, dim=-1) 125 | 126 | x_feat = self.x_featurizer(x_) 127 | x_feat = self.bnn(x_feat, sample_index) 128 | x_feat = self.x_featurizer2(x_feat) 129 | x_feat = self.bnn2(x_feat, sample_index) 130 | if x_.ndim != zsample_.ndim: 131 | x_feat_ = x_feat[None].expand(mc_samples, *x_feat.shape) 132 | nuisance_oh = nuisance_oh[None].expand(mc_samples, *nuisance_oh.shape) 133 | else: 134 | x_feat_ = x_feat 135 | 136 | inputs = torch.cat([x_feat_, zsample_], -1) 137 | # inputs = x_feat_ 138 | qu = self.qu(inputs) 139 | if use_mean: 140 | u = qu.loc 141 | else: 142 | u = qu.rsample() 143 | 144 | if self.linear_decoder_uz: 145 | z = self.pz(u, sample_index_cf) 146 | else: 147 | inputs = torch.cat([u, zsample_], -1) 148 | z = self.pz(inputs) 149 | library = torch.log(x.sum(1)).unsqueeze(1) 150 | 151 | return dict( 152 | qu=qu, 153 | u=u, 154 | z=z, 155 | zsample=zsample, 156 | library=library, 157 | nuisance_oh=nuisance_oh, 158 | ) 159 | 160 | def get_z( 161 | self, 162 | u: torch.Tensor, 163 | zsample: torch.Tensor | None = None, 164 | sample_index: torch.Tensor | None = None, 165 | ) -> torch.Tensor: 166 | if sample_index is not None: 167 | zsample = self.sample_embeddings(sample_index.long().squeeze(-1)) 168 | zsample = zsample 169 | else: 170 | zsample_ = zsample 171 | inputs = torch.cat([u, zsample_], -1) 172 | z = self.pz(inputs) 173 | return z 174 | 175 | def _get_generative_input( 176 | self, 177 | tensors: dict[str, torch.Tensor], 178 | inference_outputs: dict[str, torch.Tensor], 179 | **kwargs, 180 | ) -> dict[str, torch.Tensor]: 181 | res = dict( 182 | z=inference_outputs["z"], 183 | library=inference_outputs["library"], 184 | nuisance_oh=inference_outputs["nuisance_oh"], 185 | ) 186 | 187 | return res 188 | 189 | @auto_move_data 190 | def generative( 191 | self, 192 | z: torch.Tensor, 193 | library: torch.Tensor, 194 | nuisance_oh: torch.Tensor, 195 | ) -> dict[str, torch.Tensor]: 196 | inputs = torch.concat([z, nuisance_oh], dim=-1) 197 | px = self.px(inputs, size_factor=library.exp()) 198 | h = px.mu / library.exp() 199 | 200 | pu = db.Normal(0, 1) 201 | return dict(px=px, pu=pu, h=h) 202 | 203 | def loss( 204 | self, 205 | tensors: dict[str, torch.Tensor], 206 | inference_outputs: dict[str, torch.Tensor], 207 | generative_outputs: dict[str, torch.Tensor], 208 | kl_weight: float = 1.0, 209 | ) -> LossOutput: 210 | reconstruction_loss = ( 211 | -generative_outputs["px"].log_prob(tensors[REGISTRY_KEYS.X_KEY]).sum(-1) 212 | ) 213 | kl_u = kl(inference_outputs["qu"], generative_outputs["pu"]).sum(-1) 214 | kl_local_for_warmup = kl_u 215 | 216 | weighted_kl_local = kl_weight * kl_local_for_warmup 217 | loss = torch.mean(reconstruction_loss + weighted_kl_local) 218 | 219 | kl_local = torch.tensor(0.0) 220 | kl_global = torch.tensor(0.0) 221 | 222 | return LossOutput( 223 | loss=loss, 224 | reconstruction_loss=reconstruction_loss, 225 | kl_local=kl_local, 226 | kl_global=kl_global, 227 | ) 228 | -------------------------------------------------------------------------------- /mrvi/_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | import torch.distributions as db 5 | import torch.nn as nn 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | class ResnetFC(nn.Module): 11 | def __init__( 12 | self, 13 | n_in, 14 | n_out, 15 | n_hidden=128, 16 | activation=nn.Softmax(-1), 17 | ): 18 | super().__init__() 19 | self.module1 = nn.Sequential( 20 | nn.Linear(n_in, n_hidden), 21 | nn.BatchNorm1d(n_hidden), 22 | nn.ReLU(), 23 | ) 24 | self.module2 = nn.Sequential( 25 | nn.Linear(n_hidden, n_out), 26 | nn.BatchNorm1d(n_out), 27 | ) 28 | if n_in != n_hidden: 29 | self.id_map1 = nn.Linear(n_in, n_hidden) 30 | else: 31 | self.id_map1 = None 32 | self.activation = activation 33 | 34 | def forward(self, inputs): 35 | need_reshaping = False 36 | if inputs.ndim == 3: 37 | n_d1, nd2 = inputs.shape[:2] 38 | inputs = inputs.reshape(n_d1 * nd2, -1) 39 | need_reshaping = True 40 | h = self.module1(inputs) 41 | if self.id_map1 is not None: 42 | h = h + self.id_map1(inputs) 43 | h = self.module2(h) 44 | if need_reshaping: 45 | h = h.view(n_d1, nd2, -1) 46 | if self.activation is not None: 47 | return self.activation(h) 48 | return h 49 | 50 | 51 | class _NormalNN(nn.Module): 52 | def __init__( 53 | self, 54 | n_in, 55 | n_out, 56 | n_hidden=128, 57 | n_layers=1, 58 | use_batch_norm=True, 59 | use_layer_norm=False, 60 | do_orthogonal=False, 61 | ): 62 | super().__init__() 63 | self.n_layers = n_layers 64 | 65 | self.hidden = ResnetFC(n_in, n_out=n_hidden, activation=nn.ReLU()) 66 | self._mean = nn.Linear(n_hidden, n_out) 67 | self._var = nn.Sequential(nn.Linear(n_hidden, n_out), nn.Softplus()) 68 | 69 | def forward(self, inputs): 70 | if self.n_layers >= 1: 71 | h = self.hidden(inputs) 72 | mean = self._mean(h) 73 | var = self._var(h) 74 | else: 75 | mean = self._mean(inputs) 76 | k = mean.shape[0] 77 | var = self._var[None].expand(k, -1) 78 | return mean, var 79 | 80 | 81 | class NormalNN(nn.Module): 82 | def __init__( 83 | self, 84 | n_in, 85 | n_out, 86 | n_categories, 87 | n_hidden=128, 88 | n_layers=1, 89 | use_batch_norm=True, 90 | use_layer_norm=False, 91 | ): 92 | super().__init__() 93 | nn_kwargs = dict( 94 | n_in=n_in, 95 | n_out=n_out, 96 | n_hidden=n_hidden, 97 | n_layers=n_layers, 98 | use_batch_norm=use_batch_norm, 99 | use_layer_norm=use_layer_norm, 100 | ) 101 | self.n_out = n_out 102 | self._mymodules = nn.ModuleList( 103 | [_NormalNN(**nn_kwargs) for _ in range(n_categories)] 104 | ) 105 | 106 | def forward(self, inputs, categories=None): 107 | means = [] 108 | vars = [] 109 | for idx, module in enumerate(self._mymodules): 110 | _means, _vars = module(inputs) 111 | means.append(_means[..., None]) 112 | vars.append(_vars[..., None]) 113 | means = torch.cat(means, -1) 114 | vars = torch.cat(vars, -1) 115 | if categories is not None: 116 | # categories (minibatch, 1) 117 | n_batch = categories.shape[0] 118 | cat_ = categories.unsqueeze(-1).long().expand(n_batch, self.n_out, 1) 119 | if means.ndim == 4: 120 | d1, n_batch, _, _ = means.shape 121 | cat_ = ( 122 | categories[None, :, None].long().expand(d1, n_batch, self.n_out, 1) 123 | ) 124 | means = torch.gather(means, -1, cat_) 125 | vars = torch.gather(vars, -1, cat_) 126 | means = means.squeeze(-1) 127 | vars = vars.squeeze(-1) 128 | return db.Normal(means, vars + 1e-5) 129 | 130 | 131 | class ConditionalBatchNorm1d(nn.Module): 132 | def __init__(self, num_features, num_classes): 133 | super().__init__() 134 | self.num_features = num_features 135 | self.bn = nn.BatchNorm1d(num_features, affine=False) 136 | self.embed = nn.Embedding(num_classes, num_features * 2) 137 | self.embed.weight.data[:, :num_features].normal_( 138 | 1, 0.02 139 | ) # Initialise scale at N(1, 0.02) 140 | self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 141 | 142 | def forward(self, x, y): 143 | need_reshaping = False 144 | if x.ndim == 3: 145 | n_d1, nd2 = x.shape[:2] 146 | x = x.view(n_d1 * nd2, -1) 147 | need_reshaping = True 148 | 149 | y = y[None].expand(n_d1, nd2, -1) 150 | y = y.contiguous().view(n_d1 * nd2, -1) 151 | 152 | out = self.bn(x) 153 | gamma, beta = self.embed(y.squeeze(-1).long()).chunk(2, 1) 154 | out = gamma * out + beta 155 | 156 | if need_reshaping: 157 | out = out.view(n_d1, nd2, -1) 158 | 159 | return out 160 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.isort] 2 | include_trailing_comma = true 3 | multi_line_output = 3 4 | profile = "black" 5 | skip_glob = ["docs/*", "mrvi/__init__.py"] 6 | 7 | [tool.poetry] 8 | authors = ["Pierre Boyeau ", "Justin Hong ", "Adam Gayoso "] 9 | classifiers = [ 10 | "Development Status :: 4 - Beta", 11 | "Intended Audience :: Science/Research", 12 | "Natural Language :: English", 13 | "Programming Language :: Python :: 3.6", 14 | "Programming Language :: Python :: 3.7", 15 | "Operating System :: MacOS :: MacOS X", 16 | "Operating System :: Microsoft :: Windows", 17 | "Operating System :: POSIX :: Linux", 18 | "Topic :: Scientific/Engineering :: Bio-Informatics", 19 | ] 20 | description = "Multi-resolution analysis of single-cell data." 21 | documentation = "https://scvi-tools.org" 22 | homepage = "https://github.com/YosefLab/mrvi" 23 | license = "BSD-3-Clause" 24 | name = "mrvi" 25 | packages = [ 26 | {include = "mrvi"}, 27 | ] 28 | readme = "README.md" 29 | version = "0.2.0" 30 | 31 | [tool.poetry.dependencies] 32 | anndata = ">=0.7.5" 33 | black = {version = ">=20.8b1", optional = true} 34 | codecov = {version = ">=2.0.8", optional = true} 35 | flake8 = {version = ">=3.7.7", optional = true} 36 | importlib-metadata = {version = "^1.0", python = "<3.8"} 37 | ipython = {version = ">=7.1.1", optional = true} 38 | isort = {version = ">=5.7", optional = true} 39 | jupyter = {version = ">=1.0", optional = true} 40 | leidenalg = {version = "*", optional = true} 41 | loompy = {version = ">=3.0.6", optional = true} 42 | nbconvert = {version = ">=5.4.0", optional = true} 43 | nbformat = {version = ">=4.4.0", optional = true} 44 | nbsphinx = {version = "*", optional = true} 45 | nbsphinx-link = {version = "*", optional = true} 46 | pre-commit = {version = ">=2.7.1", optional = true} 47 | pydata-sphinx-theme = {version = ">=0.4.0", optional = true} 48 | pytest = {version = ">=4.4", optional = true} 49 | python = ">=3.7.2,<4.0" 50 | python-igraph = {version = "*", optional = true} 51 | scanpy = {version = ">=1.6", optional = true} 52 | scanpydoc = {version = ">=0.5", optional = true} 53 | scikit-misc = {version = ">=0.1.3", optional = true} 54 | scvi-tools = ">=1.0.0" 55 | sphinx = {version = ">=4.1,<4.4", optional = true} 56 | sphinx-autodoc-typehints = {version = "*", optional = true} 57 | sphinx-rtd-theme = {version = "*", optional = true} 58 | typing_extensions = {version = "*", python = "<3.8"} 59 | 60 | [tool.poetry.extras] 61 | dev = ["black", "pytest", "flake8", "codecov", "scanpy", "loompy", "jupyter", "nbformat", "nbconvert", "pre-commit", "isort"] 62 | docs = [ 63 | "sphinx", 64 | "scanpydoc", 65 | "nbsphinx", 66 | "nbsphinx-link", 67 | "ipython", 68 | "pydata-sphinx-theme", 69 | "typing_extensions", 70 | "sphinx-autodoc-typehints", 71 | "sphinx-rtd-theme", 72 | "sphinxcontrib-bibtex", 73 | ] 74 | 75 | [tool.poetry.dev-dependencies] 76 | 77 | [build-system] 78 | requires = [ 79 | "poetry-core>=1.0.0", 80 | ] 81 | build-backend = "poetry.core.masonry.api" 82 | -------------------------------------------------------------------------------- /readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | build: 3 | image: latest 4 | sphinx: 5 | configuration: docs/conf.py 6 | python: 7 | version: 3.7 8 | install: 9 | - method: pip 10 | path: . 11 | extra_requirements: 12 | - docs 13 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # This is a shim to hopefully allow Github to detect the package, build is done with poetry 4 | 5 | import setuptools 6 | 7 | if __name__ == "__main__": 8 | setuptools.setup(name="mrvi") 9 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YosefLab/mrvi_archive/86112e3d32e3f1d499efa5ffcace26746b022603/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_mrvi.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scvi.data import synthetic_iid 3 | 4 | from mrvi import MrVI 5 | 6 | 7 | def test_mrvi(): 8 | adata = synthetic_iid() 9 | adata.obs["sample"] = np.random.choice(15, size=adata.shape[0]) 10 | MrVI.setup_anndata(adata, sample_key="sample", categorical_nuisance_keys=["batch"]) 11 | for linear_decoder_uz in [True, False]: 12 | for linear_decoder_zx in [True, False]: 13 | model = MrVI( 14 | adata, 15 | n_latent_sample=5, 16 | linear_decoder_zx=linear_decoder_zx, 17 | linear_decoder_uz=linear_decoder_uz, 18 | ) 19 | model.train(1, check_val_every_n_epoch=1, train_size=0.5) 20 | model.history 21 | 22 | model = MrVI( 23 | adata, 24 | n_latent_sample=5, 25 | linear_decoder_zx=True, 26 | linear_decoder_uz=True, 27 | linear_decoder_uz_scaler=True, 28 | ) 29 | model.train(1, check_val_every_n_epoch=1, train_size=0.5) 30 | model.get_latent_representation() 31 | assert model.get_local_sample_representation().shape == (adata.shape[0], 15, 10) 32 | assert model.get_local_sample_representation(return_distances=True).shape == ( 33 | adata.shape[0], 34 | 15, 35 | 15, 36 | ) 37 | # tests __repr__ 38 | print(model) 39 | -------------------------------------------------------------------------------- /workflow/Snakefile: -------------------------------------------------------------------------------- 1 | configfile: "config/config.yaml" 2 | 3 | MODELS = [ 4 | "MrVISmall", 5 | "MrVILinear", 6 | "MrVILinear50", 7 | "MrVILinear50COMP", 8 | "MrVILinear10COMP", 9 | "MrVILinearLinear10COMP", 10 | "MrVILinearLinear10", 11 | "MrVILinearLinear10SCALER", 12 | "MILO", 13 | "MILOSCVI", 14 | "PCAKNN", 15 | "CompositionPCA", 16 | "CompositionSCVI", 17 | "SCVI", 18 | ] 19 | 20 | import random 21 | 22 | def get_n_replicates(dataset): 23 | dataset_config = config[dataset] 24 | n_replicates = dataset_config["nReplicates"] if "nReplicates" in dataset_config else 1 25 | return n_replicates 26 | 27 | def get_random_seeds(dataset, n_models): 28 | dataset_config = config[dataset] 29 | n_replicates = dataset_config["nReplicates"] if "nReplicates" in dataset_config else 1 30 | random_seed = dataset_config["randomSeed"] if "randomSeed" in dataset_config else 1 31 | random.seed(random_seed) 32 | rep_random_seeds = [random.randint(0, 2**32) for _ in range(n_replicates * n_models)] 33 | return rep_random_seeds 34 | 35 | rule all: 36 | input: 37 | "results/synthetic_experiment.done", 38 | "results/semisynthetic_experiment.done", 39 | "results/snrna_experiment.done", 40 | 41 | rule synthetic_experiment: 42 | output: touch("results/synthetic_experiment.done") 43 | input: 44 | expand( 45 | "results/synthetic/final_adata_{model}_{seed}.h5ad", 46 | zip, 47 | model=MODELS * get_n_replicates("synthetic"), 48 | seed=get_random_seeds("synthetic", len(MODELS)), 49 | ) 50 | 51 | rule semisynthetic_experiment: 52 | output: touch("results/semisynthetic_experiment.done") 53 | input: 54 | expand( 55 | "results/semisynthetic/final_adata_{model}_{seed}.h5ad", 56 | zip, 57 | model=MODELS * get_n_replicates("semisynthetic"), 58 | seed=get_random_seeds("semisynthetic", len(MODELS)), 59 | ), 60 | 61 | rule snrna_experiment: 62 | output: touch("results/snrna_experiment.done") 63 | input: 64 | expand( 65 | "results/snrna/final_adata_{model}_{seed}.h5ad", 66 | zip, 67 | model=MODELS * get_n_replicates("snrna"), 68 | seed=get_random_seeds("snrna", len(MODELS)), 69 | ) 70 | 71 | 72 | def get_s3_dataset_path(wildcards): 73 | return config[wildcards.dataset]["s3FilePath"] 74 | 75 | rule load_dataset_from_s3: 76 | params: get_s3_dataset_path 77 | output: "data/{dataset,[A-Za-z0-9]+}/adata.h5ad" 78 | shell: 79 | "aws s3 cp s3://{params} {output}" 80 | 81 | rule process_dataset: 82 | input: 83 | "data/{dataset}/adata.h5ad" 84 | output: 85 | "data/{dataset,[A-Za-z0-9]+}/adata.processed.h5ad", 86 | conda: 87 | "envs/process_data.yaml" 88 | script: 89 | "scripts/process_data.py" 90 | 91 | rule run_model: 92 | input: 93 | "data/{dataset}/adata.processed.h5ad", 94 | output: 95 | "results/{dataset}/adata_{model,[A-Za-z0-9]+}_{seed, \d+}.h5ad" 96 | threads: 97 | 8 98 | log: 99 | "logs/{dataset}_{model}_{seed}.log" 100 | conda: 101 | "envs/run_model.yaml" 102 | resources: 103 | nvidia_gpu=1 104 | script: 105 | "scripts/run_model.py" 106 | 107 | rule compute_local_scores: 108 | input: 109 | "results/{dataset}/adata_{model}_{seed}.h5ad" 110 | output: 111 | "results/{dataset}/final_adata_{model,[A-Za-z0-9]+}_{seed, \d+}.h5ad" 112 | threads: 113 | 8 114 | conda: 115 | "envs/run_model.yaml" 116 | script: 117 | "scripts/compute_local_scores.py" 118 | -------------------------------------------------------------------------------- /workflow/config/config.yaml: -------------------------------------------------------------------------------- 1 | synthetic: 2 | s3FilePath: largedonor/symsim_new.h5ad 3 | nReplicates: 5 4 | randomSeed: 123 5 | nEpochs: 400 6 | batchSize: 256 7 | keyMapping: 8 | donorKey: donor 9 | cellTypeKey: celltype 10 | nuisanceKeys: 11 | - batch 12 | relevantKeys: 13 | - batch 14 | - donor_meta_1 15 | - donor_meta_2 16 | - donor_meta_3 17 | 18 | semisynthetic: 19 | origFilePath: None 20 | nEpochs: 400 21 | batchSize: 256 22 | s3FilePath: largedonor/scvi_pbmcs.h5ad 23 | keyMapping: 24 | donorKey: batch 25 | cellTypeKey: str_labels 26 | nuisanceKeys: 27 | - Site 28 | relevantKeys: 29 | - Site 30 | - tree_id1_0 31 | - tree_id1_1 32 | - tree_id1_2 33 | - tree_id1_3 34 | - tree_id1_4 35 | - tree_id2_0 36 | - tree_id2_1 37 | - tree_id2_2 38 | - tree_id2_3 39 | - tree_id2_4 40 | 41 | snrna: 42 | origFilePath: SCP259 43 | s3FilePath: largedonor/nucleus.h5ad 44 | nEpochs: 300 45 | batchSize: 256 46 | preprocessing: 47 | filter_genes: 48 | min_cells: 500 49 | highly_variable_genes: 50 | n_top_genes: 3000 51 | flavor: seurat_v3 52 | keyMapping: 53 | donorKey: library_uuid 54 | cellTypeKey: cell_type 55 | nuisanceKeys: 56 | - suspension_type 57 | relevantKeys: 58 | - library_uuid 59 | - suspension_type 60 | -------------------------------------------------------------------------------- /workflow/envs/process_data.yaml: -------------------------------------------------------------------------------- 1 | channels: 2 | - conda-forge 3 | dependencies: 4 | - python=3.8 5 | - pytorch=1.11.0 6 | - pip: 7 | - scanpy==1.9.1 8 | - scikit-misc==0.1.4 9 | -------------------------------------------------------------------------------- /workflow/envs/run_model.yaml: -------------------------------------------------------------------------------- 1 | channels: 2 | - pytorch 3 | - conda-forge 4 | - bioconda 5 | - defaults 6 | dependencies: 7 | - python~=3.8 8 | - cudatoolkit=11.3.1=ha36c431_9 9 | - scvi-tools=0.17.4 10 | - pytorch=1.11.0=py3.8_cuda11.3_cudnn8.2.0_0 11 | - torchvision=0.12.0=py38_cu113 12 | - pandas=1.3.5 13 | - leidenalg=0.8.10 14 | - scanpy=1.9.1 15 | - pip 16 | - r-essentials=4.1 17 | - bioconductor-edger==3.36.0 18 | - r-statmod=1.4.36 19 | - scipy<1.9.0 20 | - pip: 21 | - scikit-learn==1.0.2 22 | - scikit-misc==0.1.4 23 | - scanpy==1.9.1 24 | - regex 25 | - git+https://github.com/YosefLab/mrvi.git@main 26 | - git+https://github.com/emdann/milopy.git@master 27 | - mnnpy==0.1.9.5 28 | -------------------------------------------------------------------------------- /workflow/notebooks/semisynthetic.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import glob 3 | import os 4 | import re 5 | from collections import defaultdict 6 | 7 | import ete3 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | import pandas as pd 11 | import plotnine as p9 12 | import scanpy as sc 13 | from scipy import stats 14 | from scipy.cluster.hierarchy import linkage, to_tree 15 | from sklearn.metrics import pairwise_distances, precision_score, recall_score 16 | from statsmodels.stats.multitest import multipletests 17 | from tqdm import tqdm 18 | 19 | # %% 20 | workflow_dir = "../" 21 | base_dir = os.path.join(workflow_dir, "results/semisynthetic") 22 | full_data_path = os.path.join(workflow_dir, "data/semisynthetic/adata.processed.h5ad") 23 | input_files = glob.glob(os.path.join(base_dir, "final_adata*")) 24 | 25 | figure_dir = os.path.join(workflow_dir, "figures/semisynthetic/") 26 | 27 | 28 | model_adatas = defaultdict(list) 29 | mapper = None 30 | for file in tqdm(input_files): 31 | print(file) 32 | adata_ = sc.read_h5ad(file) 33 | file_re = re.match(r".*final_adata_([\w\d]+)_(\d+).h5ad", file) 34 | if file_re is None: 35 | continue 36 | model_name, seed = file_re.groups() 37 | print(model_name, seed) 38 | 39 | uns_keys = list(adata_.uns.keys()) 40 | for uns_key in uns_keys: 41 | uns_vals = adata_.uns[uns_key] 42 | if uns_key.endswith("local_donor_rep"): 43 | meta_keys = [ 44 | key 45 | for key in adata_.uns.keys() 46 | if key.endswith("_local_donor_rep_metadata") 47 | ] 48 | print(uns_key) 49 | if len(meta_keys) != 0: 50 | # SORT UNS by donor_key values 51 | 52 | meta_key = meta_keys[0] 53 | print(uns_key, meta_key) 54 | mapper = adata_.uns["mapper"] 55 | donor_key = mapper["donor_key"] 56 | metad = adata_.uns[meta_key].reset_index().sort_values(donor_key) 57 | # print(metad.index.values) 58 | uns_vals = uns_vals[:, metad.index.values] 59 | else: 60 | uns_vals = adata_.uns[uns_key] 61 | for key in uns_vals: 62 | uns_vals[key] = uns_vals[key].sort_index() 63 | # print(uns_vals[key].index) 64 | 65 | adata_.uns[uns_key] = uns_vals 66 | model_adatas[model_name].append(dict(adata=adata_, seed=seed)) 67 | 68 | # %% 69 | adata_ = model_adatas["MrVILinearLinear10"][0]["adata"] 70 | affected_ct_a = adata_.obs.affected_ct1.unique()[0] 71 | ordered_tree_a = adata_.uns[ 72 | "MrVILinearLinear10_local_donor_rep_metadata" 73 | ].tree_id1.sort_values() 74 | t_gt_a = ete3.Tree( 75 | "((({}, {}), ({}, {})), (({},{}), ({}, {})));".format( 76 | *["t{}".format(val + 1) for val in ordered_tree_a.index] 77 | ) 78 | ) 79 | d_order_a = ordered_tree_a.index 80 | 81 | affected_ct_b = adata_.obs.affected_ct2.unique()[0] 82 | ordered_tree_b = adata_.uns[ 83 | "MrVILinearLinear10_local_donor_rep_metadata" 84 | ].tree_id2.sort_values() 85 | t_gt_b = ete3.Tree( 86 | "((({}, {}), ({}, {})), (({},{}), ({}, {})));".format( 87 | *["t{}".format(val + 1) for val in ordered_tree_b.index] 88 | ) 89 | ) 90 | d_order_b = ordered_tree_b.index 91 | print(affected_ct_a, affected_ct_b) 92 | 93 | 94 | # %% 95 | 96 | similarity_mat = np.zeros((len(ordered_tree_b), len(ordered_tree_b))) 97 | for i in range(len(ordered_tree_b)): 98 | for j in range(len(ordered_tree_b)): 99 | x = np.array([int(v) for v in ordered_tree_b.iloc[i]], dtype=bool) 100 | y = np.array([int(v) for v in ordered_tree_b.iloc[j]], dtype=bool) 101 | shared = x == y 102 | shared_ = np.where(~shared)[0] 103 | if len(shared_) == 0: 104 | similarity_mat[i, j] = 5 105 | else: 106 | similarity_mat[i, j] = np.where(~shared)[0][0] 107 | dissimilarity_mat = 5 - similarity_mat 108 | # %% 109 | ct_key = "leiden" 110 | sample_key = mapper["donor_key"] 111 | 112 | # %% 113 | # Unsupervised analysis 114 | MODELS = [ 115 | dict(model_name="CompositionPCA", cell_specific=False), 116 | dict(model_name="CompositionSCVI", cell_specific=False), 117 | dict(model_name="MrVILinearLinear10", cell_specific=True), 118 | ] 119 | 120 | 121 | # %% 122 | def compute_aggregate_dmat(reps): 123 | n_cells, n_donors, _ = reps.shape 124 | pairwise_ds = np.zeros((n_donors, n_donors)) 125 | for x in tqdm(reps): 126 | d_ = pairwise_distances(x, metric=METRIC) 127 | pairwise_ds += d_ / n_cells 128 | return pairwise_ds 129 | 130 | 131 | # %% 132 | METRIC = "euclidean" 133 | 134 | dist_mtxs = defaultdict(list) 135 | for model_params in MODELS: 136 | model_name = model_params["model_name"] 137 | rep_key = f"{model_name}_local_donor_rep" 138 | metadata_key = f"{model_name}_local_donor_rep_metadata" 139 | is_cell_specific = model_params["cell_specific"] 140 | 141 | for model_res in model_adatas[model_name]: 142 | adata, seed = model_res["adata"], model_res["seed"] 143 | if rep_key not in adata.uns: 144 | continue 145 | rep = adata.uns[rep_key] 146 | 147 | print(model_name) 148 | for cluster in adata.obs[ct_key].unique(): 149 | if not is_cell_specific: 150 | rep_ct = rep[cluster] 151 | ss_matrix = pairwise_distances(rep_ct.values, metric=METRIC) 152 | cats = metad.set_index(sample_key).loc[rep_ct.index].index.values 153 | else: 154 | good_cells = adata.obs[ct_key] == cluster 155 | rep_ct = rep[good_cells] 156 | subobs = adata.obs[good_cells] 157 | observed_donors = subobs[sample_key].value_counts() 158 | observed_donors = observed_donors[observed_donors >= 1].index 159 | good_d_idx = metad.loc[ 160 | lambda x: x[sample_key].isin(observed_donors) 161 | ].index.values 162 | 163 | ss_matrix = compute_aggregate_dmat(rep_ct[:, good_d_idx, :]) 164 | cats = metad.loc[lambda x: x[sample_key].isin(observed_donors)][ 165 | sample_key 166 | ].values 167 | 168 | dist_mtxs[model_name].append( 169 | dict(dist_matrix=ss_matrix, cats=cats, seed=seed, ct=cluster) 170 | ) 171 | 172 | 173 | # %% 174 | # https://stackoverflow.com/questions/9364609/converting-ndarray-generated-by-hcluster-into-a-newick-string-for-use-with-ete2/17657426#17657426 175 | def linkage_to_ete(linkage_obj): 176 | R = to_tree(linkage_obj) 177 | root = ete3.Tree() 178 | root.dist = 0 179 | root.name = "root" 180 | item2node = {R.get_id(): root} 181 | to_visit = [R] 182 | 183 | while to_visit: 184 | node = to_visit.pop() 185 | cl_dist = node.dist / 2.0 186 | 187 | for ch_node in [node.get_left(), node.get_right()]: 188 | if ch_node: 189 | ch_node_id = ch_node.get_id() 190 | ch_node_name = ( 191 | f"t{int(ch_node_id) + 1}" if ch_node.is_leaf() else str(ch_node_id) 192 | ) 193 | ch = ete3.Tree() 194 | ch.dist = cl_dist 195 | ch.name = ch_node_name 196 | 197 | item2node[node.get_id()].add_child(ch) 198 | item2node[ch_node_id] = ch 199 | to_visit.append(ch_node) 200 | return root 201 | 202 | 203 | # %% 204 | fig, ax = plt.subplots() 205 | im = ax.imshow(dissimilarity_mat) 206 | fig.savefig(os.path.join(figure_dir, "CT B_GT_dist_matrix.svg")) 207 | plt.show() 208 | plt.close() 209 | # %% 210 | rf_df = [] 211 | ete_trees = defaultdict(list) 212 | for model in dist_mtxs: 213 | for model_dist_mtx in dist_mtxs[model]: 214 | dist_mtx = model_dist_mtx["dist_matrix"] 215 | seed = model_dist_mtx["seed"] 216 | ct = model_dist_mtx["ct"] 217 | cats = model_dist_mtx["cats"] 218 | # cats = [cat[10:] if cat[:10] == "donor_meta" else cat for cat in cats] 219 | if ct == affected_ct_a: 220 | # Heatmaps 221 | fig, ax = plt.subplots() 222 | vmin = np.quantile(dist_mtx, 0.05) 223 | vmax = np.quantile(dist_mtx, 0.7) 224 | im = ax.imshow( 225 | dist_mtx[d_order_a][:, d_order_a], 226 | vmin=vmin, 227 | vmax=vmax, 228 | ) 229 | 230 | # Show all ticks and label them with the respective list entries 231 | ax.set_xticks(np.arange(len(cats)), fontsize=5) 232 | ax.set_yticks(np.arange(len(cats)), fontsize=5) 233 | 234 | # Rotate the tick labels and set their alignment. 235 | plt.setp( 236 | ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor" 237 | ) 238 | 239 | cell_type = "CT A" 240 | ax.set_title(f"{model} {cell_type}") 241 | 242 | fig.tight_layout() 243 | fig.savefig( 244 | os.path.join(figure_dir, f"{cell_type}_{model}_{seed}_dist_matrix.svg") 245 | ) 246 | plt.show() 247 | plt.close() 248 | 249 | z = linkage(dist_mtx, method="ward") 250 | z_ete = linkage_to_ete(z) 251 | ete_trees[model].append(dict(t=z_ete, ct=ct)) 252 | 253 | rf_dist = t_gt_a.robinson_foulds(z_ete) 254 | norm_rf = rf_dist[0] / rf_dist[1] 255 | rf_df.append( 256 | dict( 257 | ct=ct, 258 | model=model, 259 | rf=norm_rf, 260 | ) 261 | ) 262 | 263 | if ct == affected_ct_b: 264 | fig, ax = plt.subplots() 265 | vmin = np.quantile(dist_mtx, 0.05) 266 | vmax = np.quantile(dist_mtx, 0.7) 267 | im = ax.imshow( 268 | dist_mtx[d_order_b][:, d_order_b], 269 | vmin=vmin, 270 | vmax=vmax, 271 | ) 272 | 273 | # Show all ticks and label them with the respective list entries 274 | ax.set_xticks(np.arange(len(cats)), fontsize=5) 275 | ax.set_yticks(np.arange(len(cats)), fontsize=5) 276 | 277 | # Rotate the tick labels and set their alignment. 278 | plt.setp( 279 | ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor" 280 | ) 281 | 282 | cell_type = "CT B" 283 | ax.set_title(f"{model} {cell_type}") 284 | 285 | fig.tight_layout() 286 | fig.savefig( 287 | os.path.join(figure_dir, f"{cell_type}_{model}_{seed}_dist_matrix.svg") 288 | ) 289 | plt.show() 290 | plt.close() 291 | 292 | # Dendrograms 293 | z = linkage(dist_mtx, method="ward") 294 | z_ete = linkage_to_ete(z) 295 | ete_trees[model].append(dict(t=z_ete, ct=ct)) 296 | 297 | rf_dist = t_gt_b.robinson_foulds(z_ete) 298 | norm_rf = rf_dist[0] / rf_dist[1] 299 | rf_df.append( 300 | dict( 301 | ct=ct, 302 | model=model, 303 | rf=norm_rf, 304 | ) 305 | ) 306 | rf_df = pd.DataFrame(rf_df) 307 | 308 | # %% 309 | 310 | for ct in [0, 1]: 311 | subs = rf_df.query(f"ct == '{ct}'") 312 | pop1 = subs.query("model == 'MrVILinearLinear10'").rf 313 | pop2 = subs.query("model == 'CompositionSCVI'").rf 314 | pop3 = subs.query("model == 'CompositionPCA'").rf 315 | 316 | pval21 = stats.ttest_rel(pop1, pop2, alternative="less").pvalue 317 | print("mrVI vs CompositionSCVI", pval21) 318 | 319 | pval31 = stats.ttest_rel(pop1, pop3, alternative="less").pvalue 320 | print("mrVI vs CPCA", pval31) 321 | print() 322 | 323 | # %% 324 | rf_subplot = rf_df.loc[ 325 | lambda x: x["model"].isin( 326 | [ 327 | "CompositionSCVI", 328 | "CompositionPCA", 329 | "MrVILinearLinear10", 330 | ] 331 | ) 332 | ] 333 | 334 | fig = ( 335 | p9.ggplot(rf_subplot, p9.aes(x="model", y="rf")) 336 | + p9.geom_bar(p9.aes(fill="model"), stat="summary") 337 | + p9.coord_flip() 338 | + p9.facet_wrap("~ct") 339 | + p9.theme_classic() 340 | + p9.theme( 341 | axis_text=p9.element_text(size=12), 342 | axis_title=p9.element_text(size=15), 343 | aspect_ratio=1, 344 | strip_background=p9.element_blank(), 345 | legend_position="none", 346 | ) 347 | + p9.labs(x="", y="Robinson-Foulds Distance") 348 | + p9.scale_y_continuous(expand=[0, 0]) 349 | ) 350 | fig.save(os.path.join(figure_dir, "semisynth_rf_dist.svg"), verbose=False) 351 | fig 352 | 353 | 354 | # %% 355 | def correct_for_mult(x): 356 | # print(x.name) 357 | if x.name.startswith("MILO"): 358 | print("MILO already FDR controlled; skipping") 359 | return x 360 | return multipletests(x, method="fdr_bh")[1] 361 | 362 | 363 | padj_dfs = [] 364 | ct_obs = adata.obs[adata.uns["mapper"]["cell_type_key"]] 365 | for model in model_adatas: 366 | for model_adata in model_adatas[model]: 367 | adata = model_adata["adata"] 368 | # Selecting right columns in adata.obs 369 | sig_keys = [col for col in adata.obs.columns if col.endswith("significance")] 370 | # Adjust for multiple testing if needed 371 | padj_dfs.append(adata.obs.loc[:, sig_keys].apply(correct_for_mult, axis=0)) 372 | 373 | # Select right cell subpopulations 374 | padjs = pd.concat(padj_dfs, axis=1) 375 | target_a = ct_obs == "1" 376 | good_cols_a = [col for col in padjs.columns if "tree_id2" in col] 377 | 378 | 379 | ALPHA = 0.05 380 | 381 | 382 | def get_pr(padj, target): 383 | fdr = 1.0 - precision_score(target, padj <= ALPHA, zero_division=1) 384 | tpr = recall_score(target, padj <= ALPHA, zero_division=1) 385 | res = pd.Series(dict(FDP=fdr, TPR=tpr)) 386 | return res 387 | 388 | 389 | plot_df = ( 390 | padjs.loc[:, good_cols_a] 391 | .apply(get_pr, target=target_a, axis=0) 392 | .stack() 393 | .to_frame("score") 394 | .reset_index() 395 | .rename(columns=dict(level_0="metric", level_1="approach")) 396 | .assign( 397 | model=lambda x: x.approach.str.split("_").str[0], 398 | test=lambda x: x.approach.str.split("_").str[-2], 399 | model_test=lambda x: x.model + " " + x.test, 400 | metadata=lambda x: x.approach.str.split("_") 401 | .str[1:-2] 402 | .apply(lambda y: "_".join(y)), 403 | ) 404 | ) 405 | plot_df 406 | 407 | # %% 408 | MODEL_SELECTION = [ 409 | "MrVILinearLinear10", 410 | "MILOSCVI", 411 | ] 412 | 413 | plot_df_ = plot_df.loc[ 414 | lambda x: (x["model"].isin(MODEL_SELECTION)) 415 | & ( 416 | x["test"].isin( 417 | ( 418 | "ks", 419 | "LFC", 420 | ) 421 | ) 422 | ) 423 | & (x["metadata"] != "batch") 424 | ] 425 | 426 | fig = ( 427 | p9.ggplot( 428 | plot_df_.query('metric == "TPR"'), 429 | p9.aes(x="factor(metadata)", y="score", fill="model_test"), 430 | ) 431 | + p9.geom_bar(stat="identity", position="dodge") 432 | # + p9.facet_wrap("~metadata", scales="free") 433 | # + p9.coord_flip() 434 | + p9.theme_classic() 435 | + p9.theme( 436 | axis_text=p9.element_text(size=12), 437 | axis_title=p9.element_text(size=15), 438 | aspect_ratio=1.5, 439 | ) 440 | + p9.labs(x="", y="TPR", fill="") 441 | + p9.scale_y_continuous(expand=(0, 0)) 442 | ) 443 | fig.save( 444 | os.path.join(figure_dir, "semisynth_TPR_comparison_synth.svg"), 445 | ) 446 | fig 447 | 448 | # %% 449 | fig = ( 450 | p9.ggplot( 451 | plot_df_.query('metric == "FDP"'), 452 | p9.aes(x="factor(metadata)", y="score", fill="model_test"), 453 | ) 454 | + p9.geom_bar(stat="identity", position="dodge") 455 | # + p9.facet_wrap("~metadata", scales="free") 456 | # + p9.coord_flip() 457 | + p9.theme_classic() 458 | + p9.labs(x="", y="FDP", fill="") 459 | + p9.theme( 460 | axis_text=p9.element_text(size=12), 461 | axis_title=p9.element_text(size=15), 462 | aspect_ratio=1.5, 463 | ) 464 | + p9.scale_y_continuous(expand=(0, 0)) 465 | ) 466 | fig.save( 467 | os.path.join(figure_dir, "semisynth_FDR_comparison_synth.svg"), 468 | ) 469 | fig 470 | -------------------------------------------------------------------------------- /workflow/notebooks/snrna.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import glob 3 | import os 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import pandas as pd 8 | import plotnine as p9 9 | import scanpy as sc 10 | import scipy.stats as stats 11 | from scib.metrics import silhouette, silhouette_batch 12 | from sklearn.metrics import pairwise_distances 13 | from sklearn.preprocessing import OneHotEncoder 14 | from tqdm import tqdm 15 | 16 | METRIC = "euclidean" 17 | 18 | 19 | def compute_aggregate_dmat(reps, metric="cosine"): 20 | n_cells, n_donors, _ = reps.shape 21 | pairwise_ds = np.zeros((n_donors, n_donors)) 22 | for x in tqdm(reps): 23 | d_ = pairwise_distances(x, metric=metric) 24 | pairwise_ds += d_ / n_cells 25 | return pairwise_ds 26 | 27 | 28 | def return_distances_cluster_rep(rep, metric="cosine"): 29 | ordered_metad = metad.set_index(sample_key).loc[rep.index] 30 | ss_matrix = pairwise_distances(rep.values, metric=metric) 31 | cats = ordered_metad[bio_group_key].values[:, None] 32 | suspension_cats = ordered_metad[techno_key].values[:, None] 33 | return ss_matrix, cats, suspension_cats, ordered_metad 34 | 35 | 36 | def return_distances_cell_specific(rep, good_cells, metric="cosine"): 37 | good_cells = adata.obs[ct_key] == cluster 38 | rep_ct = rep[good_cells] 39 | subobs = adata.obs[good_cells] 40 | 41 | observed_donors = subobs[sample_key].value_counts() 42 | observed_donors = observed_donors[observed_donors >= 1].index 43 | meta_ = metad.reset_index() 44 | good_d_idx = meta_.loc[lambda x: x[sample_key].isin(observed_donors)].index.values 45 | ss_matrix = compute_aggregate_dmat(rep_ct[:, good_d_idx, :], metric=metric) 46 | cats = meta_.loc[good_d_idx][bio_group_key].values[:, None] 47 | suspension_cats = meta_.loc[good_d_idx][techno_key].values[:, None] 48 | return ss_matrix, cats, suspension_cats, meta_.loc[good_d_idx] 49 | 50 | 51 | # %% 52 | dataset_name = "snrna" 53 | 54 | base_dir = "workflow/results_V1/{}".format(dataset_name) 55 | full_data_path = "workflow/data/{}/adata.processed.h5ad".format(dataset_name) 56 | model_path = os.path.join(base_dir, "MrVI") 57 | 58 | input_files = glob.glob(os.path.join(base_dir, "final_adata*")) 59 | input_files 60 | 61 | # %% 62 | adata = sc.read_h5ad(input_files[0]) 63 | print(adata.shape) 64 | for file in tqdm(input_files[1:]): 65 | adata_ = sc.read_h5ad(file) 66 | print(file, adata_.shape) 67 | new_cols = np.setdiff1d(adata_.obs.columns, adata.obs.columns) 68 | adata.obs.loc[:, new_cols] = adata_.obs.loc[:, new_cols] 69 | 70 | new_uns = np.setdiff1d(list(adata_.uns.keys()), list(adata.uns.keys())) 71 | for new_uns_ in new_uns: 72 | uns_vals = adata_.uns[new_uns_] 73 | if new_uns_.endswith("local_donor_rep"): 74 | meta_keys = [ 75 | key 76 | for key in adata_.uns.keys() 77 | if key.endswith("_local_donor_rep_metadata") 78 | ] 79 | print(new_uns_) 80 | if len(meta_keys) != 0: 81 | # SORT UNS by donor_key values 82 | meta_key = meta_keys[0] 83 | print(new_uns_, meta_key) 84 | donor_key = adata_.uns["mapper"]["donor_key"] 85 | metad = adata_.uns[meta_key].reset_index().sort_values(donor_key) 86 | print(metad.index.values) 87 | uns_vals = uns_vals[:, metad.index.values] 88 | else: 89 | uns_vals = adata_.uns[new_uns_] 90 | for key in uns_vals: 91 | uns_vals[key] = uns_vals[key].sort_index() 92 | print(uns_vals[key].index) 93 | 94 | adata.uns[new_uns_] = uns_vals 95 | 96 | new_obsm = np.setdiff1d(list(adata_.obsm.keys()), list(adata.obsm.keys())) 97 | for new_obsm_ in new_obsm: 98 | adata.obsm[new_obsm_] = adata_.obsm[new_obsm_] 99 | 100 | 101 | # %% 102 | mapper = adata.uns["mapper"] 103 | ct_key = mapper["cell_type_key"] 104 | sample_key = mapper["donor_key"] 105 | bio_group_key = "donor_uuid" 106 | techno_key = "suspension_type" 107 | 108 | adata.obs.loc[:, "Sample_id"] = adata.obs["sample"].str[:3] 109 | metad.loc[:, "Sample_id"] = metad["sample"].str[:3] 110 | metad.loc[:, "Sample_id"] = metad.loc[:, "Sample_id"].astype("category") 111 | # %% 112 | adata.obs["sample"].value_counts() 113 | 114 | 115 | # %% 116 | # Unservised analysis 117 | MODELS = [ 118 | dict(model_name="CompositionPCA", cell_specific=False), 119 | dict(model_name="CompositionSCVI", cell_specific=False), 120 | dict(model_name="MrVILinearLinear10", cell_specific=True), 121 | ] 122 | # %% 123 | _dd_plot_df = pd.DataFrame() 124 | 125 | for model_params in MODELS: 126 | rep_key = "{}_local_donor_rep".format(model_params["model_name"]) 127 | metadata_key = "{}_local_donor_rep_metadata".format(model_params["model_name"]) 128 | is_cell_specific = model_params["cell_specific"] 129 | rep = adata.uns[rep_key] 130 | 131 | for cluster in adata.obs[ct_key].unique(): 132 | if not is_cell_specific: 133 | ss_matrix, cats, suspension_cats, meta_ = return_distances_cluster_rep( 134 | rep[cluster], metric=METRIC 135 | ) 136 | else: 137 | good_cells = adata.obs[ct_key] == cluster 138 | ss_matrix, cats, suspension_cats, meta_ = return_distances_cell_specific( 139 | rep, good_cells, metric=METRIC 140 | ) 141 | 142 | suspension_cats_oh = OneHotEncoder(sparse=False).fit_transform(suspension_cats) 143 | where_similar_sus = suspension_cats_oh @ suspension_cats_oh.T 144 | where_similar_sus = where_similar_sus.astype(bool) 145 | 146 | cats_oh = OneHotEncoder(sparse=False).fit_transform(cats) 147 | where_similar = cats_oh @ cats_oh.T 148 | n_donors = where_similar.shape[0] 149 | # where_similar = (where_similar - np.eye(where_similar.shape[0])).astype(bool) 150 | where_similar = where_similar.astype(bool) 151 | print(where_similar.shape) 152 | offdiag = ~np.eye(where_similar.shape[0], dtype=bool) 153 | 154 | if "library_uuid" not in meta_.columns: 155 | meta_ = meta_.reset_index() 156 | library_names = np.concatenate( 157 | np.array( 158 | [n_donors * [lib_name] for lib_name in meta_["library_uuid"].values] 159 | )[None] 160 | ) 161 | 162 | new_vals = pd.DataFrame( 163 | dict( 164 | dist=ss_matrix.reshape(-1), 165 | is_similar=where_similar.reshape(-1), 166 | has_similar_suspension=where_similar_sus.reshape(-1), 167 | library_name1=library_names.reshape(-1), 168 | library_name2=library_names.T.reshape(-1), 169 | ) 170 | ).assign(model=model_params["model_name"], cluster=cluster) 171 | _dd_plot_df = pd.concat([_dd_plot_df, new_vals], axis=0) 172 | 173 | # %% 174 | # Construct metadata 175 | final_meta = metad.copy() 176 | final_meta.loc[:, "donor_name"] = metad["sample"].str[:3] 177 | 178 | lib_to_sample_name = pd.Series( 179 | { 180 | "24723d89-8db6-4e5b-a227-5805b49bb8e6": "C41", 181 | "4059d4aa-b0d5-4b88-92f3-f5623e744c2f": "C58_TST", 182 | "7bdadd5c-74cc-4aef-9baf-cd2a75382a0c": "C58_RESEQ", 183 | "7ec5239b-b687-46ea-9c6b-9e2ea970ba21": "C72_RESEQ", 184 | "7ec5239b-b687-46ea-9c6b-9e2ea970ba21_split": "C72_RESEQ_split", 185 | "c557eece-31dc-4825-83c6-7af195076696": "C41_TST", 186 | "d62041ea-a566-4b7b-8280-2a8e5f776270": "C70_RESEQ", 187 | "da945071-1938-4ed5-b0fb-bcf5eef6f92f": "C70_TST", 188 | "f4a052f1-ffd8-4372-ae45-777811d945ee": "C72_TST", 189 | } 190 | ).to_frame("sample_name") 191 | final_meta = final_meta.merge( 192 | lib_to_sample_name, left_on="library_uuid", right_index=True 193 | ).assign( 194 | sample_name1=lambda x: x["sample_name"], 195 | sample_name2=lambda x: x["sample_name"], 196 | donor_name1=lambda x: x["donor_name"], 197 | donor_name2=lambda x: x["donor_name"], 198 | ) 199 | 200 | 201 | dd_plot_df_full = ( 202 | _dd_plot_df.merge( 203 | final_meta.loc[:, ["library_uuid", "sample_name1", "donor_name1"]], 204 | left_on="library_name1", 205 | right_on="library_uuid", 206 | how="left", 207 | ) 208 | .merge( 209 | final_meta.loc[:, ["library_uuid", "sample_name2", "donor_name2"]], 210 | left_on="library_name2", 211 | right_on="library_uuid", 212 | how="left", 213 | ) 214 | .assign( 215 | sample_name1=lambda x: pd.Categorical( 216 | x["sample_name1"], categories=np.sort(x.sample_name1.unique()) 217 | ), 218 | sample_name2=lambda x: pd.Categorical( 219 | x["sample_name2"], categories=np.sort(x.sample_name2.unique()) 220 | ), 221 | donor_name1=lambda x: pd.Categorical( 222 | x["donor_name1"], categories=np.sort(x.donor_name1.unique()) 223 | ), 224 | donor_name2=lambda x: pd.Categorical( 225 | x["donor_name2"], categories=np.sort(x.donor_name2.unique()) 226 | ), 227 | ) 228 | ) 229 | dd_plot_df_full.loc[:, "sample_name2_r"] = pd.Categorical( 230 | dd_plot_df_full["sample_name2"].values, 231 | categories=dd_plot_df_full["sample_name2"].cat.categories[::-1], 232 | ) 233 | dd_plot_df = dd_plot_df_full.loc[lambda x: x.library_name1 != x.library_name2] 234 | 235 | 236 | # %% 237 | for model in dd_plot_df.model.unique(): 238 | plot_ = dd_plot_df_full.query("cluster == 'periportal region hepatocyte'").loc[ 239 | lambda x: x.model == model 240 | ] 241 | vmin, vmax = np.quantile(plot_.dist, [0.2, 0.8]) 242 | plot_.loc[:, "dist_clip"] = np.clip(plot_.dist, vmin, vmax) 243 | fig = ( 244 | p9.ggplot(p9.aes(x="sample_name1", y="sample_name2_r")) 245 | + p9.geom_raster(plot_, p9.aes(fill="dist_clip")) 246 | + p9.geom_tile( 247 | plot_.query("is_similar"), 248 | color="#ff3f05", 249 | fill="none", 250 | size=2, 251 | ) 252 | + p9.theme_void() 253 | + p9.theme( 254 | axis_ticks=p9.element_blank(), 255 | ) 256 | + p9.scale_y_discrete() 257 | + p9.labs( 258 | title=model, 259 | x="", 260 | y="", 261 | ) 262 | + p9.coord_flip() 263 | + p9.scale_fill_cmap("viridis") 264 | ) 265 | fig.save("figures/snrna_heatmaps_{}.svg".format(model)) 266 | fig.draw() 267 | 268 | 269 | # %% 270 | gp1 = dd_plot_df.groupby(["model", "cluster", "donor_name1"]) 271 | v1 = gp1.apply(lambda x: x.query("is_similar").dist.median()) 272 | v2 = gp1.apply( 273 | lambda x: x.loc[lambda x: x.has_similar_suspension & ~(x.is_similar)].dist.median() 274 | ) 275 | 276 | min_dist = ( 277 | dd_plot_df.query("sample_name1 == 'C72_RESEQ'") 278 | .query("sample_name2 == 'C72_RESEQ_split'") 279 | .groupby(["model", "cluster"]) 280 | .dist.mean() 281 | .to_frame("min_dist") 282 | .reset_index() 283 | # averages the two pairwise distances 284 | ) 285 | v2_s = v2.to_frame("denom_dist").reset_index().query("donor_name1 == 'C72'") 286 | lower_bound_ratio = min_dist.merge(v2_s, on=["model", "cluster"]).assign( 287 | lower_bound_ratio=lambda x: x.min_dist / x.denom_dist 288 | ) 289 | 290 | 291 | ratio = (v1 / v2).to_frame("ratio").reset_index().dropna() 292 | 293 | SUBSELECTED_MODELS = [ 294 | "MrVILinearLinear10", 295 | "CompositionPCA", 296 | "CompositionSCVI", 297 | ] 298 | ratio_plot = ratio.loc[lambda x: x.model.isin(SUBSELECTED_MODELS)] 299 | lower_bound_ratio_plot = lower_bound_ratio.loc[ 300 | lambda x: x.model.isin(SUBSELECTED_MODELS) 301 | ] 302 | 303 | fig = ( 304 | p9.ggplot(ratio_plot, p9.aes(x="model", y="ratio", fill="model")) 305 | + p9.geom_boxplot() 306 | + p9.theme_classic() 307 | + p9.coord_flip() 308 | + p9.theme(legend_position="none") 309 | # + p9.labs(y="Ratio of mean distance between similar over dissimilar samples", x="") 310 | ) 311 | fig 312 | 313 | # %% 314 | 315 | 316 | ( 317 | p9.ggplot(ratio_plot, p9.aes(x="ratio", color="model")) 318 | + p9.stat_ecdf() 319 | + p9.xlim(0, 2) 320 | + p9.labs(y="ECDF", x="Ratio of mean distances") 321 | ) 322 | 323 | # %% 324 | vplot = pd.DataFrame( 325 | dict( 326 | ratio=[ 327 | 1, 328 | ], 329 | ) 330 | ) 331 | 332 | fig = ( 333 | p9.ggplot(ratio_plot, p9.aes(y="ratio", x="model", fill="model")) 334 | + p9.facet_wrap("donor_name1") 335 | + p9.geom_boxplot() 336 | + p9.geom_hline(p9.aes(yintercept="ratio"), data=vplot, linetype="dashed", size=1) 337 | + p9.coord_flip() 338 | + p9.theme_classic() 339 | + p9.theme( 340 | legend_position="none", 341 | strip_background=p9.element_blank(), 342 | aspect_ratio=0.6, 343 | axis_ticks_major_y=p9.element_blank(), 344 | axis_title=p9.element_text(size=15), 345 | strip_text=p9.element_text(size=15), 346 | axis_text=p9.element_text(size=15), 347 | ) 348 | + p9.labs(y="Ratio", x="") 349 | ) 350 | fig.save("figures/snrna_ratios.svg") 351 | fig 352 | 353 | 354 | # %% 355 | for donor_name in ratio_plot.donor_name1.unique(): 356 | subs = ratio_plot.query("donor_name1 == @donor_name") 357 | pop1 = subs.query("model == 'MrVILinearLinear10'").ratio 358 | pop2 = subs.query("model == 'CompositionSCVI'").ratio 359 | pop3 = subs.query("model == 'CompositionPCA'").ratio 360 | 361 | print(donor_name) 362 | pval21 = stats.ttest_rel(pop1, pop2, alternative="less").pvalue 363 | print("mrVI vs CompositionSCVI", pval21) 364 | 365 | pval31 = stats.ttest_rel(pop1, pop3, alternative="less").pvalue 366 | print("mrVI vs CPCA", pval31) 367 | print() 368 | 369 | 370 | # %% 371 | MODELS 372 | 373 | 374 | def return_distances_cell_specific(rep, good_cells, metric="cosine"): 375 | good_cells = adata.obs[ct_key] == cluster 376 | rep_ct = rep[good_cells] 377 | subobs = adata.obs[good_cells] 378 | 379 | observed_donors = subobs[sample_key].value_counts() 380 | observed_donors = observed_donors[observed_donors >= 1].index 381 | meta_ = metad.reset_index() 382 | good_d_idx = meta_.loc[lambda x: x[sample_key].isin(observed_donors)].index.values 383 | ss_matrix = compute_aggregate_dmat(rep_ct[:, good_d_idx, :], metric=metric) 384 | cats = meta_.loc[good_d_idx][bio_group_key].values[:, None] 385 | suspension_cats = meta_.loc[good_d_idx][techno_key].values[:, None] 386 | return ss_matrix, cats, suspension_cats, meta_.loc[good_d_idx] 387 | 388 | 389 | # %% 390 | _dd_plot_df = pd.DataFrame() 391 | 392 | MODELS = [ 393 | dict(model_name="CompositionPCA", cell_specific=False), 394 | dict(model_name="CompositionSCVI", cell_specific=False), 395 | dict(model_name="MrVILinearLinear10", cell_specific=True), 396 | ] 397 | 398 | 399 | rdm_idx = np.random.choice(adata.n_obs, 10000, replace=False) 400 | 401 | for model_params in MODELS: 402 | rep_key = "{}_local_donor_rep".format(model_params["model_name"]) 403 | metadata_key = "{}_local_donor_rep_metadata".format(model_params["model_name"]) 404 | is_cell_specific = model_params["cell_specific"] 405 | rep = adata.uns[rep_key] 406 | 407 | rep_ = rep[rdm_idx] 408 | ss_matrix = compute_aggregate_dmat(rep_, metric=METRIC) 409 | ds = [] 410 | donors = metad.Sample_id.cat.codes.values[..., None] 411 | select_num = OneHotEncoder(sparse=False).fit_transform(donors) 412 | where_same_d = select_num @ select_num.T 413 | 414 | donors = metad.Sample_id.cat.codes.values[..., None] 415 | select_num = OneHotEncoder(sparse=False).fit_transform(donors) 416 | where_same_d = select_num @ select_num.T 417 | for x in tqdm(rep_): 418 | d_ = pairwise_distances(x, metric=METRIC) 419 | 420 | # %% 421 | cell_reps = [ 422 | "MrVILinearLinear10_cell", 423 | "SCVI_cell", 424 | ] 425 | 426 | mixing_df = [] 427 | 428 | for rep in cell_reps: 429 | algo_name = rep.split("_")[0] 430 | batch_aws = silhouette_batch( 431 | adata, "suspension_type", "author_cell_type", rep, verbose=False 432 | ) 433 | sample_aws = silhouette_batch( 434 | adata, "library_uuid", "author_cell_type", rep, verbose=False 435 | ) 436 | ct_asw = silhouette(adata, "author_cell_type", rep) 437 | 438 | mixing_df.append( 439 | dict( 440 | batch_aws=batch_aws, 441 | sample_aws=sample_aws, 442 | algo_name=algo_name, 443 | ct_asw=ct_asw, 444 | ) 445 | ) 446 | mixing_df = pd.DataFrame(mixing_df) 447 | # %% 448 | mixing_df 449 | # %% 450 | for u_key in [ 451 | "MrVILinearLinear10_cell", 452 | ]: 453 | sc.pp.neighbors(adata, n_neighbors=15, use_rep=u_key) 454 | sc.tl.umap(adata) 455 | 456 | savename = "_".join([dataset_name, "u_sample_suspension_type"]) 457 | savename += ".png" 458 | with plt.rc_context({"figure.dpi": 500}): 459 | sc.pl.umap( 460 | adata, 461 | color=["Sample_id", "suspension_type", "author_cell_type"], 462 | title=u_key, 463 | save=savename, 464 | ) 465 | -------------------------------------------------------------------------------- /workflow/notebooks/synthetic.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import glob 3 | import os 4 | import re 5 | from collections import defaultdict 6 | from itertools import product 7 | 8 | import ete3 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | import pandas as pd 12 | import scanpy as sc 13 | import seaborn as sns 14 | from scipy.cluster.hierarchy import dendrogram, linkage, to_tree 15 | from sklearn.metrics import average_precision_score, pairwise_distances, precision_score 16 | from statsmodels.stats.multitest import multipletests 17 | from tqdm import tqdm 18 | 19 | # %% 20 | workflow_dir = "../" 21 | base_dir = os.path.join(workflow_dir, "results/synthetic") 22 | full_data_path = os.path.join(workflow_dir, "data/synthetic/adata.processed.h5ad") 23 | input_files = glob.glob(os.path.join(base_dir, "final_adata*")) 24 | 25 | figure_dir = os.path.join(workflow_dir, "figures/synthetic/") 26 | 27 | # %% 28 | # Ground truth similarity matrix 29 | meta_corr = [0, 0.5, 0.9] 30 | donor_combos = list(product(*[[0, 1] for _ in range(len(meta_corr))])) 31 | 32 | # E[||x - y||_2^2] 33 | meta_dist = 2 - 2 * np.array(meta_corr) 34 | dist_mtx = np.zeros((len(donor_combos), len(donor_combos))) 35 | for i in range(len(donor_combos)): 36 | for j in range(i): 37 | donor_i, donor_j = donor_combos[i], donor_combos[j] 38 | donor_diff = abs(np.array(donor_i) - np.array(donor_j)) 39 | dist_mtx[i, j] = dist_mtx[j, i] = np.sum(donor_diff * meta_dist) 40 | dist_mtx = np.sqrt(dist_mtx) # E[||x - y||_2] 41 | 42 | donor_replicates = 4 43 | 44 | gt_donor_combos = sum([donor_replicates * [str(dc)] for dc in donor_combos], []) 45 | gt_dist_mtx = np.zeros((len(gt_donor_combos), len(gt_donor_combos))) 46 | for i in range(len(gt_donor_combos)): 47 | for j in range(i + 1): 48 | gt_dist_mtx[i, j] = gt_dist_mtx[j, i] = dist_mtx[ 49 | i // donor_replicates, j // donor_replicates 50 | ] 51 | gt_control_dist_mtx = 1 - np.eye(len(gt_donor_combos)) 52 | 53 | # %% 54 | model_adatas = defaultdict(list) 55 | 56 | mapper = None 57 | for file in tqdm(input_files): 58 | adata_ = sc.read_h5ad(file) 59 | file_re = re.match(r".*final_adata_([\w\d]+)_(\d+).h5ad", file) 60 | if file_re is None: 61 | continue 62 | model_name, seed = file_re.groups() 63 | print(model_name, seed) 64 | 65 | uns_keys = list(adata_.uns.keys()) 66 | for uns_key in uns_keys: 67 | uns_vals = adata_.uns[uns_key] 68 | if uns_key.endswith("local_donor_rep"): 69 | meta_keys = [ 70 | key 71 | for key in adata_.uns.keys() 72 | if key.endswith("_local_donor_rep_metadata") 73 | ] 74 | print(uns_key) 75 | if len(meta_keys) != 0: 76 | # SORT UNS by donor_key values 77 | meta_key = meta_keys[0] 78 | print(uns_key, meta_key) 79 | mapper = adata_.uns["mapper"] 80 | donor_key = mapper["donor_key"] 81 | metad = adata_.uns[meta_key].reset_index().sort_values(donor_key) 82 | # print(metad.index.values) 83 | uns_vals = uns_vals[:, metad.index.values] 84 | else: 85 | uns_vals = adata_.uns[uns_key] 86 | for key in uns_vals: 87 | uns_vals[key] = uns_vals[key].sort_index() 88 | # print(uns_vals[key].index) 89 | 90 | adata_.uns[uns_key] = uns_vals 91 | model_adatas[model_name].append(dict(adata=adata_, seed=seed)) 92 | 93 | # %% 94 | ct_key = mapper["cell_type_key"] 95 | sample_key = mapper["donor_key"] 96 | 97 | # %% 98 | # Unsupervised analysis 99 | MODELS = [ 100 | dict(model_name="CompositionPCA", cell_specific=False), 101 | dict(model_name="CompositionSCVI", cell_specific=False), 102 | dict(model_name="MrVISmall", cell_specific=True), 103 | dict(model_name="MrVILinear", cell_specific=True), 104 | dict(model_name="MrVILinear50", cell_specific=True), 105 | dict(model_name="MrVILinear10COMP", cell_specific=True), 106 | dict(model_name="MrVILinear50COMP", cell_specific=True), 107 | ] 108 | 109 | 110 | # %% 111 | def compute_aggregate_dmat(reps): 112 | # return pairwise_distances(reps.mean(0)) 113 | n_cells, n_donors, _ = reps.shape 114 | pairwise_ds = np.zeros((n_donors, n_donors)) 115 | for x in tqdm(reps): 116 | d_ = pairwise_distances(x, metric=METRIC) 117 | pairwise_ds += d_ / n_cells 118 | return pairwise_ds 119 | 120 | 121 | # %% 122 | METRIC = "euclidean" 123 | 124 | dist_mtxs = defaultdict(list) 125 | for model_params in MODELS: 126 | model_name = model_params["model_name"] 127 | rep_key = f"{model_name}_local_donor_rep" 128 | metadata_key = f"{model_name}_local_donor_rep_metadata" 129 | is_cell_specific = model_params["cell_specific"] 130 | 131 | for model_res in model_adatas[model_name]: 132 | adata, seed = model_res["adata"], model_res["seed"] 133 | if rep_key not in adata.uns: 134 | continue 135 | rep = adata.uns[rep_key] 136 | 137 | print(model_name) 138 | for cluster in adata.obs[ct_key].unique(): 139 | if not is_cell_specific: 140 | rep_ct = rep[cluster] 141 | ss_matrix = pairwise_distances(rep_ct.values, metric=METRIC) 142 | cats = metad.set_index(sample_key).loc[rep_ct.index].index.values 143 | else: 144 | good_cells = adata.obs[ct_key] == cluster 145 | rep_ct = rep[good_cells] 146 | subobs = adata.obs[good_cells] 147 | observed_donors = subobs[sample_key].value_counts() 148 | observed_donors = observed_donors[observed_donors >= 1].index 149 | good_d_idx = metad.loc[ 150 | lambda x: x[sample_key].isin(observed_donors) 151 | ].index.values 152 | 153 | ss_matrix = compute_aggregate_dmat(rep_ct[:, good_d_idx, :]) 154 | cats = metad.loc[lambda x: x[sample_key].isin(observed_donors)][ 155 | sample_key 156 | ].values 157 | 158 | dist_mtxs[model_name].append( 159 | dict(dist_matrix=ss_matrix, cats=cats, seed=seed, ct=cluster) 160 | ) 161 | 162 | # %% 163 | dist_mtxs["GroundTruth"] = [ 164 | dict(dist_matrix=gt_dist_mtx, cats=gt_donor_combos, seed=None, ct="CT1:1"), 165 | dict(dist_matrix=gt_control_dist_mtx, cats=gt_donor_combos, seed=None, ct="CT2:1"), 166 | ] 167 | 168 | 169 | # %% 170 | # https://stackoverflow.com/questions/9364609/converting-ndarray-generated-by-hcluster-into-a-newick-string-for-use-with-ete2/17657426#17657426 171 | def linkage_to_ete(linkage_obj): 172 | R = to_tree(linkage_obj) 173 | root = ete3.Tree() 174 | root.dist = 0 175 | root.name = "root" 176 | item2node = {R.get_id(): root} 177 | to_visit = [R] 178 | 179 | while to_visit: 180 | node = to_visit.pop() 181 | cl_dist = node.dist / 2.0 182 | 183 | for ch_node in [node.get_left(), node.get_right()]: 184 | if ch_node: 185 | ch_node_id = ch_node.get_id() 186 | ch_node_name = ( 187 | f"t{int(ch_node_id) + 1}" if ch_node.is_leaf() else str(ch_node_id) 188 | ) 189 | ch = ete3.Tree() 190 | ch.dist = cl_dist 191 | ch.name = ch_node_name 192 | 193 | item2node[node.get_id()].add_child(ch) 194 | item2node[ch_node_id] = ch 195 | to_visit.append(ch_node) 196 | return root 197 | 198 | 199 | # %% 200 | ete_trees = defaultdict(list) 201 | for model in dist_mtxs: 202 | for model_dist_mtx in dist_mtxs[model]: 203 | dist_mtx = model_dist_mtx["dist_matrix"] 204 | seed = model_dist_mtx["seed"] 205 | ct = model_dist_mtx["ct"] 206 | cats = model_dist_mtx["cats"] 207 | cats = [cat[10:] if cat[:10] == "donor_meta" else cat for cat in cats] 208 | 209 | # Heatmaps 210 | fig, ax = plt.subplots() 211 | im = ax.imshow(dist_mtx) 212 | 213 | # Show all ticks and label them with the respective list entries 214 | ax.set_xticks(np.arange(len(cats)), labels=cats, fontsize=5) 215 | ax.set_yticks(np.arange(len(cats)), labels=cats, fontsize=5) 216 | 217 | # Rotate the tick labels and set their alignment. 218 | plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") 219 | 220 | cell_type = "Experimental Cell Type" if ct == "CT1:1" else "Control Cell Type" 221 | ax.set_title(f"{model} {cell_type}") 222 | 223 | fig.tight_layout() 224 | fig.savefig(os.path.join(figure_dir, f"{model}_{ct}_{seed}_dist_matrix.svg")) 225 | plt.show() 226 | plt.close() 227 | 228 | # Dendrograms 229 | z = linkage(dist_mtx, method="ward") 230 | 231 | fig, ax = plt.subplots() 232 | dn = dendrogram(z, ax=ax, orientation="top") 233 | ax.set_title(f"{model} {cell_type}") 234 | fig.tight_layout() 235 | fig.savefig(os.path.join(figure_dir, f"{model}_{ct}_{seed}_dendrogram.svg")) 236 | plt.close() 237 | 238 | z_ete = linkage_to_ete(z) 239 | ete_trees[model].append(dict(t=z_ete, ct=ct)) 240 | 241 | # %% 242 | # Boxplot for RF distance 243 | for gt_trees in ete_trees["GroundTruth"]: 244 | if gt_trees["ct"] == "CT1:1": 245 | gt_tree = gt_trees["t"] 246 | elif gt_trees["ct"] == "CT2:1": 247 | gt_control_tree = gt_trees["t"] 248 | else: 249 | continue 250 | 251 | rf_df_rows = [] 252 | for model, ts in ete_trees.items(): 253 | if model == "GroundTruth": 254 | continue 255 | 256 | for t_dict in ts: 257 | t = t_dict["t"] 258 | ct = t_dict["ct"] 259 | if ct == "CT1:1": 260 | rf_dist = gt_tree.robinson_foulds(t) 261 | elif ct == "CT2:1": 262 | rf_dist = gt_control_tree.robinson_foulds(t) 263 | else: 264 | continue 265 | norm_rf = rf_dist[0] / rf_dist[1] 266 | 267 | rf_df_rows.append((model, ct, norm_rf)) 268 | 269 | rf_df = pd.DataFrame(rf_df_rows, columns=["model", "cell_type", "rf_dist"]) 270 | 271 | # %% 272 | # Experimental Cell Type Plot 273 | fig, ax = plt.subplots(figsize=(2, 5)) 274 | sns.barplot( 275 | data=rf_df[rf_df["cell_type"] == "CT1:1"], 276 | x="model", 277 | y="rf_dist", 278 | ax=ax, 279 | ) 280 | sns.swarmplot( 281 | data=rf_df[rf_df["cell_type"] == "CT1:1"], 282 | x="model", 283 | y="rf_dist", 284 | color="black", 285 | ax=ax, 286 | ) 287 | plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") 288 | ax.set_xticklabels(ax.get_xticklabels(), fontsize=7) 289 | ax.set_xlabel("Robinson-Foulds Distance") 290 | ax.set_ylabel("Model") 291 | ax.set_title("Robinson-Foulds Comparison to Ground Truth for Experimental Cell Type") 292 | fig.tight_layout() 293 | fig.savefig( 294 | os.path.join(figure_dir, "robinson_foulds_CT1:1_boxplot.svg"), bbox_inches="tight" 295 | ) 296 | plt.show() 297 | 298 | 299 | # %% 300 | def correct_for_mult(x): 301 | # print(x.name) 302 | if x.name.startswith("MILO"): 303 | print("MILO already FDR controlled; skipping") 304 | return x 305 | return multipletests(x, method="fdr_bh")[1] 306 | 307 | 308 | padj_dfs = [] 309 | ct_obs = None 310 | for model in model_adatas: 311 | for model_adata in model_adatas[model]: 312 | adata = model_adata["adata"] 313 | if ct_obs is None: 314 | ct_obs = adata.obs["celltype"] 315 | 316 | # Selecting right columns in adata.obs 317 | sig_keys = [col for col in adata.obs.columns if col.endswith("significance")] 318 | # Adjust for multiple testing if needed 319 | padj_dfs.append(adata.obs.loc[:, sig_keys].apply(correct_for_mult, axis=0)) 320 | 321 | # %% 322 | # Select right cell subpopulations 323 | padjs = pd.concat(padj_dfs, axis=1) 324 | target = ct_obs == "CT1:1" 325 | 326 | plot_df = ( 327 | padjs.apply( 328 | lambda x: pd.Series( 329 | { 330 | **{ 331 | target_fdr: 1.0 332 | - precision_score(target, x <= target_fdr, zero_division=0) 333 | for target_fdr in [0.05, 0.1, 0.2] 334 | }, 335 | } 336 | ), 337 | axis=0, 338 | ) 339 | .stack() 340 | .to_frame("FDP") 341 | .reset_index() 342 | .rename(columns=dict(level_0="targetFDR", level_1="approach")) 343 | .assign( 344 | model=lambda x: x.approach.str.split("_").str[0], 345 | test=lambda x: x.approach.str.split("_").str[-2], 346 | model_test=lambda x: x.model + " " + x.test, 347 | metadata=lambda x: x.approach.str.split("_") 348 | .str[1:-2] 349 | .apply(lambda y: "_".join(y)), 350 | ) 351 | ) 352 | # %% 353 | # Keep MrVILinear Manova, PCAKNN Manova, SCVI MANOVA, MRVILINEAR50 KS 354 | # MrVILinear KS, PCAKNN KS, SCVI KS, MrVILiear50 KS, MILOSCVI LFC 355 | for metadata in plot_df.metadata.unique(): 356 | fig, ax = plt.subplots() 357 | sns.barplot( 358 | data=plot_df[ 359 | (plot_df["metadata"] == metadata) 360 | & ( 361 | plot_df["model"].isin( 362 | ( 363 | "MrVILinear", 364 | "MrVILinear50", 365 | "PCAKNN", 366 | "SCVI", 367 | "MILOSCVI", 368 | ) 369 | ) 370 | ) 371 | & ( 372 | plot_df["test"].isin( 373 | ( 374 | # "manova", 375 | "ks", 376 | "LFC", 377 | ) 378 | ) 379 | ) 380 | ], 381 | x="targetFDR", 382 | y="FDP", 383 | hue="model_test", 384 | ax=ax, 385 | hue_order=[ 386 | "MrVILinear ks", 387 | "MrVILinear50 ks", 388 | "SCVI ks", 389 | "MILOSCVI LFC", 390 | "PCAKNN ks", 391 | ], 392 | ) 393 | plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") 394 | ax.set_xticklabels(ax.get_xticklabels(), fontsize=7) 395 | ax.set_xlabel("Target FDR") 396 | ax.set_ylabel("FDP") 397 | ax.set_title(f"False Discovery Rate Comparison ({metadata[6:]})") 398 | sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1)) 399 | 400 | fig.tight_layout() 401 | fig.savefig( 402 | os.path.join(figure_dir, f"FDR_comparison_{metadata}_ks_only.svg"), 403 | bbox_inches="tight", 404 | ) 405 | plt.show() 406 | # %% 407 | # AP Comparison 408 | ap_plot_df = ( 409 | padjs.apply( 410 | lambda x: pd.Series( 411 | { 412 | **{ 413 | target_fdr: average_precision_score(target, x <= target_fdr) 414 | for target_fdr in [0.05, 0.1, 0.2] 415 | }, 416 | } 417 | ), 418 | axis=0, 419 | ) 420 | .stack() 421 | .to_frame("AP") 422 | .reset_index() 423 | .rename(columns=dict(level_0="targetFDR", level_1="approach")) 424 | .assign( 425 | model=lambda x: x.approach.str.split("_").str[0], 426 | test=lambda x: x.approach.str.split("_").str[-2], 427 | model_test=lambda x: x.model + " " + x.test, 428 | metadata=lambda x: x.approach.str.split("_") 429 | .str[1:-2] 430 | .apply(lambda y: "_".join(y)), 431 | ) 432 | ) 433 | 434 | # %% 435 | for metadata in ap_plot_df.metadata.unique(): 436 | fig, ax = plt.subplots() 437 | sns.barplot( 438 | data=ap_plot_df[ 439 | (ap_plot_df["metadata"] == metadata) 440 | & ( 441 | plot_df["model"].isin( 442 | ( 443 | "MrVILinear", 444 | "MrVILinear50", 445 | "PCAKNN", 446 | "SCVI", 447 | "MILOSCVI", 448 | ) 449 | ) 450 | ) 451 | & ( 452 | plot_df["test"].isin( 453 | ( 454 | # "manova", 455 | "ks", 456 | "LFC", 457 | ) 458 | ) 459 | ) 460 | ], 461 | x="targetFDR", 462 | y="AP", 463 | hue="model_test", 464 | ax=ax, 465 | hue_order=[ 466 | "MrVILinear ks", 467 | "MrVILinear50 ks", 468 | "SCVI ks", 469 | "MILOSCVI LFC", 470 | "PCAKNN ks", 471 | ], 472 | ) 473 | plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") 474 | ax.set_xticklabels(ax.get_xticklabels(), fontsize=7) 475 | ax.set_xlabel("Target FDR") 476 | ax.set_ylabel("AP") 477 | ax.set_ylim((0.5, 1.05)) 478 | ax.set_title(f"Average Precision Rate Comparison ({metadata})") 479 | sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1)) 480 | 481 | fig.tight_layout() 482 | fig.savefig( 483 | os.path.join(figure_dir, f"AP_comparison_{metadata}_ks_only.svg"), 484 | bbox_inches="tight", 485 | ) 486 | plt.show() 487 | 488 | # %% 489 | -------------------------------------------------------------------------------- /workflow/scripts/compute_local_scores.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from functools import partial 3 | 4 | import anndata as ad 5 | import numpy as np 6 | import pynndescent 7 | import scanpy as sc 8 | import torch 9 | from tqdm import tqdm 10 | from utils import compute_ks, compute_manova 11 | 12 | 13 | def compute_autocorrelation_metric( 14 | local_representation, 15 | donor_labels, 16 | metric_fn, 17 | has_significance, 18 | batch_size, 19 | minibatched=True, 20 | desc="desc", 21 | ): 22 | # Compute autocorrelation 23 | if minibatched: 24 | scores = [] 25 | pvals = [] 26 | for local_rep in tqdm(local_representation.split(batch_size), desc=desc): 27 | local_rep_ = local_rep.to("cuda") 28 | if has_significance: 29 | score, pval = metric_fn(local_rep_, donor_labels) 30 | scores.append(score) 31 | pvals.append(pval) 32 | else: 33 | scores_ = metric_fn(local_rep_, donor_labels=donor_labels) 34 | scores.append(scores_) 35 | scores = np.concatenate(scores, 0) 36 | pvals = np.concatenate(pvals, 0) if has_significance else None 37 | else: 38 | if has_significance: 39 | scores, pvals = metric_fn( 40 | local_representation.numpy(), donor_labels=donor_labels 41 | ) 42 | else: 43 | scores = metric_fn(local_representation.numpy(), donor_labels=donor_labels) 44 | pvals = None 45 | return scores, pvals 46 | 47 | 48 | def compute_mrvi(adata, model_name, obs_key, batch_size=256, redo=True): 49 | uns_key = "{}_local_donor_rep".format(model_name) 50 | local_donor_rep = adata.uns[uns_key] 51 | if not isinstance(local_donor_rep, np.ndarray): 52 | logging.warn(f"{model_name} local donor rep not compatible with metric.") 53 | return 54 | local_representation = torch.from_numpy(local_donor_rep) 55 | # Assumes local_representation is n_cells, n_donors, n_donors 56 | metadata_key = "{}_local_donor_rep_metadata".format(model_name) 57 | metadata = adata.uns[metadata_key] 58 | donor_labels = metadata[obs_key].values 59 | 60 | # Compute and save various metric scores 61 | _compute_ks = partial(compute_ks, do_smoothing=False) 62 | configs = [ 63 | dict( 64 | metric_name="ks", 65 | metric_fn=_compute_ks, 66 | has_significance=True, 67 | minibatched=True, 68 | ), 69 | dict( 70 | metric_name="manova", 71 | metric_fn=compute_manova, 72 | has_significance=True, 73 | minibatched=False, 74 | ), 75 | ] 76 | 77 | for config in configs: 78 | metric_name = config.pop("metric_name") 79 | scores, pvals = compute_autocorrelation_metric( 80 | local_representation, 81 | donor_labels, 82 | batch_size=batch_size, 83 | **config, 84 | ) 85 | output_key = f"{model_name}_{obs_key}_{metric_name}_score" 86 | sig_key = f"{model_name}_{obs_key}_{metric_name}_significance" 87 | 88 | is_obs = scores.shape[0] == adata.shape[0] 89 | adata.uns[output_key] = scores 90 | if pvals is not None: 91 | adata.uns[sig_key] = pvals 92 | if is_obs: 93 | adata.obs[output_key] = scores 94 | if pvals is not None: 95 | adata.obs[sig_key] = pvals 96 | 97 | 98 | def compute_milo(adata, model_name, obs_key): 99 | import milopy.core as milo 100 | 101 | sample_col = adata.uns["nhood_adata"].uns["sample_col"] 102 | if sample_col == obs_key: 103 | logging.warning("Milo cannot run a GLM against the sample col, skipping.") 104 | return 105 | 106 | try: 107 | adata.obs[obs_key] = adata.obs[obs_key].astype(str).astype("category") 108 | design = f"~ {obs_key}" 109 | 110 | # Issue with None valued uns values being dropped. 111 | if "nhood_neighbors_key" not in adata.uns: 112 | adata.uns["nhood_neighbors_key"] = None 113 | milo.DA_nhoods(adata, design=design) 114 | milo_results = adata.uns["nhood_adata"].obs 115 | 116 | except Exception as e: 117 | logging.warning( 118 | f"Skipping test since key {obs_key} may be invalid or model did not complete" 119 | f" with the error message: {e.__class__.__name__} - {str(e)}." 120 | ) 121 | return 122 | 123 | is_index_cell_key = f"{model_name}_{obs_key}_is_index_cell" 124 | is_index_cell = adata.obs.index.isin(milo_results["index_cell"]) 125 | adata.obs[is_index_cell_key] = is_index_cell 126 | 127 | cell_rep = adata.obsm["_cell_rep"] 128 | cell_anchors = cell_rep[is_index_cell] 129 | nn_donor = pynndescent.NNDescent(cell_anchors) 130 | nn_indices = nn_donor.query(cell_rep, k=1)[0].squeeze(-1) 131 | 132 | output_key = f"{model_name}_{obs_key}_LFC_score" 133 | lfc_score = milo_results["logFC"].values[nn_indices] 134 | adata.obs[output_key] = lfc_score 135 | 136 | sig_key = f"{model_name}_{obs_key}_LFC_significance" 137 | lfc_sig = milo_results["SpatialFDR"].values[nn_indices] 138 | adata.obs[sig_key] = lfc_sig 139 | return 140 | 141 | 142 | def compute_metrics(adata, model_name, obs_key, batch_size=256, redo=True): 143 | if model_name.startswith("MILO"): 144 | compute_milo(adata, model_name, obs_key) 145 | else: 146 | compute_mrvi(adata, model_name, obs_key, batch_size, redo) 147 | 148 | 149 | def process_predictions( 150 | model_name, 151 | path_to_h5ad, 152 | path_to_output, 153 | donor_obs_keys, 154 | ): 155 | adata = sc.read_h5ad(path_to_h5ad) 156 | for donor_obs_key in donor_obs_keys: 157 | compute_metrics(adata, model_name, donor_obs_key) 158 | adata_ = ad.AnnData( 159 | obs=adata.obs, 160 | obsm=adata.obsm, 161 | var=adata.var, 162 | uns=adata.uns, 163 | ) 164 | adata_.write(path_to_output) 165 | 166 | 167 | if __name__ == "__main__": 168 | process_predictions( 169 | model_name=snakemake.wildcards.model, 170 | path_to_h5ad=snakemake.input[0], 171 | path_to_output=snakemake.output[0], 172 | donor_obs_keys=snakemake.config[snakemake.wildcards.dataset]["keyMapping"][ 173 | "relevantKeys" 174 | ], 175 | ) 176 | -------------------------------------------------------------------------------- /workflow/scripts/process_data.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import string 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import scanpy as sc 7 | import scipy.sparse as sp 8 | 9 | 10 | def create_joint_obs_key(adata, keys): 11 | joint_key_name = "_".join(keys) 12 | adata.obs[joint_key_name] = ( 13 | adata.obs[keys].astype(str).apply(lambda x: "-".join(x), axis=1) 14 | ) 15 | return joint_key_name 16 | 17 | 18 | def make_categorical(adata, obs_key): 19 | adata.obs[obs_key] = adata.obs[obs_key].astype("category") 20 | 21 | 22 | def create_obs_mapper(adata, dataset_name): 23 | dataset_config = snakemake.config[snakemake.wildcards.dataset]["keyMapping"] 24 | 25 | donor_key = dataset_config["donorKey"] 26 | if isinstance(donor_key, list): 27 | donor_key = create_joint_obs_key(adata, donor_key) 28 | make_categorical(adata, donor_key) 29 | 30 | cell_type_key = dataset_config["cellTypeKey"] 31 | make_categorical(adata, cell_type_key) 32 | 33 | nuisance_keys = dataset_config["nuisanceKeys"] 34 | for nuisance_key in nuisance_keys: 35 | make_categorical(adata, nuisance_key) 36 | 37 | adata.uns["mapper"] = dict( 38 | donor_key=donor_key, 39 | categorical_nuisance_keys=nuisance_keys, 40 | cell_type_key=cell_type_key, 41 | ) 42 | 43 | 44 | def assign_symsim_donors(adata): 45 | np.random.seed(1) 46 | dataset_config = snakemake.config["synthetic"]["keyMapping"] 47 | donor_key = dataset_config["donorKey"] 48 | batch_key = dataset_config["nuisanceKeys"][0] 49 | 50 | n_donors = 32 51 | n_meta = len([k for k in adata.obs.keys() if "meta_" in k]) 52 | meta_keys = [f"meta_{i + 1}" for i in range(n_meta)] 53 | make_categorical(adata, batch_key) 54 | batches = adata.obs[batch_key].cat.categories.tolist() 55 | n_batch = len(batches) 56 | 57 | meta_combos = list(itertools.product([0, 1], repeat=n_meta)) 58 | donors_per_meta_batch_combo = n_donors // len(meta_combos) // n_batch 59 | 60 | # Assign donors uniformly at random for cells with matching metadata. 61 | donor_assignment = np.empty(adata.n_obs, dtype=object) 62 | for batch in batches: 63 | batch_donors = [] 64 | for meta_combo in meta_combos: 65 | match_cats = [f"CT{meta_combo[i]+1}:1" for i in range(n_meta)] 66 | eligible_cell_idxs = ( 67 | ( 68 | np.all( 69 | adata.obs[meta_keys].values == match_cats, 70 | axis=1, 71 | ) 72 | & (adata.obs[batch_key] == batch) 73 | ) 74 | .to_numpy() 75 | .nonzero()[0] 76 | ) 77 | meta_donors = [ 78 | f"donor_meta{meta_combo}_batch{batch}_{ch}" 79 | for ch in string.ascii_lowercase[:donors_per_meta_batch_combo] 80 | ] 81 | donor_assignment[eligible_cell_idxs] = np.random.choice( 82 | meta_donors, replace=True, size=len(eligible_cell_idxs) 83 | ) 84 | batch_donors += meta_donors 85 | 86 | adata.obs[donor_key] = donor_assignment 87 | 88 | donor_meta = adata.obs[donor_key].str.extractall( 89 | r"donor_meta\(([0-1]), ([0-1]), ([0-1])\)_batch[0-9]_[a-z]" 90 | ) 91 | for match_idx, meta_key in enumerate(meta_keys): 92 | adata.obs[f"donor_{meta_key}"] = donor_meta[match_idx].astype(int).tolist() 93 | 94 | 95 | def construct_tree_semisynth(adata, depth_tree=3, dataset_name="semisynthetic"): 96 | """Modifies gene expression in two cell subpopulations according to a controlled 97 | donor tree structure 98 | 99 | """ 100 | # construct donors 101 | n_donors = int(2**depth_tree) 102 | np.random.seed(0) 103 | random_donors = np.random.randint(0, n_donors, adata.n_obs) 104 | n_modules = sum([int(2**k) for k in range(1, depth_tree + 1)]) 105 | 106 | # ct_key = snakemake.config[dataset_name]["keyMapping"]["cellTypeKey"] 107 | # cts = adata.obs.groupby(ct_key).size().sort_values(ascending=False)[:2] 108 | # ct1, ct2 = cts.index.values 109 | ct_key = "leiden" 110 | ct1, ct2 = "0", "1" 111 | 112 | # construct donor trees 113 | leaves_id = np.array( 114 | [format(i, "0{}b".format(depth_tree)) for i in range(n_donors)] 115 | ) # ids of leaves 116 | 117 | all_node_ids = [] 118 | for dep in range(1, depth_tree + 1): 119 | node_ids = [format(i, "0{}b".format(dep)) for i in range(2**dep)] 120 | all_node_ids += node_ids # ids of all nodes in the tree 121 | 122 | def perturb_gene_exp(ct, X_perturbed, all_node_ids, leaves_id): 123 | leaves_id1 = leaves_id.copy() 124 | np.random.shuffle(leaves_id1) 125 | genes = np.arange(adata.n_vars) 126 | np.random.shuffle(genes) 127 | gene_modules = np.array_split(genes, n_modules) 128 | gene_modules = { 129 | node_id: gene_modules[i] for i, node_id in enumerate(all_node_ids) 130 | } 131 | gene_modules = { 132 | node_id: np.isin(np.arange(adata.n_vars), gene_modules[node_id]) 133 | for node_id in all_node_ids 134 | } 135 | # convert to one hots to make life easier for saving 136 | 137 | # modifying gene expression 138 | # e.g., 001 has perturbed modules 0, 00, and 001 139 | subpop = adata.obs.loc[:, ct_key].values == ct 140 | print("perturbing {}".format(ct)) 141 | for donor_id in range(n_donors): 142 | selected_pop = subpop & (random_donors == donor_id) 143 | leaf_id = leaves_id1[donor_id] 144 | perturbed_mod_ids = [leaf_id[:i] for i in range(1, depth_tree + 1)] 145 | perturbed_modules = np.zeros(adata.n_vars, dtype=bool) 146 | for id in perturbed_mod_ids: 147 | perturbed_modules = perturbed_modules | gene_modules[id] 148 | 149 | Xmat = X_perturbed[selected_pop].copy() 150 | print( 151 | "Perturbing {} genes in {} cells".format( 152 | perturbed_modules.sum(), selected_pop.sum() 153 | ) 154 | ) 155 | print( 156 | "Non-zero values in the relevant subpopulation and modules: ", 157 | (Xmat[:, perturbed_modules] != 0).sum(), 158 | ) 159 | Xmat[:, perturbed_modules] = Xmat[:, perturbed_modules] * 2 160 | 161 | X_perturbed[selected_pop] = Xmat 162 | return X_perturbed, gene_modules, leaves_id1 163 | 164 | X_pert = adata.X.copy() 165 | X_pert, gene_mod1, leaves1 = perturb_gene_exp(ct1, X_pert, all_node_ids, leaves_id) 166 | X_pert, gene_mod2, leaves2 = perturb_gene_exp(ct2, X_pert, all_node_ids, leaves_id) 167 | 168 | gene_modules1 = pd.DataFrame(gene_mod1) 169 | gene_modules2 = pd.DataFrame(gene_mod2) 170 | donor_metadata = pd.DataFrame( 171 | dict( 172 | donor_id=np.arange(n_donors), 173 | tree_id1=leaves1, 174 | affected_ct1=ct1, 175 | tree_id2=leaves2, 176 | affected_ct2=ct2, 177 | ) 178 | ) 179 | 180 | meta_id1 = pd.DataFrame([list(x) for x in donor_metadata.tree_id1.values]).astype( 181 | int 182 | ) 183 | n_features1 = meta_id1.shape[1] 184 | new_cols1 = ["tree_id1_{}".format(i) for i in range(n_features1)] 185 | donor_metadata.loc[:, new_cols1] = meta_id1.values 186 | meta_id2 = pd.DataFrame([list(x) for x in donor_metadata.tree_id2.values]).astype( 187 | int 188 | ) 189 | n_features2 = meta_id2.shape[1] 190 | new_cols2 = ["tree_id2_{}".format(i) for i in range(n_features2)] 191 | donor_metadata.loc[:, new_cols2] = meta_id2.values 192 | 193 | donor_metadata.loc[:, new_cols1 + new_cols2] = donor_metadata.loc[ 194 | :, new_cols1 + new_cols2 195 | ].astype("category") 196 | 197 | adata.obs.loc[:, "batch"] = random_donors 198 | adata.obs.loc[:, "Site"] = 1 199 | original_index = adata.obs.index.copy() 200 | adata.obs = adata.obs.merge( 201 | donor_metadata, left_on="batch", right_on="donor_id", how="left" 202 | ) 203 | adata.obs.index = original_index 204 | adata.uns["gene_modules1"] = gene_modules1 205 | adata.uns["gene_modules2"] = gene_modules2 206 | adata.uns["donor_metadata"] = donor_metadata 207 | adata = sc.AnnData(X_pert, obs=adata.obs, var=adata.var, uns=adata.uns) 208 | return adata 209 | 210 | 211 | def process_snrna(adata): 212 | # We first remove samples processed with specific protocols that 213 | # are not used in other samples 214 | select_cell = ~(adata.obs["sample"].isin(["C41_CST", "C41_NST"])) 215 | adata = adata[select_cell].copy() 216 | 217 | # We also artificially subsample 218 | # one of the samples into two samples 219 | # as a negative control 220 | subplit_sample = "C72_RESEQ" 221 | mask_selected_lib = (adata.obs["sample"] == subplit_sample).values 222 | np.random.seed(0) 223 | mask_split = ( 224 | np.random.randint(0, 2, size=mask_selected_lib.shape[0]).astype(bool) 225 | * mask_selected_lib 226 | ) 227 | libraries = adata.obs["library_uuid"].astype(str) 228 | libraries.loc[mask_split] = libraries.loc[mask_split] + "_split" 229 | assert libraries.unique().shape[0] == 9 230 | adata.obs.loc[:, "library_uuid"] = libraries.astype("category") 231 | return adata 232 | 233 | 234 | def process_dataset(dataset_name, input_h5ad, output_h5ad): 235 | adata = sc.read_h5ad(input_h5ad) 236 | 237 | if isinstance(adata.X, sp.csc_matrix): 238 | adata.X = adata.X.tocsr() 239 | 240 | dataset_info = snakemake.config[dataset_name] 241 | preprocessing_kwargs = dataset_info.get("preprocessing") 242 | if preprocessing_kwargs is not None: 243 | if "subsample" in preprocessing_kwargs: 244 | sc.pp.subsample(adata, **preprocessing_kwargs.get("subsample")) 245 | if "filter_genes" in preprocessing_kwargs: 246 | sc.pp.filter_genes(adata, **preprocessing_kwargs.get("filter_genes")) 247 | if "highly_variable_genes" in preprocessing_kwargs: 248 | sc.pp.highly_variable_genes( 249 | adata, **preprocessing_kwargs.get("highly_variable_genes") 250 | ) 251 | adata = adata[:, adata.var.highly_variable].copy() 252 | if dataset_name == "synthetic": 253 | assign_symsim_donors(adata) 254 | elif dataset_name == "semisynthetic": 255 | adata = construct_tree_semisynth( 256 | adata, 257 | depth_tree=4, 258 | ) 259 | elif dataset_name == "snrna": 260 | adata = process_snrna(adata) 261 | 262 | create_obs_mapper(adata, dataset_name) 263 | 264 | adata.write(output_h5ad) 265 | 266 | 267 | if __name__ == "__main__": 268 | process_dataset( 269 | dataset_name=snakemake.wildcards.dataset, 270 | input_h5ad=snakemake.input[0], 271 | output_h5ad=snakemake.output[0], 272 | ) 273 | -------------------------------------------------------------------------------- /workflow/scripts/run_model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import anndata as ad 5 | import numpy as np 6 | import scanpy as sc 7 | from utils import MILO, PCAKNN, CompositionBaseline, MrVIWrapper, SCVIModel 8 | 9 | profile = snakemake.config[snakemake.wildcards.dataset] 10 | N_EPOCHS = profile["nEpochs"] 11 | if "batchSize" in profile: 12 | BATCH_SIZE = profile["batchSize"] 13 | else: 14 | BATCH_SIZE = 256 15 | 16 | MRVI_BASE_MODEL_KWARGS = dict( 17 | observe_library_sizes=True, 18 | # n_latent_donor=5, 19 | px_kwargs=dict(n_hidden=32), 20 | pz_kwargs=dict( 21 | n_layers=1, 22 | n_hidden=32, 23 | ), 24 | ) 25 | MRVI_BASE_TRAIN_KWARGS = dict( 26 | max_epochs=N_EPOCHS, 27 | early_stopping=True, 28 | early_stopping_patience=15, 29 | check_val_every_n_epoch=1, 30 | batch_size=BATCH_SIZE, 31 | train_size=0.9, 32 | ) 33 | 34 | MODELS = dict( 35 | PCAKNN=(PCAKNN, dict()), 36 | MILO=(MILO, dict(model_kwargs=dict(embedding="mnn"))), 37 | MILOSCVI=( 38 | MILO, 39 | dict( 40 | model_kwargs=dict( 41 | embedding="scvi", 42 | dropout_rate=0.0, 43 | dispersion="gene", 44 | gene_likelihood="nb", 45 | ), 46 | train_kwargs=dict( 47 | batch_size=BATCH_SIZE, 48 | plan_kwargs=dict(lr=1e-2, n_epochs_kl_warmup=20), 49 | max_epochs=N_EPOCHS, 50 | early_stopping=True, 51 | early_stopping_patience=15, 52 | ), 53 | ), 54 | ), 55 | CompositionPCA=( 56 | CompositionBaseline, 57 | dict( 58 | model_kwargs=dict( 59 | dim_reduction_approach="PCA", 60 | n_dim=50, 61 | clustering_on="celltype", 62 | ), 63 | train_kwargs=None, 64 | ), 65 | ), 66 | CompositionSCVI=( 67 | CompositionBaseline, 68 | dict( 69 | model_kwargs=dict( 70 | dim_reduction_approach="SCVI", 71 | n_dim=10, 72 | clustering_on="celltype", 73 | ), 74 | train_kwargs=dict( 75 | batch_size=BATCH_SIZE, 76 | plan_kwargs=dict(lr=1e-2), 77 | max_epochs=N_EPOCHS, 78 | early_stopping=True, 79 | early_stopping_patience=15, 80 | ), 81 | ), 82 | ), 83 | SCVI=( 84 | SCVIModel, 85 | dict( 86 | model_kwargs=dict( 87 | dropout_rate=0.0, 88 | feed_nuisance=False, 89 | dispersion="gene", 90 | gene_likelihood="nb", 91 | ), 92 | train_kwargs=dict( 93 | batch_size=BATCH_SIZE, 94 | plan_kwargs=dict(lr=1e-2, n_epochs_kl_warmup=20), 95 | max_epochs=N_EPOCHS, 96 | early_stopping=True, 97 | early_stopping_patience=15, 98 | ), 99 | ), 100 | ), 101 | MrVISmall=( 102 | MrVIWrapper, 103 | dict( 104 | model_kwargs=dict( 105 | linear_decoder_zx=False, 106 | **MRVI_BASE_MODEL_KWARGS, 107 | ), 108 | train_kwargs=dict( 109 | plan_kwargs=dict( 110 | lr=1e-2, n_epochs_kl_warmup=20, do_comp=True, lambd=1.0 111 | ), 112 | **MRVI_BASE_TRAIN_KWARGS, 113 | ), 114 | ), 115 | ), 116 | MrVILinear=( 117 | MrVIWrapper, 118 | dict( 119 | model_kwargs=dict( 120 | linear_decoder_zx=True, 121 | n_latent=10, 122 | **MRVI_BASE_MODEL_KWARGS, 123 | ), 124 | train_kwargs=dict( 125 | plan_kwargs=dict( 126 | lr=1e-2, n_epochs_kl_warmup=20, do_comp=False, lambd=1.0 127 | ), 128 | **MRVI_BASE_TRAIN_KWARGS, 129 | ), 130 | ), 131 | ), 132 | MrVILinear50=( 133 | MrVIWrapper, 134 | dict( 135 | model_kwargs=dict( 136 | linear_decoder_zx=True, 137 | n_latent=50, 138 | n_latent_donor=2, 139 | **MRVI_BASE_MODEL_KWARGS, 140 | ), 141 | train_kwargs=dict( 142 | plan_kwargs=dict( 143 | lr=1e-2, n_epochs_kl_warmup=20, do_comp=False, lambd=1.0 144 | ), 145 | **MRVI_BASE_TRAIN_KWARGS, 146 | ), 147 | ), 148 | ), 149 | MrVILinear50COMP=( 150 | MrVIWrapper, 151 | dict( 152 | model_kwargs=dict( 153 | linear_decoder_zx=True, 154 | n_latent=50, 155 | n_latent_donor=2, 156 | **MRVI_BASE_MODEL_KWARGS, 157 | ), 158 | train_kwargs=dict( 159 | plan_kwargs=dict( 160 | lr=1e-2, 161 | n_epochs_kl_warmup=20, 162 | do_comp=True, 163 | lambd=0.1, 164 | ), 165 | **MRVI_BASE_TRAIN_KWARGS, 166 | ), 167 | ), 168 | ), 169 | MrVILinear10COMP=( 170 | MrVIWrapper, 171 | dict( 172 | model_kwargs=dict( 173 | linear_decoder_zx=True, 174 | n_latent=10, 175 | n_latent_donor=2, 176 | **MRVI_BASE_MODEL_KWARGS, 177 | ), 178 | train_kwargs=dict( 179 | plan_kwargs=dict( 180 | lr=1e-2, 181 | n_epochs_kl_warmup=20, 182 | do_comp=True, 183 | lambd=0.1, 184 | ), 185 | **MRVI_BASE_TRAIN_KWARGS, 186 | ), 187 | ), 188 | ), 189 | MrVILinearLinear10COMP=( 190 | MrVIWrapper, 191 | dict( 192 | model_kwargs=dict( 193 | linear_decoder_zx=True, 194 | linear_decoder_uz=True, 195 | n_latent=10, 196 | n_latent_donor=2, 197 | **MRVI_BASE_MODEL_KWARGS, 198 | ), 199 | train_kwargs=dict( 200 | plan_kwargs=dict( 201 | lr=1e-2, 202 | n_epochs_kl_warmup=20, 203 | do_comp=True, 204 | lambd=0.1, 205 | ), 206 | **MRVI_BASE_TRAIN_KWARGS, 207 | ), 208 | ), 209 | ), 210 | MrVILinearLinear10=( 211 | MrVIWrapper, 212 | dict( 213 | model_kwargs=dict( 214 | linear_decoder_zx=True, 215 | linear_decoder_uz=True, 216 | n_latent=10, 217 | n_latent_donor=2, 218 | **MRVI_BASE_MODEL_KWARGS, 219 | ), 220 | train_kwargs=dict( 221 | plan_kwargs=dict( 222 | lr=1e-2, 223 | n_epochs_kl_warmup=20, 224 | do_comp=False, 225 | lambd=0.1, 226 | ), 227 | **MRVI_BASE_TRAIN_KWARGS, 228 | ), 229 | ), 230 | ), 231 | MrVILinearLinear10SCALER=( 232 | MrVIWrapper, 233 | dict( 234 | model_kwargs=dict( 235 | linear_decoder_zx=True, 236 | linear_decoder_uz=True, 237 | linear_decoder_uz_scaler=True, 238 | n_latent=10, 239 | n_latent_donor=2, 240 | **MRVI_BASE_MODEL_KWARGS, 241 | ), 242 | train_kwargs=dict( 243 | plan_kwargs=dict( 244 | lr=1e-2, 245 | n_epochs_kl_warmup=20, 246 | do_comp=False, 247 | lambd=0.1, 248 | ), 249 | **MRVI_BASE_TRAIN_KWARGS, 250 | ), 251 | ), 252 | ), 253 | ) 254 | 255 | 256 | def compute_model_predictions(model_name, path_to_h5ad, path_to_output, random_seed): 257 | np.random.seed(random_seed) 258 | adata = sc.read_h5ad(path_to_h5ad) 259 | logging.info("adata shape: {}".format(adata.shape)) 260 | mapper = adata.uns["mapper"] 261 | algo_cls, algo_kwargs = MODELS[model_name] 262 | model = algo_cls(adata=adata, **algo_kwargs, **mapper) 263 | model.fit() 264 | if model.has_donor_representation: 265 | rep = model.get_donor_representation().assign(model=model_name) 266 | rep.columns = rep.columns.astype(str) 267 | adata.uns["{}_donor".format(model_name)] = rep 268 | if model.has_cell_representation: 269 | repb = model.get_cell_representation() 270 | adata.obsm["{}_cell".format(model_name)] = repb 271 | if model.has_local_donor_representation: 272 | _adata = None 273 | scores = model.get_local_sample_representation(adata=_adata) 274 | adata.uns["{}_local_donor_rep".format(model_name)] = scores 275 | if hasattr(model, "get_donor_representation_metadata"): 276 | metadata = model.get_donor_representation_metadata() 277 | adata.uns["{}_local_donor_rep_metadata".format(model_name)] = metadata 278 | if model.has_custom_representation: 279 | adata = model.compute() 280 | 281 | adata.uns["model_name"] = model_name 282 | adata.X = None 283 | adata_ = ad.AnnData( 284 | obs=adata.obs, 285 | var=adata.var, 286 | obsm=adata.obsm, 287 | uns=adata.uns, 288 | ) 289 | adata_.write_h5ad(path_to_output) 290 | if model.has_save: 291 | dir_path = os.path.dirname(path_to_output) 292 | model_dir_path = os.path.join(dir_path, model_name) 293 | model.save(model_dir_path) 294 | 295 | 296 | if __name__ == "__main__": 297 | compute_model_predictions( 298 | model_name=snakemake.wildcards.model, 299 | path_to_h5ad=snakemake.input[0], 300 | path_to_output=snakemake.output[0], 301 | random_seed=int(snakemake.wildcards.seed), 302 | ) 303 | -------------------------------------------------------------------------------- /workflow/scripts/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from ._baselines import ( 2 | PCAKNN, 3 | CompositionBaseline, 4 | CTypeProportions, 5 | PseudoBulkPCA, 6 | SCVIModel, 7 | StatifiedPseudoBulkPCA, 8 | ) 9 | from ._metrics import ( 10 | compute_cramers, 11 | compute_geary, 12 | compute_hotspot_morans, 13 | compute_ks, 14 | compute_manova, 15 | ) 16 | from ._milo import MILO 17 | from ._mrvi import MrVIWrapper 18 | 19 | __all__ = [ 20 | "MrVIWrapper", 21 | "CTypeProportions", 22 | "CompositionBaseline", 23 | "PseudoBulkPCA", 24 | "MILO", 25 | "SCVIModel", 26 | "StatifiedPseudoBulkPCA", 27 | "PCAKNN", 28 | "compute_geary", 29 | "compute_hotspot_morans", 30 | "compute_cramers", 31 | "compute_ks", 32 | "compute_manova", 33 | ] 34 | -------------------------------------------------------------------------------- /workflow/scripts/utils/_base_model.py: -------------------------------------------------------------------------------- 1 | import scanpy as sc 2 | 3 | 4 | class BaseModelClass: 5 | has_donor_representation = False 6 | has_cell_representation = False 7 | has_local_donor_representation = False 8 | has_custom_representation = False 9 | has_save = False 10 | 11 | def __init__( 12 | self, 13 | adata, 14 | cell_type_key, 15 | donor_key, 16 | categorical_nuisance_keys=None, 17 | n_hvg=None, 18 | ): 19 | self.adata = adata 20 | self.cell_type_key = cell_type_key 21 | self.donor_key = donor_key 22 | 23 | self.n_genes = self.adata.X.shape[1] 24 | self.n_donors = self.adata.obs[self.donor_key].unique().shape[0] 25 | self.categorical_nuisance_keys = categorical_nuisance_keys 26 | self.n_hvg = n_hvg 27 | 28 | def get_donor_representation(self, adata=None): 29 | return None 30 | 31 | def _filter_hvg(self): 32 | if (self.n_hvg is not None) and (self.n_hvg <= self.n_genes - 1): 33 | adata_ = self.adata.copy() 34 | sc.pp.highly_variable_genes( 35 | adata=adata_, n_top_genes=self.n_hvg, flavor="seurat_v3" 36 | ) 37 | self.adata = adata_[:, self.highly_variable] 38 | 39 | def preprocess_data(self): 40 | self._filter_hvg() 41 | 42 | def fit(self): 43 | return None 44 | 45 | def get_cell_representation(self, adata=None): 46 | return None 47 | 48 | def save(self, save_path): 49 | return 50 | -------------------------------------------------------------------------------- /workflow/scripts/utils/_baselines.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import pynndescent 4 | import scanpy as sc 5 | import torch 6 | from scvi import REGISTRY_KEYS 7 | from scvi.model import SCVI 8 | from sklearn.decomposition import PCA 9 | from tqdm import tqdm 10 | 11 | from ._base_model import BaseModelClass 12 | 13 | 14 | class CTypeProportions(BaseModelClass): 15 | has_cell_representation = False 16 | has_donor_representation = True 17 | 18 | def __init__(self, **kwargs): 19 | super().__init__(**kwargs) 20 | 21 | def get_donor_representation(self, adata=None): 22 | adata = self.adata if adata is None else adata 23 | ct_props = ( 24 | self.adata.obs.groupby(self.donor_key)[self.cell_type_key] 25 | .value_counts() 26 | .unstack() 27 | .fillna(0.0) 28 | .apply(lambda x: x / x.sum(), axis=1) 29 | ) 30 | return ct_props 31 | 32 | 33 | class PseudoBulkPCA(BaseModelClass): 34 | has_cell_representation = False 35 | has_donor_representation = True 36 | 37 | def __init__(self, n_components=50, **kwargs): 38 | super().__init__(**kwargs) 39 | self.n_components = np.minimum(n_components, self.n_donors) 40 | 41 | def get_donor_representation(self, adata=None): 42 | adata = self.adata if adata is None else adata 43 | idx_donor = adata.obs[self.donor_key] 44 | 45 | X = np.zeros((self.n_donors, self.n_genes)) 46 | unique_donors = idx_donor.unique() 47 | for idx, unique_donor in enumerate(unique_donors): 48 | cell_is_selected = idx_donor == unique_donor 49 | X[idx, :] = adata.X[cell_is_selected].sum(0) 50 | X_cpm = 1e6 * X / X.sum(-1, keepdims=True) 51 | X_logcpm = np.log1p(X_cpm) 52 | z = PCA(n_components=self.n_components).fit_transform(X_logcpm) 53 | return pd.DataFrame(z, index=unique_donors) 54 | 55 | 56 | class StatifiedPseudoBulkPCA(BaseModelClass): 57 | has_cell_representation = False 58 | has_donor_representation = False 59 | has_local_donor_representation = True 60 | 61 | def __init__(self, n_components=50, **kwargs): 62 | super().__init__(**kwargs) 63 | self.n_components = np.minimum(n_components, self.n_donors) 64 | self.cell_type_key = None 65 | 66 | def fit(self): 67 | pass 68 | 69 | def get_local_sample_representation(self, adata=None): 70 | self.cell_type_key = self.adata.uns["mapper"]["cell_type_key"] 71 | adata = self.adata 72 | idx_donor = adata.obs[self.donor_key] 73 | cell_types = adata.obs[self.cell_type_key] 74 | 75 | X = [] 76 | unique_donors = idx_donor.unique() 77 | unique_types = cell_types.unique() 78 | reps_all = dict() 79 | for cell_type in unique_types: 80 | X = [] 81 | donors = [] 82 | # Computing the pseudo-bulk for each cell type 83 | for unique_donor in unique_donors: 84 | cell_is_selected = (idx_donor == unique_donor) & ( 85 | cell_types == cell_type 86 | ) 87 | new_counts = adata.X[cell_is_selected].sum(0) 88 | if new_counts.sum() > 0: 89 | X.append(new_counts) 90 | donors.append(unique_donor) 91 | X = np.array(X).squeeze(1) 92 | donors = np.array(donors).astype(str) 93 | X_cpm = 1e6 * X / X.sum(-1, keepdims=True) 94 | X_logcpm = np.log1p(X_cpm) 95 | n_comps = np.minimum(self.n_components, X.shape[0]) 96 | z = PCA(n_components=n_comps).fit_transform(X_logcpm) 97 | reps = pd.DataFrame(z, index=donors) 98 | reps.columns = ["PC_{}".format(i) for i in range(n_comps)] 99 | # reps.index = pd.CategoricalIndex(donors, categories=donors) 100 | reps_all[cell_type] = reps 101 | return reps_all 102 | 103 | 104 | class CompositionBaseline(BaseModelClass): 105 | has_cell_representation = True 106 | has_donor_representation = False 107 | has_local_donor_representation = True 108 | 109 | default_model_kwargs = dict( 110 | dim_reduction_approach="PCA", 111 | n_dim=50, 112 | clustering_on="leiden", 113 | ) 114 | default_train_kwargs = {} 115 | 116 | def __init__(self, model_kwargs=None, train_kwargs=None, **kwargs): 117 | super().__init__(**kwargs) 118 | self.model_kwargs = ( 119 | model_kwargs if model_kwargs is not None else self.default_model_kwargs 120 | ) 121 | self.train_kwargs = ( 122 | train_kwargs if train_kwargs is not None else self.default_train_kwargs 123 | ) 124 | 125 | self.dim_reduction_approach = model_kwargs.pop("dim_reduction_approach") 126 | self.n_dim = model_kwargs.pop("n_dim") 127 | self.clustering_on = model_kwargs.pop( 128 | "clustering_on" 129 | ) # one of leiden, celltype 130 | 131 | def preprocess_data(self): 132 | super().preprocess_data() 133 | self.adata_ = self.adata.copy() 134 | if self.dim_reduction_approach == "PCA": 135 | self.adata_ = self.adata.copy() 136 | sc.pp.normalize_total(self.adata_, target_sum=1e4) 137 | sc.pp.log1p(self.adata_) 138 | 139 | def fit(self): 140 | self.preprocess_data() 141 | if self.dim_reduction_approach == "PCA": 142 | sc.pp.pca(self.adata_, n_comps=self.n_dim) # saves "X_pca" in obsm 143 | self.adata_.obsm["X_red"] = self.adata_.obsm["X_pca"] 144 | elif self.dim_reduction_approach == "SCVI": 145 | SCVI.setup_anndata( 146 | self.adata_, categorical_covariate_keys=self.categorical_nuisance_keys 147 | ) 148 | scvi_model = SCVI(self.adata_, **self.model_kwargs) 149 | scvi_model.train(**self.train_kwargs) 150 | self.adata_.obsm["X_red"] = scvi_model.get_latent_representation() 151 | 152 | def get_cell_representation(self, adata=None): 153 | assert adata is None 154 | return self.adata_.obsm["X_red"] 155 | 156 | def get_local_sample_representation(self, adata=None): 157 | if self.clustering_on == "leiden": 158 | sc.pp.neighbors(self.adata_, n_neighbors=30, use_rep="X_red") 159 | sc.tl.leiden(self.adata_, resolution=1.0, key_added="leiden_1.0") 160 | clustering_key = "leiden_1.0" 161 | elif self.clustering_on == "celltype": 162 | clustering_key = self.cell_type_key 163 | 164 | freqs_all = dict() 165 | for unique_cluster in self.adata_.obs[clustering_key].unique(): 166 | cell_is_selected = self.adata_.obs[clustering_key] == unique_cluster 167 | subann = self.adata_[cell_is_selected].copy() 168 | 169 | # Step 1: subcluster 170 | sc.pp.neighbors(subann, n_neighbors=30, use_rep="X_red") 171 | sc.tl.leiden(subann, resolution=1.0, key_added=clustering_key) 172 | 173 | szs = ( 174 | subann.obs.groupby([clustering_key, self.donor_key]) 175 | .size() 176 | .to_frame("n_cells") 177 | .reset_index() 178 | ) 179 | szs_total = ( 180 | szs.groupby(self.donor_key) 181 | .sum() 182 | .rename(columns={"n_cells": "n_cells_total"}) 183 | ) 184 | comps = szs.merge(szs_total, on=self.donor_key).assign( 185 | freqs=lambda x: x.n_cells / x.n_cells_total 186 | ) 187 | freqs = ( 188 | comps.loc[:, [self.donor_key, clustering_key, "freqs"]] 189 | .set_index([self.donor_key, clustering_key]) 190 | .squeeze() 191 | .unstack() 192 | ) 193 | freqs_ = freqs 194 | freqs_all[unique_cluster] = freqs_ 195 | # n_donors, n_clusters 196 | return freqs_all 197 | 198 | def get_donor_representation_metadata(self, adata=None): 199 | pass 200 | 201 | 202 | class PCAKNN(BaseModelClass): 203 | has_cell_representation = True 204 | has_donor_representation = False 205 | has_local_donor_representation = True 206 | 207 | def __init__(self, n_components=25, **kwargs): 208 | super().__init__(**kwargs) 209 | self.n_components = n_components 210 | self.pca = None 211 | self.adata_ = None 212 | self.donor_order = None 213 | 214 | def preprocess_data(self): 215 | super().preprocess_data() 216 | self.adata_ = self.adata.copy() 217 | sc.pp.normalize_total(self.adata_, target_sum=1e4) 218 | sc.pp.log1p(self.adata_) 219 | 220 | def fit(self): 221 | self.preprocess_data() 222 | sc.pp.pca(self.adata_, n_comps=self.n_components) 223 | 224 | def get_cell_representation(self, adata=None): 225 | assert adata is None 226 | return self.adata_.obsm["X_pca"] 227 | 228 | def get_local_sample_representation(self, adata=None): 229 | # for each cell, compute nearest neighbor in given donor 230 | pca_rep = self.adata_.obsm["X_pca"] 231 | 232 | local_reps = [] 233 | self.donor_order = self.adata_.obs[self.donor_key].unique() 234 | for donor in self.donor_order: 235 | donor_is_selected = self.adata_.obs[self.donor_key] == donor 236 | pca_donor = pca_rep[donor_is_selected] 237 | nn_donor = pynndescent.NNDescent(pca_donor) 238 | nn_indices = nn_donor.query(pca_rep, k=1)[0].squeeze(-1) 239 | nn_rep = pca_donor[nn_indices][:, None, :] 240 | local_reps.append(nn_rep) 241 | local_reps = np.concatenate(local_reps, axis=1) 242 | return local_reps 243 | 244 | def get_donor_representation_metadata(self): 245 | donor_to_id_map = pd.DataFrame( 246 | self.donor_order, columns=[self.donor_key] 247 | ).assign(donor_order=lambda x: np.arange(len(x))) 248 | res = self.adata_.obs.drop_duplicates(self.donor_key) 249 | res = res.merge(donor_to_id_map, on=self.donor_key, how="left").sort_values( 250 | "donor_order" 251 | ) 252 | return res 253 | 254 | 255 | class SCVIModel(BaseModelClass): 256 | has_cell_representation = True 257 | has_local_donor_representation = True 258 | has_save = True 259 | 260 | default_model_kwargs = dict( 261 | dropout_rate=0.0, 262 | dispersion="gene", 263 | gene_likelihood="nb", 264 | ) 265 | default_train_kwargs = dict( 266 | max_epochs=100, 267 | check_val_every_n_epoch=1, 268 | batch_size=256, 269 | plan_kwargs=dict(lr=1e-2, n_epochs_kl_warmup=20), 270 | ) 271 | 272 | def __init__(self, model_kwargs=None, train_kwargs=None, **kwargs): 273 | super().__init__(**kwargs) 274 | self.model_kwargs = ( 275 | self.default_model_kwargs if model_kwargs is None else model_kwargs 276 | ) 277 | self.train_kwargs = ( 278 | self.default_train_kwargs if train_kwargs is None else train_kwargs 279 | ) 280 | self.adata_ = None 281 | 282 | @property 283 | def has_donor_representation(self): 284 | has_donor_embedding = self.model_kwargs.get("do_batch_embedding", False) 285 | return has_donor_embedding 286 | 287 | def fit(self): 288 | self.preprocess_data() 289 | adata_ = self.adata.copy() 290 | 291 | feed_nuisance = self.model_kwargs.pop("feed_nuisance", True) 292 | categorical_nuisance_keys = ( 293 | self.categorical_nuisance_keys if feed_nuisance else None 294 | ) 295 | SCVI.setup_anndata( 296 | adata_, 297 | batch_key=self.donor_key, 298 | categorical_covariate_keys=categorical_nuisance_keys, 299 | ) 300 | self.adata_ = adata_ 301 | self.model = SCVI(adata=adata_, **self.model_kwargs) 302 | self.model.train(**self.train_kwargs) 303 | return True 304 | 305 | def save(self, save_path, overwrite=True): 306 | self.model.save(save_path, overwrite=overwrite) 307 | return True 308 | 309 | # def get_donor_representation(self, adata=None): 310 | # assert self.has_donor_representation 311 | 312 | def get_cell_representation(self, adata=None, batch_size=256): 313 | return self.model.get_latent_representation(adata, batch_size=batch_size) 314 | 315 | @torch.no_grad() 316 | def get_local_sample_representation( 317 | self, adata=None, batch_size=256, x_dim=50, eps=1e-8, mc_samples=10 318 | ): 319 | # z = self.model.get_latent_representation(adata, batch_size=batch_size) 320 | # index = pynndescent.NNDescent(z, n_neighbors=n_neighbors) 321 | # index.prepare() 322 | 323 | # neighbors, _ = index.query(z) 324 | # return neighbors[:, 1:] 325 | 326 | adata = self.adata_ if adata is None else adata 327 | self.model._check_if_trained(warn=False) 328 | adata = self.model._validate_anndata(adata) 329 | scdl = self.model._make_data_loader( 330 | adata=adata, indices=None, batch_size=batch_size 331 | ) 332 | 333 | # hs = self.get_normalized_expression(adata, batch_size=batch_size, eps=eps) 334 | hs = [] 335 | for tensors in tqdm(scdl): 336 | inference_inputs = self.model.module._get_inference_input(tensors) 337 | inference_outputs = self.model.module.inference(**inference_inputs) 338 | 339 | generative_inputs = self.model.module._get_generative_input( 340 | tensors=tensors, inference_outputs=inference_outputs 341 | ) 342 | generative_outputs = self.model.module.generative(**generative_inputs) 343 | new = (eps + generative_outputs["px"].scale).log() 344 | hs.append(new.cpu()) 345 | hs = torch.cat(hs, dim=0).numpy() 346 | means = np.mean(hs, axis=0) 347 | stds = np.std(hs, axis=0) 348 | hs = (hs - means) / stds 349 | pca = PCA(n_components=x_dim).fit(hs) 350 | w = torch.tensor( 351 | pca.components_, dtype=torch.float32, device=self.model.device 352 | ).T 353 | means = torch.tensor(means, dtype=torch.float32, device=self.model.device) 354 | stds = torch.tensor(stds, dtype=torch.float32, device=self.model.device) 355 | 356 | reps = [] 357 | for tensors in tqdm(scdl): 358 | xs = [] 359 | for batch in range(self.model.summary_stats.n_batch): 360 | cf_batch = batch * torch.ones_like(tensors["batch"]) 361 | tensors[REGISTRY_KEYS.BATCH_KEY] = cf_batch 362 | inference_inputs = self.model.module._get_inference_input(tensors) 363 | inference_outputs = self.model.module.inference( 364 | n_samples=mc_samples, **inference_inputs 365 | ) 366 | 367 | generative_inputs = self.model.module._get_generative_input( 368 | tensors=tensors, inference_outputs=inference_outputs 369 | ) 370 | generative_outputs = self.model.module.generative(**generative_inputs) 371 | new = (eps + generative_outputs["px"].scale).log() 372 | if x_dim is not None: 373 | new = (new - means) / stds 374 | new = new @ w 375 | xs.append(new[:, :, None]) 376 | 377 | xs = torch.cat(xs, 2).mean(0) 378 | reps.append(xs.cpu().numpy()) 379 | # n_cells, n_donors, n_donors 380 | reps = np.concatenate(reps, 0) 381 | return reps 382 | 383 | def get_donor_representation_metadata(self): 384 | return ( 385 | self.model.adata.obs.drop_duplicates("_scvi_batch") 386 | .set_index("_scvi_batch") 387 | .sort_index() 388 | ) 389 | -------------------------------------------------------------------------------- /workflow/scripts/utils/_metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import scipy.stats as stats 4 | import statsmodels.api as sm 5 | import torch 6 | from joblib import Parallel, delayed 7 | from sklearn.preprocessing import OneHotEncoder 8 | from statsmodels.multivariate.manova import MANOVA 9 | 10 | 11 | def smooth_distance(sq_dists, mask_farther_than_k=False): 12 | n_donors = sq_dists.shape[-1] 13 | k = int(np.sqrt(n_donors)) 14 | topk_vals, topk_idx = per_cell_knn(sq_dists, k=k) 15 | bandwidth_idx = k // 3 16 | bandwidth_vals = topk_vals[:, :, bandwidth_idx].unsqueeze( 17 | -1 18 | ) # n_cells x n_donors x 1 19 | w_mtx = torch.exp(-sq_dists / bandwidth_vals) # n_cells x n_donors x n_donors 20 | 21 | if mask_farther_than_k: 22 | masked_w_mtx = torch.zeros_like(w_mtx) 23 | masked_w_mtx = masked_w_mtx.scatter( 24 | -1, topk_idx, w_mtx 25 | ) # n_cells x n_donors x n_donors 26 | w_mtx = masked_w_mtx 27 | 28 | return w_mtx 29 | 30 | 31 | @torch.no_grad() 32 | def per_cell_knn(dists, k): 33 | # Given n_cells x n_donors x n_donors returns n_cells x n_donors x k 34 | # tensor of dists and indices of the k nearest neighbors for each donor in each cell 35 | topkp1 = torch.topk(dists, k + 1, dim=-1, largest=False, sorted=True) 36 | topk_values, topk_indices = topkp1.values[:, :, 1:], topkp1.indices[:, :, 1:] 37 | return topk_values, topk_indices 38 | 39 | 40 | @torch.no_grad() 41 | def compute_geary(xs, donor_labels): 42 | oh_feats = OneHotEncoder(sparse=False).fit_transform( 43 | donor_labels[:, None] 44 | ) # n_donors x n_labels 45 | w_mat = torch.tensor(oh_feats @ oh_feats.T, device="cuda") # n_donors x n_donors 46 | xs = xs.unsqueeze(-2) # n_cells x n_donors x 1 x n_donor_latent 47 | sq_dists = ((xs - xs.transpose(-2, -3)) ** 2).sum( 48 | -1 49 | ) # n_cells x n_donors x n_donors 50 | scores_ = (sq_dists * w_mat).sum([-1, -2]) / w_mat.sum([-1, -2]) 51 | var_estim = xs.var(1).sum(-1).squeeze() 52 | scores_ = scores_ / (2.0 * var_estim) 53 | return scores_.cpu().numpy() 54 | 55 | 56 | @torch.no_grad() 57 | def compute_hotspot_morans(xs, donor_labels): 58 | oh_feats = OneHotEncoder(sparse=False).fit_transform(donor_labels[:, None]) 59 | like_label_mtx = torch.tensor(oh_feats @ oh_feats.T, device="cuda") 60 | xx_mtx = (like_label_mtx * 2) - 1 # n_donors x n_donors 61 | 62 | xs = xs.unsqueeze(-2) # n_cells x n_donors x 1 x n_donor_latent 63 | sq_dists = ((xs - xs.transpose(-2, -3)) ** 2).sum( 64 | -1 65 | ) # n_cells x n_donors x n_donors 66 | w_mtx = smooth_distance(sq_dists) 67 | w_norm_mtx = w_mtx / w_mtx.sum(-1, keepdim=True) # n_cells x n_donors x n_donors 68 | 69 | scores_ = (w_norm_mtx * xx_mtx).sum([-1, -2]) # n_cells 70 | return scores_.cpu().numpy() 71 | 72 | 73 | @torch.no_grad() 74 | def compute_cramers(xs, donor_labels): 75 | oh_feats = OneHotEncoder(sparse=False).fit_transform( 76 | donor_labels[:, None] 77 | ) # n_donors x n_labels 78 | oh_feats = torch.tensor(oh_feats, device="cuda").float() # n_donors x n_labels 79 | 80 | xs = xs.unsqueeze(-2) # n_cells x n_donors x 1 x n_donor_latent 81 | sq_dists = ((xs - xs.transpose(-2, -3)) ** 2).sum( 82 | -1 83 | ) # n_cells x n_donors x n_donors 84 | w_mtx = smooth_distance(sq_dists, mask_farther_than_k=True) 85 | 86 | c_ij = w_mtx @ oh_feats # n_cells x n_donors x n_labels 87 | contingency_X = c_ij.transpose(-1, -2) @ oh_feats # n_cells x n_labels x n_labels 88 | 89 | scores = [] 90 | sig_scores = [] 91 | for i in range(contingency_X.shape[0]): 92 | contingency_Xi = contingency_X[i].cpu().numpy() 93 | chi_sq, sig_score, _, _ = stats.chi2_contingency(contingency_Xi) 94 | n = np.sum(contingency_Xi) 95 | min_dim = contingency_Xi.shape[0] - 1 96 | scores.append(np.sqrt(chi_sq / (n * min_dim))) 97 | sig_scores.append(sig_score) 98 | return np.array(scores), np.array(sig_scores) 99 | 100 | 101 | def _random_subsample(d1, n_mc_samples): 102 | n_minibatch = d1.shape[0] 103 | if d1.shape[0] >= n_mc_samples: 104 | d1_ = torch.zeros(n_minibatch, n_mc_samples) 105 | for i in range(n_minibatch): 106 | rdm_idx1 = torch.randperm(d1.shape[0])[:n_mc_samples] 107 | d1_[i] = d1[i][rdm_idx1] 108 | return d1 109 | 110 | 111 | @torch.no_grad() 112 | def compute_manova( 113 | xs, 114 | donor_labels, 115 | ): 116 | target = pd.Series(donor_labels, dtype="category").cat.codes 117 | target = sm.add_constant(target) 118 | 119 | def _compute(xi): 120 | try: 121 | res = ( 122 | MANOVA(endog=xi, exog=target) 123 | .mv_test() 124 | .results["x1"]["stat"] 125 | .loc["Wilks' lambda", ["F Value", "Pr > F"]] 126 | .values 127 | ) 128 | except Exception: 129 | return np.array([1000.0, 1.0]) 130 | return res 131 | 132 | all_res = np.array( 133 | Parallel(n_jobs=10)(delayed(_compute)(xi) for xi in xs), dtype=np.float32 134 | ) 135 | if isinstance(xs, np.ndarray): 136 | xs = torch.from_numpy(xs).to(torch.float32) 137 | all_res = np.array( 138 | Parallel(n_jobs=10)(delayed(_compute)(xi.cpu().numpy()) for xi in xs), 139 | dtype=np.float32, 140 | ) 141 | stats, pvals = all_res[:, 0], all_res[:, 1] 142 | return stats, pvals 143 | 144 | 145 | def compute_ks( 146 | xs, donor_labels, n_mc_samples=5000, alternative="two-sided", do_smoothing=False 147 | ): 148 | n_minibatch = xs.shape[0] 149 | oh_feats = OneHotEncoder(sparse=False).fit_transform( 150 | donor_labels[:, None] 151 | ) # n_donors x n_labels 152 | oh_feats = torch.tensor(oh_feats, device="cuda").float() # n_donors x n_labels 153 | wmat = oh_feats @ oh_feats.T 154 | wmat = wmat - torch.eye(wmat.shape[0], device="cuda").float() 155 | is_off_diag = 1.0 - torch.eye(wmat.shape[0], device="cuda").float() 156 | 157 | xs = xs.unsqueeze(-2) 158 | sq_dists = ((xs - xs.transpose(-2, -3)) ** 2).sum( 159 | -1 160 | ) # n_cells x n_donors x n_donors 161 | if do_smoothing: 162 | sq_dists = smooth_distance(sq_dists) 163 | sq_dists = sq_dists.reshape(n_minibatch, -1) 164 | 165 | d1 = sq_dists[:, wmat.reshape(-1).bool()] 166 | d1 = _random_subsample(d1, n_mc_samples).cpu().numpy() 167 | d2 = sq_dists[:, is_off_diag.reshape(-1).bool()] 168 | d2 = _random_subsample(d2, n_mc_samples).cpu().numpy() 169 | 170 | ks_results = stats.ttest_ind(d1, d2, 1, equal_var=False) 171 | effect_sizes = ks_results.statistic 172 | p_values = ks_results.pvalue 173 | return effect_sizes, p_values 174 | -------------------------------------------------------------------------------- /workflow/scripts/utils/_milo.py: -------------------------------------------------------------------------------- 1 | import scanpy as sc 2 | from scvi.model import SCVI 3 | 4 | from ._base_model import BaseModelClass 5 | 6 | 7 | class MILO(BaseModelClass): 8 | has_donor_representation = False 9 | has_cell_representation = False 10 | has_local_donor_representation = False 11 | has_custom_representation = True 12 | default_model_kwargs = dict( 13 | dropout_rate=0.0, 14 | dispersion="gene", 15 | gene_likelihood="nb", 16 | ) 17 | default_train_kwargs = dict( 18 | max_epochs=100, 19 | check_val_every_n_epoch=1, 20 | batch_size=256, 21 | plan_kwargs=dict(lr=1e-2, n_epochs_kl_warmup=20), 22 | ) 23 | 24 | def __init__(self, model_kwargs=None, train_kwargs=None, **kwargs): 25 | super().__init__(**kwargs) 26 | self.model_kwargs = ( 27 | self.default_model_kwargs if model_kwargs is None else model_kwargs 28 | ) 29 | self.train_kwargs = ( 30 | self.default_train_kwargs if train_kwargs is None else train_kwargs 31 | ) 32 | 33 | def fit(self, **kwargs): 34 | import milopy.core as milo 35 | 36 | adata_ = self.adata.copy() 37 | embedding = self.model_kwargs.pop("embedding", "mnn") 38 | 39 | if embedding == "mnn": 40 | # Run MNN 41 | alldata = [] 42 | batch_key = self.categorical_nuisance_keys[0] 43 | adata_.obs[batch_key] = adata_.obs[batch_key].astype("category") 44 | for batch_cat in adata_.obs[batch_key].cat.categories.tolist(): 45 | alldata.append(adata_[adata_.obs[batch_key] == batch_cat,]) 46 | 47 | cdata = sc.external.pp.mnn_correct( 48 | *alldata, svd_dim=50, batch_key=batch_key, n_jobs=8 49 | )[0] 50 | if isinstance(cdata, tuple): 51 | cdata = cdata[0] 52 | 53 | # Run PCA 54 | cell_rep = sc.tl.pca(cdata.X, svd_solver="arpack", return_info=False) 55 | elif embedding == "scvi": 56 | # Run scVI 57 | self.preprocess_data() 58 | adata_ = self.adata.copy() 59 | 60 | batch_key = self.categorical_nuisance_keys[0] 61 | SCVI.setup_anndata( 62 | adata_, 63 | batch_key=batch_key, 64 | categorical_covariate_keys=( 65 | self.categorical_nuisance_keys[1:] 66 | if len(self.categorical_nuisance_keys) > 1 67 | else None 68 | ), 69 | ) 70 | self.adata_ = adata_ 71 | scvi_model = SCVI(adata_, **self.model_kwargs) 72 | scvi_model.train(**self.train_kwargs) 73 | cell_rep = scvi_model.get_latent_representation() 74 | else: 75 | raise ValueError(f"Unknown embedding: {self.embedding}") 76 | 77 | adata_.obsm["_cell_rep"] = cell_rep 78 | sc.pp.neighbors(adata_, n_neighbors=10, use_rep="_cell_rep") 79 | 80 | ## Assign cells to neighbourhoods 81 | milo.make_nhoods(adata_, prop=0.05) 82 | 83 | ## Count cells from each sample in each nhood 84 | milo.count_nhoods(adata_, sample_col=self.donor_key) 85 | self.adata_ = adata_ 86 | 87 | def compute(self): 88 | return self.adata_ 89 | -------------------------------------------------------------------------------- /workflow/scripts/utils/_mrvi.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import torch 4 | from tqdm import tqdm 5 | 6 | from mrvi import MrVI 7 | 8 | from ._base_model import BaseModelClass 9 | 10 | 11 | class MrVIWrapper(BaseModelClass): 12 | has_cell_representation = True 13 | has_donor_representation = True 14 | has_local_donor_representation = True 15 | has_save = True 16 | 17 | default_model_kwargs = dict( 18 | observe_library_sizes=True, 19 | n_latent_donor=5, 20 | ) 21 | default_train_kwargs = dict( 22 | max_epochs=100, 23 | check_val_every_n_epoch=1, 24 | batch_size=256, 25 | plan_kwargs=dict(lr=1e-2, n_epochs_kl_warmup=20), 26 | ) 27 | model_name = "mrvi" 28 | 29 | def __init__(self, model_kwargs=None, train_kwargs=None, **kwargs): 30 | super().__init__(**kwargs) 31 | self.model_kwargs = self.default_model_kwargs.copy() 32 | if model_kwargs is not None: 33 | self.model_kwargs.update(model_kwargs) 34 | self.train_kwargs = self.default_train_kwargs.copy() 35 | if train_kwargs is not None: 36 | self.train_kwargs.update(train_kwargs) 37 | 38 | def fit(self): 39 | self.preprocess_data() 40 | adata_ = self.adata.copy() 41 | self.adata_ = adata_ 42 | 43 | MrVI.setup_anndata( 44 | adata_, 45 | batch_key=self.donor_key, 46 | categorical_nuisance_keys=self.categorical_nuisance_keys, 47 | categorical_biological_keys=None, 48 | ) 49 | self.model = MrVI(adata=adata_, **self.model_kwargs) 50 | self.model.train(**self.train_kwargs) 51 | return True 52 | 53 | def save(self, save_path, overwrite=True): 54 | self.model.save(save_path, overwrite=overwrite) 55 | return True 56 | 57 | @classmethod 58 | def load(self, adata, save_path): 59 | mapper = adata.uns["mapper"] 60 | cls = MrVIWrapper(adata=adata, **mapper) 61 | cls.model = MrVI.load(save_path, adata) 62 | adata_ = cls.adata.copy() 63 | cls.adata_ = adata_ 64 | return cls 65 | 66 | def get_donor_representation(self): 67 | d_embeddings = self.model.module.donor_embeddings.weight.cpu().detach().numpy() 68 | index_ = ( 69 | self.adata_.obs.drop_duplicates("_scvi_batch") 70 | .set_index("_scvi_batch") 71 | .sort_index() 72 | .loc[:, self.donor_key] 73 | ) 74 | return pd.DataFrame(d_embeddings, index=index_) 75 | 76 | def get_cell_representation(self, adata=None, batch_size=512, give_z=False): 77 | return self.model.get_latent_representation( 78 | adata, batch_size=batch_size, give_z=give_z 79 | ) 80 | 81 | @torch.no_grad() 82 | def get_normalized_expression( 83 | self, 84 | adata=None, 85 | x_log=True, 86 | batch_size=256, 87 | eps=1e-6, 88 | cf_site=0.0, 89 | ): 90 | adata = self.adata_ if adata is None else adata 91 | self.model._check_if_trained(warn=False) 92 | adata = self.model._validate_anndata(adata) 93 | scdl = self.model._make_data_loader( 94 | adata=adata, indices=None, batch_size=batch_size 95 | ) 96 | 97 | reps = [] 98 | for tensors in tqdm(scdl): 99 | xs = [] 100 | if cf_site is not None: 101 | tensors[ 102 | "categorical_nuisance_keys" 103 | ] *= cf_site # set to 0 all nuisance factors 104 | inference_inputs = self.model.module._get_inference_input(tensors) 105 | outputs_n = self.model.module.inference(use_mean=True, **inference_inputs) 106 | outs_g = self.model.module.generative( 107 | **self.model.module._get_generative_input( 108 | tensors, inference_outputs=outputs_n 109 | ) 110 | ) 111 | xs = outs_g["h"] 112 | if x_log: 113 | xs = (eps + xs).log() 114 | reps.append(xs.cpu().numpy()) 115 | # n_cells, n_donors, n_donors 116 | reps = np.concatenate(reps, 0) 117 | return reps 118 | 119 | @torch.no_grad() 120 | def get_average_expression( 121 | self, 122 | adata=None, 123 | indices=None, 124 | batch_size=256, 125 | eps=1e-6, 126 | mc_samples=10, 127 | ): 128 | adata = self.adata_ if adata is None else adata 129 | # self.model._check_if_trained(warn=False) 130 | # adata = self.model._validate_anndata(adata) 131 | scdl = self.model._make_data_loader( 132 | adata=adata, indices=indices, batch_size=batch_size 133 | ) 134 | 135 | reps = np.zeros((self.model.summary_stats.n_batch, adata.n_vars)) 136 | for tensors in tqdm(scdl): 137 | xs = [] 138 | for batch in range(self.model.summary_stats.n_batch): 139 | tensors[ 140 | "categorical_nuisance_keys" 141 | ] *= 0.0 # set to 0 all nuisance factors 142 | 143 | cf_batch = batch * torch.ones_like(tensors["batch"]) 144 | inference_inputs = self.model.module._get_inference_input(tensors) 145 | inference_outputs = self.model.module.inference( 146 | n_samples=mc_samples, cf_batch=cf_batch, **inference_inputs 147 | ) 148 | generative_inputs = self.model.module._get_generative_input( 149 | tensors=tensors, inference_outputs=inference_outputs 150 | ) 151 | generative_outputs = self.model.module.generative(**generative_inputs) 152 | new = generative_outputs["h"] 153 | new = (eps + generative_outputs["h"]).log() 154 | xs.append(new[:, :, None]) 155 | 156 | xs = torch.cat(xs, 2).mean(0) # size (n_cells, n_donors, n_genes) 157 | reps += xs.mean(0).cpu().numpy() 158 | return reps 159 | 160 | @torch.no_grad() 161 | def get_local_sample_representation(self, adata=None, **kwargs): 162 | return self.model.get_local_sample_representation(adata=adata, **kwargs) 163 | 164 | def get_donor_representation_metadata(self): 165 | return ( 166 | self.model.adata.obs.drop_duplicates("_scvi_batch") 167 | .set_index("_scvi_batch") 168 | .sort_index() 169 | ) 170 | --------------------------------------------------------------------------------