├── .gitattributes
├── .gitignore
├── LICENSE
├── README.md
├── demo.ipynb
├── environment.yml
├── models
├── __init__.py
├── cf.py
├── likelihoods.py
├── mefisto.py
├── sf.py
└── sfh.py
├── nsf-paper.Rproj
├── requirements.txt
├── scrna
├── .gitignore
├── 01_viz_spatial_importance.Rmd
├── sshippo
│ ├── 01_data_loading.Rmd
│ ├── 02_data_loading.ipy
│ ├── 03_benchmark.ipy
│ ├── 04_exploratory.ipy
│ ├── 05_benchmark_viz.Rmd
│ ├── 06_interpret_genes.Rmd
│ ├── 07_traditional.ipy
│ └── results
│ │ └── benchmark.csv
├── utils
│ └── interpret_genes.R
├── visium_brain_sagittal
│ ├── 01_data_loading.ipy
│ ├── 02_exploratory.ipy
│ ├── 03_benchmark.ipy
│ ├── 04_benchmark_viz.Rmd
│ ├── 05_interpret_genes.Rmd
│ ├── 06_traditional.ipy
│ └── results
│ │ └── benchmark.csv
└── xyzeq_liver
│ ├── 01_data_loading.ipy
│ ├── 02_exploratory.ipy
│ ├── 03_benchmark.ipy
│ ├── 04_benchmark_viz.Rmd
│ ├── 05_interpret_genes.Rmd
│ ├── 06_traditional.ipy
│ └── results
│ └── benchmark.csv
├── simulations
├── .gitignore
├── __init__.py
├── benchmark.py
├── benchmark.slurm
├── benchmark_gof.py
├── benchmark_gof.slurm
├── bm_mixed
│ ├── 01_data_generation.ipy
│ ├── 02_benchmark.ipy
│ ├── 03_benchmark_viz.Rmd
│ └── results
│ │ └── benchmark.csv
├── bm_sp
│ ├── 01_data_generation.ipy
│ ├── 02_benchmark.ipy
│ ├── 03_benchmark_viz.Rmd
│ ├── 04_quilt_exploratory.ipy
│ ├── 05_ggblocks_exploratory.ipy
│ ├── data
│ │ ├── S1.h5ad
│ │ └── S6.h5ad
│ └── results
│ │ └── benchmark.csv
└── sim.py
└── utils
├── __init__.py
├── benchmark.py
├── benchmark_array.slurm
├── benchmark_gof.py
├── benchmark_gof.slurm
├── misc.py
├── nnfu.py
├── postprocess.py
├── preprocess.py
├── training.py
└── visualize.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.jpg binary
2 | *.png binary
3 | *.pdf binary
4 | *.RData binary
5 | *.h5ad binary
6 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Added by WT
2 | .ipynb_checkpoints/
3 | # data
4 | # plots
5 | model_checkpoints
6 | *.rds
7 | .spyproject
8 | __pycache__
9 | .DS_Store
10 | tf_ckpts
11 | slurm-*.out
12 | slurm-*.err
13 | *.pyc
14 | *.pickle
15 | resources
16 |
17 | # History files
18 | .Rhistory
19 | .Rapp.history
20 |
21 | # Session Data files
22 | .RData
23 |
24 | # User-specific files
25 | .Ruserdata
26 |
27 | # Example code in package build process
28 | *-Ex.R
29 |
30 | # Output files from R CMD build
31 | /*.tar.gz
32 |
33 | # Output files from R CMD check
34 | /*.Rcheck/
35 |
36 | # RStudio files
37 | .Rproj.user/
38 |
39 | # produced vignettes
40 | vignettes/*.html
41 | vignettes/*.pdf
42 |
43 | # OAuth2 token, see https://github.com/hadley/httr/releases/tag/v0.3
44 | .httr-oauth
45 |
46 | # knitr and R markdown default cache directories
47 | *_cache/
48 | /cache/
49 |
50 | # Temporary files created by R markdown
51 | *.utf8.md
52 | *.knit.md
53 |
54 | # R Environment Variables
55 | .Renviron
56 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | GNU LESSER GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 |
9 | This version of the GNU Lesser General Public License incorporates
10 | the terms and conditions of version 3 of the GNU General Public
11 | License, supplemented by the additional permissions listed below.
12 |
13 | 0. Additional Definitions.
14 |
15 | As used herein, "this License" refers to version 3 of the GNU Lesser
16 | General Public License, and the "GNU GPL" refers to version 3 of the GNU
17 | General Public License.
18 |
19 | "The Library" refers to a covered work governed by this License,
20 | other than an Application or a Combined Work as defined below.
21 |
22 | An "Application" is any work that makes use of an interface provided
23 | by the Library, but which is not otherwise based on the Library.
24 | Defining a subclass of a class defined by the Library is deemed a mode
25 | of using an interface provided by the Library.
26 |
27 | A "Combined Work" is a work produced by combining or linking an
28 | Application with the Library. The particular version of the Library
29 | with which the Combined Work was made is also called the "Linked
30 | Version".
31 |
32 | The "Minimal Corresponding Source" for a Combined Work means the
33 | Corresponding Source for the Combined Work, excluding any source code
34 | for portions of the Combined Work that, considered in isolation, are
35 | based on the Application, and not on the Linked Version.
36 |
37 | The "Corresponding Application Code" for a Combined Work means the
38 | object code and/or source code for the Application, including any data
39 | and utility programs needed for reproducing the Combined Work from the
40 | Application, but excluding the System Libraries of the Combined Work.
41 |
42 | 1. Exception to Section 3 of the GNU GPL.
43 |
44 | You may convey a covered work under sections 3 and 4 of this License
45 | without being bound by section 3 of the GNU GPL.
46 |
47 | 2. Conveying Modified Versions.
48 |
49 | If you modify a copy of the Library, and, in your modifications, a
50 | facility refers to a function or data to be supplied by an Application
51 | that uses the facility (other than as an argument passed when the
52 | facility is invoked), then you may convey a copy of the modified
53 | version:
54 |
55 | a) under this License, provided that you make a good faith effort to
56 | ensure that, in the event an Application does not supply the
57 | function or data, the facility still operates, and performs
58 | whatever part of its purpose remains meaningful, or
59 |
60 | b) under the GNU GPL, with none of the additional permissions of
61 | this License applicable to that copy.
62 |
63 | 3. Object Code Incorporating Material from Library Header Files.
64 |
65 | The object code form of an Application may incorporate material from
66 | a header file that is part of the Library. You may convey such object
67 | code under terms of your choice, provided that, if the incorporated
68 | material is not limited to numerical parameters, data structure
69 | layouts and accessors, or small macros, inline functions and templates
70 | (ten or fewer lines in length), you do both of the following:
71 |
72 | a) Give prominent notice with each copy of the object code that the
73 | Library is used in it and that the Library and its use are
74 | covered by this License.
75 |
76 | b) Accompany the object code with a copy of the GNU GPL and this license
77 | document.
78 |
79 | 4. Combined Works.
80 |
81 | You may convey a Combined Work under terms of your choice that,
82 | taken together, effectively do not restrict modification of the
83 | portions of the Library contained in the Combined Work and reverse
84 | engineering for debugging such modifications, if you also do each of
85 | the following:
86 |
87 | a) Give prominent notice with each copy of the Combined Work that
88 | the Library is used in it and that the Library and its use are
89 | covered by this License.
90 |
91 | b) Accompany the Combined Work with a copy of the GNU GPL and this license
92 | document.
93 |
94 | c) For a Combined Work that displays copyright notices during
95 | execution, include the copyright notice for the Library among
96 | these notices, as well as a reference directing the user to the
97 | copies of the GNU GPL and this license document.
98 |
99 | d) Do one of the following:
100 |
101 | 0) Convey the Minimal Corresponding Source under the terms of this
102 | License, and the Corresponding Application Code in a form
103 | suitable for, and under terms that permit, the user to
104 | recombine or relink the Application with a modified version of
105 | the Linked Version to produce a modified Combined Work, in the
106 | manner specified by section 6 of the GNU GPL for conveying
107 | Corresponding Source.
108 |
109 | 1) Use a suitable shared library mechanism for linking with the
110 | Library. A suitable mechanism is one that (a) uses at run time
111 | a copy of the Library already present on the user's computer
112 | system, and (b) will operate properly with a modified version
113 | of the Library that is interface-compatible with the Linked
114 | Version.
115 |
116 | e) Provide Installation Information, but only if you would otherwise
117 | be required to provide such information under section 6 of the
118 | GNU GPL, and only to the extent that such information is
119 | necessary to install and execute a modified version of the
120 | Combined Work produced by recombining or relinking the
121 | Application with a modified version of the Linked Version. (If
122 | you use option 4d0, the Installation Information must accompany
123 | the Minimal Corresponding Source and Corresponding Application
124 | Code. If you use option 4d1, you must provide the Installation
125 | Information in the manner specified by section 6 of the GNU GPL
126 | for conveying Corresponding Source.)
127 |
128 | 5. Combined Libraries.
129 |
130 | You may place library facilities that are a work based on the
131 | Library side by side in a single library together with other library
132 | facilities that are not Applications and are not covered by this
133 | License, and convey such a combined library under terms of your
134 | choice, if you do both of the following:
135 |
136 | a) Accompany the combined library with a copy of the same work based
137 | on the Library, uncombined with any other library facilities,
138 | conveyed under the terms of this License.
139 |
140 | b) Give prominent notice with the combined library that part of it
141 | is a work based on the Library, and explaining where to find the
142 | accompanying uncombined form of the same work.
143 |
144 | 6. Revised Versions of the GNU Lesser General Public License.
145 |
146 | The Free Software Foundation may publish revised and/or new versions
147 | of the GNU Lesser General Public License from time to time. Such new
148 | versions will be similar in spirit to the present version, but may
149 | differ in detail to address new problems or concerns.
150 |
151 | Each version is given a distinguishing version number. If the
152 | Library as you received it specifies that a certain numbered version
153 | of the GNU Lesser General Public License "or any later version"
154 | applies to it, you have the option of following the terms and
155 | conditions either of that published version or of any later version
156 | published by the Free Software Foundation. If the Library as you
157 | received it does not specify a version number of the GNU Lesser
158 | General Public License, you may choose any version of the GNU Lesser
159 | General Public License ever published by the Free Software Foundation.
160 |
161 | If the Library as you received it specifies that a proxy can decide
162 | whether future versions of the GNU Lesser General Public License shall
163 | apply, that proxy's public statement of acceptance of any version is
164 | permanent authorization for you to choose that version for the
165 | Library.
166 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Nonnegative spatial factorization for multivariate count data
2 |
3 | [](https://zenodo.org/badge/latestdoi/415147174)
4 |
5 | This repository contains supporting code to facilitate reproducible analysis. For details see the [preprint](https://arxiv.org/abs/2110.06122). If you find bugs please create a github issue. An [installable python package](https://github.com/willtownes/spatial-factorization-py)
6 | is also under development.
7 |
8 | ### Authors
9 |
10 | Will Townes and Barbara Engelhardt
11 |
12 | ### Abstract
13 |
14 | Gaussian processes are widely used for the analysis of spatial data due to their nonparametric flexibility and ability to quantify uncertainty, and recently developed scalable approximations have facilitated application to massive datasets. For multivariate outcomes, linear models of coregionalization combine dimension reduction with spatial correlation. However, their real-valued latent factors and loadings are difficult to interpret because, unlike nonnegative models, they do not recover a parts-based representation. We present nonnegative spatial factorization (NSF), a spatially-aware probabilistic dimension reduction model that naturally encourages sparsity. We compare NSF to real-valued spatial factorizations such as MEFISTO and nonspatial dimension reduction methods using simulations and high-dimensional spatial transcriptomics data. NSF identifies generalizable spatial patterns of gene expression. Since not all patterns of gene expression are spatial, we also propose a hybrid extension of NSF that combines spatial and nonspatial components, enabling quantification of spatial importance for both observations and features.
15 |
16 | ## Demo
17 |
18 | A basic demonstration ([**demo.ipynb**](https://github.com/willtownes/nsf-paper/blob/main/demo.ipynb)) using simulated data is provided as a [jupyter](https://jupyter.org) notebook. The expected output is a series of heatmap plots. The runtime should be about 5 minutes.
19 |
20 | ## Description of Repository Contents
21 | All scripts should be run from the top level directory. Files with the suffix `.ipy` are essentially text-only versions of jupyter notebooks and can best be used through the [Spyder IDE](https://www.spyder-ide.org). They can be converted to full jupyter notebooks using [jupytext](https://jupytext.readthedocs.io/en/latest/).
22 |
23 | ### models
24 |
25 | TensorFlow implementations of probabilistic factor models
26 | * *cf.py* - nonspatial models (factor analysis and probabilistic nonnegative matrix factorization).
27 | * *mefisto.py* - wrapper around the MEFISTO implementation in the [mofapy2](https://github.com/bioFAM/mofapy2/commit/8f6ffcb5b18d22b3f44ff2a06bcb92f2806afed0) python package.
28 | * *sf.py* - nonnegative and real-valued spatial process factorization (NSF and RSF).
29 | * *sfh.py* - NSF hybrid model, includes both spatial and nonspatial components.
30 |
31 | ### scrna
32 |
33 | Analysis of spatial transcriptomics data
34 | * *sshippo* - Slide-seqV2 mouse hippocampus
35 | * *visium_brain_sagittal* - Visium mouse brain (anterior sagittal section)
36 | * *xyzeq_liver* - XYZeq mouse liver/tumor
37 |
38 | ### simulations
39 |
40 | Data generation and model fitting for the ggblocks and quilt simulations.
41 | * *benchmark.py* - can be called as a command line script to facilitate benchmarking of large numbers of scenarios and parameter combinations.
42 | * *benchmark_gof.py* - compute goodness of fit and other metrics on fitted models.
43 | * *bm_mixed* - mixed spatial and nonspatial factors
44 | * *bm_sp* - spatial factors only. Within this folder, the notebooks
45 | `04_quilt_exploratory.ipy` and `05_ggblocks_exploratory.ipy` have many
46 | visualizations of the various models compared in the paper.
47 | * *sim.py* - functions for creating the simulated datasets.
48 |
49 | ### utils
50 |
51 | Python modules containing functions and classes needed by scripts and model implementation classes.
52 | * *benchmark.py* - functions used in fitting models to datasets and pickling the objects for later evaluation. Can be called as a command line script to facilitate automation.
53 | * *benchmark_gof.py* - script with basic command line interface for computing goodness-of-fit, sparsity, and timing statistics on large numbers of fitted model objects
54 | * *misc.py* - miscellaneous convenience functions useful in preprocessing (normalization and reversing normalization), postprocessing, computing benchmarking statistics, parameter manipulation, and reading and writing pickle and CSV files.
55 | * *nnfu.py* - nonnegative factor model utility functions for rescaling and regularization. Useful in initialization and postprocessing.
56 | * *postprocess.py* - postprocessing functions to facilitate interpretation of nonnegative factor models.
57 | * *preprocess.py* - data loading and preprocessing functions. Normalization of count data, rescaling spatial coordinates for numerical stability, deviance functions for feature selection (analogous to [scry](https://doi.org/doi:10.18129/B9.bioc.scry)), conversions between AnnData and TensorFlow objects.
58 | * *training.py* - classes for fitting TensorFlow models to data, including caching with checkpoints, automatic handling of numeric instabilities, and ConvergenceChecker, which uses a cubic spline to detect convergence of a stochastic optimizer trace.
59 | * *visualize.py* - plotting functions for making heatmaps to visualize spatial and nonspatial factors, as well as some goodness-of-fit metrics.
60 |
61 | ## System requirements
62 |
63 | We used the following versions in our analyses: Python 3.8.10, tensorflow 2.5.0, tensorflow probability 0.13.0, scanpy 1.8.0, squidpy 1.1.0, scikit-learn 0.24.2, pandas 1.2.5, numpy 1.19.5, scipy 1.7.0.
64 | We used the MEFISTO implementation from the mofapy2 Python package, installed from the GitHub development branch at commit 8f6ffcb5b18d22b3f44ff2a06bcb92f2806afed0.
65 |
66 | ```Shell
67 | pip install git+git://github.com/bioFAM/mofapy2.git@8f6ffcb5b18d22b3f44ff2a06bcb92f2806afed0
68 | ```
69 |
70 | Graphics were generated using either matplotlib 3.4.2 in Python or ggplot2 3.3.5 in R (version 4.1.0). The R packages Seurat 0.4.3, SeuratData 0.2.1, and SeuratDisk 0.0.0.9019 were used for some initial data manipulations.
71 |
72 | Computationally-intensive model fitting was done on Princeton's [Della cluster](https://researchcomputing.princeton.edu/systems/della). Analyses that were less computationally intensive were done on personal computers with operating system MacOS version 12.4.
73 |
74 | ## Installation
75 |
76 | ```Shell
77 | git clone https://github.com/willtownes/nsf-paper.git
78 | ```
79 |
80 | This should only take a few seconds on an ordinary computer with a good internet connection.
81 |
82 | ## Instructions for use
83 |
84 | Data should be stored as a Scanpy AnnData object with the raw counts in the layer "counts" and spatial coordinates in the obsm["spatial"] slot. Utility functions to convert this into the required Tensorflow objects for model fitting are demonstrated in the demo. To reproduce results from the manuscript, use the numbered ipython scripts in each dataset's subfolder. Intermediate results from benchmarking are cached in `results/benchmark.csv` files which can be used to produce many of the plots in the manuscript.
85 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: nsf-paper
2 | channels:
3 | - conda-forge
4 | dependencies:
5 | - abseil-cpp=20220623.0=he4e09e4_6
6 | - absl-py=1.4.0=pyhd8ed1ab_0
7 | - aiohttp=3.8.4=py38hb192615_1
8 | - aiosignal=1.3.1=pyhd8ed1ab_0
9 | - anndata=0.9.1=pyhd8ed1ab_0
10 | - aom=3.5.0=h7ea286d_0
11 | - appnope=0.1.3=pyhd8ed1ab_0
12 | - arpack=3.7.0=h58ebc17_2
13 | - arrow-cpp=11.0.0=hce30654_5_cpu
14 | - asciitree=0.3.3=py_2
15 | - asttokens=2.2.1=pyhd8ed1ab_0
16 | - astunparse=1.6.3=pyhd8ed1ab_0
17 | - async-timeout=4.0.2=pyhd8ed1ab_0
18 | - attrs=23.1.0=pyh71513ae_1
19 | - aws-c-auth=0.6.24=he8f13b4_5
20 | - aws-c-cal=0.5.20=h9571af1_6
21 | - aws-c-common=0.8.11=h1a8c8d9_0
22 | - aws-c-compression=0.2.16=h7334ab6_3
23 | - aws-c-event-stream=0.2.18=ha663d55_6
24 | - aws-c-http=0.7.4=h49dec38_2
25 | - aws-c-io=0.13.17=h323b671_2
26 | - aws-c-mqtt=0.8.6=hdc0f556_6
27 | - aws-c-s3=0.2.4=hbb4c6b3_3
28 | - aws-c-sdkutils=0.1.7=h7334ab6_3
29 | - aws-checksums=0.1.14=h7334ab6_3
30 | - aws-crt-cpp=0.19.7=h6f6c549_7
31 | - aws-sdk-cpp=1.10.57=hbe10753_4
32 | - backcall=0.2.0=pyh9f0ad1d_0
33 | - backports=1.0=pyhd8ed1ab_3
34 | - backports.functools_lru_cache=1.6.5=pyhd8ed1ab_0
35 | - blinker=1.6.2=pyhd8ed1ab_0
36 | - blosc=1.21.4=hc338f07_0
37 | - bokeh=3.1.1=pyhd8ed1ab_0
38 | - brotli=1.0.9=h1a8c8d9_9
39 | - brotli-bin=1.0.9=h1a8c8d9_9
40 | - brotlipy=0.7.0=py38hb991d35_1005
41 | - brunsli=0.1=h9f76cd9_0
42 | - bzip2=1.0.8=h3422bc3_4
43 | - c-ares=1.19.1=hb547adb_0
44 | - c-blosc2=2.9.3=h068da5f_0
45 | - ca-certificates=2023.5.7=hf0a4a13_0
46 | - cached-property=1.5.2=hd8ed1ab_1
47 | - cached_property=1.5.2=pyha770c72_1
48 | - cachetools=5.3.0=pyhd8ed1ab_0
49 | - certifi=2023.5.7=pyhd8ed1ab_0
50 | - cffi=1.15.1=py38ha45ccd6_3
51 | - cfitsio=4.2.0=h2f961c4_0
52 | - charls=2.4.2=h13dd4ca_0
53 | - charset-normalizer=3.1.0=pyhd8ed1ab_0
54 | - click=8.1.3=unix_pyhd8ed1ab_2
55 | - cloudpickle=2.2.1=pyhd8ed1ab_0
56 | - colorama=0.4.6=pyhd8ed1ab_0
57 | - comm=0.1.3=pyhd8ed1ab_0
58 | - contourpy=1.1.0=py38h9afee92_0
59 | - cryptography=41.0.1=py38h92a0862_0
60 | - cycler=0.11.0=pyhd8ed1ab_0
61 | - cytoolz=0.12.0=py38hb991d35_1
62 | - dask=2023.5.0=pyhd8ed1ab_0
63 | - dask-core=2023.5.0=pyhd8ed1ab_0
64 | - dask-image=2023.3.0=pyhd8ed1ab_0
65 | - dav1d=1.2.1=hb547adb_0
66 | - debugpy=1.6.7=py38h2b1e499_0
67 | - decorator=5.1.1=pyhd8ed1ab_0
68 | - dill=0.3.6=pyhd8ed1ab_1
69 | - distributed=2023.5.0=pyhd8ed1ab_0
70 | - dm-tree=0.1.7=py38h55de146_0
71 | - docrep=0.3.2=pyh44b312d_0
72 | - entrypoints=0.4=pyhd8ed1ab_0
73 | - executing=1.2.0=pyhd8ed1ab_0
74 | - fasteners=0.17.3=pyhd8ed1ab_0
75 | - flatbuffers=22.12.06=hb7217d7_2
76 | - fonttools=4.40.0=py38hb192615_0
77 | - freetype=2.12.1=hd633e50_1
78 | - frozenlist=1.3.3=py38hb991d35_0
79 | - fsspec=2023.6.0=pyh1a96a4e_0
80 | - gast=0.4.0=pyh9f0ad1d_0
81 | - gflags=2.2.2=hc88da5d_1004
82 | - giflib=5.2.1=h1a8c8d9_3
83 | - glog=0.6.0=h6da1cb0_0
84 | - glpk=5.0=h6d7a090_0
85 | - gmp=6.2.1=h9f76cd9_0
86 | - google-auth=2.21.0=pyh1a96a4e_0
87 | - google-auth-oauthlib=0.4.6=pyhd8ed1ab_0
88 | - google-pasta=0.2.0=pyh8c360ce_0
89 | - grpc-cpp=1.51.1=h44b9a77_1
90 | - grpcio=1.51.1=py38h171e7b7_1
91 | - h5py=3.9.0=nompi_py38h8a8aaa0_101
92 | - hdf5=1.14.1=nompi_h3aba7b3_100
93 | - icu=70.1=h6b3803e_0
94 | - idna=3.4=pyhd8ed1ab_0
95 | - igraph=0.10.3=h80e09cb_0
96 | - imagecodecs=2023.1.23=py38h57345ed_0
97 | - imageio=2.31.1=pyh24c5eb1_0
98 | - importlib-metadata=6.7.0=pyha770c72_0
99 | - importlib-resources=5.12.0=pyhd8ed1ab_0
100 | - importlib_metadata=6.7.0=hd8ed1ab_0
101 | - importlib_resources=5.12.0=pyhd8ed1ab_0
102 | - inflect=6.0.4=pyhd8ed1ab_0
103 | - ipykernel=6.23.3=pyh5fb750a_0
104 | - ipython=8.12.2=pyhd1c38e8_0
105 | - jax=0.3.25=pyhd8ed1ab_0
106 | - jaxlib=0.3.25=cpu_py38ha029f96_1
107 | - jedi=0.18.2=pyhd8ed1ab_0
108 | - jinja2=3.1.2=pyhd8ed1ab_1
109 | - joblib=1.3.0=pyhd8ed1ab_0
110 | - jpeg=9e=h1a8c8d9_3
111 | - jupyter_client=8.3.0=pyhd8ed1ab_0
112 | - jupyter_core=5.3.1=py38h10201cd_0
113 | - jxrlib=1.1=h27ca646_2
114 | - keras=2.11.0=pyhd8ed1ab_0
115 | - keras-preprocessing=1.1.2=pyhd8ed1ab_0
116 | - kiwisolver=1.4.4=py38h9dc3d6a_1
117 | - krb5=1.20.1=h69eda48_0
118 | - lazy_loader=0.2=pyhd8ed1ab_0
119 | - lcms2=2.15=h481adae_0
120 | - leidenalg=0.9.1=py38h2b1e499_0
121 | - lerc=4.0.0=h9a09cb3_0
122 | - libabseil=20220623.0=cxx17_h28b99d4_6
123 | - libaec=1.0.6=hb7217d7_1
124 | - libarrow=11.0.0=h0b9b5d1_5_cpu
125 | - libavif=0.11.1=h9f83d30_2
126 | - libblas=3.9.0=17_osxarm64_openblas
127 | - libbrotlicommon=1.0.9=h1a8c8d9_9
128 | - libbrotlidec=1.0.9=h1a8c8d9_9
129 | - libbrotlienc=1.0.9=h1a8c8d9_9
130 | - libcblas=3.9.0=17_osxarm64_openblas
131 | - libcrc32c=1.1.2=hbdafb3b_0
132 | - libcurl=8.1.2=h912dcd9_0
133 | - libcxx=16.0.6=h4653b0c_0
134 | - libdeflate=1.17=h1a8c8d9_0
135 | - libedit=3.1.20191231=hc8eb9b7_2
136 | - libev=4.33=h642e427_1
137 | - libevent=2.1.10=h7673551_4
138 | - libffi=3.4.2=h3422bc3_5
139 | - libgfortran=5.0.0=12_2_0_hd922786_31
140 | - libgfortran5=12.2.0=h0eea778_31
141 | - libgoogle-cloud=2.7.0=hcf11473_1
142 | - libgrpc=1.51.1=hb15be72_1
143 | - libiconv=1.17=he4db4b2_0
144 | - liblapack=3.9.0=17_osxarm64_openblas
145 | - libllvm11=11.1.0=hfa12f05_5
146 | - libllvm14=14.0.6=hd1a9a77_3
147 | - libnghttp2=1.52.0=hae82a92_0
148 | - libopenblas=0.3.23=openmp_hc731615_0
149 | - libpng=1.6.39=h76d750c_0
150 | - libprotobuf=3.21.12=hb5ab8b9_0
151 | - libsodium=1.0.18=h27ca646_1
152 | - libsqlite=3.42.0=hb31c410_0
153 | - libssh2=1.11.0=h7a5bd25_0
154 | - libthrift=0.18.0=h6635e49_0
155 | - libtiff=4.5.0=h5dffbdd_2
156 | - libutf8proc=2.8.0=h1a8c8d9_0
157 | - libwebp-base=1.3.0=h1a8c8d9_0
158 | - libxcb=1.13=h9b22ae9_1004
159 | - libxml2=2.10.3=h67585b2_4
160 | - libzlib=1.2.13=h53f4e23_5
161 | - libzopfli=1.0.3=h9f76cd9_0
162 | - llvm-openmp=16.0.6=h1c12783_0
163 | - llvmlite=0.38.1=py38h8a5a59d_0
164 | - locket=1.0.0=pyhd8ed1ab_0
165 | - lz4=4.3.2=py38h76a69a3_0
166 | - lz4-c=1.9.4=hb7217d7_0
167 | - markdown=3.4.3=pyhd8ed1ab_0
168 | - markupsafe=2.1.3=py38hb192615_0
169 | - matplotlib=3.7.1=py38h150bfb4_0
170 | - matplotlib-base=3.7.1=py38hbbe890c_0
171 | - matplotlib-inline=0.1.6=pyhd8ed1ab_0
172 | - matplotlib-scalebar=0.8.1=pyhd8ed1ab_0
173 | - metis=5.1.0=h9f76cd9_1006
174 | - mpfr=4.2.0=he09a6ba_0
175 | - msgpack-python=1.0.5=py38h9dc3d6a_0
176 | - multidict=6.0.4=py38hb991d35_0
177 | - munkres=1.1.4=pyh9f0ad1d_0
178 | - natsort=8.4.0=pyhd8ed1ab_0
179 | - ncurses=6.4=h7ea286d_0
180 | - nest-asyncio=1.5.6=pyhd8ed1ab_0
181 | - networkx=3.1=pyhd8ed1ab_0
182 | - numba=0.55.2=py38h25e2f74_0
183 | - numcodecs=0.11.0=py38h2b1e499_1
184 | - numpy=1.22.4=py38he1fcd3f_0
185 | - oauthlib=3.2.2=pyhd8ed1ab_0
186 | - omnipath=1.0.7=pyhd8ed1ab_0
187 | - openjpeg=2.5.0=hbc2ba62_2
188 | - openssl=3.1.1=h53f4e23_1
189 | - opt_einsum=3.3.0=pyhd8ed1ab_1
190 | - orc=1.8.2=hef0d403_2
191 | - packaging=23.1=pyhd8ed1ab_0
192 | - pandas=1.5.3=py38h61dac83_1
193 | - parquet-cpp=1.5.1=2
194 | - parso=0.8.3=pyhd8ed1ab_0
195 | - partd=1.4.0=pyhd8ed1ab_0
196 | - patsy=0.5.3=pyhd8ed1ab_0
197 | - pexpect=4.8.0=pyh1a96a4e_2
198 | - pickleshare=0.7.5=py_1003
199 | - pillow=9.4.0=py38h1bb68ce_1
200 | - pims=0.6.1=pyhd8ed1ab_1
201 | - pip=23.1.2=pyhd8ed1ab_0
202 | - platformdirs=2.6.0=pyhd8ed1ab_0
203 | - prompt-toolkit=3.0.38=pyha770c72_0
204 | - prompt_toolkit=3.0.38=hd8ed1ab_0
205 | - protobuf=4.21.12=py38h2b1e499_0
206 | - psutil=5.9.5=py38hb991d35_0
207 | - pthread-stubs=0.4=h27ca646_1001
208 | - ptyprocess=0.7.0=pyhd3deb0d_0
209 | - pure_eval=0.2.2=pyhd8ed1ab_0
210 | - pyarrow=11.0.0=py38h32b283d_5_cpu
211 | - pyasn1=0.4.8=py_0
212 | - pyasn1-modules=0.2.7=py_0
213 | - pycparser=2.21=pyhd8ed1ab_0
214 | - pydantic=1.9.2=py38he5c2ac2_0
215 | - pygments=2.15.1=pyhd8ed1ab_0
216 | - pyjwt=2.7.0=pyhd8ed1ab_0
217 | - pynndescent=0.5.10=pyh1a96a4e_0
218 | - pyopenssl=23.2.0=pyhd8ed1ab_1
219 | - pyparsing=3.1.0=pyhd8ed1ab_0
220 | - pysocks=1.7.1=pyha2e5f31_6
221 | - python=3.8.17=h3ba56d0_0_cpython
222 | - python-dateutil=2.8.2=pyhd8ed1ab_0
223 | - python-flatbuffers=23.5.26=pyhd8ed1ab_0
224 | - python-igraph=0.10.4=py38h0504639_0
225 | - python-tzdata=2023.3=pyhd8ed1ab_0
226 | - python_abi=3.8=3_cp38
227 | - pytz=2023.3=pyhd8ed1ab_0
228 | - pyu2f=0.1.5=pyhd8ed1ab_0
229 | - pywavelets=1.4.1=py38hb39dbe9_0
230 | - pyyaml=6.0=py38hb991d35_5
231 | - pyzmq=25.1.0=py38hef91016_0
232 | - re2=2023.02.01=hb7217d7_0
233 | - readline=8.2=h92ec313_1
234 | - requests=2.31.0=pyhd8ed1ab_0
235 | - requests-oauthlib=1.3.1=pyhd8ed1ab_0
236 | - rsa=4.9=pyhd8ed1ab_0
237 | - scanpy=1.9.3=pyhd8ed1ab_0
238 | - scikit-image=0.20.0=py38hfaca753_1
239 | - scikit-learn=1.2.2=py38h971c870_2
240 | - scipy=1.9.1=py38h3aeb131_0
241 | - seaborn=0.12.2=hd8ed1ab_0
242 | - seaborn-base=0.12.2=pyhd8ed1ab_0
243 | - session-info=1.0.0=pyhd8ed1ab_0
244 | - setuptools=68.0.0=pyhd8ed1ab_0
245 | - six=1.15.0=pyh9f0ad1d_0
246 | - slicerator=1.1.0=pyhd8ed1ab_0
247 | - snappy=1.1.10=h17c5cce_0
248 | - sortedcontainers=2.4.0=pyhd8ed1ab_0
249 | - spyder-kernels=2.4.3=unix_pyhd8ed1ab_0
250 | - sqlite=3.42.0=h203b68d_0
251 | - squidpy=1.2.3=pyhd8ed1ab_0
252 | - stack_data=0.6.2=pyhd8ed1ab_0
253 | - statsmodels=0.14.0=py38h58b515d_1
254 | - stdlib-list=0.8.0=pyhd8ed1ab_0
255 | - suitesparse=5.10.1=h7cd81ec_1
256 | - tbb=2021.9.0=hffc8910_0
257 | - tblib=1.7.0=pyhd8ed1ab_0
258 | - tensorboard=2.11.2=pyhd8ed1ab_0
259 | - tensorboard-data-server=0.6.1=py38h23f6d3d_4
260 | - tensorboard-plugin-wit=1.8.1=pyhd8ed1ab_0
261 | - tensorflow=2.11.0=cpu_py38h4487131_0
262 | - tensorflow-base=2.11.0=cpu_py38h968a1bb_0
263 | - tensorflow-estimator=2.11.0=cpu_py38h8d8ceda_0
264 | - tensorflow-probability=0.19.0=pyhd8ed1ab_1
265 | - termcolor=1.1.0=pyhd8ed1ab_3
266 | - texttable=1.6.7=pyhd8ed1ab_0
267 | - threadpoolctl=3.1.0=pyh8a188c0_0
268 | - tifffile=2023.4.12=pyhd8ed1ab_0
269 | - tk=8.6.12=he1e0b03_0
270 | - toolz=0.12.0=pyhd8ed1ab_0
271 | - tornado=6.3.2=py38hb192615_0
272 | - tqdm=4.65.0=pyhd8ed1ab_1
273 | - traitlets=5.9.0=pyhd8ed1ab_0
274 | - typing-extensions=3.7.4.3=0
275 | - typing_extensions=3.7.4.3=py_0
276 | - umap-learn=0.5.3=py38h10201cd_1
277 | - unicodedata2=15.0.0=py38hb991d35_0
278 | - urllib3=1.26.15=pyhd8ed1ab_0
279 | - validators=0.20.0=pyhd8ed1ab_0
280 | - wcwidth=0.2.6=pyhd8ed1ab_0
281 | - werkzeug=2.3.6=pyhd8ed1ab_0
282 | - wheel=0.40.0=pyhd8ed1ab_0
283 | - wrapt=1.12.1=py38hea4295b_3
284 | - wurlitzer=3.0.3=pyhd8ed1ab_0
285 | - xarray=2023.1.0=pyhd8ed1ab_0
286 | - xorg-libxau=1.0.11=hb547adb_0
287 | - xorg-libxdmcp=1.1.3=h27ca646_0
288 | - xyzservices=2023.5.0=pyhd8ed1ab_1
289 | - xz=5.2.6=h57fd34a_0
290 | - yaml=0.2.5=h3422bc3_2
291 | - yarl=1.9.2=py38hb192615_0
292 | - zarr=2.15.0=pyhd8ed1ab_0
293 | - zeromq=4.3.4=hbdafb3b_1
294 | - zfp=1.0.0=hb6e4faa_3
295 | - zict=3.0.0=pyhd8ed1ab_0
296 | - zipp=3.15.0=pyhd8ed1ab_0
297 | - zlib=1.2.13=h53f4e23_5
298 | - zlib-ng=2.0.7=h1a8c8d9_0
299 | - zstd=1.5.2=hf913c23_6
300 | prefix: /opt/homebrew/Caskroom/mambaforge/base/envs/nsf-paper
301 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/willtownes/nsf-paper/0cacf8352e09d223ab8d4421025195358bbde8df/models/__init__.py
--------------------------------------------------------------------------------
/models/likelihoods.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on Tue Jun 29 10:52:16 2021
5 |
6 | @author: townesf
7 | """
8 | from numpy import tile
9 | import tensorflow_probability as tfp
10 | tfb = tfp.bijectors
11 | tfd = tfp.distributions
12 | tv = tfp.util.TransformedVariable
13 |
14 |
15 | class InvalidLikelihoodError(ValueError):
16 | pass
17 |
18 | def init_lik(lik, J, disp="default", dtp="float32"):
19 | """
20 | Given a likelihood, number of features (J),
21 | and a dispersion parameter initialization value,
22 | Return a TransformedVariable representing a vector of trainable,
23 | feature-specific dispersion parameters.
24 | """
25 | if lik=="gau":
26 | #disp is the scale (st dev) of the normal distribution
27 | if disp=="default": disp=1.0
28 | elif lik=="nb":
29 | #var= mu + disp*mu^2, disp->0 is Poisson limit
30 | if disp=="default": disp=0.01
31 | elif lik=="poi": #lik="poi"
32 | disp = None
33 | else:
34 | raise InvalidLikelihoodError
35 | if disp is not None: #ie lik in ("gau","nb")
36 | disp = tv(tile(disp,J), tfb.Softplus(), dtype=dtp, name="dispersion")
37 | return disp
38 |
39 | def lik_to_distr(lik, mu, disp):
40 | """
41 | Given a likelihood and a tensorflow TransformedVariable 'disp',
42 | Return a tensorflow probability distribution object with means 'pred_means'
43 | """
44 | if lik=="poi":
45 | return tfd.Poisson(mu)
46 | elif lik=="gau":
47 | return tfd.Normal(mu, disp)
48 | elif lik=="nb":
49 | return tfd.NegativeBinomial.experimental_from_mean_dispersion(mu, disp)
50 | else:
51 | raise InvalidLikelihoodError
52 |
53 | def choose_nmf_pars(lik):
54 | if lik in ("poi","nb"):
55 | return {"beta_loss":"kullback-leibler", "solver":"mu", "init":"nndsvda"}
56 | elif lik=="gau":
57 | return {"beta_loss":"frobenius", "solver":"cd", "init":"nndsvd"}
58 | else:
59 | raise InvalidLikelihoodError
60 |
--------------------------------------------------------------------------------
/models/mefisto.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | MEFISTO wrapper functions
5 |
6 | Created on Mon Jul 19 18:25:31 2021
7 |
8 | @author: townesf
9 | """
10 | #import json
11 | import numpy as np
12 | from os import path
13 | from time import time,process_time
14 | #from mofax import mofa_model
15 | from mofapy2.run.entry_point import entry_point
16 |
17 | from utils import misc
18 | # from models.likelihoods import InvalidLikelihoodError
19 |
20 |
21 | class MEFISTO(object):
22 | def __init__(self, Dtr, num_factors, inducing_pts=1000, quiet=False,
23 | pickle_path=None): #lik="gau"
24 | """Dtr should be normalized and log transformed data"""
25 | #https://nbviewer.jupyter.org/github/bioFAM/MEFISTO_tutorials/blob/master/MEFISTO_ST.ipynb
26 | # if lik=="gau": lik = "gaussian"
27 | # elif lik=="poi": lik = "poisson"
28 | # else: raise InvalidLikelihoodError
29 | lik = "gaussian"
30 | ent = entry_point()
31 | ent.set_data_options(use_float32=True, center_groups=True)
32 | ent.set_data_matrix([[Dtr["Y"]]],likelihoods=lik)
33 | #ent.set_data_from_anndata(adtr, likelihoods="gaussian")
34 | ent.set_model_options(factors=num_factors,ard_weights=False)
35 | ent.set_train_options(quiet=quiet)#iter=ne)
36 | ent.set_covariates([Dtr["X"]])
37 | #ent.set_covariates([adtr.obsm["spatial"]])
38 | Mfrac = float(inducing_pts)/Dtr["X"].shape[0]
39 | if Mfrac<1.0:
40 | ent.set_smooth_options(sparseGP=True, frac_inducing=Mfrac)
41 | else:
42 | ent.set_smooth_options(sparseGP=False)
43 | #ent.set_stochastic_options(learning_rate=0.5)
44 | ent.build()
45 | self.ent = ent
46 | self.ptime = 0.0
47 | self.wtime = 0.0
48 | self.elbos = np.array([])
49 | self.epoch = 0
50 | self.set_pickle_path(pickle_path)
51 | self.converged=False
52 | self.feature_means = Dtr["Y"].mean(axis=0)
53 | self.L = num_factors
54 | self.M = inducing_pts
55 |
56 | def init_loadings(self,*args,**kwargs):
57 | pass #only here for compatibility with PF, CF, etc
58 |
59 | def set_pickle_path(self,pickle_path):
60 | if pickle_path is not None:
61 | misc.mkdir_p(pickle_path)
62 | self.pickle_path=pickle_path
63 |
64 | def generate_pickle_path(self,base=None):
65 | pars = {"L":self.L,
66 | "lik":"gau",
67 | "model":"MEFISTO",
68 | "kernel":"ExponentiatedQuadratic",
69 | "M":self.M
70 | }
71 | pth = misc.params2key(pars)
72 | if base: pth = path.join(base,pth)
73 | return pth
74 |
75 | def pickle(self):
76 | """
77 | *args passed to update_times method
78 | """
79 | if self.converged:
80 | fname = "converged.pickle"
81 | else:
82 | fname = "epoch{}.pickle".format(self.epoch)
83 | misc.pickle_to_file(self, path.join(self.pickle_path,fname))
84 |
85 | @staticmethod
86 | def from_pickle(pth,epoch=None):
87 | if epoch:
88 | fname = "epoch{}.pickle".format(epoch)
89 | else:
90 | fname = "converged.pickle"
91 | return misc.unpickle_from_file(path.join(pth,fname))
92 |
93 | def train(self):
94 | ptic = process_time()
95 | wtic = time()
96 | self.ent.run()
97 | self.converged = self.ent.model.trained
98 | elbos = self.ent.model.getTrainingStats()["elbo"]
99 | self.elbos = elbos[~np.isnan(elbos)]
100 | self.epoch = max(len(self.elbos)-1, 0)
101 | self.ptime = process_time()-ptic
102 | self.wtime = time()-wtic
103 | if self.pickle_path: self.pickle()
104 |
105 | def get_loadings(self):
106 | return self.ent.model.nodes["W"].getExpectations()[0]["E"]
107 |
108 | def get_factors(self):
109 | return self.ent.model.nodes["Z"].getExpectations()["E"]
110 |
111 | def _reverse_normalization(self,Lambda,sz):
112 | return misc.reverse_normalization(Lambda, feature_means=self.feature_means,
113 | transform=np.expm1, sz=sz)
114 |
115 | def predict(self,Dtr,Dval=None,S=None):
116 | """
117 | Here Dtr,Dval should be raw counts (not normalized or log-transformed)
118 |
119 | returns the predicted training data mean and validation data mean
120 | on the original count scale
121 |
122 | S is not used, only here for compatibility with visualization functions
123 | """
124 | Wt = self.get_loadings().T
125 | Ftr = self.get_factors()
126 | sz_tr = Dtr["Y"].sum(axis=1)
127 | Mu_tr = self._reverse_normalization(Ftr@Wt, sz=sz_tr)
128 | if Dval:
129 | Xu,idx = np.unique(Dval["X"],axis=0,return_inverse=True)
130 | if Xu.shape[0]0
51 | if T is None: T = ceil(L/2.)
52 | assert T>0 and T% group_by(data) %>% summarize(med=median(spatial_wt),avg=mean(spatial_wt)) %>% arrange(med)
43 | sg$data<-factor(sg$data,levels=levs)
44 | ggplot(sg,aes(x=data,y=spatial_wt))+geom_boxplot()+xlab("dataset")+ylab("spatial score per gene")
45 | ggsave(fp(plt_pth,"gene_scores.pdf"),width=6,height=4)
46 | for(d in levs){
47 | pd<-subset(sg,data==d)
48 | ggplot(pd,aes(x=spatial_wt))+geom_histogram(bins=50,fill="darkblue",color="darkblue")+xlab("spatial importance score")+ylab("number of genes")
49 | ggsave(fp("scrna",d,"results/plots",paste0(d,"_gene_spatial_score.pdf")),width=4,height=2.7)
50 | }
51 | ```
52 | Compare to Hotspot
53 | ```{r}
54 | hs<-do.call(rbind,lapply(names(LL),load_hotspot))
55 | hs$data<-factor(hs$data,levels=levs)
56 | colnames(hs)[which(colnames(hs)=="Gene")]<-"gene"
57 | hs$hs_spatial<-hs$FDR<0.05
58 | sg$nsfh_spatial<-sg$spatial_wt>0.5
59 | # with(sg,tapply(is_spatial,data,mean))
60 | res<-merge(sg,hs,by=c("data","gene"))
61 | with(res,table(nsfh_spatial,hs_spatial,data))
62 | res %>% group_by(data) %>% summarise(corr = cor(spatial_wt,Z,method="spearman"))
63 | ```
64 |
65 | Spatial scores per cell
66 |
67 | ```{r}
68 | sc<-do.call(rbind,lapply(names(LL),load_csv,"spatial_cell_weights"))
69 | sc$data<-factor(sc$data,levels=levs)
70 | ggplot(sc,aes(x=data,y=spatial_wt))+geom_boxplot()+xlab("dataset")+ylab("spatial score per observation")
71 | ggsave(fp(plt_pth,"obs_scores.pdf"),width=6,height=4)
72 | for(d in levs){
73 | pd<-subset(sc,data==d)
74 | ggplot(pd,aes(x=spatial_wt))+geom_histogram(bins=50,fill="darkblue",color="darkblue")+xlab("spatial importance score")+ylab("number of observations")
75 | ggsave(fp("scrna",d,"results/plots",paste0(d,"_obs_spatial_score.pdf")),width=4,height=2.7)
76 | }
77 | ```
78 |
79 | Importance of spatial factors (SPDE style)
80 |
81 | ```{r}
82 | sgi<-do.call(rbind,lapply(names(LL),load_csv2,"dim_weights_spde"))
83 | sgi$data<-factor(sgi$data,levels=levs)
84 | ggplot(sgi,aes(x=id,y=weight,fill=factor_type))+geom_bar(stat="identity")+facet_wrap(~data,nrow=2,scales="free")+theme(legend.position="top")+xlab("factor")+ylab("importance")
85 | ggsave(fp(plt_pth,"importance_spde.pdf"),width=5,height=4)
86 | ```
87 |
88 | Importance of spatial factors (LDA style)
89 |
90 | ```{r}
91 | sci<-do.call(rbind,lapply(names(LL),load_csv2,"dim_weights_lda"))
92 | sci$data<-factor(sci$data,levels=levs)
93 | ggplot(sci,aes(x=id,y=weight,fill=factor_type))+geom_bar(stat="identity")+facet_wrap(~data,nrow=2,scales="free")+theme(legend.position="none")+xlab("factor")+ylab("importance")
94 | ggsave(fp(plt_pth,"importance_lda.pdf"),width=5,height=3)
95 | ```
96 |
97 | spatial autocorrelation of components
98 |
99 | ```{r}
100 | #sshippo
101 | dac<-read.csv(fp(pth,"sshippo/results/NSFH_dim_autocorr_spde_L20_T10.csv"))
102 | dac$type<-rep(c("spatial","nonspatial"),each=10)
103 | #dac$id<-as.character(rep(1:10,2))
104 | dac$component<-factor(dac$component,levels=dac$component)
105 | ggplot(dac,aes(x=component,y=moran_i,fill=type))+geom_bar(stat="identity")+ylab("Moran's I")+theme(legend.position="none")
106 | ggsave(fp(pth,"results/plots/sshippo_L20_moranI.pdf"),width=5,height=3)
107 |
108 | #xyzeq
109 | dac<-read.csv(fp(pth,"xyzeq_liver/results/NSFH_dim_autocorr_spde_L6_T3.csv"))
110 | dac$type<-rep(c("spatial","nonspatial"),each=3)
111 | #dac$id<-as.character(rep(1:10,2))
112 | dac$component<-factor(dac$component,levels=dac$component)
113 | ggplot(dac,aes(x=component,y=moran_i,fill=type))+geom_bar(stat="identity")+ylab("Moran's I")+theme(legend.position="none")
114 | ggsave(fp(pth,"results/plots/xyz_liv_L6_moranI.pdf"),width=3,height=1.75)
115 |
116 | #visium
117 | dac<-read.csv(fp(pth,"visium_brain_sagittal/results/NSFH_dim_autocorr_spde_L20_T10.csv"))
118 | dac$type<-rep(c("spatial","nonspatial"),each=10)
119 | #dac$id<-as.character(rep(1:10,2))
120 | dac$component<-factor(dac$component,levels=dac$component)
121 | ggplot(dac,aes(x=component,y=moran_i,fill=type))+geom_bar(stat="identity")+ylab("Moran's I")+theme(legend.position="none")
122 | ggsave(fp(pth,"results/plots/vz_brn_L20_moranI.pdf"),width=5,height=3)
123 |
124 | dac<-read.csv(fp(pth,"visium_brain_sagittal/results/NSFH_dim_autocorr_spde_L60_T30.csv"))
125 | dac$type<-rep(c("spatial","nonspatial"),each=30)
126 | #dac$id<-as.character(rep(1:10,2))
127 | dac$component<-factor(dac$component,levels=dac$component)
128 | ggplot(dac,aes(x=component,y=moran_i,fill=type))+geom_bar(stat="identity")+ylab("Moran's I")+theme(legend.position="none", axis.text.x=element_text(angle=45,hjust=1))
129 | ggsave(fp(pth,"results/plots/vz_brn_L60_moranI.pdf"),width=5.5,height=3)
130 | ```
--------------------------------------------------------------------------------
/scrna/sshippo/01_data_loading.Rmd:
--------------------------------------------------------------------------------
1 | ---
2 | title: "Slide-seq v2 hippocampus"
3 | author: "Will Townes"
4 | output: html_document
5 | ---
6 |
7 | ```{r}
8 | library(Seurat)
9 | library(SeuratData) #remotes::install_github("satijalab/seurat-data")
10 | library(SeuratDisk) #remotes::install_github("mojaveazure/seurat-disk")
11 | library(scry) #bioconductor
12 |
13 | fp<-file.path
14 | pth<-"scrna/sshippo"
15 | dpth<-fp(pth,"data")
16 | if(!dir.exists(fp(dpth,"original"))){
17 | dir.create(fp(dpth,"original"),recursive=TRUE)
18 | }
19 | ```
20 |
21 | Download the Seurat dataset
22 |
23 | ```{r}
24 | #InstallData("ssHippo") #run once to store to disk
25 | slide.seq <- LoadData("ssHippo")
26 | ```
27 |
28 | rank genes by deviance
29 |
30 | ```{r}
31 | X<-slide.seq@images[[1]]@coordinates
32 | slide.seq<-AddMetaData(slide.seq,X)
33 | Y<-slide.seq@assays[[1]]@counts
34 | #Y<-Y[rowSums(Y)>0,]
35 | dev<-devianceFeatureSelection(Y,fam="poisson")
36 | dev[is.na(dev)]<-0
37 | slide.seq@assays[[1]]<-AddMetaData(slide.seq@assays[[1]], dev, col.name="deviance_poisson")
38 |
39 | o<-order(dev,decreasing=TRUE)
40 | #Y<-Y[o,]
41 | #dev<-dev[o]
42 | plot(dev[o],type="l",log="y")
43 | abline(v=1000,col="red")
44 | ```
45 |
46 | Converting to H5AD, based on https://mojaveazure.github.io/seurat-disk/articles/convert-anndata.html
47 |
48 | ```{r}
49 | dfile<-fp(dpth,"original/sshippo.h5Seurat")
50 | SaveH5Seurat(slide.seq, filename=dfile)
51 | Convert(dfile, dest="h5ad")
52 | ```
53 |
54 |
55 |
56 |
--------------------------------------------------------------------------------
/scrna/sshippo/02_data_loading.ipy:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # ---
3 | # jupyter:
4 | # jupytext:
5 | # text_representation:
6 | # extension: .ipy
7 | # format_name: percent
8 | # format_version: '1.3'
9 | # jupytext_version: 1.6.0
10 | # kernelspec:
11 | # display_name: Python 3
12 | # language: python
13 | # name: python3
14 | # ---
15 |
16 | # %% [markdown]
17 | """
18 | First run 01_data_loading.Rmd to get data via Seurat, compute
19 | Poisson deviance for each gene, and export to H5AD.
20 | """
21 |
22 | # %% imports
23 | import random
24 | import numpy as np
25 | import scanpy as sc
26 | from os import path
27 |
28 | random.seed(101)
29 | pth = "scrna/sshippo"
30 |
31 | # %% load the pre-processed dataset
32 | #ad = sq.datasets.slideseqv2()
33 | ad = sc.read_h5ad(path.join(pth,"data/original/sshippo.h5ad"))
34 |
35 | # %% Desiderata for dataset [markdown]
36 | # 1. Spatial coordinates
37 | # 2. Features sorted in decreasing order of deviance
38 | # 3. Observations randomly shuffled
39 |
40 | #%% organize anndata
41 | ad.obsm['spatial'] = ad.obs[["x","y"]].to_numpy()
42 | ad.obs.drop(columns=["x","y"],inplace=True)
43 | ad.X = ad.raw.X
44 | ad.raw = None
45 |
46 | # %% QC, loosely following MEFISTO tutorials
47 | # https://nbviewer.jupyter.org/github/bioFAM/MEFISTO_tutorials/blob/master/MEFISTO_ST.ipynb#QC-and-preprocessing
48 | #ad.var_names_make_unique()
49 | ad.var["mt"] = ad.var_names.str.startswith("MT-")
50 | sc.pp.calculate_qc_metrics(ad, qc_vars=["mt"], inplace=True)
51 | ad.obs.pct_counts_mt.hist(bins=100)
52 | ad = ad[ad.obs.pct_counts_mt < 20] #from 53K to 45K
53 | tc = ad.obs.total_counts
54 | tc.hist(bins=100)
55 | tc[tc<500].hist(bins=100)
56 | (tc<100).sum() #8000 spots
57 | sc.pp.filter_cells(ad, min_counts=100)
58 | sc.pp.filter_genes(ad, min_cells=1)
59 | ad.layers = {"counts":ad.X.copy()} #store raw counts before normalization changes ad.X
60 | sc.pp.normalize_total(ad, inplace=True, layers=None, key_added="sizefactor")
61 | sc.pp.log1p(ad)
62 |
63 | # %% sort by deviance
64 | o = np.argsort(-ad.var['deviance_poisson'])
65 | idx = list(range(ad.shape[0]))
66 | random.shuffle(idx)
67 | ad = ad[idx,o]
68 | ad.var["deviance_poisson"].plot()
69 | ad.write_h5ad(path.join(pth,"data/sshippo.h5ad"),compression="gzip")
70 | #ad = sc.read_h5ad(path.join(pth,"data/sshippo.h5ad"))
71 | ad2 = ad[:,:2000]
72 | ad2.write_h5ad(path.join(pth,"data/sshippo_J2000.h5ad"),compression="gzip")
73 |
--------------------------------------------------------------------------------
/scrna/sshippo/03_benchmark.ipy:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # ---
3 | # jupyter:
4 | # jupytext:
5 | # text_representation:
6 | # extension: .ipy
7 | # format_name: percent
8 | # format_version: '1.3'
9 | # jupytext_version: 1.6.0
10 | # kernelspec:
11 | # display_name: Python 3
12 | # language: python
13 | # name: python3
14 | # ---
15 |
16 | #%%
17 | import pandas as pd
18 | from os import path
19 | from utils import misc,benchmark
20 |
21 | pth = "scrna/sshippo"
22 | dpth = path.join(pth,"data")
23 | mpth = path.join(pth,"models/V5")
24 | rpth = path.join(pth,"results")
25 | misc.mkdir_p(rpth)
26 |
27 | #%% Create CSV with benchmarking parameters
28 | csv_path = path.join(rpth,"benchmark.csv")
29 | try:
30 | par = pd.read_csv(csv_path)
31 | except FileNotFoundError:
32 | L = [6,12,20]
33 | sp_mods = ["NSF-P", "NSF-N", "NSF-G", "RSF-G", "NSFH-P", "NSFH-N"]
34 | ns_mods = ["PNMF-P", "PNMF-N", "FA-G"]
35 | sz = ["constant","scanpy"]
36 | M = [1000,2000,3000]
37 | par = benchmark.make_param_df(L,sp_mods,ns_mods,M,sz)
38 | par.to_csv(csv_path,index=False)
39 |
40 | #%% Benchmarking at command line [markdown]
41 | """
42 | To run on local computer, use
43 | `python -m utils.benchmark 2 scrna/sshippo/data/sshippo_J2000.h5ad`
44 | where 2 is a row ID of benchmark.csv, min value 2, max possible value is 115
45 |
46 | To run on cluster first load anaconda environment
47 | ```
48 | tmux
49 | interactive
50 | module load anaconda3/2021.5
51 | conda activate fwt
52 | python -m utils.benchmark 2 scrna/sshippo/data/sshippo_J2000.h5ad
53 | ```
54 |
55 | To run on cluster as a job array, subset of rows. Recommend 24hr time limit.
56 | ```
57 | DAT=./scrna/sshippo/data/sshippo_J2000.h5ad
58 | sbatch --mem=180G --array=61,64 ./utils/benchmark_array.slurm $DAT
59 | ```
60 |
61 | To run on cluster as a job array, all rows of CSV file
62 | ```
63 | CSV=./scrna/sshippo/results/benchmark.csv
64 | DAT=./scrna/sshippo/data/sshippo_J2000.h5ad
65 | sbatch --mem=180G --array=2-$(wc -l < $CSV) ./utils/benchmark_array.slurm $DAT
66 | ```
67 | """
68 |
69 | #%% Compute metrics for each model
70 | """
71 | DAT=./scrna/sshippo/data/sshippo_J2000.h5ad
72 | sbatch --mem=180G ./utils/benchmark_gof.slurm $DAT 5
73 | """
74 |
75 | #%% Examine one result
76 | from matplotlib import pyplot as plt
77 | from utils import training
78 | tro = training.ModelTrainer.from_pickle(path.join(mpth,"L20/poi/NSFH_T10_MaternThreeHalves_M1000"))
79 | plt.plot(tro.loss["train"][-200:-1])
80 |
--------------------------------------------------------------------------------
/scrna/sshippo/05_benchmark_viz.Rmd:
--------------------------------------------------------------------------------
1 | ---
2 | title: "Benchmarking Visualization"
3 | author: "Will Townes"
4 | output: html_document
5 | ---
6 |
7 | ```{r}
8 | library(tidyverse)
9 | theme_set(theme_bw())
10 | fp<-file.path
11 | pth<-"scrna/sshippo"
12 | plt_pth<-fp(pth,"results/plots")
13 | if(!dir.exists(plt_pth)){
14 | dir.create(plt_pth,recursive=TRUE)
15 | }
16 | ```
17 |
18 | Convergence failures
19 | * MEFISTO with L>6 or M>2000 (ran out of memory)
20 |
21 | ```{r}
22 | d0<-read.csv(fp(pth,"results/benchmark.csv"))
23 | d0$converged<-as.logical(d0$converged)
24 | #d0$model<-plyr::mapvalues(d0$model,c("FA","RSF","PNMF","NSFH","NSF"),c("FA","RSF","PNMF","NSFH","NSF"))
25 | d<-subset(d0,converged==TRUE)
26 | da<-subset(d,model %in% c("FA","MEFISTO","RSF"))
27 | db<-subset(d,(model %in% c("PNMF","NSFH","NSF")) & (sz=="scanpy"))
28 | d<-rbind(da,db)
29 | d$model<-paste0(d$model," (",d$lik,")")
30 | unique(d$model)
31 | keep<-c("FA (gau)","MEFISTO (gau)","RSF (gau)","PNMF (nb)","PNMF (poi)","NSFH (nb)","NSFH (poi)","NSF (nb)","NSF (poi)")
32 | d<-subset(d,model %in% keep)
33 | d$model<-factor(d$model,levels=keep)
34 | d$dim<-factor(d$L,levels=sort(unique(d$L)))
35 | d$M[is.na(d$M)]<-3000
36 | d$IPs<-factor(d$M,levels=sort(unique(d$M)))
37 | ```
38 | subset of models for simplified main figure
39 | ```{r}
40 | d2<-subset(d,M==2000 | (model %in% c("PNMF (poi)","FA (gau)")))
41 | d2a<-subset(d2,(model %in% c("PNMF (poi)","NSFH (poi)","NSF (poi)")) & sz=="scanpy")
42 | d2b<-subset(d2,model %in% c("FA (gau)","RSF (gau)","MEFISTO (gau)"))
43 | d2<-rbind(d2a,d2b)
44 | d2$model<-factor(d2$model,levels=)
45 | d2$model<-plyr::mapvalues(d2$model,c("FA (gau)","MEFISTO (gau)","RSF (gau)","PNMF (poi)","NSFH (poi)","NSF (poi)"),c("FA","MEFISTO","RSF","PNMF","NSFH","NSF"))
46 |
47 | ggplot(d2,aes(x=model,y=dev_tr_mean,color=dim,shape=lik))+geom_point(size=6,position=position_dodge(width=0.8))+ylab("training deviance (mean)")+ylim(range(d2$dev_tr_mean)*c(.95,1.02))
48 | ggsave(fp(plt_pth,"sshippo_gof_dev_tr_mean_main.pdf"),width=5,height=2.5)
49 |
50 | ggplot(d2,aes(x=model,y=dev_val_mean,color=dim,shape=lik))+geom_point(size=6,position=position_dodge(width=0.8))+ylab("validation deviance (mean)")+ylim(range(d2$dev_val_mean)*c(.95,1.02))
51 | ggsave(fp(plt_pth,"sshippo_gof_dev_val_mean_main.pdf"),width=5,height=2.5)
52 |
53 | ggplot(d2,aes(x=model,y=rmse_val,color=dim,shape=lik))+geom_point(size=6,position=position_dodge(width=0.8))+ylab("validation RMSE")+ylim(range(d2$rmse_val)*c(.95,1.02))
54 | ggsave(fp(plt_pth,"sshippo_gof_rmse_val_simple.pdf"),width=5,height=2.5)
55 |
56 | #linear regressions for statistical significance
57 | d2$realvalued<-d2$model %in% c("FA","MEFISTO","RSF")
58 | #d2$dim_numeric<-as.numeric(as.character(d2$dim))
59 | summary(lm(dev_val_mean~realvalued+dim,data=d2))
60 | d2$unsupervised<-d2$model %in% c("FA","PNMF")
61 | d2$unsupervised[d2$model=="MEFISTO"]<-NA
62 | summary(lm(dev_val_mean~realvalued+dim+unsupervised,data=d2))
63 | t.test(dev_val_mean~model,data=subset(d2,model %in% c("NSFH","NSF")), var.equal=TRUE,alternative="greater")
64 |
65 | #time
66 | ggplot(d2,aes(x=model,y=wtime/60,color=dim,shape=lik))+geom_point(size=6,position=position_dodge(width=0.8))+ylab("time to converge (min)")+ylim(range(d2$wtime/60)*c(.95,1.02))
67 | ggsave(fp(plt_pth,"sshippo_wtime_simple.pdf"),width=5,height=2.5)
68 |
69 | #sparsity
70 | ggplot(d2,aes(x=model,y=sparsity,color=dim,shape=lik))+geom_point(size=6,position=position_dodge(width=0.8))+ylab("loadings zero fraction")+ylim(range(d2$sparsity)*c(.95,1.02))
71 | ggsave(fp(plt_pth,"sshippo_sparsity_simple.pdf"),width=5,height=2.5)
72 |
73 | #time
74 | ggplot(d2,aes(x=model,y=wtime/60,color=dim,shape=lik))+geom_point(size=6,position=position_dodge(width=0.8))+ylab("time to converge (min)")+ylim(range(d2$wtime/60)*c(.95,1.02))
75 | ggsave(fp(plt_pth,"sshippo_wtime_simple.pdf"),width=5,height=2.5)
76 | ```
77 | NB vs Poi likelihood
78 | ```{r}
79 | d3<-subset(d,lik %in% c("poi","nb"))
80 |
81 | ggplot(d3,aes(x=model,y=dev_tr_mean,color=dim,shape=IPs))+geom_point(size=4,position=position_dodge(width=0.6))+ylab("training deviance (mean)")
82 | ggsave(fp(plt_pth,"sshippo_gof_dev_tr_nb_vs_poi.pdf"),width=6,height=3)
83 |
84 | ggplot(d3,aes(x=model,y=dev_val_mean,color=IPs,shape=dim))+geom_point(size=4,position=position_dodge(width=0.6))+ylab("validation deviance (mean)")
85 | ggsave(fp(plt_pth,"sshippo_gof_dev_val_nb_vs_poi.pdf"),width=6,height=3)
86 |
87 | ggplot(d3,aes(x=model,y=rmse_val,color=IPs,shape=dim))+geom_point(size=4,position=position_dodge(width=0.6))+ylab("validation RMSE")
88 | ggsave(fp(plt_pth,"sshippo_gof_rmse_val_nb_vs_poi.pdf"),width=6,height=3)
89 |
90 | #time
91 | ggplot(d3,aes(x=model,y=wtime/60,color=IPs,shape=dim))+geom_point(size=4,position=position_dodge(width=0.8))+ylab("time to converge (min)")+scale_y_log10()
92 | ggsave(fp(plt_pth,"sshippo_wtime_nb_vs_poi.pdf"),width=6,height=3)
93 | ```
94 |
95 | NSF vs NSFH
96 |
97 | ```{r}
98 | d3<-subset(d,grepl("NSF",model))
99 | ggplot(d3,aes(x=model,y=dev_tr_mean,color=dim,shape=IPs))+geom_point(size=4,position=position_dodge(width=0.8))+ylab("training deviance (mean)")
100 | ggsave(fp(plt_pth,"sshippo_gof_dev_tr_nsf_vs_nsfh.pdf"),width=6,height=3)
101 | ```
102 |
103 | ```{r}
104 | #training deviance - average
105 | ggplot(d,aes(x=model,y=dev_tr_mean,color=dim,shape=IPs))+geom_jitter(size=5,width=.2,height=0)+ylab("training deviance (mean)")+theme(legend.position = "top")
106 | ggsave(fp(plt_pth,"sshippo_gof_dev_tr_mean.pdf"),width=6,height=3)
107 |
108 | #training deviance - max
109 | ggplot(d,aes(x=model,y=dev_tr_max,color=dim,shape=sz))+geom_jitter(size=5,width=.2,height=0)+ylab("training deviance (max)")#+scale_y_log10()
110 |
111 | #validation deviance - mean
112 | ggplot(d,aes(x=model,y=dev_val_mean,color=dim,shape=sz))+geom_jitter(size=4,width=.4,height=0)+ylab("validation deviance (mean)")+theme(legend.position="none")
113 | ggsave(fp(plt_pth,"sshippo_gof_dev_val_mean.pdf"),width=6,height=2.5)
114 |
115 | #validation deviance - max
116 | ggplot(d,aes(x=model,y=dev_val_max,color=dim,shape=sz))+geom_jitter(size=3,width=.2,height=0)+ylab("validation deviance (max)")#+scale_y_log10()
117 |
118 | #sparsity
119 | ggplot(d,aes(x=model,y=sparsity,color=dim,shape=sz))+geom_jitter(size=4,width=.3,height=0)+ylab("sparsity of loadings")+theme(legend.position="none")
120 | ggsave(fp(plt_pth,"sshippo_sparsity.pdf"),width=6,height=2.5)
121 |
122 | #wall time
123 | ggplot(d,aes(x=model,y=wtime/3600,color=IPs,shape=dim))+geom_jitter(size=5,width=.2,height=0)+ylab("wall time (hr)")+scale_y_log10()+theme(legend.position="top")
124 | ggsave(fp(plt_pth,"sshippo_wtime.pdf"),width=6,height=3)
125 | ggplot(d,aes(x=dim,y=wtime/60,color=IPs,shape=sz))+geom_jitter(size=3,width=.2,height=0)+ylab("wall time (min)")+scale_y_log10()
126 |
127 | #processor time
128 | ggplot(d,aes(x=model,y=ptime/3600,color=IPs,shape=dim))+geom_jitter(size=5,width=.2,height=0)+ylab("processor time")+scale_y_log10()+theme(legend.position="none")
129 | ggsave(fp(plt_pth,"sshippo_ptime.pdf"),width=6,height=2.5)
130 | ```
131 |
132 | Effect of using size factors
133 |
134 | ```{r}
135 | d2<-subset(d,model %in% c("NSF (nb)","NSF (poi)","NSFH (nb)","NSFH (poi)"))
136 | d2$model<-factor(d2$model)
137 | d2$dim<-factor(d2$L,levels=sort(unique(d2$L)))
138 | d2$M[is.na(d2$M)]<-3000
139 | d2$IPs<-factor(d2$M,levels=sort(unique(d2$M)))
140 | ```
141 |
142 | ```{r}
143 | ggplot(d2,aes(x=sz,y=dev_tr_mean,color=dim,shape=IPs))+geom_jitter(size=5,width=.2,height=0)+ylab("training deviance (mean)")+theme(legend.position = "top")+facet_wrap(~model,ncol=2,scales="free")
144 |
145 | ggplot(d2,aes(x=sz,y=dev_val_mean,color=dim,shape=IPs))+geom_jitter(size=5,width=.2,height=0)+ylab("validation deviance (mean)")+theme(legend.position = "top")+facet_wrap(~model,ncol=2,scales="free")
146 | ```
147 |
148 | Effect of inducing points
149 |
150 | ```{r}
151 | d2<-subset(d,model %in% c("MEFISTO (gau)","RSF (gau)","NSFH (nb)","NSFH (poi)","NSF (nb)","NSF (poi)"))
152 | ggplot(d2,aes(x=IPs,y=dev_tr_mean,colour=dim,group=dim))+geom_line(size=2,lineend="round")+ylab("training deviance (mean)")+theme(legend.position = "top")+facet_wrap(~model,scales="free")
153 | ggsave(fp(plt_pth,"sshippo_gof_dev_tr_ips.pdf"),width=6,height=3)
154 |
155 | ggplot(d2,aes(x=IPs,y=rmse_tr,colour=dim,group=dim))+geom_line(size=2,lineend="round")+ylab("training RMSE")+theme(legend.position = "top")+facet_wrap(~model,scales="free")
156 | ggsave(fp(plt_pth,"sshippo_gof_rmse_tr_ips.pdf"),width=6,height=3)
157 |
158 | ggplot(d2,aes(x=IPs,y=dev_val_mean,colour=dim,group=dim))+geom_line(size=2,lineend="round")+ylab("validation deviance (mean)")+theme(legend.position = "top")+facet_wrap(~model,scales="free")
159 | ggsave(fp(plt_pth,"sshippo_gof_dev_val_ips.pdf"),width=6,height=3)
160 |
161 | ggplot(d2,aes(x=IPs,y=rmse_val,colour=dim,group=dim))+geom_line(size=2,lineend="round")+ylab("validation RMSE")+theme(legend.position = "top")+facet_wrap(~model,scales="free")
162 | ggsave(fp(plt_pth,"sshippo_gof_rmse_val_ips.pdf"),width=6,height=3)
163 |
164 | ggplot(d2,aes(x=IPs,y=wtime/3600,colour=dim,group=dim))+geom_line(size=2,lineend="round")+ylab("wall time (hr)")+theme(legend.position = "top")+facet_wrap(~model,scales="free")
165 | ggsave(fp(plt_pth,"sshippo_wtime_ips.pdf"),width=6,height=3)
166 | ```
167 |
--------------------------------------------------------------------------------
/scrna/sshippo/06_interpret_genes.Rmd:
--------------------------------------------------------------------------------
1 | ---
2 | title: "06_interpretation_genes"
3 | author: "Will Townes"
4 | output: html_document
5 | ---
6 |
7 | ```{r}
8 | library(biomaRt)
9 | source("./scrna/utils/interpret_genes.R")
10 |
11 | fp<-file.path
12 | pth<-"scrna/sshippo"
13 | ```
14 | loadings from NSFH
15 | ```{r}
16 | W<-read.csv(fp(pth,"results/sshippo_nsfh20_spde_loadings.csv"),header=TRUE,row.names=1)
17 | rownames(W)<-toupper(rownames(W))
18 | # #format rownames in title case
19 | # g<-str_to_title(rownames(W))
20 | # rownames(W)<-sub("Mt-","mt-",g,fixed=TRUE)
21 | W2<-t(apply(as.matrix(W),1,function(x){as.numeric(x==max(x))}))
22 | colSums(W2) #clustering of genes
23 | ```
24 | Gene Ontology terms, used for both NSFH and hotspot.
25 | ```{r}
26 | # bg<-rownames(W)
27 | db<-useMart("ensembl",host="https://may2021.archive.ensembl.org",dataset='mmusculus_gene_ensembl')
28 | go_ids<-getBM(attributes=c('go_id', 'external_gene_name','namespace_1003'), filters='external_gene_name', values=rownames(W), mart=db)
29 | go_ids[,2]<-toupper(go_ids[,2])
30 | gene_2_GO<-unstack(go_ids[,c(1,2)])
31 | ```
32 | GO analysis for NSFH
33 | ```{r}
34 | # g<-strsplit(res[l,"genes"],", ",fixed=TRUE)[[1]]
35 | # expr <- getMarkers(include = g[c(1:2,4,6,7)])$
36 | res<-loadings2go(W,gene_2_GO)#,numtopgenes=50,min_genescore=1)
37 | res$type<-rep(c("spat","nsp"),each=10)
38 | res$dim<-rep(1:10,2)
39 | write.csv(res[,c(1,4,2,3)],fp(pth,"results/sshippo_nsfh20_spde_goterms.csv"),row.names=FALSE)
40 | ```
41 | Get cell types from [Panglao](https://panglaodb.se). First manually download the list of all marker genes for all cell types from https://panglaodb.se/markers.html?cell_type=%27all_cells%27 . Store this TSV in the scrna/resources folder.
42 | ```{r}
43 | res<-read.csv(fp(pth,"results/sshippo_nsfh20_spde_goterms.csv"),header=TRUE)
44 | panglao_tsv="scrna/resources/PanglaoDB_markers_27_Mar_2020.tsv.gz"
45 | mk<-read.table(gzfile(panglao_tsv),header=TRUE,sep="\t")
46 | mk<-subset(mk,species %in% c("Mm","Mm Hs"))
47 | mk<-subset(mk,organ %in% c("Brain","Connective tissue","Epithelium"))
48 | #"Immune system","Vasculature","Blood",
49 | pg<-tapply(mk$official.gene.symbol,mk$cell.type,function(x){x},simplify=FALSE)
50 | ct<-lapply(res$genes,genes2celltype,pg)#,ss=3)
51 | ct[is.na(ct)]<-""
52 | res$celltypes<-paste(ct,sep=", ")
53 | write.csv(res,fp(pth,"results/sshippo_nsfh20_spde_goterms.csv"),row.names=FALSE)
54 | ```
55 | Shorter GO table
56 | ```{r}
57 | res<-read.csv(fp(pth,"results/sshippo_nsfh20_spde_goterms.csv"),header=TRUE)
58 | res2<-res
59 | g<-strsplit(res2$genes,", ",fixed=TRUE)
60 | res2$genes<-sapply(g,function(x){paste(x[1:5],collapse=", ")})
61 | gt<-strsplit(res2$go_bp,"; ",fixed=TRUE)
62 | res2$go_bp<-sapply(gt,function(x){paste(x[1:2],collapse=", ")})
63 | write.csv(res2,fp(pth,"results/sshippo_nsfh20_spde_goterms_short.csv"),row.names=FALSE)
64 | ```
65 |
66 | GO analysis for hotspot
67 | ```{r}
68 | hs<-read.csv(fp(pth,"results/hotspot.csv"),header=TRUE)
69 | rownames(hs)<-hs$Gene
70 | Wh<-model.matrix(~0+factor(Module),data=subset(hs,Module!=-1))
71 | colnames(Wh)<-paste0("X",1:ncol(Wh))
72 | rownames(Wh)<-toupper(rownames(Wh))
73 | # all(rownames(Wh) %in% rownames(W))
74 | hres<-loadings2go(Wh,gene_2_GO, rowmean_divide=TRUE)
75 | colnames(hres)[which(colnames(hres)=="dim")]<-"cluster"
76 | write.csv(hres,fp(pth,"results/sshippo_hotspot_goterms.csv"),row.names=FALSE)
77 | hres2<-hres
78 | g<-strsplit(hres2$genes,", ",fixed=TRUE)
79 | hres2$genes<-sapply(g,function(x){paste(x[1:5],collapse=", ")})
80 | gt<-strsplit(hres2$go_bp,"; ",fixed=TRUE)
81 | hres2$go_bp<-sapply(gt,function(x){paste(x[1:2],collapse=", ")})
82 | write.csv(hres2,fp(pth,"results/sshippo_hotspot_goterms_short.csv"),row.names=FALSE)
83 | ```
84 |
--------------------------------------------------------------------------------
/scrna/sshippo/07_traditional.ipy:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # ---
3 | # jupyter:
4 | # jupytext:
5 | # text_representation:
6 | # extension: .ipy
7 | # format_name: percent
8 | # format_version: '1.3'
9 | # jupytext_version: 1.6.0
10 | # kernelspec:
11 | # display_name: Python 3
12 | # language: python
13 | # name: python3
14 | # ---
15 |
16 | #%% imports
17 | # import numpy as np
18 | import pandas as pd
19 | import scanpy as sc
20 | import matplotlib.pyplot as plt
21 | from os import path
22 | from hotspot import Hotspot
23 |
24 | from utils import misc,visualize
25 |
26 | dtp = "float32"
27 | pth = "scrna/sshippo"
28 | dpth = path.join(pth,"data")
29 | mpth = path.join(pth,"models")
30 | rpth = path.join(pth,"results")
31 | plt_pth = path.join(rpth,"plots")
32 |
33 | # %% Data Loading from scanpy
34 | J = 2000
35 | dfile = path.join(dpth,"sshippo.h5ad")
36 | adata = sc.read_h5ad(dfile)
37 | sc.pp.highly_variable_genes(adata, flavor="seurat", n_top_genes=J)
38 | X0 = adata.obsm["spatial"]
39 | X0[:,1] = -X0[:,1]
40 |
41 | #%% Traditional scanpy analysis (unsupervised clustering)
42 | #https://scanpy-tutorials.readthedocs.io/en/latest/spatial/basic-analysis.html
43 | sc.pp.pca(adata)
44 | sc.pp.neighbors(adata)
45 | sc.tl.umap(adata)
46 | sc.tl.leiden(adata, resolution=1.0, key_added="clusters")
47 | plt.rcParams["figure.figsize"] = (4, 4)
48 | sc.pl.umap(adata, color="clusters", wspace=0.4)
49 | sc.pl.embedding(adata, "spatial", color="clusters")
50 | cl = pd.get_dummies(adata.obs["clusters"]).to_numpy()
51 | tgnames = [str(i) for i in range(1,cl.shape[1]+1)]
52 | fig,axes=visualize.multiheatmap(X0, cl, (3,4), figsize=(6,4), cmap="turbo",
53 | bgcol="gray", s=0.01, marker=".",
54 | subplot_space=0, spinecolor="white")
55 | visualize.set_titles(fig, tgnames, x=0.03, y=.88, fontsize="small", c="white",
56 | ha="left", va="top")
57 | fig.savefig(path.join(plt_pth,"sshippo_heatmap_scanpy_clusters.png"),
58 | bbox_inches='tight', dpi=300)
59 |
60 | #%% Hotspot analysis (gene clusters)
61 | #https://hotspot.readthedocs.io/en/latest/Spatial_Tutorial.html
62 | J = 2000
63 | dfile = path.join(dpth,"sshippo_J{}.h5ad".format(J))
64 | adata = sc.read_h5ad(dfile)
65 | adata.layers["counts"] = adata.layers["counts"].tocsc()
66 | hs = Hotspot(adata, layer_key="counts", model="danb",
67 | latent_obsm_key="spatial", umi_counts_obs_key="total_counts")
68 | hs.create_knn_graph(weighted_graph=False, n_neighbors=20)
69 | hs_results = hs.compute_autocorrelations()
70 | # hs_results.tail()
71 | hs_genes = hs_results.index#[hs_results.FDR < 0.05]
72 | lcz = hs.compute_local_correlations(hs_genes)
73 | modules = hs.create_modules(min_gene_threshold=20, core_only=False,
74 | fdr_threshold=0.05)
75 | # modules.value_counts()
76 | hs_results = hs_results.join(modules,how="left")
77 | hs_results.to_csv(path.join(rpth,"hotspot.csv"))
78 | misc.pickle_to_file(hs,path.join(mpth,"hotspot.pickle"))
79 |
--------------------------------------------------------------------------------
/scrna/utils/interpret_genes.R:
--------------------------------------------------------------------------------
1 | library(topGO)
2 |
3 | loadings2go<-function(W, gene_2_GO, numtopgenes=100, min_genescore=1.0,
4 | rowmean_divide=TRUE){
5 | #This blog was helpful: https://datacatz.wordpress.com/2018/01/19/gene-set-enrichment-analysis-with-topgo-part-1/
6 | #divide each gene by its rowmean to adjust for constant high expression
7 | if(rowmean_divide){
8 | W<-W/rowMeans(W)
9 | }
10 | W<-as.matrix(W) #enables preservation of rownames when subsetting cols
11 | qtl<-1-(numtopgenes/nrow(W))
12 | topGenes<-function(x){
13 | cutoff<-max(quantile(x,qtl),min_genescore)
14 | x>cutoff
15 | }
16 | #W must be a matrix with gene names in rownames
17 | L<-ncol(W)
18 | res<-data.frame(dim=1:L,genes="",go_bp="")
19 | # gg<-names(tg[[12]])
20 | # gg<-gg[gg %in% names(gene_2_GO)]
21 | # geneList=factor(as.integer(bg %in% gg))
22 | for(l in 1:L){
23 | w<-W[,l]
24 | # names(geneList)<-rownames(W)
25 | tg<-head(sort(w,decreasing=TRUE),10)
26 | res[l,"genes"]<-paste(names(tg),collapse=", ")
27 | if(l==1){
28 | GOdata<-new('topGOdata', ontology='BP', allGenes=w, geneSel=topGenes, annot=annFUN.gene2GO, gene2GO=gene_2_GO, nodeSize=5)
29 | } else {
30 | GOdata<-updateGenes(GOdata,w,topGenes)
31 | }
32 | #wk<-runTest(GOdata,algorithm='weight01', statistic='ks')
33 | wf<-runTest(GOdata, algorithm='weight01', statistic='fisher')
34 | gotab=GenTable(GOdata,weightFisher=wf,orderBy='weightFisher',numChar=1000,topNodes=5)
35 | res[l,"go_bp"]<-paste(gotab$Term,collapse="; ")
36 | }
37 | res
38 | }
39 |
40 | jaccard<-function(a,b){
41 | i<-length(intersect(a,b))
42 | denom<-ifelse(i>0, length(a)+length(b)-i, 1)
43 | i/denom
44 | }
45 |
46 | # panglao_make_ref<-function(panglao_tsv="scrna/resources/PanglaoDB_markers_27_Mar_2020.tsv.gz"){
47 | # mk<-read.table(gzfile(panglao_tsv),header=TRUE,sep="\t")
48 | # tapply(mk$official.gene.symbol,mk$cell.type,function(x){x},simplify=FALSE)
49 | # }
50 |
51 | genes2celltype<-function(genelist,panglao_ref,gsplit=TRUE,ss=0){
52 | #genelist: a string with gene names separated by commas OR a list of genes
53 | #panglao_ref: a list whose names are cell types and values are lists of genes
54 | #gsplit: if TRUE, splits genelist into a list. If FALSE, assumes genelist already split
55 | #returns : cell type with highest jaccard similarity to the gene list
56 | g<-ifelse(gsplit, strsplit(genelist,", ",fixed=TRUE)[[1]], genelist)
57 | if(ss>0 && ss0){
62 | res<-j[j==jmx]
63 | return(names(res))
64 | } else { #case where no cell types were found
65 | return(NA)
66 | }
67 | }
--------------------------------------------------------------------------------
/scrna/visium_brain_sagittal/01_data_loading.ipy:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Data loading for 10x Visium mouse brain sagittal section anterior 1.
5 |
6 | Created on Sat Jun 5 15:22:19 2021
7 |
8 | @author: townesf
9 | """
10 |
11 | # %% imports
12 | import random
13 | import numpy as np
14 | import scanpy as sc
15 | from os import path
16 | from scipy import sparse
17 |
18 | from utils import preprocess, training, misc
19 |
20 | random.seed(101)
21 | pth = "scrna/visium_brain_sagittal"
22 |
23 | # %% Download original data files
24 | %%sh
25 | mkdir -p scrna/visium_brain_sagittal/data/original
26 | pushd scrna/visium_brain_sagittal/data/original
27 | wget https://cf.10xgenomics.com/samples/spatial-exp/1.1.0/V1_Mouse_Brain_Sagittal_Anterior/V1_Mouse_Brain_Sagittal_Anterior_filtered_feature_bc_matrix.h5
28 | wget https://cf.10xgenomics.com/samples/spatial-exp/1.1.0/V1_Mouse_Brain_Sagittal_Anterior/V1_Mouse_Brain_Sagittal_Anterior_spatial.tar.gz
29 | tar -xzf V1_Mouse_Brain_Sagittal_Anterior_spatial.tar.gz
30 | rm V1_Mouse_Brain_Sagittal_Anterior_spatial.tar.gz
31 | popd
32 |
33 | # %% Desiderata for dataset [markdown]
34 | # 1. Spatial coordinates
35 | # 2. Features sorted in decreasing order of deviance
36 | # 3. Observations randomly shuffled
37 |
38 | # %% QC, loosely following MEFISTO tutorials
39 | # https://nbviewer.jupyter.org/github/bioFAM/MEFISTO_tutorials/blob/master/MEFISTO_ST.ipynb#QC-and-preprocessing
40 | ad = sc.read_visium(path.join(pth,"data/original"),
41 | count_file="V1_Mouse_Brain_Sagittal_Anterior_filtered_feature_bc_matrix.h5")
42 | ad.var_names_make_unique()
43 | ad.var["mt"] = ad.var_names.str.startswith("mt-")
44 | sc.pp.calculate_qc_metrics(ad, qc_vars=["mt"], inplace=True)
45 | ad = ad[ad.obs.pct_counts_mt < 20]
46 | sc.pp.filter_genes(ad, min_cells=1)
47 | sc.pp.filter_cells(ad, min_counts=100)
48 | ad.layers = {"counts":ad.X.copy()} #store raw counts before normalization changes ad.X
49 | sc.pp.normalize_total(ad, inplace=True, layers=None, key_added="sizefactor")
50 | sc.pp.log1p(ad)
51 | #Y = misc.reverse_normalization(np.expm1(ad.X),ad.obs["sizefactor"])
52 | #np.max(np.abs(Y-np.layers["counts"]))
53 |
54 | # %% normalization, feature selection and train/test split
55 | ad.var['deviance_poisson'] = preprocess.deviancePoisson(ad.layers["counts"])
56 | o = np.argsort(-ad.var['deviance_poisson'])
57 | idx = list(range(ad.shape[0]))
58 | random.shuffle(idx)
59 | ad = ad[idx,o]
60 | ad.write_h5ad(path.join(pth,"data/visium_brain_sagittal.h5ad"),compression="gzip")
61 | #ad = sc.read_h5ad(path.join(pth,"data/visium_brain_sagittal.h5ad"))
62 | ad2 = ad[:,:2000]
63 | ad2.write_h5ad(path.join(pth,"data/visium_brain_sagittal_J2000.h5ad"),compression="gzip")
64 |
--------------------------------------------------------------------------------
/scrna/visium_brain_sagittal/03_benchmark.ipy:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # ---
3 | # jupyter:
4 | # jupytext:
5 | # text_representation:
6 | # extension: .ipy
7 | # format_name: percent
8 | # format_version: '1.3'
9 | # jupytext_version: 1.6.0
10 | # kernelspec:
11 | # display_name: Python 3
12 | # language: python
13 | # name: python3
14 | # ---
15 |
16 | #%%
17 | import pandas as pd
18 | from os import path
19 | from utils import misc,benchmark
20 |
21 | pth = "scrna/visium_brain_sagittal"
22 | dpth = path.join(pth,"data")
23 | mpth = path.join(pth,"models/V5")
24 | rpth = path.join(pth,"results")
25 | misc.mkdir_p(rpth)
26 |
27 | #%% Create CSV with benchmarking parameters
28 | csv_path = path.join(rpth,"benchmark.csv")
29 | try:
30 | par = pd.read_csv(csv_path)
31 | except FileNotFoundError:
32 | L = [6,12,20]
33 | sp_mods = ["NSF-P", "NSF-N", "NSF-G", "RSF-G", "NSFH-P", "NSFH-N"]
34 | ns_mods = ["PNMF-P", "PNMF-N", "FA-G"]
35 | sz = ["constant","scanpy"]
36 | M = [500,1000,2363]
37 | V = [5]
38 | kernels=["MaternThreeHalves"]
39 | par = benchmark.make_param_df(L,sp_mods,ns_mods,M,sz,V=V,kernels=kernels)
40 | par.to_csv(csv_path,index=False)
41 |
42 | #%% merge old benchmark csv with new
43 | # old = pd.read_csv(path.join(rpth,"benchmark1.csv"))
44 | # new = par.merge(old,on="key",how="outer",copy=True)
45 | # new["converged"] = new["converged_y"]
46 | # new["converged"].fillna(False, inplace=True)
47 | # new.drop(["converged_x","converged_y"],axis=1,inplace=True)
48 | # new.to_csv(path.join(rpth,"benchmark2.csv"),index=False)
49 | ##rename benchmark2 to benchmark manually
50 | ##some additional scenarios added manually as well (NSF,NSFH with L=36)
51 |
52 | #%% Additional scenarios added in response to reviewer comments
53 | # manually merge with original benchmark.csv
54 | # all original scenarios with 80/20 train/val split (only Matern32 kernel, no NB lik)
55 | csv_path = path.join(rpth,"benchmark2.csv")
56 | try:
57 | par = pd.read_csv(csv_path)
58 | except FileNotFoundError:
59 | L = [6,12,20]
60 | sp_mods = ["NSF-P", "RSF-G", "NSFH-P"]
61 | ns_mods = ["PNMF-P", "FA-G"]
62 | sz = ["constant","scanpy"]
63 | M = [500,1000,2363]
64 | V = [20]
65 | kernels=["MaternThreeHalves"]
66 | par = benchmark.make_param_df(L,sp_mods,ns_mods,M,sz,V=V,kernels=kernels)
67 | par.to_csv(csv_path,index=False)
68 | # NSF, RSF, and NSFH with ExponentiatedQuadratic kernel
69 | # need to manually delete duplicate MEFISTO scenarios in this one
70 | csv_path = path.join(rpth,"benchmark3.csv")
71 | try:
72 | par = pd.read_csv(csv_path)
73 | except FileNotFoundError:
74 | L = [6,12,20]
75 | sp_mods = ["NSF-P", "RSF-G", "NSFH-P"]
76 | ns_mods = []
77 | sz = ["constant","scanpy"]
78 | M = [500,1000,2363]
79 | V = [5]
80 | kernels=["ExponentiatedQuadratic"]
81 | par = benchmark.make_param_df(L,sp_mods,ns_mods,M,sz,V=V,kernels=kernels)
82 | par.to_csv(csv_path,index=False)
83 |
84 | #%% Benchmarking at command line [markdown]
85 | """
86 | To run on local computer, use
87 | `python -m utils.benchmark 197 scrna/visium_brain_sagittal/data/visium_brain_sagittal_J2000.h5ad`
88 | where 41 is a row ID of benchmark.csv, min value 2, max possible value is 241
89 |
90 | To run on cluster first load anaconda environment
91 | ```
92 | tmux
93 | interactive
94 | module load anaconda3/2021.5
95 | conda activate fwt
96 | python -m utils.benchmark 14 scrna/visium_brain_sagittal/data/visium_brain_sagittal_J2000.h5ad
97 | ```
98 |
99 | To run on cluster as a job array, subset of rows, recommend 6hr time limit.
100 | ```
101 | DAT=./scrna/visium_brain_sagittal/data/visium_brain_sagittal_J2000.h5ad
102 | sbatch --mem=72G --array=135-196,198-241 ./utils/benchmark_array.slurm $DAT
103 | ```
104 |
105 | To run on cluster as a job array, all rows of CSV file
106 | ```
107 | CSV=./scrna/visium_brain_sagittal/results/benchmark.csv
108 | DAT=./scrna/visium_brain_sagittal/data/visium_brain_sagittal_J2000.h5ad
109 | sbatch --mem=72G --array=2-$(wc -l < $CSV) ./utils/benchmark_array.slurm $DAT
110 | ```
111 | """
112 |
113 | #%% Compute metrics for each model (as a job)
114 | """
115 | DAT=./scrna/visium_brain_sagittal/data/visium_brain_sagittal_J2000.h5ad
116 | sbatch --mem=72G ./utils/benchmark_gof.slurm $DAT 5
117 | #wait until job finishes, then run below
118 | DAT=./scrna/visium_brain_sagittal/data/visium_brain_sagittal_J2000.h5ad
119 | sbatch --mem=72G ./utils/benchmark_gof.slurm $DAT 20
120 | """
121 |
122 | #%% Compute metrics for each model (manually)
123 | from utils import benchmark
124 | dat = "scrna/visium_brain_sagittal/data/visium_brain_sagittal_J2000.h5ad"
125 | res = benchmark.update_results(dat, val_pct=20, todisk=True)
126 |
127 | #%% Examine one result
128 | from matplotlib import pyplot as plt
129 | from utils import training
130 | tro = training.ModelTrainer.from_pickle(path.join(mpth,"L4/poi/NSF_MaternThreeHalves_M3000"))
131 | plt.plot(tro.loss["train"][-200:-1])
132 |
133 | #%%
134 | csv_file = path.join(rpth,"benchmark.csv")
135 | Ntr = tro.model.Z.shape[0]
136 | benchmark.correct_inducing_pts(csv_file, Ntr)
137 |
--------------------------------------------------------------------------------
/scrna/visium_brain_sagittal/04_benchmark_viz.Rmd:
--------------------------------------------------------------------------------
1 | ---
2 | title: "Benchmarking Visualization"
3 | author: "Will Townes"
4 | output: html_document
5 | ---
6 |
7 | ```{r}
8 | library(tidyverse)
9 | theme_set(theme_bw())
10 | fp<-file.path
11 | pth<-"scrna/visium_brain_sagittal"
12 | plt_pth<-fp(pth,"results/plots")
13 | if(!dir.exists(plt_pth)){
14 | dir.create(plt_pth,recursive=TRUE)
15 | }
16 | ```
17 |
18 | ```{r}
19 | d0<-read.csv(fp(pth,"results/benchmark.csv"))
20 | d0$converged<-as.logical(d0$converged)
21 | #d0$model<-plyr::mapvalues(d0$model,c("FA","RSF","PNMF","NSFH","NSF"),c("FA","RSF","PNMF","NSFH","NSF"))
22 | d<-subset(d0,converged==TRUE)
23 | d$model<-paste0(d$model," (",d$lik,")")
24 | unique(d$model)
25 | keep<-c("FA (gau)","MEFISTO (gau)","RSF (gau)","PNMF (nb)","PNMF (poi)","NSFH (nb)","NSFH (poi)","NSF (nb)","NSF (poi)")
26 |
27 | d<-subset(d,model %in% keep)
28 | d$model<-factor(d$model,levels=keep)
29 | d$dim<-factor(d$L,levels=sort(unique(d$L)))
30 | d$M[is.na(d$M)]<-2363
31 | d$IPs<-factor(d$M,levels=sort(unique(d$M)))
32 | d$standard_kernel<-TRUE
33 | d$standard_kernel[(d$model %in% c("RSF (gau)","NSFH (nb)","NSFH (poi)","NSF (nb)","NSF (poi)")) & d$kernel=="ExponentiatedQuadratic"]<-FALSE
34 | d1<-subset(d,V==5 & standard_kernel)
35 | ```
36 | subset of models for simplified main figure
37 | ```{r}
38 | d2<-subset(d1,M==2363 | (model %in% c("PNMF (poi)","FA (gau)")))
39 | d2a<-subset(d2,(model %in% c("PNMF (poi)","NSFH (poi)","NSF (poi)")) & sz=="scanpy")
40 | d2b<-subset(d2,model %in% c("FA (gau)","RSF (gau)","MEFISTO (gau)"))
41 | d2<-rbind(d2a,d2b)
42 | #d2$model<-factor(d2$model,levels=)
43 | d2$model<-plyr::mapvalues(d2$model,c("FA (gau)","MEFISTO (gau)","RSF (gau)","PNMF (poi)","NSFH (poi)","NSF (poi)"),c("FA","MEFISTO","RSF","PNMF","NSFH","NSF"))
44 | ggplot(d2,aes(x=model,y=dev_val_mean,color=dim,shape=lik))+geom_point(size=6,position=position_dodge(width=0.5))+ylab("validation deviance (mean)")+ylim(range(d2$dev_val_mean)*c(.95,1.02))
45 | ggsave(fp(plt_pth,"vz_brn_gof_dev_val_mean_main.pdf"),width=5,height=2.5)
46 |
47 | #linear regressions for statistical significance
48 | d2$realvalued<-d2$model %in% c("FA","MEFISTO","RSF")
49 | #d2$dim_numeric<-as.numeric(as.character(d2$dim))
50 | summary(lm(dev_val_mean~realvalued+dim,data=d2))
51 | summary(lm(dev_val_mean~model,data=d2))
52 | t.test(dev_val_mean~model,data=subset(d2,model %in% c("NSF","RSF")), var.equal=TRUE, alternative="less")
53 | d2$unsupervised<-d2$model %in% c("FA","PNMF")
54 | d2$unsupervised[d2$model=="MEFISTO"]<-NA
55 | summary(lm(dev_val_mean~realvalued+dim+unsupervised,data=d2))
56 | t.test(dev_val_mean~model,data=subset(d2,model %in% c("RSF","MEFISTO")), var.equal=TRUE,alternative="greater")
57 |
58 | d2_no60<-subset(d2,dim!="60")
59 | d2_no60$dim<-factor(d2_no60$dim)
60 | ggplot(d2_no60,aes(x=model,y=rmse_val,color=dim,shape=lik))+geom_point(size=6,position=position_dodge(width=0.5))+ylab("validation RMSE")+ylim(range(d2$rmse_val)*c(.95,1.02))
61 | ggsave(fp(plt_pth,"vz_brn_gof_rmse_val_simple.pdf"),width=5,height=2.5)
62 |
63 | ggplot(d2_no60,aes(x=model,y=sparsity,color=dim,shape=lik))+geom_point(size=6,position=position_dodge(width=0.5))+ylab("loadings zero fraction")+ylim(range(d2$sparsity)*c(.95,1.02))
64 | ggsave(fp(plt_pth,"vz_brn_sparsity_simple.pdf"),width=5,height=2.5)
65 | ```
66 |
67 | NSF vs NSFH
68 |
69 | ```{r}
70 | d3<-subset(d,grepl("NSF",model)&kernel=="MaternThreeHalves" & sz=="scanpy" & V==5 & dim!="60")
71 | ggplot(d3,aes(x=model,y=dev_tr_mean,color=dim,shape=IPs))+geom_point(size=4,position=position_dodge(width=0.6))+ylab("training deviance (mean)")
72 | ggsave(fp(plt_pth,"vz_brn_gof_dev_tr_nsf_vs_nsfh.pdf"),width=6,height=3)
73 | ```
74 |
75 | ```{r}
76 | #training deviance - average
77 | ggplot(d1,aes(x=model,y=dev_tr_mean,color=dim,shape=IPs))+geom_jitter(size=5,width=.2,height=0)+ylab("training deviance (mean)")+theme(legend.position = "top")
78 | ggsave(fp(plt_pth,"vz_brn_gof_dev_tr_mean.pdf"),width=8,height=3)
79 |
80 | #training deviance - max
81 | ggplot(d1,aes(x=model,y=dev_tr_max,color=dim,shape=IPs))+geom_jitter(size=5,width=.2,height=0)+ylab("training deviance (max)")#+scale_y_log10()
82 |
83 | #validation deviance - mean
84 | ggplot(d1,aes(x=model,y=dev_val_mean,color=dim,shape=IPs))+geom_jitter(size=4,width=.4,height=0)+ylab("validation deviance (mean)")+theme(legend.position="none")
85 | ggsave(fp(plt_pth,"vz_brn_gof_dev_val_mean.pdf"),width=8,height=2.5)
86 |
87 | #validation deviance - max
88 | ggplot(d1,aes(x=model,y=dev_val_max,color=dim,shape=IPs))+geom_jitter(size=3,width=.2,height=0)+ylab("validation deviance (max)")#+scale_y_log10()
89 |
90 | #sparsity
91 | ggplot(d1,aes(x=model,y=sparsity,color=dim,shape=IPs))+geom_jitter(size=4,width=.3,height=0)+ylab("sparsity of loadings")+theme(legend.position="none")
92 | ggsave(fp(plt_pth,"vz_brn_sparsity.pdf"),width=6,height=2.5)
93 |
94 | #wall time
95 | ggplot(d1,aes(x=model,y=wtime/60,color=IPs,shape=dim))+geom_jitter(size=5,width=.2,height=0)+ylab("wall time (min)")+scale_y_log10()+theme(legend.position="top")
96 | ggsave(fp(plt_pth,"vz_brn_wtime.pdf"),width=6,height=3)
97 | ggplot(d1,aes(x=dim,y=wtime/60,color=IPs))+geom_jitter(size=3,width=.2,height=0)+ylab("wall time (min)")+scale_y_log10()
98 |
99 | #processor time
100 | ggplot(d1,aes(x=model,y=ptime/60,color=IPs,shape=dim))+geom_jitter(size=5,width=.2,height=0)+ylab("processor time")+scale_y_log10()+theme(legend.position="none")
101 | ggsave(fp(plt_pth,"vz_brn_ptime.pdf"),width=6,height=2.5)
102 | ```
103 |
104 | Effect of using size factors
105 |
106 | ```{r}
107 | d2<-subset(d1,model %in% c("NSF (nb)","NSF (poi)","NSFH (nb)","NSFH (poi)","PNMF (nb)","PNMF (poi)"))
108 | d2$model<-factor(d2$model)
109 | d2$dim<-factor(d2$L,levels=sort(unique(d2$L)))
110 | d2$M[is.na(d2$M)]<-2363
111 | d2$IPs<-factor(d2$M,levels=sort(unique(d2$M)))
112 | ```
113 |
114 | ```{r}
115 | ggplot(d2,aes(x=sz,y=dev_tr_mean,color=dim,shape=IPs))+geom_jitter(size=5,width=.2,height=0)+ylab("training deviance (mean)")+theme(legend.position = "top")+facet_wrap(~model,ncol=2,scales="free")
116 |
117 | ggplot(d2,aes(x=sz,y=dev_val_mean,color=dim,shape=IPs))+geom_jitter(size=5,width=.2,height=0)+ylab("validation deviance (mean)")+theme(legend.position = "top")+facet_wrap(~model,ncol=2,scales="free")
118 | ```
119 |
120 | Reviewer comment: GOF metrics with validation 20% of data instead of 5%
121 | ```{r}
122 | d2<-subset(d,V==20 & standard_kernel)
123 | d2<-subset(d2,M==2363 | (model %in% c("PNMF (poi)","FA (gau)")))
124 | d2a<-subset(d2,(model %in% c("PNMF (poi)","NSFH (poi)","NSF (poi)")) & sz=="scanpy")
125 | d2b<-subset(d2,model %in% c("FA (gau)","RSF (gau)","MEFISTO (gau)"))
126 | d2<-rbind(d2a,d2b)
127 | #d2$model<-factor(d2$model,levels=)
128 | d2$model<-plyr::mapvalues(d2$model,c("FA (gau)","MEFISTO (gau)","RSF (gau)","PNMF (poi)","NSFH (poi)","NSF (poi)"),c("FA","MEFISTO","RSF","PNMF","NSFH","NSF"))
129 | ggplot(d2,aes(x=model,y=dev_val_mean,color=dim,shape=lik))+geom_point(size=6,position=position_dodge(width=0.5))+ylab("validation deviance (mean)")+ylim(range(d2$dev_val_mean)*c(.95,1.02))
130 | ggsave(fp(plt_pth,"vz_brn_gof_dev_val_mean_main_V20.pdf"),width=5,height=2.5)
131 |
132 | ggplot(d2,aes(x=model,y=rmse_val,color=dim,shape=lik))+geom_point(size=6,position=position_dodge(width=0.5))+ylab("validation RMSE")+ylim(range(d2$rmse_val)*c(.95,1.02))
133 | ggsave(fp(plt_pth,"vz_brn_gof_rmse_val_simple_V20.pdf"),width=5,height=2.5)
134 | ```
135 |
136 | Reviewer comment: Matern vs RBF kernels
137 | ```{r}
138 | d2<-subset(d,V==5 & M==1000)
139 | d2a<-subset(d2,(model %in% c("NSFH (poi)","NSF (poi)")) & sz=="scanpy")
140 | d2b<-subset(d2,model %in% c("RSF (gau)","MEFISTO (gau)"))
141 | d2<-rbind(d2a,d2b)
142 | #d2$model<-factor(d2$model,levels=)
143 | d2$model<-plyr::mapvalues(d2$model,c("MEFISTO (gau)","RSF (gau)","NSFH (poi)","NSF (poi)"),c("MEFISTO","RSF","NSFH","NSF"))
144 | ggplot(d2,aes(x=model,y=dev_val_mean,color=dim,shape=kernel))+geom_point(size=6,position=position_dodge(width=0.7))+ylab("validation deviance (mean)")+ylim(range(d2$dev_val_mean)*c(.95,1.02))
145 | ggsave(fp(plt_pth,"vz_brn_gof_dev_val_kernels.pdf"),width=5,height=2.5)
146 |
147 | ggplot(d2,aes(x=model,y=rmse_val,color=dim,shape=kernel))+geom_point(size=6,position=position_dodge(width=0.7))+ylab("validation RMSE")+ylim(range(d2$rmse_val)*c(.95,1.02))
148 | ggsave(fp(plt_pth,"vz_brn_gof_rmse_val_kernels.pdf"),width=5,height=2.5)
149 | ```
150 |
151 | numerical stability of different kernel choices
152 | ```{r}
153 | d0<-read.csv(fp(pth,"results/benchmark.csv"))
154 | d0$converged<-as.logical(d0$converged)
155 | table(d0$converged)
156 | d1<-subset(d0,V==5)# & M==2363)
157 | d1a<-subset(d1,model=="RSF" & lik=="gau")
158 | d1b<-subset(d1,model %in% c("NSF","NSFH") & lik=="poi")
159 | d<-rbind(d1a,d1b)
160 | d2 <- d %>% group_by(kernel,M,model,lik) %>% summarize(total_runs=length(converged),converged=sum(converged))
161 | ```
162 |
--------------------------------------------------------------------------------
/scrna/visium_brain_sagittal/05_interpret_genes.Rmd:
--------------------------------------------------------------------------------
1 | ---
2 | title: "05_interpretation_genes"
3 | author: "Will Townes"
4 | output: html_document
5 | ---
6 |
7 | ```{r}
8 | # library(tidyverse)
9 | library(biomaRt)
10 | # library(rPanglaoDB)
11 | source("./scrna/utils/interpret_genes.R")
12 |
13 | fp<-file.path
14 | pth<-"scrna/visium_brain_sagittal"
15 | ```
16 | loadings from NSFH
17 | ```{r}
18 | W<-read.csv(fp(pth,"results/vz_brn_nsfh20_spde_loadings.csv"),header=TRUE,row.names=1)
19 | rownames(W)<-toupper(rownames(W))
20 | W2<-t(apply(as.matrix(W),1,function(x){as.numeric(x==max(x))}))
21 | colSums(W2) #clustering of genes
22 | ```
23 | Gene Ontology terms, used for both NSFH and hotspot.
24 | ```{r}
25 | # bg<-rownames(W)
26 | db<-useMart("ensembl",host="https://may2021.archive.ensembl.org",dataset='mmusculus_gene_ensembl')
27 | go_ids<-getBM(attributes=c('go_id', 'external_gene_name','namespace_1003'), filters='external_gene_name', values=rownames(W), mart=db)
28 | go_ids[,2]<-toupper(go_ids[,2])
29 | gene_2_GO<-unstack(go_ids[,c(1,2)])
30 | ```
31 | GO analysis for NSFH
32 | ```{r}
33 | res<-loadings2go(W,gene_2_GO)
34 | res$type<-rep(c("spat","nsp"),each=10)
35 | res$dim<-rep(1:10,2)
36 | write.csv(res[,c(1,4,2,3)],fp(pth,"results/vz_brn_nsfh20_spde_goterms.csv"),row.names=FALSE)
37 | ```
38 | Get cell types from [Panglao](https://panglaodb.se). First manually download the list of all marker genes for all cell types from https://panglaodb.se/markers.html?cell_type=%27all_cells%27 . Store this TSV in the scrna/resources folder.
39 | ```{r}
40 | res<-read.csv(fp(pth,"results/vz_brn_nsfh20_spde_goterms.csv"),header=TRUE)
41 | # g<-strsplit(res[10,"genes"],", ",fixed=TRUE)[[1]]
42 | #pg<-panglao_make_ref()
43 | panglao_tsv="scrna/resources/PanglaoDB_markers_27_Mar_2020.tsv.gz"
44 | mk<-read.table(gzfile(panglao_tsv),header=TRUE,sep="\t")
45 | mk<-subset(mk,species %in% c("Mm","Mm Hs"))
46 | mk<-subset(mk,organ %in% c("Brain","Connective tissue","Epithelium","Olfactory system"))
47 | #,"Immune system","Vasculature","Blood"
48 | pg<-tapply(mk$official.gene.symbol,mk$cell.type,function(x){x},simplify=FALSE)
49 | # pg2<-pg[!(names(pg) %in% c("Plasma cells","Purkinje neurons","Delta cells",
50 | # bad<-c("Plasma cells","Purkinje neurons","Delta cells","Purkinje fiber cells","Erythroid-like and erythroid precurser cells")
51 | ct<-lapply(res$genes,genes2celltype,pg)#,ss=3)
52 | ct[is.na(ct)]<-""
53 | res$celltypes<-paste(ct,sep=", ")
54 | write.csv(res,fp(pth,"results/vz_brn_nsfh20_spde_goterms.csv"),row.names=FALSE)
55 | ```
56 | Shorter GO table
57 | ```{r}
58 | res<-read.csv(fp(pth,"results/vz_brn_nsfh20_spde_goterms.csv"),header=TRUE)
59 | res2<-res
60 | g<-strsplit(res2$genes,", ",fixed=TRUE)
61 | res2$genes<-sapply(g,function(x){paste(x[1:5],collapse=", ")})
62 | gt<-strsplit(res2$go_bp,"; ",fixed=TRUE)
63 | res2$go_bp<-sapply(gt,function(x){paste(x[1:2],collapse=", ")})
64 | write.csv(res2,fp(pth,"results/vz_brn_nsfh20_spde_goterms_short.csv"),row.names=FALSE)
65 | ```
66 |
67 | GO analysis for hotspot
68 | ```{r}
69 | hs<-read.csv(fp(pth,"results/hotspot.csv"),header=TRUE)
70 | rownames(hs)<-hs$Gene
71 | Wh<-model.matrix(~0+factor(Module),data=subset(hs,Module!=-1))
72 | colnames(Wh)<-paste0("X",1:ncol(Wh))
73 | rownames(Wh)<-toupper(rownames(Wh))
74 | # all(rownames(Wh) %in% rownames(W))
75 | hres<-loadings2go(Wh,gene_2_GO, rowmean_divide=TRUE)
76 | colnames(hres)[which(colnames(hres)=="dim")]<-"cluster"
77 | write.csv(hres,fp(pth,"results/vz_brn_hotspot_goterms.csv"),row.names=FALSE)
78 | hres2<-hres
79 | g<-strsplit(hres2$genes,", ",fixed=TRUE)
80 | hres2$genes<-sapply(g,function(x){paste(x[1:5],collapse=", ")})
81 | gt<-strsplit(hres2$go_bp,"; ",fixed=TRUE)
82 | hres2$go_bp<-sapply(gt,function(x){paste(x[1:2],collapse=", ")})
83 | write.csv(hres2,fp(pth,"results/vz_brn_hotspot_goterms_short.csv"),row.names=FALSE)
84 | ```
85 |
--------------------------------------------------------------------------------
/scrna/visium_brain_sagittal/06_traditional.ipy:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # ---
3 | # jupyter:
4 | # jupytext:
5 | # text_representation:
6 | # extension: .ipy
7 | # format_name: percent
8 | # format_version: '1.3'
9 | # jupytext_version: 1.6.0
10 | # kernelspec:
11 | # display_name: Python 3
12 | # language: python
13 | # name: python3
14 | # ---
15 |
16 | #%% imports
17 | # import numpy as np
18 | import pandas as pd
19 | import scanpy as sc
20 | import matplotlib.pyplot as plt
21 | from os import path
22 | from hotspot import Hotspot
23 |
24 | from utils import misc,visualize
25 |
26 | dtp = "float32"
27 | pth = "scrna/visium_brain_sagittal"
28 | dpth = path.join(pth,"data")
29 | mpth = path.join(pth,"models")
30 | rpth = path.join(pth,"results")
31 | plt_pth = path.join(rpth,"plots")
32 |
33 | # %% Data Loading from scanpy
34 | J = 2000
35 | # dfile = path.join(dpth,"visium_brain_sagittal_J{}.h5ad".format(J))
36 | dfile = path.join(dpth,"visium_brain_sagittal.h5ad")
37 | adata = sc.read_h5ad(dfile)
38 | sc.pp.highly_variable_genes(adata, flavor="seurat", n_top_genes=J)
39 | X0 = adata.obsm["spatial"]
40 | X0[:,1] = -X0[:,1]
41 |
42 | #%% Traditional scanpy analysis (unsupervised clustering)
43 | #https://scanpy-tutorials.readthedocs.io/en/latest/spatial/basic-analysis.html
44 | sc.pp.pca(adata)
45 | sc.pp.neighbors(adata)
46 | sc.tl.umap(adata)
47 | sc.tl.leiden(adata, resolution=1.0, key_added="clusters")
48 | plt.rcParams["figure.figsize"] = (4, 4)
49 | sc.pl.umap(adata, color="clusters", wspace=0.4)
50 | sc.pl.embedding(adata, "spatial", color="clusters")
51 | cl = pd.get_dummies(adata.obs["clusters"]).to_numpy()
52 | tgnames = [str(i) for i in range(1,cl.shape[1]+1)]
53 | hmkw = {"figsize":(10,8), "s":0.5, "marker":"D", "subplot_space":0,
54 | "spinecolor":"white"}
55 | fig,axes=visualize.multiheatmap(X0, cl, (4,5), **hmkw)
56 | visualize.set_titles(fig, tgnames, x=0.03, y=.88, fontsize="small", c="white",
57 | ha="left", va="top")
58 | fig.savefig(path.join(plt_pth,"vz_brn_heatmap_scanpy_clusters.pdf"),
59 | bbox_inches='tight')
60 |
61 | #%% Hotspot analysis (gene clusters)
62 | #https://hotspot.readthedocs.io/en/latest/Spatial_Tutorial.html
63 | J = 2000
64 | dfile = path.join(dpth,"visium_brain_sagittal_J{}.h5ad".format(J))
65 | adata = sc.read_h5ad(dfile)
66 | adata.layers["counts"] = adata.layers["counts"].tocsc()
67 | hs = Hotspot(adata, layer_key="counts", model="danb",
68 | latent_obsm_key="spatial", umi_counts_obs_key="total_counts")
69 | hs.create_knn_graph(weighted_graph=False, n_neighbors=20)
70 | hs_results = hs.compute_autocorrelations()
71 | # hs_results.tail()
72 | hs_genes = hs_results.index#[hs_results.FDR < 0.05]
73 | lcz = hs.compute_local_correlations(hs_genes)
74 | modules = hs.create_modules(min_gene_threshold=20, core_only=False,
75 | fdr_threshold=0.05)
76 | # modules.value_counts()
77 | hs_results = hs_results.join(modules,how="left")
78 | hs_results.to_csv(path.join(rpth,"hotspot.csv"))
79 | misc.pickle_to_file(hs,path.join(mpth,"hotspot.pickle"))
80 |
--------------------------------------------------------------------------------
/scrna/xyzeq_liver/01_data_loading.ipy:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # ---
3 | # jupyter:
4 | # jupytext:
5 | # text_representation:
6 | # extension: .ipy
7 | # format_name: percent
8 | # format_version: '1.3'
9 | # jupytext_version: 1.6.0
10 | # kernelspec:
11 | # display_name: Python 3
12 | # language: python
13 | # name: python3
14 | # ---
15 |
16 | # %% imports
17 | import random
18 | import numpy as np
19 | import pandas as pd
20 | import scanpy as sc
21 | from os import path
22 | from scipy import sparse
23 | from matplotlib import pyplot as plt
24 |
25 | from utils import preprocess, training, misc
26 |
27 | random.seed(101)
28 | pth = "scrna/xyzeq_liver"
29 | dpth = path.join(pth,"data")
30 |
31 | # %% Download original data files
32 | %%sh
33 | mkdir -p scrna/xyzeq_liver/data/original
34 | pushd scrna/xyzeq_liver/data/original
35 | wget ftp://ftp.ncbi.nlm.nih.gov/geo/samples/GSM5009nnn/GSM5009531/suppl/GSM5009531_L20C1.h5ad.gz
36 | unpigz GSM5009531_L20C1.h5ad.gz
37 | popd
38 | # Additional data files manually downloaded, provided by authors in email
39 |
40 | # %% Data Loading from original
41 | coords = pd.read_csv(path.join(dpth,"original","plate23_map.csv"))
42 | coords.rename(columns={coords.columns[0]: "barcode"},inplace=True)
43 | labs = pd.read_csv(path.join(dpth,"original","L20C1.csv"))
44 | #merge spatial coordinates with cell type labels and metadata
45 | labs["barcode"] = [w.split(".")[1] for w in labs["index"]]
46 | labs = labs.merge(coords,on="barcode")
47 | ad = sc.read_h5ad(path.join(dpth,"original","GSM5009531_L20C1.h5ad"))
48 | #match cell type labels with anndata rownames
49 | ad = ad[labs["index"]]
50 | labs.set_index("index",inplace=True,verify_integrity=True)
51 | ad.obs = labs
52 | mouse_cells = ad.obs["cell_call"]=="M"
53 | mouse_genes = ad.var_names.str.startswith("mm10_")
54 | ad = ad[mouse_cells,mouse_genes] #mouse cells only
55 | nz_genes = np.ravel(ad.X.sum(axis=0)>0)
56 | ad = ad[:,nz_genes]
57 | #rename genes to remove mm10_ prefix
58 | ad.var_names = ad.var_names.str.replace("mm10_","")
59 | #how many unique barcodes in this slice: only 289
60 | print("Unique locations: {}".format(len(ad.obs["barcode"].unique())))
61 | ad.obsm["spatial"] = ad.obs[["X","Y"]].to_numpy()
62 | ad.obs.drop(columns=["X","Y"],inplace=True)
63 | X = ad.obsm["spatial"]
64 | #rectangle marker code: https://stackoverflow.com/a/62572367
65 | plt.scatter(X[:,0],X[:,1],marker='$\u25AE$',s=120)
66 | plt.gca().invert_yaxis()
67 | plt.title("Mouse cell locations")
68 | plt.show()
69 |
70 | #%% QC
71 | ad.var["mt"] = ad.var_names.str.startswith("mt-")
72 | sc.pp.calculate_qc_metrics(ad, qc_vars=["mt"], inplace=True)
73 | ad.obs.pct_counts_mt.hist() #all less than 2%, no need to filter cells
74 | #all cells and genes passed the below criteria
75 | ad = ad[ad.obs.pct_counts_mt < 20]
76 | sc.pp.filter_genes(ad, min_cells=1)
77 | sc.pp.filter_cells(ad, min_counts=100)
78 | ad.layers = {"counts":ad.X.copy()} #store raw counts before normalization changes ad.X
79 | sc.pp.normalize_total(ad, inplace=True, layers=None, key_added="sizefactor")
80 | sc.pp.log1p(ad)
81 |
82 | # %% normalization, feature selection and train/test split
83 | ad.var['deviance_poisson'] = preprocess.deviancePoisson(ad.layers["counts"])
84 | o = np.argsort(-ad.var['deviance_poisson'])
85 | idx = list(range(ad.shape[0]))
86 | random.shuffle(idx)
87 | ad = ad[idx,o]
88 | ad.write_h5ad(path.join(dpth,"xyzeq_liver_L20C1_mouseonly.h5ad"),compression="gzip")
89 | #ad = sc.read_h5ad(path.join(dpth,"xyzeq_liver_L20C1_mouseonly.h5ad"))
90 | plt.plot(ad.var["deviance_poisson"].to_numpy())
91 | ad2 = ad[:,:2000]
92 | ad2.write_h5ad(path.join(pth,"data/xyzeq_liver_L20C1_mouseonly_J2000.h5ad"),compression="gzip")
93 |
--------------------------------------------------------------------------------
/scrna/xyzeq_liver/03_benchmark.ipy:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # ---
3 | # jupyter:
4 | # jupytext:
5 | # text_representation:
6 | # extension: .ipy
7 | # format_name: percent
8 | # format_version: '1.3'
9 | # jupytext_version: 1.6.0
10 | # kernelspec:
11 | # display_name: Python 3
12 | # language: python
13 | # name: python3
14 | # ---
15 |
16 | #%%
17 | import pandas as pd
18 | from os import path
19 | from utils import misc,benchmark
20 |
21 | pth = "scrna/xyzeq_liver"
22 | dpth = path.join(pth,"data")
23 | mpth = path.join(pth,"models/V5")
24 | rpth = path.join(pth,"results")
25 | misc.mkdir_p(rpth)
26 |
27 | #%% Create CSV with benchmarking parameters
28 | csv_path = path.join(rpth,"benchmark.csv")
29 | try:
30 | par = pd.read_csv(csv_path)
31 | except FileNotFoundError:
32 | L = [6,12,20]
33 | sp_mods = ["NSF-P", "NSF-N", "NSF-G", "RSF-G", "NSFH-P", "NSFH-N"]
34 | ns_mods = ["PNMF-P", "PNMF-N", "FA-G"]
35 | sz = ["constant","scanpy"]
36 | M = [288]
37 | par = benchmark.make_param_df(L,sp_mods,ns_mods,M,sz)
38 | par.to_csv(csv_path,index=False)
39 |
40 | #%% Benchmarking at command line [markdown]
41 | """
42 | To run on local computer, use
43 | `python -um utils.benchmark 28 scrna/xyzeq_liver/data/xyzeq_liver_L20C1_mouseonly_J2000.h5ad`
44 | where 2 is a row ID of benchmark.csv, min value 2, max possible value is 49
45 |
46 | To run on cluster first load anaconda environment
47 | ```
48 | tmux
49 | interactive
50 | module load anaconda3/2021.5
51 | conda activate fwt
52 | python -um utils.benchmark 2 scrna/xyzeq_liver/data/xyzeq_liver_L20C1_mouseonly_J2000.h5ad
53 | ```
54 |
55 | To run on cluster as a job array, subset of rows, recommend 2hr time limit
56 | ```
57 | DAT=./scrna/xyzeq_liver/data/xyzeq_liver_L20C1_mouseonly_J2000.h5ad
58 | sbatch --mem=48G --array=2-4 ./utils/benchmark_array.slurm $DAT
59 | ```
60 |
61 | To run on cluster as a job array, all rows of CSV file
62 | ```
63 | CSV=./scrna/xyzeq_liver/results/benchmark.csv
64 | DAT=./scrna/xyzeq_liver/data/xyzeq_liver_L20C1_mouseonly_J2000.h5ad
65 | sbatch --mem=48G --array=2-$(wc -l < $CSV) ./utils/benchmark_array.slurm $DAT
66 | ```
67 | """
68 |
69 | #%% Compute metrics for each model
70 | """
71 | DAT=./scrna/xyzeq_liver/data/xyzeq_liver_L20C1_mouseonly_J2000.h5ad
72 | sbatch --mem=48G ./utils/benchmark_gof.slurm $DAT 5
73 | """
74 |
75 | #%% Examine one result
76 | from matplotlib import pyplot as plt
77 | from utils import training
78 | tro = training.ModelTrainer.from_pickle(path.join(mpth,"L20/poi/NSFH_T10_MaternThreeHalves_M1000"))
79 | plt.plot(tro.loss["train"][-200:-1])
80 |
--------------------------------------------------------------------------------
/scrna/xyzeq_liver/04_benchmark_viz.Rmd:
--------------------------------------------------------------------------------
1 | ---
2 | title: "Benchmarking Visualization"
3 | author: "Will Townes"
4 | output: html_document
5 | ---
6 |
7 | ```{r}
8 | library(tidyverse)
9 | theme_set(theme_bw())
10 | fp<-file.path
11 | pth<-"scrna/xyzeq_liver"
12 | plt_pth<-fp(pth,"results/plots")
13 | if(!dir.exists(plt_pth)){
14 | dir.create(plt_pth,recursive=TRUE)
15 | }
16 | ```
17 |
18 | ```{r}
19 | d0<-read.csv(fp(pth,"results/benchmark.csv"))
20 | d0$converged<-as.logical(d0$converged)
21 | #d0$model<-plyr::mapvalues(d0$model,c("RCF","RPF","NCF","NPFH","NPF"),c("FA","RSF","PNMF","NSFH","NSF"))
22 | d<-subset(d0,converged==TRUE)
23 | d$model<-paste0(d$model," (",d$lik,")")
24 | unique(d$model)
25 | # keep<-c("NPF (nb)", "NPF (poi)", "NPFH (nb)", "NPFH (poi)", "NCF (nb)", "NCF (poi)", "RPF (gau)", "RCF (gau)")#"MEFISTO (gau)",
26 | keep<-c("FA (gau)","MEFISTO (gau)","RSF (gau)","PNMF (nb)","PNMF (poi)","NSFH (nb)","NSFH (poi)","NSF (nb)","NSF (poi)")
27 | d<-subset(d,model %in% keep)
28 | d$model<-factor(d$model,levels=keep)
29 | d$dim<-factor(d$L,levels=sort(unique(d$L)))
30 | d$M[is.na(d$M)]<-288
31 | d$IPs<-factor(d$M,levels=sort(unique(d$M)))
32 | ```
33 | subset of models for simplified main figure
34 | ```{r}
35 | d2<-subset(d,M==288 | (model %in% c("PNMF (poi)","FA (gau)")))
36 | d2a<-subset(d2,(model %in% c("PNMF (poi)","NSFH (poi)","NSF (poi)")) & sz=="scanpy")
37 | d2b<-subset(d2,model %in% c("FA (gau)","RSF (gau)","MEFISTO (gau)"))
38 | d2<-rbind(d2a,d2b)
39 | #d2$model<-factor(d2$model,levels=)
40 | d2$model<-plyr::mapvalues(d2$model,c("FA (gau)","MEFISTO (gau)","RSF (gau)","PNMF (poi)","NSFH (poi)","NSF (poi)"),c("FA","MEFISTO","RSF","PNMF","NSFH","NSF"))
41 | ggplot(d2,aes(x=model,y=dev_val_mean,color=dim,shape=lik))+geom_point(size=6,position=position_dodge(width=0.8))+ylab("validation deviance (mean)")+ylim(range(d2$dev_val_mean)*c(.95,1.02))
42 | ggsave(fp(plt_pth,"xyzeq_liver_gof_dev_val_mean_main.pdf"),width=5,height=2.5)
43 |
44 | #linear regressions for significance
45 | d2$real_valued<-d2$model %in% c("FA","MEFISTO","RSF")
46 | d2$spatial_aware<-d2$model %in% c("MEFISTO","RSF","NSFH","NSF")
47 | summary(lm(dev_val_mean~real_valued+spatial_aware+dim,data=d2))
48 | t.test(dev_val_mean~model,data=subset(d2,model %in% c("NSFH","NSF")), var.equal=TRUE, alternative="greater")
49 |
50 | ggplot(d2,aes(x=model,y=rmse_val,color=dim,shape=lik))+geom_point(size=6,position=position_dodge(width=0.8))+ylab("validation RMSE")+ylim(range(d2$rmse_val)*c(.95,1.02))
51 | ggsave(fp(plt_pth,"xyzeq_liver_gof_rmse_val_simple.pdf"),width=5,height=2.5)
52 |
53 | ggplot(d2,aes(x=model,y=sparsity,color=dim,shape=lik))+geom_point(size=6,position=position_dodge(width=0.8))+ylab("loadings zero fraction")+ylim(range(d2$sparsity)*c(.95,1.02))
54 | ggsave(fp(plt_pth,"xyzeq_liver_sparsity_simple.pdf"),width=5,height=2.5)
55 | ```
56 |
57 | NSF vs NSFH
58 |
59 | ```{r}
60 | d3<-subset(d,grepl("NSF",model) & sz=="scanpy")
61 | ggplot(d3,aes(x=model,y=dev_tr_mean,color=dim))+geom_point(size=4,position=position_dodge(width=0.8))+ylab("training deviance (mean)")
62 | ggsave(fp(plt_pth,"xyzeq_liver_gof_dev_tr_nsf_vs_nsfh.pdf"),width=6,height=3)
63 | ```
64 |
65 | ```{r}
66 | #training deviance - average
67 | ggplot(d,aes(x=model,y=dev_tr_mean,color=dim,shape=sz))+geom_jitter(size=5,width=.2,height=0)+ylab("training deviance (mean)")+theme(legend.position = "top")
68 | ggsave(fp(plt_pth,"xyzeq_liver_gof_dev_tr_mean.pdf"),width=6,height=3)
69 |
70 | #training deviance - max
71 | ggplot(d,aes(x=model,y=dev_tr_max,color=dim,shape=sz))+geom_jitter(size=5,width=.2,height=0)+ylab("training deviance (max)")#+scale_y_log10()
72 |
73 | #validation deviance - mean
74 | ggplot(d,aes(x=model,y=dev_val_mean,color=dim,shape=sz))+geom_jitter(size=4,width=.4,height=0)+ylab("validation deviance (mean)")+theme(legend.position="none")
75 | ggsave(fp(plt_pth,"xyzeq_liver_gof_dev_val_mean.pdf"),width=6,height=2.5)
76 |
77 | #validation deviance - max
78 | ggplot(d,aes(x=model,y=dev_val_max,color=dim,shape=sz))+geom_jitter(size=3,width=.2,height=0)+ylab("validation deviance (max)")#+scale_y_log10()
79 |
80 | #sparsity
81 | ggplot(d,aes(x=model,y=sparsity,color=dim,shape=sz))+geom_jitter(size=4,width=.3,height=0)+ylab("sparsity of loadings")+theme(legend.position="none")
82 | ggsave(fp(plt_pth,"xyzeq_liver_sparsity.pdf"),width=6,height=2.5)
83 |
84 | #wall time
85 | ggplot(d,aes(x=model,y=wtime/60,color=dim,shape=sz))+geom_jitter(size=5,width=.2,height=0)+ylab("wall time (min)")+scale_y_log10()+theme(legend.position="top")
86 | ggsave(fp(plt_pth,"xyzeq_liver_wtime.pdf"),width=6,height=3)
87 | ggplot(d,aes(x=dim,y=wtime/60,color=sz))+geom_jitter(size=3,width=.2,height=0)+ylab("wall time (min)")+scale_y_log10()
88 |
89 | #processor time
90 | ggplot(d,aes(x=model,y=ptime/60,color=dim,shape=sz))+geom_jitter(size=5,width=.2,height=0)+ylab("processor time")+scale_y_log10()+theme(legend.position="none")
91 | ggsave(fp(plt_pth,"xyzeq_liver_ptime.pdf"),width=6,height=2.5)
92 | ```
93 |
--------------------------------------------------------------------------------
/scrna/xyzeq_liver/05_interpret_genes.Rmd:
--------------------------------------------------------------------------------
1 | ---
2 | title: "05_interpretation_genes"
3 | author: "Will Townes"
4 | output: html_document
5 | ---
6 |
7 | ```{r}
8 | # library(tidyverse)
9 | library(biomaRt)
10 | # library(rPanglaoDB)
11 | source("./scrna/utils/interpret_genes.R")
12 |
13 | fp<-file.path
14 | pth<-"scrna/xyzeq_liver"
15 | ```
16 | loadings from NSFH
17 | ```{r}
18 | W<-read.csv(fp(pth,"results/xyz_liv_nsfh6_spde_loadings.csv"),header=TRUE,row.names=1)
19 | rownames(W)<-toupper(rownames(W))
20 | W2<-t(apply(as.matrix(W),1,function(x){as.numeric(x==max(x))}))
21 | colSums(W2) #clustering of genes
22 | ```
23 | Gene Ontology terms, used for both NSFH and hotspot.
24 | ```{r}
25 | # bg<-rownames(W)
26 | db<-useMart("ensembl",host="https://may2021.archive.ensembl.org",dataset='mmusculus_gene_ensembl')
27 | go_ids<-getBM(attributes=c('go_id', 'external_gene_name','namespace_1003'), filters='external_gene_name', values=rownames(W), mart=db)
28 | go_ids[,2]<-toupper(go_ids[,2])
29 | gene_2_GO<-unstack(go_ids[,c(1,2)])
30 | ```
31 | GO analysis for NSFH
32 | ```{r}
33 | # g<-strsplit(res[l,"genes"],", ",fixed=TRUE)[[1]]
34 | # expr <- getMarkers(include = g[c(1:2,4,6,7)])$
35 | res<-loadings2go(W,gene_2_GO)
36 | res$type<-rep(c("spat","nsp"),each=3)
37 | res$dim<-rep(1:3,2)
38 | write.csv(res[,c(1,4,2,3)],fp(pth,"results/xyz_liv_nsfh6_spde_goterms.csv"),row.names=FALSE)
39 | ```
40 | Get cell types from [Panglao](https://panglaodb.se). First manually download the list of all marker genes for all cell types from https://panglaodb.se/markers.html?cell_type=%27all_cells%27 . Store this TSV in the scrna/resources folder.
41 | ```{r}
42 | res<-read.csv(fp(pth,"results/xyz_liv_nsfh6_spde_goterms.csv"),header=TRUE)
43 | panglao_tsv="scrna/resources/PanglaoDB_markers_27_Mar_2020.tsv.gz"
44 | mk<-read.table(gzfile(panglao_tsv),header=TRUE,sep="\t")
45 | mk<-subset(mk,species %in% c("Mm","Mm Hs"))
46 | mk<-subset(mk,organ %in% c("Liver","Immune system","Vasculature","Blood"))
47 | pg<-tapply(mk$official.gene.symbol,mk$cell.type,function(x){x},simplify=FALSE)
48 | ct<-lapply(res$genes,genes2celltype,pg,ss=1)
49 | ct[is.na(ct)]<-""
50 | #the automated search did not work so we tried a manual search on the panglao website
51 | ct[[1]]<-"Hepatocytes"
52 | ct[[3]]<-"Macrophages"
53 | ct[[4]]<-"Fibroblasts"
54 | ct[[6]]<-"Macrophages"
55 | res$celltypes<-paste(ct,sep=", ")
56 | write.csv(res,fp(pth,"results/xyz_liv_nsfh6_spde_goterms.csv"),row.names=FALSE)
57 | ```
58 | Shorter GO table
59 | ```{r}
60 | res<-read.csv(fp(pth,"results/xyz_liv_nsfh6_spde_goterms.csv"),header=TRUE)
61 | res2<-res
62 | g<-strsplit(res2$genes,", ",fixed=TRUE)
63 | res2$genes<-sapply(g,function(x){paste(x[1:5],collapse=", ")})
64 | gt<-strsplit(res2$go_bp,"; ",fixed=TRUE)
65 | res2$go_bp<-sapply(gt,function(x){paste(x[1:2],collapse=", ")})
66 | write.csv(res2,fp(pth,"results/xyz_liv_nsfh6_spde_goterms_short.csv"),row.names=FALSE)
67 | ```
68 |
69 | GO analysis for hotspot
70 | ```{r}
71 | hs<-read.csv(fp(pth,"results/hotspot.csv"),header=TRUE)
72 | rownames(hs)<-hs$Gene
73 | Wh<-model.matrix(~0+factor(Module),data=subset(hs,Module!=-1))
74 | colnames(Wh)<-paste0("X",1:ncol(Wh))
75 | rownames(Wh)<-toupper(rownames(Wh))
76 | # all(rownames(Wh) %in% rownames(W))
77 | hres<-loadings2go(Wh,gene_2_GO, rowmean_divide=TRUE)
78 | colnames(hres)[which(colnames(hres)=="dim")]<-"cluster"
79 | write.csv(hres,fp(pth,"results/xyz_liv_hotspot_goterms.csv"),row.names=FALSE)
80 | hres2<-hres
81 | g<-strsplit(hres2$genes,", ",fixed=TRUE)
82 | hres2$genes<-sapply(g,function(x){paste(x[1:5],collapse=", ")})
83 | gt<-strsplit(hres2$go_bp,"; ",fixed=TRUE)
84 | hres2$go_bp<-sapply(gt,function(x){paste(x[1:2],collapse=", ")})
85 | write.csv(hres2,fp(pth,"results/xyz_liv_hotspot_goterms_short.csv"),row.names=FALSE)
86 | ```
87 |
--------------------------------------------------------------------------------
/scrna/xyzeq_liver/06_traditional.ipy:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # ---
3 | # jupyter:
4 | # jupytext:
5 | # text_representation:
6 | # extension: .ipy
7 | # format_name: percent
8 | # format_version: '1.3'
9 | # jupytext_version: 1.6.0
10 | # kernelspec:
11 | # display_name: Python 3
12 | # language: python
13 | # name: python3
14 | # ---
15 |
16 | #%% imports
17 | import numpy as np
18 | import pandas as pd
19 | import scanpy as sc
20 | import matplotlib.pyplot as plt
21 | from os import path
22 | from hotspot import Hotspot
23 |
24 | from utils import misc,visualize
25 |
26 | dtp = "float32"
27 | pth = "scrna/xyzeq_liver"
28 | dpth = path.join(pth,"data")
29 | mpth = path.join(pth,"models")
30 | rpth = path.join(pth,"results")
31 | plt_pth = path.join(rpth,"plots")
32 |
33 | # %% Data Loading from scanpy
34 | J = 2000
35 | dfile = path.join(dpth,"xyzeq_liver_L20C1_mouseonly.h5ad")
36 | adata = sc.read_h5ad(dfile)
37 | sc.pp.highly_variable_genes(adata, flavor="seurat", n_top_genes=J)
38 | X0 = adata.obsm["spatial"]
39 | X0[:,1] = -X0[:,1]
40 | Z = np.unique(X0, axis=0)
41 | #find mapping between Xtr and Z
42 | from scipy.spatial.distance import cdist
43 | ZX = 1-(cdist(Z,X0)>0) #Ntr x M matrix
44 | ZX2 = ZX/ZX.sum(axis=1)[:,None] #premultiply by this to average within locations
45 |
46 | #%% Traditional scanpy analysis (unsupervised clustering)
47 | #https://scanpy-tutorials.readthedocs.io/en/latest/spatial/basic-analysis.html
48 | sc.pp.pca(adata)
49 | sc.pp.neighbors(adata)
50 | sc.tl.umap(adata)
51 | sc.tl.leiden(adata, resolution=1.0, key_added="clusters")
52 | plt.rcParams["figure.figsize"] = (4, 4)
53 | sc.pl.umap(adata, color="clusters", wspace=0.4)
54 | sc.pl.embedding(adata, "spatial", color="clusters")
55 | cl = pd.get_dummies(adata.obs["clusters"]).to_numpy()
56 |
57 | #%% Visualize clustering
58 | hmkw = {"figsize":(6.5,3.5),"subplot_space":0,"spinecolor":"white","marker":"$\u25AE$"}
59 | tgnames = [str(i) for i in range(1,cl.shape[1]+1)]
60 | fig,axes=visualize.multiheatmap(Z, ZX2@cl, (3,5), s=10, **hmkw)
61 | visualize.set_titles(fig, tgnames, x=0.05, y=.85, fontsize="small", c="white",
62 | ha="left", va="top")
63 | fig.savefig(path.join(plt_pth,"xyz_liv_heatmap_scanpy_clusters.pdf"),
64 | bbox_inches='tight')
65 |
66 | #%% Hotspot analysis (gene clusters)
67 | #https://hotspot.readthedocs.io/en/latest/Spatial_Tutorial.html
68 | J = 2000
69 | dfile = path.join(dpth,"xyzeq_liver_L20C1_mouseonly_J{}.h5ad".format(J))
70 | adata = sc.read_h5ad(dfile)
71 | adata.layers["counts"] = adata.layers["counts"].tocsc()
72 | hs = Hotspot(adata, layer_key="counts", model="danb",
73 | latent_obsm_key="spatial", umi_counts_obs_key="total_counts")
74 | hs.create_knn_graph(weighted_graph=False, n_neighbors=20)
75 | hs_results = hs.compute_autocorrelations()
76 | # hs_results.tail()
77 | hs_genes = hs_results.index#[hs_results.FDR < 0.05]
78 | lcz = hs.compute_local_correlations(hs_genes)
79 | modules = hs.create_modules(min_gene_threshold=20, core_only=False,
80 | fdr_threshold=0.05)
81 | # modules.value_counts()
82 | hs_results = hs_results.join(modules,how="left")
83 | hs_results.to_csv(path.join(rpth,"hotspot.csv"))
84 | misc.pickle_to_file(hs,path.join(mpth,"hotspot.pickle"))
85 |
86 | #%% Hotspot module scores (optional)
87 | import mplscience
88 | module_scores = hs.calculate_module_scores()
89 | # module_scores.head()
90 | module_cols = []
91 | for c in module_scores.columns:
92 | key = f"Module {c}"
93 | adata.obs[key] = module_scores[c]
94 | module_cols.append(key)
95 | with mplscience.style_context():
96 | sc.pl.spatial(adata, color=module_cols, frameon=False, vmin="p0", vmax="p99", spot_size=1)
97 |
--------------------------------------------------------------------------------
/scrna/xyzeq_liver/results/benchmark.csv:
--------------------------------------------------------------------------------
1 | V,L,model,kernel,M,lik,sz,key,converged,epochs,ptime,wtime,elbo_avg_tr,dev_tr_mean,dev_tr_argmax,dev_tr_max,dev_tr_med,rmse_tr,dev_val_mean,dev_val_argmax,dev_val_max,dev_val_med,rmse_val,sparsity,elbo_avg_val
2 | 5,6,MEFISTO,ExponentiatedQuadratic,288,gau,constant,V5/L6/gau/MEFISTO_ExponentiatedQuadratic_M288,TRUE,33,475.0265762,40.65122175,-577.3475467,4.134167,1,15.144811,4.1147914,1.9141325,4.113402888,1,17.16539548,4.098728814,1.963800171,0.22425,
3 | 5,12,MEFISTO,ExponentiatedQuadratic,288,gau,constant,V5/L12/gau/MEFISTO_ExponentiatedQuadratic_M288,TRUE,33,980.6648019,85.63033605,-582.2198654,4.120909,1,15.31339,4.101304,1.9123095,4.113208661,1,16.21541236,4.104301253,1.968408833,0.602541667,
4 | 5,20,MEFISTO,ExponentiatedQuadratic,288,gau,constant,V5/L20/gau/MEFISTO_ExponentiatedQuadratic_M288,TRUE,34,1491.377479,129.0016003,-589.3510638,4.136579,1,15.082991,4.117299,1.9100683,4.143368165,1,16.62485705,4.128560757,1.971220716,0.76875,
5 | 5,6,PNMF,,,nb,constant,V5/L6/nb_sz-constant/PNMF,TRUE,920,2524.103027,246.219986,-1720.036011,5.041337,0,6.126661,5.0378885,1.7476368,7.603628159,3,19.91540909,7.524975777,3.220266819,0.198833333,-1798.789551
6 | 5,6,PNMF,,,nb,scanpy,V5/L6/nb_sz-scanpy/PNMF,TRUE,940,2491.810059,243.1531219,-1517.83667,5.0389776,0,6.230536,5.0354776,1.7511828,6.049181461,1,34.24664688,5.948660374,5.179713249,0.203083333,-1606.150024
7 | 5,6,PNMF,,,poi,constant,V5/L6/poi_sz-constant/PNMF,TRUE,360,301.4029236,39.87991333,-2776.731445,5.0833573,5,5.1073666,5.083555,1.6904502,8.216876984,0,10.31153202,8.245800018,2.171020985,0.2005,-2931.80835
8 | 5,6,PNMF,,,poi,scanpy,V5/L6/poi_sz-scanpy/PNMF,TRUE,330,226.1258087,30.06169701,-2404.501709,5.074917,5,5.0900927,5.074988,1.6942964,6.917968273,0,13.38948631,6.929588318,2.131427288,0.202416667,-2541.011475
9 | 5,12,PNMF,,,nb,constant,V5/L12/nb_sz-constant/PNMF,TRUE,940,2676.507324,259.4359436,-1683.691284,4.868257,0,7.232113,4.8613777,1.7105677,7.549126148,3,27.5600853,7.375882626,6.447946072,0.326208333,-1761.467041
10 | 5,12,PNMF,,,nb,scanpy,V5/L12/nb_sz-scanpy/PNMF,TRUE,910,2454.509033,245.877243,-1494.77002,4.867705,0,7.1815634,4.8610163,1.7017567,6.839495182,3,54.47519684,6.406699181,21.46104813,0.338625,-1582.869751
11 | 5,12,PNMF,,,poi,constant,V5/L12/poi_sz-constant/PNMF,TRUE,370,270.802063,35.36241531,-2704.347412,4.9247146,0,5.0390167,4.9248867,1.5932386,8.489975929,0,12.06497669,8.515666962,2.344451904,0.320958333,-2869.649658
12 | 5,12,PNMF,,,poi,scanpy,V5/L12/poi_sz-scanpy/PNMF,TRUE,370,261.4364014,35.36388016,-2378.112793,4.9237103,0,5.0717387,4.9236794,1.5926809,7.006500721,0,11.88849831,7.025444031,2.141976118,0.322875,-2502.074951
13 | 5,20,PNMF,,,nb,constant,V5/L20/nb_sz-constant/PNMF,TRUE,930,2510.842773,246.0563812,-1652.982666,4.718672,2,8.630838,4.705968,1.9244496,7.915697575,1,12.51301765,7.900341988,2.530363083,0.41855,-1709.746704
14 | 5,20,PNMF,,,nb,scanpy,V5/L20/nb_sz-scanpy/PNMF,TRUE,900,2409.625244,236.7058563,-1468.537842,4.7337065,2,8.253958,4.721704,1.8141147,6.628312588,1,13.39372826,6.609304428,2.209197044,0.4285,-1550.239258
15 | 5,20,PNMF,,,poi,constant,V5/L20/poi_sz-constant/PNMF,TRUE,370,271.4530029,36.1639328,-2636.20874,4.788446,4,4.8216515,4.788552,1.4579074,8.605758667,0,11.07707977,8.640348434,2.193706274,0.42125,-2758.330078
16 | 5,20,PNMF,,,poi,scanpy,V5/L20/poi_sz-scanpy/PNMF,TRUE,320,221.6009521,30.11948395,-2283.694824,4.799126,4,4.849787,4.7993846,1.4611748,7.195641518,0,11.38509369,7.226406574,2.085646868,0.424725,-2395.852295
17 | 5,6,NSF,MaternThreeHalves,288,gau,constant,V5/L6/gau/NSF_MaternThreeHalves_M288,TRUE,330,514.3267212,67.1272583,-22716.13086,4.0782957,1,16.631784,4.0586843,1.9377972,4.038724422,1,19.72306061,4.026765347,2.007090092,0.182166667,-24909.04883
18 | 5,6,NSF,MaternThreeHalves,288,nb,constant,V5/L6/nb_sz-constant/NSF_MaternThreeHalves_M288,TRUE,1080,4598.710449,448.9979553,-1197.514282,6.1589003,1,6.669904,6.1590214,1.9816872,6.435605526,0,7.976939678,6.452908039,2.02231884,0.181416667,-1290.170288
19 | 5,6,NSF,MaternThreeHalves,288,nb,scanpy,V5/L6/nb_sz-scanpy/NSF_MaternThreeHalves_M288,TRUE,980,3909.734863,400.6589661,-1157.806885,5.4424324,0,7.322513,5.4386325,1.8403312,5.85012722,0,8.568701744,5.868012428,1.899519563,0.1945,-1247.425659
20 | 5,6,NSF,MaternThreeHalves,288,poi,constant,V5/L6/poi_sz-constant/NSF_MaternThreeHalves_M288,TRUE,190,319.7627563,41.72207642,-1945.525513,6.1958456,0,6.313511,6.1993403,1.9779277,6.510575294,0,7.618911266,6.531297684,2.039809704,0.204666667,-2105.933594
21 | 5,6,NSF,MaternThreeHalves,288,poi,scanpy,V5/L6/poi_sz-scanpy/NSF_MaternThreeHalves_M288,TRUE,160,271.1428833,36.14931488,-1750.665161,5.4353385,0,5.6069202,5.436984,1.7997899,5.829191208,0,7.138292313,5.85113287,1.86504066,0.228416667,-1879.915771
22 | 5,12,NSF,MaternThreeHalves,288,gau,constant,V5/L12/gau/NSF_MaternThreeHalves_M288,TRUE,330,729.2949829,81.42583466,-22867.52734,4.06159,1,16.795048,4.042221,1.9832529,4.002580643,1,22.01260757,3.99038434,2.108330011,0.29475,-25116.64063
23 | 5,12,NSF,MaternThreeHalves,288,nb,constant,V5/L12/nb_sz-constant/NSF_MaternThreeHalves_M288,TRUE,1060,5142.533691,495.576355,-1193.328491,6.0695033,0,12.234192,6.0640707,2.4376976,6.468214989,0,14.01273441,6.480208397,2.521373749,0.247791667,-1316.597778
24 | 5,12,NSF,MaternThreeHalves,288,nb,scanpy,V5/L12/nb_sz-scanpy/NSF_MaternThreeHalves_M288,TRUE,2570,12512.71582,1208.075195,-1247.108032,5.347305,0,13.777866,5.3390913,2.9746828,5.852240086,0,16.2920723,5.864924431,2.987389803,0.245,-1377.313843
25 | 5,12,NSF,MaternThreeHalves,288,poi,constant,V5/L12/poi_sz-constant/NSF_MaternThreeHalves_M288,TRUE,170,395.9271851,46.07463837,-1936.16333,6.1576486,0,6.3947344,6.162737,1.9496453,6.505647659,0,8.054733276,6.526143074,2.027067661,0.312333333,-2123.313965
26 | 5,12,NSF,MaternThreeHalves,288,poi,scanpy,V5/L12/poi_sz-scanpy/NSF_MaternThreeHalves_M288,TRUE,210,481.3469238,61.67857361,-1740.683472,5.406471,18,5.4939284,5.4104576,1.7529149,5.860135555,0,7.361698627,5.884509563,1.859587193,0.35775,-1902.700195
27 | 5,20,NSF,MaternThreeHalves,288,gau,constant,V5/L20/gau/NSF_MaternThreeHalves_M288,TRUE,330,966.2658691,107.8690338,-23064.33203,4.0386744,1,15.490185,4.017086,2.0383823,3.966190338,1,16.15410233,3.954101324,2.110280037,0.390325,-25329.31445
28 | 5,20,NSF,MaternThreeHalves,288,nb,constant,V5/L20/nb_sz-constant/NSF_MaternThreeHalves_M288,TRUE,1090,5905.898926,600.9832153,-1186.449829,5.9732676,0,18.473352,5.9590197,3.865973,6.459947586,0,20.63406563,6.465190887,3.858034372,0.28045,-1337.837524
29 | 5,20,NSF,MaternThreeHalves,288,nb,scanpy,V5/L20/nb_sz-scanpy/NSF_MaternThreeHalves_M288,TRUE,970,5574.790527,538.3094482,-1146.770874,5.2692337,0,20.756817,5.250762,5.047426,5.878831387,0,25.57594109,5.885680199,4.886908531,0.285075,-1294.078247
30 | 5,20,NSF,MaternThreeHalves,288,poi,constant,V5/L20/poi_sz-constant/NSF_MaternThreeHalves_M288,TRUE,130,421.626709,48.69716644,-1932.29895,6.166916,226,6.225311,6.1753907,1.9243623,6.500660896,0,7.525696278,6.523514748,1.995321751,0.35825,-2144.512207
31 | 5,20,NSF,MaternThreeHalves,288,poi,scanpy,V5/L20/poi_sz-scanpy/NSF_MaternThreeHalves_M288,TRUE,220,686.5084839,74.90756226,-1719.314697,5.290785,0,5.424759,5.290799,1.684984,5.8166008,0,7.393985748,5.837240696,1.810169578,0.429025,-1924.939697
32 | 5,6,NSFH,MaternThreeHalves,288,nb,constant,V5/L6/nb_sz-constant/NSFH_T3_MaternThreeHalves_M288,TRUE,1010,4063.513672,400.4521179,-1519.04895,5.208853,3,31.806583,5.1926947,2.4516542,6.999220848,3,45.26802063,6.974718094,2.682162762,0.211166667,-1612.036987
33 | 5,6,NSFH,MaternThreeHalves,288,nb,scanpy,V5/L6/nb_sz-scanpy/NSFH_T3_MaternThreeHalves_M288,TRUE,950,3661.553711,372.4749756,-1265.884644,5.2266626,0,7.021208,5.224578,1.7982568,6.028244019,0,8.544682503,6.021720886,2.196630001,0.24175,-1350.790039
34 | 5,6,NSFH,MaternThreeHalves,288,poi,constant,V5/L6/poi_sz-constant/NSFH_T3_MaternThreeHalves_M288,TRUE,550,810.0413208,105.2309952,-2577.969727,5.2439775,3,10.560041,5.2404513,1.8573103,7.400186062,3,17.4361763,7.400484085,2.282255411,0.21225,-2746.791992
35 | 5,6,NSFH,MaternThreeHalves,288,poi,scanpy,V5/L6/poi_sz-scanpy/NSFH_T3_MaternThreeHalves_M288,TRUE,230,344.7897034,47.61157227,-1875.325195,5.2310743,0,5.37592,5.231882,1.7471918,6.22217989,0,7.375682354,6.249918938,1.910659194,0.287166667,-1992.292114
36 | 5,12,NSFH,MaternThreeHalves,288,nb,constant,V5/L12/nb_sz-constant/NSFH_T6_MaternThreeHalves_M288,TRUE,960,4144.26123,408.4788513,-1505.599365,5.036107,0,35.802643,5.0135765,2.824309,7.508047104,0,42.9197998,7.464859009,3.275102854,0.295041667,-1601.634155
37 | 5,12,NSFH,MaternThreeHalves,288,nb,scanpy,V5/L12/nb_sz-scanpy/NSFH_T6_MaternThreeHalves_M288,TRUE,940,3446.153076,343.68396,-1249.673828,5.1136417,0,14.9740715,5.107809,1.995267,6.045698643,0,22.48182297,6.050631046,2.206949472,0.332375,-1340.674072
38 | 5,12,NSFH,MaternThreeHalves,288,poi,constant,V5/L12/poi_sz-constant/NSFH_T6_MaternThreeHalves_M288,TRUE,480,724.5327148,89.72132111,-2447.5979,5.072291,0,6.2428656,5.071582,1.6995671,7.525919437,0,10.80947208,7.541521549,2.137618065,0.310208333,-2571.26001
39 | 5,12,NSFH,MaternThreeHalves,288,poi,scanpy,V5/L12/poi_sz-scanpy/NSFH_T6_MaternThreeHalves_M288,TRUE,260,421.4099426,55.64730453,-1839.009766,5.132295,0,6.134202,5.132534,1.6707282,6.011397362,0,10.5695591,6.034472466,1.912448764,0.378625,-1958.931885
40 | 5,20,NSFH,MaternThreeHalves,288,nb,constant,V5/L20/nb_sz-constant/NSFH_T10_MaternThreeHalves_M288,TRUE,930,3572.945557,359.5145569,-1486.776611,4.953084,0,11.2164545,4.943058,2.6832116,7.717191219,0,16.1084938,7.714572906,2.997145176,0.37005,-1586.578613
41 | 5,20,NSFH,MaternThreeHalves,288,nb,scanpy,V5/L20/nb_sz-scanpy/NSFH_T10_MaternThreeHalves_M288,TRUE,930,3610.885986,363.6177673,-1275.165161,4.9190106,0,26.494144,4.9006205,2.5711167,6.319021702,0,33.04694748,6.318647385,2.331879377,0.40405,-1384.026123
42 | 5,20,NSFH,MaternThreeHalves,288,poi,constant,V5/L20/poi_sz-constant/NSFH_T10_MaternThreeHalves_M288,TRUE,330,664.0027466,85.68494415,-2215.143066,5.0840487,0,5.842009,5.0865946,1.6648625,7.702683449,0,8.923201561,7.737083435,2.080332994,0.4205,-2381.046387
43 | 5,20,NSFH,MaternThreeHalves,288,poi,scanpy,V5/L20/poi_sz-scanpy/NSFH_T10_MaternThreeHalves_M288,TRUE,300,586.2957764,70.1000824,-1871.017578,5.001067,0,5.0868735,5.0011663,1.601763,6.247176647,0,7.752833366,6.273864746,1.897039056,0.4557,-2026.676025
44 | 5,6,FA,,,gau,constant,V5/L6/gau/FA,TRUE,330,259.3052979,36.04651642,-25486.67188,3.684314,1,12.101881,3.6632836,1.8165487,4.752997875,1,21.94866753,4.741261005,2.187690258,0,-27738.12109
45 | 5,12,FA,,,gau,constant,V5/L12/gau/FA,TRUE,520,374.788208,47.83795547,-26551.11719,3.585647,1,13.961616,3.5635648,1.7608483,4.825815201,1,21.52416992,4.813059807,2.269764185,0.000125,-28805.42578
46 | 5,20,FA,,,gau,constant,V5/L20/gau/FA,TRUE,560,409.4586487,53.65220261,-27751.89844,3.5128248,1,14.16262,3.4883509,1.7047457,4.784780502,1,21.31032944,4.769587517,2.233421087,0.0001,-30139.4668
47 | 5,6,RSF,MaternThreeHalves,288,gau,constant,V5/L6/gau/RSF_MaternThreeHalves_M288,TRUE,330,515.1203613,66.48441315,-23721.80469,4.062435,1,15.900438,4.042738,1.9017575,3.97942543,1,18.46582603,3.969529152,1.962098241,0,-25971.41602
48 | 5,12,RSF,MaternThreeHalves,288,gau,constant,V5/L12/gau/RSF_MaternThreeHalves_M288,TRUE,330,710.8594971,82.67900085,-23931.0625,4.034443,1,15.513014,4.014183,1.8843172,3.930668831,1,18.57084656,3.918994904,1.947594881,0.000125,-26235.27734
49 | 5,20,RSF,MaternThreeHalves,288,gau,constant,V5/L20/gau/RSF_MaternThreeHalves_M288,TRUE,330,823.336853,91.08987427,-24120.2793,4.0221286,1,15.5932865,4.0022526,1.865359,3.934701681,1,17.80091286,3.926838636,1.922031164,0,-26500.32617
--------------------------------------------------------------------------------
/simulations/.gitignore:
--------------------------------------------------------------------------------
1 | models
2 |
--------------------------------------------------------------------------------
/simulations/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/willtownes/nsf-paper/0cacf8352e09d223ab8d4421025195358bbde8df/simulations/__init__.py
--------------------------------------------------------------------------------
/simulations/benchmark.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Example usage:
5 |
6 | python -m simulations.benchmark 2 simulations/bm_sp
7 |
8 | Fit a model to a dataset and save the results.
9 |
10 | Slightly different from utils.benchmark, which assumes there is only one possible
11 | data file under the parent directory. Here, we allow the data directory to contain
12 | multiple H5AD files and ensure the saved models indicate the scenario in the "key" (filepath).
13 |
14 | Result:
15 | * Loads data, fits a model
16 | * pickles a fitted ModelTrainer object under [dataset]/models/ directory
17 |
18 | Pickled model naming conventions:
19 |
20 | file scheme for spatial models:
21 | [dataset]/models/S[scenario]/V[val frac]/L[factors]/[likelihood]/[model]_[kernel]_M[inducing_pts]/epoch[epoch].pickle
22 |
23 | file scheme for nonspatial models:
24 | [dataset]/models/S[scenario]/V[val frac]/L[factors]/[model]/epoch[epoch].pickle
25 |
26 | @author: townesf
27 | """
28 | from os import path
29 | from argparse import ArgumentParser
30 | from utils.misc import read_csv_oneline
31 | from utils import benchmark as ubm
32 |
33 | def benchmark(ID,pth):
34 | """
35 | Run benchmarking on dataset for the model specified in benchmark.csv in row ID.
36 | """
37 | # dsplit = dataset.split("/data/")
38 | # pth = dsplit[0]
39 | csv_file = path.join(pth,"results/benchmark.csv")
40 | #header of CSV is row zero
41 | p = read_csv_oneline(csv_file,ID-1)
42 | opath = path.join(pth,"models",p["key"]) #p["key"] includes "S{}/" for simulations
43 | print("{}".format(p["key"]))
44 | if path.isfile(path.join(opath,"converged.pickle")):
45 | print("Benchmark already complete, exiting.")
46 | return None
47 | else:
48 | print("Starting benchmark.")
49 | train_frac = ubm.val2train_frac(p["V"])
50 | dataset = path.join(pth,"data/S{}.h5ad".format(p["scenario"])) #different in simulations
51 | D,fmeans = ubm.load_data(dataset,model=p['model'],lik=p['lik'],sz=p['sz'],
52 | train_frac=train_frac, flip_yaxis=False)
53 | fit = ubm.init_model(D,p,opath,fmeans=fmeans)
54 | tro = ubm.fit_model(D,fit,p,opath)
55 | return tro
56 |
57 | def arghandler(args=None):
58 | """parses a list of arguments (default is sys.argv[1:])"""
59 | parser = ArgumentParser()
60 | parser.add_argument("id", type=int,
61 | help="line in benchmark csv from which to get parameters")
62 | parser.add_argument("path", type=str,
63 | help="top level directory containing with subfolders 'data', 'models', and 'results'.")
64 | args = parser.parse_args(args) #if args is None, this will automatically parse sys.argv[1:]
65 | return args
66 |
67 | if __name__=="__main__":
68 | # #input from argparse
69 | # ID = 2 #slurm job array uses 1-based indexing (2,3,...,43)
70 | # #start with 2 to avoid header row of CSV
71 | # DATASET = "simulations/bm_sp/data/S1.h5ad"
72 | args = arghandler()
73 | tro = benchmark(args.id, args.path)
74 |
--------------------------------------------------------------------------------
/simulations/benchmark.slurm:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH --job-name=benchmark # create a short name for your job
3 | #SBATCH --nodes=1 # node count
4 | #SBATCH --ntasks=1 # total number of tasks across all nodes
5 | #SBATCH --cpus-per-task=4 # cpu-cores per task (>1 if multi-threaded tasks)
6 | #SBATCH --time=01:02:00 # total run time limit (HH:MM:SS)
7 | #SBATCH --mail-type=END,FAIL
8 | #SBATCH --mail-user=ftownes@princeton.edu
9 |
10 | #example usage, default is 4G memory per core (--mem-per-cpu), --mem is total
11 | #CSV=./simulations/bm_sp/results/benchmark.csv
12 | #PTH=./simulations/bm_sp
13 | #sbatch --mem=16G --array=1-$(wc -l < $CSV) ./simulations/benchmark.slurm $PTH
14 |
15 | module purge
16 | module load anaconda3/2021.5
17 | conda activate fwt
18 |
19 | #first command line arg $1 is file path to parent directory of dataset
20 | python -um simulations.benchmark $SLURM_ARRAY_TASK_ID $1
21 |
--------------------------------------------------------------------------------
/simulations/benchmark_gof.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | python -m simulations.benchmark_gof simulations/bm_sp
5 | """
6 | import numpy as np
7 | import pandas as pd
8 | from os import path
9 | from argparse import ArgumentParser
10 | from contextlib import suppress
11 | from scanpy import read_h5ad
12 | from scipy.spatial.distance import pdist,cdist
13 | from scipy.stats.stats import pearsonr,spearmanr
14 | from sklearn.cluster import KMeans
15 | from sklearn import metrics
16 |
17 | from utils import postprocess
18 | from utils.preprocess import load_data
19 | from utils import benchmark as ubm
20 |
21 | def compare_to_truth(ad, Ntr, fit, model):
22 | ad = ad[:Ntr,:]
23 | X = ad.obsm["spatial"]
24 | #extract ground truth factors and loadings
25 | tru = postprocess.interpret_nonneg(ad.obsm["spfac"], ad.varm["spload"],
26 | lda_mode=False)
27 | F0 = tru["factors"]
28 | W0 = tru["loadings"]
29 | FFd0 = pdist(F0) #cell-cell distances (vectorized)
30 | WWd0 = pdist(W0) #gene-gene distances (vectorized)
31 | ifit = postprocess.interpret_fit(fit,X,model)
32 | F = ifit["factors"]
33 | W = ifit["loadings"]
34 | Fpc = cdist(F0.T,F.T,metric=lambda x,y: abs(pearsonr(x,y)[0])).max(axis=1)
35 | Fsc = cdist(F0.T,F.T,metric=lambda x,y: abs(spearmanr(x,y)[0])).max(axis=1)
36 | Wpc = cdist(W0.T,W.T,metric=lambda x,y: abs(pearsonr(x,y)[0])).max(axis=1)
37 | Wsc = cdist(W0.T,W.T,metric=lambda x,y: abs(spearmanr(x,y)[0])).max(axis=1)
38 | res = {}
39 | res["factors_pearson"] = Fpc.tolist()
40 | res["factors_spearman"] = Fsc.tolist()
41 | res["loadings_pearson"] = Wpc.tolist()
42 | res["loadings_spearman"] = Wsc.tolist()
43 | FFd = pdist(F)
44 | WWd = pdist(W)
45 | res["dfactors_pearson"] = pearsonr(FFd0,FFd)[0]
46 | res["dfactors_spearman"] = spearmanr(FFd0,FFd)[0]
47 | res["dloadings_pearson"] = pearsonr(WWd0,WWd)[0]
48 | res["dloadings_spearman"] = spearmanr(WWd0,WWd)[0]
49 | nclust = W0.shape[1]
50 | km0 = KMeans(n_clusters=nclust).fit(W0).labels_
51 | km1 = KMeans(n_clusters=nclust).fit(W).labels_
52 | res["loadings_clust_ari"] = metrics.adjusted_rand_score(km0, km1)
53 | return res
54 |
55 | def compare_to_truth_mixed(ad, Ntr, fit, model):
56 | ad = ad[:Ntr,:].copy()
57 | X = ad.obsm["spatial"]
58 | #extract factors and loadings
59 | tru = postprocess.interpret_nonneg_mixed(ad.obsm["spfac"], ad.varm["spload"],
60 | ad.obsm["nsfac"], ad.varm["nsload"],
61 | lda_mode=False)
62 | F0 = tru["spatial"]["factors"]
63 | W0 = tru["spatial"]["loadings"]
64 | H0 = tru["nonspatial"]["factors"]
65 | V0 = tru["nonspatial"]["loadings"]
66 | FH0 = np.concatenate((F0,H0),axis=1)
67 | WV0 = np.concatenate((W0,V0),axis=1)
68 | alpha0 = W0.sum(axis=1) #true spatial importances
69 | ifit = postprocess.interpret_fit(fit,X,model)
70 | if model == "NSFH":
71 | F = ifit["spatial"]["factors"]
72 | W = ifit["spatial"]["loadings"]
73 | H = ifit["nonspatial"]["factors"]
74 | V = ifit["nonspatial"]["loadings"]
75 | FH = np.concatenate((F,H),axis=1)
76 | WV = np.concatenate((W,V),axis=1)
77 | alpha = W.sum(axis=1) #estimated spatial importances
78 | else:
79 | FH = ifit["factors"]
80 | WV = ifit["loadings"]
81 | if model =="NSF":
82 | alpha = np.ones_like(alpha0)
83 | elif model=="PNMF":
84 | alpha = np.zeros_like(alpha0)
85 | else:
86 | alpha = None
87 | Fpc = cdist(FH0.T,FH.T,metric=lambda x,y: abs(pearsonr(x,y)[0])).max(axis=1)
88 | Fsc = cdist(FH0.T,FH.T,metric=lambda x,y: abs(spearmanr(x,y)[0])).max(axis=1)
89 | Wpc = cdist(WV0.T,WV.T,metric=lambda x,y: abs(pearsonr(x,y)[0])).max(axis=1)
90 | Wsc = cdist(WV0.T,WV.T,metric=lambda x,y: abs(spearmanr(x,y)[0])).max(axis=1)
91 | res = {}
92 | res["factors_pearson"] = Fpc.tolist()
93 | res["factors_spearman"] = Fsc.tolist()
94 | res["loadings_pearson"] = Wpc.tolist()
95 | res["loadings_spearman"] = Wsc.tolist()
96 | if alpha is None:
97 | sp_imp_dist = None
98 | else:
99 | sp_imp_dist = np.abs(alpha-alpha0).sum() #L1 distance
100 | res["spatial_importance_dist"] = sp_imp_dist
101 | nclust = WV0.shape[1]
102 | km0 = KMeans(n_clusters=nclust).fit(WV0).labels_
103 | km1 = KMeans(n_clusters=nclust).fit(WV).labels_
104 | res["loadings_clust_ari"] = metrics.adjusted_rand_score(km0, km1)
105 | return res
106 |
107 | def row_metrics(row,pth,verbose=True,mode="bm_sp"):
108 | row = dict(row) #originally row is a pandas.Series object
109 | train_frac = ubm.val2train_frac(row["V"])
110 | if not "converged" in row or not row['converged']:# or row[mnames].isnull().any():
111 | dataset = path.join(pth,"data","S{}.h5ad".format(row["scenario"]))
112 | ad = read_h5ad(path.normpath(dataset))
113 | pkl = path.join(pth,"models",row["key"])
114 | if row["sz"]=="scanpy":
115 | D,fmeans = load_data(ad, model="NSF", lik="poi", sz="scanpy",
116 | train_frac=train_frac, flip_yaxis=False)
117 | else:
118 | D,fmeans = load_data(ad, model=None, lik=None, sz="constant",
119 | train_frac=train_frac, flip_yaxis=False)
120 | with suppress(FileNotFoundError):
121 | fit,tro = ubm.load(pkl)
122 | if verbose: print(row["key"])
123 | metrics = ubm.get_metrics(fit,D["raw"]["tr"],D["raw"]["val"],tro=tro)
124 | row.update(metrics)
125 | Ntr = D["raw"]["tr"]["X"].shape[0]
126 | if mode=="bm_sp":
127 | metrics2 = compare_to_truth(ad, Ntr, fit, row["model"])
128 | elif mode=="bm_mixed":
129 | metrics2 = compare_to_truth_mixed(ad, Ntr, fit, row["model"])
130 | else:
131 | raise ValueError("mode must be either 'bm_sp' or 'bm_mixed'")
132 | row.update(metrics2)
133 | return row
134 |
135 | def update_results(pth,todisk=True,verbose=True):
136 | """
137 | different than the utils.benchmark.update_results because it loads a separate
138 | dataset and model for each row of the results.csv to accommmodate multiple
139 | simulation scenarios
140 | """
141 | # pth = "simulations/bm_sp"
142 | csv_file = path.join(pth,"results/benchmark.csv")
143 | res = pd.read_csv(csv_file)
144 | res = res.apply(row_metrics, args=(pth,), axis=1, result_type="expand",
145 | verbose=verbose, mode=path.split(pth)[-1])
146 | if todisk: res.to_csv(csv_file,index=False)
147 | return res
148 |
149 | def arghandler(args=None):
150 | """parses a list of arguments (default is sys.argv[1:])"""
151 | parser = ArgumentParser()
152 | parser.add_argument("path", type=str,
153 | help="top level directory containing subfolders 'data', 'models', and 'results'.")
154 | args = parser.parse_args(args) #if args is None, this will automatically parse sys.argv[1:]
155 | return args
156 |
157 | if __name__=="__main__":
158 | args = arghandler()
159 | res = update_results(args.path, todisk=True)
160 |
--------------------------------------------------------------------------------
/simulations/benchmark_gof.slurm:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH --job-name=benchmark_gof # create a short name for your job
3 | #SBATCH --nodes=1 # node count
4 | #SBATCH --ntasks=1 # total number of tasks across all nodes
5 | #SBATCH --cpus-per-task=4 # cpu-cores per task (>1 if multi-threaded tasks)
6 | #SBATCH --time=00:20:00 # total run time limit (HH:MM:SS)
7 | #SBATCH --mail-type=END,FAIL
8 | #SBATCH --mail-user=ftownes@princeton.edu
9 |
10 | #example usage, --mem-per-cpu default is 4G per core
11 | #PTH=./simulations/bm_sp
12 | #sbatch --mem=16G ./simulations/benchmark_gof.slurm $PTH
13 |
14 | module purge
15 | module load anaconda3/2021.5
16 | conda activate fwt
17 |
18 | #first command line arg $1 is file path to parent directory
19 | python -um simulations.benchmark_gof $1
20 |
--------------------------------------------------------------------------------
/simulations/bm_mixed/01_data_generation.ipy:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # ---
3 | # jupyter:
4 | # jupytext:
5 | # text_representation:
6 | # extension: .ipy
7 | # format_name: percent
8 | # format_version: '1.3'
9 | # jupytext_version: 1.6.0
10 | # kernelspec:
11 | # display_name: Python 3
12 | # language: python
13 | # name: python3
14 | # ---
15 |
16 | #%% imports
17 | from os import path
18 | from copy import deepcopy
19 | import numpy as np
20 | import pandas as pd
21 | from scanpy import read_h5ad
22 | from janitor import expand_grid
23 | from utils import visualize,misc
24 | from simulations import sim
25 | pth = "simulations/bm_mixed"
26 | dpth = path.join(pth,"data")
27 | rpth = path.join(pth,"results")
28 | mpth = path.join(pth,"models")
29 | # misc.mkdir_p(dpth) # or use symlink to dropbox
30 | # misc.mkdir_p(rpth)
31 | # misc.mkdir_p(mpth)
32 |
33 | # %%
34 | cfg = {"sim":["quilt","ggblocks","both"], "nside":36, "nzprob_nsp":0.2,
35 | "bkg_mean":0.2, "nb_shape":10.0,
36 | "J":[(250,0,250),(0,500,0)], #"Jsp":0, "Jmix":500, "Jns":0,
37 | "expr_mean":20.0, "mix_frac_spat":0.6,
38 | "seed":[1,2,3,4,5], "V":5}
39 | a = expand_grid(others=cfg)
40 | a.rename(columns={"J_0":"Jsp", "J_1":"Jmix", "J_2":"Jns"}, inplace=True)
41 | b = pd.DataFrame({"sim":["quilt","ggblocks","both"],
42 | "Lsp":[4,4,8],"Lns":[3,3,6]})
43 | a = a.merge(b,how="left",on="sim")
44 | a["scenario"] = list(range(1,a.shape[0]+1))
45 | a.to_csv("simulations/bm_mixed/scenarios.csv",index=False)
46 |
47 | # %% generate the simulated datasets and store to disk
48 | a = pd.read_csv(path.join(pth,"scenarios.csv")).convert_dtypes()
49 | def sim2disk(p):
50 | p = deepcopy(p)
51 | scen = p.pop("scenario")
52 | Lsp = p.pop("Lsp")
53 | ad = sim.sim(p["sim"], **p)
54 | ad.write_h5ad(path.join(dpth,"S{}.h5ad".format(scen)),compression="gzip")
55 | a.apply(sim2disk,axis=1)
56 |
57 | # %% check the hdf5 file is correct
58 | ad = read_h5ad(path.join(dpth,"S1.h5ad"))
59 | X = ad.obsm["spatial"]
60 | Y = ad.layers["counts"]
61 | Yn = ad.X
62 | visualize.heatmap(X,Y[:,0],cmap="Blues")
63 | visualize.heatmap(X,Yn[:,0],cmap="Blues")
64 | #check distribution of validation data points
65 | N = Y.shape[0]
66 | z = np.zeros(N)
67 | Ntr = round(0.95*N)
68 | z[Ntr:] = 1
69 | visualize.heatmap(X,z,cmap="Blues")
70 |
71 | # %% merge with models to make results csv for tracking model runs
72 | m = pd.read_csv(path.join(pth,"models.csv")).convert_dtypes() #this CSV was manually created
73 | d = a.merge(m,how="cross")
74 | d["L"] = d["Lsp"]+d["Lns"]
75 | d["T"] = d["Lsp"]
76 | d["key"] = d.apply(misc.params2key,axis=1)
77 | d["key"] = d.agg(lambda x: f"S{x['scenario']}/{x['key']}", axis=1)
78 | d["converged"] = False
79 | d.to_csv(path.join(rpth,"benchmark.csv"),index=False)
80 |
--------------------------------------------------------------------------------
/simulations/bm_mixed/02_benchmark.ipy:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # ---
3 | # jupyter:
4 | # jupytext:
5 | # text_representation:
6 | # extension: .ipy
7 | # format_name: percent
8 | # format_version: '1.3'
9 | # jupytext_version: 1.6.0
10 | # kernelspec:
11 | # display_name: Python 3
12 | # language: python
13 | # name: python3
14 | # ---
15 |
16 | # %%
17 | from os import path
18 | # from copy import deepcopy
19 | # import numpy as np
20 | import pandas as pd
21 | # from matplotlib import pyplot as plt
22 | # from scipy.spatial.distance import pdist,cdist
23 | # from scipy.stats.stats import pearsonr,spearmanr
24 | # from sklearn.cluster import KMeans
25 | # from sklearn import metrics
26 | from scanpy import read_h5ad
27 | from tensorflow_probability import math as tm
28 | tfk = tm.psd_kernels
29 |
30 | from utils import training,postprocess,visualize
31 | # from simulations import sim
32 | pth = "simulations/bm_mixed"
33 | dpth = path.join(pth,"data")
34 | rpth = path.join(pth,"results")
35 | mpth = path.join(pth,"models")
36 |
37 | #%% Benchmarking at command line [markdown]
38 | """
39 | To run on local computer, use
40 | `python -m simulations.benchmark 2 simulations/bm_mixed`
41 | where 2 is a row ID of benchmark.csv, min value 2, max possible value is 91
42 |
43 | To run on cluster first load anaconda environment
44 | ```
45 | tmux
46 | interactive
47 | module load anaconda3/2021.5
48 | conda activate fwt
49 | python -m simulations.benchmark 2 simulations/bm_mixed
50 | ```
51 |
52 | To run on cluster as a job array, subset of rows. Recommend 10min time limit.
53 | ```
54 | PTH=./simulations/bm_mixed
55 | sbatch --mem=16G --array=5-91 ./simulations/benchmark.slurm $PTH
56 | ```
57 |
58 | To run on cluster as a job array, all rows of CSV file
59 | ```
60 | CSV=./simulations/bm_mixed/results/benchmark.csv
61 | PTH=./simulations/bm_mixed
62 | sbatch --mem=16G --array=1-$(wc -l < $CSV) ./simulations/benchmark.slurm $PTH
63 | ```
64 | """
65 |
66 | #%% Load dataset and set kernel and IPs
67 | ad = read_h5ad(path.join(dpth,"S1.h5ad"))
68 | #include only the training observations
69 | Ntr = round(0.95*ad.shape[0])
70 | ad = ad[:Ntr,:]
71 | X = ad.obsm["spatial"]
72 | #extract factors and loadings
73 | tru = postprocess.interpret_nonneg_mixed(ad.obsm["spfac"], ad.varm["spload"],
74 | ad.obsm["nsfac"],ad.varm["nsload"],
75 | lda_mode=False)
76 | F0 = tru["spatial"]["factors"]
77 | W0 = tru["spatial"]["loadings"]
78 | alpha = W0.sum(axis=1)
79 | pd1 = pd.DataFrame({"spatial_wt":alpha})
80 | pd1.spatial_wt.hist(bins=100) #spatial importance by feature
81 | #set hyperparams
82 | T = W0.shape[1]
83 | L = T+tru["nonspatial"]["loadings"].shape[1]
84 | M = 1296
85 | ker = tfk.MaternThreeHalves
86 | hmkw = {"figsize":(6,1.5),"s":1.5,"marker":"s","subplot_space":0,
87 | "spinecolor":"gray"}
88 | fig,axes=visualize.multiheatmap(X, tru["spatial"]["factors"], (1,4), cmap="Blues", **hmkw)
89 | fig,axes=visualize.multiheatmap(X, tru["nonspatial"]["factors"], (1,4), cmap="Blues", **hmkw)
90 |
91 | #%% Compare inferred to true factors
92 | pp = path.join(mpth,"S1/V5/L{}/poi_sz-constant/NSFH_T{}_{}_M{}".format(L,T,ker.__name__, M))
93 | tro = training.ModelTrainer.from_pickle(pp)
94 | fit = tro.model
95 | insfh = postprocess.interpret_nsfh(fit,X)
96 | F = insfh["spatial"]["factors"]
97 | fig,axes=visualize.multiheatmap(X, F, (1,4), cmap="Blues", **hmkw)
98 | fig,axes=visualize.multiheatmap(X, insfh["nonspatial"]["factors"], (1,4), cmap="Blues", **hmkw)
99 |
100 | #%% Compute goodness-of-fit metrics
101 | """
102 | ```
103 | python -m simulations.benchmark_gof simulations/bm_mixed
104 | ```
105 | or
106 | ```
107 | PTH=./simulations/bm_mixed
108 | sbatch --mem=16G ./simulations/benchmark_gof.slurm $PTH
109 | ```
110 | """
111 |
--------------------------------------------------------------------------------
/simulations/bm_mixed/03_benchmark_viz.Rmd:
--------------------------------------------------------------------------------
1 | ---
2 | title: "Benchmarking Visualization"
3 | author: "Will Townes"
4 | output: html_document
5 | ---
6 |
7 | ```{r}
8 | library(tidyverse)
9 | theme_set(theme_bw())
10 | fp<-file.path
11 | pth<-"simulations/bm_mixed"
12 | plt_pth<-fp(pth,"results/plots")
13 | if(!dir.exists(plt_pth)){
14 | dir.create(plt_pth,recursive=TRUE)
15 | }
16 | ```
17 |
18 | ```{r}
19 | d<-read.csv(fp(pth,"results/benchmark.csv"))
20 | d$converged<-as.logical(d$converged)
21 | d$model<-factor(d$model,levels=c("PNMF","NSF","NSFH"))
22 | sims<-c("ggblocks","quilt","both")
23 | d$sim<-factor(d$sim,levels=sims)
24 | d$simulation<-plyr::mapvalues(d$sim,sims,c("blocks (4+3)","quilt (4+3)","both (8+6)"))
25 | d$mixed_genes<-d$Jmix>0
26 |
27 | summarize_multicol<-function(x,f){
28 | #x is a character column whose entries are lists of numbers "[3.02, 4.33, 7.71,...]"
29 | #expand each entry into a vector and summarize it by
30 | #a function (f) like min, max, mean
31 | #return a numeric vector with the summary stat, has same length as x
32 | x<-sub("[","c(",x,fixed=TRUE)
33 | x<-sub("]",")",x,fixed=TRUE)
34 | vapply(x,function(t){f(eval(str2lang(t)))},FUN.VALUE=1.0)
35 | }
36 |
37 | d$factors_pearson_min<-summarize_multicol(d$factors_pearson,min)
38 | d$loadings_pearson_min<-summarize_multicol(d$loadings_pearson,min)
39 | d$factors_spearman_min<-summarize_multicol(d$factors_spearman,min)
40 | d$loadings_spearman_min<-summarize_multicol(d$loadings_spearman,min)
41 | ```
42 |
43 | ```{r}
44 | ggplot(d,aes(x=model,y=dev_val_mean,color=simulation,shape=mixed_genes))+geom_point(size=3,position=position_jitterdodge())+scale_y_log10()+ylab("validation deviance (mean)")
45 | ggsave(fp(plt_pth,"bm_mixed_dev_val_mean.pdf"),width=5,height=2.5)
46 |
47 | # ggplot(d,aes(x=model,y=factors_pearson_min,color=simulation,fill=simulation))+geom_boxplot()+scale_y_log10()+ylab("minimum factors correlation")
48 | # ggsave(fp(plt_pth,"bm_mixed_factors_pcor_min.pdf"),width=5,height=2.5)
49 | #
50 | # ggplot(d,aes(x=model,y=loadings_pearson_min,fill=simulation,color=simulation))+geom_boxplot()+scale_y_log10()+ylab("minimum loadings correlation")
51 | # ggsave(fp(plt_pth,"bm_mixed_loadings_pcor_min.pdf"),width=5,height=2.5)
52 |
53 | ggplot(d,aes(x=model,y=spatial_importance_dist,color=simulation,shape=mixed_genes))+geom_point(size=3,position=position_jitterdodge())+ylab("spatial importance distance")
54 | ggsave(fp(plt_pth,"bm_mixed_spat_importance_dist.pdf"),width=5,height=2.5)
55 |
56 | #ggplot(d,aes(x=model,y=loadings_clust_ari,fill=simulation,color=simulation))+geom_boxplot()+scale_y_log10()+ylab("Concordance of feature clusters with truth (ARI)")
57 | ```
58 |
59 | Linear regressions to get statistical significance
60 |
61 | ```{r}
62 | #NSFH vs NSF
63 | d$is_NSFH<-(d$model=="NSFH")
64 | summary(lm(dev_val_mean~is_NSFH+simulation,data=d))
65 | ```
66 |
--------------------------------------------------------------------------------
/simulations/bm_sp/01_data_generation.ipy:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # ---
3 | # jupyter:
4 | # jupytext:
5 | # text_representation:
6 | # extension: .ipy
7 | # format_name: percent
8 | # format_version: '1.3'
9 | # jupytext_version: 1.6.0
10 | # kernelspec:
11 | # display_name: Python 3
12 | # language: python
13 | # name: python3
14 | # ---
15 |
16 | # %%
17 | from os import path
18 | from copy import deepcopy
19 | import numpy as np
20 | import pandas as pd
21 | from scanpy import read_h5ad
22 | from janitor import expand_grid
23 | from utils import misc,visualize
24 | from simulations import sim
25 | pth = "simulations/bm_sp"
26 | dpth = path.join(pth,"data")
27 | rpth = path.join(pth,"results")
28 | mpth = path.join(pth,"models")
29 | # misc.mkdir_p(dpth) # or use symlink to dropbox
30 | # misc.mkdir_p(rpth)
31 | # misc.mkdir_p(mpth)
32 |
33 | # %% define scenarios
34 | cfg = {"sim":["quilt","ggblocks","both"], "nside":36, "bkg_mean":0.2,
35 | "nb_shape":10.0, "Jsp":200, "Jmix":0, "Jns":0, "expr_mean":20.0,
36 | "seed":[1,2,3,4,5], "V":5}
37 | a = expand_grid(others=cfg)
38 | b = pd.DataFrame({"sim":["quilt","ggblocks","both"], "L":[4,4,8]})
39 | a = a.merge(b,how="left",on="sim")
40 | a["scenario"] = list(range(1,a.shape[0]+1))
41 | a.to_csv(path.join(pth,"scenarios.csv"),index=False) #store separately for data generation
42 |
43 | # %% generate the simulated datasets and store to disk
44 | a = pd.read_csv(path.join(pth,"scenarios.csv")).convert_dtypes()
45 | def sim2disk(p):
46 | p = deepcopy(p)
47 | scen = p.pop("scenario")
48 | ad = sim.sim(p["sim"], Lns=0, **p)
49 | ad.write_h5ad(path.join(dpth,"S{}.h5ad".format(scen)),compression="gzip")
50 | a.apply(sim2disk,axis=1)
51 |
52 | # %% check the hdf5 file is correct
53 | ad = read_h5ad(path.join(dpth,"S1.h5ad"))
54 | X = ad.obsm["spatial"]
55 | Y = ad.layers["counts"]
56 | Yn = ad.X
57 | visualize.heatmap(X,Y[:,0],cmap="Blues")
58 | visualize.heatmap(X,Yn[:,0],cmap="Blues")
59 | #check distribution of validation data points
60 | N = Y.shape[0]
61 | z = np.zeros(N)
62 | Ntr = round(0.95*N)
63 | z[Ntr:] = 1
64 | visualize.heatmap(X,z,cmap="Blues")
65 |
66 | # %% merge with models to make results csv for tracking model runs
67 | m = pd.read_csv(path.join(pth,"models.csv")).convert_dtypes() #this CSV was manually created
68 | d = a.merge(m,how="cross")
69 | d["key"] = d.apply(misc.params2key,axis=1)
70 | d["key"] = d.agg(lambda x: f"S{x['scenario']}/{x['key']}", axis=1)
71 | # d["scenario"].to_string()+"/"+d["key"]
72 | d["converged"] = False
73 | d.to_csv(path.join(rpth,"benchmark.csv"),index=False)
74 |
--------------------------------------------------------------------------------
/simulations/bm_sp/02_benchmark.ipy:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # ---
3 | # jupyter:
4 | # jupytext:
5 | # text_representation:
6 | # extension: .ipy
7 | # format_name: percent
8 | # format_version: '1.3'
9 | # jupytext_version: 1.6.0
10 | # kernelspec:
11 | # display_name: Python 3
12 | # language: python
13 | # name: python3
14 | # ---
15 |
16 | # %%
17 | from os import path
18 | # from copy import deepcopy
19 | import numpy as np
20 | # import pandas as pd
21 | from matplotlib import pyplot as plt
22 | from scipy.spatial.distance import pdist,cdist
23 | from scipy.stats.stats import pearsonr,spearmanr
24 | from sklearn.cluster import KMeans
25 | from sklearn import metrics
26 | from scanpy import read_h5ad
27 | from tensorflow_probability import math as tm
28 | tfk = tm.psd_kernels
29 |
30 | from utils import training,postprocess,visualize
31 | # from simulations import sim
32 | pth = "simulations/bm_sp"
33 | dpth = path.join(pth,"data")
34 | rpth = path.join(pth,"results")
35 | mpth = path.join(pth,"models")
36 |
37 | #%% Benchmarking at command line [markdown]
38 | """
39 | To run on local computer, use
40 | `python -m simulations.benchmark 2 simulations/bm_sp`
41 | where 2 is a row ID of benchmark.csv, min value 2, max possible value is 115
42 |
43 | To run on cluster first load anaconda environment
44 | ```
45 | tmux
46 | interactive
47 | module load anaconda3/2021.5
48 | conda activate fwt
49 | python -m simulations.benchmark 2 simulations/bm_sp
50 | ```
51 |
52 | To run on cluster as a job array, subset of rows. Recommend 10min time limit.
53 | ```
54 | PTH=./simulations/bm_sp
55 | sbatch --mem=16G --array=2-26,52-76 ./simulations/benchmark.slurm $PTH
56 | ```
57 |
58 | To run on cluster as a job array, all rows of CSV file
59 | ```
60 | CSV=./simulations/bm_sp/results/benchmark.csv
61 | PTH=./simulations/bm_sp
62 | sbatch --mem=16G --array=1-$(wc -l < $CSV) ./simulations/benchmark.slurm $PTH
63 | ```
64 | """
65 |
66 | #%% Load dataset and set kernel and IPs
67 | ad = read_h5ad(path.join(dpth,"S12.h5ad"))
68 | #include only the training observations
69 | Ntr = round(0.95*ad.shape[0])
70 | ad = ad[:Ntr,:]
71 | X = ad.obsm["spatial"]
72 | #extract factors and loadings
73 | tru = postprocess.interpret_nonneg(ad.obsm["spfac"],ad.varm["spload"],lda_mode=False)
74 | F0 = tru["factors"]
75 | W0 = tru["loadings"]
76 | FFd0 = pdist(F0)
77 | WWd0 = pdist(W0)
78 | #set hyperparams
79 | M = 1296
80 | ker = tfk.MaternThreeHalves
81 | hmkw = {"figsize":(6,3),"s":1.5,"marker":"s","subplot_space":0,
82 | "spinecolor":"gray"}
83 | fig,axes=visualize.multiheatmap(X, F0, (2,4), cmap="Blues", **hmkw)
84 |
85 | #%% Compare inferred to true factors
86 | pp = path.join(mpth,"S12/V5/L8/poi_sz-constant/NSF_{}_M{}".format(ker.__name__, M))
87 | tro = training.ModelTrainer.from_pickle(pp)
88 | fit = tro.model
89 | insf = postprocess.interpret_nsf(fit,X)
90 | F = insf["factors"]
91 | fig,axes=visualize.multiheatmap(X, F, (2,4), cmap="Blues", **hmkw)
92 | cdist(F0.T,F.T,metric=lambda x,y: abs(pearsonr(x,y)[0])).max(axis=1)
93 | cdist(F0.T,F.T,metric=lambda x,y: abs(spearmanr(x,y)[0])).max(axis=1)
94 | plt.scatter(F0[:,0],F[:,0])
95 |
96 | FFd = pdist(insf["factors"])
97 | W = insf["loadings"]
98 | WWd = pdist(W)
99 | plt.hexbin(FFd0,FFd,gridsize=100,cmap="Greys",bins="log")
100 | plt.scatter(WWd0,WWd)
101 | pearsonr(FFd0,FFd)[0]
102 | spearmanr(FFd0,FFd)[0]
103 | pearsonr(WWd0,WWd)[0]
104 | spearmanr(WWd0,WWd)[0]
105 | nclust = W0.shape[1]
106 | km0 = KMeans(n_clusters=nclust).fit(W0).labels_
107 | km1 = KMeans(n_clusters=nclust).fit(W).labels_
108 | ari = metrics.adjusted_rand_score(km0, km1)
109 |
110 | #%% Compute goodness-of-fit metrics
111 | """
112 | ```
113 | python -m simulations.benchmark_gof simulations/bm_sp
114 | ```
115 | or
116 | ```
117 | PTH=./simulations/bm_sp
118 | sbatch --mem=16G ./simulations/benchmark_gof.slurm $PTH
119 | ```
120 | """
121 |
122 | #%% Visualize results
123 | import pandas as pd
124 | import seaborn as sns
125 | from ast import literal_eval
126 | d = pd.read_csv(path.join(rpth,"benchmark.csv"))
127 | # d = d[d["converged"]]
128 | d["factors_pearson"] = d["factors_pearson"].map(lambda x: np.array(literal_eval(x)))
129 | d["factors_pearson_min"] = d["factors_pearson"].map(min)
130 | d["factors_pearson_mean"] = d["factors_pearson"].map(np.mean)
131 | # d["factors_spearman"] = d["factors_spearman"].map(lambda x: np.array(literal_eval(x)))
132 | # d["factors_spearman_min"] = d["factors_spearman"].map(min)
133 | # d["factors_spearman_mean"] = d["factors_spearman"].map(np.mean)
134 | sns.stripplot(x="model",y="factors_pearson_min",hue="sim",dodge=True,data=d)
135 |
136 |
137 |
--------------------------------------------------------------------------------
/simulations/bm_sp/03_benchmark_viz.Rmd:
--------------------------------------------------------------------------------
1 | ---
2 | title: "Benchmarking Visualization"
3 | author: "Will Townes"
4 | output: html_document
5 | ---
6 |
7 | ```{r}
8 | library(tidyverse)
9 | theme_set(theme_bw())
10 | fp<-file.path
11 | pth<-"simulations/bm_sp"
12 | plt_pth<-fp(pth,"results/plots")
13 | if(!dir.exists(plt_pth)){
14 | dir.create(plt_pth,recursive=TRUE)
15 | }
16 | ```
17 |
18 | ```{r}
19 | d<-read.csv(fp(pth,"results/benchmark.csv"))
20 | d$converged<-as.logical(d$converged)
21 | d$model<-factor(d$model,levels=c("FA","MEFISTO","RSF","PNMF","NSF"))
22 | sims<-c("ggblocks","quilt","both")
23 | d$sim<-factor(d$sim,levels=sims)
24 | d$simulation<-plyr::mapvalues(d$sim,sims,c("blocks (4)","quilt (4)","both (8)"))
25 |
26 | summarize_multicol<-function(x,f){
27 | #x is a character column whose entries are lists of numbers "[3.02, 4.33, 7.71,...]"
28 | #expand each entry into a vector and summarize it by
29 | #a function (f) like min, max, mean
30 | #return a numeric vector with the summary stat, has same length as x
31 | x<-sub("[","c(",x,fixed=TRUE)
32 | x<-sub("]",")",x,fixed=TRUE)
33 | vapply(x,function(t){f(eval(str2lang(t)))},FUN.VALUE=1.0)
34 | }
35 |
36 | d$factors_pearson_min<-summarize_multicol(d$factors_pearson,min)
37 | d$loadings_pearson_min<-summarize_multicol(d$loadings_pearson,min)
38 | d$factors_spearman_min<-summarize_multicol(d$factors_spearman,min)
39 | d$loadings_spearman_min<-summarize_multicol(d$loadings_spearman,min)
40 | ```
41 |
42 | ```{r}
43 | ggplot(d,aes(x=model,y=dev_val_mean,color=simulation))+geom_point(size=3,position=position_jitterdodge())+scale_y_log10()+ylab("validation deviance (mean)")
44 | ggsave(fp(plt_pth,"bm_sp_dev_val_mean.pdf"),width=5,height=2.5)
45 |
46 | ggplot(d,aes(x=model,y=factors_pearson_min,color=simulation))+geom_point(size=3,position=position_jitterdodge())+scale_y_log10()+ylab("minimum factors correlation")
47 | ggsave(fp(plt_pth,"bm_sp_factors_pcor_min.pdf"),width=5,height=2.5)
48 |
49 | ggplot(d,aes(x=model,y=loadings_pearson_min,color=simulation))+geom_point(size=3,position=position_jitterdodge())+scale_y_log10()+ylab("minimum loadings correlation")
50 | ggsave(fp(plt_pth,"bm_sp_loadings_pcor_min.pdf"),width=5,height=2.5)
51 |
52 | #ggplot(d,aes(x=model,y=factors_spearman_min,fill=simulation,color=simulation))+geom_boxplot()+scale_y_log10()+ylab("Spearman correlation with true factors")
53 |
54 | #ggplot(d,aes(x=model,y=loadings_spearman_min,fill=simulation,color=simulation))+geom_boxplot()+scale_y_log10()+ylab("Spearman correlation with true loadings")
55 |
56 | #ggplot(d,aes(x=model,y=loadings_clust_ari,fill=simulation,color=simulation))+geom_boxplot()+scale_y_log10()+ylab("Concordance of feature clusters with truth (ARI)")
57 | ```
58 |
59 | Linear regressions to get statistical significance
60 |
61 | ```{r}
62 | #nonnegative vs real-valued
63 | d$nonneg<-(d$model %in% c("PNMF","NSF"))
64 | summary(lm(factors_pearson_min~nonneg+simulation,data=d))
65 | summary(lm(loadings_pearson_min~nonneg+simulation,data=d))
66 | d$sp_aware<-(d$model %in% c("RSF","NSF"))
67 | summary(lm(dev_val_mean~sp_aware+simulation,data=d))
68 | ```
69 |
--------------------------------------------------------------------------------
/simulations/bm_sp/04_quilt_exploratory.ipy:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # ---
3 | # jupyter:
4 | # jupytext:
5 | # text_representation:
6 | # extension: .ipy
7 | # format_name: percent
8 | # format_version: '1.3'
9 | # jupytext_version: 1.6.0
10 | # kernelspec:
11 | # display_name: Python 3
12 | # language: python
13 | # name: python3
14 | # ---
15 |
16 | # %%
17 | # import numpy as np
18 | from os import path
19 | from scanpy import read_h5ad
20 | from tensorflow_probability import math as tm
21 | tfk = tm.psd_kernels
22 |
23 | from utils import preprocess,training,misc,postprocess,visualize
24 |
25 | # rng = np.random.default_rng()
26 | pth = "simulations/bm_sp"
27 | dpth = path.join(pth,"data")
28 | mpth = path.join(pth,"models/S1/V5")
29 | plt_pth = path.join(pth,"results/plots")
30 | misc.mkdir_p(plt_pth)
31 |
32 | #%% Data loading
33 | ad = read_h5ad(path.join(dpth,"S1.h5ad"))
34 | N = ad.shape[0]
35 | Ntr = round(0.95*N)
36 | ad = ad[:Ntr,:]
37 | J = ad.shape[1]
38 | X = ad.obsm["spatial"]
39 | # D_n,_ = preprocess.anndata_to_train_val(ad,train_frac=1.0,flip_yaxis=False)
40 |
41 | #%% Save heatmap of true values and sampled data
42 | hmkw = {"figsize":(8,1.9), "bgcol":"gray", "subplot_space":0.1, "marker":"s",
43 | "s":2.9}
44 | Ftrue = ad.obsm["spfac"]
45 | fig,axes=visualize.multiheatmap(X, Ftrue, (1,4), cmap="Blues", **hmkw)
46 | fig.savefig(path.join(plt_pth,"quilt_true_factors.png"),bbox_inches='tight')
47 | # fig.savefig(path.join(plt_pth,"quilt_true_factors.pdf"),bbox_inches='tight')
48 | #saving directly to pdf looks weird. Try imagemagick instead
49 | #convert quilt_true_factors.png quilt_true_factors.pdf
50 |
51 | Yss = ad.layers["counts"][:,(4,0,1,2)]
52 | fig,axes=visualize.multiheatmap(X, Yss, (1,4), cmap="Blues", **hmkw)
53 | fig.savefig(path.join(plt_pth,"quilt_data.png"),bbox_inches='tight')
54 | # fig.savefig(path.join(plt_pth,"quilt_data.pdf"),bbox_inches='tight')
55 |
56 | # %% Initialize inducing points
57 | L = 4
58 | M = N #number of inducing points
59 | Z = X
60 | ker = tfk.MaternThreeHalves
61 |
62 | #%% NSF
63 | try:
64 | pp = path.join(mpth,"L{}/poi_sz-constant/NSF_{}_M{}".format(L,ker.__name__,M))
65 | tro = training.ModelTrainer.from_pickle(pp)
66 | fit = tro.model
67 | # except FileNotFoundError:
68 | # fit = sf.SpatialFactorization(J,L,Z,psd_kernel=ker,nonneg=True,lik="poi")
69 | # fit.init_loadings(D["Y"],X=X,sz=D["sz"],shrinkage=0.3)
70 | # pp = fit.generate_pickle_path("constant",base=mpth)
71 | # tro = training.ModelTrainer(fit,pickle_path=pp)
72 | # %time tro.train_model(*Dtf) #12 mins
73 | insf = postprocess.interpret_nsf(fit,X,S=100,lda_mode=False)
74 | Fplot = insf["factors"][:,[3,1,2,0]]
75 | fig,axes=visualize.multiheatmap(X, Fplot, (1,4), cmap="Blues", **hmkw)
76 | fig.savefig(path.join(plt_pth,"quilt_nsf.png"),bbox_inches='tight')
77 |
78 | #%% PNMF
79 | try:
80 | pp = path.join(mpth,"L{}/poi_sz-constant/PNMF".format(L))
81 | tro = training.ModelTrainer.from_pickle(pp)
82 | fit = tro.model
83 | # except FileNotFoundError:
84 | # fit = cf.CountFactorization(N,J,L,nonneg=True,lik="poi")
85 | # fit.init_loadings(D["Y"],sz=D["sz"],shrinkage=0.3)
86 | # pp = fit.generate_pickle_path("constant",base=mpth)
87 | # tro = training.ModelTrainer(fit,pickle_path=pp)
88 | # %time tro.train_model(*Dtf) #3 mins
89 | ipnmf = postprocess.interpret_pnmf(fit,S=100,lda_mode=False)
90 | Fplot = ipnmf["factors"][:,[3,1,2,0]]
91 | fig,axes=visualize.multiheatmap(X, Fplot, (1,4), cmap="Blues", **hmkw)
92 | fig.savefig(path.join(plt_pth,"quilt_pnmf.png"),bbox_inches='tight')
93 |
94 | #%% MEFISTO-Gaussian
95 | from models.mefisto import MEFISTO
96 | pp = path.join(mpth,"L{}/gau/MEFISTO_ExponentiatedQuadratic_M{}".format(L,M))
97 | try:
98 | mef = MEFISTO.from_pickle(pp)
99 | # except FileNotFoundError:
100 | # mef = MEFISTO(D_n, L, inducing_pts=M, pickle_path=pp)
101 | # %time mef.train() #also saves to pickle file- 28min
102 | Fplot = mef.get_factors()
103 | fig,axes=visualize.multiheatmap(X, Fplot, (1,4), cmap="RdBu", **hmkw)
104 | fig.savefig(path.join(plt_pth,"quilt_mefisto.png"),bbox_inches='tight')
105 |
106 | #%% FA: Non-spatial, real-valued
107 | try:
108 | pp = path.join(mpth,"L{}/gau/FA".format(L))
109 | tro = training.ModelTrainer.from_pickle(pp)
110 | fit = tro.model
111 | # except FileNotFoundError:
112 | # fit = cf.CountFactorization(N, J, L, nonneg=False, lik="gau",
113 | # feature_means=fmeans)
114 | # fit.init_loadings(D_c["Y"])
115 | # pp = fit.generate_pickle_path(None,base=mpth)
116 | # tro = training.ModelTrainer(fit,pickle_path=pp)
117 | # %time tro.train_model(*Dtf_c) #14sec
118 | Fplot = postprocess.interpret_fa(fit,S=100)["factors"]
119 | fig,axes=visualize.multiheatmap(X, Fplot, (1,4), cmap="RdBu", **hmkw)
120 | fig.savefig(path.join(plt_pth,"quilt_fa.png"),bbox_inches='tight')
121 |
122 | #%% RSF
123 | try:
124 | pp = path.join(mpth,"L{}/gau/RSF_{}_M{}".format(L,ker.__name__,M))
125 | tro = training.ModelTrainer.from_pickle(pp)
126 | fit = tro.model
127 | # except FileNotFoundError:
128 | # fit = sf.SpatialFactorization(J,L,Z,psd_kernel=ker,nonneg=False,lik="gau")
129 | # fit.init_loadings(D_c["Y"],X=X)
130 | # pp = fit.generate_pickle_path(None,base=mpth)
131 | # tro = training.ModelTrainer(fit,pickle_path=pp)
132 | # %time tro.train_model(*Dtf_c) #5 mins
133 | Fplot = postprocess.interpret_rsf(fit,X,S=100)["factors"]
134 | fig,axes=visualize.multiheatmap(X, Fplot, (1,4), cmap="RdBu", **hmkw)
135 | fig.savefig(path.join(plt_pth,"quilt_rsf.png"),bbox_inches='tight')
136 |
137 | #%% Traditional scanpy analysis (unsupervised clustering)
138 | #https://scanpy-tutorials.readthedocs.io/en/latest/spatial/basic-analysis.html
139 | import pandas as pd
140 | import scanpy as sc
141 | sc.pp.pca(ad)
142 | sc.pp.neighbors(ad)
143 | sc.tl.umap(ad)
144 | sc.tl.leiden(ad, resolution=1.0, key_added="clusters")
145 | # plt.rcParams["figure.figsize"] = (4, 4)
146 | sc.pl.umap(ad, color="clusters", wspace=0.4)
147 | sc.pl.embedding(ad, "spatial", color="clusters")
148 | cl = pd.get_dummies(ad.obs["clusters"]).to_numpy()
149 | # tgnames = [str(i) for i in range(1,cl.shape[1]+1)]
150 | # hmkw = {"figsize":(7.7,5.9), "bgcol":"gray", "subplot_space":0.05, "marker":"s",
151 | # "s":2.9}
152 | hmkw = {"figsize":(8,2.5), "bgcol":"gray", "subplot_space":0.05, "marker":"s",
153 | "s":1}
154 | fig,axes=visualize.multiheatmap(X, cl, (2,6), cmap="Blues", **hmkw)
155 | # visualize.set_titles(fig, tgnames, x=0.03, y=.88, fontsize="small", c="white",
156 | # ha="left", va="top")
157 | fig.savefig(path.join(plt_pth,"quilt_scanpy_clusters.png"),
158 | bbox_inches='tight')
159 |
--------------------------------------------------------------------------------
/simulations/bm_sp/05_ggblocks_exploratory.ipy:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # ---
3 | # jupyter:
4 | # jupytext:
5 | # text_representation:
6 | # extension: .ipy
7 | # format_name: percent
8 | # format_version: '1.3'
9 | # jupytext_version: 1.6.0
10 | # kernelspec:
11 | # display_name: Python 3
12 | # language: python
13 | # name: python3
14 | # ---
15 |
16 | # %%
17 | # import numpy as np
18 | from os import path
19 | from scanpy import read_h5ad
20 | from tensorflow_probability import math as tm
21 | tfk = tm.psd_kernels
22 |
23 | from utils import preprocess,training,misc,postprocess,visualize
24 |
25 | # rng = np.random.default_rng()
26 | pth = "simulations/bm_sp"
27 | dpth = path.join(pth,"data")
28 | mpth = path.join(pth,"models/S6/V5")
29 | plt_pth = path.join(pth,"results/plots")
30 | misc.mkdir_p(plt_pth)
31 |
32 | #%% Data loading
33 | ad = read_h5ad(path.join(dpth,"S6.h5ad"))
34 | N = ad.shape[0]
35 | Ntr = round(0.95*N)
36 | ad = ad[:Ntr,:]
37 | J = ad.shape[1]
38 | X = ad.obsm["spatial"]
39 | D_n,_ = preprocess.anndata_to_train_val(ad,train_frac=1.0,flip_yaxis=False)
40 |
41 | #%% Save heatmap of true values and sampled data
42 | hmkw = {"figsize":(8,1.9), "bgcol":"gray", "subplot_space":0.1, "marker":"s",
43 | "s":2.9}
44 | Ftrue = ad.obsm["spfac"]
45 | fig,axes=visualize.multiheatmap(X, Ftrue, (1,4), cmap="Blues", **hmkw)
46 | fig.savefig(path.join(plt_pth,"ggblocks_true_factors.png"),bbox_inches='tight')
47 |
48 | Yss = ad.layers["counts"][:,(4,0,1,2)]
49 | fig,axes=visualize.multiheatmap(X, Yss, (1,4), cmap="Blues", **hmkw)
50 | fig.savefig(path.join(plt_pth,"ggblocks_data.png"),bbox_inches='tight')
51 |
52 | # %% Initialize inducing points
53 | L = 4
54 | M = N #number of inducing points
55 | Z = X
56 | ker = tfk.MaternThreeHalves
57 |
58 | #%% NSF
59 | try:
60 | pp = path.join(mpth,"L{}/poi_sz-constant/NSF_{}_M{}".format(L,ker.__name__,M))
61 | tro = training.ModelTrainer.from_pickle(pp)
62 | fit = tro.model
63 | # except FileNotFoundError:
64 | # fit = sf.SpatialFactorization(J,L,Z,psd_kernel=ker,nonneg=True,lik="poi")
65 | # fit.init_loadings(D["Y"],X=X,sz=D["sz"],shrinkage=0.3)
66 | # pp = fit.generate_pickle_path("constant",base=mpth)
67 | # tro = training.ModelTrainer(fit,pickle_path=pp)
68 | # %time tro.train_model(*Dtf) #12 mins
69 | insf = postprocess.interpret_nsf(fit,X,S=100,lda_mode=False)
70 | Fplot = insf["factors"][:,[3,1,2,0]]
71 | fig,axes=visualize.multiheatmap(X, Fplot, (1,4), cmap="Blues", **hmkw)
72 | fig.savefig(path.join(plt_pth,"ggblocks_nsf.png"),bbox_inches='tight')
73 |
74 | #%% PNMF
75 | try:
76 | pp = path.join(mpth,"L{}/poi_sz-constant/PNMF".format(L))
77 | tro = training.ModelTrainer.from_pickle(pp)
78 | fit = tro.model
79 | # except FileNotFoundError:
80 | # fit = cf.CountFactorization(N,J,L,nonneg=True,lik="poi")
81 | # fit.init_loadings(D["Y"],sz=D["sz"],shrinkage=0.3)
82 | # pp = fit.generate_pickle_path("constant",base=mpth)
83 | # tro = training.ModelTrainer(fit,pickle_path=pp)
84 | # %time tro.train_model(*Dtf) #3 mins
85 | ipnmf = postprocess.interpret_pnmf(fit,S=100,lda_mode=False)
86 | Fplot = ipnmf["factors"][:,[3,1,2,0]]
87 | fig,axes=visualize.multiheatmap(X, Fplot, (1,4), cmap="Blues", **hmkw)
88 | fig.savefig(path.join(plt_pth,"ggblocks_pnmf.png"),bbox_inches='tight')
89 |
90 | #%% MEFISTO-Gaussian
91 | from models.mefisto import MEFISTO
92 | pp = path.join(mpth,"L{}/gau/MEFISTO_ExponentiatedQuadratic_M{}".format(L,M))
93 | try:
94 | mef = MEFISTO.from_pickle(pp)
95 | except FileNotFoundError:
96 | mef = MEFISTO(D_n, L, inducing_pts=M, pickle_path=pp)
97 | %time mef.train() #also saves to pickle file- 9min
98 | Fplot = mef.get_factors()
99 | fig,axes=visualize.multiheatmap(X, Fplot, (1,4), cmap="RdBu", **hmkw)
100 | fig.savefig(path.join(plt_pth,"ggblocks_mefisto.png"),bbox_inches='tight')
101 |
102 | #%% FA: Non-spatial, real-valued
103 | try:
104 | pp = path.join(mpth,"L{}/gau/FA".format(L))
105 | tro = training.ModelTrainer.from_pickle(pp)
106 | fit = tro.model
107 | # except FileNotFoundError:
108 | # fit = cf.CountFactorization(N, J, L, nonneg=False, lik="gau",
109 | # feature_means=fmeans)
110 | # fit.init_loadings(D_c["Y"])
111 | # pp = fit.generate_pickle_path(None,base=mpth)
112 | # tro = training.ModelTrainer(fit,pickle_path=pp)
113 | # %time tro.train_model(*Dtf_c) #14sec
114 | Fplot = postprocess.interpret_fa(fit,S=100)["factors"]
115 | fig,axes=visualize.multiheatmap(X, Fplot, (1,4), cmap="RdBu", **hmkw)
116 | fig.savefig(path.join(plt_pth,"ggblocks_fa.png"),bbox_inches='tight')
117 |
118 | #%% RSF
119 | try:
120 | pp = path.join(mpth,"L{}/gau/RSF_{}_M{}".format(L,ker.__name__,M))
121 | tro = training.ModelTrainer.from_pickle(pp)
122 | fit = tro.model
123 | # except FileNotFoundError:
124 | # fit = sf.SpatialFactorization(J,L,Z,psd_kernel=ker,nonneg=False,lik="gau")
125 | # fit.init_loadings(D_c["Y"],X=X)
126 | # pp = fit.generate_pickle_path(None,base=mpth)
127 | # tro = training.ModelTrainer(fit,pickle_path=pp)
128 | # %time tro.train_model(*Dtf_c) #5 mins
129 | Fplot = postprocess.interpret_rsf(fit,X,S=100)["factors"]
130 | fig,axes=visualize.multiheatmap(X, Fplot, (1,4), cmap="RdBu", **hmkw)
131 | fig.savefig(path.join(plt_pth,"ggblocks_rsf.png"),bbox_inches='tight')
132 |
133 | #%% Traditional scanpy analysis (unsupervised clustering)
134 | #https://scanpy-tutorials.readthedocs.io/en/latest/spatial/basic-analysis.html
135 | import pandas as pd
136 | import scanpy as sc
137 | sc.pp.pca(ad)
138 | sc.pp.neighbors(ad)
139 | sc.tl.umap(ad)
140 | sc.tl.leiden(ad, resolution=1.0, key_added="clusters")
141 | # plt.rcParams["figure.figsize"] = (4, 4)
142 | sc.pl.umap(ad, color="clusters", wspace=0.4)
143 | sc.pl.embedding(ad, "spatial", color="clusters")
144 | cl = pd.get_dummies(ad.obs["clusters"]).to_numpy()
145 | # tgnames = [str(i) for i in range(1,cl.shape[1]+1)]
146 | # hmkw = {"figsize":(7.7,5.9), "bgcol":"gray", "subplot_space":0.05, "marker":"s",
147 | # "s":2.9}
148 | # hmkw = {"figsize":(6,4), "bgcol":"gray", "subplot_space":0.05, "marker":"s",
149 | # "s":3.4}
150 | hmkw = {"figsize":(8,1.6), "bgcol":"gray", "subplot_space":0.05, "marker":"s",
151 | "s":1.6}
152 | fig,axes=visualize.multiheatmap(X, cl, (1,5), cmap="Blues", **hmkw)
153 | # visualize.set_titles(fig, tgnames, x=0.03, y=.88, fontsize="small", c="white",
154 | # ha="left", va="top")
155 | fig.savefig(path.join(plt_pth,"ggblocks_scanpy_clusters.png"),
156 | bbox_inches='tight')
157 |
--------------------------------------------------------------------------------
/simulations/bm_sp/data/S1.h5ad:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/willtownes/nsf-paper/0cacf8352e09d223ab8d4421025195358bbde8df/simulations/bm_sp/data/S1.h5ad
--------------------------------------------------------------------------------
/simulations/bm_sp/data/S6.h5ad:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/willtownes/nsf-paper/0cacf8352e09d223ab8d4421025195358bbde8df/simulations/bm_sp/data/S6.h5ad
--------------------------------------------------------------------------------
/simulations/sim.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on Fri Feb 4 09:05:19 2022
5 |
6 | @author: townesf
7 | """
8 | import random
9 | import numpy as np
10 | from pandas import get_dummies
11 | from anndata import AnnData
12 | from scanpy import pp
13 | from utils import misc,preprocess
14 | dtp = "float32"
15 | random.seed(101)
16 |
17 | def squares():
18 | A = np.zeros([12,12])
19 | A[1:5,1:5] = 1
20 | A[7:11,1:5] = 1
21 | A[1:5,7:11] = 1
22 | A[7:11,7:11] = 1
23 | return A
24 |
25 | def corners():
26 | B = np.zeros([6,6])
27 | for i in range(6):
28 | B[i,i:] = 1
29 | A = np.flip(B,axis=1)
30 | AB = np.hstack((A,B))
31 | CD = np.flip(AB,axis=0)
32 | return np.vstack((AB,CD))
33 |
34 | def scotland():
35 | A = np.eye(12)
36 | for i in range(12):
37 | A[-i-1,i] = 1
38 | return A
39 |
40 | def checkers():
41 | A = np.zeros([4,4])
42 | B = np.ones([4,4])
43 | AB = np.hstack((A,B,A))
44 | BA = np.hstack((B,A,B))
45 | return np.vstack((AB,BA,AB))
46 |
47 | def quilt():
48 | A = np.zeros([4,144])
49 | A[0,:] = squares().flatten()
50 | A[1,:] = corners().flatten()
51 | A[2,:] = scotland().flatten()
52 | A[3,:] = checkers().flatten()
53 | return A #basic block size is 12x12
54 |
55 | def ggblocks():
56 | A = np.zeros( [ 4 , 36 ] )
57 | A[0, [ 1 , 6 , 7 , 8 , 13 ] ] = 1
58 | A[1, [ 3 , 4 , 5 , 9 , 11 , 15 , 16 , 17 ] ] = 1
59 | A[2, [ 18 , 24 , 25 , 30 , 31 , 32 ] ] = 1
60 | A[3, [ 21 , 22 , 23 , 28 , 34 ] ] = 1
61 | return A #basic block size is 6x6
62 |
63 | def sqrt_int(x):
64 | z = int(round(x**.5))
65 | if x==z**2:
66 | return z
67 | else:
68 | raise ValueError("x must be a square integer")
69 |
70 | def gen_spatial_factors(scenario="quilt",nside=36):
71 | """
72 | Generate the factors matrix for either the 'quilt' or 'ggblocks' scenario
73 | There are 4 basic patterns (L=4)
74 | There are N=(nside^2) observations.
75 | Returns:
76 | factor values[Nx4] matrix
77 | """
78 | if scenario=="quilt":
79 | A = quilt()
80 | elif scenario=="ggblocks":
81 | A = ggblocks()
82 | else:
83 | raise ValueError("scenario must be 'quilt' or 'ggblocks'")
84 | unit = sqrt_int(A.shape[1]) #quilt: 12, ggblocks: 6
85 | assert nside%unit==0
86 | ncopy = nside//unit
87 | N = nside**2 #36x36=1296
88 | L = A.shape[0] #4
89 | A = A.reshape((L,unit,unit))
90 | A = np.kron(A,np.ones((1,ncopy,ncopy)))
91 | F = A.reshape((L,N)).T #NxL
92 | return F
93 |
94 | def gen_spatial_coords(N): #N is number of observations
95 | X = misc.make_grid(N)
96 | X[:,1] = -X[:,1] #make the display the same
97 | return preprocess.rescale_spatial_coords(X)
98 |
99 | def gen_nonspatial_factors(N,L=3,nzprob=0.2,seed=101):
100 | rng = np.random.default_rng(seed)
101 | return rng.binomial(1,nzprob,size=(N,L))
102 |
103 | def gen_loadings(Lsp, Lns=3, Jsp=0, Jmix=500, Jns=0, expr_mean=20.0,
104 | mix_frac_spat=0.55, seed=101, **kwargs):
105 | """
106 | generate a loadings matrix L=components, J=features
107 | kwargs currently ignored
108 | """
109 | rng = np.random.default_rng(seed)
110 | J = Jsp+Jmix+Jns #total number of features
111 | if Lsp>0:
112 | w = rng.choice(Lsp,J,replace=True) #spatial loadings
113 | W = get_dummies(w).to_numpy(dtype=dtp) #JxLsp indicator matrix
114 | else:
115 | W = np.zeros((J,0))
116 | if Lns>0:
117 | v = rng.choice(Lns,J,replace=True) #nonspatial loadings
118 | V = get_dummies(v).to_numpy(dtype=dtp) #JxLnsp indicator matrix
119 | else:
120 | V = np.zeros((J,0))
121 | #pure spatial features
122 | W[:Jsp,:]*=expr_mean
123 | V[:Jsp,:]=0
124 | #features with mixed assignment to spatial and nonspatial components
125 | W[Jsp:(Jsp+Jmix),:]*=(mix_frac_spat*expr_mean)
126 | V[Jsp:(Jsp+Jmix),:]*=((1-mix_frac_spat)*expr_mean)
127 | #pure nonspatial features
128 | W[(Jsp+Jmix):,:]=0
129 | V[(Jsp+Jmix):,:]*=expr_mean
130 | return W,V
131 |
132 | def sim2anndata(locs, outcome, spfac, spload, nsfac=None, nsload=None):
133 | """
134 | d: a dict returned by sim_quilt or sim_ggblocks
135 | returns: an AnnData object with both raw counts and scanpy normalized counts
136 | obsm slots: spatial, spfac, nsfac (coordinates and factors)
137 | varm slots: spload, nsload (loadings)
138 | """
139 | obsm = {"spatial":locs, "spfac":spfac, "nsfac":nsfac}
140 | varm = {"spload":spload, "nsload":nsload}
141 | ad = AnnData(outcome, obsm=obsm, varm=varm)
142 | ad.layers = {"counts":ad.X.copy()} #store raw counts before normalization changes ad.X
143 | #here we do not normalize because the total counts are actually meaningful!
144 | # pp.normalize_total(ad, inplace=True, layers=None, key_added="sizefactor")
145 | pp.log1p(ad)
146 | #shuffle indices to disperse validation data throughout the field of view
147 | idx = list(range(ad.shape[0]))
148 | random.shuffle(idx)
149 | ad = ad[idx,:]
150 | return ad
151 |
152 | def sim(scenario, nside=36, nzprob_nsp=0.2, bkg_mean=0.2, nb_shape=10.0,
153 | seed=101, **kwargs):
154 | """
155 | scenario: either 'quilt'(L=4), 'ggblocks'(L=4), or 'both'(L=8)
156 | N=number of observations is nside**2
157 | nzprob_nsp: for nonspatial factors, the probability of a "one" (else zero)
158 | bkg_mean: negative binomial mean for observations that are "zero" in the factors
159 | nb_shape: shape parameter of negative binomial distribution
160 | seed: for random number generation reproducibility
161 | kwargs: passed to gen_loadings
162 | """
163 | if scenario=="both":
164 | F1 = gen_spatial_factors(nside=nside,scenario="ggblocks")
165 | F2 = gen_spatial_factors(nside=nside,scenario="quilt")
166 | F = np.hstack((F1,F2))
167 | else:
168 | F = gen_spatial_factors(scenario=scenario,nside=nside)
169 | rng = np.random.default_rng(seed)
170 | N = nside**2
171 | X = gen_spatial_coords(N)
172 | W,V = gen_loadings(F.shape[1],seed=seed, **kwargs)
173 | U = gen_nonspatial_factors(N,L=V.shape[1],nzprob=nzprob_nsp,seed=seed)
174 | Lambda = bkg_mean+F@W.T+U@V.T #NxJ
175 | r = nb_shape
176 | Y = rng.negative_binomial(r,r/(Lambda+r))
177 | return sim2anndata(X,Y,F,W,nsfac=U,nsload=V)
178 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/willtownes/nsf-paper/0cacf8352e09d223ab8d4421025195358bbde8df/utils/__init__.py
--------------------------------------------------------------------------------
/utils/benchmark_array.slurm:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH --job-name=benchmark # create a short name for your job
3 | #SBATCH --nodes=1 # node count
4 | #SBATCH --ntasks=1 # total number of tasks across all nodes
5 | #SBATCH --cpus-per-task=12 # cpu-cores per task (>1 if multi-threaded tasks)
6 | #SBATCH --time=12:00:00 # total run time limit (HH:MM:SS)
7 | #SBATCH --mail-type=END,FAIL
8 | #SBATCH --mail-user=ftownes@princeton.edu
9 |
10 | #example usage, default is 4G memory per core (--mem-per-cpu), --mem is total
11 | #CSV=./scrna/visium_brain_sagittal/results/benchmark.csv
12 | #DAT=./scrna/visium_brain_sagittal/data/visium_brain_sagittal_J2000.h5ad
13 | #sbatch --mem=72G --array=1-$(wc -l < $CSV) ./util/benchmark_array.slurm $DAT
14 |
15 | module purge
16 | module load anaconda3/2021.5
17 | conda activate fwt
18 |
19 | #first command line arg $1 is file path to dataset
20 | python -um utils.benchmark $SLURM_ARRAY_TASK_ID $1
21 |
--------------------------------------------------------------------------------
/utils/benchmark_gof.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on Fri Sep 17 08:50:51 2021
5 |
6 | @author: townesf
7 | """
8 |
9 | from argparse import ArgumentParser
10 | from utils import benchmark
11 |
12 | def arghandler(args=None):
13 | """parses a list of arguments (default is sys.argv[1:])"""
14 | parser = ArgumentParser()
15 | parser.add_argument("dataset", type=str,
16 | help="location of scanpy H5AD data file")
17 | parser.add_argument("val_pct", type=int,
18 | help="percentage of data to be used as validation set (0-100)")
19 | args = parser.parse_args(args) #if args is None, this will automatically parse sys.argv[1:]
20 | return args
21 |
22 | if __name__=="__main__":
23 | args = arghandler()
24 | # dat = "scrna/sshippo/data/sshippo_J2000.h5ad"
25 | res = benchmark.update_results(args.dataset, val_pct=args.val_pct, todisk=True)
26 |
--------------------------------------------------------------------------------
/utils/benchmark_gof.slurm:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH --job-name=benchmark_gof # create a short name for your job
3 | #SBATCH --nodes=1 # node count
4 | #SBATCH --ntasks=1 # total number of tasks across all nodes
5 | #SBATCH --cpus-per-task=12 # cpu-cores per task (>1 if multi-threaded tasks)
6 | #SBATCH --time=3:00:00 # total run time limit (HH:MM:SS)
7 | #SBATCH --mail-type=END,FAIL
8 | #SBATCH --mail-user=ftownes@princeton.edu
9 |
10 | #example usage, --mem-per-cpu default is 4G per core
11 | #DAT=./scrna/sshippo/data/sshippo_J2000.h5ad
12 | #sbatch --mem=180G ./utils/benchmark_gof.slurm $DAT
13 |
14 | module purge
15 | module load anaconda3/2021.5
16 | conda activate fwt
17 |
18 | #first command line arg $1 is file path to dataset
19 | #second command line arg $2 is an integer for the pct of data to be validation set (typically 5, implying 95% of observations are for training data)
20 | python -um utils.benchmark_gof $1 $2
21 |
--------------------------------------------------------------------------------
/utils/misc.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on Tue Apr 6 18:10:06 2021
5 |
6 | @author: townesf
7 | """
8 | import pathlib
9 | import numpy as np
10 | #from pickle import dump,load
11 | from math import ceil
12 | from copy import deepcopy
13 | from dill import dump,load
14 | from pandas import read_csv as pd_read_csv
15 | from tensorflow import clip_by_value
16 | from scipy.sparse import issparse
17 | from sklearn.cluster import KMeans
18 | from sklearn.utils import sparsefuncs
19 | from anndata import AnnData
20 | from squidpy.gr import spatial_neighbors,spatial_autocorr
21 |
22 | def mkdir_p(pth):
23 | pathlib.Path(pth).mkdir(parents=True,exist_ok=True)
24 |
25 | def rm_glob(pth,glob="*"):
26 | """
27 | Remove all files in the pth folder matching glob.
28 | """
29 | for i in pathlib.Path(pth).glob(glob):
30 | i.unlink(missing_ok=True)
31 |
32 | def poisson_loss(y,mu):
33 | """
34 | Equivalent to the Tensorflow Poisson loss
35 | https://www.tensorflow.org/api_docs/python/tf/keras/losses/Poisson
36 | It's the negative log-likelihood of Poisson without the log y! constant
37 | """
38 | with np.errstate(divide='ignore',invalid='ignore'):
39 | res = mu-y*np.log(mu)
40 | return np.mean(res[np.isfinite(res)])
41 |
42 | def poisson_deviance(y,mu,agg="sum",axis=None):
43 | """
44 | Equivalent to "KL divergence" between y and mu:
45 | https://scikit-learn.org/stable/modules/decomposition.html#nmf
46 | """
47 | with np.errstate(divide='ignore',invalid='ignore'):
48 | term1 = y*np.log(y/mu)
49 | if agg=="sum":
50 | aggfunc = np.sum
51 | elif agg=="mean":
52 | aggfunc = np.mean
53 | else:
54 | raise ValueError("agg must be 'sum' or 'mean'")
55 | term1 = aggfunc(term1[np.isfinite(term1)], axis=axis)
56 | return term1 + aggfunc(mu - y, axis=axis)
57 |
58 | def rmse(y,mu):
59 | return np.sqrt(((y-mu)**2).mean())
60 |
61 | def dev2ss(dev):
62 | """
63 | dev: a 1d array of deviance values (one per feature)
64 | returns: a dictionary with summary statistics (mean, argmax, and max)
65 | """
66 | return {"mean":dev.mean(), "argmax":dev.argmax(), "max":dev.max(), "med":np.median(dev)}
67 |
68 | def make_nonneg(x):
69 | return clip_by_value(x,0.0,np.inf)
70 |
71 | def make_grid(N,xmin=-2,xmax=2,dtype="float32"):
72 | x = np.linspace(xmin,xmax,num=int(np.sqrt(N)),dtype=dtype)
73 | return np.stack([X.ravel() for X in np.meshgrid(x,x)],axis=1)
74 |
75 | def kmeans_inducing_pts(X,M):
76 | M = int(M)
77 | Z = np.unique(X, axis=0)
78 | unique_locs = Z.shape[0]
79 | if M0
207 | #if row_id=1, skip=[]
208 | #if row_id=2, skip=[1]
209 | #if row_id=(file length-1), skip=[1,2,3,...,file_length-2]
210 | skip = range(1,row_id)
211 | #returns a pandas Series object
212 | return pd_read_csv(csv_file,skiprows=skip,nrows=1).iloc[0,:]
213 |
214 | def dims_autocorr(factors,coords,sort=True):
215 | """
216 | factors: (num observations) x (num latent dimensions) array
217 | coords: (num observations) x (num spatial dimensions) array
218 | sort: if True (default), returns the index and I statistics in decreasing
219 | order of autocorrelation. If False, returns the index and I statistics
220 | according to the ordering of factors.
221 |
222 | returns: an integer array of length (num latent dims), "idx"
223 | and a numpy array containing the Moran's I values for each dimension
224 |
225 | indexing factors[:,idx] will sort the factors in decreasing order of spatial
226 | autocorrelation.
227 | """
228 | ad = AnnData(X=factors,obsm={"spatial":coords})
229 | spatial_neighbors(ad)
230 | df = spatial_autocorr(ad,mode="moran",copy=True)
231 | if not sort: #revert to original sort order
232 | df.sort_index(inplace=True)
233 | idx = np.array([int(i) for i in df.index])
234 | return idx,df["I"].to_numpy()
235 |
236 | def nan_to_zero(X):
237 | X = deepcopy(X)
238 | X[np.isnan(X)] = 0
239 | return X
240 |
--------------------------------------------------------------------------------
/utils/nnfu.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Utility functions for working with nonnegative factor models. See also
5 | postprocess.py for more complex functions to facilitate interpretation.
6 |
7 | Created on Thu Sep 23 13:30:10 2021
8 |
9 | @author: townesf
10 | """
11 | import numpy as np
12 | from sklearn.decomposition import NMF
13 | # from scipy.special import logsumexp
14 | from utils.misc import lnormal_approx_dirichlet
15 |
16 | def normalize_cols(W):
17 | """
18 | Rescale the columns of a matrix to sum to one
19 | """
20 | wsum = W.sum(axis=0)
21 | return W/wsum, wsum
22 |
23 | def normalize_rows(W):
24 | """
25 | Rescale the rows of a matrix to sum to one
26 | """
27 | wsum = W.sum(axis=1)
28 | return (W.T/wsum).T, wsum
29 |
30 | def shrink_factors(F,shrinkage=0.2):
31 | a = shrinkage
32 | if 0=N: Dval = None #avoid returning an empty array
144 | return Dtr,Dval
145 |
146 | def center_data(Dtr_n,Dval_n=None):
147 | Dtr_c = deepcopy(Dtr_n)
148 | feature_means=Dtr_c["Y"].mean(axis=0)
149 | Dtr_c["Y"] -= feature_means
150 | if Dval_n:
151 | Dval_c = deepcopy(Dval_n)
152 | Dval_c["Y"] -= feature_means
153 | else:
154 | Dval_c = None
155 | return feature_means,Dtr_c,Dval_c
156 |
157 | def minibatch_size_adjust(num_obs,batch_size):
158 | """
159 | Calculate adjusted minibatch size that divides
160 | num_obs as evenly as possible
161 | num_obs : number of observations in full data
162 | batch_size : maximum size of a minibatch
163 | """
164 | nbatch = ceil(num_obs/float(batch_size))
165 | return int(ceil(num_obs/nbatch))
166 |
167 | def prepare_datasets_tf(Dtrain,Dval=None,shuffle=False,batch_size=None):
168 | """
169 | Dtrain and Dval are dicts containing numpy np.arrays of data.
170 | Dtrain must contain the key "Y"
171 | Returns a from_tensor_slices conversion of Dtrain and a dict of tensors for Dval
172 | """
173 | Ntr = Dtrain["Y"].shape[0]
174 | if batch_size is None:
175 | #ie one batch containing all observations by default
176 | batch_size = Ntr
177 | else:
178 | batch_size = minibatch_size_adjust(Ntr,batch_size)
179 | Dtrain = Dataset.from_tensor_slices(Dtrain)
180 | if shuffle:
181 | Dtrain = Dtrain.shuffle(Ntr)
182 | Dtrain = Dtrain.batch(batch_size)
183 | if Dval is not None:
184 | Dval = {i:constant(Dval[i]) for i in Dval}
185 | return Dtrain, Ntr, Dval
186 |
187 | def load_data(dataset, model=None, lik=None, train_frac=0.95, sz="constant",
188 | flip_yaxis=True):
189 | """
190 | dataset: the file path of a scanpy anndata h5ad file
191 | --OR-- the AnnData object itself
192 | p: a dict-like object of model parameters
193 | """
194 | try:
195 | ad = read_h5ad(path.normpath(dataset))
196 | except TypeError:
197 | ad = dataset
198 | kw1 = {"nfeat":None, "train_frac":train_frac, "dtp":"float32",
199 | "flip_yaxis":flip_yaxis}
200 | kw2 = {"shuffle":False,"batch_size":None}
201 | D = {"raw":{}}
202 | Dtr,Dval = anndata_to_train_val(ad,layer="counts",sz=sz,**kw1)
203 | D["raw"]["tr"] = Dtr
204 | D["raw"]["val"] = Dval
205 | D["raw"]["tf"] = prepare_datasets_tf(Dtr,Dval=Dval,**kw2)
206 | fmeans=None
207 | if lik is None or lik=="gau":
208 | #normalized data
209 | Dtr_n,Dval_n = anndata_to_train_val(ad,layer=None,sz="constant",**kw1)
210 | D["norm"] = {}
211 | D["norm"]["tr"] = Dtr_n
212 | D["norm"]["val"] = Dval_n
213 | D["norm"]["tf"] = prepare_datasets_tf(Dtr_n,Dval=Dval_n,**kw2)
214 | if model is None or model in ("RSF","FA"):
215 | #centered features
216 | fmeans,Dtr_c,Dval_c = center_data(Dtr_n,Dval_n)
217 | D["ctr"] = {}
218 | D["ctr"]["tr"] = Dtr_c
219 | D["ctr"]["val"] = Dval_c
220 | D["ctr"]["tf"] = prepare_datasets_tf(Dtr_c,Dval=Dval_c,**kw2)
221 | return D,fmeans
222 |
223 | # def split_data_tuple(D,train_frac=0.8,shuffle=True):
224 | # """
225 | # D is list of data [X,Y,sz]
226 | # leading dimension of each element of D must be the same
227 | # train_frac: fraction of observations that should go into training data
228 | # 1-train_frac: observations for validation data
229 | # """
230 | # n = D[0].shape[0]
231 | # idx = list(range(n))
232 | # if shuffle: random.shuffle(idx)
233 | # ntr = round(train_frac*n)
234 | # itr = idx[:ntr]
235 | # ival = idx[ntr:n]
236 | # Dtr = []
237 | # Dval = []
238 | # for d in D:
239 | # shp = d.shape
240 | # if len(shp)==1:
241 | # Dtr.append(d[itr])
242 | # Dval.append(d[ival])
243 | # else:
244 | # Dtr.append(d[itr,:])
245 | # Dval.append(d[ival,:])
246 | # return tuple(Dtr),tuple(Dval)
247 |
248 | # def calc_hidden_layer_sizes(J,T,N):
249 | # """
250 | # J is input dimension, T is output dimension, N is number of hidden layers.
251 | # Returns a tuple of length (N) containing the widths of hidden layers.
252 | # The spacing forms a linear decline from J to T
253 | # """
254 | # delta = float(J-T)/(N+1) #increment size
255 | # res = np.round(J-delta*np.array(range(1,N+1)))
256 | # #res = res.astype("int32").tolist()
257 | # #return [J]+res+[T]
258 | # return res.astype("int32")
259 |
--------------------------------------------------------------------------------
/utils/visualize.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Functions for visualization
5 |
6 | Created on Wed May 19 09:58:59 2021
7 |
8 | @author: townesf
9 | """
10 |
11 | import numpy as np
12 | import matplotlib.pyplot as plt
13 | from contextlib import suppress
14 | from scipy.interpolate import interp1d
15 | from scipy.spatial import Delaunay
16 |
17 | from utils.misc import poisson_deviance,dev2ss,rmse
18 |
19 | def heatmap(X,y,figsize=(6,4),bgcol="gray",cmap="turbo",**kwargs):
20 | fig,ax=plt.subplots(figsize=figsize)
21 | ax.set_facecolor(bgcol)
22 | ax.scatter(X[:,0],X[:,1],c=y,cmap=cmap,**kwargs)
23 | # fig.show()
24 |
25 | def hide_spines(ax):
26 | for side in ax.spines:
27 | ax.spines[side].set_visible(False)
28 |
29 | def color_spines(ax,col="black"):
30 | # ax.spines['top'].set_color(col)
31 | # ax.spines['right'].set_color(col)
32 | # ax.spines['bottom'].set_color(col)
33 | # ax.spines['left'].set_color(col)
34 | for side in ax.spines:
35 | ax.spines[side].set_color(col)
36 |
37 | def set_titles(fig,titles,**kwargs):
38 | for i in range(len(titles)):
39 | ax = fig.axes[i]
40 | ax.set_title(titles[i],**kwargs)
41 |
42 | def hide_axes(ax):
43 | ax.tick_params(bottom=False,left=False,labelbottom=False,labelleft=False)
44 |
45 | def multiheatmap(X, Y, grid, figsize=(6,4), cmap="turbo", bgcol="gray",
46 | axhide=True, subplot_space=None, spinecolor=None,
47 | savepath=None, **kwargs):
48 | if subplot_space is not None:
49 | gridspec_kw = {'wspace':subplot_space, 'hspace':subplot_space}
50 | else:
51 | gridspec_kw = {}
52 | fig, axgrid = plt.subplots(*grid, figsize=figsize, gridspec_kw=gridspec_kw)
53 | # if subplot_space is not None:
54 | # plt.subplots_adjust(wspace=subplot_space, hspace=subplot_space)
55 | for i in range(len(fig.axes)):
56 | ax = fig.axes[i]
57 | ax.set_facecolor(bgcol)
58 | if i=0
201 |
202 | def hull_tile(Z, N):
203 | Zg = bounding_box_tile(Z,N)
204 | return Zg[in_hull(Zg,Z)]
205 |
206 |
--------------------------------------------------------------------------------