├── .editorconfig ├── .gitattributes ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── custom.md │ └── feature_request.md └── workflows │ └── ci.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CHANGELOG.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── data ├── README.md └── _scripts │ ├── 1_download_all.sh │ ├── 2_prepare.py │ ├── cosmx_convert.py │ ├── cosmx_download.sh │ ├── merscope_convert.py │ ├── merscope_download.sh │ ├── stereoseq_convert.py │ ├── stereoseq_download.sh │ ├── xenium_convert.py │ └── xenium_download.sh ├── docs ├── api │ ├── Novae.md │ ├── dataloader.md │ ├── metrics.md │ ├── modules.md │ ├── plot.md │ └── utils.md ├── assets │ ├── Figure1.png │ ├── banner.png │ ├── logo_favicon.png │ ├── logo_small_black.png │ └── logo_white.png ├── cite_us.md ├── faq.md ├── getting_started.md ├── index.md ├── javascripts │ └── mathjax.js └── tutorials │ ├── LEE_AGING_CEREBELLUM_UP.json │ ├── hallmarks_pathways.json │ ├── input_modes.md │ ├── main_usage.ipynb │ ├── mouse_hallmarks.json │ ├── proteins.ipynb │ └── resolutions.ipynb ├── mkdocs.yml ├── novae ├── __init__.py ├── _constants.py ├── _logging.py ├── _settings.py ├── data │ ├── __init__.py │ ├── convert.py │ ├── datamodule.py │ └── dataset.py ├── model.py ├── module │ ├── __init__.py │ ├── aggregate.py │ ├── augment.py │ ├── embed.py │ ├── encode.py │ └── swav.py ├── monitor │ ├── __init__.py │ ├── callback.py │ ├── eval.py │ └── log.py ├── plot │ ├── __init__.py │ ├── _bar.py │ ├── _graph.py │ ├── _heatmap.py │ ├── _spatial.py │ └── _utils.py └── utils │ ├── __init__.py │ ├── _build.py │ ├── _correct.py │ ├── _data.py │ ├── _mode.py │ ├── _preprocess.py │ ├── _utils.py │ └── _validate.py ├── poetry.lock ├── pyproject.toml ├── scripts ├── README.md ├── __init__.py ├── config.py ├── config │ ├── README.md │ ├── _example.yaml │ ├── all_16.yaml │ ├── all_17.yaml │ ├── all_brain.yaml │ ├── all_human.yaml │ ├── all_human2.yaml │ ├── all_mouse.yaml │ ├── all_new.yaml │ ├── all_ruche.yaml │ ├── all_spot.yaml │ ├── brain.yaml │ ├── brain2.yaml │ ├── breast.yaml │ ├── breast_zs.yaml │ ├── breast_zs2.yaml │ ├── colon.yaml │ ├── colon_retrain.yaml │ ├── colon_zs.yaml │ ├── colon_zs2.yaml │ ├── local_tests.yaml │ ├── lymph_node.yaml │ ├── missing.yaml │ ├── ovarian.yaml │ ├── revision.yaml │ ├── revision_tests.yaml │ ├── toy_cpu_seed0.yaml │ └── toy_missing.yaml ├── missing_domain.py ├── revision │ ├── cpu.sh │ ├── heterogeneous.py │ ├── heterogeneous_start.py │ ├── mgc.py │ ├── missing_domains.py │ ├── perturbations.py │ └── seg_robustness.py ├── ruche │ ├── README.md │ ├── agent.sh │ ├── convert.sh │ ├── cpu.sh │ ├── debug_cpu.sh │ ├── debug_gpu.sh │ ├── download.sh │ ├── gpu.sh │ ├── prepare.sh │ └── test.sh ├── sweep │ ├── README.md │ ├── cpu.yaml │ ├── debug.yaml │ ├── gpu.yaml │ ├── gpu_ruche.yaml │ ├── lung.yaml │ ├── revision.yaml │ └── toy.yaml ├── toy_missing_domain.py ├── train.py └── utils.py ├── setup.py └── tests ├── __init__.py ├── _utils.py ├── test_correction.py ├── test_dataset.py ├── test_metrics.py ├── test_misc.py ├── test_model.py ├── test_plots.py └── test_utils.py /.editorconfig: -------------------------------------------------------------------------------- 1 | max_line_length = 120 2 | 3 | [*.json] 4 | indent_style = space 5 | indent_size = 4 6 | 7 | [*.yaml] 8 | indent_size = 2 9 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Ignoring notebooks for language statistics on Github 2 | .ipynb -linguist-detectable 3 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: "[Bug]" 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## Description 11 | A clear and concise description of what the bug is. 12 | 13 | ## Reproducing the issue 14 | Steps to reproduce the behavior. 15 | 16 | ## Expected behavior 17 | A clear and concise description of what you expected to happen. 18 | 19 | ## System 20 | - OS: [e.g. Linux] 21 | - Python version [e.g. 3.10.7] 22 |
23 | Dependencies versions 24 |
25 | Paste here what 'pip list' gives you. 26 |
27 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/custom.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Custom issue template 3 | about: Describe this issue template's purpose here. 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: "[Feature]" 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: test_deploy_publish 2 | on: 3 | push: 4 | tags: 5 | - v* 6 | pull_request: 7 | branches: [main] 8 | 9 | jobs: 10 | pre-commit: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v3 14 | - uses: actions/setup-python@v3 15 | - uses: pre-commit/action@v3.0.1 16 | 17 | build: 18 | needs: [pre-commit] 19 | runs-on: ubuntu-latest 20 | strategy: 21 | matrix: 22 | python-version: ["3.10", "3.11", "3.12"] 23 | steps: 24 | - uses: actions/checkout@v4 25 | - uses: actions/setup-python@v5 26 | with: 27 | python-version: ${{ matrix.python-version }} 28 | cache: "pip" 29 | - run: pip install '.[dev]' 30 | 31 | - name: Run tests 32 | run: pytest --cov 33 | 34 | - name: Deploy doc 35 | if: matrix.python-version == '3.10' && contains(github.ref, 'tags') 36 | run: mkdocs gh-deploy --force 37 | 38 | - name: Upload results to Codecov 39 | if: matrix.python-version == '3.10' 40 | uses: codecov/codecov-action@v4 41 | with: 42 | token: ${{ secrets.CODECOV_TOKEN }} 43 | 44 | publish: 45 | needs: [build] 46 | if: contains(github.ref, 'tags') 47 | runs-on: ubuntu-latest 48 | steps: 49 | - uses: actions/checkout@v3 50 | - name: Build and publish to pypi 51 | uses: JRubics/poetry-publish@v1.17 52 | with: 53 | python_version: "3.10" 54 | pypi_token: ${{ secrets.PYPI_TOKEN }} 55 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # OS related 2 | .DS_Store 3 | 4 | # Ruche related 5 | jobs/ 6 | 7 | # Tests related 8 | tests/* 9 | !tests/*.py 10 | 11 | # IDE related 12 | .vscode/* 13 | !.vscode/settings.json 14 | !.vscode/tasks.json 15 | !.vscode/launch.json 16 | !.vscode/extensions.json 17 | *.code-workspace 18 | **/.vscode 19 | 20 | # Documentation 21 | site 22 | 23 | # Logs 24 | outputs 25 | multirun 26 | lightning_logs 27 | novae_* 28 | checkpoints 29 | 30 | # Data files 31 | data/* 32 | !data/_scripts 33 | !data/README.md 34 | !data/*.py 35 | !data/*.sh 36 | 37 | # Results files 38 | results/* 39 | 40 | # Jupyter Notebook 41 | .ipynb_checkpoints 42 | *.ipynb 43 | !docs/tutorials/*.ipynb 44 | 45 | # Misc 46 | logs/ 47 | test.h5ad 48 | test.ckpt 49 | wandb/ 50 | .env 51 | .autoenv 52 | !**/.gitkeep 53 | exploration 54 | 55 | # pyenv 56 | .python-version 57 | 58 | # Byte-compiled / optimized / DLL files 59 | __pycache__/ 60 | *.py[cod] 61 | *$py.class 62 | 63 | # C extensions 64 | *.so 65 | 66 | # Distribution / packaging 67 | .Python 68 | build/ 69 | develop-eggs/ 70 | dist/ 71 | downloads/ 72 | eggs/ 73 | .eggs/ 74 | lib/ 75 | lib64/ 76 | parts/ 77 | sdist/ 78 | var/ 79 | wheels/ 80 | pip-wheel-metadata/ 81 | share/python-wheels/ 82 | *.egg-info/ 83 | .installed.cfg 84 | *.egg 85 | MANIFEST 86 | 87 | # PyInstaller 88 | # Usually these files are written by a python script from a template 89 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 90 | *.manifest 91 | *.spec 92 | 93 | # Installer logs 94 | pip-log.txt 95 | pip-delete-this-directory.txt 96 | 97 | # Unit test / coverage reports 98 | htmlcov/ 99 | .tox/ 100 | .nox/ 101 | .coverage 102 | .coverage.* 103 | .cache 104 | nosetests.xml 105 | coverage.xml 106 | *.cover 107 | *.py,cover 108 | .hypothesis/ 109 | .pytest_cache/ 110 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | fail_fast: false 2 | default_language_version: 3 | python: python3 4 | default_stages: 5 | - pre-commit 6 | - pre-push 7 | minimum_pre_commit_version: 2.16.0 8 | repos: 9 | - repo: https://github.com/astral-sh/ruff-pre-commit 10 | rev: v0.11.5 11 | hooks: 12 | - id: ruff 13 | - id: ruff-format 14 | - repo: https://github.com/pre-commit/pre-commit-hooks 15 | rev: v4.5.0 16 | hooks: 17 | - id: trailing-whitespace 18 | - id: end-of-file-fixer 19 | - id: check-yaml 20 | - id: debug-statements 21 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | ## [0.2.4] - 2025-03-26 2 | 3 | Hotfix (#18) 4 | 5 | ## [0.2.3] - 2025-03-21 6 | 7 | ### Added 8 | - New Visium-HD and Visium tutorials 9 | - Infer default plot size for plots (median neighbors distrance) 10 | - Can filter by technology in `load_dataset` 11 | 12 | ### Fixed 13 | - Fix `model.plot_prototype_covariance` and `model.plot_prototype_weights` 14 | - Edge case: ensure that even negative max-weights are above the prototype threshold 15 | 16 | ### Changed 17 | - `novae.utils.spatial_neighbors` can be now called via `novae.spatial_neighbors` 18 | - Store distances after `spatial_neighbors` instead of `1` when `coord_type="grid"` 19 | 20 | ## [0.2.2] - 2024-12-17 21 | 22 | ### Added 23 | - `load_dataset`: add `custom_filter` and `dry_run` arguments 24 | - added `min_prototypes_ratio` argument in `fine_tune` to run `init_slide_queue` 25 | - Added tutorials for proteins data + minor docs improvements 26 | 27 | ### Fixed 28 | - Ensure reset clustering if multiple zero-shot (#9) 29 | 30 | ### Changed 31 | - Removed the docs formatting (better for autocompletion) 32 | - Reorder parameters in Novae `__init__` (sorted by importance) 33 | 34 | ## [0.2.1] - 2024-12-04 35 | 36 | ### Added 37 | - `novae.utils.quantile_scaling` for proteins expression 38 | 39 | ### Fixed 40 | - Fix autocompletion using `__new__` to trick hugging_face inheritance 41 | 42 | 43 | ## [0.2.0] - 2024-12-03 44 | 45 | ### Added 46 | 47 | - `novae.plot.connectivities(...)` to show the cells neighbors 48 | - `novae.settings.auto_processing = False` to enforce using your own preprocessing 49 | - Tutorials update (more plots and more details) 50 | 51 | ### Fixed 52 | 53 | - Issue with `library_id` in `novae.plot.domains` (#8) 54 | - Set `pandas>=2.0.0` in the dependencies (#5) 55 | 56 | ### Breaking changes 57 | 58 | - `novae.utils.spatial_neighbors` must always be run, to force the user having more control on it 59 | - For multi-slide mode, the `slide_key` argument should now be used in `novae.utils.spatial_neighbors` (and only there) 60 | - Drop python 3.9 support (because dropped in `anndata`) 61 | 62 | ## [0.1.0] - 2024-09-11 63 | 64 | First official `novae` release. Preprint coming soon. 65 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to *novae* 2 | 3 | Contributions are welcome as we aim to continue improving `novae`. For instance, you can contribute by: 4 | 5 | - Opening an issue 6 | - Discussing the current state of the code 7 | - Making a Pull Request (PR) 8 | 9 | If you want to open a PR, follow the following instructions. 10 | 11 | ## Making a Pull Request (PR) 12 | 13 | To add some new code to **novae**, you should: 14 | 15 | 1. Fork the repository 16 | 2. Install `novae` in editable mode with the `dev` extra (see below) 17 | 3. Create your personal branch from `main` 18 | 4. Implement your changes according to the 'Coding guidelines' below 19 | 5. Create a pull request on the `main` branch of the original repository. Add explanations about your developed features, and wait for discussion and validation of your pull request 20 | 21 | ## Installing `novae` in editable mode 22 | 23 | When contributing, installing `novae` in editable mode is recommended. We also recommend installing the `dev` extra. 24 | 25 | For that, choose between `pip` and `poetry` as below: 26 | 27 | ```sh 28 | git clone https://github.com/MICS-Lab/novae.git 29 | cd novae 30 | 31 | pip install -e '.[dev]' # pip installation 32 | poetry install -E dev # poetry installation 33 | ``` 34 | 35 | ## Coding guidelines 36 | 37 | ### Styling and formatting 38 | 39 | We use [`pre-commit`](https://pre-commit.com/) to run code quality controls before the commits. This will run `ruff` and others minor checks. 40 | 41 | 42 | You can set it up at the root of the repository like this: 43 | ```sh 44 | pre-commit install 45 | ``` 46 | 47 | Then, it will run the pre-commit automatically before each commit. 48 | 49 | You can also run the pre-commit manually: 50 | ```sh 51 | pre-commit run --all-files 52 | ``` 53 | 54 | Apart from this, we recommend to follow the standard styling conventions: 55 | - Follow the [PEP8](https://peps.python.org/pep-0008/) style guide. 56 | - Provide meaningful names to all your variables and functions. 57 | - Provide type hints to your function inputs/outputs. 58 | - Add docstrings in the Google style. 59 | - Try as much as possible to follow the same coding style as the rest of the repository. 60 | 61 | ### Testing 62 | 63 | When create a pull request, tests are run automatically. But you can also run the tests yourself before making the PR. For that, run `pytest` at the root of the repository. You can also add new tests in the `./tests` directory. 64 | 65 | To check the coverage of the tests: 66 | 67 | ```sh 68 | coverage run -m pytest 69 | coverage report # command line report 70 | coverage html # html report 71 | ``` 72 | 73 | ### Documentation 74 | 75 | You can update the documentation in the `./docs` directory. Refer to the [mkdocs-material documentation](https://squidfunk.github.io/mkdocs-material/) for more help. 76 | 77 | To serve the documentation locally: 78 | 79 | ```sh 80 | mkdocs serve 81 | ``` 82 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2024, Quentin Blampey 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | novae_banner 3 |

