├── .editorconfig ├── .gitattributes ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.yml │ ├── config.yml │ └── feature_request.yml ├── labels.yml ├── pull_request_template.md ├── release-drafter.yml └── workflows │ ├── build.yml │ ├── labeler.yml │ ├── release.yml │ ├── release_drafter.yml │ └── test.yml ├── .gitignore ├── .gitmodules ├── .pre-commit-config.yaml ├── .readthedocs.yml ├── LICENSE ├── README.md ├── biome.jsonc ├── codecov.yml ├── docs ├── Makefile ├── _ext │ ├── edit_on_github.py │ └── typed_returns.py ├── _static │ ├── SCVI_LICENSE │ ├── css │ │ ├── overwrite.css │ │ └── sphinx_gallery.css │ ├── docstring_previews │ │ ├── augur_dp_scatter.png │ │ ├── augur_important_features.png │ │ ├── augur_lollipop.png │ │ ├── augur_scatterplot.png │ │ ├── de_fold_change.png │ │ ├── de_multicomparison_fc.png │ │ ├── de_paired_expression.png │ │ ├── de_volcano.png │ │ ├── dialogue_pairplot.png │ │ ├── dialogue_violin.png │ │ ├── enrichment_dotplot.png │ │ ├── enrichment_gsea.png │ │ ├── milo_da_beeswarm.png │ │ ├── milo_nhood.png │ │ ├── milo_nhood_graph.png │ │ ├── mixscape_barplot.png │ │ ├── mixscape_heatmap.png │ │ ├── mixscape_lda.png │ │ ├── mixscape_perturbscore.png │ │ ├── mixscape_violin.png │ │ ├── pseudobulk_samples.png │ │ ├── sccoda_boxplots.png │ │ ├── sccoda_effects_barplot.png │ │ ├── sccoda_rel_abundance_dispersion_plot.png │ │ ├── sccoda_stacked_barplot.png │ │ ├── scgen_reg_mean.png │ │ ├── tasccoda_draw_effects.png │ │ ├── tasccoda_draw_tree.png │ │ └── tasccoda_effects_umap.png │ ├── icons │ │ ├── code-24px.svg │ │ ├── computer-24px.svg │ │ ├── library_books-24px.svg │ │ └── play_circle_outline-24px.svg │ ├── pertpy_logo.png │ ├── pertpy_logo.svg │ ├── placeholder.png │ └── tutorials │ │ ├── augur.png │ │ ├── cinemaot.png │ │ ├── dge.png │ │ ├── dialogue.png │ │ ├── distances.png │ │ ├── distances_tests.png │ │ ├── enrichment.png │ │ ├── guide_rna_assignment.png │ │ ├── mcfarland.png │ │ ├── metadata.png │ │ ├── milo.png │ │ ├── mixscape.png │ │ ├── norman.png │ │ ├── ontology.png │ │ ├── perturbation_space.png │ │ ├── placeholder.png │ │ ├── sccoda.png │ │ ├── sccoda_extended.png │ │ ├── scgen_perturbation_prediction.png │ │ ├── tasccoda.png │ │ └── zhang.png ├── _templates │ └── autosummary │ │ └── class.rst ├── about │ ├── background.md │ └── cite.md ├── api.md ├── api │ ├── datasets_index.md │ ├── metadata_index.md │ ├── preprocessing_index.md │ └── tools_index.md ├── changelog.md ├── conf.py ├── contributing.md ├── index.md ├── installation.md ├── make.bat ├── references.bib ├── references.md ├── tutorials.md ├── tutorials │ ├── metadata.md │ ├── preprocessing.md │ └── tools.md ├── usecases.md └── utils.py ├── pertpy ├── __init__.py ├── _doc.py ├── _types.py ├── data │ ├── __init__.py │ ├── _dataloader.py │ └── _datasets.py ├── metadata │ ├── __init__.py │ ├── _cell_line.py │ ├── _compound.py │ ├── _drug.py │ ├── _look_up.py │ ├── _metadata.py │ └── _moa.py ├── plot │ └── __init__.py ├── preprocessing │ ├── __init__.py │ ├── _guide_rna.py │ └── _guide_rna_mixture.py ├── py.typed └── tools │ ├── __init__.py │ ├── _augur.py │ ├── _cinemaot.py │ ├── _coda │ ├── __init__.py │ ├── _base_coda.py │ ├── _sccoda.py │ └── _tasccoda.py │ ├── _dialogue.py │ ├── _differential_gene_expression │ ├── __init__.py │ ├── _base.py │ ├── _checks.py │ ├── _dge_comparison.py │ ├── _edger.py │ ├── _pydeseq2.py │ ├── _simple_tests.py │ └── _statsmodels.py │ ├── _distances │ ├── __init__.py │ ├── _distance_tests.py │ └── _distances.py │ ├── _enrichment.py │ ├── _milo.py │ ├── _mixscape.py │ ├── _perturbation_space │ ├── __init__.py │ ├── _clustering.py │ ├── _comparison.py │ ├── _discriminator_classifiers.py │ ├── _metrics.py │ ├── _perturbation_space.py │ └── _simple.py │ ├── _scgen │ ├── __init__.py │ ├── _base_components.py │ ├── _scgen.py │ ├── _scgenvae.py │ └── _utils.py │ ├── decoupler_LICENSE │ └── transferlearning_MMD_LICENSE ├── pyproject.toml └── tests ├── conftest.py ├── metadata ├── test_cell_line.py ├── test_compound.py ├── test_drug.py └── test_moa.py ├── preprocessing └── test_grna_assignment.py └── tools ├── _coda ├── test_sccoda.py └── test_tasccoda.py ├── _differential_gene_expression ├── __init__.py ├── conftest.py ├── test_base.py ├── test_compare_groups.py ├── test_dge.py ├── test_edger.py ├── test_input_checks.py ├── test_pydeseq2.py ├── test_simple_tests.py └── test_statsmodels.py ├── _distances ├── test_distance_tests.py └── test_distances.py ├── _perturbation_space ├── test_comparison.py ├── test_discriminator_classifiers.py ├── test_simple_cluster_space.py └── test_simple_perturbation_space.py ├── test_augur.py ├── test_cinemaot.py ├── test_dialogue.py ├── test_enrichment.py ├── test_milo.py ├── test_mixscape.py └── test_scgen.py /.editorconfig: -------------------------------------------------------------------------------- 1 | # http://editorconfig.org 2 | 3 | root = true 4 | 5 | [*] 6 | indent_style = space 7 | indent_size = 4 8 | trim_trailing_whitespace = true 9 | insert_final_newline = true 10 | charset = utf-8 11 | end_of_line = lf 12 | 13 | [*.bat] 14 | indent_style = tab 15 | end_of_line = crlf 16 | 17 | [LICENSE] 18 | insert_final_newline = true 19 | 20 | [Makefile] 21 | indent_style = tab 22 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | * text=auto eol=lf 2 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.yml: -------------------------------------------------------------------------------- 1 | name: Bug report 2 | description: pertpy doesn't do what it should? Please help us fix it! 3 | #title: ... 4 | labels: 5 | - bug 6 | - triage 7 | #assignees: [] 8 | body: 9 | - type: checkboxes 10 | id: terms 11 | attributes: 12 | label: Please make sure these conditions are met 13 | # description: ... 14 | options: 15 | - label: I have checked that this issue has not already been reported. 16 | required: true 17 | - label: I have confirmed this bug exists on the latest version of pertpy. 18 | required: true 19 | - label: (optional) I have confirmed this bug exists on the main branch. 20 | required: false 21 | - type: markdown 22 | attributes: 23 | value: | 24 | **Note**: Please read [this guide][] detailing how to provide the necessary information for us to reproduce your bug. 25 | 26 | [this guide]: https://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports 27 | - type: textarea 28 | id: Report 29 | attributes: 30 | label: Report 31 | description: | 32 | Describe the bug you encountered, and what you were trying to do. Please use [github markdown][] features for readability. 33 | 34 | [github markdown]: https://docs.github.com/en/get-started/writing-on-github/getting-started-with-writing-and-formatting-on-github/basic-writing-and-formatting-syntax 35 | value: | 36 | Code: 37 | 38 | ```python 39 | 40 | ``` 41 | 42 | Traceback: 43 | 44 | ```pytb 45 | 46 | ``` 47 | validations: 48 | required: true 49 | - type: textarea 50 | id: versions 51 | attributes: 52 | label: Versions 53 | description: | 54 | Which version of pertpy and other related software you used. 55 | 56 | Please install `session-info2`, run the following command in a notebook, 57 | click the “Copy as Markdown” button, 58 | then paste the results into the text box below. 59 | 60 | ```python 61 | In[1]: import pertpy, session_info2; session_info2.session_info(dependencies=True) 62 | ``` 63 | 64 | Alternatively, run this in a console: 65 | 66 | ```python 67 | >>> import pertpy, session_info2; print(session_info2.session_info(dependencies=True)._repr_mimebundle_()["text/markdown"]) 68 | ``` 69 | render: markdown 70 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | contact_links: 3 | - name: Scverse Community Forum 4 | url: https://discourse.scverse.org/ 5 | about: If you have questions about “How to do X”, please ask them here. 6 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.yml: -------------------------------------------------------------------------------- 1 | name: Feature request 2 | description: Propose a new feature for pertpy 3 | labels: enhancement 4 | body: 5 | - type: textarea 6 | id: description 7 | attributes: 8 | label: Description of feature 9 | description: Please describe your suggestion for a new feature. It might help to describe a problem or use case, plus any alternatives that you have considered. 10 | validations: 11 | required: true 12 | -------------------------------------------------------------------------------- /.github/labels.yml: -------------------------------------------------------------------------------- 1 | --- 2 | # Labels names are important as they are used by Release Drafter to decide 3 | # regarding where to record them in changelog or if to skip them. 4 | # 5 | # The repository labels will be automatically configured using this file and 6 | # the GitHub Action https://github.com/marketplace/actions/github-labeler. 7 | - name: breaking 8 | description: Breaking Changes 9 | color: bfd4f2 10 | - name: bug 11 | description: Something isn't working 12 | color: d73a4a 13 | - name: build 14 | description: Build System and Dependencies 15 | color: bfdadc 16 | - name: ci 17 | description: Continuous Integration 18 | color: 4a97d6 19 | - name: dependencies 20 | description: Pull requests that update a dependency file 21 | color: 0366d6 22 | - name: documentation 23 | description: Improvements or additions to documentation 24 | color: 0075ca 25 | - name: duplicate 26 | description: This issue or pull request already exists 27 | color: cfd3d7 28 | - name: enhancement 29 | description: New feature or request 30 | color: a2eeef 31 | - name: github_actions 32 | description: Pull requests that update Github_actions code 33 | color: "000000" 34 | - name: good first issue 35 | description: Good for newcomers 36 | color: 7057ff 37 | - name: help wanted 38 | description: Extra attention is needed 39 | color: 008672 40 | - name: invalid 41 | description: This doesn't seem right 42 | color: e4e669 43 | - name: performance 44 | description: Performance 45 | color: "016175" 46 | - name: python 47 | description: Pull requests that update Python code 48 | color: 2b67c6 49 | - name: question 50 | description: Further information is requested 51 | color: d876e3 52 | - name: refactoring 53 | description: Refactoring 54 | color: ef67c4 55 | - name: removal 56 | description: Removals and Deprecations 57 | color: 9ae7ea 58 | - name: style 59 | description: Style 60 | color: c120e5 61 | - name: testing 62 | description: Testing 63 | color: b1fc6f 64 | - name: wontfix 65 | description: This will not be worked on 66 | color: ffffff 67 | - name: skip-changelog 68 | description: Changes that should be omitted from the release notes 69 | color: ededed 70 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | **PR Checklist** 4 | 5 | 6 | 7 | - [ ] Referenced issue is linked 8 | - [ ] If you've fixed a bug or added code that should be tested, add tests! 9 | - [ ] Documentation in `docs` is updated 10 | 11 | **Description of changes** 12 | 13 | 14 | 15 | **Technical details** 16 | 17 | 18 | 19 | **Additional context** 20 | 21 | 22 | -------------------------------------------------------------------------------- /.github/release-drafter.yml: -------------------------------------------------------------------------------- 1 | name-template: "1.0.0 🌈" 2 | tag-template: 1.0.0 3 | exclude-labels: 4 | - "skip-changelog" 5 | 6 | categories: 7 | - title: "🚀 Features" 8 | labels: 9 | - feature 10 | - enhancement 11 | - title: "🐛 Bug Fixes" 12 | labels: 13 | - fix 14 | - bugfix 15 | - bug 16 | - title: "🧰 Maintenance" 17 | label: chore 18 | - title: ":package: Dependencies" 19 | labels: 20 | - dependencies 21 | - build 22 | - dependabot 23 | - DEPENDABOT 24 | version-resolver: 25 | major: 26 | labels: 27 | - major 28 | minor: 29 | labels: 30 | - minor 31 | patch: 32 | labels: 33 | - patch 34 | default: patch 35 | autolabeler: 36 | - label: chore 37 | files: 38 | - "*.md" 39 | branch: 40 | - '/docs{0,1}\/.+/' 41 | - label: bug 42 | branch: 43 | - /fix\/.+/ 44 | title: 45 | - /fix/i 46 | - label: enhancement 47 | branch: 48 | - /feature\/.+/ 49 | body: 50 | - "/JIRA-[0-9]{1,4}/" 51 | template: | 52 | ## Changes 53 | 54 | $CHANGES 55 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: Check Build 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | 9 | concurrency: 10 | group: ${{ github.workflow }}-${{ github.ref }} 11 | cancel-in-progress: true 12 | 13 | defaults: 14 | run: 15 | # to fail on error in multiline statements (-e), in pipes (-o pipefail), and on unset variables (-u). 16 | shell: bash -euo pipefail {0} 17 | 18 | jobs: 19 | package: 20 | runs-on: ubuntu-latest 21 | steps: 22 | - uses: actions/checkout@v4 23 | with: 24 | filter: blob:none 25 | fetch-depth: 0 26 | 27 | - name: Install uv 28 | uses: astral-sh/setup-uv@v6 29 | with: 30 | cache-dependency-glob: pyproject.toml 31 | 32 | - name: Build package 33 | run: uv build 34 | 35 | - name: Check package 36 | run: uvx twine check --strict dist/*.whl 37 | -------------------------------------------------------------------------------- /.github/workflows/labeler.yml: -------------------------------------------------------------------------------- 1 | name: Labeler 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - master 8 | 9 | jobs: 10 | labeler: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: Check out the repository 14 | uses: actions/checkout@v4 15 | 16 | - name: Run Labeler 17 | uses: crazy-max/ghaction-github-labeler@v4.1.0 18 | with: 19 | skip-delete: true 20 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | defaults: 8 | run: 9 | shell: bash -euo pipefail {0} 10 | 11 | jobs: 12 | release: 13 | name: Upload release to PyPI 14 | runs-on: ubuntu-latest 15 | environment: 16 | name: pypi 17 | url: https://pypi.org/p/pertpy 18 | permissions: 19 | id-token: write 20 | steps: 21 | - uses: actions/checkout@v4 22 | with: 23 | filter: blob:none 24 | fetch-depth: 0 25 | 26 | - name: Install uv 27 | uses: astral-sh/setup-uv@v5 28 | with: 29 | cache-dependency-glob: pyproject.toml 30 | 31 | - name: Build package 32 | run: uv build 33 | - name: Publish package distributions to PyPI 34 | uses: pypa/gh-action-pypi-publish@release/v1 35 | -------------------------------------------------------------------------------- /.github/workflows/release_drafter.yml: -------------------------------------------------------------------------------- 1 | name: Release Drafter 2 | on: 3 | push: 4 | branches: 5 | - main 6 | pull_request: 7 | branches: 8 | - main 9 | types: 10 | - opened 11 | - reopened 12 | - synchronize 13 | jobs: 14 | update_release_draft: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: release-drafter/release-drafter@v5 18 | env: 19 | GITHUB_TOKEN: "${{ secrets.GITHUB_TOKEN }}" 20 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | 9 | concurrency: 10 | group: ${{ github.workflow }}-${{ github.ref }} 11 | cancel-in-progress: true 12 | 13 | jobs: 14 | test: 15 | runs-on: ${{ matrix.os }} 16 | defaults: 17 | run: 18 | # to fail on error in multiline statements (-e), in pipes (-o pipefail), and on unset variables (-u). 19 | shell: bash -euo pipefail {0} 20 | 21 | strategy: 22 | fail-fast: false 23 | matrix: 24 | include: 25 | - os: ubuntu-latest 26 | python: "3.13" 27 | - os: ubuntu-latest 28 | python: "3.13" 29 | pip-flags: "--pre" 30 | 31 | env: 32 | OS: ${{ matrix.os }} 33 | PYTHON: ${{ matrix.python }} 34 | 35 | steps: 36 | - uses: actions/checkout@v4 37 | with: 38 | filter: blob:none 39 | fetch-depth: 0 40 | 41 | - name: Cache .pertpy_cache 42 | uses: actions/cache@v4 43 | with: 44 | path: cache 45 | key: ${{ runner.os }}-pertpy-cache-${{ hashFiles('pertpy/metadata/**') }} 46 | restore-keys: | 47 | ${{ runner.os }}-pertpy-cache 48 | 49 | - name: Set up Python ${{ matrix.python }} 50 | uses: actions/setup-python@v5 51 | with: 52 | python-version: ${{ matrix.python }} 53 | - name: Install R 54 | uses: r-lib/actions/setup-r@v2 55 | with: 56 | r-version: "4.4.3" 57 | 58 | - name: Cache R packages 59 | id: r-cache 60 | uses: actions/cache@v3 61 | with: 62 | path: ${{ env.R_LIBS_USER }} 63 | key: ${{ runner.os }}-r-${{ hashFiles('**/pertpy/tools/_milo.py') }} 64 | restore-keys: ${{ runner.os }}-r- 65 | 66 | - name: Install R dependencies 67 | if: steps.r-cache.outputs.cache-hit != 'true' 68 | run: | 69 | mkdir -p ${{ env.R_LIBS_USER }} 70 | Rscript --vanilla -e "install.packages(c('BiocManager', 'statmod'), repos='https://cran.r-project.org'); BiocManager::install('edgeR', lib='${{ env.R_LIBS_USER }}')" 71 | 72 | - name: Install uv 73 | uses: astral-sh/setup-uv@v6 74 | with: 75 | enable-cache: true 76 | cache-dependency-glob: pyproject.toml 77 | - name: Install dependencies 78 | run: | 79 | uv pip install --system rpy2 80 | uv pip install --system ${{ matrix.pip-flags }} ".[dev,test,tcoda,de]" 81 | 82 | - name: Test 83 | env: 84 | MPLBACKEND: agg 85 | PLATFORM: ${{ matrix.os }} 86 | DISPLAY: :42 87 | run: | 88 | pytest_args="-m pytest -v --color=yes" 89 | coverage run $pytest_args 90 | 91 | - name: Show coverage report 92 | run: coverage report -m 93 | 94 | - name: Upload coverage 95 | uses: codecov/codecov-action@v4 96 | with: 97 | token: ${{ secrets.CODECOV_TOKEN }} 98 | fail_ci_if_error: true 99 | verbose: true 100 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | .pertpy_cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | docs/api/data 76 | docs/api/metadata 77 | docs/api/tools 78 | docs/api/preprocessing 79 | !docs/api/api.md 80 | 81 | # PyBuilder 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | .python-version 93 | 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 103 | __pypackages__/ 104 | 105 | # Celery stuff 106 | celerybeat-schedule 107 | celerybeat.pid 108 | 109 | # SageMath parsed files 110 | *.sage.py 111 | 112 | # Environments 113 | .env 114 | .venv 115 | env/ 116 | venv/ 117 | ENV/ 118 | env.bak/ 119 | venv.bak/ 120 | 121 | # Spyder project settings 122 | .spyderproject 123 | .spyproject 124 | 125 | # Rope project settings 126 | .ropeproject 127 | 128 | # mkdocs documentation 129 | /site 130 | 131 | # mypy 132 | .mypy_cache/ 133 | .dmypy.json 134 | dmypy.json 135 | .pytype/ 136 | 137 | # Pyre type checker 138 | .pyre/ 139 | 140 | # Jetbrains IDE 141 | .idea/ 142 | 143 | # VSCode 144 | .vscode 145 | 146 | # Coala 147 | *.orig 148 | 149 | # Datasets 150 | *.h5ad 151 | *.h5md 152 | *.h5mu 153 | 154 | # Test cache 155 | cache 156 | 157 | # Apple stuff 158 | .DS_Store 159 | 160 | lightning_logs/* 161 | */lightning_logs/* 162 | 163 | node_modules 164 | 165 | # lamindb 166 | test.ipynb 167 | test-perturbation 168 | test-bug 169 | 170 | # uv 171 | uv.lock 172 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "docs/notebooks"] 2 | path = docs/tutorials/notebooks 3 | url = https://github.com/scverse/pertpy-tutorials/ 4 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | fail_fast: false 2 | default_language_version: 3 | python: python3 4 | default_stages: 5 | - pre-commit 6 | - pre-push 7 | minimum_pre_commit_version: 2.16.0 8 | repos: 9 | - repo: https://github.com/biomejs/pre-commit 10 | rev: v1.9.4 11 | hooks: 12 | - id: biome-format 13 | - repo: https://github.com/astral-sh/ruff-pre-commit 14 | rev: v0.11.9 15 | hooks: 16 | - id: ruff 17 | args: [--fix, --exit-non-zero-on-fix, --unsafe-fixes] 18 | - id: ruff-format 19 | - repo: https://github.com/pre-commit/pre-commit-hooks 20 | rev: v5.0.0 21 | hooks: 22 | - id: detect-private-key 23 | - id: check-ast 24 | - id: end-of-file-fixer 25 | - id: mixed-line-ending 26 | args: [--fix=lf] 27 | - id: trailing-whitespace 28 | - id: check-case-conflict 29 | - id: check-added-large-files 30 | - id: check-toml 31 | - id: check-yaml 32 | - id: check-merge-conflict 33 | - id: no-commit-to-branch 34 | args: ["--branch=main"] 35 | - repo: https://github.com/pre-commit/mirrors-mypy 36 | rev: v1.15.0 37 | hooks: 38 | - id: mypy 39 | args: [--no-strict-optional, --ignore-missing-imports] 40 | additional_dependencies: 41 | ["types-setuptools", "types-requests", "types-attrs"] 42 | - repo: local 43 | hooks: 44 | - id: forbid-to-commit 45 | name: Don't commit rej files 46 | entry: | 47 | Cannot commit .rej files. These indicate merge conflicts that arise during automated template updates. 48 | Fix the merge conflicts manually and remove the .rej files. 49 | language: fail 50 | files: '.*\.rej$' 51 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | build: 3 | os: ubuntu-24.04 4 | tools: 5 | python: "3.13" 6 | jobs: 7 | create_environment: 8 | - asdf plugin add uv 9 | - asdf install uv latest 10 | - asdf global uv latest 11 | - uv venv 12 | - uv pip install .[doc,tcoda,de] 13 | build: 14 | html: 15 | - uv run sphinx-build -T -W -b html docs $READTHEDOCS_OUTPUT/html 16 | 17 | submodules: 18 | include: all 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021, Lukas Heumos 4 | Copyright (c) 2025, scverse® 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Build](https://github.com/scverse/pertpy/actions/workflows/build.yml/badge.svg)](https://github.com/scverse/pertpy/actions/workflows/build.yml) 2 | [![codecov](https://codecov.io/gh/scverse/pertpy/graph/badge.svg?token=1dTpIPBShv)](https://codecov.io/gh/scverse/pertpy) 3 | [![License](https://img.shields.io/github/license/scverse/pertpy)](https://opensource.org/licenses/Apache2.0) 4 | [![PyPI](https://img.shields.io/pypi/v/pertpy.svg)](https://pypi.org/project/pertpy/) 5 | [![Python Version](https://img.shields.io/pypi/pyversions/pertpy)](https://pypi.org/project/pertpy) 6 | [![Read the Docs](https://img.shields.io/readthedocs/pertpy/latest.svg?label=Read%20the%20Docs)](https://pertpy.readthedocs.io/) 7 | [![Test](https://github.com/scverse/pertpy/actions/workflows/test.yml/badge.svg)](https://github.com/scverse/pertpy/actions/workflows/test.yml) 8 | [![PyPI](https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit&logoColor=white)](https://github.com/pre-commit/pre-commit) 9 | 10 | # pertpy - Perturbation Analysis in Python 11 | 12 | Pertpy is a scverse ecosystem framework for analyzing large-scale single-cell perturbation experiments. 13 | It provides tools for harmonizing perturbation datasets, automating metadata annotation, calculating perturbation distances, and efficiently analyzing how cells respond to various stimuli like genetic modifications, drug treatments, and environmental changes. 14 | 15 | ![fig1](https://github.com/user-attachments/assets/d2e32d69-b767-4be3-a938-77a9dce45d3f) 16 | 17 | ## Documentation 18 | 19 | Please read the [documentation](https://pertpy.readthedocs.io/en/latest) for installation, tutorials, use cases, and more. 20 | 21 | ## Installation 22 | 23 | We recommend installing and running pertpy on a recent version of Linux (e.g. Ubuntu 24.04 LTS). 24 | No particular hardware beyond a standard laptop is required. 25 | 26 | You can install _pertpy_ in less than a minute via [pip] from [PyPI]: 27 | 28 | ```console 29 | pip install pertpy 30 | ``` 31 | 32 | or [conda-forge]: 33 | 34 | ```console 35 | conda install -c conda-forge pertpy 36 | ``` 37 | 38 | ### Differential gene expression 39 | 40 | If you want to use the differential gene expression interface, please install pertpy by running: 41 | 42 | ```console 43 | pip install 'pertpy[de]' 44 | ``` 45 | 46 | ### tascCODA 47 | 48 | if you want to use tascCODA, please install pertpy as follows: 49 | 50 | ```console 51 | pip install 'pertpy[tcoda]' 52 | ``` 53 | 54 | ### milo 55 | 56 | milo requires either the "de" extra for the "pydeseq2" solver: 57 | 58 | ```console 59 | pip install 'pertpy[de]' 60 | ``` 61 | 62 | or, edger, statmod, and rpy2 for the "edger" solver: 63 | 64 | ```R 65 | BiocManager::install("edgeR") 66 | BiocManager::install("statmod") 67 | ``` 68 | 69 | ```console 70 | pip install rpy2 71 | ``` 72 | 73 | ## Citation 74 | 75 | ```bibtex 76 | @article {Heumos2024.08.04.606516, 77 | author = {Heumos, Lukas and Ji, Yuge and May, Lilly and Green, Tessa and Zhang, Xinyue and Wu, Xichen and Ostner, Johannes and Peidli, Stefan and Schumacher, Antonia and Hrovatin, Karin and Müller, Michaela and Chong, Faye and Sturm, Gregor and Tejada, Alejandro and Dann, Emma and Dong, Mingze and Bahrami, Mojtaba and Gold, Ilan and Rybakov, Sergei and Namsaraeva, Altana and Moinfar, Amir and Zheng, Zihe and Roellin, Eljas and Mekki, Isra and Sander, Chris and Lotfollahi, Mohammad and Schiller, Herbert B. and Theis, Fabian J.}, 78 | title = {Pertpy: an end-to-end framework for perturbation analysis}, 79 | elocation-id = {2024.08.04.606516}, 80 | year = {2024}, 81 | doi = {10.1101/2024.08.04.606516}, 82 | publisher = {Cold Spring Harbor Laboratory}, 83 | URL = {https://www.biorxiv.org/content/early/2024/08/07/2024.08.04.606516}, 84 | eprint = {https://www.biorxiv.org/content/early/2024/08/07/2024.08.04.606516.full.pdf}, 85 | journal = {bioRxiv} 86 | } 87 | ``` 88 | 89 | [pip]: https://pip.pypa.io/ 90 | [pypi]: https://pypi.org/ 91 | [api]: https://pertpy.readthedocs.io/en/latest/api.html 92 | [conda-forge]: https://anaconda.org/conda-forge/pertpy 93 | [//]: # "numfocus-fiscal-sponsor-attribution" 94 | 95 | pertpy is part of the scverse® project ([website](https://scverse.org), [governance](https://scverse.org/about/roles)) and is fiscally sponsored by [NumFOCUS](https://numfocus.org/). 96 | If you like scverse® and want to support our mission, please consider making a tax-deductible [donation](https://numfocus.org/donate-to-scverse) to help the project pay for developer time, professional services, travel, workshops, and a variety of other needs. 97 | 98 |
99 | 100 | 104 | 105 |
106 | -------------------------------------------------------------------------------- /biome.jsonc: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "https://biomejs.dev/schemas/1.9.4/schema.json", 3 | "formatter": { "useEditorconfig": true }, 4 | "overrides": [ 5 | { 6 | "include": ["./.vscode/*.json", "**/*.jsonc", "**/asv.conf.json"], 7 | "json": { 8 | "formatter": { 9 | "trailingCommas": "all", 10 | }, 11 | "parser": { 12 | "allowComments": true, 13 | "allowTrailingCommas": true, 14 | }, 15 | }, 16 | }, 17 | ], 18 | } 19 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | codecov: 2 | require_ci_to_pass: no 3 | 4 | coverage: 5 | status: 6 | project: 7 | default: 8 | # Require 1% coverage, i.e., succeed as long as coverage collection works 9 | target: 1 10 | patch: false 11 | changes: false 12 | 13 | comment: 14 | layout: diff, flags, files 15 | behavior: once 16 | require_base: no 17 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = python -msphinx 7 | SPHINXPROJ = pertpy 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/_ext/edit_on_github.py: -------------------------------------------------------------------------------- 1 | """Based on gist.github.com/MantasVaitkunas/7c16de233812adcb7028.""" 2 | 3 | import os 4 | from typing import Any 5 | 6 | from sphinx.application import Sphinx 7 | 8 | __licence__ = "BSD (3 clause)" 9 | 10 | 11 | def get_github_repo(app: Sphinx, path: str) -> str: # noqa: D103 12 | if path.endswith(".ipynb"): 13 | return str(app.config.github_nb_repo) 14 | if "auto_examples" in path: 15 | return str(app.config.github_nb_repo) 16 | if "auto_tutorials" in path: 17 | return str(app.config.github_nb_repo) 18 | return str(app.config.github_repo) 19 | 20 | 21 | def _html_page_context( 22 | app: Sphinx, _pagename: str, templatename: str, context: dict[str, Any], doctree: Any | None 23 | ) -> None: 24 | # doctree is None - otherwise viewcode fails 25 | if templatename != "page.html" or doctree is None: 26 | return 27 | 28 | if not app.config.github_repo: 29 | return 30 | 31 | if not app.config.github_nb_repo: 32 | nb_repo = f"{app.config.github_repo}_notebooks" 33 | app.config.github_nb_repo = nb_repo 34 | 35 | path = os.path.relpath(doctree.get("source"), app.builder.srcdir) 36 | repo = get_github_repo(app, path) 37 | 38 | # For sphinx_rtd_theme. 39 | context["display_github"] = True 40 | context["github_user"] = "scverse" 41 | context["github_version"] = "master" 42 | context["github_repo"] = repo 43 | context["conf_py_path"] = "/docs/source/" 44 | 45 | 46 | def setup(app: Sphinx) -> None: # noqa: D103 47 | app.add_config_value("github_nb_repo", "", True) 48 | app.add_config_value("github_repo", "", True) 49 | app.connect("html-page-context", _html_page_context) 50 | -------------------------------------------------------------------------------- /docs/_ext/typed_returns.py: -------------------------------------------------------------------------------- 1 | import re 2 | from collections.abc import Iterable, Iterator 3 | 4 | from sphinx.application import Sphinx 5 | from sphinx.ext.napoleon import NumpyDocstring 6 | 7 | 8 | def _process_return(lines: Iterable[str]) -> Iterator[str]: 9 | for line in lines: 10 | m = re.fullmatch(r"(?P\w+)\s+:\s+(?P[\w.]+)", line) 11 | if m: 12 | # Once this is in scanpydoc, we can use the fancy hover stuff 13 | yield f"**{m['param']}** : :class:`~{m['type']}`" 14 | else: 15 | yield line 16 | 17 | 18 | def _parse_returns_section(self: NumpyDocstring, section: str) -> list[str]: 19 | lines_raw = list(_process_return(self._dedent(self._consume_to_next_section()))) 20 | lines: list[str] = self._format_block(":returns: ", lines_raw) 21 | if lines and lines[-1]: 22 | lines.append("") 23 | return lines 24 | 25 | 26 | def setup(app: Sphinx) -> None: # noqa: D103 27 | NumpyDocstring._parse_returns_section = _parse_returns_section 28 | -------------------------------------------------------------------------------- /docs/_static/SCVI_LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2020 Romain Lopez, Adam Gayoso, Galen Xing, Yosef Lab 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /docs/_static/css/overwrite.css: -------------------------------------------------------------------------------- 1 | /* for the sphinx design cards */ 2 | body { 3 | --sd-color-shadow: dimgrey; 4 | } 5 | 6 | dt:target, 7 | span.highlighted { 8 | background-color: #f0f0f0; 9 | } 10 | 11 | dl.citation > dt { 12 | float: left; 13 | margin-right: 15px; 14 | font-weight: bold; 15 | } 16 | 17 | /* Parameters normalize size and captialized, */ 18 | dl .field-list dt { 19 | font-size: var(--font-size--normal) !important; 20 | text-transform: none !important; 21 | } 22 | 23 | /* examples and headings in classes */ 24 | p.rubric { 25 | font-size: var(--font-size--normal); 26 | text-transform: none; 27 | font-weight: 500; 28 | } 29 | 30 | /* adapted from https://github.com/dask/dask-sphinx-theme/blob/main/dask_sphinx_theme/static/css/nbsphinx.css */ 31 | 32 | .nbinput .prompt, 33 | .nboutput .prompt { 34 | display: none; 35 | } 36 | .nboutput .stderr { 37 | display: none; 38 | } 39 | 40 | div.nblast.container { 41 | padding-bottom: 10px !important; 42 | padding-right: 0px; 43 | padding-left: 0px; 44 | } 45 | 46 | div.nbinput.container { 47 | padding-top: 10px !important; 48 | padding-right: 0px; 49 | padding-left: 0px; 50 | } 51 | 52 | div.nbinput.container div.input_area div[class*="highlight"] > pre { 53 | padding: 10px !important; 54 | margin: 0; 55 | } 56 | 57 | p.topic-title { 58 | margin-top: 0; 59 | } 60 | 61 | /* so that api methods are small in sidebar */ 62 | li.toctree-l3 { 63 | font-size: 81.25% !important; 64 | } 65 | li.toctree-l4 { 66 | font-size: 75% !important; 67 | } 68 | 69 | .bd-sidebar .caption-text { 70 | color: #e63946; 71 | font-weight: 600; 72 | text-transform: uppercase; 73 | } 74 | -------------------------------------------------------------------------------- /docs/_static/css/sphinx_gallery.css: -------------------------------------------------------------------------------- 1 | .sphx-glr-thumbcontainer { 2 | background: inherit !important; 3 | min-height: 250px !important; 4 | margin: 10px !important; 5 | } 6 | 7 | .sphx-glr-thumbcontainer .headerlink { 8 | display: none !important; 9 | } 10 | 11 | div.sphx-glr-thumbcontainer span { 12 | font-style: normal !important; 13 | } 14 | 15 | .sphx-glr-thumbcontainer a.internal { 16 | padding: 140px 10px 0 !important; 17 | } 18 | 19 | .sphx-glr-thumbcontainer .figure { 20 | width: 200px !important; 21 | } 22 | 23 | .sphx-glr-thumbcontainer .figure.align-center { 24 | text-align: center; 25 | margin-left: 0%; 26 | transform: translate(0%); 27 | } 28 | -------------------------------------------------------------------------------- /docs/_static/docstring_previews/augur_dp_scatter.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/augur_dp_scatter.png -------------------------------------------------------------------------------- /docs/_static/docstring_previews/augur_important_features.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/augur_important_features.png -------------------------------------------------------------------------------- /docs/_static/docstring_previews/augur_lollipop.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/augur_lollipop.png -------------------------------------------------------------------------------- /docs/_static/docstring_previews/augur_scatterplot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/augur_scatterplot.png -------------------------------------------------------------------------------- /docs/_static/docstring_previews/de_fold_change.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/de_fold_change.png -------------------------------------------------------------------------------- /docs/_static/docstring_previews/de_multicomparison_fc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/de_multicomparison_fc.png -------------------------------------------------------------------------------- /docs/_static/docstring_previews/de_paired_expression.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/de_paired_expression.png -------------------------------------------------------------------------------- /docs/_static/docstring_previews/de_volcano.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/de_volcano.png -------------------------------------------------------------------------------- /docs/_static/docstring_previews/dialogue_pairplot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/dialogue_pairplot.png -------------------------------------------------------------------------------- /docs/_static/docstring_previews/dialogue_violin.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/dialogue_violin.png -------------------------------------------------------------------------------- /docs/_static/docstring_previews/enrichment_dotplot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/enrichment_dotplot.png -------------------------------------------------------------------------------- /docs/_static/docstring_previews/enrichment_gsea.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/enrichment_gsea.png -------------------------------------------------------------------------------- /docs/_static/docstring_previews/milo_da_beeswarm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/milo_da_beeswarm.png -------------------------------------------------------------------------------- /docs/_static/docstring_previews/milo_nhood.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/milo_nhood.png -------------------------------------------------------------------------------- /docs/_static/docstring_previews/milo_nhood_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/milo_nhood_graph.png -------------------------------------------------------------------------------- /docs/_static/docstring_previews/mixscape_barplot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/mixscape_barplot.png -------------------------------------------------------------------------------- /docs/_static/docstring_previews/mixscape_heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/mixscape_heatmap.png -------------------------------------------------------------------------------- /docs/_static/docstring_previews/mixscape_lda.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/mixscape_lda.png -------------------------------------------------------------------------------- /docs/_static/docstring_previews/mixscape_perturbscore.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/mixscape_perturbscore.png -------------------------------------------------------------------------------- /docs/_static/docstring_previews/mixscape_violin.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/mixscape_violin.png -------------------------------------------------------------------------------- /docs/_static/docstring_previews/pseudobulk_samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/pseudobulk_samples.png -------------------------------------------------------------------------------- /docs/_static/docstring_previews/sccoda_boxplots.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/sccoda_boxplots.png -------------------------------------------------------------------------------- /docs/_static/docstring_previews/sccoda_effects_barplot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/sccoda_effects_barplot.png -------------------------------------------------------------------------------- /docs/_static/docstring_previews/sccoda_rel_abundance_dispersion_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/sccoda_rel_abundance_dispersion_plot.png -------------------------------------------------------------------------------- /docs/_static/docstring_previews/sccoda_stacked_barplot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/sccoda_stacked_barplot.png -------------------------------------------------------------------------------- /docs/_static/docstring_previews/scgen_reg_mean.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/scgen_reg_mean.png -------------------------------------------------------------------------------- /docs/_static/docstring_previews/tasccoda_draw_effects.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/tasccoda_draw_effects.png -------------------------------------------------------------------------------- /docs/_static/docstring_previews/tasccoda_draw_tree.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/tasccoda_draw_tree.png -------------------------------------------------------------------------------- /docs/_static/docstring_previews/tasccoda_effects_umap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/tasccoda_effects_umap.png -------------------------------------------------------------------------------- /docs/_static/icons/code-24px.svg: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /docs/_static/icons/computer-24px.svg: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /docs/_static/icons/library_books-24px.svg: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /docs/_static/icons/play_circle_outline-24px.svg: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /docs/_static/pertpy_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/pertpy_logo.png -------------------------------------------------------------------------------- /docs/_static/pertpy_logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | -------------------------------------------------------------------------------- /docs/_static/placeholder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/placeholder.png -------------------------------------------------------------------------------- /docs/_static/tutorials/augur.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/tutorials/augur.png -------------------------------------------------------------------------------- /docs/_static/tutorials/cinemaot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/tutorials/cinemaot.png -------------------------------------------------------------------------------- /docs/_static/tutorials/dge.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/tutorials/dge.png -------------------------------------------------------------------------------- /docs/_static/tutorials/dialogue.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/tutorials/dialogue.png -------------------------------------------------------------------------------- /docs/_static/tutorials/distances.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/tutorials/distances.png -------------------------------------------------------------------------------- /docs/_static/tutorials/distances_tests.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/tutorials/distances_tests.png -------------------------------------------------------------------------------- /docs/_static/tutorials/enrichment.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/tutorials/enrichment.png -------------------------------------------------------------------------------- /docs/_static/tutorials/guide_rna_assignment.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/tutorials/guide_rna_assignment.png -------------------------------------------------------------------------------- /docs/_static/tutorials/mcfarland.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/tutorials/mcfarland.png -------------------------------------------------------------------------------- /docs/_static/tutorials/metadata.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/tutorials/metadata.png -------------------------------------------------------------------------------- /docs/_static/tutorials/milo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/tutorials/milo.png -------------------------------------------------------------------------------- /docs/_static/tutorials/mixscape.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/tutorials/mixscape.png -------------------------------------------------------------------------------- /docs/_static/tutorials/norman.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/tutorials/norman.png -------------------------------------------------------------------------------- /docs/_static/tutorials/ontology.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/tutorials/ontology.png -------------------------------------------------------------------------------- /docs/_static/tutorials/perturbation_space.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/tutorials/perturbation_space.png -------------------------------------------------------------------------------- /docs/_static/tutorials/placeholder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/tutorials/placeholder.png -------------------------------------------------------------------------------- /docs/_static/tutorials/sccoda.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/tutorials/sccoda.png -------------------------------------------------------------------------------- /docs/_static/tutorials/sccoda_extended.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/tutorials/sccoda_extended.png -------------------------------------------------------------------------------- /docs/_static/tutorials/scgen_perturbation_prediction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/tutorials/scgen_perturbation_prediction.png -------------------------------------------------------------------------------- /docs/_static/tutorials/tasccoda.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/tutorials/tasccoda.png -------------------------------------------------------------------------------- /docs/_static/tutorials/zhang.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/tutorials/zhang.png -------------------------------------------------------------------------------- /docs/_templates/autosummary/class.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. add toctree option to make autodoc generate the pages 6 | 7 | .. autoclass:: {{ objname }} 8 | 9 | {% block attributes %} 10 | {% if attributes %} 11 | Attributes table 12 | ~~~~~~~~~~~~~~~~ 13 | 14 | .. autosummary:: 15 | {% for item in attributes %} 16 | ~{{ name }}.{{ item }} 17 | {%- endfor %} 18 | {% endif %} 19 | {% endblock %} 20 | 21 | {% block methods %} 22 | {% if methods %} 23 | Methods table 24 | ~~~~~~~~~~~~~ 25 | 26 | .. autosummary:: 27 | {% for item in methods %} 28 | {%- if item != '__init__' %} 29 | ~{{ name }}.{{ item }} 30 | {%- endif -%} 31 | {%- endfor %} 32 | {% endif %} 33 | {% endblock %} 34 | 35 | {% block attributes_documentation %} 36 | {% if attributes %} 37 | Attributes 38 | ~~~~~~~~~~ 39 | 40 | {% for item in attributes %} 41 | 42 | .. autoattribute:: {{ [objname, item] | join(".") }} 43 | {%- endfor %} 44 | 45 | {% endif %} 46 | {% endblock %} 47 | 48 | {% block methods_documentation %} 49 | {% if methods %} 50 | Methods 51 | ~~~~~~~ 52 | 53 | {% for item in methods %} 54 | {%- if item != '__init__' %} 55 | 56 | .. automethod:: {{ [objname, item] | join(".") }} 57 | {%- endif -%} 58 | {%- endfor %} 59 | 60 | {% endif %} 61 | {% endblock %} 62 | -------------------------------------------------------------------------------- /docs/about/background.md: -------------------------------------------------------------------------------- 1 | # About Pertpy 2 | 3 | Pertpy is an end-to-end framework for the analysis of large-scale single-cell perturbation experiments. 4 | It provides access to harmonized perturbation datasets and metadata databases along with numerous fast and user-friendly implementations of both established and novel methods such as automatic metadata annotation or perturbation distances to efficiently analyze perturbation data. 5 | As part of the scverse ecosystem, pertpy interoperates with existing single-cell analysis libraries and is designed to be easily extended. 6 | If you find pertpy useful for your research, please check out {doc}`cite`. 7 | 8 | ## Design principles 9 | 10 | Our framework is based on three key principles: `Modularity`, `Flexibility`, and `Scalability`. 11 | 12 | ### Modularity 13 | 14 | Pertpy includes modules for analysis of single and combinatorial perturbations covering diverse types of perturbation data including genetic knockouts, drug screens, and disease states. 15 | The framework is designed for flexibility, offering more than 100 composable and interoperable analysis functions organized in modules which further ease downstream interpretation and visualization. 16 | These modules host fundamental building blocks for implementation and methods that share functionality and can be chained into custom pipelines. 17 | 18 | A typical Pertpy workflow consists of several steps: 19 | 20 | * Initial **data transformation** such as guide RNA assignment for CRISPR screens 21 | * **Quality control** to address confounding factors and technical variation 22 | * **Metadata annotation** against ontologies and enrichment from databases 23 | * **Perturbation space analysis** to learn biologically interpretable embeddings 24 | * **Downstream analysis** including differential expression, compositional analysis, and distance calculation 25 | 26 | This modular approach yields a powerful and flexible framework as many analysis steps can be independently applied or chained together. 27 | 28 | ### Flexibility 29 | 30 | Pertpy is purpose-built to organize, analyze, and visualize complex perturbation datasets. 31 | It is flexible and can be applied to datasets of different assays, data types, sizes, and perturbations, thereby unifying previous data-type- or assay-specific single-problem approaches. 32 | Designed to integrate external metadata with measured data, it enables unprecedented contextualization of results through swiftly built, experiment-specific pipelines, leading to more robust outcomes. 33 | 34 | The inputs to a typical analysis with pertpy are unimodal scRNA-seq or multimodal perturbation readouts stored in AnnData or MuData objects. 35 | While pertpy is primarily designed to explore perturbations such as genetic modifications, drug treatments, exposure to pathogens, and other environmental conditions, its utility extends to various other perturbation settings, including diverse disease states where experimental perturbations have not been applied. 36 | 37 | ### Scalability 38 | 39 | Pertpy addresses a wide array of use-cases and different types of growing datasets through its sparse and memory-efficient implementations, which leverage the parallelization and GPU acceleration library Jax, and numba, thereby making them substantially faster than original implementations. 40 | The framework can be applied to datasets ranging from thousands to millions of cells. 41 | 42 | For example, when analyzing CRISPR screens, Pertpy's implementation of Mixscape is optimized using PyNNDescent for nearest neighbor search during the calculation of perturbation signatures. 43 | Other methods such as scCODA and tascCODA are accelerated by replacing the Hamiltonian Monte Carlo algorithm in TensorFlow with the no-U-turn sampler from numpyro. 44 | CINEMA-OT is optimized with ott-jax to make the implementation portable across hardware, enabling GPU acceleration. 45 | 46 | ## Why is it called "Pertpy"? 47 | 48 | Pertpy is named for its core purpose: The analysis of **pert**urbations in **Py**thon. 49 | The framework unifies perturbation analysis approaches across different data types and experimental designs, providing a comprehensive solution for understanding cellular responses to various stimuli. 50 | -------------------------------------------------------------------------------- /docs/about/cite.md: -------------------------------------------------------------------------------- 1 | # Citing pertpy 2 | 3 | If you find pertpy useful for your research, please consider citing our work as follows: 4 | 5 | ```bibtex 6 | @article {Heumos2024.08.04.606516, 7 | author = {Heumos, Lukas and Ji, Yuge and May, Lilly and Green, Tessa and Zhang, Xinyue and Wu, Xichen and Ostner, Johannes and Peidli, Stefan and Schumacher, Antonia and Hrovatin, Karin and Müller, Michaela and Chong, Faye and Sturm, Gregor and Tejada, Alejandro and Dann, Emma and Dong, Mingze and Bahrami, Mojtaba and Gold, Ilan and Rybakov, Sergei and Namsaraeva, Altana and Moinfar, Amir and Zheng, Zihe and Roellin, Eljas and Mekki, Isra and Sander, Chris and Lotfollahi, Mohammad and Schiller, Herbert B. and Theis, Fabian J.}, 8 | title = {Pertpy: an end-to-end framework for perturbation analysis}, 9 | elocation-id = {2024.08.04.606516}, 10 | year = {2024}, 11 | doi = {10.1101/2024.08.04.606516}, 12 | publisher = {Cold Spring Harbor Laboratory}, 13 | URL = {https://www.biorxiv.org/content/early/2024/08/07/2024.08.04.606516}, 14 | eprint = {https://www.biorxiv.org/content/early/2024/08/07/2024.08.04.606516.full.pdf}, 15 | journal = {bioRxiv} 16 | } 17 | ``` 18 | 19 | If you are using any previously published tool, please also cite the original publication. 20 | All tool specific references can be found here: {doc}`../references`. 21 | -------------------------------------------------------------------------------- /docs/api.md: -------------------------------------------------------------------------------- 1 | # API 2 | 3 | Import the pertpy API as follows: 4 | 5 | ```python 6 | import pertpy as pt 7 | ``` 8 | 9 | You can then access the respective modules like: 10 | 11 | ```python 12 | pt.tl.cool_fancy_tool() 13 | ``` 14 | 15 | ```{toctree} 16 | :maxdepth: 1 17 | 18 | api/datasets_index 19 | api/preprocessing_index 20 | api/tools_index 21 | api/metadata_index 22 | ``` 23 | -------------------------------------------------------------------------------- /docs/api/datasets_index.md: -------------------------------------------------------------------------------- 1 | ```{eval-rst} 2 | .. currentmodule:: pertpy 3 | ``` 4 | 5 | # Datasets 6 | 7 | pertpy provides access to several curated single-cell datasets spanning several types of perturbations. 8 | Many of the datasets originate from [scperturb](http://projects.sanderlab.org/scperturb/) {cite}`Peidli2024`. 9 | 10 | ```{eval-rst} 11 | .. autosummary:: 12 | :toctree: data 13 | 14 | data.adamson_2016_pilot 15 | data.adamson_2016_upr_epistasis 16 | data.adamson_2016_upr_perturb_seq 17 | data.aissa_2021 18 | data.bhattacherjee 19 | data.burczynski_crohn 20 | data.chang_2021 21 | data.combosciplex 22 | data.cinemaot_example 23 | data.datlinger_2017 24 | data.datlinger_2021 25 | data.dialogue_example 26 | data.distance_example 27 | data.dixit_2016 28 | data.dixit_2016_raw 29 | data.dong_2023 30 | data.frangieh_2021 31 | data.frangieh_2021_protein 32 | data.frangieh_2021_raw 33 | data.frangieh_2021_rna 34 | data.gasperini_2019_atscale 35 | data.gasperini_2019_highmoi 36 | data.gasperini_2019_lowmoi 37 | data.gehring_2019 38 | data.haber_2017_regions 39 | data.hagai_2018 40 | data.kang_2018 41 | data.mcfarland_2020 42 | data.norman_2019 43 | data.norman_2019_raw 44 | data.papalexi_2021 45 | data.replogle_2022_k562_essential 46 | data.replogle_2022_k562_gwps 47 | data.replogle_2022_rpe1 48 | data.sc_sim_augur 49 | data.schiebinger_2019_16day 50 | data.schiebinger_2019_18day 51 | data.schraivogel_2020_tap_screen_chr8 52 | data.schraivogel_2020_tap_screen_chr11 53 | data.sciplex_gxe1 54 | data.sciplex3_raw 55 | data.shifrut_2018 56 | data.smillie_2019 57 | data.srivatsan_2020_sciplex2 58 | data.srivatsan_2020_sciplex3 59 | data.srivatsan_2020_sciplex4 60 | data.stephenson_2021_subsampled 61 | data.tasccoda_example 62 | data.tian_2019_day7neuron 63 | data.tian_2019_ipsc 64 | data.tian_2021_crispra 65 | data.tian_2021_crispri 66 | data.weinreb_2020 67 | data.xie_2017 68 | data.zhao_2021 69 | data.zhang_2021 70 | ``` 71 | -------------------------------------------------------------------------------- /docs/api/metadata_index.md: -------------------------------------------------------------------------------- 1 | ```{eval-rst} 2 | .. currentmodule:: pertpy 3 | ``` 4 | 5 | # Metadata 6 | 7 | The metadata module provides tooling to annotate perturbations by querying databases. 8 | Such metadata can aid with the development of biologically informed models and can be used for enrichment tests. 9 | 10 | ## Cell line 11 | 12 | This module allows for the retrieval of various types of information related to cell lines, 13 | including cell line annotation, bulk RNA and protein expression data. 14 | 15 | Available databases for cell line metadata: 16 | 17 | - [The Cancer Dependency Map Project at Broad](https://depmap.org/portal/) 18 | - [The Cancer Dependency Map Project at Sanger](https://depmap.sanger.ac.uk/) 19 | - [Genomics of Drug Sensitivity in Cancer (GDSC)](https://www.cancerrxgene.org/) 20 | 21 | ## Compound 22 | 23 | The Compound module enables the retrieval of various types of information related to compounds of interest, including the most common synonym, pubchemID and canonical SMILES. 24 | 25 | Available databases for compound metadata: 26 | 27 | - [PubChem](https://pubchem.ncbi.nlm.nih.gov/) 28 | 29 | ## Mechanism of Action 30 | 31 | This module aims to retrieve metadata of mechanism of action studies related to perturbagens of interest, depending on the molecular targets. 32 | 33 | Available databases for mechanism of action metadata: 34 | 35 | - [CLUE](https://clue.io/) 36 | 37 | ## Drug 38 | 39 | This module allows for the retrieval of Drug target information. 40 | 41 | Available databases for drug metadata: 42 | 43 | - [chembl](https://www.ebi.ac.uk/chembl/) 44 | 45 | ```{eval-rst} 46 | .. autosummary:: 47 | :toctree: metadata 48 | :recursive: 49 | 50 | metadata.CellLine 51 | metadata.Compound 52 | metadata.Moa 53 | metadata.Drug 54 | metadata.LookUp 55 | ``` 56 | -------------------------------------------------------------------------------- /docs/api/preprocessing_index.md: -------------------------------------------------------------------------------- 1 | ```{eval-rst} 2 | .. currentmodule:: pertpy 3 | ``` 4 | 5 | # Preprocessing 6 | 7 | ## Guide Assignment 8 | 9 | Guide assignment is essential for quality control in single-cell Perturb-seq data, ensuring accurate mapping of guide RNAs to cells for reliable interpretation of gene perturbation effects. 10 | pertpy provides a simple function to assign guides based on thresholds and a Gaussian mixture model {cite}`Replogle2022`. 11 | 12 | ```{eval-rst} 13 | .. autosummary:: 14 | :toctree: preprocessing 15 | :nosignatures: 16 | 17 | preprocessing.GuideAssignment 18 | ``` 19 | 20 | Example implementation: 21 | 22 | ```python 23 | import pertpy as pt 24 | import scanpy as sc 25 | 26 | mdata = pt.dt.papalexi_2021() 27 | gdo = mdata.mod["gdo"] 28 | gdo.layers["counts"] = gdo.X.copy() 29 | sc.pp.log1p(gdo) 30 | 31 | ga = pt.pp.GuideAssignment() 32 | ga.assign_by_threshold(gdo, 5, layer="counts", output_layer="assigned_guides") 33 | 34 | ga.plot_heatmap(gdo, layer="assigned_guides") 35 | ``` 36 | 37 | See [guide assignment tutorial](https://pertpy.readthedocs.io/en/latest/tutorials/notebooks/guide_rna_assignment.html). 38 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # mypy: ignore-errors 3 | 4 | import sys 5 | from datetime import datetime 6 | from importlib.metadata import metadata 7 | from pathlib import Path 8 | 9 | HERE = Path(__file__).parent 10 | sys.path[:0] = [str(HERE.parent), str(HERE / "extensions")] 11 | 12 | needs_sphinx = "4.3" 13 | 14 | info = metadata("pertpy") 15 | project_name = info["Name"] 16 | author = info["Author"] 17 | copyright = f"{datetime.now():%Y}, {author}." 18 | version = info["Version"] 19 | urls = dict(pu.split(", ") for pu in info.get_all("Project-URL")) 20 | repository_url = urls["Source"] 21 | release = info["Version"] 22 | github_repo = "pertpy" 23 | master_doc = "index" 24 | language = "en" 25 | 26 | extensions = [ 27 | "myst_nb", 28 | "sphinx.ext.autodoc", 29 | "sphinx.ext.intersphinx", 30 | "sphinx.ext.viewcode", 31 | "nbsphinx", 32 | "nbsphinx_link", 33 | "sphinx.ext.mathjax", 34 | "sphinx.ext.napoleon", 35 | "sphinx_autodoc_typehints", # needs to be after napoleon 36 | "sphinx.ext.autosummary", 37 | "sphinx_copybutton", 38 | "sphinx_gallery.load_style", 39 | "sphinx_remove_toctrees", 40 | "sphinx_design", 41 | "sphinx_issues", 42 | "sphinxcontrib.bibtex", 43 | "IPython.sphinxext.ipython_console_highlighting", 44 | ] 45 | 46 | ogp_site_url = "https://pertpy.readthedocs.io/en/latest/" 47 | ogp_image = "https://pertpy.readthedocs.io/en/latest/_static/pertpy_logo.png" 48 | 49 | # nbsphinx specific settings 50 | exclude_patterns = [ 51 | "_build", 52 | "Thumbs.db", 53 | ".DS_Store", 54 | "auto_*/**.ipynb", 55 | "auto_*/**.md5", 56 | "auto_*/**.py", 57 | "**.ipynb_checkpoints", 58 | ] 59 | nbsphinx_execute = "never" 60 | pygments_style = "sphinx" 61 | 62 | templates_path = ["_templates"] 63 | bibtex_bibfiles = ["references.bib"] 64 | nitpicky = True # Warn about broken links 65 | source_suffix = { 66 | ".rst": "restructuredtext", 67 | ".ipynb": "myst-nb", 68 | ".myst": "myst-nb", 69 | } 70 | 71 | suppress_warnings = ["toc.not_included"] 72 | 73 | autosummary_generate = True 74 | autosummary_imported_members = True 75 | autodoc_member_order = "groupwise" 76 | napoleon_google_docstring = True 77 | napoleon_include_init_with_doc = False 78 | napoleon_use_rtype = True 79 | napoleon_use_param = True 80 | myst_heading_anchors = 6 81 | napoleon_custom_sections = [("Params", "Parameters")] 82 | todo_include_todos = False 83 | annotate_defaults = True 84 | myst_enable_extensions = [ 85 | "amsmath", 86 | "colon_fence", 87 | "deflist", 88 | "dollarmath", 89 | "html_image", 90 | "html_admonition", 91 | ] 92 | myst_url_schemes = ("http", "https", "mailto") 93 | nb_execution_mode = "off" 94 | nb_merge_streams = True 95 | warn_as_error = True 96 | 97 | typehints_defaults = "comma" 98 | 99 | html_theme = "scanpydoc" 100 | html_title = "pertpy" 101 | html_logo = "_static/pertpy_logo.svg" 102 | 103 | html_theme_options = {} 104 | 105 | html_static_path = ["_static"] 106 | html_css_files = ["css/overwrite.css", "css/sphinx_gallery.css"] 107 | html_show_sphinx = False 108 | 109 | add_module_names = False 110 | autodoc_mock_imports = ["ete4"] 111 | intersphinx_mapping = { 112 | "anndata": ("https://anndata.readthedocs.io/en/stable/", None), 113 | "mudata": ("https://mudata.readthedocs.io/en/stable/", None), 114 | "scvi-tools": ("https://docs.scvi-tools.org/en/stable/", None), 115 | "ipython": ("https://ipython.readthedocs.io/en/stable/", None), 116 | "matplotlib": ("https://matplotlib.org/stable/", None), 117 | "numpy": ("https://numpy.org/doc/stable/", None), 118 | "pandas": ("https://pandas.pydata.org/pandas-docs/stable/", None), 119 | "python": ("https://docs.python.org/3", None), 120 | "scipy": ("https://docs.scipy.org/doc/scipy/", None), 121 | "torch": ("https://docs.pytorch.org/docs/main", None), 122 | "scanpy": ("https://scanpy.readthedocs.io/en/stable/", None), 123 | "pytorch_lightning": ("https://lightning.ai/docs/pytorch/stable/", None), 124 | "pyro": ("https://docs.pyro.ai/en/stable/", None), 125 | "pymde": ("https://pymde.org/", None), 126 | "flax": ("https://flax.readthedocs.io/en/latest/", None), 127 | "jax": ("https://docs.jax.dev/en/latest/", None), 128 | "ete": ("https://etetoolkit.org/docs/latest/", None), 129 | "arviz": ("https://python.arviz.org/en/stable/", None), 130 | "sklearn": ("https://scikit-learn.org/stable", None), 131 | "statsmodels": ("https://www.statsmodels.org/stable", None), 132 | } 133 | nitpick_ignore = [ 134 | ("py:class", "ete4.core.tree.Tree"), 135 | ("py:class", "ete4.treeview.TreeStyle"), 136 | ("py:class", "pertpy.tools._distances._distances.MeanVar"), 137 | ("py:class", "The requested data as a NumPy array."), 138 | ("py:class", "The full registry saved with the model"), 139 | ("py:class", "The requested data."), 140 | ("py:class", "Model with loaded state dictionaries."), 141 | ("py:class", "pertpy.tools.lazy_import..Placeholder"), 142 | ] 143 | 144 | sphinx_gallery_conf = {"nested_sections=": False} 145 | nbsphinx_thumbnails = { 146 | "tutorials/notebooks/guide_rna_assignment": "_static/tutorials/guide_rna_assignment.png", 147 | "tutorials/notebooks/mixscape": "_static/tutorials/mixscape.png", 148 | "tutorials/notebooks/augur": "_static/tutorials/augur.png", 149 | "tutorials/notebooks/sccoda": "_static/tutorials/sccoda.png", 150 | "tutorials/notebooks/sccoda_extended": "_static/tutorials/sccoda_extended.png", 151 | "tutorials/notebooks/tasccoda": "_static/tutorials/tasccoda.png", 152 | "tutorials/notebooks/milo": "_static/tutorials/milo.png", 153 | "tutorials/notebooks/dialogue": "_static/tutorials/dialogue.png", 154 | "tutorials/notebooks/enrichment": "_static/tutorials/enrichment.png", 155 | "tutorials/notebooks/distances": "_static/tutorials/distances.png", 156 | "tutorials/notebooks/distance_tests": "_static/tutorials/distances_tests.png", 157 | "tutorials/notebooks/cinemaot": "_static/tutorials/cinemaot.png", 158 | "tutorials/notebooks/scgen_perturbation_prediction": "_static/tutorials/scgen_perturbation_prediction.png", 159 | "tutorials/notebooks/perturbation_space": "_static/tutorials/perturbation_space.png", 160 | "tutorials/notebooks/differential_gene_expression": "_static/tutorials/dge.png", 161 | "tutorials/notebooks/metadata_annotation": "_static/tutorials/metadata.png", 162 | "tutorials/notebooks/ontology_mapping": "_static/tutorials/ontology.png", 163 | "tutorials/notebooks/norman_use_case": "_static/tutorials/norman.png", 164 | "tutorials/notebooks/mcfarland_use_case": "_static/tutorials/mcfarland.png", 165 | "tutorials/notebooks/zhang_use_case": "_static/tutorials/zhang.png", 166 | } 167 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # pertpy 2 | 3 | Pertpy is a scverse ecosystem framework for analyzing large-scale single-cell perturbation experiments. 4 | It provides tools for harmonizing perturbation datasets, automating metadata annotation, calculating perturbation distances, and efficiently analyzing how cells respond to various stimuli like genetic modifications, drug treatments, and environmental changes. 5 | 6 | ![overview](https://github.com/user-attachments/assets/d2e32d69-b767-4be3-a938-77a9dce45d3f) 7 | 8 | ```{eval-rst} 9 | .. card:: Installation :octicon:`plug;1em;` 10 | :link: installation 11 | :link-type: doc 12 | 13 | New to *pertpy*? Check out the installation guide. 14 | ``` 15 | 16 | ```{eval-rst} 17 | .. card:: API reference :octicon:`book;1em;` 18 | :link: api 19 | :link-type: doc 20 | 21 | The API reference contains a detailed description of the pertpy API. 22 | ``` 23 | 24 | ```{eval-rst} 25 | .. card:: Tutorials :octicon:`play;1em;` 26 | :link: tutorials 27 | :link-type: doc 28 | 29 | The tutorials walk you through real-world applications of pertpy. 30 | ``` 31 | 32 | ```{eval-rst} 33 | .. card:: Discussion :octicon:`megaphone;1em;` 34 | :link: https://discourse.scverse.org/ 35 | 36 | Need help? Reach out on our forum to get your questions answered! 37 | 38 | ``` 39 | 40 | ```{eval-rst} 41 | .. card:: GitHub :octicon:`mark-github;1em;` 42 | :link: https://github.com/scverse/pertpy 43 | 44 | Found a bug? Interested in improving pertpy? Checkout our GitHub for the latest developments. 45 | 46 | ``` 47 | 48 | ```{toctree} 49 | :caption: 'General' 50 | :hidden: true 51 | :maxdepth: 2 52 | 53 | installation 54 | api 55 | contributing 56 | changelog 57 | references 58 | ``` 59 | 60 | ```{toctree} 61 | :caption: 'Gallery' 62 | :hidden: true 63 | :maxdepth: 3 64 | 65 | tutorials 66 | usecases 67 | ``` 68 | 69 | ```{toctree} 70 | :caption: 'About' 71 | :hidden: true 72 | :maxdepth: 2 73 | 74 | about/background 75 | about/cite 76 | GitHub 77 | Discourse 78 | ``` 79 | 80 | ## Citation 81 | 82 | ```bibtex 83 | @article {Heumos2024.08.04.606516, 84 | author = {Heumos, Lukas and Ji, Yuge and May, Lilly and Green, Tessa and Zhang, Xinyue and Wu, Xichen and Ostner, Johannes and Peidli, Stefan and Schumacher, Antonia and Hrovatin, Karin and Müller, Michaela and Chong, Faye and Sturm, Gregor and Tejada, Alejandro and Dann, Emma and Dong, Mingze and Bahrami, Mojtaba and Gold, Ilan and Rybakov, Sergei and Namsaraeva, Altana and Moinfar, Amir and Zheng, Zihe and Roellin, Eljas and Mekki, Isra and Sander, Chris and Lotfollahi, Mohammad and Schiller, Herbert B. and Theis, Fabian J.}, 85 | title = {Pertpy: an end-to-end framework for perturbation analysis}, 86 | elocation-id = {2024.08.04.606516}, 87 | year = {2024}, 88 | doi = {10.1101/2024.08.04.606516}, 89 | publisher = {Cold Spring Harbor Laboratory}, 90 | URL = {https://www.biorxiv.org/content/early/2024/08/07/2024.08.04.606516}, 91 | eprint = {https://www.biorxiv.org/content/early/2024/08/07/2024.08.04.606516.full.pdf}, 92 | journal = {bioRxiv} 93 | } 94 | ``` 95 | 96 | ## NumFOCUS 97 | 98 | [//]: # "numfocus-fiscal-sponsor-attribution" 99 | 100 | pertpy is part of the scverse® project ([website](https://scverse.org), [governance](https://scverse.org/about/roles)) and is fiscally sponsored by [NumFOCUS](https://numfocus.org/). 101 | If you like scverse® and want to support our mission, please consider making a tax-deductible [donation](https://numfocus.org/donate-to-scverse) to help the project pay for developer time, professional services, travel, workshops, and a variety of other needs. 102 | 103 |
104 | 105 | 109 | 110 |
111 | -------------------------------------------------------------------------------- /docs/installation.md: -------------------------------------------------------------------------------- 1 | ```{highlight} shell 2 | 3 | ``` 4 | 5 | # Installation 6 | 7 | ## Stable release 8 | 9 | ### PyPI 10 | 11 | To install pertpy, run this command in your terminal: 12 | 13 | ```console 14 | pip install pertpy 15 | ``` 16 | 17 | This is the preferred method to install pertpy, as it will always install the most recent stable release. 18 | If you don't have [pip] installed, this [Python installation guide] can guide you through the process. 19 | 20 | ### conda-forge 21 | 22 | Alternatively, you can install pertpy from [conda-forge]: 23 | 24 | ```console 25 | conda install -c conda-forge pertpy 26 | ``` 27 | 28 | ### Additional dependency groups 29 | 30 | #### Differential gene expression interface 31 | 32 | The DGE interface of pertpy requires additional dependencies that can be installed by running: 33 | 34 | ```console 35 | pip install pertpy[de] 36 | ``` 37 | 38 | Note that edger in pertpy requires edger and rpy2 to be installed: 39 | 40 | ```R 41 | BiocManager::install("edgeR") 42 | ``` 43 | 44 | ```console 45 | pip install rpy2 46 | ``` 47 | 48 | #### milo 49 | 50 | milo requires either the "de" extra for the "pydeseq2" solver: 51 | 52 | ```console 53 | pip install 'pertpy[de]' 54 | ``` 55 | 56 | or, edger, statmod, and rpy2 for the "edger" solver: 57 | 58 | ```R 59 | BiocManager::install("edgeR") 60 | BiocManager::install("statmod") 61 | ``` 62 | 63 | ```console 64 | pip install rpy2 65 | ``` 66 | 67 | #### tascCODA 68 | 69 | TascCODA requires an additional set of dependencies (ete4, pyqt6, and toytree) that can be installed by running: 70 | 71 | ```console 72 | pip install pertpy[tcoda] 73 | ``` 74 | 75 | ## From sources 76 | 77 | The sources for pertpy can be downloaded from the [Github repo]. 78 | 79 | You can either clone the public repository: 80 | 81 | ```console 82 | $ git clone git://github.com/scverse/pertpy 83 | ``` 84 | 85 | Or download the [tarball]: 86 | 87 | ```console 88 | $ curl -OJL https://github.com/scverse/pertpy/tarball/master 89 | ``` 90 | 91 | [github repo]: https://github.com/scverse/pertpy 92 | [pip]: https://pip.pypa.io 93 | [conda-forge]: https://anaconda.org/conda-forge/pertpy 94 | [python installation guide]: http://docs.python-guide.org/en/latest/starting/installation/ 95 | [tarball]: https://github.com/scverse/pertpy/tarball/master 96 | [Homebrew]: https://brew.sh/ 97 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=python -msphinx 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | set SPHINXPROJ=pertpy 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The Sphinx module was not found. Make sure you have Sphinx installed, 20 | echo.then set the SPHINXBUILD environment variable to point to the full 21 | echo.path of the 'sphinx-build' executable. Alternatively you may add the 22 | echo.Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /docs/references.md: -------------------------------------------------------------------------------- 1 | # References 2 | 3 | ```{bibliography} 4 | :cited: 5 | ``` 6 | -------------------------------------------------------------------------------- /docs/tutorials.md: -------------------------------------------------------------------------------- 1 | --- 2 | orphan: true 3 | --- 4 | 5 | # Tutorials 6 | 7 | The easiest way to get familiar with pertpy is to follow along with our tutorials. 8 | 9 | :::{note} 10 | For questions about the usage of pertpy use the [scverse discourse](https://discourse.scverse.org/). 11 | ::: 12 | 13 | ```{toctree} 14 | :maxdepth: 1 15 | 16 | tutorials/preprocessing.md 17 | tutorials/tools.md 18 | tutorials/metadata.md 19 | ``` 20 | -------------------------------------------------------------------------------- /docs/tutorials/metadata.md: -------------------------------------------------------------------------------- 1 | # Metadata 2 | 3 | ```{eval-rst} 4 | .. nbgallery:: 5 | 6 | notebooks/metadata_annotation 7 | notebooks/ontology_mapping 8 | ``` 9 | -------------------------------------------------------------------------------- /docs/tutorials/preprocessing.md: -------------------------------------------------------------------------------- 1 | # Preprocessing 2 | 3 | ```{eval-rst} 4 | .. nbgallery:: 5 | 6 | notebooks/guide_rna_assignment 7 | ``` 8 | -------------------------------------------------------------------------------- /docs/tutorials/tools.md: -------------------------------------------------------------------------------- 1 | # Tools 2 | 3 | ## Differential gene expression 4 | 5 | ```{eval-rst} 6 | .. nbgallery:: 7 | 8 | notebooks/differential_gene_expression 9 | ``` 10 | 11 | ## Pooled CRISPR screens 12 | 13 | ```{eval-rst} 14 | .. nbgallery:: 15 | 16 | notebooks/mixscape 17 | ``` 18 | 19 | ## Compositional analysis 20 | 21 | ```{eval-rst} 22 | .. nbgallery:: 23 | 24 | notebooks/sccoda 25 | notebooks/sccoda_extended 26 | notebooks/tasccoda 27 | notebooks/milo 28 | ``` 29 | 30 | ## Multicellular and gene programs 31 | 32 | ```{eval-rst} 33 | .. nbgallery:: 34 | 35 | notebooks/dialogue 36 | notebooks/enrichment 37 | ``` 38 | 39 | ## Distances and permutation tests 40 | 41 | ```{eval-rst} 42 | .. nbgallery:: 43 | 44 | notebooks/distances 45 | notebooks/distance_tests 46 | ``` 47 | 48 | ## Response prediction 49 | 50 | ```{eval-rst} 51 | .. nbgallery:: 52 | 53 | notebooks/augur 54 | notebooks/cinemaot 55 | notebooks/scgen_perturbation_prediction 56 | ``` 57 | 58 | ## Perturbation space 59 | 60 | ```{eval-rst} 61 | .. nbgallery:: 62 | 63 | notebooks/perturbation_space 64 | ``` 65 | -------------------------------------------------------------------------------- /docs/usecases.md: -------------------------------------------------------------------------------- 1 | # Use cases 2 | 3 | Our use cases showcase a variety of pertpy tools applied to one dataset. 4 | They are designed to give you a sense of how to use pertpy in a real-world scenario. 5 | The use cases featured here are those we present in the [pertpy preprint](https://www.biorxiv.org/content/10.1101/2024.08.04.606516v1). 6 | 7 | ```{eval-rst} 8 | .. nbgallery:: 9 | 10 | tutorials/notebooks/norman_use_case 11 | tutorials/notebooks/mcfarland_use_case 12 | tutorials/notebooks/zhang_use_case 13 | ``` 14 | -------------------------------------------------------------------------------- /docs/utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/utils.py -------------------------------------------------------------------------------- /pertpy/__init__.py: -------------------------------------------------------------------------------- 1 | """Top-level package for pertpy.""" 2 | 3 | __author__ = "Lukas Heumos" 4 | __email__ = "lukas.heumos@posteo.net" 5 | __version__ = "1.0.0" 6 | 7 | import warnings 8 | 9 | from anndata._core.aligned_df import ImplicitModificationWarning 10 | from matplotlib import MatplotlibDeprecationWarning 11 | from numba import NumbaDeprecationWarning 12 | 13 | warnings.filterwarnings("ignore", category=NumbaDeprecationWarning) 14 | warnings.filterwarnings("ignore", category=MatplotlibDeprecationWarning) 15 | warnings.filterwarnings("ignore", category=SyntaxWarning) 16 | warnings.filterwarnings("ignore", category=UserWarning, module="scvi._settings") 17 | warnings.filterwarnings("ignore", message="Environment variable.*redefined by R") 18 | warnings.filterwarnings("ignore", message="Transforming to str index.", category=ImplicitModificationWarning) 19 | 20 | import mudata 21 | 22 | mudata.set_options(pull_on_update=False) 23 | 24 | from . import data as dt 25 | from . import metadata as md 26 | from . import plot as pl 27 | from . import preprocessing as pp 28 | from . import tools as tl 29 | -------------------------------------------------------------------------------- /pertpy/_doc.py: -------------------------------------------------------------------------------- 1 | from textwrap import dedent 2 | 3 | 4 | def _doc_params(**kwds): # pragma: no cover 5 | r"""Docstrings should start with "\" in the first line for proper formatting.""" 6 | 7 | def dec(obj): 8 | obj.__orig_doc__ = obj.__doc__ 9 | obj.__doc__ = dedent(obj.__doc__.format_map(kwds)) 10 | return obj 11 | 12 | return dec 13 | 14 | 15 | doc_common_plot_args = """\ 16 | return_fig: if `True`, returns figure of the plot, that can be used for saving.\ 17 | """ 18 | -------------------------------------------------------------------------------- /pertpy/_types.py: -------------------------------------------------------------------------------- 1 | from scipy import sparse 2 | 3 | CSBase = sparse.csr_matrix | sparse.csc_matrix 4 | CSRBase = sparse.csr_matrix 5 | CSCBase = sparse.csc_matrix 6 | SpBase = sparse.spmatrix 7 | -------------------------------------------------------------------------------- /pertpy/data/__init__.py: -------------------------------------------------------------------------------- 1 | from pertpy.data._datasets import ( 2 | adamson_2016_pilot, 3 | adamson_2016_upr_epistasis, 4 | adamson_2016_upr_perturb_seq, 5 | aissa_2021, 6 | bhattacherjee, 7 | burczynski_crohn, 8 | chang_2021, 9 | cinemaot_example, 10 | combosciplex, 11 | datlinger_2017, 12 | datlinger_2021, 13 | dialogue_example, 14 | distance_example, 15 | dixit_2016, 16 | dixit_2016_raw, 17 | dong_2023, 18 | frangieh_2021, 19 | frangieh_2021_protein, 20 | frangieh_2021_raw, 21 | frangieh_2021_rna, 22 | gasperini_2019_atscale, 23 | gasperini_2019_highmoi, 24 | gasperini_2019_lowmoi, 25 | gehring_2019, 26 | haber_2017_regions, 27 | hagai_2018, 28 | kang_2018, 29 | mcfarland_2020, 30 | norman_2019, 31 | norman_2019_raw, 32 | papalexi_2021, 33 | replogle_2022_k562_essential, 34 | replogle_2022_k562_gwps, 35 | replogle_2022_rpe1, 36 | sc_sim_augur, 37 | schiebinger_2019_16day, 38 | schiebinger_2019_18day, 39 | schraivogel_2020_tap_screen_chr8, 40 | schraivogel_2020_tap_screen_chr11, 41 | sciplex3_raw, 42 | sciplex_gxe1, 43 | shifrut_2018, 44 | smillie_2019, 45 | srivatsan_2020_sciplex2, 46 | srivatsan_2020_sciplex3, 47 | srivatsan_2020_sciplex4, 48 | stephenson_2021_subsampled, 49 | tasccoda_example, 50 | tian_2019_day7neuron, 51 | tian_2019_ipsc, 52 | tian_2021_crispra, 53 | tian_2021_crispri, 54 | weinreb_2020, 55 | xie_2017, 56 | zhang_2021, 57 | zhao_2021, 58 | ) 59 | 60 | __all__ = [ 61 | "adamson_2016_pilot", 62 | "adamson_2016_upr_epistasis", 63 | "adamson_2016_upr_perturb_seq", 64 | "aissa_2021", 65 | "bhattacherjee", 66 | "burczynski_crohn", 67 | "chang_2021", 68 | "cinemaot_example", 69 | "combosciplex", 70 | "datlinger_2017", 71 | "datlinger_2021", 72 | "dialogue_example", 73 | "distance_example", 74 | "dixit_2016", 75 | "dixit_2016_raw", 76 | "dong_2023", 77 | "frangieh_2021", 78 | "frangieh_2021_protein", 79 | "frangieh_2021_raw", 80 | "frangieh_2021_rna", 81 | "gasperini_2019_atscale", 82 | "gasperini_2019_highmoi", 83 | "gasperini_2019_lowmoi", 84 | "gehring_2019", 85 | "haber_2017_regions", 86 | "hagai_2018", 87 | "kang_2018", 88 | "mcfarland_2020", 89 | "norman_2019", 90 | "norman_2019_raw", 91 | "papalexi_2021", 92 | "replogle_2022_k562_essential", 93 | "replogle_2022_k562_gwps", 94 | "replogle_2022_rpe1", 95 | "sc_sim_augur", 96 | "schiebinger_2019_16day", 97 | "schiebinger_2019_18day", 98 | "schraivogel_2020_tap_screen_chr8", 99 | "schraivogel_2020_tap_screen_chr11", 100 | "sciplex3_raw", 101 | "sciplex_gxe1", 102 | "shifrut_2018", 103 | "smillie_2019", 104 | "srivatsan_2020_sciplex2", 105 | "srivatsan_2020_sciplex3", 106 | "srivatsan_2020_sciplex4", 107 | "stephenson_2021_subsampled", 108 | "tasccoda_example", 109 | "tian_2019_day7neuron", 110 | "tian_2019_ipsc", 111 | "tian_2021_crispra", 112 | "tian_2021_crispri", 113 | "weinreb_2020", 114 | "xie_2017", 115 | "zhao_2021", 116 | "zhang_2021", 117 | ] 118 | -------------------------------------------------------------------------------- /pertpy/data/_dataloader.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import tempfile 3 | import time 4 | from pathlib import Path 5 | from random import choice 6 | from string import ascii_lowercase 7 | from zipfile import ZipFile 8 | 9 | import requests 10 | from filelock import FileLock 11 | from lamin_utils import logger 12 | from requests.exceptions import RequestException 13 | from rich.progress import Progress 14 | 15 | 16 | def _download( # pragma: no cover 17 | url: str, 18 | output_file_name: str | None = None, 19 | output_path: str | Path | None = None, 20 | block_size: int = 1024, 21 | overwrite: bool = False, 22 | is_zip: bool = False, 23 | timeout: int = 30, 24 | max_retries: int = 3, 25 | retry_delay: int = 5, 26 | ) -> Path: 27 | """Downloads a dataset irrespective of the format. 28 | 29 | Args: 30 | url: URL to download 31 | output_file_name: Name of the downloaded file 32 | output_path: Path to download/extract the files to. 33 | block_size: Block size for downloads in bytes. 34 | overwrite: Whether to overwrite existing files. 35 | is_zip: Whether the downloaded file needs to be unzipped. 36 | timeout: Request timeout in seconds. 37 | max_retries: Maximum number of retry attempts. 38 | retry_delay: Delay between retries in seconds. 39 | """ 40 | if output_file_name is None: 41 | letters = ascii_lowercase 42 | output_file_name = f"pertpy_tmp_{''.join(choice(letters) for _ in range(10))}" 43 | 44 | if output_path is None: 45 | output_path = tempfile.gettempdir() 46 | 47 | download_to_path = Path(output_path) / output_file_name 48 | 49 | Path(output_path).mkdir(parents=True, exist_ok=True) 50 | lock_path = Path(output_path) / f"{output_file_name}.lock" 51 | 52 | with FileLock(lock_path, timeout=300): 53 | if Path(download_to_path).exists() and not overwrite: 54 | logger.warning(f"File {download_to_path} already exists!") 55 | return download_to_path 56 | 57 | temp_file_name = Path(f"{download_to_path}.part") 58 | 59 | retry_count = 0 60 | while retry_count <= max_retries: 61 | try: 62 | head_response = requests.head(url, timeout=timeout) 63 | head_response.raise_for_status() 64 | content_length = int(head_response.headers.get("content-length", 0)) 65 | 66 | free_space = shutil.disk_usage(output_path).free 67 | if content_length > free_space: 68 | raise OSError( 69 | f"Insufficient disk space. Need {content_length} bytes, but only {free_space} available." 70 | ) 71 | 72 | response = requests.get(url, stream=True) 73 | response.raise_for_status() 74 | total = int(response.headers.get("content-length", 0)) 75 | 76 | with Progress(refresh_per_second=5) as progress: 77 | task = progress.add_task("[red]Downloading...", total=total) 78 | with Path(temp_file_name).open("wb") as file: 79 | for data in response.iter_content(block_size): 80 | file.write(data) 81 | progress.update(task, advance=len(data)) 82 | progress.update(task, completed=total, refresh=True) 83 | 84 | Path(temp_file_name).replace(download_to_path) 85 | 86 | if is_zip: 87 | with ZipFile(download_to_path, "r") as zip_obj: 88 | zip_obj.extractall(path=output_path) 89 | return Path(output_path) 90 | 91 | return download_to_path 92 | except (OSError, RequestException) as e: 93 | retry_count += 1 94 | if retry_count <= max_retries: 95 | logger.warning( 96 | f"Download attempt {retry_count}/{max_retries} failed: {str(e)}. Retrying in {retry_delay} seconds..." 97 | ) 98 | time.sleep(retry_delay) 99 | else: 100 | logger.error(f"Download failed after {max_retries} attempts: {str(e)}") 101 | if Path(temp_file_name).exists(): 102 | Path(temp_file_name).unlink(missing_ok=True) 103 | raise 104 | 105 | except Exception as e: 106 | logger.error(f"Download failed: {str(e)}") 107 | if Path(temp_file_name).exists(): 108 | Path(temp_file_name).unlink(missing_ok=True) 109 | raise 110 | finally: 111 | if Path(temp_file_name).exists(): 112 | Path(temp_file_name).unlink(missing_ok=True) 113 | 114 | return Path(download_to_path) 115 | -------------------------------------------------------------------------------- /pertpy/metadata/__init__.py: -------------------------------------------------------------------------------- 1 | from pertpy.metadata._cell_line import CellLine 2 | from pertpy.metadata._compound import Compound 3 | from pertpy.metadata._drug import Drug 4 | from pertpy.metadata._look_up import LookUp 5 | from pertpy.metadata._moa import Moa 6 | 7 | __all__ = ["CellLine", "Compound", "Drug", "Moa", "LookUp"] 8 | -------------------------------------------------------------------------------- /pertpy/metadata/_compound.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING, Literal 4 | 5 | import pandas as pd 6 | 7 | from ._look_up import LookUp 8 | from ._metadata import MetaData 9 | 10 | if TYPE_CHECKING: 11 | from anndata import AnnData 12 | 13 | 14 | class Compound(MetaData): 15 | """Utilities to fetch metadata for compounds.""" 16 | 17 | def __init__(self): 18 | super().__init__() 19 | 20 | def annotate_compounds( 21 | self, 22 | adata: AnnData, 23 | query_id: str = "perturbation", 24 | query_id_type: Literal["name", "cid"] = "name", 25 | verbosity: int | str = 5, 26 | copy: bool = False, 27 | ) -> AnnData: 28 | """Fetch compound annotation from pubchempy. 29 | 30 | Args: 31 | adata: The data object to annotate. 32 | query_id: The column of `.obs` with compound identifiers. 33 | query_id_type: The type of compound identifiers, 'name' or 'cid'. 34 | verbosity: The number of unmatched identifiers to print, can be either non-negative values or "all". 35 | copy: Determines whether a copy of the `adata` is returned. 36 | 37 | Returns: 38 | Returns an AnnData object with compound annotation. 39 | """ 40 | if copy: 41 | adata = adata.copy() 42 | 43 | if query_id not in adata.obs.columns: 44 | raise ValueError(f"The requested query_id {query_id} is not in `adata.obs`.\n Please check again.") 45 | 46 | import pubchempy as pcp 47 | 48 | query_dict = {} 49 | not_matched_identifiers = [] 50 | for compound in adata.obs[query_id].dropna().astype(str).unique(): 51 | if query_id_type == "name": 52 | cids = pcp.get_compounds(compound, "name") 53 | if len(cids) == 0: # search did not work 54 | not_matched_identifiers.append(compound) 55 | if len(cids) >= 1: 56 | # If the name matches the first synonym offered by PubChem (outside of capitalization), 57 | # it is not changed (outside of capitalization). Otherwise, it is replaced with the first synonym. 58 | query_dict[compound] = [ 59 | cids[0].synonyms[0], 60 | cids[0].cid, 61 | cids[0].canonical_smiles, 62 | ] 63 | else: 64 | try: 65 | cid = pcp.Compound.from_cid(compound) 66 | query_dict[compound] = [ 67 | cid.synonyms[0], 68 | compound, 69 | cid.canonical_smiles, 70 | ] 71 | except pcp.BadRequestError: 72 | # pubchempy throws badrequest if a cid is not found 73 | not_matched_identifiers.append(compound) 74 | 75 | identifier_num_all = len(adata.obs[query_id].unique()) 76 | self._warn_unmatch( 77 | total_identifiers=identifier_num_all, 78 | unmatched_identifiers=not_matched_identifiers, 79 | query_id=query_id, 80 | reference_id=query_id_type, 81 | metadata_type="compound", 82 | verbosity=verbosity, 83 | ) 84 | 85 | query_df = pd.DataFrame.from_dict(query_dict, orient="index", columns=["pubchem_name", "pubchem_ID", "smiles"]) 86 | # Merge and remove duplicate columns 87 | # Column is converted to float after merging due to unmatches 88 | # Convert back to integers afterwards 89 | if query_id_type == "cid": 90 | query_df.pubchem_ID = query_df.pubchem_ID.astype("Int64") 91 | adata.obs = ( 92 | adata.obs.merge( 93 | query_df, 94 | left_on=query_id, 95 | right_on="pubchem_ID", 96 | how="left", 97 | suffixes=("", "_fromMeta"), 98 | ) 99 | .filter(regex="^(?!.*_fromMeta)") 100 | .set_index(adata.obs.index) 101 | ) 102 | else: 103 | adata.obs = ( 104 | adata.obs.merge( 105 | query_df, 106 | left_on=query_id, 107 | right_index=True, 108 | how="left", 109 | suffixes=("", "_fromMeta"), 110 | ) 111 | .filter(regex="^(?!.*_fromMeta)") 112 | .set_index(adata.obs.index) 113 | ) 114 | adata.obs.pubchem_ID = adata.obs.pubchem_ID.astype("Int64") 115 | 116 | return adata 117 | 118 | def lookup(self) -> LookUp: 119 | """Generate LookUp object for CompoundMetaData. 120 | 121 | The LookUp object provides an overview of the metadata to annotate. 122 | Each annotate_{metadata} function has a corresponding lookup function in the LookUp object, 123 | where users can search the reference_id in the metadata and compare with the query_id in their own data. 124 | 125 | Returns: 126 | Returns a LookUp object specific for compound annotation. 127 | """ 128 | return LookUp(type="compound") 129 | -------------------------------------------------------------------------------- /pertpy/metadata/_metadata.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING, Literal 4 | 5 | from lamin_utils import logger 6 | 7 | if TYPE_CHECKING: 8 | from collections.abc import Sequence 9 | 10 | 11 | class MetaData: 12 | """Superclass for pertpy's MetaData components.""" 13 | 14 | def _warn_unmatch( 15 | self, 16 | total_identifiers: int, 17 | unmatched_identifiers: Sequence[str], 18 | query_id: str, 19 | reference_id: str, 20 | metadata_type: Literal[ 21 | "cell line", 22 | "protein expression", 23 | "bulk RNA", 24 | "drug response", 25 | "moa", 26 | "compound", 27 | ] = "cell line", 28 | verbosity: int | str = 5, 29 | ) -> None: 30 | """Helper function to print out the unmatched identifiers. 31 | 32 | Args: 33 | total_identifiers: The total number of identifiers in the `adata` object. 34 | unmatched_identifiers: Unmatched identifiers in the `adata` object. 35 | query_id: The column of `.obs` with cell line information. 36 | reference_id: The type of cell line identifier in the metadata. 37 | metadata_type: The type of metadata where some identifiers are not matched during annotation such as 38 | cell line, protein expression, bulk RNA expression, drug response, moa or compound. 39 | verbosity: The number of unmatched identifiers to print, can be either non-negative values or 'all'. 40 | """ 41 | if isinstance(verbosity, str): 42 | if verbosity != "all": 43 | raise ValueError("Only a non-negative value or 'all' is accepted.") 44 | else: 45 | verbosity = len(unmatched_identifiers) 46 | 47 | if len(unmatched_identifiers) == total_identifiers: 48 | hint = "" 49 | if metadata_type in ["protein expression", "bulk RNA", "drug response"]: 50 | hint = "Additionally, call the `CellLineMetaData.annotate()` function to acquire more possible query IDs that can be used for cell line annotation purposes." 51 | raise ValueError( 52 | f"No matches between `{query_id}` in adata.obs and `{reference_id}` in {metadata_type} data. " 53 | f"Use `lookup()` to check compatible identifier types. {hint}" 54 | ) 55 | if len(unmatched_identifiers) == 0: 56 | return 57 | if isinstance(verbosity, int) and verbosity >= 0: 58 | verbosity = min(verbosity, len(unmatched_identifiers)) 59 | if verbosity > 0: 60 | logger.info( 61 | f"{total_identifiers} identifiers in `adata.obs`, {len(unmatched_identifiers)} not found in {metadata_type} data. " 62 | f"NA values present. Unmatched: {unmatched_identifiers[:verbosity]}" 63 | ) 64 | else: 65 | raise ValueError("Only 'all' or a non-negative value is accepted.") 66 | -------------------------------------------------------------------------------- /pertpy/metadata/_moa.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | from typing import TYPE_CHECKING 5 | 6 | import numpy as np 7 | import pandas as pd 8 | from scanpy import settings 9 | 10 | from pertpy.data._dataloader import _download 11 | 12 | from ._look_up import LookUp 13 | from ._metadata import MetaData 14 | 15 | if TYPE_CHECKING: 16 | from anndata import AnnData 17 | 18 | 19 | class Moa(MetaData): 20 | """Utilities to fetch metadata for mechanism of action studies.""" 21 | 22 | def __init__(self): 23 | self.clue = None 24 | 25 | def _download_clue(self) -> None: 26 | clue_path = Path(settings.cachedir) / "repurposing_drugs_20200324.txt" 27 | if not Path(clue_path).exists(): 28 | _download( 29 | url="https://s3.amazonaws.com/data.clue.io/repurposing/downloads/repurposing_drugs_20200324.txt", 30 | output_file_name="repurposing_drugs_20200324.txt", 31 | output_path=settings.cachedir, 32 | block_size=4096, 33 | is_zip=False, 34 | ) 35 | self.clue = pd.read_csv(clue_path, sep=" ", skiprows=9) 36 | self.clue = self.clue[["pert_iname", "moa", "target"]] 37 | 38 | def annotate( 39 | self, 40 | adata: AnnData, 41 | query_id: str = "perturbation", 42 | target: str | None = None, 43 | verbosity: int | str = 5, 44 | copy: bool = False, 45 | ) -> AnnData: 46 | """Annotate cells affected by perturbations by mechanism of action. 47 | 48 | For each cell, we fetch the mechanism of action and molecular targets of the compounds sourced from clue.io. 49 | 50 | Args: 51 | adata: The data object to annotate. 52 | query_id: The column of `.obs` with the name of a perturbagen. 53 | target: The column of `.obs` with target information. If set to None, all MoAs are retrieved without comparing molecular targets. 54 | verbosity: The number of unmatched identifiers to print, can be either non-negative values or 'all'. 55 | copy: Determines whether a copy of the `adata` is returned. 56 | 57 | Returns: 58 | Returns an AnnData object with MoA annotation. 59 | """ 60 | if copy: 61 | adata = adata.copy() 62 | 63 | if query_id not in adata.obs.columns: 64 | raise ValueError(f"The requested query_id {query_id} is not in `adata.obs`.\nPlease check again.") 65 | 66 | if self.clue is None: 67 | self._download_clue() 68 | 69 | identifier_num_all = len(adata.obs[query_id].unique()) 70 | not_matched_identifiers = list(set(adata.obs[query_id].str.lower()) - set(self.clue["pert_iname"].str.lower())) 71 | self._warn_unmatch( 72 | total_identifiers=identifier_num_all, 73 | unmatched_identifiers=not_matched_identifiers, 74 | query_id=query_id, 75 | reference_id="pert_iname", 76 | metadata_type="moa", 77 | verbosity=verbosity, 78 | ) 79 | 80 | adata.obs = ( 81 | adata.obs.merge( 82 | self.clue, 83 | left_on=adata.obs[query_id].str.lower(), 84 | right_on=self.clue["pert_iname"].str.lower(), 85 | how="left", 86 | suffixes=("", "_fromMeta"), 87 | ) 88 | .set_index(adata.obs.index) 89 | .drop("key_0", axis=1) 90 | ) 91 | 92 | # If target column is given, check whether it is one of the targets listed in the metadata 93 | # If inconsistent, treat this perturbagen as unmatched and overwrite the annotated metadata with NaN 94 | if target is not None: 95 | target_meta = "target" if target != "target" else "target_fromMeta" 96 | adata.obs[target_meta] = adata.obs[target_meta].mask( 97 | ~adata.obs.apply(lambda row: str(row[target]) in str(row[target_meta]), axis=1) 98 | ) 99 | pertname_meta = "pert_iname" if query_id != "pert_iname" else "pert_iname_fromMeta" 100 | adata.obs.loc[adata.obs[target_meta].isna(), [pertname_meta, "moa"]] = np.nan 101 | 102 | # If query_id and reference_id have different names, there will be a column for each of them after merging 103 | # which is redundant as they refer to the same information. 104 | if query_id != "pert_iname": 105 | del adata.obs["pert_iname"] 106 | 107 | return adata 108 | 109 | def lookup(self) -> LookUp: 110 | """Generate LookUp object for Moa metadata. 111 | 112 | The LookUp object provides an overview of the metadata to annotate. 113 | 114 | Returns: 115 | Returns a LookUp object specific for MoA annotation. 116 | """ 117 | if self.clue is None: 118 | self._download_clue() 119 | 120 | return LookUp( 121 | type="moa", 122 | transfer_metadata=[self.clue], 123 | ) 124 | -------------------------------------------------------------------------------- /pertpy/plot/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/pertpy/plot/__init__.py -------------------------------------------------------------------------------- /pertpy/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | from ._guide_rna import GuideAssignment 2 | 3 | __all__ = ["GuideAssignment"] 4 | -------------------------------------------------------------------------------- /pertpy/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/pertpy/py.typed -------------------------------------------------------------------------------- /pertpy/tools/__init__.py: -------------------------------------------------------------------------------- 1 | from importlib import import_module 2 | 3 | 4 | def lazy_import(module_path: str, class_name: str, extras: list[str]): 5 | try: 6 | for extra in extras: 7 | import_module(extra) 8 | module = import_module(module_path) 9 | return getattr(module, class_name) 10 | except ImportError: 11 | 12 | class Placeholder: 13 | def __init__(self, *args, **kwargs): 14 | raise ImportError( 15 | f"Extra dependencies required: {', '.join(extras)}. " 16 | f"Please install with: pip install {' '.join(extras)}" 17 | ) 18 | 19 | return Placeholder 20 | 21 | 22 | from pertpy.tools._augur import Augur 23 | from pertpy.tools._cinemaot import Cinemaot 24 | from pertpy.tools._coda._sccoda import Sccoda 25 | from pertpy.tools._dialogue import Dialogue 26 | from pertpy.tools._distances._distance_tests import DistanceTest 27 | from pertpy.tools._distances._distances import Distance 28 | from pertpy.tools._enrichment import Enrichment 29 | from pertpy.tools._milo import Milo 30 | from pertpy.tools._mixscape import Mixscape 31 | from pertpy.tools._perturbation_space._clustering import ClusteringSpace 32 | from pertpy.tools._perturbation_space._comparison import PerturbationComparison 33 | from pertpy.tools._perturbation_space._discriminator_classifiers import ( 34 | LRClassifierSpace, 35 | MLPClassifierSpace, 36 | ) 37 | from pertpy.tools._perturbation_space._simple import ( 38 | CentroidSpace, 39 | DBSCANSpace, 40 | KMeansSpace, 41 | PseudobulkSpace, 42 | ) 43 | from pertpy.tools._scgen import Scgen 44 | 45 | CODA_EXTRAS = ["toytree", "ete4"] # also "pyqt6" but it cannot be imported 46 | Tasccoda = lazy_import("pertpy.tools._coda._tasccoda", "Tasccoda", CODA_EXTRAS) 47 | 48 | DE_EXTRAS = ["formulaic", "pydeseq2"] 49 | EdgeR = lazy_import("pertpy.tools._differential_gene_expression", "EdgeR", DE_EXTRAS) # edgeR will be imported via rpy2 50 | PyDESeq2 = lazy_import("pertpy.tools._differential_gene_expression", "PyDESeq2", DE_EXTRAS) 51 | Statsmodels = lazy_import("pertpy.tools._differential_gene_expression", "Statsmodels", DE_EXTRAS + ["statsmodels"]) 52 | TTest = lazy_import("pertpy.tools._differential_gene_expression", "TTest", DE_EXTRAS) 53 | WilcoxonTest = lazy_import("pertpy.tools._differential_gene_expression", "WilcoxonTest", DE_EXTRAS) 54 | 55 | __all__ = [ 56 | "Augur", 57 | "Cinemaot", 58 | "Sccoda", 59 | "Tasccoda", 60 | "Dialogue", 61 | "EdgeR", 62 | "PyDESeq2", 63 | "WilcoxonTest", 64 | "TTest", 65 | "Statsmodels", 66 | "DistanceTest", 67 | "Distance", 68 | "Enrichment", 69 | "Milo", 70 | "Mixscape", 71 | "ClusteringSpace", 72 | "PerturbationComparison", 73 | "LRClassifierSpace", 74 | "MLPClassifierSpace", 75 | "CentroidSpace", 76 | "DBSCANSpace", 77 | "KMeansSpace", 78 | "PseudobulkSpace", 79 | "Scgen", 80 | ] 81 | -------------------------------------------------------------------------------- /pertpy/tools/_coda/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/pertpy/tools/_coda/__init__.py -------------------------------------------------------------------------------- /pertpy/tools/_differential_gene_expression/__init__.py: -------------------------------------------------------------------------------- 1 | from ._base import LinearModelBase, MethodBase 2 | from ._dge_comparison import DGEEVAL 3 | from ._edger import EdgeR 4 | from ._pydeseq2 import PyDESeq2 5 | from ._simple_tests import SimpleComparisonBase, TTest, WilcoxonTest 6 | from ._statsmodels import Statsmodels 7 | 8 | __all__ = [ 9 | "MethodBase", 10 | "LinearModelBase", 11 | "EdgeR", 12 | "PyDESeq2", 13 | "Statsmodels", 14 | "SimpleComparisonBase", 15 | "WilcoxonTest", 16 | "TTest", 17 | ] 18 | 19 | AVAILABLE_METHODS = [Statsmodels, EdgeR, PyDESeq2, WilcoxonTest, TTest] 20 | -------------------------------------------------------------------------------- /pertpy/tools/_differential_gene_expression/_checks.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.sparse import issparse, spmatrix 3 | 4 | 5 | def check_is_numeric_matrix(array: np.ndarray | spmatrix) -> None: 6 | """Check if a matrix is numeric and only contains finite/non-NA values. 7 | 8 | Args: 9 | array: Dense or sparse matrix to check. 10 | 11 | Raises: 12 | ValueError: If the matrix is not numeric or contains NaNs or infinite values. 13 | """ 14 | if not np.issubdtype(array.dtype, np.number): 15 | raise ValueError("Counts must be numeric.") 16 | if issparse(array): 17 | if np.any(~np.isfinite(array.data)): 18 | raise ValueError("Counts cannot contain negative, NaN or Inf values.") 19 | elif np.any(~np.isfinite(array)): 20 | raise ValueError("Counts cannot contain negative, NaN or Inf values.") 21 | 22 | 23 | def check_is_integer_matrix(array: np.ndarray | spmatrix, tolerance: float = 1e-6) -> None: 24 | """Check if a matrix container integers, or floats that are close to integers. 25 | 26 | Args: 27 | array: Dense or sparse matrix to check. 28 | tolerance: Values must be this close to integers. 29 | 30 | Raises: 31 | ValueError: If the matrix contains values that are not close to integers. 32 | """ 33 | if issparse(array): 34 | if not array.data.dtype.kind == "i" and not np.all(np.abs(array.data - np.round(array.data)) < tolerance): 35 | raise ValueError("Non-zero elements of the matrix must be close to integer values.") 36 | elif array.dtype.kind != "i" and not np.all(np.abs(array - np.round(array)) < tolerance): 37 | raise ValueError("Matrix must be a count matrix.") 38 | if (array < 0).sum() > 0: 39 | raise ValueError("Non-zero elements of the matrix must be positive.") 40 | -------------------------------------------------------------------------------- /pertpy/tools/_differential_gene_expression/_dge_comparison.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from anndata import AnnData 4 | 5 | 6 | class DGEEVAL: 7 | def compare( 8 | self, 9 | adata: AnnData | None = None, 10 | de_key1: str = None, 11 | de_key2: str = None, 12 | de_df1: pd.DataFrame | None = None, 13 | de_df2: pd.DataFrame | None = None, 14 | shared_top: int = 100, 15 | ) -> dict[str, float]: 16 | """Compare two differential expression analyses. 17 | 18 | Compare two sets of DE results and evaluate the similarity by the overlap of top DEG and 19 | the correlation of their scores and adjusted p-values. 20 | 21 | Args: 22 | adata: AnnData object containing DE results in `uns`. Required if `de_key1` and `de_key2` are used. 23 | de_key1: Key for DE results in `adata.uns`, e.g., output of `tl.rank_genes_groups`. 24 | de_key2: Another key for DE results in `adata.uns`, e.g., output of `tl.rank_genes_groups`. 25 | de_df1: DataFrame containing DE results, e.g. output from pertpy differential gene expression interface. 26 | de_df2: DataFrame containing DE results, e.g. output from pertpy differential gene expression interface. 27 | shared_top: The number of top DEG to compute the proportion of their intersection. 28 | 29 | """ 30 | if (de_key1 or de_key2) and (de_df1 is not None or de_df2 is not None): 31 | raise ValueError( 32 | "Please provide either both `de_key1` and `de_key2` with `adata`, or `de_df1` and `de_df2`, but not both." 33 | ) 34 | 35 | if de_df1 is None and de_df2 is None: # use keys 36 | if not de_key1 or not de_key2: 37 | raise ValueError("Both `de_key1` and `de_key2` must be provided together if using `adata`.") 38 | 39 | elif de_df1 is None or de_df2 is None: 40 | raise ValueError("Both `de_df1` and `de_df2` must be provided together if using DataFrames.") 41 | 42 | if de_key1: 43 | if not adata: 44 | raise ValueError("`adata` should be provided with `de_key1` and `de_key2`. ") 45 | assert all(k in adata.uns for k in [de_key1, de_key2]), ( 46 | "Provided `de_key1` and `de_key2` must exist in `adata.uns`." 47 | ) 48 | vars = adata.var_names 49 | 50 | if de_df1 is not None: 51 | for df in (de_df1, de_df2): 52 | if not {"variable", "log_fc", "adj_p_value"}.issubset(df.columns): 53 | raise ValueError("Each DataFrame must contain columns: 'variable', 'log_fc', and 'adj_p_value'.") 54 | 55 | assert set(de_df1["variable"]) == set(de_df2["variable"]), "Variables in both dataframes must match." 56 | vars = de_df1["variable"].sort_values() 57 | 58 | shared_top = min(shared_top, len(vars)) 59 | vars_ranks = np.arange(1, len(vars) + 1) 60 | results = pd.DataFrame(index=vars) 61 | top_names = [] 62 | 63 | if de_key1 and de_key2: 64 | for i, k in enumerate([de_key1, de_key2]): 65 | label = adata.uns[k]["names"].dtype.names[0] 66 | srt_idx = np.argsort(adata.uns[k]["names"][label]) 67 | results[f"scores_{i}"] = adata.uns[k]["scores"][label][srt_idx] 68 | results[f"pvals_adj_{i}"] = adata.uns[k]["pvals_adj"][label][srt_idx] 69 | results[f"ranks_{i}"] = vars_ranks[srt_idx] 70 | top_names.append(adata.uns[k]["names"][label][:shared_top]) 71 | else: 72 | for i, df in enumerate([de_df1, de_df2]): 73 | srt_idx = np.argsort(df["variable"]) 74 | results[f"scores_{i}"] = df["log_fc"].values[srt_idx] 75 | results[f"pvals_adj_{i}"] = df["adj_p_value"].values[srt_idx] 76 | results[f"ranks_{i}"] = vars_ranks[srt_idx] 77 | top_names.append(df["variable"][:shared_top]) 78 | 79 | metrics = {} 80 | metrics["shared_top_genes"] = len(set(top_names[0]).intersection(top_names[1])) / shared_top 81 | metrics["scores_corr"] = results["scores_0"].corr(results["scores_1"], method="pearson") 82 | metrics["pvals_adj_corr"] = results["pvals_adj_0"].corr(results["pvals_adj_1"], method="pearson") 83 | metrics["scores_ranks_corr"] = results["ranks_0"].corr(results["ranks_1"], method="spearman") 84 | 85 | return metrics 86 | -------------------------------------------------------------------------------- /pertpy/tools/_differential_gene_expression/_edger.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Sequence 2 | 3 | import numpy as np 4 | import pandas as pd 5 | from lamin_utils import logger 6 | from scipy.sparse import issparse 7 | 8 | from ._base import LinearModelBase 9 | from ._checks import check_is_integer_matrix 10 | 11 | 12 | class EdgeR(LinearModelBase): 13 | """Differential expression test using EdgeR.""" 14 | 15 | def _check_counts(self): 16 | check_is_integer_matrix(self.data) 17 | 18 | def fit(self, **kwargs): # adata, design, mask, layer 19 | """Fit model using edgeR. 20 | 21 | Note: this creates its own AnnData object for downstream. 22 | 23 | Args: 24 | **kwargs: Keyword arguments specific to glmQLFit() 25 | """ 26 | # For running in notebook 27 | # pandas2ri.activate() 28 | # rpy2.robjects.numpy2ri.activate() 29 | try: 30 | from rpy2 import robjects as ro 31 | from rpy2.robjects import numpy2ri, pandas2ri 32 | from rpy2.robjects.conversion import get_conversion, localconverter 33 | from rpy2.robjects.packages import importr 34 | 35 | except ImportError: 36 | raise ImportError("edger requires rpy2 to be installed.") from None 37 | 38 | try: 39 | edger = importr("edgeR") 40 | except ImportError as e: 41 | raise ImportError( 42 | "edgeR requires a valid R installation with the following packages:\nedgeR, BiocParallel, RhpcBLASctl" 43 | ) from e 44 | 45 | # Convert dataframe 46 | with localconverter(get_conversion() + numpy2ri.converter): 47 | expr = self.adata.X if self.layer is None else self.adata.layers[self.layer] 48 | expr = expr.T.toarray() if issparse(expr) else expr.T 49 | 50 | with localconverter(get_conversion() + pandas2ri.converter): 51 | expr_r = ro.conversion.py2rpy(pd.DataFrame(expr, index=self.adata.var_names, columns=self.adata.obs_names)) 52 | samples_r = ro.conversion.py2rpy(self.adata.obs) 53 | 54 | dge = edger.DGEList(counts=expr_r, samples=samples_r) 55 | 56 | logger.info("Calculating NormFactors") 57 | dge = edger.calcNormFactors(dge) 58 | 59 | with localconverter(get_conversion() + numpy2ri.converter): 60 | design_r = ro.conversion.py2rpy(self.design.values) 61 | 62 | logger.info("Estimating Dispersions") 63 | dge = edger.estimateDisp(dge, design=design_r) 64 | 65 | logger.info("Fitting linear model") 66 | fit = edger.glmQLFit(dge, design=design_r, **kwargs) 67 | 68 | ro.globalenv["fit"] = fit 69 | self.fit = fit 70 | 71 | def _test_single_contrast(self, contrast: Sequence[float], **kwargs) -> pd.DataFrame: # noqa: D417 72 | """Conduct test for each contrast and return a data frame. 73 | 74 | Args: 75 | contrast: numpy array of integars indicating contrast i.e. [-1, 0, 1, 0, 0] 76 | """ 77 | ## -- Check installations 78 | # For running in notebook 79 | # pandas2ri.activate() 80 | # rpy2.robjects.numpy2ri.activate() 81 | 82 | # ToDo: 83 | # parse **kwargs to R function 84 | # Fix mask for .fit() 85 | 86 | try: 87 | from rpy2 import robjects as ro 88 | from rpy2.robjects import numpy2ri, pandas2ri 89 | from rpy2.robjects.conversion import get_conversion, localconverter 90 | from rpy2.robjects.packages import importr 91 | 92 | except ImportError: 93 | raise ImportError("edger requires rpy2 to be installed.") from None 94 | 95 | try: 96 | importr("edgeR") 97 | except ImportError: 98 | raise ImportError( 99 | "edgeR requires a valid R installation with the following packages: edgeR, BiocParallel, RhpcBLASctl" 100 | ) from None 101 | 102 | # Convert vector to R, which drops a category like `self.design_matrix` to use the intercept for the left out. 103 | with localconverter(get_conversion() + numpy2ri.converter): 104 | contrast_vec_r = ro.conversion.py2rpy(np.asarray(contrast)) 105 | ro.globalenv["contrast_vec"] = contrast_vec_r 106 | 107 | # Test contrast with R 108 | ro.r( 109 | """ 110 | test = edgeR::glmQLFTest(fit, contrast=contrast_vec) 111 | de_res = edgeR::topTags(test, n=Inf, adjust.method="BH", sort.by="PValue")$table 112 | """ 113 | ) 114 | 115 | # Retrieve the `de_res` object 116 | de_res = ro.globalenv["de_res"] 117 | 118 | # If already a Pandas DataFrame, return it directly 119 | if isinstance(de_res, pd.DataFrame): 120 | de_res.index.name = "variable" 121 | return de_res.reset_index().rename(columns={"PValue": "p_value", "logFC": "log_fc", "FDR": "adj_p_value"}) 122 | 123 | # Convert to Pandas DataFrame if still an R object 124 | with localconverter(get_conversion() + pandas2ri.converter): 125 | de_res = ro.conversion.rpy2py(de_res) 126 | 127 | de_res.index.name = "variable" 128 | de_res = de_res.reset_index() 129 | 130 | return de_res.rename(columns={"PValue": "p_value", "logFC": "log_fc", "FDR": "adj_p_value"}) 131 | -------------------------------------------------------------------------------- /pertpy/tools/_differential_gene_expression/_pydeseq2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import warnings 4 | 5 | import numpy as np 6 | import pandas as pd 7 | from anndata import AnnData 8 | from numpy import ndarray 9 | from pydeseq2.dds import DeseqDataSet 10 | from pydeseq2.default_inference import DefaultInference 11 | from pydeseq2.ds import DeseqStats 12 | from scipy.sparse import issparse 13 | 14 | from ._base import LinearModelBase 15 | from ._checks import check_is_integer_matrix 16 | 17 | 18 | class PyDESeq2(LinearModelBase): 19 | """Differential expression test using a PyDESeq2.""" 20 | 21 | def __init__( 22 | self, adata: AnnData, design: str | ndarray, *, mask: str | None = None, layer: str | None = None, **kwargs 23 | ): 24 | super().__init__(adata, design, mask=mask, layer=layer, **kwargs) 25 | # work around pydeseq2 issue with sparse matrices 26 | # see also https://github.com/owkin/PyDESeq2/issues/25 27 | if issparse(self.data): 28 | if self.layer is None: 29 | self.adata.X = self.adata.X.toarray() 30 | else: 31 | self.adata.layers[self.layer] = self.adata.layers[self.layer].toarray() 32 | 33 | def _check_counts(self): 34 | check_is_integer_matrix(self.data) 35 | 36 | def fit(self, **kwargs) -> pd.DataFrame: 37 | """Fit dds model using pydeseq2. 38 | 39 | Note: this creates its own AnnData object for downstream processing. 40 | 41 | Args: 42 | **kwargs: Keyword arguments specific to DeseqDataSet(), except for `n_cpus` which will use all available CPUs minus one if the argument is not passed. 43 | """ 44 | try: 45 | usable_cpus = len(os.sched_getaffinity(0)) 46 | except AttributeError: 47 | usable_cpus = os.cpu_count() 48 | 49 | inference = DefaultInference(n_cpus=kwargs.pop("n_cpus", usable_cpus)) 50 | 51 | dds = DeseqDataSet( 52 | adata=self.adata, 53 | design=self.design, # initialize using design matrix, not formula 54 | refit_cooks=True, 55 | inference=inference, 56 | **kwargs, 57 | ) 58 | 59 | dds.deseq2() 60 | self.dds = dds 61 | 62 | def _test_single_contrast(self, contrast, alpha=0.05, **kwargs) -> pd.DataFrame: 63 | """Conduct a specific test and returns a Pandas DataFrame. 64 | 65 | Args: 66 | contrast: list of three strings of the form `["variable", "tested level", "reference level"]`. 67 | alpha: p value threshold used for controlling fdr with independent hypothesis weighting 68 | **kwargs: extra arguments to pass to DeseqStats() 69 | """ 70 | contrast = np.array(contrast) 71 | stat_res = DeseqStats(self.dds, contrast=contrast, alpha=alpha, **kwargs) 72 | # Calling `.summary()` is required to fill the `results_df` data frame 73 | stat_res.summary() 74 | res_df = ( 75 | pd.DataFrame(stat_res.results_df) 76 | .rename(columns={"pvalue": "p_value", "padj": "adj_p_value", "log2FoldChange": "log_fc"}) 77 | .sort_values("p_value") 78 | ) 79 | res_df.index.name = "variable" 80 | res_df = res_df.reset_index() 81 | return res_df 82 | -------------------------------------------------------------------------------- /pertpy/tools/_differential_gene_expression/_statsmodels.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import scanpy as sc 4 | import statsmodels 5 | import statsmodels.api as sm 6 | from tqdm.auto import tqdm 7 | 8 | from ._base import LinearModelBase 9 | from ._checks import check_is_numeric_matrix 10 | 11 | 12 | class Statsmodels(LinearModelBase): 13 | """Differential expression test using a statsmodels linear regression.""" 14 | 15 | def _check_counts(self): 16 | check_is_numeric_matrix(self.data) 17 | 18 | def fit( 19 | self, 20 | regression_model: type[sm.OLS] | type[sm.GLM] = sm.OLS, 21 | **kwargs, 22 | ) -> None: 23 | """Fit the specified regression model. 24 | 25 | Args: 26 | regression_model: A statsmodels regression model class, either OLS or GLM. 27 | **kwargs: Additional arguments for fitting the specific method. In particular, this 28 | is where you can specify the family for GLM. 29 | 30 | Examples: 31 | >>> import statsmodels.api as sm 32 | >>> import pertpy as pt 33 | >>> model = pt.tl.Statsmodels(adata, design="~condition") 34 | >>> model.fit(sm.GLM, family=sm.families.NegativeBinomial(link=sm.families.links.Log())) 35 | >>> results = model.test_contrasts(np.array([0, 1])) 36 | """ 37 | self.models = [] 38 | for var in tqdm(self.adata.var_names): 39 | mod = regression_model( 40 | sc.get.obs_df(self.adata, keys=[var], layer=self.layer)[var], 41 | self.design, 42 | **kwargs, 43 | ) 44 | mod = mod.fit() 45 | self.models.append(mod) 46 | 47 | def _test_single_contrast(self, contrast, **kwargs) -> pd.DataFrame: 48 | res = [] 49 | for var, mod in zip(tqdm(self.adata.var_names), self.models, strict=False): 50 | t_test = mod.t_test(contrast) 51 | res.append( 52 | { 53 | "variable": var, 54 | "p_value": t_test.pvalue, 55 | "t_value": t_test.tvalue.item(), 56 | "sd": t_test.sd.item(), 57 | "log_fc": t_test.effect.item(), 58 | } 59 | ) 60 | return ( 61 | pd.DataFrame(res) 62 | .sort_values("p_value") 63 | .assign(adj_p_value=lambda x: statsmodels.stats.multitest.fdrcorrection(x["p_value"])[1]) 64 | ) 65 | -------------------------------------------------------------------------------- /pertpy/tools/_distances/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/pertpy/tools/_distances/__init__.py -------------------------------------------------------------------------------- /pertpy/tools/_perturbation_space/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/pertpy/tools/_perturbation_space/__init__.py -------------------------------------------------------------------------------- /pertpy/tools/_perturbation_space/_clustering.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING 4 | 5 | from sklearn.metrics import pairwise_distances 6 | 7 | from pertpy.tools._perturbation_space._perturbation_space import PerturbationSpace 8 | 9 | if TYPE_CHECKING: 10 | from collections.abc import Iterable 11 | 12 | from anndata import AnnData 13 | 14 | 15 | class ClusteringSpace(PerturbationSpace): 16 | """Applies various clustering techniques to an embedding.""" 17 | 18 | def __init__(self): 19 | super().__init__() 20 | self.X = None 21 | 22 | def evaluate_clustering( 23 | self, 24 | adata: AnnData, 25 | true_label_col: str, 26 | cluster_col: str, 27 | metrics: Iterable[str] = None, 28 | **kwargs, 29 | ): 30 | """Evaluation of previously computed clustering against ground truth labels. 31 | 32 | Args: 33 | adata: AnnData object that contains the clustered data and the cluster labels. 34 | true_label_col: ground truth labels. 35 | cluster_col: cluster computed labels. 36 | metrics: Metrics to compute. If `None` it defaults to ["nmi", "ari", "asw"]. 37 | **kwargs: Additional arguments to pass to the metrics. For nmi, average_method can be passed. 38 | For asw, metric, distances, sample_size, and random_state can be passed. 39 | 40 | Examples: 41 | Example usage with KMeansSpace: 42 | 43 | >>> import pertpy as pt 44 | >>> mdata = pt.dt.papalexi_2021() 45 | >>> kmeans = pt.tl.KMeansSpace() 46 | >>> kmeans_adata = kmeans.compute(mdata["rna"], n_clusters=26) 47 | >>> results = kmeans.evaluate_clustering( 48 | ... kmeans_adata, true_label_col="gene_target", cluster_col="k-means", metrics=["nmi"] 49 | ... ) 50 | """ 51 | if metrics is None: 52 | metrics = ["nmi", "ari", "asw"] 53 | true_labels = adata.obs[true_label_col] 54 | 55 | results = {} 56 | for metric in metrics: 57 | if metric == "nmi": 58 | from pertpy.tools._perturbation_space._metrics import nmi 59 | 60 | if "average_method" not in kwargs: 61 | kwargs["average_method"] = "arithmetic" # by default in sklearn implementation 62 | 63 | nmi_score = nmi( 64 | true_labels=true_labels, 65 | predicted_labels=adata.obs[cluster_col], 66 | average_method=kwargs["average_method"], 67 | ) 68 | results["nmi"] = nmi_score 69 | 70 | if metric == "ari": 71 | from pertpy.tools._perturbation_space._metrics import ari 72 | 73 | ari_score = ari(true_labels=true_labels, predicted_labels=adata.obs[cluster_col]) 74 | results["ari"] = ari_score 75 | 76 | if metric == "asw": 77 | from pertpy.tools._perturbation_space._metrics import asw 78 | 79 | if "metric" not in kwargs: 80 | kwargs["metric"] = "euclidean" 81 | if "distances" not in kwargs: 82 | distances = pairwise_distances(self.X, metric=kwargs["metric"]) 83 | if "sample_size" not in kwargs: 84 | kwargs["sample_size"] = None 85 | if "random_state" not in kwargs: 86 | kwargs["random_state"] = None 87 | 88 | asw_score = asw( 89 | pairwise_distances=distances, 90 | labels=true_labels, 91 | metric=kwargs["metric"], 92 | sample_size=kwargs["sample_size"], 93 | random_state=kwargs["random_state"], 94 | ) 95 | 96 | results["asw"] = asw_score 97 | 98 | return results 99 | -------------------------------------------------------------------------------- /pertpy/tools/_perturbation_space/_comparison.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | import numpy as np 4 | from scipy.sparse import issparse 5 | from scipy.sparse import vstack as sp_vstack 6 | from sklearn.base import ClassifierMixin 7 | from sklearn.linear_model import LogisticRegression 8 | 9 | if TYPE_CHECKING: 10 | from numpy.typing import NDArray 11 | 12 | 13 | class PerturbationComparison: 14 | """Comparison between real and simulated perturbations.""" 15 | 16 | def compare_classification( 17 | self, 18 | real: np.ndarray, 19 | simulated: np.ndarray, 20 | control: np.ndarray, 21 | clf: ClassifierMixin | None = None, 22 | ) -> float: 23 | """Compare classification accuracy between real and simulated perturbations. 24 | 25 | Trains a classifier on the real perturbation data + the control data and reports a normalized 26 | classification accuracy on the simulated perturbation. 27 | 28 | Args: 29 | real: Real perturbed data. 30 | simulated: Simulated perturbed data. 31 | control: Control data 32 | clf: sklearn classifier to use, `sklearn.linear_model.LogisticRegression` if not provided. 33 | """ 34 | assert real.shape[1] == simulated.shape[1] == control.shape[1] 35 | if clf is None: 36 | clf = LogisticRegression() 37 | n_x = real.shape[0] 38 | data = sp_vstack((real, control)) if issparse(real) else np.vstack((real, control)) 39 | labels = np.concatenate([np.full(real.shape[0], "comp"), np.full(control.shape[0], "ctrl")]) 40 | 41 | clf.fit(data, labels) 42 | norm_score = clf.score(simulated, np.full(simulated.shape[0], "comp")) / clf.score(real, labels[:n_x]) 43 | norm_score = min(1.0, norm_score) 44 | 45 | return norm_score 46 | 47 | def compare_knn( 48 | self, 49 | real: np.ndarray, 50 | simulated: np.ndarray, 51 | control: np.ndarray | None = None, 52 | use_simulated_for_knn: bool = False, 53 | n_neighbors: int = 20, 54 | random_state: int = 0, 55 | n_jobs: int = 1, 56 | ) -> dict[str, float]: 57 | """Calculate proportions of real perturbed and control data points for simulated data. 58 | 59 | Computes proportions of real perturbed, control and simulated (if `use_simulated_for_knn=True`) 60 | data points for simulated data. If control (`C`) is not provided, builds the knn graph from 61 | real perturbed + simulated perturbed. 62 | 63 | Args: 64 | real: Real perturbed data. 65 | simulated: Simulated perturbed data. 66 | control: Control data 67 | use_simulated_for_knn: Include simulted perturbed data (`simulated`) into the knn graph. Only valid when 68 | control (`control`) is provided. 69 | n_neighbors: Number of neighbors to use in k-neighbor graph. 70 | random_state: Random state used for k-neighbor graph construction. 71 | n_jobs: Number of cores to use. Defaults to -1 (all). 72 | 73 | """ 74 | assert real.shape[1] == simulated.shape[1] 75 | if control is not None: 76 | assert real.shape[1] == control.shape[1] 77 | 78 | n_y = simulated.shape[0] 79 | 80 | if control is None: 81 | index_data = sp_vstack((simulated, real)) if issparse(real) else np.vstack((simulated, real)) 82 | else: 83 | datas = (simulated, real, control) if use_simulated_for_knn else (real, control) 84 | index_data = sp_vstack(datas) if issparse(real) else np.vstack(datas) 85 | 86 | y_in_index = use_simulated_for_knn or control is None 87 | c_in_index = control is not None 88 | label_groups = ["comp"] 89 | labels: NDArray[np.str_] = np.full(index_data.shape[0], "comp") 90 | if y_in_index: 91 | labels[:n_y] = "siml" 92 | label_groups.append("siml") 93 | if c_in_index: 94 | labels[-control.shape[0] :] = "ctrl" 95 | label_groups.append("ctrl") 96 | 97 | from pynndescent import NNDescent 98 | 99 | index = NNDescent( 100 | index_data, 101 | n_neighbors=max(50, n_neighbors), 102 | random_state=random_state, 103 | n_jobs=n_jobs, 104 | ) 105 | indices = index.query(simulated, k=n_neighbors)[0] 106 | 107 | uq, uq_counts = np.unique(labels[indices], return_counts=True) 108 | uq_counts_norm = uq_counts / uq_counts.sum() 109 | counts = dict(zip(label_groups, [0.0] * len(label_groups), strict=False)) 110 | counts = dict(zip(uq, uq_counts_norm, strict=False)) 111 | 112 | return counts 113 | -------------------------------------------------------------------------------- /pertpy/tools/_perturbation_space/_metrics.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING, Literal 4 | 5 | from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score, silhouette_score 6 | 7 | if TYPE_CHECKING: 8 | import numpy as np 9 | from numpy._typing import ArrayLike 10 | 11 | 12 | def nmi( 13 | true_labels: np.ndarray, 14 | predicted_labels: np.ndarray, 15 | average_method: Literal["min", "max", "geometric", "arithmetic"] = "arithmetic", 16 | ) -> float: 17 | """Calculates the normalized mutual information score between two sets of clusters. 18 | 19 | See https://scikit-learn.org/stable/modules/generated/sklearn.metrics.normalized_mutual_info_score.html 20 | 21 | Args: 22 | true_labels: A clustering of the data into disjoint subsets. 23 | predicted_labels: A clustering of the data into disjoint subsets. 24 | average_method: How to compute the normalizer in the denominator. 25 | 26 | Returns: 27 | Score between 0.0 and 1.0 in normalized nats (based on the natural logarithm). 1.0 stands for perfectly complete labeling. 28 | """ 29 | return normalized_mutual_info_score( 30 | labels_true=true_labels, labels_pred=predicted_labels, average_method=average_method 31 | ) 32 | 33 | 34 | def ari(true_labels: np.ndarray, predicted_labels: np.ndarray) -> float: 35 | """Calculates the adjusted rand index for chance. 36 | 37 | See https://scikit-learn.org/stable/modules/generated/sklearn.metrics.adjusted_rand_score.html 38 | 39 | Args: 40 | true_labels: Ground truth class labels to be used as a reference. 41 | predicted_labels: Cluster labels to evaluate. 42 | 43 | Returns: 44 | Similarity score between -0.5 and 1.0. Random labelings have an ARI close to 0.0. 1.0 stands for perfect match. 45 | """ 46 | return adjusted_rand_score(labels_true=true_labels, labels_pred=predicted_labels) 47 | 48 | 49 | def asw( 50 | pairwise_distances: ArrayLike, 51 | labels: ArrayLike, 52 | metric: str = "euclidean", 53 | sample_size: int = None, 54 | random_state: int = None, 55 | **kwargs, 56 | ) -> float: 57 | """Computes the average-width silhouette score. 58 | 59 | See: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.silhouette_score.html 60 | 61 | Args: 62 | pairwise_distances: An array of pairwise distances between samples, or a feature array. 63 | labels: Predicted labels for each sample. 64 | metric: The metric to use when calculating distance between instances in a feature array. 65 | If metric is a string, it must be one of the options allowed by metrics.pairwise.pairwise_distances. 66 | If X is the distance array itself, use metric="precomputed". 67 | sample_size: The size of the sample to use when computing the Silhouette Coefficient on a random subset of the data. 68 | If sample_size is None, no sampling is used. 69 | random_state: Determines random number generation for selecting a subset of samples. Used when sample_size is not None. 70 | **kwargs: Any further parameters are passed directly to the distance function. If using a scipy.spatial.distance metric, the parameters are still metric dependent. 71 | 72 | Returns: 73 | Mean Silhouette Coefficient for all samples. 74 | """ 75 | return silhouette_score( 76 | X=pairwise_distances, labels=labels, metric=metric, sample_size=sample_size, random_state=random_state, **kwargs 77 | ) 78 | -------------------------------------------------------------------------------- /pertpy/tools/_scgen/__init__.py: -------------------------------------------------------------------------------- 1 | from pertpy.tools._scgen._scgen import Scgen 2 | -------------------------------------------------------------------------------- /pertpy/tools/_scgen/_base_components.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING 4 | 5 | import jax.numpy as jnp 6 | from flax import linen as nn 7 | 8 | if TYPE_CHECKING: 9 | import jaxlib 10 | 11 | 12 | class FlaxEncoder(nn.Module): 13 | n_latent: int = 10 14 | n_layers: int = 2 15 | n_hidden: int = 800 16 | dropout_rate: float = 0.1 17 | latent_distribution: str = "normal" 18 | use_batch_norm: bool = True 19 | use_layer_norm: bool = False 20 | activation_fn: jaxlib.xla_extension.CompiledFunction = nn.activation.leaky_relu # type: ignore 21 | training: bool | None = None 22 | var_activation: jaxlib.xla_extension.CompiledFunction = jnp.exp # type: ignore 23 | # var_eps: float=1e-4, 24 | 25 | @nn.compact 26 | def __call__(self, x: jnp.ndarray, training: bool | None = None) -> tuple[float, float]: 27 | """Forward pass. 28 | 29 | Args: 30 | x: The input data matrix. 31 | training: Whether to use running training average. 32 | 33 | Returns: 34 | Mean and variance. 35 | """ 36 | training = nn.merge_param("training", self.training, training) 37 | for _ in range(self.n_layers): 38 | x = nn.Dense(self.n_hidden)(x) 39 | if self.use_batch_norm: 40 | x = nn.BatchNorm( 41 | momentum=0.99, 42 | epsilon=0.001, 43 | use_running_average=not training, 44 | )(x) 45 | x = self.activation_fn(x) 46 | if self.use_layer_norm: 47 | x = nn.LayerNorm(x) # type: ignore 48 | x = nn.Dropout(rate=self.dropout_rate, deterministic=not training)(x) if self.dropout_rate > 0 else x 49 | 50 | mean_x = nn.Dense(self.n_latent)(x) 51 | logvar_x = nn.Dense(self.n_latent)(x) 52 | 53 | return mean_x, self.var_activation(logvar_x) 54 | 55 | 56 | class FlaxDecoder(nn.Module): 57 | n_output: int 58 | n_layers: int = 1 59 | n_hidden: int = 128 60 | dropout_rate: float = 0.2 61 | use_batch_norm: bool = False 62 | use_layer_norm: bool = False 63 | activation_fn: nn.activation = nn.activation.leaky_relu # type: ignore 64 | training: bool | None = None 65 | 66 | @nn.compact 67 | def __call__(self, x: jnp.ndarray, training: bool | None = None) -> jnp.ndarray: # type: ignore 68 | """Forward pass. 69 | 70 | Args: 71 | x: Input data. 72 | training: Whether to use running training average. 73 | 74 | Returns: 75 | Decoded data. 76 | """ 77 | training = nn.merge_param("training", self.training, training) 78 | 79 | for _ in range(self.n_layers): 80 | x = nn.Dense(self.n_hidden)(x) 81 | if self.use_batch_norm: 82 | x = nn.BatchNorm( 83 | momentum=0.99, 84 | epsilon=0.001, 85 | use_running_average=not training, 86 | )(x) 87 | x = self.activation_fn(x) # type: ignore 88 | if self.use_layer_norm: 89 | x = nn.LayerNorm(x) # type: ignore 90 | x = nn.Dropout(rate=self.dropout_rate, deterministic=not training)(x) if self.dropout_rate > 0 else x 91 | 92 | x = nn.Dense(self.n_output)(x) # type: ignore 93 | 94 | return x 95 | -------------------------------------------------------------------------------- /pertpy/tools/_scgen/_scgenvae.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import flax.linen as nn 4 | import jax.numpy as jnp 5 | import numpyro.distributions as dist 6 | from scvi import REGISTRY_KEYS 7 | from scvi.module.base import JaxBaseModuleClass, LossOutput, flax_configure 8 | 9 | from ._base_components import FlaxDecoder, FlaxEncoder 10 | 11 | 12 | @flax_configure 13 | class JaxSCGENVAE(JaxBaseModuleClass): 14 | n_input: int 15 | n_hidden: int = 800 16 | n_latent: int = 10 17 | n_layers: int = 2 18 | dropout_rate: float = 0.1 19 | log_variational: bool = False 20 | latent_distribution: str = "normal" 21 | use_batch_norm: str = "both" 22 | use_layer_norm: str = "none" 23 | kl_weight: float = 0.00005 24 | training: bool = True 25 | 26 | def setup(self): 27 | use_batch_norm_encoder = self.use_batch_norm in ("encoder", "both") 28 | use_layer_norm_encoder = self.use_layer_norm in ("encoder", "both") 29 | 30 | self.encoder = FlaxEncoder( 31 | n_latent=self.n_latent, 32 | n_layers=self.n_layers, 33 | n_hidden=self.n_hidden, 34 | dropout_rate=self.dropout_rate, 35 | latent_distribution=self.latent_distribution, 36 | use_batch_norm=use_batch_norm_encoder, 37 | use_layer_norm=use_layer_norm_encoder, 38 | activation_fn=nn.activation.leaky_relu, 39 | training=self.training, 40 | ) 41 | 42 | self.decoder = FlaxDecoder( 43 | n_output=self.n_input, 44 | n_layers=self.n_layers, 45 | n_hidden=self.n_hidden, 46 | activation_fn=nn.activation.leaky_relu, 47 | dropout_rate=self.dropout_rate, 48 | training=self.training, 49 | ) 50 | 51 | @property 52 | def required_rngs(self): 53 | return ("params", "dropout", "z") 54 | 55 | def _get_inference_input(self, tensors: dict[str, jnp.ndarray]): 56 | x = tensors[REGISTRY_KEYS.X_KEY] 57 | 58 | input_dict = {"x": x} 59 | return input_dict 60 | 61 | def inference(self, x: jnp.ndarray, n_samples: int = 1) -> dict: 62 | mean, var = self.encoder(x) 63 | stddev = jnp.sqrt(var) 64 | 65 | qz = dist.Normal(mean, stddev) 66 | z_rng = self.make_rng("z") 67 | sample_shape = () if n_samples == 1 else (n_samples,) 68 | z = qz.rsample(z_rng, sample_shape=sample_shape) 69 | 70 | return {"qz": qz, "z": z} 71 | 72 | def _get_generative_input( 73 | self, 74 | tensors: dict[str, jnp.ndarray], 75 | inference_outputs: dict[str, jnp.ndarray], 76 | ): 77 | # x = tensors[REGISTRY_KEYS.X_KEY] 78 | z = inference_outputs["z"] 79 | # batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] 80 | 81 | input_dict = { 82 | # x=x, 83 | "z": z, 84 | # batch_index=batch_index, 85 | } 86 | return input_dict 87 | 88 | # def generative(self, x, z, batch_index) -> dict: 89 | def generative(self, z) -> dict: 90 | px = self.decoder(z) 91 | return {"px": px} 92 | 93 | def loss(self, tensors, inference_outputs, generative_outputs): 94 | x = tensors[REGISTRY_KEYS.X_KEY] 95 | px = generative_outputs["px"] 96 | qz = inference_outputs["qz"] 97 | 98 | kl_divergence_z = dist.kl_divergence(qz, dist.Normal(0, 1)).sum(-1) 99 | reconst_loss = self.get_reconstruction_loss(px, x) 100 | 101 | weighted_kl_local = self.kl_weight * kl_divergence_z 102 | 103 | loss = jnp.mean(0.5 * reconst_loss + 0.5 * weighted_kl_local) 104 | 105 | return LossOutput( 106 | loss=loss, 107 | reconstruction_loss=reconst_loss, 108 | kl_local=kl_divergence_z, 109 | n_obs_minibatch=x.shape[0], 110 | ) 111 | 112 | def sample( 113 | self, 114 | tensors, 115 | n_samples=1, 116 | ): 117 | inference_kwargs = {"n_samples": n_samples} 118 | ( 119 | inference_outputs, 120 | generative_outputs, 121 | ) = self.forward( 122 | tensors, 123 | inference_kwargs=inference_kwargs, 124 | compute_loss=False, 125 | ) 126 | px = dist.Normal(generative_outputs["px"], 1).sample() 127 | return px 128 | 129 | def get_reconstruction_loss(self, x, px): 130 | return jnp.sum((x - px) ** 2) 131 | -------------------------------------------------------------------------------- /pertpy/tools/_scgen/_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import numpy as np 4 | 5 | 6 | def extractor( 7 | data, 8 | cell_type, 9 | condition_key, 10 | cell_type_key, 11 | ctrl_key, 12 | stim_key, 13 | ): 14 | """Returns a list of `data` files while filtering for a specific `cell_type`. 15 | 16 | Args: 17 | data: `~anndata.AnnData` Annotated data matrix 18 | cell_type: Specific cell type to be extracted from `data`. 19 | condition_key: Key for `.obs` of `data` where conditions can be found. 20 | cell_type_key: Key for `.obs` of `data` where cell types can be found. 21 | ctrl_key: Key for `control` part of the `data` found in `condition_key`. 22 | stim_key: Key for `stimulated` part of the `data` found in `condition_key`. 23 | 24 | Returns: 25 | List of `data` files while filtering for a specific `cell_type`. 26 | 27 | Example: 28 | .. code-block:: python 29 | 30 | import Scgen 31 | import anndata 32 | 33 | train_data = anndata.read("./data/train.h5ad") 34 | test_data = anndata.read("./data/test.h5ad") 35 | train_data_extracted_list = extractor( 36 | train_data, "CD4T", "conditions", "cell_type", "control", "stimulated" 37 | ) 38 | """ 39 | cell_with_both_condition = data[data.obs[cell_type_key] == cell_type] 40 | condition_1 = data[(data.obs[cell_type_key] == cell_type) & (data.obs[condition_key] == ctrl_key)] 41 | condition_2 = data[(data.obs[cell_type_key] == cell_type) & (data.obs[condition_key] == stim_key)] 42 | training = data[~((data.obs[cell_type_key] == cell_type) & (data.obs[condition_key] == stim_key))] 43 | 44 | return [training, condition_1, condition_2, cell_with_both_condition] 45 | 46 | 47 | def balancer( 48 | adata, 49 | cell_type_key, 50 | ): 51 | """Makes cell type populations equal. 52 | 53 | Args: 54 | adata: `~anndata.AnnData` Annotated data matrix. 55 | cell_type_key: key for `.obs` of `data` where cell types can be found. 56 | 57 | Returns: 58 | Equal cell type population Annotated data matrix. 59 | 60 | Example: 61 | .. code-block:: python 62 | 63 | import Scgen 64 | import anndata 65 | 66 | train_data = anndata.read("./train_kang.h5ad") 67 | train_ctrl = train_data[train_data.obs["condition"] == "control", :] 68 | train_ctrl = balancer(train_ctrl, "conditions", "cell_type") 69 | """ 70 | class_names = np.unique(adata.obs[cell_type_key]) 71 | class_pop = {} 72 | for cls in class_names: 73 | class_pop[cls] = adata[adata.obs[cell_type_key] == cls].shape[0] 74 | max_number = np.max(list(class_pop.values())) 75 | index_all = [] 76 | for cls in class_names: 77 | class_index = np.array(adata.obs[cell_type_key] == cls) 78 | index_cls = np.nonzero(class_index)[0] 79 | rng = np.random.default_rng() 80 | index_cls_r = index_cls[rng.choice(len(index_cls), max_number)] 81 | index_all.append(index_cls_r) 82 | 83 | balanced_data = adata[np.concatenate(index_all)].copy() 84 | 85 | return balanced_data 86 | -------------------------------------------------------------------------------- /pertpy/tools/transferlearning_MMD_LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Jindong Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | 5 | @pytest.fixture 6 | def rng(): 7 | return np.random.default_rng() 8 | -------------------------------------------------------------------------------- /tests/metadata/test_cell_line.py: -------------------------------------------------------------------------------- 1 | import anndata 2 | import numpy as np 3 | import pandas as pd 4 | import pertpy as pt 5 | import pytest 6 | from anndata import AnnData 7 | from scipy import sparse 8 | 9 | NUM_CELLS = 100 10 | NUM_GENES = 100 11 | NUM_CELLS_PER_ID = NUM_CELLS // 4 12 | 13 | 14 | pt_metadata = pt.md.CellLine() 15 | 16 | 17 | @pytest.fixture 18 | def adata() -> AnnData: 19 | X = np.random.default_rng().normal(0, 1, (NUM_CELLS, NUM_GENES)) 20 | 21 | obs = pd.DataFrame( 22 | { 23 | "DepMap_ID": ["ACH-000016", "ACH-000049", "ACH-001208", "ACH-000956"] * NUM_CELLS_PER_ID, 24 | "perturbation": ["Midostaurin"] * NUM_CELLS_PER_ID * 4, 25 | }, 26 | index=[str(i) for i in range(NUM_CELLS)], 27 | ) 28 | 29 | var_data = {"gene_name": [f"gene{i}" for i in range(1, NUM_GENES + 1)]} 30 | var = pd.DataFrame(var_data).set_index("gene_name", drop=False).rename_axis("index") 31 | 32 | X = sparse.csr_matrix(X) 33 | adata = anndata.AnnData(X=X, obs=obs, var=var) 34 | 35 | return adata 36 | 37 | 38 | def test_cell_line_annotation(adata): 39 | pt_metadata.annotate(adata=adata) 40 | assert len(adata.obs.columns) == len(pt_metadata.depmap.columns) + 1 # due to the perturbation column 41 | stripped_cell_line_name = ["SLR21", "HEKTE", "TK10", "22RV1"] * NUM_CELLS_PER_ID 42 | assert stripped_cell_line_name == list(adata.obs["StrippedCellLineName"]) 43 | 44 | 45 | def test_gdsc_annotation(adata): 46 | pt_metadata.annotate(adata) 47 | pt_metadata.annotate_from_gdsc(adata, query_id="StrippedCellLineName") 48 | assert "ln_ic50_gdsc" in adata.obs 49 | assert "auc_gdsc" in adata.obs 50 | 51 | 52 | def test_prism_annotation(adata): 53 | adata.obs = pd.DataFrame( 54 | { 55 | "DepMap_ID": ["ACH-000879", "ACH-000488", "ACH-000488", "ACH-000008"] * NUM_CELLS_PER_ID, 56 | "perturbation": ["cytarabine", "cytarabine", "secnidazole", "flutamide"] * NUM_CELLS_PER_ID, 57 | }, 58 | index=[str(i) for i in range(NUM_CELLS)], 59 | ) 60 | 61 | pt_metadata.annotate(adata) 62 | pt_metadata.annotate_from_prism(adata, query_id="DepMap_ID") 63 | assert "ic50_prism" in adata.obs 64 | assert "ec50_prism" in adata.obs 65 | assert "auc_prism" in adata.obs 66 | 67 | 68 | def test_protein_expression_annotation(adata): 69 | pt_metadata.annotate(adata) 70 | pt_metadata.annotate_protein_expression(adata, query_id="StrippedCellLineName") 71 | 72 | assert len(adata.obsm) == 1 73 | assert adata.obsm["proteomics_protein_intensity"].shape == ( 74 | NUM_GENES, 75 | len(pt_metadata.proteomics.uniprot_id.unique()), 76 | ) 77 | 78 | 79 | def test_bulk_rna_expression_annotation(adata): 80 | pt_metadata.annotate(adata) 81 | pt_metadata.annotate_bulk_rna(adata, query_id="DepMap_ID", cell_line_source="broad") 82 | 83 | assert len(adata.obsm) == 1 84 | assert adata.obsm["bulk_rna_broad"].shape == ( 85 | NUM_GENES, 86 | pt_metadata.bulk_rna_broad.shape[1], 87 | ) 88 | 89 | pt_metadata.annotate_bulk_rna(adata, query_id="StrippedCellLineName") 90 | 91 | assert len(adata.obsm) == 2 92 | assert adata.obsm["bulk_rna_sanger"].shape == ( 93 | NUM_GENES, 94 | pt_metadata.bulk_rna_sanger.shape[1], 95 | ) 96 | -------------------------------------------------------------------------------- /tests/metadata/test_compound.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import anndata 4 | import numpy as np 5 | import pandas as pd 6 | import pertpy as pt 7 | import pytest 8 | from anndata import AnnData 9 | from pubchempy import PubChemHTTPError 10 | from scipy import sparse 11 | 12 | NUM_CELLS = 100 13 | NUM_GENES = 100 14 | NUM_CELLS_PER_ID = NUM_CELLS // 4 15 | 16 | 17 | pt_compound = pt.md.Compound() 18 | 19 | 20 | @pytest.fixture 21 | def adata() -> AnnData: 22 | rng = np.random.default_rng(1) 23 | X = rng.standard_normal((NUM_CELLS, NUM_GENES)) 24 | X = np.where(X < 0, 0, X) 25 | 26 | obs = pd.DataFrame( 27 | { 28 | "DepMap_ID": ["ACH-000016", "ACH-000049", "ACH-001208", "ACH-000956"] * NUM_CELLS_PER_ID, 29 | "perturbation": ["AG-490", "Iniparib", "TAK-901", "Quercetin"] * NUM_CELLS_PER_ID, 30 | }, 31 | index=[str(i) for i in range(NUM_GENES)], 32 | ) 33 | 34 | var_data = {"gene_name": [f"gene{i}" for i in range(1, NUM_GENES + 1)]} 35 | var = pd.DataFrame(var_data).set_index("gene_name", drop=False).rename_axis("index") 36 | 37 | X = sparse.csr_matrix(X) 38 | adata = anndata.AnnData(X=X, obs=obs, var=var) 39 | 40 | return adata 41 | 42 | 43 | def test_compound_annotation(adata): 44 | retries = 3 45 | attempt = 0 46 | while attempt < retries: 47 | try: 48 | pt_compound.annotate_compounds(adata=adata, query_id="perturbation") 49 | assert len(adata.obs.columns) == 5 50 | pubchemid = [5328779, 9796068, 16124208, 5280343] * NUM_CELLS_PER_ID 51 | assert pubchemid == list(adata.obs["pubchem_ID"]) 52 | return 53 | except PubChemHTTPError: 54 | if attempt == retries - 1: 55 | pytest.fail("Max retries reached, PubChemHTTPError occurred") 56 | time.sleep(10) 57 | attempt += 1 58 | -------------------------------------------------------------------------------- /tests/metadata/test_drug.py: -------------------------------------------------------------------------------- 1 | import anndata 2 | import numpy as np 3 | import pandas as pd 4 | import pertpy as pt 5 | import pytest 6 | from anndata import AnnData 7 | 8 | pt_drug = pt.md.Drug() 9 | 10 | 11 | @pytest.fixture 12 | def adata() -> AnnData: 13 | rng = np.random.default_rng() 14 | 15 | gene_names = ["SLC6A2", "SSTR3", "COL1A1", "RPS24", "SSTR2"] 16 | adata = anndata.AnnData(X=rng.standard_normal(size=(5, 5)), var=pd.DataFrame(index=gene_names)) 17 | 18 | return adata 19 | 20 | 21 | def test_drug_chembl(adata): 22 | pt_drug.annotate(adata=adata) 23 | assert {"compounds"}.issubset(adata.var.columns) 24 | assert "CHEMBL1693" in adata.var["compounds"]["SLC6A2"] 25 | 26 | 27 | def test_drug_dgidb(adata): 28 | pt_drug.annotate(adata=adata, source="dgidb") 29 | assert {"compounds"}.issubset(adata.var.columns) 30 | assert "AMITIFADINE" in adata.var["compounds"]["SLC6A2"] 31 | 32 | 33 | def test_drug_pharmgkb(adata): 34 | pt_drug.annotate(adata=adata, source="pharmgkb") 35 | assert {"compounds"}.issubset(adata.var.columns) 36 | assert "3,4-methylenedioxymethamphetamine" in adata.var["compounds"]["SLC6A2"] 37 | -------------------------------------------------------------------------------- /tests/metadata/test_moa.py: -------------------------------------------------------------------------------- 1 | import anndata 2 | import numpy as np 3 | import pandas as pd 4 | import pertpy as pt 5 | import pytest 6 | from anndata import AnnData 7 | from scipy import sparse 8 | 9 | NUM_CELLS = 100 10 | NUM_GENES = 100 11 | NUM_CELLS_PER_ID = NUM_CELLS // 4 12 | 13 | 14 | pt_moa = pt.md.Moa() 15 | 16 | 17 | @pytest.fixture 18 | def adata() -> AnnData: 19 | rng = np.random.default_rng(1) 20 | X = rng.standard_normal((NUM_CELLS, NUM_GENES)) 21 | X = np.where(X < 0, 0, X) 22 | 23 | obs = pd.DataFrame( 24 | { 25 | "DepMap_ID": ["ACH-000016", "ACH-000049", "ACH-001208", "ACH-000956"] * NUM_CELLS_PER_ID, 26 | "perturbation": ["AG-490", "Iniparib", "TAK-901", "Quercetin"] * NUM_CELLS_PER_ID, 27 | }, 28 | index=[str(i) for i in range(NUM_GENES)], 29 | ) 30 | 31 | var_data = {"gene_name": [f"gene{i}" for i in range(1, NUM_GENES + 1)]} 32 | var = pd.DataFrame(var_data).set_index("gene_name", drop=False).rename_axis("index") 33 | 34 | X = sparse.csr_matrix(X) 35 | adata = anndata.AnnData(X=X, obs=obs, var=var) 36 | 37 | return adata 38 | 39 | 40 | def test_moa_annotation(adata): 41 | pt_moa.annotate(adata=adata, query_id="perturbation") 42 | assert len(adata.obs.columns) == len(pt_moa.clue.columns) + 1 # due to the DepMap_ID column 43 | assert {"moa", "target"}.issubset(adata.obs) 44 | moa = [ 45 | "EGFR inhibitor|JAK inhibitor", 46 | "PARP inhibitor", 47 | "Aurora kinase inhibitor", 48 | "polar auxin transport inhibitor", 49 | ] * NUM_CELLS_PER_ID 50 | assert moa == list(adata.obs["moa"]) 51 | -------------------------------------------------------------------------------- /tests/preprocessing/test_grna_assignment.py: -------------------------------------------------------------------------------- 1 | import anndata as ad 2 | import numpy as np 3 | import pandas as pd 4 | import pertpy as pt 5 | import pytest 6 | from scipy import sparse 7 | 8 | 9 | @pytest.fixture 10 | def adata_simple(): 11 | exp_matrix = np.array( 12 | [ 13 | [9, 0, 1, 0, 1, 0, 0], 14 | [1, 5, 1, 7, 0, 0, 0], 15 | [2, 0, 1, 0, 0, 8, 0], 16 | [1, 1, 1, 0, 1, 1, 1], 17 | [0, 0, 1, 0, 0, 5, 0], 18 | [9, 0, 1, 7, 0, 0, 0], 19 | [0, 0, 1, 0, 0, 0, 6], 20 | [8, 0, 1, 0, 0, 0, 0], 21 | ] 22 | ).astype(np.float32) 23 | adata = ad.AnnData( 24 | exp_matrix, 25 | obs=pd.DataFrame(index=[f"cell_{i + 1}" for i in range(exp_matrix.shape[0])]), 26 | var=pd.DataFrame(index=[f"guide_{i + 1}" for i in range(exp_matrix.shape[1])]), 27 | ) 28 | return adata 29 | 30 | 31 | @pytest.fixture 32 | def tiny_dense_adata(): 33 | exp_matrix = np.array([[6, 0, 2], [1, 5, 0], [0, 1, 7]]).astype(np.float32) 34 | return ad.AnnData( 35 | exp_matrix, 36 | obs=pd.DataFrame(index=[f"cell_{i + 1}" for i in range(exp_matrix.shape[0])]), 37 | var=pd.DataFrame(index=[f"guide_{i + 1}" for i in range(exp_matrix.shape[1])]), 38 | ) 39 | 40 | 41 | @pytest.fixture 42 | def tiny_sparse_adata(): 43 | sparse_matrix = sparse.csr_matrix(np.array([[6, 0, 2], [1, 5, 0], [0, 1, 7]]).astype(np.float32)) 44 | return ad.AnnData( 45 | sparse_matrix, 46 | obs=pd.DataFrame(index=[f"cell_{i + 1}" for i in range(sparse_matrix.shape[0])]), 47 | var=pd.DataFrame(index=[f"guide_{i + 1}" for i in range(sparse_matrix.shape[1])]), 48 | ) 49 | 50 | 51 | @pytest.mark.parametrize("adata_fixture", ["tiny_dense_adata", "tiny_sparse_adata"]) 52 | def test_grna_threshold_assignment(request, adata_fixture): 53 | adata = request.getfixturevalue(adata_fixture) 54 | threshold = 5 55 | output_layer = "assigned_guides" 56 | assert output_layer not in adata.layers 57 | 58 | ga = pt.pp.GuideAssignment() 59 | ga.assign_by_threshold(adata, assignment_threshold=threshold, output_layer=output_layer) 60 | 61 | assert output_layer in adata.layers 62 | 63 | # Convert to dense for comparison if needed 64 | if sparse.issparse(adata.layers[output_layer]): 65 | result_matrix = adata.layers[output_layer].toarray() 66 | else: 67 | result_matrix = adata.layers[output_layer] 68 | 69 | # Convert original data to dense for comparison if needed 70 | original_matrix = adata.X.toarray() if sparse.issparse(adata.X) else adata.X 71 | 72 | assert np.all(np.logical_xor(original_matrix < threshold, result_matrix == 1)) 73 | 74 | 75 | @pytest.mark.parametrize("adata_fixture", ["tiny_dense_adata", "tiny_sparse_adata"]) 76 | def test_grna_max_assignment(request, adata_fixture): 77 | adata = request.getfixturevalue(adata_fixture) 78 | threshold = 6 79 | obs_key = "assigned_guide" 80 | assert obs_key not in adata.obs 81 | 82 | ga = pt.pp.GuideAssignment() 83 | ga.assign_to_max_guide(adata, assignment_threshold=threshold, obs_key=obs_key) 84 | assert obs_key in adata.obs 85 | assert tuple(adata.obs[obs_key]) == ("guide_1", "Negative", "guide_3") 86 | 87 | 88 | def test_grna_mixture_model(adata_simple): 89 | output_key = "assigned_guide" 90 | assert output_key not in adata_simple.obs 91 | 92 | ga = pt.pp.GuideAssignment() 93 | ga.assign_mixture_model(adata_simple) 94 | assert output_key in adata_simple.obs 95 | target = [f"guide_{i}" if i > 0 else "negative" for i in [1, 4, 6, 0, 6, 1, 7, 1, 0]] 96 | assert all(t in g for t, g in zip(target, adata_simple.obs[output_key], strict=False)) 97 | -------------------------------------------------------------------------------- /tests/tools/_coda/test_sccoda.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import arviz as az 4 | import numpy as np 5 | import pandas as pd 6 | import pertpy as pt 7 | import pytest 8 | import scanpy as sc 9 | from mudata import MuData 10 | 11 | CWD = Path(__file__).parent.resolve() 12 | 13 | 14 | sccoda = pt.tl.Sccoda() 15 | 16 | 17 | @pytest.fixture 18 | def adata(): 19 | cells = pt.dt.haber_2017_regions() 20 | cells = sc.pp.subsample(cells, 0.1, copy=True) 21 | 22 | return cells 23 | 24 | 25 | def test_load(adata): 26 | mdata = sccoda.load( 27 | adata, 28 | type="cell_level", 29 | generate_sample_level=True, 30 | cell_type_identifier="cell_label", 31 | sample_identifier="batch", 32 | covariate_obs=["condition"], 33 | ) 34 | assert isinstance(mdata, MuData) 35 | assert "rna" in mdata.mod 36 | assert "coda" in mdata.mod 37 | 38 | 39 | def test_prepare(adata): 40 | mdata = sccoda.load( 41 | adata, 42 | type="cell_level", 43 | generate_sample_level=True, 44 | cell_type_identifier="cell_label", 45 | sample_identifier="batch", 46 | covariate_obs=["condition"], 47 | ) 48 | mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine") 49 | assert "scCODA_params" in mdata["coda"].uns 50 | assert "covariate_matrix" in mdata["coda"].obsm 51 | assert "sample_counts" in mdata["coda"].obsm 52 | assert isinstance(mdata["coda"].obsm["sample_counts"], np.ndarray) 53 | assert np.sum(mdata["coda"].obsm["covariate_matrix"]) == 6 54 | 55 | 56 | def test_run_nuts(adata): 57 | mdata = sccoda.load( 58 | adata, 59 | type="cell_level", 60 | generate_sample_level=True, 61 | cell_type_identifier="cell_label", 62 | sample_identifier="batch", 63 | covariate_obs=["condition"], 64 | ) 65 | mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine") 66 | sccoda.run_nuts(mdata, num_samples=1000, num_warmup=100) 67 | assert "effect_df_condition[T.Hpoly.Day10]" in mdata["coda"].varm 68 | assert "effect_df_condition[T.Hpoly.Day3]" in mdata["coda"].varm 69 | assert "effect_df_condition[T.Salmonella]" in mdata["coda"].varm 70 | assert "intercept_df" in mdata["coda"].varm 71 | assert mdata["coda"].varm["effect_df_condition[T.Hpoly.Day10]"].shape == (8, 7) 72 | assert mdata["coda"].varm["effect_df_condition[T.Hpoly.Day3]"].shape == (8, 7) 73 | assert mdata["coda"].varm["effect_df_condition[T.Salmonella]"].shape == (8, 7) 74 | assert mdata["coda"].varm["intercept_df"].shape == (8, 5) 75 | 76 | 77 | def test_credible_effects(adata): 78 | adata_salm = adata[adata.obs["condition"].isin(["Control", "Salmonella"])] 79 | mdata = sccoda.load( 80 | adata_salm, 81 | type="cell_level", 82 | generate_sample_level=True, 83 | cell_type_identifier="cell_label", 84 | sample_identifier="batch", 85 | covariate_obs=["condition"], 86 | ) 87 | mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Goblet") 88 | sccoda.run_nuts(mdata) 89 | assert isinstance(sccoda.credible_effects(mdata), pd.Series) 90 | assert sccoda.credible_effects(mdata)["condition[T.Salmonella]"]["Enterocyte"] 91 | 92 | 93 | def test_make_arviz(adata): 94 | adata_salm = adata[adata.obs["condition"].isin(["Control", "Salmonella"])] 95 | mdata = sccoda.load( 96 | adata_salm, 97 | type="cell_level", 98 | generate_sample_level=True, 99 | cell_type_identifier="cell_label", 100 | sample_identifier="batch", 101 | covariate_obs=["condition"], 102 | ) 103 | mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Goblet") 104 | sccoda.run_nuts(mdata) 105 | arviz_data = sccoda.make_arviz(mdata, num_prior_samples=100) 106 | assert isinstance(arviz_data, az.InferenceData) 107 | -------------------------------------------------------------------------------- /tests/tools/_coda/test_tasccoda.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import pytest 5 | import scanpy as sc 6 | from mudata import MuData 7 | 8 | try: 9 | import ete4 10 | except ImportError: 11 | pytest.skip("ete4 not available", allow_module_level=True) 12 | 13 | import pertpy as pt 14 | 15 | CWD = Path(__file__).parent.resolve() 16 | 17 | 18 | tasccoda = pt.tl.Tasccoda() 19 | 20 | 21 | @pytest.fixture 22 | def smillie_adata(): 23 | smillie_adata = pt.dt.tasccoda_example() 24 | smillie_adata = sc.pp.subsample(smillie_adata, 0.1, copy=True) 25 | 26 | return smillie_adata 27 | 28 | 29 | def test_load(smillie_adata): 30 | mdata = tasccoda.load( 31 | smillie_adata, 32 | type="sample_level", 33 | levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"], 34 | key_added="lineage", 35 | add_level_name=True, 36 | ) 37 | assert isinstance(mdata, MuData) 38 | assert "rna" in mdata.mod 39 | assert "coda" in mdata.mod 40 | assert "lineage" in mdata["coda"].uns 41 | 42 | 43 | def test_prepare(smillie_adata): 44 | mdata = tasccoda.load( 45 | smillie_adata, 46 | type="sample_level", 47 | levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"], 48 | key_added="lineage", 49 | add_level_name=True, 50 | ) 51 | mdata = tasccoda.prepare( 52 | mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0} 53 | ) 54 | assert "scCODA_params" in mdata["coda"].uns 55 | assert "covariate_matrix" in mdata["coda"].obsm 56 | assert "sample_counts" in mdata["coda"].obsm 57 | assert isinstance(mdata["coda"].obsm["sample_counts"], np.ndarray) 58 | assert np.sum(mdata["coda"].obsm["covariate_matrix"]) == 8 59 | 60 | 61 | def test_run_nuts(smillie_adata): 62 | mdata = tasccoda.load( 63 | smillie_adata, 64 | type="sample_level", 65 | levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"], 66 | key_added="lineage", 67 | add_level_name=True, 68 | ) 69 | mdata = tasccoda.prepare( 70 | mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0} 71 | ) 72 | tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100) 73 | assert "effect_df_Health[T.Inflamed]" in mdata["coda"].varm 74 | assert "effect_df_Health[T.Non-inflamed]" in mdata["coda"].varm 75 | assert mdata["coda"].varm["effect_df_Health[T.Inflamed]"].shape == (51, 7) 76 | assert mdata["coda"].varm["effect_df_Health[T.Non-inflamed]"].shape == (51, 7) 77 | -------------------------------------------------------------------------------- /tests/tools/_differential_gene_expression/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/tests/tools/_differential_gene_expression/__init__.py -------------------------------------------------------------------------------- /tests/tools/_differential_gene_expression/conftest.py: -------------------------------------------------------------------------------- 1 | import anndata as ad 2 | import numpy as np 3 | import pandas as pd 4 | import pytest 5 | import scipy.sparse as sp 6 | from pydeseq2.utils import load_example_data 7 | 8 | 9 | @pytest.fixture 10 | def test_counts(): 11 | return load_example_data( 12 | modality="raw_counts", 13 | dataset="synthetic", 14 | debug=False, 15 | ) 16 | 17 | 18 | @pytest.fixture 19 | def test_metadata(): 20 | return load_example_data( 21 | modality="metadata", 22 | dataset="synthetic", 23 | debug=False, 24 | ) 25 | 26 | 27 | @pytest.fixture 28 | def test_adata(test_counts, test_metadata): 29 | return ad.AnnData(X=test_counts, obs=test_metadata) 30 | 31 | 32 | @pytest.fixture(params=[np.array, sp.csr_matrix, sp.csc_matrix]) 33 | def test_adata_minimal(request): 34 | matrix_format = request.param 35 | n_obs = 80 36 | n_donors = n_obs // 4 37 | rng = np.random.default_rng(9) # make tests deterministic 38 | obs = pd.DataFrame( 39 | { 40 | "condition": ["A", "B"] * (n_obs // 2), 41 | "donor": sum(([f"D{i}"] * n_donors for i in range(n_obs // n_donors)), []), 42 | "other": (["X"] * (n_obs // 4)) + (["Y"] * ((3 * n_obs) // 4)), 43 | "pairing": sum(([str(i), str(i)] for i in range(n_obs // 2)), []), 44 | "continuous": [rng.uniform(0, 1) * 4000 for _ in range(n_obs)], 45 | }, 46 | ) 47 | var = pd.DataFrame(index=["gene1", "gene2"]) 48 | group1 = rng.negative_binomial(20, 0.1, n_obs // 2) # large mean 49 | group2 = rng.negative_binomial(5, 0.5, n_obs // 2) # small mean 50 | 51 | condition_data = np.empty((n_obs,), dtype=group1.dtype) 52 | condition_data[0::2] = group1 53 | condition_data[1::2] = group2 54 | 55 | donor_data = np.empty((n_obs,), dtype=group1.dtype) 56 | donor_data[0:n_donors] = group2[:n_donors] 57 | donor_data[n_donors : (2 * n_donors)] = group1[n_donors:] 58 | 59 | donor_data[(2 * n_donors) : (3 * n_donors)] = group2[:n_donors] 60 | donor_data[(3 * n_donors) :] = group1[n_donors:] 61 | 62 | X = matrix_format(np.vstack([condition_data, donor_data]).T) 63 | 64 | return ad.AnnData(X=X, obs=obs, var=var) 65 | -------------------------------------------------------------------------------- /tests/tools/_differential_gene_expression/test_base.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Sequence 2 | 3 | import pytest 4 | from pandas.core.api import DataFrame 5 | from pertpy.tools._differential_gene_expression import LinearModelBase 6 | 7 | 8 | @pytest.fixture 9 | def MockLinearModel(): 10 | class _MockLinearModel(LinearModelBase): 11 | def _check_counts(self) -> None: 12 | pass 13 | 14 | def fit(self, **kwargs) -> None: 15 | pass 16 | 17 | def _test_single_contrast(self, contrast: Sequence[float], **kwargs) -> DataFrame: 18 | pass 19 | 20 | return _MockLinearModel 21 | 22 | 23 | @pytest.mark.parametrize( 24 | "formula,cond_kwargs,expected_contrast", 25 | [ 26 | # single variable 27 | ["~ condition", {}, [1, 0]], 28 | ["~ condition", {"condition": "A"}, [1, 0]], 29 | ["~ condition", {"condition": "B"}, [1, 1]], 30 | ["~ condition", {"condition": "42"}, ValueError], # non-existant category 31 | # no-intercept models 32 | ["~ 0 + condition", {"condition": "A"}, [1, 0]], 33 | ["~ 0 + condition", {"condition": "B"}, [0, 1]], 34 | # Different way of specifying dummy coding 35 | ["~ donor", {"donor": "D0"}, [1, 0, 0, 0]], 36 | ["~ C(donor)", {"donor": "D0"}, [1, 0, 0, 0]], 37 | ["~ C(donor, contr.treatment(base='D2'))", {"donor": "D2"}, [1, 0, 0, 0]], 38 | ["~ C(donor, contr.treatment(base='D2'))", {"donor": "D0"}, [1, 1, 0, 0]], 39 | # Handle continuous covariates 40 | ["~ donor + continuous", {"donor": "D1"}, [1, 1, 0, 0, 0]], 41 | ["~ donor + np.log1p(continuous)", {"donor": "D1"}, [1, 1, 0, 0, 0]], 42 | ["~ donor + continuous + np.log1p(continuous)", {"donor": "D0"}, [1, 0, 0, 0, 0, 0]], 43 | # Nonsense models repeating the same variable, which are nonetheless allowed by formulaic 44 | ["~ donor + C(donor)", {"donor": "D1"}, [1, 1, 0, 0, 1, 0, 0]], 45 | ["~ donor + C(donor, contr.treatment(base='D2'))", {"donor": "D0"}, [1, 0, 0, 0, 1, 0, 0]], 46 | [ 47 | "~ condition + donor + C(donor, contr.treatment(base='D2'))", 48 | {"condition": "A"}, 49 | ValueError, 50 | ], # donor base category can't be resolved because it's ambiguous -> ValueError 51 | # Sum2zero coding 52 | ["~ C(donor, contr.sum)", {"donor": "D0"}, [1, 1, 0, 0]], 53 | ["~ C(donor, contr.sum)", {"donor": "D3"}, [1, -1, -1, -1]], 54 | # Multiple categorical variables 55 | ["~ condition + donor", {"condition": "A"}, [1, 0, 0, 0, 0]], 56 | ["~ condition + donor", {"donor": "D2"}, [1, 0, 0, 1, 0]], 57 | ["~ condition + donor", {"condition": "B", "donor": "D2"}, [1, 1, 0, 1, 0]], 58 | ["~ 0 + condition + donor", {"donor": "D1"}, [0, 0, 1, 0, 0]], 59 | # Interaction terms 60 | ["~ condition * donor", {"condition": "A"}, [1, 0, 0, 0, 0, 0, 0, 0]], 61 | ["~ condition + donor + condition:donor", {"condition": "A"}, [1, 0, 0, 0, 0, 0, 0, 0]], 62 | ["~ condition * donor", {"condition": "B", "donor": "D2"}, [1, 1, 0, 1, 0, 0, 1, 0]], 63 | ["~ condition * C(donor, contr.treatment(base='D2'))", {"condition": "A"}, [1, 0, 0, 0, 0, 0, 0, 0]], 64 | [ 65 | "~ condition * C(donor, contr.treatment(base='D2'))", 66 | {"condition": "B", "donor": "D0"}, 67 | [1, 1, 1, 0, 0, 1, 0, 0], 68 | ], 69 | [ 70 | "~ condition:donor", 71 | {"condition": "A"}, 72 | ValueError, 73 | ], # Can't automatically resolve base category, because Formulaic builds a reduced-rank and full-rank factor internally 74 | ["~ condition:donor", {"condition": "A", "donor": "D1"}, [1, 1, 0, 0, 0, 0, 0, 0]], 75 | ["~ condition:C(donor)", {"condition": "A", "donor": "D1"}, [1, 1, 0, 0, 0, 0, 0, 0]], 76 | ], 77 | ) 78 | def test_model_cond(test_adata_minimal, MockLinearModel, formula, cond_kwargs, expected_contrast): 79 | mod = MockLinearModel(test_adata_minimal, formula) 80 | if isinstance(expected_contrast, type): 81 | with pytest.raises(expected_contrast): 82 | mod.cond(**cond_kwargs) 83 | else: 84 | actual_contrast = mod.cond(**cond_kwargs) 85 | assert actual_contrast.tolist() == expected_contrast 86 | assert actual_contrast.index.tolist() == mod.design.columns.tolist() 87 | -------------------------------------------------------------------------------- /tests/tools/_differential_gene_expression/test_compare_groups.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from pertpy.tools._differential_gene_expression import AVAILABLE_METHODS 4 | 5 | 6 | @pytest.mark.parametrize("method", AVAILABLE_METHODS) 7 | @pytest.mark.parametrize("paired_by", ["pairing", None]) 8 | def test_unified_api_single_group(test_adata_minimal, method, paired_by): 9 | """ 10 | Test that all methods implement the unified API. 11 | 12 | Here, we don't check the correctness of the results 13 | (we have the method-specific tests for that), but rather that the interface works 14 | as expected and the format of the resulting data frame is what we expect. 15 | 16 | TODO: tests for layers 17 | TODO: tests for mask 18 | """ 19 | res_df = method.compare_groups( 20 | adata=test_adata_minimal, column="condition", baseline="A", groups_to_compare="B", paired_by=paired_by 21 | ) 22 | assert res_df.shape[0] == test_adata_minimal.shape[1], "The result dataframe must contain a value for each var name" 23 | assert {"variable", "p_value", "log_fc", "adj_p_value"} - set(res_df.columns) == set(), ( 24 | "Mandated column names not in result df" 25 | ) 26 | assert np.all((res_df["p_value"] >= 0) & (res_df["p_value"] <= 1)) 27 | assert np.all((res_df["adj_p_value"] >= 0) & (res_df["adj_p_value"] <= 1)) 28 | assert np.all(res_df["adj_p_value"] >= res_df["p_value"]) 29 | 30 | 31 | @pytest.mark.parametrize("method", AVAILABLE_METHODS) 32 | def test_unified_api_multiple_groups(test_adata_minimal, method): 33 | """ 34 | Test that all methods implement the unified API. 35 | 36 | Here, we don't check the correctness of the results 37 | (we have the method-specific tests for that), but rather that the interface works 38 | as expected and the format of the resulting data frame is what we expect. 39 | 40 | TODO: tests for layers 41 | TODO: tests for mask 42 | """ 43 | res_df = method.compare_groups( 44 | adata=test_adata_minimal, 45 | column="donor", 46 | baseline="D0", 47 | groups_to_compare=["D1", "D2", "D3"], 48 | paired_by=None, # No pairing possible here. 49 | ) 50 | assert res_df.shape[0] == 3 * test_adata_minimal.shape[1], ( 51 | "The result dataframe must contain a value for each var name" 52 | ) 53 | assert {"variable", "p_value", "log_fc", "adj_p_value"} - set(res_df.columns) == set(), ( 54 | "Mandated column names not in result df" 55 | ) 56 | assert np.all((res_df["p_value"] >= 0) & (res_df["p_value"] <= 1)) 57 | assert np.all((res_df["adj_p_value"] >= 0) & (res_df["adj_p_value"] <= 1)) 58 | assert np.all(res_df["adj_p_value"] >= res_df["p_value"]) 59 | -------------------------------------------------------------------------------- /tests/tools/_differential_gene_expression/test_dge.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import pertpy as pt 4 | import pytest 5 | from anndata import AnnData 6 | 7 | pytest.skip("Disabled", allow_module_level=True) 8 | 9 | 10 | @pytest.fixture 11 | def adata(rng): 12 | adata = AnnData(rng.normal(size=(100, 10))) 13 | genes = np.rec.fromarrays( 14 | [np.array([f"gene{i}" for i in range(10)])], 15 | names=["group1", "O"], 16 | ) 17 | adata.uns["de_key1"] = { 18 | "names": genes, 19 | "scores": {"group1": rng.random(10)}, 20 | "pvals_adj": {"group1": rng.random(10)}, 21 | } 22 | adata.uns["de_key2"] = { 23 | "names": genes, 24 | "scores": {"group1": rng.random(10)}, 25 | "pvals_adj": {"group1": rng.random(10)}, 26 | } 27 | return adata 28 | 29 | 30 | @pytest.fixture 31 | def dataframe(rng): 32 | df1 = pd.DataFrame( 33 | { 34 | "variable": ["gene" + str(i) for i in range(10)], 35 | "log_fc": rng.random(10), 36 | "adj_p_value": rng.random(10), 37 | } 38 | ) 39 | df2 = pd.DataFrame( 40 | { 41 | "variable": ["gene" + str(i) for i in range(10)], 42 | "log_fc": rng.random(10), 43 | "adj_p_value": rng.random(10), 44 | } 45 | ) 46 | return df1, df2 47 | 48 | 49 | def test_error_both_keys_and_dfs(adata, dataframe): 50 | with pytest.raises(ValueError): 51 | pt_DGE = pt.tl.DGEEVAL() 52 | pt_DGE.compare(adata=adata, de_key1="de_key1", de_df1=dataframe[0]) 53 | 54 | 55 | def test_error_missing_adata(): 56 | with pytest.raises(ValueError): 57 | pt_DGE = pt.tl.DGEEVAL() 58 | pt_DGE.compare(de_key1="de_key1", de_key2="de_key2") 59 | 60 | 61 | def test_error_missing_df(dataframe): 62 | with pytest.raises(ValueError): 63 | pt_DGE = pt.tl.DGEEVAL() 64 | pt_DGE.compare(de_df1=dataframe[0]) 65 | 66 | 67 | def test_key(adata): 68 | pt_DGE = pt.tl.DGEEVAL() 69 | results = pt_DGE.compare(adata=adata, de_key1="de_key1", de_key2="de_key2", shared_top=5) 70 | assert "shared_top_genes" in results 71 | assert "scores_corr" in results 72 | assert "pvals_adj_corr" in results 73 | assert "scores_ranks_corr" in results 74 | 75 | 76 | def test_df(dataframe): 77 | pt_DGE = pt.tl.DGEEVAL() 78 | results = pt_DGE.compare(de_df1=dataframe[0], de_df2=dataframe[1], shared_top=5) 79 | assert "shared_top_genes" in results 80 | assert "scores_corr" in results 81 | assert "pvals_adj_corr" in results 82 | assert "scores_ranks_corr" in results 83 | -------------------------------------------------------------------------------- /tests/tools/_differential_gene_expression/test_edger.py: -------------------------------------------------------------------------------- 1 | import numpy.testing as npt 2 | import pytest 3 | from pertpy.tools._differential_gene_expression import EdgeR, PyDESeq2 4 | 5 | try: 6 | from rpy2.robjects.packages import importr 7 | 8 | r_dependency = importr("edgeR") 9 | except Exception: # noqa: BLE001 10 | r_dependency = None 11 | 12 | pytestmark = pytest.mark.skipif(r_dependency is None, reason="Required R package 'edgeR' not available") 13 | 14 | 15 | def test_edger_simple(test_adata): 16 | """Check that the EdgeR method can be 17 | 18 | 1. Initialized 19 | 2. Fitted 20 | 3. That test_contrast returns a DataFrame with the correct number of rows 21 | """ 22 | method = EdgeR(adata=test_adata, design="~condition") 23 | method.fit() 24 | res_df = method.test_contrasts(method.contrast("condition", "A", "B")) 25 | 26 | assert len(res_df) == test_adata.n_vars 27 | # Compare against snapshot 28 | npt.assert_almost_equal( 29 | res_df.p_value.values, 30 | [ 31 | 8.0000e-05, 32 | 1.8000e-04, 33 | 5.3000e-04, 34 | 1.1800e-03, 35 | 3.3800e-02, 36 | 3.3820e-02, 37 | 7.7980e-02, 38 | 1.3715e-01, 39 | 2.5052e-01, 40 | 9.2485e-01, 41 | ], 42 | decimal=4, 43 | ) 44 | npt.assert_almost_equal( 45 | res_df.log_fc.values, 46 | [0.61208, -0.39374, 0.57944, 0.7343, -0.58675, 0.42575, -0.23951, -0.20761, 0.17489, 0.0247], 47 | decimal=4, 48 | ) 49 | 50 | 51 | def test_edger_complex(test_adata): 52 | """Check that the EdgeR method can be initialized with a different covariate name and fitted and that the test_contrast 53 | method returns a dataframe with the correct number of rows. 54 | """ 55 | test_adata.obs["condition1"] = test_adata.obs["condition"].copy() 56 | method = EdgeR(adata=test_adata, design="~condition1+group") 57 | method.fit() 58 | res_df = method.test_contrasts(method.contrast("condition1", "A", "B")) 59 | 60 | assert len(res_df) == test_adata.n_vars 61 | # Check that the index of the result matches the var_names of the AnnData object 62 | assert set(test_adata.var_names) == set(res_df["variable"]) 63 | 64 | # Compare ranking of genes from a different method (without design matrix handling) 65 | down_gene = res_df.set_index("variable").loc["gene3", "log_fc"] 66 | up_gene = res_df.set_index("variable").loc["gene1", "log_fc"] 67 | assert down_gene < up_gene 68 | 69 | method = PyDESeq2(adata=test_adata, design="~condition1+group") 70 | method.fit() 71 | deseq_res_df = method.test_contrasts(method.contrast("condition1", "A", "B")) 72 | assert all(res_df.sort_values("log_fc")["variable"].values == deseq_res_df.sort_values("log_fc")["variable"].values) 73 | -------------------------------------------------------------------------------- /tests/tools/_differential_gene_expression/test_input_checks.py: -------------------------------------------------------------------------------- 1 | import anndata as ad 2 | import numpy as np 3 | import pytest 4 | import scipy.sparse as sp 5 | from pertpy.tools._differential_gene_expression import Statsmodels 6 | from pertpy.tools._differential_gene_expression._checks import check_is_integer_matrix, check_is_numeric_matrix 7 | 8 | 9 | @pytest.mark.parametrize( 10 | "matrix_type,invalid_input", 11 | [ 12 | [np.array, np.nan], 13 | [np.array, np.inf], 14 | [np.array, "foo"], 15 | # not possible to have a sparse matrix with 'object' dtype (e.g. "foo") 16 | [sp.csr_matrix, np.nan], 17 | [sp.csr_matrix, np.nan], 18 | [sp.csc_matrix, np.inf], 19 | [sp.csc_matrix, np.inf], 20 | ], 21 | ) 22 | def test_invalid_inputs(matrix_type, invalid_input, test_counts, test_metadata): 23 | """Check that invalid inputs in MethodBase counts raise an error.""" 24 | test_counts[0, 0] = invalid_input 25 | adata = ad.AnnData(X=matrix_type(test_counts), obs=test_metadata) 26 | with pytest.raises((ValueError, TypeError)): 27 | Statsmodels(adata=adata, design="~condition") 28 | 29 | 30 | @pytest.mark.parametrize("matrix_type", [np.array, sp.csr_matrix, sp.csc_matrix]) 31 | @pytest.mark.parametrize( 32 | "input,expected", 33 | [ 34 | pytest.param([[1, 2], [3, 4]], None, id="valid"), 35 | pytest.param([[1, -2], [3, 4]], ValueError, id="negative values"), 36 | pytest.param([[1, 2.5], [3, 4]], ValueError, id="non-integer"), 37 | pytest.param([[1, np.nan], [3, 4]], ValueError, id="nans"), 38 | ], 39 | ) 40 | def test_check_is_integer_matrix(matrix_type, input, expected: type): 41 | """Test for valid integer matrix.""" 42 | matrix = matrix_type(input, dtype=float) 43 | 44 | if expected is None: 45 | check_is_integer_matrix(matrix) 46 | else: 47 | with pytest.raises(expected): 48 | check_is_integer_matrix(matrix) 49 | 50 | 51 | @pytest.mark.parametrize("matrix_type", [np.array, sp.csr_matrix, sp.csc_matrix]) 52 | @pytest.mark.parametrize( 53 | "input,expected", 54 | [ 55 | pytest.param([[1, 2], [3, 4]], None, id="valid"), 56 | pytest.param([[1, -2], [3, 4]], None, id="negative values"), 57 | pytest.param([[1, 2.5], [3, 4]], None, id="non-integer"), 58 | pytest.param([[1, np.nan], [3, 4]], ValueError, id="nans"), 59 | ], 60 | ) 61 | def test_check_is_numeric_matrix(matrix_type, input, expected: type): 62 | """Test for valid numeric matrix. 63 | 64 | This is like the integer matrix check above, except that also negative 65 | and float values are allowed. 66 | """ 67 | matrix = matrix_type(input, dtype=float) 68 | 69 | if expected is None: 70 | check_is_numeric_matrix(matrix) 71 | else: 72 | with pytest.raises(expected): 73 | check_is_numeric_matrix(matrix) 74 | -------------------------------------------------------------------------------- /tests/tools/_differential_gene_expression/test_pydeseq2.py: -------------------------------------------------------------------------------- 1 | from importlib.util import find_spec 2 | 3 | import numpy.testing as npt 4 | import pytest 5 | from pertpy.tools._differential_gene_expression import PyDESeq2 6 | 7 | if find_spec("pydeseq2") is None: 8 | pytestmark = pytest.mark.skip(reason="pydeseq2 not available") 9 | 10 | 11 | def test_pydeseq2_simple(test_adata): 12 | """Check that the pyDESeq2 method can be 13 | 14 | 1. Initialized 15 | 2. Fitted 16 | 3. and that test_contrast returns a DataFrame with the correct number of rows. 17 | """ 18 | method = PyDESeq2(adata=test_adata, design="~condition") 19 | method.fit() 20 | res_df = method.test_contrasts(method.contrast("condition", "A", "B")) 21 | 22 | assert len(res_df) == test_adata.n_vars 23 | # Compare against snapshot 24 | npt.assert_almost_equal( 25 | res_df.p_value.values, 26 | [0.00017, 0.00033, 0.00051, 0.0286, 0.03207, 0.04723, 0.11039, 0.11452, 0.3703, 0.99625], 27 | decimal=4, 28 | ) 29 | npt.assert_almost_equal( 30 | res_df.log_fc.values, 31 | [0.58207, 0.53855, -0.4121, 0.63281, -0.63283, -0.27066, -0.21271, 0.38601, 0.13434, 0.00146], 32 | decimal=4, 33 | ) 34 | 35 | 36 | def test_pydeseq2_complex(test_adata): 37 | """Check that the pyDESeq2 method can be initialized with a different covariate name and fitted and that the test_contrast 38 | method returns a dataframe with the correct number of rows. 39 | """ 40 | test_adata.obs["condition1"] = test_adata.obs["condition"].copy() 41 | method = PyDESeq2(adata=test_adata, design="~condition1+group") 42 | method.fit() 43 | res_df = method.test_contrasts(method.contrast("condition1", "A", "B")) 44 | 45 | assert len(res_df) == test_adata.n_vars 46 | # Check that the index of the result matches the var_names of the AnnData object 47 | assert set(test_adata.var_names) == set(res_df["variable"]) 48 | # Compare against snapshot 49 | npt.assert_almost_equal( 50 | res_df.p_value.values, 51 | [7e-05, 0.00012, 0.00035, 0.01062, 0.01906, 0.03892, 0.10755, 0.11175, 0.36631, 0.94952], 52 | decimal=4, 53 | ) 54 | npt.assert_almost_equal( 55 | res_df.log_fc.values, 56 | [-0.42347, 0.58802, 0.53528, 0.73147, -0.67374, -0.27158, -0.21402, 0.38953, 0.13511, -0.01949], 57 | decimal=4, 58 | ) 59 | 60 | 61 | def test_pydeseq2_formula(test_adata): 62 | """Check that the pyDESeq2 method gives consistent results when specifying contrasts, regardless of the order of covariates""" 63 | model1 = PyDESeq2(adata=test_adata, design="~condition+group") 64 | model1.fit() 65 | res_1 = model1.test_contrasts(model1.contrast("condition", "A", "B")) 66 | 67 | model2 = PyDESeq2(adata=test_adata, design="~group+condition") 68 | model2.fit() 69 | res_2 = model2.test_contrasts(model2.contrast("condition", "A", "B")) 70 | 71 | npt.assert_almost_equal(res_2.log_fc.values, res_1.log_fc.values) 72 | -------------------------------------------------------------------------------- /tests/tools/_differential_gene_expression/test_simple_tests.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import pytest 4 | from pandas.core.api import DataFrame 5 | from pertpy.tools._differential_gene_expression import SimpleComparisonBase, TTest, WilcoxonTest 6 | 7 | 8 | @pytest.mark.parametrize( 9 | "paired_by,expected", 10 | [ 11 | pytest.param( 12 | None, 13 | {"gene1": {"p_value": 1.34e-14, "log_fc": -5.14}, "gene2": {"p_value": 0.54, "log_fc": -0.016}}, 14 | id="unpaired", 15 | ), 16 | pytest.param( 17 | "pairing", 18 | {"gene1": {"p_value": 3.70e-8, "log_fc": -5.14}, "gene2": {"p_value": 0.67, "log_fc": -0.016}}, 19 | id="paired", 20 | ), 21 | ], 22 | ) 23 | def test_wilcoxon(test_adata_minimal, paired_by, expected): 24 | """Test that wilcoxon test gives the correct values. 25 | 26 | Reference values have been computed in R using wilcox.test 27 | """ 28 | res_df = WilcoxonTest.compare_groups( 29 | adata=test_adata_minimal, column="condition", baseline="A", groups_to_compare="B", paired_by=paired_by 30 | ) 31 | actual = res_df.loc[:, ["variable", "p_value", "log_fc"]].set_index("variable").to_dict(orient="index") 32 | for gene in expected: 33 | assert actual[gene] == pytest.approx(expected[gene], abs=0.02) 34 | 35 | 36 | @pytest.mark.parametrize( 37 | "paired_by,expected", 38 | [ 39 | pytest.param( 40 | None, 41 | {"gene1": {"p_value": 2.13e-26, "log_fc": -5.14}, "gene2": {"p_value": 0.96, "log_fc": -0.016}}, 42 | id="unpaired", 43 | ), 44 | pytest.param( 45 | "pairing", 46 | {"gene1": {"p_value": 1.63e-26, "log_fc": -5.14}, "gene2": {"p_value": 0.85, "log_fc": -0.016}}, 47 | id="paired", 48 | ), 49 | ], 50 | ) 51 | def test_t(test_adata_minimal, paired_by, expected): 52 | """Test that t-test gives the correct values. 53 | 54 | Reference values have been computed in R using wilcox.test 55 | """ 56 | res_df = TTest.compare_groups( 57 | adata=test_adata_minimal, column="condition", baseline="A", groups_to_compare="B", paired_by=paired_by 58 | ) 59 | actual = res_df.loc[:, ["variable", "p_value", "log_fc"]].set_index("variable").to_dict(orient="index") 60 | for gene in expected: 61 | assert actual[gene] == pytest.approx(expected[gene], abs=0.02) 62 | 63 | 64 | @pytest.mark.parametrize("seed", range(10)) 65 | def test_simple_comparison_pairing(test_adata_minimal, seed): 66 | """Test that paired samples are properly matched in a paired test""" 67 | 68 | class MockSimpleComparison(SimpleComparisonBase): 69 | @staticmethod 70 | def _test(): 71 | return None 72 | 73 | def _compare_single_group( 74 | self, baseline_idx: np.ndarray, comparison_idx: np.ndarray, *, paired: bool = False, **kwargs 75 | ) -> DataFrame: 76 | assert paired 77 | x0 = self.adata[baseline_idx, :] 78 | x1 = self.adata[comparison_idx, :] 79 | assert np.all(x0.obs["condition"] == "A") 80 | assert np.all(x1.obs["condition"] == "B") 81 | assert np.all(x0.obs["pairing"].values == x1.obs["pairing"].values) 82 | return pd.DataFrame([{"p_value": 1}]) 83 | 84 | rng = np.random.default_rng(seed) 85 | shuffle_adata_idx = rng.permutation(test_adata_minimal.obs_names) 86 | tmp_adata = test_adata_minimal[shuffle_adata_idx, :].copy() 87 | 88 | MockSimpleComparison.compare_groups( 89 | tmp_adata, column="condition", baseline="A", groups_to_compare=["B"], paired_by="pairing" 90 | ) 91 | 92 | 93 | @pytest.mark.parametrize( 94 | "params", 95 | [ 96 | pytest.param( 97 | {"column": "donor", "baseline": "D0", "paired_by": "pairing", "groups_to_compare": "D1"}, 98 | id="pairing not subgroup of donor", 99 | ), 100 | pytest.param( 101 | {"column": "donor", "baseline": "D0", "paired_by": "condition", "groups_to_compare": "D1"}, 102 | id="more than two per group (donor)", 103 | ), 104 | pytest.param( 105 | {"column": "condition", "baseline": "A", "paired_by": "donor", "groups_to_compare": "B"}, 106 | id="more than two per group (condition)", 107 | ), 108 | ], 109 | ) 110 | def test_invalid_pairing(test_adata_minimal, params): 111 | """Test that the SimpleComparisonBase class raises an error when paired analysis is requested with invalid configuration.""" 112 | with pytest.raises(ValueError): 113 | TTest.compare_groups(test_adata_minimal, **params) 114 | -------------------------------------------------------------------------------- /tests/tools/_differential_gene_expression/test_statsmodels.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import statsmodels.api as sm 4 | from pertpy.tools._differential_gene_expression import Statsmodels 5 | 6 | 7 | @pytest.mark.parametrize( 8 | "method_class,kwargs", 9 | [ 10 | # OLS 11 | (Statsmodels, {}), 12 | # Negative Binomial 13 | ( 14 | Statsmodels, 15 | {"regression_model": sm.GLM, "family": sm.families.NegativeBinomial()}, 16 | ), 17 | ], 18 | ) 19 | def test_statsmodels(test_adata, method_class, kwargs): 20 | """Check that the method can be initialized and fitted, and perform basic checks on 21 | the result of test_contrasts.""" 22 | method = method_class(adata=test_adata, design="~condition") # type: ignore 23 | method.fit(**kwargs) 24 | res_df = method.test_contrasts(np.array([0, 1])) 25 | # Check that the result has the correct number of rows 26 | assert len(res_df) == test_adata.n_vars 27 | 28 | 29 | # TODO: there should be a test checking if, for a concrete example, the output p-values and effect sizes are what 30 | # we expect (-> frozen snapshot, that way we also get a heads-up if something changes upstream) 31 | -------------------------------------------------------------------------------- /tests/tools/_distances/test_distance_tests.py: -------------------------------------------------------------------------------- 1 | import pertpy as pt 2 | import pytest 3 | import scanpy as sc 4 | from pandas import DataFrame 5 | 6 | distances = [ 7 | "edistance", 8 | "euclidean", 9 | "mse", 10 | "mean_absolute_error", 11 | "pearson_distance", 12 | "spearman_distance", 13 | "kendalltau_distance", 14 | "cosine_distance", 15 | "wasserstein", 16 | "mean_pairwise", 17 | "mmd", 18 | "r2_distance", 19 | "sym_kldiv", 20 | "t_test", 21 | "ks_test", 22 | "classifier_proba", 23 | # "classifier_cp", 24 | # "nbll", 25 | "mahalanobis", 26 | "mean_var_distribution", 27 | ] 28 | 29 | count_distances = ["nb_ll"] 30 | 31 | 32 | @pytest.fixture 33 | def adata(): 34 | adata = pt.dt.distance_example() 35 | adata = sc.pp.subsample(adata, 0.1, copy=True) 36 | 37 | return adata 38 | 39 | 40 | @pytest.mark.parametrize("distance", distances) 41 | def test_distancetest(adata, distance): 42 | etest = pt.tl.DistanceTest(distance, n_perms=10, obsm_key="X_pca", alpha=0.05, correction="holm-sidak") 43 | tab = etest(adata, groupby="perturbation", contrast="control") 44 | 45 | # Well-defined output 46 | assert tab.shape[1] == 5 47 | assert isinstance(tab, DataFrame) 48 | 49 | # p-values are in [0,1] 50 | assert tab["pvalue"].min() >= 0 51 | assert tab["pvalue"].max() <= 1 52 | assert tab["pvalue_adj"].min() >= 0 53 | assert tab["pvalue_adj"].max() <= 1 54 | -------------------------------------------------------------------------------- /tests/tools/_perturbation_space/test_comparison.py: -------------------------------------------------------------------------------- 1 | import pertpy as pt 2 | import pytest 3 | 4 | 5 | @pytest.fixture 6 | def test_data(rng): 7 | X = rng.normal(size=(100, 10)) 8 | Y = rng.normal(size=(100, 10)) 9 | C = rng.normal(size=(100, 10)) 10 | return X, Y, C 11 | 12 | 13 | def test_compare_class(test_data): 14 | X, Y, C = test_data 15 | pt_comparison = pt.tl.PerturbationComparison() 16 | result = pt_comparison.compare_classification(X, Y, C) 17 | assert result <= 1 18 | 19 | 20 | def test_compare_knn(test_data): 21 | X, Y, C = test_data 22 | pt_comparison = pt.tl.PerturbationComparison() 23 | result = pt_comparison.compare_knn(X, Y, C) 24 | assert isinstance(result, dict) 25 | assert "comp" in result 26 | assert isinstance(result["comp"], float) 27 | 28 | result_no_ctrl = pt_comparison.compare_knn(X, Y) 29 | assert isinstance(result_no_ctrl, dict) 30 | -------------------------------------------------------------------------------- /tests/tools/_perturbation_space/test_discriminator_classifiers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import pertpy as pt 4 | import pytest 5 | from anndata import AnnData 6 | 7 | 8 | @pytest.fixture 9 | def adata(): 10 | X = np.zeros((20, 5), dtype=np.float32) 11 | 12 | pert_index = [ 13 | "control", 14 | "target1", 15 | "target1", 16 | "target2", 17 | "target2", 18 | "target1", 19 | "target1", 20 | "target2", 21 | "target2", 22 | "target2", 23 | "control", 24 | "target1", 25 | "target1", 26 | "target2", 27 | "target2", 28 | "target1", 29 | "target1", 30 | "target2", 31 | "target2", 32 | "target2", 33 | ] 34 | 35 | for i, value in enumerate(pert_index): 36 | if value == "control": 37 | X[i, :] = 0 38 | elif value == "target1": 39 | X[i, :] = 10 40 | elif value == "target2": 41 | X[i, :] = 30 42 | 43 | obs = pd.DataFrame({"perturbations": pert_index}) 44 | 45 | adata = AnnData(X, obs=obs) 46 | 47 | # Add a obs annotations to the adata 48 | adata.obs["MoA"] = ["Growth" if pert == "target1" else "Unknown" for pert in adata.obs["perturbations"]] 49 | adata.obs["Partial Annotation"] = ["Anno1" if pert == "target2" else np.nan for pert in adata.obs["perturbations"]] 50 | 51 | return adata 52 | 53 | 54 | def test_mlp_classifier_space(adata): 55 | classifier_ps = pt.tl.MLPClassifierSpace() 56 | pert_embeddings = classifier_ps.compute(adata, hidden_dim=[128], max_epochs=2) 57 | 58 | # The embeddings should cluster in 3 perfects clusters since the perturbations are easily separable 59 | ps = pt.tl.KMeansSpace() 60 | adata = ps.compute(pert_embeddings, n_clusters=3, copy=True) 61 | results = ps.evaluate_clustering(adata, true_label_col="perturbations", cluster_col="k-means") 62 | np.testing.assert_equal(len(results), 3) 63 | np.testing.assert_allclose(results["nmi"], 0.99, rtol=0.1) 64 | np.testing.assert_allclose(results["ari"], 0.99, rtol=0.1) 65 | np.testing.assert_allclose(results["asw"], 0.99, rtol=0.1) 66 | 67 | 68 | def test_regression_classifier_space(adata): 69 | ps = pt.tl.LRClassifierSpace() 70 | pert_embeddings = ps.compute(adata) 71 | 72 | assert pert_embeddings.shape == (3, 5) 73 | assert pert_embeddings.obs[pert_embeddings.obs["perturbations"] == "target1"]["MoA"].values == "Growth" 74 | assert "Partial Annotation" not in pert_embeddings.obs_names 75 | # The classifier should be able to distinguish control and target2 from the respective other two classes 76 | assert np.all( 77 | pert_embeddings.obs[pert_embeddings.obs["perturbations"].isin(["control", "target2"])][ 78 | "classifier_score" 79 | ].values 80 | == 1.0 81 | ) 82 | -------------------------------------------------------------------------------- /tests/tools/_perturbation_space/test_simple_cluster_space.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import pertpy as pt 4 | from anndata import AnnData 5 | 6 | 7 | def test_clustering(): 8 | X = np.zeros((10, 5)) 9 | 10 | pert_index = [ 11 | "control", 12 | "target1", 13 | "target1", 14 | "target2", 15 | "target2", 16 | "target1", 17 | "target1", 18 | "target2", 19 | "target2", 20 | "target2", 21 | ] 22 | 23 | for i, value in enumerate(pert_index): 24 | if value == "control": 25 | X[i, :] = 0 26 | elif value == "target1": 27 | X[i, :] = 10 28 | elif value == "target2": 29 | X[i, :] = 30 30 | 31 | obs = pd.DataFrame({"perturbations": pert_index}) 32 | 33 | adata = AnnData(X, obs=obs) 34 | 35 | # Compute clustering at observation level 36 | ps = pt.tl.KMeansSpace() 37 | adata = ps.compute(adata, n_clusters=3, copy=True) 38 | 39 | ps = pt.tl.DBSCANSpace() 40 | adata = ps.compute(adata, min_samples=1, copy=True) 41 | 42 | results = ps.evaluate_clustering(adata, true_label_col="perturbations", cluster_col="k-means", metric="l1") 43 | np.testing.assert_equal(len(results), 3) 44 | np.testing.assert_allclose(results["nmi"], 0.99, rtol=0.1) 45 | np.testing.assert_allclose(results["ari"], 0.99, rtol=0.1) 46 | np.testing.assert_allclose(results["asw"], 0.99, rtol=0.1) 47 | 48 | results = ps.evaluate_clustering(adata, true_label_col="perturbations", cluster_col="dbscan", metric="l1") 49 | np.testing.assert_equal(len(results), 3) 50 | np.testing.assert_allclose(results["nmi"], 0.99, rtol=0.1) 51 | np.testing.assert_allclose(results["ari"], 0.99, rtol=0.1) 52 | np.testing.assert_allclose(results["asw"], 0.99, rtol=0.1) 53 | 54 | np.testing.assert_allclose(results["nmi"], 0.99, rtol=0.1) 55 | np.testing.assert_allclose(results["ari"], 0.99, rtol=0.1) 56 | np.testing.assert_allclose(results["asw"], 0.99, rtol=0.1) 57 | -------------------------------------------------------------------------------- /tests/tools/test_augur.py: -------------------------------------------------------------------------------- 1 | from math import isclose 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | import pertpy as pt 6 | import pytest 7 | import scanpy as sc 8 | 9 | CWD = Path(__file__).parent.resolve() 10 | 11 | 12 | ag_rfc = pt.tl.Augur("random_forest_classifier", random_state=42) 13 | ag_lrc = pt.tl.Augur("logistic_regression_classifier", random_state=42) 14 | ag_rfr = pt.tl.Augur("random_forest_regressor", random_state=42) 15 | 16 | 17 | @pytest.fixture 18 | def adata(): 19 | adata = pt.dt.sc_sim_augur() 20 | adata = sc.pp.subsample(adata, n_obs=200, copy=True, random_state=10) 21 | 22 | return adata 23 | 24 | 25 | def test_load(adata): 26 | """Test if load function creates anndata objects.""" 27 | ag = pt.tl.Augur(estimator="random_forest_classifier") 28 | 29 | loaded_adata = ag.load(adata) 30 | loaded_df = ag.load(adata.to_df(), meta=adata.obs, cell_type_col="cell_type", label_col="label") 31 | 32 | assert loaded_adata.obs["y_"].equals(loaded_df.obs["y_"]) is True 33 | assert adata.to_df().equals(loaded_adata.to_df()) is True and adata.to_df().equals(loaded_df.to_df()) 34 | 35 | 36 | def test_random_forest_classifier(adata): 37 | """Tests random forest for auc calculation.""" 38 | adata = ag_rfc.load(adata) 39 | sc.pp.highly_variable_genes(adata) 40 | h_adata, results = ag_rfc.predict( 41 | adata, n_threads=4, n_subsamples=3, random_state=42, select_variance_features=False 42 | ) 43 | 44 | assert results["CellTypeA"][2]["subsample_idx"] == 2 45 | assert "augur_score" in h_adata.obs.columns 46 | assert np.allclose(results["summary_metrics"].loc["mean_augur_score"].tolist(), [0.634920, 0.933484, 0.902494]) 47 | assert "feature_importances" in results 48 | assert len(set(results["summary_metrics"]["CellTypeA"])) == len(results["summary_metrics"]["CellTypeA"]) - 1 49 | 50 | 51 | def test_logistic_regression_classifier(adata): 52 | """Tests logistic classifier for auc calculation.""" 53 | adata = ag_rfc.load(adata) 54 | sc.pp.highly_variable_genes(adata) 55 | h_adata, results = ag_lrc.predict( 56 | adata, n_threads=4, n_subsamples=3, random_state=42, select_variance_features=False 57 | ) 58 | 59 | assert "augur_score" in h_adata.obs.columns 60 | assert np.allclose(results["summary_metrics"].loc["mean_augur_score"].tolist(), [0.691232, 0.955404, 0.972789]) 61 | assert "feature_importances" in results 62 | 63 | 64 | def test_random_forest_regressor(adata): 65 | """Tests random forest regressor for ccc calculation.""" 66 | adata = ag_rfc.load(adata) 67 | sc.pp.highly_variable_genes(adata) 68 | 69 | with pytest.raises(ValueError): 70 | ag_rfr.predict(adata, n_threads=4, n_subsamples=3, random_state=42) 71 | 72 | 73 | def test_classifier(adata): 74 | """Test run cross validation with classifier.""" 75 | adata = ag_rfc.load(adata) 76 | sc.pp.highly_variable_genes(adata) 77 | adata_subsampled = sc.pp.subsample(adata, n_obs=100, random_state=42, copy=True) 78 | 79 | cv = ag_rfc.run_cross_validation(adata_subsampled, subsample_idx=1, folds=3, random_state=42, zero_division=0) 80 | auc = 0.786412 81 | assert any([isclose(cv["mean_auc"], auc, abs_tol=10**-3)]) 82 | 83 | cv = ag_lrc.run_cross_validation(adata, subsample_idx=1, folds=3, random_state=42, zero_division=0) 84 | auc = 0.978673 85 | assert any([isclose(cv["mean_auc"], auc, abs_tol=10**-3)]) 86 | 87 | 88 | def test_regressor(adata): 89 | """Test run cross validation with regressor.""" 90 | adata = ag_rfc.load(adata) 91 | cv = ag_rfr.run_cross_validation(adata, subsample_idx=1, folds=3, random_state=42, zero_division=0) 92 | ccc = 0.168800 93 | r2 = 0.149887 94 | assert any([isclose(cv["mean_ccc"], ccc, abs_tol=10**-5), isclose(cv["mean_r2"], r2, abs_tol=10**-5)]) 95 | 96 | 97 | def test_subsample(adata): 98 | """Test default, permute and velocity subsampling process.""" 99 | adata = ag_rfc.load(adata) 100 | sc.pp.highly_variable_genes(adata) 101 | categorical_subsample = ag_rfc.draw_subsample( 102 | adata=adata, 103 | augur_mode="default", 104 | subsample_size=20, 105 | feature_perc=0.3, 106 | categorical=True, 107 | random_state=42, 108 | ) 109 | assert len(categorical_subsample.obs_names) == 40 110 | 111 | non_categorical_subsample = ag_rfc.draw_subsample( 112 | adata=adata, 113 | augur_mode="default", 114 | subsample_size=20, 115 | feature_perc=0.3, 116 | categorical=False, 117 | random_state=42, 118 | ) 119 | assert len(non_categorical_subsample.obs_names) == 20 120 | 121 | permut_subsample = ag_rfc.draw_subsample( 122 | adata=adata, 123 | augur_mode="permute", 124 | subsample_size=20, 125 | feature_perc=0.3, 126 | categorical=True, 127 | random_state=42, 128 | ) 129 | assert (adata.obs.loc[permut_subsample.obs.index, "y_"] != permut_subsample.obs["y_"]).any() 130 | 131 | velocity_subsample = ag_rfc.draw_subsample( 132 | adata=adata, 133 | augur_mode="velocity", 134 | subsample_size=20, 135 | feature_perc=0.3, 136 | categorical=True, 137 | random_state=42, 138 | ) 139 | assert len(velocity_subsample.var_names) == 5505 and len(velocity_subsample.obs_names) == 40 140 | 141 | 142 | def test_select_variance(adata): 143 | """Test select variance implementation.""" 144 | adata = ag_rfc.load(adata) 145 | sc.pp.highly_variable_genes(adata) 146 | adata_cell_type = adata[adata.obs["cell_type"] == "CellTypeA"].copy() 147 | ad = ag_rfc.select_variance(adata_cell_type, var_quantile=0.5, span=0.3, filter_negative_residuals=False) 148 | 149 | assert len(ad.var.index[ad.var["highly_variable"]]) == 3672 150 | 151 | 152 | def test_differential_prioritization(): 153 | """Test differential prioritization run.""" 154 | # Requires the full dataset or it fails because of a lack of statistical power 155 | adata = pt.dt.sc_sim_augur() 156 | adata = sc.pp.subsample(adata, n_obs=500, copy=True, random_state=10) 157 | ag = pt.tl.Augur("logistic_regression_classifier", random_state=42) 158 | ag.load(adata) 159 | 160 | adata, results1 = ag.predict(adata, n_threads=4, n_subsamples=3, random_state=2) 161 | adata, results2 = ag.predict(adata, n_threads=4, n_subsamples=3, random_state=42) 162 | 163 | a, permut1 = ag.predict(adata, augur_mode="permute", n_threads=4, n_subsamples=100, random_state=2) 164 | a, permut2 = ag.predict(adata, augur_mode="permute", n_threads=4, n_subsamples=100, random_state=42) 165 | delta = ag.predict_differential_prioritization(results1, results2, permut1, permut2) 166 | assert not np.isnan(delta["z"]).any() 167 | -------------------------------------------------------------------------------- /tests/tools/test_cinemaot.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import pertpy as pt 5 | import scanpy as sc 6 | from _pytest.fixtures import fixture 7 | 8 | CWD = Path(__file__).parent.resolve() 9 | 10 | 11 | @fixture 12 | def adata(): 13 | adata = pt.dt.cinemaot_example() 14 | adata = sc.pp.subsample(adata, 0.1, copy=True) 15 | 16 | return adata 17 | 18 | 19 | def test_unweighted(adata): 20 | sc.pp.pca(adata) 21 | model = pt.tl.Cinemaot() 22 | de = model.causaleffect( 23 | adata, 24 | pert_key="perturbation", 25 | control="No stimulation", 26 | return_matching=True, 27 | thres=0.5, 28 | smoothness=1e-5, 29 | eps=1e-3, 30 | solver="Sinkhorn", 31 | preweight_label="cell_type0528", 32 | ) 33 | 34 | eps = 1e-1 35 | assert "cf" in adata.obsm 36 | assert "ot" in de.obsm 37 | assert not np.isnan(np.sum(de.obsm["ot"])) 38 | assert not np.abs(np.sum(de.obsm["ot"]) - 1) > eps 39 | 40 | 41 | def test_weighted(adata): 42 | sc.pp.pca(adata) 43 | model = pt.tl.Cinemaot() 44 | ad, de = model.causaleffect_weighted( 45 | adata, 46 | pert_key="perturbation", 47 | control="No stimulation", 48 | return_matching=True, 49 | thres=0.5, 50 | smoothness=1e-5, 51 | eps=1e-3, 52 | solver="Sinkhorn", 53 | ) 54 | 55 | eps = 1e-1 56 | assert "cf" in ad.obsm 57 | assert "ot" in de.obsm 58 | assert not np.isnan(np.sum(de.obsm["ot"])) 59 | assert not np.abs(np.sum(de.obsm["ot"]) - 1) > eps 60 | 61 | 62 | def test_pseudobulk(adata): 63 | sc.pp.pca(adata) 64 | model = pt.tl.Cinemaot() 65 | de = model.causaleffect( 66 | adata, 67 | pert_key="perturbation", 68 | control="No stimulation", 69 | return_matching=True, 70 | thres=0.5, 71 | smoothness=1e-5, 72 | eps=1e-3, 73 | solver="Sinkhorn", 74 | preweight_label="cell_type0528", 75 | ) 76 | adata_pb = model.generate_pseudobulk(adata, de, pert_key="perturbation", control="No stimulation", label_list=None) 77 | 78 | expect_num = 9 79 | eps = 7 80 | assert "ptb" in adata_pb.obs 81 | assert not np.isnan(np.sum(adata_pb.X)) 82 | print(adata_pb.shape) 83 | assert not np.abs(adata_pb.shape[0] - expect_num) >= eps 84 | -------------------------------------------------------------------------------- /tests/tools/test_dialogue.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import pertpy as pt 3 | import scanpy as sc 4 | 5 | # This is not a proper test! 6 | # We are only testing a few functions to ensure that at least these run 7 | # The pipeline is obtained from https://pertpy.readthedocs.io/en/latest/tutorials/notebooks/dialogue.html 8 | 9 | 10 | def test_dialogue_pipeline(): 11 | adata = pt.dt.dialogue_example() 12 | 13 | sc.pp.pca(adata) 14 | sc.pp.neighbors(adata) 15 | sc.tl.umap(adata) 16 | 17 | isecs = pd.crosstab(adata.obs["cell.subtypes"], adata.obs["sample"]) 18 | adata = adata[adata.obs["cell.subtypes"] != "CD8+ IL17+"] 19 | isecs = pd.crosstab(adata.obs["cell.subtypes"], adata.obs["sample"]) 20 | 21 | keep_pts = list(isecs.loc[:, (isecs > 3).sum(axis=0) == isecs.shape[0]].columns.values) 22 | adata = adata[adata.obs["sample"].isin(keep_pts), :].copy() 23 | 24 | dl = pt.tl.Dialogue( 25 | sample_id="sample", 26 | celltype_key="cell.subtypes", 27 | n_counts_key="nCount_RNA", 28 | n_mpcs=3, 29 | ) 30 | 31 | adata, mcps, ws, ct_subs = dl.calculate_multifactor_PMD(adata, normalize=True) 32 | 33 | dl.test_association(adata, "path_str") 34 | 35 | dl.get_extrema_MCP_genes(ct_subs) 36 | -------------------------------------------------------------------------------- /tests/tools/test_enrichment.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pertpy as pt 3 | import pytest 4 | import scanpy as sc 5 | from anndata import AnnData 6 | 7 | 8 | @pytest.fixture 9 | def dummy_adata(): 10 | n_obs = 10 11 | n_vars = 5 12 | rng = np.random.default_rng() 13 | X = rng.random((n_obs, n_vars)) 14 | adata = AnnData(X) 15 | adata.var_names = [f"gene{i}" for i in range(n_vars)] 16 | adata.obs["cluster"] = ["group_1"] * 5 + ["group_2"] * 5 17 | sc.tl.rank_genes_groups(adata, groupby="cluster", method="t-test") 18 | 19 | return adata 20 | 21 | 22 | @pytest.fixture(scope="module") 23 | def enricher(): 24 | return pt.tl.Enrichment() 25 | 26 | 27 | def test_score_basic(dummy_adata, enricher): 28 | targets = {"group1": ["gene1", "gene2"], "group2": ["gene3", "gene4"]} 29 | enricher.score(adata=dummy_adata, targets=targets) 30 | assert "pertpy_enrichment_score" in dummy_adata.uns 31 | 32 | 33 | def test_score_with_different_layers(dummy_adata, enricher): 34 | rng = np.random.default_rng() 35 | dummy_adata.layers["layer"] = rng.random((10, 5)) 36 | targets = {"group1": ["gene1", "gene2"], "group2": ["gene3", "gene4"]} 37 | enricher.score(adata=dummy_adata, layer="layer", targets=targets) 38 | assert "pertpy_enrichment_score" in dummy_adata.uns 39 | 40 | 41 | def test_score_with_nested_targets(dummy_adata, enricher): 42 | targets = {"category1": {"group1": ["gene1", "gene2"]}, "category2": {"group2": ["gene3", "gene4"]}} 43 | enricher.score(adata=dummy_adata, targets=targets, nested=True) 44 | assert "pertpy_enrichment_score" in dummy_adata.uns 45 | 46 | 47 | def test_hypergeometric_basic(dummy_adata, enricher): 48 | targets = {"group1": ["gene1", "gene2"]} 49 | results = enricher.hypergeometric(dummy_adata, targets) 50 | assert isinstance(results, dict) 51 | 52 | 53 | def test_hypergeometric_with_nested_targets(dummy_adata, enricher): 54 | targets = {"category1": {"group1": ["gene1", "gene2"]}} 55 | results = enricher.hypergeometric(dummy_adata, targets, nested=True) 56 | assert isinstance(results, dict) 57 | 58 | 59 | @pytest.mark.parametrize("direction", ["up", "down", "both"]) 60 | def test_hypergeometric_with_different_directions(dummy_adata, enricher, direction): 61 | targets = {"group1": ["gene1", "gene2"]} 62 | results = enricher.hypergeometric(dummy_adata, targets, direction=direction) 63 | assert isinstance(results, dict) 64 | -------------------------------------------------------------------------------- /tests/tools/test_scgen.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import anndata as ad 4 | import pertpy as pt 5 | import scanpy as sc 6 | from scvi.data import synthetic_iid 7 | 8 | 9 | def test_scgen(): 10 | with warnings.catch_warnings(): 11 | warnings.filterwarnings("ignore", message="Observation names are not unique") 12 | adata = synthetic_iid() 13 | adata.obs_names_make_unique() 14 | pt.tl.Scgen.setup_anndata( 15 | adata, 16 | batch_key="batch", 17 | labels_key="labels", 18 | ) 19 | 20 | scg = pt.tl.Scgen(adata) 21 | scg.train(max_epochs=1, batch_size=32, early_stopping=True, early_stopping_patience=25) 22 | 23 | scg.batch_removal() 24 | 25 | # predict 26 | pred, delta = scg.predict(ctrl_key="batch_0", stim_key="batch_1", celltype_to_predict="label_0") 27 | pred.obs["batch"] = "pred" 28 | 29 | # reg mean and reg var 30 | ctrl_adata = adata[((adata.obs["labels"] == "label_0") & (adata.obs["batch"] == "batch_0"))] 31 | stim_adata = adata[((adata.obs["labels"] == "label_0") & (adata.obs["batch"] == "batch_1"))] 32 | eval_adata = ad.concat([ctrl_adata, stim_adata, pred], label="concat_batches") 33 | label_0 = adata[adata.obs["labels"] == "label_0"] 34 | sc.tl.rank_genes_groups(label_0, groupby="batch", method="wilcoxon") 35 | diff_genes = label_0.uns["rank_genes_groups"]["names"]["batch_1"] 36 | 37 | scg.plot_reg_mean_plot( 38 | eval_adata, 39 | condition_key="batch", 40 | axis_keys={"x": "pred", "y": "batch_1"}, 41 | gene_list=diff_genes[:10], 42 | labels={"x": "predicted", "y": "ground truth"}, 43 | save=False, 44 | show=False, 45 | legend=False, 46 | ) 47 | 48 | scg.plot_reg_var_plot( 49 | eval_adata, 50 | condition_key="batch", 51 | axis_keys={"x": "pred", "y": "batch_1"}, 52 | gene_list=diff_genes[:10], 53 | labels={"x": "predicted", "y": "ground truth"}, 54 | save=False, 55 | show=False, 56 | legend=False, 57 | ) 58 | --------------------------------------------------------------------------------