├── .github └── workflows │ ├── docker.yaml │ ├── pages.yaml │ ├── pypi-publish.yaml │ └── pypi-test.yaml ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── NEWS.rst ├── README.rst ├── docker ├── Dockerfile └── start-notebook.sh ├── docs ├── Makefile ├── _static │ ├── advanced_thumbnail.png │ ├── annotation_model.png │ ├── annotation_thumbnail.png │ ├── custom.css │ ├── environment.yaml │ ├── model_training.png │ ├── query_model.png │ ├── search_engine.png │ ├── search_thumbnail_1.png │ ├── search_thumbnail_2.png │ ├── search_thumbnail_3.png │ └── training_thumbnail.png ├── about.rst ├── api.rst ├── conf.py ├── index.rst ├── install.rst ├── modules │ ├── anndata_data_models.rst │ ├── cell_annotation.rst │ ├── cell_embedding.rst │ ├── cell_query.rst │ ├── cell_search_knn.rst │ ├── interpreter.rst │ ├── nn_models.rst │ ├── ontologies.rst │ ├── tiledb_data_models.rst │ ├── training_models.rst │ ├── triplet_selector.rst │ ├── utils.rst │ ├── visualizations.rst │ ├── zarr_data_models.rst │ └── zarr_dataset.rst ├── news.rst ├── notebooks │ ├── advanced_tutorial.ipynb │ ├── cell_annotation_tutorial.ipynb │ ├── cell_search_tutorial_1.ipynb │ ├── cell_search_tutorial_2.ipynb │ ├── cell_search_tutorial_3.ipynb │ └── training_tutorial.ipynb ├── requirements.txt └── tutorials.rst ├── pyproject.toml ├── scripts ├── build_annotation_knn.py ├── build_cellsearch_knn.py ├── build_cellsearch_metadata.py ├── build_embeddings.py └── train.py ├── setup.cfg ├── setup.py ├── src └── scimilarity │ ├── __init__.py │ ├── anndata_data_models.py │ ├── cell_annotation.py │ ├── cell_embedding.py │ ├── cell_query.py │ ├── cell_search_knn.py │ ├── interpreter.py │ ├── nn_models.py │ ├── ontologies.py │ ├── tiledb_data_models.py │ ├── training_models.py │ ├── triplet_selector.py │ ├── utils.py │ ├── visualizations.py │ ├── zarr_data_models.py │ └── zarr_dataset.py └── tox.ini /.github/workflows/docker.yaml: -------------------------------------------------------------------------------- 1 | # Build docker container and push to GitHub registry 2 | name: Docker 3 | on: 4 | workflow_dispatch: 5 | 6 | # Sets permissions of the GITHUB_TOKEN 7 | permissions: 8 | contents: read 9 | packages: write 10 | 11 | env: 12 | REGISTRY: ghcr.io 13 | IMAGE_NAME: ${{ github.repository }} 14 | 15 | # Build and push 16 | jobs: 17 | build-and-push-image: 18 | runs-on: ubuntu-latest 19 | steps: 20 | - name: Checkout repository 21 | uses: actions/checkout@v3 22 | - name: Log in to the Container registry 23 | uses: docker/login-action@v2 24 | with: 25 | registry: ${{ env.REGISTRY }} 26 | username: ${{ github.actor }} 27 | password: ${{ secrets.GITHUB_TOKEN }} 28 | - name: Extract Docker metadata 29 | id: meta 30 | uses: docker/metadata-action@v4 31 | with: 32 | images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} 33 | tags: type=raw,value=latest,enable={{is_default_branch}} 34 | - name: Build and push Docker image 35 | uses: docker/build-push-action@v4 36 | with: 37 | context: ./docker 38 | push: true 39 | tags: ${{ steps.meta.outputs.tags }} 40 | labels: ${{ steps.meta.outputs.labels }} 41 | -------------------------------------------------------------------------------- /.github/workflows/pages.yaml: -------------------------------------------------------------------------------- 1 | # Build and deploy Sphinx docs to GitHub Pages 2 | name: Pages 3 | on: 4 | push: 5 | branches: ['main'] 6 | workflow_dispatch: 7 | 8 | # Sets permissions of the GITHUB_TOKEN 9 | permissions: 10 | contents: read 11 | pages: write 12 | id-token: write 13 | 14 | # Allow only one concurrent deployment, skipping runs queued between the run in-progress and latest queued. 15 | # However, do NOT cancel in-progress runs as we want to allow these production deployments to complete. 16 | concurrency: 17 | group: "pages" 18 | cancel-in-progress: false 19 | 20 | jobs: 21 | # Build job 22 | build: 23 | runs-on: ubuntu-latest 24 | container: 25 | image: ghcr.io/genentech/scimilarity:latest 26 | steps: 27 | - uses: actions/checkout@v3 28 | - name: Install dependencies 29 | run: | 30 | pip install -r docs/requirements.txt 31 | - name: Setup Pages 32 | id: pages 33 | uses: actions/configure-pages@v3 34 | - name: Sphinx build 35 | run: | 36 | sphinx-build -b html docs _build 37 | - name: Upload artifact 38 | uses: actions/upload-pages-artifact@v3 39 | with: 40 | path: ./_build 41 | # Deployment job 42 | deploy: 43 | environment: 44 | name: github-pages 45 | url: ${{ steps.deployment.outputs.page_url }} 46 | runs-on: ubuntu-latest 47 | needs: build 48 | steps: 49 | - name: Deploy to GitHub Pages 50 | id: deployment 51 | uses: actions/deploy-pages@v4 52 | -------------------------------------------------------------------------------- /.github/workflows/pypi-publish.yaml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Publish to PyPI 5 | 6 | on: 7 | workflow_dispatch 8 | # push: 9 | # tags: '^[0-9]+.[0-9]+.[0-9]+' 10 | 11 | jobs: 12 | build: 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python 3.9 18 | uses: actions/setup-python@v4 19 | with: 20 | python-version: 3.9 21 | 22 | - name: Install system dependencies 23 | run: | 24 | sudo apt install pandoc 25 | 26 | - name: Install python dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install flake8 pytest tox 30 | 31 | - name: Build package 32 | run: | 33 | python -m tox -e clean,build 34 | 35 | - name: Publish package 36 | uses: pypa/gh-action-pypi-publish@release/v1 37 | with: 38 | user: __token__ 39 | password: ${{ secrets.PYPI_TOKEN }} 40 | -------------------------------------------------------------------------------- /.github/workflows/pypi-test.yaml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Test the library 5 | 6 | on: 7 | workflow_dispatch 8 | # push: 9 | # branches: [main] 10 | # pull_request: 11 | # branches: [main] 12 | 13 | jobs: 14 | build: 15 | runs-on: ubuntu-latest 16 | strategy: 17 | matrix: 18 | python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] 19 | 20 | name: Python ${{ matrix.python-version }} 21 | steps: 22 | - uses: actions/checkout@v2 23 | - name: Setup Python 24 | uses: actions/setup-python@v4 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | cache: "pip" 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install flake8 pytest tox 32 | 33 | # - name: Lint with flake8 34 | # run: | 35 | # # stop the build if there are Python syntax errors or undefined names 36 | # flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 37 | # # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 38 | # # flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 39 | 40 | - name: Test with tox 41 | run: | 42 | tox 43 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Temporary and binary files 2 | *~ 3 | *.py[cod] 4 | *.so 5 | *.cfg 6 | !.isort.cfg 7 | !setup.cfg 8 | *.orig 9 | *.log 10 | *.pot 11 | __pycache__/* 12 | .cache/* 13 | .*.swp 14 | */.ipynb_checkpoints/* 15 | .DS_Store 16 | 17 | # Project files 18 | .ropeproject 19 | .project 20 | .pydevproject 21 | .settings 22 | .idea 23 | .vscode 24 | tags 25 | 26 | # Package files 27 | *.egg 28 | *.eggs/ 29 | .installed.cfg 30 | *.egg-info 31 | 32 | # Unittest and coverage 33 | htmlcov/* 34 | .coverage 35 | .coverage.* 36 | .tox 37 | junit*.xml 38 | coverage.xml 39 | .pytest_cache/ 40 | 41 | # Build and docs folder/files 42 | build/* 43 | dist/* 44 | sdist/* 45 | docs/_rst/* 46 | docs/_build/* 47 | cover/* 48 | MANIFEST 49 | 50 | # Per-project virtualenvs 51 | .venv*/ 52 | .conda*/ 53 | .python-version 54 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2023 Genentech, Inc. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | exclude .gitignore 2 | recursive-exclude .github * 3 | recursive-exclude docker * 4 | recursive-exclude scripts * 5 | -------------------------------------------------------------------------------- /NEWS.rst: -------------------------------------------------------------------------------- 1 | Release Notes 2 | ================================================================================ 3 | 4 | Version 0.4.0: May 05, 2025 5 | -------------------------------------------------------------------------------- 6 | 7 | General: 8 | + A new training tutorial has been added which describes the new training 9 | workflow. This includes data preparation, training, and post-training data 10 | structures using the new scripts. 11 | 12 | Training: 13 | + A new training workflow has been added to use CellArr, a TileDB based 14 | framework, as the data store to streamline the end-to-end process. This 15 | replaces the old Zarr based workflows. 16 | + New data loaders and samplers for CellArr data have been added in the 17 | ``tiledb_data_models`` module. 18 | + A example training script has been added to show how to train models as 19 | ``scripts/train.py``. 20 | + New scripts for creating all post-training data structures have been added 21 | in the folder ``scripts``. 22 | + New utility methods that make use of the CellArr store: 23 | ``utils.query_tiledb_df`` to query a tiledb dataframe, 24 | ``utils.adata_from_tiledb`` to extract cells from the tiledb stores based on 25 | index, including raw counts. 26 | 27 | Version 0.3.0: November 19, 2024 28 | -------------------------------------------------------------------------------- 29 | 30 | General: 31 | + Various changes to utility functions to improve efficiency and flexibility. 32 | + Simplification of many class constructor parameters. 33 | + Tutorials have been updated with new download links and analyses. 34 | 35 | Exhaustive queries: 36 | + Functionality to perform exhaustive queries has been added as new methods 37 | ``cell_query.search_exhaustive``, ``cell_query.search_centroid_exhaustive``, 38 | and ``cell_query.search_cluster_centroids_exhaustive``. 39 | + The kNN query method ``cell_query.search`` has been renamed to 40 | ``cell_query.search_nearest``. 41 | 42 | Query result filtering and interpretation: 43 | + The ``cell_query.compile_sample_metadata`` method has been expanded to 44 | allow grouping by tissue and disease (in addition to study and sample). 45 | + The methods ``utils.subset_by_unique_values``, 46 | ``utils.subset_by_frequency``, and ``utils.categorize_and_sort_by_score`` 47 | have been added to provide tools for filtering, sorting and summarizing 48 | query results. 49 | + The "query_stability" quality control metric has been renamed to 50 | "query_coherence" and is now deterministic (by setting a random seed). 51 | + Results from exhaustive queries can be constrained to specific 52 | metadata criteria (e.g., tissue, disease, in vitro vs in vivo, etc). 53 | using the ``metadata_filter`` argument to ``cell_query.search_exhaustive``. 54 | + Results from exhaustive queries can be constrained by distance-to-query 55 | using the ``max_dist`` argument to ``cell_query.search_exhaustive``. 56 | + The mappings in ``utils.clean_tissues`` and ``utils.clean_diseases`` have 57 | been expanded. 58 | 59 | Optimizations to training: 60 | + The ASW and NMSE training evaluation metrics were added. 61 | + The ``triplet_selector.get_asw`` method was added to calculate ASW. 62 | + The ``ontologies.find_most_viable_parent`` method was added to help coarse 63 | grain cell type ontology labels. 64 | + Optimized sampling weights of study and cell type. 65 | 66 | Version 0.2.0: March 22, 2024 67 | -------------------------------------------------------------------------------- 68 | 69 | + Updated version requirements for multiple dependencies and removed 70 | the ``pegasuspy`` dependency. 71 | + Expanded API documentation and tutorials. 72 | + Simplified model file structure and model loader methods. 73 | + Added ``search_centroid`` method to class ``CellQuery`` for cell 74 | queries using custom centroids which provides quality control 75 | statistics to assess the consistency of query results. 76 | 77 | 78 | Version 0.1.0: August 13, 2023 79 | -------------------------------------------------------------------------------- 80 | 81 | + Initial public release. 82 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | SCimilarity 2 | ================================================================================ 3 | 4 | **SCimilarity** is a unifying representation of single cell expression profiles 5 | that quantifies similarity between expression states and generalizes to 6 | represent new studies without additional training. 7 | 8 | This enables a novel cell search capability, which sifts through millions of 9 | profiles to find cells similar to a query cell state and allows researchers to 10 | quickly and systematically leverage massive public scRNA-seq atlases to learn 11 | about a cell state of interest. 12 | 13 | Documentation 14 | -------------------------------------------------------------------------------- 15 | 16 | Tutorials and API documentation can be found at: 17 | https://genentech.github.io/scimilarity/index.html 18 | 19 | Download & Install 20 | -------------------------------------------------------------------------------- 21 | 22 | The latest API release can be installed from PyPI:: 23 | 24 | pip install scimilarity 25 | 26 | Pretrained model weights, embeddings, kNN graphs, a single-cell metadata 27 | can be downloaded from: 28 | https://zenodo.org/records/10685499 29 | 30 | A docker container with SCimilarity preinstalled can be pulled from: 31 | https://ghcr.io/genentech/scimilarity 32 | 33 | Citation 34 | -------------------------------------------------------------------------------- 35 | 36 | To cite SCimilarity in publications please use: 37 | 38 | **A cell atlas foundation model for scalable search of similar human cells.** 39 | *Graham Heimberg\*, Tony Kuo\*, Daryle J. DePianto, Tobias Heigl, 40 | Nathaniel Diamant, Omar Salem, Gabriele Scalia, Tommaso Biancalani, 41 | Jason R. Rock, Shannon J. Turley, Héctor Corrada Bravo, Josh Kaminker\*\*, 42 | Jason A. Vander Heiden\*\*, Aviv Regev\*\*.* 43 | Nature (2024). https://doi.org/10.1038/s41586-024-08411-y -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM fedora:42 2 | LABEL maintainer="Jason Anthony Vander Heiden [vandej27@gene.com]" \ 3 | org.opencontainers.image.description="SCimilarity" \ 4 | org.opencontainers.image.source="https://github.com/genentech/scimilarity" 5 | 6 | # Bind points 7 | VOLUME /data 8 | VOLUME /models 9 | VOLUME /workspace 10 | VOLUME /scratch 11 | 12 | # Tools 13 | COPY start-notebook.sh /usr/local/bin/start-notebook 14 | 15 | # Environment 16 | ENV SCDATA_HOME=/data 17 | ENV SCMODEL_HOME=/models 18 | 19 | # Update and install required packages 20 | RUN dnf -y update && dnf install -y \ 21 | bzip2 \ 22 | cmake \ 23 | igraph \ 24 | gcc-c++ \ 25 | git \ 26 | lz4 \ 27 | pandoc \ 28 | python3 \ 29 | python3-aiohttp \ 30 | python3-asciitree \ 31 | python3-biopython \ 32 | python3-cloudpickle \ 33 | python3-Cython \ 34 | python3-numcodecs \ 35 | python3-dask \ 36 | python3-dask+array \ 37 | python3-devel \ 38 | python3-fasteners \ 39 | python3-GitPython \ 40 | python3-h5py \ 41 | python3-igraph \ 42 | python3-jupyter-client \ 43 | python3-jupyterlab_pygments \ 44 | python3-matplotlib \ 45 | python3-matplotlib-scalebar \ 46 | python3-natsort \ 47 | python3-nbconvert \ 48 | python3-nbsphinx \ 49 | python3-notebook \ 50 | python3-numpy \ 51 | python3-pandas \ 52 | python3-pip \ 53 | python3-pydantic \ 54 | python3-pydata-sphinx-theme \ 55 | python3-PyYAML \ 56 | python3-seaborn \ 57 | python3-setuptools \ 58 | python3-scipy \ 59 | python3-stdlib-list \ 60 | python3-texttable \ 61 | python3-toolz \ 62 | python3-tqdm \ 63 | python3-wrapt \ 64 | python3-zarr \ 65 | sudo \ 66 | tar \ 67 | wget \ 68 | zstd \ 69 | && dnf clean all 70 | 71 | # Install python dependencies 72 | RUN pip3 install \ 73 | scikit-learn \ 74 | scikit-misc \ 75 | numba \ 76 | tiledb \ 77 | tiledb-cloud \ 78 | tiledb-vector-search \ 79 | leidenalg \ 80 | louvain \ 81 | umap-learn \ 82 | hnswlib \ 83 | obonet \ 84 | circlify \ 85 | captum \ 86 | torch \ 87 | pytorch-lightning \ 88 | scanpy \ 89 | scrublet 90 | 91 | # Install SCimilarity API 92 | RUN git clone https://github.com/Genentech/scimilarity.git /tmp/scimilarity \ 93 | && cd /tmp/scimilarity \ 94 | && pip install . 95 | 96 | # Entry points 97 | CMD ["start-notebook"] 98 | -------------------------------------------------------------------------------- /docker/start-notebook.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | jupyter notebook --ip=0.0.0.0 --port=8888 --no-browser \ 3 | --allow-root --notebook-dir /workspace 4 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = docs 9 | BUILDDIR = docs/_build 10 | AUTODOCDIR = modules 11 | 12 | # User-friendly check for sphinx-build 13 | ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $?), 1) 14 | $(error "The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from https://sphinx-doc.org/") 15 | endif 16 | 17 | .PHONY: help clean Makefile 18 | 19 | # Put it first so that "make" without argument is like "make help". 20 | help: 21 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 22 | 23 | clean: 24 | rm -rf $(BUILDDIR)/* $(AUTODOCDIR) 25 | 26 | # Catch-all target: route all unknown targets to Sphinx using the new 27 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 28 | %: Makefile 29 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 30 | -------------------------------------------------------------------------------- /docs/_static/advanced_thumbnail.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/scimilarity/52824e59ab90fc19d1bf1bfc3be46f418cb68ba4/docs/_static/advanced_thumbnail.png -------------------------------------------------------------------------------- /docs/_static/annotation_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/scimilarity/52824e59ab90fc19d1bf1bfc3be46f418cb68ba4/docs/_static/annotation_model.png -------------------------------------------------------------------------------- /docs/_static/annotation_thumbnail.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/scimilarity/52824e59ab90fc19d1bf1bfc3be46f418cb68ba4/docs/_static/annotation_thumbnail.png -------------------------------------------------------------------------------- /docs/_static/custom.css: -------------------------------------------------------------------------------- 1 | img { 2 | width: 90%; 3 | height: auto; 4 | } -------------------------------------------------------------------------------- /docs/_static/environment.yaml: -------------------------------------------------------------------------------- 1 | name: scimilarity 2 | channels: 3 | - nvidia 4 | - pytorch 5 | - bioconda 6 | - conda-forge 7 | dependencies: 8 | - python=3.10 9 | - ipykernel 10 | - ipython 11 | - ipywidgets 12 | - leidenalg 13 | - pip -------------------------------------------------------------------------------- /docs/_static/model_training.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/scimilarity/52824e59ab90fc19d1bf1bfc3be46f418cb68ba4/docs/_static/model_training.png -------------------------------------------------------------------------------- /docs/_static/query_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/scimilarity/52824e59ab90fc19d1bf1bfc3be46f418cb68ba4/docs/_static/query_model.png -------------------------------------------------------------------------------- /docs/_static/search_engine.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/scimilarity/52824e59ab90fc19d1bf1bfc3be46f418cb68ba4/docs/_static/search_engine.png -------------------------------------------------------------------------------- /docs/_static/search_thumbnail_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/scimilarity/52824e59ab90fc19d1bf1bfc3be46f418cb68ba4/docs/_static/search_thumbnail_1.png -------------------------------------------------------------------------------- /docs/_static/search_thumbnail_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/scimilarity/52824e59ab90fc19d1bf1bfc3be46f418cb68ba4/docs/_static/search_thumbnail_2.png -------------------------------------------------------------------------------- /docs/_static/search_thumbnail_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/scimilarity/52824e59ab90fc19d1bf1bfc3be46f418cb68ba4/docs/_static/search_thumbnail_3.png -------------------------------------------------------------------------------- /docs/_static/training_thumbnail.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/scimilarity/52824e59ab90fc19d1bf1bfc3be46f418cb68ba4/docs/_static/training_thumbnail.png -------------------------------------------------------------------------------- /docs/about.rst: -------------------------------------------------------------------------------- 1 | About 2 | ================================================================================ 3 | 4 | Support 5 | -------------------------------------------------------------------------------- 6 | 7 | You can report issues with SCimilarity using our GitHub 8 | `issue tracker `__. 9 | 10 | .. _Authors: 11 | 12 | Authors 13 | -------------------------------------------------------------------------------- 14 | 15 | Graham Heimberg, Tony Kuo, Nathaniel Diamant, Omar Salem, 16 | Héctor Corrada Bravo, Jason A. Vander Heiden 17 | 18 | .. _Cite: 19 | 20 | How to Cite 21 | -------------------------------------------------------------------------------- 22 | 23 | To cite SCimilarity in publications please use: 24 | 25 | **A cell atlas foundation model for scalable search of similar human cells.** 26 | *Graham Heimberg\*, Tony Kuo\*, Daryle J. DePianto, Tobias Heigl, 27 | Nathaniel Diamant, Omar Salem, Gabriele Scalia, Tommaso Biancalani, 28 | Jason R. Rock, Shannon J. Turley, Héctor Corrada Bravo, Josh Kaminker\*\*, 29 | Jason A. Vander Heiden\*\*, Aviv Regev\*\*.* 30 | Nature (2024). https://doi.org/10.1038/s41586-024-08411-y 31 | -------------------------------------------------------------------------------- /docs/api.rst: -------------------------------------------------------------------------------- 1 | .. _API: 2 | 3 | API Reference 4 | ================================================================================ 5 | 6 | .. toctree:: 7 | :maxdepth: 1 8 | :hidden: 9 | 10 | API Reference 11 | 12 | .. toctree:: 13 | :maxdepth: 2 14 | :caption: Core Functionality 15 | :hidden: 16 | 17 | modules/cell_annotation 18 | modules/cell_embedding 19 | modules/cell_query 20 | modules/cell_search_knn 21 | modules/interpreter 22 | 23 | .. toctree:: 24 | :maxdepth: 2 25 | :caption: Model Training 26 | :hidden: 27 | 28 | modules/anndata_data_models 29 | modules/nn_models 30 | modules/tiledb_data_models 31 | modules/training_models 32 | modules/triplet_selector 33 | modules/zarr_data_models 34 | 35 | .. toctree:: 36 | :maxdepth: 2 37 | :caption: Utilities 38 | :hidden: 39 | 40 | modules/ontologies 41 | modules/utils 42 | modules/visualizations 43 | modules/zarr_dataset 44 | 45 | Core Functionality 46 | -------------------------------------------------------------------------------- 47 | 48 | These modules provide functionality for utilizing SCimilarity embeddings for a 49 | variety of tasks, including cell type annotation, cell queries, and gene 50 | attribution scoring. 51 | 52 | * :mod:`scimilarity.cell_annotation` 53 | * :mod:`scimilarity.cell_embedding` 54 | * :mod:`scimilarity.cell_query` 55 | * :mod:`scimilarity.cell_search_knn` 56 | * :mod:`scimilarity.interpreter` 57 | 58 | Model Training 59 | -------------------------------------------------------------------------------- 60 | 61 | Training new SCimilarity models requires aggregated and curated training data. 62 | This relies on specialized data loaders that are optimized for random cell access 63 | across datasets, specialized variations of metric learning loss functions, and 64 | procedures for cell ontology aware triplet mining. The following modules include 65 | support for these training tasks. 66 | 67 | * :mod:`scimilarity.anndata_data_models` 68 | * :mod:`scimilarity.nn_models` 69 | * :mod:`scimilarity.tiledb_data_models` 70 | * :mod:`scimilarity.training_models` 71 | * :mod:`scimilarity.triplet_selector` 72 | * :mod:`scimilarity.zarr_data_models` 73 | * :mod:`scimilarity.zarr_dataset` 74 | 75 | Utilities 76 | -------------------------------------------------------------------------------- 77 | 78 | SCimilarity uses specific visualizations, ontology interfaces, and data 79 | preprocessing steps. These modules provide functionality useful for model 80 | training as well as a variety of SCimilarity analyses. 81 | 82 | * :mod:`scimilarity.ontologies` 83 | * :mod:`scimilarity.utils` 84 | * :mod:`scimilarity.visualizations` 85 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # This file is execfile()d with the current directory set to its containing dir. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | # 7 | # All configuration values have a default; values that are commented out 8 | # serve to show the default. 9 | 10 | import os 11 | import sys 12 | 13 | # -- Path setup -------------------------------------------------------------- 14 | 15 | __location__ = os.path.dirname(__file__) 16 | 17 | # If extensions (or modules to document with autodoc) are in another directory, 18 | # add these directories to sys.path here. If the directory is relative to the 19 | # documentation root, use os.path.abspath to make it absolute, like shown here. 20 | sys.path.insert(0, os.path.join(__location__, "../src")) 21 | 22 | 23 | # -- General configuration --------------------------------------------------- 24 | 25 | # If your documentation needs a minimal Sphinx version, state it here. 26 | # needs_sphinx = '1.0' 27 | 28 | # Add any Sphinx extension module names here, as strings. They can be extensions 29 | # coming with Sphinx (named 'sphinx.ext.*') or your custom ones. 30 | extensions = [ 31 | "sphinx.ext.autodoc", 32 | "sphinx.ext.intersphinx", 33 | "sphinx.ext.todo", 34 | "sphinx.ext.autosummary", 35 | "sphinx.ext.viewcode", 36 | "sphinx.ext.coverage", 37 | "sphinx.ext.doctest", 38 | "sphinx.ext.ifconfig", 39 | "sphinx.ext.mathjax", 40 | "sphinx.ext.napoleon", 41 | "nbsphinx", 42 | "sphinx_gallery.load_style", 43 | ] 44 | 45 | # autodoc configuration 46 | autodoc_typehints = "description" 47 | autodoc_mock_imports = ["anndata", 48 | "hnswlib", 49 | "captum", 50 | "circlify", 51 | "matplotlib", 52 | "networkx", 53 | "numba", 54 | "numcodecs", 55 | "numpy", 56 | "obonet", 57 | "pandas", 58 | "pegasusio", 59 | "pytorch_lightning", 60 | "scanpy", 61 | "scipy", 62 | "seaborn", 63 | "tiledb", 64 | "tqdm", 65 | "torch", 66 | "zarr"] 67 | 68 | # todo configuration 69 | todo_include_todos = True 70 | 71 | # nbsphinx configuration 72 | nbsphinx_thumbnails = { 73 | 'notebooks/cell_search_tutorial_1': '_static/search_thumbnail_1.png', 74 | 'notebooks/cell_search_tutorial_2': '_static/search_thumbnail_2.png', 75 | 'notebooks/cell_search_tutorial_3': '_static/search_thumbnail_3.png', 76 | 'notebooks/cell_annotation_tutorial': '_static/annotation_thumbnail.png', 77 | 'notebooks/advanced_tutorial': '_static/advanced_thumbnail.png', 78 | 'notebooks/training_tutorial': '_static/training_thumbnail.png' 79 | } 80 | 81 | # Add any paths that contain templates here, relative to this directory. 82 | templates_path = ["_templates"] 83 | 84 | # The suffix of source filenames. 85 | source_suffix = ".rst" 86 | 87 | # The encoding of source files. 88 | # source_encoding = 'utf-8-sig' 89 | 90 | # The master toctree document. 91 | master_doc = "index" 92 | 93 | # General information about the project. 94 | project = "scimilarity" 95 | copyright = "2023, Genentech, Inc." 96 | 97 | # The version info for the project you're documenting, acts as replacement for 98 | # |version| and |release|, also used in various other places throughout the 99 | # built documents. 100 | # 101 | # version: The short X.Y version. 102 | # release: The full version, including alpha/beta/rc tags. 103 | # If you don’t need the separation provided between version and release, 104 | # just set them both to the same value. 105 | try: 106 | from scimilarity import __version__ as version 107 | except ImportError: 108 | version = "" 109 | 110 | if not version or version.lower() == "unknown": 111 | version = os.getenv("READTHEDOCS_VERSION", "unknown") # automatically set by RTD 112 | 113 | release = version 114 | 115 | # The language for content autogenerated by Sphinx. Refer to documentation 116 | # for a list of supported languages. 117 | # language = None 118 | 119 | # There are two options for replacing |today|: either, you set today to some 120 | # non-false value, then it is used: 121 | # today = '' 122 | # Else, today_fmt is used as the format for a strftime call. 123 | # today_fmt = '%B %d, %Y' 124 | 125 | # List of patterns, relative to source directory, that match files and 126 | # directories to ignore when looking for source files. 127 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", ".venv"] 128 | 129 | # The reST default role (used for this markup: `text`) to use for all documents. 130 | # default_role = None 131 | 132 | # If true, '()' will be appended to :func: etc. cross-reference text. 133 | # add_function_parentheses = True 134 | 135 | # If true, the current module name will be prepended to all description 136 | # unit titles (such as .. function::). 137 | # add_module_names = True 138 | 139 | # If true, sectionauthor and moduleauthor directives will be shown in the 140 | # output. They are ignored by default. 141 | # show_authors = False 142 | 143 | # The name of the Pygments (syntax highlighting) style to use. 144 | pygments_style = "sphinx" 145 | 146 | # A list of ignored prefixes for module index sorting. 147 | # modindex_common_prefix = [] 148 | 149 | # If true, keep warnings as "system message" paragraphs in the built documents. 150 | # keep_warnings = False 151 | 152 | # If this is True, todo emits a warning for each TODO entries. The default is False. 153 | todo_emit_warnings = True 154 | 155 | 156 | # -- Options for HTML output ------------------------------------------------- 157 | 158 | # The theme to use for HTML and HTML Help pages. See the documentation for 159 | # a list of builtin themes. 160 | html_theme = "pydata_sphinx_theme" 161 | 162 | # Theme options are theme-specific and customize the look and feel of a theme 163 | # further. For a list of options available for each theme, see the 164 | # documentation. 165 | # html_theme_options = { 166 | # "sidebar_width": "300px", 167 | # "page_width": "1200px" 168 | # } 169 | 170 | # Add any paths that contain custom themes here, relative to this directory. 171 | # html_theme_path = [] 172 | 173 | # The name for this set of Sphinx documents. If None, it defaults to 174 | # " v documentation". 175 | # html_title = None 176 | 177 | # A shorter title for the navigation bar. Default is the same as html_title. 178 | # html_short_title = None 179 | 180 | # The name of an image file (relative to this directory) to place at the top 181 | # of the sidebar. 182 | # html_logo = "" 183 | 184 | # The name of an image file (within the static path) to use as favicon of the 185 | # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 186 | # pixels large. 187 | # html_favicon = None 188 | 189 | # Add any paths that contain custom static files (such as style sheets) here, 190 | # relative to this directory. They are copied after the builtin static files, 191 | # so a file named "default.css" will overwrite the builtin "default.css". 192 | html_static_path = ["_static"] 193 | 194 | # Custom CSS 195 | html_css_files = ['custom.css'] 196 | 197 | # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, 198 | # using the given strftime format. 199 | # html_last_updated_fmt = '%b %d, %Y' 200 | 201 | # If true, SmartyPants will be used to convert quotes and dashes to 202 | # typographically correct entities. 203 | # html_use_smartypants = True 204 | 205 | # Custom sidebar templates, maps document names to template names. 206 | # html_sidebars = {} 207 | 208 | # Additional templates that should be rendered to pages, maps page names to 209 | # template names. 210 | # html_additional_pages = {} 211 | 212 | # If false, no module index is generated. 213 | # html_domain_indices = True 214 | 215 | # If false, no index is generated. 216 | # html_use_index = True 217 | 218 | # If true, the index is split into individual pages for each letter. 219 | # html_split_index = False 220 | 221 | # If true, links to the reST sources are added to the pages. 222 | # html_show_sourcelink = True 223 | 224 | # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. 225 | # html_show_sphinx = True 226 | 227 | # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. 228 | # html_show_copyright = True 229 | 230 | # If true, an OpenSearch description file will be output, and all pages will 231 | # contain a tag referring to it. The value of this option must be the 232 | # base URL from which the finished HTML is served. 233 | # html_use_opensearch = '' 234 | 235 | # This is the file name suffix for HTML files (e.g. ".xhtml"). 236 | # html_file_suffix = None 237 | 238 | # Output file base name for HTML help builder. 239 | htmlhelp_basename = "scimilarity-doc" 240 | 241 | 242 | # -- Options for LaTeX output ------------------------------------------------ 243 | 244 | latex_elements = { 245 | # The paper size ("letterpaper" or "a4paper"). 246 | # "papersize": "letterpaper", 247 | # The font size ("10pt", "11pt" or "12pt"). 248 | # "pointsize": "10pt", 249 | # Additional stuff for the LaTeX preamble. 250 | # "preamble": "", 251 | } 252 | 253 | # Grouping the document tree into LaTeX files. List of tuples 254 | # (source start file, target name, title, author, documentclass [howto/manual]). 255 | latex_documents = [ 256 | ("index", "user_guide.tex", "scimilarity Documentation", "Graham Heimberg", "manual") 257 | ] 258 | 259 | # The name of an image file (relative to this directory) to place at the top of 260 | # the title page. 261 | # latex_logo = "" 262 | 263 | # For "manual" documents, if this is true, then toplevel headings are parts, 264 | # not chapters. 265 | # latex_use_parts = False 266 | 267 | # If true, show page references after internal links. 268 | # latex_show_pagerefs = False 269 | 270 | # If true, show URL addresses after external links. 271 | # latex_show_urls = False 272 | 273 | # Documents to append as an appendix to all manuals. 274 | # latex_appendices = [] 275 | 276 | # If false, no module index is generated. 277 | # latex_domain_indices = True 278 | 279 | # -- External mapping -------------------------------------------------------- 280 | python_version = ".".join(map(str, sys.version_info[0:2])) 281 | # intersphinx_mapping = { 282 | # "sphinx": ("https://www.sphinx-doc.org/en/master", None), 283 | # "python": ("https://docs.python.org/" + python_version, None), 284 | # "matplotlib": ("https://matplotlib.org", None), 285 | # "numpy": ("https://numpy.org/doc/stable", None), 286 | # "sklearn": ("https://scikit-learn.org/stable", None), 287 | # "pandas": ("https://pandas.pydata.org/pandas-docs/stable", None), 288 | # "scipy": ("https://docs.scipy.org/doc/scipy/reference", None), 289 | # "setuptools": ("https://setuptools.pypa.io/en/stable/", None), 290 | # "pyscaffold": ("https://pyscaffold.org/en/stable", None), 291 | # } 292 | 293 | print(f"loading configurations for {project} {version} ...", file=sys.stderr) 294 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | SCimilarity 2 | ================================================================================ 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | :hidden: 7 | 8 | Overview 9 | Installation 10 | Tutorials 11 | API Reference 12 | Release Notes 13 | About 14 | 15 | A search engine for cells 16 | -------------------------------------------------------------------------------- 17 | 18 | SCimilarity is a unifying representation of single cell expression profiles that 19 | quantifies similarity between expression states and generalizes to represent new 20 | studies without additional training. 21 | 22 | This enables a novel cell search capability, which sifts through millions of 23 | profiles to find cells similar to a query cell state and allows researchers to 24 | quickly and systematically leverage massive public scRNA-seq atlases to learn 25 | about a cell state of interest. 26 | 27 | .. image:: _static/search_engine.png 28 | :width: 75% 29 | :align: center 30 | 31 | Capabilities 32 | -------------------------------------------------------------------------------- 33 | 34 | Cell query 35 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 36 | 37 | :mod:`scimilarity.cell_query` provides tools to search for similar cells 38 | across a large reference. Input a gene expression profile of a cell state of 39 | interest and search across tens of millions of cells to find cells that resemble 40 | your query cell state. This does not require special preprocessing such as batch 41 | correction or highly variable gene selection. 42 | 43 | .. image:: _static/query_model.png 44 | :width: 75% 45 | :align: center 46 | 47 | Cell type classification 48 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 49 | 50 | :mod:`scimilarity.cell_annotation` classifies cells by finding the most 51 | similar cells in a reference catalog of 7M author annotated cells. This does not 52 | require a user imputed reference. Users can subset the reference to cell types 53 | of interest to increase annotation accuracy. 54 | 55 | .. image:: _static/annotation_model.png 56 | :width: 75% 57 | :align: center 58 | 59 | Model training 60 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 61 | 62 | :mod:`scimilarity.training_models` provides functionality to train new 63 | SCimilarity models. User provided training datasets can be used as input to 64 | Cell Ontology filtering, triplet sampling, and model training. 65 | 66 | .. image:: _static/model_training.png 67 | :width: 75% 68 | :align: center 69 | 70 | .. note:: 71 | Model training tutorials are under construction. 72 | 73 | And more! 74 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 75 | Use SCimilarity to derive new cell state signatures, compare in vivo and in 76 | vitro conditions, or simply visualize datasets in a pre-computed common space. 77 | With a pan-body measure of cell similarity, there are plenty of new 78 | opportunities. 79 | 80 | Indices 81 | -------------------------------------------------------------------------------- 82 | 83 | | :ref:`genindex` 84 | | :ref:`modindex` 85 | -------------------------------------------------------------------------------- /docs/install.rst: -------------------------------------------------------------------------------- 1 | .. _Installation: 2 | 3 | Installation and Setup 4 | ================================================================================ 5 | 6 | Installing the SCimilarity API 7 | -------------------------------------------------------------------------------- 8 | 9 | The latest API release can be installed quickly using ``pip`` in the 10 | usual manner: 11 | 12 | :: 13 | 14 | pip install scimilarity 15 | 16 | The SCimilarity API is under activate development. The latest development API 17 | can be downloaded from `GitHub `__ 18 | and installed as follows: 19 | 20 | :: 21 | 22 | git clone https://github.com/genentech/scimilarity.git 23 | cd scimilarity 24 | pip install -e . 25 | 26 | .. warning:: 27 | 28 | To enable rapid searches across tens of millions of cells, SCimilarity has very 29 | high memory requirements. To make queries, you will need at least 64 GB of 30 | system RAM. 31 | 32 | .. warning:: 33 | 34 | If your environment has sufficient memory but loading the model or making 35 | kNN queries crashes, that may be due to older versions of dependencies such 36 | as hnswlib or numpy. We recommend using either using the Docker container 37 | or Conda environment described below. 38 | 39 | .. note:: 40 | 41 | A GPU is not necessary for most applications, but model training will 42 | require GPU resources. 43 | 44 | Downloading the pretrained models 45 | -------------------------------------------------------------------------------- 46 | 47 | You can download the following pretrained models for use with SCimilarity from 48 | Zenodo: 49 | https://zenodo.org/records/10685499 50 | 51 | 52 | Conda environment setup 53 | -------------------------------------------------------------------------------- 54 | 55 | To install the SCimilarity API in a [Conda](https://docs.conda.io) environment 56 | we recommend this environment setup: 57 | 58 | :download:`Download environment file <_static/environment.yaml>` 59 | 60 | .. literalinclude:: _static/environment.yaml 61 | :language: YAML 62 | 63 | Followed by installing the ``scimilarity`` package via ``pip``, as above. 64 | 65 | Using the SCimilarity Docker container 66 | -------------------------------------------------------------------------------- 67 | 68 | A Docker container that includes the SCimilarity API is available from the 69 | `GitHub Container Registry `__, which can 70 | be pulled via: 71 | 72 | :: 73 | 74 | docker pull ghcr.io/genentech/scimilarity:latest 75 | 76 | Models are not included in the Docker container and must be downloaded separately. 77 | 78 | There are four preset bind points in the container: 79 | 80 | * ``/models`` 81 | * ``/data`` 82 | * ``/workspace`` 83 | * ``/scratch`` 84 | 85 | We require binding ``/models`` to your local path storing SCimilarity models, 86 | ``/data`` to your repository of scRNA-seq data, and ``/workspace`` to your 87 | notebook path. 88 | 89 | You can initiate a Jupyter Notebook session rooted in ``/workspace`` using the 90 | ``start-notebook`` command as follows: 91 | 92 | :: 93 | 94 | docker run -it --platform linux/amd64 -p 8888:8888 \ 95 | -v /path/to/workspace:/workspace \ 96 | -v /path/to/data:/data \ 97 | -v /path/to/models:/models \ 98 | ghcr.io/genentech/scimilarity:latest start-notebook 99 | -------------------------------------------------------------------------------- /docs/modules/anndata_data_models.rst: -------------------------------------------------------------------------------- 1 | scimilarity.anndata_data_models 2 | -------------------------------------------------------------------------------- 3 | 4 | .. automodule:: scimilarity.anndata_data_models 5 | :members: 6 | :show-inheritance: 7 | -------------------------------------------------------------------------------- /docs/modules/cell_annotation.rst: -------------------------------------------------------------------------------- 1 | scimilarity.cell_annotation 2 | -------------------------------------------------------------------------------- 3 | 4 | .. automodule:: scimilarity.cell_annotation 5 | :members: 6 | :show-inheritance: -------------------------------------------------------------------------------- /docs/modules/cell_embedding.rst: -------------------------------------------------------------------------------- 1 | scimilarity.cell_embedding 2 | -------------------------------------------------------------------------------- 3 | 4 | .. automodule:: scimilarity.cell_embedding 5 | :members: 6 | :show-inheritance: -------------------------------------------------------------------------------- /docs/modules/cell_query.rst: -------------------------------------------------------------------------------- 1 | scimilarity.cell_query 2 | -------------------------------------------------------------------------------- 3 | 4 | .. automodule:: scimilarity.cell_query 5 | :members: 6 | :show-inheritance: -------------------------------------------------------------------------------- /docs/modules/cell_search_knn.rst: -------------------------------------------------------------------------------- 1 | scimilarity.cell_search_knn 2 | -------------------------------------------------------------------------------- 3 | 4 | .. automodule:: scimilarity.cell_search_knn 5 | :members: 6 | :show-inheritance: 7 | -------------------------------------------------------------------------------- /docs/modules/interpreter.rst: -------------------------------------------------------------------------------- 1 | scimilarity.interpreter 2 | -------------------------------------------------------------------------------- 3 | 4 | .. automodule:: scimilarity.interpreter 5 | :members: 6 | :show-inheritance: -------------------------------------------------------------------------------- /docs/modules/nn_models.rst: -------------------------------------------------------------------------------- 1 | scimilarity.nn_models 2 | -------------------------------------------------------------------------------- 3 | 4 | .. automodule:: scimilarity.nn_models 5 | :members: 6 | :show-inheritance: -------------------------------------------------------------------------------- /docs/modules/ontologies.rst: -------------------------------------------------------------------------------- 1 | scimilarity.ontologies 2 | -------------------------------------------------------------------------------- 3 | 4 | .. automodule:: scimilarity.ontologies 5 | :members: 6 | :show-inheritance: -------------------------------------------------------------------------------- /docs/modules/tiledb_data_models.rst: -------------------------------------------------------------------------------- 1 | scimilarity.tiledb_data_models 2 | -------------------------------------------------------------------------------- 3 | 4 | .. automodule:: scimilarity.tiledb_data_models 5 | :members: 6 | :show-inheritance: -------------------------------------------------------------------------------- /docs/modules/training_models.rst: -------------------------------------------------------------------------------- 1 | scimilarity.training_models 2 | -------------------------------------------------------------------------------- 3 | 4 | .. automodule:: scimilarity.training_models 5 | :members: 6 | :show-inheritance: -------------------------------------------------------------------------------- /docs/modules/triplet_selector.rst: -------------------------------------------------------------------------------- 1 | scimilarity.triplet_selector 2 | -------------------------------------------------------------------------------- 3 | 4 | .. automodule:: scimilarity.triplet_selector 5 | :members: 6 | :show-inheritance: -------------------------------------------------------------------------------- /docs/modules/utils.rst: -------------------------------------------------------------------------------- 1 | scimilarity.utils 2 | -------------------------------------------------------------------------------- 3 | 4 | .. automodule:: scimilarity.utils 5 | :members: 6 | :show-inheritance: -------------------------------------------------------------------------------- /docs/modules/visualizations.rst: -------------------------------------------------------------------------------- 1 | scimilarity.visualizations 2 | -------------------------------------------------------------------------------- 3 | 4 | .. automodule:: scimilarity.visualizations 5 | :members: 6 | :show-inheritance: -------------------------------------------------------------------------------- /docs/modules/zarr_data_models.rst: -------------------------------------------------------------------------------- 1 | scimilarity.zarr_data_models 2 | -------------------------------------------------------------------------------- 3 | 4 | .. automodule:: scimilarity.zarr_data_models 5 | :members: 6 | :show-inheritance: -------------------------------------------------------------------------------- /docs/modules/zarr_dataset.rst: -------------------------------------------------------------------------------- 1 | scimilarity.zarr_dataset 2 | -------------------------------------------------------------------------------- 3 | 4 | .. automodule:: scimilarity.zarr_dataset 5 | :members: 6 | :show-inheritance: -------------------------------------------------------------------------------- /docs/news.rst: -------------------------------------------------------------------------------- 1 | .. _news: 2 | .. include:: ../NEWS.rst 3 | -------------------------------------------------------------------------------- /docs/notebooks/training_tutorial.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "8943365d-1db3-4b0f-9ba0-18ed16e124fd", 6 | "metadata": {}, 7 | "source": [ 8 | "# Data preparation, model training, and building post-training structures\n", 9 | "\n", 10 | "We have transitioned the SCimilarity codebase to the use of [TileDB](https://tiledb.com/) instead of the previously used [Zarr](https://zarr.dev/) to streamline dataloading, training, and building of post-training data structures. \n", 11 | "\n", 12 | "This tutorial is a high level workflow for training SCimilarity, including data structure requirement, dataloaders, and scripts to build post-training data structures.\n", 13 | "\n", 14 | "The workflow is divided into the following sections:\n", 15 | "\n", 16 | " 1. Process the data into CellArr TileDB format.\n", 17 | " 2. Train using the TileDB dataloader.\n", 18 | " 3. Build all post-training data structures, such as kNN indices and TileDB arrays." 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "id": "086e39f2-6d87-40bd-bb84-5a8f06011d35", 24 | "metadata": {}, 25 | "source": [ 26 | "## 1. Construct a TileDB store in CellArr format\n", 27 | "\n", 28 | "To facilate random access for a cell corpus that is too large to fit into memory, we currently recommend a TileDB structure as used in [CellArr](https://github.com/CellArr/cellarr). The construction requires a collection of AnnData H5AD files which will be organized into TileDB arrays and dataframes by CellArr.\n", 29 | "\n", 30 | "For more details on how to construct CellArr TileDB structures from AnnData H5AD files, please see the [CellArr documentation](https://cellarr.github.io/cellarr/) and the [tutorial](https://cellarr.github.io/cellarr/tutorial_cellxgene.html) to create CellArr objects from CELLxGENE datasets.\n", 31 | "\n", 32 | "**Before you start creating a CellArr object, read the below sections on requirements and processing the AnnData.**" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "id": "90bbe463-55f5-42c6-8165-b66638213df2", 38 | "metadata": {}, 39 | "source": [ 40 | "### 1.1 Processing the AnnData\n", 41 | "\n", 42 | "We recommend you process the AnnData before constructing the CellArr. There are required `obs` columns that are described in the next section.\n", 43 | "\n", 44 | "Some important requirements:\n", 45 | "\n", 46 | " 1. `var.index` must be in gene symbols and not Ensemble IDs, you will need to map Ensemble ID to gene symbol.\n", 47 | " 2. If you map Ensemble ID to gene symbol, you will need to consolidate counts in cases where multiple Ensemble IDs map to one gene symbol. We provide a utility function for this as `utils.consolidate_duplicate_symbols`.\n", 48 | " 3. Please ensure that `var.index` contains only gene symbols as sometimes it can contain a mix of symbols and Ensemble IDs. \n", 49 | "\n", 50 | "An example of this process, assuming you have mapping tables, is shown below. Please note that this may need to be modified depending on the structure of your mapping table." 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "id": "57f0a94b-9f2c-4eb2-8f92-763074c9a6ca", 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "def get_id2name(mapping_table_file: str):\n", 61 | " import pandas as pd\n", 62 | "\n", 63 | " id2name = pd.read_csv(mapping_table_file, delimiter=\"\\t\", index_col=0, dtype=\"unicode\")[\"Gene name\"]\n", 64 | " id2name = id2name.reset_index().drop_duplicates().set_index(\"Gene stable ID\")[\"Gene name\"]\n", 65 | " id2name = pd.concat([id2name, pd.Series(id2name.values, index=id2name.values).drop_duplicates()])\n", 66 | " return id2name\n", 67 | "\n", 68 | "def convert_ids_to_names(adata):\n", 69 | " id2name = get_id2name()\n", 70 | " if not any(adata.var.index.isin(id2name.keys())) and \"symbol\" in adata.var.columns:\n", 71 | " adata.var = adata.var.set_index(\"symbol\", drop=False)\n", 72 | " adata.var.index = adata.var.index.str.replace(\"'\", \"\")\n", 73 | " adata.var.index = adata.var.index.str.replace('\"', \"\")\n", 74 | " adata = adata[:, (adata.var.index.isin(id2name.keys())) & ~(adata.var.index.isnull())].copy()\n", 75 | " adata.var.index = id2name[adata.var.index]\n", 76 | " return adata\n", 77 | "\n", 78 | "def clean_var(adata):\n", 79 | " from scimilarity.utils import consolidate_duplicate_symbols\n", 80 | " \n", 81 | " adata = convert_ids_to_names(adata)\n", 82 | " adata = consolidate_duplicate_symbols(adata)\n", 83 | " adata.var.index.name = \"symbol\"\n", 84 | " return adata" 85 | ] 86 | }, 87 | { 88 | "cell_type": "markdown", 89 | "id": "613ba0ca-6ac6-4cf2-abfc-0cded8fea50e", 90 | "metadata": {}, 91 | "source": [ 92 | "### 1.2 Required columns\n", 93 | "\n", 94 | "We require specific columns to be present in the metadata by default, many of which are used for filtering cells. If you would like to change the default filter and/or what columns are required, you can do so with the `filter_condition` parameter in the dataloader class `tiledb_data_models.CellMultisetDataModule`.\n", 95 | "\n", 96 | "You may include any columns you wish in the metadata but the following columns are required by default:\n", 97 | "\n", 98 | " 1. A string column for the study identifier, e.g., `datasetID`.\n", 99 | " 2. A string column for sample identifier, e.g., `sampleID`.\n", 100 | " 3. A string column for the cell type label identifier that contains the **cell ontology term identifier**, e.g., `cellTypeOntologyID`.\n", 101 | " 4. A string column named `tissue` that contains the tissue annotation.\n", 102 | " 5. A string column named `disease` that contains the disease annotation.\n", 103 | " 6. An int column named `n_genes_by_counts` that contains the number of genes expressed, as produced by `scanpy.pp.calculate_qc_metrics`.\n", 104 | " 7. An int column named `total_counts` that contains the number of UMIs, as produced by `scanpy.pp.calculate_qc_metrics`.\n", 105 | " 8. An int column named `total_counts_mt` that contains the number of mitochondrial UMIs, as produced by `scanpy.pp.calculate_qc_metrics`.\n", 106 | " 9. A float column named `pct_counts_mt` that contains the percentage of mitochondrial UMIs, as produced by `scanpy.pp.calculate_qc_metrics`.\n", 107 | " 10. An int column named `predicted_doublets` that is `1` for doublet and `0` for not doublet, as produced by [scrublet](https://github.com/swolock/scrublet) or your choice of doublet detection tool.\n", 108 | "\n", 109 | "**NOTE**: Not every column needs to have values. For example, some cells do not have cell type annotations, in which case you may leave it blank.\n", 110 | "\n", 111 | "An example of this process is shown below. Please note that you may need to customize this based on the columns that exist in your data." 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "id": "cda21c09-4c8e-4bc4-86da-afa10bf5304f", 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [ 121 | "# datasetID and sampleID are explicit as it is not trivial to infer it from obs column names\n", 122 | "# Feel free to add cell type ontology ID, tissue, or disease if you wish to make those explicit per AnnData\n", 123 | "def clean_obs(adata, datasetID, sampleID):\n", 124 | " # If you wish to use a doublet prediction tool other than scrublet, feel free\n", 125 | " # Though, we will require the \"predicted_doublets\" column as described previously\n", 126 | " import scrublet as scr\n", 127 | " import scanpy as sc\n", 128 | "\n", 129 | " # Determine between \"MT-\" and \"mt-\"\n", 130 | " mito_prefix = \"MT-\"\n", 131 | " if any(adata.var.index.str.startswith(\"mt-\")) is True:\n", 132 | " mito_prefix = \"mt-\"\n", 133 | "\n", 134 | " # QC stats\n", 135 | " adata.var[\"mt\"] = adata.var_names.str.startswith(mito_prefix)\n", 136 | " sc.pp.calculate_qc_metrics(\n", 137 | " adata,\n", 138 | " qc_vars=[\"mt\"],\n", 139 | " percent_top=None,\n", 140 | " log1p=False,\n", 141 | " inplace=True,\n", 142 | " layer=\"counts\",\n", 143 | " )\n", 144 | " obs = adata.obs.copy()\n", 145 | "\n", 146 | " obs.columns = [x[0].lower() + x[1:] for x in obs.columns] # lowercase first letter\n", 147 | "\n", 148 | " obs[\"datasetID\"] = datasetID\n", 149 | " obs[\"sampleID\"] = sampleID\n", 150 | "\n", 151 | " # Scrublet can fail, so we default to False in that case\n", 152 | " try:\n", 153 | " scrub = scr.Scrublet(adata.layers['counts'])\n", 154 | " doublet_scores, predicted_doublets = scrub.scrub_doublets(verbose=False)\n", 155 | " obs[\"doublet_score\"] = doublet_scores\n", 156 | " obs[\"predicted_doublets\"] = predicted_doublets.astype(int)\n", 157 | " except:\n", 158 | " obs[\"doublet_score\"] = np.nan\n", 159 | " obs[\"predicted_doublets\"] = 0\n", 160 | " pass\n", 161 | "\n", 162 | " # You will need to go through all columns that might contain cell type ontology\n", 163 | " # or convert cell type name to ontology (see: scimilarity.ontologies for helper functions)\n", 164 | " if \"cellTypeOntologyID\" not in obs.columns:\n", 165 | " obs[\"cellTypeOntologyID\"] = \"\"\n", 166 | "\n", 167 | " # You will need to go through all columns that might contain tissue\n", 168 | " if \"tissue\" not in obs.columns and \"meta_tissue\" not in obs.columns:\n", 169 | " obs[\"tissue\"] = \"\"\n", 170 | " elif \"tissue\" not in obs.columns and \"meta_tissue\" in obs.columns:\n", 171 | " obs = obs.rename(columns={\"meta_tissue\": \"tissue\"})\n", 172 | "\n", 173 | " # You will need to go through all columns that might contain disease\n", 174 | " if \"disease\" not in obs.columns and \"meta_disease\" not in obs.columns:\n", 175 | " obs[\"disease\"] = \"\"\n", 176 | " elif \"disease\" not in obs.columns and \"meta_disease\" in obs.columns:\n", 177 | " obs = obs.rename(columns={\"meta_disease\": \"disease\"})\n", 178 | "\n", 179 | " columns = [\n", 180 | " \"datasetID\", \"sampleID\", \"cellTypeOntologyID\", \"tissue\", \"disease\",\n", 181 | " \"n_genes_by_counts\", \"total_counts\", \"total_counts_mt\", \"pct_counts_mt\",\n", 182 | " \"doublet_score\", \"predicted_doublets\",\n", 183 | " ]\n", 184 | " obs = obs[columns].copy()\n", 185 | "\n", 186 | " convert_dict = {\n", 187 | " \"datasetID\": str,\n", 188 | " \"sampleID\": str,\n", 189 | " \"cellTypeOntologyID\": str,\n", 190 | " \"tissue\": str,\n", 191 | " \"disease\": str,\n", 192 | " \"n_genes_by_counts\": int,\n", 193 | " \"total_counts\": int,\n", 194 | " \"total_counts_mt\": int,\n", 195 | " \"pct_counts_mt\": float,\n", 196 | " \"doublet_score\": float,\n", 197 | " \"predicted_doublets\": int,\n", 198 | " }\n", 199 | " adata.obs = obs.astype(convert_dict)\n", 200 | " return adata" 201 | ] 202 | }, 203 | { 204 | "cell_type": "markdown", 205 | "id": "84451603-8e9d-4990-b87d-50c031ae5582", 206 | "metadata": {}, 207 | "source": [ 208 | "### 1.3 Reason for required columns\n", 209 | "\n", 210 | "First, the cell type ontology identifier is required because we utilize the cell ontology to determine ancestors or descendents during triplet mining. It is important that the cell type annotation is a controlled vocabulary that conforms to cell ontology.\n", 211 | "\n", 212 | "Second, the CellArr cell corpus is a database of **all** cells, not just those which will be used in training. Cells not used in training are still used in cell search. The default behavior for the dataloader is to select for cells that can be used in training, as follows:\n", 213 | "\n", 214 | " - If the cell type annotation exists and is an ontology identifier\n", 215 | " - `total_counts` > 1000\n", 216 | " - `n_genes_by_counts` > 500\n", 217 | " - `pct_counts_mt` < 20\n", 218 | " - `predicted_doublets` == 0\n", 219 | "\n", 220 | "These criteria can be customized in the dataloader, but the above are the defaults.\n", 221 | "\n", 222 | "Lastly, this standardization of datasets makes it easier to create a CellArr object and provides expected columns for the dataloader. Feel free to add additional columns as you wish." 223 | ] 224 | }, 225 | { 226 | "cell_type": "markdown", 227 | "id": "9d1b3213-4bf4-43c2-ba3f-1fd7f19ddc0a", 228 | "metadata": {}, 229 | "source": [ 230 | "### 1.4 Select a gene space\n", 231 | "\n", 232 | "CellArr will create tiledb structures that cover the union set of all genes in all your datasets. This is usually not efficient for the model, as only a subset of genes are well represented across studies, so we filter the gene space for model training.\n", 233 | "\n", 234 | "An example of this process is shown below, utilizing the sample metadata to look at the gene vs study count distribution and select well represented genes. **You should save the gene order file so that it can be used for training.**" 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": null, 240 | "id": "8c794227-4d6a-4991-bc66-261c0df2310e", 241 | "metadata": {}, 242 | "outputs": [], 243 | "source": [ 244 | "def get_genes_vs_studies(cellarr_path):\n", 245 | " import os\n", 246 | " import tiledb\n", 247 | " import seaborn as sns\n", 248 | " from matplotlib import pyplot as plt\n", 249 | "\n", 250 | " plt.rcParams[\"pdf.fonttype\"] = 42\n", 251 | "\n", 252 | " SAMPLEURI = \"sample_metadata\"\n", 253 | "\n", 254 | " # Get studies, one per row of a dataframe\n", 255 | " sample_tdb = tiledb.open(os.path.join(tiledb_base_path, SAMPLEURI), \"r\")\n", 256 | " sample_df = sample_tdb.query(attrs=[\"datasetID\", \"cellarr_cell_counts\"]).df[:]\n", 257 | " sample_tdb.close()\n", 258 | " \n", 259 | " df = sample_df.groupby(\"datasetID\", observed=True)[\"cellarr_cell_counts\"].max().sort_values(ascending=False).reset_index()\n", 260 | " df = df[df.index.isin(df[\"datasetID\"].drop_duplicates().index.values)]\n", 261 | " \n", 262 | " selected = []\n", 263 | " for idx, row in df.iterrows():\n", 264 | " datasetID = row[\"datasetID\"]\n", 265 | " selected.append(sample_df[sample_df[\"datasetID\"]==datasetID].index[0])\n", 266 | " print(len(selected))\n", 267 | " \n", 268 | " sample_tdb = tiledb.open(os.path.join(tiledb_base_path, SAMPLEURI), \"r\")\n", 269 | " genes_df = sample_tdb.df[sorted(selected)]\n", 270 | " sample_tdb.close()\n", 271 | " \n", 272 | " # Get all original genes from each study\n", 273 | " gene_counts = {}\n", 274 | " for idx, row in genes_df.iterrows():\n", 275 | " genes = row['cellarr_original_gene_set'].split(',')\n", 276 | " for g in genes:\n", 277 | " if g not in gene_counts:\n", 278 | " gene_counts[g] = 0\n", 279 | " gene_counts[g] += 1\n", 280 | " gene_counts = {k: v for k, v in sorted(gene_counts.items(), key=lambda item: item[1], reverse=False)}\n", 281 | " \n", 282 | " # Construct gene vs count dataframe and visualize\n", 283 | " gene = []\n", 284 | " count = []\n", 285 | " for k, v in gene_counts.items():\n", 286 | " gene.append(k)\n", 287 | " count.append(v)\n", 288 | " gene_counts_df = pd.DataFrame({\"gene\": gene, \"count\": count}) \n", 289 | " gene_counts_df = gene_counts_df.sort_values(by=\"count\", ascending=False).reset_index(drop=True)\n", 290 | " \n", 291 | " fig, ax = plt.subplots(1)\n", 292 | " sns.barplot(ax=ax, x=\"gene\", y=\"count\", data=gene_counts_df)\n", 293 | " plt.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)\n", 294 | " plt.tight_layout()\n", 295 | "\n", 296 | " return gene_counts_df\n", 297 | "\n", 298 | "def select_genes(gene_counts_df, min_studies, filename=\"scimilarity_gene_order.tsv\"):\n", 299 | " gene_order = gene_counts_df[gene_counts_df[\"count\"]>=min_studies][\"gene\"].values.tolist()\n", 300 | " gene_order = sorted(gene_order)\n", 301 | " with open(filename, \"w\") as f:\n", 302 | " f.write(\"\\n\".join(gene_order))" 303 | ] 304 | }, 305 | { 306 | "cell_type": "markdown", 307 | "id": "ce12b4c0-c63b-4d76-9f04-0f1d9d2b3cba", 308 | "metadata": {}, 309 | "source": [ 310 | "## 2. Training\n", 311 | "\n", 312 | "We provide an example training script at `scripts/train.py` which you can customize to your specifications.\n", 313 | "\n", 314 | "Training follows the `pytorch-lightning` paradigm:\n", 315 | " - Data module class `CellMultisetDataModule` from `tiledb_data_models`\n", 316 | " - Training module class `MetricLearning` from `training_models`\n", 317 | "\n", 318 | "Training can be resumed if interrupted by using the checkpoints logged by `pytorch-lightning`. The training script already implements this by specifying the log directory. By default, we use `tensorboard` as our logger, but you can change this to your preferred logger.\n", 319 | "\n", 320 | "Once training is complete, all files will be saved to the model directory given by the user. This includes csv files that contain the indices and metadata for cells that were used in training and validation." 321 | ] 322 | }, 323 | { 324 | "cell_type": "markdown", 325 | "id": "2a3afff8-8b1a-4c6b-ba26-a66fba324449", 326 | "metadata": {}, 327 | "source": [ 328 | "## 3. Post-training structures" 329 | ] 330 | }, 331 | { 332 | "cell_type": "markdown", 333 | "id": "c71986f0-af76-487b-b66c-d1e8a6f3c7a6", 334 | "metadata": {}, 335 | "source": [ 336 | "### 3.1 Embeddings tiledb\n", 337 | "\n", 338 | "The first post-training structure to create is the embeddings TileDB, which will be used to create kNN indices.\n", 339 | "\n", 340 | "We provide an example script at `scripts/build_embeddings.py` which embeds all cells in your CellArr objects and saves a TileDB array to `{model_path}/cellsearch/cell_embedding`." 341 | ] 342 | }, 343 | { 344 | "cell_type": "markdown", 345 | "id": "4c76e0b2-cba5-4848-98d6-00410df104f1", 346 | "metadata": {}, 347 | "source": [ 348 | "### 3.2 Annotation kNN\n", 349 | "\n", 350 | "The annotation kNN uses the embeddings TileDB and the training cells csv (`train_cells.csv.gz`) to create a `hnswlib` kNN index.\n", 351 | "\n", 352 | "We provide an example script at `scripts/build_annotation_knn.py`. By default, the kNN index file will be saved to `{model_path}/annotation/labelled_kNN.bin` and the reference labels will be saved to `{model_path}/annotation/reference_labels.tsv`." 353 | ] 354 | }, 355 | { 356 | "cell_type": "markdown", 357 | "id": "ae102661-f637-47b4-8de9-ffbe35587041", 358 | "metadata": {}, 359 | "source": [ 360 | "### 3.3 Cell search kNN\n", 361 | "\n", 362 | "The cell search kNN uses the embeddings TileDB to create a `tiledb_vector_search` kNN index for all cells in your CellArr object.\n", 363 | "\n", 364 | "We provide an example script at `scripts/build_cellsearch_knn.py`. By default, the kNN index will be saved to `{model_path}/cellsearch/full_kNN.bin`.\n", 365 | "\n", 366 | "**NOTE**: Previously, we used `hnswlib` as the cell search kNN index. But as the data size grew, we transitioned to an on-disk kNN. The `CellQuery` class still defaults to an `hnswlib` index, but if you followed the process described in this tutorial, you should initialize `CellQuery` with:\n", 367 | "\n", 368 | "`cq = CellQuery(model_path, knn_type=\"tiledb_vector_search\")`" 369 | ] 370 | }, 371 | { 372 | "cell_type": "markdown", 373 | "id": "6a2e1255-7b5b-4d54-a4c2-28133246b12f", 374 | "metadata": {}, 375 | "source": [ 376 | "### 3.4 Cell search metadata\n", 377 | "\n", 378 | "The `CellQuery` class includes a TileDB dataframe with cell metadata containing precomputed cell type predictions for all cells, as well as other metadata from the CellArr object.\n", 379 | "\n", 380 | "We provide an example script at `scripts/build_cellsearch_metadata.py` that uses the CellArr object, embeddings TileDB, and annotation kNN to construct the TileDB dataframe containing the cell metadata.\n", 381 | "\n", 382 | "This process will copy over cell metadata from the CellArr object into the model's cell metadata. Though this is somewhat inefficient in storage, we prefer the separation of the \"raw\" data contained in the CellArr object and the \"processed\" data in the model structures, as you may wish to train completely different models using the CellArr object.\n", 383 | "\n", 384 | "**NOTE**: We often apply a safelist for cell type predictions for the cell metadata to remove cell types that are too general or too specific. The script as an option to use a safelist, given as a file with one gene per line. It is up to the user to decide if they wish to use this and select the cell types they wish to safelist." 385 | ] 386 | }, 387 | { 388 | "cell_type": "markdown", 389 | "id": "ad48d52c-f9b9-4aa4-8adc-59fdd2dd7f64", 390 | "metadata": {}, 391 | "source": [ 392 | "## Conclusion\n", 393 | "\n", 394 | "At this point you will have all the structures you need to run SCimilarity.\n", 395 | "\n", 396 | "One advantage of using the CellArr structure is you can retrieve the original count data using `utils.adata_from_tiledb` based on indices in the cell metadata of the model, which is aligned with the CellArr index. This includes results from kNN cell search or from filters you may want to apply to the cell metadata (e.g., \"get all regulatory T cells\").\n", 397 | "\n", 398 | "An example of how to search and then retrieve count data is shown below." 399 | ] 400 | }, 401 | { 402 | "cell_type": "code", 403 | "execution_count": null, 404 | "id": "4a18ada6-b8e3-4edb-957f-72806227681b", 405 | "metadata": {}, 406 | "outputs": [], 407 | "source": [ 408 | "def search_then_gather_cells(model_path, cellarr_path, adata, centroid_column_name):\n", 409 | " from scimilarity import CellQuery\n", 410 | " from scimilarity.utils import adata_from_tiledb \n", 411 | "\n", 412 | " cq = CellQuery(model_path, knn_type=\"tiledb_vector_search\")\n", 413 | " centroid_embedding, nn_idxs, nn_dists, results_metadata, qc_stats = cq.search_centroid_nearest(adata, centroid_column_name)\n", 414 | " \n", 415 | " results_adata = adata_from_tiledb(results_metadata[\"index\"].values, tiledb_base_path=cellarr_path)\n", 416 | "\n", 417 | " #The above results are from the CellArr raw data, so augment the obs with `results_metadata` \n", 418 | " for c in results_metadata.columns:\n", 419 | " if c not in results_adata.obs:\n", 420 | " results_adata.obs[c] = results_metadata[c].values\n", 421 | " return results_adata" 422 | ] 423 | } 424 | ], 425 | "metadata": { 426 | "kernelspec": { 427 | "display_name": "Python 3 (ipykernel)", 428 | "language": "python", 429 | "name": "python3" 430 | }, 431 | "language_info": { 432 | "codemirror_mode": { 433 | "name": "ipython", 434 | "version": 3 435 | }, 436 | "file_extension": ".py", 437 | "mimetype": "text/x-python", 438 | "name": "python", 439 | "nbconvert_exporter": "python", 440 | "pygments_lexer": "ipython3", 441 | "version": "3.10.14" 442 | } 443 | }, 444 | "nbformat": 4, 445 | "nbformat_minor": 5 446 | } 447 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | # Requirements file for building the documentation 2 | sphinx>=5.3.0 3 | pydata-sphinx-theme 4 | nbsphinx 5 | sphinx-gallery==0.10 6 | -------------------------------------------------------------------------------- /docs/tutorials.rst: -------------------------------------------------------------------------------- 1 | .. _Tutorials: 2 | 3 | Tutorials 4 | ================================================================================ 5 | 6 | .. nbgallery:: 7 | 8 | Individual cell search 9 | Gene signature based search 10 | Cluster-based search 11 | Cell type annotation 12 | Gene attribution and comparison 13 | Model training 14 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | # AVOID CHANGING REQUIRES: IT WILL BE UPDATED BY PYSCAFFOLD! 3 | requires = ["setuptools>=46.1.0", "setuptools_scm[toml]>=5", "wheel"] 4 | build-backend = "setuptools.build_meta" 5 | 6 | [tool.setuptools_scm] 7 | # For smarter version schemes and other configuration options, 8 | # check out https://github.com/pypa/setuptools_scm 9 | version_scheme = "no-guess-dev" 10 | -------------------------------------------------------------------------------- /scripts/build_annotation_knn.py: -------------------------------------------------------------------------------- 1 | import anndata 2 | import argparse 3 | import hnswlib 4 | import os, sys 5 | import numpy as np 6 | import pandas as pd 7 | import tiledb 8 | from tqdm import tqdm 9 | 10 | import warnings 11 | warnings.filterwarnings("ignore") 12 | 13 | 14 | def main(): 15 | parser = argparse.ArgumentParser(description="Build annotation knn from precomputed embeddings") 16 | parser.add_argument("-m", type=str, help="model path") 17 | parser.add_argument("-b", type=int, default=500000, help="cell buffer size") 18 | parser.add_argument("--label_column_name", type=str, default="cellTypeName", help="label column name in metadata") 19 | parser.add_argument("--study_column_name", type=str, default="datasetID", help="study column name in metadata") 20 | parser.add_argument("--knn", type=str, default="labelled_kNN.bin", help="knn filename") 21 | parser.add_argument("--labels", type=str, default="reference_labels.tsv", help="labels filename") 22 | parser.add_argument("--safelist_file", type=str, default=None, help="optional cell type safelist filename") 23 | parser.add_argument("--ef_construction", type=int, default=1000, help="hnswlib ef construction parameter") 24 | parser.add_argument("--M_construction", type=int, default=80, help="hnswlib M construction parameter") 25 | args = parser.parse_args() 26 | print(args) 27 | 28 | model_path = args.m 29 | buffer_size = args.b 30 | label_column_name = args.label_column_name 31 | study_column_name = args.study_column_name 32 | knn_filename = args.knn 33 | label_filename = args.labels 34 | ef_construction = args.ef_construction 35 | M = args.M_construction 36 | 37 | # tileDB config 38 | cfg = tiledb.Config() 39 | cfg["sm.mem.total_budget"] = 50000000000 # 50G 40 | 41 | # training data 42 | dataframe_path = os.path.join(model_path, "train_cells.csv.gz") 43 | reference_df = pd.read_csv(dataframe_path, index_col=0) 44 | 45 | if args.safelist_file is not None: 46 | with open(args.safelist_file, "r") as fh: 47 | safelist = [line.strip() for line in fh] 48 | reference_df = reference_df[reference_df[label_column_name].isin(safelist)].copy() 49 | assert reference_df.shape[0] > 0, "No valid safelist entries in data" 50 | 51 | # precomputed embeddings 52 | embedding_tdb_uri = os.path.join(model_path, "cellsearch", "cell_embedding") 53 | 54 | embeddings = [] 55 | labels = [] 56 | studies = [] 57 | for i in tqdm(range(0, reference_df.shape[0], buffer_size)): 58 | n = min(i + buffer_size, reference_df.shape[0]) 59 | df = reference_df.iloc[range(i, n)].copy() 60 | 61 | embedding_tdb = tiledb.open(embedding_tdb_uri, "r", config=cfg) 62 | cell_idx = df.index.tolist() 63 | attr = embedding_tdb.schema.attr(0).name 64 | embedding = embedding_tdb.query(attrs=[attr], coords=True).multi_index[cell_idx][attr] 65 | embedding_tdb.close() 66 | embeddings.append(embedding) 67 | labels.extend(df[label_column_name].tolist()) 68 | studies.extend(df[study_column_name].tolist()) 69 | embeddings = np.vstack(embeddings) 70 | print("embeddings", embeddings.shape) 71 | 72 | annotation_path = os.path.join(model_path, "annotation") 73 | os.makedirs(annotation_path, exist_ok=True) 74 | 75 | # save labels 76 | labels_fullpath = os.path.join(annotation_path, label_filename) 77 | if os.path.isfile(labels_fullpath): # backup existing 78 | os.rename(labels_fullpath, labels_fullpath + ".bak") 79 | with open(labels_fullpath, "w") as f: 80 | for i in range(len(labels)): 81 | f.write(f"{labels[i]}\t{studies[i]}\n") 82 | 83 | # build knn 84 | n_cells, n_dims = embeddings.shape 85 | knn = hnswlib.Index(space="cosine", dim=n_dims) 86 | knn.init_index(max_elements=n_cells, ef_construction=ef_construction, M=M) 87 | knn.set_ef(ef_construction) 88 | knn.add_items(embeddings, range(len(embeddings))) 89 | 90 | # save knn 91 | knn_fullpath = os.path.join(annotation_path, knn_filename) 92 | if os.path.isfile(knn_fullpath): # backup existing 93 | os.rename(knn_fullpath, knn_fullpath + ".bak") 94 | knn.save_index(knn_fullpath) 95 | 96 | 97 | if __name__ == "__main__": 98 | main() -------------------------------------------------------------------------------- /scripts/build_cellsearch_knn.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import tiledb 5 | import tiledb.vector_search as vs 6 | from tiledb.vector_search import _tiledbvspy as vspy 7 | 8 | cfg = tiledb.Config() 9 | cfg["sm.mem.total_budget"] = 50000000000 # 50G 10 | 11 | def main(): 12 | parser = argparse.ArgumentParser(description="Build cellsearch knn from embeddings tiledb") 13 | parser.add_argument("-m", type=str, help="model path") 14 | parser.add_argument("--knn_filename", type=str, default="full_kNN.bin", help="knn filename") 15 | parser.add_argument("--knn_type", type=str, default="tiledb_vector_search", help="Type of knn: ['hnswlib', 'tiledb_vector_search']") 16 | parser.add_argument("--ef_construction", type=int, default=1000, help="hnswlib ef construction parameter") 17 | parser.add_argument("--M_construction", type=int, default=80, help="hnswlib M construction parameter") 18 | args = parser.parse_args() 19 | print(args) 20 | 21 | model_path = args.m 22 | knn_filename = args.knn_filename 23 | knn_type = args.knn_type 24 | ef_construction = args.ef_construction 25 | M = args.M_construction 26 | 27 | # embeddings 28 | cellsearch_path = os.path.join(model_path, "cellsearch") 29 | embedding_tdb = tiledb.open(os.path.join(cellsearch_path, "cell_embedding"), "r", config=cfg) 30 | attr = embedding_tdb.schema.attr(0).name 31 | embeddings = embedding_tdb[:][attr] 32 | embedding_tdb.close() 33 | 34 | # build knn 35 | knn_fullpath = os.path.join(cellsearch_path, knn_filename) 36 | if knn_type == "hnswlib": 37 | # build knn 38 | n_cells, n_dims = embeddings.shape 39 | knn = hnswlib.Index(space="cosine", dim=n_dims) 40 | knn.init_index(max_elements=n_cells, ef_construction=ef_construction, M=M) 41 | knn.set_ef(ef_construction) 42 | knn.add_items(embeddings, range(len(embeddings))) 43 | knn.save_index(os.path.join(cellsearch_path, knn_filename)) 44 | elif knn_type == "tiledb_vector_search": 45 | knn = vs.ingest( 46 | index_type="IVF_FLAT", 47 | index_uri=os.path.join(cellsearch_path, knn_filename), 48 | input_vectors=embeddings, 49 | distance_metric=vspy.DistanceMetric.COSINE, 50 | normalized=True, 51 | filters=tiledb.FilterList([tiledb.LZ4Filter()]) 52 | ) 53 | knn.vacuum() 54 | 55 | print("Vector array URI:", knn.db_uri, "\n") 56 | A = tiledb.open(knn.db_uri) 57 | print("Vector array schema:\n") 58 | print(A.schema) 59 | print(A.nonempty_domain()) 60 | A.close() 61 | 62 | if __name__ == "__main__": 63 | main() -------------------------------------------------------------------------------- /scripts/build_cellsearch_metadata.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import os, sys 4 | import pandas as pd 5 | import tiledb 6 | from tqdm import tqdm 7 | 8 | from scimilarity import CellAnnotation 9 | from scimilarity.ontologies import import_cell_ontology, get_id_mapper 10 | 11 | import warnings 12 | warnings.filterwarnings("ignore") 13 | 14 | 15 | def main(): 16 | parser = argparse.ArgumentParser(description="Build cellsearch metadata") 17 | parser.add_argument("-t", type=str, help="tiledb base path") 18 | parser.add_argument("-m", type=str, help="model path") 19 | parser.add_argument("-b", type=int, default=100000, help="cell buffer size") 20 | parser.add_argument("--study_column", type=str, default="datasetID", help="study column in tiledb, will be renamed to 'study'") 21 | parser.add_argument("--sample_column", type=str, default="sampleID", help="sample column in tiledb, will be renamed to 'sample'") 22 | parser.add_argument("--tissue_column", type=str, default="tissue", help="tissue column in tiledb, will be renamed to 'tissue'") 23 | parser.add_argument("--disease_column", type=str,default="disease", help="disease column in tiledb, will be renamed to 'disease'") 24 | parser.add_argument("--knn", type=str, default="labelled_kNN.bin", help="knn filename") 25 | parser.add_argument("--labels", type=str, default="reference_labels.tsv", help="labels filename") 26 | parser.add_argument("--safelist_file", type=str,default=None, help="An optional file for a safelist of cell type names in cell type prediction, one per line") 27 | args = parser.parse_args() 28 | print(args) 29 | 30 | tiledb_base_path = args.t 31 | model_path = args.m 32 | buffer_size = args.b 33 | study_column = args.study_column 34 | sample_column = args.sample_column 35 | tissue_column = args.tissue_column 36 | disease_column = args.disease_column 37 | 38 | # paths 39 | CELLURI = "cell_metadata" 40 | cellsearch_path = os.path.join(model_path, "cellsearch") 41 | os.makedirs(cellsearch_path, exist_ok=True) 42 | 43 | # tileDB config 44 | cfg = tiledb.Config() 45 | cfg["sm.mem.total_budget"] = 50000000000 # 50G 46 | 47 | # cell metadata 48 | cell_tdb = tiledb.open(os.path.join(tiledb_base_path, CELLURI), "r", config=cfg) 49 | cell_metadata = cell_tdb.df[:] 50 | cell_tdb.close() 51 | cell_metadata = cell_metadata.reset_index(drop=False, names="index") 52 | cell_metadata = cell_metadata.rename( 53 | columns={study_column: "study", sample_column: "sample", tissue_column: "tissue", disease_column: "disease"} 54 | ) 55 | print(f"cellarr metadata: {cell_metadata.shape}") 56 | print(f"cellarr metadata: {cell_metadata.columns}") 57 | 58 | # map cell type names 59 | onto = import_cell_ontology() 60 | id2name = get_id_mapper(onto) 61 | cell_metadata["author_label"] = cell_metadata["cellTypeOntologyID"].map(id2name).astype(str) 62 | 63 | # training and validation metadata 64 | train_df = pd.read_csv(os.path.join(model_path, "train_cells.csv.gz"), index_col=0) 65 | val_df = pd.read_csv(os.path.join(model_path, "val_cells.csv.gz"), index_col=0) 66 | print(f"training: {train_df.shape}") 67 | print(f"validation: {val_df.shape}") 68 | 69 | # annotation model 70 | filenames = { 71 | "knn": args.knn, 72 | "celltype_labels": args.labels, 73 | } 74 | model = CellAnnotation(model_path=model_path, filenames=filenames) 75 | 76 | if args.safelist_file is not None: 77 | with open(args.safelist_file, "r") as fh: 78 | safelist = [line.strip() for line in fh] 79 | model.safelist_celltypes(safelist) 80 | 81 | embedding_tdb = tiledb.open(os.path.join(cellsearch_path, "cell_embedding"), "r", config=cfg) 82 | prediction_list = [] 83 | prediction_nn_dist_list = [] 84 | data_type_list = [] 85 | for i in tqdm(range(0, cell_metadata.shape[0], buffer_size)): 86 | n = min(i + buffer_size, cell_metadata.shape[0]) 87 | df = cell_metadata.iloc[range(i, n)].copy() 88 | cell_idx = df.index.tolist() 89 | attr = embedding_tdb.schema.attr(0).name 90 | embedding = embedding_tdb.query(attrs=[attr], coords=True).multi_index[cell_idx][attr] 91 | 92 | predictions, _, distances, _ = model.get_predictions_knn(embedding, disable_progress=True) 93 | nn_dist = distances.min(axis=1) 94 | prediction_list.extend(predictions.values.tolist()) 95 | prediction_nn_dist_list.extend(nn_dist.tolist()) 96 | 97 | in_train = [x in train_df.index for x in cell_idx] 98 | in_val = [x in val_df.index for x in cell_idx] 99 | for j in range(len(in_train)): 100 | if in_train[j]: 101 | data_type_list.append("train") 102 | elif in_val[j]: 103 | data_type_list.append("test") 104 | else: 105 | data_type_list.append("NA") 106 | embedding_tdb.close() 107 | 108 | cell_metadata["data_type"] = data_type_list 109 | cell_metadata["prediction"] = prediction_list 110 | cell_metadata["prediction_nn_dist"] = prediction_nn_dist_list 111 | 112 | cell_metadata_tdb_uri = os.path.join(cellsearch_path, "cell_metadata") 113 | tiledb.from_pandas(cell_metadata_tdb_uri, cell_metadata) 114 | 115 | cell_metadata_tdb = tiledb.open(cell_metadata_tdb_uri, "r", config=cfg) 116 | print(cell_metadata_tdb.shape) 117 | print(cell_metadata_tdb.schema) 118 | cell_metadata_tdb.close() 119 | 120 | if __name__ == "__main__": 121 | main() -------------------------------------------------------------------------------- /scripts/build_embeddings.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import sys 5 | from tqdm import tqdm 6 | 7 | import tiledb 8 | import tiledb.vector_search as vs 9 | from tiledb.vector_search import _tiledbvspy as vspy 10 | import numpy as np 11 | from scipy.sparse import coo_matrix, diags 12 | 13 | from scimilarity import CellEmbedding 14 | from scimilarity.utils import write_array_to_tiledb, optimize_tiledb_array 15 | 16 | cfg = tiledb.Config() 17 | cfg["sm.mem.total_budget"] = 50000000000 # 50G 18 | 19 | def get_expression(matrix_tdb, matrix_shape, cell_idx, gene_indices, target_sum=1e4): 20 | results = matrix_tdb[cell_idx, :] 21 | counts = coo_matrix((results["data"], (results["cell_index"], results["gene_index"])), shape=matrix_shape).tocsr() 22 | counts = counts[cell_idx, :] 23 | counts = counts[:, gene_indices] 24 | 25 | X = counts.astype(np.float32) 26 | 27 | # normalize to target sum 28 | row_sums = np.ravel(X.sum(axis=1)) # row sums as a 1D array 29 | # avoid division by zero by setting zero sums to one (they will remain zero after normalization) 30 | row_sums[row_sums == 0] = 1 31 | # create a sparse diagonal matrix with the inverse of the row sums 32 | inv_row_sums = diags(1 / row_sums).tocsr() 33 | # normalize the rows to sum to 1 34 | normalized_matrix = inv_row_sums.dot(X) 35 | # scale the rows sum to target_sum 36 | X = normalized_matrix.multiply(target_sum) 37 | X = X.log1p() 38 | 39 | return X 40 | 41 | def main(): 42 | parser = argparse.ArgumentParser(description="Build embeddings tiledb") 43 | parser.add_argument("-t", type=str, help="CellArr base path") 44 | parser.add_argument("-m", type=str, help="model path") 45 | parser.add_argument("-b", type=int, default=100000, help="batch size") 46 | args = parser.parse_args() 47 | print(args) 48 | 49 | model_path = args.m 50 | batch_size = args.b 51 | 52 | # model 53 | ce = CellEmbedding(model_path) 54 | cellsearch_path = os.path.join(model_path, "cellsearch") 55 | embedding_tdb_uri = os.path.join(cellsearch_path, "cell_embedding") 56 | 57 | # cellarr 58 | tiledb_base_path = args.t 59 | GENEURI = "gene_annotation" 60 | COUNTSURI = "counts" 61 | 62 | # gene space alignment 63 | gene_tdb = tiledb.open(os.path.join(tiledb_base_path, GENEURI), "r", config=cfg) 64 | genes = gene_tdb.query(attrs=["cellarr_gene_index"]).df[:]["cellarr_gene_index"].tolist() 65 | gene_tdb.close() 66 | gene_indices = [genes.index(x) for x in ce.gene_order] 67 | 68 | # counts matrix 69 | matrix_tdb_uri = os.path.join(tiledb_base_path, COUNTSURI) 70 | matrix_tdb = tiledb.open(os.path.join(tiledb_base_path, COUNTSURI), "r", config=cfg) 71 | matrix_shape = (matrix_tdb.nonempty_domain()[0][1] + 1, matrix_tdb.nonempty_domain()[1][1] + 1) 72 | print("Cell counts:", matrix_shape) 73 | 74 | # array schema 75 | xdimtype = np.uint32 76 | ydimtype = np.uint32 77 | value_type = np.float32 78 | 79 | xdim = tiledb.Dim(name="x", domain=(0, matrix_shape[0] - 1), tile=10000, dtype=xdimtype) 80 | ydim = tiledb.Dim(name="y", domain=(0, ce.latent_dim - 1), tile=ce.latent_dim, dtype=ydimtype) 81 | dom = tiledb.Domain(xdim, ydim) 82 | 83 | attr = tiledb.Attr( 84 | name="data", 85 | dtype=value_type, 86 | filters=tiledb.FilterList([tiledb.LZ4Filter()]), 87 | ) 88 | 89 | schema = tiledb.ArraySchema( 90 | domain=dom, 91 | sparse=False, 92 | cell_order="row-major", 93 | tile_order="row-major", 94 | attrs=[attr], 95 | ) 96 | 97 | if os.path.exists(embedding_tdb_uri): 98 | shutil.rmtree(embedding_tdb_uri) 99 | tiledb.Array.create(embedding_tdb_uri, schema) 100 | 101 | # write to array 102 | embeddings = [] 103 | embedding_tdb = tiledb.open(embedding_tdb_uri, "w", config=cfg) 104 | for i in tqdm(range(0, matrix_shape[0], batch_size)): 105 | j = min(i + batch_size, matrix_shape[0]) 106 | cell_idx = slice(i, j) 107 | X = get_expression(matrix_tdb, matrix_shape, cell_idx, gene_indices) 108 | embedding = ce.get_embeddings(X).astype(value_type) 109 | embedding_tdb[cell_idx] = embedding 110 | embeddings.append(embedding) 111 | matrix_tdb.close() 112 | embedding_tdb.close() 113 | 114 | embedding_tdb = tiledb.open(embedding_tdb_uri, "r", config=cfg) 115 | print("Embeddings tiledb:", embedding_tdb.nonempty_domain()) 116 | embeddings = np.vstack(embeddings) 117 | print("Embeddings numpy:", embeddings.shape) 118 | 119 | optimize_tiledb_array(embedding_tdb_uri) 120 | 121 | if __name__ == "__main__": 122 | main() -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import argparse 3 | import json 4 | import tiledb 5 | 6 | import pytorch_lightning as pl 7 | from scimilarity.tiledb_data_models import CellMultisetDataModule 8 | from scimilarity.training_models import MetricLearning 9 | 10 | 11 | # Studies to hold out for validation 12 | val_studies = [ 13 | "DS000012424", 14 | "DS000015749", 15 | "DS000012615", 16 | "DS000012500", 17 | "DS000012585", 18 | "DS000016072", 19 | "DS000012588", 20 | "DS000015711", 21 | "DS000015743", 22 | "DS000015752", 23 | "DS000010060", 24 | "DS000010475", 25 | "DS000011735", 26 | "DS000012592", 27 | "DS000012595", 28 | ] 29 | 30 | # Studies to exclude from training 31 | exclude_studies = [ 32 | "DS000010493", 33 | "DS000011454", 34 | "DS000015993", 35 | "DS000011376", 36 | "DS000015862", 37 | "DS000012588", 38 | "DS000016559", 39 | "DS000012617", 40 | "DS000016065", 41 | "DS000010411", 42 | "DS000012503", 43 | "DS000011665", 44 | "DS000010661", 45 | "DS000017926", 46 | "DS000010632", 47 | "DS000017592", 48 | "DS000015896", 49 | "DS000010633", 50 | "DS000016065", 51 | "DS000016521", 52 | "DS000013506", 53 | "DS000017568", 54 | "DS000018537", 55 | "DS000016407", 56 | "DS000012183", 57 | "DS000014907", 58 | "DS000015699", 59 | "DS000010642", 60 | "DS000016526", 61 | ] 62 | 63 | 64 | def train(args): 65 | tiledb_base_path = args.tiledb 66 | hidden_dim = args.hidden_dim 67 | latent_dim = args.latent_dim 68 | gene_order = args.gene_order 69 | margin = args.m 70 | negative_selection = args.t 71 | triplet_loss_weight = args.w 72 | lr = args.l 73 | batch_size = args.b 74 | n_batches = args.n 75 | max_epochs = args.e 76 | cosine_annealing_tmax = args.cosine_annealing_tmax 77 | suffix = args.suffix 78 | 79 | if cosine_annealing_tmax == 0: 80 | cosine_annealing_tmax = max_epochs 81 | 82 | model_name = ( 83 | f"model_{batch_size}_{margin}_{latent_dim}_{len(hidden_dim)}_{triplet_loss_weight}_{suffix}" 84 | ) 85 | 86 | model_folder = args.model_folder 87 | os.makedirs(model_folder, exist_ok=True) 88 | log_folder = args.log_folder 89 | os.makedirs(log_folder, exist_ok=True) 90 | result_folder = os.path.join(args.result_folder, model_name) 91 | os.makedirs(result_folder, exist_ok=True) 92 | 93 | print(model_name) 94 | print(args) 95 | if os.path.isdir(os.path.join(model_folder, model_name)): 96 | sys.exit(0) 97 | 98 | # Let's filter out cancer datasets based on disease annotation 99 | sample_tdb = tiledb.open(os.path.join(tiledb_base_path, "sample_metadata"), "r") 100 | sample_df = sample_tdb.query(attrs=["datasetID", "sampleID", "disease"]).df[:] 101 | sample_tdb.close() 102 | 103 | cancer_keywords = [ 104 | "cancer", 105 | "carcinoma", 106 | "leukemia", 107 | "myeloma", 108 | "glioma", 109 | "tumor", 110 | "metastati", 111 | "melanoma", 112 | "blastoma", 113 | "sarcoma", 114 | "cytoma", 115 | "lymphoma", 116 | "adenoma", 117 | "endothelioma", 118 | "teratoma", 119 | "lipoma", 120 | "leiomyoma", 121 | "meningioma", 122 | "ependymoma", 123 | ] 124 | mask = sample_df["disease"].apply(lambda x: any([k in x for k in cancer_keywords])) 125 | sample_df = sample_df[mask].set_index(["datasetID", "sampleID"]) 126 | 127 | # Exclude specific samples based on the above cancer keyword search 128 | exclude_samples = {} 129 | for study, sample in sample_df.index: 130 | if study not in exclude_samples: 131 | exclude_samples[study] = [] 132 | exclude_samples[study].append(sample) 133 | 134 | # Set a filter condition for training cells based on columns in the CellArr cell metadata 135 | filter_condition = f"cellTypeOntologyID!='nan' and total_counts>1000 and n_genes_by_counts>500 and pct_counts_mt<20 and predicted_doublets==0 and cellTypeOntologyID!='CL:0009010'" 136 | 137 | datamodule = CellMultisetDataModule( 138 | dataset_path=tiledb_base_path, 139 | gene_order=gene_order, 140 | val_studies=val_studies, 141 | exclude_studies=exclude_studies, 142 | exclude_samples=exclude_samples, 143 | label_id_column="cellTypeOntologyID", 144 | study_column="datasetID", 145 | sample_column="sampleID", 146 | filter_condition=filter_condition, 147 | batch_size=batch_size, 148 | n_batches=n_batches, 149 | num_workers=args.num_workers, 150 | sparse=False, 151 | remove_singleton_classes=True, 152 | persistent_workers=True, 153 | ) 154 | print(f"Training data size: {datamodule.train_df.shape}") 155 | print(f"Validation data size: {datamodule.val_df.shape}") 156 | 157 | model = MetricLearning( 158 | datamodule.n_genes, 159 | latent_dim=latent_dim, 160 | hidden_dim=hidden_dim, 161 | dropout=args.dropout, 162 | input_dropout=args.input_dropout, 163 | margin=margin, 164 | negative_selection=negative_selection, 165 | sample_across_studies=(args.cross == 1), 166 | perturb_labels=(args.perturb == 1), 167 | perturb_labels_fraction=args.perturb_fraction, 168 | lr=lr, 169 | triplet_loss_weight=triplet_loss_weight, 170 | l1=args.l1, 171 | l2=args.l2, 172 | max_epochs=max_epochs, 173 | cosine_annealing_tmax=cosine_annealing_tmax, 174 | #track_triplets=result_folder, # uncomment this to track triplet compositions per step 175 | ) 176 | 177 | # Use tensorboard to log training. Modify this based on your preferred logger. 178 | from pytorch_lightning.loggers import TensorBoardLogger 179 | 180 | logger = TensorBoardLogger( 181 | log_folder, 182 | name=model_name, 183 | default_hp_metric=False, 184 | flush_secs=1, 185 | version=suffix, 186 | ) 187 | 188 | gpu_idx = args.g 189 | 190 | from pytorch_lightning.callbacks import LearningRateMonitor 191 | 192 | lr_monitor = LearningRateMonitor(logging_interval="step") 193 | 194 | params = { 195 | "max_epochs": max_epochs, 196 | "logger": True, 197 | "logger": logger, 198 | "accelerator": "gpu", 199 | "callbacks": [lr_monitor], 200 | "log_every_n_steps": 1, 201 | "limit_train_batches": n_batches, 202 | "limit_val_batches": 10, 203 | "limit_test_batches": 10, 204 | } 205 | 206 | trainer = pl.Trainer(**params) 207 | 208 | ckpt_path = os.path.join(log_folder, model_name, suffix, "checkpoints") 209 | if os.path.isdir(ckpt_path): # resume training if checkpoints exist 210 | ckpt_files = sorted( 211 | [x for x in os.listdir(ckpt_path) if x.endswith(".ckpt")], 212 | key=lambda x: int(x.replace(".ckpt", "").split("=")[-1]), 213 | ) 214 | trainer.fit( 215 | model, 216 | datamodule=datamodule, 217 | ckpt_path=os.path.join(ckpt_path, ckpt_files[-1]), 218 | ) 219 | else: 220 | trainer.fit(model, datamodule=datamodule) 221 | 222 | model.save_all(model_path=os.path.join(model_folder, model_name)) 223 | 224 | if result_folder is not None: 225 | test_results = trainer.test(model, datamodule=datamodule) 226 | if test_results: 227 | with open(os.path.join(result_folder, f"{model_name}.test.json"), "w+") as fh: 228 | fh.write(json.dumps(test_results[0])) 229 | print(model_name) 230 | 231 | 232 | def main(): 233 | parser = argparse.ArgumentParser(description="Train SCimilarity model") 234 | parser.add_argument("--tiledb", type=str, help="CellArr tiledb base path") 235 | parser.add_argument("-g", type=int, default=0, help="gpu index") 236 | parser.add_argument("-m", type=float, default=0.05, help="triplet loss margin") 237 | parser.add_argument("-w", type=float, default=0.001, help="triplet loss weight") 238 | parser.add_argument("-t", type=str, default="semihard", help="negative selection type: [semihard, random, hardest]") 239 | parser.add_argument("-b", type=int, default=1000, help="batch size, number of cells") 240 | parser.add_argument("-e", type=int, default=500, help="max epochs") 241 | parser.add_argument("-n", type=int, default=100, help="number of batches per epoch") 242 | parser.add_argument("-l", type=float, default=0.005, help="learning rate") 243 | parser.add_argument("--latent_dim", type=int, default=128, help="latent space dim") 244 | parser.add_argument("--hidden_dim", nargs="+", type=int, default=[1024, 1024, 1024], help="list of hidden layers and sizes") 245 | parser.add_argument("--gene_order", type=str, default="/home/kuot/scratch/scimilarity_gene_order.tsv", help="gene order tsv") 246 | parser.add_argument("--input_dropout", type=float, default=0.4, help="input layer dropout p") 247 | parser.add_argument("--dropout", type=float, default=0.5, help="hidden layer dropout p") 248 | parser.add_argument("--l1", type=float, default=1e-4, help="l1 regularization lambda") 249 | parser.add_argument("--l2", type=float, default=0.01, help="l2 regularization lambda") 250 | parser.add_argument("--cross", type=int, default=1, help="sample across studies, 0: off, 1: on") 251 | parser.add_argument("--perturb", type=int, default=0, help="perturb labels with parent cell type (if parent exists in training data), 0: off, 1: on") 252 | parser.add_argument("--perturb_fraction", type=float, default=0.5, help="fraction of labels to attempt to perturb") 253 | parser.add_argument("--suffix", type=str, default="version_0", help="model name suffix") 254 | parser.add_argument("--model_folder", type=str, help="where to save model") 255 | parser.add_argument("--result_folder", type=str, default=None, help="where to save results") 256 | parser.add_argument("--log_folder", type=str, default="lightning_logs", help="where to save lightning logs") 257 | parser.add_argument("--num_workers", type=int, default=8, help="number of workers") 258 | parser.add_argument("--cosine_annealing_tmax", type=int, default=0, help="T max for cosine LR annealing, use max epochs if 0") 259 | args = parser.parse_args() 260 | 261 | train(args) 262 | 263 | 264 | if __name__ == "__main__": 265 | main() 266 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | # This file is used to configure your project. 2 | # Read more about the various options under: 3 | # https://setuptools.pypa.io/en/latest/userguide/declarative_config.html 4 | # https://setuptools.pypa.io/en/latest/references/keywords.html 5 | 6 | [metadata] 7 | name = scimilarity 8 | version = 0.4.0 9 | description = Single cell embedding into latent space that quantifies similarity between expression states. 10 | author = Graham Heimberg, Tony Kuo, Nathaniel Diamant, Omar Salem, Héctor Corrada Bravo, Jason A. Vander Heiden 11 | author_email = heimberg@gene.com 12 | keywords = single-cell embedding/retrieval 13 | long_description = file: README.rst 14 | long_description_content_type = text/x-rst 15 | url = https://github.com/genentech/scimilarity 16 | project_urls = 17 | Documentation = https://genentech.github.io/scimilarity 18 | Source = https://github.com/genentech/scimilarity 19 | Tracker = https://github.com/genentech/scimilarity/issues 20 | 21 | # Change if running only on Windows, Mac or Linux (comma-separated) 22 | platforms = any 23 | 24 | # Add here all kinds of additional classifiers as defined under 25 | # https://pypi.org/classifiers/ 26 | classifiers = 27 | Intended Audience :: Science/Research 28 | License :: OSI Approved :: Apache Software License 29 | Natural Language :: English 30 | Operating System :: MacOS :: MacOS X 31 | Operating System :: POSIX :: Linux 32 | Programming Language :: Python :: 3 33 | Programming Language :: Python :: 3.10 34 | 35 | [options] 36 | zip_safe = False 37 | packages = find_namespace: 38 | include_package_data = True 39 | package_dir = 40 | =src 41 | 42 | # Require a min/specific Python version (comma-separated conditions) 43 | python_requires = >=3.10 44 | 45 | # Add here dependencies of your project (line-separated), e.g. requests>=2.2,<3.0. 46 | # Version specifiers like >=2.2,<3.0 avoid problems due to API changes in 47 | # new major versions. This works if the required packages follow Semantic Versioning. 48 | # For more information, check out https://semver.org/. 49 | install_requires = 50 | captum>=0.5.0 51 | circlify>=0.14.0 52 | hnswlib>=0.8.0 53 | obonet>=1.0.0 54 | pyarrow>=15.0.0 55 | pytorch-lightning>=2.0.0 56 | scanpy>=1.9.2 57 | tiledb>=0.18.2 58 | tiledb-vector-search>=0.11.0 59 | torch>=1.10.1 60 | tqdm 61 | zarr>=2.6.1 62 | importlib-metadata; python_version<"3.8" 63 | 64 | [options.packages.find] 65 | where = src 66 | exclude = 67 | tests 68 | 69 | [options.extras_require] 70 | # Add here test requirements (semicolon/line-separated) 71 | testing = 72 | setuptools 73 | pytest 74 | pytest-cov 75 | 76 | [options.entry_points] 77 | # Add here console scripts like: 78 | # console_scripts = 79 | # script_name = scimilarity.module:function 80 | # For example: 81 | # console_scripts = 82 | # fibonacci = scimilarity.skeleton:run 83 | # And any other entry points, for example: 84 | # pyscaffold.cli = 85 | # awesome = pyscaffoldext.awesome.extension:AwesomeExtension 86 | 87 | [tool:pytest] 88 | # Specify command line options as you would do when invoking pytest directly. 89 | # e.g. --cov-report html (or xml) for html/xml output or --junitxml junit.xml 90 | # in order to write a coverage file that can be read by Jenkins. 91 | # CAUTION: --cov flags may prohibit setting breakpoints while debugging. 92 | # Comment those flags to avoid this pytest issue. 93 | addopts = 94 | --cov scimilarity --cov-report term-missing 95 | --verbose 96 | norecursedirs = 97 | dist 98 | build 99 | .tox 100 | testpaths = tests 101 | # Use pytest markers to select/deselect specific tests 102 | # markers = 103 | # slow: mark tests as slow (deselect with '-m "not slow"') 104 | # system: mark end-to-end system tests 105 | 106 | [devpi:upload] 107 | # Options for the devpi: PyPI server and packaging tool 108 | # VCS export must be deactivated since we are using setuptools-scm 109 | no_vcs = 1 110 | formats = bdist_wheel 111 | 112 | [flake8] 113 | # Some sane defaults for the code style checker flake8 114 | max_line_length = 88 115 | extend_ignore = E203, W503 116 | # ^ Black-compatible 117 | # E203 and W503 have edge cases handled by black 118 | exclude = 119 | .tox 120 | build 121 | dist 122 | .eggs 123 | docs/conf.py 124 | 125 | [pyscaffold] 126 | # PyScaffold's parameters when the project was created. 127 | # This will be used when updating. Do not change! 128 | version = 4.5 129 | package = scimilarity 130 | extensions = 131 | markdown 132 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Setup file for scimilarity. 3 | Use setup.cfg to configure your project. 4 | 5 | This file was generated with PyScaffold 4.5. 6 | PyScaffold helps you to put up the scaffold of your new Python project. 7 | Learn more under: https://pyscaffold.org/ 8 | """ 9 | from setuptools import setup 10 | 11 | if __name__ == "__main__": 12 | try: 13 | setup(use_scm_version={"version_scheme": "no-guess-dev"}) 14 | except: # noqa 15 | print( 16 | "\n\nAn error occurred while building the project, " 17 | "please ensure you have the most updated version of setuptools, " 18 | "setuptools_scm and wheel with:\n" 19 | " pip install -U setuptools setuptools_scm wheel\n\n" 20 | ) 21 | raise 22 | -------------------------------------------------------------------------------- /src/scimilarity/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | if sys.version_info[:2] >= (3, 8): 4 | # TODO: Import directly (no need for conditional) when `python_requires = >= 3.8` 5 | from importlib.metadata import PackageNotFoundError, version # pragma: no cover 6 | else: 7 | from importlib_metadata import PackageNotFoundError, version # pragma: no cover 8 | 9 | try: 10 | # Change here if project is renamed and does not equal the package name 11 | dist_name = __name__ 12 | __version__ = version(dist_name) 13 | except PackageNotFoundError: # pragma: no cover 14 | __version__ = "unknown" 15 | finally: 16 | del version, PackageNotFoundError 17 | 18 | from .cell_embedding import CellEmbedding 19 | from .cell_annotation import CellAnnotation 20 | from .cell_query import CellQuery 21 | from .interpreter import Interpreter 22 | from .utils import align_dataset, lognorm_counts 23 | -------------------------------------------------------------------------------- /src/scimilarity/anndata_data_models.py: -------------------------------------------------------------------------------- 1 | import anndata 2 | from collections import Counter 3 | import numpy as np 4 | import pytorch_lightning as pl 5 | import torch 6 | from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler 7 | from typing import Optional 8 | 9 | from .utils import align_dataset 10 | from .ontologies import import_cell_ontology, get_id_mapper 11 | 12 | 13 | class scDataset(Dataset): 14 | """A class that represents a single cell dataset. 15 | 16 | Parameters 17 | ---------- 18 | X: numpy.ndarray 19 | Gene expression vectors for every cell. 20 | Y: numpy.ndarray 21 | Text labels for every cell. 22 | study: numpy.ndarray 23 | The study identifier for every cell. 24 | """ 25 | 26 | def __init__(self, X, Y, study=None): 27 | self.X = X 28 | self.Y = Y 29 | self.study = study 30 | 31 | def __len__(self): 32 | return len(self.Y) 33 | 34 | def __getitem__(self, idx): 35 | # data, label, study 36 | return self.X[idx].A, self.Y[idx], self.study[idx] 37 | 38 | 39 | class MetricLearningDataModule(pl.LightningDataModule): 40 | """A class to encapsulate the anndata needed to train the model. 41 | 42 | Parameters 43 | ---------- 44 | train_path: str 45 | Path to the training h5ad file. 46 | val_path: str, optional, default: None 47 | Path to the validataion h5ad file. 48 | obs_field: str, default: "celltype_name" 49 | The obs key name containing celltype labels. 50 | batch_size: int, default: 1000 51 | Batch size. 52 | num_workers: int, default: 1 53 | The number of worker threads for dataloaders. 54 | gene_order_file: str, optional 55 | Use a given gene order as described in the specified file rather than using the 56 | training dataset's gene order. One gene symbol per line. 57 | 58 | Examples 59 | -------- 60 | >>> datamodule = MetricLearningDataModule( 61 | batch_size=1000, 62 | num_workers=1, 63 | obs_field="celltype_name", 64 | train_path="train.h5ad", 65 | ) 66 | """ 67 | 68 | def __init__( 69 | self, 70 | train_path: str, 71 | val_path: Optional[str] = None, 72 | obs_field: str = "celltype_name", 73 | batch_size: int = 500, 74 | num_workers: int = 1, 75 | gene_order_file: Optional[str] = None, 76 | ): 77 | super().__init__() 78 | self.train_path = train_path 79 | self.val_path = val_path 80 | self.obs_field = obs_field 81 | self.batch_size = batch_size 82 | self.num_workers = num_workers 83 | 84 | # read in ontology terms 85 | self.name2id = { 86 | value: key for key, value in get_id_mapper(import_cell_ontology()).items() 87 | } 88 | 89 | # read in training dataset 90 | train_data = anndata.read_h5ad(self.train_path) 91 | 92 | # keep cells whose celltype labels have valid ontology id 93 | train_data = self.subset_valid_terms(train_data) 94 | 95 | if ( 96 | gene_order_file is not None 97 | ): # gene space needs be aligned to the given gene order 98 | with open(gene_order_file, "r") as fh: 99 | self.gene_order = [line.strip() for line in fh] 100 | train_data = align_dataset(train_data, self.gene_order) 101 | else: # training dataset gene space is the gene order 102 | self.gene_order = train_data.var.index.tolist() 103 | 104 | self.n_genes = train_data.shape[1] # used when creating training model 105 | 106 | # map training labels to ints 107 | self.class_names = set(train_data.obs[obs_field]) 108 | self.label2int = {label: i for i, label in enumerate(self.class_names)} 109 | self.int2label = { 110 | value: key for key, value in self.label2int.items() 111 | } # used during training 112 | 113 | train_study = train_data.obs["study"] # studies 114 | self.train_Y = train_data.obs[obs_field].values # text labels 115 | self.train_dataset = scDataset(train_data.X, self.train_Y, study=train_study) 116 | 117 | self.val_dataset = None 118 | if val_path is not None: 119 | val_data = anndata.read_h5ad(self.val_path) 120 | val_data = align_dataset( 121 | self.subset_valid_terms(val_data), self.gene_order 122 | ) # gene space needs to match training set 123 | val_data = val_data[ 124 | val_data.obs[self.obs_field].isin(self.class_names) 125 | ] # labels need to be subsetted to training labels 126 | 127 | if val_data.shape[0] == 0: 128 | raise RuntimeError("No celltype labels have a valid ontology id.") 129 | val_study = val_data.obs["study"] # studies 130 | val_Y = val_data.obs[obs_field].values # text labels 131 | self.val_dataset = scDataset(val_data.X, val_Y, study=val_study) 132 | 133 | def subset_valid_terms(self, data: anndata.AnnData) -> anndata.AnnData: 134 | """Keep cells whose celltype labels have valid ontology id. 135 | 136 | Parameters 137 | ---------- 138 | data: anndata.AnnData 139 | Annotated data to subset by valid ontology id. 140 | 141 | Returns 142 | ------- 143 | anndata.AnnData 144 | An object containing the data whose celltype labels have 145 | valid ontology id. 146 | """ 147 | 148 | valid_terms_idx = data.obs[self.obs_field].isin(self.name2id.keys()) 149 | if valid_terms_idx.any(): 150 | return data[valid_terms_idx] 151 | raise RuntimeError("No celltype labels have a valid ontology id.") 152 | 153 | def get_sampler_weights( 154 | self, labels: list, studies: Optional[list] = None 155 | ) -> WeightedRandomSampler: 156 | """Get weighted random sampler. 157 | 158 | Parameters 159 | ---------- 160 | dataset: scDataset 161 | Single cell dataset. 162 | 163 | Returns 164 | ------- 165 | WeightedRandomSampler 166 | A WeightedRandomSampler object. 167 | """ 168 | 169 | if studies is None: 170 | class_sample_count = Counter(labels) 171 | sample_weights = torch.Tensor([1.0 / class_sample_count[t] for t in labels]) 172 | else: 173 | class_sample_count = Counter(labels) 174 | study_sample_count = Counter(studies) 175 | class_sample_count = { 176 | x: np.log1p(class_sample_count[x] / 1e4) for x in class_sample_count 177 | } 178 | study_sample_count = { 179 | x: np.log1p(study_sample_count[x] / 1e5) for x in study_sample_count 180 | } 181 | sample_weights = torch.Tensor( 182 | [ 183 | 1.0 / class_sample_count[labels[i]] / study_sample_count[studies[i]] 184 | for i in range(len(labels)) 185 | ] 186 | ) 187 | return WeightedRandomSampler(sample_weights, len(sample_weights)) 188 | 189 | def collate(self, batch): 190 | """Collate tensors. 191 | 192 | Parameters 193 | ---------- 194 | batch: 195 | Batch to collate. 196 | 197 | Returns 198 | ------- 199 | tuple 200 | A Tuple[torch.Tensor, torch.Tensor, list] containing information 201 | on the collated tensors. 202 | """ 203 | 204 | profiles, labels, studies = tuple( 205 | map(list, zip(*batch)) 206 | ) # tuple([list(t) for t in zip(*batch)]) 207 | return ( 208 | torch.squeeze(torch.Tensor(np.vstack(profiles))), 209 | torch.Tensor([self.label2int[l] for l in labels]), 210 | studies, 211 | ) 212 | 213 | def train_dataloader(self) -> DataLoader: 214 | """Load the training dataset. 215 | 216 | Returns 217 | ------- 218 | DataLoader 219 | A DataLoader object containing the training dataset. 220 | """ 221 | 222 | return DataLoader( 223 | self.train_dataset, 224 | batch_size=self.batch_size, 225 | num_workers=self.num_workers, 226 | pin_memory=True, 227 | drop_last=True, 228 | sampler=self.get_sampler_weights(self.train_dataset), 229 | collate_fn=self.collate, 230 | ) 231 | 232 | def val_dataloader(self) -> DataLoader: 233 | """Load the validation dataset. 234 | 235 | Returns 236 | ------- 237 | DataLoader 238 | A DataLoader object containing the validation dataset. 239 | """ 240 | 241 | if self.val_dataset is None: 242 | return None 243 | return DataLoader( 244 | self.val_dataset, 245 | batch_size=self.batch_size, 246 | num_workers=self.num_workers, 247 | pin_memory=True, 248 | drop_last=True, 249 | sampler=self.get_sampler_weights(self.val_dataset), 250 | collate_fn=self.collate, 251 | ) 252 | 253 | def test_dataloader(self) -> DataLoader: 254 | """Load the test dataset. 255 | 256 | Returns 257 | ------- 258 | DataLoader 259 | A DataLoader object containing the test dataset. 260 | """ 261 | 262 | return self.val_dataloader() 263 | -------------------------------------------------------------------------------- /src/scimilarity/cell_annotation.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union, List, Set, Tuple 2 | 3 | from .cell_search_knn import CellSearchKNN 4 | 5 | 6 | class CellAnnotation(CellSearchKNN): 7 | """A class that annotates cells using a cell embedding and then knn search. 8 | 9 | Parameters 10 | ---------- 11 | model_path: str 12 | Path to the directory containing model files. 13 | use_gpu: bool, default: False 14 | Use GPU instead of CPU. 15 | filenames: dict, optional, default: None 16 | Use a dictionary of custom filenames for files instead default. 17 | 18 | Examples 19 | -------- 20 | >>> ca = CellAnnotation(model_path="/opt/data/model") 21 | """ 22 | 23 | def __init__( 24 | self, 25 | model_path: str, 26 | use_gpu: bool = False, 27 | filenames: Optional[dict] = None, 28 | ): 29 | import os 30 | 31 | super().__init__( 32 | model_path=model_path, 33 | use_gpu=use_gpu, 34 | knn_type="hnswlib", 35 | ) 36 | 37 | self.annotation_path = os.path.join(model_path, "annotation") 38 | os.makedirs(self.annotation_path, exist_ok=True) 39 | 40 | if filenames is None: 41 | filenames = {} 42 | 43 | self.filenames["knn"] = os.path.join( 44 | self.annotation_path, filenames.get("knn", "labelled_kNN.bin") 45 | ) 46 | self.filenames["celltype_labels"] = os.path.join( 47 | self.annotation_path, 48 | filenames.get("celltype_labels", "reference_labels.tsv"), 49 | ) 50 | 51 | # get knn 52 | self.load_knn_index(self.filenames["knn"]) 53 | 54 | # get int2label and int2study 55 | self.idx2label = {} 56 | self.idx2study = {} 57 | if self.knn is not None: 58 | with open(self.filenames["celltype_labels"], "r") as fh: 59 | for i, line in enumerate(fh): 60 | token = line.strip().split("\t") 61 | self.idx2label[i] = token[0] 62 | if len(token) > 1: 63 | self.idx2study[i] = token[1] 64 | 65 | self.safelist = None 66 | self.blocklist = None 67 | 68 | @property 69 | def classes() -> set: 70 | """Get the set of all viable prediction classes.""" 71 | 72 | return set(self.label2int.keys()) 73 | 74 | def reset_knn(self): 75 | """Reset the knn such that nothing is marked deleted. 76 | 77 | Examples 78 | -------- 79 | >>> ca.reset_knn() 80 | """ 81 | 82 | self.blocklist = None 83 | self.safelist = None 84 | 85 | # hnswlib does not have a marked status, so we need to unmark all 86 | for i in self.idx2label: 87 | try: # throws an expection if not already marked 88 | self.knn.unmark_deleted(i) 89 | except: 90 | pass 91 | 92 | def blocklist_celltypes(self, labels: Union[List[str], Set[str]]): 93 | """Blocklist celltypes. 94 | 95 | Parameters 96 | ---------- 97 | labels: List[str], Set[str] 98 | A list or set containing blocklist labels. 99 | 100 | Notes 101 | ----- 102 | Blocking a celltype will persist for this instance of the class and subsequent predictions will have this blocklist. 103 | Blocklists and safelists are mutually exclusive, setting one will clear the other. 104 | 105 | Examples 106 | -------- 107 | >>> ca.blocklist_celltypes(["T cell"]) 108 | """ 109 | 110 | self.reset_knn() 111 | self.blocklist = set(labels) 112 | self.safelist = None 113 | 114 | for i, celltype_name in self.idx2label.items(): 115 | if celltype_name in self.blocklist: 116 | self.knn.mark_deleted(i) # mark blocklist 117 | 118 | def safelist_celltypes(self, labels: Union[List[str], Set[str]]): 119 | """Safelist celltypes. 120 | 121 | Parameters 122 | ---------- 123 | labels: List[str], Set[str] 124 | A list or set containing safelist labels. 125 | 126 | Notes 127 | ----- 128 | Safelisting a celltype will persist for this instance of the class and subsequent predictions will have this safelist. 129 | Blocklists and safelists are mutually exclusive, setting one will clear the other. 130 | 131 | Examples 132 | -------- 133 | >>> ca.safelist_celltypes(["CD4-positive, alpha-beta T cell"]) 134 | """ 135 | 136 | self.blocklist = None 137 | self.safelist = set(labels) 138 | 139 | for i in self.idx2label: # mark all 140 | try: # throws an exception if already marked 141 | self.knn.mark_deleted(i) 142 | except: 143 | pass 144 | for i, celltype_name in self.idx2label.items(): 145 | if celltype_name in self.safelist: 146 | self.knn.unmark_deleted(i) # unmark safelist 147 | 148 | def get_predictions_knn( 149 | self, 150 | embeddings: "numpy.ndarray", 151 | k: int = 50, 152 | ef: int = 100, 153 | weighting: bool = False, 154 | disable_progress: bool = False, 155 | ) -> Tuple["numpy.ndarray", "numpy.ndarray", "numpy.ndarray", "pandas.DataFrame"]: 156 | """Get predictions from knn search results. 157 | 158 | Parameters 159 | ---------- 160 | embeddings: numpy.ndarray 161 | Embeddings as a numpy array. 162 | k: int, default: 50 163 | The number of nearest neighbors. 164 | ef: int, default: 100 165 | The size of the dynamic list for the nearest neighbors. 166 | See https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md 167 | weighting: bool, default: False 168 | Use distance weighting when getting the consensus prediction. 169 | disable_progress: bool, default: False 170 | Disable tqdm progress bar 171 | 172 | Returns 173 | ------- 174 | predictions: pandas.Series 175 | A pandas series containing celltype label predictions. 176 | nn_idxs: numpy.ndarray 177 | A 2D numpy array of nearest neighbor indices [num_cells x k]. 178 | nn_dists: numpy.ndarray 179 | A 2D numpy array of nearest neighbor distances [num_cells x k]. 180 | stats: pandas.DataFrame 181 | Prediction statistics dataframe with columns: 182 | "hits" is a json string with the count for every class in k cells. 183 | "min_dist" is the minimum distance. 184 | "max_dist" is the maximum distance 185 | "vs2nd" is sum(best) / sum(best + 2nd best). 186 | "vsAll" is sum(best) / sum(all hits). 187 | "hits_weighted" is a json string with the weighted count for every class in k cells. 188 | "vs2nd_weighted" is weighted sum(best) / sum(best + 2nd best). 189 | "vsAll_weighted" is weighted sum(best) / sum(all hits). 190 | 191 | Examples 192 | -------- 193 | >>> ca = CellAnnotation(model_path="/opt/data/model") 194 | >>> embeddings = ca.get_embeddings(align_dataset(data, ca.gene_order).X) 195 | >>> predictions, nn_idxs, nn_dists, stats = ca.get_predictions_knn(embeddings) 196 | """ 197 | 198 | from collections import defaultdict 199 | import json 200 | import operator 201 | import numpy as np 202 | import pandas as pd 203 | import time 204 | from tqdm import tqdm 205 | 206 | start_time = time.time() 207 | nn_idxs, nn_dists = self.get_nearest_neighbors( 208 | embeddings=embeddings, k=k, ef=ef 209 | ) 210 | end_time = time.time() 211 | if not disable_progress: 212 | print( 213 | f"Get nearest neighbors finished in: {float(end_time - start_time) / 60} min" 214 | ) 215 | stats = { 216 | "hits": [], 217 | "hits_weighted": [], 218 | "min_dist": [], 219 | "max_dist": [], 220 | "vs2nd": [], 221 | "vsAll": [], 222 | "vs2nd_weighted": [], 223 | "vsAll_weighted": [], 224 | } 225 | if k == 1: 226 | predictions = pd.Series(nn_idxs.flatten()).map(self.idx2label) 227 | else: 228 | predictions = [] 229 | for nns, d_nns in tqdm( 230 | zip(nn_idxs, nn_dists), total=nn_idxs.shape[0], disable=disable_progress 231 | ): 232 | # count celltype in nearest neighbors (optionally with distance weights) 233 | celltype = defaultdict(float) 234 | celltype_weighted = defaultdict(float) 235 | for neighbor, dist in zip(nns, d_nns): 236 | celltype[self.idx2label[neighbor]] += 1.0 237 | celltype_weighted[self.idx2label[neighbor]] += 1.0 / float( 238 | max(dist, 1e-6) 239 | ) 240 | # predict based on consensus max occurrence 241 | if weighting: 242 | predictions.append( 243 | max(celltype_weighted.items(), key=operator.itemgetter(1))[0] 244 | ) 245 | else: 246 | predictions.append( 247 | max(celltype.items(), key=operator.itemgetter(1))[0] 248 | ) 249 | # compute prediction stats 250 | stats["hits"].append(json.dumps(celltype)) 251 | stats["hits_weighted"].append(json.dumps(celltype_weighted)) 252 | stats["min_dist"].append(np.min(d_nns)) 253 | stats["max_dist"].append(np.max(d_nns)) 254 | 255 | hits = sorted(celltype.values(), reverse=True) 256 | hits_weighted = [ 257 | max(x, 1e-6) 258 | for x in sorted(celltype_weighted.values(), reverse=True) 259 | ] 260 | if len(hits) > 1: 261 | stats["vs2nd"].append(hits[0] / (hits[0] + hits[1])) 262 | stats["vsAll"].append(hits[0] / sum(hits)) 263 | stats["vs2nd_weighted"].append( 264 | hits_weighted[0] / (hits_weighted[0] + hits_weighted[1]) 265 | ) 266 | stats["vsAll_weighted"].append( 267 | hits_weighted[0] / sum(hits_weighted) 268 | ) 269 | else: 270 | stats["vs2nd"].append(1.0) 271 | stats["vsAll"].append(1.0) 272 | stats["vs2nd_weighted"].append(1.0) 273 | stats["vsAll_weighted"].append(1.0) 274 | return ( 275 | pd.Series(predictions), 276 | nn_idxs, 277 | nn_dists, 278 | pd.DataFrame(stats), 279 | ) 280 | 281 | def annotate_dataset( 282 | self, 283 | data: "anndata.AnnData", 284 | ) -> "anndata.AnnData": 285 | """Annotate dataset with celltype predictions. 286 | 287 | Parameters 288 | ---------- 289 | data: anndata.AnnData 290 | The annotated data matrix with rows for cells and columns for genes. 291 | This function assumes the data has been log normalized (i.e. via lognorm_counts) accordingly. 292 | 293 | Returns 294 | ------- 295 | anndata.AnnData 296 | A data object where: 297 | - celltype predictions are in obs["celltype_hint"] 298 | - embeddings are in obs["X_scimilarity"]. 299 | 300 | Examples 301 | -------- 302 | >>> ca = CellAnnotation(model_path="/opt/data/model") 303 | >>> data = annotate_dataset(data) 304 | """ 305 | 306 | from .utils import align_dataset 307 | 308 | embeddings = self.get_embeddings(align_dataset(data, self.gene_order).X) 309 | data.obsm["X_scimilarity"] = embeddings 310 | 311 | predictions, _, _, nn_stats = self.get_predictions_knn(embeddings) 312 | data.obs["celltype_hint"] = predictions.values 313 | data.obs["min_dist"] = nn_stats["min_dist"].values 314 | data.obs["celltype_hits"] = nn_stats["hits"].values 315 | data.obs["celltype_hits_weighted"] = nn_stats["hits_weighted"].values 316 | data.obs["celltype_hint_stat"] = nn_stats["vsAll"].values 317 | data.obs["celltype_hint_weighted_stat"] = nn_stats["vsAll_weighted"].values 318 | 319 | return data 320 | -------------------------------------------------------------------------------- /src/scimilarity/cell_embedding.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | 3 | 4 | class CellEmbedding: 5 | """A class that embeds cell gene expression data using an ML model. 6 | 7 | Parameters 8 | ---------- 9 | model_path: str 10 | Path to the directory containing model files. 11 | use_gpu: bool, default: False 12 | Use GPU instead of CPU. 13 | 14 | Examples 15 | -------- 16 | >>> ce = CellEmbedding(model_path="/opt/data/model") 17 | """ 18 | 19 | def __init__( 20 | self, 21 | model_path: str, 22 | use_gpu: bool = False, 23 | ): 24 | import json 25 | import os 26 | import pandas as pd 27 | from .nn_models import Encoder 28 | 29 | self.model_path = model_path 30 | self.use_gpu = use_gpu 31 | 32 | self.filenames = { 33 | "model": os.path.join(self.model_path, "encoder.ckpt"), 34 | "gene_order": os.path.join(self.model_path, "gene_order.tsv"), 35 | } 36 | 37 | # get gene order 38 | with open(self.filenames["gene_order"], "r") as fh: 39 | self.gene_order = [line.strip() for line in fh] 40 | 41 | # get neural network model and infer network size 42 | with open(os.path.join(self.model_path, "layer_sizes.json"), "r") as fh: 43 | layer_sizes = json.load(fh) 44 | # keys: network.1.weight, network.2.weight, ..., network.n.weight 45 | layers = [ 46 | (key, layer_sizes[key]) 47 | for key in sorted(list(layer_sizes.keys())) 48 | if "weight" in key and len(layer_sizes[key]) > 1 49 | ] 50 | parameters = { 51 | "latent_dim": layers[-1][1][0], # last 52 | "hidden_dim": [layer[1][0] for layer in layers][0:-1], # all but last 53 | } 54 | 55 | self.n_genes = len(self.gene_order) 56 | self.latent_dim = parameters["latent_dim"] 57 | self.model = Encoder( 58 | n_genes=self.n_genes, 59 | latent_dim=parameters["latent_dim"], 60 | hidden_dim=parameters["hidden_dim"], 61 | ) 62 | if self.use_gpu is True: 63 | self.model.cuda() 64 | self.model.load_state(self.filenames["model"]) 65 | self.model.eval() 66 | 67 | self.int2label = pd.read_csv( 68 | os.path.join(self.model_path, "label_ints.csv"), index_col=0 69 | )["0"].to_dict() 70 | self.label2int = {value: key for key, value in self.int2label.items()} 71 | 72 | def get_embeddings( 73 | self, 74 | X: Union["scipy.sparse.csr_matrix", "scipy.sparse.csc_matrix", "numpy.ndarray"], 75 | num_cells: int = -1, 76 | buffer_size: int = 10000, 77 | ) -> "numpy.ndarray": 78 | """Calculate embeddings for lognormed gene expression matrix. 79 | 80 | Parameters 81 | ---------- 82 | X: scipy.sparse.csr_matrix, scipy.sparse.csc_matrix, numpy.ndarray 83 | Gene space aligned and log normalized (tp10k) gene expression matrix. 84 | num_cells: int, default: -1 85 | The number of cells to embed, starting from index 0. 86 | A value of -1 will embed all cells. 87 | buffer_size: int, default: 10000 88 | The number of cells to embed in one batch. 89 | 90 | Returns 91 | ------- 92 | numpy.ndarray 93 | A 2D numpy array of embeddings [num_cells x latent_space_dimensions]. 94 | 95 | Examples 96 | -------- 97 | >>> from scimilarity.utils import align_dataset, lognorm_counts 98 | >>> ce = CellEmbedding(model_path="/opt/data/model") 99 | >>> data = align_dataset(data, ce.gene_order) 100 | >>> data = lognorm_counts(data) 101 | >>> embeddings = ce.get_embeddings(data.X) 102 | """ 103 | 104 | import numpy as np 105 | from scipy.sparse import csr_matrix, csc_matrix 106 | import torch 107 | import zarr 108 | 109 | if num_cells == -1: 110 | num_cells = X.shape[0] 111 | 112 | if ( 113 | (isinstance(X, csr_matrix) or isinstance(X, csc_matrix)) 114 | and ( 115 | isinstance(X.data, zarr.core.Array) 116 | or isinstance(X.indices, zarr.core.Array) 117 | or isinstance(X.indptr, zarr.core.Array) 118 | ) 119 | and num_cells <= buffer_size 120 | ): 121 | X.data = X.data[...] 122 | X.indices = X.indices[...] 123 | X.indptr = X.indptr[...] 124 | 125 | embedding_parts = [] 126 | with torch.inference_mode(): # disable gradients, not needed for inference 127 | for i in range(0, num_cells, buffer_size): 128 | profiles = None 129 | if isinstance(X, np.ndarray): 130 | profiles = torch.Tensor(X[i : i + buffer_size]) 131 | elif isinstance(X, torch.Tensor): 132 | profiles = X[i : i + buffer_size] 133 | elif isinstance(X, csr_matrix) or isinstance(X, csc_matrix): 134 | profiles = torch.Tensor(X[i : i + buffer_size].toarray()) 135 | 136 | if profiles is None: 137 | raise RuntimeError(f"Unknown data type {type(X)}.") 138 | 139 | if self.use_gpu is True: 140 | profiles = profiles.cuda() 141 | embedding_parts.append(self.model(profiles)) 142 | 143 | if not embedding_parts: 144 | raise RuntimeError(f"No valid cells detected.") 145 | 146 | if self.use_gpu: 147 | # detach, move from gpu into cpu, return as numpy array 148 | embedding = torch.vstack(embedding_parts).detach().cpu().numpy() 149 | else: 150 | # detach, return as numpy array 151 | embedding = torch.vstack(embedding_parts).detach().numpy() 152 | 153 | if np.isnan(embedding).any(): 154 | raise RuntimeError(f"NaN detected in embeddings.") 155 | 156 | return embedding 157 | -------------------------------------------------------------------------------- /src/scimilarity/cell_search_knn.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | 3 | from .cell_embedding import CellEmbedding 4 | 5 | 6 | class CellSearchKNN(CellEmbedding): 7 | """A class for searching similar cells using cell embeddings kNN. 8 | 9 | Parameters 10 | ---------- 11 | model_path: str 12 | Path to the directory containing model files. 13 | knn_type: str, default: "hnswlib" 14 | What type of knn to use, options are ["hnswlib", "tiledb_vector_search"] 15 | use_gpu: bool, default: False 16 | Use GPU instead of CPU. 17 | 18 | Examples 19 | -------- 20 | >>> cs = CellSearchKNN(model_path="/opt/data/model") 21 | """ 22 | 23 | def __init__( 24 | self, 25 | model_path: str, 26 | knn_type: str, 27 | use_gpu: bool = False, 28 | ): 29 | super().__init__( 30 | model_path=model_path, 31 | use_gpu=use_gpu, 32 | ) 33 | 34 | self.knn = None 35 | self.knn_type = knn_type 36 | assert self.knn_type in ["hnswlib", "tiledb_vector_search"] 37 | self.safelist = None 38 | self.blocklist = None 39 | 40 | def load_knn_index(self, knn_file: str, memory_budget: int = 50000000): 41 | """Load the kNN index file 42 | 43 | Parameters 44 | ---------- 45 | knn_file: str 46 | Filename of the kNN index. 47 | memory_budget: int, default: 50000000 48 | Memory budget for tiledb vector search. 49 | """ 50 | 51 | import hnswlib 52 | import os 53 | import tiledb.vector_search as vs 54 | 55 | if os.path.isfile(knn_file) and self.knn_type == "hnswlib": 56 | self.knn = hnswlib.Index(space="cosine", dim=self.model.latent_dim) 57 | self.knn.load_index(knn_file) 58 | elif os.path.isdir(knn_file) and self.knn_type == "tiledb_vector_search": 59 | self.knn = vs.IVFFlatIndex(knn_file, memory_budget=memory_budget) 60 | else: 61 | print(f"Warning: No KNN index found at {knn_file}") 62 | self.knn = None 63 | 64 | def get_nearest_neighbors( 65 | self, embeddings: "numpy.ndarray", k: int = 50, ef: int = 100 66 | ) -> Tuple["numpy.ndarray", "numpy.ndarray"]: 67 | """Get nearest neighbors. 68 | Used by classes that inherit from CellEmbedding and have an instantiated kNN. 69 | 70 | Parameters 71 | ---------- 72 | embeddings: numpy.ndarray 73 | Embeddings as a 2D numpy array. 74 | k: int, default: 50 75 | The number of nearest neighbors. 76 | ef: int, default: 100 77 | The size of the dynamic list for the nearest neighbors for hnswlib. 78 | See https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md 79 | 80 | Returns 81 | ------- 82 | nn_idxs: numpy.ndarray 83 | A 2D numpy array of nearest neighbor indices [num_embeddings x k]. 84 | nn_dists: numpy.ndarray 85 | A 2D numpy array of nearest neighbor distances [num_embeddings x k]. 86 | 87 | Examples 88 | -------- 89 | >>> nn_idxs, nn_dists = get_nearest_neighbors(embeddings) 90 | """ 91 | 92 | if self.knn is None: 93 | raise RuntimeError("kNN is not initialized.") 94 | if self.knn_type == "hnswlib": 95 | self.knn.set_ef(ef) 96 | return self.knn.knn_query(embeddings, k=k) 97 | elif self.knn_type == "tiledb_vector_search": 98 | import math 99 | 100 | nn_dists, nn_idxs = self.knn.query( 101 | embeddings, k=k, nprobe=int(math.sqrt(self.knn.partitions)) 102 | ) 103 | return (nn_idxs, nn_dists) 104 | -------------------------------------------------------------------------------- /src/scimilarity/interpreter.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from typing import Optional, Union 3 | 4 | 5 | class SimpleDist(nn.Module): 6 | """Calculates the distance between representations 7 | 8 | Parameters 9 | ---------- 10 | encoder: torch.nn.Module 11 | The encoder pytorch object. 12 | """ 13 | 14 | def __init__(self, encoder: "torch.nn.Module"): 15 | super().__init__() 16 | self.encoder = encoder 17 | 18 | def forward( 19 | self, 20 | anchors: "torch.Tensor", 21 | negatives: "torch.Tensor", 22 | ): 23 | """Forward. 24 | 25 | Parameters 26 | ---------- 27 | anchors: torch.Tensor 28 | Tensor for anchor or positive cells. 29 | negatives: torch.Tensor 30 | Tensor for negative cells. 31 | 32 | Returns 33 | ------- 34 | float 35 | Sum of squares distance for the encoded tensors. 36 | """ 37 | 38 | f_anc = self.encoder(anchors) 39 | f_neg = self.encoder(negatives) 40 | return ((f_neg - f_anc) ** 2).sum(dim=1) 41 | 42 | 43 | class Interpreter: 44 | """A class that interprets significant genes. 45 | 46 | Parameters 47 | ---------- 48 | encoder: torch.nn.Module 49 | The encoder pytorch object. 50 | gene_order: list 51 | The list of genes. 52 | 53 | Examples 54 | -------- 55 | >>> interpreter = Interpreter(CellEmbedding("/opt/data/model").model) 56 | """ 57 | 58 | def __init__( 59 | self, 60 | encoder: "torch.nn.Module", 61 | gene_order: list, 62 | ): 63 | from captum.attr import IntegratedGradients 64 | 65 | self.encoder = encoder 66 | self.dist_ig = IntegratedGradients(SimpleDist(self.encoder)) 67 | self.gene_order = gene_order 68 | 69 | def get_attributions( 70 | self, 71 | anchors: Union["torch.Tensor", "numpy.ndarray", "scipy.sparse.csr_matrix"], 72 | negatives: Union["torch.Tensor", "numpy.ndarray", "scipy.sparse.csr_matrix"], 73 | ) -> "numpy.ndarray": 74 | """Returns attributions, which can later be aggregated. 75 | High attributions for genes that are expressed more highly in the anchor 76 | and that affect the distance between anchors and negatives strongly. 77 | 78 | Parameters 79 | ---------- 80 | anchors: numpy.ndarray, scipy.sparse.csr_matrix, torch.Tensor 81 | Tensor for anchor or positive cells. 82 | negatives: numpy.ndarray, scipy.sparse.csr_matrix, torch.Tensor 83 | Tensor for negative cells. 84 | 85 | Returns 86 | ------- 87 | numpy.ndarray 88 | A 2D numpy array of attributions [num_cells x num_genes]. 89 | 90 | Examples 91 | -------- 92 | >>> attr = interpreter.get_attributions(anchors, negatives) 93 | """ 94 | 95 | import numpy as np 96 | from scipy.sparse import csr_matrix 97 | import torch 98 | 99 | assert anchors.shape == negatives.shape 100 | 101 | if isinstance(anchors, np.ndarray): 102 | anc = torch.Tensor(anchors) 103 | elif isinstance(anchors, csr_matrix): 104 | anc = torch.Tensor(anchors.todense()) 105 | else: 106 | anc = anchors 107 | 108 | if isinstance(negatives, np.ndarray): 109 | neg = torch.Tensor(negatives) 110 | elif isinstance(negatives, csr_matrix): 111 | neg = torch.Tensor(negatives.todense()) 112 | else: 113 | neg = negatives 114 | 115 | # Check if model is on gpu device 116 | if next(self.encoder.parameters()).is_cuda: 117 | anc = anc.cuda() 118 | neg = neg.cuda() 119 | 120 | # attribute l2_dist(anchors, negatives) 121 | attr = self.dist_ig.attribute( 122 | anc, 123 | baselines=neg, # integrate from negatives to anchors 124 | additional_forward_args=neg, 125 | ) 126 | attr *= anc > neg 127 | attr = +attr.abs() # signs unreliable, so use absolute value of attributions 128 | 129 | if next(self.encoder.parameters()).is_cuda: 130 | return attr.detach().cpu().numpy() 131 | return attr.detach().numpy() 132 | 133 | def get_ranked_genes(self, attrs: "numpy.ndarray") -> "pandas.DataFrame": 134 | """Get the ranked gene list based on highest attributions. 135 | 136 | Parameters 137 | ---------- 138 | attr: numpy.ndarray 139 | Attributions matrix. 140 | 141 | Returns 142 | ------- 143 | pandas.DataFrame 144 | A pandas dataframe containing the ranked attributions for each gene 145 | 146 | Examples 147 | -------- 148 | >>> attrs_df = interpreter.get_ranked_genes(attrs) 149 | """ 150 | 151 | import numpy as np 152 | import pandas as pd 153 | 154 | mean_attrs = attrs.mean(axis=0) 155 | idx = mean_attrs.argsort()[::-1] 156 | df = { 157 | "gene": np.array(self.gene_order)[idx], 158 | "gene_idx": idx, 159 | "attribution": mean_attrs[idx], 160 | "attribution_std": attrs.std(axis=0)[idx], 161 | "cells": attrs.shape[0], 162 | } 163 | return pd.DataFrame(df) 164 | 165 | def plot_ranked_genes( 166 | self, 167 | attrs_df: "pandas.DataFrame", 168 | n_plot: int = 15, 169 | filename: Optional[str] = None, 170 | ): 171 | """Plot the ranked gene attributions. 172 | 173 | Parameters 174 | ---------- 175 | attrs_df: pandas.DataFrame 176 | Dataframe of ranked attributions. 177 | n_plot: int 178 | The number of top genes to plot. 179 | filename: str, optional 180 | The filename to save to plot as. 181 | 182 | Examples 183 | -------- 184 | >>> interpreter.plot_ranked_genes(attrs_df) 185 | """ 186 | 187 | import matplotlib.pyplot as plt 188 | import matplotlib as mpl 189 | import numpy as np 190 | import seaborn as sns 191 | 192 | mpl.rcParams["pdf.fonttype"] = 42 193 | 194 | df = attrs_df.head(n_plot) 195 | ci = 1.96 * df["attribution_std"] / np.sqrt(df["cells"]) 196 | 197 | fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(5, 2), dpi=200) 198 | sns.barplot(ax=ax, data=df, x="gene", y="attribution", hue="gene", dodge=False) 199 | ax.set_yticks([]) 200 | plt.tick_params(axis="x", which="major", labelsize=8, labelrotation=90) 201 | 202 | ax.errorbar( 203 | df["gene"].values, 204 | df["attribution"].values, 205 | yerr=ci, 206 | ecolor="black", 207 | fmt="none", 208 | ) 209 | if ax.get_legend() is not None: 210 | ax.get_legend().remove() 211 | 212 | if filename: # save the figure 213 | fig.savefig(filename, bbox_inches="tight") 214 | -------------------------------------------------------------------------------- /src/scimilarity/nn_models.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains the neural network architectures. 3 | These are all you need for inference. 4 | """ 5 | 6 | import torch 7 | from torch import nn 8 | import torch.nn.functional as F 9 | from typing import List 10 | 11 | 12 | class Encoder(nn.Module): 13 | """A class that encapsulates the encoder. 14 | 15 | Parameters 16 | ---------- 17 | n_genes: int 18 | The number of genes in the gene space, representing the input dimensions. 19 | latent_dim: int, default: 128 20 | The latent space dimensions 21 | hidden_dim: List[int], default: [1024, 1024] 22 | A list of hidden layer dimensions, describing the number of layers and their dimensions. 23 | Hidden layers are constructed in the order of the list for the encoder and in reverse 24 | for the decoder. 25 | dropout: float, default: 0.5 26 | The dropout rate for hidden layers 27 | input_dropout: float, default: 0.4 28 | The dropout rate for the input layer 29 | """ 30 | 31 | def __init__( 32 | self, 33 | n_genes: int, 34 | latent_dim: int = 128, 35 | hidden_dim: List[int] = [1024, 1024], 36 | dropout: float = 0.5, 37 | input_dropout: float = 0.4, 38 | ): 39 | super().__init__() 40 | self.latent_dim = latent_dim 41 | self.network = nn.ModuleList() 42 | for i in range(len(hidden_dim)): 43 | if i == 0: # input layer 44 | self.network.append( 45 | nn.Sequential( 46 | nn.Dropout(p=input_dropout), 47 | nn.Linear(n_genes, hidden_dim[i]), 48 | nn.BatchNorm1d(hidden_dim[i]), 49 | nn.PReLU(), 50 | ) 51 | ) 52 | else: # hidden layers 53 | self.network.append( 54 | nn.Sequential( 55 | nn.Dropout(p=dropout), 56 | nn.Linear(hidden_dim[i - 1], hidden_dim[i]), 57 | nn.BatchNorm1d(hidden_dim[i]), 58 | nn.PReLU(), 59 | ) 60 | ) 61 | # output layer 62 | self.network.append(nn.Linear(hidden_dim[-1], latent_dim)) 63 | 64 | def forward(self, x) -> torch.Tensor: 65 | """Forward. 66 | 67 | Parameters 68 | ---------- 69 | x: torch.Tensor 70 | Input tensor corresponding to input layer. 71 | 72 | Returns 73 | ------- 74 | torch.Tensor 75 | Output tensor corresponding to output layer. 76 | """ 77 | 78 | for i, layer in enumerate(self.network): 79 | x = layer(x) 80 | return F.normalize(x, p=2, dim=1) 81 | 82 | def save_state(self, filename: str): 83 | """Save model state. 84 | 85 | Parameters 86 | ---------- 87 | filename: str 88 | Filename to save the model state. 89 | """ 90 | 91 | torch.save({"state_dict": self.state_dict()}, filename) 92 | 93 | def load_state(self, filename: str, use_gpu: bool = False): 94 | """Load model state. 95 | 96 | Parameters 97 | ---------- 98 | filename: str 99 | Filename containing the model state. 100 | use_gpu: bool, default: False 101 | Boolean indicating whether or not to use GPUs. 102 | """ 103 | 104 | if not use_gpu: 105 | ckpt = torch.load( 106 | filename, map_location=torch.device("cpu"), weights_only=False 107 | ) 108 | else: 109 | ckpt = torch.load(filename, weights_only=False) 110 | self.load_state_dict(ckpt["state_dict"]) 111 | 112 | 113 | class Decoder(nn.Module): 114 | """A class that encapsulates the decoder. 115 | 116 | Parameters 117 | ---------- 118 | n_genes: int 119 | The number of genes in the gene space, representing the input dimensions. 120 | latent_dim: int, default: 128 121 | The latent space dimensions 122 | hidden_dim: List[int], default: [1024, 1024] 123 | A list of hidden layer dimensions, describing the number of layers and their dimensions. 124 | Hidden layers are constructed in the order of the list for the encoder and in reverse 125 | for the decoder. 126 | dropout: float, default: 0.5 127 | The dropout rate for hidden layers 128 | """ 129 | 130 | def __init__( 131 | self, 132 | n_genes: int, 133 | latent_dim: int = 128, 134 | hidden_dim: List[int] = [1024, 1024], 135 | dropout: float = 0.5, 136 | ): 137 | super().__init__() 138 | self.latent_dim = latent_dim 139 | self.network = nn.ModuleList() 140 | for i in range(len(hidden_dim)): 141 | if i == 0: # first hidden layer 142 | self.network.append( 143 | nn.Sequential( 144 | nn.Linear(latent_dim, hidden_dim[i]), 145 | nn.BatchNorm1d(hidden_dim[i]), 146 | nn.PReLU(), 147 | ) 148 | ) 149 | else: # other hidden layers 150 | self.network.append( 151 | nn.Sequential( 152 | nn.Dropout(p=dropout), 153 | nn.Linear(hidden_dim[i - 1], hidden_dim[i]), 154 | nn.BatchNorm1d(hidden_dim[i]), 155 | nn.PReLU(), 156 | ) 157 | ) 158 | # reconstruction layer 159 | self.network.append(nn.Linear(hidden_dim[-1], n_genes)) 160 | 161 | def forward(self, x) -> torch.Tensor: 162 | """Forward. 163 | 164 | Parameters 165 | ---------- 166 | x: torch.Tensor 167 | Input tensor corresponding to input layer. 168 | 169 | Returns 170 | ------- 171 | torch.Tensor 172 | Output tensor corresponding to output layer. 173 | """ 174 | for i, layer in enumerate(self.network): 175 | x = layer(x) 176 | return x 177 | 178 | def save_state(self, filename: str): 179 | """Save model state. 180 | 181 | Parameters 182 | ---------- 183 | filename: str 184 | Filename to save the model state. 185 | """ 186 | 187 | torch.save({"state_dict": self.state_dict()}, filename) 188 | 189 | def load_state(self, filename: str, use_gpu: bool = False): 190 | """Load model state. 191 | 192 | Parameters 193 | ---------- 194 | filename: str 195 | Filename containing the model state. 196 | use_gpu: bool, default: False 197 | Boolean indicating whether or not to use GPUs. 198 | """ 199 | 200 | if not use_gpu: 201 | ckpt = torch.load( 202 | filename, map_location=torch.device("cpu"), weights_only=False 203 | ) 204 | else: 205 | ckpt = torch.load(filename, weights_only=False) 206 | self.load_state_dict(ckpt["state_dict"]) 207 | -------------------------------------------------------------------------------- /src/scimilarity/ontologies.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import obonet 3 | import pandas as pd 4 | from typing import Union, Tuple, List 5 | 6 | 7 | def subset_nodes_to_set(nodes, restricted_set: Union[list, set]) -> nx.DiGraph: 8 | """Restrict nodes to a given set. 9 | 10 | Parameters 11 | ---------- 12 | nodes: networkx.DiGraph 13 | Node graph. 14 | restricted_set: list, set 15 | Restricted node list. 16 | 17 | Returns 18 | ------- 19 | networkx.DiGraph 20 | Node graph of restricted set. 21 | 22 | Examples 23 | -------- 24 | >>> subset_nodes_to_set(nodes, node_list) 25 | """ 26 | 27 | return {node for node in nodes if node in restricted_set} 28 | 29 | 30 | def import_cell_ontology( 31 | url="http://purl.obolibrary.org/obo/cl/cl-basic.obo", 32 | ) -> nx.DiGraph: 33 | """Import taxrank cell ontology. 34 | 35 | Parameters 36 | ---------- 37 | url: str, default: "http://purl.obolibrary.org/obo/cl/cl-basic.obo" 38 | The url of the ontology obo file. 39 | 40 | Returns 41 | ------- 42 | networkx.DiGraph 43 | Node graph of ontology. 44 | 45 | Examples 46 | -------- 47 | >>> onto = import_cell_ontology() 48 | """ 49 | 50 | graph = obonet.read_obo(url).reverse() # flip for intuitiveness 51 | return nx.DiGraph(graph) # return as graph 52 | 53 | 54 | def import_uberon_ontology( 55 | url="http://purl.obolibrary.org/obo/uberon/basic.obo", 56 | ) -> nx.DiGraph: 57 | """Import uberon tissue ontology. 58 | 59 | Parameters 60 | ---------- 61 | url: str, default: "http://purl.obolibrary.org/obo/uberon/basic.obo" 62 | The url of the ontology obo file. 63 | 64 | Returns 65 | ------- 66 | networkx.DiGraph 67 | Node graph of ontology. 68 | 69 | Examples 70 | -------- 71 | >>> onto = import_uberon_ontology() 72 | """ 73 | 74 | graph = obonet.read_obo(url).reverse() # flip for intuitiveness 75 | return nx.DiGraph(graph) # return as graph 76 | 77 | 78 | def import_doid_ontology( 79 | url="http://purl.obolibrary.org/obo/doid.obo", 80 | ) -> nx.DiGraph: 81 | """Import doid disease ontology. 82 | 83 | Parameters 84 | ---------- 85 | url: str, default: "http://purl.obolibrary.org/obo/doid.obo" 86 | The url of the ontology obo file. 87 | 88 | Returns 89 | ------- 90 | networkx.DiGraph 91 | Node graph of ontology. 92 | 93 | Examples 94 | -------- 95 | >>> onto = import_doid_ontology() 96 | """ 97 | 98 | graph = obonet.read_obo(url).reverse() # flip for intuitiveness 99 | return nx.DiGraph(graph) # return as graph 100 | 101 | 102 | def import_mondo_ontology( 103 | url="http://purl.obolibrary.org/obo/mondo.obo", 104 | ) -> nx.DiGraph: 105 | """Import mondo disease ontology. 106 | 107 | Parameters 108 | ---------- 109 | url: str, default: "http://purl.obolibrary.org/obo/mondo.obo" 110 | The url of the ontology obo file. 111 | 112 | Returns 113 | ------- 114 | networkx.DiGraph 115 | Node graph of ontology. 116 | 117 | Examples 118 | -------- 119 | >>> onto = import_mondo_ontology() 120 | """ 121 | 122 | graph = obonet.read_obo(url).reverse() # flip for intuitiveness 123 | return nx.DiGraph(graph) # return as graph 124 | 125 | 126 | def get_id_mapper(graph) -> dict: 127 | """Mapping from term ID to name. 128 | 129 | Parameters 130 | ---------- 131 | graph: networkx.DiGraph 132 | Node graph. 133 | 134 | Returns 135 | ------- 136 | dict 137 | The id to name mapping dictionary. 138 | 139 | Examples 140 | -------- 141 | >>> id2name = get_id_mapper(onto) 142 | """ 143 | 144 | return {id_: data.get("name") for id_, data in graph.nodes(data=True)} 145 | 146 | 147 | def get_children(graph, node, node_list=None) -> nx.DiGraph: 148 | """Get children nodes of a given node. 149 | 150 | Parameters 151 | ---------- 152 | graph: networkx.DiGraph 153 | Node graph. 154 | node: str 155 | ID of given node. 156 | node_list: list, set, optional, default: None 157 | A restricted node list for filtering. 158 | 159 | Returns 160 | ------- 161 | networkx.DiGraph 162 | Node graph of children. 163 | 164 | Examples 165 | -------- 166 | >>> children = get_children(onto, id) 167 | """ 168 | 169 | children = {item[1] for item in graph.out_edges(node)} 170 | if node_list is None: 171 | return children 172 | return subset_nodes_to_set(children, node_list) 173 | 174 | 175 | def get_parents(graph, node, node_list=None) -> nx.DiGraph: 176 | """Get parent nodes of a given node. 177 | 178 | Parameters 179 | ---------- 180 | graph: networkx.DiGraph 181 | Node graph. 182 | node: str 183 | ID of given node. 184 | node_list: list, set, optional, default: None 185 | A restricted node list for filtering. 186 | 187 | Returns 188 | ------- 189 | networkx.DiGraph 190 | Node graph of parents. 191 | 192 | Examples 193 | -------- 194 | >>> parents = get_parents(onto, id) 195 | """ 196 | 197 | parents = {item[0] for item in graph.in_edges(node)} 198 | if node_list is None: 199 | return parents 200 | return subset_nodes_to_set(parents, node_list) 201 | 202 | 203 | def get_siblings(graph, node, node_list=None) -> nx.DiGraph: 204 | """Get sibling nodes of a given node. 205 | 206 | Parameters 207 | ---------- 208 | graph: networkx.DiGraph 209 | Node graph. 210 | node: str 211 | ID of given node. 212 | node_list: list, set, optional, default: None 213 | A restricted node list for filtering. 214 | 215 | Returns 216 | ------- 217 | networkx.DiGraph 218 | Node graph of siblings. 219 | 220 | Examples 221 | -------- 222 | >>> siblings = get_siblings(onto, id) 223 | """ 224 | 225 | parents = get_parents(graph, node) 226 | siblings = set.union( 227 | *[set(get_children(graph, parent)) for parent in parents] 228 | ) - set([node]) 229 | if node_list is None: 230 | return siblings 231 | return subset_nodes_to_set(siblings, node_list) 232 | 233 | 234 | def get_all_ancestors(graph, node, node_list=None, inclusive=False) -> nx.DiGraph: 235 | """Get all ancestor nodes of a given node. 236 | 237 | Parameters 238 | ---------- 239 | graph: networkx.DiGraph 240 | Node graph. 241 | node: str 242 | ID of given node. 243 | node_list: list, set, optional, default: None 244 | A restricted node list for filtering. 245 | inclusive: bool, default: False 246 | Whether to include the given node in the results. 247 | 248 | Returns 249 | ------- 250 | networkx.DiGraph 251 | Node graph of ancestors. 252 | 253 | Examples 254 | -------- 255 | >>> ancestors = get_all_ancestors(onto, id) 256 | """ 257 | 258 | ancestors = nx.ancestors(graph, node) 259 | if inclusive: 260 | ancestors = ancestors | {node} 261 | 262 | if node_list is None: 263 | return ancestors 264 | return subset_nodes_to_set(ancestors, node_list) 265 | 266 | 267 | def get_all_descendants(graph, nodes, node_list=None, inclusive=False) -> nx.DiGraph: 268 | """Get all descendant nodes of given node(s). 269 | 270 | Parameters 271 | ---------- 272 | graph: networkx.DiGraph 273 | Node graph. 274 | nodes: str, list 275 | ID of given node or a list of node IDs. 276 | node_list: list, set, optional, default: None 277 | A restricted node list for filtering. 278 | inclusive: bool, default: False 279 | Whether to include the given node in the results. 280 | 281 | Returns 282 | ------- 283 | networkx.DiGraph 284 | Node graph of descendants. 285 | 286 | Examples 287 | -------- 288 | >>> descendants = get_all_descendants(onto, id) 289 | """ 290 | 291 | if isinstance(nodes, str): # one term id 292 | descendants = nx.descendants(graph, nodes) 293 | else: # list of term ids 294 | descendants = set.union(*[nx.descendants(graph, node) for node in nodes]) 295 | 296 | if inclusive: 297 | descendants = descendants | {nodes} 298 | 299 | if node_list is None: 300 | return descendants 301 | return subset_nodes_to_set(descendants, node_list) 302 | 303 | 304 | def get_lowest_common_ancestor(graph, node1, node2) -> nx.DiGraph: 305 | """Get the lowest common ancestor of two nodes. 306 | 307 | Parameters 308 | ---------- 309 | graph: networkx.DiGraph 310 | Node graph. 311 | node1: str 312 | ID of node1. 313 | node2: str 314 | ID of node2. 315 | 316 | Returns 317 | ------- 318 | networkx.DiGraph 319 | Node graph of descendants. 320 | 321 | Examples 322 | -------- 323 | >>> common_ancestor = get_lowest_common_ancestor(onto, id1, id2) 324 | """ 325 | 326 | return nx.algorithms.lowest_common_ancestors.lowest_common_ancestor( 327 | graph, node1, node2 328 | ) 329 | 330 | 331 | def find_most_viable_parent(graph, node, node_list): 332 | """Get most viable parent of a given node among the node_list. 333 | 334 | Parameters 335 | ---------- 336 | graph: networkx.DiGraph 337 | Node graph. 338 | node: str 339 | ID of given node. 340 | node_list: list, set, optional, default: None 341 | A restricted node list for filtering. 342 | 343 | Returns 344 | ------- 345 | networkx.DiGraph 346 | Node graph of parents. 347 | 348 | Examples 349 | -------- 350 | >>> coarse_grained = find_most_viable_parent(onto, id, celltype_list) 351 | """ 352 | 353 | parents = get_parents(graph, node, node_list=node_list) 354 | if len(parents) == 0: 355 | coarse_grained = None 356 | all_parents = list(get_parents(graph, node)) 357 | if len(all_parents) == 1: 358 | grandparents = get_parents(graph, all_parents[0], node_list=node_list) 359 | if len(grandparents) == 1: 360 | (coarse_grained,) = grandparents 361 | elif len(parents) == 1: 362 | (coarse_grained,) = parents 363 | else: 364 | for parent in list(parents): 365 | coarse_grained = None 366 | if get_all_ancestors(graph, parent, node_list=pd.Index(parents)): 367 | coarse_grained = parent 368 | break 369 | return coarse_grained 370 | 371 | 372 | def ontology_similarity(graph, node1, node2, restricted_set=None) -> int: 373 | """Get the ontology similarity of two terms based on the number of common ancestors. 374 | 375 | Parameters 376 | ---------- 377 | graph: networkx.DiGraph 378 | Node graph. 379 | node1: str 380 | ID of node1. 381 | node2: str 382 | ID of node2. 383 | restricted_set: set 384 | Set of restricted nodes to remove from their common ancestors. 385 | 386 | Returns 387 | ------- 388 | int 389 | Number of common ancestors. 390 | 391 | Examples 392 | -------- 393 | >>> onto_sim = ontology_similarity(onto, id1, id2) 394 | """ 395 | 396 | common_ancestors = get_all_ancestors(graph, node1).intersection( 397 | get_all_ancestors(graph, node2) 398 | ) 399 | if restricted_set is not None: 400 | common_ancestors -= restricted_set 401 | return len(common_ancestors) 402 | 403 | 404 | def all_pair_similarities(graph, nodes, restricted_set=None) -> "pandas.DataFrame": 405 | """Get the ontology similarity of all pairs in a node list. 406 | 407 | Parameters 408 | ---------- 409 | graph: networkx.DiGraph 410 | Node graph. 411 | nodes: list, set 412 | List of nodes. 413 | restricted_set: set 414 | Set of restricted nodes to remove from their common ancestors. 415 | 416 | Returns 417 | ------- 418 | pandas.DataFrame 419 | A pandas dataframe showing similarity for all node pairs. 420 | 421 | Examples 422 | -------- 423 | >>> onto_sim = all_pair_similarities(onto, id1, id2) 424 | """ 425 | 426 | import itertools 427 | import pandas as pd 428 | 429 | node_pairs = itertools.combinations(nodes, 2) 430 | similarity_df = pd.DataFrame(0, index=nodes, columns=nodes) 431 | for node1, node2 in node_pairs: 432 | s = ontology_similarity( 433 | graph, node1, node2, restricted_set=restricted_set 434 | ) # too slow, cause recomputes each ancestor 435 | similarity_df.at[node1, node2] = s 436 | return similarity_df + similarity_df.T 437 | 438 | 439 | def ontology_silhouette_width( 440 | embeddings: "numpy.ndarray", 441 | labels: List[str], 442 | onto: nx.DiGraph, 443 | name2id: dict, 444 | metric: str = "cosine", 445 | ) -> Tuple[float, "pandas.DataFrame"]: 446 | """Get the average silhouette width of celltypes, being aware of cell ontology such that 447 | ancestors are not considered inter-cluster and descendants are considered intra-cluster. 448 | 449 | Parameters 450 | ---------- 451 | embeddings: numpy.ndarray 452 | Cell embeddings. 453 | labels: List[str] 454 | Celltype names. 455 | onto: 456 | Cell ontology graph object. 457 | name2id: dict 458 | A mapping dictionary of celltype name to id 459 | metric: str, default: "cosine" 460 | The distance metric to use for scipy.spatial.distance.cdist(). 461 | 462 | Returns 463 | ------- 464 | asw: float 465 | The average silhouette width. 466 | asw_df: pandas.DataFrame 467 | A dataframe containing silhouette width as well as 468 | inter and intra cluster distances for all cell types. 469 | 470 | Examples 471 | -------- 472 | >>> asw, asw_df = ontology_silhouette_width( 473 | embeddings, labels, onto, name2id, metric="cosine" 474 | ) 475 | """ 476 | 477 | import numpy as np 478 | import pandas as pd 479 | from scipy.spatial.distance import cdist 480 | 481 | data = {"label": [], "intra": [], "inter": [], "sw": []} 482 | for i, name1 in enumerate(labels): 483 | term_id1 = name2id[name1] 484 | ancestors = get_all_ancestors(onto, term_id1) 485 | descendants = get_all_descendants(onto, term_id1) 486 | distances = cdist( 487 | embeddings[i].reshape(1, -1), embeddings, metric=metric 488 | ).flatten() 489 | 490 | a_i = [] 491 | b_i = {} 492 | for j, name2 in enumerate(labels): 493 | if i == j: 494 | continue 495 | 496 | term_id2 = name2id[name2] 497 | if term_id2 == term_id1 or term_id2 in descendants: # intra-cluster 498 | a_i.append(distances[j]) 499 | elif term_id2 != term_id1 and term_id2 not in ancestors: # inter-cluster 500 | if term_id2 not in b_i: 501 | b_i[term_id2] = [] 502 | b_i[term_id2].append(distances[j]) 503 | 504 | if len(a_i) <= 1 or not b_i: 505 | continue 506 | a_i = np.sum(a_i) / (len(a_i) - 1) 507 | b_i = np.min( 508 | [np.sum(values) / len(values) for values in b_i.values() if len(values) > 1] 509 | ) 510 | 511 | s_i = (b_i - a_i) / np.max([a_i, b_i]) 512 | 513 | data["label"].append(name1) 514 | data["intra"].append(a_i) 515 | data["inter"].append(b_i) 516 | data["sw"].append(s_i) 517 | return np.mean(data["sw"]), pd.DataFrame(data) 518 | -------------------------------------------------------------------------------- /src/scimilarity/triplet_selector.py: -------------------------------------------------------------------------------- 1 | from itertools import combinations 2 | import numpy as np 3 | import random 4 | import torch 5 | import torch.nn.functional as F 6 | from typing import List, Union, Optional 7 | 8 | from .ontologies import ( 9 | import_cell_ontology, 10 | get_id_mapper, 11 | get_all_ancestors, 12 | get_all_descendants, 13 | get_parents, 14 | ) 15 | 16 | 17 | class TripletSelector: 18 | """For each anchor-positive pair, mine negative samples to create a triplet. 19 | 20 | Parameters 21 | ---------- 22 | margin: float 23 | Triplet loss margin. 24 | negative_selection: str, default: "semihard" 25 | Method for negative selection: {"semihard", "hardest", "random"}. 26 | perturb_labels: bool, default: False 27 | Whether to perturb the ontology labels by coarse graining one level up. 28 | perturb_labels_fraction: float, default: 0.5 29 | The fraction of labels to perturb. 30 | 31 | Examples 32 | -------- 33 | >>> triplet_selector = TripletSelector(margin=0.05, negative_selection="semihard") 34 | """ 35 | 36 | def __init__( 37 | self, 38 | margin: float, 39 | negative_selection: str = "semihard", 40 | perturb_labels: bool = False, 41 | perturb_labels_fraction: float = 0.5, 42 | ): 43 | self.margin = margin 44 | self.negative_selection = negative_selection 45 | 46 | self.onto = import_cell_ontology() 47 | self.id2name = get_id_mapper(self.onto) 48 | self.name2id = {value: key for key, value in self.id2name.items()} 49 | 50 | self.perturb_labels = perturb_labels 51 | self.perturb_labels_fraction = perturb_labels_fraction 52 | 53 | def get_triplets_idx( 54 | self, 55 | embeddings: Union[np.ndarray, torch.Tensor], 56 | labels: Union[np.ndarray, torch.Tensor], 57 | int2label: dict, 58 | studies: Optional[Union[np.ndarray, torch.Tensor, list]] = None, 59 | ): 60 | """Get triplets as anchor, positive, and negative cell indices. 61 | 62 | Parameters 63 | ---------- 64 | embeddings: numpy.ndarray, torch.Tensor 65 | Cell embeddings. 66 | labels: numpy.ndarray, torch.Tensor 67 | Cell labels in integer form. 68 | int2label: dict 69 | Dictionary to map labels in integer form to string 70 | studies: numpy.ndarray, torch.Tensor, optional, default: None 71 | Studies metadata for each cell. 72 | 73 | Returns 74 | ------- 75 | triplets: Tuple[List, List, List] 76 | A tuple of lists containing anchor, positive, and negative cell indices. 77 | num_hard_triplets: int 78 | Number of hard triplets. 79 | num_viable_triplets: int 80 | Number of viable triplets. 81 | ) 82 | """ 83 | 84 | if isinstance(embeddings, torch.Tensor): 85 | distance_matrix = self.pdist(embeddings.detach().cpu().numpy()) 86 | else: 87 | distance_matrix = self.pdist(embeddings) 88 | 89 | if isinstance(labels, torch.Tensor): 90 | labels = labels.detach().cpu().numpy() 91 | 92 | if studies is not None and isinstance(studies, torch.Tensor): 93 | studies = studies.detach().cpu().numpy() 94 | 95 | labels_ids = np.array([self.name2id[int2label[label]] for label in labels]) 96 | labels_set = set(labels) 97 | 98 | if self.perturb_labels: 99 | labels_ids_set = set(labels_ids.tolist()) 100 | label2int = {value: key for key, value in int2label.items()} 101 | perturb_list = random.choices( 102 | np.arange(len(labels)), 103 | k=int(len(labels) * self.perturb_labels_fraction), 104 | ) 105 | for i in perturb_list: # cells chosen for perturbation of labels 106 | term_id = self.name2id[int2label[labels[i]]] 107 | parents = set() 108 | # Max ancestor levels: 1=parents, 2=grandparents, ... 109 | max_ancestors = 1 110 | ancestor_level = 0 111 | while ancestor_level < max_ancestors: 112 | ancestor_level += 1 113 | if not parents: 114 | parents = get_parents(self.onto, term_id) 115 | else: 116 | current = set() 117 | for p in parents: 118 | current = current | get_parents(self.onto, p) 119 | parents = current 120 | found = any((parent in labels_ids_set for parent in parents)) 121 | if found is True: 122 | parents = list(parents) 123 | np.random.shuffle(parents) 124 | p = next( 125 | parent for parent in parents if parent in labels_ids_set 126 | ) 127 | labels[i] = label2int[self.id2name[p]] 128 | break # label perturbed, skip the rest of the ancestors 129 | 130 | triplets = [] 131 | num_hard_triplets = 0 132 | num_viable_triplets = 0 133 | for label in labels_set: 134 | term_id = self.name2id[int2label[label]] 135 | ancestors = get_all_ancestors(self.onto, term_id) 136 | descendants = get_all_descendants(self.onto, term_id) 137 | violating_terms = ancestors.union(descendants) 138 | 139 | label_mask = labels == label 140 | label_indices = np.where(label_mask)[0] 141 | if len(label_indices) < 2: 142 | continue 143 | 144 | # negatives are labels that are not the same as current and not in violating terms 145 | negative_indices = np.where( 146 | np.logical_not(label_mask | np.isin(labels_ids, list(violating_terms))) 147 | )[0] 148 | 149 | # compute all pairs of anchor-positives 150 | anchor_positives = list(combinations(label_indices, 2)) 151 | 152 | # enforce anchor and positive coming from different studies 153 | if studies is not None: 154 | anchor_positives = [ 155 | (anchor, positive) 156 | for anchor, positive in anchor_positives 157 | if studies[anchor] != studies[positive] 158 | ] 159 | 160 | for anchor_positive in anchor_positives: 161 | loss_values = ( 162 | distance_matrix[anchor_positive[0], anchor_positive[1]] 163 | - distance_matrix[[anchor_positive[0]], negative_indices] 164 | + self.margin 165 | ) 166 | num_hard_triplets += (loss_values > 0).sum() 167 | num_viable_triplets += loss_values.size 168 | 169 | # select one negative for anchor positive pair based on selection function 170 | if self.negative_selection == "semihard": 171 | hard_negative = self.semihard_negative(loss_values) 172 | elif self.negative_selection == "hardest": 173 | hard_negative = self.hardest_negative(loss_values) 174 | elif self.negative_selection == "random": 175 | hard_negative = self.random_negative(loss_values) 176 | else: 177 | hard_negative = None 178 | 179 | if hard_negative is not None: 180 | hard_negative = negative_indices[hard_negative] 181 | triplets.append( 182 | [anchor_positive[0], anchor_positive[1], hard_negative] 183 | ) 184 | 185 | if len(triplets) == 0: 186 | triplets.append([0, 0, 0]) 187 | 188 | anchor_idx, positive_idx, negative_idx = tuple( 189 | map(list, zip(*triplets)) 190 | ) # tuple([list(t) for t in zip(*triplets)]) 191 | return ( 192 | ( 193 | anchor_idx, 194 | positive_idx, 195 | negative_idx, 196 | ), 197 | num_hard_triplets, 198 | num_viable_triplets, 199 | ) 200 | 201 | def pdist(self, vectors: np.ndarray): 202 | """Get pair-wise distance between all cell embeddings. 203 | 204 | Parameters 205 | ---------- 206 | vectors: numpy.ndarray 207 | Cell embeddings. 208 | 209 | Returns 210 | ------- 211 | numpy.ndarray 212 | Distance matrix of cell embeddings. 213 | """ 214 | 215 | vectors_squared_sum = (vectors**2).sum(axis=1) 216 | distance_matrix = ( 217 | -2 * np.matmul(vectors, np.matrix.transpose(vectors)) 218 | + vectors_squared_sum.reshape(1, -1) 219 | + vectors_squared_sum.reshape(-1, 1) 220 | ) 221 | return distance_matrix 222 | 223 | def hardest_negative(self, loss_values): 224 | """Get hardest negative. 225 | 226 | Parameters 227 | ---------- 228 | loss_values: numpy.ndarray 229 | Triplet loss of all negatives for given anchor positive pair. 230 | 231 | Returns 232 | ------- 233 | int 234 | Index of selection. 235 | """ 236 | 237 | hard_negative = np.argmax(loss_values) 238 | return hard_negative if loss_values[hard_negative] > 0 else None 239 | 240 | def random_negative(self, loss_values): 241 | """Get random negative. 242 | 243 | Parameters 244 | ---------- 245 | loss_values: numpy.ndarray 246 | Triplet loss of all negatives for given anchor positive pair. 247 | 248 | Returns 249 | ------- 250 | int 251 | Index of selection. 252 | """ 253 | 254 | hard_negatives = np.where(loss_values > 0)[0] 255 | return np.random.choice(hard_negatives) if len(hard_negatives) > 0 else None 256 | 257 | def semihard_negative(self, loss_values): 258 | """Get a random semihard negative. 259 | 260 | Parameters 261 | ---------- 262 | loss_values: numpy.ndarray 263 | Triplet loss of all negatives for given anchor positive pair. 264 | 265 | Returns 266 | ------- 267 | int 268 | Index of selection. 269 | """ 270 | 271 | semihard_negatives = np.where( 272 | np.logical_and(loss_values < self.margin, loss_values > 0) 273 | )[0] 274 | return ( 275 | np.random.choice(semihard_negatives) 276 | if len(semihard_negatives) > 0 277 | else None 278 | ) 279 | 280 | def get_asw( 281 | self, 282 | embeddings: Union[np.ndarray, torch.Tensor], 283 | labels: List[str], 284 | int2label: dict, 285 | metric: str = "cosine", 286 | ) -> float: 287 | """Get the average silhouette width of celltypes, being aware of cell ontology such that 288 | ancestors are not considered inter-cluster and descendants are considered intra-cluster. 289 | 290 | Parameters 291 | ---------- 292 | embeddings: numpy.ndarray, torch.Tensor 293 | Cell embeddings. 294 | labels: List[str] 295 | Celltype names. 296 | int2label: dict 297 | Dictionary to map labels in integer form to string 298 | metric: str, default: "cosine" 299 | The distance metric to use for scipy.spatial.distance.cdist(). 300 | 301 | Returns 302 | ------- 303 | asw: float 304 | The average silhouette width. 305 | 306 | Examples 307 | -------- 308 | >>> asw = ontology_silhouette_width(embeddings, labels, metric="cosine") 309 | """ 310 | 311 | if isinstance(embeddings, torch.Tensor): 312 | distance_matrix = self.pdist(embeddings.detach().cpu().numpy()) 313 | else: 314 | distance_matrix = self.pdist(embeddings) 315 | 316 | if isinstance(labels, torch.Tensor): 317 | labels = labels.detach().cpu().numpy() 318 | 319 | sw = [] 320 | for i, label1 in enumerate(labels): 321 | term_id1 = self.name2id[int2label[label1]] 322 | ancestors = get_all_ancestors(self.onto, term_id1) 323 | descendants = get_all_descendants(self.onto, term_id1) 324 | 325 | a_i = [] 326 | b_i = {} 327 | for j, label2 in enumerate(labels): 328 | if i == j: 329 | continue 330 | 331 | term_id2 = self.name2id[int2label[label2]] 332 | if term_id2 == term_id1 or term_id2 in descendants: # intra-cluster 333 | a_i.append(distance_matrix[i, j]) 334 | elif ( 335 | term_id2 != term_id1 and term_id2 not in ancestors 336 | ): # inter-cluster 337 | if term_id2 not in b_i: 338 | b_i[term_id2] = [] 339 | b_i[term_id2].append(distance_matrix[i, j]) 340 | 341 | if len(a_i) <= 1 or not b_i: 342 | continue 343 | a_i = np.sum(a_i) / (len(a_i) - 1) 344 | b_i = np.min( 345 | [ 346 | np.sum(values) / len(values) 347 | for values in b_i.values() 348 | if len(values) > 1 349 | ] 350 | ) 351 | 352 | s_i = (b_i - a_i) / np.max([a_i, b_i]) 353 | sw.append(s_i) 354 | return np.mean(sw) 355 | 356 | 357 | class TripletLoss(torch.nn.TripletMarginLoss): 358 | """ 359 | Wrapper for pytorch TripletMarginLoss. 360 | Triplets are generated using TripletSelector object which take embeddings and labels 361 | then return triplets. 362 | 363 | Parameters 364 | ---------- 365 | margin: float 366 | Triplet loss margin. 367 | sample_across_studies: bool, default: True 368 | Whether to enforce anchor-positive pairs being from different studies. 369 | negative_selection: str 370 | Method for negative selection: {"semihard", "hardest", "random"} 371 | perturb_labels: bool, default: False 372 | Whether to perturb the ontology labels by coarse graining one level up. 373 | perturb_labels_fraction: float, default: 0.5 374 | The fraction of labels to perturb 375 | 376 | Examples 377 | -------- 378 | >>> triplet_loss = TripletLoss(margin=0.05) 379 | """ 380 | 381 | def __init__( 382 | self, 383 | margin: float, 384 | sample_across_studies: bool = True, 385 | negative_selection: str = "semihard", 386 | perturb_labels: bool = False, 387 | perturb_labels_fraction: float = 0.5, 388 | ): 389 | super().__init__() 390 | self.margin = margin 391 | self.sample_across_studies = sample_across_studies 392 | self.triplet_selector = TripletSelector( 393 | margin=margin, 394 | negative_selection=negative_selection, 395 | perturb_labels=perturb_labels, 396 | perturb_labels_fraction=perturb_labels_fraction, 397 | ) 398 | 399 | def forward( 400 | self, 401 | embeddings: torch.Tensor, 402 | labels: torch.Tensor, 403 | int2label: dict, 404 | studies: torch.Tensor, 405 | ): 406 | if self.sample_across_studies is False: 407 | studies = None 408 | 409 | ( 410 | triplets_idx, 411 | num_violating_triplets, 412 | num_viable_triplets, 413 | ) = self.triplet_selector.get_triplets_idx( 414 | embeddings, labels, int2label, studies 415 | ) 416 | 417 | anchor_idx, positive_idx, negative_idx = triplets_idx 418 | anchor = embeddings[anchor_idx] 419 | positive = embeddings[positive_idx] 420 | negative = embeddings[negative_idx] 421 | 422 | return ( 423 | F.triplet_margin_loss( 424 | anchor, 425 | positive, 426 | negative, 427 | margin=self.margin, 428 | p=self.p, 429 | eps=self.eps, 430 | swap=self.swap, 431 | reduction="none", 432 | ), 433 | torch.tensor(num_violating_triplets, dtype=torch.float), 434 | torch.tensor(num_viable_triplets, dtype=torch.float), 435 | triplets_idx, 436 | ) 437 | -------------------------------------------------------------------------------- /src/scimilarity/visualizations.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Tuple, Optional 2 | 3 | 4 | def aggregate_counts(data: "pandas.DataFrame", levels: List[str]) -> dict: 5 | """Aggregates cell counts on sample metadata and compiles it into circlify format. 6 | 7 | Parameters 8 | ---------- 9 | data: pandas.DataFrame 10 | A pandas dataframe containing sample metadata. 11 | levels: List[str] 12 | Specify the groupby columns for grouping the sample metadata. 13 | 14 | Returns 15 | ------- 16 | dict 17 | A circlify format dictionary containing grouped sample metadata. 18 | 19 | Examples 20 | -------- 21 | >>> circ_dict = aggregate_counts(sample_metadata, ["tissue", "disease"]) 22 | """ 23 | 24 | data_dict = {} 25 | for n in range(len(levels)): 26 | # construct a groupby dataframe to obtain counts 27 | columns = levels[0 : (n + 1)] 28 | df = ( 29 | data.groupby(columns, observed=True)[columns[0]] 30 | .count() 31 | .reset_index(name="count") 32 | ) 33 | 34 | # construct a nested dict to handle children levels 35 | for r in df.index: 36 | if n == 0: # top level 37 | data_dict[df.iloc[r, 0]] = {"datum": df.loc[r, "count"]} 38 | else: 39 | entry = data_dict[df.iloc[r, 0]] 40 | for c in range( 41 | 1, len(columns) 42 | ): # go through nested levels to find the deepest 43 | if ( 44 | "children" not in entry 45 | ): # create a child dict if it does not exist 46 | entry["children"] = {} 47 | entry = entry["children"] # go into child dict 48 | if df.iloc[r, c] in entry: # go into child dict entry if it exists 49 | entry = entry[df.iloc[r, c]] 50 | entry[df.iloc[r, c]] = { 51 | "datum": df.loc[r, "count"] 52 | } # create child entry 53 | return data_dict 54 | 55 | 56 | def assign_size( 57 | data_dict: dict, 58 | data: "pandas.DataFrame", 59 | levels: List[str], 60 | size_column: str, 61 | name_column: str, 62 | ) -> dict: 63 | """Assigns circle sizes to a circlify format dictionary. 64 | 65 | Parameters 66 | ---------- 67 | data_dict: dict 68 | A circlify format dictionary. 69 | data: pandas.DataFrame 70 | A pandas dataframe containing sample metadata. 71 | levels: List[str] 72 | Specify the groupby columns for grouping the sample metadata. 73 | size_column: str 74 | The name of the column that will be used for circle size. 75 | name_column: str 76 | The name of the column that will be used for circle name. 77 | 78 | Returns 79 | ------- 80 | dict 81 | A circlify format dictionary. 82 | 83 | Examples 84 | -------- 85 | >>> circ_dict = assign_size(circ_dict, sample_metadata, ["tissue", "disease"], size_column="cells", name_column="study") 86 | """ 87 | 88 | df = data[levels + [size_column, name_column]] 89 | df = ( 90 | df.groupby(levels + [name_column], observed=True)[size_column] 91 | .sum() 92 | .reset_index(name="count") 93 | ) 94 | for ( 95 | r 96 | ) in ( 97 | df.index 98 | ): # find the deepest levels in data_dict and create an entry with (name, size) 99 | entry = data_dict[df.iloc[r, 0]] 100 | for c in range(1, len(levels)): 101 | entry = entry["children"][df.iloc[r, c]] 102 | if "children" not in entry: 103 | entry["children"] = {} 104 | entry["children"][df.loc[r, name_column]] = {"datum": df.loc[r, "count"]} 105 | return data_dict 106 | 107 | 108 | def assign_suffix( 109 | data_dict: dict, 110 | data: "pandas.DataFrame", 111 | levels: List[str], 112 | suffix_column: str, 113 | name_column: str, 114 | ) -> dict: 115 | """Assigns circle name and suffix to a circlify format dictionary. 116 | 117 | Parameters 118 | ---------- 119 | data_dict: dict 120 | A circlify format dictionary. 121 | data: pandas.DataFrame 122 | A pandas dataframe containing sample metadata. 123 | levels: List[str] 124 | Specify the groupby columns for grouping the sample metadata. 125 | suffix_column: str 126 | The name of the column that will be used for the circle name suffix. 127 | name_column: str 128 | The name of the column that will be used for circle name. 129 | 130 | Returns 131 | ------- 132 | dict 133 | A circlify format dictionary. 134 | 135 | Examples 136 | -------- 137 | >>> circ_dict = assign_suffix(circ_dict, sample_metadata, ["tissue", "disease"], suffix_column="cells", name_column="study") 138 | """ 139 | 140 | df = data[levels + [suffix_column, name_column]] 141 | for r in df.index: # find the deepest levels in data_dict and rename with suffix 142 | entry = data_dict[df.iloc[r, 0]] 143 | for c in range(1, len(levels)): 144 | entry = entry["children"][df.iloc[r, c]] 145 | if df.loc[r, name_column] in entry["children"]: 146 | entry["children"][ 147 | f"{df.loc[r, name_column]}_{df.loc[r, suffix_column]}" 148 | ] = entry["children"].pop(df.loc[r, name_column]) 149 | return data_dict 150 | 151 | 152 | def assign_colors( 153 | data_dict: dict, 154 | data: "pandas.DataFrame", 155 | levels: List[str], 156 | color_column: str, 157 | name_column: str, 158 | ) -> dict: 159 | """Assigns circle name and color to a circlify format dictionary. 160 | 161 | Parameters 162 | ---------- 163 | data_dict: dict 164 | A circlify format dictionary. 165 | data: pandas.DataFrame 166 | A pandas dataframe containing sample metadata. 167 | levels: List[str] 168 | Specify the groupby columns for grouping the sample metadata. 169 | color_column: str 170 | The name of the column that will be used for the circle color. 171 | name_column: str 172 | The name of the column that will be used for circle name. 173 | 174 | Returns 175 | ------- 176 | dict 177 | A circlify format dictionary. 178 | 179 | Examples 180 | -------- 181 | >>> circ_dict = assign_colors(circ_dict, sample_metadata, ["tissue", "disease"], color_column="cells", name_column="study") 182 | """ 183 | 184 | df = data[levels + [color_column, name_column]] 185 | for r in df.index: # find the deepest levels in data_dict and rename with color 186 | entry = data_dict[df.iloc[r, 0]] 187 | for c in range(1, len(levels)): 188 | entry = entry["children"][df.iloc[r, c]] 189 | if df.loc[r, name_column] in entry["children"]: 190 | entry["children"][df.loc[r, color_column]] = entry["children"].pop( 191 | df.loc[r, name_column] 192 | ) 193 | return data_dict 194 | 195 | 196 | def get_children_data(data_dict: dict) -> List[dict]: 197 | """Recursively get all children data for a given circle. 198 | 199 | Parameters 200 | ---------- 201 | data_dict: dict 202 | A circlify format dictionary 203 | 204 | Returns 205 | ------- 206 | List[dict] 207 | A list of children data. 208 | 209 | Examples 210 | -------- 211 | >>> children = get_children_data(circ_dict[i]["children"]) 212 | """ 213 | 214 | child_data = [] 215 | for i in data_dict: # recursively get all children data 216 | entry = {"id": i, "datum": data_dict[i]["datum"]} 217 | if "children" in data_dict[i]: 218 | children = get_children_data(data_dict[i]["children"]) 219 | entry["children"] = children 220 | child_data.append(entry) 221 | return child_data 222 | 223 | 224 | def circ_dict2data(circ_dict: dict) -> List[dict]: 225 | """Convert a circlify format dictionary to the list format expected by circlify. 226 | 227 | Parameters 228 | ---------- 229 | data_dict: dict 230 | A circlify format dictionary 231 | 232 | Returns 233 | ------- 234 | List[dict] 235 | A list of circle data. 236 | 237 | Examples 238 | -------- 239 | >>> circ_data = circ_dict2data(circ_dict) 240 | """ 241 | 242 | circ_data = [] 243 | for i in circ_dict: # convert dict to circlify list data 244 | entry = {"id": i, "datum": circ_dict[i]["datum"]} 245 | if "children" in circ_dict[i]: 246 | children = get_children_data(circ_dict[i]["children"]) 247 | entry["children"] = children 248 | circ_data.append(entry) 249 | return circ_data 250 | 251 | 252 | def draw_circles( 253 | circ_data: List[dict], 254 | title: str = "", 255 | figsize: Tuple[int, int] = (10, 10), 256 | filename: Optional[str] = None, 257 | use_colormap: Optional[str] = None, 258 | use_suffix: Optional[dict] = None, 259 | use_suffix_as_color: bool = False, 260 | ): 261 | """Draw the circlify plot. 262 | 263 | Parameters 264 | ---------- 265 | circ_data: List[dict] 266 | A circlify format list. 267 | title: str, default: "" 268 | The figure title. 269 | figsize: Tuple[int, int], default: (10, 10) 270 | The figure size in inches. 271 | filename: str, optional, default: None 272 | Filename to save the figure. 273 | use_colormap: str, optional, default: None 274 | The colormap identifier. 275 | use_suffix: dict, optional, default: None 276 | A mapping of suffix to color using a dictionary in the form {suffix: float} 277 | use_suffix_as_color: bool, default: False 278 | Use the suffix as the color. This expects the suffix to be a float. 279 | 280 | Examples 281 | -------- 282 | >>> draw_circles(circ_data) 283 | """ 284 | 285 | try: 286 | import circlify as circ 287 | except: 288 | raise ImportError( 289 | "Package 'circlify' not found. Please install with 'pip install circlify'." 290 | ) 291 | 292 | import matplotlib.pyplot as plt 293 | import matplotlib as mpl 294 | 295 | mpl.rcParams["pdf.fonttype"] = 42 296 | 297 | circles = circ.circlify(circ_data, show_enclosure=True) 298 | 299 | fig, ax = plt.subplots(figsize=figsize) 300 | if use_colormap: 301 | cmap = mpl.cm.get_cmap(use_colormap) 302 | 303 | ax.set_title(title) # title 304 | ax.axis("off") # remove axes 305 | 306 | # find axis boundaries 307 | lim = max( 308 | max(abs(circle.x) + circle.r, abs(circle.y) + circle.r) for circle in circles 309 | ) 310 | plt.xlim(-lim, lim) 311 | plt.ylim(-lim, lim) 312 | 313 | # 1st level: 314 | for circle in circles: 315 | if circle.level != 1: 316 | continue 317 | x, y, r = circle 318 | ax.add_patch( 319 | plt.Circle( 320 | (x, y), 321 | r, 322 | alpha=0.5, 323 | linewidth=1, 324 | facecolor="lightblue", 325 | edgecolor="black", 326 | ) 327 | ) 328 | 329 | # 2nd level: 330 | for circle in circles: 331 | if circle.level != 2: 332 | continue 333 | x, y, r = circle 334 | plt.annotate(circle.ex["id"], (x, y), ha="center", color="black") 335 | ax.add_patch( 336 | plt.Circle( 337 | (x, y), 338 | r, 339 | alpha=0.5, 340 | linewidth=1, 341 | facecolor="#69b3a2", 342 | edgecolor="black", 343 | ) 344 | ) 345 | 346 | # 3rd level: 347 | for circle in circles: 348 | if circle.level != 3: 349 | continue 350 | x, y, r = circle 351 | 352 | if use_colormap: 353 | if use_suffix: 354 | suffix = circle.ex["id"].split("_")[-1] 355 | color_fraction = use_suffix[suffix] 356 | elif use_suffix_as_color: 357 | suffix = circle.ex["id"].split("_")[-1] 358 | color_fraction = float(suffix) 359 | else: 360 | color_fraction = circle.ex["id"] 361 | ax.add_patch( 362 | plt.Circle( 363 | (x, y), 364 | r, 365 | alpha=1, 366 | linewidth=1, 367 | facecolor=cmap(color_fraction), 368 | edgecolor="white", 369 | ) 370 | ) 371 | else: 372 | ax.add_patch( 373 | plt.Circle( 374 | (x, y), 375 | r, 376 | alpha=0.5, 377 | linewidth=1, 378 | facecolor="red", 379 | edgecolor="white", 380 | ) 381 | ) 382 | 383 | # 1st level labels: 384 | for circle in circles: 385 | if circle.level != 1: 386 | continue 387 | x, y, r = circle 388 | label = circle.ex["id"] 389 | plt.annotate( 390 | label, 391 | (x, y), 392 | va="center", 393 | ha="center", 394 | bbox=dict(facecolor="white", edgecolor="black", boxstyle="round", pad=0.5), 395 | ) 396 | 397 | if filename: # save the figure 398 | fig.savefig(filename, bbox_inches="tight") 399 | 400 | 401 | def hits_circles( 402 | metadata: "pandas.DataFrame", 403 | levels: list = ["tissue", "disease"], 404 | figsize: Tuple[int, int] = (10, 10), 405 | filename: Optional[str] = None, 406 | ): 407 | """Visualize sample metadata as circle plots for tissue and disease. 408 | 409 | Parameters 410 | ---------- 411 | metadata: pandas.DataFrame 412 | A pandas dataframe containing sample metadata for nearest neighbors 413 | with at least columns: ["study", "cells"], that represent the number 414 | of circles and circle size respectively. 415 | levels: list, default: ["tissue", "disease"] 416 | The columns to uses as group levels in the circles hierarchy. 417 | figsize: Tuple[int, int], default: (10, 10) 418 | Figure size, width x height 419 | filename: str, optional 420 | Filename to save the figure. 421 | 422 | Examples 423 | -------- 424 | >>> hits_circles(metadata) 425 | """ 426 | 427 | circ_dict = aggregate_counts(metadata, levels) 428 | circ_dict = assign_size( 429 | circ_dict, metadata, levels, size_column="cells", name_column="study" 430 | ) 431 | circ_data = circ_dict2data(circ_dict) 432 | draw_circles(circ_data, figsize=figsize, filename=filename) 433 | 434 | 435 | def hits_heatmap( 436 | sample_metadata: Dict[str, "pandas.DataFrame"], 437 | x: str, 438 | y: str, 439 | count_type: str = "cells", 440 | figsize: Tuple[int, int] = (10, 10), 441 | filename: Optional[str] = None, 442 | ): 443 | """Visualize a list of sample metadata objects as a heatmap. 444 | 445 | Parameters 446 | ---------- 447 | sample_metadata: Dict[str, pandas.DataFrame] 448 | A dict where keys are cluster names and values are pandas dataframes containing 449 | sample metadata for each cluster centroid with columns: ["tissue", "disease", "study", "sample"]. 450 | x: str 451 | x-axis label key. This corresponds to cluster name values. 452 | y: str 453 | y-axis label key. This corresponds to the dataframe column to visualize. 454 | count_type: {"cells", "fraction"}, default: "cells" 455 | Count type to color in the heatmap. 456 | figsize: Tuple[int, int], default: (10, 10) 457 | Figure size, width x height 458 | filename: str, optional 459 | Filename to save the figure. 460 | 461 | Examples 462 | -------- 463 | >>> hits_heatmap(sample_metadata, "time", "disease") 464 | """ 465 | 466 | import matplotlib.pyplot as plt 467 | import matplotlib as mpl 468 | import numpy as np 469 | import pandas as pd 470 | import seaborn as sns 471 | 472 | mpl.rcParams["pdf.fonttype"] = 42 473 | 474 | valid_count_types = {"cells", "fraction"} 475 | if count_type not in valid_count_types: 476 | raise ValueError( 477 | f"Unknown count_type {count_type}. Options are {valid_count_types}." 478 | ) 479 | 480 | for k in sample_metadata: 481 | sample_metadata[k][x] = k 482 | df = pd.concat(sample_metadata).reset_index(drop=True) 483 | 484 | if count_type == "cells": 485 | df_m = ( 486 | df.groupby([x, y], observed=True)["cells"].sum().unstack(level=0).fillna(0) 487 | ) 488 | else: 489 | df_m = ( 490 | df.groupby([x, y], observed=True)["fraction"] 491 | .mean() 492 | .unstack(level=0) 493 | .fillna(0) 494 | ) 495 | 496 | fig, ax = plt.subplots(figsize=figsize) 497 | sns.heatmap( 498 | ax=ax, 499 | data=df_m, 500 | xticklabels=True, 501 | yticklabels=True, 502 | square=True, 503 | cmap="Blues", 504 | linewidth=0.01, 505 | linecolor="gray", 506 | cbar_kws={"shrink": 0.5}, 507 | ) 508 | plt.tick_params(axis="both", labelsize=8, grid_alpha=0.0) 509 | 510 | # xticks 511 | ax.xaxis.tick_top() 512 | plt.xticks(np.arange(len(sample_metadata)) + 0.5, rotation=90) 513 | # axis labels 514 | plt.xlabel("") 515 | plt.ylabel("") 516 | # cbar font 517 | cbar = ax.collections[0].colorbar 518 | cbar.ax.tick_params(labelsize=6) 519 | 520 | if filename: # save the figure 521 | fig.savefig(filename, bbox_inches="tight") 522 | -------------------------------------------------------------------------------- /src/scimilarity/zarr_data_models.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | import numpy as np 3 | import os 4 | import pytorch_lightning as pl 5 | import torch 6 | from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler 7 | from tqdm import tqdm 8 | from typing import Optional 9 | 10 | from .zarr_dataset import ZarrDataset 11 | 12 | 13 | class scDataset(Dataset): 14 | """A class that represent a collection of single cell datasets in zarr format. 15 | 16 | Parameters 17 | ---------- 18 | data_list: list 19 | List of single-cell datasets. 20 | obs_celltype: str, default: "celltype_name" 21 | Cell type name. 22 | obs_study: str, default: "study" 23 | Study name. 24 | """ 25 | 26 | def __init__(self, data_list, obs_celltype="celltype_name", obs_study="study"): 27 | self.data_list = data_list 28 | self.ncells_list = [data.shape[0] for data in data_list] 29 | self.ncells = sum(self.ncells_list) 30 | self.obs_celltype = obs_celltype 31 | self.obs_study = obs_study 32 | 33 | self.data_idx = [ 34 | n for n in range(len(self.ncells_list)) for i in range(self.ncells_list[n]) 35 | ] 36 | self.cell_idx = [ 37 | i for n in range(len(self.ncells_list)) for i in range(self.ncells_list[n]) 38 | ] 39 | 40 | def __len__(self): 41 | return self.ncells 42 | 43 | def __getitem__(self, idx): 44 | # data, label, study 45 | data_idx = self.data_idx[idx] 46 | cell_idx = self.cell_idx[idx] 47 | return ( 48 | self.data_list[data_idx].get_cell(cell_idx).A, 49 | self.data_list[data_idx].get_obs(self.obs_celltype)[cell_idx], 50 | self.data_list[data_idx].get_obs(self.obs_study)[cell_idx], 51 | ) 52 | 53 | 54 | class MetricLearningDataModule(pl.LightningDataModule): 55 | """A class to encapsulate a collection of zarr datasets to train the model. 56 | 57 | Parameters 58 | ---------- 59 | train_path: str 60 | Path to folder containing all training datasets. 61 | All datasets should be in zarr format, aligned to a known gene space, and 62 | cleaned to only contain valid cell ontology terms. 63 | gene_order: str 64 | Use a given gene order as described in the specified file. One gene 65 | symbol per line. 66 | IMPORTANT: the zarr datasets should already be in this gene order 67 | after preprocessing. 68 | val_path: str, optional, default: None 69 | Path to folder containing all validation datasets. 70 | obs_field: str, default: "celltype_name" 71 | The obs key name containing celltype labels. 72 | batch_size: int, default: 1000 73 | Batch size. 74 | num_workers: int, default: 1 75 | The number of worker threads for dataloaders 76 | 77 | Examples 78 | -------- 79 | >>> datamodule = MetricLearningZarrDataModule( 80 | batch_size=1000, 81 | num_workers=1, 82 | obs_field="celltype_name", 83 | train_path="train", 84 | gene_order="gene_order.tsv", 85 | ) 86 | """ 87 | 88 | def __init__( 89 | self, 90 | train_path: str, 91 | gene_order: str, 92 | val_path: Optional[str] = None, 93 | obs_field: str = "celltype_name", 94 | batch_size: int = 1000, 95 | num_workers: int = 1, 96 | ): 97 | super().__init__() 98 | self.train_path = train_path 99 | self.val_path = val_path 100 | self.batch_size = batch_size 101 | self.num_workers = num_workers 102 | 103 | # gene space needs be aligned to the given gene order 104 | with open(gene_order, "r") as fh: 105 | self.gene_order = [line.strip() for line in fh] 106 | 107 | self.n_genes = len(self.gene_order) # used when creating training model 108 | 109 | train_data_list = [] 110 | self.train_Y = [] # text labels 111 | self.train_study = [] # text studies 112 | 113 | if self.train_path[-1] != os.sep: 114 | self.train_path += os.sep 115 | 116 | self.train_file_list = [ 117 | ( 118 | root.replace(self.train_path, "").split(os.sep)[0], 119 | dirs[0].replace(".aligned.zarr", ""), 120 | ) 121 | for root, dirs, files in os.walk(self.train_path) 122 | if dirs and dirs[0].endswith(".aligned.zarr") 123 | ] 124 | 125 | for study, sample in tqdm(self.train_file_list): 126 | data_path = os.path.join( 127 | self.train_path, study, sample, sample + ".aligned.zarr" 128 | ) 129 | if os.path.isdir(data_path): 130 | zarr_data = ZarrDataset(data_path) 131 | train_data_list.append(zarr_data) 132 | self.train_Y.extend(zarr_data.get_obs(obs_field).astype(str).tolist()) 133 | self.train_study.extend(zarr_data.get_obs("study").astype(str).tolist()) 134 | 135 | # Lazy load training data from list of zarr datasets 136 | self.train_dataset = scDataset(train_data_list) 137 | 138 | self.class_names = set(self.train_Y) 139 | self.label2int = {label: i for i, label in enumerate(self.class_names)} 140 | self.int2label = {value: key for key, value in self.label2int.items()} 141 | 142 | self.val_dataset = None 143 | if self.val_path is not None: 144 | val_data_list = [] 145 | self.val_Y = [] 146 | self.val_study = [] 147 | 148 | if self.val_path[-1] != os.sep: 149 | self.val_path += os.sep 150 | 151 | self.val_file_list = [ 152 | ( 153 | root.replace(self.val_path, "").split(os.sep)[0], 154 | dirs[0].replace(".aligned.zarr", ""), 155 | ) 156 | for root, dirs, files in os.walk(self.val_path) 157 | if dirs and dirs[0].endswith(".aligned.zarr") 158 | ] 159 | 160 | for study, sample in tqdm(self.val_file_list): 161 | data_path = os.path.join( 162 | self.val_path, study, sample, sample + ".aligned.zarr" 163 | ) 164 | if os.path.isdir(data_path): 165 | zarr_data = ZarrDataset(data_path) 166 | val_data_list.append(zarr_data) 167 | self.val_Y.extend(zarr_data.get_obs(obs_field).astype(str).tolist()) 168 | self.val_study.extend( 169 | zarr_data.get_obs("study").astype(str).tolist() 170 | ) 171 | 172 | # Lazy load val data from list of zarr datasets 173 | self.val_dataset = scDataset(val_data_list) 174 | 175 | def get_sampler_weights( 176 | self, labels: list, studies: Optional[list] = None 177 | ) -> WeightedRandomSampler: 178 | """Get weighted random sampler. 179 | 180 | Parameters 181 | ---------- 182 | dataset: scDataset 183 | Single cell dataset. 184 | 185 | Returns 186 | ------- 187 | WeightedRandomSampler 188 | A WeightedRandomSampler object. 189 | """ 190 | 191 | if studies is None: 192 | class_sample_count = Counter(labels) 193 | sample_weights = torch.Tensor([1.0 / class_sample_count[t] for t in labels]) 194 | else: 195 | class_sample_count = Counter(labels) 196 | study_sample_count = Counter(studies) 197 | class_sample_count = { 198 | x: np.log1p(class_sample_count[x] / 1e4) for x in class_sample_count 199 | } 200 | study_sample_count = { 201 | x: np.log1p(study_sample_count[x] / 1e5) for x in study_sample_count 202 | } 203 | sample_weights = torch.Tensor( 204 | [ 205 | 1.0 / class_sample_count[labels[i]] / study_sample_count[studies[i]] 206 | for i in range(len(labels)) 207 | ] 208 | ) 209 | return WeightedRandomSampler(sample_weights, len(sample_weights)) 210 | 211 | def collate(self, batch): 212 | """Collate tensors. 213 | 214 | Parameters 215 | ---------- 216 | batch: 217 | Batch to collate. 218 | 219 | Returns 220 | ------- 221 | tuple 222 | A Tuple[torch.Tensor, torch.Tensor, list] containing information 223 | on the collated tensors. 224 | """ 225 | 226 | profiles, labels, studies = tuple( 227 | map(list, zip(*batch)) 228 | ) # tuple([list(t) for t in zip(*batch)]) 229 | return ( 230 | torch.squeeze(torch.Tensor(np.vstack(profiles))), 231 | torch.Tensor([self.label2int[l] for l in labels]), # text to int labels 232 | studies, 233 | ) 234 | 235 | def train_dataloader(self) -> DataLoader: 236 | """Load the training dataset. 237 | 238 | Returns 239 | ------- 240 | DataLoader 241 | A DataLoader object containing the training dataset. 242 | """ 243 | 244 | return DataLoader( 245 | self.train_dataset, 246 | batch_size=self.batch_size, 247 | num_workers=self.num_workers, 248 | pin_memory=True, 249 | drop_last=True, 250 | sampler=self.get_sampler_weights(self.train_Y, self.train_study), 251 | collate_fn=self.collate, 252 | ) 253 | 254 | def val_dataloader(self) -> DataLoader: 255 | """Load the validation dataset. 256 | 257 | Returns 258 | ------- 259 | DataLoader 260 | A DataLoader object containing the validation dataset. 261 | """ 262 | 263 | if self.val_dataset is None: 264 | return None 265 | return DataLoader( 266 | self.val_dataset, 267 | batch_size=self.batch_size, 268 | num_workers=self.num_workers, 269 | pin_memory=True, 270 | drop_last=True, 271 | sampler=self.get_sampler_weights(self.val_Y, self.val_study), 272 | collate_fn=self.collate, 273 | ) 274 | 275 | def test_dataloader(self) -> DataLoader: 276 | """Load the test dataset. 277 | 278 | Returns 279 | ------- 280 | DataLoader 281 | A DataLoader object containing the test dataset. 282 | """ 283 | 284 | return self.val_dataloader() 285 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | # Tox configuration file 2 | # Read more under https://tox.wiki/ 3 | # THIS SCRIPT IS SUPPOSED TO BE AN EXAMPLE. MODIFY IT ACCORDING TO YOUR NEEDS! 4 | 5 | [tox] 6 | minversion = 3.24 7 | envlist = default 8 | isolated_build = True 9 | 10 | 11 | [testenv] 12 | description = Invoke pytest to run automated tests 13 | setenv = 14 | TOXINIDIR = {toxinidir} 15 | passenv = 16 | HOME 17 | SETUPTOOLS_* 18 | extras = 19 | testing 20 | commands = 21 | pytest {posargs} 22 | 23 | 24 | # # To run `tox -e lint` you need to make sure you have a 25 | # # `.pre-commit-config.yaml` file. See https://pre-commit.com 26 | # [testenv:lint] 27 | # description = Perform static analysis and style checks 28 | # skip_install = True 29 | # deps = pre-commit 30 | # passenv = 31 | # HOMEPATH 32 | # PROGRAMDATA 33 | # SETUPTOOLS_* 34 | # commands = 35 | # pre-commit run --all-files {posargs:--show-diff-on-failure} 36 | 37 | 38 | [testenv:{build,clean}] 39 | description = 40 | build: Build the package in isolation according to PEP517, see https://github.com/pypa/build 41 | clean: Remove old distribution files and temporary build artifacts (./build and ./dist) 42 | # https://setuptools.pypa.io/en/stable/build_meta.html#how-to-use-it 43 | skip_install = True 44 | changedir = {toxinidir} 45 | deps = 46 | build: build[virtualenv] 47 | passenv = 48 | SETUPTOOLS_* 49 | commands = 50 | clean: python -c 'import shutil; [shutil.rmtree(p, True) for p in ("build", "dist", "docs/_build")]' 51 | clean: python -c 'import pathlib, shutil; [shutil.rmtree(p, True) for p in pathlib.Path("src").glob("*.egg-info")]' 52 | build: python -m build {posargs} 53 | # By default, both `sdist` and `wheel` are built. If your sdist is too big or you don't want 54 | # to make it available, consider running: `tox -e build -- --wheel` 55 | 56 | 57 | [testenv:{docs,doctests,linkcheck}] 58 | description = 59 | docs: Invoke sphinx-build to build the docs 60 | doctests: Invoke sphinx-build to run doctests 61 | linkcheck: Check for broken links in the documentation 62 | passenv = 63 | SETUPTOOLS_* 64 | setenv = 65 | DOCSDIR = {toxinidir}/docs 66 | BUILDDIR = {toxinidir}/docs/_build 67 | docs: BUILD = html 68 | doctests: BUILD = doctest 69 | linkcheck: BUILD = linkcheck 70 | deps = 71 | -r {toxinidir}/docs/requirements.txt 72 | # ^ requirements.txt shared with Read The Docs 73 | commands = 74 | sphinx-build --color -b {env:BUILD} -d "{env:BUILDDIR}/doctrees" "{env:DOCSDIR}" "{env:BUILDDIR}/{env:BUILD}" {posargs} 75 | 76 | 77 | [testenv:publish] 78 | description = 79 | Publish the package you have been developing to a package index server. 80 | By default, it uses testpypi. If you really want to publish your package 81 | to be publicly accessible in PyPI, use the `-- --repository pypi` option. 82 | skip_install = True 83 | changedir = {toxinidir} 84 | passenv = 85 | # See: https://twine.readthedocs.io/en/latest/ 86 | TWINE_USERNAME 87 | TWINE_PASSWORD 88 | TWINE_REPOSITORY 89 | TWINE_REPOSITORY_URL 90 | deps = twine 91 | commands = 92 | python -m twine check dist/* 93 | python -m twine upload {posargs:--repository {env:TWINE_REPOSITORY:testpypi}} dist/* 94 | --------------------------------------------------------------------------------