4 | 5 |
6 | 7 | [![PyPI](https://img.shields.io/pypi/v/novae.svg)](https://pypi.org/project/novae) 8 | [![Downloads](https://static.pepy.tech/badge/novae)](https://pepy.tech/project/novae) 9 | [![Docs](https://img.shields.io/badge/docs-mkdocs-blue)](https://mics-lab.github.io/novae) 10 | ![Build](https://github.com/MICS-Lab/novae/workflows/ci/badge.svg) 11 | [![License](https://img.shields.io/pypi/l/novae.svg)](https://github.com/MICS-Lab/novae/blob/main/LICENSE) 12 | [![codecov](https://codecov.io/gh/MICS-Lab/novae/graph/badge.svg?token=FFI44M52O9)](https://codecov.io/gh/MICS-Lab/novae) 13 | [![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff) 14 | 15 |
16 | 17 |

18 | 💫 Graph-based foundation model for spatial transcriptomics data 19 |

20 | 21 | Novae is a deep learning model for spatial domain assignments of spatial transcriptomics data (at both single-cell or spot resolution). It works across multiple gene panels, tissues, and technologies. Novae offers several additional features, including: (i) native batch-effect correction, (ii) analysis of spatially variable genes and pathways, and (iii) architecture analysis of tissue slides. 22 | 23 | ## Documentation 24 | 25 | Check [Novae's documentation](https://mics-lab.github.io/novae/) to get started. It contains installation explanations, API details, and tutorials. 26 | 27 | ## Overview 28 | 29 |

30 | novae_overview 31 |

32 | 33 | > **(a)** Novae was trained on a large dataset, and is shared on [Hugging Face Hub](https://huggingface.co/collections/MICS-Lab/novae-669cdf1754729d168a69f6bd). **(b)** Illustration of the main tasks and properties of Novae. **(c)** Illustration of the method behind Novae (self-supervision on graphs, adapted from [SwAV](https://arxiv.org/abs/2006.09882)). 34 | 35 | ## Installation 36 | 37 | ### PyPI 38 | 39 | `novae` can be installed via `PyPI` on all OS, for `python>=3.10`. 40 | 41 | ``` 42 | pip install novae 43 | ``` 44 | 45 | ### Editable mode 46 | 47 | To install `novae` in editable mode (e.g., to contribute), clone the repository and choose among the options below. 48 | 49 | ```sh 50 | pip install -e . # pip, minimal dependencies 51 | pip install -e '.[dev]' # pip, all extras 52 | poetry install # poetry, minimal dependencies 53 | poetry install --all-extras # poetry, all extras 54 | ``` 55 | 56 | ## Usage 57 | 58 | Here is a minimal usage example. For more details, refer to the [documentation](https://mics-lab.github.io/novae/). 59 | 60 | ```python 61 | import novae 62 | 63 | model = novae.Novae.from_pretrained("MICS-Lab/novae-human-0") 64 | 65 | model.compute_representations(adata, zero_shot=True) 66 | model.assign_domains(adata) 67 | ``` 68 | 69 | ## Cite us 70 | 71 | You can cite our [preprint](https://www.biorxiv.org/content/10.1101/2024.09.09.612009v1) as below: 72 | 73 | ```txt 74 | @article{blampeyNovae2024, 75 | title = {Novae: A Graph-Based Foundation Model for Spatial Transcriptomics Data}, 76 | author = {Blampey, Quentin and Benkirane, Hakim and Bercovici, Nadege and Andre, Fabrice and Cournede, Paul-Henry}, 77 | year = {2024}, 78 | pages = {2024.09.09.612009}, 79 | publisher = {bioRxiv}, 80 | doi = {10.1101/2024.09.09.612009}, 81 | } 82 | ``` 83 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Public datasets 2 | 3 | We detail below how to download public spatial transcriptomics datasets. 4 | 5 | ## Option 1: Hugging Face Hub 6 | 7 | We store our dataset on [Hugging Face Hub](https://huggingface.co/datasets/MICS-Lab/novae). 8 | To automatically download these slides, you can use the [`novae.utils.load_dataset`](https://mics-lab.github.io/novae/api/utils/#novae.utils.load_dataset) function. 9 | 10 | NB: not all slides are uploaded on Hugging Face yet, but we are progressively adding new slides. To get the full dataset right now, use the "Option 2" below. 11 | 12 | ## Option 2: Download 13 | 14 | For consistency, all the scripts below need to be executed at the root of the `data` directory (i.e., `novae/data`). 15 | 16 | ### MERSCOPE (18 samples) 17 | 18 | Requirements: the `gsutil` command line should be installed (see [here](https://cloud.google.com/storage/docs/gsutil_install)) and a Python environment with `scanpy`. 19 | 20 | ```sh 21 | # download all MERSCOPE datasets 22 | sh _scripts/merscope_download.sh 23 | 24 | # convert all datasets to h5ad files 25 | python _scripts/merscope_convert.py 26 | ``` 27 | 28 | ### Xenium (20+ samples) 29 | 30 | Requirements: a Python environment with `spatialdata-io` installed. 31 | 32 | ```sh 33 | # download all Xenium datasets 34 | sh _scripts/xenium_download.sh 35 | 36 | # convert all datasets to h5ad files 37 | python _scripts/xenium_convert.py 38 | ``` 39 | 40 | ### CosMX (3 samples) 41 | 42 | Requirements: a Python environment with `scanpy` installed. 43 | 44 | ```sh 45 | # download all CosMX datasets 46 | sh _scripts/cosmx_download.sh 47 | 48 | # convert all datasets to h5ad files 49 | python _scripts/cosmx_convert.py 50 | ``` 51 | 52 | ### All datasets 53 | 54 | All above datasets can be downloaded using a single command line. Make sure you have all the requirements listed above. 55 | 56 | ```sh 57 | sh _scripts/1_download_all.sh 58 | ``` 59 | 60 | ### Preprocess and prepare for training 61 | 62 | The script bellow will copy all `adata.h5ad` files into a single directory, compute UMAPs, and minor preprocessing. See the `argparse` helper of this script for more details. 63 | 64 | ```sh 65 | python _scripts/2_prepare.py 66 | ``` 67 | 68 | ### Usage 69 | 70 | These datasets can be used during training (see the `scripts` directory at the root of the `novae` repository). 71 | 72 | ## Notes 73 | - Missing technologies: CosMX, Curio Seeker, Resolve 74 | - Public institute datasets with [STOmics DB](https://db.cngb.org/stomics/) 75 | - Some Xenium datasets are available outside of the main "10X Datasets" page: 76 | - https://www.10xgenomics.com/products/visium-hd-spatial-gene-expression/dataset-human-crc 77 | -------------------------------------------------------------------------------- /data/_scripts/1_download_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # download all MERSCOPE datasets 4 | sh merscope_download.sh 5 | 6 | # convert all datasets to h5ad files 7 | python merscope_convert.py 8 | 9 | # download all Xenium datasets 10 | sh xenium_download.sh 11 | 12 | # convert all datasets to h5ad files 13 | python xenium_convert.py 14 | 15 | # download all CosMX datasets 16 | sh cosmx_download.sh 17 | 18 | # convert all datasets to h5ad files 19 | python cosmx_convert.py 20 | -------------------------------------------------------------------------------- /data/_scripts/2_prepare.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import anndata 5 | import scanpy as sc 6 | from anndata import AnnData 7 | 8 | import novae 9 | 10 | MIN_CELLS = 50 11 | DELAUNAY_RADIUS = 100 12 | 13 | 14 | def preprocess(adata: AnnData, compute_umap: bool = False): 15 | sc.pp.filter_genes(adata, min_cells=MIN_CELLS) 16 | adata.layers["counts"] = adata.X.copy() 17 | sc.pp.normalize_total(adata) 18 | sc.pp.log1p(adata) 19 | 20 | if compute_umap: 21 | sc.pp.neighbors(adata) 22 | sc.tl.umap(adata) 23 | 24 | novae.utils.spatial_neighbors(adata, radius=[0, DELAUNAY_RADIUS]) 25 | 26 | 27 | def main(args): 28 | data_path: Path = Path(args.path).absolute() 29 | out_dir = data_path / args.name 30 | 31 | if not out_dir.exists(): 32 | out_dir.mkdir() 33 | 34 | for dataset in args.datasets: 35 | dataset_dir: Path = data_path / dataset 36 | for file in dataset_dir.glob("**/adata.h5ad"): 37 | print("Reading file", file) 38 | 39 | adata = anndata.read_h5ad(file) 40 | adata.obs["technology"] = dataset 41 | 42 | if "slide_id" not in adata.obs: 43 | print(" (no slide_id in obs, skipping)") 44 | continue 45 | 46 | out_file = out_dir / f"{adata.obs['slide_id'].iloc[0]}.h5ad" 47 | 48 | if out_file.exists() and not args.overwrite: 49 | print(" (already exists)") 50 | continue 51 | 52 | preprocess(adata, compute_umap=args.umap) 53 | adata.write_h5ad(out_file) 54 | 55 | 56 | if __name__ == "__main__": 57 | parser = argparse.ArgumentParser() 58 | parser.add_argument( 59 | "-p", 60 | "--path", 61 | type=str, 62 | default=".", 63 | help="Path to spatial directory", 64 | ) 65 | parser.add_argument( 66 | "-n", 67 | "--name", 68 | type=str, 69 | default="all", 70 | help="Name of the resulting data directory", 71 | ) 72 | parser.add_argument( 73 | "-d", 74 | "--datasets", 75 | nargs="+", 76 | default=["xenium", "merscope", "cosmx"], 77 | help="List of dataset names to concatenate", 78 | ) 79 | parser.add_argument( 80 | "-o", 81 | "--overwrite", 82 | action="store_true", 83 | help="Overwrite existing output files", 84 | ) 85 | parser.add_argument( 86 | "-u", 87 | "--umap", 88 | action="store_true", 89 | help="Whether to compute the UMAP embedding", 90 | ) 91 | 92 | args = parser.parse_args() 93 | main(args) 94 | -------------------------------------------------------------------------------- /data/_scripts/cosmx_convert.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import anndata 5 | import pandas as pd 6 | from scipy.sparse import csr_matrix 7 | 8 | 9 | def convert_to_h5ad(dataset_dir: Path): 10 | print(f"Reading {dataset_dir}") 11 | res_path = dataset_dir / "adata.h5ad" 12 | 13 | if res_path.exists(): 14 | print(f"File {res_path} already existing.") 15 | return 16 | 17 | slide_id = f"cosmx_{dataset_dir.name}" 18 | 19 | counts_files = list(dataset_dir.glob("*exprMat_file.csv")) 20 | metadata_files = list(dataset_dir.glob("*metadata_file.csv")) 21 | 22 | if len(counts_files) != 1 or len(metadata_files) != 1: 23 | print(f"Did not found both exprMat and metadata csv inside {dataset_dir}. Skipping this directory.") 24 | return 25 | 26 | data = pd.read_csv(counts_files[0], index_col=[0, 1]) 27 | obs = pd.read_csv(metadata_files[0], index_col=[0, 1]) 28 | 29 | data.index = data.index.map(lambda x: f"{x[0]}-{x[1]}") 30 | obs.index = obs.index.map(lambda x: f"{x[0]}-{x[1]}") 31 | 32 | if len(data) != len(obs): 33 | cell_ids = list(set(data.index) & set(obs.index)) 34 | data = data.loc[cell_ids] 35 | obs = obs.loc[cell_ids] 36 | 37 | obs.index = obs.index.astype(str) + f"_{slide_id}" 38 | data.index = obs.index 39 | 40 | is_gene = ~data.columns.str.lower().str.contains("SystemControl") 41 | 42 | adata = anndata.AnnData(data.loc[:, is_gene], obs=obs) 43 | 44 | adata.obsm["spatial"] = adata.obs[["CenterX_global_px", "CenterY_global_px"]].values * 0.120280945 45 | adata.obs["slide_id"] = pd.Series(slide_id, index=adata.obs_names, dtype="category") 46 | 47 | adata.X = csr_matrix(adata.X) 48 | adata.write_h5ad(res_path) 49 | 50 | print(f"Created file at path {res_path}") 51 | 52 | 53 | def main(args): 54 | path = Path(args.path).absolute() / "cosmx" 55 | 56 | print(f"Reading all datasets inside {path}") 57 | 58 | for dataset_dir in path.iterdir(): 59 | if dataset_dir.is_dir(): 60 | convert_to_h5ad(dataset_dir) 61 | 62 | 63 | if __name__ == "__main__": 64 | parser = argparse.ArgumentParser() 65 | parser.add_argument( 66 | "-p", 67 | "--path", 68 | type=str, 69 | default=".", 70 | help="Path to spatial directory (containing the 'cosmx' directory)", 71 | ) 72 | 73 | main(parser.parse_args()) 74 | -------------------------------------------------------------------------------- /data/_scripts/cosmx_download.sh: -------------------------------------------------------------------------------- 1 | # Pancreas 2 | PANCREAS_FLAT_FILES="https://smi-public.objects.liquidweb.services/cosmx-wtx/Pancreas-CosMx-WTx-FlatFiles.zip" 3 | PANCREAS_OUTPUT_ZIP="cosmx/pancreas/Pancreas-CosMx-WTx-FlatFiles.zip" 4 | 5 | mkdir -p cosmx/pancreas 6 | 7 | if [ -f $PANCREAS_OUTPUT_ZIP ]; then 8 | echo "File $PANCREAS_OUTPUT_ZIP already exists." 9 | else 10 | echo "Downloading $PANCREAS_FLAT_FILES to $PANCREAS_OUTPUT_ZIP" 11 | curl $PANCREAS_FLAT_FILES -o $PANCREAS_OUTPUT_ZIP 12 | unzip $PANCREAS_OUTPUT_ZIP -d cosmx/pancreas 13 | fi 14 | 15 | # Normal Liver 16 | mkdir -p cosmx/normal_liver 17 | METADATA_FILE="https://nanostring.app.box.com/index.php?rm=box_download_shared_file&shared_name=id16si2dckxqqpilexl2zg90leo57grn&file_id=f_1392279064291" 18 | METADATA_OUTPUT="cosmx/normal_liver/metadata_file.csv" 19 | if [ -f $METADATA_OUTPUT ]; then 20 | echo "File $METADATA_OUTPUT already exists." 21 | else 22 | echo "Downloading $METADATA_FILE to $METADATA_OUTPUT" 23 | curl -L $METADATA_FILE -o $METADATA_OUTPUT 24 | fi 25 | COUNT_FILE="https://nanostring.app.box.com/index.php?rm=box_download_shared_file&shared_name=id16si2dckxqqpilexl2zg90leo57grn&file_id=f_1392318918584" 26 | COUNT_OUTPUT="cosmx/normal_liver/exprMat_file.csv" 27 | if [ -f $COUNT_OUTPUT ]; then 28 | echo "File $COUNT_OUTPUT already exists." 29 | else 30 | echo "Downloading $COUNT_FILE to $COUNT_OUTPUT" 31 | curl -L $COUNT_FILE -o $COUNT_OUTPUT 32 | fi 33 | 34 | # Cancer Liver 35 | mkdir -p cosmx/cancer_liver 36 | METADATA_FILE="https://nanostring.app.box.com/index.php?rm=box_download_shared_file&shared_name=id16si2dckxqqpilexl2zg90leo57grn&file_id=f_1392293795557" 37 | METADATA_OUTPUT="cosmx/cancer_liver/metadata_file.csv" 38 | if [ -f $METADATA_OUTPUT ]; then 39 | echo "File $METADATA_OUTPUT already exists." 40 | else 41 | echo "Downloading $METADATA_FILE to $METADATA_OUTPUT" 42 | curl -L $METADATA_FILE -o $METADATA_OUTPUT 43 | fi 44 | COUNT_FILE="https://nanostring.app.box.com/index.php?rm=box_download_shared_file&shared_name=id16si2dckxqqpilexl2zg90leo57grn&file_id=f_1392441469377" 45 | COUNT_OUTPUT="cosmx/cancer_liver/exprMat_file.csv" 46 | if [ -f $COUNT_OUTPUT ]; then 47 | echo "File $COUNT_OUTPUT already exists." 48 | else 49 | echo "Downloading $COUNT_FILE to $COUNT_OUTPUT" 50 | curl -L $COUNT_FILE -o $COUNT_OUTPUT 51 | fi 52 | -------------------------------------------------------------------------------- /data/_scripts/merscope_convert.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import anndata 5 | import numpy as np 6 | import pandas as pd 7 | from scipy.sparse import csr_matrix 8 | 9 | 10 | def convert_to_h5ad(dataset_dir: Path): 11 | res_path = dataset_dir / "adata.h5ad" 12 | 13 | if res_path.exists(): 14 | print(f"File {res_path} already existing.") 15 | return 16 | 17 | region = "region_0" 18 | slide_id = f"{dataset_dir.name}_{region}" 19 | 20 | data_dir = dataset_dir / "cell_by_gene.csv" 21 | obs_dir = dataset_dir / "cell_metadata.csv" 22 | 23 | if not data_dir.exists() or not obs_dir.exists(): 24 | print(f"Did not found both csv inside {dataset_dir}. Skipping this directory.") 25 | return 26 | 27 | data = pd.read_csv(data_dir, index_col=0, dtype={"cell": str}) 28 | obs = pd.read_csv(obs_dir, index_col=0, dtype={"EntityID": str}) 29 | 30 | obs.index = obs.index.astype(str) + f"_{slide_id}" 31 | data.index = data.index.astype(str) + f"_{slide_id}" 32 | obs = obs.loc[data.index] 33 | 34 | is_gene = ~data.columns.str.lower().str.contains("blank") 35 | 36 | adata = anndata.AnnData(data.loc[:, is_gene], dtype=np.uint16, obs=obs) 37 | 38 | adata.obsm["spatial"] = adata.obs[["center_x", "center_y"]].values 39 | adata.obs["region"] = pd.Series(region, index=adata.obs_names, dtype="category") 40 | adata.obs["slide_id"] = pd.Series(slide_id, index=adata.obs_names, dtype="category") 41 | 42 | adata.X = csr_matrix(adata.X) 43 | adata.write_h5ad(res_path) 44 | 45 | print(f"Created file at path {res_path}") 46 | 47 | 48 | def main(args): 49 | path = Path(args.path).absolute() / "merscope" 50 | 51 | print(f"Reading all datasets inside {path}") 52 | 53 | for dataset_dir in path.iterdir(): 54 | if dataset_dir.is_dir(): 55 | convert_to_h5ad(dataset_dir) 56 | 57 | 58 | if __name__ == "__main__": 59 | parser = argparse.ArgumentParser() 60 | parser.add_argument( 61 | "-p", 62 | "--path", 63 | type=str, 64 | default=".", 65 | help="Path to spatial directory (containing the 'merscope' directory)", 66 | ) 67 | 68 | main(parser.parse_args()) 69 | -------------------------------------------------------------------------------- /data/_scripts/merscope_download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | BUCKET_NAME="vz-ffpe-showcase" 4 | OUTPUT_DIR="./merscope" 5 | 6 | for BUCKET_DIR in $(gsutil ls -d gs://$BUCKET_NAME/*); do 7 | DATASET_NAME=$(basename $BUCKET_DIR) 8 | OUTPUT_DATASET_DIR=$OUTPUT_DIR/$DATASET_NAME 9 | 10 | mkdir -p $OUTPUT_DATASET_DIR 11 | 12 | for BUCKET_FILE in ${BUCKET_DIR}{cell_by_gene,cell_metadata}.csv; do 13 | FILE_NAME=$(basename $BUCKET_FILE) 14 | 15 | if [ -f $OUTPUT_DATASET_DIR/$FILE_NAME ]; then 16 | echo "File $FILE_NAME already exists in $OUTPUT_DATASET_DIR" 17 | else 18 | echo "Copying $BUCKET_FILE to $OUTPUT_DATASET_DIR" 19 | gsutil cp "$BUCKET_FILE" $OUTPUT_DATASET_DIR 20 | echo "Copied successfully" 21 | fi 22 | done 23 | done 24 | -------------------------------------------------------------------------------- /data/_scripts/stereoseq_convert.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import anndata 5 | import pandas as pd 6 | from scipy.sparse import csr_matrix 7 | 8 | 9 | def convert_to_h5ad(dataset_dir: Path): 10 | res_path = dataset_dir / "adata.h5ad" 11 | 12 | if res_path.exists(): 13 | print(f"File {res_path} already existing.") 14 | return 15 | 16 | slide_id = f"stereoseq_{dataset_dir.name}" 17 | 18 | h5ad_files = dataset_dir.glob(".h5ad") 19 | 20 | if len(h5ad_files) != 1: 21 | print(f"Found {len(h5ad_files)} h5ad file inside {dataset_dir}. Skipping this directory.") 22 | return 23 | 24 | adata = anndata.read_h5ad(h5ad_files[0]) 25 | adata.X = adata.layers["raw_counts"] 26 | del adata.layers["raw_counts"] 27 | 28 | adata.obsm["spatial"] = adata.obsm["spatial"].astype(float).values 29 | adata.obs["slide_id"] = pd.Series(slide_id, index=adata.obs_names, dtype="category") 30 | 31 | adata.X = csr_matrix(adata.X) 32 | adata.write_h5ad(res_path) 33 | 34 | print(f"Created file at path {res_path}") 35 | 36 | 37 | def main(args): 38 | path = Path(args.path).absolute() / "stereoseq" 39 | 40 | print(f"Reading all datasets inside {path}") 41 | 42 | for dataset_dir in path.iterdir(): 43 | if dataset_dir.is_dir(): 44 | convert_to_h5ad(dataset_dir) 45 | 46 | 47 | if __name__ == "__main__": 48 | parser = argparse.ArgumentParser() 49 | parser.add_argument( 50 | "-p", 51 | "--path", 52 | type=str, 53 | default=".", 54 | help="Path to spatial directory (containing the 'stereoseq' directory)", 55 | ) 56 | 57 | main(parser.parse_args()) 58 | -------------------------------------------------------------------------------- /data/_scripts/stereoseq_download.sh: -------------------------------------------------------------------------------- 1 | H5AD_REMOTE_PATHS=(\ 2 | "https://ftp.cngb.org/pub/SciRAID/stomics/STDS0000062/stomics/FP200000498TL_D2_stereoseq.h5ad"\ 3 | "https://ftp.cngb.org/pub/SciRAID/stomics/STDS0000062/stomics/FP200000498TL_E4_stereoseq.h5ad"\ 4 | "https://ftp.cngb.org/pub/SciRAID/stomics/STDS0000062/stomics/FP200000498TL_E5_stereoseq.h5ad"\ 5 | ) 6 | 7 | OUTPUT_DIR="stereoseq" 8 | mkdir -p $OUTPUT_DIR 9 | 10 | for H5AD_REMOTE_PATH in "${H5AD_REMOTE_PATHS[@]}" 11 | do 12 | DATASET_NAME=$(basename $H5AD_REMOTE_PATH) 13 | OUTPUT_DATASET_DIR=${OUTPUT_DIR}/${DATASET_NAME%.h5ad} 14 | OUTPUT_DATASET=$OUTPUT_DATASET_DIR/${DATASET_NAME} 15 | 16 | mkdir -p $OUTPUT_DATASET_DIR 17 | 18 | if [ -f $OUTPUT_DATASET ]; then 19 | echo "File $OUTPUT_DATASET_DIR already exists" 20 | else 21 | echo "Downloading $H5AD_REMOTE_PATH to $OUTPUT_DATASET" 22 | # curl $H5AD_REMOTE_PATH -o $OUTPUT_DATASET 23 | echo "Successfully downloaded" 24 | fi 25 | done 26 | -------------------------------------------------------------------------------- /data/_scripts/xenium_convert.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import anndata 5 | import pandas as pd 6 | from spatialdata_io.readers.xenium import _get_tables_and_circles 7 | 8 | 9 | def convert_to_h5ad(dataset_dir: Path): 10 | res_path = dataset_dir / "adata.h5ad" 11 | 12 | if res_path.exists(): 13 | print(f"File {res_path} already existing.") 14 | return 15 | 16 | adata: anndata.AnnData = _get_tables_and_circles(dataset_dir, False, {"region": "region_0"}) 17 | adata.obs["cell_id"] = adata.obs["cell_id"].apply(lambda x: x if (isinstance(x, (str, int))) else x.decode("utf-8")) 18 | 19 | slide_id = dataset_dir.name 20 | adata.obs.index = adata.obs["cell_id"].astype(str).values + f"_{slide_id}" 21 | 22 | adata.obs["slide_id"] = pd.Series(slide_id, index=adata.obs_names, dtype="category") 23 | 24 | adata.write_h5ad(res_path) 25 | 26 | print(f"Created file at path {res_path}") 27 | 28 | 29 | def main(args): 30 | path = Path(args.path).absolute() / "xenium" 31 | 32 | print(f"Reading all datasets inside {path}") 33 | 34 | for dataset_dir in path.iterdir(): 35 | if dataset_dir.is_dir(): 36 | print(f"In {dataset_dir}") 37 | try: 38 | convert_to_h5ad(dataset_dir) 39 | except: 40 | print(f"Failed to convert {dataset_dir}") 41 | 42 | 43 | if __name__ == "__main__": 44 | parser = argparse.ArgumentParser() 45 | parser.add_argument( 46 | "-p", 47 | "--path", 48 | type=str, 49 | default=".", 50 | help="Path to spatial directory (containing the 'xenium' directory)", 51 | ) 52 | 53 | main(parser.parse_args()) 54 | -------------------------------------------------------------------------------- /docs/api/Novae.md: -------------------------------------------------------------------------------- 1 | ::: novae.Novae 2 | -------------------------------------------------------------------------------- /docs/api/dataloader.md: -------------------------------------------------------------------------------- 1 | ::: novae.data.AnnDataTorch 2 | 3 | ::: novae.data.NovaeDataset 4 | 5 | ::: novae.data.NovaeDatamodule 6 | -------------------------------------------------------------------------------- /docs/api/metrics.md: -------------------------------------------------------------------------------- 1 | ::: novae.monitor.jensen_shannon_divergence 2 | 3 | ::: novae.monitor.fide_score 4 | 5 | ::: novae.monitor.mean_fide_score 6 | -------------------------------------------------------------------------------- /docs/api/modules.md: -------------------------------------------------------------------------------- 1 | ::: novae.module.AttentionAggregation 2 | 3 | ::: novae.module.CellEmbedder 4 | 5 | ::: novae.module.GraphAugmentation 6 | 7 | ::: novae.module.GraphEncoder 8 | 9 | ::: novae.module.SwavHead 10 | -------------------------------------------------------------------------------- /docs/api/plot.md: -------------------------------------------------------------------------------- 1 | ::: novae.plot.domains 2 | 3 | ::: novae.plot.domains_proportions 4 | 5 | ::: novae.plot.connectivities 6 | 7 | ::: novae.plot.pathway_scores 8 | 9 | ::: novae.plot.paga 10 | 11 | ::: novae.plot.spatially_variable_genes 12 | -------------------------------------------------------------------------------- /docs/api/utils.md: -------------------------------------------------------------------------------- 1 | ::: novae.spatial_neighbors 2 | 3 | ::: novae.batch_effect_correction 4 | 5 | ::: novae.utils.quantile_scaling 6 | 7 | ::: novae.utils.prepare_adatas 8 | 9 | ::: novae.utils.load_dataset 10 | 11 | ::: novae.utils.toy_dataset 12 | -------------------------------------------------------------------------------- /docs/assets/Figure1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MICS-Lab/novae/b641a0bd1947759317c9e7ab4d997bb7a4a00932/docs/assets/Figure1.png -------------------------------------------------------------------------------- /docs/assets/banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MICS-Lab/novae/b641a0bd1947759317c9e7ab4d997bb7a4a00932/docs/assets/banner.png -------------------------------------------------------------------------------- /docs/assets/logo_favicon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MICS-Lab/novae/b641a0bd1947759317c9e7ab4d997bb7a4a00932/docs/assets/logo_favicon.png -------------------------------------------------------------------------------- /docs/assets/logo_small_black.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MICS-Lab/novae/b641a0bd1947759317c9e7ab4d997bb7a4a00932/docs/assets/logo_small_black.png -------------------------------------------------------------------------------- /docs/assets/logo_white.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MICS-Lab/novae/b641a0bd1947759317c9e7ab4d997bb7a4a00932/docs/assets/logo_white.png -------------------------------------------------------------------------------- /docs/cite_us.md: -------------------------------------------------------------------------------- 1 | You can cite our [preprint](https://www.biorxiv.org/content/10.1101/2024.09.09.612009v1) as below: 2 | 3 | ```txt 4 | @article{blampeyNovae2024, 5 | title = {Novae: A Graph-Based Foundation Model for Spatial Transcriptomics Data}, 6 | author = {Blampey, Quentin and Benkirane, Hakim and Bercovici, Nadege and Andre, Fabrice and Cournede, Paul-Henry}, 7 | year = {2024}, 8 | pages = {2024.09.09.612009}, 9 | publisher = {bioRxiv}, 10 | doi = {10.1101/2024.09.09.612009}, 11 | } 12 | ``` 13 | 14 | This library has been developed by Quentin Blampey, PhD student in biomathematics and deep learning. The following institutions funded this work: 15 | 16 | - Lab of Mathematics and Computer Science (MICS), **CentraleSupélec** (Engineering School, Paris-Saclay University). 17 | - PRISM center, **Gustave Roussy Institute** (Cancer campus, Paris-Saclay University). 18 | -------------------------------------------------------------------------------- /docs/faq.md: -------------------------------------------------------------------------------- 1 | # Frequently asked questions 2 | 3 | ### How to use the GPU? 4 | 5 | Using a GPU may significantly speed up Novae's training or inference. 6 | 7 | If you have a valid GPU for PyTorch, you can set the `accelerator` argument (e.g., one of `["cpu", "gpu", "tpu", "hpu", "mps", "auto"]`) in the following methods: [model.fit()](../api/Novae/#novae.Novae.fit), [model.fine_tune()](../api/Novae/#novae.Novae.fine_tune), [model.compute_representations()](../api/Novae/#novae.Novae.compute_representations). 8 | 9 | When using a GPU, we also highly recommend setting multiple workers to speed up the dataset `__getitem__`. For that, you'll need to set the `num_workers` argument in the previous methods, according to the number of CPUs available (`num_workers=8` is usually a good value). 10 | 11 | For more details, refer to the API of the [PyTorch Lightning Trainer](https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api) and to the API of the [PyTorch DataLoader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader). 12 | 13 | ### How to load a pretrained model? 14 | 15 | We highly recommend loading a pre-trained Novae model instead of re-training from scratch. For that, choose an available Novae model name on [our HuggingFace collection](https://huggingface.co/collections/MICS-Lab/novae-669cdf1754729d168a69f6bd), and provide this name to the [model.save_pretrained()](../api/Novae/#novae.Novae.save_pretrained) method: 16 | 17 | ```python 18 | from novae import Novae 19 | 20 | model = Novae.from_pretrained("MICS-Lab/novae-human-0") # or any valid model name 21 | ``` 22 | 23 | ### How to avoid overcorrecting? 24 | 25 | By default, Novae corrects the batch-effect to get shared spatial domains across slides. 26 | The batch information is used only during training (`fit` or `fine_tune`), which should prevent Novae from overcorrecting in `zero_shot` mode. 27 | 28 | If not using the `zero_shot` mode, you can provide the `min_prototypes_ratio` parameter to control batch effect correction: either (i) in the `fine_tune` method itself, or (ii) during the model initialization (if retraining a model from scratch). 29 | 30 | For instance, if `min_prototypes_ratio=0.5`, Novae expects each slide to contain at least 50% of the prototypes (each prototype can be interpreted as an "elementary spatial domain"). Therefore, the lower `min_prototypes_ratio`, the lower the batch-effect correction. Conversely, if `min_prototypes_ratio=1`, all prototypes are expected to be found in all slides (this doesn't mean the proportions will be the same overall slides, though). 31 | 32 | ### How do I save my own model? 33 | 34 | If you have trained or fine-tuned your own Novae model, you can save it for later use. For that, use the [model.save_pretrained()](../api/Novae/#novae.Novae.save_pretrained) method as below: 35 | 36 | ```python 37 | model.save_pretrained(save_directory="./my-model-directory") 38 | ``` 39 | 40 | Then, you can load this model back via the [model.from_pretrained()](../api/Novae/#novae.Novae.from_pretrained) method: 41 | 42 | ```python 43 | from novae import Novae 44 | 45 | model = Novae.from_pretrained("./my-model-directory") 46 | ``` 47 | 48 | ### How to turn lazy loading on or off? 49 | 50 | By default, lazy loading is used only on large datasets. To enforce a specific behavior, you can do the following: 51 | 52 | ```python 53 | # never use lazy loading 54 | novae.settings.disable_lazy_loading() 55 | 56 | # always use lazy loading 57 | novae.settings.enable_lazy_loading() 58 | 59 | # use lazy loading only for AnnData objects with 1M+ cells 60 | novae.settings.enable_lazy_loading(n_obs_threshold=1_000_000) 61 | ``` 62 | 63 | ### How to update the logging level? 64 | 65 | The logging level can be updated as below: 66 | 67 | ```python 68 | import logging 69 | from novae import log 70 | 71 | log.setLevel(logging.ERROR) # or any other level, e.g. logging.DEBUG 72 | ``` 73 | 74 | ### How to disable auto-preprocessing 75 | 76 | By default, Novae automatically run data preprocessing for you. If you don't want that, you can run the line below. 77 | 78 | ```python 79 | novae.settings.auto_preprocessing = False 80 | ``` 81 | 82 | ### How long does it take to use Novae? 83 | 84 | The `pip` installation of Novae usually takes less than a minute on a standard laptop. The inference time depends on the number of cells, but typically takes 5-20 minutes on a CPUs, or 30sec to 2 minutes on a GPU (expect it to be roughly 10x times faster on a GPU). 85 | 86 | ### How to contribute? 87 | 88 | If you want to contribute, check our [contributing guide](https://github.com/MICS-Lab/novae/blob/main/CONTRIBUTING.md). 89 | 90 | ### How to resolve any other issue? 91 | 92 | If you have any bugs/questions/suggestions, don't hesitate to [open a new issue](https://github.com/MICS-Lab/novae/issues). 93 | -------------------------------------------------------------------------------- /docs/getting_started.md: -------------------------------------------------------------------------------- 1 | ## Installation 2 | 3 | Novae can be installed on every OS via `pip` or [`poetry`](https://python-poetry.org/docs/), on any Python version from `3.10` to `3.12` (included). By default, we recommend using `python==3.10`. 4 | 5 | !!! note "Advice (optional)" 6 | 7 | We advise creating a new environment via a package manager, except if you use Poetry, which will automatically create the environment. 8 | 9 | For instance, you can create a new `conda` environment: 10 | 11 | ```bash 12 | conda create --name novae python=3.10 13 | conda activate novae 14 | ``` 15 | 16 | Choose one of the following, depending on your needs. It should take at most a few minutes. 17 | 18 | === "From PyPI" 19 | 20 | ``` bash 21 | pip install novae 22 | ``` 23 | 24 | === "pip (editable mode)" 25 | 26 | ``` bash 27 | git clone https://github.com/MICS-Lab/novae.git 28 | cd novae 29 | 30 | pip install -e . # no extra 31 | pip install -e '.[dev]' # all extras 32 | ``` 33 | 34 | === "Poetry (editable mode)" 35 | 36 | ``` bash 37 | git clone https://github.com/MICS-Lab/novae.git 38 | cd novae 39 | 40 | poetry install --all-extras 41 | ``` 42 | 43 | ## Next steps 44 | 45 | - We recommend to start with our [first tutorial](../tutorials/main_usage). 46 | - You can also read the [API](../api/Novae). 47 | - If you have questions, please check our [FAQ](../faq) or open an issue on the [GitHub repository](https://github.com/MICS-Lab/novae). 48 | - If you want to contribute, check our [contributing guide](https://github.com/MICS-Lab/novae/blob/main/CONTRIBUTING.md). 49 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 |

2 | novae_logo 3 |

4 | 5 |

6 | 💫 Graph-based foundation model for spatial transcriptomics data 7 |

8 | 9 | Novae is a deep learning model for spatial domain assignments of spatial transcriptomics data (at both single-cell or spot resolution). It works across multiple gene panels, tissues, and technologies. Novae offers several additional features, including: (i) native batch-effect correction, (ii) analysis of spatially variable genes and pathways, and (iii) architecture analysis of tissue slides. 10 | 11 | ## Overview 12 | 13 |

14 | novae_overview 15 |

16 | 17 | > **(a)** Novae was trained on a large dataset, and is shared on [Hugging Face Hub](https://huggingface.co/collections/MICS-Lab/novae-669cdf1754729d168a69f6bd). **(b)** Illustration of the main tasks and properties of Novae. **(c)** Illustration of the method behind Novae (self-supervision on graphs, adapted from [SwAV](https://arxiv.org/abs/2006.09882)). 18 | 19 | 20 | ## Why using Novae 21 | 22 | - It is already pretrained on a large dataset (pan human/mouse tissues, brain, ...). Therefore, you can compute spatial domains in a zero-shot manner (i.e., without fine-tuning). 23 | - It has been developed to find consistent domains across many slides. This also works if you have different technologies (e.g., MERSCOPE/Xenium) and multiple gene panels. 24 | - You can natively correct batch-effect, without using external tools. 25 | - After inference, the spatial domain assignment is super fast, allowing you to try multiple resolutions easily. 26 | - It supports many downstream tasks, all included inside one framework. 27 | -------------------------------------------------------------------------------- /docs/javascripts/mathjax.js: -------------------------------------------------------------------------------- 1 | window.MathJax = { 2 | tex: { 3 | inlineMath: [["\\(", "\\)"]], 4 | displayMath: [["\\[", "\\]"]], 5 | processEscapes: true, 6 | processEnvironments: true 7 | }, 8 | options: { 9 | ignoreHtmlClass: ".*|", 10 | processHtmlClass: "arithmatex" 11 | } 12 | }; 13 | 14 | document$.subscribe(() => { 15 | MathJax.typesetPromise() 16 | }) 17 | -------------------------------------------------------------------------------- /docs/tutorials/LEE_AGING_CEREBELLUM_UP.json: -------------------------------------------------------------------------------- 1 | { 2 | "LEE_AGING_CEREBELLUM_UP" : {"systematicName":"MM1023","pmid":"10888876","exactSource":"Table 5S","geneSymbols":["Acadvl","Agt","Amh","Apc","Apoe","Axl","B2m","Bcl2a1a","Bdnf","C1qa","C1qb","C1qc","C4b","Capn2","Ccl21b","Cd24a","Cd68","Cdk4","Cst7","Ctsd","Ctsh","Ctss","Ctsz","Dnajb2","Efs","Eif2b5","Eprs1","Eps15","F2","Fcrl2","Fos","Gbp3","Gck","Gfap","Gnb2","Gng11","H2-Ab1","Hexb","Hmox1","Hnrnph3","Hoxa4","Hoxd12","Hspa8","Iars1","Ifi27","Ifit1","Ighm","Impa1","Irf7","Irgm1","Itgb5","Lgals3","Lgals3bp","Lmnb1","Mnat1","Mpeg1","Myh8","Nfya","Nos3","Notch1","Nr4a1","Or2c1","Pglyrp1","Pigf","Psg-ps1","Ptbp2","Ptpro","Rhog","Rpsa","Selplg","Sez6","Sipa1l2","Slc11a1","Slc7a3","Snta1","Spp1","Tbc1d1","Tbx6","Tgfbr3","Thbs2","Trappc5","Trim30a","Tyms","Ube2h","Wdfy3","Zfp40"],"msigdbURL":"https://www.gsea-msigdb.org/gsea/msigdb/mouse/geneset/LEE_AGING_CEREBELLUM_UP","externalDetailsURL":[],"filteredBySimilarity":[],"externalNamesForSimilarTerms":[],"collection":"M2:CGP"} 3 | } 4 | -------------------------------------------------------------------------------- /docs/tutorials/input_modes.md: -------------------------------------------------------------------------------- 1 | Depending on your data and preferences, you can use 4 types of inputs. 2 | Specifically, it depends on whether (i) you have one or multiple slides and (ii) you prefer to concatenate your data. 3 | 4 | !!! info 5 | In all cases, the data structure is [AnnData](https://anndata.readthedocs.io/en/latest/). We may support MuData in the future. 6 | 7 | ## 1. One slide mode 8 | 9 | This case is the easiest one. You simply have one `AnnData` object corresponding to one slide. 10 | 11 | You can follow the first section of the [main usage tutorial](../main_usage). 12 | 13 | ## 2. Multiple slides, one AnnData object 14 | 15 | If you have multiple slides with the same gene panel, you can concatenate them into one `AnnData` object. In that case, make sure you keep a column in `adata.obs` that denotes which cell corresponds to which slide. 16 | 17 | Then, remind this column, and pass it to [`novae.spatial_neighbors`](../../api/utils/#novae.spatial_neighbors). 18 | 19 | !!! example 20 | For instance, you can do: 21 | ```python 22 | novae.spatial_neighbors(adata, slide_key="my-slide-id-column") 23 | ``` 24 | 25 | ## 3. Multiple slides, one AnnData object per slide 26 | 27 | If you have multiple slides, you may prefer to keep one `AnnData` object for each slide. This is also convenient if you have different gene panels and can't concatenate your data. 28 | 29 | That case is pretty easy, since most functions and methods of Novae also support a **list of `AnnData` objects** as inputs. Therefore, simply pass a list of `AnnData` object, as below: 30 | 31 | !!! example 32 | ```python 33 | adatas = [adata_1, adata_2, ...] 34 | 35 | novae.spatial_neighbors(adatas) 36 | 37 | model.compute_representations(adatas, zero_shot=True) 38 | ``` 39 | 40 | ## 4. Multiple slides, multiple slides per AnnData object 41 | 42 | If you have multiple slides and multiple panels, instead of the above option, you could have one `AnnData` object per panel, and multiple slides inside each `AnnData` object. In that case, make sure you keep a column in `adata.obs` that denotes which cell corresponds to which slide. 43 | 44 | Then, remind this column, and pass it to [`novae.spatial_neighbors`](../../api/utils/#novae.spatial_neighbors). The other functions don't need this argument. 45 | 46 | !!! example 47 | For instance, you can do: 48 | ```python 49 | adatas = [adata_1, adata_2, ...] 50 | 51 | novae.spatial_neighbors(adatas, slide_key="my-slide-id-column") 52 | 53 | model.compute_representations(adatas, zero_shot=True) 54 | ``` 55 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: Novae 2 | repo_name: MICS-Lab/novae 3 | repo_url: https://github.com/MICS-Lab/novae 4 | copyright: Copyright © 2024 Quentin Blampey 5 | theme: 6 | name: material 7 | logo: assets/logo_small_black.png 8 | favicon: assets/logo_favicon.png 9 | palette: 10 | scheme: slate 11 | primary: white 12 | nav: 13 | - Home: index.md 14 | - Getting started: getting_started.md 15 | - Tutorials: 16 | - Main usage: tutorials/main_usage.ipynb 17 | - Different input modes: tutorials/input_modes.md 18 | - Usage on proteins: tutorials/proteins.ipynb 19 | - Spot/bin technologies: tutorials/resolutions.ipynb 20 | - API: 21 | - Novae model: api/Novae.md 22 | - Utils: api/utils.md 23 | - Plotting: api/plot.md 24 | - Advanced: 25 | - Metrics: api/metrics.md 26 | - Modules: api/modules.md 27 | - Dataloader: api/dataloader.md 28 | - FAQ: faq.md 29 | - Cite us: cite_us.md 30 | 31 | plugins: 32 | - search 33 | - mkdocstrings: 34 | handlers: 35 | python: 36 | options: 37 | show_root_heading: true 38 | heading_level: 3 39 | - mkdocs-jupyter: 40 | include_source: True 41 | markdown_extensions: 42 | - admonition 43 | - attr_list 44 | - md_in_html 45 | - pymdownx.details 46 | - pymdownx.highlight: 47 | anchor_linenums: true 48 | - pymdownx.inlinehilite 49 | - pymdownx.snippets 50 | - pymdownx.superfences 51 | - pymdownx.arithmatex: 52 | generic: true 53 | - pymdownx.tabbed: 54 | alternate_style: true 55 | extra_javascript: 56 | - javascripts/mathjax.js 57 | - https://polyfill.io/v3/polyfill.min.js?features=es6 58 | - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js 59 | -------------------------------------------------------------------------------- /novae/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib.metadata 2 | import logging 3 | 4 | __version__ = importlib.metadata.version("novae") 5 | 6 | from ._logging import configure_logger 7 | from ._settings import settings 8 | from .model import Novae 9 | from . import utils 10 | from . import data 11 | from . import monitor 12 | from . import plot 13 | from .utils import spatial_neighbors, batch_effect_correction 14 | 15 | log = logging.getLogger("novae") 16 | configure_logger(log) 17 | -------------------------------------------------------------------------------- /novae/_constants.py: -------------------------------------------------------------------------------- 1 | class Keys: 2 | # obs keys 3 | LEAVES: str = "novae_leaves" 4 | DOMAINS_PREFIX: str = "novae_domains_" 5 | IS_VALID_OBS: str = "neighborhood_valid" 6 | SLIDE_ID: str = "novae_sid" 7 | 8 | # obsm keys 9 | REPR: str = "novae_latent" 10 | REPR_CORRECTED: str = "novae_latent_corrected" 11 | 12 | # obsp keys 13 | ADJ: str = "spatial_distances" 14 | ADJ_LOCAL: str = "spatial_distances_local" 15 | ADJ_PAIR: str = "spatial_distances_pair" 16 | 17 | # var keys 18 | VAR_MEAN: str = "mean" 19 | VAR_STD: str = "std" 20 | IS_KNOWN_GENE: str = "in_vocabulary" 21 | HIGHLY_VARIABLE: str = "highly_variable" 22 | USE_GENE: str = "novae_use_gene" 23 | 24 | # layer keys 25 | COUNTS_LAYER: str = "counts" 26 | 27 | # misc keys 28 | UNS_TISSUE: str = "novae_tissue" 29 | ADATA_INDEX: str = "adata_index" 30 | N_BATCHES: str = "n_batches" 31 | NOVAE_VERSION: str = "novae_version" 32 | 33 | 34 | class Nums: 35 | # training constants 36 | EPS: float = 1e-8 37 | MIN_DATASET_LENGTH: int = 50_000 38 | MAX_DATASET_LENGTH_RATIO: float = 0.02 39 | DEFAULT_SAMPLE_CELLS: int = 100_000 40 | WARMUP_EPOCHS: int = 1 41 | 42 | # distances constants and thresholds (in microns) 43 | CELLS_CHARACTERISTIC_DISTANCE: int = 20 # characteristic distance between two cells, in microns 44 | MAX_MEAN_DISTANCE_RATIO: float = 8 45 | 46 | # genes constants 47 | N_HVG_THRESHOLD: int = 500 48 | MIN_GENES_FOR_HVG: int = 100 49 | MIN_GENES: int = 20 50 | 51 | # swav head constants 52 | SWAV_EPSILON: float = 0.05 53 | SINKHORN_ITERATIONS: int = 3 54 | QUEUE_SIZE: int = 2 55 | QUEUE_WEIGHT_THRESHOLD_RATIO: float = 0.99 56 | 57 | # misc nums 58 | MEAN_NGH_TH_WARNING: float = 3.5 59 | N_OBS_THRESHOLD: int = 2_000_000 # above this number, lazy loading is used 60 | RATIO_VALID_CELLS_TH: float = 0.7 61 | -------------------------------------------------------------------------------- /novae/_logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | log = logging.getLogger(__name__) 4 | 5 | 6 | class ColorFormatter(logging.Formatter): 7 | grey = "\x1b[38;20m" 8 | blue = "\x1b[36;20m" 9 | yellow = "\x1b[33;20m" 10 | red = "\x1b[31;20m" 11 | bold_red = "\x1b[31;1m" 12 | reset = "\x1b[0m" 13 | 14 | prefix = "[%(levelname)s] (%(name)s)" 15 | suffix = "%(message)s" 16 | 17 | FORMATS = { 18 | logging.DEBUG: f"{grey}{prefix}{reset} {suffix}", 19 | logging.INFO: f"{blue}{prefix}{reset} {suffix}", 20 | logging.WARNING: f"{yellow}{prefix}{reset} {suffix}", 21 | logging.ERROR: f"{red}{prefix}{reset} {suffix}", 22 | logging.CRITICAL: f"{bold_red}{prefix}{reset} {suffix}", 23 | } 24 | 25 | def format(self, record): 26 | log_fmt = self.FORMATS.get(record.levelno) 27 | formatter = logging.Formatter(log_fmt) 28 | return formatter.format(record) 29 | 30 | 31 | def configure_logger(log: logging.Logger): 32 | log.setLevel(logging.INFO) 33 | 34 | consoleHandler = logging.StreamHandler() 35 | consoleHandler.setFormatter(ColorFormatter()) 36 | 37 | log.addHandler(consoleHandler) 38 | log.propagate = False 39 | -------------------------------------------------------------------------------- /novae/_settings.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from ._constants import Nums 4 | 5 | 6 | class Settings: 7 | # misc settings 8 | auto_preprocessing: bool = True 9 | 10 | def disable_lazy_loading(self): 11 | """Disable lazy loading of subgraphs in the NovaeDataset.""" 12 | Nums.N_OBS_THRESHOLD = np.inf 13 | 14 | def enable_lazy_loading(self, n_obs_threshold: int = 0): 15 | """Enable lazy loading of subgraphs in the NovaeDataset. 16 | 17 | Args: 18 | n_obs_threshold: Lazy loading is used above this number of cells in an AnnData object. 19 | """ 20 | Nums.N_OBS_THRESHOLD = n_obs_threshold 21 | 22 | @property 23 | def warmup_epochs(self): 24 | return Nums.WARMUP_EPOCHS 25 | 26 | @warmup_epochs.setter 27 | def warmup_epochs(self, value: int): 28 | Nums.WARMUP_EPOCHS = value 29 | 30 | 31 | settings = Settings() 32 | -------------------------------------------------------------------------------- /novae/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .convert import AnnDataTorch 2 | from .dataset import NovaeDataset 3 | from .datamodule import NovaeDatamodule 4 | -------------------------------------------------------------------------------- /novae/data/convert.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from anndata import AnnData 4 | from sklearn.preprocessing import LabelEncoder 5 | from torch import Tensor 6 | 7 | from .._constants import Keys, Nums 8 | from ..module import CellEmbedder 9 | from ..utils import sparse_std 10 | 11 | 12 | class AnnDataTorch: 13 | tensors: list[Tensor] | None 14 | genes_indices_list: list[Tensor] 15 | 16 | def __init__(self, adatas: list[AnnData], cell_embedder: CellEmbedder): 17 | """Converting AnnData objects to PyTorch tensors. 18 | 19 | Args: 20 | adatas: A list of `AnnData` objects. 21 | cell_embedder: A [novae.module.CellEmbedder][] object. 22 | """ 23 | super().__init__() 24 | self.adatas = adatas 25 | self.cell_embedder = cell_embedder 26 | 27 | self.genes_indices_list = [self._adata_to_genes_indices(adata) for adata in self.adatas] 28 | self.tensors = None 29 | 30 | self.means, self.stds, self.label_encoder = self._compute_means_stds() 31 | 32 | # Tensors are loaded in memory for low numbers of cells 33 | if sum(adata.n_obs for adata in self.adatas) < Nums.N_OBS_THRESHOLD: 34 | self.tensors = [self.to_tensor(adata) for adata in self.adatas] 35 | 36 | def _adata_to_genes_indices(self, adata: AnnData) -> Tensor: 37 | return self.cell_embedder.genes_to_indices(adata.var_names[self._keep_var(adata)])[None, :] 38 | 39 | def _keep_var(self, adata: AnnData) -> AnnData: 40 | return adata.var[Keys.USE_GENE] 41 | 42 | def _compute_means_stds(self) -> tuple[Tensor, Tensor, LabelEncoder]: 43 | means, stds = {}, {} 44 | 45 | for adata in self.adatas: 46 | slide_ids = adata.obs[Keys.SLIDE_ID] 47 | for slide_id in slide_ids.cat.categories: 48 | adata_slide = adata[adata.obs[Keys.SLIDE_ID] == slide_id, self._keep_var(adata)] 49 | 50 | mean = adata_slide.X.mean(0) 51 | mean = mean.A1 if isinstance(mean, np.matrix) else mean 52 | means[slide_id] = mean.astype(np.float32) 53 | 54 | std = adata_slide.X.std(0) if isinstance(adata_slide.X, np.ndarray) else sparse_std(adata_slide.X, 0).A1 55 | stds[slide_id] = std.astype(np.float32) 56 | 57 | label_encoder = LabelEncoder() 58 | label_encoder.fit(list(means.keys())) 59 | 60 | means = [torch.tensor(means[slide_id]) for slide_id in label_encoder.classes_] 61 | stds = [torch.tensor(stds[slide_id]) for slide_id in label_encoder.classes_] 62 | 63 | return means, stds, label_encoder 64 | 65 | def to_tensor(self, adata: AnnData) -> Tensor: 66 | """Get the normalized gene expressions of the cells in the dataset. 67 | Only the genes of interest are kept (known genes and highly variable). 68 | 69 | Args: 70 | adata: An `AnnData` object. 71 | 72 | Returns: 73 | A `Tensor` containing the normalized gene expresions. 74 | """ 75 | adata = adata[:, self._keep_var(adata)] 76 | 77 | if len(np.unique(adata.obs[Keys.SLIDE_ID])) == 1: 78 | slide_id_index = self.label_encoder.transform([adata.obs.iloc[0][Keys.SLIDE_ID]])[0] 79 | mean, std = self.means[slide_id_index], self.stds[slide_id_index] 80 | else: 81 | slide_id_indices = self.label_encoder.transform(adata.obs[Keys.SLIDE_ID]) 82 | mean = torch.stack([self.means[i] for i in slide_id_indices]) # TODO: avoid stack (only if not fast enough) 83 | std = torch.stack([self.stds[i] for i in slide_id_indices]) 84 | 85 | X = adata.X if isinstance(adata.X, np.ndarray) else adata.X.toarray() 86 | X = torch.tensor(X, dtype=torch.float32) 87 | X = (X - mean) / (std + Nums.EPS) 88 | 89 | return X 90 | 91 | def __getitem__(self, item: tuple[int, slice]) -> tuple[Tensor, Tensor]: 92 | """Get the expression values for a subset of cells (corresponding to a subgraph). 93 | 94 | Args: 95 | item: A `tuple` containing the index of the `AnnData` object and the indices of the cells in the neighborhoods. 96 | 97 | Returns: 98 | A `Tensor` of normalized gene expressions and a `Tensor` of gene indices. 99 | """ 100 | adata_index, obs_indices = item 101 | 102 | if self.tensors is not None: 103 | return self.tensors[adata_index][obs_indices], self.genes_indices_list[adata_index] 104 | 105 | adata = self.adatas[adata_index] 106 | adata_view = adata[obs_indices] 107 | 108 | return self.to_tensor(adata_view), self.genes_indices_list[adata_index] 109 | -------------------------------------------------------------------------------- /novae/data/datamodule.py: -------------------------------------------------------------------------------- 1 | import lightning as L 2 | from anndata import AnnData 3 | from torch_geometric.loader import DataLoader 4 | 5 | from ..module import CellEmbedder 6 | from . import NovaeDataset 7 | 8 | 9 | class NovaeDatamodule(L.LightningDataModule): 10 | """ 11 | Datamodule used for training and inference. Small wrapper around the [novae.data.NovaeDataset][] 12 | """ 13 | 14 | def __init__( 15 | self, 16 | adatas: list[AnnData], 17 | cell_embedder: CellEmbedder, 18 | batch_size: int, 19 | n_hops_local: int, 20 | n_hops_view: int, 21 | num_workers: int = 0, 22 | sample_cells: int | None = None, 23 | ) -> None: 24 | super().__init__() 25 | self.dataset = NovaeDataset( 26 | adatas, 27 | cell_embedder=cell_embedder, 28 | batch_size=batch_size, 29 | n_hops_local=n_hops_local, 30 | n_hops_view=n_hops_view, 31 | sample_cells=sample_cells, 32 | ) 33 | self.batch_size = batch_size 34 | self.num_workers = num_workers 35 | 36 | def train_dataloader(self) -> DataLoader: 37 | """Get a Pytorch dataloader for prediction. 38 | 39 | Returns: 40 | The training dataloader. 41 | """ 42 | self.dataset.training = True 43 | return DataLoader( 44 | self.dataset, 45 | batch_size=self.batch_size, 46 | shuffle=False, 47 | drop_last=True, 48 | num_workers=self.num_workers, 49 | ) 50 | 51 | def predict_dataloader(self) -> DataLoader: 52 | """Get a Pytorch dataloader for prediction or inference. 53 | 54 | Returns: 55 | The prediction dataloader. 56 | """ 57 | self.dataset.training = False 58 | return DataLoader( 59 | self.dataset, 60 | batch_size=self.batch_size, 61 | shuffle=False, 62 | drop_last=False, 63 | num_workers=self.num_workers, 64 | ) 65 | -------------------------------------------------------------------------------- /novae/module/__init__.py: -------------------------------------------------------------------------------- 1 | from .aggregate import AttentionAggregation 2 | from .embed import CellEmbedder 3 | from .augment import GraphAugmentation 4 | from .encode import GraphEncoder 5 | from .swav import SwavHead 6 | -------------------------------------------------------------------------------- /novae/module/aggregate.py: -------------------------------------------------------------------------------- 1 | import lightning as L 2 | from torch import Tensor, nn 3 | from torch_geometric.nn.aggr import Aggregation 4 | from torch_geometric.nn.inits import reset 5 | from torch_geometric.utils import softmax 6 | 7 | 8 | class AttentionAggregation(Aggregation, L.LightningModule): 9 | """Aggregate the node embeddings using attention.""" 10 | 11 | def __init__(self, output_size: int): 12 | """ 13 | 14 | Args: 15 | output_size: Size of the representations, i.e. the encoder outputs (`O` in the article). 16 | """ 17 | super().__init__() 18 | self.attention_aggregation = ProjectionLayers(output_size) # for backward compatibility when loading models 19 | 20 | def forward( 21 | self, 22 | x: Tensor, 23 | index: Tensor | None = None, 24 | ptr: None = None, 25 | dim_size: None = None, 26 | dim: int = -2, 27 | ) -> Tensor: 28 | """Performs attention aggragation. 29 | 30 | Args: 31 | x: The nodes embeddings representing `B` total graphs. 32 | index: The Pytorch Geometric index used to know to which graph each node belongs. 33 | 34 | Returns: 35 | A tensor of shape `(B, O)` of graph embeddings. 36 | """ 37 | gate = self.attention_aggregation.gate_nn(x) 38 | x = self.attention_aggregation.nn(x) 39 | 40 | gate = softmax(gate, index, dim=dim) 41 | 42 | return self.reduce(gate * x, index, dim=dim) 43 | 44 | def reset_parameters(self): 45 | reset(self.attention_aggregation.gate_nn) 46 | reset(self.attention_aggregation.nn) 47 | 48 | def __repr__(self) -> str: 49 | return f"{self.__class__.__name__}(gate_nn={self.attention_aggregation.gate_nn}, nn={self.attention_aggregation.nn})" 50 | 51 | 52 | class ProjectionLayers(L.LightningModule): 53 | """ 54 | Small class for backward compatibility when loading models 55 | Contains the projection layers used for the attention aggregation 56 | """ 57 | 58 | def __init__(self, output_size): 59 | super().__init__() 60 | self.gate_nn = nn.Linear(output_size, 1) 61 | self.nn = nn.Linear(output_size, output_size) 62 | -------------------------------------------------------------------------------- /novae/module/augment.py: -------------------------------------------------------------------------------- 1 | import lightning as L 2 | import torch 3 | from torch.distributions import Exponential 4 | from torch_geometric.data import Batch 5 | 6 | 7 | class GraphAugmentation(L.LightningModule): 8 | """Perform graph augmentation for Novae. It adds noise to the data and keeps a subset of the genes.""" 9 | 10 | def __init__( 11 | self, 12 | panel_subset_size: float, 13 | background_noise_lambda: float, 14 | sensitivity_noise_std: float, 15 | ): 16 | """ 17 | 18 | Args: 19 | panel_subset_size: Ratio of genes kept from the panel during augmentation. 20 | background_noise_lambda: Parameter of the exponential distribution for the noise augmentation. 21 | sensitivity_noise_std: Standard deviation for the multiplicative for for the noise augmentation. 22 | """ 23 | super().__init__() 24 | self.panel_subset_size = panel_subset_size 25 | self.background_noise_lambda = background_noise_lambda 26 | self.sensitivity_noise_std = sensitivity_noise_std 27 | 28 | self.background_noise_distribution = Exponential(torch.tensor(float(background_noise_lambda))) 29 | 30 | def noise(self, data: Batch): 31 | """Add noise (inplace) to the data as detailed in the article. 32 | 33 | Args: 34 | data: A Pytorch Geometric `Data` object representing a batch of `B` graphs. 35 | """ 36 | sample_shape = (data.batch_size, data.x.shape[1]) 37 | 38 | additions = self.background_noise_distribution.sample(sample_shape=sample_shape).to(self.device) 39 | gaussian_noise = torch.randn(sample_shape, device=self.device) 40 | factors = (1 + gaussian_noise * self.sensitivity_noise_std).clip(0, 2) 41 | 42 | for i in range(data.batch_size): 43 | start, stop = data.ptr[i], data.ptr[i + 1] 44 | data.x[start:stop] = data.x[start:stop] * factors[i] + additions[i] 45 | 46 | def panel_subset(self, data: Batch): 47 | """ 48 | Keep a ratio of `panel_subset_size` of the input genes (inplace operation). 49 | 50 | Args: 51 | data: A Pytorch Geometric `Data` object representing a batch of `B` graphs. 52 | """ 53 | n_total = len(data.genes_indices[0]) 54 | n_subset = int(n_total * self.panel_subset_size) 55 | 56 | gene_subset_indices = torch.randperm(n_total)[:n_subset] 57 | 58 | data.x = data.x[:, gene_subset_indices] 59 | data.genes_indices = data.genes_indices[:, gene_subset_indices] 60 | 61 | def forward(self, data: Batch) -> Batch: 62 | """Perform data augmentation (`noise` and `panel_subset`). 63 | 64 | Args: 65 | data: A Pytorch Geometric `Data` object representing a batch of `B` graphs. 66 | 67 | Returns: 68 | The augmented `Data` object 69 | """ 70 | self.panel_subset(data) 71 | self.noise(data) 72 | return data 73 | -------------------------------------------------------------------------------- /novae/module/embed.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | from pathlib import Path 4 | 5 | import lightning as L 6 | import numpy as np 7 | import pandas as pd 8 | import torch 9 | import torch.nn.functional as F 10 | from anndata import AnnData 11 | from scipy.sparse import issparse 12 | from sklearn.decomposition import PCA 13 | from sklearn.neighbors import KDTree 14 | from torch import nn 15 | from torch_geometric.data import Data 16 | 17 | from .. import utils 18 | from .._constants import Keys 19 | 20 | log = logging.getLogger(__name__) 21 | 22 | 23 | class CellEmbedder(L.LightningModule): 24 | """Convert a cell into an embedding using a gene embedding matrix.""" 25 | 26 | def __init__( 27 | self, 28 | gene_names: list[str] | dict[str, int], 29 | embedding_size: int | None, 30 | embedding: torch.Tensor | None = None, 31 | ) -> None: 32 | """ 33 | 34 | Args: 35 | gene_names: Name of the genes to be used in the embedding, or dictionnary of index to name. 36 | embedding_size: Size of the embeddings of the genes (`E` in the article). Optional if `embedding` is provided. 37 | embedding: Optional pre-trained embedding matrix. If provided, `embedding_size` shouldn't be provided. 38 | """ 39 | super().__init__() 40 | assert (embedding_size is None) ^ (embedding is None), "Either embedding_size or embedding must be provided" 41 | 42 | if isinstance(gene_names, dict): 43 | self.gene_to_index = {gene.lower(): index for gene, index in gene_names.items()} 44 | self.gene_names = sorted(self.gene_to_index, key=self.gene_to_index.get) 45 | _check_gene_to_index(self.gene_to_index) 46 | else: 47 | self.gene_names = [gene.lower() for gene in gene_names] 48 | self.gene_to_index = {gene: i for i, gene in enumerate(self.gene_names)} 49 | 50 | self.voc_size = len(self.gene_names) 51 | 52 | if embedding is None: 53 | self.embedding_size = embedding_size 54 | self.embedding = nn.Embedding(self.voc_size, embedding_size) 55 | else: 56 | self.embedding_size = embedding.size(1) 57 | self.embedding = nn.Embedding.from_pretrained(embedding) 58 | 59 | self.linear = nn.Linear(self.embedding_size, self.embedding_size) 60 | self._init_linear() 61 | 62 | @torch.no_grad() 63 | def _init_linear(self): 64 | self.linear.weight.data.copy_(torch.eye(self.embedding_size)) 65 | self.linear.bias.data.zero_() 66 | 67 | @classmethod 68 | def from_scgpt_embedding(cls, scgpt_model_dir: str) -> "CellEmbedder": 69 | """Initialize the CellEmbedder from a scGPT pretrained model directory. 70 | 71 | Args: 72 | scgpt_model_dir: Path to a directory containing a scGPT checkpoint, i.e. a `vocab.json` and a `best_model.pt` file. 73 | 74 | Returns: 75 | A CellEmbedder instance. 76 | """ 77 | scgpt_model_dir = Path(scgpt_model_dir) 78 | 79 | vocab_file = scgpt_model_dir / "vocab.json" 80 | 81 | with open(vocab_file) as file: 82 | gene_to_index: dict[str, int] = json.load(file) 83 | 84 | checkpoint = torch.load(scgpt_model_dir / "best_model.pt", map_location=torch.device("cpu")) 85 | embedding = checkpoint["encoder.embedding.weight"] 86 | 87 | return cls(gene_to_index, None, embedding=embedding) 88 | 89 | def genes_to_indices(self, gene_names: pd.Index | list[str], as_torch: bool = True) -> torch.Tensor | np.ndarray: 90 | """Convert gene names to their corresponding indices. 91 | 92 | Args: 93 | gene_names: Names of the gene names to convert. 94 | as_torch: Whether to return a `torch` tensor or a `numpy` array. 95 | 96 | Returns: 97 | A tensor or array of gene indices. 98 | """ 99 | indices = [self.gene_to_index[gene] for gene in utils.lower_var_names(gene_names)] 100 | 101 | if as_torch: 102 | return torch.tensor(indices, dtype=torch.long) 103 | 104 | return np.array(indices, dtype=np.int16) 105 | 106 | def forward(self, data: Data) -> Data: 107 | """Embed the input data. 108 | 109 | Args: 110 | data: A Pytorch Geometric `Data` object representing a batch of `B` graphs. The number of node features is variable. 111 | 112 | Returns: 113 | data: A Pytorch Geometric `Data` object representing a batch of `B` graphs. Each node now has a size of `E`. 114 | """ 115 | genes_embeddings = self.embedding(data.genes_indices[0]) 116 | genes_embeddings = self.linear(genes_embeddings) 117 | genes_embeddings = F.normalize(genes_embeddings, dim=0, p=2) 118 | 119 | data.x = data.x @ genes_embeddings 120 | return data 121 | 122 | def pca_init(self, adatas: list[AnnData] | None): 123 | """Initialize the Noave embeddings with PCA components. 124 | 125 | Args: 126 | adatas: A list of `AnnData` objects to use for PCA initialization. 127 | """ 128 | if adatas is None: 129 | return 130 | 131 | adatas = [adata[:, adata.var[Keys.USE_GENE]] for adata in adatas] 132 | 133 | adata = max(adatas, key=lambda adata: adata.n_vars) 134 | 135 | if adata.X.shape[1] <= self.embedding_size: 136 | log.warning( 137 | f"PCA with {self.embedding_size} components can not be run on shape {adata.X.shape}.\nTo use PCA initialization, set a lower `embedding_size` (<{adata.X.shape[1]}) in novae.Novae()." 138 | ) 139 | return 140 | 141 | X = adata.X.toarray() if issparse(adata.X) else adata.X 142 | 143 | log.info("Running PCA embedding initialization") 144 | 145 | pca = PCA(n_components=self.embedding_size) 146 | pca.fit(X.astype(np.float32)) 147 | 148 | indices = self.genes_to_indices(adata.var_names) 149 | self.embedding.weight.data[indices] = torch.tensor(pca.components_.T) 150 | 151 | known_var_names = utils.lower_var_names(adata.var_names) 152 | 153 | for other_adata in adatas: 154 | other_var_names = utils.lower_var_names(other_adata.var_names) 155 | where_in = np.isin(other_var_names, known_var_names) 156 | 157 | if where_in.all(): 158 | continue 159 | 160 | X = other_adata[:, where_in].X.toarray().T 161 | Y = other_adata[:, ~where_in].X.toarray().T 162 | 163 | tree = KDTree(X) 164 | _, ind = tree.query(Y, k=1) 165 | neighbor_indices = self.genes_to_indices(other_adata[:, where_in].var_names[ind[:, 0]]) 166 | 167 | indices = self.genes_to_indices(other_adata[:, ~where_in].var_names) 168 | self.embedding.weight.data[indices] = self.embedding.weight.data[neighbor_indices].clone() 169 | 170 | 171 | def _check_gene_to_index(gene_to_index: dict[str, int]): 172 | values = list(set(gene_to_index.values())) 173 | 174 | assert len(values) == len(gene_to_index), "gene_to_index should be a dictionnary with unique values" 175 | 176 | assert min(values) == 0 and max(values) == len(values) - 1, ( 177 | "gene_to_index should be a dictionnary with continuous indices starting from 0" 178 | ) 179 | -------------------------------------------------------------------------------- /novae/module/encode.py: -------------------------------------------------------------------------------- 1 | import lightning as L 2 | from torch import Tensor 3 | from torch_geometric.data import Batch 4 | from torch_geometric.nn.models import GAT 5 | 6 | from . import AttentionAggregation 7 | 8 | 9 | class GraphEncoder(L.LightningModule): 10 | """Graph encoder of Novae. It uses a graph attention network.""" 11 | 12 | def __init__( 13 | self, 14 | embedding_size: int, 15 | hidden_size: int, 16 | num_layers: int, 17 | output_size: int, 18 | heads: int, 19 | ) -> None: 20 | """ 21 | Args: 22 | embedding_size: Size of the embeddings of the genes (`E` in the article). 23 | hidden_size: The size of the hidden layers in the GAT. 24 | num_layers: The number of layers in the GAT. 25 | output_size: Size of the representations, i.e. the encoder outputs (`O` in the article). 26 | heads: The number of attention heads in the GAT. 27 | """ 28 | super().__init__() 29 | self.gnn = GAT( 30 | embedding_size, 31 | hidden_channels=hidden_size, 32 | num_layers=num_layers, 33 | out_channels=output_size, 34 | edge_dim=1, 35 | v2=True, 36 | heads=heads, 37 | act="ELU", 38 | ) 39 | 40 | self.node_aggregation = AttentionAggregation(output_size) 41 | 42 | def forward(self, data: Batch) -> Tensor: 43 | """Encode the input data. 44 | 45 | Args: 46 | data: A Pytorch Geometric `Data` object representing a batch of `B` graphs. 47 | 48 | Returns: 49 | A tensor of shape `(B, O)` containing the encoded graphs. 50 | """ 51 | out = self.gnn(x=data.x, edge_index=data.edge_index, edge_attr=data.edge_attr) 52 | return self.node_aggregation(out, index=data.batch) 53 | -------------------------------------------------------------------------------- /novae/monitor/__init__.py: -------------------------------------------------------------------------------- 1 | from .eval import ( 2 | jensen_shannon_divergence, 3 | mean_fide_score, 4 | fide_score, 5 | entropy, 6 | heuristic, 7 | mean_normalized_entropy, 8 | ) 9 | -------------------------------------------------------------------------------- /novae/monitor/callback.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import scanpy as sc 4 | import seaborn as sns 5 | from anndata import AnnData 6 | from lightning import Trainer 7 | from lightning.pytorch.callbacks import Callback 8 | 9 | from .._constants import Keys 10 | from ..model import Novae 11 | from .eval import heuristic, mean_fide_score 12 | from .log import log_plt_figure, save_pdf_figure 13 | 14 | 15 | class LogProtoCovCallback(Callback): 16 | def on_train_epoch_end(self, trainer: Trainer, model: Novae) -> None: 17 | C = model.swav_head.prototypes.data.numpy(force=True) 18 | 19 | plt.figure(figsize=(10, 10)) 20 | sns.clustermap(np.cov(C)) 21 | log_plt_figure("prototypes_covariance") 22 | 23 | 24 | class LogTissuePrototypeWeights(Callback): 25 | def on_train_epoch_end(self, trainer: Trainer, model: Novae) -> None: 26 | if model.swav_head.queue is None: 27 | return 28 | 29 | model.plot_prototype_weights() 30 | save_pdf_figure(f"tissue_prototype_weights_e{model.current_epoch}") 31 | log_plt_figure("tissue_prototype_weights") 32 | 33 | 34 | class ValidationCallback(Callback): 35 | def __init__( 36 | self, 37 | adatas: list[AnnData] | None, 38 | accelerator: str = "cpu", 39 | num_workers: int = 0, 40 | slide_name_key: str = "slide_id", 41 | k: int = 7, 42 | ): 43 | assert adatas is None or len(adatas) == 1, "ValidationCallback only supports single slide mode for now" 44 | self.adata = adatas[0] if adatas is not None else None 45 | self.accelerator = accelerator 46 | self.num_workers = num_workers 47 | self.slide_name_key = slide_name_key 48 | self.k = k 49 | 50 | self._max_heuristic = 0.0 51 | 52 | def on_train_epoch_end(self, trainer: Trainer, model: Novae): 53 | if self.adata is None: 54 | return 55 | 56 | model.mode.trained = True # trick to avoid assert error in compute_representations 57 | 58 | model.compute_representations( 59 | self.adata, accelerator=self.accelerator, num_workers=self.num_workers, zero_shot=True 60 | ) 61 | model.swav_head.hierarchical_clustering() 62 | 63 | obs_key = model.assign_domains(self.adata, n_domains=self.k) 64 | 65 | plt.figure() 66 | sc.pl.spatial(self.adata, color=obs_key, spot_size=20, img_key=None, show=False) 67 | slide_name_key = self.slide_name_key if self.slide_name_key in self.adata.obs else Keys.SLIDE_ID 68 | log_plt_figure(f"val_{self.k}_{self.adata.obs[slide_name_key].iloc[0]}") 69 | 70 | fide = mean_fide_score(self.adata, obs_key=obs_key, n_classes=self.k) 71 | model.log("metrics/val_mean_fide_score", fide) 72 | 73 | heuristic_ = heuristic(self.adata, obs_key=obs_key, n_classes=self.k) 74 | model.log("metrics/val_heuristic", heuristic_) 75 | 76 | self._max_heuristic = max(self._max_heuristic, heuristic_) 77 | model.log("metrics/val_max_heuristic", self._max_heuristic) 78 | 79 | model.mode.zero_shot = False 80 | -------------------------------------------------------------------------------- /novae/monitor/eval.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | from anndata import AnnData 5 | from sklearn import metrics 6 | 7 | from .._constants import Keys, Nums 8 | 9 | log = logging.getLogger(__name__) 10 | 11 | 12 | def mean_fide_score( 13 | adatas: AnnData | list[AnnData], obs_key: str, slide_key: str | None = None, n_classes: int | None = None 14 | ) -> float: 15 | """Mean FIDE score over all slides. A low score indicates a great domain continuity. 16 | 17 | Args: 18 | adatas: An `AnnData` object, or a list of `AnnData` objects. 19 | obs_key: Key of `adata.obs` containing the domains annotation. 20 | slide_key: Optional key of `adata.obs` containing the ID of each slide. Not needed if each `adata` is a slide. 21 | n_classes: Optional number of classes. This can be useful if not all classes are predicted, for a fair comparision. 22 | 23 | Returns: 24 | The FIDE score averaged for all slides. 25 | """ 26 | return np.mean([ 27 | fide_score(adata, obs_key, n_classes=n_classes) 28 | for adata in _iter_uid(adatas, slide_key=slide_key, obs_key=obs_key) 29 | ]) 30 | 31 | 32 | def fide_score(adata: AnnData, obs_key: str, n_classes: int | None = None) -> float: 33 | """F1-score of intra-domain edges (FIDE). A high score indicates a great domain continuity. 34 | 35 | Note: 36 | The F1-score is computed for every class, then all F1-scores are averaged. If some classes 37 | are not predicted, the `n_classes` argument allows to pad with zeros before averaging the F1-scores. 38 | 39 | Args: 40 | adata: An `AnnData` object 41 | obs_key: Key of `adata.obs` containing the domains annotation. 42 | n_classes: Optional number of classes. This can be useful if not all classes are predicted, for a fair comparision. 43 | 44 | Returns: 45 | The FIDE score. 46 | """ 47 | i_left, i_right = adata.obsp[Keys.ADJ].nonzero() 48 | classes_left, classes_right = adata.obs.iloc[i_left][obs_key].values, adata.obs.iloc[i_right][obs_key].values 49 | 50 | where_valid = ~classes_left.isna() & ~classes_right.isna() 51 | classes_left, classes_right = classes_left[where_valid], classes_right[where_valid] 52 | 53 | f1_scores = metrics.f1_score(classes_left, classes_right, average=None) 54 | 55 | if n_classes is None: 56 | return f1_scores.mean() 57 | 58 | assert n_classes >= len(f1_scores), f"Expected {n_classes:=}, but found {len(f1_scores)}, which is greater" 59 | 60 | return np.pad(f1_scores, (0, n_classes - len(f1_scores))).mean() 61 | 62 | 63 | def jensen_shannon_divergence(adatas: AnnData | list[AnnData], obs_key: str, slide_key: str | None = None) -> float: 64 | """Jensen-Shannon divergence (JSD) over all slides 65 | 66 | Args: 67 | adatas: One or a list of AnnData object(s) 68 | obs_key: Key of `adata.obs` containing the domains annotation. 69 | slide_key: Optional key of `adata.obs` containing the ID of each slide. Not needed if each `adata` is a slide. 70 | 71 | Returns: 72 | The Jensen-Shannon divergence score for all slides 73 | """ 74 | all_categories = set() 75 | for adata in _iter_uid(adatas, slide_key=slide_key, obs_key=obs_key): 76 | all_categories.update(adata.obs[obs_key].cat.categories) 77 | all_categories = sorted(all_categories) 78 | 79 | distributions = [] 80 | for adata in _iter_uid(adatas, slide_key=slide_key, obs_key=obs_key): 81 | value_counts = adata.obs[obs_key].value_counts(sort=False) 82 | distribution = np.zeros(len(all_categories)) 83 | 84 | for i, category in enumerate(all_categories): 85 | if category in value_counts: 86 | distribution[i] = value_counts[category] 87 | 88 | distributions.append(distribution) 89 | 90 | return _jensen_shannon_divergence(np.array(distributions)) 91 | 92 | 93 | def _jensen_shannon_divergence(distributions: np.ndarray) -> float: 94 | """Compute the Jensen-Shannon divergence (JSD) for a multiple probability distributions. 95 | 96 | The lower the score, the better distribution of clusters among the different batches. 97 | 98 | Args: 99 | distributions: An array of shape (B, C), where B is the number of batches, and C is the number of clusters. For each batch, it contains the percentage of each cluster among cells. 100 | 101 | Returns: 102 | A float corresponding to the JSD 103 | """ 104 | distributions = distributions / distributions.sum(1)[:, None] 105 | mean_distribution = np.mean(distributions, 0) 106 | 107 | return entropy(mean_distribution) - np.mean([entropy(dist) for dist in distributions]) 108 | 109 | 110 | def entropy(distribution: np.ndarray) -> float: 111 | """Shannon entropy 112 | 113 | Args: 114 | distribution: An array of probabilities (should sum to one) 115 | 116 | Returns: 117 | The Shannon entropy 118 | """ 119 | return -(distribution * np.log2(distribution + Nums.EPS)).sum() 120 | 121 | 122 | def mean_normalized_entropy( 123 | adatas: AnnData | list[AnnData], n_classes: int, obs_key: str, slide_key: str | None = None 124 | ) -> float: 125 | return np.mean([ 126 | _mean_normalized_entropy(adata, obs_key, n_classes=n_classes) 127 | for adata in _iter_uid(adatas, slide_key=slide_key, obs_key=obs_key) 128 | ]) 129 | 130 | 131 | def _mean_normalized_entropy(adata: AnnData, obs_key: str, n_classes: int) -> float: 132 | distribution = adata.obs[obs_key].value_counts(normalize=True).values 133 | distribution = np.pad(distribution, (0, n_classes - len(distribution)), mode="constant") 134 | entropy_ = entropy(distribution) 135 | 136 | return entropy_ / np.log2(n_classes) 137 | 138 | 139 | def heuristic(adata: AnnData | list[AnnData], obs_key: str, n_classes: int, slide_key: str | None = None) -> float: 140 | """Heuristic score to evaluate the quality of the clustering. 141 | 142 | Args: 143 | adata: An `AnnData` object 144 | obs_key: The key in `adata.obs` that contains the domains. 145 | n_classes: The number of classes. 146 | slide_key: The key in `adata.obs` that contains the slide id. 147 | 148 | Returns: 149 | The heuristic score. 150 | """ 151 | return np.mean([ 152 | _heuristic(adata, obs_key, n_classes) for adata in _iter_uid(adata, slide_key=slide_key, obs_key=obs_key) 153 | ]) 154 | 155 | 156 | def _heuristic(adata: AnnData, obs_key: str, n_classes: int) -> float: 157 | fide_ = fide_score(adata, obs_key, n_classes=n_classes) 158 | 159 | distribution = adata.obs[obs_key].value_counts(normalize=True).values 160 | distribution = np.pad(distribution, (0, n_classes - len(distribution)), mode="constant") 161 | entropy_ = entropy(distribution) 162 | 163 | return fide_ * entropy_ / np.log2(n_classes) 164 | 165 | 166 | def _iter_uid(adatas: AnnData | list[AnnData], slide_key: str | None = None, obs_key: str | None = None): 167 | """Iterate over all slides, and make sure `adata.obs[obs_key]` is categorical. 168 | 169 | Args: 170 | adatas: One or a list of AnnData object(s). 171 | slide_key: The key in `adata.obs` that contains the slide id. 172 | obs_key: The key in `adata.obs` that contains the domain id. 173 | 174 | Yields: 175 | One `AnnData` per slide. 176 | """ 177 | if isinstance(adatas, AnnData): 178 | adatas = [adatas] 179 | 180 | if obs_key is not None: 181 | categories = set.union(*[set(adata.obs[obs_key].astype("category").cat.categories) for adata in adatas]) 182 | for adata in adatas: 183 | adata.obs[obs_key] = adata.obs[obs_key].astype("category").cat.set_categories(categories) 184 | 185 | for adata in adatas: 186 | if slide_key is None: 187 | yield adata 188 | continue 189 | 190 | for slide_id in adata.obs[slide_key].unique(): 191 | adata_yield = adata[adata.obs[slide_key] == slide_id] 192 | 193 | yield adata_yield 194 | -------------------------------------------------------------------------------- /novae/monitor/log.py: -------------------------------------------------------------------------------- 1 | import io 2 | from pathlib import Path 3 | 4 | import matplotlib.pyplot as plt 5 | import wandb 6 | from PIL import Image 7 | 8 | from ..utils import repository_root 9 | 10 | 11 | def log_plt_figure(name: str, dpi: int = 300) -> None: 12 | img_buf = io.BytesIO() 13 | plt.savefig(img_buf, format="png", bbox_inches="tight", dpi=dpi) 14 | wandb.log({name: wandb.Image(Image.open(img_buf))}) 15 | plt.close() 16 | 17 | 18 | def save_pdf_figure(name: str): 19 | plt.savefig(wandb_results_dir() / f"{name}.pdf", format="pdf", bbox_inches="tight") 20 | 21 | 22 | def wandb_results_dir() -> Path: 23 | res_dir: Path = repository_root() / "data" / "results" / wandb.run.name 24 | res_dir.mkdir(parents=True, exist_ok=True) 25 | return res_dir 26 | -------------------------------------------------------------------------------- /novae/plot/__init__.py: -------------------------------------------------------------------------------- 1 | from ._graph import _domains_hierarchy, paga, connectivities 2 | from ._spatial import domains, spatially_variable_genes 3 | from ._heatmap import _weights_clustermap, pathway_scores 4 | from ._bar import domains_proportions 5 | -------------------------------------------------------------------------------- /novae/plot/_bar.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import pandas as pd 3 | import seaborn as sns 4 | from anndata import AnnData 5 | 6 | from .. import utils 7 | from ._utils import get_categorical_color_palette 8 | 9 | 10 | def domains_proportions( 11 | adata: AnnData | list[AnnData], 12 | obs_key: str | None = None, 13 | slide_name_key: str | None = None, 14 | figsize: tuple[int, int] = (2, 5), 15 | show: bool = True, 16 | ): 17 | """Show the proportion of each domain in the slide(s). 18 | 19 | Args: 20 | adata: One `AnnData` object, or a list of `AnnData` objects. 21 | obs_key: The key in `adata.obs` that contains the Novae domains. By default, the last available domain key is shown. 22 | figsize: Matplotlib figure size. 23 | show: Whether to show the plot. 24 | """ 25 | adatas = [adata] if isinstance(adata, AnnData) else adata 26 | slide_name_key = utils.check_slide_name_key(adatas, slide_name_key) 27 | obs_key = utils.check_available_domains_key(adatas, obs_key) 28 | 29 | all_domains, colors = get_categorical_color_palette(adatas, obs_key) 30 | 31 | names, series = [], [] 32 | for adata_slide in utils.iter_slides(adatas): 33 | names.append(adata_slide.obs[slide_name_key].iloc[0]) 34 | series.append(adata_slide.obs[obs_key].value_counts(normalize=True)) 35 | 36 | df = pd.concat(series, axis=1) 37 | df.columns = names 38 | 39 | df.T.plot(kind="bar", stacked=True, figsize=figsize, color=dict(zip(all_domains, colors))) 40 | sns.despine(offset=10, trim=True) 41 | plt.legend(bbox_to_anchor=(1.04, 0.5), loc="center left", borderaxespad=0, frameon=False) 42 | plt.ylabel("Proportion") 43 | plt.xticks(rotation=90) 44 | 45 | if show: 46 | plt.show() 47 | -------------------------------------------------------------------------------- /novae/plot/_graph.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import scanpy as sc 4 | import seaborn as sns 5 | from anndata import AnnData 6 | from matplotlib.collections import LineCollection 7 | from scipy.cluster.hierarchy import dendrogram 8 | from sklearn.cluster import AgglomerativeClustering 9 | 10 | from .. import utils 11 | from .._constants import Keys 12 | from ._utils import _subplots_per_slide, get_categorical_color_palette 13 | 14 | 15 | def _leaves_count(clustering: AgglomerativeClustering) -> np.ndarray: 16 | counts = np.zeros(clustering.children_.shape[0]) 17 | n_samples = len(clustering.labels_) 18 | for i, merge in enumerate(clustering.children_): 19 | current_count = 0 20 | for child_idx in merge: 21 | if child_idx < n_samples: 22 | current_count += 1 # leaf node 23 | else: 24 | current_count += counts[child_idx - n_samples] 25 | counts[i] = current_count 26 | return counts 27 | 28 | 29 | def _domains_hierarchy( 30 | clustering: AgglomerativeClustering, 31 | max_level: int = 10, 32 | hline_level: int | list[int] | None = None, 33 | leaf_font_size: int = 10, 34 | **kwargs, 35 | ) -> None: 36 | assert max_level > 1 37 | 38 | size = clustering.children_.shape[0] 39 | original_ymax = max_level + 1 40 | original_ticks = np.arange(1, original_ymax) 41 | height = original_ymax + np.arange(size) - size 42 | 43 | if hline_level is not None: 44 | hline_level = [hline_level] if isinstance(hline_level, int) else hline_level 45 | for level in hline_level: 46 | plt.hlines(original_ymax - level, 0, 1e5, colors="r", linestyles="dashed") 47 | 48 | linkage_matrix = np.column_stack([clustering.children_, height.clip(0), _leaves_count(clustering)]).astype(float) 49 | 50 | ddata = dendrogram( 51 | linkage_matrix, 52 | color_threshold=-1, 53 | leaf_font_size=leaf_font_size, 54 | p=max_level + 1, 55 | truncate_mode="lastp", 56 | above_threshold_color="#ccc", 57 | **kwargs, 58 | ) 59 | 60 | for i, d in zip(ddata["icoord"][::-1], ddata["dcoord"][::-1]): 61 | x, y = 0.5 * sum(i[1:3]), d[1] 62 | plt.plot(x, y, "ko") 63 | plt.annotate( 64 | f"D{2 * size - max_level + int(y)}", 65 | (x, y), 66 | xytext=(0, -8), 67 | textcoords="offset points", 68 | va="top", 69 | ha="center", 70 | ) 71 | 72 | plt.yticks(original_ticks, original_ymax - original_ticks) 73 | 74 | plt.xlabel(None) 75 | plt.ylabel("Domains level") 76 | plt.title("Domains hierarchy") 77 | plt.xlabel("Number of domains in node (or prototype ID if no parenthesis)") 78 | sns.despine(offset=10, trim=True, bottom=True) 79 | 80 | 81 | def paga(adata: AnnData, obs_key: str | None = None, show: bool = True, **paga_plot_kwargs: int): 82 | """Plot a PAGA graph. 83 | 84 | Info: 85 | Currently, this function only supports one slide per call. 86 | 87 | Args: 88 | adata: An AnnData object. 89 | obs_key: Name of the key from `adata.obs` containing the Novae domains. By default, the last available domain key is shown. 90 | show: Whether to show the plot. 91 | **paga_plot_kwargs: Additional arguments for `sc.pl.paga`. 92 | """ 93 | assert isinstance(adata, AnnData), f"For now, only AnnData objects are supported, received {type(adata)}" 94 | 95 | obs_key = utils.check_available_domains_key([adata], obs_key) 96 | 97 | get_categorical_color_palette([adata], obs_key) 98 | 99 | adata_clean = adata[~adata.obs[obs_key].isna()] 100 | 101 | if "paga" not in adata.uns or adata.uns["paga"]["groups"] != obs_key: 102 | sc.pp.neighbors(adata_clean, use_rep=Keys.REPR) 103 | sc.tl.paga(adata_clean, groups=obs_key) 104 | 105 | adata.uns["paga"] = adata_clean.uns["paga"] 106 | adata.uns[f"{obs_key}_sizes"] = adata_clean.uns[f"{obs_key}_sizes"] 107 | 108 | sc.pl.paga(adata_clean, title=f"PAGA graph ({obs_key})", show=False, **paga_plot_kwargs) 109 | sns.despine(offset=10, trim=True, bottom=True) 110 | 111 | if show: 112 | plt.show() 113 | 114 | 115 | def connectivities( 116 | adata: AnnData, 117 | ngh_threshold: int | None = 2, 118 | cell_size: int = 2, 119 | ncols: int = 4, 120 | fig_size_per_slide: tuple[int, int] = (5, 5), 121 | linewidths: float = 0.1, 122 | line_color: str = "#333", 123 | cmap: str = "rocket", 124 | color_isolated_cells: str = "orangered", 125 | show: bool = True, 126 | ): 127 | """Show the graph of the spatial connectivities between cells. By default, 128 | the cells which have a number of neighbors inferior to `ngh_threshold` are shown 129 | in red. If `ngh_threshold` is `None`, the cells are colored by the number of neighbors. 130 | 131 | !!! info "Quality control" 132 | This plot is useful to check the quality of the spatial connectivities obtained via [novae.spatial_neighbors][]. 133 | Make sure few cells (e.g., less than 5%) have a number of neighbors below `ngh_threshold`. 134 | If too many cells are isolated, you may want to increase the `radius` parameter in [novae.spatial_neighbors][]. 135 | Conversely, if there are some less that are really **far from each other**, but still connected, so may want to decrease the `radius` parameter to **disconnect** them. 136 | 137 | Args: 138 | adata: An AnnData object. 139 | ngh_threshold: Only cells with a number of neighbors below this threshold are shown (with color `color_isolated_cells`). If `None`, cells are colored by the number of neighbors. 140 | cell_size: Size of the dots for each cell. By default, it uses the median distance between neighbor cells. 141 | ncols: Number of columns to be shown. 142 | fig_size_per_slide: Size of the figure for each slide. 143 | linewidths: Width of the lines/edges connecting the cells. 144 | line_color: Color of the lines/edges. 145 | cmap: Name of the colormap to use for the number of neighbors. 146 | color_isolated_cells: Color for the cells with a number of neighbors below `ngh_threshold` (if not `None`). 147 | show: Whether to show the plot. 148 | """ 149 | adatas = [adata] if isinstance(adata, AnnData) else adata 150 | 151 | fig, axes = _subplots_per_slide(adatas, ncols, fig_size_per_slide) 152 | 153 | for i, adata in enumerate(utils.iter_slides(adatas)): 154 | ax = axes[i // ncols, i % ncols] 155 | 156 | utils.check_has_spatial_adjancency(adata) 157 | 158 | X, A = adata.obsm["spatial"], adata.obsp[Keys.ADJ] 159 | 160 | ax.invert_yaxis() 161 | ax.axes.set_aspect("equal") 162 | 163 | rows, cols = A.nonzero() 164 | mask = rows < cols 165 | rows, cols = rows[mask], cols[mask] 166 | edge_segments = np.stack([X[rows], X[cols]], axis=1) 167 | edges = LineCollection(edge_segments, color=line_color, linewidths=linewidths, zorder=1) 168 | ax.add_collection(edges) 169 | 170 | n_neighbors = (A > 0).sum(1).A1 171 | 172 | if ngh_threshold is None: 173 | _ = ax.scatter(X[:, 0], X[:, 1], c=n_neighbors, s=cell_size, zorder=2, cmap=cmap) 174 | plt.colorbar(_, ax=ax) 175 | else: 176 | isolated_cells = n_neighbors < ngh_threshold 177 | ax.scatter(X[isolated_cells, 0], X[isolated_cells, 1], color=color_isolated_cells, s=cell_size, zorder=2) 178 | 179 | ax.set_title(adata.obs[Keys.SLIDE_ID].iloc[0]) 180 | 181 | [fig.delaxes(ax) for ax in axes.flatten() if not ax.has_data()] # remove unused subplots 182 | 183 | title = "Node connectivities" + (f" (threshold={ngh_threshold} neighbors)" if ngh_threshold is not None else "") 184 | 185 | if i == 0: 186 | axes[0, 0].set_title(title) 187 | else: 188 | fig.suptitle(title, fontsize=14) 189 | 190 | if show: 191 | plt.show() 192 | -------------------------------------------------------------------------------- /novae/plot/_heatmap.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | 4 | import matplotlib.patches as mpatches 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import pandas as pd 8 | import scanpy as sc 9 | import seaborn as sns 10 | from anndata import AnnData 11 | 12 | from .. import utils 13 | from .._constants import Keys 14 | 15 | log = logging.getLogger(__name__) 16 | 17 | 18 | def _weights_clustermap( 19 | weights: np.ndarray, 20 | adatas: list[AnnData] | None, 21 | slide_ids: list[str], 22 | show_yticklabels: bool = False, 23 | show_tissue_legend: bool = True, 24 | figsize: tuple[int] = (6, 4), 25 | vmin: float = 0, 26 | vmax: float = 1, 27 | **kwargs: int, 28 | ) -> None: 29 | row_colors = None 30 | if show_tissue_legend and adatas is not None and all(Keys.UNS_TISSUE in adata.uns for adata in adatas): 31 | tissues = list({adata.uns[Keys.UNS_TISSUE] for adata in adatas}) 32 | tissue_colors = {tissue: sns.color_palette("tab20")[i] for i, tissue in enumerate(tissues)} 33 | 34 | row_colors = [] 35 | for slide_id in slide_ids: 36 | for adata in adatas: 37 | if adata.obs[Keys.SLIDE_ID].iloc[0] == slide_id: 38 | row_colors.append(tissue_colors[adata.uns[Keys.UNS_TISSUE]]) 39 | break 40 | else: 41 | row_colors.append("gray") 42 | log.info(f"Using {row_colors=}") 43 | 44 | sns.clustermap( 45 | weights, 46 | yticklabels=slide_ids if show_yticklabels else False, 47 | xticklabels=False, 48 | vmin=vmin, 49 | vmax=vmax, 50 | figsize=figsize, 51 | row_colors=row_colors, 52 | **kwargs, 53 | ) 54 | 55 | if show_tissue_legend and row_colors is not None: 56 | handles = [mpatches.Patch(facecolor=color, label=tissue) for tissue, color in tissue_colors.items()] 57 | ax = plt.gcf().axes[3] 58 | ax.legend(handles=handles, bbox_to_anchor=(1.04, 0.5), loc="center left", borderaxespad=0, frameon=False) 59 | 60 | 61 | TEMP_KEY = "_temp" 62 | 63 | 64 | def pathway_scores( 65 | adata: AnnData, 66 | pathways: dict[str, list[str]] | str, 67 | obs_key: str | None = None, 68 | pathway_name: str | None = None, 69 | slide_name_key: str | None = None, 70 | return_df: bool = False, 71 | figsize: tuple[int, int] = (10, 5), 72 | min_pathway_size: int = 4, 73 | show: bool = True, 74 | **kwargs: int, 75 | ) -> pd.DataFrame | None: 76 | """Show a heatmap of either (i) the score of multiple pathways for each domain, or (ii) one pathway score for each domain and for each slide. 77 | To use the latter case, provide `pathway_name`, or make sure to have only one pathway in `pathways`. 78 | 79 | Info: 80 | Currently, this function only supports one AnnData object per call. 81 | 82 | Args: 83 | adata: An `AnnData` object. 84 | pathways: Either a dictionary of pathways (keys are pathway names, values are lists of gene names), or a path to a [GSEA](https://www.gsea-msigdb.org/gsea/msigdb/index.jsp) JSON file. 85 | obs_key: Key in `adata.obs` that contains the domains. By default, it will use the last available Novae domain key. 86 | pathway_name: If `None`, all pathways will be shown (first mode). If not `None`, this specific pathway will be shown, for all domains and all slides (second mode). 87 | slide_name_key: Key of `adata.obs` that contains the slide names. By default, uses the Novae unique slide ID. 88 | return_df: Whether to return the DataFrame. 89 | figsize: Matplotlib figure size. 90 | min_pathway_size: Minimum number of known genes in the pathway to be considered. 91 | show: Whether to show the plot. 92 | 93 | Returns: 94 | A DataFrame of scores per domain if `return_df` is True. 95 | """ 96 | assert isinstance(adata, AnnData), f"For now, only one AnnData object is supported, received {type(adata)}" 97 | 98 | obs_key = utils.check_available_domains_key([adata], obs_key) 99 | 100 | if isinstance(pathways, str): 101 | pathways = _load_gsea_json(pathways) 102 | log.info(f"Loaded {len(pathways)} pathway(s)") 103 | 104 | if len(pathways) == 1: 105 | pathway_name = next(iter(pathways.keys())) 106 | 107 | if pathway_name is not None: 108 | gene_names = pathways[pathway_name] 109 | is_valid = _get_pathway_score(adata, gene_names, min_pathway_size) 110 | assert is_valid, f"Pathway '{pathway_name}' has less than {min_pathway_size} genes in the dataset." 111 | else: 112 | scores = {} 113 | 114 | for key, gene_names in pathways.items(): 115 | is_valid = _get_pathway_score(adata, gene_names, min_pathway_size) 116 | if is_valid: 117 | scores[key] = adata.obs[TEMP_KEY] 118 | 119 | if pathway_name is not None: 120 | log.info(f"Plot mode: {pathway_name} score per domain per slide") 121 | 122 | slide_name_key = utils.check_slide_name_key(adata, slide_name_key) 123 | 124 | df = adata.obs.groupby([obs_key, slide_name_key], observed=True)[TEMP_KEY].mean().unstack() 125 | df.columns.name = slide_name_key 126 | 127 | assert len(df) > 1, f"Found {len(df)} valid slide. Minimum 2 required." 128 | else: 129 | log.info(f"Plot mode: {len(scores)} pathways scores per domain") 130 | 131 | assert len(scores) > 1, f"Found {len(scores)} valid pathway. Minimum 2 required." 132 | 133 | df = pd.DataFrame(scores) 134 | df[obs_key] = adata.obs[obs_key] 135 | df = df.groupby(obs_key, observed=True).mean() 136 | df.columns.name = "Pathways" 137 | 138 | del adata.obs[TEMP_KEY] 139 | 140 | df = df.fillna(0) 141 | 142 | g = sns.clustermap(df, figsize=figsize, **kwargs) 143 | plt.setp(g.ax_heatmap.yaxis.get_majorticklabels(), rotation=0) 144 | 145 | if show: 146 | plt.show() 147 | 148 | if return_df: 149 | return df 150 | 151 | 152 | def _get_pathway_score(adata: AnnData, gene_names: list[str], min_pathway_size: int) -> bool: 153 | lower_var_names = adata.var_names.str.lower() 154 | 155 | gene_names = np.array([gene_name.lower() for gene_name in gene_names]) 156 | gene_names = adata.var_names[np.isin(lower_var_names, gene_names)] 157 | 158 | if len(gene_names) >= min_pathway_size: 159 | sc.tl.score_genes(adata, gene_names, score_name=TEMP_KEY) 160 | return True 161 | return False 162 | 163 | 164 | def _load_gsea_json(path: str) -> dict[str, list[str]]: 165 | with open(path) as f: 166 | content: dict = json.load(f) 167 | assert all("geneSymbols" in value for value in content.values()), ( 168 | "Missing 'geneSymbols' key in JSON file. Expected a valid GSEA JSON file." 169 | ) 170 | return {key: value["geneSymbols"] for key, value in content.items()} 171 | -------------------------------------------------------------------------------- /novae/plot/_spatial.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import pandas as pd 6 | import scanpy as sc 7 | import seaborn as sns 8 | from anndata import AnnData 9 | from matplotlib.lines import Line2D 10 | from scanpy._utils import sanitize_anndata 11 | 12 | from .. import utils 13 | from .._constants import Keys 14 | from ._utils import ( 15 | _get_default_cell_size, 16 | _subplots_per_slide, 17 | get_categorical_color_palette, 18 | ) 19 | 20 | log = logging.getLogger(__name__) 21 | 22 | 23 | def domains( 24 | adata: AnnData | list[AnnData], 25 | obs_key: str | None = None, 26 | slide_name_key: str | None = None, 27 | cell_size: int | None = None, 28 | ncols: int = 4, 29 | fig_size_per_slide: tuple[int, int] = (5, 5), 30 | na_color: str = "#ccc", 31 | show: bool = True, 32 | library_id: str | None = None, 33 | **kwargs: int, 34 | ): 35 | """Show the Novae spatial domains for all slides in the `AnnData` object. 36 | 37 | Info: 38 | Make sure you have already your Novae domains assigned to the `AnnData` object. You can use `model.assign_domains(...)` to do so. 39 | 40 | Args: 41 | adata: An `AnnData` object, or a list of `AnnData` objects. 42 | obs_key: Name of the key from `adata.obs` containing the Novae domains. By default, the last available domain key is shown. 43 | slide_name_key: Key of `adata.obs` that contains the slide names. By default, uses the Novae unique slide ID. 44 | cell_size: Size of the cells or spots. By default, it uses the median distance between neighbor cells. 45 | ncols: Number of columns to be shown. 46 | fig_size_per_slide: Size of the figure for each slide. 47 | na_color: Color for cells that does not belong to any domain (i.e. cells with a too small neighborhood). 48 | show: Whether to show the plot. 49 | library_id: `library_id` argument for `sc.pl.spatial`. 50 | **kwargs: Additional arguments for `sc.pl.spatial`. 51 | """ 52 | if obs_key is not None: 53 | assert str(obs_key).startswith(Keys.DOMAINS_PREFIX), f"Received {obs_key=}, which is not a valid Novae obs_key" 54 | 55 | adatas = adata if isinstance(adata, list) else [adata] 56 | slide_name_key = utils.check_slide_name_key(adatas, slide_name_key) 57 | obs_key = utils.check_available_domains_key(adatas, obs_key) 58 | 59 | for adata in adatas: 60 | sanitize_anndata(adata) 61 | 62 | all_domains, colors = get_categorical_color_palette(adatas, obs_key) 63 | cell_size = cell_size or _get_default_cell_size(adata) 64 | 65 | fig, axes = _subplots_per_slide(adatas, ncols, fig_size_per_slide) 66 | 67 | for i, adata in enumerate(utils.iter_slides(adatas)): 68 | ax = axes[i // ncols, i % ncols] 69 | slide_name = adata.obs[slide_name_key].iloc[0] 70 | assert len(np.unique(adata.obs[slide_name_key])) == 1 71 | 72 | sc.pl.spatial( 73 | adata, 74 | spot_size=cell_size, 75 | color=obs_key, 76 | ax=ax, 77 | show=False, 78 | library_id=library_id, 79 | **kwargs, 80 | ) 81 | sns.despine(ax=ax, offset=10, trim=True) 82 | ax.get_legend().remove() 83 | ax.set_title(slide_name) 84 | 85 | [fig.delaxes(ax) for ax in axes.flatten() if not ax.has_data()] # remove unused subplots 86 | 87 | title = f"Novae domains ({obs_key})" 88 | 89 | if i == 0: 90 | axes[0, 0].set_title(title) 91 | else: 92 | fig.suptitle(title, fontsize=14, y=1.15) 93 | 94 | handles = [ 95 | Line2D([0], [0], marker="o", color="w", markerfacecolor=color, markersize=8, linestyle="None") 96 | for color in [*colors, na_color] 97 | ] 98 | fig.legend( 99 | handles, 100 | [*all_domains, "NA"], 101 | loc="upper center" if i > 1 else "center left", 102 | bbox_to_anchor=(0.5, 1.1) if i > 1 else (1.04, 0.5), 103 | borderaxespad=0, 104 | frameon=False, 105 | ncol=len(colors) // (3 if i > 1 else 10) + 1, 106 | ) 107 | 108 | if show: 109 | plt.show() 110 | 111 | 112 | def spatially_variable_genes( 113 | adata: AnnData, 114 | obs_key: str | None = None, 115 | top_k: int = 5, 116 | cell_size: int | None = None, 117 | min_positive_ratio: float = 0.05, 118 | return_list: bool = False, 119 | show: bool = True, 120 | **kwargs: int, 121 | ) -> None | list[str]: 122 | """Plot the most spatially variable genes (SVG) for a given `AnnData` object. 123 | 124 | !!! info 125 | Currently, this function only supports one slide per call. 126 | 127 | Args: 128 | adata: An `AnnData` object corresponding to one slide. 129 | obs_key: Key in `adata.obs` that contains the domains. By default, it will use the last available Novae domain key. 130 | top_k: Number of SVG to be shown. 131 | cell_size: Size of the cells or spots (`spot_size` argument of `sc.pl.spatial`). By default, it uses the median distance between neighbor cells. 132 | min_positive_ratio: Genes whose "ratio of cells expressing it" is lower than this threshold are not considered. 133 | return_list: Whether to return the list of SVG instead of plotting them. 134 | show: Whether to show the plot. 135 | **kwargs: Additional arguments for `sc.pl.spatial`. 136 | 137 | Returns: 138 | A list of SVG names if `return_list` is `True`. 139 | """ 140 | assert isinstance(adata, AnnData), f"Received adata of type {type(adata)}. Currently only AnnData is supported." 141 | 142 | obs_key = utils.check_available_domains_key([adata], obs_key) 143 | 144 | sc.tl.rank_genes_groups(adata, groupby=obs_key) 145 | df = pd.concat( 146 | [ 147 | sc.get.rank_genes_groups_df(adata, domain).set_index("names")["logfoldchanges"] 148 | for domain in adata.obs[obs_key].cat.categories 149 | ], 150 | axis=1, 151 | ) 152 | 153 | where = (adata.X > 0).mean(0) > min_positive_ratio 154 | valid_vars = adata.var_names[where.A1 if isinstance(where, np.matrix) else where] 155 | assert len(valid_vars) >= top_k, ( 156 | f"Only {len(valid_vars)} genes are available. Please decrease `top_k` or `min_positive_ratio`." 157 | ) 158 | 159 | svg = df.std(1).loc[valid_vars].sort_values(ascending=False).head(top_k).index 160 | 161 | if return_list: 162 | return svg.tolist() 163 | 164 | sc.pl.spatial(adata, color=svg, spot_size=cell_size or _get_default_cell_size(adata), show=show, **kwargs) 165 | -------------------------------------------------------------------------------- /novae/plot/_utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import seaborn as sns 4 | from anndata import AnnData 5 | 6 | from .._constants import Keys 7 | 8 | 9 | def get_categorical_color_palette(adatas: list[AnnData], obs_key: str) -> tuple[list[str], list[str]]: 10 | key_added = f"{obs_key}_colors" 11 | 12 | all_domains = sorted(set.union(*[set(adata.obs[obs_key].cat.categories) for adata in adatas])) 13 | 14 | n_colors = len(all_domains) 15 | colors = list(sns.color_palette("tab10" if n_colors <= 10 else "tab20", n_colors=n_colors).as_hex()) 16 | for adata in adatas: 17 | adata.obs[obs_key] = adata.obs[obs_key].cat.set_categories(all_domains) 18 | adata.uns[key_added] = colors 19 | 20 | return all_domains, colors 21 | 22 | 23 | def _subplots_per_slide( 24 | adatas: list[AnnData], ncols: int, fig_size_per_slide: tuple[int, int] 25 | ) -> tuple[plt.Figure, np.ndarray]: 26 | n_slides = sum(len(adata.obs[Keys.SLIDE_ID].cat.categories) for adata in adatas) 27 | ncols = n_slides if n_slides < ncols else ncols 28 | nrows = (n_slides + ncols - 1) // ncols 29 | 30 | fig, axes = plt.subplots( 31 | nrows, ncols, figsize=(ncols * fig_size_per_slide[0], nrows * fig_size_per_slide[1]), squeeze=False 32 | ) 33 | 34 | return fig, axes 35 | 36 | 37 | def _get_default_cell_size(adata: AnnData | list[AnnData]) -> float: 38 | if isinstance(adata, list): 39 | adata = max(adata, key=lambda adata: adata.n_obs) 40 | 41 | assert Keys.ADJ in adata.obsp, ( 42 | f"Expected {Keys.ADJ} in adata.obsp. Please run `novae.spatial_neighbors(...)` first." 43 | ) 44 | 45 | return np.median(adata.obsp[Keys.ADJ].data) 46 | -------------------------------------------------------------------------------- /novae/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from ._utils import ( 2 | repository_root, 3 | tqdm, 4 | fill_invalid_indices, 5 | lower_var_names, 6 | wandb_log_dir, 7 | pretty_num_parameters, 8 | pretty_model_repr, 9 | parse_device_args, 10 | valid_indices, 11 | unique_leaves_indices, 12 | unique_obs, 13 | sparse_std, 14 | iter_slides, 15 | ) 16 | from ._build import spatial_neighbors 17 | from ._validate import check_available_domains_key, prepare_adatas, check_has_spatial_adjancency, check_slide_name_key 18 | from ._data import load_local_dataset, load_wandb_artifact, toy_dataset, load_dataset 19 | from ._mode import Mode 20 | from ._correct import batch_effect_correction 21 | from ._preprocess import quantile_scaling 22 | -------------------------------------------------------------------------------- /novae/utils/_correct.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from anndata import AnnData 4 | 5 | from .._constants import Keys 6 | 7 | 8 | def _slides_indices(adatas: list[AnnData], only_valid_obs: bool = True) -> tuple[list[int], list[np.ndarray]]: 9 | adata_indices, slides_obs_indices = [], [] 10 | 11 | for i, adata in enumerate(adatas): 12 | slide_ids = adata.obs[Keys.SLIDE_ID].cat.categories 13 | 14 | for slide_id in slide_ids: 15 | condition = adata.obs[Keys.SLIDE_ID] == slide_id 16 | if only_valid_obs: 17 | condition = condition & adata.obs[Keys.IS_VALID_OBS] 18 | 19 | adata_indices.append(i) 20 | slides_obs_indices.append(np.where(condition)[0]) 21 | 22 | return adata_indices, slides_obs_indices 23 | 24 | 25 | def _domains_counts(adata: AnnData, i: int, obs_key: str) -> pd.DataFrame: 26 | df = adata.obs[[Keys.SLIDE_ID, obs_key]].groupby(Keys.SLIDE_ID, observed=False)[obs_key].value_counts().unstack() 27 | df[Keys.ADATA_INDEX] = i 28 | return df 29 | 30 | 31 | def _domains_counts_per_slide(adatas: list[AnnData], obs_key: str) -> pd.DataFrame: 32 | return pd.concat([_domains_counts(adata, i, obs_key) for i, adata in enumerate(adatas)], axis=0) 33 | 34 | 35 | def batch_effect_correction(adatas: list[AnnData], obs_key: str) -> None: 36 | for adata in adatas: 37 | assert obs_key in adata.obs, f"Did not found `adata.obs['{obs_key}']`" 38 | assert Keys.REPR in adata.obsm, ( 39 | f"Did not found `adata.obsm['{Keys.REPR}']`. Please run `model.compute_representations(...)` first" 40 | ) 41 | 42 | adata_indices, slides_obs_indices = _slides_indices(adatas) 43 | 44 | domains_counts_per_slide = _domains_counts_per_slide(adatas, obs_key) 45 | domains = domains_counts_per_slide.columns[:-1] 46 | ref_slide_ids: pd.Series = domains_counts_per_slide[domains].idxmax(axis=0) 47 | 48 | def _centroid_reference(domain: str, slide_id: str, obs_key: str): 49 | adata_ref_index: int = domains_counts_per_slide[Keys.ADATA_INDEX].loc[slide_id] 50 | adata_ref = adatas[adata_ref_index] 51 | where = (adata_ref.obs[Keys.SLIDE_ID] == slide_id) & (adata_ref.obs[obs_key] == domain) 52 | return adata_ref.obsm[Keys.REPR][where].mean(0) 53 | 54 | centroids_reference = pd.DataFrame({ 55 | domain: _centroid_reference(domain, slide_id, obs_key) for domain, slide_id in ref_slide_ids.items() 56 | }) 57 | 58 | for adata in adatas: 59 | adata.obsm[Keys.REPR_CORRECTED] = adata.obsm[Keys.REPR].copy() 60 | 61 | for adata_index, obs_indices in zip(adata_indices, slides_obs_indices): 62 | adata = adatas[adata_index] 63 | 64 | for domain in domains: 65 | if adata.obs[Keys.SLIDE_ID].iloc[obs_indices[0]] == ref_slide_ids.loc[domain]: 66 | continue # reference for this domain 67 | 68 | indices_domain = obs_indices[adata.obs.iloc[obs_indices][obs_key] == domain] 69 | if len(indices_domain) == 0: 70 | continue 71 | 72 | centroid_reference = centroids_reference[domain].values 73 | centroid = adata.obsm[Keys.REPR][indices_domain].mean(0) 74 | 75 | adata.obsm[Keys.REPR_CORRECTED][indices_domain] += centroid_reference - centroid 76 | -------------------------------------------------------------------------------- /novae/utils/_mode.py: -------------------------------------------------------------------------------- 1 | class Mode: 2 | """Novae mode class, used to store states variables related to training and inference.""" 3 | 4 | zero_shot_clustering_attrs: list[str] = ["_clustering_zero", "_clusters_levels_zero"] 5 | normal_clustering_attrs: list[str] = ["_clustering", "_clusters_levels"] 6 | all_clustering_attrs: list[str] = normal_clustering_attrs + zero_shot_clustering_attrs 7 | 8 | def __init__(self): 9 | self.zero_shot = False 10 | self.trained = False 11 | self.pretrained = False 12 | 13 | def __repr__(self) -> str: 14 | return f"Mode({dict(self.__dict__.items())})" 15 | 16 | ### Mode modifiers 17 | 18 | def from_pretrained(self): 19 | self.zero_shot = False 20 | self.trained = True 21 | self.pretrained = True 22 | 23 | def fine_tune(self): 24 | assert self.pretrained, "Fine-tuning requires a pretrained model." 25 | self.zero_shot = False 26 | 27 | def fit(self): 28 | self.zero_shot = False 29 | self.trained = False 30 | 31 | ### Mode-specific attributes 32 | 33 | @property 34 | def clustering_attr(self): 35 | return "_clustering_zero" if self.zero_shot else "_clustering" 36 | 37 | @property 38 | def clusters_levels_attr(self): 39 | return "_clusters_levels_zero" if self.zero_shot else "_clusters_levels" 40 | -------------------------------------------------------------------------------- /novae/utils/_preprocess.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from anndata import AnnData 4 | 5 | from .._constants import Nums 6 | from . import iter_slides 7 | from ._validate import _check_has_slide_id 8 | 9 | 10 | def quantile_scaling( 11 | adata: AnnData | list[AnnData], 12 | multiplier: float = 5, 13 | quantile: float = 0.2, 14 | per_slide: bool = True, 15 | ) -> pd.DataFrame: 16 | """Preprocess fluorescence data from `adata.X` using quantiles of expression. 17 | For each column `X`, we compute `asinh(X / 5*Q(0.2, X))`, and store them back. 18 | 19 | Args: 20 | adata: An `AnnData` object, or a list of `AnnData` objects. 21 | multiplier: The multiplier for the quantile. 22 | quantile: The quantile to compute. 23 | per_slide: Whether to compute the quantile per slide. If `False`, the quantile is computed for each `AnnData` object. 24 | """ 25 | _check_has_slide_id(adata) 26 | 27 | if isinstance(adata, list): 28 | for adata_ in adata: 29 | quantile_scaling(adata_, multiplier, quantile, per_slide=per_slide) 30 | return 31 | 32 | if not per_slide: 33 | return _quantile_scaling(adata, multiplier, quantile) 34 | 35 | for adata_ in iter_slides(adata): 36 | _quantile_scaling(adata_, multiplier, quantile) 37 | 38 | 39 | def _quantile_scaling(adata: AnnData, multiplier: float, quantile: float): 40 | df = adata.to_df() 41 | 42 | divider = multiplier * np.quantile(df, quantile, axis=0) 43 | divider[divider == 0] = df.max(axis=0)[divider == 0] + Nums.EPS 44 | 45 | adata.X = np.arcsinh(df / divider) 46 | -------------------------------------------------------------------------------- /novae/utils/_utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import logging 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | from anndata import AnnData 9 | from lightning.pytorch.trainer.connectors.accelerator_connector import ( 10 | _AcceleratorConnector, 11 | ) 12 | from scipy.sparse import csr_matrix 13 | from torch import Tensor 14 | 15 | from .._constants import Keys 16 | 17 | log = logging.getLogger(__name__) 18 | 19 | 20 | def unique_obs(adata: AnnData | list[AnnData], obs_key: str) -> set: 21 | if isinstance(adata, list): 22 | return set.union(*[unique_obs(adata_, obs_key) for adata_ in adata]) 23 | return set(adata.obs[obs_key].dropna().unique()) 24 | 25 | 26 | def unique_leaves_indices(adata: AnnData | list[AnnData]) -> set: 27 | leaves = unique_obs(adata, Keys.LEAVES) 28 | return np.array([int(x[1:]) for x in leaves]) 29 | 30 | 31 | def valid_indices(adata: AnnData) -> np.ndarray: 32 | return np.where(adata.obs[Keys.IS_VALID_OBS])[0] 33 | 34 | 35 | def lower_var_names(var_names: pd.Index | list[str]) -> pd.Index | list[str]: 36 | if isinstance(var_names, pd.Index): 37 | return var_names.str.lower() 38 | return [name.lower() for name in var_names] 39 | 40 | 41 | def sparse_std(a: csr_matrix, axis=None) -> np.matrix: 42 | a_squared = a.multiply(a) 43 | return np.sqrt(a_squared.mean(axis) - np.square(a.mean(axis))) 44 | 45 | 46 | def fill_invalid_indices( 47 | out: np.ndarray | Tensor, 48 | n_obs: int, 49 | valid_indices: list[int], 50 | fill_value: float | str = np.nan, 51 | dtype: object = None, 52 | ) -> np.ndarray: 53 | if isinstance(out, Tensor): 54 | out = out.numpy(force=True) 55 | 56 | dtype = np.float32 if dtype is None else dtype 57 | 58 | if isinstance(fill_value, str): 59 | dtype = object 60 | 61 | res = np.full((n_obs, *out.shape[1:]), fill_value, dtype=dtype) 62 | res[valid_indices] = out 63 | return res 64 | 65 | 66 | def parse_device_args(accelerator: str = "cpu") -> torch.device: 67 | """Updated from scvi-tools""" 68 | connector = _AcceleratorConnector(accelerator=accelerator) 69 | _accelerator = connector._accelerator_flag 70 | _devices = connector._devices_flag 71 | 72 | if _accelerator == "cpu": 73 | return torch.device("cpu") 74 | 75 | if isinstance(_devices, list): 76 | device_idx = _devices[0] 77 | elif isinstance(_devices, str) and "," in _devices: 78 | device_idx = _devices.split(",")[0] 79 | else: 80 | device_idx = _devices 81 | 82 | return torch.device(f"{_accelerator}:{device_idx}") 83 | 84 | 85 | def repository_root() -> Path: 86 | """Get the path to the root of the repository (dev-mode users only) 87 | 88 | Returns: 89 | `novae` repository path 90 | """ 91 | path = Path(__file__).parents[2] 92 | 93 | if path.name != "novae": 94 | log.warning(f"Trying to get the novae repository path, but it seems it was not installed in dev mode: {path}") 95 | 96 | return path 97 | 98 | 99 | def wandb_log_dir() -> Path: 100 | return repository_root() / "wandb" 101 | 102 | 103 | def tqdm(*args, desc="DataLoader", **kwargs): 104 | # check if ipywidgets is installed before importing tqdm.auto 105 | # to ensure it won't fail and a progress bar is displayed 106 | if importlib.util.find_spec("ipywidgets") is not None: 107 | from tqdm.auto import tqdm as _tqdm 108 | else: 109 | from tqdm import tqdm as _tqdm 110 | 111 | return _tqdm(*args, desc=desc, **kwargs) 112 | 113 | 114 | def pretty_num_parameters(model: torch.nn.Module) -> str: 115 | n_params = sum(p.numel() for p in model.parameters()) 116 | 117 | if n_params < 1_000_000: 118 | return f"{n_params / 1_000:.1f}K" 119 | 120 | return f"{n_params / 1_000_000:.1f}M" 121 | 122 | 123 | def pretty_model_repr(info_dict: dict[str, str], model_name: str = "Novae") -> str: 124 | rows = [f"{model_name} model"] + [f"{k}: {v}" for k, v in info_dict.items()] 125 | return "\n ├── ".join(rows[:-1]) + "\n └── " + rows[-1] 126 | 127 | 128 | def iter_slides(adatas: AnnData | list[AnnData]): 129 | """Iterate over all slides. 130 | 131 | Args: 132 | adatas: One or a list of AnnData object(s). 133 | 134 | Yields: 135 | One `AnnData` per slide. 136 | """ 137 | if isinstance(adatas, AnnData): 138 | adatas = [adatas] 139 | 140 | for adata in adatas: 141 | slide_ids = adata.obs[Keys.SLIDE_ID].unique() 142 | 143 | if len(slide_ids) == 1: 144 | yield adata 145 | continue 146 | 147 | for slide_id in slide_ids: 148 | yield adata[adata.obs[Keys.SLIDE_ID] == slide_id] 149 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "novae" 3 | version = "0.2.4" 4 | description = "Graph-based foundation model for spatial transcriptomics data" 5 | documentation = "https://mics-lab.github.io/novae/" 6 | homepage = "https://mics-lab.github.io/novae/" 7 | repository = "https://github.com/MICS-Lab/novae" 8 | authors = ["Quentin Blampey "] 9 | packages = [{ include = "novae" }] 10 | license = "BSD-3-Clause" 11 | readme = "README.md" 12 | classifiers = [ 13 | "License :: OSI Approved :: BSD License", 14 | "Operating System :: MacOS :: MacOS X", 15 | "Operating System :: POSIX :: Linux", 16 | "Operating System :: Microsoft :: Windows", 17 | "Programming Language :: Python :: 3", 18 | "Topic :: Scientific/Engineering", 19 | ] 20 | 21 | [tool.poetry.dependencies] 22 | python = ">=3.10,<3.13" 23 | scanpy = ">=1.9.8" 24 | lightning = ">=2.2.1" 25 | torch = ">=2.2.1" 26 | torch-geometric = ">=2.5.2" 27 | huggingface-hub = ">=0.24.0" 28 | safetensors = ">=0.4.3" 29 | pandas = ">=2.0.0" 30 | 31 | ruff = { version = ">=0.11.4", optional = true } 32 | mypy = { version = ">=1.15.0", optional = true } 33 | pre-commit = { version = ">=3.8.0", optional = true } 34 | pytest = { version = ">=7.1.3", optional = true } 35 | pytest-cov = { version = ">=5.0.0", optional = true } 36 | wandb = { version = ">=0.17.2", optional = true } 37 | pyyaml = { version = ">=6.0.1", optional = true } 38 | pydantic = { version = ">=2.8.2", optional = true } 39 | ipykernel = { version = ">=6.22.0", optional = true } 40 | ipywidgets = { version = ">=8.1.2", optional = true } 41 | mkdocs-material = { version = ">=8.5.6", optional = true } 42 | mkdocs-jupyter = { version = ">=0.21.0", optional = true } 43 | mkdocstrings = { version = ">=0.19.0", optional = true } 44 | mkdocstrings-python = { version = ">=0.7.1", optional = true } 45 | 46 | [tool.poetry.extras] 47 | dev = [ 48 | "ruff", 49 | "mypy", 50 | "pre-commit", 51 | "pytest", 52 | "pytest-cov", 53 | "wandb", 54 | "pyyaml", 55 | "pydantic", 56 | "ipykernel", 57 | "ipywidgets", 58 | "mkdocs-material", 59 | "mkdocs-jupyter", 60 | "mkdocstrings", 61 | "mkdocstrings-python", 62 | ] 63 | 64 | [build-system] 65 | requires = ["poetry-core>=1.0.0"] 66 | build-backend = "poetry.core.masonry.api" 67 | 68 | [tool.mypy] 69 | files = ["novae"] 70 | no_implicit_optional = true 71 | check_untyped_defs = true 72 | warn_return_any = true 73 | warn_unused_ignores = true 74 | show_error_codes = true 75 | ignore_missing_imports = true 76 | 77 | [tool.pytest.ini_options] 78 | testpaths = ["tests"] 79 | python_files = "test_*.py" 80 | 81 | 82 | [tool.ruff] 83 | target-version = "py39" 84 | line-length = 120 85 | fix = true 86 | 87 | [tool.ruff.lint] 88 | select = [ 89 | # flake8-2020 90 | "YTT", 91 | # flake8-bandit 92 | "S", 93 | # flake8-bugbear 94 | "B", 95 | # flake8-builtins 96 | "A", 97 | # flake8-comprehensions 98 | "C4", 99 | # flake8-debugger 100 | "T10", 101 | # flake8-simplify 102 | "SIM", 103 | # isort 104 | "I", 105 | # mccabe 106 | "C90", 107 | # pycodestyle 108 | "E", 109 | "W", 110 | # pyflakes 111 | "F", 112 | # pygrep-hooks 113 | "PGH", 114 | # pyupgrade 115 | "UP", 116 | # ruff 117 | "RUF", 118 | # tryceratops 119 | "TRY", 120 | ] 121 | ignore = [ 122 | # LineTooLong 123 | "E501", 124 | # DoNotAssignLambda 125 | "E731", 126 | # DoNotUseAssert 127 | "S101", 128 | "TRY003", 129 | "RUF012", 130 | "B904", 131 | "E722", 132 | ] 133 | 134 | [tool.ruff.lint.per-file-ignores] 135 | "tests/*" = ["S101"] 136 | "__init__.py" = ["F401", "I001"] 137 | "*.ipynb" = ["F401"] 138 | 139 | [tool.ruff.format] 140 | preview = true 141 | 142 | [tool.coverage.report] 143 | skip_empty = true 144 | 145 | [tool.coverage.run] 146 | source = ["novae"] 147 | omit = ["**/test_*.py", "novae/monitor/log.py", "novae/monitor/callback.py"] 148 | -------------------------------------------------------------------------------- /scripts/README.md: -------------------------------------------------------------------------------- 1 | # Pre-training scripts 2 | 3 | These scripts are used to pretrain and monitor Novae with Weight & Biases. To see the actualy source code of Novae, refer to the `novae` directory. 4 | 5 | ## Setup 6 | 7 | For monitoring, `novae` must be installed with the `dev` extra, for instance via pip: 8 | 9 | ```sh 10 | pip install -e '.[dev]' 11 | ``` 12 | 13 | ## Usage 14 | 15 | In the `data` directory, make sure to have `.h5ad` files. You can use the downloading scripts to get public data. 16 | The corresponding `AnnData` object should contain raw counts, or preprocessed with `normalize_total` and `log1p`. 17 | 18 | ### Normal training 19 | 20 | Choose a config inside the `config` directory. 21 | 22 | ```sh 23 | python -m scripts.train --config .yaml 24 | ``` 25 | 26 | ### Sweep training 27 | 28 | Choose a sweep config inside the `sweep` directory. 29 | 30 | Inside the `scripts` directory, initialize the sweep with: 31 | ```sh 32 | wandb sweep --project novae sweep/.yaml 33 | ``` 34 | 35 | Run the sweep with: 36 | ```sh 37 | wandb agent --count 1 38 | ``` 39 | 40 | ### Slurm usage 41 | 42 | ⚠️ Warning: the scripts below are specific to one cluster (Flamingo, Gustave Roussy). You'll need to update the `.sh` scripts according to your cluster. 43 | 44 | In the `slurm` directory: 45 | - `train.sh` / `train_cpu.sh` for training 46 | - `download.sh` to download public data 47 | - `sbatch agent.sh SWEEP_ID COUNT` to run agents (where SWEEP_ID comes from `wandb sweep --project novae sweep/.yaml`) 48 | 49 | E.g., on ruche: 50 | ```sh 51 | module load anaconda3/2024.06/gcc-13.2.0 && source activate novae 52 | wandb sweep --project novae sweep/gpu_ruche.yaml 53 | 54 | cd ruche 55 | sbatch agent.sh SWEEP_ID 56 | ``` 57 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MICS-Lab/novae/b641a0bd1947759317c9e7ab4d997bb7a4a00932/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/config.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | 3 | 4 | class DataConfig(BaseModel): 5 | train_dataset: str = "all" 6 | val_dataset: str | None = None 7 | 8 | files_black_list: list[str] = [] 9 | 10 | 11 | class PostTrainingConfig(BaseModel): 12 | n_domains: list[int] = [7, 10] 13 | log_umap: bool = False 14 | log_metrics: bool = False 15 | log_domains: bool = False 16 | save_h5ad: bool = False 17 | delete_X: bool = False 18 | 19 | 20 | class Config(BaseModel): 21 | project: str = "novae" 22 | wandb_artefact: str | None = None 23 | zero_shot: bool = False 24 | sweep: bool = False 25 | seed: int = 0 26 | 27 | data: DataConfig = DataConfig() 28 | post_training: PostTrainingConfig = PostTrainingConfig() 29 | 30 | model_kwargs: dict = {} 31 | fit_kwargs: dict = {} 32 | wandb_init_kwargs: dict = {} 33 | -------------------------------------------------------------------------------- /scripts/config/README.md: -------------------------------------------------------------------------------- 1 | # Config 2 | 3 | These `.yaml` files are used when training with Weight & Biases with the `train.py` script. 4 | 5 | ## Description 6 | 7 | This is a minimal YAML config used to explain its structure. The dataset names are relative to the `data` directory. If a directory, loads every `.h5ad` files inside it. Can also be a file, or a file pattern. 8 | 9 | ```yaml 10 | data: 11 | train_dataset: merscope # training dataset name 12 | val_dataset: xenium # eval dataset name 13 | 14 | model_kwargs: # Novae model kwargs 15 | heads: 4 16 | 17 | fit_kwargs: # Trainer kwargs (from Lightning) 18 | max_epochs: 3 19 | log_every_n_steps: 10 20 | accelerator: "cpu" 21 | 22 | wandb_init_kwargs: # wandb.init kwargs 23 | mode: "online" 24 | ``` 25 | -------------------------------------------------------------------------------- /scripts/config/_example.yaml: -------------------------------------------------------------------------------- 1 | project: novae 2 | wandb_artefact: novae/novae/xxx 3 | 4 | post_training: 5 | n_domains: [7, 10] 6 | log_umap: true 7 | save_h5ad: true 8 | 9 | data: 10 | train_dataset: all 11 | val_dataset: false 12 | 13 | model_kwargs: 14 | scgpt_model_dir: /gpfs/workdir/blampeyq/checkpoints/scgpt/scGPT_human 15 | n_hops_view: 3 16 | heads: 16 17 | hidden_size: 128 18 | 19 | fit_kwargs: 20 | max_epochs: 30 21 | lr: 0.0001 22 | accelerator: gpu 23 | num_workers: 8 24 | 25 | wandb_init_kwargs: 26 | disabled: true 27 | -------------------------------------------------------------------------------- /scripts/config/all_16.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_dataset: all 3 | val_dataset: igr/202305031337_hBreast-slide-B-4h-photobleach_VMSC09302 4 | 5 | model_kwargs: 6 | scgpt_model_dir: /gpfs/workdir/blampeyq/checkpoints/scgpt/scGPT_human 7 | n_hops_view: 3 8 | heads: 32 9 | hidden_size: 128 10 | temperature: 0.1 11 | num_prototypes: 256 12 | background_noise_lambda: 5 13 | panel_subset_size: 0.8 14 | min_prototypes_ratio: 0.15 15 | 16 | fit_kwargs: 17 | max_epochs: 30 18 | lr: 0.0002 19 | accelerator: "gpu" 20 | num_workers: 8 21 | patience: 5 22 | min_delta: 0.025 23 | 24 | post_training: 25 | n_domains: [15, 20] 26 | log_metrics: true 27 | save_h5ad: true 28 | log_umap: true 29 | log_domains: true 30 | delete_X: true 31 | -------------------------------------------------------------------------------- /scripts/config/all_17.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_dataset: all 3 | val_dataset: igr/202305031337_hBreast-slide-B-4h-photobleach_VMSC09302 4 | 5 | model_kwargs: 6 | scgpt_model_dir: /gpfs/workdir/blampeyq/checkpoints/scgpt/scGPT_human 7 | n_hops_view: 3 8 | heads: 16 9 | hidden_size: 128 10 | temperature: 0.1 11 | num_prototypes: 64 12 | background_noise_lambda: 5 13 | panel_subset_size: 0.8 14 | min_prototypes_ratio: 0.15 15 | 16 | fit_kwargs: 17 | max_epochs: 30 18 | lr: 0.0002 19 | accelerator: "gpu" 20 | num_workers: 8 21 | patience: 5 22 | min_delta: 0.025 23 | 24 | post_training: 25 | n_domains: [15, 20] 26 | log_metrics: true 27 | save_h5ad: false 28 | log_umap: false 29 | log_domains: true 30 | delete_X: true 31 | -------------------------------------------------------------------------------- /scripts/config/all_brain.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_dataset: /gpfs/workdir/shared/prime/spatial/brain 3 | val_dataset: igr/202305031337_hBreast-slide-B-4h-photobleach_VMSC09302 4 | 5 | model_kwargs: 6 | scgpt_model_dir: /gpfs/workdir/blampeyq/checkpoints/scgpt/scGPT_brain 7 | n_hops_view: 3 8 | heads: 16 9 | hidden_size: 128 10 | temperature: 0.1 11 | num_prototypes: 512 12 | background_noise_lambda: 5 13 | panel_subset_size: 0.8 14 | min_prototypes_ratio: 0.4 15 | 16 | fit_kwargs: 17 | max_epochs: 30 18 | lr: 0.0001 19 | accelerator: "gpu" 20 | num_workers: 8 21 | patience: 6 22 | min_delta: 0.025 23 | 24 | post_training: 25 | n_domains: [15, 20, 25] 26 | log_metrics: true 27 | save_h5ad: true 28 | log_umap: true 29 | log_domains: true 30 | delete_X: true 31 | -------------------------------------------------------------------------------- /scripts/config/all_human.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_dataset: /gpfs/workdir/shared/prime/spatial/human 3 | val_dataset: igr/202305031337_hBreast-slide-B-4h-photobleach_VMSC09302 4 | 5 | model_kwargs: 6 | scgpt_model_dir: /gpfs/workdir/blampeyq/checkpoints/scgpt/scGPT_human 7 | n_hops_view: 3 8 | heads: 16 9 | hidden_size: 128 10 | temperature: 0.1 11 | num_prototypes: 512 12 | background_noise_lambda: 5 13 | panel_subset_size: 0.8 14 | min_prototypes_ratio: 0.5 15 | 16 | fit_kwargs: 17 | max_epochs: 30 18 | lr: 0.0001 19 | accelerator: "gpu" 20 | num_workers: 8 21 | patience: 6 22 | min_delta: 0.025 23 | 24 | post_training: 25 | n_domains: [15, 20, 25] 26 | log_metrics: true 27 | save_h5ad: true 28 | log_umap: true 29 | log_domains: true 30 | delete_X: true 31 | -------------------------------------------------------------------------------- /scripts/config/all_human2.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_dataset: /gpfs/workdir/shared/prime/spatial/human 3 | val_dataset: igr/202305031337_hBreast-slide-B-4h-photobleach_VMSC09302 4 | 5 | model_kwargs: 6 | scgpt_model_dir: /gpfs/workdir/blampeyq/checkpoints/scgpt/scGPT_human 7 | n_hops_view: 2 8 | heads: 16 9 | hidden_size: 128 10 | temperature: 0.1 11 | num_prototypes: 512 12 | background_noise_lambda: 5 13 | panel_subset_size: 0.8 14 | min_prototypes_ratio: 0.75 15 | 16 | fit_kwargs: 17 | max_epochs: 30 18 | lr: 0.0001 19 | accelerator: "gpu" 20 | num_workers: 8 21 | patience: 6 22 | min_delta: 0.025 23 | 24 | post_training: 25 | n_domains: [15, 20, 25] 26 | log_metrics: true 27 | save_h5ad: true 28 | log_umap: true 29 | log_domains: true 30 | delete_X: true 31 | -------------------------------------------------------------------------------- /scripts/config/all_mouse.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_dataset: /gpfs/workdir/shared/prime/spatial/mouse 3 | val_dataset: igr/202305031337_hBreast-slide-B-4h-photobleach_VMSC09302 4 | 5 | model_kwargs: 6 | scgpt_model_dir: /gpfs/workdir/blampeyq/checkpoints/scgpt/scGPT_human 7 | n_hops_view: 3 8 | heads: 16 9 | hidden_size: 128 10 | temperature: 0.1 11 | num_prototypes: 512 12 | background_noise_lambda: 5 13 | panel_subset_size: 0.8 14 | min_prototypes_ratio: 0.4 15 | 16 | fit_kwargs: 17 | max_epochs: 30 18 | lr: 0.0001 19 | accelerator: "gpu" 20 | num_workers: 8 21 | patience: 6 22 | min_delta: 0.025 23 | 24 | post_training: 25 | n_domains: [15, 20, 25] 26 | log_metrics: true 27 | save_h5ad: true 28 | log_umap: true 29 | log_domains: true 30 | delete_X: true 31 | -------------------------------------------------------------------------------- /scripts/config/all_new.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_dataset: all 3 | val_dataset: igr/202305031337_hBreast-slide-B-4h-photobleach_VMSC09302 4 | 5 | model_kwargs: 6 | scgpt_model_dir: /gpfs/workdir/blampeyq/checkpoints/scgpt/scGPT_human 7 | n_hops_view: 3 8 | heads: 16 9 | hidden_size: 128 10 | temperature: 0.1 11 | num_prototypes: 512 12 | background_noise_lambda: 5 13 | panel_subset_size: 0.8 14 | min_prototypes_ratio: 0.15 15 | 16 | fit_kwargs: 17 | max_epochs: 30 18 | lr: 0.0001 19 | accelerator: "gpu" 20 | num_workers: 8 21 | patience: 6 22 | min_delta: 0.025 23 | 24 | post_training: 25 | n_domains: [15, 20, 25] 26 | log_metrics: true 27 | save_h5ad: true 28 | log_umap: true 29 | log_domains: true 30 | delete_X: true 31 | -------------------------------------------------------------------------------- /scripts/config/all_ruche.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_dataset: all 3 | val_dataset: igr/202305031337_hBreast-slide-B-4h-photobleach_VMSC09302 4 | 5 | model_kwargs: 6 | scgpt_model_dir: /gpfs/workdir/blampeyq/checkpoints/scgpt/scGPT_human 7 | n_hops_view: 3 8 | heads: 16 9 | num_prototypes: 512 10 | panel_subset_size: 0.8 11 | temperature: 0.1 12 | 13 | fit_kwargs: 14 | max_epochs: 50 15 | accelerator: "gpu" 16 | num_workers: 8 17 | patience: 3 18 | min_delta: 0.05 19 | -------------------------------------------------------------------------------- /scripts/config/all_spot.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_dataset: all 3 | val_dataset: igr/202305031337_hBreast-slide-B-4h-photobleach_VMSC09302 4 | 5 | model_kwargs: 6 | scgpt_model_dir: /gpfs/workdir/blampeyq/checkpoints/scgpt/scGPT_human 7 | n_hops_view: 2 8 | heads: 8 9 | hidden_size: 128 10 | temperature: 0.1 11 | n_hops_local: 1 12 | num_prototypes: 256 13 | background_noise_lambda: 5 14 | panel_subset_size: 0.8 15 | min_prototypes_ratio: 0.15 16 | 17 | fit_kwargs: 18 | max_epochs: 30 19 | lr: 0.0002 20 | accelerator: "gpu" 21 | num_workers: 8 22 | patience: 5 23 | min_delta: 0.025 24 | 25 | post_training: 26 | n_domains: [15, 20] 27 | log_metrics: true 28 | save_h5ad: true 29 | log_umap: true 30 | log_domains: true 31 | delete_X: true 32 | -------------------------------------------------------------------------------- /scripts/config/brain.yaml: -------------------------------------------------------------------------------- 1 | project: novae_eval 2 | wandb_artefact: novae/novae/model-y8i2da7y:v30 3 | 4 | data: 5 | train_dataset: brain 6 | 7 | fit_kwargs: 8 | max_epochs: 10 9 | accelerator: "gpu" 10 | num_workers: 8 11 | patience: 3 12 | min_delta: 0.05 13 | 14 | post_training: 15 | n_domains: [7, 10, 15] 16 | log_metrics: true 17 | save_h5ad: true 18 | log_umap: true 19 | log_domains: true 20 | -------------------------------------------------------------------------------- /scripts/config/brain2.yaml: -------------------------------------------------------------------------------- 1 | project: novae_eval 2 | wandb_artefact: novae/novae/model-y8i2da7y:v30 3 | 4 | data: 5 | train_dataset: brain2 6 | 7 | fit_kwargs: 8 | max_epochs: 10 9 | accelerator: "gpu" 10 | num_workers: 8 11 | patience: 3 12 | min_delta: 0.05 13 | 14 | post_training: 15 | n_domains: [7, 10, 15] 16 | log_metrics: true 17 | save_h5ad: true 18 | log_umap: true 19 | log_domains: true 20 | -------------------------------------------------------------------------------- /scripts/config/breast.yaml: -------------------------------------------------------------------------------- 1 | project: novae_eval 2 | wandb_artefact: novae/novae/model-y8i2da7y:v30 3 | 4 | data: 5 | train_dataset: breast 6 | 7 | fit_kwargs: 8 | max_epochs: 10 9 | accelerator: "gpu" 10 | num_workers: 8 11 | patience: 3 12 | min_delta: 0.05 13 | 14 | post_training: 15 | n_domains: [7, 10, 15] 16 | log_metrics: true 17 | save_h5ad: true 18 | log_umap: true 19 | log_domains: true 20 | -------------------------------------------------------------------------------- /scripts/config/breast_zs.yaml: -------------------------------------------------------------------------------- 1 | project: novae_eval 2 | wandb_artefact: novae/novae/model-y8i2da7y:v30 3 | zero_shot: true 4 | 5 | data: 6 | train_dataset: breast 7 | 8 | post_training: 9 | n_domains: [7, 10, 15] 10 | log_metrics: true 11 | save_h5ad: true 12 | log_umap: true 13 | log_domains: true 14 | -------------------------------------------------------------------------------- /scripts/config/breast_zs2.yaml: -------------------------------------------------------------------------------- 1 | project: novae_eval 2 | wandb_artefact: novae/novae/model-y4f53mpp:v30 3 | zero_shot: true 4 | 5 | data: 6 | train_dataset: breast 7 | 8 | post_training: 9 | n_domains: [7, 10, 15] 10 | log_metrics: true 11 | save_h5ad: true 12 | log_umap: true 13 | log_domains: true 14 | -------------------------------------------------------------------------------- /scripts/config/colon.yaml: -------------------------------------------------------------------------------- 1 | project: novae_eval 2 | wandb_artefact: novae/novae/model-y8i2da7y:v30 3 | 4 | data: 5 | train_dataset: colon 6 | 7 | fit_kwargs: 8 | max_epochs: 15 9 | accelerator: "gpu" 10 | num_workers: 8 11 | patience: 3 12 | min_delta: 0.05 13 | 14 | post_training: 15 | n_domains: [7, 10, 15] 16 | log_metrics: true 17 | save_h5ad: true 18 | log_umap: true 19 | log_domains: true 20 | delete_X: true 21 | -------------------------------------------------------------------------------- /scripts/config/colon_retrain.yaml: -------------------------------------------------------------------------------- 1 | project: novae_eval 2 | 3 | model_kwargs: 4 | scgpt_model_dir: /gpfs/workdir/blampeyq/checkpoints/scgpt/scGPT_human 5 | n_hops_view: 3 6 | heads: 16 7 | hidden_size: 128 8 | temperature: 0.1 9 | num_prototypes: 512 10 | background_noise_lambda: 5 11 | panel_subset_size: 0.8 12 | 13 | data: 14 | train_dataset: colon 15 | 16 | fit_kwargs: 17 | max_epochs: 20 18 | accelerator: "gpu" 19 | num_workers: 8 20 | patience: 3 21 | min_delta: 0.05 22 | 23 | post_training: 24 | n_domains: [7, 10, 15] 25 | log_metrics: true 26 | save_h5ad: true 27 | log_umap: true 28 | log_domains: true 29 | delete_X: true 30 | -------------------------------------------------------------------------------- /scripts/config/colon_zs.yaml: -------------------------------------------------------------------------------- 1 | project: novae_eval 2 | wandb_artefact: novae/novae/model-y8i2da7y:v30 3 | zero_shot: true 4 | 5 | data: 6 | train_dataset: colon 7 | 8 | post_training: 9 | n_domains: [7, 10, 15] 10 | log_metrics: true 11 | save_h5ad: true 12 | log_umap: true 13 | log_domains: true 14 | delete_X: true 15 | -------------------------------------------------------------------------------- /scripts/config/colon_zs2.yaml: -------------------------------------------------------------------------------- 1 | project: novae_eval 2 | wandb_artefact: novae/novae/model-y4f53mpp:v30 3 | zero_shot: true 4 | 5 | data: 6 | train_dataset: colon 7 | 8 | post_training: 9 | n_domains: [7, 10, 15] 10 | log_metrics: true 11 | save_h5ad: true 12 | log_umap: true 13 | log_domains: true 14 | delete_X: true 15 | -------------------------------------------------------------------------------- /scripts/config/local_tests.yaml: -------------------------------------------------------------------------------- 1 | project: novae_tests 2 | 3 | data: 4 | train_dataset: toy_train 5 | val_dataset: toy_val 6 | 7 | model_kwargs: 8 | embedding_size: 50 9 | 10 | fit_kwargs: 11 | max_epochs: 2 12 | accelerator: "cpu" 13 | 14 | post_training: 15 | n_domains: [2, 3] 16 | log_umap: true 17 | log_domains: true 18 | -------------------------------------------------------------------------------- /scripts/config/lymph_node.yaml: -------------------------------------------------------------------------------- 1 | project: novae_eval 2 | wandb_artefact: novae/novae/model-y8i2da7y:v30 3 | 4 | data: 5 | train_dataset: lymph_node 6 | 7 | fit_kwargs: 8 | max_epochs: 10 9 | accelerator: "gpu" 10 | num_workers: 8 11 | patience: 3 12 | min_delta: 0.05 13 | 14 | post_training: 15 | n_domains: [7, 10, 15] 16 | log_metrics: true 17 | save_h5ad: true 18 | log_umap: true 19 | log_domains: true 20 | -------------------------------------------------------------------------------- /scripts/config/missing.yaml: -------------------------------------------------------------------------------- 1 | project: novae 2 | seed: 0 3 | 4 | data: 5 | train_dataset: toy_ari/v2 6 | 7 | model_kwargs: 8 | n_hops_view: 2 9 | n_hops_local: 2 10 | 11 | fit_kwargs: 12 | max_epochs: 40 13 | accelerator: "gpu" 14 | num_workers: 4 15 | -------------------------------------------------------------------------------- /scripts/config/ovarian.yaml: -------------------------------------------------------------------------------- 1 | project: novae_eval 2 | wandb_artefact: novae/novae/model-y8i2da7y:v30 3 | 4 | data: 5 | train_dataset: ovarian 6 | 7 | fit_kwargs: 8 | max_epochs: 10 9 | accelerator: "gpu" 10 | num_workers: 8 11 | patience: 3 12 | min_delta: 0.05 13 | 14 | post_training: 15 | n_domains: [7, 10, 15] 16 | log_metrics: true 17 | save_h5ad: true 18 | log_umap: true 19 | log_domains: true 20 | -------------------------------------------------------------------------------- /scripts/config/revision.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_dataset: /gpfs/workdir/blampeyq/novae/data/_hyper 3 | 4 | model_kwargs: 5 | scgpt_model_dir: /gpfs/workdir/blampeyq/checkpoints/scgpt/scGPT_human 6 | n_hops_view: 2 7 | background_noise_lambda: 5 8 | panel_subset_size: 0.8 9 | min_prototypes_ratio: 0.6 10 | 11 | fit_kwargs: 12 | max_epochs: 40 13 | accelerator: "gpu" 14 | num_workers: 8 15 | patience: 5 16 | min_delta: 0.015 17 | 18 | post_training: 19 | n_domains: [8, 10, 12] 20 | log_metrics: true 21 | log_domains: true 22 | -------------------------------------------------------------------------------- /scripts/config/revision_tests.yaml: -------------------------------------------------------------------------------- 1 | project: novae_tests 2 | 3 | data: 4 | train_dataset: toy_train 5 | 6 | model_kwargs: 7 | n_hops_view: 2 8 | background_noise_lambda: 5 9 | panel_subset_size: 0.8 10 | min_prototypes_ratio: 0.6 11 | 12 | fit_kwargs: 13 | max_epochs: 1 14 | accelerator: "cpu" 15 | patience: 4 16 | min_delta: 0.025 17 | 18 | post_training: 19 | n_domains: [8, 10, 12] 20 | log_metrics: true 21 | log_domains: true 22 | -------------------------------------------------------------------------------- /scripts/config/toy_cpu_seed0.yaml: -------------------------------------------------------------------------------- 1 | project: novae_eval 2 | seed: 0 3 | 4 | data: 5 | train_dataset: toy_ari/v2 6 | 7 | model_kwargs: 8 | n_hops_view: 1 9 | n_hops_local: 1 10 | heads: 16 11 | hidden_size: 128 12 | temperature: 0.1 13 | num_prototypes: 256 14 | background_noise_lambda: 5 15 | panel_subset_size: 0.8 16 | min_prototypes_ratio: 1 17 | 18 | fit_kwargs: 19 | max_epochs: 30 20 | lr: 0.0005 21 | patience: 6 22 | min_delta: 0.025 23 | 24 | post_training: 25 | n_domains: [7, 10, 15] 26 | log_metrics: true 27 | save_h5ad: true 28 | log_umap: true 29 | log_domains: true 30 | delete_X: true 31 | -------------------------------------------------------------------------------- /scripts/config/toy_missing.yaml: -------------------------------------------------------------------------------- 1 | project: novae 2 | seed: 0 3 | 4 | data: 5 | train_dataset: toy_ari/v2 6 | 7 | model_kwargs: 8 | n_hops_view: 1 9 | n_hops_local: 1 10 | 11 | fit_kwargs: 12 | max_epochs: 30 13 | accelerator: "gpu" 14 | num_workers: 4 15 | -------------------------------------------------------------------------------- /scripts/missing_domain.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import scanpy as sc 5 | import wandb 6 | 7 | import novae 8 | from novae.monitor.log import log_plt_figure 9 | 10 | from .utils import init_wandb_logger, read_config 11 | 12 | 13 | def main(args: argparse.Namespace) -> None: 14 | config = read_config(args) 15 | 16 | path = Path("/gpfs/workdir/blampeyq/novae/data/_lung_robustness") 17 | adata1_split = sc.read_h5ad(path / "v1_split.h5ad") 18 | adata2_full = sc.read_h5ad(path / "v2_full.h5ad") 19 | adatas = [adata1_split, adata2_full] 20 | 21 | logger = init_wandb_logger(config) 22 | 23 | model = novae.Novae(adatas, **config.model_kwargs) 24 | model.fit(logger=logger, **config.fit_kwargs) 25 | 26 | model.compute_representations(adatas) 27 | obs_key = model.assign_domains(adatas, level=7) 28 | 29 | novae.plot.domains(adatas, obs_key=obs_key, show=False) 30 | log_plt_figure(f"domains_{obs_key}") 31 | 32 | adata2_split = adata2_full[adata2_full.obsm["spatial"][:, 0] < 5000].copy() 33 | jsd = novae.monitor.jensen_shannon_divergence([adata1_split, adata2_split], obs_key=obs_key) 34 | 35 | wandb.log({"metrics/jsd": jsd}) 36 | 37 | 38 | if __name__ == "__main__": 39 | parser = argparse.ArgumentParser() 40 | parser.add_argument( 41 | "-c", 42 | "--config", 43 | type=str, 44 | required=True, 45 | help="Fullname of the YAML config to be used for training (see under the `config` directory)", 46 | ) 47 | parser.add_argument( 48 | "-s", 49 | "--sweep", 50 | nargs="?", 51 | default=False, 52 | const=True, 53 | help="Whether it is a sweep or not", 54 | ) 55 | 56 | main(parser.parse_args()) 57 | -------------------------------------------------------------------------------- /scripts/revision/cpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=novae 3 | #SBATCH --output=/gpfs/workdir/blampeyq/.jobs_outputs/%j 4 | #SBATCH --mem=80G 5 | #SBATCH --cpus-per-task=8 6 | #SBATCH --partition=cpu_long 7 | 8 | module purge 9 | module load anaconda3/2022.10/gcc-11.2.0 && source activate novae 10 | 11 | cd /gpfs/workdir/blampeyq/novae/scripts/revision 12 | 13 | # Execute training 14 | python -u $1 15 | -------------------------------------------------------------------------------- /scripts/revision/heterogeneous.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import scanpy as sc 3 | 4 | import novae 5 | from novae._constants import Nums 6 | 7 | Nums.QUEUE_WEIGHT_THRESHOLD_RATIO = 0.9999999 8 | Nums.WARMUP_EPOCHS = 4 9 | 10 | suffix = "_constants_fit_all9" 11 | 12 | dir_name = "/gpfs/workdir/blampeyq/novae/data/_heterogeneous" 13 | 14 | adatas = [ 15 | sc.read_h5ad(f"{dir_name}/Xenium_V1_Human_Colon_Cancer_P2_CRC_Add_on_FFPE_outs.h5ad"), 16 | # sc.read_h5ad(f"{dir_name}/Xenium_V1_Human_Brain_GBM_FFPE_outs.h5ad"), 17 | sc.read_h5ad(f"{dir_name}/Xenium_V1_hLymphNode_nondiseased_section_outs.h5ad"), 18 | ] 19 | 20 | adatas[0].uns["novae_tissue"] = "colon" 21 | # adatas[1].uns["novae_tissue"] = "brain" 22 | adatas[1].uns["novae_tissue"] = "lymph_node" 23 | 24 | novae.utils.spatial_neighbors(adatas, radius=80) 25 | 26 | model = novae.Novae( 27 | adatas, 28 | scgpt_model_dir="/gpfs/workdir/blampeyq/checkpoints/scgpt/scGPT_human", 29 | min_prototypes_ratio=0.3, 30 | heads=16, 31 | hidden_size=128, 32 | temperature=0.1, 33 | num_prototypes=1024, 34 | background_noise_lambda=5, 35 | panel_subset_size=0.8, 36 | # num_prototypes=512, 37 | # temperature=0.5, 38 | ) 39 | model.fit(max_epochs=30) 40 | model.compute_representations() 41 | 42 | # model = novae.Novae.from_pretrained("MICS-Lab/novae-human-0") 43 | # model.fine_tune(adatas, min_prototypes_ratio=0.25, reference="all") 44 | # model.compute_representations(adatas) 45 | 46 | for level in range(7, 15): 47 | model.assign_domains(level=level) 48 | 49 | model.plot_prototype_weights() 50 | plt.savefig(f"{dir_name}/prototype_weights{suffix}.pdf", bbox_inches="tight") 51 | model.plot_prototype_weights(assign_zeros=False) 52 | plt.savefig(f"{dir_name}/prototype_weights{suffix}_nz.pdf", bbox_inches="tight") 53 | 54 | for i, adata in enumerate(adatas): 55 | del adata.X 56 | for key in list(adata.layers.keys()): 57 | del adata.layers[key] 58 | adata.write_h5ad(f"{dir_name}/{i}_res{suffix}.h5ad") 59 | -------------------------------------------------------------------------------- /scripts/revision/heterogeneous_start.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import scanpy as sc 3 | 4 | import novae 5 | from novae._constants import Nums 6 | 7 | Nums.QUEUE_WEIGHT_THRESHOLD_RATIO = 0.99 8 | 9 | suffix = "" 10 | 11 | dir_name = "/gpfs/workdir/blampeyq/novae/data/_heterogeneous" 12 | 13 | adatas = [ 14 | sc.read_h5ad(f"{dir_name}/Xenium_V1_Human_Colon_Cancer_P2_CRC_Add_on_FFPE_outs.h5ad"), 15 | # sc.read_h5ad(f"{dir_name}/Xenium_V1_Human_Brain_GBM_FFPE_outs.h5ad"), 16 | sc.read_h5ad(f"{dir_name}/Xenium_V1_hLymphNode_nondiseased_section_outs.h5ad"), 17 | ] 18 | 19 | adatas[0].uns["novae_tissue"] = "colon" 20 | # adatas[1].uns["novae_tissue"] = "brain" 21 | adatas[1].uns["novae_tissue"] = "lymph_node" 22 | 23 | for adata in adatas: 24 | adata.obs["novae_tissue"] = adata.uns["novae_tissue"] 25 | 26 | novae.utils.spatial_neighbors(adatas, radius=80) 27 | 28 | # model = novae.Novae(adatas) 29 | # model.mode.trained = True 30 | # model.compute_representations(adatas) 31 | # model.assign_domains(adatas) 32 | 33 | # adata = sc.concat(adatas, join="inner") 34 | # adata = sc.pp.subsample(adata, n_obs=100_000, copy=True) 35 | # sc.pp.neighbors(adata, use_rep="novae_latent") 36 | # sc.tl.umap(adata) 37 | # sc.pl.umap(adata, color=["novae_domains_7", "novae_tissue"]) 38 | # plt.savefig(f"{dir_name}/umap_start{suffix}.png", bbox_inches="tight") 39 | 40 | model = novae.Novae( 41 | adatas, 42 | scgpt_model_dir="/gpfs/workdir/blampeyq/checkpoints/scgpt/scGPT_human", 43 | ) 44 | model.mode.trained = True 45 | model.compute_representations(adatas) 46 | model.assign_domains(adatas) 47 | 48 | adata = sc.concat(adatas, join="inner") 49 | adata = sc.pp.subsample(adata, n_obs=100_000, copy=True) 50 | sc.pp.neighbors(adata, use_rep="novae_latent") 51 | sc.tl.umap(adata) 52 | sc.pl.umap(adata, color=["novae_domains_7", "novae_tissue"]) 53 | plt.savefig(f"{dir_name}/umap_start_scgpt{suffix}.png", bbox_inches="tight") 54 | -------------------------------------------------------------------------------- /scripts/revision/mgc.py: -------------------------------------------------------------------------------- 1 | import anndata 2 | import matplotlib.pyplot as plt 3 | 4 | import novae 5 | 6 | adata = anndata.read_h5ad("/gpfs/workdir/blampeyq/novae/data/_mgc/MGC_merged_adata_clean_graph.h5ad") 7 | 8 | novae.utils.quantile_scaling(adata) 9 | 10 | model = novae.Novae(adata, embedding_size=62) # 63 proteins 11 | 12 | model.fit() 13 | 14 | model.compute_representations() 15 | 16 | model.plot_domains_hierarchy(max_level=16) 17 | plt.savefig("/gpfs/workdir/blampeyq/novae/data/_mgc/domains_hierarchy.pdf", bbox_inches="tight") 18 | 19 | for level in range(7, 15): 20 | model.assign_domains(level=level) 21 | 22 | adata.write_h5ad("/gpfs/workdir/blampeyq/novae/data/_mgc/MGC_merged_adata_clean_graph_domains.h5ad") 23 | -------------------------------------------------------------------------------- /scripts/revision/missing_domains.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import anndata 4 | import matplotlib.pyplot as plt 5 | import scanpy as sc 6 | 7 | import novae 8 | from novae._constants import Nums 9 | 10 | Nums.WARMUP_EPOCHS = 1 11 | # Nums.SWAV_EPSILON = 0.01 12 | 13 | suffix = "_numproto_5" 14 | 15 | path = Path("/gpfs/workdir/blampeyq/novae/data/_lung_robustness") 16 | 17 | adata1_split = sc.read_h5ad(path / "v1_split.h5ad") 18 | adata2_full = sc.read_h5ad(path / "v2_full.h5ad") 19 | # adata2_split = sc.read_h5ad(path / "v2_split.h5ad") 20 | 21 | adata_join = anndata.concat([adata1_split, adata2_full], join="inner") 22 | # adatas = [adata1_split, adata2_full] 23 | 24 | novae.spatial_neighbors(adata_join, slide_key="slide_id", radius=80) 25 | 26 | # shared_genes = adata1_split.var_names.intersection(adata2_full.var_names) 27 | # adata1_split = adata1_split[:, shared_genes].copy() 28 | # adata2_full = adata2_full[:, shared_genes].copy() 29 | # adatas = [adata1_split, adata2_full] 30 | 31 | model = novae.Novae( 32 | adata_join, 33 | num_prototypes=2000, 34 | heads=8, 35 | hidden_size=128, 36 | min_prototypes_ratio=0.5, 37 | ) 38 | model.fit() 39 | model.compute_representations() 40 | 41 | # model = novae.Novae.from_pretrained("MICS-Lab/novae-human-0") 42 | # model.fine_tune(adatas, min_prototypes_ratio=0.5, reference="largest") 43 | # model.compute_representations(adatas) 44 | 45 | obs_key = model.assign_domains(adata_join, resolution=1) 46 | for res in [0.3, 0.35, 0.4, 0.45, 0.5]: 47 | obs_key = model.assign_domains(adata_join, resolution=res) 48 | obs_key = model.assign_domains(adata_join, level=7) 49 | 50 | model.plot_prototype_weights() 51 | plt.savefig(path / f"prototype_weights{suffix}.pdf", bbox_inches="tight") 52 | 53 | # model.umap_prototypes() 54 | # plt.savefig(path / f"umap_prototypes{suffix}.png", bbox_inches="tight") 55 | 56 | adatas = [adata_join[adata_join.obs["novae_sid"] == ID] for ID in adata_join.obs["novae_sid"].unique()] 57 | names = ["v1_split", "v2_full"] 58 | 59 | ### Save 60 | 61 | for adata, name in zip(adatas, names): 62 | adata.obs[f"{obs_key}_split_ft"] = adata.obs[obs_key] 63 | del adata.X 64 | for key in list(adata.layers.keys()): 65 | del adata.layers[key] 66 | adata.write_h5ad(path / f"{name}_res2{suffix}.h5ad") 67 | -------------------------------------------------------------------------------- /scripts/revision/perturbations.py: -------------------------------------------------------------------------------- 1 | import anndata 2 | import torch 3 | from torch_geometric.utils import scatter, softmax 4 | 5 | import novae 6 | from novae._constants import Nums 7 | 8 | adatas = [ 9 | anndata.read_h5ad("/gpfs/workdir/blampeyq/novae/data/_perturbation/HumanBreastCancerPatient1_region_0.h5ad"), 10 | anndata.read_h5ad( 11 | "/gpfs/workdir/blampeyq/novae/data/_perturbation/Xenium_V1_Human_Colon_Cancer_P2_CRC_Add_on_FFPE_outs.h5ad" 12 | ), 13 | ] 14 | 15 | model = novae.Novae.from_pretrained("MICS-Lab/novae-human-0") 16 | 17 | novae.utils.spatial_neighbors(adatas, radius=80) 18 | 19 | # Zero-shot 20 | print("Zero-shot") 21 | model.compute_representations(adatas, zero_shot=True) 22 | 23 | for level in range(7, 15): 24 | model.assign_domains(adatas, level=level) 25 | 26 | for adata in adatas: 27 | adata.obsm["novae_latent_normal"] = adata.obsm["novae_latent"] 28 | 29 | # Attention heterogeneity 30 | print("Attention heterogeneity") 31 | for adata in adatas: 32 | model._datamodule = model._init_datamodule(adata) 33 | 34 | with torch.no_grad(): 35 | _entropies = torch.tensor([], dtype=torch.float32) 36 | gat = model.encoder.gnn 37 | 38 | for data_batch in model.datamodule.predict_dataloader(): 39 | averaged_attentions_list = [] 40 | 41 | data = data_batch["main"] 42 | data = model._embed_pyg_data(data) 43 | 44 | x = data.x 45 | 46 | for i, conv in enumerate(gat.convs): 47 | x, (index, attentions) = conv( 48 | x, data.edge_index, edge_attr=data.edge_attr, return_attention_weights=True 49 | ) 50 | averaged_attentions = scatter(attentions.mean(1), index[0], dim_size=len(data.x), reduce="mean") 51 | averaged_attentions_list.append(averaged_attentions) 52 | if i < gat.num_layers - 1: 53 | x = gat.act(x) 54 | 55 | attention_scores = torch.stack(averaged_attentions_list).mean(0) 56 | attention_scores = softmax(attention_scores / 0.01, data.batch, dim=0) 57 | attention_entropy = scatter(-attention_scores * (attention_scores + Nums.EPS).log2(), index=data.batch) 58 | _entropies = torch.cat([_entropies, attention_entropy]) 59 | 60 | adata.obs["attention_entropies"] = 0.0 61 | adata.obs.loc[adata.obs["neighborhood_valid"], "attention_entropies"] = _entropies.numpy(force=True) 62 | 63 | # Shuffle nodes 64 | print("Shuffle nodes") 65 | novae.settings.shuffle_nodes = True 66 | model.compute_representations(adatas) 67 | 68 | for adata in adatas: 69 | adata.obs["rs_shuffle"] = novae.utils.get_relative_sensitivity(adata, "novae_latent_normal", "novae_latent") 70 | novae.settings.shuffle_nodes = False 71 | 72 | # Edge length drop 73 | print("Edge length drop") 74 | for adata in adatas: 75 | adata.obsp["spatial_distances"].data[:] = 0.01 76 | 77 | model.compute_representations(adatas) 78 | 79 | for adata in adatas: 80 | adata.obs["rs_edge_length"] = novae.utils.get_relative_sensitivity(adata, "novae_latent_normal", "novae_latent") 81 | 82 | # Saving results 83 | 84 | for adata in adatas: 85 | del adata.X 86 | for key in list(adata.layers.keys()): 87 | del adata.layers[key] 88 | 89 | adatas[0].write_h5ad( 90 | "/gpfs/workdir/blampeyq/novae/data/_perturbation/HumanBreastCancerPatient1_region_0_perturbed.h5ad" 91 | ) 92 | adatas[1].write_h5ad( 93 | "/gpfs/workdir/blampeyq/novae/data/_perturbation/Xenium_V1_Human_Colon_Cancer_P2_CRC_Add_on_FFPE_outs_perturbed.h5ad" 94 | ) 95 | -------------------------------------------------------------------------------- /scripts/revision/seg_robustness.py: -------------------------------------------------------------------------------- 1 | import anndata 2 | 3 | import novae 4 | 5 | adatas = [ 6 | anndata.read_h5ad("/gpfs/workdir/blampeyq/novae/data/_colon_seg/adata_graph.h5ad"), 7 | anndata.read_h5ad("/gpfs/workdir/blampeyq/novae/data/_colon_seg/adata_default_graph.h5ad"), 8 | ] 9 | 10 | suffix = "_2" 11 | 12 | model = novae.Novae( 13 | adatas, 14 | num_prototypes=3000, 15 | heads=8, 16 | hidden_size=128, 17 | min_prototypes_ratio=1, 18 | ) 19 | model.fit() 20 | model.compute_representations() 21 | 22 | # model = novae.Novae.from_pretrained("MICS-Lab/novae-human-0") 23 | # model.fine_tune(adatas) 24 | # model.compute_representations(adatas) 25 | 26 | for level in range(7, 15): 27 | model.assign_domains(adatas, level=level) 28 | 29 | adatas[0].write_h5ad(f"/gpfs/workdir/blampeyq/novae/data/_colon_seg/adata_graph_domains{suffix}.h5ad") 30 | adatas[1].write_h5ad(f"/gpfs/workdir/blampeyq/novae/data/_colon_seg/adata_default_graph_domains{suffix}.h5ad") 31 | -------------------------------------------------------------------------------- /scripts/ruche/README.md: -------------------------------------------------------------------------------- 1 | # Scripts for training (rûche HPC) 2 | -------------------------------------------------------------------------------- /scripts/ruche/agent.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=novae 3 | #SBATCH --output=/gpfs/workdir/blampeyq/.jobs_outputs/%j 4 | #SBATCH --time=16:00:00 5 | #SBATCH --partition=gpu 6 | #SBATCH --mem=160G 7 | #SBATCH --cpus-per-task=8 8 | #SBATCH --gres=gpu:1 9 | 10 | module purge 11 | module load anaconda3/2022.10/gcc-11.2.0 12 | source activate novae-gpu 13 | 14 | cd /gpfs/workdir/blampeyq/novae 15 | 16 | # Get config 17 | SWEEP_ID=${1} 18 | AGENT_COUNT=${2:-1} 19 | echo "Running $AGENT_COUNT sequential agent(s) for SWEEP_ID=$SWEEP_ID" 20 | 21 | WANDB__SERVICE_WAIT=300 22 | export WANDB__SERVICE_WAIT 23 | 24 | # Run one agent 25 | wandb agent $SWEEP_ID --count $AGENT_COUNT 26 | -------------------------------------------------------------------------------- /scripts/ruche/convert.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=novae 3 | #SBATCH --output=/gpfs/workdir/blampeyq/.jobs_outputs/%j 4 | #SBATCH --mem=40G 5 | #SBATCH --cpus-per-task=8 6 | #SBATCH --partition=cpu_med 7 | 8 | module purge 9 | module load anaconda3/2022.10/gcc-11.2.0 && source activate spatial 10 | 11 | cd /gpfs/workdir/blampeyq/novae/data 12 | 13 | python -u xenium_convert.py 14 | -------------------------------------------------------------------------------- /scripts/ruche/cpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=novae 3 | #SBATCH --output=/gpfs/workdir/blampeyq/.jobs_outputs/%j 4 | #SBATCH --mem=80G 5 | #SBATCH --cpus-per-task=8 6 | #SBATCH --partition=cpu_long 7 | 8 | module purge 9 | module load anaconda3/2022.10/gcc-11.2.0 && source activate novae 10 | 11 | cd /gpfs/workdir/blampeyq/novae 12 | 13 | # Get config 14 | DEFAULT_CONFIG=swav_cpu_0.yaml 15 | CONFIG=${1:-$DEFAULT_CONFIG} 16 | echo Running with CONFIG=$CONFIG 17 | 18 | WANDB__SERVICE_WAIT=300 19 | export WANDB__SERVICE_WAIT 20 | 21 | # Execute training 22 | python -u -m scripts.train --config $CONFIG 23 | -------------------------------------------------------------------------------- /scripts/ruche/debug_cpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=novae 3 | #SBATCH --output=/gpfs/workdir/blampeyq/.jobs_outputs/%j 4 | #SBATCH --time=00:30:00 5 | #SBATCH --partition=cpu_short 6 | #SBATCH --mem=20G 7 | #SBATCH --cpus-per-task=8 8 | 9 | module purge 10 | module load anaconda3/2022.10/gcc-11.2.0 11 | source activate novae 12 | 13 | cd /gpfs/workdir/blampeyq/novae/scripts 14 | 15 | # Get config 16 | DEFAULT_CONFIG=debug_cpu.yaml 17 | CONFIG=${1:-$DEFAULT_CONFIG} 18 | echo Running with CONFIG=$CONFIG 19 | 20 | WANDB__SERVICE_WAIT=300 21 | export WANDB__SERVICE_WAIT 22 | 23 | # Execute training 24 | CUDA_LAUNCH_BLOCKING=1 python -u train.py --config $CONFIG 25 | -------------------------------------------------------------------------------- /scripts/ruche/debug_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=novae 3 | #SBATCH --output=/gpfs/workdir/blampeyq/.jobs_outputs/%j 4 | #SBATCH --time=00:30:00 5 | #SBATCH --partition=gpu 6 | #SBATCH --mem=32gb 7 | #SBATCH --cpus-per-task=8 8 | #SBATCH --gres=gpu:1 9 | 10 | module purge 11 | module load anaconda3/2022.10/gcc-11.2.0 12 | source activate novae-gpu 13 | 14 | cd /gpfs/workdir/blampeyq/novae/scripts 15 | 16 | # Get config 17 | DEFAULT_CONFIG=debug_gpu.yaml 18 | CONFIG=${1:-$DEFAULT_CONFIG} 19 | echo Running with CONFIG=$CONFIG 20 | 21 | WANDB__SERVICE_WAIT=300 22 | export WANDB__SERVICE_WAIT 23 | 24 | # Execute training 25 | CUDA_LAUNCH_BLOCKING=1 python -u train.py --config $CONFIG 26 | -------------------------------------------------------------------------------- /scripts/ruche/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=novae 3 | #SBATCH --output=/gpfs/workdir/blampeyq/.jobs_outputs/%j 4 | #SBATCH --mem=16G 5 | #SBATCH --time=01:00:00 6 | #SBATCH --cpus-per-task=1 7 | #SBATCH --partition=cpu_med 8 | 9 | module purge 10 | module load anaconda3/2022.10/gcc-11.2.0 11 | source activate novae 12 | 13 | cd /gpfs/workdir/blampeyq/novae/data 14 | 15 | # download all MERSCOPE datasets 16 | sh _scripts/merscope_download.sh 17 | 18 | # convert all datasets to h5ad files 19 | python _scripts/merscope_convert.py 20 | -------------------------------------------------------------------------------- /scripts/ruche/gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=novae 3 | #SBATCH --output=/gpfs/workdir/blampeyq/.jobs_outputs/%j 4 | #SBATCH --time=24:00:00 5 | #SBATCH --partition=gpu 6 | #SBATCH --mem=300G 7 | #SBATCH --cpus-per-task=8 8 | #SBATCH --gres=gpu:1 9 | 10 | module purge 11 | module load anaconda3/2022.10/gcc-11.2.0 12 | source activate novae-gpu 13 | 14 | cd /gpfs/workdir/blampeyq/novae 15 | 16 | # Get config 17 | DEFAULT_CONFIG=all_6.yaml 18 | CONFIG=${1:-$DEFAULT_CONFIG} 19 | echo Running with CONFIG=$CONFIG 20 | 21 | WANDB__SERVICE_WAIT=300 22 | export WANDB__SERVICE_WAIT 23 | 24 | # Execute training 25 | python -u -m scripts.train --config $CONFIG 26 | -------------------------------------------------------------------------------- /scripts/ruche/prepare.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=novae 3 | #SBATCH --output=/gpfs/workdir/blampeyq/.jobs_outputs/%j 4 | #SBATCH --time=16:00:00 5 | #SBATCH --partition=gpu 6 | #SBATCH --mem=160G 7 | #SBATCH --cpus-per-task=8 8 | #SBATCH --gres=gpu:1 9 | 10 | module purge 11 | module load anaconda3/2022.10/gcc-11.2.0 && source activate novae 12 | 13 | cd /gpfs/workdir/blampeyq/novae/data 14 | 15 | python -u _scripts/2_prepare.py -n all3 -d merscope -u 16 | -------------------------------------------------------------------------------- /scripts/ruche/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=novae 3 | #SBATCH --output=/gpfs/workdir/blampeyq/.jobs_outputs/%j 4 | #SBATCH --mem=16G 5 | #SBATCH --cpus-per-task=4 6 | #SBATCH --partition=cpu_short 7 | 8 | module purge 9 | cd /gpfs/workdir/blampeyq/novae/data 10 | 11 | 12 | OUTPUT_DIR="./xenium" 13 | 14 | mkdir -p $OUTPUT_DIR 15 | 16 | /usr/bin/unzip -v 17 | -------------------------------------------------------------------------------- /scripts/sweep/README.md: -------------------------------------------------------------------------------- 1 | ## Wandb sweep configuration 2 | 3 | This is the structure of a wandb config file: 4 | 5 | ```yaml 6 | method: random # search strategy (grid/random/bayes) 7 | 8 | metric: 9 | goal: minimize 10 | name: train/loss_epoch # the metric to be optimized 11 | 12 | parameters: # parameters to hyperoptimize (kwargs arguments of Novae) 13 | heads: 14 | values: [1, 4, 8] 15 | lr: 16 | min: 0.0001 17 | max: 0.01 18 | 19 | command: 20 | - python 21 | - train.py 22 | - --config 23 | - swav_cpu_0.yaml # name of the config under the scripts/config directory 24 | - --sweep 25 | ``` 26 | -------------------------------------------------------------------------------- /scripts/sweep/cpu.yaml: -------------------------------------------------------------------------------- 1 | method: random 2 | metric: 3 | goal: maximize 4 | name: metrics/val_mean_fide_score 5 | parameters: 6 | embedding_size: 7 | values: [32, 64, 128] 8 | heads: 9 | values: [8, 16] 10 | num_prototypes: 11 | values: [1024, 2048] 12 | batch_size: 13 | values: [512, 1024, 2048] 14 | output_size: 15 | values: [64, 128, 256] 16 | num_layers: 17 | values: [8, 12, 16] 18 | hidden_size: 19 | values: [64, 128, 256] 20 | panel_subset_size: 21 | min: 0.1 22 | max: 0.4 23 | gene_expression_dropout: 24 | min: 0.1 25 | max: 0.3 26 | background_noise_lambda: 27 | min: 2 28 | max: 8 29 | sensitivity_noise_std: 30 | min: 0.01 31 | max: 0.1 32 | lr: 33 | min: 0.0001 34 | max: 0.001 35 | 36 | command: 37 | - python 38 | - -m 39 | - scripts.train 40 | - --config 41 | - swav_cpu_0.yaml 42 | - --sweep 43 | -------------------------------------------------------------------------------- /scripts/sweep/debug.yaml: -------------------------------------------------------------------------------- 1 | method: random 2 | metric: 3 | goal: minimize 4 | name: train/loss_epoch 5 | parameters: 6 | heads: 7 | values: [1, 4, 8] 8 | lr: 9 | min: 0.0001 10 | max: 0.01 11 | 12 | command: 13 | - python 14 | - -m 15 | - scripts.train 16 | - --config 17 | - debug_online.yaml 18 | - --sweep 19 | -------------------------------------------------------------------------------- /scripts/sweep/gpu.yaml: -------------------------------------------------------------------------------- 1 | method: random 2 | metric: 3 | goal: maximize 4 | name: metrics/val_mean_fide_score 5 | parameters: 6 | n_hops_view: 7 | values: [2, 3] 8 | heads: 9 | values: [8, 16] 10 | num_prototypes: 11 | values: [256, 512, 1024] 12 | output_size: 13 | values: [128, 256, 512] 14 | num_layers: 15 | values: [8, 12, 16] 16 | hidden_size: 17 | values: [128, 256, 512] 18 | panel_subset_size: 19 | min: 0.4 20 | max: 0.9 21 | background_noise_lambda: 22 | min: 2 23 | max: 8 24 | sensitivity_noise_std: 25 | min: 0.01 26 | max: 0.1 27 | lr: 28 | min: 0.0001 29 | max: 0.001 30 | temperature: 31 | min: 0.05 32 | max: 0.2 33 | 34 | command: 35 | - python 36 | - -m 37 | - scripts.train 38 | - --config 39 | - all.yaml 40 | - --sweep 41 | -------------------------------------------------------------------------------- /scripts/sweep/gpu_ruche.yaml: -------------------------------------------------------------------------------- 1 | method: random 2 | metric: 3 | goal: maximize 4 | name: metrics/val_heuristic 5 | parameters: 6 | output_size: 7 | values: [64, 128] 8 | num_layers: 9 | values: [10, 14, 18] 10 | hidden_size: 11 | values: [64, 128] 12 | background_noise_lambda: 13 | min: 3 14 | max: 5 15 | sensitivity_noise_std: 16 | min: 0.08 17 | max: 0.16 18 | lr: 19 | values: [0.00005, 0.0001, 0.0002, 0.0004] 20 | 21 | command: 22 | - python 23 | - -m 24 | - scripts.train 25 | - --config 26 | - all_ruche.yaml 27 | - --sweep 28 | -------------------------------------------------------------------------------- /scripts/sweep/lung.yaml: -------------------------------------------------------------------------------- 1 | method: bayes 2 | metric: 3 | goal: minimize 4 | name: metrics/jsd 5 | parameters: 6 | num_prototypes: 7 | min: 1000 8 | max: 8000 9 | temperature: 10 | min: 0.05 11 | max: 0.20 12 | lr: 13 | min: 0.00005 14 | max: 0.001 15 | min_delta: 16 | values: [ 0.001, 0.005, 0.01, 0.05, 0.1 ] 17 | SWAV_EPSILON: 18 | min: 0.005 19 | max: 0.05 20 | QUEUE_SIZE: 21 | values: [ 2, 4, 8 ] 22 | WARMUP_EPOCHS: 23 | values: [ 0, 2, 4, 8 ] 24 | 25 | command: 26 | - python 27 | - -m 28 | - scripts.missing_domain 29 | - --config 30 | - missing.yaml 31 | - --sweep 32 | -------------------------------------------------------------------------------- /scripts/sweep/revision.yaml: -------------------------------------------------------------------------------- 1 | method: random 2 | metric: 3 | goal: maximize 4 | name: metrics/train_heuristic_8_domains 5 | parameters: 6 | heads: 7 | values: [2, 4, 8, 16] 8 | num_layers: 9 | values: [2, 6, 9, 12, 15] 10 | num_prototypes: 11 | values: [16, 64, 128, 256, 512, 1024] 12 | batch_size: 13 | values: [64, 128, 256, 512, 1024] 14 | temperature: 15 | values: [0.025, 0.05, 0.1, 0.2, 0.4] 16 | lr: 17 | values: [0.00005, 0.0001, 0.0002, 0.0005, 0.001] 18 | 19 | command: 20 | - python 21 | - -m 22 | - scripts.train 23 | - --config 24 | - revision.yaml 25 | - --sweep 26 | -------------------------------------------------------------------------------- /scripts/sweep/toy.yaml: -------------------------------------------------------------------------------- 1 | method: bayes 2 | metric: 3 | goal: maximize 4 | name: metrics/score 5 | parameters: 6 | num_prototypes: 7 | values: [ 1_000, 3_000, 8_000 ] 8 | temperature: 9 | min: 0.1 10 | max: 0.2 11 | lr: 12 | values: [ 0.0001, 0.0005, 0.001 ] 13 | min_delta: 14 | values: [ 0.001, 0.005, 0.01, 0.05, 0.1 ] 15 | SWAV_EPSILON: 16 | min: 0.01 17 | max: 0.05 18 | 19 | command: 20 | - python 21 | - -m 22 | - scripts.toy_missing_domain 23 | - --config 24 | - toy_missing.yaml 25 | - --sweep 26 | -------------------------------------------------------------------------------- /scripts/toy_missing_domain.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import wandb 4 | from sklearn.metrics.cluster import adjusted_rand_score 5 | 6 | import novae 7 | from novae.monitor.log import log_plt_figure 8 | 9 | from .utils import init_wandb_logger, read_config 10 | 11 | 12 | def main(args: argparse.Namespace) -> None: 13 | config = read_config(args) 14 | 15 | adatas = novae.utils.toy_dataset(n_panels=2, xmax=2_000, n_domains=7, n_vars=300) 16 | 17 | adatas[1] = adatas[1][adatas[1].obs["domain"] != "domain_6", :].copy() 18 | novae.spatial_neighbors(adatas) 19 | 20 | logger = init_wandb_logger(config) 21 | 22 | model = novae.Novae(adatas, **config.model_kwargs) 23 | model.fit(logger=logger, **config.fit_kwargs) 24 | 25 | model.compute_representations(adatas) 26 | obs_key = model.assign_domains(adatas, level=7) 27 | 28 | novae.plot.domains(adatas, obs_key=obs_key, show=False) 29 | log_plt_figure(f"domains_{obs_key}") 30 | 31 | ari = adjusted_rand_score(adatas[0].obs[obs_key], adatas[0].obs["domain"]) 32 | 33 | adatas[0] = adatas[0][adatas[0].obs["domain"] != "domain_6", :].copy() 34 | accuracy = (adatas[0].obs[obs_key].values.astype(str) == adatas[1].obs[obs_key].values.astype(str)).mean() 35 | 36 | wandb.log({"metrics/score": ari * accuracy, "metrics/ari": ari, "metrics/accuracy": accuracy}) 37 | 38 | 39 | if __name__ == "__main__": 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument( 42 | "-c", 43 | "--config", 44 | type=str, 45 | required=True, 46 | help="Fullname of the YAML config to be used for training (see under the `config` directory)", 47 | ) 48 | parser.add_argument( 49 | "-s", 50 | "--sweep", 51 | nargs="?", 52 | default=False, 53 | const=True, 54 | help="Whether it is a sweep or not", 55 | ) 56 | 57 | main(parser.parse_args()) 58 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Novae model training with Weight & Biases monitoring 3 | This is **not** the actual Novae source code. Instead, see the `novae` directory 4 | """ 5 | 6 | import argparse 7 | 8 | import novae 9 | 10 | from .utils import get_callbacks, init_wandb_logger, post_training, read_config 11 | 12 | 13 | def main(args: argparse.Namespace) -> None: 14 | config = read_config(args) 15 | 16 | adatas = novae.utils.load_local_dataset(config.data.train_dataset, files_black_list=config.data.files_black_list) 17 | adatas_val = novae.utils.load_local_dataset(config.data.val_dataset) if config.data.val_dataset else None 18 | 19 | logger = init_wandb_logger(config) 20 | callbacks = get_callbacks(config, adatas_val) 21 | 22 | if config.wandb_artefact is not None: 23 | model = novae.Novae._load_wandb_artifact(config.wandb_artefact) 24 | 25 | if not config.zero_shot: 26 | model.fine_tune(adatas, logger=logger, callbacks=callbacks, **config.fit_kwargs) 27 | else: 28 | model = novae.Novae(adatas, **config.model_kwargs) 29 | model.fit(logger=logger, callbacks=callbacks, **config.fit_kwargs) 30 | 31 | post_training(model, adatas, config) 32 | 33 | 34 | if __name__ == "__main__": 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument( 37 | "-c", 38 | "--config", 39 | type=str, 40 | required=True, 41 | help="Fullname of the YAML config to be used for training (see under the `config` directory)", 42 | ) 43 | parser.add_argument( 44 | "-s", 45 | "--sweep", 46 | nargs="?", 47 | default=False, 48 | const=True, 49 | help="Whether it is a sweep or not", 50 | ) 51 | 52 | main(parser.parse_args()) 53 | -------------------------------------------------------------------------------- /scripts/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import lightning as L 4 | import numpy as np 5 | import pandas as pd 6 | import scanpy as sc 7 | import wandb 8 | import yaml 9 | from anndata import AnnData 10 | from lightning.pytorch.callbacks import ModelCheckpoint 11 | from lightning.pytorch.loggers import WandbLogger 12 | 13 | import novae 14 | from novae import log 15 | from novae._constants import Keys, Nums 16 | from novae.monitor import ( 17 | jensen_shannon_divergence, 18 | mean_fide_score, 19 | mean_normalized_entropy, 20 | ) 21 | from novae.monitor.callback import ( 22 | LogProtoCovCallback, 23 | LogTissuePrototypeWeights, 24 | ValidationCallback, 25 | ) 26 | from novae.monitor.log import log_plt_figure, wandb_results_dir 27 | 28 | from .config import Config 29 | 30 | 31 | def init_wandb_logger(config: Config) -> WandbLogger: 32 | wandb.init(project=config.project, **config.wandb_init_kwargs) 33 | 34 | if config.sweep: 35 | sweep_config = dict(wandb.config) 36 | 37 | fit_kwargs_args = ["lr", "patience", "min_delta"] 38 | 39 | for arg in fit_kwargs_args: 40 | if arg in sweep_config: 41 | config.fit_kwargs[arg] = sweep_config[arg] 42 | del sweep_config[arg] 43 | for key, value in sweep_config.items(): 44 | if hasattr(Nums, key): 45 | log.info(f"Nums.{key} = {value}") 46 | setattr(Nums, key, value) 47 | else: 48 | config.model_kwargs[key] = value 49 | 50 | log.info(f"Full config:\n{config.model_dump()}") 51 | 52 | assert "slide_key" not in config.model_kwargs, "For now, please provide one adata per file." 53 | 54 | wandb_logger = WandbLogger( 55 | save_dir=novae.utils.wandb_log_dir(), 56 | log_model="all", 57 | project=config.project, 58 | ) 59 | 60 | config_flat = pd.json_normalize(config.model_dump(), sep=".").to_dict(orient="records")[0] 61 | wandb_logger.experiment.config.update(config_flat) 62 | 63 | return wandb_logger 64 | 65 | 66 | def _get_hardware_kwargs(config: Config) -> dict: 67 | return { 68 | "num_workers": config.fit_kwargs.get("num_workers", 0), 69 | "accelerator": config.fit_kwargs.get("accelerator", "cpu"), 70 | } 71 | 72 | 73 | def get_callbacks(config: Config, adatas_val: list[AnnData] | None) -> list[L.Callback] | None: 74 | if config.wandb_init_kwargs.get("mode") == "disabled": 75 | return None 76 | 77 | validation_callback = [ValidationCallback(adatas_val, **_get_hardware_kwargs(config))] if adatas_val else [] 78 | 79 | if config.sweep: 80 | return validation_callback 81 | 82 | return [ 83 | *validation_callback, 84 | ModelCheckpoint(monitor="metrics/val_heuristic", mode="max", save_last=True, save_top_k=1), 85 | LogProtoCovCallback(), 86 | LogTissuePrototypeWeights(), 87 | ] 88 | 89 | 90 | def read_config(args: argparse.Namespace) -> Config: 91 | with open(novae.utils.repository_root() / "scripts" / "config" / args.config) as f: 92 | config = yaml.safe_load(f) 93 | config = Config(**config, sweep=args.sweep) 94 | 95 | log.info(f"Using {config.seed}") 96 | L.seed_everything(config.seed) 97 | 98 | return config 99 | 100 | 101 | def post_training(model: novae.Novae, adatas: list[AnnData], config: Config): # noqa: C901 102 | wandb.log({"num_parameters": sum(p.numel() for p in model.parameters())}) 103 | 104 | keys_repr = ["log_umap", "log_metrics", "log_domains"] 105 | if any(getattr(config.post_training, key) for key in keys_repr): 106 | model.compute_representations(adatas, **_get_hardware_kwargs(config), zero_shot=config.zero_shot) 107 | for n_domains in config.post_training.n_domains: 108 | model.assign_domains(adatas, n_domains=n_domains) 109 | 110 | if config.post_training.log_domains: 111 | for n_domains in config.post_training.n_domains: 112 | obs_key = model.assign_domains(adatas, n_domains=n_domains) 113 | novae.plot.domains(adatas, obs_key, show=False) 114 | log_plt_figure(f"domains_{n_domains=}") 115 | 116 | if config.post_training.log_metrics: 117 | for n_domains in config.post_training.n_domains: 118 | obs_key = model.assign_domains(adatas, n_domains=n_domains) 119 | jsd = jensen_shannon_divergence(adatas, obs_key) 120 | fide = mean_fide_score(adatas, obs_key, n_classes=n_domains) 121 | mne = mean_normalized_entropy(adatas, n_classes=n_domains, obs_key=obs_key) 122 | log.info(f"[{n_domains=}] JSD: {jsd}, FIDE: {fide}, MNE: {mne}") 123 | wandb.log({ 124 | f"metrics/jsd_{n_domains}_domains": jsd, 125 | f"metrics/fid_{n_domains}_domainse": fide, 126 | f"metrics/mne_{n_domains}_domains": mne, 127 | f"metrics/train_heuristic_{n_domains}_domains": fide * mne, 128 | }) 129 | 130 | if config.post_training.log_umap: 131 | _log_umap(model, adatas, config) 132 | 133 | if config.post_training.save_h5ad: 134 | for adata in adatas: 135 | if config.post_training.delete_X: 136 | del adata.X 137 | if "counts" in adata.layers: 138 | del adata.layers["counts"] 139 | _save_h5ad(adata) 140 | 141 | 142 | def _log_umap(model: novae.Novae, adatas: list[AnnData], config: Config, n_obs_th: int = 500_000): 143 | for adata in adatas: 144 | if "novae_tissue" in adata.uns: 145 | adata.obs["tissue"] = adata.uns["novae_tissue"] 146 | 147 | for n_domains in config.post_training.n_domains: 148 | obs_key = model.assign_domains(adatas, n_domains=n_domains) 149 | model.batch_effect_correction(adatas, obs_key=obs_key) 150 | 151 | latent_conc = np.concatenate([adata.obsm[Keys.REPR_CORRECTED] for adata in adatas], axis=0) 152 | obs_conc = pd.concat([adata.obs for adata in adatas], axis=0, join="inner") 153 | adata_conc = AnnData(obsm={Keys.REPR_CORRECTED: latent_conc}, obs=obs_conc) 154 | 155 | if "cell_id" in adata_conc.obs: 156 | del adata_conc.obs["cell_id"] # can't be saved for some reasons 157 | _save_h5ad(adata_conc, "adata_conc") 158 | 159 | if adata_conc.n_obs > n_obs_th: 160 | sc.pp.subsample(adata_conc, n_obs=n_obs_th) 161 | 162 | sc.pp.neighbors(adata_conc, use_rep=Keys.REPR_CORRECTED) 163 | sc.tl.umap(adata_conc) 164 | 165 | colors = [obs_key] 166 | for key in ["tissue", "technology"]: 167 | if key in adata_conc.obs: 168 | colors.append(key) 169 | 170 | sc.pl.umap(adata_conc, color=colors, show=False) 171 | log_plt_figure(f"umap_{n_domains=}") 172 | 173 | 174 | def _save_h5ad(adata: AnnData, stem: str | None = None): 175 | if stem is None: 176 | stem = adata.obs["slide_id"].iloc[0] if "slide_id" in adata.obs else str(id(adata)) 177 | 178 | out_path = wandb_results_dir() / f"{stem}.h5ad" 179 | log.info(f"Writing adata file to {out_path}: {adata}") 180 | adata.write_h5ad(out_path) 181 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | if __name__ == "__main__": 4 | setup() 5 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MICS-Lab/novae/b641a0bd1947759317c9e7ab4d997bb7a4a00932/tests/__init__.py -------------------------------------------------------------------------------- /tests/_utils.py: -------------------------------------------------------------------------------- 1 | import anndata 2 | import numpy as np 3 | import pandas as pd 4 | from anndata import AnnData 5 | 6 | import novae 7 | 8 | domains = [ 9 | ["D1", "D2", "D3", "D4", "D5"], 10 | ["D1", "D1", "D2", "D2", "D3"], 11 | ["D1", "D1", "D1", "D2", np.nan], 12 | ] 13 | 14 | spatial_coords = np.array([ 15 | [[0, 0], [0, 1], [0, 2], [3, 3], [1, 1]], 16 | [[0, 0], [0, 1], [0, 2], [3, 3], [1, 1]], 17 | [[0, 0], [0, 1], [0, 2], [3, 3], [1, 1]], 18 | ]) 19 | 20 | 21 | def _get_adata(i: int): 22 | values = domains[i] 23 | adata = AnnData(obs=pd.DataFrame({"domain": values}, index=[str(i) for i in range(len(values))])) 24 | adata.obs["slide_key"] = f"slide_{i}" 25 | adata.obsm["spatial"] = spatial_coords[i] 26 | novae.utils.spatial_neighbors(adata, radius=[0, 1.5]) 27 | return adata 28 | 29 | 30 | adatas = [_get_adata(i) for i in range(len(domains))] 31 | 32 | adata = adatas[0] 33 | 34 | adata_concat = anndata.concat(adatas) 35 | 36 | # o 37 | # 38 | # o-o 39 | # 40 | # o-o-o 41 | spatial_coords2 = np.array( 42 | [[0, 0], [0, 1], [0, 2], [2, 0], [2, 1], [4, 0]], 43 | ) 44 | 45 | adata_line = AnnData(obs=pd.DataFrame(index=[str(i) for i in range(len(spatial_coords2))])) 46 | adata_line.obsm["spatial"] = spatial_coords2 47 | novae.utils.spatial_neighbors(adata_line, radius=1.5) 48 | -------------------------------------------------------------------------------- /tests/test_correction.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from anndata import AnnData 4 | 5 | import novae 6 | from novae._constants import Keys 7 | 8 | obs1 = pd.DataFrame( 9 | { 10 | Keys.SLIDE_ID: ["a", "a", "b", "b", "a", "b", "b"], 11 | "domains_key": ["N1", "N1", "N1", "N1", "N3", "N1", "N2"], 12 | Keys.IS_VALID_OBS: [True, True, True, True, True, True, True], 13 | }, 14 | index=[f"{i}_1" for i in range(7)], 15 | ) 16 | 17 | latent1 = np.array([ 18 | [1, 2], 19 | [3, 4], 20 | [5, 6], 21 | [7, 8], 22 | [9, 10], 23 | [12, 13], 24 | [-2, -2], 25 | ]).astype(np.float32) 26 | 27 | expected1 = np.array([ 28 | [7, 8], 29 | [9, 10], 30 | [5, 6], 31 | [7, 8], 32 | [9, 10], 33 | [12, 13], 34 | [-2 + 35, -2 + 34], 35 | ]).astype(np.float32) 36 | 37 | adata1 = AnnData(obs=obs1, obsm={Keys.REPR: latent1}) 38 | 39 | obs2 = pd.DataFrame( 40 | { 41 | Keys.SLIDE_ID: ["c", "c", "c", "c", "c"], 42 | "domains_key": ["N2", "N1", np.nan, "N2", "N2"], 43 | Keys.IS_VALID_OBS: [True, True, False, True, True], 44 | }, 45 | index=[f"{i}_2" for i in range(5)], 46 | ) 47 | 48 | latent2 = np.array([ 49 | [-1, -3], 50 | [0, -10], 51 | [0, 0], 52 | [0, -1], 53 | [100, 100], 54 | ]).astype(np.float32) 55 | 56 | expected2 = np.array([ 57 | [-1, -3], 58 | [0 + 8, -10 + 19], 59 | [0, 0], 60 | [0, -1], 61 | [100, 100], 62 | ]).astype(np.float32) 63 | 64 | adata2 = AnnData(obs=obs2, obsm={Keys.REPR: latent2}) 65 | 66 | adatas = [adata1, adata2] 67 | for adata in adatas: 68 | for key in [Keys.SLIDE_ID, "domains_key"]: 69 | adata.obs[key] = adata.obs[key].astype("category") 70 | 71 | 72 | def test_batch_effect_correction(): 73 | novae.utils.batch_effect_correction(adatas, "domains_key") 74 | assert (adata1.obsm[Keys.REPR_CORRECTED] == expected1).all() 75 | assert (adata2.obsm[Keys.REPR_CORRECTED] == expected2).all() 76 | -------------------------------------------------------------------------------- /tests/test_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | import novae 5 | from novae._constants import Keys 6 | 7 | N_PANELS = 2 8 | N_SLIDES_PER_PANEL = 3 9 | 10 | 11 | adatas = novae.utils.toy_dataset(n_panels=N_PANELS, n_slides_per_panel=N_SLIDES_PER_PANEL) 12 | 13 | single_adata = adatas[0] 14 | 15 | 16 | def test_raise_invalid_slide_id(): 17 | with pytest.raises(AssertionError): 18 | novae.utils.spatial_neighbors(adatas, slide_key="key_not_in_obs") 19 | 20 | 21 | def test_single_panel(): 22 | novae.utils.spatial_neighbors(single_adata) 23 | model = novae.Novae(single_adata) 24 | model._datamodule = model._init_datamodule() 25 | 26 | assert len(model.dataset.slides_metadata) == 1 27 | assert model.dataset.obs_ilocs is not None 28 | 29 | 30 | def test_single_panel_slide_key(): 31 | novae.utils.spatial_neighbors(single_adata, slide_key="slide_key") 32 | model = novae.Novae(single_adata) 33 | model._datamodule = model._init_datamodule() 34 | 35 | assert len(model.dataset.slides_metadata) == N_SLIDES_PER_PANEL 36 | assert model.dataset.obs_ilocs is not None 37 | 38 | _ensure_batch_same_slide(model) 39 | 40 | 41 | def test_multi_panel(): 42 | novae.utils.spatial_neighbors(adatas) 43 | model = novae.Novae(adatas) 44 | model._datamodule = model._init_datamodule() 45 | 46 | assert len(model.dataset.slides_metadata) == N_PANELS 47 | assert model.dataset.obs_ilocs is None 48 | 49 | _ensure_batch_same_slide(model) 50 | 51 | 52 | def test_multi_panel_slide_key(): 53 | novae.utils.spatial_neighbors(adatas, slide_key="slide_key") 54 | model = novae.Novae(adatas) 55 | model._datamodule = model._init_datamodule() 56 | 57 | assert len(model.dataset.slides_metadata) == N_PANELS * N_SLIDES_PER_PANEL 58 | assert model.dataset.obs_ilocs is None 59 | 60 | _ensure_batch_same_slide(model) 61 | 62 | 63 | def _ensure_batch_same_slide(model: novae.Novae): 64 | n_obs_dataset = model.dataset.shuffled_obs_ilocs.shape[0] 65 | assert n_obs_dataset % model.hparams.batch_size == 0 66 | 67 | for batch_index in range(n_obs_dataset // model.hparams.batch_size): 68 | sub_obs_ilocs = model.dataset.shuffled_obs_ilocs[ 69 | batch_index * model.hparams.batch_size : (batch_index + 1) * model.hparams.batch_size 70 | ] 71 | unique_adata_indices = np.unique(sub_obs_ilocs[:, 0]) 72 | assert len(unique_adata_indices) == 1 73 | slide_ids = adatas[unique_adata_indices[0]].obs[Keys.SLIDE_ID].iloc[sub_obs_ilocs[:, 1]] 74 | assert len(np.unique(slide_ids)) == 1 75 | -------------------------------------------------------------------------------- /tests/test_metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pytest import approx 3 | 4 | import novae 5 | from novae.monitor.eval import _jensen_shannon_divergence, entropy 6 | 7 | from ._utils import adata_concat, adatas 8 | 9 | true_distributions = np.array([ 10 | [0.2] * 5, 11 | [0.4, 0.4, 0.2, 0, 0], 12 | [0.75, 0.25, 0, 0, 0], 13 | ]) 14 | 15 | true_jsd = _jensen_shannon_divergence(true_distributions) 16 | 17 | 18 | def test_heuristic(): 19 | novae.monitor.heuristic(adatas, "domain", n_classes=5) 20 | 21 | 22 | def test_jensen_shannon_divergence(): 23 | jsd = novae.monitor.jensen_shannon_divergence(adatas, "domain") 24 | 25 | assert jsd == approx(true_jsd) 26 | 27 | 28 | def test_jensen_shannon_divergence_concat(): 29 | jsd = novae.monitor.jensen_shannon_divergence(adata_concat, "domain", slide_key="slide_key") 30 | 31 | assert jsd == approx(true_jsd) 32 | 33 | 34 | def test_jensen_shannon_divergence_manual(): 35 | assert _jensen_shannon_divergence(np.ones((1, 5))) == approx(0.0) 36 | assert _jensen_shannon_divergence(np.ones((2, 5))) == approx(0.0) 37 | 38 | distribution = np.array([ 39 | [0.3, 0.2, 0.5], 40 | [0.1, 0.1, 0.8], 41 | [0.2, 0.3, 0.5], 42 | [0, 0, 1], 43 | ]) 44 | 45 | means = np.array([0.15, 0.15, 0.7]) 46 | 47 | entropy_means = entropy(means) 48 | 49 | assert entropy_means == approx(1.18, rel=1e-2) 50 | 51 | jsd_manual = entropy_means - 0.25 * sum(entropy(d) for d in distribution) 52 | 53 | jsd = _jensen_shannon_divergence(distribution) 54 | 55 | assert jsd == approx(jsd_manual) 56 | 57 | 58 | def test_fide_score(): 59 | fide = novae.monitor.fide_score(adatas[0], "domain", n_classes=5) 60 | 61 | assert fide == 0 62 | 63 | fide = novae.monitor.fide_score(adatas[1], "domain", n_classes=3) 64 | 65 | assert fide == approx(0.4 / 3) 66 | 67 | fide = novae.monitor.fide_score(adatas[1], "domain", n_classes=5) 68 | 69 | assert fide == approx(0.4 / 5) 70 | 71 | fide = novae.monitor.fide_score(adatas[2], "domain", n_classes=1) 72 | 73 | assert fide == approx(1) 74 | 75 | fide = novae.monitor.fide_score(adatas[2], "domain", n_classes=5) 76 | 77 | assert fide == approx(1 / 5) 78 | -------------------------------------------------------------------------------- /tests/test_misc.py: -------------------------------------------------------------------------------- 1 | import novae 2 | from novae._constants import Nums 3 | 4 | adata = novae.utils.toy_dataset(xmax=200)[0] 5 | 6 | 7 | def test_settings(): 8 | novae.settings.disable_lazy_loading() 9 | 10 | novae.utils.spatial_neighbors(adata) 11 | model = novae.Novae(adata) 12 | model._datamodule = model._init_datamodule() 13 | assert model.dataset.anndata_torch.tensors is not None 14 | 15 | novae.settings.enable_lazy_loading(n_obs_threshold=100) 16 | novae.utils.spatial_neighbors(adata) 17 | model = novae.Novae(adata) 18 | model._datamodule = model._init_datamodule() 19 | assert model.dataset.anndata_torch.tensors is None 20 | 21 | novae.settings.warmup_epochs = 2 22 | assert novae.settings.warmup_epochs == Nums.WARMUP_EPOCHS 23 | 24 | 25 | def test_repr(): 26 | novae.utils.spatial_neighbors(adata) 27 | model = novae.Novae(adata) 28 | 29 | repr(model) 30 | repr(model.mode) 31 | -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import anndata 4 | import numpy as np 5 | import pandas as pd 6 | import pytest 7 | import torch 8 | 9 | import novae 10 | from novae._constants import Keys 11 | from novae.utils._data import GENE_NAMES_SUBSET 12 | 13 | adatas = novae.utils.toy_dataset( 14 | n_panels=2, 15 | n_slides_per_panel=2, 16 | xmax=100, 17 | n_domains=2, 18 | compute_spatial_neighbors=True, 19 | ) 20 | adata = adatas[0] 21 | 22 | 23 | def _generate_fake_scgpt_inputs(): 24 | gene_names = GENE_NAMES_SUBSET[:100] 25 | indices = [7, 4, 3, 0, 1, 2, 9, 5, 6, 8, *list(range(10, 100))] 26 | 27 | vocab = dict(zip(gene_names, indices)) 28 | 29 | with open("tests/vocab.json", "w") as f: 30 | json.dump(vocab, f, indent=4) 31 | 32 | torch.save({"encoder.embedding.weight": torch.randn(len(vocab), 16)}, "tests/best_model.pt") 33 | 34 | 35 | _generate_fake_scgpt_inputs() 36 | 37 | 38 | # def test_load_wandb_artifact(): 39 | # novae.Novae._load_wandb_artifact("novae/novae/model-4i8e9g2v:v17") 40 | 41 | 42 | def test_load_huggingface_model(): 43 | model = novae.Novae.from_pretrained("MICS-Lab/novae-test") 44 | 45 | assert model.cell_embedder.embedding.weight.requires_grad is False 46 | 47 | 48 | def test_train(): 49 | adatas = novae.utils.toy_dataset( 50 | n_panels=2, 51 | n_slides_per_panel=2, 52 | xmax=60, 53 | n_domains=2, 54 | ) 55 | 56 | novae.utils.spatial_neighbors(adatas) 57 | model = novae.Novae(adatas, num_prototypes=10) 58 | 59 | with pytest.raises(AssertionError): # should raise an error because the model has not been trained 60 | model.compute_representations() 61 | 62 | model.fit(max_epochs=3) 63 | model.compute_representations() 64 | model.compute_representations(num_workers=2) 65 | 66 | # obs_key = model.assign_domains(n_domains=2) 67 | obs_key = model.assign_domains(level=2) 68 | 69 | model.batch_effect_correction() 70 | 71 | adatas[0].obs.iloc[0][obs_key] = np.nan 72 | 73 | novae.monitor.mean_fide_score(adatas, obs_key=obs_key) 74 | novae.monitor.jensen_shannon_divergence(adatas, obs_key=obs_key) 75 | 76 | novae.monitor.mean_fide_score(adatas, obs_key=obs_key, slide_key="slide_key") 77 | novae.monitor.jensen_shannon_divergence(adatas, obs_key=obs_key, slide_key="slide_key") 78 | 79 | adatas[0].write_h5ad("tests/test.h5ad") # ensures the output can be saved 80 | 81 | model.compute_representations(adatas, zero_shot=True) 82 | 83 | with pytest.raises(AssertionError): 84 | model.fine_tune(adatas, max_epochs=1) 85 | 86 | model.mode.pretrained = True 87 | 88 | model.fine_tune(adatas, max_epochs=1) 89 | 90 | 91 | @pytest.mark.parametrize("slide_key", [None, "slide_key"]) 92 | def test_representation_single_panel(slide_key: str | None): 93 | adata = novae.utils.toy_dataset( 94 | n_panels=1, 95 | n_slides_per_panel=2, 96 | xmax=100, 97 | n_domains=2, 98 | )[0] 99 | 100 | novae.utils.spatial_neighbors(adata, slide_key=slide_key) 101 | 102 | model = novae.Novae(adata) 103 | model._datamodule = model._init_datamodule() 104 | model.mode.trained = True 105 | 106 | model.compute_representations() 107 | 108 | domains = adata.obs[Keys.LEAVES].copy() 109 | 110 | model.compute_representations() 111 | 112 | assert domains.equals(adata.obs[Keys.LEAVES]) 113 | 114 | novae.utils.spatial_neighbors(adata, slide_key=slide_key) 115 | model.compute_representations([adata]) 116 | 117 | assert domains.equals(adata.obs[Keys.LEAVES]) 118 | 119 | if slide_key is not None: 120 | sids = adata.obs[slide_key].unique() 121 | adatas = [adata[adata.obs[slide_key] == sid] for sid in sids] 122 | 123 | novae.utils.spatial_neighbors(adatas) 124 | model.compute_representations(adatas) 125 | 126 | adata_concat = anndata.concat(adatas) 127 | 128 | assert domains.equals(adata_concat.obs[Keys.LEAVES].loc[domains.index]) 129 | 130 | novae.utils.spatial_neighbors(adatas, slide_key=slide_key) 131 | model.compute_representations(adatas) 132 | 133 | adata_concat = anndata.concat(adatas) 134 | 135 | assert domains.equals(adata_concat.obs[Keys.LEAVES].loc[domains.index]) 136 | 137 | 138 | @pytest.mark.parametrize("slide_key", [None, "slide_key"]) 139 | def test_representation_multi_panel(slide_key: str | None): 140 | adatas = novae.utils.toy_dataset( 141 | n_panels=3, 142 | n_slides_per_panel=2, 143 | xmax=100, 144 | n_domains=2, 145 | ) 146 | 147 | novae.utils.spatial_neighbors(adatas, slide_key=slide_key) 148 | model = novae.Novae(adatas) 149 | model._datamodule = model._init_datamodule() 150 | model.mode.trained = True 151 | 152 | model.compute_representations() 153 | 154 | domains_series = pd.concat([adata.obs[Keys.LEAVES].copy() for adata in adatas]) 155 | 156 | novae.utils.spatial_neighbors(adatas, slide_key=slide_key) 157 | model.compute_representations(adatas) 158 | 159 | domains_series2 = pd.concat([adata.obs[Keys.LEAVES].copy() for adata in adatas]) 160 | 161 | assert domains_series.equals(domains_series2.loc[domains_series.index]) 162 | 163 | adata_split = [ 164 | adata[adata.obs[Keys.SLIDE_ID] == sid].copy() for adata in adatas for sid in adata.obs[Keys.SLIDE_ID].unique() 165 | ] 166 | 167 | model.compute_representations(adata_split) 168 | 169 | domains_series2 = pd.concat([adata.obs[Keys.LEAVES] for adata in adata_split]) 170 | 171 | assert domains_series.equals(domains_series2.loc[domains_series.index]) 172 | 173 | 174 | @pytest.mark.parametrize("slide_key", [None, "slide_key"]) 175 | @pytest.mark.parametrize("scgpt_model_dir", [None, "tests"]) 176 | def test_saved_model_identical(slide_key: str | None, scgpt_model_dir: str | None): 177 | adata = novae.utils.toy_dataset( 178 | n_panels=1, 179 | n_slides_per_panel=2, 180 | xmax=100, 181 | n_domains=2, 182 | )[0] 183 | 184 | # using weird parameters 185 | novae.utils.spatial_neighbors(adata, slide_key=slide_key) 186 | model = novae.Novae( 187 | adata, 188 | embedding_size=67, 189 | output_size=78, 190 | n_hops_local=4, 191 | n_hops_view=3, 192 | heads=2, 193 | hidden_size=46, 194 | num_layers=3, 195 | batch_size=345, 196 | temperature=0.13, 197 | num_prototypes=212, 198 | panel_subset_size=0.62, 199 | background_noise_lambda=7.7, 200 | sensitivity_noise_std=0.042, 201 | scgpt_model_dir=scgpt_model_dir, 202 | ) 203 | 204 | assert model.cell_embedder.embedding.weight.requires_grad is (scgpt_model_dir is None) 205 | 206 | model._datamodule = model._init_datamodule() 207 | model.mode.trained = True 208 | 209 | model.compute_representations() 210 | model.assign_domains() 211 | 212 | domains = adata.obs[Keys.LEAVES].copy() 213 | representations = adata.obsm[Keys.REPR].copy() 214 | 215 | model.save_pretrained("tests/test_model") 216 | 217 | new_model = novae.Novae.from_pretrained("tests/test_model") 218 | 219 | novae.utils.spatial_neighbors(adata, slide_key=slide_key) 220 | new_model.compute_representations(adata) 221 | new_model.assign_domains(adata) 222 | 223 | assert (adata.obsm[Keys.REPR] == representations).all() 224 | assert domains.equals(adata.obs[Keys.LEAVES]) 225 | 226 | for name, param in model.named_parameters(): 227 | assert torch.equal(param, new_model.state_dict()[name]) 228 | 229 | 230 | def test_safetensors_parameters_names(): 231 | from huggingface_hub import hf_hub_download 232 | from safetensors import safe_open 233 | 234 | local_file = hf_hub_download(repo_id="MICS-Lab/novae-human-0", filename="model.safetensors") 235 | with safe_open(local_file, framework="pt", device="cpu") as f: 236 | pretrained_model_names = f.keys() 237 | 238 | model = novae.Novae(adata) 239 | 240 | actual_names = [name for name, _ in model.named_parameters()] 241 | 242 | assert set(pretrained_model_names) == set(actual_names) 243 | 244 | 245 | def test_reset_clusters_zero_shot(): 246 | adata = novae.utils.toy_dataset()[0] 247 | 248 | novae.utils.spatial_neighbors(adata) 249 | 250 | model = novae.Novae.from_pretrained("MICS-Lab/novae-human-0") 251 | 252 | model.compute_representations(adata, zero_shot=True) 253 | clusters_levels = model.swav_head.clusters_levels.copy() 254 | 255 | adata = adata[:550].copy() 256 | 257 | model.compute_representations(adata, zero_shot=True) 258 | 259 | assert not (model.swav_head.clusters_levels == clusters_levels).all() 260 | -------------------------------------------------------------------------------- /tests/test_plots.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import novae 4 | from novae._constants import Keys 5 | 6 | adatas = novae.utils.toy_dataset(n_panels=2, xmax=200) 7 | novae.utils.spatial_neighbors(adatas) 8 | 9 | model = novae.Novae(adatas) 10 | 11 | model.mode.trained = True 12 | model.compute_representations() 13 | model.assign_domains() 14 | 15 | 16 | @pytest.mark.parametrize("hline_level", [None, 3, [2, 4]]) 17 | def test_plot_domains_hierarchy(hline_level: int | list[int] | None): 18 | model.plot_domains_hierarchy(hline_level=hline_level) 19 | 20 | 21 | def test_plot_prototype_weights(): 22 | model.plot_prototype_weights() 23 | 24 | adatas[0].uns[Keys.UNS_TISSUE] = "breast" 25 | adatas[1].uns[Keys.UNS_TISSUE] = "lung" 26 | 27 | model.plot_prototype_weights() 28 | 29 | 30 | def test_plot_domains(): 31 | novae.plot.domains(adatas, show=False) 32 | 33 | 34 | def test_plot_connectivities(): 35 | novae.plot.connectivities(adatas, ngh_threshold=2, show=False) 36 | novae.plot.connectivities(adatas, ngh_threshold=None, show=False) 37 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import anndata 2 | import numpy as np 3 | import pandas as pd 4 | import pytest 5 | from anndata import AnnData 6 | 7 | import novae 8 | from novae._constants import Keys 9 | from novae.data.dataset import _to_adjacency_local, _to_adjacency_view 10 | from novae.utils._build import _set_unique_slide_ids 11 | 12 | from ._utils import adata, adata_concat, adata_line 13 | 14 | true_connectivities = np.array([ 15 | [0, 1, 0, 0, 1], 16 | [1, 0, 1, 0, 1], 17 | [0, 1, 0, 0, 1], 18 | [0, 0, 0, 0, 0], 19 | [1, 1, 1, 0, 0], 20 | ]) 21 | 22 | 23 | def test_build(): 24 | connectivities = adata.obsp["spatial_connectivities"] 25 | 26 | assert connectivities.shape[0] == adata.n_obs 27 | 28 | assert (connectivities.todense() == true_connectivities).all() 29 | 30 | 31 | def test_set_unique_slide_ids(): 32 | adatas = novae.utils.toy_dataset( 33 | xmax=200, 34 | n_panels=2, 35 | n_slides_per_panel=1, 36 | n_vars=30, 37 | slide_ids_unique=False, 38 | ) 39 | 40 | _set_unique_slide_ids(adatas, slide_key="slide_key") 41 | 42 | assert adatas[0].obs[Keys.SLIDE_ID].iloc[0] == f"{id(adatas[0])}_slide_0" 43 | 44 | adatas = novae.utils.toy_dataset( 45 | xmax=200, 46 | n_panels=2, 47 | n_slides_per_panel=1, 48 | n_vars=30, 49 | slide_ids_unique=True, 50 | ) 51 | 52 | _set_unique_slide_ids(adatas, slide_key="slide_key") 53 | 54 | assert adatas[0].obs[Keys.SLIDE_ID].iloc[0] == "slide_0_0" 55 | 56 | 57 | def test_build_slide_key(): 58 | adata_ = adata_concat.copy() 59 | novae.utils.spatial_neighbors(adata_, radius=1.5, slide_key="slide_key") 60 | 61 | true_connectivities = np.array([ 62 | [0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 63 | [1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 64 | [0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 65 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 66 | [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 67 | [0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0], 68 | [0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0], 69 | [0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0], 70 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 71 | [0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], 72 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1], 73 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1], 74 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1], 75 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 76 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0], 77 | ]) 78 | 79 | assert (adata_.obsp["spatial_connectivities"].todense() == true_connectivities).all() 80 | 81 | 82 | def test_build_slide_key_disjoint_indices(): 83 | adata = novae.utils.toy_dataset( 84 | n_panels=1, 85 | n_slides_per_panel=1, 86 | xmax=100, 87 | n_domains=2, 88 | )[0] 89 | 90 | adata2 = adata.copy() 91 | adata2.obs["slide_key"] = "slide_key2" 92 | adata2.obs_names = adata2.obs_names + "_2" 93 | 94 | novae.utils.spatial_neighbors(adata) 95 | novae.utils.spatial_neighbors(adata2) 96 | 97 | n1, n2 = len(adata.obsp["spatial_connectivities"].data), len(adata2.obsp["spatial_connectivities"].data) 98 | assert n1 == n2 99 | 100 | adata_concat = anndata.concat([adata, adata2], axis=0).copy() 101 | 102 | novae.utils.spatial_neighbors(adata_concat, slide_key="slide_key") 103 | 104 | assert len(adata_concat.obsp["spatial_connectivities"].data) == n1 + n2 105 | 106 | 107 | def test_build_technology(): 108 | adata_cosmx = adata.copy() 109 | adata_cosmx.obs[["CenterX_global_px", "CenterY_global_px"]] = adata_cosmx.obsm["spatial"] 110 | del adata_cosmx.obsm["spatial"] 111 | novae.utils.spatial_neighbors(adata_cosmx, technology="cosmx") 112 | 113 | del adata_cosmx.obs["CenterY_global_px"] 114 | 115 | # one column is missing in obs 116 | with pytest.raises(AssertionError): 117 | novae.utils.spatial_neighbors(adata_cosmx, technology="cosmx") 118 | 119 | 120 | def test_invalid_build(): 121 | adata_invalid = anndata.AnnData(obs=pd.DataFrame(index=["0", "1", "2"])) 122 | 123 | with pytest.raises(AssertionError): 124 | novae.utils.spatial_neighbors(adata_invalid, radius=[0, 1.5]) 125 | 126 | adata_invalid.obsm["spatial"] = np.array([[0, 0, 0], [0, 1, 2], [0, 2, 4]]) 127 | 128 | with pytest.raises(AssertionError): 129 | novae.utils.spatial_neighbors(adata_invalid, radius=[0, 1.5]) 130 | 131 | with pytest.raises(AssertionError): 132 | novae.utils.spatial_neighbors(adata_invalid, radius=2, technology="unknown") 133 | 134 | with pytest.raises(AssertionError): 135 | novae.utils.spatial_neighbors(adata_invalid, radius=1, technology="cosmx", pixel_size=0.1) 136 | 137 | 138 | def test_to_adjacency_local(): 139 | adjancency_local = _to_adjacency_local(adata.obsp["spatial_connectivities"], 1) 140 | 141 | assert ( 142 | (adjancency_local.todense() > 0) 143 | == np.array([ 144 | [True, True, False, False, True], 145 | [True, True, True, False, True], 146 | [False, True, True, False, True], 147 | [False, False, False, True, False], 148 | [True, True, True, False, True], 149 | ]) 150 | ).all() 151 | 152 | adjancency_local = _to_adjacency_local(adata.obsp["spatial_connectivities"], 2) 153 | 154 | assert ( 155 | (adjancency_local.todense() > 0) 156 | == np.array([ 157 | [True, True, True, False, True], 158 | [True, True, True, False, True], 159 | [True, True, True, False, True], 160 | [False, False, False, False, False], # unconnected node => no local connections with n_hop >= 2 161 | [True, True, True, False, True], 162 | ]) 163 | ).all() 164 | 165 | adjancency_local = _to_adjacency_local(adata_line.obsp["spatial_connectivities"], 1) 166 | 167 | assert ( 168 | (adjancency_local.todense() > 0) 169 | == np.array([ 170 | [True, True, False, False, False, False], 171 | [True, True, True, False, False, False], 172 | [False, True, True, False, False, False], 173 | [False, False, False, True, True, False], 174 | [False, False, False, True, True, False], 175 | [False, False, False, False, False, True], 176 | ]) 177 | ).all() 178 | 179 | adjancency_local = _to_adjacency_local(adata_line.obsp["spatial_connectivities"], 2) 180 | 181 | assert ( 182 | (adjancency_local.todense() > 0) 183 | == np.array([ 184 | [True, True, True, False, False, False], 185 | [True, True, True, False, False, False], 186 | [True, True, True, False, False, False], 187 | [False, False, False, True, True, False], 188 | [False, False, False, True, True, False], 189 | [False, False, False, False, False, False], # unconnected node => no local connections with n_hop >= 2 190 | ]) 191 | ).all() 192 | 193 | 194 | def test_to_adjacency_view(): 195 | adjancency_view = _to_adjacency_view(adata.obsp["spatial_connectivities"], 2) 196 | 197 | assert ( 198 | (adjancency_view.todense() > 0) 199 | == np.array([ 200 | [False, False, True, False, False], 201 | [False, False, False, False, False], 202 | [True, False, False, False, False], 203 | [False, False, False, False, False], 204 | [False, False, False, False, False], 205 | ]) 206 | ).all() 207 | 208 | adjancency_view = _to_adjacency_view(adata.obsp["spatial_connectivities"], 3) 209 | 210 | assert adjancency_view.sum() == 0 211 | 212 | adjancency_view = _to_adjacency_view(adata_line.obsp["spatial_connectivities"], 1) 213 | 214 | assert ( 215 | (adjancency_view.todense() > 0) 216 | == np.array([ 217 | [False, True, False, False, False, False], 218 | [True, False, True, False, False, False], 219 | [False, True, False, False, False, False], 220 | [False, False, False, False, True, False], 221 | [False, False, False, True, False, False], 222 | [False, False, False, False, False, False], 223 | ]) 224 | ).all() 225 | 226 | adjancency_view = _to_adjacency_view(adata_line.obsp["spatial_connectivities"], 2) 227 | 228 | assert ( 229 | (adjancency_view.todense() > 0) 230 | == np.array([ 231 | [False, False, True, False, False, False], 232 | [False, False, False, False, False, False], 233 | [True, False, False, False, False, False], 234 | [False, False, False, False, False, False], 235 | [False, False, False, False, False, False], 236 | [False, False, False, False, False, False], 237 | ]) 238 | ).all() 239 | 240 | adjancency_view = _to_adjacency_view(adata_line.obsp["spatial_connectivities"], 3) 241 | 242 | assert adjancency_view.sum() == 0 243 | 244 | 245 | def test_check_slide_name_key(): 246 | obs = pd.DataFrame({Keys.SLIDE_ID: ["slide1", "slide2"], "name": ["sample1", "sample2"]}) 247 | adata = AnnData(obs=obs) 248 | 249 | assert novae.utils.check_slide_name_key(adata, None) == Keys.SLIDE_ID 250 | novae.utils.check_slide_name_key(adata, "name") 251 | 252 | obs = pd.DataFrame({Keys.SLIDE_ID: ["slide1", "slide2", "slide2"], "name": ["sample1", "sample2", "sample3"]}) 253 | adata2 = AnnData(obs=obs) 254 | 255 | assert novae.utils.check_slide_name_key(adata2, None) == Keys.SLIDE_ID 256 | with pytest.raises(AssertionError): 257 | novae.utils.check_slide_name_key(adata2, "name") 258 | 259 | del adata2.obs[Keys.SLIDE_ID] 260 | with pytest.raises(AssertionError): 261 | novae.utils.check_slide_name_key(adata2, None) 262 | --------------------------------------------------------------------------------