├── .codecov.yaml ├── .cruft.json ├── .editorconfig ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.yml │ ├── config.yml │ └── feature_request.yml └── workflows │ ├── build.yaml │ ├── release.yaml │ └── test.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yaml ├── CHANGELOG.md ├── LICENSE ├── README.md ├── docs ├── Makefile ├── _static │ ├── .gitkeep │ └── css │ │ └── custom.css ├── _templates │ ├── .gitkeep │ └── autosummary │ │ └── class.rst ├── api.md ├── changelog.md ├── conf.py ├── contributing.md ├── extensions │ └── typed_returns.py ├── index.md ├── notebooks │ ├── cellttype_harmonize.json │ ├── mil_classification.ipynb │ ├── paired_integration_cite-seq.ipynb │ └── trimodal_integration.ipynb ├── references.bib ├── references.md └── tutorials.md ├── pyproject.toml ├── src └── multimil │ ├── __init__.py │ ├── data │ ├── __init__.py │ └── _preprocessing.py │ ├── dataloaders │ ├── __init__.py │ ├── _ann_dataloader.py │ └── _data_splitting.py │ ├── distributions │ ├── __init__.py │ ├── _jeffreys.py │ └── _mmd.py │ ├── model │ ├── __init__.py │ ├── _mil.py │ ├── _multivae.py │ └── _multivae_mil.py │ ├── module │ ├── __init__.py │ ├── _mil_torch.py │ ├── _multivae_mil_torch.py │ └── _multivae_torch.py │ ├── nn │ ├── __init__.py │ └── _base_components.py │ └── utils │ ├── __init__.py │ └── _utils.py └── tests └── test_basic.py /.codecov.yaml: -------------------------------------------------------------------------------- 1 | # Based on pydata/xarray 2 | codecov: 3 | require_ci_to_pass: no 4 | 5 | coverage: 6 | status: 7 | project: 8 | default: 9 | # Require 1% coverage, i.e., always succeed 10 | target: 1 11 | patch: false 12 | changes: false 13 | 14 | comment: 15 | layout: diff, flags, files 16 | behavior: once 17 | require_base: no 18 | -------------------------------------------------------------------------------- /.cruft.json: -------------------------------------------------------------------------------- 1 | { 2 | "template": "https://github.com/scverse/cookiecutter-scverse", 3 | "commit": "aa877587e59c855e2464d09ed7678af00999191a", 4 | "checkout": null, 5 | "context": { 6 | "cookiecutter": { 7 | "project_name": "multimil", 8 | "package_name": "multimil", 9 | "project_description": "Multimodal weakly supervised learning to identify disease-specific changes in single-cell atlases", 10 | "author_full_name": "Anastasia Litinetskaya", 11 | "author_email": "alitinet@gmail.com", 12 | "github_user": "alitinet", 13 | "project_repo": "https://github.com/theislab/multimil", 14 | "license": "BSD 3-Clause License", 15 | "_copy_without_render": [ 16 | ".github/workflows/build.yaml", 17 | ".github/workflows/test.yaml", 18 | "docs/_templates/autosummary/**.rst" 19 | ], 20 | "_render_devdocs": false, 21 | "_jinja2_env_vars": { 22 | "lstrip_blocks": true, 23 | "trim_blocks": true 24 | }, 25 | "_template": "https://github.com/scverse/cookiecutter-scverse" 26 | } 27 | }, 28 | "directory": null 29 | } 30 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*] 4 | indent_style = space 5 | indent_size = 4 6 | end_of_line = lf 7 | charset = utf-8 8 | trim_trailing_whitespace = true 9 | insert_final_newline = true 10 | 11 | [*.{yml,yaml}] 12 | indent_size = 2 13 | 14 | [.cruft.json] 15 | indent_size = 2 16 | 17 | [Makefile] 18 | indent_style = tab 19 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.yml: -------------------------------------------------------------------------------- 1 | name: Bug report 2 | description: Report something that is broken or incorrect 3 | labels: bug 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: | 8 | **Note**: Please read [this guide](https://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports) 9 | detailing how to provide the necessary information for us to reproduce your bug. In brief: 10 | * Please provide exact steps how to reproduce the bug in a clean Python environment. 11 | * In case it's not clear what's causing this bug, please provide the data or the data generation procedure. 12 | * Sometimes it is not possible to share the data, but usually it is possible to replicate problems on publicly 13 | available datasets or to share a subset of your data. 14 | 15 | - type: textarea 16 | id: report 17 | attributes: 18 | label: Report 19 | description: A clear and concise description of what the bug is. 20 | validations: 21 | required: true 22 | 23 | - type: textarea 24 | id: versions 25 | attributes: 26 | label: Version information 27 | description: | 28 | Please paste below the output of 29 | 30 | ```python 31 | import session_info 32 | session_info.show(html=False, dependencies=True) 33 | ``` 34 | placeholder: | 35 | ----- 36 | anndata 0.8.0rc2.dev27+ge524389 37 | session_info 1.0.0 38 | ----- 39 | asttokens NA 40 | awkward 1.8.0 41 | backcall 0.2.0 42 | cython_runtime NA 43 | dateutil 2.8.2 44 | debugpy 1.6.0 45 | decorator 5.1.1 46 | entrypoints 0.4 47 | executing 0.8.3 48 | h5py 3.7.0 49 | ipykernel 6.15.0 50 | jedi 0.18.1 51 | mpl_toolkits NA 52 | natsort 8.1.0 53 | numpy 1.22.4 54 | packaging 21.3 55 | pandas 1.4.2 56 | parso 0.8.3 57 | pexpect 4.8.0 58 | pickleshare 0.7.5 59 | pkg_resources NA 60 | prompt_toolkit 3.0.29 61 | psutil 5.9.1 62 | ptyprocess 0.7.0 63 | pure_eval 0.2.2 64 | pydev_ipython NA 65 | pydevconsole NA 66 | pydevd 2.8.0 67 | pydevd_file_utils NA 68 | pydevd_plugins NA 69 | pydevd_tracing NA 70 | pygments 2.12.0 71 | pytz 2022.1 72 | scipy 1.8.1 73 | setuptools 62.5.0 74 | setuptools_scm NA 75 | six 1.16.0 76 | stack_data 0.3.0 77 | tornado 6.1 78 | traitlets 5.3.0 79 | wcwidth 0.2.5 80 | zmq 23.1.0 81 | ----- 82 | IPython 8.4.0 83 | jupyter_client 7.3.4 84 | jupyter_core 4.10.0 85 | ----- 86 | Python 3.9.13 | packaged by conda-forge | (main, May 27 2022, 16:58:50) [GCC 10.3.0] 87 | Linux-5.18.6-arch1-1-x86_64-with-glibc2.35 88 | ----- 89 | Session information updated at 2022-07-07 17:55 90 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | contact_links: 3 | - name: Scverse Community Forum 4 | url: https://discourse.scverse.org/ 5 | about: If you have questions about “How to do X”, please ask them here. 6 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.yml: -------------------------------------------------------------------------------- 1 | name: Feature request 2 | description: Propose a new feature for multimil 3 | labels: enhancement 4 | body: 5 | - type: textarea 6 | id: description 7 | attributes: 8 | label: Description of feature 9 | description: Please describe your suggestion for a new feature. It might help to describe a problem or use case, plus any alternatives that you have considered. 10 | validations: 11 | required: true 12 | -------------------------------------------------------------------------------- /.github/workflows/build.yaml: -------------------------------------------------------------------------------- 1 | name: Check Build 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | 9 | concurrency: 10 | group: ${{ github.workflow }}-${{ github.ref }} 11 | cancel-in-progress: true 12 | 13 | jobs: 14 | package: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@v3 18 | - name: Set up Python 3.10 19 | uses: actions/setup-python@v4 20 | with: 21 | python-version: "3.10" 22 | cache: "pip" 23 | cache-dependency-path: "**/pyproject.toml" 24 | - name: Install build dependencies 25 | run: python -m pip install --upgrade pip wheel twine build 26 | - name: Build package 27 | run: python -m build 28 | - name: Check package 29 | run: twine check --strict dist/*.whl 30 | -------------------------------------------------------------------------------- /.github/workflows/release.yaml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | # Use "trusted publishing", see https://docs.pypi.org/trusted-publishers/ 8 | jobs: 9 | release: 10 | name: Upload release to PyPI 11 | runs-on: ubuntu-latest 12 | environment: 13 | name: pypi 14 | url: https://pypi.org/p/multimil 15 | permissions: 16 | id-token: write # IMPORTANT: this permission is mandatory for trusted publishing 17 | steps: 18 | - uses: actions/checkout@v4 19 | with: 20 | filter: blob:none 21 | fetch-depth: 0 22 | - uses: actions/setup-python@v4 23 | with: 24 | python-version: "3.x" 25 | cache: "pip" 26 | - run: pip install build 27 | - run: python -m build 28 | - name: Publish package distributions to PyPI 29 | uses: pypa/gh-action-pypi-publish@release/v1 30 | -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | schedule: 9 | - cron: "0 5 1,15 * *" 10 | 11 | concurrency: 12 | group: ${{ github.workflow }}-${{ github.ref }} 13 | cancel-in-progress: true 14 | 15 | jobs: 16 | test: 17 | runs-on: ${{ matrix.os }} 18 | defaults: 19 | run: 20 | shell: bash -e {0} # -e to fail on error 21 | 22 | strategy: 23 | fail-fast: false 24 | matrix: 25 | include: 26 | - os: ubuntu-latest 27 | python: "3.10" 28 | - os: ubuntu-latest 29 | python: "3.12" 30 | - os: ubuntu-latest 31 | python: "3.12" 32 | pip-flags: "--pre" 33 | name: PRE-RELEASE DEPENDENCIES 34 | 35 | name: ${{ matrix.name }} Python ${{ matrix.python }} 36 | 37 | env: 38 | OS: ${{ matrix.os }} 39 | PYTHON: ${{ matrix.python }} 40 | 41 | steps: 42 | - uses: actions/checkout@v3 43 | - name: Set up Python ${{ matrix.python }} 44 | uses: actions/setup-python@v4 45 | with: 46 | python-version: ${{ matrix.python }} 47 | cache: "pip" 48 | cache-dependency-path: "**/pyproject.toml" 49 | 50 | - name: Install test dependencies 51 | run: | 52 | python -m pip install --upgrade pip wheel 53 | - name: Install dependencies 54 | run: | 55 | pip install ${{ matrix.pip-flags }} ".[dev,test]" 56 | - name: Test 57 | env: 58 | MPLBACKEND: agg 59 | PLATFORM: ${{ matrix.os }} 60 | DISPLAY: :42 61 | run: | 62 | coverage run -m pytest -v --color=yes 63 | - name: Report coverage 64 | run: | 65 | coverage report 66 | - name: Upload coverage 67 | uses: codecov/codecov-action@v3 68 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Notebooks 2 | .ipynb_checkpoints 3 | 4 | # Logs 5 | scvi_log/ 6 | 7 | # Temp files 8 | .DS_Store 9 | *~ 10 | buck-out/ 11 | 12 | # Compiled files 13 | .venv/ 14 | __pycache__/ 15 | .mypy_cache/ 16 | .ruff_cache/ 17 | 18 | # Distribution / packaging 19 | /build/ 20 | /dist/ 21 | /*.egg-info/ 22 | 23 | # Tests and coverage 24 | /.pytest_cache/ 25 | /.cache/ 26 | /data/ 27 | /node_modules/ 28 | 29 | # docs 30 | /docs/generated/ 31 | /docs/_build/ 32 | 33 | # IDEs 34 | /.idea/ 35 | /.vscode/ 36 | 37 | # Data 38 | *.h5ad 39 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | fail_fast: false 2 | default_language_version: 3 | python: python3 4 | default_stages: 5 | - commit 6 | - push 7 | minimum_pre_commit_version: 2.16.0 8 | repos: 9 | - repo: https://github.com/pre-commit/mirrors-prettier 10 | rev: v4.0.0-alpha.8 11 | hooks: 12 | - id: prettier 13 | - repo: https://github.com/astral-sh/ruff-pre-commit 14 | rev: v0.5.3 15 | hooks: 16 | - id: ruff 17 | types_or: [python, pyi, jupyter] 18 | args: [--fix, --exit-non-zero-on-fix] 19 | - id: ruff-format 20 | types_or: [python, pyi, jupyter] 21 | - repo: https://github.com/pre-commit/pre-commit-hooks 22 | rev: v4.6.0 23 | hooks: 24 | - id: detect-private-key 25 | - id: check-ast 26 | - id: end-of-file-fixer 27 | - id: mixed-line-ending 28 | args: [--fix=lf] 29 | - id: trailing-whitespace 30 | - id: check-case-conflict 31 | # Check that there are no merge conflicts (could be generated by template sync) 32 | - id: check-merge-conflict 33 | args: [--assume-in-merge] 34 | - repo: local 35 | hooks: 36 | - id: forbid-to-commit 37 | name: Don't commit rej files 38 | entry: | 39 | Cannot commit .rej files. These indicate merge conflicts that arise during automated template updates. 40 | Fix the merge conflicts manually and remove the .rej files. 41 | language: fail 42 | files: '.*\.rej$' 43 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # https://docs.readthedocs.io/en/stable/config-file/v2.html 2 | version: 2 3 | build: 4 | os: ubuntu-20.04 5 | tools: 6 | python: "3.10" 7 | sphinx: 8 | configuration: docs/conf.py 9 | # disable this for more lenient docs builds 10 | fail_on_warning: false 11 | python: 12 | install: 13 | - method: pip 14 | path: . 15 | extra_requirements: 16 | - doc 17 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | The format is based on [Keep a Changelog][], 6 | and this project adheres to [Semantic Versioning][]. 7 | 8 | [keep a changelog]: https://keepachangelog.com/en/1.0.0/ 9 | [semantic versioning]: https://semver.org/spec/v2.0.0.html 10 | 11 | ## [Unreleased] 12 | 13 | ### Added 14 | 15 | - Basic tool, preprocessing and plotting functions 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2024, Anastasia Litinetskaya 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multimodal weakly supervised learning to identify disease-specific changes in single-cell atlases 2 | 3 | [![Tests][badge-tests]][link-tests] 4 | [![Documentation][badge-docs]][link-docs] 5 | 6 | [badge-tests]: https://img.shields.io/github/actions/workflow/status/theislab/multimil/test.yaml?branch=main 7 | [link-tests]: https://github.com/theislab/multimil/actions/workflows/test.yml 8 | [badge-docs]: https://img.shields.io/readthedocs/multimil 9 | [badge-colab]: https://colab.research.google.com/assets/colab-badge.svg 10 | 11 | ## Getting started 12 | 13 | Please refer to the [documentation][link-docs]. In particular, the 14 | 15 | - [API documentation][link-api] 16 | 17 | and the tutorials: 18 | 19 | 22 | 23 | - [Classification with MIL](https://multimil.readthedocs.io/en/latest/notebooks/mil_classification.html) [![Open In Colab][badge-colab]](https://colab.research.google.com/github/theislab/multimil/blob/main/docs/notebooks/mil_classification.ipynb) 24 | 25 | ## Installation 26 | 27 | You need to have Python 3.10 or newer installed on your system. We recommend installing [Mambaforge](https://github.com/conda-forge/miniforge#mambaforge). 28 | 29 | To create and activate a new environment: 30 | 31 | ```bash 32 | mamba create --name multimil python=3.10 33 | mamba activate multimil 34 | ``` 35 | 36 | Next, there are several alternative options to install multimil: 37 | 38 | 1. Install the latest release of `multimil` from [PyPI][link-pypi]: 39 | 40 | ```bash 41 | pip install multimil 42 | ``` 43 | 44 | 2. Or install the latest development version: 45 | 46 | ```bash 47 | pip install git+https://github.com/theislab/multimil.git@main 48 | ``` 49 | 50 | ## Release notes 51 | 52 | See the [changelog][changelog]. 53 | 54 | ## Contact 55 | 56 | If you found a bug, please use the [issue tracker][issue-tracker]. 57 | 58 | ## Citation 59 | 60 | > **Multimodal Weakly Supervised Learning to Identify Disease-Specific Changes in Single-Cell Atlases** 61 | > 62 | > Anastasia Litinetskaya, Maiia Shulman, Soroor Hediyeh-zadeh, Amir Ali Moinfar, Fabiola Curion, Artur Szalata, Alireza Omidi, Mohammad Lotfollahi, and Fabian J. Theis. 2024. bioRxiv. https://doi.org/10.1101/2024.07.29.605625. 63 | 64 | ## Reproducibility 65 | 66 | Code and notebooks to reproduce the results from the paper are available at https://github.com/theislab/multimil_reproducibility. 67 | 68 | [issue-tracker]: https://github.com/theislab/multimil/issues 69 | [changelog]: https://multimil.readthedocs.io/latest/changelog.html 70 | [link-docs]: https://multimil.readthedocs.io 71 | [link-api]: https://multimil.readthedocs.io/latest/api.html 72 | [link-pypi]: https://pypi.org/project/multimil 73 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/_static/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theislab/multimil/edfb3971cc2bc0302021eb3fa6175b11428bd903/docs/_static/.gitkeep -------------------------------------------------------------------------------- /docs/_static/css/custom.css: -------------------------------------------------------------------------------- 1 | /* Reduce the font size in data frames - See https://github.com/scverse/cookiecutter-scverse/issues/193 */ 2 | div.cell_output table.dataframe { 3 | font-size: 0.8em; 4 | } 5 | -------------------------------------------------------------------------------- /docs/_templates/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theislab/multimil/edfb3971cc2bc0302021eb3fa6175b11428bd903/docs/_templates/.gitkeep -------------------------------------------------------------------------------- /docs/_templates/autosummary/class.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. add toctree option to make autodoc generate the pages 6 | 7 | .. autoclass:: {{ objname }} 8 | 9 | {% block attributes %} 10 | {% if attributes %} 11 | Attributes table 12 | ~~~~~~~~~~~~~~~~~~ 13 | 14 | .. autosummary:: 15 | {% for item in attributes %} 16 | ~{{ fullname }}.{{ item }} 17 | {%- endfor %} 18 | {% endif %} 19 | {% endblock %} 20 | 21 | {% block methods %} 22 | {% if methods %} 23 | Methods table 24 | ~~~~~~~~~~~~~ 25 | 26 | .. autosummary:: 27 | {% for item in methods %} 28 | {%- if item != '__init__' %} 29 | ~{{ fullname }}.{{ item }} 30 | {%- endif -%} 31 | {%- endfor %} 32 | {% endif %} 33 | {% endblock %} 34 | 35 | {% block attributes_documentation %} 36 | {% if attributes %} 37 | Attributes 38 | ~~~~~~~~~~~ 39 | 40 | {% for item in attributes %} 41 | 42 | .. autoattribute:: {{ [objname, item] | join(".") }} 43 | {%- endfor %} 44 | 45 | {% endif %} 46 | {% endblock %} 47 | 48 | {% block methods_documentation %} 49 | {% if methods %} 50 | Methods 51 | ~~~~~~~ 52 | 53 | {% for item in methods %} 54 | {%- if item != '__init__' %} 55 | 56 | .. automethod:: {{ [objname, item] | join(".") }} 57 | {%- endif -%} 58 | {%- endfor %} 59 | 60 | {% endif %} 61 | {% endblock %} 62 | -------------------------------------------------------------------------------- /docs/api.md: -------------------------------------------------------------------------------- 1 | # API 2 | 3 | ## Model 4 | 5 | ```{eval-rst} 6 | .. module:: multimil.model 7 | .. currentmodule:: multimil 8 | 9 | .. autosummary:: 10 | :toctree: generated 11 | 12 | model.MultiVAE 13 | model.MILClassifier 14 | model.MultiVAE_MIL 15 | ``` 16 | 17 | ## Module 18 | 19 | ```{eval-rst} 20 | .. module:: multimil.module 21 | .. currentmodule:: multimil 22 | 23 | .. autosummary:: 24 | :toctree: generated 25 | 26 | module.MultiVAETorch 27 | module.MILClassifierTorch 28 | module.MultiVAETorch_MIL 29 | ``` 30 | -------------------------------------------------------------------------------- /docs/changelog.md: -------------------------------------------------------------------------------- 1 | ```{include} ../CHANGELOG.md 2 | 3 | ``` 4 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | import sys 9 | from datetime import datetime 10 | from importlib.metadata import metadata 11 | from pathlib import Path 12 | 13 | HERE = Path(__file__).parent 14 | sys.path.insert(0, str(HERE / "extensions")) 15 | 16 | 17 | # -- Project information ----------------------------------------------------- 18 | 19 | # NOTE: If you installed your project in editable mode, this might be stale. 20 | # If this is the case, reinstall it to refresh the metadata 21 | info = metadata("multimil") 22 | project_name = info["Name"] 23 | author = info["Author"] 24 | copyright = f"{datetime.now():%Y}, {author}." 25 | version = info["Version"] 26 | urls = dict(pu.split(", ") for pu in info.get_all("Project-URL")) 27 | repository_url = urls["Source"] 28 | 29 | # The full version, including alpha/beta/rc tags 30 | release = info["Version"] 31 | 32 | bibtex_bibfiles = ["references.bib"] 33 | templates_path = ["_templates"] 34 | nitpicky = True # Warn about broken links 35 | needs_sphinx = "4.0" 36 | 37 | html_context = { 38 | "display_github": True, # Integrate GitHub 39 | "github_user": "alitinet", 40 | "github_repo": "https://github.com/theislab/multimil", 41 | "github_version": "main", 42 | "conf_py_path": "/docs/", 43 | } 44 | 45 | # -- General configuration --------------------------------------------------- 46 | 47 | # Add any Sphinx extension module names here, as strings. 48 | # They can be extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones. 49 | extensions = [ 50 | "myst_nb", 51 | "sphinx_copybutton", 52 | "sphinx.ext.autodoc", 53 | "sphinx.ext.intersphinx", 54 | "sphinx.ext.autosummary", 55 | "sphinx.ext.napoleon", 56 | "sphinxcontrib.bibtex", 57 | "sphinx_autodoc_typehints", 58 | "sphinx.ext.mathjax", 59 | "IPython.sphinxext.ipython_console_highlighting", 60 | "sphinxext.opengraph", 61 | *[p.stem for p in (HERE / "extensions").glob("*.py")], 62 | ] 63 | 64 | autosummary_generate = True 65 | autodoc_member_order = "groupwise" 66 | default_role = "literal" 67 | napoleon_google_docstring = False 68 | napoleon_numpy_docstring = True 69 | napoleon_include_init_with_doc = False 70 | napoleon_use_rtype = True # having a separate entry generally helps readability 71 | napoleon_use_param = True 72 | myst_heading_anchors = 6 # create anchors for h1-h6 73 | myst_enable_extensions = [ 74 | "amsmath", 75 | "colon_fence", 76 | "deflist", 77 | "dollarmath", 78 | "html_image", 79 | "html_admonition", 80 | ] 81 | myst_url_schemes = ("http", "https", "mailto") 82 | nb_output_stderr = "remove" 83 | nb_execution_mode = "off" 84 | nb_merge_streams = True 85 | typehints_defaults = "braces" 86 | 87 | source_suffix = { 88 | ".rst": "restructuredtext", 89 | ".ipynb": "myst-nb", 90 | ".myst": "myst-nb", 91 | } 92 | 93 | intersphinx_mapping = { 94 | "python": ("https://docs.python.org/3", None), 95 | "anndata": ("https://anndata.readthedocs.io/en/stable/", None), 96 | "numpy": ("https://numpy.org/doc/stable/", None), 97 | } 98 | 99 | # List of patterns, relative to source directory, that match files and 100 | # directories to ignore when looking for source files. 101 | # This pattern also affects html_static_path and html_extra_path. 102 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "**.ipynb_checkpoints"] 103 | 104 | 105 | # -- Options for HTML output ------------------------------------------------- 106 | 107 | # The theme to use for HTML and HTML Help pages. See the documentation for 108 | # a list of builtin themes. 109 | # 110 | html_theme = "sphinx_book_theme" 111 | html_static_path = ["_static"] 112 | html_css_files = ["css/custom.css"] 113 | 114 | html_title = project_name 115 | 116 | html_theme_options = { 117 | "repository_url": repository_url, 118 | "use_repository_button": True, 119 | "path_to_docs": "docs/", 120 | "navigation_with_keys": False, 121 | } 122 | 123 | pygments_style = "default" 124 | 125 | nitpick_ignore = [ 126 | # If building the documentation fails because of a missing link that is outside your control, 127 | # you can add an exception to this list. 128 | # ("py:class", "igraph.Graph"), 129 | ] 130 | -------------------------------------------------------------------------------- /docs/contributing.md: -------------------------------------------------------------------------------- 1 | # Contributing guide 2 | 3 | Scanpy provides extensive [developer documentation][scanpy developer guide], most of which applies to this project, too. 4 | This document will not reproduce the entire content from there. Instead, it aims at summarizing the most important 5 | information to get you started on contributing. 6 | 7 | We assume that you are already familiar with git and with making pull requests on GitHub. If not, please refer 8 | to the [scanpy developer guide][]. 9 | 10 | ## Installing dev dependencies 11 | 12 | In addition to the packages needed to _use_ this package, you need additional python packages to _run tests_ and _build 13 | the documentation_. It's easy to install them using `pip`: 14 | 15 | ```bash 16 | cd multimil 17 | pip install -e ".[dev,test,doc]" 18 | ``` 19 | 20 | ## Code-style 21 | 22 | This package uses [pre-commit][] to enforce consistent code-styles. 23 | On every commit, pre-commit checks will either automatically fix issues with the code, or raise an error message. 24 | 25 | To enable pre-commit locally, simply run 26 | 27 | ```bash 28 | pre-commit install 29 | ``` 30 | 31 | in the root of the repository. Pre-commit will automatically download all dependencies when it is run for the first time. 32 | 33 | Alternatively, you can rely on the [pre-commit.ci][] service enabled on GitHub. If you didn't run `pre-commit` before 34 | pushing changes to GitHub it will automatically commit fixes to your pull request, or show an error message. 35 | 36 | If pre-commit.ci added a commit on a branch you still have been working on locally, simply use 37 | 38 | ```bash 39 | git pull --rebase 40 | ``` 41 | 42 | to integrate the changes into yours. 43 | While the [pre-commit.ci][] is useful, we strongly encourage installing and running pre-commit locally first to understand its usage. 44 | 45 | Finally, most editors have an _autoformat on save_ feature. Consider enabling this option for [ruff][ruff-editors] 46 | and [prettier][prettier-editors]. 47 | 48 | [ruff-editors]: https://docs.astral.sh/ruff/integrations/ 49 | [prettier-editors]: https://prettier.io/docs/en/editors.html 50 | 51 | ## Writing tests 52 | 53 | ```{note} 54 | Remember to first install the package with `pip install -e '.[dev,test]'` 55 | ``` 56 | 57 | This package uses the [pytest][] for automated testing. Please [write tests][scanpy-test-docs] for every function added 58 | to the package. 59 | 60 | Most IDEs integrate with pytest and provide a GUI to run tests. Alternatively, you can run all tests from the 61 | command line by executing 62 | 63 | ```bash 64 | pytest 65 | ``` 66 | 67 | in the root of the repository. 68 | 69 | ### Continuous integration 70 | 71 | Continuous integration will automatically run the tests on all pull requests and test 72 | against the minimum and maximum supported Python version. 73 | 74 | Additionally, there's a CI job that tests against pre-releases of all dependencies 75 | (if there are any). The purpose of this check is to detect incompatibilities 76 | of new package versions early on and gives you time to fix the issue or reach 77 | out to the developers of the dependency before the package is released to a wider audience. 78 | 79 | [scanpy-test-docs]: https://scanpy.readthedocs.io/en/latest/dev/testing.html#writing-tests 80 | 81 | ## Publishing a release 82 | 83 | ### Updating the version number 84 | 85 | Before making a release, you need to update the version number in the `pyproject.toml` file. Please adhere to [Semantic Versioning][semver], in brief 86 | 87 | > Given a version number MAJOR.MINOR.PATCH, increment the: 88 | > 89 | > 1. MAJOR version when you make incompatible API changes, 90 | > 2. MINOR version when you add functionality in a backwards compatible manner, and 91 | > 3. PATCH version when you make backwards compatible bug fixes. 92 | > 93 | > Additional labels for pre-release and build metadata are available as extensions to the MAJOR.MINOR.PATCH format. 94 | 95 | Once you are done, commit and push your changes and navigate to the "Releases" page of this project on GitHub. 96 | Specify `vX.X.X` as a tag name and create a release. For more information, see [managing GitHub releases][]. This will automatically create a git tag and trigger a Github workflow that creates a release on PyPI. 97 | 98 | ## Writing documentation 99 | 100 | Please write documentation for new or changed features and use-cases. This project uses [sphinx][] with the following features: 101 | 102 | - the [myst][] extension allows to write documentation in markdown/Markedly Structured Text 103 | - [Numpy-style docstrings][numpydoc] (through the [napoloen][numpydoc-napoleon] extension). 104 | - Jupyter notebooks as tutorials through [myst-nb][] (See [Tutorials with myst-nb](#tutorials-with-myst-nb-and-jupyter-notebooks)) 105 | - [Sphinx autodoc typehints][], to automatically reference annotated input and output types 106 | - Citations (like {cite:p}`Virshup_2023`) can be included with [sphinxcontrib-bibtex](https://sphinxcontrib-bibtex.readthedocs.io/) 107 | 108 | See the [scanpy developer docs](https://scanpy.readthedocs.io/en/latest/dev/documentation.html) for more information 109 | on how to write documentation. 110 | 111 | ### Tutorials with myst-nb and jupyter notebooks 112 | 113 | The documentation is set-up to render jupyter notebooks stored in the `docs/notebooks` directory using [myst-nb][]. 114 | Currently, only notebooks in `.ipynb` format are supported that will be included with both their input and output cells. 115 | It is your responsibility to update and re-run the notebook whenever necessary. 116 | 117 | If you are interested in automatically running notebooks as part of the continuous integration, please check 118 | out [this feature request](https://github.com/scverse/cookiecutter-scverse/issues/40) in the `cookiecutter-scverse` 119 | repository. 120 | 121 | #### Hints 122 | 123 | - If you refer to objects from other packages, please add an entry to `intersphinx_mapping` in `docs/conf.py`. Only 124 | if you do so can sphinx automatically create a link to the external documentation. 125 | - If building the documentation fails because of a missing link that is outside your control, you can add an entry to 126 | the `nitpick_ignore` list in `docs/conf.py` 127 | 128 | #### Building the docs locally 129 | 130 | ```bash 131 | cd docs 132 | make html 133 | open _build/html/index.html 134 | ``` 135 | 136 | 137 | 138 | [scanpy developer guide]: https://scanpy.readthedocs.io/en/latest/dev/index.html 139 | [cookiecutter-scverse-instance]: https://cookiecutter-scverse-instance.readthedocs.io/en/latest/template_usage.html 140 | [github quickstart guide]: https://docs.github.com/en/get-started/quickstart/create-a-repo?tool=webui 141 | [codecov]: https://about.codecov.io/sign-up/ 142 | [codecov docs]: https://docs.codecov.com/docs 143 | [codecov bot]: https://docs.codecov.com/docs/team-bot 144 | [codecov app]: https://github.com/apps/codecov 145 | [pre-commit.ci]: https://pre-commit.ci/ 146 | [readthedocs.org]: https://readthedocs.org/ 147 | [myst-nb]: https://myst-nb.readthedocs.io/en/latest/ 148 | [jupytext]: https://jupytext.readthedocs.io/en/latest/ 149 | [pre-commit]: https://pre-commit.com/ 150 | [anndata]: https://github.com/scverse/anndata 151 | [mudata]: https://github.com/scverse/mudata 152 | [pytest]: https://docs.pytest.org/ 153 | [semver]: https://semver.org/ 154 | [sphinx]: https://www.sphinx-doc.org/en/master/ 155 | [myst]: https://myst-parser.readthedocs.io/en/latest/intro.html 156 | [numpydoc-napoleon]: https://www.sphinx-doc.org/en/master/usage/extensions/napoleon.html 157 | [numpydoc]: https://numpydoc.readthedocs.io/en/latest/format.html 158 | [sphinx autodoc typehints]: https://github.com/tox-dev/sphinx-autodoc-typehints 159 | [pypi]: https://pypi.org/ 160 | [managing GitHub releases]: https://docs.github.com/en/repositories/releasing-projects-on-github/managing-releases-in-a-repository 161 | -------------------------------------------------------------------------------- /docs/extensions/typed_returns.py: -------------------------------------------------------------------------------- 1 | # code from https://github.com/theislab/scanpy/blob/master/docs/extensions/typed_returns.py 2 | # with some minor adjustment 3 | from __future__ import annotations 4 | 5 | import re 6 | from collections.abc import Generator, Iterable 7 | 8 | from sphinx.application import Sphinx 9 | from sphinx.ext.napoleon import NumpyDocstring 10 | 11 | 12 | def _process_return(lines: Iterable[str]) -> Generator[str, None, None]: 13 | for line in lines: 14 | if m := re.fullmatch(r"(?P\w+)\s+:\s+(?P[\w.]+)", line): 15 | yield f'-{m["param"]} (:class:`~{m["type"]}`)' 16 | else: 17 | yield line 18 | 19 | 20 | def _parse_returns_section(self: NumpyDocstring, section: str) -> list[str]: 21 | lines_raw = self._dedent(self._consume_to_next_section()) 22 | if lines_raw[0] == ":": 23 | del lines_raw[0] 24 | lines = self._format_block(":returns: ", list(_process_return(lines_raw))) 25 | if lines and lines[-1]: 26 | lines.append("") 27 | return lines 28 | 29 | 30 | def setup(app: Sphinx): 31 | """Set app.""" 32 | NumpyDocstring._parse_returns_section = _parse_returns_section 33 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | ```{include} ../README.md 2 | 3 | ``` 4 | 5 | ```{toctree} 6 | :hidden: true 7 | :maxdepth: 2 8 | 9 | api.md 10 | tutorials.md 11 | changelog.md 12 | contributing.md 13 | references.md 14 | ``` 15 | -------------------------------------------------------------------------------- /docs/notebooks/cellttype_harmonize.json: -------------------------------------------------------------------------------- 1 | { 2 | "cite_ct_l1_map": { 3 | "Naive CD20+ B IGKC+": "B", 4 | "Naive CD20+ B IGKC-": "B", 5 | "B1 B IGKC-": "B", 6 | "B1 B IGKC+": "B", 7 | "Transitional B": "B", 8 | "G/M prog": "Myeloid", 9 | "CD14+ Mono": "Myeloid", 10 | "pDC": "Myeloid", 11 | "cDC2": "Myeloid", 12 | "CD16+ Mono": "Myeloid", 13 | "cDC1": "Myeloid", 14 | "HSC": "HSC", 15 | "Reticulocyte": "Erythroid", 16 | "Normoblast": "Erythroid", 17 | "Erythroblast": "Erythroid", 18 | "Proerythroblast": "Erythroid", 19 | "MK/E prog": "Erythroid", 20 | "NK CD158e1+": "NK", 21 | "NK": "NK", 22 | "CD4+ T naive": "T", 23 | "CD4+ T activated": "T", 24 | "CD4+ T activated integrinB7+": "T", 25 | "CD4+ T CD314+ CD45RA+": "T", 26 | "CD8+ T naive": "T", 27 | "CD8+ T CD49f+": "T", 28 | "CD8+ T TIGIT+ CD45RO+": "T", 29 | "CD8+ T CD57+ CD45RA+": "T", 30 | "CD8+ T CD69+ CD45RO+": "T", 31 | "CD8+ T TIGIT+ CD45RA+": "T", 32 | "CD8+ T CD69+ CD45RA+": "T", 33 | "CD8+ T naive CD127+ CD26- CD101-": "T", 34 | "CD8+ T CD57+ CD45RO+": "T", 35 | "MAIT": "T", 36 | "T reg": "T", 37 | "gdT TCRVD2+": "T", 38 | "gdT CD158b+": "T", 39 | "dnT": "T", 40 | "T prog cycling": "T", 41 | "ILC1": "T", 42 | "Lymph prog": "Other Lympoid", 43 | "ILC": "Other Lympoid", 44 | "Plasmablast IGKC+": "Plasma", 45 | "Plasmablast IGKC-": "Plasma", 46 | "Plasma cell IGKC+": "Plasma", 47 | "Plasma cell IGKC-": "Plasma" 48 | }, 49 | "cite_ct_l2_map": { 50 | "Naive CD20+ B IGKC+": "Naive CD20+ B", 51 | "Naive CD20+ B IGKC-": "Naive CD20+ B", 52 | "B1 B IGKC-": "B1 B", 53 | "B1 B IGKC+": "B1 B", 54 | "Transitional B": "Transitional B", 55 | "G/M prog": "G/M prog", 56 | "CD14+ Mono": "CD14+ Mono", 57 | "pDC": "pDC", 58 | "cDC2": "cDC2", 59 | "CD16+ Mono": "CD16+ Mono", 60 | "cDC1": "cDC1", 61 | "HSC": "HSC", 62 | "Reticulocyte": "Reticulocyte", 63 | "Normoblast": "Normoblast", 64 | "Erythroblast": "Erythroblast", 65 | "Proerythroblast": "Proerythroblast", 66 | "MK/E prog": "MK/E prog", 67 | "NK CD158e1+": "NK", 68 | "NK": "NK", 69 | "CD4+ T naive": "CD4+ T naive", 70 | "CD4+ T activated": "CD4+ T activated", 71 | "CD4+ T activated integrinB7+": "CD4+ T activated", 72 | "CD4+ T CD314+ CD45RA+": "CD4+ T activated", 73 | "CD8+ T naive": "CD8+ T naive", 74 | "CD8+ T naive CD127+ CD26- CD101-": "CD8+ T naive", 75 | "CD8+ T CD49f+": "CD8+ T activated", 76 | "CD8+ T TIGIT+ CD45RO+": "CD8+ T activated", 77 | "CD8+ T CD57+ CD45RA+": "CD8+ T activated", 78 | "CD8+ T CD69+ CD45RO+": "CD8+ T activated", 79 | "CD8+ T TIGIT+ CD45RA+": "CD8+ T activated", 80 | "CD8+ T CD69+ CD45RA+": "CD8+ T activated", 81 | "CD8+ T CD57+ CD45RO+": "CD8+ T activated", 82 | "MAIT": "Other T", 83 | "T reg": "Other T", 84 | "gdT TCRVD2+": "Other T", 85 | "gdT CD158b+": "Other T", 86 | "dnT": "Other T", 87 | "T prog cycling": "Other T", 88 | "ILC1": "Other T", 89 | "Lymph prog": "Early Lymphoid", 90 | "ILC": "ILC", 91 | "Plasmablast IGKC+": "Plasma", 92 | "Plasmablast IGKC-": "Plasma", 93 | "Plasma cell IGKC+": "Plasma", 94 | "Plasma cell IGKC-": "Plasma" 95 | }, 96 | "multi_ct_l1_map": { 97 | "G/M prog": "Myeloid", 98 | "CD14+ Mono": "Myeloid", 99 | "CD16+ Mono": "Myeloid", 100 | "pDC": "Myeloid", 101 | "cDC2": "Myeloid", 102 | "ID2-hi myeloid prog": "Myeloid", 103 | "CD8+ T": "T", 104 | "CD8+ T naive": "T", 105 | "CD4+ T naive": "T", 106 | "CD4+ T activated": "T", 107 | "NK": "NK", 108 | "ILC": "Other Lympoid", 109 | "Lymph prog": "Other Lympoid", 110 | "Transitional B": "B", 111 | "Naive CD20+ B": "B", 112 | "B1 B": "B", 113 | "Plasma cell": "Plasma", 114 | "MK/E prog": "Erythroid", 115 | "Proerythroblast": "Erythroid", 116 | "Erythroblast": "Erythroid", 117 | "Normoblast": "Erythroid", 118 | "HSC": "HSC" 119 | }, 120 | "multi_ct_l2_map": { 121 | "G/M prog": "G/M prog", 122 | "CD14+ Mono": "CD14+ Mono", 123 | "CD16+ Mono": "CD16+ Mono", 124 | "pDC": "pDC", 125 | "cDC2": "cDC2", 126 | "ID2-hi myeloid prog": "Other Myeloid", 127 | "CD8+ T": "CD8+ T activated", 128 | "CD8+ T naive": "CD8+ T naive", 129 | "CD4+ T naive": "CD4+ T naive", 130 | "CD4+ T activated": "CD4+ T activated", 131 | "NK": "NK", 132 | "ILC": "ILC", 133 | "Lymph prog": "Early Lymphoid", 134 | "Transitional B": "Transitional B", 135 | "Naive CD20+ B": "Naive CD20+ B", 136 | "B1 B": "B1 B", 137 | "Plasma cell": "Plasma", 138 | "MK/E prog": "MK/E prog", 139 | "Proerythroblast": "Proerythroblast", 140 | "Erythroblast": "Erythroblast", 141 | "Normoblast": "Normoblast", 142 | "HSC": "HSC" 143 | } 144 | } 145 | -------------------------------------------------------------------------------- /docs/references.bib: -------------------------------------------------------------------------------- 1 | @article{Virshup_2023, 2 | doi = {10.1038/s41587-023-01733-8}, 3 | url = {https://doi.org/10.1038%2Fs41587-023-01733-8}, 4 | year = 2023, 5 | month = {apr}, 6 | publisher = {Springer Science and Business Media {LLC}}, 7 | author = {Isaac Virshup and Danila Bredikhin and Lukas Heumos and Giovanni Palla and Gregor Sturm and Adam Gayoso and Ilia Kats and Mikaela Koutrouli and Philipp Angerer and Volker Bergen and Pierre Boyeau and Maren Büttner and Gokcen Eraslan and David Fischer and Max Frank and Justin Hong and Michal Klein and Marius Lange and Romain Lopez and Mohammad Lotfollahi and Malte D. Luecken and Fidel Ramirez and Jeffrey Regier and Sergei Rybakov and Anna C. Schaar and Valeh Valiollah Pour Amiri and Philipp Weiler and Galen Xing and Bonnie Berger and Dana Pe'er and Aviv Regev and Sarah A. Teichmann and Francesca Finotello and F. Alexander Wolf and Nir Yosef and Oliver Stegle and Fabian J. Theis and}, 8 | title = {The scverse project provides a computational ecosystem for single-cell omics data analysis}, 9 | journal = {Nature Biotechnology} 10 | } 11 | 12 | @inproceedings{Luecken2021-ct, 13 | author = {Luecken, Malte and Burkhardt, Daniel and Cannoodt, Robrecht and Lance, Christopher and Agrawal, Aditi and Aliee, Hananeh and Chen, Ann and Deconinck, Louise and Detweiler, Angela and Granados, Alejandro and Huynh, Shelly and Isacco, Laura and Kim, Yang and Klein, Dominik and DE KUMAR, BONY and Kuppasani, Sunil and Lickert, Heiko and McGeever, Aaron and Melgarejo, Joaquin and Mekonen, Honey and Morri, Maurizio and M\"{u}ller, Michaela and Neff, Norma and Paul, Sheryl and Rieck, Bastian and Schneider, Kaylie and Steelman, Scott and Sterr, Michael and Treacy, Daniel and Tong, Alexander and Villani, Alexandra-Chloe and Wang, Guilin and Yan, Jia and Zhang, Ce and Pisco, Angela and Krishnaswamy, Smita and Theis, Fabian and Bloom, Jonathan M}, 14 | booktitle = {Proceedings of the Neural Information Processing Systems Track on Datasets and Benchmarks}, 15 | editor = {J. Vanschoren and S. Yeung}, 16 | pages = {}, 17 | title = {A sandbox for prediction and integration of DNA, RNA, and proteins in single cells}, 18 | url = {https://datasets-benchmarks-proceedings.neurips.cc/paper_files/paper/2021/file/158f3069a435b314a80bdcb024f8e422-Paper-round2.pdf}, 19 | volume = {1}, 20 | year = {2021} 21 | } 22 | 23 | @ARTICLE{Sikkema2023-oh, 24 | title = "An integrated cell atlas of the lung in health and disease", 25 | author = "Sikkema, Lisa and Ramírez-Suástegui, Ciro and Strobl, Daniel C and 26 | Gillett, Tessa E and Zappia, Luke and Madissoon, Elo and Markov, 27 | Nikolay S and Zaragosi, Laure-Emmanuelle and Ji, Yuge and Ansari, 28 | Meshal and Arguel, Marie-Jeanne and Apperloo, Leonie and Banchero, 29 | Martin and Bécavin, Christophe and Berg, Marijn and 30 | Chichelnitskiy, Evgeny and Chung, Mei-I and Collin, Antoine and 31 | Gay, Aurore C A and Gote-Schniering, Janine and Hooshiar Kashani, 32 | Baharak and Inecik, Kemal and Jain, Manu and Kapellos, Theodore S 33 | and Kole, Tessa M and Leroy, Sylvie and Mayr, Christoph H and 34 | Oliver, Amanda J and von Papen, Michael and Peter, Lance and 35 | Taylor, Chase J and Walzthoeni, Thomas and Xu, Chuan and Bui, Linh 36 | T and De Donno, Carlo and Dony, Leander and Faiz, Alen and Guo, 37 | Minzhe and Gutierrez, Austin J and Heumos, Lukas and Huang, Ni and 38 | Ibarra, Ignacio L and Jackson, Nathan D and Kadur Lakshminarasimha 39 | Murthy, Preetish and Lotfollahi, Mohammad and Tabib, Tracy and 40 | Talavera-López, Carlos and Travaglini, Kyle J and Wilbrey-Clark, 41 | Anna and Worlock, Kaylee B and Yoshida, Masahiro and {Lung 42 | Biological Network Consortium} and van den Berge, Maarten and 43 | Bossé, Yohan and Desai, Tushar J and Eickelberg, Oliver and 44 | Kaminski, Naftali and Krasnow, Mark A and Lafyatis, Robert and 45 | Nikolic, Marko Z and Powell, Joseph E and Rajagopal, Jayaraj and 46 | Rojas, Mauricio and Rozenblatt-Rosen, Orit and Seibold, Max A and 47 | Sheppard, Dean and Shepherd, Douglas P and Sin, Don D and Timens, 48 | Wim and Tsankov, Alexander M and Whitsett, Jeffrey and Xu, Yan and 49 | Banovich, Nicholas E and Barbry, Pascal and Duong, Thu Elizabeth 50 | and Falk, Christine S and Meyer, Kerstin B and Kropski, Jonathan A 51 | and Pe'er, Dana and Schiller, Herbert B and Tata, Purushothama Rao 52 | and Schultze, Joachim L and Teichmann, Sara A and Misharin, 53 | Alexander V and Nawijn, Martijn C and Luecken, Malte D and Theis, 54 | Fabian J", 55 | journal = "Nat. Med.", 56 | volume = 29, 57 | number = 6, 58 | pages = "1563--1577", 59 | abstract = "Single-cell technologies have transformed our understanding of 60 | human tissues. Yet, studies typically capture only a limited 61 | number of donors and disagree on cell type definitions. 62 | Integrating many single-cell datasets can address these 63 | limitations of individual studies and capture the variability 64 | present in the population. Here we present the integrated Human 65 | Lung Cell Atlas (HLCA), combining 49 datasets of the human 66 | respiratory system into a single atlas spanning over 2.4 million 67 | cells from 486 individuals. The HLCA presents a consensus cell 68 | type re-annotation with matching marker genes, including 69 | annotations of rare and previously undescribed cell types. 70 | Leveraging the number and diversity of individuals in the HLCA, we 71 | identify gene modules that are associated with demographic 72 | covariates such as age, sex and body mass index, as well as gene 73 | modules changing expression along the proximal-to-distal axis of 74 | the bronchial tree. Mapping new data to the HLCA enables rapid 75 | data annotation and interpretation. Using the HLCA as a reference 76 | for the study of disease, we identify shared cell states across 77 | multiple lung diseases, including SPP1+ profibrotic 78 | monocyte-derived macrophages in COVID-19, pulmonary fibrosis and 79 | lung carcinoma. Overall, the HLCA serves as an example for the 80 | development and use of large-scale, cross-dataset organ atlases 81 | within the Human Cell Atlas.", 82 | month = jun, 83 | year = 2023, 84 | language = "en" 85 | } 86 | 87 | @ARTICLE{Lotfollahi2022-jw, 88 | title = "Mapping single-cell data to reference atlases by transfer learning", 89 | author = "Lotfollahi, Mohammad and Naghipourfar, Mohsen and Luecken, Malte D 90 | and Khajavi, Matin and Büttner, Maren and Wagenstetter, Marco and 91 | Avsec, Žiga and Gayoso, Adam and Yosef, Nir and Interlandi, Marta 92 | and Rybakov, Sergei and Misharin, Alexander V and Theis, Fabian J", 93 | journal = "Nat. Biotechnol.", 94 | volume = 40, 95 | number = 1, 96 | pages = "121--130", 97 | abstract = "Large single-cell atlases are now routinely generated to serve as 98 | references for analysis of smaller-scale studies. Yet learning 99 | from reference data is complicated by batch effects between 100 | datasets, limited availability of computational resources and 101 | sharing restrictions on raw data. Here we introduce a deep 102 | learning strategy for mapping query datasets on top of a reference 103 | called single-cell architectural surgery (scArches). scArches uses 104 | transfer learning and parameter optimization to enable efficient, 105 | decentralized, iterative reference building and contextualization 106 | of new datasets with existing references without sharing raw data. 107 | Using examples from mouse brain, pancreas, immune and 108 | whole-organism atlases, we show that scArches preserves biological 109 | state information while removing batch effects, despite using four 110 | orders of magnitude fewer parameters than de novo integration. 111 | scArches generalizes to multimodal reference mapping, allowing 112 | imputation of missing modalities. Finally, scArches retains 113 | coronavirus disease 2019 (COVID-19) disease variation when mapping 114 | to a healthy reference, enabling the discovery of disease-specific 115 | cell states. scArches will facilitate collaborative projects by 116 | enabling iterative construction, updating, sharing and efficient 117 | use of reference atlases.", 118 | month = jan, 119 | year = 2022, 120 | language = "en" 121 | } 122 | -------------------------------------------------------------------------------- /docs/references.md: -------------------------------------------------------------------------------- 1 | # References 2 | 3 | ```{bibliography} 4 | :cited: 5 | ``` 6 | -------------------------------------------------------------------------------- /docs/tutorials.md: -------------------------------------------------------------------------------- 1 | # Tutorials 2 | 3 | 11 | 12 | 13 | 14 | ```{toctree} 15 | :maxdepth: 1 16 | 17 | notebooks/mil_classification 18 | ``` 19 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | build-backend = "hatchling.build" 3 | requires = ["hatchling"] 4 | 5 | [project] 6 | name = "multimil" 7 | version = "0.2.0" 8 | description = "Multimodal weakly supervised learning to identify disease-specific changes in single-cell atlases" 9 | readme = "README.md" 10 | requires-python = ">=3.10" 11 | license = {file = "LICENSE"} 12 | authors = [ 13 | {name = "Anastasia Litinetskaya"}, 14 | ] 15 | maintainers = [ 16 | {name = "Anastasia Litinetskaya", email = "alitinet@gmail.com"}, 17 | ] 18 | urls.Documentation = "https://multimil.readthedocs.io/" 19 | urls.Source = "https://github.com/theislab/multimil" 20 | urls.Home-page = "https://github.com/theislab/multimil" 21 | dependencies = [ 22 | "anndata<0.11", 23 | "numpy<2.0", 24 | # for debug logging (referenced from the issue template) 25 | "session-info", 26 | # multimil specific 27 | "scanpy", 28 | "scvi-tools<1.0.0", 29 | "matplotlib", 30 | ] 31 | 32 | [project.optional-dependencies] 33 | dev = [ 34 | "pre-commit", 35 | "twine>=4.0.2", 36 | ] 37 | doc = [ 38 | "docutils>=0.8,!=0.18.*,!=0.19.*", 39 | "sphinx>=4", 40 | "sphinx-book-theme>=1.0.0", 41 | "myst-nb>=1.1.0", 42 | "sphinxcontrib-bibtex>=1.0.0", 43 | "setuptools", # Until pybtex >0.23.0 releases: https://bitbucket.org/pybtex-devs/pybtex/issues/169/ 44 | "sphinx-autodoc-typehints", 45 | "sphinxext-opengraph", 46 | # For notebooks 47 | "ipykernel", 48 | "ipython", 49 | "sphinx-copybutton", 50 | "pandas", 51 | ] 52 | test = [ 53 | "pytest", 54 | "coverage", 55 | ] 56 | tutorials = [ 57 | "muon", 58 | "jupyterlab", 59 | "ipywidgets", 60 | "leidenalg", 61 | "igraph", 62 | "gdown", 63 | ] 64 | 65 | [tool.coverage.run] 66 | source = ["multimil"] 67 | omit = [ 68 | "**/test_*.py", 69 | ] 70 | 71 | [tool.pytest.ini_options] 72 | testpaths = ["tests"] 73 | xfail_strict = true 74 | addopts = [ 75 | "--import-mode=importlib", # allow using test files with same name 76 | ] 77 | 78 | [tool.ruff] 79 | line-length = 120 80 | src = ["src"] 81 | extend-include = ["*.ipynb"] 82 | 83 | [tool.ruff.format] 84 | docstring-code-format = true 85 | 86 | [tool.ruff.lint] 87 | select = [ 88 | "F", # Errors detected by Pyflakes 89 | "E", # Error detected by Pycodestyle 90 | "W", # Warning detected by Pycodestyle 91 | "I", # isort 92 | "D", # pydocstyle 93 | "B", # flake8-bugbear 94 | "TID", # flake8-tidy-imports 95 | "C4", # flake8-comprehensions 96 | "BLE", # flake8-blind-except 97 | "UP", # pyupgrade 98 | "RUF100", # Report unused noqa directives 99 | ] 100 | ignore = [ 101 | # line too long -> we accept long comment lines; formatter gets rid of long code lines 102 | "E501", 103 | # Do not assign a lambda expression, use a def -> lambda expression assignments are convenient 104 | "E731", 105 | # allow I, O, l as variable names -> I is the identity matrix 106 | "E741", 107 | # Missing docstring in public package 108 | "D104", 109 | # Missing docstring in public module 110 | "D100", 111 | # Missing docstring in __init__ 112 | "D107", 113 | # Errors from function calls in argument defaults. These are fine when the result is immutable. 114 | "B008", 115 | # __magic__ methods are often self-explanatory, allow missing docstrings 116 | "D105", 117 | # first line should end with a period [Bug: doesn't work with single-line docstrings] 118 | "D400", 119 | # First line should be in imperative mood; try rephrasing 120 | "D401", 121 | ## Disable one in each pair of mutually incompatible rules 122 | # We don’t want a blank line before a class docstring 123 | "D203", 124 | # We want docstrings to start immediately after the opening triple quote 125 | "D213", 126 | ] 127 | 128 | [tool.ruff.lint.pydocstyle] 129 | convention = "numpy" 130 | 131 | [tool.ruff.lint.per-file-ignores] 132 | "docs/*" = ["I"] 133 | "tests/*" = ["D"] 134 | "*/__init__.py" = ["F401"] 135 | "src/multimil/module/__init__.py" = ["I"] 136 | 137 | [tool.cruft] 138 | skip = [ 139 | "tests", 140 | "src/**/__init__.py", 141 | "src/**/basic.py", 142 | "docs/api.md", 143 | "docs/changelog.md", 144 | "docs/references.bib", 145 | "docs/references.md", 146 | "docs/notebooks/example.ipynb", 147 | ] 148 | -------------------------------------------------------------------------------- /src/multimil/__init__.py: -------------------------------------------------------------------------------- 1 | from importlib.metadata import version 2 | 3 | from . import data, dataloaders, distributions, model, module, nn, utils 4 | 5 | __all__ = ["data", "dataloaders", "distributions", "model", "module", "nn", "utils"] 6 | 7 | __version__ = version("multimil") 8 | -------------------------------------------------------------------------------- /src/multimil/data/__init__.py: -------------------------------------------------------------------------------- 1 | from ._preprocessing import organize_multimodal_anndatas 2 | 3 | __all__ = ["organize_multimodal_anndatas"] 4 | -------------------------------------------------------------------------------- /src/multimil/data/_preprocessing.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import anndata as ad 4 | import numpy as np 5 | import pandas as pd 6 | 7 | 8 | def organize_multimodal_anndatas( 9 | adatas: list[list[ad.AnnData | None]], 10 | layers: list[list[str | None]] | None = None, 11 | ) -> ad.AnnData: 12 | """Concatenate all the input anndata objects. 13 | 14 | These anndata objects should already have been preprocessed so that all single-modality 15 | objects use a subset of the features used in the multiome object. The feature names (index of 16 | `.var`) should match between the objects for vertical integration and cell names (index of 17 | `.obs`) should match between the objects for horizontal integration. 18 | 19 | Parameters 20 | ---------- 21 | adatas 22 | List of Lists with AnnData objects or None where each sublist corresponds to a modality. 23 | layers 24 | List of Lists of the same lengths as `adatas` specifying which `.layer` to use for each AnnData. Default is None which means using `.X`. 25 | 26 | Returns 27 | ------- 28 | Concatenated AnnData object across modalities and datasets. 29 | """ 30 | # TODO: add checks for layers 31 | # TODO: add check that len of modalities is the same as len of losses, etc 32 | 33 | # needed for scArches operation setup 34 | datasets_lengths = {} 35 | datasets_obs_names = {} 36 | datasets_obs = {} 37 | modality_lengths = {} 38 | modality_var_names = {} 39 | 40 | # sanity checks and preparing data for concat 41 | for mod, modality_adatas in enumerate(adatas): 42 | for i, adata in enumerate(modality_adatas): 43 | if adata is not None: 44 | # will create .obs['group'] later, so throw a warning here if the column already exists 45 | if "group" in adata.obs.columns: 46 | warnings.warn( 47 | "Column `.obs['group']` will be overwritten. Please save the original data in another column if needed.", 48 | stacklevel=2, 49 | ) 50 | # check that all adatas in the same modality have the same number of features 51 | if (mod_length := modality_lengths.get(f"{mod}", None)) is None: 52 | modality_lengths[f"{mod}"] = adata.shape[1] 53 | else: 54 | if adata.shape[1] != mod_length: 55 | raise ValueError( 56 | f"Adatas have different number of features for modality {mod}, namely {mod_length} and {adata.shape[1]}." 57 | ) 58 | # check that there is the same number of observations for paired data 59 | if (dataset_length := datasets_lengths.get(i, None)) is None: 60 | datasets_lengths[i] = adata.shape[0] 61 | else: 62 | if adata.shape[0] != dataset_length: 63 | raise ValueError( 64 | f"Paired adatas have different number of observations for group {i}, namely {dataset_length} and {adata.shape[0]}." 65 | ) 66 | # check that .obs_names are the same for paired data 67 | if (dataset_obs_names := datasets_obs_names.get(i, None)) is None: 68 | datasets_obs_names[i] = adata.obs_names 69 | else: 70 | if np.sum(adata.obs_names != dataset_obs_names): 71 | raise ValueError(f"`.obs_names` are not the same for group {i}.") 72 | # keep all the .obs 73 | if datasets_obs.get(i, None) is None: 74 | datasets_obs[i] = adata.obs 75 | datasets_obs[i].loc[:, "group"] = i 76 | else: 77 | cols_to_use = adata.obs.columns.difference(datasets_obs[i].columns) 78 | datasets_obs[i] = datasets_obs[i].join(adata.obs[cols_to_use]) 79 | modality_var_names[mod] = adata.var_names 80 | 81 | for mod, modality_adatas in enumerate(adatas): 82 | for i, adata in enumerate(modality_adatas): 83 | if not isinstance(adata, ad.AnnData) and adata is None: 84 | X_zeros = np.zeros((datasets_lengths[i], modality_lengths[f"{mod}"])) 85 | adatas[mod][i] = ad.AnnData(X_zeros, dtype=X_zeros.dtype) 86 | adatas[mod][i].obs_names = datasets_obs_names[i] 87 | adatas[mod][i].var_names = modality_var_names[mod] 88 | adatas[mod][i] = adatas[mod][i].copy() 89 | if layers is not None: 90 | if layers[mod][i]: 91 | layer = layers[mod][i] 92 | adatas[mod][i] = adatas[mod][i].copy() 93 | adatas[mod][i].X = adatas[mod][i].layers[layer].copy() 94 | 95 | # concat adatas within each modality first 96 | mod_adatas = [] 97 | for modality_adatas in adatas: 98 | mod_adatas.append(ad.concat(modality_adatas, join="outer")) 99 | 100 | # concat modality adatas along the feature axis 101 | multiome_anndata = ad.concat(mod_adatas, axis=1, label="modality") 102 | 103 | # add .obs back 104 | multiome_anndata.obs = pd.concat(datasets_obs.values()) 105 | 106 | # we will need modality_length later for the model init 107 | multiome_anndata.uns["modality_lengths"] = modality_lengths 108 | multiome_anndata.var_names_make_unique() 109 | 110 | return multiome_anndata 111 | -------------------------------------------------------------------------------- /src/multimil/dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | from ._ann_dataloader import GroupAnnDataLoader 2 | from ._data_splitting import GroupDataSplitter 3 | 4 | __all__ = ["GroupAnnDataLoader", "GroupDataSplitter"] 5 | -------------------------------------------------------------------------------- /src/multimil/dataloaders/_ann_dataloader.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import itertools 3 | 4 | import numpy as np 5 | import torch 6 | from scvi.data import AnnDataManager 7 | from scvi.dataloaders import AnnTorchDataset 8 | from torch.utils.data import DataLoader 9 | 10 | 11 | # adjusted from scvi-tools 12 | # https://github.com/YosefLab/scvi-tools/blob/ac0c3e04fcc2772fdcf7de4de819db3af9465b6b/scvi/dataloaders/_ann_dataloader.py#L15 13 | # accessed on 4 November 2021 14 | class StratifiedSampler(torch.utils.data.sampler.Sampler): 15 | """Custom stratified sampler class which enables sampling the same number of observation from each group in each mini-batch. 16 | 17 | Parameters 18 | ---------- 19 | indices 20 | List of indices to sample from. 21 | batch_size 22 | Batch size of each iteration. 23 | shuffle 24 | If ``True``, shuffles indices before sampling. 25 | drop_last 26 | If int, drops the last batch if its length is less than drop_last. 27 | If drop_last == True, drops last non-full batch. 28 | If drop_last == False, iterate over all batches. 29 | """ 30 | 31 | def __init__( 32 | self, 33 | indices: np.ndarray, 34 | group_labels: np.ndarray, 35 | batch_size: int, 36 | min_size_per_class: int, 37 | shuffle: bool = True, 38 | drop_last: bool | int = True, 39 | shuffle_classes: bool = True, 40 | ): 41 | if drop_last > batch_size: 42 | raise ValueError( 43 | "drop_last can't be greater than batch_size. " 44 | + f"drop_last is {drop_last} but batch_size is {batch_size}." 45 | ) 46 | 47 | if batch_size % min_size_per_class != 0: 48 | raise ValueError( 49 | "min_size_per_class has to be a divisor of batch_size." 50 | + f"min_size_per_class is {min_size_per_class} but batch_size is {batch_size}." 51 | ) 52 | 53 | self.indices = indices 54 | self.group_labels = group_labels 55 | self.n_obs = len(indices) 56 | self.batch_size = batch_size 57 | self.shuffle = shuffle 58 | self.shuffle_classes = shuffle_classes 59 | self.min_size_per_class = min_size_per_class 60 | self.drop_last = drop_last 61 | 62 | from math import ceil 63 | 64 | classes = list(dict.fromkeys(self.group_labels)) 65 | 66 | tmp = 0 67 | for cl in classes: 68 | idx = np.where(self.group_labels == cl)[0] 69 | cl_idx = self.indices[idx] 70 | n_obs = len(cl_idx) 71 | last_batch_len = n_obs % self.min_size_per_class 72 | if (self.drop_last is True) or (last_batch_len < self.drop_last): 73 | drop_last_n = last_batch_len 74 | elif (self.drop_last is False) or (last_batch_len >= self.drop_last): 75 | drop_last_n = 0 76 | else: 77 | raise ValueError("Invalid input for drop_last param. Must be bool or int.") 78 | 79 | if drop_last_n != 0: 80 | tmp += n_obs // self.min_size_per_class 81 | else: 82 | tmp += ceil(n_obs / self.min_size_per_class) 83 | 84 | classes_per_batch = int(self.batch_size / self.min_size_per_class) 85 | self.length = ceil(tmp / classes_per_batch) 86 | 87 | def __iter__(self): 88 | classes_per_batch = int(self.batch_size / self.min_size_per_class) 89 | 90 | classes = list(dict.fromkeys(self.group_labels)) 91 | data_iter = [] 92 | 93 | for cl in classes: 94 | idx = np.where(self.group_labels == cl)[0] 95 | cl_idx = self.indices[idx] 96 | n_obs = len(cl_idx) 97 | 98 | if self.shuffle is True: 99 | idx = torch.randperm(n_obs).tolist() 100 | else: 101 | idx = torch.arange(n_obs).tolist() 102 | 103 | last_batch_len = n_obs % self.min_size_per_class 104 | if (self.drop_last is True) or (last_batch_len < self.drop_last): 105 | drop_last_n = last_batch_len 106 | elif (self.drop_last is False) or (last_batch_len >= self.drop_last): 107 | drop_last_n = 0 108 | else: 109 | raise ValueError("Invalid input for drop_last param. Must be bool or int.") 110 | 111 | if drop_last_n != 0: 112 | idx = idx[:-drop_last_n] 113 | 114 | data_iter.extend( 115 | [cl_idx[idx[i : i + self.min_size_per_class]] for i in range(0, len(idx), self.min_size_per_class)] 116 | ) 117 | 118 | if self.shuffle_classes: 119 | idx = torch.randperm(len(data_iter)).tolist() 120 | data_iter = [data_iter[id] for id in idx] 121 | 122 | final_data_iter = [] 123 | 124 | end = len(data_iter) - len(data_iter) % classes_per_batch 125 | for i in range(0, end, classes_per_batch): 126 | batch_idx = list(itertools.chain.from_iterable(data_iter[i : i + classes_per_batch])) 127 | final_data_iter.append(batch_idx) 128 | 129 | # deal with the last manually 130 | if end != len(data_iter): 131 | batch_idx = list(itertools.chain.from_iterable(data_iter[end:])) 132 | final_data_iter.append(batch_idx) 133 | 134 | return iter(final_data_iter) 135 | 136 | def __len__(self): 137 | return self.length 138 | 139 | 140 | # adjusted from scvi-tools 141 | # https://github.com/scverse/scvi-tools/blob/0b802762869c43c9f49e69fe62b1a5a9b5c4dae6/scvi/dataloaders/_ann_dataloader.py#L89 142 | # accessed on 5 November 2022 143 | class GroupAnnDataLoader(DataLoader): 144 | """DataLoader for loading tensors from AnnData objects. 145 | 146 | Parameters 147 | ---------- 148 | adata_manager 149 | :class:`~scvi.data.AnnDataManager` object with a registered AnnData object. 150 | shuffle 151 | Whether the data should be shuffled. 152 | indices 153 | The indices of the observations in the adata to load. 154 | batch_size 155 | Minibatch size to load each iteration. 156 | data_and_attributes 157 | Dictionary with keys representing keys in data registry (`adata.uns["_scvi"]`) 158 | and value equal to desired numpy loading type (later made into torch tensor). 159 | If `None`, defaults to all registered data. 160 | data_loader_kwargs 161 | Keyword arguments for :class:`~torch.utils.data.DataLoader`. 162 | """ 163 | 164 | def __init__( 165 | self, 166 | adata_manager: AnnDataManager, 167 | group_column: str, 168 | shuffle=True, 169 | shuffle_classes=True, 170 | indices=None, 171 | batch_size=128, 172 | min_size_per_class=None, 173 | data_and_attributes: dict | None = None, 174 | drop_last: bool | int = True, 175 | sampler: torch.utils.data.sampler.Sampler | None = StratifiedSampler, 176 | **data_loader_kwargs, 177 | ): 178 | if adata_manager.adata is None: 179 | raise ValueError("Please run register_fields() on your AnnDataManager object first.") 180 | 181 | if data_and_attributes is not None: 182 | data_registry = adata_manager.data_registry 183 | for key in data_and_attributes.keys(): 184 | if key not in data_registry.keys(): 185 | raise ValueError(f"{key} required for model but not registered with AnnDataManager.") 186 | 187 | if group_column not in adata_manager.registry["setup_args"]["categorical_covariate_keys"]: 188 | raise ValueError( 189 | f"{group_column} required for model but not in categorical covariates has to be one of the registered categorical covariates = {adata_manager.registry['setup_args']['categorical_covariate_keys']}." 190 | ) 191 | 192 | self.dataset = AnnTorchDataset(adata_manager, getitem_tensors=data_and_attributes) 193 | 194 | if min_size_per_class is None: 195 | min_size_per_class = batch_size // 2 196 | 197 | sampler_kwargs = { 198 | "batch_size": batch_size, 199 | "shuffle": shuffle, 200 | "drop_last": drop_last, 201 | "min_size_per_class": min_size_per_class, 202 | "shuffle_classes": shuffle_classes, 203 | } 204 | 205 | if indices is None: 206 | indices = np.arange(len(self.dataset)) 207 | sampler_kwargs["indices"] = indices 208 | else: 209 | if hasattr(indices, "dtype") and indices.dtype is np.dtype("bool"): 210 | indices = np.where(indices)[0].ravel() 211 | indices = np.asarray(indices) 212 | sampler_kwargs["indices"] = indices 213 | 214 | sampler_kwargs["group_labels"] = np.array( 215 | adata_manager.adata[indices].obsm["_scvi_extra_categorical_covs"][group_column] 216 | ) 217 | 218 | self.indices = indices 219 | self.sampler_kwargs = sampler_kwargs 220 | 221 | sampler = sampler(**self.sampler_kwargs) 222 | self.data_loader_kwargs = copy.copy(data_loader_kwargs) 223 | # do not touch batch size here, sampler gives batched indices 224 | self.data_loader_kwargs.update({"sampler": sampler, "batch_size": None}) 225 | 226 | super().__init__(self.dataset, **self.data_loader_kwargs) 227 | -------------------------------------------------------------------------------- /src/multimil/dataloaders/_data_splitting.py: -------------------------------------------------------------------------------- 1 | from scvi.data import AnnDataManager 2 | from scvi.dataloaders import DataSplitter 3 | 4 | from multimil.dataloaders._ann_dataloader import GroupAnnDataLoader 5 | 6 | 7 | # adjusted from scvi-tools 8 | # https://github.com/scverse/scvi-tools/blob/0b802762869c43c9f49e69fe62b1a5a9b5c4dae6/scvi/dataloaders/_data_splitting.py#L56 9 | # accessed on 5 November 2022 10 | class GroupDataSplitter(DataSplitter): 11 | """Creates data loaders ``train_set``, ``validation_set``, ``test_set``. 12 | 13 | If ``train_size + validation_set < 1`` then ``test_set`` is non-empty. 14 | 15 | Parameters 16 | ---------- 17 | adata_manager 18 | :class:`~scvi.data.AnnDataManager` object that has been created via ``setup_anndata``. 19 | train_size 20 | Proportion of cells to use as the train set. Float, or None (default is 0.9). 21 | validation_size 22 | Proportion of cell to use as the valisation set. Float, or None (default is None). If None, is set to 1 - ``train_size``. 23 | use_gpu 24 | Use default GPU if available (if None or True), or index of GPU to use (if int), 25 | or name of GPU (if str, e.g., `'cuda:0'`), or use CPU (if False). 26 | kwargs 27 | Keyword args for data loader. Data loader class is :class:`~mtg.dataloaders.GroupAnnDataLoader`. 28 | """ 29 | 30 | def __init__( 31 | self, 32 | adata_manager: AnnDataManager, 33 | group_column: str, 34 | train_size: float = 0.9, 35 | validation_size: float | None = None, 36 | use_gpu: bool = False, 37 | **kwargs, 38 | ): 39 | self.group_column = group_column 40 | super().__init__(adata_manager, train_size, validation_size, use_gpu, **kwargs) 41 | 42 | def train_dataloader(self): 43 | """Return data loader for train AnnData.""" 44 | return GroupAnnDataLoader( 45 | self.adata_manager, 46 | self.group_column, 47 | indices=self.train_idx, 48 | shuffle=True, 49 | drop_last=True, 50 | pin_memory=self.pin_memory, 51 | **self.data_loader_kwargs, 52 | ) 53 | 54 | def val_dataloader(self): 55 | """Return data loader for validation AnnData.""" 56 | if len(self.val_idx) > 0: 57 | return GroupAnnDataLoader( 58 | self.adata_manager, 59 | self.group_column, 60 | indices=self.val_idx, 61 | shuffle=False, 62 | drop_last=True, 63 | pin_memory=self.pin_memory, 64 | **self.data_loader_kwargs, 65 | ) 66 | else: 67 | pass 68 | 69 | def test_dataloader(self): 70 | """Return data loader for test AnnData.""" 71 | if len(self.test_idx) > 0: 72 | return GroupAnnDataLoader( 73 | self.adata_manager, 74 | self.group_column, 75 | indices=self.test_idx, 76 | shuffle=False, 77 | drop_last=True, 78 | pin_memory=self.pin_memory, 79 | **self.data_loader_kwargs, 80 | ) 81 | else: 82 | pass 83 | -------------------------------------------------------------------------------- /src/multimil/distributions/__init__.py: -------------------------------------------------------------------------------- 1 | from ._jeffreys import Jeffreys 2 | from ._mmd import MMD 3 | 4 | __all__ = ["MMD", "Jeffreys"] 5 | -------------------------------------------------------------------------------- /src/multimil/distributions/_jeffreys.py: -------------------------------------------------------------------------------- 1 | # code adapted from https://github.com/scverse/scvi-tools/blob/c53efe06379c866e36e549afbb8158a120b82d14/src/scvi/module/_multivae.py#L900C1-L923C56 2 | # last accessed on 16th October 2024 3 | import torch 4 | from torch.distributions import Normal 5 | from torch.distributions import kl_divergence as kld 6 | 7 | 8 | class Jeffreys(torch.nn.Module): 9 | """Jeffreys divergence (Symmetric KL divergence) using torch distributions. 10 | 11 | Parameters 12 | ---------- 13 | None 14 | """ 15 | 16 | def __init__(self): 17 | super().__init__() 18 | 19 | def sym_kld( 20 | self, 21 | mu1: torch.Tensor, 22 | sigma1: torch.Tensor, 23 | mu2: torch.Tensor, 24 | sigma2: torch.Tensor, 25 | ) -> torch.Tensor: 26 | """Symmetric KL divergence between two Gaussians using torch.distributions. 27 | 28 | Parameters 29 | ---------- 30 | mu1 31 | Mean of the first distribution. 32 | sigma1 33 | Variance of the first distribution (note: this will be square-rooted to get std dev). 34 | mu2 35 | Mean of the second distribution. 36 | sigma2 37 | Variance of the second distribution (note: this will be square-rooted to get std dev). 38 | 39 | Returns 40 | ------- 41 | Symmetric KL divergence between the two distributions. 42 | """ 43 | rv1 = Normal(mu1, sigma1.sqrt()) 44 | rv2 = Normal(mu2, sigma2.sqrt()) 45 | 46 | out = kld(rv1, rv2).mean() + kld(rv2, rv1).mean() 47 | return out 48 | 49 | def forward(self, params1: torch.Tensor, params2: torch.Tensor) -> torch.Tensor: 50 | """Forward computation for Jeffreys divergence. 51 | 52 | Parameters 53 | ---------- 54 | params1 55 | A tuple of mean and variance (or standard deviation) of the first distribution. 56 | params2 57 | A tuple of mean and variance (or standard deviation) of the second distribution. 58 | 59 | Returns 60 | ------- 61 | Jeffreys divergence between the two distributions. 62 | """ 63 | mu1, sigma1 = params1 64 | mu2, sigma2 = params2 65 | 66 | # Ensure non-negative sigma (variance) values 67 | sigma1 = sigma1.clamp(min=1e-6) 68 | sigma2 = sigma2.clamp(min=1e-6) 69 | 70 | return self.sym_kld(mu1, sigma1, mu2, sigma2) 71 | -------------------------------------------------------------------------------- /src/multimil/distributions/_mmd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class MMD(torch.nn.Module): 5 | """Maximum mean discrepancy. 6 | 7 | Parameters 8 | ---------- 9 | kernel_type 10 | Indicates if to use Gaussian kernel. One of 11 | * ``'gaussian'`` - use Gaussian kernel 12 | * ``'not gaussian'`` - do not use Gaussian kernel. 13 | """ 14 | 15 | def __init__(self, kernel_type="gaussian"): 16 | super().__init__() 17 | self.kernel_type = kernel_type 18 | # TODO: add check for gaussian kernel that shapes are same 19 | 20 | def gaussian_kernel( 21 | self, 22 | x: torch.Tensor, 23 | y: torch.Tensor, 24 | gamma: list[float] | None = None, 25 | ) -> torch.Tensor: 26 | """Apply Guassian kernel. 27 | 28 | Parameters 29 | ---------- 30 | x 31 | Tensor from the first distribution. 32 | y 33 | Tensor from the second distribution. 34 | gamma 35 | List of gamma parameters. 36 | 37 | Returns 38 | ------- 39 | Gaussian kernel between ``x`` and ``y``. 40 | """ 41 | if gamma is None: 42 | gamma = [ 43 | 1e-6, 44 | 1e-5, 45 | 1e-4, 46 | 1e-3, 47 | 1e-2, 48 | 1e-1, 49 | 1, 50 | 5, 51 | 10, 52 | 15, 53 | 20, 54 | 25, 55 | 30, 56 | 35, 57 | 100, 58 | 1e3, 59 | 1e4, 60 | 1e5, 61 | 1e6, 62 | ] 63 | 64 | D = torch.cdist(x, y).pow(2) 65 | K = torch.zeros_like(D) 66 | 67 | for g in gamma: 68 | K.add_(torch.exp(D.mul(-g))) 69 | 70 | return K / len(gamma) 71 | 72 | def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 73 | """Forward computation. 74 | 75 | Adapted from 76 | Title: scarches 77 | Date: 9th Octover 2021 78 | Code version: 0.4.0 79 | Availability: https://github.com/theislab/scarches/blob/63a7c2b35a01e55fe7e1dd871add459a86cd27fb/scarches/models/trvae/losses.py 80 | Citation: Gretton, Arthur, et al. "A Kernel Two-Sample Test", 2012. 81 | 82 | Parameters 83 | ---------- 84 | x 85 | Tensor with shape ``(batch_size, z_dim)``. 86 | y 87 | Tensor with shape ``(batch_size, z_dim)``. 88 | 89 | Returns 90 | ------- 91 | MMD between ``x`` and ``y``. 92 | """ 93 | # in case there is only one sample in a batch belonging to one of the groups, then skip the batch 94 | if len(x) == 1 or len(y) == 1: 95 | return torch.tensor(0.0) 96 | 97 | if self.kernel_type == "gaussian": 98 | Kxx = self.gaussian_kernel(x, x).mean() 99 | Kyy = self.gaussian_kernel(y, y).mean() 100 | Kxy = self.gaussian_kernel(x, y).mean() 101 | return Kxx + Kyy - 2 * Kxy 102 | else: 103 | mean_x = x.mean(0, keepdim=True) 104 | mean_y = y.mean(0, keepdim=True) 105 | cent_x = x - mean_x 106 | cent_y = y - mean_y 107 | cova_x = (cent_x.t() @ cent_x) / (len(x) - 1) 108 | cova_y = (cent_y.t() @ cent_y) / (len(y) - 1) 109 | 110 | mean_diff = (mean_x - mean_y).pow(2).mean() 111 | cova_diff = (cova_x - cova_y).pow(2).mean() 112 | 113 | return mean_diff + cova_diff 114 | -------------------------------------------------------------------------------- /src/multimil/model/__init__.py: -------------------------------------------------------------------------------- 1 | from ._mil import MILClassifier 2 | from ._multivae import MultiVAE 3 | from ._multivae_mil import MultiVAE_MIL 4 | 5 | __all__ = ["MultiVAE", "MILClassifier", "MultiVAE_MIL"] 6 | -------------------------------------------------------------------------------- /src/multimil/model/_mil.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import warnings 3 | 4 | import anndata as ad 5 | import torch 6 | from anndata import AnnData 7 | from pytorch_lightning.callbacks import ModelCheckpoint 8 | from scvi import REGISTRY_KEYS 9 | from scvi.data import AnnDataManager, fields 10 | from scvi.data._constants import _MODEL_NAME_KEY, _SETUP_ARGS_KEY 11 | from scvi.model._utils import parse_use_gpu_arg 12 | from scvi.model.base import ArchesMixin, BaseModelClass 13 | from scvi.model.base._archesmixin import _get_loaded_data 14 | from scvi.model.base._utils import _initialize_model 15 | from scvi.train import AdversarialTrainingPlan, TrainRunner 16 | from scvi.train._callbacks import SaveBestState 17 | 18 | from multimil.dataloaders import GroupAnnDataLoader, GroupDataSplitter 19 | from multimil.module import MILClassifierTorch 20 | from multimil.utils import ( 21 | get_bag_info, 22 | get_predictions, 23 | plt_plot_losses, 24 | prep_minibatch, 25 | save_predictions_in_adata, 26 | select_covariates, 27 | setup_ordinal_regression, 28 | ) 29 | 30 | logger = logging.getLogger(__name__) 31 | 32 | 33 | class MILClassifier(BaseModelClass, ArchesMixin): 34 | """MultiMIL MIL prediction model. 35 | 36 | Parameters 37 | ---------- 38 | adata 39 | AnnData object containing embeddings and covariates. 40 | sample_key 41 | Key in `adata.obs` that corresponds to the sample covariate. 42 | classification 43 | List of keys in `adata.obs` that correspond to the classification covariates. 44 | regression 45 | List of keys in `adata.obs` that correspond to the regression covariates. 46 | ordinal_regression 47 | List of keys in `adata.obs` that correspond to the ordinal regression covariates. 48 | sample_batch_size 49 | Number of samples per bag, i.e. sample. Default is 128. 50 | normalization 51 | One of "layer" or "batch". Default is "layer". 52 | z_dim 53 | Dimensionality of the input latent space. Default is 16. 54 | dropout 55 | Dropout rate. Default is 0.2. 56 | scoring 57 | How to calculate attention scores. One of "gated_attn", "attn", "mean", "max", "sum", "MLP". Default is "gated_attn". 58 | attn_dim 59 | Dimensionality of the hidden layer in attention calculation. Default is 16. 60 | n_layers_cell_aggregator 61 | Number of layers in the cell aggregator. Default is 1. 62 | n_layers_classifier 63 | Number of layers in the classifier. Default is 2. 64 | n_layers_regressor 65 | Number of layers in the regressor. Default is 2. 66 | n_layers_mlp_attn 67 | Number of layers in the MLP attention. Only used if `scoring` = "MLP". Default is 1. 68 | n_hidden_cell_aggregator 69 | Number of hidden units in the cell aggregator. Default is 128. 70 | n_hidden_classifier 71 | Number of hidden units in the classifier. Default is 128. 72 | n_hidden_mlp_attn 73 | Number of hidden units in the MLP attention. Default is 32. 74 | n_hidden_regressor 75 | Number of hidden units in the regressor. Default is 128. 76 | class_loss_coef 77 | Coefficient for the classification loss. Default is 1.0. 78 | regression_loss_coef 79 | Coefficient for the regression loss. Default is 1.0. 80 | activation 81 | Activation function. Default is 'leaky_relu'. 82 | initialization 83 | Initialization method for the weights. Default is None. 84 | anneal_class_loss 85 | Whether to anneal the classification loss. Default is False. 86 | ignore_covariates 87 | List of covariates to ignore. Needed for query-to-reference mapping. Default is None. 88 | """ 89 | 90 | def __init__( 91 | self, 92 | adata, 93 | sample_key, 94 | classification=None, 95 | regression=None, 96 | ordinal_regression=None, 97 | sample_batch_size=128, 98 | normalization="layer", 99 | z_dim=16, # TODO do we need it? can't we get it from adata? 100 | dropout=0.2, 101 | scoring="gated_attn", # TODO test if MLP is supported and if we want to keep it 102 | attn_dim=16, 103 | n_layers_cell_aggregator: int = 1, 104 | n_layers_classifier: int = 2, 105 | n_layers_regressor: int = 2, 106 | n_layers_mlp_attn: int = 1, 107 | n_hidden_cell_aggregator: int = 128, 108 | n_hidden_classifier: int = 128, 109 | n_hidden_mlp_attn: int = 32, 110 | n_hidden_regressor: int = 128, 111 | class_loss_coef=1.0, 112 | regression_loss_coef=1.0, 113 | activation="leaky_relu", # or tanh 114 | initialization=None, # xavier (tanh) or kaiming (leaky_relu) 115 | anneal_class_loss=False, 116 | ignore_covariates=None, 117 | ): 118 | super().__init__(adata) 119 | 120 | if classification is None: 121 | classification = [] 122 | if regression is None: 123 | regression = [] 124 | if ordinal_regression is None: 125 | ordinal_regression = [] 126 | 127 | self.sample_key = sample_key 128 | self.scoring = scoring 129 | 130 | if self.sample_key not in self.adata_manager.registry["setup_args"]["categorical_covariate_keys"]: 131 | raise ValueError( 132 | f"Sample key = '{self.sample_key}' has to be one of the registered categorical covariates = {self.adata_manager.registry['setup_args']['categorical_covariate_keys']}" 133 | ) 134 | 135 | if len(classification) + len(regression) + len(ordinal_regression) == 0: 136 | raise ValueError( 137 | 'At least one of "classification", "regression", "ordinal_regression" has to be specified.' 138 | ) 139 | 140 | self.classification = classification 141 | self.regression = regression 142 | self.ordinal_regression = ordinal_regression 143 | 144 | # TODO check if all of the three above were registered with setup anndata 145 | # TODO add check that class is the same within a patient 146 | # TODO assert length of things is the same as number of modalities 147 | # TODO add that n_layers has to be > 0 for all 148 | # TODO warning if n_layers == 1 then n_hidden is not used for classifier and MLP attention 149 | # TODO warning if MLP attention is used but n layers and n hidden not given that using default values 150 | # TODO check that there is at least on ecovariate to predict 151 | if scoring == "MLP": 152 | if not n_layers_mlp_attn: 153 | n_layers_mlp_attn = 1 154 | if not n_hidden_mlp_attn: 155 | n_hidden_mlp_attn = 16 156 | 157 | self.regression_idx = [] 158 | if len(cont_covs := self.adata_manager.get_state_registry(REGISTRY_KEYS.CONT_COVS_KEY)) > 0: 159 | for key in cont_covs["columns"]: 160 | if key in self.regression: 161 | self.regression_idx.append(list(cont_covs["columns"]).index(key)) 162 | else: # only can happen when using multivae_mil 163 | if ignore_covariates is not None and key not in ignore_covariates: 164 | warnings.warn( 165 | f"Registered continuous covariate '{key}' is not in regression covariates so will be ignored.", 166 | stacklevel=2, 167 | ) 168 | 169 | # classification and ordinal regression together here as ordinal regression values need to be registered as categorical covariates 170 | self.class_idx, self.ord_idx = [], [] 171 | self.num_classification_classes = [] 172 | if len(cat_covs := self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY)) > 0: 173 | for i, num_cat in enumerate(cat_covs.n_cats_per_key): 174 | cat_cov_name = cat_covs["field_keys"][i] 175 | if cat_cov_name in self.classification: 176 | self.num_classification_classes.append(num_cat) 177 | self.class_idx.append(i) 178 | elif cat_cov_name in self.ordinal_regression: 179 | self.num_classification_classes.append(num_cat) 180 | self.ord_idx.append(i) 181 | else: # the actual categorical covariate, only can happen when using multivae_mil 182 | if ( 183 | ignore_covariates is not None 184 | and cat_cov_name not in ignore_covariates 185 | and cat_cov_name != self.sample_key 186 | ): 187 | warnings.warn( 188 | f"Registered categorical covariate '{cat_cov_name}' is not in classification or ordinal regression covariates and is not the sample covariate so will be ignored.", 189 | stacklevel=2, 190 | ) 191 | 192 | for label in ordinal_regression: 193 | print( 194 | f"The order for {label} ordinal classes is: {adata.obs[label].cat.categories}. If you need to change the order, please rerun setup_anndata and specify the correct order with the `ordinal_regression_order` parameter." 195 | ) 196 | 197 | self.class_idx = torch.tensor(self.class_idx) 198 | self.ord_idx = torch.tensor(self.ord_idx) 199 | self.regression_idx = torch.tensor(self.regression_idx) 200 | 201 | self.module = MILClassifierTorch( 202 | z_dim=z_dim, 203 | dropout=dropout, 204 | activation=activation, 205 | initialization=initialization, 206 | normalization=normalization, 207 | num_classification_classes=self.num_classification_classes, 208 | scoring=scoring, 209 | attn_dim=attn_dim, 210 | n_layers_cell_aggregator=n_layers_cell_aggregator, 211 | n_layers_classifier=n_layers_classifier, 212 | n_layers_mlp_attn=n_layers_mlp_attn, 213 | n_layers_regressor=n_layers_regressor, 214 | n_hidden_regressor=n_hidden_regressor, 215 | n_hidden_cell_aggregator=n_hidden_cell_aggregator, 216 | n_hidden_classifier=n_hidden_classifier, 217 | n_hidden_mlp_attn=n_hidden_mlp_attn, 218 | class_loss_coef=class_loss_coef, 219 | regression_loss_coef=regression_loss_coef, 220 | sample_batch_size=sample_batch_size, 221 | class_idx=self.class_idx, 222 | ord_idx=self.ord_idx, 223 | reg_idx=self.regression_idx, 224 | anneal_class_loss=anneal_class_loss, 225 | ) 226 | 227 | self.init_params_ = self._get_init_params(locals()) 228 | 229 | def train( 230 | self, 231 | max_epochs: int = 200, 232 | lr: float = 5e-4, 233 | use_gpu: str | int | bool | None = None, 234 | train_size: float = 0.9, 235 | validation_size: float | None = None, 236 | batch_size: int = 256, 237 | weight_decay: float = 1e-3, 238 | eps: float = 1e-08, 239 | early_stopping: bool = True, 240 | save_best: bool = True, 241 | check_val_every_n_epoch: int | None = None, 242 | n_epochs_kl_warmup: int | None = None, 243 | n_steps_kl_warmup: int | None = None, 244 | adversarial_mixing: bool = False, 245 | plan_kwargs: dict | None = None, 246 | early_stopping_monitor: str | None = "accuracy_validation", 247 | early_stopping_mode: str | None = "max", 248 | save_checkpoint_every_n_epochs: int | None = None, 249 | path_to_checkpoints: str | None = None, 250 | **kwargs, 251 | ): 252 | """Trains the model using amortized variational inference. 253 | 254 | Parameters 255 | ---------- 256 | max_epochs 257 | Number of passes through the dataset. 258 | lr 259 | Learning rate for optimization. 260 | use_gpu 261 | Use default GPU if available (if None or True), or index of GPU to use (if int), 262 | or name of GPU (if str), or use CPU (if False). 263 | train_size 264 | Size of training set in the range [0.0, 1.0]. 265 | validation_size 266 | Size of the test set. If `None`, defaults to 1 - `train_size`. If 267 | `train_size + validation_size < 1`, the remaining cells belong to a test set. 268 | batch_size 269 | Minibatch size to use during training. 270 | weight_decay 271 | weight decay regularization term for optimization 272 | eps 273 | Optimizer eps 274 | early_stopping 275 | Whether to perform early stopping with respect to the validation set. 276 | save_best 277 | Save the best model state with respect to the validation loss, or use the final 278 | state in the training procedure 279 | check_val_every_n_epoch 280 | Check val every n train epochs. By default, val is not checked, unless `early_stopping` is `True`. 281 | If so, val is checked every epoch. 282 | n_epochs_kl_warmup 283 | Number of epochs to scale weight on KL divergences from 0 to 1. 284 | Overrides `n_steps_kl_warmup` when both are not `None`. Default is 1/3 of `max_epochs`. 285 | n_steps_kl_warmup 286 | Number of training steps (minibatches) to scale weight on KL divergences from 0 to 1. 287 | Only activated when `n_epochs_kl_warmup` is set to None. If `None`, defaults 288 | to `floor(0.75 * adata.n_obs)` 289 | adversarial_mixing 290 | Whether to use adversarial mixing. Default is False. 291 | plan_kwargs 292 | Keyword args for :class:`~scvi.train.TrainingPlan`. Keyword arguments passed to 293 | `train()` will overwrite values present in `plan_kwargs`, when appropriate. 294 | early_stopping_monitor 295 | Metric to monitor for early stopping. Default is "accuracy_validation". 296 | early_stopping_mode 297 | One of "min" or "max". Default is "max". 298 | save_checkpoint_every_n_epochs 299 | Save a checkpoint every n epochs. 300 | path_to_checkpoints 301 | Path to save checkpoints. 302 | **kwargs 303 | Other keyword args for :class:`~scvi.train.Trainer`. 304 | 305 | Returns 306 | ------- 307 | Trainer object. 308 | """ 309 | # TODO put in a function, return params needed for splitter, plan and runner, then can call the function from multivae_mil 310 | if len(self.regression) > 0: 311 | if early_stopping_monitor == "accuracy_validation": 312 | warnings.warn( 313 | "Setting early_stopping_monitor to 'regression_loss_validation' and early_stopping_mode to 'min' as regression is used.", 314 | stacklevel=2, 315 | ) 316 | early_stopping_monitor = "regression_loss_validation" 317 | early_stopping_mode = "min" 318 | if n_epochs_kl_warmup is None: 319 | n_epochs_kl_warmup = max(max_epochs // 3, 1) 320 | update_dict = { 321 | "lr": lr, 322 | "adversarial_classifier": adversarial_mixing, 323 | "weight_decay": weight_decay, 324 | "eps": eps, 325 | "n_epochs_kl_warmup": n_epochs_kl_warmup, 326 | "n_steps_kl_warmup": n_steps_kl_warmup, 327 | "optimizer": "AdamW", 328 | "scale_adversarial_loss": 1, 329 | } 330 | if plan_kwargs is not None: 331 | plan_kwargs.update(update_dict) 332 | else: 333 | plan_kwargs = update_dict 334 | 335 | if save_best: 336 | if "callbacks" not in kwargs.keys(): 337 | kwargs["callbacks"] = [] 338 | kwargs["callbacks"].append(SaveBestState(monitor=early_stopping_monitor, mode=early_stopping_mode)) 339 | 340 | if save_checkpoint_every_n_epochs is not None: 341 | if path_to_checkpoints is not None: 342 | kwargs["callbacks"].append( 343 | ModelCheckpoint( 344 | dirpath=path_to_checkpoints, 345 | save_top_k=-1, 346 | monitor="epoch", 347 | every_n_epochs=save_checkpoint_every_n_epochs, 348 | verbose=True, 349 | ) 350 | ) 351 | else: 352 | raise ValueError( 353 | f"`save_checkpoint_every_n_epochs` = {save_checkpoint_every_n_epochs} so `path_to_checkpoints` has to be not None but is {path_to_checkpoints}." 354 | ) 355 | # until here 356 | 357 | data_splitter = GroupDataSplitter( 358 | self.adata_manager, 359 | group_column=self.sample_key, 360 | train_size=train_size, 361 | validation_size=validation_size, 362 | batch_size=batch_size, 363 | use_gpu=use_gpu, 364 | ) 365 | 366 | training_plan = AdversarialTrainingPlan(self.module, **plan_kwargs) 367 | runner = TrainRunner( 368 | self, 369 | training_plan=training_plan, 370 | data_splitter=data_splitter, 371 | max_epochs=max_epochs, 372 | use_gpu=use_gpu, 373 | early_stopping=early_stopping, 374 | check_val_every_n_epoch=check_val_every_n_epoch, 375 | early_stopping_monitor=early_stopping_monitor, 376 | early_stopping_mode=early_stopping_mode, 377 | early_stopping_patience=50, 378 | enable_checkpointing=True, 379 | **kwargs, 380 | ) 381 | return runner() 382 | 383 | @classmethod 384 | def setup_anndata( 385 | cls, 386 | adata: ad.AnnData, 387 | categorical_covariate_keys: list[str] | None = None, 388 | continuous_covariate_keys: list[str] | None = None, 389 | ordinal_regression_order: dict[str, list[str]] | None = None, 390 | **kwargs, 391 | ): 392 | """Set up :class:`~anndata.AnnData` object. 393 | 394 | A mapping will be created between data fields used by ``scvi`` to their respective locations in adata. 395 | This method will also compute the log mean and log variance per batch for the library size prior. 396 | None of the data in adata are modified. Only adds fields to adata. 397 | 398 | Parameters 399 | ---------- 400 | adata 401 | AnnData object containing raw counts. Rows represent cells, columns represent features. 402 | categorical_covariate_keys 403 | Keys in `adata.obs` that correspond to categorical data. 404 | continuous_covariate_keys 405 | Keys in `adata.obs` that correspond to continuous data. 406 | ordinal_regression_order 407 | Dictionary with regression classes as keys and order of classes as values. 408 | kwargs 409 | Additional parameters to pass to register_fields() of AnnDataManager. 410 | """ 411 | setup_ordinal_regression(adata, ordinal_regression_order, categorical_covariate_keys) 412 | 413 | setup_method_args = cls._get_setup_method_args(**locals()) 414 | 415 | anndata_fields = [ 416 | fields.LayerField( 417 | REGISTRY_KEYS.X_KEY, 418 | layer=None, 419 | ), 420 | fields.CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, None), 421 | fields.CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys), 422 | fields.NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys), 423 | ] 424 | 425 | adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) 426 | adata_manager.register_fields(adata, **kwargs) 427 | cls.register_manager(adata_manager) 428 | 429 | @torch.inference_mode() 430 | def get_model_output( 431 | self, 432 | adata=None, 433 | batch_size=256, 434 | ): 435 | """Save the attention scores and predictions in the adata object. 436 | 437 | Parameters 438 | ---------- 439 | adata 440 | AnnData object to run the model on. If `None`, the model's AnnData object is used. 441 | batch_size 442 | Minibatch size to use. Default is 256. 443 | 444 | """ 445 | if not self.is_trained_: 446 | raise RuntimeError("Please train the model first.") 447 | 448 | adata = self._validate_anndata(adata) 449 | 450 | scdl = self._make_data_loader( 451 | adata=adata, 452 | batch_size=batch_size, 453 | min_size_per_class=batch_size, # hack to ensure that not full batches are processed properly 454 | data_loader_class=GroupAnnDataLoader, 455 | shuffle=False, 456 | shuffle_classes=False, 457 | group_column=self.sample_key, 458 | drop_last=False, 459 | ) 460 | 461 | cell_level_attn, bags = [], [] 462 | class_pred, ord_pred, reg_pred = {}, {}, {} 463 | ( 464 | bag_class_true, 465 | bag_class_pred, 466 | bag_reg_true, 467 | bag_reg_pred, 468 | bag_ord_true, 469 | bag_ord_pred, 470 | ) = ({}, {}, {}, {}, {}, {}) 471 | 472 | bag_counter = 0 473 | cell_counter = 0 474 | 475 | for tensors in scdl: 476 | cont_key = REGISTRY_KEYS.CONT_COVS_KEY 477 | cont_covs = tensors[cont_key] if cont_key in tensors.keys() else None 478 | 479 | cat_key = REGISTRY_KEYS.CAT_COVS_KEY 480 | cat_covs = tensors[cat_key] if cat_key in tensors.keys() else None 481 | 482 | inference_inputs = self.module._get_inference_input(tensors) 483 | outputs = self.module.inference(**inference_inputs) 484 | pred = outputs["predictions"] 485 | 486 | # get attention for each cell in the bag 487 | if self.scoring in ["gated_attn", "attn"]: 488 | cell_attn = self.module.cell_level_aggregator[-1].A.squeeze(dim=1) 489 | cell_attn = cell_attn.flatten() # in inference always one patient per batch 490 | cell_level_attn += [cell_attn.cpu()] 491 | 492 | assert outputs["z_joint"].shape[0] % pred[0].shape[0] == 0 493 | sample_size = outputs["z_joint"].shape[0] // pred[0].shape[0] # how many cells in patient_minibatch 494 | minibatch_size, n_samples_in_batch = prep_minibatch(cat_covs, self.module.sample_batch_size) 495 | regression = select_covariates(cont_covs, self.regression_idx, n_samples_in_batch) 496 | ordinal_regression = select_covariates(cat_covs, self.ord_idx, n_samples_in_batch) 497 | classification = select_covariates(cat_covs, self.class_idx, n_samples_in_batch) 498 | 499 | # calculate accuracies of predictions 500 | bag_class_pred, bag_class_true, class_pred = get_predictions( 501 | self.class_idx, pred, classification, sample_size, bag_class_pred, bag_class_true, class_pred 502 | ) 503 | bag_ord_pred, bag_ord_true, ord_pred = get_predictions( 504 | self.ord_idx, 505 | pred, 506 | ordinal_regression, 507 | sample_size, 508 | bag_ord_pred, 509 | bag_ord_true, 510 | ord_pred, 511 | len(self.class_idx), 512 | ) 513 | bag_reg_pred, bag_reg_true, reg_pred = get_predictions( 514 | self.regression_idx, 515 | pred, 516 | regression, 517 | sample_size, 518 | bag_reg_pred, 519 | bag_reg_true, 520 | reg_pred, 521 | len(self.class_idx) + len(self.ord_idx), 522 | ) 523 | 524 | # TODO remove n_of_bags_in_batch after testing 525 | n_of_bags_in_batch = pred[0].shape[0] 526 | assert n_samples_in_batch == n_of_bags_in_batch 527 | 528 | # save bag info to be able to calculate bag predictions later 529 | bags, cell_counter, bag_counter = get_bag_info( 530 | bags, n_samples_in_batch, minibatch_size, cell_counter, bag_counter, self.module.sample_batch_size 531 | ) 532 | 533 | if self.scoring in ["gated_attn", "attn"]: 534 | cell_level = torch.cat(cell_level_attn).numpy() 535 | adata.obs["cell_attn"] = cell_level 536 | flat_bags = [value for sublist in bags for value in sublist] 537 | adata.obs["bags"] = flat_bags 538 | 539 | for i in range(len(self.class_idx)): 540 | name = self.classification[i] 541 | class_names = self.adata_manager.get_state_registry("extra_categorical_covs")["mappings"][name] 542 | save_predictions_in_adata( 543 | adata, 544 | i, 545 | self.classification, 546 | bag_class_pred, 547 | bag_class_true, 548 | class_pred, 549 | class_names, 550 | name, 551 | clip="argmax", 552 | ) 553 | for i in range(len(self.ord_idx)): 554 | name = self.ordinal_regression[i] 555 | class_names = self.adata_manager.get_state_registry("extra_categorical_covs")["mappings"][name] 556 | save_predictions_in_adata( 557 | adata, i, self.ordinal_regression, bag_ord_pred, bag_ord_true, ord_pred, class_names, name, clip="clip" 558 | ) 559 | for i in range(len(self.regression_idx)): 560 | name = self.regression[i] 561 | reg_names = self.adata_manager.get_state_registry("extra_continuous_covs")["columns"] 562 | save_predictions_in_adata( 563 | adata, i, self.regression, bag_reg_pred, bag_reg_true, reg_pred, reg_names, name, clip=None, reg=True 564 | ) 565 | 566 | def plot_losses(self, save=None): 567 | """Plot losses. 568 | 569 | Parameters 570 | ---------- 571 | save 572 | If not None, save the plot to this location. 573 | """ 574 | loss_names = self.module.select_losses_to_plot() 575 | plt_plot_losses(self.history, loss_names, save) 576 | 577 | # adjusted from scvi-tools 578 | # https://github.com/scverse/scvi-tools/blob/0b802762869c43c9f49e69fe62b1a5a9b5c4dae6/scvi/model/base/_archesmixin.py#L30 579 | # accessed on 7 November 2022 580 | @classmethod 581 | def load_query_data( 582 | cls, 583 | adata: AnnData, 584 | reference_model: BaseModelClass, 585 | use_gpu: str | int | bool | None = None, 586 | ) -> BaseModelClass: 587 | """Online update of a reference model with scArches algorithm # TODO cite. 588 | 589 | Parameters 590 | ---------- 591 | adata 592 | AnnData organized in the same way as data used to train model. 593 | It is not necessary to run setup_anndata, 594 | as AnnData is validated against the ``registry``. 595 | reference_model 596 | Already instantiated model of the same class. 597 | use_gpu 598 | Load model on default GPU if available (if None or True), 599 | or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False). 600 | 601 | Returns 602 | ------- 603 | Model with updated architecture and weights. 604 | """ 605 | # currently this function works only if the prediction cov is present in the .obs of the query 606 | # TODO need to allow it to be missing, maybe add a dummy column to .obs of query adata 607 | 608 | _, _, device = parse_use_gpu_arg(use_gpu) 609 | 610 | attr_dict, _, _ = _get_loaded_data(reference_model, device=device) 611 | 612 | registry = attr_dict.pop("registry_") 613 | if _MODEL_NAME_KEY in registry and registry[_MODEL_NAME_KEY] != cls.__name__: 614 | raise ValueError("It appears you are loading a model from a different class.") 615 | 616 | if _SETUP_ARGS_KEY not in registry: 617 | raise ValueError("Saved model does not contain original setup inputs. " "Cannot load the original setup.") 618 | 619 | cls.setup_anndata( 620 | adata, 621 | source_registry=registry, 622 | extend_categories=True, 623 | allow_missing_labels=True, 624 | **registry[_SETUP_ARGS_KEY], 625 | ) 626 | 627 | model = _initialize_model(cls, adata, attr_dict) 628 | model.module.load_state_dict(reference_model.module.state_dict()) 629 | model.to_device(device) 630 | 631 | model.module.eval() 632 | model.is_trained_ = True 633 | 634 | return model 635 | -------------------------------------------------------------------------------- /src/multimil/model/_multivae.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Literal 3 | 4 | import anndata as ad 5 | import torch 6 | from pytorch_lightning.callbacks import ModelCheckpoint 7 | from scvi import REGISTRY_KEYS 8 | from scvi.data import AnnDataManager, fields 9 | from scvi.data._constants import _MODEL_NAME_KEY, _SETUP_ARGS_KEY 10 | from scvi.dataloaders import DataSplitter 11 | from scvi.model._utils import parse_use_gpu_arg 12 | from scvi.model.base import ArchesMixin, BaseModelClass 13 | from scvi.model.base._archesmixin import _get_loaded_data 14 | from scvi.model.base._utils import _initialize_model 15 | from scvi.train import AdversarialTrainingPlan, TrainRunner 16 | from scvi.train._callbacks import SaveBestState 17 | 18 | from multimil.dataloaders import GroupDataSplitter 19 | from multimil.module import MultiVAETorch 20 | from multimil.utils import calculate_size_factor, plt_plot_losses 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | class MultiVAE(BaseModelClass, ArchesMixin): 26 | """MultiMIL multimodal integration model. 27 | 28 | Parameters 29 | ---------- 30 | adata 31 | AnnData object that has been registered via :meth:`~multigrate.model.MultiVAE.setup_anndata`. 32 | integrate_on 33 | One of the categorical covariates refistered with :math:`~multigrate.model.MultiVAE.setup_anndata` to integrate on. The latent space then will be disentangled from this covariate. If `None`, no integration is performed. 34 | condition_encoders 35 | Whether to concatentate covariate embeddings to the first layer of the encoders. Default is `False`. 36 | condition_decoders 37 | Whether to concatentate covariate embeddings to the first layer of the decoders. Default is `True`. 38 | normalization 39 | What normalization to use; has to be one of `batch` or `layer`. Default is `layer`. 40 | z_dim 41 | Dimensionality of the latent space. Default is 16. 42 | losses 43 | Which losses to use for each modality. Has to be the same length as the number of modalities. Default is `MSE` for all modalities. 44 | dropout 45 | Dropout rate. Default is 0.2. 46 | cond_dim 47 | Dimensionality of the covariate embeddings. Default is 16. 48 | kernel_type 49 | Type of kernel to use for the MMD loss. Default is `gaussian`. 50 | loss_coefs 51 | Loss coeficients for the different losses in the model. Default is 1 for all. 52 | cont_cov_type 53 | How to calculate embeddings for continuous covariates. Default is `logsim`. 54 | n_layers_cont_embed 55 | Number of layers for the continuous covariate embedding calculation. Default is 1. 56 | n_layers_encoders 57 | Number of layers for each encoder. Default is 2 for all modalities. Has to be the same length as the number of modalities. 58 | n_layers_decoders 59 | Number of layers for each decoder. Default is 2 for all modalities. Has to be the same length as the number of modalities. 60 | n_hidden_cont_embed 61 | Number of nodes for each hidden layer in the continuous covariate embedding calculation. Default is 32. 62 | n_hidden_encoders 63 | Number of nodes for each hidden layer in the encoders. Default is 32. 64 | n_hidden_decoders 65 | Number of nodes for each hidden layer in the decoders. Default is 32. 66 | modality_alignment 67 | Whether to align the modalities, one of ['MMD', 'Jeffreys', None]. Default is `None`. 68 | alignment_type 69 | Which alignment type to use, one of ['latent', 'marginal', 'both']. Default is `latent`. 70 | activation 71 | Activation function to use. Default is `leaky_relu`. 72 | initialization 73 | Initialization method to use. Default is `None`. 74 | ignore_covariates 75 | List of covariates to ignore. Needed for query-to-reference mapping. Default is `None`. 76 | mix 77 | How to mix the distributions to get the joint, one of ['product', 'mixture']. Default is `product`. 78 | """ 79 | 80 | def __init__( 81 | self, 82 | adata: ad.AnnData, 83 | integrate_on: str | None = None, 84 | condition_encoders: bool = True, 85 | condition_decoders: bool = True, 86 | normalization: Literal["layer", "batch", None] = "layer", 87 | z_dim: int = 16, 88 | losses: list[str] | None = None, 89 | dropout: float = 0.2, 90 | cond_dim: int = 16, 91 | kernel_type: Literal["gaussian", None] = "gaussian", 92 | loss_coefs: dict[int, float] = None, 93 | cont_cov_type: Literal["logsim", "sigm", None] = "logsigm", 94 | n_layers_cont_embed: int = 1, # TODO default to None? 95 | n_layers_encoders: list[int] | None = None, 96 | n_layers_decoders: list[int] | None = None, 97 | n_hidden_cont_embed: int = 32, # TODO default to None? 98 | n_hidden_encoders: list[int] | None = None, 99 | n_hidden_decoders: list[int] | None = None, 100 | modality_alignment: Literal["MMD", "Jeffreys", None] = None, 101 | alignment_type: Literal["latent", "marginal", "both"] = "latent", 102 | activation: str | None = "leaky_relu", # TODO add which options are impelemted 103 | initialization: str | None = None, # TODO add which options are impelemted 104 | ignore_covariates: list[str] | None = None, 105 | mix: Literal["product", "mixture"] = "product", 106 | ): 107 | super().__init__(adata) 108 | 109 | # for the integration with the alignment loss 110 | self.group_column = integrate_on 111 | 112 | # TODO: add options for number of hidden layers, hidden layers dim and output activation functions 113 | if normalization not in ["layer", "batch", None]: 114 | raise ValueError('Normalization has to be one of ["layer", "batch", None]') 115 | # TODO: do some assertions for other parameters 116 | 117 | if ignore_covariates is None: 118 | ignore_covariates = [] 119 | 120 | if ( 121 | "nb" in losses or "zinb" in losses 122 | ) and REGISTRY_KEYS.SIZE_FACTOR_KEY not in self.adata_manager.data_registry: 123 | raise ValueError(f"Have to register {REGISTRY_KEYS.SIZE_FACTOR_KEY} when using 'nb' or 'zinb' loss.") 124 | 125 | self.num_groups = 1 126 | self.integrate_on_idx = None 127 | if integrate_on is not None: 128 | if integrate_on not in self.adata_manager.registry["setup_args"]["categorical_covariate_keys"]: 129 | raise ValueError( 130 | "Cannot integrate on {!r}, has to be one of the registered categorical covariates = {}".format( 131 | integrate_on, self.adata_manager.registry["setup_args"]["categorical_covariate_keys"] 132 | ) 133 | ) 134 | elif integrate_on in ignore_covariates: 135 | raise ValueError( 136 | f"Specified integrate_on = {integrate_on!r} is in ignore_covariates = {ignore_covariates}." 137 | ) 138 | else: 139 | self.num_groups = len( 140 | self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY)["mappings"][integrate_on] 141 | ) 142 | self.integrate_on_idx = self.adata_manager.registry["setup_args"]["categorical_covariate_keys"].index( 143 | integrate_on 144 | ) 145 | 146 | self.modality_lengths = [ 147 | adata.uns["modality_lengths"][key] for key in sorted(adata.uns["modality_lengths"].keys()) 148 | ] 149 | 150 | self.cont_covs_idx = [] 151 | self.cont_covariate_dims = [] 152 | if len(cont_covs := self.adata_manager.get_state_registry(REGISTRY_KEYS.CONT_COVS_KEY)) > 0: 153 | for i, key in enumerate(cont_covs["columns"]): 154 | if key not in ignore_covariates: 155 | self.cont_covs_idx.append(i) 156 | self.cont_covariate_dims.append(1) 157 | 158 | self.cat_covs_idx = [] 159 | self.cat_covariate_dims = [] 160 | if len(cat_covs := self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY)) > 0: 161 | for i, num_cat in enumerate(cat_covs.n_cats_per_key): 162 | if cat_covs["field_keys"][i] not in ignore_covariates: 163 | self.cat_covs_idx.append(i) 164 | self.cat_covariate_dims.append(num_cat) 165 | 166 | self.cat_covs_idx = torch.tensor(self.cat_covs_idx) 167 | self.cont_covs_idx = torch.tensor(self.cont_covs_idx) 168 | 169 | self.module = MultiVAETorch( 170 | modality_lengths=self.modality_lengths, 171 | condition_encoders=condition_encoders, 172 | condition_decoders=condition_decoders, 173 | normalization=normalization, 174 | z_dim=z_dim, 175 | losses=losses, 176 | dropout=dropout, 177 | cond_dim=cond_dim, 178 | kernel_type=kernel_type, 179 | loss_coefs=loss_coefs, 180 | num_groups=self.num_groups, 181 | integrate_on_idx=self.integrate_on_idx, 182 | cat_covs_idx=self.cat_covs_idx, 183 | cont_covs_idx=self.cont_covs_idx, 184 | cat_covariate_dims=self.cat_covariate_dims, 185 | cont_covariate_dims=self.cont_covariate_dims, 186 | n_layers_encoders=n_layers_encoders, 187 | n_layers_decoders=n_layers_decoders, 188 | n_hidden_encoders=n_hidden_encoders, 189 | n_hidden_decoders=n_hidden_decoders, 190 | cont_cov_type=cont_cov_type, 191 | n_layers_cont_embed=n_layers_cont_embed, 192 | n_hidden_cont_embed=n_hidden_cont_embed, 193 | modality_alignment=modality_alignment, 194 | alignment_type=alignment_type, 195 | activation=activation, 196 | initialization=initialization, 197 | mix=mix, 198 | ) 199 | 200 | self.init_params_ = self._get_init_params(locals()) 201 | 202 | @torch.inference_mode() 203 | def impute(self, adata=None, batch_size=256): 204 | """Impute missing values in the adata object. 205 | 206 | Parameters 207 | ---------- 208 | adata 209 | AnnData object to run the model on. If `None`, the model's AnnData object is used. 210 | batch_size 211 | Minibatch size to use. Default is 256. 212 | """ 213 | if not self.is_trained_: 214 | raise RuntimeError("Please train the model first.") 215 | 216 | adata = self._validate_anndata(adata) 217 | 218 | scdl = self._make_data_loader(adata=adata, batch_size=batch_size) 219 | 220 | imputed = [[] for _ in range(len(self.modality_lengths))] 221 | 222 | for tensors in scdl: 223 | inference_inputs = self.module._get_inference_input(tensors) 224 | inference_outputs = self.module.inference(**inference_inputs) 225 | generative_inputs = self.module._get_generative_input(tensors, inference_outputs) 226 | outputs = self.module.generative(**generative_inputs) 227 | for i, output in enumerate(outputs["rs"]): 228 | imputed[i] += [output.cpu()] 229 | for i in range(len(imputed)): 230 | imputed[i] = torch.cat(imputed[i]).numpy() 231 | adata.obsm[f"imputed_modality_{i}"] = imputed[i] 232 | 233 | @torch.inference_mode() 234 | def get_model_output(self, adata=None, batch_size=256): 235 | """Save the latent representation in the adata object. 236 | 237 | Parameters 238 | ---------- 239 | adata 240 | AnnData object to run the model on. If `None`, the model's AnnData object is used. 241 | batch_size 242 | Minibatch size to use. Default is 256. 243 | """ 244 | if not self.is_trained_: 245 | raise RuntimeError("Please train the model first.") 246 | 247 | adata = self._validate_anndata(adata) 248 | 249 | scdl = self._make_data_loader(adata=adata, batch_size=batch_size) 250 | 251 | latent = [] 252 | for tensors in scdl: 253 | inference_inputs = self.module._get_inference_input(tensors) 254 | outputs = self.module.inference(**inference_inputs) 255 | z = outputs["z_joint"] 256 | latent += [z.cpu()] 257 | 258 | adata.obsm["X_multiMIL"] = torch.cat(latent).numpy() 259 | 260 | def train( 261 | self, 262 | max_epochs: int = 200, 263 | lr: float = 5e-4, 264 | use_gpu: str | int | bool | None = None, 265 | train_size: float = 0.9, 266 | validation_size: float | None = None, 267 | batch_size: int = 256, 268 | weight_decay: float = 1e-3, 269 | eps: float = 1e-08, 270 | early_stopping: bool = True, 271 | save_best: bool = True, 272 | check_val_every_n_epoch: int | None = None, 273 | n_epochs_kl_warmup: int | None = None, 274 | n_steps_kl_warmup: int | None = None, 275 | adversarial_mixing: bool = False, # TODO check if suppored by us, i don't think it is 276 | plan_kwargs: dict | None = None, 277 | save_checkpoint_every_n_epochs: int | None = None, 278 | path_to_checkpoints: str | None = None, 279 | **kwargs, 280 | ): 281 | """Train the model using amortized variational inference. 282 | 283 | Parameters 284 | ---------- 285 | max_epochs 286 | Number of passes through the dataset. 287 | lr 288 | Learning rate for optimization. 289 | use_gpu 290 | Use default GPU if available (if None or True), or index of GPU to use (if int), 291 | or name of GPU (if str), or use CPU (if False). 292 | train_size 293 | Size of training set in the range [0.0, 1.0]. 294 | validation_size 295 | Size of the test set. If `None`, defaults to 1 - `train_size`. If 296 | `train_size + validation_size < 1`, the remaining cells belong to a test set. 297 | batch_size 298 | Minibatch size to use during training. 299 | weight_decay 300 | Weight decay regularization term for optimization. 301 | eps 302 | Optimizer eps. 303 | early_stopping 304 | Whether to perform early stopping with respect to the validation set. 305 | save_best 306 | Save the best model state with respect to the validation loss, or use the final 307 | state in the training procedure. 308 | check_val_every_n_epoch 309 | Check val every n train epochs. By default, val is not checked, unless `early_stopping` is `True`. 310 | If so, val is checked every epoch. 311 | n_epochs_kl_warmup 312 | Number of epochs to scale weight on KL divergences from 0 to 1. 313 | Overrides `n_steps_kl_warmup` when both are not `None`. Default is 1/3 of `max_epochs`. 314 | n_steps_kl_warmup 315 | Number of training steps (minibatches) to scale weight on KL divergences from 0 to 1. 316 | Only activated when `n_epochs_kl_warmup` is set to None. If `None`, defaults 317 | to `floor(0.75 * adata.n_obs)`. 318 | adversarial_mixing 319 | Whether to use adversarial mixing. Default is `False`. 320 | plan_kwargs 321 | Keyword args for :class:`~scvi.train.TrainingPlan`. Keyword arguments passed to 322 | `train()` will overwrite values present in `plan_kwargs`, when appropriate. 323 | save_checkpoint_every_n_epochs 324 | Save a checkpoint every n epochs. If `None`, no checkpoints are saved. 325 | path_to_checkpoints 326 | Path to save checkpoints. Required if `save_checkpoint_every_n_epochs` is not `None`. 327 | kwargs 328 | Additional keyword arguments for :class:`~scvi.train.TrainRunner`. 329 | 330 | Returns 331 | ------- 332 | Trainer object. 333 | """ 334 | if n_epochs_kl_warmup is None: 335 | n_epochs_kl_warmup = max(max_epochs // 3, 1) 336 | update_dict = { 337 | "lr": lr, 338 | "adversarial_classifier": adversarial_mixing, 339 | "weight_decay": weight_decay, 340 | "eps": eps, 341 | "n_epochs_kl_warmup": n_epochs_kl_warmup, 342 | "n_steps_kl_warmup": n_steps_kl_warmup, 343 | "optimizer": "AdamW", 344 | "scale_adversarial_loss": 1, 345 | } 346 | if plan_kwargs is not None: 347 | plan_kwargs.update(update_dict) 348 | else: 349 | plan_kwargs = update_dict 350 | 351 | if save_best: 352 | if "callbacks" not in kwargs.keys(): 353 | kwargs["callbacks"] = [] 354 | kwargs["callbacks"].append(SaveBestState(monitor="reconstruction_loss_validation")) 355 | 356 | if save_checkpoint_every_n_epochs is not None: 357 | if path_to_checkpoints is not None: 358 | kwargs["callbacks"].append( 359 | ModelCheckpoint( 360 | dirpath=path_to_checkpoints, 361 | save_top_k=-1, 362 | monitor="epoch", 363 | every_n_epochs=save_checkpoint_every_n_epochs, 364 | verbose=True, 365 | ) 366 | ) 367 | else: 368 | raise ValueError( 369 | f"`save_checkpoint_every_n_epochs` = {save_checkpoint_every_n_epochs} so `path_to_checkpoints` has to be not None but is {path_to_checkpoints}." 370 | ) 371 | 372 | if self.group_column is not None: 373 | data_splitter = GroupDataSplitter( 374 | self.adata_manager, 375 | group_column=self.group_column, 376 | train_size=train_size, 377 | validation_size=validation_size, 378 | batch_size=batch_size, 379 | use_gpu=use_gpu, 380 | ) 381 | else: 382 | data_splitter = DataSplitter( 383 | self.adata_manager, 384 | train_size=train_size, 385 | validation_size=validation_size, 386 | batch_size=batch_size, 387 | use_gpu=use_gpu, 388 | ) 389 | training_plan = AdversarialTrainingPlan(self.module, **plan_kwargs) 390 | runner = TrainRunner( 391 | self, 392 | training_plan=training_plan, 393 | data_splitter=data_splitter, 394 | max_epochs=max_epochs, 395 | use_gpu=use_gpu, 396 | early_stopping=early_stopping, 397 | check_val_every_n_epoch=check_val_every_n_epoch, 398 | early_stopping_monitor="reconstruction_loss_validation", 399 | early_stopping_patience=50, 400 | enable_checkpointing=True, 401 | **kwargs, 402 | ) 403 | return runner() 404 | 405 | @classmethod 406 | def setup_anndata( 407 | cls, 408 | adata: ad.AnnData, 409 | size_factor_key: str | None = None, 410 | rna_indices_end: int | None = None, 411 | categorical_covariate_keys: list[str] | None = None, 412 | continuous_covariate_keys: list[str] | None = None, 413 | **kwargs, 414 | ): 415 | """Set up :class:`~anndata.AnnData` object. 416 | 417 | A mapping will be created between data fields used by ``scvi`` to their respective locations in adata. 418 | This method will also compute the log mean and log variance per batch for the library size prior. 419 | None of the data in adata are modified. Only adds fields to adata. 420 | 421 | Parameters 422 | ---------- 423 | adata 424 | AnnData object containing raw counts. Rows represent cells, columns represent features. 425 | size_factor_key 426 | Key in `adata.obs` containing the size factor. If `None`, will be calculated from the RNA counts. 427 | rna_indices_end 428 | Integer to indicate where RNA feature end in the AnnData object. Is used to calculate ``libary_size``. 429 | categorical_covariate_keys 430 | Keys in `adata.obs` that correspond to categorical data. 431 | continuous_covariate_keys 432 | Keys in `adata.obs` that correspond to continuous data. 433 | kwargs 434 | Additional parameters to pass to register_fields() of AnnDataManager. 435 | """ 436 | setup_method_args = cls._get_setup_method_args(**locals()) 437 | 438 | anndata_fields = [ 439 | fields.LayerField( 440 | REGISTRY_KEYS.X_KEY, 441 | layer=None, 442 | ), 443 | fields.CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, None), # TODO check if need to add if it's None 444 | fields.CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys), 445 | fields.NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys), 446 | ] 447 | size_factor_key = calculate_size_factor(adata, size_factor_key, rna_indices_end) 448 | anndata_fields.append(fields.NumericalObsField(REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key)) 449 | 450 | adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) 451 | adata_manager.register_fields(adata, **kwargs) 452 | cls.register_manager(adata_manager) 453 | 454 | def plot_losses(self, save=None): 455 | """Plot losses. 456 | 457 | Parameters 458 | ---------- 459 | save 460 | If not None, save the plot to this location. 461 | """ 462 | loss_names = self.module.select_losses_to_plot() 463 | plt_plot_losses(self.history, loss_names, save) 464 | 465 | # adjusted from scvi-tools 466 | # https://github.com/scverse/scvi-tools/blob/0b802762869c43c9f49e69fe62b1a5a9b5c4dae6/scvi/model/base/_archesmixin.py#L30 467 | # accessed on 5 November 2022 468 | @classmethod 469 | def load_query_data( 470 | cls, 471 | adata: ad.AnnData, 472 | reference_model: BaseModelClass, 473 | use_gpu: str | int | bool | None = None, 474 | freeze: bool = True, 475 | ignore_covariates: list[str] | None = None, 476 | ) -> BaseModelClass: 477 | """Online update of a reference model with scArches algorithm # TODO cite. 478 | 479 | Parameters 480 | ---------- 481 | adata 482 | AnnData organized in the same way as data used to train model. 483 | It is not necessary to run setup_anndata, 484 | as AnnData is validated against the ``registry``. 485 | reference_model 486 | Already instantiated model of the same class. 487 | use_gpu 488 | Load model on default GPU if available (if None or True), 489 | or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False). 490 | freeze 491 | Whether to freeze the encoders and decoders and only train the new weights. 492 | ignore_covariates 493 | List of covariates to ignore. Needed for query-to-reference mapping. Default is `None`. 494 | 495 | Returns 496 | ------- 497 | Model with updated architecture and weights. 498 | """ 499 | _, _, device = parse_use_gpu_arg(use_gpu) 500 | 501 | attr_dict, _, load_state_dict = _get_loaded_data(reference_model, device=device) 502 | 503 | registry = attr_dict.pop("registry_") 504 | if _MODEL_NAME_KEY in registry and registry[_MODEL_NAME_KEY] != cls.__name__: 505 | raise ValueError("It appears you are loading a model from a different class.") 506 | 507 | if _SETUP_ARGS_KEY not in registry: 508 | raise ValueError("Saved model does not contain original setup inputs. " "Cannot load the original setup.") 509 | 510 | if ignore_covariates is None: 511 | ignore_covariates = [] 512 | 513 | cls.setup_anndata( 514 | adata, 515 | source_registry=registry, 516 | extend_categories=True, 517 | allow_missing_labels=True, 518 | **registry[_SETUP_ARGS_KEY], 519 | ) 520 | 521 | model = _initialize_model(cls, adata, attr_dict) 522 | adata_manager = model.get_anndata_manager(adata, required=True) 523 | 524 | # TODO add an exception if need to add new categories but condition_encoders is False 525 | # model tweaking 526 | num_of_cat_to_add = [ 527 | new_cat - old_cat 528 | for i, (old_cat, new_cat) in enumerate( 529 | zip( 530 | reference_model.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).n_cats_per_key, 531 | adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).n_cats_per_key, 532 | strict=False, 533 | ) 534 | ) 535 | if adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY)["field_keys"][i] not in ignore_covariates 536 | ] 537 | 538 | model.to_device(device) 539 | 540 | new_state_dict = model.module.state_dict() 541 | for key, load_ten in load_state_dict.items(): # load_state_dict = old 542 | new_ten = new_state_dict[key] 543 | if new_ten.size() == load_ten.size(): 544 | continue 545 | # new categoricals changed size 546 | else: 547 | old_shape = new_ten.shape 548 | new_shape = load_ten.shape 549 | if old_shape[0] == new_shape[0]: 550 | dim_diff = new_ten.size()[-1] - load_ten.size()[-1] 551 | fixed_ten = torch.cat([load_ten, new_ten[..., -dim_diff:]], dim=-1) 552 | else: 553 | dim_diff = new_ten.size()[0] - load_ten.size()[0] 554 | fixed_ten = torch.cat([load_ten, new_ten[-dim_diff:, ...]], dim=0) 555 | load_state_dict[key] = fixed_ten 556 | 557 | model.module.load_state_dict(load_state_dict) 558 | 559 | # freeze everything but the embeddings that have new categories 560 | if freeze is True: 561 | for _, par in model.module.named_parameters(): 562 | par.requires_grad = False 563 | for i, embed in enumerate(model.module.cat_covariate_embeddings): 564 | if num_of_cat_to_add[i] > 0: # unfreeze the ones where categories were added 565 | embed.weight.requires_grad = True 566 | if model.module.integrate_on_idx is not None: 567 | model.module.theta.requires_grad = True 568 | 569 | model.module.eval() 570 | model.is_trained_ = False 571 | 572 | return model 573 | -------------------------------------------------------------------------------- /src/multimil/module/__init__.py: -------------------------------------------------------------------------------- 1 | from ._multivae_torch import MultiVAETorch 2 | from ._mil_torch import MILClassifierTorch 3 | from ._multivae_mil_torch import MultiVAETorch_MIL 4 | 5 | __all__ = ["MultiVAETorch", "MILClassifierTorch", "MultiVAETorch_MIL"] 6 | -------------------------------------------------------------------------------- /src/multimil/module/_mil_torch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from scvi import REGISTRY_KEYS 3 | from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from multimil.nn import MLP, Aggregator 8 | from multimil.utils import prep_minibatch, select_covariates 9 | 10 | 11 | class MILClassifierTorch(BaseModuleClass): 12 | """MultiMIL's MIL classification module. 13 | 14 | Parameters 15 | ---------- 16 | z_dim 17 | Latent dimension. 18 | dropout 19 | Dropout rate. 20 | normalization 21 | Normalization type. 22 | num_classification_classes 23 | Number of classes for each of the classification task. 24 | scoring 25 | Scoring type. One of ["gated_attn", "attn", "mean", "max", "sum", "mlp"]. 26 | attn_dim 27 | Hidden attention dimension. 28 | n_layers_cell_aggregator 29 | Number of layers in the cell aggregator. 30 | n_layers_classifier 31 | Number of layers in the classifier. 32 | n_layers_mlp_attn 33 | Number of layers in the MLP attention. 34 | n_layers_regressor 35 | Number of layers in the regressor. 36 | n_hidden_regressor 37 | Hidden dimension in the regressor. 38 | n_hidden_cell_aggregator 39 | Hidden dimension in the cell aggregator. 40 | n_hidden_classifier 41 | Hidden dimension in the classifier. 42 | n_hidden_mlp_attn 43 | Hidden dimension in the MLP attention. 44 | class_loss_coef 45 | Classification loss coefficient. 46 | regression_loss_coef 47 | Regression loss coefficient. 48 | sample_batch_size 49 | Sample batch size. 50 | class_idx 51 | Which indices in cat covariates to do classification on. 52 | ord_idx 53 | Which indices in cat covariates to do ordinal regression on. 54 | reg_idx 55 | Which indices in cont covariates to do regression on. 56 | activation 57 | Activation function. 58 | initialization 59 | Initialization type. 60 | anneal_class_loss 61 | Whether to anneal the classification loss. 62 | """ 63 | 64 | def __init__( 65 | self, 66 | z_dim=16, 67 | dropout=0.2, 68 | normalization="layer", 69 | num_classification_classes=None, # number of classes for each of the classification task 70 | scoring="gated_attn", 71 | attn_dim=16, 72 | n_layers_cell_aggregator=1, 73 | n_layers_classifier=1, 74 | n_layers_mlp_attn=1, 75 | n_layers_regressor=1, 76 | n_hidden_regressor=16, 77 | n_hidden_cell_aggregator=16, 78 | n_hidden_classifier=16, 79 | n_hidden_mlp_attn=16, 80 | class_loss_coef=1.0, 81 | regression_loss_coef=1.0, 82 | sample_batch_size=128, 83 | class_idx=None, # which indices in cat covariates to do classification on, i.e. exclude from inference; this is a torch tensor 84 | ord_idx=None, # which indices in cat covariates to do ordinal regression on and also exclude from inference; this is a torch tensor 85 | reg_idx=None, # which indices in cont covariates to do regression on and also exclude from inference; this is a torch tensor 86 | activation="leaky_relu", 87 | initialization=None, 88 | anneal_class_loss=False, 89 | ): 90 | super().__init__() 91 | 92 | if activation == "leaky_relu": 93 | self.activation = nn.LeakyReLU 94 | elif activation == "tanh": 95 | self.activation = nn.Tanh 96 | else: 97 | raise NotImplementedError( 98 | f'activation should be one of ["leaky_relu", "tanh"], but activation={activation} was passed.' 99 | ) 100 | 101 | self.class_loss_coef = class_loss_coef 102 | self.regression_loss_coef = regression_loss_coef 103 | self.sample_batch_size = sample_batch_size 104 | self.anneal_class_loss = anneal_class_loss 105 | self.num_classification_classes = num_classification_classes 106 | self.class_idx = class_idx 107 | self.ord_idx = ord_idx 108 | self.reg_idx = reg_idx 109 | 110 | self.cell_level_aggregator = nn.Sequential( 111 | MLP( 112 | z_dim, 113 | z_dim, 114 | n_layers=n_layers_cell_aggregator, 115 | n_hidden=n_hidden_cell_aggregator, 116 | dropout_rate=dropout, 117 | activation=self.activation, 118 | normalization=normalization, 119 | ), 120 | Aggregator( 121 | n_input=z_dim, 122 | scoring=scoring, 123 | attn_dim=attn_dim, 124 | sample_batch_size=sample_batch_size, 125 | scale=True, 126 | dropout=dropout, 127 | n_layers_mlp_attn=n_layers_mlp_attn, 128 | n_hidden_mlp_attn=n_hidden_mlp_attn, 129 | activation=self.activation, 130 | ), 131 | ) 132 | 133 | if len(self.class_idx) > 0: 134 | self.classifiers = torch.nn.ModuleList() 135 | 136 | # classify zs directly 137 | class_input_dim = z_dim 138 | 139 | for num in self.num_classification_classes: 140 | if n_layers_classifier == 1: 141 | self.classifiers.append(nn.Linear(class_input_dim, num)) 142 | else: 143 | self.classifiers.append( 144 | nn.Sequential( 145 | MLP( 146 | class_input_dim, 147 | n_hidden_classifier, 148 | n_layers=n_layers_classifier - 1, 149 | n_hidden=n_hidden_classifier, 150 | dropout_rate=dropout, 151 | activation=self.activation, 152 | ), 153 | nn.Linear(n_hidden_classifier, num), 154 | ) 155 | ) 156 | 157 | if len(self.ord_idx) + len(self.reg_idx) > 0: 158 | self.regressors = torch.nn.ModuleList() 159 | for _ in range( 160 | len(self.ord_idx) + len(self.reg_idx) 161 | ): # one head per standard regression and one per ordinal regression 162 | if n_layers_regressor == 1: 163 | self.regressors.append(nn.Linear(z_dim, 1)) 164 | else: 165 | self.regressors.append( 166 | nn.Sequential( 167 | MLP( 168 | z_dim, 169 | n_hidden_regressor, 170 | n_layers=n_layers_regressor - 1, 171 | n_hidden=n_hidden_regressor, 172 | dropout_rate=dropout, 173 | activation=self.activation, 174 | ), 175 | nn.Linear(n_hidden_regressor, 1), 176 | ) 177 | ) 178 | 179 | if initialization == "xavier": 180 | for layer in self.modules(): 181 | if isinstance(layer, nn.Linear): 182 | nn.init.xavier_uniform_(layer.weight, gain=nn.init.calculate_gain("leaky_relu")) 183 | elif initialization == "kaiming": 184 | for layer in self.modules(): 185 | if isinstance(layer, nn.Linear): 186 | # following https://towardsdatascience.com/understand-kaiming-initialization-and-implementation-detail-in-pytorch-f7aa967e9138 (accessed 16.08.22) 187 | nn.init.kaiming_normal_(layer.weight, mode="fan_in") 188 | 189 | def _get_inference_input(self, tensors): 190 | x = tensors[REGISTRY_KEYS.X_KEY] 191 | return {"x": x} 192 | 193 | def _get_generative_input(self, tensors, inference_outputs): 194 | z_joint = inference_outputs["z_joint"] 195 | return {"z_joint": z_joint} 196 | 197 | @auto_move_data 198 | def inference(self, x) -> dict[str, torch.Tensor | list[torch.Tensor]]: 199 | """Forward pass for inference. 200 | 201 | Parameters 202 | ---------- 203 | x 204 | Input. 205 | 206 | Returns 207 | ------- 208 | Predictions. 209 | """ 210 | z_joint = x 211 | inference_outputs = {"z_joint": z_joint} 212 | 213 | # MIL part 214 | batch_size = x.shape[0] 215 | 216 | idx = list(range(self.sample_batch_size, batch_size, self.sample_batch_size)) 217 | if ( 218 | batch_size % self.sample_batch_size != 0 219 | ): # can only happen during inference for last batches for each sample 220 | idx = [] 221 | zs = torch.tensor_split(z_joint, idx, dim=0) 222 | zs = torch.stack(zs, dim=0) # num of bags x batch_size x z_dim 223 | zs_attn = self.cell_level_aggregator(zs) # num of bags x cond_dim 224 | 225 | predictions = [] 226 | if len(self.class_idx) > 0: 227 | predictions.extend([classifier(zs_attn) for classifier in self.classifiers]) 228 | if len(self.ord_idx) + len(self.reg_idx) > 0: 229 | predictions.extend([regressor(zs_attn) for regressor in self.regressors]) 230 | 231 | inference_outputs.update( 232 | {"predictions": predictions} 233 | ) # predictions are a list as they can have different number of classes 234 | return inference_outputs # z_joint, mu, logvar, predictions 235 | 236 | @auto_move_data 237 | def generative(self, z_joint) -> torch.Tensor: 238 | # TODO even if not used, make consistent with the rest, i.e. return dict 239 | """Forward pass for generative. 240 | 241 | Parameters 242 | ---------- 243 | z_joint 244 | Latent embeddings. 245 | 246 | Returns 247 | ------- 248 | Same as input. 249 | """ 250 | return z_joint 251 | 252 | def _calculate_loss(self, tensors, inference_outputs, generative_outputs, kl_weight: float = 1.0): 253 | cont_key = REGISTRY_KEYS.CONT_COVS_KEY 254 | cont_covs = tensors[cont_key] if cont_key in tensors.keys() else None 255 | 256 | cat_key = REGISTRY_KEYS.CAT_COVS_KEY 257 | cat_covs = tensors[cat_key] if cat_key in tensors.keys() else None 258 | 259 | # MIL classification loss 260 | minibatch_size, n_samples_in_batch = prep_minibatch(cat_covs, self.sample_batch_size) 261 | regression = select_covariates(cont_covs, self.reg_idx.to(self.device), n_samples_in_batch) 262 | ordinal_regression = select_covariates(cat_covs, self.ord_idx.to(self.device), n_samples_in_batch) 263 | classification = select_covariates(cat_covs, self.class_idx.to(self.device), n_samples_in_batch) 264 | 265 | predictions = inference_outputs["predictions"] # list, first from classifiers, then from regressors 266 | 267 | accuracies = [] 268 | classification_loss = torch.tensor(0.0).to(self.device) 269 | for i in range(len(self.class_idx)): 270 | classification_loss += F.cross_entropy( 271 | predictions[i], classification[:, i].long() 272 | ) # assume same in the batch 273 | accuracies.append( 274 | torch.sum(torch.eq(torch.argmax(predictions[i], dim=-1), classification[:, i])) 275 | / classification[:, i].shape[0] 276 | ) 277 | 278 | regression_loss = torch.tensor(0.0).to(self.device) 279 | for i in range(len(self.ord_idx)): 280 | regression_loss += F.mse_loss(predictions[len(self.class_idx) + i].squeeze(-1), ordinal_regression[:, i]) 281 | accuracies.append( 282 | torch.sum( 283 | torch.eq( 284 | torch.clamp( 285 | torch.round(predictions[len(self.class_idx) + i].squeeze()), 286 | min=0.0, 287 | max=self.num_classification_classes[i] - 1.0, 288 | ), 289 | ordinal_regression[:, i], 290 | ) 291 | ) 292 | / ordinal_regression[:, i].shape[0] 293 | ) 294 | 295 | for i in range(len(self.reg_idx)): 296 | regression_loss += F.mse_loss( 297 | predictions[len(self.class_idx) + len(self.ord_idx) + i].squeeze(-1), 298 | regression[:, i], 299 | ) 300 | 301 | class_loss_anneal_coef = kl_weight if self.anneal_class_loss else 1.0 302 | 303 | loss = torch.mean( 304 | self.class_loss_coef * classification_loss * class_loss_anneal_coef 305 | + self.regression_loss_coef * regression_loss 306 | ) 307 | 308 | extra_metrics = { 309 | "class_loss": classification_loss, 310 | "regression_loss": regression_loss, 311 | } 312 | 313 | if len(accuracies) > 0: 314 | accuracy = torch.sum(torch.tensor(accuracies)) / len(accuracies) 315 | extra_metrics["accuracy"] = accuracy 316 | 317 | # don't need in this model but have to return 318 | recon_loss = torch.zeros(minibatch_size) 319 | kl_loss = torch.zeros(minibatch_size) 320 | 321 | return loss, recon_loss, kl_loss, extra_metrics 322 | 323 | def loss(self, tensors, inference_outputs, generative_outputs, kl_weight: float = 1.0): 324 | """Loss calculation. 325 | 326 | Parameters 327 | ---------- 328 | tensors 329 | Input tensors. 330 | inference_outputs 331 | Inference outputs. 332 | generative_outputs 333 | Generative outputs. 334 | kl_weight 335 | KL weight. Default is 1.0. 336 | 337 | Returns 338 | ------- 339 | Prediction loss. 340 | """ 341 | loss, recon_loss, kl_loss, extra_metrics = self._calculate_loss( 342 | tensors, inference_outputs, generative_outputs, kl_weight 343 | ) 344 | 345 | return LossOutput( 346 | loss=loss, 347 | reconstruction_loss=recon_loss, 348 | kl_local=kl_loss, 349 | extra_metrics=extra_metrics, 350 | ) 351 | 352 | def select_losses_to_plot(self): 353 | """Select losses to plot. 354 | 355 | Returns 356 | ------- 357 | Loss names. 358 | """ 359 | loss_names = [] 360 | if self.class_loss_coef != 0 and len(self.class_idx) > 0: 361 | loss_names.extend(["class_loss", "accuracy"]) 362 | if self.regression_loss_coef != 0 and len(self.reg_idx) > 0: 363 | loss_names.append("regression_loss") 364 | if self.regression_loss_coef != 0 and len(self.ord_idx) > 0: 365 | loss_names.extend(["regression_loss", "accuracy"]) 366 | return loss_names 367 | -------------------------------------------------------------------------------- /src/multimil/module/_multivae_mil_torch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from scvi import REGISTRY_KEYS 3 | from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data 4 | 5 | from multimil.module import MILClassifierTorch, MultiVAETorch 6 | 7 | 8 | class MultiVAETorch_MIL(BaseModuleClass): 9 | """MultiMIL's end-to-end multimodal integration and MIL classification modules. 10 | 11 | Parameters 12 | ---------- 13 | modality_lengths 14 | Number of features for each modality. 15 | condition_encoders 16 | Whether to condition the encoders on the covariates. 17 | condition_decoders 18 | Whether to condition the decoders on the covariates. 19 | normalization 20 | Normalization to use in the network. 21 | z_dim 22 | Dimensionality of the latent space. 23 | losses 24 | List of losses to use in the VAE. 25 | dropout 26 | Dropout rate. 27 | cond_dim 28 | Dimensionality of the covariate embeddings. 29 | kernel_type 30 | Type of kernel to use for the MMD loss. 31 | loss_coefs 32 | Coefficients for the different losses. 33 | num_groups 34 | Number of groups to use for the MMD loss. 35 | integrate_on_idx 36 | Indices of the covariates to integrate on. 37 | n_layers_encoders 38 | Number of layers in the encoders. 39 | n_layers_decoders 40 | Number of layers in the decoders. 41 | n_hidden_encoders 42 | Number of hidden units in the encoders. 43 | n_hidden_decoders 44 | Number of hidden units in the decoders. 45 | num_classification_classes 46 | Number of classes for each of the classification task. 47 | scoring 48 | Scoring function to use for the MIL classification. 49 | attn_dim 50 | Dimensionality of the hidden attention dimension. 51 | cat_covariate_dims 52 | Number of categories for each of the categorical covariates. 53 | cont_covariate_dims 54 | Number of categories for each of the continuous covariates. Always 1. 55 | cat_covs_idx 56 | Indices of the categorical covariates. 57 | cont_covs_idx 58 | Indices of the continuous covariates. 59 | cont_cov_type 60 | Type of continuous covariate. 61 | n_layers_cell_aggregator 62 | Number of layers in the cell aggregator. 63 | n_layers_classifier 64 | Number of layers in the classifier. 65 | n_layers_mlp_attn 66 | Number of layers in the attention MLP. 67 | n_layers_cont_embed 68 | Number of layers in the continuous embedding calculation. 69 | n_layers_regressor 70 | Number of layers in the regressor. 71 | n_hidden_regressor 72 | Number of hidden units in the regressor. 73 | n_hidden_cell_aggregator 74 | Number of hidden units in the cell aggregator. 75 | n_hidden_classifier 76 | Number of hidden units in the classifier. 77 | n_hidden_mlp_attn 78 | Number of hidden units in the attention MLP. 79 | n_hidden_cont_embed 80 | Number of hidden units in the continuous embedding calculation. 81 | class_loss_coef 82 | Coefficient for the classification loss. 83 | regression_loss_coef 84 | Coefficient for the regression loss. 85 | sample_batch_size 86 | Bag size. 87 | class_idx 88 | Which indices in cat covariates to do classification on. 89 | ord_idx 90 | Which indices in cat covariates to do ordinal regression on. 91 | reg_idx 92 | Which indices in cont covariates to do regression on. 93 | mmd 94 | Type of MMD loss to use. 95 | activation 96 | Activation function to use. 97 | initialization 98 | Initialization method to use. 99 | anneal_class_loss 100 | Whether to anneal the classification loss. 101 | """ 102 | 103 | def __init__( 104 | self, 105 | modality_lengths, 106 | condition_encoders=False, 107 | condition_decoders=True, 108 | normalization="layer", 109 | z_dim=16, 110 | losses=None, 111 | dropout=0.2, 112 | cond_dim=16, 113 | kernel_type="gaussian", 114 | loss_coefs=None, 115 | num_groups=1, 116 | integrate_on_idx=None, 117 | n_layers_encoders=None, 118 | n_layers_decoders=None, 119 | n_hidden_encoders=None, 120 | n_hidden_decoders=None, 121 | num_classification_classes=None, # number of classes for each of the classification task 122 | scoring="gated_attn", 123 | attn_dim=16, 124 | cat_covariate_dims=None, 125 | cont_covariate_dims=None, 126 | cat_covs_idx=None, 127 | cont_covs_idx=None, 128 | cont_cov_type="logsigm", 129 | n_layers_cell_aggregator=1, 130 | n_layers_classifier=1, 131 | n_layers_mlp_attn=1, 132 | n_layers_cont_embed=1, 133 | n_layers_regressor=1, 134 | n_hidden_regressor=16, 135 | n_hidden_cell_aggregator=16, 136 | n_hidden_classifier=16, 137 | n_hidden_mlp_attn=16, 138 | n_hidden_cont_embed=16, 139 | class_loss_coef=1.0, 140 | regression_loss_coef=1.0, 141 | sample_batch_size=128, 142 | class_idx=None, # which indices in cat covariates to do classification on, i.e. exclude from inference 143 | ord_idx=None, # which indices in cat covariates to do ordinal regression on and also exclude from inference 144 | reg_idx=None, # which indices in cont covariates to do regression on and also exclude from inference 145 | mmd="latent", 146 | activation="leaky_relu", 147 | initialization=None, 148 | anneal_class_loss=False, 149 | ): 150 | super().__init__() 151 | 152 | self.vae_module = MultiVAETorch( 153 | modality_lengths=modality_lengths, 154 | condition_encoders=condition_encoders, 155 | condition_decoders=condition_decoders, 156 | normalization=normalization, 157 | z_dim=z_dim, 158 | losses=losses, 159 | dropout=dropout, 160 | cond_dim=cond_dim, 161 | kernel_type=kernel_type, 162 | loss_coefs=loss_coefs, 163 | num_groups=num_groups, 164 | integrate_on_idx=integrate_on_idx, 165 | cat_covariate_dims=cat_covariate_dims, # only the actual categorical covs are considered here 166 | cont_covariate_dims=cont_covariate_dims, # only the actual cont covs are considered here 167 | cat_covs_idx=cat_covs_idx, 168 | cont_covs_idx=cont_covs_idx, 169 | cont_cov_type=cont_cov_type, 170 | n_layers_encoders=n_layers_encoders, 171 | n_layers_decoders=n_layers_decoders, 172 | n_layers_cont_embed=n_layers_cont_embed, 173 | n_hidden_encoders=n_hidden_encoders, 174 | n_hidden_decoders=n_hidden_decoders, 175 | n_hidden_cont_embed=n_hidden_cont_embed, 176 | mmd=mmd, 177 | activation=activation, 178 | initialization=initialization, 179 | ) 180 | self.mil_module = MILClassifierTorch( 181 | z_dim=z_dim, 182 | dropout=dropout, 183 | n_layers_cell_aggregator=n_layers_cell_aggregator, 184 | n_layers_classifier=n_layers_classifier, 185 | n_layers_mlp_attn=n_layers_mlp_attn, 186 | n_layers_regressor=n_layers_regressor, 187 | n_hidden_regressor=n_hidden_regressor, 188 | n_hidden_cell_aggregator=n_hidden_cell_aggregator, 189 | n_hidden_classifier=n_hidden_classifier, 190 | n_hidden_mlp_attn=n_hidden_mlp_attn, 191 | class_loss_coef=class_loss_coef, 192 | regression_loss_coef=regression_loss_coef, 193 | sample_batch_size=sample_batch_size, 194 | anneal_class_loss=anneal_class_loss, 195 | num_classification_classes=num_classification_classes, 196 | class_idx=class_idx, 197 | ord_idx=ord_idx, 198 | reg_idx=reg_idx, 199 | activation=activation, 200 | initialization=initialization, 201 | normalization=normalization, 202 | scoring=scoring, 203 | attn_dim=attn_dim, 204 | ) 205 | 206 | def _get_inference_input(self, tensors): 207 | x = tensors[REGISTRY_KEYS.X_KEY] 208 | 209 | cont_key = REGISTRY_KEYS.CONT_COVS_KEY 210 | cont_covs = tensors[cont_key] if cont_key in tensors.keys() else None 211 | 212 | cat_key = REGISTRY_KEYS.CAT_COVS_KEY 213 | cat_covs = tensors[cat_key] if cat_key in tensors.keys() else None 214 | 215 | return {"x": x, "cat_covs": cat_covs, "cont_covs": cont_covs} 216 | 217 | def _get_generative_input(self, tensors, inference_outputs): 218 | z_joint = inference_outputs["z_joint"] 219 | 220 | cont_key = REGISTRY_KEYS.CONT_COVS_KEY 221 | cont_covs = tensors[cont_key] if cont_key in tensors.keys() else None 222 | 223 | cat_key = REGISTRY_KEYS.CAT_COVS_KEY 224 | cat_covs = tensors[cat_key] if cat_key in tensors.keys() else None 225 | 226 | return {"z_joint": z_joint, "cat_covs": cat_covs, "cont_covs": cont_covs} 227 | 228 | @auto_move_data 229 | def inference(self, x, cat_covs, cont_covs) -> dict[str, torch.Tensor | list[torch.Tensor]]: 230 | """Forward pass for inference. 231 | 232 | Parameters 233 | ---------- 234 | x 235 | Input. 236 | cat_covs 237 | Categorical covariates to condition on. 238 | cont_covs 239 | Continuous covariates to condition on. 240 | 241 | Returns 242 | ------- 243 | Joint representations, marginal representations, joint mu's and logvar's and predictions. 244 | """ 245 | # VAE part 246 | inference_outputs = self.vae_module.inference(x, cat_covs, cont_covs) 247 | z_joint = inference_outputs["z_joint"] 248 | 249 | # MIL part 250 | mil_inference_outputs = self.mil_module.inference(z_joint) 251 | inference_outputs.update(mil_inference_outputs) 252 | return inference_outputs # z_joint, mu, logvar, z_marginal, predictions 253 | 254 | @auto_move_data 255 | def generative(self, z_joint, cat_covs, cont_covs) -> dict[str, torch.Tensor]: 256 | """Compute necessary inference quantities. 257 | 258 | Parameters 259 | ---------- 260 | z_joint 261 | Tensor of values with shape ``(batch_size, z_dim)``. 262 | cat_covs 263 | Categorical covariates to condition on. 264 | cont_covs 265 | Continuous covariates to condition on. 266 | 267 | Returns 268 | ------- 269 | Reconstructed values for each modality. 270 | """ 271 | return self.vae_module.generative(z_joint, cat_covs, cont_covs) 272 | 273 | def loss(self, tensors, inference_outputs, generative_outputs, kl_weight: float = 1.0): 274 | """Calculate the (modality) reconstruction loss, Kullback divergences and integration loss. 275 | 276 | Parameters 277 | ---------- 278 | tensors 279 | Tensor of values with shape ``(batch_size, n_input_features)``. 280 | inference_outputs 281 | Dictionary with the inference output. 282 | generative_outputs 283 | Dictionary with the generative output. 284 | kl_weight 285 | Weight of the KL loss. Default is 1.0. 286 | 287 | Returns 288 | ------- 289 | Reconstruction loss, Kullback divergences, integration loss, modality reconstruction and prediction losses. 290 | """ 291 | loss_vae, recon_loss, kl_loss, extra_metrics = self.vae_module._calculate_loss( 292 | tensors, inference_outputs, generative_outputs, kl_weight 293 | ) 294 | loss_mil, _, _, extra_metrics_mil = self.mil_module._calculate_loss( 295 | tensors, inference_outputs, generative_outputs, kl_weight 296 | ) 297 | loss = loss_vae + loss_mil 298 | extra_metrics.update(extra_metrics_mil) 299 | 300 | return LossOutput( 301 | loss=loss, 302 | reconstruction_loss=recon_loss, 303 | kl_local=kl_loss, 304 | extra_metrics=extra_metrics, 305 | ) 306 | -------------------------------------------------------------------------------- /src/multimil/module/_multivae_torch.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Literal 3 | 4 | import torch 5 | from scvi import REGISTRY_KEYS 6 | from scvi.distributions import NegativeBinomial, ZeroInflatedNegativeBinomial 7 | from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data 8 | from torch import nn 9 | from torch.distributions import Normal 10 | from torch.distributions import kl_divergence as kl 11 | 12 | from multimil.distributions import MMD, Jeffreys 13 | from multimil.nn import MLP, Decoder, GeneralizedSigmoid 14 | 15 | 16 | class MultiVAETorch(BaseModuleClass): 17 | """MultiMIL's multimodal integration module. 18 | 19 | Parameters 20 | ---------- 21 | modality_lengths 22 | List with lengths of each modality. 23 | condition_encoders 24 | Boolean to indicate if to condition encoders. 25 | condition_decoders 26 | Boolean to indicate if to condition decoders. 27 | normalization 28 | One of the following 29 | * ``'layer'`` - layer normalization 30 | * ``'batch'`` - batch normalization 31 | * ``None`` - no normalization. 32 | z_dim 33 | Dimensionality of the latent space. 34 | losses 35 | List of which losses to use. For each modality can be one of the following: 36 | * ``'mse'`` - mean squared error 37 | * ``'nb'`` - negative binomial 38 | * ``zinb`` - zero-inflated negative binomial 39 | * ``bce`` - binary cross-entropy. 40 | dropout 41 | Dropout rate for neural networks. 42 | cond_dim 43 | Dimensionality of the covariate embeddings. 44 | kernel_type 45 | One of the following: 46 | * ``'gaussian'`` - Gaussian kernel 47 | * ``'not gaussian'`` - not Gaussian kernel. 48 | loss_coefs 49 | Dictionary with weights for each of the losses. 50 | num_groups 51 | Number of groups to integrate on. 52 | integrate_on_idx 53 | Indices on which to integrate on. 54 | cat_covariate_dims 55 | List with number of classes for each of the categorical covariates. 56 | cont_covariate_dims 57 | List of 1's for each of the continuous covariate. 58 | cont_cov_type 59 | How to transform continuous covariate before multiplying with the embedding. One of the following: 60 | * ``'logsim'`` - generalized sigmoid 61 | * ``'mlp'`` - MLP. 62 | n_layers_cont_embed 63 | Number of layers for the transformation of the continuous covariates before multiplying with the embedding. 64 | n_hidden_cont_embed 65 | Number of nodes in hidden layers in the network that transforms continuous covariates. 66 | n_layers_encoders 67 | Number of layers in each encoder. 68 | n_layers_decoders 69 | Number of layers in each decoder. 70 | n_hidden_encoders 71 | Number of nodes in hidden layers in encoders. 72 | n_hidden_decoders 73 | Number of nodes in hidden layers in decoders. 74 | alignment_type 75 | How to calculate integration loss. One of the following 76 | * ``'latent'`` - only on the latent representations 77 | * ``'marginal'`` - only on the marginal representations 78 | * ``both`` - the sum of the two above. 79 | """ 80 | 81 | def __init__( 82 | self, 83 | modality_lengths, 84 | condition_encoders=False, 85 | condition_decoders=True, 86 | normalization: Literal["layer", "batch", None] = "layer", 87 | z_dim=16, 88 | losses=None, 89 | dropout=0.2, 90 | cond_dim=16, 91 | kernel_type="gaussian", 92 | loss_coefs=None, 93 | num_groups=1, 94 | integrate_on_idx=None, 95 | cat_covariate_dims=None, 96 | cont_covariate_dims=None, 97 | cat_covs_idx=None, 98 | cont_covs_idx=None, 99 | cont_cov_type="logsigm", 100 | n_layers_cont_embed: int = 1, 101 | n_layers_encoders=None, 102 | n_layers_decoders=None, 103 | n_hidden_cont_embed: int = 16, 104 | n_hidden_encoders=None, 105 | n_hidden_decoders=None, 106 | modality_alignment=None, 107 | alignment_type="latent", 108 | activation="leaky_relu", 109 | initialization=None, 110 | mix="product", 111 | ): 112 | super().__init__() 113 | 114 | self.input_dims = modality_lengths 115 | self.condition_encoders = condition_encoders 116 | self.condition_decoders = condition_decoders 117 | self.n_modality = len(self.input_dims) 118 | self.kernel_type = kernel_type 119 | self.integrate_on_idx = integrate_on_idx 120 | self.n_cont_cov = len(cont_covariate_dims) 121 | self.cont_cov_type = cont_cov_type 122 | self.alignment_type = alignment_type 123 | self.modality_alignment = modality_alignment 124 | self.normalization = normalization 125 | self.z_dim = z_dim 126 | self.dropout = dropout 127 | self.cond_dim = cond_dim 128 | self.kernel_type = kernel_type 129 | self.n_layers_cont_embed = n_layers_cont_embed 130 | self.n_layers_encoders = n_layers_encoders 131 | self.n_layers_decoders = n_layers_decoders 132 | self.n_hidden_cont_embed = n_hidden_cont_embed 133 | self.n_hidden_encoders = n_hidden_encoders 134 | self.n_hidden_decoders = n_hidden_decoders 135 | self.cat_covs_idx = cat_covs_idx 136 | self.cont_covs_idx = cont_covs_idx 137 | self.mix = mix 138 | 139 | if activation == "leaky_relu": 140 | self.activation = nn.LeakyReLU 141 | elif activation == "tanh": 142 | self.activation = nn.Tanh 143 | else: 144 | raise NotImplementedError( 145 | f'activation should be one of ["leaky_relu", "tanh"], but activation={activation} was passed.' 146 | ) 147 | 148 | # TODO: add warnings that mse is used 149 | if losses is None: 150 | self.losses = ["mse"] * self.n_modality 151 | elif len(losses) == self.n_modality: 152 | self.losses = losses 153 | else: 154 | raise ValueError( 155 | f"losses has to be the same length as the number of modalities. number of modalities = {self.n_modality} != {len(losses)} = len(losses)" 156 | ) 157 | if cat_covariate_dims is None: 158 | raise ValueError("cat_covariate_dims = None was passed.") 159 | if cont_covariate_dims is None: 160 | raise ValueError("cont_covariate_dims = None was passed.") 161 | 162 | # TODO: add warning that using these 163 | if self.n_layers_encoders is None: 164 | self.n_layers_encoders = [2] * self.n_modality 165 | if self.n_layers_decoders is None: 166 | self.n_layers_decoders = [2] * self.n_modality 167 | if self.n_hidden_encoders is None: 168 | self.n_hidden_encoders = [128] * self.n_modality 169 | if self.n_hidden_decoders is None: 170 | self.n_hidden_decoders = [128] * self.n_modality 171 | 172 | self.loss_coefs = { 173 | "recon": 1, 174 | "kl": 1e-6, 175 | "integ": 0, 176 | } 177 | for i in range(self.n_modality): 178 | self.loss_coefs[str(i)] = 1 179 | if loss_coefs is not None: 180 | self.loss_coefs.update(loss_coefs) 181 | 182 | # assume for now that can only use nb/zinb once, i.e. for RNA-seq modality 183 | # TODO: add check for multiple nb/zinb losses given 184 | self.theta = None 185 | for i, loss in enumerate(losses): 186 | if loss in ["nb", "zinb"]: 187 | self.theta = torch.nn.Parameter(torch.randn(self.input_dims[i], num_groups)) 188 | break 189 | 190 | # modality encoders 191 | cond_dim_enc = cond_dim * (len(cat_covariate_dims) + len(cont_covariate_dims)) if self.condition_encoders else 0 192 | self.encoders = [ 193 | MLP( 194 | n_input=x_dim + cond_dim_enc, 195 | n_output=z_dim, 196 | n_layers=n_layers, 197 | n_hidden=n_hidden, 198 | dropout_rate=dropout, 199 | normalization=normalization, 200 | activation=self.activation, 201 | ) 202 | for x_dim, n_layers, n_hidden in zip( 203 | self.input_dims, self.n_layers_encoders, self.n_hidden_encoders, strict=False 204 | ) 205 | ] 206 | 207 | # modality decoders 208 | cond_dim_dec = cond_dim * (len(cat_covariate_dims) + len(cont_covariate_dims)) if self.condition_decoders else 0 209 | dec_input = z_dim 210 | self.decoders = [ 211 | Decoder( 212 | n_input=dec_input + cond_dim_dec, 213 | n_output=x_dim, 214 | n_layers=n_layers, 215 | n_hidden=n_hidden, 216 | dropout_rate=dropout, 217 | normalization=normalization, 218 | activation=self.activation, 219 | loss=loss, 220 | ) 221 | for x_dim, loss, n_layers, n_hidden in zip( 222 | self.input_dims, self.losses, self.n_layers_decoders, self.n_hidden_decoders, strict=False 223 | ) 224 | ] 225 | 226 | self.mus = [nn.Linear(z_dim, z_dim) for _ in self.input_dims] 227 | self.logvars = [nn.Linear(z_dim, z_dim) for _ in self.input_dims] 228 | 229 | self.cat_covariate_embeddings = [nn.Embedding(dim, cond_dim) for dim in cat_covariate_dims] 230 | if self.n_cont_cov > 0: 231 | self.cont_covariate_embeddings = nn.Embedding(self.n_cont_cov, cond_dim) 232 | if self.cont_cov_type == "mlp": 233 | self.cont_covariate_curves = torch.nn.ModuleList() 234 | for _ in range(self.n_cont_cov): 235 | n_input = n_hidden_cont_embed if self.n_layers_cont_embed > 1 else 1 236 | self.cont_covariate_curves.append( 237 | nn.Sequential( 238 | MLP( 239 | n_input=1, 240 | n_output=n_hidden_cont_embed, 241 | n_layers=self.n_layers_cont_embed - 1, 242 | n_hidden=n_hidden_cont_embed, 243 | dropout_rate=dropout, 244 | normalization=normalization, 245 | activation=self.activation, 246 | ), 247 | nn.Linear(n_input, 1), 248 | ) 249 | if self.n_layers_cont_embed > 1 250 | else nn.Linear(n_input, 1) 251 | ) 252 | else: 253 | self.cont_covariate_curves = GeneralizedSigmoid( 254 | dim=self.n_cont_cov, 255 | nonlin=self.cont_cov_type, 256 | ) 257 | 258 | # register sub-modules 259 | for i, (enc, dec, mu, logvar) in enumerate( 260 | zip(self.encoders, self.decoders, self.mus, self.logvars, strict=False) 261 | ): 262 | self.add_module(f"encoder_{i}", enc) 263 | self.add_module(f"decoder_{i}", dec) 264 | self.add_module(f"mu_{i}", mu) 265 | self.add_module(f"logvar_{i}", logvar) 266 | 267 | for i, emb in enumerate(self.cat_covariate_embeddings): 268 | self.add_module(f"cat_covariate_embedding_{i}", emb) 269 | 270 | if initialization is not None: 271 | if initialization == "xavier": 272 | if activation != "leaky_relu": 273 | warnings.warn( 274 | f"We recommend using Xavier initialization with leaky_relu, but activation={activation} was passed.", 275 | stacklevel=2, 276 | ) 277 | for layer in self.modules(): 278 | if isinstance(layer, nn.Linear): 279 | nn.init.xavier_uniform_(layer.weight, gain=nn.init.calculate_gain(activation)) 280 | elif initialization == "kaiming": 281 | if activation != "tanh": 282 | warnings.warn( 283 | f"We recommend using Kaiming initialization with tanh, but activation={activation} was passed.", 284 | stacklevel=2, 285 | ) 286 | for layer in self.modules(): 287 | if isinstance(layer, nn.Linear): 288 | # following https://towardsdatascience.com/understand-kaiming-initialization-and-implementation-detail-in-pytorch-f7aa967e9138 (accessed 16.08.22) 289 | nn.init.kaiming_normal_(layer.weight, mode="fan_in") 290 | 291 | def _reparameterize(self, mu, logvar): 292 | std = torch.exp(0.5 * logvar) 293 | eps = torch.randn_like(std) 294 | return mu + eps * std 295 | 296 | def _bottleneck(self, z, i): 297 | mu = self.mus[i](z) 298 | logvar = self.logvars[i](z) 299 | z = self._reparameterize(mu, logvar) 300 | return z, mu, logvar 301 | 302 | def _x_to_h(self, x, i): 303 | return self.encoders[i](x) 304 | 305 | def _h_to_x(self, h, i): 306 | x = self.decoders[i](h) 307 | return x 308 | 309 | def _product_of_experts(self, mus, logvars, masks): 310 | vars = torch.exp(logvars) 311 | masks = masks.unsqueeze(-1).repeat(1, 1, vars.shape[-1]) 312 | mus_joint = torch.sum(mus * masks / vars, dim=1) 313 | vars_joint = torch.ones_like(mus_joint) # batch size 314 | vars_joint += torch.sum(masks / vars, dim=1) 315 | vars_joint = 1.0 / vars_joint # inverse 316 | mus_joint *= vars_joint 317 | logvars_joint = torch.log(vars_joint) 318 | return mus_joint, logvars_joint 319 | 320 | def _mixture_of_experts(self, mus, logvars, masks): 321 | vars = torch.exp(logvars) 322 | masks = masks.unsqueeze(-1).repeat(1, 1, vars.shape[-1]) 323 | masks = masks.float() 324 | weights = masks / torch.sum(masks, dim=1, keepdim=True) # normalize so the sum is 1 325 | # params of 1/2 * (X + Y) 326 | mus_mixture = torch.sum(weights * mus, dim=1) 327 | vars_mixture = torch.sum(weights**2 * vars, dim=1) 328 | logvars_mixture = torch.log(vars_mixture) 329 | return mus_mixture, logvars_mixture 330 | 331 | def _get_inference_input(self, tensors): 332 | x = tensors[REGISTRY_KEYS.X_KEY] 333 | 334 | cont_key = REGISTRY_KEYS.CONT_COVS_KEY 335 | cont_covs = tensors[cont_key] if cont_key in tensors.keys() else None 336 | 337 | cat_key = REGISTRY_KEYS.CAT_COVS_KEY 338 | cat_covs = tensors[cat_key] if cat_key in tensors.keys() else None 339 | 340 | return {"x": x, "cat_covs": cat_covs, "cont_covs": cont_covs} 341 | 342 | def _get_generative_input(self, tensors, inference_outputs): 343 | z_joint = inference_outputs["z_joint"] 344 | 345 | cont_key = REGISTRY_KEYS.CONT_COVS_KEY 346 | cont_covs = tensors[cont_key] if cont_key in tensors.keys() else None 347 | 348 | cat_key = REGISTRY_KEYS.CAT_COVS_KEY 349 | cat_covs = tensors[cat_key] if cat_key in tensors.keys() else None 350 | 351 | return {"z_joint": z_joint, "cat_covs": cat_covs, "cont_covs": cont_covs} 352 | 353 | @auto_move_data 354 | def inference( 355 | self, 356 | x: torch.Tensor, 357 | cat_covs: torch.Tensor | None = None, 358 | cont_covs: torch.Tensor | None = None, 359 | masks: list[torch.Tensor] | None = None, 360 | ) -> dict[str, torch.Tensor | list[torch.Tensor]]: 361 | """Compute necessary inference quantities. 362 | 363 | Parameters 364 | ---------- 365 | x 366 | Tensor of values with shape ``(batch_size, n_input_features)``. 367 | cat_covs 368 | Categorical covariates to condition on. 369 | cont_covs 370 | Continuous covariates to condition on. 371 | masks 372 | List of binary tensors indicating which values in ``x`` belong to which modality. 373 | 374 | Returns 375 | ------- 376 | Joint representations, marginal representations, joint mu's and logvar's. 377 | """ 378 | # split x into modality xs 379 | if torch.is_tensor(x): 380 | xs = torch.split( 381 | x, self.input_dims, dim=-1 382 | ) # list of tensors of len = n_mod, each tensor is of shape batch_size x mod_input_dim 383 | else: 384 | xs = x 385 | 386 | # TODO: check if masks still supported 387 | if masks is None: 388 | masks = [x.sum(dim=1) > 0 for x in xs] # list of masks per modality 389 | masks = torch.stack(masks, dim=1) 390 | # if we want to condition encoders, i.e. concat covariates to the input 391 | if self.condition_encoders is True: 392 | cat_embedds = self._select_cat_covariates(cat_covs) 393 | cont_embedds = self._select_cont_covariates(cont_covs) 394 | 395 | # concatenate input with categorical and continuous covariates 396 | xs = [ 397 | torch.cat([x, cat_embedds, cont_embedds], dim=-1) for x in xs 398 | ] # concat input to each modality x along the feature axis 399 | 400 | # TODO don't forward if mask is 0 for that dataset for that modality 401 | # hs = hidden state that we get after the encoder but before calculating mu and logvar for each modality 402 | hs = [self._x_to_h(x, mod) for mod, x in enumerate(xs)] 403 | # out = [zs_marginal, mus, logvars] and len(zs_marginal) = len(mus) = len(logvars) = number of modalities 404 | out = [self._bottleneck(h, mod) for mod, h in enumerate(hs)] 405 | # split out into zs_marginal, mus and logvars TODO check if easier to use split 406 | zs_marginal = [mod_out[0] for mod_out in out] 407 | z_marginal = torch.stack(zs_marginal, dim=1) 408 | mus = [mod_out[1] for mod_out in out] 409 | mu = torch.stack(mus, dim=1) 410 | logvars = [mod_out[2] for mod_out in out] 411 | logvar = torch.stack(logvars, dim=1) 412 | if self.mix == "product": 413 | mu_joint, logvar_joint = self._product_of_experts(mu, logvar, masks) 414 | elif self.mix == "mixture": 415 | mu_joint, logvar_joint = self._mixture_of_experts(mu, logvar, masks) 416 | else: 417 | raise ValueError(f"mix should be one of ['product', 'mixture'], but mix={self.mix} was passed") 418 | z_joint = self._reparameterize(mu_joint, logvar_joint) 419 | # drop mus and logvars according to masks for kl calculation 420 | # TODO here or in loss calculation? check 421 | # return mus+mus_joint 422 | return { 423 | "z_joint": z_joint, 424 | "mu": mu_joint, 425 | "logvar": logvar_joint, 426 | "z_marginal": z_marginal, 427 | "mu_marginal": mu, 428 | "logvar_marginal": logvar, 429 | } 430 | 431 | @auto_move_data 432 | def generative( 433 | self, z_joint: torch.Tensor, cat_covs: torch.Tensor | None = None, cont_covs: torch.Tensor | None = None 434 | ) -> dict[str, list[torch.Tensor]]: 435 | """Compute necessary inference quantities. 436 | 437 | Parameters 438 | ---------- 439 | z_joint 440 | Tensor of values with shape ``(batch_size, z_dim)``. 441 | cat_covs 442 | Categorical covariates to condition on. 443 | cont_covs 444 | Continuous covariates to condition on. 445 | 446 | Returns 447 | ------- 448 | Reconstructed values for each modality. 449 | """ 450 | z = z_joint.unsqueeze(1).repeat(1, self.n_modality, 1) 451 | zs = torch.split(z, 1, dim=1) 452 | 453 | if self.condition_decoders is True: 454 | cat_embedds = self._select_cat_covariates(cat_covs) 455 | cont_embedds = self._select_cont_covariates(cont_covs) 456 | 457 | zs = [ 458 | torch.cat([z.squeeze(1), cat_embedds, cont_embedds], dim=-1) for z in zs 459 | ] # concat embedding to each modality x along the feature axis 460 | 461 | rs = [self._h_to_x(z, mod) for mod, z in enumerate(zs)] 462 | return {"rs": rs} 463 | 464 | def _select_cat_covariates(self, cat_covs): 465 | if len(self.cat_covs_idx) > 0: 466 | cat_covs = torch.index_select(cat_covs, 1, self.cat_covs_idx.to(self.device)) 467 | cat_embedds = [ 468 | cat_covariate_embedding(covariate.long()) 469 | for cat_covariate_embedding, covariate in zip(self.cat_covariate_embeddings, cat_covs.T, strict=False) 470 | ] 471 | else: 472 | cat_embedds = [] 473 | 474 | if len(cat_embedds) > 0: 475 | cat_embedds = torch.cat(cat_embedds, dim=-1) # TODO check if concatenation is needed 476 | else: 477 | cat_embedds = torch.Tensor().to(self.device) 478 | return cat_embedds 479 | 480 | def _select_cont_covariates(self, cont_covs): 481 | if len(self.cont_covs_idx) > 0: 482 | cont_covs = torch.index_select(cont_covs, 1, self.cont_covs_idx.to(self.device)) 483 | if cont_covs.shape[-1] != self.n_cont_cov: # get rid of size_factors 484 | cont_covs = cont_covs[:, 0 : self.n_cont_cov] 485 | cont_embedds = self._compute_cont_cov_embeddings(cont_covs) 486 | else: 487 | cont_embedds = torch.Tensor().to(self.device) 488 | return cont_embedds 489 | 490 | def _calculate_loss(self, tensors, inference_outputs, generative_outputs, kl_weight: float = 1.0): 491 | x = tensors[REGISTRY_KEYS.X_KEY] 492 | if self.integrate_on_idx is not None: 493 | integrate_on = tensors.get(REGISTRY_KEYS.CAT_COVS_KEY)[:, self.integrate_on_idx] 494 | else: 495 | integrate_on = torch.zeros(x.shape[0], 1).to(self.device) 496 | 497 | size_factor = tensors.get(REGISTRY_KEYS.SIZE_FACTOR_KEY, None) 498 | 499 | rs = generative_outputs["rs"] 500 | mu = inference_outputs["mu"] 501 | logvar = inference_outputs["logvar"] 502 | z_joint = inference_outputs["z_joint"] 503 | z_marginal = inference_outputs["z_marginal"] # batch_size x n_modalities x latent_dim 504 | mu_marginal = inference_outputs["mu_marginal"] 505 | logvar_marginal = inference_outputs["logvar_marginal"] 506 | 507 | xs = torch.split( 508 | x, self.input_dims, dim=-1 509 | ) # list of tensors of len = n_mod, each tensor is of shape batch_size x mod_input_dim 510 | masks = [x.sum(dim=1) > 0 for x in xs] # [batch_size] * num_modalities 511 | 512 | recon_loss, modality_recon_losses = self._calc_recon_loss( 513 | xs, rs, self.losses, integrate_on, size_factor, self.loss_coefs, masks 514 | ) 515 | kl_loss = kl_weight * kl(Normal(mu, torch.sqrt(torch.exp(logvar))), Normal(0, 1)).sum(dim=1) 516 | 517 | integ_loss = torch.tensor(0.0).to(self.device) 518 | if self.loss_coefs["integ"] > 0: 519 | if self.alignment_type == "latent" or self.alignment_type == "both": 520 | integ_loss += self._calc_integ_loss(z_joint, mu, logvar, integrate_on).to(self.device) 521 | if self.alignment_type == "marginal" or self.alignment_type == "both": 522 | for i in range(len(masks)): 523 | for j in range(i + 1, len(masks)): 524 | idx_where_to_calc_integ_loss = torch.eq( 525 | masks[i] == masks[j], 526 | torch.eq(masks[i], torch.ones_like(masks[i])), 527 | ) 528 | if ( 529 | idx_where_to_calc_integ_loss.any() 530 | ): # if need to calc integ loss for a group between modalities 531 | marginal_i = z_marginal[:, i, :][idx_where_to_calc_integ_loss] 532 | marginal_j = z_marginal[:, j, :][idx_where_to_calc_integ_loss] 533 | 534 | mu_i = mu_marginal[:, i, :][idx_where_to_calc_integ_loss] 535 | mu_j = mu_marginal[:, j, :][idx_where_to_calc_integ_loss] 536 | 537 | logvar_i = logvar_marginal[:, i, :][idx_where_to_calc_integ_loss] 538 | logvar_j = logvar_marginal[:, j, :][idx_where_to_calc_integ_loss] 539 | 540 | marginals = torch.cat([marginal_i, marginal_j]) 541 | mus_marginal = torch.cat([mu_i, mu_j]) 542 | logvars_marginal = torch.cat([logvar_i, logvar_j]) 543 | 544 | modalities = torch.cat( 545 | [ 546 | torch.Tensor([i] * marginal_i.shape[0]), 547 | torch.Tensor([j] * marginal_j.shape[0]), 548 | ] 549 | ).to(self.device) 550 | 551 | integ_loss += self._calc_integ_loss( 552 | marginals, mus_marginal, logvars_marginal, modalities 553 | ).to(self.device) 554 | 555 | for i in range(len(masks)): 556 | marginal_i = z_marginal[:, i, :] 557 | marginal_i = marginal_i[masks[i]] 558 | 559 | mu_i = mu_marginal[:, i, :] 560 | mu_i = mu_i[masks[i]] 561 | 562 | logvar_i = logvar_marginal[:, i, :] 563 | logvar_i = logvar_i[masks[i]] 564 | 565 | group_marginal = integrate_on[masks[i]] 566 | integ_loss += self._calc_integ_loss(marginal_i, mu_i, logvar_i, group_marginal).to(self.device) 567 | 568 | loss = torch.mean( 569 | self.loss_coefs["recon"] * recon_loss 570 | + self.loss_coefs["kl"] * kl_loss 571 | + self.loss_coefs["integ"] * integ_loss 572 | ) 573 | 574 | modality_recon_losses = { 575 | f"modality_{i}_reconstruction_loss": modality_recon_losses[i] for i in range(len(modality_recon_losses)) 576 | } 577 | extra_metrics = {"integ_loss": integ_loss} 578 | extra_metrics.update(modality_recon_losses) 579 | 580 | return loss, recon_loss, kl_loss, extra_metrics 581 | 582 | def loss( 583 | self, tensors, inference_outputs, generative_outputs, kl_weight: float = 1.0 584 | ) -> tuple[ 585 | torch.FloatTensor, 586 | dict[str, torch.FloatTensor], 587 | torch.FloatTensor, 588 | torch.FloatTensor, 589 | torch.FloatTensor, 590 | torch.FloatTensor, 591 | dict[str, torch.FloatTensor], 592 | ]: 593 | """Calculate the (modality) reconstruction loss, Kullback divergences and integration loss. 594 | 595 | Parameters 596 | ---------- 597 | tensors 598 | Tensor of values with shape ``(batch_size, n_input_features)``. 599 | inference_outputs 600 | Dictionary with the inference output. 601 | generative_outputs 602 | Dictionary with the generative output. 603 | kl_weight 604 | Weight of the KL loss. Default is 1.0. 605 | 606 | Returns 607 | ------- 608 | Reconstruction loss, Kullback divergences, integration loss and modality reconstruction losses. 609 | """ 610 | loss, recon_loss, kl_loss, extra_metrics = self._calculate_loss( 611 | tensors, inference_outputs, generative_outputs, kl_weight 612 | ) 613 | 614 | return LossOutput( 615 | loss=loss, 616 | reconstruction_loss=recon_loss, 617 | kl_local=kl_loss, 618 | extra_metrics=extra_metrics, 619 | ) 620 | 621 | # @torch.inference_mode() 622 | # def sample(self, tensors, n_samples=1): 623 | # """Sample from the generative model. 624 | 625 | # Parameters 626 | # ---------- 627 | # tensors 628 | # Tensor of values. 629 | # n_samples 630 | # Number of samples to generate. 631 | 632 | # Returns 633 | # ------- 634 | # Generative outputs. 635 | # """ 636 | # inference_kwargs = {"n_samples": n_samples} 637 | # with torch.inference_mode(): 638 | # ( 639 | # _, 640 | # generative_outputs, 641 | # ) = self.forward( 642 | # tensors, 643 | # inference_kwargs=inference_kwargs, 644 | # compute_loss=False, 645 | # ) 646 | # return generative_outputs["rs"] 647 | 648 | def _calc_recon_loss(self, xs, rs, losses, group, size_factor, loss_coefs, masks): 649 | loss = [] 650 | for i, (x, r, loss_type) in enumerate(zip(xs, rs, losses, strict=False)): 651 | if len(r) != 2 and len(r.shape) == 3: 652 | r = r.squeeze() 653 | if loss_type == "mse": 654 | mse_loss = loss_coefs[str(i)] * torch.sum(nn.MSELoss(reduction="none")(r, x), dim=-1) 655 | loss.append(mse_loss) 656 | elif loss_type == "nb": 657 | dec_mean = r 658 | size_factor_view = size_factor.expand(dec_mean.size(0), dec_mean.size(1)) 659 | dec_mean = dec_mean * size_factor_view 660 | dispersion = self.theta.T[group.squeeze().long()] 661 | dispersion = torch.exp(dispersion) 662 | nb_loss = torch.sum(NegativeBinomial(mu=dec_mean, theta=dispersion).log_prob(x), dim=-1) 663 | nb_loss = loss_coefs[str(i)] * nb_loss 664 | loss.append(-nb_loss) 665 | elif loss_type == "zinb": 666 | dec_mean, dec_dropout = r 667 | dec_mean = dec_mean.squeeze() 668 | dec_dropout = dec_dropout.squeeze() 669 | size_factor_view = size_factor.unsqueeze(1).expand(dec_mean.size(0), dec_mean.size(1)) 670 | dec_mean = dec_mean * size_factor_view 671 | dispersion = self.theta.T[group.squeeze().long()] 672 | dispersion = torch.exp(dispersion) 673 | zinb_loss = torch.sum( 674 | ZeroInflatedNegativeBinomial(mu=dec_mean, theta=dispersion, zi_logits=dec_dropout).log_prob(x), 675 | dim=-1, 676 | ) 677 | zinb_loss = loss_coefs[str(i)] * zinb_loss 678 | loss.append(-zinb_loss) 679 | elif loss_type == "bce": 680 | bce_loss = loss_coefs[str(i)] * torch.sum(torch.nn.BCELoss(reduction="none")(r, x), dim=-1) 681 | loss.append(bce_loss) 682 | 683 | return ( 684 | torch.sum(torch.stack(loss, dim=-1) * torch.stack(masks, dim=-1), dim=1), 685 | torch.sum(torch.stack(loss, dim=-1) * torch.stack(masks, dim=-1), dim=0), 686 | ) 687 | 688 | def _calc_integ_loss(self, z, mu, logvar, group): 689 | loss = torch.tensor(0.0).to(self.device) 690 | unique = torch.unique(group) 691 | 692 | if len(unique) > 1: 693 | if self.modality_alignment == "MMD": 694 | zs = [z[group == i] for i in unique] 695 | for i in range(len(zs)): 696 | for j in range(i + 1, len(zs)): 697 | loss += MMD(kernel_type=self.kernel_type)(zs[i], zs[j]) 698 | 699 | elif self.modality_alignment == "Jeffreys": 700 | mus_joint = [mu[group == i] for i in unique] 701 | logvars_joint = [logvar[group == i] for i in unique] 702 | for i in range(len(mus_joint)): 703 | for j in range(i + 1, len(mus_joint)): 704 | if len(mus_joint[i]) == len(mus_joint[j]): 705 | loss += Jeffreys()( 706 | (mus_joint[i], torch.exp(logvars_joint[i])), (mus_joint[j], torch.exp(logvars_joint[j])) 707 | ) 708 | else: 709 | raise ValueError( 710 | f"mus_joint[i] and mus_joint[j] have different lengths: {len(mus_joint[i])} != {len(mus_joint[j])}" 711 | ) 712 | 713 | return loss 714 | 715 | def _compute_cont_cov_embeddings(self, covs): 716 | """Compute embeddings for continuous covariates. 717 | 718 | Adapted from 719 | Title: CPA (c) Facebook, Inc. 720 | Date: 26.01.2022 721 | Link to the used code: 722 | https://github.com/facebookresearch/CPA/blob/382ff641c588820a453d801e5d0e5bb56642f282/compert/model.py#L342 723 | 724 | """ 725 | if self.cont_cov_type == "mlp": 726 | embeddings = [] 727 | for cov in range(covs.size(1)): 728 | this_cov = covs[:, cov].view(-1, 1) 729 | embeddings.append( 730 | self.cont_covariate_curves[cov](this_cov).sigmoid() 731 | ) # * this_drug.gt(0)) # TODO check what this .gt(0) is 732 | return torch.cat(embeddings, 1) @ self.cont_covariate_embeddings.weight 733 | else: 734 | return self.cont_covariate_curves(covs) @ self.cont_covariate_embeddings.weight 735 | 736 | def select_losses_to_plot(self): 737 | """Select losses to plot. 738 | 739 | Returns 740 | ------- 741 | Loss names. 742 | """ 743 | loss_names = ["kl_local", "elbo", "reconstruction_loss"] 744 | for i in range(self.n_modality): 745 | loss_names.append(f"modality_{i}_reconstruction_loss") 746 | if self.loss_coefs["integ"] != 0: 747 | loss_names.append("integ_loss") 748 | return loss_names 749 | -------------------------------------------------------------------------------- /src/multimil/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from ._base_components import MLP, Aggregator, Decoder, GeneralizedSigmoid 2 | 3 | __all__ = ["MLP", "Decoder", "GeneralizedSigmoid", "Aggregator"] 4 | -------------------------------------------------------------------------------- /src/multimil/nn/_base_components.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | 3 | import torch 4 | from scvi.nn import FCLayers 5 | from torch import nn 6 | from torch.nn import functional as F 7 | 8 | 9 | class MLP(nn.Module): 10 | """A helper class to build blocks of fully-connected, normalization, dropout and activation layers. 11 | 12 | Parameters 13 | ---------- 14 | n_input 15 | Number of input features. 16 | n_output 17 | Number of output features. 18 | n_layers 19 | Number of hidden layers. 20 | n_hidden 21 | Number of hidden units. 22 | dropout_rate 23 | Dropout rate. 24 | normalization 25 | Type of normalization to use. Can be one of ["layer", "batch", "none"]. 26 | activation 27 | Activation function to use. 28 | 29 | """ 30 | 31 | def __init__( 32 | self, 33 | n_input: int, 34 | n_output: int, 35 | n_layers: int = 1, 36 | n_hidden: int = 128, 37 | dropout_rate: float = 0.1, 38 | normalization: str = "layer", 39 | activation=nn.LeakyReLU, 40 | ): 41 | super().__init__() 42 | use_layer_norm = False 43 | use_batch_norm = True 44 | if normalization == "layer": 45 | use_layer_norm = True 46 | use_batch_norm = False 47 | elif normalization == "none": 48 | use_batch_norm = False 49 | 50 | self.mlp = FCLayers( 51 | n_in=n_input, 52 | n_out=n_output, 53 | n_layers=n_layers, 54 | n_hidden=n_hidden, 55 | dropout_rate=dropout_rate, 56 | use_layer_norm=use_layer_norm, 57 | use_batch_norm=use_batch_norm, 58 | activation_fn=activation, 59 | ) 60 | 61 | def forward(self, x: torch.Tensor) -> torch.Tensor: 62 | """Forward computation on ``x``. 63 | 64 | Parameters 65 | ---------- 66 | x 67 | Tensor of values with shape ``(n_input,)``. 68 | 69 | Returns 70 | ------- 71 | Tensor of values with shape ``(n_output,)``. 72 | """ 73 | return self.mlp(x) 74 | 75 | 76 | class Decoder(nn.Module): 77 | """A helper class to build custom decoders depending on which loss was passed. 78 | 79 | Parameters 80 | ---------- 81 | n_input 82 | Number of input features. 83 | n_output 84 | Number of output features. 85 | n_layers 86 | Number of hidden layers. 87 | n_hidden 88 | Number of hidden units. 89 | dropout_rate 90 | Dropout rate. 91 | normalization 92 | Type of normalization to use. Can be one of ["layer", "batch", "none"]. 93 | activation 94 | Activation function to use. 95 | loss 96 | Loss function to use. Can be one of ["mse", "nb", "zinb", "bce"]. 97 | """ 98 | 99 | def __init__( 100 | self, 101 | n_input: int, 102 | n_output: int, 103 | n_layers: int = 1, 104 | n_hidden: int = 128, 105 | dropout_rate: float = 0.1, 106 | normalization: str = "layer", 107 | activation=nn.LeakyReLU, 108 | loss="mse", 109 | ): 110 | super().__init__() 111 | 112 | if loss not in ["mse", "nb", "zinb", "bce"]: 113 | raise NotImplementedError(f"Loss function {loss} is not implemented.") 114 | else: 115 | self.loss = loss 116 | 117 | self.decoder = MLP( 118 | n_input=n_input, 119 | n_output=n_hidden, 120 | n_layers=n_layers, 121 | n_hidden=n_hidden, 122 | dropout_rate=dropout_rate, 123 | normalization=normalization, 124 | activation=activation, 125 | ) 126 | if loss == "mse": 127 | self.recon_decoder = nn.Linear(n_hidden, n_output) 128 | elif loss == "nb": 129 | self.mean_decoder = nn.Sequential(nn.Linear(n_hidden, n_output), nn.Softmax(dim=-1)) 130 | 131 | elif loss == "zinb": 132 | self.mean_decoder = nn.Sequential(nn.Linear(n_hidden, n_output), nn.Softmax(dim=-1)) 133 | self.dropout_decoder = nn.Linear(n_hidden, n_output) 134 | 135 | elif loss == "bce": 136 | self.recon_decoder = FCLayers( 137 | n_in=n_hidden, 138 | n_out=n_output, 139 | n_layers=0, 140 | dropout_rate=0, 141 | use_layer_norm=False, 142 | use_batch_norm=False, 143 | activation_fn=nn.Sigmoid, 144 | ) 145 | 146 | def forward(self, x: torch.Tensor) -> torch.Tensor: 147 | """Forward computation on ``x``. 148 | 149 | Parameters 150 | ---------- 151 | x 152 | Tensor of values with shape ``(n_input,)``. 153 | 154 | Returns 155 | ------- 156 | Tensor of values with shape ``(n_output,)``. 157 | """ 158 | x = self.decoder(x) 159 | if self.loss == "mse" or self.loss == "bce": 160 | return self.recon_decoder(x) 161 | elif self.loss == "nb": 162 | return self.mean_decoder(x) 163 | elif self.loss == "zinb": 164 | return self.mean_decoder(x), self.dropout_decoder(x) 165 | 166 | 167 | class GeneralizedSigmoid(nn.Module): 168 | """Sigmoid, log-sigmoid or linear functions for encoding continuous covariates. 169 | 170 | Adapted from 171 | Title: CPA (c) Facebook, Inc. 172 | Date: 26.01.2022 173 | Link to the used code: 174 | https://github.com/facebookresearch/CPA/blob/382ff641c588820a453d801e5d0e5bb56642f282/compert/model.py#L109 175 | 176 | Parameters 177 | ---------- 178 | dim 179 | Number of input features. 180 | nonlin 181 | Type of non-linearity to use. Can be one of ["logsigm", "sigm"]. Default is "logsigm". 182 | """ 183 | 184 | def __init__(self, dim, nonlin: Literal["logsigm", "sigm"] | None = "logsigm"): 185 | super().__init__() 186 | self.nonlin = nonlin 187 | self.beta = torch.nn.Parameter(torch.ones(1, dim), requires_grad=True) 188 | self.bias = torch.nn.Parameter(torch.zeros(1, dim), requires_grad=True) 189 | 190 | def forward(self, x) -> torch.Tensor: 191 | """Forward computation on ``x``. 192 | 193 | Parameters 194 | ---------- 195 | x 196 | Tensor of values. 197 | 198 | Returns 199 | ------- 200 | Tensor of values with the same shape as ``x``. 201 | """ 202 | if self.nonlin == "logsigm": 203 | return (torch.log1p(x) * self.beta + self.bias).sigmoid() 204 | elif self.nonlin == "sigm": 205 | return (x * self.beta + self.bias).sigmoid() 206 | else: 207 | return x 208 | 209 | 210 | class Aggregator(nn.Module): 211 | """A helper class to build custom aggregators depending on the scoring function passed. 212 | 213 | Parameters 214 | ---------- 215 | n_input 216 | Number of input features. 217 | scoring 218 | Scoring function to use. Can be one of ["attn", "gated_attn", "mean", "max", "sum", "mlp"]. 219 | attn_dim 220 | Dimension of the hidden attention layer. 221 | sample_batch_size 222 | Bag batch size. 223 | scale 224 | Whether to scale the attention weights. 225 | dropout 226 | Dropout rate. 227 | n_layers_mlp_attn 228 | Number of hidden layers in the MLP attention. 229 | n_hidden_mlp_attn 230 | Number of hidden units in the MLP attention. 231 | activation 232 | Activation function to use. 233 | """ 234 | 235 | def __init__( 236 | self, 237 | n_input=None, 238 | scoring="gated_attn", 239 | attn_dim=16, # D 240 | sample_batch_size=None, 241 | scale=False, 242 | dropout=0.2, 243 | n_layers_mlp_attn=1, 244 | n_hidden_mlp_attn=16, 245 | activation=nn.LeakyReLU, 246 | ): 247 | super().__init__() 248 | 249 | self.scoring = scoring 250 | self.patient_batch_size = sample_batch_size 251 | self.scale = scale 252 | 253 | if self.scoring == "attn": 254 | self.attn_dim = attn_dim # attn dim from https://arxiv.org/pdf/1802.04712.pdf 255 | self.attention = nn.Sequential( 256 | nn.Linear(n_input, self.attn_dim), 257 | nn.Tanh(), 258 | nn.Linear(self.attn_dim, 1, bias=False), 259 | ) 260 | elif self.scoring == "gated_attn": 261 | self.attn_dim = attn_dim 262 | self.attention_V = nn.Sequential( 263 | nn.Linear(n_input, self.attn_dim), 264 | nn.Tanh(), 265 | ) 266 | 267 | self.attention_U = nn.Sequential( 268 | nn.Linear(n_input, self.attn_dim), 269 | nn.Sigmoid(), 270 | ) 271 | 272 | self.attention_weights = nn.Linear(self.attn_dim, 1, bias=False) 273 | 274 | elif self.scoring == "mlp": 275 | if n_layers_mlp_attn == 1: 276 | self.attention = nn.Linear(n_input, 1) 277 | else: 278 | self.attention = nn.Sequential( 279 | MLP( 280 | n_input, 281 | n_hidden_mlp_attn, 282 | n_layers=n_layers_mlp_attn - 1, 283 | n_hidden=n_hidden_mlp_attn, 284 | dropout_rate=dropout, 285 | activation=activation, 286 | ), 287 | nn.Linear(n_hidden_mlp_attn, 1), 288 | ) 289 | 290 | def forward(self, x) -> torch.Tensor: 291 | """Forward computation on ``x``. 292 | 293 | Parameters 294 | ---------- 295 | x 296 | Tensor of values with shape ``(n_input,)``. 297 | 298 | Returns 299 | ------- 300 | Tensor of pooled values. 301 | """ 302 | # # TODO add mean and max pooling 303 | # if self.scoring == "attn": 304 | # # from https://github.com/AMLab-Amsterdam/AttentionDeepMIL/blob/master/model.py (accessed 16.09.2021) 305 | # self.A = self.attention(x) # Nx1 306 | # self.A = torch.transpose(self.A, -1, -2) # 1xN 307 | # self.A = F.softmax(self.A, dim=-1) # softmax over N 308 | 309 | # elif self.scoring == "gated_attn": 310 | # # from https://github.com/AMLab-Amsterdam/AttentionDeepMIL/blob/master/model.py (accessed 16.09.2021) 311 | # A_V = self.attention_V(x) # NxD 312 | # A_U = self.attention_U(x) # NxD 313 | # self.A = self.attention_weights(A_V * A_U) # element wise multiplication # Nx1 314 | # self.A = torch.transpose(self.A, -1, -2) # 1xN 315 | # self.A = F.softmax(self.A, dim=-1) # softmax over N 316 | 317 | # elif self.scoring == "mlp": 318 | # self.A = self.attention(x) # N 319 | # self.A = torch.transpose(self.A, -1, -2) 320 | # self.A = F.softmax(self.A, dim=-1) 321 | 322 | # else: 323 | # raise NotImplementedError( 324 | # f'scoring = {self.scoring} is not implemented. Has to be one of ["attn", "gated_attn", "mlp"].' 325 | # ) 326 | 327 | # if self.scale: 328 | # self.A = self.A * self.A.shape[-1] / self.patient_batch_size 329 | 330 | # return torch.bmm(self.A, x).squeeze(dim=1) 331 | # Apply different pooling strategies based on the scoring method 332 | if self.scoring in ["attn", "gated_attn", "mlp", "sum", "mean", "max"]: 333 | if self.scoring == "attn": 334 | # from https://github.com/AMLab-Amsterdam/AttentionDeepMIL/blob/master/model.py (accessed 16.09.2021) 335 | A = self.attention(x) # (batch_size, N, 1) 336 | A = A.transpose(1, 2) # (batch_size, 1, N) 337 | A = F.softmax(A, dim=-1) 338 | elif self.scoring == "gated_attn": 339 | # from https://github.com/AMLab-Amsterdam/AttentionDeepMIL/blob/master/model.py (accessed 16.09.2021) 340 | A_V = self.attention_V(x) # (batch_size, N, attn_dim) 341 | A_U = self.attention_U(x) # (batch_size, N, attn_dim) 342 | A = self.attention_weights(A_V * A_U) # (batch_size, N, 1) 343 | A = A.transpose(1, 2) # (batch_size, 1, N) 344 | A = F.softmax(A, dim=-1) 345 | elif self.scoring == "mlp": 346 | A = self.attention(x) # (batch_size, N, 1) 347 | A = A.transpose(1, 2) # (batch_size, 1, N) 348 | A = F.softmax(A, dim=-1) 349 | 350 | elif self.scoring == "sum": 351 | return torch.sum(x, dim=1) # (batch_size, n_input) 352 | elif self.scoring == "mean": 353 | return torch.mean(x, dim=1) # (batch_size, n_input) 354 | elif self.scoring == "max": 355 | return torch.max(x, dim=1).values # (batch_size, n_input) 356 | else: 357 | raise NotImplementedError( 358 | f'scoring = {self.scoring} is not implemented. Has to be one of ["attn", "gated_attn", "mlp", "sum", "mean", "max"].' 359 | ) 360 | if self.scale: 361 | if self.patient_batch_size is None: 362 | raise ValueError("patient_batch_size must be set when scale is True.") 363 | self.A = A * A.shape[-1] / self.patient_batch_size 364 | 365 | pooled = torch.bmm(A, x).squeeze(dim=1) # (batch_size, n_input) 366 | return pooled 367 | -------------------------------------------------------------------------------- /src/multimil/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from ._utils import ( 2 | calculate_size_factor, 3 | create_df, 4 | get_bag_info, 5 | get_predictions, 6 | plt_plot_losses, 7 | prep_minibatch, 8 | save_predictions_in_adata, 9 | select_covariates, 10 | setup_ordinal_regression, 11 | ) 12 | 13 | __all__ = [ 14 | "create_df", 15 | "calculate_size_factor", 16 | "setup_ordinal_regression", 17 | "select_covariates", 18 | "prep_minibatch", 19 | "get_predictions", 20 | "get_bag_info", 21 | "save_predictions_in_adata", 22 | "plt_plot_losses", 23 | ] 24 | -------------------------------------------------------------------------------- /src/multimil/utils/_utils.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import scipy 6 | import torch 7 | from matplotlib import pyplot as plt 8 | 9 | 10 | def create_df(pred, columns=None, index=None) -> pd.DataFrame: 11 | """Create a pandas DataFrame from a list of predictions. 12 | 13 | Parameters 14 | ---------- 15 | pred 16 | List of predictions. 17 | columns 18 | Column names, i.e. class_names. 19 | index 20 | Index names, i.e. obs_names. 21 | 22 | Returns 23 | ------- 24 | DataFrame with predictions. 25 | """ 26 | if isinstance(pred, dict): 27 | for key in pred.keys(): 28 | pred[key] = torch.cat(pred[key]).squeeze().cpu().numpy() 29 | else: 30 | pred = torch.cat(pred).squeeze().cpu().numpy() 31 | 32 | df = pd.DataFrame(pred) 33 | if index is not None: 34 | df.index = index 35 | if columns is not None: 36 | df.columns = columns 37 | return df 38 | 39 | 40 | def calculate_size_factor(adata, size_factor_key, rna_indices_end) -> str: 41 | """Calculate size factors. 42 | 43 | Parameters 44 | ---------- 45 | adata 46 | Annotated data object. 47 | size_factor_key 48 | Key in `adata.obs` where size factors are stored. 49 | rna_indices_end 50 | Index of the last RNA feature in the data. 51 | 52 | Returns 53 | ------- 54 | Size factor key. 55 | """ 56 | # TODO check that organize_multimodal_anndatas was run, i.e. that .uns['modality_lengths'] was added, needed for q2r 57 | if size_factor_key is not None and rna_indices_end is not None: 58 | raise ValueError( 59 | "Only one of [`size_factor_key`, `rna_indices_end`] can be specified, but both are not `None`." 60 | ) 61 | # TODO change to when both are None and data in unimodal, use all input features to calculate the size factors, add warning 62 | if size_factor_key is None and rna_indices_end is None: 63 | raise ValueError("One of [`size_factor_key`, `rna_indices_end`] has to be specified, but both are `None`.") 64 | 65 | if size_factor_key is not None: 66 | return size_factor_key 67 | if rna_indices_end is not None: 68 | adata_rna = adata[:, :rna_indices_end].copy() 69 | if scipy.sparse.issparse(adata.X): 70 | adata.obs.loc[:, "size_factors"] = adata_rna.X.toarray().sum(1).T.tolist() 71 | else: 72 | adata.obs.loc[:, "size_factors"] = adata_rna.X.sum(1).T.tolist() 73 | return "size_factors" 74 | 75 | 76 | def setup_ordinal_regression(adata, ordinal_regression_order, categorical_covariate_keys): 77 | """Setup ordinal regression. 78 | 79 | Parameters 80 | ---------- 81 | adata 82 | Annotated data object. 83 | ordinal_regression_order 84 | Order of categories for ordinal regression. 85 | categorical_covariate_keys 86 | Keys of categorical covariates. 87 | """ 88 | # TODO make sure not to assume categorical columns for ordinal regression -> change to np.unique if needed 89 | if ordinal_regression_order is not None: 90 | if not set(ordinal_regression_order.keys()).issubset(categorical_covariate_keys): 91 | raise ValueError( 92 | f"All keys {ordinal_regression_order.keys()} has to be registered as categorical covariates too, but categorical_covariate_keys = {categorical_covariate_keys}" 93 | ) 94 | for key in ordinal_regression_order.keys(): 95 | adata.obs[key] = adata.obs[key].astype("category") 96 | if set(adata.obs[key].cat.categories) != set(ordinal_regression_order[key]): 97 | raise ValueError( 98 | f"Categories of adata.obs[{key}]={adata.obs[key].cat.categories} are not the same as categories specified = {ordinal_regression_order[key]}" 99 | ) 100 | adata.obs[key] = adata.obs[key].cat.reorder_categories(ordinal_regression_order[key]) 101 | 102 | 103 | def select_covariates(covs, prediction_idx, n_samples_in_batch) -> torch.Tensor: 104 | """Select prediction covariates from all covariates. 105 | 106 | Parameters 107 | ---------- 108 | covs 109 | Covariates. 110 | prediction_idx 111 | Index of predictions. 112 | n_samples_in_batch 113 | Number of samples in the batch. 114 | 115 | Returns 116 | ------- 117 | Prediction covariates. 118 | """ 119 | if len(prediction_idx) > 0: 120 | covs = torch.index_select(covs, 1, prediction_idx) 121 | covs = covs.view(n_samples_in_batch, -1, len(prediction_idx))[:, 0, :] 122 | else: 123 | covs = torch.tensor([]) 124 | return covs 125 | 126 | 127 | def prep_minibatch(covs, sample_batch_size) -> tuple[int, int]: 128 | """Prepare minibatch. 129 | 130 | Parameters 131 | ---------- 132 | covs 133 | Covariates. 134 | sample_batch_size 135 | Sample batch size. 136 | 137 | Returns 138 | ------- 139 | Batch size and number of samples in the batch. 140 | """ 141 | batch_size = covs.shape[0] 142 | 143 | if batch_size % sample_batch_size != 0: 144 | n_samples_in_batch = 1 145 | else: 146 | n_samples_in_batch = batch_size // sample_batch_size 147 | return batch_size, n_samples_in_batch 148 | 149 | 150 | def get_predictions( 151 | prediction_idx, pred_values, true_values, size, bag_pred, bag_true, full_pred, offset=0 152 | ) -> tuple[dict, dict, dict]: 153 | """Get predictions. 154 | 155 | Parameters 156 | ---------- 157 | prediction_idx 158 | Index of predictions. 159 | pred_values 160 | Predicted values. 161 | true_values 162 | True values. 163 | size 164 | Size of the bag minibatch. 165 | bag_pred 166 | Bag predictions. 167 | bag_true 168 | Bag true values. 169 | full_pred 170 | Full predictions, i.e. on cell-level. 171 | offset 172 | Offset, needed because of several possible types of predictions. 173 | 174 | Returns 175 | ------- 176 | Bag predictions, bag true values, full predictions on cell-level. 177 | """ 178 | for i in range(len(prediction_idx)): 179 | bag_pred[i] = bag_pred.get(i, []) + [pred_values[offset + i].cpu()] 180 | bag_true[i] = bag_true.get(i, []) + [true_values[:, i].cpu()] 181 | # TODO in ord reg had pred[len(self.mil.class_idx) + i].repeat(1, size).flatten() 182 | # in reg had 183 | # cell level, i.e. prediction for the cell = prediction for the bag 184 | full_pred[i] = full_pred.get(i, []) + [pred_values[offset + i].unsqueeze(1).repeat(1, size, 1).flatten(0, 1)] 185 | return bag_pred, bag_true, full_pred 186 | 187 | 188 | def get_bag_info(bags, n_samples_in_batch, minibatch_size, cell_counter, bag_counter, sample_batch_size): 189 | """Get bag information. 190 | 191 | Parameters 192 | ---------- 193 | bags 194 | Bags. 195 | n_samples_in_batch 196 | Number of samples in the batch. 197 | minibatch_size 198 | Minibatch size. 199 | cell_counter 200 | Cell counter. 201 | bag_counter 202 | Bag counter. 203 | sample_batch_size 204 | Sample batch size. 205 | 206 | Returns 207 | ------- 208 | Updated bags, cell counter, and bag counter. 209 | """ 210 | if n_samples_in_batch == 1: 211 | bags += [[bag_counter] * minibatch_size] 212 | cell_counter += minibatch_size 213 | bag_counter += 1 214 | else: 215 | bags += [[bag_counter + i] * sample_batch_size for i in range(n_samples_in_batch)] 216 | bag_counter += n_samples_in_batch 217 | cell_counter += sample_batch_size * n_samples_in_batch 218 | return bags, cell_counter, bag_counter 219 | 220 | 221 | def save_predictions_in_adata( 222 | adata, idx, predictions, bag_pred, bag_true, cell_pred, class_names, name, clip, reg=False 223 | ): 224 | """Save predictions in anndata object. 225 | 226 | Parameters 227 | ---------- 228 | adata 229 | Annotated data object. 230 | idx 231 | Index, i.e. obs_names. 232 | predictions 233 | Predictions. 234 | bag_pred 235 | Bag predictions. 236 | bag_true 237 | Bag true values. 238 | cell_pred 239 | Cell predictions. 240 | class_names 241 | Class names. 242 | name 243 | Name of the prediction column. 244 | clip 245 | Whether to transofrm the predictions. One of `clip`, `argmax`, or `none`. 246 | reg 247 | Whether the rediciton task is a regression task. 248 | """ 249 | # cell level predictions) 250 | 251 | if clip == "clip": # ord regression 252 | df = create_df(cell_pred[idx], [name], index=adata.obs_names) 253 | adata.obsm[f"full_predictions_{name}"] = df 254 | adata.obs[f"predicted_{name}"] = np.clip(np.round(df.to_numpy()), a_min=0.0, a_max=len(class_names) - 1.0) 255 | elif clip == "argmax": # classification 256 | df = create_df(cell_pred[idx], class_names, index=adata.obs_names) 257 | adata.obsm[f"full_predictions_{name}"] = df 258 | adata.obs[f"predicted_{name}"] = df.to_numpy().argmax(axis=1) 259 | else: # regression 260 | df = create_df(cell_pred[idx], [name], index=adata.obs_names) 261 | adata.obsm[f"full_predictions_{name}"] = df 262 | adata.obs[f"predicted_{name}"] = df.to_numpy() 263 | if reg is False: 264 | adata.obs[f"predicted_{name}"] = adata.obs[f"predicted_{name}"].astype("category") 265 | adata.obs[f"predicted_{name}"] = adata.obs[f"predicted_{name}"].cat.rename_categories( 266 | dict(enumerate(class_names)) 267 | ) 268 | 269 | # bag level predictions 270 | adata.uns[f"bag_true_{name}"] = create_df(bag_true, predictions) 271 | if clip == "clip": # ordinal regression 272 | df_bag = create_df(bag_pred[idx], [name]) 273 | adata.uns[f"bag_full_predictions_{name}"] = np.clip( 274 | np.round(df_bag.to_numpy()), a_min=0.0, a_max=len(class_names) - 1.0 275 | ) 276 | elif clip == "argmax": # classification 277 | df_bag = create_df(bag_pred[idx], class_names) 278 | adata.uns[f"bag_full_predictions_{name}"] = df_bag.to_numpy().argmax(axis=1) 279 | else: # regression 280 | df_bag = create_df(bag_pred[idx], [name]) 281 | adata.uns[f"bag_full_predictions_{name}"] = df_bag.to_numpy() 282 | 283 | 284 | def plt_plot_losses(history, loss_names, save): 285 | """Plot losses. 286 | 287 | Parameters 288 | ---------- 289 | history 290 | History of losses. 291 | loss_names 292 | Loss names to plot. 293 | save 294 | Path to save the plot. 295 | """ 296 | df = pd.concat(history, axis=1) 297 | df.columns = df.columns.droplevel(-1) 298 | df["epoch"] = df.index 299 | 300 | nrows = ceil(len(loss_names) / 2) 301 | 302 | plt.figure(figsize=(15, 5 * nrows)) 303 | 304 | for i, name in enumerate(loss_names): 305 | plt.subplot(nrows, 2, i + 1) 306 | plt.plot(df["epoch"], df[name + "_train"], ".-", label=name + "_train") 307 | plt.plot(df["epoch"], df[name + "_validation"], ".-", label=name + "_validation") 308 | plt.xlabel("epoch") 309 | plt.legend() 310 | if save is not None: 311 | plt.savefig(save, bbox_inches="tight") 312 | -------------------------------------------------------------------------------- /tests/test_basic.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import multimil 4 | 5 | 6 | def test_package_has_version(): 7 | assert multimil.__version__ is not None 8 | 9 | 10 | @pytest.mark.skip(reason="This decorator should be removed when test passes.") 11 | def test_example(): 12 | assert 1 == 0 # This test is designed to fail. 13 | --------------------------------------------------------------------------------