├── .dockerignore ├── .github └── workflows │ └── main.yml ├── .gitignore ├── Dockerfile ├── LICENSE.txt ├── README.md ├── docs ├── Makefile ├── build │ └── .gitkeep └── source │ ├── _static │ └── css │ │ └── custom.css │ ├── conf.py │ ├── index.rst │ ├── starling.rst │ ├── tutorial │ ├── getting-started.ipynb │ ├── log │ │ └── .gitkeep │ └── sample_input.h5ad │ ├── tutorials.md │ └── utility.rst ├── poetry.lock ├── pyproject.toml ├── requirements.txt ├── starling-schematic600x.png ├── starling.png ├── starling ├── __init__.py ├── starling.py └── utility.py └── tests ├── conftest.py ├── fixtures └── sample_input.h5ad ├── test_sanity.py ├── test_starling.py └── test_utility.py /.dockerignore: -------------------------------------------------------------------------------- 1 | .venv 2 | -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: Starling 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | pull_request: 7 | branches: [ "main" ] 8 | 9 | jobs: 10 | test: 11 | runs-on: ubuntu-latest 12 | strategy: 13 | matrix: 14 | python-version: [ '3.9', '3.10', '3.11', '3.12' ] 15 | steps: 16 | - uses: actions/checkout@v3 17 | - name: Install Poetry 18 | with: 19 | virtualenvs-create: true 20 | virtualenvs-in-project: true 21 | uses: snok/install-poetry@v1 22 | - uses: actions/setup-python@v4 23 | with: 24 | python-version: ${{ matrix.python-version }} 25 | cache: 'poetry' 26 | - name: Install project 27 | if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' 28 | run: poetry install --with docs,dev --no-interaction 29 | - uses: isort/isort-action@v1 30 | - name: Black 31 | run: | 32 | poetry run black --check --verbose ./starling 33 | - name: Pytest 34 | run: | 35 | poetry run pytest 36 | 37 | docs: 38 | runs-on: ubuntu-latest 39 | steps: 40 | - uses: actions/checkout@v3 41 | - name: Install Poetry 42 | with: 43 | virtualenvs-create: true 44 | virtualenvs-in-project: true 45 | uses: snok/install-poetry@v1 46 | - uses: actions/setup-python@v4 47 | with: 48 | python-version: '3.12' 49 | cache: 'poetry' 50 | - name: Install project 51 | if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' 52 | run: poetry install --with docs,dev --no-interaction 53 | - name: Build documentation 54 | run: | 55 | mkdir gh-pages 56 | touch gh-pages/.nojekyll 57 | cd docs/source 58 | poetry run sphinx-build -b html . ../build 59 | cp -r ../build/* ../../gh-pages/ 60 | - name: Deploy documentation 61 | if: ${{ github.event_name == 'push' }} 62 | uses: JamesIves/github-pages-deploy-action@4.1.4 63 | with: 64 | branch: gh-pages 65 | folder: gh-pages 66 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .venv 3 | lightning_logs 4 | pyvenv.cfg 5 | **/log/** 6 | !**/log/.gitkeep 7 | docs/build/** 8 | dist 9 | !docs/build/.gitkeep 10 | model.pt 11 | .pytest_cache 12 | .vscode 13 | jupyter_execute 14 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.12-slim 2 | 3 | WORKDIR /code 4 | 5 | ARG USERNAME=starling 6 | ARG USER_UID=1000 7 | ARG USER_GID=$USER_UID 8 | 9 | RUN apt-get update && \ 10 | apt-get install -y build-essential \ 11 | libssl-dev \ 12 | libffi-dev \ 13 | python3-dev \ 14 | git && \ 15 | rm -rf /var/lib/apt/lists/* 16 | 17 | RUN groupadd --gid $USER_GID $USERNAME \ 18 | && useradd --uid $USER_UID --gid $USER_GID -m $USERNAME \ 19 | && chown -R ${USER_UID}:${USER_GID} /code 20 | 21 | # Install Poetry 22 | # https://github.com/python-poetry/poetry/issues/6397#issuecomment-1236327500 23 | 24 | ENV POETRY_HOME=/opt/poetry 25 | 26 | # install poetry into its own venv 27 | RUN python3 -m venv $POETRY_HOME && \ 28 | $POETRY_HOME/bin/pip install poetry==1.8.0 29 | 30 | ENV VIRTUAL_ENV=/poetry-env \ 31 | PATH="/poetry-env/bin:$POETRY_HOME/bin:$PATH" 32 | 33 | RUN python3 -m venv $VIRTUAL_ENV && \ 34 | chown -R $USER_UID:$USER_GID $POETRY_HOME /poetry-env 35 | 36 | USER $USERNAME 37 | 38 | # prevent full rebuilds every time code changes 39 | COPY --chown=${USER_UID}:${USER_GID} pyproject.toml poetry.lock README.md /code/ 40 | COPY --chown=${USER_UID}:${USER_GID} starling/__init__.py /code/starling/__init__.py 41 | 42 | RUN poetry install --with docs,dev && poetry self add poetry-plugin-export 43 | 44 | COPY --chown=${USER_UID}:${USER_GID} . . 45 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright Notice 2 | 3 | Copyright ©2024 Sinai Health System, Toronto, Canada. All Rights Reserved. 4 | License 5 | 6 | This software is dual licensed under the Apache License, Version 2.0 (the "Apache License") or a Commercial Use License. 7 | 8 | APACHE LICENSE - FOR RESEARCH OR PERSONAL USE ONLY 9 | Under the terms of the Apache License, Version 2.0, you may obtain a copy of the license at http://www.apache.org/licenses/LICENSE-2.0. 10 | Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | See the License for the specific language governing permissions and limitations under the License. 12 | 13 | COMMERCIAL USE LICENSE - IF SOFTWARE IS INCORPORATED INTO PRODUCTS, SERVICES OR OTHER COMMERCIAL ENDEAVORS 14 | For users who wish to incorporate this software into commercial products and services for sale and/or further distribution, a separate Commercial Use License is required. For more information about a Commercial Use license, please contact: 15 | Director, Technology Transfer, 16 | Office of Technology Transfer & Industrial Liaison 17 | homonko@lunenfeld.ca -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SegmentaTion AwaRe cLusterING (STARLING) 2 | 3 | ![build](https://github.com/camlab-bioml/starling/actions/workflows/main.yml/badge.svg) 4 | ![](https://img.shields.io/badge/Python-3.9-blue) 5 | ![](https://img.shields.io/badge/Python-3.10-blue) 6 | ![](https://img.shields.io/badge/Python-3.11-blue) 7 | ![](https://img.shields.io/badge/Python-3.12-blue) 8 | 9 | 10 | STARLING is a probabilistic model for clustering cells measured with spatial expression assays (e.g. IMC, MIBI, etc...) while accounting for segmentation errors. 11 | 12 | It outputs: 13 | 1. Clusters that account for segmentation errors in the data (i.e. should no longer show implausible marker co-expression) 14 | 2. Assignments for every cell in the dataset to those clusters 15 | 3. A segmentation error probability for each cell 16 | 17 | The paper describing the method and introducing evaluation metrics and gold standard data is available: [Lee et al. Nature Communications (2025) _Segmentation aware probabilistic phenotyping of single-cell spatial protein expression data_](https://www.nature.com/articles/s41467-024-55214-w) 18 | 19 | A **tutorial** outlining basic usage is available [here][tutorial]. 20 | 21 | ![Model](https://github.com/camlab-bioml/starling/raw/main/starling-schematic600x.png) 22 | 23 | ## Requirements 24 | 25 | Python 3.9 or above is required to run starling. If your current version of python is not one of these, we recommend using [pyenv](https://github.com/pyenv/pyenv) to install a compatible version alongside your current one. Alternately, you could use the Docker configuration described below. 26 | 27 | ## Installation 28 | 29 | ### Install with pip 30 | 31 | `pip install biostarling` and then import the module `from starling import starling` 32 | 33 | ### Building from source 34 | 35 | Starling can be cloned and installed locally (typically <10 minutes) via the Github repository, 36 | 37 | ``` 38 | git clone https://github.com/camlab-bioml/starling.git && cd starling 39 | ``` 40 | 41 | After cloning the repository, the next step is to install the required dependencies. There are three recommended methods: 42 | 43 | ### 1. Use `requirements.txt` and your own virtual environment: 44 | 45 | We use virtualenvwrapper (4.8.4) to create and activated a standalone virtual environment for _starling_: 46 | 47 | ``` 48 | pip install virtualenvwrapper==4.8.4 49 | mkvirtualenv starling 50 | ``` 51 | 52 | For convenience, one can install packages in the tested environment: 53 | 54 | ``` 55 | pip install -r requirements.txt 56 | ``` 57 | 58 | The virtual environment can be activated and deactivated subsequently: 59 | 60 | ``` 61 | workon starling 62 | deactivate 63 | ``` 64 | 65 | ### 2. Use Poetry and `pyproject.toml`. 66 | 67 | [Poetry](https://python-poetry.org/) is a packaging and dependency management tool can simplify code development and deployment. If you do not have Poetry installed, you can find instructions [here](https://python-poetry.org/docs/). 68 | 69 | Once poetry is installed, navigate to the `starling` directory and run `poetry install`. This will download the required packages into a virtual environment and install Starling in development mode. The location and state of the virtual environment may depend on your system. For more details, see [the documentation](https://python-poetry.org/docs/managing-environments/). 70 | 71 | 72 | ### 3. Use Docker 73 | 74 | If you have Docker installed on your system, you can run `docker build -t starling .` from the project root in order to build the image locally. You can then open a shell within the image with a command like `docker run --rm -it starling bash`. 75 | 76 | ## Getting started 77 | 78 | With starling installed, please proceed to the [online documentation][docs] or launch the [interactive notebook tutorial][tutorial] to learn more about the package's features. 79 | 80 | ## Authors 81 | 82 | This software is authored by: Jett (Yuju) Lee, Conor Klamann, Kieran R Campbell 83 | 84 | Lunenfeld-Tanenbaum Research Institute & University of Toronto 85 | 86 | 87 | 88 | [tutorial]: https://colab.research.google.com/github/camlab-bioml/starling/blob/main/docs/source/tutorial/getting-started.ipynb 89 | [license]: https://github.com/camlab-bioml/starling/blob/main/LICENSE 90 | [docs]: https://camlab-bioml.github.io/starling/ 91 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/build/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camlab-bioml/starling/64f7294972eb14e09fe0a642c64fd26844cfbe8a/docs/build/.gitkeep -------------------------------------------------------------------------------- /docs/source/_static/css/custom.css: -------------------------------------------------------------------------------- 1 | /* Prevent wide tables from overflowing */ 2 | .cell.docutils.container { 3 | overflow-x: auto; 4 | } 5 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | project = "Starling" 2 | copyright = "2024, contribs" 3 | author = "contribs" 4 | 5 | extensions = [ 6 | "autodocsumm", 7 | "myst_nb", 8 | "sphinx_autodoc_typehints", 9 | "sphinx_rtd_theme", 10 | "sphinx.ext.autodoc", 11 | "sphinx.ext.intersphinx", 12 | "sphinx.ext.mathjax", 13 | "sphinx.ext.viewcode", 14 | ] 15 | 16 | nb_custom_formats = { 17 | ".md": ["jupytext.reads", {"fmt": "mystnb"}], 18 | } 19 | 20 | source_suffix = [".rst", ".md"] 21 | 22 | nb_execution_timeout = -1 23 | 24 | autodoc_typehints = "description" 25 | 26 | html_theme = "sphinx_rtd_theme" 27 | html_static_path = ["_static"] 28 | html_css_files = [ 29 | "css/custom.css", 30 | ] 31 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | 2 | .. toctree:: 3 | :caption: Getting Started 4 | :hidden: 5 | 6 | tutorials 7 | 8 | .. toctree:: 9 | :caption: Package Reference 10 | :hidden: 11 | 12 | starling 13 | utility 14 | 15 | 16 | .. include:: ../../README.md 17 | :parser: myst_parser.sphinx_ 18 | -------------------------------------------------------------------------------- /docs/source/starling.rst: -------------------------------------------------------------------------------- 1 | Starling Class (ST) 2 | =================== 3 | 4 | .. autoclass:: starling.starling.ST 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | -------------------------------------------------------------------------------- /docs/source/tutorial/getting-started.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "b3e79fa9", 6 | "metadata": {}, 7 | "source": [ 8 | "# Getting started with Starling (ST)\n" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "2a06e71b", 15 | "metadata": { 16 | "tags": [ 17 | "hide-output" 18 | ] 19 | }, 20 | "outputs": [], 21 | "source": [ 22 | "%pip install biostarling\n", 23 | "%pip install lightning_lite\n", 24 | "\n", 25 | "import anndata as ad\n", 26 | "import pandas as pd\n", 27 | "import torch\n", 28 | "from starling import starling, utility\n", 29 | "from lightning_lite import seed_everything\n", 30 | "import pytorch_lightning as pl\n", 31 | "\n" 32 | ] 33 | }, 34 | { 35 | "cell_type": "markdown", 36 | "id": "b615eb39", 37 | "metadata": {}, 38 | "source": [ 39 | "## Setting seed for everything\n" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "id": "e83f4cce", 46 | "metadata": { 47 | "tags": [ 48 | "hide-output" 49 | ] 50 | }, 51 | "outputs": [], 52 | "source": [ 53 | "seed_everything(10, workers=True)" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "id": "d1f5142d", 59 | "metadata": {}, 60 | "source": [ 61 | "## Loading annData objects\n" 62 | ] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "id": "69ef8c1d", 67 | "metadata": {}, 68 | "source": [ 69 | "The example below runs Kmeans with 10 clusters read from \"sample_input.h5ad\" object." 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "id": "8c4a203f", 76 | "metadata": { 77 | "tags": [ 78 | "hide-output" 79 | ] 80 | }, 81 | "outputs": [], 82 | "source": [ 83 | "!wget https://github.com/camlab-bioml/starling/raw/main/docs/source/tutorial/sample_input.h5ad\n", 84 | "\n", 85 | "adata = utility.init_clustering(\"KM\", ad.read_h5ad(\"sample_input.h5ad\"), k=10)\n" 86 | ] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "id": "52d3d9fb", 91 | "metadata": {}, 92 | "source": [ 93 | "- The input anndata object should contain a cell-by-protein matrix of segmented single-cell expression profiles in the `.X` position. Optionally, cell size information can also be provided as a column of the `.obs` DataFrame. In this case `model_cell_size` should be set to `True` and the column specified in the `cell_size_col_name`argument.\n", 94 | "- Users might want to arcsinh protein expressions in \\*.h5ad (for example, `sample_input.h5ad`).\n", 95 | "- The `utility.py` provides an easy setup of GMM, KM (Kmeans) or PG (PhenoGraph).\n", 96 | "- Default settings are applied to each method.\n", 97 | "- k can be omitted when PG is used.\n" 98 | ] 99 | }, 100 | { 101 | "cell_type": "markdown", 102 | "id": "7fd11c15", 103 | "metadata": {}, 104 | "source": [ 105 | "## Setting initializations\n" 106 | ] 107 | }, 108 | { 109 | "cell_type": "markdown", 110 | "id": "6effd2b9", 111 | "metadata": {}, 112 | "source": [ 113 | "The example below uses defualt parameter settings based on benchmarking results (more details in manuscript).\n" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": null, 119 | "id": "eff9a063", 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "st = starling.ST(adata)" 124 | ] 125 | }, 126 | { 127 | "cell_type": "markdown", 128 | "id": "923d2e71", 129 | "metadata": {}, 130 | "source": [ 131 | "A list of parameters are shown:\n", 132 | "\n", 133 | "- adata: annDATA object of the sample\n", 134 | "- dist_option (default: 'T'): T for Student-T (df=2) and N for Normal (Gaussian)\n", 135 | "- singlet_prop (default: 0.6): the proportion of anticipated segmentation error free cells \n", 136 | "- model_cell_size (default: 'Y'): Y for incoporating cell size in the model and N otherwise\n", 137 | "- cell_size_col_name (default: 'area'): area is the column name in anndata.obs dataframe\n", 138 | "- model_zplane_overlap (default: 'Y'): Y for modeling z-plane overlap when cell size is modelled and N otherwise\n", 139 | " Note: if the user sets model_cell_size = 'N', then model_zplane_overlap is ignored\n", 140 | "- model_regularizer (default: 1): Regularizier term impose on synthetic doublet loss (BCE)\n", 141 | "- learning_rate (default: 1e-3): The learning rate of ADAM optimizer for STARLING\n", 142 | "\n", 143 | "Equivalent to the above example:\n", 144 | "```python\n", 145 | "st = starling.ST(adata, 'T', 'Y', 'area', 'Y', 1, 1e-3)\n", 146 | "```\n" 147 | ] 148 | }, 149 | { 150 | "cell_type": "markdown", 151 | "id": "63939215", 152 | "metadata": {}, 153 | "source": [ 154 | "## Setting training log\n" 155 | ] 156 | }, 157 | { 158 | "cell_type": "markdown", 159 | "id": "d721258f", 160 | "metadata": {}, 161 | "source": [ 162 | "Once training starts, a new directory 'log' will be created." 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": null, 168 | "id": "a217070c", 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "## log training results via tensorboard\n", 173 | "log_tb = pl.loggers.TensorBoardLogger(save_dir=\"log\")" 174 | ] 175 | }, 176 | { 177 | "cell_type": "markdown", 178 | "id": "ae8e46ea", 179 | "metadata": {}, 180 | "source": [ 181 | "One could view the training information via tensorboard. Please refer to torch lightning (https://lightning.ai/docs/pytorch/stable/api_references.html#profiler) for other possible loggers.\n" 182 | ] 183 | }, 184 | { 185 | "cell_type": "markdown", 186 | "id": "914bcd5c", 187 | "metadata": {}, 188 | "source": [ 189 | "## Setting early stopping criterion\n" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": null, 195 | "id": "90877a9c", 196 | "metadata": {}, 197 | "outputs": [], 198 | "source": [ 199 | "## set early stopping criterion\n", 200 | "cb_early_stopping = pl.callbacks.EarlyStopping(monitor=\"train_loss\", mode=\"min\", verbose=False)" 201 | ] 202 | }, 203 | { 204 | "cell_type": "markdown", 205 | "id": "ac4c7459", 206 | "metadata": {}, 207 | "source": [ 208 | "Training loss is monitored.\n" 209 | ] 210 | }, 211 | { 212 | "cell_type": "markdown", 213 | "id": "bb32a46b", 214 | "metadata": {}, 215 | "source": [ 216 | "## Training Starling\n" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": null, 222 | "id": "8f49c63c", 223 | "metadata": { 224 | "tags": [ 225 | "hide-output" 226 | ] 227 | }, 228 | "outputs": [], 229 | "source": [ 230 | "## train ST\n", 231 | "st.train_and_fit(\n", 232 | " callbacks=[cb_early_stopping],\n", 233 | " logger=[log_tb],\n", 234 | ")" 235 | ] 236 | }, 237 | { 238 | "cell_type": "markdown", 239 | "id": "3ba887b2", 240 | "metadata": {}, 241 | "source": [ 242 | "## Appending STARLING results to the annData object\n" 243 | ] 244 | }, 245 | { 246 | "cell_type": "code", 247 | "execution_count": null, 248 | "id": "3082c69a", 249 | "metadata": { 250 | "scrolled": true 251 | }, 252 | "outputs": [], 253 | "source": [ 254 | "## retrive starling results\n", 255 | "result = st.result()" 256 | ] 257 | }, 258 | { 259 | "cell_type": "markdown", 260 | "id": "a705d895", 261 | "metadata": {}, 262 | "source": [ 263 | "## The following information can be retrived from the annData object:\n", 264 | "\n", 265 | "- st.adata.varm['init_exp_centroids'] -- initial expression cluster centroids (P x C matrix)\n", 266 | "- st.adata.varm['st_exp_centroids'] -- ST expression cluster centroids (P x C matrix)\n", 267 | "- st.adata.uns['init_cell_size_centroids'] -- initial cell size centroids if STARLING models cell size\n", 268 | "- st.adata.uns['st_cell_size_centroids'] -- initial & ST cell size centroids if ST models cell size\n", 269 | "- st.adata.obsm['assignment_prob_matrix'] -- cell assignment probability (N x C maxtrix)\n", 270 | "- st.adata.obsm['gamma_prob_matrix'] -- gamma probabilitiy of two cells (N x C x C maxtrix)\n", 271 | "- st.adata.obs['doublet'] -- doublet indicator\n", 272 | "- st.adata.obs['doublet_prob'] -- doublet probabilities\n", 273 | "- st.adata.obs['init_label'] -- initial assignments\n", 274 | "- st.adata.obs['st_label'] -- ST assignments\n", 275 | "- st.adata.obs['max_assign_prob'] -- ST max probabilites of assignments\n", 276 | "\n", 277 | "_N: # of cells; C: # of clusters; P: # of proteins_\n" 278 | ] 279 | }, 280 | { 281 | "cell_type": "markdown", 282 | "id": "4ab8cb0a", 283 | "metadata": {}, 284 | "source": [ 285 | "## Saving the model\n" 286 | ] 287 | }, 288 | { 289 | "cell_type": "code", 290 | "execution_count": null, 291 | "id": "204cad47", 292 | "metadata": {}, 293 | "outputs": [], 294 | "source": [ 295 | "## st object can be saved\n", 296 | "torch.save(st, \"model.pt\")" 297 | ] 298 | }, 299 | { 300 | "cell_type": "markdown", 301 | "id": "980dad28", 302 | "metadata": {}, 303 | "source": [ 304 | "model.pt will be saved in the same directory as this notebook.\n" 305 | ] 306 | }, 307 | { 308 | "cell_type": "markdown", 309 | "id": "ad7e5fc0", 310 | "metadata": {}, 311 | "source": [ 312 | "## Showing STARLING results\n" 313 | ] 314 | }, 315 | { 316 | "cell_type": "code", 317 | "execution_count": null, 318 | "id": "c7e67d1d", 319 | "metadata": { 320 | "tags": [ 321 | "scroll-output" 322 | ] 323 | }, 324 | "outputs": [], 325 | "source": [ 326 | "display(result)" 327 | ] 328 | }, 329 | { 330 | "cell_type": "markdown", 331 | "id": "53e32d26", 332 | "metadata": {}, 333 | "source": [ 334 | "One could easily perform further analysis such as co-occurance, enrichment analysis and etc.\n" 335 | ] 336 | }, 337 | { 338 | "cell_type": "code", 339 | "execution_count": null, 340 | "id": "b601be72", 341 | "metadata": {}, 342 | "outputs": [], 343 | "source": [ 344 | "result.obs" 345 | ] 346 | }, 347 | { 348 | "cell_type": "markdown", 349 | "id": "af541283", 350 | "metadata": {}, 351 | "source": [ 352 | "Starling provides doublet probabilities and cell assignment if it were a singlet for each cell.\n" 353 | ] 354 | }, 355 | { 356 | "cell_type": "markdown", 357 | "id": "80e61208", 358 | "metadata": {}, 359 | "source": [ 360 | "## Showing initial expression centroids:\n" 361 | ] 362 | }, 363 | { 364 | "cell_type": "code", 365 | "execution_count": null, 366 | "id": "a2be0fcc", 367 | "metadata": {}, 368 | "outputs": [], 369 | "source": [ 370 | "## initial expression centroids (p x c) matrix\n", 371 | "pd.DataFrame(result.varm[\"init_exp_centroids\"], index=result.var_names)" 372 | ] 373 | }, 374 | { 375 | "cell_type": "markdown", 376 | "id": "03424211", 377 | "metadata": {}, 378 | "source": [ 379 | "There are 10 centroids since we set Kmeans (KM) as k = 10 earlier.\n" 380 | ] 381 | }, 382 | { 383 | "cell_type": "markdown", 384 | "id": "f0bc41a8", 385 | "metadata": {}, 386 | "source": [ 387 | "## Showing Starling expression centroids:\n" 388 | ] 389 | }, 390 | { 391 | "cell_type": "code", 392 | "execution_count": null, 393 | "id": "a11a5334", 394 | "metadata": {}, 395 | "outputs": [], 396 | "source": [ 397 | "## starling expression centroids (p x c) matrix\n", 398 | "pd.DataFrame(result.varm[\"st_exp_centroids\"], index=result.var_names)" 399 | ] 400 | }, 401 | { 402 | "cell_type": "markdown", 403 | "id": "a2cccf9d", 404 | "metadata": {}, 405 | "source": [ 406 | "From here one could easily annotate cluster centroids to cell type.\n" 407 | ] 408 | }, 409 | { 410 | "cell_type": "markdown", 411 | "id": "993eb08b", 412 | "metadata": {}, 413 | "source": [ 414 | "## Showing Assignment Distributions:\n" 415 | ] 416 | }, 417 | { 418 | "cell_type": "code", 419 | "execution_count": null, 420 | "id": "75f8b562", 421 | "metadata": {}, 422 | "outputs": [], 423 | "source": [ 424 | "## assignment distributions (n x c maxtrix)\n", 425 | "pd.DataFrame(result.obsm[\"assignment_prob_matrix\"], index=result.obs.index)" 426 | ] 427 | }, 428 | { 429 | "cell_type": "markdown", 430 | "id": "b203933c", 431 | "metadata": {}, 432 | "source": [ 433 | "Currently, we assign a cell label based on the maximum probability among all possible clusters. However, these could be mislabeled because maximum and second highest probabilies can be very close." 434 | ] 435 | } 436 | ], 437 | "metadata": { 438 | "kernelspec": { 439 | "display_name": "Python 3", 440 | "language": "python", 441 | "name": "python3" 442 | }, 443 | "language_info": { 444 | "codemirror_mode": { 445 | "name": "ipython", 446 | "version": 3 447 | }, 448 | "file_extension": ".py", 449 | "mimetype": "text/x-python", 450 | "name": "python", 451 | "nbconvert_exporter": "python", 452 | "pygments_lexer": "ipython3", 453 | "version": "3.9.15" 454 | } 455 | }, 456 | "nbformat": 4, 457 | "nbformat_minor": 5 458 | } 459 | -------------------------------------------------------------------------------- /docs/source/tutorial/log/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camlab-bioml/starling/64f7294972eb14e09fe0a642c64fd26844cfbe8a/docs/source/tutorial/log/.gitkeep -------------------------------------------------------------------------------- /docs/source/tutorial/sample_input.h5ad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camlab-bioml/starling/64f7294972eb14e09fe0a642c64fd26844cfbe8a/docs/source/tutorial/sample_input.h5ad -------------------------------------------------------------------------------- /docs/source/tutorials.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Tutorials 4 | 5 | ```{toctree} 6 | :maxdepth: 2 7 | :hidden: 8 | 9 | tutorial/getting-started.ipynb 10 | ``` 11 | 12 | 13 | - [Getting started with Starling (ST)](tutorial/getting-started.ipynb) 14 | -------------------------------------------------------------------------------- /docs/source/utility.rst: -------------------------------------------------------------------------------- 1 | starling.utility 2 | ================ 3 | 4 | Module contents 5 | --------------- 6 | 7 | .. automodule:: starling.utility 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | 12 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "biostarling" 3 | packages = [{ include = "starling" }] 4 | version = "0.1.4" 5 | description = "Segmentation error aware clustering single-cell spatial expression data" 6 | repository = "https://github.com/camlab-bioml/starling" 7 | authors = ["Jett (Yuju) Lee "] 8 | readme = "README.md" 9 | keywords = ["imaging cytometry classifier single-cell"] 10 | classifiers = [ 11 | "Intended Audience :: Science/Research", 12 | "Programming Language :: Python :: 3.9", 13 | "Programming Language :: Python :: 3.10", 14 | "Programming Language :: Python :: 3.11", 15 | "Programming Language :: Python :: 3.12", 16 | ] 17 | license = "See License.txt" 18 | 19 | [tool.poetry.dependencies] 20 | python = ">= 3.9, < 3.13" 21 | phenograph = "^1.5.7" 22 | flowsom = "^0.1.1" 23 | numpy = "^1.26" 24 | pandas = ">= 0.23.0" 25 | pytorch-lightning = "^2.3.3" 26 | scanpy = "^1.10.2" 27 | torch = "^2.4.0" 28 | 29 | [tool.poetry.group.dev] 30 | optional = true 31 | 32 | [tool.poetry.group.dev.dependencies] 33 | black = "^23.12.1" 34 | pytest = "^7.4.4" 35 | isort = "^5.13.2" 36 | twine = "^5.1.0" 37 | build = "^1.2.1" 38 | 39 | [tool.poetry.group.docs] 40 | optional = true 41 | 42 | [tool.poetry.group.docs.dependencies] 43 | tensorboard = "^2.15.1" 44 | lightning-lite = "^1.8.6" 45 | ipykernel = "^6.29.0" 46 | sphinx = "^7.2.6" 47 | sphinx-autodoc-typehints = "^2.0.0" 48 | sphinx-rtd-theme = "^2.0.0" 49 | autodocsumm = "^0.2.12" 50 | docutils = "^0.20.1" 51 | myst-nb = "^1.0.0" 52 | jupytext = "^1.16.1" 53 | autodoc = "^0.5.0" 54 | 55 | 56 | [tool.pytest.ini_options] 57 | filterwarnings = [ 58 | "ignore::DeprecationWarning", 59 | "ignore::FutureWarning", 60 | "ignore::UserWarning", 61 | ] 62 | 63 | [tool.isort] 64 | profile = "black" 65 | 66 | [build-system] 67 | requires = ["poetry-core"] 68 | build-backend = "poetry.core.masonry.api" 69 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiohappyeyeballs==2.3.5 ; python_version >= "3.9" and python_version < "3.13" 2 | aiohttp==3.10.3 ; python_version >= "3.9" and python_version < "3.13" 3 | aiosignal==1.3.1 ; python_version >= "3.9" and python_version < "3.13" 4 | anndata==0.10.8 ; python_version >= "3.9" and python_version < "3.13" 5 | array-api-compat==1.8 ; python_version >= "3.9" and python_version < "3.13" 6 | async-timeout==4.0.3 ; python_version >= "3.9" and python_version < "3.11" 7 | attrs==24.2.0 ; python_version >= "3.9" and python_version < "3.13" 8 | colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Windows" 9 | contourpy==1.2.1 ; python_version >= "3.9" and python_version < "3.13" 10 | cycler==0.12.1 ; python_version >= "3.9" and python_version < "3.13" 11 | decorator==5.1.1 ; python_version >= "3.9" and python_version < "3.13" 12 | exceptiongroup==1.2.2 ; python_version >= "3.9" and python_version < "3.11" 13 | fcsparser==0.2.8 ; python_version >= "3.9" and python_version < "3.13" 14 | filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13" 15 | flowcytometrytools==0.5.1 ; python_version >= "3.9" and python_version < "3.13" 16 | flowsom==0.1.1 ; python_version >= "3.9" and python_version < "3.13" 17 | fonttools==4.53.1 ; python_version >= "3.9" and python_version < "3.13" 18 | frozenlist==1.4.1 ; python_version >= "3.9" and python_version < "3.13" 19 | fsspec==2024.6.1 ; python_version >= "3.9" and python_version < "3.13" 20 | fsspec[http]==2024.6.1 ; python_version >= "3.9" and python_version < "3.13" 21 | get-annotations==0.1.2 ; python_version >= "3.9" and python_version < "3.10" 22 | h5py==3.11.0 ; python_version >= "3.9" and python_version < "3.13" 23 | idna==3.7 ; python_version >= "3.9" and python_version < "3.13" 24 | igraph==0.11.6 ; python_version >= "3.9" and python_version < "3.13" 25 | importlib-resources==6.4.0 ; python_version >= "3.9" and python_version < "3.10" 26 | jinja2==3.1.4 ; python_version >= "3.9" and python_version < "3.13" 27 | joblib==1.4.2 ; python_version >= "3.9" and python_version < "3.13" 28 | kiwisolver==1.4.5 ; python_version >= "3.9" and python_version < "3.13" 29 | legacy-api-wrap==1.4 ; python_version >= "3.9" and python_version < "3.13" 30 | leidenalg==0.10.2 ; python_version >= "3.9" and python_version < "3.13" 31 | lightning-utilities==0.11.6 ; python_version >= "3.9" and python_version < "3.13" 32 | llvmlite==0.43.0 ; python_version >= "3.9" and python_version < "3.13" 33 | markupsafe==2.1.5 ; python_version >= "3.9" and python_version < "3.13" 34 | matplotlib==3.9.1.post1 ; python_version >= "3.9" and python_version < "3.13" 35 | minisom==2.3.2 ; python_version >= "3.9" and python_version < "3.13" 36 | mpmath==1.3.0 ; python_version >= "3.9" and python_version < "3.13" 37 | multidict==6.0.5 ; python_version >= "3.9" and python_version < "3.13" 38 | natsort==8.4.0 ; python_version >= "3.9" and python_version < "3.13" 39 | networkx==3.2.1 ; python_version >= "3.9" and python_version < "3.13" 40 | numba==0.60.0 ; python_version >= "3.9" and python_version < "3.13" 41 | numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" 42 | nvidia-cublas-cu12==12.1.3.1 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13" 43 | nvidia-cuda-cupti-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13" 44 | nvidia-cuda-nvrtc-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13" 45 | nvidia-cuda-runtime-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13" 46 | nvidia-cudnn-cu12==9.1.0.70 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13" 47 | nvidia-cufft-cu12==11.0.2.54 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13" 48 | nvidia-curand-cu12==10.3.2.106 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13" 49 | nvidia-cusolver-cu12==11.4.5.107 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13" 50 | nvidia-cusparse-cu12==12.1.0.106 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13" 51 | nvidia-nccl-cu12==2.20.5 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13" 52 | nvidia-nvjitlink-cu12==12.6.20 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13" 53 | nvidia-nvtx-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13" 54 | packaging==24.1 ; python_version >= "3.9" and python_version < "3.13" 55 | pandas==2.2.2 ; python_version >= "3.9" and python_version < "3.13" 56 | patsy==0.5.6 ; python_version >= "3.9" and python_version < "3.13" 57 | phenograph==1.5.7 ; python_version >= "3.9" and python_version < "3.13" 58 | pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13" 59 | psutil==6.0.0 ; python_version >= "3.9" and python_version < "3.13" 60 | pynndescent==0.5.13 ; python_version >= "3.9" and python_version < "3.13" 61 | pyparsing==3.1.2 ; python_version >= "3.9" and python_version < "3.13" 62 | python-dateutil==2.9.0.post0 ; python_version >= "3.9" and python_version < "3.13" 63 | pytorch-lightning==2.4.0 ; python_version >= "3.9" and python_version < "3.13" 64 | pytz==2024.1 ; python_version >= "3.9" and python_version < "3.13" 65 | pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13" 66 | scanpy==1.10.2 ; python_version >= "3.9" and python_version < "3.13" 67 | scikit-learn==1.5.1 ; python_version >= "3.9" and python_version < "3.13" 68 | scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" 69 | seaborn==0.13.2 ; python_version >= "3.9" and python_version < "3.13" 70 | session-info==1.0.0 ; python_version >= "3.9" and python_version < "3.13" 71 | setuptools==72.1.0 ; python_version >= "3.9" and python_version < "3.13" 72 | six==1.16.0 ; python_version >= "3.9" and python_version < "3.13" 73 | statsmodels==0.14.2 ; python_version >= "3.9" and python_version < "3.13" 74 | stdlib-list==0.10.0 ; python_version >= "3.9" and python_version < "3.13" 75 | sympy==1.13.2 ; python_version >= "3.9" and python_version < "3.13" 76 | texttable==1.7.0 ; python_version >= "3.9" and python_version < "3.13" 77 | threadpoolctl==3.5.0 ; python_version >= "3.9" and python_version < "3.13" 78 | torch==2.4.0 ; python_version >= "3.9" and python_version < "3.13" 79 | torchmetrics==1.4.1 ; python_version >= "3.9" and python_version < "3.13" 80 | tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13" 81 | triton==3.0.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version < "3.13" and python_version >= "3.9" 82 | typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13" 83 | tzdata==2024.1 ; python_version >= "3.9" and python_version < "3.13" 84 | umap-learn==0.5.6 ; python_version >= "3.9" and python_version < "3.13" 85 | yarl==1.9.4 ; python_version >= "3.9" and python_version < "3.13" 86 | zipp==3.20.0 ; python_version >= "3.9" and python_version < "3.10" 87 | -------------------------------------------------------------------------------- /starling-schematic600x.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camlab-bioml/starling/64f7294972eb14e09fe0a642c64fd26844cfbe8a/starling-schematic600x.png -------------------------------------------------------------------------------- /starling.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camlab-bioml/starling/64f7294972eb14e09fe0a642c64fd26844cfbe8a/starling.png -------------------------------------------------------------------------------- /starling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camlab-bioml/starling/64f7294972eb14e09fe0a642c64fd26844cfbe8a/starling/__init__.py -------------------------------------------------------------------------------- /starling/starling.py: -------------------------------------------------------------------------------- 1 | from datetime import timedelta 2 | from typing import Dict, Iterable, List, Optional, Union 3 | 4 | import numpy as np 5 | import pytorch_lightning as pl 6 | import torch 7 | from anndata import AnnData 8 | from lightning_fabric.connector import _PRECISION_INPUT 9 | from lightning_fabric.utilities.types import _PATH 10 | from pytorch_lightning.accelerators.accelerator import Accelerator 11 | from pytorch_lightning.callbacks import Callback 12 | from pytorch_lightning.loggers import Logger 13 | from pytorch_lightning.plugins import _PLUGIN_INPUT 14 | from pytorch_lightning.profilers import Profiler 15 | from pytorch_lightning.strategies.strategy import Strategy 16 | from pytorch_lightning.trainer.connectors.accelerator_connector import _LITERAL_WARN 17 | from torch.utils.data import DataLoader 18 | 19 | from starling import utility 20 | 21 | BATCH_SIZE = 512 22 | AVAIL_GPUS = min(1, torch.cuda.device_count()) 23 | DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 24 | 25 | 26 | class ST(pl.LightningModule): 27 | """The STARLING module 28 | 29 | :param adata: The sample to be analyzed, with clusters and annotations from :py:func:`starling.uility.init_clustering` 30 | :param dist_option: The distribution to use, one of 'T' for Student-T (df=2) or 'N' for Normal (Gaussian), defaults to T 31 | :param singlet_prop: The proportion of anticipated segmentation error free cells 32 | :param model_cell_size: Whether STARLING should incoporate cell size in the model 33 | :param cell_size_col_name: The column name in ``AnnData`` (anndata.obs). Required only if ``model_cell_size`` is ``True``, 34 | otherwise ignored. 35 | :param model_zplane_overlap: If cell size is modelled, should STARLING model z-plane overlap 36 | :param model_regularizer: Regularizer term impose on synethic doublet loss (BCE) 37 | :param learning_rate: Learning rate of ADAM optimizer for STARLING 38 | 39 | """ 40 | 41 | def __init__( 42 | self, 43 | adata: AnnData, 44 | dist_option: str = "T", 45 | singlet_prop: float = 0.6, 46 | model_cell_size: bool = True, 47 | cell_size_col_name: str = "area", 48 | model_zplane_overlap: bool = True, 49 | model_regularizer: float = 1.0, 50 | learning_rate: float = 1e-3, 51 | ): 52 | super().__init__() 53 | 54 | # self.save_hyperparameters() 55 | 56 | utility.validate_starling_arguments( 57 | adata, 58 | dist_option, 59 | singlet_prop, 60 | model_cell_size, 61 | cell_size_col_name, 62 | model_zplane_overlap, 63 | model_regularizer, 64 | learning_rate, 65 | ) 66 | 67 | self.adata = adata 68 | self.dist_option = dist_option 69 | self.singlet_prop = singlet_prop 70 | self.model_cell_size = model_cell_size 71 | self.cell_size_col_name = cell_size_col_name 72 | self.model_zplane_overlap = model_zplane_overlap 73 | self.model_regularizer = model_regularizer 74 | self.learning_rate = learning_rate 75 | 76 | self.X = torch.tensor(self.adata.X) 77 | self.S = ( 78 | torch.tensor(self.adata.obs[self.cell_size_col_name]) 79 | if self.model_cell_size 80 | else None 81 | ) 82 | 83 | def forward( 84 | self, batch: list[torch.Tensor] 85 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 86 | """The module's forward pass 87 | 88 | :param batch: A list of tensors 89 | 90 | :returns: Negative log loss, Binary Cross-Entropy Loss, singlet probability 91 | """ 92 | if self.model_cell_size: 93 | y, s, fy, fs, fl = batch 94 | _, _, model_nll, _ = utility.compute_posteriors( 95 | y, s, self.model_params, self.dist_option, self.model_zplane_overlap 96 | ) 97 | _, _, _, p_fake_singlet = utility.compute_posteriors( 98 | fy, fs, self.model_params, self.dist_option, self.model_zplane_overlap 99 | ) 100 | else: 101 | y, fy, fl = batch 102 | _, _, model_nll, _ = utility.compute_posteriors( 103 | y, None, self.model_params, self.dist_option, self.model_zplane_overlap 104 | ) 105 | _, _, _, p_fake_singlet = utility.compute_posteriors( 106 | fy, None, self.model_params, self.dist_option, self.model_zplane_overlap 107 | ) 108 | 109 | fake_loss = torch.nn.BCELoss()(p_fake_singlet, fl.to(torch.double)) 110 | 111 | return model_nll, fake_loss, p_fake_singlet 112 | 113 | def training_step(self, batch: List[torch.Tensor]) -> torch.Tensor: 114 | """Compute and return the training loss 115 | 116 | :param batch: A list of tensors of size m x n 117 | 118 | :returns: Total loss 119 | """ 120 | # y, s, fy, fs, fl = batch 121 | model_nll, fake_loss, p_fake_singlet = self(batch) 122 | 123 | # total loss 124 | loss = model_nll + self.model_regularizer * fake_loss 125 | 126 | self.log("train_nll", model_nll) 127 | self.log("train_bce", fake_loss) 128 | self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=True) 129 | 130 | return loss 131 | 132 | def configure_optimizers(self) -> torch.optim.Adam: 133 | """Configure the Adam optimizer. 134 | 135 | :returns: the optimizer 136 | """ 137 | optimizer = torch.optim.Adam(self.model_params.values(), lr=self.learning_rate) 138 | return optimizer 139 | 140 | def prepare_data(self) -> None: 141 | """Create training dataset and set model parameters""" 142 | tr_fy, tr_fs, tr_fl = utility.simulate_data( 143 | self.X, self.S, self.model_zplane_overlap 144 | ) 145 | 146 | ## simulate data 147 | if self.S is not None and tr_fs is not None: 148 | self.train_df = utility.ConcatDataset([self.X, self.S, tr_fy, tr_fs, tr_fl]) 149 | ## get cell size averge/variance 150 | init_s = [] 151 | init_sv = [] 152 | for ii in range(len(np.unique(self.adata.obs["init_label"]))): 153 | tmp = self.adata[np.where(self.adata.obs["init_label"] == ii)[0]].obs[ 154 | self.cell_size_col_name 155 | ] 156 | init_s.append(np.mean(tmp)) 157 | init_sv.append(np.var(tmp)) 158 | # self.init_cell_size_centroids = np.array(init_s); self.init_cell_size_variances = np.array(init_sv) 159 | self.adata.uns["init_cell_size_centroids"] = np.array(init_s) 160 | self.adata.uns["init_cell_size_variances"] = np.array(init_sv) 161 | else: 162 | # init_cell_size_centroids = None; init_cell_size_variances = None 163 | self.adata.uns["init_cell_size_centroids"] = None 164 | self.adata.uns["init_cell_size_variances"] = None 165 | self.train_df = utility.ConcatDataset([self.X, tr_fy, tr_fl]) 166 | 167 | # model_params = utility.model_paramters(self.init_e, self.init_v, self.init_s, self.init_sv) 168 | model_params = utility.model_parameters(self.adata, self.singlet_prop) 169 | self.model_params = { 170 | k: torch.from_numpy(val).to(DEVICE).requires_grad_(True) 171 | for (k, val) in model_params.items() 172 | } 173 | 174 | def train_dataloader(self) -> DataLoader: 175 | """Create the training DataLoader 176 | 177 | :returns: the training DataLoader 178 | """ 179 | return DataLoader( 180 | self.train_df, batch_size=BATCH_SIZE, shuffle=True, num_workers=8 181 | ) 182 | 183 | def train_and_fit( 184 | self, 185 | *, 186 | accelerator: Union[str, Accelerator] = "auto", 187 | strategy: Union[str, Strategy] = "auto", 188 | devices: Union[List[int], str, int] = "auto", 189 | num_nodes: int = 1, 190 | precision: Optional[_PRECISION_INPUT] = None, 191 | logger: Optional[Union[Logger, Iterable[Logger], bool]] = True, 192 | callbacks: Optional[Union[List[Callback], Callback]] = None, 193 | fast_dev_run: Union[int, bool] = False, 194 | max_epochs: Optional[int] = 100, 195 | min_epochs: Optional[int] = None, 196 | max_steps: int = -1, 197 | min_steps: Optional[int] = None, 198 | max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None, 199 | limit_train_batches: Optional[Union[int, float]] = None, 200 | limit_val_batches: Optional[Union[int, float]] = None, 201 | limit_test_batches: Optional[Union[int, float]] = None, 202 | limit_predict_batches: Optional[Union[int, float]] = None, 203 | overfit_batches: Union[int, float] = 0.0, 204 | val_check_interval: Optional[Union[int, float]] = None, 205 | check_val_every_n_epoch: Optional[int] = 1, 206 | num_sanity_val_steps: Optional[int] = None, 207 | log_every_n_steps: Optional[int] = None, 208 | enable_checkpointing: Optional[bool] = None, 209 | enable_progress_bar: Optional[bool] = None, 210 | enable_model_summary: Optional[bool] = None, 211 | accumulate_grad_batches: int = 1, 212 | gradient_clip_val: Optional[Union[int, float]] = None, 213 | gradient_clip_algorithm: Optional[str] = None, 214 | deterministic: Optional[Union[bool, _LITERAL_WARN]] = True, 215 | benchmark: Optional[bool] = None, 216 | inference_mode: bool = True, 217 | use_distributed_sampler: bool = True, 218 | profiler: Optional[Union[Profiler, str]] = None, 219 | detect_anomaly: bool = False, 220 | barebones: bool = False, 221 | plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None, 222 | sync_batchnorm: bool = False, 223 | reload_dataloaders_every_n_epochs: int = 0, 224 | default_root_dir: Optional[_PATH] = None, 225 | ) -> None: 226 | """Train the model using lightning's trainer. 227 | Param annotations (with defaults altered as needed) taken from https://lightning.ai/docs/pytorch/stable/_modules/lightning/pytorch/trainer/trainer.html#Trainer.__init__ 228 | 229 | :param accelerator: Supports passing different accelerator types ("cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto") 230 | as well as custom accelerator instances. Defaults to ``"auto"``. 231 | :param strategy: Supports different training strategies with aliases as well custom strategies. 232 | Defaults to ``"auto"``. 233 | :param devices: The devices to use. Can be set to a positive number (int or str), a sequence of device indices 234 | (list or str), the value ``-1`` to indicate all available devices should be used, or ``"auto"`` for 235 | automatic selection based on the chosen accelerator. Defaults to ``"auto"``. 236 | :param num_nodes: Number of GPU nodes for distributed training. 237 | Defaults to ``1``. 238 | :param precision: Double precision (64, '64' or '64-true'), full precision (32, '32' or '32-true'), 239 | 16bit mixed precision (16, '16', '16-mixed') or bfloat16 mixed precision ('bf16', 'bf16-mixed'). 240 | Can be used on CPU, GPU, TPUs, HPUs or IPUs. 241 | Defaults to ``'32-true'``. 242 | :param logger: Logger (or iterable collection of loggers) for experiment tracking. A ``True`` value uses 243 | the default ``TensorBoardLogger`` if it is installed, otherwise ``CSVLogger``. 244 | ``False`` will disable logging. If multiple loggers are provided, local files 245 | (checkpoints, profiler traces, etc.) are saved in the ``log_dir`` of the first logger. 246 | Defaults to ``True``. 247 | :param callbacks: Add a callback or list of callbacks. 248 | Defaults to ``None``. 249 | :param fast_dev_run: Runs n if set to ``n`` (int) else 1 if set to ``True`` batch(es) 250 | of train, val and test to find any bugs (:param ie: a sort of unit test). 251 | Defaults to ``False``. 252 | :param max_epochs: Stop training once this number of epochs is reached. Disabled by default (None). 253 | If both max_epochs and max_steps are not specified, defaults to ``max_epochs = 100``. 254 | To enable infinite training, set ``max_epochs = -1``. 255 | :param min_epochs: Force training for at least these many epochs. Disabled by default (None). 256 | :param max_steps: Stop training after this number of steps. Disabled by default (-1). If ``max_steps = -1`` 257 | and ``max_epochs = None``, will default to ``max_epochs = 1000``. To enable infinite training, set 258 | ``max_epochs`` to ``-1``. 259 | :param min_steps: Force training for at least these number of steps. Disabled by default (``None``). 260 | :param max_time: Stop training after this amount of time has passed. Disabled by default (``None``). 261 | The time duration can be specified in the format DD:HH:MM:SS (days, hours, minutes seconds), as a 262 | :class:`datetime.timedelta`, or a dictionary with keys that will be passed to 263 | :class:`datetime.timedelta`. 264 | :param limit_train_batches: How much of training dataset to check (float = fraction, int = num_batches). 265 | Defaults to ``1.0``. 266 | :param limit_val_batches: How much of validation dataset to check (float = fraction, int = num_batches). 267 | Defaults to ``1.0``. 268 | :param limit_test_batches: How much of test dataset to check (float = fraction, int = num_batches). 269 | Defaults to ``1.0``. 270 | :param limit_predict_batches: How much of prediction dataset to check (float = fraction, int = num_batches). 271 | Defaults to ``1.0``. 272 | :param overfit_batches: Overfit a fraction of training/validation data (float) or a set number of batches (int). 273 | Defaults to ``0.0``. 274 | :param val_check_interval: How often to check the validation set. Pass a ``float`` in the range [0.0, 1.0] to check 275 | after a fraction of the training epoch. Pass an ``int`` to check after a fixed number of training 276 | batches. An ``int`` value can only be higher than the number of training batches when 277 | ``check_val_every_n_epoch=None``, which validates after every ``N`` training batches 278 | across epochs or during iteration-based training. 279 | Defaults to ``1.0``. 280 | :param check_val_every_n_epoch: Perform a validation loop every after every `N` training epochs. If ``None``, 281 | validation will be done solely based on the number of training batches, requiring ``val_check_interval`` 282 | to be an integer value. 283 | Defaults to ``1``. 284 | :param num_sanity_val_steps: Sanity check runs n validation batches before starting the training routine. 285 | Set it to `-1` to run all batches in all validation dataloaders. 286 | Defaults to ``2``. 287 | :param log_every_n_steps: How often to log within steps. 288 | Defaults to ``50``. 289 | :param enable_checkpointing: If ``True``, enable checkpointing. 290 | It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in 291 | `lightning.pytorch.trainer.trainer.Trainer.callbacks`. 292 | Defaults to ``True``. 293 | :param enable_progress_bar: Whether to enable to progress bar by default. 294 | Defaults to ``True``. 295 | :param enable_model_summary: Whether to enable model summarization by default. 296 | Defaults to ``True``. 297 | :param accumulate_grad_batches: Accumulates gradients over k batches before stepping the optimizer. 298 | Defaults to 1. 299 | :param gradient_clip_val: The value at which to clip gradients. Passing ``gradient_clip_val=None`` disables 300 | gradient clipping. If using Automatic Mixed Precision (AMP), the gradients will be unscaled before. 301 | Defaults to ``None``. 302 | :param gradient_clip_algorithm: The gradient clipping algorithm to use. Pass ``gradient_clip_algorithm="value"`` 303 | to clip by value, and ``gradient_clip_algorithm="norm"`` to clip by norm. By default it will 304 | be set to ``"norm"``. 305 | :param deterministic: If ``True``, sets whether PyTorch operations must use deterministic algorithms. 306 | Set to ``"warn"`` to use deterministic algorithms whenever possible, throwing warnings on operations 307 | that don't support deterministic mode. If not set, defaults to ``False``. Defaults to ``True``. 308 | :param benchmark: The value (``True`` or ``False``) to set ``torch.backends.cudnn.benchmark`` to. 309 | The value for ``torch.backends.cudnn.benchmark`` set in the current session will be used 310 | (``False`` if not manually set). If `deterministic` 311 | is set to ``True``, this will default to ``False``. Override to manually set a different value. 312 | Defaults to ``None``. 313 | :param inference_mode: Whether to use `torch.inference_mode` or `torch.no_grad` during 314 | evaluation (``validate``/``test``/``predict``). 315 | :param use_distributed_sampler: Whether to wrap the DataLoader's sampler with 316 | :class:`torch.utils.data.DistributedSampler`. If not specified this is toggled automatically for 317 | strategies that require it. By default, it will add ``shuffle=True`` for the train sampler and 318 | ``shuffle=False`` for validation/test/predict samplers. If you want to disable this logic, you can pass 319 | ``False`` and add your own distributed sampler in the dataloader hooks. If ``True`` and a distributed 320 | sampler was already added, Lightning will not replace the existing one. For iterable-style datasets, 321 | we don't do this automatically. 322 | :param profiler: To profile individual steps during training and assist in identifying bottlenecks. 323 | Defaults to ``None``. 324 | :param detect_anomaly: Enable anomaly detection for the autograd engine. 325 | Defaults to ``False``. 326 | :param barebones: Whether to run in "barebones mode", where all features that may impact raw speed are 327 | disabled. This is meant for analyzing the Trainer overhead and is discouraged during regular training 328 | runs. 329 | :param plugins: Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins. 330 | Defaults to ``None``. 331 | :param sync_batchnorm: Synchronize batch norm layers between process groups/whole world. 332 | Defaults to ``False``. 333 | :param reload_dataloaders_every_n_epochs: Set to a positive integer to reload dataloaders every n epochs. 334 | Defaults to ``0``. 335 | :param default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed. 336 | Defaults to ``os.getcwd()``. 337 | Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/' 338 | 339 | Raises: 340 | TypeError: 341 | If ``gradient_clip_val`` is not an int or float. 342 | 343 | MisconfigurationException: 344 | If ``gradient_clip_algorithm`` is invalid. 345 | """ 346 | 347 | _locals = locals() 348 | 349 | _locals.pop("self") 350 | 351 | trainer = pl.Trainer(**_locals) 352 | 353 | trainer.fit(self) 354 | 355 | def result(self, threshold: float = 0.5) -> AnnData: 356 | """Retrieve the results and add them to ``self.adata`` 357 | 358 | :param threshold: minimum threshold for singlet probability 359 | """ 360 | if self.S is not None: 361 | model_pred_loader = DataLoader( 362 | utility.ConcatDataset([self.X, self.S]), batch_size=1000, shuffle=False 363 | ) 364 | else: 365 | model_pred_loader = DataLoader(self.X, batch_size=1000, shuffle=False) 366 | 367 | singlet_prob, singlet_assig_prob, gamma_assig_prob = utility.predict( 368 | model_pred_loader, 369 | self.model_params, 370 | self.dist_option, 371 | self.model_cell_size, 372 | self.model_zplane_overlap, 373 | threshold, 374 | ) 375 | 376 | self.adata.obs["st_label"] = np.array( 377 | singlet_assig_prob.max(1).indices 378 | ) ##p(z=c|d=1) 379 | self.adata.obs["doublet_prob"] = 1 - np.array(singlet_prob) 380 | self.adata.obs["doublet"] = 0 381 | self.adata.obs.loc[self.adata.obs["doublet_prob"] > 0.5, "doublet"] = 1 382 | self.adata.obs["max_assign_prob"] = np.array(singlet_assig_prob.max(1).values) 383 | 384 | self.adata.obsm["assignment_prob_matrix"] = np.array(singlet_assig_prob) 385 | self.adata.obsm["gamma_assignment_prob_matrix"] = np.array(gamma_assig_prob) 386 | 387 | # st_label = singlet_assig_label.numpy().astype('str') 388 | # st_label[st_label == '-1'] = 'doublet' 389 | # self.adata.obs['st_label'] = st_label 390 | 391 | # if self.model_cell_size == 'Y': 392 | # pretty_printing = np.hstack((self.adata.var_names, self.cell_size_col_name)) 393 | # c = torch.hstack([self.model_params['log_mu'], self.model_params['log_psi'].reshape(-1,1)]).detach().exp().cpu().numpy() 394 | # v = torch.hstack([self.model_params['log_sigma'], self.model_params['log_omega'].reshape(-1,1)]).detach().exp().cpu().numpy() 395 | # self.adata.uns['st_exp_centroids'] = pd.DataFrame(c, columns=pretty_printing) 396 | 397 | # else: 398 | c = self.model_params["log_mu"].detach().exp().cpu().numpy() 399 | # v = self.model_params['log_sigma'].cpu().detach().exp().cpu().numpy() 400 | self.adata.varm[ 401 | "st_exp_centroids" 402 | ] = c.T # pd.DataFrame(c, columns=self.adata.var_names) 403 | 404 | if self.model_cell_size: 405 | self.adata.uns["st_cell_size_centroids"] = ( 406 | self.model_params["log_psi"] 407 | .reshape(-1, 1) 408 | .detach() 409 | .exp() 410 | .cpu() 411 | .numpy() 412 | .T 413 | ) 414 | 415 | # self.adata.varm['init_exp_centroids'] = pd.DataFrame(self.adata.varm['init_exp_centroids'], columns = self.adata.var_names) #.to_csv(code_dir + "/output/init_centroids.csv") 416 | # self.adata.varm['init_exp_variances'] = pd.DataFrame(self.adata.varm['init_exp_variances'], columns = self.adata.var_names) #.to_csv(code_dir + "/output/init_centroids.csv") 417 | 418 | return self.adata 419 | -------------------------------------------------------------------------------- /starling/utility.py: -------------------------------------------------------------------------------- 1 | # https://github.com/eyurtsev/FlowCytometryTools/issues/44 2 | import collections 3 | from collections import abc 4 | from numbers import Number 5 | from typing import Dict, Literal, Optional, Tuple, Union 6 | 7 | import numpy as np 8 | import pandas as pd 9 | import scanpy.external as sce 10 | import torch 11 | 12 | # patch outdated flowsom dependencies 13 | pd.DataFrame.as_matrix = pd.DataFrame.to_numpy 14 | collections.MutableMapping = abc.MutableMapping 15 | from flowsom import flowsom 16 | from scanpy import AnnData 17 | from sklearn.cluster import AgglomerativeClustering, KMeans 18 | from sklearn.mixture import GaussianMixture 19 | from torch.utils.data import DataLoader, Dataset 20 | 21 | DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 22 | 23 | 24 | class ConcatDataset(Dataset): 25 | """A dataset composed of datasets 26 | 27 | :param datasets: the datasets to concatenate, each of ``d.shape[0] == m`` 28 | """ 29 | 30 | def __init__(self, datasets: list[torch.Tensor]): 31 | self.datasets = datasets 32 | 33 | def __getitem__(self, i): 34 | return tuple(d[i] for d in self.datasets) 35 | 36 | def __len__(self): 37 | return min(len(d) for d in self.datasets) 38 | 39 | 40 | def init_clustering( 41 | initial_clustering_method: Literal["User", "KM", "GMM", "FS", "PG"], 42 | adata: AnnData, 43 | k: Union[int, None] = None, 44 | labels: Optional[np.ndarray] = None, 45 | ) -> AnnData: 46 | """Compute initial cluster centroids, variances & labels 47 | 48 | :param adata: The initial data to be analyzed 49 | :param initial_clustering_method: The method for computing the initial clusters, 50 | one of ``KM`` (KMeans), ``GMM`` (Gaussian Mixture Model), 51 | ``FS`` (FlowSOM), ``User`` (user-provided), or ``PG`` (PhenoGraph). 52 | :param k: The number of clusters, must be ``n_components`` when ``initial_clustering_method`` is ``GMM`` (required), 53 | ``k`` when ``initial_clustering_method`` is ``KM`` (required), ``k`` when ``initial_clustering_method`` 54 | is ``FS`` (required), ``?`` when ``initial_clustering_method`` is ``PG`` (optional), and can be ommited when 55 | ``initial_clustering_method`` is "User", because user will be passing in their own labels. 56 | :param labels: optional, user-provided labels 57 | 58 | :raises: ValueError 59 | 60 | :return: The annotated data with labels, centroids, and variances 61 | """ 62 | 63 | if initial_clustering_method not in ["KM", "GMM", "FS", "PG", "User"]: 64 | raise ValueError( 65 | 'initial_clustering_method must be one of "KM","GMM","FS","PG" or "User" defined cluster centroids/variances' 66 | ) 67 | 68 | if initial_clustering_method in ["KM", "GMM", "FS"] and k is None: 69 | raise ValueError( 70 | "k cannot be ommitted for KMeans, FlowSOM, or Gaussian Mixture" 71 | ) 72 | 73 | if initial_clustering_method == "User" and labels is None: 74 | raise ValueError( 75 | "labels must be provided when initial_clustering_method is set to 'User'" 76 | ) 77 | 78 | if initial_clustering_method == "KM": 79 | kms = KMeans(k).fit(adata.X) 80 | init_l = kms.labels_ 81 | init_label_class = np.unique(init_l) 82 | 83 | init_e = kms.cluster_centers_ 84 | init_ev = np.array( 85 | [np.array(adata.X)[init_l == c, :].var(0) for c in init_label_class] 86 | ) 87 | 88 | elif initial_clustering_method == "GMM": 89 | gmm = GaussianMixture(n_components=k, covariance_type="diag").fit(adata.X) 90 | init_l = gmm.predict(adata.X) 91 | 92 | init_e = gmm.means_ 93 | init_ev = gmm.covariances_ 94 | 95 | elif initial_clustering_method == "User" or initial_clustering_method == "PG": 96 | if initial_clustering_method == "PG": 97 | init_l, _, _ = sce.tl.phenograph(adata.X) 98 | else: 99 | init_l = labels 100 | 101 | classes = np.unique(init_l) 102 | k = len(classes) 103 | init_e = np.zeros((k, adata.X.shape[1])) 104 | init_ev = np.zeros((k, adata.X.shape[1])) 105 | for i, c in enumerate(classes): 106 | init_e[i, :] = adata.X[init_l == c].mean(0) 107 | init_ev[i, :] = adata.X[init_l == c].var(0) 108 | 109 | elif initial_clustering_method == "FS": 110 | ## needs to output to csv first 111 | # ofn = OPATH + "fs_" + ONAME + ".csv" 112 | pd.DataFrame(adata.X).to_csv("fs.csv") 113 | fsom = flowsom("fs.csv", if_fcs=False, if_drop=True, drop_col=["Unnamed: 0"]) 114 | 115 | fsom.som_mapping( 116 | 50, # x_n: e.g. 100, the dimension of expected map 117 | 50, # y_n: e.g. 100, the dimension of expected map 118 | fsom.df.shape[1], 119 | 1, # sigma: e.g 1, the standard deviation of initialized weights 120 | 0.5, # lr: e.g 0.5, learning rate 121 | 1000, # batch_size: 1000, iteration times 122 | tf_str=None, # string, e.g. hlog', None, etc - the transform algorithm 123 | if_fcs=False, # bool, when the the input file is fcs file. If not, it should be a csv file 124 | # seed = 10, for reproducing 125 | ) 126 | start = k 127 | fsom_num_cluster = 0 128 | while fsom_num_cluster < k: 129 | # print(nc, start, fsom_nc) 130 | fsom.meta_clustering( 131 | AgglomerativeClustering, 132 | min_n=start, 133 | max_n=start, 134 | verbose=False, 135 | iter_n=10, 136 | ) # train the meta clustering for cluster in range(40,45) 137 | 138 | fsom.labeling() 139 | # fsom.bestk # the best number of clusters within the range of (min_n, max_n) 140 | fsom_class = np.unique(fsom.df["category"]) 141 | fsom_num_cluster = len(fsom_class) 142 | start += 1 143 | 144 | fsom_labels = np.array(fsom.df["category"]) 145 | 146 | i = 0 147 | init_l = np.zeros(fsom.df.shape[0], dtype=int) 148 | init_e = np.zeros((len(fsom_class), fsom.df.shape[1])) 149 | init_ev = np.zeros((len(fsom_class), fsom.df.shape[1])) 150 | for row in fsom_class: 151 | init_l[fsom_labels == row] = i 152 | init_e[i, :] = fsom.df[fsom_labels == row].mean(0) 153 | init_ev[i, :] = fsom.df[fsom_labels == row].var(0) 154 | i += 1 155 | 156 | init_e = init_e[:, :-1] 157 | init_ev = init_ev[:, :-1] 158 | 159 | adata.obs["init_label"] = init_l 160 | adata.varm[ 161 | "init_exp_centroids" 162 | ] = ( 163 | init_e.T 164 | ) ## An expression matrix (PxC) resulting from a clustering method (i.e., Kmeans) 165 | adata.varm[ 166 | "init_exp_variances" 167 | ] = ( 168 | init_ev.T 169 | ) ## An expression variance (daignal) matrix (PxC) resulting from a clustering method 170 | 171 | return adata 172 | 173 | 174 | def is_non_negative_float(arg: float): 175 | return isinstance(arg, Number) and arg > 0 176 | 177 | 178 | def validate_starling_arguments( 179 | adata: AnnData, 180 | dist_option: str, 181 | singlet_prop: float, 182 | model_cell_size: bool, 183 | cell_size_col_name: str, 184 | model_zplane_overlap: bool, 185 | model_regularizer: float, 186 | learning_rate: float, 187 | ): 188 | if type(adata) != AnnData: 189 | raise ValueError( 190 | f"Argument `adata` must be of type AnnData, received {type(adata)}." 191 | ) 192 | 193 | if adata.shape[0] < 10 or adata.shape[1] < 10: 194 | raise ValueError( 195 | f"Argument `adata` shape must be at least (10,10), received {adata.shape}." 196 | ) 197 | 198 | if type(dist_option) != str or dist_option not in ["T", "N"]: 199 | raise ValueError( 200 | f"Argument `dist_option` must be either 'T' or 'N', received {dist_option}" 201 | ) 202 | 203 | if not isinstance(singlet_prop, Number) or 0 > singlet_prop > 1: 204 | raise ValueError( 205 | f"Argument `singlet_prop` must be a number between 0 and 1, received {singlet_prop}" 206 | ) 207 | 208 | if not type(model_cell_size) == bool: 209 | raise ValueError( 210 | f"Argument `model_cell_size` must be boolean, received {type(model_cell_size)}" 211 | ) 212 | 213 | if model_cell_size and cell_size_col_name not in adata.obs: 214 | raise ValueError( 215 | f"Argument `cell_size_col_name` must be a valid column in `adata.obs`" 216 | ) 217 | 218 | if not type(model_zplane_overlap) == bool: 219 | raise ValueError( 220 | f"Argument `model_zplane_overlap` must be boolean, received {type(model_cell_size)}" 221 | ) 222 | 223 | if not is_non_negative_float(model_regularizer): 224 | raise ValueError( 225 | f"Argument `model_regularizer` must be a non-negative number, received {model_regularizer}" 226 | ) 227 | 228 | if not is_non_negative_float(learning_rate): 229 | raise ValueError( 230 | f"Argument `learning_rate` must be a non-negative number, received {learning_rate}" 231 | ) 232 | 233 | 234 | def model_parameters(adata: AnnData, singlet_prop: float) -> Dict[str, np.ndarray]: 235 | """Return initial model parameters 236 | 237 | :param adata: The sample to be analyzed, with clusters and annotations from :py:func:`init_clustering` 238 | :param singlet_prop: The proportion of anticipated segmentation error free cells 239 | 240 | :return: the model parameters 241 | """ 242 | 243 | init_e = adata.varm["init_exp_centroids"].T 244 | init_v = adata.varm["init_exp_variances"].T 245 | init_s = adata.uns["init_cell_size_centroids"] 246 | init_sv = adata.uns["init_cell_size_variances"] 247 | 248 | nc = init_e.shape[0] 249 | pi = np.ones(nc) / nc 250 | tau = np.ones((nc, nc)) 251 | tau = tau / tau.sum() 252 | 253 | model_params = { 254 | "is_pi": np.log(pi + 1e-6), 255 | "is_tau": np.log(tau + 1e-6), 256 | "is_delta": np.log([1 - singlet_prop, singlet_prop]), 257 | } 258 | 259 | model_params["log_mu"] = np.log(init_e + 1e-6) 260 | 261 | init_v[np.isnan(init_v)] = 1e-6 262 | model_params["log_sigma"] = np.log(init_v + 1e-6) 263 | 264 | if init_s is not None: 265 | model_params["log_psi"] = np.log(init_s.astype(float) + 1e-6) 266 | 267 | init_sv[np.isnan(init_sv)] = 1e-6 268 | model_params["log_omega"] = np.log(init_sv.astype(float) ** 0.5 + 1e-6) 269 | 270 | return model_params 271 | 272 | 273 | def simulate_data( 274 | Y: torch.Tensor, S: Union[torch.Tensor, None] = None, model_overlap: bool = True 275 | ) -> Tuple[torch.tensor]: 276 | """Use real data to simulate singlets/doublets (equal proportions). 277 | Return same number of cells as in Y/S, half of them are singlets and another half are doublets 278 | 279 | :param Y: data matrix of shape m x n 280 | :param S: data matrix of shape m 281 | :param model_overlap: If cell size is modelled, should STARLING model z-plane overlap 282 | 283 | :return: the simulated data 284 | """ 285 | 286 | sample_size = int(Y.shape[0] / 2) 287 | idx_singlet = np.random.choice(Y.shape[0], size=sample_size, replace=True) 288 | Y_singlet = Y[idx_singlet, :] ## expression 289 | 290 | idx_doublet = [ 291 | np.random.choice(Y.shape[0], size=sample_size), 292 | np.random.choice(Y.shape[0], size=sample_size), 293 | ] 294 | Y_doublet = (Y[idx_doublet[0], :] + Y[idx_doublet[1], :]) / 2.0 295 | 296 | fake_Y = torch.vstack([Y_singlet, Y_doublet]) 297 | fake_label = torch.concat( 298 | [ 299 | torch.ones(sample_size, dtype=torch.int), 300 | torch.zeros(sample_size, dtype=torch.int), 301 | ] 302 | ) 303 | 304 | if S is None: 305 | return fake_Y, None, fake_label 306 | else: 307 | S_singlet = S[idx_singlet] 308 | if model_overlap: 309 | dmax = torch.vstack([S[idx_doublet[0]], S[idx_doublet[1]]]).max(0).values 310 | dsum = S[idx_doublet[0]] + S[idx_doublet[1]] 311 | rr_dist = torch.distributions.Uniform( 312 | dmax.type(torch.float64), dsum.type(torch.float64) 313 | ) 314 | S_doublet = rr_dist.sample() 315 | else: 316 | S_doublet = S[idx_doublet[0]] + S[idx_doublet[1]] 317 | fake_S = torch.hstack([S_singlet, S_doublet]) 318 | return fake_Y, fake_S, fake_label ## singlet == 1, doublet == 0 319 | 320 | 321 | def compute_p_y_given_z(Y, Theta, dist_option): ## singlet case given expressions 322 | """:return: # of obs x # of cluster matrix - p(y_n | z_n = c)""" 323 | 324 | mu = torch.clamp(torch.exp(torch.clamp(Theta["log_mu"], min=-12, max=14)), min=0) 325 | sigma = torch.clamp( 326 | torch.exp(torch.clamp(Theta["log_sigma"], min=-12, max=14)), min=0 327 | ) 328 | 329 | if dist_option == "N": 330 | dist_Y = torch.distributions.Normal(loc=mu, scale=sigma) 331 | else: 332 | dist_Y = torch.distributions.StudentT(df=2, loc=mu, scale=sigma) 333 | 334 | return dist_Y.log_prob(Y.reshape(-1, 1, Y.shape[1])).sum( 335 | 2 336 | ) # <- sum because IID over G 337 | 338 | 339 | def compute_p_s_given_z(S, Theta, dist_option): ## singlet case given cell sizes 340 | """:return: # of obs x # of cluster matrix - p(s_n | z_n = c)""" 341 | 342 | psi = torch.clamp(torch.exp(torch.clamp(Theta["log_psi"], min=-12, max=14)), min=0) 343 | omega = torch.clamp( 344 | torch.exp(torch.clamp(Theta["log_omega"], min=-12, max=14)), min=0 345 | ) 346 | 347 | if dist_option == "N": 348 | dist_S = torch.distributions.Normal(loc=psi, scale=omega) 349 | else: 350 | dist_S = torch.distributions.StudentT(df=2, loc=psi, scale=omega) 351 | 352 | return dist_S.log_prob(S.reshape(-1, 1)) 353 | 354 | 355 | def compute_p_y_given_gamma(Y, Theta, dist_option): ## doublet case given expressions 356 | """:return: # of obs x # of cluster x # of cluster matrix - p(y_n | gamma_n = [c,c'])""" 357 | 358 | mu = torch.clamp(torch.exp(torch.clamp(Theta["log_mu"], min=-12, max=14)), min=0) 359 | sigma = torch.clamp( 360 | torch.exp(torch.clamp(Theta["log_sigma"], min=-12, max=14)), min=0 361 | ) 362 | 363 | mu2 = mu.reshape(1, mu.shape[0], mu.shape[1]) 364 | mu2 = (mu2 + mu2.permute(1, 0, 2)) / 2.0 # C x C x G matrix 365 | 366 | sigma2 = sigma.reshape(1, mu.shape[0], mu.shape[1]) 367 | sigma2 = (sigma2 + sigma2.permute(1, 0, 2)) / 2.0 368 | 369 | if dist_option == "N": 370 | dist_Y2 = torch.distributions.Normal(loc=mu2, scale=sigma2) 371 | else: 372 | dist_Y2 = torch.distributions.StudentT(df=2, loc=mu2, scale=sigma2) 373 | 374 | return dist_Y2.log_prob(Y.reshape(-1, 1, 1, mu.shape[1])).sum( 375 | 3 376 | ) # <- sum because IID over G 377 | 378 | 379 | def compute_p_s_given_gamma(S, Theta, dist_option): ## singlet case given cell size 380 | """:return: # of obs x # of cluster x # of cluster matrix - p(s_n | gamma_n = [c,c'])""" 381 | 382 | psi = torch.clamp(torch.exp(torch.clamp(Theta["log_psi"], min=-12, max=14)), min=0) 383 | omega = torch.clamp( 384 | torch.exp(torch.clamp(Theta["log_omega"], min=-12, max=14)), min=0 385 | ) # + 1e-6 386 | 387 | psi2 = psi.reshape(-1, 1) 388 | psi2 = psi2 + psi2.T 389 | 390 | omega2 = omega.reshape(-1, 1) 391 | omega2 = omega2 + omega2.T # + 1e-6 392 | 393 | if dist_option == "N": 394 | dist_S2 = torch.distributions.Normal(loc=psi2, scale=omega2) 395 | else: 396 | dist_S2 = torch.distributions.StudentT(df=2, loc=psi2, scale=omega2) 397 | return dist_S2.log_prob(S.reshape(-1, 1, 1)) 398 | 399 | 400 | def compute_p_s_given_gamma_model_overlap(S, Theta): 401 | """:return: # of obs x # of cluster x # of cluster matrix - p(s_n | gamma_n = [c,c'])""" 402 | 403 | psi = torch.clamp(torch.exp(torch.clamp(Theta["log_psi"], min=-12, max=14)), min=0) 404 | omega = torch.clamp( 405 | torch.exp(torch.clamp(Theta["log_omega"], min=-12, max=14)), min=0 406 | ) # + 1e-6 407 | 408 | psi2 = psi.reshape(-1, 1) 409 | psi2 = psi2 + psi2.T 410 | 411 | omega2 = omega.reshape(-1, 1) 412 | omega2 = omega2 + omega2.T 413 | 414 | ## for v 415 | ccmax = torch.combinations(psi).max(1).values 416 | mat = torch.zeros(len(psi), len(psi), dtype=torch.float64).to(DEVICE) 417 | mat[np.triu_indices(len(psi), 1)] = ccmax 418 | mat += mat.clone().T 419 | mat += torch.eye(len(psi)).to(DEVICE) * psi 420 | 421 | ## for s 422 | c = 1 / (np.sqrt(2) * omega2) 423 | q = psi2 - S.reshape(-1, 1, 1) 424 | p = mat - S.reshape(-1, 1, 1) 425 | 426 | const = 1 / (2 * (psi2 - mat)) 427 | lbp = torch.special.erf(p * c) 428 | ubp = torch.special.erf(q * c) 429 | prob = torch.clamp(const * (ubp - lbp), min=1e-6, max=1.0) 430 | 431 | return prob.log() 432 | 433 | 434 | def compute_posteriors(Y, S, Theta, dist_option, model_overlap): 435 | ## priors 436 | log_pi = torch.nn.functional.log_softmax(Theta["is_pi"], dim=0) ## C 437 | log_tau = torch.nn.functional.log_softmax( 438 | Theta["is_tau"].reshape(-1), dim=0 439 | ).reshape( 440 | log_pi.shape[0], log_pi.shape[0] 441 | ) ## CxC 442 | log_delta = torch.nn.functional.log_softmax(Theta["is_delta"], dim=0) ## 2 443 | 444 | prob_y_given_z = compute_p_y_given_z( 445 | Y, Theta, dist_option 446 | ) ## log p(y_n|z=c) -> NxC 447 | prob_data_given_z_d0 = ( 448 | prob_y_given_z + log_pi 449 | ) ## log p(y_n|z=c) + log p(z=c) -> NxC + C -> NxC 450 | 451 | if S is not None: 452 | prob_s_given_z = compute_p_s_given_z( 453 | S, Theta, dist_option 454 | ) ## log p(s_n|z=c) -> NxC 455 | prob_data_given_z_d0 += ( 456 | prob_s_given_z ## log p(y_n|z=c) + log p(s_n|z=c) -> NxC 457 | ) 458 | 459 | prob_y_given_gamma = compute_p_y_given_gamma( 460 | Y, Theta, dist_option 461 | ) ## log p(y_n|g=[c,c']) -> NxCxC 462 | prob_data_given_gamma_d1 = ( 463 | prob_y_given_gamma + log_tau 464 | ) ## log p(y_n|g=[c,c']) + log p(g=[c,c']) -> NxCxC 465 | 466 | if S is not None: 467 | if model_overlap == "Y": 468 | prob_s_given_gamma = compute_p_s_given_gamma_model_overlap( 469 | S, Theta 470 | ) ## log p(s_n|g=[c,c']) -> NxCxC 471 | else: 472 | prob_s_given_gamma = compute_p_s_given_gamma( 473 | S, Theta, dist_option 474 | ) ## log p(s_n|g=[c,c']) -> NxCxC 475 | 476 | prob_data_given_gamma_d1 += ( 477 | prob_s_given_gamma ## log p(y_n|g=[c,c']) + log p(s_n|g=[c,c']) -> NxCxC 478 | ) 479 | 480 | prob_data = torch.hstack( 481 | [ 482 | prob_data_given_z_d0 + log_delta[0], 483 | prob_data_given_gamma_d1.reshape(Y.shape[0], -1) + log_delta[1], 484 | ] 485 | ) 486 | prob_data = torch.logsumexp(prob_data, dim=1) ## N 487 | ## log p(data) = 488 | # case 1: 489 | # log p(y_n|z=c) + log p(d_n=0) + 490 | # log p(y_n|g=[c,c']) + log p(d_n=1) 491 | # case 2: 492 | # log p(y_n,s_n|z=c) + log p(d_n=0) + 493 | # log p(y_n,s_n|g=[c,c']) + log p(d_n=1) 494 | 495 | ## average negative likelihood scores 496 | cost = -prob_data.mean() ## a value 497 | 498 | ## integrate out z 499 | prob_data_given_d0 = torch.logsumexp( 500 | prob_data_given_z_d0, dim=1 501 | ) ## p(data_n|d=0)_N 502 | prob_singlet = torch.clamp( 503 | torch.exp(prob_data_given_d0 + log_delta[0] - prob_data), min=0.0, max=1.0 504 | ) 505 | 506 | ## assignments 507 | r = prob_data_given_z_d0.T + log_delta[0] - prob_data ## p(d=0,z=c|data) 508 | v = ( 509 | prob_data_given_gamma_d1.T + log_delta[1] - prob_data 510 | ) ## p(d=1,gamma=[c,c']|data) 511 | 512 | return r.T, v.T, cost, prob_singlet 513 | 514 | 515 | def predict( 516 | dataLoader: DataLoader, 517 | model_params: Dict[str, torch.Tensor], 518 | dist_option: str, 519 | model_cell_size: bool, 520 | model_zplane_overlap: bool, 521 | threshold: float = 0.5, 522 | ): 523 | """return singlet/doublet probabilities, singlet cluster assignment probabilty matrix & assignment labels 524 | 525 | :param dataLoader: the dataloader 526 | :param model_params: the model parameters 527 | :param dist_option: str, one of 'T' for Student-T (df=2) or 'N' for Normal (Gaussian) 528 | :param model_cell_size: bool 529 | :param model_zplane_overlap: whether z-plane overlap is modeled 530 | :param threshold: 531 | :return: 532 | """ 533 | 534 | singlet_prob_list = [] 535 | gamma_assig_prob_list = [] 536 | singlet_assig_prob_list = [] 537 | # singlet_assig_label_list = [] 538 | 539 | with torch.no_grad(): 540 | for i, bat in enumerate(dataLoader): 541 | if model_cell_size: 542 | # print(bat[0].shape) 543 | # print(bat[1].shape) 544 | ( 545 | singlet_assig_prob, 546 | gamma_assig_prob, 547 | _, 548 | singlet_prob, 549 | ) = compute_posteriors( 550 | bat[0].to(DEVICE), 551 | bat[1].to(DEVICE), 552 | model_params, 553 | dist_option, 554 | model_zplane_overlap, 555 | ) 556 | else: 557 | ( 558 | singlet_assig_prob, 559 | gamma_assig_prob, 560 | _, 561 | singlet_prob, 562 | ) = compute_posteriors( 563 | bat.to(DEVICE), 564 | None, 565 | model_params, 566 | dist_option, 567 | model_zplane_overlap, 568 | ) 569 | 570 | singlet_prob_list.append(singlet_prob.cpu()) 571 | gamma_assig_prob_list.append(gamma_assig_prob.exp().cpu()) 572 | singlet_assig_prob_list.append(singlet_assig_prob.exp().cpu()) 573 | 574 | # batch_pred = singlet_assig_prob.exp().max(1).indices 575 | # batch_pred[singlet_prob <= threshold] = -1 576 | # singlet_assig_label_list.append(batch_pred.cpu()) 577 | 578 | singlet_prob = torch.cat(singlet_prob_list) 579 | gamma_assig_prob = torch.cat(gamma_assig_prob_list) 580 | singlet_assig_prob = torch.cat(singlet_assig_prob_list) 581 | # singlet_assig_label = torch.cat(singlet_assig_label_list) 582 | 583 | return singlet_prob, singlet_assig_prob, gamma_assig_prob 584 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import pytest 5 | from anndata import AnnData 6 | 7 | from starling.utility import init_clustering 8 | 9 | 10 | @pytest.fixture 11 | def simple_adata(size=True): 12 | adata = AnnData(np.arange(512).reshape(32, 16)) 13 | if size: 14 | adata.obs["area"] = [random.randint(1, 5) for i in range(adata.shape[0])] 15 | return adata 16 | 17 | 18 | @pytest.fixture 19 | def simple_adata_with_size(simple_adata): 20 | simple_adata.obs["area"] = [ 21 | random.randint(1, 5) for i in range(simple_adata.shape[0]) 22 | ] 23 | return simple_adata 24 | 25 | 26 | @pytest.fixture 27 | def simple_adata_km_initialized(simple_adata): 28 | k = 3 29 | return init_clustering("KM", simple_adata, k) 30 | 31 | 32 | @pytest.fixture 33 | def simple_adata_km_initialized_with_size(simple_adata_with_size): 34 | k = 3 35 | return init_clustering("KM", simple_adata_with_size, k) 36 | -------------------------------------------------------------------------------- /tests/fixtures/sample_input.h5ad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camlab-bioml/starling/64f7294972eb14e09fe0a642c64fd26844cfbe8a/tests/fixtures/sample_input.h5ad -------------------------------------------------------------------------------- /tests/test_sanity.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, join 2 | 3 | import anndata as ad 4 | import pandas as pd 5 | from lightning_lite import seed_everything 6 | from pytorch_lightning.callbacks import EarlyStopping 7 | 8 | from starling import starling, utility 9 | 10 | 11 | def test_can_run_km(tmpdir): 12 | """Test that we can run with the KM setting in init_clustering""" 13 | seed_everything(10, workers=True) 14 | 15 | raw_adata = ad.read_h5ad(join(dirname(__file__), "fixtures", "sample_input.h5ad")) 16 | 17 | adata = utility.init_clustering( 18 | "KM", 19 | raw_adata, 20 | k=10, 21 | ) 22 | st = starling.ST(adata) 23 | cb_early_stopping = EarlyStopping(monitor="train_loss", mode="min", verbose=False) 24 | 25 | ## train ST 26 | st.train_and_fit( 27 | max_epochs=2, 28 | callbacks=[cb_early_stopping], 29 | default_root_dir=tmpdir, 30 | ) 31 | 32 | result = st.result() 33 | 34 | ## initial expression centriods (p x c) matrix 35 | init_cent = pd.DataFrame(result.varm["init_exp_centroids"], index=result.var_names) 36 | 37 | assert init_cent.shape == (24, 10) 38 | 39 | ## starling expression centriods (p x c) matrix 40 | exp_cent = pd.DataFrame(result.varm["st_exp_centroids"], index=result.var_names) 41 | 42 | assert exp_cent.shape == (24, 10) 43 | 44 | ## assignment distributions (n x c maxtrix) 45 | prom_mat = pd.DataFrame( 46 | result.obsm["assignment_prob_matrix"], index=result.obs.index 47 | ) 48 | 49 | assert prom_mat.shape == (13685, 10) 50 | 51 | 52 | def test_can_run_gmm(tmpdir): 53 | """Test that we can run with the GMM setting in init_clustering""" 54 | seed_everything(10, workers=True) 55 | adata = utility.init_clustering( 56 | "GMM", 57 | ad.read_h5ad(join(dirname(__file__), "fixtures", "sample_input.h5ad")), 58 | k=10, 59 | ) 60 | st = starling.ST(adata) 61 | cb_early_stopping = EarlyStopping(monitor="train_loss", mode="min", verbose=False) 62 | 63 | ## train ST 64 | st.train_and_fit( 65 | max_epochs=2, 66 | callbacks=[cb_early_stopping], 67 | default_root_dir=tmpdir, 68 | ) 69 | 70 | result = st.result() 71 | 72 | ## initial expression centriods (p x c) matrix 73 | init_cent = pd.DataFrame(result.varm["init_exp_centroids"], index=result.var_names) 74 | 75 | assert init_cent.shape == (24, 10) 76 | 77 | ## starling expression centriods (p x c) matrix 78 | exp_cent = pd.DataFrame(result.varm["st_exp_centroids"], index=result.var_names) 79 | 80 | assert exp_cent.shape == (24, 10) 81 | 82 | ## assignment distributions (n x c maxtrix) 83 | prom_mat = pd.DataFrame( 84 | result.obsm["assignment_prob_matrix"], index=result.obs.index 85 | ) 86 | 87 | assert prom_mat.shape == (13685, 10) 88 | 89 | 90 | def test_can_run_pg(tmpdir): 91 | """Test that we can run with the PG setting in init_clustering""" 92 | seed_everything(10, workers=True) 93 | adata = utility.init_clustering( 94 | "PG", 95 | ad.read_h5ad(join(dirname(__file__), "fixtures", "sample_input.h5ad")), 96 | k=10, 97 | ) 98 | st = starling.ST(adata) 99 | cb_early_stopping = EarlyStopping(monitor="train_loss", mode="min", verbose=False) 100 | 101 | ## train ST 102 | st.train_and_fit( 103 | max_epochs=2, 104 | callbacks=[cb_early_stopping], 105 | default_root_dir=tmpdir, 106 | ) 107 | 108 | result = st.result() 109 | 110 | ## initial expression centriods (p x c) matrix 111 | init_cent = pd.DataFrame(result.varm["init_exp_centroids"], index=result.var_names) 112 | 113 | assert init_cent.shape[0] == 24 114 | 115 | ## starling expression centriods (p x c) matrix 116 | exp_cent = pd.DataFrame(result.varm["st_exp_centroids"], index=result.var_names) 117 | 118 | assert exp_cent.shape[0] == 24 119 | 120 | ## assignment distributions (n x c maxtrix) 121 | prom_mat = pd.DataFrame( 122 | result.obsm["assignment_prob_matrix"], index=result.obs.index 123 | ) 124 | 125 | assert prom_mat.shape[0] == 13685 126 | 127 | 128 | def test_can_run_pg_without_cell_size(tmpdir): 129 | """Test that we can run the model with model_cell_size=False in ST""" 130 | seed_everything(10, workers=True) 131 | adata = utility.init_clustering( 132 | "PG", 133 | ad.read_h5ad(join(dirname(__file__), "fixtures", "sample_input.h5ad")), 134 | k=10, 135 | ) 136 | st = starling.ST(adata, model_cell_size=False) 137 | cb_early_stopping = EarlyStopping(monitor="train_loss", mode="min", verbose=False) 138 | 139 | ## train ST 140 | st.train_and_fit( 141 | max_epochs=2, 142 | callbacks=[cb_early_stopping], 143 | default_root_dir=tmpdir, 144 | ) 145 | 146 | result = st.result() 147 | 148 | exp_cent = pd.DataFrame(result.varm["st_exp_centroids"], index=result.var_names) 149 | 150 | assert exp_cent.shape[0] == 24 151 | -------------------------------------------------------------------------------- /tests/test_starling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from starling.starling import ST 4 | 5 | 6 | def test_can_instantiate(simple_adata_with_size): 7 | st = ST(simple_adata_with_size) 8 | assert type(st.X) == torch.Tensor 9 | assert type(st.S) == torch.Tensor 10 | 11 | 12 | def test_can_instantiate_without_size(simple_adata): 13 | st = ST(simple_adata, model_cell_size=False) 14 | assert type(st.X) == torch.Tensor 15 | assert st.S is None 16 | 17 | 18 | def test_prepare_data(simple_adata_km_initialized): 19 | st = ST(simple_adata_km_initialized, model_cell_size=True) 20 | assert getattr(st, "model_params", None) is None 21 | st.prepare_data() 22 | assert st.model_params is not None 23 | -------------------------------------------------------------------------------- /tests/test_utility.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from anndata import AnnData 3 | 4 | from starling.utility import init_clustering, validate_starling_arguments 5 | 6 | 7 | def assert_annotated(adata: AnnData, k): 8 | assert "init_exp_centroids" in adata.varm 9 | assert adata.varm["init_exp_centroids"].shape == (adata.X.shape[1], k) 10 | assert not np.any(np.isnan(adata.varm["init_exp_centroids"])) 11 | 12 | assert "init_exp_centroids" in adata.varm 13 | assert adata.varm["init_exp_variances"].shape == (adata.X.shape[1], k) 14 | assert not np.any(np.isnan(adata.varm["init_exp_variances"])) 15 | 16 | assert "init_label" in adata.obs 17 | assert adata.obs["init_label"].shape == (adata.X.shape[0],) 18 | 19 | 20 | def test_init_clustering_km(simple_adata): 21 | k = 3 22 | initialized = init_clustering("KM", simple_adata, k) 23 | assert_annotated(initialized, k) 24 | 25 | 26 | def test_init_clustering_gmm(simple_adata): 27 | k = 3 28 | initialized = init_clustering("GMM", simple_adata, k) 29 | assert_annotated(initialized, k) 30 | 31 | 32 | def test_init_clustering_pg(simple_adata): 33 | k = 2 34 | initialized = init_clustering("PG", simple_adata, k) 35 | assert_annotated(initialized, k) 36 | 37 | 38 | def test_init_clustering_fs(simple_adata): 39 | k = 2 40 | initialized = init_clustering("FS", simple_adata, k) 41 | assert_annotated(initialized, k) 42 | 43 | 44 | def test_init_clustering_user(simple_adata): 45 | k = 3 46 | initialized = init_clustering( 47 | "User", simple_adata, labels=np.random.randint(k, size=32) 48 | ) 49 | assert_annotated(initialized, k) 50 | 51 | 52 | def test_init_clustering_user_string(simple_adata): 53 | k = 3 54 | initialized = init_clustering( 55 | "User", 56 | simple_adata, 57 | labels=np.random.choice(np.array(["a", "b", "c"]), size=32), 58 | ) 59 | 60 | assert_annotated(initialized, k) 61 | 62 | 63 | def test_validation_passes_with_no_size(simple_adata): 64 | validate_starling_arguments( 65 | adata=simple_adata, 66 | cell_size_col_name="nonexistent", 67 | dist_option="T", 68 | singlet_prop=0.5, 69 | model_cell_size=False, 70 | model_zplane_overlap=False, 71 | model_regularizer=0.1, 72 | learning_rate=1e-3, 73 | ) 74 | --------------------------------------------------------------------------------