├── .gitignore ├── .readthedocs.yaml ├── LICENSE ├── README.md ├── data └── README.md ├── dev-requirements.txt ├── docs ├── Makefile ├── make.bat └── source │ ├── conf.py │ ├── index.rst │ ├── installation.rst │ ├── intro.rst │ ├── notebooks │ ├── Starfysh_tutorial_real.ipynb │ ├── Starfysh_tutorial_simulation.ipynb │ └── update_tutorial.sh │ └── starfysh.rst ├── figure ├── github_figure_1.png ├── github_figure_2.png └── logo.png ├── notebooks ├── Starfysh_tutorial_integration.ipynb ├── Starfysh_tutorial_real-with_poe.ipynb ├── Starfysh_tutorial_real_without_poe.ipynb ├── Starfysh_tutorial_real_wo_signatures.ipynb ├── Starfysh_tutorial_simulation.ipynb ├── generate_image.ipynb └── slideseq_starfysh_tutorial_on_later.ipynb ├── reinstall.sh ├── requirements.txt ├── setup.py ├── starfysh ├── AA.py ├── __init__.py ├── dataloader.py ├── gener_img.py ├── plot_utils.py ├── post_analysis.py ├── starfysh.py ├── utils.py └── utils_integrate.py └── tests └── test_modules.py /.gitignore: -------------------------------------------------------------------------------- 1 | # OS generated files 2 | .DS_Store 3 | 4 | # Compiled 5 | __pycache__/ 6 | *.py[cod] 7 | 8 | # R code 9 | .Rdata 10 | .Rhistory 11 | 12 | # Packaging 13 | build/ 14 | dist/ 15 | *.egg 16 | *.egg-info/ 17 | lib/ 18 | lib64/ 19 | bdist.linux-x86_64/ 20 | MANIFEST 21 | 22 | # Environments 23 | .env 24 | env/ 25 | .venv 26 | /venv/ 27 | 28 | # Others 29 | .csv 30 | .idea 31 | .ipynb_checkpoints 32 | data*/ 33 | **/results/ 34 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # Required 2 | version: 2 3 | 4 | # Set the version of Python and other tools you might need 5 | build: 6 | os: ubuntu-20.04 7 | tools: 8 | python: "3.8" 9 | 10 | # Build documentation in the docs/ directory with Sphinx 11 | sphinx: 12 | configuration: docs/source/conf.py 13 | builder: html 14 | 15 | # If using Sphinx, optionally build your docs in additional formats such as PDF 16 | # formats: 17 | # - pdf 18 | 19 | # Optionally declare the Python requirements required to build your docs 20 | python: 21 | install: 22 | - requirements: requirements.txt 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2022, azizilab 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | [![Documentation Status](https://readthedocs.org/projects/cellpose/badge/?version=latest)](https://readthedocs.org/projects/starfysh/badge/?version=latest) 4 | [![Licence: GPL v3](https://img.shields.io/github/license/azizilab/starfysh)](https://github.com/azizilab/starfysh/blob/master/LICENSE) 5 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1a_mxF6Ot5vA_xzr5EaNz-GllYY-pXhwA) 6 | 7 | ## Starfysh: Spatial Transcriptomic Analysis using Reference-Free deep generative modeling with archetYpes and Shared Histology 8 | 9 | Starfysh is an end-to-end toolbox for the analysis and integration of Spatial Transcriptomic (ST) datasets. In summary, the Starfysh framework enables reference-free deconvolution of cell types and cell states and can be improved with the integration of paired histology images of tissues, if available. Starfysh is capable of integrating data from multiple tissues. In particular, Starfysh identifies common or sample-specific spatial “hubs” with unique composition of cell types. To uncover mechanisms underlying local and long-range communication, Starfysh can be used to perform downstream analysis on the spatial organization of hubs. 10 | 11 | 12 | 13 | 14 | 15 | ### Quickstart tutorials 16 | - [1. Basic deconvolution on an example breast cancer data with pre-defiend signatures, without poe(dataset & signature files included).](notebooks/Starfysh_tutorial_real_without_poe.ipynb) 17 | - [2. Histology integration & deconvolution without pre-defined signatures.](notebooks/Starfysh_tutorial_real_wo_signatures.ipynb) 18 | - [3. Basic deconvolution on an example breast cancer data with pre-defiend signatures, with histology images integration (poe) (dataset & signature files included).](notebooks/Starfysh_tutorial_real_without_poe.ipynb) 19 | - [4. Multi-sample integration, with archetypes and poe](notebooks/Starfysh_tutorial_integration.ipynb) 20 | 21 | 22 | Please refer to [Starfysh Documentation](http://starfysh.readthedocs.io) for additional tutorials & APIs 23 | 24 | ### Installation 25 | Github-version installation: 26 | ```bash 27 | # Step 1: Clone the Repository 28 | git clone https://github.com/azizilab/starfysh.git 29 | 30 | # Step 2: Navigate to the Repository 31 | cd starfysh 32 | 33 | # Step 3: Install the Package 34 | pip install . 35 | ``` 36 | 37 | 38 | ### Model Input: 39 | - Spatial Transcriptomics count matrix 40 | - Annotated signature gene sets (see [example](https://drive.google.com/file/d/1AXWQy_mwzFEKNjAdrJjXuegB3onxJoOM/view?usp=share_link)) 41 | - (Optional): paired H&E image 42 | 43 | ### Features: 44 | - Deconvolving cell types & discovering novel, unannotated cell states 45 | - Integrating with histology images and multi-sample integration 46 | - Downstream analysis: spatial hub identification, cell-type colocalization networks & receptor-ligand (R-L) interactions 47 | 48 | ### Directories 49 | 50 | ``` 51 | . 52 | ├── data: Spatial Transcritomics & synthetic simulation datasets 53 | ├── notebooks: Sample tutorial notebooks 54 | ├── starfysh: Starfysh core model 55 | ``` 56 | 57 | ### How to cite Starfysh 58 | Please cite [Starfysh paper published in Nature Biotechnology](https://www.nature.com/articles/s41587-024-02173-8#citeas): 59 | ``` 60 | He, S., Jin, Y., Nazaret, A. et al. 61 | Starfysh integrates spatial transcriptomic and histologic data to reveal heterogeneous tumor–immune hubs. 62 | Nat Biotechnol (2024). 63 | https://doi.org/10.1038/s41587-024-02173-8 64 | ``` 65 | 66 | ### BibTex 67 | ``` 68 | @article{He2024, 69 | title = {Starfysh integrates spatial transcriptomic and histologic data to reveal heterogeneous tumor–immune hubs}, 70 | ISSN = {1546-1696}, 71 | url = {http://dx.doi.org/10.1038/s41587-024-02173-8}, 72 | DOI = {10.1038/s41587-024-02173-8}, 73 | journal = {Nature Biotechnology}, 74 | publisher = {Springer Science and Business Media LLC}, 75 | author = {He, Siyu and Jin, Yinuo and Nazaret, Achille and Shi, Lingting and Chen, Xueer and Rampersaud, Sham and Dhillon, Bahawar S. and Valdez, Izabella and Friend, Lauren E. and Fan, Joy Linyue and Park, Cameron Y. and Mintz, Rachel L. and Lao, Yeh-Hsing and Carrera, David and Fang, Kaylee W. and Mehdi, Kaleem and Rohde, Madeline and McFaline-Figueroa, José L. and Blei, David and Leong, Kam W. and Rudensky, Alexander Y. and Plitas, George and Azizi, Elham}, 76 | year = {2024}, 77 | month = mar 78 | } 79 | ``` 80 | 81 | If you have questions, please contact the authors: 82 | 83 | - Siyu He - sh3846@columbia.edu 84 | - Yinuo Jin - yj2589@columbia.edu 85 | 86 | 87 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | ### Description 2 | 3 | The following public dataset from Wu *et al.*, 2021 study (**A single-cell and spatially resolved atlas of human breast cancers**) is downloaded from the Zenodo [link](https://zenodo.org/record/4739739#.Y4ohMbLMKxs): 4 | 5 | - 1142243F_TNBC 6 | - 1160920F_TNBC 7 | - CID4465_TNBC 8 | - CID44971_TNBC 9 | - CID4535_ER 10 | - CID4290_ER 11 | 12 | **Reference**:
13 | Wu, S. Z., Al-Eryani, G., Roden, D. L., Junankar, S., Harvey, K., Andersson, A., ... & Swarbrick, A. (2021). A single-cell and spatially resolved atlas of human breast cancers. Nature genetics, 53(9), 1334-1347. 14 | -------------------------------------------------------------------------------- /dev-requirements.txt: -------------------------------------------------------------------------------- 1 | ipykernel>=6.17.0 2 | nbsphinx>=0.8.9 3 | sphinx>=4.2.0 4 | sphinxcontrib-apidoc>=4.2.0 -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.https://www.sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | sys.path.insert(0, os.path.abspath('../..')) 16 | 17 | # -- Project information ----------------------------------------------------- 18 | 19 | project = 'Starfysh' 20 | copyright = '2022, Siyu He, Yinuo Jin @ Azizi Lab' 21 | author = 'Siyu He, Yinuo Jin' 22 | 23 | # The full version, including alpha/beta/rc tags 24 | release = '1.0.0' 25 | 26 | 27 | # -- General configuration --------------------------------------------------- 28 | 29 | # Add any Sphinx extension module names here, as strings. They can be 30 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 31 | # ones. 32 | extensions = [ 33 | 'nbsphinx', 34 | 'sphinx.ext.autodoc', 35 | 'sphinx.ext.viewcode', 36 | 'sphinx.ext.napoleon' 37 | ] 38 | 39 | # Add any paths that contain templates here, relative to this directory. 40 | templates_path = ['_templates'] 41 | 42 | # List of patterns, relative to source directory, that match files and 43 | # directories to ignore when looking for source files. 44 | # This pattern also affects html_static_path and html_extra_path. 45 | exclude_patterns = [] 46 | 47 | 48 | # -- Options for HTML output ------------------------------------------------- 49 | 50 | # The theme to use for HTML and HTML Help pages. See the documentation for 51 | # a list of builtin themes. 52 | # 53 | html_theme = 'sphinx_rtd_theme' 54 | 55 | # Add any paths that contain custom static files (such as style sheets) here, 56 | # relative to this directory. They are copied after the builtin static files, 57 | # so a file named "default.css" will overwrite the builtin "default.css". 58 | html_static_path = ['_static'] 59 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. Starfysh documentation master file, created by 2 | sphinx-quickstart on Tue Nov 1 02:30:02 2022. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to Starfysh's documentation! 7 | ==================================== 8 | .. toctree:: 9 | :maxdepth: 3 10 | :caption: Basics: 11 | 12 | intro 13 | installation 14 | 15 | .. toctree:: 16 | :maxdepth: 3 17 | :caption: Examples: 18 | 19 | notebooks/Starfysh_tutorial_simulation.ipynb 20 | notebooks/Starfysh_tutorial_real.ipynb 21 | 22 | .. toctree:: 23 | :maxdepth: 3 24 | :caption: API Reference: 25 | 26 | starfysh 27 | 28 | Indices and tables 29 | ================== 30 | 31 | * :ref:`genindex` 32 | * :ref:`modindex` 33 | * :ref:`search` 34 | 35 | -------------------------------------------------------------------------------- /docs/source/installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ************ 3 | .. code-block:: bash 4 | 5 | pip install Starfysh 6 | 7 | Quickstart 8 | ********** 9 | 10 | .. code-block:: python 11 | 12 | import os 13 | import numpy as np 14 | import pandas as pd 15 | import torch 16 | 17 | from starfysh import (AA, utils, plot_utils, post_analysis) 18 | from starfysh import starfysh as sf_model 19 | 20 | # (1) Loading dataset & signature gene sets 21 | data_path = 'data/' # specify data directory 22 | sig_path = 'signature/signatures.csv' # specify signature directory 23 | sample_id = 'SAMPLE_ID' 24 | 25 | # --- (a) ST matrix --- 26 | adata, adata_norm = utils.load_adata( 27 | data_path, 28 | sample_id, 29 | n_genes=2000 30 | ) 31 | 32 | # --- (b) paired H&E image + spots info --- 33 | img_metadata = utils.preprocess_img( 34 | data_path, 35 | sample_id, 36 | adata_index=adata.obs.index, 37 | hchannel=False 38 | ) 39 | 40 | # --- (c) signature gene sets --- 41 | gene_sig = utils.filter_gene_sig( 42 | pd.read_csv(sig_path), 43 | adata.to_df() 44 | ) 45 | 46 | # (2) Starfysh deconvolution 47 | 48 | # --- (a) Preparing arguments for model training 49 | args = utils.VisiumArguments(adata, 50 | adata_normed, 51 | gene_sig, 52 | img_metadata, 53 | n_anchors=60, 54 | window_size=3, 55 | sample_id=sample_id 56 | ) 57 | 58 | adata, adata_normed = args.get_adata() 59 | anchors_df = args.get_anchors() 60 | 61 | # --- (b) Model training --- 62 | n_restarts = 3 63 | epochs = 200 64 | patience = 50 65 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 66 | 67 | # Run models 68 | model, loss = utils.run_starfysh( 69 | visium_args, 70 | n_repeats=n_repeats, 71 | epochs=epochs, 72 | patience=patience, 73 | device=device 74 | ) 75 | 76 | # (3). Parse deconvolution outputs 77 | inference_outputs, generative_outputs = sf_model.model_eval( 78 | model, 79 | adata, 80 | visium_args, 81 | poe=False, 82 | device=device 83 | ) 84 | 85 | -------------------------------------------------------------------------------- /docs/source/intro.rst: -------------------------------------------------------------------------------- 1 | Overview 2 | ======== 3 | Starfysh is an end-to-end toolbox for analysis and integration of ST datasets. 4 | In summary, the Starfysh framework consists of reference-free deconvolution of cell types and cell states, which can be improved with integration of paired histology images of tissues, if available. To facilitate comparison of tissues between healthy or disease contexts and deriving differential spatial patterns, Starfysh is capable of integrating data from multiple tissues and further identifies common or sample-specific spatial “hubs”, defined as neighborhoods with a unique composition of cell types. To uncover mechanisms underlying local and long-range communication, Starfysh performs downstream analysis on the spatial organization of hubs and identifies critical genes with spatially varying patterns as well as cell-cell interaction networks. 5 | 6 | 7 | Features 8 | ******** 9 | 10 | * Deconvolving cell types / cell states 11 | * Discovering and learning novel cell states 12 | * Integrating with histology images and multi-sample integration 13 | * Downstream analysis: spatial hub identification, cell-type colocalization networks & receptor-ligand (R-L) interactions 14 | 15 | Model Specifications 16 | ******************** 17 | 18 | Starfysh performs cell-type deconvolution followed by various downstream analysis to discover spatial interactions in tumor microenvironment. 19 | The core deconvolution model is based on semi-supervised Auxiliary Variational Autoencoder (AVAE). We further provide optional Archetypal Analysis (AA) & Product-of-Experts (PoE) for cell-type annotaation and H&E image integration to further aid deconvolution. 20 | Specifically, Starfysh looks for *anchor spots*, the presumed purest spots with the highest proportion of a given cell type guided by signatures, and further deconvolve the remaining spots. Starfysh provides the following options: 21 | 22 | **Base feature**: 23 | 24 | * Auxiliary Variational AutoEncoder (AVAE): 25 | Spot-level deconvolution with expected cell types and corresponding annotated *signature* gene sets (default) 26 | 27 | **Optional**: 28 | 29 | * Archetypal Analysis (AA): 30 | If signature is not provided: 31 | 32 | * Unsupervised cell type annotation (if the input *signature* is not provided) 33 | 34 | If signature is provided: 35 | 36 | * Novel cell type / cell state discovery (complementary to known cell types from the *signatures* 37 | * Refine known marker genes by appending archetype-specific differentially expressed genes, and update anchor spots accordingly 38 | 39 | * Product-of-Experts (PoE) integration: 40 | Multi-modal integrative predictions with *expression* & *histology image* by leverging additional side information (e.g. cell density) from H&E image. 41 | 42 | 43 | I/O 44 | *** 45 | - Input: 46 | 47 | - Spatial Transcriptomics count matrix 48 | - Annotated signature gene sets (`see example `_) 49 | - (Optional): paired H&E image 50 | 51 | - Output: 52 | 53 | - Spot-wise deconvolution matrix (`q(c)`) 54 | - Low-dimensional manifold representation (`q(z)`) 55 | - Spatial hubs (in-sample or multiple-sample integration) 56 | - Co-localization networks across cell types and Spatial receptor-ligand (R-L) interactions 57 | - Reconstructed count matrix (`p(x)`) 58 | 59 | 60 | -------------------------------------------------------------------------------- /docs/source/notebooks/Starfysh_tutorial_real.ipynb: -------------------------------------------------------------------------------- 1 | ../../../notebooks/Starfysh_tutorial_real.ipynb -------------------------------------------------------------------------------- /docs/source/notebooks/Starfysh_tutorial_simulation.ipynb: -------------------------------------------------------------------------------- 1 | ../../../notebooks/Starfysh_tutorial_simulation.ipynb -------------------------------------------------------------------------------- /docs/source/notebooks/update_tutorial.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ln -sf ../../../notebooks/Starfysh_tutorial_simulation.ipynb Starfysh_tutorial_simulation.ipynb 4 | ln -sf ../../../notebooks/Starfysh_tutorial_real.ipynb Starfysh_tutorial_real.ipynb 5 | 6 | -------------------------------------------------------------------------------- /docs/source/starfysh.rst: -------------------------------------------------------------------------------- 1 | starfysh package 2 | ================ 3 | 4 | Submodules 5 | ---------- 6 | 7 | starfysh.AA module 8 | ------------------ 9 | 10 | .. automodule:: starfysh.AA 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | starfysh.dataloader module 16 | -------------------------- 17 | 18 | .. automodule:: starfysh.dataloader 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | starfysh.plot\_utils module 24 | --------------------------- 25 | 26 | .. automodule:: starfysh.plot_utils 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | starfysh.post\_analysis module 32 | ------------------------------ 33 | 34 | .. automodule:: starfysh.post_analysis 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | starfysh.starfysh module 40 | ------------------------ 41 | 42 | .. automodule:: starfysh.starfysh 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | starfysh.utils module 48 | --------------------- 49 | 50 | .. automodule:: starfysh.utils 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | Module contents 56 | --------------- 57 | 58 | .. automodule:: starfysh 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | -------------------------------------------------------------------------------- /figure/github_figure_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/azizilab/starfysh/56fb01ef734401d067eb3078280dd805b97621d1/figure/github_figure_1.png -------------------------------------------------------------------------------- /figure/github_figure_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/azizilab/starfysh/56fb01ef734401d067eb3078280dd805b97621d1/figure/github_figure_2.png -------------------------------------------------------------------------------- /figure/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/azizilab/starfysh/56fb01ef734401d067eb3078280dd805b97621d1/figure/logo.png -------------------------------------------------------------------------------- /reinstall.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | pip uninstall starfysh 4 | python setup.py install --user 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | histomicstk>=1.2.0 2 | matplotlib>=3.4.2 3 | networkx>=2.6.3 4 | numba>=0.56.4 5 | numpy>=1.23.5 6 | opencv_python_headless>=4.5.1.48 7 | pandas>=1.3.3 8 | py_pcha>=0.1.3 9 | scanpy>=1.9.2 10 | scikit_dimension>=0.3 11 | scikit_image>=0.19.2 12 | scikit_learn>=1.2.1 13 | scipy>=1.7.1 14 | seaborn>=0.11.2 15 | setuptools>=50.3.2 16 | threadpoolctl>=3.1.0 17 | torch>=2.0.0 18 | torchvision>=0.15.1 19 | tqdm>=4.62.2 20 | umap_learn>=0.5.3 21 | nbsphinx 22 | ipykernel 23 | sphinx 24 | sphinxcontrib-apidoc 25 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open("requirements.txt", 'r') as ifile: 4 | requirements = ifile.read().splitlines() 5 | 6 | nb_requirements = [ 7 | 'nbconvert>=6.1.0', 8 | 'nbformat>=5.1.3', 9 | 'notebook>=6.4.11', 10 | 'jupyter>=7.0.0', 11 | 'jupyterlab>=3.4.3', 12 | 'ipython>=7.27.0', 13 | ] 14 | 15 | setup( 16 | name="Starfysh", 17 | version="1.2.0", 18 | description="Spatial Transcriptomic Analysis using Reference-Free auxiliarY deep generative modeling and Shared Histology", 19 | authors=["Siyu He", "Yinuo Jin", "Achille Nazaret"], 20 | url="https://starfysh.readthedocs.io", 21 | packages=find_packages(), 22 | python_requires='>=3.7', 23 | install_requires=requirements, 24 | zip_safe=False, 25 | classifiers=[ 26 | "Programming Language :: Python :: 3", 27 | "License :: OSI Approved :: BSD License", 28 | "Operating System :: OS Independent", 29 | ], 30 | 31 | dependency_links=[ 32 | 'https://girder.github.io/large_image_wheels' 33 | ] 34 | 35 | # extras_require={ 36 | # 'notebooks': nb_requirements, 37 | # 'dev': open('dev-requirements.txt').read().splitlines(), 38 | # } 39 | ) 40 | 41 | -------------------------------------------------------------------------------- /starfysh/AA.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | import scanpy as sc 5 | import umap 6 | import skdim 7 | import matplotlib.pyplot as plt 8 | import seaborn as sns 9 | 10 | from py_pcha import PCHA 11 | from matplotlib.pyplot import cm 12 | from scipy.spatial.distance import cdist, euclidean 13 | from sklearn.neighbors import NearestNeighbors 14 | from starfysh import LOGGER 15 | 16 | 17 | class ArchetypalAnalysis: 18 | # Todo: add assertion that input `adata` should be raw counts to ensure math assumptions 19 | def __init__( 20 | self, 21 | adata_orig, 22 | r=100, 23 | u=None, 24 | u_3d=None, 25 | verbose=True, 26 | outdir=None, 27 | filename=None, 28 | savefig=False, 29 | ): 30 | """ 31 | Parameters 32 | ---------- 33 | adata_orig : sc.AnnData 34 | ST raw count matrix 35 | 36 | r : int 37 | Resolution parameter to control granularity of major archetypes 38 | If two archetypes reside within r nearest neighbors, the latter 39 | one will be merged. 40 | """ 41 | 42 | # Check adata raw counts 43 | 44 | 45 | self.adata = adata_orig.copy() 46 | 47 | # Perform dim. reduction with PCA, select the first 30 PCs 48 | if 'X_umap' in self.adata.obsm_keys(): 49 | del self.adata.obsm['X_umap'] 50 | sc.pp.normalize_total(self.adata, target_sum=1e6) 51 | sc.pp.pca(self.adata, n_comps=30) 52 | 53 | self.r = r # granularity parameter 54 | self.count = self.adata.obsm['X_pca'] 55 | self.n_spots = self.count.shape[0] 56 | 57 | self.verbose = verbose 58 | self.outdir = outdir 59 | self.filename = filename 60 | self.savefig = savefig 61 | 62 | self.archetype = None 63 | self.major_archetype = None 64 | self.major_idx = None 65 | self.arche_dict = None 66 | self.arche_df = None 67 | self.kmin = 0 68 | 69 | self.U = u 70 | self.U_3d = u_3d 71 | 72 | def compute_archetypes( 73 | self, 74 | cn=30, 75 | n_iters=20, 76 | converge=1e-3, 77 | display=False 78 | ): 79 | """ 80 | Estimate the upper bound of archetype count (k) by calculating intrinsic dimension 81 | Compute hierarchical archetypes (major + raw) with given granularity 82 | 83 | Parameters 84 | ---------- 85 | cn : int 86 | Conditional Number to choose PCs for intrinsic estimator as 87 | lower bound # archetype estimation. Please refer to: 88 | https://scikit-dimension.readthedocs.io/en/latest/skdim.id.FisherS.html#skdim.id.FisherS 89 | 90 | n_iters : int 91 | Max. # iterations of AA to find the best k estimation 92 | 93 | converge : int 94 | Convergence criteria for AA iteration with diff(explained variance) 95 | 96 | display : bool 97 | Whether to display Intrinsic Dimension (ID) estimation plots 98 | 99 | Returns 100 | ------- 101 | archetype : np.ndarray (dim=[K, G]) 102 | Raw archetypes as linear combination of subset of spot counts 103 | 104 | arche_dict : dict 105 | Hierarchical structure of major_archetype -> its fine-grained neighbor archetypes 106 | 107 | major_idx : int 108 | Index of major archetypes among `k` raw candidates after merging\ 109 | 110 | evs : list 111 | Explained variance with different Ks 112 | """ 113 | 114 | # TMP: across-sample comparison: fix # principle components for all samples 115 | 116 | if self.verbose: 117 | LOGGER.info('Computing intrinsic dimension to estimate k...') 118 | 119 | # Estimate ID 120 | id_model = skdim.id.FisherS(conditional_number=cn, 121 | produce_plots=display, 122 | verbose=self.verbose) 123 | 124 | self.kmin = max(1, int(id_model.fit(self.count).dimension_)) 125 | 126 | # Compute raw archetypes 127 | if self.verbose: 128 | LOGGER.info('Estimating lower bound of # archetype as {0}...'.format(self.kmin)) 129 | X = self.count.T 130 | archetypes = [] 131 | evs = [] 132 | 133 | # TODO: speedup with multiprocessing 134 | for i, k in enumerate(range(self.kmin, self.kmin+n_iters, 2)): 135 | archetype, _, _, _, ev = PCHA(X, noc=k) 136 | evs.append(ev) 137 | archetypes.append(np.array(archetype).T) 138 | if i > 0 and ev - evs[i-1] < converge: 139 | # early stopping 140 | break 141 | self.archetype = archetypes[-1] 142 | 143 | # Merge raw archetypes to get major archetypes 144 | if self.verbose: 145 | LOGGER.info('{0} variance explained by raw archetypes.\n' 146 | 'Merging raw archetypes within {1} NNs to get major archetypes'.format(np.round(ev, 4), self.r)) 147 | 148 | arche_dict, major_idx = self._merge_archetypes() 149 | self.major_archetype = self.archetype[major_idx] 150 | self.major_idx = np.array(major_idx) 151 | self.arche_dict = arche_dict 152 | 153 | # return all archetypes for Silhouette score calculation 154 | return archetypes, arche_dict, major_idx, evs 155 | 156 | def _merge_archetypes(self): 157 | """ 158 | Merge raw archetypes into major ones by removing candidate with `r`-step distance 159 | from its previous identified neighbors 160 | """ 161 | assert self.archetype is not None, "Please compute archetypes first!" 162 | 163 | n_archetypes = self.archetype.shape[0] 164 | X_concat = np.vstack([self.count, self.archetype]) 165 | nbrs = NearestNeighbors(n_neighbors=self.r).fit(X_concat) 166 | nn_graph = nbrs.kneighbors(X_concat)[1][self.n_spots:, 1:] # retrieve NN-graph of only archetype spots 167 | 168 | idxs_to_remove = set() 169 | arche_dict = {} 170 | for i in range(n_archetypes): 171 | if i not in idxs_to_remove: 172 | query = np.arange(self.n_spots+i, self.n_spots+n_archetypes) 173 | nbrs = np.setdiff1d( 174 | nn_graph[i][np.isin(nn_graph[i], query)] - self.n_spots, 175 | list(idxs_to_remove) # avoid over-assign merged archetypes to multiple major archetypes 176 | ) 177 | if len(nbrs) != 0: 178 | arche_dict[i] = np.insert(nbrs, 0, i) 179 | idxs_to_remove.update(nbrs) 180 | 181 | major_idx = np.setdiff1d(np.arange(n_archetypes), list(idxs_to_remove)) 182 | return arche_dict, major_idx 183 | 184 | def find_archetypal_spots(self, major=True): 185 | """ 186 | Assign N-nearest-neighbor spots to each archetype as `archetypal spots` (archetype community) 187 | 188 | Parameters 189 | ---------- 190 | major : bool 191 | Whether to find NNs for only major archetypes 192 | 193 | Returns 194 | ------- 195 | arche_df : pd.DataFrame 196 | Dataframe of archetypal spots 197 | """ 198 | assert self.archetype is not None, "Please compute archetypes first!" 199 | if self.verbose: 200 | LOGGER.info('Finding {} nearest neighbors for each archetype...'.format(self.r)) 201 | 202 | indices = self.major_idx if major else np.arange(self.archetype.shape[0]) 203 | x_concat = np.vstack([self.count, self.archetype]) 204 | nbrs = self._get_knns(x_concat, n_nbrs=self.r, indices=indices+self.n_spots) 205 | self.arche_df = pd.DataFrame({ 206 | 'arch_{}'.format(idx): g 207 | for (idx, g) in zip(indices, nbrs) 208 | }) 209 | 210 | return self.arche_df 211 | 212 | def find_markers(self, n_markers=30, display=False): 213 | """ 214 | Find marker genes for each archetype community via Wilcoxon rank sum test (in-group vs. out-of-group) 215 | 216 | Parameters 217 | ---------- 218 | n_markers : int 219 | Number of top marker genes to find for each archetype community 220 | 221 | Returns 222 | ------- 223 | marker_df : pd.DataFrame 224 | Dataframe of marker genes for each archetype community 225 | """ 226 | assert self.arche_df is not None, "Please compute archetypes & assign nearest-neighbors first!" 227 | if self.verbose: 228 | LOGGER.info('Finding {} top marker genes for each archetype...'.format(n_markers)) 229 | 230 | adata = self.adata.copy() 231 | markers = [] 232 | for col in self.arche_df.columns: 233 | # Annotate in-group (current archetype) vs. out-of-group 234 | annots = np.zeros(self.n_spots, dtype=np.int64).astype(str) 235 | annots[self.arche_df[col]] = col 236 | adata.obs[col] = annots 237 | adata.obs[col] = adata.obs[col].astype('category') 238 | 239 | # Identify marker genes 240 | sc.tl.rank_genes_groups(adata, col, use_raw=False, method='wilcoxon') 241 | markers.append(adata.uns['rank_genes_groups']['names'][col][:n_markers]) 242 | 243 | if display: 244 | plt.rcParams['figure.figsize'] = (8, 3) 245 | plt.rcParams['figure.dpi'] = 300 246 | sc.pl.rank_genes_groups_violin(adata, groups=[col], n_genes=n_markers) 247 | 248 | return pd.DataFrame(np.stack(markers, axis=1), columns=self.arche_df.columns) 249 | 250 | def assign_archetypes(self, anchor_df, r=30): 251 | """ 252 | Stable-matching to obtain best 1-1 mapping of archetype community to its closest anchor community 253 | (cell-type specific anchor spots) 254 | 255 | Parameters 256 | ---------- 257 | anchor_df : pd.DataFrame 258 | Dataframe of anchor spot indices 259 | 260 | ` r : int 261 | Resolution parameter to threshold archetype - anchor mapping 262 | 263 | Returns 264 | ------- 265 | overlaps_df : pd.DataFrame 266 | DataFrame of overlapping spot ratio of each anchor `i` to archetype `j` 267 | 268 | map_dict : dict 269 | Dictionary of cell type -> mapped archetype 270 | """ 271 | assert self.arche_df is not None, "Please compute archetypes & assign nearest-neighbors first!" 272 | 273 | x_concat = np.vstack([self.count, self.archetype]) 274 | anchor_nbrs = anchor_df.values 275 | archetypal_nbrs = self._get_knns( 276 | x_concat, n_nbrs=r, indices=self.n_spots + self.major_idx).T # r-nearest nbrs to each archetype 277 | 278 | overlaps = np.array( 279 | [ 280 | [ 281 | len(np.intersect1d(anchor_nbrs[:, i], archetypal_nbrs[:, j])) 282 | for j in range(archetypal_nbrs.shape[1]) 283 | ] 284 | for i in range(anchor_nbrs.shape[1]) 285 | ] 286 | ) 287 | overlaps_df = pd.DataFrame(overlaps, index=anchor_df.columns, columns=self.arche_df.columns) 288 | 289 | # Stable marriage matching: archetype -> anchor clusters 290 | map_idx_df = self._stable_matching(overlaps) 291 | map_dict = {overlaps_df.index[k]: overlaps_df.columns[v] for k, v in map_idx_df.items()} 292 | return overlaps_df, map_dict 293 | 294 | def find_distant_archetypes(self, anchor_df, map_dict=None, n=3): 295 | """ 296 | Sort and return top n archetypes that are unmapped and farthest from anchor spots of know cell types 297 | They are more likely to represent novel cell types / states 298 | 299 | Parameters 300 | ---------- 301 | anchor_df : pd.DataFrame 302 | Dataframe of anchor spot indices 303 | 304 | map_dict : dict 305 | Dictionary of cell type -> mapped archetype 306 | 307 | n : int 308 | Number of distant archetypes to return 309 | 310 | Returns 311 | ------- 312 | distant_archetypes : list 313 | List of archetype labels (farthest --> closest to anchors) 314 | """ 315 | assert self.arche_df is not None, "Please compute archetypes & assign nearest-neighbors first!" 316 | 317 | cell_types = anchor_df.columns 318 | arche_lbls = self.arche_df.columns 319 | 320 | # Find the unmapped archetypes 321 | if map_dict is None: 322 | _, map_dict = self.assign_archetypes(anchor_df=anchor_df) 323 | unmapped_archetypes = np.setdiff1d( 324 | arche_lbls, 325 | list(set([v for k, v in map_dict.items()])) 326 | ) 327 | 328 | # Sort unmapped archetypes in descending orders with avg. distance to its 2 closest anchor spot centroid 329 | if n > len(unmapped_archetypes): 330 | LOGGER.warning('Insufficient candidates to find {0} distant archetypes\nSet n={1}'.format( 331 | n, len(unmapped_archetypes) 332 | )) 333 | anchor_centroids = self.count[anchor_df[anchor_df.columns]].mean(0) 334 | arche_centroids = self.count[self.arche_df[self.arche_df.columns]].mean(0) 335 | dist_df = pd.DataFrame( 336 | cdist(anchor_centroids, arche_centroids), 337 | index=cell_types, 338 | columns=arche_lbls 339 | ) 340 | dist_unmapped = dist_df[unmapped_archetypes].values # subset only distance to `unmapped` archetypes 341 | dist_to_nbrs = np.sort(dist_unmapped, axis=0)[:2].mean(0) 342 | distant_arches = [unmapped_archetypes[idx] for idx in np.argsort(-dist_to_nbrs)][:n] # dist - Discending order 343 | 344 | return distant_arches 345 | 346 | def _get_knns(self, x, n_nbrs, indices): 347 | """Compute kNNs (actual spots) to each archetype""" 348 | assert 0 <= indices.min() < indices.max() < x.shape[0], \ 349 | "Invalid indices of interest to compute k-NNs" 350 | nbrs = np.zeros((len(indices), n_nbrs), dtype=np.int32) 351 | for i, index in enumerate(indices): 352 | u = x[index] 353 | dist = np.ones(x.shape[0])*np.inf 354 | for j, v in enumerate(x[:self.n_spots]): 355 | dist[j] = np.linalg.norm(u-v) 356 | nbrs[i] = np.argsort(dist)[:n_nbrs] 357 | return nbrs 358 | 359 | def _stable_matching(self, A): 360 | matching = {} 361 | free_rows, free_cols = set(range(A.shape[0])), set(range(A.shape[1])) 362 | 363 | while free_rows: 364 | i = free_rows.pop() 365 | for j in np.argsort(A[i])[::-1]: # iter cols in decreasing vals 366 | if j in free_cols: 367 | matching[i] = j 368 | free_cols.remove(j) 369 | break 370 | else: # Check matched cols & compare tie-breaking conditions 371 | i_prime = next(k for k, v in matching.items() if v == j) 372 | if A[i, j] > A[i_prime, j]: 373 | matching[i] = j 374 | free_rows.add(i_prime) 375 | del matching[i_prime] 376 | break 377 | 378 | return matching 379 | 380 | # ------------------- 381 | # Plotting functions 382 | # ------------------- 383 | 384 | def _get_umap(self, ndim=2, random_state=42): 385 | assert ndim == 2 or ndim == 3, "Invalid dimension for UMAP: {}".format(ndim) 386 | LOGGER.info('Calculating UMAPs for counts + Archetypes...') 387 | reducer = umap.UMAP(n_neighbors=self.r+10, n_components=ndim, random_state=random_state) 388 | U = reducer.fit_transform(np.vstack([self.count, self.archetype])) 389 | return U 390 | 391 | def _save_fig(self, fig, lgds, default_name): 392 | filename = self.filename if self.filename is not None else default_name 393 | if not os.path.exists(self.outdir): 394 | os.makedirs(self.outdir) 395 | 396 | fig.savefig( 397 | os.path.join(self.outdir, filename+'.svg'), 398 | bbox_extra_artists=lgds, bbox_inches='tight', 399 | format='svg' 400 | ) 401 | 402 | def plot_archetypes( 403 | self, 404 | major=True, 405 | do_3d=False, 406 | lgd_ncol=1, 407 | figsize=(6, 4), 408 | disp_cluster=True, 409 | disp_arche=True 410 | ): 411 | """ 412 | Display archetype & archetypal spot communities 413 | """ 414 | assert self.arche_df is not None, "Please compute archetypes & assign nearest-neighbors first!" 415 | n_archetypes = self.arche_df.shape[1] 416 | arche_indices = self.major_idx if major else np.arange(n_archetypes) 417 | U = self.U_3d if do_3d else self.U 418 | colors = cm.tab20(np.linspace(0, 1, n_archetypes)) 419 | 420 | if do_3d: 421 | if self.U_3d is None: 422 | self.U_3d = self._get_umap(ndim=3) 423 | U = self.U_3d 424 | 425 | fig, ax = plt.subplots(1, 1, figsize=figsize, dpi=200, subplot_kw=dict(projection='3d')) 426 | 427 | # Color background spots & archetypal spots 428 | 429 | ax.scatter( 430 | U[:self.n_spots, 0], 431 | U[:self.n_spots, 1], 432 | U[:self.n_spots, 2], 433 | s=1, alpha=0.7, linewidth=.3, 434 | edgecolors='black', c='lightgray' 435 | ) 436 | 437 | if disp_cluster: 438 | for i, label in enumerate(self.arche_df.columns): 439 | lbl = int(label.split('_')[-1]) 440 | if lbl in arche_indices: 441 | idxs = self.arche_df[label] 442 | ax.scatter( 443 | U[idxs, 0], 444 | U[idxs, 1], 445 | U[idxs, 2], 446 | marker='o', s=3, 447 | color=colors[i], label=label 448 | ) 449 | 450 | # Highlight archetype 451 | if disp_arche: 452 | ax.scatter( 453 | U[self.n_spots+arche_indices, 0], 454 | U[self.n_spots+arche_indices, 1], 455 | U[self.n_spots+arche_indices, 2], 456 | s=10, c='blue', marker='^' 457 | ) 458 | for j, z in zip(arche_indices, U[self.n_spots+arche_indices]): 459 | ax.text(z[0], z[1], z[2], str(j), fontsize=10, c='blue') 460 | 461 | lgd = ax.legend(loc='right', bbox_to_anchor=(0.5, 0, 1.5, 0.5), ncol=lgd_ncol) 462 | 463 | ax.grid(False) 464 | ax.set_xlabel('UMAP1') 465 | ax.set_ylabel('UMAP2') 466 | ax.set_zlabel('UMAP3') 467 | 468 | ax.set_xticklabels([]) 469 | ax.set_yticklabels([]) 470 | ax.set_zticklabels([]) 471 | 472 | ax.xaxis.pane.set_edgecolor('black') 473 | ax.yaxis.pane.set_edgecolor('black') 474 | 475 | ax.set_xticks([]) 476 | ax.set_yticks([]) 477 | ax.set_zticks([]) 478 | 479 | ax.xaxis.pane.fill = False 480 | ax.yaxis.pane.fill = False 481 | ax.zaxis.pane.fill = False 482 | 483 | ax.view_init(20, 135) 484 | 485 | else: # 2D plot 486 | if self.U is None: 487 | self.U = self._get_umap(ndim=2) 488 | U = self.U 489 | 490 | fig, ax = plt.subplots(1, 1, figsize=figsize, dpi=300) 491 | 492 | # Color background & archetypal spots 493 | ax.scatter( 494 | U[:self.n_spots, 0], 495 | U[:self.n_spots, 1], 496 | alpha=1, s=1, color='lightgray') 497 | 498 | if disp_cluster: 499 | for i, label in enumerate(self.arche_df.columns): 500 | lbl = int(label.split('_')[-1]) 501 | if lbl in arche_indices: 502 | idxs = self.arche_df[label] 503 | ax.scatter( 504 | U[idxs, 0], 505 | U[idxs, 1], 506 | marker='o', s=3, 507 | color=colors[i], label=label 508 | ) 509 | 510 | if disp_arche: 511 | ax.scatter( 512 | U[self.n_spots+arche_indices, 0], 513 | U[self.n_spots+arche_indices, 1], 514 | s=10, c='blue', marker='^' 515 | ) 516 | 517 | for j, z in zip(arche_indices, U[self.n_spots+arche_indices]): 518 | ax.text(z[0], z[1], str(j), fontsize=10, c='blue') 519 | lgd = ax.legend(loc='right', bbox_to_anchor=(1, 0.5), ncol=lgd_ncol) 520 | 521 | ax.grid(False) 522 | ax.axis('off') 523 | 524 | if self.savefig and self.outdir is not None: 525 | self._save_fig(fig, (lgd,), 'archetypes') 526 | return fig, ax 527 | 528 | def plot_anchor_archetype_clusters( 529 | self, 530 | anchor_df, 531 | cell_types=None, 532 | arche_lbls=None, 533 | lgd_ncol=2, 534 | do_3d=False 535 | ): 536 | """ 537 | Joint display subset of anchor spots & archetypal spots (to visualize overlapping degree) 538 | """ 539 | assert self.arche_df is not None, "Please compute archetypes & assign nearest-neighbors first!" 540 | 541 | cell_types = anchor_df.columns if cell_types is None else np.intersect1d(cell_types, anchor_df.columns) 542 | arche_lbls = self.arche_df.columns if arche_lbls is None else np.intersect1d(arche_lbls, self.arche_df.columns) 543 | u_centroids = U[self.arche_df[arche_lbls]].mean(0) 544 | 545 | anchor_colors = cm.RdBu_r(np.linspace(0, 1, len(cell_types))) 546 | arche_colors = cm.RdBu_r(np.linspace(0, 1, len(arche_lbls))) 547 | 548 | if do_3d: 549 | if self.U_3d is None: 550 | self.U_3d = self._get_umap(ndim=3) 551 | U = self.U_3d 552 | 553 | fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 5), dpi=300, subplot_kw=dict(projection='3d')) 554 | 555 | # Display anchors 556 | ax1.scatter( 557 | U[:self.n_spots, 0], 558 | U[:self.n_spots, 1], 559 | U[:self.n_spots, 2], 560 | c='gray', marker='.', s=1, alpha=0.2 561 | ) 562 | for c, label in zip(anchor_colors, cell_types): 563 | idxs = anchor_df[label] 564 | ax1.scatter( 565 | U[idxs, 0], 566 | U[idxs, 1], 567 | U[idxs, 2], 568 | color=c, marker='^', s=5, 569 | alpha=0.9, label=label 570 | ) 571 | 572 | ax1.grid(False) 573 | ax1.set_xticklabels([]) 574 | ax1.set_yticklabels([]) 575 | ax1.set_zticklabels([]) 576 | ax1.view_init(30, 45) 577 | 578 | lgd1 = ax1.legend(loc='lower center', bbox_to_anchor=(0.5, -1), ncol=lgd_ncol) 579 | 580 | # Display archetypal spots 581 | ax2.scatter(U[:self.n_spots, 0], U[:self.n_spots, 1], U[:self.n_spots, 2], c='gray', marker='.', s=1, alpha=0.2) 582 | for c, label in zip(arche_colors, arche_lbls): 583 | idxs = self.arche_df[label] 584 | ax2.scatter(U[idxs, 0], U[idxs, 1], U[idxs, 2], color=c, marker='o', s=3, alpha=0.9, label=label) 585 | 586 | # Highlight selected archetypes 587 | for label, z in zip(arche_lbls, u_centroids): 588 | idx = int(label.split('_')[-1]) 589 | ax2.text(z[0], z[1], z[2], str(idx)) 590 | 591 | ax2.grid(False) 592 | ax2.set_xticklabels([]) 593 | ax2.set_yticklabels([]) 594 | ax2.set_zticklabels([]) 595 | ax2.view_init(30, 45) 596 | 597 | lgd2 = ax2.legend(loc='lower center', bbox_to_anchor=(0.5, -1), ncol=lgd_ncol) 598 | 599 | else: 600 | if self.U is None: 601 | self.U = self._get_umap(ndim=2) 602 | U = self.U 603 | 604 | fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(9, 3), dpi=300) 605 | 606 | # Display anchors 607 | ax1.scatter( 608 | U[:self.n_spots, 0], 609 | U[:self.n_spots, 1], 610 | c='gray', marker='.', s=1, alpha=0.2 611 | ) 612 | 613 | for c, label in zip(anchor_colors, cell_types): 614 | idxs = anchor_df[label] 615 | ax1.scatter( 616 | U[idxs, 0], 617 | U[idxs, 1], 618 | color=c, marker='^', s=5, 619 | alpha=0.9, label=label 620 | ) 621 | 622 | lgd1 = ax1.legend(loc='lower center', bbox_to_anchor=(0.5, -1.75), ncol=lgd_ncol) 623 | 624 | # Display archetypal spots 625 | ax2.scatter(U[:self.n_spots, 0], U[:self.n_spots, 1], c='gray', marker='.', s=1, alpha=0.2) 626 | for c, label in zip(arche_colors, arche_lbls): 627 | idxs = self.arche_df[label] 628 | ax2.scatter( 629 | U[idxs, 0], 630 | U[idxs, 1], 631 | color=c, marker='o', s=3, 632 | alpha=0.9, label=label 633 | ) 634 | 635 | # Highlight selected archetypes 636 | for label, z in zip(arche_lbls, u_centroids): 637 | idx = int(label.split('_')[-1]) 638 | ax2.text(z[0], z[1], str(idx)) 639 | lgd2 = ax2.legend(loc='lower center', bbox_to_anchor=(0.5, -1.85), ncol=lgd_ncol) 640 | 641 | if self.savefig and self.outdir is not None: 642 | self._save_fig(fig, (lgd1, lgd2), 'anchor_archetypal_spots') 643 | return fig, (ax1, ax2) 644 | 645 | def plot_mapping(self, map_df, figsize=(6, 5)): 646 | """ 647 | Display anchor - archetype mapping (overlapping # spot ratio) 648 | """ 649 | filename = 'cluster' if self.filename is None else self.filename 650 | g = sns.clustermap( 651 | map_df, 652 | method='ward', 653 | figsize=figsize, 654 | xticklabels=True, 655 | yticklabels=True, 656 | square=True, 657 | annot_kws={'size': 15} 658 | ) 659 | 660 | text = g.ax_heatmap.set_title('# Overlapped NN Spots (k={})'.format(map_df.shape[1]), 661 | fontsize=20, x=0.6, y=1.3) 662 | # g.ax_row_dendrogram.set_visible(False) 663 | # g.ax_col_dendrogram.set_visible(False) 664 | 665 | if self.savefig and self.outdir is not None: 666 | if not os.path.exists(self.outdir): 667 | os.makedirs(self.outdir) 668 | g.figure.savefig( 669 | os.path.join(self.outdir, filename + '.eps'), 670 | bbox_extra_artists=(text,), bbox_inches='tight', format='eps' 671 | ) 672 | 673 | return g 674 | -------------------------------------------------------------------------------- /starfysh/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import multiprocessing 3 | import logging 4 | 5 | n_cores = multiprocessing.cpu_count() 6 | os.environ["NUMEXPR_MAX_THREADS"] = str(n_cores) 7 | 8 | # Configure global logging format 9 | logging.basicConfig( 10 | level=logging.INFO, 11 | format='[%(asctime)s] %(message)s', 12 | datefmt='%Y-%m-%d %H:%M:%S', 13 | force=True 14 | ) 15 | 16 | LOGGER = logging.getLogger('Starfysh') 17 | -------------------------------------------------------------------------------- /starfysh/dataloader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.utils.data import Dataset 6 | 7 | from starfysh import LOGGER 8 | 9 | 10 | #--------------------------- 11 | # Single Sample dataloader 12 | #--------------------------- 13 | 14 | class VisiumDataset(Dataset): 15 | """ 16 | Loading a single preprocessed ST AnnData, gene signature & Anchor spots for Starfysh training 17 | """ 18 | 19 | def __init__( 20 | self, 21 | adata, 22 | args, 23 | ): 24 | spots = adata.obs_names 25 | genes = adata.var_names 26 | 27 | x = adata.X if isinstance(adata.X, np.ndarray) else adata.X.A 28 | self.expr_mat = pd.DataFrame(x, index=spots, columns=genes) 29 | self.gexp = args.sig_mean_norm 30 | self.anchor_idx = args.pure_idx 31 | self.library_n = args.win_loglib 32 | 33 | def __len__(self): 34 | return len(self.expr_mat) 35 | 36 | def __getitem__(self, idx): 37 | if torch.is_tensor(idx): 38 | idx = idx.tolist() 39 | sample = torch.Tensor( 40 | np.array(self.expr_mat.iloc[idx, :], dtype='float') 41 | ) 42 | 43 | return (sample, 44 | torch.Tensor(self.gexp.iloc[idx, :]), # normalized signature exprs 45 | torch.Tensor(self.anchor_idx[idx, :]), # anchors 46 | torch.Tensor(self.library_n[idx,None]), # library size 47 | ) 48 | 49 | 50 | class VisiumPoEDataSet(VisiumDataset): 51 | 52 | def __init__( 53 | self, 54 | adata, 55 | args, 56 | ): 57 | super(VisiumPoEDataSet, self).__init__(adata, args) 58 | self.image = args.img.astype(np.float64) 59 | self.map_info = args.map_info 60 | self.r = args.params['patch_r'] 61 | self.spot_img_stack = [] 62 | 63 | self.density_std = args.img.std() 64 | 65 | assert self.image is not None,\ 66 | "Empty paired H&E image," \ 67 | "please use regular `Starfysh` without PoE integration" \ 68 | "if your dataset doesn't contain histology image" 69 | 70 | # Retrieve image patch around each spot 71 | scalef = args.scalefactor['tissue_hires_scalef'] # High-res scale factor 72 | h, w = self.image.shape[:2] 73 | patch_dim = (self.r*2, self.r*2, 3) if self.image.ndim == 3 else (self.r*2, self.r*2) 74 | 75 | for i in range(len(self.expr_mat)): 76 | xc = int(np.round(self.map_info.iloc[i]['imagecol'] * scalef)) 77 | yc = int(np.round(self.map_info.iloc[i]['imagerow'] * scalef)) 78 | 79 | # boundary conditions: edge spots 80 | yl, yr = max(0, yc-self.r), min(self.image.shape[0], yc+self.r) 81 | xl, xr = max(0, xc-self.r), min(self.image.shape[1], xc+self.r) 82 | top = max(0, self.r-yc) 83 | bottom = h if h > (yc+self.r) else h-(yc+self.r) 84 | left = max(0, self.r-xc) 85 | right = w if w > (xc+self.r) else w-(xc+self.r) 86 | 87 | #try: 88 | patch = np.zeros(patch_dim) 89 | patch[top:bottom, left:right] = self.image[yl:yr, xl:xr] 90 | self.spot_img_stack.append(patch) 91 | #except ValueError: 92 | # LOGGER.warning('Skipping the patch loading of an edge spot...') 93 | 94 | 95 | def __len__(self): 96 | return len(self.expr_mat) 97 | 98 | def __getitem__(self, idx): 99 | if torch.is_tensor(idx): 100 | idx = idx.tolist() 101 | sample = torch.Tensor( 102 | np.array(self.expr_mat.iloc[idx, :], dtype='float') 103 | ) 104 | spot_img_stack = self.spot_img_stack[idx] 105 | return (sample, 106 | torch.Tensor(self.anchor_idx[idx, :]), 107 | torch.Tensor(self.library_n[idx, None]), 108 | spot_img_stack, 109 | self.map_info.index[idx], 110 | torch.Tensor(self.gexp.iloc[idx, :]), 111 | ) 112 | 113 | #--------------------------- 114 | # Integrative Dataloader 115 | #--------------------------- 116 | 117 | 118 | class IntegrativeDataset(VisiumDataset): 119 | """ 120 | Loading multiple preprocessed ST sample AnnDatas, gene signature & Anchor spots for Starfysh training 121 | """ 122 | 123 | def __init__( 124 | self, 125 | adata, 126 | args, 127 | ): 128 | super(IntegrativeDataset, self).__init__(adata, args) 129 | self.image = args.img 130 | self.map_info = args.map_info 131 | self.r = args.params['patch_r'] 132 | 133 | def __len__(self): 134 | return len(self.expr_mat) 135 | 136 | def __getitem__(self, idx): 137 | if torch.is_tensor(idx): 138 | idx = idx.tolist() 139 | sample = torch.Tensor( 140 | np.array(self.expr_mat.iloc[idx, :], dtype='float') 141 | ) 142 | return (sample, 143 | torch.Tensor(self.gexp.iloc[idx, :]), 144 | torch.Tensor(self.anchor_idx[idx, :]), 145 | torch.Tensor(self.library_n[idx, None]) 146 | ) 147 | 148 | 149 | class IntegrativePoEDataset(VisiumDataset): 150 | 151 | def __init__( 152 | self, 153 | adata, 154 | args, 155 | ): 156 | super(IntegrativePoEDataset, self).__init__(adata, args) 157 | self.image = args.img 158 | self.map_info = args.map_info 159 | self.r = args.params['patch_r'] 160 | 161 | 162 | assert self.image is not None,\ 163 | "Empty paired H&E image," \ 164 | "please use regular `Starfysh` without PoE integration" \ 165 | "if your dataset doesn't contain histology image" 166 | spot_img_all = [] 167 | 168 | # Retrieve image patch around each spot 169 | for sample_id in args.img.keys(): 170 | 171 | scalef_i = args.scalefactor[sample_id]['tissue_hires_scalef'] # High-res scale factor 172 | h, w = self.image[sample_id].shape[:2] 173 | patch_dim = (self.r*2, self.r*2, 3) if self.image[sample_id].ndim == 3 else (self.r*2, self.r*2) 174 | 175 | list_ = adata.obs['sample'] == sample_id 176 | for i in range(len(self.expr_mat.loc[list_,:])): 177 | xc = int(np.round(self.map_info.loc[list_,:].iloc[i]['imagecol'] * scalef_i)) 178 | yc = int(np.round(self.map_info.loc[list_,:].iloc[i]['imagerow'] * scalef_i)) 179 | 180 | # boundary conditions: edge spots 181 | yl, yr = max(0, yc-self.r), min(self.image[sample_id].shape[0], yc+self.r) 182 | xl, xr = max(0, xc-self.r), min(self.image[sample_id].shape[1], xc+self.r) 183 | top = max(0, self.r-yc) 184 | bottom = h if h > (yc+self.r) else h-(yc+self.r) 185 | left = max(0, self.r-xc) 186 | right = w if w > (xc+self.r) else w-(xc+self.r) 187 | 188 | try: 189 | patch = np.zeros(patch_dim) 190 | patch[top:bottom, left:right] = self.image[sample_id][yl:yr, xl:xr] 191 | spot_img_all.append(patch) 192 | except ValueError: 193 | LOGGER.warning('Skipping the patch loading of an edge spot...') 194 | 195 | self.spot_img_stack = list(spot_img_all) 196 | #print(self.spot_img_stack.shape) 197 | 198 | def __len__(self): 199 | return len(self.expr_mat) 200 | 201 | def __getitem__(self, idx): 202 | if torch.is_tensor(idx): 203 | idx = idx.tolist() 204 | sample = torch.Tensor( 205 | np.array(self.expr_mat.iloc[idx, :], dtype='float') 206 | ) 207 | return (sample, 208 | torch.Tensor(self.anchor_idx[idx, :]), 209 | torch.Tensor(self.library_n[idx, None]), 210 | self.spot_img_stack[idx], 211 | self.map_info.index[idx], 212 | torch.Tensor(self.gexp.iloc[idx, :]), 213 | ) 214 | 215 | -------------------------------------------------------------------------------- /starfysh/gener_img.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import scipy.stats as stats 5 | import pandas as pd 6 | import pickle 7 | import anndata 8 | import json 9 | import scanpy as sc 10 | import scipy 11 | import PIL 12 | import seaborn as sns 13 | import torch 14 | import matplotlib.pyplot as plt 15 | from skimage import io 16 | from torch.utils.data import DataLoader,ConcatDataset 17 | import torch 18 | import torch.nn as nn 19 | from torch.nn import functional as F 20 | from typing import List, Callable, Union, Any, TypeVar, Tuple 21 | from tqdm import tqdm 22 | from torchvision.utils import save_image 23 | 24 | sys.path.append('HE-Net') 25 | torch.manual_seed(0) 26 | from henet import utils 27 | from henet import datasets 28 | from henet import model 29 | 30 | 31 | class dataset(torch.utils.data.Dataset): 32 | 33 | def __init__( 34 | self, 35 | spot, 36 | exp_spot, 37 | barcode_spot, 38 | #img_size, 39 | #histo_img, 40 | transform=None 41 | ): 42 | 43 | super(dataset, self).__init__() 44 | self.spot = spot 45 | self.exp_spot = exp_spot 46 | self.barcode_spot = barcode_spot 47 | self.transform = transform 48 | #self.img_size = img_size 49 | #self.histo_img = histo_img 50 | 51 | def __getitem__(self, idx): 52 | #print(idx) 53 | spot_x = self.spot[idx,0] 54 | spot_y = self.spot[idx,1] 55 | exp_spot = self.exp_spot[idx,:] 56 | barcode_spot = self.barcode_spot[idx] 57 | 58 | #x_l = int(spot_x-self.img_size/2) 59 | #x_r = int(spot_x+self.img_size/2) 60 | #y_l = int(spot_y-self.img_size/2) 61 | #y_r = int(spot_y+self.img_size/2) 62 | 63 | #img_spot = img[y_l:y_r,x_l:x_r,:] 64 | #img_spot = self.histo_img[y_l:y_r,x_l:x_r] 65 | #img_spot = img_spot-img_spot.min() 66 | #img_spot = img_spot/img_spot.max() 67 | 68 | #if self.transform is not None: 69 | # img_spot = self.transform(img_spot) 70 | return exp_spot, barcode_spot#, img_spot 71 | #return exp_spot 72 | 73 | def __len__(self): 74 | 75 | return len(self.spot) 76 | 77 | def prep_dataset(adata): 78 | 79 | #library_id = list(adata.uns.get("spatial",{}).keys())[0] 80 | #histo_img = adata.uns['spatial'][str(library_id)]['images']['hires'] 81 | #spot_diameter_fullres = adata.uns['spatial'][str(library_id)]['scalefactors']['spot_diameter_fullres'] 82 | #tissue_hires_scalef = adata.uns['spatial'][str(library_id)]['scalefactors']['tissue_hires_scalef'] 83 | #circle_radius = spot_diameter_fullres *tissue_hires_scalef * 0.5 84 | image_spot = adata.obsm['spatial'] #*tissue_hires_scalef 85 | 86 | #img_size=32 87 | N_spot = image_spot.shape[0] 88 | train_index = np.random.choice(image_spot.shape[0], size=int(N_spot*0.6), replace=False) 89 | #val_index=np.random.choice(image_spot.shape[0], size=int(N_spot*0.1), replace=False) 90 | test_index=np.random.choice(image_spot.shape[0], size=int(N_spot*0.4), replace=True) 91 | 92 | #train_transforms = transforms.Compose([transforms.ToTensor()]) 93 | #val_transforms = transforms.Compose([transforms.ToTensor()]) 94 | #test_transforms = transforms.Compose([transforms.ToTensor()]) 95 | 96 | train_spot = image_spot[train_index] 97 | test_spot = image_spot[test_index] 98 | #val_spot = image_spot[val_index] 99 | 100 | adata_df = adata.to_df() 101 | 102 | train_exp_spot = np.array(adata_df.iloc[train_index]) 103 | test_exp_spot = np.array(adata_df.iloc[test_index]) 104 | #val_exp_spot = np.array(adata_df.iloc[val_index]) 105 | 106 | 107 | train_barcode_spot = adata_df.index[train_index] 108 | test_barcode_spot = adata_df.index[test_index] 109 | #val_barcode_spot = adata_df.index[val_index] 110 | 111 | train_set = dataset(train_spot, train_exp_spot, train_barcode_spot, None) 112 | test_set = dataset(test_spot, test_exp_spot, test_barcode_spot,None) 113 | 114 | all_set = dataset( image_spot,np.array(adata_df), adata_df.index, None) 115 | 116 | return train_set,test_set,all_set 117 | 118 | 119 | def generate_img(dat_path, train_flag=True): 120 | """ 121 | input: 122 | dat_path: the path for csv file 123 | 124 | """ 125 | 126 | dat_folder = 'data' 127 | dat_name = 'CID44971' 128 | n_genes = 6000 129 | adata1,variable_gene = utils.load_adata(dat_folder=dat_folder, 130 | dat_name=dat_name, 131 | use_other_gene=False, 132 | other_gene_list=None, 133 | n_genes=n_genes) 134 | 135 | 136 | adata5 = sc.read_csv(dat_path) 137 | adata5.obs_names = adata5.to_df().index 138 | adata5.var_names = adata5.to_df().columns 139 | adata5.var_names_make_unique() 140 | 141 | adata5.var["mt"] = adata5.var_names.str.startswith("MT-") 142 | sc.pp.calculate_qc_metrics(adata5, qc_vars=["mt"], inplace=True) 143 | print('The datasets have',adata5.n_obs,'spots, and',adata5.n_vars,'genes') 144 | adata5.var['rp'] = adata5.var_names.str.startswith('RPS') + adata5.var_names.str.startswith('RPL') 145 | sc.pp.calculate_qc_metrics(adata5, qc_vars=['rp'], percent_top=None, log1p=False, inplace=True) 146 | 147 | # Remove mitochondrial genes 148 | #print('removing mt/rp genes') 149 | adata5 = adata5[:,-adata5.var['mt']] 150 | adata5 = adata5[:,-adata5.var['rp']] 151 | 152 | adata5_new = pd.DataFrame(np.zeros([adata5.to_df().shape[0],6000]),index =adata5.obs_names, columns=variable_gene ) 153 | 154 | inter_gene = np.intersect1d(variable_gene,adata5.var_names) 155 | 156 | for i in inter_gene: 157 | adata5_new.loc[:,i] = np.array(adata5[:,i].X) 158 | 159 | adata5_new_anndata = anndata.AnnData(np.log( 160 | #np.clip(adata5_new,0,2000) 161 | adata5_new 162 | +1)) 163 | 164 | other_list_temp = variable_gene.intersection(adata5.var.index) 165 | adata5 = adata5[:,other_list_temp] 166 | sc.pp.normalize_total(adata5,inplace=True) 167 | sc.pp.log1p(adata5) 168 | 169 | map_info_list = [] 170 | for i in range(2500): 171 | x,y = np.where(np.arange(2500).reshape(50, 50)==i) 172 | map_info_list.append([x[0]+5,y[0]+5]) 173 | map_info = pd.DataFrame(map_info_list,columns=['array_row','array_col'],index = adata5.obs_names) 174 | 175 | tissue_hires_scalef = 0.20729685 176 | adata5_new_anndata.obsm['spatial']=np.array(map_info)*32 177 | adata5.obsm['spatial']=np.array(map_info)*32 178 | 179 | 180 | 181 | 182 | train_set1,test_set1,all_set1 = datasets.prep_dataset(adata1, 'CID44971') 183 | 184 | train_set = ConcatDataset([train_set1, 185 | 186 | ]) 187 | test_set = ConcatDataset([test_set1, 188 | 189 | ]) 190 | 191 | 192 | train_dataloader = DataLoader(train_set, shuffle=True, num_workers=1, batch_size=8) 193 | test_dataloader= DataLoader(test_set, shuffle=True, num_workers=1, batch_size=8) 194 | all_dataloader1 = DataLoader(all_set1, shuffle=True, num_workers=1, batch_size=8) 195 | 196 | 197 | train_set5,test_set5,all_set5 = prep_dataset(adata5_new_anndata) 198 | all_dataloader5 = DataLoader(all_set5, batch_size=8) 199 | 200 | # Model Hyperparameters 201 | cuda = False 202 | DEVICE = torch.device("cuda" if cuda else "cpu") 203 | batch_size = 20 204 | x_dim =1024 205 | hidden_dim = 400 206 | latent_dim = 20 207 | lr = 1e-3 208 | epochs = 1000 209 | 210 | encoder = model.Encoder(input_dim=x_dim, hidden_dim=hidden_dim, latent_dim=latent_dim,DEVICE=DEVICE) 211 | decoder = model.Decoder(latent_dim=latent_dim, hidden_dim = hidden_dim, output_dim = x_dim) 212 | 213 | vae = model.Model(Encoder=encoder, Decoder=decoder).to(DEVICE) 214 | check_point_filename = 'HE-Net/trained_model_on_lab1A.pt' 215 | vae.load_state_dict(torch.load(check_point_filename,map_location=torch.device('cpu') )) 216 | 217 | 218 | train_net = model.ResNet.resnet18() 219 | 220 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 221 | #device = torch.device("cpu") 222 | train_net.to(device) 223 | lr = 0.001 224 | #optimizer = torch.optim.SGD(oct_net.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4) 225 | optimizer = torch.optim.Adam(train_net.parameters(), lr=lr, amsgrad=True) 226 | 227 | 228 | vae.eval() 229 | if train_flag: 230 | for epoch in range(10): 231 | losses = [] 232 | for batch_idx, (exp_spot, barcode_spot, img_spot_i) in enumerate(train_dataloader): 233 | #print(batch_idx) 234 | img_spot_i = img_spot_i.transpose(3,2).transpose(2,1) 235 | img_spot_i = img_spot_i.reshape(-1, 1024) 236 | 237 | img_spot_i = img_spot_i.to(DEVICE) 238 | 239 | x_hat, mean, log_var = vae(img_spot_i) 240 | 241 | mean = mean.view(exp_spot.shape[0], 3, 20) 242 | mean = mean.view(exp_spot.shape[0], 60) 243 | #mean = mean.cpu().detach().numpy() 244 | x_hat= x_hat.view(exp_spot.shape[0], 3, 32, 32) 245 | img_spot_i= img_spot_i.view(exp_spot.shape[0], 3, 32, 32) 246 | 247 | #print(exp_spot.shape) 248 | output1 = train_net(exp_spot[:,None,:].to(DEVICE)) 249 | loss_model = nn.L1Loss() 250 | #print(output1.flatten()) 251 | loss = F.mse_loss(output1.reshape(-1), mean.reshape(-1))#+0.3*loss_model(output1, y1) 252 | if train_flag: 253 | loss.backward() 254 | optimizer.step() 255 | optimizer.zero_grad() 256 | losses.append(loss.cpu().detach().numpy()) 257 | #print(np.mean(losses)) 258 | print("\tEpoch", epoch + 1, "complete!", "\tAverage Loss: ", np.mean(losses)) 259 | 260 | 261 | adata_test = adata5_new_anndata#adata5_new_anndata#adata5_new_anndata 262 | all_dataloader_test = all_dataloader5#all_dataloader5 263 | 264 | 265 | train_net.eval() 266 | #library_id = list(adata_test.uns.get("spatial",{}).keys())[0] 267 | img_size=32 268 | #tissue_hires_scalef = adata_test.uns['spatial'][str(library_id)]['scalefactors']['tissue_hires_scalef'] 269 | #recon = np.ones(histo_img.shape) 270 | recon = np.ones([2000,2000,3]) 271 | for batch_idx, (exp_spot, barcode_spot) in enumerate(all_dataloader_test): 272 | 273 | #img_spot_i = img_spot_i.transpose(3,2).transpose(2,1) 274 | #img_spot_i = img_spot_i.reshape(-1, 1024) 275 | 276 | #img_spot_i = img_spot_i.to(DEVICE) 277 | 278 | #x_hat, mean, log_var = vae(img_spot_i) 279 | 280 | #mean = mean.view(exp_spot.shape[0], 3, 20) 281 | #mean = mean.view(exp_spot.shape[0], 60) 282 | #mean = mean.cpu().detach().numpy() 283 | #x_hat= x_hat.view(exp_spot.shape[0], 3, 32, 32) 284 | #img_spot_i= img_spot_i.view(exp_spot.shape[0], 3, 32, 32) 285 | 286 | #print(exp_spot.shape) 287 | output1 = train_net(exp_spot[:,None,:].to(DEVICE)) 288 | #plt.figure() 289 | #plt.imshow(mean.cpu().detach()) 290 | #plt.figure() 291 | #plt.imshow(output1.cpu().detach()) 292 | 293 | 294 | output1 = output1.view(exp_spot.shape[0],3,20) 295 | #print(mean.shape) 296 | #print(output1.shape) 297 | 298 | 299 | output1_1 = output1[:,0,:] 300 | output1_2 = output1[:,1,:] 301 | output1_3 = output1[:,2,:] 302 | 303 | x_predicted_1 = vae.Decoder(output1_1) 304 | x_predicted_1 = x_predicted_1.view(exp_spot.shape[0],32,32).cpu().detach().numpy() 305 | x_predicted_2 = vae.Decoder(output1_2) 306 | x_predicted_2 = x_predicted_2.view(exp_spot.shape[0],32,32).cpu().detach().numpy() 307 | x_predicted_3 = vae.Decoder(output1_3) 308 | x_predicted_3 = x_predicted_3.view(exp_spot.shape[0],32,32).cpu().detach().numpy() 309 | 310 | for j,i in zip(range(len(barcode_spot)),barcode_spot): 311 | spot_x = ((adata_test[i].obsm['spatial'] )[0])[1] 312 | spot_y = ((adata_test[i].obsm['spatial'] )[0])[0] 313 | 314 | x_l = int(spot_x-img_size/2) 315 | x_r = int(spot_x+img_size/2) 316 | y_l = int(spot_y-img_size/2) 317 | y_r = int(spot_y+img_size/2) 318 | 319 | recon[y_l:y_r,x_l:x_r,0]=x_predicted_1[j] 320 | recon[y_l:y_r,x_l:x_r,1]=x_predicted_2[j] 321 | recon[y_l:y_r,x_l:x_r,2]=x_predicted_3[j] 322 | 323 | save_path = dat_path.split('/')[0]+'/'+dat_path.split('/')[1] 324 | if not os.path.exists(save_path+'/spatial'): 325 | os.makedirs(save_path+'/spatial') 326 | io.imsave(save_path+'/spatial/tissue_hires_image.png',(recon*255).astype('uint8')) 327 | map_info[['img_col','img_row']]=adata5_new_anndata.obsm['spatial'] 328 | map_info.to_csv(save_path+'/spatial/tissue_positions_list.csv',index=True) 329 | adata5_new_anndata.write(save_path+'/'+dat_path.split('/')[1]+'.h5ad') 330 | 331 | 332 | return recon 333 | 334 | -------------------------------------------------------------------------------- /starfysh/plot_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import scanpy as sc 6 | import matplotlib.pyplot as plt 7 | import seaborn as sns 8 | 9 | from sklearn.metrics import r2_score 10 | from scipy.stats import pearsonr, gaussian_kde 11 | 12 | 13 | # Module import 14 | from .post_analysis import get_z_umap 15 | from .utils import extract_feature 16 | 17 | 18 | def plot_integrated_spatial_feature(data, 19 | sample_id, 20 | vmin=None, 21 | vmax=None, 22 | spot_size = 1, 23 | figsize=(2.5, 2), 24 | fig_dpi = 300, 25 | cmap = 'Blues', 26 | legend_on = True, 27 | legend_loc = (2,0.5) 28 | ): 29 | 30 | adata_ann = data[data.obs['sample']==sample_id] 31 | #color_idx_list = np.array(adata_ann.obs['pheno_louvain']).astype(int) 32 | color_idx_list = np.array(adata_ann.obs['pheno_louvain']).astype(int) 33 | all_loc = np.array(adata_ann.obsm['spatial']) 34 | fig,axs= plt.subplots(1,1,figsize=figsize,dpi=500) 35 | for i in np.unique(color_idx_list): 36 | g=axs.scatter(all_loc[color_idx_list==i,0],-all_loc[color_idx_list==i, 1],s=spot_size,marker='h',c=cmap[i]) 37 | 38 | if legend_on: 39 | plt.legend(list(np.unique(color_idx_list)),title='Hub region',loc='right',bbox_to_anchor=legend_loc) 40 | #fig.colorbar(g,label=label) 41 | #plt.title(sample_id) 42 | axs.set_xticks([]) 43 | axs.set_yticks([]) 44 | plt.axis('off') 45 | 46 | def plot_spatial_density( 47 | data, 48 | vmin=None, 49 | vmax=None, 50 | spot_size = 1, 51 | figsize=(2.5, 2), 52 | fig_dpi = 300, 53 | cmap = 'Blues', 54 | colorbar_on = True, 55 | ): 56 | fig, axes = plt.subplots(1, 1, figsize=figsize, dpi=fig_dpi) 57 | g = axes.scatter(data.obsm['spatial'][:,0], 58 | -data.obsm['spatial'][:,1], 59 | c=data.obsm['ql_m'], 60 | vmin=vmin, 61 | vmax=vmax, 62 | cmap=cmap, 63 | s=spot_size, 64 | ) 65 | if colorbar_on: 66 | fig.colorbar(g, label='Estimated density') 67 | axes.axis('off') 68 | 69 | def plot_spatial_cell_type_frac( data = None, 70 | cell_type = None, 71 | vmin=None,# adjust 72 | vmax=None,# adjust 73 | spot_size=2,# adjust 74 | figsize = (3,2.5), 75 | fig_dpi = 300, # >300 for high quality img 76 | cmap = 'magma', 77 | colorbar_on = True, 78 | label=None, 79 | title=None, 80 | ): 81 | 82 | 83 | fig,axs= plt.subplots(1,1,figsize=figsize,dpi=fig_dpi) 84 | idx = data.uns['cell_types'].index(cell_type) 85 | g=axs.scatter(data.obsm['spatial'][:,0], 86 | -data.obsm['spatial'][:,1], 87 | c=data.obsm['qc_m'][:,idx], 88 | cmap=cmap, 89 | vmin=vmin, 90 | vmax=vmax, 91 | s=spot_size 92 | ) 93 | if title is not None: 94 | plt.title(title) 95 | else: 96 | plt.title(cell_type) 97 | if colorbar_on: 98 | fig.colorbar(g,label=label) 99 | plt.axis('off') 100 | 101 | 102 | 103 | def plot_z_umap_cell_type_frac(data = None, 104 | cell_type = None, 105 | vmin=None,# adjust 106 | vmax=None,# adjust 107 | spot_size=2,# adjust 108 | figsize = (3,2.5), 109 | fig_dpi = 300, # >300 for high quality img 110 | cmap = 'magma', 111 | colorbar_on = True, 112 | label=None, 113 | title=None, 114 | ): 115 | 116 | 117 | fig,axs= plt.subplots(1,1,figsize=figsize,dpi=fig_dpi) 118 | idx = data.uns['cell_types'].index(cell_type) 119 | g=axs.scatter(data.obsm['z_umap'][:,0], 120 | data.obsm['z_umap'][:,1], 121 | c=data.obsm['qc_m'][:,idx], 122 | cmap=cmap, 123 | vmin=vmin, 124 | vmax=vmax, 125 | s=spot_size 126 | ) 127 | if title is not None: 128 | plt.title(title) 129 | else: 130 | plt.title(cell_type) 131 | if colorbar_on: 132 | fig.colorbar(g,label=label) 133 | plt.axis('off') 134 | 135 | def plot_spatial_feature(data = None, 136 | feature = None, 137 | vmin=None,# adjust 138 | vmax=None,# adjust 139 | spot_size=2,# adjust 140 | figsize = (3,2.5), 141 | fig_dpi = 300, # >300 for high quality img 142 | cmap = 'magma', 143 | colorbar_on = True, 144 | label=None 145 | ): 146 | 147 | 148 | fig,axs= plt.subplots(1,1,figsize=figsize,dpi=fig_dpi) 149 | 150 | g=axs.scatter(data.obsm['spatial'][:,0], 151 | -data.obsm['spatial'][:,1], 152 | c=feature, 153 | cmap=cmap, 154 | vmin=vmin, 155 | vmax=vmax, 156 | s=spot_size 157 | ) 158 | 159 | if colorbar_on: 160 | fig.colorbar(g,label=label) 161 | plt.axis('off') 162 | 163 | def plot_spatial_gene(data, 164 | data_normed, 165 | gene_name = None, 166 | log_gene = False, 167 | vmin=None, 168 | vmax=None, 169 | spot_size=5, 170 | figsize = (2,2), 171 | fig_dpi = 300, 172 | cmap = 'magma', 173 | colorbar_on = True, 174 | ): 175 | 176 | fig,axs= plt.subplots(1,1,figsize=figsize,dpi=fig_dpi) 177 | if log_gene: 178 | g=axs.scatter(data_normed.obsm['spatial'][:,0], 179 | -data_normed.obsm['spatial'][:,1], 180 | c=data_normed.to_df().loc[:,gene_name], 181 | cmap=cmap, 182 | vmin=vmin, 183 | vmax=vmax, 184 | s=spot_size 185 | ) 186 | else: 187 | g=axs.scatter(data.obsm['spatial'][:,0], 188 | -data.obsm['spatial'][:,1], 189 | c=data.to_df().loc[:,gene_name], 190 | cmap=cmap, 191 | vmin=vmin, 192 | vmax=vmax, 193 | s=spot_size 194 | ) 195 | if colorbar_on: 196 | fig.colorbar(g,label=gene_name) 197 | plt.axis('off') 198 | 199 | 200 | def plot_anchor_spots(umap_plot, 201 | pure_spots, 202 | sig_mean, 203 | bbox_x=2, 204 | ): 205 | fig,ax = plt.subplots(1,1,dpi=300,figsize=(3,3)) 206 | ax.scatter(umap_plot['umap1'], 207 | umap_plot['umap2'], 208 | s=2, 209 | alpha=1, 210 | color='lightgray') 211 | for i in range(len(pure_spots)): 212 | ax.scatter(umap_plot['umap1'][pure_spots[i]], 213 | umap_plot['umap2'][pure_spots[i]], 214 | s=8) 215 | plt.legend(['all']+[i for i in sig_mean.columns], 216 | loc='right', 217 | bbox_to_anchor=(bbox_x,0.5),) 218 | ax.grid(False) 219 | ax.axis('off') 220 | 221 | 222 | def plot_evs(evs, kmin): 223 | fig, ax = plt.subplots(1, 1, dpi=300, figsize=(6, 3)) 224 | plt.plot(np.arange(len(evs))+kmin, evs, '.-') 225 | plt.xlabel('ks') 226 | plt.ylabel('Explained Variance') 227 | plt.show() 228 | 229 | 230 | def pl_spatial_inf_feature( 231 | adata, 232 | feature, 233 | factor=None, 234 | vmin=0, 235 | vmax=None, 236 | spot_size=100, 237 | alpha=0, 238 | cmap='Spectral_r' 239 | ): 240 | """Spatial visualization of Starfysh inference features""" 241 | if isinstance(factor, str): 242 | assert factor in adata.uns['cell_types'], \ 243 | "Invalid Starfysh inference factor (cell type): ".format(factor) 244 | elif isinstance(factor, list): 245 | for f in factor: 246 | assert f in adata.uns['cell_types'], \ 247 | "Invalid Starfysh inference factor (cell type): ".format(f) 248 | else: 249 | factor = adata.uns['cell_types'] # if None, display for all cell types 250 | 251 | adata_pl = extract_feature(adata, feature) 252 | 253 | if feature == 'qc_m': 254 | if isinstance(factor, list): 255 | title = [f + ' (Inferred proportion - Spatial)' for f in factor] 256 | else: 257 | title = factor + ' (Inferred proportion - Spatial)' 258 | sc.pl.spatial( 259 | adata_pl, 260 | color=factor, spot_size=spot_size, color_map=cmap, 261 | ncols=3, vmin=vmin, vmax=vmax, alpha_img=alpha, 262 | title=title, legend_fontsize=8 263 | ) 264 | elif feature == 'ql_m': 265 | title = 'Estimated tissue density' 266 | sc.pl.spatial( 267 | adata_pl, 268 | color='density', spot_size=spot_size, color_map=cmap, 269 | vmin=vmin, vmax=vmax, alpha_img=alpha, 270 | title=title, legend_fontsize=8 271 | ) 272 | elif feature == 'qz_m': 273 | # Visualize deconvolution on UMAP of inferred Z-space 274 | qz_u = get_z_umap(adata_pl.obs.values) 275 | qc_df = extract_feature(adata, 'qc_m').obs 276 | if isinstance(factor, list): 277 | for cell_type in factor: 278 | title = cell_type + ' (Inferred proportion - UMAP of Z)' 279 | pl_umap_feature(qz_u, qc_df[cell_type].values, cmap, title, 280 | vmin=vmin, vmax=vmax) 281 | else: 282 | title = factor + ' (Inferred proportion - UMAP of Z)' 283 | fig, ax = pl_umap_feature(qz_u, qc_df[factor].values, cmap, title, 284 | vmin=vmin, vmax=vmax) 285 | return fig, ax 286 | else: 287 | raise ValueError('Invalid Starfysh inference results `{}`, please choose from `qc_m`, `qz_m` & `ql_m`'.format(feature)) 288 | 289 | pass 290 | 291 | 292 | def pl_umap_feature(qz_u, qc, cmap, title, spot_size=3, vmin=0, vmax=None): 293 | """Single Z-UMAP visualization of Starfysh deconvolutions""" 294 | fig, ax = plt.subplots(1, 1, figsize=(4, 3), dpi=200) 295 | g = ax.scatter( 296 | qz_u[:, 0], qz_u[:, 1], 297 | cmap=cmap, c=qc, s=spot_size, vmin=vmin, vmax=vmax, 298 | ) 299 | ax.set_xticks([]) 300 | ax.set_yticks([]) 301 | ax.set_title(title) 302 | ax.axis('off') 303 | fig.colorbar(g, label='Inferred proportions') 304 | 305 | return fig, ax 306 | 307 | 308 | def pl_spatial_inf_gene( 309 | adata=None, 310 | factor=None, 311 | feature=None, 312 | vmin=None, 313 | vmax=None, 314 | spot_size=100, 315 | alpha=0, 316 | figsize = (3,2.5), 317 | fig_dpi = 500, 318 | cmap='Spectral_r', 319 | colorbar_on = True, 320 | title = None , 321 | 322 | ): 323 | 324 | if isinstance(feature, str): 325 | assert feature in set(adata.var_names), \ 326 | "Gene {0} isn't HVG, please choose from `adata.var_names`".format(feature) 327 | title_new = feature + ' (Predicted expression)' 328 | else: 329 | for f in feature: 330 | assert f in set(adata.var_names), \ 331 | "Gene {0} isn't HVG, please choose from `adata.var_names`".format(f) 332 | title_new = [f + ' (Predicted expression)' for f in feature] 333 | 334 | if title is not None: 335 | title_new = title 336 | # Assign dummy `var_names` to avoid gene name in both obs & var 337 | adata_expr = extract_feature(adata, factor+'_inferred_exprs') 338 | adata_expr.var_names = np.arange(adata_expr.shape[1]) 339 | 340 | sc.settings.set_figure_params(figsize=figsize,dpi=fig_dpi,facecolor="white") 341 | 342 | sc.pl.spatial( 343 | adata_expr, 344 | color=feature, spot_size=spot_size, color_map=cmap, 345 | ncols=3, vmin=vmin, vmax=vmax, alpha_img=alpha, 346 | title=title_new, 347 | legend_fontsize=8, 348 | 349 | ) 350 | 351 | pass 352 | 353 | 354 | # -------------------------- 355 | # util funcs for benchmark 356 | # -------------------------- 357 | 358 | def _dist2gt(A, A_gt): 359 | """ 360 | Calculate the distance to ground-truth correlation matrix (proportions) 361 | """ 362 | return np.linalg.norm(A - A_gt, ord='fro') 363 | 364 | 365 | def _calc_rmse(y_true, y_pred): 366 | """Calculate per-spot RMSE between ground-truth & predicted proportions""" 367 | assert y_true.shape == y_pred.shape, "proportion matrices need to be the same shape to calculate RMSE" 368 | n_cts = y_true.shape[1] 369 | rmse = np.sqrt(((y_true.values-y_pred.values)**2).sum(1) / n_cts) 370 | return rmse 371 | 372 | 373 | def bootstrap_dists(corr_df, corr_gt_df, n_iter=1000, size=10): 374 | """ 375 | Calculate the avg. distance to ground-truth (sub)-matrix based on random subsampling 376 | """ 377 | if size == None: 378 | size = corr_df.shape[0] 379 | n = min(size, size) 380 | labels = corr_df.columns 381 | dists = np.zeros(n_iter) 382 | 383 | for i in range(n_iter): 384 | lbl = np.random.choice(corr_df.columns, n) 385 | A = corr_df.loc[lbl, lbl].values 386 | A_gt = corr_gt_df.loc[lbl, lbl].values 387 | dists[i] = _dist2gt(A, A_gt) 388 | 389 | return dists 390 | 391 | 392 | def disp_rmse(y_true, y_preds, labels, title=None, return_rmse=False): 393 | """ 394 | Boxplot of per-spot RMSEs for each prediction 395 | """ 396 | n_spots, n_cts = y_true.shape 397 | rmse = np.array([ 398 | np.sqrt(((y_true.values-y_pred.values)**2).sum(1) / n_cts) 399 | for y_pred in y_preds 400 | ]) 401 | 402 | lbls = np.repeat(labels, n_spots) 403 | df = pd.DataFrame({ 404 | 'RMSE': rmses.flatten(), 405 | 'Method': lbls 406 | }) 407 | plt.figure(figsize=(10, 6)) 408 | g = sns.boxplot(x='Method', y='RMSE', data=df) 409 | g.set_xticklabels(labels, rotation=60) 410 | plt.suptitle(title) 411 | plt.show() 412 | 413 | return rmses if return_rmse else None 414 | 415 | 416 | def disp_corr( 417 | y_true, y_pred, 418 | outdir=None, 419 | figsize=(3.2, 3.2), 420 | fontsize=5, 421 | title=None, 422 | filename=None, 423 | savefig=False, 424 | format='png', 425 | return_corr=False 426 | ): 427 | """ 428 | Calculate & plot correlation of cell proportion (or absolute cell abundance) 429 | between ground-truth & predictions (both [S x F]) 430 | """ 431 | 432 | assert y_true.shape[0] == y_pred.shape[0], 'Inconsistent sample sizes between ground-truth & prediction' 433 | if savefig: 434 | assert format == 'png' or format == 'eps' or format == 'svg', "Invalid saving format" 435 | 436 | v1 = y_true.values 437 | v2 = y_pred.values 438 | 439 | n_factor1, n_factor2 = v1.shape[1], v2.shape[1] 440 | corr = np.zeros((n_factor1, n_factor2)) 441 | gt_corr = y_true.corr().values 442 | 443 | for i in range(n_factor1): 444 | for j in range(n_factor2): 445 | corr[i, j], _ = np.round(pearsonr(v1[:, i], v2[:, j]), 3) 446 | 447 | fig, ax = plt.subplots(1, 1, figsize=figsize, dpi=300) 448 | ax = sns.heatmap( 449 | corr, annot=True, 450 | cmap='RdBu_r', vmin=-1, vmax=1, 451 | annot_kws={"fontsize": fontsize}, 452 | cbar_kws={'label': 'Cell type proportion corr.'}, 453 | ax=ax 454 | ) 455 | 456 | ax.set_xticks(np.arange(n_factor2) + 0.5) 457 | ax.set_yticks(np.arange(n_factor1) + 0.5) 458 | ax.set_xticklabels(y_pred.columns, rotation=90) 459 | ax.set_yticklabels(y_true.columns, rotation=0) 460 | ax.set_xlabel('Estimated proportion') 461 | ax.set_ylabel('Ground truth proportion') 462 | 463 | if title is not None: 464 | # ax.set_title(title+'\n'+'Distance = %.3f' % (dist2identity(corr))) 465 | ax.set_title(title + '\n' + 'Distance = %.3f' % (_calc_rmse(y_true, y_pred).mean())) 466 | 467 | for item in (ax.get_xticklabels() + ax.get_yticklabels()): 468 | item.set_fontsize(12) 469 | if savefig and (outdir is not None and filename is not None): 470 | if not os.path.exists(outdir): 471 | os.makedirs(outdir) 472 | fig.savefig(os.path.join(outdir, filename + '.' + format), bbox_inches='tight', format=format) 473 | plt.show() 474 | 475 | return corr if return_corr else None 476 | 477 | 478 | def disp_prop_scatter( 479 | y_true, y_pred, 480 | outdir=None, 481 | filename=None, 482 | savefig=False, 483 | format='png' 484 | ): 485 | """ 486 | Scatter plot of spot-wise proportion between ground-truth & predictions 487 | """ 488 | assert y_true.shape == y_pred.shape, 'Inconsistent dimension between ground-truth & prediction' 489 | if savefig: 490 | assert format == 'png' or format == 'eps' or format == 'svg', "Invalid saving format" 491 | 492 | n_factors = y_true.shape[1] 493 | y_true_vals = y_true.values 494 | y_pred_vals = y_pred.values 495 | ncols = int(np.ceil(n_factors / 2)) 496 | 497 | fig, (ax1, ax2) = plt.subplots(2, ncols, figsize=(2 * ncols, 4.4), dpi=300) 498 | 499 | for i in range(n_factors): 500 | v1 = y_true_vals[:, i] 501 | v2 = y_pred_vals[:, i] 502 | r2 = r2_score(v1, v2) 503 | 504 | v_stacked = np.vstack([v1, v2]) 505 | den = gaussian_kde(v_stacked)(v_stacked) 506 | 507 | ax = ax1[i] if i < ncols else ax2[i % ncols] 508 | ax.scatter(v1, v2, c=den, s=.2, cmap='turbo', vmax=den.max() / 3) 509 | 510 | ax.set_aspect('equal') 511 | ax.spines['right'].set_visible(False) 512 | ax.spines['top'].set_visible(False) 513 | ax.axis('equal') 514 | 515 | # Only show ticks on the left and bottom spines 516 | ax.yaxis.set_ticks_position('left') 517 | ax.xaxis.set_ticks_position('bottom') 518 | 519 | ax.set_title(y_pred.columns[i]) 520 | ax.annotate(r"$R^2$ = {:.3f}".format(r2), (0, 1), fontsize=8) 521 | 522 | ax.set_xlim([-0.1, 1.1]) 523 | ax.set_ylim([-0.1, 1.1]) 524 | ax.set_xticks(np.arange(0, 1.1, 0.5)) 525 | ax.set_yticks(np.arange(0, 1.1, 0.5)) 526 | 527 | ax.set_xlabel('Ground truth proportions') 528 | ax.set_ylabel('Predicted proportions') 529 | 530 | plt.tight_layout() 531 | if savefig and (outdir is not None and filename is not None): 532 | if not os.path.exists(outdir): 533 | os.makedirs(outdir) 534 | fig.savefig(os.path.join(outdir, filename + '.' + format), bbox_inches='tight', format=format) 535 | 536 | plt.show() 537 | 538 | -------------------------------------------------------------------------------- /starfysh/post_analysis.py: -------------------------------------------------------------------------------- 1 | import json 2 | import scanpy as sc 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import pandas as pd 6 | import os 7 | import networkx as nx 8 | import seaborn as sns 9 | from starfysh import utils 10 | import umap 11 | from scipy.stats import pearsonr, gaussian_kde 12 | from sklearn.neighbors import KNeighborsRegressor 13 | 14 | 15 | 16 | def get_z_umap(qz_m): 17 | fit = umap.UMAP(n_neighbors=45, min_dist=0.5) 18 | u = fit.fit_transform(qz_m) 19 | return u 20 | 21 | 22 | def plot_type_all(model, adata, proportions, figsize=(4, 4)): 23 | u = get_z_umap(adata.obsm['qz_m']) 24 | qc_m = adata.obsm["qc_m"] 25 | group_c = np.argmax(qc_m,axis=1) 26 | 27 | fig, ax = plt.subplots(figsize=figsize, dpi=300) 28 | cmaps = ['Blues','Greens','Reds','Oranges','Purples'] 29 | for i in range(proportions.shape[1]): 30 | plt.scatter(u[group_c==i,0],u[group_c==i,1],s=1,c = qc_m[group_c==i,i], cmap=cmaps[i]) 31 | 32 | # project the model's u on the umap 33 | knr = KNeighborsRegressor(10) 34 | knr.fit(adata.obsm["qz_m"], u) 35 | qu_umap = knr.predict(adata.uns['qu']) 36 | ax.scatter(*qu_umap.T, c='yellow', edgecolors='black', 37 | s=np.exp(model.qs_logm.cpu().detach()).sum(1)**1/2) 38 | 39 | plt.legend(proportions.columns,loc='right', bbox_to_anchor=(2.2,0.5),) 40 | plt.axis('off') 41 | return fig, ax 42 | 43 | 44 | def get_corr_map(inference_outputs, proportions): 45 | qc_m_n = inference_outputs["qc_m"].detach().cpu().numpy() 46 | corr_map_qcm = np.zeros([qc_m_n.shape[1],qc_m_n.shape[1]]) 47 | 48 | for i in range(corr_map_qcm.shape[0]): 49 | for j in range(corr_map_qcm.shape[0]): 50 | corr_map_qcm[i, j], _ = pearsonr(qc_m_n[:,i], proportions.iloc[:, j]) 51 | 52 | 53 | plt.figure(dpi=300,figsize=(3.2,3.2)) 54 | ax = sns.heatmap(corr_map_qcm.T, annot=True, 55 | cmap='RdBu_r',vmax=1,vmin=-1, 56 | cbar_kws={'label': 'Cell type proportion corr.'} 57 | ) 58 | plt.xticks(np.array(range(qc_m_n.shape[1]))+0.5,labels=proportions.columns,rotation=90) 59 | plt.yticks(np.array(range(qc_m_n.shape[1]))+0.5,labels=proportions.columns,rotation=0) 60 | plt.xlabel('Estimated proportion') 61 | plt.ylabel('Ground truth proportion') 62 | 63 | 64 | 65 | def display_reconst( 66 | df_true, 67 | df_pred, 68 | density=False, 69 | marker_genes=None, 70 | sample_rate=0.1, 71 | size=(3, 3), 72 | spot_size=1, 73 | title=None, 74 | x_label='', 75 | y_label='', 76 | x_min=0, 77 | x_max=10, 78 | y_min=0, 79 | y_max=10, 80 | ): 81 | """ 82 | Scatter plot - raw gexp vs. reconstructed gexp 83 | """ 84 | assert 0 < sample_rate <= 1, \ 85 | "Invalid downsampling rate for reconstruct scatter plot: {}".format(sample_rate) 86 | 87 | if marker_genes is not None: 88 | marker_genes = set(marker_genes) 89 | 90 | df_true_sample = df_true.sample(frac=sample_rate, random_state=0) 91 | df_pred_sample = df_pred.loc[df_true_sample.index] 92 | 93 | plt.rcParams["figure.figsize"] = size 94 | plt.figure(dpi=300) 95 | ax = plt.gca() 96 | 97 | xx = df_true_sample.T.to_numpy().flatten() 98 | yy = df_pred_sample.T.to_numpy().flatten() 99 | 100 | if density: 101 | for gene in df_true_sample.columns: 102 | try: 103 | gene_true = df_true_sample[gene].values 104 | gene_pred = df_pred_sample[gene].values 105 | gexp_stacked = np.vstack([df_true_sample[gene].values, df_pred_sample[gene].values]) 106 | 107 | z = gaussian_kde(gexp_stacked)(gexp_stacked) 108 | ax.scatter(gene_true, gene_pred, c=z, s=spot_size, alpha=0.5) 109 | except np.linalg.LinAlgError as e: 110 | pass 111 | 112 | elif marker_genes is not None: 113 | color_dict = {True: 'red', False: 'green'} 114 | gene_colors = np.vectorize( 115 | lambda x: color_dict[x in marker_genes] 116 | )(df_true_sample.columns) 117 | colors = np.repeat(gene_colors, df_true_sample.shape[0]) 118 | 119 | ax.scatter(xx, yy, c=colors, s=spot_size, alpha=0.5) 120 | 121 | else: 122 | ax.scatter(xx, yy, s=spot_size, alpha=0.5) 123 | 124 | min_val = min(xx.min(), yy.min()) 125 | max_val = max(xx.max(), yy.max()) 126 | #ax.set_xlim(min_val, 400) 127 | ax.set_xlim(x_min, x_max) 128 | ax.set_ylim(y_min, y_max) 129 | #ax.set_ylim(min_val, 400) 130 | 131 | plt.suptitle(title) 132 | plt.xlabel(x_label) 133 | plt.ylabel(y_label) 134 | # Hide the right and top spines 135 | ax.spines['right'].set_visible(False) 136 | ax.spines['top'].set_visible(False) 137 | ax.axis('equal') 138 | # Only show ticks on the left and bottom spines 139 | ax.yaxis.set_ticks_position('left') 140 | ax.xaxis.set_ticks_position('bottom') 141 | 142 | plt.show() 143 | 144 | 145 | def gene_mean_vs_inferred_prop(inference_outputs, visium_args,idx): 146 | 147 | sig_mean_n_df = pd.DataFrame( 148 | np.array(visium_args.sig_mean_norm)/(np.sum(np.array(visium_args.sig_mean_norm),axis=1,keepdims=True)+1e-5), 149 | columns=visium_args.sig_mean_norm.columns, 150 | index=visium_args.sig_mean_norm.index 151 | ) 152 | 153 | qc_m = inference_outputs["qc_m"].detach().cpu().numpy() 154 | 155 | figs,ax=plt.subplots(1,1,dpi=300,figsize=(2,2)) 156 | 157 | v1 = sig_mean_n_df.iloc[:,idx].values 158 | v2 = qc_m[:,idx] 159 | 160 | v_stacked = np.vstack([v1, v2]) 161 | den = gaussian_kde(v_stacked)(v_stacked) 162 | 163 | ax.scatter(v1,v2,c=den,s=1,cmap='jet',vmax=den.max()/3) 164 | 165 | ax.set_aspect('equal') 166 | ax.spines['right'].set_visible(False) 167 | ax.spines['top'].set_visible(False) 168 | ax.axis('equal') 169 | # Only show ticks on the left and bottom spines 170 | ax.yaxis.set_ticks_position('left') 171 | ax.xaxis.set_ticks_position('bottom') 172 | plt.title(visium_args.gene_sig.columns[idx]) 173 | plt.xlim([v1.min()-0.1,v1.max()+0.1]) 174 | plt.ylim([v2.min()-0.1,v2.max()+0.1]) 175 | #plt.xticks(np.arange(0,1.1,0.5)) 176 | #plt.yticks(np.arange(0,1.1,0.5)) 177 | plt.xlabel('Gene signature mean') 178 | plt.ylabel('Predicted proportions') 179 | 180 | 181 | def plot_stacked_prop(results, category_names): 182 | """ 183 | Parameters 184 | ---------- 185 | results : dict 186 | A mapping from question labels to a list of answers per category. 187 | It is assumed all lists contain the same number of entries and that 188 | it matches the length of *category_names*. 189 | 190 | category_names : list of str 191 | The category labels. 192 | """ 193 | labels = list(results.keys()) 194 | data = np.array(list(results.values())) 195 | data_cum = data.cumsum(axis=1) 196 | category_colors = plt.get_cmap('rainbow')( 197 | np.linspace(0.15, 0.85, data.shape[1])) 198 | #category_colors = np.array(['b','g','r','oragne','purple']) 199 | fig, ax = plt.subplots(figsize=(2.5,1.8),dpi=300) 200 | ax.invert_yaxis() 201 | ax.xaxis.set_visible(False) 202 | ax.set_xlim(0, np.sum(data, axis=1).max()) 203 | 204 | for i, (colname, color) in enumerate(zip(category_names, category_colors)): 205 | widths = data[:, i] 206 | starts = data_cum[:, i] - widths 207 | ax.barh(labels, widths, left=starts, height=0.6,label=colname, color=color) 208 | xcenters = starts + widths / 2 209 | 210 | r, g, b, _ = color 211 | text_color = 'black' #if r * g * b < 0.5 else 'darkgrey' 212 | #for y, (x, c) in enumerate(zip(xcenters, widths)): 213 | # ax.text(x, y, str(round(c,2)), ha='center', va='center', 214 | # color=text_color) 215 | #ax.legend(ncol=len(category_names), bbox_to_anchor=(0, 1), 216 | # loc='right', fontsize='small') 217 | ax.legend(category_names,loc='right',bbox_to_anchor=(2, 0.5)) 218 | return fig, ax 219 | 220 | 221 | def plot_density(results, category_names): 222 | """ 223 | Parameters 224 | ---------- 225 | results : dict 226 | A mapping from question labels to a list of answers per category. 227 | It is assumed all lists contain the same number of entries and that 228 | it matches the length of *category_names*. 229 | 230 | category_names : list of str 231 | The category labels. 232 | """ 233 | labels = list(results.keys()) 234 | data = np.array(list(results.values())) 235 | category_colors = plt.get_cmap('RdBu_r')( 236 | np.linspace(0.15, 0.85, data.shape[1])) 237 | fig, ax = plt.subplots(figsize=(2.5,1.8),dpi=300) 238 | ax.invert_yaxis() 239 | #ax.xaxis.set_visible(False) 240 | ax.set_xlim(0, np.sum(data, axis=1).max()) 241 | 242 | for i, (colname, color) in enumerate(zip(category_names, category_colors)): 243 | widths = data[:, i] 244 | starts = 0 245 | ax.barh(labels, widths, left=starts, height=0.6,label=colname, color=color) 246 | 247 | r, g, b, _ = color 248 | return fig, ax 249 | 250 | 251 | def get_factor_dist(sample_ids,file_path): 252 | qc_p_dist = {} 253 | # Opening JSON file 254 | for sample_id in sample_ids: 255 | print(sample_id) 256 | f = open(file_path+sample_id+'_factor.json','r') 257 | data = json.load(f) 258 | qc_p_dist[sample_id] = data['qc_m'] 259 | f.close() 260 | return qc_p_dist 261 | 262 | 263 | def get_adata(sample_ids, data_folder): 264 | adata_sample_all = [] 265 | map_info_all = [] 266 | adata_image_all = [] 267 | for sample_id in sample_ids: 268 | print('loading...',sample_id) 269 | if (sample_id.startswith('MBC'))|(sample_id.startswith('CT')): 270 | 271 | adata_sample = sc.read_visium(path=os.path.join(data_folder, sample_id),library_id = sample_id) 272 | adata_sample.var_names_make_unique() 273 | adata_sample.obs['sample']=sample_id 274 | adata_sample.obs['sample_type']='MBC' 275 | #adata_sample.obs_names = adata_sample.obs_names+'-'+sample_id 276 | #adata_sample.obs_names = adata_sample.obs_names+'_'+sample_id 277 | if '_index' in adata_sample.var.columns: 278 | adata_sample.var_names=adata_sample.var['_index'] 279 | 280 | else: 281 | adata_sample = sc.read_h5ad(os.path.join(data_folder,sample_id, sample_id+'.h5ad')) 282 | adata_sample.var_names_make_unique() 283 | adata_sample.obs['sample']=sample_id 284 | adata_sample.obs['sample_type']='TNBC' 285 | #adata_sample.obs_names = adata_sample.obs_names+'-'+sample_id 286 | #adata_sample.obs_names = adata_sample.obs_names+'_'+sample_id 287 | if '_index' in adata_sample.var.columns: 288 | adata_sample.var_names=adata_sample.var['_index'] 289 | 290 | if data_folder =='simu_data': 291 | map_info = utils.get_simu_map_info(umap_df) 292 | else: 293 | adata_image,map_info = utils.preprocess_img(data_folder,sample_id,adata_sample.obs.index,hchannal=False) 294 | 295 | adata_sample.obs_names = adata_sample.obs_names+'-'+sample_id 296 | map_info.index = map_info.index+'-'+sample_id 297 | adata_sample_all.append(adata_sample) 298 | map_info_all.append(map_info) 299 | adata_image_all.append(adata_image) 300 | return adata_sample_all,map_info_all,adata_image_all 301 | 302 | def get_Moran(W, X): 303 | N = W.shape[0] 304 | term1 = N / W.sum().sum() 305 | x_m = X.mean() 306 | term2 = np.matmul(np.matmul(np.diag(X-x_m),W),np.diag(X-x_m)) 307 | term3 = term2.sum().sum() 308 | term4 = ((X-x_m)**2).sum() 309 | term5 = term1 * term3 / term4 310 | return term5 311 | 312 | def get_LISA(W, X): 313 | lisa_score = np.zeros(X.shape) 314 | N = W.shape[0] 315 | x_m = X.mean() 316 | term1 = X-x_m 317 | term2 = ((X-x_m)**2).sum() 318 | for i in range(term1.shape[0]): 319 | #term3 = np.zeros(X.shape) 320 | term3 = (W[i,:]*(X-x_m)).sum() 321 | #for j in range(W.shape[0]): 322 | # term3[j]=W[i,j]*(X[j]-x_m) 323 | #term3 = term3.sum() 324 | lisa_score[i] = np.sign(X[i]-x_m) * N * (X[i]-x_m) * term3 / term2 325 | #lisa_score[i] = N * (X[i]-x_m) * term3 / term2 326 | 327 | return lisa_score 328 | 329 | def get_SCI(W, X, Y): 330 | 331 | N = W.shape[0] 332 | term1 = N / (2*W.sum().sum()) 333 | 334 | x_m = X.mean() 335 | y_m = Y.mean() 336 | term2 = np.matmul(np.matmul(np.diag(X-x_m),W),np.diag(Y-y_m)) 337 | term3 = term2.sum().sum() 338 | 339 | term4 = np.sqrt(((X-x_m)**2).sum()) * np.sqrt(((Y-y_m)**2).sum()) 340 | 341 | term5 = term1 * term3 / term4 342 | 343 | return term5 344 | 345 | def get_cormtx(sample_id, hub_num ): 346 | # TODO: get_cormtx 347 | prop_i = proportions_df[ids_df['source']==sample_id][cluster_df['cluster']==hub_num] 348 | loc_i = np.array(map_info_all.loc[prop_i.index].loc[:,['array_col','array_row',]]) 349 | W = np.zeros([loc_i.shape[0],loc_i.shape[0]]) 350 | 351 | cor_matrix = np.zeros([gene_sig.shape[1],gene_sig.shape[1]]) 352 | for i in range(loc_i.shape[0]): 353 | for j in range(i,loc_i.shape[0]): 354 | if np.sqrt((loc_i[i,0]-loc_i[j,0])**2+(loc_i[i,1]-loc_i[j,1])**2)<=3: 355 | W[i,j] = 1 356 | W[j,i] = 1 357 | #indices = vor.regions[vor.point_region[i]] 358 | #neighbor_i = np.concatenate([vor.ridge_points[np.where(vor.ridge_points[:,0] == i)],np.flip(vor.ridge_points[np.where(vor.ridge_points[:,1] == i)],axis=1)],axis=0)[:,1] 359 | #W[i,neighbor_i]=1 360 | #W[neighbor_i,i]=1 361 | print('spots in hub ',hub_num, '= ',prop_i.shape[0]) 362 | if prop_i.shape[0]>1: 363 | for i in range(gene_sig.shape[1]): 364 | for j in range(i+1,gene_sig.shape[1]): 365 | cor_matrix[i,j]=get_SCI(W, np.array(prop_i.iloc[:,i]), np.array(prop_i.iloc[:,j])) 366 | cor_matrix[j,i]=cor_matrix[i,j] 367 | return cor_matrix 368 | 369 | def get_hub_cormtx(sample_ids, hub_num): 370 | cor_matrix = np.zeros([gene_sig.shape[1],gene_sig.shape[1]]) 371 | for sample_id in sample_ids: 372 | print(sample_id) 373 | cor_matrix = cor_matrix + get_cormtx(sample_id = sample_id, hub_num=hub_num) 374 | #print(cor_matrix) 375 | cor_matrix = cor_matrix/len(sample_ids) 376 | #cor_matrix = pd.DataFrame(cor_matrix) 377 | return cor_matrix 378 | 379 | def create_corr_network_5(G, node_size_list,corr_direction, min_correlation): 380 | ##Creates a copy of the graph 381 | H = G.copy() 382 | 383 | ##Checks all the edges and removes some based on corr_direction 384 | for stock1, stock2, weight in G.edges(data=True): 385 | #print(weight) 386 | ##if we only want to see the positive correlations we then delete the edges with weight smaller than 0 387 | if corr_direction == "positive": 388 | ####it adds a minimum value for correlation. 389 | ####If correlation weaker than the min, then it deletes the edge 390 | if weight["weight"] <0 or weight["weight"] < min_correlation: 391 | H.remove_edge(stock1, stock2) 392 | ##this part runs if the corr_direction is negative and removes edges with weights equal or largen than 0 393 | else: 394 | ####it adds a minimum value for correlation. 395 | ####If correlation weaker than the min, then it deletes the edge 396 | if weight["weight"] >=0 or weight["weight"] > min_correlation: 397 | H.remove_edge(stock1, stock2) 398 | 399 | 400 | #crates a list for edges and for the weights 401 | edges,weights = zip(*nx.get_edge_attributes(H,'weight').items()) 402 | 403 | 404 | ### increases the value of weights, so that they are more visible in the graph 405 | #weights = tuple([(0.5+abs(x))**1 for x in weights]) 406 | weights = tuple([x*2 for x in weights]) 407 | #print(len(weights)) 408 | #####calculates the degree of each node 409 | d = nx.degree(H) 410 | #print(d) 411 | #####creates list of nodes and a list their degrees that will be used later for their sizes 412 | nodelist, node_sizes = zip(*dict(d).items()) 413 | #import sys, networkx as nx, matplotlib.pyplot as plt 414 | 415 | # Create a list of 10 nodes numbered [0, 9] 416 | #nodes = range(10) 417 | node_sizes = [] 418 | labels = {} 419 | for n in nodelist: 420 | node_sizes.append( node_size_list[n] ) 421 | labels[n] = 1 * n 422 | 423 | # Node sizes: [0, 100, 200, 300, 400, 500, 600, 700, 800, 900] 424 | 425 | # Connect each node to its successor 426 | #edges = [ (i, i+1) for i in range(len(nodes)-1) ] 427 | 428 | # Create the graph and draw it with the node labels 429 | #g = nx.Graph() 430 | #g.add_nodes_from(nodes) 431 | #g.add_edges_from(edges) 432 | 433 | #nx.draw_random(g, node_size = node_sizes, labels=labels, with_labels=True) 434 | #plt.show() 435 | 436 | #positions 437 | positions=nx.circular_layout(H) 438 | #print(positions) 439 | 440 | #Figure size 441 | plt.figure(figsize=(2,2),dpi=500) 442 | 443 | #draws nodes, 444 | #options = {"edgecolors": "tab:gray", "alpha": 0.9} 445 | nx.draw_networkx_nodes(H,positions, 446 | #node_color='#DA70D6', 447 | nodelist=nodelist, 448 | #####the node size will be now based on its degree 449 | node_color=_colors['leiden_colors'][hub_num],# 'lightgreen',#pink, 'lightblue',#'#FFACB7',lightgreen B19CD9。#FFACB7 brown 450 | alpha = 0.8, 451 | node_size=tuple([x**1 for x in node_sizes]), 452 | #**options 453 | ) 454 | 455 | #Styling for labels 456 | nx.draw_networkx_labels(H, positions, font_size=4, 457 | font_family='sans-serif') 458 | 459 | ###edge colors based on weight direction 460 | if corr_direction == "positive": 461 | edge_colour = plt.cm.GnBu#PiYG_r#RdBu_r#Spectral_r#GnBu#RdPu#PuRd#Blues#PuRd#GnBu OrRd 462 | else: 463 | edge_colour = plt.cm.PuRd 464 | 465 | #draws the edges 466 | print(min(weights)) 467 | print(max(weights)) 468 | 469 | nx.draw_networkx_edges(H, positions, edgelist=edges,style='solid', 470 | ###adds width=weights and edge_color = weights 471 | ###so that edges are based on the weight parameter 472 | ###edge_cmap is for the color scale based on the weight 473 | ### edge_vmin and edge_vmax assign the min and max weights for the width 474 | width=weights, edge_color = weights, edge_cmap = edge_colour, 475 | edge_vmin = 0,#min(weights),#0.55,#min(weights), 476 | edge_vmax= 0.7,#max(weights),#0.6,#max(weights) 477 | #edge_vmin = min(weights),#0.55,#min(weights), 478 | #edge_vmax= max(weights),#0.6,#max(weights) 479 | ) 480 | 481 | # displays the graph without axis 482 | plt.axis('off') 483 | #plt.legend(['r','r']) 484 | #saves image 485 | #plt.savefig("part5" + corr_direction + ".png", format="PNG") 486 | #plt.show() 487 | -------------------------------------------------------------------------------- /starfysh/starfysh.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import os 6 | import random 7 | 8 | os.environ['CUDA_LAUNCH_BLOCKING'] = "1" 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from torchvision import transforms 14 | from torchvision.utils import make_grid 15 | from torch.distributions import constraints, Distribution, Normal, Gamma, Poisson, Dirichlet 16 | from torch.distributions import kl_divergence as kl 17 | 18 | # Module import 19 | from starfysh import LOGGER 20 | from .post_analysis import get_z_umap 21 | os.environ['CUDA_LAUNCH_BLOCKING'] = "1" 22 | random.seed(0) 23 | np.random.seed(0) 24 | 25 | 26 | # TODO: 27 | # inherit `AVAE` (expr model) w/ `AVAE_PoE` (expr + histology model), update latest PoE model 28 | class AVAE(nn.Module): 29 | """ 30 | Model design 31 | p(x|z)=f(z) 32 | p(z|x)~N(0,1) 33 | q(z|x)~g(x) 34 | """ 35 | 36 | def __init__( 37 | self, 38 | adata, 39 | gene_sig, 40 | win_loglib, 41 | alpha_mul=50, 42 | seed=0, 43 | device=torch.device('cuda' if torch.cuda.is_available() else 'cpu') 44 | ) -> None: 45 | """ 46 | Auxiliary Variational AutoEncoder (AVAE) - Core model for 47 | spatial deconvolution without H&E image integration 48 | 49 | Paramters 50 | --------- 51 | adata : sc.AnnData 52 | ST raw expression count (dim: [S, G]) 53 | 54 | gene_sig : pd.DataFrame 55 | Normalized avg. signature expressions for each annotated cell type 56 | 57 | win_loglib : float 58 | Log-library size smoothed with neighboring spots 59 | 60 | alpha_mul : float (default=50) 61 | Multiplier of Dirichlet concentration parameter to control 62 | signature prior's confidence 63 | """ 64 | super().__init__() 65 | torch.manual_seed(seed) 66 | 67 | self.win_loglib=torch.Tensor(win_loglib) 68 | 69 | self.c_in = adata.shape[1] # c_in : Num. input features (# input genes) 70 | self.c_bn = 10 # c_bn : latent number, numbers of bottle-necks 71 | self.c_hidden = 256 72 | self.c_kn = gene_sig.shape[1] 73 | self.eps = 1e-5 # for r.v. w/ numerical constraints 74 | self.device = device 75 | 76 | self.alpha = torch.ones(self.c_kn) * alpha_mul 77 | self.alpha = self.alpha.to(device) 78 | 79 | self.qs_logm = torch.nn.Parameter(torch.zeros(self.c_kn, self.c_bn), requires_grad=True) 80 | self.qu_m = torch.nn.Parameter(torch.randn(self.c_kn, self.c_bn), requires_grad=True) 81 | self.qu_logv = torch.nn.Parameter(torch.zeros(self.c_kn, self.c_bn), requires_grad=True) 82 | 83 | self.c_enc = nn.Sequential( 84 | nn.Linear(self.c_in, self.c_hidden, bias=True), 85 | nn.BatchNorm1d(self.c_hidden, momentum=0.01, eps=0.001), 86 | nn.ReLU() 87 | ) 88 | 89 | self.c_enc_m = nn.Sequential( 90 | nn.Linear(self.c_hidden, self.c_kn, bias=True), 91 | nn.BatchNorm1d(self.c_kn, momentum=0.01, eps=0.001), 92 | nn.Softmax(dim=-1) 93 | ) 94 | 95 | self.l_enc = nn.Sequential( 96 | nn.Linear(self.c_in, self.c_hidden, bias=True), 97 | nn.BatchNorm1d(self.c_hidden, momentum=0.01, eps=0.001), 98 | nn.ReLU() 99 | ) 100 | 101 | self.l_enc_m = nn.Linear(self.c_hidden, 1) 102 | self.l_enc_logv = nn.Linear(self.c_hidden, 1) 103 | 104 | # neural network f1 to get the z, p(z|x), f1(x,\phi_1)=[z_m,torch.exp(z_logv)] 105 | self.z_enc = nn.Sequential( 106 | # nn.Linear(self.c_in+self.c_kn, self.c_hidden, bias=True), 107 | nn.Linear(self.c_in, self.c_hidden, bias=True), 108 | nn.BatchNorm1d(self.c_hidden, momentum=0.01, eps=0.001), 109 | nn.ReLU(), 110 | ) 111 | 112 | self.z_enc_m = nn.Linear(self.c_hidden, self.c_bn * self.c_kn) 113 | self.z_enc_logv = nn.Linear(self.c_hidden, self.c_bn * self.c_kn) 114 | 115 | # gene dispersion 116 | self._px_r = torch.nn.Parameter(torch.randn(self.c_in), requires_grad=True) 117 | 118 | # neural network g to get the x_m and x_v, p(x|z), g(z,\phi_3)=[x_m,x_v] 119 | self.px_hidden_decoder = nn.Sequential( 120 | nn.Linear(self.c_bn, self.c_hidden, bias=True), 121 | nn.ReLU(), 122 | ) 123 | self.px_scale_decoder = nn.Sequential( 124 | nn.Linear(self.c_hidden, self.c_in), 125 | nn.Softmax(dim=-1) 126 | ) 127 | 128 | def reparameterize(self, mu, log_var): 129 | """ 130 | :param mu: mean from the encoder's latent space 131 | :param log_var: log variance from the encoder's latent space 132 | """ 133 | std = torch.exp(0.5 * log_var) # standard deviation 134 | eps = torch.randn_like(std) # `randn_like` as we need the same size 135 | sample = mu + (eps * std) # sampling 136 | return sample 137 | 138 | def inference(self, x): 139 | x_n = torch.log1p(x) 140 | 141 | hidden = self.l_enc(x_n) 142 | ql_m = self.l_enc_m(hidden) 143 | ql_logv = self.l_enc_logv(hidden) 144 | ql = self.reparameterize(ql_m, ql_logv) 145 | 146 | hidden = self.c_enc(x_n) 147 | qc_m = self.c_enc_m(hidden) 148 | qc = Dirichlet(qc_m * self.alpha + self.eps).rsample()[:,:,None] 149 | 150 | hidden = self.z_enc(x_n) 151 | qz_m_ct = self.z_enc_m(hidden).reshape([x_n.shape[0], self.c_kn, self.c_bn]) 152 | qz_m_ct = qc * qz_m_ct 153 | qz_m = qz_m_ct.sum(axis=1) 154 | 155 | qz_logv_ct = self.z_enc_logv(hidden).reshape([x_n.shape[0], self.c_kn, self.c_bn]) 156 | qz_logv_ct = qc * qz_logv_ct 157 | qz_logv = qz_logv_ct.sum(axis=1) 158 | qz = self.reparameterize(qz_m, qz_logv) 159 | 160 | qu = self.reparameterize(self.qu_m, self.qu_logv) 161 | 162 | return dict( 163 | # q(u) 164 | qu=qu, 165 | 166 | # q(c | x) 167 | qc_m=qc_m, 168 | qc=qc, 169 | 170 | # q(z | c, x) 171 | qz_m=qz_m, 172 | qz_m_ct=qz_m_ct, 173 | qz_logv=qz_logv, 174 | qz_logv_ct=qz_logv_ct, 175 | qz=qz, 176 | 177 | # q(l | x) 178 | ql_m=ql_m, 179 | ql_logv=ql_logv, 180 | ql=ql, 181 | ) 182 | 183 | def generative( 184 | self, 185 | inference_outputs, 186 | xs_k, 187 | ): 188 | qz = inference_outputs['qz'] 189 | ql = inference_outputs['ql'] 190 | 191 | hidden = self.px_hidden_decoder(qz) 192 | px_scale = self.px_scale_decoder(hidden) 193 | px_rate = torch.exp(ql) * px_scale + self.eps 194 | pc_p = xs_k + self.eps 195 | 196 | return dict( 197 | px_rate=px_rate, 198 | px_r=self.px_r, 199 | pc_p=pc_p, 200 | xs_k=xs_k, 201 | ) 202 | 203 | def get_loss( 204 | self, 205 | generative_outputs, 206 | inference_outputs, 207 | x, 208 | library, 209 | device 210 | ): 211 | # Variational params 212 | qs_logm = self.qs_logm 213 | qu_m, qu_logv, qu = self.qu_m, self.qu_logv, inference_outputs["qu"] 214 | qc_m, qc = inference_outputs["qc_m"], inference_outputs["qc"] 215 | qz_m, qz_logv = inference_outputs["qz_m"], inference_outputs["qz_logv"] 216 | ql_m, ql_logv = inference_outputs["ql_m"], inference_outputs['ql_logv'] 217 | 218 | # p(x | z), p(c; \alpha), p(u; \sigma) 219 | px_rate = generative_outputs["px_rate"] 220 | px_r = generative_outputs["px_r"] 221 | pc_p = generative_outputs["pc_p"] 222 | 223 | pu_m = torch.zeros_like(qu_m) 224 | pu_std = torch.ones_like(qu_logv) * 10 225 | 226 | # Regularization terms 227 | kl_divergence_u = kl( 228 | Normal(qu_m, torch.exp(qu_logv / 2)), 229 | Normal(pu_m, pu_std) 230 | ).sum(dim=1).mean() 231 | 232 | kl_divergence_l = kl( 233 | Normal(ql_m, torch.exp(ql_logv / 2)), 234 | Normal(library, torch.ones_like(ql_m)) 235 | ).sum(dim=1).mean() 236 | 237 | kl_divergence_c = kl( 238 | Dirichlet(qc_m * self.alpha), 239 | Dirichlet(pc_p * self.alpha) 240 | ).mean() 241 | 242 | pz_m = (qu.unsqueeze(0) * qc).sum(axis=1) 243 | pz_std = (torch.exp(qs_logm / 2).unsqueeze(0) * qc).sum(axis=1) 244 | 245 | kl_divergence_z = kl( 246 | Normal(qz_m, torch.exp(qz_logv / 2)), 247 | Normal(pz_m, pz_std) 248 | ).sum(dim=1).mean() 249 | 250 | # Reconstruction term 251 | reconst_loss = -NegBinom(px_rate, torch.exp(px_r)).log_prob(x).sum(-1).mean() 252 | 253 | loss = reconst_loss.to(device) + \ 254 | kl_divergence_u.to(device) + \ 255 | kl_divergence_z.to(device) + \ 256 | kl_divergence_c.to(device) + \ 257 | kl_divergence_l.to(device) 258 | 259 | return (loss, 260 | reconst_loss, 261 | kl_divergence_u, 262 | kl_divergence_z, 263 | kl_divergence_c, 264 | kl_divergence_l 265 | ) 266 | 267 | @property 268 | def px_r(self): 269 | return F.softplus(self._px_r) + self.eps 270 | 271 | 272 | class AVAE_PoE(nn.Module): 273 | """ 274 | Model design: 275 | p(x|z)=f(z) 276 | p(z|x)~N(0,1) 277 | q(z|x)~g(x) 278 | """ 279 | 280 | def __init__( 281 | self, 282 | adata, 283 | gene_sig, 284 | patch_r, 285 | win_loglib, 286 | alpha_mul=50, 287 | n_img_chan=1, 288 | seed=0, 289 | device=torch.device('cuda' if torch.cuda.is_available() else 'cpu') 290 | ) -> None: 291 | """ 292 | Auxiliary Variational AutoEncoder (AVAE) with Joint H&E inference 293 | - Core model for spatial deconvolution w/ H&E image integration 294 | 295 | Paramters 296 | --------- 297 | adata : sc.AnnData 298 | ST raw expression count (dim: [S, G]) 299 | 300 | gene_sig : pd.DataFrame 301 | Signature gene sets for each annotated cell type 302 | 303 | patch_r : int 304 | Mini-patch size sampled around each spot from raw H&E image 305 | 306 | win_loglib : float 307 | Log-library size smoothed with neighboring spots 308 | 309 | alpha_mul : float (default=50) 310 | Multiplier of Dirichlet concentration parameter to control 311 | signature prior's confidence 312 | 313 | """ 314 | super().__init__() 315 | torch.manual_seed(seed) 316 | 317 | self.win_loglib = torch.Tensor(win_loglib) 318 | self.patch_r = patch_r 319 | self.c_in = adata.shape[1] # c_in : Num. input features (# input genes) 320 | self.nimg_chan = n_img_chan 321 | self.c_in_img = self.patch_r**2 * 4 * n_img_chan # c_in_img : (# pixels for the spot's img patch) 322 | self.c_bn = 10 # c_bn : latent number, numbers of bottleneck 323 | self.c_hidden = 256 324 | self.c_kn = gene_sig.shape[1] 325 | 326 | self.eps = 1e-5 # for r.v. w/ numerical constraints 327 | self.alpha = torch.ones(self.c_kn) * alpha_mul 328 | self.alpha = self.alpha.to(device) 329 | 330 | # --- neural nets for Expression view --- 331 | self.qs_logm = torch.nn.Parameter(torch.zeros(self.c_kn, self.c_bn), requires_grad=True) 332 | self.qu_m = torch.nn.Parameter(torch.randn(self.c_kn, self.c_bn), requires_grad=True) 333 | self.qu_logv = torch.nn.Parameter(torch.zeros(self.c_kn, self.c_bn), requires_grad=True) 334 | 335 | self.c_enc = nn.Sequential( 336 | nn.Linear(self.c_in, self.c_hidden, bias=True), 337 | nn.BatchNorm1d(self.c_hidden, momentum=0.01, eps=0.001), 338 | nn.ReLU() 339 | ) 340 | 341 | self.c_enc_m = nn.Sequential( 342 | nn.Linear(self.c_hidden, self.c_kn, bias=True), 343 | nn.BatchNorm1d(self.c_kn, momentum=0.01, eps=0.001), 344 | nn.Softmax(dim=-1) 345 | ) 346 | 347 | self.l_enc = nn.Sequential( 348 | nn.Linear(self.c_in+self.c_in_img, self.c_hidden, bias=True), 349 | nn.BatchNorm1d(self.c_hidden, momentum=0.01, eps=0.001), 350 | nn.ReLU(), 351 | ) 352 | self.l_enc_m = nn.Linear(self.c_hidden, 1) 353 | self.l_enc_logv = nn.Linear(self.c_hidden, 1) 354 | 355 | # neural network f1 to get the z, p(z|x), f1(x,\phi_1)=[z_m,torch.exp(z_logv)] 356 | self.z_enc = nn.Sequential( 357 | nn.Linear(self.c_in, self.c_hidden, bias=True), 358 | nn.BatchNorm1d(self.c_hidden, momentum=0.01, eps=0.001), 359 | nn.ReLU(), 360 | ) 361 | self.z_enc_m = nn.Linear(self.c_hidden, self.c_bn * self.c_kn) 362 | self.z_enc_logv = nn.Linear(self.c_hidden, self.c_bn * self.c_kn) 363 | 364 | # gene dispersion 365 | self._px_r = torch.nn.Parameter(torch.randn(self.c_in), requires_grad=True) 366 | 367 | # neural network g to get the x_m and x_v, p(x|z), g(z,\phi_3)=[x_m,x_v] 368 | self.z_to_hidden_decoder = nn.Sequential( 369 | nn.Linear(self.c_bn, self.c_hidden, bias=True), 370 | nn.ReLU(), 371 | ) 372 | self.px_scale_decoder = nn.Sequential( 373 | nn.Linear(self.c_hidden, self.c_in), 374 | nn.Softmax(dim=-1) 375 | ) 376 | 377 | # --- neural nets for Histology view --- 378 | # encoder paths: 379 | self.img_z_enc = nn.Sequential( 380 | nn.Linear(self.c_in_img, self.c_hidden, bias=True), 381 | nn.BatchNorm1d(self.c_hidden, momentum=0.01, eps=0.001), 382 | nn.ReLU() 383 | ) 384 | self.img_z_enc_m = nn.Linear(self.c_hidden, self.c_bn) 385 | self.img_z_enc_logv = nn.Linear(self.c_hidden, self.c_bn) 386 | 387 | # decoder paths: 388 | self.img_z_to_hidden_decoder = nn.Linear(self.c_bn, self.c_hidden, bias=True) 389 | self.py_mu_decoder = nn.Sequential( 390 | nn.Linear(self.c_hidden, self.c_in_img), 391 | nn.BatchNorm1d(self.c_in_img, momentum=0.01, eps=0.001), 392 | nn.ReLU() 393 | ) 394 | self.py_logv_decoder = nn.Sequential( 395 | nn.Linear(self.c_hidden, self.c_in_img), 396 | nn.ReLU() 397 | ) 398 | 399 | # --- PoE view --- 400 | self.z_to_hidden_poe_decoder = nn.Linear(self.c_bn, self.c_hidden, bias=True) 401 | self.px_scale_poe_decoder = nn.Sequential( 402 | nn.Linear(self.c_hidden, self.c_in), 403 | nn.ReLU() 404 | ) 405 | self._px_r_poe = torch.nn.Parameter(torch.randn(self.c_in), requires_grad=True) 406 | 407 | self.py_mu_poe_decoder = nn.Sequential( 408 | nn.Linear(self.c_hidden, self.c_in_img), 409 | nn.BatchNorm1d(self.c_in_img, momentum=0.01, eps=0.001), 410 | nn.ReLU() 411 | ) 412 | self.py_logv_poe_decoder = nn.Sequential( 413 | nn.Linear(self.c_hidden, self.c_in_img), 414 | nn.ReLU() 415 | ) 416 | 417 | def reparameterize(self, mu, log_var): 418 | """ 419 | :param mu: mean from the encoder's latent space 420 | :param log_var: log variance from the encoder's latent space 421 | """ 422 | std = torch.exp(0.5 * log_var) # standard deviation 423 | eps = torch.randn_like(std) # `randn_like` as we need the same size 424 | sample = mu + (eps * std) # sampling 425 | return sample 426 | 427 | def inference(self, x, y): 428 | # q(l | x) 429 | x_n = torch.log1p(x) # l is inferred from log(x) 430 | y_n = torch.log1p(y) 431 | 432 | hidden = self.l_enc(torch.concat([x_n,y_n],axis=1)) 433 | ql_m = self.l_enc_m(hidden) 434 | ql_logv = self.l_enc_logv(hidden) 435 | ql = self.reparameterize(ql_m, ql_logv) 436 | 437 | # q(c | x) 438 | hidden = self.c_enc(x_n) 439 | qc_m = self.c_enc_m(hidden) 440 | qc = Dirichlet(qc_m * self.alpha + self.eps).rsample()[:,:,None] 441 | 442 | # q(z | c, x) 443 | hidden = self.z_enc(x_n) 444 | qz_m_ct = self.z_enc_m(hidden).reshape([x_n.shape[0], self.c_kn, self.c_bn]) 445 | qz_m_ct = qc * qz_m_ct 446 | qz_m = qz_m_ct.sum(axis=1) 447 | 448 | qz_logv_ct = self.z_enc_logv(hidden).reshape([x_n.shape[0], self.c_kn, self.c_bn]) 449 | qz_logv_ct = qc * qz_logv_ct 450 | qz_logv = qz_logv_ct.sum(axis=1) 451 | qz = self.reparameterize(qz_m, qz_logv) 452 | 453 | # q(u), mean-field VI 454 | qu = self.reparameterize(self.qu_m, self.qu_logv) 455 | 456 | return dict( 457 | # q(u) 458 | qu=qu, 459 | 460 | # q(c | x) 461 | qc_m=qc_m, 462 | qc=qc, 463 | 464 | # q(z | x) 465 | qz_m=qz_m, 466 | qz_m_ct=qz_m_ct, 467 | qz_logv=qz_logv, 468 | qz_logv_ct=qz_logv_ct, 469 | qz=qz, 470 | 471 | # q(l | x) 472 | ql_m=ql_m, 473 | ql_logv=ql_logv, 474 | ql=ql, 475 | ) 476 | 477 | def generative( 478 | self, 479 | inference_outputs, 480 | xs_k, 481 | ): 482 | """ 483 | xs_k : torch.Tensor 484 | Z-normed avg. gene exprs 485 | """ 486 | qz = inference_outputs['qz'] 487 | ql = inference_outputs['ql'] 488 | 489 | hidden = self.z_to_hidden_decoder(qz) 490 | px_scale = self.px_scale_decoder(hidden) 491 | px_rate = torch.exp(ql) * px_scale + self.eps 492 | pc_p = xs_k + self.eps 493 | 494 | return dict( 495 | px_rate=px_rate, 496 | px_r=self.px_r, 497 | px_scale=px_scale, 498 | pc_p=pc_p, 499 | xs_k=xs_k, 500 | ) 501 | 502 | def predictor_img(self, y): 503 | """Inference & generative paths for image view""" 504 | # --- Inference path --- 505 | y_n = torch.log1p(y) 506 | 507 | # q(z | y) 508 | hidden_z = self.img_z_enc(y_n) 509 | qz_m = self.img_z_enc_m(hidden_z) 510 | qz_logv = self.img_z_enc_logv(hidden_z) 511 | qz = self.reparameterize(qz_m, qz_logv) 512 | 513 | # --- Generative path --- 514 | hidden_y = self.img_z_to_hidden_decoder(qz) 515 | py_m = self.py_mu_decoder(hidden_y) 516 | py_logv = self.py_logv_decoder(hidden_y) 517 | 518 | return dict( 519 | # q(z | y) 520 | qz_m_img=qz_m, 521 | qz_logv_img=qz_logv, 522 | qz_img=qz, 523 | 524 | # p(y | z) (image reconst) 525 | py_m=py_m, 526 | py_logv=py_logv 527 | ) 528 | 529 | def generative_img(self, z): 530 | """Generative path of histology view given z""" 531 | hidden_y = self.img_z_to_hidden_decoder(z) 532 | py_m = self.py_mu_decoder(hidden_y) 533 | py_logv = self.py_logv_decoder(hidden_y) 534 | return dict( 535 | py_m=py_m, 536 | py_logv=py_logv 537 | ) 538 | 539 | def predictor_poe( 540 | self, 541 | inference_outputs, 542 | img_outputs, 543 | ): 544 | """Inference & generative paths for Joint view""" 545 | 546 | # Variational params. (expression branch) 547 | ql = inference_outputs['ql'] 548 | qz_m = inference_outputs['qz_m'] 549 | qz_logv = inference_outputs['qz_logv'] 550 | 551 | # Variational params. (img branch) 552 | qz_m_img = img_outputs['qz_m_img'] 553 | qz_logv_img = img_outputs['qz_logv_img'] 554 | 555 | batch, _ = qz_m.shape 556 | 557 | # PoE joint qz 558 | # --- Joint posterior qz with PoE --- 559 | qz_var_poe = torch.div( 560 | 1., 561 | torch.div(1., torch.exp(qz_logv)) + torch.div(1., torch.exp(qz_logv_img)) 562 | ) 563 | qz_m_poe = qz_var_poe * ( 564 | qz_m * torch.div(1., torch.exp(qz_logv) + self.eps) + 565 | qz_m_img * torch.div(1., torch.exp(qz_logv_img) + self.eps) 566 | ) 567 | qz = self.reparameterize(qz_m_poe, torch.log(qz_var_poe)) # Joint posterior 568 | 569 | # PoE joint & view-specific decoders 570 | hidden = self.z_to_hidden_poe_decoder(qz) 571 | 572 | # p(x | z_poe) 573 | px_scale = self.px_scale_poe_decoder(hidden) 574 | px_rate = torch.exp(ql) * px_scale + self.eps 575 | 576 | # p(y | z_poe) 577 | py_m = self.py_mu_poe_decoder(hidden) 578 | py_logv = self.py_logv_poe_decoder(hidden) 579 | 580 | return dict( 581 | # PoE q(z | x, y) 582 | qz_m=qz_m_poe, 583 | qz_logv=torch.log1p(qz_var_poe), 584 | qz=qz, 585 | 586 | # PoE p(x | z, l) & p(y | z) 587 | px_rate=px_rate, 588 | px_r=self.px_r_poe, 589 | py_m=py_m, 590 | py_logv=py_logv 591 | ) 592 | 593 | def get_loss( 594 | self, 595 | generative_outputs, 596 | inference_outputs, 597 | img_outputs, 598 | poe_outputs, 599 | x, 600 | library, 601 | y, 602 | device 603 | ): 604 | lambda_poe = 0.2 605 | 606 | # --- Parse variables --- 607 | # Variational params 608 | qc_m, qc = inference_outputs["qc_m"], inference_outputs["qc"] 609 | 610 | qs_logm = self.qs_logm 611 | qu_m, qu_logv, qu = self.qu_m, self.qu_logv, inference_outputs["qu"] 612 | 613 | qz_m, qz_logv = inference_outputs["qz_m"], inference_outputs["qz_logv"] 614 | ql_m, ql_logv = inference_outputs["ql_m"], inference_outputs['ql_logv'] 615 | 616 | qz_m_img, qz_logv_img = img_outputs['qz_m_img'], img_outputs['qz_logv_img'] 617 | qz_m_poe, qz_logv_poe = poe_outputs['qz_m'], poe_outputs['qz_logv'] 618 | 619 | # Generative params 620 | px_rate = generative_outputs["px_rate"] 621 | px_r = generative_outputs["px_r"] 622 | pc_p = generative_outputs["pc_p"] 623 | 624 | py_m, py_logv = img_outputs['py_m'], img_outputs['py_logv'] 625 | 626 | px_rate_poe = poe_outputs['px_rate'] 627 | px_r_poe = poe_outputs['px_r'] 628 | py_m_poe, py_logv_poe = poe_outputs['py_m'], poe_outputs['py_logv'] 629 | 630 | 631 | # --- Losses --- 632 | # (1). Joint Loss 633 | y_n = torch.log1p(y) 634 | reconst_loss_x_poe = -NegBinom(px_rate_poe, torch.exp(px_r_poe)).log_prob(x).sum(-1).mean() 635 | reconst_loss_y_poe = -Normal(py_m_poe, torch.exp(py_logv_poe/2)).log_prob(y_n).sum(-1).mean() 636 | 637 | # prior: p(z | c, u) 638 | pz_m = (qu.unsqueeze(0) * qc).sum(axis=1) 639 | pz_std = (torch.exp(qs_logm / 2).unsqueeze(0) * qc).sum(axis=1) 640 | 641 | kl_divergence_z_poe = kl( 642 | Normal(qz_m_poe, torch.exp(qz_logv_poe / 2)), 643 | Normal(pz_m, pz_std) 644 | ).sum(dim=1).mean() 645 | 646 | Loss_IBJ = reconst_loss_x_poe + reconst_loss_y_poe + kl_divergence_z_poe.to(device) 647 | 648 | # (2). View-specific losses 649 | # Expression view 650 | kl_divergence_u = kl( 651 | Normal(qu_m, torch.exp(qu_logv / 2)), 652 | Normal(torch.zeros_like(qu_m), torch.ones_like(qu_m) * 10) 653 | ).sum(dim=1).mean() 654 | 655 | kl_divergence_z = kl( 656 | Normal(qz_m, torch.exp(qz_logv / 2)), 657 | Normal(pz_m, pz_std) 658 | ).sum(dim=1).mean() 659 | 660 | kl_divergence_l = kl( 661 | Normal(ql_m, torch.exp(ql_logv / 2)), 662 | Normal(library, torch.ones_like(ql_m)) 663 | ).sum(dim=1).mean() 664 | 665 | kl_divergence_c = kl( 666 | Dirichlet(qc_m * self.alpha), # q(c | x; α) = Dir(α * λ(x)) 667 | Dirichlet(pc_p * self.alpha) 668 | ).mean() 669 | 670 | reconst_loss_x = -NegBinom(px_rate, torch.exp(px_r)).log_prob(x).sum(-1).mean() 671 | loss_exp = reconst_loss_x.to(device) + \ 672 | kl_divergence_u.to(device) + \ 673 | kl_divergence_z.to(device) + \ 674 | kl_divergence_c.to(device) + \ 675 | kl_divergence_l.to(device) 676 | 677 | # Image view 678 | kl_divergence_z_img = kl( 679 | Normal(qz_m_img, torch.sqrt(torch.exp(qz_logv_img / 2))), 680 | Normal(pz_m, pz_std) 681 | ).sum(dim=1).mean() 682 | 683 | reconst_loss_y = -Normal(py_m, torch.exp(py_logv/2)).log_prob(y_n).sum(-1).mean() 684 | loss_img = reconst_loss_y.to(device) + kl_divergence_z_img.to(device) 685 | 686 | # PoE total loss: Joint Loss + a * \sum(marginal loss) 687 | Loss_IBM = (loss_exp + loss_img) 688 | loss = lambda_poe*Loss_IBJ + Loss_IBM 689 | # loss = self.lambda_poe*Loss_IBJ + Loss_IBM 690 | 691 | return ( 692 | # Total loss 693 | loss, 694 | 695 | # sum of marginal reconstruction losses 696 | lambda_poe*(reconst_loss_x_poe+reconst_loss_y_poe) + (reconst_loss_x+reconst_loss_y), 697 | 698 | # KL divergence 699 | kl_divergence_u, 700 | lambda_poe*kl_divergence_z_poe + kl_divergence_z + kl_divergence_z_img, 701 | kl_divergence_c, 702 | kl_divergence_l 703 | ) 704 | 705 | @property 706 | def px_r(self): 707 | return F.softplus(self._px_r) + self.eps 708 | 709 | @property 710 | def px_r_poe(self): 711 | return F.softplus(self._px_r_poe) + self.eps 712 | 713 | 714 | def train( 715 | model, 716 | dataloader, 717 | device, 718 | optimizer, 719 | ): 720 | model.train() 721 | 722 | running_loss = 0.0 723 | running_u = 0.0 724 | running_z = 0.0 725 | running_c = 0.0 726 | running_l = 0.0 727 | running_reconst = 0.0 728 | counter = 0 729 | corr_list = [] 730 | for i, (x, xs_k, x_peri, library_i) in enumerate(dataloader): 731 | 732 | counter += 1 733 | x = x.float() 734 | x = x.to(device) 735 | xs_k = xs_k.to(device) 736 | x_peri = x_peri.to(device) 737 | library_i = library_i.to(device) 738 | 739 | inference_outputs = model.inference(x) 740 | generative_outputs = model.generative(inference_outputs, xs_k) 741 | 742 | # Check for NaNs 743 | #if torch.isnan(loss) or any(torch.isnan(p).any() for p in model.parameters()): 744 | if any(torch.isnan(p).any() for p in model.parameters()): 745 | LOGGER.warning('NaNs detected in model parameters, Skipping current epoch...') 746 | continue 747 | 748 | (loss, 749 | reconst_loss, 750 | kl_divergence_u, 751 | kl_divergence_z, 752 | kl_divergence_c, 753 | kl_divergence_l 754 | ) = model.get_loss( 755 | generative_outputs, 756 | inference_outputs, 757 | x, 758 | library_i, 759 | device 760 | ) 761 | 762 | optimizer.zero_grad() 763 | loss.backward() 764 | 765 | nn.utils.clip_grad_norm_(model.parameters(), 5) 766 | optimizer.step() 767 | 768 | running_loss += loss.item() 769 | running_reconst += reconst_loss.item() 770 | running_u += kl_divergence_u.item() 771 | running_z += kl_divergence_z.item() 772 | running_c += kl_divergence_c.item() 773 | running_l += kl_divergence_l.item() 774 | 775 | train_loss = running_loss / counter 776 | train_reconst = running_reconst / counter 777 | train_u = running_u / counter 778 | train_z = running_z / counter 779 | train_c = running_c / counter 780 | train_l = running_l / counter 781 | 782 | return train_loss, train_reconst, train_u, train_z, train_c, train_l, corr_list 783 | 784 | 785 | def train_poe( 786 | model, 787 | dataloader, 788 | device, 789 | optimizer, 790 | ): 791 | model.train() 792 | 793 | running_loss = 0.0 794 | running_z = 0.0 795 | running_c = 0.0 796 | running_l = 0.0 797 | running_u = 0.0 798 | running_reconst = 0.0 799 | counter = 0 800 | corr_list = [] 801 | for i, (x, 802 | x_peri, 803 | library_i, 804 | img, 805 | data_loc, 806 | xs_k, 807 | ) in enumerate(dataloader): 808 | counter += 1 809 | mini_batch, _ = x.shape 810 | 811 | x = x.float() 812 | x = x.to(device) 813 | x_peri = x_peri.to(device) 814 | library_i = library_i.to(device) 815 | xs_k = xs_k.to(device) 816 | 817 | img = img.reshape(mini_batch, -1).float() 818 | img = img.to(device) 819 | 820 | inference_outputs = model.inference(x,img) # inference for 1D expr. data 821 | generative_outputs = model.generative(inference_outputs, xs_k) 822 | img_outputs = model.predictor_img(img) # inference & generative for 2D img. data 823 | poe_outputs = model.predictor_poe(inference_outputs, img_outputs) # PoE generative outputs 824 | 825 | # Check for NaNs 826 | if any(torch.isnan(p).any() for p in model.parameters()): 827 | LOGGER.warning('NaNs detected in model parameters, Skipping current epoch...') 828 | continue 829 | 830 | (loss, 831 | reconst_loss, 832 | kl_divergence_u, 833 | kl_divergence_z, 834 | kl_divergence_c, 835 | kl_divergence_l 836 | ) = model.get_loss( 837 | generative_outputs, 838 | inference_outputs, 839 | img_outputs, 840 | poe_outputs, 841 | x, 842 | library_i, 843 | img, 844 | device 845 | ) 846 | 847 | optimizer.zero_grad() 848 | loss.backward() 849 | 850 | nn.utils.clip_grad_norm_(model.parameters(), 5) 851 | optimizer.step() 852 | 853 | running_loss += loss.item() 854 | running_reconst += reconst_loss.item() 855 | running_z += kl_divergence_z.item() 856 | running_c += kl_divergence_c.item() 857 | running_l += kl_divergence_l.item() 858 | running_u += kl_divergence_u.item() 859 | 860 | train_loss = running_loss / counter 861 | train_reconst = running_reconst / counter 862 | train_z = running_z / counter 863 | train_c = running_c / counter 864 | train_l = running_l / counter 865 | train_u = running_u / counter 866 | 867 | return train_loss, train_reconst, train_u, train_z, train_c, train_l, corr_list 868 | 869 | 870 | # Reference: 871 | # https://github.com/YosefLab/scvi-tools/blob/master/scvi/distributions/_negative_binomial.py 872 | class NegBinom(Distribution): 873 | """ 874 | Gamma-Poisson mixture approximation of Negative Binomial(mean, dispersion) 875 | 876 | lambda ~ Gamma(mu, theta) 877 | x ~ Poisson(lambda) 878 | """ 879 | arg_constraints = { 880 | 'mu': constraints.greater_than_eq(0), 881 | 'theta': constraints.greater_than_eq(0), 882 | } 883 | support = constraints.nonnegative_integer 884 | 885 | def __init__(self, mu, theta, eps=1e-10): 886 | """ 887 | Parameters 888 | ---------- 889 | mu : torch.Tensor 890 | mean of NegBinom. distribution 891 | shape - [# genes,] 892 | 893 | theta : torch.Tensor 894 | dispersion of NegBinom. distribution 895 | shape - [# genes,] 896 | """ 897 | self.mu = mu 898 | self.theta = theta 899 | self.eps = eps 900 | super(NegBinom, self).__init__(validate_args=True) 901 | 902 | def sample(self): 903 | lambdas = Gamma( 904 | concentration=self.theta + self.eps, 905 | rate=(self.theta + self.eps) / (self.mu + self.eps), 906 | ).rsample() 907 | 908 | x = Poisson(lambdas).sample() 909 | 910 | return x 911 | 912 | def log_prob(self, x): 913 | """log-likelihood""" 914 | ll = torch.lgamma(x + self.theta) - \ 915 | torch.lgamma(x + 1) - \ 916 | torch.lgamma(self.theta) + \ 917 | self.theta * (torch.log(self.theta + self.eps) - torch.log(self.theta + self.mu + self.eps)) + \ 918 | x * (torch.log(self.mu + self.eps) - torch.log(self.theta + self.mu + self.eps)) 919 | 920 | return ll 921 | 922 | 923 | def model_eval( 924 | model, 925 | adata, 926 | visium_args, 927 | poe=False, 928 | device=torch.device('cuda' if torch.cuda.is_available() else 'cpu') 929 | ): 930 | adata_ = adata.copy() 931 | model.eval() 932 | model = model.to(device) 933 | 934 | x_in = torch.Tensor(adata_.to_df().values).to(device) 935 | sig_means = torch.Tensor(visium_args.sig_mean_norm.values).to(device) 936 | anchor_idx = torch.Tensor(visium_args.pure_idx).to(device) 937 | 938 | with torch.no_grad(): 939 | if not poe: 940 | inference_outputs = model.inference(x_in) 941 | generative_outputs = model.generative(inference_outputs, sig_means) 942 | 943 | if poe: 944 | 945 | img_in = torch.Tensor(visium_args.get_img_patches()).float().to(device) 946 | inference_outputs = model.inference(x_in,img_in) 947 | generative_outputs = model.generative(inference_outputs, sig_means) 948 | 949 | img_outputs = model.predictor_img(img_in) 950 | poe_outputs = model.predictor_poe(inference_outputs, img_outputs) 951 | 952 | # Parse image view / PoE inference & generative outputs 953 | # Save to `inference_outputs` & `generative_outputs` 954 | for k, v in img_outputs.items(): 955 | if 'q' in k: 956 | inference_outputs[k] = v 957 | else: 958 | generative_outputs[k] = v 959 | 960 | for k, v in poe_outputs.items(): 961 | if 'q' in k: 962 | inference_outputs[k+'_poe'] = v 963 | else: 964 | generative_outputs[k+'_poe'] = v 965 | 966 | try: 967 | px = NegBinom( 968 | mu=generative_outputs["px_rate"], 969 | theta=torch.exp(generative_outputs["px_r"]) 970 | ).sample().detach().cpu().numpy() 971 | adata_.obsm['px'] = px 972 | except ValueError as ve: 973 | LOGGER.warning('Invalid Gamma distribution parameters `px_rate` or `px_r`, unable to sample inferred p(x | z)') 974 | 975 | # Save inference & generative outputs in adata 976 | for rv in inference_outputs.keys(): 977 | val = inference_outputs[rv].detach().cpu().numpy().squeeze() 978 | if "qu" not in rv and "qs" not in rv: 979 | adata_.obsm[rv] = val 980 | else: 981 | adata_.uns[rv] = val 982 | 983 | for rv in generative_outputs.keys(): 984 | try: 985 | if rv == 'px_r' or rv == 'px_r_poe': 986 | val = generative_outputs[rv].data.detach().cpu().numpy().squeeze() 987 | adata_.varm[rv] = val 988 | else: 989 | val = generative_outputs[rv].data.detach().cpu().numpy().squeeze() 990 | adata_.obsm[rv] = val 991 | except: 992 | print("rv: {} can't be stored".format(rv)) 993 | 994 | qz_umap = get_z_umap(adata_.obsm['qz_m']) 995 | adata_.obsm['z_umap'] = qz_umap 996 | return inference_outputs, generative_outputs, adata_ 997 | 998 | 999 | def model_eval_integrate( 1000 | model, 1001 | adata, 1002 | visium_args, 1003 | poe=False, 1004 | device=torch.device('cpu') 1005 | ): 1006 | """ 1007 | Model evaluation for sample integration 1008 | TODO: code refactor 1009 | """ 1010 | model.eval() 1011 | model = model.to(device) 1012 | adata_ = adata.copy() 1013 | x_in = torch.Tensor(adata_.to_df().values).to(device) 1014 | sig_means = torch.Tensor(visium_args.sig_mean_norm.values).to(device) 1015 | anchor_idx = torch.Tensor(visium_args.pure_idx).to(device) 1016 | 1017 | with torch.no_grad(): 1018 | if not poe: 1019 | inference_outputs = model.inference(x_in) 1020 | generative_outputs = model.generative(inference_outputs, sig_means) 1021 | if poe: 1022 | img_in = torch.Tensor(visium_args.get_img_patches()).float().to(device) 1023 | 1024 | inference_outputs = model.inference(x_in, img_in) 1025 | generative_outputs = model.generative(inference_outputs, sig_means) 1026 | 1027 | img_outputs = model.predictor_img(img_in) if img_in.max() <= 1 else model.predictor_img(img_in/255) 1028 | poe_outputs = model.predictor_poe(inference_outputs, img_outputs) 1029 | 1030 | # Parse image view / PoE inference & generative outputs 1031 | # Save to `inference_outputs` & `generative_outputs` 1032 | for k, v in img_outputs.items(): 1033 | if 'q' in k: 1034 | inference_outputs[k] = v 1035 | else: 1036 | generative_outputs[k] = v 1037 | 1038 | for k, v in poe_outputs.items(): 1039 | if 'q' in k: 1040 | inference_outputs[k+'_poe'] = v 1041 | else: 1042 | generative_outputs[k+'_poe'] = v 1043 | 1044 | # TODO: move histology reconstruction to a separate func 1045 | 1046 | # # Reconst histology prediction 1047 | # # reconst_img_patches = img_outputs['py_m'] 1048 | # reconst_poe_img_patches = poe_outputs['py_m'] 1049 | # # reconst_img_all = {} 1050 | # reconst_poe_img_all = {} 1051 | 1052 | # batch_idx = 0 # img metadata counter for each counter 1053 | 1054 | # for sample_id in visium_args.adata.obs['sample'].unique(): 1055 | 1056 | # img_dim = visium_args.img[sample_id].shape 1057 | # # reconst_img = np.ones((img_dim + (model.nimg_chan,))) * (-1/255) 1058 | # reconst_poe_img = np.ones((img_dim + (model.nimg_chan,))) * (-1/255) 1059 | 1060 | # # image_col = img_metadata[i]['map_info']['imagecol']*img_metadata[i]['scalefactor']['tissue_hires_scalef'] 1061 | # # image_row = img_metadata[i]['map_info']['imagerow']*img_metadata[i]['scalefactor']['tissue_hires_scalef'] 1062 | # map_info = visium_args.map_info[adata.obs['sample'] == sample_id] 1063 | # scalefactor = visium_args.scalefactor[sample_id] 1064 | # image_col = map_info['imagecol'] * scalefactor['tissue_hires_scalef'] 1065 | # image_row = map_info['imagerow'] * scalefactor['tissue_hires_scalef'] 1066 | 1067 | # for idx in range(image_col.shape[0]): 1068 | 1069 | # patch_y = slice(int(image_row[idx])-model.patch_r, int(image_row[idx])+model.patch_r) 1070 | # patch_x = slice(int(image_col[idx])-model.patch_r, int(image_col[idx])+model.patch_r) 1071 | 1072 | # """ 1073 | # reconst_img[patch_y, patch_x, :] = reconst_img_patches[idx+batch_idx].reshape([ 1074 | # model.patch_r*2, 1075 | # model.patch_r*2, 1076 | # model.nimg_chan 1077 | # ]).cpu().detach().numpy() 1078 | # """ 1079 | # reconst_poe_img[patch_y, patch_x, :] = reconst_poe_img_patches[idx+batch_idx].reshape([ 1080 | # model.patch_r*2, 1081 | # model.patch_r*2, 1082 | # model.nimg_chan 1083 | # ]).cpu().detatch().numpy() 1084 | 1085 | # # reconst_img_all[sample_id] = reconst_img 1086 | # reconst_poe_img_all[sample_id] = reconst_poe_img 1087 | # batch_idx += image_col.shape[0] 1088 | 1089 | # Update reconstructed image 1090 | # adata.uns['reconst_img'] = reconst_poe_img_all 1091 | 1092 | try: 1093 | px = NegBinom( 1094 | mu=generative_outputs["px_rate"], 1095 | theta=torch.exp(generative_outputs["px_r"]) 1096 | ).sample().detach().cpu().numpy() 1097 | adata_.obsm['px'] = px 1098 | except ValueError as ve: 1099 | LOGGER.warning('Invalid Gamma distribution parameters `px_rate` or `px_r`, unable to sample inferred p(x | z)') 1100 | 1101 | # Save inference & generative outputs in adata 1102 | for rv in inference_outputs.keys(): 1103 | val = inference_outputs[rv].detach().cpu().numpy().squeeze() 1104 | if "qu" not in rv and "qs" not in rv: 1105 | adata_.obsm[rv] = val 1106 | else: 1107 | adata_.uns[rv] = val 1108 | 1109 | for rv in generative_outputs.keys(): 1110 | try: 1111 | if rv == 'px_r' or rv == 'reconstruction': # Posterior avg. znorm signature means 1112 | val = generative_outputs[rv].data.detach().cpu().numpy().squeeze() 1113 | adata_.varm[rv] = val 1114 | else: 1115 | val = generative_outputs[rv].data.detach().cpu().numpy().squeeze() 1116 | adata_.obsm[rv] = val 1117 | except: 1118 | print("rv: {} can't be stored".format(rv)) 1119 | 1120 | return inference_outputs, generative_outputs, adata_ 1121 | 1122 | 1123 | def model_ct_exp( 1124 | model, 1125 | adata, 1126 | visium_args, 1127 | poe = False, 1128 | device=torch.device('cuda' if torch.cuda.is_available() else 'cpu') 1129 | ): 1130 | """ 1131 | Obtain generative cell-type specific expression in each spot (after model training) 1132 | """ 1133 | sig_means = torch.Tensor(visium_args.sig_mean_norm.values).to(device) 1134 | anchor_idx = torch.Tensor(visium_args.pure_idx).to(device) 1135 | x_in = torch.Tensor(adata.to_df().values).to(device) 1136 | if poe: 1137 | y_in = torch.Tensor(visium_args.get_img_patches()).float().to(device) 1138 | 1139 | 1140 | model.eval() 1141 | model = model.to(device) 1142 | pred_exprs = {} 1143 | 1144 | for ct_idx, cell_type in enumerate(adata.uns['cell_types']): 1145 | # Get inference outputs for the given cell type 1146 | 1147 | if poe: 1148 | inference_outputs = model.inference(x_in,y_in) 1149 | else: 1150 | inference_outputs = model.inference(x_in) 1151 | inference_outputs['qz'] = inference_outputs['qz_m_ct'][:, ct_idx, :] 1152 | 1153 | # Get generative outputs 1154 | generative_outputs = model.generative(inference_outputs, sig_means) 1155 | 1156 | px = NegBinom( 1157 | mu=generative_outputs["px_rate"], 1158 | theta=torch.exp(generative_outputs["px_r"]) 1159 | ).sample() 1160 | px = px.detach().cpu().numpy() 1161 | 1162 | # Save results in adata.obsm 1163 | px_df = pd.DataFrame(px, index=adata.obs_names, columns=adata.var_names) 1164 | pred_exprs[cell_type] = px_df 1165 | adata.obsm[cell_type + '_inferred_exprs'] = px 1166 | 1167 | return pred_exprs 1168 | 1169 | 1170 | def model_ct_img( 1171 | model, 1172 | adata, 1173 | visium_args, 1174 | device=torch.device('cuda' if torch.cuda.is_available() else 'cpu') 1175 | ): 1176 | """ 1177 | Obtain generative cell-type specific images (after model training) 1178 | """ 1179 | x_in = torch.Tensor(adata.to_df().values).to(device) 1180 | y_in = torch.Tensor(visium_args.get_img_patches()).float().to(device) 1181 | 1182 | model.eval() 1183 | model = model.to(device) 1184 | ct_imgs = {} 1185 | ct_imgs_poe = {} 1186 | 1187 | for ct_idx, cell_type in enumerate(adata.uns['cell_types']): 1188 | # Get inference outputs for the given cell type 1189 | inference_outputs = model.inference(x_in) 1190 | qz_m_ct, qz_logv_ct = inference_outputs['qz_m_ct'][:, ct_idx, :], inference_outputs['qz_logv_ct'][:, ct_idx, :] 1191 | qz_ct = model.reparameterize(qz_m_ct, qz_logv_ct) 1192 | inference_outputs['qz_m'], inference_outputs['qz'] = qz_m_ct, qz_ct 1193 | 1194 | # Generate cell-type specific low-dim representations (z) 1195 | img_outputs = model.predictor_img(y_in) 1196 | qz_m_ct_img, qz_logv_ct_img = img_outputs['qz_m_ct_img'][:, ct_idx, :], img_outputs['qz_logv_ct_img'][:, ct_idx, :] 1197 | qz_ct_img = model.reparameterize(qz_m_ct_img, qz_logv_ct_img) 1198 | img_outputs['qz_m_img'], img_outputs['qz_img'] = qz_m_ct_img, qz_ct_img 1199 | generative_outputs_img = model.generative_img(qz_ct_img) 1200 | 1201 | poe_outputs = model.predictor_poe(inference_outputs, img_outputs) 1202 | 1203 | ct_imgs[cell_type] = generative_outputs_img['py_m'].detach().cpu().numpy() 1204 | ct_imgs_poe[cell_type] = poe_outputs['py_m'].detach().cpu().numpy() 1205 | 1206 | return ct_imgs, ct_imgs_poe 1207 | 1208 | -------------------------------------------------------------------------------- /starfysh/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import json 4 | import numpy as np 5 | import pandas as pd 6 | import scanpy as sc 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | 12 | from scipy.stats import median_abs_deviation 13 | from torch.utils.data import DataLoader 14 | 15 | import sys 16 | import histomicstk as htk 17 | from skimage import io 18 | 19 | # Module import 20 | from starfysh import LOGGER 21 | from .dataloader import VisiumDataset, VisiumPoEDataSet 22 | from .starfysh import AVAE, AVAE_PoE, train, train_poe 23 | 24 | 25 | # ------------------- 26 | # Model Parameters 27 | # ------------------- 28 | 29 | class VisiumArguments: 30 | """ 31 | Loading Visium AnnData, perform preprocessing, library-size smoothing & Anchor spot detection 32 | 33 | Parameters 34 | ---------- 35 | adata : AnnData 36 | annotated visium count matrix 37 | 38 | adata_norm : AnnData 39 | annotated visium count matrix after normalization & log-transform 40 | 41 | gene_sig : pd.DataFrame 42 | list of signature genes for each cell type. (dim: [S, Cell_type]) 43 | 44 | img_metadata : dict 45 | Spatial information metadata (histology image, coordinates, scalefactor) 46 | """ 47 | def __init__( 48 | self, 49 | adata, 50 | adata_norm, 51 | gene_sig, 52 | img_metadata, 53 | **kwargs 54 | ): 55 | 56 | self.adata = adata 57 | self.adata_norm = adata_norm 58 | self.gene_sig = gene_sig 59 | self.map_info = img_metadata['map_info'].iloc[:, :4].astype(float) 60 | self.scalefactor = img_metadata['scalefactor'] 61 | self.img = img_metadata['img'] 62 | self.img_patches = None 63 | self.eps = 1e-6 64 | 65 | self.params = { 66 | 'sample_id': 'ST', 67 | 'n_anchors': int(adata.shape[0]), 68 | 'patch_r': 16, 69 | 'signif_level': 3, 70 | 'window_size': 1, 71 | 'n_img_chan': 1 72 | } 73 | 74 | # Update parameters for library smoothing & anchor spot identification 75 | for k, v in kwargs.items(): 76 | if k in self.params.keys(): 77 | self.params[k] = v 78 | 79 | # Center expression for gene score calculation 80 | adata_scale = self.adata_norm.copy() 81 | sc.pp.scale(adata_scale) 82 | 83 | # Store cell types 84 | self.adata.uns['cell_types'] = list(self.gene_sig.columns) 85 | 86 | # Filter out signature genes X listed in expression matrix 87 | LOGGER.info('Subsetting highly variable & signature genes ...') 88 | self.adata, self.adata_norm = get_adata_wsig(adata, adata_norm, gene_sig) 89 | self.adata_scale = adata_scale[:, adata.var_names] 90 | 91 | # Update spatial metadata 92 | self._update_spatial_info(self.params['sample_id']) 93 | 94 | # Get smoothed library size 95 | LOGGER.info('Smoothing library size by taking averaging with neighbor spots...') 96 | log_lib = np.log1p(self.adata.X.sum(1)) 97 | self.log_lib = np.squeeze(np.asarray(log_lib)) if log_lib.ndim > 1 else log_lib 98 | 99 | self.win_loglib = get_windowed_library(self.adata, 100 | self.map_info, 101 | self.log_lib, 102 | window_size=self.params['window_size']) 103 | 104 | # Retrieve & normalize signature gene expressions by 105 | # comparing w/ randomized base expression (`sc.tl.score_genes`) 106 | LOGGER.info('Retrieving & normalizing signature gene expressions...') 107 | self.sig_mean = self._get_sig_mean() 108 | self.sig_mean_norm = self._calc_gene_scores() 109 | 110 | # Get anchor spots 111 | LOGGER.info('Identifying anchor spots (highly expression of specific cell-type signatures)...') 112 | anchor_info = self._compute_anchors() 113 | 114 | # row-norm; "ReLU" on signature scores for valid dirichlet param 115 | self.sig_mean_norm[self.sig_mean_norm < 0] = self.eps 116 | self.sig_mean_norm = self.sig_mean_norm.div(self.sig_mean_norm.sum(1), axis=0) 117 | self.sig_mean_norm.fillna(1/self.sig_mean_norm.shape[1], inplace=True) 118 | 119 | self.pure_spots, self.pure_dict, self.pure_idx = anchor_info 120 | del self.adata.raw, self.adata_norm.raw 121 | 122 | def get_adata(self): 123 | """Return adata after preprocessing & HVG gene selection""" 124 | return self.adata, self.adata_norm 125 | 126 | def get_anchors(self): 127 | """Return indices of anchor spots for each cell type""" 128 | anchors_df = pd.DataFrame.from_dict(self.pure_dict, orient='index') 129 | anchors_df = anchors_df.transpose() 130 | 131 | # Check whether empty anchors detected for any factor 132 | empty_indices = np.where( 133 | (~pd.isna(anchors_df)).sum(0) == 0 134 | )[0] 135 | 136 | if len(empty_indices) > 0: 137 | raise ValueError("Cell type(s) {} has no anchors significantly enriched for its signatures," 138 | "please lower outlier stats `signif_level` or try zscore-based signature".format( 139 | anchors_df.columns[empty_indices].to_list() 140 | )) 141 | 142 | return anchors_df.applymap( 143 | lambda x: 144 | -1 if x is None else np.where(self.adata.obs.index == x)[0][0] 145 | ) 146 | 147 | def get_img_patches(self): 148 | assert self.img_patches is not None, "Please run Starfysh PoE first" 149 | return self.img_patches 150 | 151 | def append_factors(self, arche_markers): 152 | """ 153 | Append list of archetypes (w/ corresponding markers) as additional cell type(s) / state(s) to the `gene_sig` 154 | """ 155 | self.gene_sig = pd.concat((self.gene_sig, arche_markers), axis=1) 156 | 157 | # Update factor names & anchor spots 158 | self.adata.uns['cell_types'] = list(self.gene_sig.columns) 159 | self._update_anchors() 160 | return None 161 | 162 | def replace_factors(self, factors_to_repl, arche_markers): 163 | """ 164 | Replace factor(s) with archetypes & their corresponding markers in the `gene_sig` 165 | """ 166 | if isinstance(factors_to_repl, str): 167 | assert isinstance(arche_markers, pd.Series),\ 168 | "Please pick only one archetype to replace the factor {}".format(factors_to_repl) 169 | factors_to_repl = [factors_to_repl] 170 | archetypes = [arche_markers.name] 171 | else: 172 | assert len(factors_to_repl) == len(arche_markers.columns), \ 173 | "Unequal # cell types & archetypes to replace with" 174 | archetypes = arche_markers.columns 175 | 176 | self.gene_sig.rename( 177 | columns={ 178 | f: a 179 | for (f, a) in zip(factors_to_repl, archetypes) 180 | }, inplace=True 181 | ) 182 | self.gene_sig[archetypes] = pd.DataFrame(arche_markers) 183 | 184 | # Update factor names & anchor spots 185 | self.adata.uns['cell_types'] = list(self.gene_sig.columns) 186 | self._update_anchors() 187 | return None 188 | 189 | # --- Private methods --- 190 | def _compute_anchors(self): 191 | """ 192 | Calculate top `anchor_spots` significantly enriched for given cell type(s) 193 | determined by gene set scores from signatures 194 | """ 195 | score_df = self.sig_mean_norm 196 | n_anchor = self.params['n_anchors'] 197 | 198 | top_expr_spots = (-score_df.values).argsort(axis=0)[:n_anchor, :] 199 | pure_spots = np.transpose(np.array(score_df.index)[top_expr_spots]) 200 | 201 | pure_dict = { 202 | ct: spot 203 | for (spot, ct) in zip(pure_spots, score_df.columns) 204 | } 205 | 206 | pure_indices = np.zeros([score_df.shape[0], 1]) 207 | idx = [np.where(score_df.index == i)[0][0] 208 | for i in sorted({x for v in pure_dict.values() for x in v})] 209 | pure_indices[idx] = 1 210 | return pure_spots, pure_dict, pure_indices 211 | 212 | def _update_anchors(self): 213 | """Re-calculate anchor spots given updated gene signatures""" 214 | self.sig_mean = self._get_sig_mean() 215 | self.sig_mean_norm = self._calc_gene_scores() 216 | self.adata.uns['cell_types'] = list(self.gene_sig.columns) 217 | 218 | LOGGER.info('Recalculating anchor spots (highly expression of specific cell-type signatures)...') 219 | anchor_info = self._compute_anchors() 220 | self.sig_mean_norm[self.sig_mean_norm < 0] = self.eps 221 | self.sig_mean_norm.fillna(1/self.sig_mean_norm.shape[1], inplace=True) 222 | self.pure_spots, self.pure_dict, self.pure_idx = anchor_info 223 | 224 | def _get_sig_mean(self): 225 | sig_mean_expr = pd.DataFrame() 226 | cnt_df = self.adata_norm.to_df() 227 | 228 | # Calculate avg. signature expressions for each cell type 229 | for i, cell_type in enumerate(self.gene_sig.columns): 230 | sigs = np.intersect1d(cnt_df.columns, self.gene_sig.iloc[:, i].astype(str)) 231 | 232 | if len(sigs) == 0: 233 | raise ValueError("Empty signatures for {}," 234 | "please double check your `gene_sig` input or set a higher" 235 | "`n_gene` threshold upon dataloading".format(cell_type)) 236 | 237 | else: 238 | sig_mean_expr[cell_type] = cnt_df.loc[:, sigs].mean(axis=1) 239 | 240 | sig_mean_expr.index = self.adata.obs_names 241 | sig_mean_expr.columns = self.gene_sig.columns 242 | return sig_mean_expr 243 | 244 | def _update_spatial_info(self, sample_id): 245 | """Update paired spatial information to ST adata""" 246 | # Update image channel count for RGB input (`y`) 247 | if self.img is not None and self.img.ndim == 3: 248 | self.params['n_img_chan'] = 3 249 | 250 | if 'spatial' not in self.adata.uns_keys(): 251 | self.adata.uns['spatial'] = { 252 | sample_id: { 253 | 'images': {'hires': (self.img - self.img.min()) / (self.img.max() - self.img.min())}, 254 | 'scalefactors': self.scalefactor 255 | }, 256 | } 257 | 258 | self.adata_norm.uns['spatial'] = { 259 | sample_id: { 260 | 'images': {'hires': (self.img - self.img.min()) / (self.img.max() - self.img.min())}, 261 | 'scalefactors': self.scalefactor 262 | }, 263 | } 264 | 265 | self.adata.obsm['spatial'] = self.map_info[['imagecol', 'imagerow']].values 266 | self.adata_norm.obsm['spatial'] = self.map_info[['imagecol', 'imagerow']].values 267 | 268 | # Typecast: spatial coords. 269 | self.adata_norm.obsm['spatial'] = self.map_info[['imagecol', 'imagerow']].values 270 | self.adata_norm.obsm['spatial'] = self.adata_norm.obsm['spatial'].astype(np.float32) 271 | return None 272 | 273 | def _update_img_patches(self, dl_poe): 274 | imgs = torch.Tensor(dl_poe.spot_img_stack) 275 | self.img_patches = imgs.reshape(imgs.shape[0], -1) 276 | return None 277 | 278 | def _norm_sig(self): 279 | # col-norm for each cell type: divided by mean 280 | gexp = self.sig_mean.apply(lambda x: x / x.mean(), axis=0) 281 | return gexp 282 | 283 | def _calc_gene_scores(self): 284 | """Calculate gene set enrichment scores for each signature sets""" 285 | adata = self.adata_scale.copy() 286 | #adata = self.adata_norm.copy() 287 | for cell_type in self.gene_sig.columns: 288 | sig = self.gene_sig[cell_type][~pd.isna(self.gene_sig[cell_type])].to_list() 289 | sc.tl.score_genes(adata, sig, use_raw=False, score_name=cell_type+'_score') 290 | 291 | gsea_df = adata.obs[[cell_type+'_score' for cell_type in self.gene_sig.columns]] 292 | gsea_df.columns = self.gene_sig.columns 293 | return gsea_df 294 | 295 | # -------------------------------- 296 | # Running starfysh with 3-restart 297 | # -------------------------------- 298 | 299 | def init_weights(module): 300 | if type(module) == nn.Linear: 301 | nn.init.xavier_uniform_(module.weight) 302 | 303 | elif type(module) == nn.BatchNorm1d: 304 | module.bias.data.zero_() 305 | module.weight.data.fill_(1.0) 306 | 307 | 308 | def run_starfysh( 309 | visium_args, 310 | n_repeats=3, 311 | lr=1e-4, 312 | epochs=100, 313 | batch_size=32, 314 | alpha_mul=50, 315 | poe=False, 316 | device=torch.device('cpu'), 317 | seed=0, 318 | verbose=True 319 | ): 320 | """ 321 | Wrapper to run starfysh deconvolution. 322 | 323 | Parameters 324 | ---------- 325 | visium_args : VisiumArguments 326 | Preprocessed metadata calculated from input visium matrix: 327 | e.g. mean signature expression, library size, anchor spots, etc. 328 | 329 | n_repeats : int 330 | Number of restart to run Starfysh 331 | 332 | epochs : int 333 | Max. number of iterations 334 | 335 | poe : bool 336 | Whether to perform inference with Poe w/ image integration 337 | 338 | Returns 339 | ------- 340 | best_model : starfysh.AVAE or starfysh.AVAE_PoE 341 | Trained Starfysh model with deconvolution results 342 | 343 | loss : np.ndarray 344 | Training losses 345 | """ 346 | np.random.seed(seed) 347 | 348 | # Loading parameters 349 | adata = visium_args.adata 350 | win_loglib = visium_args.win_loglib 351 | gene_sig, sig_mean_norm = visium_args.gene_sig, visium_args.sig_mean_norm 352 | 353 | models = [None] * n_repeats 354 | losses = [] 355 | loss_c_list = np.repeat(np.inf, n_repeats) 356 | 357 | if poe: 358 | dl_func = VisiumPoEDataSet # dataloader 359 | train_func = train_poe # training wrapper 360 | else: 361 | dl_func = VisiumDataset 362 | train_func = train 363 | 364 | trainset = dl_func(adata=adata, args=visium_args) 365 | trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, drop_last=True) 366 | 367 | # Running Starfysh with multiple starts 368 | LOGGER.info('Running Starfysh with {} restarts, choose the model with best parameters...'.format(n_repeats)) 369 | 370 | count = 0 371 | while count < n_repeats: 372 | best_loss_c = np.inf 373 | if poe: 374 | model = AVAE_PoE( 375 | adata=adata, 376 | gene_sig=sig_mean_norm, 377 | patch_r=visium_args.params['patch_r'], 378 | win_loglib=win_loglib, 379 | alpha_mul=alpha_mul, 380 | n_img_chan=visium_args.params['n_img_chan'] 381 | ) 382 | # Update patched & flattened image patches 383 | visium_args._update_img_patches(trainset) 384 | else: 385 | model = AVAE( 386 | adata=adata, 387 | gene_sig=sig_mean_norm, 388 | win_loglib=win_loglib, 389 | alpha_mul=alpha_mul 390 | ) 391 | 392 | model = model.to(device) 393 | loss_dict = { 394 | 'reconst': [], 395 | 'c': [], 396 | 'u': [], 397 | 'z': [], 398 | 'n': [], 399 | 'tot': [] 400 | } 401 | 402 | # Initialize model params 403 | if verbose: 404 | LOGGER.info('Initializing model parameters...') 405 | 406 | model.apply(init_weights) 407 | optimizer = optim.Adam(model.parameters(), lr=lr) 408 | scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.98) 409 | 410 | try: 411 | for epoch in range(epochs): 412 | result = train_func(model, trainloader, device, optimizer) 413 | torch.cuda.empty_cache() 414 | 415 | loss_tot, loss_reconst, loss_u, loss_z, loss_c, loss_n, corr_list = result 416 | if loss_c < best_loss_c: 417 | models[count] = model 418 | best_loss_c = loss_c 419 | 420 | torch.cuda.empty_cache() 421 | 422 | loss_dict['tot'].append(loss_tot) 423 | loss_dict['reconst'].append(loss_reconst) 424 | loss_dict['u'].append(loss_u) 425 | loss_dict['z'].append(loss_z) 426 | loss_dict['c'].append(loss_c) 427 | loss_dict['n'].append(loss_n) 428 | 429 | if (epoch + 1) % 10 == 0 and verbose: 430 | LOGGER.info("Epoch[{}/{}], train_loss: {:.4f}, train_reconst: {:.4f}, train_u: {:.4f},train_z: {:.4f},train_c: {:.4f},train_l: {:.4f}".format( 431 | epoch + 1, epochs, loss_tot, loss_reconst, loss_u, loss_z, loss_c, loss_n) 432 | ) 433 | scheduler.step() 434 | 435 | losses.append(loss_dict) 436 | loss_c_list[count] = best_loss_c 437 | 438 | count += 1 439 | 440 | except ValueError as ve: # Bad model initialization -> numerical instability 441 | continue 442 | 443 | if verbose: 444 | LOGGER.info('Saving the best-performance model...') 445 | LOGGER.info(" === Finished training === \n") 446 | 447 | idx = np.argmin(loss_c_list) 448 | best_model = models[idx] 449 | loss = losses[idx] 450 | 451 | return best_model, loss 452 | 453 | 454 | # ------------------- 455 | # Preprocessing & IO 456 | # ------------------- 457 | 458 | def preprocess( 459 | adata_raw, 460 | min_perc=None, 461 | max_perc=None, 462 | n_top_genes=2000, 463 | mt_thld=100, 464 | verbose=True, 465 | multiple_data=False 466 | ): 467 | """ 468 | Preprocessing ST gexp matrix, remove Ribosomal & Mitochondrial genes 469 | 470 | Parameters 471 | ---------- 472 | adata_raw : annData 473 | Spot x Bene raw expression matrix [S x G] 474 | 475 | min_perc : float 476 | lower-bound percentile of non-zero gexps for filtering spots 477 | 478 | max_perc : float 479 | upper-bound percentile of non-zero gexps for filtering spots 480 | 481 | n_top_genes: float 482 | number of the variable genes 483 | 484 | mt_thld : float 485 | max. percentage of mitochondrial gexps for filtering spots 486 | with excessive MT expressions 487 | 488 | multiple_data: bool 489 | whether the study need integrate datasets 490 | """ 491 | adata = adata_raw.copy() 492 | 493 | if min_perc and max_perc: 494 | assert 0 < min_perc < max_perc < 100, \ 495 | "Invalid thresholds for cells: {0}, {1}".format(min_perc, max_perc) 496 | min_counts = np.percentile(adata.obs['total_counts'], min_perc) 497 | sc.pp.filter_cells(adata, min_counts=min_counts) 498 | 499 | # Remove cells with excessive MT expressions 500 | # Remove MT & RB genes 501 | if verbose: 502 | LOGGER.info('Preprocessing1: delete the mt and rp') 503 | 504 | adata.var['mt'] = np.logical_or( 505 | adata.var_names.str.startswith('MT-'), 506 | adata.var_names.str.startswith('mt-') 507 | ) 508 | adata.var['rb'] = adata.var_names.str.startswith(('RP', 'Rp', 'rp')) 509 | 510 | sc.pp.calculate_qc_metrics(adata, qc_vars=['mt'], inplace=True) 511 | mask_cell = adata.obs['pct_counts_mt'] < mt_thld 512 | mask_gene = np.logical_and(~adata.var['mt'], ~adata.var['rb']) 513 | 514 | adata = adata[mask_cell, mask_gene] 515 | sc.pp.filter_genes(adata, min_cells=1) 516 | 517 | # Normalize & take log-transform 518 | if verbose: 519 | LOGGER.info('Preprocessing2: Normalize') 520 | if multiple_data: 521 | sc.pp.normalize_total(adata, target_sum=1e6, inplace=True) 522 | else: 523 | sc.pp.normalize_total(adata, inplace=True) 524 | 525 | # Preprocessing3: Logarithm 526 | if verbose: 527 | LOGGER.info('Preprocessing3: Logarithm') 528 | sc.pp.log1p(adata) 529 | 530 | # Preprocessing4: Find the variable genes 531 | if verbose: 532 | LOGGER.info('Preprocessing4: Find the variable genes') 533 | sc.pp.highly_variable_genes(adata, flavor='seurat', n_top_genes=n_top_genes, inplace=True) 534 | 535 | # Filter corresponding `obs` & `var` in raw-count matrix 536 | adata_raw = adata_raw[adata.obs_names, adata.var_names] 537 | adata_raw.var['highly_variable'] = adata.var['highly_variable'] 538 | adata_raw.obs = adata.obs 539 | 540 | return adata_raw, adata 541 | 542 | 543 | def load_adata(data_folder, sample_id, n_genes, multiple_data=False): 544 | """ 545 | load visium adata with raw counts, preprocess & extract highly variable genes 546 | 547 | Parameters 548 | ---------- 549 | data_folder : str 550 | Root directory of the data 551 | 552 | sample_id : str 553 | Sample subdirectory under `data_folder` 554 | 555 | n_genes : int 556 | the number of the gene for training 557 | 558 | multiple_data: bool 559 | whether the study include multiple datasets 560 | 561 | Returns 562 | ------- 563 | adata : sc.AnnData 564 | Processed ST raw counts 565 | 566 | adata_norm : sc.AnnData 567 | Processed ST normalized & log-transformed data 568 | """ 569 | has_feature_h5 = os.path.isfile( 570 | os.path.join(data_folder, sample_id, 'filtered_feature_bc_matrix.h5') 571 | ) # whether dataset stored in h5 with spatial info. 572 | 573 | if has_feature_h5: 574 | adata = sc.read_visium(path=os.path.join(data_folder, sample_id), library_id=sample_id) 575 | adata.var_names_make_unique() 576 | adata.obs['sample'] = sample_id 577 | elif sample_id.startswith('simu'): # simulations 578 | adata = sc.read_csv(os.path.join(data_folder, sample_id, 'counts.st_synth.csv')) 579 | else: 580 | filenames = [ 581 | f[:-5] for f in os.listdir(os.path.join(data_folder, sample_id)) 582 | if f[-5:] == '.h5ad' 583 | ] 584 | assert len(filenames) == 1, \ 585 | "None or more than `h5ad` file in the data directory," \ 586 | "please contain only 1 target ST file in the given directory" 587 | adata = sc.read_h5ad(os.path.join(data_folder, sample_id, filenames[0] + '.h5ad')) 588 | adata.var_names_make_unique() 589 | adata.obs['sample'] = sample_id 590 | 591 | if '_index' in adata.var.columns: 592 | adata.var_names = adata.var['_index'] 593 | adata.var_names.name = 'Genes' 594 | adata.var.drop('_index', axis=1, inplace=True) 595 | 596 | adata, adata_norm = preprocess(adata, n_top_genes=n_genes, multiple_data=multiple_data) 597 | return adata, adata_norm 598 | 599 | 600 | def load_signatures(filename, adata): 601 | """ 602 | load annotated signature gene sets 603 | 604 | Parameters 605 | ---------- 606 | filename : str 607 | Signature file 608 | 609 | adata : sc.AnnData 610 | ST count matrix 611 | 612 | Returns 613 | ------- 614 | gene_sig : pd.DataFrame 615 | signatures per cell type / state 616 | """ 617 | assert os.path.isfile(filename), "Unable to find the signature file" 618 | gene_sig = pd.read_csv(filename, index_col=0) 619 | gene_sig = filter_gene_sig(gene_sig, adata.to_df()) 620 | sigs = np.unique( 621 | gene_sig.apply( 622 | lambda x: 623 | pd.unique(x[~pd.isna(x)]) 624 | ).values 625 | ) 626 | 627 | return gene_sig, np.unique(sigs) 628 | 629 | 630 | def preprocess_img( 631 | data_path, 632 | sample_id, 633 | adata_index, 634 | rgb_channels=True 635 | ): 636 | """ 637 | Load and preprocess visium paired H&E image & spatial coords 638 | 639 | Parameters 640 | ---------- 641 | data_path : str 642 | Root directory of the data 643 | 644 | sample_id : str 645 | Sample subdirectory under `data_path` 646 | 647 | rgb_channels : bool 648 | Whether to apply binary color deconvolution to extract 1D `eosin` channel 649 | Please refer to: 650 | https://digitalslidearchive.github.io/HistomicsTK/examples/color_deconvolution.html 651 | 652 | Returns 653 | ------- 654 | img : np.ndarray 655 | Processed histology image 656 | 657 | map_info : np.ndarray 658 | Spatial coords of spots (dim: [S, 2]) 659 | """ 660 | filename = os.path.join(data_path, sample_id, 'spatial', 'tissue_hires_image.png') 661 | if os.path.isfile(filename): 662 | if rgb_channels: 663 | img = io.imread(filename) 664 | #img = (img-img.min())/(img.max()-img.min()) 665 | 666 | else: 667 | img = io.imread(filename) 668 | if img.max() <= 1: 669 | img = (img * 255).astype(np.uint8) 670 | 671 | # Create stain matrix 672 | stains = ['hematoxylin','eosin', 'null'] 673 | stain_cmap = htk.preprocessing.color_deconvolution.stain_color_map 674 | W = np.array([stain_cmap[st] for st in stains]).T 675 | 676 | # Color deconvolution 677 | imDeconvolved = htk.preprocessing.color_deconvolution.color_deconvolution(img, W) 678 | img = imDeconvolved.Stains[:,:,0] # H-channel 679 | 680 | # Take inverse of H-channel (approx. cell density) 681 | img = (img - img.min()) / (img.max()-img.min()) 682 | img = img.max() - img 683 | img = (img*255).astype(np.uint8) 684 | else: 685 | img = None 686 | 687 | # Mapping images to location 688 | f = open(os.path.join(data_path, sample_id, 'spatial', 'scalefactors_json.json', )) 689 | json_info = json.load(f) 690 | f.close() 691 | 692 | tissue_position_list = pd.read_csv(os.path.join(data_path, sample_id, 'spatial', 'tissue_positions_list.csv'), header=None, index_col=0) 693 | tissue_position_list = tissue_position_list.loc[adata_index, :] 694 | map_info = tissue_position_list.iloc[:, -4:-2] 695 | map_info.columns = ['array_row', 'array_col'] 696 | map_info.loc[:, 'imagerow'] = tissue_position_list.iloc[:, -2] 697 | map_info.loc[:, 'imagecol'] = tissue_position_list.iloc[:, -1] 698 | map_info.loc[:, 'sample'] = sample_id 699 | 700 | return { 701 | 'img': img, 702 | 'map_info': map_info, 703 | 'scalefactor': json_info 704 | } 705 | 706 | 707 | def get_adata_wsig(adata, adata_norm, gene_sig): 708 | """ 709 | Select intersection of HVGs from dataset & signature annotations 710 | """ 711 | # TODO: in-place operators for `adata` 712 | hvgs = adata.var_names[adata.var.highly_variable] 713 | unique_sigs = np.unique(gene_sig.values[~pd.isna(gene_sig)]) 714 | genes_to_keep = np.union1d( 715 | hvgs, 716 | np.intersect1d(adata.var_names, unique_sigs) 717 | ) 718 | return adata[:, genes_to_keep], adata_norm[:, genes_to_keep] 719 | 720 | 721 | def filter_gene_sig(gene_sig, adata_df): 722 | for i in range(gene_sig.shape[0]): 723 | for j in range(gene_sig.shape[1]): 724 | gene = gene_sig.iloc[i, j] 725 | if gene in adata_df.columns: 726 | # We don't filter signature genes based on expression level (prev: threshold=20) 727 | if adata_df.loc[:, gene].sum() < 0: 728 | gene_sig.iloc[i, j] = 'NaN' 729 | return gene_sig 730 | 731 | 732 | def get_umap(adata_sample, display=False): 733 | sc.tl.pca(adata_sample, svd_solver='arpack') 734 | sc.pp.neighbors(adata_sample, n_neighbors=15, n_pcs=40) 735 | sc.tl.umap(adata_sample, min_dist=0.2) 736 | if display: 737 | sc.pl.umap(adata_sample) 738 | umap_plot = pd.DataFrame(adata_sample.obsm['X_umap'], 739 | columns=['umap1', 'umap2'], 740 | index=adata_sample.obs_names) 741 | return umap_plot 742 | 743 | 744 | def get_simu_map_info(umap_plot): 745 | map_info = [] 746 | map_info = [-umap_plot['umap2'] * 10, umap_plot['umap1'] * 10] 747 | map_info = pd.DataFrame(np.transpose(map_info), 748 | columns=['array_row', 'array_col'], 749 | index=umap_plot.index) 750 | return map_info 751 | 752 | 753 | def get_windowed_library(adata_sample, map_info, library, window_size): 754 | library_n = [] 755 | 756 | for i in adata_sample.obs_names: 757 | window_size = window_size 758 | dist_arr = np.sqrt( 759 | (map_info.loc[:, 'array_col'] - map_info.loc[i, 'array_col']) ** 2 + 760 | (map_info.loc[:, 'array_row'] - map_info.loc[i, 'array_row']) ** 2 761 | ) 762 | 763 | library_n.append(library[dist_arr < window_size].mean()) 764 | library_n = np.array(library_n) 765 | return library_n 766 | 767 | 768 | def append_sigs(gene_sig, factor, sigs, n_genes=10): 769 | """ 770 | Append list of genes to a given cell type as additional signatures or 771 | add novel cell type / states & their signatures 772 | """ 773 | assert len(sigs) > 0, "Signature list must have positive length" 774 | gene_sig_new = gene_sig.copy() 775 | 776 | if not isinstance(sigs, list): 777 | sigs = sigs.to_list() 778 | if n_genes < len(sigs): 779 | sigs = sigs[:n_genes] 780 | 781 | markers = set([i for i in gene_sig[factor] if str(i) != 'nan'] + 782 | [i for i in sigs if str(i) != 'nan']) 783 | nrow_diff = int(np.abs(len(markers)-gene_sig_new.shape[0])) 784 | if len(markers) > gene_sig_new.shape[0]: 785 | df_dummy = pd.DataFrame([[np.nan] * gene_sig.shape[1]] * nrow_diff, 786 | columns=gene_sig_new.columns) 787 | gene_sig_new = pd.concat([gene_sig_new, df_dummy], ignore_index=True) 788 | else: 789 | markers = list(markers) + [np.nan]*nrow_diff 790 | gene_sig_new[factor] = list(markers) 791 | 792 | return gene_sig_new 793 | 794 | 795 | def refine_anchors( 796 | visium_args, 797 | aa_model, 798 | anchor_threshold=0.1, 799 | n_genes=10 800 | ): 801 | """ 802 | Refine anchor spots & marker genes with archetypal analysis. We append DEGs 803 | computed from archetypes to their best-matched anchors followed by re-computing 804 | new anchor spots 805 | 806 | Parameters 807 | ---------- 808 | visium_args : VisiumArgument 809 | Default parameter set for Starfysh upon dataloading 810 | 811 | aa_model : ArchetypalAnalysis 812 | Pre-computed archetype object 813 | 814 | anchor_threshold : float 815 | Top percent of anchor spots per cell-type 816 | for archetypal mapping 817 | 818 | n_genes : int 819 | # archetypal marker genes to append per refinement iteration 820 | 821 | Returns 822 | ------- 823 | visimu_args : VisiumArgument 824 | updated parameter set for Starfysh 825 | """ 826 | # TODO: integrate into `visium_args` class 827 | 828 | n_spots = visium_args.adata.shape[0] 829 | gene_sig = visium_args.gene_sig.copy() 830 | anchors_df = visium_args.get_anchors() 831 | n_top_anchors = int(anchor_threshold*n_spots) 832 | 833 | # Retrieve anchor-archetype mapping scores 834 | map_df, map_dict = aa_model.assign_archetypes(anchor_df=anchors_df[:n_top_anchors], 835 | r=n_top_anchors) 836 | markers_df = aa_model.find_markers(display=False) 837 | 838 | # (1). Update signatures 839 | for cell_type, archetype in map_dict.items(): 840 | gene_sig = append_sigs(gene_sig=gene_sig, 841 | factor=cell_type, 842 | sigs=markers_df[archetype], 843 | n_genes=n_genes) 844 | 845 | # (2). Update data args. 846 | visium_args.gene_sig = gene_sig 847 | visium_args._update_anchors() 848 | return visium_args 849 | 850 | 851 | # ------------------- 852 | # Post-processing 853 | # ------------------- 854 | 855 | def extract_feature(adata, key): 856 | """ 857 | Extract generative / inference output from adata.obsm 858 | generate dummy tmp. adata for plotting 859 | """ 860 | assert key in adata.obsm.keys(), "Unfounded Starfysh generative / inference output: {}".format(key) 861 | 862 | if key == 'qc_m': 863 | cols = adata.uns['cell_types'] # cell type deconvolution 864 | elif key == 'qz_m': 865 | cols = ['z'+str(i) for i in range(adata.obsm[key].shape[1])] # inferred qz (low-dim manifold) 866 | elif '_inferred_exprs' in key: 867 | cols = adata.var_names # inferred cell-type specific expressions 868 | else: 869 | cols = ['density'] 870 | adata_dummy = adata.copy() 871 | adata_dummy.obs = pd.DataFrame(adata.obsm[key], index=adata.obs.index, columns=cols) 872 | return adata_dummy 873 | 874 | 875 | def get_reconst_img(args, img_patches): 876 | """ 877 | Reconst original histology image (H x W) from the given patched image (S x P) 878 | """ 879 | reconst_img = np.zeros_like(args.img, dtype=np.float64) 880 | r = args.params['patch_r'] 881 | patch_size = (r * 2, r * 2, 3) if args.img.ndim == 3 else (r * 2, r * 2) 882 | scale_factor = args.scalefactor['tissue_hires_scalef'] 883 | img_col = args.map_info['imagecol'] * scale_factor 884 | img_row = args.map_info['imagerow'] * scale_factor 885 | 886 | for i in range(len(img_col)): 887 | patch_y = slice(int(img_row[i]) - r, int(img_row[i]) + r) 888 | patch_x = slice(int(img_col[i]) - r, int(img_col[i]) + r) 889 | 890 | sy, sx = reconst_img[patch_y, patch_x].shape[:2] 891 | img_patch = img_patches[i].reshape(patch_size) 892 | reconst_img[patch_y, patch_x] = img_patch[:sy, :sx] # edge patch cases 893 | 894 | return reconst_img 895 | 896 | 897 | -------------------------------------------------------------------------------- /starfysh/utils_integrate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import json 4 | import numpy as np 5 | import pandas as pd 6 | import scanpy as sc 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | 12 | from scipy.stats import median_abs_deviation 13 | from torch.utils.data import DataLoader 14 | 15 | import sys 16 | from skimage import io 17 | 18 | # Module import 19 | from starfysh import LOGGER 20 | from .dataloader import IntegrativeDataset, IntegrativePoEDataset 21 | from .starfysh import AVAE, AVAE_PoE, train, train_poe 22 | 23 | import numpy as np 24 | import pandas as pd 25 | import logging 26 | import torch 27 | import torch.nn.functional as F 28 | from torch.utils.data import Dataset 29 | 30 | 31 | class VisiumArguments_integrate: 32 | """ 33 | Loading Visium AnnData, perform preprocessing, library-size smoothing & Anchor spot detection 34 | 35 | Parameters 36 | ---------- 37 | adata : AnnData 38 | annotated visium count matrix 39 | 40 | adata_norm : AnnData 41 | annotated visium count matrix after normalization & log-transform 42 | 43 | gene_sig : pd.DataFrame 44 | list of signature genes for each cell type. (dim: [S, Cell_type]) 45 | 46 | img_metadata : dict 47 | Spatial information metadata (histology image, coordinates, scalefactor) 48 | """ 49 | def __init__( 50 | self, 51 | adata, 52 | adata_norm, 53 | gene_sig, 54 | img_metadata, 55 | individual_args, 56 | **kwargs 57 | ): 58 | 59 | self.adata = adata 60 | self.adata_norm = adata_norm 61 | self.gene_sig = gene_sig 62 | self.eps = 1e-6 63 | 64 | self.params = { 65 | 'sample_id': 'ST', 66 | 'n_anchors': int(adata.shape[0]), 67 | 'patch_r': 13, 68 | 'signif_level': 3, 69 | 'window_size': 1, 70 | 'n_img_chan': 1 71 | } 72 | 73 | for k, v in kwargs.items(): 74 | if k in self.params.keys(): 75 | self.params[k] = v 76 | 77 | 78 | map_info_temp_all = [] 79 | for i in self.params['sample_id']: 80 | map_info_temp = img_metadata[i]['map_info'].iloc[:, :4].astype(float) 81 | map_info_temp_all.append(map_info_temp) 82 | self.map_info = pd.concat(map_info_temp_all) 83 | 84 | img_temp = {} 85 | for i in self.params['sample_id']: 86 | img_temp[i] = img_metadata[i]['img'] 87 | self.img = img_temp 88 | 89 | self.img_patches = None 90 | 91 | scalefactor_temp = {} 92 | for i in self.params['sample_id']: 93 | scalefactor_temp[i] = img_metadata[i]['scalefactor'] 94 | self.scalefactor = scalefactor_temp 95 | 96 | # Update parameters for library smoothing & anchor spot identification 97 | 98 | 99 | # Center expression for gene score calculation 100 | adata_scale = self.adata_norm.copy() 101 | sc.pp.scale(adata_scale) 102 | 103 | # Store cell types 104 | self.adata.uns['cell_types'] = list(self.gene_sig.columns) 105 | 106 | # Filter out signature genes X listed in expression matrix 107 | LOGGER.info('Subsetting highly variable & signature genes ...') 108 | self.adata, self.adata_norm = get_adata_wsig(adata, adata_norm, gene_sig) 109 | self.adata_scale = adata_scale[:, adata.var_names] 110 | 111 | # Update spatial metadata 112 | self._update_spatial_info(self.params['sample_id']) 113 | 114 | # Get smoothed library size 115 | LOGGER.info('Smoothing library size by taking averaging with neighbor spots...') 116 | log_lib = np.log1p(self.adata.X.sum(1)) 117 | self.log_lib = np.squeeze(np.asarray(log_lib)) if log_lib.ndim > 1 else log_lib 118 | 119 | win_loglib_temp_all = [] 120 | for i in self.params['sample_id']: 121 | win_loglib_temp = get_windowed_library(self.adata[self.adata.obs['sample']==i], 122 | self.map_info[self.adata.obs['sample']==i], 123 | self.log_lib[self.adata.obs['sample']==i], 124 | window_size=self.params['window_size'] 125 | ) 126 | win_loglib_temp_all.append(pd.DataFrame(win_loglib_temp)) 127 | 128 | self.win_loglib = pd.concat(win_loglib_temp_all) 129 | self.win_loglib = np.array(self.win_loglib) 130 | 131 | # Retrieve & normalize signature gexp 132 | LOGGER.info('Retrieving & normalizing signature gene expressions...') 133 | sig_mean_temp = [] 134 | for i in individual_args.keys(): 135 | sig_mean_temp.append(individual_args[i].sig_mean) 136 | 137 | self.sig_mean = pd.concat(sig_mean_temp,axis=0) 138 | 139 | sig_mean_norm_temp = [] 140 | for i in individual_args.keys(): 141 | sig_mean_norm_temp.append(individual_args[i].sig_mean_norm) 142 | 143 | self.sig_mean_norm = pd.concat(sig_mean_norm_temp,axis=0) 144 | 145 | # Get anchor spots 146 | LOGGER.info('Identifying anchor spots (highly expression of specific cell-type signatures)...') 147 | anchor_info = self._compute_anchors() 148 | 149 | self.pure_spots, self.pure_dict, self.pure_idx = anchor_info 150 | del self.adata.raw, self.adata_norm.raw 151 | 152 | def get_adata(self): 153 | """Return adata after preprocessing & HVG gene selection""" 154 | return self.adata, self.adata_norm 155 | 156 | def get_anchors(self): 157 | """Return indices of anchor spots for each cell type""" 158 | anchors_df = pd.DataFrame.from_dict(self.pure_dict, orient='index') 159 | anchors_df = anchors_df.transpose() 160 | 161 | # Check whether empty anchors detected for any factor 162 | empty_indices = np.where( 163 | (~pd.isna(anchors_df)).sum(0) == 0 164 | )[0] 165 | 166 | if len(empty_indices) > 0: 167 | raise ValueError("Cell type(s) {} has no anchors significantly enriched for its signatures," 168 | "please lower outlier stats `signif_level`".format( 169 | anchors_df.columns[empty_indices].to_list() 170 | )) 171 | 172 | return anchors_df.applymap( 173 | lambda x: 174 | -1 if x is None else np.where(self.adata.obs.index == x)[0][0] 175 | ) 176 | 177 | def get_img_patches(self): 178 | assert self.img_patches is not None, "Please run Starfysh PoE first" 179 | return self.img_patches 180 | 181 | def append_factors(self, arche_markers): 182 | """ 183 | Append list of archetypes (w/ corresponding markers) as additional cell type(s) / state(s) to the `gene_sig` 184 | """ 185 | self.gene_sig = pd.concat((self.gene_sig, arche_markers), axis=1) 186 | 187 | # Update factor names & anchor spots 188 | self.adata.uns['cell_types'] = list(self.gene_sig.columns) 189 | self._update_anchors() 190 | return None 191 | 192 | def replace_factors(self, factors_to_repl, arche_markers): 193 | """ 194 | Replace factor(s) with archetypes & their corresponding markers in the `gene_sig` 195 | """ 196 | if isinstance(factors_to_repl, str): 197 | assert isinstance(arche_markers, pd.Series),\ 198 | "Please pick only one archetype to replace the factor {}".format(factors_to_repl) 199 | factors_to_repl = [factors_to_repl] 200 | archetypes = [arche_markers.name] 201 | else: 202 | assert len(factors_to_repl) == len(arche_markers.columns), \ 203 | "Unequal # cell types & archetypes to replace with" 204 | archetypes = arche_markers.columns 205 | 206 | self.gene_sig.rename( 207 | columns={ 208 | f: a 209 | for (f, a) in zip(factors_to_repl, archetypes) 210 | }, inplace=True 211 | ) 212 | self.gene_sig[archetypes] = pd.DataFrame(arche_markers) 213 | 214 | # Update factor names & anchor spots 215 | self.adata.uns['cell_types'] = list(self.gene_sig.columns) 216 | self._update_anchors() 217 | return None 218 | 219 | # --- Private methods --- 220 | def _compute_anchors(self): 221 | """ 222 | Calculate top `anchor_spots` significantly enriched for given cell type(s) 223 | determined by gene set scores from signatures 224 | """ 225 | score_df = self.sig_mean_norm 226 | signif_level = self.params['signif_level'] 227 | n_anchor = self.params['n_anchors'] 228 | 229 | top_expr_spots = (-score_df.values).argsort(axis=0)[:n_anchor, :] 230 | pure_spots = np.transpose(np.array(score_df.index)[top_expr_spots]) 231 | 232 | pure_dict = { 233 | ct: spot 234 | for (spot, ct) in zip(pure_spots, score_df.columns) 235 | } 236 | 237 | pure_indices = np.zeros([score_df.shape[0], 1]) 238 | idx = [np.where(score_df.index == i)[0][0] 239 | for i in sorted({x for v in pure_dict.values() for x in v})] 240 | pure_indices[idx] = 1 241 | return pure_spots, pure_dict, pure_indices 242 | 243 | def _update_anchors(self): 244 | """Re-calculate anchor spots given updated gene signatures""" 245 | self.sig_mean = self._get_sig_mean() 246 | self.sig_mean_norm = self._calc_gene_scores() 247 | self.adata.uns['cell_types'] = list(self.gene_sig.columns) 248 | 249 | LOGGER.info('Recalculating anchor spots (highly expression of specific cell-type signatures)...') 250 | anchor_info = self._compute_anchors() 251 | self.sig_mean_norm[self.sig_mean_norm < 0] = self.eps 252 | self.sig_mean_norm.fillna(1/self.sig_mean_norm.shape[1], inplace=True) 253 | self.pure_spots, self.pure_dict, self.pure_idx = anchor_info 254 | 255 | def _get_sig_mean(self): 256 | sig_mean_expr = pd.DataFrame() 257 | cnt_df = self.adata_norm.to_df() 258 | 259 | # Calculate avg. signature expressions for each cell type 260 | for i, cell_type in enumerate(self.gene_sig.columns): 261 | sigs = np.intersect1d(cnt_df.columns, self.gene_sig.iloc[:, i].astype(str)) 262 | 263 | if len(sigs) == 0: 264 | raise ValueError("Empty signatures for {}," 265 | "please double check your `gene_sig` input or set a higher" 266 | "`n_gene` threshold upon dataloading".format(cell_type)) 267 | 268 | else: 269 | sig_mean_expr[cell_type] = cnt_df.loc[:, sigs].mean(axis=1) 270 | 271 | sig_mean_expr.index = self.adata.obs_names 272 | sig_mean_expr.columns = self.gene_sig.columns 273 | return sig_mean_expr 274 | 275 | def _update_spatial_info(self, sample_id): 276 | """Update paired spatial information to ST adata""" 277 | # Update image channel count for RGB input (`y`) 278 | if self.img is not None and self.img[sample_id.iloc[0]].ndim == 3: 279 | self.params['n_img_chan'] = 3 280 | 281 | if 'spatial' not in self.adata.uns_keys(): 282 | self.adata.uns['spatial'] = { 283 | i: { 284 | 'images': {'hires': (self.img[i] - self.img[i].min()) / (self.img[i].max() - self.img[i].min())}, 285 | 'scalefactors': self.scalefactor 286 | } for i in sample_id 287 | } 288 | 289 | self.adata_norm.uns['spatial'] = { 290 | i: { 291 | 'images': {'hires': (self.img[i] - self.img[i].min()) / (self.img[i].max() - self.img[i].min())}, 292 | 'scalefactors': self.scalefactor[i] 293 | } for i in sample_id 294 | } 295 | self.adata.obsm['spatial'] = self.map_info[['imagecol', 'imagerow']].values 296 | self.adata_norm.obsm['spatial'] = self.map_info[['imagecol', 'imagerow']].values 297 | 298 | # Typecast: spatial coords. 299 | self.adata_norm.obsm['spatial'] = self.map_info[['imagecol', 'imagerow']].values 300 | self.adata_norm.obsm['spatial'] = self.adata_norm.obsm['spatial'].astype(np.float32) 301 | return None 302 | 303 | def _update_img_patches(self, dl_poe): 304 | dl_poe.spot_img_stack = np.array(dl_poe.spot_img_stack) 305 | dl_poe.spot_img_stack = dl_poe.spot_img_stack.reshape(dl_poe.spot_img_stack.shape[0], -1) 306 | imgs = torch.Tensor(dl_poe.spot_img_stack) 307 | self.img_patches = imgs 308 | return None 309 | 310 | def _norm_sig(self): 311 | # col-norm for each cell type: divided by mean 312 | gexp = self.sig_mean.apply(lambda x: x / x.mean(), axis=0) 313 | return gexp 314 | 315 | def _calc_gene_scores(self): 316 | """Calculate gene set enrichment scores for each signature sets""" 317 | adata = self.adata_scale.copy() 318 | #adata = self.adata_norm.copy() 319 | for cell_type in self.gene_sig.columns: 320 | sig = self.gene_sig[cell_type][~pd.isna(self.gene_sig[cell_type])].to_list() 321 | sc.tl.score_genes(adata, sig, score_name=cell_type+'_score',use_raw=False) 322 | 323 | gsea_df = adata.obs[[cell_type+'_score' for cell_type in self.gene_sig.columns]] 324 | gsea_df.columns = self.gene_sig.columns 325 | return gsea_df 326 | 327 | 328 | def get_adata_wsig(adata, adata_norm, gene_sig): 329 | """ 330 | Select intersection of HVGs from dataset & signature annotations 331 | """ 332 | hvgs = adata.var_names[adata.var.highly_variable] 333 | unique_sigs = np.unique(gene_sig.values[~pd.isna(gene_sig)]) 334 | genes_to_keep = np.union1d( 335 | hvgs, 336 | np.intersect1d(adata.var_names, unique_sigs) 337 | ) 338 | return adata[:, genes_to_keep], adata_norm[:, genes_to_keep] 339 | 340 | 341 | def get_windowed_library(adata_sample, map_info, library, window_size): 342 | library_n = [] 343 | 344 | for i in adata_sample.obs_names: 345 | window_size = window_size 346 | dist_arr = np.sqrt( 347 | (map_info.loc[:, 'array_col'] - map_info.loc[i, 'array_col']) ** 2 + 348 | (map_info.loc[:, 'array_row'] - map_info.loc[i, 'array_row']) ** 2 349 | ) 350 | 351 | library_n.append(library[dist_arr < window_size].mean()) 352 | library_n = np.array(library_n) 353 | 354 | return library_n 355 | 356 | 357 | def init_weights(module): 358 | if type(module) == nn.Linear: 359 | torch.nn.init.kaiming_uniform_(module.weight) 360 | 361 | elif type(module) == nn.BatchNorm1d: 362 | module.bias.data.zero_() 363 | module.weight.data.fill_(1.0) 364 | 365 | def run_starfysh( 366 | visium_args, 367 | n_repeats=3, 368 | lr=1e-4, 369 | epochs=100, 370 | batch_size=32, 371 | alpha_mul=50, 372 | poe=False, 373 | device=torch.device('cpu'), 374 | verbose=True, 375 | 376 | ): 377 | """ 378 | Wrapper to run starfysh deconvolution. 379 | 380 | Parameters 381 | ---------- 382 | visium_args : VisiumArguments 383 | Preprocessed metadata calculated from input visium matrix: 384 | e.g. mean signature expression, library size, anchor spots, etc. 385 | 386 | n_repeats : int 387 | Number of restart to run Starfysh 388 | 389 | epochs : int 390 | Max. number of iterations 391 | 392 | poe : bool 393 | Whether to perform inference with Poe w/ image integration 394 | 395 | Returns 396 | ------- 397 | best_model : starfysh.AVAE or starfysh.AVAE_PoE 398 | Trained Starfysh model with deconvolution results 399 | 400 | loss : np.ndarray 401 | Training losses 402 | """ 403 | np.random.seed(0) 404 | 405 | # Loading parameters 406 | adata = visium_args.adata 407 | win_loglib = visium_args.win_loglib 408 | gene_sig, sig_mean_norm = visium_args.gene_sig, visium_args.sig_mean_norm 409 | 410 | models = [None] * n_repeats 411 | losses = [] 412 | loss_c_list = np.repeat(np.inf, n_repeats) 413 | 414 | if poe: 415 | dl_func = IntegrativePoEDataset # dataloader 416 | train_func = train_poe # training wrapper 417 | else: 418 | dl_func = IntegrativeDataset 419 | train_func = train 420 | 421 | trainset = dl_func(adata=adata, args=visium_args) 422 | trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, drop_last=True) 423 | 424 | # Running Starfysh with multiple starts 425 | LOGGER.info('Running Starfysh with {} restarts, choose the model with best parameters...'.format(n_repeats)) 426 | for i in range(n_repeats): 427 | if verbose: 428 | LOGGER.info(" === Restart Starfysh {0} === \n".format(i + 1)) 429 | best_loss_c = np.inf 430 | 431 | 432 | if poe: 433 | 434 | model = AVAE_PoE( 435 | adata=adata, 436 | gene_sig=sig_mean_norm, 437 | patch_r=visium_args.params['patch_r'], 438 | win_loglib=win_loglib, 439 | alpha_mul=alpha_mul, 440 | n_img_chan=visium_args.params['n_img_chan'] 441 | ) 442 | # Update patched & flattened image patches 443 | visium_args._update_img_patches(trainset) 444 | else: 445 | model = AVAE( 446 | adata=adata, 447 | gene_sig=sig_mean_norm, 448 | win_loglib=win_loglib, 449 | alpha_mul=alpha_mul 450 | ) 451 | 452 | model = model.to(device) 453 | loss_dict = { 454 | 'reconst': [], 455 | 'c': [], 456 | 'u': [], 457 | 'z': [], 458 | 'n': [], 459 | 'tot': [] 460 | } 461 | 462 | # Initialize model params 463 | if verbose: 464 | LOGGER.info('Initializing model parameters...') 465 | 466 | model.apply(init_weights) 467 | optimizer = optim.Adam(model.parameters(), lr=lr) 468 | scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99) 469 | 470 | for epoch in range(epochs): 471 | result = train_func(model, trainloader, device, optimizer) 472 | torch.cuda.empty_cache() 473 | 474 | loss_tot, loss_reconst, loss_u, loss_z, loss_c, loss_n, corr_list = result 475 | if loss_c < best_loss_c: 476 | models[i] = model 477 | best_loss_c = loss_c 478 | 479 | torch.cuda.empty_cache() 480 | 481 | loss_dict['tot'].append(loss_tot) 482 | loss_dict['reconst'].append(loss_reconst) 483 | loss_dict['u'].append(loss_u) 484 | loss_dict['z'].append(loss_z) 485 | loss_dict['c'].append(loss_c) 486 | loss_dict['n'].append(loss_n) 487 | 488 | if (epoch + 1) % 10 == 0 and verbose: 489 | LOGGER.info("Epoch[{}/{}], train_loss: {:.4f}, train_reconst: {:.4f}, train_u: {:.4f},train_z: {:.4f},train_c: {:.4f},train_n: {:.4f}".format( 490 | epoch + 1, epochs, loss_tot, loss_reconst, loss_u, loss_z, loss_c, loss_n) 491 | ) 492 | scheduler.step() 493 | 494 | losses.append(loss_dict) 495 | loss_c_list[i] = best_loss_c 496 | if verbose: 497 | LOGGER.info('Saving the best-performance model...') 498 | LOGGER.info(" === Finished training === \n") 499 | 500 | idx = np.argmin(loss_c_list) 501 | best_model = models[idx] 502 | loss = losses[idx] 503 | 504 | return best_model, loss 505 | -------------------------------------------------------------------------------- /tests/test_modules.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import importlib 3 | 4 | class TestModules(unittest.TestCase): 5 | def test_import_modules(self): 6 | modules = [ 7 | "AA", 8 | "dataloader", 9 | "plot_utils", 10 | "post_analysis", 11 | "starfysh", 12 | "utils", 13 | "utils_integrate", 14 | ] 15 | 16 | for module_name in modules: 17 | try: 18 | importlib.import_module("starfysh." + module_name) 19 | print(f"Module {module_name} imported successfully.") 20 | except ImportError as e: 21 | self.fail(f"Failed to import module {module_name}: {e}") 22 | 23 | if __name__ == '__main__': 24 | unittest.main() 25 | --------------------------------------------------------------------------------