├── docs ├── _static │ ├── .gitkeep │ └── css │ │ └── custom.css ├── _templates │ ├── .gitkeep │ └── autosummary │ │ └── class.rst ├── changelog.md ├── references.md ├── tutorials.md ├── index.md ├── api.md ├── Makefile ├── extensions │ └── typed_returns.py ├── conf.py ├── contributing.md └── references.bib ├── src └── multimil │ ├── model │ ├── __init__.py │ └── _mil.py │ ├── nn │ ├── __init__.py │ └── _base_components.py │ ├── module │ ├── __init__.py │ └── _mil_torch.py │ ├── dataloaders │ ├── __init__.py │ ├── _data_splitting.py │ └── _ann_dataloader.py │ ├── __init__.py │ └── utils │ ├── __init__.py │ └── _utils.py ├── .github ├── ISSUE_TEMPLATE │ ├── config.yml │ ├── feature_request.yml │ └── bug_report.yml └── workflows │ ├── build.yaml │ ├── release.yaml │ └── test.yaml ├── tests └── test_basic.py ├── .editorconfig ├── .codecov.yaml ├── .readthedocs.yaml ├── .gitignore ├── .cruft.json ├── .pre-commit-config.yaml ├── LICENSE ├── CHANGELOG.md ├── README.md └── pyproject.toml /docs/_static/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/_templates/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/changelog.md: -------------------------------------------------------------------------------- 1 | ```{include} ../CHANGELOG.md 2 | 3 | ``` 4 | -------------------------------------------------------------------------------- /docs/references.md: -------------------------------------------------------------------------------- 1 | # References 2 | 3 | ```{bibliography} 4 | :cited: 5 | ``` 6 | -------------------------------------------------------------------------------- /src/multimil/model/__init__.py: -------------------------------------------------------------------------------- 1 | from ._mil import MILClassifier 2 | 3 | __all__ = ["MILClassifier"] 4 | -------------------------------------------------------------------------------- /src/multimil/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from ._base_components import MLP, Aggregator 2 | 3 | __all__ = ["MLP", "Aggregator"] 4 | -------------------------------------------------------------------------------- /src/multimil/module/__init__.py: -------------------------------------------------------------------------------- 1 | from ._mil_torch import MILClassifierTorch 2 | 3 | __all__ = ["MILClassifierTorch"] 4 | -------------------------------------------------------------------------------- /docs/tutorials.md: -------------------------------------------------------------------------------- 1 | # Tutorials 2 | 3 | ## Prediction 4 | 5 | ```{toctree} 6 | :maxdepth: 1 7 | 8 | notebooks/mil_classification 9 | notebooks/mil_regression 10 | ``` 11 | -------------------------------------------------------------------------------- /src/multimil/dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | from ._ann_dataloader import GroupAnnDataLoader 2 | from ._data_splitting import GroupDataSplitter 3 | 4 | __all__ = ["GroupAnnDataLoader", "GroupDataSplitter"] 5 | -------------------------------------------------------------------------------- /docs/_static/css/custom.css: -------------------------------------------------------------------------------- 1 | /* Reduce the font size in data frames - See https://github.com/scverse/cookiecutter-scverse/issues/193 */ 2 | div.cell_output table.dataframe { 3 | font-size: 0.8em; 4 | } 5 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | ```{include} ../README.md 2 | 3 | ``` 4 | 5 | ```{toctree} 6 | :hidden: true 7 | :maxdepth: 2 8 | 9 | api.md 10 | tutorials.md 11 | changelog.md 12 | contributing.md 13 | references.md 14 | ``` 15 | -------------------------------------------------------------------------------- /src/multimil/__init__.py: -------------------------------------------------------------------------------- 1 | from importlib.metadata import version 2 | 3 | from . import dataloaders, model, module, nn, utils 4 | 5 | __all__ = ["dataloaders", "model", "module", "nn", "utils"] 6 | 7 | __version__ = version("multimil") 8 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | contact_links: 3 | - name: Scverse Community Forum 4 | url: https://discourse.scverse.org/ 5 | about: If you have questions about “How to do X”, please ask them here. 6 | -------------------------------------------------------------------------------- /tests/test_basic.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import multimil 4 | 5 | 6 | def test_package_has_version(): 7 | assert multimil.__version__ is not None 8 | 9 | 10 | @pytest.mark.skip(reason="This decorator should be removed when test passes.") 11 | def test_example(): 12 | assert 1 == 0 # This test is designed to fail. 13 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*] 4 | indent_style = space 5 | indent_size = 4 6 | end_of_line = lf 7 | charset = utf-8 8 | trim_trailing_whitespace = true 9 | insert_final_newline = true 10 | 11 | [*.{yml,yaml}] 12 | indent_size = 2 13 | 14 | [.cruft.json] 15 | indent_size = 2 16 | 17 | [Makefile] 18 | indent_style = tab 19 | -------------------------------------------------------------------------------- /.codecov.yaml: -------------------------------------------------------------------------------- 1 | # Based on pydata/xarray 2 | codecov: 3 | require_ci_to_pass: no 4 | 5 | coverage: 6 | status: 7 | project: 8 | default: 9 | # Require 1% coverage, i.e., always succeed 10 | target: 1 11 | patch: false 12 | changes: false 13 | 14 | comment: 15 | layout: diff, flags, files 16 | behavior: once 17 | require_base: no 18 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # https://docs.readthedocs.io/en/stable/config-file/v2.html 2 | version: 2 3 | build: 4 | os: ubuntu-20.04 5 | tools: 6 | python: "3.10" 7 | sphinx: 8 | configuration: docs/conf.py 9 | # disable this for more lenient docs builds 10 | fail_on_warning: false 11 | python: 12 | install: 13 | - method: pip 14 | path: . 15 | extra_requirements: 16 | - doc 17 | -------------------------------------------------------------------------------- /docs/api.md: -------------------------------------------------------------------------------- 1 | # API 2 | 3 | ## Model 4 | 5 | ```{eval-rst} 6 | .. module:: multimil.model 7 | .. currentmodule:: multimil 8 | 9 | .. autosummary:: 10 | :toctree: generated 11 | 12 | model.MILClassifier 13 | ``` 14 | 15 | ## Module 16 | 17 | ```{eval-rst} 18 | .. module:: multimil.module 19 | .. currentmodule:: multimil 20 | 21 | .. autosummary:: 22 | :toctree: generated 23 | 24 | module.MILClassifierTorch 25 | ``` 26 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.yml: -------------------------------------------------------------------------------- 1 | name: Feature request 2 | description: Propose a new feature for multimil 3 | labels: enhancement 4 | body: 5 | - type: textarea 6 | id: description 7 | attributes: 8 | label: Description of feature 9 | description: Please describe your suggestion for a new feature. It might help to describe a problem or use case, plus any alternatives that you have considered. 10 | validations: 11 | required: true 12 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Notebooks 2 | .ipynb_checkpoints 3 | 4 | # Logs 5 | scvi_log/ 6 | 7 | # Temp files 8 | .DS_Store 9 | *~ 10 | buck-out/ 11 | 12 | # Compiled files 13 | .venv/ 14 | __pycache__/ 15 | .mypy_cache/ 16 | .ruff_cache/ 17 | 18 | # Distribution / packaging 19 | /build/ 20 | /dist/ 21 | /*.egg-info/ 22 | 23 | # Tests and coverage 24 | /.pytest_cache/ 25 | /.cache/ 26 | /data/ 27 | /node_modules/ 28 | 29 | # docs 30 | /docs/generated/ 31 | /docs/_build/ 32 | 33 | # IDEs 34 | /.idea/ 35 | /.vscode/ 36 | 37 | # Data 38 | *.h5ad 39 | 40 | # Private notebooks 41 | _*.ipynb 42 | -------------------------------------------------------------------------------- /src/multimil/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from ._utils import ( 2 | create_df, 3 | get_bag_info, 4 | get_predictions, 5 | get_sample_representations, 6 | plt_plot_losses, 7 | prep_minibatch, 8 | save_predictions_in_adata, 9 | score_top_cells, 10 | select_covariates, 11 | setup_ordinal_regression, 12 | ) 13 | 14 | __all__ = [ 15 | "create_df", 16 | "setup_ordinal_regression", 17 | "select_covariates", 18 | "prep_minibatch", 19 | "get_predictions", 20 | "get_bag_info", 21 | "save_predictions_in_adata", 22 | "plt_plot_losses", 23 | "get_sample_representations", 24 | "score_top_cells", 25 | ] 26 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /.github/workflows/build.yaml: -------------------------------------------------------------------------------- 1 | name: Check Build 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | 9 | concurrency: 10 | group: ${{ github.workflow }}-${{ github.ref }} 11 | cancel-in-progress: true 12 | 13 | jobs: 14 | package: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@v3 18 | - name: Set up Python 3.10 19 | uses: actions/setup-python@v4 20 | with: 21 | python-version: "3.10" 22 | cache: "pip" 23 | cache-dependency-path: "**/pyproject.toml" 24 | - name: Install build dependencies 25 | run: python -m pip install --upgrade pip wheel twine build 26 | - name: Build package 27 | run: python -m build 28 | - name: Check package 29 | run: twine check --strict dist/*.whl 30 | -------------------------------------------------------------------------------- /.github/workflows/release.yaml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | # Use "trusted publishing", see https://docs.pypi.org/trusted-publishers/ 8 | jobs: 9 | release: 10 | name: Upload release to PyPI 11 | runs-on: ubuntu-latest 12 | environment: 13 | name: pypi 14 | url: https://pypi.org/p/multimil 15 | permissions: 16 | id-token: write # IMPORTANT: this permission is mandatory for trusted publishing 17 | steps: 18 | - uses: actions/checkout@v4 19 | with: 20 | filter: blob:none 21 | fetch-depth: 0 22 | - uses: actions/setup-python@v4 23 | with: 24 | python-version: "3.x" 25 | cache: "pip" 26 | - run: pip install build 27 | - run: python -m build 28 | - name: Publish package distributions to PyPI 29 | uses: pypa/gh-action-pypi-publish@release/v1 30 | -------------------------------------------------------------------------------- /.cruft.json: -------------------------------------------------------------------------------- 1 | { 2 | "template": "https://github.com/scverse/cookiecutter-scverse", 3 | "commit": "aa877587e59c855e2464d09ed7678af00999191a", 4 | "checkout": null, 5 | "context": { 6 | "cookiecutter": { 7 | "project_name": "multimil", 8 | "package_name": "multimil", 9 | "project_description": "Multimodal weakly supervised learning to identify disease-specific changes in single-cell atlases", 10 | "author_full_name": "Anastasia Litinetskaya", 11 | "author_email": "alitinet@gmail.com", 12 | "github_user": "alitinet", 13 | "project_repo": "https://github.com/theislab/multimil", 14 | "license": "BSD 3-Clause License", 15 | "_copy_without_render": [ 16 | ".github/workflows/build.yaml", 17 | ".github/workflows/test.yaml", 18 | "docs/_templates/autosummary/**.rst" 19 | ], 20 | "_render_devdocs": false, 21 | "_jinja2_env_vars": { 22 | "lstrip_blocks": true, 23 | "trim_blocks": true 24 | }, 25 | "_template": "https://github.com/scverse/cookiecutter-scverse" 26 | } 27 | }, 28 | "directory": null 29 | } 30 | -------------------------------------------------------------------------------- /docs/extensions/typed_returns.py: -------------------------------------------------------------------------------- 1 | # code from https://github.com/theislab/scanpy/blob/master/docs/extensions/typed_returns.py 2 | # with some minor adjustment 3 | from __future__ import annotations 4 | 5 | import re 6 | from collections.abc import Generator, Iterable 7 | 8 | from sphinx.application import Sphinx 9 | from sphinx.ext.napoleon import NumpyDocstring 10 | 11 | 12 | def _process_return(lines: Iterable[str]) -> Generator[str, None, None]: 13 | for line in lines: 14 | if m := re.fullmatch(r"(?P\w+)\s+:\s+(?P[\w.]+)", line): 15 | yield f'-{m["param"]} (:class:`~{m["type"]}`)' 16 | else: 17 | yield line 18 | 19 | 20 | def _parse_returns_section(self: NumpyDocstring, section: str) -> list[str]: 21 | lines_raw = self._dedent(self._consume_to_next_section()) 22 | if lines_raw[0] == ":": 23 | del lines_raw[0] 24 | lines = self._format_block(":returns: ", list(_process_return(lines_raw))) 25 | if lines and lines[-1]: 26 | lines.append("") 27 | return lines 28 | 29 | 30 | def setup(app: Sphinx): 31 | """Set app.""" 32 | NumpyDocstring._parse_returns_section = _parse_returns_section 33 | -------------------------------------------------------------------------------- /docs/_templates/autosummary/class.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. add toctree option to make autodoc generate the pages 6 | 7 | .. autoclass:: {{ objname }} 8 | 9 | {% block attributes %} 10 | {% if attributes %} 11 | Attributes table 12 | ~~~~~~~~~~~~~~~~~~ 13 | 14 | .. autosummary:: 15 | {% for item in attributes %} 16 | ~{{ fullname }}.{{ item }} 17 | {%- endfor %} 18 | {% endif %} 19 | {% endblock %} 20 | 21 | {% block methods %} 22 | {% if methods %} 23 | Methods table 24 | ~~~~~~~~~~~~~ 25 | 26 | .. autosummary:: 27 | {% for item in methods %} 28 | {%- if item != '__init__' %} 29 | ~{{ fullname }}.{{ item }} 30 | {%- endif -%} 31 | {%- endfor %} 32 | {% endif %} 33 | {% endblock %} 34 | 35 | {% block attributes_documentation %} 36 | {% if attributes %} 37 | Attributes 38 | ~~~~~~~~~~~ 39 | 40 | {% for item in attributes %} 41 | 42 | .. autoattribute:: {{ [objname, item] | join(".") }} 43 | {%- endfor %} 44 | 45 | {% endif %} 46 | {% endblock %} 47 | 48 | {% block methods_documentation %} 49 | {% if methods %} 50 | Methods 51 | ~~~~~~~ 52 | 53 | {% for item in methods %} 54 | {%- if item != '__init__' %} 55 | 56 | .. automethod:: {{ [objname, item] | join(".") }} 57 | {%- endif -%} 58 | {%- endfor %} 59 | 60 | {% endif %} 61 | {% endblock %} 62 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | fail_fast: false 2 | default_language_version: 3 | python: python3 4 | default_stages: 5 | - pre-commit 6 | - pre-push 7 | minimum_pre_commit_version: 2.16.0 8 | repos: 9 | - repo: https://github.com/pre-commit/mirrors-prettier 10 | rev: v4.0.0-alpha.8 11 | hooks: 12 | - id: prettier 13 | - repo: https://github.com/astral-sh/ruff-pre-commit 14 | rev: v0.5.3 15 | hooks: 16 | - id: ruff 17 | types_or: [python, pyi, jupyter] 18 | args: [--fix, --exit-non-zero-on-fix] 19 | - id: ruff-format 20 | types_or: [python, pyi, jupyter] 21 | - repo: https://github.com/pre-commit/pre-commit-hooks 22 | rev: v5.0.0 23 | hooks: 24 | - id: detect-private-key 25 | - id: check-ast 26 | - id: end-of-file-fixer 27 | - id: mixed-line-ending 28 | args: [--fix=lf] 29 | - id: trailing-whitespace 30 | - id: check-case-conflict 31 | # Check that there are no merge conflicts (could be generated by template sync) 32 | - id: check-merge-conflict 33 | args: [--assume-in-merge] 34 | - repo: local 35 | hooks: 36 | - id: forbid-to-commit 37 | name: Don't commit rej files 38 | entry: | 39 | Cannot commit .rej files. These indicate merge conflicts that arise during automated template updates. 40 | Fix the merge conflicts manually and remove the .rej files. 41 | language: fail 42 | files: '.*\.rej$' 43 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2024, Anastasia Litinetskaya 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | schedule: 9 | - cron: "0 5 1,15 * *" 10 | 11 | concurrency: 12 | group: ${{ github.workflow }}-${{ github.ref }} 13 | cancel-in-progress: true 14 | 15 | jobs: 16 | test: 17 | runs-on: ${{ matrix.os }} 18 | defaults: 19 | run: 20 | shell: bash -e {0} # -e to fail on error 21 | 22 | strategy: 23 | fail-fast: false 24 | matrix: 25 | include: 26 | - os: ubuntu-latest 27 | python: "3.10" 28 | - os: ubuntu-latest 29 | python: "3.12" 30 | - os: ubuntu-latest 31 | python: "3.12" 32 | pip-flags: "--pre" 33 | name: PRE-RELEASE DEPENDENCIES 34 | 35 | name: ${{ matrix.name }} Python ${{ matrix.python }} 36 | 37 | env: 38 | OS: ${{ matrix.os }} 39 | PYTHON: ${{ matrix.python }} 40 | 41 | steps: 42 | - uses: actions/checkout@v3 43 | - name: Set up Python ${{ matrix.python }} 44 | uses: actions/setup-python@v4 45 | with: 46 | python-version: ${{ matrix.python }} 47 | cache: "pip" 48 | cache-dependency-path: "**/pyproject.toml" 49 | 50 | - name: Install test dependencies 51 | run: | 52 | python -m pip install --upgrade pip wheel 53 | - name: Install dependencies 54 | run: | 55 | pip install ${{ matrix.pip-flags }} ".[dev,test]" 56 | - name: Test 57 | env: 58 | MPLBACKEND: agg 59 | PLATFORM: ${{ matrix.os }} 60 | DISPLAY: :42 61 | run: | 62 | coverage run -m pytest -v --color=yes 63 | - name: Report coverage 64 | run: | 65 | coverage report 66 | - name: Upload coverage 67 | uses: codecov/codecov-action@v3 68 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | The format is based on [Keep a Changelog][], 6 | and this project adheres to [Semantic Versioning][]. 7 | 8 | [keep a changelog]: https://keepachangelog.com/en/1.0.0/ 9 | [semantic versioning]: https://semver.org/spec/v2.0.0.html 10 | 11 | ## [0.3.2] - 2025-12-16 12 | 13 | ### Added 14 | 15 | - tutorial for regression 16 | - support for sample-covariate embeddings, experimental, only one-hot 17 | 18 | ### Changed 19 | 20 | - removed `muon` from dependencies for tutorials 21 | - changed the classification tutorial to include training across 3 CV folds 22 | 23 | ## Fixed 24 | 25 | - fixed a bug in regression which was caused by a wrong key when accessing the continuous covariates registered with scvi-tools 26 | 27 | ## [0.3.1] - 2025-07-14 28 | 29 | ### Fixed 30 | 31 | - Fixed a bug in `score_top_cells` that didn't set the specified `key_added` column in .obs to True. 32 | - Fixed a bug in `score_top_cells` that set the `key_added` to be categorical, which resulted in wrong indexing in `get_sample_representations`. 33 | 34 | ## [0.3.0] - 2025-07-13 35 | 36 | ### Added 37 | 38 | - **Utility functions**: Added `score_top_cells` and `get_sample_representations` to utils module 39 | - `score_top_cells`: Function to identify and score top cells based on attention weights 40 | - `get_sample_representations`: Function to aggregate cell-level data to sample-level representations 41 | 42 | ### Changed 43 | 44 | - **Major refactoring**: Removed MultiVAE and MultiVAE_MIL models, keeping only MIL classifier 45 | - **Code cleanup**: Removed MLP attention weight learning from Aggregator class 46 | - **Parameter consistency**: Fixed default values between model and module classes 47 | - **Dynamic z_dim**: Automatically infer z_dim from input data shape instead of hardcoded value 48 | 49 | ### Fixed 50 | 51 | - **Make categorical covariates categorical**: Ensure the correct type in `setup_anndata`. 52 | - **Improved error handling**: If the prediction covariate hasn't been registered with `setup_anndata`, throw an error. 53 | - **Dead links to API and changelog**: Fixed in README. 54 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Weakly supervised learning uncovers phenotypic signatures in single-cell data 2 | 3 | [![Tests][badge-tests]][link-tests] 4 | [![Documentation][badge-docs]][link-docs] 5 | 6 | [badge-tests]: https://img.shields.io/github/actions/workflow/status/theislab/multimil/test.yaml?branch=main 7 | [link-tests]: https://github.com/theislab/multimil/actions/workflows/test.yml 8 | [badge-docs]: https://img.shields.io/readthedocs/multimil 9 | [badge-colab]: https://colab.research.google.com/assets/colab-badge.svg 10 | 11 | ## Getting started 12 | 13 | Please refer to the [documentation][link-docs]. In particular, the 14 | 15 | - [API documentation][link-api] 16 | 17 | and the tutorials: 18 | 19 | - [Classification with MultiMIL](https://multimil.readthedocs.io/en/latest/notebooks/mil_classification.html) [![Open In Colab][badge-colab]](https://colab.research.google.com/github/theislab/multimil/blob/main/docs/notebooks/mil_classification.ipynb) 20 | - [Regression with MultiMIL](https://multimil.readthedocs.io/en/latest/notebooks/mil_regression.html) [![Open In Colab][badge-colab]](https://colab.research.google.com/github/theislab/multimil/blob/main/docs/notebooks/mil_regression.ipynb) 21 | 22 | Please also check out our [sample prediction pipeline](https://github.com/theislab/sample-prediction-pipeline), which contains MultiMIL and several other baselines. 23 | 24 | ## Installation 25 | 26 | You need to have Python 3.10 or newer installed on your system. We recommend installing [Mambaforge](https://github.com/conda-forge/miniforge#mambaforge). 27 | 28 | To create and activate a new environment: 29 | 30 | ```bash 31 | mamba create --name multimil python=3.10 32 | mamba activate multimil 33 | ``` 34 | 35 | Next, there are several alternative options to install multimil: 36 | 37 | 1. Install the latest release of `multimil` from [PyPI][link-pypi]: 38 | 39 | ```bash 40 | pip install multimil 41 | ``` 42 | 43 | 2. Or install the latest development version: 44 | 45 | ```bash 46 | pip install git+https://github.com/theislab/multimil.git@main 47 | ``` 48 | 49 | ## Release notes 50 | 51 | See the [changelog][changelog]. 52 | 53 | ## Contact 54 | 55 | If you found a bug, please use the [issue tracker][issue-tracker]. 56 | 57 | ## Citation 58 | 59 | Weakly supervised learning uncovers phenotypic signatures in single-cell data 60 | 61 | Anastasia Litinetskaya, Soroor Hediyeh-zadeh, Amir Ali Moinfar, Mohammad Lotfollahi, Fabian J. Theis 62 | 63 | bioRxiv 2024.07.29.605625; doi: https://doi.org/10.1101/2024.07.29.605625 64 | 65 | ## Reproducibility 66 | 67 | Code and notebooks to reproduce the results from the paper are available at [theislab/multimil_reproducibility](https://github.com/theislab/multimil_reproducibility). 68 | 69 | [issue-tracker]: https://github.com/theislab/multimil/issues 70 | [changelog]: https://multimil.readthedocs.io/en/latest/changelog.html 71 | [link-docs]: https://multimil.readthedocs.io 72 | [link-api]: https://multimil.readthedocs.io/en/latest/api.html 73 | [link-pypi]: https://pypi.org/project/multimil 74 | -------------------------------------------------------------------------------- /src/multimil/dataloaders/_data_splitting.py: -------------------------------------------------------------------------------- 1 | from scvi.data import AnnDataManager 2 | from scvi.dataloaders import DataSplitter 3 | 4 | from multimil.dataloaders._ann_dataloader import GroupAnnDataLoader 5 | 6 | 7 | # adjusted from scvi-tools 8 | # https://github.com/scverse/scvi-tools/blob/0b802762869c43c9f49e69fe62b1a5a9b5c4dae6/scvi/dataloaders/_data_splitting.py#L56 9 | # accessed on 5 November 2022 10 | class GroupDataSplitter(DataSplitter): 11 | """Creates data loaders ``train_set``, ``validation_set``, ``test_set``. 12 | 13 | If ``train_size + validation_set < 1`` then ``test_set`` is non-empty. 14 | 15 | Parameters 16 | ---------- 17 | adata_manager 18 | :class:`~scvi.data.AnnDataManager` object that has been created via ``setup_anndata``. 19 | train_size 20 | Proportion of cells to use as the train set. Float, or None (default is 0.9). 21 | validation_size 22 | Proportion of cell to use as the valisation set. Float, or None (default is None). If None, is set to 1 - ``train_size``. 23 | use_gpu 24 | Use default GPU if available (if None or True), or index of GPU to use (if int), 25 | or name of GPU (if str, e.g., `'cuda:0'`), or use CPU (if False). 26 | kwargs 27 | Keyword args for data loader. Data loader class is :class:`~mtg.dataloaders.GroupAnnDataLoader`. 28 | """ 29 | 30 | def __init__( 31 | self, 32 | adata_manager: AnnDataManager, 33 | group_column: str, 34 | train_size: float = 0.9, 35 | validation_size: float | None = None, 36 | use_gpu: bool = False, 37 | **kwargs, 38 | ): 39 | self.group_column = group_column 40 | super().__init__(adata_manager, train_size, validation_size, use_gpu, **kwargs) 41 | 42 | def train_dataloader(self): 43 | """Return data loader for train AnnData.""" 44 | return GroupAnnDataLoader( 45 | self.adata_manager, 46 | self.group_column, 47 | indices=self.train_idx, 48 | shuffle=True, 49 | drop_last=True, 50 | pin_memory=self.pin_memory, 51 | **self.data_loader_kwargs, 52 | ) 53 | 54 | def val_dataloader(self): 55 | """Return data loader for validation AnnData.""" 56 | if len(self.val_idx) > 0: 57 | return GroupAnnDataLoader( 58 | self.adata_manager, 59 | self.group_column, 60 | indices=self.val_idx, 61 | shuffle=False, 62 | drop_last=True, 63 | pin_memory=self.pin_memory, 64 | **self.data_loader_kwargs, 65 | ) 66 | else: 67 | pass 68 | 69 | def test_dataloader(self): 70 | """Return data loader for test AnnData.""" 71 | if len(self.test_idx) > 0: 72 | return GroupAnnDataLoader( 73 | self.adata_manager, 74 | self.group_column, 75 | indices=self.test_idx, 76 | shuffle=False, 77 | drop_last=True, 78 | pin_memory=self.pin_memory, 79 | **self.data_loader_kwargs, 80 | ) 81 | else: 82 | pass 83 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.yml: -------------------------------------------------------------------------------- 1 | name: Bug report 2 | description: Report something that is broken or incorrect 3 | labels: bug 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: | 8 | **Note**: Please read [this guide](https://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports) 9 | detailing how to provide the necessary information for us to reproduce your bug. In brief: 10 | * Please provide exact steps how to reproduce the bug in a clean Python environment. 11 | * In case it's not clear what's causing this bug, please provide the data or the data generation procedure. 12 | * Sometimes it is not possible to share the data, but usually it is possible to replicate problems on publicly 13 | available datasets or to share a subset of your data. 14 | 15 | - type: textarea 16 | id: report 17 | attributes: 18 | label: Report 19 | description: A clear and concise description of what the bug is. 20 | validations: 21 | required: true 22 | 23 | - type: textarea 24 | id: versions 25 | attributes: 26 | label: Version information 27 | description: | 28 | Please paste below the output of 29 | 30 | ```python 31 | import session_info 32 | session_info.show(html=False, dependencies=True) 33 | ``` 34 | placeholder: | 35 | ----- 36 | anndata 0.8.0rc2.dev27+ge524389 37 | session_info 1.0.0 38 | ----- 39 | asttokens NA 40 | awkward 1.8.0 41 | backcall 0.2.0 42 | cython_runtime NA 43 | dateutil 2.8.2 44 | debugpy 1.6.0 45 | decorator 5.1.1 46 | entrypoints 0.4 47 | executing 0.8.3 48 | h5py 3.7.0 49 | ipykernel 6.15.0 50 | jedi 0.18.1 51 | mpl_toolkits NA 52 | natsort 8.1.0 53 | numpy 1.22.4 54 | packaging 21.3 55 | pandas 1.4.2 56 | parso 0.8.3 57 | pexpect 4.8.0 58 | pickleshare 0.7.5 59 | pkg_resources NA 60 | prompt_toolkit 3.0.29 61 | psutil 5.9.1 62 | ptyprocess 0.7.0 63 | pure_eval 0.2.2 64 | pydev_ipython NA 65 | pydevconsole NA 66 | pydevd 2.8.0 67 | pydevd_file_utils NA 68 | pydevd_plugins NA 69 | pydevd_tracing NA 70 | pygments 2.12.0 71 | pytz 2022.1 72 | scipy 1.8.1 73 | setuptools 62.5.0 74 | setuptools_scm NA 75 | six 1.16.0 76 | stack_data 0.3.0 77 | tornado 6.1 78 | traitlets 5.3.0 79 | wcwidth 0.2.5 80 | zmq 23.1.0 81 | ----- 82 | IPython 8.4.0 83 | jupyter_client 7.3.4 84 | jupyter_core 4.10.0 85 | ----- 86 | Python 3.9.13 | packaged by conda-forge | (main, May 27 2022, 16:58:50) [GCC 10.3.0] 87 | Linux-5.18.6-arch1-1-x86_64-with-glibc2.35 88 | ----- 89 | Session information updated at 2022-07-07 17:55 90 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | build-backend = "hatchling.build" 3 | requires = ["hatchling"] 4 | 5 | [project] 6 | name = "multimil" 7 | version = "0.3.2" 8 | description = "Multimodal weakly supervised learning to identify disease-specific changes in single-cell atlases" 9 | readme = "README.md" 10 | requires-python = ">=3.10" 11 | license = {file = "LICENSE"} 12 | authors = [ 13 | {name = "Anastasia Litinetskaya"}, 14 | ] 15 | maintainers = [ 16 | {name = "Anastasia Litinetskaya", email = "alitinet@gmail.com"}, 17 | ] 18 | urls.Documentation = "https://multimil.readthedocs.io/" 19 | urls.Source = "https://github.com/theislab/multimil" 20 | urls.Home-page = "https://github.com/theislab/multimil" 21 | dependencies = [ 22 | "anndata<0.11", 23 | "numpy<2.0", 24 | # for debug logging (referenced from the issue template) 25 | "session-info", 26 | # multimil specific 27 | "scanpy", 28 | "scvi-tools<1.0.0", 29 | "jax<0.6", 30 | "requests", 31 | "matplotlib", 32 | ] 33 | 34 | [project.optional-dependencies] 35 | dev = [ 36 | "pre-commit", 37 | "twine>=4.0.2", 38 | ] 39 | doc = [ 40 | "docutils>=0.8,!=0.18.*,!=0.19.*", 41 | "sphinx>=4", 42 | "sphinx-book-theme>=1.0.0", 43 | "myst-nb>=1.1.0", 44 | "sphinxcontrib-bibtex>=1.0.0", 45 | "setuptools", # Until pybtex >0.23.0 releases: https://bitbucket.org/pybtex-devs/pybtex/issues/169/ 46 | "sphinx-autodoc-typehints", 47 | "sphinxext-opengraph", 48 | # For notebooks 49 | "ipykernel", 50 | "ipython", 51 | "sphinx-copybutton", 52 | "pandas", 53 | ] 54 | test = [ 55 | "pytest", 56 | "coverage", 57 | ] 58 | tutorials = [ 59 | "jupyterlab", 60 | "ipywidgets", 61 | "leidenalg", 62 | "igraph", 63 | "gdown", 64 | ] 65 | 66 | [tool.coverage.run] 67 | source = ["multimil"] 68 | omit = [ 69 | "**/test_*.py", 70 | ] 71 | 72 | [tool.pytest.ini_options] 73 | testpaths = ["tests"] 74 | xfail_strict = true 75 | addopts = [ 76 | "--import-mode=importlib", # allow using test files with same name 77 | ] 78 | 79 | [tool.ruff] 80 | line-length = 120 81 | src = ["src"] 82 | extend-include = ["*.ipynb"] 83 | 84 | [tool.ruff.format] 85 | docstring-code-format = true 86 | 87 | [tool.ruff.lint] 88 | select = [ 89 | "F", # Errors detected by Pyflakes 90 | "E", # Error detected by Pycodestyle 91 | "W", # Warning detected by Pycodestyle 92 | "I", # isort 93 | "D", # pydocstyle 94 | "B", # flake8-bugbear 95 | "TID", # flake8-tidy-imports 96 | "C4", # flake8-comprehensions 97 | "BLE", # flake8-blind-except 98 | "UP", # pyupgrade 99 | "RUF100", # Report unused noqa directives 100 | ] 101 | ignore = [ 102 | # line too long -> we accept long comment lines; formatter gets rid of long code lines 103 | "E501", 104 | # Do not assign a lambda expression, use a def -> lambda expression assignments are convenient 105 | "E731", 106 | # allow I, O, l as variable names -> I is the identity matrix 107 | "E741", 108 | # Missing docstring in public package 109 | "D104", 110 | # Missing docstring in public module 111 | "D100", 112 | # Missing docstring in __init__ 113 | "D107", 114 | # Errors from function calls in argument defaults. These are fine when the result is immutable. 115 | "B008", 116 | # __magic__ methods are often self-explanatory, allow missing docstrings 117 | "D105", 118 | # first line should end with a period [Bug: doesn't work with single-line docstrings] 119 | "D400", 120 | # First line should be in imperative mood; try rephrasing 121 | "D401", 122 | ## Disable one in each pair of mutually incompatible rules 123 | # We don’t want a blank line before a class docstring 124 | "D203", 125 | # We want docstrings to start immediately after the opening triple quote 126 | "D213", 127 | ] 128 | 129 | [tool.ruff.lint.pydocstyle] 130 | convention = "numpy" 131 | 132 | [tool.ruff.lint.per-file-ignores] 133 | "docs/*" = ["I"] 134 | "tests/*" = ["D"] 135 | "*/__init__.py" = ["F401"] 136 | "src/multimil/module/__init__.py" = ["I"] 137 | 138 | [tool.cruft] 139 | skip = [ 140 | "tests", 141 | "src/**/__init__.py", 142 | "src/**/basic.py", 143 | "docs/api.md", 144 | "docs/changelog.md", 145 | "docs/references.bib", 146 | "docs/references.md", 147 | "docs/notebooks/example.ipynb", 148 | ] 149 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | import sys 9 | from datetime import datetime 10 | from importlib.metadata import metadata 11 | from pathlib import Path 12 | 13 | HERE = Path(__file__).parent 14 | sys.path.insert(0, str(HERE / "extensions")) 15 | 16 | 17 | # -- Project information ----------------------------------------------------- 18 | 19 | # NOTE: If you installed your project in editable mode, this might be stale. 20 | # If this is the case, reinstall it to refresh the metadata 21 | info = metadata("multimil") 22 | project_name = info["Name"] 23 | author = info["Author"] 24 | copyright = f"{datetime.now():%Y}, {author}." 25 | version = info["Version"] 26 | urls = dict(pu.split(", ") for pu in info.get_all("Project-URL")) 27 | repository_url = urls["Source"] 28 | 29 | # The full version, including alpha/beta/rc tags 30 | release = info["Version"] 31 | 32 | bibtex_bibfiles = ["references.bib"] 33 | templates_path = ["_templates"] 34 | nitpicky = True # Warn about broken links 35 | needs_sphinx = "4.0" 36 | 37 | html_context = { 38 | "display_github": True, # Integrate GitHub 39 | "github_user": "alitinet", 40 | "github_repo": "https://github.com/theislab/multimil", 41 | "github_version": "main", 42 | "conf_py_path": "/docs/", 43 | } 44 | 45 | # -- General configuration --------------------------------------------------- 46 | 47 | # Add any Sphinx extension module names here, as strings. 48 | # They can be extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones. 49 | extensions = [ 50 | "myst_nb", 51 | "sphinx_copybutton", 52 | "sphinx.ext.autodoc", 53 | "sphinx.ext.intersphinx", 54 | "sphinx.ext.autosummary", 55 | "sphinx.ext.napoleon", 56 | "sphinxcontrib.bibtex", 57 | "sphinx_autodoc_typehints", 58 | "sphinx.ext.mathjax", 59 | "IPython.sphinxext.ipython_console_highlighting", 60 | "sphinxext.opengraph", 61 | *[p.stem for p in (HERE / "extensions").glob("*.py")], 62 | ] 63 | 64 | autosummary_generate = True 65 | autodoc_member_order = "groupwise" 66 | default_role = "literal" 67 | napoleon_google_docstring = False 68 | napoleon_numpy_docstring = True 69 | napoleon_include_init_with_doc = False 70 | napoleon_use_rtype = True # having a separate entry generally helps readability 71 | napoleon_use_param = True 72 | myst_heading_anchors = 6 # create anchors for h1-h6 73 | myst_enable_extensions = [ 74 | "amsmath", 75 | "colon_fence", 76 | "deflist", 77 | "dollarmath", 78 | "html_image", 79 | "html_admonition", 80 | ] 81 | myst_url_schemes = ("http", "https", "mailto") 82 | nb_output_stderr = "remove" 83 | nb_execution_mode = "off" 84 | nb_merge_streams = True 85 | typehints_defaults = "braces" 86 | 87 | source_suffix = { 88 | ".rst": "restructuredtext", 89 | ".ipynb": "myst-nb", 90 | ".myst": "myst-nb", 91 | } 92 | 93 | intersphinx_mapping = { 94 | "python": ("https://docs.python.org/3", None), 95 | "anndata": ("https://anndata.readthedocs.io/en/stable/", None), 96 | "numpy": ("https://numpy.org/doc/stable/", None), 97 | } 98 | 99 | # List of patterns, relative to source directory, that match files and 100 | # directories to ignore when looking for source files. 101 | # This pattern also affects html_static_path and html_extra_path. 102 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "**.ipynb_checkpoints"] 103 | 104 | 105 | # -- Options for HTML output ------------------------------------------------- 106 | 107 | # The theme to use for HTML and HTML Help pages. See the documentation for 108 | # a list of builtin themes. 109 | # 110 | html_theme = "sphinx_book_theme" 111 | html_static_path = ["_static"] 112 | html_css_files = ["css/custom.css"] 113 | 114 | html_title = project_name 115 | 116 | html_theme_options = { 117 | "repository_url": repository_url, 118 | "use_repository_button": True, 119 | "path_to_docs": "docs/", 120 | "navigation_with_keys": False, 121 | } 122 | 123 | pygments_style = "default" 124 | 125 | nitpick_ignore = [ 126 | # If building the documentation fails because of a missing link that is outside your control, 127 | # you can add an exception to this list. 128 | # ("py:class", "igraph.Graph"), 129 | ] 130 | -------------------------------------------------------------------------------- /src/multimil/nn/_base_components.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from scvi.nn import FCLayers 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | 7 | class MLP(nn.Module): 8 | """A helper class to build blocks of fully-connected, normalization, dropout and activation layers. 9 | 10 | Parameters 11 | ---------- 12 | n_input 13 | Number of input features. 14 | n_output 15 | Number of output features. 16 | n_layers 17 | Number of hidden layers. 18 | n_hidden 19 | Number of hidden units. 20 | dropout_rate 21 | Dropout rate. 22 | normalization 23 | Type of normalization to use. Can be one of ["layer", "batch", "none"]. 24 | activation 25 | Activation function to use. 26 | 27 | """ 28 | 29 | def __init__( 30 | self, 31 | n_input: int, 32 | n_output: int, 33 | n_layers: int = 1, 34 | n_hidden: int = 128, 35 | dropout_rate: float = 0.1, 36 | normalization: str = "layer", 37 | activation=nn.LeakyReLU, 38 | ): 39 | super().__init__() 40 | use_layer_norm = False 41 | use_batch_norm = True 42 | if normalization == "layer": 43 | use_layer_norm = True 44 | use_batch_norm = False 45 | elif normalization == "none": 46 | use_batch_norm = False 47 | 48 | self.mlp = FCLayers( 49 | n_in=n_input, 50 | n_out=n_output, 51 | n_layers=n_layers, 52 | n_hidden=n_hidden, 53 | dropout_rate=dropout_rate, 54 | use_layer_norm=use_layer_norm, 55 | use_batch_norm=use_batch_norm, 56 | activation_fn=activation, 57 | ) 58 | 59 | def forward(self, x: torch.Tensor) -> torch.Tensor: 60 | """Forward computation on ``x``. 61 | 62 | Parameters 63 | ---------- 64 | x 65 | Tensor of values with shape ``(n_input,)``. 66 | 67 | Returns 68 | ------- 69 | Tensor of values with shape ``(n_output,)``. 70 | """ 71 | return self.mlp(x) 72 | 73 | 74 | class Aggregator(nn.Module): 75 | """A helper class to build custom aggregators depending on the scoring function passed. 76 | 77 | Parameters 78 | ---------- 79 | n_input 80 | Number of input features. 81 | scoring 82 | Scoring function to use. Can be one of ["attn", "gated_attn", "mean", "max", "sum"]. 83 | attn_dim 84 | Dimension of the hidden attention layer. 85 | sample_batch_size 86 | Bag batch size. 87 | scale 88 | Whether to scale the attention weights. 89 | dropout 90 | Dropout rate. 91 | activation 92 | Activation function to use. 93 | """ 94 | 95 | def __init__( 96 | self, 97 | n_input: int | None = None, 98 | scoring="gated_attn", 99 | attn_dim=16, # D 100 | sample_batch_size=None, 101 | scale=False, 102 | dropout=0.2, 103 | activation=nn.LeakyReLU, 104 | ): 105 | super().__init__() 106 | 107 | self.scoring = scoring 108 | self.patient_batch_size = sample_batch_size 109 | self.scale = scale 110 | 111 | if self.scoring == "attn": 112 | if n_input is None: 113 | raise ValueError("n_input must be provided for attn scoring") 114 | self.attn_dim = attn_dim # attn dim from https://arxiv.org/pdf/1802.04712.pdf 115 | self.attention = nn.Sequential( 116 | nn.Linear(n_input, self.attn_dim), 117 | nn.Tanh(), 118 | nn.Linear(self.attn_dim, 1, bias=False), 119 | ) 120 | elif self.scoring == "gated_attn": 121 | if n_input is None: 122 | raise ValueError("n_input must be provided for gated_attn scoring") 123 | self.attn_dim = attn_dim 124 | self.attention_V = nn.Sequential( 125 | nn.Linear(n_input, self.attn_dim), 126 | nn.Tanh(), 127 | ) 128 | 129 | self.attention_U = nn.Sequential( 130 | nn.Linear(n_input, self.attn_dim), 131 | nn.Sigmoid(), 132 | ) 133 | 134 | self.attention_weights = nn.Linear(self.attn_dim, 1, bias=False) 135 | 136 | def forward(self, x) -> torch.Tensor: 137 | """Forward computation on ``x``. 138 | 139 | Parameters 140 | ---------- 141 | x 142 | Tensor of values with shape ``(batch_size, N, n_input)``. 143 | 144 | Returns 145 | ------- 146 | Tensor of pooled values with shape ``(batch_size, n_input)``. 147 | """ 148 | if self.scoring == "attn": 149 | # from https://github.com/AMLab-Amsterdam/AttentionDeepMIL/blob/master/model.py (accessed 16.09.2021) 150 | A = self.attention(x) # (batch_size, N, 1) 151 | A = A.transpose(1, 2) # (batch_size, 1, N) 152 | self.A = F.softmax(A, dim=-1) 153 | elif self.scoring == "gated_attn": 154 | # from https://github.com/AMLab-Amsterdam/AttentionDeepMIL/blob/master/model.py (accessed 16.09.2021) 155 | A_V = self.attention_V(x) # (batch_size, N, attn_dim) 156 | A_U = self.attention_U(x) # (batch_size, N, attn_dim) 157 | A = self.attention_weights(A_V * A_U) # (batch_size, N, 1) 158 | A = A.transpose(1, 2) # (batch_size, 1, N) 159 | self.A = F.softmax(A, dim=-1) 160 | elif self.scoring == "sum": 161 | return torch.sum(x, dim=1) # (batch_size, n_input) 162 | elif self.scoring == "mean": 163 | return torch.mean(x, dim=1) # (batch_size, n_input) 164 | elif self.scoring == "max": 165 | return torch.max(x, dim=1).values # (batch_size, n_input) 166 | else: 167 | raise NotImplementedError( 168 | f'scoring = {self.scoring} is not implemented. Has to be one of ["attn", "gated_attn", "sum", "mean", "max"].' 169 | ) 170 | 171 | if self.scale: 172 | if self.patient_batch_size is None: 173 | raise ValueError("patient_batch_size must be set when scale is True.") 174 | self.A = self.A * self.A.shape[-1] / self.patient_batch_size 175 | 176 | pooled = torch.bmm(self.A, x).squeeze(dim=1) # (batch_size, n_input) 177 | return pooled 178 | -------------------------------------------------------------------------------- /docs/contributing.md: -------------------------------------------------------------------------------- 1 | # Contributing guide 2 | 3 | Scanpy provides extensive [developer documentation][scanpy developer guide], most of which applies to this project, too. 4 | This document will not reproduce the entire content from there. Instead, it aims at summarizing the most important 5 | information to get you started on contributing. 6 | 7 | We assume that you are already familiar with git and with making pull requests on GitHub. If not, please refer 8 | to the [scanpy developer guide][]. 9 | 10 | ## Installing dev dependencies 11 | 12 | In addition to the packages needed to _use_ this package, you need additional python packages to _run tests_ and _build 13 | the documentation_. It's easy to install them using `pip`: 14 | 15 | ```bash 16 | cd multimil 17 | pip install -e ".[dev,test,doc]" 18 | ``` 19 | 20 | ## Code-style 21 | 22 | This package uses [pre-commit][] to enforce consistent code-styles. 23 | On every commit, pre-commit checks will either automatically fix issues with the code, or raise an error message. 24 | 25 | To enable pre-commit locally, simply run 26 | 27 | ```bash 28 | pre-commit install 29 | ``` 30 | 31 | in the root of the repository. Pre-commit will automatically download all dependencies when it is run for the first time. 32 | 33 | Alternatively, you can rely on the [pre-commit.ci][] service enabled on GitHub. If you didn't run `pre-commit` before 34 | pushing changes to GitHub it will automatically commit fixes to your pull request, or show an error message. 35 | 36 | If pre-commit.ci added a commit on a branch you still have been working on locally, simply use 37 | 38 | ```bash 39 | git pull --rebase 40 | ``` 41 | 42 | to integrate the changes into yours. 43 | While the [pre-commit.ci][] is useful, we strongly encourage installing and running pre-commit locally first to understand its usage. 44 | 45 | Finally, most editors have an _autoformat on save_ feature. Consider enabling this option for [ruff][ruff-editors] 46 | and [prettier][prettier-editors]. 47 | 48 | [ruff-editors]: https://docs.astral.sh/ruff/integrations/ 49 | [prettier-editors]: https://prettier.io/docs/en/editors.html 50 | 51 | ## Writing tests 52 | 53 | ```{note} 54 | Remember to first install the package with `pip install -e '.[dev,test]'` 55 | ``` 56 | 57 | This package uses the [pytest][] for automated testing. Please [write tests][scanpy-test-docs] for every function added 58 | to the package. 59 | 60 | Most IDEs integrate with pytest and provide a GUI to run tests. Alternatively, you can run all tests from the 61 | command line by executing 62 | 63 | ```bash 64 | pytest 65 | ``` 66 | 67 | in the root of the repository. 68 | 69 | ### Continuous integration 70 | 71 | Continuous integration will automatically run the tests on all pull requests and test 72 | against the minimum and maximum supported Python version. 73 | 74 | Additionally, there's a CI job that tests against pre-releases of all dependencies 75 | (if there are any). The purpose of this check is to detect incompatibilities 76 | of new package versions early on and gives you time to fix the issue or reach 77 | out to the developers of the dependency before the package is released to a wider audience. 78 | 79 | [scanpy-test-docs]: https://scanpy.readthedocs.io/en/latest/dev/testing.html#writing-tests 80 | 81 | ## Publishing a release 82 | 83 | ### Updating the version number 84 | 85 | Before making a release, you need to update the version number in the `pyproject.toml` file. Please adhere to [Semantic Versioning][semver], in brief 86 | 87 | > Given a version number MAJOR.MINOR.PATCH, increment the: 88 | > 89 | > 1. MAJOR version when you make incompatible API changes, 90 | > 2. MINOR version when you add functionality in a backwards compatible manner, and 91 | > 3. PATCH version when you make backwards compatible bug fixes. 92 | > 93 | > Additional labels for pre-release and build metadata are available as extensions to the MAJOR.MINOR.PATCH format. 94 | 95 | Once you are done, commit and push your changes and navigate to the "Releases" page of this project on GitHub. 96 | Specify `vX.X.X` as a tag name and create a release. For more information, see [managing GitHub releases][]. This will automatically create a git tag and trigger a Github workflow that creates a release on PyPI. 97 | 98 | ## Writing documentation 99 | 100 | Please write documentation for new or changed features and use-cases. This project uses [sphinx][] with the following features: 101 | 102 | - the [myst][] extension allows to write documentation in markdown/Markedly Structured Text 103 | - [Numpy-style docstrings][numpydoc] (through the [napoloen][numpydoc-napoleon] extension). 104 | - Jupyter notebooks as tutorials through [myst-nb][] (See [Tutorials with myst-nb](#tutorials-with-myst-nb-and-jupyter-notebooks)) 105 | - [Sphinx autodoc typehints][], to automatically reference annotated input and output types 106 | - Citations (like {cite:p}`Virshup_2023`) can be included with [sphinxcontrib-bibtex](https://sphinxcontrib-bibtex.readthedocs.io/) 107 | 108 | See the [scanpy developer docs](https://scanpy.readthedocs.io/en/latest/dev/documentation.html) for more information 109 | on how to write documentation. 110 | 111 | ### Tutorials with myst-nb and jupyter notebooks 112 | 113 | The documentation is set-up to render jupyter notebooks stored in the `docs/notebooks` directory using [myst-nb][]. 114 | Currently, only notebooks in `.ipynb` format are supported that will be included with both their input and output cells. 115 | It is your responsibility to update and re-run the notebook whenever necessary. 116 | 117 | If you are interested in automatically running notebooks as part of the continuous integration, please check 118 | out [this feature request](https://github.com/scverse/cookiecutter-scverse/issues/40) in the `cookiecutter-scverse` 119 | repository. 120 | 121 | #### Hints 122 | 123 | - If you refer to objects from other packages, please add an entry to `intersphinx_mapping` in `docs/conf.py`. Only 124 | if you do so can sphinx automatically create a link to the external documentation. 125 | - If building the documentation fails because of a missing link that is outside your control, you can add an entry to 126 | the `nitpick_ignore` list in `docs/conf.py` 127 | 128 | #### Building the docs locally 129 | 130 | ```bash 131 | cd docs 132 | make html 133 | open _build/html/index.html 134 | ``` 135 | 136 | 137 | 138 | [scanpy developer guide]: https://scanpy.readthedocs.io/en/latest/dev/index.html 139 | [cookiecutter-scverse-instance]: https://cookiecutter-scverse-instance.readthedocs.io/en/latest/template_usage.html 140 | [github quickstart guide]: https://docs.github.com/en/get-started/quickstart/create-a-repo?tool=webui 141 | [codecov]: https://about.codecov.io/sign-up/ 142 | [codecov docs]: https://docs.codecov.com/docs 143 | [codecov bot]: https://docs.codecov.com/docs/team-bot 144 | [codecov app]: https://github.com/apps/codecov 145 | [pre-commit.ci]: https://pre-commit.ci/ 146 | [readthedocs.org]: https://readthedocs.org/ 147 | [myst-nb]: https://myst-nb.readthedocs.io/en/latest/ 148 | [jupytext]: https://jupytext.readthedocs.io/en/latest/ 149 | [pre-commit]: https://pre-commit.com/ 150 | [anndata]: https://github.com/scverse/anndata 151 | [mudata]: https://github.com/scverse/mudata 152 | [pytest]: https://docs.pytest.org/ 153 | [semver]: https://semver.org/ 154 | [sphinx]: https://www.sphinx-doc.org/en/master/ 155 | [myst]: https://myst-parser.readthedocs.io/en/latest/intro.html 156 | [numpydoc-napoleon]: https://www.sphinx-doc.org/en/master/usage/extensions/napoleon.html 157 | [numpydoc]: https://numpydoc.readthedocs.io/en/latest/format.html 158 | [sphinx autodoc typehints]: https://github.com/tox-dev/sphinx-autodoc-typehints 159 | [pypi]: https://pypi.org/ 160 | [managing GitHub releases]: https://docs.github.com/en/repositories/releasing-projects-on-github/managing-releases-in-a-repository 161 | -------------------------------------------------------------------------------- /docs/references.bib: -------------------------------------------------------------------------------- 1 | @article{Virshup_2023, 2 | doi = {10.1038/s41587-023-01733-8}, 3 | url = {https://doi.org/10.1038%2Fs41587-023-01733-8}, 4 | year = 2023, 5 | month = {apr}, 6 | publisher = {Springer Science and Business Media {LLC}}, 7 | author = {Isaac Virshup and Danila Bredikhin and Lukas Heumos and Giovanni Palla and Gregor Sturm and Adam Gayoso and Ilia Kats and Mikaela Koutrouli and Philipp Angerer and Volker Bergen and Pierre Boyeau and Maren Büttner and Gokcen Eraslan and David Fischer and Max Frank and Justin Hong and Michal Klein and Marius Lange and Romain Lopez and Mohammad Lotfollahi and Malte D. Luecken and Fidel Ramirez and Jeffrey Regier and Sergei Rybakov and Anna C. Schaar and Valeh Valiollah Pour Amiri and Philipp Weiler and Galen Xing and Bonnie Berger and Dana Pe'er and Aviv Regev and Sarah A. Teichmann and Francesca Finotello and F. Alexander Wolf and Nir Yosef and Oliver Stegle and Fabian J. Theis and}, 8 | title = {The scverse project provides a computational ecosystem for single-cell omics data analysis}, 9 | journal = {Nature Biotechnology} 10 | } 11 | 12 | @inproceedings{Luecken2021-ct, 13 | author = {Luecken, Malte and Burkhardt, Daniel and Cannoodt, Robrecht and Lance, Christopher and Agrawal, Aditi and Aliee, Hananeh and Chen, Ann and Deconinck, Louise and Detweiler, Angela and Granados, Alejandro and Huynh, Shelly and Isacco, Laura and Kim, Yang and Klein, Dominik and DE KUMAR, BONY and Kuppasani, Sunil and Lickert, Heiko and McGeever, Aaron and Melgarejo, Joaquin and Mekonen, Honey and Morri, Maurizio and M\"{u}ller, Michaela and Neff, Norma and Paul, Sheryl and Rieck, Bastian and Schneider, Kaylie and Steelman, Scott and Sterr, Michael and Treacy, Daniel and Tong, Alexander and Villani, Alexandra-Chloe and Wang, Guilin and Yan, Jia and Zhang, Ce and Pisco, Angela and Krishnaswamy, Smita and Theis, Fabian and Bloom, Jonathan M}, 14 | booktitle = {Proceedings of the Neural Information Processing Systems Track on Datasets and Benchmarks}, 15 | editor = {J. Vanschoren and S. Yeung}, 16 | pages = {}, 17 | title = {A sandbox for prediction and integration of DNA, RNA, and proteins in single cells}, 18 | url = {https://datasets-benchmarks-proceedings.neurips.cc/paper_files/paper/2021/file/158f3069a435b314a80bdcb024f8e422-Paper-round2.pdf}, 19 | volume = {1}, 20 | year = {2021} 21 | } 22 | 23 | @ARTICLE{Sikkema2023-oh, 24 | title = "An integrated cell atlas of the lung in health and disease", 25 | author = "Sikkema, Lisa and Ramírez-Suástegui, Ciro and Strobl, Daniel C and 26 | Gillett, Tessa E and Zappia, Luke and Madissoon, Elo and Markov, 27 | Nikolay S and Zaragosi, Laure-Emmanuelle and Ji, Yuge and Ansari, 28 | Meshal and Arguel, Marie-Jeanne and Apperloo, Leonie and Banchero, 29 | Martin and Bécavin, Christophe and Berg, Marijn and 30 | Chichelnitskiy, Evgeny and Chung, Mei-I and Collin, Antoine and 31 | Gay, Aurore C A and Gote-Schniering, Janine and Hooshiar Kashani, 32 | Baharak and Inecik, Kemal and Jain, Manu and Kapellos, Theodore S 33 | and Kole, Tessa M and Leroy, Sylvie and Mayr, Christoph H and 34 | Oliver, Amanda J and von Papen, Michael and Peter, Lance and 35 | Taylor, Chase J and Walzthoeni, Thomas and Xu, Chuan and Bui, Linh 36 | T and De Donno, Carlo and Dony, Leander and Faiz, Alen and Guo, 37 | Minzhe and Gutierrez, Austin J and Heumos, Lukas and Huang, Ni and 38 | Ibarra, Ignacio L and Jackson, Nathan D and Kadur Lakshminarasimha 39 | Murthy, Preetish and Lotfollahi, Mohammad and Tabib, Tracy and 40 | Talavera-López, Carlos and Travaglini, Kyle J and Wilbrey-Clark, 41 | Anna and Worlock, Kaylee B and Yoshida, Masahiro and {Lung 42 | Biological Network Consortium} and van den Berge, Maarten and 43 | Bossé, Yohan and Desai, Tushar J and Eickelberg, Oliver and 44 | Kaminski, Naftali and Krasnow, Mark A and Lafyatis, Robert and 45 | Nikolic, Marko Z and Powell, Joseph E and Rajagopal, Jayaraj and 46 | Rojas, Mauricio and Rozenblatt-Rosen, Orit and Seibold, Max A and 47 | Sheppard, Dean and Shepherd, Douglas P and Sin, Don D and Timens, 48 | Wim and Tsankov, Alexander M and Whitsett, Jeffrey and Xu, Yan and 49 | Banovich, Nicholas E and Barbry, Pascal and Duong, Thu Elizabeth 50 | and Falk, Christine S and Meyer, Kerstin B and Kropski, Jonathan A 51 | and Pe'er, Dana and Schiller, Herbert B and Tata, Purushothama Rao 52 | and Schultze, Joachim L and Teichmann, Sara A and Misharin, 53 | Alexander V and Nawijn, Martijn C and Luecken, Malte D and Theis, 54 | Fabian J", 55 | journal = "Nat. Med.", 56 | volume = 29, 57 | number = 6, 58 | pages = "1563--1577", 59 | abstract = "Single-cell technologies have transformed our understanding of 60 | human tissues. Yet, studies typically capture only a limited 61 | number of donors and disagree on cell type definitions. 62 | Integrating many single-cell datasets can address these 63 | limitations of individual studies and capture the variability 64 | present in the population. Here we present the integrated Human 65 | Lung Cell Atlas (HLCA), combining 49 datasets of the human 66 | respiratory system into a single atlas spanning over 2.4 million 67 | cells from 486 individuals. The HLCA presents a consensus cell 68 | type re-annotation with matching marker genes, including 69 | annotations of rare and previously undescribed cell types. 70 | Leveraging the number and diversity of individuals in the HLCA, we 71 | identify gene modules that are associated with demographic 72 | covariates such as age, sex and body mass index, as well as gene 73 | modules changing expression along the proximal-to-distal axis of 74 | the bronchial tree. Mapping new data to the HLCA enables rapid 75 | data annotation and interpretation. Using the HLCA as a reference 76 | for the study of disease, we identify shared cell states across 77 | multiple lung diseases, including SPP1+ profibrotic 78 | monocyte-derived macrophages in COVID-19, pulmonary fibrosis and 79 | lung carcinoma. Overall, the HLCA serves as an example for the 80 | development and use of large-scale, cross-dataset organ atlases 81 | within the Human Cell Atlas.", 82 | month = jun, 83 | year = 2023, 84 | language = "en" 85 | } 86 | 87 | @ARTICLE{Lotfollahi2022-jw, 88 | title = "Mapping single-cell data to reference atlases by transfer learning", 89 | author = "Lotfollahi, Mohammad and Naghipourfar, Mohsen and Luecken, Malte D 90 | and Khajavi, Matin and Büttner, Maren and Wagenstetter, Marco and 91 | Avsec, Žiga and Gayoso, Adam and Yosef, Nir and Interlandi, Marta 92 | and Rybakov, Sergei and Misharin, Alexander V and Theis, Fabian J", 93 | journal = "Nat. Biotechnol.", 94 | volume = 40, 95 | number = 1, 96 | pages = "121--130", 97 | abstract = "Large single-cell atlases are now routinely generated to serve as 98 | references for analysis of smaller-scale studies. Yet learning 99 | from reference data is complicated by batch effects between 100 | datasets, limited availability of computational resources and 101 | sharing restrictions on raw data. Here we introduce a deep 102 | learning strategy for mapping query datasets on top of a reference 103 | called single-cell architectural surgery (scArches). scArches uses 104 | transfer learning and parameter optimization to enable efficient, 105 | decentralized, iterative reference building and contextualization 106 | of new datasets with existing references without sharing raw data. 107 | Using examples from mouse brain, pancreas, immune and 108 | whole-organism atlases, we show that scArches preserves biological 109 | state information while removing batch effects, despite using four 110 | orders of magnitude fewer parameters than de novo integration. 111 | scArches generalizes to multimodal reference mapping, allowing 112 | imputation of missing modalities. Finally, scArches retains 113 | coronavirus disease 2019 (COVID-19) disease variation when mapping 114 | to a healthy reference, enabling the discovery of disease-specific 115 | cell states. scArches will facilitate collaborative projects by 116 | enabling iterative construction, updating, sharing and efficient 117 | use of reference atlases.", 118 | month = jan, 119 | year = 2022, 120 | language = "en" 121 | } 122 | -------------------------------------------------------------------------------- /src/multimil/dataloaders/_ann_dataloader.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import itertools 3 | 4 | import numpy as np 5 | import torch 6 | from scvi.data import AnnDataManager 7 | from scvi.dataloaders import AnnTorchDataset 8 | from torch.utils.data import DataLoader 9 | 10 | 11 | # adjusted from scvi-tools 12 | # https://github.com/YosefLab/scvi-tools/blob/ac0c3e04fcc2772fdcf7de4de819db3af9465b6b/scvi/dataloaders/_ann_dataloader.py#L15 13 | # accessed on 4 November 2021 14 | class StratifiedSampler(torch.utils.data.sampler.Sampler): 15 | """Custom stratified sampler class which enables sampling the same number of observation from each group in each mini-batch. 16 | 17 | Parameters 18 | ---------- 19 | indices 20 | List of indices to sample from. 21 | batch_size 22 | Batch size of each iteration. 23 | shuffle 24 | If ``True``, shuffles indices before sampling. 25 | drop_last 26 | If int, drops the last batch if its length is less than drop_last. 27 | If drop_last == True, drops last non-full batch. 28 | If drop_last == False, iterate over all batches. 29 | """ 30 | 31 | def __init__( 32 | self, 33 | indices: np.ndarray, 34 | group_labels: np.ndarray, 35 | batch_size: int, 36 | min_size_per_class: int, 37 | shuffle: bool = True, 38 | drop_last: bool | int = True, 39 | shuffle_classes: bool = True, 40 | ): 41 | if drop_last > batch_size: 42 | raise ValueError( 43 | "drop_last can't be greater than batch_size. " 44 | + f"drop_last is {drop_last} but batch_size is {batch_size}." 45 | ) 46 | 47 | if batch_size % min_size_per_class != 0: 48 | raise ValueError( 49 | "min_size_per_class has to be a divisor of batch_size." 50 | + f"min_size_per_class is {min_size_per_class} but batch_size is {batch_size}." 51 | ) 52 | 53 | self.indices = indices 54 | self.group_labels = group_labels 55 | self.n_obs = len(indices) 56 | self.batch_size = batch_size 57 | self.shuffle = shuffle 58 | self.shuffle_classes = shuffle_classes 59 | self.min_size_per_class = min_size_per_class 60 | self.drop_last = drop_last 61 | 62 | from math import ceil 63 | 64 | classes = list(dict.fromkeys(self.group_labels)) 65 | 66 | tmp = 0 67 | for cl in classes: 68 | idx = np.where(self.group_labels == cl)[0] 69 | cl_idx = self.indices[idx] 70 | n_obs = len(cl_idx) 71 | last_batch_len = n_obs % self.min_size_per_class 72 | if (self.drop_last is True) or (last_batch_len < self.drop_last): 73 | drop_last_n = last_batch_len 74 | elif (self.drop_last is False) or (last_batch_len >= self.drop_last): 75 | drop_last_n = 0 76 | else: 77 | raise ValueError("Invalid input for drop_last param. Must be bool or int.") 78 | 79 | if drop_last_n != 0: 80 | tmp += n_obs // self.min_size_per_class 81 | else: 82 | tmp += ceil(n_obs / self.min_size_per_class) 83 | 84 | classes_per_batch = int(self.batch_size / self.min_size_per_class) 85 | self.length = ceil(tmp / classes_per_batch) 86 | 87 | def __iter__(self): 88 | classes_per_batch = int(self.batch_size / self.min_size_per_class) 89 | 90 | classes = list(dict.fromkeys(self.group_labels)) 91 | data_iter = [] 92 | 93 | for cl in classes: 94 | idx = np.where(self.group_labels == cl)[0] 95 | cl_idx = self.indices[idx] 96 | n_obs = len(cl_idx) 97 | 98 | if self.shuffle is True: 99 | idx = torch.randperm(n_obs).tolist() 100 | else: 101 | idx = torch.arange(n_obs).tolist() 102 | 103 | last_batch_len = n_obs % self.min_size_per_class 104 | if (self.drop_last is True) or (last_batch_len < self.drop_last): 105 | drop_last_n = last_batch_len 106 | elif (self.drop_last is False) or (last_batch_len >= self.drop_last): 107 | drop_last_n = 0 108 | else: 109 | raise ValueError("Invalid input for drop_last param. Must be bool or int.") 110 | 111 | if drop_last_n != 0: 112 | idx = idx[:-drop_last_n] 113 | 114 | data_iter.extend( 115 | [cl_idx[idx[i : i + self.min_size_per_class]] for i in range(0, len(idx), self.min_size_per_class)] 116 | ) 117 | 118 | if self.shuffle_classes: 119 | idx = torch.randperm(len(data_iter)).tolist() 120 | data_iter = [data_iter[id] for id in idx] 121 | 122 | final_data_iter = [] 123 | 124 | end = len(data_iter) - len(data_iter) % classes_per_batch 125 | for i in range(0, end, classes_per_batch): 126 | batch_idx = list(itertools.chain.from_iterable(data_iter[i : i + classes_per_batch])) 127 | final_data_iter.append(batch_idx) 128 | 129 | # deal with the last manually 130 | if end != len(data_iter): 131 | batch_idx = list(itertools.chain.from_iterable(data_iter[end:])) 132 | final_data_iter.append(batch_idx) 133 | 134 | return iter(final_data_iter) 135 | 136 | def __len__(self): 137 | return self.length 138 | 139 | 140 | # adjusted from scvi-tools 141 | # https://github.com/scverse/scvi-tools/blob/0b802762869c43c9f49e69fe62b1a5a9b5c4dae6/scvi/dataloaders/_ann_dataloader.py#L89 142 | # accessed on 5 November 2022 143 | class GroupAnnDataLoader(DataLoader): 144 | """DataLoader for loading tensors from AnnData objects. 145 | 146 | Parameters 147 | ---------- 148 | adata_manager 149 | :class:`~scvi.data.AnnDataManager` object with a registered AnnData object. 150 | shuffle 151 | Whether the data should be shuffled. 152 | indices 153 | The indices of the observations in the adata to load. 154 | batch_size 155 | Minibatch size to load each iteration. 156 | data_and_attributes 157 | Dictionary with keys representing keys in data registry (`adata.uns["_scvi"]`) 158 | and value equal to desired numpy loading type (later made into torch tensor). 159 | If `None`, defaults to all registered data. 160 | data_loader_kwargs 161 | Keyword arguments for :class:`~torch.utils.data.DataLoader`. 162 | """ 163 | 164 | def __init__( 165 | self, 166 | adata_manager: AnnDataManager, 167 | group_column: str, 168 | shuffle=True, 169 | shuffle_classes=True, 170 | indices=None, 171 | batch_size=128, 172 | min_size_per_class=None, 173 | data_and_attributes: dict | None = None, 174 | drop_last: bool | int = True, 175 | sampler: torch.utils.data.sampler.Sampler | None = StratifiedSampler, 176 | **data_loader_kwargs, 177 | ): 178 | if adata_manager.adata is None: 179 | raise ValueError("Please run register_fields() on your AnnDataManager object first.") 180 | 181 | if data_and_attributes is not None: 182 | data_registry = adata_manager.data_registry 183 | for key in data_and_attributes.keys(): 184 | if key not in data_registry.keys(): 185 | raise ValueError(f"{key} required for model but not registered with AnnDataManager.") 186 | 187 | if group_column not in adata_manager.registry["setup_args"]["categorical_covariate_keys"]: 188 | raise ValueError( 189 | f"{group_column} required for model but not in categorical covariates has to be one of the registered categorical covariates = {adata_manager.registry['setup_args']['categorical_covariate_keys']}." 190 | ) 191 | 192 | self.dataset = AnnTorchDataset(adata_manager, getitem_tensors=data_and_attributes) 193 | 194 | if min_size_per_class is None: 195 | min_size_per_class = batch_size // 2 196 | 197 | sampler_kwargs = { 198 | "batch_size": batch_size, 199 | "shuffle": shuffle, 200 | "drop_last": drop_last, 201 | "min_size_per_class": min_size_per_class, 202 | "shuffle_classes": shuffle_classes, 203 | } 204 | 205 | if indices is None: 206 | indices = np.arange(len(self.dataset)) 207 | sampler_kwargs["indices"] = indices 208 | else: 209 | if hasattr(indices, "dtype") and indices.dtype is np.dtype("bool"): 210 | indices = np.where(indices)[0].ravel() 211 | indices = np.asarray(indices) 212 | sampler_kwargs["indices"] = indices 213 | 214 | sampler_kwargs["group_labels"] = np.array( 215 | adata_manager.adata[indices].obsm["_scvi_extra_categorical_covs"][group_column] 216 | ) 217 | 218 | self.indices = indices 219 | self.sampler_kwargs = sampler_kwargs 220 | 221 | sampler = sampler(**self.sampler_kwargs) 222 | self.data_loader_kwargs = copy.copy(data_loader_kwargs) 223 | # do not touch batch size here, sampler gives batched indices 224 | self.data_loader_kwargs.update({"sampler": sampler, "batch_size": None}) 225 | 226 | super().__init__(self.dataset, **self.data_loader_kwargs) 227 | -------------------------------------------------------------------------------- /src/multimil/utils/_utils.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | 3 | import anndata as ad 4 | import numpy as np 5 | import pandas as pd 6 | import scanpy as sc 7 | import torch 8 | from matplotlib import pyplot as plt 9 | 10 | 11 | def create_df(pred, columns=None, index=None) -> pd.DataFrame: 12 | """Create a pandas DataFrame from a list of predictions. 13 | 14 | Parameters 15 | ---------- 16 | pred 17 | List of predictions. 18 | columns 19 | Column names, i.e. class_names. 20 | index 21 | Index names, i.e. obs_names. 22 | 23 | Returns 24 | ------- 25 | DataFrame with predictions. 26 | """ 27 | if isinstance(pred, dict): 28 | for key in pred.keys(): 29 | pred[key] = torch.cat(pred[key]).squeeze().cpu().numpy() 30 | else: 31 | pred = torch.cat(pred).squeeze().cpu().numpy() 32 | 33 | df = pd.DataFrame(pred) 34 | if index is not None: 35 | df.index = index 36 | if columns is not None: 37 | df.columns = columns 38 | return df 39 | 40 | 41 | def setup_ordinal_regression(adata, ordinal_regression_order, categorical_covariate_keys): 42 | """Setup ordinal regression. 43 | 44 | Parameters 45 | ---------- 46 | adata 47 | Annotated data object. 48 | ordinal_regression_order 49 | Order of categories for ordinal regression. 50 | categorical_covariate_keys 51 | Keys of categorical covariates. 52 | """ 53 | if ordinal_regression_order is not None: 54 | if not set(ordinal_regression_order.keys()).issubset(categorical_covariate_keys): 55 | raise ValueError( 56 | f"All keys {ordinal_regression_order.keys()} has to be registered as categorical covariates too, but categorical_covariate_keys = {categorical_covariate_keys}" 57 | ) 58 | for key in ordinal_regression_order.keys(): 59 | # Get unique values from the column without assuming it's categorical 60 | unique_values = np.unique(adata.obs[key].values) 61 | if set(unique_values) != set(ordinal_regression_order[key]): 62 | raise ValueError( 63 | f"Unique values in adata.obs[{key}]={unique_values} are not the same as categories specified = {ordinal_regression_order[key]}" 64 | ) 65 | adata.obs[key] = adata.obs[key].cat.reorder_categories(ordinal_regression_order[key]) 66 | 67 | 68 | def select_covariates(covs, idx, n_samples_in_batch) -> torch.Tensor: 69 | """Select covariates from all covariates. 70 | 71 | Parameters 72 | ---------- 73 | covs 74 | Covariates. 75 | idx 76 | Index of covariates. 77 | n_samples_in_batch 78 | Number of samples in the batch. 79 | 80 | Returns 81 | ------- 82 | Prediction covariates. 83 | """ 84 | if len(idx) > 0: 85 | covs = torch.index_select(covs, 1, idx) 86 | covs = covs.view(n_samples_in_batch, -1, len(idx))[:, 0, :] 87 | else: 88 | covs = torch.tensor([]) 89 | return covs 90 | 91 | 92 | def prep_minibatch(covs, sample_batch_size) -> tuple[int, int]: 93 | """Prepare minibatch. 94 | 95 | Parameters 96 | ---------- 97 | covs 98 | Covariates. 99 | sample_batch_size 100 | Sample batch size. 101 | 102 | Returns 103 | ------- 104 | Batch size and number of samples in the batch. 105 | """ 106 | batch_size = covs.shape[0] 107 | 108 | if batch_size % sample_batch_size != 0: 109 | n_samples_in_batch = 1 110 | else: 111 | n_samples_in_batch = batch_size // sample_batch_size 112 | return batch_size, n_samples_in_batch 113 | 114 | 115 | def get_predictions( 116 | prediction_idx, pred_values, true_values, size, bag_pred, bag_true, full_pred, offset=0 117 | ) -> tuple[dict, dict, dict]: 118 | """Get predictions. 119 | 120 | Parameters 121 | ---------- 122 | prediction_idx 123 | Index of predictions. 124 | pred_values 125 | Predicted values. 126 | true_values 127 | True values. 128 | size 129 | Size of the bag minibatch. 130 | bag_pred 131 | Bag predictions. 132 | bag_true 133 | Bag true values. 134 | full_pred 135 | Full predictions, i.e. on cell-level. 136 | offset 137 | Offset, needed because of several possible types of predictions. 138 | 139 | Returns 140 | ------- 141 | Bag predictions, bag true values, full predictions on cell-level. 142 | """ 143 | for i in range(len(prediction_idx)): 144 | bag_pred[i] = bag_pred.get(i, []) + [pred_values[offset + i].cpu()] 145 | bag_true[i] = bag_true.get(i, []) + [true_values[:, i].cpu()] 146 | # TODO in ord reg had pred[len(self.mil.class_idx) + i].repeat(1, size).flatten() 147 | # in reg had 148 | # cell level, i.e. prediction for the cell = prediction for the bag 149 | full_pred[i] = full_pred.get(i, []) + [pred_values[offset + i].unsqueeze(1).repeat(1, size, 1).flatten(0, 1)] 150 | return bag_pred, bag_true, full_pred 151 | 152 | 153 | def get_bag_info(bags, n_samples_in_batch, minibatch_size, cell_counter, bag_counter, sample_batch_size): 154 | """Get bag information. 155 | 156 | Parameters 157 | ---------- 158 | bags 159 | Bags. 160 | n_samples_in_batch 161 | Number of samples in the batch. 162 | minibatch_size 163 | Minibatch size. 164 | cell_counter 165 | Cell counter. 166 | bag_counter 167 | Bag counter. 168 | sample_batch_size 169 | Sample batch size. 170 | 171 | Returns 172 | ------- 173 | Updated bags, cell counter, and bag counter. 174 | """ 175 | if n_samples_in_batch == 1: 176 | bags += [[bag_counter] * minibatch_size] 177 | cell_counter += minibatch_size 178 | bag_counter += 1 179 | else: 180 | bags += [[bag_counter + i] * sample_batch_size for i in range(n_samples_in_batch)] 181 | bag_counter += n_samples_in_batch 182 | cell_counter += sample_batch_size * n_samples_in_batch 183 | return bags, cell_counter, bag_counter 184 | 185 | 186 | def save_predictions_in_adata( 187 | adata, idx, predictions, bag_pred, bag_true, cell_pred, class_names, name, clip, reg=False 188 | ): 189 | """Save predictions in anndata object. 190 | 191 | Parameters 192 | ---------- 193 | adata 194 | Annotated data object. 195 | idx 196 | Index, i.e. obs_names. 197 | predictions 198 | Predictions. 199 | bag_pred 200 | Bag predictions. 201 | bag_true 202 | Bag true values. 203 | cell_pred 204 | Cell predictions. 205 | class_names 206 | Class names. 207 | name 208 | Name of the prediction column. 209 | clip 210 | Whether to transofrm the predictions. One of `clip`, `argmax`, or `none`. 211 | reg 212 | Whether the rediciton task is a regression task. 213 | """ 214 | # cell level predictions) 215 | 216 | if clip == "clip": # ord regression 217 | df = create_df(cell_pred[idx], [name], index=adata.obs_names) 218 | adata.obsm[f"full_predictions_{name}"] = df 219 | adata.obs[f"predicted_{name}"] = np.clip(np.round(df.to_numpy()), a_min=0.0, a_max=len(class_names) - 1.0) 220 | elif clip == "argmax": # classification 221 | df = create_df(cell_pred[idx], class_names, index=adata.obs_names) 222 | adata.obsm[f"full_predictions_{name}"] = df 223 | adata.obs[f"predicted_{name}"] = df.to_numpy().argmax(axis=1) 224 | else: # regression 225 | df = create_df(cell_pred[idx], [name], index=adata.obs_names) 226 | adata.obsm[f"full_predictions_{name}"] = df 227 | adata.obs[f"predicted_{name}"] = df.to_numpy() 228 | if reg is False: 229 | adata.obs[f"predicted_{name}"] = adata.obs[f"predicted_{name}"].astype("category") 230 | adata.obs[f"predicted_{name}"] = adata.obs[f"predicted_{name}"].cat.rename_categories( 231 | dict(enumerate(class_names)) 232 | ) 233 | 234 | # bag level predictions 235 | adata.uns[f"bag_true_{name}"] = create_df(bag_true, predictions) 236 | if clip == "clip": # ordinal regression 237 | df_bag = create_df(bag_pred[idx], [name]) 238 | adata.uns[f"bag_full_predictions_{name}"] = np.clip( 239 | np.round(df_bag.to_numpy()), a_min=0.0, a_max=len(class_names) - 1.0 240 | ) 241 | elif clip == "argmax": # classification 242 | df_bag = create_df(bag_pred[idx], class_names) 243 | adata.uns[f"bag_full_predictions_{name}"] = df_bag.to_numpy().argmax(axis=1) 244 | else: # regression 245 | df_bag = create_df(bag_pred[idx], [name]) 246 | adata.uns[f"bag_full_predictions_{name}"] = df_bag.to_numpy() 247 | 248 | 249 | def plt_plot_losses(history, loss_names, save): 250 | """Plot losses. 251 | 252 | Parameters 253 | ---------- 254 | history 255 | History of losses. 256 | loss_names 257 | Loss names to plot. 258 | save 259 | Path to save the plot. 260 | """ 261 | df = pd.concat(history, axis=1) 262 | df.columns = df.columns.droplevel(-1) 263 | df["epoch"] = df.index 264 | 265 | nrows = ceil(len(loss_names) / 2) 266 | 267 | plt.figure(figsize=(15, 5 * nrows)) 268 | 269 | for i, name in enumerate(loss_names): 270 | plt.subplot(nrows, 2, i + 1) 271 | plt.plot(df["epoch"], df[name + "_train"], ".-", label=name + "_train") 272 | plt.plot(df["epoch"], df[name + "_validation"], ".-", label=name + "_validation") 273 | plt.xlabel("epoch") 274 | plt.legend() 275 | if save is not None: 276 | plt.savefig(save, bbox_inches="tight") 277 | 278 | 279 | def get_sample_representations( 280 | adata, 281 | sample_key, 282 | use_rep="X", 283 | aggregation="weighted", 284 | cell_attn_key="cell_attn", 285 | covs_to_keep=None, 286 | top_fraction=None, 287 | ) -> ad.AnnData: 288 | """Get sample representations from cell-level representations. 289 | 290 | Parameters 291 | ---------- 292 | adata 293 | Annotated data object with cell-level representations. 294 | sample_key 295 | Key in `adata.obs` that identifies samples. 296 | use_rep 297 | Key in `adata.obsm` to use for sample representations or '.X' (default is 'X'). 298 | aggregation 299 | Method to aggregate cell-level representations to sample-level. Options are 'weighted' or 'mean'. 300 | cell_attn_key 301 | Key in `adata.obs` that contains cell-level attention weights (if aggregation is 'weighted'). 302 | covs_to_keep 303 | List of sample-level covariate keys to keep in the final sample representation. 304 | top_fraction 305 | Fraction of top cells to select based on attention weights. If None, uses all cells. 306 | If provided, will first score top cells and then use only those for sample representation. 307 | 308 | Returns 309 | ------- 310 | ad.AnnData 311 | Annotated data object with sample-level representations. 312 | """ 313 | if use_rep == "X": 314 | tmp = adata.copy() 315 | else: 316 | if use_rep not in adata.obsm.keys(): 317 | raise ValueError(f"Key '{use_rep}' not found in adata.obsm. Available keys: {adata.obsm.keys()}") 318 | tmp = sc.AnnData(adata.obsm[use_rep], obs=adata.obs.copy()) 319 | tmp.obs[sample_key] = tmp.obs[sample_key].astype(str) 320 | 321 | # If top_fraction is provided, first score top cells and filter 322 | if top_fraction is not None: 323 | if cell_attn_key not in tmp.obs.columns: 324 | raise ValueError(f"Key '{cell_attn_key}' not found in adata.obs. Required for top cell selection.") 325 | 326 | # Use the existing score_top_cells function 327 | score_top_cells(tmp, top_fraction=top_fraction, sample_key=sample_key, key_added="_top_cell_flag") 328 | 329 | # Filter to only top cells 330 | tmp = tmp[tmp.obs["_top_cell_flag"]].copy() 331 | tmp.obs = tmp.obs.drop("_top_cell_flag", axis=1) 332 | 333 | for i in range(tmp.X.shape[1]): 334 | if aggregation == "weighted": 335 | tmp.obs[f"latent{i}"] = tmp.X[:, i] * tmp.obs[cell_attn_key] 336 | elif aggregation == "mean": 337 | tmp.obs[f"latent{i}"] = tmp.X[:, i].copy() 338 | else: 339 | raise ValueError(f"Aggregation method {aggregation} is not supported. Use 'weighted' or 'mean'.") 340 | 341 | if covs_to_keep is not None: 342 | # check that covariates are sample-level i.e. have the same value for all cells in a sample and print warning if not and which ones are not 343 | for cov in covs_to_keep: 344 | if cov not in tmp.obs.columns: 345 | raise ValueError(f"Covariate '{cov}' not found in adata.obs. Available keys: {tmp.obs.columns}") 346 | # check that value is the same for all cells in each sample 347 | if tmp.obs.groupby(sample_key)[cov].nunique().max() > 1: 348 | raise ValueError( 349 | f"Covariate '{cov}' has different values for different cells in a sample. " 350 | "Please pass only sample-level covariates." 351 | ) 352 | if sample_key not in covs_to_keep: 353 | covs_to_keep = [sample_key] + covs_to_keep 354 | else: 355 | covs_to_keep = [sample_key] 356 | 357 | if aggregation == "weighted": 358 | df = ( 359 | tmp.obs[[f"latent{i}" for i in range(tmp.X.shape[1])] + [sample_key]].groupby(sample_key).agg("sum") 360 | ) # because already multiplied by normalized weights 361 | elif aggregation == "mean": 362 | df = tmp.obs[[f"latent{i}" for i in range(tmp.X.shape[1])] + [sample_key]].groupby(sample_key).agg("mean") 363 | df = df.join(tmp.obs[covs_to_keep].groupby(sample_key).agg("first")) 364 | 365 | final_covs = [cov for cov in covs_to_keep if cov != sample_key] 366 | pb = sc.AnnData(df.drop(final_covs, axis=1).values) 367 | pb.obs = df[final_covs].copy() 368 | 369 | return pb 370 | 371 | 372 | def score_top_cells(adata, top_fraction=0.1, sample_key=None, key_added="top_cell_attn"): 373 | """Score top cells based on cell attention weights. 374 | 375 | Parameters 376 | ---------- 377 | adata 378 | Annotated data object with cell attention weights in `adata.obs['cell_attn']`. 379 | top_fraction 380 | Fraction of top cells to select based on attention weights (default is 0.1). 381 | sample_key 382 | Key in `adata.obs` that identifies samples. If None, will calculate across all cells. 383 | key_added 384 | Key in `adata.obs` to store the top cell attention scores (default is 'top_cell_attn'). 385 | """ 386 | if "cell_attn" not in adata.obs.columns: 387 | raise ValueError("adata.obs must contain 'cell_attn' column with cell attention weights.") 388 | 389 | adata.obs[key_added] = False 390 | if sample_key is None: 391 | sample_key = "_tmp_sample" 392 | adata.obs[sample_key] = "_tmp_sample" 393 | for sample in np.unique(adata.obs[sample_key]): 394 | adata_sample = adata[adata.obs[sample_key] == sample].copy() 395 | threshold_idx = int(len(adata_sample) * (1 - top_fraction)) 396 | threshold_value = sorted(adata_sample.obs["cell_attn"])[threshold_idx] 397 | top_idx = adata_sample[adata_sample.obs["cell_attn"] >= threshold_value].obs_names 398 | adata.obs.loc[top_idx, key_added] = True 399 | if sample_key == "_tmp_sample": 400 | del adata.obs[sample_key] 401 | -------------------------------------------------------------------------------- /src/multimil/module/_mil_torch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from scvi import REGISTRY_KEYS 3 | from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from multimil.nn import MLP, Aggregator 8 | from multimil.utils import prep_minibatch, select_covariates 9 | 10 | 11 | class MILClassifierTorch(BaseModuleClass): 12 | """MultiMIL's MIL classification module. 13 | 14 | Parameters 15 | ---------- 16 | z_dim 17 | Latent dimension. 18 | dropout 19 | Dropout rate. 20 | normalization 21 | Normalization type. 22 | num_classification_classes 23 | Number of classes for each of the classification task. 24 | scoring 25 | Scoring type. One of ["gated_attn", "attn", "mean", "max", "sum"]. 26 | attn_dim 27 | Hidden attention dimension. 28 | n_layers_cell_aggregator 29 | Number of layers in the cell aggregator. 30 | n_layers_classifier 31 | Number of layers in the classifier. 32 | n_layers_regressor 33 | Number of layers in the regressor. 34 | n_hidden_regressor 35 | Hidden dimension in the regressor. 36 | n_hidden_cell_aggregator 37 | Hidden dimension in the cell aggregator. 38 | n_hidden_classifier 39 | Hidden dimension in the classifier. 40 | class_loss_coef 41 | Classification loss coefficient. 42 | regression_loss_coef 43 | Regression loss coefficient. 44 | sample_batch_size 45 | Sample batch size. 46 | class_idx 47 | Which indices in cat covariates to do classification on. 48 | ord_idx 49 | Which indices in cat covariates to do ordinal regression on. 50 | reg_idx 51 | Which indices in cont covariates to do regression on. 52 | activation 53 | Activation function. 54 | initialization 55 | Initialization type. 56 | anneal_class_loss 57 | Whether to anneal the classification loss. 58 | """ 59 | 60 | def __init__( 61 | self, 62 | z_dim=16, 63 | dropout=0.2, 64 | normalization="layer", 65 | num_classification_classes=None, # number of classes for each of the classification task 66 | scoring="gated_attn", 67 | attn_dim=16, 68 | n_layers_cell_aggregator=1, 69 | n_layers_classifier=2, 70 | n_layers_regressor=2, 71 | n_hidden_regressor=128, 72 | n_hidden_cell_aggregator=128, 73 | n_hidden_classifier=128, 74 | class_loss_coef=1.0, 75 | regression_loss_coef=1.0, 76 | sample_batch_size=128, 77 | class_idx=None, # which indices in cat covariates to do classification on, i.e. exclude from inference; this is a torch tensor 78 | ord_idx=None, # which indices in cat covariates to do ordinal regression on and also exclude from inference; this is a torch tensor 79 | reg_idx=None, # which indices in cont covariates to do regression on and also exclude from inference; this is a torch tensor 80 | activation="leaky_relu", 81 | initialization=None, 82 | anneal_class_loss=False, 83 | use_sample_cov=False, 84 | cat_sample_idx=None, 85 | cont_sample_idx=None, 86 | num_cat_cov_classes=None, 87 | min_cont_sample_cov=None, 88 | max_cont_sample_cov=None, 89 | ): 90 | super().__init__() 91 | 92 | if activation == "leaky_relu": 93 | self.activation = nn.LeakyReLU 94 | elif activation == "tanh": 95 | self.activation = nn.Tanh 96 | else: 97 | raise NotImplementedError( 98 | f'activation should be one of ["leaky_relu", "tanh"], but activation={activation} was passed.' 99 | ) 100 | 101 | self.class_loss_coef = class_loss_coef 102 | self.regression_loss_coef = regression_loss_coef 103 | self.sample_batch_size = sample_batch_size 104 | self.anneal_class_loss = anneal_class_loss 105 | self.num_classification_classes = num_classification_classes 106 | self.class_idx = class_idx 107 | self.ord_idx = ord_idx 108 | self.reg_idx = reg_idx 109 | self.cat_sample_idx = cat_sample_idx 110 | self.cont_sample_idx = cont_sample_idx 111 | self.num_cat_cov_classes = num_cat_cov_classes 112 | self.use_sample_cov = use_sample_cov 113 | self.min_cont_sample_cov = min_cont_sample_cov 114 | self.max_cont_sample_cov = max_cont_sample_cov 115 | 116 | self.cell_level_aggregator = nn.Sequential( 117 | MLP( 118 | z_dim, 119 | z_dim, 120 | n_layers=n_layers_cell_aggregator, 121 | n_hidden=n_hidden_cell_aggregator, 122 | dropout_rate=dropout, 123 | activation=self.activation, 124 | normalization=normalization, 125 | ), 126 | Aggregator( 127 | n_input=z_dim, 128 | scoring=scoring, 129 | attn_dim=attn_dim, 130 | sample_batch_size=sample_batch_size, 131 | scale=True, 132 | dropout=dropout, 133 | activation=self.activation, 134 | ), 135 | ) 136 | 137 | if use_sample_cov is True: 138 | class_input_dim = z_dim 139 | for n in num_cat_cov_classes: 140 | class_input_dim += n 141 | for _ in range(len(self.cont_sample_idx)): 142 | class_input_dim += 1 143 | else: 144 | class_input_dim = z_dim 145 | 146 | if len(self.class_idx) > 0: 147 | self.classifiers = torch.nn.ModuleList() 148 | 149 | for num in self.num_classification_classes: 150 | if n_layers_classifier == 1: 151 | self.classifiers.append(nn.Linear(class_input_dim, num)) 152 | else: 153 | self.classifiers.append( 154 | nn.Sequential( 155 | MLP( 156 | class_input_dim, 157 | n_hidden_classifier, 158 | n_layers=n_layers_classifier - 1, 159 | n_hidden=n_hidden_classifier, 160 | dropout_rate=dropout, 161 | activation=self.activation, 162 | ), 163 | nn.Linear(n_hidden_classifier, num), 164 | ) 165 | ) 166 | 167 | if len(self.ord_idx) + len(self.reg_idx) > 0: 168 | self.regressors = torch.nn.ModuleList() 169 | for _ in range( 170 | len(self.ord_idx) + len(self.reg_idx) 171 | ): # one head per standard regression and one per ordinal regression 172 | if n_layers_regressor == 1: 173 | self.regressors.append(nn.Linear(z_dim, 1)) 174 | else: 175 | self.regressors.append( 176 | nn.Sequential( 177 | MLP( 178 | class_input_dim, 179 | n_hidden_regressor, 180 | n_layers=n_layers_regressor - 1, 181 | n_hidden=n_hidden_regressor, 182 | dropout_rate=dropout, 183 | activation=self.activation, 184 | ), 185 | nn.Linear(n_hidden_regressor, 1), 186 | ) 187 | ) 188 | 189 | if initialization == "xavier": 190 | for layer in self.modules(): 191 | if isinstance(layer, nn.Linear): 192 | nn.init.xavier_uniform_(layer.weight, gain=nn.init.calculate_gain("leaky_relu")) 193 | elif initialization == "kaiming": 194 | for layer in self.modules(): 195 | if isinstance(layer, nn.Linear): 196 | # following https://towardsdatascience.com/understand-kaiming-initialization-and-implementation-detail-in-pytorch-f7aa967e9138 (accessed 16.08.22) 197 | nn.init.kaiming_normal_(layer.weight, mode="fan_in") 198 | 199 | def _get_inference_input(self, tensors): 200 | x = tensors[REGISTRY_KEYS.X_KEY] 201 | if self.use_sample_cov is True: 202 | cont_key = REGISTRY_KEYS.CONT_COVS_KEY 203 | cont_covs = tensors[cont_key] if cont_key in tensors.keys() else None 204 | 205 | cat_key = REGISTRY_KEYS.CAT_COVS_KEY 206 | cat_covs = tensors[cat_key] if cat_key in tensors.keys() else None 207 | 208 | _, n_samples_in_batch = prep_minibatch(cat_covs, self.sample_batch_size) 209 | if self.cat_sample_idx.shape[0] > 0: 210 | cat_sample_covs = select_covariates( 211 | cat_covs.to(self.device), self.cat_sample_idx.to(self.device), n_samples_in_batch 212 | ) 213 | else: 214 | cat_sample_covs = None 215 | if self.cont_sample_idx.shape[0] > 0: 216 | cont_sample_covs = select_covariates( 217 | cont_covs.to(self.device), self.cont_sample_idx.to(self.device), n_samples_in_batch 218 | ) 219 | else: 220 | cont_sample_covs = None 221 | else: 222 | cat_sample_covs = None 223 | cont_sample_covs = None 224 | 225 | return {"x": x, "cat_sample_covs": cat_sample_covs, "cont_sample_covs": cont_sample_covs} 226 | 227 | def _get_generative_input(self, tensors, inference_outputs): 228 | z = inference_outputs["z"] 229 | return {"z": z} 230 | 231 | @auto_move_data 232 | def inference(self, x, cat_sample_covs, cont_sample_covs) -> dict[str, torch.Tensor | list[torch.Tensor]]: 233 | """Forward pass for inference. 234 | 235 | Parameters 236 | ---------- 237 | x 238 | Input. 239 | cat_sample_covs 240 | Categorical sample covariates. 241 | cont_sample_covs 242 | Continuous sample covariates. 243 | 244 | Returns 245 | ------- 246 | Predictions. 247 | """ 248 | z = x 249 | inference_outputs = {"z": z} 250 | 251 | # MIL part 252 | batch_size = x.shape[0] 253 | 254 | idx = list(range(self.sample_batch_size, batch_size, self.sample_batch_size)) 255 | if ( 256 | batch_size % self.sample_batch_size != 0 257 | ): # can only happen during inference for last batches for each sample 258 | idx = [] 259 | zs = torch.tensor_split(z, idx, dim=0) 260 | zs = torch.stack(zs, dim=0) # num of bags x batch_size x z_dim 261 | zs_attn = self.cell_level_aggregator(zs) # num of bags x cond_dim 262 | 263 | if cat_sample_covs is not None: 264 | one_hot_cat_sample_covs = [] 265 | for i, num_classes in zip(range(cat_sample_covs.shape[1]), self.num_cat_cov_classes, strict=False): 266 | one_hot_sample = F.one_hot(cat_sample_covs.long()[:, i], num_classes=num_classes) 267 | one_hot_cat_sample_covs.append(one_hot_sample) 268 | 269 | cat_sample_covs = torch.cat(one_hot_cat_sample_covs, dim=1) 270 | zs_attn = torch.cat([zs_attn, cat_sample_covs], dim=1) 271 | 272 | if cont_sample_covs is not None: 273 | # min max scale continuous sample covariates 274 | for i in range(cont_sample_covs.shape[1]): 275 | cont_sample_covs[:, i] = (cont_sample_covs[:, i] - self.min_cont_sample_cov[i]) / ( 276 | self.max_cont_sample_cov[i] - self.min_cont_sample_cov[i] 277 | ) 278 | zs_attn = torch.cat([zs_attn, cont_sample_covs], dim=1) 279 | 280 | predictions = [] 281 | if len(self.class_idx) > 0: 282 | predictions.extend([classifier(zs_attn) for classifier in self.classifiers]) 283 | if len(self.ord_idx) + len(self.reg_idx) > 0: 284 | predictions.extend([regressor(zs_attn) for regressor in self.regressors]) 285 | 286 | inference_outputs.update( 287 | {"predictions": predictions} 288 | ) # predictions are a list as they can have different number of classes 289 | return inference_outputs 290 | 291 | @auto_move_data 292 | def generative(self, z) -> dict[str, torch.Tensor | list[torch.Tensor]]: 293 | """Forward pass for generative. 294 | 295 | Parameters 296 | ---------- 297 | z 298 | Latent embeddings. 299 | 300 | Returns 301 | ------- 302 | Tensor of same shape as input. 303 | """ 304 | return {"z": z} 305 | 306 | def _calculate_loss(self, tensors, inference_outputs, generative_outputs, kl_weight: float = 1.0): 307 | cont_key = REGISTRY_KEYS.CONT_COVS_KEY 308 | cont_covs = tensors[cont_key] if cont_key in tensors.keys() else None 309 | 310 | cat_key = REGISTRY_KEYS.CAT_COVS_KEY 311 | cat_covs = tensors[cat_key] if cat_key in tensors.keys() else None 312 | 313 | # MIL classification loss 314 | minibatch_size, n_samples_in_batch = prep_minibatch(cat_covs, self.sample_batch_size) 315 | regression = select_covariates(cont_covs, self.reg_idx.to(self.device), n_samples_in_batch) 316 | ordinal_regression = select_covariates(cat_covs, self.ord_idx.to(self.device), n_samples_in_batch) 317 | classification = select_covariates(cat_covs, self.class_idx.to(self.device), n_samples_in_batch) 318 | 319 | predictions = inference_outputs["predictions"] # list, first from classifiers, then from regressors 320 | 321 | accuracies = [] 322 | classification_loss = torch.tensor(0.0).to(self.device) 323 | for i in range(len(self.class_idx)): 324 | classification_loss += F.cross_entropy( 325 | predictions[i], classification[:, i].long() 326 | ) # assume same in the batch 327 | accuracies.append( 328 | torch.sum(torch.eq(torch.argmax(predictions[i], dim=-1), classification[:, i])) 329 | / classification[:, i].shape[0] 330 | ) 331 | 332 | regression_loss = torch.tensor(0.0).to(self.device) 333 | for i in range(len(self.ord_idx)): 334 | regression_loss += F.mse_loss(predictions[len(self.class_idx) + i].squeeze(-1), ordinal_regression[:, i]) 335 | accuracies.append( 336 | torch.sum( 337 | torch.eq( 338 | torch.clamp( 339 | torch.round(predictions[len(self.class_idx) + i].squeeze()), 340 | min=0.0, 341 | max=self.num_classification_classes[i] - 1.0, 342 | ), 343 | ordinal_regression[:, i], 344 | ) 345 | ) 346 | / ordinal_regression[:, i].shape[0] 347 | ) 348 | 349 | for i in range(len(self.reg_idx)): 350 | regression_loss += F.mse_loss( 351 | predictions[len(self.class_idx) + len(self.ord_idx) + i].squeeze(-1), 352 | regression[:, i], 353 | ) 354 | 355 | class_loss_anneal_coef = kl_weight if self.anneal_class_loss else 1.0 356 | 357 | loss = torch.mean( 358 | self.class_loss_coef * classification_loss * class_loss_anneal_coef 359 | + self.regression_loss_coef * regression_loss 360 | ) 361 | 362 | extra_metrics = { 363 | "class_loss": classification_loss, 364 | "regression_loss": regression_loss, 365 | } 366 | 367 | if len(accuracies) > 0: 368 | accuracy = torch.sum(torch.tensor(accuracies)) / len(accuracies) 369 | extra_metrics["accuracy"] = accuracy 370 | 371 | # don't need in this model but have to return 372 | recon_loss = torch.zeros(minibatch_size) 373 | kl_loss = torch.zeros(minibatch_size) 374 | 375 | return loss, recon_loss, kl_loss, extra_metrics 376 | 377 | def loss(self, tensors, inference_outputs, generative_outputs, kl_weight: float = 1.0): 378 | """Loss calculation. 379 | 380 | Parameters 381 | ---------- 382 | tensors 383 | Input tensors. 384 | inference_outputs 385 | Inference outputs. 386 | generative_outputs 387 | Generative outputs. 388 | kl_weight 389 | KL weight. Default is 1.0. 390 | 391 | Returns 392 | ------- 393 | Prediction loss. 394 | """ 395 | loss, recon_loss, kl_loss, extra_metrics = self._calculate_loss( 396 | tensors, inference_outputs, generative_outputs, kl_weight 397 | ) 398 | 399 | return LossOutput( 400 | loss=loss, 401 | reconstruction_loss=recon_loss, 402 | kl_local=kl_loss, 403 | extra_metrics=extra_metrics, 404 | ) 405 | 406 | def select_losses_to_plot(self): 407 | """Select losses to plot. 408 | 409 | Returns 410 | ------- 411 | Loss names. 412 | """ 413 | loss_names = [] 414 | if self.class_loss_coef != 0 and len(self.class_idx) > 0: 415 | loss_names.extend(["class_loss", "accuracy"]) 416 | if self.regression_loss_coef != 0 and len(self.reg_idx) > 0: 417 | loss_names.append("regression_loss") 418 | if self.regression_loss_coef != 0 and len(self.ord_idx) > 0: 419 | loss_names.extend(["regression_loss", "accuracy"]) 420 | return loss_names 421 | -------------------------------------------------------------------------------- /src/multimil/model/_mil.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import warnings 3 | 4 | import anndata as ad 5 | import torch 6 | from anndata import AnnData 7 | from pytorch_lightning.callbacks import ModelCheckpoint 8 | from scvi import REGISTRY_KEYS 9 | from scvi.data import AnnDataManager, fields 10 | from scvi.data._constants import _MODEL_NAME_KEY, _SETUP_ARGS_KEY 11 | from scvi.model._utils import parse_use_gpu_arg 12 | from scvi.model.base import ArchesMixin, BaseModelClass 13 | from scvi.model.base._archesmixin import _get_loaded_data 14 | from scvi.model.base._utils import _initialize_model 15 | from scvi.train import AdversarialTrainingPlan, TrainRunner 16 | from scvi.train._callbacks import SaveBestState 17 | 18 | from multimil.dataloaders import GroupAnnDataLoader, GroupDataSplitter 19 | from multimil.module import MILClassifierTorch 20 | from multimil.utils import ( 21 | get_bag_info, 22 | get_predictions, 23 | plt_plot_losses, 24 | prep_minibatch, 25 | save_predictions_in_adata, 26 | select_covariates, 27 | setup_ordinal_regression, 28 | ) 29 | 30 | logger = logging.getLogger(__name__) 31 | 32 | 33 | class MILClassifier(BaseModelClass, ArchesMixin): 34 | """MultiMIL MIL prediction model. 35 | 36 | Parameters 37 | ---------- 38 | adata 39 | AnnData object containing embeddings and covariates. 40 | sample_key 41 | Key in `adata.obs` that corresponds to the sample covariate. 42 | classification 43 | List of keys in `adata.obs` that correspond to the classification covariates. 44 | regression 45 | List of keys in `adata.obs` that correspond to the regression covariates. 46 | ordinal_regression 47 | List of keys in `adata.obs` that correspond to the ordinal regression covariates. 48 | sample_batch_size 49 | Number of samples per bag, i.e. sample. Default is 128. 50 | normalization 51 | One of "layer" or "batch". Default is "layer". 52 | z_dim 53 | Dimensionality of the input latent space. Default is 16. 54 | dropout 55 | Dropout rate. Default is 0.2. 56 | scoring 57 | How to calculate attention scores. One of "gated_attn", "attn", "mean", "max", "sum". Default is "gated_attn". 58 | attn_dim 59 | Dimensionality of the hidden layer in attention calculation. Default is 16. 60 | n_layers_cell_aggregator 61 | Number of layers in the cell aggregator. Default is 1. 62 | n_layers_classifier 63 | Number of layers in the classifier. Default is 2. 64 | n_layers_regressor 65 | Number of layers in the regressor. Default is 2. 66 | n_hidden_cell_aggregator 67 | Number of hidden units in the cell aggregator. Default is 128. 68 | n_hidden_classifier 69 | Number of hidden units in the classifier. Default is 128. 70 | n_hidden_regressor 71 | Number of hidden units in the regressor. Default is 128. 72 | class_loss_coef 73 | Coefficient for the classification loss. Default is 1.0. 74 | regression_loss_coef 75 | Coefficient for the regression loss. Default is 1.0. 76 | activation 77 | Activation function. Default is 'leaky_relu'. 78 | initialization 79 | Initialization method for the weights. Default is None. 80 | anneal_class_loss 81 | Whether to anneal the classification loss. Default is False. 82 | 83 | """ 84 | 85 | def __init__( 86 | self, 87 | adata, 88 | sample_key, 89 | classification=None, 90 | regression=None, 91 | ordinal_regression=None, 92 | sample_batch_size=128, 93 | normalization="layer", 94 | dropout=0.2, 95 | scoring="gated_attn", # How to calculate attention scores 96 | attn_dim=16, 97 | n_layers_cell_aggregator: int = 1, 98 | n_layers_classifier: int = 2, 99 | n_layers_regressor: int = 2, 100 | n_hidden_cell_aggregator: int = 128, 101 | n_hidden_classifier: int = 128, 102 | n_hidden_regressor: int = 128, 103 | class_loss_coef=1.0, 104 | regression_loss_coef=1.0, 105 | activation="leaky_relu", # or tanh 106 | initialization=None, # xavier (tanh) or kaiming (leaky_relu) 107 | anneal_class_loss=False, 108 | ): 109 | super().__init__(adata) 110 | 111 | z_dim = adata.X.shape[1] 112 | 113 | if classification is None: 114 | classification = [] 115 | if regression is None: 116 | regression = [] 117 | if ordinal_regression is None: 118 | ordinal_regression = [] 119 | 120 | self.sample_key = sample_key 121 | self.scoring = scoring 122 | 123 | if self.sample_key not in self.adata_manager.registry["setup_args"]["categorical_covariate_keys"]: 124 | raise ValueError( 125 | f"Sample key = '{self.sample_key}' has to be one of the registered categorical covariates = {self.adata_manager.registry['setup_args']['categorical_covariate_keys']}" 126 | ) 127 | 128 | if len(classification) + len(regression) + len(ordinal_regression) == 0: 129 | raise ValueError( 130 | 'At least one of "classification", "regression", "ordinal_regression" has to be specified.' 131 | ) 132 | 133 | self.classification = classification 134 | self.regression = regression 135 | self.ordinal_regression = ordinal_regression 136 | 137 | # check if all of the three above were registered with setup anndata 138 | for key in classification + ordinal_regression: 139 | if key not in self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY)["field_keys"]: 140 | raise ValueError(f"Key '{key}' is not registered as categorical covariates.") 141 | 142 | for key in regression: 143 | if key not in self.adata_manager.get_state_registry(REGISTRY_KEYS.CONT_COVS_KEY)["columns"]: 144 | raise ValueError(f"Key '{key}' is not registered as continuous covariates.") 145 | 146 | # use the rest of the categoricalcovariates for sample covariates 147 | self.cat_sample_idx, self.cont_sample_idx = [], [] 148 | self.num_cat_cov_classes = [] 149 | cat_sample_covs, cont_sample_covs = [], [] 150 | 151 | if len(cat_covs := self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY)) > 0: 152 | for i, num_cat in enumerate(cat_covs.n_cats_per_key): 153 | cat_cov_name = cat_covs["field_keys"][i] 154 | 155 | if ( 156 | cat_cov_name not in self.classification 157 | and cat_cov_name not in self.ordinal_regression 158 | and cat_cov_name != self.sample_key 159 | ): 160 | self.cat_sample_idx.append(i) 161 | cat_sample_covs.append(cat_cov_name) 162 | self.num_cat_cov_classes.append(num_cat) 163 | 164 | min_cont_sample_cov = [] 165 | max_cont_sample_cov = [] 166 | if len(cont_covs := self.adata_manager.get_state_registry(REGISTRY_KEYS.CONT_COVS_KEY)) > 0: 167 | for key in cont_covs["columns"]: 168 | if key not in self.regression: 169 | self.cont_sample_idx.append(list(cont_covs["columns"]).index(key)) 170 | cont_sample_covs.append(key) 171 | # store the min and max of the continuous sample covariates 172 | min_cont_sample_cov.append(adata.obsm["_scvi_extra_continuous_covs"][key].min()) 173 | max_cont_sample_cov.append(adata.obsm["_scvi_extra_continuous_covs"][key].max()) 174 | 175 | use_sample_cov = False 176 | if len(cat_sample_covs) + len(cont_sample_covs) > 0: 177 | use_sample_cov = True 178 | if len(cat_sample_covs) > 0: 179 | print(f"Using {cat_sample_covs} as categorical sample covariates.") 180 | if len(cont_sample_covs) > 0: 181 | print(f"Using {cont_sample_covs} as continuous sample covariates.") 182 | 183 | # TODO add check that class is the same within a patient 184 | # TODO add that n_layers has to be > 0 for all 185 | # TODO warning if n_layers == 1 then n_hidden is not used for classifier 186 | # TODO check that there is at least one covariate to predict 187 | 188 | self.regression_idx = [] 189 | if len(cont_covs := self.adata_manager.get_state_registry(REGISTRY_KEYS.CONT_COVS_KEY)) > 0: 190 | for key in cont_covs["columns"]: 191 | if key in self.regression: 192 | self.regression_idx.append(list(cont_covs["columns"]).index(key)) 193 | 194 | # classification and ordinal regression together here as ordinal regression values need to be registered as categorical covariates 195 | self.class_idx, self.ord_idx = [], [] 196 | self.num_classification_classes = [] 197 | if len(cat_covs := self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY)) > 0: 198 | for i, num_cat in enumerate(cat_covs.n_cats_per_key): 199 | cat_cov_name = cat_covs["field_keys"][i] 200 | if cat_cov_name in self.classification: 201 | self.num_classification_classes.append(num_cat) 202 | self.class_idx.append(i) 203 | elif cat_cov_name in self.ordinal_regression: 204 | self.num_classification_classes.append(num_cat) 205 | self.ord_idx.append(i) 206 | 207 | for label in ordinal_regression: 208 | print( 209 | f"The order for {label} ordinal classes is: {adata.obs[label].cat.categories}. If you need to change the order, please rerun setup_anndata and specify the correct order with the `ordinal_regression_order` parameter." 210 | ) 211 | 212 | self.class_idx = torch.tensor(self.class_idx) 213 | self.ord_idx = torch.tensor(self.ord_idx) 214 | self.regression_idx = torch.tensor(self.regression_idx) 215 | self.cat_sample_idx = torch.tensor(self.cat_sample_idx) 216 | self.cont_sample_idx = torch.tensor(self.cont_sample_idx) 217 | 218 | self.module = MILClassifierTorch( 219 | z_dim=z_dim, 220 | dropout=dropout, 221 | activation=activation, 222 | initialization=initialization, 223 | normalization=normalization, 224 | num_classification_classes=self.num_classification_classes, 225 | scoring=scoring, 226 | attn_dim=attn_dim, 227 | n_layers_cell_aggregator=n_layers_cell_aggregator, 228 | n_layers_classifier=n_layers_classifier, 229 | n_layers_regressor=n_layers_regressor, 230 | n_hidden_regressor=n_hidden_regressor, 231 | n_hidden_cell_aggregator=n_hidden_cell_aggregator, 232 | n_hidden_classifier=n_hidden_classifier, 233 | class_loss_coef=class_loss_coef, 234 | regression_loss_coef=regression_loss_coef, 235 | sample_batch_size=sample_batch_size, 236 | class_idx=self.class_idx, 237 | ord_idx=self.ord_idx, 238 | reg_idx=self.regression_idx, 239 | anneal_class_loss=anneal_class_loss, 240 | cat_sample_idx=self.cat_sample_idx, 241 | cont_sample_idx=self.cont_sample_idx, 242 | num_cat_cov_classes=self.num_cat_cov_classes, 243 | use_sample_cov=use_sample_cov, 244 | min_cont_sample_cov=min_cont_sample_cov, 245 | max_cont_sample_cov=max_cont_sample_cov, 246 | ) 247 | 248 | self.init_params_ = self._get_init_params(locals()) 249 | 250 | def train( 251 | self, 252 | max_epochs: int = 200, 253 | lr: float = 5e-4, 254 | use_gpu: str | int | bool | None = None, 255 | train_size: float = 0.9, 256 | validation_size: float | None = None, 257 | batch_size: int = 256, 258 | weight_decay: float = 1e-3, 259 | eps: float = 1e-08, 260 | early_stopping: bool = True, 261 | save_best: bool = True, 262 | check_val_every_n_epoch: int | None = None, 263 | n_epochs_kl_warmup: int | None = None, 264 | n_steps_kl_warmup: int | None = None, 265 | adversarial_mixing: bool = False, 266 | plan_kwargs: dict | None = None, 267 | early_stopping_monitor: str | None = "accuracy_validation", 268 | early_stopping_mode: str | None = "max", 269 | save_checkpoint_every_n_epochs: int | None = None, 270 | path_to_checkpoints: str | None = None, 271 | **kwargs, 272 | ): 273 | """Trains the model using amortized variational inference. 274 | 275 | Parameters 276 | ---------- 277 | max_epochs 278 | Number of passes through the dataset. 279 | lr 280 | Learning rate for optimization. 281 | use_gpu 282 | Use default GPU if available (if None or True), or index of GPU to use (if int), 283 | or name of GPU (if str), or use CPU (if False). 284 | train_size 285 | Size of training set in the range [0.0, 1.0]. 286 | validation_size 287 | Size of the test set. If `None`, defaults to 1 - `train_size`. If 288 | `train_size + validation_size < 1`, the remaining cells belong to a test set. 289 | batch_size 290 | Minibatch size to use during training. 291 | weight_decay 292 | weight decay regularization term for optimization 293 | eps 294 | Optimizer eps 295 | early_stopping 296 | Whether to perform early stopping with respect to the validation set. 297 | save_best 298 | Save the best model state with respect to the validation loss, or use the final 299 | state in the training procedure 300 | check_val_every_n_epoch 301 | Check val every n train epochs. By default, val is not checked, unless `early_stopping` is `True`. 302 | If so, val is checked every epoch. 303 | n_epochs_kl_warmup 304 | Number of epochs to scale weight on KL divergences from 0 to 1. 305 | Overrides `n_steps_kl_warmup` when both are not `None`. Default is 1/3 of `max_epochs`. 306 | n_steps_kl_warmup 307 | Number of training steps (minibatches) to scale weight on KL divergences from 0 to 1. 308 | Only activated when `n_epochs_kl_warmup` is set to None. If `None`, defaults 309 | to `floor(0.75 * adata.n_obs)` 310 | adversarial_mixing 311 | Whether to use adversarial mixing. Default is False. 312 | plan_kwargs 313 | Keyword args for :class:`~scvi.train.TrainingPlan`. Keyword arguments passed to 314 | `train()` will overwrite values present in `plan_kwargs`, when appropriate. 315 | early_stopping_monitor 316 | Metric to monitor for early stopping. Default is "accuracy_validation". 317 | early_stopping_mode 318 | One of "min" or "max". Default is "max". 319 | save_checkpoint_every_n_epochs 320 | Save a checkpoint every n epochs. 321 | path_to_checkpoints 322 | Path to save checkpoints. 323 | **kwargs 324 | Other keyword args for :class:`~scvi.train.Trainer`. 325 | 326 | Returns 327 | ------- 328 | Trainer object. 329 | """ 330 | if len(self.regression) > 0: 331 | if early_stopping_monitor == "accuracy_validation": 332 | warnings.warn( 333 | "Setting early_stopping_monitor to 'regression_loss_validation' and early_stopping_mode to 'min' as regression is used.", 334 | stacklevel=2, 335 | ) 336 | early_stopping_monitor = "regression_loss_validation" 337 | early_stopping_mode = "min" 338 | if n_epochs_kl_warmup is None: 339 | n_epochs_kl_warmup = max(max_epochs // 3, 1) 340 | update_dict = { 341 | "lr": lr, 342 | "adversarial_classifier": adversarial_mixing, 343 | "weight_decay": weight_decay, 344 | "eps": eps, 345 | "n_epochs_kl_warmup": n_epochs_kl_warmup, 346 | "n_steps_kl_warmup": n_steps_kl_warmup, 347 | "optimizer": "AdamW", 348 | "scale_adversarial_loss": 1, 349 | } 350 | if plan_kwargs is not None: 351 | plan_kwargs.update(update_dict) 352 | else: 353 | plan_kwargs = update_dict 354 | 355 | if save_best: 356 | if "callbacks" not in kwargs.keys(): 357 | kwargs["callbacks"] = [] 358 | kwargs["callbacks"].append(SaveBestState(monitor=early_stopping_monitor, mode=early_stopping_mode)) 359 | 360 | if save_checkpoint_every_n_epochs is not None: 361 | if path_to_checkpoints is not None: 362 | kwargs["callbacks"].append( 363 | ModelCheckpoint( 364 | dirpath=path_to_checkpoints, 365 | save_top_k=-1, 366 | monitor="epoch", 367 | every_n_epochs=save_checkpoint_every_n_epochs, 368 | verbose=True, 369 | ) 370 | ) 371 | else: 372 | raise ValueError( 373 | f"`save_checkpoint_every_n_epochs` = {save_checkpoint_every_n_epochs} so `path_to_checkpoints` has to be not None but is {path_to_checkpoints}." 374 | ) 375 | # until here 376 | 377 | data_splitter = GroupDataSplitter( 378 | self.adata_manager, 379 | group_column=self.sample_key, 380 | train_size=train_size, 381 | validation_size=validation_size, 382 | batch_size=batch_size, 383 | use_gpu=use_gpu, 384 | ) 385 | 386 | training_plan = AdversarialTrainingPlan(self.module, **plan_kwargs) 387 | runner = TrainRunner( 388 | self, 389 | training_plan=training_plan, 390 | data_splitter=data_splitter, 391 | max_epochs=max_epochs, 392 | use_gpu=use_gpu, 393 | early_stopping=early_stopping, 394 | check_val_every_n_epoch=check_val_every_n_epoch, 395 | early_stopping_monitor=early_stopping_monitor, 396 | early_stopping_mode=early_stopping_mode, 397 | early_stopping_patience=50, 398 | enable_checkpointing=True, 399 | **kwargs, 400 | ) 401 | return runner() 402 | 403 | @classmethod 404 | def setup_anndata( 405 | cls, 406 | adata: ad.AnnData, 407 | categorical_covariate_keys: list[str] | None = None, 408 | continuous_covariate_keys: list[str] | None = None, 409 | ordinal_regression_order: dict[str, list[str]] | None = None, 410 | **kwargs, 411 | ): 412 | """Set up :class:`~anndata.AnnData` object. 413 | 414 | A mapping will be created between data fields used by ``scvi`` to their respective locations in adata. 415 | This method will also compute the log mean and log variance per batch for the library size prior. 416 | None of the data in adata are modified. Only adds fields to adata. 417 | 418 | Parameters 419 | ---------- 420 | adata 421 | AnnData object containing raw counts. Rows represent cells, columns represent features. 422 | categorical_covariate_keys 423 | Keys in `adata.obs` that correspond to categorical data. 424 | continuous_covariate_keys 425 | Keys in `adata.obs` that correspond to continuous data. 426 | ordinal_regression_order 427 | Dictionary with regression classes as keys and order of classes as values. 428 | kwargs 429 | Additional parameters to pass to register_fields() of AnnDataManager. 430 | """ 431 | if categorical_covariate_keys is not None: 432 | for key in categorical_covariate_keys: 433 | adata.obs[key] = adata.obs[key].astype("category") 434 | 435 | setup_ordinal_regression(adata, ordinal_regression_order, categorical_covariate_keys) 436 | 437 | setup_method_args = cls._get_setup_method_args(**locals()) 438 | 439 | anndata_fields = [ 440 | fields.LayerField( 441 | REGISTRY_KEYS.X_KEY, 442 | layer=None, 443 | ), 444 | fields.CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, None), 445 | fields.CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys), 446 | fields.NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys), 447 | ] 448 | 449 | adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) 450 | adata_manager.register_fields(adata, **kwargs) 451 | cls.register_manager(adata_manager) 452 | 453 | @torch.inference_mode() 454 | def get_model_output( 455 | self, 456 | adata=None, 457 | batch_size=256, 458 | ): 459 | """Save the attention scores and predictions in the adata object. 460 | 461 | Parameters 462 | ---------- 463 | adata 464 | AnnData object to run the model on. If `None`, the model's AnnData object is used. 465 | batch_size 466 | Minibatch size to use. Default is 256. 467 | 468 | """ 469 | if not self.is_trained_: 470 | raise RuntimeError("Please train the model first.") 471 | 472 | adata = self._validate_anndata(adata) 473 | 474 | scdl = self._make_data_loader( 475 | adata=adata, 476 | batch_size=batch_size, 477 | min_size_per_class=batch_size, # hack to ensure that not full batches are processed properly 478 | data_loader_class=GroupAnnDataLoader, 479 | shuffle=False, 480 | shuffle_classes=False, 481 | group_column=self.sample_key, 482 | drop_last=False, 483 | ) 484 | 485 | cell_level_attn, bags = [], [] 486 | class_pred, ord_pred, reg_pred = {}, {}, {} 487 | ( 488 | bag_class_true, 489 | bag_class_pred, 490 | bag_reg_true, 491 | bag_reg_pred, 492 | bag_ord_true, 493 | bag_ord_pred, 494 | ) = ({}, {}, {}, {}, {}, {}) 495 | 496 | bag_counter = 0 497 | cell_counter = 0 498 | 499 | for tensors in scdl: 500 | cont_key = REGISTRY_KEYS.CONT_COVS_KEY 501 | cont_covs = tensors[cont_key] if cont_key in tensors.keys() else None 502 | 503 | cat_key = REGISTRY_KEYS.CAT_COVS_KEY 504 | cat_covs = tensors[cat_key] if cat_key in tensors.keys() else None 505 | 506 | inference_inputs = self.module._get_inference_input(tensors) 507 | outputs = self.module.inference(**inference_inputs) 508 | pred = outputs["predictions"] 509 | 510 | # get attention for each cell in the bag 511 | if self.scoring in ["gated_attn", "attn"]: 512 | cell_attn = self.module.cell_level_aggregator[-1].A.squeeze(dim=1) 513 | cell_attn = cell_attn.flatten() # in inference always one patient per batch 514 | cell_level_attn += [cell_attn.cpu()] 515 | 516 | assert outputs["z"].shape[0] % pred[0].shape[0] == 0 517 | sample_size = outputs["z"].shape[0] // pred[0].shape[0] # how many cells in patient_minibatch 518 | minibatch_size, n_samples_in_batch = prep_minibatch(cat_covs, self.module.sample_batch_size) 519 | regression = select_covariates(cont_covs, self.regression_idx, n_samples_in_batch) 520 | ordinal_regression = select_covariates(cat_covs, self.ord_idx, n_samples_in_batch) 521 | classification = select_covariates(cat_covs, self.class_idx, n_samples_in_batch) 522 | 523 | # calculate accuracies of predictions 524 | bag_class_pred, bag_class_true, class_pred = get_predictions( 525 | self.class_idx, pred, classification, sample_size, bag_class_pred, bag_class_true, class_pred 526 | ) 527 | bag_ord_pred, bag_ord_true, ord_pred = get_predictions( 528 | self.ord_idx, 529 | pred, 530 | ordinal_regression, 531 | sample_size, 532 | bag_ord_pred, 533 | bag_ord_true, 534 | ord_pred, 535 | len(self.class_idx), 536 | ) 537 | bag_reg_pred, bag_reg_true, reg_pred = get_predictions( 538 | self.regression_idx, 539 | pred, 540 | regression, 541 | sample_size, 542 | bag_reg_pred, 543 | bag_reg_true, 544 | reg_pred, 545 | len(self.class_idx) + len(self.ord_idx), 546 | ) 547 | 548 | # save bag info to be able to calculate bag predictions later 549 | bags, cell_counter, bag_counter = get_bag_info( 550 | bags, n_samples_in_batch, minibatch_size, cell_counter, bag_counter, self.module.sample_batch_size 551 | ) 552 | 553 | if self.scoring in ["gated_attn", "attn"]: 554 | cell_level = torch.cat(cell_level_attn).numpy() 555 | adata.obs["cell_attn"] = cell_level 556 | flat_bags = [value for sublist in bags for value in sublist] 557 | adata.obs["bags"] = flat_bags 558 | 559 | for i in range(len(self.class_idx)): 560 | name = self.classification[i] 561 | class_names = self.adata_manager.get_state_registry("extra_categorical_covs")["mappings"][name] 562 | save_predictions_in_adata( 563 | adata, 564 | i, 565 | self.classification, 566 | bag_class_pred, 567 | bag_class_true, 568 | class_pred, 569 | class_names, 570 | name, 571 | clip="argmax", 572 | ) 573 | for i in range(len(self.ord_idx)): 574 | name = self.ordinal_regression[i] 575 | class_names = self.adata_manager.get_state_registry("extra_categorical_covs")["mappings"][name] 576 | save_predictions_in_adata( 577 | adata, i, self.ordinal_regression, bag_ord_pred, bag_ord_true, ord_pred, class_names, name, clip="clip" 578 | ) 579 | for i in range(len(self.regression_idx)): 580 | name = self.regression[i] 581 | reg_names = self.adata_manager.get_state_registry("extra_continuous_covs")["columns"] 582 | save_predictions_in_adata( 583 | adata, i, self.regression, bag_reg_pred, bag_reg_true, reg_pred, reg_names, name, clip=None, reg=True 584 | ) 585 | 586 | def plot_losses(self, save=None): 587 | """Plot losses. 588 | 589 | Parameters 590 | ---------- 591 | save 592 | If not None, save the plot to this location. 593 | """ 594 | loss_names = self.module.select_losses_to_plot() 595 | plt_plot_losses(self.history, loss_names, save) 596 | 597 | # adjusted from scvi-tools 598 | # https://github.com/scverse/scvi-tools/blob/0b802762869c43c9f49e69fe62b1a5a9b5c4dae6/scvi/model/base/_archesmixin.py#L30 599 | # accessed on 7 November 2022 600 | @classmethod 601 | def load_query_data( 602 | cls, 603 | adata: AnnData, 604 | reference_model: BaseModelClass, 605 | use_gpu: str | int | bool | None = None, 606 | ) -> BaseModelClass: 607 | """Online update of a reference model with scArches algorithm # TODO cite. 608 | 609 | Parameters 610 | ---------- 611 | adata 612 | AnnData organized in the same way as data used to train model. 613 | It is not necessary to run setup_anndata, 614 | as AnnData is validated against the ``registry``. 615 | reference_model 616 | Already instantiated model of the same class. 617 | use_gpu 618 | Load model on default GPU if available (if None or True), 619 | or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False). 620 | 621 | Returns 622 | ------- 623 | Model with updated architecture and weights. 624 | """ 625 | # currently this function works only if the prediction cov is present in the .obs of the query 626 | # TODO need to allow it to be missing, maybe add a dummy column to .obs of query adata 627 | 628 | _, _, device = parse_use_gpu_arg(use_gpu) 629 | 630 | attr_dict, _, _ = _get_loaded_data(reference_model, device=device) 631 | 632 | registry = attr_dict.pop("registry_") 633 | if _MODEL_NAME_KEY in registry and registry[_MODEL_NAME_KEY] != cls.__name__: 634 | raise ValueError("It appears you are loading a model from a different class.") 635 | 636 | if _SETUP_ARGS_KEY not in registry: 637 | raise ValueError("Saved model does not contain original setup inputs. " "Cannot load the original setup.") 638 | 639 | cls.setup_anndata( 640 | adata, 641 | source_registry=registry, 642 | extend_categories=True, 643 | allow_missing_labels=True, 644 | **registry[_SETUP_ARGS_KEY], 645 | ) 646 | 647 | model = _initialize_model(cls, adata, attr_dict) 648 | model.module.load_state_dict(reference_model.module.state_dict()) 649 | model.to_device(device) 650 | 651 | model.module.eval() 652 | model.is_trained_ = True 653 | 654 | return model 655 | --------------------------------------------------------------------------------