├── .codecov.yaml
├── .cruft.json
├── .editorconfig
├── .github
├── ISSUE_TEMPLATE
│ ├── bug_report.yml
│ ├── config.yml
│ └── feature_request.yml
└── workflows
│ ├── build.yaml
│ ├── release.yaml
│ └── test.yaml
├── .gitignore
├── .pre-commit-config.yaml
├── .readthedocs.yaml
├── CHANGELOG.md
├── LICENSE
├── README.md
├── docs
├── Makefile
├── _static
│ ├── .gitkeep
│ ├── css
│ │ └── custom.css
│ └── images
│ │ └── equation_schematic.png
├── _templates
│ ├── .gitkeep
│ └── autosummary
│ │ └── class.rst
├── api.md
├── changelog.md
├── conf.py
├── extensions
│ └── typed_returns.py
├── index.md
├── notebooks
│ └── Tutorial.myst
├── references.bib
└── references.md
├── notebooks
├── check_implementation.qmd
├── devel_experiments.qmd
├── kang_analysis_in_R.qmd
├── lemur_model_experiments.qmd
├── scanpy_experiments.qmd
└── test.ipynb
├── pyproject.toml
├── src
└── pylemur
│ ├── __init__.py
│ ├── pl
│ ├── __init__.py
│ └── basic.py
│ ├── pp
│ ├── __init__.py
│ └── basic.py
│ └── tl
│ ├── __init__.py
│ ├── _design_matrix_utils.py
│ ├── _grassmann.py
│ ├── _grassmann_lm.py
│ ├── _lin_alg_wrappers.py
│ ├── alignment.py
│ └── lemur.py
└── tests
├── test_grasmann_lm.py
├── test_grassmann.py
├── test_lemur.py
└── test_lin_alg_utils.py
/.codecov.yaml:
--------------------------------------------------------------------------------
1 | # Based on pydata/xarray
2 | codecov:
3 | require_ci_to_pass: no
4 |
5 | coverage:
6 | status:
7 | project:
8 | default:
9 | # Require 1% coverage, i.e., always succeed
10 | target: 1
11 | patch: false
12 | changes: false
13 |
14 | comment:
15 | layout: diff, flags, files
16 | behavior: once
17 | require_base: no
18 |
--------------------------------------------------------------------------------
/.cruft.json:
--------------------------------------------------------------------------------
1 | {
2 | "template": "https://github.com/scverse/cookiecutter-scverse",
3 | "commit": "87a407a65408d75a949c0b54b19fd287475a56f8",
4 | "checkout": "v0.4.0",
5 | "context": {
6 | "cookiecutter": {
7 | "project_name": "pyLemur",
8 | "package_name": "pylemur",
9 | "project_description": "A Python implementation of the LEMUR algorithm for analyzing multi-condition single-cell RNA-seq data.",
10 | "author_full_name": "Your Name",
11 | "author_email": "artjom31415@googlemail.com",
12 | "github_user": "const-ae",
13 | "project_repo": "https://github.com/const-ae/pyLemur",
14 | "license": "MIT License",
15 | "_copy_without_render": [
16 | ".github/workflows/build.yaml",
17 | ".github/workflows/test.yaml",
18 | "docs/_templates/autosummary/**.rst"
19 | ],
20 | "_render_devdocs": false,
21 | "_jinja2_env_vars": {
22 | "lstrip_blocks": true,
23 | "trim_blocks": true
24 | },
25 | "_template": "https://github.com/scverse/cookiecutter-scverse"
26 | }
27 | },
28 | "directory": null
29 | }
30 |
--------------------------------------------------------------------------------
/.editorconfig:
--------------------------------------------------------------------------------
1 | root = true
2 |
3 | [*]
4 | indent_style = space
5 | indent_size = 4
6 | end_of_line = lf
7 | charset = utf-8
8 | trim_trailing_whitespace = true
9 | insert_final_newline = true
10 |
11 | [*.{yml,yaml}]
12 | indent_size = 2
13 |
14 | [.cruft.json]
15 | indent_size = 2
16 |
17 | [Makefile]
18 | indent_style = tab
19 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.yml:
--------------------------------------------------------------------------------
1 | name: Bug report
2 | description: Report something that is broken or incorrect
3 | labels: bug
4 | body:
5 | - type: markdown
6 | attributes:
7 | value: |
8 | **Note**: Please read [this guide](https://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports)
9 | detailing how to provide the necessary information for us to reproduce your bug. In brief:
10 | * Please provide exact steps how to reproduce the bug in a clean Python environment.
11 | * In case it's not clear what's causing this bug, please provide the data or the data generation procedure.
12 | * Sometimes it is not possible to share the data, but usually it is possible to replicate problems on publicly
13 | available datasets or to share a subset of your data.
14 |
15 | - type: textarea
16 | id: report
17 | attributes:
18 | label: Report
19 | description: A clear and concise description of what the bug is.
20 | validations:
21 | required: true
22 |
23 | - type: textarea
24 | id: versions
25 | attributes:
26 | label: Version information
27 | description: |
28 | Please paste below the output of
29 |
30 | ```python
31 | import session_info
32 | session_info.show(html=False, dependencies=True)
33 | ```
34 | placeholder: |
35 | -----
36 | anndata 0.8.0rc2.dev27+ge524389
37 | session_info 1.0.0
38 | -----
39 | asttokens NA
40 | awkward 1.8.0
41 | backcall 0.2.0
42 | cython_runtime NA
43 | dateutil 2.8.2
44 | debugpy 1.6.0
45 | decorator 5.1.1
46 | entrypoints 0.4
47 | executing 0.8.3
48 | h5py 3.7.0
49 | ipykernel 6.15.0
50 | jedi 0.18.1
51 | mpl_toolkits NA
52 | natsort 8.1.0
53 | numpy 1.22.4
54 | packaging 21.3
55 | pandas 1.4.2
56 | parso 0.8.3
57 | pexpect 4.8.0
58 | pickleshare 0.7.5
59 | pkg_resources NA
60 | prompt_toolkit 3.0.29
61 | psutil 5.9.1
62 | ptyprocess 0.7.0
63 | pure_eval 0.2.2
64 | pydev_ipython NA
65 | pydevconsole NA
66 | pydevd 2.8.0
67 | pydevd_file_utils NA
68 | pydevd_plugins NA
69 | pydevd_tracing NA
70 | pygments 2.12.0
71 | pytz 2022.1
72 | scipy 1.8.1
73 | setuptools 62.5.0
74 | setuptools_scm NA
75 | six 1.16.0
76 | stack_data 0.3.0
77 | tornado 6.1
78 | traitlets 5.3.0
79 | wcwidth 0.2.5
80 | zmq 23.1.0
81 | -----
82 | IPython 8.4.0
83 | jupyter_client 7.3.4
84 | jupyter_core 4.10.0
85 | -----
86 | Python 3.9.13 | packaged by conda-forge | (main, May 27 2022, 16:58:50) [GCC 10.3.0]
87 | Linux-5.18.6-arch1-1-x86_64-with-glibc2.35
88 | -----
89 | Session information updated at 2022-07-07 17:55
90 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/config.yml:
--------------------------------------------------------------------------------
1 | blank_issues_enabled: false
2 | contact_links:
3 | - name: Scverse Community Forum
4 | url: https://discourse.scverse.org/
5 | about: If you have questions about “How to do X”, please ask them here.
6 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.yml:
--------------------------------------------------------------------------------
1 | name: Feature request
2 | description: Propose a new feature for pyLemur
3 | labels: enhancement
4 | body:
5 | - type: textarea
6 | id: description
7 | attributes:
8 | label: Description of feature
9 | description: Please describe your suggestion for a new feature. It might help to describe a problem or use case, plus any alternatives that you have considered.
10 | validations:
11 | required: true
12 |
--------------------------------------------------------------------------------
/.github/workflows/build.yaml:
--------------------------------------------------------------------------------
1 | name: Check Build
2 |
3 | on:
4 | push:
5 | branches: [main]
6 | pull_request:
7 | branches: [main]
8 |
9 | concurrency:
10 | group: ${{ github.workflow }}-${{ github.ref }}
11 | cancel-in-progress: true
12 |
13 | jobs:
14 | package:
15 | runs-on: ubuntu-latest
16 | steps:
17 | - uses: actions/checkout@v3
18 | - name: Set up Python 3.10
19 | uses: actions/setup-python@v4
20 | with:
21 | python-version: "3.10"
22 | cache: "pip"
23 | cache-dependency-path: "**/pyproject.toml"
24 | - name: Install build dependencies
25 | run: python -m pip install --upgrade pip wheel twine build
26 | - name: Build package
27 | run: python -m build
28 | - name: Check package
29 | run: twine check --strict dist/*.whl
30 |
--------------------------------------------------------------------------------
/.github/workflows/release.yaml:
--------------------------------------------------------------------------------
1 | name: Release
2 |
3 | on:
4 | release:
5 | types: [published]
6 |
7 | # Use "trusted publishing", see https://docs.pypi.org/trusted-publishers/
8 | jobs:
9 | release:
10 | name: Upload release to PyPI
11 | runs-on: ubuntu-latest
12 | environment:
13 | name: pypi
14 | url: https://pypi.org/p/pylemur
15 | permissions:
16 | id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
17 | steps:
18 | - uses: actions/checkout@v4
19 | with:
20 | filter: blob:none
21 | fetch-depth: 0
22 | - uses: actions/setup-python@v4
23 | with:
24 | python-version: "3.x"
25 | cache: "pip"
26 | - run: pip install build
27 | - run: python -m build
28 | - name: Publish package distributions to PyPI
29 | uses: pypa/gh-action-pypi-publish@release/v1
30 |
--------------------------------------------------------------------------------
/.github/workflows/test.yaml:
--------------------------------------------------------------------------------
1 | name: Test
2 |
3 | on:
4 | push:
5 | branches: [main]
6 | pull_request:
7 | branches: [main]
8 | schedule:
9 | - cron: "0 5 1,15 * *"
10 |
11 | concurrency:
12 | group: ${{ github.workflow }}-${{ github.ref }}
13 | cancel-in-progress: true
14 |
15 | jobs:
16 | test:
17 | runs-on: ${{ matrix.os }}
18 | defaults:
19 | run:
20 | shell: bash -e {0} # -e to fail on error
21 |
22 | strategy:
23 | fail-fast: false
24 | matrix:
25 | include:
26 | - os: ubuntu-latest
27 | python: "3.10"
28 | - os: ubuntu-latest
29 | python: "3.12"
30 | - os: ubuntu-latest
31 | python: "3.12"
32 | pip-flags: "--pre"
33 | name: PRE-RELEASE DEPENDENCIES
34 |
35 | name: ${{ matrix.name }} Python ${{ matrix.python }}
36 |
37 | env:
38 | OS: ${{ matrix.os }}
39 | PYTHON: ${{ matrix.python }}
40 |
41 | steps:
42 | - uses: actions/checkout@v3
43 | - name: Set up Python ${{ matrix.python }}
44 | uses: actions/setup-python@v4
45 | with:
46 | python-version: ${{ matrix.python }}
47 | cache: "pip"
48 | cache-dependency-path: "**/pyproject.toml"
49 |
50 | - name: Install test dependencies
51 | run: |
52 | python -m pip install --upgrade pip wheel
53 | - name: Install dependencies
54 | run: |
55 | pip install ${{ matrix.pip-flags }} ".[dev,test]"
56 | - name: Test
57 | env:
58 | MPLBACKEND: agg
59 | PLATFORM: ${{ matrix.os }}
60 | DISPLAY: :42
61 | run: |
62 | coverage run -m pytest -v --color=yes
63 | - name: Report coverage
64 | run: |
65 | coverage report
66 | - name: Upload coverage
67 | uses: codecov/codecov-action@v3
68 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Created by pytest automatically.
2 | # Byte-compiled / optimized / DLL files
3 | __pycache__/
4 | *.py[cod]
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | bin/
11 | build/
12 | develop-eggs/
13 | dist/
14 | eggs/
15 | lib/
16 | lib64/
17 | parts/
18 | sdist/
19 | var/
20 | *.egg-info/
21 | .installed.cfg
22 | *.egg
23 |
24 | # Installer logs
25 | pip-log.txt
26 | pip-delete-this-directory.txt
27 |
28 | # Unit test / coverage reports
29 | .tox/
30 | .coverage
31 | .cache
32 | nosetests.xml
33 | coverage.xml
34 |
35 | .pytest_cache/
36 |
37 | # Temp files
38 | .DS_Store
39 | *~
40 | buck-out/
41 |
42 | # Compiled files
43 | .venv/
44 | .venv_pre_release/
45 | __pycache__/
46 | .mypy_cache/
47 | .ruff_cache/
48 |
49 | # Distribution / packaging
50 | /build/
51 | /dist/
52 | /*.egg-info/
53 |
54 | # Tests and coverage
55 | /.pytest_cache/
56 | /.cache/
57 | /data/
58 | /node_modules/
59 |
60 | # docs
61 | /docs/generated/
62 | /docs/_build/
63 |
64 | # IDEs
65 | /.idea/
66 | /.vscode/
67 |
68 | # Data files from tutorial
69 | docs/notebooks/data/*
70 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | fail_fast: false
2 | default_language_version:
3 | python: python3
4 | default_stages:
5 | - commit
6 | - push
7 | minimum_pre_commit_version: 2.16.0
8 | repos:
9 | - repo: https://github.com/pre-commit/mirrors-prettier
10 | rev: v4.0.0-alpha.8
11 | hooks:
12 | - id: prettier
13 | - repo: https://github.com/astral-sh/ruff-pre-commit
14 | rev: v0.3.2
15 | hooks:
16 | - id: ruff
17 | types_or: [python, pyi, jupyter]
18 | args: [--fix, --exit-non-zero-on-fix]
19 | - id: ruff-format
20 | types_or: [python, pyi, jupyter]
21 | - repo: https://github.com/pre-commit/pre-commit-hooks
22 | rev: v4.5.0
23 | hooks:
24 | - id: detect-private-key
25 | - id: check-ast
26 | - id: end-of-file-fixer
27 | - id: mixed-line-ending
28 | args: [--fix=lf]
29 | - id: trailing-whitespace
30 | - id: check-case-conflict
31 | # Check that there are no merge conflicts (could be generated by template sync)
32 | - id: check-merge-conflict
33 | args: [--assume-in-merge]
34 | - repo: local
35 | hooks:
36 | - id: forbid-to-commit
37 | name: Don't commit rej files
38 | entry: |
39 | Cannot commit .rej files. These indicate merge conflicts that arise during automated template updates.
40 | Fix the merge conflicts manually and remove the .rej files.
41 | language: fail
42 | files: '.*\.rej$'
43 |
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # https://docs.readthedocs.io/en/stable/config-file/v2.html
2 | version: 2
3 | build:
4 | os: ubuntu-20.04
5 | tools:
6 | python: "3.10"
7 | sphinx:
8 | configuration: docs/conf.py
9 | # disable this for more lenient docs builds
10 | fail_on_warning: false
11 | python:
12 | install:
13 | - method: pip
14 | path: .
15 | extra_requirements:
16 | - doc
17 |
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | # Changelog
2 |
3 | All notable changes to this project will be documented in this file.
4 |
5 | The format is based on [Keep a Changelog][],
6 | and this project adheres to [Semantic Versioning][].
7 |
8 | [keep a changelog]: https://keepachangelog.com/en/1.0.0/
9 | [semantic versioning]: https://semver.org/spec/v2.0.0.html
10 |
11 | ## [Unreleased]
12 |
13 | # [0.3.1]
14 |
15 | - Fix documentation of `cond()` return type and handle pd.Series in
16 | `model.predict()` (#5, thanks Mark Keller)
17 |
18 | ## [0.3.0]
19 |
20 | - Depend on `formulaic_contrast` package
21 | - Refactor `cond()` implementation to use `formulaic_contrast` implementation.
22 |
23 | ## [0.2.2]
24 |
25 | - Sync with cookiecutter-template update (version 0.4)
26 | - Bump required Python version to `3.10`
27 | - Allow data frames as design matrices
28 | - Allow matrices as input to LEMUR()
29 |
30 | ## [0.2.1]
31 |
32 | - Change example gene to one with clearer differential expression pattern
33 | - Remove error output in `align_harmony
34 |
35 | ## [0.2.0]
36 |
37 | Major rewrite of the API. Instead of adding coefficients as custom fields
38 | to the input `AnnData` object, the API now follows an object-oriented style
39 | similar to scikit-learn or `SCVI`. This change was motivated by the feedback
40 | during the submission to the `scverse` ecosystem.
41 | ([Thanks](<(https://github.com/scverse/ecosystem-packages/pull/156#issuecomment-2014676654)>) Gregor).
42 |
43 | ### Changed
44 |
45 | - Instead of calling `fit = pylemur.tl.lemur(adata, ...)`, you now create a LEMUR model
46 | (`model = pylemur.tl.LEMUR(adata, ...)`) and subsequently call `model.fit()`, `model.align_with_harmony()`,
47 | and `model.predict()`.
48 |
49 | ## [0.1.0] - 2024-03-21
50 |
51 | - Initial beta release of `pyLemur`
52 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024, Your Name
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # pyLemur
2 |
3 | [![Tests][badge-tests]][link-tests]
4 | [![Documentation][badge-docs]][link-docs]
5 |
6 | [badge-tests]: https://img.shields.io/github/actions/workflow/status/const-ae/pyLemur/test.yaml?branch=main
7 | [link-tests]: https://github.com/const-ae/pyLemur/actions/workflows/test.yaml
8 | [link-docs]: https://pyLemur.readthedocs.io
9 | [badge-docs]: http://readthedocs.org/projects/pylemur/badge
10 |
11 | The Python implementation of the LEMUR method to analyze multi-condition single-cell data. For the more complete version in R, see [github.com/const-ae/lemur](https://github.com/const-ae/lemur). To learn more check-out the [function documentation](https://pylemur.readthedocs.io/page/api.html) and the [tutorial](https://pylemur.readthedocs.io/page/notebooks/Tutorial.html) at [pylemur.readthedocs.io](https://pylemur.readthedocs.io). To check-out the source code or submit an issue go to [github.com/const-ae/pyLemur](https://github.com/const-ae/pyLemur)
12 |
13 | ## Citation
14 |
15 | > Ahlmann-Eltze C, Huber W (2025).
16 | > “Analysis of multi-condition single-cell data with latent embedding multivariate regression.” Nature Genetics (2025).
17 | > [doi:10.1038/s41588-024-01996-0](https://doi.org/10.1038/s41588-024-01996-0).
18 |
19 | # Getting started
20 |
21 | ## Installation
22 |
23 | You need to have Python 3.10 or newer installed on your system.
24 | There are several alternative options to install pyLemur:
25 |
26 | Install the latest release of `pyLemur` from [PyPI](https://pypi.org/project/pyLemur/):
27 |
28 | ```bash
29 | pip install pyLemur
30 | ```
31 |
32 | Alternatively, install the latest development version directly from Github:
33 |
34 | ```bash
35 | pip install git+https://github.com/const-ae/pyLemur.git@main
36 | ```
37 |
38 | ## Documentation
39 |
40 | For more information on the functions see the [API docs](https://pyLemur.readthedocs.io/page/api.html) and the [tutorial](https://pylemur.readthedocs.io/page/notebooks/Tutorial.html).
41 |
42 | ## Contact
43 |
44 | For questions and help requests, you can reach out in the [scverse discourse][scverse-discourse].
45 | If you found a bug, please use the [issue tracker][issue-tracker].
46 |
47 | [scverse-discourse]: https://discourse.scverse.org/
48 | [issue-tracker]: https://github.com/const-ae/pyLemur/issues
49 |
50 | ## Building
51 |
52 | Install the package in editable mode:
53 |
54 | ```
55 | pip install ".[dev,doc,test]"
56 | ```
57 |
58 | Build the documentation locally
59 |
60 | ```
61 | cd docs
62 | make html
63 | open _build/html/index.html
64 | ```
65 |
66 | Run the unit tests
67 |
68 | ```
69 | pytest
70 | ```
71 |
72 | Run pre-commit hooks manually
73 |
74 | ```
75 | pre-commit run --all-files
76 | ```
77 |
78 | or individually
79 |
80 | ```
81 | ruff check
82 | ```
83 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | livehtml:
18 | sphinx-autobuild "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
19 |
20 | # Catch-all target: route all unknown targets to Sphinx using the new
21 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
22 | %: Makefile
23 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
24 |
--------------------------------------------------------------------------------
/docs/_static/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/const-ae/pyLemur/64b8301e2ec14d9eb948b74c6e3eabcff9b2d63a/docs/_static/.gitkeep
--------------------------------------------------------------------------------
/docs/_static/css/custom.css:
--------------------------------------------------------------------------------
1 | /* Reduce the font size in data frames - See https://github.com/scverse/cookiecutter-scverse/issues/193 */
2 | div.cell_output table.dataframe {
3 | font-size: 0.8em;
4 | }
5 |
--------------------------------------------------------------------------------
/docs/_static/images/equation_schematic.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/const-ae/pyLemur/64b8301e2ec14d9eb948b74c6e3eabcff9b2d63a/docs/_static/images/equation_schematic.png
--------------------------------------------------------------------------------
/docs/_templates/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/const-ae/pyLemur/64b8301e2ec14d9eb948b74c6e3eabcff9b2d63a/docs/_templates/.gitkeep
--------------------------------------------------------------------------------
/docs/_templates/autosummary/class.rst:
--------------------------------------------------------------------------------
1 | {{ fullname | escape | underline}}
2 |
3 | .. currentmodule:: {{ module }}
4 |
5 | .. add toctree option to make autodoc generate the pages
6 |
7 | .. autoclass:: {{ objname }}
8 |
9 | {% block attributes %}
10 | {% if attributes %}
11 | Attributes table
12 | ~~~~~~~~~~~~~~~~~~
13 |
14 | .. autosummary::
15 | {% for item in attributes %}
16 | ~{{ fullname }}.{{ item }}
17 | {%- endfor %}
18 | {% endif %}
19 | {% endblock %}
20 |
21 | {% block methods %}
22 | {% if methods %}
23 | Methods table
24 | ~~~~~~~~~~~~~
25 |
26 | .. autosummary::
27 | {% for item in methods %}
28 | {%- if item != '__init__' %}
29 | ~{{ fullname }}.{{ item }}
30 | {%- endif -%}
31 | {%- endfor %}
32 | {% endif %}
33 | {% endblock %}
34 |
35 | {% block attributes_documentation %}
36 | {% if attributes %}
37 | Attributes
38 | ~~~~~~~~~~~
39 |
40 | {% for item in attributes %}
41 |
42 | .. autoattribute:: {{ [objname, item] | join(".") }}
43 | {%- endfor %}
44 |
45 | {% endif %}
46 | {% endblock %}
47 |
48 | {% block methods_documentation %}
49 | {% if methods %}
50 | Methods
51 | ~~~~~~~
52 |
53 | {% for item in methods %}
54 | {%- if item != '__init__' %}
55 |
56 | .. automethod:: {{ [objname, item] | join(".") }}
57 | {%- endif -%}
58 | {%- endfor %}
59 |
60 | {% endif %}
61 | {% endblock %}
62 |
--------------------------------------------------------------------------------
/docs/api.md:
--------------------------------------------------------------------------------
1 | # API
2 |
3 | ## Tools
4 |
5 | To create the LEMUR object that provides functionality to fit, align, predict, and transform input data:
6 |
7 | ```{eval-rst}
8 | .. module:: pylemur.tl
9 | .. module:: pylemur.pp
10 |
11 | .. currentmodule:: pylemur
12 |
13 | .. autosummary::
14 | :toctree: generated
15 |
16 | tl.LEMUR
17 | pp.shifted_log_transform
18 | ```
19 |
--------------------------------------------------------------------------------
/docs/changelog.md:
--------------------------------------------------------------------------------
1 | ```{include} ../CHANGELOG.md
2 |
3 | ```
4 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 |
3 | # This file only contains a selection of the most common options. For a full
4 | # list see the documentation:
5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
6 |
7 | # -- Path setup --------------------------------------------------------------
8 | import sys
9 | from datetime import datetime
10 | from importlib.metadata import metadata
11 | from pathlib import Path
12 |
13 | HERE = Path(__file__).parent
14 | sys.path.insert(0, str(HERE / "extensions"))
15 |
16 |
17 | # -- Project information -----------------------------------------------------
18 |
19 | # NOTE: If you installed your project in editable mode, this might be stale.
20 | # If this is the case, reinstall it to refresh the metadata
21 | info = metadata("pyLemur")
22 | project_name = info["Name"]
23 | author = info["Author"]
24 | copyright = f"{datetime.now():%Y}, {author}."
25 | version = info["Version"]
26 | urls = dict(pu.split(", ") for pu in info.get_all("Project-URL"))
27 | repository_url = urls["Source"]
28 |
29 | # The full version, including alpha/beta/rc tags
30 | release = info["Version"]
31 |
32 | bibtex_bibfiles = ["references.bib"]
33 | bibtex_default_style = "unsrt"
34 |
35 | templates_path = ["_templates"]
36 | nitpicky = True # Warn about broken links
37 | needs_sphinx = "4.0"
38 |
39 | html_context = {
40 | "display_github": True, # Integrate GitHub
41 | "github_user": "const-ae",
42 | "github_repo": "https://github.com/const-ae/pyLemur",
43 | "github_version": "main",
44 | "conf_py_path": "/docs/",
45 | }
46 |
47 | # -- General configuration ---------------------------------------------------
48 |
49 | # Add any Sphinx extension module names here, as strings.
50 | # They can be extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones.
51 | extensions = [
52 | "myst_nb",
53 | "sphinx_copybutton",
54 | "sphinx.ext.autodoc",
55 | "sphinx.ext.intersphinx",
56 | "sphinx.ext.autosummary",
57 | "sphinx.ext.napoleon",
58 | "sphinx_autodoc_typehints",
59 | "sphinxcontrib.bibtex",
60 | "sphinx_autodoc_typehints",
61 | "sphinx.ext.mathjax",
62 | "sphinx_design",
63 | "IPython.sphinxext.ipython_console_highlighting",
64 | "sphinxext.opengraph",
65 | *[p.stem for p in (HERE / "extensions").glob("*.py")],
66 | ]
67 |
68 | autosummary_generate = True
69 | autodoc_member_order = "groupwise"
70 | default_role = "literal"
71 | napoleon_google_docstring = False
72 | napoleon_numpy_docstring = True
73 | napoleon_include_init_with_doc = False
74 | napoleon_use_rtype = True # having a separate entry generally helps readability
75 | napoleon_use_param = True
76 | napoleon_use_ivar = True
77 |
78 |
79 | myst_heading_anchors = 6 # create anchors for h1-h6
80 | myst_enable_extensions = [
81 | "amsmath",
82 | "colon_fence",
83 | "deflist",
84 | "dollarmath",
85 | "html_image",
86 | "html_admonition",
87 | ]
88 | myst_url_schemes = ("http", "https", "mailto")
89 | nb_output_stderr = "remove"
90 | nb_execution_mode = "cache"
91 | nb_merge_streams = True
92 | typehints_defaults = "braces"
93 |
94 | source_suffix = {
95 | ".rst": "restructuredtext",
96 | ".ipynb": "myst-nb",
97 | ".myst": "myst-nb",
98 | }
99 |
100 | intersphinx_mapping = {
101 | "python": ("https://docs.python.org/3", None),
102 | "anndata": ("https://anndata.readthedocs.io/en/stable/", None),
103 | "numpy": ("https://numpy.org/doc/stable/", None),
104 | "pandas": ("https://pandas.pydata.org/pandas-docs/stable/", None),
105 | "scipy": ("https://docs.scipy.org/doc/scipy/", None),
106 | "sklearn": ("https://scikit-learn.org/stable/", None),
107 | # "formulaic": ("https://matthewwardrop.github.io/formulaic/", None) # Doesn't work.
108 | }
109 |
110 |
111 | # List of patterns, relative to source directory, that match files and
112 | # directories to ignore when looking for source files.
113 | # This pattern also affects html_static_path and html_extra_path.
114 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "**.ipynb_checkpoints"]
115 |
116 |
117 | # -- Options for HTML output -------------------------------------------------
118 |
119 | # The theme to use for HTML and HTML Help pages. See the documentation for
120 | # a list of builtin themes.
121 | #
122 | html_theme = "sphinx_book_theme"
123 | html_static_path = ["_static"]
124 | html_css_files = ["css/custom.css"]
125 |
126 | html_title = project_name
127 |
128 | html_theme_options = {
129 | "repository_url": repository_url,
130 | "use_repository_button": True,
131 | "path_to_docs": "docs/",
132 | "navigation_with_keys": False,
133 | }
134 |
135 | pygments_style = "default"
136 |
137 | nitpick_ignore = [
138 | # If building the documentation fails because of a missing link that is outside your control,
139 | # you can add an exception to this list.
140 | # ("py:class", "igraph.Graph"),
141 | ]
142 |
--------------------------------------------------------------------------------
/docs/extensions/typed_returns.py:
--------------------------------------------------------------------------------
1 | # code from https://github.com/theislab/scanpy/blob/master/docs/extensions/typed_returns.py
2 | # with some minor adjustment
3 | from __future__ import annotations
4 |
5 | import re
6 | from collections.abc import Generator, Iterable
7 |
8 | from sphinx.application import Sphinx
9 | from sphinx.ext.napoleon import NumpyDocstring
10 |
11 |
12 | def _process_return(lines: Iterable[str]) -> Generator[str, None, None]:
13 | for line in lines:
14 | if m := re.fullmatch(r"(?P\w+)\s+:\s+(?P[\w.]+)", line):
15 | yield f'-{m["param"]} (:class:`~{m["type"]}`)'
16 | else:
17 | yield line
18 |
19 |
20 | def _parse_returns_section(self: NumpyDocstring, section: str) -> list[str]:
21 | lines_raw = self._dedent(self._consume_to_next_section())
22 | if lines_raw[0] == ":":
23 | del lines_raw[0]
24 | lines = self._format_block(":returns: ", list(_process_return(lines_raw)))
25 | if lines and lines[-1]:
26 | lines.append("")
27 | return lines
28 |
29 |
30 | def setup(app: Sphinx):
31 | """Set app."""
32 | NumpyDocstring._parse_returns_section = _parse_returns_section
33 |
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | ```{include} ../README.md
2 |
3 | ```
4 |
5 | ```{toctree}
6 | :hidden: true
7 | :maxdepth: 1
8 |
9 | api.md
10 | notebooks/Tutorial.myst
11 | changelog.md
12 | references.md
13 |
14 | ```
15 |
--------------------------------------------------------------------------------
/docs/notebooks/Tutorial.myst:
--------------------------------------------------------------------------------
1 | ---
2 | file_format: mystnb
3 | mystnb:
4 | execution_timeout: 600
5 | kernelspec:
6 | name: python3
7 | ---
8 |
9 | # pyLemur Walkthrough
10 |
11 |
12 | The goal of `pyLemur` is to simplify the analysis of multi-condition single-cell data. If you have collected a single-cell RNA-seq dataset with more than one condition, LEMUR predicts for each cell and gene how much the expression would change if the cell had been in the other condition.
13 |
14 | `pyLemur` is a Python implementation of the LEMUR model; there is also an `R` package called [lemur](https://bioconductor.org/packages/lemur/), which provides additional functionality: identifying neighborhoods of cells that show consistent differential expression values and a pseudo-bulk test to validate the findings.
15 |
16 | `pyLemur` implements a novel framework to disentangle the effects of known covariates, latent cell states, and their interactions. At the core is a combination of matrix factorization and regression analysis implemented as geodesic regression on Grassmann manifolds. We call this latent embedding multivariate regression (LEMUR). For more details, see our [preprint](https://www.biorxiv.org/content/10.1101/2023.03.06.531268) {cite:p}`Ahlmann-Eltze2024`.
17 |
18 |
19 |
20 |
21 | ## Data
22 |
23 | For demonstration, I will use a dataset of interferon-$\beta$ stimulated blood cells from {cite:t}`kang2018`.
24 |
25 | ```{code-cell} ipython3
26 | ---
27 | output_stderr: remove
28 | ---
29 | # Standard imports
30 | import numpy as np
31 | import scanpy as sc
32 | # pertpy is needed to download the Kang data
33 | import pertpy
34 |
35 | # This will download the data to ./data/kang_2018.h5ad
36 | adata = pertpy.data.kang_2018()
37 | # Store counts separately in the layers
38 | adata.layers["counts"] = adata.X.copy()
39 | ```
40 |
41 | The data consists of $24\,673$ cells and $15\,706$ genes. The cells were measured in two conditions (`label="ctrl"` and `label="stim"`). The authors have annotated the cell type for each cell, which will be useful to analyze LEMUR's results; however, note that the cell type labels are not used (and not needed) to fit the LEMUR model.
42 |
43 | ```{code-cell} ipython3
44 | :tags: ["remove-cell"]
45 | import pandas as pd
46 | pd.options.display.width = 200
47 | pd.options.display.max_colwidth = 20
48 | ```
49 |
50 | ```{code-cell} ipython3
51 | print(adata)
52 | print(adata.obs)
53 | ```
54 |
55 | ## Preprocessing
56 |
57 | LEMUR expects that the input has been variance-stabilized. Here, I will use the log-transformation as a simple, yet effective approach.
58 | In addition, I will only work on the $1\,000$ most variable genes to make the results easier to manage.
59 | ```{code-cell} ipython3
60 | # This follows the standard recommendation from scanpy
61 | sc.pp.normalize_total(adata, target_sum = 1e4, inplace=True)
62 | sc.pp.log1p(adata)
63 | adata.layers["logcounts"] = adata.X.copy()
64 | sc.pp.highly_variable_genes(adata, n_top_genes=1000, flavor="cell_ranger")
65 | adata = adata[:, adata.var.highly_variable]
66 | adata
67 | ```
68 |
69 | If we make a 2D plot of the data using UMAP, we see that the cell types separate by treatment status.
70 | ```{code-cell} ipython3
71 | sc.tl.pca(adata)
72 | sc.pp.neighbors(adata)
73 | sc.tl.umap(adata)
74 | sc.pl.umap(adata, color=["label", "cell_type"])
75 | ```
76 |
77 |
78 | ## LEMUR
79 |
80 | First, we import `pyLemur`; then, we fit the LEMUR model by providing the `AnnData` object, a specification of the experimental design, and the number of latent dimensions.
81 |
82 | ```{code-cell} ipython3
83 | import pylemur
84 | model = pylemur.tl.LEMUR(adata, design = "~ label", n_embedding=15)
85 | model.fit()
86 | model.align_with_harmony()
87 | print(model)
88 | ```
89 |
90 | To assess if the model was fit successfully, we plot a UMAP representation of the 15-dimensional embedding calculated by LEMUR. We want to see that the two conditions are well mixed in the embedding space because that means that LEMUR was able to disentangle the treatment effect from the cell type effect and that the residual variation is driven by the cell states.
91 | ```{code-cell} ipython3
92 | # Recalculate the UMAP on the embedding calculated by LEMUR
93 | adata.obsm["embedding"] = model.embedding
94 | sc.pp.neighbors(adata, use_rep="embedding")
95 | sc.tl.umap(adata)
96 | sc.pl.umap(adata, color=["label", "cell_type"])
97 | ```
98 |
99 | The LEMUR model is fully parametric, which means that we can predict for each cell what its expression would have been in any condition (i.e., for a cell observed in the control condition, we can predict its expression under treatment) as a function of its low-dimensional embedding.
100 |
101 | ```{code-cell} ipython3
102 | # The model.cond(**kwargs) call specifies the condition for the prediction
103 | ctrl_pred = model.predict(new_condition=model.cond(label="ctrl"))
104 | stim_pred = model.predict(new_condition=model.cond(label="stim"))
105 | ```
106 |
107 | We can now check the predicted differential expression against the underlying observed expression patterns for individual genes. Here, I chose _TSC22D3_ as an example. The blue cells in the first plot are in neighborhoods with higher expression in the control condition than in the stimulated condition. The two other plots show the underlying gene expression for the control and stimulated cells and confirm LEMUR's inference.
108 | ```{code-cell} ipython3
109 | import matplotlib.pyplot as plt
110 | adata.layers["diff"] = stim_pred - ctrl_pred
111 | # Also try CXCL10, IL8, and FBXO40
112 | sel_gene = "TSC22D3"
113 |
114 | fsize = plt.rcParams['figure.figsize']
115 | fig = plt.figure(figsize=(fsize[0] * 3, fsize[1]))
116 | axs = [fig.add_subplot(1, 3, i+1) for i in range(3)]
117 | for ax in axs:
118 | ax.set_aspect('equal')
119 | sc.pl.umap(adata, layer="diff", color=[sel_gene], cmap = plt.get_cmap("seismic"), vcenter=0,
120 | vmin=-4, vmax=4, title="Pred diff (stim - ctrl)", ax=axs[0], show=False)
121 | sc.pl.umap(adata[adata.obs["label"]=="ctrl"], layer="logcounts", color=[sel_gene], vmin = 0, vmax =4,
122 | title="Ctrl expr", ax=axs[1], show=False)
123 | sc.pl.umap(adata[adata.obs["label"]=="stim"], layer="logcounts", color=[sel_gene], vmin = 0, vmax =4,
124 | title="Stim expr", ax=axs[2])
125 | ```
126 |
127 | To assess the overall accuracy of LEMUR's predictions, I will compare the average observed and predicted expression per cell type between conditions. The next plot simply shows the observed expression values. Genes on the diagonal don't change expression much between conditions within a cell type, whereas all off-diagonal genes are differentially expressed:
128 | ```{code-cell} ipython3
129 | def rowMeans_per_group(X, group):
130 | uniq = np.unique(group)
131 | res = np.zeros((len(uniq), X.shape[1]))
132 | for i, e in enumerate(uniq):
133 | res[i,:] = X[group == e,:].sum(axis=0) / sum(group == e)
134 | return res
135 |
136 | adata_ctrl = adata[adata.obs["label"] == "ctrl",:]
137 | adata_stim = adata[adata.obs["label"] == "stim",:]
138 | ctrl_expr_per_cell_type = rowMeans_per_group(adata_ctrl.layers["logcounts"], adata_ctrl.obs["cell_type"])
139 | stim_expr_per_cell_type = rowMeans_per_group(adata_stim.layers["logcounts"], adata_stim.obs["cell_type"])
140 | obs_diff = stim_expr_per_cell_type - ctrl_expr_per_cell_type
141 | plt.scatter(ctrl_expr_per_cell_type, stim_expr_per_cell_type, c = obs_diff,
142 | cmap = plt.get_cmap("seismic"), vmin=-5, vmax=5, marker="o",edgecolors= "black")
143 | plt.colorbar()
144 | plt.title( "Inf-b stim. increases gene expression for many genes")
145 | plt.axline((0, 0), (1, 1), linewidth=1, color='black')
146 | ```
147 |
148 | To demonstrate that LEMUR learned the underlying expression relations, I predict what the expression of cells from the control condition would have been had they been stimulated and compare the results against the observed expression in the stimulated condition. The closer the points are to the diagonal, the better the predictions.
149 | ```{code-cell} ipython3
150 | stim_pred_per_cell_type = rowMeans_per_group(stim_pred[adata.obs["label"]=="ctrl"], adata_ctrl.obs["cell_type"])
151 |
152 | plt.scatter(stim_expr_per_cell_type, stim_pred_per_cell_type, c = obs_diff,
153 | cmap = plt.get_cmap("seismic"), vmin=-5, vmax=5, marker="o",edgecolors= "black")
154 | plt.colorbar()
155 | plt.title( "LEMUR's expression predictions are accurate")
156 | plt.axline((0, 0), (1, 1), linewidth=1, color='black')
157 | ```
158 |
159 | Lastly, I directly compare the average predicted differential expression against the average observed differential expression per cell type. Again, the closer the points are to the diagonal, the better the predictions.
160 |
161 | ```{code-cell} ipython3
162 | pred_diff = rowMeans_per_group(adata.layers["diff"], adata.obs["cell_type"])
163 |
164 | plt.scatter(obs_diff, pred_diff, c = obs_diff,
165 | cmap = plt.get_cmap("seismic"), vmin=-5, vmax=5, marker="o",edgecolors= "black")
166 | plt.colorbar()
167 | plt.title( "LEMUR's DE predictions are accurate")
168 | plt.axline((0, 0), (1, 1), linewidth=1, color='black')
169 | ```
170 |
171 | Another advantage of LEMUR's parametricity is that you could train the model on a subset of the data and then apply it to the full data.
172 |
173 | I will demonstrate this by training the same LEMUR model on 5% of the original data, then `transform` the full data, and finally compare the first three dimensions of the embedding against the embedding from the model trained on the full model.
174 |
175 | ```{code-cell} ipython3
176 | adata_subset = adata[np.random.choice(np.arange(adata.shape[0]), size = round(adata.shape[0] * 0.05)),]
177 | model_small = pylemur.tl.LEMUR(adata_subset, design = "~ label", n_embedding=15)
178 | model_small.fit().align_with_harmony()
179 | emb_proj = model_small.transform(adata)
180 | plt.scatter(emb_proj[:,0:3], model.embedding[:,0:3], s = 0.1)
181 | plt.axline((0, 0), (1, 1), linewidth=1, color='black')
182 | plt.axline((0, 0), (-1, 1), linewidth=1, color='black')
183 | ```
184 |
185 | We see that the small model still captures most of the relevant variation.
186 | ```{code-cell} ipython3
187 | adata.obsm["embedding_from_small_fit"] = emb_proj
188 | sc.pp.neighbors(adata, use_rep="embedding_from_small_fit")
189 | sc.tl.umap(adata)
190 | sc.pl.umap(adata, color=["label", "cell_type"])
191 | ```
192 |
193 | ### Session Info
194 |
195 | ```{code-cell} ipython3
196 | import session_info
197 | session_info.show()
198 | ```
199 |
200 |
201 | ### References
202 |
203 | ```{bibliography}
204 | :style: plain
205 | :filter: docname in docnames
206 | ```
207 |
--------------------------------------------------------------------------------
/docs/references.bib:
--------------------------------------------------------------------------------
1 | @article{Virshup_2023,
2 | doi = {10.1038/s41587-023-01733-8},
3 | url = {https://doi.org/10.1038%2Fs41587-023-01733-8},
4 | year = 2023,
5 | month = {apr},
6 | publisher = {Springer Science and Business Media {LLC}},
7 | author = {Isaac Virshup and Danila Bredikhin and Lukas Heumos and Giovanni Palla and Gregor Sturm and Adam Gayoso and Ilia Kats and Mikaela Koutrouli and Philipp Angerer and Volker Bergen and Pierre Boyeau and Maren Büttner and Gokcen Eraslan and David Fischer and Max Frank and Justin Hong and Michal Klein and Marius Lange and Romain Lopez and Mohammad Lotfollahi and Malte D. Luecken and Fidel Ramirez and Jeffrey Regier and Sergei Rybakov and Anna C. Schaar and Valeh Valiollah Pour Amiri and Philipp Weiler and Galen Xing and Bonnie Berger and Dana Pe'er and Aviv Regev and Sarah A. Teichmann and Francesca Finotello and F. Alexander Wolf and Nir Yosef and Oliver Stegle and Fabian J. Theis and},
8 | title = {The scverse project provides a computational ecosystem for single-cell omics data analysis},
9 | journal = {Nature Biotechnology}
10 | }
11 | @article{Ahlmann-Eltze2024,
12 | title = {Analysis of multi-condition single-cell data with latent embedding multivariate regression},
13 | author = {Ahlmann-Eltze, Constantin and Huber, Wolfgang},
14 | year = {2024},
15 | month = {02},
16 | url = {http://dx.doi.org/10.1101/2023.03.06.531268},
17 | journal = {bioRxiv},
18 | }
19 | @article{kang2018,
20 | title = {Multiplexed droplet single-cell RNA-sequencing using natural genetic variation},
21 | author = {Kang, Hyun Min and Subramaniam, Meena and Targ, Sasha and Nguyen, Michelle and Maliskova, Lenka and McCarthy, Elizabeth and Wan, Eunice and Wong, Simon and Byrnes, Lauren and Lanata, Cristina M and Gate, Rachel E and Mostafavi, Sara and Marson, Alexander and Zaitlen, Noah and Criswell, Lindsey A and Ye, Chun Jimmie},
22 | year = {2018},
23 | month = {01},
24 | date = {2018-01},
25 | journal = {Nature Biotechnology},
26 | pages = {89--94},
27 | volume = {36},
28 | number = {1},
29 | doi = {10.1038/nbt.4042},
30 | url = {http://www.nature.com/articles/nbt.4042},
31 | langid = {en}
32 | }
33 |
--------------------------------------------------------------------------------
/docs/references.md:
--------------------------------------------------------------------------------
1 | # References
2 |
3 | ```{bibliography}
4 | :cited:
5 | ```
6 |
--------------------------------------------------------------------------------
/notebooks/check_implementation.qmd:
--------------------------------------------------------------------------------
1 | ---
2 | title: "Compare implementation directly against R"
3 | author: Constantin Ahlmann-Eltze
4 | format:
5 | html:
6 | code-fold: false
7 | embed-resources: true
8 | highlight-style: github
9 | toc: true
10 | code-line-numbers: true
11 | execute:
12 | keep-ipynb: true
13 | jupyter: python3
14 | ---
15 |
16 |
17 | ```{python}
18 | %load_ext autoreload
19 | %autoreload 2
20 | ```
21 |
22 |
23 | ```{python}
24 | import debugpy
25 | debugpy.listen(5678)
26 | print("Waiting for debugger attach")
27 | debugpy.wait_for_client()
28 | ```
29 |
30 |
31 | ```{python}
32 | import numpy as np
33 | import scanpy as sc
34 | import pertpy
35 |
36 | adata = pertpy.data.kang_2018()
37 | adata.layers["counts"] = adata.X.copy()
38 | ```
39 |
40 |
41 | ```{python}
42 | sc.pp.normalize_total(adata, target_sum = 1e4, inplace=True)
43 | sc.pp.log1p(adata)
44 | adata.layers["logcounts"] = adata.X.copy()
45 | sc.pp.highly_variable_genes(adata, n_top_genes=1000, flavor="cell_ranger")
46 | adata = adata[:, adata.var.highly_variable]
47 | adata
48 | ```
49 |
50 |
51 | ```{python}
52 | import pylemur
53 | model = pylemur.tl.LEMUR(adata, design = "~ label", n_embedding=15, layer = "logcounts")
54 | model.fit()
55 | model.align_with_harmony()
56 | print(model)
57 | ```
58 |
59 |
60 | ```{python}
61 | ctrl_pred = model.predict(new_condition=model.cond(label="ctrl"))
62 | model.cond(label = "stim") - model.cond(label = "ctrl")
63 | ```
64 |
65 | ```{python}
66 | model.embedding.shape
67 | model.adata
68 | ```
69 |
70 | ```{python}
71 | # groups = np.array([np.nan] * fit.shape[0])
72 | # groups[fit.obs["cell_type"] == "CD4 T cells"] = 1
73 | # groups[fit.obs["cell_type"] == "NK cells"] = 2
74 | # groups[fit.obs["cell_type"] == "Dendritic cells"] = 3
75 | import pandas as pd
76 | groups = fit.obs["cell_type"]
77 | groups[np.array([0,9,99])] = pd.NA
78 |
79 | np.unique(groups.to_numpy())
80 | ```
81 |
82 | ```{python}
83 | fit2 = fit.copy()
84 | fit2 = pylemur.tl.align_with_grouping(fit2, groups)
85 | ```
86 |
87 |
88 | ```{python}
89 | fit2.obsm["embedding"][0:3, 0:3].T
90 | fit2.obsm["embedding"][24600:24603,0:3].T
91 | ```
92 |
93 | ```{python}
94 | pred_ctrl = pylemur.tl.predict(fit2, new_condition=pylemur.tl.cond(fit2, label = "ctrl"))
95 | pred_ctrl[0:3,0:3].T
96 | ```
97 |
98 |
99 | ```{python}
100 | np.savetxt("/var/folders/dc/tppjxs9x6ll378lq88lz1fm40000gq/T//Rtmpm5PTRh/pred_ctrl.tsv", pred_ctrl, delimiter="\t")
101 | np.savetxt("/var/folders/dc/tppjxs9x6ll378lq88lz1fm40000gq/T//Rtmpm5PTRh/embedding.tsv", fit2.obsm["embedding"], delimiter="\t")
102 | ```
103 |
--------------------------------------------------------------------------------
/notebooks/devel_experiments.qmd:
--------------------------------------------------------------------------------
1 | ---
2 | title: "Quick check if new functions work as expected"
3 | author: Constantin Ahlmann-Eltze
4 | format:
5 | html:
6 | code-fold: false
7 | embed-resources: true
8 | highlight-style: github
9 | toc: true
10 | code-line-numbers: true
11 | execute:
12 | keep-ipynb: true
13 | jupyter: python3
14 | ---
15 |
16 |
17 | ```{python}
18 | %load_ext autoreload
19 | %autoreload 2
20 | ```
21 |
22 |
23 | ```{python}
24 | import debugpy
25 | debugpy.listen(5678)
26 | print("Waiting for debugger attach")
27 | debugpy.wait_for_client()
28 | ```
29 |
30 | Convert formula to design matrix in Python
31 |
32 | ```{python}
33 | import numpy as np
34 | from patsy import dmatrices, dmatrix, demo_data
35 | ```
36 |
37 |
38 | ```{python}
39 | data = demo_data("a", "b", "x1", "x2", "y", "z column", min_rows = 400)
40 | data
41 | ```
42 |
43 | ```{python}
44 | form = dmatrix("~ a + x1", data)
45 | ```
46 |
47 |
48 | ```{python}
49 | from pyLemur.handle_design import *
50 | from pyLemur.row_groups import *
51 | ```
52 |
53 | ```{python}
54 | des, form = convert_formula_to_design_matrix("~ a + x1", data)
55 | ```
56 |
57 |
58 | ```{python}
59 | row_groups(des)
60 | ```
61 |
62 |
63 | ```{python}
64 | # 400 cells, 30 features
65 | Y = np.random.rand(400, 30)
66 | ```
67 |
68 |
69 | ```{python}
70 | import sklearn.decomposition as skd
71 |
72 | pca_fit = PCA(n_components=2)
73 | emb = pca_fit.fit_transform(Y)
74 | pca_fit.score(Y)
75 | pca_fit.mean_
76 | ```
77 |
78 |
79 | ```{python}
80 | svd = skd.TruncatedSVD(n_components=2)
81 |
82 | ```
83 |
84 | ```{python}
85 | from pyLemur.pca import *
86 |
87 | pca = fit_pca(Y, 2, center=False)
88 | pca
89 |
90 | ```
91 |
92 |
93 | ```{python}
94 | from pyLemur.lin_alg_wrappers import *
95 | from pyLemur.design_matrix_utils import *
96 |
97 | data = demo_data("a", "b", "x1", "x2", "y", "z column", min_rows = 400)
98 | des, form = convert_formula_to_design_matrix("~ a", data)
99 | Y = np.random.rand(400, 30)
100 | beta = ridge_regression(Y, des, 0)
101 | beta
102 | ```
103 |
104 | ```{python}
105 | np.unique(des, return_counts=True)[1]
106 | ```
107 |
108 | ```{python}
109 | from pyLemur.grassmann_lm import *
110 |
111 | base_point = fit_pca(Y, n = 3, center = False).coord_system
112 | V = grassmann_lm(Y, des, base_point)
113 | V.shape
114 |
115 | ```
116 |
117 | ```{python}
118 | np.hstack([[1,2,3], [4,5,6]]).shape
119 | ```
120 |
121 |
122 | ```{python}
123 | # The python equivalent of R's seq(-3, 4, length.out = 18)
124 | # Y = np.linspace(-3, 4, 18).reshape((6,3))
125 | # des = np.hstack([np.ones((6,1)), np.array([1,0,1,0,1,0]).reshape((6,1))])
126 | # Read csv file into numpy array
127 | import pandas as pd
128 | Y = pd.read_csv("/var/folders/dc/tppjxs9x6ll378lq88lz1fm40000gq/T//Rtmp89jWYn/file128665ac82720").to_numpy().T
129 | des = pd.read_csv("/var/folders/dc/tppjxs9x6ll378lq88lz1fm40000gq/T//Rtmp89jWYn/file12866a80d470").to_numpy()
130 | base_point = fit_pca(Y, n = 3, center = False).coord_system
131 | V = grassmann_lm(Y, des, base_point)
132 | V
133 | ```
134 |
--------------------------------------------------------------------------------
/notebooks/kang_analysis_in_R.qmd:
--------------------------------------------------------------------------------
1 | ---
2 | title: "Kang analysis with R"
3 | format: html
4 | ---
5 |
6 |
7 | ```{r}
8 | library(tidyverse)
9 | library(SingleCellExperiment)
10 | ```
11 |
12 |
13 | ```{r}
14 | sce <- zellkonverter::readH5AD("data/kang_2018.h5ad", X_name = "counts")
15 | ```
16 |
17 | ```{r}
18 | logcounts(sce) <- transformGamPoi::shifted_log_transform(counts(sce))
19 | hvgs <- order(-rowVars(logcounts(sce)))
20 | sce <- sce[hvgs[1:1000],]
21 | ```
22 |
23 | ```{r}
24 | system.time({
25 | fit <- lemur::lemur(sce, design = ~ label, n_embedding = 15, verbose = TRUE, test_fraction = 0)
26 | })
27 |
28 | fit$embedding[1:3, 1:3]
29 | fit$embedding[1:3, 24601:24603]
30 |
31 |
32 | # fit <- lemur::align_harmony(fit)
33 |
34 | # groups <- rep(NA, ncol(fit))
35 | # groups[fit$colData$cell_type == "CD4 T cells"] <- 1
36 | # groups[fit$colData$cell_type == "NK cells"] <- 2
37 | # groups[fit$colData$cell_type == "Dendritic cells"] <- 3
38 | groups <- fit$colData$cell_type
39 | groups[c(1,10,100)] <- NA
40 |
41 | fit <- lemur::align_by_grouping(fit, groups)
42 |
43 | fit$embedding[1:3, 1:3]
44 | fit$embedding[1:3, 24601:24603]
45 |
46 |
47 | pred_ctrl <- predict(fit, newcondition = cond(label = "ctrl"))
48 | pred_ctrl[1:3, 1:3]
49 | ```
50 |
51 |
52 | ```{r}
53 | py_pred <- t(as.matrix(readr::read_tsv("/var/folders/dc/tppjxs9x6ll378lq88lz1fm40000gq/T//Rtmpm5PTRh/pred_ctrl.tsv", col_names = FALSE)))
54 | dimnames(py_pred) <- dimnames(fit)
55 | plot(py_pred[1,], pred_ctrl[1,])
56 | ```
57 |
58 |
59 | ```{r}
60 | py_emb <- t(as.matrix(readr::read_tsv("/var/folders/dc/tppjxs9x6ll378lq88lz1fm40000gq/T//Rtmpm5PTRh/embedding.tsv", col_names = FALSE)))
61 | dimnames(py_emb) <- dimnames(fit$embedding)
62 | plot(py_emb[1,], fit$embedding[1,])
63 | ```
64 |
65 | ```{r}
66 | as_tibble(colData(fit)) %>%
67 | mutate(r_emb = t(fit$embedding),
68 | py_emb = t(py_emb)) %>%
69 | pivot_longer(ends_with("emb"), names_sep = "_", names_to = c("origin", ".value")) %>%
70 | ggplot(aes(x = emb[,4], y = emb[,15])) +
71 | geom_point(aes(color = label)) +
72 | facet_wrap(vars(origin))
73 |
74 | ```
75 |
--------------------------------------------------------------------------------
/notebooks/lemur_model_experiments.qmd:
--------------------------------------------------------------------------------
1 | ---
2 | title: "Play with scanpy"
3 | author: Constantin Ahlmann-Eltze
4 | format:
5 | html:
6 | code-fold: false
7 | embed-resources: true
8 | highlight-style: github
9 | toc: true
10 | code-line-numbers: true
11 | execute:
12 | keep-ipynb: true
13 | jupyter: python3
14 | ---
15 |
16 |
17 | ```{python}
18 | %load_ext autoreload
19 | %autoreload 2
20 | ```
21 |
22 |
23 | ```{python}
24 | import debugpy
25 | debugpy.listen(5678)
26 | print("Waiting for debugger attach")
27 | debugpy.wait_for_client()
28 | ```
29 |
30 | ```{python}
31 | import numpy as np
32 | import scanpy as sc
33 | import pertpy
34 | import scanpy.preprocessing._simple as simple
35 |
36 | adata = pertpy.data.kang_2018()
37 | adata.layers["counts"] = adata.X.copy()
38 | sf = np.array(adata.layers["counts"].sum(axis=1))
39 | sf = sf / np.median(sf)
40 | adata.layers["logcounts"] = sc.pp.log1p(adata.layers["counts"] / sf)
41 | var = simple._get_mean_var(adata.layers["logcounts"])[1]
42 | hvgs = var.argpartition(-1000)[-1000:]
43 | adata = adata[:, hvgs]
44 | ```
45 |
46 | ```{python}
47 | import numpy as np
48 | n_cells = 400
49 | n_genes = 100
50 | mat = np.arange(n_cells * n_genes).reshape((n_cells, n_genes))
51 | log_mat = shifted_log_transform(mat)
52 | ```
53 |
54 |
55 | ```{python}
56 | # shifted log transformation ala transformGamPoi
57 |
58 | ```
59 |
60 |
61 | ```{python}
62 | adata.X = adata.layers["logcounts"]
63 | sc.pp.pca(adata)
64 | sc.pp.neighbors(adata)
65 | sc.tl.umap(adata)
66 | sc.pl.umap(adata, color=["label", "cell_type"])
67 | ```
68 |
69 | ```{python}
70 | from pylemur.tl import lemur
71 |
72 | fit = lemur(adata, design = ["label"])
73 | ```
74 |
75 | ```{python}
76 | sc.pp.neighbors(fit, use_rep="embedding")
77 | sc.tl.umap(fit)
78 | sc.pl.umap(fit, color=["label", "cell_type"])
79 | ```
80 |
81 |
82 | ```{python}
83 | from pylemur.tl import align_with_harmony
84 | align_with_harmony(fit, ridge_penalty = 0.01)
85 | ```
86 |
87 |
88 | ```{python}
89 | nei = sc.pp.neighbors(fit, use_rep="embedding")
90 | sc.tl.umap(fit)
91 | sc.pl.umap(fit, color=["label", "cell_type"])
92 | ```
93 |
94 |
95 | ```{python}
96 | import matplotlib.pyplot as plt
97 | plt.scatter(fit.obsm["new_embedding"], fit.obsm["embedding"])
98 | ```
99 |
100 | ```{python}
101 | from pylemur.tl import predict, cond
102 | pred_ctrl = predict(fit, new_condition = cond(fit, label = "ctrl"))
103 | pred_stim = predict(fit, new_condition = cond(fit, label = "stim"))
104 | delta2 = pred_stim - pred_ctrl
105 | ```
106 |
107 | ```{python}
108 | import matplotlib.pyplot as plt
109 | plt.scatter(delta[:,0], delta2[:,0])
110 | ```
111 |
112 | ```{python}
113 | gene = 0
114 | fit.obs["delta"] = delta[:,gene]
115 | fit.obs["delta2"] = delta2[:,gene]
116 | fit.obs["expr"] = fit.layers["logcounts"][:,gene].toarray()
117 |
118 | import matplotlib.pyplot as plt
119 | sc.pl.umap(fit, color=["delta", "delta2"], cmap = plt.get_cmap("seismic"), vcenter=0)
120 |
121 | sc.pl.umap(fit[fit.obs["label"] == "ctrl"], color="expr", vmax = 1)
122 | sc.pl.umap(fit[fit.obs["label"] == "stim"], color="expr", vmax = 1)
123 | ```
124 |
125 | ```{python}
126 | # fit.obs["plot_label"] = fit.obs["label"].astype(str) + "-" + fit.obs["cell_type"].astype(str)
127 | fit.obs["plot_label"] = fit.obs["label"].astype(str) + "-" + fit.obs["cell_type"].astype(str)
128 | sc.pl.violin(fit, groupby="plot_label", keys="expr")
129 | ```
130 |
--------------------------------------------------------------------------------
/notebooks/scanpy_experiments.qmd:
--------------------------------------------------------------------------------
1 | ---
2 | title: "Play with scanpy"
3 | author: Constantin Ahlmann-Eltze
4 | format:
5 | html:
6 | code-fold: false
7 | embed-resources: true
8 | highlight-style: github
9 | toc: true
10 | code-line-numbers: true
11 | execute:
12 | keep-ipynb: true
13 | jupyter: python3
14 | ---
15 |
16 |
17 | ```{python}
18 | %load_ext autoreload
19 | %autoreload 2
20 | ```
21 |
22 |
23 | ```{python}
24 | import debugpy
25 | debugpy.listen(5678)
26 | print("Waiting for debugger attach")
27 | debugpy.wait_for_client()
28 | ```
29 |
30 | ```{python}
31 | import numpy as np
32 | import scanpy as sc
33 | import pertpy
34 |
35 | adata = pertpy.data.kang_2018()
36 | adata
37 | ```
38 |
39 | ```{python}
40 | adata.obs
41 | adata.var
42 | ```
43 |
44 | ```{python}
45 | np.unique(adata.X.sum(axis=0) > 5, axis = 1, return_counts=True)
46 | ```
47 |
48 | ```{python}
49 | adata.X[0:100,0:10]
50 | ```
51 |
52 |
53 | ```{python}
54 | adata.layers["counts"] = adata.X.copy()
55 | sf = np.array(adata.layers["counts"].sum(axis=1))
56 | sf = sf / np.median(sf)
57 | adata.layers["logcounts"] = sc.pp.log1p(adata.layers["counts"] / sf)
58 | ```
59 |
60 | ```{python}
61 | import scanpy.preprocessing._simple as simple
62 | var = simple._get_mean_var(adata.layers["logcounts"])[1]
63 | hvgs = var.argpartition(-1000)[-1000:]
64 | adata = adata[:, hvgs]
65 | ```
66 |
67 |
68 | ```{python}
69 | adata.X = adata.layers["logcounts"]
70 | sc.pp.pca(adata)
71 | sc.pp.neighbors(adata)
72 | sc.tl.umap(adata)
73 | sc.pl.umap(adata, color=["label", "cell_type"])
74 | ```
75 |
76 | ```{python}
77 | # Shuffle the rows of adata
78 | adata = adata[np.random.permutation(adata.obs.index), :]
79 | ```
80 |
81 | ```{python}
82 | import formulaic
83 | import pandas as pd
84 | df = pd.DataFrame(adata.obs)
85 |
86 | des = formulaic.model_matrix("~ label", df)
87 | form = formulaic.Formula("~ label")
88 | ```
89 |
90 | ```{python}
91 | import patsy
92 | des2 = patsy.dmatrix("~ label", df)
93 | des2
94 | ```
95 |
96 | ```{python}
97 | from pyLemur.lemur import lemur
98 |
99 | extra_data = {"test": np.random.randint(3, size=adata.shape[0]),
100 | "cat": ["ABC"[x] for x in np.random.randint(3, size=adata.shape[0])]}
101 | fit = lemur(adata, design = ["label"], obs_data=extra_data)
102 | ```
103 |
104 | ```{python}
105 | sc.pp.neighbors(fit, use_rep="embedding")
106 | sc.tl.umap(fit)
107 | sc.pl.umap(fit, color=["label", "cat", "cell_type"])
108 | ```
109 |
110 | ```{python}
111 | coord = fit_pca(adata.layers["logcounts"].toarray(), n = 15, center = False).coord_system
112 | coord2 = fit_pca(adata.layers["logcounts"].toarray(), n = 15, center = False).coord_system
113 | print(grassmann_angle_from_point(coord2.T, coord.T))
114 | ```
115 |
116 |
117 | ```{python}
118 | from pyLemur.lin_alg_wrappers import *
119 | from pyLemur.grassmann import *
120 |
121 | fit = lemur(adata, design = "~ label", linear_coefficient_estimator="zero")
122 | V_slice1 = fit.uns["lemur"]["coefficients"][:,:,0]
123 | V_slice2 = fit.uns["lemur"]["coefficients"][:,:,1]
124 |
125 | coord_all = fit_pca(adata.X.toarray(), n = 15, center = False).coord_system
126 | coord_ctrl = fit_pca(adata.X[adata.obs["label"] == "ctrl", :].toarray(), n = 15, center = False).coord_system
127 | coord_stim = fit_pca(adata.X[adata.obs["label"] == "stim", :].toarray(), n = 15, center = False).coord_system
128 |
129 | print(grassmann_angle_from_point(fit.uns["lemur"]["base_point"].T, coord_all.T))
130 | print(grassmann_angle_from_point(grassmann_map(V_slice1.T, fit.uns["lemur"]["base_point"].T), coord_ctrl.T))
131 | print(grassmann_angle_from_point(grassmann_map((V_slice1 + V_slice2).T, fit.uns["lemur"]["base_point"].T), coord_stim.T))
132 | ```
133 |
134 |
135 |
136 | ```{python}
137 | from pyLemur.predict import *
138 |
139 | fit.uns["lemur"]["design_matrix"]
140 |
141 | ```
142 |
143 | ```{python}
144 | pred_stim = predict(fit, new_condition = cond(fit, label = "stim", cat = "B"))
145 | pred_ctrl = predict(fit, new_condition = cond(fit, label = "ctrl", cat = "B"))
146 | delta = pred_stim - pred_ctrl
147 | ```
148 |
149 | ```{python}
150 | gene = 3
151 | fit.obs["delta"] = delta[:,gene]
152 | fit.obs["expr"] = fit.layers["logcounts"][:,gene].toarray()
153 |
154 | import matplotlib.pyplot as plt
155 | sc.pl.umap(fit, color="delta", cmap = plt.get_cmap("seismic"), vcenter=0)
156 |
157 | sc.pl.umap(fit[fit.obs["label"] == "ctrl"], color="expr", vmax = 2)
158 | sc.pl.umap(fit[fit.obs["label"] == "stim"], color="expr", vmax = 2)
159 | ```
160 |
161 | ```{python}
162 | ```
163 |
164 | ```{python}
165 | import harmonypy
166 | ho = harmonypy.run_harmony(fit.obsm["embedding"], fit.obs, "label")
167 | fit.obsm["harmony"] = ho.Z_corr.T
168 | ho.cl
169 | nei = sc.pp.neighbors(fit, use_rep="harmony")
170 | sc.tl.umap(fit)
171 | sc.pl.umap(fit, color=["label", "delta", "cell_type"])
172 | ```
173 |
174 |
175 | ```{python}
176 | from pyLemur.alignment import *
177 | align_with_harmony(fit)
178 | ```
179 |
180 | ```{python}
181 | ho = harmonypy.run_harmony(fit.obsm["embedding"], fit.obs, "label")
182 |
183 | from pyLemur.alignment import *
184 |
185 | ho = init_harmony(fit.obsm["embedding"], fit.uns["lemur"]["design_matrix"])
186 | ho.cluster()
187 | ```
188 |
189 | ```{python}
190 | A = np.hstack([np.ones((5, 2)), np.zeros((5, 6))])
191 | B = np.arange(8).reshape((8, 1))
192 | multiply_along_axis(A, B.T, axis = 1)
193 | ```
194 |
195 | ```{python}
196 | nei = sc.pp.neighbors(fit, use_rep="new_embedding")
197 | sc.tl.umap(fit)
198 | sc.pl.umap(fit, color=["label", "delta", "cell_type"])
199 | ```
200 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | build-backend = "hatchling.build"
3 | requires = ["hatchling"]
4 |
5 | [project]
6 | name = "pyLemur"
7 | version = "0.3.1"
8 | description = "A Python implementation of the LEMUR algorithm for analyzing multi-condition single-cell RNA-seq data."
9 | readme = "README.md"
10 | requires-python = ">=3.10"
11 | license = {file = "LICENSE"}
12 | authors = [
13 | {name = "Constantin Ahlmann-Eltze"},
14 | ]
15 | maintainers = [
16 | {name = "Constantin Ahlmann-Eltze", email = "artjom31415@googlemail.com"},
17 | ]
18 | urls.Documentation = "https://pyLemur.readthedocs.io/"
19 | urls.Source = "https://github.com/const-ae/pyLemur"
20 | urls.Home-page = "https://github.com/const-ae/pyLemur"
21 | dependencies = [
22 | "anndata",
23 | "formulaic",
24 | "formulaic_contrasts",
25 | "harmonypy",
26 | "pandas",
27 | "numpy",
28 | "scikit-learn",
29 | # for debug logging (referenced from the issue template)
30 | "session-info",
31 | ]
32 |
33 | [project.optional-dependencies]
34 | dev = [
35 | "pre-commit",
36 | "twine>=4.0.2",
37 | ]
38 | doc = [
39 | "docutils>=0.8,!=0.18.*,!=0.19.*",
40 | "sphinx>=4",
41 | "sphinx-book-theme>=1.0.0",
42 | "myst-nb>=1.1.0",
43 | "sphinxcontrib-bibtex>=1.0.0",
44 | "sphinx-autodoc-typehints",
45 | "sphinxext-opengraph",
46 | "sphinx-design",
47 | # For notebooks
48 | "ipykernel",
49 | "ipython",
50 | "sphinx-copybutton",
51 | "pandas",
52 | "scanpy",
53 | "pertpy",
54 | "matplotlib",
55 | "session-info",
56 | ]
57 | test = [
58 | "pytest",
59 | "coverage",
60 | ]
61 |
62 | [tool.coverage.run]
63 | source = ["pylemur"]
64 | omit = [
65 | "**/test_*.py",
66 | ]
67 |
68 | [tool.pytest.ini_options]
69 | testpaths = ["tests"]
70 | xfail_strict = true
71 | addopts = [
72 | "--import-mode=importlib", # allow using test files with same name
73 | ]
74 |
75 | [tool.ruff]
76 | line-length = 120
77 | src = ["src"]
78 | extend-include = ["*.ipynb"]
79 |
80 | [tool.ruff.format]
81 | docstring-code-format = true
82 |
83 | [tool.ruff.lint]
84 | select = [
85 | "F", # Errors detected by Pyflakes
86 | "E", # Error detected by Pycodestyle
87 | "W", # Warning detected by Pycodestyle
88 | "I", # isort
89 | "D", # pydocstyle
90 | "B", # flake8-bugbear
91 | "TID", # flake8-tidy-imports
92 | "C4", # flake8-comprehensions
93 | "BLE", # flake8-blind-except
94 | "UP", # pyupgrade
95 | "RUF100", # Report unused noqa directives
96 | ]
97 | ignore = [
98 | # line too long -> we accept long comment lines; formatter gets rid of long code lines
99 | "E501",
100 | # Do not assign a lambda expression, use a def -> lambda expression assignments are convenient
101 | "E731",
102 | # allow I, O, l as variable names -> I is the identity matrix
103 | "E741",
104 | # Missing docstring in public package
105 | "D104",
106 | # Missing docstring in public module
107 | "D100",
108 | # Missing docstring in __init__
109 | "D107",
110 | # Errors from function calls in argument defaults. These are fine when the result is immutable.
111 | "B008",
112 | # __magic__ methods are often self-explanatory, allow missing docstrings
113 | "D105",
114 | # first line should end with a period [Bug: doesn't work with single-line docstrings]
115 | "D400",
116 | # First line should be in imperative mood; try rephrasing
117 | "D401",
118 | ## Disable one in each pair of mutually incompatible rules
119 | # We don’t want a blank line before a class docstring
120 | "D203",
121 | # We want docstrings to start immediately after the opening triple quote
122 | "D213",
123 | ]
124 |
125 | [tool.ruff.lint.pydocstyle]
126 | convention = "numpy"
127 |
128 | [tool.ruff.lint.per-file-ignores]
129 | "docs/*" = ["I"]
130 | "tests/*" = [
131 | # Pydoc doesn't mattter in tests
132 | "D",
133 | # Ignore star imports in tests
134 | "F403", "F405"
135 | ]
136 | "*/__init__.py" = ["F401"]
137 |
138 | [tool.cruft]
139 | skip = [
140 | "tests",
141 | "src/**/__init__.py",
142 | "src/**/basic.py",
143 | "docs/api.md",
144 | "docs/changelog.md",
145 | "docs/references.bib",
146 | "docs/references.md",
147 | "docs/notebooks/example.ipynb",
148 | ]
149 |
--------------------------------------------------------------------------------
/src/pylemur/__init__.py:
--------------------------------------------------------------------------------
1 | from importlib.metadata import version
2 |
3 | from . import pl, pp, tl
4 |
5 | __all__ = ["pl", "pp", "tl"]
6 |
7 | __version__ = version("pyLemur")
8 |
--------------------------------------------------------------------------------
/src/pylemur/pl/__init__.py:
--------------------------------------------------------------------------------
1 | from .basic import BasicClass, basic_plot
2 |
--------------------------------------------------------------------------------
/src/pylemur/pl/basic.py:
--------------------------------------------------------------------------------
1 | from anndata import AnnData
2 |
3 |
4 | def basic_plot(adata: AnnData) -> int:
5 | """Generate a basic plot for an AnnData object.
6 |
7 | Parameters
8 | ----------
9 | adata
10 | The AnnData object to preprocess.
11 |
12 | Returns
13 | -------
14 | Some integer value.
15 | """
16 | print("Import matplotlib and implement a plotting function here.")
17 | return 0
18 |
19 |
20 | class BasicClass:
21 | """A basic class.
22 |
23 | Parameters
24 | ----------
25 | adata
26 | The AnnData object to preprocess.
27 | """
28 |
29 | my_attribute: str = "Some attribute."
30 | my_other_attribute: int = 0
31 |
32 | def __init__(self, adata: AnnData):
33 | print("Implement a class here.")
34 |
35 | def my_method(self, param: int) -> int:
36 | """A basic method.
37 |
38 | Parameters
39 | ----------
40 | param
41 | A parameter.
42 |
43 | Returns
44 | -------
45 | Some integer value.
46 | """
47 | print("Implement a method here.")
48 | return 0
49 |
50 | def my_other_method(self, param: str) -> str:
51 | """Another basic method.
52 |
53 | Parameters
54 | ----------
55 | param
56 | A parameter.
57 |
58 | Returns
59 | -------
60 | Some integer value.
61 | """
62 | print("Implement a method here.")
63 | return ""
64 |
--------------------------------------------------------------------------------
/src/pylemur/pp/__init__.py:
--------------------------------------------------------------------------------
1 | from .basic import shifted_log_transform
2 |
--------------------------------------------------------------------------------
/src/pylemur/pp/basic.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy
3 |
4 |
5 | def shifted_log_transform(counts, overdispersion=0.05, pseudo_count=None, minimum_overdispersion=0.001):
6 | r"""Apply log transformation to count data
7 |
8 | The transformation is proportional to :math:`\log(y/s + c)`, where :math:`y` are the counts, :math:`s` is the size factor,
9 | and :math:`c` is the pseudo-count.
10 |
11 | The actual transformation is :math:`a \, \log(y/(sc) + 1)`, where using :math:`+1` ensures that the output remains
12 | sparse for :math:`y=0` and the scaling with :math:`a=\sqrt{4c}` ensures that the transformation approximates the :math:`\operatorname{acosh}`
13 | transformation. Using :math:`\log(y/(sc) + 1)` instead of :math:`\log(y/s+c)`
14 | only changes the results by a constant offset, as :math:`\log(y+c) = \log(y/c + 1) - \log(1/c)`. Importantly, neither
15 | scaling nor offsetting by a constant changes the variance-stabilizing quality of the transformation.
16 |
17 | The size factors are calculated as normalized sum per cell::
18 |
19 | size_factors = counts.sum(axis=1)
20 | size_factors = size_factors / np.exp(np.mean(np.log(size_factors)))
21 |
22 | In case `y` is not a matrix, the `size_factors` are fixed to 1.
23 |
24 | Parameters
25 | ----------
26 | counts
27 | The count matrix which can be sparse.
28 | overdispersion,pseudo_count
29 | Specification how much variation is expected for a homogeneous sample. The `overdispersion` and
30 | `pseudo_count` are related by `overdispersion = 1 / (4 * pseudo_count)`.
31 | minimum_overdispersion
32 | Avoid overdispersion values smaller than `minimum_overdispersion`.
33 |
34 | Returns
35 | -------
36 | A matrix of variance-stabilized values.
37 | """
38 | if pseudo_count is None:
39 | pseudo_count = 1 / (4 * overdispersion)
40 |
41 | n_cells = counts.shape[0]
42 | size_factors = counts.sum(axis=1)
43 | size_factors = size_factors / np.exp(np.mean(np.log(size_factors)))
44 | norm_mat = counts / size_factors.reshape((n_cells, 1))
45 | overdispersion = 1 / (4 * pseudo_count)
46 | if overdispersion < minimum_overdispersion:
47 | overdispersion = minimum_overdispersion
48 | res = 1 / np.sqrt(overdispersion) * np.log1p(4 * overdispersion * norm_mat)
49 | if scipy.sparse.issparse(counts):
50 | res = scipy.sparse.csr_matrix(res)
51 | return res
52 |
--------------------------------------------------------------------------------
/src/pylemur/tl/__init__.py:
--------------------------------------------------------------------------------
1 | from .lemur import LEMUR
2 |
--------------------------------------------------------------------------------
/src/pylemur/tl/_design_matrix_utils.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Mapping
2 |
3 | import numpy as np
4 | import pandas as pd
5 |
6 | # import patsy
7 | from formulaic import model_matrix
8 | from numpy.lib import NumpyVersion
9 |
10 |
11 | def handle_data(data, layer):
12 | Y = data.X if layer is None else data.layers[layer]
13 | if not isinstance(Y, np.ndarray):
14 | Y = Y.toarray()
15 | return Y
16 |
17 |
18 | def handle_design_parameter(design, obs_data):
19 | if isinstance(design, np.ndarray):
20 | if design.ndim == 1:
21 | # Throw error
22 | raise ValueError("design specified as a 1d array is not supported yet")
23 | elif design.ndim == 2:
24 | design_matrix = pd.DataFrame(design)
25 | design_formula = None
26 | else:
27 | raise ValueError("design must be a 2d array")
28 | elif isinstance(design, pd.DataFrame):
29 | design_matrix = design
30 | design_formula = None
31 | elif isinstance(design, list):
32 | return handle_design_parameter(" * ".join(design), obs_data)
33 | elif isinstance(design, str):
34 | # Check if design starts with a ~
35 | if design[0] != "~":
36 | design = "~" + design + " - 1"
37 | design_matrix, design_formula = convert_formula_to_design_matrix(design, obs_data)
38 | else:
39 | raise ValueError("design must be a 2d array or string")
40 |
41 | return design_matrix, design_formula
42 |
43 |
44 | def handle_obs_data(adata, obs_data):
45 | a = make_data_frame(adata.obs)
46 | b = make_data_frame(obs_data, preferred_index=a.index if a is not None else None)
47 | if a is None and b is None:
48 | return pd.DataFrame(index=pd.RangeIndex(0, adata.shape[0]))
49 | elif a is None:
50 | return b
51 | elif b is None:
52 | return a
53 | else:
54 | return pd.concat([a, b], axis=1)
55 |
56 |
57 | def make_data_frame(data, preferred_index=None):
58 | if data is None:
59 | return None
60 | if isinstance(data, pd.DataFrame) and preferred_index is None:
61 | return data
62 | if isinstance(data, pd.DataFrame) and preferred_index is not None:
63 | if preferred_index.equals(data.index) or preferred_index.equals(data.index.map(str)):
64 | data.index = preferred_index
65 | return data
66 | else:
67 | raise ValueError("The index of adata.obs and obsData do not match")
68 | elif isinstance(data, Mapping):
69 | return pd.DataFrame(data, index=preferred_index)
70 | else:
71 | raise ValueError("data must be None, a pandas DataFrame or a Mapping object")
72 |
73 |
74 | def convert_formula_to_design_matrix(formula, obs_data):
75 | # Check if formula is string
76 | if isinstance(formula, str):
77 | # Convert formula to design matrix
78 | # design_matrix = patsy.dmatrix(formula, obs_data)
79 | design_matrix = model_matrix(formula, obs_data)
80 | return design_matrix, formula
81 | else:
82 | raise ValueError("formula must be a string")
83 |
84 |
85 | def row_groups(matrix, return_reduced_matrix=False, return_group_ids=False):
86 | reduced_matrix, inv = np.unique(matrix, axis=0, return_inverse=True)
87 | if NumpyVersion(np.__version__) >= "2.0.0rc":
88 | inv = np.squeeze(inv)
89 | group_ids = np.unique(inv)
90 | if return_reduced_matrix and return_group_ids:
91 | return inv, reduced_matrix, group_ids
92 | elif return_reduced_matrix:
93 | return inv, reduced_matrix
94 | elif return_group_ids:
95 | return inv, group_ids
96 | else:
97 | return inv
98 |
--------------------------------------------------------------------------------
/src/pylemur/tl/_grassmann.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def grassmann_map(x, base_point):
5 | if base_point.shape[0] == 0 or base_point.shape[1] == 0:
6 | return base_point
7 | elif np.isnan(x).any():
8 | # Return an object with the same shape as x filled with nan
9 | return np.full(x.shape, np.nan)
10 | else:
11 | u, s, vt = np.linalg.svd(x, full_matrices=False)
12 | return (base_point @ vt.T) @ np.diag(np.cos(s)) @ vt + u @ np.diag(np.sin(s)) @ vt
13 |
14 |
15 | def grassmann_log(p, q):
16 | n = p.shape[0]
17 | k = p.shape[1]
18 |
19 | if n == 0 or k == 0:
20 | return p
21 | else:
22 | z = q.T @ p
23 | At = q.T - z @ p.T
24 | # Translate `lm.fit(z, At)$coefficients` to python
25 | Bt = np.linalg.lstsq(z, At, rcond=None)[0]
26 | u, s, vt = np.linalg.svd(Bt.T, full_matrices=True)
27 | u = u[:, :k]
28 | s = s[:k]
29 | vt = vt[:k, :]
30 | return u @ np.diag(np.arctan(s)) @ vt
31 |
32 |
33 | def grassmann_project(x):
34 | return np.linalg.qr(x)[0]
35 |
36 |
37 | def grassmann_project_tangent(x, base_point):
38 | return x - base_point @ base_point.T @ x
39 |
40 |
41 | def grassmann_random_point(n, k):
42 | x = np.random.randn(n, k)
43 | return grassmann_project(x)
44 |
45 |
46 | def grassmann_random_tangent(base_point):
47 | x = np.random.randn(*base_point.shape)
48 | return grassmann_project_tangent(x, base_point)
49 |
50 |
51 | def grassmann_angle_from_tangent(x, normalized=True):
52 | thetas = np.linalg.svd(x, full_matrices=True, compute_uv=False) / np.pi * 180
53 | if normalized:
54 | return np.minimum(thetas, 180 - thetas).max()
55 | else:
56 | return thetas[0]
57 |
58 |
59 | def grassmann_angle_from_point(x, y):
60 | return grassmann_angle_from_tangent(grassmann_log(y, x))
61 |
--------------------------------------------------------------------------------
/src/pylemur/tl/_grassmann_lm.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from pylemur.tl._design_matrix_utils import row_groups
4 | from pylemur.tl._grassmann import grassmann_log, grassmann_map
5 | from pylemur.tl._lin_alg_wrappers import fit_pca, ridge_regression
6 |
7 |
8 | def grassmann_geodesic_regression(coord_systems, design, base_point, weights=None):
9 | """
10 | Fit geodesic regression on Grassmann manifolds
11 |
12 | Solve Sum_j d(U_j, Exp_p(Sum_k V_k:: * X_jk)) for V, where
13 | d(U, V) = ||Log(U, V)|| is the inverse of the exponential map on the Grassmann manifold.
14 |
15 | Parameters
16 | ----------
17 | coord_systems : a list of orthonormal 2D matrices (length n_groups)
18 | design : design matrix, shape (n_groups, n_coef)
19 | base_point : array-like, shape (n_emb, n_features)
20 | The base point.
21 | weights : array-like, shape (n_groups,)
22 |
23 | Returns
24 | -------
25 | beta: array-like, shape (n_emb, n_features, n_coef)
26 | """
27 | n_obs = design.shape[0]
28 | n_coef = design.shape[1]
29 | n_emb = base_point.shape[0]
30 | n_features = base_point.shape[1]
31 |
32 | assert len(coord_systems) == n_obs
33 | for i in range(n_obs):
34 | assert coord_systems[i].shape == (n_emb, n_features)
35 | if weights is None:
36 | weights = np.ones(n_obs)
37 |
38 | tangent_vecs = [grassmann_log(base_point.T, coord_systems[i].T).T.reshape(n_emb * n_features) for i in range(n_obs)]
39 | tangent_vecs = np.vstack(tangent_vecs)
40 | if tangent_vecs.shape[0] == 0:
41 | tangent_fit = np.zeros((0, n_coef))
42 | else:
43 | tangent_fit = ridge_regression(tangent_vecs, design, weights=weights)
44 |
45 | tangent_fit = tangent_fit.reshape((n_coef, n_emb, n_features)).transpose((1, 2, 0))
46 | return tangent_fit
47 |
48 |
49 | def grassmann_lm(Y, design_matrix, base_point):
50 | """
51 | Solve Sum_i||Y_i: - Y_i: Proj(Exp_p(Sum_k V_k:: * X_ik))||^2 for V.
52 |
53 | Parameters
54 | ----------
55 | Y : array-like, shape (n_samples, n_features)
56 | The input data matrix.
57 | design_matrix : array-like, shape (n_samples, n_coef)
58 | The design matrix.
59 | base_point : array-like, shape (n_emb, n_features)
60 | The base point.
61 |
62 | Returns
63 | -------
64 | beta: array-like, shape (n_emb, n_features, n_coef)
65 | """
66 | n_emb = base_point.shape[0]
67 |
68 | des_row_groups, reduced_design_matrix, des_row_group_ids = row_groups(
69 | design_matrix, return_reduced_matrix=True, return_group_ids=True
70 | )
71 | if np.min(np.unique(des_row_groups, return_counts=True)[1]) < n_emb:
72 | raise ValueError("Too few dataset points in some design matrix group.")
73 | group_planes = [fit_pca(Y[des_row_groups == i, :], n_emb, center=False).coord_system for i in des_row_group_ids]
74 | group_sizes = [np.sum(des_row_groups == i) for i in des_row_group_ids]
75 |
76 | coef = grassmann_geodesic_regression(group_planes, reduced_design_matrix, base_point, weights=group_sizes)
77 | return coef
78 |
79 |
80 | def project_diffemb_into_data_space(embedding, design_matrix, coefficients, base_point):
81 | n_features = base_point.shape[1]
82 | res = np.zeros((design_matrix.shape[0], n_features))
83 | des_row_groups, reduced_design_matrix, des_row_group_ids = row_groups(
84 | design_matrix, return_reduced_matrix=True, return_group_ids=True
85 | )
86 | for id in des_row_group_ids:
87 | covar = reduced_design_matrix[id, :]
88 | subspace = grassmann_map(np.dot(coefficients, covar).T, base_point.T)
89 | res[des_row_groups == id, :] = embedding[des_row_groups == id, :] @ subspace.T
90 | return res
91 |
92 |
93 | def project_data_on_diffemb(Y, design_matrix, coefficients, base_point):
94 | n_emb = base_point.shape[0]
95 | n_obs = Y.shape[0]
96 | res = np.zeros((n_obs, n_emb))
97 | des_row_groups, reduced_design_matrix, des_row_group_ids = row_groups(
98 | design_matrix, return_reduced_matrix=True, return_group_ids=True
99 | )
100 | for id in des_row_group_ids:
101 | Y_subset = Y[des_row_groups == id, :]
102 | covar = reduced_design_matrix[id, :]
103 | subspace = grassmann_map(np.dot(coefficients, covar).T, base_point.T)
104 | res[des_row_groups == id, :] = Y_subset @ subspace
105 | return res
106 |
--------------------------------------------------------------------------------
/src/pylemur/tl/_lin_alg_wrappers.py:
--------------------------------------------------------------------------------
1 | from typing import NamedTuple
2 |
3 | import numpy as np
4 | import sklearn.decomposition as skd
5 |
6 |
7 | class PCA(NamedTuple):
8 | embedding: np.ndarray
9 | coord_system: np.ndarray
10 | offset: np.ndarray
11 |
12 |
13 | def fit_pca(Y, n, center=True):
14 | """
15 | Calculate the PCA of a given data matrix Y.
16 |
17 | Parameters
18 | ----------
19 | Y : array-like, shape (n_samples, n_features)
20 | The input data matrix.
21 | n : int
22 | The number of principal components to return.
23 | center : bool, default=True
24 | If True, the data will be centered before computing the covariance matrix.
25 |
26 | Returns
27 | -------
28 | pca : sklearn.decomposition.PCA
29 | The PCA object.
30 | """
31 | if center:
32 | pca = skd.PCA(n_components=n)
33 | emb = pca.fit_transform(Y)
34 | coord_system = pca.components_
35 | mean = pca.mean_
36 | else:
37 | svd = skd.TruncatedSVD(n_components=n, algorithm="arpack")
38 | emb = svd.fit_transform(Y)
39 | coord_system = svd.components_
40 | mean = np.zeros(Y.shape[1])
41 | return PCA(emb, coord_system, mean)
42 |
43 |
44 | def ridge_regression(Y, X, ridge_penalty=0, weights=None):
45 | """
46 | Calculate the ridge regression of a given data matrix Y.
47 |
48 | Parameters
49 | ----------
50 | Y : array-like, shape (n_samples, n_features)
51 | The input data matrix.
52 | X : array-like, shape (n_samples, n_coef)
53 | The input data matrix.
54 | ridge_penalty : float, default=0
55 | The ridge penalty.
56 | weights : array-like, shape (n_features,)
57 | The weights to apply to each feature.
58 |
59 | Returns
60 | -------
61 | ridge: array-like, shape (n_coef, n_features)
62 | """
63 | n_coef = X.shape[1]
64 | n_samples = X.shape[0]
65 | n_feat = Y.shape[1]
66 | assert Y.shape[0] == n_samples
67 | if weights is None:
68 | weights = np.ones(n_samples)
69 | assert len(weights) == n_samples
70 |
71 | if np.ndim(ridge_penalty) == 0 or len(ridge_penalty) == 1:
72 | ridge_penalty = np.eye(n_coef) * ridge_penalty
73 | elif np.ndim(ridge_penalty) == 1:
74 | assert len(ridge_penalty) == n_coef
75 | ridge_penalty = np.diag(ridge_penalty)
76 | elif np.ndim(ridge_penalty) == 1:
77 | assert ridge_penalty.shape == (n_coef, n_coef)
78 | pass
79 | else:
80 | raise ValueError("ridge_penalty must be a scalar, 1d array, or 2d array")
81 |
82 | ridge_penalty_sq = np.sqrt(np.sum(weights)) * (ridge_penalty.T @ ridge_penalty)
83 | weights_sqrt = np.sqrt(weights)
84 | X_ext = np.vstack([multiply_along_axis(X, weights_sqrt, 0), ridge_penalty_sq])
85 | Y_ext = np.vstack([multiply_along_axis(Y, weights_sqrt, 0), np.zeros((n_coef, n_feat))])
86 |
87 | ridge = np.linalg.lstsq(X_ext, Y_ext, rcond=None)[0]
88 | return ridge
89 |
90 |
91 | def multiply_along_axis(A, B, axis):
92 | # Copied from https://stackoverflow.com/a/71750176/604854
93 | return np.swapaxes(np.swapaxes(A, axis, -1) * B, -1, axis)
94 |
--------------------------------------------------------------------------------
/src/pylemur/tl/alignment.py:
--------------------------------------------------------------------------------
1 | import harmonypy
2 | import numpy as np
3 |
4 | from pylemur.tl._design_matrix_utils import row_groups
5 | from pylemur.tl._lin_alg_wrappers import ridge_regression
6 |
7 |
8 | def _align_impl(
9 | embedding,
10 | grouping,
11 | design_matrix,
12 | ridge_penalty=0.01,
13 | preserve_position_of_NAs=False,
14 | calculate_new_embedding=True,
15 | verbose=True,
16 | ):
17 | if grouping.ndim == 1:
18 | uniq_elem, fct_levels = np.unique(grouping, return_inverse=True)
19 | I = np.eye(len(uniq_elem))
20 | I[:, np.isnan(uniq_elem)] = 0
21 | grouping_matrix = I[:, fct_levels]
22 | else:
23 | assert grouping.shape[1] == embedding.shape[0]
24 | assert np.all(grouping[~np.isnan(grouping)] >= 0)
25 | col_sums = grouping.sum(axis=0)
26 | col_sums[col_sums == 0] = 1
27 | grouping_matrix = grouping / col_sums
28 |
29 | n_groups = grouping_matrix.shape[0]
30 | n_emb = embedding.shape[1]
31 | K = design_matrix.shape[1]
32 |
33 | # NA's are converted to zero columns ensuring that `diff %*% grouping_matrix = 0`
34 | grouping_matrix[:, np.isnan(grouping_matrix.sum(axis=0))] = 0
35 | if not preserve_position_of_NAs:
36 | not_all_zero_column = grouping_matrix.sum(axis=0) != 0
37 | grouping_matrix = grouping_matrix[:, not_all_zero_column]
38 | embedding = embedding[not_all_zero_column, :]
39 | design_matrix = design_matrix[not_all_zero_column]
40 |
41 | des_row_groups, des_row_group_ids = row_groups(design_matrix, return_group_ids=True)
42 | n_conditions = des_row_group_ids.shape[0]
43 | cond_ct_means = [np.zeros((n_emb, n_groups)) for _ in des_row_group_ids]
44 | for id in des_row_group_ids:
45 | sel = des_row_groups == id
46 | for idx in range(n_groups):
47 | if grouping_matrix[idx, sel].sum() > 0:
48 | cond_ct_means[id][:, idx] = np.average(embedding[sel, :], axis=0, weights=grouping_matrix[idx, sel])
49 | else:
50 | cond_ct_means[id][:, idx] = np.nan
51 |
52 | target = np.zeros((n_emb, n_groups)) * np.nan
53 | for idx in range(n_groups):
54 | tmp = np.zeros((n_conditions, n_emb))
55 | for id in des_row_group_ids:
56 | tmp[id, :] = cond_ct_means[id][:, idx]
57 | if np.all(np.isnan(tmp)):
58 | target[:, idx] = np.nan
59 | else:
60 | target[:, idx] = np.average(tmp[:, ~np.isnan(tmp.sum(axis=0))], axis=0)
61 |
62 | new_pos = embedding.copy()
63 | for id in des_row_group_ids:
64 | sel = des_row_groups == id
65 | diff = target - cond_ct_means[id]
66 | # NA's are converted to zero so that they don't propagate.
67 | diff[np.isnan(diff)] = 0
68 | new_pos[sel, :] = new_pos[sel, :] + (diff @ grouping_matrix[:, sel]).T
69 |
70 | intercept_emb = np.hstack([np.ones((embedding.shape[0], 1)), embedding])
71 | interact_design_matrix = np.repeat(design_matrix, n_emb + 1, axis=1) * np.hstack([intercept_emb] * K)
72 | alignment_coefs = ridge_regression(new_pos - embedding, interact_design_matrix, ridge_penalty)
73 | ## The alignment error is weird, as it doesn't necessarily go down. Better not to show it
74 | # if verbose:
75 | # print(f"Alignment error: {np.linalg.norm((new_pos - embedding) - interact_design_matrix @ alignment_coefs)}")
76 | alignment_coefs = alignment_coefs.reshape((K, n_emb + 1, n_emb)).transpose((2, 1, 0))
77 | if calculate_new_embedding:
78 | new_embedding = _apply_linear_transformation(embedding, alignment_coefs, design_matrix)
79 | return alignment_coefs, new_embedding
80 | else:
81 | return alignment_coefs
82 |
83 |
84 | def _apply_linear_transformation(embedding, alignment_coefs, design_matrix):
85 | des_row_groups, reduced_design_matrix, des_row_group_ids = row_groups(
86 | design_matrix, return_reduced_matrix=True, return_group_ids=True
87 | )
88 | embedding = embedding.copy()
89 | for id in des_row_group_ids:
90 | sel = des_row_groups == id
91 | embedding[sel, :] = (
92 | np.hstack([np.ones((np.sum(sel), 1)), embedding[sel, :]])
93 | @ _forward_linear_transformation(alignment_coefs, reduced_design_matrix[id, :]).T
94 | )
95 | return embedding
96 |
97 |
98 | def _forward_linear_transformation(alignment_coef, design_vector):
99 | n_emb = alignment_coef.shape[0]
100 | if n_emb == 0:
101 | return np.zeros((0, 0))
102 | else:
103 | return np.hstack([np.zeros((n_emb, 1)), np.eye(n_emb)]) + np.dot(alignment_coef, design_vector)
104 |
105 |
106 | def _reverse_linear_transformation(alignment_coef, design_vector):
107 | n_emb = alignment_coef.shape[0]
108 | if n_emb == 0:
109 | return np.zeros((0, 0))
110 | else:
111 | return np.linalg.inv(np.eye(n_emb) + np.dot(alignment_coef[:, 1:, :], design_vector))
112 |
113 |
114 | def _init_harmony(
115 | embedding,
116 | design_matrix,
117 | theta=2,
118 | lamb=1,
119 | sigma=0.1,
120 | nclust=None,
121 | tau=0,
122 | block_size=0.05,
123 | max_iter_kmeans=20,
124 | epsilon_cluster=1e-5,
125 | epsilon_harmony=1e-4,
126 | verbose=True,
127 | ):
128 | n_obs = embedding.shape[0]
129 | des_row_groups, des_row_group_ids = row_groups(design_matrix, return_group_ids=True)
130 | n_groups = len(des_row_group_ids)
131 | if nclust is None:
132 | nclust = np.min([np.round(n_obs / 30.0), 100]).astype(int)
133 |
134 | phi = np.eye(n_groups)[:, des_row_groups]
135 | # phi_n = np.ones(n_groups)
136 |
137 | N_b = phi.sum(axis=1)
138 | Pr_b = N_b / n_obs
139 | sigma = np.repeat(sigma, nclust)
140 |
141 | theta = np.repeat(theta, n_groups)
142 | if tau != 0:
143 | theta = theta * (1 - np.exp(-((N_b / (nclust * tau)) ** 2)))
144 |
145 | lamb = np.repeat(lamb, n_groups)
146 | lamb_mat = np.diag(np.insert(lamb, 0, 0))
147 | phi_moe = np.vstack((np.repeat(1, n_obs), phi))
148 |
149 | max_iter_harmony = 0
150 | ho = harmonypy.Harmony(
151 | embedding.T,
152 | phi,
153 | phi_moe,
154 | Pr_b,
155 | sigma,
156 | theta,
157 | max_iter_harmony,
158 | max_iter_kmeans,
159 | epsilon_cluster,
160 | epsilon_harmony,
161 | nclust,
162 | block_size,
163 | lamb_mat,
164 | verbose,
165 | )
166 | return ho
167 |
--------------------------------------------------------------------------------
/src/pylemur/tl/lemur.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from collections.abc import Iterable, Mapping
3 | from typing import Any, Literal
4 | import copy
5 |
6 | import anndata as ad
7 | import formulaic_contrasts
8 | import numpy as np
9 | import pandas as pd
10 | from sklearn.exceptions import NotFittedError
11 |
12 | from pylemur.tl._design_matrix_utils import handle_data, handle_design_parameter, handle_obs_data, row_groups
13 | from pylemur.tl._grassmann import grassmann_map
14 | from pylemur.tl._grassmann_lm import grassmann_lm, project_data_on_diffemb
15 | from pylemur.tl._lin_alg_wrappers import fit_pca, multiply_along_axis, ridge_regression
16 | from pylemur.tl.alignment import (
17 | _align_impl,
18 | _apply_linear_transformation,
19 | _init_harmony,
20 | _reverse_linear_transformation,
21 | )
22 |
23 |
24 | class LEMUR:
25 | r"""Fit the LEMUR model
26 |
27 | A python implementation of the LEMUR algorithm. For more details please refer
28 | to Ahlmann-Eltze (2024).
29 |
30 | Parameters
31 | ----------
32 | data
33 | The AnnData object (or a different matrix container) with the variance stabilized data and the
34 | cell-wise annotations in `data.obs`.
35 | design
36 | A specification of the experimental design. This can be a string,
37 | which is then parsed using `formulaic`. Alternatively, it can be a
38 | a list of strings, which are assumed to refer to the columns in
39 | `data.obs`. Finally, it can be a numpy array, representing a
40 | design matrix of size `n_cells` x `n_covariates`. If not provided,
41 | a constant design is used.
42 | obs_data
43 | A pandas DataFrame or a dictionary of iterables containing the
44 | cell-wise annotations. It is used in combination with the
45 | information in `data.obs`.
46 | n_embedding
47 | The number of dimensions to use for the shared embedding space.
48 | linear_coefficient_estimator
49 | The method to use for estimating the linear coefficients. If `"linear"`,
50 | the linear coefficients are estimated using ridge regression. If `"zero"`,
51 | the linear coefficients are set to zero.
52 | layer
53 | The name of the layer to use in `data`. If `None`, the `X` slot is used.
54 | copy
55 | Whether to make a copy of `data`.
56 |
57 | Attributes
58 | ----------
59 | embedding : :class:`~numpy.ndarray` (:math:`C \times P`)
60 | Low-dimensional representation of each cell
61 | adata : :class:`~anndata.AnnData`
62 | A reference to (potentially a copy of) the input data.
63 | data_matrix : :class:`~numpy.ndarray` (:math:`C \times G`)
64 | A reference to the data matrix from the `adata` object.
65 | n_embedding : int
66 | The number of latent dimensions
67 | design_matrix : `ModelMatrix` (:math:`C \times K`)
68 | The design matrix that is used for the fit.
69 | formula : str
70 | The design formula specification.
71 | coefficients : :class:`~numpy.ndarray` (:math:`P \times G \times K`)
72 | The 3D array of coefficients for the Grassmann regression.
73 | alignment_coefficients : :class:`~numpy.ndarray` (:math:`P \times (P+1) \times K`)
74 | The 3D array of coefficients for the affine alignment.
75 | linear_coefficients : :class:`~numpy.ndarray` (:math:`K\times G`)
76 | The 2D array of coefficients for the linear offset per condition.
77 | linear_coefficient_estimator : str
78 | The linear coefficient estimation specification.
79 | base_point : :class:`~numpy.ndarray` (:math:`(P \times G`))
80 | The 2D array representing the reference subspace.
81 |
82 | Examples
83 | --------
84 | >>> model = pylemur.tl.LEMUR(adata, design="~ label + batch_cov", n_embedding=15)
85 | >>> model.fit()
86 | >>> model.align_with_harmony()
87 | >>> pred_expr = model.predict(new_condition=model.cond(label="treated"))
88 | >>> emb_proj = model_small.transform(adata)
89 | """
90 |
91 | def __init__(
92 | self,
93 | adata: ad.AnnData | Any,
94 | design: str | list[str] | np.ndarray = "~ 1",
95 | obs_data: pd.DataFrame | Mapping[str, Iterable[Any]] | None = None,
96 | n_embedding: int = 15,
97 | linear_coefficient_estimator: Literal["linear", "zero"] = "linear",
98 | layer: str | None = None,
99 | copy: bool = True,
100 | ):
101 | adata = _handle_data_arg(adata)
102 | if copy:
103 | adata = adata.copy()
104 | self.adata = adata
105 |
106 | adata.obs = handle_obs_data(adata, obs_data)
107 | design_matrix, formula = handle_design_parameter(design, adata.obs)
108 |
109 | if formula:
110 | self.contrast_builder = formulaic_contrasts.FormulaicContrasts(adata.obs, formula)
111 | assert design_matrix.equals(self.contrast_builder.design_matrix)
112 | else:
113 | self.contrast_builder = None
114 |
115 | self.design_matrix = design_matrix
116 | self.formula = formula
117 | if design_matrix.shape[0] != adata.shape[0]:
118 | raise ValueError("number of rows in design matrix must be equal to number of samples in data")
119 | self.data_matrix = handle_data(adata, layer)
120 | self.linear_coefficient_estimator = linear_coefficient_estimator
121 | self.n_embedding = n_embedding
122 | self.embedding = None
123 | self.coefficients = None
124 | self.linear_coefficients = None
125 | self.base_point = None
126 | self.alignment_coefficients = None
127 |
128 | def fit(self, verbose: bool = True):
129 | """Fit the LEMUR model
130 |
131 | Parameters
132 | ----------
133 | verbose
134 | Whether to print progress to the console.
135 |
136 | Returns
137 | -------
138 | `self`
139 | The fitted LEMUR model.
140 | """
141 | Y = self.data_matrix
142 | design_matrix = self.design_matrix
143 | n_embedding = self.n_embedding
144 |
145 | if self.linear_coefficient_estimator == "linear":
146 | if verbose:
147 | print("Centering the data using linear regression.")
148 | lin_coef = ridge_regression(Y, design_matrix.to_numpy())
149 | Y = Y - design_matrix.to_numpy() @ lin_coef
150 | else: # linear_coefficient_estimator == "zero"
151 | lin_coef = np.zeros((design_matrix.shape[1], Y.shape[1]))
152 |
153 | if verbose:
154 | print("Find base point")
155 | base_point = fit_pca(Y, n_embedding, center=False).coord_system
156 | if verbose:
157 | print("Fit regression on latent spaces")
158 | coefficients = grassmann_lm(Y, design_matrix.to_numpy(), base_point)
159 | if verbose:
160 | print("Find shared embedding coordinates")
161 | embedding = project_data_on_diffemb(Y, design_matrix.to_numpy(), coefficients, base_point)
162 |
163 | embedding, coefficients, base_point = _order_axis_by_variance(embedding, coefficients, base_point)
164 |
165 | self.embedding = embedding
166 | self.alignment_coefficients = np.zeros((n_embedding, n_embedding + 1, design_matrix.shape[1]))
167 | self.coefficients = coefficients
168 | self.base_point = base_point
169 | self.linear_coefficients = lin_coef
170 |
171 | return self
172 |
173 | def align_with_harmony(
174 | self, ridge_penalty: float | list[float] | np.ndarray = 0.01, max_iter: int = 10, verbose: bool = True
175 | ):
176 | """Fine-tune the embedding with a parametric version of Harmony.
177 |
178 | Parameters
179 | ----------
180 | ridge_penalty
181 | The penalty controlling the flexibility of the alignment. Smaller
182 | values mean more flexible alignments.
183 | max_iter
184 | The maximum number of iterations to perform.
185 | verbose
186 | Whether to print progress to the console.
187 |
188 |
189 | Returns
190 | -------
191 | `self`
192 | The fitted LEMUR model with the updated embedding space stored in
193 | `model.embedding` attribute and an the updated alignment coefficients
194 | stored in `model.alignment_coefficients`.
195 | """
196 | if self.embedding is None:
197 | raise NotFittedError(
198 | "self.embedding is None. Make sure to call 'model.fit()' "
199 | + "before calling 'model.align_with_harmony()'."
200 | )
201 |
202 | embedding = self.embedding.copy()
203 | design_matrix = self.design_matrix
204 | # Init harmony
205 | harm_obj = _init_harmony(embedding, design_matrix, verbose=verbose)
206 | for idx in range(max_iter):
207 | if verbose:
208 | print(f"Alignment iteration {idx}")
209 | # Update harmony
210 | harm_obj.cluster()
211 | # alignment <- align_impl(training_fit$embedding, harm_obj$R, act_design_matrix, ridge_penalty = ridge_penalty)
212 | al_coef, new_emb = _align_impl(
213 | embedding,
214 | harm_obj.R,
215 | design_matrix,
216 | ridge_penalty=ridge_penalty,
217 | calculate_new_embedding=True,
218 | verbose=verbose,
219 | )
220 | harm_obj.Z_corr = new_emb.T
221 | harm_obj.Z_cos = multiply_along_axis(
222 | new_emb, 1 / np.linalg.norm(new_emb, axis=1).reshape((new_emb.shape[0], 1)), axis=1
223 | ).T
224 |
225 | if harm_obj.check_convergence(1):
226 | if verbose:
227 | print("Converged")
228 | break
229 | self.alignment_coefficients = al_coef
230 | self.embedding = _apply_linear_transformation(embedding, al_coef, design_matrix)
231 | return self
232 |
233 | def align_with_grouping(
234 | self,
235 | grouping: list | np.ndarray | pd.Series,
236 | ridge_penalty: float | list[float] | np.ndarray = 0.01,
237 | preserve_position_of_NAs: bool = False,
238 | verbose: bool = True,
239 | ):
240 | """Fine-tune the embedding using annotated groups of cells.
241 |
242 | Parameters
243 | ----------
244 | grouping
245 | A list, :class:`~numpy.ndarray`, or pandas :class:`pandas.Series` specifying the group of cells.
246 | The groups span different conditions and can for example be cell types.
247 | ridge_penalty
248 | The penalty controlling the flexibility of the alignment.
249 | preserve_position_of_NAs
250 | `True` means that `NA`'s in the `grouping` indicate that these cells should stay
251 | where they are (if possible). `False` means that they are free to move around.
252 | verbose
253 | Whether to print progress to the console.
254 |
255 | Returns
256 | -------
257 | `self`
258 | The fitted LEMUR model with the updated embedding space stored in
259 | `model.embedding` attribute and an the updated alignment coefficients
260 | stored in `model.alignment_coefficients`.
261 | """
262 | if self.embedding is None:
263 | raise NotFittedError(
264 | "self.embedding is None. Make sure to call 'model.fit()' "
265 | + "before calling 'model.align_with_grouping()'."
266 | )
267 | embedding = self.embedding.copy()
268 | design_matrix = self.design_matrix
269 | if isinstance(grouping, list):
270 | grouping = pd.Series(grouping)
271 |
272 | if isinstance(grouping, pd.Series):
273 | grouping = grouping.factorize()[0] * 1.0
274 | grouping[grouping == -1] = np.nan
275 |
276 | al_coef = _align_impl(
277 | embedding,
278 | grouping,
279 | design_matrix,
280 | ridge_penalty=ridge_penalty,
281 | calculate_new_embedding=False,
282 | verbose=verbose,
283 | )
284 | self.alignment_coefficients = al_coef
285 | self.embedding = _apply_linear_transformation(embedding, al_coef, design_matrix)
286 | return self
287 |
288 | def transform(
289 | self,
290 | adata: ad.AnnData,
291 | layer: str | None = None,
292 | obs_data: pd.DataFrame | Mapping[str, Iterable[Any]] | None = None,
293 | return_type: Literal["embedding", "LEMUR"] = "embedding",
294 | ):
295 | """Transform data using the fitted LEMUR model
296 |
297 | Parameters
298 | ----------
299 | adata
300 | The AnnData object to transform.
301 | obs_data
302 | Optional set of annotations for each cell (same as `obs_data` in the
303 | constructor).
304 | return_type
305 | Flag that decides if the function returns a full `LEMUR` object or
306 | only the embedding.
307 |
308 | Returns
309 | -------
310 | :class:`~pylemur.tl.LEMUR`
311 | (if `return_type = "LEMUR"`) A new `LEMUR` object object with the embedding
312 | calculated for the input `adata`.
313 |
314 | :class:`~numpy.ndarray`
315 | (if `return_type = "embedding"`) A 2D numpy array of the embedding matrix
316 | calculated for the input `adata` (with cells in the rows and latent dimensions
317 | in the columns).
318 | """
319 | Y = handle_data(adata, layer)
320 | adata.obs = handle_obs_data(adata, obs_data)
321 | design_matrix, _ = handle_design_parameter(self.formula, adata.obs)
322 | dm = design_matrix.to_numpy()
323 | Y_clean = Y - dm @ self.linear_coefficients
324 | embedding = project_data_on_diffemb(
325 | Y_clean, design_matrix=dm, coefficients=self.coefficients, base_point=self.base_point
326 | )
327 | embedding = _apply_linear_transformation(embedding, self.alignment_coefficients, dm)
328 | if return_type == "embedding":
329 | return embedding
330 | elif return_type == "LEMUR":
331 | fit = LEMUR.copy()
332 | fit.adata = adata
333 | fit.design_matrix = design_matrix
334 | fit.embedding = embedding
335 | return fit
336 |
337 | def predict(
338 | self,
339 | embedding: np.ndarray | None = None,
340 | new_design: str | list[str] | np.ndarray | None = None,
341 | new_condition: np.ndarray | pd.DataFrame | None = None,
342 | obs_data: pd.DataFrame | Mapping[str, Iterable[Any]] | None = None,
343 | new_adata_layer: None | str = None,
344 | ):
345 | """Predict the expression of cells in a specific condition
346 |
347 | Parameters
348 | ----------
349 | embedding
350 | The coordinates of the cells in the shared embedding space. If None,
351 | the coordinates stored in `model.embedding` are used.
352 | new_design
353 | Either a design formula parsed using `model.adata.obs` and `obs_data` or
354 | a design matrix defining the condition for each cell. If both `new_design`
355 | and `new_condition` are None, the original design matrix
356 | (`model.design_matrix`) is used.
357 | new_condition
358 | A specification of the new condition that is applied to all cells. Typically,
359 | this is generated by `cond(...)`.
360 | obs_data
361 | A DataFrame-like object containing cell-wise annotations. It is only used if `new_design`
362 | contains a formulaic formula string.
363 | new_adata_layer
364 | If not `None`, the function returns `self` and stores the prediction in
365 | `model.adata["new_adata_layer"]`.
366 |
367 | Returns
368 | -------
369 | array-like, shape (n_cells, n_genes)
370 | The predicted expression of the cells in the new condition.
371 | """
372 | if embedding is None:
373 | if self.embedding is None:
374 | raise NotFittedError("The model has not been fitted yet.")
375 | embedding = self.embedding
376 |
377 | if new_condition is not None:
378 | if new_design is not None:
379 | warnings.warn("new_design is ignored if new_condition is provided.", stacklevel=1)
380 |
381 | if isinstance(new_condition, pd.DataFrame):
382 | new_design = new_condition.to_numpy()
383 | elif isinstance(new_condition, pd.Series):
384 | new_design = np.expand_dims(new_condition.to_numpy(), axis=0)
385 | elif isinstance(new_condition, np.ndarray):
386 | new_design = new_condition
387 | else:
388 | raise ValueError("new_condition must be a created using 'cond(...)' or a numpy array.")
389 | if new_design.shape[0] != 1:
390 | raise ValueError("new_condition must only have one row")
391 | # Repeat values row-wise
392 | new_design = np.ones((embedding.shape[0], 1)) @ new_design
393 | elif new_design is None:
394 | new_design = self.design_matrix.to_numpy()
395 | else:
396 | new_design = handle_design_parameter(new_design, handle_obs_data(self.adata, obs_data))[0].to_numpy()
397 |
398 | # Make prediciton
399 | approx = new_design @ self.linear_coefficients
400 |
401 | coef = self.coefficients
402 | al_coefs = self.alignment_coefficients
403 | des_row_groups, reduced_design_matrix, des_row_group_ids = row_groups(
404 | new_design, return_reduced_matrix=True, return_group_ids=True
405 | )
406 | for id in des_row_group_ids:
407 | covars = reduced_design_matrix[id, :]
408 | subspace = grassmann_map(np.dot(coef, covars).T, self.base_point.T)
409 | alignment = _reverse_linear_transformation(al_coefs, covars)
410 | offset = np.dot(al_coefs[:, 0, :], covars)
411 | approx[des_row_groups == id, :] += (
412 | (embedding[des_row_groups == id, :] - offset) @ alignment.T
413 | ) @ subspace.T
414 | if new_adata_layer is not None:
415 | self.adata.layers[new_adata_layer] = approx
416 | return self
417 | else:
418 | return approx
419 |
420 | def cond(self, **kwargs):
421 | """Define a condition for the `predict` function.
422 |
423 | Parameters
424 | ----------
425 | kwargs
426 | Named arguments specifying the levels of the covariates from the
427 | design formula. If a covariate is not specified, the first level is
428 | used.
429 |
430 | Returns
431 | -------
432 | `pd.Series`
433 | A contrast vector that aligns to the columns of the design matrix.
434 |
435 |
436 | Notes
437 | -----
438 | Subtracting two `cond(...)` calls, produces a contrast vector; these are
439 | commonly used in `R` to test for differences in a regression model.
440 | This pattern is inspired by the `R` package `glmGamPoi `__.
441 | """
442 | if self.contrast_builder:
443 | return self.contrast_builder.cond(**kwargs)
444 | else:
445 | raise ValueError("The design was not specified as a formula. Cannot automatically construct contrast.")
446 |
447 | def copy(self, copy_adata = True):
448 | cp = copy.deepcopy(self)
449 | if copy_adata:
450 | cp.adata = self.adata.copy()
451 | return cp
452 |
453 | def __deepcopy__(self, memo):
454 | cp = LEMUR(self.adata, copy=False)
455 | cp.contrast_builder = self.contrast_builder
456 | cp.design_matrix = self.design_matrix
457 | cp.formula = self.formula
458 | cp.data_matrix = self.data_matrix
459 | cp.linear_coefficient_estimator = self.linear_coefficient_estimator
460 | cp.n_embedding = self.n_embedding
461 | cp.embedding = self.embedding
462 | cp.coefficients = self.coefficients
463 | cp.linear_coefficients = self.linear_coefficients
464 | cp.base_point = self.base_point
465 | cp.alignment_coefficients = self.alignment_coefficients
466 | return cp
467 |
468 |
469 | def __eq__(self, other):
470 | """Equality testing"""
471 | raise NotImplementedError(
472 | "Equality comparisons are not supported for AnnData objects, "
473 | "instead compare the desired attributes."
474 | )
475 |
476 | def __str__(self):
477 | if self.embedding is None:
478 | return f"LEMUR model (not fitted yet) with {self.n_embedding} dimensions"
479 | else:
480 | return f"LEMUR model with {self.n_embedding} dimensions"
481 |
482 |
483 | def _handle_data_arg(data):
484 | if isinstance(data, ad.AnnData):
485 | return data
486 | else:
487 | return ad.AnnData(data)
488 |
489 |
490 | def _order_axis_by_variance(embedding, coefficients, base_point):
491 | U, d, Vt = np.linalg.svd(embedding, full_matrices=False)
492 | base_point = Vt @ base_point
493 | coefficients = np.einsum("pq,qij->pij", Vt, coefficients)
494 | embedding = U @ np.diag(d)
495 | return embedding, coefficients, base_point
496 |
--------------------------------------------------------------------------------
/tests/test_grasmann_lm.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from pylemur.tl._grassmann import grassmann_angle_from_point, grassmann_map, grassmann_project
4 | from pylemur.tl._grassmann_lm import (
5 | grassmann_geodesic_regression,
6 | grassmann_lm,
7 | project_data_on_diffemb,
8 | project_diffemb_into_data_space,
9 | )
10 | from pylemur.tl._lin_alg_wrappers import fit_pca
11 |
12 |
13 | def test_geodesic_regression():
14 | n_feat = 17
15 | base_point = grassmann_project(np.random.randn(n_feat, 3)).T
16 | assert np.allclose(base_point @ base_point.T, np.eye(3))
17 | coord_systems = [grassmann_project(np.random.randn(n_feat, 3)).T for _ in range(10)]
18 | x = np.arange(10)
19 | design = np.vstack([np.ones(10), x]).T
20 |
21 | fit = grassmann_geodesic_regression(coord_systems, design, base_point)
22 | assert fit.shape == (3, n_feat, 2)
23 | proj = grassmann_map(fit[:, :, 0].T, base_point.T)
24 | assert np.allclose(proj.T @ proj, np.eye(3))
25 | proj = grassmann_map(fit[:, :, 1].T, base_point.T)
26 | assert np.allclose(proj.T @ proj, np.eye(3))
27 |
28 |
29 | def test_grassmann_lm():
30 | n_obs = 100
31 | base_point = grassmann_project(np.random.randn(5, 2)).T
32 | data = np.random.randn(n_obs, 5)
33 | plane_all = fit_pca(data, 2, center=False).coord_system
34 | des = np.ones((n_obs, 1))
35 | fit = grassmann_lm(data, des, base_point)
36 | assert np.allclose(grassmann_angle_from_point(grassmann_map(fit[:, :, 0].T, base_point.T), plane_all.T), 0)
37 |
38 | # Make a design matrix of three groups (with an intercept)
39 | x = np.random.randint(3, size=n_obs)
40 | des = np.eye(3)[x, :]
41 | des = np.hstack([np.ones((n_obs, 1)), des[:, 1:3]])
42 | fit = grassmann_lm(data, des, base_point)
43 |
44 | plane_a = fit_pca(data[x == 0], 2, center=False).coord_system
45 | plane_b = fit_pca(data[x == 1], 2, center=False).coord_system
46 | plane_c = fit_pca(data[x == 2], 2, center=False).coord_system
47 |
48 | assert np.allclose(grassmann_angle_from_point(grassmann_map(fit[:, :, 0].T, base_point.T), plane_a.T), 0)
49 | assert np.allclose(
50 | grassmann_angle_from_point(grassmann_map((fit[:, :, 0] + fit[:, :, 1]).T, base_point.T), plane_b.T), 0
51 | )
52 | assert np.allclose(
53 | grassmann_angle_from_point(grassmann_map((fit[:, :, 0] + fit[:, :, 2]).T, base_point.T), plane_c.T), 0
54 | )
55 |
56 |
57 | def test_project_data_on_diffemb():
58 | n_obs = 100
59 | base_point = grassmann_project(np.random.randn(5, 2)).T
60 | data = np.random.randn(n_obs, 5)
61 | des = np.ones((n_obs, 1))
62 | fit = grassmann_lm(data, des, base_point)
63 | pca = fit_pca(data, 2, center=False)
64 | angle = grassmann_angle_from_point(grassmann_map(fit[:, :, 0].T, base_point.T), pca.coord_system.T)
65 | assert np.allclose(angle, 0)
66 |
67 | proj = project_data_on_diffemb(data, des, fit, base_point)
68 | # The projection and the embedding are rotated to each other
69 | # Remove rotation effect using orthogonal procrustes
70 | U, _, Vt = np.linalg.svd(proj.T @ pca.embedding, full_matrices=False)
71 | rot = U @ Vt
72 | assert np.allclose(proj @ rot, pca.embedding)
73 |
74 |
75 | def test_project_data_on_diffemb2():
76 | n_obs = 100
77 | base_point = grassmann_project(np.random.randn(5, 2)).T
78 | data = np.random.randn(n_obs, 5)
79 | des = np.ones((n_obs, 1))
80 | fit = grassmann_lm(data, des, base_point)
81 | pca = fit_pca(data, 2, center=False)
82 | angle = grassmann_angle_from_point(grassmann_map(fit[:, :, 0].T, base_point.T), pca.coord_system.T)
83 | assert np.allclose(angle, 0)
84 |
85 | proj = project_data_on_diffemb(data, des, fit, base_point)
86 | data_hat1 = pca.embedding @ pca.coord_system
87 | data_hat2 = project_diffemb_into_data_space(proj, des, fit, base_point)
88 | assert np.allclose(data_hat1, data_hat2)
89 |
--------------------------------------------------------------------------------
/tests/test_grassmann.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from pylemur.tl._grassmann import *
4 |
5 |
6 | def test_grassmann_map():
7 | # Test case 1: empty base point
8 | x = np.array([[1, 2], [3, 4], [5, 6]])
9 | p = np.array([])
10 | assert np.array_equal(grassmann_map(x, p), p)
11 |
12 | # Test case 2: x contains NaN values
13 | x = np.array([[1, 2], [3, np.nan], [5, 6]])
14 | p = np.array([[1, 0], [0, 1], [0, 0]])
15 | assert np.isnan(grassmann_map(x, p)).all()
16 |
17 | p = grassmann_random_point(5, 2)
18 | assert np.allclose(p.T @ p, np.eye(2))
19 | v = grassmann_random_tangent(p) * 10
20 | assert np.linalg.svd(v, compute_uv=False)[0] / np.pi * 180 > 90
21 |
22 | assert np.allclose(p.T @ v + v.T @ p, np.zeros((2, 2)))
23 | p2 = grassmann_map(v, p)
24 | valt = grassmann_log(p, p2)
25 | p3 = grassmann_map(valt, p)
26 |
27 | assert grassmann_angle_from_point(p3, p2) < 1e-10
28 | # assert np.linalg.matrix_rank(np.hstack([p3, p2])) == 2
29 | assert (valt**2).sum() < (v**2).sum()
30 |
31 | p4 = grassmann_random_point(5, 2)
32 | p5 = grassmann_random_point(5, 2)
33 | v45 = grassmann_log(p4, p5)
34 | assert np.allclose(p4.T @ v45 + v45.T @ p4, np.zeros((2, 2)))
35 | # This failed randomly on the github runner
36 | # assert np.linalg.matrix_rank(np.hstack([grassmann_map(v45, p4), p5])) == 2
37 | assert np.allclose(grassmann_log(p4, grassmann_map(v45, p4)), v45)
38 |
39 |
40 | def test_experiment():
41 | assert 1 == 3 - 2
42 |
--------------------------------------------------------------------------------
/tests/test_lemur.py:
--------------------------------------------------------------------------------
1 | from io import StringIO
2 |
3 | import anndata as ad
4 | import formulaic
5 | import numpy as np
6 | import pandas as pd
7 |
8 | import pylemur.tl._grassmann
9 | import pylemur.tl.alignment
10 | import pylemur.tl.lemur
11 |
12 |
13 | def test_design_specification_works():
14 | Y = np.random.randn(500, 30)
15 | dat = pd.DataFrame({"condition": np.random.choice(["trt", "ctrl"], size=500)})
16 | model = pylemur.tl.LEMUR(Y, design="~ condition", obs_data=dat)
17 | assert model.design_matrix.equals(formulaic.model_matrix("~ condition", dat))
18 |
19 | adata = ad.AnnData(Y)
20 | model = pylemur.tl.LEMUR(adata, design="~ condition", obs_data=dat)
21 | assert model.design_matrix.equals(formulaic.model_matrix("~ condition", dat))
22 |
23 | adata = ad.AnnData(Y, obs=dat)
24 | model = pylemur.tl.LEMUR(adata, design="~ condition")
25 | assert model.design_matrix.equals(formulaic.model_matrix("~ condition", dat))
26 |
27 |
28 | def test_numpy_design_matrix_works():
29 | Y = np.random.randn(500, 30)
30 | dat = pd.DataFrame({"condition": np.random.choice(["trt", "ctrl"], size=500)})
31 | design_mat = formulaic.model_matrix("~ condition", dat).to_numpy()
32 | grouping = np.random.choice(2, size=500)
33 |
34 | ref_model = pylemur.tl.LEMUR(Y, design="~ condition", obs_data=dat).fit().align_with_grouping(grouping)
35 |
36 | model = pylemur.tl.LEMUR(Y, design=design_mat).fit().align_with_grouping(grouping)
37 | assert np.allclose(model.coefficients, ref_model.coefficients)
38 |
39 | model = pylemur.tl.LEMUR(ad.AnnData(Y), design=design_mat).fit().align_with_grouping(grouping)
40 | assert np.allclose(model.coefficients, ref_model.coefficients)
41 |
42 | design_df = pd.DataFrame(design_mat, columns=["Intercept", "Covar1"])
43 | model = pylemur.tl.LEMUR(Y, design=design_df).fit().align_with_grouping(grouping)
44 | assert np.allclose(model.coefficients, ref_model.coefficients)
45 | assert np.allclose(model.alignment_coefficients, ref_model.alignment_coefficients)
46 |
47 |
48 | def test_pandas_design_matrix_works():
49 | Y = np.random.randn(500, 30)
50 | dat = pd.DataFrame({"condition": np.random.choice(["trt", "ctrl"], size=500)})
51 | design_mat = formulaic.model_matrix("~ condition", dat)
52 |
53 | ref_model = pylemur.tl.LEMUR(Y, design="~ condition", obs_data=dat).fit()
54 |
55 | model = pylemur.tl.LEMUR(Y, design=design_mat).fit()
56 | assert np.allclose(model.coefficients, ref_model.coefficients)
57 |
58 | model = pylemur.tl.LEMUR(ad.AnnData(Y), design=design_mat).fit()
59 | assert np.allclose(model.coefficients, ref_model.coefficients)
60 |
61 |
62 | def test_copy_works():
63 | Y = np.random.randn(500, 30)
64 | dat = pd.DataFrame({"condition": np.random.choice(["trt", "ctrl"], size=500)})
65 | model = pylemur.tl.LEMUR(Y, design="~ condition", obs_data=dat)
66 | cp = model.copy(copy_adata=False)
67 | cp2 = model.copy(copy_adata=True)
68 | assert id(model.adata) == id(cp.adata)
69 | assert id(model.adata) != id(cp2.adata)
70 | _assert_lemur_model_equal(model, cp)
71 | _assert_lemur_model_equal(model, cp2, adata_id_equal=False)
72 |
73 |
74 | def test_predict():
75 | ## Make sure I get the same results as in R
76 | # save_func <- function(obj){
77 | # readr::format_csv(as.data.frame(obj), col_names = FALSE)
78 | # }
79 | # randn <- function(n, m, ...){
80 | # matrix(rnorm(n * m, ...), nrow = n, ncol = m)
81 | # }
82 | #
83 | # set.seed(1)
84 | # Y <- round(randn(5, 30), 2)
85 | # design <- cbind(1, rep(1:2, each = 15))
86 | # fit <- lemur(Y, design = design, n_embedding = 3, test_fraction = 0)
87 | # save_func(Y)
88 | # save_func(design)
89 | # save_func(predict(fit))
90 | Y = np.genfromtxt(
91 | StringIO(
92 | "-0.63,-0.82,1.51,-0.04,0.92,-0.06,1.36,-0.41,-0.16,-0.71,0.4,1.98,2.4,0.19,0.48,0.29,-0.57,0.33,-0.54,0.56,-0.62,1.77,-0.64,-0.39,-0.51,0.71,0.06,-1.54,-1.91,-0.75\n0.18,0.49,0.39,-0.02,0.78,-0.16,-0.1,-0.39,-0.25,0.36,-0.61,-0.37,-0.04,-1.8,-0.71,-0.44,-0.14,1.06,1.21,-1.28,0.04,0.72,-0.46,-0.32,1.34,-0.07,-0.59,-0.3,1.18,2.09\n-0.84,0.74,-0.62,0.94,0.07,-1.47,0.39,-0.06,0.7,0.77,0.34,-1.04,0.69,1.47,0.61,0,1.18,-0.3,1.16,-0.57,-0.91,0.91,1.43,-0.28,-0.21,-0.04,0.53,-0.53,-1.66,0.02\n1.6,0.58,-2.21,0.82,-1.99,-0.48,-0.05,1.1,0.56,-0.11,-1.13,0.57,0.03,0.15,-0.93,0.07,-1.52,0.37,0.7,-1.22,0.16,0.38,-0.65,0.49,-0.18,-0.68,-1.52,-0.65,-0.46,-1.29\n0.33,-0.31,1.12,0.59,0.62,0.42,-1.38,0.76,-0.69,0.88,1.43,-0.14,-0.74,2.17,-1.25,-0.59,0.59,0.27,1.59,-0.47,-0.65,1.68,-0.21,-0.18,-0.1,-0.32,0.31,-0.06,-1.12,-1.64\n"
93 | ),
94 | delimiter=",",
95 | ).T
96 | design = np.genfromtxt(
97 | StringIO(
98 | "1,1\n1,1\n1,1\n1,1\n1,1\n1,1\n1,1\n1,1\n1,1\n1,1\n1,1\n1,1\n1,1\n1,1\n1,1\n1,2\n1,2\n1,2\n1,2\n1,2\n1,2\n1,2\n1,2\n1,2\n1,2\n1,2\n1,2\n1,2\n1,2\n1,2\n"
99 | ),
100 | delimiter=",",
101 | )
102 | results_init = np.genfromtxt(
103 | StringIO(
104 | "-0.8554888214077533,-0.2720977014782501,1.467533504177673,-0.22767956411506196,1.2785521085014837,0.14379806862053318,1.4750333827944937,-0.6570301309706671,0.2066831996074998,-0.353750513692208,0.40174695240184183,1.0849159925675056,1.8230740377826868,-0.3796132967907797,1.27432278200101,0.25734634440017745,-0.48925340880447665,0.15787257755416356,0.04938969573830959,0.3192312181063839,-0.3542994869485033,1.401448533272853,-0.3938541882847783,0.08741246846567274,-0.6510475608713524,0.22923361101684514,-0.06804635141148654,-0.8510364064773515,-1.6943910912740754,-1.7500059544823854\n0.22583899533086937,-0.23735468442585927,0.19733713734339808,-0.4421331667319863,0.03586724830483476,0.4198858825154666,-0.08304003977298911,-0.16513486823394496,-0.30042983957286806,-0.3460343445737336,-0.3558867742284205,0.3011982462662531,-0.20850076449582816,-1.0161941132088552,-0.27541891451634165,-0.38785291893731433,-0.12147153601775817,0.9332086847171666,1.6012153818454613,-1.4764027185416573,0.23315085154197657,0.4422452475355182,-0.20026735579063953,0.04360380473394784,1.2311912129895743,-0.3862511983506538,-0.7251300187432477,0.11080434946587961,1.2741080839827195,1.467848129569024\n-0.7940511553704984,0.36872650465867485,-0.696190082245086,0.7842394355421791,-0.2859846254450281,-1.2496520777731421,0.3835639668947345,0.06394998322061843,0.6354423384795286,0.43034530046034825,0.446899273641681,-0.6513965946225246,0.6871089757451906,1.8676308579518373,0.6993678988614832,-0.35853838671959254,1.3529850324988306,-0.246355558204097,1.1840747073603424,-0.40278806288570534,-0.980382841146775,1.0552239017849732,0.9738387917168745,-0.4865414099512409,-0.14227709749759898,-0.07349599449524213,0.7673735611128623,-0.27683325130457614,-1.4140680282113947,-0.2222153640576589\n1.418455323621421,1.003633642928876,-2.249906497292914,0.6558257816955055,-1.7203847145600766,-0.29827588824455953,0.043760865569417634,0.9060713852440867,0.8559072353487625,0.15882996592743862,-1.1213675611187175,-0.13675398777029954,-0.4426075110785659,-0.28963584470503034,-0.27355219556533633,-0.22900366687718224,-1.4215471940245057,0.5395008912619365,0.30620330206252744,-0.9001187449182614,-0.09048745003886999,0.7701978486901692,-1.2331890579233868,-0.0319982155355501,-0.019760633637726566,-0.37122588017711394,-1.2165536309067462,-0.9082217868619842,-0.3910853125055948,-0.8027104686077078\n0.16860548847823306,0.019631704090379926,1.069169569969778,0.4089409239787705,0.8084926808056435,0.6289431224230143,-1.2935687373997728,0.6009167939510944,-0.42509727026387234,1.0706646727858744,1.4570831601240841,-0.7309939926404594,-1.1819487005772338,1.8301112490396232,-0.6209506647651512,-0.1631607424281398,0.3494484690478879,0.3004091825038984,1.248612413250847,-0.5327275335431212,-0.7111287487922291,1.7104365790742557,0.179842226080069,-0.19838902997632976,-0.10222070429937023,-0.02570345245587259,0.10711511924305978,-0.7160399504157087,-1.5156808495142373,-0.8308129777750086\n"
105 | ),
106 | delimiter=",",
107 | ).T
108 |
109 | model = pylemur.tl.LEMUR(Y, design=design, n_embedding=3)
110 | model.fit(verbose=False)
111 | pred_init = model.predict()
112 | assert np.allclose(pred_init, results_init)
113 |
114 | ## Here I have to override all coefficients to make sure that the results match
115 | # set.seed(1)
116 | # wide_alignment_coef <- round(cbind(randn(3,4), randn(3,4)), 2)
117 | # fit@metadata$alignment_coefficients <- array(wide_alignment_coef, dim = c(3,4,2))
118 | # res <- predict(fit)
119 | # save_func(fit$base_point)
120 | # save_func(fit$embedding)
121 | # save_func(wide_alignment_coef)
122 | # save_func(cbind(fit$coefficients[,,1], fit$coefficients[,,2]))
123 | # save_func(res)
124 | basepoint = np.genfromtxt(
125 | StringIO(
126 | "-0.38925206703492543,0.5213219572509142,-0.7231411670117558\n-0.014297729446464128,-0.3059067922758592,0.056315487696737755\n0.2455091078121989,0.48054559390052237,0.133975572527343\n0.7927765077921876,-0.20187676099522883,-0.5743821409539152\n0.39938589098238725,0.602466726782849,0.3550086203707195\n"
127 | ),
128 | delimiter=",",
129 | ).T
130 | assert np.allclose(pylemur.tl._grassmann.grassmann_angle_from_point(basepoint.T, model.base_point.T), 0)
131 | emb = np.genfromtxt(
132 | StringIO(
133 | "1.307607290530331,1.204126637770359,-2.334941421246343,1.202469151007156,-1.7184261449703435,-0.5734035327253721,-0.6805385978545121,1.3185253662996623,0.8741076652495349,0.8858962317444363,-0.34911335825714623,-0.9635203043565856,-1.052446567304073,1.396858005348387,-0.5172004212354832,-0.5028122904697048,-0.3322559893690832,1.050654556399191,2.1375501746096512,-1.7864259568934735,-0.38553520332697033,1.7579862599667835,-0.44888734172910766,-0.14472710389159946,0.8047342904549903,-0.4509018220855433,-0.8826656042763416,-0.7455036156569727,-0.28188072558002397,0.2106703718482004\n-1.8164798196642507,-0.477295729296259,0.4790240845360194,0.1574495019414363,0.5418966595889152,-1.218733839936566,-0.030245840455864665,-0.6590134762699248,-0.17773495609983386,0.16681479907923033,0.9577303872796369,-0.8649693035413422,0.5516152492211109,1.872319735730998,0.5176225478866773,0.1933219672043183,1.1600064890134691,-0.2055279125365953,0.8140862711968238,0.5522995484090838,-1.0056198219223997,2.0135021529189063,0.9037871448321165,-0.1835087018953936,-0.7687706326544156,0.4350873523204601,1.1121743457677276,-0.6639780098588698,-2.6747181225665626,-1.6821420702286731\n0.18948314499743618,-0.35557961492948087,1.2558867706759906,-0.06200523822352812,0.8150276978150715,0.9017861009351356,-1.7593369587152685,0.38440670001846194,-0.9259345355859308,0.7952852407963301,1.2519868131937648,-0.7881479698006301,-1.7254602945657311,1.1468879075695446,-1.1242857641811612,-0.777687817264553,1.2908968434310275,-0.7487113052869242,0.3633562246175412,-0.7434161046620548,-0.6564971373410212,-1.062751302635433,0.9113503972724165,-0.742431234492101,0.1738627261607038,-0.5327010359755776,0.46611269007910494,0.42937368327065845,0.34776880301511215,1.281474569811115\n"
134 | ),
135 | delimiter=",",
136 | ).T
137 | wide_alignment_coef = np.genfromtxt(
138 | StringIO(
139 | "0,-0.63,1.6,0.49,0,-0.31,-0.62,-0.04\n0,0.18,0.33,0.74,0,1.51,-2.21,-0.02\n0,-0.84,-0.82,0.58,0,0.39,1.12,0.94\n"
140 | ),
141 | delimiter=",",
142 | )
143 | alignment_coef = np.stack([wide_alignment_coef[:, 0:4], wide_alignment_coef[:, 4:8]], axis=2)
144 | coef = np.genfromtxt(
145 | StringIO(
146 | "-0.39126157680942947,-0.13893603024573095,0.22217190036192688,0.2997192813183081,0.045415132726990666,-0.11388588356030183\n-1.4118852084954074,-0.12010846331372559,-0.042620766619525983,1.1685923275506187,-0.009795610498605375,0.1481261731296855\n0.4301922175223224,0.8308632945684709,-1.7460499853532059,-0.17472584078151635,-0.3588445576311046,1.1196336953167267\n0.010836267191959998,-0.020245706071854806,0.04720615470254291,-0.013801681502086402,0.009718078909566252,-0.032178386499295816\n-0.7178345111625161,-0.6102684315610786,1.1946309212817439,0.46875216448007323,0.2452095180077213,-0.7300769880872995\n"
147 | ),
148 | delimiter=",",
149 | )
150 | coef = np.stack([coef[:, 0:3], coef[:, 3:6]], axis=2)
151 | coef = np.einsum("ijk->jik", coef)
152 | results = np.genfromtxt(
153 | StringIO(
154 | "0.9835142473010876,0.5282279297874248,0.3734623735376733,0.29148494672406977,0.3233549213466831,0.8738371099709996,0.5407962262440337,0.5618064038652619,0.4615777825112788,0.28201578105225994,0.06759705564456442,0.8278171090396453,0.3564197483842603,-0.37494408404208857,0.3130324486328697,4.43109412669912,4.485260023784498,-13.556871509080235,-25.740741431276582,19.3507385539798,4.267028891777953,-24.31631072038289,5.55415174991481,0.5751412151679098,-8.850525840651931,3.9506221824433108,9.87464782407815,9.75898404222322,5.923544047383523,0.5432368439394216\n-0.5884851968278264,-0.6980003631236811,1.0095730179286158,-0.706603653513592,0.6775999178228775,0.2960753981255735,-0.04883524121983143,-0.6502161106433406,-0.6374945628932893,-0.46524145497379454,0.0868361701396683,0.24887091370316872,0.08165271296897453,-0.7703009490648383,-0.08543059842869558,1.8195448028133048,2.2357171298099003,-4.317401830679442,-8.088317047084285,6.921096710379453,1.5046074415988786,-7.462954899818863,2.5044021915523067,0.44085644478623487,-2.7813085924525587,1.7337832588534812,3.965667679124676,3.523130209648968,1.7727989396531905,0.2683775618147773\n1.2241353640082717,1.439811782612444,-2.4939338263214896,1.413000126505777,-1.732752410802792,-0.8109483715136666,0.060386636787661846,1.297799915381223,1.325781352884555,0.8252158435825825,-0.46440880845174237,-0.6215041984220009,-0.2577507799745221,1.405770288579147,0.07939708514456326,3.63765958748972,4.311718634450114,-10.958511305614007,-20.4685709755674,15.761881934367485,3.4712598990276082,-19.643794665294983,5.058176520273988,0.500553412576863,-6.917971274367058,3.329924337359561,8.448268016100178,8.2448015916862,5.000276047604282,0.9543282399075053\n-0.7921997533676974,-0.22060767244684792,-0.13854023872636448,-0.07315533443899382,-0.04074837994335717,-0.6868254379104765,0.30181730830558207,-0.4486065557871184,0.016532124694023168,-0.2487768721959298,-0.043942589618246174,-0.18230527640412342,0.5013051052792722,0.22675447837288173,0.33929909418738713,3.3366520193532914,3.389761379749944,-10.811207239740792,-20.286351865125106,15.07164187230232,3.0442924834976672,-18.933270905633723,4.214979326150896,0.264093656883327,-7.2491426213752,2.978123336697904,7.6524353015805255,7.331256553093012,4.067336259849709,-0.0705995572837037\n0.429231215968942,0.4526675323865768,-0.0583697984419273,0.6074892682578291,0.02405662404573697,0.13577913690270105,-0.3608880085761539,0.6353597043295498,0.26375995622833437,0.7194765847659087,0.5782845037350453,-0.326907098443994,-0.38119677146460307,1.1810728627919298,-0.08981571248586934,5.740966466337848,6.297668960348173,-17.056604679816925,-32.1536933270833,24.663672997961317,5.422318375098499,-30.48048219956338,7.551698596077222,0.8269489489661501,-10.99498861966501,5.205042841546413,12.959282592098615,12.608602675876007,7.517277260484411,0.9922891113340572\n"
155 | ),
156 | delimiter=",",
157 | ).T
158 |
159 | model.alignment_coefficients = alignment_coef
160 | model.base_point = basepoint
161 | model.embedding = emb
162 | model.coefficients = coef
163 | pred_new = model.predict()
164 | assert np.allclose(pred_new, results)
165 |
166 |
167 | def test_align_with_grouping():
168 | ## Make sure I get the same results as in R
169 | ## Same code as above
170 | Y = np.genfromtxt(
171 | StringIO(
172 | "-0.63,-0.82,1.51,-0.04,0.92,-0.06,1.36,-0.41,-0.16,-0.71,0.4,1.98,2.4,0.19,0.48,0.29,-0.57,0.33,-0.54,0.56,-0.62,1.77,-0.64,-0.39,-0.51,0.71,0.06,-1.54,-1.91,-0.75\n0.18,0.49,0.39,-0.02,0.78,-0.16,-0.1,-0.39,-0.25,0.36,-0.61,-0.37,-0.04,-1.8,-0.71,-0.44,-0.14,1.06,1.21,-1.28,0.04,0.72,-0.46,-0.32,1.34,-0.07,-0.59,-0.3,1.18,2.09\n-0.84,0.74,-0.62,0.94,0.07,-1.47,0.39,-0.06,0.7,0.77,0.34,-1.04,0.69,1.47,0.61,0,1.18,-0.3,1.16,-0.57,-0.91,0.91,1.43,-0.28,-0.21,-0.04,0.53,-0.53,-1.66,0.02\n1.6,0.58,-2.21,0.82,-1.99,-0.48,-0.05,1.1,0.56,-0.11,-1.13,0.57,0.03,0.15,-0.93,0.07,-1.52,0.37,0.7,-1.22,0.16,0.38,-0.65,0.49,-0.18,-0.68,-1.52,-0.65,-0.46,-1.29\n0.33,-0.31,1.12,0.59,0.62,0.42,-1.38,0.76,-0.69,0.88,1.43,-0.14,-0.74,2.17,-1.25,-0.59,0.59,0.27,1.59,-0.47,-0.65,1.68,-0.21,-0.18,-0.1,-0.32,0.31,-0.06,-1.12,-1.64\n"
173 | ),
174 | delimiter=",",
175 | ).T
176 | design = np.genfromtxt(
177 | StringIO(
178 | "1,1\n1,1\n1,1\n1,1\n1,1\n1,1\n1,1\n1,1\n1,1\n1,1\n1,1\n1,1\n1,1\n1,1\n1,1\n1,2\n1,2\n1,2\n1,2\n1,2\n1,2\n1,2\n1,2\n1,2\n1,2\n1,2\n1,2\n1,2\n1,2\n1,2\n"
179 | ),
180 | delimiter=",",
181 | )
182 | results_init = np.genfromtxt(
183 | StringIO(
184 | "-0.8554888214077533,-0.2720977014782501,1.467533504177673,-0.22767956411506196,1.2785521085014837,0.14379806862053318,1.4750333827944937,-0.6570301309706671,0.2066831996074998,-0.353750513692208,0.40174695240184183,1.0849159925675056,1.8230740377826868,-0.3796132967907797,1.27432278200101,0.25734634440017745,-0.48925340880447665,0.15787257755416356,0.04938969573830959,0.3192312181063839,-0.3542994869485033,1.401448533272853,-0.3938541882847783,0.08741246846567274,-0.6510475608713524,0.22923361101684514,-0.06804635141148654,-0.8510364064773515,-1.6943910912740754,-1.7500059544823854\n0.22583899533086937,-0.23735468442585927,0.19733713734339808,-0.4421331667319863,0.03586724830483476,0.4198858825154666,-0.08304003977298911,-0.16513486823394496,-0.30042983957286806,-0.3460343445737336,-0.3558867742284205,0.3011982462662531,-0.20850076449582816,-1.0161941132088552,-0.27541891451634165,-0.38785291893731433,-0.12147153601775817,0.9332086847171666,1.6012153818454613,-1.4764027185416573,0.23315085154197657,0.4422452475355182,-0.20026735579063953,0.04360380473394784,1.2311912129895743,-0.3862511983506538,-0.7251300187432477,0.11080434946587961,1.2741080839827195,1.467848129569024\n-0.7940511553704984,0.36872650465867485,-0.696190082245086,0.7842394355421791,-0.2859846254450281,-1.2496520777731421,0.3835639668947345,0.06394998322061843,0.6354423384795286,0.43034530046034825,0.446899273641681,-0.6513965946225246,0.6871089757451906,1.8676308579518373,0.6993678988614832,-0.35853838671959254,1.3529850324988306,-0.246355558204097,1.1840747073603424,-0.40278806288570534,-0.980382841146775,1.0552239017849732,0.9738387917168745,-0.4865414099512409,-0.14227709749759898,-0.07349599449524213,0.7673735611128623,-0.27683325130457614,-1.4140680282113947,-0.2222153640576589\n1.418455323621421,1.003633642928876,-2.249906497292914,0.6558257816955055,-1.7203847145600766,-0.29827588824455953,0.043760865569417634,0.9060713852440867,0.8559072353487625,0.15882996592743862,-1.1213675611187175,-0.13675398777029954,-0.4426075110785659,-0.28963584470503034,-0.27355219556533633,-0.22900366687718224,-1.4215471940245057,0.5395008912619365,0.30620330206252744,-0.9001187449182614,-0.09048745003886999,0.7701978486901692,-1.2331890579233868,-0.0319982155355501,-0.019760633637726566,-0.37122588017711394,-1.2165536309067462,-0.9082217868619842,-0.3910853125055948,-0.8027104686077078\n0.16860548847823306,0.019631704090379926,1.069169569969778,0.4089409239787705,0.8084926808056435,0.6289431224230143,-1.2935687373997728,0.6009167939510944,-0.42509727026387234,1.0706646727858744,1.4570831601240841,-0.7309939926404594,-1.1819487005772338,1.8301112490396232,-0.6209506647651512,-0.1631607424281398,0.3494484690478879,0.3004091825038984,1.248612413250847,-0.5327275335431212,-0.7111287487922291,1.7104365790742557,0.179842226080069,-0.19838902997632976,-0.10222070429937023,-0.02570345245587259,0.10711511924305978,-0.7160399504157087,-1.5156808495142373,-0.8308129777750086\n"
185 | ),
186 | delimiter=",",
187 | ).T
188 |
189 | model = pylemur.tl.LEMUR(Y, design=design, n_embedding=3)
190 | model.fit(verbose=False)
191 | pred_init = model.predict()
192 | assert np.allclose(pred_init, results_init)
193 |
194 | ## Here is a modified version of the approach from the previous test
195 | # set.seed(1)
196 | # grouping <- sample(c(1,2), size = 30, replace = TRUE)
197 | # save_func(grouping)
198 | # fit <- align_by_grouping(fit, grouping = as.character(grouping), design = fit$design_matrix)
199 | # save_func(cbind(fit$alignment_coefficients[,,1], fit$alignment_coefficients[,,2]))
200 | # ...
201 | grouping = np.genfromtxt(
202 | StringIO("1\n2\n1\n1\n2\n1\n1\n1\n2\n2\n1\n1\n1\n1\n1\n2\n2\n2\n2\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n2\n")
203 | )
204 |
205 | basepoint = np.genfromtxt(
206 | StringIO(
207 | "-0.38925206703492543,0.5213219572509142,-0.7231411670117558\n-0.014297729446464128,-0.3059067922758592,0.056315487696737755\n0.2455091078121989,0.48054559390052237,0.133975572527343\n0.7927765077921876,-0.20187676099522883,-0.5743821409539152\n0.39938589098238725,0.602466726782849,0.3550086203707195\n"
208 | ),
209 | delimiter=",",
210 | ).T
211 | assert np.allclose(pylemur.tl._grassmann.grassmann_angle_from_point(basepoint.T, model.base_point.T), 0)
212 | emb = np.genfromtxt(
213 | StringIO(
214 | "1.307607290530331,1.204126637770359,-2.334941421246343,1.202469151007156,-1.7184261449703435,-0.5734035327253721,-0.6805385978545121,1.3185253662996623,0.8741076652495349,0.8858962317444363,-0.34911335825714623,-0.9635203043565856,-1.052446567304073,1.396858005348387,-0.5172004212354832,-0.5028122904697048,-0.3322559893690832,1.050654556399191,2.1375501746096512,-1.7864259568934735,-0.38553520332697033,1.7579862599667835,-0.44888734172910766,-0.14472710389159946,0.8047342904549903,-0.4509018220855433,-0.8826656042763416,-0.7455036156569727,-0.28188072558002397,0.2106703718482004\n-1.8164798196642507,-0.477295729296259,0.4790240845360194,0.1574495019414363,0.5418966595889152,-1.218733839936566,-0.030245840455864665,-0.6590134762699248,-0.17773495609983386,0.16681479907923033,0.9577303872796369,-0.8649693035413422,0.5516152492211109,1.872319735730998,0.5176225478866773,0.1933219672043183,1.1600064890134691,-0.2055279125365953,0.8140862711968238,0.5522995484090838,-1.0056198219223997,2.0135021529189063,0.9037871448321165,-0.1835087018953936,-0.7687706326544156,0.4350873523204601,1.1121743457677276,-0.6639780098588698,-2.6747181225665626,-1.6821420702286731\n0.18948314499743618,-0.35557961492948087,1.2558867706759906,-0.06200523822352812,0.8150276978150715,0.9017861009351356,-1.7593369587152685,0.38440670001846194,-0.9259345355859308,0.7952852407963301,1.2519868131937648,-0.7881479698006301,-1.7254602945657311,1.1468879075695446,-1.1242857641811612,-0.777687817264553,1.2908968434310275,-0.7487113052869242,0.3633562246175412,-0.7434161046620548,-0.6564971373410212,-1.062751302635433,0.9113503972724165,-0.742431234492101,0.1738627261607038,-0.5327010359755776,0.46611269007910494,0.42937368327065845,0.34776880301511215,1.281474569811115\n"
215 | ),
216 | delimiter=",",
217 | ).T
218 | coef = np.genfromtxt(
219 | StringIO(
220 | "-0.39126157680942947,-0.13893603024573095,0.22217190036192688,0.2997192813183081,0.045415132726990666,-0.11388588356030183\n-1.4118852084954074,-0.12010846331372559,-0.042620766619525983,1.1685923275506187,-0.009795610498605375,0.1481261731296855\n0.4301922175223224,0.8308632945684709,-1.7460499853532059,-0.17472584078151635,-0.3588445576311046,1.1196336953167267\n0.010836267191959998,-0.020245706071854806,0.04720615470254291,-0.013801681502086402,0.009718078909566252,-0.032178386499295816\n-0.7178345111625161,-0.6102684315610786,1.1946309212817439,0.46875216448007323,0.2452095180077213,-0.7300769880872995\n"
221 | ),
222 | delimiter=",",
223 | )
224 | coef = np.stack([coef[:, 0:3], coef[:, 3:6]], axis=2)
225 | coef = np.einsum("ijk->jik", coef)
226 |
227 | wide_alignment_coef = np.genfromtxt(
228 | StringIO(
229 | "-0.09057594036322775,0.004854216024969793,0.002327313009887152,0.005758995615529929,0.05944370038362205,-0.0035991050730759615,-0.0020095175466780125,-0.005428755213860675\n-0.05277256958166519,0.042258290722341184,0.020260381752233707,0.05013483325379276,0.02699629278032626,-0.03133194479519331,-0.017493819035318634,-0.04725993134785861\n-0.07693347488863567,0.04690711963321551,0.02248922363923753,0.0556501595551681,0.04220308611767923,-0.03477876785187711,-0.019418311727825953,-0.052458990075042476\n"
230 | ),
231 | delimiter=",",
232 | )
233 | alignment_coef = np.stack([wide_alignment_coef[:, 0:4], wide_alignment_coef[:, 4:8]], axis=2)
234 | results = np.genfromtxt(
235 | StringIO(
236 | "-0.8554888214077531,-0.27209770147824963,1.4675335041776725,-0.22767956411506185,1.2785521085014835,0.14379806862053301,1.4750333827944933,-0.6570301309706669,0.20668319960749992,-0.3537505136922079,0.4017469524018417,1.0849159925675056,1.8230740377826868,-0.3796132967907798,1.27432278200101,0.25734634440017756,-0.4892534088044769,0.15787257755416356,0.049389695738309536,0.3192312181063839,-0.3542994869485032,1.4014485332728528,-0.3938541882847784,0.0874124684656728,-0.6510475608713524,0.22923361101684514,-0.0680463514114866,-0.8510364064773513,-1.6943910912740756,-1.7500059544823856\n0.22583899533086926,-0.2373546844258592,0.19733713734339786,-0.4421331667319863,0.03586724830483465,0.4198858825154665,-0.08304003977298917,-0.16513486823394496,-0.300429839572868,-0.34603434457373355,-0.3558867742284205,0.3011982462662531,-0.20850076449582813,-1.0161941132088552,-0.2754189145163416,-0.38785291893731433,-0.12147153601775795,0.9332086847171664,1.601215381845461,-1.476402718541657,0.23315085154197665,0.44224524753551797,-0.20026735579063937,0.043603804733947815,1.2311912129895743,-0.3862511983506537,-0.7251300187432477,0.11080434946587975,1.27410808398272,1.467848129569024\n-0.7940511553704983,0.36872650465867474,-0.6961900822450855,0.784239435542179,-0.2859846254450279,-1.249652077773142,0.3835639668947346,0.0639499832206183,0.6354423384795286,0.43034530046034813,0.446899273641681,-0.6513965946225245,0.6871089757451905,1.8676308579518373,0.6993678988614832,-0.3585383867195926,1.352985032498831,-0.24635555820409705,1.1840747073603421,-0.40278806288570546,-0.9803828411467752,1.0552239017849732,0.9738387917168745,-0.486541409951241,-0.1422770974975989,-0.07349599449524222,0.7673735611128623,-0.27683325130457614,-1.4140680282113949,-0.22221536405765868\n1.4184553236214208,1.0036336429288755,-2.2499064972929133,0.6558257816955054,-1.7203847145600764,-0.2982758882445596,0.043760865569417745,0.9060713852440865,0.8559072353487623,0.15882996592743856,-1.1213675611187173,-0.13675398777029935,-0.4426075110785658,-0.28963584470503045,-0.2735521955653362,-0.22900366687718213,-1.4215471940245057,0.5395008912619363,0.3062033020625272,-0.9001187449182614,-0.09048745003886988,0.7701978486901689,-1.2331890579233868,-0.031998215535550045,-0.019760633637726732,-0.37122588017711383,-1.2165536309067464,-0.9082217868619842,-0.3910853125055947,-0.802710468607708\n0.1686054884782331,0.019631704090379815,1.0691695699697783,0.40894092397877047,0.8084926808056437,0.6289431224230144,-1.2935687373997728,0.6009167939510944,-0.42509727026387234,1.0706646727858744,1.4570831601240841,-0.7309939926404595,-1.1819487005772338,1.8301112490396232,-0.6209506647651512,-0.16316074242813974,0.34944846904788796,0.3004091825038983,1.248612413250847,-0.5327275335431213,-0.7111287487922291,1.7104365790742553,0.179842226080069,-0.19838902997632976,-0.10222070429937029,-0.02570345245587259,0.10711511924305983,-0.7160399504157086,-1.5156808495142378,-0.8308129777750086\n"
237 | ),
238 | delimiter=",",
239 | ).T
240 |
241 | model.base_point = basepoint
242 | model.embedding = emb
243 | model.coefficients = coef
244 | model.align_with_grouping(grouping)
245 | assert np.allclose(model.alignment_coefficients, alignment_coef)
246 |
247 | pred_new = model.predict()
248 | assert np.allclose(pred_new, results)
249 | assert np.allclose(pred_new, pred_init)
250 |
251 |
252 | def test_align_works():
253 | nelem = 200
254 | Y = np.random.randn(nelem, 5)
255 | design = np.stack([np.ones(nelem), np.random.choice(2, size=nelem)], axis=1)
256 | model = pylemur.tl.LEMUR(Y, design=design, n_embedding=3)
257 | model.fit(verbose=False)
258 | model_2 = model.copy()
259 | model_3 = model.copy()
260 |
261 | grouping = np.random.choice(2, nelem)
262 | model_2.align_with_grouping(grouping)
263 | model_3.align_with_harmony()
264 | assert np.allclose(model_2.predict(), model.predict())
265 | assert np.allclose(model_3.predict(), model.predict())
266 |
267 |
268 | def _assert_lemur_model_equal(m1, m2, adata_id_equal=True):
269 | for k in m1.__dict__.keys():
270 | if k == "adata" and adata_id_equal:
271 | assert id(m1.adata) == id(m2.adata)
272 | elif k == "adata" and not adata_id_equal:
273 | assert id(m1.adata) != id(m2.adata)
274 | elif isinstance(m1.__dict__[k], pd.DataFrame):
275 | pd.testing.assert_frame_equal(m1.__dict__[k], m2.__dict__[k])
276 | elif isinstance(m1.__dict__[k], np.ndarray):
277 | assert np.array_equal(m1.__dict__[k], m2.__dict__[k])
278 | else:
279 | assert m1.__dict__[k] == m2.__dict__[k]
280 |
--------------------------------------------------------------------------------
/tests/test_lin_alg_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from sklearn.linear_model import LinearRegression, Ridge
3 |
4 | import pylemur.tl.alignment as pylemur_al
5 | from pylemur.tl._lin_alg_wrappers import *
6 |
7 |
8 | def test_fit_pca():
9 | # Make example data
10 | Y = np.random.randn(30, 400)
11 | pca = fit_pca(Y, 3, center=True)
12 | assert np.allclose(pca.offset, Y.mean(axis=0))
13 | assert np.allclose(pca.coord_system @ pca.coord_system.T, np.eye(3))
14 | assert np.allclose(pca.embedding, (Y - pca.offset) @ pca.coord_system.T)
15 |
16 | pca2 = fit_pca(Y, 3, center=False)
17 | assert np.allclose(pca2.offset, np.zeros(Y.shape[1]))
18 | assert np.allclose(pca2.coord_system @ pca2.coord_system.T, np.eye(3))
19 | assert np.allclose(pca2.embedding, (Y - pca2.offset) @ pca2.coord_system.T)
20 | assert np.allclose(Y @ pca2.coord_system.T, pca2.embedding)
21 |
22 |
23 | def test_ridge_regression():
24 | # Regular least squares
25 | Y = np.random.randn(400, 30)
26 | X = np.random.randn(400, 3)
27 | beta = ridge_regression(Y, X)
28 | assert beta.shape == (3, 30)
29 | assert np.allclose(beta, np.linalg.inv(X.T @ X) @ X.T @ Y)
30 | reg = LinearRegression(fit_intercept=False).fit(X, Y)
31 | assert np.allclose(beta, reg.coef_.T)
32 |
33 | # Check with weights
34 | weights = np.random.rand(400)
35 | beta = ridge_regression(Y, X, weights=weights)
36 | reg = LinearRegression(fit_intercept=False).fit(X, Y, sample_weight=weights)
37 | assert np.allclose(beta, reg.coef_.T)
38 |
39 | # Check with ridge penalty
40 | pen = 0.3
41 | beta = ridge_regression(Y, X, ridge_penalty=pen)
42 | gamma = np.sqrt(400) * pen**2 * np.eye(3)
43 | beta2 = np.linalg.inv(X.T @ X + gamma.T @ gamma) @ X.T @ Y
44 | assert np.allclose(beta, beta2)
45 |
46 | reg = Ridge(alpha=pen**4 * 400, fit_intercept=False).fit(X, Y)
47 | assert np.allclose(beta, reg.coef_.T)
48 |
49 |
50 | def test_forward_reverse_transformations():
51 | coef = np.random.randn(5, 6, 2)
52 | vec = np.random.randn(2)
53 | forward = pylemur_al._forward_linear_transformation(coef, vec)
54 | reverse = pylemur_al._reverse_linear_transformation(coef, vec)
55 | assert np.allclose(reverse @ forward[:, 1:], np.diag(np.ones(5)))
56 |
--------------------------------------------------------------------------------