├── .codecov.yaml
├── .cruft.json
├── .editorconfig
├── .github
├── ISSUE_TEMPLATE
│ ├── bug_report.yml
│ ├── config.yml
│ └── feature_request.yml
└── workflows
│ ├── build.yaml
│ ├── release.yaml
│ └── test.yaml
├── .gitignore
├── .pre-commit-config.yaml
├── .readthedocs.yaml
├── CHANGELOG.md
├── LICENSE
├── README.md
├── docs
├── Makefile
├── _static
│ ├── .gitkeep
│ └── css
│ │ └── custom.css
├── _templates
│ ├── .gitkeep
│ └── autosummary
│ │ └── class.rst
├── api.md
├── changelog.md
├── conf.py
├── contributing.md
├── extensions
│ └── typed_returns.py
├── index.md
├── notebooks
│ ├── cellttype_harmonize.json
│ ├── mil_classification.ipynb
│ ├── paired_integration_cite-seq.ipynb
│ └── trimodal_integration.ipynb
├── references.bib
├── references.md
└── tutorials.md
├── pyproject.toml
├── src
└── multimil
│ ├── __init__.py
│ ├── data
│ ├── __init__.py
│ └── _preprocessing.py
│ ├── dataloaders
│ ├── __init__.py
│ ├── _ann_dataloader.py
│ └── _data_splitting.py
│ ├── distributions
│ ├── __init__.py
│ ├── _jeffreys.py
│ └── _mmd.py
│ ├── model
│ ├── __init__.py
│ ├── _mil.py
│ ├── _multivae.py
│ └── _multivae_mil.py
│ ├── module
│ ├── __init__.py
│ ├── _mil_torch.py
│ ├── _multivae_mil_torch.py
│ └── _multivae_torch.py
│ ├── nn
│ ├── __init__.py
│ └── _base_components.py
│ └── utils
│ ├── __init__.py
│ └── _utils.py
└── tests
└── test_basic.py
/.codecov.yaml:
--------------------------------------------------------------------------------
1 | # Based on pydata/xarray
2 | codecov:
3 | require_ci_to_pass: no
4 |
5 | coverage:
6 | status:
7 | project:
8 | default:
9 | # Require 1% coverage, i.e., always succeed
10 | target: 1
11 | patch: false
12 | changes: false
13 |
14 | comment:
15 | layout: diff, flags, files
16 | behavior: once
17 | require_base: no
18 |
--------------------------------------------------------------------------------
/.cruft.json:
--------------------------------------------------------------------------------
1 | {
2 | "template": "https://github.com/scverse/cookiecutter-scverse",
3 | "commit": "aa877587e59c855e2464d09ed7678af00999191a",
4 | "checkout": null,
5 | "context": {
6 | "cookiecutter": {
7 | "project_name": "multimil",
8 | "package_name": "multimil",
9 | "project_description": "Multimodal weakly supervised learning to identify disease-specific changes in single-cell atlases",
10 | "author_full_name": "Anastasia Litinetskaya",
11 | "author_email": "alitinet@gmail.com",
12 | "github_user": "alitinet",
13 | "project_repo": "https://github.com/theislab/multimil",
14 | "license": "BSD 3-Clause License",
15 | "_copy_without_render": [
16 | ".github/workflows/build.yaml",
17 | ".github/workflows/test.yaml",
18 | "docs/_templates/autosummary/**.rst"
19 | ],
20 | "_render_devdocs": false,
21 | "_jinja2_env_vars": {
22 | "lstrip_blocks": true,
23 | "trim_blocks": true
24 | },
25 | "_template": "https://github.com/scverse/cookiecutter-scverse"
26 | }
27 | },
28 | "directory": null
29 | }
30 |
--------------------------------------------------------------------------------
/.editorconfig:
--------------------------------------------------------------------------------
1 | root = true
2 |
3 | [*]
4 | indent_style = space
5 | indent_size = 4
6 | end_of_line = lf
7 | charset = utf-8
8 | trim_trailing_whitespace = true
9 | insert_final_newline = true
10 |
11 | [*.{yml,yaml}]
12 | indent_size = 2
13 |
14 | [.cruft.json]
15 | indent_size = 2
16 |
17 | [Makefile]
18 | indent_style = tab
19 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.yml:
--------------------------------------------------------------------------------
1 | name: Bug report
2 | description: Report something that is broken or incorrect
3 | labels: bug
4 | body:
5 | - type: markdown
6 | attributes:
7 | value: |
8 | **Note**: Please read [this guide](https://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports)
9 | detailing how to provide the necessary information for us to reproduce your bug. In brief:
10 | * Please provide exact steps how to reproduce the bug in a clean Python environment.
11 | * In case it's not clear what's causing this bug, please provide the data or the data generation procedure.
12 | * Sometimes it is not possible to share the data, but usually it is possible to replicate problems on publicly
13 | available datasets or to share a subset of your data.
14 |
15 | - type: textarea
16 | id: report
17 | attributes:
18 | label: Report
19 | description: A clear and concise description of what the bug is.
20 | validations:
21 | required: true
22 |
23 | - type: textarea
24 | id: versions
25 | attributes:
26 | label: Version information
27 | description: |
28 | Please paste below the output of
29 |
30 | ```python
31 | import session_info
32 | session_info.show(html=False, dependencies=True)
33 | ```
34 | placeholder: |
35 | -----
36 | anndata 0.8.0rc2.dev27+ge524389
37 | session_info 1.0.0
38 | -----
39 | asttokens NA
40 | awkward 1.8.0
41 | backcall 0.2.0
42 | cython_runtime NA
43 | dateutil 2.8.2
44 | debugpy 1.6.0
45 | decorator 5.1.1
46 | entrypoints 0.4
47 | executing 0.8.3
48 | h5py 3.7.0
49 | ipykernel 6.15.0
50 | jedi 0.18.1
51 | mpl_toolkits NA
52 | natsort 8.1.0
53 | numpy 1.22.4
54 | packaging 21.3
55 | pandas 1.4.2
56 | parso 0.8.3
57 | pexpect 4.8.0
58 | pickleshare 0.7.5
59 | pkg_resources NA
60 | prompt_toolkit 3.0.29
61 | psutil 5.9.1
62 | ptyprocess 0.7.0
63 | pure_eval 0.2.2
64 | pydev_ipython NA
65 | pydevconsole NA
66 | pydevd 2.8.0
67 | pydevd_file_utils NA
68 | pydevd_plugins NA
69 | pydevd_tracing NA
70 | pygments 2.12.0
71 | pytz 2022.1
72 | scipy 1.8.1
73 | setuptools 62.5.0
74 | setuptools_scm NA
75 | six 1.16.0
76 | stack_data 0.3.0
77 | tornado 6.1
78 | traitlets 5.3.0
79 | wcwidth 0.2.5
80 | zmq 23.1.0
81 | -----
82 | IPython 8.4.0
83 | jupyter_client 7.3.4
84 | jupyter_core 4.10.0
85 | -----
86 | Python 3.9.13 | packaged by conda-forge | (main, May 27 2022, 16:58:50) [GCC 10.3.0]
87 | Linux-5.18.6-arch1-1-x86_64-with-glibc2.35
88 | -----
89 | Session information updated at 2022-07-07 17:55
90 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/config.yml:
--------------------------------------------------------------------------------
1 | blank_issues_enabled: false
2 | contact_links:
3 | - name: Scverse Community Forum
4 | url: https://discourse.scverse.org/
5 | about: If you have questions about “How to do X”, please ask them here.
6 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.yml:
--------------------------------------------------------------------------------
1 | name: Feature request
2 | description: Propose a new feature for multimil
3 | labels: enhancement
4 | body:
5 | - type: textarea
6 | id: description
7 | attributes:
8 | label: Description of feature
9 | description: Please describe your suggestion for a new feature. It might help to describe a problem or use case, plus any alternatives that you have considered.
10 | validations:
11 | required: true
12 |
--------------------------------------------------------------------------------
/.github/workflows/build.yaml:
--------------------------------------------------------------------------------
1 | name: Check Build
2 |
3 | on:
4 | push:
5 | branches: [main]
6 | pull_request:
7 | branches: [main]
8 |
9 | concurrency:
10 | group: ${{ github.workflow }}-${{ github.ref }}
11 | cancel-in-progress: true
12 |
13 | jobs:
14 | package:
15 | runs-on: ubuntu-latest
16 | steps:
17 | - uses: actions/checkout@v3
18 | - name: Set up Python 3.10
19 | uses: actions/setup-python@v4
20 | with:
21 | python-version: "3.10"
22 | cache: "pip"
23 | cache-dependency-path: "**/pyproject.toml"
24 | - name: Install build dependencies
25 | run: python -m pip install --upgrade pip wheel twine build
26 | - name: Build package
27 | run: python -m build
28 | - name: Check package
29 | run: twine check --strict dist/*.whl
30 |
--------------------------------------------------------------------------------
/.github/workflows/release.yaml:
--------------------------------------------------------------------------------
1 | name: Release
2 |
3 | on:
4 | release:
5 | types: [published]
6 |
7 | # Use "trusted publishing", see https://docs.pypi.org/trusted-publishers/
8 | jobs:
9 | release:
10 | name: Upload release to PyPI
11 | runs-on: ubuntu-latest
12 | environment:
13 | name: pypi
14 | url: https://pypi.org/p/multimil
15 | permissions:
16 | id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
17 | steps:
18 | - uses: actions/checkout@v4
19 | with:
20 | filter: blob:none
21 | fetch-depth: 0
22 | - uses: actions/setup-python@v4
23 | with:
24 | python-version: "3.x"
25 | cache: "pip"
26 | - run: pip install build
27 | - run: python -m build
28 | - name: Publish package distributions to PyPI
29 | uses: pypa/gh-action-pypi-publish@release/v1
30 |
--------------------------------------------------------------------------------
/.github/workflows/test.yaml:
--------------------------------------------------------------------------------
1 | name: Test
2 |
3 | on:
4 | push:
5 | branches: [main]
6 | pull_request:
7 | branches: [main]
8 | schedule:
9 | - cron: "0 5 1,15 * *"
10 |
11 | concurrency:
12 | group: ${{ github.workflow }}-${{ github.ref }}
13 | cancel-in-progress: true
14 |
15 | jobs:
16 | test:
17 | runs-on: ${{ matrix.os }}
18 | defaults:
19 | run:
20 | shell: bash -e {0} # -e to fail on error
21 |
22 | strategy:
23 | fail-fast: false
24 | matrix:
25 | include:
26 | - os: ubuntu-latest
27 | python: "3.10"
28 | - os: ubuntu-latest
29 | python: "3.12"
30 | - os: ubuntu-latest
31 | python: "3.12"
32 | pip-flags: "--pre"
33 | name: PRE-RELEASE DEPENDENCIES
34 |
35 | name: ${{ matrix.name }} Python ${{ matrix.python }}
36 |
37 | env:
38 | OS: ${{ matrix.os }}
39 | PYTHON: ${{ matrix.python }}
40 |
41 | steps:
42 | - uses: actions/checkout@v3
43 | - name: Set up Python ${{ matrix.python }}
44 | uses: actions/setup-python@v4
45 | with:
46 | python-version: ${{ matrix.python }}
47 | cache: "pip"
48 | cache-dependency-path: "**/pyproject.toml"
49 |
50 | - name: Install test dependencies
51 | run: |
52 | python -m pip install --upgrade pip wheel
53 | - name: Install dependencies
54 | run: |
55 | pip install ${{ matrix.pip-flags }} ".[dev,test]"
56 | - name: Test
57 | env:
58 | MPLBACKEND: agg
59 | PLATFORM: ${{ matrix.os }}
60 | DISPLAY: :42
61 | run: |
62 | coverage run -m pytest -v --color=yes
63 | - name: Report coverage
64 | run: |
65 | coverage report
66 | - name: Upload coverage
67 | uses: codecov/codecov-action@v3
68 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Notebooks
2 | .ipynb_checkpoints
3 |
4 | # Logs
5 | scvi_log/
6 |
7 | # Temp files
8 | .DS_Store
9 | *~
10 | buck-out/
11 |
12 | # Compiled files
13 | .venv/
14 | __pycache__/
15 | .mypy_cache/
16 | .ruff_cache/
17 |
18 | # Distribution / packaging
19 | /build/
20 | /dist/
21 | /*.egg-info/
22 |
23 | # Tests and coverage
24 | /.pytest_cache/
25 | /.cache/
26 | /data/
27 | /node_modules/
28 |
29 | # docs
30 | /docs/generated/
31 | /docs/_build/
32 |
33 | # IDEs
34 | /.idea/
35 | /.vscode/
36 |
37 | # Data
38 | *.h5ad
39 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | fail_fast: false
2 | default_language_version:
3 | python: python3
4 | default_stages:
5 | - commit
6 | - push
7 | minimum_pre_commit_version: 2.16.0
8 | repos:
9 | - repo: https://github.com/pre-commit/mirrors-prettier
10 | rev: v4.0.0-alpha.8
11 | hooks:
12 | - id: prettier
13 | - repo: https://github.com/astral-sh/ruff-pre-commit
14 | rev: v0.5.3
15 | hooks:
16 | - id: ruff
17 | types_or: [python, pyi, jupyter]
18 | args: [--fix, --exit-non-zero-on-fix]
19 | - id: ruff-format
20 | types_or: [python, pyi, jupyter]
21 | - repo: https://github.com/pre-commit/pre-commit-hooks
22 | rev: v4.6.0
23 | hooks:
24 | - id: detect-private-key
25 | - id: check-ast
26 | - id: end-of-file-fixer
27 | - id: mixed-line-ending
28 | args: [--fix=lf]
29 | - id: trailing-whitespace
30 | - id: check-case-conflict
31 | # Check that there are no merge conflicts (could be generated by template sync)
32 | - id: check-merge-conflict
33 | args: [--assume-in-merge]
34 | - repo: local
35 | hooks:
36 | - id: forbid-to-commit
37 | name: Don't commit rej files
38 | entry: |
39 | Cannot commit .rej files. These indicate merge conflicts that arise during automated template updates.
40 | Fix the merge conflicts manually and remove the .rej files.
41 | language: fail
42 | files: '.*\.rej$'
43 |
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # https://docs.readthedocs.io/en/stable/config-file/v2.html
2 | version: 2
3 | build:
4 | os: ubuntu-20.04
5 | tools:
6 | python: "3.10"
7 | sphinx:
8 | configuration: docs/conf.py
9 | # disable this for more lenient docs builds
10 | fail_on_warning: false
11 | python:
12 | install:
13 | - method: pip
14 | path: .
15 | extra_requirements:
16 | - doc
17 |
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | # Changelog
2 |
3 | All notable changes to this project will be documented in this file.
4 |
5 | The format is based on [Keep a Changelog][],
6 | and this project adheres to [Semantic Versioning][].
7 |
8 | [keep a changelog]: https://keepachangelog.com/en/1.0.0/
9 | [semantic versioning]: https://semver.org/spec/v2.0.0.html
10 |
11 | ## [Unreleased]
12 |
13 | ### Added
14 |
15 | - Basic tool, preprocessing and plotting functions
16 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2024, Anastasia Litinetskaya
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | 3. Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Multimodal weakly supervised learning to identify disease-specific changes in single-cell atlases
2 |
3 | [![Tests][badge-tests]][link-tests]
4 | [![Documentation][badge-docs]][link-docs]
5 |
6 | [badge-tests]: https://img.shields.io/github/actions/workflow/status/theislab/multimil/test.yaml?branch=main
7 | [link-tests]: https://github.com/theislab/multimil/actions/workflows/test.yml
8 | [badge-docs]: https://img.shields.io/readthedocs/multimil
9 | [badge-colab]: https://colab.research.google.com/assets/colab-badge.svg
10 |
11 | ## Getting started
12 |
13 | Please refer to the [documentation][link-docs]. In particular, the
14 |
15 | - [API documentation][link-api]
16 |
17 | and the tutorials:
18 |
19 |
22 |
23 | - [Classification with MIL](https://multimil.readthedocs.io/en/latest/notebooks/mil_classification.html) [![Open In Colab][badge-colab]](https://colab.research.google.com/github/theislab/multimil/blob/main/docs/notebooks/mil_classification.ipynb)
24 |
25 | ## Installation
26 |
27 | You need to have Python 3.10 or newer installed on your system. We recommend installing [Mambaforge](https://github.com/conda-forge/miniforge#mambaforge).
28 |
29 | To create and activate a new environment:
30 |
31 | ```bash
32 | mamba create --name multimil python=3.10
33 | mamba activate multimil
34 | ```
35 |
36 | Next, there are several alternative options to install multimil:
37 |
38 | 1. Install the latest release of `multimil` from [PyPI][link-pypi]:
39 |
40 | ```bash
41 | pip install multimil
42 | ```
43 |
44 | 2. Or install the latest development version:
45 |
46 | ```bash
47 | pip install git+https://github.com/theislab/multimil.git@main
48 | ```
49 |
50 | ## Release notes
51 |
52 | See the [changelog][changelog].
53 |
54 | ## Contact
55 |
56 | If you found a bug, please use the [issue tracker][issue-tracker].
57 |
58 | ## Citation
59 |
60 | > **Multimodal Weakly Supervised Learning to Identify Disease-Specific Changes in Single-Cell Atlases**
61 | >
62 | > Anastasia Litinetskaya, Maiia Shulman, Soroor Hediyeh-zadeh, Amir Ali Moinfar, Fabiola Curion, Artur Szalata, Alireza Omidi, Mohammad Lotfollahi, and Fabian J. Theis. 2024. bioRxiv. https://doi.org/10.1101/2024.07.29.605625.
63 |
64 | ## Reproducibility
65 |
66 | Code and notebooks to reproduce the results from the paper are available at https://github.com/theislab/multimil_reproducibility.
67 |
68 | [issue-tracker]: https://github.com/theislab/multimil/issues
69 | [changelog]: https://multimil.readthedocs.io/latest/changelog.html
70 | [link-docs]: https://multimil.readthedocs.io
71 | [link-api]: https://multimil.readthedocs.io/latest/api.html
72 | [link-pypi]: https://pypi.org/project/multimil
73 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/_static/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theislab/multimil/edfb3971cc2bc0302021eb3fa6175b11428bd903/docs/_static/.gitkeep
--------------------------------------------------------------------------------
/docs/_static/css/custom.css:
--------------------------------------------------------------------------------
1 | /* Reduce the font size in data frames - See https://github.com/scverse/cookiecutter-scverse/issues/193 */
2 | div.cell_output table.dataframe {
3 | font-size: 0.8em;
4 | }
5 |
--------------------------------------------------------------------------------
/docs/_templates/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theislab/multimil/edfb3971cc2bc0302021eb3fa6175b11428bd903/docs/_templates/.gitkeep
--------------------------------------------------------------------------------
/docs/_templates/autosummary/class.rst:
--------------------------------------------------------------------------------
1 | {{ fullname | escape | underline}}
2 |
3 | .. currentmodule:: {{ module }}
4 |
5 | .. add toctree option to make autodoc generate the pages
6 |
7 | .. autoclass:: {{ objname }}
8 |
9 | {% block attributes %}
10 | {% if attributes %}
11 | Attributes table
12 | ~~~~~~~~~~~~~~~~~~
13 |
14 | .. autosummary::
15 | {% for item in attributes %}
16 | ~{{ fullname }}.{{ item }}
17 | {%- endfor %}
18 | {% endif %}
19 | {% endblock %}
20 |
21 | {% block methods %}
22 | {% if methods %}
23 | Methods table
24 | ~~~~~~~~~~~~~
25 |
26 | .. autosummary::
27 | {% for item in methods %}
28 | {%- if item != '__init__' %}
29 | ~{{ fullname }}.{{ item }}
30 | {%- endif -%}
31 | {%- endfor %}
32 | {% endif %}
33 | {% endblock %}
34 |
35 | {% block attributes_documentation %}
36 | {% if attributes %}
37 | Attributes
38 | ~~~~~~~~~~~
39 |
40 | {% for item in attributes %}
41 |
42 | .. autoattribute:: {{ [objname, item] | join(".") }}
43 | {%- endfor %}
44 |
45 | {% endif %}
46 | {% endblock %}
47 |
48 | {% block methods_documentation %}
49 | {% if methods %}
50 | Methods
51 | ~~~~~~~
52 |
53 | {% for item in methods %}
54 | {%- if item != '__init__' %}
55 |
56 | .. automethod:: {{ [objname, item] | join(".") }}
57 | {%- endif -%}
58 | {%- endfor %}
59 |
60 | {% endif %}
61 | {% endblock %}
62 |
--------------------------------------------------------------------------------
/docs/api.md:
--------------------------------------------------------------------------------
1 | # API
2 |
3 | ## Model
4 |
5 | ```{eval-rst}
6 | .. module:: multimil.model
7 | .. currentmodule:: multimil
8 |
9 | .. autosummary::
10 | :toctree: generated
11 |
12 | model.MultiVAE
13 | model.MILClassifier
14 | model.MultiVAE_MIL
15 | ```
16 |
17 | ## Module
18 |
19 | ```{eval-rst}
20 | .. module:: multimil.module
21 | .. currentmodule:: multimil
22 |
23 | .. autosummary::
24 | :toctree: generated
25 |
26 | module.MultiVAETorch
27 | module.MILClassifierTorch
28 | module.MultiVAETorch_MIL
29 | ```
30 |
--------------------------------------------------------------------------------
/docs/changelog.md:
--------------------------------------------------------------------------------
1 | ```{include} ../CHANGELOG.md
2 |
3 | ```
4 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 |
3 | # This file only contains a selection of the most common options. For a full
4 | # list see the documentation:
5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
6 |
7 | # -- Path setup --------------------------------------------------------------
8 | import sys
9 | from datetime import datetime
10 | from importlib.metadata import metadata
11 | from pathlib import Path
12 |
13 | HERE = Path(__file__).parent
14 | sys.path.insert(0, str(HERE / "extensions"))
15 |
16 |
17 | # -- Project information -----------------------------------------------------
18 |
19 | # NOTE: If you installed your project in editable mode, this might be stale.
20 | # If this is the case, reinstall it to refresh the metadata
21 | info = metadata("multimil")
22 | project_name = info["Name"]
23 | author = info["Author"]
24 | copyright = f"{datetime.now():%Y}, {author}."
25 | version = info["Version"]
26 | urls = dict(pu.split(", ") for pu in info.get_all("Project-URL"))
27 | repository_url = urls["Source"]
28 |
29 | # The full version, including alpha/beta/rc tags
30 | release = info["Version"]
31 |
32 | bibtex_bibfiles = ["references.bib"]
33 | templates_path = ["_templates"]
34 | nitpicky = True # Warn about broken links
35 | needs_sphinx = "4.0"
36 |
37 | html_context = {
38 | "display_github": True, # Integrate GitHub
39 | "github_user": "alitinet",
40 | "github_repo": "https://github.com/theislab/multimil",
41 | "github_version": "main",
42 | "conf_py_path": "/docs/",
43 | }
44 |
45 | # -- General configuration ---------------------------------------------------
46 |
47 | # Add any Sphinx extension module names here, as strings.
48 | # They can be extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones.
49 | extensions = [
50 | "myst_nb",
51 | "sphinx_copybutton",
52 | "sphinx.ext.autodoc",
53 | "sphinx.ext.intersphinx",
54 | "sphinx.ext.autosummary",
55 | "sphinx.ext.napoleon",
56 | "sphinxcontrib.bibtex",
57 | "sphinx_autodoc_typehints",
58 | "sphinx.ext.mathjax",
59 | "IPython.sphinxext.ipython_console_highlighting",
60 | "sphinxext.opengraph",
61 | *[p.stem for p in (HERE / "extensions").glob("*.py")],
62 | ]
63 |
64 | autosummary_generate = True
65 | autodoc_member_order = "groupwise"
66 | default_role = "literal"
67 | napoleon_google_docstring = False
68 | napoleon_numpy_docstring = True
69 | napoleon_include_init_with_doc = False
70 | napoleon_use_rtype = True # having a separate entry generally helps readability
71 | napoleon_use_param = True
72 | myst_heading_anchors = 6 # create anchors for h1-h6
73 | myst_enable_extensions = [
74 | "amsmath",
75 | "colon_fence",
76 | "deflist",
77 | "dollarmath",
78 | "html_image",
79 | "html_admonition",
80 | ]
81 | myst_url_schemes = ("http", "https", "mailto")
82 | nb_output_stderr = "remove"
83 | nb_execution_mode = "off"
84 | nb_merge_streams = True
85 | typehints_defaults = "braces"
86 |
87 | source_suffix = {
88 | ".rst": "restructuredtext",
89 | ".ipynb": "myst-nb",
90 | ".myst": "myst-nb",
91 | }
92 |
93 | intersphinx_mapping = {
94 | "python": ("https://docs.python.org/3", None),
95 | "anndata": ("https://anndata.readthedocs.io/en/stable/", None),
96 | "numpy": ("https://numpy.org/doc/stable/", None),
97 | }
98 |
99 | # List of patterns, relative to source directory, that match files and
100 | # directories to ignore when looking for source files.
101 | # This pattern also affects html_static_path and html_extra_path.
102 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "**.ipynb_checkpoints"]
103 |
104 |
105 | # -- Options for HTML output -------------------------------------------------
106 |
107 | # The theme to use for HTML and HTML Help pages. See the documentation for
108 | # a list of builtin themes.
109 | #
110 | html_theme = "sphinx_book_theme"
111 | html_static_path = ["_static"]
112 | html_css_files = ["css/custom.css"]
113 |
114 | html_title = project_name
115 |
116 | html_theme_options = {
117 | "repository_url": repository_url,
118 | "use_repository_button": True,
119 | "path_to_docs": "docs/",
120 | "navigation_with_keys": False,
121 | }
122 |
123 | pygments_style = "default"
124 |
125 | nitpick_ignore = [
126 | # If building the documentation fails because of a missing link that is outside your control,
127 | # you can add an exception to this list.
128 | # ("py:class", "igraph.Graph"),
129 | ]
130 |
--------------------------------------------------------------------------------
/docs/contributing.md:
--------------------------------------------------------------------------------
1 | # Contributing guide
2 |
3 | Scanpy provides extensive [developer documentation][scanpy developer guide], most of which applies to this project, too.
4 | This document will not reproduce the entire content from there. Instead, it aims at summarizing the most important
5 | information to get you started on contributing.
6 |
7 | We assume that you are already familiar with git and with making pull requests on GitHub. If not, please refer
8 | to the [scanpy developer guide][].
9 |
10 | ## Installing dev dependencies
11 |
12 | In addition to the packages needed to _use_ this package, you need additional python packages to _run tests_ and _build
13 | the documentation_. It's easy to install them using `pip`:
14 |
15 | ```bash
16 | cd multimil
17 | pip install -e ".[dev,test,doc]"
18 | ```
19 |
20 | ## Code-style
21 |
22 | This package uses [pre-commit][] to enforce consistent code-styles.
23 | On every commit, pre-commit checks will either automatically fix issues with the code, or raise an error message.
24 |
25 | To enable pre-commit locally, simply run
26 |
27 | ```bash
28 | pre-commit install
29 | ```
30 |
31 | in the root of the repository. Pre-commit will automatically download all dependencies when it is run for the first time.
32 |
33 | Alternatively, you can rely on the [pre-commit.ci][] service enabled on GitHub. If you didn't run `pre-commit` before
34 | pushing changes to GitHub it will automatically commit fixes to your pull request, or show an error message.
35 |
36 | If pre-commit.ci added a commit on a branch you still have been working on locally, simply use
37 |
38 | ```bash
39 | git pull --rebase
40 | ```
41 |
42 | to integrate the changes into yours.
43 | While the [pre-commit.ci][] is useful, we strongly encourage installing and running pre-commit locally first to understand its usage.
44 |
45 | Finally, most editors have an _autoformat on save_ feature. Consider enabling this option for [ruff][ruff-editors]
46 | and [prettier][prettier-editors].
47 |
48 | [ruff-editors]: https://docs.astral.sh/ruff/integrations/
49 | [prettier-editors]: https://prettier.io/docs/en/editors.html
50 |
51 | ## Writing tests
52 |
53 | ```{note}
54 | Remember to first install the package with `pip install -e '.[dev,test]'`
55 | ```
56 |
57 | This package uses the [pytest][] for automated testing. Please [write tests][scanpy-test-docs] for every function added
58 | to the package.
59 |
60 | Most IDEs integrate with pytest and provide a GUI to run tests. Alternatively, you can run all tests from the
61 | command line by executing
62 |
63 | ```bash
64 | pytest
65 | ```
66 |
67 | in the root of the repository.
68 |
69 | ### Continuous integration
70 |
71 | Continuous integration will automatically run the tests on all pull requests and test
72 | against the minimum and maximum supported Python version.
73 |
74 | Additionally, there's a CI job that tests against pre-releases of all dependencies
75 | (if there are any). The purpose of this check is to detect incompatibilities
76 | of new package versions early on and gives you time to fix the issue or reach
77 | out to the developers of the dependency before the package is released to a wider audience.
78 |
79 | [scanpy-test-docs]: https://scanpy.readthedocs.io/en/latest/dev/testing.html#writing-tests
80 |
81 | ## Publishing a release
82 |
83 | ### Updating the version number
84 |
85 | Before making a release, you need to update the version number in the `pyproject.toml` file. Please adhere to [Semantic Versioning][semver], in brief
86 |
87 | > Given a version number MAJOR.MINOR.PATCH, increment the:
88 | >
89 | > 1. MAJOR version when you make incompatible API changes,
90 | > 2. MINOR version when you add functionality in a backwards compatible manner, and
91 | > 3. PATCH version when you make backwards compatible bug fixes.
92 | >
93 | > Additional labels for pre-release and build metadata are available as extensions to the MAJOR.MINOR.PATCH format.
94 |
95 | Once you are done, commit and push your changes and navigate to the "Releases" page of this project on GitHub.
96 | Specify `vX.X.X` as a tag name and create a release. For more information, see [managing GitHub releases][]. This will automatically create a git tag and trigger a Github workflow that creates a release on PyPI.
97 |
98 | ## Writing documentation
99 |
100 | Please write documentation for new or changed features and use-cases. This project uses [sphinx][] with the following features:
101 |
102 | - the [myst][] extension allows to write documentation in markdown/Markedly Structured Text
103 | - [Numpy-style docstrings][numpydoc] (through the [napoloen][numpydoc-napoleon] extension).
104 | - Jupyter notebooks as tutorials through [myst-nb][] (See [Tutorials with myst-nb](#tutorials-with-myst-nb-and-jupyter-notebooks))
105 | - [Sphinx autodoc typehints][], to automatically reference annotated input and output types
106 | - Citations (like {cite:p}`Virshup_2023`) can be included with [sphinxcontrib-bibtex](https://sphinxcontrib-bibtex.readthedocs.io/)
107 |
108 | See the [scanpy developer docs](https://scanpy.readthedocs.io/en/latest/dev/documentation.html) for more information
109 | on how to write documentation.
110 |
111 | ### Tutorials with myst-nb and jupyter notebooks
112 |
113 | The documentation is set-up to render jupyter notebooks stored in the `docs/notebooks` directory using [myst-nb][].
114 | Currently, only notebooks in `.ipynb` format are supported that will be included with both their input and output cells.
115 | It is your responsibility to update and re-run the notebook whenever necessary.
116 |
117 | If you are interested in automatically running notebooks as part of the continuous integration, please check
118 | out [this feature request](https://github.com/scverse/cookiecutter-scverse/issues/40) in the `cookiecutter-scverse`
119 | repository.
120 |
121 | #### Hints
122 |
123 | - If you refer to objects from other packages, please add an entry to `intersphinx_mapping` in `docs/conf.py`. Only
124 | if you do so can sphinx automatically create a link to the external documentation.
125 | - If building the documentation fails because of a missing link that is outside your control, you can add an entry to
126 | the `nitpick_ignore` list in `docs/conf.py`
127 |
128 | #### Building the docs locally
129 |
130 | ```bash
131 | cd docs
132 | make html
133 | open _build/html/index.html
134 | ```
135 |
136 |
137 |
138 | [scanpy developer guide]: https://scanpy.readthedocs.io/en/latest/dev/index.html
139 | [cookiecutter-scverse-instance]: https://cookiecutter-scverse-instance.readthedocs.io/en/latest/template_usage.html
140 | [github quickstart guide]: https://docs.github.com/en/get-started/quickstart/create-a-repo?tool=webui
141 | [codecov]: https://about.codecov.io/sign-up/
142 | [codecov docs]: https://docs.codecov.com/docs
143 | [codecov bot]: https://docs.codecov.com/docs/team-bot
144 | [codecov app]: https://github.com/apps/codecov
145 | [pre-commit.ci]: https://pre-commit.ci/
146 | [readthedocs.org]: https://readthedocs.org/
147 | [myst-nb]: https://myst-nb.readthedocs.io/en/latest/
148 | [jupytext]: https://jupytext.readthedocs.io/en/latest/
149 | [pre-commit]: https://pre-commit.com/
150 | [anndata]: https://github.com/scverse/anndata
151 | [mudata]: https://github.com/scverse/mudata
152 | [pytest]: https://docs.pytest.org/
153 | [semver]: https://semver.org/
154 | [sphinx]: https://www.sphinx-doc.org/en/master/
155 | [myst]: https://myst-parser.readthedocs.io/en/latest/intro.html
156 | [numpydoc-napoleon]: https://www.sphinx-doc.org/en/master/usage/extensions/napoleon.html
157 | [numpydoc]: https://numpydoc.readthedocs.io/en/latest/format.html
158 | [sphinx autodoc typehints]: https://github.com/tox-dev/sphinx-autodoc-typehints
159 | [pypi]: https://pypi.org/
160 | [managing GitHub releases]: https://docs.github.com/en/repositories/releasing-projects-on-github/managing-releases-in-a-repository
161 |
--------------------------------------------------------------------------------
/docs/extensions/typed_returns.py:
--------------------------------------------------------------------------------
1 | # code from https://github.com/theislab/scanpy/blob/master/docs/extensions/typed_returns.py
2 | # with some minor adjustment
3 | from __future__ import annotations
4 |
5 | import re
6 | from collections.abc import Generator, Iterable
7 |
8 | from sphinx.application import Sphinx
9 | from sphinx.ext.napoleon import NumpyDocstring
10 |
11 |
12 | def _process_return(lines: Iterable[str]) -> Generator[str, None, None]:
13 | for line in lines:
14 | if m := re.fullmatch(r"(?P\w+)\s+:\s+(?P[\w.]+)", line):
15 | yield f'-{m["param"]} (:class:`~{m["type"]}`)'
16 | else:
17 | yield line
18 |
19 |
20 | def _parse_returns_section(self: NumpyDocstring, section: str) -> list[str]:
21 | lines_raw = self._dedent(self._consume_to_next_section())
22 | if lines_raw[0] == ":":
23 | del lines_raw[0]
24 | lines = self._format_block(":returns: ", list(_process_return(lines_raw)))
25 | if lines and lines[-1]:
26 | lines.append("")
27 | return lines
28 |
29 |
30 | def setup(app: Sphinx):
31 | """Set app."""
32 | NumpyDocstring._parse_returns_section = _parse_returns_section
33 |
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | ```{include} ../README.md
2 |
3 | ```
4 |
5 | ```{toctree}
6 | :hidden: true
7 | :maxdepth: 2
8 |
9 | api.md
10 | tutorials.md
11 | changelog.md
12 | contributing.md
13 | references.md
14 | ```
15 |
--------------------------------------------------------------------------------
/docs/notebooks/cellttype_harmonize.json:
--------------------------------------------------------------------------------
1 | {
2 | "cite_ct_l1_map": {
3 | "Naive CD20+ B IGKC+": "B",
4 | "Naive CD20+ B IGKC-": "B",
5 | "B1 B IGKC-": "B",
6 | "B1 B IGKC+": "B",
7 | "Transitional B": "B",
8 | "G/M prog": "Myeloid",
9 | "CD14+ Mono": "Myeloid",
10 | "pDC": "Myeloid",
11 | "cDC2": "Myeloid",
12 | "CD16+ Mono": "Myeloid",
13 | "cDC1": "Myeloid",
14 | "HSC": "HSC",
15 | "Reticulocyte": "Erythroid",
16 | "Normoblast": "Erythroid",
17 | "Erythroblast": "Erythroid",
18 | "Proerythroblast": "Erythroid",
19 | "MK/E prog": "Erythroid",
20 | "NK CD158e1+": "NK",
21 | "NK": "NK",
22 | "CD4+ T naive": "T",
23 | "CD4+ T activated": "T",
24 | "CD4+ T activated integrinB7+": "T",
25 | "CD4+ T CD314+ CD45RA+": "T",
26 | "CD8+ T naive": "T",
27 | "CD8+ T CD49f+": "T",
28 | "CD8+ T TIGIT+ CD45RO+": "T",
29 | "CD8+ T CD57+ CD45RA+": "T",
30 | "CD8+ T CD69+ CD45RO+": "T",
31 | "CD8+ T TIGIT+ CD45RA+": "T",
32 | "CD8+ T CD69+ CD45RA+": "T",
33 | "CD8+ T naive CD127+ CD26- CD101-": "T",
34 | "CD8+ T CD57+ CD45RO+": "T",
35 | "MAIT": "T",
36 | "T reg": "T",
37 | "gdT TCRVD2+": "T",
38 | "gdT CD158b+": "T",
39 | "dnT": "T",
40 | "T prog cycling": "T",
41 | "ILC1": "T",
42 | "Lymph prog": "Other Lympoid",
43 | "ILC": "Other Lympoid",
44 | "Plasmablast IGKC+": "Plasma",
45 | "Plasmablast IGKC-": "Plasma",
46 | "Plasma cell IGKC+": "Plasma",
47 | "Plasma cell IGKC-": "Plasma"
48 | },
49 | "cite_ct_l2_map": {
50 | "Naive CD20+ B IGKC+": "Naive CD20+ B",
51 | "Naive CD20+ B IGKC-": "Naive CD20+ B",
52 | "B1 B IGKC-": "B1 B",
53 | "B1 B IGKC+": "B1 B",
54 | "Transitional B": "Transitional B",
55 | "G/M prog": "G/M prog",
56 | "CD14+ Mono": "CD14+ Mono",
57 | "pDC": "pDC",
58 | "cDC2": "cDC2",
59 | "CD16+ Mono": "CD16+ Mono",
60 | "cDC1": "cDC1",
61 | "HSC": "HSC",
62 | "Reticulocyte": "Reticulocyte",
63 | "Normoblast": "Normoblast",
64 | "Erythroblast": "Erythroblast",
65 | "Proerythroblast": "Proerythroblast",
66 | "MK/E prog": "MK/E prog",
67 | "NK CD158e1+": "NK",
68 | "NK": "NK",
69 | "CD4+ T naive": "CD4+ T naive",
70 | "CD4+ T activated": "CD4+ T activated",
71 | "CD4+ T activated integrinB7+": "CD4+ T activated",
72 | "CD4+ T CD314+ CD45RA+": "CD4+ T activated",
73 | "CD8+ T naive": "CD8+ T naive",
74 | "CD8+ T naive CD127+ CD26- CD101-": "CD8+ T naive",
75 | "CD8+ T CD49f+": "CD8+ T activated",
76 | "CD8+ T TIGIT+ CD45RO+": "CD8+ T activated",
77 | "CD8+ T CD57+ CD45RA+": "CD8+ T activated",
78 | "CD8+ T CD69+ CD45RO+": "CD8+ T activated",
79 | "CD8+ T TIGIT+ CD45RA+": "CD8+ T activated",
80 | "CD8+ T CD69+ CD45RA+": "CD8+ T activated",
81 | "CD8+ T CD57+ CD45RO+": "CD8+ T activated",
82 | "MAIT": "Other T",
83 | "T reg": "Other T",
84 | "gdT TCRVD2+": "Other T",
85 | "gdT CD158b+": "Other T",
86 | "dnT": "Other T",
87 | "T prog cycling": "Other T",
88 | "ILC1": "Other T",
89 | "Lymph prog": "Early Lymphoid",
90 | "ILC": "ILC",
91 | "Plasmablast IGKC+": "Plasma",
92 | "Plasmablast IGKC-": "Plasma",
93 | "Plasma cell IGKC+": "Plasma",
94 | "Plasma cell IGKC-": "Plasma"
95 | },
96 | "multi_ct_l1_map": {
97 | "G/M prog": "Myeloid",
98 | "CD14+ Mono": "Myeloid",
99 | "CD16+ Mono": "Myeloid",
100 | "pDC": "Myeloid",
101 | "cDC2": "Myeloid",
102 | "ID2-hi myeloid prog": "Myeloid",
103 | "CD8+ T": "T",
104 | "CD8+ T naive": "T",
105 | "CD4+ T naive": "T",
106 | "CD4+ T activated": "T",
107 | "NK": "NK",
108 | "ILC": "Other Lympoid",
109 | "Lymph prog": "Other Lympoid",
110 | "Transitional B": "B",
111 | "Naive CD20+ B": "B",
112 | "B1 B": "B",
113 | "Plasma cell": "Plasma",
114 | "MK/E prog": "Erythroid",
115 | "Proerythroblast": "Erythroid",
116 | "Erythroblast": "Erythroid",
117 | "Normoblast": "Erythroid",
118 | "HSC": "HSC"
119 | },
120 | "multi_ct_l2_map": {
121 | "G/M prog": "G/M prog",
122 | "CD14+ Mono": "CD14+ Mono",
123 | "CD16+ Mono": "CD16+ Mono",
124 | "pDC": "pDC",
125 | "cDC2": "cDC2",
126 | "ID2-hi myeloid prog": "Other Myeloid",
127 | "CD8+ T": "CD8+ T activated",
128 | "CD8+ T naive": "CD8+ T naive",
129 | "CD4+ T naive": "CD4+ T naive",
130 | "CD4+ T activated": "CD4+ T activated",
131 | "NK": "NK",
132 | "ILC": "ILC",
133 | "Lymph prog": "Early Lymphoid",
134 | "Transitional B": "Transitional B",
135 | "Naive CD20+ B": "Naive CD20+ B",
136 | "B1 B": "B1 B",
137 | "Plasma cell": "Plasma",
138 | "MK/E prog": "MK/E prog",
139 | "Proerythroblast": "Proerythroblast",
140 | "Erythroblast": "Erythroblast",
141 | "Normoblast": "Normoblast",
142 | "HSC": "HSC"
143 | }
144 | }
145 |
--------------------------------------------------------------------------------
/docs/references.bib:
--------------------------------------------------------------------------------
1 | @article{Virshup_2023,
2 | doi = {10.1038/s41587-023-01733-8},
3 | url = {https://doi.org/10.1038%2Fs41587-023-01733-8},
4 | year = 2023,
5 | month = {apr},
6 | publisher = {Springer Science and Business Media {LLC}},
7 | author = {Isaac Virshup and Danila Bredikhin and Lukas Heumos and Giovanni Palla and Gregor Sturm and Adam Gayoso and Ilia Kats and Mikaela Koutrouli and Philipp Angerer and Volker Bergen and Pierre Boyeau and Maren Büttner and Gokcen Eraslan and David Fischer and Max Frank and Justin Hong and Michal Klein and Marius Lange and Romain Lopez and Mohammad Lotfollahi and Malte D. Luecken and Fidel Ramirez and Jeffrey Regier and Sergei Rybakov and Anna C. Schaar and Valeh Valiollah Pour Amiri and Philipp Weiler and Galen Xing and Bonnie Berger and Dana Pe'er and Aviv Regev and Sarah A. Teichmann and Francesca Finotello and F. Alexander Wolf and Nir Yosef and Oliver Stegle and Fabian J. Theis and},
8 | title = {The scverse project provides a computational ecosystem for single-cell omics data analysis},
9 | journal = {Nature Biotechnology}
10 | }
11 |
12 | @inproceedings{Luecken2021-ct,
13 | author = {Luecken, Malte and Burkhardt, Daniel and Cannoodt, Robrecht and Lance, Christopher and Agrawal, Aditi and Aliee, Hananeh and Chen, Ann and Deconinck, Louise and Detweiler, Angela and Granados, Alejandro and Huynh, Shelly and Isacco, Laura and Kim, Yang and Klein, Dominik and DE KUMAR, BONY and Kuppasani, Sunil and Lickert, Heiko and McGeever, Aaron and Melgarejo, Joaquin and Mekonen, Honey and Morri, Maurizio and M\"{u}ller, Michaela and Neff, Norma and Paul, Sheryl and Rieck, Bastian and Schneider, Kaylie and Steelman, Scott and Sterr, Michael and Treacy, Daniel and Tong, Alexander and Villani, Alexandra-Chloe and Wang, Guilin and Yan, Jia and Zhang, Ce and Pisco, Angela and Krishnaswamy, Smita and Theis, Fabian and Bloom, Jonathan M},
14 | booktitle = {Proceedings of the Neural Information Processing Systems Track on Datasets and Benchmarks},
15 | editor = {J. Vanschoren and S. Yeung},
16 | pages = {},
17 | title = {A sandbox for prediction and integration of DNA, RNA, and proteins in single cells},
18 | url = {https://datasets-benchmarks-proceedings.neurips.cc/paper_files/paper/2021/file/158f3069a435b314a80bdcb024f8e422-Paper-round2.pdf},
19 | volume = {1},
20 | year = {2021}
21 | }
22 |
23 | @ARTICLE{Sikkema2023-oh,
24 | title = "An integrated cell atlas of the lung in health and disease",
25 | author = "Sikkema, Lisa and Ramírez-Suástegui, Ciro and Strobl, Daniel C and
26 | Gillett, Tessa E and Zappia, Luke and Madissoon, Elo and Markov,
27 | Nikolay S and Zaragosi, Laure-Emmanuelle and Ji, Yuge and Ansari,
28 | Meshal and Arguel, Marie-Jeanne and Apperloo, Leonie and Banchero,
29 | Martin and Bécavin, Christophe and Berg, Marijn and
30 | Chichelnitskiy, Evgeny and Chung, Mei-I and Collin, Antoine and
31 | Gay, Aurore C A and Gote-Schniering, Janine and Hooshiar Kashani,
32 | Baharak and Inecik, Kemal and Jain, Manu and Kapellos, Theodore S
33 | and Kole, Tessa M and Leroy, Sylvie and Mayr, Christoph H and
34 | Oliver, Amanda J and von Papen, Michael and Peter, Lance and
35 | Taylor, Chase J and Walzthoeni, Thomas and Xu, Chuan and Bui, Linh
36 | T and De Donno, Carlo and Dony, Leander and Faiz, Alen and Guo,
37 | Minzhe and Gutierrez, Austin J and Heumos, Lukas and Huang, Ni and
38 | Ibarra, Ignacio L and Jackson, Nathan D and Kadur Lakshminarasimha
39 | Murthy, Preetish and Lotfollahi, Mohammad and Tabib, Tracy and
40 | Talavera-López, Carlos and Travaglini, Kyle J and Wilbrey-Clark,
41 | Anna and Worlock, Kaylee B and Yoshida, Masahiro and {Lung
42 | Biological Network Consortium} and van den Berge, Maarten and
43 | Bossé, Yohan and Desai, Tushar J and Eickelberg, Oliver and
44 | Kaminski, Naftali and Krasnow, Mark A and Lafyatis, Robert and
45 | Nikolic, Marko Z and Powell, Joseph E and Rajagopal, Jayaraj and
46 | Rojas, Mauricio and Rozenblatt-Rosen, Orit and Seibold, Max A and
47 | Sheppard, Dean and Shepherd, Douglas P and Sin, Don D and Timens,
48 | Wim and Tsankov, Alexander M and Whitsett, Jeffrey and Xu, Yan and
49 | Banovich, Nicholas E and Barbry, Pascal and Duong, Thu Elizabeth
50 | and Falk, Christine S and Meyer, Kerstin B and Kropski, Jonathan A
51 | and Pe'er, Dana and Schiller, Herbert B and Tata, Purushothama Rao
52 | and Schultze, Joachim L and Teichmann, Sara A and Misharin,
53 | Alexander V and Nawijn, Martijn C and Luecken, Malte D and Theis,
54 | Fabian J",
55 | journal = "Nat. Med.",
56 | volume = 29,
57 | number = 6,
58 | pages = "1563--1577",
59 | abstract = "Single-cell technologies have transformed our understanding of
60 | human tissues. Yet, studies typically capture only a limited
61 | number of donors and disagree on cell type definitions.
62 | Integrating many single-cell datasets can address these
63 | limitations of individual studies and capture the variability
64 | present in the population. Here we present the integrated Human
65 | Lung Cell Atlas (HLCA), combining 49 datasets of the human
66 | respiratory system into a single atlas spanning over 2.4 million
67 | cells from 486 individuals. The HLCA presents a consensus cell
68 | type re-annotation with matching marker genes, including
69 | annotations of rare and previously undescribed cell types.
70 | Leveraging the number and diversity of individuals in the HLCA, we
71 | identify gene modules that are associated with demographic
72 | covariates such as age, sex and body mass index, as well as gene
73 | modules changing expression along the proximal-to-distal axis of
74 | the bronchial tree. Mapping new data to the HLCA enables rapid
75 | data annotation and interpretation. Using the HLCA as a reference
76 | for the study of disease, we identify shared cell states across
77 | multiple lung diseases, including SPP1+ profibrotic
78 | monocyte-derived macrophages in COVID-19, pulmonary fibrosis and
79 | lung carcinoma. Overall, the HLCA serves as an example for the
80 | development and use of large-scale, cross-dataset organ atlases
81 | within the Human Cell Atlas.",
82 | month = jun,
83 | year = 2023,
84 | language = "en"
85 | }
86 |
87 | @ARTICLE{Lotfollahi2022-jw,
88 | title = "Mapping single-cell data to reference atlases by transfer learning",
89 | author = "Lotfollahi, Mohammad and Naghipourfar, Mohsen and Luecken, Malte D
90 | and Khajavi, Matin and Büttner, Maren and Wagenstetter, Marco and
91 | Avsec, Žiga and Gayoso, Adam and Yosef, Nir and Interlandi, Marta
92 | and Rybakov, Sergei and Misharin, Alexander V and Theis, Fabian J",
93 | journal = "Nat. Biotechnol.",
94 | volume = 40,
95 | number = 1,
96 | pages = "121--130",
97 | abstract = "Large single-cell atlases are now routinely generated to serve as
98 | references for analysis of smaller-scale studies. Yet learning
99 | from reference data is complicated by batch effects between
100 | datasets, limited availability of computational resources and
101 | sharing restrictions on raw data. Here we introduce a deep
102 | learning strategy for mapping query datasets on top of a reference
103 | called single-cell architectural surgery (scArches). scArches uses
104 | transfer learning and parameter optimization to enable efficient,
105 | decentralized, iterative reference building and contextualization
106 | of new datasets with existing references without sharing raw data.
107 | Using examples from mouse brain, pancreas, immune and
108 | whole-organism atlases, we show that scArches preserves biological
109 | state information while removing batch effects, despite using four
110 | orders of magnitude fewer parameters than de novo integration.
111 | scArches generalizes to multimodal reference mapping, allowing
112 | imputation of missing modalities. Finally, scArches retains
113 | coronavirus disease 2019 (COVID-19) disease variation when mapping
114 | to a healthy reference, enabling the discovery of disease-specific
115 | cell states. scArches will facilitate collaborative projects by
116 | enabling iterative construction, updating, sharing and efficient
117 | use of reference atlases.",
118 | month = jan,
119 | year = 2022,
120 | language = "en"
121 | }
122 |
--------------------------------------------------------------------------------
/docs/references.md:
--------------------------------------------------------------------------------
1 | # References
2 |
3 | ```{bibliography}
4 | :cited:
5 | ```
6 |
--------------------------------------------------------------------------------
/docs/tutorials.md:
--------------------------------------------------------------------------------
1 | # Tutorials
2 |
3 |
11 |
12 |
13 |
14 | ```{toctree}
15 | :maxdepth: 1
16 |
17 | notebooks/mil_classification
18 | ```
19 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | build-backend = "hatchling.build"
3 | requires = ["hatchling"]
4 |
5 | [project]
6 | name = "multimil"
7 | version = "0.2.0"
8 | description = "Multimodal weakly supervised learning to identify disease-specific changes in single-cell atlases"
9 | readme = "README.md"
10 | requires-python = ">=3.10"
11 | license = {file = "LICENSE"}
12 | authors = [
13 | {name = "Anastasia Litinetskaya"},
14 | ]
15 | maintainers = [
16 | {name = "Anastasia Litinetskaya", email = "alitinet@gmail.com"},
17 | ]
18 | urls.Documentation = "https://multimil.readthedocs.io/"
19 | urls.Source = "https://github.com/theislab/multimil"
20 | urls.Home-page = "https://github.com/theislab/multimil"
21 | dependencies = [
22 | "anndata<0.11",
23 | "numpy<2.0",
24 | # for debug logging (referenced from the issue template)
25 | "session-info",
26 | # multimil specific
27 | "scanpy",
28 | "scvi-tools<1.0.0",
29 | "matplotlib",
30 | ]
31 |
32 | [project.optional-dependencies]
33 | dev = [
34 | "pre-commit",
35 | "twine>=4.0.2",
36 | ]
37 | doc = [
38 | "docutils>=0.8,!=0.18.*,!=0.19.*",
39 | "sphinx>=4",
40 | "sphinx-book-theme>=1.0.0",
41 | "myst-nb>=1.1.0",
42 | "sphinxcontrib-bibtex>=1.0.0",
43 | "setuptools", # Until pybtex >0.23.0 releases: https://bitbucket.org/pybtex-devs/pybtex/issues/169/
44 | "sphinx-autodoc-typehints",
45 | "sphinxext-opengraph",
46 | # For notebooks
47 | "ipykernel",
48 | "ipython",
49 | "sphinx-copybutton",
50 | "pandas",
51 | ]
52 | test = [
53 | "pytest",
54 | "coverage",
55 | ]
56 | tutorials = [
57 | "muon",
58 | "jupyterlab",
59 | "ipywidgets",
60 | "leidenalg",
61 | "igraph",
62 | "gdown",
63 | ]
64 |
65 | [tool.coverage.run]
66 | source = ["multimil"]
67 | omit = [
68 | "**/test_*.py",
69 | ]
70 |
71 | [tool.pytest.ini_options]
72 | testpaths = ["tests"]
73 | xfail_strict = true
74 | addopts = [
75 | "--import-mode=importlib", # allow using test files with same name
76 | ]
77 |
78 | [tool.ruff]
79 | line-length = 120
80 | src = ["src"]
81 | extend-include = ["*.ipynb"]
82 |
83 | [tool.ruff.format]
84 | docstring-code-format = true
85 |
86 | [tool.ruff.lint]
87 | select = [
88 | "F", # Errors detected by Pyflakes
89 | "E", # Error detected by Pycodestyle
90 | "W", # Warning detected by Pycodestyle
91 | "I", # isort
92 | "D", # pydocstyle
93 | "B", # flake8-bugbear
94 | "TID", # flake8-tidy-imports
95 | "C4", # flake8-comprehensions
96 | "BLE", # flake8-blind-except
97 | "UP", # pyupgrade
98 | "RUF100", # Report unused noqa directives
99 | ]
100 | ignore = [
101 | # line too long -> we accept long comment lines; formatter gets rid of long code lines
102 | "E501",
103 | # Do not assign a lambda expression, use a def -> lambda expression assignments are convenient
104 | "E731",
105 | # allow I, O, l as variable names -> I is the identity matrix
106 | "E741",
107 | # Missing docstring in public package
108 | "D104",
109 | # Missing docstring in public module
110 | "D100",
111 | # Missing docstring in __init__
112 | "D107",
113 | # Errors from function calls in argument defaults. These are fine when the result is immutable.
114 | "B008",
115 | # __magic__ methods are often self-explanatory, allow missing docstrings
116 | "D105",
117 | # first line should end with a period [Bug: doesn't work with single-line docstrings]
118 | "D400",
119 | # First line should be in imperative mood; try rephrasing
120 | "D401",
121 | ## Disable one in each pair of mutually incompatible rules
122 | # We don’t want a blank line before a class docstring
123 | "D203",
124 | # We want docstrings to start immediately after the opening triple quote
125 | "D213",
126 | ]
127 |
128 | [tool.ruff.lint.pydocstyle]
129 | convention = "numpy"
130 |
131 | [tool.ruff.lint.per-file-ignores]
132 | "docs/*" = ["I"]
133 | "tests/*" = ["D"]
134 | "*/__init__.py" = ["F401"]
135 | "src/multimil/module/__init__.py" = ["I"]
136 |
137 | [tool.cruft]
138 | skip = [
139 | "tests",
140 | "src/**/__init__.py",
141 | "src/**/basic.py",
142 | "docs/api.md",
143 | "docs/changelog.md",
144 | "docs/references.bib",
145 | "docs/references.md",
146 | "docs/notebooks/example.ipynb",
147 | ]
148 |
--------------------------------------------------------------------------------
/src/multimil/__init__.py:
--------------------------------------------------------------------------------
1 | from importlib.metadata import version
2 |
3 | from . import data, dataloaders, distributions, model, module, nn, utils
4 |
5 | __all__ = ["data", "dataloaders", "distributions", "model", "module", "nn", "utils"]
6 |
7 | __version__ = version("multimil")
8 |
--------------------------------------------------------------------------------
/src/multimil/data/__init__.py:
--------------------------------------------------------------------------------
1 | from ._preprocessing import organize_multimodal_anndatas
2 |
3 | __all__ = ["organize_multimodal_anndatas"]
4 |
--------------------------------------------------------------------------------
/src/multimil/data/_preprocessing.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import anndata as ad
4 | import numpy as np
5 | import pandas as pd
6 |
7 |
8 | def organize_multimodal_anndatas(
9 | adatas: list[list[ad.AnnData | None]],
10 | layers: list[list[str | None]] | None = None,
11 | ) -> ad.AnnData:
12 | """Concatenate all the input anndata objects.
13 |
14 | These anndata objects should already have been preprocessed so that all single-modality
15 | objects use a subset of the features used in the multiome object. The feature names (index of
16 | `.var`) should match between the objects for vertical integration and cell names (index of
17 | `.obs`) should match between the objects for horizontal integration.
18 |
19 | Parameters
20 | ----------
21 | adatas
22 | List of Lists with AnnData objects or None where each sublist corresponds to a modality.
23 | layers
24 | List of Lists of the same lengths as `adatas` specifying which `.layer` to use for each AnnData. Default is None which means using `.X`.
25 |
26 | Returns
27 | -------
28 | Concatenated AnnData object across modalities and datasets.
29 | """
30 | # TODO: add checks for layers
31 | # TODO: add check that len of modalities is the same as len of losses, etc
32 |
33 | # needed for scArches operation setup
34 | datasets_lengths = {}
35 | datasets_obs_names = {}
36 | datasets_obs = {}
37 | modality_lengths = {}
38 | modality_var_names = {}
39 |
40 | # sanity checks and preparing data for concat
41 | for mod, modality_adatas in enumerate(adatas):
42 | for i, adata in enumerate(modality_adatas):
43 | if adata is not None:
44 | # will create .obs['group'] later, so throw a warning here if the column already exists
45 | if "group" in adata.obs.columns:
46 | warnings.warn(
47 | "Column `.obs['group']` will be overwritten. Please save the original data in another column if needed.",
48 | stacklevel=2,
49 | )
50 | # check that all adatas in the same modality have the same number of features
51 | if (mod_length := modality_lengths.get(f"{mod}", None)) is None:
52 | modality_lengths[f"{mod}"] = adata.shape[1]
53 | else:
54 | if adata.shape[1] != mod_length:
55 | raise ValueError(
56 | f"Adatas have different number of features for modality {mod}, namely {mod_length} and {adata.shape[1]}."
57 | )
58 | # check that there is the same number of observations for paired data
59 | if (dataset_length := datasets_lengths.get(i, None)) is None:
60 | datasets_lengths[i] = adata.shape[0]
61 | else:
62 | if adata.shape[0] != dataset_length:
63 | raise ValueError(
64 | f"Paired adatas have different number of observations for group {i}, namely {dataset_length} and {adata.shape[0]}."
65 | )
66 | # check that .obs_names are the same for paired data
67 | if (dataset_obs_names := datasets_obs_names.get(i, None)) is None:
68 | datasets_obs_names[i] = adata.obs_names
69 | else:
70 | if np.sum(adata.obs_names != dataset_obs_names):
71 | raise ValueError(f"`.obs_names` are not the same for group {i}.")
72 | # keep all the .obs
73 | if datasets_obs.get(i, None) is None:
74 | datasets_obs[i] = adata.obs
75 | datasets_obs[i].loc[:, "group"] = i
76 | else:
77 | cols_to_use = adata.obs.columns.difference(datasets_obs[i].columns)
78 | datasets_obs[i] = datasets_obs[i].join(adata.obs[cols_to_use])
79 | modality_var_names[mod] = adata.var_names
80 |
81 | for mod, modality_adatas in enumerate(adatas):
82 | for i, adata in enumerate(modality_adatas):
83 | if not isinstance(adata, ad.AnnData) and adata is None:
84 | X_zeros = np.zeros((datasets_lengths[i], modality_lengths[f"{mod}"]))
85 | adatas[mod][i] = ad.AnnData(X_zeros, dtype=X_zeros.dtype)
86 | adatas[mod][i].obs_names = datasets_obs_names[i]
87 | adatas[mod][i].var_names = modality_var_names[mod]
88 | adatas[mod][i] = adatas[mod][i].copy()
89 | if layers is not None:
90 | if layers[mod][i]:
91 | layer = layers[mod][i]
92 | adatas[mod][i] = adatas[mod][i].copy()
93 | adatas[mod][i].X = adatas[mod][i].layers[layer].copy()
94 |
95 | # concat adatas within each modality first
96 | mod_adatas = []
97 | for modality_adatas in adatas:
98 | mod_adatas.append(ad.concat(modality_adatas, join="outer"))
99 |
100 | # concat modality adatas along the feature axis
101 | multiome_anndata = ad.concat(mod_adatas, axis=1, label="modality")
102 |
103 | # add .obs back
104 | multiome_anndata.obs = pd.concat(datasets_obs.values())
105 |
106 | # we will need modality_length later for the model init
107 | multiome_anndata.uns["modality_lengths"] = modality_lengths
108 | multiome_anndata.var_names_make_unique()
109 |
110 | return multiome_anndata
111 |
--------------------------------------------------------------------------------
/src/multimil/dataloaders/__init__.py:
--------------------------------------------------------------------------------
1 | from ._ann_dataloader import GroupAnnDataLoader
2 | from ._data_splitting import GroupDataSplitter
3 |
4 | __all__ = ["GroupAnnDataLoader", "GroupDataSplitter"]
5 |
--------------------------------------------------------------------------------
/src/multimil/dataloaders/_ann_dataloader.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import itertools
3 |
4 | import numpy as np
5 | import torch
6 | from scvi.data import AnnDataManager
7 | from scvi.dataloaders import AnnTorchDataset
8 | from torch.utils.data import DataLoader
9 |
10 |
11 | # adjusted from scvi-tools
12 | # https://github.com/YosefLab/scvi-tools/blob/ac0c3e04fcc2772fdcf7de4de819db3af9465b6b/scvi/dataloaders/_ann_dataloader.py#L15
13 | # accessed on 4 November 2021
14 | class StratifiedSampler(torch.utils.data.sampler.Sampler):
15 | """Custom stratified sampler class which enables sampling the same number of observation from each group in each mini-batch.
16 |
17 | Parameters
18 | ----------
19 | indices
20 | List of indices to sample from.
21 | batch_size
22 | Batch size of each iteration.
23 | shuffle
24 | If ``True``, shuffles indices before sampling.
25 | drop_last
26 | If int, drops the last batch if its length is less than drop_last.
27 | If drop_last == True, drops last non-full batch.
28 | If drop_last == False, iterate over all batches.
29 | """
30 |
31 | def __init__(
32 | self,
33 | indices: np.ndarray,
34 | group_labels: np.ndarray,
35 | batch_size: int,
36 | min_size_per_class: int,
37 | shuffle: bool = True,
38 | drop_last: bool | int = True,
39 | shuffle_classes: bool = True,
40 | ):
41 | if drop_last > batch_size:
42 | raise ValueError(
43 | "drop_last can't be greater than batch_size. "
44 | + f"drop_last is {drop_last} but batch_size is {batch_size}."
45 | )
46 |
47 | if batch_size % min_size_per_class != 0:
48 | raise ValueError(
49 | "min_size_per_class has to be a divisor of batch_size."
50 | + f"min_size_per_class is {min_size_per_class} but batch_size is {batch_size}."
51 | )
52 |
53 | self.indices = indices
54 | self.group_labels = group_labels
55 | self.n_obs = len(indices)
56 | self.batch_size = batch_size
57 | self.shuffle = shuffle
58 | self.shuffle_classes = shuffle_classes
59 | self.min_size_per_class = min_size_per_class
60 | self.drop_last = drop_last
61 |
62 | from math import ceil
63 |
64 | classes = list(dict.fromkeys(self.group_labels))
65 |
66 | tmp = 0
67 | for cl in classes:
68 | idx = np.where(self.group_labels == cl)[0]
69 | cl_idx = self.indices[idx]
70 | n_obs = len(cl_idx)
71 | last_batch_len = n_obs % self.min_size_per_class
72 | if (self.drop_last is True) or (last_batch_len < self.drop_last):
73 | drop_last_n = last_batch_len
74 | elif (self.drop_last is False) or (last_batch_len >= self.drop_last):
75 | drop_last_n = 0
76 | else:
77 | raise ValueError("Invalid input for drop_last param. Must be bool or int.")
78 |
79 | if drop_last_n != 0:
80 | tmp += n_obs // self.min_size_per_class
81 | else:
82 | tmp += ceil(n_obs / self.min_size_per_class)
83 |
84 | classes_per_batch = int(self.batch_size / self.min_size_per_class)
85 | self.length = ceil(tmp / classes_per_batch)
86 |
87 | def __iter__(self):
88 | classes_per_batch = int(self.batch_size / self.min_size_per_class)
89 |
90 | classes = list(dict.fromkeys(self.group_labels))
91 | data_iter = []
92 |
93 | for cl in classes:
94 | idx = np.where(self.group_labels == cl)[0]
95 | cl_idx = self.indices[idx]
96 | n_obs = len(cl_idx)
97 |
98 | if self.shuffle is True:
99 | idx = torch.randperm(n_obs).tolist()
100 | else:
101 | idx = torch.arange(n_obs).tolist()
102 |
103 | last_batch_len = n_obs % self.min_size_per_class
104 | if (self.drop_last is True) or (last_batch_len < self.drop_last):
105 | drop_last_n = last_batch_len
106 | elif (self.drop_last is False) or (last_batch_len >= self.drop_last):
107 | drop_last_n = 0
108 | else:
109 | raise ValueError("Invalid input for drop_last param. Must be bool or int.")
110 |
111 | if drop_last_n != 0:
112 | idx = idx[:-drop_last_n]
113 |
114 | data_iter.extend(
115 | [cl_idx[idx[i : i + self.min_size_per_class]] for i in range(0, len(idx), self.min_size_per_class)]
116 | )
117 |
118 | if self.shuffle_classes:
119 | idx = torch.randperm(len(data_iter)).tolist()
120 | data_iter = [data_iter[id] for id in idx]
121 |
122 | final_data_iter = []
123 |
124 | end = len(data_iter) - len(data_iter) % classes_per_batch
125 | for i in range(0, end, classes_per_batch):
126 | batch_idx = list(itertools.chain.from_iterable(data_iter[i : i + classes_per_batch]))
127 | final_data_iter.append(batch_idx)
128 |
129 | # deal with the last manually
130 | if end != len(data_iter):
131 | batch_idx = list(itertools.chain.from_iterable(data_iter[end:]))
132 | final_data_iter.append(batch_idx)
133 |
134 | return iter(final_data_iter)
135 |
136 | def __len__(self):
137 | return self.length
138 |
139 |
140 | # adjusted from scvi-tools
141 | # https://github.com/scverse/scvi-tools/blob/0b802762869c43c9f49e69fe62b1a5a9b5c4dae6/scvi/dataloaders/_ann_dataloader.py#L89
142 | # accessed on 5 November 2022
143 | class GroupAnnDataLoader(DataLoader):
144 | """DataLoader for loading tensors from AnnData objects.
145 |
146 | Parameters
147 | ----------
148 | adata_manager
149 | :class:`~scvi.data.AnnDataManager` object with a registered AnnData object.
150 | shuffle
151 | Whether the data should be shuffled.
152 | indices
153 | The indices of the observations in the adata to load.
154 | batch_size
155 | Minibatch size to load each iteration.
156 | data_and_attributes
157 | Dictionary with keys representing keys in data registry (`adata.uns["_scvi"]`)
158 | and value equal to desired numpy loading type (later made into torch tensor).
159 | If `None`, defaults to all registered data.
160 | data_loader_kwargs
161 | Keyword arguments for :class:`~torch.utils.data.DataLoader`.
162 | """
163 |
164 | def __init__(
165 | self,
166 | adata_manager: AnnDataManager,
167 | group_column: str,
168 | shuffle=True,
169 | shuffle_classes=True,
170 | indices=None,
171 | batch_size=128,
172 | min_size_per_class=None,
173 | data_and_attributes: dict | None = None,
174 | drop_last: bool | int = True,
175 | sampler: torch.utils.data.sampler.Sampler | None = StratifiedSampler,
176 | **data_loader_kwargs,
177 | ):
178 | if adata_manager.adata is None:
179 | raise ValueError("Please run register_fields() on your AnnDataManager object first.")
180 |
181 | if data_and_attributes is not None:
182 | data_registry = adata_manager.data_registry
183 | for key in data_and_attributes.keys():
184 | if key not in data_registry.keys():
185 | raise ValueError(f"{key} required for model but not registered with AnnDataManager.")
186 |
187 | if group_column not in adata_manager.registry["setup_args"]["categorical_covariate_keys"]:
188 | raise ValueError(
189 | f"{group_column} required for model but not in categorical covariates has to be one of the registered categorical covariates = {adata_manager.registry['setup_args']['categorical_covariate_keys']}."
190 | )
191 |
192 | self.dataset = AnnTorchDataset(adata_manager, getitem_tensors=data_and_attributes)
193 |
194 | if min_size_per_class is None:
195 | min_size_per_class = batch_size // 2
196 |
197 | sampler_kwargs = {
198 | "batch_size": batch_size,
199 | "shuffle": shuffle,
200 | "drop_last": drop_last,
201 | "min_size_per_class": min_size_per_class,
202 | "shuffle_classes": shuffle_classes,
203 | }
204 |
205 | if indices is None:
206 | indices = np.arange(len(self.dataset))
207 | sampler_kwargs["indices"] = indices
208 | else:
209 | if hasattr(indices, "dtype") and indices.dtype is np.dtype("bool"):
210 | indices = np.where(indices)[0].ravel()
211 | indices = np.asarray(indices)
212 | sampler_kwargs["indices"] = indices
213 |
214 | sampler_kwargs["group_labels"] = np.array(
215 | adata_manager.adata[indices].obsm["_scvi_extra_categorical_covs"][group_column]
216 | )
217 |
218 | self.indices = indices
219 | self.sampler_kwargs = sampler_kwargs
220 |
221 | sampler = sampler(**self.sampler_kwargs)
222 | self.data_loader_kwargs = copy.copy(data_loader_kwargs)
223 | # do not touch batch size here, sampler gives batched indices
224 | self.data_loader_kwargs.update({"sampler": sampler, "batch_size": None})
225 |
226 | super().__init__(self.dataset, **self.data_loader_kwargs)
227 |
--------------------------------------------------------------------------------
/src/multimil/dataloaders/_data_splitting.py:
--------------------------------------------------------------------------------
1 | from scvi.data import AnnDataManager
2 | from scvi.dataloaders import DataSplitter
3 |
4 | from multimil.dataloaders._ann_dataloader import GroupAnnDataLoader
5 |
6 |
7 | # adjusted from scvi-tools
8 | # https://github.com/scverse/scvi-tools/blob/0b802762869c43c9f49e69fe62b1a5a9b5c4dae6/scvi/dataloaders/_data_splitting.py#L56
9 | # accessed on 5 November 2022
10 | class GroupDataSplitter(DataSplitter):
11 | """Creates data loaders ``train_set``, ``validation_set``, ``test_set``.
12 |
13 | If ``train_size + validation_set < 1`` then ``test_set`` is non-empty.
14 |
15 | Parameters
16 | ----------
17 | adata_manager
18 | :class:`~scvi.data.AnnDataManager` object that has been created via ``setup_anndata``.
19 | train_size
20 | Proportion of cells to use as the train set. Float, or None (default is 0.9).
21 | validation_size
22 | Proportion of cell to use as the valisation set. Float, or None (default is None). If None, is set to 1 - ``train_size``.
23 | use_gpu
24 | Use default GPU if available (if None or True), or index of GPU to use (if int),
25 | or name of GPU (if str, e.g., `'cuda:0'`), or use CPU (if False).
26 | kwargs
27 | Keyword args for data loader. Data loader class is :class:`~mtg.dataloaders.GroupAnnDataLoader`.
28 | """
29 |
30 | def __init__(
31 | self,
32 | adata_manager: AnnDataManager,
33 | group_column: str,
34 | train_size: float = 0.9,
35 | validation_size: float | None = None,
36 | use_gpu: bool = False,
37 | **kwargs,
38 | ):
39 | self.group_column = group_column
40 | super().__init__(adata_manager, train_size, validation_size, use_gpu, **kwargs)
41 |
42 | def train_dataloader(self):
43 | """Return data loader for train AnnData."""
44 | return GroupAnnDataLoader(
45 | self.adata_manager,
46 | self.group_column,
47 | indices=self.train_idx,
48 | shuffle=True,
49 | drop_last=True,
50 | pin_memory=self.pin_memory,
51 | **self.data_loader_kwargs,
52 | )
53 |
54 | def val_dataloader(self):
55 | """Return data loader for validation AnnData."""
56 | if len(self.val_idx) > 0:
57 | return GroupAnnDataLoader(
58 | self.adata_manager,
59 | self.group_column,
60 | indices=self.val_idx,
61 | shuffle=False,
62 | drop_last=True,
63 | pin_memory=self.pin_memory,
64 | **self.data_loader_kwargs,
65 | )
66 | else:
67 | pass
68 |
69 | def test_dataloader(self):
70 | """Return data loader for test AnnData."""
71 | if len(self.test_idx) > 0:
72 | return GroupAnnDataLoader(
73 | self.adata_manager,
74 | self.group_column,
75 | indices=self.test_idx,
76 | shuffle=False,
77 | drop_last=True,
78 | pin_memory=self.pin_memory,
79 | **self.data_loader_kwargs,
80 | )
81 | else:
82 | pass
83 |
--------------------------------------------------------------------------------
/src/multimil/distributions/__init__.py:
--------------------------------------------------------------------------------
1 | from ._jeffreys import Jeffreys
2 | from ._mmd import MMD
3 |
4 | __all__ = ["MMD", "Jeffreys"]
5 |
--------------------------------------------------------------------------------
/src/multimil/distributions/_jeffreys.py:
--------------------------------------------------------------------------------
1 | # code adapted from https://github.com/scverse/scvi-tools/blob/c53efe06379c866e36e549afbb8158a120b82d14/src/scvi/module/_multivae.py#L900C1-L923C56
2 | # last accessed on 16th October 2024
3 | import torch
4 | from torch.distributions import Normal
5 | from torch.distributions import kl_divergence as kld
6 |
7 |
8 | class Jeffreys(torch.nn.Module):
9 | """Jeffreys divergence (Symmetric KL divergence) using torch distributions.
10 |
11 | Parameters
12 | ----------
13 | None
14 | """
15 |
16 | def __init__(self):
17 | super().__init__()
18 |
19 | def sym_kld(
20 | self,
21 | mu1: torch.Tensor,
22 | sigma1: torch.Tensor,
23 | mu2: torch.Tensor,
24 | sigma2: torch.Tensor,
25 | ) -> torch.Tensor:
26 | """Symmetric KL divergence between two Gaussians using torch.distributions.
27 |
28 | Parameters
29 | ----------
30 | mu1
31 | Mean of the first distribution.
32 | sigma1
33 | Variance of the first distribution (note: this will be square-rooted to get std dev).
34 | mu2
35 | Mean of the second distribution.
36 | sigma2
37 | Variance of the second distribution (note: this will be square-rooted to get std dev).
38 |
39 | Returns
40 | -------
41 | Symmetric KL divergence between the two distributions.
42 | """
43 | rv1 = Normal(mu1, sigma1.sqrt())
44 | rv2 = Normal(mu2, sigma2.sqrt())
45 |
46 | out = kld(rv1, rv2).mean() + kld(rv2, rv1).mean()
47 | return out
48 |
49 | def forward(self, params1: torch.Tensor, params2: torch.Tensor) -> torch.Tensor:
50 | """Forward computation for Jeffreys divergence.
51 |
52 | Parameters
53 | ----------
54 | params1
55 | A tuple of mean and variance (or standard deviation) of the first distribution.
56 | params2
57 | A tuple of mean and variance (or standard deviation) of the second distribution.
58 |
59 | Returns
60 | -------
61 | Jeffreys divergence between the two distributions.
62 | """
63 | mu1, sigma1 = params1
64 | mu2, sigma2 = params2
65 |
66 | # Ensure non-negative sigma (variance) values
67 | sigma1 = sigma1.clamp(min=1e-6)
68 | sigma2 = sigma2.clamp(min=1e-6)
69 |
70 | return self.sym_kld(mu1, sigma1, mu2, sigma2)
71 |
--------------------------------------------------------------------------------
/src/multimil/distributions/_mmd.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class MMD(torch.nn.Module):
5 | """Maximum mean discrepancy.
6 |
7 | Parameters
8 | ----------
9 | kernel_type
10 | Indicates if to use Gaussian kernel. One of
11 | * ``'gaussian'`` - use Gaussian kernel
12 | * ``'not gaussian'`` - do not use Gaussian kernel.
13 | """
14 |
15 | def __init__(self, kernel_type="gaussian"):
16 | super().__init__()
17 | self.kernel_type = kernel_type
18 | # TODO: add check for gaussian kernel that shapes are same
19 |
20 | def gaussian_kernel(
21 | self,
22 | x: torch.Tensor,
23 | y: torch.Tensor,
24 | gamma: list[float] | None = None,
25 | ) -> torch.Tensor:
26 | """Apply Guassian kernel.
27 |
28 | Parameters
29 | ----------
30 | x
31 | Tensor from the first distribution.
32 | y
33 | Tensor from the second distribution.
34 | gamma
35 | List of gamma parameters.
36 |
37 | Returns
38 | -------
39 | Gaussian kernel between ``x`` and ``y``.
40 | """
41 | if gamma is None:
42 | gamma = [
43 | 1e-6,
44 | 1e-5,
45 | 1e-4,
46 | 1e-3,
47 | 1e-2,
48 | 1e-1,
49 | 1,
50 | 5,
51 | 10,
52 | 15,
53 | 20,
54 | 25,
55 | 30,
56 | 35,
57 | 100,
58 | 1e3,
59 | 1e4,
60 | 1e5,
61 | 1e6,
62 | ]
63 |
64 | D = torch.cdist(x, y).pow(2)
65 | K = torch.zeros_like(D)
66 |
67 | for g in gamma:
68 | K.add_(torch.exp(D.mul(-g)))
69 |
70 | return K / len(gamma)
71 |
72 | def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
73 | """Forward computation.
74 |
75 | Adapted from
76 | Title: scarches
77 | Date: 9th Octover 2021
78 | Code version: 0.4.0
79 | Availability: https://github.com/theislab/scarches/blob/63a7c2b35a01e55fe7e1dd871add459a86cd27fb/scarches/models/trvae/losses.py
80 | Citation: Gretton, Arthur, et al. "A Kernel Two-Sample Test", 2012.
81 |
82 | Parameters
83 | ----------
84 | x
85 | Tensor with shape ``(batch_size, z_dim)``.
86 | y
87 | Tensor with shape ``(batch_size, z_dim)``.
88 |
89 | Returns
90 | -------
91 | MMD between ``x`` and ``y``.
92 | """
93 | # in case there is only one sample in a batch belonging to one of the groups, then skip the batch
94 | if len(x) == 1 or len(y) == 1:
95 | return torch.tensor(0.0)
96 |
97 | if self.kernel_type == "gaussian":
98 | Kxx = self.gaussian_kernel(x, x).mean()
99 | Kyy = self.gaussian_kernel(y, y).mean()
100 | Kxy = self.gaussian_kernel(x, y).mean()
101 | return Kxx + Kyy - 2 * Kxy
102 | else:
103 | mean_x = x.mean(0, keepdim=True)
104 | mean_y = y.mean(0, keepdim=True)
105 | cent_x = x - mean_x
106 | cent_y = y - mean_y
107 | cova_x = (cent_x.t() @ cent_x) / (len(x) - 1)
108 | cova_y = (cent_y.t() @ cent_y) / (len(y) - 1)
109 |
110 | mean_diff = (mean_x - mean_y).pow(2).mean()
111 | cova_diff = (cova_x - cova_y).pow(2).mean()
112 |
113 | return mean_diff + cova_diff
114 |
--------------------------------------------------------------------------------
/src/multimil/model/__init__.py:
--------------------------------------------------------------------------------
1 | from ._mil import MILClassifier
2 | from ._multivae import MultiVAE
3 | from ._multivae_mil import MultiVAE_MIL
4 |
5 | __all__ = ["MultiVAE", "MILClassifier", "MultiVAE_MIL"]
6 |
--------------------------------------------------------------------------------
/src/multimil/model/_mil.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import warnings
3 |
4 | import anndata as ad
5 | import torch
6 | from anndata import AnnData
7 | from pytorch_lightning.callbacks import ModelCheckpoint
8 | from scvi import REGISTRY_KEYS
9 | from scvi.data import AnnDataManager, fields
10 | from scvi.data._constants import _MODEL_NAME_KEY, _SETUP_ARGS_KEY
11 | from scvi.model._utils import parse_use_gpu_arg
12 | from scvi.model.base import ArchesMixin, BaseModelClass
13 | from scvi.model.base._archesmixin import _get_loaded_data
14 | from scvi.model.base._utils import _initialize_model
15 | from scvi.train import AdversarialTrainingPlan, TrainRunner
16 | from scvi.train._callbacks import SaveBestState
17 |
18 | from multimil.dataloaders import GroupAnnDataLoader, GroupDataSplitter
19 | from multimil.module import MILClassifierTorch
20 | from multimil.utils import (
21 | get_bag_info,
22 | get_predictions,
23 | plt_plot_losses,
24 | prep_minibatch,
25 | save_predictions_in_adata,
26 | select_covariates,
27 | setup_ordinal_regression,
28 | )
29 |
30 | logger = logging.getLogger(__name__)
31 |
32 |
33 | class MILClassifier(BaseModelClass, ArchesMixin):
34 | """MultiMIL MIL prediction model.
35 |
36 | Parameters
37 | ----------
38 | adata
39 | AnnData object containing embeddings and covariates.
40 | sample_key
41 | Key in `adata.obs` that corresponds to the sample covariate.
42 | classification
43 | List of keys in `adata.obs` that correspond to the classification covariates.
44 | regression
45 | List of keys in `adata.obs` that correspond to the regression covariates.
46 | ordinal_regression
47 | List of keys in `adata.obs` that correspond to the ordinal regression covariates.
48 | sample_batch_size
49 | Number of samples per bag, i.e. sample. Default is 128.
50 | normalization
51 | One of "layer" or "batch". Default is "layer".
52 | z_dim
53 | Dimensionality of the input latent space. Default is 16.
54 | dropout
55 | Dropout rate. Default is 0.2.
56 | scoring
57 | How to calculate attention scores. One of "gated_attn", "attn", "mean", "max", "sum", "MLP". Default is "gated_attn".
58 | attn_dim
59 | Dimensionality of the hidden layer in attention calculation. Default is 16.
60 | n_layers_cell_aggregator
61 | Number of layers in the cell aggregator. Default is 1.
62 | n_layers_classifier
63 | Number of layers in the classifier. Default is 2.
64 | n_layers_regressor
65 | Number of layers in the regressor. Default is 2.
66 | n_layers_mlp_attn
67 | Number of layers in the MLP attention. Only used if `scoring` = "MLP". Default is 1.
68 | n_hidden_cell_aggregator
69 | Number of hidden units in the cell aggregator. Default is 128.
70 | n_hidden_classifier
71 | Number of hidden units in the classifier. Default is 128.
72 | n_hidden_mlp_attn
73 | Number of hidden units in the MLP attention. Default is 32.
74 | n_hidden_regressor
75 | Number of hidden units in the regressor. Default is 128.
76 | class_loss_coef
77 | Coefficient for the classification loss. Default is 1.0.
78 | regression_loss_coef
79 | Coefficient for the regression loss. Default is 1.0.
80 | activation
81 | Activation function. Default is 'leaky_relu'.
82 | initialization
83 | Initialization method for the weights. Default is None.
84 | anneal_class_loss
85 | Whether to anneal the classification loss. Default is False.
86 | ignore_covariates
87 | List of covariates to ignore. Needed for query-to-reference mapping. Default is None.
88 | """
89 |
90 | def __init__(
91 | self,
92 | adata,
93 | sample_key,
94 | classification=None,
95 | regression=None,
96 | ordinal_regression=None,
97 | sample_batch_size=128,
98 | normalization="layer",
99 | z_dim=16, # TODO do we need it? can't we get it from adata?
100 | dropout=0.2,
101 | scoring="gated_attn", # TODO test if MLP is supported and if we want to keep it
102 | attn_dim=16,
103 | n_layers_cell_aggregator: int = 1,
104 | n_layers_classifier: int = 2,
105 | n_layers_regressor: int = 2,
106 | n_layers_mlp_attn: int = 1,
107 | n_hidden_cell_aggregator: int = 128,
108 | n_hidden_classifier: int = 128,
109 | n_hidden_mlp_attn: int = 32,
110 | n_hidden_regressor: int = 128,
111 | class_loss_coef=1.0,
112 | regression_loss_coef=1.0,
113 | activation="leaky_relu", # or tanh
114 | initialization=None, # xavier (tanh) or kaiming (leaky_relu)
115 | anneal_class_loss=False,
116 | ignore_covariates=None,
117 | ):
118 | super().__init__(adata)
119 |
120 | if classification is None:
121 | classification = []
122 | if regression is None:
123 | regression = []
124 | if ordinal_regression is None:
125 | ordinal_regression = []
126 |
127 | self.sample_key = sample_key
128 | self.scoring = scoring
129 |
130 | if self.sample_key not in self.adata_manager.registry["setup_args"]["categorical_covariate_keys"]:
131 | raise ValueError(
132 | f"Sample key = '{self.sample_key}' has to be one of the registered categorical covariates = {self.adata_manager.registry['setup_args']['categorical_covariate_keys']}"
133 | )
134 |
135 | if len(classification) + len(regression) + len(ordinal_regression) == 0:
136 | raise ValueError(
137 | 'At least one of "classification", "regression", "ordinal_regression" has to be specified.'
138 | )
139 |
140 | self.classification = classification
141 | self.regression = regression
142 | self.ordinal_regression = ordinal_regression
143 |
144 | # TODO check if all of the three above were registered with setup anndata
145 | # TODO add check that class is the same within a patient
146 | # TODO assert length of things is the same as number of modalities
147 | # TODO add that n_layers has to be > 0 for all
148 | # TODO warning if n_layers == 1 then n_hidden is not used for classifier and MLP attention
149 | # TODO warning if MLP attention is used but n layers and n hidden not given that using default values
150 | # TODO check that there is at least on ecovariate to predict
151 | if scoring == "MLP":
152 | if not n_layers_mlp_attn:
153 | n_layers_mlp_attn = 1
154 | if not n_hidden_mlp_attn:
155 | n_hidden_mlp_attn = 16
156 |
157 | self.regression_idx = []
158 | if len(cont_covs := self.adata_manager.get_state_registry(REGISTRY_KEYS.CONT_COVS_KEY)) > 0:
159 | for key in cont_covs["columns"]:
160 | if key in self.regression:
161 | self.regression_idx.append(list(cont_covs["columns"]).index(key))
162 | else: # only can happen when using multivae_mil
163 | if ignore_covariates is not None and key not in ignore_covariates:
164 | warnings.warn(
165 | f"Registered continuous covariate '{key}' is not in regression covariates so will be ignored.",
166 | stacklevel=2,
167 | )
168 |
169 | # classification and ordinal regression together here as ordinal regression values need to be registered as categorical covariates
170 | self.class_idx, self.ord_idx = [], []
171 | self.num_classification_classes = []
172 | if len(cat_covs := self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY)) > 0:
173 | for i, num_cat in enumerate(cat_covs.n_cats_per_key):
174 | cat_cov_name = cat_covs["field_keys"][i]
175 | if cat_cov_name in self.classification:
176 | self.num_classification_classes.append(num_cat)
177 | self.class_idx.append(i)
178 | elif cat_cov_name in self.ordinal_regression:
179 | self.num_classification_classes.append(num_cat)
180 | self.ord_idx.append(i)
181 | else: # the actual categorical covariate, only can happen when using multivae_mil
182 | if (
183 | ignore_covariates is not None
184 | and cat_cov_name not in ignore_covariates
185 | and cat_cov_name != self.sample_key
186 | ):
187 | warnings.warn(
188 | f"Registered categorical covariate '{cat_cov_name}' is not in classification or ordinal regression covariates and is not the sample covariate so will be ignored.",
189 | stacklevel=2,
190 | )
191 |
192 | for label in ordinal_regression:
193 | print(
194 | f"The order for {label} ordinal classes is: {adata.obs[label].cat.categories}. If you need to change the order, please rerun setup_anndata and specify the correct order with the `ordinal_regression_order` parameter."
195 | )
196 |
197 | self.class_idx = torch.tensor(self.class_idx)
198 | self.ord_idx = torch.tensor(self.ord_idx)
199 | self.regression_idx = torch.tensor(self.regression_idx)
200 |
201 | self.module = MILClassifierTorch(
202 | z_dim=z_dim,
203 | dropout=dropout,
204 | activation=activation,
205 | initialization=initialization,
206 | normalization=normalization,
207 | num_classification_classes=self.num_classification_classes,
208 | scoring=scoring,
209 | attn_dim=attn_dim,
210 | n_layers_cell_aggregator=n_layers_cell_aggregator,
211 | n_layers_classifier=n_layers_classifier,
212 | n_layers_mlp_attn=n_layers_mlp_attn,
213 | n_layers_regressor=n_layers_regressor,
214 | n_hidden_regressor=n_hidden_regressor,
215 | n_hidden_cell_aggregator=n_hidden_cell_aggregator,
216 | n_hidden_classifier=n_hidden_classifier,
217 | n_hidden_mlp_attn=n_hidden_mlp_attn,
218 | class_loss_coef=class_loss_coef,
219 | regression_loss_coef=regression_loss_coef,
220 | sample_batch_size=sample_batch_size,
221 | class_idx=self.class_idx,
222 | ord_idx=self.ord_idx,
223 | reg_idx=self.regression_idx,
224 | anneal_class_loss=anneal_class_loss,
225 | )
226 |
227 | self.init_params_ = self._get_init_params(locals())
228 |
229 | def train(
230 | self,
231 | max_epochs: int = 200,
232 | lr: float = 5e-4,
233 | use_gpu: str | int | bool | None = None,
234 | train_size: float = 0.9,
235 | validation_size: float | None = None,
236 | batch_size: int = 256,
237 | weight_decay: float = 1e-3,
238 | eps: float = 1e-08,
239 | early_stopping: bool = True,
240 | save_best: bool = True,
241 | check_val_every_n_epoch: int | None = None,
242 | n_epochs_kl_warmup: int | None = None,
243 | n_steps_kl_warmup: int | None = None,
244 | adversarial_mixing: bool = False,
245 | plan_kwargs: dict | None = None,
246 | early_stopping_monitor: str | None = "accuracy_validation",
247 | early_stopping_mode: str | None = "max",
248 | save_checkpoint_every_n_epochs: int | None = None,
249 | path_to_checkpoints: str | None = None,
250 | **kwargs,
251 | ):
252 | """Trains the model using amortized variational inference.
253 |
254 | Parameters
255 | ----------
256 | max_epochs
257 | Number of passes through the dataset.
258 | lr
259 | Learning rate for optimization.
260 | use_gpu
261 | Use default GPU if available (if None or True), or index of GPU to use (if int),
262 | or name of GPU (if str), or use CPU (if False).
263 | train_size
264 | Size of training set in the range [0.0, 1.0].
265 | validation_size
266 | Size of the test set. If `None`, defaults to 1 - `train_size`. If
267 | `train_size + validation_size < 1`, the remaining cells belong to a test set.
268 | batch_size
269 | Minibatch size to use during training.
270 | weight_decay
271 | weight decay regularization term for optimization
272 | eps
273 | Optimizer eps
274 | early_stopping
275 | Whether to perform early stopping with respect to the validation set.
276 | save_best
277 | Save the best model state with respect to the validation loss, or use the final
278 | state in the training procedure
279 | check_val_every_n_epoch
280 | Check val every n train epochs. By default, val is not checked, unless `early_stopping` is `True`.
281 | If so, val is checked every epoch.
282 | n_epochs_kl_warmup
283 | Number of epochs to scale weight on KL divergences from 0 to 1.
284 | Overrides `n_steps_kl_warmup` when both are not `None`. Default is 1/3 of `max_epochs`.
285 | n_steps_kl_warmup
286 | Number of training steps (minibatches) to scale weight on KL divergences from 0 to 1.
287 | Only activated when `n_epochs_kl_warmup` is set to None. If `None`, defaults
288 | to `floor(0.75 * adata.n_obs)`
289 | adversarial_mixing
290 | Whether to use adversarial mixing. Default is False.
291 | plan_kwargs
292 | Keyword args for :class:`~scvi.train.TrainingPlan`. Keyword arguments passed to
293 | `train()` will overwrite values present in `plan_kwargs`, when appropriate.
294 | early_stopping_monitor
295 | Metric to monitor for early stopping. Default is "accuracy_validation".
296 | early_stopping_mode
297 | One of "min" or "max". Default is "max".
298 | save_checkpoint_every_n_epochs
299 | Save a checkpoint every n epochs.
300 | path_to_checkpoints
301 | Path to save checkpoints.
302 | **kwargs
303 | Other keyword args for :class:`~scvi.train.Trainer`.
304 |
305 | Returns
306 | -------
307 | Trainer object.
308 | """
309 | # TODO put in a function, return params needed for splitter, plan and runner, then can call the function from multivae_mil
310 | if len(self.regression) > 0:
311 | if early_stopping_monitor == "accuracy_validation":
312 | warnings.warn(
313 | "Setting early_stopping_monitor to 'regression_loss_validation' and early_stopping_mode to 'min' as regression is used.",
314 | stacklevel=2,
315 | )
316 | early_stopping_monitor = "regression_loss_validation"
317 | early_stopping_mode = "min"
318 | if n_epochs_kl_warmup is None:
319 | n_epochs_kl_warmup = max(max_epochs // 3, 1)
320 | update_dict = {
321 | "lr": lr,
322 | "adversarial_classifier": adversarial_mixing,
323 | "weight_decay": weight_decay,
324 | "eps": eps,
325 | "n_epochs_kl_warmup": n_epochs_kl_warmup,
326 | "n_steps_kl_warmup": n_steps_kl_warmup,
327 | "optimizer": "AdamW",
328 | "scale_adversarial_loss": 1,
329 | }
330 | if plan_kwargs is not None:
331 | plan_kwargs.update(update_dict)
332 | else:
333 | plan_kwargs = update_dict
334 |
335 | if save_best:
336 | if "callbacks" not in kwargs.keys():
337 | kwargs["callbacks"] = []
338 | kwargs["callbacks"].append(SaveBestState(monitor=early_stopping_monitor, mode=early_stopping_mode))
339 |
340 | if save_checkpoint_every_n_epochs is not None:
341 | if path_to_checkpoints is not None:
342 | kwargs["callbacks"].append(
343 | ModelCheckpoint(
344 | dirpath=path_to_checkpoints,
345 | save_top_k=-1,
346 | monitor="epoch",
347 | every_n_epochs=save_checkpoint_every_n_epochs,
348 | verbose=True,
349 | )
350 | )
351 | else:
352 | raise ValueError(
353 | f"`save_checkpoint_every_n_epochs` = {save_checkpoint_every_n_epochs} so `path_to_checkpoints` has to be not None but is {path_to_checkpoints}."
354 | )
355 | # until here
356 |
357 | data_splitter = GroupDataSplitter(
358 | self.adata_manager,
359 | group_column=self.sample_key,
360 | train_size=train_size,
361 | validation_size=validation_size,
362 | batch_size=batch_size,
363 | use_gpu=use_gpu,
364 | )
365 |
366 | training_plan = AdversarialTrainingPlan(self.module, **plan_kwargs)
367 | runner = TrainRunner(
368 | self,
369 | training_plan=training_plan,
370 | data_splitter=data_splitter,
371 | max_epochs=max_epochs,
372 | use_gpu=use_gpu,
373 | early_stopping=early_stopping,
374 | check_val_every_n_epoch=check_val_every_n_epoch,
375 | early_stopping_monitor=early_stopping_monitor,
376 | early_stopping_mode=early_stopping_mode,
377 | early_stopping_patience=50,
378 | enable_checkpointing=True,
379 | **kwargs,
380 | )
381 | return runner()
382 |
383 | @classmethod
384 | def setup_anndata(
385 | cls,
386 | adata: ad.AnnData,
387 | categorical_covariate_keys: list[str] | None = None,
388 | continuous_covariate_keys: list[str] | None = None,
389 | ordinal_regression_order: dict[str, list[str]] | None = None,
390 | **kwargs,
391 | ):
392 | """Set up :class:`~anndata.AnnData` object.
393 |
394 | A mapping will be created between data fields used by ``scvi`` to their respective locations in adata.
395 | This method will also compute the log mean and log variance per batch for the library size prior.
396 | None of the data in adata are modified. Only adds fields to adata.
397 |
398 | Parameters
399 | ----------
400 | adata
401 | AnnData object containing raw counts. Rows represent cells, columns represent features.
402 | categorical_covariate_keys
403 | Keys in `adata.obs` that correspond to categorical data.
404 | continuous_covariate_keys
405 | Keys in `adata.obs` that correspond to continuous data.
406 | ordinal_regression_order
407 | Dictionary with regression classes as keys and order of classes as values.
408 | kwargs
409 | Additional parameters to pass to register_fields() of AnnDataManager.
410 | """
411 | setup_ordinal_regression(adata, ordinal_regression_order, categorical_covariate_keys)
412 |
413 | setup_method_args = cls._get_setup_method_args(**locals())
414 |
415 | anndata_fields = [
416 | fields.LayerField(
417 | REGISTRY_KEYS.X_KEY,
418 | layer=None,
419 | ),
420 | fields.CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, None),
421 | fields.CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys),
422 | fields.NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys),
423 | ]
424 |
425 | adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
426 | adata_manager.register_fields(adata, **kwargs)
427 | cls.register_manager(adata_manager)
428 |
429 | @torch.inference_mode()
430 | def get_model_output(
431 | self,
432 | adata=None,
433 | batch_size=256,
434 | ):
435 | """Save the attention scores and predictions in the adata object.
436 |
437 | Parameters
438 | ----------
439 | adata
440 | AnnData object to run the model on. If `None`, the model's AnnData object is used.
441 | batch_size
442 | Minibatch size to use. Default is 256.
443 |
444 | """
445 | if not self.is_trained_:
446 | raise RuntimeError("Please train the model first.")
447 |
448 | adata = self._validate_anndata(adata)
449 |
450 | scdl = self._make_data_loader(
451 | adata=adata,
452 | batch_size=batch_size,
453 | min_size_per_class=batch_size, # hack to ensure that not full batches are processed properly
454 | data_loader_class=GroupAnnDataLoader,
455 | shuffle=False,
456 | shuffle_classes=False,
457 | group_column=self.sample_key,
458 | drop_last=False,
459 | )
460 |
461 | cell_level_attn, bags = [], []
462 | class_pred, ord_pred, reg_pred = {}, {}, {}
463 | (
464 | bag_class_true,
465 | bag_class_pred,
466 | bag_reg_true,
467 | bag_reg_pred,
468 | bag_ord_true,
469 | bag_ord_pred,
470 | ) = ({}, {}, {}, {}, {}, {})
471 |
472 | bag_counter = 0
473 | cell_counter = 0
474 |
475 | for tensors in scdl:
476 | cont_key = REGISTRY_KEYS.CONT_COVS_KEY
477 | cont_covs = tensors[cont_key] if cont_key in tensors.keys() else None
478 |
479 | cat_key = REGISTRY_KEYS.CAT_COVS_KEY
480 | cat_covs = tensors[cat_key] if cat_key in tensors.keys() else None
481 |
482 | inference_inputs = self.module._get_inference_input(tensors)
483 | outputs = self.module.inference(**inference_inputs)
484 | pred = outputs["predictions"]
485 |
486 | # get attention for each cell in the bag
487 | if self.scoring in ["gated_attn", "attn"]:
488 | cell_attn = self.module.cell_level_aggregator[-1].A.squeeze(dim=1)
489 | cell_attn = cell_attn.flatten() # in inference always one patient per batch
490 | cell_level_attn += [cell_attn.cpu()]
491 |
492 | assert outputs["z_joint"].shape[0] % pred[0].shape[0] == 0
493 | sample_size = outputs["z_joint"].shape[0] // pred[0].shape[0] # how many cells in patient_minibatch
494 | minibatch_size, n_samples_in_batch = prep_minibatch(cat_covs, self.module.sample_batch_size)
495 | regression = select_covariates(cont_covs, self.regression_idx, n_samples_in_batch)
496 | ordinal_regression = select_covariates(cat_covs, self.ord_idx, n_samples_in_batch)
497 | classification = select_covariates(cat_covs, self.class_idx, n_samples_in_batch)
498 |
499 | # calculate accuracies of predictions
500 | bag_class_pred, bag_class_true, class_pred = get_predictions(
501 | self.class_idx, pred, classification, sample_size, bag_class_pred, bag_class_true, class_pred
502 | )
503 | bag_ord_pred, bag_ord_true, ord_pred = get_predictions(
504 | self.ord_idx,
505 | pred,
506 | ordinal_regression,
507 | sample_size,
508 | bag_ord_pred,
509 | bag_ord_true,
510 | ord_pred,
511 | len(self.class_idx),
512 | )
513 | bag_reg_pred, bag_reg_true, reg_pred = get_predictions(
514 | self.regression_idx,
515 | pred,
516 | regression,
517 | sample_size,
518 | bag_reg_pred,
519 | bag_reg_true,
520 | reg_pred,
521 | len(self.class_idx) + len(self.ord_idx),
522 | )
523 |
524 | # TODO remove n_of_bags_in_batch after testing
525 | n_of_bags_in_batch = pred[0].shape[0]
526 | assert n_samples_in_batch == n_of_bags_in_batch
527 |
528 | # save bag info to be able to calculate bag predictions later
529 | bags, cell_counter, bag_counter = get_bag_info(
530 | bags, n_samples_in_batch, minibatch_size, cell_counter, bag_counter, self.module.sample_batch_size
531 | )
532 |
533 | if self.scoring in ["gated_attn", "attn"]:
534 | cell_level = torch.cat(cell_level_attn).numpy()
535 | adata.obs["cell_attn"] = cell_level
536 | flat_bags = [value for sublist in bags for value in sublist]
537 | adata.obs["bags"] = flat_bags
538 |
539 | for i in range(len(self.class_idx)):
540 | name = self.classification[i]
541 | class_names = self.adata_manager.get_state_registry("extra_categorical_covs")["mappings"][name]
542 | save_predictions_in_adata(
543 | adata,
544 | i,
545 | self.classification,
546 | bag_class_pred,
547 | bag_class_true,
548 | class_pred,
549 | class_names,
550 | name,
551 | clip="argmax",
552 | )
553 | for i in range(len(self.ord_idx)):
554 | name = self.ordinal_regression[i]
555 | class_names = self.adata_manager.get_state_registry("extra_categorical_covs")["mappings"][name]
556 | save_predictions_in_adata(
557 | adata, i, self.ordinal_regression, bag_ord_pred, bag_ord_true, ord_pred, class_names, name, clip="clip"
558 | )
559 | for i in range(len(self.regression_idx)):
560 | name = self.regression[i]
561 | reg_names = self.adata_manager.get_state_registry("extra_continuous_covs")["columns"]
562 | save_predictions_in_adata(
563 | adata, i, self.regression, bag_reg_pred, bag_reg_true, reg_pred, reg_names, name, clip=None, reg=True
564 | )
565 |
566 | def plot_losses(self, save=None):
567 | """Plot losses.
568 |
569 | Parameters
570 | ----------
571 | save
572 | If not None, save the plot to this location.
573 | """
574 | loss_names = self.module.select_losses_to_plot()
575 | plt_plot_losses(self.history, loss_names, save)
576 |
577 | # adjusted from scvi-tools
578 | # https://github.com/scverse/scvi-tools/blob/0b802762869c43c9f49e69fe62b1a5a9b5c4dae6/scvi/model/base/_archesmixin.py#L30
579 | # accessed on 7 November 2022
580 | @classmethod
581 | def load_query_data(
582 | cls,
583 | adata: AnnData,
584 | reference_model: BaseModelClass,
585 | use_gpu: str | int | bool | None = None,
586 | ) -> BaseModelClass:
587 | """Online update of a reference model with scArches algorithm # TODO cite.
588 |
589 | Parameters
590 | ----------
591 | adata
592 | AnnData organized in the same way as data used to train model.
593 | It is not necessary to run setup_anndata,
594 | as AnnData is validated against the ``registry``.
595 | reference_model
596 | Already instantiated model of the same class.
597 | use_gpu
598 | Load model on default GPU if available (if None or True),
599 | or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False).
600 |
601 | Returns
602 | -------
603 | Model with updated architecture and weights.
604 | """
605 | # currently this function works only if the prediction cov is present in the .obs of the query
606 | # TODO need to allow it to be missing, maybe add a dummy column to .obs of query adata
607 |
608 | _, _, device = parse_use_gpu_arg(use_gpu)
609 |
610 | attr_dict, _, _ = _get_loaded_data(reference_model, device=device)
611 |
612 | registry = attr_dict.pop("registry_")
613 | if _MODEL_NAME_KEY in registry and registry[_MODEL_NAME_KEY] != cls.__name__:
614 | raise ValueError("It appears you are loading a model from a different class.")
615 |
616 | if _SETUP_ARGS_KEY not in registry:
617 | raise ValueError("Saved model does not contain original setup inputs. " "Cannot load the original setup.")
618 |
619 | cls.setup_anndata(
620 | adata,
621 | source_registry=registry,
622 | extend_categories=True,
623 | allow_missing_labels=True,
624 | **registry[_SETUP_ARGS_KEY],
625 | )
626 |
627 | model = _initialize_model(cls, adata, attr_dict)
628 | model.module.load_state_dict(reference_model.module.state_dict())
629 | model.to_device(device)
630 |
631 | model.module.eval()
632 | model.is_trained_ = True
633 |
634 | return model
635 |
--------------------------------------------------------------------------------
/src/multimil/model/_multivae.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Literal
3 |
4 | import anndata as ad
5 | import torch
6 | from pytorch_lightning.callbacks import ModelCheckpoint
7 | from scvi import REGISTRY_KEYS
8 | from scvi.data import AnnDataManager, fields
9 | from scvi.data._constants import _MODEL_NAME_KEY, _SETUP_ARGS_KEY
10 | from scvi.dataloaders import DataSplitter
11 | from scvi.model._utils import parse_use_gpu_arg
12 | from scvi.model.base import ArchesMixin, BaseModelClass
13 | from scvi.model.base._archesmixin import _get_loaded_data
14 | from scvi.model.base._utils import _initialize_model
15 | from scvi.train import AdversarialTrainingPlan, TrainRunner
16 | from scvi.train._callbacks import SaveBestState
17 |
18 | from multimil.dataloaders import GroupDataSplitter
19 | from multimil.module import MultiVAETorch
20 | from multimil.utils import calculate_size_factor, plt_plot_losses
21 |
22 | logger = logging.getLogger(__name__)
23 |
24 |
25 | class MultiVAE(BaseModelClass, ArchesMixin):
26 | """MultiMIL multimodal integration model.
27 |
28 | Parameters
29 | ----------
30 | adata
31 | AnnData object that has been registered via :meth:`~multigrate.model.MultiVAE.setup_anndata`.
32 | integrate_on
33 | One of the categorical covariates refistered with :math:`~multigrate.model.MultiVAE.setup_anndata` to integrate on. The latent space then will be disentangled from this covariate. If `None`, no integration is performed.
34 | condition_encoders
35 | Whether to concatentate covariate embeddings to the first layer of the encoders. Default is `False`.
36 | condition_decoders
37 | Whether to concatentate covariate embeddings to the first layer of the decoders. Default is `True`.
38 | normalization
39 | What normalization to use; has to be one of `batch` or `layer`. Default is `layer`.
40 | z_dim
41 | Dimensionality of the latent space. Default is 16.
42 | losses
43 | Which losses to use for each modality. Has to be the same length as the number of modalities. Default is `MSE` for all modalities.
44 | dropout
45 | Dropout rate. Default is 0.2.
46 | cond_dim
47 | Dimensionality of the covariate embeddings. Default is 16.
48 | kernel_type
49 | Type of kernel to use for the MMD loss. Default is `gaussian`.
50 | loss_coefs
51 | Loss coeficients for the different losses in the model. Default is 1 for all.
52 | cont_cov_type
53 | How to calculate embeddings for continuous covariates. Default is `logsim`.
54 | n_layers_cont_embed
55 | Number of layers for the continuous covariate embedding calculation. Default is 1.
56 | n_layers_encoders
57 | Number of layers for each encoder. Default is 2 for all modalities. Has to be the same length as the number of modalities.
58 | n_layers_decoders
59 | Number of layers for each decoder. Default is 2 for all modalities. Has to be the same length as the number of modalities.
60 | n_hidden_cont_embed
61 | Number of nodes for each hidden layer in the continuous covariate embedding calculation. Default is 32.
62 | n_hidden_encoders
63 | Number of nodes for each hidden layer in the encoders. Default is 32.
64 | n_hidden_decoders
65 | Number of nodes for each hidden layer in the decoders. Default is 32.
66 | modality_alignment
67 | Whether to align the modalities, one of ['MMD', 'Jeffreys', None]. Default is `None`.
68 | alignment_type
69 | Which alignment type to use, one of ['latent', 'marginal', 'both']. Default is `latent`.
70 | activation
71 | Activation function to use. Default is `leaky_relu`.
72 | initialization
73 | Initialization method to use. Default is `None`.
74 | ignore_covariates
75 | List of covariates to ignore. Needed for query-to-reference mapping. Default is `None`.
76 | mix
77 | How to mix the distributions to get the joint, one of ['product', 'mixture']. Default is `product`.
78 | """
79 |
80 | def __init__(
81 | self,
82 | adata: ad.AnnData,
83 | integrate_on: str | None = None,
84 | condition_encoders: bool = True,
85 | condition_decoders: bool = True,
86 | normalization: Literal["layer", "batch", None] = "layer",
87 | z_dim: int = 16,
88 | losses: list[str] | None = None,
89 | dropout: float = 0.2,
90 | cond_dim: int = 16,
91 | kernel_type: Literal["gaussian", None] = "gaussian",
92 | loss_coefs: dict[int, float] = None,
93 | cont_cov_type: Literal["logsim", "sigm", None] = "logsigm",
94 | n_layers_cont_embed: int = 1, # TODO default to None?
95 | n_layers_encoders: list[int] | None = None,
96 | n_layers_decoders: list[int] | None = None,
97 | n_hidden_cont_embed: int = 32, # TODO default to None?
98 | n_hidden_encoders: list[int] | None = None,
99 | n_hidden_decoders: list[int] | None = None,
100 | modality_alignment: Literal["MMD", "Jeffreys", None] = None,
101 | alignment_type: Literal["latent", "marginal", "both"] = "latent",
102 | activation: str | None = "leaky_relu", # TODO add which options are impelemted
103 | initialization: str | None = None, # TODO add which options are impelemted
104 | ignore_covariates: list[str] | None = None,
105 | mix: Literal["product", "mixture"] = "product",
106 | ):
107 | super().__init__(adata)
108 |
109 | # for the integration with the alignment loss
110 | self.group_column = integrate_on
111 |
112 | # TODO: add options for number of hidden layers, hidden layers dim and output activation functions
113 | if normalization not in ["layer", "batch", None]:
114 | raise ValueError('Normalization has to be one of ["layer", "batch", None]')
115 | # TODO: do some assertions for other parameters
116 |
117 | if ignore_covariates is None:
118 | ignore_covariates = []
119 |
120 | if (
121 | "nb" in losses or "zinb" in losses
122 | ) and REGISTRY_KEYS.SIZE_FACTOR_KEY not in self.adata_manager.data_registry:
123 | raise ValueError(f"Have to register {REGISTRY_KEYS.SIZE_FACTOR_KEY} when using 'nb' or 'zinb' loss.")
124 |
125 | self.num_groups = 1
126 | self.integrate_on_idx = None
127 | if integrate_on is not None:
128 | if integrate_on not in self.adata_manager.registry["setup_args"]["categorical_covariate_keys"]:
129 | raise ValueError(
130 | "Cannot integrate on {!r}, has to be one of the registered categorical covariates = {}".format(
131 | integrate_on, self.adata_manager.registry["setup_args"]["categorical_covariate_keys"]
132 | )
133 | )
134 | elif integrate_on in ignore_covariates:
135 | raise ValueError(
136 | f"Specified integrate_on = {integrate_on!r} is in ignore_covariates = {ignore_covariates}."
137 | )
138 | else:
139 | self.num_groups = len(
140 | self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY)["mappings"][integrate_on]
141 | )
142 | self.integrate_on_idx = self.adata_manager.registry["setup_args"]["categorical_covariate_keys"].index(
143 | integrate_on
144 | )
145 |
146 | self.modality_lengths = [
147 | adata.uns["modality_lengths"][key] for key in sorted(adata.uns["modality_lengths"].keys())
148 | ]
149 |
150 | self.cont_covs_idx = []
151 | self.cont_covariate_dims = []
152 | if len(cont_covs := self.adata_manager.get_state_registry(REGISTRY_KEYS.CONT_COVS_KEY)) > 0:
153 | for i, key in enumerate(cont_covs["columns"]):
154 | if key not in ignore_covariates:
155 | self.cont_covs_idx.append(i)
156 | self.cont_covariate_dims.append(1)
157 |
158 | self.cat_covs_idx = []
159 | self.cat_covariate_dims = []
160 | if len(cat_covs := self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY)) > 0:
161 | for i, num_cat in enumerate(cat_covs.n_cats_per_key):
162 | if cat_covs["field_keys"][i] not in ignore_covariates:
163 | self.cat_covs_idx.append(i)
164 | self.cat_covariate_dims.append(num_cat)
165 |
166 | self.cat_covs_idx = torch.tensor(self.cat_covs_idx)
167 | self.cont_covs_idx = torch.tensor(self.cont_covs_idx)
168 |
169 | self.module = MultiVAETorch(
170 | modality_lengths=self.modality_lengths,
171 | condition_encoders=condition_encoders,
172 | condition_decoders=condition_decoders,
173 | normalization=normalization,
174 | z_dim=z_dim,
175 | losses=losses,
176 | dropout=dropout,
177 | cond_dim=cond_dim,
178 | kernel_type=kernel_type,
179 | loss_coefs=loss_coefs,
180 | num_groups=self.num_groups,
181 | integrate_on_idx=self.integrate_on_idx,
182 | cat_covs_idx=self.cat_covs_idx,
183 | cont_covs_idx=self.cont_covs_idx,
184 | cat_covariate_dims=self.cat_covariate_dims,
185 | cont_covariate_dims=self.cont_covariate_dims,
186 | n_layers_encoders=n_layers_encoders,
187 | n_layers_decoders=n_layers_decoders,
188 | n_hidden_encoders=n_hidden_encoders,
189 | n_hidden_decoders=n_hidden_decoders,
190 | cont_cov_type=cont_cov_type,
191 | n_layers_cont_embed=n_layers_cont_embed,
192 | n_hidden_cont_embed=n_hidden_cont_embed,
193 | modality_alignment=modality_alignment,
194 | alignment_type=alignment_type,
195 | activation=activation,
196 | initialization=initialization,
197 | mix=mix,
198 | )
199 |
200 | self.init_params_ = self._get_init_params(locals())
201 |
202 | @torch.inference_mode()
203 | def impute(self, adata=None, batch_size=256):
204 | """Impute missing values in the adata object.
205 |
206 | Parameters
207 | ----------
208 | adata
209 | AnnData object to run the model on. If `None`, the model's AnnData object is used.
210 | batch_size
211 | Minibatch size to use. Default is 256.
212 | """
213 | if not self.is_trained_:
214 | raise RuntimeError("Please train the model first.")
215 |
216 | adata = self._validate_anndata(adata)
217 |
218 | scdl = self._make_data_loader(adata=adata, batch_size=batch_size)
219 |
220 | imputed = [[] for _ in range(len(self.modality_lengths))]
221 |
222 | for tensors in scdl:
223 | inference_inputs = self.module._get_inference_input(tensors)
224 | inference_outputs = self.module.inference(**inference_inputs)
225 | generative_inputs = self.module._get_generative_input(tensors, inference_outputs)
226 | outputs = self.module.generative(**generative_inputs)
227 | for i, output in enumerate(outputs["rs"]):
228 | imputed[i] += [output.cpu()]
229 | for i in range(len(imputed)):
230 | imputed[i] = torch.cat(imputed[i]).numpy()
231 | adata.obsm[f"imputed_modality_{i}"] = imputed[i]
232 |
233 | @torch.inference_mode()
234 | def get_model_output(self, adata=None, batch_size=256):
235 | """Save the latent representation in the adata object.
236 |
237 | Parameters
238 | ----------
239 | adata
240 | AnnData object to run the model on. If `None`, the model's AnnData object is used.
241 | batch_size
242 | Minibatch size to use. Default is 256.
243 | """
244 | if not self.is_trained_:
245 | raise RuntimeError("Please train the model first.")
246 |
247 | adata = self._validate_anndata(adata)
248 |
249 | scdl = self._make_data_loader(adata=adata, batch_size=batch_size)
250 |
251 | latent = []
252 | for tensors in scdl:
253 | inference_inputs = self.module._get_inference_input(tensors)
254 | outputs = self.module.inference(**inference_inputs)
255 | z = outputs["z_joint"]
256 | latent += [z.cpu()]
257 |
258 | adata.obsm["X_multiMIL"] = torch.cat(latent).numpy()
259 |
260 | def train(
261 | self,
262 | max_epochs: int = 200,
263 | lr: float = 5e-4,
264 | use_gpu: str | int | bool | None = None,
265 | train_size: float = 0.9,
266 | validation_size: float | None = None,
267 | batch_size: int = 256,
268 | weight_decay: float = 1e-3,
269 | eps: float = 1e-08,
270 | early_stopping: bool = True,
271 | save_best: bool = True,
272 | check_val_every_n_epoch: int | None = None,
273 | n_epochs_kl_warmup: int | None = None,
274 | n_steps_kl_warmup: int | None = None,
275 | adversarial_mixing: bool = False, # TODO check if suppored by us, i don't think it is
276 | plan_kwargs: dict | None = None,
277 | save_checkpoint_every_n_epochs: int | None = None,
278 | path_to_checkpoints: str | None = None,
279 | **kwargs,
280 | ):
281 | """Train the model using amortized variational inference.
282 |
283 | Parameters
284 | ----------
285 | max_epochs
286 | Number of passes through the dataset.
287 | lr
288 | Learning rate for optimization.
289 | use_gpu
290 | Use default GPU if available (if None or True), or index of GPU to use (if int),
291 | or name of GPU (if str), or use CPU (if False).
292 | train_size
293 | Size of training set in the range [0.0, 1.0].
294 | validation_size
295 | Size of the test set. If `None`, defaults to 1 - `train_size`. If
296 | `train_size + validation_size < 1`, the remaining cells belong to a test set.
297 | batch_size
298 | Minibatch size to use during training.
299 | weight_decay
300 | Weight decay regularization term for optimization.
301 | eps
302 | Optimizer eps.
303 | early_stopping
304 | Whether to perform early stopping with respect to the validation set.
305 | save_best
306 | Save the best model state with respect to the validation loss, or use the final
307 | state in the training procedure.
308 | check_val_every_n_epoch
309 | Check val every n train epochs. By default, val is not checked, unless `early_stopping` is `True`.
310 | If so, val is checked every epoch.
311 | n_epochs_kl_warmup
312 | Number of epochs to scale weight on KL divergences from 0 to 1.
313 | Overrides `n_steps_kl_warmup` when both are not `None`. Default is 1/3 of `max_epochs`.
314 | n_steps_kl_warmup
315 | Number of training steps (minibatches) to scale weight on KL divergences from 0 to 1.
316 | Only activated when `n_epochs_kl_warmup` is set to None. If `None`, defaults
317 | to `floor(0.75 * adata.n_obs)`.
318 | adversarial_mixing
319 | Whether to use adversarial mixing. Default is `False`.
320 | plan_kwargs
321 | Keyword args for :class:`~scvi.train.TrainingPlan`. Keyword arguments passed to
322 | `train()` will overwrite values present in `plan_kwargs`, when appropriate.
323 | save_checkpoint_every_n_epochs
324 | Save a checkpoint every n epochs. If `None`, no checkpoints are saved.
325 | path_to_checkpoints
326 | Path to save checkpoints. Required if `save_checkpoint_every_n_epochs` is not `None`.
327 | kwargs
328 | Additional keyword arguments for :class:`~scvi.train.TrainRunner`.
329 |
330 | Returns
331 | -------
332 | Trainer object.
333 | """
334 | if n_epochs_kl_warmup is None:
335 | n_epochs_kl_warmup = max(max_epochs // 3, 1)
336 | update_dict = {
337 | "lr": lr,
338 | "adversarial_classifier": adversarial_mixing,
339 | "weight_decay": weight_decay,
340 | "eps": eps,
341 | "n_epochs_kl_warmup": n_epochs_kl_warmup,
342 | "n_steps_kl_warmup": n_steps_kl_warmup,
343 | "optimizer": "AdamW",
344 | "scale_adversarial_loss": 1,
345 | }
346 | if plan_kwargs is not None:
347 | plan_kwargs.update(update_dict)
348 | else:
349 | plan_kwargs = update_dict
350 |
351 | if save_best:
352 | if "callbacks" not in kwargs.keys():
353 | kwargs["callbacks"] = []
354 | kwargs["callbacks"].append(SaveBestState(monitor="reconstruction_loss_validation"))
355 |
356 | if save_checkpoint_every_n_epochs is not None:
357 | if path_to_checkpoints is not None:
358 | kwargs["callbacks"].append(
359 | ModelCheckpoint(
360 | dirpath=path_to_checkpoints,
361 | save_top_k=-1,
362 | monitor="epoch",
363 | every_n_epochs=save_checkpoint_every_n_epochs,
364 | verbose=True,
365 | )
366 | )
367 | else:
368 | raise ValueError(
369 | f"`save_checkpoint_every_n_epochs` = {save_checkpoint_every_n_epochs} so `path_to_checkpoints` has to be not None but is {path_to_checkpoints}."
370 | )
371 |
372 | if self.group_column is not None:
373 | data_splitter = GroupDataSplitter(
374 | self.adata_manager,
375 | group_column=self.group_column,
376 | train_size=train_size,
377 | validation_size=validation_size,
378 | batch_size=batch_size,
379 | use_gpu=use_gpu,
380 | )
381 | else:
382 | data_splitter = DataSplitter(
383 | self.adata_manager,
384 | train_size=train_size,
385 | validation_size=validation_size,
386 | batch_size=batch_size,
387 | use_gpu=use_gpu,
388 | )
389 | training_plan = AdversarialTrainingPlan(self.module, **plan_kwargs)
390 | runner = TrainRunner(
391 | self,
392 | training_plan=training_plan,
393 | data_splitter=data_splitter,
394 | max_epochs=max_epochs,
395 | use_gpu=use_gpu,
396 | early_stopping=early_stopping,
397 | check_val_every_n_epoch=check_val_every_n_epoch,
398 | early_stopping_monitor="reconstruction_loss_validation",
399 | early_stopping_patience=50,
400 | enable_checkpointing=True,
401 | **kwargs,
402 | )
403 | return runner()
404 |
405 | @classmethod
406 | def setup_anndata(
407 | cls,
408 | adata: ad.AnnData,
409 | size_factor_key: str | None = None,
410 | rna_indices_end: int | None = None,
411 | categorical_covariate_keys: list[str] | None = None,
412 | continuous_covariate_keys: list[str] | None = None,
413 | **kwargs,
414 | ):
415 | """Set up :class:`~anndata.AnnData` object.
416 |
417 | A mapping will be created between data fields used by ``scvi`` to their respective locations in adata.
418 | This method will also compute the log mean and log variance per batch for the library size prior.
419 | None of the data in adata are modified. Only adds fields to adata.
420 |
421 | Parameters
422 | ----------
423 | adata
424 | AnnData object containing raw counts. Rows represent cells, columns represent features.
425 | size_factor_key
426 | Key in `adata.obs` containing the size factor. If `None`, will be calculated from the RNA counts.
427 | rna_indices_end
428 | Integer to indicate where RNA feature end in the AnnData object. Is used to calculate ``libary_size``.
429 | categorical_covariate_keys
430 | Keys in `adata.obs` that correspond to categorical data.
431 | continuous_covariate_keys
432 | Keys in `adata.obs` that correspond to continuous data.
433 | kwargs
434 | Additional parameters to pass to register_fields() of AnnDataManager.
435 | """
436 | setup_method_args = cls._get_setup_method_args(**locals())
437 |
438 | anndata_fields = [
439 | fields.LayerField(
440 | REGISTRY_KEYS.X_KEY,
441 | layer=None,
442 | ),
443 | fields.CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, None), # TODO check if need to add if it's None
444 | fields.CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys),
445 | fields.NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys),
446 | ]
447 | size_factor_key = calculate_size_factor(adata, size_factor_key, rna_indices_end)
448 | anndata_fields.append(fields.NumericalObsField(REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key))
449 |
450 | adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
451 | adata_manager.register_fields(adata, **kwargs)
452 | cls.register_manager(adata_manager)
453 |
454 | def plot_losses(self, save=None):
455 | """Plot losses.
456 |
457 | Parameters
458 | ----------
459 | save
460 | If not None, save the plot to this location.
461 | """
462 | loss_names = self.module.select_losses_to_plot()
463 | plt_plot_losses(self.history, loss_names, save)
464 |
465 | # adjusted from scvi-tools
466 | # https://github.com/scverse/scvi-tools/blob/0b802762869c43c9f49e69fe62b1a5a9b5c4dae6/scvi/model/base/_archesmixin.py#L30
467 | # accessed on 5 November 2022
468 | @classmethod
469 | def load_query_data(
470 | cls,
471 | adata: ad.AnnData,
472 | reference_model: BaseModelClass,
473 | use_gpu: str | int | bool | None = None,
474 | freeze: bool = True,
475 | ignore_covariates: list[str] | None = None,
476 | ) -> BaseModelClass:
477 | """Online update of a reference model with scArches algorithm # TODO cite.
478 |
479 | Parameters
480 | ----------
481 | adata
482 | AnnData organized in the same way as data used to train model.
483 | It is not necessary to run setup_anndata,
484 | as AnnData is validated against the ``registry``.
485 | reference_model
486 | Already instantiated model of the same class.
487 | use_gpu
488 | Load model on default GPU if available (if None or True),
489 | or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False).
490 | freeze
491 | Whether to freeze the encoders and decoders and only train the new weights.
492 | ignore_covariates
493 | List of covariates to ignore. Needed for query-to-reference mapping. Default is `None`.
494 |
495 | Returns
496 | -------
497 | Model with updated architecture and weights.
498 | """
499 | _, _, device = parse_use_gpu_arg(use_gpu)
500 |
501 | attr_dict, _, load_state_dict = _get_loaded_data(reference_model, device=device)
502 |
503 | registry = attr_dict.pop("registry_")
504 | if _MODEL_NAME_KEY in registry and registry[_MODEL_NAME_KEY] != cls.__name__:
505 | raise ValueError("It appears you are loading a model from a different class.")
506 |
507 | if _SETUP_ARGS_KEY not in registry:
508 | raise ValueError("Saved model does not contain original setup inputs. " "Cannot load the original setup.")
509 |
510 | if ignore_covariates is None:
511 | ignore_covariates = []
512 |
513 | cls.setup_anndata(
514 | adata,
515 | source_registry=registry,
516 | extend_categories=True,
517 | allow_missing_labels=True,
518 | **registry[_SETUP_ARGS_KEY],
519 | )
520 |
521 | model = _initialize_model(cls, adata, attr_dict)
522 | adata_manager = model.get_anndata_manager(adata, required=True)
523 |
524 | # TODO add an exception if need to add new categories but condition_encoders is False
525 | # model tweaking
526 | num_of_cat_to_add = [
527 | new_cat - old_cat
528 | for i, (old_cat, new_cat) in enumerate(
529 | zip(
530 | reference_model.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).n_cats_per_key,
531 | adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).n_cats_per_key,
532 | strict=False,
533 | )
534 | )
535 | if adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY)["field_keys"][i] not in ignore_covariates
536 | ]
537 |
538 | model.to_device(device)
539 |
540 | new_state_dict = model.module.state_dict()
541 | for key, load_ten in load_state_dict.items(): # load_state_dict = old
542 | new_ten = new_state_dict[key]
543 | if new_ten.size() == load_ten.size():
544 | continue
545 | # new categoricals changed size
546 | else:
547 | old_shape = new_ten.shape
548 | new_shape = load_ten.shape
549 | if old_shape[0] == new_shape[0]:
550 | dim_diff = new_ten.size()[-1] - load_ten.size()[-1]
551 | fixed_ten = torch.cat([load_ten, new_ten[..., -dim_diff:]], dim=-1)
552 | else:
553 | dim_diff = new_ten.size()[0] - load_ten.size()[0]
554 | fixed_ten = torch.cat([load_ten, new_ten[-dim_diff:, ...]], dim=0)
555 | load_state_dict[key] = fixed_ten
556 |
557 | model.module.load_state_dict(load_state_dict)
558 |
559 | # freeze everything but the embeddings that have new categories
560 | if freeze is True:
561 | for _, par in model.module.named_parameters():
562 | par.requires_grad = False
563 | for i, embed in enumerate(model.module.cat_covariate_embeddings):
564 | if num_of_cat_to_add[i] > 0: # unfreeze the ones where categories were added
565 | embed.weight.requires_grad = True
566 | if model.module.integrate_on_idx is not None:
567 | model.module.theta.requires_grad = True
568 |
569 | model.module.eval()
570 | model.is_trained_ = False
571 |
572 | return model
573 |
--------------------------------------------------------------------------------
/src/multimil/module/__init__.py:
--------------------------------------------------------------------------------
1 | from ._multivae_torch import MultiVAETorch
2 | from ._mil_torch import MILClassifierTorch
3 | from ._multivae_mil_torch import MultiVAETorch_MIL
4 |
5 | __all__ = ["MultiVAETorch", "MILClassifierTorch", "MultiVAETorch_MIL"]
6 |
--------------------------------------------------------------------------------
/src/multimil/module/_mil_torch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from scvi import REGISTRY_KEYS
3 | from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data
4 | from torch import nn
5 | from torch.nn import functional as F
6 |
7 | from multimil.nn import MLP, Aggregator
8 | from multimil.utils import prep_minibatch, select_covariates
9 |
10 |
11 | class MILClassifierTorch(BaseModuleClass):
12 | """MultiMIL's MIL classification module.
13 |
14 | Parameters
15 | ----------
16 | z_dim
17 | Latent dimension.
18 | dropout
19 | Dropout rate.
20 | normalization
21 | Normalization type.
22 | num_classification_classes
23 | Number of classes for each of the classification task.
24 | scoring
25 | Scoring type. One of ["gated_attn", "attn", "mean", "max", "sum", "mlp"].
26 | attn_dim
27 | Hidden attention dimension.
28 | n_layers_cell_aggregator
29 | Number of layers in the cell aggregator.
30 | n_layers_classifier
31 | Number of layers in the classifier.
32 | n_layers_mlp_attn
33 | Number of layers in the MLP attention.
34 | n_layers_regressor
35 | Number of layers in the regressor.
36 | n_hidden_regressor
37 | Hidden dimension in the regressor.
38 | n_hidden_cell_aggregator
39 | Hidden dimension in the cell aggregator.
40 | n_hidden_classifier
41 | Hidden dimension in the classifier.
42 | n_hidden_mlp_attn
43 | Hidden dimension in the MLP attention.
44 | class_loss_coef
45 | Classification loss coefficient.
46 | regression_loss_coef
47 | Regression loss coefficient.
48 | sample_batch_size
49 | Sample batch size.
50 | class_idx
51 | Which indices in cat covariates to do classification on.
52 | ord_idx
53 | Which indices in cat covariates to do ordinal regression on.
54 | reg_idx
55 | Which indices in cont covariates to do regression on.
56 | activation
57 | Activation function.
58 | initialization
59 | Initialization type.
60 | anneal_class_loss
61 | Whether to anneal the classification loss.
62 | """
63 |
64 | def __init__(
65 | self,
66 | z_dim=16,
67 | dropout=0.2,
68 | normalization="layer",
69 | num_classification_classes=None, # number of classes for each of the classification task
70 | scoring="gated_attn",
71 | attn_dim=16,
72 | n_layers_cell_aggregator=1,
73 | n_layers_classifier=1,
74 | n_layers_mlp_attn=1,
75 | n_layers_regressor=1,
76 | n_hidden_regressor=16,
77 | n_hidden_cell_aggregator=16,
78 | n_hidden_classifier=16,
79 | n_hidden_mlp_attn=16,
80 | class_loss_coef=1.0,
81 | regression_loss_coef=1.0,
82 | sample_batch_size=128,
83 | class_idx=None, # which indices in cat covariates to do classification on, i.e. exclude from inference; this is a torch tensor
84 | ord_idx=None, # which indices in cat covariates to do ordinal regression on and also exclude from inference; this is a torch tensor
85 | reg_idx=None, # which indices in cont covariates to do regression on and also exclude from inference; this is a torch tensor
86 | activation="leaky_relu",
87 | initialization=None,
88 | anneal_class_loss=False,
89 | ):
90 | super().__init__()
91 |
92 | if activation == "leaky_relu":
93 | self.activation = nn.LeakyReLU
94 | elif activation == "tanh":
95 | self.activation = nn.Tanh
96 | else:
97 | raise NotImplementedError(
98 | f'activation should be one of ["leaky_relu", "tanh"], but activation={activation} was passed.'
99 | )
100 |
101 | self.class_loss_coef = class_loss_coef
102 | self.regression_loss_coef = regression_loss_coef
103 | self.sample_batch_size = sample_batch_size
104 | self.anneal_class_loss = anneal_class_loss
105 | self.num_classification_classes = num_classification_classes
106 | self.class_idx = class_idx
107 | self.ord_idx = ord_idx
108 | self.reg_idx = reg_idx
109 |
110 | self.cell_level_aggregator = nn.Sequential(
111 | MLP(
112 | z_dim,
113 | z_dim,
114 | n_layers=n_layers_cell_aggregator,
115 | n_hidden=n_hidden_cell_aggregator,
116 | dropout_rate=dropout,
117 | activation=self.activation,
118 | normalization=normalization,
119 | ),
120 | Aggregator(
121 | n_input=z_dim,
122 | scoring=scoring,
123 | attn_dim=attn_dim,
124 | sample_batch_size=sample_batch_size,
125 | scale=True,
126 | dropout=dropout,
127 | n_layers_mlp_attn=n_layers_mlp_attn,
128 | n_hidden_mlp_attn=n_hidden_mlp_attn,
129 | activation=self.activation,
130 | ),
131 | )
132 |
133 | if len(self.class_idx) > 0:
134 | self.classifiers = torch.nn.ModuleList()
135 |
136 | # classify zs directly
137 | class_input_dim = z_dim
138 |
139 | for num in self.num_classification_classes:
140 | if n_layers_classifier == 1:
141 | self.classifiers.append(nn.Linear(class_input_dim, num))
142 | else:
143 | self.classifiers.append(
144 | nn.Sequential(
145 | MLP(
146 | class_input_dim,
147 | n_hidden_classifier,
148 | n_layers=n_layers_classifier - 1,
149 | n_hidden=n_hidden_classifier,
150 | dropout_rate=dropout,
151 | activation=self.activation,
152 | ),
153 | nn.Linear(n_hidden_classifier, num),
154 | )
155 | )
156 |
157 | if len(self.ord_idx) + len(self.reg_idx) > 0:
158 | self.regressors = torch.nn.ModuleList()
159 | for _ in range(
160 | len(self.ord_idx) + len(self.reg_idx)
161 | ): # one head per standard regression and one per ordinal regression
162 | if n_layers_regressor == 1:
163 | self.regressors.append(nn.Linear(z_dim, 1))
164 | else:
165 | self.regressors.append(
166 | nn.Sequential(
167 | MLP(
168 | z_dim,
169 | n_hidden_regressor,
170 | n_layers=n_layers_regressor - 1,
171 | n_hidden=n_hidden_regressor,
172 | dropout_rate=dropout,
173 | activation=self.activation,
174 | ),
175 | nn.Linear(n_hidden_regressor, 1),
176 | )
177 | )
178 |
179 | if initialization == "xavier":
180 | for layer in self.modules():
181 | if isinstance(layer, nn.Linear):
182 | nn.init.xavier_uniform_(layer.weight, gain=nn.init.calculate_gain("leaky_relu"))
183 | elif initialization == "kaiming":
184 | for layer in self.modules():
185 | if isinstance(layer, nn.Linear):
186 | # following https://towardsdatascience.com/understand-kaiming-initialization-and-implementation-detail-in-pytorch-f7aa967e9138 (accessed 16.08.22)
187 | nn.init.kaiming_normal_(layer.weight, mode="fan_in")
188 |
189 | def _get_inference_input(self, tensors):
190 | x = tensors[REGISTRY_KEYS.X_KEY]
191 | return {"x": x}
192 |
193 | def _get_generative_input(self, tensors, inference_outputs):
194 | z_joint = inference_outputs["z_joint"]
195 | return {"z_joint": z_joint}
196 |
197 | @auto_move_data
198 | def inference(self, x) -> dict[str, torch.Tensor | list[torch.Tensor]]:
199 | """Forward pass for inference.
200 |
201 | Parameters
202 | ----------
203 | x
204 | Input.
205 |
206 | Returns
207 | -------
208 | Predictions.
209 | """
210 | z_joint = x
211 | inference_outputs = {"z_joint": z_joint}
212 |
213 | # MIL part
214 | batch_size = x.shape[0]
215 |
216 | idx = list(range(self.sample_batch_size, batch_size, self.sample_batch_size))
217 | if (
218 | batch_size % self.sample_batch_size != 0
219 | ): # can only happen during inference for last batches for each sample
220 | idx = []
221 | zs = torch.tensor_split(z_joint, idx, dim=0)
222 | zs = torch.stack(zs, dim=0) # num of bags x batch_size x z_dim
223 | zs_attn = self.cell_level_aggregator(zs) # num of bags x cond_dim
224 |
225 | predictions = []
226 | if len(self.class_idx) > 0:
227 | predictions.extend([classifier(zs_attn) for classifier in self.classifiers])
228 | if len(self.ord_idx) + len(self.reg_idx) > 0:
229 | predictions.extend([regressor(zs_attn) for regressor in self.regressors])
230 |
231 | inference_outputs.update(
232 | {"predictions": predictions}
233 | ) # predictions are a list as they can have different number of classes
234 | return inference_outputs # z_joint, mu, logvar, predictions
235 |
236 | @auto_move_data
237 | def generative(self, z_joint) -> torch.Tensor:
238 | # TODO even if not used, make consistent with the rest, i.e. return dict
239 | """Forward pass for generative.
240 |
241 | Parameters
242 | ----------
243 | z_joint
244 | Latent embeddings.
245 |
246 | Returns
247 | -------
248 | Same as input.
249 | """
250 | return z_joint
251 |
252 | def _calculate_loss(self, tensors, inference_outputs, generative_outputs, kl_weight: float = 1.0):
253 | cont_key = REGISTRY_KEYS.CONT_COVS_KEY
254 | cont_covs = tensors[cont_key] if cont_key in tensors.keys() else None
255 |
256 | cat_key = REGISTRY_KEYS.CAT_COVS_KEY
257 | cat_covs = tensors[cat_key] if cat_key in tensors.keys() else None
258 |
259 | # MIL classification loss
260 | minibatch_size, n_samples_in_batch = prep_minibatch(cat_covs, self.sample_batch_size)
261 | regression = select_covariates(cont_covs, self.reg_idx.to(self.device), n_samples_in_batch)
262 | ordinal_regression = select_covariates(cat_covs, self.ord_idx.to(self.device), n_samples_in_batch)
263 | classification = select_covariates(cat_covs, self.class_idx.to(self.device), n_samples_in_batch)
264 |
265 | predictions = inference_outputs["predictions"] # list, first from classifiers, then from regressors
266 |
267 | accuracies = []
268 | classification_loss = torch.tensor(0.0).to(self.device)
269 | for i in range(len(self.class_idx)):
270 | classification_loss += F.cross_entropy(
271 | predictions[i], classification[:, i].long()
272 | ) # assume same in the batch
273 | accuracies.append(
274 | torch.sum(torch.eq(torch.argmax(predictions[i], dim=-1), classification[:, i]))
275 | / classification[:, i].shape[0]
276 | )
277 |
278 | regression_loss = torch.tensor(0.0).to(self.device)
279 | for i in range(len(self.ord_idx)):
280 | regression_loss += F.mse_loss(predictions[len(self.class_idx) + i].squeeze(-1), ordinal_regression[:, i])
281 | accuracies.append(
282 | torch.sum(
283 | torch.eq(
284 | torch.clamp(
285 | torch.round(predictions[len(self.class_idx) + i].squeeze()),
286 | min=0.0,
287 | max=self.num_classification_classes[i] - 1.0,
288 | ),
289 | ordinal_regression[:, i],
290 | )
291 | )
292 | / ordinal_regression[:, i].shape[0]
293 | )
294 |
295 | for i in range(len(self.reg_idx)):
296 | regression_loss += F.mse_loss(
297 | predictions[len(self.class_idx) + len(self.ord_idx) + i].squeeze(-1),
298 | regression[:, i],
299 | )
300 |
301 | class_loss_anneal_coef = kl_weight if self.anneal_class_loss else 1.0
302 |
303 | loss = torch.mean(
304 | self.class_loss_coef * classification_loss * class_loss_anneal_coef
305 | + self.regression_loss_coef * regression_loss
306 | )
307 |
308 | extra_metrics = {
309 | "class_loss": classification_loss,
310 | "regression_loss": regression_loss,
311 | }
312 |
313 | if len(accuracies) > 0:
314 | accuracy = torch.sum(torch.tensor(accuracies)) / len(accuracies)
315 | extra_metrics["accuracy"] = accuracy
316 |
317 | # don't need in this model but have to return
318 | recon_loss = torch.zeros(minibatch_size)
319 | kl_loss = torch.zeros(minibatch_size)
320 |
321 | return loss, recon_loss, kl_loss, extra_metrics
322 |
323 | def loss(self, tensors, inference_outputs, generative_outputs, kl_weight: float = 1.0):
324 | """Loss calculation.
325 |
326 | Parameters
327 | ----------
328 | tensors
329 | Input tensors.
330 | inference_outputs
331 | Inference outputs.
332 | generative_outputs
333 | Generative outputs.
334 | kl_weight
335 | KL weight. Default is 1.0.
336 |
337 | Returns
338 | -------
339 | Prediction loss.
340 | """
341 | loss, recon_loss, kl_loss, extra_metrics = self._calculate_loss(
342 | tensors, inference_outputs, generative_outputs, kl_weight
343 | )
344 |
345 | return LossOutput(
346 | loss=loss,
347 | reconstruction_loss=recon_loss,
348 | kl_local=kl_loss,
349 | extra_metrics=extra_metrics,
350 | )
351 |
352 | def select_losses_to_plot(self):
353 | """Select losses to plot.
354 |
355 | Returns
356 | -------
357 | Loss names.
358 | """
359 | loss_names = []
360 | if self.class_loss_coef != 0 and len(self.class_idx) > 0:
361 | loss_names.extend(["class_loss", "accuracy"])
362 | if self.regression_loss_coef != 0 and len(self.reg_idx) > 0:
363 | loss_names.append("regression_loss")
364 | if self.regression_loss_coef != 0 and len(self.ord_idx) > 0:
365 | loss_names.extend(["regression_loss", "accuracy"])
366 | return loss_names
367 |
--------------------------------------------------------------------------------
/src/multimil/module/_multivae_mil_torch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from scvi import REGISTRY_KEYS
3 | from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data
4 |
5 | from multimil.module import MILClassifierTorch, MultiVAETorch
6 |
7 |
8 | class MultiVAETorch_MIL(BaseModuleClass):
9 | """MultiMIL's end-to-end multimodal integration and MIL classification modules.
10 |
11 | Parameters
12 | ----------
13 | modality_lengths
14 | Number of features for each modality.
15 | condition_encoders
16 | Whether to condition the encoders on the covariates.
17 | condition_decoders
18 | Whether to condition the decoders on the covariates.
19 | normalization
20 | Normalization to use in the network.
21 | z_dim
22 | Dimensionality of the latent space.
23 | losses
24 | List of losses to use in the VAE.
25 | dropout
26 | Dropout rate.
27 | cond_dim
28 | Dimensionality of the covariate embeddings.
29 | kernel_type
30 | Type of kernel to use for the MMD loss.
31 | loss_coefs
32 | Coefficients for the different losses.
33 | num_groups
34 | Number of groups to use for the MMD loss.
35 | integrate_on_idx
36 | Indices of the covariates to integrate on.
37 | n_layers_encoders
38 | Number of layers in the encoders.
39 | n_layers_decoders
40 | Number of layers in the decoders.
41 | n_hidden_encoders
42 | Number of hidden units in the encoders.
43 | n_hidden_decoders
44 | Number of hidden units in the decoders.
45 | num_classification_classes
46 | Number of classes for each of the classification task.
47 | scoring
48 | Scoring function to use for the MIL classification.
49 | attn_dim
50 | Dimensionality of the hidden attention dimension.
51 | cat_covariate_dims
52 | Number of categories for each of the categorical covariates.
53 | cont_covariate_dims
54 | Number of categories for each of the continuous covariates. Always 1.
55 | cat_covs_idx
56 | Indices of the categorical covariates.
57 | cont_covs_idx
58 | Indices of the continuous covariates.
59 | cont_cov_type
60 | Type of continuous covariate.
61 | n_layers_cell_aggregator
62 | Number of layers in the cell aggregator.
63 | n_layers_classifier
64 | Number of layers in the classifier.
65 | n_layers_mlp_attn
66 | Number of layers in the attention MLP.
67 | n_layers_cont_embed
68 | Number of layers in the continuous embedding calculation.
69 | n_layers_regressor
70 | Number of layers in the regressor.
71 | n_hidden_regressor
72 | Number of hidden units in the regressor.
73 | n_hidden_cell_aggregator
74 | Number of hidden units in the cell aggregator.
75 | n_hidden_classifier
76 | Number of hidden units in the classifier.
77 | n_hidden_mlp_attn
78 | Number of hidden units in the attention MLP.
79 | n_hidden_cont_embed
80 | Number of hidden units in the continuous embedding calculation.
81 | class_loss_coef
82 | Coefficient for the classification loss.
83 | regression_loss_coef
84 | Coefficient for the regression loss.
85 | sample_batch_size
86 | Bag size.
87 | class_idx
88 | Which indices in cat covariates to do classification on.
89 | ord_idx
90 | Which indices in cat covariates to do ordinal regression on.
91 | reg_idx
92 | Which indices in cont covariates to do regression on.
93 | mmd
94 | Type of MMD loss to use.
95 | activation
96 | Activation function to use.
97 | initialization
98 | Initialization method to use.
99 | anneal_class_loss
100 | Whether to anneal the classification loss.
101 | """
102 |
103 | def __init__(
104 | self,
105 | modality_lengths,
106 | condition_encoders=False,
107 | condition_decoders=True,
108 | normalization="layer",
109 | z_dim=16,
110 | losses=None,
111 | dropout=0.2,
112 | cond_dim=16,
113 | kernel_type="gaussian",
114 | loss_coefs=None,
115 | num_groups=1,
116 | integrate_on_idx=None,
117 | n_layers_encoders=None,
118 | n_layers_decoders=None,
119 | n_hidden_encoders=None,
120 | n_hidden_decoders=None,
121 | num_classification_classes=None, # number of classes for each of the classification task
122 | scoring="gated_attn",
123 | attn_dim=16,
124 | cat_covariate_dims=None,
125 | cont_covariate_dims=None,
126 | cat_covs_idx=None,
127 | cont_covs_idx=None,
128 | cont_cov_type="logsigm",
129 | n_layers_cell_aggregator=1,
130 | n_layers_classifier=1,
131 | n_layers_mlp_attn=1,
132 | n_layers_cont_embed=1,
133 | n_layers_regressor=1,
134 | n_hidden_regressor=16,
135 | n_hidden_cell_aggregator=16,
136 | n_hidden_classifier=16,
137 | n_hidden_mlp_attn=16,
138 | n_hidden_cont_embed=16,
139 | class_loss_coef=1.0,
140 | regression_loss_coef=1.0,
141 | sample_batch_size=128,
142 | class_idx=None, # which indices in cat covariates to do classification on, i.e. exclude from inference
143 | ord_idx=None, # which indices in cat covariates to do ordinal regression on and also exclude from inference
144 | reg_idx=None, # which indices in cont covariates to do regression on and also exclude from inference
145 | mmd="latent",
146 | activation="leaky_relu",
147 | initialization=None,
148 | anneal_class_loss=False,
149 | ):
150 | super().__init__()
151 |
152 | self.vae_module = MultiVAETorch(
153 | modality_lengths=modality_lengths,
154 | condition_encoders=condition_encoders,
155 | condition_decoders=condition_decoders,
156 | normalization=normalization,
157 | z_dim=z_dim,
158 | losses=losses,
159 | dropout=dropout,
160 | cond_dim=cond_dim,
161 | kernel_type=kernel_type,
162 | loss_coefs=loss_coefs,
163 | num_groups=num_groups,
164 | integrate_on_idx=integrate_on_idx,
165 | cat_covariate_dims=cat_covariate_dims, # only the actual categorical covs are considered here
166 | cont_covariate_dims=cont_covariate_dims, # only the actual cont covs are considered here
167 | cat_covs_idx=cat_covs_idx,
168 | cont_covs_idx=cont_covs_idx,
169 | cont_cov_type=cont_cov_type,
170 | n_layers_encoders=n_layers_encoders,
171 | n_layers_decoders=n_layers_decoders,
172 | n_layers_cont_embed=n_layers_cont_embed,
173 | n_hidden_encoders=n_hidden_encoders,
174 | n_hidden_decoders=n_hidden_decoders,
175 | n_hidden_cont_embed=n_hidden_cont_embed,
176 | mmd=mmd,
177 | activation=activation,
178 | initialization=initialization,
179 | )
180 | self.mil_module = MILClassifierTorch(
181 | z_dim=z_dim,
182 | dropout=dropout,
183 | n_layers_cell_aggregator=n_layers_cell_aggregator,
184 | n_layers_classifier=n_layers_classifier,
185 | n_layers_mlp_attn=n_layers_mlp_attn,
186 | n_layers_regressor=n_layers_regressor,
187 | n_hidden_regressor=n_hidden_regressor,
188 | n_hidden_cell_aggregator=n_hidden_cell_aggregator,
189 | n_hidden_classifier=n_hidden_classifier,
190 | n_hidden_mlp_attn=n_hidden_mlp_attn,
191 | class_loss_coef=class_loss_coef,
192 | regression_loss_coef=regression_loss_coef,
193 | sample_batch_size=sample_batch_size,
194 | anneal_class_loss=anneal_class_loss,
195 | num_classification_classes=num_classification_classes,
196 | class_idx=class_idx,
197 | ord_idx=ord_idx,
198 | reg_idx=reg_idx,
199 | activation=activation,
200 | initialization=initialization,
201 | normalization=normalization,
202 | scoring=scoring,
203 | attn_dim=attn_dim,
204 | )
205 |
206 | def _get_inference_input(self, tensors):
207 | x = tensors[REGISTRY_KEYS.X_KEY]
208 |
209 | cont_key = REGISTRY_KEYS.CONT_COVS_KEY
210 | cont_covs = tensors[cont_key] if cont_key in tensors.keys() else None
211 |
212 | cat_key = REGISTRY_KEYS.CAT_COVS_KEY
213 | cat_covs = tensors[cat_key] if cat_key in tensors.keys() else None
214 |
215 | return {"x": x, "cat_covs": cat_covs, "cont_covs": cont_covs}
216 |
217 | def _get_generative_input(self, tensors, inference_outputs):
218 | z_joint = inference_outputs["z_joint"]
219 |
220 | cont_key = REGISTRY_KEYS.CONT_COVS_KEY
221 | cont_covs = tensors[cont_key] if cont_key in tensors.keys() else None
222 |
223 | cat_key = REGISTRY_KEYS.CAT_COVS_KEY
224 | cat_covs = tensors[cat_key] if cat_key in tensors.keys() else None
225 |
226 | return {"z_joint": z_joint, "cat_covs": cat_covs, "cont_covs": cont_covs}
227 |
228 | @auto_move_data
229 | def inference(self, x, cat_covs, cont_covs) -> dict[str, torch.Tensor | list[torch.Tensor]]:
230 | """Forward pass for inference.
231 |
232 | Parameters
233 | ----------
234 | x
235 | Input.
236 | cat_covs
237 | Categorical covariates to condition on.
238 | cont_covs
239 | Continuous covariates to condition on.
240 |
241 | Returns
242 | -------
243 | Joint representations, marginal representations, joint mu's and logvar's and predictions.
244 | """
245 | # VAE part
246 | inference_outputs = self.vae_module.inference(x, cat_covs, cont_covs)
247 | z_joint = inference_outputs["z_joint"]
248 |
249 | # MIL part
250 | mil_inference_outputs = self.mil_module.inference(z_joint)
251 | inference_outputs.update(mil_inference_outputs)
252 | return inference_outputs # z_joint, mu, logvar, z_marginal, predictions
253 |
254 | @auto_move_data
255 | def generative(self, z_joint, cat_covs, cont_covs) -> dict[str, torch.Tensor]:
256 | """Compute necessary inference quantities.
257 |
258 | Parameters
259 | ----------
260 | z_joint
261 | Tensor of values with shape ``(batch_size, z_dim)``.
262 | cat_covs
263 | Categorical covariates to condition on.
264 | cont_covs
265 | Continuous covariates to condition on.
266 |
267 | Returns
268 | -------
269 | Reconstructed values for each modality.
270 | """
271 | return self.vae_module.generative(z_joint, cat_covs, cont_covs)
272 |
273 | def loss(self, tensors, inference_outputs, generative_outputs, kl_weight: float = 1.0):
274 | """Calculate the (modality) reconstruction loss, Kullback divergences and integration loss.
275 |
276 | Parameters
277 | ----------
278 | tensors
279 | Tensor of values with shape ``(batch_size, n_input_features)``.
280 | inference_outputs
281 | Dictionary with the inference output.
282 | generative_outputs
283 | Dictionary with the generative output.
284 | kl_weight
285 | Weight of the KL loss. Default is 1.0.
286 |
287 | Returns
288 | -------
289 | Reconstruction loss, Kullback divergences, integration loss, modality reconstruction and prediction losses.
290 | """
291 | loss_vae, recon_loss, kl_loss, extra_metrics = self.vae_module._calculate_loss(
292 | tensors, inference_outputs, generative_outputs, kl_weight
293 | )
294 | loss_mil, _, _, extra_metrics_mil = self.mil_module._calculate_loss(
295 | tensors, inference_outputs, generative_outputs, kl_weight
296 | )
297 | loss = loss_vae + loss_mil
298 | extra_metrics.update(extra_metrics_mil)
299 |
300 | return LossOutput(
301 | loss=loss,
302 | reconstruction_loss=recon_loss,
303 | kl_local=kl_loss,
304 | extra_metrics=extra_metrics,
305 | )
306 |
--------------------------------------------------------------------------------
/src/multimil/module/_multivae_torch.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from typing import Literal
3 |
4 | import torch
5 | from scvi import REGISTRY_KEYS
6 | from scvi.distributions import NegativeBinomial, ZeroInflatedNegativeBinomial
7 | from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data
8 | from torch import nn
9 | from torch.distributions import Normal
10 | from torch.distributions import kl_divergence as kl
11 |
12 | from multimil.distributions import MMD, Jeffreys
13 | from multimil.nn import MLP, Decoder, GeneralizedSigmoid
14 |
15 |
16 | class MultiVAETorch(BaseModuleClass):
17 | """MultiMIL's multimodal integration module.
18 |
19 | Parameters
20 | ----------
21 | modality_lengths
22 | List with lengths of each modality.
23 | condition_encoders
24 | Boolean to indicate if to condition encoders.
25 | condition_decoders
26 | Boolean to indicate if to condition decoders.
27 | normalization
28 | One of the following
29 | * ``'layer'`` - layer normalization
30 | * ``'batch'`` - batch normalization
31 | * ``None`` - no normalization.
32 | z_dim
33 | Dimensionality of the latent space.
34 | losses
35 | List of which losses to use. For each modality can be one of the following:
36 | * ``'mse'`` - mean squared error
37 | * ``'nb'`` - negative binomial
38 | * ``zinb`` - zero-inflated negative binomial
39 | * ``bce`` - binary cross-entropy.
40 | dropout
41 | Dropout rate for neural networks.
42 | cond_dim
43 | Dimensionality of the covariate embeddings.
44 | kernel_type
45 | One of the following:
46 | * ``'gaussian'`` - Gaussian kernel
47 | * ``'not gaussian'`` - not Gaussian kernel.
48 | loss_coefs
49 | Dictionary with weights for each of the losses.
50 | num_groups
51 | Number of groups to integrate on.
52 | integrate_on_idx
53 | Indices on which to integrate on.
54 | cat_covariate_dims
55 | List with number of classes for each of the categorical covariates.
56 | cont_covariate_dims
57 | List of 1's for each of the continuous covariate.
58 | cont_cov_type
59 | How to transform continuous covariate before multiplying with the embedding. One of the following:
60 | * ``'logsim'`` - generalized sigmoid
61 | * ``'mlp'`` - MLP.
62 | n_layers_cont_embed
63 | Number of layers for the transformation of the continuous covariates before multiplying with the embedding.
64 | n_hidden_cont_embed
65 | Number of nodes in hidden layers in the network that transforms continuous covariates.
66 | n_layers_encoders
67 | Number of layers in each encoder.
68 | n_layers_decoders
69 | Number of layers in each decoder.
70 | n_hidden_encoders
71 | Number of nodes in hidden layers in encoders.
72 | n_hidden_decoders
73 | Number of nodes in hidden layers in decoders.
74 | alignment_type
75 | How to calculate integration loss. One of the following
76 | * ``'latent'`` - only on the latent representations
77 | * ``'marginal'`` - only on the marginal representations
78 | * ``both`` - the sum of the two above.
79 | """
80 |
81 | def __init__(
82 | self,
83 | modality_lengths,
84 | condition_encoders=False,
85 | condition_decoders=True,
86 | normalization: Literal["layer", "batch", None] = "layer",
87 | z_dim=16,
88 | losses=None,
89 | dropout=0.2,
90 | cond_dim=16,
91 | kernel_type="gaussian",
92 | loss_coefs=None,
93 | num_groups=1,
94 | integrate_on_idx=None,
95 | cat_covariate_dims=None,
96 | cont_covariate_dims=None,
97 | cat_covs_idx=None,
98 | cont_covs_idx=None,
99 | cont_cov_type="logsigm",
100 | n_layers_cont_embed: int = 1,
101 | n_layers_encoders=None,
102 | n_layers_decoders=None,
103 | n_hidden_cont_embed: int = 16,
104 | n_hidden_encoders=None,
105 | n_hidden_decoders=None,
106 | modality_alignment=None,
107 | alignment_type="latent",
108 | activation="leaky_relu",
109 | initialization=None,
110 | mix="product",
111 | ):
112 | super().__init__()
113 |
114 | self.input_dims = modality_lengths
115 | self.condition_encoders = condition_encoders
116 | self.condition_decoders = condition_decoders
117 | self.n_modality = len(self.input_dims)
118 | self.kernel_type = kernel_type
119 | self.integrate_on_idx = integrate_on_idx
120 | self.n_cont_cov = len(cont_covariate_dims)
121 | self.cont_cov_type = cont_cov_type
122 | self.alignment_type = alignment_type
123 | self.modality_alignment = modality_alignment
124 | self.normalization = normalization
125 | self.z_dim = z_dim
126 | self.dropout = dropout
127 | self.cond_dim = cond_dim
128 | self.kernel_type = kernel_type
129 | self.n_layers_cont_embed = n_layers_cont_embed
130 | self.n_layers_encoders = n_layers_encoders
131 | self.n_layers_decoders = n_layers_decoders
132 | self.n_hidden_cont_embed = n_hidden_cont_embed
133 | self.n_hidden_encoders = n_hidden_encoders
134 | self.n_hidden_decoders = n_hidden_decoders
135 | self.cat_covs_idx = cat_covs_idx
136 | self.cont_covs_idx = cont_covs_idx
137 | self.mix = mix
138 |
139 | if activation == "leaky_relu":
140 | self.activation = nn.LeakyReLU
141 | elif activation == "tanh":
142 | self.activation = nn.Tanh
143 | else:
144 | raise NotImplementedError(
145 | f'activation should be one of ["leaky_relu", "tanh"], but activation={activation} was passed.'
146 | )
147 |
148 | # TODO: add warnings that mse is used
149 | if losses is None:
150 | self.losses = ["mse"] * self.n_modality
151 | elif len(losses) == self.n_modality:
152 | self.losses = losses
153 | else:
154 | raise ValueError(
155 | f"losses has to be the same length as the number of modalities. number of modalities = {self.n_modality} != {len(losses)} = len(losses)"
156 | )
157 | if cat_covariate_dims is None:
158 | raise ValueError("cat_covariate_dims = None was passed.")
159 | if cont_covariate_dims is None:
160 | raise ValueError("cont_covariate_dims = None was passed.")
161 |
162 | # TODO: add warning that using these
163 | if self.n_layers_encoders is None:
164 | self.n_layers_encoders = [2] * self.n_modality
165 | if self.n_layers_decoders is None:
166 | self.n_layers_decoders = [2] * self.n_modality
167 | if self.n_hidden_encoders is None:
168 | self.n_hidden_encoders = [128] * self.n_modality
169 | if self.n_hidden_decoders is None:
170 | self.n_hidden_decoders = [128] * self.n_modality
171 |
172 | self.loss_coefs = {
173 | "recon": 1,
174 | "kl": 1e-6,
175 | "integ": 0,
176 | }
177 | for i in range(self.n_modality):
178 | self.loss_coefs[str(i)] = 1
179 | if loss_coefs is not None:
180 | self.loss_coefs.update(loss_coefs)
181 |
182 | # assume for now that can only use nb/zinb once, i.e. for RNA-seq modality
183 | # TODO: add check for multiple nb/zinb losses given
184 | self.theta = None
185 | for i, loss in enumerate(losses):
186 | if loss in ["nb", "zinb"]:
187 | self.theta = torch.nn.Parameter(torch.randn(self.input_dims[i], num_groups))
188 | break
189 |
190 | # modality encoders
191 | cond_dim_enc = cond_dim * (len(cat_covariate_dims) + len(cont_covariate_dims)) if self.condition_encoders else 0
192 | self.encoders = [
193 | MLP(
194 | n_input=x_dim + cond_dim_enc,
195 | n_output=z_dim,
196 | n_layers=n_layers,
197 | n_hidden=n_hidden,
198 | dropout_rate=dropout,
199 | normalization=normalization,
200 | activation=self.activation,
201 | )
202 | for x_dim, n_layers, n_hidden in zip(
203 | self.input_dims, self.n_layers_encoders, self.n_hidden_encoders, strict=False
204 | )
205 | ]
206 |
207 | # modality decoders
208 | cond_dim_dec = cond_dim * (len(cat_covariate_dims) + len(cont_covariate_dims)) if self.condition_decoders else 0
209 | dec_input = z_dim
210 | self.decoders = [
211 | Decoder(
212 | n_input=dec_input + cond_dim_dec,
213 | n_output=x_dim,
214 | n_layers=n_layers,
215 | n_hidden=n_hidden,
216 | dropout_rate=dropout,
217 | normalization=normalization,
218 | activation=self.activation,
219 | loss=loss,
220 | )
221 | for x_dim, loss, n_layers, n_hidden in zip(
222 | self.input_dims, self.losses, self.n_layers_decoders, self.n_hidden_decoders, strict=False
223 | )
224 | ]
225 |
226 | self.mus = [nn.Linear(z_dim, z_dim) for _ in self.input_dims]
227 | self.logvars = [nn.Linear(z_dim, z_dim) for _ in self.input_dims]
228 |
229 | self.cat_covariate_embeddings = [nn.Embedding(dim, cond_dim) for dim in cat_covariate_dims]
230 | if self.n_cont_cov > 0:
231 | self.cont_covariate_embeddings = nn.Embedding(self.n_cont_cov, cond_dim)
232 | if self.cont_cov_type == "mlp":
233 | self.cont_covariate_curves = torch.nn.ModuleList()
234 | for _ in range(self.n_cont_cov):
235 | n_input = n_hidden_cont_embed if self.n_layers_cont_embed > 1 else 1
236 | self.cont_covariate_curves.append(
237 | nn.Sequential(
238 | MLP(
239 | n_input=1,
240 | n_output=n_hidden_cont_embed,
241 | n_layers=self.n_layers_cont_embed - 1,
242 | n_hidden=n_hidden_cont_embed,
243 | dropout_rate=dropout,
244 | normalization=normalization,
245 | activation=self.activation,
246 | ),
247 | nn.Linear(n_input, 1),
248 | )
249 | if self.n_layers_cont_embed > 1
250 | else nn.Linear(n_input, 1)
251 | )
252 | else:
253 | self.cont_covariate_curves = GeneralizedSigmoid(
254 | dim=self.n_cont_cov,
255 | nonlin=self.cont_cov_type,
256 | )
257 |
258 | # register sub-modules
259 | for i, (enc, dec, mu, logvar) in enumerate(
260 | zip(self.encoders, self.decoders, self.mus, self.logvars, strict=False)
261 | ):
262 | self.add_module(f"encoder_{i}", enc)
263 | self.add_module(f"decoder_{i}", dec)
264 | self.add_module(f"mu_{i}", mu)
265 | self.add_module(f"logvar_{i}", logvar)
266 |
267 | for i, emb in enumerate(self.cat_covariate_embeddings):
268 | self.add_module(f"cat_covariate_embedding_{i}", emb)
269 |
270 | if initialization is not None:
271 | if initialization == "xavier":
272 | if activation != "leaky_relu":
273 | warnings.warn(
274 | f"We recommend using Xavier initialization with leaky_relu, but activation={activation} was passed.",
275 | stacklevel=2,
276 | )
277 | for layer in self.modules():
278 | if isinstance(layer, nn.Linear):
279 | nn.init.xavier_uniform_(layer.weight, gain=nn.init.calculate_gain(activation))
280 | elif initialization == "kaiming":
281 | if activation != "tanh":
282 | warnings.warn(
283 | f"We recommend using Kaiming initialization with tanh, but activation={activation} was passed.",
284 | stacklevel=2,
285 | )
286 | for layer in self.modules():
287 | if isinstance(layer, nn.Linear):
288 | # following https://towardsdatascience.com/understand-kaiming-initialization-and-implementation-detail-in-pytorch-f7aa967e9138 (accessed 16.08.22)
289 | nn.init.kaiming_normal_(layer.weight, mode="fan_in")
290 |
291 | def _reparameterize(self, mu, logvar):
292 | std = torch.exp(0.5 * logvar)
293 | eps = torch.randn_like(std)
294 | return mu + eps * std
295 |
296 | def _bottleneck(self, z, i):
297 | mu = self.mus[i](z)
298 | logvar = self.logvars[i](z)
299 | z = self._reparameterize(mu, logvar)
300 | return z, mu, logvar
301 |
302 | def _x_to_h(self, x, i):
303 | return self.encoders[i](x)
304 |
305 | def _h_to_x(self, h, i):
306 | x = self.decoders[i](h)
307 | return x
308 |
309 | def _product_of_experts(self, mus, logvars, masks):
310 | vars = torch.exp(logvars)
311 | masks = masks.unsqueeze(-1).repeat(1, 1, vars.shape[-1])
312 | mus_joint = torch.sum(mus * masks / vars, dim=1)
313 | vars_joint = torch.ones_like(mus_joint) # batch size
314 | vars_joint += torch.sum(masks / vars, dim=1)
315 | vars_joint = 1.0 / vars_joint # inverse
316 | mus_joint *= vars_joint
317 | logvars_joint = torch.log(vars_joint)
318 | return mus_joint, logvars_joint
319 |
320 | def _mixture_of_experts(self, mus, logvars, masks):
321 | vars = torch.exp(logvars)
322 | masks = masks.unsqueeze(-1).repeat(1, 1, vars.shape[-1])
323 | masks = masks.float()
324 | weights = masks / torch.sum(masks, dim=1, keepdim=True) # normalize so the sum is 1
325 | # params of 1/2 * (X + Y)
326 | mus_mixture = torch.sum(weights * mus, dim=1)
327 | vars_mixture = torch.sum(weights**2 * vars, dim=1)
328 | logvars_mixture = torch.log(vars_mixture)
329 | return mus_mixture, logvars_mixture
330 |
331 | def _get_inference_input(self, tensors):
332 | x = tensors[REGISTRY_KEYS.X_KEY]
333 |
334 | cont_key = REGISTRY_KEYS.CONT_COVS_KEY
335 | cont_covs = tensors[cont_key] if cont_key in tensors.keys() else None
336 |
337 | cat_key = REGISTRY_KEYS.CAT_COVS_KEY
338 | cat_covs = tensors[cat_key] if cat_key in tensors.keys() else None
339 |
340 | return {"x": x, "cat_covs": cat_covs, "cont_covs": cont_covs}
341 |
342 | def _get_generative_input(self, tensors, inference_outputs):
343 | z_joint = inference_outputs["z_joint"]
344 |
345 | cont_key = REGISTRY_KEYS.CONT_COVS_KEY
346 | cont_covs = tensors[cont_key] if cont_key in tensors.keys() else None
347 |
348 | cat_key = REGISTRY_KEYS.CAT_COVS_KEY
349 | cat_covs = tensors[cat_key] if cat_key in tensors.keys() else None
350 |
351 | return {"z_joint": z_joint, "cat_covs": cat_covs, "cont_covs": cont_covs}
352 |
353 | @auto_move_data
354 | def inference(
355 | self,
356 | x: torch.Tensor,
357 | cat_covs: torch.Tensor | None = None,
358 | cont_covs: torch.Tensor | None = None,
359 | masks: list[torch.Tensor] | None = None,
360 | ) -> dict[str, torch.Tensor | list[torch.Tensor]]:
361 | """Compute necessary inference quantities.
362 |
363 | Parameters
364 | ----------
365 | x
366 | Tensor of values with shape ``(batch_size, n_input_features)``.
367 | cat_covs
368 | Categorical covariates to condition on.
369 | cont_covs
370 | Continuous covariates to condition on.
371 | masks
372 | List of binary tensors indicating which values in ``x`` belong to which modality.
373 |
374 | Returns
375 | -------
376 | Joint representations, marginal representations, joint mu's and logvar's.
377 | """
378 | # split x into modality xs
379 | if torch.is_tensor(x):
380 | xs = torch.split(
381 | x, self.input_dims, dim=-1
382 | ) # list of tensors of len = n_mod, each tensor is of shape batch_size x mod_input_dim
383 | else:
384 | xs = x
385 |
386 | # TODO: check if masks still supported
387 | if masks is None:
388 | masks = [x.sum(dim=1) > 0 for x in xs] # list of masks per modality
389 | masks = torch.stack(masks, dim=1)
390 | # if we want to condition encoders, i.e. concat covariates to the input
391 | if self.condition_encoders is True:
392 | cat_embedds = self._select_cat_covariates(cat_covs)
393 | cont_embedds = self._select_cont_covariates(cont_covs)
394 |
395 | # concatenate input with categorical and continuous covariates
396 | xs = [
397 | torch.cat([x, cat_embedds, cont_embedds], dim=-1) for x in xs
398 | ] # concat input to each modality x along the feature axis
399 |
400 | # TODO don't forward if mask is 0 for that dataset for that modality
401 | # hs = hidden state that we get after the encoder but before calculating mu and logvar for each modality
402 | hs = [self._x_to_h(x, mod) for mod, x in enumerate(xs)]
403 | # out = [zs_marginal, mus, logvars] and len(zs_marginal) = len(mus) = len(logvars) = number of modalities
404 | out = [self._bottleneck(h, mod) for mod, h in enumerate(hs)]
405 | # split out into zs_marginal, mus and logvars TODO check if easier to use split
406 | zs_marginal = [mod_out[0] for mod_out in out]
407 | z_marginal = torch.stack(zs_marginal, dim=1)
408 | mus = [mod_out[1] for mod_out in out]
409 | mu = torch.stack(mus, dim=1)
410 | logvars = [mod_out[2] for mod_out in out]
411 | logvar = torch.stack(logvars, dim=1)
412 | if self.mix == "product":
413 | mu_joint, logvar_joint = self._product_of_experts(mu, logvar, masks)
414 | elif self.mix == "mixture":
415 | mu_joint, logvar_joint = self._mixture_of_experts(mu, logvar, masks)
416 | else:
417 | raise ValueError(f"mix should be one of ['product', 'mixture'], but mix={self.mix} was passed")
418 | z_joint = self._reparameterize(mu_joint, logvar_joint)
419 | # drop mus and logvars according to masks for kl calculation
420 | # TODO here or in loss calculation? check
421 | # return mus+mus_joint
422 | return {
423 | "z_joint": z_joint,
424 | "mu": mu_joint,
425 | "logvar": logvar_joint,
426 | "z_marginal": z_marginal,
427 | "mu_marginal": mu,
428 | "logvar_marginal": logvar,
429 | }
430 |
431 | @auto_move_data
432 | def generative(
433 | self, z_joint: torch.Tensor, cat_covs: torch.Tensor | None = None, cont_covs: torch.Tensor | None = None
434 | ) -> dict[str, list[torch.Tensor]]:
435 | """Compute necessary inference quantities.
436 |
437 | Parameters
438 | ----------
439 | z_joint
440 | Tensor of values with shape ``(batch_size, z_dim)``.
441 | cat_covs
442 | Categorical covariates to condition on.
443 | cont_covs
444 | Continuous covariates to condition on.
445 |
446 | Returns
447 | -------
448 | Reconstructed values for each modality.
449 | """
450 | z = z_joint.unsqueeze(1).repeat(1, self.n_modality, 1)
451 | zs = torch.split(z, 1, dim=1)
452 |
453 | if self.condition_decoders is True:
454 | cat_embedds = self._select_cat_covariates(cat_covs)
455 | cont_embedds = self._select_cont_covariates(cont_covs)
456 |
457 | zs = [
458 | torch.cat([z.squeeze(1), cat_embedds, cont_embedds], dim=-1) for z in zs
459 | ] # concat embedding to each modality x along the feature axis
460 |
461 | rs = [self._h_to_x(z, mod) for mod, z in enumerate(zs)]
462 | return {"rs": rs}
463 |
464 | def _select_cat_covariates(self, cat_covs):
465 | if len(self.cat_covs_idx) > 0:
466 | cat_covs = torch.index_select(cat_covs, 1, self.cat_covs_idx.to(self.device))
467 | cat_embedds = [
468 | cat_covariate_embedding(covariate.long())
469 | for cat_covariate_embedding, covariate in zip(self.cat_covariate_embeddings, cat_covs.T, strict=False)
470 | ]
471 | else:
472 | cat_embedds = []
473 |
474 | if len(cat_embedds) > 0:
475 | cat_embedds = torch.cat(cat_embedds, dim=-1) # TODO check if concatenation is needed
476 | else:
477 | cat_embedds = torch.Tensor().to(self.device)
478 | return cat_embedds
479 |
480 | def _select_cont_covariates(self, cont_covs):
481 | if len(self.cont_covs_idx) > 0:
482 | cont_covs = torch.index_select(cont_covs, 1, self.cont_covs_idx.to(self.device))
483 | if cont_covs.shape[-1] != self.n_cont_cov: # get rid of size_factors
484 | cont_covs = cont_covs[:, 0 : self.n_cont_cov]
485 | cont_embedds = self._compute_cont_cov_embeddings(cont_covs)
486 | else:
487 | cont_embedds = torch.Tensor().to(self.device)
488 | return cont_embedds
489 |
490 | def _calculate_loss(self, tensors, inference_outputs, generative_outputs, kl_weight: float = 1.0):
491 | x = tensors[REGISTRY_KEYS.X_KEY]
492 | if self.integrate_on_idx is not None:
493 | integrate_on = tensors.get(REGISTRY_KEYS.CAT_COVS_KEY)[:, self.integrate_on_idx]
494 | else:
495 | integrate_on = torch.zeros(x.shape[0], 1).to(self.device)
496 |
497 | size_factor = tensors.get(REGISTRY_KEYS.SIZE_FACTOR_KEY, None)
498 |
499 | rs = generative_outputs["rs"]
500 | mu = inference_outputs["mu"]
501 | logvar = inference_outputs["logvar"]
502 | z_joint = inference_outputs["z_joint"]
503 | z_marginal = inference_outputs["z_marginal"] # batch_size x n_modalities x latent_dim
504 | mu_marginal = inference_outputs["mu_marginal"]
505 | logvar_marginal = inference_outputs["logvar_marginal"]
506 |
507 | xs = torch.split(
508 | x, self.input_dims, dim=-1
509 | ) # list of tensors of len = n_mod, each tensor is of shape batch_size x mod_input_dim
510 | masks = [x.sum(dim=1) > 0 for x in xs] # [batch_size] * num_modalities
511 |
512 | recon_loss, modality_recon_losses = self._calc_recon_loss(
513 | xs, rs, self.losses, integrate_on, size_factor, self.loss_coefs, masks
514 | )
515 | kl_loss = kl_weight * kl(Normal(mu, torch.sqrt(torch.exp(logvar))), Normal(0, 1)).sum(dim=1)
516 |
517 | integ_loss = torch.tensor(0.0).to(self.device)
518 | if self.loss_coefs["integ"] > 0:
519 | if self.alignment_type == "latent" or self.alignment_type == "both":
520 | integ_loss += self._calc_integ_loss(z_joint, mu, logvar, integrate_on).to(self.device)
521 | if self.alignment_type == "marginal" or self.alignment_type == "both":
522 | for i in range(len(masks)):
523 | for j in range(i + 1, len(masks)):
524 | idx_where_to_calc_integ_loss = torch.eq(
525 | masks[i] == masks[j],
526 | torch.eq(masks[i], torch.ones_like(masks[i])),
527 | )
528 | if (
529 | idx_where_to_calc_integ_loss.any()
530 | ): # if need to calc integ loss for a group between modalities
531 | marginal_i = z_marginal[:, i, :][idx_where_to_calc_integ_loss]
532 | marginal_j = z_marginal[:, j, :][idx_where_to_calc_integ_loss]
533 |
534 | mu_i = mu_marginal[:, i, :][idx_where_to_calc_integ_loss]
535 | mu_j = mu_marginal[:, j, :][idx_where_to_calc_integ_loss]
536 |
537 | logvar_i = logvar_marginal[:, i, :][idx_where_to_calc_integ_loss]
538 | logvar_j = logvar_marginal[:, j, :][idx_where_to_calc_integ_loss]
539 |
540 | marginals = torch.cat([marginal_i, marginal_j])
541 | mus_marginal = torch.cat([mu_i, mu_j])
542 | logvars_marginal = torch.cat([logvar_i, logvar_j])
543 |
544 | modalities = torch.cat(
545 | [
546 | torch.Tensor([i] * marginal_i.shape[0]),
547 | torch.Tensor([j] * marginal_j.shape[0]),
548 | ]
549 | ).to(self.device)
550 |
551 | integ_loss += self._calc_integ_loss(
552 | marginals, mus_marginal, logvars_marginal, modalities
553 | ).to(self.device)
554 |
555 | for i in range(len(masks)):
556 | marginal_i = z_marginal[:, i, :]
557 | marginal_i = marginal_i[masks[i]]
558 |
559 | mu_i = mu_marginal[:, i, :]
560 | mu_i = mu_i[masks[i]]
561 |
562 | logvar_i = logvar_marginal[:, i, :]
563 | logvar_i = logvar_i[masks[i]]
564 |
565 | group_marginal = integrate_on[masks[i]]
566 | integ_loss += self._calc_integ_loss(marginal_i, mu_i, logvar_i, group_marginal).to(self.device)
567 |
568 | loss = torch.mean(
569 | self.loss_coefs["recon"] * recon_loss
570 | + self.loss_coefs["kl"] * kl_loss
571 | + self.loss_coefs["integ"] * integ_loss
572 | )
573 |
574 | modality_recon_losses = {
575 | f"modality_{i}_reconstruction_loss": modality_recon_losses[i] for i in range(len(modality_recon_losses))
576 | }
577 | extra_metrics = {"integ_loss": integ_loss}
578 | extra_metrics.update(modality_recon_losses)
579 |
580 | return loss, recon_loss, kl_loss, extra_metrics
581 |
582 | def loss(
583 | self, tensors, inference_outputs, generative_outputs, kl_weight: float = 1.0
584 | ) -> tuple[
585 | torch.FloatTensor,
586 | dict[str, torch.FloatTensor],
587 | torch.FloatTensor,
588 | torch.FloatTensor,
589 | torch.FloatTensor,
590 | torch.FloatTensor,
591 | dict[str, torch.FloatTensor],
592 | ]:
593 | """Calculate the (modality) reconstruction loss, Kullback divergences and integration loss.
594 |
595 | Parameters
596 | ----------
597 | tensors
598 | Tensor of values with shape ``(batch_size, n_input_features)``.
599 | inference_outputs
600 | Dictionary with the inference output.
601 | generative_outputs
602 | Dictionary with the generative output.
603 | kl_weight
604 | Weight of the KL loss. Default is 1.0.
605 |
606 | Returns
607 | -------
608 | Reconstruction loss, Kullback divergences, integration loss and modality reconstruction losses.
609 | """
610 | loss, recon_loss, kl_loss, extra_metrics = self._calculate_loss(
611 | tensors, inference_outputs, generative_outputs, kl_weight
612 | )
613 |
614 | return LossOutput(
615 | loss=loss,
616 | reconstruction_loss=recon_loss,
617 | kl_local=kl_loss,
618 | extra_metrics=extra_metrics,
619 | )
620 |
621 | # @torch.inference_mode()
622 | # def sample(self, tensors, n_samples=1):
623 | # """Sample from the generative model.
624 |
625 | # Parameters
626 | # ----------
627 | # tensors
628 | # Tensor of values.
629 | # n_samples
630 | # Number of samples to generate.
631 |
632 | # Returns
633 | # -------
634 | # Generative outputs.
635 | # """
636 | # inference_kwargs = {"n_samples": n_samples}
637 | # with torch.inference_mode():
638 | # (
639 | # _,
640 | # generative_outputs,
641 | # ) = self.forward(
642 | # tensors,
643 | # inference_kwargs=inference_kwargs,
644 | # compute_loss=False,
645 | # )
646 | # return generative_outputs["rs"]
647 |
648 | def _calc_recon_loss(self, xs, rs, losses, group, size_factor, loss_coefs, masks):
649 | loss = []
650 | for i, (x, r, loss_type) in enumerate(zip(xs, rs, losses, strict=False)):
651 | if len(r) != 2 and len(r.shape) == 3:
652 | r = r.squeeze()
653 | if loss_type == "mse":
654 | mse_loss = loss_coefs[str(i)] * torch.sum(nn.MSELoss(reduction="none")(r, x), dim=-1)
655 | loss.append(mse_loss)
656 | elif loss_type == "nb":
657 | dec_mean = r
658 | size_factor_view = size_factor.expand(dec_mean.size(0), dec_mean.size(1))
659 | dec_mean = dec_mean * size_factor_view
660 | dispersion = self.theta.T[group.squeeze().long()]
661 | dispersion = torch.exp(dispersion)
662 | nb_loss = torch.sum(NegativeBinomial(mu=dec_mean, theta=dispersion).log_prob(x), dim=-1)
663 | nb_loss = loss_coefs[str(i)] * nb_loss
664 | loss.append(-nb_loss)
665 | elif loss_type == "zinb":
666 | dec_mean, dec_dropout = r
667 | dec_mean = dec_mean.squeeze()
668 | dec_dropout = dec_dropout.squeeze()
669 | size_factor_view = size_factor.unsqueeze(1).expand(dec_mean.size(0), dec_mean.size(1))
670 | dec_mean = dec_mean * size_factor_view
671 | dispersion = self.theta.T[group.squeeze().long()]
672 | dispersion = torch.exp(dispersion)
673 | zinb_loss = torch.sum(
674 | ZeroInflatedNegativeBinomial(mu=dec_mean, theta=dispersion, zi_logits=dec_dropout).log_prob(x),
675 | dim=-1,
676 | )
677 | zinb_loss = loss_coefs[str(i)] * zinb_loss
678 | loss.append(-zinb_loss)
679 | elif loss_type == "bce":
680 | bce_loss = loss_coefs[str(i)] * torch.sum(torch.nn.BCELoss(reduction="none")(r, x), dim=-1)
681 | loss.append(bce_loss)
682 |
683 | return (
684 | torch.sum(torch.stack(loss, dim=-1) * torch.stack(masks, dim=-1), dim=1),
685 | torch.sum(torch.stack(loss, dim=-1) * torch.stack(masks, dim=-1), dim=0),
686 | )
687 |
688 | def _calc_integ_loss(self, z, mu, logvar, group):
689 | loss = torch.tensor(0.0).to(self.device)
690 | unique = torch.unique(group)
691 |
692 | if len(unique) > 1:
693 | if self.modality_alignment == "MMD":
694 | zs = [z[group == i] for i in unique]
695 | for i in range(len(zs)):
696 | for j in range(i + 1, len(zs)):
697 | loss += MMD(kernel_type=self.kernel_type)(zs[i], zs[j])
698 |
699 | elif self.modality_alignment == "Jeffreys":
700 | mus_joint = [mu[group == i] for i in unique]
701 | logvars_joint = [logvar[group == i] for i in unique]
702 | for i in range(len(mus_joint)):
703 | for j in range(i + 1, len(mus_joint)):
704 | if len(mus_joint[i]) == len(mus_joint[j]):
705 | loss += Jeffreys()(
706 | (mus_joint[i], torch.exp(logvars_joint[i])), (mus_joint[j], torch.exp(logvars_joint[j]))
707 | )
708 | else:
709 | raise ValueError(
710 | f"mus_joint[i] and mus_joint[j] have different lengths: {len(mus_joint[i])} != {len(mus_joint[j])}"
711 | )
712 |
713 | return loss
714 |
715 | def _compute_cont_cov_embeddings(self, covs):
716 | """Compute embeddings for continuous covariates.
717 |
718 | Adapted from
719 | Title: CPA (c) Facebook, Inc.
720 | Date: 26.01.2022
721 | Link to the used code:
722 | https://github.com/facebookresearch/CPA/blob/382ff641c588820a453d801e5d0e5bb56642f282/compert/model.py#L342
723 |
724 | """
725 | if self.cont_cov_type == "mlp":
726 | embeddings = []
727 | for cov in range(covs.size(1)):
728 | this_cov = covs[:, cov].view(-1, 1)
729 | embeddings.append(
730 | self.cont_covariate_curves[cov](this_cov).sigmoid()
731 | ) # * this_drug.gt(0)) # TODO check what this .gt(0) is
732 | return torch.cat(embeddings, 1) @ self.cont_covariate_embeddings.weight
733 | else:
734 | return self.cont_covariate_curves(covs) @ self.cont_covariate_embeddings.weight
735 |
736 | def select_losses_to_plot(self):
737 | """Select losses to plot.
738 |
739 | Returns
740 | -------
741 | Loss names.
742 | """
743 | loss_names = ["kl_local", "elbo", "reconstruction_loss"]
744 | for i in range(self.n_modality):
745 | loss_names.append(f"modality_{i}_reconstruction_loss")
746 | if self.loss_coefs["integ"] != 0:
747 | loss_names.append("integ_loss")
748 | return loss_names
749 |
--------------------------------------------------------------------------------
/src/multimil/nn/__init__.py:
--------------------------------------------------------------------------------
1 | from ._base_components import MLP, Aggregator, Decoder, GeneralizedSigmoid
2 |
3 | __all__ = ["MLP", "Decoder", "GeneralizedSigmoid", "Aggregator"]
4 |
--------------------------------------------------------------------------------
/src/multimil/nn/_base_components.py:
--------------------------------------------------------------------------------
1 | from typing import Literal
2 |
3 | import torch
4 | from scvi.nn import FCLayers
5 | from torch import nn
6 | from torch.nn import functional as F
7 |
8 |
9 | class MLP(nn.Module):
10 | """A helper class to build blocks of fully-connected, normalization, dropout and activation layers.
11 |
12 | Parameters
13 | ----------
14 | n_input
15 | Number of input features.
16 | n_output
17 | Number of output features.
18 | n_layers
19 | Number of hidden layers.
20 | n_hidden
21 | Number of hidden units.
22 | dropout_rate
23 | Dropout rate.
24 | normalization
25 | Type of normalization to use. Can be one of ["layer", "batch", "none"].
26 | activation
27 | Activation function to use.
28 |
29 | """
30 |
31 | def __init__(
32 | self,
33 | n_input: int,
34 | n_output: int,
35 | n_layers: int = 1,
36 | n_hidden: int = 128,
37 | dropout_rate: float = 0.1,
38 | normalization: str = "layer",
39 | activation=nn.LeakyReLU,
40 | ):
41 | super().__init__()
42 | use_layer_norm = False
43 | use_batch_norm = True
44 | if normalization == "layer":
45 | use_layer_norm = True
46 | use_batch_norm = False
47 | elif normalization == "none":
48 | use_batch_norm = False
49 |
50 | self.mlp = FCLayers(
51 | n_in=n_input,
52 | n_out=n_output,
53 | n_layers=n_layers,
54 | n_hidden=n_hidden,
55 | dropout_rate=dropout_rate,
56 | use_layer_norm=use_layer_norm,
57 | use_batch_norm=use_batch_norm,
58 | activation_fn=activation,
59 | )
60 |
61 | def forward(self, x: torch.Tensor) -> torch.Tensor:
62 | """Forward computation on ``x``.
63 |
64 | Parameters
65 | ----------
66 | x
67 | Tensor of values with shape ``(n_input,)``.
68 |
69 | Returns
70 | -------
71 | Tensor of values with shape ``(n_output,)``.
72 | """
73 | return self.mlp(x)
74 |
75 |
76 | class Decoder(nn.Module):
77 | """A helper class to build custom decoders depending on which loss was passed.
78 |
79 | Parameters
80 | ----------
81 | n_input
82 | Number of input features.
83 | n_output
84 | Number of output features.
85 | n_layers
86 | Number of hidden layers.
87 | n_hidden
88 | Number of hidden units.
89 | dropout_rate
90 | Dropout rate.
91 | normalization
92 | Type of normalization to use. Can be one of ["layer", "batch", "none"].
93 | activation
94 | Activation function to use.
95 | loss
96 | Loss function to use. Can be one of ["mse", "nb", "zinb", "bce"].
97 | """
98 |
99 | def __init__(
100 | self,
101 | n_input: int,
102 | n_output: int,
103 | n_layers: int = 1,
104 | n_hidden: int = 128,
105 | dropout_rate: float = 0.1,
106 | normalization: str = "layer",
107 | activation=nn.LeakyReLU,
108 | loss="mse",
109 | ):
110 | super().__init__()
111 |
112 | if loss not in ["mse", "nb", "zinb", "bce"]:
113 | raise NotImplementedError(f"Loss function {loss} is not implemented.")
114 | else:
115 | self.loss = loss
116 |
117 | self.decoder = MLP(
118 | n_input=n_input,
119 | n_output=n_hidden,
120 | n_layers=n_layers,
121 | n_hidden=n_hidden,
122 | dropout_rate=dropout_rate,
123 | normalization=normalization,
124 | activation=activation,
125 | )
126 | if loss == "mse":
127 | self.recon_decoder = nn.Linear(n_hidden, n_output)
128 | elif loss == "nb":
129 | self.mean_decoder = nn.Sequential(nn.Linear(n_hidden, n_output), nn.Softmax(dim=-1))
130 |
131 | elif loss == "zinb":
132 | self.mean_decoder = nn.Sequential(nn.Linear(n_hidden, n_output), nn.Softmax(dim=-1))
133 | self.dropout_decoder = nn.Linear(n_hidden, n_output)
134 |
135 | elif loss == "bce":
136 | self.recon_decoder = FCLayers(
137 | n_in=n_hidden,
138 | n_out=n_output,
139 | n_layers=0,
140 | dropout_rate=0,
141 | use_layer_norm=False,
142 | use_batch_norm=False,
143 | activation_fn=nn.Sigmoid,
144 | )
145 |
146 | def forward(self, x: torch.Tensor) -> torch.Tensor:
147 | """Forward computation on ``x``.
148 |
149 | Parameters
150 | ----------
151 | x
152 | Tensor of values with shape ``(n_input,)``.
153 |
154 | Returns
155 | -------
156 | Tensor of values with shape ``(n_output,)``.
157 | """
158 | x = self.decoder(x)
159 | if self.loss == "mse" or self.loss == "bce":
160 | return self.recon_decoder(x)
161 | elif self.loss == "nb":
162 | return self.mean_decoder(x)
163 | elif self.loss == "zinb":
164 | return self.mean_decoder(x), self.dropout_decoder(x)
165 |
166 |
167 | class GeneralizedSigmoid(nn.Module):
168 | """Sigmoid, log-sigmoid or linear functions for encoding continuous covariates.
169 |
170 | Adapted from
171 | Title: CPA (c) Facebook, Inc.
172 | Date: 26.01.2022
173 | Link to the used code:
174 | https://github.com/facebookresearch/CPA/blob/382ff641c588820a453d801e5d0e5bb56642f282/compert/model.py#L109
175 |
176 | Parameters
177 | ----------
178 | dim
179 | Number of input features.
180 | nonlin
181 | Type of non-linearity to use. Can be one of ["logsigm", "sigm"]. Default is "logsigm".
182 | """
183 |
184 | def __init__(self, dim, nonlin: Literal["logsigm", "sigm"] | None = "logsigm"):
185 | super().__init__()
186 | self.nonlin = nonlin
187 | self.beta = torch.nn.Parameter(torch.ones(1, dim), requires_grad=True)
188 | self.bias = torch.nn.Parameter(torch.zeros(1, dim), requires_grad=True)
189 |
190 | def forward(self, x) -> torch.Tensor:
191 | """Forward computation on ``x``.
192 |
193 | Parameters
194 | ----------
195 | x
196 | Tensor of values.
197 |
198 | Returns
199 | -------
200 | Tensor of values with the same shape as ``x``.
201 | """
202 | if self.nonlin == "logsigm":
203 | return (torch.log1p(x) * self.beta + self.bias).sigmoid()
204 | elif self.nonlin == "sigm":
205 | return (x * self.beta + self.bias).sigmoid()
206 | else:
207 | return x
208 |
209 |
210 | class Aggregator(nn.Module):
211 | """A helper class to build custom aggregators depending on the scoring function passed.
212 |
213 | Parameters
214 | ----------
215 | n_input
216 | Number of input features.
217 | scoring
218 | Scoring function to use. Can be one of ["attn", "gated_attn", "mean", "max", "sum", "mlp"].
219 | attn_dim
220 | Dimension of the hidden attention layer.
221 | sample_batch_size
222 | Bag batch size.
223 | scale
224 | Whether to scale the attention weights.
225 | dropout
226 | Dropout rate.
227 | n_layers_mlp_attn
228 | Number of hidden layers in the MLP attention.
229 | n_hidden_mlp_attn
230 | Number of hidden units in the MLP attention.
231 | activation
232 | Activation function to use.
233 | """
234 |
235 | def __init__(
236 | self,
237 | n_input=None,
238 | scoring="gated_attn",
239 | attn_dim=16, # D
240 | sample_batch_size=None,
241 | scale=False,
242 | dropout=0.2,
243 | n_layers_mlp_attn=1,
244 | n_hidden_mlp_attn=16,
245 | activation=nn.LeakyReLU,
246 | ):
247 | super().__init__()
248 |
249 | self.scoring = scoring
250 | self.patient_batch_size = sample_batch_size
251 | self.scale = scale
252 |
253 | if self.scoring == "attn":
254 | self.attn_dim = attn_dim # attn dim from https://arxiv.org/pdf/1802.04712.pdf
255 | self.attention = nn.Sequential(
256 | nn.Linear(n_input, self.attn_dim),
257 | nn.Tanh(),
258 | nn.Linear(self.attn_dim, 1, bias=False),
259 | )
260 | elif self.scoring == "gated_attn":
261 | self.attn_dim = attn_dim
262 | self.attention_V = nn.Sequential(
263 | nn.Linear(n_input, self.attn_dim),
264 | nn.Tanh(),
265 | )
266 |
267 | self.attention_U = nn.Sequential(
268 | nn.Linear(n_input, self.attn_dim),
269 | nn.Sigmoid(),
270 | )
271 |
272 | self.attention_weights = nn.Linear(self.attn_dim, 1, bias=False)
273 |
274 | elif self.scoring == "mlp":
275 | if n_layers_mlp_attn == 1:
276 | self.attention = nn.Linear(n_input, 1)
277 | else:
278 | self.attention = nn.Sequential(
279 | MLP(
280 | n_input,
281 | n_hidden_mlp_attn,
282 | n_layers=n_layers_mlp_attn - 1,
283 | n_hidden=n_hidden_mlp_attn,
284 | dropout_rate=dropout,
285 | activation=activation,
286 | ),
287 | nn.Linear(n_hidden_mlp_attn, 1),
288 | )
289 |
290 | def forward(self, x) -> torch.Tensor:
291 | """Forward computation on ``x``.
292 |
293 | Parameters
294 | ----------
295 | x
296 | Tensor of values with shape ``(n_input,)``.
297 |
298 | Returns
299 | -------
300 | Tensor of pooled values.
301 | """
302 | # # TODO add mean and max pooling
303 | # if self.scoring == "attn":
304 | # # from https://github.com/AMLab-Amsterdam/AttentionDeepMIL/blob/master/model.py (accessed 16.09.2021)
305 | # self.A = self.attention(x) # Nx1
306 | # self.A = torch.transpose(self.A, -1, -2) # 1xN
307 | # self.A = F.softmax(self.A, dim=-1) # softmax over N
308 |
309 | # elif self.scoring == "gated_attn":
310 | # # from https://github.com/AMLab-Amsterdam/AttentionDeepMIL/blob/master/model.py (accessed 16.09.2021)
311 | # A_V = self.attention_V(x) # NxD
312 | # A_U = self.attention_U(x) # NxD
313 | # self.A = self.attention_weights(A_V * A_U) # element wise multiplication # Nx1
314 | # self.A = torch.transpose(self.A, -1, -2) # 1xN
315 | # self.A = F.softmax(self.A, dim=-1) # softmax over N
316 |
317 | # elif self.scoring == "mlp":
318 | # self.A = self.attention(x) # N
319 | # self.A = torch.transpose(self.A, -1, -2)
320 | # self.A = F.softmax(self.A, dim=-1)
321 |
322 | # else:
323 | # raise NotImplementedError(
324 | # f'scoring = {self.scoring} is not implemented. Has to be one of ["attn", "gated_attn", "mlp"].'
325 | # )
326 |
327 | # if self.scale:
328 | # self.A = self.A * self.A.shape[-1] / self.patient_batch_size
329 |
330 | # return torch.bmm(self.A, x).squeeze(dim=1)
331 | # Apply different pooling strategies based on the scoring method
332 | if self.scoring in ["attn", "gated_attn", "mlp", "sum", "mean", "max"]:
333 | if self.scoring == "attn":
334 | # from https://github.com/AMLab-Amsterdam/AttentionDeepMIL/blob/master/model.py (accessed 16.09.2021)
335 | A = self.attention(x) # (batch_size, N, 1)
336 | A = A.transpose(1, 2) # (batch_size, 1, N)
337 | A = F.softmax(A, dim=-1)
338 | elif self.scoring == "gated_attn":
339 | # from https://github.com/AMLab-Amsterdam/AttentionDeepMIL/blob/master/model.py (accessed 16.09.2021)
340 | A_V = self.attention_V(x) # (batch_size, N, attn_dim)
341 | A_U = self.attention_U(x) # (batch_size, N, attn_dim)
342 | A = self.attention_weights(A_V * A_U) # (batch_size, N, 1)
343 | A = A.transpose(1, 2) # (batch_size, 1, N)
344 | A = F.softmax(A, dim=-1)
345 | elif self.scoring == "mlp":
346 | A = self.attention(x) # (batch_size, N, 1)
347 | A = A.transpose(1, 2) # (batch_size, 1, N)
348 | A = F.softmax(A, dim=-1)
349 |
350 | elif self.scoring == "sum":
351 | return torch.sum(x, dim=1) # (batch_size, n_input)
352 | elif self.scoring == "mean":
353 | return torch.mean(x, dim=1) # (batch_size, n_input)
354 | elif self.scoring == "max":
355 | return torch.max(x, dim=1).values # (batch_size, n_input)
356 | else:
357 | raise NotImplementedError(
358 | f'scoring = {self.scoring} is not implemented. Has to be one of ["attn", "gated_attn", "mlp", "sum", "mean", "max"].'
359 | )
360 | if self.scale:
361 | if self.patient_batch_size is None:
362 | raise ValueError("patient_batch_size must be set when scale is True.")
363 | self.A = A * A.shape[-1] / self.patient_batch_size
364 |
365 | pooled = torch.bmm(A, x).squeeze(dim=1) # (batch_size, n_input)
366 | return pooled
367 |
--------------------------------------------------------------------------------
/src/multimil/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from ._utils import (
2 | calculate_size_factor,
3 | create_df,
4 | get_bag_info,
5 | get_predictions,
6 | plt_plot_losses,
7 | prep_minibatch,
8 | save_predictions_in_adata,
9 | select_covariates,
10 | setup_ordinal_regression,
11 | )
12 |
13 | __all__ = [
14 | "create_df",
15 | "calculate_size_factor",
16 | "setup_ordinal_regression",
17 | "select_covariates",
18 | "prep_minibatch",
19 | "get_predictions",
20 | "get_bag_info",
21 | "save_predictions_in_adata",
22 | "plt_plot_losses",
23 | ]
24 |
--------------------------------------------------------------------------------
/src/multimil/utils/_utils.py:
--------------------------------------------------------------------------------
1 | from math import ceil
2 |
3 | import numpy as np
4 | import pandas as pd
5 | import scipy
6 | import torch
7 | from matplotlib import pyplot as plt
8 |
9 |
10 | def create_df(pred, columns=None, index=None) -> pd.DataFrame:
11 | """Create a pandas DataFrame from a list of predictions.
12 |
13 | Parameters
14 | ----------
15 | pred
16 | List of predictions.
17 | columns
18 | Column names, i.e. class_names.
19 | index
20 | Index names, i.e. obs_names.
21 |
22 | Returns
23 | -------
24 | DataFrame with predictions.
25 | """
26 | if isinstance(pred, dict):
27 | for key in pred.keys():
28 | pred[key] = torch.cat(pred[key]).squeeze().cpu().numpy()
29 | else:
30 | pred = torch.cat(pred).squeeze().cpu().numpy()
31 |
32 | df = pd.DataFrame(pred)
33 | if index is not None:
34 | df.index = index
35 | if columns is not None:
36 | df.columns = columns
37 | return df
38 |
39 |
40 | def calculate_size_factor(adata, size_factor_key, rna_indices_end) -> str:
41 | """Calculate size factors.
42 |
43 | Parameters
44 | ----------
45 | adata
46 | Annotated data object.
47 | size_factor_key
48 | Key in `adata.obs` where size factors are stored.
49 | rna_indices_end
50 | Index of the last RNA feature in the data.
51 |
52 | Returns
53 | -------
54 | Size factor key.
55 | """
56 | # TODO check that organize_multimodal_anndatas was run, i.e. that .uns['modality_lengths'] was added, needed for q2r
57 | if size_factor_key is not None and rna_indices_end is not None:
58 | raise ValueError(
59 | "Only one of [`size_factor_key`, `rna_indices_end`] can be specified, but both are not `None`."
60 | )
61 | # TODO change to when both are None and data in unimodal, use all input features to calculate the size factors, add warning
62 | if size_factor_key is None and rna_indices_end is None:
63 | raise ValueError("One of [`size_factor_key`, `rna_indices_end`] has to be specified, but both are `None`.")
64 |
65 | if size_factor_key is not None:
66 | return size_factor_key
67 | if rna_indices_end is not None:
68 | adata_rna = adata[:, :rna_indices_end].copy()
69 | if scipy.sparse.issparse(adata.X):
70 | adata.obs.loc[:, "size_factors"] = adata_rna.X.toarray().sum(1).T.tolist()
71 | else:
72 | adata.obs.loc[:, "size_factors"] = adata_rna.X.sum(1).T.tolist()
73 | return "size_factors"
74 |
75 |
76 | def setup_ordinal_regression(adata, ordinal_regression_order, categorical_covariate_keys):
77 | """Setup ordinal regression.
78 |
79 | Parameters
80 | ----------
81 | adata
82 | Annotated data object.
83 | ordinal_regression_order
84 | Order of categories for ordinal regression.
85 | categorical_covariate_keys
86 | Keys of categorical covariates.
87 | """
88 | # TODO make sure not to assume categorical columns for ordinal regression -> change to np.unique if needed
89 | if ordinal_regression_order is not None:
90 | if not set(ordinal_regression_order.keys()).issubset(categorical_covariate_keys):
91 | raise ValueError(
92 | f"All keys {ordinal_regression_order.keys()} has to be registered as categorical covariates too, but categorical_covariate_keys = {categorical_covariate_keys}"
93 | )
94 | for key in ordinal_regression_order.keys():
95 | adata.obs[key] = adata.obs[key].astype("category")
96 | if set(adata.obs[key].cat.categories) != set(ordinal_regression_order[key]):
97 | raise ValueError(
98 | f"Categories of adata.obs[{key}]={adata.obs[key].cat.categories} are not the same as categories specified = {ordinal_regression_order[key]}"
99 | )
100 | adata.obs[key] = adata.obs[key].cat.reorder_categories(ordinal_regression_order[key])
101 |
102 |
103 | def select_covariates(covs, prediction_idx, n_samples_in_batch) -> torch.Tensor:
104 | """Select prediction covariates from all covariates.
105 |
106 | Parameters
107 | ----------
108 | covs
109 | Covariates.
110 | prediction_idx
111 | Index of predictions.
112 | n_samples_in_batch
113 | Number of samples in the batch.
114 |
115 | Returns
116 | -------
117 | Prediction covariates.
118 | """
119 | if len(prediction_idx) > 0:
120 | covs = torch.index_select(covs, 1, prediction_idx)
121 | covs = covs.view(n_samples_in_batch, -1, len(prediction_idx))[:, 0, :]
122 | else:
123 | covs = torch.tensor([])
124 | return covs
125 |
126 |
127 | def prep_minibatch(covs, sample_batch_size) -> tuple[int, int]:
128 | """Prepare minibatch.
129 |
130 | Parameters
131 | ----------
132 | covs
133 | Covariates.
134 | sample_batch_size
135 | Sample batch size.
136 |
137 | Returns
138 | -------
139 | Batch size and number of samples in the batch.
140 | """
141 | batch_size = covs.shape[0]
142 |
143 | if batch_size % sample_batch_size != 0:
144 | n_samples_in_batch = 1
145 | else:
146 | n_samples_in_batch = batch_size // sample_batch_size
147 | return batch_size, n_samples_in_batch
148 |
149 |
150 | def get_predictions(
151 | prediction_idx, pred_values, true_values, size, bag_pred, bag_true, full_pred, offset=0
152 | ) -> tuple[dict, dict, dict]:
153 | """Get predictions.
154 |
155 | Parameters
156 | ----------
157 | prediction_idx
158 | Index of predictions.
159 | pred_values
160 | Predicted values.
161 | true_values
162 | True values.
163 | size
164 | Size of the bag minibatch.
165 | bag_pred
166 | Bag predictions.
167 | bag_true
168 | Bag true values.
169 | full_pred
170 | Full predictions, i.e. on cell-level.
171 | offset
172 | Offset, needed because of several possible types of predictions.
173 |
174 | Returns
175 | -------
176 | Bag predictions, bag true values, full predictions on cell-level.
177 | """
178 | for i in range(len(prediction_idx)):
179 | bag_pred[i] = bag_pred.get(i, []) + [pred_values[offset + i].cpu()]
180 | bag_true[i] = bag_true.get(i, []) + [true_values[:, i].cpu()]
181 | # TODO in ord reg had pred[len(self.mil.class_idx) + i].repeat(1, size).flatten()
182 | # in reg had
183 | # cell level, i.e. prediction for the cell = prediction for the bag
184 | full_pred[i] = full_pred.get(i, []) + [pred_values[offset + i].unsqueeze(1).repeat(1, size, 1).flatten(0, 1)]
185 | return bag_pred, bag_true, full_pred
186 |
187 |
188 | def get_bag_info(bags, n_samples_in_batch, minibatch_size, cell_counter, bag_counter, sample_batch_size):
189 | """Get bag information.
190 |
191 | Parameters
192 | ----------
193 | bags
194 | Bags.
195 | n_samples_in_batch
196 | Number of samples in the batch.
197 | minibatch_size
198 | Minibatch size.
199 | cell_counter
200 | Cell counter.
201 | bag_counter
202 | Bag counter.
203 | sample_batch_size
204 | Sample batch size.
205 |
206 | Returns
207 | -------
208 | Updated bags, cell counter, and bag counter.
209 | """
210 | if n_samples_in_batch == 1:
211 | bags += [[bag_counter] * minibatch_size]
212 | cell_counter += minibatch_size
213 | bag_counter += 1
214 | else:
215 | bags += [[bag_counter + i] * sample_batch_size for i in range(n_samples_in_batch)]
216 | bag_counter += n_samples_in_batch
217 | cell_counter += sample_batch_size * n_samples_in_batch
218 | return bags, cell_counter, bag_counter
219 |
220 |
221 | def save_predictions_in_adata(
222 | adata, idx, predictions, bag_pred, bag_true, cell_pred, class_names, name, clip, reg=False
223 | ):
224 | """Save predictions in anndata object.
225 |
226 | Parameters
227 | ----------
228 | adata
229 | Annotated data object.
230 | idx
231 | Index, i.e. obs_names.
232 | predictions
233 | Predictions.
234 | bag_pred
235 | Bag predictions.
236 | bag_true
237 | Bag true values.
238 | cell_pred
239 | Cell predictions.
240 | class_names
241 | Class names.
242 | name
243 | Name of the prediction column.
244 | clip
245 | Whether to transofrm the predictions. One of `clip`, `argmax`, or `none`.
246 | reg
247 | Whether the rediciton task is a regression task.
248 | """
249 | # cell level predictions)
250 |
251 | if clip == "clip": # ord regression
252 | df = create_df(cell_pred[idx], [name], index=adata.obs_names)
253 | adata.obsm[f"full_predictions_{name}"] = df
254 | adata.obs[f"predicted_{name}"] = np.clip(np.round(df.to_numpy()), a_min=0.0, a_max=len(class_names) - 1.0)
255 | elif clip == "argmax": # classification
256 | df = create_df(cell_pred[idx], class_names, index=adata.obs_names)
257 | adata.obsm[f"full_predictions_{name}"] = df
258 | adata.obs[f"predicted_{name}"] = df.to_numpy().argmax(axis=1)
259 | else: # regression
260 | df = create_df(cell_pred[idx], [name], index=adata.obs_names)
261 | adata.obsm[f"full_predictions_{name}"] = df
262 | adata.obs[f"predicted_{name}"] = df.to_numpy()
263 | if reg is False:
264 | adata.obs[f"predicted_{name}"] = adata.obs[f"predicted_{name}"].astype("category")
265 | adata.obs[f"predicted_{name}"] = adata.obs[f"predicted_{name}"].cat.rename_categories(
266 | dict(enumerate(class_names))
267 | )
268 |
269 | # bag level predictions
270 | adata.uns[f"bag_true_{name}"] = create_df(bag_true, predictions)
271 | if clip == "clip": # ordinal regression
272 | df_bag = create_df(bag_pred[idx], [name])
273 | adata.uns[f"bag_full_predictions_{name}"] = np.clip(
274 | np.round(df_bag.to_numpy()), a_min=0.0, a_max=len(class_names) - 1.0
275 | )
276 | elif clip == "argmax": # classification
277 | df_bag = create_df(bag_pred[idx], class_names)
278 | adata.uns[f"bag_full_predictions_{name}"] = df_bag.to_numpy().argmax(axis=1)
279 | else: # regression
280 | df_bag = create_df(bag_pred[idx], [name])
281 | adata.uns[f"bag_full_predictions_{name}"] = df_bag.to_numpy()
282 |
283 |
284 | def plt_plot_losses(history, loss_names, save):
285 | """Plot losses.
286 |
287 | Parameters
288 | ----------
289 | history
290 | History of losses.
291 | loss_names
292 | Loss names to plot.
293 | save
294 | Path to save the plot.
295 | """
296 | df = pd.concat(history, axis=1)
297 | df.columns = df.columns.droplevel(-1)
298 | df["epoch"] = df.index
299 |
300 | nrows = ceil(len(loss_names) / 2)
301 |
302 | plt.figure(figsize=(15, 5 * nrows))
303 |
304 | for i, name in enumerate(loss_names):
305 | plt.subplot(nrows, 2, i + 1)
306 | plt.plot(df["epoch"], df[name + "_train"], ".-", label=name + "_train")
307 | plt.plot(df["epoch"], df[name + "_validation"], ".-", label=name + "_validation")
308 | plt.xlabel("epoch")
309 | plt.legend()
310 | if save is not None:
311 | plt.savefig(save, bbox_inches="tight")
312 |
--------------------------------------------------------------------------------
/tests/test_basic.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | import multimil
4 |
5 |
6 | def test_package_has_version():
7 | assert multimil.__version__ is not None
8 |
9 |
10 | @pytest.mark.skip(reason="This decorator should be removed when test passes.")
11 | def test_example():
12 | assert 1 == 0 # This test is designed to fail.
13 |
--------------------------------------------------------------------------------