├── .editorconfig
├── .flake8
├── .github
└── workflows
│ └── test.yml
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── codecov.yml
├── data
└── synthetic
│ └── scripts
│ ├── symSim.R
│ └── symsim_env.yaml
├── docs
├── Makefile
├── _static
│ └── css
│ │ └── custom.css
├── _templates
│ ├── autosummary
│ │ └── class.rst
│ └── class_no_inherited.rst
├── conf.py
├── extensions
│ └── typed_returns.py
├── index.rst
├── installation.rst
├── make.bat
├── references.bib
├── references.rst
└── release_notes
│ ├── index.rst
│ └── v0.1.0.rst
├── mrvi
├── __init__.py
├── _components.py
├── _constants.py
├── _model.py
├── _module.py
└── _utils.py
├── pyproject.toml
├── readthedocs.yml
├── setup.py
├── tests
├── __init__.py
└── test_mrvi.py
└── workflow
├── Snakefile
├── config
└── config.yaml
├── envs
├── process_data.yaml
└── run_model.yaml
├── notebooks
├── semisynthetic.py
├── snrna.py
└── synthetic.py
└── scripts
├── compute_local_scores.py
├── process_data.py
├── run_model.py
└── utils
├── __init__.py
├── _base_model.py
├── _baselines.py
├── _metrics.py
├── _milo.py
└── _mrvi.py
/.editorconfig:
--------------------------------------------------------------------------------
1 | # http://editorconfig.org
2 |
3 | root = true
4 |
5 | [*]
6 | indent_style = space
7 | indent_size = 4
8 | trim_trailing_whitespace = true
9 | insert_final_newline = true
10 | charset = utf-8
11 | end_of_line = lf
12 |
13 | [*.bat]
14 | indent_style = tab
15 | end_of_line = crlf
16 |
17 | [LICENSE]
18 | insert_final_newline = false
19 |
20 | [Makefile]
21 | indent_style = tab
22 |
--------------------------------------------------------------------------------
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | ignore = E203, E266, E501, W503, W605, N812, F821
3 | exclude = .git,docs,workflow/.snakemake,workflow/notebooks
4 | max-line-length = 119
5 |
6 |
--------------------------------------------------------------------------------
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3 |
4 | name: mrvi
5 |
6 | on:
7 | push:
8 | branches: [main]
9 | pull_request:
10 | branches: [main]
11 |
12 | jobs:
13 | build:
14 | runs-on: ubuntu-latest
15 | strategy:
16 | matrix:
17 | python-version: ["3.9", "3.10", "3.11"]
18 |
19 | steps:
20 | - uses: actions/checkout@v3
21 | - name: Set up Python ${{ matrix.python-version }}
22 | uses: actions/setup-python@v4
23 | with:
24 | python-version: ${{ matrix.python-version }}
25 | - name: Cache pip
26 | uses: actions/cache@v2
27 | with:
28 | path: ~/.cache/pip
29 | key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
30 | restore-keys: |
31 | ${{ runner.os }}-pip-
32 | - name: Install dependencies
33 | run: |
34 | pip install pytest-cov
35 | pip install .[dev]
36 | - name: Lint with flake8
37 | run: |
38 | flake8
39 | - name: Format with black
40 | run: |
41 | black --check .
42 | - name: Test with pytest
43 | run: |
44 | pytest --cov-report=xml --cov=mrvi
45 | - name: After success
46 | run: |
47 | bash <(curl -s https://codecov.io/bash)
48 | pip list
49 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # DS_Store
2 | .DS_Store
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | pip-wheel-metadata/
27 | share/python-wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | *.py,cover
54 | .hypothesis/
55 | .pytest_cache/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 | docs/api/
77 |
78 | # PyBuilder
79 | target/
80 |
81 | # Jupyter Notebook
82 | .ipynb_checkpoints
83 |
84 | # IPython
85 | profile_default/
86 | ipython_config.py
87 |
88 | # pyenv
89 | .python-version
90 |
91 | # pipenv
92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
95 | # install all needed dependencies.
96 | #Pipfile.lock
97 |
98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
99 | __pypackages__/
100 |
101 | # Celery stuff
102 | celerybeat-schedule
103 | celerybeat.pid
104 |
105 | # SageMath parsed files
106 | *.sage.py
107 |
108 | # Environments
109 | .env
110 | .venv
111 | env/
112 | venv/
113 | ENV/
114 | env.bak/
115 | venv.bak/
116 |
117 | # Spyder project settings
118 | .spyderproject
119 | .spyproject
120 |
121 | # Rope project settings
122 | .ropeproject
123 |
124 | # mkdocs documentation
125 | /site
126 |
127 | # mypy
128 | .mypy_cache/
129 | .dmypy.json
130 | dmypy.json
131 |
132 | # Pyre type checker
133 | .pyre/
134 |
135 | # vscode
136 | .vscode/settings.json
137 |
138 | # snakemake
139 | workflow/.snakemake/
140 | workflow/results/
141 | workflow/data/
142 | workflow/figures/
143 |
144 | # mrvi data
145 | data/*/*.h5ad
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/psf/black
3 | rev: 24.1.1
4 | hooks:
5 | - id: black
6 | - repo: https://github.com/PyCQA/flake8
7 | rev: 7.0.0
8 | hooks:
9 | - id: flake8
10 | - repo: https://github.com/pycqa/isort
11 | rev: 5.13.2
12 | hooks:
13 | - id: isort
14 | name: isort (python)
15 | additional_dependencies: [toml]
16 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2021, Yosef Lab
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | 3. Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Multi-resolution Variational Inference
2 |
3 | **DEPRECIATED: please refer to the current [version](https://github.com/YosefLab/mrvi) of the project for up-to-date code**
4 |
5 | Multi-resolution Variational Inference (MrVI) is a package for analysis of sample-level heterogeneity in multi-site, multi-sample single-cell omics data. Built with [scvi-tools](https://scvi-tools.org).
6 |
7 | ---
8 |
9 | To install, run:
10 |
11 | ```
12 | pip install mrvi
13 | ```
14 |
15 | `mrvi.MrVI` follows the same API used in scvi-tools.
16 |
17 | ```python
18 | import mrvi
19 | import anndata
20 |
21 | adata = anndata.read_h5ad("path/to/adata.h5ad")
22 | # Sample (e.g. donors, perturbations, etc.) should go in sample_key
23 | # Sites, plates, and other factors should go in categorical_nuisance_keys
24 | mrvi.MrVI.setup_anndata(adata, sample_key="donor", categorical_nuisance_keys=["site"])
25 | mrvi_model = mrvi.MrVI(adata)
26 | mrvi_model.train()
27 | # Get z representation
28 | adata.obsm["X_mrvi_z"] = mrvi_model.get_latent_representation(give_z=True)
29 | # Get u representation
30 | adata.obsm["X_mrvi_u"] = mrvi_model.get_latent_representation(give_z=False)
31 | # Cells by n_sample by n_latent
32 | cell_sample_representations = mrvi_model.get_local_sample_representation()
33 | # Cells by n_sample by n_sample
34 | cell_sample_sample_distances = mrvi_model.get_local_sample_representation(return_distances=True)
35 | ```
36 |
37 | ## Citation
38 |
39 | ```
40 | @article {Boyeau2022.10.04.510898,
41 | author = {Boyeau, Pierre and Hong, Justin and Gayoso, Adam and Jordan, Michael and Azizi, Elham and Yosef, Nir},
42 | title = {Deep generative modeling for quantifying sample-level heterogeneity in single-cell omics},
43 | elocation-id = {2022.10.04.510898},
44 | year = {2022},
45 | doi = {10.1101/2022.10.04.510898},
46 | publisher = {Cold Spring Harbor Laboratory},
47 | abstract = {Contemporary single-cell omics technologies have enabled complex experimental designs incorporating hundreds of samples accompanied by detailed information on sample-level conditions. Current approaches for analyzing condition-level heterogeneity in these experiments often rely on a simplification of the data such as an aggregation at the cell-type or cell-state-neighborhood level. Here we present MrVI, a deep generative model that provides sample-sample comparisons at a single-cell resolution, permitting the discovery of subtle sample-specific effects across cell populations. Additionally, the output of MrVI can be used to quantify the association between sample-level metadata and cell state variation. We benchmarked MrVI against conventional meta-analysis procedures on two synthetic datasets and one real dataset with a well-controlled experimental structure. This work introduces a novel approach to understanding sample-level heterogeneity while leveraging the full resolution of single-cell sequencing data.Competing Interest StatementN.Y. is an advisor and/or has equity in Cellarity, Celsius Therapeutics, and Rheos Medicine.},
48 | URL = {https://www.biorxiv.org/content/early/2022/10/06/2022.10.04.510898},
49 | eprint = {https://www.biorxiv.org/content/early/2022/10/06/2022.10.04.510898.full.pdf},
50 | journal = {bioRxiv}
51 | }
52 | ```
53 |
--------------------------------------------------------------------------------
/codecov.yml:
--------------------------------------------------------------------------------
1 | # Run to check if valid
2 | # curl --data-binary @codecov.yml https://codecov.io/validate
3 | coverage:
4 | status:
5 | project:
6 | default:
7 | target: 80%
8 | threshold: 1%
9 | patch: off
10 |
--------------------------------------------------------------------------------
/data/synthetic/scripts/symSim.R:
--------------------------------------------------------------------------------
1 | library(SymSim) # Uses the fork https://github.com/justjhong/SymSim to expose DiscreteEVF
2 | library(ape)
3 | library(anndata)
4 |
5 | genCorTree <- function(nTips, cor) {
6 | root <- stree(1, tip.label = "branching_point")
7 | tip <- stree(nTips, tip.label = sprintf("CT%i:1", 1:nTips))
8 | tree <- bind.tree(root, tip, where = 1)
9 | tree$edge.length <- c(cor, rep(1 - cor, nTips))
10 | return(tree)
11 | }
12 |
13 | genMetaEVFs <- function(n_cells_total, nCellTypes, nMetadata, metadataCor, n_nd_evf, nEVFsCellType, nEVFsPerMetadata, save_dir = NULL, randseed = 1) {
14 | set.seed(randseed)
15 | seed <- sample(c(1:1e5), size = nMetadata + 2)
16 |
17 | if (length(metadataCor == 1)) {
18 | metadataCor <- rep(metadataCor, nMetadata)
19 | }
20 | if (length(nEVFsPerMetadata == 1)) {
21 | nEVFsPerMetadata <- rep(nEVFsPerMetadata, nMetadata)
22 | }
23 |
24 | ct_tree <- genCorTree(nCellTypes, 0)
25 | base_evf_res <- DiscreteEVF(ct_tree, n_cells_total, n_cells_total / nCellTypes, 1, 0.4, n_nd_evf, nEVFsCellType, "all", 1, seed[[1]])
26 | evf_mtxs <- base_evf_res[[1]]
27 | base_evf_mtx_ncols <- c()
28 | for (j in 1:3) {
29 | colnames(evf_mtxs[[j]]) <- paste(colnames(evf_mtxs[[j]]), "base", sep = "_")
30 | base_evf_mtx_ncols <- c(base_evf_mtx_ncols, ncol(evf_mtxs[[j]]))
31 | }
32 |
33 | ct_mapping <- ct_tree$tip.label
34 | names(ct_mapping) <- seq_len(length(ct_mapping))
35 | meta <- data.frame("celltype" = ct_mapping[base_evf_res[[2]]$pop])
36 |
37 | for (i in 1:nMetadata) {
38 | shuffled_row_idxs <- sample(1:n_cells_total)
39 | meta_tree <- genCorTree(2, metadataCor[[i]])
40 | meta_evf_res <- DiscreteEVF(meta_tree, n_cells_total, n_cells_total / 2, 1, 1.0, 0, nEVFsPerMetadata[[i]], "all", 1, seed[[i + 2]])
41 | meta_evf_mtxs <- meta_evf_res[[1]]
42 | for (j in 1:3) {
43 | colnames(meta_evf_mtxs[[j]]) <- paste(colnames(meta_evf_mtxs[[j]]), sprintf("meta_%d", i), sep = "_")
44 | evf_mtxs[[j]] <- cbind(evf_mtxs[[j]], meta_evf_mtxs[[j]][shuffled_row_idxs, ])
45 | }
46 |
47 | meta_mapping <- meta_tree$tip.label
48 | names(meta_mapping) <- seq_len(length(meta_mapping))
49 | meta[sprintf("meta_%d", i)] <- meta_mapping[meta_evf_res[[2]]$pop][shuffled_row_idxs]
50 | }
51 |
52 |
53 | # random meta_evfs for cell type 2
54 | nc_tree <- genCorTree(1, 0)
55 | nc_evf_res <- DiscreteEVF(nc_tree, n_cells_total / 2, n_cells_total / 2, 1, 1.0, 0, sum(nEVFsPerMetadata), "all", 1, seed[[2]])
56 | nc_evf_mtxs <- nc_evf_res[[1]]
57 | ct2_idxs <- which(meta$celltype == "CT2:1")
58 | for (j in 1:3) {
59 | evf_mtxs[[j]][ct2_idxs, (base_evf_mtx_ncols[[j]] + 1):ncol(evf_mtxs[[j]])] = nc_evf_mtxs[[j]]
60 | }
61 |
62 | return(list(evf_mtxs, meta, nc_evf_mtxs))
63 | }
64 |
65 |
66 | MetaSim <- function(nMetadata, metadataCor, nEVFsPerMetadata, nEVFsCellType, write = F, save_path = NULL, randseed = 1) {
67 | ncells <- 20000
68 | ngenes <- 2000
69 | data(gene_len_pool)
70 | gene_len <- sample(gene_len_pool, ngenes, replace = FALSE)
71 |
72 | evf_res <- genMetaEVFs(n_cells_total = ncells, nCellTypes = 2, nMetadata = nMetadata, metadataCor = metadataCor, n_nd_evf = 60, nEVFsCellType = nEVFsCellType, nEVFsPerMetadata = nEVFsPerMetadata, randseed = randseed)
73 |
74 | print("simulating true counts")
75 | true_counts_res <- SimulateTrueCountsFromEVF(evf_res, ngenes = ngenes, randseed = randseed)
76 | rm(evf_res)
77 | gc()
78 |
79 | print("simulating observed counts")
80 | observed_counts <- True2ObservedCounts(true_counts = true_counts_res[[1]], meta_cell = true_counts_res[[3]], protocol = "UMI", alpha_mean = 0.05, alpha_sd = 0.02, gene_len = gene_len, depth_mean = 5e4, depth_sd = 3e3)
81 | rm(true_counts_res)
82 | gc()
83 |
84 | print("simulating batch effects")
85 | observed_counts_2batches <- DivideBatches(observed_counts_res = observed_counts, nbatch = 2, batch_effect_size = 1)
86 | rm(observed_counts)
87 | gc()
88 |
89 | print("converting to anndata")
90 | meta_keys <- c("celltype", "batch")
91 | meta_keys <- c(meta_keys, paste("meta_", 1:nMetadata, sep = ""))
92 | ad_results <- AnnData(X = t(observed_counts_2batches$counts), obs = observed_counts_2batches$cell_meta[meta_keys])
93 |
94 | if (write == T && !is.null(save_path)) {
95 | write_h5ad(ad_results, save_path)
96 | }
97 |
98 | return(list(ad_results, evf_res))
99 | }
100 |
101 | meta_results <- MetaSim(3, metadataCor = c(0, 0.5, 0.9), nEVFsPerMetadata = 7, nEVFsCellType = 40, write = T, save_path = "3_meta_sim_20k.h5ad", randseed = 126)
--------------------------------------------------------------------------------
/data/synthetic/scripts/symsim_env.yaml:
--------------------------------------------------------------------------------
1 | name: symsim
2 | channels:
3 | - r
4 | - conda-forge
5 | - defaults
6 | dependencies:
7 | - _libgcc_mutex=0.1=conda_forge
8 | - _openmp_mutex=4.5=2_gnu
9 | - _r-mutex=1.0.0=anacondar_1
10 | - binutils_impl_linux-64=2.36.1=h193b22a_2
11 | - bwidget=1.9.14=ha770c72_1
12 | - bzip2=1.0.8=h7f98852_4
13 | - c-ares=1.18.1=h7f98852_0
14 | - ca-certificates=2022.9.14=ha878542_0
15 | - cairo=1.16.0=ha61ee94_1014
16 | - curl=7.83.1=h2283fc2_0
17 | - expat=2.4.8=h27087fc_0
18 | - font-ttf-dejavu-sans-mono=2.37=hab24e00_0
19 | - font-ttf-inconsolata=3.000=h77eed37_0
20 | - font-ttf-source-code-pro=2.038=h77eed37_0
21 | - font-ttf-ubuntu=0.83=hab24e00_0
22 | - fontconfig=2.14.0=hc2a2eb6_1
23 | - fonts-conda-ecosystem=1=0
24 | - fonts-conda-forge=1=0
25 | - freetype=2.12.1=hca18f0e_0
26 | - fribidi=1.0.10=h36c2ea0_0
27 | - gcc_impl_linux-64=12.1.0=hea43390_16
28 | - gettext=0.19.8.1=h73d1719_1008
29 | - gfortran_impl_linux-64=12.1.0=h1db8e46_16
30 | - graphite2=1.3.13=h58526e2_1001
31 | - gsl=2.7=he838d99_0
32 | - gxx_impl_linux-64=12.1.0=hea43390_16
33 | - harfbuzz=5.2.0=hf9f4e7c_0
34 | - icu=70.1=h27087fc_0
35 | - jpeg=9e=h166bdaf_2
36 | - kernel-headers_linux-64=2.6.32=he073ed8_15
37 | - keyutils=1.6.1=h166bdaf_0
38 | - krb5=1.19.3=h08a2579_0
39 | - ld_impl_linux-64=2.36.1=hea4e1c9_2
40 | - lerc=4.0.0=h27087fc_0
41 | - libblas=3.9.0=16_linux64_openblas
42 | - libcblas=3.9.0=16_linux64_openblas
43 | - libcurl=7.83.1=h2283fc2_0
44 | - libdeflate=1.14=h166bdaf_0
45 | - libedit=3.1.20191231=he28a2e2_2
46 | - libev=4.33=h516909a_1
47 | - libffi=3.4.2=h7f98852_5
48 | - libgcc-devel_linux-64=12.1.0=h1ec3361_16
49 | - libgcc-ng=12.1.0=h8d9b700_16
50 | - libgfortran-ng=12.1.0=h69a702a_16
51 | - libgfortran5=12.1.0=hdcd56e2_16
52 | - libglib=2.72.1=h2d90d5f_0
53 | - libgomp=12.1.0=h8d9b700_16
54 | - libiconv=1.16=h516909a_0
55 | - liblapack=3.9.0=16_linux64_openblas
56 | - libnghttp2=1.47.0=hff17c54_1
57 | - libopenblas=0.3.21=pthreads_h78a6416_3
58 | - libpng=1.6.38=h753d276_0
59 | - libsanitizer=12.1.0=ha89aaad_16
60 | - libssh2=1.10.0=hf14f497_3
61 | - libstdcxx-devel_linux-64=12.1.0=h1ec3361_16
62 | - libstdcxx-ng=12.1.0=ha89aaad_16
63 | - libtiff=4.4.0=h55922b4_4
64 | - libuuid=2.32.1=h7f98852_1000
65 | - libwebp-base=1.2.4=h166bdaf_0
66 | - libxcb=1.13=h7f98852_1004
67 | - libxml2=2.9.14=h22db469_4
68 | - libzlib=1.2.12=h166bdaf_3
69 | - make=4.3=hd18ef5c_1
70 | - ncurses=6.3=h27087fc_1
71 | - openssl=3.0.5=h166bdaf_2
72 | - pango=1.50.10=hc4f8a73_0
73 | - pcre=8.45=h9c3ff4c_0
74 | - pcre2=10.37=hc3806b6_1
75 | - pixman=0.40.0=h36c2ea0_0
76 | - pthread-stubs=0.4=h36c2ea0_1001
77 | - r-askpass=1.1=r42h76d94ec_0
78 | - r-base=4.2.1=ha8c3e7c_1
79 | - r-brew=1.0_7=r42h6115d3f_0
80 | - r-brio=1.1.3=r42h76d94ec_0
81 | - r-cachem=1.0.6=r42h76d94ec_0
82 | - r-callr=3.7.0=r42h6115d3f_0
83 | - r-cli=3.3.0=r42h884c59f_0
84 | - r-clipr=0.8.0=r42h6115d3f_0
85 | - r-commonmark=1.8.0=r42h76d94ec_0
86 | - r-cpp11=0.4.2=r42h6115d3f_0
87 | - r-crayon=1.5.1=r42h6115d3f_0
88 | - r-credentials=1.3.2=r42h142f84f_0
89 | - r-curl=4.3.2=r42h76d94ec_0
90 | - r-desc=1.4.1=r42h6115d3f_0
91 | - r-devtools=2.4.3=r42h6115d3f_0
92 | - r-diffobj=0.3.5=r42h76d94ec_0
93 | - r-digest=0.6.29=r42h884c59f_0
94 | - r-ellipsis=0.3.2=r42h76d94ec_0
95 | - r-evaluate=0.15=r42h6115d3f_0
96 | - r-fansi=1.0.3=r42h76d94ec_0
97 | - r-fastmap=1.1.0=r42h884c59f_0
98 | - r-fs=1.5.2=r42h884c59f_0
99 | - r-gert=1.6.0=r42h76d94ec_0
100 | - r-gh=1.3.0=r42h142f84f_0
101 | - r-gitcreds=0.1.1=r42h6115d3f_0
102 | - r-glue=1.6.2=r42h76d94ec_0
103 | - r-highr=0.9=r42h6115d3f_0
104 | - r-httr=1.4.3=r42h6115d3f_0
105 | - r-ini=0.3.1=r42h142f84f_0
106 | - r-jsonlite=1.8.0=r42h76d94ec_0
107 | - r-knitr=1.39=r42h6115d3f_0
108 | - r-lifecycle=1.0.1=r42h142f84f_0
109 | - r-magrittr=2.0.3=r42h76d94ec_0
110 | - r-memoise=2.0.1=r42h6115d3f_0
111 | - r-mime=0.12=r42h76d94ec_0
112 | - r-openssl=2.0.2=r42h76d94ec_0
113 | - r-pillar=1.7.0=r42h6115d3f_0
114 | - r-pkgbuild=1.3.1=r42h142f84f_0
115 | - r-pkgconfig=2.0.3=r42h6115d3f_0
116 | - r-pkgload=1.2.4=r42h142f84f_0
117 | - r-praise=1.0.0=r42h6115d3f_4
118 | - r-prettyunits=1.1.1=r42h142f84f_0
119 | - r-processx=3.5.3=r42h76d94ec_0
120 | - r-ps=1.7.0=r42h76d94ec_0
121 | - r-purrr=0.3.4=r42h76d94ec_0
122 | - r-r6=2.5.1=r42h6115d3f_0
123 | - r-rappdirs=0.3.3=r42h76d94ec_0
124 | - r-rcmdcheck=1.4.0=r42h142f84f_0
125 | - r-rematch2=2.1.2=r42h142f84f_0
126 | - r-remotes=2.4.2=r42h142f84f_0
127 | - r-rlang=1.0.2=r42h884c59f_0
128 | - r-roxygen2=7.2.0=r42h884c59f_0
129 | - r-rprojroot=2.0.3=r42h6115d3f_0
130 | - r-rstudioapi=0.13=r42h6115d3f_0
131 | - r-rversions=2.1.1=r42h6115d3f_0
132 | - r-sessioninfo=1.2.2=r42h142f84f_0
133 | - r-stringi=1.7.6=r42h884c59f_0
134 | - r-stringr=1.4.0=r42h6115d3f_0
135 | - r-sys=3.4=r42h76d94ec_0
136 | - r-testthat=3.1.4=r42h884c59f_0
137 | - r-tibble=3.1.7=r42h76d94ec_0
138 | - r-usethis=2.1.6=r42h142f84f_0
139 | - r-utf8=1.2.2=r42h76d94ec_0
140 | - r-vctrs=0.4.1=r42h884c59f_0
141 | - r-waldo=0.4.0=r42h6115d3f_0
142 | - r-whisker=0.4=r42h6115d3f_0
143 | - r-withr=2.5.0=r42h6115d3f_0
144 | - r-xfun=0.31=r42h76d94ec_0
145 | - r-xml2=1.3.3=r42h884c59f_0
146 | - r-xopen=1.0.0=r42h142f84f_0
147 | - r-yaml=2.3.5=r42h76d94ec_0
148 | - r-zip=2.2.0=r42h76d94ec_0
149 | - readline=8.1.2=h0f457ee_0
150 | - sed=4.8=he412f7d_0
151 | - sysroot_linux-64=2.12=he073ed8_15
152 | - tk=8.6.12=h27826a3_0
153 | - tktable=2.10=hb7b940f_3
154 | - xorg-kbproto=1.0.7=h7f98852_1002
155 | - xorg-libice=1.0.10=h7f98852_0
156 | - xorg-libsm=1.2.3=hd9c2040_1000
157 | - xorg-libx11=1.7.2=h7f98852_0
158 | - xorg-libxau=1.0.9=h7f98852_0
159 | - xorg-libxdmcp=1.1.3=h7f98852_0
160 | - xorg-libxext=1.3.4=h7f98852_1
161 | - xorg-libxrender=0.9.10=h7f98852_1003
162 | - xorg-libxt=1.2.1=h7f98852_2
163 | - xorg-renderproto=0.11.1=h7f98852_1002
164 | - xorg-xextproto=7.3.0=h7f98852_1002
165 | - xorg-xproto=7.0.31=h7f98852_1007
166 | - xz=5.2.6=h166bdaf_0
167 | - zlib=1.2.12=h166bdaf_3
168 | - zstd=1.5.2=h6239696_4
169 | prefix: /home/justin/miniconda3/envs/symsim
170 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line.
5 | SPHINXOPTS =
6 | SPHINXBUILD = python -msphinx
7 | SPHINXPROJ = mrvi
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/_static/css/custom.css:
--------------------------------------------------------------------------------
1 | /* influenced by and borrowed from: https://github.com/cvxgrp/pymde/blob/main/docs_src/source/_static/css/custom.css */
2 |
3 | @import url('https://fonts.googleapis.com/css2?family=Roboto:ital,wght@0,300;0,400;0,600;1,300;1,400;1,600&display=swap');
4 |
5 | :root {
6 | --sidebarcolor: #003262;
7 | --sidebarfontcolor: #ffffff;
8 | --sidebarhover: #295e97;
9 |
10 | --bodyfontcolor: #333;
11 | --webfont: 'Roboto';
12 |
13 | --contentwidth: 1000px;
14 | }
15 |
16 | /* Fonts and text */
17 | h1, h2, h3, h4, h5, h6 {
18 | font-family: var(--webfont), 'Helvetica Neue', Helvetica, Arial, sans-serif;
19 | font-weight: 400;
20 | }
21 |
22 | h2, h3, h4, h5, h6 {
23 | padding-top: 0.25em;
24 | margin-bottom: 0.5em;
25 | }
26 |
27 | h1 {
28 | font-size: 225%;
29 | }
30 |
31 | body {
32 | font-family: var(--webfont), 'Helvetica Neue', Helvetica, Arial, sans-serif;
33 | color: var(--bodyfontcolor);
34 | }
35 |
36 | p {
37 | font-size: 1em;
38 | line-height: 150%;
39 | }
40 |
41 |
42 | /* Sidebar */
43 | .wy-side-nav-search {
44 | background-color: var(--sidebarcolor);
45 | }
46 |
47 | .wy-nav-side {
48 | background: var(--sidebarcolor);
49 | }
50 |
51 | .wy-menu-vertical header, .wy-menu-vertical p.caption {
52 | color: var(--sidebarfontcolor);
53 | }
54 |
55 | .wy-menu-vertical a {
56 | color: var(--sidebarfontcolor);
57 | }
58 |
59 | .wy-side-nav-search > div.version {
60 | color: var(--sidebarfontcolor);
61 | }
62 |
63 | .wy-menu-vertical a:hover {
64 | background-color: var(--sidebarhover);
65 | }
66 |
67 | /* Main content */
68 | .wy-nav-content {
69 | max-width: var(--contentwidth);
70 | }
71 |
72 |
73 | html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple) dl:not(.field-list) > dt{
74 | margin-bottom: 6px;
75 | border-left: none;
76 | background: none;
77 | color: #555;
78 | }
79 |
--------------------------------------------------------------------------------
/docs/_templates/autosummary/class.rst:
--------------------------------------------------------------------------------
1 | {{ fullname | escape | underline}}
2 |
3 | .. currentmodule:: {{ module }}
4 |
5 | .. add toctree option to make autodoc generate the pages
6 |
7 | .. autoclass:: {{ objname }}
8 |
9 | {% block attributes %}
10 | {% if attributes %}
11 | .. rubric:: Attributes
12 |
13 | .. autosummary::
14 | :toctree: .
15 | {% for item in attributes %}
16 | ~{{ fullname }}.{{ item }}
17 | {%- endfor %}
18 | {% endif %}
19 | {% endblock %}
20 |
21 | {% block methods %}
22 | {% if methods %}
23 | .. rubric:: Methods
24 |
25 | .. autosummary::
26 | :toctree: .
27 | {% for item in methods %}
28 | {%- if item != '__init__' %}
29 | ~{{ fullname }}.{{ item }}
30 | {%- endif -%}
31 | {%- endfor %}
32 | {% endif %}
33 | {% endblock %}
34 |
--------------------------------------------------------------------------------
/docs/_templates/class_no_inherited.rst:
--------------------------------------------------------------------------------
1 | {{ fullname | escape | underline}}
2 |
3 | .. currentmodule:: {{ module }}
4 |
5 | .. add toctree option to make autodoc generate the pages
6 |
7 | .. autoclass:: {{ objname }}
8 |
9 | {% block methods %}
10 | {% if methods %}
11 | .. rubric:: Methods
12 |
13 | .. autosummary::
14 | :toctree: .
15 | {% for item in methods %}
16 | {%- if item != '__init__' and item not in inherited_members%}
17 | ~{{ fullname }}.{{ item }}
18 | {%- endif -%}
19 |
20 | {%- endfor %}
21 | {% endif %}
22 | {% endblock %}
23 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # This file only contains a selection of the most common options. For a full
4 | # list see the documentation:
5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
6 |
7 | # -- Path setup --------------------------------------------------------------
8 |
9 | import sys
10 | from pathlib import Path
11 |
12 | HERE = Path(__file__).parent
13 | sys.path[:0] = [str(HERE.parent), str(HERE / "extensions")]
14 |
15 | import mrvi # noqa
16 |
17 |
18 | # -- General configuration ---------------------------------------------
19 |
20 | # If your documentation needs a minimal Sphinx version, state it here.
21 | #
22 | needs_sphinx = "3.0" # Nicer param docs
23 |
24 | # Add any Sphinx extension module names here, as strings. They can be
25 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones.
26 | extensions = [
27 | "sphinx.ext.autodoc",
28 | "sphinx.ext.viewcode",
29 | "nbsphinx",
30 | "nbsphinx_link",
31 | "sphinx.ext.mathjax",
32 | "sphinx.ext.napoleon",
33 | "sphinx_autodoc_typehints", # needs to be after napoleon
34 | "sphinx.ext.intersphinx",
35 | "sphinx.ext.autosummary",
36 | "scanpydoc.elegant_typehints",
37 | "scanpydoc.definition_list_typed_field",
38 | "scanpydoc.autosummary_generate_imported",
39 | *[p.stem for p in (HERE / "extensions").glob("*.py")],
40 | ]
41 |
42 | # nbsphinx specific settings
43 | exclude_patterns = ["_build", "**.ipynb_checkpoints"]
44 | nbsphinx_execute = "never"
45 |
46 | # Add any paths that contain templates here, relative to this directory.
47 | templates_path = ["_templates"]
48 |
49 | # The suffix(es) of source filenames.
50 | # You can specify multiple suffix as a list of string:
51 | #
52 | # source_suffix = ['.rst', '.md']
53 | source_suffix = ".rst"
54 |
55 | # Generate the API documentation when building
56 | autosummary_generate = True
57 | autodoc_member_order = "bysource"
58 | napoleon_google_docstring = False
59 | napoleon_numpy_docstring = True
60 | napoleon_include_init_with_doc = False
61 | napoleon_use_rtype = True
62 | napoleon_use_param = True
63 | napoleon_custom_sections = [("Params", "Parameters")]
64 | todo_include_todos = False
65 | numpydoc_show_class_members = False
66 | annotate_defaults = True
67 | # The master toctree document.
68 | master_doc = "index"
69 |
70 |
71 | intersphinx_mapping = dict(
72 | anndata=("https://anndata.readthedocs.io/en/stable/", None),
73 | ipython=("https://ipython.readthedocs.io/en/stable/", None),
74 | matplotlib=("https://matplotlib.org/", None),
75 | numpy=("https://docs.scipy.org/doc/numpy/", None),
76 | pandas=("https://pandas.pydata.org/pandas-docs/stable/", None),
77 | python=("https://docs.python.org/3", None),
78 | scipy=("https://docs.scipy.org/doc/scipy/reference/", None),
79 | sklearn=("https://scikit-learn.org/stable/", None),
80 | torch=("https://pytorch.org/docs/master/", None),
81 | scanpy=("https://scanpy.readthedocs.io/en/stable/", None),
82 | pytorch_lightning=("https://pytorch-lightning.readthedocs.io/en/stable/", None),
83 | )
84 |
85 |
86 | # General information about the project.
87 | project = "mrvi"
88 | copyright = "2022, Yosef Lab, UC Berkeley"
89 | author = "Pierre Boyeau, Justin Hong, Adam Gayoso"
90 |
91 | # The version info for the project you're documenting, acts as replacement
92 | # for |version| and |release|, also used in various other places throughout
93 | # the built documents.
94 | #
95 | # The short X.Y version.
96 | version = mrvi.__version__
97 | # The full version, including alpha/beta/rc tags.
98 | release = mrvi.__version__
99 |
100 | # The language for content autogenerated by Sphinx. Refer to documentation
101 | # for a list of supported languages.
102 | #
103 | # This is also used if you do content translation via gettext catalogs.
104 | # Usually you set "language" from the command line for these cases.
105 | language = None
106 |
107 | # List of patterns, relative to source directory, that match files and
108 | # directories to ignore when looking for source files.
109 | # This patterns also effect to html_static_path and html_extra_path
110 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
111 |
112 | # The name of the Pygments (syntax highlighting) style to use.
113 | pygments_style = "tango"
114 |
115 | # If true, `todo` and `todoList` produce output, else they produce nothing.
116 | todo_include_todos = False
117 |
118 | # -- Options for HTML output -------------------------------------------------
119 |
120 | # The theme to use for HTML and HTML Help pages. See the documentation for
121 | # a list of builtin themes.
122 | #
123 | html_theme = "sphinx_rtd_theme"
124 |
125 | html_show_sourcelink = False
126 |
127 | html_show_copyright = False
128 |
129 | display_version = True
130 |
131 | # Add any paths that contain custom static files (such as style sheets) here,
132 | # relative to this directory. They are copied after the builtin static files,
133 | # so a file named "default.css" will overwrite the builtin "default.css".
134 | html_static_path = ["_static"]
135 |
136 | html_css_files = [
137 | "css/custom.css",
138 | ]
139 |
140 | html_favicon = "favicon.ico"
141 |
--------------------------------------------------------------------------------
/docs/extensions/typed_returns.py:
--------------------------------------------------------------------------------
1 | # code from https://github.com/theislab/scanpy/blob/master/docs/extensions/typed_returns.py
2 | # with some minor adjustment
3 | import re
4 |
5 | from sphinx.application import Sphinx
6 | from sphinx.ext.napoleon import NumpyDocstring
7 |
8 |
9 | def process_return(lines):
10 | for line in lines:
11 | m = re.fullmatch(r"(?P\w+)\s+:\s+(?P[\w.]+)", line)
12 | if m:
13 | # Once this is in scanpydoc, we can use the fancy hover stuff
14 | yield f'-{m["param"]} (:class:`~{m["type"]}`)'
15 | else:
16 | yield line
17 |
18 |
19 | def scanpy_parse_returns_section(self, section):
20 | lines_raw = list(process_return(self._dedent(self._consume_to_next_section())))
21 | lines = self._format_block(":returns: ", lines_raw)
22 | if lines and lines[-1]:
23 | lines.append("")
24 | return lines
25 |
26 |
27 | def setup(app: Sphinx):
28 | NumpyDocstring._parse_returns_section = scanpy_parse_returns_section
29 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | ========================
2 | mrvi documentation
3 | ========================
4 |
5 | Welcome! This is the documentation website for the `mrvi
6 | `_ package.
7 |
8 | .. toctree::
9 | :maxdepth: 1
10 | :hidden:
11 |
12 | installation
13 | api/index
14 | release_notes/index
15 | references
16 |
--------------------------------------------------------------------------------
/docs/installation.rst:
--------------------------------------------------------------------------------
1 | Installation
2 | ============
3 |
4 | Prerequisites
5 | ~~~~~~~~~~~~~~
6 |
7 | `mrvi` can be installed via PyPI.
8 |
9 | conda prerequisites
10 | ###################
11 |
12 | 1. Install Conda. We typically use the Miniconda_ Python distribution. Use Python version >=3.7.
13 |
14 | 2. Create a new conda environment::
15 |
16 | conda create -n mrvi-env python=3.7
17 |
18 | 3. Activate your environment::
19 |
20 | source activate mrvi-env
21 |
22 | pip prerequisites:
23 | ##################
24 |
25 | 1. Install Python_, we prefer the `pyenv `_ version management system, along with `pyenv-virtualenv `_.
26 |
27 | 2. Install PyTorch_. If you have an Nvidia GPU, be sure to install a version of PyTorch that supports it -- scvi-tools runs much faster with a discrete GPU.
28 |
29 | .. _Miniconda: https://conda.io/miniconda.html
30 | .. _Python: https://www.python.org/downloads/
31 | .. _PyTorch: http://pytorch.org
32 |
33 | mrvi installation
34 | ~~~~~~~~~~~~~~~~~~~~~~~
35 |
36 | Install mrvi in one of the following ways:
37 |
38 | Through **pip**::
39 |
40 | pip install
41 |
42 | Through pip with packages to run notebooks. This installs scanpy, etc.::
43 |
44 | pip install [tutorials]
45 |
46 | Nightly version - clone this repo and run::
47 |
48 | pip install .
49 |
50 | For development - clone this repo and run::
51 |
52 | pip install -e .[dev,docs]
53 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=python -msphinx
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 | set SPHINXPROJ=mrvi
13 |
14 | if "%1" == "" goto help
15 |
16 | %SPHINXBUILD% >NUL 2>NUL
17 | if errorlevel 9009 (
18 | echo.
19 | echo.The Sphinx module was not found. Make sure you have Sphinx installed,
20 | echo.then set the SPHINXBUILD environment variable to point to the full
21 | echo.path of the 'sphinx-build' executable. Alternatively you may add the
22 | echo.Sphinx directory to PATH.
23 | echo.
24 | echo.If you don't have Sphinx installed, grab it from
25 | echo.http://sphinx-doc.org/
26 | exit /b 1
27 | )
28 |
29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
30 | goto end
31 |
32 | :help
33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
34 |
35 | :end
36 | popd
37 |
--------------------------------------------------------------------------------
/docs/references.bib:
--------------------------------------------------------------------------------
1 | @article{Boyeau2022mrvi,
2 | abstract = {Contemporary single-cell omics technologies have enabled complex experimental designs incorporating hundreds of samples accompanied by detailed information on sample-level conditions. Current approaches for analyzing condition-level heterogeneity in these experiments often rely on a simplification of the data such as an aggregation at the cell-type or cell-state-neighborhood level. Here we present MrVI, a deep generative model that provides sample-sample comparisons at a single-cell resolution, permitting the discovery of subtle sample-specific effects across cell populations. Additionally, the output of MrVI can be used to quantify the association between sample-level metadata and cell state variation. We benchmarked MrVI against conventional meta-analysis procedures on two synthetic datasets and one real dataset with a well-controlled experimental structure. This work introduces a novel approach to understanding sample-level heterogeneity while leveraging the full resolution of single-cell sequencing data.Competing Interest StatementN.Y. is an advisor and/or has equity in Cellarity, Celsius Therapeutics, and Rheos Medicine.},
3 | author = {Boyeau, Pierre and Hong, Justin and Gayoso, Adam and Jordan, Michael I. and Azizi, Elham and Yosef, Nir},
4 | doi = {10.1101/2022.10.04.510898},
5 | elocation-id = {2022.10.04.510898},
6 | eprint = {https://www.biorxiv.org/content/early/2022/10/06/2022.10.04.510898.full.pdf},
7 | journal = {bioRxiv},
8 | publisher = {Cold Spring Harbor Laboratory},
9 | title = {Deep generative modeling for quantifying sample-level heterogeneity in single-cell omics},
10 | url = {https://www.biorxiv.org/content/early/2022/10/06/2022.10.04.510898},
11 | year = {2022},
12 | bdsk-url-1 = {https://www.biorxiv.org/content/early/2022/10/06/2022.10.04.510898},
13 | bdsk-url-2 = {https://doi.org/10.1101/2022.10.04.510898}}
14 |
--------------------------------------------------------------------------------
/docs/references.rst:
--------------------------------------------------------------------------------
1 | References
2 | ==========
3 | **Deep generative modeling for quantifying sample-level heterogeneity in single-cell omics**
4 | Pierre Boyeau*, Justin Hong*, Adam Gayoso, Michael I. Jordan, Elham Azizi, Nir Yosef
5 | bioRxiv 2022. `Link `_.
6 |
7 |
--------------------------------------------------------------------------------
/docs/release_notes/index.rst:
--------------------------------------------------------------------------------
1 | Release notes
2 | =============
3 |
4 | This is the list of changes to ``mrvi`` between each release. Full commit history
5 | is available in the `commit logs `_.
6 |
7 | Version 0.1
8 | -----------
9 | .. toctree::
10 | :maxdepth: 2
11 |
12 | v0.1.0
13 |
--------------------------------------------------------------------------------
/docs/release_notes/v0.1.0.rst:
--------------------------------------------------------------------------------
1 | New in 0.1.0 (2020-10-04)
2 | -------------------------
3 | Initial release of ``mrvi``. Complies with ``scvi-tools>=0.9.0a0``.
4 |
--------------------------------------------------------------------------------
/mrvi/__init__.py:
--------------------------------------------------------------------------------
1 | """scvi-tools-skeleton."""
2 |
3 | import logging
4 |
5 | from rich.console import Console
6 | from rich.logging import RichHandler
7 |
8 | from ._model import MrVI, MrVAE
9 |
10 | # https://github.com/python-poetry/poetry/pull/2366#issuecomment-652418094
11 | # https://github.com/python-poetry/poetry/issues/144#issuecomment-623927302
12 | try:
13 | import importlib.metadata as importlib_metadata
14 | except ModuleNotFoundError:
15 | import importlib_metadata
16 |
17 | package_name = "mrvi"
18 | __version__ = importlib_metadata.version(package_name)
19 |
20 | logger = logging.getLogger(__name__)
21 | # set the logging level
22 | logger.setLevel(logging.INFO)
23 |
24 | # nice logging outputs
25 | console = Console(force_terminal=True)
26 | if console.is_jupyter is True:
27 | console.is_jupyter = False
28 | ch = RichHandler(show_path=False, console=console, show_time=False)
29 | formatter = logging.Formatter("mypackage: %(message)s")
30 | ch.setFormatter(formatter)
31 | logger.addHandler(ch)
32 |
33 | # this prevents double outputs
34 | logger.propagate = False
35 |
36 | __all__ = ["MrVI", "MrVAE"]
37 |
--------------------------------------------------------------------------------
/mrvi/_components.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from scvi.distributions import NegativeBinomial
4 | from scvi.nn import one_hot
5 |
6 | from ._utils import ResnetFC
7 |
8 |
9 | class ExpActivation(nn.Module):
10 | def __init__(self) -> None:
11 | super().__init__()
12 |
13 | def forward(self, x: torch.Tensor) -> torch.Tensor:
14 | return torch.exp(x)
15 |
16 |
17 | class DecoderZX(nn.Module):
18 | """Parameterizes the counts likelihood for the data given the latent variables."""
19 |
20 | def __init__(
21 | self,
22 | n_in,
23 | n_out,
24 | n_nuisance,
25 | linear_decoder,
26 | n_hidden=128,
27 | activation="softmax",
28 | ):
29 | super().__init__()
30 | if activation == "softmax":
31 | activation_ = nn.Softmax(-1)
32 | elif activation == "softplus":
33 | activation_ = nn.Softplus()
34 | elif activation == "exp":
35 | activation_ = ExpActivation()
36 | elif activation == "sigmoid":
37 | activation_ = nn.Sigmoid()
38 | else:
39 | raise ValueError("activation must be one of 'softmax' or 'softplus'")
40 | self.linear_decoder = linear_decoder
41 | self.n_nuisance = n_nuisance
42 | self.n_latent = n_in - n_nuisance
43 | if linear_decoder:
44 | self.amat = nn.Linear(self.n_latent, n_out, bias=False)
45 | self.amat_site = nn.Parameter(
46 | torch.randn(self.n_nuisance, self.n_latent, n_out)
47 | )
48 | self.offsets = nn.Parameter(torch.randn(self.n_nuisance, n_out))
49 | self.dropout_ = nn.Dropout(0.1)
50 | self.activation_ = activation_
51 |
52 | else:
53 | self.px_mean = ResnetFC(
54 | n_in=n_in,
55 | n_out=n_out,
56 | n_hidden=n_hidden,
57 | activation=activation_,
58 | )
59 | self.px_r = nn.Parameter(torch.randn(n_out))
60 |
61 | def forward(self, z, size_factor):
62 | if self.linear_decoder:
63 | nuisance_oh = z[..., -self.n_nuisance :]
64 | z0 = z[..., : -self.n_nuisance]
65 | x1 = self.amat(z0)
66 |
67 | nuisance_ids = torch.argmax(nuisance_oh, -1)
68 | As = self.amat_site[nuisance_ids]
69 | z0_detach = self.dropout_(z0.detach())[..., None]
70 | x2 = (As * z0_detach).sum(-2)
71 | offsets = self.offsets[nuisance_ids]
72 | mu = x1 + x2 + offsets
73 | mu = self.activation_(mu)
74 | else:
75 | mu = self.px_mean(z)
76 | mu = mu * size_factor
77 | return NegativeBinomial(mu=mu, theta=self.px_r.exp())
78 |
79 |
80 | class LinearDecoderUZ(nn.Module):
81 | def __init__(
82 | self,
83 | n_latent,
84 | n_sample,
85 | n_out,
86 | scaler=False,
87 | scaler_n_hidden=32,
88 | ):
89 | super().__init__()
90 | self.n_latent = n_latent
91 | self.n_sample = n_sample
92 | self.n_out = n_out
93 |
94 | self.amat_sample = nn.Parameter(torch.randn(n_sample, self.n_latent, n_out))
95 | self.offsets = nn.Parameter(torch.randn(n_sample, n_out))
96 |
97 | self.scaler = None
98 | if scaler:
99 | self.scaler = nn.Sequential(
100 | nn.Linear(n_latent + n_sample, scaler_n_hidden),
101 | nn.LayerNorm(scaler_n_hidden),
102 | nn.ReLU(),
103 | nn.Linear(scaler_n_hidden, 1),
104 | nn.Sigmoid(),
105 | )
106 |
107 | def forward(self, u, sample_id):
108 | sample_id_ = sample_id.long().squeeze()
109 | As = self.amat_sample[sample_id_]
110 |
111 | u_detach = u.detach()[..., None]
112 | z2 = (As * u_detach).sum(-2)
113 | offsets = self.offsets[sample_id_]
114 | delta = z2 + offsets
115 | if self.scaler is not None:
116 | sample_oh = one_hot(sample_id, self.n_sample)
117 | if u.ndim != sample_oh.ndim:
118 | sample_oh = sample_oh[None].expand(u.shape[0], *sample_oh.shape)
119 | inputs = torch.cat([u.detach(), sample_oh], -1)
120 | delta = delta * self.scaler(inputs)
121 | return u + delta
122 |
123 |
124 | class DecoderUZ(nn.Module):
125 | def __init__(
126 | self,
127 | n_latent,
128 | n_latent_sample,
129 | n_out,
130 | dropout_rate=0.0,
131 | n_layers=1,
132 | n_hidden=128,
133 | ):
134 | super().__init__()
135 | self.n_latent = n_latent
136 | self.n_latent_sample = n_latent_sample
137 | self.n_in = n_latent + n_latent_sample
138 | self.n_out = n_out
139 |
140 | arch_mod = self.construct_arch(self.n_in, n_hidden, n_layers, dropout_rate) + [
141 | nn.Linear(n_hidden, self.n_out, bias=False)
142 | ]
143 | self.mod = nn.Sequential(*arch_mod)
144 |
145 | arch_scaler = self.construct_arch(
146 | self.n_latent, n_hidden, n_layers, dropout_rate
147 | ) + [nn.Linear(n_hidden, 1)]
148 | self.scaler = nn.Sequential(*arch_scaler)
149 | self.scaler.append(nn.Sigmoid())
150 |
151 | @staticmethod
152 | def construct_arch(n_inputs, n_hidden, n_layers, dropout_rate):
153 | """Initializes MLP architecture"""
154 |
155 | block_inputs = [
156 | nn.Linear(n_inputs, n_hidden),
157 | nn.BatchNorm1d(n_hidden),
158 | nn.Dropout(p=dropout_rate),
159 | nn.ReLU(),
160 | ]
161 |
162 | block_inner = n_layers * [
163 | nn.Linear(n_hidden, n_hidden),
164 | nn.BatchNorm1d(n_hidden),
165 | nn.ReLU(),
166 | ]
167 | return block_inputs + block_inner
168 |
169 | def forward(self, u):
170 | u_ = u.clone()
171 | if u_.dim() == 3:
172 | n_samples, n_cells, n_features = u_.shape
173 | u0_ = u_[:, :, : self.n_latent].reshape(-1, self.n_latent)
174 | u_ = u_.reshape(-1, n_features)
175 | pred_ = self.mod(u_).reshape(n_samples, n_cells, -1)
176 | scaler_ = self.scaler(u0_).reshape(n_samples, n_cells, -1)
177 | else:
178 | pred_ = self.mod(u)
179 | scaler_ = self.scaler(u[:, : self.n_latent])
180 | mean = u[..., : self.n_latent] + scaler_ * pred_
181 | return mean
182 |
--------------------------------------------------------------------------------
/mrvi/_constants.py:
--------------------------------------------------------------------------------
1 | from typing import NamedTuple
2 |
3 |
4 | class _MRVI_REGISTRY_KEYS_NT(NamedTuple):
5 | SAMPLE_KEY: str = "sample"
6 | CATEGORICAL_NUISANCE_KEYS: str = "categorical_nuisance_keys"
7 |
8 |
9 | MRVI_REGISTRY_KEYS = _MRVI_REGISTRY_KEYS_NT()
10 |
--------------------------------------------------------------------------------
/mrvi/_model.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import logging
4 | from copy import deepcopy
5 | from typing import List, Optional
6 |
7 | import numpy as np
8 | import torch
9 | from anndata import AnnData
10 | from scvi import REGISTRY_KEYS
11 | from scvi.data import AnnDataManager
12 | from scvi.data.fields import CategoricalJointObsField, CategoricalObsField, LayerField
13 | from scvi.model.base import BaseModelClass, UnsupervisedTrainingMixin, VAEMixin
14 | from sklearn.metrics import pairwise_distances
15 | from tqdm import tqdm
16 |
17 | from ._constants import MRVI_REGISTRY_KEYS
18 | from ._module import MrVAE
19 |
20 | logger = logging.getLogger(__name__)
21 |
22 | DEFAULT_TRAIN_KWARGS = dict(
23 | early_stopping=True,
24 | early_stopping_patience=15,
25 | check_val_every_n_epoch=1,
26 | batch_size=256,
27 | train_size=0.9,
28 | plan_kwargs=dict(
29 | lr=1e-2,
30 | n_epochs_kl_warmup=20,
31 | ),
32 | )
33 |
34 |
35 | class MrVI(UnsupervisedTrainingMixin, VAEMixin, BaseModelClass):
36 | """
37 | Multi-resolution Variational Inference (MrVI) :cite:`Boyeau2022mrvi`.
38 |
39 | Parameters
40 | ----------
41 | adata
42 | AnnData object that has been registered via
43 | :meth:`~scvi.model.MrVI.setup_anndata`.
44 | n_latent
45 | Dimensionality of the latent space.
46 | n_latent_donor
47 | Dimensionality of the latent space for sample embeddings.
48 | linear_decoder_zx
49 | Whether to use a linear decoder for the decoder from z to x.
50 | linear_decoder_uz
51 | Whether to use a linear decoder for the decoder from u to z.
52 | linear_decoder_uz_scaler
53 | Whether to incorporate a learned scaler term in the decoder from u to z.
54 | linear_decoder_uz_scaler_n_hidden
55 | If `linear_decoder_uz_scaler` is True, the number of hidden
56 | units in the neural network used to produce the scaler term
57 | in decoder from u to z.
58 | px_kwargs
59 | Keyword args for :class:`~mrvi.components.DecoderZX`.
60 | pz_kwargs
61 | Keyword args for :class:`~mrvi.components.DecoderUZ`.
62 | """
63 |
64 | def __init__(
65 | self,
66 | adata: AnnData,
67 | **model_kwargs,
68 | ):
69 | super().__init__(adata)
70 | n_cats_per_nuisance_keys = (
71 | self.adata_manager.get_state_registry(
72 | MRVI_REGISTRY_KEYS.CATEGORICAL_NUISANCE_KEYS
73 | ).n_cats_per_key
74 | if MRVI_REGISTRY_KEYS.CATEGORICAL_NUISANCE_KEYS
75 | in self.adata_manager.data_registry
76 | else []
77 | )
78 |
79 | n_sample = self.summary_stats.n_sample
80 | n_obs_per_sample = (
81 | adata.obs.groupby(
82 | self.adata_manager.get_state_registry(MRVI_REGISTRY_KEYS.SAMPLE_KEY)[
83 | "original_key"
84 | ]
85 | )
86 | .size()
87 | .loc[
88 | self.adata_manager.get_state_registry(MRVI_REGISTRY_KEYS.SAMPLE_KEY)[
89 | "categorical_mapping"
90 | ]
91 | ]
92 | .values
93 | )
94 | n_obs_per_sample = torch.from_numpy(n_obs_per_sample).float()
95 | self.data_splitter = None
96 | self.module = MrVAE(
97 | n_input=self.summary_stats.n_vars,
98 | n_sample=n_sample,
99 | n_obs_per_sample=n_obs_per_sample,
100 | n_cats_per_nuisance_keys=n_cats_per_nuisance_keys,
101 | **model_kwargs,
102 | )
103 | self.init_params_ = self._get_init_params(locals())
104 |
105 | @classmethod
106 | def setup_anndata(
107 | cls,
108 | adata: AnnData,
109 | layer: Optional[str] = None,
110 | sample_key: Optional[str] = None,
111 | categorical_nuisance_keys: Optional[List[str]] = None,
112 | **kwargs,
113 | ):
114 | setup_method_args = cls._get_setup_method_args(**locals())
115 | anndata_fields = [
116 | LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True),
117 | CategoricalObsField(MRVI_REGISTRY_KEYS.SAMPLE_KEY, sample_key),
118 | CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, None),
119 | CategoricalJointObsField(
120 | MRVI_REGISTRY_KEYS.CATEGORICAL_NUISANCE_KEYS, categorical_nuisance_keys
121 | ),
122 | ]
123 | adata_manager = AnnDataManager(
124 | fields=anndata_fields, setup_method_args=setup_method_args
125 | )
126 | adata_manager.register_fields(adata, **kwargs)
127 | cls.register_manager(adata_manager)
128 |
129 | def train(
130 | self,
131 | max_epochs: int | None = None,
132 | accelerator: str = "auto",
133 | devices: int | list[int] | str = "auto",
134 | train_size: float = 0.9,
135 | validation_size: float | None = None,
136 | batch_size: int = 128,
137 | early_stopping: bool = False,
138 | plan_kwargs: dict | None = None,
139 | **trainer_kwargs,
140 | ):
141 | train_kwargs = dict(
142 | max_epochs=max_epochs,
143 | accelerator=accelerator,
144 | devices=devices,
145 | train_size=train_size,
146 | validation_size=validation_size,
147 | batch_size=batch_size,
148 | early_stopping=early_stopping,
149 | **trainer_kwargs,
150 | )
151 | train_kwargs = dict(deepcopy(DEFAULT_TRAIN_KWARGS), **train_kwargs)
152 | plan_kwargs = plan_kwargs or {}
153 | train_kwargs["plan_kwargs"] = dict(
154 | deepcopy(DEFAULT_TRAIN_KWARGS["plan_kwargs"]), **plan_kwargs
155 | )
156 | super().train(**train_kwargs)
157 |
158 | @torch.no_grad()
159 | def get_latent_representation(
160 | self,
161 | adata: AnnData | None = None,
162 | indices: list[int] | None = None,
163 | mc_samples: int = 5000,
164 | batch_size: int | None = None,
165 | give_z: bool = False,
166 | ) -> np.ndarray:
167 | self._check_if_trained(warn=False)
168 | adata = self._validate_anndata(adata)
169 | scdl = self._make_data_loader(
170 | adata=adata, indices=indices, batch_size=batch_size
171 | )
172 |
173 | u = []
174 | z = []
175 | for tensors in tqdm(scdl):
176 | inference_inputs = self.module._get_inference_input(tensors)
177 | outputs = self.module.inference(mc_samples=mc_samples, **inference_inputs)
178 | u.append(outputs["u"].mean(0).cpu())
179 | z.append(outputs["z"].mean(0).cpu())
180 |
181 | u = torch.cat(u, 0).numpy()
182 | z = torch.cat(z, 0).numpy()
183 | return z if give_z else u
184 |
185 | @staticmethod
186 | def compute_distance_matrix_from_representations(
187 | representations: np.ndarray, metric: str = "euclidean"
188 | ) -> np.ndarray:
189 | """
190 | Compute distance matrices from counterfactual sample representations.
191 |
192 | Parameters
193 | ----------
194 | representations
195 | Counterfactual sample representations of shape
196 | (n_cells, n_sample, n_features).
197 | metric
198 | Metric to use for computing distance matrix.
199 | """
200 | n_cells, n_donors, _ = representations.shape
201 | pairwise_dists = np.zeros((n_cells, n_donors, n_donors))
202 | for i, cell_rep in enumerate(representations):
203 | d_ = pairwise_distances(cell_rep, metric=metric)
204 | pairwise_dists[i, :, :] = d_
205 | return pairwise_dists
206 |
207 | @torch.no_grad()
208 | def get_local_sample_representation(
209 | self,
210 | adata: AnnData | None = None,
211 | batch_size: int = 256,
212 | mc_samples: int = 10,
213 | return_distances: bool = False,
214 | ):
215 | """
216 | Computes the local sample representation of the cells in the adata object.
217 |
218 | For each cell, it returns a matrix of size (n_sample, n_features).
219 |
220 | Parameters
221 | ----------
222 | adata
223 | AnnData object to use for computing the local sample representation.
224 | batch_size
225 | Batch size to use for computing the local sample representation.
226 | mc_samples
227 | Number of Monte Carlo samples to use for computing the local sample
228 | representation.
229 | return_distances
230 | If ``return_distances`` is ``True``, returns a distance matrix of
231 | size (n_sample, n_sample) for each cell.
232 | """
233 | adata = self.adata if adata is None else adata
234 | self._check_if_trained(warn=False)
235 | adata = self._validate_anndata(adata)
236 | scdl = self._make_data_loader(adata=adata, indices=None, batch_size=batch_size)
237 |
238 | reps = []
239 | for tensors in tqdm(scdl):
240 | xs = []
241 | for sample in range(self.summary_stats.n_sample):
242 | cf_sample = sample * torch.ones_like(
243 | tensors[MRVI_REGISTRY_KEYS.SAMPLE_KEY]
244 | )
245 | inference_inputs = self.module._get_inference_input(tensors)
246 | inference_outputs = self.module.inference(
247 | mc_samples=mc_samples, cf_sample=cf_sample, **inference_inputs
248 | )
249 | new = inference_outputs["z"]
250 |
251 | xs.append(new[:, :, None])
252 |
253 | xs = torch.cat(xs, 2).mean(0)
254 | reps.append(xs.cpu().numpy())
255 | # n_cells, n_sample, n_latent
256 | reps = np.concatenate(reps, 0)
257 |
258 | if return_distances:
259 | return self.compute_distance_matrix_from_representations(reps)
260 |
261 | return reps
262 |
--------------------------------------------------------------------------------
/mrvi/_module.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import torch
4 | import torch.distributions as db
5 | import torch.nn as nn
6 | from scvi import REGISTRY_KEYS
7 | from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data
8 | from scvi.nn import one_hot
9 | from torch.distributions import kl_divergence as kl
10 |
11 | from ._components import DecoderUZ, DecoderZX, LinearDecoderUZ
12 | from ._constants import MRVI_REGISTRY_KEYS
13 | from ._utils import ConditionalBatchNorm1d, NormalNN
14 |
15 | DEFAULT_PX_HIDDEN = 32
16 | DEFAULT_PZ_LAYERS = 1
17 | DEFAULT_PZ_HIDDEN = 32
18 |
19 |
20 | class MrVAE(BaseModuleClass):
21 | def __init__(
22 | self,
23 | n_input: int,
24 | n_sample: int,
25 | n_obs_per_sample: int,
26 | n_cats_per_nuisance_keys: list[int],
27 | n_latent: int = 10,
28 | n_latent_sample: int = 2,
29 | linear_decoder_zx: bool = True,
30 | linear_decoder_uz: bool = True,
31 | linear_decoder_uz_scaler: bool = False,
32 | linear_decoder_uz_scaler_n_hidden: int = 32,
33 | px_kwargs: dict | None = None,
34 | pz_kwargs: dict | None = None,
35 | ):
36 | super().__init__()
37 | px_kwargs = dict(n_hidden=DEFAULT_PX_HIDDEN)
38 | if px_kwargs is not None:
39 | px_kwargs.update(px_kwargs)
40 | pz_kwargs = dict(n_layers=DEFAULT_PZ_LAYERS, n_hidden=DEFAULT_PZ_HIDDEN)
41 | if pz_kwargs is not None:
42 | pz_kwargs.update(pz_kwargs)
43 |
44 | self.n_cats_per_nuisance_keys = n_cats_per_nuisance_keys
45 | self.n_sample = n_sample
46 | assert n_latent_sample != 0
47 | self.sample_embeddings = nn.Embedding(n_sample, n_latent_sample)
48 |
49 | n_nuisance = sum(self.n_cats_per_nuisance_keys)
50 | # Generative model
51 | self.px = DecoderZX(
52 | n_latent + n_nuisance,
53 | n_input,
54 | n_nuisance=n_nuisance,
55 | linear_decoder=linear_decoder_zx,
56 | **px_kwargs,
57 | )
58 | self.qu = NormalNN(128 + n_latent_sample, n_latent, n_categories=1)
59 | self.ql = NormalNN(n_input, 1, n_categories=1)
60 |
61 | self.linear_decoder_uz = linear_decoder_uz
62 | if linear_decoder_uz:
63 | self.pz = LinearDecoderUZ(
64 | n_latent,
65 | self.n_sample,
66 | n_latent,
67 | scaler=linear_decoder_uz_scaler,
68 | scaler_n_hidden=linear_decoder_uz_scaler_n_hidden,
69 | )
70 | else:
71 | self.pz = DecoderUZ(
72 | n_latent,
73 | n_latent_sample,
74 | n_latent,
75 | **pz_kwargs,
76 | )
77 | self.n_obs_per_sample = nn.Parameter(n_obs_per_sample, requires_grad=False)
78 |
79 | self.x_featurizer = nn.Sequential(nn.Linear(n_input, 128), nn.ReLU())
80 | self.bnn = ConditionalBatchNorm1d(128, n_sample)
81 | self.x_featurizer2 = nn.Sequential(nn.Linear(128, 128), nn.ReLU())
82 | self.bnn2 = ConditionalBatchNorm1d(128, n_sample)
83 |
84 | def _get_inference_input(
85 | self, tensors: dict[str, torch.Tensor], **kwargs
86 | ) -> dict[str, torch.Tensor]:
87 | x = tensors[REGISTRY_KEYS.X_KEY]
88 | sample_index = tensors[MRVI_REGISTRY_KEYS.SAMPLE_KEY]
89 | categorical_nuisance_keys = tensors[
90 | MRVI_REGISTRY_KEYS.CATEGORICAL_NUISANCE_KEYS
91 | ]
92 | return dict(
93 | x=x,
94 | sample_index=sample_index,
95 | categorical_nuisance_keys=categorical_nuisance_keys,
96 | )
97 |
98 | @auto_move_data
99 | def inference(
100 | self,
101 | x: torch.Tensor,
102 | sample_index: torch.Tensor,
103 | categorical_nuisance_keys: torch.Tensor,
104 | mc_samples: int = 1,
105 | cf_sample: torch.Tensor | None = None,
106 | use_mean: bool = False,
107 | ) -> dict[str, torch.Tensor]:
108 | x_ = torch.log1p(x)
109 |
110 | sample_index_cf = sample_index if cf_sample is None else cf_sample
111 | zsample = self.sample_embeddings(sample_index_cf.long().squeeze(-1))
112 | zsample_ = zsample
113 | if mc_samples >= 2:
114 | zsample_ = zsample[None].expand(mc_samples, *zsample.shape)
115 |
116 | nuisance_oh = []
117 | for dim in range(categorical_nuisance_keys.shape[-1]):
118 | nuisance_oh.append(
119 | one_hot(
120 | categorical_nuisance_keys[:, [dim]],
121 | self.n_cats_per_nuisance_keys[dim],
122 | )
123 | )
124 | nuisance_oh = torch.cat(nuisance_oh, dim=-1)
125 |
126 | x_feat = self.x_featurizer(x_)
127 | x_feat = self.bnn(x_feat, sample_index)
128 | x_feat = self.x_featurizer2(x_feat)
129 | x_feat = self.bnn2(x_feat, sample_index)
130 | if x_.ndim != zsample_.ndim:
131 | x_feat_ = x_feat[None].expand(mc_samples, *x_feat.shape)
132 | nuisance_oh = nuisance_oh[None].expand(mc_samples, *nuisance_oh.shape)
133 | else:
134 | x_feat_ = x_feat
135 |
136 | inputs = torch.cat([x_feat_, zsample_], -1)
137 | # inputs = x_feat_
138 | qu = self.qu(inputs)
139 | if use_mean:
140 | u = qu.loc
141 | else:
142 | u = qu.rsample()
143 |
144 | if self.linear_decoder_uz:
145 | z = self.pz(u, sample_index_cf)
146 | else:
147 | inputs = torch.cat([u, zsample_], -1)
148 | z = self.pz(inputs)
149 | library = torch.log(x.sum(1)).unsqueeze(1)
150 |
151 | return dict(
152 | qu=qu,
153 | u=u,
154 | z=z,
155 | zsample=zsample,
156 | library=library,
157 | nuisance_oh=nuisance_oh,
158 | )
159 |
160 | def get_z(
161 | self,
162 | u: torch.Tensor,
163 | zsample: torch.Tensor | None = None,
164 | sample_index: torch.Tensor | None = None,
165 | ) -> torch.Tensor:
166 | if sample_index is not None:
167 | zsample = self.sample_embeddings(sample_index.long().squeeze(-1))
168 | zsample = zsample
169 | else:
170 | zsample_ = zsample
171 | inputs = torch.cat([u, zsample_], -1)
172 | z = self.pz(inputs)
173 | return z
174 |
175 | def _get_generative_input(
176 | self,
177 | tensors: dict[str, torch.Tensor],
178 | inference_outputs: dict[str, torch.Tensor],
179 | **kwargs,
180 | ) -> dict[str, torch.Tensor]:
181 | res = dict(
182 | z=inference_outputs["z"],
183 | library=inference_outputs["library"],
184 | nuisance_oh=inference_outputs["nuisance_oh"],
185 | )
186 |
187 | return res
188 |
189 | @auto_move_data
190 | def generative(
191 | self,
192 | z: torch.Tensor,
193 | library: torch.Tensor,
194 | nuisance_oh: torch.Tensor,
195 | ) -> dict[str, torch.Tensor]:
196 | inputs = torch.concat([z, nuisance_oh], dim=-1)
197 | px = self.px(inputs, size_factor=library.exp())
198 | h = px.mu / library.exp()
199 |
200 | pu = db.Normal(0, 1)
201 | return dict(px=px, pu=pu, h=h)
202 |
203 | def loss(
204 | self,
205 | tensors: dict[str, torch.Tensor],
206 | inference_outputs: dict[str, torch.Tensor],
207 | generative_outputs: dict[str, torch.Tensor],
208 | kl_weight: float = 1.0,
209 | ) -> LossOutput:
210 | reconstruction_loss = (
211 | -generative_outputs["px"].log_prob(tensors[REGISTRY_KEYS.X_KEY]).sum(-1)
212 | )
213 | kl_u = kl(inference_outputs["qu"], generative_outputs["pu"]).sum(-1)
214 | kl_local_for_warmup = kl_u
215 |
216 | weighted_kl_local = kl_weight * kl_local_for_warmup
217 | loss = torch.mean(reconstruction_loss + weighted_kl_local)
218 |
219 | kl_local = torch.tensor(0.0)
220 | kl_global = torch.tensor(0.0)
221 |
222 | return LossOutput(
223 | loss=loss,
224 | reconstruction_loss=reconstruction_loss,
225 | kl_local=kl_local,
226 | kl_global=kl_global,
227 | )
228 |
--------------------------------------------------------------------------------
/mrvi/_utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import torch
4 | import torch.distributions as db
5 | import torch.nn as nn
6 |
7 | logger = logging.getLogger(__name__)
8 |
9 |
10 | class ResnetFC(nn.Module):
11 | def __init__(
12 | self,
13 | n_in,
14 | n_out,
15 | n_hidden=128,
16 | activation=nn.Softmax(-1),
17 | ):
18 | super().__init__()
19 | self.module1 = nn.Sequential(
20 | nn.Linear(n_in, n_hidden),
21 | nn.BatchNorm1d(n_hidden),
22 | nn.ReLU(),
23 | )
24 | self.module2 = nn.Sequential(
25 | nn.Linear(n_hidden, n_out),
26 | nn.BatchNorm1d(n_out),
27 | )
28 | if n_in != n_hidden:
29 | self.id_map1 = nn.Linear(n_in, n_hidden)
30 | else:
31 | self.id_map1 = None
32 | self.activation = activation
33 |
34 | def forward(self, inputs):
35 | need_reshaping = False
36 | if inputs.ndim == 3:
37 | n_d1, nd2 = inputs.shape[:2]
38 | inputs = inputs.reshape(n_d1 * nd2, -1)
39 | need_reshaping = True
40 | h = self.module1(inputs)
41 | if self.id_map1 is not None:
42 | h = h + self.id_map1(inputs)
43 | h = self.module2(h)
44 | if need_reshaping:
45 | h = h.view(n_d1, nd2, -1)
46 | if self.activation is not None:
47 | return self.activation(h)
48 | return h
49 |
50 |
51 | class _NormalNN(nn.Module):
52 | def __init__(
53 | self,
54 | n_in,
55 | n_out,
56 | n_hidden=128,
57 | n_layers=1,
58 | use_batch_norm=True,
59 | use_layer_norm=False,
60 | do_orthogonal=False,
61 | ):
62 | super().__init__()
63 | self.n_layers = n_layers
64 |
65 | self.hidden = ResnetFC(n_in, n_out=n_hidden, activation=nn.ReLU())
66 | self._mean = nn.Linear(n_hidden, n_out)
67 | self._var = nn.Sequential(nn.Linear(n_hidden, n_out), nn.Softplus())
68 |
69 | def forward(self, inputs):
70 | if self.n_layers >= 1:
71 | h = self.hidden(inputs)
72 | mean = self._mean(h)
73 | var = self._var(h)
74 | else:
75 | mean = self._mean(inputs)
76 | k = mean.shape[0]
77 | var = self._var[None].expand(k, -1)
78 | return mean, var
79 |
80 |
81 | class NormalNN(nn.Module):
82 | def __init__(
83 | self,
84 | n_in,
85 | n_out,
86 | n_categories,
87 | n_hidden=128,
88 | n_layers=1,
89 | use_batch_norm=True,
90 | use_layer_norm=False,
91 | ):
92 | super().__init__()
93 | nn_kwargs = dict(
94 | n_in=n_in,
95 | n_out=n_out,
96 | n_hidden=n_hidden,
97 | n_layers=n_layers,
98 | use_batch_norm=use_batch_norm,
99 | use_layer_norm=use_layer_norm,
100 | )
101 | self.n_out = n_out
102 | self._mymodules = nn.ModuleList(
103 | [_NormalNN(**nn_kwargs) for _ in range(n_categories)]
104 | )
105 |
106 | def forward(self, inputs, categories=None):
107 | means = []
108 | vars = []
109 | for idx, module in enumerate(self._mymodules):
110 | _means, _vars = module(inputs)
111 | means.append(_means[..., None])
112 | vars.append(_vars[..., None])
113 | means = torch.cat(means, -1)
114 | vars = torch.cat(vars, -1)
115 | if categories is not None:
116 | # categories (minibatch, 1)
117 | n_batch = categories.shape[0]
118 | cat_ = categories.unsqueeze(-1).long().expand(n_batch, self.n_out, 1)
119 | if means.ndim == 4:
120 | d1, n_batch, _, _ = means.shape
121 | cat_ = (
122 | categories[None, :, None].long().expand(d1, n_batch, self.n_out, 1)
123 | )
124 | means = torch.gather(means, -1, cat_)
125 | vars = torch.gather(vars, -1, cat_)
126 | means = means.squeeze(-1)
127 | vars = vars.squeeze(-1)
128 | return db.Normal(means, vars + 1e-5)
129 |
130 |
131 | class ConditionalBatchNorm1d(nn.Module):
132 | def __init__(self, num_features, num_classes):
133 | super().__init__()
134 | self.num_features = num_features
135 | self.bn = nn.BatchNorm1d(num_features, affine=False)
136 | self.embed = nn.Embedding(num_classes, num_features * 2)
137 | self.embed.weight.data[:, :num_features].normal_(
138 | 1, 0.02
139 | ) # Initialise scale at N(1, 0.02)
140 | self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
141 |
142 | def forward(self, x, y):
143 | need_reshaping = False
144 | if x.ndim == 3:
145 | n_d1, nd2 = x.shape[:2]
146 | x = x.view(n_d1 * nd2, -1)
147 | need_reshaping = True
148 |
149 | y = y[None].expand(n_d1, nd2, -1)
150 | y = y.contiguous().view(n_d1 * nd2, -1)
151 |
152 | out = self.bn(x)
153 | gamma, beta = self.embed(y.squeeze(-1).long()).chunk(2, 1)
154 | out = gamma * out + beta
155 |
156 | if need_reshaping:
157 | out = out.view(n_d1, nd2, -1)
158 |
159 | return out
160 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.isort]
2 | include_trailing_comma = true
3 | multi_line_output = 3
4 | profile = "black"
5 | skip_glob = ["docs/*", "mrvi/__init__.py"]
6 |
7 | [tool.poetry]
8 | authors = ["Pierre Boyeau ", "Justin Hong ", "Adam Gayoso "]
9 | classifiers = [
10 | "Development Status :: 4 - Beta",
11 | "Intended Audience :: Science/Research",
12 | "Natural Language :: English",
13 | "Programming Language :: Python :: 3.6",
14 | "Programming Language :: Python :: 3.7",
15 | "Operating System :: MacOS :: MacOS X",
16 | "Operating System :: Microsoft :: Windows",
17 | "Operating System :: POSIX :: Linux",
18 | "Topic :: Scientific/Engineering :: Bio-Informatics",
19 | ]
20 | description = "Multi-resolution analysis of single-cell data."
21 | documentation = "https://scvi-tools.org"
22 | homepage = "https://github.com/YosefLab/mrvi"
23 | license = "BSD-3-Clause"
24 | name = "mrvi"
25 | packages = [
26 | {include = "mrvi"},
27 | ]
28 | readme = "README.md"
29 | version = "0.2.0"
30 |
31 | [tool.poetry.dependencies]
32 | anndata = ">=0.7.5"
33 | black = {version = ">=20.8b1", optional = true}
34 | codecov = {version = ">=2.0.8", optional = true}
35 | flake8 = {version = ">=3.7.7", optional = true}
36 | importlib-metadata = {version = "^1.0", python = "<3.8"}
37 | ipython = {version = ">=7.1.1", optional = true}
38 | isort = {version = ">=5.7", optional = true}
39 | jupyter = {version = ">=1.0", optional = true}
40 | leidenalg = {version = "*", optional = true}
41 | loompy = {version = ">=3.0.6", optional = true}
42 | nbconvert = {version = ">=5.4.0", optional = true}
43 | nbformat = {version = ">=4.4.0", optional = true}
44 | nbsphinx = {version = "*", optional = true}
45 | nbsphinx-link = {version = "*", optional = true}
46 | pre-commit = {version = ">=2.7.1", optional = true}
47 | pydata-sphinx-theme = {version = ">=0.4.0", optional = true}
48 | pytest = {version = ">=4.4", optional = true}
49 | python = ">=3.7.2,<4.0"
50 | python-igraph = {version = "*", optional = true}
51 | scanpy = {version = ">=1.6", optional = true}
52 | scanpydoc = {version = ">=0.5", optional = true}
53 | scikit-misc = {version = ">=0.1.3", optional = true}
54 | scvi-tools = ">=1.0.0"
55 | sphinx = {version = ">=4.1,<4.4", optional = true}
56 | sphinx-autodoc-typehints = {version = "*", optional = true}
57 | sphinx-rtd-theme = {version = "*", optional = true}
58 | typing_extensions = {version = "*", python = "<3.8"}
59 |
60 | [tool.poetry.extras]
61 | dev = ["black", "pytest", "flake8", "codecov", "scanpy", "loompy", "jupyter", "nbformat", "nbconvert", "pre-commit", "isort"]
62 | docs = [
63 | "sphinx",
64 | "scanpydoc",
65 | "nbsphinx",
66 | "nbsphinx-link",
67 | "ipython",
68 | "pydata-sphinx-theme",
69 | "typing_extensions",
70 | "sphinx-autodoc-typehints",
71 | "sphinx-rtd-theme",
72 | "sphinxcontrib-bibtex",
73 | ]
74 |
75 | [tool.poetry.dev-dependencies]
76 |
77 | [build-system]
78 | requires = [
79 | "poetry-core>=1.0.0",
80 | ]
81 | build-backend = "poetry.core.masonry.api"
82 |
--------------------------------------------------------------------------------
/readthedocs.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 | build:
3 | image: latest
4 | sphinx:
5 | configuration: docs/conf.py
6 | python:
7 | version: 3.7
8 | install:
9 | - method: pip
10 | path: .
11 | extra_requirements:
12 | - docs
13 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | # This is a shim to hopefully allow Github to detect the package, build is done with poetry
4 |
5 | import setuptools
6 |
7 | if __name__ == "__main__":
8 | setuptools.setup(name="mrvi")
9 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YosefLab/mrvi_archive/86112e3d32e3f1d499efa5ffcace26746b022603/tests/__init__.py
--------------------------------------------------------------------------------
/tests/test_mrvi.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scvi.data import synthetic_iid
3 |
4 | from mrvi import MrVI
5 |
6 |
7 | def test_mrvi():
8 | adata = synthetic_iid()
9 | adata.obs["sample"] = np.random.choice(15, size=adata.shape[0])
10 | MrVI.setup_anndata(adata, sample_key="sample", categorical_nuisance_keys=["batch"])
11 | for linear_decoder_uz in [True, False]:
12 | for linear_decoder_zx in [True, False]:
13 | model = MrVI(
14 | adata,
15 | n_latent_sample=5,
16 | linear_decoder_zx=linear_decoder_zx,
17 | linear_decoder_uz=linear_decoder_uz,
18 | )
19 | model.train(1, check_val_every_n_epoch=1, train_size=0.5)
20 | model.history
21 |
22 | model = MrVI(
23 | adata,
24 | n_latent_sample=5,
25 | linear_decoder_zx=True,
26 | linear_decoder_uz=True,
27 | linear_decoder_uz_scaler=True,
28 | )
29 | model.train(1, check_val_every_n_epoch=1, train_size=0.5)
30 | model.get_latent_representation()
31 | assert model.get_local_sample_representation().shape == (adata.shape[0], 15, 10)
32 | assert model.get_local_sample_representation(return_distances=True).shape == (
33 | adata.shape[0],
34 | 15,
35 | 15,
36 | )
37 | # tests __repr__
38 | print(model)
39 |
--------------------------------------------------------------------------------
/workflow/Snakefile:
--------------------------------------------------------------------------------
1 | configfile: "config/config.yaml"
2 |
3 | MODELS = [
4 | "MrVISmall",
5 | "MrVILinear",
6 | "MrVILinear50",
7 | "MrVILinear50COMP",
8 | "MrVILinear10COMP",
9 | "MrVILinearLinear10COMP",
10 | "MrVILinearLinear10",
11 | "MrVILinearLinear10SCALER",
12 | "MILO",
13 | "MILOSCVI",
14 | "PCAKNN",
15 | "CompositionPCA",
16 | "CompositionSCVI",
17 | "SCVI",
18 | ]
19 |
20 | import random
21 |
22 | def get_n_replicates(dataset):
23 | dataset_config = config[dataset]
24 | n_replicates = dataset_config["nReplicates"] if "nReplicates" in dataset_config else 1
25 | return n_replicates
26 |
27 | def get_random_seeds(dataset, n_models):
28 | dataset_config = config[dataset]
29 | n_replicates = dataset_config["nReplicates"] if "nReplicates" in dataset_config else 1
30 | random_seed = dataset_config["randomSeed"] if "randomSeed" in dataset_config else 1
31 | random.seed(random_seed)
32 | rep_random_seeds = [random.randint(0, 2**32) for _ in range(n_replicates * n_models)]
33 | return rep_random_seeds
34 |
35 | rule all:
36 | input:
37 | "results/synthetic_experiment.done",
38 | "results/semisynthetic_experiment.done",
39 | "results/snrna_experiment.done",
40 |
41 | rule synthetic_experiment:
42 | output: touch("results/synthetic_experiment.done")
43 | input:
44 | expand(
45 | "results/synthetic/final_adata_{model}_{seed}.h5ad",
46 | zip,
47 | model=MODELS * get_n_replicates("synthetic"),
48 | seed=get_random_seeds("synthetic", len(MODELS)),
49 | )
50 |
51 | rule semisynthetic_experiment:
52 | output: touch("results/semisynthetic_experiment.done")
53 | input:
54 | expand(
55 | "results/semisynthetic/final_adata_{model}_{seed}.h5ad",
56 | zip,
57 | model=MODELS * get_n_replicates("semisynthetic"),
58 | seed=get_random_seeds("semisynthetic", len(MODELS)),
59 | ),
60 |
61 | rule snrna_experiment:
62 | output: touch("results/snrna_experiment.done")
63 | input:
64 | expand(
65 | "results/snrna/final_adata_{model}_{seed}.h5ad",
66 | zip,
67 | model=MODELS * get_n_replicates("snrna"),
68 | seed=get_random_seeds("snrna", len(MODELS)),
69 | )
70 |
71 |
72 | def get_s3_dataset_path(wildcards):
73 | return config[wildcards.dataset]["s3FilePath"]
74 |
75 | rule load_dataset_from_s3:
76 | params: get_s3_dataset_path
77 | output: "data/{dataset,[A-Za-z0-9]+}/adata.h5ad"
78 | shell:
79 | "aws s3 cp s3://{params} {output}"
80 |
81 | rule process_dataset:
82 | input:
83 | "data/{dataset}/adata.h5ad"
84 | output:
85 | "data/{dataset,[A-Za-z0-9]+}/adata.processed.h5ad",
86 | conda:
87 | "envs/process_data.yaml"
88 | script:
89 | "scripts/process_data.py"
90 |
91 | rule run_model:
92 | input:
93 | "data/{dataset}/adata.processed.h5ad",
94 | output:
95 | "results/{dataset}/adata_{model,[A-Za-z0-9]+}_{seed, \d+}.h5ad"
96 | threads:
97 | 8
98 | log:
99 | "logs/{dataset}_{model}_{seed}.log"
100 | conda:
101 | "envs/run_model.yaml"
102 | resources:
103 | nvidia_gpu=1
104 | script:
105 | "scripts/run_model.py"
106 |
107 | rule compute_local_scores:
108 | input:
109 | "results/{dataset}/adata_{model}_{seed}.h5ad"
110 | output:
111 | "results/{dataset}/final_adata_{model,[A-Za-z0-9]+}_{seed, \d+}.h5ad"
112 | threads:
113 | 8
114 | conda:
115 | "envs/run_model.yaml"
116 | script:
117 | "scripts/compute_local_scores.py"
118 |
--------------------------------------------------------------------------------
/workflow/config/config.yaml:
--------------------------------------------------------------------------------
1 | synthetic:
2 | s3FilePath: largedonor/symsim_new.h5ad
3 | nReplicates: 5
4 | randomSeed: 123
5 | nEpochs: 400
6 | batchSize: 256
7 | keyMapping:
8 | donorKey: donor
9 | cellTypeKey: celltype
10 | nuisanceKeys:
11 | - batch
12 | relevantKeys:
13 | - batch
14 | - donor_meta_1
15 | - donor_meta_2
16 | - donor_meta_3
17 |
18 | semisynthetic:
19 | origFilePath: None
20 | nEpochs: 400
21 | batchSize: 256
22 | s3FilePath: largedonor/scvi_pbmcs.h5ad
23 | keyMapping:
24 | donorKey: batch
25 | cellTypeKey: str_labels
26 | nuisanceKeys:
27 | - Site
28 | relevantKeys:
29 | - Site
30 | - tree_id1_0
31 | - tree_id1_1
32 | - tree_id1_2
33 | - tree_id1_3
34 | - tree_id1_4
35 | - tree_id2_0
36 | - tree_id2_1
37 | - tree_id2_2
38 | - tree_id2_3
39 | - tree_id2_4
40 |
41 | snrna:
42 | origFilePath: SCP259
43 | s3FilePath: largedonor/nucleus.h5ad
44 | nEpochs: 300
45 | batchSize: 256
46 | preprocessing:
47 | filter_genes:
48 | min_cells: 500
49 | highly_variable_genes:
50 | n_top_genes: 3000
51 | flavor: seurat_v3
52 | keyMapping:
53 | donorKey: library_uuid
54 | cellTypeKey: cell_type
55 | nuisanceKeys:
56 | - suspension_type
57 | relevantKeys:
58 | - library_uuid
59 | - suspension_type
60 |
--------------------------------------------------------------------------------
/workflow/envs/process_data.yaml:
--------------------------------------------------------------------------------
1 | channels:
2 | - conda-forge
3 | dependencies:
4 | - python=3.8
5 | - pytorch=1.11.0
6 | - pip:
7 | - scanpy==1.9.1
8 | - scikit-misc==0.1.4
9 |
--------------------------------------------------------------------------------
/workflow/envs/run_model.yaml:
--------------------------------------------------------------------------------
1 | channels:
2 | - pytorch
3 | - conda-forge
4 | - bioconda
5 | - defaults
6 | dependencies:
7 | - python~=3.8
8 | - cudatoolkit=11.3.1=ha36c431_9
9 | - scvi-tools=0.17.4
10 | - pytorch=1.11.0=py3.8_cuda11.3_cudnn8.2.0_0
11 | - torchvision=0.12.0=py38_cu113
12 | - pandas=1.3.5
13 | - leidenalg=0.8.10
14 | - scanpy=1.9.1
15 | - pip
16 | - r-essentials=4.1
17 | - bioconductor-edger==3.36.0
18 | - r-statmod=1.4.36
19 | - scipy<1.9.0
20 | - pip:
21 | - scikit-learn==1.0.2
22 | - scikit-misc==0.1.4
23 | - scanpy==1.9.1
24 | - regex
25 | - git+https://github.com/YosefLab/mrvi.git@main
26 | - git+https://github.com/emdann/milopy.git@master
27 | - mnnpy==0.1.9.5
28 |
--------------------------------------------------------------------------------
/workflow/notebooks/semisynthetic.py:
--------------------------------------------------------------------------------
1 | # %%
2 | import glob
3 | import os
4 | import re
5 | from collections import defaultdict
6 |
7 | import ete3
8 | import matplotlib.pyplot as plt
9 | import numpy as np
10 | import pandas as pd
11 | import plotnine as p9
12 | import scanpy as sc
13 | from scipy import stats
14 | from scipy.cluster.hierarchy import linkage, to_tree
15 | from sklearn.metrics import pairwise_distances, precision_score, recall_score
16 | from statsmodels.stats.multitest import multipletests
17 | from tqdm import tqdm
18 |
19 | # %%
20 | workflow_dir = "../"
21 | base_dir = os.path.join(workflow_dir, "results/semisynthetic")
22 | full_data_path = os.path.join(workflow_dir, "data/semisynthetic/adata.processed.h5ad")
23 | input_files = glob.glob(os.path.join(base_dir, "final_adata*"))
24 |
25 | figure_dir = os.path.join(workflow_dir, "figures/semisynthetic/")
26 |
27 |
28 | model_adatas = defaultdict(list)
29 | mapper = None
30 | for file in tqdm(input_files):
31 | print(file)
32 | adata_ = sc.read_h5ad(file)
33 | file_re = re.match(r".*final_adata_([\w\d]+)_(\d+).h5ad", file)
34 | if file_re is None:
35 | continue
36 | model_name, seed = file_re.groups()
37 | print(model_name, seed)
38 |
39 | uns_keys = list(adata_.uns.keys())
40 | for uns_key in uns_keys:
41 | uns_vals = adata_.uns[uns_key]
42 | if uns_key.endswith("local_donor_rep"):
43 | meta_keys = [
44 | key
45 | for key in adata_.uns.keys()
46 | if key.endswith("_local_donor_rep_metadata")
47 | ]
48 | print(uns_key)
49 | if len(meta_keys) != 0:
50 | # SORT UNS by donor_key values
51 |
52 | meta_key = meta_keys[0]
53 | print(uns_key, meta_key)
54 | mapper = adata_.uns["mapper"]
55 | donor_key = mapper["donor_key"]
56 | metad = adata_.uns[meta_key].reset_index().sort_values(donor_key)
57 | # print(metad.index.values)
58 | uns_vals = uns_vals[:, metad.index.values]
59 | else:
60 | uns_vals = adata_.uns[uns_key]
61 | for key in uns_vals:
62 | uns_vals[key] = uns_vals[key].sort_index()
63 | # print(uns_vals[key].index)
64 |
65 | adata_.uns[uns_key] = uns_vals
66 | model_adatas[model_name].append(dict(adata=adata_, seed=seed))
67 |
68 | # %%
69 | adata_ = model_adatas["MrVILinearLinear10"][0]["adata"]
70 | affected_ct_a = adata_.obs.affected_ct1.unique()[0]
71 | ordered_tree_a = adata_.uns[
72 | "MrVILinearLinear10_local_donor_rep_metadata"
73 | ].tree_id1.sort_values()
74 | t_gt_a = ete3.Tree(
75 | "((({}, {}), ({}, {})), (({},{}), ({}, {})));".format(
76 | *["t{}".format(val + 1) for val in ordered_tree_a.index]
77 | )
78 | )
79 | d_order_a = ordered_tree_a.index
80 |
81 | affected_ct_b = adata_.obs.affected_ct2.unique()[0]
82 | ordered_tree_b = adata_.uns[
83 | "MrVILinearLinear10_local_donor_rep_metadata"
84 | ].tree_id2.sort_values()
85 | t_gt_b = ete3.Tree(
86 | "((({}, {}), ({}, {})), (({},{}), ({}, {})));".format(
87 | *["t{}".format(val + 1) for val in ordered_tree_b.index]
88 | )
89 | )
90 | d_order_b = ordered_tree_b.index
91 | print(affected_ct_a, affected_ct_b)
92 |
93 |
94 | # %%
95 |
96 | similarity_mat = np.zeros((len(ordered_tree_b), len(ordered_tree_b)))
97 | for i in range(len(ordered_tree_b)):
98 | for j in range(len(ordered_tree_b)):
99 | x = np.array([int(v) for v in ordered_tree_b.iloc[i]], dtype=bool)
100 | y = np.array([int(v) for v in ordered_tree_b.iloc[j]], dtype=bool)
101 | shared = x == y
102 | shared_ = np.where(~shared)[0]
103 | if len(shared_) == 0:
104 | similarity_mat[i, j] = 5
105 | else:
106 | similarity_mat[i, j] = np.where(~shared)[0][0]
107 | dissimilarity_mat = 5 - similarity_mat
108 | # %%
109 | ct_key = "leiden"
110 | sample_key = mapper["donor_key"]
111 |
112 | # %%
113 | # Unsupervised analysis
114 | MODELS = [
115 | dict(model_name="CompositionPCA", cell_specific=False),
116 | dict(model_name="CompositionSCVI", cell_specific=False),
117 | dict(model_name="MrVILinearLinear10", cell_specific=True),
118 | ]
119 |
120 |
121 | # %%
122 | def compute_aggregate_dmat(reps):
123 | n_cells, n_donors, _ = reps.shape
124 | pairwise_ds = np.zeros((n_donors, n_donors))
125 | for x in tqdm(reps):
126 | d_ = pairwise_distances(x, metric=METRIC)
127 | pairwise_ds += d_ / n_cells
128 | return pairwise_ds
129 |
130 |
131 | # %%
132 | METRIC = "euclidean"
133 |
134 | dist_mtxs = defaultdict(list)
135 | for model_params in MODELS:
136 | model_name = model_params["model_name"]
137 | rep_key = f"{model_name}_local_donor_rep"
138 | metadata_key = f"{model_name}_local_donor_rep_metadata"
139 | is_cell_specific = model_params["cell_specific"]
140 |
141 | for model_res in model_adatas[model_name]:
142 | adata, seed = model_res["adata"], model_res["seed"]
143 | if rep_key not in adata.uns:
144 | continue
145 | rep = adata.uns[rep_key]
146 |
147 | print(model_name)
148 | for cluster in adata.obs[ct_key].unique():
149 | if not is_cell_specific:
150 | rep_ct = rep[cluster]
151 | ss_matrix = pairwise_distances(rep_ct.values, metric=METRIC)
152 | cats = metad.set_index(sample_key).loc[rep_ct.index].index.values
153 | else:
154 | good_cells = adata.obs[ct_key] == cluster
155 | rep_ct = rep[good_cells]
156 | subobs = adata.obs[good_cells]
157 | observed_donors = subobs[sample_key].value_counts()
158 | observed_donors = observed_donors[observed_donors >= 1].index
159 | good_d_idx = metad.loc[
160 | lambda x: x[sample_key].isin(observed_donors)
161 | ].index.values
162 |
163 | ss_matrix = compute_aggregate_dmat(rep_ct[:, good_d_idx, :])
164 | cats = metad.loc[lambda x: x[sample_key].isin(observed_donors)][
165 | sample_key
166 | ].values
167 |
168 | dist_mtxs[model_name].append(
169 | dict(dist_matrix=ss_matrix, cats=cats, seed=seed, ct=cluster)
170 | )
171 |
172 |
173 | # %%
174 | # https://stackoverflow.com/questions/9364609/converting-ndarray-generated-by-hcluster-into-a-newick-string-for-use-with-ete2/17657426#17657426
175 | def linkage_to_ete(linkage_obj):
176 | R = to_tree(linkage_obj)
177 | root = ete3.Tree()
178 | root.dist = 0
179 | root.name = "root"
180 | item2node = {R.get_id(): root}
181 | to_visit = [R]
182 |
183 | while to_visit:
184 | node = to_visit.pop()
185 | cl_dist = node.dist / 2.0
186 |
187 | for ch_node in [node.get_left(), node.get_right()]:
188 | if ch_node:
189 | ch_node_id = ch_node.get_id()
190 | ch_node_name = (
191 | f"t{int(ch_node_id) + 1}" if ch_node.is_leaf() else str(ch_node_id)
192 | )
193 | ch = ete3.Tree()
194 | ch.dist = cl_dist
195 | ch.name = ch_node_name
196 |
197 | item2node[node.get_id()].add_child(ch)
198 | item2node[ch_node_id] = ch
199 | to_visit.append(ch_node)
200 | return root
201 |
202 |
203 | # %%
204 | fig, ax = plt.subplots()
205 | im = ax.imshow(dissimilarity_mat)
206 | fig.savefig(os.path.join(figure_dir, "CT B_GT_dist_matrix.svg"))
207 | plt.show()
208 | plt.close()
209 | # %%
210 | rf_df = []
211 | ete_trees = defaultdict(list)
212 | for model in dist_mtxs:
213 | for model_dist_mtx in dist_mtxs[model]:
214 | dist_mtx = model_dist_mtx["dist_matrix"]
215 | seed = model_dist_mtx["seed"]
216 | ct = model_dist_mtx["ct"]
217 | cats = model_dist_mtx["cats"]
218 | # cats = [cat[10:] if cat[:10] == "donor_meta" else cat for cat in cats]
219 | if ct == affected_ct_a:
220 | # Heatmaps
221 | fig, ax = plt.subplots()
222 | vmin = np.quantile(dist_mtx, 0.05)
223 | vmax = np.quantile(dist_mtx, 0.7)
224 | im = ax.imshow(
225 | dist_mtx[d_order_a][:, d_order_a],
226 | vmin=vmin,
227 | vmax=vmax,
228 | )
229 |
230 | # Show all ticks and label them with the respective list entries
231 | ax.set_xticks(np.arange(len(cats)), fontsize=5)
232 | ax.set_yticks(np.arange(len(cats)), fontsize=5)
233 |
234 | # Rotate the tick labels and set their alignment.
235 | plt.setp(
236 | ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor"
237 | )
238 |
239 | cell_type = "CT A"
240 | ax.set_title(f"{model} {cell_type}")
241 |
242 | fig.tight_layout()
243 | fig.savefig(
244 | os.path.join(figure_dir, f"{cell_type}_{model}_{seed}_dist_matrix.svg")
245 | )
246 | plt.show()
247 | plt.close()
248 |
249 | z = linkage(dist_mtx, method="ward")
250 | z_ete = linkage_to_ete(z)
251 | ete_trees[model].append(dict(t=z_ete, ct=ct))
252 |
253 | rf_dist = t_gt_a.robinson_foulds(z_ete)
254 | norm_rf = rf_dist[0] / rf_dist[1]
255 | rf_df.append(
256 | dict(
257 | ct=ct,
258 | model=model,
259 | rf=norm_rf,
260 | )
261 | )
262 |
263 | if ct == affected_ct_b:
264 | fig, ax = plt.subplots()
265 | vmin = np.quantile(dist_mtx, 0.05)
266 | vmax = np.quantile(dist_mtx, 0.7)
267 | im = ax.imshow(
268 | dist_mtx[d_order_b][:, d_order_b],
269 | vmin=vmin,
270 | vmax=vmax,
271 | )
272 |
273 | # Show all ticks and label them with the respective list entries
274 | ax.set_xticks(np.arange(len(cats)), fontsize=5)
275 | ax.set_yticks(np.arange(len(cats)), fontsize=5)
276 |
277 | # Rotate the tick labels and set their alignment.
278 | plt.setp(
279 | ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor"
280 | )
281 |
282 | cell_type = "CT B"
283 | ax.set_title(f"{model} {cell_type}")
284 |
285 | fig.tight_layout()
286 | fig.savefig(
287 | os.path.join(figure_dir, f"{cell_type}_{model}_{seed}_dist_matrix.svg")
288 | )
289 | plt.show()
290 | plt.close()
291 |
292 | # Dendrograms
293 | z = linkage(dist_mtx, method="ward")
294 | z_ete = linkage_to_ete(z)
295 | ete_trees[model].append(dict(t=z_ete, ct=ct))
296 |
297 | rf_dist = t_gt_b.robinson_foulds(z_ete)
298 | norm_rf = rf_dist[0] / rf_dist[1]
299 | rf_df.append(
300 | dict(
301 | ct=ct,
302 | model=model,
303 | rf=norm_rf,
304 | )
305 | )
306 | rf_df = pd.DataFrame(rf_df)
307 |
308 | # %%
309 |
310 | for ct in [0, 1]:
311 | subs = rf_df.query(f"ct == '{ct}'")
312 | pop1 = subs.query("model == 'MrVILinearLinear10'").rf
313 | pop2 = subs.query("model == 'CompositionSCVI'").rf
314 | pop3 = subs.query("model == 'CompositionPCA'").rf
315 |
316 | pval21 = stats.ttest_rel(pop1, pop2, alternative="less").pvalue
317 | print("mrVI vs CompositionSCVI", pval21)
318 |
319 | pval31 = stats.ttest_rel(pop1, pop3, alternative="less").pvalue
320 | print("mrVI vs CPCA", pval31)
321 | print()
322 |
323 | # %%
324 | rf_subplot = rf_df.loc[
325 | lambda x: x["model"].isin(
326 | [
327 | "CompositionSCVI",
328 | "CompositionPCA",
329 | "MrVILinearLinear10",
330 | ]
331 | )
332 | ]
333 |
334 | fig = (
335 | p9.ggplot(rf_subplot, p9.aes(x="model", y="rf"))
336 | + p9.geom_bar(p9.aes(fill="model"), stat="summary")
337 | + p9.coord_flip()
338 | + p9.facet_wrap("~ct")
339 | + p9.theme_classic()
340 | + p9.theme(
341 | axis_text=p9.element_text(size=12),
342 | axis_title=p9.element_text(size=15),
343 | aspect_ratio=1,
344 | strip_background=p9.element_blank(),
345 | legend_position="none",
346 | )
347 | + p9.labs(x="", y="Robinson-Foulds Distance")
348 | + p9.scale_y_continuous(expand=[0, 0])
349 | )
350 | fig.save(os.path.join(figure_dir, "semisynth_rf_dist.svg"), verbose=False)
351 | fig
352 |
353 |
354 | # %%
355 | def correct_for_mult(x):
356 | # print(x.name)
357 | if x.name.startswith("MILO"):
358 | print("MILO already FDR controlled; skipping")
359 | return x
360 | return multipletests(x, method="fdr_bh")[1]
361 |
362 |
363 | padj_dfs = []
364 | ct_obs = adata.obs[adata.uns["mapper"]["cell_type_key"]]
365 | for model in model_adatas:
366 | for model_adata in model_adatas[model]:
367 | adata = model_adata["adata"]
368 | # Selecting right columns in adata.obs
369 | sig_keys = [col for col in adata.obs.columns if col.endswith("significance")]
370 | # Adjust for multiple testing if needed
371 | padj_dfs.append(adata.obs.loc[:, sig_keys].apply(correct_for_mult, axis=0))
372 |
373 | # Select right cell subpopulations
374 | padjs = pd.concat(padj_dfs, axis=1)
375 | target_a = ct_obs == "1"
376 | good_cols_a = [col for col in padjs.columns if "tree_id2" in col]
377 |
378 |
379 | ALPHA = 0.05
380 |
381 |
382 | def get_pr(padj, target):
383 | fdr = 1.0 - precision_score(target, padj <= ALPHA, zero_division=1)
384 | tpr = recall_score(target, padj <= ALPHA, zero_division=1)
385 | res = pd.Series(dict(FDP=fdr, TPR=tpr))
386 | return res
387 |
388 |
389 | plot_df = (
390 | padjs.loc[:, good_cols_a]
391 | .apply(get_pr, target=target_a, axis=0)
392 | .stack()
393 | .to_frame("score")
394 | .reset_index()
395 | .rename(columns=dict(level_0="metric", level_1="approach"))
396 | .assign(
397 | model=lambda x: x.approach.str.split("_").str[0],
398 | test=lambda x: x.approach.str.split("_").str[-2],
399 | model_test=lambda x: x.model + " " + x.test,
400 | metadata=lambda x: x.approach.str.split("_")
401 | .str[1:-2]
402 | .apply(lambda y: "_".join(y)),
403 | )
404 | )
405 | plot_df
406 |
407 | # %%
408 | MODEL_SELECTION = [
409 | "MrVILinearLinear10",
410 | "MILOSCVI",
411 | ]
412 |
413 | plot_df_ = plot_df.loc[
414 | lambda x: (x["model"].isin(MODEL_SELECTION))
415 | & (
416 | x["test"].isin(
417 | (
418 | "ks",
419 | "LFC",
420 | )
421 | )
422 | )
423 | & (x["metadata"] != "batch")
424 | ]
425 |
426 | fig = (
427 | p9.ggplot(
428 | plot_df_.query('metric == "TPR"'),
429 | p9.aes(x="factor(metadata)", y="score", fill="model_test"),
430 | )
431 | + p9.geom_bar(stat="identity", position="dodge")
432 | # + p9.facet_wrap("~metadata", scales="free")
433 | # + p9.coord_flip()
434 | + p9.theme_classic()
435 | + p9.theme(
436 | axis_text=p9.element_text(size=12),
437 | axis_title=p9.element_text(size=15),
438 | aspect_ratio=1.5,
439 | )
440 | + p9.labs(x="", y="TPR", fill="")
441 | + p9.scale_y_continuous(expand=(0, 0))
442 | )
443 | fig.save(
444 | os.path.join(figure_dir, "semisynth_TPR_comparison_synth.svg"),
445 | )
446 | fig
447 |
448 | # %%
449 | fig = (
450 | p9.ggplot(
451 | plot_df_.query('metric == "FDP"'),
452 | p9.aes(x="factor(metadata)", y="score", fill="model_test"),
453 | )
454 | + p9.geom_bar(stat="identity", position="dodge")
455 | # + p9.facet_wrap("~metadata", scales="free")
456 | # + p9.coord_flip()
457 | + p9.theme_classic()
458 | + p9.labs(x="", y="FDP", fill="")
459 | + p9.theme(
460 | axis_text=p9.element_text(size=12),
461 | axis_title=p9.element_text(size=15),
462 | aspect_ratio=1.5,
463 | )
464 | + p9.scale_y_continuous(expand=(0, 0))
465 | )
466 | fig.save(
467 | os.path.join(figure_dir, "semisynth_FDR_comparison_synth.svg"),
468 | )
469 | fig
470 |
--------------------------------------------------------------------------------
/workflow/notebooks/snrna.py:
--------------------------------------------------------------------------------
1 | # %%
2 | import glob
3 | import os
4 |
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 | import pandas as pd
8 | import plotnine as p9
9 | import scanpy as sc
10 | import scipy.stats as stats
11 | from scib.metrics import silhouette, silhouette_batch
12 | from sklearn.metrics import pairwise_distances
13 | from sklearn.preprocessing import OneHotEncoder
14 | from tqdm import tqdm
15 |
16 | METRIC = "euclidean"
17 |
18 |
19 | def compute_aggregate_dmat(reps, metric="cosine"):
20 | n_cells, n_donors, _ = reps.shape
21 | pairwise_ds = np.zeros((n_donors, n_donors))
22 | for x in tqdm(reps):
23 | d_ = pairwise_distances(x, metric=metric)
24 | pairwise_ds += d_ / n_cells
25 | return pairwise_ds
26 |
27 |
28 | def return_distances_cluster_rep(rep, metric="cosine"):
29 | ordered_metad = metad.set_index(sample_key).loc[rep.index]
30 | ss_matrix = pairwise_distances(rep.values, metric=metric)
31 | cats = ordered_metad[bio_group_key].values[:, None]
32 | suspension_cats = ordered_metad[techno_key].values[:, None]
33 | return ss_matrix, cats, suspension_cats, ordered_metad
34 |
35 |
36 | def return_distances_cell_specific(rep, good_cells, metric="cosine"):
37 | good_cells = adata.obs[ct_key] == cluster
38 | rep_ct = rep[good_cells]
39 | subobs = adata.obs[good_cells]
40 |
41 | observed_donors = subobs[sample_key].value_counts()
42 | observed_donors = observed_donors[observed_donors >= 1].index
43 | meta_ = metad.reset_index()
44 | good_d_idx = meta_.loc[lambda x: x[sample_key].isin(observed_donors)].index.values
45 | ss_matrix = compute_aggregate_dmat(rep_ct[:, good_d_idx, :], metric=metric)
46 | cats = meta_.loc[good_d_idx][bio_group_key].values[:, None]
47 | suspension_cats = meta_.loc[good_d_idx][techno_key].values[:, None]
48 | return ss_matrix, cats, suspension_cats, meta_.loc[good_d_idx]
49 |
50 |
51 | # %%
52 | dataset_name = "snrna"
53 |
54 | base_dir = "workflow/results_V1/{}".format(dataset_name)
55 | full_data_path = "workflow/data/{}/adata.processed.h5ad".format(dataset_name)
56 | model_path = os.path.join(base_dir, "MrVI")
57 |
58 | input_files = glob.glob(os.path.join(base_dir, "final_adata*"))
59 | input_files
60 |
61 | # %%
62 | adata = sc.read_h5ad(input_files[0])
63 | print(adata.shape)
64 | for file in tqdm(input_files[1:]):
65 | adata_ = sc.read_h5ad(file)
66 | print(file, adata_.shape)
67 | new_cols = np.setdiff1d(adata_.obs.columns, adata.obs.columns)
68 | adata.obs.loc[:, new_cols] = adata_.obs.loc[:, new_cols]
69 |
70 | new_uns = np.setdiff1d(list(adata_.uns.keys()), list(adata.uns.keys()))
71 | for new_uns_ in new_uns:
72 | uns_vals = adata_.uns[new_uns_]
73 | if new_uns_.endswith("local_donor_rep"):
74 | meta_keys = [
75 | key
76 | for key in adata_.uns.keys()
77 | if key.endswith("_local_donor_rep_metadata")
78 | ]
79 | print(new_uns_)
80 | if len(meta_keys) != 0:
81 | # SORT UNS by donor_key values
82 | meta_key = meta_keys[0]
83 | print(new_uns_, meta_key)
84 | donor_key = adata_.uns["mapper"]["donor_key"]
85 | metad = adata_.uns[meta_key].reset_index().sort_values(donor_key)
86 | print(metad.index.values)
87 | uns_vals = uns_vals[:, metad.index.values]
88 | else:
89 | uns_vals = adata_.uns[new_uns_]
90 | for key in uns_vals:
91 | uns_vals[key] = uns_vals[key].sort_index()
92 | print(uns_vals[key].index)
93 |
94 | adata.uns[new_uns_] = uns_vals
95 |
96 | new_obsm = np.setdiff1d(list(adata_.obsm.keys()), list(adata.obsm.keys()))
97 | for new_obsm_ in new_obsm:
98 | adata.obsm[new_obsm_] = adata_.obsm[new_obsm_]
99 |
100 |
101 | # %%
102 | mapper = adata.uns["mapper"]
103 | ct_key = mapper["cell_type_key"]
104 | sample_key = mapper["donor_key"]
105 | bio_group_key = "donor_uuid"
106 | techno_key = "suspension_type"
107 |
108 | adata.obs.loc[:, "Sample_id"] = adata.obs["sample"].str[:3]
109 | metad.loc[:, "Sample_id"] = metad["sample"].str[:3]
110 | metad.loc[:, "Sample_id"] = metad.loc[:, "Sample_id"].astype("category")
111 | # %%
112 | adata.obs["sample"].value_counts()
113 |
114 |
115 | # %%
116 | # Unservised analysis
117 | MODELS = [
118 | dict(model_name="CompositionPCA", cell_specific=False),
119 | dict(model_name="CompositionSCVI", cell_specific=False),
120 | dict(model_name="MrVILinearLinear10", cell_specific=True),
121 | ]
122 | # %%
123 | _dd_plot_df = pd.DataFrame()
124 |
125 | for model_params in MODELS:
126 | rep_key = "{}_local_donor_rep".format(model_params["model_name"])
127 | metadata_key = "{}_local_donor_rep_metadata".format(model_params["model_name"])
128 | is_cell_specific = model_params["cell_specific"]
129 | rep = adata.uns[rep_key]
130 |
131 | for cluster in adata.obs[ct_key].unique():
132 | if not is_cell_specific:
133 | ss_matrix, cats, suspension_cats, meta_ = return_distances_cluster_rep(
134 | rep[cluster], metric=METRIC
135 | )
136 | else:
137 | good_cells = adata.obs[ct_key] == cluster
138 | ss_matrix, cats, suspension_cats, meta_ = return_distances_cell_specific(
139 | rep, good_cells, metric=METRIC
140 | )
141 |
142 | suspension_cats_oh = OneHotEncoder(sparse=False).fit_transform(suspension_cats)
143 | where_similar_sus = suspension_cats_oh @ suspension_cats_oh.T
144 | where_similar_sus = where_similar_sus.astype(bool)
145 |
146 | cats_oh = OneHotEncoder(sparse=False).fit_transform(cats)
147 | where_similar = cats_oh @ cats_oh.T
148 | n_donors = where_similar.shape[0]
149 | # where_similar = (where_similar - np.eye(where_similar.shape[0])).astype(bool)
150 | where_similar = where_similar.astype(bool)
151 | print(where_similar.shape)
152 | offdiag = ~np.eye(where_similar.shape[0], dtype=bool)
153 |
154 | if "library_uuid" not in meta_.columns:
155 | meta_ = meta_.reset_index()
156 | library_names = np.concatenate(
157 | np.array(
158 | [n_donors * [lib_name] for lib_name in meta_["library_uuid"].values]
159 | )[None]
160 | )
161 |
162 | new_vals = pd.DataFrame(
163 | dict(
164 | dist=ss_matrix.reshape(-1),
165 | is_similar=where_similar.reshape(-1),
166 | has_similar_suspension=where_similar_sus.reshape(-1),
167 | library_name1=library_names.reshape(-1),
168 | library_name2=library_names.T.reshape(-1),
169 | )
170 | ).assign(model=model_params["model_name"], cluster=cluster)
171 | _dd_plot_df = pd.concat([_dd_plot_df, new_vals], axis=0)
172 |
173 | # %%
174 | # Construct metadata
175 | final_meta = metad.copy()
176 | final_meta.loc[:, "donor_name"] = metad["sample"].str[:3]
177 |
178 | lib_to_sample_name = pd.Series(
179 | {
180 | "24723d89-8db6-4e5b-a227-5805b49bb8e6": "C41",
181 | "4059d4aa-b0d5-4b88-92f3-f5623e744c2f": "C58_TST",
182 | "7bdadd5c-74cc-4aef-9baf-cd2a75382a0c": "C58_RESEQ",
183 | "7ec5239b-b687-46ea-9c6b-9e2ea970ba21": "C72_RESEQ",
184 | "7ec5239b-b687-46ea-9c6b-9e2ea970ba21_split": "C72_RESEQ_split",
185 | "c557eece-31dc-4825-83c6-7af195076696": "C41_TST",
186 | "d62041ea-a566-4b7b-8280-2a8e5f776270": "C70_RESEQ",
187 | "da945071-1938-4ed5-b0fb-bcf5eef6f92f": "C70_TST",
188 | "f4a052f1-ffd8-4372-ae45-777811d945ee": "C72_TST",
189 | }
190 | ).to_frame("sample_name")
191 | final_meta = final_meta.merge(
192 | lib_to_sample_name, left_on="library_uuid", right_index=True
193 | ).assign(
194 | sample_name1=lambda x: x["sample_name"],
195 | sample_name2=lambda x: x["sample_name"],
196 | donor_name1=lambda x: x["donor_name"],
197 | donor_name2=lambda x: x["donor_name"],
198 | )
199 |
200 |
201 | dd_plot_df_full = (
202 | _dd_plot_df.merge(
203 | final_meta.loc[:, ["library_uuid", "sample_name1", "donor_name1"]],
204 | left_on="library_name1",
205 | right_on="library_uuid",
206 | how="left",
207 | )
208 | .merge(
209 | final_meta.loc[:, ["library_uuid", "sample_name2", "donor_name2"]],
210 | left_on="library_name2",
211 | right_on="library_uuid",
212 | how="left",
213 | )
214 | .assign(
215 | sample_name1=lambda x: pd.Categorical(
216 | x["sample_name1"], categories=np.sort(x.sample_name1.unique())
217 | ),
218 | sample_name2=lambda x: pd.Categorical(
219 | x["sample_name2"], categories=np.sort(x.sample_name2.unique())
220 | ),
221 | donor_name1=lambda x: pd.Categorical(
222 | x["donor_name1"], categories=np.sort(x.donor_name1.unique())
223 | ),
224 | donor_name2=lambda x: pd.Categorical(
225 | x["donor_name2"], categories=np.sort(x.donor_name2.unique())
226 | ),
227 | )
228 | )
229 | dd_plot_df_full.loc[:, "sample_name2_r"] = pd.Categorical(
230 | dd_plot_df_full["sample_name2"].values,
231 | categories=dd_plot_df_full["sample_name2"].cat.categories[::-1],
232 | )
233 | dd_plot_df = dd_plot_df_full.loc[lambda x: x.library_name1 != x.library_name2]
234 |
235 |
236 | # %%
237 | for model in dd_plot_df.model.unique():
238 | plot_ = dd_plot_df_full.query("cluster == 'periportal region hepatocyte'").loc[
239 | lambda x: x.model == model
240 | ]
241 | vmin, vmax = np.quantile(plot_.dist, [0.2, 0.8])
242 | plot_.loc[:, "dist_clip"] = np.clip(plot_.dist, vmin, vmax)
243 | fig = (
244 | p9.ggplot(p9.aes(x="sample_name1", y="sample_name2_r"))
245 | + p9.geom_raster(plot_, p9.aes(fill="dist_clip"))
246 | + p9.geom_tile(
247 | plot_.query("is_similar"),
248 | color="#ff3f05",
249 | fill="none",
250 | size=2,
251 | )
252 | + p9.theme_void()
253 | + p9.theme(
254 | axis_ticks=p9.element_blank(),
255 | )
256 | + p9.scale_y_discrete()
257 | + p9.labs(
258 | title=model,
259 | x="",
260 | y="",
261 | )
262 | + p9.coord_flip()
263 | + p9.scale_fill_cmap("viridis")
264 | )
265 | fig.save("figures/snrna_heatmaps_{}.svg".format(model))
266 | fig.draw()
267 |
268 |
269 | # %%
270 | gp1 = dd_plot_df.groupby(["model", "cluster", "donor_name1"])
271 | v1 = gp1.apply(lambda x: x.query("is_similar").dist.median())
272 | v2 = gp1.apply(
273 | lambda x: x.loc[lambda x: x.has_similar_suspension & ~(x.is_similar)].dist.median()
274 | )
275 |
276 | min_dist = (
277 | dd_plot_df.query("sample_name1 == 'C72_RESEQ'")
278 | .query("sample_name2 == 'C72_RESEQ_split'")
279 | .groupby(["model", "cluster"])
280 | .dist.mean()
281 | .to_frame("min_dist")
282 | .reset_index()
283 | # averages the two pairwise distances
284 | )
285 | v2_s = v2.to_frame("denom_dist").reset_index().query("donor_name1 == 'C72'")
286 | lower_bound_ratio = min_dist.merge(v2_s, on=["model", "cluster"]).assign(
287 | lower_bound_ratio=lambda x: x.min_dist / x.denom_dist
288 | )
289 |
290 |
291 | ratio = (v1 / v2).to_frame("ratio").reset_index().dropna()
292 |
293 | SUBSELECTED_MODELS = [
294 | "MrVILinearLinear10",
295 | "CompositionPCA",
296 | "CompositionSCVI",
297 | ]
298 | ratio_plot = ratio.loc[lambda x: x.model.isin(SUBSELECTED_MODELS)]
299 | lower_bound_ratio_plot = lower_bound_ratio.loc[
300 | lambda x: x.model.isin(SUBSELECTED_MODELS)
301 | ]
302 |
303 | fig = (
304 | p9.ggplot(ratio_plot, p9.aes(x="model", y="ratio", fill="model"))
305 | + p9.geom_boxplot()
306 | + p9.theme_classic()
307 | + p9.coord_flip()
308 | + p9.theme(legend_position="none")
309 | # + p9.labs(y="Ratio of mean distance between similar over dissimilar samples", x="")
310 | )
311 | fig
312 |
313 | # %%
314 |
315 |
316 | (
317 | p9.ggplot(ratio_plot, p9.aes(x="ratio", color="model"))
318 | + p9.stat_ecdf()
319 | + p9.xlim(0, 2)
320 | + p9.labs(y="ECDF", x="Ratio of mean distances")
321 | )
322 |
323 | # %%
324 | vplot = pd.DataFrame(
325 | dict(
326 | ratio=[
327 | 1,
328 | ],
329 | )
330 | )
331 |
332 | fig = (
333 | p9.ggplot(ratio_plot, p9.aes(y="ratio", x="model", fill="model"))
334 | + p9.facet_wrap("donor_name1")
335 | + p9.geom_boxplot()
336 | + p9.geom_hline(p9.aes(yintercept="ratio"), data=vplot, linetype="dashed", size=1)
337 | + p9.coord_flip()
338 | + p9.theme_classic()
339 | + p9.theme(
340 | legend_position="none",
341 | strip_background=p9.element_blank(),
342 | aspect_ratio=0.6,
343 | axis_ticks_major_y=p9.element_blank(),
344 | axis_title=p9.element_text(size=15),
345 | strip_text=p9.element_text(size=15),
346 | axis_text=p9.element_text(size=15),
347 | )
348 | + p9.labs(y="Ratio", x="")
349 | )
350 | fig.save("figures/snrna_ratios.svg")
351 | fig
352 |
353 |
354 | # %%
355 | for donor_name in ratio_plot.donor_name1.unique():
356 | subs = ratio_plot.query("donor_name1 == @donor_name")
357 | pop1 = subs.query("model == 'MrVILinearLinear10'").ratio
358 | pop2 = subs.query("model == 'CompositionSCVI'").ratio
359 | pop3 = subs.query("model == 'CompositionPCA'").ratio
360 |
361 | print(donor_name)
362 | pval21 = stats.ttest_rel(pop1, pop2, alternative="less").pvalue
363 | print("mrVI vs CompositionSCVI", pval21)
364 |
365 | pval31 = stats.ttest_rel(pop1, pop3, alternative="less").pvalue
366 | print("mrVI vs CPCA", pval31)
367 | print()
368 |
369 |
370 | # %%
371 | MODELS
372 |
373 |
374 | def return_distances_cell_specific(rep, good_cells, metric="cosine"):
375 | good_cells = adata.obs[ct_key] == cluster
376 | rep_ct = rep[good_cells]
377 | subobs = adata.obs[good_cells]
378 |
379 | observed_donors = subobs[sample_key].value_counts()
380 | observed_donors = observed_donors[observed_donors >= 1].index
381 | meta_ = metad.reset_index()
382 | good_d_idx = meta_.loc[lambda x: x[sample_key].isin(observed_donors)].index.values
383 | ss_matrix = compute_aggregate_dmat(rep_ct[:, good_d_idx, :], metric=metric)
384 | cats = meta_.loc[good_d_idx][bio_group_key].values[:, None]
385 | suspension_cats = meta_.loc[good_d_idx][techno_key].values[:, None]
386 | return ss_matrix, cats, suspension_cats, meta_.loc[good_d_idx]
387 |
388 |
389 | # %%
390 | _dd_plot_df = pd.DataFrame()
391 |
392 | MODELS = [
393 | dict(model_name="CompositionPCA", cell_specific=False),
394 | dict(model_name="CompositionSCVI", cell_specific=False),
395 | dict(model_name="MrVILinearLinear10", cell_specific=True),
396 | ]
397 |
398 |
399 | rdm_idx = np.random.choice(adata.n_obs, 10000, replace=False)
400 |
401 | for model_params in MODELS:
402 | rep_key = "{}_local_donor_rep".format(model_params["model_name"])
403 | metadata_key = "{}_local_donor_rep_metadata".format(model_params["model_name"])
404 | is_cell_specific = model_params["cell_specific"]
405 | rep = adata.uns[rep_key]
406 |
407 | rep_ = rep[rdm_idx]
408 | ss_matrix = compute_aggregate_dmat(rep_, metric=METRIC)
409 | ds = []
410 | donors = metad.Sample_id.cat.codes.values[..., None]
411 | select_num = OneHotEncoder(sparse=False).fit_transform(donors)
412 | where_same_d = select_num @ select_num.T
413 |
414 | donors = metad.Sample_id.cat.codes.values[..., None]
415 | select_num = OneHotEncoder(sparse=False).fit_transform(donors)
416 | where_same_d = select_num @ select_num.T
417 | for x in tqdm(rep_):
418 | d_ = pairwise_distances(x, metric=METRIC)
419 |
420 | # %%
421 | cell_reps = [
422 | "MrVILinearLinear10_cell",
423 | "SCVI_cell",
424 | ]
425 |
426 | mixing_df = []
427 |
428 | for rep in cell_reps:
429 | algo_name = rep.split("_")[0]
430 | batch_aws = silhouette_batch(
431 | adata, "suspension_type", "author_cell_type", rep, verbose=False
432 | )
433 | sample_aws = silhouette_batch(
434 | adata, "library_uuid", "author_cell_type", rep, verbose=False
435 | )
436 | ct_asw = silhouette(adata, "author_cell_type", rep)
437 |
438 | mixing_df.append(
439 | dict(
440 | batch_aws=batch_aws,
441 | sample_aws=sample_aws,
442 | algo_name=algo_name,
443 | ct_asw=ct_asw,
444 | )
445 | )
446 | mixing_df = pd.DataFrame(mixing_df)
447 | # %%
448 | mixing_df
449 | # %%
450 | for u_key in [
451 | "MrVILinearLinear10_cell",
452 | ]:
453 | sc.pp.neighbors(adata, n_neighbors=15, use_rep=u_key)
454 | sc.tl.umap(adata)
455 |
456 | savename = "_".join([dataset_name, "u_sample_suspension_type"])
457 | savename += ".png"
458 | with plt.rc_context({"figure.dpi": 500}):
459 | sc.pl.umap(
460 | adata,
461 | color=["Sample_id", "suspension_type", "author_cell_type"],
462 | title=u_key,
463 | save=savename,
464 | )
465 |
--------------------------------------------------------------------------------
/workflow/notebooks/synthetic.py:
--------------------------------------------------------------------------------
1 | # %%
2 | import glob
3 | import os
4 | import re
5 | from collections import defaultdict
6 | from itertools import product
7 |
8 | import ete3
9 | import matplotlib.pyplot as plt
10 | import numpy as np
11 | import pandas as pd
12 | import scanpy as sc
13 | import seaborn as sns
14 | from scipy.cluster.hierarchy import dendrogram, linkage, to_tree
15 | from sklearn.metrics import average_precision_score, pairwise_distances, precision_score
16 | from statsmodels.stats.multitest import multipletests
17 | from tqdm import tqdm
18 |
19 | # %%
20 | workflow_dir = "../"
21 | base_dir = os.path.join(workflow_dir, "results/synthetic")
22 | full_data_path = os.path.join(workflow_dir, "data/synthetic/adata.processed.h5ad")
23 | input_files = glob.glob(os.path.join(base_dir, "final_adata*"))
24 |
25 | figure_dir = os.path.join(workflow_dir, "figures/synthetic/")
26 |
27 | # %%
28 | # Ground truth similarity matrix
29 | meta_corr = [0, 0.5, 0.9]
30 | donor_combos = list(product(*[[0, 1] for _ in range(len(meta_corr))]))
31 |
32 | # E[||x - y||_2^2]
33 | meta_dist = 2 - 2 * np.array(meta_corr)
34 | dist_mtx = np.zeros((len(donor_combos), len(donor_combos)))
35 | for i in range(len(donor_combos)):
36 | for j in range(i):
37 | donor_i, donor_j = donor_combos[i], donor_combos[j]
38 | donor_diff = abs(np.array(donor_i) - np.array(donor_j))
39 | dist_mtx[i, j] = dist_mtx[j, i] = np.sum(donor_diff * meta_dist)
40 | dist_mtx = np.sqrt(dist_mtx) # E[||x - y||_2]
41 |
42 | donor_replicates = 4
43 |
44 | gt_donor_combos = sum([donor_replicates * [str(dc)] for dc in donor_combos], [])
45 | gt_dist_mtx = np.zeros((len(gt_donor_combos), len(gt_donor_combos)))
46 | for i in range(len(gt_donor_combos)):
47 | for j in range(i + 1):
48 | gt_dist_mtx[i, j] = gt_dist_mtx[j, i] = dist_mtx[
49 | i // donor_replicates, j // donor_replicates
50 | ]
51 | gt_control_dist_mtx = 1 - np.eye(len(gt_donor_combos))
52 |
53 | # %%
54 | model_adatas = defaultdict(list)
55 |
56 | mapper = None
57 | for file in tqdm(input_files):
58 | adata_ = sc.read_h5ad(file)
59 | file_re = re.match(r".*final_adata_([\w\d]+)_(\d+).h5ad", file)
60 | if file_re is None:
61 | continue
62 | model_name, seed = file_re.groups()
63 | print(model_name, seed)
64 |
65 | uns_keys = list(adata_.uns.keys())
66 | for uns_key in uns_keys:
67 | uns_vals = adata_.uns[uns_key]
68 | if uns_key.endswith("local_donor_rep"):
69 | meta_keys = [
70 | key
71 | for key in adata_.uns.keys()
72 | if key.endswith("_local_donor_rep_metadata")
73 | ]
74 | print(uns_key)
75 | if len(meta_keys) != 0:
76 | # SORT UNS by donor_key values
77 | meta_key = meta_keys[0]
78 | print(uns_key, meta_key)
79 | mapper = adata_.uns["mapper"]
80 | donor_key = mapper["donor_key"]
81 | metad = adata_.uns[meta_key].reset_index().sort_values(donor_key)
82 | # print(metad.index.values)
83 | uns_vals = uns_vals[:, metad.index.values]
84 | else:
85 | uns_vals = adata_.uns[uns_key]
86 | for key in uns_vals:
87 | uns_vals[key] = uns_vals[key].sort_index()
88 | # print(uns_vals[key].index)
89 |
90 | adata_.uns[uns_key] = uns_vals
91 | model_adatas[model_name].append(dict(adata=adata_, seed=seed))
92 |
93 | # %%
94 | ct_key = mapper["cell_type_key"]
95 | sample_key = mapper["donor_key"]
96 |
97 | # %%
98 | # Unsupervised analysis
99 | MODELS = [
100 | dict(model_name="CompositionPCA", cell_specific=False),
101 | dict(model_name="CompositionSCVI", cell_specific=False),
102 | dict(model_name="MrVISmall", cell_specific=True),
103 | dict(model_name="MrVILinear", cell_specific=True),
104 | dict(model_name="MrVILinear50", cell_specific=True),
105 | dict(model_name="MrVILinear10COMP", cell_specific=True),
106 | dict(model_name="MrVILinear50COMP", cell_specific=True),
107 | ]
108 |
109 |
110 | # %%
111 | def compute_aggregate_dmat(reps):
112 | # return pairwise_distances(reps.mean(0))
113 | n_cells, n_donors, _ = reps.shape
114 | pairwise_ds = np.zeros((n_donors, n_donors))
115 | for x in tqdm(reps):
116 | d_ = pairwise_distances(x, metric=METRIC)
117 | pairwise_ds += d_ / n_cells
118 | return pairwise_ds
119 |
120 |
121 | # %%
122 | METRIC = "euclidean"
123 |
124 | dist_mtxs = defaultdict(list)
125 | for model_params in MODELS:
126 | model_name = model_params["model_name"]
127 | rep_key = f"{model_name}_local_donor_rep"
128 | metadata_key = f"{model_name}_local_donor_rep_metadata"
129 | is_cell_specific = model_params["cell_specific"]
130 |
131 | for model_res in model_adatas[model_name]:
132 | adata, seed = model_res["adata"], model_res["seed"]
133 | if rep_key not in adata.uns:
134 | continue
135 | rep = adata.uns[rep_key]
136 |
137 | print(model_name)
138 | for cluster in adata.obs[ct_key].unique():
139 | if not is_cell_specific:
140 | rep_ct = rep[cluster]
141 | ss_matrix = pairwise_distances(rep_ct.values, metric=METRIC)
142 | cats = metad.set_index(sample_key).loc[rep_ct.index].index.values
143 | else:
144 | good_cells = adata.obs[ct_key] == cluster
145 | rep_ct = rep[good_cells]
146 | subobs = adata.obs[good_cells]
147 | observed_donors = subobs[sample_key].value_counts()
148 | observed_donors = observed_donors[observed_donors >= 1].index
149 | good_d_idx = metad.loc[
150 | lambda x: x[sample_key].isin(observed_donors)
151 | ].index.values
152 |
153 | ss_matrix = compute_aggregate_dmat(rep_ct[:, good_d_idx, :])
154 | cats = metad.loc[lambda x: x[sample_key].isin(observed_donors)][
155 | sample_key
156 | ].values
157 |
158 | dist_mtxs[model_name].append(
159 | dict(dist_matrix=ss_matrix, cats=cats, seed=seed, ct=cluster)
160 | )
161 |
162 | # %%
163 | dist_mtxs["GroundTruth"] = [
164 | dict(dist_matrix=gt_dist_mtx, cats=gt_donor_combos, seed=None, ct="CT1:1"),
165 | dict(dist_matrix=gt_control_dist_mtx, cats=gt_donor_combos, seed=None, ct="CT2:1"),
166 | ]
167 |
168 |
169 | # %%
170 | # https://stackoverflow.com/questions/9364609/converting-ndarray-generated-by-hcluster-into-a-newick-string-for-use-with-ete2/17657426#17657426
171 | def linkage_to_ete(linkage_obj):
172 | R = to_tree(linkage_obj)
173 | root = ete3.Tree()
174 | root.dist = 0
175 | root.name = "root"
176 | item2node = {R.get_id(): root}
177 | to_visit = [R]
178 |
179 | while to_visit:
180 | node = to_visit.pop()
181 | cl_dist = node.dist / 2.0
182 |
183 | for ch_node in [node.get_left(), node.get_right()]:
184 | if ch_node:
185 | ch_node_id = ch_node.get_id()
186 | ch_node_name = (
187 | f"t{int(ch_node_id) + 1}" if ch_node.is_leaf() else str(ch_node_id)
188 | )
189 | ch = ete3.Tree()
190 | ch.dist = cl_dist
191 | ch.name = ch_node_name
192 |
193 | item2node[node.get_id()].add_child(ch)
194 | item2node[ch_node_id] = ch
195 | to_visit.append(ch_node)
196 | return root
197 |
198 |
199 | # %%
200 | ete_trees = defaultdict(list)
201 | for model in dist_mtxs:
202 | for model_dist_mtx in dist_mtxs[model]:
203 | dist_mtx = model_dist_mtx["dist_matrix"]
204 | seed = model_dist_mtx["seed"]
205 | ct = model_dist_mtx["ct"]
206 | cats = model_dist_mtx["cats"]
207 | cats = [cat[10:] if cat[:10] == "donor_meta" else cat for cat in cats]
208 |
209 | # Heatmaps
210 | fig, ax = plt.subplots()
211 | im = ax.imshow(dist_mtx)
212 |
213 | # Show all ticks and label them with the respective list entries
214 | ax.set_xticks(np.arange(len(cats)), labels=cats, fontsize=5)
215 | ax.set_yticks(np.arange(len(cats)), labels=cats, fontsize=5)
216 |
217 | # Rotate the tick labels and set their alignment.
218 | plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
219 |
220 | cell_type = "Experimental Cell Type" if ct == "CT1:1" else "Control Cell Type"
221 | ax.set_title(f"{model} {cell_type}")
222 |
223 | fig.tight_layout()
224 | fig.savefig(os.path.join(figure_dir, f"{model}_{ct}_{seed}_dist_matrix.svg"))
225 | plt.show()
226 | plt.close()
227 |
228 | # Dendrograms
229 | z = linkage(dist_mtx, method="ward")
230 |
231 | fig, ax = plt.subplots()
232 | dn = dendrogram(z, ax=ax, orientation="top")
233 | ax.set_title(f"{model} {cell_type}")
234 | fig.tight_layout()
235 | fig.savefig(os.path.join(figure_dir, f"{model}_{ct}_{seed}_dendrogram.svg"))
236 | plt.close()
237 |
238 | z_ete = linkage_to_ete(z)
239 | ete_trees[model].append(dict(t=z_ete, ct=ct))
240 |
241 | # %%
242 | # Boxplot for RF distance
243 | for gt_trees in ete_trees["GroundTruth"]:
244 | if gt_trees["ct"] == "CT1:1":
245 | gt_tree = gt_trees["t"]
246 | elif gt_trees["ct"] == "CT2:1":
247 | gt_control_tree = gt_trees["t"]
248 | else:
249 | continue
250 |
251 | rf_df_rows = []
252 | for model, ts in ete_trees.items():
253 | if model == "GroundTruth":
254 | continue
255 |
256 | for t_dict in ts:
257 | t = t_dict["t"]
258 | ct = t_dict["ct"]
259 | if ct == "CT1:1":
260 | rf_dist = gt_tree.robinson_foulds(t)
261 | elif ct == "CT2:1":
262 | rf_dist = gt_control_tree.robinson_foulds(t)
263 | else:
264 | continue
265 | norm_rf = rf_dist[0] / rf_dist[1]
266 |
267 | rf_df_rows.append((model, ct, norm_rf))
268 |
269 | rf_df = pd.DataFrame(rf_df_rows, columns=["model", "cell_type", "rf_dist"])
270 |
271 | # %%
272 | # Experimental Cell Type Plot
273 | fig, ax = plt.subplots(figsize=(2, 5))
274 | sns.barplot(
275 | data=rf_df[rf_df["cell_type"] == "CT1:1"],
276 | x="model",
277 | y="rf_dist",
278 | ax=ax,
279 | )
280 | sns.swarmplot(
281 | data=rf_df[rf_df["cell_type"] == "CT1:1"],
282 | x="model",
283 | y="rf_dist",
284 | color="black",
285 | ax=ax,
286 | )
287 | plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
288 | ax.set_xticklabels(ax.get_xticklabels(), fontsize=7)
289 | ax.set_xlabel("Robinson-Foulds Distance")
290 | ax.set_ylabel("Model")
291 | ax.set_title("Robinson-Foulds Comparison to Ground Truth for Experimental Cell Type")
292 | fig.tight_layout()
293 | fig.savefig(
294 | os.path.join(figure_dir, "robinson_foulds_CT1:1_boxplot.svg"), bbox_inches="tight"
295 | )
296 | plt.show()
297 |
298 |
299 | # %%
300 | def correct_for_mult(x):
301 | # print(x.name)
302 | if x.name.startswith("MILO"):
303 | print("MILO already FDR controlled; skipping")
304 | return x
305 | return multipletests(x, method="fdr_bh")[1]
306 |
307 |
308 | padj_dfs = []
309 | ct_obs = None
310 | for model in model_adatas:
311 | for model_adata in model_adatas[model]:
312 | adata = model_adata["adata"]
313 | if ct_obs is None:
314 | ct_obs = adata.obs["celltype"]
315 |
316 | # Selecting right columns in adata.obs
317 | sig_keys = [col for col in adata.obs.columns if col.endswith("significance")]
318 | # Adjust for multiple testing if needed
319 | padj_dfs.append(adata.obs.loc[:, sig_keys].apply(correct_for_mult, axis=0))
320 |
321 | # %%
322 | # Select right cell subpopulations
323 | padjs = pd.concat(padj_dfs, axis=1)
324 | target = ct_obs == "CT1:1"
325 |
326 | plot_df = (
327 | padjs.apply(
328 | lambda x: pd.Series(
329 | {
330 | **{
331 | target_fdr: 1.0
332 | - precision_score(target, x <= target_fdr, zero_division=0)
333 | for target_fdr in [0.05, 0.1, 0.2]
334 | },
335 | }
336 | ),
337 | axis=0,
338 | )
339 | .stack()
340 | .to_frame("FDP")
341 | .reset_index()
342 | .rename(columns=dict(level_0="targetFDR", level_1="approach"))
343 | .assign(
344 | model=lambda x: x.approach.str.split("_").str[0],
345 | test=lambda x: x.approach.str.split("_").str[-2],
346 | model_test=lambda x: x.model + " " + x.test,
347 | metadata=lambda x: x.approach.str.split("_")
348 | .str[1:-2]
349 | .apply(lambda y: "_".join(y)),
350 | )
351 | )
352 | # %%
353 | # Keep MrVILinear Manova, PCAKNN Manova, SCVI MANOVA, MRVILINEAR50 KS
354 | # MrVILinear KS, PCAKNN KS, SCVI KS, MrVILiear50 KS, MILOSCVI LFC
355 | for metadata in plot_df.metadata.unique():
356 | fig, ax = plt.subplots()
357 | sns.barplot(
358 | data=plot_df[
359 | (plot_df["metadata"] == metadata)
360 | & (
361 | plot_df["model"].isin(
362 | (
363 | "MrVILinear",
364 | "MrVILinear50",
365 | "PCAKNN",
366 | "SCVI",
367 | "MILOSCVI",
368 | )
369 | )
370 | )
371 | & (
372 | plot_df["test"].isin(
373 | (
374 | # "manova",
375 | "ks",
376 | "LFC",
377 | )
378 | )
379 | )
380 | ],
381 | x="targetFDR",
382 | y="FDP",
383 | hue="model_test",
384 | ax=ax,
385 | hue_order=[
386 | "MrVILinear ks",
387 | "MrVILinear50 ks",
388 | "SCVI ks",
389 | "MILOSCVI LFC",
390 | "PCAKNN ks",
391 | ],
392 | )
393 | plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
394 | ax.set_xticklabels(ax.get_xticklabels(), fontsize=7)
395 | ax.set_xlabel("Target FDR")
396 | ax.set_ylabel("FDP")
397 | ax.set_title(f"False Discovery Rate Comparison ({metadata[6:]})")
398 | sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1))
399 |
400 | fig.tight_layout()
401 | fig.savefig(
402 | os.path.join(figure_dir, f"FDR_comparison_{metadata}_ks_only.svg"),
403 | bbox_inches="tight",
404 | )
405 | plt.show()
406 | # %%
407 | # AP Comparison
408 | ap_plot_df = (
409 | padjs.apply(
410 | lambda x: pd.Series(
411 | {
412 | **{
413 | target_fdr: average_precision_score(target, x <= target_fdr)
414 | for target_fdr in [0.05, 0.1, 0.2]
415 | },
416 | }
417 | ),
418 | axis=0,
419 | )
420 | .stack()
421 | .to_frame("AP")
422 | .reset_index()
423 | .rename(columns=dict(level_0="targetFDR", level_1="approach"))
424 | .assign(
425 | model=lambda x: x.approach.str.split("_").str[0],
426 | test=lambda x: x.approach.str.split("_").str[-2],
427 | model_test=lambda x: x.model + " " + x.test,
428 | metadata=lambda x: x.approach.str.split("_")
429 | .str[1:-2]
430 | .apply(lambda y: "_".join(y)),
431 | )
432 | )
433 |
434 | # %%
435 | for metadata in ap_plot_df.metadata.unique():
436 | fig, ax = plt.subplots()
437 | sns.barplot(
438 | data=ap_plot_df[
439 | (ap_plot_df["metadata"] == metadata)
440 | & (
441 | plot_df["model"].isin(
442 | (
443 | "MrVILinear",
444 | "MrVILinear50",
445 | "PCAKNN",
446 | "SCVI",
447 | "MILOSCVI",
448 | )
449 | )
450 | )
451 | & (
452 | plot_df["test"].isin(
453 | (
454 | # "manova",
455 | "ks",
456 | "LFC",
457 | )
458 | )
459 | )
460 | ],
461 | x="targetFDR",
462 | y="AP",
463 | hue="model_test",
464 | ax=ax,
465 | hue_order=[
466 | "MrVILinear ks",
467 | "MrVILinear50 ks",
468 | "SCVI ks",
469 | "MILOSCVI LFC",
470 | "PCAKNN ks",
471 | ],
472 | )
473 | plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
474 | ax.set_xticklabels(ax.get_xticklabels(), fontsize=7)
475 | ax.set_xlabel("Target FDR")
476 | ax.set_ylabel("AP")
477 | ax.set_ylim((0.5, 1.05))
478 | ax.set_title(f"Average Precision Rate Comparison ({metadata})")
479 | sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1))
480 |
481 | fig.tight_layout()
482 | fig.savefig(
483 | os.path.join(figure_dir, f"AP_comparison_{metadata}_ks_only.svg"),
484 | bbox_inches="tight",
485 | )
486 | plt.show()
487 |
488 | # %%
489 |
--------------------------------------------------------------------------------
/workflow/scripts/compute_local_scores.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from functools import partial
3 |
4 | import anndata as ad
5 | import numpy as np
6 | import pynndescent
7 | import scanpy as sc
8 | import torch
9 | from tqdm import tqdm
10 | from utils import compute_ks, compute_manova
11 |
12 |
13 | def compute_autocorrelation_metric(
14 | local_representation,
15 | donor_labels,
16 | metric_fn,
17 | has_significance,
18 | batch_size,
19 | minibatched=True,
20 | desc="desc",
21 | ):
22 | # Compute autocorrelation
23 | if minibatched:
24 | scores = []
25 | pvals = []
26 | for local_rep in tqdm(local_representation.split(batch_size), desc=desc):
27 | local_rep_ = local_rep.to("cuda")
28 | if has_significance:
29 | score, pval = metric_fn(local_rep_, donor_labels)
30 | scores.append(score)
31 | pvals.append(pval)
32 | else:
33 | scores_ = metric_fn(local_rep_, donor_labels=donor_labels)
34 | scores.append(scores_)
35 | scores = np.concatenate(scores, 0)
36 | pvals = np.concatenate(pvals, 0) if has_significance else None
37 | else:
38 | if has_significance:
39 | scores, pvals = metric_fn(
40 | local_representation.numpy(), donor_labels=donor_labels
41 | )
42 | else:
43 | scores = metric_fn(local_representation.numpy(), donor_labels=donor_labels)
44 | pvals = None
45 | return scores, pvals
46 |
47 |
48 | def compute_mrvi(adata, model_name, obs_key, batch_size=256, redo=True):
49 | uns_key = "{}_local_donor_rep".format(model_name)
50 | local_donor_rep = adata.uns[uns_key]
51 | if not isinstance(local_donor_rep, np.ndarray):
52 | logging.warn(f"{model_name} local donor rep not compatible with metric.")
53 | return
54 | local_representation = torch.from_numpy(local_donor_rep)
55 | # Assumes local_representation is n_cells, n_donors, n_donors
56 | metadata_key = "{}_local_donor_rep_metadata".format(model_name)
57 | metadata = adata.uns[metadata_key]
58 | donor_labels = metadata[obs_key].values
59 |
60 | # Compute and save various metric scores
61 | _compute_ks = partial(compute_ks, do_smoothing=False)
62 | configs = [
63 | dict(
64 | metric_name="ks",
65 | metric_fn=_compute_ks,
66 | has_significance=True,
67 | minibatched=True,
68 | ),
69 | dict(
70 | metric_name="manova",
71 | metric_fn=compute_manova,
72 | has_significance=True,
73 | minibatched=False,
74 | ),
75 | ]
76 |
77 | for config in configs:
78 | metric_name = config.pop("metric_name")
79 | scores, pvals = compute_autocorrelation_metric(
80 | local_representation,
81 | donor_labels,
82 | batch_size=batch_size,
83 | **config,
84 | )
85 | output_key = f"{model_name}_{obs_key}_{metric_name}_score"
86 | sig_key = f"{model_name}_{obs_key}_{metric_name}_significance"
87 |
88 | is_obs = scores.shape[0] == adata.shape[0]
89 | adata.uns[output_key] = scores
90 | if pvals is not None:
91 | adata.uns[sig_key] = pvals
92 | if is_obs:
93 | adata.obs[output_key] = scores
94 | if pvals is not None:
95 | adata.obs[sig_key] = pvals
96 |
97 |
98 | def compute_milo(adata, model_name, obs_key):
99 | import milopy.core as milo
100 |
101 | sample_col = adata.uns["nhood_adata"].uns["sample_col"]
102 | if sample_col == obs_key:
103 | logging.warning("Milo cannot run a GLM against the sample col, skipping.")
104 | return
105 |
106 | try:
107 | adata.obs[obs_key] = adata.obs[obs_key].astype(str).astype("category")
108 | design = f"~ {obs_key}"
109 |
110 | # Issue with None valued uns values being dropped.
111 | if "nhood_neighbors_key" not in adata.uns:
112 | adata.uns["nhood_neighbors_key"] = None
113 | milo.DA_nhoods(adata, design=design)
114 | milo_results = adata.uns["nhood_adata"].obs
115 |
116 | except Exception as e:
117 | logging.warning(
118 | f"Skipping test since key {obs_key} may be invalid or model did not complete"
119 | f" with the error message: {e.__class__.__name__} - {str(e)}."
120 | )
121 | return
122 |
123 | is_index_cell_key = f"{model_name}_{obs_key}_is_index_cell"
124 | is_index_cell = adata.obs.index.isin(milo_results["index_cell"])
125 | adata.obs[is_index_cell_key] = is_index_cell
126 |
127 | cell_rep = adata.obsm["_cell_rep"]
128 | cell_anchors = cell_rep[is_index_cell]
129 | nn_donor = pynndescent.NNDescent(cell_anchors)
130 | nn_indices = nn_donor.query(cell_rep, k=1)[0].squeeze(-1)
131 |
132 | output_key = f"{model_name}_{obs_key}_LFC_score"
133 | lfc_score = milo_results["logFC"].values[nn_indices]
134 | adata.obs[output_key] = lfc_score
135 |
136 | sig_key = f"{model_name}_{obs_key}_LFC_significance"
137 | lfc_sig = milo_results["SpatialFDR"].values[nn_indices]
138 | adata.obs[sig_key] = lfc_sig
139 | return
140 |
141 |
142 | def compute_metrics(adata, model_name, obs_key, batch_size=256, redo=True):
143 | if model_name.startswith("MILO"):
144 | compute_milo(adata, model_name, obs_key)
145 | else:
146 | compute_mrvi(adata, model_name, obs_key, batch_size, redo)
147 |
148 |
149 | def process_predictions(
150 | model_name,
151 | path_to_h5ad,
152 | path_to_output,
153 | donor_obs_keys,
154 | ):
155 | adata = sc.read_h5ad(path_to_h5ad)
156 | for donor_obs_key in donor_obs_keys:
157 | compute_metrics(adata, model_name, donor_obs_key)
158 | adata_ = ad.AnnData(
159 | obs=adata.obs,
160 | obsm=adata.obsm,
161 | var=adata.var,
162 | uns=adata.uns,
163 | )
164 | adata_.write(path_to_output)
165 |
166 |
167 | if __name__ == "__main__":
168 | process_predictions(
169 | model_name=snakemake.wildcards.model,
170 | path_to_h5ad=snakemake.input[0],
171 | path_to_output=snakemake.output[0],
172 | donor_obs_keys=snakemake.config[snakemake.wildcards.dataset]["keyMapping"][
173 | "relevantKeys"
174 | ],
175 | )
176 |
--------------------------------------------------------------------------------
/workflow/scripts/process_data.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | import string
3 |
4 | import numpy as np
5 | import pandas as pd
6 | import scanpy as sc
7 | import scipy.sparse as sp
8 |
9 |
10 | def create_joint_obs_key(adata, keys):
11 | joint_key_name = "_".join(keys)
12 | adata.obs[joint_key_name] = (
13 | adata.obs[keys].astype(str).apply(lambda x: "-".join(x), axis=1)
14 | )
15 | return joint_key_name
16 |
17 |
18 | def make_categorical(adata, obs_key):
19 | adata.obs[obs_key] = adata.obs[obs_key].astype("category")
20 |
21 |
22 | def create_obs_mapper(adata, dataset_name):
23 | dataset_config = snakemake.config[snakemake.wildcards.dataset]["keyMapping"]
24 |
25 | donor_key = dataset_config["donorKey"]
26 | if isinstance(donor_key, list):
27 | donor_key = create_joint_obs_key(adata, donor_key)
28 | make_categorical(adata, donor_key)
29 |
30 | cell_type_key = dataset_config["cellTypeKey"]
31 | make_categorical(adata, cell_type_key)
32 |
33 | nuisance_keys = dataset_config["nuisanceKeys"]
34 | for nuisance_key in nuisance_keys:
35 | make_categorical(adata, nuisance_key)
36 |
37 | adata.uns["mapper"] = dict(
38 | donor_key=donor_key,
39 | categorical_nuisance_keys=nuisance_keys,
40 | cell_type_key=cell_type_key,
41 | )
42 |
43 |
44 | def assign_symsim_donors(adata):
45 | np.random.seed(1)
46 | dataset_config = snakemake.config["synthetic"]["keyMapping"]
47 | donor_key = dataset_config["donorKey"]
48 | batch_key = dataset_config["nuisanceKeys"][0]
49 |
50 | n_donors = 32
51 | n_meta = len([k for k in adata.obs.keys() if "meta_" in k])
52 | meta_keys = [f"meta_{i + 1}" for i in range(n_meta)]
53 | make_categorical(adata, batch_key)
54 | batches = adata.obs[batch_key].cat.categories.tolist()
55 | n_batch = len(batches)
56 |
57 | meta_combos = list(itertools.product([0, 1], repeat=n_meta))
58 | donors_per_meta_batch_combo = n_donors // len(meta_combos) // n_batch
59 |
60 | # Assign donors uniformly at random for cells with matching metadata.
61 | donor_assignment = np.empty(adata.n_obs, dtype=object)
62 | for batch in batches:
63 | batch_donors = []
64 | for meta_combo in meta_combos:
65 | match_cats = [f"CT{meta_combo[i]+1}:1" for i in range(n_meta)]
66 | eligible_cell_idxs = (
67 | (
68 | np.all(
69 | adata.obs[meta_keys].values == match_cats,
70 | axis=1,
71 | )
72 | & (adata.obs[batch_key] == batch)
73 | )
74 | .to_numpy()
75 | .nonzero()[0]
76 | )
77 | meta_donors = [
78 | f"donor_meta{meta_combo}_batch{batch}_{ch}"
79 | for ch in string.ascii_lowercase[:donors_per_meta_batch_combo]
80 | ]
81 | donor_assignment[eligible_cell_idxs] = np.random.choice(
82 | meta_donors, replace=True, size=len(eligible_cell_idxs)
83 | )
84 | batch_donors += meta_donors
85 |
86 | adata.obs[donor_key] = donor_assignment
87 |
88 | donor_meta = adata.obs[donor_key].str.extractall(
89 | r"donor_meta\(([0-1]), ([0-1]), ([0-1])\)_batch[0-9]_[a-z]"
90 | )
91 | for match_idx, meta_key in enumerate(meta_keys):
92 | adata.obs[f"donor_{meta_key}"] = donor_meta[match_idx].astype(int).tolist()
93 |
94 |
95 | def construct_tree_semisynth(adata, depth_tree=3, dataset_name="semisynthetic"):
96 | """Modifies gene expression in two cell subpopulations according to a controlled
97 | donor tree structure
98 |
99 | """
100 | # construct donors
101 | n_donors = int(2**depth_tree)
102 | np.random.seed(0)
103 | random_donors = np.random.randint(0, n_donors, adata.n_obs)
104 | n_modules = sum([int(2**k) for k in range(1, depth_tree + 1)])
105 |
106 | # ct_key = snakemake.config[dataset_name]["keyMapping"]["cellTypeKey"]
107 | # cts = adata.obs.groupby(ct_key).size().sort_values(ascending=False)[:2]
108 | # ct1, ct2 = cts.index.values
109 | ct_key = "leiden"
110 | ct1, ct2 = "0", "1"
111 |
112 | # construct donor trees
113 | leaves_id = np.array(
114 | [format(i, "0{}b".format(depth_tree)) for i in range(n_donors)]
115 | ) # ids of leaves
116 |
117 | all_node_ids = []
118 | for dep in range(1, depth_tree + 1):
119 | node_ids = [format(i, "0{}b".format(dep)) for i in range(2**dep)]
120 | all_node_ids += node_ids # ids of all nodes in the tree
121 |
122 | def perturb_gene_exp(ct, X_perturbed, all_node_ids, leaves_id):
123 | leaves_id1 = leaves_id.copy()
124 | np.random.shuffle(leaves_id1)
125 | genes = np.arange(adata.n_vars)
126 | np.random.shuffle(genes)
127 | gene_modules = np.array_split(genes, n_modules)
128 | gene_modules = {
129 | node_id: gene_modules[i] for i, node_id in enumerate(all_node_ids)
130 | }
131 | gene_modules = {
132 | node_id: np.isin(np.arange(adata.n_vars), gene_modules[node_id])
133 | for node_id in all_node_ids
134 | }
135 | # convert to one hots to make life easier for saving
136 |
137 | # modifying gene expression
138 | # e.g., 001 has perturbed modules 0, 00, and 001
139 | subpop = adata.obs.loc[:, ct_key].values == ct
140 | print("perturbing {}".format(ct))
141 | for donor_id in range(n_donors):
142 | selected_pop = subpop & (random_donors == donor_id)
143 | leaf_id = leaves_id1[donor_id]
144 | perturbed_mod_ids = [leaf_id[:i] for i in range(1, depth_tree + 1)]
145 | perturbed_modules = np.zeros(adata.n_vars, dtype=bool)
146 | for id in perturbed_mod_ids:
147 | perturbed_modules = perturbed_modules | gene_modules[id]
148 |
149 | Xmat = X_perturbed[selected_pop].copy()
150 | print(
151 | "Perturbing {} genes in {} cells".format(
152 | perturbed_modules.sum(), selected_pop.sum()
153 | )
154 | )
155 | print(
156 | "Non-zero values in the relevant subpopulation and modules: ",
157 | (Xmat[:, perturbed_modules] != 0).sum(),
158 | )
159 | Xmat[:, perturbed_modules] = Xmat[:, perturbed_modules] * 2
160 |
161 | X_perturbed[selected_pop] = Xmat
162 | return X_perturbed, gene_modules, leaves_id1
163 |
164 | X_pert = adata.X.copy()
165 | X_pert, gene_mod1, leaves1 = perturb_gene_exp(ct1, X_pert, all_node_ids, leaves_id)
166 | X_pert, gene_mod2, leaves2 = perturb_gene_exp(ct2, X_pert, all_node_ids, leaves_id)
167 |
168 | gene_modules1 = pd.DataFrame(gene_mod1)
169 | gene_modules2 = pd.DataFrame(gene_mod2)
170 | donor_metadata = pd.DataFrame(
171 | dict(
172 | donor_id=np.arange(n_donors),
173 | tree_id1=leaves1,
174 | affected_ct1=ct1,
175 | tree_id2=leaves2,
176 | affected_ct2=ct2,
177 | )
178 | )
179 |
180 | meta_id1 = pd.DataFrame([list(x) for x in donor_metadata.tree_id1.values]).astype(
181 | int
182 | )
183 | n_features1 = meta_id1.shape[1]
184 | new_cols1 = ["tree_id1_{}".format(i) for i in range(n_features1)]
185 | donor_metadata.loc[:, new_cols1] = meta_id1.values
186 | meta_id2 = pd.DataFrame([list(x) for x in donor_metadata.tree_id2.values]).astype(
187 | int
188 | )
189 | n_features2 = meta_id2.shape[1]
190 | new_cols2 = ["tree_id2_{}".format(i) for i in range(n_features2)]
191 | donor_metadata.loc[:, new_cols2] = meta_id2.values
192 |
193 | donor_metadata.loc[:, new_cols1 + new_cols2] = donor_metadata.loc[
194 | :, new_cols1 + new_cols2
195 | ].astype("category")
196 |
197 | adata.obs.loc[:, "batch"] = random_donors
198 | adata.obs.loc[:, "Site"] = 1
199 | original_index = adata.obs.index.copy()
200 | adata.obs = adata.obs.merge(
201 | donor_metadata, left_on="batch", right_on="donor_id", how="left"
202 | )
203 | adata.obs.index = original_index
204 | adata.uns["gene_modules1"] = gene_modules1
205 | adata.uns["gene_modules2"] = gene_modules2
206 | adata.uns["donor_metadata"] = donor_metadata
207 | adata = sc.AnnData(X_pert, obs=adata.obs, var=adata.var, uns=adata.uns)
208 | return adata
209 |
210 |
211 | def process_snrna(adata):
212 | # We first remove samples processed with specific protocols that
213 | # are not used in other samples
214 | select_cell = ~(adata.obs["sample"].isin(["C41_CST", "C41_NST"]))
215 | adata = adata[select_cell].copy()
216 |
217 | # We also artificially subsample
218 | # one of the samples into two samples
219 | # as a negative control
220 | subplit_sample = "C72_RESEQ"
221 | mask_selected_lib = (adata.obs["sample"] == subplit_sample).values
222 | np.random.seed(0)
223 | mask_split = (
224 | np.random.randint(0, 2, size=mask_selected_lib.shape[0]).astype(bool)
225 | * mask_selected_lib
226 | )
227 | libraries = adata.obs["library_uuid"].astype(str)
228 | libraries.loc[mask_split] = libraries.loc[mask_split] + "_split"
229 | assert libraries.unique().shape[0] == 9
230 | adata.obs.loc[:, "library_uuid"] = libraries.astype("category")
231 | return adata
232 |
233 |
234 | def process_dataset(dataset_name, input_h5ad, output_h5ad):
235 | adata = sc.read_h5ad(input_h5ad)
236 |
237 | if isinstance(adata.X, sp.csc_matrix):
238 | adata.X = adata.X.tocsr()
239 |
240 | dataset_info = snakemake.config[dataset_name]
241 | preprocessing_kwargs = dataset_info.get("preprocessing")
242 | if preprocessing_kwargs is not None:
243 | if "subsample" in preprocessing_kwargs:
244 | sc.pp.subsample(adata, **preprocessing_kwargs.get("subsample"))
245 | if "filter_genes" in preprocessing_kwargs:
246 | sc.pp.filter_genes(adata, **preprocessing_kwargs.get("filter_genes"))
247 | if "highly_variable_genes" in preprocessing_kwargs:
248 | sc.pp.highly_variable_genes(
249 | adata, **preprocessing_kwargs.get("highly_variable_genes")
250 | )
251 | adata = adata[:, adata.var.highly_variable].copy()
252 | if dataset_name == "synthetic":
253 | assign_symsim_donors(adata)
254 | elif dataset_name == "semisynthetic":
255 | adata = construct_tree_semisynth(
256 | adata,
257 | depth_tree=4,
258 | )
259 | elif dataset_name == "snrna":
260 | adata = process_snrna(adata)
261 |
262 | create_obs_mapper(adata, dataset_name)
263 |
264 | adata.write(output_h5ad)
265 |
266 |
267 | if __name__ == "__main__":
268 | process_dataset(
269 | dataset_name=snakemake.wildcards.dataset,
270 | input_h5ad=snakemake.input[0],
271 | output_h5ad=snakemake.output[0],
272 | )
273 |
--------------------------------------------------------------------------------
/workflow/scripts/run_model.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 |
4 | import anndata as ad
5 | import numpy as np
6 | import scanpy as sc
7 | from utils import MILO, PCAKNN, CompositionBaseline, MrVIWrapper, SCVIModel
8 |
9 | profile = snakemake.config[snakemake.wildcards.dataset]
10 | N_EPOCHS = profile["nEpochs"]
11 | if "batchSize" in profile:
12 | BATCH_SIZE = profile["batchSize"]
13 | else:
14 | BATCH_SIZE = 256
15 |
16 | MRVI_BASE_MODEL_KWARGS = dict(
17 | observe_library_sizes=True,
18 | # n_latent_donor=5,
19 | px_kwargs=dict(n_hidden=32),
20 | pz_kwargs=dict(
21 | n_layers=1,
22 | n_hidden=32,
23 | ),
24 | )
25 | MRVI_BASE_TRAIN_KWARGS = dict(
26 | max_epochs=N_EPOCHS,
27 | early_stopping=True,
28 | early_stopping_patience=15,
29 | check_val_every_n_epoch=1,
30 | batch_size=BATCH_SIZE,
31 | train_size=0.9,
32 | )
33 |
34 | MODELS = dict(
35 | PCAKNN=(PCAKNN, dict()),
36 | MILO=(MILO, dict(model_kwargs=dict(embedding="mnn"))),
37 | MILOSCVI=(
38 | MILO,
39 | dict(
40 | model_kwargs=dict(
41 | embedding="scvi",
42 | dropout_rate=0.0,
43 | dispersion="gene",
44 | gene_likelihood="nb",
45 | ),
46 | train_kwargs=dict(
47 | batch_size=BATCH_SIZE,
48 | plan_kwargs=dict(lr=1e-2, n_epochs_kl_warmup=20),
49 | max_epochs=N_EPOCHS,
50 | early_stopping=True,
51 | early_stopping_patience=15,
52 | ),
53 | ),
54 | ),
55 | CompositionPCA=(
56 | CompositionBaseline,
57 | dict(
58 | model_kwargs=dict(
59 | dim_reduction_approach="PCA",
60 | n_dim=50,
61 | clustering_on="celltype",
62 | ),
63 | train_kwargs=None,
64 | ),
65 | ),
66 | CompositionSCVI=(
67 | CompositionBaseline,
68 | dict(
69 | model_kwargs=dict(
70 | dim_reduction_approach="SCVI",
71 | n_dim=10,
72 | clustering_on="celltype",
73 | ),
74 | train_kwargs=dict(
75 | batch_size=BATCH_SIZE,
76 | plan_kwargs=dict(lr=1e-2),
77 | max_epochs=N_EPOCHS,
78 | early_stopping=True,
79 | early_stopping_patience=15,
80 | ),
81 | ),
82 | ),
83 | SCVI=(
84 | SCVIModel,
85 | dict(
86 | model_kwargs=dict(
87 | dropout_rate=0.0,
88 | feed_nuisance=False,
89 | dispersion="gene",
90 | gene_likelihood="nb",
91 | ),
92 | train_kwargs=dict(
93 | batch_size=BATCH_SIZE,
94 | plan_kwargs=dict(lr=1e-2, n_epochs_kl_warmup=20),
95 | max_epochs=N_EPOCHS,
96 | early_stopping=True,
97 | early_stopping_patience=15,
98 | ),
99 | ),
100 | ),
101 | MrVISmall=(
102 | MrVIWrapper,
103 | dict(
104 | model_kwargs=dict(
105 | linear_decoder_zx=False,
106 | **MRVI_BASE_MODEL_KWARGS,
107 | ),
108 | train_kwargs=dict(
109 | plan_kwargs=dict(
110 | lr=1e-2, n_epochs_kl_warmup=20, do_comp=True, lambd=1.0
111 | ),
112 | **MRVI_BASE_TRAIN_KWARGS,
113 | ),
114 | ),
115 | ),
116 | MrVILinear=(
117 | MrVIWrapper,
118 | dict(
119 | model_kwargs=dict(
120 | linear_decoder_zx=True,
121 | n_latent=10,
122 | **MRVI_BASE_MODEL_KWARGS,
123 | ),
124 | train_kwargs=dict(
125 | plan_kwargs=dict(
126 | lr=1e-2, n_epochs_kl_warmup=20, do_comp=False, lambd=1.0
127 | ),
128 | **MRVI_BASE_TRAIN_KWARGS,
129 | ),
130 | ),
131 | ),
132 | MrVILinear50=(
133 | MrVIWrapper,
134 | dict(
135 | model_kwargs=dict(
136 | linear_decoder_zx=True,
137 | n_latent=50,
138 | n_latent_donor=2,
139 | **MRVI_BASE_MODEL_KWARGS,
140 | ),
141 | train_kwargs=dict(
142 | plan_kwargs=dict(
143 | lr=1e-2, n_epochs_kl_warmup=20, do_comp=False, lambd=1.0
144 | ),
145 | **MRVI_BASE_TRAIN_KWARGS,
146 | ),
147 | ),
148 | ),
149 | MrVILinear50COMP=(
150 | MrVIWrapper,
151 | dict(
152 | model_kwargs=dict(
153 | linear_decoder_zx=True,
154 | n_latent=50,
155 | n_latent_donor=2,
156 | **MRVI_BASE_MODEL_KWARGS,
157 | ),
158 | train_kwargs=dict(
159 | plan_kwargs=dict(
160 | lr=1e-2,
161 | n_epochs_kl_warmup=20,
162 | do_comp=True,
163 | lambd=0.1,
164 | ),
165 | **MRVI_BASE_TRAIN_KWARGS,
166 | ),
167 | ),
168 | ),
169 | MrVILinear10COMP=(
170 | MrVIWrapper,
171 | dict(
172 | model_kwargs=dict(
173 | linear_decoder_zx=True,
174 | n_latent=10,
175 | n_latent_donor=2,
176 | **MRVI_BASE_MODEL_KWARGS,
177 | ),
178 | train_kwargs=dict(
179 | plan_kwargs=dict(
180 | lr=1e-2,
181 | n_epochs_kl_warmup=20,
182 | do_comp=True,
183 | lambd=0.1,
184 | ),
185 | **MRVI_BASE_TRAIN_KWARGS,
186 | ),
187 | ),
188 | ),
189 | MrVILinearLinear10COMP=(
190 | MrVIWrapper,
191 | dict(
192 | model_kwargs=dict(
193 | linear_decoder_zx=True,
194 | linear_decoder_uz=True,
195 | n_latent=10,
196 | n_latent_donor=2,
197 | **MRVI_BASE_MODEL_KWARGS,
198 | ),
199 | train_kwargs=dict(
200 | plan_kwargs=dict(
201 | lr=1e-2,
202 | n_epochs_kl_warmup=20,
203 | do_comp=True,
204 | lambd=0.1,
205 | ),
206 | **MRVI_BASE_TRAIN_KWARGS,
207 | ),
208 | ),
209 | ),
210 | MrVILinearLinear10=(
211 | MrVIWrapper,
212 | dict(
213 | model_kwargs=dict(
214 | linear_decoder_zx=True,
215 | linear_decoder_uz=True,
216 | n_latent=10,
217 | n_latent_donor=2,
218 | **MRVI_BASE_MODEL_KWARGS,
219 | ),
220 | train_kwargs=dict(
221 | plan_kwargs=dict(
222 | lr=1e-2,
223 | n_epochs_kl_warmup=20,
224 | do_comp=False,
225 | lambd=0.1,
226 | ),
227 | **MRVI_BASE_TRAIN_KWARGS,
228 | ),
229 | ),
230 | ),
231 | MrVILinearLinear10SCALER=(
232 | MrVIWrapper,
233 | dict(
234 | model_kwargs=dict(
235 | linear_decoder_zx=True,
236 | linear_decoder_uz=True,
237 | linear_decoder_uz_scaler=True,
238 | n_latent=10,
239 | n_latent_donor=2,
240 | **MRVI_BASE_MODEL_KWARGS,
241 | ),
242 | train_kwargs=dict(
243 | plan_kwargs=dict(
244 | lr=1e-2,
245 | n_epochs_kl_warmup=20,
246 | do_comp=False,
247 | lambd=0.1,
248 | ),
249 | **MRVI_BASE_TRAIN_KWARGS,
250 | ),
251 | ),
252 | ),
253 | )
254 |
255 |
256 | def compute_model_predictions(model_name, path_to_h5ad, path_to_output, random_seed):
257 | np.random.seed(random_seed)
258 | adata = sc.read_h5ad(path_to_h5ad)
259 | logging.info("adata shape: {}".format(adata.shape))
260 | mapper = adata.uns["mapper"]
261 | algo_cls, algo_kwargs = MODELS[model_name]
262 | model = algo_cls(adata=adata, **algo_kwargs, **mapper)
263 | model.fit()
264 | if model.has_donor_representation:
265 | rep = model.get_donor_representation().assign(model=model_name)
266 | rep.columns = rep.columns.astype(str)
267 | adata.uns["{}_donor".format(model_name)] = rep
268 | if model.has_cell_representation:
269 | repb = model.get_cell_representation()
270 | adata.obsm["{}_cell".format(model_name)] = repb
271 | if model.has_local_donor_representation:
272 | _adata = None
273 | scores = model.get_local_sample_representation(adata=_adata)
274 | adata.uns["{}_local_donor_rep".format(model_name)] = scores
275 | if hasattr(model, "get_donor_representation_metadata"):
276 | metadata = model.get_donor_representation_metadata()
277 | adata.uns["{}_local_donor_rep_metadata".format(model_name)] = metadata
278 | if model.has_custom_representation:
279 | adata = model.compute()
280 |
281 | adata.uns["model_name"] = model_name
282 | adata.X = None
283 | adata_ = ad.AnnData(
284 | obs=adata.obs,
285 | var=adata.var,
286 | obsm=adata.obsm,
287 | uns=adata.uns,
288 | )
289 | adata_.write_h5ad(path_to_output)
290 | if model.has_save:
291 | dir_path = os.path.dirname(path_to_output)
292 | model_dir_path = os.path.join(dir_path, model_name)
293 | model.save(model_dir_path)
294 |
295 |
296 | if __name__ == "__main__":
297 | compute_model_predictions(
298 | model_name=snakemake.wildcards.model,
299 | path_to_h5ad=snakemake.input[0],
300 | path_to_output=snakemake.output[0],
301 | random_seed=int(snakemake.wildcards.seed),
302 | )
303 |
--------------------------------------------------------------------------------
/workflow/scripts/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from ._baselines import (
2 | PCAKNN,
3 | CompositionBaseline,
4 | CTypeProportions,
5 | PseudoBulkPCA,
6 | SCVIModel,
7 | StatifiedPseudoBulkPCA,
8 | )
9 | from ._metrics import (
10 | compute_cramers,
11 | compute_geary,
12 | compute_hotspot_morans,
13 | compute_ks,
14 | compute_manova,
15 | )
16 | from ._milo import MILO
17 | from ._mrvi import MrVIWrapper
18 |
19 | __all__ = [
20 | "MrVIWrapper",
21 | "CTypeProportions",
22 | "CompositionBaseline",
23 | "PseudoBulkPCA",
24 | "MILO",
25 | "SCVIModel",
26 | "StatifiedPseudoBulkPCA",
27 | "PCAKNN",
28 | "compute_geary",
29 | "compute_hotspot_morans",
30 | "compute_cramers",
31 | "compute_ks",
32 | "compute_manova",
33 | ]
34 |
--------------------------------------------------------------------------------
/workflow/scripts/utils/_base_model.py:
--------------------------------------------------------------------------------
1 | import scanpy as sc
2 |
3 |
4 | class BaseModelClass:
5 | has_donor_representation = False
6 | has_cell_representation = False
7 | has_local_donor_representation = False
8 | has_custom_representation = False
9 | has_save = False
10 |
11 | def __init__(
12 | self,
13 | adata,
14 | cell_type_key,
15 | donor_key,
16 | categorical_nuisance_keys=None,
17 | n_hvg=None,
18 | ):
19 | self.adata = adata
20 | self.cell_type_key = cell_type_key
21 | self.donor_key = donor_key
22 |
23 | self.n_genes = self.adata.X.shape[1]
24 | self.n_donors = self.adata.obs[self.donor_key].unique().shape[0]
25 | self.categorical_nuisance_keys = categorical_nuisance_keys
26 | self.n_hvg = n_hvg
27 |
28 | def get_donor_representation(self, adata=None):
29 | return None
30 |
31 | def _filter_hvg(self):
32 | if (self.n_hvg is not None) and (self.n_hvg <= self.n_genes - 1):
33 | adata_ = self.adata.copy()
34 | sc.pp.highly_variable_genes(
35 | adata=adata_, n_top_genes=self.n_hvg, flavor="seurat_v3"
36 | )
37 | self.adata = adata_[:, self.highly_variable]
38 |
39 | def preprocess_data(self):
40 | self._filter_hvg()
41 |
42 | def fit(self):
43 | return None
44 |
45 | def get_cell_representation(self, adata=None):
46 | return None
47 |
48 | def save(self, save_path):
49 | return
50 |
--------------------------------------------------------------------------------
/workflow/scripts/utils/_baselines.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import pynndescent
4 | import scanpy as sc
5 | import torch
6 | from scvi import REGISTRY_KEYS
7 | from scvi.model import SCVI
8 | from sklearn.decomposition import PCA
9 | from tqdm import tqdm
10 |
11 | from ._base_model import BaseModelClass
12 |
13 |
14 | class CTypeProportions(BaseModelClass):
15 | has_cell_representation = False
16 | has_donor_representation = True
17 |
18 | def __init__(self, **kwargs):
19 | super().__init__(**kwargs)
20 |
21 | def get_donor_representation(self, adata=None):
22 | adata = self.adata if adata is None else adata
23 | ct_props = (
24 | self.adata.obs.groupby(self.donor_key)[self.cell_type_key]
25 | .value_counts()
26 | .unstack()
27 | .fillna(0.0)
28 | .apply(lambda x: x / x.sum(), axis=1)
29 | )
30 | return ct_props
31 |
32 |
33 | class PseudoBulkPCA(BaseModelClass):
34 | has_cell_representation = False
35 | has_donor_representation = True
36 |
37 | def __init__(self, n_components=50, **kwargs):
38 | super().__init__(**kwargs)
39 | self.n_components = np.minimum(n_components, self.n_donors)
40 |
41 | def get_donor_representation(self, adata=None):
42 | adata = self.adata if adata is None else adata
43 | idx_donor = adata.obs[self.donor_key]
44 |
45 | X = np.zeros((self.n_donors, self.n_genes))
46 | unique_donors = idx_donor.unique()
47 | for idx, unique_donor in enumerate(unique_donors):
48 | cell_is_selected = idx_donor == unique_donor
49 | X[idx, :] = adata.X[cell_is_selected].sum(0)
50 | X_cpm = 1e6 * X / X.sum(-1, keepdims=True)
51 | X_logcpm = np.log1p(X_cpm)
52 | z = PCA(n_components=self.n_components).fit_transform(X_logcpm)
53 | return pd.DataFrame(z, index=unique_donors)
54 |
55 |
56 | class StatifiedPseudoBulkPCA(BaseModelClass):
57 | has_cell_representation = False
58 | has_donor_representation = False
59 | has_local_donor_representation = True
60 |
61 | def __init__(self, n_components=50, **kwargs):
62 | super().__init__(**kwargs)
63 | self.n_components = np.minimum(n_components, self.n_donors)
64 | self.cell_type_key = None
65 |
66 | def fit(self):
67 | pass
68 |
69 | def get_local_sample_representation(self, adata=None):
70 | self.cell_type_key = self.adata.uns["mapper"]["cell_type_key"]
71 | adata = self.adata
72 | idx_donor = adata.obs[self.donor_key]
73 | cell_types = adata.obs[self.cell_type_key]
74 |
75 | X = []
76 | unique_donors = idx_donor.unique()
77 | unique_types = cell_types.unique()
78 | reps_all = dict()
79 | for cell_type in unique_types:
80 | X = []
81 | donors = []
82 | # Computing the pseudo-bulk for each cell type
83 | for unique_donor in unique_donors:
84 | cell_is_selected = (idx_donor == unique_donor) & (
85 | cell_types == cell_type
86 | )
87 | new_counts = adata.X[cell_is_selected].sum(0)
88 | if new_counts.sum() > 0:
89 | X.append(new_counts)
90 | donors.append(unique_donor)
91 | X = np.array(X).squeeze(1)
92 | donors = np.array(donors).astype(str)
93 | X_cpm = 1e6 * X / X.sum(-1, keepdims=True)
94 | X_logcpm = np.log1p(X_cpm)
95 | n_comps = np.minimum(self.n_components, X.shape[0])
96 | z = PCA(n_components=n_comps).fit_transform(X_logcpm)
97 | reps = pd.DataFrame(z, index=donors)
98 | reps.columns = ["PC_{}".format(i) for i in range(n_comps)]
99 | # reps.index = pd.CategoricalIndex(donors, categories=donors)
100 | reps_all[cell_type] = reps
101 | return reps_all
102 |
103 |
104 | class CompositionBaseline(BaseModelClass):
105 | has_cell_representation = True
106 | has_donor_representation = False
107 | has_local_donor_representation = True
108 |
109 | default_model_kwargs = dict(
110 | dim_reduction_approach="PCA",
111 | n_dim=50,
112 | clustering_on="leiden",
113 | )
114 | default_train_kwargs = {}
115 |
116 | def __init__(self, model_kwargs=None, train_kwargs=None, **kwargs):
117 | super().__init__(**kwargs)
118 | self.model_kwargs = (
119 | model_kwargs if model_kwargs is not None else self.default_model_kwargs
120 | )
121 | self.train_kwargs = (
122 | train_kwargs if train_kwargs is not None else self.default_train_kwargs
123 | )
124 |
125 | self.dim_reduction_approach = model_kwargs.pop("dim_reduction_approach")
126 | self.n_dim = model_kwargs.pop("n_dim")
127 | self.clustering_on = model_kwargs.pop(
128 | "clustering_on"
129 | ) # one of leiden, celltype
130 |
131 | def preprocess_data(self):
132 | super().preprocess_data()
133 | self.adata_ = self.adata.copy()
134 | if self.dim_reduction_approach == "PCA":
135 | self.adata_ = self.adata.copy()
136 | sc.pp.normalize_total(self.adata_, target_sum=1e4)
137 | sc.pp.log1p(self.adata_)
138 |
139 | def fit(self):
140 | self.preprocess_data()
141 | if self.dim_reduction_approach == "PCA":
142 | sc.pp.pca(self.adata_, n_comps=self.n_dim) # saves "X_pca" in obsm
143 | self.adata_.obsm["X_red"] = self.adata_.obsm["X_pca"]
144 | elif self.dim_reduction_approach == "SCVI":
145 | SCVI.setup_anndata(
146 | self.adata_, categorical_covariate_keys=self.categorical_nuisance_keys
147 | )
148 | scvi_model = SCVI(self.adata_, **self.model_kwargs)
149 | scvi_model.train(**self.train_kwargs)
150 | self.adata_.obsm["X_red"] = scvi_model.get_latent_representation()
151 |
152 | def get_cell_representation(self, adata=None):
153 | assert adata is None
154 | return self.adata_.obsm["X_red"]
155 |
156 | def get_local_sample_representation(self, adata=None):
157 | if self.clustering_on == "leiden":
158 | sc.pp.neighbors(self.adata_, n_neighbors=30, use_rep="X_red")
159 | sc.tl.leiden(self.adata_, resolution=1.0, key_added="leiden_1.0")
160 | clustering_key = "leiden_1.0"
161 | elif self.clustering_on == "celltype":
162 | clustering_key = self.cell_type_key
163 |
164 | freqs_all = dict()
165 | for unique_cluster in self.adata_.obs[clustering_key].unique():
166 | cell_is_selected = self.adata_.obs[clustering_key] == unique_cluster
167 | subann = self.adata_[cell_is_selected].copy()
168 |
169 | # Step 1: subcluster
170 | sc.pp.neighbors(subann, n_neighbors=30, use_rep="X_red")
171 | sc.tl.leiden(subann, resolution=1.0, key_added=clustering_key)
172 |
173 | szs = (
174 | subann.obs.groupby([clustering_key, self.donor_key])
175 | .size()
176 | .to_frame("n_cells")
177 | .reset_index()
178 | )
179 | szs_total = (
180 | szs.groupby(self.donor_key)
181 | .sum()
182 | .rename(columns={"n_cells": "n_cells_total"})
183 | )
184 | comps = szs.merge(szs_total, on=self.donor_key).assign(
185 | freqs=lambda x: x.n_cells / x.n_cells_total
186 | )
187 | freqs = (
188 | comps.loc[:, [self.donor_key, clustering_key, "freqs"]]
189 | .set_index([self.donor_key, clustering_key])
190 | .squeeze()
191 | .unstack()
192 | )
193 | freqs_ = freqs
194 | freqs_all[unique_cluster] = freqs_
195 | # n_donors, n_clusters
196 | return freqs_all
197 |
198 | def get_donor_representation_metadata(self, adata=None):
199 | pass
200 |
201 |
202 | class PCAKNN(BaseModelClass):
203 | has_cell_representation = True
204 | has_donor_representation = False
205 | has_local_donor_representation = True
206 |
207 | def __init__(self, n_components=25, **kwargs):
208 | super().__init__(**kwargs)
209 | self.n_components = n_components
210 | self.pca = None
211 | self.adata_ = None
212 | self.donor_order = None
213 |
214 | def preprocess_data(self):
215 | super().preprocess_data()
216 | self.adata_ = self.adata.copy()
217 | sc.pp.normalize_total(self.adata_, target_sum=1e4)
218 | sc.pp.log1p(self.adata_)
219 |
220 | def fit(self):
221 | self.preprocess_data()
222 | sc.pp.pca(self.adata_, n_comps=self.n_components)
223 |
224 | def get_cell_representation(self, adata=None):
225 | assert adata is None
226 | return self.adata_.obsm["X_pca"]
227 |
228 | def get_local_sample_representation(self, adata=None):
229 | # for each cell, compute nearest neighbor in given donor
230 | pca_rep = self.adata_.obsm["X_pca"]
231 |
232 | local_reps = []
233 | self.donor_order = self.adata_.obs[self.donor_key].unique()
234 | for donor in self.donor_order:
235 | donor_is_selected = self.adata_.obs[self.donor_key] == donor
236 | pca_donor = pca_rep[donor_is_selected]
237 | nn_donor = pynndescent.NNDescent(pca_donor)
238 | nn_indices = nn_donor.query(pca_rep, k=1)[0].squeeze(-1)
239 | nn_rep = pca_donor[nn_indices][:, None, :]
240 | local_reps.append(nn_rep)
241 | local_reps = np.concatenate(local_reps, axis=1)
242 | return local_reps
243 |
244 | def get_donor_representation_metadata(self):
245 | donor_to_id_map = pd.DataFrame(
246 | self.donor_order, columns=[self.donor_key]
247 | ).assign(donor_order=lambda x: np.arange(len(x)))
248 | res = self.adata_.obs.drop_duplicates(self.donor_key)
249 | res = res.merge(donor_to_id_map, on=self.donor_key, how="left").sort_values(
250 | "donor_order"
251 | )
252 | return res
253 |
254 |
255 | class SCVIModel(BaseModelClass):
256 | has_cell_representation = True
257 | has_local_donor_representation = True
258 | has_save = True
259 |
260 | default_model_kwargs = dict(
261 | dropout_rate=0.0,
262 | dispersion="gene",
263 | gene_likelihood="nb",
264 | )
265 | default_train_kwargs = dict(
266 | max_epochs=100,
267 | check_val_every_n_epoch=1,
268 | batch_size=256,
269 | plan_kwargs=dict(lr=1e-2, n_epochs_kl_warmup=20),
270 | )
271 |
272 | def __init__(self, model_kwargs=None, train_kwargs=None, **kwargs):
273 | super().__init__(**kwargs)
274 | self.model_kwargs = (
275 | self.default_model_kwargs if model_kwargs is None else model_kwargs
276 | )
277 | self.train_kwargs = (
278 | self.default_train_kwargs if train_kwargs is None else train_kwargs
279 | )
280 | self.adata_ = None
281 |
282 | @property
283 | def has_donor_representation(self):
284 | has_donor_embedding = self.model_kwargs.get("do_batch_embedding", False)
285 | return has_donor_embedding
286 |
287 | def fit(self):
288 | self.preprocess_data()
289 | adata_ = self.adata.copy()
290 |
291 | feed_nuisance = self.model_kwargs.pop("feed_nuisance", True)
292 | categorical_nuisance_keys = (
293 | self.categorical_nuisance_keys if feed_nuisance else None
294 | )
295 | SCVI.setup_anndata(
296 | adata_,
297 | batch_key=self.donor_key,
298 | categorical_covariate_keys=categorical_nuisance_keys,
299 | )
300 | self.adata_ = adata_
301 | self.model = SCVI(adata=adata_, **self.model_kwargs)
302 | self.model.train(**self.train_kwargs)
303 | return True
304 |
305 | def save(self, save_path, overwrite=True):
306 | self.model.save(save_path, overwrite=overwrite)
307 | return True
308 |
309 | # def get_donor_representation(self, adata=None):
310 | # assert self.has_donor_representation
311 |
312 | def get_cell_representation(self, adata=None, batch_size=256):
313 | return self.model.get_latent_representation(adata, batch_size=batch_size)
314 |
315 | @torch.no_grad()
316 | def get_local_sample_representation(
317 | self, adata=None, batch_size=256, x_dim=50, eps=1e-8, mc_samples=10
318 | ):
319 | # z = self.model.get_latent_representation(adata, batch_size=batch_size)
320 | # index = pynndescent.NNDescent(z, n_neighbors=n_neighbors)
321 | # index.prepare()
322 |
323 | # neighbors, _ = index.query(z)
324 | # return neighbors[:, 1:]
325 |
326 | adata = self.adata_ if adata is None else adata
327 | self.model._check_if_trained(warn=False)
328 | adata = self.model._validate_anndata(adata)
329 | scdl = self.model._make_data_loader(
330 | adata=adata, indices=None, batch_size=batch_size
331 | )
332 |
333 | # hs = self.get_normalized_expression(adata, batch_size=batch_size, eps=eps)
334 | hs = []
335 | for tensors in tqdm(scdl):
336 | inference_inputs = self.model.module._get_inference_input(tensors)
337 | inference_outputs = self.model.module.inference(**inference_inputs)
338 |
339 | generative_inputs = self.model.module._get_generative_input(
340 | tensors=tensors, inference_outputs=inference_outputs
341 | )
342 | generative_outputs = self.model.module.generative(**generative_inputs)
343 | new = (eps + generative_outputs["px"].scale).log()
344 | hs.append(new.cpu())
345 | hs = torch.cat(hs, dim=0).numpy()
346 | means = np.mean(hs, axis=0)
347 | stds = np.std(hs, axis=0)
348 | hs = (hs - means) / stds
349 | pca = PCA(n_components=x_dim).fit(hs)
350 | w = torch.tensor(
351 | pca.components_, dtype=torch.float32, device=self.model.device
352 | ).T
353 | means = torch.tensor(means, dtype=torch.float32, device=self.model.device)
354 | stds = torch.tensor(stds, dtype=torch.float32, device=self.model.device)
355 |
356 | reps = []
357 | for tensors in tqdm(scdl):
358 | xs = []
359 | for batch in range(self.model.summary_stats.n_batch):
360 | cf_batch = batch * torch.ones_like(tensors["batch"])
361 | tensors[REGISTRY_KEYS.BATCH_KEY] = cf_batch
362 | inference_inputs = self.model.module._get_inference_input(tensors)
363 | inference_outputs = self.model.module.inference(
364 | n_samples=mc_samples, **inference_inputs
365 | )
366 |
367 | generative_inputs = self.model.module._get_generative_input(
368 | tensors=tensors, inference_outputs=inference_outputs
369 | )
370 | generative_outputs = self.model.module.generative(**generative_inputs)
371 | new = (eps + generative_outputs["px"].scale).log()
372 | if x_dim is not None:
373 | new = (new - means) / stds
374 | new = new @ w
375 | xs.append(new[:, :, None])
376 |
377 | xs = torch.cat(xs, 2).mean(0)
378 | reps.append(xs.cpu().numpy())
379 | # n_cells, n_donors, n_donors
380 | reps = np.concatenate(reps, 0)
381 | return reps
382 |
383 | def get_donor_representation_metadata(self):
384 | return (
385 | self.model.adata.obs.drop_duplicates("_scvi_batch")
386 | .set_index("_scvi_batch")
387 | .sort_index()
388 | )
389 |
--------------------------------------------------------------------------------
/workflow/scripts/utils/_metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import scipy.stats as stats
4 | import statsmodels.api as sm
5 | import torch
6 | from joblib import Parallel, delayed
7 | from sklearn.preprocessing import OneHotEncoder
8 | from statsmodels.multivariate.manova import MANOVA
9 |
10 |
11 | def smooth_distance(sq_dists, mask_farther_than_k=False):
12 | n_donors = sq_dists.shape[-1]
13 | k = int(np.sqrt(n_donors))
14 | topk_vals, topk_idx = per_cell_knn(sq_dists, k=k)
15 | bandwidth_idx = k // 3
16 | bandwidth_vals = topk_vals[:, :, bandwidth_idx].unsqueeze(
17 | -1
18 | ) # n_cells x n_donors x 1
19 | w_mtx = torch.exp(-sq_dists / bandwidth_vals) # n_cells x n_donors x n_donors
20 |
21 | if mask_farther_than_k:
22 | masked_w_mtx = torch.zeros_like(w_mtx)
23 | masked_w_mtx = masked_w_mtx.scatter(
24 | -1, topk_idx, w_mtx
25 | ) # n_cells x n_donors x n_donors
26 | w_mtx = masked_w_mtx
27 |
28 | return w_mtx
29 |
30 |
31 | @torch.no_grad()
32 | def per_cell_knn(dists, k):
33 | # Given n_cells x n_donors x n_donors returns n_cells x n_donors x k
34 | # tensor of dists and indices of the k nearest neighbors for each donor in each cell
35 | topkp1 = torch.topk(dists, k + 1, dim=-1, largest=False, sorted=True)
36 | topk_values, topk_indices = topkp1.values[:, :, 1:], topkp1.indices[:, :, 1:]
37 | return topk_values, topk_indices
38 |
39 |
40 | @torch.no_grad()
41 | def compute_geary(xs, donor_labels):
42 | oh_feats = OneHotEncoder(sparse=False).fit_transform(
43 | donor_labels[:, None]
44 | ) # n_donors x n_labels
45 | w_mat = torch.tensor(oh_feats @ oh_feats.T, device="cuda") # n_donors x n_donors
46 | xs = xs.unsqueeze(-2) # n_cells x n_donors x 1 x n_donor_latent
47 | sq_dists = ((xs - xs.transpose(-2, -3)) ** 2).sum(
48 | -1
49 | ) # n_cells x n_donors x n_donors
50 | scores_ = (sq_dists * w_mat).sum([-1, -2]) / w_mat.sum([-1, -2])
51 | var_estim = xs.var(1).sum(-1).squeeze()
52 | scores_ = scores_ / (2.0 * var_estim)
53 | return scores_.cpu().numpy()
54 |
55 |
56 | @torch.no_grad()
57 | def compute_hotspot_morans(xs, donor_labels):
58 | oh_feats = OneHotEncoder(sparse=False).fit_transform(donor_labels[:, None])
59 | like_label_mtx = torch.tensor(oh_feats @ oh_feats.T, device="cuda")
60 | xx_mtx = (like_label_mtx * 2) - 1 # n_donors x n_donors
61 |
62 | xs = xs.unsqueeze(-2) # n_cells x n_donors x 1 x n_donor_latent
63 | sq_dists = ((xs - xs.transpose(-2, -3)) ** 2).sum(
64 | -1
65 | ) # n_cells x n_donors x n_donors
66 | w_mtx = smooth_distance(sq_dists)
67 | w_norm_mtx = w_mtx / w_mtx.sum(-1, keepdim=True) # n_cells x n_donors x n_donors
68 |
69 | scores_ = (w_norm_mtx * xx_mtx).sum([-1, -2]) # n_cells
70 | return scores_.cpu().numpy()
71 |
72 |
73 | @torch.no_grad()
74 | def compute_cramers(xs, donor_labels):
75 | oh_feats = OneHotEncoder(sparse=False).fit_transform(
76 | donor_labels[:, None]
77 | ) # n_donors x n_labels
78 | oh_feats = torch.tensor(oh_feats, device="cuda").float() # n_donors x n_labels
79 |
80 | xs = xs.unsqueeze(-2) # n_cells x n_donors x 1 x n_donor_latent
81 | sq_dists = ((xs - xs.transpose(-2, -3)) ** 2).sum(
82 | -1
83 | ) # n_cells x n_donors x n_donors
84 | w_mtx = smooth_distance(sq_dists, mask_farther_than_k=True)
85 |
86 | c_ij = w_mtx @ oh_feats # n_cells x n_donors x n_labels
87 | contingency_X = c_ij.transpose(-1, -2) @ oh_feats # n_cells x n_labels x n_labels
88 |
89 | scores = []
90 | sig_scores = []
91 | for i in range(contingency_X.shape[0]):
92 | contingency_Xi = contingency_X[i].cpu().numpy()
93 | chi_sq, sig_score, _, _ = stats.chi2_contingency(contingency_Xi)
94 | n = np.sum(contingency_Xi)
95 | min_dim = contingency_Xi.shape[0] - 1
96 | scores.append(np.sqrt(chi_sq / (n * min_dim)))
97 | sig_scores.append(sig_score)
98 | return np.array(scores), np.array(sig_scores)
99 |
100 |
101 | def _random_subsample(d1, n_mc_samples):
102 | n_minibatch = d1.shape[0]
103 | if d1.shape[0] >= n_mc_samples:
104 | d1_ = torch.zeros(n_minibatch, n_mc_samples)
105 | for i in range(n_minibatch):
106 | rdm_idx1 = torch.randperm(d1.shape[0])[:n_mc_samples]
107 | d1_[i] = d1[i][rdm_idx1]
108 | return d1
109 |
110 |
111 | @torch.no_grad()
112 | def compute_manova(
113 | xs,
114 | donor_labels,
115 | ):
116 | target = pd.Series(donor_labels, dtype="category").cat.codes
117 | target = sm.add_constant(target)
118 |
119 | def _compute(xi):
120 | try:
121 | res = (
122 | MANOVA(endog=xi, exog=target)
123 | .mv_test()
124 | .results["x1"]["stat"]
125 | .loc["Wilks' lambda", ["F Value", "Pr > F"]]
126 | .values
127 | )
128 | except Exception:
129 | return np.array([1000.0, 1.0])
130 | return res
131 |
132 | all_res = np.array(
133 | Parallel(n_jobs=10)(delayed(_compute)(xi) for xi in xs), dtype=np.float32
134 | )
135 | if isinstance(xs, np.ndarray):
136 | xs = torch.from_numpy(xs).to(torch.float32)
137 | all_res = np.array(
138 | Parallel(n_jobs=10)(delayed(_compute)(xi.cpu().numpy()) for xi in xs),
139 | dtype=np.float32,
140 | )
141 | stats, pvals = all_res[:, 0], all_res[:, 1]
142 | return stats, pvals
143 |
144 |
145 | def compute_ks(
146 | xs, donor_labels, n_mc_samples=5000, alternative="two-sided", do_smoothing=False
147 | ):
148 | n_minibatch = xs.shape[0]
149 | oh_feats = OneHotEncoder(sparse=False).fit_transform(
150 | donor_labels[:, None]
151 | ) # n_donors x n_labels
152 | oh_feats = torch.tensor(oh_feats, device="cuda").float() # n_donors x n_labels
153 | wmat = oh_feats @ oh_feats.T
154 | wmat = wmat - torch.eye(wmat.shape[0], device="cuda").float()
155 | is_off_diag = 1.0 - torch.eye(wmat.shape[0], device="cuda").float()
156 |
157 | xs = xs.unsqueeze(-2)
158 | sq_dists = ((xs - xs.transpose(-2, -3)) ** 2).sum(
159 | -1
160 | ) # n_cells x n_donors x n_donors
161 | if do_smoothing:
162 | sq_dists = smooth_distance(sq_dists)
163 | sq_dists = sq_dists.reshape(n_minibatch, -1)
164 |
165 | d1 = sq_dists[:, wmat.reshape(-1).bool()]
166 | d1 = _random_subsample(d1, n_mc_samples).cpu().numpy()
167 | d2 = sq_dists[:, is_off_diag.reshape(-1).bool()]
168 | d2 = _random_subsample(d2, n_mc_samples).cpu().numpy()
169 |
170 | ks_results = stats.ttest_ind(d1, d2, 1, equal_var=False)
171 | effect_sizes = ks_results.statistic
172 | p_values = ks_results.pvalue
173 | return effect_sizes, p_values
174 |
--------------------------------------------------------------------------------
/workflow/scripts/utils/_milo.py:
--------------------------------------------------------------------------------
1 | import scanpy as sc
2 | from scvi.model import SCVI
3 |
4 | from ._base_model import BaseModelClass
5 |
6 |
7 | class MILO(BaseModelClass):
8 | has_donor_representation = False
9 | has_cell_representation = False
10 | has_local_donor_representation = False
11 | has_custom_representation = True
12 | default_model_kwargs = dict(
13 | dropout_rate=0.0,
14 | dispersion="gene",
15 | gene_likelihood="nb",
16 | )
17 | default_train_kwargs = dict(
18 | max_epochs=100,
19 | check_val_every_n_epoch=1,
20 | batch_size=256,
21 | plan_kwargs=dict(lr=1e-2, n_epochs_kl_warmup=20),
22 | )
23 |
24 | def __init__(self, model_kwargs=None, train_kwargs=None, **kwargs):
25 | super().__init__(**kwargs)
26 | self.model_kwargs = (
27 | self.default_model_kwargs if model_kwargs is None else model_kwargs
28 | )
29 | self.train_kwargs = (
30 | self.default_train_kwargs if train_kwargs is None else train_kwargs
31 | )
32 |
33 | def fit(self, **kwargs):
34 | import milopy.core as milo
35 |
36 | adata_ = self.adata.copy()
37 | embedding = self.model_kwargs.pop("embedding", "mnn")
38 |
39 | if embedding == "mnn":
40 | # Run MNN
41 | alldata = []
42 | batch_key = self.categorical_nuisance_keys[0]
43 | adata_.obs[batch_key] = adata_.obs[batch_key].astype("category")
44 | for batch_cat in adata_.obs[batch_key].cat.categories.tolist():
45 | alldata.append(adata_[adata_.obs[batch_key] == batch_cat,])
46 |
47 | cdata = sc.external.pp.mnn_correct(
48 | *alldata, svd_dim=50, batch_key=batch_key, n_jobs=8
49 | )[0]
50 | if isinstance(cdata, tuple):
51 | cdata = cdata[0]
52 |
53 | # Run PCA
54 | cell_rep = sc.tl.pca(cdata.X, svd_solver="arpack", return_info=False)
55 | elif embedding == "scvi":
56 | # Run scVI
57 | self.preprocess_data()
58 | adata_ = self.adata.copy()
59 |
60 | batch_key = self.categorical_nuisance_keys[0]
61 | SCVI.setup_anndata(
62 | adata_,
63 | batch_key=batch_key,
64 | categorical_covariate_keys=(
65 | self.categorical_nuisance_keys[1:]
66 | if len(self.categorical_nuisance_keys) > 1
67 | else None
68 | ),
69 | )
70 | self.adata_ = adata_
71 | scvi_model = SCVI(adata_, **self.model_kwargs)
72 | scvi_model.train(**self.train_kwargs)
73 | cell_rep = scvi_model.get_latent_representation()
74 | else:
75 | raise ValueError(f"Unknown embedding: {self.embedding}")
76 |
77 | adata_.obsm["_cell_rep"] = cell_rep
78 | sc.pp.neighbors(adata_, n_neighbors=10, use_rep="_cell_rep")
79 |
80 | ## Assign cells to neighbourhoods
81 | milo.make_nhoods(adata_, prop=0.05)
82 |
83 | ## Count cells from each sample in each nhood
84 | milo.count_nhoods(adata_, sample_col=self.donor_key)
85 | self.adata_ = adata_
86 |
87 | def compute(self):
88 | return self.adata_
89 |
--------------------------------------------------------------------------------
/workflow/scripts/utils/_mrvi.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import torch
4 | from tqdm import tqdm
5 |
6 | from mrvi import MrVI
7 |
8 | from ._base_model import BaseModelClass
9 |
10 |
11 | class MrVIWrapper(BaseModelClass):
12 | has_cell_representation = True
13 | has_donor_representation = True
14 | has_local_donor_representation = True
15 | has_save = True
16 |
17 | default_model_kwargs = dict(
18 | observe_library_sizes=True,
19 | n_latent_donor=5,
20 | )
21 | default_train_kwargs = dict(
22 | max_epochs=100,
23 | check_val_every_n_epoch=1,
24 | batch_size=256,
25 | plan_kwargs=dict(lr=1e-2, n_epochs_kl_warmup=20),
26 | )
27 | model_name = "mrvi"
28 |
29 | def __init__(self, model_kwargs=None, train_kwargs=None, **kwargs):
30 | super().__init__(**kwargs)
31 | self.model_kwargs = self.default_model_kwargs.copy()
32 | if model_kwargs is not None:
33 | self.model_kwargs.update(model_kwargs)
34 | self.train_kwargs = self.default_train_kwargs.copy()
35 | if train_kwargs is not None:
36 | self.train_kwargs.update(train_kwargs)
37 |
38 | def fit(self):
39 | self.preprocess_data()
40 | adata_ = self.adata.copy()
41 | self.adata_ = adata_
42 |
43 | MrVI.setup_anndata(
44 | adata_,
45 | batch_key=self.donor_key,
46 | categorical_nuisance_keys=self.categorical_nuisance_keys,
47 | categorical_biological_keys=None,
48 | )
49 | self.model = MrVI(adata=adata_, **self.model_kwargs)
50 | self.model.train(**self.train_kwargs)
51 | return True
52 |
53 | def save(self, save_path, overwrite=True):
54 | self.model.save(save_path, overwrite=overwrite)
55 | return True
56 |
57 | @classmethod
58 | def load(self, adata, save_path):
59 | mapper = adata.uns["mapper"]
60 | cls = MrVIWrapper(adata=adata, **mapper)
61 | cls.model = MrVI.load(save_path, adata)
62 | adata_ = cls.adata.copy()
63 | cls.adata_ = adata_
64 | return cls
65 |
66 | def get_donor_representation(self):
67 | d_embeddings = self.model.module.donor_embeddings.weight.cpu().detach().numpy()
68 | index_ = (
69 | self.adata_.obs.drop_duplicates("_scvi_batch")
70 | .set_index("_scvi_batch")
71 | .sort_index()
72 | .loc[:, self.donor_key]
73 | )
74 | return pd.DataFrame(d_embeddings, index=index_)
75 |
76 | def get_cell_representation(self, adata=None, batch_size=512, give_z=False):
77 | return self.model.get_latent_representation(
78 | adata, batch_size=batch_size, give_z=give_z
79 | )
80 |
81 | @torch.no_grad()
82 | def get_normalized_expression(
83 | self,
84 | adata=None,
85 | x_log=True,
86 | batch_size=256,
87 | eps=1e-6,
88 | cf_site=0.0,
89 | ):
90 | adata = self.adata_ if adata is None else adata
91 | self.model._check_if_trained(warn=False)
92 | adata = self.model._validate_anndata(adata)
93 | scdl = self.model._make_data_loader(
94 | adata=adata, indices=None, batch_size=batch_size
95 | )
96 |
97 | reps = []
98 | for tensors in tqdm(scdl):
99 | xs = []
100 | if cf_site is not None:
101 | tensors[
102 | "categorical_nuisance_keys"
103 | ] *= cf_site # set to 0 all nuisance factors
104 | inference_inputs = self.model.module._get_inference_input(tensors)
105 | outputs_n = self.model.module.inference(use_mean=True, **inference_inputs)
106 | outs_g = self.model.module.generative(
107 | **self.model.module._get_generative_input(
108 | tensors, inference_outputs=outputs_n
109 | )
110 | )
111 | xs = outs_g["h"]
112 | if x_log:
113 | xs = (eps + xs).log()
114 | reps.append(xs.cpu().numpy())
115 | # n_cells, n_donors, n_donors
116 | reps = np.concatenate(reps, 0)
117 | return reps
118 |
119 | @torch.no_grad()
120 | def get_average_expression(
121 | self,
122 | adata=None,
123 | indices=None,
124 | batch_size=256,
125 | eps=1e-6,
126 | mc_samples=10,
127 | ):
128 | adata = self.adata_ if adata is None else adata
129 | # self.model._check_if_trained(warn=False)
130 | # adata = self.model._validate_anndata(adata)
131 | scdl = self.model._make_data_loader(
132 | adata=adata, indices=indices, batch_size=batch_size
133 | )
134 |
135 | reps = np.zeros((self.model.summary_stats.n_batch, adata.n_vars))
136 | for tensors in tqdm(scdl):
137 | xs = []
138 | for batch in range(self.model.summary_stats.n_batch):
139 | tensors[
140 | "categorical_nuisance_keys"
141 | ] *= 0.0 # set to 0 all nuisance factors
142 |
143 | cf_batch = batch * torch.ones_like(tensors["batch"])
144 | inference_inputs = self.model.module._get_inference_input(tensors)
145 | inference_outputs = self.model.module.inference(
146 | n_samples=mc_samples, cf_batch=cf_batch, **inference_inputs
147 | )
148 | generative_inputs = self.model.module._get_generative_input(
149 | tensors=tensors, inference_outputs=inference_outputs
150 | )
151 | generative_outputs = self.model.module.generative(**generative_inputs)
152 | new = generative_outputs["h"]
153 | new = (eps + generative_outputs["h"]).log()
154 | xs.append(new[:, :, None])
155 |
156 | xs = torch.cat(xs, 2).mean(0) # size (n_cells, n_donors, n_genes)
157 | reps += xs.mean(0).cpu().numpy()
158 | return reps
159 |
160 | @torch.no_grad()
161 | def get_local_sample_representation(self, adata=None, **kwargs):
162 | return self.model.get_local_sample_representation(adata=adata, **kwargs)
163 |
164 | def get_donor_representation_metadata(self):
165 | return (
166 | self.model.adata.obs.drop_duplicates("_scvi_batch")
167 | .set_index("_scvi_batch")
168 | .sort_index()
169 | )
170 |
--------------------------------------------------------------------------------