├── .codecov.yaml ├── .cruft.json ├── .editorconfig ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.yml │ ├── config.yml │ └── feature_request.yml └── workflows │ ├── build.yaml │ ├── release.yaml │ └── test.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yaml ├── CHANGELOG.md ├── LICENSE ├── README.md ├── docs ├── Makefile ├── _static │ ├── .gitkeep │ ├── css │ │ └── custom.css │ └── images │ │ └── equation_schematic.png ├── _templates │ ├── .gitkeep │ └── autosummary │ │ └── class.rst ├── api.md ├── changelog.md ├── conf.py ├── extensions │ └── typed_returns.py ├── index.md ├── notebooks │ └── Tutorial.myst ├── references.bib └── references.md ├── notebooks ├── check_implementation.qmd ├── devel_experiments.qmd ├── kang_analysis_in_R.qmd ├── lemur_model_experiments.qmd ├── scanpy_experiments.qmd └── test.ipynb ├── pyproject.toml ├── src └── pylemur │ ├── __init__.py │ ├── pl │ ├── __init__.py │ └── basic.py │ ├── pp │ ├── __init__.py │ └── basic.py │ └── tl │ ├── __init__.py │ ├── _design_matrix_utils.py │ ├── _grassmann.py │ ├── _grassmann_lm.py │ ├── _lin_alg_wrappers.py │ ├── alignment.py │ └── lemur.py └── tests ├── test_grasmann_lm.py ├── test_grassmann.py ├── test_lemur.py └── test_lin_alg_utils.py /.codecov.yaml: -------------------------------------------------------------------------------- 1 | # Based on pydata/xarray 2 | codecov: 3 | require_ci_to_pass: no 4 | 5 | coverage: 6 | status: 7 | project: 8 | default: 9 | # Require 1% coverage, i.e., always succeed 10 | target: 1 11 | patch: false 12 | changes: false 13 | 14 | comment: 15 | layout: diff, flags, files 16 | behavior: once 17 | require_base: no 18 | -------------------------------------------------------------------------------- /.cruft.json: -------------------------------------------------------------------------------- 1 | { 2 | "template": "https://github.com/scverse/cookiecutter-scverse", 3 | "commit": "87a407a65408d75a949c0b54b19fd287475a56f8", 4 | "checkout": "v0.4.0", 5 | "context": { 6 | "cookiecutter": { 7 | "project_name": "pyLemur", 8 | "package_name": "pylemur", 9 | "project_description": "A Python implementation of the LEMUR algorithm for analyzing multi-condition single-cell RNA-seq data.", 10 | "author_full_name": "Your Name", 11 | "author_email": "artjom31415@googlemail.com", 12 | "github_user": "const-ae", 13 | "project_repo": "https://github.com/const-ae/pyLemur", 14 | "license": "MIT License", 15 | "_copy_without_render": [ 16 | ".github/workflows/build.yaml", 17 | ".github/workflows/test.yaml", 18 | "docs/_templates/autosummary/**.rst" 19 | ], 20 | "_render_devdocs": false, 21 | "_jinja2_env_vars": { 22 | "lstrip_blocks": true, 23 | "trim_blocks": true 24 | }, 25 | "_template": "https://github.com/scverse/cookiecutter-scverse" 26 | } 27 | }, 28 | "directory": null 29 | } 30 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*] 4 | indent_style = space 5 | indent_size = 4 6 | end_of_line = lf 7 | charset = utf-8 8 | trim_trailing_whitespace = true 9 | insert_final_newline = true 10 | 11 | [*.{yml,yaml}] 12 | indent_size = 2 13 | 14 | [.cruft.json] 15 | indent_size = 2 16 | 17 | [Makefile] 18 | indent_style = tab 19 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.yml: -------------------------------------------------------------------------------- 1 | name: Bug report 2 | description: Report something that is broken or incorrect 3 | labels: bug 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: | 8 | **Note**: Please read [this guide](https://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports) 9 | detailing how to provide the necessary information for us to reproduce your bug. In brief: 10 | * Please provide exact steps how to reproduce the bug in a clean Python environment. 11 | * In case it's not clear what's causing this bug, please provide the data or the data generation procedure. 12 | * Sometimes it is not possible to share the data, but usually it is possible to replicate problems on publicly 13 | available datasets or to share a subset of your data. 14 | 15 | - type: textarea 16 | id: report 17 | attributes: 18 | label: Report 19 | description: A clear and concise description of what the bug is. 20 | validations: 21 | required: true 22 | 23 | - type: textarea 24 | id: versions 25 | attributes: 26 | label: Version information 27 | description: | 28 | Please paste below the output of 29 | 30 | ```python 31 | import session_info 32 | session_info.show(html=False, dependencies=True) 33 | ``` 34 | placeholder: | 35 | ----- 36 | anndata 0.8.0rc2.dev27+ge524389 37 | session_info 1.0.0 38 | ----- 39 | asttokens NA 40 | awkward 1.8.0 41 | backcall 0.2.0 42 | cython_runtime NA 43 | dateutil 2.8.2 44 | debugpy 1.6.0 45 | decorator 5.1.1 46 | entrypoints 0.4 47 | executing 0.8.3 48 | h5py 3.7.0 49 | ipykernel 6.15.0 50 | jedi 0.18.1 51 | mpl_toolkits NA 52 | natsort 8.1.0 53 | numpy 1.22.4 54 | packaging 21.3 55 | pandas 1.4.2 56 | parso 0.8.3 57 | pexpect 4.8.0 58 | pickleshare 0.7.5 59 | pkg_resources NA 60 | prompt_toolkit 3.0.29 61 | psutil 5.9.1 62 | ptyprocess 0.7.0 63 | pure_eval 0.2.2 64 | pydev_ipython NA 65 | pydevconsole NA 66 | pydevd 2.8.0 67 | pydevd_file_utils NA 68 | pydevd_plugins NA 69 | pydevd_tracing NA 70 | pygments 2.12.0 71 | pytz 2022.1 72 | scipy 1.8.1 73 | setuptools 62.5.0 74 | setuptools_scm NA 75 | six 1.16.0 76 | stack_data 0.3.0 77 | tornado 6.1 78 | traitlets 5.3.0 79 | wcwidth 0.2.5 80 | zmq 23.1.0 81 | ----- 82 | IPython 8.4.0 83 | jupyter_client 7.3.4 84 | jupyter_core 4.10.0 85 | ----- 86 | Python 3.9.13 | packaged by conda-forge | (main, May 27 2022, 16:58:50) [GCC 10.3.0] 87 | Linux-5.18.6-arch1-1-x86_64-with-glibc2.35 88 | ----- 89 | Session information updated at 2022-07-07 17:55 90 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | contact_links: 3 | - name: Scverse Community Forum 4 | url: https://discourse.scverse.org/ 5 | about: If you have questions about “How to do X”, please ask them here. 6 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.yml: -------------------------------------------------------------------------------- 1 | name: Feature request 2 | description: Propose a new feature for pyLemur 3 | labels: enhancement 4 | body: 5 | - type: textarea 6 | id: description 7 | attributes: 8 | label: Description of feature 9 | description: Please describe your suggestion for a new feature. It might help to describe a problem or use case, plus any alternatives that you have considered. 10 | validations: 11 | required: true 12 | -------------------------------------------------------------------------------- /.github/workflows/build.yaml: -------------------------------------------------------------------------------- 1 | name: Check Build 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | 9 | concurrency: 10 | group: ${{ github.workflow }}-${{ github.ref }} 11 | cancel-in-progress: true 12 | 13 | jobs: 14 | package: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@v3 18 | - name: Set up Python 3.10 19 | uses: actions/setup-python@v4 20 | with: 21 | python-version: "3.10" 22 | cache: "pip" 23 | cache-dependency-path: "**/pyproject.toml" 24 | - name: Install build dependencies 25 | run: python -m pip install --upgrade pip wheel twine build 26 | - name: Build package 27 | run: python -m build 28 | - name: Check package 29 | run: twine check --strict dist/*.whl 30 | -------------------------------------------------------------------------------- /.github/workflows/release.yaml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | # Use "trusted publishing", see https://docs.pypi.org/trusted-publishers/ 8 | jobs: 9 | release: 10 | name: Upload release to PyPI 11 | runs-on: ubuntu-latest 12 | environment: 13 | name: pypi 14 | url: https://pypi.org/p/pylemur 15 | permissions: 16 | id-token: write # IMPORTANT: this permission is mandatory for trusted publishing 17 | steps: 18 | - uses: actions/checkout@v4 19 | with: 20 | filter: blob:none 21 | fetch-depth: 0 22 | - uses: actions/setup-python@v4 23 | with: 24 | python-version: "3.x" 25 | cache: "pip" 26 | - run: pip install build 27 | - run: python -m build 28 | - name: Publish package distributions to PyPI 29 | uses: pypa/gh-action-pypi-publish@release/v1 30 | -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | schedule: 9 | - cron: "0 5 1,15 * *" 10 | 11 | concurrency: 12 | group: ${{ github.workflow }}-${{ github.ref }} 13 | cancel-in-progress: true 14 | 15 | jobs: 16 | test: 17 | runs-on: ${{ matrix.os }} 18 | defaults: 19 | run: 20 | shell: bash -e {0} # -e to fail on error 21 | 22 | strategy: 23 | fail-fast: false 24 | matrix: 25 | include: 26 | - os: ubuntu-latest 27 | python: "3.10" 28 | - os: ubuntu-latest 29 | python: "3.12" 30 | - os: ubuntu-latest 31 | python: "3.12" 32 | pip-flags: "--pre" 33 | name: PRE-RELEASE DEPENDENCIES 34 | 35 | name: ${{ matrix.name }} Python ${{ matrix.python }} 36 | 37 | env: 38 | OS: ${{ matrix.os }} 39 | PYTHON: ${{ matrix.python }} 40 | 41 | steps: 42 | - uses: actions/checkout@v3 43 | - name: Set up Python ${{ matrix.python }} 44 | uses: actions/setup-python@v4 45 | with: 46 | python-version: ${{ matrix.python }} 47 | cache: "pip" 48 | cache-dependency-path: "**/pyproject.toml" 49 | 50 | - name: Install test dependencies 51 | run: | 52 | python -m pip install --upgrade pip wheel 53 | - name: Install dependencies 54 | run: | 55 | pip install ${{ matrix.pip-flags }} ".[dev,test]" 56 | - name: Test 57 | env: 58 | MPLBACKEND: agg 59 | PLATFORM: ${{ matrix.os }} 60 | DISPLAY: :42 61 | run: | 62 | coverage run -m pytest -v --color=yes 63 | - name: Report coverage 64 | run: | 65 | coverage report 66 | - name: Upload coverage 67 | uses: codecov/codecov-action@v3 68 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by pytest automatically. 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | bin/ 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | eggs/ 15 | lib/ 16 | lib64/ 17 | parts/ 18 | sdist/ 19 | var/ 20 | *.egg-info/ 21 | .installed.cfg 22 | *.egg 23 | 24 | # Installer logs 25 | pip-log.txt 26 | pip-delete-this-directory.txt 27 | 28 | # Unit test / coverage reports 29 | .tox/ 30 | .coverage 31 | .cache 32 | nosetests.xml 33 | coverage.xml 34 | 35 | .pytest_cache/ 36 | 37 | # Temp files 38 | .DS_Store 39 | *~ 40 | buck-out/ 41 | 42 | # Compiled files 43 | .venv/ 44 | .venv_pre_release/ 45 | __pycache__/ 46 | .mypy_cache/ 47 | .ruff_cache/ 48 | 49 | # Distribution / packaging 50 | /build/ 51 | /dist/ 52 | /*.egg-info/ 53 | 54 | # Tests and coverage 55 | /.pytest_cache/ 56 | /.cache/ 57 | /data/ 58 | /node_modules/ 59 | 60 | # docs 61 | /docs/generated/ 62 | /docs/_build/ 63 | 64 | # IDEs 65 | /.idea/ 66 | /.vscode/ 67 | 68 | # Data files from tutorial 69 | docs/notebooks/data/* 70 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | fail_fast: false 2 | default_language_version: 3 | python: python3 4 | default_stages: 5 | - commit 6 | - push 7 | minimum_pre_commit_version: 2.16.0 8 | repos: 9 | - repo: https://github.com/pre-commit/mirrors-prettier 10 | rev: v4.0.0-alpha.8 11 | hooks: 12 | - id: prettier 13 | - repo: https://github.com/astral-sh/ruff-pre-commit 14 | rev: v0.3.2 15 | hooks: 16 | - id: ruff 17 | types_or: [python, pyi, jupyter] 18 | args: [--fix, --exit-non-zero-on-fix] 19 | - id: ruff-format 20 | types_or: [python, pyi, jupyter] 21 | - repo: https://github.com/pre-commit/pre-commit-hooks 22 | rev: v4.5.0 23 | hooks: 24 | - id: detect-private-key 25 | - id: check-ast 26 | - id: end-of-file-fixer 27 | - id: mixed-line-ending 28 | args: [--fix=lf] 29 | - id: trailing-whitespace 30 | - id: check-case-conflict 31 | # Check that there are no merge conflicts (could be generated by template sync) 32 | - id: check-merge-conflict 33 | args: [--assume-in-merge] 34 | - repo: local 35 | hooks: 36 | - id: forbid-to-commit 37 | name: Don't commit rej files 38 | entry: | 39 | Cannot commit .rej files. These indicate merge conflicts that arise during automated template updates. 40 | Fix the merge conflicts manually and remove the .rej files. 41 | language: fail 42 | files: '.*\.rej$' 43 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # https://docs.readthedocs.io/en/stable/config-file/v2.html 2 | version: 2 3 | build: 4 | os: ubuntu-20.04 5 | tools: 6 | python: "3.10" 7 | sphinx: 8 | configuration: docs/conf.py 9 | # disable this for more lenient docs builds 10 | fail_on_warning: false 11 | python: 12 | install: 13 | - method: pip 14 | path: . 15 | extra_requirements: 16 | - doc 17 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | The format is based on [Keep a Changelog][], 6 | and this project adheres to [Semantic Versioning][]. 7 | 8 | [keep a changelog]: https://keepachangelog.com/en/1.0.0/ 9 | [semantic versioning]: https://semver.org/spec/v2.0.0.html 10 | 11 | ## [Unreleased] 12 | 13 | # [0.3.1] 14 | 15 | - Fix documentation of `cond()` return type and handle pd.Series in 16 | `model.predict()` (#5, thanks Mark Keller) 17 | 18 | ## [0.3.0] 19 | 20 | - Depend on `formulaic_contrast` package 21 | - Refactor `cond()` implementation to use `formulaic_contrast` implementation. 22 | 23 | ## [0.2.2] 24 | 25 | - Sync with cookiecutter-template update (version 0.4) 26 | - Bump required Python version to `3.10` 27 | - Allow data frames as design matrices 28 | - Allow matrices as input to LEMUR() 29 | 30 | ## [0.2.1] 31 | 32 | - Change example gene to one with clearer differential expression pattern 33 | - Remove error output in `align_harmony 34 | 35 | ## [0.2.0] 36 | 37 | Major rewrite of the API. Instead of adding coefficients as custom fields 38 | to the input `AnnData` object, the API now follows an object-oriented style 39 | similar to scikit-learn or `SCVI`. This change was motivated by the feedback 40 | during the submission to the `scverse` ecosystem. 41 | ([Thanks](<(https://github.com/scverse/ecosystem-packages/pull/156#issuecomment-2014676654)>) Gregor). 42 | 43 | ### Changed 44 | 45 | - Instead of calling `fit = pylemur.tl.lemur(adata, ...)`, you now create a LEMUR model 46 | (`model = pylemur.tl.LEMUR(adata, ...)`) and subsequently call `model.fit()`, `model.align_with_harmony()`, 47 | and `model.predict()`. 48 | 49 | ## [0.1.0] - 2024-03-21 50 | 51 | - Initial beta release of `pyLemur` 52 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024, Your Name 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pyLemur 2 | 3 | [![Tests][badge-tests]][link-tests] 4 | [![Documentation][badge-docs]][link-docs] 5 | 6 | [badge-tests]: https://img.shields.io/github/actions/workflow/status/const-ae/pyLemur/test.yaml?branch=main 7 | [link-tests]: https://github.com/const-ae/pyLemur/actions/workflows/test.yaml 8 | [link-docs]: https://pyLemur.readthedocs.io 9 | [badge-docs]: http://readthedocs.org/projects/pylemur/badge 10 | 11 | The Python implementation of the LEMUR method to analyze multi-condition single-cell data. For the more complete version in R, see [github.com/const-ae/lemur](https://github.com/const-ae/lemur). To learn more check-out the [function documentation](https://pylemur.readthedocs.io/page/api.html) and the [tutorial](https://pylemur.readthedocs.io/page/notebooks/Tutorial.html) at [pylemur.readthedocs.io](https://pylemur.readthedocs.io). To check-out the source code or submit an issue go to [github.com/const-ae/pyLemur](https://github.com/const-ae/pyLemur) 12 | 13 | ## Citation 14 | 15 | > Ahlmann-Eltze C, Huber W (2025). 16 | > “Analysis of multi-condition single-cell data with latent embedding multivariate regression.” Nature Genetics (2025). 17 | > [doi:10.1038/s41588-024-01996-0](https://doi.org/10.1038/s41588-024-01996-0). 18 | 19 | # Getting started 20 | 21 | ## Installation 22 | 23 | You need to have Python 3.10 or newer installed on your system. 24 | There are several alternative options to install pyLemur: 25 | 26 | Install the latest release of `pyLemur` from [PyPI](https://pypi.org/project/pyLemur/): 27 | 28 | ```bash 29 | pip install pyLemur 30 | ``` 31 | 32 | Alternatively, install the latest development version directly from Github: 33 | 34 | ```bash 35 | pip install git+https://github.com/const-ae/pyLemur.git@main 36 | ``` 37 | 38 | ## Documentation 39 | 40 | For more information on the functions see the [API docs](https://pyLemur.readthedocs.io/page/api.html) and the [tutorial](https://pylemur.readthedocs.io/page/notebooks/Tutorial.html). 41 | 42 | ## Contact 43 | 44 | For questions and help requests, you can reach out in the [scverse discourse][scverse-discourse]. 45 | If you found a bug, please use the [issue tracker][issue-tracker]. 46 | 47 | [scverse-discourse]: https://discourse.scverse.org/ 48 | [issue-tracker]: https://github.com/const-ae/pyLemur/issues 49 | 50 | ## Building 51 | 52 | Install the package in editable mode: 53 | 54 | ``` 55 | pip install ".[dev,doc,test]" 56 | ``` 57 | 58 | Build the documentation locally 59 | 60 | ``` 61 | cd docs 62 | make html 63 | open _build/html/index.html 64 | ``` 65 | 66 | Run the unit tests 67 | 68 | ``` 69 | pytest 70 | ``` 71 | 72 | Run pre-commit hooks manually 73 | 74 | ``` 75 | pre-commit run --all-files 76 | ``` 77 | 78 | or individually 79 | 80 | ``` 81 | ruff check 82 | ``` 83 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | livehtml: 18 | sphinx-autobuild "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 19 | 20 | # Catch-all target: route all unknown targets to Sphinx using the new 21 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 22 | %: Makefile 23 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 24 | -------------------------------------------------------------------------------- /docs/_static/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/const-ae/pyLemur/64b8301e2ec14d9eb948b74c6e3eabcff9b2d63a/docs/_static/.gitkeep -------------------------------------------------------------------------------- /docs/_static/css/custom.css: -------------------------------------------------------------------------------- 1 | /* Reduce the font size in data frames - See https://github.com/scverse/cookiecutter-scverse/issues/193 */ 2 | div.cell_output table.dataframe { 3 | font-size: 0.8em; 4 | } 5 | -------------------------------------------------------------------------------- /docs/_static/images/equation_schematic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/const-ae/pyLemur/64b8301e2ec14d9eb948b74c6e3eabcff9b2d63a/docs/_static/images/equation_schematic.png -------------------------------------------------------------------------------- /docs/_templates/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/const-ae/pyLemur/64b8301e2ec14d9eb948b74c6e3eabcff9b2d63a/docs/_templates/.gitkeep -------------------------------------------------------------------------------- /docs/_templates/autosummary/class.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. add toctree option to make autodoc generate the pages 6 | 7 | .. autoclass:: {{ objname }} 8 | 9 | {% block attributes %} 10 | {% if attributes %} 11 | Attributes table 12 | ~~~~~~~~~~~~~~~~~~ 13 | 14 | .. autosummary:: 15 | {% for item in attributes %} 16 | ~{{ fullname }}.{{ item }} 17 | {%- endfor %} 18 | {% endif %} 19 | {% endblock %} 20 | 21 | {% block methods %} 22 | {% if methods %} 23 | Methods table 24 | ~~~~~~~~~~~~~ 25 | 26 | .. autosummary:: 27 | {% for item in methods %} 28 | {%- if item != '__init__' %} 29 | ~{{ fullname }}.{{ item }} 30 | {%- endif -%} 31 | {%- endfor %} 32 | {% endif %} 33 | {% endblock %} 34 | 35 | {% block attributes_documentation %} 36 | {% if attributes %} 37 | Attributes 38 | ~~~~~~~~~~~ 39 | 40 | {% for item in attributes %} 41 | 42 | .. autoattribute:: {{ [objname, item] | join(".") }} 43 | {%- endfor %} 44 | 45 | {% endif %} 46 | {% endblock %} 47 | 48 | {% block methods_documentation %} 49 | {% if methods %} 50 | Methods 51 | ~~~~~~~ 52 | 53 | {% for item in methods %} 54 | {%- if item != '__init__' %} 55 | 56 | .. automethod:: {{ [objname, item] | join(".") }} 57 | {%- endif -%} 58 | {%- endfor %} 59 | 60 | {% endif %} 61 | {% endblock %} 62 | -------------------------------------------------------------------------------- /docs/api.md: -------------------------------------------------------------------------------- 1 | # API 2 | 3 | ## Tools 4 | 5 | To create the LEMUR object that provides functionality to fit, align, predict, and transform input data: 6 | 7 | ```{eval-rst} 8 | .. module:: pylemur.tl 9 | .. module:: pylemur.pp 10 | 11 | .. currentmodule:: pylemur 12 | 13 | .. autosummary:: 14 | :toctree: generated 15 | 16 | tl.LEMUR 17 | pp.shifted_log_transform 18 | ``` 19 | -------------------------------------------------------------------------------- /docs/changelog.md: -------------------------------------------------------------------------------- 1 | ```{include} ../CHANGELOG.md 2 | 3 | ``` 4 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | import sys 9 | from datetime import datetime 10 | from importlib.metadata import metadata 11 | from pathlib import Path 12 | 13 | HERE = Path(__file__).parent 14 | sys.path.insert(0, str(HERE / "extensions")) 15 | 16 | 17 | # -- Project information ----------------------------------------------------- 18 | 19 | # NOTE: If you installed your project in editable mode, this might be stale. 20 | # If this is the case, reinstall it to refresh the metadata 21 | info = metadata("pyLemur") 22 | project_name = info["Name"] 23 | author = info["Author"] 24 | copyright = f"{datetime.now():%Y}, {author}." 25 | version = info["Version"] 26 | urls = dict(pu.split(", ") for pu in info.get_all("Project-URL")) 27 | repository_url = urls["Source"] 28 | 29 | # The full version, including alpha/beta/rc tags 30 | release = info["Version"] 31 | 32 | bibtex_bibfiles = ["references.bib"] 33 | bibtex_default_style = "unsrt" 34 | 35 | templates_path = ["_templates"] 36 | nitpicky = True # Warn about broken links 37 | needs_sphinx = "4.0" 38 | 39 | html_context = { 40 | "display_github": True, # Integrate GitHub 41 | "github_user": "const-ae", 42 | "github_repo": "https://github.com/const-ae/pyLemur", 43 | "github_version": "main", 44 | "conf_py_path": "/docs/", 45 | } 46 | 47 | # -- General configuration --------------------------------------------------- 48 | 49 | # Add any Sphinx extension module names here, as strings. 50 | # They can be extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones. 51 | extensions = [ 52 | "myst_nb", 53 | "sphinx_copybutton", 54 | "sphinx.ext.autodoc", 55 | "sphinx.ext.intersphinx", 56 | "sphinx.ext.autosummary", 57 | "sphinx.ext.napoleon", 58 | "sphinx_autodoc_typehints", 59 | "sphinxcontrib.bibtex", 60 | "sphinx_autodoc_typehints", 61 | "sphinx.ext.mathjax", 62 | "sphinx_design", 63 | "IPython.sphinxext.ipython_console_highlighting", 64 | "sphinxext.opengraph", 65 | *[p.stem for p in (HERE / "extensions").glob("*.py")], 66 | ] 67 | 68 | autosummary_generate = True 69 | autodoc_member_order = "groupwise" 70 | default_role = "literal" 71 | napoleon_google_docstring = False 72 | napoleon_numpy_docstring = True 73 | napoleon_include_init_with_doc = False 74 | napoleon_use_rtype = True # having a separate entry generally helps readability 75 | napoleon_use_param = True 76 | napoleon_use_ivar = True 77 | 78 | 79 | myst_heading_anchors = 6 # create anchors for h1-h6 80 | myst_enable_extensions = [ 81 | "amsmath", 82 | "colon_fence", 83 | "deflist", 84 | "dollarmath", 85 | "html_image", 86 | "html_admonition", 87 | ] 88 | myst_url_schemes = ("http", "https", "mailto") 89 | nb_output_stderr = "remove" 90 | nb_execution_mode = "cache" 91 | nb_merge_streams = True 92 | typehints_defaults = "braces" 93 | 94 | source_suffix = { 95 | ".rst": "restructuredtext", 96 | ".ipynb": "myst-nb", 97 | ".myst": "myst-nb", 98 | } 99 | 100 | intersphinx_mapping = { 101 | "python": ("https://docs.python.org/3", None), 102 | "anndata": ("https://anndata.readthedocs.io/en/stable/", None), 103 | "numpy": ("https://numpy.org/doc/stable/", None), 104 | "pandas": ("https://pandas.pydata.org/pandas-docs/stable/", None), 105 | "scipy": ("https://docs.scipy.org/doc/scipy/", None), 106 | "sklearn": ("https://scikit-learn.org/stable/", None), 107 | # "formulaic": ("https://matthewwardrop.github.io/formulaic/", None) # Doesn't work. 108 | } 109 | 110 | 111 | # List of patterns, relative to source directory, that match files and 112 | # directories to ignore when looking for source files. 113 | # This pattern also affects html_static_path and html_extra_path. 114 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "**.ipynb_checkpoints"] 115 | 116 | 117 | # -- Options for HTML output ------------------------------------------------- 118 | 119 | # The theme to use for HTML and HTML Help pages. See the documentation for 120 | # a list of builtin themes. 121 | # 122 | html_theme = "sphinx_book_theme" 123 | html_static_path = ["_static"] 124 | html_css_files = ["css/custom.css"] 125 | 126 | html_title = project_name 127 | 128 | html_theme_options = { 129 | "repository_url": repository_url, 130 | "use_repository_button": True, 131 | "path_to_docs": "docs/", 132 | "navigation_with_keys": False, 133 | } 134 | 135 | pygments_style = "default" 136 | 137 | nitpick_ignore = [ 138 | # If building the documentation fails because of a missing link that is outside your control, 139 | # you can add an exception to this list. 140 | # ("py:class", "igraph.Graph"), 141 | ] 142 | -------------------------------------------------------------------------------- /docs/extensions/typed_returns.py: -------------------------------------------------------------------------------- 1 | # code from https://github.com/theislab/scanpy/blob/master/docs/extensions/typed_returns.py 2 | # with some minor adjustment 3 | from __future__ import annotations 4 | 5 | import re 6 | from collections.abc import Generator, Iterable 7 | 8 | from sphinx.application import Sphinx 9 | from sphinx.ext.napoleon import NumpyDocstring 10 | 11 | 12 | def _process_return(lines: Iterable[str]) -> Generator[str, None, None]: 13 | for line in lines: 14 | if m := re.fullmatch(r"(?P\w+)\s+:\s+(?P[\w.]+)", line): 15 | yield f'-{m["param"]} (:class:`~{m["type"]}`)' 16 | else: 17 | yield line 18 | 19 | 20 | def _parse_returns_section(self: NumpyDocstring, section: str) -> list[str]: 21 | lines_raw = self._dedent(self._consume_to_next_section()) 22 | if lines_raw[0] == ":": 23 | del lines_raw[0] 24 | lines = self._format_block(":returns: ", list(_process_return(lines_raw))) 25 | if lines and lines[-1]: 26 | lines.append("") 27 | return lines 28 | 29 | 30 | def setup(app: Sphinx): 31 | """Set app.""" 32 | NumpyDocstring._parse_returns_section = _parse_returns_section 33 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | ```{include} ../README.md 2 | 3 | ``` 4 | 5 | ```{toctree} 6 | :hidden: true 7 | :maxdepth: 1 8 | 9 | api.md 10 | notebooks/Tutorial.myst 11 | changelog.md 12 | references.md 13 | 14 | ``` 15 | -------------------------------------------------------------------------------- /docs/notebooks/Tutorial.myst: -------------------------------------------------------------------------------- 1 | --- 2 | file_format: mystnb 3 | mystnb: 4 | execution_timeout: 600 5 | kernelspec: 6 | name: python3 7 | --- 8 | 9 | # pyLemur Walkthrough 10 | 11 | 12 | The goal of `pyLemur` is to simplify the analysis of multi-condition single-cell data. If you have collected a single-cell RNA-seq dataset with more than one condition, LEMUR predicts for each cell and gene how much the expression would change if the cell had been in the other condition. 13 | 14 | `pyLemur` is a Python implementation of the LEMUR model; there is also an `R` package called [lemur](https://bioconductor.org/packages/lemur/), which provides additional functionality: identifying neighborhoods of cells that show consistent differential expression values and a pseudo-bulk test to validate the findings. 15 | 16 | `pyLemur` implements a novel framework to disentangle the effects of known covariates, latent cell states, and their interactions. At the core is a combination of matrix factorization and regression analysis implemented as geodesic regression on Grassmann manifolds. We call this latent embedding multivariate regression (LEMUR). For more details, see our [preprint](https://www.biorxiv.org/content/10.1101/2023.03.06.531268) {cite:p}`Ahlmann-Eltze2024`. 17 | 18 | Schematic of the matrix decomposition at the core of LEMUR 19 | 20 | 21 | ## Data 22 | 23 | For demonstration, I will use a dataset of interferon-$\beta$ stimulated blood cells from {cite:t}`kang2018`. 24 | 25 | ```{code-cell} ipython3 26 | --- 27 | output_stderr: remove 28 | --- 29 | # Standard imports 30 | import numpy as np 31 | import scanpy as sc 32 | # pertpy is needed to download the Kang data 33 | import pertpy 34 | 35 | # This will download the data to ./data/kang_2018.h5ad 36 | adata = pertpy.data.kang_2018() 37 | # Store counts separately in the layers 38 | adata.layers["counts"] = adata.X.copy() 39 | ``` 40 | 41 | The data consists of $24\,673$ cells and $15\,706$ genes. The cells were measured in two conditions (`label="ctrl"` and `label="stim"`). The authors have annotated the cell type for each cell, which will be useful to analyze LEMUR's results; however, note that the cell type labels are not used (and not needed) to fit the LEMUR model. 42 | 43 | ```{code-cell} ipython3 44 | :tags: ["remove-cell"] 45 | import pandas as pd 46 | pd.options.display.width = 200 47 | pd.options.display.max_colwidth = 20 48 | ``` 49 | 50 | ```{code-cell} ipython3 51 | print(adata) 52 | print(adata.obs) 53 | ``` 54 | 55 | ## Preprocessing 56 | 57 | LEMUR expects that the input has been variance-stabilized. Here, I will use the log-transformation as a simple, yet effective approach. 58 | In addition, I will only work on the $1\,000$ most variable genes to make the results easier to manage. 59 | ```{code-cell} ipython3 60 | # This follows the standard recommendation from scanpy 61 | sc.pp.normalize_total(adata, target_sum = 1e4, inplace=True) 62 | sc.pp.log1p(adata) 63 | adata.layers["logcounts"] = adata.X.copy() 64 | sc.pp.highly_variable_genes(adata, n_top_genes=1000, flavor="cell_ranger") 65 | adata = adata[:, adata.var.highly_variable] 66 | adata 67 | ``` 68 | 69 | If we make a 2D plot of the data using UMAP, we see that the cell types separate by treatment status. 70 | ```{code-cell} ipython3 71 | sc.tl.pca(adata) 72 | sc.pp.neighbors(adata) 73 | sc.tl.umap(adata) 74 | sc.pl.umap(adata, color=["label", "cell_type"]) 75 | ``` 76 | 77 | 78 | ## LEMUR 79 | 80 | First, we import `pyLemur`; then, we fit the LEMUR model by providing the `AnnData` object, a specification of the experimental design, and the number of latent dimensions. 81 | 82 | ```{code-cell} ipython3 83 | import pylemur 84 | model = pylemur.tl.LEMUR(adata, design = "~ label", n_embedding=15) 85 | model.fit() 86 | model.align_with_harmony() 87 | print(model) 88 | ``` 89 | 90 | To assess if the model was fit successfully, we plot a UMAP representation of the 15-dimensional embedding calculated by LEMUR. We want to see that the two conditions are well mixed in the embedding space because that means that LEMUR was able to disentangle the treatment effect from the cell type effect and that the residual variation is driven by the cell states. 91 | ```{code-cell} ipython3 92 | # Recalculate the UMAP on the embedding calculated by LEMUR 93 | adata.obsm["embedding"] = model.embedding 94 | sc.pp.neighbors(adata, use_rep="embedding") 95 | sc.tl.umap(adata) 96 | sc.pl.umap(adata, color=["label", "cell_type"]) 97 | ``` 98 | 99 | The LEMUR model is fully parametric, which means that we can predict for each cell what its expression would have been in any condition (i.e., for a cell observed in the control condition, we can predict its expression under treatment) as a function of its low-dimensional embedding. 100 | 101 | ```{code-cell} ipython3 102 | # The model.cond(**kwargs) call specifies the condition for the prediction 103 | ctrl_pred = model.predict(new_condition=model.cond(label="ctrl")) 104 | stim_pred = model.predict(new_condition=model.cond(label="stim")) 105 | ``` 106 | 107 | We can now check the predicted differential expression against the underlying observed expression patterns for individual genes. Here, I chose _TSC22D3_ as an example. The blue cells in the first plot are in neighborhoods with higher expression in the control condition than in the stimulated condition. The two other plots show the underlying gene expression for the control and stimulated cells and confirm LEMUR's inference. 108 | ```{code-cell} ipython3 109 | import matplotlib.pyplot as plt 110 | adata.layers["diff"] = stim_pred - ctrl_pred 111 | # Also try CXCL10, IL8, and FBXO40 112 | sel_gene = "TSC22D3" 113 | 114 | fsize = plt.rcParams['figure.figsize'] 115 | fig = plt.figure(figsize=(fsize[0] * 3, fsize[1])) 116 | axs = [fig.add_subplot(1, 3, i+1) for i in range(3)] 117 | for ax in axs: 118 | ax.set_aspect('equal') 119 | sc.pl.umap(adata, layer="diff", color=[sel_gene], cmap = plt.get_cmap("seismic"), vcenter=0, 120 | vmin=-4, vmax=4, title="Pred diff (stim - ctrl)", ax=axs[0], show=False) 121 | sc.pl.umap(adata[adata.obs["label"]=="ctrl"], layer="logcounts", color=[sel_gene], vmin = 0, vmax =4, 122 | title="Ctrl expr", ax=axs[1], show=False) 123 | sc.pl.umap(adata[adata.obs["label"]=="stim"], layer="logcounts", color=[sel_gene], vmin = 0, vmax =4, 124 | title="Stim expr", ax=axs[2]) 125 | ``` 126 | 127 | To assess the overall accuracy of LEMUR's predictions, I will compare the average observed and predicted expression per cell type between conditions. The next plot simply shows the observed expression values. Genes on the diagonal don't change expression much between conditions within a cell type, whereas all off-diagonal genes are differentially expressed: 128 | ```{code-cell} ipython3 129 | def rowMeans_per_group(X, group): 130 | uniq = np.unique(group) 131 | res = np.zeros((len(uniq), X.shape[1])) 132 | for i, e in enumerate(uniq): 133 | res[i,:] = X[group == e,:].sum(axis=0) / sum(group == e) 134 | return res 135 | 136 | adata_ctrl = adata[adata.obs["label"] == "ctrl",:] 137 | adata_stim = adata[adata.obs["label"] == "stim",:] 138 | ctrl_expr_per_cell_type = rowMeans_per_group(adata_ctrl.layers["logcounts"], adata_ctrl.obs["cell_type"]) 139 | stim_expr_per_cell_type = rowMeans_per_group(adata_stim.layers["logcounts"], adata_stim.obs["cell_type"]) 140 | obs_diff = stim_expr_per_cell_type - ctrl_expr_per_cell_type 141 | plt.scatter(ctrl_expr_per_cell_type, stim_expr_per_cell_type, c = obs_diff, 142 | cmap = plt.get_cmap("seismic"), vmin=-5, vmax=5, marker="o",edgecolors= "black") 143 | plt.colorbar() 144 | plt.title( "Inf-b stim. increases gene expression for many genes") 145 | plt.axline((0, 0), (1, 1), linewidth=1, color='black') 146 | ``` 147 | 148 | To demonstrate that LEMUR learned the underlying expression relations, I predict what the expression of cells from the control condition would have been had they been stimulated and compare the results against the observed expression in the stimulated condition. The closer the points are to the diagonal, the better the predictions. 149 | ```{code-cell} ipython3 150 | stim_pred_per_cell_type = rowMeans_per_group(stim_pred[adata.obs["label"]=="ctrl"], adata_ctrl.obs["cell_type"]) 151 | 152 | plt.scatter(stim_expr_per_cell_type, stim_pred_per_cell_type, c = obs_diff, 153 | cmap = plt.get_cmap("seismic"), vmin=-5, vmax=5, marker="o",edgecolors= "black") 154 | plt.colorbar() 155 | plt.title( "LEMUR's expression predictions are accurate") 156 | plt.axline((0, 0), (1, 1), linewidth=1, color='black') 157 | ``` 158 | 159 | Lastly, I directly compare the average predicted differential expression against the average observed differential expression per cell type. Again, the closer the points are to the diagonal, the better the predictions. 160 | 161 | ```{code-cell} ipython3 162 | pred_diff = rowMeans_per_group(adata.layers["diff"], adata.obs["cell_type"]) 163 | 164 | plt.scatter(obs_diff, pred_diff, c = obs_diff, 165 | cmap = plt.get_cmap("seismic"), vmin=-5, vmax=5, marker="o",edgecolors= "black") 166 | plt.colorbar() 167 | plt.title( "LEMUR's DE predictions are accurate") 168 | plt.axline((0, 0), (1, 1), linewidth=1, color='black') 169 | ``` 170 | 171 | Another advantage of LEMUR's parametricity is that you could train the model on a subset of the data and then apply it to the full data. 172 | 173 | I will demonstrate this by training the same LEMUR model on 5% of the original data, then `transform` the full data, and finally compare the first three dimensions of the embedding against the embedding from the model trained on the full model. 174 | 175 | ```{code-cell} ipython3 176 | adata_subset = adata[np.random.choice(np.arange(adata.shape[0]), size = round(adata.shape[0] * 0.05)),] 177 | model_small = pylemur.tl.LEMUR(adata_subset, design = "~ label", n_embedding=15) 178 | model_small.fit().align_with_harmony() 179 | emb_proj = model_small.transform(adata) 180 | plt.scatter(emb_proj[:,0:3], model.embedding[:,0:3], s = 0.1) 181 | plt.axline((0, 0), (1, 1), linewidth=1, color='black') 182 | plt.axline((0, 0), (-1, 1), linewidth=1, color='black') 183 | ``` 184 | 185 | We see that the small model still captures most of the relevant variation. 186 | ```{code-cell} ipython3 187 | adata.obsm["embedding_from_small_fit"] = emb_proj 188 | sc.pp.neighbors(adata, use_rep="embedding_from_small_fit") 189 | sc.tl.umap(adata) 190 | sc.pl.umap(adata, color=["label", "cell_type"]) 191 | ``` 192 | 193 | ### Session Info 194 | 195 | ```{code-cell} ipython3 196 | import session_info 197 | session_info.show() 198 | ``` 199 | 200 | 201 | ### References 202 | 203 | ```{bibliography} 204 | :style: plain 205 | :filter: docname in docnames 206 | ``` 207 | -------------------------------------------------------------------------------- /docs/references.bib: -------------------------------------------------------------------------------- 1 | @article{Virshup_2023, 2 | doi = {10.1038/s41587-023-01733-8}, 3 | url = {https://doi.org/10.1038%2Fs41587-023-01733-8}, 4 | year = 2023, 5 | month = {apr}, 6 | publisher = {Springer Science and Business Media {LLC}}, 7 | author = {Isaac Virshup and Danila Bredikhin and Lukas Heumos and Giovanni Palla and Gregor Sturm and Adam Gayoso and Ilia Kats and Mikaela Koutrouli and Philipp Angerer and Volker Bergen and Pierre Boyeau and Maren Büttner and Gokcen Eraslan and David Fischer and Max Frank and Justin Hong and Michal Klein and Marius Lange and Romain Lopez and Mohammad Lotfollahi and Malte D. Luecken and Fidel Ramirez and Jeffrey Regier and Sergei Rybakov and Anna C. Schaar and Valeh Valiollah Pour Amiri and Philipp Weiler and Galen Xing and Bonnie Berger and Dana Pe'er and Aviv Regev and Sarah A. Teichmann and Francesca Finotello and F. Alexander Wolf and Nir Yosef and Oliver Stegle and Fabian J. Theis and}, 8 | title = {The scverse project provides a computational ecosystem for single-cell omics data analysis}, 9 | journal = {Nature Biotechnology} 10 | } 11 | @article{Ahlmann-Eltze2024, 12 | title = {Analysis of multi-condition single-cell data with latent embedding multivariate regression}, 13 | author = {Ahlmann-Eltze, Constantin and Huber, Wolfgang}, 14 | year = {2024}, 15 | month = {02}, 16 | url = {http://dx.doi.org/10.1101/2023.03.06.531268}, 17 | journal = {bioRxiv}, 18 | } 19 | @article{kang2018, 20 | title = {Multiplexed droplet single-cell RNA-sequencing using natural genetic variation}, 21 | author = {Kang, Hyun Min and Subramaniam, Meena and Targ, Sasha and Nguyen, Michelle and Maliskova, Lenka and McCarthy, Elizabeth and Wan, Eunice and Wong, Simon and Byrnes, Lauren and Lanata, Cristina M and Gate, Rachel E and Mostafavi, Sara and Marson, Alexander and Zaitlen, Noah and Criswell, Lindsey A and Ye, Chun Jimmie}, 22 | year = {2018}, 23 | month = {01}, 24 | date = {2018-01}, 25 | journal = {Nature Biotechnology}, 26 | pages = {89--94}, 27 | volume = {36}, 28 | number = {1}, 29 | doi = {10.1038/nbt.4042}, 30 | url = {http://www.nature.com/articles/nbt.4042}, 31 | langid = {en} 32 | } 33 | -------------------------------------------------------------------------------- /docs/references.md: -------------------------------------------------------------------------------- 1 | # References 2 | 3 | ```{bibliography} 4 | :cited: 5 | ``` 6 | -------------------------------------------------------------------------------- /notebooks/check_implementation.qmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Compare implementation directly against R" 3 | author: Constantin Ahlmann-Eltze 4 | format: 5 | html: 6 | code-fold: false 7 | embed-resources: true 8 | highlight-style: github 9 | toc: true 10 | code-line-numbers: true 11 | execute: 12 | keep-ipynb: true 13 | jupyter: python3 14 | --- 15 | 16 | 17 | ```{python} 18 | %load_ext autoreload 19 | %autoreload 2 20 | ``` 21 | 22 | 23 | ```{python} 24 | import debugpy 25 | debugpy.listen(5678) 26 | print("Waiting for debugger attach") 27 | debugpy.wait_for_client() 28 | ``` 29 | 30 | 31 | ```{python} 32 | import numpy as np 33 | import scanpy as sc 34 | import pertpy 35 | 36 | adata = pertpy.data.kang_2018() 37 | adata.layers["counts"] = adata.X.copy() 38 | ``` 39 | 40 | 41 | ```{python} 42 | sc.pp.normalize_total(adata, target_sum = 1e4, inplace=True) 43 | sc.pp.log1p(adata) 44 | adata.layers["logcounts"] = adata.X.copy() 45 | sc.pp.highly_variable_genes(adata, n_top_genes=1000, flavor="cell_ranger") 46 | adata = adata[:, adata.var.highly_variable] 47 | adata 48 | ``` 49 | 50 | 51 | ```{python} 52 | import pylemur 53 | model = pylemur.tl.LEMUR(adata, design = "~ label", n_embedding=15, layer = "logcounts") 54 | model.fit() 55 | model.align_with_harmony() 56 | print(model) 57 | ``` 58 | 59 | 60 | ```{python} 61 | ctrl_pred = model.predict(new_condition=model.cond(label="ctrl")) 62 | model.cond(label = "stim") - model.cond(label = "ctrl") 63 | ``` 64 | 65 | ```{python} 66 | model.embedding.shape 67 | model.adata 68 | ``` 69 | 70 | ```{python} 71 | # groups = np.array([np.nan] * fit.shape[0]) 72 | # groups[fit.obs["cell_type"] == "CD4 T cells"] = 1 73 | # groups[fit.obs["cell_type"] == "NK cells"] = 2 74 | # groups[fit.obs["cell_type"] == "Dendritic cells"] = 3 75 | import pandas as pd 76 | groups = fit.obs["cell_type"] 77 | groups[np.array([0,9,99])] = pd.NA 78 | 79 | np.unique(groups.to_numpy()) 80 | ``` 81 | 82 | ```{python} 83 | fit2 = fit.copy() 84 | fit2 = pylemur.tl.align_with_grouping(fit2, groups) 85 | ``` 86 | 87 | 88 | ```{python} 89 | fit2.obsm["embedding"][0:3, 0:3].T 90 | fit2.obsm["embedding"][24600:24603,0:3].T 91 | ``` 92 | 93 | ```{python} 94 | pred_ctrl = pylemur.tl.predict(fit2, new_condition=pylemur.tl.cond(fit2, label = "ctrl")) 95 | pred_ctrl[0:3,0:3].T 96 | ``` 97 | 98 | 99 | ```{python} 100 | np.savetxt("/var/folders/dc/tppjxs9x6ll378lq88lz1fm40000gq/T//Rtmpm5PTRh/pred_ctrl.tsv", pred_ctrl, delimiter="\t") 101 | np.savetxt("/var/folders/dc/tppjxs9x6ll378lq88lz1fm40000gq/T//Rtmpm5PTRh/embedding.tsv", fit2.obsm["embedding"], delimiter="\t") 102 | ``` 103 | -------------------------------------------------------------------------------- /notebooks/devel_experiments.qmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Quick check if new functions work as expected" 3 | author: Constantin Ahlmann-Eltze 4 | format: 5 | html: 6 | code-fold: false 7 | embed-resources: true 8 | highlight-style: github 9 | toc: true 10 | code-line-numbers: true 11 | execute: 12 | keep-ipynb: true 13 | jupyter: python3 14 | --- 15 | 16 | 17 | ```{python} 18 | %load_ext autoreload 19 | %autoreload 2 20 | ``` 21 | 22 | 23 | ```{python} 24 | import debugpy 25 | debugpy.listen(5678) 26 | print("Waiting for debugger attach") 27 | debugpy.wait_for_client() 28 | ``` 29 | 30 | Convert formula to design matrix in Python 31 | 32 | ```{python} 33 | import numpy as np 34 | from patsy import dmatrices, dmatrix, demo_data 35 | ``` 36 | 37 | 38 | ```{python} 39 | data = demo_data("a", "b", "x1", "x2", "y", "z column", min_rows = 400) 40 | data 41 | ``` 42 | 43 | ```{python} 44 | form = dmatrix("~ a + x1", data) 45 | ``` 46 | 47 | 48 | ```{python} 49 | from pyLemur.handle_design import * 50 | from pyLemur.row_groups import * 51 | ``` 52 | 53 | ```{python} 54 | des, form = convert_formula_to_design_matrix("~ a + x1", data) 55 | ``` 56 | 57 | 58 | ```{python} 59 | row_groups(des) 60 | ``` 61 | 62 | 63 | ```{python} 64 | # 400 cells, 30 features 65 | Y = np.random.rand(400, 30) 66 | ``` 67 | 68 | 69 | ```{python} 70 | import sklearn.decomposition as skd 71 | 72 | pca_fit = PCA(n_components=2) 73 | emb = pca_fit.fit_transform(Y) 74 | pca_fit.score(Y) 75 | pca_fit.mean_ 76 | ``` 77 | 78 | 79 | ```{python} 80 | svd = skd.TruncatedSVD(n_components=2) 81 | 82 | ``` 83 | 84 | ```{python} 85 | from pyLemur.pca import * 86 | 87 | pca = fit_pca(Y, 2, center=False) 88 | pca 89 | 90 | ``` 91 | 92 | 93 | ```{python} 94 | from pyLemur.lin_alg_wrappers import * 95 | from pyLemur.design_matrix_utils import * 96 | 97 | data = demo_data("a", "b", "x1", "x2", "y", "z column", min_rows = 400) 98 | des, form = convert_formula_to_design_matrix("~ a", data) 99 | Y = np.random.rand(400, 30) 100 | beta = ridge_regression(Y, des, 0) 101 | beta 102 | ``` 103 | 104 | ```{python} 105 | np.unique(des, return_counts=True)[1] 106 | ``` 107 | 108 | ```{python} 109 | from pyLemur.grassmann_lm import * 110 | 111 | base_point = fit_pca(Y, n = 3, center = False).coord_system 112 | V = grassmann_lm(Y, des, base_point) 113 | V.shape 114 | 115 | ``` 116 | 117 | ```{python} 118 | np.hstack([[1,2,3], [4,5,6]]).shape 119 | ``` 120 | 121 | 122 | ```{python} 123 | # The python equivalent of R's seq(-3, 4, length.out = 18) 124 | # Y = np.linspace(-3, 4, 18).reshape((6,3)) 125 | # des = np.hstack([np.ones((6,1)), np.array([1,0,1,0,1,0]).reshape((6,1))]) 126 | # Read csv file into numpy array 127 | import pandas as pd 128 | Y = pd.read_csv("/var/folders/dc/tppjxs9x6ll378lq88lz1fm40000gq/T//Rtmp89jWYn/file128665ac82720").to_numpy().T 129 | des = pd.read_csv("/var/folders/dc/tppjxs9x6ll378lq88lz1fm40000gq/T//Rtmp89jWYn/file12866a80d470").to_numpy() 130 | base_point = fit_pca(Y, n = 3, center = False).coord_system 131 | V = grassmann_lm(Y, des, base_point) 132 | V 133 | ``` 134 | -------------------------------------------------------------------------------- /notebooks/kang_analysis_in_R.qmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Kang analysis with R" 3 | format: html 4 | --- 5 | 6 | 7 | ```{r} 8 | library(tidyverse) 9 | library(SingleCellExperiment) 10 | ``` 11 | 12 | 13 | ```{r} 14 | sce <- zellkonverter::readH5AD("data/kang_2018.h5ad", X_name = "counts") 15 | ``` 16 | 17 | ```{r} 18 | logcounts(sce) <- transformGamPoi::shifted_log_transform(counts(sce)) 19 | hvgs <- order(-rowVars(logcounts(sce))) 20 | sce <- sce[hvgs[1:1000],] 21 | ``` 22 | 23 | ```{r} 24 | system.time({ 25 | fit <- lemur::lemur(sce, design = ~ label, n_embedding = 15, verbose = TRUE, test_fraction = 0) 26 | }) 27 | 28 | fit$embedding[1:3, 1:3] 29 | fit$embedding[1:3, 24601:24603] 30 | 31 | 32 | # fit <- lemur::align_harmony(fit) 33 | 34 | # groups <- rep(NA, ncol(fit)) 35 | # groups[fit$colData$cell_type == "CD4 T cells"] <- 1 36 | # groups[fit$colData$cell_type == "NK cells"] <- 2 37 | # groups[fit$colData$cell_type == "Dendritic cells"] <- 3 38 | groups <- fit$colData$cell_type 39 | groups[c(1,10,100)] <- NA 40 | 41 | fit <- lemur::align_by_grouping(fit, groups) 42 | 43 | fit$embedding[1:3, 1:3] 44 | fit$embedding[1:3, 24601:24603] 45 | 46 | 47 | pred_ctrl <- predict(fit, newcondition = cond(label = "ctrl")) 48 | pred_ctrl[1:3, 1:3] 49 | ``` 50 | 51 | 52 | ```{r} 53 | py_pred <- t(as.matrix(readr::read_tsv("/var/folders/dc/tppjxs9x6ll378lq88lz1fm40000gq/T//Rtmpm5PTRh/pred_ctrl.tsv", col_names = FALSE))) 54 | dimnames(py_pred) <- dimnames(fit) 55 | plot(py_pred[1,], pred_ctrl[1,]) 56 | ``` 57 | 58 | 59 | ```{r} 60 | py_emb <- t(as.matrix(readr::read_tsv("/var/folders/dc/tppjxs9x6ll378lq88lz1fm40000gq/T//Rtmpm5PTRh/embedding.tsv", col_names = FALSE))) 61 | dimnames(py_emb) <- dimnames(fit$embedding) 62 | plot(py_emb[1,], fit$embedding[1,]) 63 | ``` 64 | 65 | ```{r} 66 | as_tibble(colData(fit)) %>% 67 | mutate(r_emb = t(fit$embedding), 68 | py_emb = t(py_emb)) %>% 69 | pivot_longer(ends_with("emb"), names_sep = "_", names_to = c("origin", ".value")) %>% 70 | ggplot(aes(x = emb[,4], y = emb[,15])) + 71 | geom_point(aes(color = label)) + 72 | facet_wrap(vars(origin)) 73 | 74 | ``` 75 | -------------------------------------------------------------------------------- /notebooks/lemur_model_experiments.qmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Play with scanpy" 3 | author: Constantin Ahlmann-Eltze 4 | format: 5 | html: 6 | code-fold: false 7 | embed-resources: true 8 | highlight-style: github 9 | toc: true 10 | code-line-numbers: true 11 | execute: 12 | keep-ipynb: true 13 | jupyter: python3 14 | --- 15 | 16 | 17 | ```{python} 18 | %load_ext autoreload 19 | %autoreload 2 20 | ``` 21 | 22 | 23 | ```{python} 24 | import debugpy 25 | debugpy.listen(5678) 26 | print("Waiting for debugger attach") 27 | debugpy.wait_for_client() 28 | ``` 29 | 30 | ```{python} 31 | import numpy as np 32 | import scanpy as sc 33 | import pertpy 34 | import scanpy.preprocessing._simple as simple 35 | 36 | adata = pertpy.data.kang_2018() 37 | adata.layers["counts"] = adata.X.copy() 38 | sf = np.array(adata.layers["counts"].sum(axis=1)) 39 | sf = sf / np.median(sf) 40 | adata.layers["logcounts"] = sc.pp.log1p(adata.layers["counts"] / sf) 41 | var = simple._get_mean_var(adata.layers["logcounts"])[1] 42 | hvgs = var.argpartition(-1000)[-1000:] 43 | adata = adata[:, hvgs] 44 | ``` 45 | 46 | ```{python} 47 | import numpy as np 48 | n_cells = 400 49 | n_genes = 100 50 | mat = np.arange(n_cells * n_genes).reshape((n_cells, n_genes)) 51 | log_mat = shifted_log_transform(mat) 52 | ``` 53 | 54 | 55 | ```{python} 56 | # shifted log transformation ala transformGamPoi 57 | 58 | ``` 59 | 60 | 61 | ```{python} 62 | adata.X = adata.layers["logcounts"] 63 | sc.pp.pca(adata) 64 | sc.pp.neighbors(adata) 65 | sc.tl.umap(adata) 66 | sc.pl.umap(adata, color=["label", "cell_type"]) 67 | ``` 68 | 69 | ```{python} 70 | from pylemur.tl import lemur 71 | 72 | fit = lemur(adata, design = ["label"]) 73 | ``` 74 | 75 | ```{python} 76 | sc.pp.neighbors(fit, use_rep="embedding") 77 | sc.tl.umap(fit) 78 | sc.pl.umap(fit, color=["label", "cell_type"]) 79 | ``` 80 | 81 | 82 | ```{python} 83 | from pylemur.tl import align_with_harmony 84 | align_with_harmony(fit, ridge_penalty = 0.01) 85 | ``` 86 | 87 | 88 | ```{python} 89 | nei = sc.pp.neighbors(fit, use_rep="embedding") 90 | sc.tl.umap(fit) 91 | sc.pl.umap(fit, color=["label", "cell_type"]) 92 | ``` 93 | 94 | 95 | ```{python} 96 | import matplotlib.pyplot as plt 97 | plt.scatter(fit.obsm["new_embedding"], fit.obsm["embedding"]) 98 | ``` 99 | 100 | ```{python} 101 | from pylemur.tl import predict, cond 102 | pred_ctrl = predict(fit, new_condition = cond(fit, label = "ctrl")) 103 | pred_stim = predict(fit, new_condition = cond(fit, label = "stim")) 104 | delta2 = pred_stim - pred_ctrl 105 | ``` 106 | 107 | ```{python} 108 | import matplotlib.pyplot as plt 109 | plt.scatter(delta[:,0], delta2[:,0]) 110 | ``` 111 | 112 | ```{python} 113 | gene = 0 114 | fit.obs["delta"] = delta[:,gene] 115 | fit.obs["delta2"] = delta2[:,gene] 116 | fit.obs["expr"] = fit.layers["logcounts"][:,gene].toarray() 117 | 118 | import matplotlib.pyplot as plt 119 | sc.pl.umap(fit, color=["delta", "delta2"], cmap = plt.get_cmap("seismic"), vcenter=0) 120 | 121 | sc.pl.umap(fit[fit.obs["label"] == "ctrl"], color="expr", vmax = 1) 122 | sc.pl.umap(fit[fit.obs["label"] == "stim"], color="expr", vmax = 1) 123 | ``` 124 | 125 | ```{python} 126 | # fit.obs["plot_label"] = fit.obs["label"].astype(str) + "-" + fit.obs["cell_type"].astype(str) 127 | fit.obs["plot_label"] = fit.obs["label"].astype(str) + "-" + fit.obs["cell_type"].astype(str) 128 | sc.pl.violin(fit, groupby="plot_label", keys="expr") 129 | ``` 130 | -------------------------------------------------------------------------------- /notebooks/scanpy_experiments.qmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Play with scanpy" 3 | author: Constantin Ahlmann-Eltze 4 | format: 5 | html: 6 | code-fold: false 7 | embed-resources: true 8 | highlight-style: github 9 | toc: true 10 | code-line-numbers: true 11 | execute: 12 | keep-ipynb: true 13 | jupyter: python3 14 | --- 15 | 16 | 17 | ```{python} 18 | %load_ext autoreload 19 | %autoreload 2 20 | ``` 21 | 22 | 23 | ```{python} 24 | import debugpy 25 | debugpy.listen(5678) 26 | print("Waiting for debugger attach") 27 | debugpy.wait_for_client() 28 | ``` 29 | 30 | ```{python} 31 | import numpy as np 32 | import scanpy as sc 33 | import pertpy 34 | 35 | adata = pertpy.data.kang_2018() 36 | adata 37 | ``` 38 | 39 | ```{python} 40 | adata.obs 41 | adata.var 42 | ``` 43 | 44 | ```{python} 45 | np.unique(adata.X.sum(axis=0) > 5, axis = 1, return_counts=True) 46 | ``` 47 | 48 | ```{python} 49 | adata.X[0:100,0:10] 50 | ``` 51 | 52 | 53 | ```{python} 54 | adata.layers["counts"] = adata.X.copy() 55 | sf = np.array(adata.layers["counts"].sum(axis=1)) 56 | sf = sf / np.median(sf) 57 | adata.layers["logcounts"] = sc.pp.log1p(adata.layers["counts"] / sf) 58 | ``` 59 | 60 | ```{python} 61 | import scanpy.preprocessing._simple as simple 62 | var = simple._get_mean_var(adata.layers["logcounts"])[1] 63 | hvgs = var.argpartition(-1000)[-1000:] 64 | adata = adata[:, hvgs] 65 | ``` 66 | 67 | 68 | ```{python} 69 | adata.X = adata.layers["logcounts"] 70 | sc.pp.pca(adata) 71 | sc.pp.neighbors(adata) 72 | sc.tl.umap(adata) 73 | sc.pl.umap(adata, color=["label", "cell_type"]) 74 | ``` 75 | 76 | ```{python} 77 | # Shuffle the rows of adata 78 | adata = adata[np.random.permutation(adata.obs.index), :] 79 | ``` 80 | 81 | ```{python} 82 | import formulaic 83 | import pandas as pd 84 | df = pd.DataFrame(adata.obs) 85 | 86 | des = formulaic.model_matrix("~ label", df) 87 | form = formulaic.Formula("~ label") 88 | ``` 89 | 90 | ```{python} 91 | import patsy 92 | des2 = patsy.dmatrix("~ label", df) 93 | des2 94 | ``` 95 | 96 | ```{python} 97 | from pyLemur.lemur import lemur 98 | 99 | extra_data = {"test": np.random.randint(3, size=adata.shape[0]), 100 | "cat": ["ABC"[x] for x in np.random.randint(3, size=adata.shape[0])]} 101 | fit = lemur(adata, design = ["label"], obs_data=extra_data) 102 | ``` 103 | 104 | ```{python} 105 | sc.pp.neighbors(fit, use_rep="embedding") 106 | sc.tl.umap(fit) 107 | sc.pl.umap(fit, color=["label", "cat", "cell_type"]) 108 | ``` 109 | 110 | ```{python} 111 | coord = fit_pca(adata.layers["logcounts"].toarray(), n = 15, center = False).coord_system 112 | coord2 = fit_pca(adata.layers["logcounts"].toarray(), n = 15, center = False).coord_system 113 | print(grassmann_angle_from_point(coord2.T, coord.T)) 114 | ``` 115 | 116 | 117 | ```{python} 118 | from pyLemur.lin_alg_wrappers import * 119 | from pyLemur.grassmann import * 120 | 121 | fit = lemur(adata, design = "~ label", linear_coefficient_estimator="zero") 122 | V_slice1 = fit.uns["lemur"]["coefficients"][:,:,0] 123 | V_slice2 = fit.uns["lemur"]["coefficients"][:,:,1] 124 | 125 | coord_all = fit_pca(adata.X.toarray(), n = 15, center = False).coord_system 126 | coord_ctrl = fit_pca(adata.X[adata.obs["label"] == "ctrl", :].toarray(), n = 15, center = False).coord_system 127 | coord_stim = fit_pca(adata.X[adata.obs["label"] == "stim", :].toarray(), n = 15, center = False).coord_system 128 | 129 | print(grassmann_angle_from_point(fit.uns["lemur"]["base_point"].T, coord_all.T)) 130 | print(grassmann_angle_from_point(grassmann_map(V_slice1.T, fit.uns["lemur"]["base_point"].T), coord_ctrl.T)) 131 | print(grassmann_angle_from_point(grassmann_map((V_slice1 + V_slice2).T, fit.uns["lemur"]["base_point"].T), coord_stim.T)) 132 | ``` 133 | 134 | 135 | 136 | ```{python} 137 | from pyLemur.predict import * 138 | 139 | fit.uns["lemur"]["design_matrix"] 140 | 141 | ``` 142 | 143 | ```{python} 144 | pred_stim = predict(fit, new_condition = cond(fit, label = "stim", cat = "B")) 145 | pred_ctrl = predict(fit, new_condition = cond(fit, label = "ctrl", cat = "B")) 146 | delta = pred_stim - pred_ctrl 147 | ``` 148 | 149 | ```{python} 150 | gene = 3 151 | fit.obs["delta"] = delta[:,gene] 152 | fit.obs["expr"] = fit.layers["logcounts"][:,gene].toarray() 153 | 154 | import matplotlib.pyplot as plt 155 | sc.pl.umap(fit, color="delta", cmap = plt.get_cmap("seismic"), vcenter=0) 156 | 157 | sc.pl.umap(fit[fit.obs["label"] == "ctrl"], color="expr", vmax = 2) 158 | sc.pl.umap(fit[fit.obs["label"] == "stim"], color="expr", vmax = 2) 159 | ``` 160 | 161 | ```{python} 162 | ``` 163 | 164 | ```{python} 165 | import harmonypy 166 | ho = harmonypy.run_harmony(fit.obsm["embedding"], fit.obs, "label") 167 | fit.obsm["harmony"] = ho.Z_corr.T 168 | ho.cl 169 | nei = sc.pp.neighbors(fit, use_rep="harmony") 170 | sc.tl.umap(fit) 171 | sc.pl.umap(fit, color=["label", "delta", "cell_type"]) 172 | ``` 173 | 174 | 175 | ```{python} 176 | from pyLemur.alignment import * 177 | align_with_harmony(fit) 178 | ``` 179 | 180 | ```{python} 181 | ho = harmonypy.run_harmony(fit.obsm["embedding"], fit.obs, "label") 182 | 183 | from pyLemur.alignment import * 184 | 185 | ho = init_harmony(fit.obsm["embedding"], fit.uns["lemur"]["design_matrix"]) 186 | ho.cluster() 187 | ``` 188 | 189 | ```{python} 190 | A = np.hstack([np.ones((5, 2)), np.zeros((5, 6))]) 191 | B = np.arange(8).reshape((8, 1)) 192 | multiply_along_axis(A, B.T, axis = 1) 193 | ``` 194 | 195 | ```{python} 196 | nei = sc.pp.neighbors(fit, use_rep="new_embedding") 197 | sc.tl.umap(fit) 198 | sc.pl.umap(fit, color=["label", "delta", "cell_type"]) 199 | ``` 200 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | build-backend = "hatchling.build" 3 | requires = ["hatchling"] 4 | 5 | [project] 6 | name = "pyLemur" 7 | version = "0.3.1" 8 | description = "A Python implementation of the LEMUR algorithm for analyzing multi-condition single-cell RNA-seq data." 9 | readme = "README.md" 10 | requires-python = ">=3.10" 11 | license = {file = "LICENSE"} 12 | authors = [ 13 | {name = "Constantin Ahlmann-Eltze"}, 14 | ] 15 | maintainers = [ 16 | {name = "Constantin Ahlmann-Eltze", email = "artjom31415@googlemail.com"}, 17 | ] 18 | urls.Documentation = "https://pyLemur.readthedocs.io/" 19 | urls.Source = "https://github.com/const-ae/pyLemur" 20 | urls.Home-page = "https://github.com/const-ae/pyLemur" 21 | dependencies = [ 22 | "anndata", 23 | "formulaic", 24 | "formulaic_contrasts", 25 | "harmonypy", 26 | "pandas", 27 | "numpy", 28 | "scikit-learn", 29 | # for debug logging (referenced from the issue template) 30 | "session-info", 31 | ] 32 | 33 | [project.optional-dependencies] 34 | dev = [ 35 | "pre-commit", 36 | "twine>=4.0.2", 37 | ] 38 | doc = [ 39 | "docutils>=0.8,!=0.18.*,!=0.19.*", 40 | "sphinx>=4", 41 | "sphinx-book-theme>=1.0.0", 42 | "myst-nb>=1.1.0", 43 | "sphinxcontrib-bibtex>=1.0.0", 44 | "sphinx-autodoc-typehints", 45 | "sphinxext-opengraph", 46 | "sphinx-design", 47 | # For notebooks 48 | "ipykernel", 49 | "ipython", 50 | "sphinx-copybutton", 51 | "pandas", 52 | "scanpy", 53 | "pertpy", 54 | "matplotlib", 55 | "session-info", 56 | ] 57 | test = [ 58 | "pytest", 59 | "coverage", 60 | ] 61 | 62 | [tool.coverage.run] 63 | source = ["pylemur"] 64 | omit = [ 65 | "**/test_*.py", 66 | ] 67 | 68 | [tool.pytest.ini_options] 69 | testpaths = ["tests"] 70 | xfail_strict = true 71 | addopts = [ 72 | "--import-mode=importlib", # allow using test files with same name 73 | ] 74 | 75 | [tool.ruff] 76 | line-length = 120 77 | src = ["src"] 78 | extend-include = ["*.ipynb"] 79 | 80 | [tool.ruff.format] 81 | docstring-code-format = true 82 | 83 | [tool.ruff.lint] 84 | select = [ 85 | "F", # Errors detected by Pyflakes 86 | "E", # Error detected by Pycodestyle 87 | "W", # Warning detected by Pycodestyle 88 | "I", # isort 89 | "D", # pydocstyle 90 | "B", # flake8-bugbear 91 | "TID", # flake8-tidy-imports 92 | "C4", # flake8-comprehensions 93 | "BLE", # flake8-blind-except 94 | "UP", # pyupgrade 95 | "RUF100", # Report unused noqa directives 96 | ] 97 | ignore = [ 98 | # line too long -> we accept long comment lines; formatter gets rid of long code lines 99 | "E501", 100 | # Do not assign a lambda expression, use a def -> lambda expression assignments are convenient 101 | "E731", 102 | # allow I, O, l as variable names -> I is the identity matrix 103 | "E741", 104 | # Missing docstring in public package 105 | "D104", 106 | # Missing docstring in public module 107 | "D100", 108 | # Missing docstring in __init__ 109 | "D107", 110 | # Errors from function calls in argument defaults. These are fine when the result is immutable. 111 | "B008", 112 | # __magic__ methods are often self-explanatory, allow missing docstrings 113 | "D105", 114 | # first line should end with a period [Bug: doesn't work with single-line docstrings] 115 | "D400", 116 | # First line should be in imperative mood; try rephrasing 117 | "D401", 118 | ## Disable one in each pair of mutually incompatible rules 119 | # We don’t want a blank line before a class docstring 120 | "D203", 121 | # We want docstrings to start immediately after the opening triple quote 122 | "D213", 123 | ] 124 | 125 | [tool.ruff.lint.pydocstyle] 126 | convention = "numpy" 127 | 128 | [tool.ruff.lint.per-file-ignores] 129 | "docs/*" = ["I"] 130 | "tests/*" = [ 131 | # Pydoc doesn't mattter in tests 132 | "D", 133 | # Ignore star imports in tests 134 | "F403", "F405" 135 | ] 136 | "*/__init__.py" = ["F401"] 137 | 138 | [tool.cruft] 139 | skip = [ 140 | "tests", 141 | "src/**/__init__.py", 142 | "src/**/basic.py", 143 | "docs/api.md", 144 | "docs/changelog.md", 145 | "docs/references.bib", 146 | "docs/references.md", 147 | "docs/notebooks/example.ipynb", 148 | ] 149 | -------------------------------------------------------------------------------- /src/pylemur/__init__.py: -------------------------------------------------------------------------------- 1 | from importlib.metadata import version 2 | 3 | from . import pl, pp, tl 4 | 5 | __all__ = ["pl", "pp", "tl"] 6 | 7 | __version__ = version("pyLemur") 8 | -------------------------------------------------------------------------------- /src/pylemur/pl/__init__.py: -------------------------------------------------------------------------------- 1 | from .basic import BasicClass, basic_plot 2 | -------------------------------------------------------------------------------- /src/pylemur/pl/basic.py: -------------------------------------------------------------------------------- 1 | from anndata import AnnData 2 | 3 | 4 | def basic_plot(adata: AnnData) -> int: 5 | """Generate a basic plot for an AnnData object. 6 | 7 | Parameters 8 | ---------- 9 | adata 10 | The AnnData object to preprocess. 11 | 12 | Returns 13 | ------- 14 | Some integer value. 15 | """ 16 | print("Import matplotlib and implement a plotting function here.") 17 | return 0 18 | 19 | 20 | class BasicClass: 21 | """A basic class. 22 | 23 | Parameters 24 | ---------- 25 | adata 26 | The AnnData object to preprocess. 27 | """ 28 | 29 | my_attribute: str = "Some attribute." 30 | my_other_attribute: int = 0 31 | 32 | def __init__(self, adata: AnnData): 33 | print("Implement a class here.") 34 | 35 | def my_method(self, param: int) -> int: 36 | """A basic method. 37 | 38 | Parameters 39 | ---------- 40 | param 41 | A parameter. 42 | 43 | Returns 44 | ------- 45 | Some integer value. 46 | """ 47 | print("Implement a method here.") 48 | return 0 49 | 50 | def my_other_method(self, param: str) -> str: 51 | """Another basic method. 52 | 53 | Parameters 54 | ---------- 55 | param 56 | A parameter. 57 | 58 | Returns 59 | ------- 60 | Some integer value. 61 | """ 62 | print("Implement a method here.") 63 | return "" 64 | -------------------------------------------------------------------------------- /src/pylemur/pp/__init__.py: -------------------------------------------------------------------------------- 1 | from .basic import shifted_log_transform 2 | -------------------------------------------------------------------------------- /src/pylemur/pp/basic.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy 3 | 4 | 5 | def shifted_log_transform(counts, overdispersion=0.05, pseudo_count=None, minimum_overdispersion=0.001): 6 | r"""Apply log transformation to count data 7 | 8 | The transformation is proportional to :math:`\log(y/s + c)`, where :math:`y` are the counts, :math:`s` is the size factor, 9 | and :math:`c` is the pseudo-count. 10 | 11 | The actual transformation is :math:`a \, \log(y/(sc) + 1)`, where using :math:`+1` ensures that the output remains 12 | sparse for :math:`y=0` and the scaling with :math:`a=\sqrt{4c}` ensures that the transformation approximates the :math:`\operatorname{acosh}` 13 | transformation. Using :math:`\log(y/(sc) + 1)` instead of :math:`\log(y/s+c)` 14 | only changes the results by a constant offset, as :math:`\log(y+c) = \log(y/c + 1) - \log(1/c)`. Importantly, neither 15 | scaling nor offsetting by a constant changes the variance-stabilizing quality of the transformation. 16 | 17 | The size factors are calculated as normalized sum per cell:: 18 | 19 | size_factors = counts.sum(axis=1) 20 | size_factors = size_factors / np.exp(np.mean(np.log(size_factors))) 21 | 22 | In case `y` is not a matrix, the `size_factors` are fixed to 1. 23 | 24 | Parameters 25 | ---------- 26 | counts 27 | The count matrix which can be sparse. 28 | overdispersion,pseudo_count 29 | Specification how much variation is expected for a homogeneous sample. The `overdispersion` and 30 | `pseudo_count` are related by `overdispersion = 1 / (4 * pseudo_count)`. 31 | minimum_overdispersion 32 | Avoid overdispersion values smaller than `minimum_overdispersion`. 33 | 34 | Returns 35 | ------- 36 | A matrix of variance-stabilized values. 37 | """ 38 | if pseudo_count is None: 39 | pseudo_count = 1 / (4 * overdispersion) 40 | 41 | n_cells = counts.shape[0] 42 | size_factors = counts.sum(axis=1) 43 | size_factors = size_factors / np.exp(np.mean(np.log(size_factors))) 44 | norm_mat = counts / size_factors.reshape((n_cells, 1)) 45 | overdispersion = 1 / (4 * pseudo_count) 46 | if overdispersion < minimum_overdispersion: 47 | overdispersion = minimum_overdispersion 48 | res = 1 / np.sqrt(overdispersion) * np.log1p(4 * overdispersion * norm_mat) 49 | if scipy.sparse.issparse(counts): 50 | res = scipy.sparse.csr_matrix(res) 51 | return res 52 | -------------------------------------------------------------------------------- /src/pylemur/tl/__init__.py: -------------------------------------------------------------------------------- 1 | from .lemur import LEMUR 2 | -------------------------------------------------------------------------------- /src/pylemur/tl/_design_matrix_utils.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Mapping 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | # import patsy 7 | from formulaic import model_matrix 8 | from numpy.lib import NumpyVersion 9 | 10 | 11 | def handle_data(data, layer): 12 | Y = data.X if layer is None else data.layers[layer] 13 | if not isinstance(Y, np.ndarray): 14 | Y = Y.toarray() 15 | return Y 16 | 17 | 18 | def handle_design_parameter(design, obs_data): 19 | if isinstance(design, np.ndarray): 20 | if design.ndim == 1: 21 | # Throw error 22 | raise ValueError("design specified as a 1d array is not supported yet") 23 | elif design.ndim == 2: 24 | design_matrix = pd.DataFrame(design) 25 | design_formula = None 26 | else: 27 | raise ValueError("design must be a 2d array") 28 | elif isinstance(design, pd.DataFrame): 29 | design_matrix = design 30 | design_formula = None 31 | elif isinstance(design, list): 32 | return handle_design_parameter(" * ".join(design), obs_data) 33 | elif isinstance(design, str): 34 | # Check if design starts with a ~ 35 | if design[0] != "~": 36 | design = "~" + design + " - 1" 37 | design_matrix, design_formula = convert_formula_to_design_matrix(design, obs_data) 38 | else: 39 | raise ValueError("design must be a 2d array or string") 40 | 41 | return design_matrix, design_formula 42 | 43 | 44 | def handle_obs_data(adata, obs_data): 45 | a = make_data_frame(adata.obs) 46 | b = make_data_frame(obs_data, preferred_index=a.index if a is not None else None) 47 | if a is None and b is None: 48 | return pd.DataFrame(index=pd.RangeIndex(0, adata.shape[0])) 49 | elif a is None: 50 | return b 51 | elif b is None: 52 | return a 53 | else: 54 | return pd.concat([a, b], axis=1) 55 | 56 | 57 | def make_data_frame(data, preferred_index=None): 58 | if data is None: 59 | return None 60 | if isinstance(data, pd.DataFrame) and preferred_index is None: 61 | return data 62 | if isinstance(data, pd.DataFrame) and preferred_index is not None: 63 | if preferred_index.equals(data.index) or preferred_index.equals(data.index.map(str)): 64 | data.index = preferred_index 65 | return data 66 | else: 67 | raise ValueError("The index of adata.obs and obsData do not match") 68 | elif isinstance(data, Mapping): 69 | return pd.DataFrame(data, index=preferred_index) 70 | else: 71 | raise ValueError("data must be None, a pandas DataFrame or a Mapping object") 72 | 73 | 74 | def convert_formula_to_design_matrix(formula, obs_data): 75 | # Check if formula is string 76 | if isinstance(formula, str): 77 | # Convert formula to design matrix 78 | # design_matrix = patsy.dmatrix(formula, obs_data) 79 | design_matrix = model_matrix(formula, obs_data) 80 | return design_matrix, formula 81 | else: 82 | raise ValueError("formula must be a string") 83 | 84 | 85 | def row_groups(matrix, return_reduced_matrix=False, return_group_ids=False): 86 | reduced_matrix, inv = np.unique(matrix, axis=0, return_inverse=True) 87 | if NumpyVersion(np.__version__) >= "2.0.0rc": 88 | inv = np.squeeze(inv) 89 | group_ids = np.unique(inv) 90 | if return_reduced_matrix and return_group_ids: 91 | return inv, reduced_matrix, group_ids 92 | elif return_reduced_matrix: 93 | return inv, reduced_matrix 94 | elif return_group_ids: 95 | return inv, group_ids 96 | else: 97 | return inv 98 | -------------------------------------------------------------------------------- /src/pylemur/tl/_grassmann.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def grassmann_map(x, base_point): 5 | if base_point.shape[0] == 0 or base_point.shape[1] == 0: 6 | return base_point 7 | elif np.isnan(x).any(): 8 | # Return an object with the same shape as x filled with nan 9 | return np.full(x.shape, np.nan) 10 | else: 11 | u, s, vt = np.linalg.svd(x, full_matrices=False) 12 | return (base_point @ vt.T) @ np.diag(np.cos(s)) @ vt + u @ np.diag(np.sin(s)) @ vt 13 | 14 | 15 | def grassmann_log(p, q): 16 | n = p.shape[0] 17 | k = p.shape[1] 18 | 19 | if n == 0 or k == 0: 20 | return p 21 | else: 22 | z = q.T @ p 23 | At = q.T - z @ p.T 24 | # Translate `lm.fit(z, At)$coefficients` to python 25 | Bt = np.linalg.lstsq(z, At, rcond=None)[0] 26 | u, s, vt = np.linalg.svd(Bt.T, full_matrices=True) 27 | u = u[:, :k] 28 | s = s[:k] 29 | vt = vt[:k, :] 30 | return u @ np.diag(np.arctan(s)) @ vt 31 | 32 | 33 | def grassmann_project(x): 34 | return np.linalg.qr(x)[0] 35 | 36 | 37 | def grassmann_project_tangent(x, base_point): 38 | return x - base_point @ base_point.T @ x 39 | 40 | 41 | def grassmann_random_point(n, k): 42 | x = np.random.randn(n, k) 43 | return grassmann_project(x) 44 | 45 | 46 | def grassmann_random_tangent(base_point): 47 | x = np.random.randn(*base_point.shape) 48 | return grassmann_project_tangent(x, base_point) 49 | 50 | 51 | def grassmann_angle_from_tangent(x, normalized=True): 52 | thetas = np.linalg.svd(x, full_matrices=True, compute_uv=False) / np.pi * 180 53 | if normalized: 54 | return np.minimum(thetas, 180 - thetas).max() 55 | else: 56 | return thetas[0] 57 | 58 | 59 | def grassmann_angle_from_point(x, y): 60 | return grassmann_angle_from_tangent(grassmann_log(y, x)) 61 | -------------------------------------------------------------------------------- /src/pylemur/tl/_grassmann_lm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from pylemur.tl._design_matrix_utils import row_groups 4 | from pylemur.tl._grassmann import grassmann_log, grassmann_map 5 | from pylemur.tl._lin_alg_wrappers import fit_pca, ridge_regression 6 | 7 | 8 | def grassmann_geodesic_regression(coord_systems, design, base_point, weights=None): 9 | """ 10 | Fit geodesic regression on Grassmann manifolds 11 | 12 | Solve Sum_j d(U_j, Exp_p(Sum_k V_k:: * X_jk)) for V, where 13 | d(U, V) = ||Log(U, V)|| is the inverse of the exponential map on the Grassmann manifold. 14 | 15 | Parameters 16 | ---------- 17 | coord_systems : a list of orthonormal 2D matrices (length n_groups) 18 | design : design matrix, shape (n_groups, n_coef) 19 | base_point : array-like, shape (n_emb, n_features) 20 | The base point. 21 | weights : array-like, shape (n_groups,) 22 | 23 | Returns 24 | ------- 25 | beta: array-like, shape (n_emb, n_features, n_coef) 26 | """ 27 | n_obs = design.shape[0] 28 | n_coef = design.shape[1] 29 | n_emb = base_point.shape[0] 30 | n_features = base_point.shape[1] 31 | 32 | assert len(coord_systems) == n_obs 33 | for i in range(n_obs): 34 | assert coord_systems[i].shape == (n_emb, n_features) 35 | if weights is None: 36 | weights = np.ones(n_obs) 37 | 38 | tangent_vecs = [grassmann_log(base_point.T, coord_systems[i].T).T.reshape(n_emb * n_features) for i in range(n_obs)] 39 | tangent_vecs = np.vstack(tangent_vecs) 40 | if tangent_vecs.shape[0] == 0: 41 | tangent_fit = np.zeros((0, n_coef)) 42 | else: 43 | tangent_fit = ridge_regression(tangent_vecs, design, weights=weights) 44 | 45 | tangent_fit = tangent_fit.reshape((n_coef, n_emb, n_features)).transpose((1, 2, 0)) 46 | return tangent_fit 47 | 48 | 49 | def grassmann_lm(Y, design_matrix, base_point): 50 | """ 51 | Solve Sum_i||Y_i: - Y_i: Proj(Exp_p(Sum_k V_k:: * X_ik))||^2 for V. 52 | 53 | Parameters 54 | ---------- 55 | Y : array-like, shape (n_samples, n_features) 56 | The input data matrix. 57 | design_matrix : array-like, shape (n_samples, n_coef) 58 | The design matrix. 59 | base_point : array-like, shape (n_emb, n_features) 60 | The base point. 61 | 62 | Returns 63 | ------- 64 | beta: array-like, shape (n_emb, n_features, n_coef) 65 | """ 66 | n_emb = base_point.shape[0] 67 | 68 | des_row_groups, reduced_design_matrix, des_row_group_ids = row_groups( 69 | design_matrix, return_reduced_matrix=True, return_group_ids=True 70 | ) 71 | if np.min(np.unique(des_row_groups, return_counts=True)[1]) < n_emb: 72 | raise ValueError("Too few dataset points in some design matrix group.") 73 | group_planes = [fit_pca(Y[des_row_groups == i, :], n_emb, center=False).coord_system for i in des_row_group_ids] 74 | group_sizes = [np.sum(des_row_groups == i) for i in des_row_group_ids] 75 | 76 | coef = grassmann_geodesic_regression(group_planes, reduced_design_matrix, base_point, weights=group_sizes) 77 | return coef 78 | 79 | 80 | def project_diffemb_into_data_space(embedding, design_matrix, coefficients, base_point): 81 | n_features = base_point.shape[1] 82 | res = np.zeros((design_matrix.shape[0], n_features)) 83 | des_row_groups, reduced_design_matrix, des_row_group_ids = row_groups( 84 | design_matrix, return_reduced_matrix=True, return_group_ids=True 85 | ) 86 | for id in des_row_group_ids: 87 | covar = reduced_design_matrix[id, :] 88 | subspace = grassmann_map(np.dot(coefficients, covar).T, base_point.T) 89 | res[des_row_groups == id, :] = embedding[des_row_groups == id, :] @ subspace.T 90 | return res 91 | 92 | 93 | def project_data_on_diffemb(Y, design_matrix, coefficients, base_point): 94 | n_emb = base_point.shape[0] 95 | n_obs = Y.shape[0] 96 | res = np.zeros((n_obs, n_emb)) 97 | des_row_groups, reduced_design_matrix, des_row_group_ids = row_groups( 98 | design_matrix, return_reduced_matrix=True, return_group_ids=True 99 | ) 100 | for id in des_row_group_ids: 101 | Y_subset = Y[des_row_groups == id, :] 102 | covar = reduced_design_matrix[id, :] 103 | subspace = grassmann_map(np.dot(coefficients, covar).T, base_point.T) 104 | res[des_row_groups == id, :] = Y_subset @ subspace 105 | return res 106 | -------------------------------------------------------------------------------- /src/pylemur/tl/_lin_alg_wrappers.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple 2 | 3 | import numpy as np 4 | import sklearn.decomposition as skd 5 | 6 | 7 | class PCA(NamedTuple): 8 | embedding: np.ndarray 9 | coord_system: np.ndarray 10 | offset: np.ndarray 11 | 12 | 13 | def fit_pca(Y, n, center=True): 14 | """ 15 | Calculate the PCA of a given data matrix Y. 16 | 17 | Parameters 18 | ---------- 19 | Y : array-like, shape (n_samples, n_features) 20 | The input data matrix. 21 | n : int 22 | The number of principal components to return. 23 | center : bool, default=True 24 | If True, the data will be centered before computing the covariance matrix. 25 | 26 | Returns 27 | ------- 28 | pca : sklearn.decomposition.PCA 29 | The PCA object. 30 | """ 31 | if center: 32 | pca = skd.PCA(n_components=n) 33 | emb = pca.fit_transform(Y) 34 | coord_system = pca.components_ 35 | mean = pca.mean_ 36 | else: 37 | svd = skd.TruncatedSVD(n_components=n, algorithm="arpack") 38 | emb = svd.fit_transform(Y) 39 | coord_system = svd.components_ 40 | mean = np.zeros(Y.shape[1]) 41 | return PCA(emb, coord_system, mean) 42 | 43 | 44 | def ridge_regression(Y, X, ridge_penalty=0, weights=None): 45 | """ 46 | Calculate the ridge regression of a given data matrix Y. 47 | 48 | Parameters 49 | ---------- 50 | Y : array-like, shape (n_samples, n_features) 51 | The input data matrix. 52 | X : array-like, shape (n_samples, n_coef) 53 | The input data matrix. 54 | ridge_penalty : float, default=0 55 | The ridge penalty. 56 | weights : array-like, shape (n_features,) 57 | The weights to apply to each feature. 58 | 59 | Returns 60 | ------- 61 | ridge: array-like, shape (n_coef, n_features) 62 | """ 63 | n_coef = X.shape[1] 64 | n_samples = X.shape[0] 65 | n_feat = Y.shape[1] 66 | assert Y.shape[0] == n_samples 67 | if weights is None: 68 | weights = np.ones(n_samples) 69 | assert len(weights) == n_samples 70 | 71 | if np.ndim(ridge_penalty) == 0 or len(ridge_penalty) == 1: 72 | ridge_penalty = np.eye(n_coef) * ridge_penalty 73 | elif np.ndim(ridge_penalty) == 1: 74 | assert len(ridge_penalty) == n_coef 75 | ridge_penalty = np.diag(ridge_penalty) 76 | elif np.ndim(ridge_penalty) == 1: 77 | assert ridge_penalty.shape == (n_coef, n_coef) 78 | pass 79 | else: 80 | raise ValueError("ridge_penalty must be a scalar, 1d array, or 2d array") 81 | 82 | ridge_penalty_sq = np.sqrt(np.sum(weights)) * (ridge_penalty.T @ ridge_penalty) 83 | weights_sqrt = np.sqrt(weights) 84 | X_ext = np.vstack([multiply_along_axis(X, weights_sqrt, 0), ridge_penalty_sq]) 85 | Y_ext = np.vstack([multiply_along_axis(Y, weights_sqrt, 0), np.zeros((n_coef, n_feat))]) 86 | 87 | ridge = np.linalg.lstsq(X_ext, Y_ext, rcond=None)[0] 88 | return ridge 89 | 90 | 91 | def multiply_along_axis(A, B, axis): 92 | # Copied from https://stackoverflow.com/a/71750176/604854 93 | return np.swapaxes(np.swapaxes(A, axis, -1) * B, -1, axis) 94 | -------------------------------------------------------------------------------- /src/pylemur/tl/alignment.py: -------------------------------------------------------------------------------- 1 | import harmonypy 2 | import numpy as np 3 | 4 | from pylemur.tl._design_matrix_utils import row_groups 5 | from pylemur.tl._lin_alg_wrappers import ridge_regression 6 | 7 | 8 | def _align_impl( 9 | embedding, 10 | grouping, 11 | design_matrix, 12 | ridge_penalty=0.01, 13 | preserve_position_of_NAs=False, 14 | calculate_new_embedding=True, 15 | verbose=True, 16 | ): 17 | if grouping.ndim == 1: 18 | uniq_elem, fct_levels = np.unique(grouping, return_inverse=True) 19 | I = np.eye(len(uniq_elem)) 20 | I[:, np.isnan(uniq_elem)] = 0 21 | grouping_matrix = I[:, fct_levels] 22 | else: 23 | assert grouping.shape[1] == embedding.shape[0] 24 | assert np.all(grouping[~np.isnan(grouping)] >= 0) 25 | col_sums = grouping.sum(axis=0) 26 | col_sums[col_sums == 0] = 1 27 | grouping_matrix = grouping / col_sums 28 | 29 | n_groups = grouping_matrix.shape[0] 30 | n_emb = embedding.shape[1] 31 | K = design_matrix.shape[1] 32 | 33 | # NA's are converted to zero columns ensuring that `diff %*% grouping_matrix = 0` 34 | grouping_matrix[:, np.isnan(grouping_matrix.sum(axis=0))] = 0 35 | if not preserve_position_of_NAs: 36 | not_all_zero_column = grouping_matrix.sum(axis=0) != 0 37 | grouping_matrix = grouping_matrix[:, not_all_zero_column] 38 | embedding = embedding[not_all_zero_column, :] 39 | design_matrix = design_matrix[not_all_zero_column] 40 | 41 | des_row_groups, des_row_group_ids = row_groups(design_matrix, return_group_ids=True) 42 | n_conditions = des_row_group_ids.shape[0] 43 | cond_ct_means = [np.zeros((n_emb, n_groups)) for _ in des_row_group_ids] 44 | for id in des_row_group_ids: 45 | sel = des_row_groups == id 46 | for idx in range(n_groups): 47 | if grouping_matrix[idx, sel].sum() > 0: 48 | cond_ct_means[id][:, idx] = np.average(embedding[sel, :], axis=0, weights=grouping_matrix[idx, sel]) 49 | else: 50 | cond_ct_means[id][:, idx] = np.nan 51 | 52 | target = np.zeros((n_emb, n_groups)) * np.nan 53 | for idx in range(n_groups): 54 | tmp = np.zeros((n_conditions, n_emb)) 55 | for id in des_row_group_ids: 56 | tmp[id, :] = cond_ct_means[id][:, idx] 57 | if np.all(np.isnan(tmp)): 58 | target[:, idx] = np.nan 59 | else: 60 | target[:, idx] = np.average(tmp[:, ~np.isnan(tmp.sum(axis=0))], axis=0) 61 | 62 | new_pos = embedding.copy() 63 | for id in des_row_group_ids: 64 | sel = des_row_groups == id 65 | diff = target - cond_ct_means[id] 66 | # NA's are converted to zero so that they don't propagate. 67 | diff[np.isnan(diff)] = 0 68 | new_pos[sel, :] = new_pos[sel, :] + (diff @ grouping_matrix[:, sel]).T 69 | 70 | intercept_emb = np.hstack([np.ones((embedding.shape[0], 1)), embedding]) 71 | interact_design_matrix = np.repeat(design_matrix, n_emb + 1, axis=1) * np.hstack([intercept_emb] * K) 72 | alignment_coefs = ridge_regression(new_pos - embedding, interact_design_matrix, ridge_penalty) 73 | ## The alignment error is weird, as it doesn't necessarily go down. Better not to show it 74 | # if verbose: 75 | # print(f"Alignment error: {np.linalg.norm((new_pos - embedding) - interact_design_matrix @ alignment_coefs)}") 76 | alignment_coefs = alignment_coefs.reshape((K, n_emb + 1, n_emb)).transpose((2, 1, 0)) 77 | if calculate_new_embedding: 78 | new_embedding = _apply_linear_transformation(embedding, alignment_coefs, design_matrix) 79 | return alignment_coefs, new_embedding 80 | else: 81 | return alignment_coefs 82 | 83 | 84 | def _apply_linear_transformation(embedding, alignment_coefs, design_matrix): 85 | des_row_groups, reduced_design_matrix, des_row_group_ids = row_groups( 86 | design_matrix, return_reduced_matrix=True, return_group_ids=True 87 | ) 88 | embedding = embedding.copy() 89 | for id in des_row_group_ids: 90 | sel = des_row_groups == id 91 | embedding[sel, :] = ( 92 | np.hstack([np.ones((np.sum(sel), 1)), embedding[sel, :]]) 93 | @ _forward_linear_transformation(alignment_coefs, reduced_design_matrix[id, :]).T 94 | ) 95 | return embedding 96 | 97 | 98 | def _forward_linear_transformation(alignment_coef, design_vector): 99 | n_emb = alignment_coef.shape[0] 100 | if n_emb == 0: 101 | return np.zeros((0, 0)) 102 | else: 103 | return np.hstack([np.zeros((n_emb, 1)), np.eye(n_emb)]) + np.dot(alignment_coef, design_vector) 104 | 105 | 106 | def _reverse_linear_transformation(alignment_coef, design_vector): 107 | n_emb = alignment_coef.shape[0] 108 | if n_emb == 0: 109 | return np.zeros((0, 0)) 110 | else: 111 | return np.linalg.inv(np.eye(n_emb) + np.dot(alignment_coef[:, 1:, :], design_vector)) 112 | 113 | 114 | def _init_harmony( 115 | embedding, 116 | design_matrix, 117 | theta=2, 118 | lamb=1, 119 | sigma=0.1, 120 | nclust=None, 121 | tau=0, 122 | block_size=0.05, 123 | max_iter_kmeans=20, 124 | epsilon_cluster=1e-5, 125 | epsilon_harmony=1e-4, 126 | verbose=True, 127 | ): 128 | n_obs = embedding.shape[0] 129 | des_row_groups, des_row_group_ids = row_groups(design_matrix, return_group_ids=True) 130 | n_groups = len(des_row_group_ids) 131 | if nclust is None: 132 | nclust = np.min([np.round(n_obs / 30.0), 100]).astype(int) 133 | 134 | phi = np.eye(n_groups)[:, des_row_groups] 135 | # phi_n = np.ones(n_groups) 136 | 137 | N_b = phi.sum(axis=1) 138 | Pr_b = N_b / n_obs 139 | sigma = np.repeat(sigma, nclust) 140 | 141 | theta = np.repeat(theta, n_groups) 142 | if tau != 0: 143 | theta = theta * (1 - np.exp(-((N_b / (nclust * tau)) ** 2))) 144 | 145 | lamb = np.repeat(lamb, n_groups) 146 | lamb_mat = np.diag(np.insert(lamb, 0, 0)) 147 | phi_moe = np.vstack((np.repeat(1, n_obs), phi)) 148 | 149 | max_iter_harmony = 0 150 | ho = harmonypy.Harmony( 151 | embedding.T, 152 | phi, 153 | phi_moe, 154 | Pr_b, 155 | sigma, 156 | theta, 157 | max_iter_harmony, 158 | max_iter_kmeans, 159 | epsilon_cluster, 160 | epsilon_harmony, 161 | nclust, 162 | block_size, 163 | lamb_mat, 164 | verbose, 165 | ) 166 | return ho 167 | -------------------------------------------------------------------------------- /src/pylemur/tl/lemur.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections.abc import Iterable, Mapping 3 | from typing import Any, Literal 4 | import copy 5 | 6 | import anndata as ad 7 | import formulaic_contrasts 8 | import numpy as np 9 | import pandas as pd 10 | from sklearn.exceptions import NotFittedError 11 | 12 | from pylemur.tl._design_matrix_utils import handle_data, handle_design_parameter, handle_obs_data, row_groups 13 | from pylemur.tl._grassmann import grassmann_map 14 | from pylemur.tl._grassmann_lm import grassmann_lm, project_data_on_diffemb 15 | from pylemur.tl._lin_alg_wrappers import fit_pca, multiply_along_axis, ridge_regression 16 | from pylemur.tl.alignment import ( 17 | _align_impl, 18 | _apply_linear_transformation, 19 | _init_harmony, 20 | _reverse_linear_transformation, 21 | ) 22 | 23 | 24 | class LEMUR: 25 | r"""Fit the LEMUR model 26 | 27 | A python implementation of the LEMUR algorithm. For more details please refer 28 | to Ahlmann-Eltze (2024). 29 | 30 | Parameters 31 | ---------- 32 | data 33 | The AnnData object (or a different matrix container) with the variance stabilized data and the 34 | cell-wise annotations in `data.obs`. 35 | design 36 | A specification of the experimental design. This can be a string, 37 | which is then parsed using `formulaic`. Alternatively, it can be a 38 | a list of strings, which are assumed to refer to the columns in 39 | `data.obs`. Finally, it can be a numpy array, representing a 40 | design matrix of size `n_cells` x `n_covariates`. If not provided, 41 | a constant design is used. 42 | obs_data 43 | A pandas DataFrame or a dictionary of iterables containing the 44 | cell-wise annotations. It is used in combination with the 45 | information in `data.obs`. 46 | n_embedding 47 | The number of dimensions to use for the shared embedding space. 48 | linear_coefficient_estimator 49 | The method to use for estimating the linear coefficients. If `"linear"`, 50 | the linear coefficients are estimated using ridge regression. If `"zero"`, 51 | the linear coefficients are set to zero. 52 | layer 53 | The name of the layer to use in `data`. If `None`, the `X` slot is used. 54 | copy 55 | Whether to make a copy of `data`. 56 | 57 | Attributes 58 | ---------- 59 | embedding : :class:`~numpy.ndarray` (:math:`C \times P`) 60 | Low-dimensional representation of each cell 61 | adata : :class:`~anndata.AnnData` 62 | A reference to (potentially a copy of) the input data. 63 | data_matrix : :class:`~numpy.ndarray` (:math:`C \times G`) 64 | A reference to the data matrix from the `adata` object. 65 | n_embedding : int 66 | The number of latent dimensions 67 | design_matrix : `ModelMatrix` (:math:`C \times K`) 68 | The design matrix that is used for the fit. 69 | formula : str 70 | The design formula specification. 71 | coefficients : :class:`~numpy.ndarray` (:math:`P \times G \times K`) 72 | The 3D array of coefficients for the Grassmann regression. 73 | alignment_coefficients : :class:`~numpy.ndarray` (:math:`P \times (P+1) \times K`) 74 | The 3D array of coefficients for the affine alignment. 75 | linear_coefficients : :class:`~numpy.ndarray` (:math:`K\times G`) 76 | The 2D array of coefficients for the linear offset per condition. 77 | linear_coefficient_estimator : str 78 | The linear coefficient estimation specification. 79 | base_point : :class:`~numpy.ndarray` (:math:`(P \times G`)) 80 | The 2D array representing the reference subspace. 81 | 82 | Examples 83 | -------- 84 | >>> model = pylemur.tl.LEMUR(adata, design="~ label + batch_cov", n_embedding=15) 85 | >>> model.fit() 86 | >>> model.align_with_harmony() 87 | >>> pred_expr = model.predict(new_condition=model.cond(label="treated")) 88 | >>> emb_proj = model_small.transform(adata) 89 | """ 90 | 91 | def __init__( 92 | self, 93 | adata: ad.AnnData | Any, 94 | design: str | list[str] | np.ndarray = "~ 1", 95 | obs_data: pd.DataFrame | Mapping[str, Iterable[Any]] | None = None, 96 | n_embedding: int = 15, 97 | linear_coefficient_estimator: Literal["linear", "zero"] = "linear", 98 | layer: str | None = None, 99 | copy: bool = True, 100 | ): 101 | adata = _handle_data_arg(adata) 102 | if copy: 103 | adata = adata.copy() 104 | self.adata = adata 105 | 106 | adata.obs = handle_obs_data(adata, obs_data) 107 | design_matrix, formula = handle_design_parameter(design, adata.obs) 108 | 109 | if formula: 110 | self.contrast_builder = formulaic_contrasts.FormulaicContrasts(adata.obs, formula) 111 | assert design_matrix.equals(self.contrast_builder.design_matrix) 112 | else: 113 | self.contrast_builder = None 114 | 115 | self.design_matrix = design_matrix 116 | self.formula = formula 117 | if design_matrix.shape[0] != adata.shape[0]: 118 | raise ValueError("number of rows in design matrix must be equal to number of samples in data") 119 | self.data_matrix = handle_data(adata, layer) 120 | self.linear_coefficient_estimator = linear_coefficient_estimator 121 | self.n_embedding = n_embedding 122 | self.embedding = None 123 | self.coefficients = None 124 | self.linear_coefficients = None 125 | self.base_point = None 126 | self.alignment_coefficients = None 127 | 128 | def fit(self, verbose: bool = True): 129 | """Fit the LEMUR model 130 | 131 | Parameters 132 | ---------- 133 | verbose 134 | Whether to print progress to the console. 135 | 136 | Returns 137 | ------- 138 | `self` 139 | The fitted LEMUR model. 140 | """ 141 | Y = self.data_matrix 142 | design_matrix = self.design_matrix 143 | n_embedding = self.n_embedding 144 | 145 | if self.linear_coefficient_estimator == "linear": 146 | if verbose: 147 | print("Centering the data using linear regression.") 148 | lin_coef = ridge_regression(Y, design_matrix.to_numpy()) 149 | Y = Y - design_matrix.to_numpy() @ lin_coef 150 | else: # linear_coefficient_estimator == "zero" 151 | lin_coef = np.zeros((design_matrix.shape[1], Y.shape[1])) 152 | 153 | if verbose: 154 | print("Find base point") 155 | base_point = fit_pca(Y, n_embedding, center=False).coord_system 156 | if verbose: 157 | print("Fit regression on latent spaces") 158 | coefficients = grassmann_lm(Y, design_matrix.to_numpy(), base_point) 159 | if verbose: 160 | print("Find shared embedding coordinates") 161 | embedding = project_data_on_diffemb(Y, design_matrix.to_numpy(), coefficients, base_point) 162 | 163 | embedding, coefficients, base_point = _order_axis_by_variance(embedding, coefficients, base_point) 164 | 165 | self.embedding = embedding 166 | self.alignment_coefficients = np.zeros((n_embedding, n_embedding + 1, design_matrix.shape[1])) 167 | self.coefficients = coefficients 168 | self.base_point = base_point 169 | self.linear_coefficients = lin_coef 170 | 171 | return self 172 | 173 | def align_with_harmony( 174 | self, ridge_penalty: float | list[float] | np.ndarray = 0.01, max_iter: int = 10, verbose: bool = True 175 | ): 176 | """Fine-tune the embedding with a parametric version of Harmony. 177 | 178 | Parameters 179 | ---------- 180 | ridge_penalty 181 | The penalty controlling the flexibility of the alignment. Smaller 182 | values mean more flexible alignments. 183 | max_iter 184 | The maximum number of iterations to perform. 185 | verbose 186 | Whether to print progress to the console. 187 | 188 | 189 | Returns 190 | ------- 191 | `self` 192 | The fitted LEMUR model with the updated embedding space stored in 193 | `model.embedding` attribute and an the updated alignment coefficients 194 | stored in `model.alignment_coefficients`. 195 | """ 196 | if self.embedding is None: 197 | raise NotFittedError( 198 | "self.embedding is None. Make sure to call 'model.fit()' " 199 | + "before calling 'model.align_with_harmony()'." 200 | ) 201 | 202 | embedding = self.embedding.copy() 203 | design_matrix = self.design_matrix 204 | # Init harmony 205 | harm_obj = _init_harmony(embedding, design_matrix, verbose=verbose) 206 | for idx in range(max_iter): 207 | if verbose: 208 | print(f"Alignment iteration {idx}") 209 | # Update harmony 210 | harm_obj.cluster() 211 | # alignment <- align_impl(training_fit$embedding, harm_obj$R, act_design_matrix, ridge_penalty = ridge_penalty) 212 | al_coef, new_emb = _align_impl( 213 | embedding, 214 | harm_obj.R, 215 | design_matrix, 216 | ridge_penalty=ridge_penalty, 217 | calculate_new_embedding=True, 218 | verbose=verbose, 219 | ) 220 | harm_obj.Z_corr = new_emb.T 221 | harm_obj.Z_cos = multiply_along_axis( 222 | new_emb, 1 / np.linalg.norm(new_emb, axis=1).reshape((new_emb.shape[0], 1)), axis=1 223 | ).T 224 | 225 | if harm_obj.check_convergence(1): 226 | if verbose: 227 | print("Converged") 228 | break 229 | self.alignment_coefficients = al_coef 230 | self.embedding = _apply_linear_transformation(embedding, al_coef, design_matrix) 231 | return self 232 | 233 | def align_with_grouping( 234 | self, 235 | grouping: list | np.ndarray | pd.Series, 236 | ridge_penalty: float | list[float] | np.ndarray = 0.01, 237 | preserve_position_of_NAs: bool = False, 238 | verbose: bool = True, 239 | ): 240 | """Fine-tune the embedding using annotated groups of cells. 241 | 242 | Parameters 243 | ---------- 244 | grouping 245 | A list, :class:`~numpy.ndarray`, or pandas :class:`pandas.Series` specifying the group of cells. 246 | The groups span different conditions and can for example be cell types. 247 | ridge_penalty 248 | The penalty controlling the flexibility of the alignment. 249 | preserve_position_of_NAs 250 | `True` means that `NA`'s in the `grouping` indicate that these cells should stay 251 | where they are (if possible). `False` means that they are free to move around. 252 | verbose 253 | Whether to print progress to the console. 254 | 255 | Returns 256 | ------- 257 | `self` 258 | The fitted LEMUR model with the updated embedding space stored in 259 | `model.embedding` attribute and an the updated alignment coefficients 260 | stored in `model.alignment_coefficients`. 261 | """ 262 | if self.embedding is None: 263 | raise NotFittedError( 264 | "self.embedding is None. Make sure to call 'model.fit()' " 265 | + "before calling 'model.align_with_grouping()'." 266 | ) 267 | embedding = self.embedding.copy() 268 | design_matrix = self.design_matrix 269 | if isinstance(grouping, list): 270 | grouping = pd.Series(grouping) 271 | 272 | if isinstance(grouping, pd.Series): 273 | grouping = grouping.factorize()[0] * 1.0 274 | grouping[grouping == -1] = np.nan 275 | 276 | al_coef = _align_impl( 277 | embedding, 278 | grouping, 279 | design_matrix, 280 | ridge_penalty=ridge_penalty, 281 | calculate_new_embedding=False, 282 | verbose=verbose, 283 | ) 284 | self.alignment_coefficients = al_coef 285 | self.embedding = _apply_linear_transformation(embedding, al_coef, design_matrix) 286 | return self 287 | 288 | def transform( 289 | self, 290 | adata: ad.AnnData, 291 | layer: str | None = None, 292 | obs_data: pd.DataFrame | Mapping[str, Iterable[Any]] | None = None, 293 | return_type: Literal["embedding", "LEMUR"] = "embedding", 294 | ): 295 | """Transform data using the fitted LEMUR model 296 | 297 | Parameters 298 | ---------- 299 | adata 300 | The AnnData object to transform. 301 | obs_data 302 | Optional set of annotations for each cell (same as `obs_data` in the 303 | constructor). 304 | return_type 305 | Flag that decides if the function returns a full `LEMUR` object or 306 | only the embedding. 307 | 308 | Returns 309 | ------- 310 | :class:`~pylemur.tl.LEMUR` 311 | (if `return_type = "LEMUR"`) A new `LEMUR` object object with the embedding 312 | calculated for the input `adata`. 313 | 314 | :class:`~numpy.ndarray` 315 | (if `return_type = "embedding"`) A 2D numpy array of the embedding matrix 316 | calculated for the input `adata` (with cells in the rows and latent dimensions 317 | in the columns). 318 | """ 319 | Y = handle_data(adata, layer) 320 | adata.obs = handle_obs_data(adata, obs_data) 321 | design_matrix, _ = handle_design_parameter(self.formula, adata.obs) 322 | dm = design_matrix.to_numpy() 323 | Y_clean = Y - dm @ self.linear_coefficients 324 | embedding = project_data_on_diffemb( 325 | Y_clean, design_matrix=dm, coefficients=self.coefficients, base_point=self.base_point 326 | ) 327 | embedding = _apply_linear_transformation(embedding, self.alignment_coefficients, dm) 328 | if return_type == "embedding": 329 | return embedding 330 | elif return_type == "LEMUR": 331 | fit = LEMUR.copy() 332 | fit.adata = adata 333 | fit.design_matrix = design_matrix 334 | fit.embedding = embedding 335 | return fit 336 | 337 | def predict( 338 | self, 339 | embedding: np.ndarray | None = None, 340 | new_design: str | list[str] | np.ndarray | None = None, 341 | new_condition: np.ndarray | pd.DataFrame | None = None, 342 | obs_data: pd.DataFrame | Mapping[str, Iterable[Any]] | None = None, 343 | new_adata_layer: None | str = None, 344 | ): 345 | """Predict the expression of cells in a specific condition 346 | 347 | Parameters 348 | ---------- 349 | embedding 350 | The coordinates of the cells in the shared embedding space. If None, 351 | the coordinates stored in `model.embedding` are used. 352 | new_design 353 | Either a design formula parsed using `model.adata.obs` and `obs_data` or 354 | a design matrix defining the condition for each cell. If both `new_design` 355 | and `new_condition` are None, the original design matrix 356 | (`model.design_matrix`) is used. 357 | new_condition 358 | A specification of the new condition that is applied to all cells. Typically, 359 | this is generated by `cond(...)`. 360 | obs_data 361 | A DataFrame-like object containing cell-wise annotations. It is only used if `new_design` 362 | contains a formulaic formula string. 363 | new_adata_layer 364 | If not `None`, the function returns `self` and stores the prediction in 365 | `model.adata["new_adata_layer"]`. 366 | 367 | Returns 368 | ------- 369 | array-like, shape (n_cells, n_genes) 370 | The predicted expression of the cells in the new condition. 371 | """ 372 | if embedding is None: 373 | if self.embedding is None: 374 | raise NotFittedError("The model has not been fitted yet.") 375 | embedding = self.embedding 376 | 377 | if new_condition is not None: 378 | if new_design is not None: 379 | warnings.warn("new_design is ignored if new_condition is provided.", stacklevel=1) 380 | 381 | if isinstance(new_condition, pd.DataFrame): 382 | new_design = new_condition.to_numpy() 383 | elif isinstance(new_condition, pd.Series): 384 | new_design = np.expand_dims(new_condition.to_numpy(), axis=0) 385 | elif isinstance(new_condition, np.ndarray): 386 | new_design = new_condition 387 | else: 388 | raise ValueError("new_condition must be a created using 'cond(...)' or a numpy array.") 389 | if new_design.shape[0] != 1: 390 | raise ValueError("new_condition must only have one row") 391 | # Repeat values row-wise 392 | new_design = np.ones((embedding.shape[0], 1)) @ new_design 393 | elif new_design is None: 394 | new_design = self.design_matrix.to_numpy() 395 | else: 396 | new_design = handle_design_parameter(new_design, handle_obs_data(self.adata, obs_data))[0].to_numpy() 397 | 398 | # Make prediciton 399 | approx = new_design @ self.linear_coefficients 400 | 401 | coef = self.coefficients 402 | al_coefs = self.alignment_coefficients 403 | des_row_groups, reduced_design_matrix, des_row_group_ids = row_groups( 404 | new_design, return_reduced_matrix=True, return_group_ids=True 405 | ) 406 | for id in des_row_group_ids: 407 | covars = reduced_design_matrix[id, :] 408 | subspace = grassmann_map(np.dot(coef, covars).T, self.base_point.T) 409 | alignment = _reverse_linear_transformation(al_coefs, covars) 410 | offset = np.dot(al_coefs[:, 0, :], covars) 411 | approx[des_row_groups == id, :] += ( 412 | (embedding[des_row_groups == id, :] - offset) @ alignment.T 413 | ) @ subspace.T 414 | if new_adata_layer is not None: 415 | self.adata.layers[new_adata_layer] = approx 416 | return self 417 | else: 418 | return approx 419 | 420 | def cond(self, **kwargs): 421 | """Define a condition for the `predict` function. 422 | 423 | Parameters 424 | ---------- 425 | kwargs 426 | Named arguments specifying the levels of the covariates from the 427 | design formula. If a covariate is not specified, the first level is 428 | used. 429 | 430 | Returns 431 | ------- 432 | `pd.Series` 433 | A contrast vector that aligns to the columns of the design matrix. 434 | 435 | 436 | Notes 437 | ----- 438 | Subtracting two `cond(...)` calls, produces a contrast vector; these are 439 | commonly used in `R` to test for differences in a regression model. 440 | This pattern is inspired by the `R` package `glmGamPoi `__. 441 | """ 442 | if self.contrast_builder: 443 | return self.contrast_builder.cond(**kwargs) 444 | else: 445 | raise ValueError("The design was not specified as a formula. Cannot automatically construct contrast.") 446 | 447 | def copy(self, copy_adata = True): 448 | cp = copy.deepcopy(self) 449 | if copy_adata: 450 | cp.adata = self.adata.copy() 451 | return cp 452 | 453 | def __deepcopy__(self, memo): 454 | cp = LEMUR(self.adata, copy=False) 455 | cp.contrast_builder = self.contrast_builder 456 | cp.design_matrix = self.design_matrix 457 | cp.formula = self.formula 458 | cp.data_matrix = self.data_matrix 459 | cp.linear_coefficient_estimator = self.linear_coefficient_estimator 460 | cp.n_embedding = self.n_embedding 461 | cp.embedding = self.embedding 462 | cp.coefficients = self.coefficients 463 | cp.linear_coefficients = self.linear_coefficients 464 | cp.base_point = self.base_point 465 | cp.alignment_coefficients = self.alignment_coefficients 466 | return cp 467 | 468 | 469 | def __eq__(self, other): 470 | """Equality testing""" 471 | raise NotImplementedError( 472 | "Equality comparisons are not supported for AnnData objects, " 473 | "instead compare the desired attributes." 474 | ) 475 | 476 | def __str__(self): 477 | if self.embedding is None: 478 | return f"LEMUR model (not fitted yet) with {self.n_embedding} dimensions" 479 | else: 480 | return f"LEMUR model with {self.n_embedding} dimensions" 481 | 482 | 483 | def _handle_data_arg(data): 484 | if isinstance(data, ad.AnnData): 485 | return data 486 | else: 487 | return ad.AnnData(data) 488 | 489 | 490 | def _order_axis_by_variance(embedding, coefficients, base_point): 491 | U, d, Vt = np.linalg.svd(embedding, full_matrices=False) 492 | base_point = Vt @ base_point 493 | coefficients = np.einsum("pq,qij->pij", Vt, coefficients) 494 | embedding = U @ np.diag(d) 495 | return embedding, coefficients, base_point 496 | -------------------------------------------------------------------------------- /tests/test_grasmann_lm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from pylemur.tl._grassmann import grassmann_angle_from_point, grassmann_map, grassmann_project 4 | from pylemur.tl._grassmann_lm import ( 5 | grassmann_geodesic_regression, 6 | grassmann_lm, 7 | project_data_on_diffemb, 8 | project_diffemb_into_data_space, 9 | ) 10 | from pylemur.tl._lin_alg_wrappers import fit_pca 11 | 12 | 13 | def test_geodesic_regression(): 14 | n_feat = 17 15 | base_point = grassmann_project(np.random.randn(n_feat, 3)).T 16 | assert np.allclose(base_point @ base_point.T, np.eye(3)) 17 | coord_systems = [grassmann_project(np.random.randn(n_feat, 3)).T for _ in range(10)] 18 | x = np.arange(10) 19 | design = np.vstack([np.ones(10), x]).T 20 | 21 | fit = grassmann_geodesic_regression(coord_systems, design, base_point) 22 | assert fit.shape == (3, n_feat, 2) 23 | proj = grassmann_map(fit[:, :, 0].T, base_point.T) 24 | assert np.allclose(proj.T @ proj, np.eye(3)) 25 | proj = grassmann_map(fit[:, :, 1].T, base_point.T) 26 | assert np.allclose(proj.T @ proj, np.eye(3)) 27 | 28 | 29 | def test_grassmann_lm(): 30 | n_obs = 100 31 | base_point = grassmann_project(np.random.randn(5, 2)).T 32 | data = np.random.randn(n_obs, 5) 33 | plane_all = fit_pca(data, 2, center=False).coord_system 34 | des = np.ones((n_obs, 1)) 35 | fit = grassmann_lm(data, des, base_point) 36 | assert np.allclose(grassmann_angle_from_point(grassmann_map(fit[:, :, 0].T, base_point.T), plane_all.T), 0) 37 | 38 | # Make a design matrix of three groups (with an intercept) 39 | x = np.random.randint(3, size=n_obs) 40 | des = np.eye(3)[x, :] 41 | des = np.hstack([np.ones((n_obs, 1)), des[:, 1:3]]) 42 | fit = grassmann_lm(data, des, base_point) 43 | 44 | plane_a = fit_pca(data[x == 0], 2, center=False).coord_system 45 | plane_b = fit_pca(data[x == 1], 2, center=False).coord_system 46 | plane_c = fit_pca(data[x == 2], 2, center=False).coord_system 47 | 48 | assert np.allclose(grassmann_angle_from_point(grassmann_map(fit[:, :, 0].T, base_point.T), plane_a.T), 0) 49 | assert np.allclose( 50 | grassmann_angle_from_point(grassmann_map((fit[:, :, 0] + fit[:, :, 1]).T, base_point.T), plane_b.T), 0 51 | ) 52 | assert np.allclose( 53 | grassmann_angle_from_point(grassmann_map((fit[:, :, 0] + fit[:, :, 2]).T, base_point.T), plane_c.T), 0 54 | ) 55 | 56 | 57 | def test_project_data_on_diffemb(): 58 | n_obs = 100 59 | base_point = grassmann_project(np.random.randn(5, 2)).T 60 | data = np.random.randn(n_obs, 5) 61 | des = np.ones((n_obs, 1)) 62 | fit = grassmann_lm(data, des, base_point) 63 | pca = fit_pca(data, 2, center=False) 64 | angle = grassmann_angle_from_point(grassmann_map(fit[:, :, 0].T, base_point.T), pca.coord_system.T) 65 | assert np.allclose(angle, 0) 66 | 67 | proj = project_data_on_diffemb(data, des, fit, base_point) 68 | # The projection and the embedding are rotated to each other 69 | # Remove rotation effect using orthogonal procrustes 70 | U, _, Vt = np.linalg.svd(proj.T @ pca.embedding, full_matrices=False) 71 | rot = U @ Vt 72 | assert np.allclose(proj @ rot, pca.embedding) 73 | 74 | 75 | def test_project_data_on_diffemb2(): 76 | n_obs = 100 77 | base_point = grassmann_project(np.random.randn(5, 2)).T 78 | data = np.random.randn(n_obs, 5) 79 | des = np.ones((n_obs, 1)) 80 | fit = grassmann_lm(data, des, base_point) 81 | pca = fit_pca(data, 2, center=False) 82 | angle = grassmann_angle_from_point(grassmann_map(fit[:, :, 0].T, base_point.T), pca.coord_system.T) 83 | assert np.allclose(angle, 0) 84 | 85 | proj = project_data_on_diffemb(data, des, fit, base_point) 86 | data_hat1 = pca.embedding @ pca.coord_system 87 | data_hat2 = project_diffemb_into_data_space(proj, des, fit, base_point) 88 | assert np.allclose(data_hat1, data_hat2) 89 | -------------------------------------------------------------------------------- /tests/test_grassmann.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from pylemur.tl._grassmann import * 4 | 5 | 6 | def test_grassmann_map(): 7 | # Test case 1: empty base point 8 | x = np.array([[1, 2], [3, 4], [5, 6]]) 9 | p = np.array([]) 10 | assert np.array_equal(grassmann_map(x, p), p) 11 | 12 | # Test case 2: x contains NaN values 13 | x = np.array([[1, 2], [3, np.nan], [5, 6]]) 14 | p = np.array([[1, 0], [0, 1], [0, 0]]) 15 | assert np.isnan(grassmann_map(x, p)).all() 16 | 17 | p = grassmann_random_point(5, 2) 18 | assert np.allclose(p.T @ p, np.eye(2)) 19 | v = grassmann_random_tangent(p) * 10 20 | assert np.linalg.svd(v, compute_uv=False)[0] / np.pi * 180 > 90 21 | 22 | assert np.allclose(p.T @ v + v.T @ p, np.zeros((2, 2))) 23 | p2 = grassmann_map(v, p) 24 | valt = grassmann_log(p, p2) 25 | p3 = grassmann_map(valt, p) 26 | 27 | assert grassmann_angle_from_point(p3, p2) < 1e-10 28 | # assert np.linalg.matrix_rank(np.hstack([p3, p2])) == 2 29 | assert (valt**2).sum() < (v**2).sum() 30 | 31 | p4 = grassmann_random_point(5, 2) 32 | p5 = grassmann_random_point(5, 2) 33 | v45 = grassmann_log(p4, p5) 34 | assert np.allclose(p4.T @ v45 + v45.T @ p4, np.zeros((2, 2))) 35 | # This failed randomly on the github runner 36 | # assert np.linalg.matrix_rank(np.hstack([grassmann_map(v45, p4), p5])) == 2 37 | assert np.allclose(grassmann_log(p4, grassmann_map(v45, p4)), v45) 38 | 39 | 40 | def test_experiment(): 41 | assert 1 == 3 - 2 42 | -------------------------------------------------------------------------------- /tests/test_lemur.py: -------------------------------------------------------------------------------- 1 | from io import StringIO 2 | 3 | import anndata as ad 4 | import formulaic 5 | import numpy as np 6 | import pandas as pd 7 | 8 | import pylemur.tl._grassmann 9 | import pylemur.tl.alignment 10 | import pylemur.tl.lemur 11 | 12 | 13 | def test_design_specification_works(): 14 | Y = np.random.randn(500, 30) 15 | dat = pd.DataFrame({"condition": np.random.choice(["trt", "ctrl"], size=500)}) 16 | model = pylemur.tl.LEMUR(Y, design="~ condition", obs_data=dat) 17 | assert model.design_matrix.equals(formulaic.model_matrix("~ condition", dat)) 18 | 19 | adata = ad.AnnData(Y) 20 | model = pylemur.tl.LEMUR(adata, design="~ condition", obs_data=dat) 21 | assert model.design_matrix.equals(formulaic.model_matrix("~ condition", dat)) 22 | 23 | adata = ad.AnnData(Y, obs=dat) 24 | model = pylemur.tl.LEMUR(adata, design="~ condition") 25 | assert model.design_matrix.equals(formulaic.model_matrix("~ condition", dat)) 26 | 27 | 28 | def test_numpy_design_matrix_works(): 29 | Y = np.random.randn(500, 30) 30 | dat = pd.DataFrame({"condition": np.random.choice(["trt", "ctrl"], size=500)}) 31 | design_mat = formulaic.model_matrix("~ condition", dat).to_numpy() 32 | grouping = np.random.choice(2, size=500) 33 | 34 | ref_model = pylemur.tl.LEMUR(Y, design="~ condition", obs_data=dat).fit().align_with_grouping(grouping) 35 | 36 | model = pylemur.tl.LEMUR(Y, design=design_mat).fit().align_with_grouping(grouping) 37 | assert np.allclose(model.coefficients, ref_model.coefficients) 38 | 39 | model = pylemur.tl.LEMUR(ad.AnnData(Y), design=design_mat).fit().align_with_grouping(grouping) 40 | assert np.allclose(model.coefficients, ref_model.coefficients) 41 | 42 | design_df = pd.DataFrame(design_mat, columns=["Intercept", "Covar1"]) 43 | model = pylemur.tl.LEMUR(Y, design=design_df).fit().align_with_grouping(grouping) 44 | assert np.allclose(model.coefficients, ref_model.coefficients) 45 | assert np.allclose(model.alignment_coefficients, ref_model.alignment_coefficients) 46 | 47 | 48 | def test_pandas_design_matrix_works(): 49 | Y = np.random.randn(500, 30) 50 | dat = pd.DataFrame({"condition": np.random.choice(["trt", "ctrl"], size=500)}) 51 | design_mat = formulaic.model_matrix("~ condition", dat) 52 | 53 | ref_model = pylemur.tl.LEMUR(Y, design="~ condition", obs_data=dat).fit() 54 | 55 | model = pylemur.tl.LEMUR(Y, design=design_mat).fit() 56 | assert np.allclose(model.coefficients, ref_model.coefficients) 57 | 58 | model = pylemur.tl.LEMUR(ad.AnnData(Y), design=design_mat).fit() 59 | assert np.allclose(model.coefficients, ref_model.coefficients) 60 | 61 | 62 | def test_copy_works(): 63 | Y = np.random.randn(500, 30) 64 | dat = pd.DataFrame({"condition": np.random.choice(["trt", "ctrl"], size=500)}) 65 | model = pylemur.tl.LEMUR(Y, design="~ condition", obs_data=dat) 66 | cp = model.copy(copy_adata=False) 67 | cp2 = model.copy(copy_adata=True) 68 | assert id(model.adata) == id(cp.adata) 69 | assert id(model.adata) != id(cp2.adata) 70 | _assert_lemur_model_equal(model, cp) 71 | _assert_lemur_model_equal(model, cp2, adata_id_equal=False) 72 | 73 | 74 | def test_predict(): 75 | ## Make sure I get the same results as in R 76 | # save_func <- function(obj){ 77 | # readr::format_csv(as.data.frame(obj), col_names = FALSE) 78 | # } 79 | # randn <- function(n, m, ...){ 80 | # matrix(rnorm(n * m, ...), nrow = n, ncol = m) 81 | # } 82 | # 83 | # set.seed(1) 84 | # Y <- round(randn(5, 30), 2) 85 | # design <- cbind(1, rep(1:2, each = 15)) 86 | # fit <- lemur(Y, design = design, n_embedding = 3, test_fraction = 0) 87 | # save_func(Y) 88 | # save_func(design) 89 | # save_func(predict(fit)) 90 | Y = np.genfromtxt( 91 | StringIO( 92 | "-0.63,-0.82,1.51,-0.04,0.92,-0.06,1.36,-0.41,-0.16,-0.71,0.4,1.98,2.4,0.19,0.48,0.29,-0.57,0.33,-0.54,0.56,-0.62,1.77,-0.64,-0.39,-0.51,0.71,0.06,-1.54,-1.91,-0.75\n0.18,0.49,0.39,-0.02,0.78,-0.16,-0.1,-0.39,-0.25,0.36,-0.61,-0.37,-0.04,-1.8,-0.71,-0.44,-0.14,1.06,1.21,-1.28,0.04,0.72,-0.46,-0.32,1.34,-0.07,-0.59,-0.3,1.18,2.09\n-0.84,0.74,-0.62,0.94,0.07,-1.47,0.39,-0.06,0.7,0.77,0.34,-1.04,0.69,1.47,0.61,0,1.18,-0.3,1.16,-0.57,-0.91,0.91,1.43,-0.28,-0.21,-0.04,0.53,-0.53,-1.66,0.02\n1.6,0.58,-2.21,0.82,-1.99,-0.48,-0.05,1.1,0.56,-0.11,-1.13,0.57,0.03,0.15,-0.93,0.07,-1.52,0.37,0.7,-1.22,0.16,0.38,-0.65,0.49,-0.18,-0.68,-1.52,-0.65,-0.46,-1.29\n0.33,-0.31,1.12,0.59,0.62,0.42,-1.38,0.76,-0.69,0.88,1.43,-0.14,-0.74,2.17,-1.25,-0.59,0.59,0.27,1.59,-0.47,-0.65,1.68,-0.21,-0.18,-0.1,-0.32,0.31,-0.06,-1.12,-1.64\n" 93 | ), 94 | delimiter=",", 95 | ).T 96 | design = np.genfromtxt( 97 | StringIO( 98 | "1,1\n1,1\n1,1\n1,1\n1,1\n1,1\n1,1\n1,1\n1,1\n1,1\n1,1\n1,1\n1,1\n1,1\n1,1\n1,2\n1,2\n1,2\n1,2\n1,2\n1,2\n1,2\n1,2\n1,2\n1,2\n1,2\n1,2\n1,2\n1,2\n1,2\n" 99 | ), 100 | delimiter=",", 101 | ) 102 | results_init = np.genfromtxt( 103 | StringIO( 104 | "-0.8554888214077533,-0.2720977014782501,1.467533504177673,-0.22767956411506196,1.2785521085014837,0.14379806862053318,1.4750333827944937,-0.6570301309706671,0.2066831996074998,-0.353750513692208,0.40174695240184183,1.0849159925675056,1.8230740377826868,-0.3796132967907797,1.27432278200101,0.25734634440017745,-0.48925340880447665,0.15787257755416356,0.04938969573830959,0.3192312181063839,-0.3542994869485033,1.401448533272853,-0.3938541882847783,0.08741246846567274,-0.6510475608713524,0.22923361101684514,-0.06804635141148654,-0.8510364064773515,-1.6943910912740754,-1.7500059544823854\n0.22583899533086937,-0.23735468442585927,0.19733713734339808,-0.4421331667319863,0.03586724830483476,0.4198858825154666,-0.08304003977298911,-0.16513486823394496,-0.30042983957286806,-0.3460343445737336,-0.3558867742284205,0.3011982462662531,-0.20850076449582816,-1.0161941132088552,-0.27541891451634165,-0.38785291893731433,-0.12147153601775817,0.9332086847171666,1.6012153818454613,-1.4764027185416573,0.23315085154197657,0.4422452475355182,-0.20026735579063953,0.04360380473394784,1.2311912129895743,-0.3862511983506538,-0.7251300187432477,0.11080434946587961,1.2741080839827195,1.467848129569024\n-0.7940511553704984,0.36872650465867485,-0.696190082245086,0.7842394355421791,-0.2859846254450281,-1.2496520777731421,0.3835639668947345,0.06394998322061843,0.6354423384795286,0.43034530046034825,0.446899273641681,-0.6513965946225246,0.6871089757451906,1.8676308579518373,0.6993678988614832,-0.35853838671959254,1.3529850324988306,-0.246355558204097,1.1840747073603424,-0.40278806288570534,-0.980382841146775,1.0552239017849732,0.9738387917168745,-0.4865414099512409,-0.14227709749759898,-0.07349599449524213,0.7673735611128623,-0.27683325130457614,-1.4140680282113947,-0.2222153640576589\n1.418455323621421,1.003633642928876,-2.249906497292914,0.6558257816955055,-1.7203847145600766,-0.29827588824455953,0.043760865569417634,0.9060713852440867,0.8559072353487625,0.15882996592743862,-1.1213675611187175,-0.13675398777029954,-0.4426075110785659,-0.28963584470503034,-0.27355219556533633,-0.22900366687718224,-1.4215471940245057,0.5395008912619365,0.30620330206252744,-0.9001187449182614,-0.09048745003886999,0.7701978486901692,-1.2331890579233868,-0.0319982155355501,-0.019760633637726566,-0.37122588017711394,-1.2165536309067462,-0.9082217868619842,-0.3910853125055948,-0.8027104686077078\n0.16860548847823306,0.019631704090379926,1.069169569969778,0.4089409239787705,0.8084926808056435,0.6289431224230143,-1.2935687373997728,0.6009167939510944,-0.42509727026387234,1.0706646727858744,1.4570831601240841,-0.7309939926404594,-1.1819487005772338,1.8301112490396232,-0.6209506647651512,-0.1631607424281398,0.3494484690478879,0.3004091825038984,1.248612413250847,-0.5327275335431212,-0.7111287487922291,1.7104365790742557,0.179842226080069,-0.19838902997632976,-0.10222070429937023,-0.02570345245587259,0.10711511924305978,-0.7160399504157087,-1.5156808495142373,-0.8308129777750086\n" 105 | ), 106 | delimiter=",", 107 | ).T 108 | 109 | model = pylemur.tl.LEMUR(Y, design=design, n_embedding=3) 110 | model.fit(verbose=False) 111 | pred_init = model.predict() 112 | assert np.allclose(pred_init, results_init) 113 | 114 | ## Here I have to override all coefficients to make sure that the results match 115 | # set.seed(1) 116 | # wide_alignment_coef <- round(cbind(randn(3,4), randn(3,4)), 2) 117 | # fit@metadata$alignment_coefficients <- array(wide_alignment_coef, dim = c(3,4,2)) 118 | # res <- predict(fit) 119 | # save_func(fit$base_point) 120 | # save_func(fit$embedding) 121 | # save_func(wide_alignment_coef) 122 | # save_func(cbind(fit$coefficients[,,1], fit$coefficients[,,2])) 123 | # save_func(res) 124 | basepoint = np.genfromtxt( 125 | StringIO( 126 | "-0.38925206703492543,0.5213219572509142,-0.7231411670117558\n-0.014297729446464128,-0.3059067922758592,0.056315487696737755\n0.2455091078121989,0.48054559390052237,0.133975572527343\n0.7927765077921876,-0.20187676099522883,-0.5743821409539152\n0.39938589098238725,0.602466726782849,0.3550086203707195\n" 127 | ), 128 | delimiter=",", 129 | ).T 130 | assert np.allclose(pylemur.tl._grassmann.grassmann_angle_from_point(basepoint.T, model.base_point.T), 0) 131 | emb = np.genfromtxt( 132 | StringIO( 133 | "1.307607290530331,1.204126637770359,-2.334941421246343,1.202469151007156,-1.7184261449703435,-0.5734035327253721,-0.6805385978545121,1.3185253662996623,0.8741076652495349,0.8858962317444363,-0.34911335825714623,-0.9635203043565856,-1.052446567304073,1.396858005348387,-0.5172004212354832,-0.5028122904697048,-0.3322559893690832,1.050654556399191,2.1375501746096512,-1.7864259568934735,-0.38553520332697033,1.7579862599667835,-0.44888734172910766,-0.14472710389159946,0.8047342904549903,-0.4509018220855433,-0.8826656042763416,-0.7455036156569727,-0.28188072558002397,0.2106703718482004\n-1.8164798196642507,-0.477295729296259,0.4790240845360194,0.1574495019414363,0.5418966595889152,-1.218733839936566,-0.030245840455864665,-0.6590134762699248,-0.17773495609983386,0.16681479907923033,0.9577303872796369,-0.8649693035413422,0.5516152492211109,1.872319735730998,0.5176225478866773,0.1933219672043183,1.1600064890134691,-0.2055279125365953,0.8140862711968238,0.5522995484090838,-1.0056198219223997,2.0135021529189063,0.9037871448321165,-0.1835087018953936,-0.7687706326544156,0.4350873523204601,1.1121743457677276,-0.6639780098588698,-2.6747181225665626,-1.6821420702286731\n0.18948314499743618,-0.35557961492948087,1.2558867706759906,-0.06200523822352812,0.8150276978150715,0.9017861009351356,-1.7593369587152685,0.38440670001846194,-0.9259345355859308,0.7952852407963301,1.2519868131937648,-0.7881479698006301,-1.7254602945657311,1.1468879075695446,-1.1242857641811612,-0.777687817264553,1.2908968434310275,-0.7487113052869242,0.3633562246175412,-0.7434161046620548,-0.6564971373410212,-1.062751302635433,0.9113503972724165,-0.742431234492101,0.1738627261607038,-0.5327010359755776,0.46611269007910494,0.42937368327065845,0.34776880301511215,1.281474569811115\n" 134 | ), 135 | delimiter=",", 136 | ).T 137 | wide_alignment_coef = np.genfromtxt( 138 | StringIO( 139 | "0,-0.63,1.6,0.49,0,-0.31,-0.62,-0.04\n0,0.18,0.33,0.74,0,1.51,-2.21,-0.02\n0,-0.84,-0.82,0.58,0,0.39,1.12,0.94\n" 140 | ), 141 | delimiter=",", 142 | ) 143 | alignment_coef = np.stack([wide_alignment_coef[:, 0:4], wide_alignment_coef[:, 4:8]], axis=2) 144 | coef = np.genfromtxt( 145 | StringIO( 146 | "-0.39126157680942947,-0.13893603024573095,0.22217190036192688,0.2997192813183081,0.045415132726990666,-0.11388588356030183\n-1.4118852084954074,-0.12010846331372559,-0.042620766619525983,1.1685923275506187,-0.009795610498605375,0.1481261731296855\n0.4301922175223224,0.8308632945684709,-1.7460499853532059,-0.17472584078151635,-0.3588445576311046,1.1196336953167267\n0.010836267191959998,-0.020245706071854806,0.04720615470254291,-0.013801681502086402,0.009718078909566252,-0.032178386499295816\n-0.7178345111625161,-0.6102684315610786,1.1946309212817439,0.46875216448007323,0.2452095180077213,-0.7300769880872995\n" 147 | ), 148 | delimiter=",", 149 | ) 150 | coef = np.stack([coef[:, 0:3], coef[:, 3:6]], axis=2) 151 | coef = np.einsum("ijk->jik", coef) 152 | results = np.genfromtxt( 153 | StringIO( 154 | "0.9835142473010876,0.5282279297874248,0.3734623735376733,0.29148494672406977,0.3233549213466831,0.8738371099709996,0.5407962262440337,0.5618064038652619,0.4615777825112788,0.28201578105225994,0.06759705564456442,0.8278171090396453,0.3564197483842603,-0.37494408404208857,0.3130324486328697,4.43109412669912,4.485260023784498,-13.556871509080235,-25.740741431276582,19.3507385539798,4.267028891777953,-24.31631072038289,5.55415174991481,0.5751412151679098,-8.850525840651931,3.9506221824433108,9.87464782407815,9.75898404222322,5.923544047383523,0.5432368439394216\n-0.5884851968278264,-0.6980003631236811,1.0095730179286158,-0.706603653513592,0.6775999178228775,0.2960753981255735,-0.04883524121983143,-0.6502161106433406,-0.6374945628932893,-0.46524145497379454,0.0868361701396683,0.24887091370316872,0.08165271296897453,-0.7703009490648383,-0.08543059842869558,1.8195448028133048,2.2357171298099003,-4.317401830679442,-8.088317047084285,6.921096710379453,1.5046074415988786,-7.462954899818863,2.5044021915523067,0.44085644478623487,-2.7813085924525587,1.7337832588534812,3.965667679124676,3.523130209648968,1.7727989396531905,0.2683775618147773\n1.2241353640082717,1.439811782612444,-2.4939338263214896,1.413000126505777,-1.732752410802792,-0.8109483715136666,0.060386636787661846,1.297799915381223,1.325781352884555,0.8252158435825825,-0.46440880845174237,-0.6215041984220009,-0.2577507799745221,1.405770288579147,0.07939708514456326,3.63765958748972,4.311718634450114,-10.958511305614007,-20.4685709755674,15.761881934367485,3.4712598990276082,-19.643794665294983,5.058176520273988,0.500553412576863,-6.917971274367058,3.329924337359561,8.448268016100178,8.2448015916862,5.000276047604282,0.9543282399075053\n-0.7921997533676974,-0.22060767244684792,-0.13854023872636448,-0.07315533443899382,-0.04074837994335717,-0.6868254379104765,0.30181730830558207,-0.4486065557871184,0.016532124694023168,-0.2487768721959298,-0.043942589618246174,-0.18230527640412342,0.5013051052792722,0.22675447837288173,0.33929909418738713,3.3366520193532914,3.389761379749944,-10.811207239740792,-20.286351865125106,15.07164187230232,3.0442924834976672,-18.933270905633723,4.214979326150896,0.264093656883327,-7.2491426213752,2.978123336697904,7.6524353015805255,7.331256553093012,4.067336259849709,-0.0705995572837037\n0.429231215968942,0.4526675323865768,-0.0583697984419273,0.6074892682578291,0.02405662404573697,0.13577913690270105,-0.3608880085761539,0.6353597043295498,0.26375995622833437,0.7194765847659087,0.5782845037350453,-0.326907098443994,-0.38119677146460307,1.1810728627919298,-0.08981571248586934,5.740966466337848,6.297668960348173,-17.056604679816925,-32.1536933270833,24.663672997961317,5.422318375098499,-30.48048219956338,7.551698596077222,0.8269489489661501,-10.99498861966501,5.205042841546413,12.959282592098615,12.608602675876007,7.517277260484411,0.9922891113340572\n" 155 | ), 156 | delimiter=",", 157 | ).T 158 | 159 | model.alignment_coefficients = alignment_coef 160 | model.base_point = basepoint 161 | model.embedding = emb 162 | model.coefficients = coef 163 | pred_new = model.predict() 164 | assert np.allclose(pred_new, results) 165 | 166 | 167 | def test_align_with_grouping(): 168 | ## Make sure I get the same results as in R 169 | ## Same code as above 170 | Y = np.genfromtxt( 171 | StringIO( 172 | "-0.63,-0.82,1.51,-0.04,0.92,-0.06,1.36,-0.41,-0.16,-0.71,0.4,1.98,2.4,0.19,0.48,0.29,-0.57,0.33,-0.54,0.56,-0.62,1.77,-0.64,-0.39,-0.51,0.71,0.06,-1.54,-1.91,-0.75\n0.18,0.49,0.39,-0.02,0.78,-0.16,-0.1,-0.39,-0.25,0.36,-0.61,-0.37,-0.04,-1.8,-0.71,-0.44,-0.14,1.06,1.21,-1.28,0.04,0.72,-0.46,-0.32,1.34,-0.07,-0.59,-0.3,1.18,2.09\n-0.84,0.74,-0.62,0.94,0.07,-1.47,0.39,-0.06,0.7,0.77,0.34,-1.04,0.69,1.47,0.61,0,1.18,-0.3,1.16,-0.57,-0.91,0.91,1.43,-0.28,-0.21,-0.04,0.53,-0.53,-1.66,0.02\n1.6,0.58,-2.21,0.82,-1.99,-0.48,-0.05,1.1,0.56,-0.11,-1.13,0.57,0.03,0.15,-0.93,0.07,-1.52,0.37,0.7,-1.22,0.16,0.38,-0.65,0.49,-0.18,-0.68,-1.52,-0.65,-0.46,-1.29\n0.33,-0.31,1.12,0.59,0.62,0.42,-1.38,0.76,-0.69,0.88,1.43,-0.14,-0.74,2.17,-1.25,-0.59,0.59,0.27,1.59,-0.47,-0.65,1.68,-0.21,-0.18,-0.1,-0.32,0.31,-0.06,-1.12,-1.64\n" 173 | ), 174 | delimiter=",", 175 | ).T 176 | design = np.genfromtxt( 177 | StringIO( 178 | "1,1\n1,1\n1,1\n1,1\n1,1\n1,1\n1,1\n1,1\n1,1\n1,1\n1,1\n1,1\n1,1\n1,1\n1,1\n1,2\n1,2\n1,2\n1,2\n1,2\n1,2\n1,2\n1,2\n1,2\n1,2\n1,2\n1,2\n1,2\n1,2\n1,2\n" 179 | ), 180 | delimiter=",", 181 | ) 182 | results_init = np.genfromtxt( 183 | StringIO( 184 | "-0.8554888214077533,-0.2720977014782501,1.467533504177673,-0.22767956411506196,1.2785521085014837,0.14379806862053318,1.4750333827944937,-0.6570301309706671,0.2066831996074998,-0.353750513692208,0.40174695240184183,1.0849159925675056,1.8230740377826868,-0.3796132967907797,1.27432278200101,0.25734634440017745,-0.48925340880447665,0.15787257755416356,0.04938969573830959,0.3192312181063839,-0.3542994869485033,1.401448533272853,-0.3938541882847783,0.08741246846567274,-0.6510475608713524,0.22923361101684514,-0.06804635141148654,-0.8510364064773515,-1.6943910912740754,-1.7500059544823854\n0.22583899533086937,-0.23735468442585927,0.19733713734339808,-0.4421331667319863,0.03586724830483476,0.4198858825154666,-0.08304003977298911,-0.16513486823394496,-0.30042983957286806,-0.3460343445737336,-0.3558867742284205,0.3011982462662531,-0.20850076449582816,-1.0161941132088552,-0.27541891451634165,-0.38785291893731433,-0.12147153601775817,0.9332086847171666,1.6012153818454613,-1.4764027185416573,0.23315085154197657,0.4422452475355182,-0.20026735579063953,0.04360380473394784,1.2311912129895743,-0.3862511983506538,-0.7251300187432477,0.11080434946587961,1.2741080839827195,1.467848129569024\n-0.7940511553704984,0.36872650465867485,-0.696190082245086,0.7842394355421791,-0.2859846254450281,-1.2496520777731421,0.3835639668947345,0.06394998322061843,0.6354423384795286,0.43034530046034825,0.446899273641681,-0.6513965946225246,0.6871089757451906,1.8676308579518373,0.6993678988614832,-0.35853838671959254,1.3529850324988306,-0.246355558204097,1.1840747073603424,-0.40278806288570534,-0.980382841146775,1.0552239017849732,0.9738387917168745,-0.4865414099512409,-0.14227709749759898,-0.07349599449524213,0.7673735611128623,-0.27683325130457614,-1.4140680282113947,-0.2222153640576589\n1.418455323621421,1.003633642928876,-2.249906497292914,0.6558257816955055,-1.7203847145600766,-0.29827588824455953,0.043760865569417634,0.9060713852440867,0.8559072353487625,0.15882996592743862,-1.1213675611187175,-0.13675398777029954,-0.4426075110785659,-0.28963584470503034,-0.27355219556533633,-0.22900366687718224,-1.4215471940245057,0.5395008912619365,0.30620330206252744,-0.9001187449182614,-0.09048745003886999,0.7701978486901692,-1.2331890579233868,-0.0319982155355501,-0.019760633637726566,-0.37122588017711394,-1.2165536309067462,-0.9082217868619842,-0.3910853125055948,-0.8027104686077078\n0.16860548847823306,0.019631704090379926,1.069169569969778,0.4089409239787705,0.8084926808056435,0.6289431224230143,-1.2935687373997728,0.6009167939510944,-0.42509727026387234,1.0706646727858744,1.4570831601240841,-0.7309939926404594,-1.1819487005772338,1.8301112490396232,-0.6209506647651512,-0.1631607424281398,0.3494484690478879,0.3004091825038984,1.248612413250847,-0.5327275335431212,-0.7111287487922291,1.7104365790742557,0.179842226080069,-0.19838902997632976,-0.10222070429937023,-0.02570345245587259,0.10711511924305978,-0.7160399504157087,-1.5156808495142373,-0.8308129777750086\n" 185 | ), 186 | delimiter=",", 187 | ).T 188 | 189 | model = pylemur.tl.LEMUR(Y, design=design, n_embedding=3) 190 | model.fit(verbose=False) 191 | pred_init = model.predict() 192 | assert np.allclose(pred_init, results_init) 193 | 194 | ## Here is a modified version of the approach from the previous test 195 | # set.seed(1) 196 | # grouping <- sample(c(1,2), size = 30, replace = TRUE) 197 | # save_func(grouping) 198 | # fit <- align_by_grouping(fit, grouping = as.character(grouping), design = fit$design_matrix) 199 | # save_func(cbind(fit$alignment_coefficients[,,1], fit$alignment_coefficients[,,2])) 200 | # ... 201 | grouping = np.genfromtxt( 202 | StringIO("1\n2\n1\n1\n2\n1\n1\n1\n2\n2\n1\n1\n1\n1\n1\n2\n2\n2\n2\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n2\n") 203 | ) 204 | 205 | basepoint = np.genfromtxt( 206 | StringIO( 207 | "-0.38925206703492543,0.5213219572509142,-0.7231411670117558\n-0.014297729446464128,-0.3059067922758592,0.056315487696737755\n0.2455091078121989,0.48054559390052237,0.133975572527343\n0.7927765077921876,-0.20187676099522883,-0.5743821409539152\n0.39938589098238725,0.602466726782849,0.3550086203707195\n" 208 | ), 209 | delimiter=",", 210 | ).T 211 | assert np.allclose(pylemur.tl._grassmann.grassmann_angle_from_point(basepoint.T, model.base_point.T), 0) 212 | emb = np.genfromtxt( 213 | StringIO( 214 | "1.307607290530331,1.204126637770359,-2.334941421246343,1.202469151007156,-1.7184261449703435,-0.5734035327253721,-0.6805385978545121,1.3185253662996623,0.8741076652495349,0.8858962317444363,-0.34911335825714623,-0.9635203043565856,-1.052446567304073,1.396858005348387,-0.5172004212354832,-0.5028122904697048,-0.3322559893690832,1.050654556399191,2.1375501746096512,-1.7864259568934735,-0.38553520332697033,1.7579862599667835,-0.44888734172910766,-0.14472710389159946,0.8047342904549903,-0.4509018220855433,-0.8826656042763416,-0.7455036156569727,-0.28188072558002397,0.2106703718482004\n-1.8164798196642507,-0.477295729296259,0.4790240845360194,0.1574495019414363,0.5418966595889152,-1.218733839936566,-0.030245840455864665,-0.6590134762699248,-0.17773495609983386,0.16681479907923033,0.9577303872796369,-0.8649693035413422,0.5516152492211109,1.872319735730998,0.5176225478866773,0.1933219672043183,1.1600064890134691,-0.2055279125365953,0.8140862711968238,0.5522995484090838,-1.0056198219223997,2.0135021529189063,0.9037871448321165,-0.1835087018953936,-0.7687706326544156,0.4350873523204601,1.1121743457677276,-0.6639780098588698,-2.6747181225665626,-1.6821420702286731\n0.18948314499743618,-0.35557961492948087,1.2558867706759906,-0.06200523822352812,0.8150276978150715,0.9017861009351356,-1.7593369587152685,0.38440670001846194,-0.9259345355859308,0.7952852407963301,1.2519868131937648,-0.7881479698006301,-1.7254602945657311,1.1468879075695446,-1.1242857641811612,-0.777687817264553,1.2908968434310275,-0.7487113052869242,0.3633562246175412,-0.7434161046620548,-0.6564971373410212,-1.062751302635433,0.9113503972724165,-0.742431234492101,0.1738627261607038,-0.5327010359755776,0.46611269007910494,0.42937368327065845,0.34776880301511215,1.281474569811115\n" 215 | ), 216 | delimiter=",", 217 | ).T 218 | coef = np.genfromtxt( 219 | StringIO( 220 | "-0.39126157680942947,-0.13893603024573095,0.22217190036192688,0.2997192813183081,0.045415132726990666,-0.11388588356030183\n-1.4118852084954074,-0.12010846331372559,-0.042620766619525983,1.1685923275506187,-0.009795610498605375,0.1481261731296855\n0.4301922175223224,0.8308632945684709,-1.7460499853532059,-0.17472584078151635,-0.3588445576311046,1.1196336953167267\n0.010836267191959998,-0.020245706071854806,0.04720615470254291,-0.013801681502086402,0.009718078909566252,-0.032178386499295816\n-0.7178345111625161,-0.6102684315610786,1.1946309212817439,0.46875216448007323,0.2452095180077213,-0.7300769880872995\n" 221 | ), 222 | delimiter=",", 223 | ) 224 | coef = np.stack([coef[:, 0:3], coef[:, 3:6]], axis=2) 225 | coef = np.einsum("ijk->jik", coef) 226 | 227 | wide_alignment_coef = np.genfromtxt( 228 | StringIO( 229 | "-0.09057594036322775,0.004854216024969793,0.002327313009887152,0.005758995615529929,0.05944370038362205,-0.0035991050730759615,-0.0020095175466780125,-0.005428755213860675\n-0.05277256958166519,0.042258290722341184,0.020260381752233707,0.05013483325379276,0.02699629278032626,-0.03133194479519331,-0.017493819035318634,-0.04725993134785861\n-0.07693347488863567,0.04690711963321551,0.02248922363923753,0.0556501595551681,0.04220308611767923,-0.03477876785187711,-0.019418311727825953,-0.052458990075042476\n" 230 | ), 231 | delimiter=",", 232 | ) 233 | alignment_coef = np.stack([wide_alignment_coef[:, 0:4], wide_alignment_coef[:, 4:8]], axis=2) 234 | results = np.genfromtxt( 235 | StringIO( 236 | "-0.8554888214077531,-0.27209770147824963,1.4675335041776725,-0.22767956411506185,1.2785521085014835,0.14379806862053301,1.4750333827944933,-0.6570301309706669,0.20668319960749992,-0.3537505136922079,0.4017469524018417,1.0849159925675056,1.8230740377826868,-0.3796132967907798,1.27432278200101,0.25734634440017756,-0.4892534088044769,0.15787257755416356,0.049389695738309536,0.3192312181063839,-0.3542994869485032,1.4014485332728528,-0.3938541882847784,0.0874124684656728,-0.6510475608713524,0.22923361101684514,-0.0680463514114866,-0.8510364064773513,-1.6943910912740756,-1.7500059544823856\n0.22583899533086926,-0.2373546844258592,0.19733713734339786,-0.4421331667319863,0.03586724830483465,0.4198858825154665,-0.08304003977298917,-0.16513486823394496,-0.300429839572868,-0.34603434457373355,-0.3558867742284205,0.3011982462662531,-0.20850076449582813,-1.0161941132088552,-0.2754189145163416,-0.38785291893731433,-0.12147153601775795,0.9332086847171664,1.601215381845461,-1.476402718541657,0.23315085154197665,0.44224524753551797,-0.20026735579063937,0.043603804733947815,1.2311912129895743,-0.3862511983506537,-0.7251300187432477,0.11080434946587975,1.27410808398272,1.467848129569024\n-0.7940511553704983,0.36872650465867474,-0.6961900822450855,0.784239435542179,-0.2859846254450279,-1.249652077773142,0.3835639668947346,0.0639499832206183,0.6354423384795286,0.43034530046034813,0.446899273641681,-0.6513965946225245,0.6871089757451905,1.8676308579518373,0.6993678988614832,-0.3585383867195926,1.352985032498831,-0.24635555820409705,1.1840747073603421,-0.40278806288570546,-0.9803828411467752,1.0552239017849732,0.9738387917168745,-0.486541409951241,-0.1422770974975989,-0.07349599449524222,0.7673735611128623,-0.27683325130457614,-1.4140680282113949,-0.22221536405765868\n1.4184553236214208,1.0036336429288755,-2.2499064972929133,0.6558257816955054,-1.7203847145600764,-0.2982758882445596,0.043760865569417745,0.9060713852440865,0.8559072353487623,0.15882996592743856,-1.1213675611187173,-0.13675398777029935,-0.4426075110785658,-0.28963584470503045,-0.2735521955653362,-0.22900366687718213,-1.4215471940245057,0.5395008912619363,0.3062033020625272,-0.9001187449182614,-0.09048745003886988,0.7701978486901689,-1.2331890579233868,-0.031998215535550045,-0.019760633637726732,-0.37122588017711383,-1.2165536309067464,-0.9082217868619842,-0.3910853125055947,-0.802710468607708\n0.1686054884782331,0.019631704090379815,1.0691695699697783,0.40894092397877047,0.8084926808056437,0.6289431224230144,-1.2935687373997728,0.6009167939510944,-0.42509727026387234,1.0706646727858744,1.4570831601240841,-0.7309939926404595,-1.1819487005772338,1.8301112490396232,-0.6209506647651512,-0.16316074242813974,0.34944846904788796,0.3004091825038983,1.248612413250847,-0.5327275335431213,-0.7111287487922291,1.7104365790742553,0.179842226080069,-0.19838902997632976,-0.10222070429937029,-0.02570345245587259,0.10711511924305983,-0.7160399504157086,-1.5156808495142378,-0.8308129777750086\n" 237 | ), 238 | delimiter=",", 239 | ).T 240 | 241 | model.base_point = basepoint 242 | model.embedding = emb 243 | model.coefficients = coef 244 | model.align_with_grouping(grouping) 245 | assert np.allclose(model.alignment_coefficients, alignment_coef) 246 | 247 | pred_new = model.predict() 248 | assert np.allclose(pred_new, results) 249 | assert np.allclose(pred_new, pred_init) 250 | 251 | 252 | def test_align_works(): 253 | nelem = 200 254 | Y = np.random.randn(nelem, 5) 255 | design = np.stack([np.ones(nelem), np.random.choice(2, size=nelem)], axis=1) 256 | model = pylemur.tl.LEMUR(Y, design=design, n_embedding=3) 257 | model.fit(verbose=False) 258 | model_2 = model.copy() 259 | model_3 = model.copy() 260 | 261 | grouping = np.random.choice(2, nelem) 262 | model_2.align_with_grouping(grouping) 263 | model_3.align_with_harmony() 264 | assert np.allclose(model_2.predict(), model.predict()) 265 | assert np.allclose(model_3.predict(), model.predict()) 266 | 267 | 268 | def _assert_lemur_model_equal(m1, m2, adata_id_equal=True): 269 | for k in m1.__dict__.keys(): 270 | if k == "adata" and adata_id_equal: 271 | assert id(m1.adata) == id(m2.adata) 272 | elif k == "adata" and not adata_id_equal: 273 | assert id(m1.adata) != id(m2.adata) 274 | elif isinstance(m1.__dict__[k], pd.DataFrame): 275 | pd.testing.assert_frame_equal(m1.__dict__[k], m2.__dict__[k]) 276 | elif isinstance(m1.__dict__[k], np.ndarray): 277 | assert np.array_equal(m1.__dict__[k], m2.__dict__[k]) 278 | else: 279 | assert m1.__dict__[k] == m2.__dict__[k] 280 | -------------------------------------------------------------------------------- /tests/test_lin_alg_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.linear_model import LinearRegression, Ridge 3 | 4 | import pylemur.tl.alignment as pylemur_al 5 | from pylemur.tl._lin_alg_wrappers import * 6 | 7 | 8 | def test_fit_pca(): 9 | # Make example data 10 | Y = np.random.randn(30, 400) 11 | pca = fit_pca(Y, 3, center=True) 12 | assert np.allclose(pca.offset, Y.mean(axis=0)) 13 | assert np.allclose(pca.coord_system @ pca.coord_system.T, np.eye(3)) 14 | assert np.allclose(pca.embedding, (Y - pca.offset) @ pca.coord_system.T) 15 | 16 | pca2 = fit_pca(Y, 3, center=False) 17 | assert np.allclose(pca2.offset, np.zeros(Y.shape[1])) 18 | assert np.allclose(pca2.coord_system @ pca2.coord_system.T, np.eye(3)) 19 | assert np.allclose(pca2.embedding, (Y - pca2.offset) @ pca2.coord_system.T) 20 | assert np.allclose(Y @ pca2.coord_system.T, pca2.embedding) 21 | 22 | 23 | def test_ridge_regression(): 24 | # Regular least squares 25 | Y = np.random.randn(400, 30) 26 | X = np.random.randn(400, 3) 27 | beta = ridge_regression(Y, X) 28 | assert beta.shape == (3, 30) 29 | assert np.allclose(beta, np.linalg.inv(X.T @ X) @ X.T @ Y) 30 | reg = LinearRegression(fit_intercept=False).fit(X, Y) 31 | assert np.allclose(beta, reg.coef_.T) 32 | 33 | # Check with weights 34 | weights = np.random.rand(400) 35 | beta = ridge_regression(Y, X, weights=weights) 36 | reg = LinearRegression(fit_intercept=False).fit(X, Y, sample_weight=weights) 37 | assert np.allclose(beta, reg.coef_.T) 38 | 39 | # Check with ridge penalty 40 | pen = 0.3 41 | beta = ridge_regression(Y, X, ridge_penalty=pen) 42 | gamma = np.sqrt(400) * pen**2 * np.eye(3) 43 | beta2 = np.linalg.inv(X.T @ X + gamma.T @ gamma) @ X.T @ Y 44 | assert np.allclose(beta, beta2) 45 | 46 | reg = Ridge(alpha=pen**4 * 400, fit_intercept=False).fit(X, Y) 47 | assert np.allclose(beta, reg.coef_.T) 48 | 49 | 50 | def test_forward_reverse_transformations(): 51 | coef = np.random.randn(5, 6, 2) 52 | vec = np.random.randn(2) 53 | forward = pylemur_al._forward_linear_transformation(coef, vec) 54 | reverse = pylemur_al._reverse_linear_transformation(coef, vec) 55 | assert np.allclose(reverse @ forward[:, 1:], np.diag(np.ones(5))) 56 | --------------------------------------------------------------------------------