├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── demo.ipynb ├── environment.yml ├── models ├── __init__.py ├── cf.py ├── likelihoods.py ├── mefisto.py ├── sf.py └── sfh.py ├── nsf-paper.Rproj ├── requirements.txt ├── scrna ├── .gitignore ├── 01_viz_spatial_importance.Rmd ├── sshippo │ ├── 01_data_loading.Rmd │ ├── 02_data_loading.ipy │ ├── 03_benchmark.ipy │ ├── 04_exploratory.ipy │ ├── 05_benchmark_viz.Rmd │ ├── 06_interpret_genes.Rmd │ ├── 07_traditional.ipy │ └── results │ │ └── benchmark.csv ├── utils │ └── interpret_genes.R ├── visium_brain_sagittal │ ├── 01_data_loading.ipy │ ├── 02_exploratory.ipy │ ├── 03_benchmark.ipy │ ├── 04_benchmark_viz.Rmd │ ├── 05_interpret_genes.Rmd │ ├── 06_traditional.ipy │ └── results │ │ └── benchmark.csv └── xyzeq_liver │ ├── 01_data_loading.ipy │ ├── 02_exploratory.ipy │ ├── 03_benchmark.ipy │ ├── 04_benchmark_viz.Rmd │ ├── 05_interpret_genes.Rmd │ ├── 06_traditional.ipy │ └── results │ └── benchmark.csv ├── simulations ├── .gitignore ├── __init__.py ├── benchmark.py ├── benchmark.slurm ├── benchmark_gof.py ├── benchmark_gof.slurm ├── bm_mixed │ ├── 01_data_generation.ipy │ ├── 02_benchmark.ipy │ ├── 03_benchmark_viz.Rmd │ └── results │ │ └── benchmark.csv ├── bm_sp │ ├── 01_data_generation.ipy │ ├── 02_benchmark.ipy │ ├── 03_benchmark_viz.Rmd │ ├── 04_quilt_exploratory.ipy │ ├── 05_ggblocks_exploratory.ipy │ ├── data │ │ ├── S1.h5ad │ │ └── S6.h5ad │ └── results │ │ └── benchmark.csv └── sim.py └── utils ├── __init__.py ├── benchmark.py ├── benchmark_array.slurm ├── benchmark_gof.py ├── benchmark_gof.slurm ├── misc.py ├── nnfu.py ├── postprocess.py ├── preprocess.py ├── training.py └── visualize.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.jpg binary 2 | *.png binary 3 | *.pdf binary 4 | *.RData binary 5 | *.h5ad binary 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Added by WT 2 | .ipynb_checkpoints/ 3 | # data 4 | # plots 5 | model_checkpoints 6 | *.rds 7 | .spyproject 8 | __pycache__ 9 | .DS_Store 10 | tf_ckpts 11 | slurm-*.out 12 | slurm-*.err 13 | *.pyc 14 | *.pickle 15 | resources 16 | 17 | # History files 18 | .Rhistory 19 | .Rapp.history 20 | 21 | # Session Data files 22 | .RData 23 | 24 | # User-specific files 25 | .Ruserdata 26 | 27 | # Example code in package build process 28 | *-Ex.R 29 | 30 | # Output files from R CMD build 31 | /*.tar.gz 32 | 33 | # Output files from R CMD check 34 | /*.Rcheck/ 35 | 36 | # RStudio files 37 | .Rproj.user/ 38 | 39 | # produced vignettes 40 | vignettes/*.html 41 | vignettes/*.pdf 42 | 43 | # OAuth2 token, see https://github.com/hadley/httr/releases/tag/v0.3 44 | .httr-oauth 45 | 46 | # knitr and R markdown default cache directories 47 | *_cache/ 48 | /cache/ 49 | 50 | # Temporary files created by R markdown 51 | *.utf8.md 52 | *.knit.md 53 | 54 | # R Environment Variables 55 | .Renviron 56 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU LESSER GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | 9 | This version of the GNU Lesser General Public License incorporates 10 | the terms and conditions of version 3 of the GNU General Public 11 | License, supplemented by the additional permissions listed below. 12 | 13 | 0. Additional Definitions. 14 | 15 | As used herein, "this License" refers to version 3 of the GNU Lesser 16 | General Public License, and the "GNU GPL" refers to version 3 of the GNU 17 | General Public License. 18 | 19 | "The Library" refers to a covered work governed by this License, 20 | other than an Application or a Combined Work as defined below. 21 | 22 | An "Application" is any work that makes use of an interface provided 23 | by the Library, but which is not otherwise based on the Library. 24 | Defining a subclass of a class defined by the Library is deemed a mode 25 | of using an interface provided by the Library. 26 | 27 | A "Combined Work" is a work produced by combining or linking an 28 | Application with the Library. The particular version of the Library 29 | with which the Combined Work was made is also called the "Linked 30 | Version". 31 | 32 | The "Minimal Corresponding Source" for a Combined Work means the 33 | Corresponding Source for the Combined Work, excluding any source code 34 | for portions of the Combined Work that, considered in isolation, are 35 | based on the Application, and not on the Linked Version. 36 | 37 | The "Corresponding Application Code" for a Combined Work means the 38 | object code and/or source code for the Application, including any data 39 | and utility programs needed for reproducing the Combined Work from the 40 | Application, but excluding the System Libraries of the Combined Work. 41 | 42 | 1. Exception to Section 3 of the GNU GPL. 43 | 44 | You may convey a covered work under sections 3 and 4 of this License 45 | without being bound by section 3 of the GNU GPL. 46 | 47 | 2. Conveying Modified Versions. 48 | 49 | If you modify a copy of the Library, and, in your modifications, a 50 | facility refers to a function or data to be supplied by an Application 51 | that uses the facility (other than as an argument passed when the 52 | facility is invoked), then you may convey a copy of the modified 53 | version: 54 | 55 | a) under this License, provided that you make a good faith effort to 56 | ensure that, in the event an Application does not supply the 57 | function or data, the facility still operates, and performs 58 | whatever part of its purpose remains meaningful, or 59 | 60 | b) under the GNU GPL, with none of the additional permissions of 61 | this License applicable to that copy. 62 | 63 | 3. Object Code Incorporating Material from Library Header Files. 64 | 65 | The object code form of an Application may incorporate material from 66 | a header file that is part of the Library. You may convey such object 67 | code under terms of your choice, provided that, if the incorporated 68 | material is not limited to numerical parameters, data structure 69 | layouts and accessors, or small macros, inline functions and templates 70 | (ten or fewer lines in length), you do both of the following: 71 | 72 | a) Give prominent notice with each copy of the object code that the 73 | Library is used in it and that the Library and its use are 74 | covered by this License. 75 | 76 | b) Accompany the object code with a copy of the GNU GPL and this license 77 | document. 78 | 79 | 4. Combined Works. 80 | 81 | You may convey a Combined Work under terms of your choice that, 82 | taken together, effectively do not restrict modification of the 83 | portions of the Library contained in the Combined Work and reverse 84 | engineering for debugging such modifications, if you also do each of 85 | the following: 86 | 87 | a) Give prominent notice with each copy of the Combined Work that 88 | the Library is used in it and that the Library and its use are 89 | covered by this License. 90 | 91 | b) Accompany the Combined Work with a copy of the GNU GPL and this license 92 | document. 93 | 94 | c) For a Combined Work that displays copyright notices during 95 | execution, include the copyright notice for the Library among 96 | these notices, as well as a reference directing the user to the 97 | copies of the GNU GPL and this license document. 98 | 99 | d) Do one of the following: 100 | 101 | 0) Convey the Minimal Corresponding Source under the terms of this 102 | License, and the Corresponding Application Code in a form 103 | suitable for, and under terms that permit, the user to 104 | recombine or relink the Application with a modified version of 105 | the Linked Version to produce a modified Combined Work, in the 106 | manner specified by section 6 of the GNU GPL for conveying 107 | Corresponding Source. 108 | 109 | 1) Use a suitable shared library mechanism for linking with the 110 | Library. A suitable mechanism is one that (a) uses at run time 111 | a copy of the Library already present on the user's computer 112 | system, and (b) will operate properly with a modified version 113 | of the Library that is interface-compatible with the Linked 114 | Version. 115 | 116 | e) Provide Installation Information, but only if you would otherwise 117 | be required to provide such information under section 6 of the 118 | GNU GPL, and only to the extent that such information is 119 | necessary to install and execute a modified version of the 120 | Combined Work produced by recombining or relinking the 121 | Application with a modified version of the Linked Version. (If 122 | you use option 4d0, the Installation Information must accompany 123 | the Minimal Corresponding Source and Corresponding Application 124 | Code. If you use option 4d1, you must provide the Installation 125 | Information in the manner specified by section 6 of the GNU GPL 126 | for conveying Corresponding Source.) 127 | 128 | 5. Combined Libraries. 129 | 130 | You may place library facilities that are a work based on the 131 | Library side by side in a single library together with other library 132 | facilities that are not Applications and are not covered by this 133 | License, and convey such a combined library under terms of your 134 | choice, if you do both of the following: 135 | 136 | a) Accompany the combined library with a copy of the same work based 137 | on the Library, uncombined with any other library facilities, 138 | conveyed under the terms of this License. 139 | 140 | b) Give prominent notice with the combined library that part of it 141 | is a work based on the Library, and explaining where to find the 142 | accompanying uncombined form of the same work. 143 | 144 | 6. Revised Versions of the GNU Lesser General Public License. 145 | 146 | The Free Software Foundation may publish revised and/or new versions 147 | of the GNU Lesser General Public License from time to time. Such new 148 | versions will be similar in spirit to the present version, but may 149 | differ in detail to address new problems or concerns. 150 | 151 | Each version is given a distinguishing version number. If the 152 | Library as you received it specifies that a certain numbered version 153 | of the GNU Lesser General Public License "or any later version" 154 | applies to it, you have the option of following the terms and 155 | conditions either of that published version or of any later version 156 | published by the Free Software Foundation. If the Library as you 157 | received it does not specify a version number of the GNU Lesser 158 | General Public License, you may choose any version of the GNU Lesser 159 | General Public License ever published by the Free Software Foundation. 160 | 161 | If the Library as you received it specifies that a proxy can decide 162 | whether future versions of the GNU Lesser General Public License shall 163 | apply, that proxy's public statement of acceptance of any version is 164 | permanent authorization for you to choose that version for the 165 | Library. 166 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Nonnegative spatial factorization for multivariate count data 2 | 3 | [![DOI](https://zenodo.org/badge/415147174.svg)](https://zenodo.org/badge/latestdoi/415147174) 4 | 5 | This repository contains supporting code to facilitate reproducible analysis. For details see the [preprint](https://arxiv.org/abs/2110.06122). If you find bugs please create a github issue. An [installable python package](https://github.com/willtownes/spatial-factorization-py) 6 | is also under development. 7 | 8 | ### Authors 9 | 10 | Will Townes and Barbara Engelhardt 11 | 12 | ### Abstract 13 | 14 | Gaussian processes are widely used for the analysis of spatial data due to their nonparametric flexibility and ability to quantify uncertainty, and recently developed scalable approximations have facilitated application to massive datasets. For multivariate outcomes, linear models of coregionalization combine dimension reduction with spatial correlation. However, their real-valued latent factors and loadings are difficult to interpret because, unlike nonnegative models, they do not recover a parts-based representation. We present nonnegative spatial factorization (NSF), a spatially-aware probabilistic dimension reduction model that naturally encourages sparsity. We compare NSF to real-valued spatial factorizations such as MEFISTO and nonspatial dimension reduction methods using simulations and high-dimensional spatial transcriptomics data. NSF identifies generalizable spatial patterns of gene expression. Since not all patterns of gene expression are spatial, we also propose a hybrid extension of NSF that combines spatial and nonspatial components, enabling quantification of spatial importance for both observations and features. 15 | 16 | ## Demo 17 | 18 | A basic demonstration ([**demo.ipynb**](https://github.com/willtownes/nsf-paper/blob/main/demo.ipynb)) using simulated data is provided as a [jupyter](https://jupyter.org) notebook. The expected output is a series of heatmap plots. The runtime should be about 5 minutes. 19 | 20 | ## Description of Repository Contents 21 | All scripts should be run from the top level directory. Files with the suffix `.ipy` are essentially text-only versions of jupyter notebooks and can best be used through the [Spyder IDE](https://www.spyder-ide.org). They can be converted to full jupyter notebooks using [jupytext](https://jupytext.readthedocs.io/en/latest/). 22 | 23 | ### models 24 | 25 | TensorFlow implementations of probabilistic factor models 26 | * *cf.py* - nonspatial models (factor analysis and probabilistic nonnegative matrix factorization). 27 | * *mefisto.py* - wrapper around the MEFISTO implementation in the [mofapy2](https://github.com/bioFAM/mofapy2/commit/8f6ffcb5b18d22b3f44ff2a06bcb92f2806afed0) python package. 28 | * *sf.py* - nonnegative and real-valued spatial process factorization (NSF and RSF). 29 | * *sfh.py* - NSF hybrid model, includes both spatial and nonspatial components. 30 | 31 | ### scrna 32 | 33 | Analysis of spatial transcriptomics data 34 | * *sshippo* - Slide-seqV2 mouse hippocampus 35 | * *visium_brain_sagittal* - Visium mouse brain (anterior sagittal section) 36 | * *xyzeq_liver* - XYZeq mouse liver/tumor 37 | 38 | ### simulations 39 | 40 | Data generation and model fitting for the ggblocks and quilt simulations. 41 | * *benchmark.py* - can be called as a command line script to facilitate benchmarking of large numbers of scenarios and parameter combinations. 42 | * *benchmark_gof.py* - compute goodness of fit and other metrics on fitted models. 43 | * *bm_mixed* - mixed spatial and nonspatial factors 44 | * *bm_sp* - spatial factors only. Within this folder, the notebooks 45 | `04_quilt_exploratory.ipy` and `05_ggblocks_exploratory.ipy` have many 46 | visualizations of the various models compared in the paper. 47 | * *sim.py* - functions for creating the simulated datasets. 48 | 49 | ### utils 50 | 51 | Python modules containing functions and classes needed by scripts and model implementation classes. 52 | * *benchmark.py* - functions used in fitting models to datasets and pickling the objects for later evaluation. Can be called as a command line script to facilitate automation. 53 | * *benchmark_gof.py* - script with basic command line interface for computing goodness-of-fit, sparsity, and timing statistics on large numbers of fitted model objects 54 | * *misc.py* - miscellaneous convenience functions useful in preprocessing (normalization and reversing normalization), postprocessing, computing benchmarking statistics, parameter manipulation, and reading and writing pickle and CSV files. 55 | * *nnfu.py* - nonnegative factor model utility functions for rescaling and regularization. Useful in initialization and postprocessing. 56 | * *postprocess.py* - postprocessing functions to facilitate interpretation of nonnegative factor models. 57 | * *preprocess.py* - data loading and preprocessing functions. Normalization of count data, rescaling spatial coordinates for numerical stability, deviance functions for feature selection (analogous to [scry](https://doi.org/doi:10.18129/B9.bioc.scry)), conversions between AnnData and TensorFlow objects. 58 | * *training.py* - classes for fitting TensorFlow models to data, including caching with checkpoints, automatic handling of numeric instabilities, and ConvergenceChecker, which uses a cubic spline to detect convergence of a stochastic optimizer trace. 59 | * *visualize.py* - plotting functions for making heatmaps to visualize spatial and nonspatial factors, as well as some goodness-of-fit metrics. 60 | 61 | ## System requirements 62 | 63 | We used the following versions in our analyses: Python 3.8.10, tensorflow 2.5.0, tensorflow probability 0.13.0, scanpy 1.8.0, squidpy 1.1.0, scikit-learn 0.24.2, pandas 1.2.5, numpy 1.19.5, scipy 1.7.0. 64 | We used the MEFISTO implementation from the mofapy2 Python package, installed from the GitHub development branch at commit 8f6ffcb5b18d22b3f44ff2a06bcb92f2806afed0. 65 | 66 | ```Shell 67 | pip install git+git://github.com/bioFAM/mofapy2.git@8f6ffcb5b18d22b3f44ff2a06bcb92f2806afed0 68 | ``` 69 | 70 | Graphics were generated using either matplotlib 3.4.2 in Python or ggplot2 3.3.5 in R (version 4.1.0). The R packages Seurat 0.4.3, SeuratData 0.2.1, and SeuratDisk 0.0.0.9019 were used for some initial data manipulations. 71 | 72 | Computationally-intensive model fitting was done on Princeton's [Della cluster](https://researchcomputing.princeton.edu/systems/della). Analyses that were less computationally intensive were done on personal computers with operating system MacOS version 12.4. 73 | 74 | ## Installation 75 | 76 | ```Shell 77 | git clone https://github.com/willtownes/nsf-paper.git 78 | ``` 79 | 80 | This should only take a few seconds on an ordinary computer with a good internet connection. 81 | 82 | ## Instructions for use 83 | 84 | Data should be stored as a Scanpy AnnData object with the raw counts in the layer "counts" and spatial coordinates in the obsm["spatial"] slot. Utility functions to convert this into the required Tensorflow objects for model fitting are demonstrated in the demo. To reproduce results from the manuscript, use the numbered ipython scripts in each dataset's subfolder. Intermediate results from benchmarking are cached in `results/benchmark.csv` files which can be used to produce many of the plots in the manuscript. 85 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: nsf-paper 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - abseil-cpp=20220623.0=he4e09e4_6 6 | - absl-py=1.4.0=pyhd8ed1ab_0 7 | - aiohttp=3.8.4=py38hb192615_1 8 | - aiosignal=1.3.1=pyhd8ed1ab_0 9 | - anndata=0.9.1=pyhd8ed1ab_0 10 | - aom=3.5.0=h7ea286d_0 11 | - appnope=0.1.3=pyhd8ed1ab_0 12 | - arpack=3.7.0=h58ebc17_2 13 | - arrow-cpp=11.0.0=hce30654_5_cpu 14 | - asciitree=0.3.3=py_2 15 | - asttokens=2.2.1=pyhd8ed1ab_0 16 | - astunparse=1.6.3=pyhd8ed1ab_0 17 | - async-timeout=4.0.2=pyhd8ed1ab_0 18 | - attrs=23.1.0=pyh71513ae_1 19 | - aws-c-auth=0.6.24=he8f13b4_5 20 | - aws-c-cal=0.5.20=h9571af1_6 21 | - aws-c-common=0.8.11=h1a8c8d9_0 22 | - aws-c-compression=0.2.16=h7334ab6_3 23 | - aws-c-event-stream=0.2.18=ha663d55_6 24 | - aws-c-http=0.7.4=h49dec38_2 25 | - aws-c-io=0.13.17=h323b671_2 26 | - aws-c-mqtt=0.8.6=hdc0f556_6 27 | - aws-c-s3=0.2.4=hbb4c6b3_3 28 | - aws-c-sdkutils=0.1.7=h7334ab6_3 29 | - aws-checksums=0.1.14=h7334ab6_3 30 | - aws-crt-cpp=0.19.7=h6f6c549_7 31 | - aws-sdk-cpp=1.10.57=hbe10753_4 32 | - backcall=0.2.0=pyh9f0ad1d_0 33 | - backports=1.0=pyhd8ed1ab_3 34 | - backports.functools_lru_cache=1.6.5=pyhd8ed1ab_0 35 | - blinker=1.6.2=pyhd8ed1ab_0 36 | - blosc=1.21.4=hc338f07_0 37 | - bokeh=3.1.1=pyhd8ed1ab_0 38 | - brotli=1.0.9=h1a8c8d9_9 39 | - brotli-bin=1.0.9=h1a8c8d9_9 40 | - brotlipy=0.7.0=py38hb991d35_1005 41 | - brunsli=0.1=h9f76cd9_0 42 | - bzip2=1.0.8=h3422bc3_4 43 | - c-ares=1.19.1=hb547adb_0 44 | - c-blosc2=2.9.3=h068da5f_0 45 | - ca-certificates=2023.5.7=hf0a4a13_0 46 | - cached-property=1.5.2=hd8ed1ab_1 47 | - cached_property=1.5.2=pyha770c72_1 48 | - cachetools=5.3.0=pyhd8ed1ab_0 49 | - certifi=2023.5.7=pyhd8ed1ab_0 50 | - cffi=1.15.1=py38ha45ccd6_3 51 | - cfitsio=4.2.0=h2f961c4_0 52 | - charls=2.4.2=h13dd4ca_0 53 | - charset-normalizer=3.1.0=pyhd8ed1ab_0 54 | - click=8.1.3=unix_pyhd8ed1ab_2 55 | - cloudpickle=2.2.1=pyhd8ed1ab_0 56 | - colorama=0.4.6=pyhd8ed1ab_0 57 | - comm=0.1.3=pyhd8ed1ab_0 58 | - contourpy=1.1.0=py38h9afee92_0 59 | - cryptography=41.0.1=py38h92a0862_0 60 | - cycler=0.11.0=pyhd8ed1ab_0 61 | - cytoolz=0.12.0=py38hb991d35_1 62 | - dask=2023.5.0=pyhd8ed1ab_0 63 | - dask-core=2023.5.0=pyhd8ed1ab_0 64 | - dask-image=2023.3.0=pyhd8ed1ab_0 65 | - dav1d=1.2.1=hb547adb_0 66 | - debugpy=1.6.7=py38h2b1e499_0 67 | - decorator=5.1.1=pyhd8ed1ab_0 68 | - dill=0.3.6=pyhd8ed1ab_1 69 | - distributed=2023.5.0=pyhd8ed1ab_0 70 | - dm-tree=0.1.7=py38h55de146_0 71 | - docrep=0.3.2=pyh44b312d_0 72 | - entrypoints=0.4=pyhd8ed1ab_0 73 | - executing=1.2.0=pyhd8ed1ab_0 74 | - fasteners=0.17.3=pyhd8ed1ab_0 75 | - flatbuffers=22.12.06=hb7217d7_2 76 | - fonttools=4.40.0=py38hb192615_0 77 | - freetype=2.12.1=hd633e50_1 78 | - frozenlist=1.3.3=py38hb991d35_0 79 | - fsspec=2023.6.0=pyh1a96a4e_0 80 | - gast=0.4.0=pyh9f0ad1d_0 81 | - gflags=2.2.2=hc88da5d_1004 82 | - giflib=5.2.1=h1a8c8d9_3 83 | - glog=0.6.0=h6da1cb0_0 84 | - glpk=5.0=h6d7a090_0 85 | - gmp=6.2.1=h9f76cd9_0 86 | - google-auth=2.21.0=pyh1a96a4e_0 87 | - google-auth-oauthlib=0.4.6=pyhd8ed1ab_0 88 | - google-pasta=0.2.0=pyh8c360ce_0 89 | - grpc-cpp=1.51.1=h44b9a77_1 90 | - grpcio=1.51.1=py38h171e7b7_1 91 | - h5py=3.9.0=nompi_py38h8a8aaa0_101 92 | - hdf5=1.14.1=nompi_h3aba7b3_100 93 | - icu=70.1=h6b3803e_0 94 | - idna=3.4=pyhd8ed1ab_0 95 | - igraph=0.10.3=h80e09cb_0 96 | - imagecodecs=2023.1.23=py38h57345ed_0 97 | - imageio=2.31.1=pyh24c5eb1_0 98 | - importlib-metadata=6.7.0=pyha770c72_0 99 | - importlib-resources=5.12.0=pyhd8ed1ab_0 100 | - importlib_metadata=6.7.0=hd8ed1ab_0 101 | - importlib_resources=5.12.0=pyhd8ed1ab_0 102 | - inflect=6.0.4=pyhd8ed1ab_0 103 | - ipykernel=6.23.3=pyh5fb750a_0 104 | - ipython=8.12.2=pyhd1c38e8_0 105 | - jax=0.3.25=pyhd8ed1ab_0 106 | - jaxlib=0.3.25=cpu_py38ha029f96_1 107 | - jedi=0.18.2=pyhd8ed1ab_0 108 | - jinja2=3.1.2=pyhd8ed1ab_1 109 | - joblib=1.3.0=pyhd8ed1ab_0 110 | - jpeg=9e=h1a8c8d9_3 111 | - jupyter_client=8.3.0=pyhd8ed1ab_0 112 | - jupyter_core=5.3.1=py38h10201cd_0 113 | - jxrlib=1.1=h27ca646_2 114 | - keras=2.11.0=pyhd8ed1ab_0 115 | - keras-preprocessing=1.1.2=pyhd8ed1ab_0 116 | - kiwisolver=1.4.4=py38h9dc3d6a_1 117 | - krb5=1.20.1=h69eda48_0 118 | - lazy_loader=0.2=pyhd8ed1ab_0 119 | - lcms2=2.15=h481adae_0 120 | - leidenalg=0.9.1=py38h2b1e499_0 121 | - lerc=4.0.0=h9a09cb3_0 122 | - libabseil=20220623.0=cxx17_h28b99d4_6 123 | - libaec=1.0.6=hb7217d7_1 124 | - libarrow=11.0.0=h0b9b5d1_5_cpu 125 | - libavif=0.11.1=h9f83d30_2 126 | - libblas=3.9.0=17_osxarm64_openblas 127 | - libbrotlicommon=1.0.9=h1a8c8d9_9 128 | - libbrotlidec=1.0.9=h1a8c8d9_9 129 | - libbrotlienc=1.0.9=h1a8c8d9_9 130 | - libcblas=3.9.0=17_osxarm64_openblas 131 | - libcrc32c=1.1.2=hbdafb3b_0 132 | - libcurl=8.1.2=h912dcd9_0 133 | - libcxx=16.0.6=h4653b0c_0 134 | - libdeflate=1.17=h1a8c8d9_0 135 | - libedit=3.1.20191231=hc8eb9b7_2 136 | - libev=4.33=h642e427_1 137 | - libevent=2.1.10=h7673551_4 138 | - libffi=3.4.2=h3422bc3_5 139 | - libgfortran=5.0.0=12_2_0_hd922786_31 140 | - libgfortran5=12.2.0=h0eea778_31 141 | - libgoogle-cloud=2.7.0=hcf11473_1 142 | - libgrpc=1.51.1=hb15be72_1 143 | - libiconv=1.17=he4db4b2_0 144 | - liblapack=3.9.0=17_osxarm64_openblas 145 | - libllvm11=11.1.0=hfa12f05_5 146 | - libllvm14=14.0.6=hd1a9a77_3 147 | - libnghttp2=1.52.0=hae82a92_0 148 | - libopenblas=0.3.23=openmp_hc731615_0 149 | - libpng=1.6.39=h76d750c_0 150 | - libprotobuf=3.21.12=hb5ab8b9_0 151 | - libsodium=1.0.18=h27ca646_1 152 | - libsqlite=3.42.0=hb31c410_0 153 | - libssh2=1.11.0=h7a5bd25_0 154 | - libthrift=0.18.0=h6635e49_0 155 | - libtiff=4.5.0=h5dffbdd_2 156 | - libutf8proc=2.8.0=h1a8c8d9_0 157 | - libwebp-base=1.3.0=h1a8c8d9_0 158 | - libxcb=1.13=h9b22ae9_1004 159 | - libxml2=2.10.3=h67585b2_4 160 | - libzlib=1.2.13=h53f4e23_5 161 | - libzopfli=1.0.3=h9f76cd9_0 162 | - llvm-openmp=16.0.6=h1c12783_0 163 | - llvmlite=0.38.1=py38h8a5a59d_0 164 | - locket=1.0.0=pyhd8ed1ab_0 165 | - lz4=4.3.2=py38h76a69a3_0 166 | - lz4-c=1.9.4=hb7217d7_0 167 | - markdown=3.4.3=pyhd8ed1ab_0 168 | - markupsafe=2.1.3=py38hb192615_0 169 | - matplotlib=3.7.1=py38h150bfb4_0 170 | - matplotlib-base=3.7.1=py38hbbe890c_0 171 | - matplotlib-inline=0.1.6=pyhd8ed1ab_0 172 | - matplotlib-scalebar=0.8.1=pyhd8ed1ab_0 173 | - metis=5.1.0=h9f76cd9_1006 174 | - mpfr=4.2.0=he09a6ba_0 175 | - msgpack-python=1.0.5=py38h9dc3d6a_0 176 | - multidict=6.0.4=py38hb991d35_0 177 | - munkres=1.1.4=pyh9f0ad1d_0 178 | - natsort=8.4.0=pyhd8ed1ab_0 179 | - ncurses=6.4=h7ea286d_0 180 | - nest-asyncio=1.5.6=pyhd8ed1ab_0 181 | - networkx=3.1=pyhd8ed1ab_0 182 | - numba=0.55.2=py38h25e2f74_0 183 | - numcodecs=0.11.0=py38h2b1e499_1 184 | - numpy=1.22.4=py38he1fcd3f_0 185 | - oauthlib=3.2.2=pyhd8ed1ab_0 186 | - omnipath=1.0.7=pyhd8ed1ab_0 187 | - openjpeg=2.5.0=hbc2ba62_2 188 | - openssl=3.1.1=h53f4e23_1 189 | - opt_einsum=3.3.0=pyhd8ed1ab_1 190 | - orc=1.8.2=hef0d403_2 191 | - packaging=23.1=pyhd8ed1ab_0 192 | - pandas=1.5.3=py38h61dac83_1 193 | - parquet-cpp=1.5.1=2 194 | - parso=0.8.3=pyhd8ed1ab_0 195 | - partd=1.4.0=pyhd8ed1ab_0 196 | - patsy=0.5.3=pyhd8ed1ab_0 197 | - pexpect=4.8.0=pyh1a96a4e_2 198 | - pickleshare=0.7.5=py_1003 199 | - pillow=9.4.0=py38h1bb68ce_1 200 | - pims=0.6.1=pyhd8ed1ab_1 201 | - pip=23.1.2=pyhd8ed1ab_0 202 | - platformdirs=2.6.0=pyhd8ed1ab_0 203 | - prompt-toolkit=3.0.38=pyha770c72_0 204 | - prompt_toolkit=3.0.38=hd8ed1ab_0 205 | - protobuf=4.21.12=py38h2b1e499_0 206 | - psutil=5.9.5=py38hb991d35_0 207 | - pthread-stubs=0.4=h27ca646_1001 208 | - ptyprocess=0.7.0=pyhd3deb0d_0 209 | - pure_eval=0.2.2=pyhd8ed1ab_0 210 | - pyarrow=11.0.0=py38h32b283d_5_cpu 211 | - pyasn1=0.4.8=py_0 212 | - pyasn1-modules=0.2.7=py_0 213 | - pycparser=2.21=pyhd8ed1ab_0 214 | - pydantic=1.9.2=py38he5c2ac2_0 215 | - pygments=2.15.1=pyhd8ed1ab_0 216 | - pyjwt=2.7.0=pyhd8ed1ab_0 217 | - pynndescent=0.5.10=pyh1a96a4e_0 218 | - pyopenssl=23.2.0=pyhd8ed1ab_1 219 | - pyparsing=3.1.0=pyhd8ed1ab_0 220 | - pysocks=1.7.1=pyha2e5f31_6 221 | - python=3.8.17=h3ba56d0_0_cpython 222 | - python-dateutil=2.8.2=pyhd8ed1ab_0 223 | - python-flatbuffers=23.5.26=pyhd8ed1ab_0 224 | - python-igraph=0.10.4=py38h0504639_0 225 | - python-tzdata=2023.3=pyhd8ed1ab_0 226 | - python_abi=3.8=3_cp38 227 | - pytz=2023.3=pyhd8ed1ab_0 228 | - pyu2f=0.1.5=pyhd8ed1ab_0 229 | - pywavelets=1.4.1=py38hb39dbe9_0 230 | - pyyaml=6.0=py38hb991d35_5 231 | - pyzmq=25.1.0=py38hef91016_0 232 | - re2=2023.02.01=hb7217d7_0 233 | - readline=8.2=h92ec313_1 234 | - requests=2.31.0=pyhd8ed1ab_0 235 | - requests-oauthlib=1.3.1=pyhd8ed1ab_0 236 | - rsa=4.9=pyhd8ed1ab_0 237 | - scanpy=1.9.3=pyhd8ed1ab_0 238 | - scikit-image=0.20.0=py38hfaca753_1 239 | - scikit-learn=1.2.2=py38h971c870_2 240 | - scipy=1.9.1=py38h3aeb131_0 241 | - seaborn=0.12.2=hd8ed1ab_0 242 | - seaborn-base=0.12.2=pyhd8ed1ab_0 243 | - session-info=1.0.0=pyhd8ed1ab_0 244 | - setuptools=68.0.0=pyhd8ed1ab_0 245 | - six=1.15.0=pyh9f0ad1d_0 246 | - slicerator=1.1.0=pyhd8ed1ab_0 247 | - snappy=1.1.10=h17c5cce_0 248 | - sortedcontainers=2.4.0=pyhd8ed1ab_0 249 | - spyder-kernels=2.4.3=unix_pyhd8ed1ab_0 250 | - sqlite=3.42.0=h203b68d_0 251 | - squidpy=1.2.3=pyhd8ed1ab_0 252 | - stack_data=0.6.2=pyhd8ed1ab_0 253 | - statsmodels=0.14.0=py38h58b515d_1 254 | - stdlib-list=0.8.0=pyhd8ed1ab_0 255 | - suitesparse=5.10.1=h7cd81ec_1 256 | - tbb=2021.9.0=hffc8910_0 257 | - tblib=1.7.0=pyhd8ed1ab_0 258 | - tensorboard=2.11.2=pyhd8ed1ab_0 259 | - tensorboard-data-server=0.6.1=py38h23f6d3d_4 260 | - tensorboard-plugin-wit=1.8.1=pyhd8ed1ab_0 261 | - tensorflow=2.11.0=cpu_py38h4487131_0 262 | - tensorflow-base=2.11.0=cpu_py38h968a1bb_0 263 | - tensorflow-estimator=2.11.0=cpu_py38h8d8ceda_0 264 | - tensorflow-probability=0.19.0=pyhd8ed1ab_1 265 | - termcolor=1.1.0=pyhd8ed1ab_3 266 | - texttable=1.6.7=pyhd8ed1ab_0 267 | - threadpoolctl=3.1.0=pyh8a188c0_0 268 | - tifffile=2023.4.12=pyhd8ed1ab_0 269 | - tk=8.6.12=he1e0b03_0 270 | - toolz=0.12.0=pyhd8ed1ab_0 271 | - tornado=6.3.2=py38hb192615_0 272 | - tqdm=4.65.0=pyhd8ed1ab_1 273 | - traitlets=5.9.0=pyhd8ed1ab_0 274 | - typing-extensions=3.7.4.3=0 275 | - typing_extensions=3.7.4.3=py_0 276 | - umap-learn=0.5.3=py38h10201cd_1 277 | - unicodedata2=15.0.0=py38hb991d35_0 278 | - urllib3=1.26.15=pyhd8ed1ab_0 279 | - validators=0.20.0=pyhd8ed1ab_0 280 | - wcwidth=0.2.6=pyhd8ed1ab_0 281 | - werkzeug=2.3.6=pyhd8ed1ab_0 282 | - wheel=0.40.0=pyhd8ed1ab_0 283 | - wrapt=1.12.1=py38hea4295b_3 284 | - wurlitzer=3.0.3=pyhd8ed1ab_0 285 | - xarray=2023.1.0=pyhd8ed1ab_0 286 | - xorg-libxau=1.0.11=hb547adb_0 287 | - xorg-libxdmcp=1.1.3=h27ca646_0 288 | - xyzservices=2023.5.0=pyhd8ed1ab_1 289 | - xz=5.2.6=h57fd34a_0 290 | - yaml=0.2.5=h3422bc3_2 291 | - yarl=1.9.2=py38hb192615_0 292 | - zarr=2.15.0=pyhd8ed1ab_0 293 | - zeromq=4.3.4=hbdafb3b_1 294 | - zfp=1.0.0=hb6e4faa_3 295 | - zict=3.0.0=pyhd8ed1ab_0 296 | - zipp=3.15.0=pyhd8ed1ab_0 297 | - zlib=1.2.13=h53f4e23_5 298 | - zlib-ng=2.0.7=h1a8c8d9_0 299 | - zstd=1.5.2=hf913c23_6 300 | prefix: /opt/homebrew/Caskroom/mambaforge/base/envs/nsf-paper 301 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/willtownes/nsf-paper/0cacf8352e09d223ab8d4421025195358bbde8df/models/__init__.py -------------------------------------------------------------------------------- /models/likelihoods.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Tue Jun 29 10:52:16 2021 5 | 6 | @author: townesf 7 | """ 8 | from numpy import tile 9 | import tensorflow_probability as tfp 10 | tfb = tfp.bijectors 11 | tfd = tfp.distributions 12 | tv = tfp.util.TransformedVariable 13 | 14 | 15 | class InvalidLikelihoodError(ValueError): 16 | pass 17 | 18 | def init_lik(lik, J, disp="default", dtp="float32"): 19 | """ 20 | Given a likelihood, number of features (J), 21 | and a dispersion parameter initialization value, 22 | Return a TransformedVariable representing a vector of trainable, 23 | feature-specific dispersion parameters. 24 | """ 25 | if lik=="gau": 26 | #disp is the scale (st dev) of the normal distribution 27 | if disp=="default": disp=1.0 28 | elif lik=="nb": 29 | #var= mu + disp*mu^2, disp->0 is Poisson limit 30 | if disp=="default": disp=0.01 31 | elif lik=="poi": #lik="poi" 32 | disp = None 33 | else: 34 | raise InvalidLikelihoodError 35 | if disp is not None: #ie lik in ("gau","nb") 36 | disp = tv(tile(disp,J), tfb.Softplus(), dtype=dtp, name="dispersion") 37 | return disp 38 | 39 | def lik_to_distr(lik, mu, disp): 40 | """ 41 | Given a likelihood and a tensorflow TransformedVariable 'disp', 42 | Return a tensorflow probability distribution object with means 'pred_means' 43 | """ 44 | if lik=="poi": 45 | return tfd.Poisson(mu) 46 | elif lik=="gau": 47 | return tfd.Normal(mu, disp) 48 | elif lik=="nb": 49 | return tfd.NegativeBinomial.experimental_from_mean_dispersion(mu, disp) 50 | else: 51 | raise InvalidLikelihoodError 52 | 53 | def choose_nmf_pars(lik): 54 | if lik in ("poi","nb"): 55 | return {"beta_loss":"kullback-leibler", "solver":"mu", "init":"nndsvda"} 56 | elif lik=="gau": 57 | return {"beta_loss":"frobenius", "solver":"cd", "init":"nndsvd"} 58 | else: 59 | raise InvalidLikelihoodError 60 | -------------------------------------------------------------------------------- /models/mefisto.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | MEFISTO wrapper functions 5 | 6 | Created on Mon Jul 19 18:25:31 2021 7 | 8 | @author: townesf 9 | """ 10 | #import json 11 | import numpy as np 12 | from os import path 13 | from time import time,process_time 14 | #from mofax import mofa_model 15 | from mofapy2.run.entry_point import entry_point 16 | 17 | from utils import misc 18 | # from models.likelihoods import InvalidLikelihoodError 19 | 20 | 21 | class MEFISTO(object): 22 | def __init__(self, Dtr, num_factors, inducing_pts=1000, quiet=False, 23 | pickle_path=None): #lik="gau" 24 | """Dtr should be normalized and log transformed data""" 25 | #https://nbviewer.jupyter.org/github/bioFAM/MEFISTO_tutorials/blob/master/MEFISTO_ST.ipynb 26 | # if lik=="gau": lik = "gaussian" 27 | # elif lik=="poi": lik = "poisson" 28 | # else: raise InvalidLikelihoodError 29 | lik = "gaussian" 30 | ent = entry_point() 31 | ent.set_data_options(use_float32=True, center_groups=True) 32 | ent.set_data_matrix([[Dtr["Y"]]],likelihoods=lik) 33 | #ent.set_data_from_anndata(adtr, likelihoods="gaussian") 34 | ent.set_model_options(factors=num_factors,ard_weights=False) 35 | ent.set_train_options(quiet=quiet)#iter=ne) 36 | ent.set_covariates([Dtr["X"]]) 37 | #ent.set_covariates([adtr.obsm["spatial"]]) 38 | Mfrac = float(inducing_pts)/Dtr["X"].shape[0] 39 | if Mfrac<1.0: 40 | ent.set_smooth_options(sparseGP=True, frac_inducing=Mfrac) 41 | else: 42 | ent.set_smooth_options(sparseGP=False) 43 | #ent.set_stochastic_options(learning_rate=0.5) 44 | ent.build() 45 | self.ent = ent 46 | self.ptime = 0.0 47 | self.wtime = 0.0 48 | self.elbos = np.array([]) 49 | self.epoch = 0 50 | self.set_pickle_path(pickle_path) 51 | self.converged=False 52 | self.feature_means = Dtr["Y"].mean(axis=0) 53 | self.L = num_factors 54 | self.M = inducing_pts 55 | 56 | def init_loadings(self,*args,**kwargs): 57 | pass #only here for compatibility with PF, CF, etc 58 | 59 | def set_pickle_path(self,pickle_path): 60 | if pickle_path is not None: 61 | misc.mkdir_p(pickle_path) 62 | self.pickle_path=pickle_path 63 | 64 | def generate_pickle_path(self,base=None): 65 | pars = {"L":self.L, 66 | "lik":"gau", 67 | "model":"MEFISTO", 68 | "kernel":"ExponentiatedQuadratic", 69 | "M":self.M 70 | } 71 | pth = misc.params2key(pars) 72 | if base: pth = path.join(base,pth) 73 | return pth 74 | 75 | def pickle(self): 76 | """ 77 | *args passed to update_times method 78 | """ 79 | if self.converged: 80 | fname = "converged.pickle" 81 | else: 82 | fname = "epoch{}.pickle".format(self.epoch) 83 | misc.pickle_to_file(self, path.join(self.pickle_path,fname)) 84 | 85 | @staticmethod 86 | def from_pickle(pth,epoch=None): 87 | if epoch: 88 | fname = "epoch{}.pickle".format(epoch) 89 | else: 90 | fname = "converged.pickle" 91 | return misc.unpickle_from_file(path.join(pth,fname)) 92 | 93 | def train(self): 94 | ptic = process_time() 95 | wtic = time() 96 | self.ent.run() 97 | self.converged = self.ent.model.trained 98 | elbos = self.ent.model.getTrainingStats()["elbo"] 99 | self.elbos = elbos[~np.isnan(elbos)] 100 | self.epoch = max(len(self.elbos)-1, 0) 101 | self.ptime = process_time()-ptic 102 | self.wtime = time()-wtic 103 | if self.pickle_path: self.pickle() 104 | 105 | def get_loadings(self): 106 | return self.ent.model.nodes["W"].getExpectations()[0]["E"] 107 | 108 | def get_factors(self): 109 | return self.ent.model.nodes["Z"].getExpectations()["E"] 110 | 111 | def _reverse_normalization(self,Lambda,sz): 112 | return misc.reverse_normalization(Lambda, feature_means=self.feature_means, 113 | transform=np.expm1, sz=sz) 114 | 115 | def predict(self,Dtr,Dval=None,S=None): 116 | """ 117 | Here Dtr,Dval should be raw counts (not normalized or log-transformed) 118 | 119 | returns the predicted training data mean and validation data mean 120 | on the original count scale 121 | 122 | S is not used, only here for compatibility with visualization functions 123 | """ 124 | Wt = self.get_loadings().T 125 | Ftr = self.get_factors() 126 | sz_tr = Dtr["Y"].sum(axis=1) 127 | Mu_tr = self._reverse_normalization(Ftr@Wt, sz=sz_tr) 128 | if Dval: 129 | Xu,idx = np.unique(Dval["X"],axis=0,return_inverse=True) 130 | if Xu.shape[0]0 51 | if T is None: T = ceil(L/2.) 52 | assert T>0 and T% group_by(data) %>% summarize(med=median(spatial_wt),avg=mean(spatial_wt)) %>% arrange(med) 43 | sg$data<-factor(sg$data,levels=levs) 44 | ggplot(sg,aes(x=data,y=spatial_wt))+geom_boxplot()+xlab("dataset")+ylab("spatial score per gene") 45 | ggsave(fp(plt_pth,"gene_scores.pdf"),width=6,height=4) 46 | for(d in levs){ 47 | pd<-subset(sg,data==d) 48 | ggplot(pd,aes(x=spatial_wt))+geom_histogram(bins=50,fill="darkblue",color="darkblue")+xlab("spatial importance score")+ylab("number of genes") 49 | ggsave(fp("scrna",d,"results/plots",paste0(d,"_gene_spatial_score.pdf")),width=4,height=2.7) 50 | } 51 | ``` 52 | Compare to Hotspot 53 | ```{r} 54 | hs<-do.call(rbind,lapply(names(LL),load_hotspot)) 55 | hs$data<-factor(hs$data,levels=levs) 56 | colnames(hs)[which(colnames(hs)=="Gene")]<-"gene" 57 | hs$hs_spatial<-hs$FDR<0.05 58 | sg$nsfh_spatial<-sg$spatial_wt>0.5 59 | # with(sg,tapply(is_spatial,data,mean)) 60 | res<-merge(sg,hs,by=c("data","gene")) 61 | with(res,table(nsfh_spatial,hs_spatial,data)) 62 | res %>% group_by(data) %>% summarise(corr = cor(spatial_wt,Z,method="spearman")) 63 | ``` 64 | 65 | Spatial scores per cell 66 | 67 | ```{r} 68 | sc<-do.call(rbind,lapply(names(LL),load_csv,"spatial_cell_weights")) 69 | sc$data<-factor(sc$data,levels=levs) 70 | ggplot(sc,aes(x=data,y=spatial_wt))+geom_boxplot()+xlab("dataset")+ylab("spatial score per observation") 71 | ggsave(fp(plt_pth,"obs_scores.pdf"),width=6,height=4) 72 | for(d in levs){ 73 | pd<-subset(sc,data==d) 74 | ggplot(pd,aes(x=spatial_wt))+geom_histogram(bins=50,fill="darkblue",color="darkblue")+xlab("spatial importance score")+ylab("number of observations") 75 | ggsave(fp("scrna",d,"results/plots",paste0(d,"_obs_spatial_score.pdf")),width=4,height=2.7) 76 | } 77 | ``` 78 | 79 | Importance of spatial factors (SPDE style) 80 | 81 | ```{r} 82 | sgi<-do.call(rbind,lapply(names(LL),load_csv2,"dim_weights_spde")) 83 | sgi$data<-factor(sgi$data,levels=levs) 84 | ggplot(sgi,aes(x=id,y=weight,fill=factor_type))+geom_bar(stat="identity")+facet_wrap(~data,nrow=2,scales="free")+theme(legend.position="top")+xlab("factor")+ylab("importance") 85 | ggsave(fp(plt_pth,"importance_spde.pdf"),width=5,height=4) 86 | ``` 87 | 88 | Importance of spatial factors (LDA style) 89 | 90 | ```{r} 91 | sci<-do.call(rbind,lapply(names(LL),load_csv2,"dim_weights_lda")) 92 | sci$data<-factor(sci$data,levels=levs) 93 | ggplot(sci,aes(x=id,y=weight,fill=factor_type))+geom_bar(stat="identity")+facet_wrap(~data,nrow=2,scales="free")+theme(legend.position="none")+xlab("factor")+ylab("importance") 94 | ggsave(fp(plt_pth,"importance_lda.pdf"),width=5,height=3) 95 | ``` 96 | 97 | spatial autocorrelation of components 98 | 99 | ```{r} 100 | #sshippo 101 | dac<-read.csv(fp(pth,"sshippo/results/NSFH_dim_autocorr_spde_L20_T10.csv")) 102 | dac$type<-rep(c("spatial","nonspatial"),each=10) 103 | #dac$id<-as.character(rep(1:10,2)) 104 | dac$component<-factor(dac$component,levels=dac$component) 105 | ggplot(dac,aes(x=component,y=moran_i,fill=type))+geom_bar(stat="identity")+ylab("Moran's I")+theme(legend.position="none") 106 | ggsave(fp(pth,"results/plots/sshippo_L20_moranI.pdf"),width=5,height=3) 107 | 108 | #xyzeq 109 | dac<-read.csv(fp(pth,"xyzeq_liver/results/NSFH_dim_autocorr_spde_L6_T3.csv")) 110 | dac$type<-rep(c("spatial","nonspatial"),each=3) 111 | #dac$id<-as.character(rep(1:10,2)) 112 | dac$component<-factor(dac$component,levels=dac$component) 113 | ggplot(dac,aes(x=component,y=moran_i,fill=type))+geom_bar(stat="identity")+ylab("Moran's I")+theme(legend.position="none") 114 | ggsave(fp(pth,"results/plots/xyz_liv_L6_moranI.pdf"),width=3,height=1.75) 115 | 116 | #visium 117 | dac<-read.csv(fp(pth,"visium_brain_sagittal/results/NSFH_dim_autocorr_spde_L20_T10.csv")) 118 | dac$type<-rep(c("spatial","nonspatial"),each=10) 119 | #dac$id<-as.character(rep(1:10,2)) 120 | dac$component<-factor(dac$component,levels=dac$component) 121 | ggplot(dac,aes(x=component,y=moran_i,fill=type))+geom_bar(stat="identity")+ylab("Moran's I")+theme(legend.position="none") 122 | ggsave(fp(pth,"results/plots/vz_brn_L20_moranI.pdf"),width=5,height=3) 123 | 124 | dac<-read.csv(fp(pth,"visium_brain_sagittal/results/NSFH_dim_autocorr_spde_L60_T30.csv")) 125 | dac$type<-rep(c("spatial","nonspatial"),each=30) 126 | #dac$id<-as.character(rep(1:10,2)) 127 | dac$component<-factor(dac$component,levels=dac$component) 128 | ggplot(dac,aes(x=component,y=moran_i,fill=type))+geom_bar(stat="identity")+ylab("Moran's I")+theme(legend.position="none", axis.text.x=element_text(angle=45,hjust=1)) 129 | ggsave(fp(pth,"results/plots/vz_brn_L60_moranI.pdf"),width=5.5,height=3) 130 | ``` -------------------------------------------------------------------------------- /scrna/sshippo/01_data_loading.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Slide-seq v2 hippocampus" 3 | author: "Will Townes" 4 | output: html_document 5 | --- 6 | 7 | ```{r} 8 | library(Seurat) 9 | library(SeuratData) #remotes::install_github("satijalab/seurat-data") 10 | library(SeuratDisk) #remotes::install_github("mojaveazure/seurat-disk") 11 | library(scry) #bioconductor 12 | 13 | fp<-file.path 14 | pth<-"scrna/sshippo" 15 | dpth<-fp(pth,"data") 16 | if(!dir.exists(fp(dpth,"original"))){ 17 | dir.create(fp(dpth,"original"),recursive=TRUE) 18 | } 19 | ``` 20 | 21 | Download the Seurat dataset 22 | 23 | ```{r} 24 | #InstallData("ssHippo") #run once to store to disk 25 | slide.seq <- LoadData("ssHippo") 26 | ``` 27 | 28 | rank genes by deviance 29 | 30 | ```{r} 31 | X<-slide.seq@images[[1]]@coordinates 32 | slide.seq<-AddMetaData(slide.seq,X) 33 | Y<-slide.seq@assays[[1]]@counts 34 | #Y<-Y[rowSums(Y)>0,] 35 | dev<-devianceFeatureSelection(Y,fam="poisson") 36 | dev[is.na(dev)]<-0 37 | slide.seq@assays[[1]]<-AddMetaData(slide.seq@assays[[1]], dev, col.name="deviance_poisson") 38 | 39 | o<-order(dev,decreasing=TRUE) 40 | #Y<-Y[o,] 41 | #dev<-dev[o] 42 | plot(dev[o],type="l",log="y") 43 | abline(v=1000,col="red") 44 | ``` 45 | 46 | Converting to H5AD, based on https://mojaveazure.github.io/seurat-disk/articles/convert-anndata.html 47 | 48 | ```{r} 49 | dfile<-fp(dpth,"original/sshippo.h5Seurat") 50 | SaveH5Seurat(slide.seq, filename=dfile) 51 | Convert(dfile, dest="h5ad") 52 | ``` 53 | 54 | 55 | 56 | -------------------------------------------------------------------------------- /scrna/sshippo/02_data_loading.ipy: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # --- 3 | # jupyter: 4 | # jupytext: 5 | # text_representation: 6 | # extension: .ipy 7 | # format_name: percent 8 | # format_version: '1.3' 9 | # jupytext_version: 1.6.0 10 | # kernelspec: 11 | # display_name: Python 3 12 | # language: python 13 | # name: python3 14 | # --- 15 | 16 | # %% [markdown] 17 | """ 18 | First run 01_data_loading.Rmd to get data via Seurat, compute 19 | Poisson deviance for each gene, and export to H5AD. 20 | """ 21 | 22 | # %% imports 23 | import random 24 | import numpy as np 25 | import scanpy as sc 26 | from os import path 27 | 28 | random.seed(101) 29 | pth = "scrna/sshippo" 30 | 31 | # %% load the pre-processed dataset 32 | #ad = sq.datasets.slideseqv2() 33 | ad = sc.read_h5ad(path.join(pth,"data/original/sshippo.h5ad")) 34 | 35 | # %% Desiderata for dataset [markdown] 36 | # 1. Spatial coordinates 37 | # 2. Features sorted in decreasing order of deviance 38 | # 3. Observations randomly shuffled 39 | 40 | #%% organize anndata 41 | ad.obsm['spatial'] = ad.obs[["x","y"]].to_numpy() 42 | ad.obs.drop(columns=["x","y"],inplace=True) 43 | ad.X = ad.raw.X 44 | ad.raw = None 45 | 46 | # %% QC, loosely following MEFISTO tutorials 47 | # https://nbviewer.jupyter.org/github/bioFAM/MEFISTO_tutorials/blob/master/MEFISTO_ST.ipynb#QC-and-preprocessing 48 | #ad.var_names_make_unique() 49 | ad.var["mt"] = ad.var_names.str.startswith("MT-") 50 | sc.pp.calculate_qc_metrics(ad, qc_vars=["mt"], inplace=True) 51 | ad.obs.pct_counts_mt.hist(bins=100) 52 | ad = ad[ad.obs.pct_counts_mt < 20] #from 53K to 45K 53 | tc = ad.obs.total_counts 54 | tc.hist(bins=100) 55 | tc[tc<500].hist(bins=100) 56 | (tc<100).sum() #8000 spots 57 | sc.pp.filter_cells(ad, min_counts=100) 58 | sc.pp.filter_genes(ad, min_cells=1) 59 | ad.layers = {"counts":ad.X.copy()} #store raw counts before normalization changes ad.X 60 | sc.pp.normalize_total(ad, inplace=True, layers=None, key_added="sizefactor") 61 | sc.pp.log1p(ad) 62 | 63 | # %% sort by deviance 64 | o = np.argsort(-ad.var['deviance_poisson']) 65 | idx = list(range(ad.shape[0])) 66 | random.shuffle(idx) 67 | ad = ad[idx,o] 68 | ad.var["deviance_poisson"].plot() 69 | ad.write_h5ad(path.join(pth,"data/sshippo.h5ad"),compression="gzip") 70 | #ad = sc.read_h5ad(path.join(pth,"data/sshippo.h5ad")) 71 | ad2 = ad[:,:2000] 72 | ad2.write_h5ad(path.join(pth,"data/sshippo_J2000.h5ad"),compression="gzip") 73 | -------------------------------------------------------------------------------- /scrna/sshippo/03_benchmark.ipy: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # --- 3 | # jupyter: 4 | # jupytext: 5 | # text_representation: 6 | # extension: .ipy 7 | # format_name: percent 8 | # format_version: '1.3' 9 | # jupytext_version: 1.6.0 10 | # kernelspec: 11 | # display_name: Python 3 12 | # language: python 13 | # name: python3 14 | # --- 15 | 16 | #%% 17 | import pandas as pd 18 | from os import path 19 | from utils import misc,benchmark 20 | 21 | pth = "scrna/sshippo" 22 | dpth = path.join(pth,"data") 23 | mpth = path.join(pth,"models/V5") 24 | rpth = path.join(pth,"results") 25 | misc.mkdir_p(rpth) 26 | 27 | #%% Create CSV with benchmarking parameters 28 | csv_path = path.join(rpth,"benchmark.csv") 29 | try: 30 | par = pd.read_csv(csv_path) 31 | except FileNotFoundError: 32 | L = [6,12,20] 33 | sp_mods = ["NSF-P", "NSF-N", "NSF-G", "RSF-G", "NSFH-P", "NSFH-N"] 34 | ns_mods = ["PNMF-P", "PNMF-N", "FA-G"] 35 | sz = ["constant","scanpy"] 36 | M = [1000,2000,3000] 37 | par = benchmark.make_param_df(L,sp_mods,ns_mods,M,sz) 38 | par.to_csv(csv_path,index=False) 39 | 40 | #%% Benchmarking at command line [markdown] 41 | """ 42 | To run on local computer, use 43 | `python -m utils.benchmark 2 scrna/sshippo/data/sshippo_J2000.h5ad` 44 | where 2 is a row ID of benchmark.csv, min value 2, max possible value is 115 45 | 46 | To run on cluster first load anaconda environment 47 | ``` 48 | tmux 49 | interactive 50 | module load anaconda3/2021.5 51 | conda activate fwt 52 | python -m utils.benchmark 2 scrna/sshippo/data/sshippo_J2000.h5ad 53 | ``` 54 | 55 | To run on cluster as a job array, subset of rows. Recommend 24hr time limit. 56 | ``` 57 | DAT=./scrna/sshippo/data/sshippo_J2000.h5ad 58 | sbatch --mem=180G --array=61,64 ./utils/benchmark_array.slurm $DAT 59 | ``` 60 | 61 | To run on cluster as a job array, all rows of CSV file 62 | ``` 63 | CSV=./scrna/sshippo/results/benchmark.csv 64 | DAT=./scrna/sshippo/data/sshippo_J2000.h5ad 65 | sbatch --mem=180G --array=2-$(wc -l < $CSV) ./utils/benchmark_array.slurm $DAT 66 | ``` 67 | """ 68 | 69 | #%% Compute metrics for each model 70 | """ 71 | DAT=./scrna/sshippo/data/sshippo_J2000.h5ad 72 | sbatch --mem=180G ./utils/benchmark_gof.slurm $DAT 5 73 | """ 74 | 75 | #%% Examine one result 76 | from matplotlib import pyplot as plt 77 | from utils import training 78 | tro = training.ModelTrainer.from_pickle(path.join(mpth,"L20/poi/NSFH_T10_MaternThreeHalves_M1000")) 79 | plt.plot(tro.loss["train"][-200:-1]) 80 | -------------------------------------------------------------------------------- /scrna/sshippo/05_benchmark_viz.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Benchmarking Visualization" 3 | author: "Will Townes" 4 | output: html_document 5 | --- 6 | 7 | ```{r} 8 | library(tidyverse) 9 | theme_set(theme_bw()) 10 | fp<-file.path 11 | pth<-"scrna/sshippo" 12 | plt_pth<-fp(pth,"results/plots") 13 | if(!dir.exists(plt_pth)){ 14 | dir.create(plt_pth,recursive=TRUE) 15 | } 16 | ``` 17 | 18 | Convergence failures 19 | * MEFISTO with L>6 or M>2000 (ran out of memory) 20 | 21 | ```{r} 22 | d0<-read.csv(fp(pth,"results/benchmark.csv")) 23 | d0$converged<-as.logical(d0$converged) 24 | #d0$model<-plyr::mapvalues(d0$model,c("FA","RSF","PNMF","NSFH","NSF"),c("FA","RSF","PNMF","NSFH","NSF")) 25 | d<-subset(d0,converged==TRUE) 26 | da<-subset(d,model %in% c("FA","MEFISTO","RSF")) 27 | db<-subset(d,(model %in% c("PNMF","NSFH","NSF")) & (sz=="scanpy")) 28 | d<-rbind(da,db) 29 | d$model<-paste0(d$model," (",d$lik,")") 30 | unique(d$model) 31 | keep<-c("FA (gau)","MEFISTO (gau)","RSF (gau)","PNMF (nb)","PNMF (poi)","NSFH (nb)","NSFH (poi)","NSF (nb)","NSF (poi)") 32 | d<-subset(d,model %in% keep) 33 | d$model<-factor(d$model,levels=keep) 34 | d$dim<-factor(d$L,levels=sort(unique(d$L))) 35 | d$M[is.na(d$M)]<-3000 36 | d$IPs<-factor(d$M,levels=sort(unique(d$M))) 37 | ``` 38 | subset of models for simplified main figure 39 | ```{r} 40 | d2<-subset(d,M==2000 | (model %in% c("PNMF (poi)","FA (gau)"))) 41 | d2a<-subset(d2,(model %in% c("PNMF (poi)","NSFH (poi)","NSF (poi)")) & sz=="scanpy") 42 | d2b<-subset(d2,model %in% c("FA (gau)","RSF (gau)","MEFISTO (gau)")) 43 | d2<-rbind(d2a,d2b) 44 | d2$model<-factor(d2$model,levels=) 45 | d2$model<-plyr::mapvalues(d2$model,c("FA (gau)","MEFISTO (gau)","RSF (gau)","PNMF (poi)","NSFH (poi)","NSF (poi)"),c("FA","MEFISTO","RSF","PNMF","NSFH","NSF")) 46 | 47 | ggplot(d2,aes(x=model,y=dev_tr_mean,color=dim,shape=lik))+geom_point(size=6,position=position_dodge(width=0.8))+ylab("training deviance (mean)")+ylim(range(d2$dev_tr_mean)*c(.95,1.02)) 48 | ggsave(fp(plt_pth,"sshippo_gof_dev_tr_mean_main.pdf"),width=5,height=2.5) 49 | 50 | ggplot(d2,aes(x=model,y=dev_val_mean,color=dim,shape=lik))+geom_point(size=6,position=position_dodge(width=0.8))+ylab("validation deviance (mean)")+ylim(range(d2$dev_val_mean)*c(.95,1.02)) 51 | ggsave(fp(plt_pth,"sshippo_gof_dev_val_mean_main.pdf"),width=5,height=2.5) 52 | 53 | ggplot(d2,aes(x=model,y=rmse_val,color=dim,shape=lik))+geom_point(size=6,position=position_dodge(width=0.8))+ylab("validation RMSE")+ylim(range(d2$rmse_val)*c(.95,1.02)) 54 | ggsave(fp(plt_pth,"sshippo_gof_rmse_val_simple.pdf"),width=5,height=2.5) 55 | 56 | #linear regressions for statistical significance 57 | d2$realvalued<-d2$model %in% c("FA","MEFISTO","RSF") 58 | #d2$dim_numeric<-as.numeric(as.character(d2$dim)) 59 | summary(lm(dev_val_mean~realvalued+dim,data=d2)) 60 | d2$unsupervised<-d2$model %in% c("FA","PNMF") 61 | d2$unsupervised[d2$model=="MEFISTO"]<-NA 62 | summary(lm(dev_val_mean~realvalued+dim+unsupervised,data=d2)) 63 | t.test(dev_val_mean~model,data=subset(d2,model %in% c("NSFH","NSF")), var.equal=TRUE,alternative="greater") 64 | 65 | #time 66 | ggplot(d2,aes(x=model,y=wtime/60,color=dim,shape=lik))+geom_point(size=6,position=position_dodge(width=0.8))+ylab("time to converge (min)")+ylim(range(d2$wtime/60)*c(.95,1.02)) 67 | ggsave(fp(plt_pth,"sshippo_wtime_simple.pdf"),width=5,height=2.5) 68 | 69 | #sparsity 70 | ggplot(d2,aes(x=model,y=sparsity,color=dim,shape=lik))+geom_point(size=6,position=position_dodge(width=0.8))+ylab("loadings zero fraction")+ylim(range(d2$sparsity)*c(.95,1.02)) 71 | ggsave(fp(plt_pth,"sshippo_sparsity_simple.pdf"),width=5,height=2.5) 72 | 73 | #time 74 | ggplot(d2,aes(x=model,y=wtime/60,color=dim,shape=lik))+geom_point(size=6,position=position_dodge(width=0.8))+ylab("time to converge (min)")+ylim(range(d2$wtime/60)*c(.95,1.02)) 75 | ggsave(fp(plt_pth,"sshippo_wtime_simple.pdf"),width=5,height=2.5) 76 | ``` 77 | NB vs Poi likelihood 78 | ```{r} 79 | d3<-subset(d,lik %in% c("poi","nb")) 80 | 81 | ggplot(d3,aes(x=model,y=dev_tr_mean,color=dim,shape=IPs))+geom_point(size=4,position=position_dodge(width=0.6))+ylab("training deviance (mean)") 82 | ggsave(fp(plt_pth,"sshippo_gof_dev_tr_nb_vs_poi.pdf"),width=6,height=3) 83 | 84 | ggplot(d3,aes(x=model,y=dev_val_mean,color=IPs,shape=dim))+geom_point(size=4,position=position_dodge(width=0.6))+ylab("validation deviance (mean)") 85 | ggsave(fp(plt_pth,"sshippo_gof_dev_val_nb_vs_poi.pdf"),width=6,height=3) 86 | 87 | ggplot(d3,aes(x=model,y=rmse_val,color=IPs,shape=dim))+geom_point(size=4,position=position_dodge(width=0.6))+ylab("validation RMSE") 88 | ggsave(fp(plt_pth,"sshippo_gof_rmse_val_nb_vs_poi.pdf"),width=6,height=3) 89 | 90 | #time 91 | ggplot(d3,aes(x=model,y=wtime/60,color=IPs,shape=dim))+geom_point(size=4,position=position_dodge(width=0.8))+ylab("time to converge (min)")+scale_y_log10() 92 | ggsave(fp(plt_pth,"sshippo_wtime_nb_vs_poi.pdf"),width=6,height=3) 93 | ``` 94 | 95 | NSF vs NSFH 96 | 97 | ```{r} 98 | d3<-subset(d,grepl("NSF",model)) 99 | ggplot(d3,aes(x=model,y=dev_tr_mean,color=dim,shape=IPs))+geom_point(size=4,position=position_dodge(width=0.8))+ylab("training deviance (mean)") 100 | ggsave(fp(plt_pth,"sshippo_gof_dev_tr_nsf_vs_nsfh.pdf"),width=6,height=3) 101 | ``` 102 | 103 | ```{r} 104 | #training deviance - average 105 | ggplot(d,aes(x=model,y=dev_tr_mean,color=dim,shape=IPs))+geom_jitter(size=5,width=.2,height=0)+ylab("training deviance (mean)")+theme(legend.position = "top") 106 | ggsave(fp(plt_pth,"sshippo_gof_dev_tr_mean.pdf"),width=6,height=3) 107 | 108 | #training deviance - max 109 | ggplot(d,aes(x=model,y=dev_tr_max,color=dim,shape=sz))+geom_jitter(size=5,width=.2,height=0)+ylab("training deviance (max)")#+scale_y_log10() 110 | 111 | #validation deviance - mean 112 | ggplot(d,aes(x=model,y=dev_val_mean,color=dim,shape=sz))+geom_jitter(size=4,width=.4,height=0)+ylab("validation deviance (mean)")+theme(legend.position="none") 113 | ggsave(fp(plt_pth,"sshippo_gof_dev_val_mean.pdf"),width=6,height=2.5) 114 | 115 | #validation deviance - max 116 | ggplot(d,aes(x=model,y=dev_val_max,color=dim,shape=sz))+geom_jitter(size=3,width=.2,height=0)+ylab("validation deviance (max)")#+scale_y_log10() 117 | 118 | #sparsity 119 | ggplot(d,aes(x=model,y=sparsity,color=dim,shape=sz))+geom_jitter(size=4,width=.3,height=0)+ylab("sparsity of loadings")+theme(legend.position="none") 120 | ggsave(fp(plt_pth,"sshippo_sparsity.pdf"),width=6,height=2.5) 121 | 122 | #wall time 123 | ggplot(d,aes(x=model,y=wtime/3600,color=IPs,shape=dim))+geom_jitter(size=5,width=.2,height=0)+ylab("wall time (hr)")+scale_y_log10()+theme(legend.position="top") 124 | ggsave(fp(plt_pth,"sshippo_wtime.pdf"),width=6,height=3) 125 | ggplot(d,aes(x=dim,y=wtime/60,color=IPs,shape=sz))+geom_jitter(size=3,width=.2,height=0)+ylab("wall time (min)")+scale_y_log10() 126 | 127 | #processor time 128 | ggplot(d,aes(x=model,y=ptime/3600,color=IPs,shape=dim))+geom_jitter(size=5,width=.2,height=0)+ylab("processor time")+scale_y_log10()+theme(legend.position="none") 129 | ggsave(fp(plt_pth,"sshippo_ptime.pdf"),width=6,height=2.5) 130 | ``` 131 | 132 | Effect of using size factors 133 | 134 | ```{r} 135 | d2<-subset(d,model %in% c("NSF (nb)","NSF (poi)","NSFH (nb)","NSFH (poi)")) 136 | d2$model<-factor(d2$model) 137 | d2$dim<-factor(d2$L,levels=sort(unique(d2$L))) 138 | d2$M[is.na(d2$M)]<-3000 139 | d2$IPs<-factor(d2$M,levels=sort(unique(d2$M))) 140 | ``` 141 | 142 | ```{r} 143 | ggplot(d2,aes(x=sz,y=dev_tr_mean,color=dim,shape=IPs))+geom_jitter(size=5,width=.2,height=0)+ylab("training deviance (mean)")+theme(legend.position = "top")+facet_wrap(~model,ncol=2,scales="free") 144 | 145 | ggplot(d2,aes(x=sz,y=dev_val_mean,color=dim,shape=IPs))+geom_jitter(size=5,width=.2,height=0)+ylab("validation deviance (mean)")+theme(legend.position = "top")+facet_wrap(~model,ncol=2,scales="free") 146 | ``` 147 | 148 | Effect of inducing points 149 | 150 | ```{r} 151 | d2<-subset(d,model %in% c("MEFISTO (gau)","RSF (gau)","NSFH (nb)","NSFH (poi)","NSF (nb)","NSF (poi)")) 152 | ggplot(d2,aes(x=IPs,y=dev_tr_mean,colour=dim,group=dim))+geom_line(size=2,lineend="round")+ylab("training deviance (mean)")+theme(legend.position = "top")+facet_wrap(~model,scales="free") 153 | ggsave(fp(plt_pth,"sshippo_gof_dev_tr_ips.pdf"),width=6,height=3) 154 | 155 | ggplot(d2,aes(x=IPs,y=rmse_tr,colour=dim,group=dim))+geom_line(size=2,lineend="round")+ylab("training RMSE")+theme(legend.position = "top")+facet_wrap(~model,scales="free") 156 | ggsave(fp(plt_pth,"sshippo_gof_rmse_tr_ips.pdf"),width=6,height=3) 157 | 158 | ggplot(d2,aes(x=IPs,y=dev_val_mean,colour=dim,group=dim))+geom_line(size=2,lineend="round")+ylab("validation deviance (mean)")+theme(legend.position = "top")+facet_wrap(~model,scales="free") 159 | ggsave(fp(plt_pth,"sshippo_gof_dev_val_ips.pdf"),width=6,height=3) 160 | 161 | ggplot(d2,aes(x=IPs,y=rmse_val,colour=dim,group=dim))+geom_line(size=2,lineend="round")+ylab("validation RMSE")+theme(legend.position = "top")+facet_wrap(~model,scales="free") 162 | ggsave(fp(plt_pth,"sshippo_gof_rmse_val_ips.pdf"),width=6,height=3) 163 | 164 | ggplot(d2,aes(x=IPs,y=wtime/3600,colour=dim,group=dim))+geom_line(size=2,lineend="round")+ylab("wall time (hr)")+theme(legend.position = "top")+facet_wrap(~model,scales="free") 165 | ggsave(fp(plt_pth,"sshippo_wtime_ips.pdf"),width=6,height=3) 166 | ``` 167 | -------------------------------------------------------------------------------- /scrna/sshippo/06_interpret_genes.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "06_interpretation_genes" 3 | author: "Will Townes" 4 | output: html_document 5 | --- 6 | 7 | ```{r} 8 | library(biomaRt) 9 | source("./scrna/utils/interpret_genes.R") 10 | 11 | fp<-file.path 12 | pth<-"scrna/sshippo" 13 | ``` 14 | loadings from NSFH 15 | ```{r} 16 | W<-read.csv(fp(pth,"results/sshippo_nsfh20_spde_loadings.csv"),header=TRUE,row.names=1) 17 | rownames(W)<-toupper(rownames(W)) 18 | # #format rownames in title case 19 | # g<-str_to_title(rownames(W)) 20 | # rownames(W)<-sub("Mt-","mt-",g,fixed=TRUE) 21 | W2<-t(apply(as.matrix(W),1,function(x){as.numeric(x==max(x))})) 22 | colSums(W2) #clustering of genes 23 | ``` 24 | Gene Ontology terms, used for both NSFH and hotspot. 25 | ```{r} 26 | # bg<-rownames(W) 27 | db<-useMart("ensembl",host="https://may2021.archive.ensembl.org",dataset='mmusculus_gene_ensembl') 28 | go_ids<-getBM(attributes=c('go_id', 'external_gene_name','namespace_1003'), filters='external_gene_name', values=rownames(W), mart=db) 29 | go_ids[,2]<-toupper(go_ids[,2]) 30 | gene_2_GO<-unstack(go_ids[,c(1,2)]) 31 | ``` 32 | GO analysis for NSFH 33 | ```{r} 34 | # g<-strsplit(res[l,"genes"],", ",fixed=TRUE)[[1]] 35 | # expr <- getMarkers(include = g[c(1:2,4,6,7)])$ 36 | res<-loadings2go(W,gene_2_GO)#,numtopgenes=50,min_genescore=1) 37 | res$type<-rep(c("spat","nsp"),each=10) 38 | res$dim<-rep(1:10,2) 39 | write.csv(res[,c(1,4,2,3)],fp(pth,"results/sshippo_nsfh20_spde_goterms.csv"),row.names=FALSE) 40 | ``` 41 | Get cell types from [Panglao](https://panglaodb.se). First manually download the list of all marker genes for all cell types from https://panglaodb.se/markers.html?cell_type=%27all_cells%27 . Store this TSV in the scrna/resources folder. 42 | ```{r} 43 | res<-read.csv(fp(pth,"results/sshippo_nsfh20_spde_goterms.csv"),header=TRUE) 44 | panglao_tsv="scrna/resources/PanglaoDB_markers_27_Mar_2020.tsv.gz" 45 | mk<-read.table(gzfile(panglao_tsv),header=TRUE,sep="\t") 46 | mk<-subset(mk,species %in% c("Mm","Mm Hs")) 47 | mk<-subset(mk,organ %in% c("Brain","Connective tissue","Epithelium")) 48 | #"Immune system","Vasculature","Blood", 49 | pg<-tapply(mk$official.gene.symbol,mk$cell.type,function(x){x},simplify=FALSE) 50 | ct<-lapply(res$genes,genes2celltype,pg)#,ss=3) 51 | ct[is.na(ct)]<-"" 52 | res$celltypes<-paste(ct,sep=", ") 53 | write.csv(res,fp(pth,"results/sshippo_nsfh20_spde_goterms.csv"),row.names=FALSE) 54 | ``` 55 | Shorter GO table 56 | ```{r} 57 | res<-read.csv(fp(pth,"results/sshippo_nsfh20_spde_goterms.csv"),header=TRUE) 58 | res2<-res 59 | g<-strsplit(res2$genes,", ",fixed=TRUE) 60 | res2$genes<-sapply(g,function(x){paste(x[1:5],collapse=", ")}) 61 | gt<-strsplit(res2$go_bp,"; ",fixed=TRUE) 62 | res2$go_bp<-sapply(gt,function(x){paste(x[1:2],collapse=", ")}) 63 | write.csv(res2,fp(pth,"results/sshippo_nsfh20_spde_goterms_short.csv"),row.names=FALSE) 64 | ``` 65 | 66 | GO analysis for hotspot 67 | ```{r} 68 | hs<-read.csv(fp(pth,"results/hotspot.csv"),header=TRUE) 69 | rownames(hs)<-hs$Gene 70 | Wh<-model.matrix(~0+factor(Module),data=subset(hs,Module!=-1)) 71 | colnames(Wh)<-paste0("X",1:ncol(Wh)) 72 | rownames(Wh)<-toupper(rownames(Wh)) 73 | # all(rownames(Wh) %in% rownames(W)) 74 | hres<-loadings2go(Wh,gene_2_GO, rowmean_divide=TRUE) 75 | colnames(hres)[which(colnames(hres)=="dim")]<-"cluster" 76 | write.csv(hres,fp(pth,"results/sshippo_hotspot_goterms.csv"),row.names=FALSE) 77 | hres2<-hres 78 | g<-strsplit(hres2$genes,", ",fixed=TRUE) 79 | hres2$genes<-sapply(g,function(x){paste(x[1:5],collapse=", ")}) 80 | gt<-strsplit(hres2$go_bp,"; ",fixed=TRUE) 81 | hres2$go_bp<-sapply(gt,function(x){paste(x[1:2],collapse=", ")}) 82 | write.csv(hres2,fp(pth,"results/sshippo_hotspot_goterms_short.csv"),row.names=FALSE) 83 | ``` 84 | -------------------------------------------------------------------------------- /scrna/sshippo/07_traditional.ipy: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # --- 3 | # jupyter: 4 | # jupytext: 5 | # text_representation: 6 | # extension: .ipy 7 | # format_name: percent 8 | # format_version: '1.3' 9 | # jupytext_version: 1.6.0 10 | # kernelspec: 11 | # display_name: Python 3 12 | # language: python 13 | # name: python3 14 | # --- 15 | 16 | #%% imports 17 | # import numpy as np 18 | import pandas as pd 19 | import scanpy as sc 20 | import matplotlib.pyplot as plt 21 | from os import path 22 | from hotspot import Hotspot 23 | 24 | from utils import misc,visualize 25 | 26 | dtp = "float32" 27 | pth = "scrna/sshippo" 28 | dpth = path.join(pth,"data") 29 | mpth = path.join(pth,"models") 30 | rpth = path.join(pth,"results") 31 | plt_pth = path.join(rpth,"plots") 32 | 33 | # %% Data Loading from scanpy 34 | J = 2000 35 | dfile = path.join(dpth,"sshippo.h5ad") 36 | adata = sc.read_h5ad(dfile) 37 | sc.pp.highly_variable_genes(adata, flavor="seurat", n_top_genes=J) 38 | X0 = adata.obsm["spatial"] 39 | X0[:,1] = -X0[:,1] 40 | 41 | #%% Traditional scanpy analysis (unsupervised clustering) 42 | #https://scanpy-tutorials.readthedocs.io/en/latest/spatial/basic-analysis.html 43 | sc.pp.pca(adata) 44 | sc.pp.neighbors(adata) 45 | sc.tl.umap(adata) 46 | sc.tl.leiden(adata, resolution=1.0, key_added="clusters") 47 | plt.rcParams["figure.figsize"] = (4, 4) 48 | sc.pl.umap(adata, color="clusters", wspace=0.4) 49 | sc.pl.embedding(adata, "spatial", color="clusters") 50 | cl = pd.get_dummies(adata.obs["clusters"]).to_numpy() 51 | tgnames = [str(i) for i in range(1,cl.shape[1]+1)] 52 | fig,axes=visualize.multiheatmap(X0, cl, (3,4), figsize=(6,4), cmap="turbo", 53 | bgcol="gray", s=0.01, marker=".", 54 | subplot_space=0, spinecolor="white") 55 | visualize.set_titles(fig, tgnames, x=0.03, y=.88, fontsize="small", c="white", 56 | ha="left", va="top") 57 | fig.savefig(path.join(plt_pth,"sshippo_heatmap_scanpy_clusters.png"), 58 | bbox_inches='tight', dpi=300) 59 | 60 | #%% Hotspot analysis (gene clusters) 61 | #https://hotspot.readthedocs.io/en/latest/Spatial_Tutorial.html 62 | J = 2000 63 | dfile = path.join(dpth,"sshippo_J{}.h5ad".format(J)) 64 | adata = sc.read_h5ad(dfile) 65 | adata.layers["counts"] = adata.layers["counts"].tocsc() 66 | hs = Hotspot(adata, layer_key="counts", model="danb", 67 | latent_obsm_key="spatial", umi_counts_obs_key="total_counts") 68 | hs.create_knn_graph(weighted_graph=False, n_neighbors=20) 69 | hs_results = hs.compute_autocorrelations() 70 | # hs_results.tail() 71 | hs_genes = hs_results.index#[hs_results.FDR < 0.05] 72 | lcz = hs.compute_local_correlations(hs_genes) 73 | modules = hs.create_modules(min_gene_threshold=20, core_only=False, 74 | fdr_threshold=0.05) 75 | # modules.value_counts() 76 | hs_results = hs_results.join(modules,how="left") 77 | hs_results.to_csv(path.join(rpth,"hotspot.csv")) 78 | misc.pickle_to_file(hs,path.join(mpth,"hotspot.pickle")) 79 | -------------------------------------------------------------------------------- /scrna/utils/interpret_genes.R: -------------------------------------------------------------------------------- 1 | library(topGO) 2 | 3 | loadings2go<-function(W, gene_2_GO, numtopgenes=100, min_genescore=1.0, 4 | rowmean_divide=TRUE){ 5 | #This blog was helpful: https://datacatz.wordpress.com/2018/01/19/gene-set-enrichment-analysis-with-topgo-part-1/ 6 | #divide each gene by its rowmean to adjust for constant high expression 7 | if(rowmean_divide){ 8 | W<-W/rowMeans(W) 9 | } 10 | W<-as.matrix(W) #enables preservation of rownames when subsetting cols 11 | qtl<-1-(numtopgenes/nrow(W)) 12 | topGenes<-function(x){ 13 | cutoff<-max(quantile(x,qtl),min_genescore) 14 | x>cutoff 15 | } 16 | #W must be a matrix with gene names in rownames 17 | L<-ncol(W) 18 | res<-data.frame(dim=1:L,genes="",go_bp="") 19 | # gg<-names(tg[[12]]) 20 | # gg<-gg[gg %in% names(gene_2_GO)] 21 | # geneList=factor(as.integer(bg %in% gg)) 22 | for(l in 1:L){ 23 | w<-W[,l] 24 | # names(geneList)<-rownames(W) 25 | tg<-head(sort(w,decreasing=TRUE),10) 26 | res[l,"genes"]<-paste(names(tg),collapse=", ") 27 | if(l==1){ 28 | GOdata<-new('topGOdata', ontology='BP', allGenes=w, geneSel=topGenes, annot=annFUN.gene2GO, gene2GO=gene_2_GO, nodeSize=5) 29 | } else { 30 | GOdata<-updateGenes(GOdata,w,topGenes) 31 | } 32 | #wk<-runTest(GOdata,algorithm='weight01', statistic='ks') 33 | wf<-runTest(GOdata, algorithm='weight01', statistic='fisher') 34 | gotab=GenTable(GOdata,weightFisher=wf,orderBy='weightFisher',numChar=1000,topNodes=5) 35 | res[l,"go_bp"]<-paste(gotab$Term,collapse="; ") 36 | } 37 | res 38 | } 39 | 40 | jaccard<-function(a,b){ 41 | i<-length(intersect(a,b)) 42 | denom<-ifelse(i>0, length(a)+length(b)-i, 1) 43 | i/denom 44 | } 45 | 46 | # panglao_make_ref<-function(panglao_tsv="scrna/resources/PanglaoDB_markers_27_Mar_2020.tsv.gz"){ 47 | # mk<-read.table(gzfile(panglao_tsv),header=TRUE,sep="\t") 48 | # tapply(mk$official.gene.symbol,mk$cell.type,function(x){x},simplify=FALSE) 49 | # } 50 | 51 | genes2celltype<-function(genelist,panglao_ref,gsplit=TRUE,ss=0){ 52 | #genelist: a string with gene names separated by commas OR a list of genes 53 | #panglao_ref: a list whose names are cell types and values are lists of genes 54 | #gsplit: if TRUE, splits genelist into a list. If FALSE, assumes genelist already split 55 | #returns : cell type with highest jaccard similarity to the gene list 56 | g<-ifelse(gsplit, strsplit(genelist,", ",fixed=TRUE)[[1]], genelist) 57 | if(ss>0 && ss0){ 62 | res<-j[j==jmx] 63 | return(names(res)) 64 | } else { #case where no cell types were found 65 | return(NA) 66 | } 67 | } -------------------------------------------------------------------------------- /scrna/visium_brain_sagittal/01_data_loading.ipy: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Data loading for 10x Visium mouse brain sagittal section anterior 1. 5 | 6 | Created on Sat Jun 5 15:22:19 2021 7 | 8 | @author: townesf 9 | """ 10 | 11 | # %% imports 12 | import random 13 | import numpy as np 14 | import scanpy as sc 15 | from os import path 16 | from scipy import sparse 17 | 18 | from utils import preprocess, training, misc 19 | 20 | random.seed(101) 21 | pth = "scrna/visium_brain_sagittal" 22 | 23 | # %% Download original data files 24 | %%sh 25 | mkdir -p scrna/visium_brain_sagittal/data/original 26 | pushd scrna/visium_brain_sagittal/data/original 27 | wget https://cf.10xgenomics.com/samples/spatial-exp/1.1.0/V1_Mouse_Brain_Sagittal_Anterior/V1_Mouse_Brain_Sagittal_Anterior_filtered_feature_bc_matrix.h5 28 | wget https://cf.10xgenomics.com/samples/spatial-exp/1.1.0/V1_Mouse_Brain_Sagittal_Anterior/V1_Mouse_Brain_Sagittal_Anterior_spatial.tar.gz 29 | tar -xzf V1_Mouse_Brain_Sagittal_Anterior_spatial.tar.gz 30 | rm V1_Mouse_Brain_Sagittal_Anterior_spatial.tar.gz 31 | popd 32 | 33 | # %% Desiderata for dataset [markdown] 34 | # 1. Spatial coordinates 35 | # 2. Features sorted in decreasing order of deviance 36 | # 3. Observations randomly shuffled 37 | 38 | # %% QC, loosely following MEFISTO tutorials 39 | # https://nbviewer.jupyter.org/github/bioFAM/MEFISTO_tutorials/blob/master/MEFISTO_ST.ipynb#QC-and-preprocessing 40 | ad = sc.read_visium(path.join(pth,"data/original"), 41 | count_file="V1_Mouse_Brain_Sagittal_Anterior_filtered_feature_bc_matrix.h5") 42 | ad.var_names_make_unique() 43 | ad.var["mt"] = ad.var_names.str.startswith("mt-") 44 | sc.pp.calculate_qc_metrics(ad, qc_vars=["mt"], inplace=True) 45 | ad = ad[ad.obs.pct_counts_mt < 20] 46 | sc.pp.filter_genes(ad, min_cells=1) 47 | sc.pp.filter_cells(ad, min_counts=100) 48 | ad.layers = {"counts":ad.X.copy()} #store raw counts before normalization changes ad.X 49 | sc.pp.normalize_total(ad, inplace=True, layers=None, key_added="sizefactor") 50 | sc.pp.log1p(ad) 51 | #Y = misc.reverse_normalization(np.expm1(ad.X),ad.obs["sizefactor"]) 52 | #np.max(np.abs(Y-np.layers["counts"])) 53 | 54 | # %% normalization, feature selection and train/test split 55 | ad.var['deviance_poisson'] = preprocess.deviancePoisson(ad.layers["counts"]) 56 | o = np.argsort(-ad.var['deviance_poisson']) 57 | idx = list(range(ad.shape[0])) 58 | random.shuffle(idx) 59 | ad = ad[idx,o] 60 | ad.write_h5ad(path.join(pth,"data/visium_brain_sagittal.h5ad"),compression="gzip") 61 | #ad = sc.read_h5ad(path.join(pth,"data/visium_brain_sagittal.h5ad")) 62 | ad2 = ad[:,:2000] 63 | ad2.write_h5ad(path.join(pth,"data/visium_brain_sagittal_J2000.h5ad"),compression="gzip") 64 | -------------------------------------------------------------------------------- /scrna/visium_brain_sagittal/03_benchmark.ipy: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # --- 3 | # jupyter: 4 | # jupytext: 5 | # text_representation: 6 | # extension: .ipy 7 | # format_name: percent 8 | # format_version: '1.3' 9 | # jupytext_version: 1.6.0 10 | # kernelspec: 11 | # display_name: Python 3 12 | # language: python 13 | # name: python3 14 | # --- 15 | 16 | #%% 17 | import pandas as pd 18 | from os import path 19 | from utils import misc,benchmark 20 | 21 | pth = "scrna/visium_brain_sagittal" 22 | dpth = path.join(pth,"data") 23 | mpth = path.join(pth,"models/V5") 24 | rpth = path.join(pth,"results") 25 | misc.mkdir_p(rpth) 26 | 27 | #%% Create CSV with benchmarking parameters 28 | csv_path = path.join(rpth,"benchmark.csv") 29 | try: 30 | par = pd.read_csv(csv_path) 31 | except FileNotFoundError: 32 | L = [6,12,20] 33 | sp_mods = ["NSF-P", "NSF-N", "NSF-G", "RSF-G", "NSFH-P", "NSFH-N"] 34 | ns_mods = ["PNMF-P", "PNMF-N", "FA-G"] 35 | sz = ["constant","scanpy"] 36 | M = [500,1000,2363] 37 | V = [5] 38 | kernels=["MaternThreeHalves"] 39 | par = benchmark.make_param_df(L,sp_mods,ns_mods,M,sz,V=V,kernels=kernels) 40 | par.to_csv(csv_path,index=False) 41 | 42 | #%% merge old benchmark csv with new 43 | # old = pd.read_csv(path.join(rpth,"benchmark1.csv")) 44 | # new = par.merge(old,on="key",how="outer",copy=True) 45 | # new["converged"] = new["converged_y"] 46 | # new["converged"].fillna(False, inplace=True) 47 | # new.drop(["converged_x","converged_y"],axis=1,inplace=True) 48 | # new.to_csv(path.join(rpth,"benchmark2.csv"),index=False) 49 | ##rename benchmark2 to benchmark manually 50 | ##some additional scenarios added manually as well (NSF,NSFH with L=36) 51 | 52 | #%% Additional scenarios added in response to reviewer comments 53 | # manually merge with original benchmark.csv 54 | # all original scenarios with 80/20 train/val split (only Matern32 kernel, no NB lik) 55 | csv_path = path.join(rpth,"benchmark2.csv") 56 | try: 57 | par = pd.read_csv(csv_path) 58 | except FileNotFoundError: 59 | L = [6,12,20] 60 | sp_mods = ["NSF-P", "RSF-G", "NSFH-P"] 61 | ns_mods = ["PNMF-P", "FA-G"] 62 | sz = ["constant","scanpy"] 63 | M = [500,1000,2363] 64 | V = [20] 65 | kernels=["MaternThreeHalves"] 66 | par = benchmark.make_param_df(L,sp_mods,ns_mods,M,sz,V=V,kernels=kernels) 67 | par.to_csv(csv_path,index=False) 68 | # NSF, RSF, and NSFH with ExponentiatedQuadratic kernel 69 | # need to manually delete duplicate MEFISTO scenarios in this one 70 | csv_path = path.join(rpth,"benchmark3.csv") 71 | try: 72 | par = pd.read_csv(csv_path) 73 | except FileNotFoundError: 74 | L = [6,12,20] 75 | sp_mods = ["NSF-P", "RSF-G", "NSFH-P"] 76 | ns_mods = [] 77 | sz = ["constant","scanpy"] 78 | M = [500,1000,2363] 79 | V = [5] 80 | kernels=["ExponentiatedQuadratic"] 81 | par = benchmark.make_param_df(L,sp_mods,ns_mods,M,sz,V=V,kernels=kernels) 82 | par.to_csv(csv_path,index=False) 83 | 84 | #%% Benchmarking at command line [markdown] 85 | """ 86 | To run on local computer, use 87 | `python -m utils.benchmark 197 scrna/visium_brain_sagittal/data/visium_brain_sagittal_J2000.h5ad` 88 | where 41 is a row ID of benchmark.csv, min value 2, max possible value is 241 89 | 90 | To run on cluster first load anaconda environment 91 | ``` 92 | tmux 93 | interactive 94 | module load anaconda3/2021.5 95 | conda activate fwt 96 | python -m utils.benchmark 14 scrna/visium_brain_sagittal/data/visium_brain_sagittal_J2000.h5ad 97 | ``` 98 | 99 | To run on cluster as a job array, subset of rows, recommend 6hr time limit. 100 | ``` 101 | DAT=./scrna/visium_brain_sagittal/data/visium_brain_sagittal_J2000.h5ad 102 | sbatch --mem=72G --array=135-196,198-241 ./utils/benchmark_array.slurm $DAT 103 | ``` 104 | 105 | To run on cluster as a job array, all rows of CSV file 106 | ``` 107 | CSV=./scrna/visium_brain_sagittal/results/benchmark.csv 108 | DAT=./scrna/visium_brain_sagittal/data/visium_brain_sagittal_J2000.h5ad 109 | sbatch --mem=72G --array=2-$(wc -l < $CSV) ./utils/benchmark_array.slurm $DAT 110 | ``` 111 | """ 112 | 113 | #%% Compute metrics for each model (as a job) 114 | """ 115 | DAT=./scrna/visium_brain_sagittal/data/visium_brain_sagittal_J2000.h5ad 116 | sbatch --mem=72G ./utils/benchmark_gof.slurm $DAT 5 117 | #wait until job finishes, then run below 118 | DAT=./scrna/visium_brain_sagittal/data/visium_brain_sagittal_J2000.h5ad 119 | sbatch --mem=72G ./utils/benchmark_gof.slurm $DAT 20 120 | """ 121 | 122 | #%% Compute metrics for each model (manually) 123 | from utils import benchmark 124 | dat = "scrna/visium_brain_sagittal/data/visium_brain_sagittal_J2000.h5ad" 125 | res = benchmark.update_results(dat, val_pct=20, todisk=True) 126 | 127 | #%% Examine one result 128 | from matplotlib import pyplot as plt 129 | from utils import training 130 | tro = training.ModelTrainer.from_pickle(path.join(mpth,"L4/poi/NSF_MaternThreeHalves_M3000")) 131 | plt.plot(tro.loss["train"][-200:-1]) 132 | 133 | #%% 134 | csv_file = path.join(rpth,"benchmark.csv") 135 | Ntr = tro.model.Z.shape[0] 136 | benchmark.correct_inducing_pts(csv_file, Ntr) 137 | -------------------------------------------------------------------------------- /scrna/visium_brain_sagittal/04_benchmark_viz.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Benchmarking Visualization" 3 | author: "Will Townes" 4 | output: html_document 5 | --- 6 | 7 | ```{r} 8 | library(tidyverse) 9 | theme_set(theme_bw()) 10 | fp<-file.path 11 | pth<-"scrna/visium_brain_sagittal" 12 | plt_pth<-fp(pth,"results/plots") 13 | if(!dir.exists(plt_pth)){ 14 | dir.create(plt_pth,recursive=TRUE) 15 | } 16 | ``` 17 | 18 | ```{r} 19 | d0<-read.csv(fp(pth,"results/benchmark.csv")) 20 | d0$converged<-as.logical(d0$converged) 21 | #d0$model<-plyr::mapvalues(d0$model,c("FA","RSF","PNMF","NSFH","NSF"),c("FA","RSF","PNMF","NSFH","NSF")) 22 | d<-subset(d0,converged==TRUE) 23 | d$model<-paste0(d$model," (",d$lik,")") 24 | unique(d$model) 25 | keep<-c("FA (gau)","MEFISTO (gau)","RSF (gau)","PNMF (nb)","PNMF (poi)","NSFH (nb)","NSFH (poi)","NSF (nb)","NSF (poi)") 26 | 27 | d<-subset(d,model %in% keep) 28 | d$model<-factor(d$model,levels=keep) 29 | d$dim<-factor(d$L,levels=sort(unique(d$L))) 30 | d$M[is.na(d$M)]<-2363 31 | d$IPs<-factor(d$M,levels=sort(unique(d$M))) 32 | d$standard_kernel<-TRUE 33 | d$standard_kernel[(d$model %in% c("RSF (gau)","NSFH (nb)","NSFH (poi)","NSF (nb)","NSF (poi)")) & d$kernel=="ExponentiatedQuadratic"]<-FALSE 34 | d1<-subset(d,V==5 & standard_kernel) 35 | ``` 36 | subset of models for simplified main figure 37 | ```{r} 38 | d2<-subset(d1,M==2363 | (model %in% c("PNMF (poi)","FA (gau)"))) 39 | d2a<-subset(d2,(model %in% c("PNMF (poi)","NSFH (poi)","NSF (poi)")) & sz=="scanpy") 40 | d2b<-subset(d2,model %in% c("FA (gau)","RSF (gau)","MEFISTO (gau)")) 41 | d2<-rbind(d2a,d2b) 42 | #d2$model<-factor(d2$model,levels=) 43 | d2$model<-plyr::mapvalues(d2$model,c("FA (gau)","MEFISTO (gau)","RSF (gau)","PNMF (poi)","NSFH (poi)","NSF (poi)"),c("FA","MEFISTO","RSF","PNMF","NSFH","NSF")) 44 | ggplot(d2,aes(x=model,y=dev_val_mean,color=dim,shape=lik))+geom_point(size=6,position=position_dodge(width=0.5))+ylab("validation deviance (mean)")+ylim(range(d2$dev_val_mean)*c(.95,1.02)) 45 | ggsave(fp(plt_pth,"vz_brn_gof_dev_val_mean_main.pdf"),width=5,height=2.5) 46 | 47 | #linear regressions for statistical significance 48 | d2$realvalued<-d2$model %in% c("FA","MEFISTO","RSF") 49 | #d2$dim_numeric<-as.numeric(as.character(d2$dim)) 50 | summary(lm(dev_val_mean~realvalued+dim,data=d2)) 51 | summary(lm(dev_val_mean~model,data=d2)) 52 | t.test(dev_val_mean~model,data=subset(d2,model %in% c("NSF","RSF")), var.equal=TRUE, alternative="less") 53 | d2$unsupervised<-d2$model %in% c("FA","PNMF") 54 | d2$unsupervised[d2$model=="MEFISTO"]<-NA 55 | summary(lm(dev_val_mean~realvalued+dim+unsupervised,data=d2)) 56 | t.test(dev_val_mean~model,data=subset(d2,model %in% c("RSF","MEFISTO")), var.equal=TRUE,alternative="greater") 57 | 58 | d2_no60<-subset(d2,dim!="60") 59 | d2_no60$dim<-factor(d2_no60$dim) 60 | ggplot(d2_no60,aes(x=model,y=rmse_val,color=dim,shape=lik))+geom_point(size=6,position=position_dodge(width=0.5))+ylab("validation RMSE")+ylim(range(d2$rmse_val)*c(.95,1.02)) 61 | ggsave(fp(plt_pth,"vz_brn_gof_rmse_val_simple.pdf"),width=5,height=2.5) 62 | 63 | ggplot(d2_no60,aes(x=model,y=sparsity,color=dim,shape=lik))+geom_point(size=6,position=position_dodge(width=0.5))+ylab("loadings zero fraction")+ylim(range(d2$sparsity)*c(.95,1.02)) 64 | ggsave(fp(plt_pth,"vz_brn_sparsity_simple.pdf"),width=5,height=2.5) 65 | ``` 66 | 67 | NSF vs NSFH 68 | 69 | ```{r} 70 | d3<-subset(d,grepl("NSF",model)&kernel=="MaternThreeHalves" & sz=="scanpy" & V==5 & dim!="60") 71 | ggplot(d3,aes(x=model,y=dev_tr_mean,color=dim,shape=IPs))+geom_point(size=4,position=position_dodge(width=0.6))+ylab("training deviance (mean)") 72 | ggsave(fp(plt_pth,"vz_brn_gof_dev_tr_nsf_vs_nsfh.pdf"),width=6,height=3) 73 | ``` 74 | 75 | ```{r} 76 | #training deviance - average 77 | ggplot(d1,aes(x=model,y=dev_tr_mean,color=dim,shape=IPs))+geom_jitter(size=5,width=.2,height=0)+ylab("training deviance (mean)")+theme(legend.position = "top") 78 | ggsave(fp(plt_pth,"vz_brn_gof_dev_tr_mean.pdf"),width=8,height=3) 79 | 80 | #training deviance - max 81 | ggplot(d1,aes(x=model,y=dev_tr_max,color=dim,shape=IPs))+geom_jitter(size=5,width=.2,height=0)+ylab("training deviance (max)")#+scale_y_log10() 82 | 83 | #validation deviance - mean 84 | ggplot(d1,aes(x=model,y=dev_val_mean,color=dim,shape=IPs))+geom_jitter(size=4,width=.4,height=0)+ylab("validation deviance (mean)")+theme(legend.position="none") 85 | ggsave(fp(plt_pth,"vz_brn_gof_dev_val_mean.pdf"),width=8,height=2.5) 86 | 87 | #validation deviance - max 88 | ggplot(d1,aes(x=model,y=dev_val_max,color=dim,shape=IPs))+geom_jitter(size=3,width=.2,height=0)+ylab("validation deviance (max)")#+scale_y_log10() 89 | 90 | #sparsity 91 | ggplot(d1,aes(x=model,y=sparsity,color=dim,shape=IPs))+geom_jitter(size=4,width=.3,height=0)+ylab("sparsity of loadings")+theme(legend.position="none") 92 | ggsave(fp(plt_pth,"vz_brn_sparsity.pdf"),width=6,height=2.5) 93 | 94 | #wall time 95 | ggplot(d1,aes(x=model,y=wtime/60,color=IPs,shape=dim))+geom_jitter(size=5,width=.2,height=0)+ylab("wall time (min)")+scale_y_log10()+theme(legend.position="top") 96 | ggsave(fp(plt_pth,"vz_brn_wtime.pdf"),width=6,height=3) 97 | ggplot(d1,aes(x=dim,y=wtime/60,color=IPs))+geom_jitter(size=3,width=.2,height=0)+ylab("wall time (min)")+scale_y_log10() 98 | 99 | #processor time 100 | ggplot(d1,aes(x=model,y=ptime/60,color=IPs,shape=dim))+geom_jitter(size=5,width=.2,height=0)+ylab("processor time")+scale_y_log10()+theme(legend.position="none") 101 | ggsave(fp(plt_pth,"vz_brn_ptime.pdf"),width=6,height=2.5) 102 | ``` 103 | 104 | Effect of using size factors 105 | 106 | ```{r} 107 | d2<-subset(d1,model %in% c("NSF (nb)","NSF (poi)","NSFH (nb)","NSFH (poi)","PNMF (nb)","PNMF (poi)")) 108 | d2$model<-factor(d2$model) 109 | d2$dim<-factor(d2$L,levels=sort(unique(d2$L))) 110 | d2$M[is.na(d2$M)]<-2363 111 | d2$IPs<-factor(d2$M,levels=sort(unique(d2$M))) 112 | ``` 113 | 114 | ```{r} 115 | ggplot(d2,aes(x=sz,y=dev_tr_mean,color=dim,shape=IPs))+geom_jitter(size=5,width=.2,height=0)+ylab("training deviance (mean)")+theme(legend.position = "top")+facet_wrap(~model,ncol=2,scales="free") 116 | 117 | ggplot(d2,aes(x=sz,y=dev_val_mean,color=dim,shape=IPs))+geom_jitter(size=5,width=.2,height=0)+ylab("validation deviance (mean)")+theme(legend.position = "top")+facet_wrap(~model,ncol=2,scales="free") 118 | ``` 119 | 120 | Reviewer comment: GOF metrics with validation 20% of data instead of 5% 121 | ```{r} 122 | d2<-subset(d,V==20 & standard_kernel) 123 | d2<-subset(d2,M==2363 | (model %in% c("PNMF (poi)","FA (gau)"))) 124 | d2a<-subset(d2,(model %in% c("PNMF (poi)","NSFH (poi)","NSF (poi)")) & sz=="scanpy") 125 | d2b<-subset(d2,model %in% c("FA (gau)","RSF (gau)","MEFISTO (gau)")) 126 | d2<-rbind(d2a,d2b) 127 | #d2$model<-factor(d2$model,levels=) 128 | d2$model<-plyr::mapvalues(d2$model,c("FA (gau)","MEFISTO (gau)","RSF (gau)","PNMF (poi)","NSFH (poi)","NSF (poi)"),c("FA","MEFISTO","RSF","PNMF","NSFH","NSF")) 129 | ggplot(d2,aes(x=model,y=dev_val_mean,color=dim,shape=lik))+geom_point(size=6,position=position_dodge(width=0.5))+ylab("validation deviance (mean)")+ylim(range(d2$dev_val_mean)*c(.95,1.02)) 130 | ggsave(fp(plt_pth,"vz_brn_gof_dev_val_mean_main_V20.pdf"),width=5,height=2.5) 131 | 132 | ggplot(d2,aes(x=model,y=rmse_val,color=dim,shape=lik))+geom_point(size=6,position=position_dodge(width=0.5))+ylab("validation RMSE")+ylim(range(d2$rmse_val)*c(.95,1.02)) 133 | ggsave(fp(plt_pth,"vz_brn_gof_rmse_val_simple_V20.pdf"),width=5,height=2.5) 134 | ``` 135 | 136 | Reviewer comment: Matern vs RBF kernels 137 | ```{r} 138 | d2<-subset(d,V==5 & M==1000) 139 | d2a<-subset(d2,(model %in% c("NSFH (poi)","NSF (poi)")) & sz=="scanpy") 140 | d2b<-subset(d2,model %in% c("RSF (gau)","MEFISTO (gau)")) 141 | d2<-rbind(d2a,d2b) 142 | #d2$model<-factor(d2$model,levels=) 143 | d2$model<-plyr::mapvalues(d2$model,c("MEFISTO (gau)","RSF (gau)","NSFH (poi)","NSF (poi)"),c("MEFISTO","RSF","NSFH","NSF")) 144 | ggplot(d2,aes(x=model,y=dev_val_mean,color=dim,shape=kernel))+geom_point(size=6,position=position_dodge(width=0.7))+ylab("validation deviance (mean)")+ylim(range(d2$dev_val_mean)*c(.95,1.02)) 145 | ggsave(fp(plt_pth,"vz_brn_gof_dev_val_kernels.pdf"),width=5,height=2.5) 146 | 147 | ggplot(d2,aes(x=model,y=rmse_val,color=dim,shape=kernel))+geom_point(size=6,position=position_dodge(width=0.7))+ylab("validation RMSE")+ylim(range(d2$rmse_val)*c(.95,1.02)) 148 | ggsave(fp(plt_pth,"vz_brn_gof_rmse_val_kernels.pdf"),width=5,height=2.5) 149 | ``` 150 | 151 | numerical stability of different kernel choices 152 | ```{r} 153 | d0<-read.csv(fp(pth,"results/benchmark.csv")) 154 | d0$converged<-as.logical(d0$converged) 155 | table(d0$converged) 156 | d1<-subset(d0,V==5)# & M==2363) 157 | d1a<-subset(d1,model=="RSF" & lik=="gau") 158 | d1b<-subset(d1,model %in% c("NSF","NSFH") & lik=="poi") 159 | d<-rbind(d1a,d1b) 160 | d2 <- d %>% group_by(kernel,M,model,lik) %>% summarize(total_runs=length(converged),converged=sum(converged)) 161 | ``` 162 | -------------------------------------------------------------------------------- /scrna/visium_brain_sagittal/05_interpret_genes.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "05_interpretation_genes" 3 | author: "Will Townes" 4 | output: html_document 5 | --- 6 | 7 | ```{r} 8 | # library(tidyverse) 9 | library(biomaRt) 10 | # library(rPanglaoDB) 11 | source("./scrna/utils/interpret_genes.R") 12 | 13 | fp<-file.path 14 | pth<-"scrna/visium_brain_sagittal" 15 | ``` 16 | loadings from NSFH 17 | ```{r} 18 | W<-read.csv(fp(pth,"results/vz_brn_nsfh20_spde_loadings.csv"),header=TRUE,row.names=1) 19 | rownames(W)<-toupper(rownames(W)) 20 | W2<-t(apply(as.matrix(W),1,function(x){as.numeric(x==max(x))})) 21 | colSums(W2) #clustering of genes 22 | ``` 23 | Gene Ontology terms, used for both NSFH and hotspot. 24 | ```{r} 25 | # bg<-rownames(W) 26 | db<-useMart("ensembl",host="https://may2021.archive.ensembl.org",dataset='mmusculus_gene_ensembl') 27 | go_ids<-getBM(attributes=c('go_id', 'external_gene_name','namespace_1003'), filters='external_gene_name', values=rownames(W), mart=db) 28 | go_ids[,2]<-toupper(go_ids[,2]) 29 | gene_2_GO<-unstack(go_ids[,c(1,2)]) 30 | ``` 31 | GO analysis for NSFH 32 | ```{r} 33 | res<-loadings2go(W,gene_2_GO) 34 | res$type<-rep(c("spat","nsp"),each=10) 35 | res$dim<-rep(1:10,2) 36 | write.csv(res[,c(1,4,2,3)],fp(pth,"results/vz_brn_nsfh20_spde_goterms.csv"),row.names=FALSE) 37 | ``` 38 | Get cell types from [Panglao](https://panglaodb.se). First manually download the list of all marker genes for all cell types from https://panglaodb.se/markers.html?cell_type=%27all_cells%27 . Store this TSV in the scrna/resources folder. 39 | ```{r} 40 | res<-read.csv(fp(pth,"results/vz_brn_nsfh20_spde_goterms.csv"),header=TRUE) 41 | # g<-strsplit(res[10,"genes"],", ",fixed=TRUE)[[1]] 42 | #pg<-panglao_make_ref() 43 | panglao_tsv="scrna/resources/PanglaoDB_markers_27_Mar_2020.tsv.gz" 44 | mk<-read.table(gzfile(panglao_tsv),header=TRUE,sep="\t") 45 | mk<-subset(mk,species %in% c("Mm","Mm Hs")) 46 | mk<-subset(mk,organ %in% c("Brain","Connective tissue","Epithelium","Olfactory system")) 47 | #,"Immune system","Vasculature","Blood" 48 | pg<-tapply(mk$official.gene.symbol,mk$cell.type,function(x){x},simplify=FALSE) 49 | # pg2<-pg[!(names(pg) %in% c("Plasma cells","Purkinje neurons","Delta cells", 50 | # bad<-c("Plasma cells","Purkinje neurons","Delta cells","Purkinje fiber cells","Erythroid-like and erythroid precurser cells") 51 | ct<-lapply(res$genes,genes2celltype,pg)#,ss=3) 52 | ct[is.na(ct)]<-"" 53 | res$celltypes<-paste(ct,sep=", ") 54 | write.csv(res,fp(pth,"results/vz_brn_nsfh20_spde_goterms.csv"),row.names=FALSE) 55 | ``` 56 | Shorter GO table 57 | ```{r} 58 | res<-read.csv(fp(pth,"results/vz_brn_nsfh20_spde_goterms.csv"),header=TRUE) 59 | res2<-res 60 | g<-strsplit(res2$genes,", ",fixed=TRUE) 61 | res2$genes<-sapply(g,function(x){paste(x[1:5],collapse=", ")}) 62 | gt<-strsplit(res2$go_bp,"; ",fixed=TRUE) 63 | res2$go_bp<-sapply(gt,function(x){paste(x[1:2],collapse=", ")}) 64 | write.csv(res2,fp(pth,"results/vz_brn_nsfh20_spde_goterms_short.csv"),row.names=FALSE) 65 | ``` 66 | 67 | GO analysis for hotspot 68 | ```{r} 69 | hs<-read.csv(fp(pth,"results/hotspot.csv"),header=TRUE) 70 | rownames(hs)<-hs$Gene 71 | Wh<-model.matrix(~0+factor(Module),data=subset(hs,Module!=-1)) 72 | colnames(Wh)<-paste0("X",1:ncol(Wh)) 73 | rownames(Wh)<-toupper(rownames(Wh)) 74 | # all(rownames(Wh) %in% rownames(W)) 75 | hres<-loadings2go(Wh,gene_2_GO, rowmean_divide=TRUE) 76 | colnames(hres)[which(colnames(hres)=="dim")]<-"cluster" 77 | write.csv(hres,fp(pth,"results/vz_brn_hotspot_goterms.csv"),row.names=FALSE) 78 | hres2<-hres 79 | g<-strsplit(hres2$genes,", ",fixed=TRUE) 80 | hres2$genes<-sapply(g,function(x){paste(x[1:5],collapse=", ")}) 81 | gt<-strsplit(hres2$go_bp,"; ",fixed=TRUE) 82 | hres2$go_bp<-sapply(gt,function(x){paste(x[1:2],collapse=", ")}) 83 | write.csv(hres2,fp(pth,"results/vz_brn_hotspot_goterms_short.csv"),row.names=FALSE) 84 | ``` 85 | -------------------------------------------------------------------------------- /scrna/visium_brain_sagittal/06_traditional.ipy: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # --- 3 | # jupyter: 4 | # jupytext: 5 | # text_representation: 6 | # extension: .ipy 7 | # format_name: percent 8 | # format_version: '1.3' 9 | # jupytext_version: 1.6.0 10 | # kernelspec: 11 | # display_name: Python 3 12 | # language: python 13 | # name: python3 14 | # --- 15 | 16 | #%% imports 17 | # import numpy as np 18 | import pandas as pd 19 | import scanpy as sc 20 | import matplotlib.pyplot as plt 21 | from os import path 22 | from hotspot import Hotspot 23 | 24 | from utils import misc,visualize 25 | 26 | dtp = "float32" 27 | pth = "scrna/visium_brain_sagittal" 28 | dpth = path.join(pth,"data") 29 | mpth = path.join(pth,"models") 30 | rpth = path.join(pth,"results") 31 | plt_pth = path.join(rpth,"plots") 32 | 33 | # %% Data Loading from scanpy 34 | J = 2000 35 | # dfile = path.join(dpth,"visium_brain_sagittal_J{}.h5ad".format(J)) 36 | dfile = path.join(dpth,"visium_brain_sagittal.h5ad") 37 | adata = sc.read_h5ad(dfile) 38 | sc.pp.highly_variable_genes(adata, flavor="seurat", n_top_genes=J) 39 | X0 = adata.obsm["spatial"] 40 | X0[:,1] = -X0[:,1] 41 | 42 | #%% Traditional scanpy analysis (unsupervised clustering) 43 | #https://scanpy-tutorials.readthedocs.io/en/latest/spatial/basic-analysis.html 44 | sc.pp.pca(adata) 45 | sc.pp.neighbors(adata) 46 | sc.tl.umap(adata) 47 | sc.tl.leiden(adata, resolution=1.0, key_added="clusters") 48 | plt.rcParams["figure.figsize"] = (4, 4) 49 | sc.pl.umap(adata, color="clusters", wspace=0.4) 50 | sc.pl.embedding(adata, "spatial", color="clusters") 51 | cl = pd.get_dummies(adata.obs["clusters"]).to_numpy() 52 | tgnames = [str(i) for i in range(1,cl.shape[1]+1)] 53 | hmkw = {"figsize":(10,8), "s":0.5, "marker":"D", "subplot_space":0, 54 | "spinecolor":"white"} 55 | fig,axes=visualize.multiheatmap(X0, cl, (4,5), **hmkw) 56 | visualize.set_titles(fig, tgnames, x=0.03, y=.88, fontsize="small", c="white", 57 | ha="left", va="top") 58 | fig.savefig(path.join(plt_pth,"vz_brn_heatmap_scanpy_clusters.pdf"), 59 | bbox_inches='tight') 60 | 61 | #%% Hotspot analysis (gene clusters) 62 | #https://hotspot.readthedocs.io/en/latest/Spatial_Tutorial.html 63 | J = 2000 64 | dfile = path.join(dpth,"visium_brain_sagittal_J{}.h5ad".format(J)) 65 | adata = sc.read_h5ad(dfile) 66 | adata.layers["counts"] = adata.layers["counts"].tocsc() 67 | hs = Hotspot(adata, layer_key="counts", model="danb", 68 | latent_obsm_key="spatial", umi_counts_obs_key="total_counts") 69 | hs.create_knn_graph(weighted_graph=False, n_neighbors=20) 70 | hs_results = hs.compute_autocorrelations() 71 | # hs_results.tail() 72 | hs_genes = hs_results.index#[hs_results.FDR < 0.05] 73 | lcz = hs.compute_local_correlations(hs_genes) 74 | modules = hs.create_modules(min_gene_threshold=20, core_only=False, 75 | fdr_threshold=0.05) 76 | # modules.value_counts() 77 | hs_results = hs_results.join(modules,how="left") 78 | hs_results.to_csv(path.join(rpth,"hotspot.csv")) 79 | misc.pickle_to_file(hs,path.join(mpth,"hotspot.pickle")) 80 | -------------------------------------------------------------------------------- /scrna/xyzeq_liver/01_data_loading.ipy: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # --- 3 | # jupyter: 4 | # jupytext: 5 | # text_representation: 6 | # extension: .ipy 7 | # format_name: percent 8 | # format_version: '1.3' 9 | # jupytext_version: 1.6.0 10 | # kernelspec: 11 | # display_name: Python 3 12 | # language: python 13 | # name: python3 14 | # --- 15 | 16 | # %% imports 17 | import random 18 | import numpy as np 19 | import pandas as pd 20 | import scanpy as sc 21 | from os import path 22 | from scipy import sparse 23 | from matplotlib import pyplot as plt 24 | 25 | from utils import preprocess, training, misc 26 | 27 | random.seed(101) 28 | pth = "scrna/xyzeq_liver" 29 | dpth = path.join(pth,"data") 30 | 31 | # %% Download original data files 32 | %%sh 33 | mkdir -p scrna/xyzeq_liver/data/original 34 | pushd scrna/xyzeq_liver/data/original 35 | wget ftp://ftp.ncbi.nlm.nih.gov/geo/samples/GSM5009nnn/GSM5009531/suppl/GSM5009531_L20C1.h5ad.gz 36 | unpigz GSM5009531_L20C1.h5ad.gz 37 | popd 38 | # Additional data files manually downloaded, provided by authors in email 39 | 40 | # %% Data Loading from original 41 | coords = pd.read_csv(path.join(dpth,"original","plate23_map.csv")) 42 | coords.rename(columns={coords.columns[0]: "barcode"},inplace=True) 43 | labs = pd.read_csv(path.join(dpth,"original","L20C1.csv")) 44 | #merge spatial coordinates with cell type labels and metadata 45 | labs["barcode"] = [w.split(".")[1] for w in labs["index"]] 46 | labs = labs.merge(coords,on="barcode") 47 | ad = sc.read_h5ad(path.join(dpth,"original","GSM5009531_L20C1.h5ad")) 48 | #match cell type labels with anndata rownames 49 | ad = ad[labs["index"]] 50 | labs.set_index("index",inplace=True,verify_integrity=True) 51 | ad.obs = labs 52 | mouse_cells = ad.obs["cell_call"]=="M" 53 | mouse_genes = ad.var_names.str.startswith("mm10_") 54 | ad = ad[mouse_cells,mouse_genes] #mouse cells only 55 | nz_genes = np.ravel(ad.X.sum(axis=0)>0) 56 | ad = ad[:,nz_genes] 57 | #rename genes to remove mm10_ prefix 58 | ad.var_names = ad.var_names.str.replace("mm10_","") 59 | #how many unique barcodes in this slice: only 289 60 | print("Unique locations: {}".format(len(ad.obs["barcode"].unique()))) 61 | ad.obsm["spatial"] = ad.obs[["X","Y"]].to_numpy() 62 | ad.obs.drop(columns=["X","Y"],inplace=True) 63 | X = ad.obsm["spatial"] 64 | #rectangle marker code: https://stackoverflow.com/a/62572367 65 | plt.scatter(X[:,0],X[:,1],marker='$\u25AE$',s=120) 66 | plt.gca().invert_yaxis() 67 | plt.title("Mouse cell locations") 68 | plt.show() 69 | 70 | #%% QC 71 | ad.var["mt"] = ad.var_names.str.startswith("mt-") 72 | sc.pp.calculate_qc_metrics(ad, qc_vars=["mt"], inplace=True) 73 | ad.obs.pct_counts_mt.hist() #all less than 2%, no need to filter cells 74 | #all cells and genes passed the below criteria 75 | ad = ad[ad.obs.pct_counts_mt < 20] 76 | sc.pp.filter_genes(ad, min_cells=1) 77 | sc.pp.filter_cells(ad, min_counts=100) 78 | ad.layers = {"counts":ad.X.copy()} #store raw counts before normalization changes ad.X 79 | sc.pp.normalize_total(ad, inplace=True, layers=None, key_added="sizefactor") 80 | sc.pp.log1p(ad) 81 | 82 | # %% normalization, feature selection and train/test split 83 | ad.var['deviance_poisson'] = preprocess.deviancePoisson(ad.layers["counts"]) 84 | o = np.argsort(-ad.var['deviance_poisson']) 85 | idx = list(range(ad.shape[0])) 86 | random.shuffle(idx) 87 | ad = ad[idx,o] 88 | ad.write_h5ad(path.join(dpth,"xyzeq_liver_L20C1_mouseonly.h5ad"),compression="gzip") 89 | #ad = sc.read_h5ad(path.join(dpth,"xyzeq_liver_L20C1_mouseonly.h5ad")) 90 | plt.plot(ad.var["deviance_poisson"].to_numpy()) 91 | ad2 = ad[:,:2000] 92 | ad2.write_h5ad(path.join(pth,"data/xyzeq_liver_L20C1_mouseonly_J2000.h5ad"),compression="gzip") 93 | -------------------------------------------------------------------------------- /scrna/xyzeq_liver/03_benchmark.ipy: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # --- 3 | # jupyter: 4 | # jupytext: 5 | # text_representation: 6 | # extension: .ipy 7 | # format_name: percent 8 | # format_version: '1.3' 9 | # jupytext_version: 1.6.0 10 | # kernelspec: 11 | # display_name: Python 3 12 | # language: python 13 | # name: python3 14 | # --- 15 | 16 | #%% 17 | import pandas as pd 18 | from os import path 19 | from utils import misc,benchmark 20 | 21 | pth = "scrna/xyzeq_liver" 22 | dpth = path.join(pth,"data") 23 | mpth = path.join(pth,"models/V5") 24 | rpth = path.join(pth,"results") 25 | misc.mkdir_p(rpth) 26 | 27 | #%% Create CSV with benchmarking parameters 28 | csv_path = path.join(rpth,"benchmark.csv") 29 | try: 30 | par = pd.read_csv(csv_path) 31 | except FileNotFoundError: 32 | L = [6,12,20] 33 | sp_mods = ["NSF-P", "NSF-N", "NSF-G", "RSF-G", "NSFH-P", "NSFH-N"] 34 | ns_mods = ["PNMF-P", "PNMF-N", "FA-G"] 35 | sz = ["constant","scanpy"] 36 | M = [288] 37 | par = benchmark.make_param_df(L,sp_mods,ns_mods,M,sz) 38 | par.to_csv(csv_path,index=False) 39 | 40 | #%% Benchmarking at command line [markdown] 41 | """ 42 | To run on local computer, use 43 | `python -um utils.benchmark 28 scrna/xyzeq_liver/data/xyzeq_liver_L20C1_mouseonly_J2000.h5ad` 44 | where 2 is a row ID of benchmark.csv, min value 2, max possible value is 49 45 | 46 | To run on cluster first load anaconda environment 47 | ``` 48 | tmux 49 | interactive 50 | module load anaconda3/2021.5 51 | conda activate fwt 52 | python -um utils.benchmark 2 scrna/xyzeq_liver/data/xyzeq_liver_L20C1_mouseonly_J2000.h5ad 53 | ``` 54 | 55 | To run on cluster as a job array, subset of rows, recommend 2hr time limit 56 | ``` 57 | DAT=./scrna/xyzeq_liver/data/xyzeq_liver_L20C1_mouseonly_J2000.h5ad 58 | sbatch --mem=48G --array=2-4 ./utils/benchmark_array.slurm $DAT 59 | ``` 60 | 61 | To run on cluster as a job array, all rows of CSV file 62 | ``` 63 | CSV=./scrna/xyzeq_liver/results/benchmark.csv 64 | DAT=./scrna/xyzeq_liver/data/xyzeq_liver_L20C1_mouseonly_J2000.h5ad 65 | sbatch --mem=48G --array=2-$(wc -l < $CSV) ./utils/benchmark_array.slurm $DAT 66 | ``` 67 | """ 68 | 69 | #%% Compute metrics for each model 70 | """ 71 | DAT=./scrna/xyzeq_liver/data/xyzeq_liver_L20C1_mouseonly_J2000.h5ad 72 | sbatch --mem=48G ./utils/benchmark_gof.slurm $DAT 5 73 | """ 74 | 75 | #%% Examine one result 76 | from matplotlib import pyplot as plt 77 | from utils import training 78 | tro = training.ModelTrainer.from_pickle(path.join(mpth,"L20/poi/NSFH_T10_MaternThreeHalves_M1000")) 79 | plt.plot(tro.loss["train"][-200:-1]) 80 | -------------------------------------------------------------------------------- /scrna/xyzeq_liver/04_benchmark_viz.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Benchmarking Visualization" 3 | author: "Will Townes" 4 | output: html_document 5 | --- 6 | 7 | ```{r} 8 | library(tidyverse) 9 | theme_set(theme_bw()) 10 | fp<-file.path 11 | pth<-"scrna/xyzeq_liver" 12 | plt_pth<-fp(pth,"results/plots") 13 | if(!dir.exists(plt_pth)){ 14 | dir.create(plt_pth,recursive=TRUE) 15 | } 16 | ``` 17 | 18 | ```{r} 19 | d0<-read.csv(fp(pth,"results/benchmark.csv")) 20 | d0$converged<-as.logical(d0$converged) 21 | #d0$model<-plyr::mapvalues(d0$model,c("RCF","RPF","NCF","NPFH","NPF"),c("FA","RSF","PNMF","NSFH","NSF")) 22 | d<-subset(d0,converged==TRUE) 23 | d$model<-paste0(d$model," (",d$lik,")") 24 | unique(d$model) 25 | # keep<-c("NPF (nb)", "NPF (poi)", "NPFH (nb)", "NPFH (poi)", "NCF (nb)", "NCF (poi)", "RPF (gau)", "RCF (gau)")#"MEFISTO (gau)", 26 | keep<-c("FA (gau)","MEFISTO (gau)","RSF (gau)","PNMF (nb)","PNMF (poi)","NSFH (nb)","NSFH (poi)","NSF (nb)","NSF (poi)") 27 | d<-subset(d,model %in% keep) 28 | d$model<-factor(d$model,levels=keep) 29 | d$dim<-factor(d$L,levels=sort(unique(d$L))) 30 | d$M[is.na(d$M)]<-288 31 | d$IPs<-factor(d$M,levels=sort(unique(d$M))) 32 | ``` 33 | subset of models for simplified main figure 34 | ```{r} 35 | d2<-subset(d,M==288 | (model %in% c("PNMF (poi)","FA (gau)"))) 36 | d2a<-subset(d2,(model %in% c("PNMF (poi)","NSFH (poi)","NSF (poi)")) & sz=="scanpy") 37 | d2b<-subset(d2,model %in% c("FA (gau)","RSF (gau)","MEFISTO (gau)")) 38 | d2<-rbind(d2a,d2b) 39 | #d2$model<-factor(d2$model,levels=) 40 | d2$model<-plyr::mapvalues(d2$model,c("FA (gau)","MEFISTO (gau)","RSF (gau)","PNMF (poi)","NSFH (poi)","NSF (poi)"),c("FA","MEFISTO","RSF","PNMF","NSFH","NSF")) 41 | ggplot(d2,aes(x=model,y=dev_val_mean,color=dim,shape=lik))+geom_point(size=6,position=position_dodge(width=0.8))+ylab("validation deviance (mean)")+ylim(range(d2$dev_val_mean)*c(.95,1.02)) 42 | ggsave(fp(plt_pth,"xyzeq_liver_gof_dev_val_mean_main.pdf"),width=5,height=2.5) 43 | 44 | #linear regressions for significance 45 | d2$real_valued<-d2$model %in% c("FA","MEFISTO","RSF") 46 | d2$spatial_aware<-d2$model %in% c("MEFISTO","RSF","NSFH","NSF") 47 | summary(lm(dev_val_mean~real_valued+spatial_aware+dim,data=d2)) 48 | t.test(dev_val_mean~model,data=subset(d2,model %in% c("NSFH","NSF")), var.equal=TRUE, alternative="greater") 49 | 50 | ggplot(d2,aes(x=model,y=rmse_val,color=dim,shape=lik))+geom_point(size=6,position=position_dodge(width=0.8))+ylab("validation RMSE")+ylim(range(d2$rmse_val)*c(.95,1.02)) 51 | ggsave(fp(plt_pth,"xyzeq_liver_gof_rmse_val_simple.pdf"),width=5,height=2.5) 52 | 53 | ggplot(d2,aes(x=model,y=sparsity,color=dim,shape=lik))+geom_point(size=6,position=position_dodge(width=0.8))+ylab("loadings zero fraction")+ylim(range(d2$sparsity)*c(.95,1.02)) 54 | ggsave(fp(plt_pth,"xyzeq_liver_sparsity_simple.pdf"),width=5,height=2.5) 55 | ``` 56 | 57 | NSF vs NSFH 58 | 59 | ```{r} 60 | d3<-subset(d,grepl("NSF",model) & sz=="scanpy") 61 | ggplot(d3,aes(x=model,y=dev_tr_mean,color=dim))+geom_point(size=4,position=position_dodge(width=0.8))+ylab("training deviance (mean)") 62 | ggsave(fp(plt_pth,"xyzeq_liver_gof_dev_tr_nsf_vs_nsfh.pdf"),width=6,height=3) 63 | ``` 64 | 65 | ```{r} 66 | #training deviance - average 67 | ggplot(d,aes(x=model,y=dev_tr_mean,color=dim,shape=sz))+geom_jitter(size=5,width=.2,height=0)+ylab("training deviance (mean)")+theme(legend.position = "top") 68 | ggsave(fp(plt_pth,"xyzeq_liver_gof_dev_tr_mean.pdf"),width=6,height=3) 69 | 70 | #training deviance - max 71 | ggplot(d,aes(x=model,y=dev_tr_max,color=dim,shape=sz))+geom_jitter(size=5,width=.2,height=0)+ylab("training deviance (max)")#+scale_y_log10() 72 | 73 | #validation deviance - mean 74 | ggplot(d,aes(x=model,y=dev_val_mean,color=dim,shape=sz))+geom_jitter(size=4,width=.4,height=0)+ylab("validation deviance (mean)")+theme(legend.position="none") 75 | ggsave(fp(plt_pth,"xyzeq_liver_gof_dev_val_mean.pdf"),width=6,height=2.5) 76 | 77 | #validation deviance - max 78 | ggplot(d,aes(x=model,y=dev_val_max,color=dim,shape=sz))+geom_jitter(size=3,width=.2,height=0)+ylab("validation deviance (max)")#+scale_y_log10() 79 | 80 | #sparsity 81 | ggplot(d,aes(x=model,y=sparsity,color=dim,shape=sz))+geom_jitter(size=4,width=.3,height=0)+ylab("sparsity of loadings")+theme(legend.position="none") 82 | ggsave(fp(plt_pth,"xyzeq_liver_sparsity.pdf"),width=6,height=2.5) 83 | 84 | #wall time 85 | ggplot(d,aes(x=model,y=wtime/60,color=dim,shape=sz))+geom_jitter(size=5,width=.2,height=0)+ylab("wall time (min)")+scale_y_log10()+theme(legend.position="top") 86 | ggsave(fp(plt_pth,"xyzeq_liver_wtime.pdf"),width=6,height=3) 87 | ggplot(d,aes(x=dim,y=wtime/60,color=sz))+geom_jitter(size=3,width=.2,height=0)+ylab("wall time (min)")+scale_y_log10() 88 | 89 | #processor time 90 | ggplot(d,aes(x=model,y=ptime/60,color=dim,shape=sz))+geom_jitter(size=5,width=.2,height=0)+ylab("processor time")+scale_y_log10()+theme(legend.position="none") 91 | ggsave(fp(plt_pth,"xyzeq_liver_ptime.pdf"),width=6,height=2.5) 92 | ``` 93 | -------------------------------------------------------------------------------- /scrna/xyzeq_liver/05_interpret_genes.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "05_interpretation_genes" 3 | author: "Will Townes" 4 | output: html_document 5 | --- 6 | 7 | ```{r} 8 | # library(tidyverse) 9 | library(biomaRt) 10 | # library(rPanglaoDB) 11 | source("./scrna/utils/interpret_genes.R") 12 | 13 | fp<-file.path 14 | pth<-"scrna/xyzeq_liver" 15 | ``` 16 | loadings from NSFH 17 | ```{r} 18 | W<-read.csv(fp(pth,"results/xyz_liv_nsfh6_spde_loadings.csv"),header=TRUE,row.names=1) 19 | rownames(W)<-toupper(rownames(W)) 20 | W2<-t(apply(as.matrix(W),1,function(x){as.numeric(x==max(x))})) 21 | colSums(W2) #clustering of genes 22 | ``` 23 | Gene Ontology terms, used for both NSFH and hotspot. 24 | ```{r} 25 | # bg<-rownames(W) 26 | db<-useMart("ensembl",host="https://may2021.archive.ensembl.org",dataset='mmusculus_gene_ensembl') 27 | go_ids<-getBM(attributes=c('go_id', 'external_gene_name','namespace_1003'), filters='external_gene_name', values=rownames(W), mart=db) 28 | go_ids[,2]<-toupper(go_ids[,2]) 29 | gene_2_GO<-unstack(go_ids[,c(1,2)]) 30 | ``` 31 | GO analysis for NSFH 32 | ```{r} 33 | # g<-strsplit(res[l,"genes"],", ",fixed=TRUE)[[1]] 34 | # expr <- getMarkers(include = g[c(1:2,4,6,7)])$ 35 | res<-loadings2go(W,gene_2_GO) 36 | res$type<-rep(c("spat","nsp"),each=3) 37 | res$dim<-rep(1:3,2) 38 | write.csv(res[,c(1,4,2,3)],fp(pth,"results/xyz_liv_nsfh6_spde_goterms.csv"),row.names=FALSE) 39 | ``` 40 | Get cell types from [Panglao](https://panglaodb.se). First manually download the list of all marker genes for all cell types from https://panglaodb.se/markers.html?cell_type=%27all_cells%27 . Store this TSV in the scrna/resources folder. 41 | ```{r} 42 | res<-read.csv(fp(pth,"results/xyz_liv_nsfh6_spde_goterms.csv"),header=TRUE) 43 | panglao_tsv="scrna/resources/PanglaoDB_markers_27_Mar_2020.tsv.gz" 44 | mk<-read.table(gzfile(panglao_tsv),header=TRUE,sep="\t") 45 | mk<-subset(mk,species %in% c("Mm","Mm Hs")) 46 | mk<-subset(mk,organ %in% c("Liver","Immune system","Vasculature","Blood")) 47 | pg<-tapply(mk$official.gene.symbol,mk$cell.type,function(x){x},simplify=FALSE) 48 | ct<-lapply(res$genes,genes2celltype,pg,ss=1) 49 | ct[is.na(ct)]<-"" 50 | #the automated search did not work so we tried a manual search on the panglao website 51 | ct[[1]]<-"Hepatocytes" 52 | ct[[3]]<-"Macrophages" 53 | ct[[4]]<-"Fibroblasts" 54 | ct[[6]]<-"Macrophages" 55 | res$celltypes<-paste(ct,sep=", ") 56 | write.csv(res,fp(pth,"results/xyz_liv_nsfh6_spde_goterms.csv"),row.names=FALSE) 57 | ``` 58 | Shorter GO table 59 | ```{r} 60 | res<-read.csv(fp(pth,"results/xyz_liv_nsfh6_spde_goterms.csv"),header=TRUE) 61 | res2<-res 62 | g<-strsplit(res2$genes,", ",fixed=TRUE) 63 | res2$genes<-sapply(g,function(x){paste(x[1:5],collapse=", ")}) 64 | gt<-strsplit(res2$go_bp,"; ",fixed=TRUE) 65 | res2$go_bp<-sapply(gt,function(x){paste(x[1:2],collapse=", ")}) 66 | write.csv(res2,fp(pth,"results/xyz_liv_nsfh6_spde_goterms_short.csv"),row.names=FALSE) 67 | ``` 68 | 69 | GO analysis for hotspot 70 | ```{r} 71 | hs<-read.csv(fp(pth,"results/hotspot.csv"),header=TRUE) 72 | rownames(hs)<-hs$Gene 73 | Wh<-model.matrix(~0+factor(Module),data=subset(hs,Module!=-1)) 74 | colnames(Wh)<-paste0("X",1:ncol(Wh)) 75 | rownames(Wh)<-toupper(rownames(Wh)) 76 | # all(rownames(Wh) %in% rownames(W)) 77 | hres<-loadings2go(Wh,gene_2_GO, rowmean_divide=TRUE) 78 | colnames(hres)[which(colnames(hres)=="dim")]<-"cluster" 79 | write.csv(hres,fp(pth,"results/xyz_liv_hotspot_goterms.csv"),row.names=FALSE) 80 | hres2<-hres 81 | g<-strsplit(hres2$genes,", ",fixed=TRUE) 82 | hres2$genes<-sapply(g,function(x){paste(x[1:5],collapse=", ")}) 83 | gt<-strsplit(hres2$go_bp,"; ",fixed=TRUE) 84 | hres2$go_bp<-sapply(gt,function(x){paste(x[1:2],collapse=", ")}) 85 | write.csv(hres2,fp(pth,"results/xyz_liv_hotspot_goterms_short.csv"),row.names=FALSE) 86 | ``` 87 | -------------------------------------------------------------------------------- /scrna/xyzeq_liver/06_traditional.ipy: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # --- 3 | # jupyter: 4 | # jupytext: 5 | # text_representation: 6 | # extension: .ipy 7 | # format_name: percent 8 | # format_version: '1.3' 9 | # jupytext_version: 1.6.0 10 | # kernelspec: 11 | # display_name: Python 3 12 | # language: python 13 | # name: python3 14 | # --- 15 | 16 | #%% imports 17 | import numpy as np 18 | import pandas as pd 19 | import scanpy as sc 20 | import matplotlib.pyplot as plt 21 | from os import path 22 | from hotspot import Hotspot 23 | 24 | from utils import misc,visualize 25 | 26 | dtp = "float32" 27 | pth = "scrna/xyzeq_liver" 28 | dpth = path.join(pth,"data") 29 | mpth = path.join(pth,"models") 30 | rpth = path.join(pth,"results") 31 | plt_pth = path.join(rpth,"plots") 32 | 33 | # %% Data Loading from scanpy 34 | J = 2000 35 | dfile = path.join(dpth,"xyzeq_liver_L20C1_mouseonly.h5ad") 36 | adata = sc.read_h5ad(dfile) 37 | sc.pp.highly_variable_genes(adata, flavor="seurat", n_top_genes=J) 38 | X0 = adata.obsm["spatial"] 39 | X0[:,1] = -X0[:,1] 40 | Z = np.unique(X0, axis=0) 41 | #find mapping between Xtr and Z 42 | from scipy.spatial.distance import cdist 43 | ZX = 1-(cdist(Z,X0)>0) #Ntr x M matrix 44 | ZX2 = ZX/ZX.sum(axis=1)[:,None] #premultiply by this to average within locations 45 | 46 | #%% Traditional scanpy analysis (unsupervised clustering) 47 | #https://scanpy-tutorials.readthedocs.io/en/latest/spatial/basic-analysis.html 48 | sc.pp.pca(adata) 49 | sc.pp.neighbors(adata) 50 | sc.tl.umap(adata) 51 | sc.tl.leiden(adata, resolution=1.0, key_added="clusters") 52 | plt.rcParams["figure.figsize"] = (4, 4) 53 | sc.pl.umap(adata, color="clusters", wspace=0.4) 54 | sc.pl.embedding(adata, "spatial", color="clusters") 55 | cl = pd.get_dummies(adata.obs["clusters"]).to_numpy() 56 | 57 | #%% Visualize clustering 58 | hmkw = {"figsize":(6.5,3.5),"subplot_space":0,"spinecolor":"white","marker":"$\u25AE$"} 59 | tgnames = [str(i) for i in range(1,cl.shape[1]+1)] 60 | fig,axes=visualize.multiheatmap(Z, ZX2@cl, (3,5), s=10, **hmkw) 61 | visualize.set_titles(fig, tgnames, x=0.05, y=.85, fontsize="small", c="white", 62 | ha="left", va="top") 63 | fig.savefig(path.join(plt_pth,"xyz_liv_heatmap_scanpy_clusters.pdf"), 64 | bbox_inches='tight') 65 | 66 | #%% Hotspot analysis (gene clusters) 67 | #https://hotspot.readthedocs.io/en/latest/Spatial_Tutorial.html 68 | J = 2000 69 | dfile = path.join(dpth,"xyzeq_liver_L20C1_mouseonly_J{}.h5ad".format(J)) 70 | adata = sc.read_h5ad(dfile) 71 | adata.layers["counts"] = adata.layers["counts"].tocsc() 72 | hs = Hotspot(adata, layer_key="counts", model="danb", 73 | latent_obsm_key="spatial", umi_counts_obs_key="total_counts") 74 | hs.create_knn_graph(weighted_graph=False, n_neighbors=20) 75 | hs_results = hs.compute_autocorrelations() 76 | # hs_results.tail() 77 | hs_genes = hs_results.index#[hs_results.FDR < 0.05] 78 | lcz = hs.compute_local_correlations(hs_genes) 79 | modules = hs.create_modules(min_gene_threshold=20, core_only=False, 80 | fdr_threshold=0.05) 81 | # modules.value_counts() 82 | hs_results = hs_results.join(modules,how="left") 83 | hs_results.to_csv(path.join(rpth,"hotspot.csv")) 84 | misc.pickle_to_file(hs,path.join(mpth,"hotspot.pickle")) 85 | 86 | #%% Hotspot module scores (optional) 87 | import mplscience 88 | module_scores = hs.calculate_module_scores() 89 | # module_scores.head() 90 | module_cols = [] 91 | for c in module_scores.columns: 92 | key = f"Module {c}" 93 | adata.obs[key] = module_scores[c] 94 | module_cols.append(key) 95 | with mplscience.style_context(): 96 | sc.pl.spatial(adata, color=module_cols, frameon=False, vmin="p0", vmax="p99", spot_size=1) 97 | -------------------------------------------------------------------------------- /scrna/xyzeq_liver/results/benchmark.csv: -------------------------------------------------------------------------------- 1 | V,L,model,kernel,M,lik,sz,key,converged,epochs,ptime,wtime,elbo_avg_tr,dev_tr_mean,dev_tr_argmax,dev_tr_max,dev_tr_med,rmse_tr,dev_val_mean,dev_val_argmax,dev_val_max,dev_val_med,rmse_val,sparsity,elbo_avg_val 2 | 5,6,MEFISTO,ExponentiatedQuadratic,288,gau,constant,V5/L6/gau/MEFISTO_ExponentiatedQuadratic_M288,TRUE,33,475.0265762,40.65122175,-577.3475467,4.134167,1,15.144811,4.1147914,1.9141325,4.113402888,1,17.16539548,4.098728814,1.963800171,0.22425, 3 | 5,12,MEFISTO,ExponentiatedQuadratic,288,gau,constant,V5/L12/gau/MEFISTO_ExponentiatedQuadratic_M288,TRUE,33,980.6648019,85.63033605,-582.2198654,4.120909,1,15.31339,4.101304,1.9123095,4.113208661,1,16.21541236,4.104301253,1.968408833,0.602541667, 4 | 5,20,MEFISTO,ExponentiatedQuadratic,288,gau,constant,V5/L20/gau/MEFISTO_ExponentiatedQuadratic_M288,TRUE,34,1491.377479,129.0016003,-589.3510638,4.136579,1,15.082991,4.117299,1.9100683,4.143368165,1,16.62485705,4.128560757,1.971220716,0.76875, 5 | 5,6,PNMF,,,nb,constant,V5/L6/nb_sz-constant/PNMF,TRUE,920,2524.103027,246.219986,-1720.036011,5.041337,0,6.126661,5.0378885,1.7476368,7.603628159,3,19.91540909,7.524975777,3.220266819,0.198833333,-1798.789551 6 | 5,6,PNMF,,,nb,scanpy,V5/L6/nb_sz-scanpy/PNMF,TRUE,940,2491.810059,243.1531219,-1517.83667,5.0389776,0,6.230536,5.0354776,1.7511828,6.049181461,1,34.24664688,5.948660374,5.179713249,0.203083333,-1606.150024 7 | 5,6,PNMF,,,poi,constant,V5/L6/poi_sz-constant/PNMF,TRUE,360,301.4029236,39.87991333,-2776.731445,5.0833573,5,5.1073666,5.083555,1.6904502,8.216876984,0,10.31153202,8.245800018,2.171020985,0.2005,-2931.80835 8 | 5,6,PNMF,,,poi,scanpy,V5/L6/poi_sz-scanpy/PNMF,TRUE,330,226.1258087,30.06169701,-2404.501709,5.074917,5,5.0900927,5.074988,1.6942964,6.917968273,0,13.38948631,6.929588318,2.131427288,0.202416667,-2541.011475 9 | 5,12,PNMF,,,nb,constant,V5/L12/nb_sz-constant/PNMF,TRUE,940,2676.507324,259.4359436,-1683.691284,4.868257,0,7.232113,4.8613777,1.7105677,7.549126148,3,27.5600853,7.375882626,6.447946072,0.326208333,-1761.467041 10 | 5,12,PNMF,,,nb,scanpy,V5/L12/nb_sz-scanpy/PNMF,TRUE,910,2454.509033,245.877243,-1494.77002,4.867705,0,7.1815634,4.8610163,1.7017567,6.839495182,3,54.47519684,6.406699181,21.46104813,0.338625,-1582.869751 11 | 5,12,PNMF,,,poi,constant,V5/L12/poi_sz-constant/PNMF,TRUE,370,270.802063,35.36241531,-2704.347412,4.9247146,0,5.0390167,4.9248867,1.5932386,8.489975929,0,12.06497669,8.515666962,2.344451904,0.320958333,-2869.649658 12 | 5,12,PNMF,,,poi,scanpy,V5/L12/poi_sz-scanpy/PNMF,TRUE,370,261.4364014,35.36388016,-2378.112793,4.9237103,0,5.0717387,4.9236794,1.5926809,7.006500721,0,11.88849831,7.025444031,2.141976118,0.322875,-2502.074951 13 | 5,20,PNMF,,,nb,constant,V5/L20/nb_sz-constant/PNMF,TRUE,930,2510.842773,246.0563812,-1652.982666,4.718672,2,8.630838,4.705968,1.9244496,7.915697575,1,12.51301765,7.900341988,2.530363083,0.41855,-1709.746704 14 | 5,20,PNMF,,,nb,scanpy,V5/L20/nb_sz-scanpy/PNMF,TRUE,900,2409.625244,236.7058563,-1468.537842,4.7337065,2,8.253958,4.721704,1.8141147,6.628312588,1,13.39372826,6.609304428,2.209197044,0.4285,-1550.239258 15 | 5,20,PNMF,,,poi,constant,V5/L20/poi_sz-constant/PNMF,TRUE,370,271.4530029,36.1639328,-2636.20874,4.788446,4,4.8216515,4.788552,1.4579074,8.605758667,0,11.07707977,8.640348434,2.193706274,0.42125,-2758.330078 16 | 5,20,PNMF,,,poi,scanpy,V5/L20/poi_sz-scanpy/PNMF,TRUE,320,221.6009521,30.11948395,-2283.694824,4.799126,4,4.849787,4.7993846,1.4611748,7.195641518,0,11.38509369,7.226406574,2.085646868,0.424725,-2395.852295 17 | 5,6,NSF,MaternThreeHalves,288,gau,constant,V5/L6/gau/NSF_MaternThreeHalves_M288,TRUE,330,514.3267212,67.1272583,-22716.13086,4.0782957,1,16.631784,4.0586843,1.9377972,4.038724422,1,19.72306061,4.026765347,2.007090092,0.182166667,-24909.04883 18 | 5,6,NSF,MaternThreeHalves,288,nb,constant,V5/L6/nb_sz-constant/NSF_MaternThreeHalves_M288,TRUE,1080,4598.710449,448.9979553,-1197.514282,6.1589003,1,6.669904,6.1590214,1.9816872,6.435605526,0,7.976939678,6.452908039,2.02231884,0.181416667,-1290.170288 19 | 5,6,NSF,MaternThreeHalves,288,nb,scanpy,V5/L6/nb_sz-scanpy/NSF_MaternThreeHalves_M288,TRUE,980,3909.734863,400.6589661,-1157.806885,5.4424324,0,7.322513,5.4386325,1.8403312,5.85012722,0,8.568701744,5.868012428,1.899519563,0.1945,-1247.425659 20 | 5,6,NSF,MaternThreeHalves,288,poi,constant,V5/L6/poi_sz-constant/NSF_MaternThreeHalves_M288,TRUE,190,319.7627563,41.72207642,-1945.525513,6.1958456,0,6.313511,6.1993403,1.9779277,6.510575294,0,7.618911266,6.531297684,2.039809704,0.204666667,-2105.933594 21 | 5,6,NSF,MaternThreeHalves,288,poi,scanpy,V5/L6/poi_sz-scanpy/NSF_MaternThreeHalves_M288,TRUE,160,271.1428833,36.14931488,-1750.665161,5.4353385,0,5.6069202,5.436984,1.7997899,5.829191208,0,7.138292313,5.85113287,1.86504066,0.228416667,-1879.915771 22 | 5,12,NSF,MaternThreeHalves,288,gau,constant,V5/L12/gau/NSF_MaternThreeHalves_M288,TRUE,330,729.2949829,81.42583466,-22867.52734,4.06159,1,16.795048,4.042221,1.9832529,4.002580643,1,22.01260757,3.99038434,2.108330011,0.29475,-25116.64063 23 | 5,12,NSF,MaternThreeHalves,288,nb,constant,V5/L12/nb_sz-constant/NSF_MaternThreeHalves_M288,TRUE,1060,5142.533691,495.576355,-1193.328491,6.0695033,0,12.234192,6.0640707,2.4376976,6.468214989,0,14.01273441,6.480208397,2.521373749,0.247791667,-1316.597778 24 | 5,12,NSF,MaternThreeHalves,288,nb,scanpy,V5/L12/nb_sz-scanpy/NSF_MaternThreeHalves_M288,TRUE,2570,12512.71582,1208.075195,-1247.108032,5.347305,0,13.777866,5.3390913,2.9746828,5.852240086,0,16.2920723,5.864924431,2.987389803,0.245,-1377.313843 25 | 5,12,NSF,MaternThreeHalves,288,poi,constant,V5/L12/poi_sz-constant/NSF_MaternThreeHalves_M288,TRUE,170,395.9271851,46.07463837,-1936.16333,6.1576486,0,6.3947344,6.162737,1.9496453,6.505647659,0,8.054733276,6.526143074,2.027067661,0.312333333,-2123.313965 26 | 5,12,NSF,MaternThreeHalves,288,poi,scanpy,V5/L12/poi_sz-scanpy/NSF_MaternThreeHalves_M288,TRUE,210,481.3469238,61.67857361,-1740.683472,5.406471,18,5.4939284,5.4104576,1.7529149,5.860135555,0,7.361698627,5.884509563,1.859587193,0.35775,-1902.700195 27 | 5,20,NSF,MaternThreeHalves,288,gau,constant,V5/L20/gau/NSF_MaternThreeHalves_M288,TRUE,330,966.2658691,107.8690338,-23064.33203,4.0386744,1,15.490185,4.017086,2.0383823,3.966190338,1,16.15410233,3.954101324,2.110280037,0.390325,-25329.31445 28 | 5,20,NSF,MaternThreeHalves,288,nb,constant,V5/L20/nb_sz-constant/NSF_MaternThreeHalves_M288,TRUE,1090,5905.898926,600.9832153,-1186.449829,5.9732676,0,18.473352,5.9590197,3.865973,6.459947586,0,20.63406563,6.465190887,3.858034372,0.28045,-1337.837524 29 | 5,20,NSF,MaternThreeHalves,288,nb,scanpy,V5/L20/nb_sz-scanpy/NSF_MaternThreeHalves_M288,TRUE,970,5574.790527,538.3094482,-1146.770874,5.2692337,0,20.756817,5.250762,5.047426,5.878831387,0,25.57594109,5.885680199,4.886908531,0.285075,-1294.078247 30 | 5,20,NSF,MaternThreeHalves,288,poi,constant,V5/L20/poi_sz-constant/NSF_MaternThreeHalves_M288,TRUE,130,421.626709,48.69716644,-1932.29895,6.166916,226,6.225311,6.1753907,1.9243623,6.500660896,0,7.525696278,6.523514748,1.995321751,0.35825,-2144.512207 31 | 5,20,NSF,MaternThreeHalves,288,poi,scanpy,V5/L20/poi_sz-scanpy/NSF_MaternThreeHalves_M288,TRUE,220,686.5084839,74.90756226,-1719.314697,5.290785,0,5.424759,5.290799,1.684984,5.8166008,0,7.393985748,5.837240696,1.810169578,0.429025,-1924.939697 32 | 5,6,NSFH,MaternThreeHalves,288,nb,constant,V5/L6/nb_sz-constant/NSFH_T3_MaternThreeHalves_M288,TRUE,1010,4063.513672,400.4521179,-1519.04895,5.208853,3,31.806583,5.1926947,2.4516542,6.999220848,3,45.26802063,6.974718094,2.682162762,0.211166667,-1612.036987 33 | 5,6,NSFH,MaternThreeHalves,288,nb,scanpy,V5/L6/nb_sz-scanpy/NSFH_T3_MaternThreeHalves_M288,TRUE,950,3661.553711,372.4749756,-1265.884644,5.2266626,0,7.021208,5.224578,1.7982568,6.028244019,0,8.544682503,6.021720886,2.196630001,0.24175,-1350.790039 34 | 5,6,NSFH,MaternThreeHalves,288,poi,constant,V5/L6/poi_sz-constant/NSFH_T3_MaternThreeHalves_M288,TRUE,550,810.0413208,105.2309952,-2577.969727,5.2439775,3,10.560041,5.2404513,1.8573103,7.400186062,3,17.4361763,7.400484085,2.282255411,0.21225,-2746.791992 35 | 5,6,NSFH,MaternThreeHalves,288,poi,scanpy,V5/L6/poi_sz-scanpy/NSFH_T3_MaternThreeHalves_M288,TRUE,230,344.7897034,47.61157227,-1875.325195,5.2310743,0,5.37592,5.231882,1.7471918,6.22217989,0,7.375682354,6.249918938,1.910659194,0.287166667,-1992.292114 36 | 5,12,NSFH,MaternThreeHalves,288,nb,constant,V5/L12/nb_sz-constant/NSFH_T6_MaternThreeHalves_M288,TRUE,960,4144.26123,408.4788513,-1505.599365,5.036107,0,35.802643,5.0135765,2.824309,7.508047104,0,42.9197998,7.464859009,3.275102854,0.295041667,-1601.634155 37 | 5,12,NSFH,MaternThreeHalves,288,nb,scanpy,V5/L12/nb_sz-scanpy/NSFH_T6_MaternThreeHalves_M288,TRUE,940,3446.153076,343.68396,-1249.673828,5.1136417,0,14.9740715,5.107809,1.995267,6.045698643,0,22.48182297,6.050631046,2.206949472,0.332375,-1340.674072 38 | 5,12,NSFH,MaternThreeHalves,288,poi,constant,V5/L12/poi_sz-constant/NSFH_T6_MaternThreeHalves_M288,TRUE,480,724.5327148,89.72132111,-2447.5979,5.072291,0,6.2428656,5.071582,1.6995671,7.525919437,0,10.80947208,7.541521549,2.137618065,0.310208333,-2571.26001 39 | 5,12,NSFH,MaternThreeHalves,288,poi,scanpy,V5/L12/poi_sz-scanpy/NSFH_T6_MaternThreeHalves_M288,TRUE,260,421.4099426,55.64730453,-1839.009766,5.132295,0,6.134202,5.132534,1.6707282,6.011397362,0,10.5695591,6.034472466,1.912448764,0.378625,-1958.931885 40 | 5,20,NSFH,MaternThreeHalves,288,nb,constant,V5/L20/nb_sz-constant/NSFH_T10_MaternThreeHalves_M288,TRUE,930,3572.945557,359.5145569,-1486.776611,4.953084,0,11.2164545,4.943058,2.6832116,7.717191219,0,16.1084938,7.714572906,2.997145176,0.37005,-1586.578613 41 | 5,20,NSFH,MaternThreeHalves,288,nb,scanpy,V5/L20/nb_sz-scanpy/NSFH_T10_MaternThreeHalves_M288,TRUE,930,3610.885986,363.6177673,-1275.165161,4.9190106,0,26.494144,4.9006205,2.5711167,6.319021702,0,33.04694748,6.318647385,2.331879377,0.40405,-1384.026123 42 | 5,20,NSFH,MaternThreeHalves,288,poi,constant,V5/L20/poi_sz-constant/NSFH_T10_MaternThreeHalves_M288,TRUE,330,664.0027466,85.68494415,-2215.143066,5.0840487,0,5.842009,5.0865946,1.6648625,7.702683449,0,8.923201561,7.737083435,2.080332994,0.4205,-2381.046387 43 | 5,20,NSFH,MaternThreeHalves,288,poi,scanpy,V5/L20/poi_sz-scanpy/NSFH_T10_MaternThreeHalves_M288,TRUE,300,586.2957764,70.1000824,-1871.017578,5.001067,0,5.0868735,5.0011663,1.601763,6.247176647,0,7.752833366,6.273864746,1.897039056,0.4557,-2026.676025 44 | 5,6,FA,,,gau,constant,V5/L6/gau/FA,TRUE,330,259.3052979,36.04651642,-25486.67188,3.684314,1,12.101881,3.6632836,1.8165487,4.752997875,1,21.94866753,4.741261005,2.187690258,0,-27738.12109 45 | 5,12,FA,,,gau,constant,V5/L12/gau/FA,TRUE,520,374.788208,47.83795547,-26551.11719,3.585647,1,13.961616,3.5635648,1.7608483,4.825815201,1,21.52416992,4.813059807,2.269764185,0.000125,-28805.42578 46 | 5,20,FA,,,gau,constant,V5/L20/gau/FA,TRUE,560,409.4586487,53.65220261,-27751.89844,3.5128248,1,14.16262,3.4883509,1.7047457,4.784780502,1,21.31032944,4.769587517,2.233421087,0.0001,-30139.4668 47 | 5,6,RSF,MaternThreeHalves,288,gau,constant,V5/L6/gau/RSF_MaternThreeHalves_M288,TRUE,330,515.1203613,66.48441315,-23721.80469,4.062435,1,15.900438,4.042738,1.9017575,3.97942543,1,18.46582603,3.969529152,1.962098241,0,-25971.41602 48 | 5,12,RSF,MaternThreeHalves,288,gau,constant,V5/L12/gau/RSF_MaternThreeHalves_M288,TRUE,330,710.8594971,82.67900085,-23931.0625,4.034443,1,15.513014,4.014183,1.8843172,3.930668831,1,18.57084656,3.918994904,1.947594881,0.000125,-26235.27734 49 | 5,20,RSF,MaternThreeHalves,288,gau,constant,V5/L20/gau/RSF_MaternThreeHalves_M288,TRUE,330,823.336853,91.08987427,-24120.2793,4.0221286,1,15.5932865,4.0022526,1.865359,3.934701681,1,17.80091286,3.926838636,1.922031164,0,-26500.32617 -------------------------------------------------------------------------------- /simulations/.gitignore: -------------------------------------------------------------------------------- 1 | models 2 | -------------------------------------------------------------------------------- /simulations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/willtownes/nsf-paper/0cacf8352e09d223ab8d4421025195358bbde8df/simulations/__init__.py -------------------------------------------------------------------------------- /simulations/benchmark.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Example usage: 5 | 6 | python -m simulations.benchmark 2 simulations/bm_sp 7 | 8 | Fit a model to a dataset and save the results. 9 | 10 | Slightly different from utils.benchmark, which assumes there is only one possible 11 | data file under the parent directory. Here, we allow the data directory to contain 12 | multiple H5AD files and ensure the saved models indicate the scenario in the "key" (filepath). 13 | 14 | Result: 15 | * Loads data, fits a model 16 | * pickles a fitted ModelTrainer object under [dataset]/models/ directory 17 | 18 | Pickled model naming conventions: 19 | 20 | file scheme for spatial models: 21 | [dataset]/models/S[scenario]/V[val frac]/L[factors]/[likelihood]/[model]_[kernel]_M[inducing_pts]/epoch[epoch].pickle 22 | 23 | file scheme for nonspatial models: 24 | [dataset]/models/S[scenario]/V[val frac]/L[factors]/[model]/epoch[epoch].pickle 25 | 26 | @author: townesf 27 | """ 28 | from os import path 29 | from argparse import ArgumentParser 30 | from utils.misc import read_csv_oneline 31 | from utils import benchmark as ubm 32 | 33 | def benchmark(ID,pth): 34 | """ 35 | Run benchmarking on dataset for the model specified in benchmark.csv in row ID. 36 | """ 37 | # dsplit = dataset.split("/data/") 38 | # pth = dsplit[0] 39 | csv_file = path.join(pth,"results/benchmark.csv") 40 | #header of CSV is row zero 41 | p = read_csv_oneline(csv_file,ID-1) 42 | opath = path.join(pth,"models",p["key"]) #p["key"] includes "S{}/" for simulations 43 | print("{}".format(p["key"])) 44 | if path.isfile(path.join(opath,"converged.pickle")): 45 | print("Benchmark already complete, exiting.") 46 | return None 47 | else: 48 | print("Starting benchmark.") 49 | train_frac = ubm.val2train_frac(p["V"]) 50 | dataset = path.join(pth,"data/S{}.h5ad".format(p["scenario"])) #different in simulations 51 | D,fmeans = ubm.load_data(dataset,model=p['model'],lik=p['lik'],sz=p['sz'], 52 | train_frac=train_frac, flip_yaxis=False) 53 | fit = ubm.init_model(D,p,opath,fmeans=fmeans) 54 | tro = ubm.fit_model(D,fit,p,opath) 55 | return tro 56 | 57 | def arghandler(args=None): 58 | """parses a list of arguments (default is sys.argv[1:])""" 59 | parser = ArgumentParser() 60 | parser.add_argument("id", type=int, 61 | help="line in benchmark csv from which to get parameters") 62 | parser.add_argument("path", type=str, 63 | help="top level directory containing with subfolders 'data', 'models', and 'results'.") 64 | args = parser.parse_args(args) #if args is None, this will automatically parse sys.argv[1:] 65 | return args 66 | 67 | if __name__=="__main__": 68 | # #input from argparse 69 | # ID = 2 #slurm job array uses 1-based indexing (2,3,...,43) 70 | # #start with 2 to avoid header row of CSV 71 | # DATASET = "simulations/bm_sp/data/S1.h5ad" 72 | args = arghandler() 73 | tro = benchmark(args.id, args.path) 74 | -------------------------------------------------------------------------------- /simulations/benchmark.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=benchmark # create a short name for your job 3 | #SBATCH --nodes=1 # node count 4 | #SBATCH --ntasks=1 # total number of tasks across all nodes 5 | #SBATCH --cpus-per-task=4 # cpu-cores per task (>1 if multi-threaded tasks) 6 | #SBATCH --time=01:02:00 # total run time limit (HH:MM:SS) 7 | #SBATCH --mail-type=END,FAIL 8 | #SBATCH --mail-user=ftownes@princeton.edu 9 | 10 | #example usage, default is 4G memory per core (--mem-per-cpu), --mem is total 11 | #CSV=./simulations/bm_sp/results/benchmark.csv 12 | #PTH=./simulations/bm_sp 13 | #sbatch --mem=16G --array=1-$(wc -l < $CSV) ./simulations/benchmark.slurm $PTH 14 | 15 | module purge 16 | module load anaconda3/2021.5 17 | conda activate fwt 18 | 19 | #first command line arg $1 is file path to parent directory of dataset 20 | python -um simulations.benchmark $SLURM_ARRAY_TASK_ID $1 21 | -------------------------------------------------------------------------------- /simulations/benchmark_gof.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | python -m simulations.benchmark_gof simulations/bm_sp 5 | """ 6 | import numpy as np 7 | import pandas as pd 8 | from os import path 9 | from argparse import ArgumentParser 10 | from contextlib import suppress 11 | from scanpy import read_h5ad 12 | from scipy.spatial.distance import pdist,cdist 13 | from scipy.stats.stats import pearsonr,spearmanr 14 | from sklearn.cluster import KMeans 15 | from sklearn import metrics 16 | 17 | from utils import postprocess 18 | from utils.preprocess import load_data 19 | from utils import benchmark as ubm 20 | 21 | def compare_to_truth(ad, Ntr, fit, model): 22 | ad = ad[:Ntr,:] 23 | X = ad.obsm["spatial"] 24 | #extract ground truth factors and loadings 25 | tru = postprocess.interpret_nonneg(ad.obsm["spfac"], ad.varm["spload"], 26 | lda_mode=False) 27 | F0 = tru["factors"] 28 | W0 = tru["loadings"] 29 | FFd0 = pdist(F0) #cell-cell distances (vectorized) 30 | WWd0 = pdist(W0) #gene-gene distances (vectorized) 31 | ifit = postprocess.interpret_fit(fit,X,model) 32 | F = ifit["factors"] 33 | W = ifit["loadings"] 34 | Fpc = cdist(F0.T,F.T,metric=lambda x,y: abs(pearsonr(x,y)[0])).max(axis=1) 35 | Fsc = cdist(F0.T,F.T,metric=lambda x,y: abs(spearmanr(x,y)[0])).max(axis=1) 36 | Wpc = cdist(W0.T,W.T,metric=lambda x,y: abs(pearsonr(x,y)[0])).max(axis=1) 37 | Wsc = cdist(W0.T,W.T,metric=lambda x,y: abs(spearmanr(x,y)[0])).max(axis=1) 38 | res = {} 39 | res["factors_pearson"] = Fpc.tolist() 40 | res["factors_spearman"] = Fsc.tolist() 41 | res["loadings_pearson"] = Wpc.tolist() 42 | res["loadings_spearman"] = Wsc.tolist() 43 | FFd = pdist(F) 44 | WWd = pdist(W) 45 | res["dfactors_pearson"] = pearsonr(FFd0,FFd)[0] 46 | res["dfactors_spearman"] = spearmanr(FFd0,FFd)[0] 47 | res["dloadings_pearson"] = pearsonr(WWd0,WWd)[0] 48 | res["dloadings_spearman"] = spearmanr(WWd0,WWd)[0] 49 | nclust = W0.shape[1] 50 | km0 = KMeans(n_clusters=nclust).fit(W0).labels_ 51 | km1 = KMeans(n_clusters=nclust).fit(W).labels_ 52 | res["loadings_clust_ari"] = metrics.adjusted_rand_score(km0, km1) 53 | return res 54 | 55 | def compare_to_truth_mixed(ad, Ntr, fit, model): 56 | ad = ad[:Ntr,:].copy() 57 | X = ad.obsm["spatial"] 58 | #extract factors and loadings 59 | tru = postprocess.interpret_nonneg_mixed(ad.obsm["spfac"], ad.varm["spload"], 60 | ad.obsm["nsfac"], ad.varm["nsload"], 61 | lda_mode=False) 62 | F0 = tru["spatial"]["factors"] 63 | W0 = tru["spatial"]["loadings"] 64 | H0 = tru["nonspatial"]["factors"] 65 | V0 = tru["nonspatial"]["loadings"] 66 | FH0 = np.concatenate((F0,H0),axis=1) 67 | WV0 = np.concatenate((W0,V0),axis=1) 68 | alpha0 = W0.sum(axis=1) #true spatial importances 69 | ifit = postprocess.interpret_fit(fit,X,model) 70 | if model == "NSFH": 71 | F = ifit["spatial"]["factors"] 72 | W = ifit["spatial"]["loadings"] 73 | H = ifit["nonspatial"]["factors"] 74 | V = ifit["nonspatial"]["loadings"] 75 | FH = np.concatenate((F,H),axis=1) 76 | WV = np.concatenate((W,V),axis=1) 77 | alpha = W.sum(axis=1) #estimated spatial importances 78 | else: 79 | FH = ifit["factors"] 80 | WV = ifit["loadings"] 81 | if model =="NSF": 82 | alpha = np.ones_like(alpha0) 83 | elif model=="PNMF": 84 | alpha = np.zeros_like(alpha0) 85 | else: 86 | alpha = None 87 | Fpc = cdist(FH0.T,FH.T,metric=lambda x,y: abs(pearsonr(x,y)[0])).max(axis=1) 88 | Fsc = cdist(FH0.T,FH.T,metric=lambda x,y: abs(spearmanr(x,y)[0])).max(axis=1) 89 | Wpc = cdist(WV0.T,WV.T,metric=lambda x,y: abs(pearsonr(x,y)[0])).max(axis=1) 90 | Wsc = cdist(WV0.T,WV.T,metric=lambda x,y: abs(spearmanr(x,y)[0])).max(axis=1) 91 | res = {} 92 | res["factors_pearson"] = Fpc.tolist() 93 | res["factors_spearman"] = Fsc.tolist() 94 | res["loadings_pearson"] = Wpc.tolist() 95 | res["loadings_spearman"] = Wsc.tolist() 96 | if alpha is None: 97 | sp_imp_dist = None 98 | else: 99 | sp_imp_dist = np.abs(alpha-alpha0).sum() #L1 distance 100 | res["spatial_importance_dist"] = sp_imp_dist 101 | nclust = WV0.shape[1] 102 | km0 = KMeans(n_clusters=nclust).fit(WV0).labels_ 103 | km1 = KMeans(n_clusters=nclust).fit(WV).labels_ 104 | res["loadings_clust_ari"] = metrics.adjusted_rand_score(km0, km1) 105 | return res 106 | 107 | def row_metrics(row,pth,verbose=True,mode="bm_sp"): 108 | row = dict(row) #originally row is a pandas.Series object 109 | train_frac = ubm.val2train_frac(row["V"]) 110 | if not "converged" in row or not row['converged']:# or row[mnames].isnull().any(): 111 | dataset = path.join(pth,"data","S{}.h5ad".format(row["scenario"])) 112 | ad = read_h5ad(path.normpath(dataset)) 113 | pkl = path.join(pth,"models",row["key"]) 114 | if row["sz"]=="scanpy": 115 | D,fmeans = load_data(ad, model="NSF", lik="poi", sz="scanpy", 116 | train_frac=train_frac, flip_yaxis=False) 117 | else: 118 | D,fmeans = load_data(ad, model=None, lik=None, sz="constant", 119 | train_frac=train_frac, flip_yaxis=False) 120 | with suppress(FileNotFoundError): 121 | fit,tro = ubm.load(pkl) 122 | if verbose: print(row["key"]) 123 | metrics = ubm.get_metrics(fit,D["raw"]["tr"],D["raw"]["val"],tro=tro) 124 | row.update(metrics) 125 | Ntr = D["raw"]["tr"]["X"].shape[0] 126 | if mode=="bm_sp": 127 | metrics2 = compare_to_truth(ad, Ntr, fit, row["model"]) 128 | elif mode=="bm_mixed": 129 | metrics2 = compare_to_truth_mixed(ad, Ntr, fit, row["model"]) 130 | else: 131 | raise ValueError("mode must be either 'bm_sp' or 'bm_mixed'") 132 | row.update(metrics2) 133 | return row 134 | 135 | def update_results(pth,todisk=True,verbose=True): 136 | """ 137 | different than the utils.benchmark.update_results because it loads a separate 138 | dataset and model for each row of the results.csv to accommmodate multiple 139 | simulation scenarios 140 | """ 141 | # pth = "simulations/bm_sp" 142 | csv_file = path.join(pth,"results/benchmark.csv") 143 | res = pd.read_csv(csv_file) 144 | res = res.apply(row_metrics, args=(pth,), axis=1, result_type="expand", 145 | verbose=verbose, mode=path.split(pth)[-1]) 146 | if todisk: res.to_csv(csv_file,index=False) 147 | return res 148 | 149 | def arghandler(args=None): 150 | """parses a list of arguments (default is sys.argv[1:])""" 151 | parser = ArgumentParser() 152 | parser.add_argument("path", type=str, 153 | help="top level directory containing subfolders 'data', 'models', and 'results'.") 154 | args = parser.parse_args(args) #if args is None, this will automatically parse sys.argv[1:] 155 | return args 156 | 157 | if __name__=="__main__": 158 | args = arghandler() 159 | res = update_results(args.path, todisk=True) 160 | -------------------------------------------------------------------------------- /simulations/benchmark_gof.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=benchmark_gof # create a short name for your job 3 | #SBATCH --nodes=1 # node count 4 | #SBATCH --ntasks=1 # total number of tasks across all nodes 5 | #SBATCH --cpus-per-task=4 # cpu-cores per task (>1 if multi-threaded tasks) 6 | #SBATCH --time=00:20:00 # total run time limit (HH:MM:SS) 7 | #SBATCH --mail-type=END,FAIL 8 | #SBATCH --mail-user=ftownes@princeton.edu 9 | 10 | #example usage, --mem-per-cpu default is 4G per core 11 | #PTH=./simulations/bm_sp 12 | #sbatch --mem=16G ./simulations/benchmark_gof.slurm $PTH 13 | 14 | module purge 15 | module load anaconda3/2021.5 16 | conda activate fwt 17 | 18 | #first command line arg $1 is file path to parent directory 19 | python -um simulations.benchmark_gof $1 20 | -------------------------------------------------------------------------------- /simulations/bm_mixed/01_data_generation.ipy: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # --- 3 | # jupyter: 4 | # jupytext: 5 | # text_representation: 6 | # extension: .ipy 7 | # format_name: percent 8 | # format_version: '1.3' 9 | # jupytext_version: 1.6.0 10 | # kernelspec: 11 | # display_name: Python 3 12 | # language: python 13 | # name: python3 14 | # --- 15 | 16 | #%% imports 17 | from os import path 18 | from copy import deepcopy 19 | import numpy as np 20 | import pandas as pd 21 | from scanpy import read_h5ad 22 | from janitor import expand_grid 23 | from utils import visualize,misc 24 | from simulations import sim 25 | pth = "simulations/bm_mixed" 26 | dpth = path.join(pth,"data") 27 | rpth = path.join(pth,"results") 28 | mpth = path.join(pth,"models") 29 | # misc.mkdir_p(dpth) # or use symlink to dropbox 30 | # misc.mkdir_p(rpth) 31 | # misc.mkdir_p(mpth) 32 | 33 | # %% 34 | cfg = {"sim":["quilt","ggblocks","both"], "nside":36, "nzprob_nsp":0.2, 35 | "bkg_mean":0.2, "nb_shape":10.0, 36 | "J":[(250,0,250),(0,500,0)], #"Jsp":0, "Jmix":500, "Jns":0, 37 | "expr_mean":20.0, "mix_frac_spat":0.6, 38 | "seed":[1,2,3,4,5], "V":5} 39 | a = expand_grid(others=cfg) 40 | a.rename(columns={"J_0":"Jsp", "J_1":"Jmix", "J_2":"Jns"}, inplace=True) 41 | b = pd.DataFrame({"sim":["quilt","ggblocks","both"], 42 | "Lsp":[4,4,8],"Lns":[3,3,6]}) 43 | a = a.merge(b,how="left",on="sim") 44 | a["scenario"] = list(range(1,a.shape[0]+1)) 45 | a.to_csv("simulations/bm_mixed/scenarios.csv",index=False) 46 | 47 | # %% generate the simulated datasets and store to disk 48 | a = pd.read_csv(path.join(pth,"scenarios.csv")).convert_dtypes() 49 | def sim2disk(p): 50 | p = deepcopy(p) 51 | scen = p.pop("scenario") 52 | Lsp = p.pop("Lsp") 53 | ad = sim.sim(p["sim"], **p) 54 | ad.write_h5ad(path.join(dpth,"S{}.h5ad".format(scen)),compression="gzip") 55 | a.apply(sim2disk,axis=1) 56 | 57 | # %% check the hdf5 file is correct 58 | ad = read_h5ad(path.join(dpth,"S1.h5ad")) 59 | X = ad.obsm["spatial"] 60 | Y = ad.layers["counts"] 61 | Yn = ad.X 62 | visualize.heatmap(X,Y[:,0],cmap="Blues") 63 | visualize.heatmap(X,Yn[:,0],cmap="Blues") 64 | #check distribution of validation data points 65 | N = Y.shape[0] 66 | z = np.zeros(N) 67 | Ntr = round(0.95*N) 68 | z[Ntr:] = 1 69 | visualize.heatmap(X,z,cmap="Blues") 70 | 71 | # %% merge with models to make results csv for tracking model runs 72 | m = pd.read_csv(path.join(pth,"models.csv")).convert_dtypes() #this CSV was manually created 73 | d = a.merge(m,how="cross") 74 | d["L"] = d["Lsp"]+d["Lns"] 75 | d["T"] = d["Lsp"] 76 | d["key"] = d.apply(misc.params2key,axis=1) 77 | d["key"] = d.agg(lambda x: f"S{x['scenario']}/{x['key']}", axis=1) 78 | d["converged"] = False 79 | d.to_csv(path.join(rpth,"benchmark.csv"),index=False) 80 | -------------------------------------------------------------------------------- /simulations/bm_mixed/02_benchmark.ipy: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # --- 3 | # jupyter: 4 | # jupytext: 5 | # text_representation: 6 | # extension: .ipy 7 | # format_name: percent 8 | # format_version: '1.3' 9 | # jupytext_version: 1.6.0 10 | # kernelspec: 11 | # display_name: Python 3 12 | # language: python 13 | # name: python3 14 | # --- 15 | 16 | # %% 17 | from os import path 18 | # from copy import deepcopy 19 | # import numpy as np 20 | import pandas as pd 21 | # from matplotlib import pyplot as plt 22 | # from scipy.spatial.distance import pdist,cdist 23 | # from scipy.stats.stats import pearsonr,spearmanr 24 | # from sklearn.cluster import KMeans 25 | # from sklearn import metrics 26 | from scanpy import read_h5ad 27 | from tensorflow_probability import math as tm 28 | tfk = tm.psd_kernels 29 | 30 | from utils import training,postprocess,visualize 31 | # from simulations import sim 32 | pth = "simulations/bm_mixed" 33 | dpth = path.join(pth,"data") 34 | rpth = path.join(pth,"results") 35 | mpth = path.join(pth,"models") 36 | 37 | #%% Benchmarking at command line [markdown] 38 | """ 39 | To run on local computer, use 40 | `python -m simulations.benchmark 2 simulations/bm_mixed` 41 | where 2 is a row ID of benchmark.csv, min value 2, max possible value is 91 42 | 43 | To run on cluster first load anaconda environment 44 | ``` 45 | tmux 46 | interactive 47 | module load anaconda3/2021.5 48 | conda activate fwt 49 | python -m simulations.benchmark 2 simulations/bm_mixed 50 | ``` 51 | 52 | To run on cluster as a job array, subset of rows. Recommend 10min time limit. 53 | ``` 54 | PTH=./simulations/bm_mixed 55 | sbatch --mem=16G --array=5-91 ./simulations/benchmark.slurm $PTH 56 | ``` 57 | 58 | To run on cluster as a job array, all rows of CSV file 59 | ``` 60 | CSV=./simulations/bm_mixed/results/benchmark.csv 61 | PTH=./simulations/bm_mixed 62 | sbatch --mem=16G --array=1-$(wc -l < $CSV) ./simulations/benchmark.slurm $PTH 63 | ``` 64 | """ 65 | 66 | #%% Load dataset and set kernel and IPs 67 | ad = read_h5ad(path.join(dpth,"S1.h5ad")) 68 | #include only the training observations 69 | Ntr = round(0.95*ad.shape[0]) 70 | ad = ad[:Ntr,:] 71 | X = ad.obsm["spatial"] 72 | #extract factors and loadings 73 | tru = postprocess.interpret_nonneg_mixed(ad.obsm["spfac"], ad.varm["spload"], 74 | ad.obsm["nsfac"],ad.varm["nsload"], 75 | lda_mode=False) 76 | F0 = tru["spatial"]["factors"] 77 | W0 = tru["spatial"]["loadings"] 78 | alpha = W0.sum(axis=1) 79 | pd1 = pd.DataFrame({"spatial_wt":alpha}) 80 | pd1.spatial_wt.hist(bins=100) #spatial importance by feature 81 | #set hyperparams 82 | T = W0.shape[1] 83 | L = T+tru["nonspatial"]["loadings"].shape[1] 84 | M = 1296 85 | ker = tfk.MaternThreeHalves 86 | hmkw = {"figsize":(6,1.5),"s":1.5,"marker":"s","subplot_space":0, 87 | "spinecolor":"gray"} 88 | fig,axes=visualize.multiheatmap(X, tru["spatial"]["factors"], (1,4), cmap="Blues", **hmkw) 89 | fig,axes=visualize.multiheatmap(X, tru["nonspatial"]["factors"], (1,4), cmap="Blues", **hmkw) 90 | 91 | #%% Compare inferred to true factors 92 | pp = path.join(mpth,"S1/V5/L{}/poi_sz-constant/NSFH_T{}_{}_M{}".format(L,T,ker.__name__, M)) 93 | tro = training.ModelTrainer.from_pickle(pp) 94 | fit = tro.model 95 | insfh = postprocess.interpret_nsfh(fit,X) 96 | F = insfh["spatial"]["factors"] 97 | fig,axes=visualize.multiheatmap(X, F, (1,4), cmap="Blues", **hmkw) 98 | fig,axes=visualize.multiheatmap(X, insfh["nonspatial"]["factors"], (1,4), cmap="Blues", **hmkw) 99 | 100 | #%% Compute goodness-of-fit metrics 101 | """ 102 | ``` 103 | python -m simulations.benchmark_gof simulations/bm_mixed 104 | ``` 105 | or 106 | ``` 107 | PTH=./simulations/bm_mixed 108 | sbatch --mem=16G ./simulations/benchmark_gof.slurm $PTH 109 | ``` 110 | """ 111 | -------------------------------------------------------------------------------- /simulations/bm_mixed/03_benchmark_viz.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Benchmarking Visualization" 3 | author: "Will Townes" 4 | output: html_document 5 | --- 6 | 7 | ```{r} 8 | library(tidyverse) 9 | theme_set(theme_bw()) 10 | fp<-file.path 11 | pth<-"simulations/bm_mixed" 12 | plt_pth<-fp(pth,"results/plots") 13 | if(!dir.exists(plt_pth)){ 14 | dir.create(plt_pth,recursive=TRUE) 15 | } 16 | ``` 17 | 18 | ```{r} 19 | d<-read.csv(fp(pth,"results/benchmark.csv")) 20 | d$converged<-as.logical(d$converged) 21 | d$model<-factor(d$model,levels=c("PNMF","NSF","NSFH")) 22 | sims<-c("ggblocks","quilt","both") 23 | d$sim<-factor(d$sim,levels=sims) 24 | d$simulation<-plyr::mapvalues(d$sim,sims,c("blocks (4+3)","quilt (4+3)","both (8+6)")) 25 | d$mixed_genes<-d$Jmix>0 26 | 27 | summarize_multicol<-function(x,f){ 28 | #x is a character column whose entries are lists of numbers "[3.02, 4.33, 7.71,...]" 29 | #expand each entry into a vector and summarize it by 30 | #a function (f) like min, max, mean 31 | #return a numeric vector with the summary stat, has same length as x 32 | x<-sub("[","c(",x,fixed=TRUE) 33 | x<-sub("]",")",x,fixed=TRUE) 34 | vapply(x,function(t){f(eval(str2lang(t)))},FUN.VALUE=1.0) 35 | } 36 | 37 | d$factors_pearson_min<-summarize_multicol(d$factors_pearson,min) 38 | d$loadings_pearson_min<-summarize_multicol(d$loadings_pearson,min) 39 | d$factors_spearman_min<-summarize_multicol(d$factors_spearman,min) 40 | d$loadings_spearman_min<-summarize_multicol(d$loadings_spearman,min) 41 | ``` 42 | 43 | ```{r} 44 | ggplot(d,aes(x=model,y=dev_val_mean,color=simulation,shape=mixed_genes))+geom_point(size=3,position=position_jitterdodge())+scale_y_log10()+ylab("validation deviance (mean)") 45 | ggsave(fp(plt_pth,"bm_mixed_dev_val_mean.pdf"),width=5,height=2.5) 46 | 47 | # ggplot(d,aes(x=model,y=factors_pearson_min,color=simulation,fill=simulation))+geom_boxplot()+scale_y_log10()+ylab("minimum factors correlation") 48 | # ggsave(fp(plt_pth,"bm_mixed_factors_pcor_min.pdf"),width=5,height=2.5) 49 | # 50 | # ggplot(d,aes(x=model,y=loadings_pearson_min,fill=simulation,color=simulation))+geom_boxplot()+scale_y_log10()+ylab("minimum loadings correlation") 51 | # ggsave(fp(plt_pth,"bm_mixed_loadings_pcor_min.pdf"),width=5,height=2.5) 52 | 53 | ggplot(d,aes(x=model,y=spatial_importance_dist,color=simulation,shape=mixed_genes))+geom_point(size=3,position=position_jitterdodge())+ylab("spatial importance distance") 54 | ggsave(fp(plt_pth,"bm_mixed_spat_importance_dist.pdf"),width=5,height=2.5) 55 | 56 | #ggplot(d,aes(x=model,y=loadings_clust_ari,fill=simulation,color=simulation))+geom_boxplot()+scale_y_log10()+ylab("Concordance of feature clusters with truth (ARI)") 57 | ``` 58 | 59 | Linear regressions to get statistical significance 60 | 61 | ```{r} 62 | #NSFH vs NSF 63 | d$is_NSFH<-(d$model=="NSFH") 64 | summary(lm(dev_val_mean~is_NSFH+simulation,data=d)) 65 | ``` 66 | -------------------------------------------------------------------------------- /simulations/bm_sp/01_data_generation.ipy: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # --- 3 | # jupyter: 4 | # jupytext: 5 | # text_representation: 6 | # extension: .ipy 7 | # format_name: percent 8 | # format_version: '1.3' 9 | # jupytext_version: 1.6.0 10 | # kernelspec: 11 | # display_name: Python 3 12 | # language: python 13 | # name: python3 14 | # --- 15 | 16 | # %% 17 | from os import path 18 | from copy import deepcopy 19 | import numpy as np 20 | import pandas as pd 21 | from scanpy import read_h5ad 22 | from janitor import expand_grid 23 | from utils import misc,visualize 24 | from simulations import sim 25 | pth = "simulations/bm_sp" 26 | dpth = path.join(pth,"data") 27 | rpth = path.join(pth,"results") 28 | mpth = path.join(pth,"models") 29 | # misc.mkdir_p(dpth) # or use symlink to dropbox 30 | # misc.mkdir_p(rpth) 31 | # misc.mkdir_p(mpth) 32 | 33 | # %% define scenarios 34 | cfg = {"sim":["quilt","ggblocks","both"], "nside":36, "bkg_mean":0.2, 35 | "nb_shape":10.0, "Jsp":200, "Jmix":0, "Jns":0, "expr_mean":20.0, 36 | "seed":[1,2,3,4,5], "V":5} 37 | a = expand_grid(others=cfg) 38 | b = pd.DataFrame({"sim":["quilt","ggblocks","both"], "L":[4,4,8]}) 39 | a = a.merge(b,how="left",on="sim") 40 | a["scenario"] = list(range(1,a.shape[0]+1)) 41 | a.to_csv(path.join(pth,"scenarios.csv"),index=False) #store separately for data generation 42 | 43 | # %% generate the simulated datasets and store to disk 44 | a = pd.read_csv(path.join(pth,"scenarios.csv")).convert_dtypes() 45 | def sim2disk(p): 46 | p = deepcopy(p) 47 | scen = p.pop("scenario") 48 | ad = sim.sim(p["sim"], Lns=0, **p) 49 | ad.write_h5ad(path.join(dpth,"S{}.h5ad".format(scen)),compression="gzip") 50 | a.apply(sim2disk,axis=1) 51 | 52 | # %% check the hdf5 file is correct 53 | ad = read_h5ad(path.join(dpth,"S1.h5ad")) 54 | X = ad.obsm["spatial"] 55 | Y = ad.layers["counts"] 56 | Yn = ad.X 57 | visualize.heatmap(X,Y[:,0],cmap="Blues") 58 | visualize.heatmap(X,Yn[:,0],cmap="Blues") 59 | #check distribution of validation data points 60 | N = Y.shape[0] 61 | z = np.zeros(N) 62 | Ntr = round(0.95*N) 63 | z[Ntr:] = 1 64 | visualize.heatmap(X,z,cmap="Blues") 65 | 66 | # %% merge with models to make results csv for tracking model runs 67 | m = pd.read_csv(path.join(pth,"models.csv")).convert_dtypes() #this CSV was manually created 68 | d = a.merge(m,how="cross") 69 | d["key"] = d.apply(misc.params2key,axis=1) 70 | d["key"] = d.agg(lambda x: f"S{x['scenario']}/{x['key']}", axis=1) 71 | # d["scenario"].to_string()+"/"+d["key"] 72 | d["converged"] = False 73 | d.to_csv(path.join(rpth,"benchmark.csv"),index=False) 74 | -------------------------------------------------------------------------------- /simulations/bm_sp/02_benchmark.ipy: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # --- 3 | # jupyter: 4 | # jupytext: 5 | # text_representation: 6 | # extension: .ipy 7 | # format_name: percent 8 | # format_version: '1.3' 9 | # jupytext_version: 1.6.0 10 | # kernelspec: 11 | # display_name: Python 3 12 | # language: python 13 | # name: python3 14 | # --- 15 | 16 | # %% 17 | from os import path 18 | # from copy import deepcopy 19 | import numpy as np 20 | # import pandas as pd 21 | from matplotlib import pyplot as plt 22 | from scipy.spatial.distance import pdist,cdist 23 | from scipy.stats.stats import pearsonr,spearmanr 24 | from sklearn.cluster import KMeans 25 | from sklearn import metrics 26 | from scanpy import read_h5ad 27 | from tensorflow_probability import math as tm 28 | tfk = tm.psd_kernels 29 | 30 | from utils import training,postprocess,visualize 31 | # from simulations import sim 32 | pth = "simulations/bm_sp" 33 | dpth = path.join(pth,"data") 34 | rpth = path.join(pth,"results") 35 | mpth = path.join(pth,"models") 36 | 37 | #%% Benchmarking at command line [markdown] 38 | """ 39 | To run on local computer, use 40 | `python -m simulations.benchmark 2 simulations/bm_sp` 41 | where 2 is a row ID of benchmark.csv, min value 2, max possible value is 115 42 | 43 | To run on cluster first load anaconda environment 44 | ``` 45 | tmux 46 | interactive 47 | module load anaconda3/2021.5 48 | conda activate fwt 49 | python -m simulations.benchmark 2 simulations/bm_sp 50 | ``` 51 | 52 | To run on cluster as a job array, subset of rows. Recommend 10min time limit. 53 | ``` 54 | PTH=./simulations/bm_sp 55 | sbatch --mem=16G --array=2-26,52-76 ./simulations/benchmark.slurm $PTH 56 | ``` 57 | 58 | To run on cluster as a job array, all rows of CSV file 59 | ``` 60 | CSV=./simulations/bm_sp/results/benchmark.csv 61 | PTH=./simulations/bm_sp 62 | sbatch --mem=16G --array=1-$(wc -l < $CSV) ./simulations/benchmark.slurm $PTH 63 | ``` 64 | """ 65 | 66 | #%% Load dataset and set kernel and IPs 67 | ad = read_h5ad(path.join(dpth,"S12.h5ad")) 68 | #include only the training observations 69 | Ntr = round(0.95*ad.shape[0]) 70 | ad = ad[:Ntr,:] 71 | X = ad.obsm["spatial"] 72 | #extract factors and loadings 73 | tru = postprocess.interpret_nonneg(ad.obsm["spfac"],ad.varm["spload"],lda_mode=False) 74 | F0 = tru["factors"] 75 | W0 = tru["loadings"] 76 | FFd0 = pdist(F0) 77 | WWd0 = pdist(W0) 78 | #set hyperparams 79 | M = 1296 80 | ker = tfk.MaternThreeHalves 81 | hmkw = {"figsize":(6,3),"s":1.5,"marker":"s","subplot_space":0, 82 | "spinecolor":"gray"} 83 | fig,axes=visualize.multiheatmap(X, F0, (2,4), cmap="Blues", **hmkw) 84 | 85 | #%% Compare inferred to true factors 86 | pp = path.join(mpth,"S12/V5/L8/poi_sz-constant/NSF_{}_M{}".format(ker.__name__, M)) 87 | tro = training.ModelTrainer.from_pickle(pp) 88 | fit = tro.model 89 | insf = postprocess.interpret_nsf(fit,X) 90 | F = insf["factors"] 91 | fig,axes=visualize.multiheatmap(X, F, (2,4), cmap="Blues", **hmkw) 92 | cdist(F0.T,F.T,metric=lambda x,y: abs(pearsonr(x,y)[0])).max(axis=1) 93 | cdist(F0.T,F.T,metric=lambda x,y: abs(spearmanr(x,y)[0])).max(axis=1) 94 | plt.scatter(F0[:,0],F[:,0]) 95 | 96 | FFd = pdist(insf["factors"]) 97 | W = insf["loadings"] 98 | WWd = pdist(W) 99 | plt.hexbin(FFd0,FFd,gridsize=100,cmap="Greys",bins="log") 100 | plt.scatter(WWd0,WWd) 101 | pearsonr(FFd0,FFd)[0] 102 | spearmanr(FFd0,FFd)[0] 103 | pearsonr(WWd0,WWd)[0] 104 | spearmanr(WWd0,WWd)[0] 105 | nclust = W0.shape[1] 106 | km0 = KMeans(n_clusters=nclust).fit(W0).labels_ 107 | km1 = KMeans(n_clusters=nclust).fit(W).labels_ 108 | ari = metrics.adjusted_rand_score(km0, km1) 109 | 110 | #%% Compute goodness-of-fit metrics 111 | """ 112 | ``` 113 | python -m simulations.benchmark_gof simulations/bm_sp 114 | ``` 115 | or 116 | ``` 117 | PTH=./simulations/bm_sp 118 | sbatch --mem=16G ./simulations/benchmark_gof.slurm $PTH 119 | ``` 120 | """ 121 | 122 | #%% Visualize results 123 | import pandas as pd 124 | import seaborn as sns 125 | from ast import literal_eval 126 | d = pd.read_csv(path.join(rpth,"benchmark.csv")) 127 | # d = d[d["converged"]] 128 | d["factors_pearson"] = d["factors_pearson"].map(lambda x: np.array(literal_eval(x))) 129 | d["factors_pearson_min"] = d["factors_pearson"].map(min) 130 | d["factors_pearson_mean"] = d["factors_pearson"].map(np.mean) 131 | # d["factors_spearman"] = d["factors_spearman"].map(lambda x: np.array(literal_eval(x))) 132 | # d["factors_spearman_min"] = d["factors_spearman"].map(min) 133 | # d["factors_spearman_mean"] = d["factors_spearman"].map(np.mean) 134 | sns.stripplot(x="model",y="factors_pearson_min",hue="sim",dodge=True,data=d) 135 | 136 | 137 | -------------------------------------------------------------------------------- /simulations/bm_sp/03_benchmark_viz.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Benchmarking Visualization" 3 | author: "Will Townes" 4 | output: html_document 5 | --- 6 | 7 | ```{r} 8 | library(tidyverse) 9 | theme_set(theme_bw()) 10 | fp<-file.path 11 | pth<-"simulations/bm_sp" 12 | plt_pth<-fp(pth,"results/plots") 13 | if(!dir.exists(plt_pth)){ 14 | dir.create(plt_pth,recursive=TRUE) 15 | } 16 | ``` 17 | 18 | ```{r} 19 | d<-read.csv(fp(pth,"results/benchmark.csv")) 20 | d$converged<-as.logical(d$converged) 21 | d$model<-factor(d$model,levels=c("FA","MEFISTO","RSF","PNMF","NSF")) 22 | sims<-c("ggblocks","quilt","both") 23 | d$sim<-factor(d$sim,levels=sims) 24 | d$simulation<-plyr::mapvalues(d$sim,sims,c("blocks (4)","quilt (4)","both (8)")) 25 | 26 | summarize_multicol<-function(x,f){ 27 | #x is a character column whose entries are lists of numbers "[3.02, 4.33, 7.71,...]" 28 | #expand each entry into a vector and summarize it by 29 | #a function (f) like min, max, mean 30 | #return a numeric vector with the summary stat, has same length as x 31 | x<-sub("[","c(",x,fixed=TRUE) 32 | x<-sub("]",")",x,fixed=TRUE) 33 | vapply(x,function(t){f(eval(str2lang(t)))},FUN.VALUE=1.0) 34 | } 35 | 36 | d$factors_pearson_min<-summarize_multicol(d$factors_pearson,min) 37 | d$loadings_pearson_min<-summarize_multicol(d$loadings_pearson,min) 38 | d$factors_spearman_min<-summarize_multicol(d$factors_spearman,min) 39 | d$loadings_spearman_min<-summarize_multicol(d$loadings_spearman,min) 40 | ``` 41 | 42 | ```{r} 43 | ggplot(d,aes(x=model,y=dev_val_mean,color=simulation))+geom_point(size=3,position=position_jitterdodge())+scale_y_log10()+ylab("validation deviance (mean)") 44 | ggsave(fp(plt_pth,"bm_sp_dev_val_mean.pdf"),width=5,height=2.5) 45 | 46 | ggplot(d,aes(x=model,y=factors_pearson_min,color=simulation))+geom_point(size=3,position=position_jitterdodge())+scale_y_log10()+ylab("minimum factors correlation") 47 | ggsave(fp(plt_pth,"bm_sp_factors_pcor_min.pdf"),width=5,height=2.5) 48 | 49 | ggplot(d,aes(x=model,y=loadings_pearson_min,color=simulation))+geom_point(size=3,position=position_jitterdodge())+scale_y_log10()+ylab("minimum loadings correlation") 50 | ggsave(fp(plt_pth,"bm_sp_loadings_pcor_min.pdf"),width=5,height=2.5) 51 | 52 | #ggplot(d,aes(x=model,y=factors_spearman_min,fill=simulation,color=simulation))+geom_boxplot()+scale_y_log10()+ylab("Spearman correlation with true factors") 53 | 54 | #ggplot(d,aes(x=model,y=loadings_spearman_min,fill=simulation,color=simulation))+geom_boxplot()+scale_y_log10()+ylab("Spearman correlation with true loadings") 55 | 56 | #ggplot(d,aes(x=model,y=loadings_clust_ari,fill=simulation,color=simulation))+geom_boxplot()+scale_y_log10()+ylab("Concordance of feature clusters with truth (ARI)") 57 | ``` 58 | 59 | Linear regressions to get statistical significance 60 | 61 | ```{r} 62 | #nonnegative vs real-valued 63 | d$nonneg<-(d$model %in% c("PNMF","NSF")) 64 | summary(lm(factors_pearson_min~nonneg+simulation,data=d)) 65 | summary(lm(loadings_pearson_min~nonneg+simulation,data=d)) 66 | d$sp_aware<-(d$model %in% c("RSF","NSF")) 67 | summary(lm(dev_val_mean~sp_aware+simulation,data=d)) 68 | ``` 69 | -------------------------------------------------------------------------------- /simulations/bm_sp/04_quilt_exploratory.ipy: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # --- 3 | # jupyter: 4 | # jupytext: 5 | # text_representation: 6 | # extension: .ipy 7 | # format_name: percent 8 | # format_version: '1.3' 9 | # jupytext_version: 1.6.0 10 | # kernelspec: 11 | # display_name: Python 3 12 | # language: python 13 | # name: python3 14 | # --- 15 | 16 | # %% 17 | # import numpy as np 18 | from os import path 19 | from scanpy import read_h5ad 20 | from tensorflow_probability import math as tm 21 | tfk = tm.psd_kernels 22 | 23 | from utils import preprocess,training,misc,postprocess,visualize 24 | 25 | # rng = np.random.default_rng() 26 | pth = "simulations/bm_sp" 27 | dpth = path.join(pth,"data") 28 | mpth = path.join(pth,"models/S1/V5") 29 | plt_pth = path.join(pth,"results/plots") 30 | misc.mkdir_p(plt_pth) 31 | 32 | #%% Data loading 33 | ad = read_h5ad(path.join(dpth,"S1.h5ad")) 34 | N = ad.shape[0] 35 | Ntr = round(0.95*N) 36 | ad = ad[:Ntr,:] 37 | J = ad.shape[1] 38 | X = ad.obsm["spatial"] 39 | # D_n,_ = preprocess.anndata_to_train_val(ad,train_frac=1.0,flip_yaxis=False) 40 | 41 | #%% Save heatmap of true values and sampled data 42 | hmkw = {"figsize":(8,1.9), "bgcol":"gray", "subplot_space":0.1, "marker":"s", 43 | "s":2.9} 44 | Ftrue = ad.obsm["spfac"] 45 | fig,axes=visualize.multiheatmap(X, Ftrue, (1,4), cmap="Blues", **hmkw) 46 | fig.savefig(path.join(plt_pth,"quilt_true_factors.png"),bbox_inches='tight') 47 | # fig.savefig(path.join(plt_pth,"quilt_true_factors.pdf"),bbox_inches='tight') 48 | #saving directly to pdf looks weird. Try imagemagick instead 49 | #convert quilt_true_factors.png quilt_true_factors.pdf 50 | 51 | Yss = ad.layers["counts"][:,(4,0,1,2)] 52 | fig,axes=visualize.multiheatmap(X, Yss, (1,4), cmap="Blues", **hmkw) 53 | fig.savefig(path.join(plt_pth,"quilt_data.png"),bbox_inches='tight') 54 | # fig.savefig(path.join(plt_pth,"quilt_data.pdf"),bbox_inches='tight') 55 | 56 | # %% Initialize inducing points 57 | L = 4 58 | M = N #number of inducing points 59 | Z = X 60 | ker = tfk.MaternThreeHalves 61 | 62 | #%% NSF 63 | try: 64 | pp = path.join(mpth,"L{}/poi_sz-constant/NSF_{}_M{}".format(L,ker.__name__,M)) 65 | tro = training.ModelTrainer.from_pickle(pp) 66 | fit = tro.model 67 | # except FileNotFoundError: 68 | # fit = sf.SpatialFactorization(J,L,Z,psd_kernel=ker,nonneg=True,lik="poi") 69 | # fit.init_loadings(D["Y"],X=X,sz=D["sz"],shrinkage=0.3) 70 | # pp = fit.generate_pickle_path("constant",base=mpth) 71 | # tro = training.ModelTrainer(fit,pickle_path=pp) 72 | # %time tro.train_model(*Dtf) #12 mins 73 | insf = postprocess.interpret_nsf(fit,X,S=100,lda_mode=False) 74 | Fplot = insf["factors"][:,[3,1,2,0]] 75 | fig,axes=visualize.multiheatmap(X, Fplot, (1,4), cmap="Blues", **hmkw) 76 | fig.savefig(path.join(plt_pth,"quilt_nsf.png"),bbox_inches='tight') 77 | 78 | #%% PNMF 79 | try: 80 | pp = path.join(mpth,"L{}/poi_sz-constant/PNMF".format(L)) 81 | tro = training.ModelTrainer.from_pickle(pp) 82 | fit = tro.model 83 | # except FileNotFoundError: 84 | # fit = cf.CountFactorization(N,J,L,nonneg=True,lik="poi") 85 | # fit.init_loadings(D["Y"],sz=D["sz"],shrinkage=0.3) 86 | # pp = fit.generate_pickle_path("constant",base=mpth) 87 | # tro = training.ModelTrainer(fit,pickle_path=pp) 88 | # %time tro.train_model(*Dtf) #3 mins 89 | ipnmf = postprocess.interpret_pnmf(fit,S=100,lda_mode=False) 90 | Fplot = ipnmf["factors"][:,[3,1,2,0]] 91 | fig,axes=visualize.multiheatmap(X, Fplot, (1,4), cmap="Blues", **hmkw) 92 | fig.savefig(path.join(plt_pth,"quilt_pnmf.png"),bbox_inches='tight') 93 | 94 | #%% MEFISTO-Gaussian 95 | from models.mefisto import MEFISTO 96 | pp = path.join(mpth,"L{}/gau/MEFISTO_ExponentiatedQuadratic_M{}".format(L,M)) 97 | try: 98 | mef = MEFISTO.from_pickle(pp) 99 | # except FileNotFoundError: 100 | # mef = MEFISTO(D_n, L, inducing_pts=M, pickle_path=pp) 101 | # %time mef.train() #also saves to pickle file- 28min 102 | Fplot = mef.get_factors() 103 | fig,axes=visualize.multiheatmap(X, Fplot, (1,4), cmap="RdBu", **hmkw) 104 | fig.savefig(path.join(plt_pth,"quilt_mefisto.png"),bbox_inches='tight') 105 | 106 | #%% FA: Non-spatial, real-valued 107 | try: 108 | pp = path.join(mpth,"L{}/gau/FA".format(L)) 109 | tro = training.ModelTrainer.from_pickle(pp) 110 | fit = tro.model 111 | # except FileNotFoundError: 112 | # fit = cf.CountFactorization(N, J, L, nonneg=False, lik="gau", 113 | # feature_means=fmeans) 114 | # fit.init_loadings(D_c["Y"]) 115 | # pp = fit.generate_pickle_path(None,base=mpth) 116 | # tro = training.ModelTrainer(fit,pickle_path=pp) 117 | # %time tro.train_model(*Dtf_c) #14sec 118 | Fplot = postprocess.interpret_fa(fit,S=100)["factors"] 119 | fig,axes=visualize.multiheatmap(X, Fplot, (1,4), cmap="RdBu", **hmkw) 120 | fig.savefig(path.join(plt_pth,"quilt_fa.png"),bbox_inches='tight') 121 | 122 | #%% RSF 123 | try: 124 | pp = path.join(mpth,"L{}/gau/RSF_{}_M{}".format(L,ker.__name__,M)) 125 | tro = training.ModelTrainer.from_pickle(pp) 126 | fit = tro.model 127 | # except FileNotFoundError: 128 | # fit = sf.SpatialFactorization(J,L,Z,psd_kernel=ker,nonneg=False,lik="gau") 129 | # fit.init_loadings(D_c["Y"],X=X) 130 | # pp = fit.generate_pickle_path(None,base=mpth) 131 | # tro = training.ModelTrainer(fit,pickle_path=pp) 132 | # %time tro.train_model(*Dtf_c) #5 mins 133 | Fplot = postprocess.interpret_rsf(fit,X,S=100)["factors"] 134 | fig,axes=visualize.multiheatmap(X, Fplot, (1,4), cmap="RdBu", **hmkw) 135 | fig.savefig(path.join(plt_pth,"quilt_rsf.png"),bbox_inches='tight') 136 | 137 | #%% Traditional scanpy analysis (unsupervised clustering) 138 | #https://scanpy-tutorials.readthedocs.io/en/latest/spatial/basic-analysis.html 139 | import pandas as pd 140 | import scanpy as sc 141 | sc.pp.pca(ad) 142 | sc.pp.neighbors(ad) 143 | sc.tl.umap(ad) 144 | sc.tl.leiden(ad, resolution=1.0, key_added="clusters") 145 | # plt.rcParams["figure.figsize"] = (4, 4) 146 | sc.pl.umap(ad, color="clusters", wspace=0.4) 147 | sc.pl.embedding(ad, "spatial", color="clusters") 148 | cl = pd.get_dummies(ad.obs["clusters"]).to_numpy() 149 | # tgnames = [str(i) for i in range(1,cl.shape[1]+1)] 150 | # hmkw = {"figsize":(7.7,5.9), "bgcol":"gray", "subplot_space":0.05, "marker":"s", 151 | # "s":2.9} 152 | hmkw = {"figsize":(8,2.5), "bgcol":"gray", "subplot_space":0.05, "marker":"s", 153 | "s":1} 154 | fig,axes=visualize.multiheatmap(X, cl, (2,6), cmap="Blues", **hmkw) 155 | # visualize.set_titles(fig, tgnames, x=0.03, y=.88, fontsize="small", c="white", 156 | # ha="left", va="top") 157 | fig.savefig(path.join(plt_pth,"quilt_scanpy_clusters.png"), 158 | bbox_inches='tight') 159 | -------------------------------------------------------------------------------- /simulations/bm_sp/05_ggblocks_exploratory.ipy: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # --- 3 | # jupyter: 4 | # jupytext: 5 | # text_representation: 6 | # extension: .ipy 7 | # format_name: percent 8 | # format_version: '1.3' 9 | # jupytext_version: 1.6.0 10 | # kernelspec: 11 | # display_name: Python 3 12 | # language: python 13 | # name: python3 14 | # --- 15 | 16 | # %% 17 | # import numpy as np 18 | from os import path 19 | from scanpy import read_h5ad 20 | from tensorflow_probability import math as tm 21 | tfk = tm.psd_kernels 22 | 23 | from utils import preprocess,training,misc,postprocess,visualize 24 | 25 | # rng = np.random.default_rng() 26 | pth = "simulations/bm_sp" 27 | dpth = path.join(pth,"data") 28 | mpth = path.join(pth,"models/S6/V5") 29 | plt_pth = path.join(pth,"results/plots") 30 | misc.mkdir_p(plt_pth) 31 | 32 | #%% Data loading 33 | ad = read_h5ad(path.join(dpth,"S6.h5ad")) 34 | N = ad.shape[0] 35 | Ntr = round(0.95*N) 36 | ad = ad[:Ntr,:] 37 | J = ad.shape[1] 38 | X = ad.obsm["spatial"] 39 | D_n,_ = preprocess.anndata_to_train_val(ad,train_frac=1.0,flip_yaxis=False) 40 | 41 | #%% Save heatmap of true values and sampled data 42 | hmkw = {"figsize":(8,1.9), "bgcol":"gray", "subplot_space":0.1, "marker":"s", 43 | "s":2.9} 44 | Ftrue = ad.obsm["spfac"] 45 | fig,axes=visualize.multiheatmap(X, Ftrue, (1,4), cmap="Blues", **hmkw) 46 | fig.savefig(path.join(plt_pth,"ggblocks_true_factors.png"),bbox_inches='tight') 47 | 48 | Yss = ad.layers["counts"][:,(4,0,1,2)] 49 | fig,axes=visualize.multiheatmap(X, Yss, (1,4), cmap="Blues", **hmkw) 50 | fig.savefig(path.join(plt_pth,"ggblocks_data.png"),bbox_inches='tight') 51 | 52 | # %% Initialize inducing points 53 | L = 4 54 | M = N #number of inducing points 55 | Z = X 56 | ker = tfk.MaternThreeHalves 57 | 58 | #%% NSF 59 | try: 60 | pp = path.join(mpth,"L{}/poi_sz-constant/NSF_{}_M{}".format(L,ker.__name__,M)) 61 | tro = training.ModelTrainer.from_pickle(pp) 62 | fit = tro.model 63 | # except FileNotFoundError: 64 | # fit = sf.SpatialFactorization(J,L,Z,psd_kernel=ker,nonneg=True,lik="poi") 65 | # fit.init_loadings(D["Y"],X=X,sz=D["sz"],shrinkage=0.3) 66 | # pp = fit.generate_pickle_path("constant",base=mpth) 67 | # tro = training.ModelTrainer(fit,pickle_path=pp) 68 | # %time tro.train_model(*Dtf) #12 mins 69 | insf = postprocess.interpret_nsf(fit,X,S=100,lda_mode=False) 70 | Fplot = insf["factors"][:,[3,1,2,0]] 71 | fig,axes=visualize.multiheatmap(X, Fplot, (1,4), cmap="Blues", **hmkw) 72 | fig.savefig(path.join(plt_pth,"ggblocks_nsf.png"),bbox_inches='tight') 73 | 74 | #%% PNMF 75 | try: 76 | pp = path.join(mpth,"L{}/poi_sz-constant/PNMF".format(L)) 77 | tro = training.ModelTrainer.from_pickle(pp) 78 | fit = tro.model 79 | # except FileNotFoundError: 80 | # fit = cf.CountFactorization(N,J,L,nonneg=True,lik="poi") 81 | # fit.init_loadings(D["Y"],sz=D["sz"],shrinkage=0.3) 82 | # pp = fit.generate_pickle_path("constant",base=mpth) 83 | # tro = training.ModelTrainer(fit,pickle_path=pp) 84 | # %time tro.train_model(*Dtf) #3 mins 85 | ipnmf = postprocess.interpret_pnmf(fit,S=100,lda_mode=False) 86 | Fplot = ipnmf["factors"][:,[3,1,2,0]] 87 | fig,axes=visualize.multiheatmap(X, Fplot, (1,4), cmap="Blues", **hmkw) 88 | fig.savefig(path.join(plt_pth,"ggblocks_pnmf.png"),bbox_inches='tight') 89 | 90 | #%% MEFISTO-Gaussian 91 | from models.mefisto import MEFISTO 92 | pp = path.join(mpth,"L{}/gau/MEFISTO_ExponentiatedQuadratic_M{}".format(L,M)) 93 | try: 94 | mef = MEFISTO.from_pickle(pp) 95 | except FileNotFoundError: 96 | mef = MEFISTO(D_n, L, inducing_pts=M, pickle_path=pp) 97 | %time mef.train() #also saves to pickle file- 9min 98 | Fplot = mef.get_factors() 99 | fig,axes=visualize.multiheatmap(X, Fplot, (1,4), cmap="RdBu", **hmkw) 100 | fig.savefig(path.join(plt_pth,"ggblocks_mefisto.png"),bbox_inches='tight') 101 | 102 | #%% FA: Non-spatial, real-valued 103 | try: 104 | pp = path.join(mpth,"L{}/gau/FA".format(L)) 105 | tro = training.ModelTrainer.from_pickle(pp) 106 | fit = tro.model 107 | # except FileNotFoundError: 108 | # fit = cf.CountFactorization(N, J, L, nonneg=False, lik="gau", 109 | # feature_means=fmeans) 110 | # fit.init_loadings(D_c["Y"]) 111 | # pp = fit.generate_pickle_path(None,base=mpth) 112 | # tro = training.ModelTrainer(fit,pickle_path=pp) 113 | # %time tro.train_model(*Dtf_c) #14sec 114 | Fplot = postprocess.interpret_fa(fit,S=100)["factors"] 115 | fig,axes=visualize.multiheatmap(X, Fplot, (1,4), cmap="RdBu", **hmkw) 116 | fig.savefig(path.join(plt_pth,"ggblocks_fa.png"),bbox_inches='tight') 117 | 118 | #%% RSF 119 | try: 120 | pp = path.join(mpth,"L{}/gau/RSF_{}_M{}".format(L,ker.__name__,M)) 121 | tro = training.ModelTrainer.from_pickle(pp) 122 | fit = tro.model 123 | # except FileNotFoundError: 124 | # fit = sf.SpatialFactorization(J,L,Z,psd_kernel=ker,nonneg=False,lik="gau") 125 | # fit.init_loadings(D_c["Y"],X=X) 126 | # pp = fit.generate_pickle_path(None,base=mpth) 127 | # tro = training.ModelTrainer(fit,pickle_path=pp) 128 | # %time tro.train_model(*Dtf_c) #5 mins 129 | Fplot = postprocess.interpret_rsf(fit,X,S=100)["factors"] 130 | fig,axes=visualize.multiheatmap(X, Fplot, (1,4), cmap="RdBu", **hmkw) 131 | fig.savefig(path.join(plt_pth,"ggblocks_rsf.png"),bbox_inches='tight') 132 | 133 | #%% Traditional scanpy analysis (unsupervised clustering) 134 | #https://scanpy-tutorials.readthedocs.io/en/latest/spatial/basic-analysis.html 135 | import pandas as pd 136 | import scanpy as sc 137 | sc.pp.pca(ad) 138 | sc.pp.neighbors(ad) 139 | sc.tl.umap(ad) 140 | sc.tl.leiden(ad, resolution=1.0, key_added="clusters") 141 | # plt.rcParams["figure.figsize"] = (4, 4) 142 | sc.pl.umap(ad, color="clusters", wspace=0.4) 143 | sc.pl.embedding(ad, "spatial", color="clusters") 144 | cl = pd.get_dummies(ad.obs["clusters"]).to_numpy() 145 | # tgnames = [str(i) for i in range(1,cl.shape[1]+1)] 146 | # hmkw = {"figsize":(7.7,5.9), "bgcol":"gray", "subplot_space":0.05, "marker":"s", 147 | # "s":2.9} 148 | # hmkw = {"figsize":(6,4), "bgcol":"gray", "subplot_space":0.05, "marker":"s", 149 | # "s":3.4} 150 | hmkw = {"figsize":(8,1.6), "bgcol":"gray", "subplot_space":0.05, "marker":"s", 151 | "s":1.6} 152 | fig,axes=visualize.multiheatmap(X, cl, (1,5), cmap="Blues", **hmkw) 153 | # visualize.set_titles(fig, tgnames, x=0.03, y=.88, fontsize="small", c="white", 154 | # ha="left", va="top") 155 | fig.savefig(path.join(plt_pth,"ggblocks_scanpy_clusters.png"), 156 | bbox_inches='tight') 157 | -------------------------------------------------------------------------------- /simulations/bm_sp/data/S1.h5ad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/willtownes/nsf-paper/0cacf8352e09d223ab8d4421025195358bbde8df/simulations/bm_sp/data/S1.h5ad -------------------------------------------------------------------------------- /simulations/bm_sp/data/S6.h5ad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/willtownes/nsf-paper/0cacf8352e09d223ab8d4421025195358bbde8df/simulations/bm_sp/data/S6.h5ad -------------------------------------------------------------------------------- /simulations/sim.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Fri Feb 4 09:05:19 2022 5 | 6 | @author: townesf 7 | """ 8 | import random 9 | import numpy as np 10 | from pandas import get_dummies 11 | from anndata import AnnData 12 | from scanpy import pp 13 | from utils import misc,preprocess 14 | dtp = "float32" 15 | random.seed(101) 16 | 17 | def squares(): 18 | A = np.zeros([12,12]) 19 | A[1:5,1:5] = 1 20 | A[7:11,1:5] = 1 21 | A[1:5,7:11] = 1 22 | A[7:11,7:11] = 1 23 | return A 24 | 25 | def corners(): 26 | B = np.zeros([6,6]) 27 | for i in range(6): 28 | B[i,i:] = 1 29 | A = np.flip(B,axis=1) 30 | AB = np.hstack((A,B)) 31 | CD = np.flip(AB,axis=0) 32 | return np.vstack((AB,CD)) 33 | 34 | def scotland(): 35 | A = np.eye(12) 36 | for i in range(12): 37 | A[-i-1,i] = 1 38 | return A 39 | 40 | def checkers(): 41 | A = np.zeros([4,4]) 42 | B = np.ones([4,4]) 43 | AB = np.hstack((A,B,A)) 44 | BA = np.hstack((B,A,B)) 45 | return np.vstack((AB,BA,AB)) 46 | 47 | def quilt(): 48 | A = np.zeros([4,144]) 49 | A[0,:] = squares().flatten() 50 | A[1,:] = corners().flatten() 51 | A[2,:] = scotland().flatten() 52 | A[3,:] = checkers().flatten() 53 | return A #basic block size is 12x12 54 | 55 | def ggblocks(): 56 | A = np.zeros( [ 4 , 36 ] ) 57 | A[0, [ 1 , 6 , 7 , 8 , 13 ] ] = 1 58 | A[1, [ 3 , 4 , 5 , 9 , 11 , 15 , 16 , 17 ] ] = 1 59 | A[2, [ 18 , 24 , 25 , 30 , 31 , 32 ] ] = 1 60 | A[3, [ 21 , 22 , 23 , 28 , 34 ] ] = 1 61 | return A #basic block size is 6x6 62 | 63 | def sqrt_int(x): 64 | z = int(round(x**.5)) 65 | if x==z**2: 66 | return z 67 | else: 68 | raise ValueError("x must be a square integer") 69 | 70 | def gen_spatial_factors(scenario="quilt",nside=36): 71 | """ 72 | Generate the factors matrix for either the 'quilt' or 'ggblocks' scenario 73 | There are 4 basic patterns (L=4) 74 | There are N=(nside^2) observations. 75 | Returns: 76 | factor values[Nx4] matrix 77 | """ 78 | if scenario=="quilt": 79 | A = quilt() 80 | elif scenario=="ggblocks": 81 | A = ggblocks() 82 | else: 83 | raise ValueError("scenario must be 'quilt' or 'ggblocks'") 84 | unit = sqrt_int(A.shape[1]) #quilt: 12, ggblocks: 6 85 | assert nside%unit==0 86 | ncopy = nside//unit 87 | N = nside**2 #36x36=1296 88 | L = A.shape[0] #4 89 | A = A.reshape((L,unit,unit)) 90 | A = np.kron(A,np.ones((1,ncopy,ncopy))) 91 | F = A.reshape((L,N)).T #NxL 92 | return F 93 | 94 | def gen_spatial_coords(N): #N is number of observations 95 | X = misc.make_grid(N) 96 | X[:,1] = -X[:,1] #make the display the same 97 | return preprocess.rescale_spatial_coords(X) 98 | 99 | def gen_nonspatial_factors(N,L=3,nzprob=0.2,seed=101): 100 | rng = np.random.default_rng(seed) 101 | return rng.binomial(1,nzprob,size=(N,L)) 102 | 103 | def gen_loadings(Lsp, Lns=3, Jsp=0, Jmix=500, Jns=0, expr_mean=20.0, 104 | mix_frac_spat=0.55, seed=101, **kwargs): 105 | """ 106 | generate a loadings matrix L=components, J=features 107 | kwargs currently ignored 108 | """ 109 | rng = np.random.default_rng(seed) 110 | J = Jsp+Jmix+Jns #total number of features 111 | if Lsp>0: 112 | w = rng.choice(Lsp,J,replace=True) #spatial loadings 113 | W = get_dummies(w).to_numpy(dtype=dtp) #JxLsp indicator matrix 114 | else: 115 | W = np.zeros((J,0)) 116 | if Lns>0: 117 | v = rng.choice(Lns,J,replace=True) #nonspatial loadings 118 | V = get_dummies(v).to_numpy(dtype=dtp) #JxLnsp indicator matrix 119 | else: 120 | V = np.zeros((J,0)) 121 | #pure spatial features 122 | W[:Jsp,:]*=expr_mean 123 | V[:Jsp,:]=0 124 | #features with mixed assignment to spatial and nonspatial components 125 | W[Jsp:(Jsp+Jmix),:]*=(mix_frac_spat*expr_mean) 126 | V[Jsp:(Jsp+Jmix),:]*=((1-mix_frac_spat)*expr_mean) 127 | #pure nonspatial features 128 | W[(Jsp+Jmix):,:]=0 129 | V[(Jsp+Jmix):,:]*=expr_mean 130 | return W,V 131 | 132 | def sim2anndata(locs, outcome, spfac, spload, nsfac=None, nsload=None): 133 | """ 134 | d: a dict returned by sim_quilt or sim_ggblocks 135 | returns: an AnnData object with both raw counts and scanpy normalized counts 136 | obsm slots: spatial, spfac, nsfac (coordinates and factors) 137 | varm slots: spload, nsload (loadings) 138 | """ 139 | obsm = {"spatial":locs, "spfac":spfac, "nsfac":nsfac} 140 | varm = {"spload":spload, "nsload":nsload} 141 | ad = AnnData(outcome, obsm=obsm, varm=varm) 142 | ad.layers = {"counts":ad.X.copy()} #store raw counts before normalization changes ad.X 143 | #here we do not normalize because the total counts are actually meaningful! 144 | # pp.normalize_total(ad, inplace=True, layers=None, key_added="sizefactor") 145 | pp.log1p(ad) 146 | #shuffle indices to disperse validation data throughout the field of view 147 | idx = list(range(ad.shape[0])) 148 | random.shuffle(idx) 149 | ad = ad[idx,:] 150 | return ad 151 | 152 | def sim(scenario, nside=36, nzprob_nsp=0.2, bkg_mean=0.2, nb_shape=10.0, 153 | seed=101, **kwargs): 154 | """ 155 | scenario: either 'quilt'(L=4), 'ggblocks'(L=4), or 'both'(L=8) 156 | N=number of observations is nside**2 157 | nzprob_nsp: for nonspatial factors, the probability of a "one" (else zero) 158 | bkg_mean: negative binomial mean for observations that are "zero" in the factors 159 | nb_shape: shape parameter of negative binomial distribution 160 | seed: for random number generation reproducibility 161 | kwargs: passed to gen_loadings 162 | """ 163 | if scenario=="both": 164 | F1 = gen_spatial_factors(nside=nside,scenario="ggblocks") 165 | F2 = gen_spatial_factors(nside=nside,scenario="quilt") 166 | F = np.hstack((F1,F2)) 167 | else: 168 | F = gen_spatial_factors(scenario=scenario,nside=nside) 169 | rng = np.random.default_rng(seed) 170 | N = nside**2 171 | X = gen_spatial_coords(N) 172 | W,V = gen_loadings(F.shape[1],seed=seed, **kwargs) 173 | U = gen_nonspatial_factors(N,L=V.shape[1],nzprob=nzprob_nsp,seed=seed) 174 | Lambda = bkg_mean+F@W.T+U@V.T #NxJ 175 | r = nb_shape 176 | Y = rng.negative_binomial(r,r/(Lambda+r)) 177 | return sim2anndata(X,Y,F,W,nsfac=U,nsload=V) 178 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/willtownes/nsf-paper/0cacf8352e09d223ab8d4421025195358bbde8df/utils/__init__.py -------------------------------------------------------------------------------- /utils/benchmark_array.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=benchmark # create a short name for your job 3 | #SBATCH --nodes=1 # node count 4 | #SBATCH --ntasks=1 # total number of tasks across all nodes 5 | #SBATCH --cpus-per-task=12 # cpu-cores per task (>1 if multi-threaded tasks) 6 | #SBATCH --time=12:00:00 # total run time limit (HH:MM:SS) 7 | #SBATCH --mail-type=END,FAIL 8 | #SBATCH --mail-user=ftownes@princeton.edu 9 | 10 | #example usage, default is 4G memory per core (--mem-per-cpu), --mem is total 11 | #CSV=./scrna/visium_brain_sagittal/results/benchmark.csv 12 | #DAT=./scrna/visium_brain_sagittal/data/visium_brain_sagittal_J2000.h5ad 13 | #sbatch --mem=72G --array=1-$(wc -l < $CSV) ./util/benchmark_array.slurm $DAT 14 | 15 | module purge 16 | module load anaconda3/2021.5 17 | conda activate fwt 18 | 19 | #first command line arg $1 is file path to dataset 20 | python -um utils.benchmark $SLURM_ARRAY_TASK_ID $1 21 | -------------------------------------------------------------------------------- /utils/benchmark_gof.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Fri Sep 17 08:50:51 2021 5 | 6 | @author: townesf 7 | """ 8 | 9 | from argparse import ArgumentParser 10 | from utils import benchmark 11 | 12 | def arghandler(args=None): 13 | """parses a list of arguments (default is sys.argv[1:])""" 14 | parser = ArgumentParser() 15 | parser.add_argument("dataset", type=str, 16 | help="location of scanpy H5AD data file") 17 | parser.add_argument("val_pct", type=int, 18 | help="percentage of data to be used as validation set (0-100)") 19 | args = parser.parse_args(args) #if args is None, this will automatically parse sys.argv[1:] 20 | return args 21 | 22 | if __name__=="__main__": 23 | args = arghandler() 24 | # dat = "scrna/sshippo/data/sshippo_J2000.h5ad" 25 | res = benchmark.update_results(args.dataset, val_pct=args.val_pct, todisk=True) 26 | -------------------------------------------------------------------------------- /utils/benchmark_gof.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=benchmark_gof # create a short name for your job 3 | #SBATCH --nodes=1 # node count 4 | #SBATCH --ntasks=1 # total number of tasks across all nodes 5 | #SBATCH --cpus-per-task=12 # cpu-cores per task (>1 if multi-threaded tasks) 6 | #SBATCH --time=3:00:00 # total run time limit (HH:MM:SS) 7 | #SBATCH --mail-type=END,FAIL 8 | #SBATCH --mail-user=ftownes@princeton.edu 9 | 10 | #example usage, --mem-per-cpu default is 4G per core 11 | #DAT=./scrna/sshippo/data/sshippo_J2000.h5ad 12 | #sbatch --mem=180G ./utils/benchmark_gof.slurm $DAT 13 | 14 | module purge 15 | module load anaconda3/2021.5 16 | conda activate fwt 17 | 18 | #first command line arg $1 is file path to dataset 19 | #second command line arg $2 is an integer for the pct of data to be validation set (typically 5, implying 95% of observations are for training data) 20 | python -um utils.benchmark_gof $1 $2 21 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Tue Apr 6 18:10:06 2021 5 | 6 | @author: townesf 7 | """ 8 | import pathlib 9 | import numpy as np 10 | #from pickle import dump,load 11 | from math import ceil 12 | from copy import deepcopy 13 | from dill import dump,load 14 | from pandas import read_csv as pd_read_csv 15 | from tensorflow import clip_by_value 16 | from scipy.sparse import issparse 17 | from sklearn.cluster import KMeans 18 | from sklearn.utils import sparsefuncs 19 | from anndata import AnnData 20 | from squidpy.gr import spatial_neighbors,spatial_autocorr 21 | 22 | def mkdir_p(pth): 23 | pathlib.Path(pth).mkdir(parents=True,exist_ok=True) 24 | 25 | def rm_glob(pth,glob="*"): 26 | """ 27 | Remove all files in the pth folder matching glob. 28 | """ 29 | for i in pathlib.Path(pth).glob(glob): 30 | i.unlink(missing_ok=True) 31 | 32 | def poisson_loss(y,mu): 33 | """ 34 | Equivalent to the Tensorflow Poisson loss 35 | https://www.tensorflow.org/api_docs/python/tf/keras/losses/Poisson 36 | It's the negative log-likelihood of Poisson without the log y! constant 37 | """ 38 | with np.errstate(divide='ignore',invalid='ignore'): 39 | res = mu-y*np.log(mu) 40 | return np.mean(res[np.isfinite(res)]) 41 | 42 | def poisson_deviance(y,mu,agg="sum",axis=None): 43 | """ 44 | Equivalent to "KL divergence" between y and mu: 45 | https://scikit-learn.org/stable/modules/decomposition.html#nmf 46 | """ 47 | with np.errstate(divide='ignore',invalid='ignore'): 48 | term1 = y*np.log(y/mu) 49 | if agg=="sum": 50 | aggfunc = np.sum 51 | elif agg=="mean": 52 | aggfunc = np.mean 53 | else: 54 | raise ValueError("agg must be 'sum' or 'mean'") 55 | term1 = aggfunc(term1[np.isfinite(term1)], axis=axis) 56 | return term1 + aggfunc(mu - y, axis=axis) 57 | 58 | def rmse(y,mu): 59 | return np.sqrt(((y-mu)**2).mean()) 60 | 61 | def dev2ss(dev): 62 | """ 63 | dev: a 1d array of deviance values (one per feature) 64 | returns: a dictionary with summary statistics (mean, argmax, and max) 65 | """ 66 | return {"mean":dev.mean(), "argmax":dev.argmax(), "max":dev.max(), "med":np.median(dev)} 67 | 68 | def make_nonneg(x): 69 | return clip_by_value(x,0.0,np.inf) 70 | 71 | def make_grid(N,xmin=-2,xmax=2,dtype="float32"): 72 | x = np.linspace(xmin,xmax,num=int(np.sqrt(N)),dtype=dtype) 73 | return np.stack([X.ravel() for X in np.meshgrid(x,x)],axis=1) 74 | 75 | def kmeans_inducing_pts(X,M): 76 | M = int(M) 77 | Z = np.unique(X, axis=0) 78 | unique_locs = Z.shape[0] 79 | if M0 207 | #if row_id=1, skip=[] 208 | #if row_id=2, skip=[1] 209 | #if row_id=(file length-1), skip=[1,2,3,...,file_length-2] 210 | skip = range(1,row_id) 211 | #returns a pandas Series object 212 | return pd_read_csv(csv_file,skiprows=skip,nrows=1).iloc[0,:] 213 | 214 | def dims_autocorr(factors,coords,sort=True): 215 | """ 216 | factors: (num observations) x (num latent dimensions) array 217 | coords: (num observations) x (num spatial dimensions) array 218 | sort: if True (default), returns the index and I statistics in decreasing 219 | order of autocorrelation. If False, returns the index and I statistics 220 | according to the ordering of factors. 221 | 222 | returns: an integer array of length (num latent dims), "idx" 223 | and a numpy array containing the Moran's I values for each dimension 224 | 225 | indexing factors[:,idx] will sort the factors in decreasing order of spatial 226 | autocorrelation. 227 | """ 228 | ad = AnnData(X=factors,obsm={"spatial":coords}) 229 | spatial_neighbors(ad) 230 | df = spatial_autocorr(ad,mode="moran",copy=True) 231 | if not sort: #revert to original sort order 232 | df.sort_index(inplace=True) 233 | idx = np.array([int(i) for i in df.index]) 234 | return idx,df["I"].to_numpy() 235 | 236 | def nan_to_zero(X): 237 | X = deepcopy(X) 238 | X[np.isnan(X)] = 0 239 | return X 240 | -------------------------------------------------------------------------------- /utils/nnfu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Utility functions for working with nonnegative factor models. See also 5 | postprocess.py for more complex functions to facilitate interpretation. 6 | 7 | Created on Thu Sep 23 13:30:10 2021 8 | 9 | @author: townesf 10 | """ 11 | import numpy as np 12 | from sklearn.decomposition import NMF 13 | # from scipy.special import logsumexp 14 | from utils.misc import lnormal_approx_dirichlet 15 | 16 | def normalize_cols(W): 17 | """ 18 | Rescale the columns of a matrix to sum to one 19 | """ 20 | wsum = W.sum(axis=0) 21 | return W/wsum, wsum 22 | 23 | def normalize_rows(W): 24 | """ 25 | Rescale the rows of a matrix to sum to one 26 | """ 27 | wsum = W.sum(axis=1) 28 | return (W.T/wsum).T, wsum 29 | 30 | def shrink_factors(F,shrinkage=0.2): 31 | a = shrinkage 32 | if 0=N: Dval = None #avoid returning an empty array 144 | return Dtr,Dval 145 | 146 | def center_data(Dtr_n,Dval_n=None): 147 | Dtr_c = deepcopy(Dtr_n) 148 | feature_means=Dtr_c["Y"].mean(axis=0) 149 | Dtr_c["Y"] -= feature_means 150 | if Dval_n: 151 | Dval_c = deepcopy(Dval_n) 152 | Dval_c["Y"] -= feature_means 153 | else: 154 | Dval_c = None 155 | return feature_means,Dtr_c,Dval_c 156 | 157 | def minibatch_size_adjust(num_obs,batch_size): 158 | """ 159 | Calculate adjusted minibatch size that divides 160 | num_obs as evenly as possible 161 | num_obs : number of observations in full data 162 | batch_size : maximum size of a minibatch 163 | """ 164 | nbatch = ceil(num_obs/float(batch_size)) 165 | return int(ceil(num_obs/nbatch)) 166 | 167 | def prepare_datasets_tf(Dtrain,Dval=None,shuffle=False,batch_size=None): 168 | """ 169 | Dtrain and Dval are dicts containing numpy np.arrays of data. 170 | Dtrain must contain the key "Y" 171 | Returns a from_tensor_slices conversion of Dtrain and a dict of tensors for Dval 172 | """ 173 | Ntr = Dtrain["Y"].shape[0] 174 | if batch_size is None: 175 | #ie one batch containing all observations by default 176 | batch_size = Ntr 177 | else: 178 | batch_size = minibatch_size_adjust(Ntr,batch_size) 179 | Dtrain = Dataset.from_tensor_slices(Dtrain) 180 | if shuffle: 181 | Dtrain = Dtrain.shuffle(Ntr) 182 | Dtrain = Dtrain.batch(batch_size) 183 | if Dval is not None: 184 | Dval = {i:constant(Dval[i]) for i in Dval} 185 | return Dtrain, Ntr, Dval 186 | 187 | def load_data(dataset, model=None, lik=None, train_frac=0.95, sz="constant", 188 | flip_yaxis=True): 189 | """ 190 | dataset: the file path of a scanpy anndata h5ad file 191 | --OR-- the AnnData object itself 192 | p: a dict-like object of model parameters 193 | """ 194 | try: 195 | ad = read_h5ad(path.normpath(dataset)) 196 | except TypeError: 197 | ad = dataset 198 | kw1 = {"nfeat":None, "train_frac":train_frac, "dtp":"float32", 199 | "flip_yaxis":flip_yaxis} 200 | kw2 = {"shuffle":False,"batch_size":None} 201 | D = {"raw":{}} 202 | Dtr,Dval = anndata_to_train_val(ad,layer="counts",sz=sz,**kw1) 203 | D["raw"]["tr"] = Dtr 204 | D["raw"]["val"] = Dval 205 | D["raw"]["tf"] = prepare_datasets_tf(Dtr,Dval=Dval,**kw2) 206 | fmeans=None 207 | if lik is None or lik=="gau": 208 | #normalized data 209 | Dtr_n,Dval_n = anndata_to_train_val(ad,layer=None,sz="constant",**kw1) 210 | D["norm"] = {} 211 | D["norm"]["tr"] = Dtr_n 212 | D["norm"]["val"] = Dval_n 213 | D["norm"]["tf"] = prepare_datasets_tf(Dtr_n,Dval=Dval_n,**kw2) 214 | if model is None or model in ("RSF","FA"): 215 | #centered features 216 | fmeans,Dtr_c,Dval_c = center_data(Dtr_n,Dval_n) 217 | D["ctr"] = {} 218 | D["ctr"]["tr"] = Dtr_c 219 | D["ctr"]["val"] = Dval_c 220 | D["ctr"]["tf"] = prepare_datasets_tf(Dtr_c,Dval=Dval_c,**kw2) 221 | return D,fmeans 222 | 223 | # def split_data_tuple(D,train_frac=0.8,shuffle=True): 224 | # """ 225 | # D is list of data [X,Y,sz] 226 | # leading dimension of each element of D must be the same 227 | # train_frac: fraction of observations that should go into training data 228 | # 1-train_frac: observations for validation data 229 | # """ 230 | # n = D[0].shape[0] 231 | # idx = list(range(n)) 232 | # if shuffle: random.shuffle(idx) 233 | # ntr = round(train_frac*n) 234 | # itr = idx[:ntr] 235 | # ival = idx[ntr:n] 236 | # Dtr = [] 237 | # Dval = [] 238 | # for d in D: 239 | # shp = d.shape 240 | # if len(shp)==1: 241 | # Dtr.append(d[itr]) 242 | # Dval.append(d[ival]) 243 | # else: 244 | # Dtr.append(d[itr,:]) 245 | # Dval.append(d[ival,:]) 246 | # return tuple(Dtr),tuple(Dval) 247 | 248 | # def calc_hidden_layer_sizes(J,T,N): 249 | # """ 250 | # J is input dimension, T is output dimension, N is number of hidden layers. 251 | # Returns a tuple of length (N) containing the widths of hidden layers. 252 | # The spacing forms a linear decline from J to T 253 | # """ 254 | # delta = float(J-T)/(N+1) #increment size 255 | # res = np.round(J-delta*np.array(range(1,N+1))) 256 | # #res = res.astype("int32").tolist() 257 | # #return [J]+res+[T] 258 | # return res.astype("int32") 259 | -------------------------------------------------------------------------------- /utils/visualize.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Functions for visualization 5 | 6 | Created on Wed May 19 09:58:59 2021 7 | 8 | @author: townesf 9 | """ 10 | 11 | import numpy as np 12 | import matplotlib.pyplot as plt 13 | from contextlib import suppress 14 | from scipy.interpolate import interp1d 15 | from scipy.spatial import Delaunay 16 | 17 | from utils.misc import poisson_deviance,dev2ss,rmse 18 | 19 | def heatmap(X,y,figsize=(6,4),bgcol="gray",cmap="turbo",**kwargs): 20 | fig,ax=plt.subplots(figsize=figsize) 21 | ax.set_facecolor(bgcol) 22 | ax.scatter(X[:,0],X[:,1],c=y,cmap=cmap,**kwargs) 23 | # fig.show() 24 | 25 | def hide_spines(ax): 26 | for side in ax.spines: 27 | ax.spines[side].set_visible(False) 28 | 29 | def color_spines(ax,col="black"): 30 | # ax.spines['top'].set_color(col) 31 | # ax.spines['right'].set_color(col) 32 | # ax.spines['bottom'].set_color(col) 33 | # ax.spines['left'].set_color(col) 34 | for side in ax.spines: 35 | ax.spines[side].set_color(col) 36 | 37 | def set_titles(fig,titles,**kwargs): 38 | for i in range(len(titles)): 39 | ax = fig.axes[i] 40 | ax.set_title(titles[i],**kwargs) 41 | 42 | def hide_axes(ax): 43 | ax.tick_params(bottom=False,left=False,labelbottom=False,labelleft=False) 44 | 45 | def multiheatmap(X, Y, grid, figsize=(6,4), cmap="turbo", bgcol="gray", 46 | axhide=True, subplot_space=None, spinecolor=None, 47 | savepath=None, **kwargs): 48 | if subplot_space is not None: 49 | gridspec_kw = {'wspace':subplot_space, 'hspace':subplot_space} 50 | else: 51 | gridspec_kw = {} 52 | fig, axgrid = plt.subplots(*grid, figsize=figsize, gridspec_kw=gridspec_kw) 53 | # if subplot_space is not None: 54 | # plt.subplots_adjust(wspace=subplot_space, hspace=subplot_space) 55 | for i in range(len(fig.axes)): 56 | ax = fig.axes[i] 57 | ax.set_facecolor(bgcol) 58 | if i=0 201 | 202 | def hull_tile(Z, N): 203 | Zg = bounding_box_tile(Z,N) 204 | return Zg[in_hull(Zg,Z)] 205 | 206 | --------------------------------------------------------------------------------