├── .gitignore
├── .readthedocs.yaml
├── LICENSE
├── README.md
├── data
└── README.md
├── dev-requirements.txt
├── docs
├── Makefile
├── make.bat
└── source
│ ├── conf.py
│ ├── index.rst
│ ├── installation.rst
│ ├── intro.rst
│ ├── notebooks
│ ├── Starfysh_tutorial_real.ipynb
│ ├── Starfysh_tutorial_simulation.ipynb
│ └── update_tutorial.sh
│ └── starfysh.rst
├── figure
├── github_figure_1.png
├── github_figure_2.png
└── logo.png
├── notebooks
├── Starfysh_tutorial_integration.ipynb
├── Starfysh_tutorial_real-with_poe.ipynb
├── Starfysh_tutorial_real_without_poe.ipynb
├── Starfysh_tutorial_real_wo_signatures.ipynb
├── Starfysh_tutorial_simulation.ipynb
├── generate_image.ipynb
└── slideseq_starfysh_tutorial_on_later.ipynb
├── reinstall.sh
├── requirements.txt
├── setup.py
├── starfysh
├── AA.py
├── __init__.py
├── dataloader.py
├── gener_img.py
├── plot_utils.py
├── post_analysis.py
├── starfysh.py
├── utils.py
└── utils_integrate.py
└── tests
└── test_modules.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # OS generated files
2 | .DS_Store
3 |
4 | # Compiled
5 | __pycache__/
6 | *.py[cod]
7 |
8 | # R code
9 | .Rdata
10 | .Rhistory
11 |
12 | # Packaging
13 | build/
14 | dist/
15 | *.egg
16 | *.egg-info/
17 | lib/
18 | lib64/
19 | bdist.linux-x86_64/
20 | MANIFEST
21 |
22 | # Environments
23 | .env
24 | env/
25 | .venv
26 | /venv/
27 |
28 | # Others
29 | .csv
30 | .idea
31 | .ipynb_checkpoints
32 | data*/
33 | **/results/
34 |
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # Required
2 | version: 2
3 |
4 | # Set the version of Python and other tools you might need
5 | build:
6 | os: ubuntu-20.04
7 | tools:
8 | python: "3.8"
9 |
10 | # Build documentation in the docs/ directory with Sphinx
11 | sphinx:
12 | configuration: docs/source/conf.py
13 | builder: html
14 |
15 | # If using Sphinx, optionally build your docs in additional formats such as PDF
16 | # formats:
17 | # - pdf
18 |
19 | # Optionally declare the Python requirements required to build your docs
20 | python:
21 | install:
22 | - requirements: requirements.txt
23 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2022, azizilab
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | 3. Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | [](https://readthedocs.org/projects/starfysh/badge/?version=latest)
4 | [](https://github.com/azizilab/starfysh/blob/master/LICENSE)
5 | [](https://colab.research.google.com/drive/1a_mxF6Ot5vA_xzr5EaNz-GllYY-pXhwA)
6 |
7 | ## Starfysh: Spatial Transcriptomic Analysis using Reference-Free deep generative modeling with archetYpes and Shared Histology
8 |
9 | Starfysh is an end-to-end toolbox for the analysis and integration of Spatial Transcriptomic (ST) datasets. In summary, the Starfysh framework enables reference-free deconvolution of cell types and cell states and can be improved with the integration of paired histology images of tissues, if available. Starfysh is capable of integrating data from multiple tissues. In particular, Starfysh identifies common or sample-specific spatial “hubs” with unique composition of cell types. To uncover mechanisms underlying local and long-range communication, Starfysh can be used to perform downstream analysis on the spatial organization of hubs.
10 |
11 |
12 |
13 |
14 |
15 | ### Quickstart tutorials
16 | - [1. Basic deconvolution on an example breast cancer data with pre-defiend signatures, without poe(dataset & signature files included).](notebooks/Starfysh_tutorial_real_without_poe.ipynb)
17 | - [2. Histology integration & deconvolution without pre-defined signatures.](notebooks/Starfysh_tutorial_real_wo_signatures.ipynb)
18 | - [3. Basic deconvolution on an example breast cancer data with pre-defiend signatures, with histology images integration (poe) (dataset & signature files included).](notebooks/Starfysh_tutorial_real_without_poe.ipynb)
19 | - [4. Multi-sample integration, with archetypes and poe](notebooks/Starfysh_tutorial_integration.ipynb)
20 |
21 |
22 | Please refer to [Starfysh Documentation](http://starfysh.readthedocs.io) for additional tutorials & APIs
23 |
24 | ### Installation
25 | Github-version installation:
26 | ```bash
27 | # Step 1: Clone the Repository
28 | git clone https://github.com/azizilab/starfysh.git
29 |
30 | # Step 2: Navigate to the Repository
31 | cd starfysh
32 |
33 | # Step 3: Install the Package
34 | pip install .
35 | ```
36 |
37 |
38 | ### Model Input:
39 | - Spatial Transcriptomics count matrix
40 | - Annotated signature gene sets (see [example](https://drive.google.com/file/d/1AXWQy_mwzFEKNjAdrJjXuegB3onxJoOM/view?usp=share_link))
41 | - (Optional): paired H&E image
42 |
43 | ### Features:
44 | - Deconvolving cell types & discovering novel, unannotated cell states
45 | - Integrating with histology images and multi-sample integration
46 | - Downstream analysis: spatial hub identification, cell-type colocalization networks & receptor-ligand (R-L) interactions
47 |
48 | ### Directories
49 |
50 | ```
51 | .
52 | ├── data: Spatial Transcritomics & synthetic simulation datasets
53 | ├── notebooks: Sample tutorial notebooks
54 | ├── starfysh: Starfysh core model
55 | ```
56 |
57 | ### How to cite Starfysh
58 | Please cite [Starfysh paper published in Nature Biotechnology](https://www.nature.com/articles/s41587-024-02173-8#citeas):
59 | ```
60 | He, S., Jin, Y., Nazaret, A. et al.
61 | Starfysh integrates spatial transcriptomic and histologic data to reveal heterogeneous tumor–immune hubs.
62 | Nat Biotechnol (2024).
63 | https://doi.org/10.1038/s41587-024-02173-8
64 | ```
65 |
66 | ### BibTex
67 | ```
68 | @article{He2024,
69 | title = {Starfysh integrates spatial transcriptomic and histologic data to reveal heterogeneous tumor–immune hubs},
70 | ISSN = {1546-1696},
71 | url = {http://dx.doi.org/10.1038/s41587-024-02173-8},
72 | DOI = {10.1038/s41587-024-02173-8},
73 | journal = {Nature Biotechnology},
74 | publisher = {Springer Science and Business Media LLC},
75 | author = {He, Siyu and Jin, Yinuo and Nazaret, Achille and Shi, Lingting and Chen, Xueer and Rampersaud, Sham and Dhillon, Bahawar S. and Valdez, Izabella and Friend, Lauren E. and Fan, Joy Linyue and Park, Cameron Y. and Mintz, Rachel L. and Lao, Yeh-Hsing and Carrera, David and Fang, Kaylee W. and Mehdi, Kaleem and Rohde, Madeline and McFaline-Figueroa, José L. and Blei, David and Leong, Kam W. and Rudensky, Alexander Y. and Plitas, George and Azizi, Elham},
76 | year = {2024},
77 | month = mar
78 | }
79 | ```
80 |
81 | If you have questions, please contact the authors:
82 |
83 | - Siyu He - sh3846@columbia.edu
84 | - Yinuo Jin - yj2589@columbia.edu
85 |
86 |
87 |
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | ### Description
2 |
3 | The following public dataset from Wu *et al.*, 2021 study (**A single-cell and spatially resolved atlas of human breast cancers**) is downloaded from the Zenodo [link](https://zenodo.org/record/4739739#.Y4ohMbLMKxs):
4 |
5 | - 1142243F_TNBC
6 | - 1160920F_TNBC
7 | - CID4465_TNBC
8 | - CID44971_TNBC
9 | - CID4535_ER
10 | - CID4290_ER
11 |
12 | **Reference**:
13 | Wu, S. Z., Al-Eryani, G., Roden, D. L., Junankar, S., Harvey, K., Andersson, A., ... & Swarbrick, A. (2021). A single-cell and spatially resolved atlas of human breast cancers. Nature genetics, 53(9), 1334-1347.
14 |
--------------------------------------------------------------------------------
/dev-requirements.txt:
--------------------------------------------------------------------------------
1 | ipykernel>=6.17.0
2 | nbsphinx>=0.8.9
3 | sphinx>=4.2.0
4 | sphinxcontrib-apidoc>=4.2.0
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = source
9 | BUILDDIR = build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=source
11 | set BUILDDIR=build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.https://www.sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # This file only contains a selection of the most common options. For a full
4 | # list see the documentation:
5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
6 |
7 | # -- Path setup --------------------------------------------------------------
8 |
9 | # If extensions (or modules to document with autodoc) are in another directory,
10 | # add these directories to sys.path here. If the directory is relative to the
11 | # documentation root, use os.path.abspath to make it absolute, like shown here.
12 | #
13 | import os
14 | import sys
15 | sys.path.insert(0, os.path.abspath('../..'))
16 |
17 | # -- Project information -----------------------------------------------------
18 |
19 | project = 'Starfysh'
20 | copyright = '2022, Siyu He, Yinuo Jin @ Azizi Lab'
21 | author = 'Siyu He, Yinuo Jin'
22 |
23 | # The full version, including alpha/beta/rc tags
24 | release = '1.0.0'
25 |
26 |
27 | # -- General configuration ---------------------------------------------------
28 |
29 | # Add any Sphinx extension module names here, as strings. They can be
30 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
31 | # ones.
32 | extensions = [
33 | 'nbsphinx',
34 | 'sphinx.ext.autodoc',
35 | 'sphinx.ext.viewcode',
36 | 'sphinx.ext.napoleon'
37 | ]
38 |
39 | # Add any paths that contain templates here, relative to this directory.
40 | templates_path = ['_templates']
41 |
42 | # List of patterns, relative to source directory, that match files and
43 | # directories to ignore when looking for source files.
44 | # This pattern also affects html_static_path and html_extra_path.
45 | exclude_patterns = []
46 |
47 |
48 | # -- Options for HTML output -------------------------------------------------
49 |
50 | # The theme to use for HTML and HTML Help pages. See the documentation for
51 | # a list of builtin themes.
52 | #
53 | html_theme = 'sphinx_rtd_theme'
54 |
55 | # Add any paths that contain custom static files (such as style sheets) here,
56 | # relative to this directory. They are copied after the builtin static files,
57 | # so a file named "default.css" will overwrite the builtin "default.css".
58 | html_static_path = ['_static']
59 |
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | .. Starfysh documentation master file, created by
2 | sphinx-quickstart on Tue Nov 1 02:30:02 2022.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | Welcome to Starfysh's documentation!
7 | ====================================
8 | .. toctree::
9 | :maxdepth: 3
10 | :caption: Basics:
11 |
12 | intro
13 | installation
14 |
15 | .. toctree::
16 | :maxdepth: 3
17 | :caption: Examples:
18 |
19 | notebooks/Starfysh_tutorial_simulation.ipynb
20 | notebooks/Starfysh_tutorial_real.ipynb
21 |
22 | .. toctree::
23 | :maxdepth: 3
24 | :caption: API Reference:
25 |
26 | starfysh
27 |
28 | Indices and tables
29 | ==================
30 |
31 | * :ref:`genindex`
32 | * :ref:`modindex`
33 | * :ref:`search`
34 |
35 |
--------------------------------------------------------------------------------
/docs/source/installation.rst:
--------------------------------------------------------------------------------
1 | Installation
2 | ************
3 | .. code-block:: bash
4 |
5 | pip install Starfysh
6 |
7 | Quickstart
8 | **********
9 |
10 | .. code-block:: python
11 |
12 | import os
13 | import numpy as np
14 | import pandas as pd
15 | import torch
16 |
17 | from starfysh import (AA, utils, plot_utils, post_analysis)
18 | from starfysh import starfysh as sf_model
19 |
20 | # (1) Loading dataset & signature gene sets
21 | data_path = 'data/' # specify data directory
22 | sig_path = 'signature/signatures.csv' # specify signature directory
23 | sample_id = 'SAMPLE_ID'
24 |
25 | # --- (a) ST matrix ---
26 | adata, adata_norm = utils.load_adata(
27 | data_path,
28 | sample_id,
29 | n_genes=2000
30 | )
31 |
32 | # --- (b) paired H&E image + spots info ---
33 | img_metadata = utils.preprocess_img(
34 | data_path,
35 | sample_id,
36 | adata_index=adata.obs.index,
37 | hchannel=False
38 | )
39 |
40 | # --- (c) signature gene sets ---
41 | gene_sig = utils.filter_gene_sig(
42 | pd.read_csv(sig_path),
43 | adata.to_df()
44 | )
45 |
46 | # (2) Starfysh deconvolution
47 |
48 | # --- (a) Preparing arguments for model training
49 | args = utils.VisiumArguments(adata,
50 | adata_normed,
51 | gene_sig,
52 | img_metadata,
53 | n_anchors=60,
54 | window_size=3,
55 | sample_id=sample_id
56 | )
57 |
58 | adata, adata_normed = args.get_adata()
59 | anchors_df = args.get_anchors()
60 |
61 | # --- (b) Model training ---
62 | n_restarts = 3
63 | epochs = 200
64 | patience = 50
65 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
66 |
67 | # Run models
68 | model, loss = utils.run_starfysh(
69 | visium_args,
70 | n_repeats=n_repeats,
71 | epochs=epochs,
72 | patience=patience,
73 | device=device
74 | )
75 |
76 | # (3). Parse deconvolution outputs
77 | inference_outputs, generative_outputs = sf_model.model_eval(
78 | model,
79 | adata,
80 | visium_args,
81 | poe=False,
82 | device=device
83 | )
84 |
85 |
--------------------------------------------------------------------------------
/docs/source/intro.rst:
--------------------------------------------------------------------------------
1 | Overview
2 | ========
3 | Starfysh is an end-to-end toolbox for analysis and integration of ST datasets.
4 | In summary, the Starfysh framework consists of reference-free deconvolution of cell types and cell states, which can be improved with integration of paired histology images of tissues, if available. To facilitate comparison of tissues between healthy or disease contexts and deriving differential spatial patterns, Starfysh is capable of integrating data from multiple tissues and further identifies common or sample-specific spatial “hubs”, defined as neighborhoods with a unique composition of cell types. To uncover mechanisms underlying local and long-range communication, Starfysh performs downstream analysis on the spatial organization of hubs and identifies critical genes with spatially varying patterns as well as cell-cell interaction networks.
5 |
6 |
7 | Features
8 | ********
9 |
10 | * Deconvolving cell types / cell states
11 | * Discovering and learning novel cell states
12 | * Integrating with histology images and multi-sample integration
13 | * Downstream analysis: spatial hub identification, cell-type colocalization networks & receptor-ligand (R-L) interactions
14 |
15 | Model Specifications
16 | ********************
17 |
18 | Starfysh performs cell-type deconvolution followed by various downstream analysis to discover spatial interactions in tumor microenvironment.
19 | The core deconvolution model is based on semi-supervised Auxiliary Variational Autoencoder (AVAE). We further provide optional Archetypal Analysis (AA) & Product-of-Experts (PoE) for cell-type annotaation and H&E image integration to further aid deconvolution.
20 | Specifically, Starfysh looks for *anchor spots*, the presumed purest spots with the highest proportion of a given cell type guided by signatures, and further deconvolve the remaining spots. Starfysh provides the following options:
21 |
22 | **Base feature**:
23 |
24 | * Auxiliary Variational AutoEncoder (AVAE):
25 | Spot-level deconvolution with expected cell types and corresponding annotated *signature* gene sets (default)
26 |
27 | **Optional**:
28 |
29 | * Archetypal Analysis (AA):
30 | If signature is not provided:
31 |
32 | * Unsupervised cell type annotation (if the input *signature* is not provided)
33 |
34 | If signature is provided:
35 |
36 | * Novel cell type / cell state discovery (complementary to known cell types from the *signatures*
37 | * Refine known marker genes by appending archetype-specific differentially expressed genes, and update anchor spots accordingly
38 |
39 | * Product-of-Experts (PoE) integration:
40 | Multi-modal integrative predictions with *expression* & *histology image* by leverging additional side information (e.g. cell density) from H&E image.
41 |
42 |
43 | I/O
44 | ***
45 | - Input:
46 |
47 | - Spatial Transcriptomics count matrix
48 | - Annotated signature gene sets (`see example `_)
49 | - (Optional): paired H&E image
50 |
51 | - Output:
52 |
53 | - Spot-wise deconvolution matrix (`q(c)`)
54 | - Low-dimensional manifold representation (`q(z)`)
55 | - Spatial hubs (in-sample or multiple-sample integration)
56 | - Co-localization networks across cell types and Spatial receptor-ligand (R-L) interactions
57 | - Reconstructed count matrix (`p(x)`)
58 |
59 |
60 |
--------------------------------------------------------------------------------
/docs/source/notebooks/Starfysh_tutorial_real.ipynb:
--------------------------------------------------------------------------------
1 | ../../../notebooks/Starfysh_tutorial_real.ipynb
--------------------------------------------------------------------------------
/docs/source/notebooks/Starfysh_tutorial_simulation.ipynb:
--------------------------------------------------------------------------------
1 | ../../../notebooks/Starfysh_tutorial_simulation.ipynb
--------------------------------------------------------------------------------
/docs/source/notebooks/update_tutorial.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | ln -sf ../../../notebooks/Starfysh_tutorial_simulation.ipynb Starfysh_tutorial_simulation.ipynb
4 | ln -sf ../../../notebooks/Starfysh_tutorial_real.ipynb Starfysh_tutorial_real.ipynb
5 |
6 |
--------------------------------------------------------------------------------
/docs/source/starfysh.rst:
--------------------------------------------------------------------------------
1 | starfysh package
2 | ================
3 |
4 | Submodules
5 | ----------
6 |
7 | starfysh.AA module
8 | ------------------
9 |
10 | .. automodule:: starfysh.AA
11 | :members:
12 | :undoc-members:
13 | :show-inheritance:
14 |
15 | starfysh.dataloader module
16 | --------------------------
17 |
18 | .. automodule:: starfysh.dataloader
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 |
23 | starfysh.plot\_utils module
24 | ---------------------------
25 |
26 | .. automodule:: starfysh.plot_utils
27 | :members:
28 | :undoc-members:
29 | :show-inheritance:
30 |
31 | starfysh.post\_analysis module
32 | ------------------------------
33 |
34 | .. automodule:: starfysh.post_analysis
35 | :members:
36 | :undoc-members:
37 | :show-inheritance:
38 |
39 | starfysh.starfysh module
40 | ------------------------
41 |
42 | .. automodule:: starfysh.starfysh
43 | :members:
44 | :undoc-members:
45 | :show-inheritance:
46 |
47 | starfysh.utils module
48 | ---------------------
49 |
50 | .. automodule:: starfysh.utils
51 | :members:
52 | :undoc-members:
53 | :show-inheritance:
54 |
55 | Module contents
56 | ---------------
57 |
58 | .. automodule:: starfysh
59 | :members:
60 | :undoc-members:
61 | :show-inheritance:
62 |
--------------------------------------------------------------------------------
/figure/github_figure_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/azizilab/starfysh/56fb01ef734401d067eb3078280dd805b97621d1/figure/github_figure_1.png
--------------------------------------------------------------------------------
/figure/github_figure_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/azizilab/starfysh/56fb01ef734401d067eb3078280dd805b97621d1/figure/github_figure_2.png
--------------------------------------------------------------------------------
/figure/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/azizilab/starfysh/56fb01ef734401d067eb3078280dd805b97621d1/figure/logo.png
--------------------------------------------------------------------------------
/reinstall.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | pip uninstall starfysh
4 | python setup.py install --user
5 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | histomicstk>=1.2.0
2 | matplotlib>=3.4.2
3 | networkx>=2.6.3
4 | numba>=0.56.4
5 | numpy>=1.23.5
6 | opencv_python_headless>=4.5.1.48
7 | pandas>=1.3.3
8 | py_pcha>=0.1.3
9 | scanpy>=1.9.2
10 | scikit_dimension>=0.3
11 | scikit_image>=0.19.2
12 | scikit_learn>=1.2.1
13 | scipy>=1.7.1
14 | seaborn>=0.11.2
15 | setuptools>=50.3.2
16 | threadpoolctl>=3.1.0
17 | torch>=2.0.0
18 | torchvision>=0.15.1
19 | tqdm>=4.62.2
20 | umap_learn>=0.5.3
21 | nbsphinx
22 | ipykernel
23 | sphinx
24 | sphinxcontrib-apidoc
25 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | with open("requirements.txt", 'r') as ifile:
4 | requirements = ifile.read().splitlines()
5 |
6 | nb_requirements = [
7 | 'nbconvert>=6.1.0',
8 | 'nbformat>=5.1.3',
9 | 'notebook>=6.4.11',
10 | 'jupyter>=7.0.0',
11 | 'jupyterlab>=3.4.3',
12 | 'ipython>=7.27.0',
13 | ]
14 |
15 | setup(
16 | name="Starfysh",
17 | version="1.2.0",
18 | description="Spatial Transcriptomic Analysis using Reference-Free auxiliarY deep generative modeling and Shared Histology",
19 | authors=["Siyu He", "Yinuo Jin", "Achille Nazaret"],
20 | url="https://starfysh.readthedocs.io",
21 | packages=find_packages(),
22 | python_requires='>=3.7',
23 | install_requires=requirements,
24 | zip_safe=False,
25 | classifiers=[
26 | "Programming Language :: Python :: 3",
27 | "License :: OSI Approved :: BSD License",
28 | "Operating System :: OS Independent",
29 | ],
30 |
31 | dependency_links=[
32 | 'https://girder.github.io/large_image_wheels'
33 | ]
34 |
35 | # extras_require={
36 | # 'notebooks': nb_requirements,
37 | # 'dev': open('dev-requirements.txt').read().splitlines(),
38 | # }
39 | )
40 |
41 |
--------------------------------------------------------------------------------
/starfysh/AA.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import pandas as pd
4 | import scanpy as sc
5 | import umap
6 | import skdim
7 | import matplotlib.pyplot as plt
8 | import seaborn as sns
9 |
10 | from py_pcha import PCHA
11 | from matplotlib.pyplot import cm
12 | from scipy.spatial.distance import cdist, euclidean
13 | from sklearn.neighbors import NearestNeighbors
14 | from starfysh import LOGGER
15 |
16 |
17 | class ArchetypalAnalysis:
18 | # Todo: add assertion that input `adata` should be raw counts to ensure math assumptions
19 | def __init__(
20 | self,
21 | adata_orig,
22 | r=100,
23 | u=None,
24 | u_3d=None,
25 | verbose=True,
26 | outdir=None,
27 | filename=None,
28 | savefig=False,
29 | ):
30 | """
31 | Parameters
32 | ----------
33 | adata_orig : sc.AnnData
34 | ST raw count matrix
35 |
36 | r : int
37 | Resolution parameter to control granularity of major archetypes
38 | If two archetypes reside within r nearest neighbors, the latter
39 | one will be merged.
40 | """
41 |
42 | # Check adata raw counts
43 |
44 |
45 | self.adata = adata_orig.copy()
46 |
47 | # Perform dim. reduction with PCA, select the first 30 PCs
48 | if 'X_umap' in self.adata.obsm_keys():
49 | del self.adata.obsm['X_umap']
50 | sc.pp.normalize_total(self.adata, target_sum=1e6)
51 | sc.pp.pca(self.adata, n_comps=30)
52 |
53 | self.r = r # granularity parameter
54 | self.count = self.adata.obsm['X_pca']
55 | self.n_spots = self.count.shape[0]
56 |
57 | self.verbose = verbose
58 | self.outdir = outdir
59 | self.filename = filename
60 | self.savefig = savefig
61 |
62 | self.archetype = None
63 | self.major_archetype = None
64 | self.major_idx = None
65 | self.arche_dict = None
66 | self.arche_df = None
67 | self.kmin = 0
68 |
69 | self.U = u
70 | self.U_3d = u_3d
71 |
72 | def compute_archetypes(
73 | self,
74 | cn=30,
75 | n_iters=20,
76 | converge=1e-3,
77 | display=False
78 | ):
79 | """
80 | Estimate the upper bound of archetype count (k) by calculating intrinsic dimension
81 | Compute hierarchical archetypes (major + raw) with given granularity
82 |
83 | Parameters
84 | ----------
85 | cn : int
86 | Conditional Number to choose PCs for intrinsic estimator as
87 | lower bound # archetype estimation. Please refer to:
88 | https://scikit-dimension.readthedocs.io/en/latest/skdim.id.FisherS.html#skdim.id.FisherS
89 |
90 | n_iters : int
91 | Max. # iterations of AA to find the best k estimation
92 |
93 | converge : int
94 | Convergence criteria for AA iteration with diff(explained variance)
95 |
96 | display : bool
97 | Whether to display Intrinsic Dimension (ID) estimation plots
98 |
99 | Returns
100 | -------
101 | archetype : np.ndarray (dim=[K, G])
102 | Raw archetypes as linear combination of subset of spot counts
103 |
104 | arche_dict : dict
105 | Hierarchical structure of major_archetype -> its fine-grained neighbor archetypes
106 |
107 | major_idx : int
108 | Index of major archetypes among `k` raw candidates after merging\
109 |
110 | evs : list
111 | Explained variance with different Ks
112 | """
113 |
114 | # TMP: across-sample comparison: fix # principle components for all samples
115 |
116 | if self.verbose:
117 | LOGGER.info('Computing intrinsic dimension to estimate k...')
118 |
119 | # Estimate ID
120 | id_model = skdim.id.FisherS(conditional_number=cn,
121 | produce_plots=display,
122 | verbose=self.verbose)
123 |
124 | self.kmin = max(1, int(id_model.fit(self.count).dimension_))
125 |
126 | # Compute raw archetypes
127 | if self.verbose:
128 | LOGGER.info('Estimating lower bound of # archetype as {0}...'.format(self.kmin))
129 | X = self.count.T
130 | archetypes = []
131 | evs = []
132 |
133 | # TODO: speedup with multiprocessing
134 | for i, k in enumerate(range(self.kmin, self.kmin+n_iters, 2)):
135 | archetype, _, _, _, ev = PCHA(X, noc=k)
136 | evs.append(ev)
137 | archetypes.append(np.array(archetype).T)
138 | if i > 0 and ev - evs[i-1] < converge:
139 | # early stopping
140 | break
141 | self.archetype = archetypes[-1]
142 |
143 | # Merge raw archetypes to get major archetypes
144 | if self.verbose:
145 | LOGGER.info('{0} variance explained by raw archetypes.\n'
146 | 'Merging raw archetypes within {1} NNs to get major archetypes'.format(np.round(ev, 4), self.r))
147 |
148 | arche_dict, major_idx = self._merge_archetypes()
149 | self.major_archetype = self.archetype[major_idx]
150 | self.major_idx = np.array(major_idx)
151 | self.arche_dict = arche_dict
152 |
153 | # return all archetypes for Silhouette score calculation
154 | return archetypes, arche_dict, major_idx, evs
155 |
156 | def _merge_archetypes(self):
157 | """
158 | Merge raw archetypes into major ones by removing candidate with `r`-step distance
159 | from its previous identified neighbors
160 | """
161 | assert self.archetype is not None, "Please compute archetypes first!"
162 |
163 | n_archetypes = self.archetype.shape[0]
164 | X_concat = np.vstack([self.count, self.archetype])
165 | nbrs = NearestNeighbors(n_neighbors=self.r).fit(X_concat)
166 | nn_graph = nbrs.kneighbors(X_concat)[1][self.n_spots:, 1:] # retrieve NN-graph of only archetype spots
167 |
168 | idxs_to_remove = set()
169 | arche_dict = {}
170 | for i in range(n_archetypes):
171 | if i not in idxs_to_remove:
172 | query = np.arange(self.n_spots+i, self.n_spots+n_archetypes)
173 | nbrs = np.setdiff1d(
174 | nn_graph[i][np.isin(nn_graph[i], query)] - self.n_spots,
175 | list(idxs_to_remove) # avoid over-assign merged archetypes to multiple major archetypes
176 | )
177 | if len(nbrs) != 0:
178 | arche_dict[i] = np.insert(nbrs, 0, i)
179 | idxs_to_remove.update(nbrs)
180 |
181 | major_idx = np.setdiff1d(np.arange(n_archetypes), list(idxs_to_remove))
182 | return arche_dict, major_idx
183 |
184 | def find_archetypal_spots(self, major=True):
185 | """
186 | Assign N-nearest-neighbor spots to each archetype as `archetypal spots` (archetype community)
187 |
188 | Parameters
189 | ----------
190 | major : bool
191 | Whether to find NNs for only major archetypes
192 |
193 | Returns
194 | -------
195 | arche_df : pd.DataFrame
196 | Dataframe of archetypal spots
197 | """
198 | assert self.archetype is not None, "Please compute archetypes first!"
199 | if self.verbose:
200 | LOGGER.info('Finding {} nearest neighbors for each archetype...'.format(self.r))
201 |
202 | indices = self.major_idx if major else np.arange(self.archetype.shape[0])
203 | x_concat = np.vstack([self.count, self.archetype])
204 | nbrs = self._get_knns(x_concat, n_nbrs=self.r, indices=indices+self.n_spots)
205 | self.arche_df = pd.DataFrame({
206 | 'arch_{}'.format(idx): g
207 | for (idx, g) in zip(indices, nbrs)
208 | })
209 |
210 | return self.arche_df
211 |
212 | def find_markers(self, n_markers=30, display=False):
213 | """
214 | Find marker genes for each archetype community via Wilcoxon rank sum test (in-group vs. out-of-group)
215 |
216 | Parameters
217 | ----------
218 | n_markers : int
219 | Number of top marker genes to find for each archetype community
220 |
221 | Returns
222 | -------
223 | marker_df : pd.DataFrame
224 | Dataframe of marker genes for each archetype community
225 | """
226 | assert self.arche_df is not None, "Please compute archetypes & assign nearest-neighbors first!"
227 | if self.verbose:
228 | LOGGER.info('Finding {} top marker genes for each archetype...'.format(n_markers))
229 |
230 | adata = self.adata.copy()
231 | markers = []
232 | for col in self.arche_df.columns:
233 | # Annotate in-group (current archetype) vs. out-of-group
234 | annots = np.zeros(self.n_spots, dtype=np.int64).astype(str)
235 | annots[self.arche_df[col]] = col
236 | adata.obs[col] = annots
237 | adata.obs[col] = adata.obs[col].astype('category')
238 |
239 | # Identify marker genes
240 | sc.tl.rank_genes_groups(adata, col, use_raw=False, method='wilcoxon')
241 | markers.append(adata.uns['rank_genes_groups']['names'][col][:n_markers])
242 |
243 | if display:
244 | plt.rcParams['figure.figsize'] = (8, 3)
245 | plt.rcParams['figure.dpi'] = 300
246 | sc.pl.rank_genes_groups_violin(adata, groups=[col], n_genes=n_markers)
247 |
248 | return pd.DataFrame(np.stack(markers, axis=1), columns=self.arche_df.columns)
249 |
250 | def assign_archetypes(self, anchor_df, r=30):
251 | """
252 | Stable-matching to obtain best 1-1 mapping of archetype community to its closest anchor community
253 | (cell-type specific anchor spots)
254 |
255 | Parameters
256 | ----------
257 | anchor_df : pd.DataFrame
258 | Dataframe of anchor spot indices
259 |
260 | ` r : int
261 | Resolution parameter to threshold archetype - anchor mapping
262 |
263 | Returns
264 | -------
265 | overlaps_df : pd.DataFrame
266 | DataFrame of overlapping spot ratio of each anchor `i` to archetype `j`
267 |
268 | map_dict : dict
269 | Dictionary of cell type -> mapped archetype
270 | """
271 | assert self.arche_df is not None, "Please compute archetypes & assign nearest-neighbors first!"
272 |
273 | x_concat = np.vstack([self.count, self.archetype])
274 | anchor_nbrs = anchor_df.values
275 | archetypal_nbrs = self._get_knns(
276 | x_concat, n_nbrs=r, indices=self.n_spots + self.major_idx).T # r-nearest nbrs to each archetype
277 |
278 | overlaps = np.array(
279 | [
280 | [
281 | len(np.intersect1d(anchor_nbrs[:, i], archetypal_nbrs[:, j]))
282 | for j in range(archetypal_nbrs.shape[1])
283 | ]
284 | for i in range(anchor_nbrs.shape[1])
285 | ]
286 | )
287 | overlaps_df = pd.DataFrame(overlaps, index=anchor_df.columns, columns=self.arche_df.columns)
288 |
289 | # Stable marriage matching: archetype -> anchor clusters
290 | map_idx_df = self._stable_matching(overlaps)
291 | map_dict = {overlaps_df.index[k]: overlaps_df.columns[v] for k, v in map_idx_df.items()}
292 | return overlaps_df, map_dict
293 |
294 | def find_distant_archetypes(self, anchor_df, map_dict=None, n=3):
295 | """
296 | Sort and return top n archetypes that are unmapped and farthest from anchor spots of know cell types
297 | They are more likely to represent novel cell types / states
298 |
299 | Parameters
300 | ----------
301 | anchor_df : pd.DataFrame
302 | Dataframe of anchor spot indices
303 |
304 | map_dict : dict
305 | Dictionary of cell type -> mapped archetype
306 |
307 | n : int
308 | Number of distant archetypes to return
309 |
310 | Returns
311 | -------
312 | distant_archetypes : list
313 | List of archetype labels (farthest --> closest to anchors)
314 | """
315 | assert self.arche_df is not None, "Please compute archetypes & assign nearest-neighbors first!"
316 |
317 | cell_types = anchor_df.columns
318 | arche_lbls = self.arche_df.columns
319 |
320 | # Find the unmapped archetypes
321 | if map_dict is None:
322 | _, map_dict = self.assign_archetypes(anchor_df=anchor_df)
323 | unmapped_archetypes = np.setdiff1d(
324 | arche_lbls,
325 | list(set([v for k, v in map_dict.items()]))
326 | )
327 |
328 | # Sort unmapped archetypes in descending orders with avg. distance to its 2 closest anchor spot centroid
329 | if n > len(unmapped_archetypes):
330 | LOGGER.warning('Insufficient candidates to find {0} distant archetypes\nSet n={1}'.format(
331 | n, len(unmapped_archetypes)
332 | ))
333 | anchor_centroids = self.count[anchor_df[anchor_df.columns]].mean(0)
334 | arche_centroids = self.count[self.arche_df[self.arche_df.columns]].mean(0)
335 | dist_df = pd.DataFrame(
336 | cdist(anchor_centroids, arche_centroids),
337 | index=cell_types,
338 | columns=arche_lbls
339 | )
340 | dist_unmapped = dist_df[unmapped_archetypes].values # subset only distance to `unmapped` archetypes
341 | dist_to_nbrs = np.sort(dist_unmapped, axis=0)[:2].mean(0)
342 | distant_arches = [unmapped_archetypes[idx] for idx in np.argsort(-dist_to_nbrs)][:n] # dist - Discending order
343 |
344 | return distant_arches
345 |
346 | def _get_knns(self, x, n_nbrs, indices):
347 | """Compute kNNs (actual spots) to each archetype"""
348 | assert 0 <= indices.min() < indices.max() < x.shape[0], \
349 | "Invalid indices of interest to compute k-NNs"
350 | nbrs = np.zeros((len(indices), n_nbrs), dtype=np.int32)
351 | for i, index in enumerate(indices):
352 | u = x[index]
353 | dist = np.ones(x.shape[0])*np.inf
354 | for j, v in enumerate(x[:self.n_spots]):
355 | dist[j] = np.linalg.norm(u-v)
356 | nbrs[i] = np.argsort(dist)[:n_nbrs]
357 | return nbrs
358 |
359 | def _stable_matching(self, A):
360 | matching = {}
361 | free_rows, free_cols = set(range(A.shape[0])), set(range(A.shape[1]))
362 |
363 | while free_rows:
364 | i = free_rows.pop()
365 | for j in np.argsort(A[i])[::-1]: # iter cols in decreasing vals
366 | if j in free_cols:
367 | matching[i] = j
368 | free_cols.remove(j)
369 | break
370 | else: # Check matched cols & compare tie-breaking conditions
371 | i_prime = next(k for k, v in matching.items() if v == j)
372 | if A[i, j] > A[i_prime, j]:
373 | matching[i] = j
374 | free_rows.add(i_prime)
375 | del matching[i_prime]
376 | break
377 |
378 | return matching
379 |
380 | # -------------------
381 | # Plotting functions
382 | # -------------------
383 |
384 | def _get_umap(self, ndim=2, random_state=42):
385 | assert ndim == 2 or ndim == 3, "Invalid dimension for UMAP: {}".format(ndim)
386 | LOGGER.info('Calculating UMAPs for counts + Archetypes...')
387 | reducer = umap.UMAP(n_neighbors=self.r+10, n_components=ndim, random_state=random_state)
388 | U = reducer.fit_transform(np.vstack([self.count, self.archetype]))
389 | return U
390 |
391 | def _save_fig(self, fig, lgds, default_name):
392 | filename = self.filename if self.filename is not None else default_name
393 | if not os.path.exists(self.outdir):
394 | os.makedirs(self.outdir)
395 |
396 | fig.savefig(
397 | os.path.join(self.outdir, filename+'.svg'),
398 | bbox_extra_artists=lgds, bbox_inches='tight',
399 | format='svg'
400 | )
401 |
402 | def plot_archetypes(
403 | self,
404 | major=True,
405 | do_3d=False,
406 | lgd_ncol=1,
407 | figsize=(6, 4),
408 | disp_cluster=True,
409 | disp_arche=True
410 | ):
411 | """
412 | Display archetype & archetypal spot communities
413 | """
414 | assert self.arche_df is not None, "Please compute archetypes & assign nearest-neighbors first!"
415 | n_archetypes = self.arche_df.shape[1]
416 | arche_indices = self.major_idx if major else np.arange(n_archetypes)
417 | U = self.U_3d if do_3d else self.U
418 | colors = cm.tab20(np.linspace(0, 1, n_archetypes))
419 |
420 | if do_3d:
421 | if self.U_3d is None:
422 | self.U_3d = self._get_umap(ndim=3)
423 | U = self.U_3d
424 |
425 | fig, ax = plt.subplots(1, 1, figsize=figsize, dpi=200, subplot_kw=dict(projection='3d'))
426 |
427 | # Color background spots & archetypal spots
428 |
429 | ax.scatter(
430 | U[:self.n_spots, 0],
431 | U[:self.n_spots, 1],
432 | U[:self.n_spots, 2],
433 | s=1, alpha=0.7, linewidth=.3,
434 | edgecolors='black', c='lightgray'
435 | )
436 |
437 | if disp_cluster:
438 | for i, label in enumerate(self.arche_df.columns):
439 | lbl = int(label.split('_')[-1])
440 | if lbl in arche_indices:
441 | idxs = self.arche_df[label]
442 | ax.scatter(
443 | U[idxs, 0],
444 | U[idxs, 1],
445 | U[idxs, 2],
446 | marker='o', s=3,
447 | color=colors[i], label=label
448 | )
449 |
450 | # Highlight archetype
451 | if disp_arche:
452 | ax.scatter(
453 | U[self.n_spots+arche_indices, 0],
454 | U[self.n_spots+arche_indices, 1],
455 | U[self.n_spots+arche_indices, 2],
456 | s=10, c='blue', marker='^'
457 | )
458 | for j, z in zip(arche_indices, U[self.n_spots+arche_indices]):
459 | ax.text(z[0], z[1], z[2], str(j), fontsize=10, c='blue')
460 |
461 | lgd = ax.legend(loc='right', bbox_to_anchor=(0.5, 0, 1.5, 0.5), ncol=lgd_ncol)
462 |
463 | ax.grid(False)
464 | ax.set_xlabel('UMAP1')
465 | ax.set_ylabel('UMAP2')
466 | ax.set_zlabel('UMAP3')
467 |
468 | ax.set_xticklabels([])
469 | ax.set_yticklabels([])
470 | ax.set_zticklabels([])
471 |
472 | ax.xaxis.pane.set_edgecolor('black')
473 | ax.yaxis.pane.set_edgecolor('black')
474 |
475 | ax.set_xticks([])
476 | ax.set_yticks([])
477 | ax.set_zticks([])
478 |
479 | ax.xaxis.pane.fill = False
480 | ax.yaxis.pane.fill = False
481 | ax.zaxis.pane.fill = False
482 |
483 | ax.view_init(20, 135)
484 |
485 | else: # 2D plot
486 | if self.U is None:
487 | self.U = self._get_umap(ndim=2)
488 | U = self.U
489 |
490 | fig, ax = plt.subplots(1, 1, figsize=figsize, dpi=300)
491 |
492 | # Color background & archetypal spots
493 | ax.scatter(
494 | U[:self.n_spots, 0],
495 | U[:self.n_spots, 1],
496 | alpha=1, s=1, color='lightgray')
497 |
498 | if disp_cluster:
499 | for i, label in enumerate(self.arche_df.columns):
500 | lbl = int(label.split('_')[-1])
501 | if lbl in arche_indices:
502 | idxs = self.arche_df[label]
503 | ax.scatter(
504 | U[idxs, 0],
505 | U[idxs, 1],
506 | marker='o', s=3,
507 | color=colors[i], label=label
508 | )
509 |
510 | if disp_arche:
511 | ax.scatter(
512 | U[self.n_spots+arche_indices, 0],
513 | U[self.n_spots+arche_indices, 1],
514 | s=10, c='blue', marker='^'
515 | )
516 |
517 | for j, z in zip(arche_indices, U[self.n_spots+arche_indices]):
518 | ax.text(z[0], z[1], str(j), fontsize=10, c='blue')
519 | lgd = ax.legend(loc='right', bbox_to_anchor=(1, 0.5), ncol=lgd_ncol)
520 |
521 | ax.grid(False)
522 | ax.axis('off')
523 |
524 | if self.savefig and self.outdir is not None:
525 | self._save_fig(fig, (lgd,), 'archetypes')
526 | return fig, ax
527 |
528 | def plot_anchor_archetype_clusters(
529 | self,
530 | anchor_df,
531 | cell_types=None,
532 | arche_lbls=None,
533 | lgd_ncol=2,
534 | do_3d=False
535 | ):
536 | """
537 | Joint display subset of anchor spots & archetypal spots (to visualize overlapping degree)
538 | """
539 | assert self.arche_df is not None, "Please compute archetypes & assign nearest-neighbors first!"
540 |
541 | cell_types = anchor_df.columns if cell_types is None else np.intersect1d(cell_types, anchor_df.columns)
542 | arche_lbls = self.arche_df.columns if arche_lbls is None else np.intersect1d(arche_lbls, self.arche_df.columns)
543 | u_centroids = U[self.arche_df[arche_lbls]].mean(0)
544 |
545 | anchor_colors = cm.RdBu_r(np.linspace(0, 1, len(cell_types)))
546 | arche_colors = cm.RdBu_r(np.linspace(0, 1, len(arche_lbls)))
547 |
548 | if do_3d:
549 | if self.U_3d is None:
550 | self.U_3d = self._get_umap(ndim=3)
551 | U = self.U_3d
552 |
553 | fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 5), dpi=300, subplot_kw=dict(projection='3d'))
554 |
555 | # Display anchors
556 | ax1.scatter(
557 | U[:self.n_spots, 0],
558 | U[:self.n_spots, 1],
559 | U[:self.n_spots, 2],
560 | c='gray', marker='.', s=1, alpha=0.2
561 | )
562 | for c, label in zip(anchor_colors, cell_types):
563 | idxs = anchor_df[label]
564 | ax1.scatter(
565 | U[idxs, 0],
566 | U[idxs, 1],
567 | U[idxs, 2],
568 | color=c, marker='^', s=5,
569 | alpha=0.9, label=label
570 | )
571 |
572 | ax1.grid(False)
573 | ax1.set_xticklabels([])
574 | ax1.set_yticklabels([])
575 | ax1.set_zticklabels([])
576 | ax1.view_init(30, 45)
577 |
578 | lgd1 = ax1.legend(loc='lower center', bbox_to_anchor=(0.5, -1), ncol=lgd_ncol)
579 |
580 | # Display archetypal spots
581 | ax2.scatter(U[:self.n_spots, 0], U[:self.n_spots, 1], U[:self.n_spots, 2], c='gray', marker='.', s=1, alpha=0.2)
582 | for c, label in zip(arche_colors, arche_lbls):
583 | idxs = self.arche_df[label]
584 | ax2.scatter(U[idxs, 0], U[idxs, 1], U[idxs, 2], color=c, marker='o', s=3, alpha=0.9, label=label)
585 |
586 | # Highlight selected archetypes
587 | for label, z in zip(arche_lbls, u_centroids):
588 | idx = int(label.split('_')[-1])
589 | ax2.text(z[0], z[1], z[2], str(idx))
590 |
591 | ax2.grid(False)
592 | ax2.set_xticklabels([])
593 | ax2.set_yticklabels([])
594 | ax2.set_zticklabels([])
595 | ax2.view_init(30, 45)
596 |
597 | lgd2 = ax2.legend(loc='lower center', bbox_to_anchor=(0.5, -1), ncol=lgd_ncol)
598 |
599 | else:
600 | if self.U is None:
601 | self.U = self._get_umap(ndim=2)
602 | U = self.U
603 |
604 | fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(9, 3), dpi=300)
605 |
606 | # Display anchors
607 | ax1.scatter(
608 | U[:self.n_spots, 0],
609 | U[:self.n_spots, 1],
610 | c='gray', marker='.', s=1, alpha=0.2
611 | )
612 |
613 | for c, label in zip(anchor_colors, cell_types):
614 | idxs = anchor_df[label]
615 | ax1.scatter(
616 | U[idxs, 0],
617 | U[idxs, 1],
618 | color=c, marker='^', s=5,
619 | alpha=0.9, label=label
620 | )
621 |
622 | lgd1 = ax1.legend(loc='lower center', bbox_to_anchor=(0.5, -1.75), ncol=lgd_ncol)
623 |
624 | # Display archetypal spots
625 | ax2.scatter(U[:self.n_spots, 0], U[:self.n_spots, 1], c='gray', marker='.', s=1, alpha=0.2)
626 | for c, label in zip(arche_colors, arche_lbls):
627 | idxs = self.arche_df[label]
628 | ax2.scatter(
629 | U[idxs, 0],
630 | U[idxs, 1],
631 | color=c, marker='o', s=3,
632 | alpha=0.9, label=label
633 | )
634 |
635 | # Highlight selected archetypes
636 | for label, z in zip(arche_lbls, u_centroids):
637 | idx = int(label.split('_')[-1])
638 | ax2.text(z[0], z[1], str(idx))
639 | lgd2 = ax2.legend(loc='lower center', bbox_to_anchor=(0.5, -1.85), ncol=lgd_ncol)
640 |
641 | if self.savefig and self.outdir is not None:
642 | self._save_fig(fig, (lgd1, lgd2), 'anchor_archetypal_spots')
643 | return fig, (ax1, ax2)
644 |
645 | def plot_mapping(self, map_df, figsize=(6, 5)):
646 | """
647 | Display anchor - archetype mapping (overlapping # spot ratio)
648 | """
649 | filename = 'cluster' if self.filename is None else self.filename
650 | g = sns.clustermap(
651 | map_df,
652 | method='ward',
653 | figsize=figsize,
654 | xticklabels=True,
655 | yticklabels=True,
656 | square=True,
657 | annot_kws={'size': 15}
658 | )
659 |
660 | text = g.ax_heatmap.set_title('# Overlapped NN Spots (k={})'.format(map_df.shape[1]),
661 | fontsize=20, x=0.6, y=1.3)
662 | # g.ax_row_dendrogram.set_visible(False)
663 | # g.ax_col_dendrogram.set_visible(False)
664 |
665 | if self.savefig and self.outdir is not None:
666 | if not os.path.exists(self.outdir):
667 | os.makedirs(self.outdir)
668 | g.figure.savefig(
669 | os.path.join(self.outdir, filename + '.eps'),
670 | bbox_extra_artists=(text,), bbox_inches='tight', format='eps'
671 | )
672 |
673 | return g
674 |
--------------------------------------------------------------------------------
/starfysh/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import multiprocessing
3 | import logging
4 |
5 | n_cores = multiprocessing.cpu_count()
6 | os.environ["NUMEXPR_MAX_THREADS"] = str(n_cores)
7 |
8 | # Configure global logging format
9 | logging.basicConfig(
10 | level=logging.INFO,
11 | format='[%(asctime)s] %(message)s',
12 | datefmt='%Y-%m-%d %H:%M:%S',
13 | force=True
14 | )
15 |
16 | LOGGER = logging.getLogger('Starfysh')
17 |
--------------------------------------------------------------------------------
/starfysh/dataloader.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import torch
4 | import torch.nn.functional as F
5 | from torch.utils.data import Dataset
6 |
7 | from starfysh import LOGGER
8 |
9 |
10 | #---------------------------
11 | # Single Sample dataloader
12 | #---------------------------
13 |
14 | class VisiumDataset(Dataset):
15 | """
16 | Loading a single preprocessed ST AnnData, gene signature & Anchor spots for Starfysh training
17 | """
18 |
19 | def __init__(
20 | self,
21 | adata,
22 | args,
23 | ):
24 | spots = adata.obs_names
25 | genes = adata.var_names
26 |
27 | x = adata.X if isinstance(adata.X, np.ndarray) else adata.X.A
28 | self.expr_mat = pd.DataFrame(x, index=spots, columns=genes)
29 | self.gexp = args.sig_mean_norm
30 | self.anchor_idx = args.pure_idx
31 | self.library_n = args.win_loglib
32 |
33 | def __len__(self):
34 | return len(self.expr_mat)
35 |
36 | def __getitem__(self, idx):
37 | if torch.is_tensor(idx):
38 | idx = idx.tolist()
39 | sample = torch.Tensor(
40 | np.array(self.expr_mat.iloc[idx, :], dtype='float')
41 | )
42 |
43 | return (sample,
44 | torch.Tensor(self.gexp.iloc[idx, :]), # normalized signature exprs
45 | torch.Tensor(self.anchor_idx[idx, :]), # anchors
46 | torch.Tensor(self.library_n[idx,None]), # library size
47 | )
48 |
49 |
50 | class VisiumPoEDataSet(VisiumDataset):
51 |
52 | def __init__(
53 | self,
54 | adata,
55 | args,
56 | ):
57 | super(VisiumPoEDataSet, self).__init__(adata, args)
58 | self.image = args.img.astype(np.float64)
59 | self.map_info = args.map_info
60 | self.r = args.params['patch_r']
61 | self.spot_img_stack = []
62 |
63 | self.density_std = args.img.std()
64 |
65 | assert self.image is not None,\
66 | "Empty paired H&E image," \
67 | "please use regular `Starfysh` without PoE integration" \
68 | "if your dataset doesn't contain histology image"
69 |
70 | # Retrieve image patch around each spot
71 | scalef = args.scalefactor['tissue_hires_scalef'] # High-res scale factor
72 | h, w = self.image.shape[:2]
73 | patch_dim = (self.r*2, self.r*2, 3) if self.image.ndim == 3 else (self.r*2, self.r*2)
74 |
75 | for i in range(len(self.expr_mat)):
76 | xc = int(np.round(self.map_info.iloc[i]['imagecol'] * scalef))
77 | yc = int(np.round(self.map_info.iloc[i]['imagerow'] * scalef))
78 |
79 | # boundary conditions: edge spots
80 | yl, yr = max(0, yc-self.r), min(self.image.shape[0], yc+self.r)
81 | xl, xr = max(0, xc-self.r), min(self.image.shape[1], xc+self.r)
82 | top = max(0, self.r-yc)
83 | bottom = h if h > (yc+self.r) else h-(yc+self.r)
84 | left = max(0, self.r-xc)
85 | right = w if w > (xc+self.r) else w-(xc+self.r)
86 |
87 | #try:
88 | patch = np.zeros(patch_dim)
89 | patch[top:bottom, left:right] = self.image[yl:yr, xl:xr]
90 | self.spot_img_stack.append(patch)
91 | #except ValueError:
92 | # LOGGER.warning('Skipping the patch loading of an edge spot...')
93 |
94 |
95 | def __len__(self):
96 | return len(self.expr_mat)
97 |
98 | def __getitem__(self, idx):
99 | if torch.is_tensor(idx):
100 | idx = idx.tolist()
101 | sample = torch.Tensor(
102 | np.array(self.expr_mat.iloc[idx, :], dtype='float')
103 | )
104 | spot_img_stack = self.spot_img_stack[idx]
105 | return (sample,
106 | torch.Tensor(self.anchor_idx[idx, :]),
107 | torch.Tensor(self.library_n[idx, None]),
108 | spot_img_stack,
109 | self.map_info.index[idx],
110 | torch.Tensor(self.gexp.iloc[idx, :]),
111 | )
112 |
113 | #---------------------------
114 | # Integrative Dataloader
115 | #---------------------------
116 |
117 |
118 | class IntegrativeDataset(VisiumDataset):
119 | """
120 | Loading multiple preprocessed ST sample AnnDatas, gene signature & Anchor spots for Starfysh training
121 | """
122 |
123 | def __init__(
124 | self,
125 | adata,
126 | args,
127 | ):
128 | super(IntegrativeDataset, self).__init__(adata, args)
129 | self.image = args.img
130 | self.map_info = args.map_info
131 | self.r = args.params['patch_r']
132 |
133 | def __len__(self):
134 | return len(self.expr_mat)
135 |
136 | def __getitem__(self, idx):
137 | if torch.is_tensor(idx):
138 | idx = idx.tolist()
139 | sample = torch.Tensor(
140 | np.array(self.expr_mat.iloc[idx, :], dtype='float')
141 | )
142 | return (sample,
143 | torch.Tensor(self.gexp.iloc[idx, :]),
144 | torch.Tensor(self.anchor_idx[idx, :]),
145 | torch.Tensor(self.library_n[idx, None])
146 | )
147 |
148 |
149 | class IntegrativePoEDataset(VisiumDataset):
150 |
151 | def __init__(
152 | self,
153 | adata,
154 | args,
155 | ):
156 | super(IntegrativePoEDataset, self).__init__(adata, args)
157 | self.image = args.img
158 | self.map_info = args.map_info
159 | self.r = args.params['patch_r']
160 |
161 |
162 | assert self.image is not None,\
163 | "Empty paired H&E image," \
164 | "please use regular `Starfysh` without PoE integration" \
165 | "if your dataset doesn't contain histology image"
166 | spot_img_all = []
167 |
168 | # Retrieve image patch around each spot
169 | for sample_id in args.img.keys():
170 |
171 | scalef_i = args.scalefactor[sample_id]['tissue_hires_scalef'] # High-res scale factor
172 | h, w = self.image[sample_id].shape[:2]
173 | patch_dim = (self.r*2, self.r*2, 3) if self.image[sample_id].ndim == 3 else (self.r*2, self.r*2)
174 |
175 | list_ = adata.obs['sample'] == sample_id
176 | for i in range(len(self.expr_mat.loc[list_,:])):
177 | xc = int(np.round(self.map_info.loc[list_,:].iloc[i]['imagecol'] * scalef_i))
178 | yc = int(np.round(self.map_info.loc[list_,:].iloc[i]['imagerow'] * scalef_i))
179 |
180 | # boundary conditions: edge spots
181 | yl, yr = max(0, yc-self.r), min(self.image[sample_id].shape[0], yc+self.r)
182 | xl, xr = max(0, xc-self.r), min(self.image[sample_id].shape[1], xc+self.r)
183 | top = max(0, self.r-yc)
184 | bottom = h if h > (yc+self.r) else h-(yc+self.r)
185 | left = max(0, self.r-xc)
186 | right = w if w > (xc+self.r) else w-(xc+self.r)
187 |
188 | try:
189 | patch = np.zeros(patch_dim)
190 | patch[top:bottom, left:right] = self.image[sample_id][yl:yr, xl:xr]
191 | spot_img_all.append(patch)
192 | except ValueError:
193 | LOGGER.warning('Skipping the patch loading of an edge spot...')
194 |
195 | self.spot_img_stack = list(spot_img_all)
196 | #print(self.spot_img_stack.shape)
197 |
198 | def __len__(self):
199 | return len(self.expr_mat)
200 |
201 | def __getitem__(self, idx):
202 | if torch.is_tensor(idx):
203 | idx = idx.tolist()
204 | sample = torch.Tensor(
205 | np.array(self.expr_mat.iloc[idx, :], dtype='float')
206 | )
207 | return (sample,
208 | torch.Tensor(self.anchor_idx[idx, :]),
209 | torch.Tensor(self.library_n[idx, None]),
210 | self.spot_img_stack[idx],
211 | self.map_info.index[idx],
212 | torch.Tensor(self.gexp.iloc[idx, :]),
213 | )
214 |
215 |
--------------------------------------------------------------------------------
/starfysh/gener_img.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import numpy as np
4 | import scipy.stats as stats
5 | import pandas as pd
6 | import pickle
7 | import anndata
8 | import json
9 | import scanpy as sc
10 | import scipy
11 | import PIL
12 | import seaborn as sns
13 | import torch
14 | import matplotlib.pyplot as plt
15 | from skimage import io
16 | from torch.utils.data import DataLoader,ConcatDataset
17 | import torch
18 | import torch.nn as nn
19 | from torch.nn import functional as F
20 | from typing import List, Callable, Union, Any, TypeVar, Tuple
21 | from tqdm import tqdm
22 | from torchvision.utils import save_image
23 |
24 | sys.path.append('HE-Net')
25 | torch.manual_seed(0)
26 | from henet import utils
27 | from henet import datasets
28 | from henet import model
29 |
30 |
31 | class dataset(torch.utils.data.Dataset):
32 |
33 | def __init__(
34 | self,
35 | spot,
36 | exp_spot,
37 | barcode_spot,
38 | #img_size,
39 | #histo_img,
40 | transform=None
41 | ):
42 |
43 | super(dataset, self).__init__()
44 | self.spot = spot
45 | self.exp_spot = exp_spot
46 | self.barcode_spot = barcode_spot
47 | self.transform = transform
48 | #self.img_size = img_size
49 | #self.histo_img = histo_img
50 |
51 | def __getitem__(self, idx):
52 | #print(idx)
53 | spot_x = self.spot[idx,0]
54 | spot_y = self.spot[idx,1]
55 | exp_spot = self.exp_spot[idx,:]
56 | barcode_spot = self.barcode_spot[idx]
57 |
58 | #x_l = int(spot_x-self.img_size/2)
59 | #x_r = int(spot_x+self.img_size/2)
60 | #y_l = int(spot_y-self.img_size/2)
61 | #y_r = int(spot_y+self.img_size/2)
62 |
63 | #img_spot = img[y_l:y_r,x_l:x_r,:]
64 | #img_spot = self.histo_img[y_l:y_r,x_l:x_r]
65 | #img_spot = img_spot-img_spot.min()
66 | #img_spot = img_spot/img_spot.max()
67 |
68 | #if self.transform is not None:
69 | # img_spot = self.transform(img_spot)
70 | return exp_spot, barcode_spot#, img_spot
71 | #return exp_spot
72 |
73 | def __len__(self):
74 |
75 | return len(self.spot)
76 |
77 | def prep_dataset(adata):
78 |
79 | #library_id = list(adata.uns.get("spatial",{}).keys())[0]
80 | #histo_img = adata.uns['spatial'][str(library_id)]['images']['hires']
81 | #spot_diameter_fullres = adata.uns['spatial'][str(library_id)]['scalefactors']['spot_diameter_fullres']
82 | #tissue_hires_scalef = adata.uns['spatial'][str(library_id)]['scalefactors']['tissue_hires_scalef']
83 | #circle_radius = spot_diameter_fullres *tissue_hires_scalef * 0.5
84 | image_spot = adata.obsm['spatial'] #*tissue_hires_scalef
85 |
86 | #img_size=32
87 | N_spot = image_spot.shape[0]
88 | train_index = np.random.choice(image_spot.shape[0], size=int(N_spot*0.6), replace=False)
89 | #val_index=np.random.choice(image_spot.shape[0], size=int(N_spot*0.1), replace=False)
90 | test_index=np.random.choice(image_spot.shape[0], size=int(N_spot*0.4), replace=True)
91 |
92 | #train_transforms = transforms.Compose([transforms.ToTensor()])
93 | #val_transforms = transforms.Compose([transforms.ToTensor()])
94 | #test_transforms = transforms.Compose([transforms.ToTensor()])
95 |
96 | train_spot = image_spot[train_index]
97 | test_spot = image_spot[test_index]
98 | #val_spot = image_spot[val_index]
99 |
100 | adata_df = adata.to_df()
101 |
102 | train_exp_spot = np.array(adata_df.iloc[train_index])
103 | test_exp_spot = np.array(adata_df.iloc[test_index])
104 | #val_exp_spot = np.array(adata_df.iloc[val_index])
105 |
106 |
107 | train_barcode_spot = adata_df.index[train_index]
108 | test_barcode_spot = adata_df.index[test_index]
109 | #val_barcode_spot = adata_df.index[val_index]
110 |
111 | train_set = dataset(train_spot, train_exp_spot, train_barcode_spot, None)
112 | test_set = dataset(test_spot, test_exp_spot, test_barcode_spot,None)
113 |
114 | all_set = dataset( image_spot,np.array(adata_df), adata_df.index, None)
115 |
116 | return train_set,test_set,all_set
117 |
118 |
119 | def generate_img(dat_path, train_flag=True):
120 | """
121 | input:
122 | dat_path: the path for csv file
123 |
124 | """
125 |
126 | dat_folder = 'data'
127 | dat_name = 'CID44971'
128 | n_genes = 6000
129 | adata1,variable_gene = utils.load_adata(dat_folder=dat_folder,
130 | dat_name=dat_name,
131 | use_other_gene=False,
132 | other_gene_list=None,
133 | n_genes=n_genes)
134 |
135 |
136 | adata5 = sc.read_csv(dat_path)
137 | adata5.obs_names = adata5.to_df().index
138 | adata5.var_names = adata5.to_df().columns
139 | adata5.var_names_make_unique()
140 |
141 | adata5.var["mt"] = adata5.var_names.str.startswith("MT-")
142 | sc.pp.calculate_qc_metrics(adata5, qc_vars=["mt"], inplace=True)
143 | print('The datasets have',adata5.n_obs,'spots, and',adata5.n_vars,'genes')
144 | adata5.var['rp'] = adata5.var_names.str.startswith('RPS') + adata5.var_names.str.startswith('RPL')
145 | sc.pp.calculate_qc_metrics(adata5, qc_vars=['rp'], percent_top=None, log1p=False, inplace=True)
146 |
147 | # Remove mitochondrial genes
148 | #print('removing mt/rp genes')
149 | adata5 = adata5[:,-adata5.var['mt']]
150 | adata5 = adata5[:,-adata5.var['rp']]
151 |
152 | adata5_new = pd.DataFrame(np.zeros([adata5.to_df().shape[0],6000]),index =adata5.obs_names, columns=variable_gene )
153 |
154 | inter_gene = np.intersect1d(variable_gene,adata5.var_names)
155 |
156 | for i in inter_gene:
157 | adata5_new.loc[:,i] = np.array(adata5[:,i].X)
158 |
159 | adata5_new_anndata = anndata.AnnData(np.log(
160 | #np.clip(adata5_new,0,2000)
161 | adata5_new
162 | +1))
163 |
164 | other_list_temp = variable_gene.intersection(adata5.var.index)
165 | adata5 = adata5[:,other_list_temp]
166 | sc.pp.normalize_total(adata5,inplace=True)
167 | sc.pp.log1p(adata5)
168 |
169 | map_info_list = []
170 | for i in range(2500):
171 | x,y = np.where(np.arange(2500).reshape(50, 50)==i)
172 | map_info_list.append([x[0]+5,y[0]+5])
173 | map_info = pd.DataFrame(map_info_list,columns=['array_row','array_col'],index = adata5.obs_names)
174 |
175 | tissue_hires_scalef = 0.20729685
176 | adata5_new_anndata.obsm['spatial']=np.array(map_info)*32
177 | adata5.obsm['spatial']=np.array(map_info)*32
178 |
179 |
180 |
181 |
182 | train_set1,test_set1,all_set1 = datasets.prep_dataset(adata1, 'CID44971')
183 |
184 | train_set = ConcatDataset([train_set1,
185 |
186 | ])
187 | test_set = ConcatDataset([test_set1,
188 |
189 | ])
190 |
191 |
192 | train_dataloader = DataLoader(train_set, shuffle=True, num_workers=1, batch_size=8)
193 | test_dataloader= DataLoader(test_set, shuffle=True, num_workers=1, batch_size=8)
194 | all_dataloader1 = DataLoader(all_set1, shuffle=True, num_workers=1, batch_size=8)
195 |
196 |
197 | train_set5,test_set5,all_set5 = prep_dataset(adata5_new_anndata)
198 | all_dataloader5 = DataLoader(all_set5, batch_size=8)
199 |
200 | # Model Hyperparameters
201 | cuda = False
202 | DEVICE = torch.device("cuda" if cuda else "cpu")
203 | batch_size = 20
204 | x_dim =1024
205 | hidden_dim = 400
206 | latent_dim = 20
207 | lr = 1e-3
208 | epochs = 1000
209 |
210 | encoder = model.Encoder(input_dim=x_dim, hidden_dim=hidden_dim, latent_dim=latent_dim,DEVICE=DEVICE)
211 | decoder = model.Decoder(latent_dim=latent_dim, hidden_dim = hidden_dim, output_dim = x_dim)
212 |
213 | vae = model.Model(Encoder=encoder, Decoder=decoder).to(DEVICE)
214 | check_point_filename = 'HE-Net/trained_model_on_lab1A.pt'
215 | vae.load_state_dict(torch.load(check_point_filename,map_location=torch.device('cpu') ))
216 |
217 |
218 | train_net = model.ResNet.resnet18()
219 |
220 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
221 | #device = torch.device("cpu")
222 | train_net.to(device)
223 | lr = 0.001
224 | #optimizer = torch.optim.SGD(oct_net.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
225 | optimizer = torch.optim.Adam(train_net.parameters(), lr=lr, amsgrad=True)
226 |
227 |
228 | vae.eval()
229 | if train_flag:
230 | for epoch in range(10):
231 | losses = []
232 | for batch_idx, (exp_spot, barcode_spot, img_spot_i) in enumerate(train_dataloader):
233 | #print(batch_idx)
234 | img_spot_i = img_spot_i.transpose(3,2).transpose(2,1)
235 | img_spot_i = img_spot_i.reshape(-1, 1024)
236 |
237 | img_spot_i = img_spot_i.to(DEVICE)
238 |
239 | x_hat, mean, log_var = vae(img_spot_i)
240 |
241 | mean = mean.view(exp_spot.shape[0], 3, 20)
242 | mean = mean.view(exp_spot.shape[0], 60)
243 | #mean = mean.cpu().detach().numpy()
244 | x_hat= x_hat.view(exp_spot.shape[0], 3, 32, 32)
245 | img_spot_i= img_spot_i.view(exp_spot.shape[0], 3, 32, 32)
246 |
247 | #print(exp_spot.shape)
248 | output1 = train_net(exp_spot[:,None,:].to(DEVICE))
249 | loss_model = nn.L1Loss()
250 | #print(output1.flatten())
251 | loss = F.mse_loss(output1.reshape(-1), mean.reshape(-1))#+0.3*loss_model(output1, y1)
252 | if train_flag:
253 | loss.backward()
254 | optimizer.step()
255 | optimizer.zero_grad()
256 | losses.append(loss.cpu().detach().numpy())
257 | #print(np.mean(losses))
258 | print("\tEpoch", epoch + 1, "complete!", "\tAverage Loss: ", np.mean(losses))
259 |
260 |
261 | adata_test = adata5_new_anndata#adata5_new_anndata#adata5_new_anndata
262 | all_dataloader_test = all_dataloader5#all_dataloader5
263 |
264 |
265 | train_net.eval()
266 | #library_id = list(adata_test.uns.get("spatial",{}).keys())[0]
267 | img_size=32
268 | #tissue_hires_scalef = adata_test.uns['spatial'][str(library_id)]['scalefactors']['tissue_hires_scalef']
269 | #recon = np.ones(histo_img.shape)
270 | recon = np.ones([2000,2000,3])
271 | for batch_idx, (exp_spot, barcode_spot) in enumerate(all_dataloader_test):
272 |
273 | #img_spot_i = img_spot_i.transpose(3,2).transpose(2,1)
274 | #img_spot_i = img_spot_i.reshape(-1, 1024)
275 |
276 | #img_spot_i = img_spot_i.to(DEVICE)
277 |
278 | #x_hat, mean, log_var = vae(img_spot_i)
279 |
280 | #mean = mean.view(exp_spot.shape[0], 3, 20)
281 | #mean = mean.view(exp_spot.shape[0], 60)
282 | #mean = mean.cpu().detach().numpy()
283 | #x_hat= x_hat.view(exp_spot.shape[0], 3, 32, 32)
284 | #img_spot_i= img_spot_i.view(exp_spot.shape[0], 3, 32, 32)
285 |
286 | #print(exp_spot.shape)
287 | output1 = train_net(exp_spot[:,None,:].to(DEVICE))
288 | #plt.figure()
289 | #plt.imshow(mean.cpu().detach())
290 | #plt.figure()
291 | #plt.imshow(output1.cpu().detach())
292 |
293 |
294 | output1 = output1.view(exp_spot.shape[0],3,20)
295 | #print(mean.shape)
296 | #print(output1.shape)
297 |
298 |
299 | output1_1 = output1[:,0,:]
300 | output1_2 = output1[:,1,:]
301 | output1_3 = output1[:,2,:]
302 |
303 | x_predicted_1 = vae.Decoder(output1_1)
304 | x_predicted_1 = x_predicted_1.view(exp_spot.shape[0],32,32).cpu().detach().numpy()
305 | x_predicted_2 = vae.Decoder(output1_2)
306 | x_predicted_2 = x_predicted_2.view(exp_spot.shape[0],32,32).cpu().detach().numpy()
307 | x_predicted_3 = vae.Decoder(output1_3)
308 | x_predicted_3 = x_predicted_3.view(exp_spot.shape[0],32,32).cpu().detach().numpy()
309 |
310 | for j,i in zip(range(len(barcode_spot)),barcode_spot):
311 | spot_x = ((adata_test[i].obsm['spatial'] )[0])[1]
312 | spot_y = ((adata_test[i].obsm['spatial'] )[0])[0]
313 |
314 | x_l = int(spot_x-img_size/2)
315 | x_r = int(spot_x+img_size/2)
316 | y_l = int(spot_y-img_size/2)
317 | y_r = int(spot_y+img_size/2)
318 |
319 | recon[y_l:y_r,x_l:x_r,0]=x_predicted_1[j]
320 | recon[y_l:y_r,x_l:x_r,1]=x_predicted_2[j]
321 | recon[y_l:y_r,x_l:x_r,2]=x_predicted_3[j]
322 |
323 | save_path = dat_path.split('/')[0]+'/'+dat_path.split('/')[1]
324 | if not os.path.exists(save_path+'/spatial'):
325 | os.makedirs(save_path+'/spatial')
326 | io.imsave(save_path+'/spatial/tissue_hires_image.png',(recon*255).astype('uint8'))
327 | map_info[['img_col','img_row']]=adata5_new_anndata.obsm['spatial']
328 | map_info.to_csv(save_path+'/spatial/tissue_positions_list.csv',index=True)
329 | adata5_new_anndata.write(save_path+'/'+dat_path.split('/')[1]+'.h5ad')
330 |
331 |
332 | return recon
333 |
334 |
--------------------------------------------------------------------------------
/starfysh/plot_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import pandas as pd
5 | import scanpy as sc
6 | import matplotlib.pyplot as plt
7 | import seaborn as sns
8 |
9 | from sklearn.metrics import r2_score
10 | from scipy.stats import pearsonr, gaussian_kde
11 |
12 |
13 | # Module import
14 | from .post_analysis import get_z_umap
15 | from .utils import extract_feature
16 |
17 |
18 | def plot_integrated_spatial_feature(data,
19 | sample_id,
20 | vmin=None,
21 | vmax=None,
22 | spot_size = 1,
23 | figsize=(2.5, 2),
24 | fig_dpi = 300,
25 | cmap = 'Blues',
26 | legend_on = True,
27 | legend_loc = (2,0.5)
28 | ):
29 |
30 | adata_ann = data[data.obs['sample']==sample_id]
31 | #color_idx_list = np.array(adata_ann.obs['pheno_louvain']).astype(int)
32 | color_idx_list = np.array(adata_ann.obs['pheno_louvain']).astype(int)
33 | all_loc = np.array(adata_ann.obsm['spatial'])
34 | fig,axs= plt.subplots(1,1,figsize=figsize,dpi=500)
35 | for i in np.unique(color_idx_list):
36 | g=axs.scatter(all_loc[color_idx_list==i,0],-all_loc[color_idx_list==i, 1],s=spot_size,marker='h',c=cmap[i])
37 |
38 | if legend_on:
39 | plt.legend(list(np.unique(color_idx_list)),title='Hub region',loc='right',bbox_to_anchor=legend_loc)
40 | #fig.colorbar(g,label=label)
41 | #plt.title(sample_id)
42 | axs.set_xticks([])
43 | axs.set_yticks([])
44 | plt.axis('off')
45 |
46 | def plot_spatial_density(
47 | data,
48 | vmin=None,
49 | vmax=None,
50 | spot_size = 1,
51 | figsize=(2.5, 2),
52 | fig_dpi = 300,
53 | cmap = 'Blues',
54 | colorbar_on = True,
55 | ):
56 | fig, axes = plt.subplots(1, 1, figsize=figsize, dpi=fig_dpi)
57 | g = axes.scatter(data.obsm['spatial'][:,0],
58 | -data.obsm['spatial'][:,1],
59 | c=data.obsm['ql_m'],
60 | vmin=vmin,
61 | vmax=vmax,
62 | cmap=cmap,
63 | s=spot_size,
64 | )
65 | if colorbar_on:
66 | fig.colorbar(g, label='Estimated density')
67 | axes.axis('off')
68 |
69 | def plot_spatial_cell_type_frac( data = None,
70 | cell_type = None,
71 | vmin=None,# adjust
72 | vmax=None,# adjust
73 | spot_size=2,# adjust
74 | figsize = (3,2.5),
75 | fig_dpi = 300, # >300 for high quality img
76 | cmap = 'magma',
77 | colorbar_on = True,
78 | label=None,
79 | title=None,
80 | ):
81 |
82 |
83 | fig,axs= plt.subplots(1,1,figsize=figsize,dpi=fig_dpi)
84 | idx = data.uns['cell_types'].index(cell_type)
85 | g=axs.scatter(data.obsm['spatial'][:,0],
86 | -data.obsm['spatial'][:,1],
87 | c=data.obsm['qc_m'][:,idx],
88 | cmap=cmap,
89 | vmin=vmin,
90 | vmax=vmax,
91 | s=spot_size
92 | )
93 | if title is not None:
94 | plt.title(title)
95 | else:
96 | plt.title(cell_type)
97 | if colorbar_on:
98 | fig.colorbar(g,label=label)
99 | plt.axis('off')
100 |
101 |
102 |
103 | def plot_z_umap_cell_type_frac(data = None,
104 | cell_type = None,
105 | vmin=None,# adjust
106 | vmax=None,# adjust
107 | spot_size=2,# adjust
108 | figsize = (3,2.5),
109 | fig_dpi = 300, # >300 for high quality img
110 | cmap = 'magma',
111 | colorbar_on = True,
112 | label=None,
113 | title=None,
114 | ):
115 |
116 |
117 | fig,axs= plt.subplots(1,1,figsize=figsize,dpi=fig_dpi)
118 | idx = data.uns['cell_types'].index(cell_type)
119 | g=axs.scatter(data.obsm['z_umap'][:,0],
120 | data.obsm['z_umap'][:,1],
121 | c=data.obsm['qc_m'][:,idx],
122 | cmap=cmap,
123 | vmin=vmin,
124 | vmax=vmax,
125 | s=spot_size
126 | )
127 | if title is not None:
128 | plt.title(title)
129 | else:
130 | plt.title(cell_type)
131 | if colorbar_on:
132 | fig.colorbar(g,label=label)
133 | plt.axis('off')
134 |
135 | def plot_spatial_feature(data = None,
136 | feature = None,
137 | vmin=None,# adjust
138 | vmax=None,# adjust
139 | spot_size=2,# adjust
140 | figsize = (3,2.5),
141 | fig_dpi = 300, # >300 for high quality img
142 | cmap = 'magma',
143 | colorbar_on = True,
144 | label=None
145 | ):
146 |
147 |
148 | fig,axs= plt.subplots(1,1,figsize=figsize,dpi=fig_dpi)
149 |
150 | g=axs.scatter(data.obsm['spatial'][:,0],
151 | -data.obsm['spatial'][:,1],
152 | c=feature,
153 | cmap=cmap,
154 | vmin=vmin,
155 | vmax=vmax,
156 | s=spot_size
157 | )
158 |
159 | if colorbar_on:
160 | fig.colorbar(g,label=label)
161 | plt.axis('off')
162 |
163 | def plot_spatial_gene(data,
164 | data_normed,
165 | gene_name = None,
166 | log_gene = False,
167 | vmin=None,
168 | vmax=None,
169 | spot_size=5,
170 | figsize = (2,2),
171 | fig_dpi = 300,
172 | cmap = 'magma',
173 | colorbar_on = True,
174 | ):
175 |
176 | fig,axs= plt.subplots(1,1,figsize=figsize,dpi=fig_dpi)
177 | if log_gene:
178 | g=axs.scatter(data_normed.obsm['spatial'][:,0],
179 | -data_normed.obsm['spatial'][:,1],
180 | c=data_normed.to_df().loc[:,gene_name],
181 | cmap=cmap,
182 | vmin=vmin,
183 | vmax=vmax,
184 | s=spot_size
185 | )
186 | else:
187 | g=axs.scatter(data.obsm['spatial'][:,0],
188 | -data.obsm['spatial'][:,1],
189 | c=data.to_df().loc[:,gene_name],
190 | cmap=cmap,
191 | vmin=vmin,
192 | vmax=vmax,
193 | s=spot_size
194 | )
195 | if colorbar_on:
196 | fig.colorbar(g,label=gene_name)
197 | plt.axis('off')
198 |
199 |
200 | def plot_anchor_spots(umap_plot,
201 | pure_spots,
202 | sig_mean,
203 | bbox_x=2,
204 | ):
205 | fig,ax = plt.subplots(1,1,dpi=300,figsize=(3,3))
206 | ax.scatter(umap_plot['umap1'],
207 | umap_plot['umap2'],
208 | s=2,
209 | alpha=1,
210 | color='lightgray')
211 | for i in range(len(pure_spots)):
212 | ax.scatter(umap_plot['umap1'][pure_spots[i]],
213 | umap_plot['umap2'][pure_spots[i]],
214 | s=8)
215 | plt.legend(['all']+[i for i in sig_mean.columns],
216 | loc='right',
217 | bbox_to_anchor=(bbox_x,0.5),)
218 | ax.grid(False)
219 | ax.axis('off')
220 |
221 |
222 | def plot_evs(evs, kmin):
223 | fig, ax = plt.subplots(1, 1, dpi=300, figsize=(6, 3))
224 | plt.plot(np.arange(len(evs))+kmin, evs, '.-')
225 | plt.xlabel('ks')
226 | plt.ylabel('Explained Variance')
227 | plt.show()
228 |
229 |
230 | def pl_spatial_inf_feature(
231 | adata,
232 | feature,
233 | factor=None,
234 | vmin=0,
235 | vmax=None,
236 | spot_size=100,
237 | alpha=0,
238 | cmap='Spectral_r'
239 | ):
240 | """Spatial visualization of Starfysh inference features"""
241 | if isinstance(factor, str):
242 | assert factor in adata.uns['cell_types'], \
243 | "Invalid Starfysh inference factor (cell type): ".format(factor)
244 | elif isinstance(factor, list):
245 | for f in factor:
246 | assert f in adata.uns['cell_types'], \
247 | "Invalid Starfysh inference factor (cell type): ".format(f)
248 | else:
249 | factor = adata.uns['cell_types'] # if None, display for all cell types
250 |
251 | adata_pl = extract_feature(adata, feature)
252 |
253 | if feature == 'qc_m':
254 | if isinstance(factor, list):
255 | title = [f + ' (Inferred proportion - Spatial)' for f in factor]
256 | else:
257 | title = factor + ' (Inferred proportion - Spatial)'
258 | sc.pl.spatial(
259 | adata_pl,
260 | color=factor, spot_size=spot_size, color_map=cmap,
261 | ncols=3, vmin=vmin, vmax=vmax, alpha_img=alpha,
262 | title=title, legend_fontsize=8
263 | )
264 | elif feature == 'ql_m':
265 | title = 'Estimated tissue density'
266 | sc.pl.spatial(
267 | adata_pl,
268 | color='density', spot_size=spot_size, color_map=cmap,
269 | vmin=vmin, vmax=vmax, alpha_img=alpha,
270 | title=title, legend_fontsize=8
271 | )
272 | elif feature == 'qz_m':
273 | # Visualize deconvolution on UMAP of inferred Z-space
274 | qz_u = get_z_umap(adata_pl.obs.values)
275 | qc_df = extract_feature(adata, 'qc_m').obs
276 | if isinstance(factor, list):
277 | for cell_type in factor:
278 | title = cell_type + ' (Inferred proportion - UMAP of Z)'
279 | pl_umap_feature(qz_u, qc_df[cell_type].values, cmap, title,
280 | vmin=vmin, vmax=vmax)
281 | else:
282 | title = factor + ' (Inferred proportion - UMAP of Z)'
283 | fig, ax = pl_umap_feature(qz_u, qc_df[factor].values, cmap, title,
284 | vmin=vmin, vmax=vmax)
285 | return fig, ax
286 | else:
287 | raise ValueError('Invalid Starfysh inference results `{}`, please choose from `qc_m`, `qz_m` & `ql_m`'.format(feature))
288 |
289 | pass
290 |
291 |
292 | def pl_umap_feature(qz_u, qc, cmap, title, spot_size=3, vmin=0, vmax=None):
293 | """Single Z-UMAP visualization of Starfysh deconvolutions"""
294 | fig, ax = plt.subplots(1, 1, figsize=(4, 3), dpi=200)
295 | g = ax.scatter(
296 | qz_u[:, 0], qz_u[:, 1],
297 | cmap=cmap, c=qc, s=spot_size, vmin=vmin, vmax=vmax,
298 | )
299 | ax.set_xticks([])
300 | ax.set_yticks([])
301 | ax.set_title(title)
302 | ax.axis('off')
303 | fig.colorbar(g, label='Inferred proportions')
304 |
305 | return fig, ax
306 |
307 |
308 | def pl_spatial_inf_gene(
309 | adata=None,
310 | factor=None,
311 | feature=None,
312 | vmin=None,
313 | vmax=None,
314 | spot_size=100,
315 | alpha=0,
316 | figsize = (3,2.5),
317 | fig_dpi = 500,
318 | cmap='Spectral_r',
319 | colorbar_on = True,
320 | title = None ,
321 |
322 | ):
323 |
324 | if isinstance(feature, str):
325 | assert feature in set(adata.var_names), \
326 | "Gene {0} isn't HVG, please choose from `adata.var_names`".format(feature)
327 | title_new = feature + ' (Predicted expression)'
328 | else:
329 | for f in feature:
330 | assert f in set(adata.var_names), \
331 | "Gene {0} isn't HVG, please choose from `adata.var_names`".format(f)
332 | title_new = [f + ' (Predicted expression)' for f in feature]
333 |
334 | if title is not None:
335 | title_new = title
336 | # Assign dummy `var_names` to avoid gene name in both obs & var
337 | adata_expr = extract_feature(adata, factor+'_inferred_exprs')
338 | adata_expr.var_names = np.arange(adata_expr.shape[1])
339 |
340 | sc.settings.set_figure_params(figsize=figsize,dpi=fig_dpi,facecolor="white")
341 |
342 | sc.pl.spatial(
343 | adata_expr,
344 | color=feature, spot_size=spot_size, color_map=cmap,
345 | ncols=3, vmin=vmin, vmax=vmax, alpha_img=alpha,
346 | title=title_new,
347 | legend_fontsize=8,
348 |
349 | )
350 |
351 | pass
352 |
353 |
354 | # --------------------------
355 | # util funcs for benchmark
356 | # --------------------------
357 |
358 | def _dist2gt(A, A_gt):
359 | """
360 | Calculate the distance to ground-truth correlation matrix (proportions)
361 | """
362 | return np.linalg.norm(A - A_gt, ord='fro')
363 |
364 |
365 | def _calc_rmse(y_true, y_pred):
366 | """Calculate per-spot RMSE between ground-truth & predicted proportions"""
367 | assert y_true.shape == y_pred.shape, "proportion matrices need to be the same shape to calculate RMSE"
368 | n_cts = y_true.shape[1]
369 | rmse = np.sqrt(((y_true.values-y_pred.values)**2).sum(1) / n_cts)
370 | return rmse
371 |
372 |
373 | def bootstrap_dists(corr_df, corr_gt_df, n_iter=1000, size=10):
374 | """
375 | Calculate the avg. distance to ground-truth (sub)-matrix based on random subsampling
376 | """
377 | if size == None:
378 | size = corr_df.shape[0]
379 | n = min(size, size)
380 | labels = corr_df.columns
381 | dists = np.zeros(n_iter)
382 |
383 | for i in range(n_iter):
384 | lbl = np.random.choice(corr_df.columns, n)
385 | A = corr_df.loc[lbl, lbl].values
386 | A_gt = corr_gt_df.loc[lbl, lbl].values
387 | dists[i] = _dist2gt(A, A_gt)
388 |
389 | return dists
390 |
391 |
392 | def disp_rmse(y_true, y_preds, labels, title=None, return_rmse=False):
393 | """
394 | Boxplot of per-spot RMSEs for each prediction
395 | """
396 | n_spots, n_cts = y_true.shape
397 | rmse = np.array([
398 | np.sqrt(((y_true.values-y_pred.values)**2).sum(1) / n_cts)
399 | for y_pred in y_preds
400 | ])
401 |
402 | lbls = np.repeat(labels, n_spots)
403 | df = pd.DataFrame({
404 | 'RMSE': rmses.flatten(),
405 | 'Method': lbls
406 | })
407 | plt.figure(figsize=(10, 6))
408 | g = sns.boxplot(x='Method', y='RMSE', data=df)
409 | g.set_xticklabels(labels, rotation=60)
410 | plt.suptitle(title)
411 | plt.show()
412 |
413 | return rmses if return_rmse else None
414 |
415 |
416 | def disp_corr(
417 | y_true, y_pred,
418 | outdir=None,
419 | figsize=(3.2, 3.2),
420 | fontsize=5,
421 | title=None,
422 | filename=None,
423 | savefig=False,
424 | format='png',
425 | return_corr=False
426 | ):
427 | """
428 | Calculate & plot correlation of cell proportion (or absolute cell abundance)
429 | between ground-truth & predictions (both [S x F])
430 | """
431 |
432 | assert y_true.shape[0] == y_pred.shape[0], 'Inconsistent sample sizes between ground-truth & prediction'
433 | if savefig:
434 | assert format == 'png' or format == 'eps' or format == 'svg', "Invalid saving format"
435 |
436 | v1 = y_true.values
437 | v2 = y_pred.values
438 |
439 | n_factor1, n_factor2 = v1.shape[1], v2.shape[1]
440 | corr = np.zeros((n_factor1, n_factor2))
441 | gt_corr = y_true.corr().values
442 |
443 | for i in range(n_factor1):
444 | for j in range(n_factor2):
445 | corr[i, j], _ = np.round(pearsonr(v1[:, i], v2[:, j]), 3)
446 |
447 | fig, ax = plt.subplots(1, 1, figsize=figsize, dpi=300)
448 | ax = sns.heatmap(
449 | corr, annot=True,
450 | cmap='RdBu_r', vmin=-1, vmax=1,
451 | annot_kws={"fontsize": fontsize},
452 | cbar_kws={'label': 'Cell type proportion corr.'},
453 | ax=ax
454 | )
455 |
456 | ax.set_xticks(np.arange(n_factor2) + 0.5)
457 | ax.set_yticks(np.arange(n_factor1) + 0.5)
458 | ax.set_xticklabels(y_pred.columns, rotation=90)
459 | ax.set_yticklabels(y_true.columns, rotation=0)
460 | ax.set_xlabel('Estimated proportion')
461 | ax.set_ylabel('Ground truth proportion')
462 |
463 | if title is not None:
464 | # ax.set_title(title+'\n'+'Distance = %.3f' % (dist2identity(corr)))
465 | ax.set_title(title + '\n' + 'Distance = %.3f' % (_calc_rmse(y_true, y_pred).mean()))
466 |
467 | for item in (ax.get_xticklabels() + ax.get_yticklabels()):
468 | item.set_fontsize(12)
469 | if savefig and (outdir is not None and filename is not None):
470 | if not os.path.exists(outdir):
471 | os.makedirs(outdir)
472 | fig.savefig(os.path.join(outdir, filename + '.' + format), bbox_inches='tight', format=format)
473 | plt.show()
474 |
475 | return corr if return_corr else None
476 |
477 |
478 | def disp_prop_scatter(
479 | y_true, y_pred,
480 | outdir=None,
481 | filename=None,
482 | savefig=False,
483 | format='png'
484 | ):
485 | """
486 | Scatter plot of spot-wise proportion between ground-truth & predictions
487 | """
488 | assert y_true.shape == y_pred.shape, 'Inconsistent dimension between ground-truth & prediction'
489 | if savefig:
490 | assert format == 'png' or format == 'eps' or format == 'svg', "Invalid saving format"
491 |
492 | n_factors = y_true.shape[1]
493 | y_true_vals = y_true.values
494 | y_pred_vals = y_pred.values
495 | ncols = int(np.ceil(n_factors / 2))
496 |
497 | fig, (ax1, ax2) = plt.subplots(2, ncols, figsize=(2 * ncols, 4.4), dpi=300)
498 |
499 | for i in range(n_factors):
500 | v1 = y_true_vals[:, i]
501 | v2 = y_pred_vals[:, i]
502 | r2 = r2_score(v1, v2)
503 |
504 | v_stacked = np.vstack([v1, v2])
505 | den = gaussian_kde(v_stacked)(v_stacked)
506 |
507 | ax = ax1[i] if i < ncols else ax2[i % ncols]
508 | ax.scatter(v1, v2, c=den, s=.2, cmap='turbo', vmax=den.max() / 3)
509 |
510 | ax.set_aspect('equal')
511 | ax.spines['right'].set_visible(False)
512 | ax.spines['top'].set_visible(False)
513 | ax.axis('equal')
514 |
515 | # Only show ticks on the left and bottom spines
516 | ax.yaxis.set_ticks_position('left')
517 | ax.xaxis.set_ticks_position('bottom')
518 |
519 | ax.set_title(y_pred.columns[i])
520 | ax.annotate(r"$R^2$ = {:.3f}".format(r2), (0, 1), fontsize=8)
521 |
522 | ax.set_xlim([-0.1, 1.1])
523 | ax.set_ylim([-0.1, 1.1])
524 | ax.set_xticks(np.arange(0, 1.1, 0.5))
525 | ax.set_yticks(np.arange(0, 1.1, 0.5))
526 |
527 | ax.set_xlabel('Ground truth proportions')
528 | ax.set_ylabel('Predicted proportions')
529 |
530 | plt.tight_layout()
531 | if savefig and (outdir is not None and filename is not None):
532 | if not os.path.exists(outdir):
533 | os.makedirs(outdir)
534 | fig.savefig(os.path.join(outdir, filename + '.' + format), bbox_inches='tight', format=format)
535 |
536 | plt.show()
537 |
538 |
--------------------------------------------------------------------------------
/starfysh/post_analysis.py:
--------------------------------------------------------------------------------
1 | import json
2 | import scanpy as sc
3 | import numpy as np
4 | import matplotlib.pyplot as plt
5 | import pandas as pd
6 | import os
7 | import networkx as nx
8 | import seaborn as sns
9 | from starfysh import utils
10 | import umap
11 | from scipy.stats import pearsonr, gaussian_kde
12 | from sklearn.neighbors import KNeighborsRegressor
13 |
14 |
15 |
16 | def get_z_umap(qz_m):
17 | fit = umap.UMAP(n_neighbors=45, min_dist=0.5)
18 | u = fit.fit_transform(qz_m)
19 | return u
20 |
21 |
22 | def plot_type_all(model, adata, proportions, figsize=(4, 4)):
23 | u = get_z_umap(adata.obsm['qz_m'])
24 | qc_m = adata.obsm["qc_m"]
25 | group_c = np.argmax(qc_m,axis=1)
26 |
27 | fig, ax = plt.subplots(figsize=figsize, dpi=300)
28 | cmaps = ['Blues','Greens','Reds','Oranges','Purples']
29 | for i in range(proportions.shape[1]):
30 | plt.scatter(u[group_c==i,0],u[group_c==i,1],s=1,c = qc_m[group_c==i,i], cmap=cmaps[i])
31 |
32 | # project the model's u on the umap
33 | knr = KNeighborsRegressor(10)
34 | knr.fit(adata.obsm["qz_m"], u)
35 | qu_umap = knr.predict(adata.uns['qu'])
36 | ax.scatter(*qu_umap.T, c='yellow', edgecolors='black',
37 | s=np.exp(model.qs_logm.cpu().detach()).sum(1)**1/2)
38 |
39 | plt.legend(proportions.columns,loc='right', bbox_to_anchor=(2.2,0.5),)
40 | plt.axis('off')
41 | return fig, ax
42 |
43 |
44 | def get_corr_map(inference_outputs, proportions):
45 | qc_m_n = inference_outputs["qc_m"].detach().cpu().numpy()
46 | corr_map_qcm = np.zeros([qc_m_n.shape[1],qc_m_n.shape[1]])
47 |
48 | for i in range(corr_map_qcm.shape[0]):
49 | for j in range(corr_map_qcm.shape[0]):
50 | corr_map_qcm[i, j], _ = pearsonr(qc_m_n[:,i], proportions.iloc[:, j])
51 |
52 |
53 | plt.figure(dpi=300,figsize=(3.2,3.2))
54 | ax = sns.heatmap(corr_map_qcm.T, annot=True,
55 | cmap='RdBu_r',vmax=1,vmin=-1,
56 | cbar_kws={'label': 'Cell type proportion corr.'}
57 | )
58 | plt.xticks(np.array(range(qc_m_n.shape[1]))+0.5,labels=proportions.columns,rotation=90)
59 | plt.yticks(np.array(range(qc_m_n.shape[1]))+0.5,labels=proportions.columns,rotation=0)
60 | plt.xlabel('Estimated proportion')
61 | plt.ylabel('Ground truth proportion')
62 |
63 |
64 |
65 | def display_reconst(
66 | df_true,
67 | df_pred,
68 | density=False,
69 | marker_genes=None,
70 | sample_rate=0.1,
71 | size=(3, 3),
72 | spot_size=1,
73 | title=None,
74 | x_label='',
75 | y_label='',
76 | x_min=0,
77 | x_max=10,
78 | y_min=0,
79 | y_max=10,
80 | ):
81 | """
82 | Scatter plot - raw gexp vs. reconstructed gexp
83 | """
84 | assert 0 < sample_rate <= 1, \
85 | "Invalid downsampling rate for reconstruct scatter plot: {}".format(sample_rate)
86 |
87 | if marker_genes is not None:
88 | marker_genes = set(marker_genes)
89 |
90 | df_true_sample = df_true.sample(frac=sample_rate, random_state=0)
91 | df_pred_sample = df_pred.loc[df_true_sample.index]
92 |
93 | plt.rcParams["figure.figsize"] = size
94 | plt.figure(dpi=300)
95 | ax = plt.gca()
96 |
97 | xx = df_true_sample.T.to_numpy().flatten()
98 | yy = df_pred_sample.T.to_numpy().flatten()
99 |
100 | if density:
101 | for gene in df_true_sample.columns:
102 | try:
103 | gene_true = df_true_sample[gene].values
104 | gene_pred = df_pred_sample[gene].values
105 | gexp_stacked = np.vstack([df_true_sample[gene].values, df_pred_sample[gene].values])
106 |
107 | z = gaussian_kde(gexp_stacked)(gexp_stacked)
108 | ax.scatter(gene_true, gene_pred, c=z, s=spot_size, alpha=0.5)
109 | except np.linalg.LinAlgError as e:
110 | pass
111 |
112 | elif marker_genes is not None:
113 | color_dict = {True: 'red', False: 'green'}
114 | gene_colors = np.vectorize(
115 | lambda x: color_dict[x in marker_genes]
116 | )(df_true_sample.columns)
117 | colors = np.repeat(gene_colors, df_true_sample.shape[0])
118 |
119 | ax.scatter(xx, yy, c=colors, s=spot_size, alpha=0.5)
120 |
121 | else:
122 | ax.scatter(xx, yy, s=spot_size, alpha=0.5)
123 |
124 | min_val = min(xx.min(), yy.min())
125 | max_val = max(xx.max(), yy.max())
126 | #ax.set_xlim(min_val, 400)
127 | ax.set_xlim(x_min, x_max)
128 | ax.set_ylim(y_min, y_max)
129 | #ax.set_ylim(min_val, 400)
130 |
131 | plt.suptitle(title)
132 | plt.xlabel(x_label)
133 | plt.ylabel(y_label)
134 | # Hide the right and top spines
135 | ax.spines['right'].set_visible(False)
136 | ax.spines['top'].set_visible(False)
137 | ax.axis('equal')
138 | # Only show ticks on the left and bottom spines
139 | ax.yaxis.set_ticks_position('left')
140 | ax.xaxis.set_ticks_position('bottom')
141 |
142 | plt.show()
143 |
144 |
145 | def gene_mean_vs_inferred_prop(inference_outputs, visium_args,idx):
146 |
147 | sig_mean_n_df = pd.DataFrame(
148 | np.array(visium_args.sig_mean_norm)/(np.sum(np.array(visium_args.sig_mean_norm),axis=1,keepdims=True)+1e-5),
149 | columns=visium_args.sig_mean_norm.columns,
150 | index=visium_args.sig_mean_norm.index
151 | )
152 |
153 | qc_m = inference_outputs["qc_m"].detach().cpu().numpy()
154 |
155 | figs,ax=plt.subplots(1,1,dpi=300,figsize=(2,2))
156 |
157 | v1 = sig_mean_n_df.iloc[:,idx].values
158 | v2 = qc_m[:,idx]
159 |
160 | v_stacked = np.vstack([v1, v2])
161 | den = gaussian_kde(v_stacked)(v_stacked)
162 |
163 | ax.scatter(v1,v2,c=den,s=1,cmap='jet',vmax=den.max()/3)
164 |
165 | ax.set_aspect('equal')
166 | ax.spines['right'].set_visible(False)
167 | ax.spines['top'].set_visible(False)
168 | ax.axis('equal')
169 | # Only show ticks on the left and bottom spines
170 | ax.yaxis.set_ticks_position('left')
171 | ax.xaxis.set_ticks_position('bottom')
172 | plt.title(visium_args.gene_sig.columns[idx])
173 | plt.xlim([v1.min()-0.1,v1.max()+0.1])
174 | plt.ylim([v2.min()-0.1,v2.max()+0.1])
175 | #plt.xticks(np.arange(0,1.1,0.5))
176 | #plt.yticks(np.arange(0,1.1,0.5))
177 | plt.xlabel('Gene signature mean')
178 | plt.ylabel('Predicted proportions')
179 |
180 |
181 | def plot_stacked_prop(results, category_names):
182 | """
183 | Parameters
184 | ----------
185 | results : dict
186 | A mapping from question labels to a list of answers per category.
187 | It is assumed all lists contain the same number of entries and that
188 | it matches the length of *category_names*.
189 |
190 | category_names : list of str
191 | The category labels.
192 | """
193 | labels = list(results.keys())
194 | data = np.array(list(results.values()))
195 | data_cum = data.cumsum(axis=1)
196 | category_colors = plt.get_cmap('rainbow')(
197 | np.linspace(0.15, 0.85, data.shape[1]))
198 | #category_colors = np.array(['b','g','r','oragne','purple'])
199 | fig, ax = plt.subplots(figsize=(2.5,1.8),dpi=300)
200 | ax.invert_yaxis()
201 | ax.xaxis.set_visible(False)
202 | ax.set_xlim(0, np.sum(data, axis=1).max())
203 |
204 | for i, (colname, color) in enumerate(zip(category_names, category_colors)):
205 | widths = data[:, i]
206 | starts = data_cum[:, i] - widths
207 | ax.barh(labels, widths, left=starts, height=0.6,label=colname, color=color)
208 | xcenters = starts + widths / 2
209 |
210 | r, g, b, _ = color
211 | text_color = 'black' #if r * g * b < 0.5 else 'darkgrey'
212 | #for y, (x, c) in enumerate(zip(xcenters, widths)):
213 | # ax.text(x, y, str(round(c,2)), ha='center', va='center',
214 | # color=text_color)
215 | #ax.legend(ncol=len(category_names), bbox_to_anchor=(0, 1),
216 | # loc='right', fontsize='small')
217 | ax.legend(category_names,loc='right',bbox_to_anchor=(2, 0.5))
218 | return fig, ax
219 |
220 |
221 | def plot_density(results, category_names):
222 | """
223 | Parameters
224 | ----------
225 | results : dict
226 | A mapping from question labels to a list of answers per category.
227 | It is assumed all lists contain the same number of entries and that
228 | it matches the length of *category_names*.
229 |
230 | category_names : list of str
231 | The category labels.
232 | """
233 | labels = list(results.keys())
234 | data = np.array(list(results.values()))
235 | category_colors = plt.get_cmap('RdBu_r')(
236 | np.linspace(0.15, 0.85, data.shape[1]))
237 | fig, ax = plt.subplots(figsize=(2.5,1.8),dpi=300)
238 | ax.invert_yaxis()
239 | #ax.xaxis.set_visible(False)
240 | ax.set_xlim(0, np.sum(data, axis=1).max())
241 |
242 | for i, (colname, color) in enumerate(zip(category_names, category_colors)):
243 | widths = data[:, i]
244 | starts = 0
245 | ax.barh(labels, widths, left=starts, height=0.6,label=colname, color=color)
246 |
247 | r, g, b, _ = color
248 | return fig, ax
249 |
250 |
251 | def get_factor_dist(sample_ids,file_path):
252 | qc_p_dist = {}
253 | # Opening JSON file
254 | for sample_id in sample_ids:
255 | print(sample_id)
256 | f = open(file_path+sample_id+'_factor.json','r')
257 | data = json.load(f)
258 | qc_p_dist[sample_id] = data['qc_m']
259 | f.close()
260 | return qc_p_dist
261 |
262 |
263 | def get_adata(sample_ids, data_folder):
264 | adata_sample_all = []
265 | map_info_all = []
266 | adata_image_all = []
267 | for sample_id in sample_ids:
268 | print('loading...',sample_id)
269 | if (sample_id.startswith('MBC'))|(sample_id.startswith('CT')):
270 |
271 | adata_sample = sc.read_visium(path=os.path.join(data_folder, sample_id),library_id = sample_id)
272 | adata_sample.var_names_make_unique()
273 | adata_sample.obs['sample']=sample_id
274 | adata_sample.obs['sample_type']='MBC'
275 | #adata_sample.obs_names = adata_sample.obs_names+'-'+sample_id
276 | #adata_sample.obs_names = adata_sample.obs_names+'_'+sample_id
277 | if '_index' in adata_sample.var.columns:
278 | adata_sample.var_names=adata_sample.var['_index']
279 |
280 | else:
281 | adata_sample = sc.read_h5ad(os.path.join(data_folder,sample_id, sample_id+'.h5ad'))
282 | adata_sample.var_names_make_unique()
283 | adata_sample.obs['sample']=sample_id
284 | adata_sample.obs['sample_type']='TNBC'
285 | #adata_sample.obs_names = adata_sample.obs_names+'-'+sample_id
286 | #adata_sample.obs_names = adata_sample.obs_names+'_'+sample_id
287 | if '_index' in adata_sample.var.columns:
288 | adata_sample.var_names=adata_sample.var['_index']
289 |
290 | if data_folder =='simu_data':
291 | map_info = utils.get_simu_map_info(umap_df)
292 | else:
293 | adata_image,map_info = utils.preprocess_img(data_folder,sample_id,adata_sample.obs.index,hchannal=False)
294 |
295 | adata_sample.obs_names = adata_sample.obs_names+'-'+sample_id
296 | map_info.index = map_info.index+'-'+sample_id
297 | adata_sample_all.append(adata_sample)
298 | map_info_all.append(map_info)
299 | adata_image_all.append(adata_image)
300 | return adata_sample_all,map_info_all,adata_image_all
301 |
302 | def get_Moran(W, X):
303 | N = W.shape[0]
304 | term1 = N / W.sum().sum()
305 | x_m = X.mean()
306 | term2 = np.matmul(np.matmul(np.diag(X-x_m),W),np.diag(X-x_m))
307 | term3 = term2.sum().sum()
308 | term4 = ((X-x_m)**2).sum()
309 | term5 = term1 * term3 / term4
310 | return term5
311 |
312 | def get_LISA(W, X):
313 | lisa_score = np.zeros(X.shape)
314 | N = W.shape[0]
315 | x_m = X.mean()
316 | term1 = X-x_m
317 | term2 = ((X-x_m)**2).sum()
318 | for i in range(term1.shape[0]):
319 | #term3 = np.zeros(X.shape)
320 | term3 = (W[i,:]*(X-x_m)).sum()
321 | #for j in range(W.shape[0]):
322 | # term3[j]=W[i,j]*(X[j]-x_m)
323 | #term3 = term3.sum()
324 | lisa_score[i] = np.sign(X[i]-x_m) * N * (X[i]-x_m) * term3 / term2
325 | #lisa_score[i] = N * (X[i]-x_m) * term3 / term2
326 |
327 | return lisa_score
328 |
329 | def get_SCI(W, X, Y):
330 |
331 | N = W.shape[0]
332 | term1 = N / (2*W.sum().sum())
333 |
334 | x_m = X.mean()
335 | y_m = Y.mean()
336 | term2 = np.matmul(np.matmul(np.diag(X-x_m),W),np.diag(Y-y_m))
337 | term3 = term2.sum().sum()
338 |
339 | term4 = np.sqrt(((X-x_m)**2).sum()) * np.sqrt(((Y-y_m)**2).sum())
340 |
341 | term5 = term1 * term3 / term4
342 |
343 | return term5
344 |
345 | def get_cormtx(sample_id, hub_num ):
346 | # TODO: get_cormtx
347 | prop_i = proportions_df[ids_df['source']==sample_id][cluster_df['cluster']==hub_num]
348 | loc_i = np.array(map_info_all.loc[prop_i.index].loc[:,['array_col','array_row',]])
349 | W = np.zeros([loc_i.shape[0],loc_i.shape[0]])
350 |
351 | cor_matrix = np.zeros([gene_sig.shape[1],gene_sig.shape[1]])
352 | for i in range(loc_i.shape[0]):
353 | for j in range(i,loc_i.shape[0]):
354 | if np.sqrt((loc_i[i,0]-loc_i[j,0])**2+(loc_i[i,1]-loc_i[j,1])**2)<=3:
355 | W[i,j] = 1
356 | W[j,i] = 1
357 | #indices = vor.regions[vor.point_region[i]]
358 | #neighbor_i = np.concatenate([vor.ridge_points[np.where(vor.ridge_points[:,0] == i)],np.flip(vor.ridge_points[np.where(vor.ridge_points[:,1] == i)],axis=1)],axis=0)[:,1]
359 | #W[i,neighbor_i]=1
360 | #W[neighbor_i,i]=1
361 | print('spots in hub ',hub_num, '= ',prop_i.shape[0])
362 | if prop_i.shape[0]>1:
363 | for i in range(gene_sig.shape[1]):
364 | for j in range(i+1,gene_sig.shape[1]):
365 | cor_matrix[i,j]=get_SCI(W, np.array(prop_i.iloc[:,i]), np.array(prop_i.iloc[:,j]))
366 | cor_matrix[j,i]=cor_matrix[i,j]
367 | return cor_matrix
368 |
369 | def get_hub_cormtx(sample_ids, hub_num):
370 | cor_matrix = np.zeros([gene_sig.shape[1],gene_sig.shape[1]])
371 | for sample_id in sample_ids:
372 | print(sample_id)
373 | cor_matrix = cor_matrix + get_cormtx(sample_id = sample_id, hub_num=hub_num)
374 | #print(cor_matrix)
375 | cor_matrix = cor_matrix/len(sample_ids)
376 | #cor_matrix = pd.DataFrame(cor_matrix)
377 | return cor_matrix
378 |
379 | def create_corr_network_5(G, node_size_list,corr_direction, min_correlation):
380 | ##Creates a copy of the graph
381 | H = G.copy()
382 |
383 | ##Checks all the edges and removes some based on corr_direction
384 | for stock1, stock2, weight in G.edges(data=True):
385 | #print(weight)
386 | ##if we only want to see the positive correlations we then delete the edges with weight smaller than 0
387 | if corr_direction == "positive":
388 | ####it adds a minimum value for correlation.
389 | ####If correlation weaker than the min, then it deletes the edge
390 | if weight["weight"] <0 or weight["weight"] < min_correlation:
391 | H.remove_edge(stock1, stock2)
392 | ##this part runs if the corr_direction is negative and removes edges with weights equal or largen than 0
393 | else:
394 | ####it adds a minimum value for correlation.
395 | ####If correlation weaker than the min, then it deletes the edge
396 | if weight["weight"] >=0 or weight["weight"] > min_correlation:
397 | H.remove_edge(stock1, stock2)
398 |
399 |
400 | #crates a list for edges and for the weights
401 | edges,weights = zip(*nx.get_edge_attributes(H,'weight').items())
402 |
403 |
404 | ### increases the value of weights, so that they are more visible in the graph
405 | #weights = tuple([(0.5+abs(x))**1 for x in weights])
406 | weights = tuple([x*2 for x in weights])
407 | #print(len(weights))
408 | #####calculates the degree of each node
409 | d = nx.degree(H)
410 | #print(d)
411 | #####creates list of nodes and a list their degrees that will be used later for their sizes
412 | nodelist, node_sizes = zip(*dict(d).items())
413 | #import sys, networkx as nx, matplotlib.pyplot as plt
414 |
415 | # Create a list of 10 nodes numbered [0, 9]
416 | #nodes = range(10)
417 | node_sizes = []
418 | labels = {}
419 | for n in nodelist:
420 | node_sizes.append( node_size_list[n] )
421 | labels[n] = 1 * n
422 |
423 | # Node sizes: [0, 100, 200, 300, 400, 500, 600, 700, 800, 900]
424 |
425 | # Connect each node to its successor
426 | #edges = [ (i, i+1) for i in range(len(nodes)-1) ]
427 |
428 | # Create the graph and draw it with the node labels
429 | #g = nx.Graph()
430 | #g.add_nodes_from(nodes)
431 | #g.add_edges_from(edges)
432 |
433 | #nx.draw_random(g, node_size = node_sizes, labels=labels, with_labels=True)
434 | #plt.show()
435 |
436 | #positions
437 | positions=nx.circular_layout(H)
438 | #print(positions)
439 |
440 | #Figure size
441 | plt.figure(figsize=(2,2),dpi=500)
442 |
443 | #draws nodes,
444 | #options = {"edgecolors": "tab:gray", "alpha": 0.9}
445 | nx.draw_networkx_nodes(H,positions,
446 | #node_color='#DA70D6',
447 | nodelist=nodelist,
448 | #####the node size will be now based on its degree
449 | node_color=_colors['leiden_colors'][hub_num],# 'lightgreen',#pink, 'lightblue',#'#FFACB7',lightgreen B19CD9。#FFACB7 brown
450 | alpha = 0.8,
451 | node_size=tuple([x**1 for x in node_sizes]),
452 | #**options
453 | )
454 |
455 | #Styling for labels
456 | nx.draw_networkx_labels(H, positions, font_size=4,
457 | font_family='sans-serif')
458 |
459 | ###edge colors based on weight direction
460 | if corr_direction == "positive":
461 | edge_colour = plt.cm.GnBu#PiYG_r#RdBu_r#Spectral_r#GnBu#RdPu#PuRd#Blues#PuRd#GnBu OrRd
462 | else:
463 | edge_colour = plt.cm.PuRd
464 |
465 | #draws the edges
466 | print(min(weights))
467 | print(max(weights))
468 |
469 | nx.draw_networkx_edges(H, positions, edgelist=edges,style='solid',
470 | ###adds width=weights and edge_color = weights
471 | ###so that edges are based on the weight parameter
472 | ###edge_cmap is for the color scale based on the weight
473 | ### edge_vmin and edge_vmax assign the min and max weights for the width
474 | width=weights, edge_color = weights, edge_cmap = edge_colour,
475 | edge_vmin = 0,#min(weights),#0.55,#min(weights),
476 | edge_vmax= 0.7,#max(weights),#0.6,#max(weights)
477 | #edge_vmin = min(weights),#0.55,#min(weights),
478 | #edge_vmax= max(weights),#0.6,#max(weights)
479 | )
480 |
481 | # displays the graph without axis
482 | plt.axis('off')
483 | #plt.legend(['r','r'])
484 | #saves image
485 | #plt.savefig("part5" + corr_direction + ".png", format="PNG")
486 | #plt.show()
487 |
--------------------------------------------------------------------------------
/starfysh/starfysh.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import numpy as np
4 | import pandas as pd
5 | import os
6 | import random
7 |
8 | os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 | from torchvision import transforms
14 | from torchvision.utils import make_grid
15 | from torch.distributions import constraints, Distribution, Normal, Gamma, Poisson, Dirichlet
16 | from torch.distributions import kl_divergence as kl
17 |
18 | # Module import
19 | from starfysh import LOGGER
20 | from .post_analysis import get_z_umap
21 | os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
22 | random.seed(0)
23 | np.random.seed(0)
24 |
25 |
26 | # TODO:
27 | # inherit `AVAE` (expr model) w/ `AVAE_PoE` (expr + histology model), update latest PoE model
28 | class AVAE(nn.Module):
29 | """
30 | Model design
31 | p(x|z)=f(z)
32 | p(z|x)~N(0,1)
33 | q(z|x)~g(x)
34 | """
35 |
36 | def __init__(
37 | self,
38 | adata,
39 | gene_sig,
40 | win_loglib,
41 | alpha_mul=50,
42 | seed=0,
43 | device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
44 | ) -> None:
45 | """
46 | Auxiliary Variational AutoEncoder (AVAE) - Core model for
47 | spatial deconvolution without H&E image integration
48 |
49 | Paramters
50 | ---------
51 | adata : sc.AnnData
52 | ST raw expression count (dim: [S, G])
53 |
54 | gene_sig : pd.DataFrame
55 | Normalized avg. signature expressions for each annotated cell type
56 |
57 | win_loglib : float
58 | Log-library size smoothed with neighboring spots
59 |
60 | alpha_mul : float (default=50)
61 | Multiplier of Dirichlet concentration parameter to control
62 | signature prior's confidence
63 | """
64 | super().__init__()
65 | torch.manual_seed(seed)
66 |
67 | self.win_loglib=torch.Tensor(win_loglib)
68 |
69 | self.c_in = adata.shape[1] # c_in : Num. input features (# input genes)
70 | self.c_bn = 10 # c_bn : latent number, numbers of bottle-necks
71 | self.c_hidden = 256
72 | self.c_kn = gene_sig.shape[1]
73 | self.eps = 1e-5 # for r.v. w/ numerical constraints
74 | self.device = device
75 |
76 | self.alpha = torch.ones(self.c_kn) * alpha_mul
77 | self.alpha = self.alpha.to(device)
78 |
79 | self.qs_logm = torch.nn.Parameter(torch.zeros(self.c_kn, self.c_bn), requires_grad=True)
80 | self.qu_m = torch.nn.Parameter(torch.randn(self.c_kn, self.c_bn), requires_grad=True)
81 | self.qu_logv = torch.nn.Parameter(torch.zeros(self.c_kn, self.c_bn), requires_grad=True)
82 |
83 | self.c_enc = nn.Sequential(
84 | nn.Linear(self.c_in, self.c_hidden, bias=True),
85 | nn.BatchNorm1d(self.c_hidden, momentum=0.01, eps=0.001),
86 | nn.ReLU()
87 | )
88 |
89 | self.c_enc_m = nn.Sequential(
90 | nn.Linear(self.c_hidden, self.c_kn, bias=True),
91 | nn.BatchNorm1d(self.c_kn, momentum=0.01, eps=0.001),
92 | nn.Softmax(dim=-1)
93 | )
94 |
95 | self.l_enc = nn.Sequential(
96 | nn.Linear(self.c_in, self.c_hidden, bias=True),
97 | nn.BatchNorm1d(self.c_hidden, momentum=0.01, eps=0.001),
98 | nn.ReLU()
99 | )
100 |
101 | self.l_enc_m = nn.Linear(self.c_hidden, 1)
102 | self.l_enc_logv = nn.Linear(self.c_hidden, 1)
103 |
104 | # neural network f1 to get the z, p(z|x), f1(x,\phi_1)=[z_m,torch.exp(z_logv)]
105 | self.z_enc = nn.Sequential(
106 | # nn.Linear(self.c_in+self.c_kn, self.c_hidden, bias=True),
107 | nn.Linear(self.c_in, self.c_hidden, bias=True),
108 | nn.BatchNorm1d(self.c_hidden, momentum=0.01, eps=0.001),
109 | nn.ReLU(),
110 | )
111 |
112 | self.z_enc_m = nn.Linear(self.c_hidden, self.c_bn * self.c_kn)
113 | self.z_enc_logv = nn.Linear(self.c_hidden, self.c_bn * self.c_kn)
114 |
115 | # gene dispersion
116 | self._px_r = torch.nn.Parameter(torch.randn(self.c_in), requires_grad=True)
117 |
118 | # neural network g to get the x_m and x_v, p(x|z), g(z,\phi_3)=[x_m,x_v]
119 | self.px_hidden_decoder = nn.Sequential(
120 | nn.Linear(self.c_bn, self.c_hidden, bias=True),
121 | nn.ReLU(),
122 | )
123 | self.px_scale_decoder = nn.Sequential(
124 | nn.Linear(self.c_hidden, self.c_in),
125 | nn.Softmax(dim=-1)
126 | )
127 |
128 | def reparameterize(self, mu, log_var):
129 | """
130 | :param mu: mean from the encoder's latent space
131 | :param log_var: log variance from the encoder's latent space
132 | """
133 | std = torch.exp(0.5 * log_var) # standard deviation
134 | eps = torch.randn_like(std) # `randn_like` as we need the same size
135 | sample = mu + (eps * std) # sampling
136 | return sample
137 |
138 | def inference(self, x):
139 | x_n = torch.log1p(x)
140 |
141 | hidden = self.l_enc(x_n)
142 | ql_m = self.l_enc_m(hidden)
143 | ql_logv = self.l_enc_logv(hidden)
144 | ql = self.reparameterize(ql_m, ql_logv)
145 |
146 | hidden = self.c_enc(x_n)
147 | qc_m = self.c_enc_m(hidden)
148 | qc = Dirichlet(qc_m * self.alpha + self.eps).rsample()[:,:,None]
149 |
150 | hidden = self.z_enc(x_n)
151 | qz_m_ct = self.z_enc_m(hidden).reshape([x_n.shape[0], self.c_kn, self.c_bn])
152 | qz_m_ct = qc * qz_m_ct
153 | qz_m = qz_m_ct.sum(axis=1)
154 |
155 | qz_logv_ct = self.z_enc_logv(hidden).reshape([x_n.shape[0], self.c_kn, self.c_bn])
156 | qz_logv_ct = qc * qz_logv_ct
157 | qz_logv = qz_logv_ct.sum(axis=1)
158 | qz = self.reparameterize(qz_m, qz_logv)
159 |
160 | qu = self.reparameterize(self.qu_m, self.qu_logv)
161 |
162 | return dict(
163 | # q(u)
164 | qu=qu,
165 |
166 | # q(c | x)
167 | qc_m=qc_m,
168 | qc=qc,
169 |
170 | # q(z | c, x)
171 | qz_m=qz_m,
172 | qz_m_ct=qz_m_ct,
173 | qz_logv=qz_logv,
174 | qz_logv_ct=qz_logv_ct,
175 | qz=qz,
176 |
177 | # q(l | x)
178 | ql_m=ql_m,
179 | ql_logv=ql_logv,
180 | ql=ql,
181 | )
182 |
183 | def generative(
184 | self,
185 | inference_outputs,
186 | xs_k,
187 | ):
188 | qz = inference_outputs['qz']
189 | ql = inference_outputs['ql']
190 |
191 | hidden = self.px_hidden_decoder(qz)
192 | px_scale = self.px_scale_decoder(hidden)
193 | px_rate = torch.exp(ql) * px_scale + self.eps
194 | pc_p = xs_k + self.eps
195 |
196 | return dict(
197 | px_rate=px_rate,
198 | px_r=self.px_r,
199 | pc_p=pc_p,
200 | xs_k=xs_k,
201 | )
202 |
203 | def get_loss(
204 | self,
205 | generative_outputs,
206 | inference_outputs,
207 | x,
208 | library,
209 | device
210 | ):
211 | # Variational params
212 | qs_logm = self.qs_logm
213 | qu_m, qu_logv, qu = self.qu_m, self.qu_logv, inference_outputs["qu"]
214 | qc_m, qc = inference_outputs["qc_m"], inference_outputs["qc"]
215 | qz_m, qz_logv = inference_outputs["qz_m"], inference_outputs["qz_logv"]
216 | ql_m, ql_logv = inference_outputs["ql_m"], inference_outputs['ql_logv']
217 |
218 | # p(x | z), p(c; \alpha), p(u; \sigma)
219 | px_rate = generative_outputs["px_rate"]
220 | px_r = generative_outputs["px_r"]
221 | pc_p = generative_outputs["pc_p"]
222 |
223 | pu_m = torch.zeros_like(qu_m)
224 | pu_std = torch.ones_like(qu_logv) * 10
225 |
226 | # Regularization terms
227 | kl_divergence_u = kl(
228 | Normal(qu_m, torch.exp(qu_logv / 2)),
229 | Normal(pu_m, pu_std)
230 | ).sum(dim=1).mean()
231 |
232 | kl_divergence_l = kl(
233 | Normal(ql_m, torch.exp(ql_logv / 2)),
234 | Normal(library, torch.ones_like(ql_m))
235 | ).sum(dim=1).mean()
236 |
237 | kl_divergence_c = kl(
238 | Dirichlet(qc_m * self.alpha),
239 | Dirichlet(pc_p * self.alpha)
240 | ).mean()
241 |
242 | pz_m = (qu.unsqueeze(0) * qc).sum(axis=1)
243 | pz_std = (torch.exp(qs_logm / 2).unsqueeze(0) * qc).sum(axis=1)
244 |
245 | kl_divergence_z = kl(
246 | Normal(qz_m, torch.exp(qz_logv / 2)),
247 | Normal(pz_m, pz_std)
248 | ).sum(dim=1).mean()
249 |
250 | # Reconstruction term
251 | reconst_loss = -NegBinom(px_rate, torch.exp(px_r)).log_prob(x).sum(-1).mean()
252 |
253 | loss = reconst_loss.to(device) + \
254 | kl_divergence_u.to(device) + \
255 | kl_divergence_z.to(device) + \
256 | kl_divergence_c.to(device) + \
257 | kl_divergence_l.to(device)
258 |
259 | return (loss,
260 | reconst_loss,
261 | kl_divergence_u,
262 | kl_divergence_z,
263 | kl_divergence_c,
264 | kl_divergence_l
265 | )
266 |
267 | @property
268 | def px_r(self):
269 | return F.softplus(self._px_r) + self.eps
270 |
271 |
272 | class AVAE_PoE(nn.Module):
273 | """
274 | Model design:
275 | p(x|z)=f(z)
276 | p(z|x)~N(0,1)
277 | q(z|x)~g(x)
278 | """
279 |
280 | def __init__(
281 | self,
282 | adata,
283 | gene_sig,
284 | patch_r,
285 | win_loglib,
286 | alpha_mul=50,
287 | n_img_chan=1,
288 | seed=0,
289 | device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
290 | ) -> None:
291 | """
292 | Auxiliary Variational AutoEncoder (AVAE) with Joint H&E inference
293 | - Core model for spatial deconvolution w/ H&E image integration
294 |
295 | Paramters
296 | ---------
297 | adata : sc.AnnData
298 | ST raw expression count (dim: [S, G])
299 |
300 | gene_sig : pd.DataFrame
301 | Signature gene sets for each annotated cell type
302 |
303 | patch_r : int
304 | Mini-patch size sampled around each spot from raw H&E image
305 |
306 | win_loglib : float
307 | Log-library size smoothed with neighboring spots
308 |
309 | alpha_mul : float (default=50)
310 | Multiplier of Dirichlet concentration parameter to control
311 | signature prior's confidence
312 |
313 | """
314 | super().__init__()
315 | torch.manual_seed(seed)
316 |
317 | self.win_loglib = torch.Tensor(win_loglib)
318 | self.patch_r = patch_r
319 | self.c_in = adata.shape[1] # c_in : Num. input features (# input genes)
320 | self.nimg_chan = n_img_chan
321 | self.c_in_img = self.patch_r**2 * 4 * n_img_chan # c_in_img : (# pixels for the spot's img patch)
322 | self.c_bn = 10 # c_bn : latent number, numbers of bottleneck
323 | self.c_hidden = 256
324 | self.c_kn = gene_sig.shape[1]
325 |
326 | self.eps = 1e-5 # for r.v. w/ numerical constraints
327 | self.alpha = torch.ones(self.c_kn) * alpha_mul
328 | self.alpha = self.alpha.to(device)
329 |
330 | # --- neural nets for Expression view ---
331 | self.qs_logm = torch.nn.Parameter(torch.zeros(self.c_kn, self.c_bn), requires_grad=True)
332 | self.qu_m = torch.nn.Parameter(torch.randn(self.c_kn, self.c_bn), requires_grad=True)
333 | self.qu_logv = torch.nn.Parameter(torch.zeros(self.c_kn, self.c_bn), requires_grad=True)
334 |
335 | self.c_enc = nn.Sequential(
336 | nn.Linear(self.c_in, self.c_hidden, bias=True),
337 | nn.BatchNorm1d(self.c_hidden, momentum=0.01, eps=0.001),
338 | nn.ReLU()
339 | )
340 |
341 | self.c_enc_m = nn.Sequential(
342 | nn.Linear(self.c_hidden, self.c_kn, bias=True),
343 | nn.BatchNorm1d(self.c_kn, momentum=0.01, eps=0.001),
344 | nn.Softmax(dim=-1)
345 | )
346 |
347 | self.l_enc = nn.Sequential(
348 | nn.Linear(self.c_in+self.c_in_img, self.c_hidden, bias=True),
349 | nn.BatchNorm1d(self.c_hidden, momentum=0.01, eps=0.001),
350 | nn.ReLU(),
351 | )
352 | self.l_enc_m = nn.Linear(self.c_hidden, 1)
353 | self.l_enc_logv = nn.Linear(self.c_hidden, 1)
354 |
355 | # neural network f1 to get the z, p(z|x), f1(x,\phi_1)=[z_m,torch.exp(z_logv)]
356 | self.z_enc = nn.Sequential(
357 | nn.Linear(self.c_in, self.c_hidden, bias=True),
358 | nn.BatchNorm1d(self.c_hidden, momentum=0.01, eps=0.001),
359 | nn.ReLU(),
360 | )
361 | self.z_enc_m = nn.Linear(self.c_hidden, self.c_bn * self.c_kn)
362 | self.z_enc_logv = nn.Linear(self.c_hidden, self.c_bn * self.c_kn)
363 |
364 | # gene dispersion
365 | self._px_r = torch.nn.Parameter(torch.randn(self.c_in), requires_grad=True)
366 |
367 | # neural network g to get the x_m and x_v, p(x|z), g(z,\phi_3)=[x_m,x_v]
368 | self.z_to_hidden_decoder = nn.Sequential(
369 | nn.Linear(self.c_bn, self.c_hidden, bias=True),
370 | nn.ReLU(),
371 | )
372 | self.px_scale_decoder = nn.Sequential(
373 | nn.Linear(self.c_hidden, self.c_in),
374 | nn.Softmax(dim=-1)
375 | )
376 |
377 | # --- neural nets for Histology view ---
378 | # encoder paths:
379 | self.img_z_enc = nn.Sequential(
380 | nn.Linear(self.c_in_img, self.c_hidden, bias=True),
381 | nn.BatchNorm1d(self.c_hidden, momentum=0.01, eps=0.001),
382 | nn.ReLU()
383 | )
384 | self.img_z_enc_m = nn.Linear(self.c_hidden, self.c_bn)
385 | self.img_z_enc_logv = nn.Linear(self.c_hidden, self.c_bn)
386 |
387 | # decoder paths:
388 | self.img_z_to_hidden_decoder = nn.Linear(self.c_bn, self.c_hidden, bias=True)
389 | self.py_mu_decoder = nn.Sequential(
390 | nn.Linear(self.c_hidden, self.c_in_img),
391 | nn.BatchNorm1d(self.c_in_img, momentum=0.01, eps=0.001),
392 | nn.ReLU()
393 | )
394 | self.py_logv_decoder = nn.Sequential(
395 | nn.Linear(self.c_hidden, self.c_in_img),
396 | nn.ReLU()
397 | )
398 |
399 | # --- PoE view ---
400 | self.z_to_hidden_poe_decoder = nn.Linear(self.c_bn, self.c_hidden, bias=True)
401 | self.px_scale_poe_decoder = nn.Sequential(
402 | nn.Linear(self.c_hidden, self.c_in),
403 | nn.ReLU()
404 | )
405 | self._px_r_poe = torch.nn.Parameter(torch.randn(self.c_in), requires_grad=True)
406 |
407 | self.py_mu_poe_decoder = nn.Sequential(
408 | nn.Linear(self.c_hidden, self.c_in_img),
409 | nn.BatchNorm1d(self.c_in_img, momentum=0.01, eps=0.001),
410 | nn.ReLU()
411 | )
412 | self.py_logv_poe_decoder = nn.Sequential(
413 | nn.Linear(self.c_hidden, self.c_in_img),
414 | nn.ReLU()
415 | )
416 |
417 | def reparameterize(self, mu, log_var):
418 | """
419 | :param mu: mean from the encoder's latent space
420 | :param log_var: log variance from the encoder's latent space
421 | """
422 | std = torch.exp(0.5 * log_var) # standard deviation
423 | eps = torch.randn_like(std) # `randn_like` as we need the same size
424 | sample = mu + (eps * std) # sampling
425 | return sample
426 |
427 | def inference(self, x, y):
428 | # q(l | x)
429 | x_n = torch.log1p(x) # l is inferred from log(x)
430 | y_n = torch.log1p(y)
431 |
432 | hidden = self.l_enc(torch.concat([x_n,y_n],axis=1))
433 | ql_m = self.l_enc_m(hidden)
434 | ql_logv = self.l_enc_logv(hidden)
435 | ql = self.reparameterize(ql_m, ql_logv)
436 |
437 | # q(c | x)
438 | hidden = self.c_enc(x_n)
439 | qc_m = self.c_enc_m(hidden)
440 | qc = Dirichlet(qc_m * self.alpha + self.eps).rsample()[:,:,None]
441 |
442 | # q(z | c, x)
443 | hidden = self.z_enc(x_n)
444 | qz_m_ct = self.z_enc_m(hidden).reshape([x_n.shape[0], self.c_kn, self.c_bn])
445 | qz_m_ct = qc * qz_m_ct
446 | qz_m = qz_m_ct.sum(axis=1)
447 |
448 | qz_logv_ct = self.z_enc_logv(hidden).reshape([x_n.shape[0], self.c_kn, self.c_bn])
449 | qz_logv_ct = qc * qz_logv_ct
450 | qz_logv = qz_logv_ct.sum(axis=1)
451 | qz = self.reparameterize(qz_m, qz_logv)
452 |
453 | # q(u), mean-field VI
454 | qu = self.reparameterize(self.qu_m, self.qu_logv)
455 |
456 | return dict(
457 | # q(u)
458 | qu=qu,
459 |
460 | # q(c | x)
461 | qc_m=qc_m,
462 | qc=qc,
463 |
464 | # q(z | x)
465 | qz_m=qz_m,
466 | qz_m_ct=qz_m_ct,
467 | qz_logv=qz_logv,
468 | qz_logv_ct=qz_logv_ct,
469 | qz=qz,
470 |
471 | # q(l | x)
472 | ql_m=ql_m,
473 | ql_logv=ql_logv,
474 | ql=ql,
475 | )
476 |
477 | def generative(
478 | self,
479 | inference_outputs,
480 | xs_k,
481 | ):
482 | """
483 | xs_k : torch.Tensor
484 | Z-normed avg. gene exprs
485 | """
486 | qz = inference_outputs['qz']
487 | ql = inference_outputs['ql']
488 |
489 | hidden = self.z_to_hidden_decoder(qz)
490 | px_scale = self.px_scale_decoder(hidden)
491 | px_rate = torch.exp(ql) * px_scale + self.eps
492 | pc_p = xs_k + self.eps
493 |
494 | return dict(
495 | px_rate=px_rate,
496 | px_r=self.px_r,
497 | px_scale=px_scale,
498 | pc_p=pc_p,
499 | xs_k=xs_k,
500 | )
501 |
502 | def predictor_img(self, y):
503 | """Inference & generative paths for image view"""
504 | # --- Inference path ---
505 | y_n = torch.log1p(y)
506 |
507 | # q(z | y)
508 | hidden_z = self.img_z_enc(y_n)
509 | qz_m = self.img_z_enc_m(hidden_z)
510 | qz_logv = self.img_z_enc_logv(hidden_z)
511 | qz = self.reparameterize(qz_m, qz_logv)
512 |
513 | # --- Generative path ---
514 | hidden_y = self.img_z_to_hidden_decoder(qz)
515 | py_m = self.py_mu_decoder(hidden_y)
516 | py_logv = self.py_logv_decoder(hidden_y)
517 |
518 | return dict(
519 | # q(z | y)
520 | qz_m_img=qz_m,
521 | qz_logv_img=qz_logv,
522 | qz_img=qz,
523 |
524 | # p(y | z) (image reconst)
525 | py_m=py_m,
526 | py_logv=py_logv
527 | )
528 |
529 | def generative_img(self, z):
530 | """Generative path of histology view given z"""
531 | hidden_y = self.img_z_to_hidden_decoder(z)
532 | py_m = self.py_mu_decoder(hidden_y)
533 | py_logv = self.py_logv_decoder(hidden_y)
534 | return dict(
535 | py_m=py_m,
536 | py_logv=py_logv
537 | )
538 |
539 | def predictor_poe(
540 | self,
541 | inference_outputs,
542 | img_outputs,
543 | ):
544 | """Inference & generative paths for Joint view"""
545 |
546 | # Variational params. (expression branch)
547 | ql = inference_outputs['ql']
548 | qz_m = inference_outputs['qz_m']
549 | qz_logv = inference_outputs['qz_logv']
550 |
551 | # Variational params. (img branch)
552 | qz_m_img = img_outputs['qz_m_img']
553 | qz_logv_img = img_outputs['qz_logv_img']
554 |
555 | batch, _ = qz_m.shape
556 |
557 | # PoE joint qz
558 | # --- Joint posterior qz with PoE ---
559 | qz_var_poe = torch.div(
560 | 1.,
561 | torch.div(1., torch.exp(qz_logv)) + torch.div(1., torch.exp(qz_logv_img))
562 | )
563 | qz_m_poe = qz_var_poe * (
564 | qz_m * torch.div(1., torch.exp(qz_logv) + self.eps) +
565 | qz_m_img * torch.div(1., torch.exp(qz_logv_img) + self.eps)
566 | )
567 | qz = self.reparameterize(qz_m_poe, torch.log(qz_var_poe)) # Joint posterior
568 |
569 | # PoE joint & view-specific decoders
570 | hidden = self.z_to_hidden_poe_decoder(qz)
571 |
572 | # p(x | z_poe)
573 | px_scale = self.px_scale_poe_decoder(hidden)
574 | px_rate = torch.exp(ql) * px_scale + self.eps
575 |
576 | # p(y | z_poe)
577 | py_m = self.py_mu_poe_decoder(hidden)
578 | py_logv = self.py_logv_poe_decoder(hidden)
579 |
580 | return dict(
581 | # PoE q(z | x, y)
582 | qz_m=qz_m_poe,
583 | qz_logv=torch.log1p(qz_var_poe),
584 | qz=qz,
585 |
586 | # PoE p(x | z, l) & p(y | z)
587 | px_rate=px_rate,
588 | px_r=self.px_r_poe,
589 | py_m=py_m,
590 | py_logv=py_logv
591 | )
592 |
593 | def get_loss(
594 | self,
595 | generative_outputs,
596 | inference_outputs,
597 | img_outputs,
598 | poe_outputs,
599 | x,
600 | library,
601 | y,
602 | device
603 | ):
604 | lambda_poe = 0.2
605 |
606 | # --- Parse variables ---
607 | # Variational params
608 | qc_m, qc = inference_outputs["qc_m"], inference_outputs["qc"]
609 |
610 | qs_logm = self.qs_logm
611 | qu_m, qu_logv, qu = self.qu_m, self.qu_logv, inference_outputs["qu"]
612 |
613 | qz_m, qz_logv = inference_outputs["qz_m"], inference_outputs["qz_logv"]
614 | ql_m, ql_logv = inference_outputs["ql_m"], inference_outputs['ql_logv']
615 |
616 | qz_m_img, qz_logv_img = img_outputs['qz_m_img'], img_outputs['qz_logv_img']
617 | qz_m_poe, qz_logv_poe = poe_outputs['qz_m'], poe_outputs['qz_logv']
618 |
619 | # Generative params
620 | px_rate = generative_outputs["px_rate"]
621 | px_r = generative_outputs["px_r"]
622 | pc_p = generative_outputs["pc_p"]
623 |
624 | py_m, py_logv = img_outputs['py_m'], img_outputs['py_logv']
625 |
626 | px_rate_poe = poe_outputs['px_rate']
627 | px_r_poe = poe_outputs['px_r']
628 | py_m_poe, py_logv_poe = poe_outputs['py_m'], poe_outputs['py_logv']
629 |
630 |
631 | # --- Losses ---
632 | # (1). Joint Loss
633 | y_n = torch.log1p(y)
634 | reconst_loss_x_poe = -NegBinom(px_rate_poe, torch.exp(px_r_poe)).log_prob(x).sum(-1).mean()
635 | reconst_loss_y_poe = -Normal(py_m_poe, torch.exp(py_logv_poe/2)).log_prob(y_n).sum(-1).mean()
636 |
637 | # prior: p(z | c, u)
638 | pz_m = (qu.unsqueeze(0) * qc).sum(axis=1)
639 | pz_std = (torch.exp(qs_logm / 2).unsqueeze(0) * qc).sum(axis=1)
640 |
641 | kl_divergence_z_poe = kl(
642 | Normal(qz_m_poe, torch.exp(qz_logv_poe / 2)),
643 | Normal(pz_m, pz_std)
644 | ).sum(dim=1).mean()
645 |
646 | Loss_IBJ = reconst_loss_x_poe + reconst_loss_y_poe + kl_divergence_z_poe.to(device)
647 |
648 | # (2). View-specific losses
649 | # Expression view
650 | kl_divergence_u = kl(
651 | Normal(qu_m, torch.exp(qu_logv / 2)),
652 | Normal(torch.zeros_like(qu_m), torch.ones_like(qu_m) * 10)
653 | ).sum(dim=1).mean()
654 |
655 | kl_divergence_z = kl(
656 | Normal(qz_m, torch.exp(qz_logv / 2)),
657 | Normal(pz_m, pz_std)
658 | ).sum(dim=1).mean()
659 |
660 | kl_divergence_l = kl(
661 | Normal(ql_m, torch.exp(ql_logv / 2)),
662 | Normal(library, torch.ones_like(ql_m))
663 | ).sum(dim=1).mean()
664 |
665 | kl_divergence_c = kl(
666 | Dirichlet(qc_m * self.alpha), # q(c | x; α) = Dir(α * λ(x))
667 | Dirichlet(pc_p * self.alpha)
668 | ).mean()
669 |
670 | reconst_loss_x = -NegBinom(px_rate, torch.exp(px_r)).log_prob(x).sum(-1).mean()
671 | loss_exp = reconst_loss_x.to(device) + \
672 | kl_divergence_u.to(device) + \
673 | kl_divergence_z.to(device) + \
674 | kl_divergence_c.to(device) + \
675 | kl_divergence_l.to(device)
676 |
677 | # Image view
678 | kl_divergence_z_img = kl(
679 | Normal(qz_m_img, torch.sqrt(torch.exp(qz_logv_img / 2))),
680 | Normal(pz_m, pz_std)
681 | ).sum(dim=1).mean()
682 |
683 | reconst_loss_y = -Normal(py_m, torch.exp(py_logv/2)).log_prob(y_n).sum(-1).mean()
684 | loss_img = reconst_loss_y.to(device) + kl_divergence_z_img.to(device)
685 |
686 | # PoE total loss: Joint Loss + a * \sum(marginal loss)
687 | Loss_IBM = (loss_exp + loss_img)
688 | loss = lambda_poe*Loss_IBJ + Loss_IBM
689 | # loss = self.lambda_poe*Loss_IBJ + Loss_IBM
690 |
691 | return (
692 | # Total loss
693 | loss,
694 |
695 | # sum of marginal reconstruction losses
696 | lambda_poe*(reconst_loss_x_poe+reconst_loss_y_poe) + (reconst_loss_x+reconst_loss_y),
697 |
698 | # KL divergence
699 | kl_divergence_u,
700 | lambda_poe*kl_divergence_z_poe + kl_divergence_z + kl_divergence_z_img,
701 | kl_divergence_c,
702 | kl_divergence_l
703 | )
704 |
705 | @property
706 | def px_r(self):
707 | return F.softplus(self._px_r) + self.eps
708 |
709 | @property
710 | def px_r_poe(self):
711 | return F.softplus(self._px_r_poe) + self.eps
712 |
713 |
714 | def train(
715 | model,
716 | dataloader,
717 | device,
718 | optimizer,
719 | ):
720 | model.train()
721 |
722 | running_loss = 0.0
723 | running_u = 0.0
724 | running_z = 0.0
725 | running_c = 0.0
726 | running_l = 0.0
727 | running_reconst = 0.0
728 | counter = 0
729 | corr_list = []
730 | for i, (x, xs_k, x_peri, library_i) in enumerate(dataloader):
731 |
732 | counter += 1
733 | x = x.float()
734 | x = x.to(device)
735 | xs_k = xs_k.to(device)
736 | x_peri = x_peri.to(device)
737 | library_i = library_i.to(device)
738 |
739 | inference_outputs = model.inference(x)
740 | generative_outputs = model.generative(inference_outputs, xs_k)
741 |
742 | # Check for NaNs
743 | #if torch.isnan(loss) or any(torch.isnan(p).any() for p in model.parameters()):
744 | if any(torch.isnan(p).any() for p in model.parameters()):
745 | LOGGER.warning('NaNs detected in model parameters, Skipping current epoch...')
746 | continue
747 |
748 | (loss,
749 | reconst_loss,
750 | kl_divergence_u,
751 | kl_divergence_z,
752 | kl_divergence_c,
753 | kl_divergence_l
754 | ) = model.get_loss(
755 | generative_outputs,
756 | inference_outputs,
757 | x,
758 | library_i,
759 | device
760 | )
761 |
762 | optimizer.zero_grad()
763 | loss.backward()
764 |
765 | nn.utils.clip_grad_norm_(model.parameters(), 5)
766 | optimizer.step()
767 |
768 | running_loss += loss.item()
769 | running_reconst += reconst_loss.item()
770 | running_u += kl_divergence_u.item()
771 | running_z += kl_divergence_z.item()
772 | running_c += kl_divergence_c.item()
773 | running_l += kl_divergence_l.item()
774 |
775 | train_loss = running_loss / counter
776 | train_reconst = running_reconst / counter
777 | train_u = running_u / counter
778 | train_z = running_z / counter
779 | train_c = running_c / counter
780 | train_l = running_l / counter
781 |
782 | return train_loss, train_reconst, train_u, train_z, train_c, train_l, corr_list
783 |
784 |
785 | def train_poe(
786 | model,
787 | dataloader,
788 | device,
789 | optimizer,
790 | ):
791 | model.train()
792 |
793 | running_loss = 0.0
794 | running_z = 0.0
795 | running_c = 0.0
796 | running_l = 0.0
797 | running_u = 0.0
798 | running_reconst = 0.0
799 | counter = 0
800 | corr_list = []
801 | for i, (x,
802 | x_peri,
803 | library_i,
804 | img,
805 | data_loc,
806 | xs_k,
807 | ) in enumerate(dataloader):
808 | counter += 1
809 | mini_batch, _ = x.shape
810 |
811 | x = x.float()
812 | x = x.to(device)
813 | x_peri = x_peri.to(device)
814 | library_i = library_i.to(device)
815 | xs_k = xs_k.to(device)
816 |
817 | img = img.reshape(mini_batch, -1).float()
818 | img = img.to(device)
819 |
820 | inference_outputs = model.inference(x,img) # inference for 1D expr. data
821 | generative_outputs = model.generative(inference_outputs, xs_k)
822 | img_outputs = model.predictor_img(img) # inference & generative for 2D img. data
823 | poe_outputs = model.predictor_poe(inference_outputs, img_outputs) # PoE generative outputs
824 |
825 | # Check for NaNs
826 | if any(torch.isnan(p).any() for p in model.parameters()):
827 | LOGGER.warning('NaNs detected in model parameters, Skipping current epoch...')
828 | continue
829 |
830 | (loss,
831 | reconst_loss,
832 | kl_divergence_u,
833 | kl_divergence_z,
834 | kl_divergence_c,
835 | kl_divergence_l
836 | ) = model.get_loss(
837 | generative_outputs,
838 | inference_outputs,
839 | img_outputs,
840 | poe_outputs,
841 | x,
842 | library_i,
843 | img,
844 | device
845 | )
846 |
847 | optimizer.zero_grad()
848 | loss.backward()
849 |
850 | nn.utils.clip_grad_norm_(model.parameters(), 5)
851 | optimizer.step()
852 |
853 | running_loss += loss.item()
854 | running_reconst += reconst_loss.item()
855 | running_z += kl_divergence_z.item()
856 | running_c += kl_divergence_c.item()
857 | running_l += kl_divergence_l.item()
858 | running_u += kl_divergence_u.item()
859 |
860 | train_loss = running_loss / counter
861 | train_reconst = running_reconst / counter
862 | train_z = running_z / counter
863 | train_c = running_c / counter
864 | train_l = running_l / counter
865 | train_u = running_u / counter
866 |
867 | return train_loss, train_reconst, train_u, train_z, train_c, train_l, corr_list
868 |
869 |
870 | # Reference:
871 | # https://github.com/YosefLab/scvi-tools/blob/master/scvi/distributions/_negative_binomial.py
872 | class NegBinom(Distribution):
873 | """
874 | Gamma-Poisson mixture approximation of Negative Binomial(mean, dispersion)
875 |
876 | lambda ~ Gamma(mu, theta)
877 | x ~ Poisson(lambda)
878 | """
879 | arg_constraints = {
880 | 'mu': constraints.greater_than_eq(0),
881 | 'theta': constraints.greater_than_eq(0),
882 | }
883 | support = constraints.nonnegative_integer
884 |
885 | def __init__(self, mu, theta, eps=1e-10):
886 | """
887 | Parameters
888 | ----------
889 | mu : torch.Tensor
890 | mean of NegBinom. distribution
891 | shape - [# genes,]
892 |
893 | theta : torch.Tensor
894 | dispersion of NegBinom. distribution
895 | shape - [# genes,]
896 | """
897 | self.mu = mu
898 | self.theta = theta
899 | self.eps = eps
900 | super(NegBinom, self).__init__(validate_args=True)
901 |
902 | def sample(self):
903 | lambdas = Gamma(
904 | concentration=self.theta + self.eps,
905 | rate=(self.theta + self.eps) / (self.mu + self.eps),
906 | ).rsample()
907 |
908 | x = Poisson(lambdas).sample()
909 |
910 | return x
911 |
912 | def log_prob(self, x):
913 | """log-likelihood"""
914 | ll = torch.lgamma(x + self.theta) - \
915 | torch.lgamma(x + 1) - \
916 | torch.lgamma(self.theta) + \
917 | self.theta * (torch.log(self.theta + self.eps) - torch.log(self.theta + self.mu + self.eps)) + \
918 | x * (torch.log(self.mu + self.eps) - torch.log(self.theta + self.mu + self.eps))
919 |
920 | return ll
921 |
922 |
923 | def model_eval(
924 | model,
925 | adata,
926 | visium_args,
927 | poe=False,
928 | device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
929 | ):
930 | adata_ = adata.copy()
931 | model.eval()
932 | model = model.to(device)
933 |
934 | x_in = torch.Tensor(adata_.to_df().values).to(device)
935 | sig_means = torch.Tensor(visium_args.sig_mean_norm.values).to(device)
936 | anchor_idx = torch.Tensor(visium_args.pure_idx).to(device)
937 |
938 | with torch.no_grad():
939 | if not poe:
940 | inference_outputs = model.inference(x_in)
941 | generative_outputs = model.generative(inference_outputs, sig_means)
942 |
943 | if poe:
944 |
945 | img_in = torch.Tensor(visium_args.get_img_patches()).float().to(device)
946 | inference_outputs = model.inference(x_in,img_in)
947 | generative_outputs = model.generative(inference_outputs, sig_means)
948 |
949 | img_outputs = model.predictor_img(img_in)
950 | poe_outputs = model.predictor_poe(inference_outputs, img_outputs)
951 |
952 | # Parse image view / PoE inference & generative outputs
953 | # Save to `inference_outputs` & `generative_outputs`
954 | for k, v in img_outputs.items():
955 | if 'q' in k:
956 | inference_outputs[k] = v
957 | else:
958 | generative_outputs[k] = v
959 |
960 | for k, v in poe_outputs.items():
961 | if 'q' in k:
962 | inference_outputs[k+'_poe'] = v
963 | else:
964 | generative_outputs[k+'_poe'] = v
965 |
966 | try:
967 | px = NegBinom(
968 | mu=generative_outputs["px_rate"],
969 | theta=torch.exp(generative_outputs["px_r"])
970 | ).sample().detach().cpu().numpy()
971 | adata_.obsm['px'] = px
972 | except ValueError as ve:
973 | LOGGER.warning('Invalid Gamma distribution parameters `px_rate` or `px_r`, unable to sample inferred p(x | z)')
974 |
975 | # Save inference & generative outputs in adata
976 | for rv in inference_outputs.keys():
977 | val = inference_outputs[rv].detach().cpu().numpy().squeeze()
978 | if "qu" not in rv and "qs" not in rv:
979 | adata_.obsm[rv] = val
980 | else:
981 | adata_.uns[rv] = val
982 |
983 | for rv in generative_outputs.keys():
984 | try:
985 | if rv == 'px_r' or rv == 'px_r_poe':
986 | val = generative_outputs[rv].data.detach().cpu().numpy().squeeze()
987 | adata_.varm[rv] = val
988 | else:
989 | val = generative_outputs[rv].data.detach().cpu().numpy().squeeze()
990 | adata_.obsm[rv] = val
991 | except:
992 | print("rv: {} can't be stored".format(rv))
993 |
994 | qz_umap = get_z_umap(adata_.obsm['qz_m'])
995 | adata_.obsm['z_umap'] = qz_umap
996 | return inference_outputs, generative_outputs, adata_
997 |
998 |
999 | def model_eval_integrate(
1000 | model,
1001 | adata,
1002 | visium_args,
1003 | poe=False,
1004 | device=torch.device('cpu')
1005 | ):
1006 | """
1007 | Model evaluation for sample integration
1008 | TODO: code refactor
1009 | """
1010 | model.eval()
1011 | model = model.to(device)
1012 | adata_ = adata.copy()
1013 | x_in = torch.Tensor(adata_.to_df().values).to(device)
1014 | sig_means = torch.Tensor(visium_args.sig_mean_norm.values).to(device)
1015 | anchor_idx = torch.Tensor(visium_args.pure_idx).to(device)
1016 |
1017 | with torch.no_grad():
1018 | if not poe:
1019 | inference_outputs = model.inference(x_in)
1020 | generative_outputs = model.generative(inference_outputs, sig_means)
1021 | if poe:
1022 | img_in = torch.Tensor(visium_args.get_img_patches()).float().to(device)
1023 |
1024 | inference_outputs = model.inference(x_in, img_in)
1025 | generative_outputs = model.generative(inference_outputs, sig_means)
1026 |
1027 | img_outputs = model.predictor_img(img_in) if img_in.max() <= 1 else model.predictor_img(img_in/255)
1028 | poe_outputs = model.predictor_poe(inference_outputs, img_outputs)
1029 |
1030 | # Parse image view / PoE inference & generative outputs
1031 | # Save to `inference_outputs` & `generative_outputs`
1032 | for k, v in img_outputs.items():
1033 | if 'q' in k:
1034 | inference_outputs[k] = v
1035 | else:
1036 | generative_outputs[k] = v
1037 |
1038 | for k, v in poe_outputs.items():
1039 | if 'q' in k:
1040 | inference_outputs[k+'_poe'] = v
1041 | else:
1042 | generative_outputs[k+'_poe'] = v
1043 |
1044 | # TODO: move histology reconstruction to a separate func
1045 |
1046 | # # Reconst histology prediction
1047 | # # reconst_img_patches = img_outputs['py_m']
1048 | # reconst_poe_img_patches = poe_outputs['py_m']
1049 | # # reconst_img_all = {}
1050 | # reconst_poe_img_all = {}
1051 |
1052 | # batch_idx = 0 # img metadata counter for each counter
1053 |
1054 | # for sample_id in visium_args.adata.obs['sample'].unique():
1055 |
1056 | # img_dim = visium_args.img[sample_id].shape
1057 | # # reconst_img = np.ones((img_dim + (model.nimg_chan,))) * (-1/255)
1058 | # reconst_poe_img = np.ones((img_dim + (model.nimg_chan,))) * (-1/255)
1059 |
1060 | # # image_col = img_metadata[i]['map_info']['imagecol']*img_metadata[i]['scalefactor']['tissue_hires_scalef']
1061 | # # image_row = img_metadata[i]['map_info']['imagerow']*img_metadata[i]['scalefactor']['tissue_hires_scalef']
1062 | # map_info = visium_args.map_info[adata.obs['sample'] == sample_id]
1063 | # scalefactor = visium_args.scalefactor[sample_id]
1064 | # image_col = map_info['imagecol'] * scalefactor['tissue_hires_scalef']
1065 | # image_row = map_info['imagerow'] * scalefactor['tissue_hires_scalef']
1066 |
1067 | # for idx in range(image_col.shape[0]):
1068 |
1069 | # patch_y = slice(int(image_row[idx])-model.patch_r, int(image_row[idx])+model.patch_r)
1070 | # patch_x = slice(int(image_col[idx])-model.patch_r, int(image_col[idx])+model.patch_r)
1071 |
1072 | # """
1073 | # reconst_img[patch_y, patch_x, :] = reconst_img_patches[idx+batch_idx].reshape([
1074 | # model.patch_r*2,
1075 | # model.patch_r*2,
1076 | # model.nimg_chan
1077 | # ]).cpu().detach().numpy()
1078 | # """
1079 | # reconst_poe_img[patch_y, patch_x, :] = reconst_poe_img_patches[idx+batch_idx].reshape([
1080 | # model.patch_r*2,
1081 | # model.patch_r*2,
1082 | # model.nimg_chan
1083 | # ]).cpu().detatch().numpy()
1084 |
1085 | # # reconst_img_all[sample_id] = reconst_img
1086 | # reconst_poe_img_all[sample_id] = reconst_poe_img
1087 | # batch_idx += image_col.shape[0]
1088 |
1089 | # Update reconstructed image
1090 | # adata.uns['reconst_img'] = reconst_poe_img_all
1091 |
1092 | try:
1093 | px = NegBinom(
1094 | mu=generative_outputs["px_rate"],
1095 | theta=torch.exp(generative_outputs["px_r"])
1096 | ).sample().detach().cpu().numpy()
1097 | adata_.obsm['px'] = px
1098 | except ValueError as ve:
1099 | LOGGER.warning('Invalid Gamma distribution parameters `px_rate` or `px_r`, unable to sample inferred p(x | z)')
1100 |
1101 | # Save inference & generative outputs in adata
1102 | for rv in inference_outputs.keys():
1103 | val = inference_outputs[rv].detach().cpu().numpy().squeeze()
1104 | if "qu" not in rv and "qs" not in rv:
1105 | adata_.obsm[rv] = val
1106 | else:
1107 | adata_.uns[rv] = val
1108 |
1109 | for rv in generative_outputs.keys():
1110 | try:
1111 | if rv == 'px_r' or rv == 'reconstruction': # Posterior avg. znorm signature means
1112 | val = generative_outputs[rv].data.detach().cpu().numpy().squeeze()
1113 | adata_.varm[rv] = val
1114 | else:
1115 | val = generative_outputs[rv].data.detach().cpu().numpy().squeeze()
1116 | adata_.obsm[rv] = val
1117 | except:
1118 | print("rv: {} can't be stored".format(rv))
1119 |
1120 | return inference_outputs, generative_outputs, adata_
1121 |
1122 |
1123 | def model_ct_exp(
1124 | model,
1125 | adata,
1126 | visium_args,
1127 | poe = False,
1128 | device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
1129 | ):
1130 | """
1131 | Obtain generative cell-type specific expression in each spot (after model training)
1132 | """
1133 | sig_means = torch.Tensor(visium_args.sig_mean_norm.values).to(device)
1134 | anchor_idx = torch.Tensor(visium_args.pure_idx).to(device)
1135 | x_in = torch.Tensor(adata.to_df().values).to(device)
1136 | if poe:
1137 | y_in = torch.Tensor(visium_args.get_img_patches()).float().to(device)
1138 |
1139 |
1140 | model.eval()
1141 | model = model.to(device)
1142 | pred_exprs = {}
1143 |
1144 | for ct_idx, cell_type in enumerate(adata.uns['cell_types']):
1145 | # Get inference outputs for the given cell type
1146 |
1147 | if poe:
1148 | inference_outputs = model.inference(x_in,y_in)
1149 | else:
1150 | inference_outputs = model.inference(x_in)
1151 | inference_outputs['qz'] = inference_outputs['qz_m_ct'][:, ct_idx, :]
1152 |
1153 | # Get generative outputs
1154 | generative_outputs = model.generative(inference_outputs, sig_means)
1155 |
1156 | px = NegBinom(
1157 | mu=generative_outputs["px_rate"],
1158 | theta=torch.exp(generative_outputs["px_r"])
1159 | ).sample()
1160 | px = px.detach().cpu().numpy()
1161 |
1162 | # Save results in adata.obsm
1163 | px_df = pd.DataFrame(px, index=adata.obs_names, columns=adata.var_names)
1164 | pred_exprs[cell_type] = px_df
1165 | adata.obsm[cell_type + '_inferred_exprs'] = px
1166 |
1167 | return pred_exprs
1168 |
1169 |
1170 | def model_ct_img(
1171 | model,
1172 | adata,
1173 | visium_args,
1174 | device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
1175 | ):
1176 | """
1177 | Obtain generative cell-type specific images (after model training)
1178 | """
1179 | x_in = torch.Tensor(adata.to_df().values).to(device)
1180 | y_in = torch.Tensor(visium_args.get_img_patches()).float().to(device)
1181 |
1182 | model.eval()
1183 | model = model.to(device)
1184 | ct_imgs = {}
1185 | ct_imgs_poe = {}
1186 |
1187 | for ct_idx, cell_type in enumerate(adata.uns['cell_types']):
1188 | # Get inference outputs for the given cell type
1189 | inference_outputs = model.inference(x_in)
1190 | qz_m_ct, qz_logv_ct = inference_outputs['qz_m_ct'][:, ct_idx, :], inference_outputs['qz_logv_ct'][:, ct_idx, :]
1191 | qz_ct = model.reparameterize(qz_m_ct, qz_logv_ct)
1192 | inference_outputs['qz_m'], inference_outputs['qz'] = qz_m_ct, qz_ct
1193 |
1194 | # Generate cell-type specific low-dim representations (z)
1195 | img_outputs = model.predictor_img(y_in)
1196 | qz_m_ct_img, qz_logv_ct_img = img_outputs['qz_m_ct_img'][:, ct_idx, :], img_outputs['qz_logv_ct_img'][:, ct_idx, :]
1197 | qz_ct_img = model.reparameterize(qz_m_ct_img, qz_logv_ct_img)
1198 | img_outputs['qz_m_img'], img_outputs['qz_img'] = qz_m_ct_img, qz_ct_img
1199 | generative_outputs_img = model.generative_img(qz_ct_img)
1200 |
1201 | poe_outputs = model.predictor_poe(inference_outputs, img_outputs)
1202 |
1203 | ct_imgs[cell_type] = generative_outputs_img['py_m'].detach().cpu().numpy()
1204 | ct_imgs_poe[cell_type] = poe_outputs['py_m'].detach().cpu().numpy()
1205 |
1206 | return ct_imgs, ct_imgs_poe
1207 |
1208 |
--------------------------------------------------------------------------------
/starfysh/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import json
4 | import numpy as np
5 | import pandas as pd
6 | import scanpy as sc
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.optim as optim
11 |
12 | from scipy.stats import median_abs_deviation
13 | from torch.utils.data import DataLoader
14 |
15 | import sys
16 | import histomicstk as htk
17 | from skimage import io
18 |
19 | # Module import
20 | from starfysh import LOGGER
21 | from .dataloader import VisiumDataset, VisiumPoEDataSet
22 | from .starfysh import AVAE, AVAE_PoE, train, train_poe
23 |
24 |
25 | # -------------------
26 | # Model Parameters
27 | # -------------------
28 |
29 | class VisiumArguments:
30 | """
31 | Loading Visium AnnData, perform preprocessing, library-size smoothing & Anchor spot detection
32 |
33 | Parameters
34 | ----------
35 | adata : AnnData
36 | annotated visium count matrix
37 |
38 | adata_norm : AnnData
39 | annotated visium count matrix after normalization & log-transform
40 |
41 | gene_sig : pd.DataFrame
42 | list of signature genes for each cell type. (dim: [S, Cell_type])
43 |
44 | img_metadata : dict
45 | Spatial information metadata (histology image, coordinates, scalefactor)
46 | """
47 | def __init__(
48 | self,
49 | adata,
50 | adata_norm,
51 | gene_sig,
52 | img_metadata,
53 | **kwargs
54 | ):
55 |
56 | self.adata = adata
57 | self.adata_norm = adata_norm
58 | self.gene_sig = gene_sig
59 | self.map_info = img_metadata['map_info'].iloc[:, :4].astype(float)
60 | self.scalefactor = img_metadata['scalefactor']
61 | self.img = img_metadata['img']
62 | self.img_patches = None
63 | self.eps = 1e-6
64 |
65 | self.params = {
66 | 'sample_id': 'ST',
67 | 'n_anchors': int(adata.shape[0]),
68 | 'patch_r': 16,
69 | 'signif_level': 3,
70 | 'window_size': 1,
71 | 'n_img_chan': 1
72 | }
73 |
74 | # Update parameters for library smoothing & anchor spot identification
75 | for k, v in kwargs.items():
76 | if k in self.params.keys():
77 | self.params[k] = v
78 |
79 | # Center expression for gene score calculation
80 | adata_scale = self.adata_norm.copy()
81 | sc.pp.scale(adata_scale)
82 |
83 | # Store cell types
84 | self.adata.uns['cell_types'] = list(self.gene_sig.columns)
85 |
86 | # Filter out signature genes X listed in expression matrix
87 | LOGGER.info('Subsetting highly variable & signature genes ...')
88 | self.adata, self.adata_norm = get_adata_wsig(adata, adata_norm, gene_sig)
89 | self.adata_scale = adata_scale[:, adata.var_names]
90 |
91 | # Update spatial metadata
92 | self._update_spatial_info(self.params['sample_id'])
93 |
94 | # Get smoothed library size
95 | LOGGER.info('Smoothing library size by taking averaging with neighbor spots...')
96 | log_lib = np.log1p(self.adata.X.sum(1))
97 | self.log_lib = np.squeeze(np.asarray(log_lib)) if log_lib.ndim > 1 else log_lib
98 |
99 | self.win_loglib = get_windowed_library(self.adata,
100 | self.map_info,
101 | self.log_lib,
102 | window_size=self.params['window_size'])
103 |
104 | # Retrieve & normalize signature gene expressions by
105 | # comparing w/ randomized base expression (`sc.tl.score_genes`)
106 | LOGGER.info('Retrieving & normalizing signature gene expressions...')
107 | self.sig_mean = self._get_sig_mean()
108 | self.sig_mean_norm = self._calc_gene_scores()
109 |
110 | # Get anchor spots
111 | LOGGER.info('Identifying anchor spots (highly expression of specific cell-type signatures)...')
112 | anchor_info = self._compute_anchors()
113 |
114 | # row-norm; "ReLU" on signature scores for valid dirichlet param
115 | self.sig_mean_norm[self.sig_mean_norm < 0] = self.eps
116 | self.sig_mean_norm = self.sig_mean_norm.div(self.sig_mean_norm.sum(1), axis=0)
117 | self.sig_mean_norm.fillna(1/self.sig_mean_norm.shape[1], inplace=True)
118 |
119 | self.pure_spots, self.pure_dict, self.pure_idx = anchor_info
120 | del self.adata.raw, self.adata_norm.raw
121 |
122 | def get_adata(self):
123 | """Return adata after preprocessing & HVG gene selection"""
124 | return self.adata, self.adata_norm
125 |
126 | def get_anchors(self):
127 | """Return indices of anchor spots for each cell type"""
128 | anchors_df = pd.DataFrame.from_dict(self.pure_dict, orient='index')
129 | anchors_df = anchors_df.transpose()
130 |
131 | # Check whether empty anchors detected for any factor
132 | empty_indices = np.where(
133 | (~pd.isna(anchors_df)).sum(0) == 0
134 | )[0]
135 |
136 | if len(empty_indices) > 0:
137 | raise ValueError("Cell type(s) {} has no anchors significantly enriched for its signatures,"
138 | "please lower outlier stats `signif_level` or try zscore-based signature".format(
139 | anchors_df.columns[empty_indices].to_list()
140 | ))
141 |
142 | return anchors_df.applymap(
143 | lambda x:
144 | -1 if x is None else np.where(self.adata.obs.index == x)[0][0]
145 | )
146 |
147 | def get_img_patches(self):
148 | assert self.img_patches is not None, "Please run Starfysh PoE first"
149 | return self.img_patches
150 |
151 | def append_factors(self, arche_markers):
152 | """
153 | Append list of archetypes (w/ corresponding markers) as additional cell type(s) / state(s) to the `gene_sig`
154 | """
155 | self.gene_sig = pd.concat((self.gene_sig, arche_markers), axis=1)
156 |
157 | # Update factor names & anchor spots
158 | self.adata.uns['cell_types'] = list(self.gene_sig.columns)
159 | self._update_anchors()
160 | return None
161 |
162 | def replace_factors(self, factors_to_repl, arche_markers):
163 | """
164 | Replace factor(s) with archetypes & their corresponding markers in the `gene_sig`
165 | """
166 | if isinstance(factors_to_repl, str):
167 | assert isinstance(arche_markers, pd.Series),\
168 | "Please pick only one archetype to replace the factor {}".format(factors_to_repl)
169 | factors_to_repl = [factors_to_repl]
170 | archetypes = [arche_markers.name]
171 | else:
172 | assert len(factors_to_repl) == len(arche_markers.columns), \
173 | "Unequal # cell types & archetypes to replace with"
174 | archetypes = arche_markers.columns
175 |
176 | self.gene_sig.rename(
177 | columns={
178 | f: a
179 | for (f, a) in zip(factors_to_repl, archetypes)
180 | }, inplace=True
181 | )
182 | self.gene_sig[archetypes] = pd.DataFrame(arche_markers)
183 |
184 | # Update factor names & anchor spots
185 | self.adata.uns['cell_types'] = list(self.gene_sig.columns)
186 | self._update_anchors()
187 | return None
188 |
189 | # --- Private methods ---
190 | def _compute_anchors(self):
191 | """
192 | Calculate top `anchor_spots` significantly enriched for given cell type(s)
193 | determined by gene set scores from signatures
194 | """
195 | score_df = self.sig_mean_norm
196 | n_anchor = self.params['n_anchors']
197 |
198 | top_expr_spots = (-score_df.values).argsort(axis=0)[:n_anchor, :]
199 | pure_spots = np.transpose(np.array(score_df.index)[top_expr_spots])
200 |
201 | pure_dict = {
202 | ct: spot
203 | for (spot, ct) in zip(pure_spots, score_df.columns)
204 | }
205 |
206 | pure_indices = np.zeros([score_df.shape[0], 1])
207 | idx = [np.where(score_df.index == i)[0][0]
208 | for i in sorted({x for v in pure_dict.values() for x in v})]
209 | pure_indices[idx] = 1
210 | return pure_spots, pure_dict, pure_indices
211 |
212 | def _update_anchors(self):
213 | """Re-calculate anchor spots given updated gene signatures"""
214 | self.sig_mean = self._get_sig_mean()
215 | self.sig_mean_norm = self._calc_gene_scores()
216 | self.adata.uns['cell_types'] = list(self.gene_sig.columns)
217 |
218 | LOGGER.info('Recalculating anchor spots (highly expression of specific cell-type signatures)...')
219 | anchor_info = self._compute_anchors()
220 | self.sig_mean_norm[self.sig_mean_norm < 0] = self.eps
221 | self.sig_mean_norm.fillna(1/self.sig_mean_norm.shape[1], inplace=True)
222 | self.pure_spots, self.pure_dict, self.pure_idx = anchor_info
223 |
224 | def _get_sig_mean(self):
225 | sig_mean_expr = pd.DataFrame()
226 | cnt_df = self.adata_norm.to_df()
227 |
228 | # Calculate avg. signature expressions for each cell type
229 | for i, cell_type in enumerate(self.gene_sig.columns):
230 | sigs = np.intersect1d(cnt_df.columns, self.gene_sig.iloc[:, i].astype(str))
231 |
232 | if len(sigs) == 0:
233 | raise ValueError("Empty signatures for {},"
234 | "please double check your `gene_sig` input or set a higher"
235 | "`n_gene` threshold upon dataloading".format(cell_type))
236 |
237 | else:
238 | sig_mean_expr[cell_type] = cnt_df.loc[:, sigs].mean(axis=1)
239 |
240 | sig_mean_expr.index = self.adata.obs_names
241 | sig_mean_expr.columns = self.gene_sig.columns
242 | return sig_mean_expr
243 |
244 | def _update_spatial_info(self, sample_id):
245 | """Update paired spatial information to ST adata"""
246 | # Update image channel count for RGB input (`y`)
247 | if self.img is not None and self.img.ndim == 3:
248 | self.params['n_img_chan'] = 3
249 |
250 | if 'spatial' not in self.adata.uns_keys():
251 | self.adata.uns['spatial'] = {
252 | sample_id: {
253 | 'images': {'hires': (self.img - self.img.min()) / (self.img.max() - self.img.min())},
254 | 'scalefactors': self.scalefactor
255 | },
256 | }
257 |
258 | self.adata_norm.uns['spatial'] = {
259 | sample_id: {
260 | 'images': {'hires': (self.img - self.img.min()) / (self.img.max() - self.img.min())},
261 | 'scalefactors': self.scalefactor
262 | },
263 | }
264 |
265 | self.adata.obsm['spatial'] = self.map_info[['imagecol', 'imagerow']].values
266 | self.adata_norm.obsm['spatial'] = self.map_info[['imagecol', 'imagerow']].values
267 |
268 | # Typecast: spatial coords.
269 | self.adata_norm.obsm['spatial'] = self.map_info[['imagecol', 'imagerow']].values
270 | self.adata_norm.obsm['spatial'] = self.adata_norm.obsm['spatial'].astype(np.float32)
271 | return None
272 |
273 | def _update_img_patches(self, dl_poe):
274 | imgs = torch.Tensor(dl_poe.spot_img_stack)
275 | self.img_patches = imgs.reshape(imgs.shape[0], -1)
276 | return None
277 |
278 | def _norm_sig(self):
279 | # col-norm for each cell type: divided by mean
280 | gexp = self.sig_mean.apply(lambda x: x / x.mean(), axis=0)
281 | return gexp
282 |
283 | def _calc_gene_scores(self):
284 | """Calculate gene set enrichment scores for each signature sets"""
285 | adata = self.adata_scale.copy()
286 | #adata = self.adata_norm.copy()
287 | for cell_type in self.gene_sig.columns:
288 | sig = self.gene_sig[cell_type][~pd.isna(self.gene_sig[cell_type])].to_list()
289 | sc.tl.score_genes(adata, sig, use_raw=False, score_name=cell_type+'_score')
290 |
291 | gsea_df = adata.obs[[cell_type+'_score' for cell_type in self.gene_sig.columns]]
292 | gsea_df.columns = self.gene_sig.columns
293 | return gsea_df
294 |
295 | # --------------------------------
296 | # Running starfysh with 3-restart
297 | # --------------------------------
298 |
299 | def init_weights(module):
300 | if type(module) == nn.Linear:
301 | nn.init.xavier_uniform_(module.weight)
302 |
303 | elif type(module) == nn.BatchNorm1d:
304 | module.bias.data.zero_()
305 | module.weight.data.fill_(1.0)
306 |
307 |
308 | def run_starfysh(
309 | visium_args,
310 | n_repeats=3,
311 | lr=1e-4,
312 | epochs=100,
313 | batch_size=32,
314 | alpha_mul=50,
315 | poe=False,
316 | device=torch.device('cpu'),
317 | seed=0,
318 | verbose=True
319 | ):
320 | """
321 | Wrapper to run starfysh deconvolution.
322 |
323 | Parameters
324 | ----------
325 | visium_args : VisiumArguments
326 | Preprocessed metadata calculated from input visium matrix:
327 | e.g. mean signature expression, library size, anchor spots, etc.
328 |
329 | n_repeats : int
330 | Number of restart to run Starfysh
331 |
332 | epochs : int
333 | Max. number of iterations
334 |
335 | poe : bool
336 | Whether to perform inference with Poe w/ image integration
337 |
338 | Returns
339 | -------
340 | best_model : starfysh.AVAE or starfysh.AVAE_PoE
341 | Trained Starfysh model with deconvolution results
342 |
343 | loss : np.ndarray
344 | Training losses
345 | """
346 | np.random.seed(seed)
347 |
348 | # Loading parameters
349 | adata = visium_args.adata
350 | win_loglib = visium_args.win_loglib
351 | gene_sig, sig_mean_norm = visium_args.gene_sig, visium_args.sig_mean_norm
352 |
353 | models = [None] * n_repeats
354 | losses = []
355 | loss_c_list = np.repeat(np.inf, n_repeats)
356 |
357 | if poe:
358 | dl_func = VisiumPoEDataSet # dataloader
359 | train_func = train_poe # training wrapper
360 | else:
361 | dl_func = VisiumDataset
362 | train_func = train
363 |
364 | trainset = dl_func(adata=adata, args=visium_args)
365 | trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, drop_last=True)
366 |
367 | # Running Starfysh with multiple starts
368 | LOGGER.info('Running Starfysh with {} restarts, choose the model with best parameters...'.format(n_repeats))
369 |
370 | count = 0
371 | while count < n_repeats:
372 | best_loss_c = np.inf
373 | if poe:
374 | model = AVAE_PoE(
375 | adata=adata,
376 | gene_sig=sig_mean_norm,
377 | patch_r=visium_args.params['patch_r'],
378 | win_loglib=win_loglib,
379 | alpha_mul=alpha_mul,
380 | n_img_chan=visium_args.params['n_img_chan']
381 | )
382 | # Update patched & flattened image patches
383 | visium_args._update_img_patches(trainset)
384 | else:
385 | model = AVAE(
386 | adata=adata,
387 | gene_sig=sig_mean_norm,
388 | win_loglib=win_loglib,
389 | alpha_mul=alpha_mul
390 | )
391 |
392 | model = model.to(device)
393 | loss_dict = {
394 | 'reconst': [],
395 | 'c': [],
396 | 'u': [],
397 | 'z': [],
398 | 'n': [],
399 | 'tot': []
400 | }
401 |
402 | # Initialize model params
403 | if verbose:
404 | LOGGER.info('Initializing model parameters...')
405 |
406 | model.apply(init_weights)
407 | optimizer = optim.Adam(model.parameters(), lr=lr)
408 | scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.98)
409 |
410 | try:
411 | for epoch in range(epochs):
412 | result = train_func(model, trainloader, device, optimizer)
413 | torch.cuda.empty_cache()
414 |
415 | loss_tot, loss_reconst, loss_u, loss_z, loss_c, loss_n, corr_list = result
416 | if loss_c < best_loss_c:
417 | models[count] = model
418 | best_loss_c = loss_c
419 |
420 | torch.cuda.empty_cache()
421 |
422 | loss_dict['tot'].append(loss_tot)
423 | loss_dict['reconst'].append(loss_reconst)
424 | loss_dict['u'].append(loss_u)
425 | loss_dict['z'].append(loss_z)
426 | loss_dict['c'].append(loss_c)
427 | loss_dict['n'].append(loss_n)
428 |
429 | if (epoch + 1) % 10 == 0 and verbose:
430 | LOGGER.info("Epoch[{}/{}], train_loss: {:.4f}, train_reconst: {:.4f}, train_u: {:.4f},train_z: {:.4f},train_c: {:.4f},train_l: {:.4f}".format(
431 | epoch + 1, epochs, loss_tot, loss_reconst, loss_u, loss_z, loss_c, loss_n)
432 | )
433 | scheduler.step()
434 |
435 | losses.append(loss_dict)
436 | loss_c_list[count] = best_loss_c
437 |
438 | count += 1
439 |
440 | except ValueError as ve: # Bad model initialization -> numerical instability
441 | continue
442 |
443 | if verbose:
444 | LOGGER.info('Saving the best-performance model...')
445 | LOGGER.info(" === Finished training === \n")
446 |
447 | idx = np.argmin(loss_c_list)
448 | best_model = models[idx]
449 | loss = losses[idx]
450 |
451 | return best_model, loss
452 |
453 |
454 | # -------------------
455 | # Preprocessing & IO
456 | # -------------------
457 |
458 | def preprocess(
459 | adata_raw,
460 | min_perc=None,
461 | max_perc=None,
462 | n_top_genes=2000,
463 | mt_thld=100,
464 | verbose=True,
465 | multiple_data=False
466 | ):
467 | """
468 | Preprocessing ST gexp matrix, remove Ribosomal & Mitochondrial genes
469 |
470 | Parameters
471 | ----------
472 | adata_raw : annData
473 | Spot x Bene raw expression matrix [S x G]
474 |
475 | min_perc : float
476 | lower-bound percentile of non-zero gexps for filtering spots
477 |
478 | max_perc : float
479 | upper-bound percentile of non-zero gexps for filtering spots
480 |
481 | n_top_genes: float
482 | number of the variable genes
483 |
484 | mt_thld : float
485 | max. percentage of mitochondrial gexps for filtering spots
486 | with excessive MT expressions
487 |
488 | multiple_data: bool
489 | whether the study need integrate datasets
490 | """
491 | adata = adata_raw.copy()
492 |
493 | if min_perc and max_perc:
494 | assert 0 < min_perc < max_perc < 100, \
495 | "Invalid thresholds for cells: {0}, {1}".format(min_perc, max_perc)
496 | min_counts = np.percentile(adata.obs['total_counts'], min_perc)
497 | sc.pp.filter_cells(adata, min_counts=min_counts)
498 |
499 | # Remove cells with excessive MT expressions
500 | # Remove MT & RB genes
501 | if verbose:
502 | LOGGER.info('Preprocessing1: delete the mt and rp')
503 |
504 | adata.var['mt'] = np.logical_or(
505 | adata.var_names.str.startswith('MT-'),
506 | adata.var_names.str.startswith('mt-')
507 | )
508 | adata.var['rb'] = adata.var_names.str.startswith(('RP', 'Rp', 'rp'))
509 |
510 | sc.pp.calculate_qc_metrics(adata, qc_vars=['mt'], inplace=True)
511 | mask_cell = adata.obs['pct_counts_mt'] < mt_thld
512 | mask_gene = np.logical_and(~adata.var['mt'], ~adata.var['rb'])
513 |
514 | adata = adata[mask_cell, mask_gene]
515 | sc.pp.filter_genes(adata, min_cells=1)
516 |
517 | # Normalize & take log-transform
518 | if verbose:
519 | LOGGER.info('Preprocessing2: Normalize')
520 | if multiple_data:
521 | sc.pp.normalize_total(adata, target_sum=1e6, inplace=True)
522 | else:
523 | sc.pp.normalize_total(adata, inplace=True)
524 |
525 | # Preprocessing3: Logarithm
526 | if verbose:
527 | LOGGER.info('Preprocessing3: Logarithm')
528 | sc.pp.log1p(adata)
529 |
530 | # Preprocessing4: Find the variable genes
531 | if verbose:
532 | LOGGER.info('Preprocessing4: Find the variable genes')
533 | sc.pp.highly_variable_genes(adata, flavor='seurat', n_top_genes=n_top_genes, inplace=True)
534 |
535 | # Filter corresponding `obs` & `var` in raw-count matrix
536 | adata_raw = adata_raw[adata.obs_names, adata.var_names]
537 | adata_raw.var['highly_variable'] = adata.var['highly_variable']
538 | adata_raw.obs = adata.obs
539 |
540 | return adata_raw, adata
541 |
542 |
543 | def load_adata(data_folder, sample_id, n_genes, multiple_data=False):
544 | """
545 | load visium adata with raw counts, preprocess & extract highly variable genes
546 |
547 | Parameters
548 | ----------
549 | data_folder : str
550 | Root directory of the data
551 |
552 | sample_id : str
553 | Sample subdirectory under `data_folder`
554 |
555 | n_genes : int
556 | the number of the gene for training
557 |
558 | multiple_data: bool
559 | whether the study include multiple datasets
560 |
561 | Returns
562 | -------
563 | adata : sc.AnnData
564 | Processed ST raw counts
565 |
566 | adata_norm : sc.AnnData
567 | Processed ST normalized & log-transformed data
568 | """
569 | has_feature_h5 = os.path.isfile(
570 | os.path.join(data_folder, sample_id, 'filtered_feature_bc_matrix.h5')
571 | ) # whether dataset stored in h5 with spatial info.
572 |
573 | if has_feature_h5:
574 | adata = sc.read_visium(path=os.path.join(data_folder, sample_id), library_id=sample_id)
575 | adata.var_names_make_unique()
576 | adata.obs['sample'] = sample_id
577 | elif sample_id.startswith('simu'): # simulations
578 | adata = sc.read_csv(os.path.join(data_folder, sample_id, 'counts.st_synth.csv'))
579 | else:
580 | filenames = [
581 | f[:-5] for f in os.listdir(os.path.join(data_folder, sample_id))
582 | if f[-5:] == '.h5ad'
583 | ]
584 | assert len(filenames) == 1, \
585 | "None or more than `h5ad` file in the data directory," \
586 | "please contain only 1 target ST file in the given directory"
587 | adata = sc.read_h5ad(os.path.join(data_folder, sample_id, filenames[0] + '.h5ad'))
588 | adata.var_names_make_unique()
589 | adata.obs['sample'] = sample_id
590 |
591 | if '_index' in adata.var.columns:
592 | adata.var_names = adata.var['_index']
593 | adata.var_names.name = 'Genes'
594 | adata.var.drop('_index', axis=1, inplace=True)
595 |
596 | adata, adata_norm = preprocess(adata, n_top_genes=n_genes, multiple_data=multiple_data)
597 | return adata, adata_norm
598 |
599 |
600 | def load_signatures(filename, adata):
601 | """
602 | load annotated signature gene sets
603 |
604 | Parameters
605 | ----------
606 | filename : str
607 | Signature file
608 |
609 | adata : sc.AnnData
610 | ST count matrix
611 |
612 | Returns
613 | -------
614 | gene_sig : pd.DataFrame
615 | signatures per cell type / state
616 | """
617 | assert os.path.isfile(filename), "Unable to find the signature file"
618 | gene_sig = pd.read_csv(filename, index_col=0)
619 | gene_sig = filter_gene_sig(gene_sig, adata.to_df())
620 | sigs = np.unique(
621 | gene_sig.apply(
622 | lambda x:
623 | pd.unique(x[~pd.isna(x)])
624 | ).values
625 | )
626 |
627 | return gene_sig, np.unique(sigs)
628 |
629 |
630 | def preprocess_img(
631 | data_path,
632 | sample_id,
633 | adata_index,
634 | rgb_channels=True
635 | ):
636 | """
637 | Load and preprocess visium paired H&E image & spatial coords
638 |
639 | Parameters
640 | ----------
641 | data_path : str
642 | Root directory of the data
643 |
644 | sample_id : str
645 | Sample subdirectory under `data_path`
646 |
647 | rgb_channels : bool
648 | Whether to apply binary color deconvolution to extract 1D `eosin` channel
649 | Please refer to:
650 | https://digitalslidearchive.github.io/HistomicsTK/examples/color_deconvolution.html
651 |
652 | Returns
653 | -------
654 | img : np.ndarray
655 | Processed histology image
656 |
657 | map_info : np.ndarray
658 | Spatial coords of spots (dim: [S, 2])
659 | """
660 | filename = os.path.join(data_path, sample_id, 'spatial', 'tissue_hires_image.png')
661 | if os.path.isfile(filename):
662 | if rgb_channels:
663 | img = io.imread(filename)
664 | #img = (img-img.min())/(img.max()-img.min())
665 |
666 | else:
667 | img = io.imread(filename)
668 | if img.max() <= 1:
669 | img = (img * 255).astype(np.uint8)
670 |
671 | # Create stain matrix
672 | stains = ['hematoxylin','eosin', 'null']
673 | stain_cmap = htk.preprocessing.color_deconvolution.stain_color_map
674 | W = np.array([stain_cmap[st] for st in stains]).T
675 |
676 | # Color deconvolution
677 | imDeconvolved = htk.preprocessing.color_deconvolution.color_deconvolution(img, W)
678 | img = imDeconvolved.Stains[:,:,0] # H-channel
679 |
680 | # Take inverse of H-channel (approx. cell density)
681 | img = (img - img.min()) / (img.max()-img.min())
682 | img = img.max() - img
683 | img = (img*255).astype(np.uint8)
684 | else:
685 | img = None
686 |
687 | # Mapping images to location
688 | f = open(os.path.join(data_path, sample_id, 'spatial', 'scalefactors_json.json', ))
689 | json_info = json.load(f)
690 | f.close()
691 |
692 | tissue_position_list = pd.read_csv(os.path.join(data_path, sample_id, 'spatial', 'tissue_positions_list.csv'), header=None, index_col=0)
693 | tissue_position_list = tissue_position_list.loc[adata_index, :]
694 | map_info = tissue_position_list.iloc[:, -4:-2]
695 | map_info.columns = ['array_row', 'array_col']
696 | map_info.loc[:, 'imagerow'] = tissue_position_list.iloc[:, -2]
697 | map_info.loc[:, 'imagecol'] = tissue_position_list.iloc[:, -1]
698 | map_info.loc[:, 'sample'] = sample_id
699 |
700 | return {
701 | 'img': img,
702 | 'map_info': map_info,
703 | 'scalefactor': json_info
704 | }
705 |
706 |
707 | def get_adata_wsig(adata, adata_norm, gene_sig):
708 | """
709 | Select intersection of HVGs from dataset & signature annotations
710 | """
711 | # TODO: in-place operators for `adata`
712 | hvgs = adata.var_names[adata.var.highly_variable]
713 | unique_sigs = np.unique(gene_sig.values[~pd.isna(gene_sig)])
714 | genes_to_keep = np.union1d(
715 | hvgs,
716 | np.intersect1d(adata.var_names, unique_sigs)
717 | )
718 | return adata[:, genes_to_keep], adata_norm[:, genes_to_keep]
719 |
720 |
721 | def filter_gene_sig(gene_sig, adata_df):
722 | for i in range(gene_sig.shape[0]):
723 | for j in range(gene_sig.shape[1]):
724 | gene = gene_sig.iloc[i, j]
725 | if gene in adata_df.columns:
726 | # We don't filter signature genes based on expression level (prev: threshold=20)
727 | if adata_df.loc[:, gene].sum() < 0:
728 | gene_sig.iloc[i, j] = 'NaN'
729 | return gene_sig
730 |
731 |
732 | def get_umap(adata_sample, display=False):
733 | sc.tl.pca(adata_sample, svd_solver='arpack')
734 | sc.pp.neighbors(adata_sample, n_neighbors=15, n_pcs=40)
735 | sc.tl.umap(adata_sample, min_dist=0.2)
736 | if display:
737 | sc.pl.umap(adata_sample)
738 | umap_plot = pd.DataFrame(adata_sample.obsm['X_umap'],
739 | columns=['umap1', 'umap2'],
740 | index=adata_sample.obs_names)
741 | return umap_plot
742 |
743 |
744 | def get_simu_map_info(umap_plot):
745 | map_info = []
746 | map_info = [-umap_plot['umap2'] * 10, umap_plot['umap1'] * 10]
747 | map_info = pd.DataFrame(np.transpose(map_info),
748 | columns=['array_row', 'array_col'],
749 | index=umap_plot.index)
750 | return map_info
751 |
752 |
753 | def get_windowed_library(adata_sample, map_info, library, window_size):
754 | library_n = []
755 |
756 | for i in adata_sample.obs_names:
757 | window_size = window_size
758 | dist_arr = np.sqrt(
759 | (map_info.loc[:, 'array_col'] - map_info.loc[i, 'array_col']) ** 2 +
760 | (map_info.loc[:, 'array_row'] - map_info.loc[i, 'array_row']) ** 2
761 | )
762 |
763 | library_n.append(library[dist_arr < window_size].mean())
764 | library_n = np.array(library_n)
765 | return library_n
766 |
767 |
768 | def append_sigs(gene_sig, factor, sigs, n_genes=10):
769 | """
770 | Append list of genes to a given cell type as additional signatures or
771 | add novel cell type / states & their signatures
772 | """
773 | assert len(sigs) > 0, "Signature list must have positive length"
774 | gene_sig_new = gene_sig.copy()
775 |
776 | if not isinstance(sigs, list):
777 | sigs = sigs.to_list()
778 | if n_genes < len(sigs):
779 | sigs = sigs[:n_genes]
780 |
781 | markers = set([i for i in gene_sig[factor] if str(i) != 'nan'] +
782 | [i for i in sigs if str(i) != 'nan'])
783 | nrow_diff = int(np.abs(len(markers)-gene_sig_new.shape[0]))
784 | if len(markers) > gene_sig_new.shape[0]:
785 | df_dummy = pd.DataFrame([[np.nan] * gene_sig.shape[1]] * nrow_diff,
786 | columns=gene_sig_new.columns)
787 | gene_sig_new = pd.concat([gene_sig_new, df_dummy], ignore_index=True)
788 | else:
789 | markers = list(markers) + [np.nan]*nrow_diff
790 | gene_sig_new[factor] = list(markers)
791 |
792 | return gene_sig_new
793 |
794 |
795 | def refine_anchors(
796 | visium_args,
797 | aa_model,
798 | anchor_threshold=0.1,
799 | n_genes=10
800 | ):
801 | """
802 | Refine anchor spots & marker genes with archetypal analysis. We append DEGs
803 | computed from archetypes to their best-matched anchors followed by re-computing
804 | new anchor spots
805 |
806 | Parameters
807 | ----------
808 | visium_args : VisiumArgument
809 | Default parameter set for Starfysh upon dataloading
810 |
811 | aa_model : ArchetypalAnalysis
812 | Pre-computed archetype object
813 |
814 | anchor_threshold : float
815 | Top percent of anchor spots per cell-type
816 | for archetypal mapping
817 |
818 | n_genes : int
819 | # archetypal marker genes to append per refinement iteration
820 |
821 | Returns
822 | -------
823 | visimu_args : VisiumArgument
824 | updated parameter set for Starfysh
825 | """
826 | # TODO: integrate into `visium_args` class
827 |
828 | n_spots = visium_args.adata.shape[0]
829 | gene_sig = visium_args.gene_sig.copy()
830 | anchors_df = visium_args.get_anchors()
831 | n_top_anchors = int(anchor_threshold*n_spots)
832 |
833 | # Retrieve anchor-archetype mapping scores
834 | map_df, map_dict = aa_model.assign_archetypes(anchor_df=anchors_df[:n_top_anchors],
835 | r=n_top_anchors)
836 | markers_df = aa_model.find_markers(display=False)
837 |
838 | # (1). Update signatures
839 | for cell_type, archetype in map_dict.items():
840 | gene_sig = append_sigs(gene_sig=gene_sig,
841 | factor=cell_type,
842 | sigs=markers_df[archetype],
843 | n_genes=n_genes)
844 |
845 | # (2). Update data args.
846 | visium_args.gene_sig = gene_sig
847 | visium_args._update_anchors()
848 | return visium_args
849 |
850 |
851 | # -------------------
852 | # Post-processing
853 | # -------------------
854 |
855 | def extract_feature(adata, key):
856 | """
857 | Extract generative / inference output from adata.obsm
858 | generate dummy tmp. adata for plotting
859 | """
860 | assert key in adata.obsm.keys(), "Unfounded Starfysh generative / inference output: {}".format(key)
861 |
862 | if key == 'qc_m':
863 | cols = adata.uns['cell_types'] # cell type deconvolution
864 | elif key == 'qz_m':
865 | cols = ['z'+str(i) for i in range(adata.obsm[key].shape[1])] # inferred qz (low-dim manifold)
866 | elif '_inferred_exprs' in key:
867 | cols = adata.var_names # inferred cell-type specific expressions
868 | else:
869 | cols = ['density']
870 | adata_dummy = adata.copy()
871 | adata_dummy.obs = pd.DataFrame(adata.obsm[key], index=adata.obs.index, columns=cols)
872 | return adata_dummy
873 |
874 |
875 | def get_reconst_img(args, img_patches):
876 | """
877 | Reconst original histology image (H x W) from the given patched image (S x P)
878 | """
879 | reconst_img = np.zeros_like(args.img, dtype=np.float64)
880 | r = args.params['patch_r']
881 | patch_size = (r * 2, r * 2, 3) if args.img.ndim == 3 else (r * 2, r * 2)
882 | scale_factor = args.scalefactor['tissue_hires_scalef']
883 | img_col = args.map_info['imagecol'] * scale_factor
884 | img_row = args.map_info['imagerow'] * scale_factor
885 |
886 | for i in range(len(img_col)):
887 | patch_y = slice(int(img_row[i]) - r, int(img_row[i]) + r)
888 | patch_x = slice(int(img_col[i]) - r, int(img_col[i]) + r)
889 |
890 | sy, sx = reconst_img[patch_y, patch_x].shape[:2]
891 | img_patch = img_patches[i].reshape(patch_size)
892 | reconst_img[patch_y, patch_x] = img_patch[:sy, :sx] # edge patch cases
893 |
894 | return reconst_img
895 |
896 |
897 |
--------------------------------------------------------------------------------
/starfysh/utils_integrate.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import json
4 | import numpy as np
5 | import pandas as pd
6 | import scanpy as sc
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.optim as optim
11 |
12 | from scipy.stats import median_abs_deviation
13 | from torch.utils.data import DataLoader
14 |
15 | import sys
16 | from skimage import io
17 |
18 | # Module import
19 | from starfysh import LOGGER
20 | from .dataloader import IntegrativeDataset, IntegrativePoEDataset
21 | from .starfysh import AVAE, AVAE_PoE, train, train_poe
22 |
23 | import numpy as np
24 | import pandas as pd
25 | import logging
26 | import torch
27 | import torch.nn.functional as F
28 | from torch.utils.data import Dataset
29 |
30 |
31 | class VisiumArguments_integrate:
32 | """
33 | Loading Visium AnnData, perform preprocessing, library-size smoothing & Anchor spot detection
34 |
35 | Parameters
36 | ----------
37 | adata : AnnData
38 | annotated visium count matrix
39 |
40 | adata_norm : AnnData
41 | annotated visium count matrix after normalization & log-transform
42 |
43 | gene_sig : pd.DataFrame
44 | list of signature genes for each cell type. (dim: [S, Cell_type])
45 |
46 | img_metadata : dict
47 | Spatial information metadata (histology image, coordinates, scalefactor)
48 | """
49 | def __init__(
50 | self,
51 | adata,
52 | adata_norm,
53 | gene_sig,
54 | img_metadata,
55 | individual_args,
56 | **kwargs
57 | ):
58 |
59 | self.adata = adata
60 | self.adata_norm = adata_norm
61 | self.gene_sig = gene_sig
62 | self.eps = 1e-6
63 |
64 | self.params = {
65 | 'sample_id': 'ST',
66 | 'n_anchors': int(adata.shape[0]),
67 | 'patch_r': 13,
68 | 'signif_level': 3,
69 | 'window_size': 1,
70 | 'n_img_chan': 1
71 | }
72 |
73 | for k, v in kwargs.items():
74 | if k in self.params.keys():
75 | self.params[k] = v
76 |
77 |
78 | map_info_temp_all = []
79 | for i in self.params['sample_id']:
80 | map_info_temp = img_metadata[i]['map_info'].iloc[:, :4].astype(float)
81 | map_info_temp_all.append(map_info_temp)
82 | self.map_info = pd.concat(map_info_temp_all)
83 |
84 | img_temp = {}
85 | for i in self.params['sample_id']:
86 | img_temp[i] = img_metadata[i]['img']
87 | self.img = img_temp
88 |
89 | self.img_patches = None
90 |
91 | scalefactor_temp = {}
92 | for i in self.params['sample_id']:
93 | scalefactor_temp[i] = img_metadata[i]['scalefactor']
94 | self.scalefactor = scalefactor_temp
95 |
96 | # Update parameters for library smoothing & anchor spot identification
97 |
98 |
99 | # Center expression for gene score calculation
100 | adata_scale = self.adata_norm.copy()
101 | sc.pp.scale(adata_scale)
102 |
103 | # Store cell types
104 | self.adata.uns['cell_types'] = list(self.gene_sig.columns)
105 |
106 | # Filter out signature genes X listed in expression matrix
107 | LOGGER.info('Subsetting highly variable & signature genes ...')
108 | self.adata, self.adata_norm = get_adata_wsig(adata, adata_norm, gene_sig)
109 | self.adata_scale = adata_scale[:, adata.var_names]
110 |
111 | # Update spatial metadata
112 | self._update_spatial_info(self.params['sample_id'])
113 |
114 | # Get smoothed library size
115 | LOGGER.info('Smoothing library size by taking averaging with neighbor spots...')
116 | log_lib = np.log1p(self.adata.X.sum(1))
117 | self.log_lib = np.squeeze(np.asarray(log_lib)) if log_lib.ndim > 1 else log_lib
118 |
119 | win_loglib_temp_all = []
120 | for i in self.params['sample_id']:
121 | win_loglib_temp = get_windowed_library(self.adata[self.adata.obs['sample']==i],
122 | self.map_info[self.adata.obs['sample']==i],
123 | self.log_lib[self.adata.obs['sample']==i],
124 | window_size=self.params['window_size']
125 | )
126 | win_loglib_temp_all.append(pd.DataFrame(win_loglib_temp))
127 |
128 | self.win_loglib = pd.concat(win_loglib_temp_all)
129 | self.win_loglib = np.array(self.win_loglib)
130 |
131 | # Retrieve & normalize signature gexp
132 | LOGGER.info('Retrieving & normalizing signature gene expressions...')
133 | sig_mean_temp = []
134 | for i in individual_args.keys():
135 | sig_mean_temp.append(individual_args[i].sig_mean)
136 |
137 | self.sig_mean = pd.concat(sig_mean_temp,axis=0)
138 |
139 | sig_mean_norm_temp = []
140 | for i in individual_args.keys():
141 | sig_mean_norm_temp.append(individual_args[i].sig_mean_norm)
142 |
143 | self.sig_mean_norm = pd.concat(sig_mean_norm_temp,axis=0)
144 |
145 | # Get anchor spots
146 | LOGGER.info('Identifying anchor spots (highly expression of specific cell-type signatures)...')
147 | anchor_info = self._compute_anchors()
148 |
149 | self.pure_spots, self.pure_dict, self.pure_idx = anchor_info
150 | del self.adata.raw, self.adata_norm.raw
151 |
152 | def get_adata(self):
153 | """Return adata after preprocessing & HVG gene selection"""
154 | return self.adata, self.adata_norm
155 |
156 | def get_anchors(self):
157 | """Return indices of anchor spots for each cell type"""
158 | anchors_df = pd.DataFrame.from_dict(self.pure_dict, orient='index')
159 | anchors_df = anchors_df.transpose()
160 |
161 | # Check whether empty anchors detected for any factor
162 | empty_indices = np.where(
163 | (~pd.isna(anchors_df)).sum(0) == 0
164 | )[0]
165 |
166 | if len(empty_indices) > 0:
167 | raise ValueError("Cell type(s) {} has no anchors significantly enriched for its signatures,"
168 | "please lower outlier stats `signif_level`".format(
169 | anchors_df.columns[empty_indices].to_list()
170 | ))
171 |
172 | return anchors_df.applymap(
173 | lambda x:
174 | -1 if x is None else np.where(self.adata.obs.index == x)[0][0]
175 | )
176 |
177 | def get_img_patches(self):
178 | assert self.img_patches is not None, "Please run Starfysh PoE first"
179 | return self.img_patches
180 |
181 | def append_factors(self, arche_markers):
182 | """
183 | Append list of archetypes (w/ corresponding markers) as additional cell type(s) / state(s) to the `gene_sig`
184 | """
185 | self.gene_sig = pd.concat((self.gene_sig, arche_markers), axis=1)
186 |
187 | # Update factor names & anchor spots
188 | self.adata.uns['cell_types'] = list(self.gene_sig.columns)
189 | self._update_anchors()
190 | return None
191 |
192 | def replace_factors(self, factors_to_repl, arche_markers):
193 | """
194 | Replace factor(s) with archetypes & their corresponding markers in the `gene_sig`
195 | """
196 | if isinstance(factors_to_repl, str):
197 | assert isinstance(arche_markers, pd.Series),\
198 | "Please pick only one archetype to replace the factor {}".format(factors_to_repl)
199 | factors_to_repl = [factors_to_repl]
200 | archetypes = [arche_markers.name]
201 | else:
202 | assert len(factors_to_repl) == len(arche_markers.columns), \
203 | "Unequal # cell types & archetypes to replace with"
204 | archetypes = arche_markers.columns
205 |
206 | self.gene_sig.rename(
207 | columns={
208 | f: a
209 | for (f, a) in zip(factors_to_repl, archetypes)
210 | }, inplace=True
211 | )
212 | self.gene_sig[archetypes] = pd.DataFrame(arche_markers)
213 |
214 | # Update factor names & anchor spots
215 | self.adata.uns['cell_types'] = list(self.gene_sig.columns)
216 | self._update_anchors()
217 | return None
218 |
219 | # --- Private methods ---
220 | def _compute_anchors(self):
221 | """
222 | Calculate top `anchor_spots` significantly enriched for given cell type(s)
223 | determined by gene set scores from signatures
224 | """
225 | score_df = self.sig_mean_norm
226 | signif_level = self.params['signif_level']
227 | n_anchor = self.params['n_anchors']
228 |
229 | top_expr_spots = (-score_df.values).argsort(axis=0)[:n_anchor, :]
230 | pure_spots = np.transpose(np.array(score_df.index)[top_expr_spots])
231 |
232 | pure_dict = {
233 | ct: spot
234 | for (spot, ct) in zip(pure_spots, score_df.columns)
235 | }
236 |
237 | pure_indices = np.zeros([score_df.shape[0], 1])
238 | idx = [np.where(score_df.index == i)[0][0]
239 | for i in sorted({x for v in pure_dict.values() for x in v})]
240 | pure_indices[idx] = 1
241 | return pure_spots, pure_dict, pure_indices
242 |
243 | def _update_anchors(self):
244 | """Re-calculate anchor spots given updated gene signatures"""
245 | self.sig_mean = self._get_sig_mean()
246 | self.sig_mean_norm = self._calc_gene_scores()
247 | self.adata.uns['cell_types'] = list(self.gene_sig.columns)
248 |
249 | LOGGER.info('Recalculating anchor spots (highly expression of specific cell-type signatures)...')
250 | anchor_info = self._compute_anchors()
251 | self.sig_mean_norm[self.sig_mean_norm < 0] = self.eps
252 | self.sig_mean_norm.fillna(1/self.sig_mean_norm.shape[1], inplace=True)
253 | self.pure_spots, self.pure_dict, self.pure_idx = anchor_info
254 |
255 | def _get_sig_mean(self):
256 | sig_mean_expr = pd.DataFrame()
257 | cnt_df = self.adata_norm.to_df()
258 |
259 | # Calculate avg. signature expressions for each cell type
260 | for i, cell_type in enumerate(self.gene_sig.columns):
261 | sigs = np.intersect1d(cnt_df.columns, self.gene_sig.iloc[:, i].astype(str))
262 |
263 | if len(sigs) == 0:
264 | raise ValueError("Empty signatures for {},"
265 | "please double check your `gene_sig` input or set a higher"
266 | "`n_gene` threshold upon dataloading".format(cell_type))
267 |
268 | else:
269 | sig_mean_expr[cell_type] = cnt_df.loc[:, sigs].mean(axis=1)
270 |
271 | sig_mean_expr.index = self.adata.obs_names
272 | sig_mean_expr.columns = self.gene_sig.columns
273 | return sig_mean_expr
274 |
275 | def _update_spatial_info(self, sample_id):
276 | """Update paired spatial information to ST adata"""
277 | # Update image channel count for RGB input (`y`)
278 | if self.img is not None and self.img[sample_id.iloc[0]].ndim == 3:
279 | self.params['n_img_chan'] = 3
280 |
281 | if 'spatial' not in self.adata.uns_keys():
282 | self.adata.uns['spatial'] = {
283 | i: {
284 | 'images': {'hires': (self.img[i] - self.img[i].min()) / (self.img[i].max() - self.img[i].min())},
285 | 'scalefactors': self.scalefactor
286 | } for i in sample_id
287 | }
288 |
289 | self.adata_norm.uns['spatial'] = {
290 | i: {
291 | 'images': {'hires': (self.img[i] - self.img[i].min()) / (self.img[i].max() - self.img[i].min())},
292 | 'scalefactors': self.scalefactor[i]
293 | } for i in sample_id
294 | }
295 | self.adata.obsm['spatial'] = self.map_info[['imagecol', 'imagerow']].values
296 | self.adata_norm.obsm['spatial'] = self.map_info[['imagecol', 'imagerow']].values
297 |
298 | # Typecast: spatial coords.
299 | self.adata_norm.obsm['spatial'] = self.map_info[['imagecol', 'imagerow']].values
300 | self.adata_norm.obsm['spatial'] = self.adata_norm.obsm['spatial'].astype(np.float32)
301 | return None
302 |
303 | def _update_img_patches(self, dl_poe):
304 | dl_poe.spot_img_stack = np.array(dl_poe.spot_img_stack)
305 | dl_poe.spot_img_stack = dl_poe.spot_img_stack.reshape(dl_poe.spot_img_stack.shape[0], -1)
306 | imgs = torch.Tensor(dl_poe.spot_img_stack)
307 | self.img_patches = imgs
308 | return None
309 |
310 | def _norm_sig(self):
311 | # col-norm for each cell type: divided by mean
312 | gexp = self.sig_mean.apply(lambda x: x / x.mean(), axis=0)
313 | return gexp
314 |
315 | def _calc_gene_scores(self):
316 | """Calculate gene set enrichment scores for each signature sets"""
317 | adata = self.adata_scale.copy()
318 | #adata = self.adata_norm.copy()
319 | for cell_type in self.gene_sig.columns:
320 | sig = self.gene_sig[cell_type][~pd.isna(self.gene_sig[cell_type])].to_list()
321 | sc.tl.score_genes(adata, sig, score_name=cell_type+'_score',use_raw=False)
322 |
323 | gsea_df = adata.obs[[cell_type+'_score' for cell_type in self.gene_sig.columns]]
324 | gsea_df.columns = self.gene_sig.columns
325 | return gsea_df
326 |
327 |
328 | def get_adata_wsig(adata, adata_norm, gene_sig):
329 | """
330 | Select intersection of HVGs from dataset & signature annotations
331 | """
332 | hvgs = adata.var_names[adata.var.highly_variable]
333 | unique_sigs = np.unique(gene_sig.values[~pd.isna(gene_sig)])
334 | genes_to_keep = np.union1d(
335 | hvgs,
336 | np.intersect1d(adata.var_names, unique_sigs)
337 | )
338 | return adata[:, genes_to_keep], adata_norm[:, genes_to_keep]
339 |
340 |
341 | def get_windowed_library(adata_sample, map_info, library, window_size):
342 | library_n = []
343 |
344 | for i in adata_sample.obs_names:
345 | window_size = window_size
346 | dist_arr = np.sqrt(
347 | (map_info.loc[:, 'array_col'] - map_info.loc[i, 'array_col']) ** 2 +
348 | (map_info.loc[:, 'array_row'] - map_info.loc[i, 'array_row']) ** 2
349 | )
350 |
351 | library_n.append(library[dist_arr < window_size].mean())
352 | library_n = np.array(library_n)
353 |
354 | return library_n
355 |
356 |
357 | def init_weights(module):
358 | if type(module) == nn.Linear:
359 | torch.nn.init.kaiming_uniform_(module.weight)
360 |
361 | elif type(module) == nn.BatchNorm1d:
362 | module.bias.data.zero_()
363 | module.weight.data.fill_(1.0)
364 |
365 | def run_starfysh(
366 | visium_args,
367 | n_repeats=3,
368 | lr=1e-4,
369 | epochs=100,
370 | batch_size=32,
371 | alpha_mul=50,
372 | poe=False,
373 | device=torch.device('cpu'),
374 | verbose=True,
375 |
376 | ):
377 | """
378 | Wrapper to run starfysh deconvolution.
379 |
380 | Parameters
381 | ----------
382 | visium_args : VisiumArguments
383 | Preprocessed metadata calculated from input visium matrix:
384 | e.g. mean signature expression, library size, anchor spots, etc.
385 |
386 | n_repeats : int
387 | Number of restart to run Starfysh
388 |
389 | epochs : int
390 | Max. number of iterations
391 |
392 | poe : bool
393 | Whether to perform inference with Poe w/ image integration
394 |
395 | Returns
396 | -------
397 | best_model : starfysh.AVAE or starfysh.AVAE_PoE
398 | Trained Starfysh model with deconvolution results
399 |
400 | loss : np.ndarray
401 | Training losses
402 | """
403 | np.random.seed(0)
404 |
405 | # Loading parameters
406 | adata = visium_args.adata
407 | win_loglib = visium_args.win_loglib
408 | gene_sig, sig_mean_norm = visium_args.gene_sig, visium_args.sig_mean_norm
409 |
410 | models = [None] * n_repeats
411 | losses = []
412 | loss_c_list = np.repeat(np.inf, n_repeats)
413 |
414 | if poe:
415 | dl_func = IntegrativePoEDataset # dataloader
416 | train_func = train_poe # training wrapper
417 | else:
418 | dl_func = IntegrativeDataset
419 | train_func = train
420 |
421 | trainset = dl_func(adata=adata, args=visium_args)
422 | trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, drop_last=True)
423 |
424 | # Running Starfysh with multiple starts
425 | LOGGER.info('Running Starfysh with {} restarts, choose the model with best parameters...'.format(n_repeats))
426 | for i in range(n_repeats):
427 | if verbose:
428 | LOGGER.info(" === Restart Starfysh {0} === \n".format(i + 1))
429 | best_loss_c = np.inf
430 |
431 |
432 | if poe:
433 |
434 | model = AVAE_PoE(
435 | adata=adata,
436 | gene_sig=sig_mean_norm,
437 | patch_r=visium_args.params['patch_r'],
438 | win_loglib=win_loglib,
439 | alpha_mul=alpha_mul,
440 | n_img_chan=visium_args.params['n_img_chan']
441 | )
442 | # Update patched & flattened image patches
443 | visium_args._update_img_patches(trainset)
444 | else:
445 | model = AVAE(
446 | adata=adata,
447 | gene_sig=sig_mean_norm,
448 | win_loglib=win_loglib,
449 | alpha_mul=alpha_mul
450 | )
451 |
452 | model = model.to(device)
453 | loss_dict = {
454 | 'reconst': [],
455 | 'c': [],
456 | 'u': [],
457 | 'z': [],
458 | 'n': [],
459 | 'tot': []
460 | }
461 |
462 | # Initialize model params
463 | if verbose:
464 | LOGGER.info('Initializing model parameters...')
465 |
466 | model.apply(init_weights)
467 | optimizer = optim.Adam(model.parameters(), lr=lr)
468 | scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
469 |
470 | for epoch in range(epochs):
471 | result = train_func(model, trainloader, device, optimizer)
472 | torch.cuda.empty_cache()
473 |
474 | loss_tot, loss_reconst, loss_u, loss_z, loss_c, loss_n, corr_list = result
475 | if loss_c < best_loss_c:
476 | models[i] = model
477 | best_loss_c = loss_c
478 |
479 | torch.cuda.empty_cache()
480 |
481 | loss_dict['tot'].append(loss_tot)
482 | loss_dict['reconst'].append(loss_reconst)
483 | loss_dict['u'].append(loss_u)
484 | loss_dict['z'].append(loss_z)
485 | loss_dict['c'].append(loss_c)
486 | loss_dict['n'].append(loss_n)
487 |
488 | if (epoch + 1) % 10 == 0 and verbose:
489 | LOGGER.info("Epoch[{}/{}], train_loss: {:.4f}, train_reconst: {:.4f}, train_u: {:.4f},train_z: {:.4f},train_c: {:.4f},train_n: {:.4f}".format(
490 | epoch + 1, epochs, loss_tot, loss_reconst, loss_u, loss_z, loss_c, loss_n)
491 | )
492 | scheduler.step()
493 |
494 | losses.append(loss_dict)
495 | loss_c_list[i] = best_loss_c
496 | if verbose:
497 | LOGGER.info('Saving the best-performance model...')
498 | LOGGER.info(" === Finished training === \n")
499 |
500 | idx = np.argmin(loss_c_list)
501 | best_model = models[idx]
502 | loss = losses[idx]
503 |
504 | return best_model, loss
505 |
--------------------------------------------------------------------------------
/tests/test_modules.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import importlib
3 |
4 | class TestModules(unittest.TestCase):
5 | def test_import_modules(self):
6 | modules = [
7 | "AA",
8 | "dataloader",
9 | "plot_utils",
10 | "post_analysis",
11 | "starfysh",
12 | "utils",
13 | "utils_integrate",
14 | ]
15 |
16 | for module_name in modules:
17 | try:
18 | importlib.import_module("starfysh." + module_name)
19 | print(f"Module {module_name} imported successfully.")
20 | except ImportError as e:
21 | self.fail(f"Failed to import module {module_name}: {e}")
22 |
23 | if __name__ == '__main__':
24 | unittest.main()
25 |
--------------------------------------------------------------------------------