├── .editorconfig
├── .gitattributes
├── .github
├── ISSUE_TEMPLATE
│ ├── bug_report.yml
│ ├── config.yml
│ └── feature_request.yml
├── labels.yml
├── pull_request_template.md
├── release-drafter.yml
└── workflows
│ ├── build.yml
│ ├── labeler.yml
│ ├── release.yml
│ ├── release_drafter.yml
│ └── test.yml
├── .gitignore
├── .gitmodules
├── .pre-commit-config.yaml
├── .readthedocs.yml
├── LICENSE
├── README.md
├── biome.jsonc
├── codecov.yml
├── docs
├── Makefile
├── _ext
│ ├── edit_on_github.py
│ └── typed_returns.py
├── _static
│ ├── SCVI_LICENSE
│ ├── css
│ │ ├── overwrite.css
│ │ └── sphinx_gallery.css
│ ├── docstring_previews
│ │ ├── augur_dp_scatter.png
│ │ ├── augur_important_features.png
│ │ ├── augur_lollipop.png
│ │ ├── augur_scatterplot.png
│ │ ├── de_fold_change.png
│ │ ├── de_multicomparison_fc.png
│ │ ├── de_paired_expression.png
│ │ ├── de_volcano.png
│ │ ├── dialogue_pairplot.png
│ │ ├── dialogue_violin.png
│ │ ├── enrichment_dotplot.png
│ │ ├── enrichment_gsea.png
│ │ ├── milo_da_beeswarm.png
│ │ ├── milo_nhood.png
│ │ ├── milo_nhood_graph.png
│ │ ├── mixscape_barplot.png
│ │ ├── mixscape_heatmap.png
│ │ ├── mixscape_lda.png
│ │ ├── mixscape_perturbscore.png
│ │ ├── mixscape_violin.png
│ │ ├── pseudobulk_samples.png
│ │ ├── sccoda_boxplots.png
│ │ ├── sccoda_effects_barplot.png
│ │ ├── sccoda_rel_abundance_dispersion_plot.png
│ │ ├── sccoda_stacked_barplot.png
│ │ ├── scgen_reg_mean.png
│ │ ├── tasccoda_draw_effects.png
│ │ ├── tasccoda_draw_tree.png
│ │ └── tasccoda_effects_umap.png
│ ├── icons
│ │ ├── code-24px.svg
│ │ ├── computer-24px.svg
│ │ ├── library_books-24px.svg
│ │ └── play_circle_outline-24px.svg
│ ├── pertpy_logo.png
│ ├── pertpy_logo.svg
│ ├── placeholder.png
│ └── tutorials
│ │ ├── augur.png
│ │ ├── cinemaot.png
│ │ ├── dge.png
│ │ ├── dialogue.png
│ │ ├── distances.png
│ │ ├── distances_tests.png
│ │ ├── enrichment.png
│ │ ├── guide_rna_assignment.png
│ │ ├── mcfarland.png
│ │ ├── metadata.png
│ │ ├── milo.png
│ │ ├── mixscape.png
│ │ ├── norman.png
│ │ ├── ontology.png
│ │ ├── perturbation_space.png
│ │ ├── placeholder.png
│ │ ├── sccoda.png
│ │ ├── sccoda_extended.png
│ │ ├── scgen_perturbation_prediction.png
│ │ ├── tasccoda.png
│ │ └── zhang.png
├── _templates
│ └── autosummary
│ │ └── class.rst
├── about
│ ├── background.md
│ └── cite.md
├── api.md
├── api
│ ├── datasets_index.md
│ ├── metadata_index.md
│ ├── preprocessing_index.md
│ └── tools_index.md
├── changelog.md
├── conf.py
├── contributing.md
├── index.md
├── installation.md
├── make.bat
├── references.bib
├── references.md
├── tutorials.md
├── tutorials
│ ├── metadata.md
│ ├── preprocessing.md
│ └── tools.md
├── usecases.md
└── utils.py
├── pertpy
├── __init__.py
├── _doc.py
├── _types.py
├── data
│ ├── __init__.py
│ ├── _dataloader.py
│ └── _datasets.py
├── metadata
│ ├── __init__.py
│ ├── _cell_line.py
│ ├── _compound.py
│ ├── _drug.py
│ ├── _look_up.py
│ ├── _metadata.py
│ └── _moa.py
├── plot
│ └── __init__.py
├── preprocessing
│ ├── __init__.py
│ ├── _guide_rna.py
│ └── _guide_rna_mixture.py
├── py.typed
└── tools
│ ├── __init__.py
│ ├── _augur.py
│ ├── _cinemaot.py
│ ├── _coda
│ ├── __init__.py
│ ├── _base_coda.py
│ ├── _sccoda.py
│ └── _tasccoda.py
│ ├── _dialogue.py
│ ├── _differential_gene_expression
│ ├── __init__.py
│ ├── _base.py
│ ├── _checks.py
│ ├── _dge_comparison.py
│ ├── _edger.py
│ ├── _pydeseq2.py
│ ├── _simple_tests.py
│ └── _statsmodels.py
│ ├── _distances
│ ├── __init__.py
│ ├── _distance_tests.py
│ └── _distances.py
│ ├── _enrichment.py
│ ├── _milo.py
│ ├── _mixscape.py
│ ├── _perturbation_space
│ ├── __init__.py
│ ├── _clustering.py
│ ├── _comparison.py
│ ├── _discriminator_classifiers.py
│ ├── _metrics.py
│ ├── _perturbation_space.py
│ └── _simple.py
│ ├── _scgen
│ ├── __init__.py
│ ├── _base_components.py
│ ├── _scgen.py
│ ├── _scgenvae.py
│ └── _utils.py
│ ├── decoupler_LICENSE
│ └── transferlearning_MMD_LICENSE
├── pyproject.toml
└── tests
├── conftest.py
├── metadata
├── test_cell_line.py
├── test_compound.py
├── test_drug.py
└── test_moa.py
├── preprocessing
└── test_grna_assignment.py
└── tools
├── _coda
├── test_sccoda.py
└── test_tasccoda.py
├── _differential_gene_expression
├── __init__.py
├── conftest.py
├── test_base.py
├── test_compare_groups.py
├── test_dge.py
├── test_edger.py
├── test_input_checks.py
├── test_pydeseq2.py
├── test_simple_tests.py
└── test_statsmodels.py
├── _distances
├── test_distance_tests.py
└── test_distances.py
├── _perturbation_space
├── test_comparison.py
├── test_discriminator_classifiers.py
├── test_simple_cluster_space.py
└── test_simple_perturbation_space.py
├── test_augur.py
├── test_cinemaot.py
├── test_dialogue.py
├── test_enrichment.py
├── test_milo.py
├── test_mixscape.py
└── test_scgen.py
/.editorconfig:
--------------------------------------------------------------------------------
1 | # http://editorconfig.org
2 |
3 | root = true
4 |
5 | [*]
6 | indent_style = space
7 | indent_size = 4
8 | trim_trailing_whitespace = true
9 | insert_final_newline = true
10 | charset = utf-8
11 | end_of_line = lf
12 |
13 | [*.bat]
14 | indent_style = tab
15 | end_of_line = crlf
16 |
17 | [LICENSE]
18 | insert_final_newline = true
19 |
20 | [Makefile]
21 | indent_style = tab
22 |
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | * text=auto eol=lf
2 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.yml:
--------------------------------------------------------------------------------
1 | name: Bug report
2 | description: pertpy doesn't do what it should? Please help us fix it!
3 | #title: ...
4 | labels:
5 | - bug
6 | - triage
7 | #assignees: []
8 | body:
9 | - type: checkboxes
10 | id: terms
11 | attributes:
12 | label: Please make sure these conditions are met
13 | # description: ...
14 | options:
15 | - label: I have checked that this issue has not already been reported.
16 | required: true
17 | - label: I have confirmed this bug exists on the latest version of pertpy.
18 | required: true
19 | - label: (optional) I have confirmed this bug exists on the main branch.
20 | required: false
21 | - type: markdown
22 | attributes:
23 | value: |
24 | **Note**: Please read [this guide][] detailing how to provide the necessary information for us to reproduce your bug.
25 |
26 | [this guide]: https://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports
27 | - type: textarea
28 | id: Report
29 | attributes:
30 | label: Report
31 | description: |
32 | Describe the bug you encountered, and what you were trying to do. Please use [github markdown][] features for readability.
33 |
34 | [github markdown]: https://docs.github.com/en/get-started/writing-on-github/getting-started-with-writing-and-formatting-on-github/basic-writing-and-formatting-syntax
35 | value: |
36 | Code:
37 |
38 | ```python
39 |
40 | ```
41 |
42 | Traceback:
43 |
44 | ```pytb
45 |
46 | ```
47 | validations:
48 | required: true
49 | - type: textarea
50 | id: versions
51 | attributes:
52 | label: Versions
53 | description: |
54 | Which version of pertpy and other related software you used.
55 |
56 | Please install `session-info2`, run the following command in a notebook,
57 | click the “Copy as Markdown” button,
58 | then paste the results into the text box below.
59 |
60 | ```python
61 | In[1]: import pertpy, session_info2; session_info2.session_info(dependencies=True)
62 | ```
63 |
64 | Alternatively, run this in a console:
65 |
66 | ```python
67 | >>> import pertpy, session_info2; print(session_info2.session_info(dependencies=True)._repr_mimebundle_()["text/markdown"])
68 | ```
69 | render: markdown
70 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/config.yml:
--------------------------------------------------------------------------------
1 | blank_issues_enabled: false
2 | contact_links:
3 | - name: Scverse Community Forum
4 | url: https://discourse.scverse.org/
5 | about: If you have questions about “How to do X”, please ask them here.
6 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.yml:
--------------------------------------------------------------------------------
1 | name: Feature request
2 | description: Propose a new feature for pertpy
3 | labels: enhancement
4 | body:
5 | - type: textarea
6 | id: description
7 | attributes:
8 | label: Description of feature
9 | description: Please describe your suggestion for a new feature. It might help to describe a problem or use case, plus any alternatives that you have considered.
10 | validations:
11 | required: true
12 |
--------------------------------------------------------------------------------
/.github/labels.yml:
--------------------------------------------------------------------------------
1 | ---
2 | # Labels names are important as they are used by Release Drafter to decide
3 | # regarding where to record them in changelog or if to skip them.
4 | #
5 | # The repository labels will be automatically configured using this file and
6 | # the GitHub Action https://github.com/marketplace/actions/github-labeler.
7 | - name: breaking
8 | description: Breaking Changes
9 | color: bfd4f2
10 | - name: bug
11 | description: Something isn't working
12 | color: d73a4a
13 | - name: build
14 | description: Build System and Dependencies
15 | color: bfdadc
16 | - name: ci
17 | description: Continuous Integration
18 | color: 4a97d6
19 | - name: dependencies
20 | description: Pull requests that update a dependency file
21 | color: 0366d6
22 | - name: documentation
23 | description: Improvements or additions to documentation
24 | color: 0075ca
25 | - name: duplicate
26 | description: This issue or pull request already exists
27 | color: cfd3d7
28 | - name: enhancement
29 | description: New feature or request
30 | color: a2eeef
31 | - name: github_actions
32 | description: Pull requests that update Github_actions code
33 | color: "000000"
34 | - name: good first issue
35 | description: Good for newcomers
36 | color: 7057ff
37 | - name: help wanted
38 | description: Extra attention is needed
39 | color: 008672
40 | - name: invalid
41 | description: This doesn't seem right
42 | color: e4e669
43 | - name: performance
44 | description: Performance
45 | color: "016175"
46 | - name: python
47 | description: Pull requests that update Python code
48 | color: 2b67c6
49 | - name: question
50 | description: Further information is requested
51 | color: d876e3
52 | - name: refactoring
53 | description: Refactoring
54 | color: ef67c4
55 | - name: removal
56 | description: Removals and Deprecations
57 | color: 9ae7ea
58 | - name: style
59 | description: Style
60 | color: c120e5
61 | - name: testing
62 | description: Testing
63 | color: b1fc6f
64 | - name: wontfix
65 | description: This will not be worked on
66 | color: ffffff
67 | - name: skip-changelog
68 | description: Changes that should be omitted from the release notes
69 | color: ededed
70 |
--------------------------------------------------------------------------------
/.github/pull_request_template.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | **PR Checklist**
4 |
5 |
6 |
7 | - [ ] Referenced issue is linked
8 | - [ ] If you've fixed a bug or added code that should be tested, add tests!
9 | - [ ] Documentation in `docs` is updated
10 |
11 | **Description of changes**
12 |
13 |
14 |
15 | **Technical details**
16 |
17 |
18 |
19 | **Additional context**
20 |
21 |
22 |
--------------------------------------------------------------------------------
/.github/release-drafter.yml:
--------------------------------------------------------------------------------
1 | name-template: "1.0.0 🌈"
2 | tag-template: 1.0.0
3 | exclude-labels:
4 | - "skip-changelog"
5 |
6 | categories:
7 | - title: "🚀 Features"
8 | labels:
9 | - feature
10 | - enhancement
11 | - title: "🐛 Bug Fixes"
12 | labels:
13 | - fix
14 | - bugfix
15 | - bug
16 | - title: "🧰 Maintenance"
17 | label: chore
18 | - title: ":package: Dependencies"
19 | labels:
20 | - dependencies
21 | - build
22 | - dependabot
23 | - DEPENDABOT
24 | version-resolver:
25 | major:
26 | labels:
27 | - major
28 | minor:
29 | labels:
30 | - minor
31 | patch:
32 | labels:
33 | - patch
34 | default: patch
35 | autolabeler:
36 | - label: chore
37 | files:
38 | - "*.md"
39 | branch:
40 | - '/docs{0,1}\/.+/'
41 | - label: bug
42 | branch:
43 | - /fix\/.+/
44 | title:
45 | - /fix/i
46 | - label: enhancement
47 | branch:
48 | - /feature\/.+/
49 | body:
50 | - "/JIRA-[0-9]{1,4}/"
51 | template: |
52 | ## Changes
53 |
54 | $CHANGES
55 |
--------------------------------------------------------------------------------
/.github/workflows/build.yml:
--------------------------------------------------------------------------------
1 | name: Check Build
2 |
3 | on:
4 | push:
5 | branches: [main]
6 | pull_request:
7 | branches: [main]
8 |
9 | concurrency:
10 | group: ${{ github.workflow }}-${{ github.ref }}
11 | cancel-in-progress: true
12 |
13 | defaults:
14 | run:
15 | # to fail on error in multiline statements (-e), in pipes (-o pipefail), and on unset variables (-u).
16 | shell: bash -euo pipefail {0}
17 |
18 | jobs:
19 | package:
20 | runs-on: ubuntu-latest
21 | steps:
22 | - uses: actions/checkout@v4
23 | with:
24 | filter: blob:none
25 | fetch-depth: 0
26 |
27 | - name: Install uv
28 | uses: astral-sh/setup-uv@v6
29 | with:
30 | cache-dependency-glob: pyproject.toml
31 |
32 | - name: Build package
33 | run: uv build
34 |
35 | - name: Check package
36 | run: uvx twine check --strict dist/*.whl
37 |
--------------------------------------------------------------------------------
/.github/workflows/labeler.yml:
--------------------------------------------------------------------------------
1 | name: Labeler
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | - master
8 |
9 | jobs:
10 | labeler:
11 | runs-on: ubuntu-latest
12 | steps:
13 | - name: Check out the repository
14 | uses: actions/checkout@v4
15 |
16 | - name: Run Labeler
17 | uses: crazy-max/ghaction-github-labeler@v4.1.0
18 | with:
19 | skip-delete: true
20 |
--------------------------------------------------------------------------------
/.github/workflows/release.yml:
--------------------------------------------------------------------------------
1 | name: Release
2 |
3 | on:
4 | release:
5 | types: [published]
6 |
7 | defaults:
8 | run:
9 | shell: bash -euo pipefail {0}
10 |
11 | jobs:
12 | release:
13 | name: Upload release to PyPI
14 | runs-on: ubuntu-latest
15 | environment:
16 | name: pypi
17 | url: https://pypi.org/p/pertpy
18 | permissions:
19 | id-token: write
20 | steps:
21 | - uses: actions/checkout@v4
22 | with:
23 | filter: blob:none
24 | fetch-depth: 0
25 |
26 | - name: Install uv
27 | uses: astral-sh/setup-uv@v5
28 | with:
29 | cache-dependency-glob: pyproject.toml
30 |
31 | - name: Build package
32 | run: uv build
33 | - name: Publish package distributions to PyPI
34 | uses: pypa/gh-action-pypi-publish@release/v1
35 |
--------------------------------------------------------------------------------
/.github/workflows/release_drafter.yml:
--------------------------------------------------------------------------------
1 | name: Release Drafter
2 | on:
3 | push:
4 | branches:
5 | - main
6 | pull_request:
7 | branches:
8 | - main
9 | types:
10 | - opened
11 | - reopened
12 | - synchronize
13 | jobs:
14 | update_release_draft:
15 | runs-on: ubuntu-latest
16 | steps:
17 | - uses: release-drafter/release-drafter@v5
18 | env:
19 | GITHUB_TOKEN: "${{ secrets.GITHUB_TOKEN }}"
20 |
--------------------------------------------------------------------------------
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | name: Test
2 |
3 | on:
4 | push:
5 | branches: [main]
6 | pull_request:
7 | branches: [main]
8 |
9 | concurrency:
10 | group: ${{ github.workflow }}-${{ github.ref }}
11 | cancel-in-progress: true
12 |
13 | jobs:
14 | test:
15 | runs-on: ${{ matrix.os }}
16 | defaults:
17 | run:
18 | # to fail on error in multiline statements (-e), in pipes (-o pipefail), and on unset variables (-u).
19 | shell: bash -euo pipefail {0}
20 |
21 | strategy:
22 | fail-fast: false
23 | matrix:
24 | include:
25 | - os: ubuntu-latest
26 | python: "3.13"
27 | - os: ubuntu-latest
28 | python: "3.13"
29 | pip-flags: "--pre"
30 |
31 | env:
32 | OS: ${{ matrix.os }}
33 | PYTHON: ${{ matrix.python }}
34 |
35 | steps:
36 | - uses: actions/checkout@v4
37 | with:
38 | filter: blob:none
39 | fetch-depth: 0
40 |
41 | - name: Cache .pertpy_cache
42 | uses: actions/cache@v4
43 | with:
44 | path: cache
45 | key: ${{ runner.os }}-pertpy-cache-${{ hashFiles('pertpy/metadata/**') }}
46 | restore-keys: |
47 | ${{ runner.os }}-pertpy-cache
48 |
49 | - name: Set up Python ${{ matrix.python }}
50 | uses: actions/setup-python@v5
51 | with:
52 | python-version: ${{ matrix.python }}
53 | - name: Install R
54 | uses: r-lib/actions/setup-r@v2
55 | with:
56 | r-version: "4.4.3"
57 |
58 | - name: Cache R packages
59 | id: r-cache
60 | uses: actions/cache@v3
61 | with:
62 | path: ${{ env.R_LIBS_USER }}
63 | key: ${{ runner.os }}-r-${{ hashFiles('**/pertpy/tools/_milo.py') }}
64 | restore-keys: ${{ runner.os }}-r-
65 |
66 | - name: Install R dependencies
67 | if: steps.r-cache.outputs.cache-hit != 'true'
68 | run: |
69 | mkdir -p ${{ env.R_LIBS_USER }}
70 | Rscript --vanilla -e "install.packages(c('BiocManager', 'statmod'), repos='https://cran.r-project.org'); BiocManager::install('edgeR', lib='${{ env.R_LIBS_USER }}')"
71 |
72 | - name: Install uv
73 | uses: astral-sh/setup-uv@v6
74 | with:
75 | enable-cache: true
76 | cache-dependency-glob: pyproject.toml
77 | - name: Install dependencies
78 | run: |
79 | uv pip install --system rpy2
80 | uv pip install --system ${{ matrix.pip-flags }} ".[dev,test,tcoda,de]"
81 |
82 | - name: Test
83 | env:
84 | MPLBACKEND: agg
85 | PLATFORM: ${{ matrix.os }}
86 | DISPLAY: :42
87 | run: |
88 | pytest_args="-m pytest -v --color=yes"
89 | coverage run $pytest_args
90 |
91 | - name: Show coverage report
92 | run: coverage report -m
93 |
94 | - name: Upload coverage
95 | uses: codecov/codecov-action@v4
96 | with:
97 | token: ${{ secrets.CODECOV_TOKEN }}
98 | fail_ci_if_error: true
99 | verbose: true
100 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 |
31 | # PyInstaller
32 | # Usually these files are written by a python script from a template
33 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
34 | *.manifest
35 | *.spec
36 |
37 | # Installer logs
38 | pip-log.txt
39 | pip-delete-this-directory.txt
40 |
41 | # Unit test / coverage reports
42 | htmlcov/
43 | .tox/
44 | .nox/
45 | .coverage
46 | .coverage.*
47 | .cache
48 | .pertpy_cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | *.py,cover
53 | .hypothesis/
54 | .pytest_cache/
55 |
56 | # Translations
57 | *.mo
58 | *.pot
59 |
60 | # Django stuff:
61 | *.log
62 | local_settings.py
63 | db.sqlite3
64 | db.sqlite3-journal
65 |
66 | # Flask stuff:
67 | instance/
68 | .webassets-cache
69 |
70 | # Scrapy stuff:
71 | .scrapy
72 |
73 | # Sphinx documentation
74 | docs/_build/
75 | docs/api/data
76 | docs/api/metadata
77 | docs/api/tools
78 | docs/api/preprocessing
79 | !docs/api/api.md
80 |
81 | # PyBuilder
82 | target/
83 |
84 | # Jupyter Notebook
85 | .ipynb_checkpoints
86 |
87 | # IPython
88 | profile_default/
89 | ipython_config.py
90 |
91 | # pyenv
92 | .python-version
93 |
94 |
95 | # pipenv
96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
99 | # install all needed dependencies.
100 | #Pipfile.lock
101 |
102 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
103 | __pypackages__/
104 |
105 | # Celery stuff
106 | celerybeat-schedule
107 | celerybeat.pid
108 |
109 | # SageMath parsed files
110 | *.sage.py
111 |
112 | # Environments
113 | .env
114 | .venv
115 | env/
116 | venv/
117 | ENV/
118 | env.bak/
119 | venv.bak/
120 |
121 | # Spyder project settings
122 | .spyderproject
123 | .spyproject
124 |
125 | # Rope project settings
126 | .ropeproject
127 |
128 | # mkdocs documentation
129 | /site
130 |
131 | # mypy
132 | .mypy_cache/
133 | .dmypy.json
134 | dmypy.json
135 | .pytype/
136 |
137 | # Pyre type checker
138 | .pyre/
139 |
140 | # Jetbrains IDE
141 | .idea/
142 |
143 | # VSCode
144 | .vscode
145 |
146 | # Coala
147 | *.orig
148 |
149 | # Datasets
150 | *.h5ad
151 | *.h5md
152 | *.h5mu
153 |
154 | # Test cache
155 | cache
156 |
157 | # Apple stuff
158 | .DS_Store
159 |
160 | lightning_logs/*
161 | */lightning_logs/*
162 |
163 | node_modules
164 |
165 | # lamindb
166 | test.ipynb
167 | test-perturbation
168 | test-bug
169 |
170 | # uv
171 | uv.lock
172 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "docs/notebooks"]
2 | path = docs/tutorials/notebooks
3 | url = https://github.com/scverse/pertpy-tutorials/
4 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | fail_fast: false
2 | default_language_version:
3 | python: python3
4 | default_stages:
5 | - pre-commit
6 | - pre-push
7 | minimum_pre_commit_version: 2.16.0
8 | repos:
9 | - repo: https://github.com/biomejs/pre-commit
10 | rev: v1.9.4
11 | hooks:
12 | - id: biome-format
13 | - repo: https://github.com/astral-sh/ruff-pre-commit
14 | rev: v0.11.9
15 | hooks:
16 | - id: ruff
17 | args: [--fix, --exit-non-zero-on-fix, --unsafe-fixes]
18 | - id: ruff-format
19 | - repo: https://github.com/pre-commit/pre-commit-hooks
20 | rev: v5.0.0
21 | hooks:
22 | - id: detect-private-key
23 | - id: check-ast
24 | - id: end-of-file-fixer
25 | - id: mixed-line-ending
26 | args: [--fix=lf]
27 | - id: trailing-whitespace
28 | - id: check-case-conflict
29 | - id: check-added-large-files
30 | - id: check-toml
31 | - id: check-yaml
32 | - id: check-merge-conflict
33 | - id: no-commit-to-branch
34 | args: ["--branch=main"]
35 | - repo: https://github.com/pre-commit/mirrors-mypy
36 | rev: v1.15.0
37 | hooks:
38 | - id: mypy
39 | args: [--no-strict-optional, --ignore-missing-imports]
40 | additional_dependencies:
41 | ["types-setuptools", "types-requests", "types-attrs"]
42 | - repo: local
43 | hooks:
44 | - id: forbid-to-commit
45 | name: Don't commit rej files
46 | entry: |
47 | Cannot commit .rej files. These indicate merge conflicts that arise during automated template updates.
48 | Fix the merge conflicts manually and remove the .rej files.
49 | language: fail
50 | files: '.*\.rej$'
51 |
--------------------------------------------------------------------------------
/.readthedocs.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 | build:
3 | os: ubuntu-24.04
4 | tools:
5 | python: "3.13"
6 | jobs:
7 | create_environment:
8 | - asdf plugin add uv
9 | - asdf install uv latest
10 | - asdf global uv latest
11 | - uv venv
12 | - uv pip install .[doc,tcoda,de]
13 | build:
14 | html:
15 | - uv run sphinx-build -T -W -b html docs $READTHEDOCS_OUTPUT/html
16 |
17 | submodules:
18 | include: all
19 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021, Lukas Heumos
4 | Copyright (c) 2025, scverse®
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://github.com/scverse/pertpy/actions/workflows/build.yml)
2 | [](https://codecov.io/gh/scverse/pertpy)
3 | [](https://opensource.org/licenses/Apache2.0)
4 | [](https://pypi.org/project/pertpy/)
5 | [](https://pypi.org/project/pertpy)
6 | [](https://pertpy.readthedocs.io/)
7 | [](https://github.com/scverse/pertpy/actions/workflows/test.yml)
8 | [](https://github.com/pre-commit/pre-commit)
9 |
10 | # pertpy - Perturbation Analysis in Python
11 |
12 | Pertpy is a scverse ecosystem framework for analyzing large-scale single-cell perturbation experiments.
13 | It provides tools for harmonizing perturbation datasets, automating metadata annotation, calculating perturbation distances, and efficiently analyzing how cells respond to various stimuli like genetic modifications, drug treatments, and environmental changes.
14 |
15 | 
16 |
17 | ## Documentation
18 |
19 | Please read the [documentation](https://pertpy.readthedocs.io/en/latest) for installation, tutorials, use cases, and more.
20 |
21 | ## Installation
22 |
23 | We recommend installing and running pertpy on a recent version of Linux (e.g. Ubuntu 24.04 LTS).
24 | No particular hardware beyond a standard laptop is required.
25 |
26 | You can install _pertpy_ in less than a minute via [pip] from [PyPI]:
27 |
28 | ```console
29 | pip install pertpy
30 | ```
31 |
32 | or [conda-forge]:
33 |
34 | ```console
35 | conda install -c conda-forge pertpy
36 | ```
37 |
38 | ### Differential gene expression
39 |
40 | If you want to use the differential gene expression interface, please install pertpy by running:
41 |
42 | ```console
43 | pip install 'pertpy[de]'
44 | ```
45 |
46 | ### tascCODA
47 |
48 | if you want to use tascCODA, please install pertpy as follows:
49 |
50 | ```console
51 | pip install 'pertpy[tcoda]'
52 | ```
53 |
54 | ### milo
55 |
56 | milo requires either the "de" extra for the "pydeseq2" solver:
57 |
58 | ```console
59 | pip install 'pertpy[de]'
60 | ```
61 |
62 | or, edger, statmod, and rpy2 for the "edger" solver:
63 |
64 | ```R
65 | BiocManager::install("edgeR")
66 | BiocManager::install("statmod")
67 | ```
68 |
69 | ```console
70 | pip install rpy2
71 | ```
72 |
73 | ## Citation
74 |
75 | ```bibtex
76 | @article {Heumos2024.08.04.606516,
77 | author = {Heumos, Lukas and Ji, Yuge and May, Lilly and Green, Tessa and Zhang, Xinyue and Wu, Xichen and Ostner, Johannes and Peidli, Stefan and Schumacher, Antonia and Hrovatin, Karin and Müller, Michaela and Chong, Faye and Sturm, Gregor and Tejada, Alejandro and Dann, Emma and Dong, Mingze and Bahrami, Mojtaba and Gold, Ilan and Rybakov, Sergei and Namsaraeva, Altana and Moinfar, Amir and Zheng, Zihe and Roellin, Eljas and Mekki, Isra and Sander, Chris and Lotfollahi, Mohammad and Schiller, Herbert B. and Theis, Fabian J.},
78 | title = {Pertpy: an end-to-end framework for perturbation analysis},
79 | elocation-id = {2024.08.04.606516},
80 | year = {2024},
81 | doi = {10.1101/2024.08.04.606516},
82 | publisher = {Cold Spring Harbor Laboratory},
83 | URL = {https://www.biorxiv.org/content/early/2024/08/07/2024.08.04.606516},
84 | eprint = {https://www.biorxiv.org/content/early/2024/08/07/2024.08.04.606516.full.pdf},
85 | journal = {bioRxiv}
86 | }
87 | ```
88 |
89 | [pip]: https://pip.pypa.io/
90 | [pypi]: https://pypi.org/
91 | [api]: https://pertpy.readthedocs.io/en/latest/api.html
92 | [conda-forge]: https://anaconda.org/conda-forge/pertpy
93 | [//]: # "numfocus-fiscal-sponsor-attribution"
94 |
95 | pertpy is part of the scverse® project ([website](https://scverse.org), [governance](https://scverse.org/about/roles)) and is fiscally sponsored by [NumFOCUS](https://numfocus.org/).
96 | If you like scverse® and want to support our mission, please consider making a tax-deductible [donation](https://numfocus.org/donate-to-scverse) to help the project pay for developer time, professional services, travel, workshops, and a variety of other needs.
97 |
98 |
106 |
--------------------------------------------------------------------------------
/biome.jsonc:
--------------------------------------------------------------------------------
1 | {
2 | "$schema": "https://biomejs.dev/schemas/1.9.4/schema.json",
3 | "formatter": { "useEditorconfig": true },
4 | "overrides": [
5 | {
6 | "include": ["./.vscode/*.json", "**/*.jsonc", "**/asv.conf.json"],
7 | "json": {
8 | "formatter": {
9 | "trailingCommas": "all",
10 | },
11 | "parser": {
12 | "allowComments": true,
13 | "allowTrailingCommas": true,
14 | },
15 | },
16 | },
17 | ],
18 | }
19 |
--------------------------------------------------------------------------------
/codecov.yml:
--------------------------------------------------------------------------------
1 | codecov:
2 | require_ci_to_pass: no
3 |
4 | coverage:
5 | status:
6 | project:
7 | default:
8 | # Require 1% coverage, i.e., succeed as long as coverage collection works
9 | target: 1
10 | patch: false
11 | changes: false
12 |
13 | comment:
14 | layout: diff, flags, files
15 | behavior: once
16 | require_base: no
17 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line.
5 | SPHINXOPTS =
6 | SPHINXBUILD = python -msphinx
7 | SPHINXPROJ = pertpy
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/_ext/edit_on_github.py:
--------------------------------------------------------------------------------
1 | """Based on gist.github.com/MantasVaitkunas/7c16de233812adcb7028."""
2 |
3 | import os
4 | from typing import Any
5 |
6 | from sphinx.application import Sphinx
7 |
8 | __licence__ = "BSD (3 clause)"
9 |
10 |
11 | def get_github_repo(app: Sphinx, path: str) -> str: # noqa: D103
12 | if path.endswith(".ipynb"):
13 | return str(app.config.github_nb_repo)
14 | if "auto_examples" in path:
15 | return str(app.config.github_nb_repo)
16 | if "auto_tutorials" in path:
17 | return str(app.config.github_nb_repo)
18 | return str(app.config.github_repo)
19 |
20 |
21 | def _html_page_context(
22 | app: Sphinx, _pagename: str, templatename: str, context: dict[str, Any], doctree: Any | None
23 | ) -> None:
24 | # doctree is None - otherwise viewcode fails
25 | if templatename != "page.html" or doctree is None:
26 | return
27 |
28 | if not app.config.github_repo:
29 | return
30 |
31 | if not app.config.github_nb_repo:
32 | nb_repo = f"{app.config.github_repo}_notebooks"
33 | app.config.github_nb_repo = nb_repo
34 |
35 | path = os.path.relpath(doctree.get("source"), app.builder.srcdir)
36 | repo = get_github_repo(app, path)
37 |
38 | # For sphinx_rtd_theme.
39 | context["display_github"] = True
40 | context["github_user"] = "scverse"
41 | context["github_version"] = "master"
42 | context["github_repo"] = repo
43 | context["conf_py_path"] = "/docs/source/"
44 |
45 |
46 | def setup(app: Sphinx) -> None: # noqa: D103
47 | app.add_config_value("github_nb_repo", "", True)
48 | app.add_config_value("github_repo", "", True)
49 | app.connect("html-page-context", _html_page_context)
50 |
--------------------------------------------------------------------------------
/docs/_ext/typed_returns.py:
--------------------------------------------------------------------------------
1 | import re
2 | from collections.abc import Iterable, Iterator
3 |
4 | from sphinx.application import Sphinx
5 | from sphinx.ext.napoleon import NumpyDocstring
6 |
7 |
8 | def _process_return(lines: Iterable[str]) -> Iterator[str]:
9 | for line in lines:
10 | m = re.fullmatch(r"(?P\w+)\s+:\s+(?P[\w.]+)", line)
11 | if m:
12 | # Once this is in scanpydoc, we can use the fancy hover stuff
13 | yield f"**{m['param']}** : :class:`~{m['type']}`"
14 | else:
15 | yield line
16 |
17 |
18 | def _parse_returns_section(self: NumpyDocstring, section: str) -> list[str]:
19 | lines_raw = list(_process_return(self._dedent(self._consume_to_next_section())))
20 | lines: list[str] = self._format_block(":returns: ", lines_raw)
21 | if lines and lines[-1]:
22 | lines.append("")
23 | return lines
24 |
25 |
26 | def setup(app: Sphinx) -> None: # noqa: D103
27 | NumpyDocstring._parse_returns_section = _parse_returns_section
28 |
--------------------------------------------------------------------------------
/docs/_static/SCVI_LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2020 Romain Lopez, Adam Gayoso, Galen Xing, Yosef Lab
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | * Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | * Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | * Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/docs/_static/css/overwrite.css:
--------------------------------------------------------------------------------
1 | /* for the sphinx design cards */
2 | body {
3 | --sd-color-shadow: dimgrey;
4 | }
5 |
6 | dt:target,
7 | span.highlighted {
8 | background-color: #f0f0f0;
9 | }
10 |
11 | dl.citation > dt {
12 | float: left;
13 | margin-right: 15px;
14 | font-weight: bold;
15 | }
16 |
17 | /* Parameters normalize size and captialized, */
18 | dl .field-list dt {
19 | font-size: var(--font-size--normal) !important;
20 | text-transform: none !important;
21 | }
22 |
23 | /* examples and headings in classes */
24 | p.rubric {
25 | font-size: var(--font-size--normal);
26 | text-transform: none;
27 | font-weight: 500;
28 | }
29 |
30 | /* adapted from https://github.com/dask/dask-sphinx-theme/blob/main/dask_sphinx_theme/static/css/nbsphinx.css */
31 |
32 | .nbinput .prompt,
33 | .nboutput .prompt {
34 | display: none;
35 | }
36 | .nboutput .stderr {
37 | display: none;
38 | }
39 |
40 | div.nblast.container {
41 | padding-bottom: 10px !important;
42 | padding-right: 0px;
43 | padding-left: 0px;
44 | }
45 |
46 | div.nbinput.container {
47 | padding-top: 10px !important;
48 | padding-right: 0px;
49 | padding-left: 0px;
50 | }
51 |
52 | div.nbinput.container div.input_area div[class*="highlight"] > pre {
53 | padding: 10px !important;
54 | margin: 0;
55 | }
56 |
57 | p.topic-title {
58 | margin-top: 0;
59 | }
60 |
61 | /* so that api methods are small in sidebar */
62 | li.toctree-l3 {
63 | font-size: 81.25% !important;
64 | }
65 | li.toctree-l4 {
66 | font-size: 75% !important;
67 | }
68 |
69 | .bd-sidebar .caption-text {
70 | color: #e63946;
71 | font-weight: 600;
72 | text-transform: uppercase;
73 | }
74 |
--------------------------------------------------------------------------------
/docs/_static/css/sphinx_gallery.css:
--------------------------------------------------------------------------------
1 | .sphx-glr-thumbcontainer {
2 | background: inherit !important;
3 | min-height: 250px !important;
4 | margin: 10px !important;
5 | }
6 |
7 | .sphx-glr-thumbcontainer .headerlink {
8 | display: none !important;
9 | }
10 |
11 | div.sphx-glr-thumbcontainer span {
12 | font-style: normal !important;
13 | }
14 |
15 | .sphx-glr-thumbcontainer a.internal {
16 | padding: 140px 10px 0 !important;
17 | }
18 |
19 | .sphx-glr-thumbcontainer .figure {
20 | width: 200px !important;
21 | }
22 |
23 | .sphx-glr-thumbcontainer .figure.align-center {
24 | text-align: center;
25 | margin-left: 0%;
26 | transform: translate(0%);
27 | }
28 |
--------------------------------------------------------------------------------
/docs/_static/docstring_previews/augur_dp_scatter.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/augur_dp_scatter.png
--------------------------------------------------------------------------------
/docs/_static/docstring_previews/augur_important_features.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/augur_important_features.png
--------------------------------------------------------------------------------
/docs/_static/docstring_previews/augur_lollipop.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/augur_lollipop.png
--------------------------------------------------------------------------------
/docs/_static/docstring_previews/augur_scatterplot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/augur_scatterplot.png
--------------------------------------------------------------------------------
/docs/_static/docstring_previews/de_fold_change.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/de_fold_change.png
--------------------------------------------------------------------------------
/docs/_static/docstring_previews/de_multicomparison_fc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/de_multicomparison_fc.png
--------------------------------------------------------------------------------
/docs/_static/docstring_previews/de_paired_expression.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/de_paired_expression.png
--------------------------------------------------------------------------------
/docs/_static/docstring_previews/de_volcano.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/de_volcano.png
--------------------------------------------------------------------------------
/docs/_static/docstring_previews/dialogue_pairplot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/dialogue_pairplot.png
--------------------------------------------------------------------------------
/docs/_static/docstring_previews/dialogue_violin.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/dialogue_violin.png
--------------------------------------------------------------------------------
/docs/_static/docstring_previews/enrichment_dotplot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/enrichment_dotplot.png
--------------------------------------------------------------------------------
/docs/_static/docstring_previews/enrichment_gsea.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/enrichment_gsea.png
--------------------------------------------------------------------------------
/docs/_static/docstring_previews/milo_da_beeswarm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/milo_da_beeswarm.png
--------------------------------------------------------------------------------
/docs/_static/docstring_previews/milo_nhood.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/milo_nhood.png
--------------------------------------------------------------------------------
/docs/_static/docstring_previews/milo_nhood_graph.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/milo_nhood_graph.png
--------------------------------------------------------------------------------
/docs/_static/docstring_previews/mixscape_barplot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/mixscape_barplot.png
--------------------------------------------------------------------------------
/docs/_static/docstring_previews/mixscape_heatmap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/mixscape_heatmap.png
--------------------------------------------------------------------------------
/docs/_static/docstring_previews/mixscape_lda.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/mixscape_lda.png
--------------------------------------------------------------------------------
/docs/_static/docstring_previews/mixscape_perturbscore.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/mixscape_perturbscore.png
--------------------------------------------------------------------------------
/docs/_static/docstring_previews/mixscape_violin.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/mixscape_violin.png
--------------------------------------------------------------------------------
/docs/_static/docstring_previews/pseudobulk_samples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/pseudobulk_samples.png
--------------------------------------------------------------------------------
/docs/_static/docstring_previews/sccoda_boxplots.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/sccoda_boxplots.png
--------------------------------------------------------------------------------
/docs/_static/docstring_previews/sccoda_effects_barplot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/sccoda_effects_barplot.png
--------------------------------------------------------------------------------
/docs/_static/docstring_previews/sccoda_rel_abundance_dispersion_plot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/sccoda_rel_abundance_dispersion_plot.png
--------------------------------------------------------------------------------
/docs/_static/docstring_previews/sccoda_stacked_barplot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/sccoda_stacked_barplot.png
--------------------------------------------------------------------------------
/docs/_static/docstring_previews/scgen_reg_mean.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/scgen_reg_mean.png
--------------------------------------------------------------------------------
/docs/_static/docstring_previews/tasccoda_draw_effects.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/tasccoda_draw_effects.png
--------------------------------------------------------------------------------
/docs/_static/docstring_previews/tasccoda_draw_tree.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/tasccoda_draw_tree.png
--------------------------------------------------------------------------------
/docs/_static/docstring_previews/tasccoda_effects_umap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/docstring_previews/tasccoda_effects_umap.png
--------------------------------------------------------------------------------
/docs/_static/icons/code-24px.svg:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/docs/_static/icons/computer-24px.svg:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/docs/_static/icons/library_books-24px.svg:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/docs/_static/icons/play_circle_outline-24px.svg:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/docs/_static/pertpy_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/pertpy_logo.png
--------------------------------------------------------------------------------
/docs/_static/pertpy_logo.svg:
--------------------------------------------------------------------------------
1 |
5 |
--------------------------------------------------------------------------------
/docs/_static/placeholder.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/placeholder.png
--------------------------------------------------------------------------------
/docs/_static/tutorials/augur.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/tutorials/augur.png
--------------------------------------------------------------------------------
/docs/_static/tutorials/cinemaot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/tutorials/cinemaot.png
--------------------------------------------------------------------------------
/docs/_static/tutorials/dge.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/tutorials/dge.png
--------------------------------------------------------------------------------
/docs/_static/tutorials/dialogue.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/tutorials/dialogue.png
--------------------------------------------------------------------------------
/docs/_static/tutorials/distances.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/tutorials/distances.png
--------------------------------------------------------------------------------
/docs/_static/tutorials/distances_tests.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/tutorials/distances_tests.png
--------------------------------------------------------------------------------
/docs/_static/tutorials/enrichment.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/tutorials/enrichment.png
--------------------------------------------------------------------------------
/docs/_static/tutorials/guide_rna_assignment.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/tutorials/guide_rna_assignment.png
--------------------------------------------------------------------------------
/docs/_static/tutorials/mcfarland.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/tutorials/mcfarland.png
--------------------------------------------------------------------------------
/docs/_static/tutorials/metadata.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/tutorials/metadata.png
--------------------------------------------------------------------------------
/docs/_static/tutorials/milo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/tutorials/milo.png
--------------------------------------------------------------------------------
/docs/_static/tutorials/mixscape.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/tutorials/mixscape.png
--------------------------------------------------------------------------------
/docs/_static/tutorials/norman.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/tutorials/norman.png
--------------------------------------------------------------------------------
/docs/_static/tutorials/ontology.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/tutorials/ontology.png
--------------------------------------------------------------------------------
/docs/_static/tutorials/perturbation_space.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/tutorials/perturbation_space.png
--------------------------------------------------------------------------------
/docs/_static/tutorials/placeholder.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/tutorials/placeholder.png
--------------------------------------------------------------------------------
/docs/_static/tutorials/sccoda.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/tutorials/sccoda.png
--------------------------------------------------------------------------------
/docs/_static/tutorials/sccoda_extended.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/tutorials/sccoda_extended.png
--------------------------------------------------------------------------------
/docs/_static/tutorials/scgen_perturbation_prediction.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/tutorials/scgen_perturbation_prediction.png
--------------------------------------------------------------------------------
/docs/_static/tutorials/tasccoda.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/tutorials/tasccoda.png
--------------------------------------------------------------------------------
/docs/_static/tutorials/zhang.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/_static/tutorials/zhang.png
--------------------------------------------------------------------------------
/docs/_templates/autosummary/class.rst:
--------------------------------------------------------------------------------
1 | {{ fullname | escape | underline}}
2 |
3 | .. currentmodule:: {{ module }}
4 |
5 | .. add toctree option to make autodoc generate the pages
6 |
7 | .. autoclass:: {{ objname }}
8 |
9 | {% block attributes %}
10 | {% if attributes %}
11 | Attributes table
12 | ~~~~~~~~~~~~~~~~
13 |
14 | .. autosummary::
15 | {% for item in attributes %}
16 | ~{{ name }}.{{ item }}
17 | {%- endfor %}
18 | {% endif %}
19 | {% endblock %}
20 |
21 | {% block methods %}
22 | {% if methods %}
23 | Methods table
24 | ~~~~~~~~~~~~~
25 |
26 | .. autosummary::
27 | {% for item in methods %}
28 | {%- if item != '__init__' %}
29 | ~{{ name }}.{{ item }}
30 | {%- endif -%}
31 | {%- endfor %}
32 | {% endif %}
33 | {% endblock %}
34 |
35 | {% block attributes_documentation %}
36 | {% if attributes %}
37 | Attributes
38 | ~~~~~~~~~~
39 |
40 | {% for item in attributes %}
41 |
42 | .. autoattribute:: {{ [objname, item] | join(".") }}
43 | {%- endfor %}
44 |
45 | {% endif %}
46 | {% endblock %}
47 |
48 | {% block methods_documentation %}
49 | {% if methods %}
50 | Methods
51 | ~~~~~~~
52 |
53 | {% for item in methods %}
54 | {%- if item != '__init__' %}
55 |
56 | .. automethod:: {{ [objname, item] | join(".") }}
57 | {%- endif -%}
58 | {%- endfor %}
59 |
60 | {% endif %}
61 | {% endblock %}
62 |
--------------------------------------------------------------------------------
/docs/about/background.md:
--------------------------------------------------------------------------------
1 | # About Pertpy
2 |
3 | Pertpy is an end-to-end framework for the analysis of large-scale single-cell perturbation experiments.
4 | It provides access to harmonized perturbation datasets and metadata databases along with numerous fast and user-friendly implementations of both established and novel methods such as automatic metadata annotation or perturbation distances to efficiently analyze perturbation data.
5 | As part of the scverse ecosystem, pertpy interoperates with existing single-cell analysis libraries and is designed to be easily extended.
6 | If you find pertpy useful for your research, please check out {doc}`cite`.
7 |
8 | ## Design principles
9 |
10 | Our framework is based on three key principles: `Modularity`, `Flexibility`, and `Scalability`.
11 |
12 | ### Modularity
13 |
14 | Pertpy includes modules for analysis of single and combinatorial perturbations covering diverse types of perturbation data including genetic knockouts, drug screens, and disease states.
15 | The framework is designed for flexibility, offering more than 100 composable and interoperable analysis functions organized in modules which further ease downstream interpretation and visualization.
16 | These modules host fundamental building blocks for implementation and methods that share functionality and can be chained into custom pipelines.
17 |
18 | A typical Pertpy workflow consists of several steps:
19 |
20 | * Initial **data transformation** such as guide RNA assignment for CRISPR screens
21 | * **Quality control** to address confounding factors and technical variation
22 | * **Metadata annotation** against ontologies and enrichment from databases
23 | * **Perturbation space analysis** to learn biologically interpretable embeddings
24 | * **Downstream analysis** including differential expression, compositional analysis, and distance calculation
25 |
26 | This modular approach yields a powerful and flexible framework as many analysis steps can be independently applied or chained together.
27 |
28 | ### Flexibility
29 |
30 | Pertpy is purpose-built to organize, analyze, and visualize complex perturbation datasets.
31 | It is flexible and can be applied to datasets of different assays, data types, sizes, and perturbations, thereby unifying previous data-type- or assay-specific single-problem approaches.
32 | Designed to integrate external metadata with measured data, it enables unprecedented contextualization of results through swiftly built, experiment-specific pipelines, leading to more robust outcomes.
33 |
34 | The inputs to a typical analysis with pertpy are unimodal scRNA-seq or multimodal perturbation readouts stored in AnnData or MuData objects.
35 | While pertpy is primarily designed to explore perturbations such as genetic modifications, drug treatments, exposure to pathogens, and other environmental conditions, its utility extends to various other perturbation settings, including diverse disease states where experimental perturbations have not been applied.
36 |
37 | ### Scalability
38 |
39 | Pertpy addresses a wide array of use-cases and different types of growing datasets through its sparse and memory-efficient implementations, which leverage the parallelization and GPU acceleration library Jax, and numba, thereby making them substantially faster than original implementations.
40 | The framework can be applied to datasets ranging from thousands to millions of cells.
41 |
42 | For example, when analyzing CRISPR screens, Pertpy's implementation of Mixscape is optimized using PyNNDescent for nearest neighbor search during the calculation of perturbation signatures.
43 | Other methods such as scCODA and tascCODA are accelerated by replacing the Hamiltonian Monte Carlo algorithm in TensorFlow with the no-U-turn sampler from numpyro.
44 | CINEMA-OT is optimized with ott-jax to make the implementation portable across hardware, enabling GPU acceleration.
45 |
46 | ## Why is it called "Pertpy"?
47 |
48 | Pertpy is named for its core purpose: The analysis of **pert**urbations in **Py**thon.
49 | The framework unifies perturbation analysis approaches across different data types and experimental designs, providing a comprehensive solution for understanding cellular responses to various stimuli.
50 |
--------------------------------------------------------------------------------
/docs/about/cite.md:
--------------------------------------------------------------------------------
1 | # Citing pertpy
2 |
3 | If you find pertpy useful for your research, please consider citing our work as follows:
4 |
5 | ```bibtex
6 | @article {Heumos2024.08.04.606516,
7 | author = {Heumos, Lukas and Ji, Yuge and May, Lilly and Green, Tessa and Zhang, Xinyue and Wu, Xichen and Ostner, Johannes and Peidli, Stefan and Schumacher, Antonia and Hrovatin, Karin and Müller, Michaela and Chong, Faye and Sturm, Gregor and Tejada, Alejandro and Dann, Emma and Dong, Mingze and Bahrami, Mojtaba and Gold, Ilan and Rybakov, Sergei and Namsaraeva, Altana and Moinfar, Amir and Zheng, Zihe and Roellin, Eljas and Mekki, Isra and Sander, Chris and Lotfollahi, Mohammad and Schiller, Herbert B. and Theis, Fabian J.},
8 | title = {Pertpy: an end-to-end framework for perturbation analysis},
9 | elocation-id = {2024.08.04.606516},
10 | year = {2024},
11 | doi = {10.1101/2024.08.04.606516},
12 | publisher = {Cold Spring Harbor Laboratory},
13 | URL = {https://www.biorxiv.org/content/early/2024/08/07/2024.08.04.606516},
14 | eprint = {https://www.biorxiv.org/content/early/2024/08/07/2024.08.04.606516.full.pdf},
15 | journal = {bioRxiv}
16 | }
17 | ```
18 |
19 | If you are using any previously published tool, please also cite the original publication.
20 | All tool specific references can be found here: {doc}`../references`.
21 |
--------------------------------------------------------------------------------
/docs/api.md:
--------------------------------------------------------------------------------
1 | # API
2 |
3 | Import the pertpy API as follows:
4 |
5 | ```python
6 | import pertpy as pt
7 | ```
8 |
9 | You can then access the respective modules like:
10 |
11 | ```python
12 | pt.tl.cool_fancy_tool()
13 | ```
14 |
15 | ```{toctree}
16 | :maxdepth: 1
17 |
18 | api/datasets_index
19 | api/preprocessing_index
20 | api/tools_index
21 | api/metadata_index
22 | ```
23 |
--------------------------------------------------------------------------------
/docs/api/datasets_index.md:
--------------------------------------------------------------------------------
1 | ```{eval-rst}
2 | .. currentmodule:: pertpy
3 | ```
4 |
5 | # Datasets
6 |
7 | pertpy provides access to several curated single-cell datasets spanning several types of perturbations.
8 | Many of the datasets originate from [scperturb](http://projects.sanderlab.org/scperturb/) {cite}`Peidli2024`.
9 |
10 | ```{eval-rst}
11 | .. autosummary::
12 | :toctree: data
13 |
14 | data.adamson_2016_pilot
15 | data.adamson_2016_upr_epistasis
16 | data.adamson_2016_upr_perturb_seq
17 | data.aissa_2021
18 | data.bhattacherjee
19 | data.burczynski_crohn
20 | data.chang_2021
21 | data.combosciplex
22 | data.cinemaot_example
23 | data.datlinger_2017
24 | data.datlinger_2021
25 | data.dialogue_example
26 | data.distance_example
27 | data.dixit_2016
28 | data.dixit_2016_raw
29 | data.dong_2023
30 | data.frangieh_2021
31 | data.frangieh_2021_protein
32 | data.frangieh_2021_raw
33 | data.frangieh_2021_rna
34 | data.gasperini_2019_atscale
35 | data.gasperini_2019_highmoi
36 | data.gasperini_2019_lowmoi
37 | data.gehring_2019
38 | data.haber_2017_regions
39 | data.hagai_2018
40 | data.kang_2018
41 | data.mcfarland_2020
42 | data.norman_2019
43 | data.norman_2019_raw
44 | data.papalexi_2021
45 | data.replogle_2022_k562_essential
46 | data.replogle_2022_k562_gwps
47 | data.replogle_2022_rpe1
48 | data.sc_sim_augur
49 | data.schiebinger_2019_16day
50 | data.schiebinger_2019_18day
51 | data.schraivogel_2020_tap_screen_chr8
52 | data.schraivogel_2020_tap_screen_chr11
53 | data.sciplex_gxe1
54 | data.sciplex3_raw
55 | data.shifrut_2018
56 | data.smillie_2019
57 | data.srivatsan_2020_sciplex2
58 | data.srivatsan_2020_sciplex3
59 | data.srivatsan_2020_sciplex4
60 | data.stephenson_2021_subsampled
61 | data.tasccoda_example
62 | data.tian_2019_day7neuron
63 | data.tian_2019_ipsc
64 | data.tian_2021_crispra
65 | data.tian_2021_crispri
66 | data.weinreb_2020
67 | data.xie_2017
68 | data.zhao_2021
69 | data.zhang_2021
70 | ```
71 |
--------------------------------------------------------------------------------
/docs/api/metadata_index.md:
--------------------------------------------------------------------------------
1 | ```{eval-rst}
2 | .. currentmodule:: pertpy
3 | ```
4 |
5 | # Metadata
6 |
7 | The metadata module provides tooling to annotate perturbations by querying databases.
8 | Such metadata can aid with the development of biologically informed models and can be used for enrichment tests.
9 |
10 | ## Cell line
11 |
12 | This module allows for the retrieval of various types of information related to cell lines,
13 | including cell line annotation, bulk RNA and protein expression data.
14 |
15 | Available databases for cell line metadata:
16 |
17 | - [The Cancer Dependency Map Project at Broad](https://depmap.org/portal/)
18 | - [The Cancer Dependency Map Project at Sanger](https://depmap.sanger.ac.uk/)
19 | - [Genomics of Drug Sensitivity in Cancer (GDSC)](https://www.cancerrxgene.org/)
20 |
21 | ## Compound
22 |
23 | The Compound module enables the retrieval of various types of information related to compounds of interest, including the most common synonym, pubchemID and canonical SMILES.
24 |
25 | Available databases for compound metadata:
26 |
27 | - [PubChem](https://pubchem.ncbi.nlm.nih.gov/)
28 |
29 | ## Mechanism of Action
30 |
31 | This module aims to retrieve metadata of mechanism of action studies related to perturbagens of interest, depending on the molecular targets.
32 |
33 | Available databases for mechanism of action metadata:
34 |
35 | - [CLUE](https://clue.io/)
36 |
37 | ## Drug
38 |
39 | This module allows for the retrieval of Drug target information.
40 |
41 | Available databases for drug metadata:
42 |
43 | - [chembl](https://www.ebi.ac.uk/chembl/)
44 |
45 | ```{eval-rst}
46 | .. autosummary::
47 | :toctree: metadata
48 | :recursive:
49 |
50 | metadata.CellLine
51 | metadata.Compound
52 | metadata.Moa
53 | metadata.Drug
54 | metadata.LookUp
55 | ```
56 |
--------------------------------------------------------------------------------
/docs/api/preprocessing_index.md:
--------------------------------------------------------------------------------
1 | ```{eval-rst}
2 | .. currentmodule:: pertpy
3 | ```
4 |
5 | # Preprocessing
6 |
7 | ## Guide Assignment
8 |
9 | Guide assignment is essential for quality control in single-cell Perturb-seq data, ensuring accurate mapping of guide RNAs to cells for reliable interpretation of gene perturbation effects.
10 | pertpy provides a simple function to assign guides based on thresholds and a Gaussian mixture model {cite}`Replogle2022`.
11 |
12 | ```{eval-rst}
13 | .. autosummary::
14 | :toctree: preprocessing
15 | :nosignatures:
16 |
17 | preprocessing.GuideAssignment
18 | ```
19 |
20 | Example implementation:
21 |
22 | ```python
23 | import pertpy as pt
24 | import scanpy as sc
25 |
26 | mdata = pt.dt.papalexi_2021()
27 | gdo = mdata.mod["gdo"]
28 | gdo.layers["counts"] = gdo.X.copy()
29 | sc.pp.log1p(gdo)
30 |
31 | ga = pt.pp.GuideAssignment()
32 | ga.assign_by_threshold(gdo, 5, layer="counts", output_layer="assigned_guides")
33 |
34 | ga.plot_heatmap(gdo, layer="assigned_guides")
35 | ```
36 |
37 | See [guide assignment tutorial](https://pertpy.readthedocs.io/en/latest/tutorials/notebooks/guide_rna_assignment.html).
38 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # mypy: ignore-errors
3 |
4 | import sys
5 | from datetime import datetime
6 | from importlib.metadata import metadata
7 | from pathlib import Path
8 |
9 | HERE = Path(__file__).parent
10 | sys.path[:0] = [str(HERE.parent), str(HERE / "extensions")]
11 |
12 | needs_sphinx = "4.3"
13 |
14 | info = metadata("pertpy")
15 | project_name = info["Name"]
16 | author = info["Author"]
17 | copyright = f"{datetime.now():%Y}, {author}."
18 | version = info["Version"]
19 | urls = dict(pu.split(", ") for pu in info.get_all("Project-URL"))
20 | repository_url = urls["Source"]
21 | release = info["Version"]
22 | github_repo = "pertpy"
23 | master_doc = "index"
24 | language = "en"
25 |
26 | extensions = [
27 | "myst_nb",
28 | "sphinx.ext.autodoc",
29 | "sphinx.ext.intersphinx",
30 | "sphinx.ext.viewcode",
31 | "nbsphinx",
32 | "nbsphinx_link",
33 | "sphinx.ext.mathjax",
34 | "sphinx.ext.napoleon",
35 | "sphinx_autodoc_typehints", # needs to be after napoleon
36 | "sphinx.ext.autosummary",
37 | "sphinx_copybutton",
38 | "sphinx_gallery.load_style",
39 | "sphinx_remove_toctrees",
40 | "sphinx_design",
41 | "sphinx_issues",
42 | "sphinxcontrib.bibtex",
43 | "IPython.sphinxext.ipython_console_highlighting",
44 | ]
45 |
46 | ogp_site_url = "https://pertpy.readthedocs.io/en/latest/"
47 | ogp_image = "https://pertpy.readthedocs.io/en/latest/_static/pertpy_logo.png"
48 |
49 | # nbsphinx specific settings
50 | exclude_patterns = [
51 | "_build",
52 | "Thumbs.db",
53 | ".DS_Store",
54 | "auto_*/**.ipynb",
55 | "auto_*/**.md5",
56 | "auto_*/**.py",
57 | "**.ipynb_checkpoints",
58 | ]
59 | nbsphinx_execute = "never"
60 | pygments_style = "sphinx"
61 |
62 | templates_path = ["_templates"]
63 | bibtex_bibfiles = ["references.bib"]
64 | nitpicky = True # Warn about broken links
65 | source_suffix = {
66 | ".rst": "restructuredtext",
67 | ".ipynb": "myst-nb",
68 | ".myst": "myst-nb",
69 | }
70 |
71 | suppress_warnings = ["toc.not_included"]
72 |
73 | autosummary_generate = True
74 | autosummary_imported_members = True
75 | autodoc_member_order = "groupwise"
76 | napoleon_google_docstring = True
77 | napoleon_include_init_with_doc = False
78 | napoleon_use_rtype = True
79 | napoleon_use_param = True
80 | myst_heading_anchors = 6
81 | napoleon_custom_sections = [("Params", "Parameters")]
82 | todo_include_todos = False
83 | annotate_defaults = True
84 | myst_enable_extensions = [
85 | "amsmath",
86 | "colon_fence",
87 | "deflist",
88 | "dollarmath",
89 | "html_image",
90 | "html_admonition",
91 | ]
92 | myst_url_schemes = ("http", "https", "mailto")
93 | nb_execution_mode = "off"
94 | nb_merge_streams = True
95 | warn_as_error = True
96 |
97 | typehints_defaults = "comma"
98 |
99 | html_theme = "scanpydoc"
100 | html_title = "pertpy"
101 | html_logo = "_static/pertpy_logo.svg"
102 |
103 | html_theme_options = {}
104 |
105 | html_static_path = ["_static"]
106 | html_css_files = ["css/overwrite.css", "css/sphinx_gallery.css"]
107 | html_show_sphinx = False
108 |
109 | add_module_names = False
110 | autodoc_mock_imports = ["ete4"]
111 | intersphinx_mapping = {
112 | "anndata": ("https://anndata.readthedocs.io/en/stable/", None),
113 | "mudata": ("https://mudata.readthedocs.io/en/stable/", None),
114 | "scvi-tools": ("https://docs.scvi-tools.org/en/stable/", None),
115 | "ipython": ("https://ipython.readthedocs.io/en/stable/", None),
116 | "matplotlib": ("https://matplotlib.org/stable/", None),
117 | "numpy": ("https://numpy.org/doc/stable/", None),
118 | "pandas": ("https://pandas.pydata.org/pandas-docs/stable/", None),
119 | "python": ("https://docs.python.org/3", None),
120 | "scipy": ("https://docs.scipy.org/doc/scipy/", None),
121 | "torch": ("https://docs.pytorch.org/docs/main", None),
122 | "scanpy": ("https://scanpy.readthedocs.io/en/stable/", None),
123 | "pytorch_lightning": ("https://lightning.ai/docs/pytorch/stable/", None),
124 | "pyro": ("https://docs.pyro.ai/en/stable/", None),
125 | "pymde": ("https://pymde.org/", None),
126 | "flax": ("https://flax.readthedocs.io/en/latest/", None),
127 | "jax": ("https://docs.jax.dev/en/latest/", None),
128 | "ete": ("https://etetoolkit.org/docs/latest/", None),
129 | "arviz": ("https://python.arviz.org/en/stable/", None),
130 | "sklearn": ("https://scikit-learn.org/stable", None),
131 | "statsmodels": ("https://www.statsmodels.org/stable", None),
132 | }
133 | nitpick_ignore = [
134 | ("py:class", "ete4.core.tree.Tree"),
135 | ("py:class", "ete4.treeview.TreeStyle"),
136 | ("py:class", "pertpy.tools._distances._distances.MeanVar"),
137 | ("py:class", "The requested data as a NumPy array."),
138 | ("py:class", "The full registry saved with the model"),
139 | ("py:class", "The requested data."),
140 | ("py:class", "Model with loaded state dictionaries."),
141 | ("py:class", "pertpy.tools.lazy_import..Placeholder"),
142 | ]
143 |
144 | sphinx_gallery_conf = {"nested_sections=": False}
145 | nbsphinx_thumbnails = {
146 | "tutorials/notebooks/guide_rna_assignment": "_static/tutorials/guide_rna_assignment.png",
147 | "tutorials/notebooks/mixscape": "_static/tutorials/mixscape.png",
148 | "tutorials/notebooks/augur": "_static/tutorials/augur.png",
149 | "tutorials/notebooks/sccoda": "_static/tutorials/sccoda.png",
150 | "tutorials/notebooks/sccoda_extended": "_static/tutorials/sccoda_extended.png",
151 | "tutorials/notebooks/tasccoda": "_static/tutorials/tasccoda.png",
152 | "tutorials/notebooks/milo": "_static/tutorials/milo.png",
153 | "tutorials/notebooks/dialogue": "_static/tutorials/dialogue.png",
154 | "tutorials/notebooks/enrichment": "_static/tutorials/enrichment.png",
155 | "tutorials/notebooks/distances": "_static/tutorials/distances.png",
156 | "tutorials/notebooks/distance_tests": "_static/tutorials/distances_tests.png",
157 | "tutorials/notebooks/cinemaot": "_static/tutorials/cinemaot.png",
158 | "tutorials/notebooks/scgen_perturbation_prediction": "_static/tutorials/scgen_perturbation_prediction.png",
159 | "tutorials/notebooks/perturbation_space": "_static/tutorials/perturbation_space.png",
160 | "tutorials/notebooks/differential_gene_expression": "_static/tutorials/dge.png",
161 | "tutorials/notebooks/metadata_annotation": "_static/tutorials/metadata.png",
162 | "tutorials/notebooks/ontology_mapping": "_static/tutorials/ontology.png",
163 | "tutorials/notebooks/norman_use_case": "_static/tutorials/norman.png",
164 | "tutorials/notebooks/mcfarland_use_case": "_static/tutorials/mcfarland.png",
165 | "tutorials/notebooks/zhang_use_case": "_static/tutorials/zhang.png",
166 | }
167 |
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | # pertpy
2 |
3 | Pertpy is a scverse ecosystem framework for analyzing large-scale single-cell perturbation experiments.
4 | It provides tools for harmonizing perturbation datasets, automating metadata annotation, calculating perturbation distances, and efficiently analyzing how cells respond to various stimuli like genetic modifications, drug treatments, and environmental changes.
5 |
6 | 
7 |
8 | ```{eval-rst}
9 | .. card:: Installation :octicon:`plug;1em;`
10 | :link: installation
11 | :link-type: doc
12 |
13 | New to *pertpy*? Check out the installation guide.
14 | ```
15 |
16 | ```{eval-rst}
17 | .. card:: API reference :octicon:`book;1em;`
18 | :link: api
19 | :link-type: doc
20 |
21 | The API reference contains a detailed description of the pertpy API.
22 | ```
23 |
24 | ```{eval-rst}
25 | .. card:: Tutorials :octicon:`play;1em;`
26 | :link: tutorials
27 | :link-type: doc
28 |
29 | The tutorials walk you through real-world applications of pertpy.
30 | ```
31 |
32 | ```{eval-rst}
33 | .. card:: Discussion :octicon:`megaphone;1em;`
34 | :link: https://discourse.scverse.org/
35 |
36 | Need help? Reach out on our forum to get your questions answered!
37 |
38 | ```
39 |
40 | ```{eval-rst}
41 | .. card:: GitHub :octicon:`mark-github;1em;`
42 | :link: https://github.com/scverse/pertpy
43 |
44 | Found a bug? Interested in improving pertpy? Checkout our GitHub for the latest developments.
45 |
46 | ```
47 |
48 | ```{toctree}
49 | :caption: 'General'
50 | :hidden: true
51 | :maxdepth: 2
52 |
53 | installation
54 | api
55 | contributing
56 | changelog
57 | references
58 | ```
59 |
60 | ```{toctree}
61 | :caption: 'Gallery'
62 | :hidden: true
63 | :maxdepth: 3
64 |
65 | tutorials
66 | usecases
67 | ```
68 |
69 | ```{toctree}
70 | :caption: 'About'
71 | :hidden: true
72 | :maxdepth: 2
73 |
74 | about/background
75 | about/cite
76 | GitHub
77 | Discourse
78 | ```
79 |
80 | ## Citation
81 |
82 | ```bibtex
83 | @article {Heumos2024.08.04.606516,
84 | author = {Heumos, Lukas and Ji, Yuge and May, Lilly and Green, Tessa and Zhang, Xinyue and Wu, Xichen and Ostner, Johannes and Peidli, Stefan and Schumacher, Antonia and Hrovatin, Karin and Müller, Michaela and Chong, Faye and Sturm, Gregor and Tejada, Alejandro and Dann, Emma and Dong, Mingze and Bahrami, Mojtaba and Gold, Ilan and Rybakov, Sergei and Namsaraeva, Altana and Moinfar, Amir and Zheng, Zihe and Roellin, Eljas and Mekki, Isra and Sander, Chris and Lotfollahi, Mohammad and Schiller, Herbert B. and Theis, Fabian J.},
85 | title = {Pertpy: an end-to-end framework for perturbation analysis},
86 | elocation-id = {2024.08.04.606516},
87 | year = {2024},
88 | doi = {10.1101/2024.08.04.606516},
89 | publisher = {Cold Spring Harbor Laboratory},
90 | URL = {https://www.biorxiv.org/content/early/2024/08/07/2024.08.04.606516},
91 | eprint = {https://www.biorxiv.org/content/early/2024/08/07/2024.08.04.606516.full.pdf},
92 | journal = {bioRxiv}
93 | }
94 | ```
95 |
96 | ## NumFOCUS
97 |
98 | [//]: # "numfocus-fiscal-sponsor-attribution"
99 |
100 | pertpy is part of the scverse® project ([website](https://scverse.org), [governance](https://scverse.org/about/roles)) and is fiscally sponsored by [NumFOCUS](https://numfocus.org/).
101 | If you like scverse® and want to support our mission, please consider making a tax-deductible [donation](https://numfocus.org/donate-to-scverse) to help the project pay for developer time, professional services, travel, workshops, and a variety of other needs.
102 |
103 |
111 |
--------------------------------------------------------------------------------
/docs/installation.md:
--------------------------------------------------------------------------------
1 | ```{highlight} shell
2 |
3 | ```
4 |
5 | # Installation
6 |
7 | ## Stable release
8 |
9 | ### PyPI
10 |
11 | To install pertpy, run this command in your terminal:
12 |
13 | ```console
14 | pip install pertpy
15 | ```
16 |
17 | This is the preferred method to install pertpy, as it will always install the most recent stable release.
18 | If you don't have [pip] installed, this [Python installation guide] can guide you through the process.
19 |
20 | ### conda-forge
21 |
22 | Alternatively, you can install pertpy from [conda-forge]:
23 |
24 | ```console
25 | conda install -c conda-forge pertpy
26 | ```
27 |
28 | ### Additional dependency groups
29 |
30 | #### Differential gene expression interface
31 |
32 | The DGE interface of pertpy requires additional dependencies that can be installed by running:
33 |
34 | ```console
35 | pip install pertpy[de]
36 | ```
37 |
38 | Note that edger in pertpy requires edger and rpy2 to be installed:
39 |
40 | ```R
41 | BiocManager::install("edgeR")
42 | ```
43 |
44 | ```console
45 | pip install rpy2
46 | ```
47 |
48 | #### milo
49 |
50 | milo requires either the "de" extra for the "pydeseq2" solver:
51 |
52 | ```console
53 | pip install 'pertpy[de]'
54 | ```
55 |
56 | or, edger, statmod, and rpy2 for the "edger" solver:
57 |
58 | ```R
59 | BiocManager::install("edgeR")
60 | BiocManager::install("statmod")
61 | ```
62 |
63 | ```console
64 | pip install rpy2
65 | ```
66 |
67 | #### tascCODA
68 |
69 | TascCODA requires an additional set of dependencies (ete4, pyqt6, and toytree) that can be installed by running:
70 |
71 | ```console
72 | pip install pertpy[tcoda]
73 | ```
74 |
75 | ## From sources
76 |
77 | The sources for pertpy can be downloaded from the [Github repo].
78 |
79 | You can either clone the public repository:
80 |
81 | ```console
82 | $ git clone git://github.com/scverse/pertpy
83 | ```
84 |
85 | Or download the [tarball]:
86 |
87 | ```console
88 | $ curl -OJL https://github.com/scverse/pertpy/tarball/master
89 | ```
90 |
91 | [github repo]: https://github.com/scverse/pertpy
92 | [pip]: https://pip.pypa.io
93 | [conda-forge]: https://anaconda.org/conda-forge/pertpy
94 | [python installation guide]: http://docs.python-guide.org/en/latest/starting/installation/
95 | [tarball]: https://github.com/scverse/pertpy/tarball/master
96 | [Homebrew]: https://brew.sh/
97 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=python -msphinx
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 | set SPHINXPROJ=pertpy
13 |
14 | if "%1" == "" goto help
15 |
16 | %SPHINXBUILD% >NUL 2>NUL
17 | if errorlevel 9009 (
18 | echo.
19 | echo.The Sphinx module was not found. Make sure you have Sphinx installed,
20 | echo.then set the SPHINXBUILD environment variable to point to the full
21 | echo.path of the 'sphinx-build' executable. Alternatively you may add the
22 | echo.Sphinx directory to PATH.
23 | echo.
24 | echo.If you don't have Sphinx installed, grab it from
25 | echo.http://sphinx-doc.org/
26 | exit /b 1
27 | )
28 |
29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
30 | goto end
31 |
32 | :help
33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
34 |
35 | :end
36 | popd
37 |
--------------------------------------------------------------------------------
/docs/references.md:
--------------------------------------------------------------------------------
1 | # References
2 |
3 | ```{bibliography}
4 | :cited:
5 | ```
6 |
--------------------------------------------------------------------------------
/docs/tutorials.md:
--------------------------------------------------------------------------------
1 | ---
2 | orphan: true
3 | ---
4 |
5 | # Tutorials
6 |
7 | The easiest way to get familiar with pertpy is to follow along with our tutorials.
8 |
9 | :::{note}
10 | For questions about the usage of pertpy use the [scverse discourse](https://discourse.scverse.org/).
11 | :::
12 |
13 | ```{toctree}
14 | :maxdepth: 1
15 |
16 | tutorials/preprocessing.md
17 | tutorials/tools.md
18 | tutorials/metadata.md
19 | ```
20 |
--------------------------------------------------------------------------------
/docs/tutorials/metadata.md:
--------------------------------------------------------------------------------
1 | # Metadata
2 |
3 | ```{eval-rst}
4 | .. nbgallery::
5 |
6 | notebooks/metadata_annotation
7 | notebooks/ontology_mapping
8 | ```
9 |
--------------------------------------------------------------------------------
/docs/tutorials/preprocessing.md:
--------------------------------------------------------------------------------
1 | # Preprocessing
2 |
3 | ```{eval-rst}
4 | .. nbgallery::
5 |
6 | notebooks/guide_rna_assignment
7 | ```
8 |
--------------------------------------------------------------------------------
/docs/tutorials/tools.md:
--------------------------------------------------------------------------------
1 | # Tools
2 |
3 | ## Differential gene expression
4 |
5 | ```{eval-rst}
6 | .. nbgallery::
7 |
8 | notebooks/differential_gene_expression
9 | ```
10 |
11 | ## Pooled CRISPR screens
12 |
13 | ```{eval-rst}
14 | .. nbgallery::
15 |
16 | notebooks/mixscape
17 | ```
18 |
19 | ## Compositional analysis
20 |
21 | ```{eval-rst}
22 | .. nbgallery::
23 |
24 | notebooks/sccoda
25 | notebooks/sccoda_extended
26 | notebooks/tasccoda
27 | notebooks/milo
28 | ```
29 |
30 | ## Multicellular and gene programs
31 |
32 | ```{eval-rst}
33 | .. nbgallery::
34 |
35 | notebooks/dialogue
36 | notebooks/enrichment
37 | ```
38 |
39 | ## Distances and permutation tests
40 |
41 | ```{eval-rst}
42 | .. nbgallery::
43 |
44 | notebooks/distances
45 | notebooks/distance_tests
46 | ```
47 |
48 | ## Response prediction
49 |
50 | ```{eval-rst}
51 | .. nbgallery::
52 |
53 | notebooks/augur
54 | notebooks/cinemaot
55 | notebooks/scgen_perturbation_prediction
56 | ```
57 |
58 | ## Perturbation space
59 |
60 | ```{eval-rst}
61 | .. nbgallery::
62 |
63 | notebooks/perturbation_space
64 | ```
65 |
--------------------------------------------------------------------------------
/docs/usecases.md:
--------------------------------------------------------------------------------
1 | # Use cases
2 |
3 | Our use cases showcase a variety of pertpy tools applied to one dataset.
4 | They are designed to give you a sense of how to use pertpy in a real-world scenario.
5 | The use cases featured here are those we present in the [pertpy preprint](https://www.biorxiv.org/content/10.1101/2024.08.04.606516v1).
6 |
7 | ```{eval-rst}
8 | .. nbgallery::
9 |
10 | tutorials/notebooks/norman_use_case
11 | tutorials/notebooks/mcfarland_use_case
12 | tutorials/notebooks/zhang_use_case
13 | ```
14 |
--------------------------------------------------------------------------------
/docs/utils.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/docs/utils.py
--------------------------------------------------------------------------------
/pertpy/__init__.py:
--------------------------------------------------------------------------------
1 | """Top-level package for pertpy."""
2 |
3 | __author__ = "Lukas Heumos"
4 | __email__ = "lukas.heumos@posteo.net"
5 | __version__ = "1.0.0"
6 |
7 | import warnings
8 |
9 | from anndata._core.aligned_df import ImplicitModificationWarning
10 | from matplotlib import MatplotlibDeprecationWarning
11 | from numba import NumbaDeprecationWarning
12 |
13 | warnings.filterwarnings("ignore", category=NumbaDeprecationWarning)
14 | warnings.filterwarnings("ignore", category=MatplotlibDeprecationWarning)
15 | warnings.filterwarnings("ignore", category=SyntaxWarning)
16 | warnings.filterwarnings("ignore", category=UserWarning, module="scvi._settings")
17 | warnings.filterwarnings("ignore", message="Environment variable.*redefined by R")
18 | warnings.filterwarnings("ignore", message="Transforming to str index.", category=ImplicitModificationWarning)
19 |
20 | import mudata
21 |
22 | mudata.set_options(pull_on_update=False)
23 |
24 | from . import data as dt
25 | from . import metadata as md
26 | from . import plot as pl
27 | from . import preprocessing as pp
28 | from . import tools as tl
29 |
--------------------------------------------------------------------------------
/pertpy/_doc.py:
--------------------------------------------------------------------------------
1 | from textwrap import dedent
2 |
3 |
4 | def _doc_params(**kwds): # pragma: no cover
5 | r"""Docstrings should start with "\" in the first line for proper formatting."""
6 |
7 | def dec(obj):
8 | obj.__orig_doc__ = obj.__doc__
9 | obj.__doc__ = dedent(obj.__doc__.format_map(kwds))
10 | return obj
11 |
12 | return dec
13 |
14 |
15 | doc_common_plot_args = """\
16 | return_fig: if `True`, returns figure of the plot, that can be used for saving.\
17 | """
18 |
--------------------------------------------------------------------------------
/pertpy/_types.py:
--------------------------------------------------------------------------------
1 | from scipy import sparse
2 |
3 | CSBase = sparse.csr_matrix | sparse.csc_matrix
4 | CSRBase = sparse.csr_matrix
5 | CSCBase = sparse.csc_matrix
6 | SpBase = sparse.spmatrix
7 |
--------------------------------------------------------------------------------
/pertpy/data/__init__.py:
--------------------------------------------------------------------------------
1 | from pertpy.data._datasets import (
2 | adamson_2016_pilot,
3 | adamson_2016_upr_epistasis,
4 | adamson_2016_upr_perturb_seq,
5 | aissa_2021,
6 | bhattacherjee,
7 | burczynski_crohn,
8 | chang_2021,
9 | cinemaot_example,
10 | combosciplex,
11 | datlinger_2017,
12 | datlinger_2021,
13 | dialogue_example,
14 | distance_example,
15 | dixit_2016,
16 | dixit_2016_raw,
17 | dong_2023,
18 | frangieh_2021,
19 | frangieh_2021_protein,
20 | frangieh_2021_raw,
21 | frangieh_2021_rna,
22 | gasperini_2019_atscale,
23 | gasperini_2019_highmoi,
24 | gasperini_2019_lowmoi,
25 | gehring_2019,
26 | haber_2017_regions,
27 | hagai_2018,
28 | kang_2018,
29 | mcfarland_2020,
30 | norman_2019,
31 | norman_2019_raw,
32 | papalexi_2021,
33 | replogle_2022_k562_essential,
34 | replogle_2022_k562_gwps,
35 | replogle_2022_rpe1,
36 | sc_sim_augur,
37 | schiebinger_2019_16day,
38 | schiebinger_2019_18day,
39 | schraivogel_2020_tap_screen_chr8,
40 | schraivogel_2020_tap_screen_chr11,
41 | sciplex3_raw,
42 | sciplex_gxe1,
43 | shifrut_2018,
44 | smillie_2019,
45 | srivatsan_2020_sciplex2,
46 | srivatsan_2020_sciplex3,
47 | srivatsan_2020_sciplex4,
48 | stephenson_2021_subsampled,
49 | tasccoda_example,
50 | tian_2019_day7neuron,
51 | tian_2019_ipsc,
52 | tian_2021_crispra,
53 | tian_2021_crispri,
54 | weinreb_2020,
55 | xie_2017,
56 | zhang_2021,
57 | zhao_2021,
58 | )
59 |
60 | __all__ = [
61 | "adamson_2016_pilot",
62 | "adamson_2016_upr_epistasis",
63 | "adamson_2016_upr_perturb_seq",
64 | "aissa_2021",
65 | "bhattacherjee",
66 | "burczynski_crohn",
67 | "chang_2021",
68 | "cinemaot_example",
69 | "combosciplex",
70 | "datlinger_2017",
71 | "datlinger_2021",
72 | "dialogue_example",
73 | "distance_example",
74 | "dixit_2016",
75 | "dixit_2016_raw",
76 | "dong_2023",
77 | "frangieh_2021",
78 | "frangieh_2021_protein",
79 | "frangieh_2021_raw",
80 | "frangieh_2021_rna",
81 | "gasperini_2019_atscale",
82 | "gasperini_2019_highmoi",
83 | "gasperini_2019_lowmoi",
84 | "gehring_2019",
85 | "haber_2017_regions",
86 | "hagai_2018",
87 | "kang_2018",
88 | "mcfarland_2020",
89 | "norman_2019",
90 | "norman_2019_raw",
91 | "papalexi_2021",
92 | "replogle_2022_k562_essential",
93 | "replogle_2022_k562_gwps",
94 | "replogle_2022_rpe1",
95 | "sc_sim_augur",
96 | "schiebinger_2019_16day",
97 | "schiebinger_2019_18day",
98 | "schraivogel_2020_tap_screen_chr8",
99 | "schraivogel_2020_tap_screen_chr11",
100 | "sciplex3_raw",
101 | "sciplex_gxe1",
102 | "shifrut_2018",
103 | "smillie_2019",
104 | "srivatsan_2020_sciplex2",
105 | "srivatsan_2020_sciplex3",
106 | "srivatsan_2020_sciplex4",
107 | "stephenson_2021_subsampled",
108 | "tasccoda_example",
109 | "tian_2019_day7neuron",
110 | "tian_2019_ipsc",
111 | "tian_2021_crispra",
112 | "tian_2021_crispri",
113 | "weinreb_2020",
114 | "xie_2017",
115 | "zhao_2021",
116 | "zhang_2021",
117 | ]
118 |
--------------------------------------------------------------------------------
/pertpy/data/_dataloader.py:
--------------------------------------------------------------------------------
1 | import shutil
2 | import tempfile
3 | import time
4 | from pathlib import Path
5 | from random import choice
6 | from string import ascii_lowercase
7 | from zipfile import ZipFile
8 |
9 | import requests
10 | from filelock import FileLock
11 | from lamin_utils import logger
12 | from requests.exceptions import RequestException
13 | from rich.progress import Progress
14 |
15 |
16 | def _download( # pragma: no cover
17 | url: str,
18 | output_file_name: str | None = None,
19 | output_path: str | Path | None = None,
20 | block_size: int = 1024,
21 | overwrite: bool = False,
22 | is_zip: bool = False,
23 | timeout: int = 30,
24 | max_retries: int = 3,
25 | retry_delay: int = 5,
26 | ) -> Path:
27 | """Downloads a dataset irrespective of the format.
28 |
29 | Args:
30 | url: URL to download
31 | output_file_name: Name of the downloaded file
32 | output_path: Path to download/extract the files to.
33 | block_size: Block size for downloads in bytes.
34 | overwrite: Whether to overwrite existing files.
35 | is_zip: Whether the downloaded file needs to be unzipped.
36 | timeout: Request timeout in seconds.
37 | max_retries: Maximum number of retry attempts.
38 | retry_delay: Delay between retries in seconds.
39 | """
40 | if output_file_name is None:
41 | letters = ascii_lowercase
42 | output_file_name = f"pertpy_tmp_{''.join(choice(letters) for _ in range(10))}"
43 |
44 | if output_path is None:
45 | output_path = tempfile.gettempdir()
46 |
47 | download_to_path = Path(output_path) / output_file_name
48 |
49 | Path(output_path).mkdir(parents=True, exist_ok=True)
50 | lock_path = Path(output_path) / f"{output_file_name}.lock"
51 |
52 | with FileLock(lock_path, timeout=300):
53 | if Path(download_to_path).exists() and not overwrite:
54 | logger.warning(f"File {download_to_path} already exists!")
55 | return download_to_path
56 |
57 | temp_file_name = Path(f"{download_to_path}.part")
58 |
59 | retry_count = 0
60 | while retry_count <= max_retries:
61 | try:
62 | head_response = requests.head(url, timeout=timeout)
63 | head_response.raise_for_status()
64 | content_length = int(head_response.headers.get("content-length", 0))
65 |
66 | free_space = shutil.disk_usage(output_path).free
67 | if content_length > free_space:
68 | raise OSError(
69 | f"Insufficient disk space. Need {content_length} bytes, but only {free_space} available."
70 | )
71 |
72 | response = requests.get(url, stream=True)
73 | response.raise_for_status()
74 | total = int(response.headers.get("content-length", 0))
75 |
76 | with Progress(refresh_per_second=5) as progress:
77 | task = progress.add_task("[red]Downloading...", total=total)
78 | with Path(temp_file_name).open("wb") as file:
79 | for data in response.iter_content(block_size):
80 | file.write(data)
81 | progress.update(task, advance=len(data))
82 | progress.update(task, completed=total, refresh=True)
83 |
84 | Path(temp_file_name).replace(download_to_path)
85 |
86 | if is_zip:
87 | with ZipFile(download_to_path, "r") as zip_obj:
88 | zip_obj.extractall(path=output_path)
89 | return Path(output_path)
90 |
91 | return download_to_path
92 | except (OSError, RequestException) as e:
93 | retry_count += 1
94 | if retry_count <= max_retries:
95 | logger.warning(
96 | f"Download attempt {retry_count}/{max_retries} failed: {str(e)}. Retrying in {retry_delay} seconds..."
97 | )
98 | time.sleep(retry_delay)
99 | else:
100 | logger.error(f"Download failed after {max_retries} attempts: {str(e)}")
101 | if Path(temp_file_name).exists():
102 | Path(temp_file_name).unlink(missing_ok=True)
103 | raise
104 |
105 | except Exception as e:
106 | logger.error(f"Download failed: {str(e)}")
107 | if Path(temp_file_name).exists():
108 | Path(temp_file_name).unlink(missing_ok=True)
109 | raise
110 | finally:
111 | if Path(temp_file_name).exists():
112 | Path(temp_file_name).unlink(missing_ok=True)
113 |
114 | return Path(download_to_path)
115 |
--------------------------------------------------------------------------------
/pertpy/metadata/__init__.py:
--------------------------------------------------------------------------------
1 | from pertpy.metadata._cell_line import CellLine
2 | from pertpy.metadata._compound import Compound
3 | from pertpy.metadata._drug import Drug
4 | from pertpy.metadata._look_up import LookUp
5 | from pertpy.metadata._moa import Moa
6 |
7 | __all__ = ["CellLine", "Compound", "Drug", "Moa", "LookUp"]
8 |
--------------------------------------------------------------------------------
/pertpy/metadata/_compound.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import TYPE_CHECKING, Literal
4 |
5 | import pandas as pd
6 |
7 | from ._look_up import LookUp
8 | from ._metadata import MetaData
9 |
10 | if TYPE_CHECKING:
11 | from anndata import AnnData
12 |
13 |
14 | class Compound(MetaData):
15 | """Utilities to fetch metadata for compounds."""
16 |
17 | def __init__(self):
18 | super().__init__()
19 |
20 | def annotate_compounds(
21 | self,
22 | adata: AnnData,
23 | query_id: str = "perturbation",
24 | query_id_type: Literal["name", "cid"] = "name",
25 | verbosity: int | str = 5,
26 | copy: bool = False,
27 | ) -> AnnData:
28 | """Fetch compound annotation from pubchempy.
29 |
30 | Args:
31 | adata: The data object to annotate.
32 | query_id: The column of `.obs` with compound identifiers.
33 | query_id_type: The type of compound identifiers, 'name' or 'cid'.
34 | verbosity: The number of unmatched identifiers to print, can be either non-negative values or "all".
35 | copy: Determines whether a copy of the `adata` is returned.
36 |
37 | Returns:
38 | Returns an AnnData object with compound annotation.
39 | """
40 | if copy:
41 | adata = adata.copy()
42 |
43 | if query_id not in adata.obs.columns:
44 | raise ValueError(f"The requested query_id {query_id} is not in `adata.obs`.\n Please check again.")
45 |
46 | import pubchempy as pcp
47 |
48 | query_dict = {}
49 | not_matched_identifiers = []
50 | for compound in adata.obs[query_id].dropna().astype(str).unique():
51 | if query_id_type == "name":
52 | cids = pcp.get_compounds(compound, "name")
53 | if len(cids) == 0: # search did not work
54 | not_matched_identifiers.append(compound)
55 | if len(cids) >= 1:
56 | # If the name matches the first synonym offered by PubChem (outside of capitalization),
57 | # it is not changed (outside of capitalization). Otherwise, it is replaced with the first synonym.
58 | query_dict[compound] = [
59 | cids[0].synonyms[0],
60 | cids[0].cid,
61 | cids[0].canonical_smiles,
62 | ]
63 | else:
64 | try:
65 | cid = pcp.Compound.from_cid(compound)
66 | query_dict[compound] = [
67 | cid.synonyms[0],
68 | compound,
69 | cid.canonical_smiles,
70 | ]
71 | except pcp.BadRequestError:
72 | # pubchempy throws badrequest if a cid is not found
73 | not_matched_identifiers.append(compound)
74 |
75 | identifier_num_all = len(adata.obs[query_id].unique())
76 | self._warn_unmatch(
77 | total_identifiers=identifier_num_all,
78 | unmatched_identifiers=not_matched_identifiers,
79 | query_id=query_id,
80 | reference_id=query_id_type,
81 | metadata_type="compound",
82 | verbosity=verbosity,
83 | )
84 |
85 | query_df = pd.DataFrame.from_dict(query_dict, orient="index", columns=["pubchem_name", "pubchem_ID", "smiles"])
86 | # Merge and remove duplicate columns
87 | # Column is converted to float after merging due to unmatches
88 | # Convert back to integers afterwards
89 | if query_id_type == "cid":
90 | query_df.pubchem_ID = query_df.pubchem_ID.astype("Int64")
91 | adata.obs = (
92 | adata.obs.merge(
93 | query_df,
94 | left_on=query_id,
95 | right_on="pubchem_ID",
96 | how="left",
97 | suffixes=("", "_fromMeta"),
98 | )
99 | .filter(regex="^(?!.*_fromMeta)")
100 | .set_index(adata.obs.index)
101 | )
102 | else:
103 | adata.obs = (
104 | adata.obs.merge(
105 | query_df,
106 | left_on=query_id,
107 | right_index=True,
108 | how="left",
109 | suffixes=("", "_fromMeta"),
110 | )
111 | .filter(regex="^(?!.*_fromMeta)")
112 | .set_index(adata.obs.index)
113 | )
114 | adata.obs.pubchem_ID = adata.obs.pubchem_ID.astype("Int64")
115 |
116 | return adata
117 |
118 | def lookup(self) -> LookUp:
119 | """Generate LookUp object for CompoundMetaData.
120 |
121 | The LookUp object provides an overview of the metadata to annotate.
122 | Each annotate_{metadata} function has a corresponding lookup function in the LookUp object,
123 | where users can search the reference_id in the metadata and compare with the query_id in their own data.
124 |
125 | Returns:
126 | Returns a LookUp object specific for compound annotation.
127 | """
128 | return LookUp(type="compound")
129 |
--------------------------------------------------------------------------------
/pertpy/metadata/_metadata.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import TYPE_CHECKING, Literal
4 |
5 | from lamin_utils import logger
6 |
7 | if TYPE_CHECKING:
8 | from collections.abc import Sequence
9 |
10 |
11 | class MetaData:
12 | """Superclass for pertpy's MetaData components."""
13 |
14 | def _warn_unmatch(
15 | self,
16 | total_identifiers: int,
17 | unmatched_identifiers: Sequence[str],
18 | query_id: str,
19 | reference_id: str,
20 | metadata_type: Literal[
21 | "cell line",
22 | "protein expression",
23 | "bulk RNA",
24 | "drug response",
25 | "moa",
26 | "compound",
27 | ] = "cell line",
28 | verbosity: int | str = 5,
29 | ) -> None:
30 | """Helper function to print out the unmatched identifiers.
31 |
32 | Args:
33 | total_identifiers: The total number of identifiers in the `adata` object.
34 | unmatched_identifiers: Unmatched identifiers in the `adata` object.
35 | query_id: The column of `.obs` with cell line information.
36 | reference_id: The type of cell line identifier in the metadata.
37 | metadata_type: The type of metadata where some identifiers are not matched during annotation such as
38 | cell line, protein expression, bulk RNA expression, drug response, moa or compound.
39 | verbosity: The number of unmatched identifiers to print, can be either non-negative values or 'all'.
40 | """
41 | if isinstance(verbosity, str):
42 | if verbosity != "all":
43 | raise ValueError("Only a non-negative value or 'all' is accepted.")
44 | else:
45 | verbosity = len(unmatched_identifiers)
46 |
47 | if len(unmatched_identifiers) == total_identifiers:
48 | hint = ""
49 | if metadata_type in ["protein expression", "bulk RNA", "drug response"]:
50 | hint = "Additionally, call the `CellLineMetaData.annotate()` function to acquire more possible query IDs that can be used for cell line annotation purposes."
51 | raise ValueError(
52 | f"No matches between `{query_id}` in adata.obs and `{reference_id}` in {metadata_type} data. "
53 | f"Use `lookup()` to check compatible identifier types. {hint}"
54 | )
55 | if len(unmatched_identifiers) == 0:
56 | return
57 | if isinstance(verbosity, int) and verbosity >= 0:
58 | verbosity = min(verbosity, len(unmatched_identifiers))
59 | if verbosity > 0:
60 | logger.info(
61 | f"{total_identifiers} identifiers in `adata.obs`, {len(unmatched_identifiers)} not found in {metadata_type} data. "
62 | f"NA values present. Unmatched: {unmatched_identifiers[:verbosity]}"
63 | )
64 | else:
65 | raise ValueError("Only 'all' or a non-negative value is accepted.")
66 |
--------------------------------------------------------------------------------
/pertpy/metadata/_moa.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from pathlib import Path
4 | from typing import TYPE_CHECKING
5 |
6 | import numpy as np
7 | import pandas as pd
8 | from scanpy import settings
9 |
10 | from pertpy.data._dataloader import _download
11 |
12 | from ._look_up import LookUp
13 | from ._metadata import MetaData
14 |
15 | if TYPE_CHECKING:
16 | from anndata import AnnData
17 |
18 |
19 | class Moa(MetaData):
20 | """Utilities to fetch metadata for mechanism of action studies."""
21 |
22 | def __init__(self):
23 | self.clue = None
24 |
25 | def _download_clue(self) -> None:
26 | clue_path = Path(settings.cachedir) / "repurposing_drugs_20200324.txt"
27 | if not Path(clue_path).exists():
28 | _download(
29 | url="https://s3.amazonaws.com/data.clue.io/repurposing/downloads/repurposing_drugs_20200324.txt",
30 | output_file_name="repurposing_drugs_20200324.txt",
31 | output_path=settings.cachedir,
32 | block_size=4096,
33 | is_zip=False,
34 | )
35 | self.clue = pd.read_csv(clue_path, sep=" ", skiprows=9)
36 | self.clue = self.clue[["pert_iname", "moa", "target"]]
37 |
38 | def annotate(
39 | self,
40 | adata: AnnData,
41 | query_id: str = "perturbation",
42 | target: str | None = None,
43 | verbosity: int | str = 5,
44 | copy: bool = False,
45 | ) -> AnnData:
46 | """Annotate cells affected by perturbations by mechanism of action.
47 |
48 | For each cell, we fetch the mechanism of action and molecular targets of the compounds sourced from clue.io.
49 |
50 | Args:
51 | adata: The data object to annotate.
52 | query_id: The column of `.obs` with the name of a perturbagen.
53 | target: The column of `.obs` with target information. If set to None, all MoAs are retrieved without comparing molecular targets.
54 | verbosity: The number of unmatched identifiers to print, can be either non-negative values or 'all'.
55 | copy: Determines whether a copy of the `adata` is returned.
56 |
57 | Returns:
58 | Returns an AnnData object with MoA annotation.
59 | """
60 | if copy:
61 | adata = adata.copy()
62 |
63 | if query_id not in adata.obs.columns:
64 | raise ValueError(f"The requested query_id {query_id} is not in `adata.obs`.\nPlease check again.")
65 |
66 | if self.clue is None:
67 | self._download_clue()
68 |
69 | identifier_num_all = len(adata.obs[query_id].unique())
70 | not_matched_identifiers = list(set(adata.obs[query_id].str.lower()) - set(self.clue["pert_iname"].str.lower()))
71 | self._warn_unmatch(
72 | total_identifiers=identifier_num_all,
73 | unmatched_identifiers=not_matched_identifiers,
74 | query_id=query_id,
75 | reference_id="pert_iname",
76 | metadata_type="moa",
77 | verbosity=verbosity,
78 | )
79 |
80 | adata.obs = (
81 | adata.obs.merge(
82 | self.clue,
83 | left_on=adata.obs[query_id].str.lower(),
84 | right_on=self.clue["pert_iname"].str.lower(),
85 | how="left",
86 | suffixes=("", "_fromMeta"),
87 | )
88 | .set_index(adata.obs.index)
89 | .drop("key_0", axis=1)
90 | )
91 |
92 | # If target column is given, check whether it is one of the targets listed in the metadata
93 | # If inconsistent, treat this perturbagen as unmatched and overwrite the annotated metadata with NaN
94 | if target is not None:
95 | target_meta = "target" if target != "target" else "target_fromMeta"
96 | adata.obs[target_meta] = adata.obs[target_meta].mask(
97 | ~adata.obs.apply(lambda row: str(row[target]) in str(row[target_meta]), axis=1)
98 | )
99 | pertname_meta = "pert_iname" if query_id != "pert_iname" else "pert_iname_fromMeta"
100 | adata.obs.loc[adata.obs[target_meta].isna(), [pertname_meta, "moa"]] = np.nan
101 |
102 | # If query_id and reference_id have different names, there will be a column for each of them after merging
103 | # which is redundant as they refer to the same information.
104 | if query_id != "pert_iname":
105 | del adata.obs["pert_iname"]
106 |
107 | return adata
108 |
109 | def lookup(self) -> LookUp:
110 | """Generate LookUp object for Moa metadata.
111 |
112 | The LookUp object provides an overview of the metadata to annotate.
113 |
114 | Returns:
115 | Returns a LookUp object specific for MoA annotation.
116 | """
117 | if self.clue is None:
118 | self._download_clue()
119 |
120 | return LookUp(
121 | type="moa",
122 | transfer_metadata=[self.clue],
123 | )
124 |
--------------------------------------------------------------------------------
/pertpy/plot/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/pertpy/plot/__init__.py
--------------------------------------------------------------------------------
/pertpy/preprocessing/__init__.py:
--------------------------------------------------------------------------------
1 | from ._guide_rna import GuideAssignment
2 |
3 | __all__ = ["GuideAssignment"]
4 |
--------------------------------------------------------------------------------
/pertpy/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/pertpy/py.typed
--------------------------------------------------------------------------------
/pertpy/tools/__init__.py:
--------------------------------------------------------------------------------
1 | from importlib import import_module
2 |
3 |
4 | def lazy_import(module_path: str, class_name: str, extras: list[str]):
5 | try:
6 | for extra in extras:
7 | import_module(extra)
8 | module = import_module(module_path)
9 | return getattr(module, class_name)
10 | except ImportError:
11 |
12 | class Placeholder:
13 | def __init__(self, *args, **kwargs):
14 | raise ImportError(
15 | f"Extra dependencies required: {', '.join(extras)}. "
16 | f"Please install with: pip install {' '.join(extras)}"
17 | )
18 |
19 | return Placeholder
20 |
21 |
22 | from pertpy.tools._augur import Augur
23 | from pertpy.tools._cinemaot import Cinemaot
24 | from pertpy.tools._coda._sccoda import Sccoda
25 | from pertpy.tools._dialogue import Dialogue
26 | from pertpy.tools._distances._distance_tests import DistanceTest
27 | from pertpy.tools._distances._distances import Distance
28 | from pertpy.tools._enrichment import Enrichment
29 | from pertpy.tools._milo import Milo
30 | from pertpy.tools._mixscape import Mixscape
31 | from pertpy.tools._perturbation_space._clustering import ClusteringSpace
32 | from pertpy.tools._perturbation_space._comparison import PerturbationComparison
33 | from pertpy.tools._perturbation_space._discriminator_classifiers import (
34 | LRClassifierSpace,
35 | MLPClassifierSpace,
36 | )
37 | from pertpy.tools._perturbation_space._simple import (
38 | CentroidSpace,
39 | DBSCANSpace,
40 | KMeansSpace,
41 | PseudobulkSpace,
42 | )
43 | from pertpy.tools._scgen import Scgen
44 |
45 | CODA_EXTRAS = ["toytree", "ete4"] # also "pyqt6" but it cannot be imported
46 | Tasccoda = lazy_import("pertpy.tools._coda._tasccoda", "Tasccoda", CODA_EXTRAS)
47 |
48 | DE_EXTRAS = ["formulaic", "pydeseq2"]
49 | EdgeR = lazy_import("pertpy.tools._differential_gene_expression", "EdgeR", DE_EXTRAS) # edgeR will be imported via rpy2
50 | PyDESeq2 = lazy_import("pertpy.tools._differential_gene_expression", "PyDESeq2", DE_EXTRAS)
51 | Statsmodels = lazy_import("pertpy.tools._differential_gene_expression", "Statsmodels", DE_EXTRAS + ["statsmodels"])
52 | TTest = lazy_import("pertpy.tools._differential_gene_expression", "TTest", DE_EXTRAS)
53 | WilcoxonTest = lazy_import("pertpy.tools._differential_gene_expression", "WilcoxonTest", DE_EXTRAS)
54 |
55 | __all__ = [
56 | "Augur",
57 | "Cinemaot",
58 | "Sccoda",
59 | "Tasccoda",
60 | "Dialogue",
61 | "EdgeR",
62 | "PyDESeq2",
63 | "WilcoxonTest",
64 | "TTest",
65 | "Statsmodels",
66 | "DistanceTest",
67 | "Distance",
68 | "Enrichment",
69 | "Milo",
70 | "Mixscape",
71 | "ClusteringSpace",
72 | "PerturbationComparison",
73 | "LRClassifierSpace",
74 | "MLPClassifierSpace",
75 | "CentroidSpace",
76 | "DBSCANSpace",
77 | "KMeansSpace",
78 | "PseudobulkSpace",
79 | "Scgen",
80 | ]
81 |
--------------------------------------------------------------------------------
/pertpy/tools/_coda/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/pertpy/tools/_coda/__init__.py
--------------------------------------------------------------------------------
/pertpy/tools/_differential_gene_expression/__init__.py:
--------------------------------------------------------------------------------
1 | from ._base import LinearModelBase, MethodBase
2 | from ._dge_comparison import DGEEVAL
3 | from ._edger import EdgeR
4 | from ._pydeseq2 import PyDESeq2
5 | from ._simple_tests import SimpleComparisonBase, TTest, WilcoxonTest
6 | from ._statsmodels import Statsmodels
7 |
8 | __all__ = [
9 | "MethodBase",
10 | "LinearModelBase",
11 | "EdgeR",
12 | "PyDESeq2",
13 | "Statsmodels",
14 | "SimpleComparisonBase",
15 | "WilcoxonTest",
16 | "TTest",
17 | ]
18 |
19 | AVAILABLE_METHODS = [Statsmodels, EdgeR, PyDESeq2, WilcoxonTest, TTest]
20 |
--------------------------------------------------------------------------------
/pertpy/tools/_differential_gene_expression/_checks.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.sparse import issparse, spmatrix
3 |
4 |
5 | def check_is_numeric_matrix(array: np.ndarray | spmatrix) -> None:
6 | """Check if a matrix is numeric and only contains finite/non-NA values.
7 |
8 | Args:
9 | array: Dense or sparse matrix to check.
10 |
11 | Raises:
12 | ValueError: If the matrix is not numeric or contains NaNs or infinite values.
13 | """
14 | if not np.issubdtype(array.dtype, np.number):
15 | raise ValueError("Counts must be numeric.")
16 | if issparse(array):
17 | if np.any(~np.isfinite(array.data)):
18 | raise ValueError("Counts cannot contain negative, NaN or Inf values.")
19 | elif np.any(~np.isfinite(array)):
20 | raise ValueError("Counts cannot contain negative, NaN or Inf values.")
21 |
22 |
23 | def check_is_integer_matrix(array: np.ndarray | spmatrix, tolerance: float = 1e-6) -> None:
24 | """Check if a matrix container integers, or floats that are close to integers.
25 |
26 | Args:
27 | array: Dense or sparse matrix to check.
28 | tolerance: Values must be this close to integers.
29 |
30 | Raises:
31 | ValueError: If the matrix contains values that are not close to integers.
32 | """
33 | if issparse(array):
34 | if not array.data.dtype.kind == "i" and not np.all(np.abs(array.data - np.round(array.data)) < tolerance):
35 | raise ValueError("Non-zero elements of the matrix must be close to integer values.")
36 | elif array.dtype.kind != "i" and not np.all(np.abs(array - np.round(array)) < tolerance):
37 | raise ValueError("Matrix must be a count matrix.")
38 | if (array < 0).sum() > 0:
39 | raise ValueError("Non-zero elements of the matrix must be positive.")
40 |
--------------------------------------------------------------------------------
/pertpy/tools/_differential_gene_expression/_dge_comparison.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | from anndata import AnnData
4 |
5 |
6 | class DGEEVAL:
7 | def compare(
8 | self,
9 | adata: AnnData | None = None,
10 | de_key1: str = None,
11 | de_key2: str = None,
12 | de_df1: pd.DataFrame | None = None,
13 | de_df2: pd.DataFrame | None = None,
14 | shared_top: int = 100,
15 | ) -> dict[str, float]:
16 | """Compare two differential expression analyses.
17 |
18 | Compare two sets of DE results and evaluate the similarity by the overlap of top DEG and
19 | the correlation of their scores and adjusted p-values.
20 |
21 | Args:
22 | adata: AnnData object containing DE results in `uns`. Required if `de_key1` and `de_key2` are used.
23 | de_key1: Key for DE results in `adata.uns`, e.g., output of `tl.rank_genes_groups`.
24 | de_key2: Another key for DE results in `adata.uns`, e.g., output of `tl.rank_genes_groups`.
25 | de_df1: DataFrame containing DE results, e.g. output from pertpy differential gene expression interface.
26 | de_df2: DataFrame containing DE results, e.g. output from pertpy differential gene expression interface.
27 | shared_top: The number of top DEG to compute the proportion of their intersection.
28 |
29 | """
30 | if (de_key1 or de_key2) and (de_df1 is not None or de_df2 is not None):
31 | raise ValueError(
32 | "Please provide either both `de_key1` and `de_key2` with `adata`, or `de_df1` and `de_df2`, but not both."
33 | )
34 |
35 | if de_df1 is None and de_df2 is None: # use keys
36 | if not de_key1 or not de_key2:
37 | raise ValueError("Both `de_key1` and `de_key2` must be provided together if using `adata`.")
38 |
39 | elif de_df1 is None or de_df2 is None:
40 | raise ValueError("Both `de_df1` and `de_df2` must be provided together if using DataFrames.")
41 |
42 | if de_key1:
43 | if not adata:
44 | raise ValueError("`adata` should be provided with `de_key1` and `de_key2`. ")
45 | assert all(k in adata.uns for k in [de_key1, de_key2]), (
46 | "Provided `de_key1` and `de_key2` must exist in `adata.uns`."
47 | )
48 | vars = adata.var_names
49 |
50 | if de_df1 is not None:
51 | for df in (de_df1, de_df2):
52 | if not {"variable", "log_fc", "adj_p_value"}.issubset(df.columns):
53 | raise ValueError("Each DataFrame must contain columns: 'variable', 'log_fc', and 'adj_p_value'.")
54 |
55 | assert set(de_df1["variable"]) == set(de_df2["variable"]), "Variables in both dataframes must match."
56 | vars = de_df1["variable"].sort_values()
57 |
58 | shared_top = min(shared_top, len(vars))
59 | vars_ranks = np.arange(1, len(vars) + 1)
60 | results = pd.DataFrame(index=vars)
61 | top_names = []
62 |
63 | if de_key1 and de_key2:
64 | for i, k in enumerate([de_key1, de_key2]):
65 | label = adata.uns[k]["names"].dtype.names[0]
66 | srt_idx = np.argsort(adata.uns[k]["names"][label])
67 | results[f"scores_{i}"] = adata.uns[k]["scores"][label][srt_idx]
68 | results[f"pvals_adj_{i}"] = adata.uns[k]["pvals_adj"][label][srt_idx]
69 | results[f"ranks_{i}"] = vars_ranks[srt_idx]
70 | top_names.append(adata.uns[k]["names"][label][:shared_top])
71 | else:
72 | for i, df in enumerate([de_df1, de_df2]):
73 | srt_idx = np.argsort(df["variable"])
74 | results[f"scores_{i}"] = df["log_fc"].values[srt_idx]
75 | results[f"pvals_adj_{i}"] = df["adj_p_value"].values[srt_idx]
76 | results[f"ranks_{i}"] = vars_ranks[srt_idx]
77 | top_names.append(df["variable"][:shared_top])
78 |
79 | metrics = {}
80 | metrics["shared_top_genes"] = len(set(top_names[0]).intersection(top_names[1])) / shared_top
81 | metrics["scores_corr"] = results["scores_0"].corr(results["scores_1"], method="pearson")
82 | metrics["pvals_adj_corr"] = results["pvals_adj_0"].corr(results["pvals_adj_1"], method="pearson")
83 | metrics["scores_ranks_corr"] = results["ranks_0"].corr(results["ranks_1"], method="spearman")
84 |
85 | return metrics
86 |
--------------------------------------------------------------------------------
/pertpy/tools/_differential_gene_expression/_edger.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Sequence
2 |
3 | import numpy as np
4 | import pandas as pd
5 | from lamin_utils import logger
6 | from scipy.sparse import issparse
7 |
8 | from ._base import LinearModelBase
9 | from ._checks import check_is_integer_matrix
10 |
11 |
12 | class EdgeR(LinearModelBase):
13 | """Differential expression test using EdgeR."""
14 |
15 | def _check_counts(self):
16 | check_is_integer_matrix(self.data)
17 |
18 | def fit(self, **kwargs): # adata, design, mask, layer
19 | """Fit model using edgeR.
20 |
21 | Note: this creates its own AnnData object for downstream.
22 |
23 | Args:
24 | **kwargs: Keyword arguments specific to glmQLFit()
25 | """
26 | # For running in notebook
27 | # pandas2ri.activate()
28 | # rpy2.robjects.numpy2ri.activate()
29 | try:
30 | from rpy2 import robjects as ro
31 | from rpy2.robjects import numpy2ri, pandas2ri
32 | from rpy2.robjects.conversion import get_conversion, localconverter
33 | from rpy2.robjects.packages import importr
34 |
35 | except ImportError:
36 | raise ImportError("edger requires rpy2 to be installed.") from None
37 |
38 | try:
39 | edger = importr("edgeR")
40 | except ImportError as e:
41 | raise ImportError(
42 | "edgeR requires a valid R installation with the following packages:\nedgeR, BiocParallel, RhpcBLASctl"
43 | ) from e
44 |
45 | # Convert dataframe
46 | with localconverter(get_conversion() + numpy2ri.converter):
47 | expr = self.adata.X if self.layer is None else self.adata.layers[self.layer]
48 | expr = expr.T.toarray() if issparse(expr) else expr.T
49 |
50 | with localconverter(get_conversion() + pandas2ri.converter):
51 | expr_r = ro.conversion.py2rpy(pd.DataFrame(expr, index=self.adata.var_names, columns=self.adata.obs_names))
52 | samples_r = ro.conversion.py2rpy(self.adata.obs)
53 |
54 | dge = edger.DGEList(counts=expr_r, samples=samples_r)
55 |
56 | logger.info("Calculating NormFactors")
57 | dge = edger.calcNormFactors(dge)
58 |
59 | with localconverter(get_conversion() + numpy2ri.converter):
60 | design_r = ro.conversion.py2rpy(self.design.values)
61 |
62 | logger.info("Estimating Dispersions")
63 | dge = edger.estimateDisp(dge, design=design_r)
64 |
65 | logger.info("Fitting linear model")
66 | fit = edger.glmQLFit(dge, design=design_r, **kwargs)
67 |
68 | ro.globalenv["fit"] = fit
69 | self.fit = fit
70 |
71 | def _test_single_contrast(self, contrast: Sequence[float], **kwargs) -> pd.DataFrame: # noqa: D417
72 | """Conduct test for each contrast and return a data frame.
73 |
74 | Args:
75 | contrast: numpy array of integars indicating contrast i.e. [-1, 0, 1, 0, 0]
76 | """
77 | ## -- Check installations
78 | # For running in notebook
79 | # pandas2ri.activate()
80 | # rpy2.robjects.numpy2ri.activate()
81 |
82 | # ToDo:
83 | # parse **kwargs to R function
84 | # Fix mask for .fit()
85 |
86 | try:
87 | from rpy2 import robjects as ro
88 | from rpy2.robjects import numpy2ri, pandas2ri
89 | from rpy2.robjects.conversion import get_conversion, localconverter
90 | from rpy2.robjects.packages import importr
91 |
92 | except ImportError:
93 | raise ImportError("edger requires rpy2 to be installed.") from None
94 |
95 | try:
96 | importr("edgeR")
97 | except ImportError:
98 | raise ImportError(
99 | "edgeR requires a valid R installation with the following packages: edgeR, BiocParallel, RhpcBLASctl"
100 | ) from None
101 |
102 | # Convert vector to R, which drops a category like `self.design_matrix` to use the intercept for the left out.
103 | with localconverter(get_conversion() + numpy2ri.converter):
104 | contrast_vec_r = ro.conversion.py2rpy(np.asarray(contrast))
105 | ro.globalenv["contrast_vec"] = contrast_vec_r
106 |
107 | # Test contrast with R
108 | ro.r(
109 | """
110 | test = edgeR::glmQLFTest(fit, contrast=contrast_vec)
111 | de_res = edgeR::topTags(test, n=Inf, adjust.method="BH", sort.by="PValue")$table
112 | """
113 | )
114 |
115 | # Retrieve the `de_res` object
116 | de_res = ro.globalenv["de_res"]
117 |
118 | # If already a Pandas DataFrame, return it directly
119 | if isinstance(de_res, pd.DataFrame):
120 | de_res.index.name = "variable"
121 | return de_res.reset_index().rename(columns={"PValue": "p_value", "logFC": "log_fc", "FDR": "adj_p_value"})
122 |
123 | # Convert to Pandas DataFrame if still an R object
124 | with localconverter(get_conversion() + pandas2ri.converter):
125 | de_res = ro.conversion.rpy2py(de_res)
126 |
127 | de_res.index.name = "variable"
128 | de_res = de_res.reset_index()
129 |
130 | return de_res.rename(columns={"PValue": "p_value", "logFC": "log_fc", "FDR": "adj_p_value"})
131 |
--------------------------------------------------------------------------------
/pertpy/tools/_differential_gene_expression/_pydeseq2.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import warnings
4 |
5 | import numpy as np
6 | import pandas as pd
7 | from anndata import AnnData
8 | from numpy import ndarray
9 | from pydeseq2.dds import DeseqDataSet
10 | from pydeseq2.default_inference import DefaultInference
11 | from pydeseq2.ds import DeseqStats
12 | from scipy.sparse import issparse
13 |
14 | from ._base import LinearModelBase
15 | from ._checks import check_is_integer_matrix
16 |
17 |
18 | class PyDESeq2(LinearModelBase):
19 | """Differential expression test using a PyDESeq2."""
20 |
21 | def __init__(
22 | self, adata: AnnData, design: str | ndarray, *, mask: str | None = None, layer: str | None = None, **kwargs
23 | ):
24 | super().__init__(adata, design, mask=mask, layer=layer, **kwargs)
25 | # work around pydeseq2 issue with sparse matrices
26 | # see also https://github.com/owkin/PyDESeq2/issues/25
27 | if issparse(self.data):
28 | if self.layer is None:
29 | self.adata.X = self.adata.X.toarray()
30 | else:
31 | self.adata.layers[self.layer] = self.adata.layers[self.layer].toarray()
32 |
33 | def _check_counts(self):
34 | check_is_integer_matrix(self.data)
35 |
36 | def fit(self, **kwargs) -> pd.DataFrame:
37 | """Fit dds model using pydeseq2.
38 |
39 | Note: this creates its own AnnData object for downstream processing.
40 |
41 | Args:
42 | **kwargs: Keyword arguments specific to DeseqDataSet(), except for `n_cpus` which will use all available CPUs minus one if the argument is not passed.
43 | """
44 | try:
45 | usable_cpus = len(os.sched_getaffinity(0))
46 | except AttributeError:
47 | usable_cpus = os.cpu_count()
48 |
49 | inference = DefaultInference(n_cpus=kwargs.pop("n_cpus", usable_cpus))
50 |
51 | dds = DeseqDataSet(
52 | adata=self.adata,
53 | design=self.design, # initialize using design matrix, not formula
54 | refit_cooks=True,
55 | inference=inference,
56 | **kwargs,
57 | )
58 |
59 | dds.deseq2()
60 | self.dds = dds
61 |
62 | def _test_single_contrast(self, contrast, alpha=0.05, **kwargs) -> pd.DataFrame:
63 | """Conduct a specific test and returns a Pandas DataFrame.
64 |
65 | Args:
66 | contrast: list of three strings of the form `["variable", "tested level", "reference level"]`.
67 | alpha: p value threshold used for controlling fdr with independent hypothesis weighting
68 | **kwargs: extra arguments to pass to DeseqStats()
69 | """
70 | contrast = np.array(contrast)
71 | stat_res = DeseqStats(self.dds, contrast=contrast, alpha=alpha, **kwargs)
72 | # Calling `.summary()` is required to fill the `results_df` data frame
73 | stat_res.summary()
74 | res_df = (
75 | pd.DataFrame(stat_res.results_df)
76 | .rename(columns={"pvalue": "p_value", "padj": "adj_p_value", "log2FoldChange": "log_fc"})
77 | .sort_values("p_value")
78 | )
79 | res_df.index.name = "variable"
80 | res_df = res_df.reset_index()
81 | return res_df
82 |
--------------------------------------------------------------------------------
/pertpy/tools/_differential_gene_expression/_statsmodels.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import scanpy as sc
4 | import statsmodels
5 | import statsmodels.api as sm
6 | from tqdm.auto import tqdm
7 |
8 | from ._base import LinearModelBase
9 | from ._checks import check_is_numeric_matrix
10 |
11 |
12 | class Statsmodels(LinearModelBase):
13 | """Differential expression test using a statsmodels linear regression."""
14 |
15 | def _check_counts(self):
16 | check_is_numeric_matrix(self.data)
17 |
18 | def fit(
19 | self,
20 | regression_model: type[sm.OLS] | type[sm.GLM] = sm.OLS,
21 | **kwargs,
22 | ) -> None:
23 | """Fit the specified regression model.
24 |
25 | Args:
26 | regression_model: A statsmodels regression model class, either OLS or GLM.
27 | **kwargs: Additional arguments for fitting the specific method. In particular, this
28 | is where you can specify the family for GLM.
29 |
30 | Examples:
31 | >>> import statsmodels.api as sm
32 | >>> import pertpy as pt
33 | >>> model = pt.tl.Statsmodels(adata, design="~condition")
34 | >>> model.fit(sm.GLM, family=sm.families.NegativeBinomial(link=sm.families.links.Log()))
35 | >>> results = model.test_contrasts(np.array([0, 1]))
36 | """
37 | self.models = []
38 | for var in tqdm(self.adata.var_names):
39 | mod = regression_model(
40 | sc.get.obs_df(self.adata, keys=[var], layer=self.layer)[var],
41 | self.design,
42 | **kwargs,
43 | )
44 | mod = mod.fit()
45 | self.models.append(mod)
46 |
47 | def _test_single_contrast(self, contrast, **kwargs) -> pd.DataFrame:
48 | res = []
49 | for var, mod in zip(tqdm(self.adata.var_names), self.models, strict=False):
50 | t_test = mod.t_test(contrast)
51 | res.append(
52 | {
53 | "variable": var,
54 | "p_value": t_test.pvalue,
55 | "t_value": t_test.tvalue.item(),
56 | "sd": t_test.sd.item(),
57 | "log_fc": t_test.effect.item(),
58 | }
59 | )
60 | return (
61 | pd.DataFrame(res)
62 | .sort_values("p_value")
63 | .assign(adj_p_value=lambda x: statsmodels.stats.multitest.fdrcorrection(x["p_value"])[1])
64 | )
65 |
--------------------------------------------------------------------------------
/pertpy/tools/_distances/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/pertpy/tools/_distances/__init__.py
--------------------------------------------------------------------------------
/pertpy/tools/_perturbation_space/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/pertpy/tools/_perturbation_space/__init__.py
--------------------------------------------------------------------------------
/pertpy/tools/_perturbation_space/_clustering.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import TYPE_CHECKING
4 |
5 | from sklearn.metrics import pairwise_distances
6 |
7 | from pertpy.tools._perturbation_space._perturbation_space import PerturbationSpace
8 |
9 | if TYPE_CHECKING:
10 | from collections.abc import Iterable
11 |
12 | from anndata import AnnData
13 |
14 |
15 | class ClusteringSpace(PerturbationSpace):
16 | """Applies various clustering techniques to an embedding."""
17 |
18 | def __init__(self):
19 | super().__init__()
20 | self.X = None
21 |
22 | def evaluate_clustering(
23 | self,
24 | adata: AnnData,
25 | true_label_col: str,
26 | cluster_col: str,
27 | metrics: Iterable[str] = None,
28 | **kwargs,
29 | ):
30 | """Evaluation of previously computed clustering against ground truth labels.
31 |
32 | Args:
33 | adata: AnnData object that contains the clustered data and the cluster labels.
34 | true_label_col: ground truth labels.
35 | cluster_col: cluster computed labels.
36 | metrics: Metrics to compute. If `None` it defaults to ["nmi", "ari", "asw"].
37 | **kwargs: Additional arguments to pass to the metrics. For nmi, average_method can be passed.
38 | For asw, metric, distances, sample_size, and random_state can be passed.
39 |
40 | Examples:
41 | Example usage with KMeansSpace:
42 |
43 | >>> import pertpy as pt
44 | >>> mdata = pt.dt.papalexi_2021()
45 | >>> kmeans = pt.tl.KMeansSpace()
46 | >>> kmeans_adata = kmeans.compute(mdata["rna"], n_clusters=26)
47 | >>> results = kmeans.evaluate_clustering(
48 | ... kmeans_adata, true_label_col="gene_target", cluster_col="k-means", metrics=["nmi"]
49 | ... )
50 | """
51 | if metrics is None:
52 | metrics = ["nmi", "ari", "asw"]
53 | true_labels = adata.obs[true_label_col]
54 |
55 | results = {}
56 | for metric in metrics:
57 | if metric == "nmi":
58 | from pertpy.tools._perturbation_space._metrics import nmi
59 |
60 | if "average_method" not in kwargs:
61 | kwargs["average_method"] = "arithmetic" # by default in sklearn implementation
62 |
63 | nmi_score = nmi(
64 | true_labels=true_labels,
65 | predicted_labels=adata.obs[cluster_col],
66 | average_method=kwargs["average_method"],
67 | )
68 | results["nmi"] = nmi_score
69 |
70 | if metric == "ari":
71 | from pertpy.tools._perturbation_space._metrics import ari
72 |
73 | ari_score = ari(true_labels=true_labels, predicted_labels=adata.obs[cluster_col])
74 | results["ari"] = ari_score
75 |
76 | if metric == "asw":
77 | from pertpy.tools._perturbation_space._metrics import asw
78 |
79 | if "metric" not in kwargs:
80 | kwargs["metric"] = "euclidean"
81 | if "distances" not in kwargs:
82 | distances = pairwise_distances(self.X, metric=kwargs["metric"])
83 | if "sample_size" not in kwargs:
84 | kwargs["sample_size"] = None
85 | if "random_state" not in kwargs:
86 | kwargs["random_state"] = None
87 |
88 | asw_score = asw(
89 | pairwise_distances=distances,
90 | labels=true_labels,
91 | metric=kwargs["metric"],
92 | sample_size=kwargs["sample_size"],
93 | random_state=kwargs["random_state"],
94 | )
95 |
96 | results["asw"] = asw_score
97 |
98 | return results
99 |
--------------------------------------------------------------------------------
/pertpy/tools/_perturbation_space/_comparison.py:
--------------------------------------------------------------------------------
1 | from typing import TYPE_CHECKING
2 |
3 | import numpy as np
4 | from scipy.sparse import issparse
5 | from scipy.sparse import vstack as sp_vstack
6 | from sklearn.base import ClassifierMixin
7 | from sklearn.linear_model import LogisticRegression
8 |
9 | if TYPE_CHECKING:
10 | from numpy.typing import NDArray
11 |
12 |
13 | class PerturbationComparison:
14 | """Comparison between real and simulated perturbations."""
15 |
16 | def compare_classification(
17 | self,
18 | real: np.ndarray,
19 | simulated: np.ndarray,
20 | control: np.ndarray,
21 | clf: ClassifierMixin | None = None,
22 | ) -> float:
23 | """Compare classification accuracy between real and simulated perturbations.
24 |
25 | Trains a classifier on the real perturbation data + the control data and reports a normalized
26 | classification accuracy on the simulated perturbation.
27 |
28 | Args:
29 | real: Real perturbed data.
30 | simulated: Simulated perturbed data.
31 | control: Control data
32 | clf: sklearn classifier to use, `sklearn.linear_model.LogisticRegression` if not provided.
33 | """
34 | assert real.shape[1] == simulated.shape[1] == control.shape[1]
35 | if clf is None:
36 | clf = LogisticRegression()
37 | n_x = real.shape[0]
38 | data = sp_vstack((real, control)) if issparse(real) else np.vstack((real, control))
39 | labels = np.concatenate([np.full(real.shape[0], "comp"), np.full(control.shape[0], "ctrl")])
40 |
41 | clf.fit(data, labels)
42 | norm_score = clf.score(simulated, np.full(simulated.shape[0], "comp")) / clf.score(real, labels[:n_x])
43 | norm_score = min(1.0, norm_score)
44 |
45 | return norm_score
46 |
47 | def compare_knn(
48 | self,
49 | real: np.ndarray,
50 | simulated: np.ndarray,
51 | control: np.ndarray | None = None,
52 | use_simulated_for_knn: bool = False,
53 | n_neighbors: int = 20,
54 | random_state: int = 0,
55 | n_jobs: int = 1,
56 | ) -> dict[str, float]:
57 | """Calculate proportions of real perturbed and control data points for simulated data.
58 |
59 | Computes proportions of real perturbed, control and simulated (if `use_simulated_for_knn=True`)
60 | data points for simulated data. If control (`C`) is not provided, builds the knn graph from
61 | real perturbed + simulated perturbed.
62 |
63 | Args:
64 | real: Real perturbed data.
65 | simulated: Simulated perturbed data.
66 | control: Control data
67 | use_simulated_for_knn: Include simulted perturbed data (`simulated`) into the knn graph. Only valid when
68 | control (`control`) is provided.
69 | n_neighbors: Number of neighbors to use in k-neighbor graph.
70 | random_state: Random state used for k-neighbor graph construction.
71 | n_jobs: Number of cores to use. Defaults to -1 (all).
72 |
73 | """
74 | assert real.shape[1] == simulated.shape[1]
75 | if control is not None:
76 | assert real.shape[1] == control.shape[1]
77 |
78 | n_y = simulated.shape[0]
79 |
80 | if control is None:
81 | index_data = sp_vstack((simulated, real)) if issparse(real) else np.vstack((simulated, real))
82 | else:
83 | datas = (simulated, real, control) if use_simulated_for_knn else (real, control)
84 | index_data = sp_vstack(datas) if issparse(real) else np.vstack(datas)
85 |
86 | y_in_index = use_simulated_for_knn or control is None
87 | c_in_index = control is not None
88 | label_groups = ["comp"]
89 | labels: NDArray[np.str_] = np.full(index_data.shape[0], "comp")
90 | if y_in_index:
91 | labels[:n_y] = "siml"
92 | label_groups.append("siml")
93 | if c_in_index:
94 | labels[-control.shape[0] :] = "ctrl"
95 | label_groups.append("ctrl")
96 |
97 | from pynndescent import NNDescent
98 |
99 | index = NNDescent(
100 | index_data,
101 | n_neighbors=max(50, n_neighbors),
102 | random_state=random_state,
103 | n_jobs=n_jobs,
104 | )
105 | indices = index.query(simulated, k=n_neighbors)[0]
106 |
107 | uq, uq_counts = np.unique(labels[indices], return_counts=True)
108 | uq_counts_norm = uq_counts / uq_counts.sum()
109 | counts = dict(zip(label_groups, [0.0] * len(label_groups), strict=False))
110 | counts = dict(zip(uq, uq_counts_norm, strict=False))
111 |
112 | return counts
113 |
--------------------------------------------------------------------------------
/pertpy/tools/_perturbation_space/_metrics.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import TYPE_CHECKING, Literal
4 |
5 | from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score, silhouette_score
6 |
7 | if TYPE_CHECKING:
8 | import numpy as np
9 | from numpy._typing import ArrayLike
10 |
11 |
12 | def nmi(
13 | true_labels: np.ndarray,
14 | predicted_labels: np.ndarray,
15 | average_method: Literal["min", "max", "geometric", "arithmetic"] = "arithmetic",
16 | ) -> float:
17 | """Calculates the normalized mutual information score between two sets of clusters.
18 |
19 | See https://scikit-learn.org/stable/modules/generated/sklearn.metrics.normalized_mutual_info_score.html
20 |
21 | Args:
22 | true_labels: A clustering of the data into disjoint subsets.
23 | predicted_labels: A clustering of the data into disjoint subsets.
24 | average_method: How to compute the normalizer in the denominator.
25 |
26 | Returns:
27 | Score between 0.0 and 1.0 in normalized nats (based on the natural logarithm). 1.0 stands for perfectly complete labeling.
28 | """
29 | return normalized_mutual_info_score(
30 | labels_true=true_labels, labels_pred=predicted_labels, average_method=average_method
31 | )
32 |
33 |
34 | def ari(true_labels: np.ndarray, predicted_labels: np.ndarray) -> float:
35 | """Calculates the adjusted rand index for chance.
36 |
37 | See https://scikit-learn.org/stable/modules/generated/sklearn.metrics.adjusted_rand_score.html
38 |
39 | Args:
40 | true_labels: Ground truth class labels to be used as a reference.
41 | predicted_labels: Cluster labels to evaluate.
42 |
43 | Returns:
44 | Similarity score between -0.5 and 1.0. Random labelings have an ARI close to 0.0. 1.0 stands for perfect match.
45 | """
46 | return adjusted_rand_score(labels_true=true_labels, labels_pred=predicted_labels)
47 |
48 |
49 | def asw(
50 | pairwise_distances: ArrayLike,
51 | labels: ArrayLike,
52 | metric: str = "euclidean",
53 | sample_size: int = None,
54 | random_state: int = None,
55 | **kwargs,
56 | ) -> float:
57 | """Computes the average-width silhouette score.
58 |
59 | See: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.silhouette_score.html
60 |
61 | Args:
62 | pairwise_distances: An array of pairwise distances between samples, or a feature array.
63 | labels: Predicted labels for each sample.
64 | metric: The metric to use when calculating distance between instances in a feature array.
65 | If metric is a string, it must be one of the options allowed by metrics.pairwise.pairwise_distances.
66 | If X is the distance array itself, use metric="precomputed".
67 | sample_size: The size of the sample to use when computing the Silhouette Coefficient on a random subset of the data.
68 | If sample_size is None, no sampling is used.
69 | random_state: Determines random number generation for selecting a subset of samples. Used when sample_size is not None.
70 | **kwargs: Any further parameters are passed directly to the distance function. If using a scipy.spatial.distance metric, the parameters are still metric dependent.
71 |
72 | Returns:
73 | Mean Silhouette Coefficient for all samples.
74 | """
75 | return silhouette_score(
76 | X=pairwise_distances, labels=labels, metric=metric, sample_size=sample_size, random_state=random_state, **kwargs
77 | )
78 |
--------------------------------------------------------------------------------
/pertpy/tools/_scgen/__init__.py:
--------------------------------------------------------------------------------
1 | from pertpy.tools._scgen._scgen import Scgen
2 |
--------------------------------------------------------------------------------
/pertpy/tools/_scgen/_base_components.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import TYPE_CHECKING
4 |
5 | import jax.numpy as jnp
6 | from flax import linen as nn
7 |
8 | if TYPE_CHECKING:
9 | import jaxlib
10 |
11 |
12 | class FlaxEncoder(nn.Module):
13 | n_latent: int = 10
14 | n_layers: int = 2
15 | n_hidden: int = 800
16 | dropout_rate: float = 0.1
17 | latent_distribution: str = "normal"
18 | use_batch_norm: bool = True
19 | use_layer_norm: bool = False
20 | activation_fn: jaxlib.xla_extension.CompiledFunction = nn.activation.leaky_relu # type: ignore
21 | training: bool | None = None
22 | var_activation: jaxlib.xla_extension.CompiledFunction = jnp.exp # type: ignore
23 | # var_eps: float=1e-4,
24 |
25 | @nn.compact
26 | def __call__(self, x: jnp.ndarray, training: bool | None = None) -> tuple[float, float]:
27 | """Forward pass.
28 |
29 | Args:
30 | x: The input data matrix.
31 | training: Whether to use running training average.
32 |
33 | Returns:
34 | Mean and variance.
35 | """
36 | training = nn.merge_param("training", self.training, training)
37 | for _ in range(self.n_layers):
38 | x = nn.Dense(self.n_hidden)(x)
39 | if self.use_batch_norm:
40 | x = nn.BatchNorm(
41 | momentum=0.99,
42 | epsilon=0.001,
43 | use_running_average=not training,
44 | )(x)
45 | x = self.activation_fn(x)
46 | if self.use_layer_norm:
47 | x = nn.LayerNorm(x) # type: ignore
48 | x = nn.Dropout(rate=self.dropout_rate, deterministic=not training)(x) if self.dropout_rate > 0 else x
49 |
50 | mean_x = nn.Dense(self.n_latent)(x)
51 | logvar_x = nn.Dense(self.n_latent)(x)
52 |
53 | return mean_x, self.var_activation(logvar_x)
54 |
55 |
56 | class FlaxDecoder(nn.Module):
57 | n_output: int
58 | n_layers: int = 1
59 | n_hidden: int = 128
60 | dropout_rate: float = 0.2
61 | use_batch_norm: bool = False
62 | use_layer_norm: bool = False
63 | activation_fn: nn.activation = nn.activation.leaky_relu # type: ignore
64 | training: bool | None = None
65 |
66 | @nn.compact
67 | def __call__(self, x: jnp.ndarray, training: bool | None = None) -> jnp.ndarray: # type: ignore
68 | """Forward pass.
69 |
70 | Args:
71 | x: Input data.
72 | training: Whether to use running training average.
73 |
74 | Returns:
75 | Decoded data.
76 | """
77 | training = nn.merge_param("training", self.training, training)
78 |
79 | for _ in range(self.n_layers):
80 | x = nn.Dense(self.n_hidden)(x)
81 | if self.use_batch_norm:
82 | x = nn.BatchNorm(
83 | momentum=0.99,
84 | epsilon=0.001,
85 | use_running_average=not training,
86 | )(x)
87 | x = self.activation_fn(x) # type: ignore
88 | if self.use_layer_norm:
89 | x = nn.LayerNorm(x) # type: ignore
90 | x = nn.Dropout(rate=self.dropout_rate, deterministic=not training)(x) if self.dropout_rate > 0 else x
91 |
92 | x = nn.Dense(self.n_output)(x) # type: ignore
93 |
94 | return x
95 |
--------------------------------------------------------------------------------
/pertpy/tools/_scgen/_scgenvae.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import flax.linen as nn
4 | import jax.numpy as jnp
5 | import numpyro.distributions as dist
6 | from scvi import REGISTRY_KEYS
7 | from scvi.module.base import JaxBaseModuleClass, LossOutput, flax_configure
8 |
9 | from ._base_components import FlaxDecoder, FlaxEncoder
10 |
11 |
12 | @flax_configure
13 | class JaxSCGENVAE(JaxBaseModuleClass):
14 | n_input: int
15 | n_hidden: int = 800
16 | n_latent: int = 10
17 | n_layers: int = 2
18 | dropout_rate: float = 0.1
19 | log_variational: bool = False
20 | latent_distribution: str = "normal"
21 | use_batch_norm: str = "both"
22 | use_layer_norm: str = "none"
23 | kl_weight: float = 0.00005
24 | training: bool = True
25 |
26 | def setup(self):
27 | use_batch_norm_encoder = self.use_batch_norm in ("encoder", "both")
28 | use_layer_norm_encoder = self.use_layer_norm in ("encoder", "both")
29 |
30 | self.encoder = FlaxEncoder(
31 | n_latent=self.n_latent,
32 | n_layers=self.n_layers,
33 | n_hidden=self.n_hidden,
34 | dropout_rate=self.dropout_rate,
35 | latent_distribution=self.latent_distribution,
36 | use_batch_norm=use_batch_norm_encoder,
37 | use_layer_norm=use_layer_norm_encoder,
38 | activation_fn=nn.activation.leaky_relu,
39 | training=self.training,
40 | )
41 |
42 | self.decoder = FlaxDecoder(
43 | n_output=self.n_input,
44 | n_layers=self.n_layers,
45 | n_hidden=self.n_hidden,
46 | activation_fn=nn.activation.leaky_relu,
47 | dropout_rate=self.dropout_rate,
48 | training=self.training,
49 | )
50 |
51 | @property
52 | def required_rngs(self):
53 | return ("params", "dropout", "z")
54 |
55 | def _get_inference_input(self, tensors: dict[str, jnp.ndarray]):
56 | x = tensors[REGISTRY_KEYS.X_KEY]
57 |
58 | input_dict = {"x": x}
59 | return input_dict
60 |
61 | def inference(self, x: jnp.ndarray, n_samples: int = 1) -> dict:
62 | mean, var = self.encoder(x)
63 | stddev = jnp.sqrt(var)
64 |
65 | qz = dist.Normal(mean, stddev)
66 | z_rng = self.make_rng("z")
67 | sample_shape = () if n_samples == 1 else (n_samples,)
68 | z = qz.rsample(z_rng, sample_shape=sample_shape)
69 |
70 | return {"qz": qz, "z": z}
71 |
72 | def _get_generative_input(
73 | self,
74 | tensors: dict[str, jnp.ndarray],
75 | inference_outputs: dict[str, jnp.ndarray],
76 | ):
77 | # x = tensors[REGISTRY_KEYS.X_KEY]
78 | z = inference_outputs["z"]
79 | # batch_index = tensors[REGISTRY_KEYS.BATCH_KEY]
80 |
81 | input_dict = {
82 | # x=x,
83 | "z": z,
84 | # batch_index=batch_index,
85 | }
86 | return input_dict
87 |
88 | # def generative(self, x, z, batch_index) -> dict:
89 | def generative(self, z) -> dict:
90 | px = self.decoder(z)
91 | return {"px": px}
92 |
93 | def loss(self, tensors, inference_outputs, generative_outputs):
94 | x = tensors[REGISTRY_KEYS.X_KEY]
95 | px = generative_outputs["px"]
96 | qz = inference_outputs["qz"]
97 |
98 | kl_divergence_z = dist.kl_divergence(qz, dist.Normal(0, 1)).sum(-1)
99 | reconst_loss = self.get_reconstruction_loss(px, x)
100 |
101 | weighted_kl_local = self.kl_weight * kl_divergence_z
102 |
103 | loss = jnp.mean(0.5 * reconst_loss + 0.5 * weighted_kl_local)
104 |
105 | return LossOutput(
106 | loss=loss,
107 | reconstruction_loss=reconst_loss,
108 | kl_local=kl_divergence_z,
109 | n_obs_minibatch=x.shape[0],
110 | )
111 |
112 | def sample(
113 | self,
114 | tensors,
115 | n_samples=1,
116 | ):
117 | inference_kwargs = {"n_samples": n_samples}
118 | (
119 | inference_outputs,
120 | generative_outputs,
121 | ) = self.forward(
122 | tensors,
123 | inference_kwargs=inference_kwargs,
124 | compute_loss=False,
125 | )
126 | px = dist.Normal(generative_outputs["px"], 1).sample()
127 | return px
128 |
129 | def get_reconstruction_loss(self, x, px):
130 | return jnp.sum((x - px) ** 2)
131 |
--------------------------------------------------------------------------------
/pertpy/tools/_scgen/_utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import numpy as np
4 |
5 |
6 | def extractor(
7 | data,
8 | cell_type,
9 | condition_key,
10 | cell_type_key,
11 | ctrl_key,
12 | stim_key,
13 | ):
14 | """Returns a list of `data` files while filtering for a specific `cell_type`.
15 |
16 | Args:
17 | data: `~anndata.AnnData` Annotated data matrix
18 | cell_type: Specific cell type to be extracted from `data`.
19 | condition_key: Key for `.obs` of `data` where conditions can be found.
20 | cell_type_key: Key for `.obs` of `data` where cell types can be found.
21 | ctrl_key: Key for `control` part of the `data` found in `condition_key`.
22 | stim_key: Key for `stimulated` part of the `data` found in `condition_key`.
23 |
24 | Returns:
25 | List of `data` files while filtering for a specific `cell_type`.
26 |
27 | Example:
28 | .. code-block:: python
29 |
30 | import Scgen
31 | import anndata
32 |
33 | train_data = anndata.read("./data/train.h5ad")
34 | test_data = anndata.read("./data/test.h5ad")
35 | train_data_extracted_list = extractor(
36 | train_data, "CD4T", "conditions", "cell_type", "control", "stimulated"
37 | )
38 | """
39 | cell_with_both_condition = data[data.obs[cell_type_key] == cell_type]
40 | condition_1 = data[(data.obs[cell_type_key] == cell_type) & (data.obs[condition_key] == ctrl_key)]
41 | condition_2 = data[(data.obs[cell_type_key] == cell_type) & (data.obs[condition_key] == stim_key)]
42 | training = data[~((data.obs[cell_type_key] == cell_type) & (data.obs[condition_key] == stim_key))]
43 |
44 | return [training, condition_1, condition_2, cell_with_both_condition]
45 |
46 |
47 | def balancer(
48 | adata,
49 | cell_type_key,
50 | ):
51 | """Makes cell type populations equal.
52 |
53 | Args:
54 | adata: `~anndata.AnnData` Annotated data matrix.
55 | cell_type_key: key for `.obs` of `data` where cell types can be found.
56 |
57 | Returns:
58 | Equal cell type population Annotated data matrix.
59 |
60 | Example:
61 | .. code-block:: python
62 |
63 | import Scgen
64 | import anndata
65 |
66 | train_data = anndata.read("./train_kang.h5ad")
67 | train_ctrl = train_data[train_data.obs["condition"] == "control", :]
68 | train_ctrl = balancer(train_ctrl, "conditions", "cell_type")
69 | """
70 | class_names = np.unique(adata.obs[cell_type_key])
71 | class_pop = {}
72 | for cls in class_names:
73 | class_pop[cls] = adata[adata.obs[cell_type_key] == cls].shape[0]
74 | max_number = np.max(list(class_pop.values()))
75 | index_all = []
76 | for cls in class_names:
77 | class_index = np.array(adata.obs[cell_type_key] == cls)
78 | index_cls = np.nonzero(class_index)[0]
79 | rng = np.random.default_rng()
80 | index_cls_r = index_cls[rng.choice(len(index_cls), max_number)]
81 | index_all.append(index_cls_r)
82 |
83 | balanced_data = adata[np.concatenate(index_all)].copy()
84 |
85 | return balanced_data
86 |
--------------------------------------------------------------------------------
/pertpy/tools/transferlearning_MMD_LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Jindong Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 |
4 |
5 | @pytest.fixture
6 | def rng():
7 | return np.random.default_rng()
8 |
--------------------------------------------------------------------------------
/tests/metadata/test_cell_line.py:
--------------------------------------------------------------------------------
1 | import anndata
2 | import numpy as np
3 | import pandas as pd
4 | import pertpy as pt
5 | import pytest
6 | from anndata import AnnData
7 | from scipy import sparse
8 |
9 | NUM_CELLS = 100
10 | NUM_GENES = 100
11 | NUM_CELLS_PER_ID = NUM_CELLS // 4
12 |
13 |
14 | pt_metadata = pt.md.CellLine()
15 |
16 |
17 | @pytest.fixture
18 | def adata() -> AnnData:
19 | X = np.random.default_rng().normal(0, 1, (NUM_CELLS, NUM_GENES))
20 |
21 | obs = pd.DataFrame(
22 | {
23 | "DepMap_ID": ["ACH-000016", "ACH-000049", "ACH-001208", "ACH-000956"] * NUM_CELLS_PER_ID,
24 | "perturbation": ["Midostaurin"] * NUM_CELLS_PER_ID * 4,
25 | },
26 | index=[str(i) for i in range(NUM_CELLS)],
27 | )
28 |
29 | var_data = {"gene_name": [f"gene{i}" for i in range(1, NUM_GENES + 1)]}
30 | var = pd.DataFrame(var_data).set_index("gene_name", drop=False).rename_axis("index")
31 |
32 | X = sparse.csr_matrix(X)
33 | adata = anndata.AnnData(X=X, obs=obs, var=var)
34 |
35 | return adata
36 |
37 |
38 | def test_cell_line_annotation(adata):
39 | pt_metadata.annotate(adata=adata)
40 | assert len(adata.obs.columns) == len(pt_metadata.depmap.columns) + 1 # due to the perturbation column
41 | stripped_cell_line_name = ["SLR21", "HEKTE", "TK10", "22RV1"] * NUM_CELLS_PER_ID
42 | assert stripped_cell_line_name == list(adata.obs["StrippedCellLineName"])
43 |
44 |
45 | def test_gdsc_annotation(adata):
46 | pt_metadata.annotate(adata)
47 | pt_metadata.annotate_from_gdsc(adata, query_id="StrippedCellLineName")
48 | assert "ln_ic50_gdsc" in adata.obs
49 | assert "auc_gdsc" in adata.obs
50 |
51 |
52 | def test_prism_annotation(adata):
53 | adata.obs = pd.DataFrame(
54 | {
55 | "DepMap_ID": ["ACH-000879", "ACH-000488", "ACH-000488", "ACH-000008"] * NUM_CELLS_PER_ID,
56 | "perturbation": ["cytarabine", "cytarabine", "secnidazole", "flutamide"] * NUM_CELLS_PER_ID,
57 | },
58 | index=[str(i) for i in range(NUM_CELLS)],
59 | )
60 |
61 | pt_metadata.annotate(adata)
62 | pt_metadata.annotate_from_prism(adata, query_id="DepMap_ID")
63 | assert "ic50_prism" in adata.obs
64 | assert "ec50_prism" in adata.obs
65 | assert "auc_prism" in adata.obs
66 |
67 |
68 | def test_protein_expression_annotation(adata):
69 | pt_metadata.annotate(adata)
70 | pt_metadata.annotate_protein_expression(adata, query_id="StrippedCellLineName")
71 |
72 | assert len(adata.obsm) == 1
73 | assert adata.obsm["proteomics_protein_intensity"].shape == (
74 | NUM_GENES,
75 | len(pt_metadata.proteomics.uniprot_id.unique()),
76 | )
77 |
78 |
79 | def test_bulk_rna_expression_annotation(adata):
80 | pt_metadata.annotate(adata)
81 | pt_metadata.annotate_bulk_rna(adata, query_id="DepMap_ID", cell_line_source="broad")
82 |
83 | assert len(adata.obsm) == 1
84 | assert adata.obsm["bulk_rna_broad"].shape == (
85 | NUM_GENES,
86 | pt_metadata.bulk_rna_broad.shape[1],
87 | )
88 |
89 | pt_metadata.annotate_bulk_rna(adata, query_id="StrippedCellLineName")
90 |
91 | assert len(adata.obsm) == 2
92 | assert adata.obsm["bulk_rna_sanger"].shape == (
93 | NUM_GENES,
94 | pt_metadata.bulk_rna_sanger.shape[1],
95 | )
96 |
--------------------------------------------------------------------------------
/tests/metadata/test_compound.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import anndata
4 | import numpy as np
5 | import pandas as pd
6 | import pertpy as pt
7 | import pytest
8 | from anndata import AnnData
9 | from pubchempy import PubChemHTTPError
10 | from scipy import sparse
11 |
12 | NUM_CELLS = 100
13 | NUM_GENES = 100
14 | NUM_CELLS_PER_ID = NUM_CELLS // 4
15 |
16 |
17 | pt_compound = pt.md.Compound()
18 |
19 |
20 | @pytest.fixture
21 | def adata() -> AnnData:
22 | rng = np.random.default_rng(1)
23 | X = rng.standard_normal((NUM_CELLS, NUM_GENES))
24 | X = np.where(X < 0, 0, X)
25 |
26 | obs = pd.DataFrame(
27 | {
28 | "DepMap_ID": ["ACH-000016", "ACH-000049", "ACH-001208", "ACH-000956"] * NUM_CELLS_PER_ID,
29 | "perturbation": ["AG-490", "Iniparib", "TAK-901", "Quercetin"] * NUM_CELLS_PER_ID,
30 | },
31 | index=[str(i) for i in range(NUM_GENES)],
32 | )
33 |
34 | var_data = {"gene_name": [f"gene{i}" for i in range(1, NUM_GENES + 1)]}
35 | var = pd.DataFrame(var_data).set_index("gene_name", drop=False).rename_axis("index")
36 |
37 | X = sparse.csr_matrix(X)
38 | adata = anndata.AnnData(X=X, obs=obs, var=var)
39 |
40 | return adata
41 |
42 |
43 | def test_compound_annotation(adata):
44 | retries = 3
45 | attempt = 0
46 | while attempt < retries:
47 | try:
48 | pt_compound.annotate_compounds(adata=adata, query_id="perturbation")
49 | assert len(adata.obs.columns) == 5
50 | pubchemid = [5328779, 9796068, 16124208, 5280343] * NUM_CELLS_PER_ID
51 | assert pubchemid == list(adata.obs["pubchem_ID"])
52 | return
53 | except PubChemHTTPError:
54 | if attempt == retries - 1:
55 | pytest.fail("Max retries reached, PubChemHTTPError occurred")
56 | time.sleep(10)
57 | attempt += 1
58 |
--------------------------------------------------------------------------------
/tests/metadata/test_drug.py:
--------------------------------------------------------------------------------
1 | import anndata
2 | import numpy as np
3 | import pandas as pd
4 | import pertpy as pt
5 | import pytest
6 | from anndata import AnnData
7 |
8 | pt_drug = pt.md.Drug()
9 |
10 |
11 | @pytest.fixture
12 | def adata() -> AnnData:
13 | rng = np.random.default_rng()
14 |
15 | gene_names = ["SLC6A2", "SSTR3", "COL1A1", "RPS24", "SSTR2"]
16 | adata = anndata.AnnData(X=rng.standard_normal(size=(5, 5)), var=pd.DataFrame(index=gene_names))
17 |
18 | return adata
19 |
20 |
21 | def test_drug_chembl(adata):
22 | pt_drug.annotate(adata=adata)
23 | assert {"compounds"}.issubset(adata.var.columns)
24 | assert "CHEMBL1693" in adata.var["compounds"]["SLC6A2"]
25 |
26 |
27 | def test_drug_dgidb(adata):
28 | pt_drug.annotate(adata=adata, source="dgidb")
29 | assert {"compounds"}.issubset(adata.var.columns)
30 | assert "AMITIFADINE" in adata.var["compounds"]["SLC6A2"]
31 |
32 |
33 | def test_drug_pharmgkb(adata):
34 | pt_drug.annotate(adata=adata, source="pharmgkb")
35 | assert {"compounds"}.issubset(adata.var.columns)
36 | assert "3,4-methylenedioxymethamphetamine" in adata.var["compounds"]["SLC6A2"]
37 |
--------------------------------------------------------------------------------
/tests/metadata/test_moa.py:
--------------------------------------------------------------------------------
1 | import anndata
2 | import numpy as np
3 | import pandas as pd
4 | import pertpy as pt
5 | import pytest
6 | from anndata import AnnData
7 | from scipy import sparse
8 |
9 | NUM_CELLS = 100
10 | NUM_GENES = 100
11 | NUM_CELLS_PER_ID = NUM_CELLS // 4
12 |
13 |
14 | pt_moa = pt.md.Moa()
15 |
16 |
17 | @pytest.fixture
18 | def adata() -> AnnData:
19 | rng = np.random.default_rng(1)
20 | X = rng.standard_normal((NUM_CELLS, NUM_GENES))
21 | X = np.where(X < 0, 0, X)
22 |
23 | obs = pd.DataFrame(
24 | {
25 | "DepMap_ID": ["ACH-000016", "ACH-000049", "ACH-001208", "ACH-000956"] * NUM_CELLS_PER_ID,
26 | "perturbation": ["AG-490", "Iniparib", "TAK-901", "Quercetin"] * NUM_CELLS_PER_ID,
27 | },
28 | index=[str(i) for i in range(NUM_GENES)],
29 | )
30 |
31 | var_data = {"gene_name": [f"gene{i}" for i in range(1, NUM_GENES + 1)]}
32 | var = pd.DataFrame(var_data).set_index("gene_name", drop=False).rename_axis("index")
33 |
34 | X = sparse.csr_matrix(X)
35 | adata = anndata.AnnData(X=X, obs=obs, var=var)
36 |
37 | return adata
38 |
39 |
40 | def test_moa_annotation(adata):
41 | pt_moa.annotate(adata=adata, query_id="perturbation")
42 | assert len(adata.obs.columns) == len(pt_moa.clue.columns) + 1 # due to the DepMap_ID column
43 | assert {"moa", "target"}.issubset(adata.obs)
44 | moa = [
45 | "EGFR inhibitor|JAK inhibitor",
46 | "PARP inhibitor",
47 | "Aurora kinase inhibitor",
48 | "polar auxin transport inhibitor",
49 | ] * NUM_CELLS_PER_ID
50 | assert moa == list(adata.obs["moa"])
51 |
--------------------------------------------------------------------------------
/tests/preprocessing/test_grna_assignment.py:
--------------------------------------------------------------------------------
1 | import anndata as ad
2 | import numpy as np
3 | import pandas as pd
4 | import pertpy as pt
5 | import pytest
6 | from scipy import sparse
7 |
8 |
9 | @pytest.fixture
10 | def adata_simple():
11 | exp_matrix = np.array(
12 | [
13 | [9, 0, 1, 0, 1, 0, 0],
14 | [1, 5, 1, 7, 0, 0, 0],
15 | [2, 0, 1, 0, 0, 8, 0],
16 | [1, 1, 1, 0, 1, 1, 1],
17 | [0, 0, 1, 0, 0, 5, 0],
18 | [9, 0, 1, 7, 0, 0, 0],
19 | [0, 0, 1, 0, 0, 0, 6],
20 | [8, 0, 1, 0, 0, 0, 0],
21 | ]
22 | ).astype(np.float32)
23 | adata = ad.AnnData(
24 | exp_matrix,
25 | obs=pd.DataFrame(index=[f"cell_{i + 1}" for i in range(exp_matrix.shape[0])]),
26 | var=pd.DataFrame(index=[f"guide_{i + 1}" for i in range(exp_matrix.shape[1])]),
27 | )
28 | return adata
29 |
30 |
31 | @pytest.fixture
32 | def tiny_dense_adata():
33 | exp_matrix = np.array([[6, 0, 2], [1, 5, 0], [0, 1, 7]]).astype(np.float32)
34 | return ad.AnnData(
35 | exp_matrix,
36 | obs=pd.DataFrame(index=[f"cell_{i + 1}" for i in range(exp_matrix.shape[0])]),
37 | var=pd.DataFrame(index=[f"guide_{i + 1}" for i in range(exp_matrix.shape[1])]),
38 | )
39 |
40 |
41 | @pytest.fixture
42 | def tiny_sparse_adata():
43 | sparse_matrix = sparse.csr_matrix(np.array([[6, 0, 2], [1, 5, 0], [0, 1, 7]]).astype(np.float32))
44 | return ad.AnnData(
45 | sparse_matrix,
46 | obs=pd.DataFrame(index=[f"cell_{i + 1}" for i in range(sparse_matrix.shape[0])]),
47 | var=pd.DataFrame(index=[f"guide_{i + 1}" for i in range(sparse_matrix.shape[1])]),
48 | )
49 |
50 |
51 | @pytest.mark.parametrize("adata_fixture", ["tiny_dense_adata", "tiny_sparse_adata"])
52 | def test_grna_threshold_assignment(request, adata_fixture):
53 | adata = request.getfixturevalue(adata_fixture)
54 | threshold = 5
55 | output_layer = "assigned_guides"
56 | assert output_layer not in adata.layers
57 |
58 | ga = pt.pp.GuideAssignment()
59 | ga.assign_by_threshold(adata, assignment_threshold=threshold, output_layer=output_layer)
60 |
61 | assert output_layer in adata.layers
62 |
63 | # Convert to dense for comparison if needed
64 | if sparse.issparse(adata.layers[output_layer]):
65 | result_matrix = adata.layers[output_layer].toarray()
66 | else:
67 | result_matrix = adata.layers[output_layer]
68 |
69 | # Convert original data to dense for comparison if needed
70 | original_matrix = adata.X.toarray() if sparse.issparse(adata.X) else adata.X
71 |
72 | assert np.all(np.logical_xor(original_matrix < threshold, result_matrix == 1))
73 |
74 |
75 | @pytest.mark.parametrize("adata_fixture", ["tiny_dense_adata", "tiny_sparse_adata"])
76 | def test_grna_max_assignment(request, adata_fixture):
77 | adata = request.getfixturevalue(adata_fixture)
78 | threshold = 6
79 | obs_key = "assigned_guide"
80 | assert obs_key not in adata.obs
81 |
82 | ga = pt.pp.GuideAssignment()
83 | ga.assign_to_max_guide(adata, assignment_threshold=threshold, obs_key=obs_key)
84 | assert obs_key in adata.obs
85 | assert tuple(adata.obs[obs_key]) == ("guide_1", "Negative", "guide_3")
86 |
87 |
88 | def test_grna_mixture_model(adata_simple):
89 | output_key = "assigned_guide"
90 | assert output_key not in adata_simple.obs
91 |
92 | ga = pt.pp.GuideAssignment()
93 | ga.assign_mixture_model(adata_simple)
94 | assert output_key in adata_simple.obs
95 | target = [f"guide_{i}" if i > 0 else "negative" for i in [1, 4, 6, 0, 6, 1, 7, 1, 0]]
96 | assert all(t in g for t, g in zip(target, adata_simple.obs[output_key], strict=False))
97 |
--------------------------------------------------------------------------------
/tests/tools/_coda/test_sccoda.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import arviz as az
4 | import numpy as np
5 | import pandas as pd
6 | import pertpy as pt
7 | import pytest
8 | import scanpy as sc
9 | from mudata import MuData
10 |
11 | CWD = Path(__file__).parent.resolve()
12 |
13 |
14 | sccoda = pt.tl.Sccoda()
15 |
16 |
17 | @pytest.fixture
18 | def adata():
19 | cells = pt.dt.haber_2017_regions()
20 | cells = sc.pp.subsample(cells, 0.1, copy=True)
21 |
22 | return cells
23 |
24 |
25 | def test_load(adata):
26 | mdata = sccoda.load(
27 | adata,
28 | type="cell_level",
29 | generate_sample_level=True,
30 | cell_type_identifier="cell_label",
31 | sample_identifier="batch",
32 | covariate_obs=["condition"],
33 | )
34 | assert isinstance(mdata, MuData)
35 | assert "rna" in mdata.mod
36 | assert "coda" in mdata.mod
37 |
38 |
39 | def test_prepare(adata):
40 | mdata = sccoda.load(
41 | adata,
42 | type="cell_level",
43 | generate_sample_level=True,
44 | cell_type_identifier="cell_label",
45 | sample_identifier="batch",
46 | covariate_obs=["condition"],
47 | )
48 | mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
49 | assert "scCODA_params" in mdata["coda"].uns
50 | assert "covariate_matrix" in mdata["coda"].obsm
51 | assert "sample_counts" in mdata["coda"].obsm
52 | assert isinstance(mdata["coda"].obsm["sample_counts"], np.ndarray)
53 | assert np.sum(mdata["coda"].obsm["covariate_matrix"]) == 6
54 |
55 |
56 | def test_run_nuts(adata):
57 | mdata = sccoda.load(
58 | adata,
59 | type="cell_level",
60 | generate_sample_level=True,
61 | cell_type_identifier="cell_label",
62 | sample_identifier="batch",
63 | covariate_obs=["condition"],
64 | )
65 | mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
66 | sccoda.run_nuts(mdata, num_samples=1000, num_warmup=100)
67 | assert "effect_df_condition[T.Hpoly.Day10]" in mdata["coda"].varm
68 | assert "effect_df_condition[T.Hpoly.Day3]" in mdata["coda"].varm
69 | assert "effect_df_condition[T.Salmonella]" in mdata["coda"].varm
70 | assert "intercept_df" in mdata["coda"].varm
71 | assert mdata["coda"].varm["effect_df_condition[T.Hpoly.Day10]"].shape == (8, 7)
72 | assert mdata["coda"].varm["effect_df_condition[T.Hpoly.Day3]"].shape == (8, 7)
73 | assert mdata["coda"].varm["effect_df_condition[T.Salmonella]"].shape == (8, 7)
74 | assert mdata["coda"].varm["intercept_df"].shape == (8, 5)
75 |
76 |
77 | def test_credible_effects(adata):
78 | adata_salm = adata[adata.obs["condition"].isin(["Control", "Salmonella"])]
79 | mdata = sccoda.load(
80 | adata_salm,
81 | type="cell_level",
82 | generate_sample_level=True,
83 | cell_type_identifier="cell_label",
84 | sample_identifier="batch",
85 | covariate_obs=["condition"],
86 | )
87 | mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Goblet")
88 | sccoda.run_nuts(mdata)
89 | assert isinstance(sccoda.credible_effects(mdata), pd.Series)
90 | assert sccoda.credible_effects(mdata)["condition[T.Salmonella]"]["Enterocyte"]
91 |
92 |
93 | def test_make_arviz(adata):
94 | adata_salm = adata[adata.obs["condition"].isin(["Control", "Salmonella"])]
95 | mdata = sccoda.load(
96 | adata_salm,
97 | type="cell_level",
98 | generate_sample_level=True,
99 | cell_type_identifier="cell_label",
100 | sample_identifier="batch",
101 | covariate_obs=["condition"],
102 | )
103 | mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Goblet")
104 | sccoda.run_nuts(mdata)
105 | arviz_data = sccoda.make_arviz(mdata, num_prior_samples=100)
106 | assert isinstance(arviz_data, az.InferenceData)
107 |
--------------------------------------------------------------------------------
/tests/tools/_coda/test_tasccoda.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import numpy as np
4 | import pytest
5 | import scanpy as sc
6 | from mudata import MuData
7 |
8 | try:
9 | import ete4
10 | except ImportError:
11 | pytest.skip("ete4 not available", allow_module_level=True)
12 |
13 | import pertpy as pt
14 |
15 | CWD = Path(__file__).parent.resolve()
16 |
17 |
18 | tasccoda = pt.tl.Tasccoda()
19 |
20 |
21 | @pytest.fixture
22 | def smillie_adata():
23 | smillie_adata = pt.dt.tasccoda_example()
24 | smillie_adata = sc.pp.subsample(smillie_adata, 0.1, copy=True)
25 |
26 | return smillie_adata
27 |
28 |
29 | def test_load(smillie_adata):
30 | mdata = tasccoda.load(
31 | smillie_adata,
32 | type="sample_level",
33 | levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
34 | key_added="lineage",
35 | add_level_name=True,
36 | )
37 | assert isinstance(mdata, MuData)
38 | assert "rna" in mdata.mod
39 | assert "coda" in mdata.mod
40 | assert "lineage" in mdata["coda"].uns
41 |
42 |
43 | def test_prepare(smillie_adata):
44 | mdata = tasccoda.load(
45 | smillie_adata,
46 | type="sample_level",
47 | levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
48 | key_added="lineage",
49 | add_level_name=True,
50 | )
51 | mdata = tasccoda.prepare(
52 | mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
53 | )
54 | assert "scCODA_params" in mdata["coda"].uns
55 | assert "covariate_matrix" in mdata["coda"].obsm
56 | assert "sample_counts" in mdata["coda"].obsm
57 | assert isinstance(mdata["coda"].obsm["sample_counts"], np.ndarray)
58 | assert np.sum(mdata["coda"].obsm["covariate_matrix"]) == 8
59 |
60 |
61 | def test_run_nuts(smillie_adata):
62 | mdata = tasccoda.load(
63 | smillie_adata,
64 | type="sample_level",
65 | levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
66 | key_added="lineage",
67 | add_level_name=True,
68 | )
69 | mdata = tasccoda.prepare(
70 | mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
71 | )
72 | tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100)
73 | assert "effect_df_Health[T.Inflamed]" in mdata["coda"].varm
74 | assert "effect_df_Health[T.Non-inflamed]" in mdata["coda"].varm
75 | assert mdata["coda"].varm["effect_df_Health[T.Inflamed]"].shape == (51, 7)
76 | assert mdata["coda"].varm["effect_df_Health[T.Non-inflamed]"].shape == (51, 7)
77 |
--------------------------------------------------------------------------------
/tests/tools/_differential_gene_expression/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scverse/pertpy/512c0e0f3eb25f2df915e5a4fcbeb7dcbddbd6e4/tests/tools/_differential_gene_expression/__init__.py
--------------------------------------------------------------------------------
/tests/tools/_differential_gene_expression/conftest.py:
--------------------------------------------------------------------------------
1 | import anndata as ad
2 | import numpy as np
3 | import pandas as pd
4 | import pytest
5 | import scipy.sparse as sp
6 | from pydeseq2.utils import load_example_data
7 |
8 |
9 | @pytest.fixture
10 | def test_counts():
11 | return load_example_data(
12 | modality="raw_counts",
13 | dataset="synthetic",
14 | debug=False,
15 | )
16 |
17 |
18 | @pytest.fixture
19 | def test_metadata():
20 | return load_example_data(
21 | modality="metadata",
22 | dataset="synthetic",
23 | debug=False,
24 | )
25 |
26 |
27 | @pytest.fixture
28 | def test_adata(test_counts, test_metadata):
29 | return ad.AnnData(X=test_counts, obs=test_metadata)
30 |
31 |
32 | @pytest.fixture(params=[np.array, sp.csr_matrix, sp.csc_matrix])
33 | def test_adata_minimal(request):
34 | matrix_format = request.param
35 | n_obs = 80
36 | n_donors = n_obs // 4
37 | rng = np.random.default_rng(9) # make tests deterministic
38 | obs = pd.DataFrame(
39 | {
40 | "condition": ["A", "B"] * (n_obs // 2),
41 | "donor": sum(([f"D{i}"] * n_donors for i in range(n_obs // n_donors)), []),
42 | "other": (["X"] * (n_obs // 4)) + (["Y"] * ((3 * n_obs) // 4)),
43 | "pairing": sum(([str(i), str(i)] for i in range(n_obs // 2)), []),
44 | "continuous": [rng.uniform(0, 1) * 4000 for _ in range(n_obs)],
45 | },
46 | )
47 | var = pd.DataFrame(index=["gene1", "gene2"])
48 | group1 = rng.negative_binomial(20, 0.1, n_obs // 2) # large mean
49 | group2 = rng.negative_binomial(5, 0.5, n_obs // 2) # small mean
50 |
51 | condition_data = np.empty((n_obs,), dtype=group1.dtype)
52 | condition_data[0::2] = group1
53 | condition_data[1::2] = group2
54 |
55 | donor_data = np.empty((n_obs,), dtype=group1.dtype)
56 | donor_data[0:n_donors] = group2[:n_donors]
57 | donor_data[n_donors : (2 * n_donors)] = group1[n_donors:]
58 |
59 | donor_data[(2 * n_donors) : (3 * n_donors)] = group2[:n_donors]
60 | donor_data[(3 * n_donors) :] = group1[n_donors:]
61 |
62 | X = matrix_format(np.vstack([condition_data, donor_data]).T)
63 |
64 | return ad.AnnData(X=X, obs=obs, var=var)
65 |
--------------------------------------------------------------------------------
/tests/tools/_differential_gene_expression/test_base.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Sequence
2 |
3 | import pytest
4 | from pandas.core.api import DataFrame
5 | from pertpy.tools._differential_gene_expression import LinearModelBase
6 |
7 |
8 | @pytest.fixture
9 | def MockLinearModel():
10 | class _MockLinearModel(LinearModelBase):
11 | def _check_counts(self) -> None:
12 | pass
13 |
14 | def fit(self, **kwargs) -> None:
15 | pass
16 |
17 | def _test_single_contrast(self, contrast: Sequence[float], **kwargs) -> DataFrame:
18 | pass
19 |
20 | return _MockLinearModel
21 |
22 |
23 | @pytest.mark.parametrize(
24 | "formula,cond_kwargs,expected_contrast",
25 | [
26 | # single variable
27 | ["~ condition", {}, [1, 0]],
28 | ["~ condition", {"condition": "A"}, [1, 0]],
29 | ["~ condition", {"condition": "B"}, [1, 1]],
30 | ["~ condition", {"condition": "42"}, ValueError], # non-existant category
31 | # no-intercept models
32 | ["~ 0 + condition", {"condition": "A"}, [1, 0]],
33 | ["~ 0 + condition", {"condition": "B"}, [0, 1]],
34 | # Different way of specifying dummy coding
35 | ["~ donor", {"donor": "D0"}, [1, 0, 0, 0]],
36 | ["~ C(donor)", {"donor": "D0"}, [1, 0, 0, 0]],
37 | ["~ C(donor, contr.treatment(base='D2'))", {"donor": "D2"}, [1, 0, 0, 0]],
38 | ["~ C(donor, contr.treatment(base='D2'))", {"donor": "D0"}, [1, 1, 0, 0]],
39 | # Handle continuous covariates
40 | ["~ donor + continuous", {"donor": "D1"}, [1, 1, 0, 0, 0]],
41 | ["~ donor + np.log1p(continuous)", {"donor": "D1"}, [1, 1, 0, 0, 0]],
42 | ["~ donor + continuous + np.log1p(continuous)", {"donor": "D0"}, [1, 0, 0, 0, 0, 0]],
43 | # Nonsense models repeating the same variable, which are nonetheless allowed by formulaic
44 | ["~ donor + C(donor)", {"donor": "D1"}, [1, 1, 0, 0, 1, 0, 0]],
45 | ["~ donor + C(donor, contr.treatment(base='D2'))", {"donor": "D0"}, [1, 0, 0, 0, 1, 0, 0]],
46 | [
47 | "~ condition + donor + C(donor, contr.treatment(base='D2'))",
48 | {"condition": "A"},
49 | ValueError,
50 | ], # donor base category can't be resolved because it's ambiguous -> ValueError
51 | # Sum2zero coding
52 | ["~ C(donor, contr.sum)", {"donor": "D0"}, [1, 1, 0, 0]],
53 | ["~ C(donor, contr.sum)", {"donor": "D3"}, [1, -1, -1, -1]],
54 | # Multiple categorical variables
55 | ["~ condition + donor", {"condition": "A"}, [1, 0, 0, 0, 0]],
56 | ["~ condition + donor", {"donor": "D2"}, [1, 0, 0, 1, 0]],
57 | ["~ condition + donor", {"condition": "B", "donor": "D2"}, [1, 1, 0, 1, 0]],
58 | ["~ 0 + condition + donor", {"donor": "D1"}, [0, 0, 1, 0, 0]],
59 | # Interaction terms
60 | ["~ condition * donor", {"condition": "A"}, [1, 0, 0, 0, 0, 0, 0, 0]],
61 | ["~ condition + donor + condition:donor", {"condition": "A"}, [1, 0, 0, 0, 0, 0, 0, 0]],
62 | ["~ condition * donor", {"condition": "B", "donor": "D2"}, [1, 1, 0, 1, 0, 0, 1, 0]],
63 | ["~ condition * C(donor, contr.treatment(base='D2'))", {"condition": "A"}, [1, 0, 0, 0, 0, 0, 0, 0]],
64 | [
65 | "~ condition * C(donor, contr.treatment(base='D2'))",
66 | {"condition": "B", "donor": "D0"},
67 | [1, 1, 1, 0, 0, 1, 0, 0],
68 | ],
69 | [
70 | "~ condition:donor",
71 | {"condition": "A"},
72 | ValueError,
73 | ], # Can't automatically resolve base category, because Formulaic builds a reduced-rank and full-rank factor internally
74 | ["~ condition:donor", {"condition": "A", "donor": "D1"}, [1, 1, 0, 0, 0, 0, 0, 0]],
75 | ["~ condition:C(donor)", {"condition": "A", "donor": "D1"}, [1, 1, 0, 0, 0, 0, 0, 0]],
76 | ],
77 | )
78 | def test_model_cond(test_adata_minimal, MockLinearModel, formula, cond_kwargs, expected_contrast):
79 | mod = MockLinearModel(test_adata_minimal, formula)
80 | if isinstance(expected_contrast, type):
81 | with pytest.raises(expected_contrast):
82 | mod.cond(**cond_kwargs)
83 | else:
84 | actual_contrast = mod.cond(**cond_kwargs)
85 | assert actual_contrast.tolist() == expected_contrast
86 | assert actual_contrast.index.tolist() == mod.design.columns.tolist()
87 |
--------------------------------------------------------------------------------
/tests/tools/_differential_gene_expression/test_compare_groups.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | from pertpy.tools._differential_gene_expression import AVAILABLE_METHODS
4 |
5 |
6 | @pytest.mark.parametrize("method", AVAILABLE_METHODS)
7 | @pytest.mark.parametrize("paired_by", ["pairing", None])
8 | def test_unified_api_single_group(test_adata_minimal, method, paired_by):
9 | """
10 | Test that all methods implement the unified API.
11 |
12 | Here, we don't check the correctness of the results
13 | (we have the method-specific tests for that), but rather that the interface works
14 | as expected and the format of the resulting data frame is what we expect.
15 |
16 | TODO: tests for layers
17 | TODO: tests for mask
18 | """
19 | res_df = method.compare_groups(
20 | adata=test_adata_minimal, column="condition", baseline="A", groups_to_compare="B", paired_by=paired_by
21 | )
22 | assert res_df.shape[0] == test_adata_minimal.shape[1], "The result dataframe must contain a value for each var name"
23 | assert {"variable", "p_value", "log_fc", "adj_p_value"} - set(res_df.columns) == set(), (
24 | "Mandated column names not in result df"
25 | )
26 | assert np.all((res_df["p_value"] >= 0) & (res_df["p_value"] <= 1))
27 | assert np.all((res_df["adj_p_value"] >= 0) & (res_df["adj_p_value"] <= 1))
28 | assert np.all(res_df["adj_p_value"] >= res_df["p_value"])
29 |
30 |
31 | @pytest.mark.parametrize("method", AVAILABLE_METHODS)
32 | def test_unified_api_multiple_groups(test_adata_minimal, method):
33 | """
34 | Test that all methods implement the unified API.
35 |
36 | Here, we don't check the correctness of the results
37 | (we have the method-specific tests for that), but rather that the interface works
38 | as expected and the format of the resulting data frame is what we expect.
39 |
40 | TODO: tests for layers
41 | TODO: tests for mask
42 | """
43 | res_df = method.compare_groups(
44 | adata=test_adata_minimal,
45 | column="donor",
46 | baseline="D0",
47 | groups_to_compare=["D1", "D2", "D3"],
48 | paired_by=None, # No pairing possible here.
49 | )
50 | assert res_df.shape[0] == 3 * test_adata_minimal.shape[1], (
51 | "The result dataframe must contain a value for each var name"
52 | )
53 | assert {"variable", "p_value", "log_fc", "adj_p_value"} - set(res_df.columns) == set(), (
54 | "Mandated column names not in result df"
55 | )
56 | assert np.all((res_df["p_value"] >= 0) & (res_df["p_value"] <= 1))
57 | assert np.all((res_df["adj_p_value"] >= 0) & (res_df["adj_p_value"] <= 1))
58 | assert np.all(res_df["adj_p_value"] >= res_df["p_value"])
59 |
--------------------------------------------------------------------------------
/tests/tools/_differential_gene_expression/test_dge.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import pertpy as pt
4 | import pytest
5 | from anndata import AnnData
6 |
7 | pytest.skip("Disabled", allow_module_level=True)
8 |
9 |
10 | @pytest.fixture
11 | def adata(rng):
12 | adata = AnnData(rng.normal(size=(100, 10)))
13 | genes = np.rec.fromarrays(
14 | [np.array([f"gene{i}" for i in range(10)])],
15 | names=["group1", "O"],
16 | )
17 | adata.uns["de_key1"] = {
18 | "names": genes,
19 | "scores": {"group1": rng.random(10)},
20 | "pvals_adj": {"group1": rng.random(10)},
21 | }
22 | adata.uns["de_key2"] = {
23 | "names": genes,
24 | "scores": {"group1": rng.random(10)},
25 | "pvals_adj": {"group1": rng.random(10)},
26 | }
27 | return adata
28 |
29 |
30 | @pytest.fixture
31 | def dataframe(rng):
32 | df1 = pd.DataFrame(
33 | {
34 | "variable": ["gene" + str(i) for i in range(10)],
35 | "log_fc": rng.random(10),
36 | "adj_p_value": rng.random(10),
37 | }
38 | )
39 | df2 = pd.DataFrame(
40 | {
41 | "variable": ["gene" + str(i) for i in range(10)],
42 | "log_fc": rng.random(10),
43 | "adj_p_value": rng.random(10),
44 | }
45 | )
46 | return df1, df2
47 |
48 |
49 | def test_error_both_keys_and_dfs(adata, dataframe):
50 | with pytest.raises(ValueError):
51 | pt_DGE = pt.tl.DGEEVAL()
52 | pt_DGE.compare(adata=adata, de_key1="de_key1", de_df1=dataframe[0])
53 |
54 |
55 | def test_error_missing_adata():
56 | with pytest.raises(ValueError):
57 | pt_DGE = pt.tl.DGEEVAL()
58 | pt_DGE.compare(de_key1="de_key1", de_key2="de_key2")
59 |
60 |
61 | def test_error_missing_df(dataframe):
62 | with pytest.raises(ValueError):
63 | pt_DGE = pt.tl.DGEEVAL()
64 | pt_DGE.compare(de_df1=dataframe[0])
65 |
66 |
67 | def test_key(adata):
68 | pt_DGE = pt.tl.DGEEVAL()
69 | results = pt_DGE.compare(adata=adata, de_key1="de_key1", de_key2="de_key2", shared_top=5)
70 | assert "shared_top_genes" in results
71 | assert "scores_corr" in results
72 | assert "pvals_adj_corr" in results
73 | assert "scores_ranks_corr" in results
74 |
75 |
76 | def test_df(dataframe):
77 | pt_DGE = pt.tl.DGEEVAL()
78 | results = pt_DGE.compare(de_df1=dataframe[0], de_df2=dataframe[1], shared_top=5)
79 | assert "shared_top_genes" in results
80 | assert "scores_corr" in results
81 | assert "pvals_adj_corr" in results
82 | assert "scores_ranks_corr" in results
83 |
--------------------------------------------------------------------------------
/tests/tools/_differential_gene_expression/test_edger.py:
--------------------------------------------------------------------------------
1 | import numpy.testing as npt
2 | import pytest
3 | from pertpy.tools._differential_gene_expression import EdgeR, PyDESeq2
4 |
5 | try:
6 | from rpy2.robjects.packages import importr
7 |
8 | r_dependency = importr("edgeR")
9 | except Exception: # noqa: BLE001
10 | r_dependency = None
11 |
12 | pytestmark = pytest.mark.skipif(r_dependency is None, reason="Required R package 'edgeR' not available")
13 |
14 |
15 | def test_edger_simple(test_adata):
16 | """Check that the EdgeR method can be
17 |
18 | 1. Initialized
19 | 2. Fitted
20 | 3. That test_contrast returns a DataFrame with the correct number of rows
21 | """
22 | method = EdgeR(adata=test_adata, design="~condition")
23 | method.fit()
24 | res_df = method.test_contrasts(method.contrast("condition", "A", "B"))
25 |
26 | assert len(res_df) == test_adata.n_vars
27 | # Compare against snapshot
28 | npt.assert_almost_equal(
29 | res_df.p_value.values,
30 | [
31 | 8.0000e-05,
32 | 1.8000e-04,
33 | 5.3000e-04,
34 | 1.1800e-03,
35 | 3.3800e-02,
36 | 3.3820e-02,
37 | 7.7980e-02,
38 | 1.3715e-01,
39 | 2.5052e-01,
40 | 9.2485e-01,
41 | ],
42 | decimal=4,
43 | )
44 | npt.assert_almost_equal(
45 | res_df.log_fc.values,
46 | [0.61208, -0.39374, 0.57944, 0.7343, -0.58675, 0.42575, -0.23951, -0.20761, 0.17489, 0.0247],
47 | decimal=4,
48 | )
49 |
50 |
51 | def test_edger_complex(test_adata):
52 | """Check that the EdgeR method can be initialized with a different covariate name and fitted and that the test_contrast
53 | method returns a dataframe with the correct number of rows.
54 | """
55 | test_adata.obs["condition1"] = test_adata.obs["condition"].copy()
56 | method = EdgeR(adata=test_adata, design="~condition1+group")
57 | method.fit()
58 | res_df = method.test_contrasts(method.contrast("condition1", "A", "B"))
59 |
60 | assert len(res_df) == test_adata.n_vars
61 | # Check that the index of the result matches the var_names of the AnnData object
62 | assert set(test_adata.var_names) == set(res_df["variable"])
63 |
64 | # Compare ranking of genes from a different method (without design matrix handling)
65 | down_gene = res_df.set_index("variable").loc["gene3", "log_fc"]
66 | up_gene = res_df.set_index("variable").loc["gene1", "log_fc"]
67 | assert down_gene < up_gene
68 |
69 | method = PyDESeq2(adata=test_adata, design="~condition1+group")
70 | method.fit()
71 | deseq_res_df = method.test_contrasts(method.contrast("condition1", "A", "B"))
72 | assert all(res_df.sort_values("log_fc")["variable"].values == deseq_res_df.sort_values("log_fc")["variable"].values)
73 |
--------------------------------------------------------------------------------
/tests/tools/_differential_gene_expression/test_input_checks.py:
--------------------------------------------------------------------------------
1 | import anndata as ad
2 | import numpy as np
3 | import pytest
4 | import scipy.sparse as sp
5 | from pertpy.tools._differential_gene_expression import Statsmodels
6 | from pertpy.tools._differential_gene_expression._checks import check_is_integer_matrix, check_is_numeric_matrix
7 |
8 |
9 | @pytest.mark.parametrize(
10 | "matrix_type,invalid_input",
11 | [
12 | [np.array, np.nan],
13 | [np.array, np.inf],
14 | [np.array, "foo"],
15 | # not possible to have a sparse matrix with 'object' dtype (e.g. "foo")
16 | [sp.csr_matrix, np.nan],
17 | [sp.csr_matrix, np.nan],
18 | [sp.csc_matrix, np.inf],
19 | [sp.csc_matrix, np.inf],
20 | ],
21 | )
22 | def test_invalid_inputs(matrix_type, invalid_input, test_counts, test_metadata):
23 | """Check that invalid inputs in MethodBase counts raise an error."""
24 | test_counts[0, 0] = invalid_input
25 | adata = ad.AnnData(X=matrix_type(test_counts), obs=test_metadata)
26 | with pytest.raises((ValueError, TypeError)):
27 | Statsmodels(adata=adata, design="~condition")
28 |
29 |
30 | @pytest.mark.parametrize("matrix_type", [np.array, sp.csr_matrix, sp.csc_matrix])
31 | @pytest.mark.parametrize(
32 | "input,expected",
33 | [
34 | pytest.param([[1, 2], [3, 4]], None, id="valid"),
35 | pytest.param([[1, -2], [3, 4]], ValueError, id="negative values"),
36 | pytest.param([[1, 2.5], [3, 4]], ValueError, id="non-integer"),
37 | pytest.param([[1, np.nan], [3, 4]], ValueError, id="nans"),
38 | ],
39 | )
40 | def test_check_is_integer_matrix(matrix_type, input, expected: type):
41 | """Test for valid integer matrix."""
42 | matrix = matrix_type(input, dtype=float)
43 |
44 | if expected is None:
45 | check_is_integer_matrix(matrix)
46 | else:
47 | with pytest.raises(expected):
48 | check_is_integer_matrix(matrix)
49 |
50 |
51 | @pytest.mark.parametrize("matrix_type", [np.array, sp.csr_matrix, sp.csc_matrix])
52 | @pytest.mark.parametrize(
53 | "input,expected",
54 | [
55 | pytest.param([[1, 2], [3, 4]], None, id="valid"),
56 | pytest.param([[1, -2], [3, 4]], None, id="negative values"),
57 | pytest.param([[1, 2.5], [3, 4]], None, id="non-integer"),
58 | pytest.param([[1, np.nan], [3, 4]], ValueError, id="nans"),
59 | ],
60 | )
61 | def test_check_is_numeric_matrix(matrix_type, input, expected: type):
62 | """Test for valid numeric matrix.
63 |
64 | This is like the integer matrix check above, except that also negative
65 | and float values are allowed.
66 | """
67 | matrix = matrix_type(input, dtype=float)
68 |
69 | if expected is None:
70 | check_is_numeric_matrix(matrix)
71 | else:
72 | with pytest.raises(expected):
73 | check_is_numeric_matrix(matrix)
74 |
--------------------------------------------------------------------------------
/tests/tools/_differential_gene_expression/test_pydeseq2.py:
--------------------------------------------------------------------------------
1 | from importlib.util import find_spec
2 |
3 | import numpy.testing as npt
4 | import pytest
5 | from pertpy.tools._differential_gene_expression import PyDESeq2
6 |
7 | if find_spec("pydeseq2") is None:
8 | pytestmark = pytest.mark.skip(reason="pydeseq2 not available")
9 |
10 |
11 | def test_pydeseq2_simple(test_adata):
12 | """Check that the pyDESeq2 method can be
13 |
14 | 1. Initialized
15 | 2. Fitted
16 | 3. and that test_contrast returns a DataFrame with the correct number of rows.
17 | """
18 | method = PyDESeq2(adata=test_adata, design="~condition")
19 | method.fit()
20 | res_df = method.test_contrasts(method.contrast("condition", "A", "B"))
21 |
22 | assert len(res_df) == test_adata.n_vars
23 | # Compare against snapshot
24 | npt.assert_almost_equal(
25 | res_df.p_value.values,
26 | [0.00017, 0.00033, 0.00051, 0.0286, 0.03207, 0.04723, 0.11039, 0.11452, 0.3703, 0.99625],
27 | decimal=4,
28 | )
29 | npt.assert_almost_equal(
30 | res_df.log_fc.values,
31 | [0.58207, 0.53855, -0.4121, 0.63281, -0.63283, -0.27066, -0.21271, 0.38601, 0.13434, 0.00146],
32 | decimal=4,
33 | )
34 |
35 |
36 | def test_pydeseq2_complex(test_adata):
37 | """Check that the pyDESeq2 method can be initialized with a different covariate name and fitted and that the test_contrast
38 | method returns a dataframe with the correct number of rows.
39 | """
40 | test_adata.obs["condition1"] = test_adata.obs["condition"].copy()
41 | method = PyDESeq2(adata=test_adata, design="~condition1+group")
42 | method.fit()
43 | res_df = method.test_contrasts(method.contrast("condition1", "A", "B"))
44 |
45 | assert len(res_df) == test_adata.n_vars
46 | # Check that the index of the result matches the var_names of the AnnData object
47 | assert set(test_adata.var_names) == set(res_df["variable"])
48 | # Compare against snapshot
49 | npt.assert_almost_equal(
50 | res_df.p_value.values,
51 | [7e-05, 0.00012, 0.00035, 0.01062, 0.01906, 0.03892, 0.10755, 0.11175, 0.36631, 0.94952],
52 | decimal=4,
53 | )
54 | npt.assert_almost_equal(
55 | res_df.log_fc.values,
56 | [-0.42347, 0.58802, 0.53528, 0.73147, -0.67374, -0.27158, -0.21402, 0.38953, 0.13511, -0.01949],
57 | decimal=4,
58 | )
59 |
60 |
61 | def test_pydeseq2_formula(test_adata):
62 | """Check that the pyDESeq2 method gives consistent results when specifying contrasts, regardless of the order of covariates"""
63 | model1 = PyDESeq2(adata=test_adata, design="~condition+group")
64 | model1.fit()
65 | res_1 = model1.test_contrasts(model1.contrast("condition", "A", "B"))
66 |
67 | model2 = PyDESeq2(adata=test_adata, design="~group+condition")
68 | model2.fit()
69 | res_2 = model2.test_contrasts(model2.contrast("condition", "A", "B"))
70 |
71 | npt.assert_almost_equal(res_2.log_fc.values, res_1.log_fc.values)
72 |
--------------------------------------------------------------------------------
/tests/tools/_differential_gene_expression/test_simple_tests.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import pytest
4 | from pandas.core.api import DataFrame
5 | from pertpy.tools._differential_gene_expression import SimpleComparisonBase, TTest, WilcoxonTest
6 |
7 |
8 | @pytest.mark.parametrize(
9 | "paired_by,expected",
10 | [
11 | pytest.param(
12 | None,
13 | {"gene1": {"p_value": 1.34e-14, "log_fc": -5.14}, "gene2": {"p_value": 0.54, "log_fc": -0.016}},
14 | id="unpaired",
15 | ),
16 | pytest.param(
17 | "pairing",
18 | {"gene1": {"p_value": 3.70e-8, "log_fc": -5.14}, "gene2": {"p_value": 0.67, "log_fc": -0.016}},
19 | id="paired",
20 | ),
21 | ],
22 | )
23 | def test_wilcoxon(test_adata_minimal, paired_by, expected):
24 | """Test that wilcoxon test gives the correct values.
25 |
26 | Reference values have been computed in R using wilcox.test
27 | """
28 | res_df = WilcoxonTest.compare_groups(
29 | adata=test_adata_minimal, column="condition", baseline="A", groups_to_compare="B", paired_by=paired_by
30 | )
31 | actual = res_df.loc[:, ["variable", "p_value", "log_fc"]].set_index("variable").to_dict(orient="index")
32 | for gene in expected:
33 | assert actual[gene] == pytest.approx(expected[gene], abs=0.02)
34 |
35 |
36 | @pytest.mark.parametrize(
37 | "paired_by,expected",
38 | [
39 | pytest.param(
40 | None,
41 | {"gene1": {"p_value": 2.13e-26, "log_fc": -5.14}, "gene2": {"p_value": 0.96, "log_fc": -0.016}},
42 | id="unpaired",
43 | ),
44 | pytest.param(
45 | "pairing",
46 | {"gene1": {"p_value": 1.63e-26, "log_fc": -5.14}, "gene2": {"p_value": 0.85, "log_fc": -0.016}},
47 | id="paired",
48 | ),
49 | ],
50 | )
51 | def test_t(test_adata_minimal, paired_by, expected):
52 | """Test that t-test gives the correct values.
53 |
54 | Reference values have been computed in R using wilcox.test
55 | """
56 | res_df = TTest.compare_groups(
57 | adata=test_adata_minimal, column="condition", baseline="A", groups_to_compare="B", paired_by=paired_by
58 | )
59 | actual = res_df.loc[:, ["variable", "p_value", "log_fc"]].set_index("variable").to_dict(orient="index")
60 | for gene in expected:
61 | assert actual[gene] == pytest.approx(expected[gene], abs=0.02)
62 |
63 |
64 | @pytest.mark.parametrize("seed", range(10))
65 | def test_simple_comparison_pairing(test_adata_minimal, seed):
66 | """Test that paired samples are properly matched in a paired test"""
67 |
68 | class MockSimpleComparison(SimpleComparisonBase):
69 | @staticmethod
70 | def _test():
71 | return None
72 |
73 | def _compare_single_group(
74 | self, baseline_idx: np.ndarray, comparison_idx: np.ndarray, *, paired: bool = False, **kwargs
75 | ) -> DataFrame:
76 | assert paired
77 | x0 = self.adata[baseline_idx, :]
78 | x1 = self.adata[comparison_idx, :]
79 | assert np.all(x0.obs["condition"] == "A")
80 | assert np.all(x1.obs["condition"] == "B")
81 | assert np.all(x0.obs["pairing"].values == x1.obs["pairing"].values)
82 | return pd.DataFrame([{"p_value": 1}])
83 |
84 | rng = np.random.default_rng(seed)
85 | shuffle_adata_idx = rng.permutation(test_adata_minimal.obs_names)
86 | tmp_adata = test_adata_minimal[shuffle_adata_idx, :].copy()
87 |
88 | MockSimpleComparison.compare_groups(
89 | tmp_adata, column="condition", baseline="A", groups_to_compare=["B"], paired_by="pairing"
90 | )
91 |
92 |
93 | @pytest.mark.parametrize(
94 | "params",
95 | [
96 | pytest.param(
97 | {"column": "donor", "baseline": "D0", "paired_by": "pairing", "groups_to_compare": "D1"},
98 | id="pairing not subgroup of donor",
99 | ),
100 | pytest.param(
101 | {"column": "donor", "baseline": "D0", "paired_by": "condition", "groups_to_compare": "D1"},
102 | id="more than two per group (donor)",
103 | ),
104 | pytest.param(
105 | {"column": "condition", "baseline": "A", "paired_by": "donor", "groups_to_compare": "B"},
106 | id="more than two per group (condition)",
107 | ),
108 | ],
109 | )
110 | def test_invalid_pairing(test_adata_minimal, params):
111 | """Test that the SimpleComparisonBase class raises an error when paired analysis is requested with invalid configuration."""
112 | with pytest.raises(ValueError):
113 | TTest.compare_groups(test_adata_minimal, **params)
114 |
--------------------------------------------------------------------------------
/tests/tools/_differential_gene_expression/test_statsmodels.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | import statsmodels.api as sm
4 | from pertpy.tools._differential_gene_expression import Statsmodels
5 |
6 |
7 | @pytest.mark.parametrize(
8 | "method_class,kwargs",
9 | [
10 | # OLS
11 | (Statsmodels, {}),
12 | # Negative Binomial
13 | (
14 | Statsmodels,
15 | {"regression_model": sm.GLM, "family": sm.families.NegativeBinomial()},
16 | ),
17 | ],
18 | )
19 | def test_statsmodels(test_adata, method_class, kwargs):
20 | """Check that the method can be initialized and fitted, and perform basic checks on
21 | the result of test_contrasts."""
22 | method = method_class(adata=test_adata, design="~condition") # type: ignore
23 | method.fit(**kwargs)
24 | res_df = method.test_contrasts(np.array([0, 1]))
25 | # Check that the result has the correct number of rows
26 | assert len(res_df) == test_adata.n_vars
27 |
28 |
29 | # TODO: there should be a test checking if, for a concrete example, the output p-values and effect sizes are what
30 | # we expect (-> frozen snapshot, that way we also get a heads-up if something changes upstream)
31 |
--------------------------------------------------------------------------------
/tests/tools/_distances/test_distance_tests.py:
--------------------------------------------------------------------------------
1 | import pertpy as pt
2 | import pytest
3 | import scanpy as sc
4 | from pandas import DataFrame
5 |
6 | distances = [
7 | "edistance",
8 | "euclidean",
9 | "mse",
10 | "mean_absolute_error",
11 | "pearson_distance",
12 | "spearman_distance",
13 | "kendalltau_distance",
14 | "cosine_distance",
15 | "wasserstein",
16 | "mean_pairwise",
17 | "mmd",
18 | "r2_distance",
19 | "sym_kldiv",
20 | "t_test",
21 | "ks_test",
22 | "classifier_proba",
23 | # "classifier_cp",
24 | # "nbll",
25 | "mahalanobis",
26 | "mean_var_distribution",
27 | ]
28 |
29 | count_distances = ["nb_ll"]
30 |
31 |
32 | @pytest.fixture
33 | def adata():
34 | adata = pt.dt.distance_example()
35 | adata = sc.pp.subsample(adata, 0.1, copy=True)
36 |
37 | return adata
38 |
39 |
40 | @pytest.mark.parametrize("distance", distances)
41 | def test_distancetest(adata, distance):
42 | etest = pt.tl.DistanceTest(distance, n_perms=10, obsm_key="X_pca", alpha=0.05, correction="holm-sidak")
43 | tab = etest(adata, groupby="perturbation", contrast="control")
44 |
45 | # Well-defined output
46 | assert tab.shape[1] == 5
47 | assert isinstance(tab, DataFrame)
48 |
49 | # p-values are in [0,1]
50 | assert tab["pvalue"].min() >= 0
51 | assert tab["pvalue"].max() <= 1
52 | assert tab["pvalue_adj"].min() >= 0
53 | assert tab["pvalue_adj"].max() <= 1
54 |
--------------------------------------------------------------------------------
/tests/tools/_perturbation_space/test_comparison.py:
--------------------------------------------------------------------------------
1 | import pertpy as pt
2 | import pytest
3 |
4 |
5 | @pytest.fixture
6 | def test_data(rng):
7 | X = rng.normal(size=(100, 10))
8 | Y = rng.normal(size=(100, 10))
9 | C = rng.normal(size=(100, 10))
10 | return X, Y, C
11 |
12 |
13 | def test_compare_class(test_data):
14 | X, Y, C = test_data
15 | pt_comparison = pt.tl.PerturbationComparison()
16 | result = pt_comparison.compare_classification(X, Y, C)
17 | assert result <= 1
18 |
19 |
20 | def test_compare_knn(test_data):
21 | X, Y, C = test_data
22 | pt_comparison = pt.tl.PerturbationComparison()
23 | result = pt_comparison.compare_knn(X, Y, C)
24 | assert isinstance(result, dict)
25 | assert "comp" in result
26 | assert isinstance(result["comp"], float)
27 |
28 | result_no_ctrl = pt_comparison.compare_knn(X, Y)
29 | assert isinstance(result_no_ctrl, dict)
30 |
--------------------------------------------------------------------------------
/tests/tools/_perturbation_space/test_discriminator_classifiers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import pertpy as pt
4 | import pytest
5 | from anndata import AnnData
6 |
7 |
8 | @pytest.fixture
9 | def adata():
10 | X = np.zeros((20, 5), dtype=np.float32)
11 |
12 | pert_index = [
13 | "control",
14 | "target1",
15 | "target1",
16 | "target2",
17 | "target2",
18 | "target1",
19 | "target1",
20 | "target2",
21 | "target2",
22 | "target2",
23 | "control",
24 | "target1",
25 | "target1",
26 | "target2",
27 | "target2",
28 | "target1",
29 | "target1",
30 | "target2",
31 | "target2",
32 | "target2",
33 | ]
34 |
35 | for i, value in enumerate(pert_index):
36 | if value == "control":
37 | X[i, :] = 0
38 | elif value == "target1":
39 | X[i, :] = 10
40 | elif value == "target2":
41 | X[i, :] = 30
42 |
43 | obs = pd.DataFrame({"perturbations": pert_index})
44 |
45 | adata = AnnData(X, obs=obs)
46 |
47 | # Add a obs annotations to the adata
48 | adata.obs["MoA"] = ["Growth" if pert == "target1" else "Unknown" for pert in adata.obs["perturbations"]]
49 | adata.obs["Partial Annotation"] = ["Anno1" if pert == "target2" else np.nan for pert in adata.obs["perturbations"]]
50 |
51 | return adata
52 |
53 |
54 | def test_mlp_classifier_space(adata):
55 | classifier_ps = pt.tl.MLPClassifierSpace()
56 | pert_embeddings = classifier_ps.compute(adata, hidden_dim=[128], max_epochs=2)
57 |
58 | # The embeddings should cluster in 3 perfects clusters since the perturbations are easily separable
59 | ps = pt.tl.KMeansSpace()
60 | adata = ps.compute(pert_embeddings, n_clusters=3, copy=True)
61 | results = ps.evaluate_clustering(adata, true_label_col="perturbations", cluster_col="k-means")
62 | np.testing.assert_equal(len(results), 3)
63 | np.testing.assert_allclose(results["nmi"], 0.99, rtol=0.1)
64 | np.testing.assert_allclose(results["ari"], 0.99, rtol=0.1)
65 | np.testing.assert_allclose(results["asw"], 0.99, rtol=0.1)
66 |
67 |
68 | def test_regression_classifier_space(adata):
69 | ps = pt.tl.LRClassifierSpace()
70 | pert_embeddings = ps.compute(adata)
71 |
72 | assert pert_embeddings.shape == (3, 5)
73 | assert pert_embeddings.obs[pert_embeddings.obs["perturbations"] == "target1"]["MoA"].values == "Growth"
74 | assert "Partial Annotation" not in pert_embeddings.obs_names
75 | # The classifier should be able to distinguish control and target2 from the respective other two classes
76 | assert np.all(
77 | pert_embeddings.obs[pert_embeddings.obs["perturbations"].isin(["control", "target2"])][
78 | "classifier_score"
79 | ].values
80 | == 1.0
81 | )
82 |
--------------------------------------------------------------------------------
/tests/tools/_perturbation_space/test_simple_cluster_space.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import pertpy as pt
4 | from anndata import AnnData
5 |
6 |
7 | def test_clustering():
8 | X = np.zeros((10, 5))
9 |
10 | pert_index = [
11 | "control",
12 | "target1",
13 | "target1",
14 | "target2",
15 | "target2",
16 | "target1",
17 | "target1",
18 | "target2",
19 | "target2",
20 | "target2",
21 | ]
22 |
23 | for i, value in enumerate(pert_index):
24 | if value == "control":
25 | X[i, :] = 0
26 | elif value == "target1":
27 | X[i, :] = 10
28 | elif value == "target2":
29 | X[i, :] = 30
30 |
31 | obs = pd.DataFrame({"perturbations": pert_index})
32 |
33 | adata = AnnData(X, obs=obs)
34 |
35 | # Compute clustering at observation level
36 | ps = pt.tl.KMeansSpace()
37 | adata = ps.compute(adata, n_clusters=3, copy=True)
38 |
39 | ps = pt.tl.DBSCANSpace()
40 | adata = ps.compute(adata, min_samples=1, copy=True)
41 |
42 | results = ps.evaluate_clustering(adata, true_label_col="perturbations", cluster_col="k-means", metric="l1")
43 | np.testing.assert_equal(len(results), 3)
44 | np.testing.assert_allclose(results["nmi"], 0.99, rtol=0.1)
45 | np.testing.assert_allclose(results["ari"], 0.99, rtol=0.1)
46 | np.testing.assert_allclose(results["asw"], 0.99, rtol=0.1)
47 |
48 | results = ps.evaluate_clustering(adata, true_label_col="perturbations", cluster_col="dbscan", metric="l1")
49 | np.testing.assert_equal(len(results), 3)
50 | np.testing.assert_allclose(results["nmi"], 0.99, rtol=0.1)
51 | np.testing.assert_allclose(results["ari"], 0.99, rtol=0.1)
52 | np.testing.assert_allclose(results["asw"], 0.99, rtol=0.1)
53 |
54 | np.testing.assert_allclose(results["nmi"], 0.99, rtol=0.1)
55 | np.testing.assert_allclose(results["ari"], 0.99, rtol=0.1)
56 | np.testing.assert_allclose(results["asw"], 0.99, rtol=0.1)
57 |
--------------------------------------------------------------------------------
/tests/tools/test_augur.py:
--------------------------------------------------------------------------------
1 | from math import isclose
2 | from pathlib import Path
3 |
4 | import numpy as np
5 | import pertpy as pt
6 | import pytest
7 | import scanpy as sc
8 |
9 | CWD = Path(__file__).parent.resolve()
10 |
11 |
12 | ag_rfc = pt.tl.Augur("random_forest_classifier", random_state=42)
13 | ag_lrc = pt.tl.Augur("logistic_regression_classifier", random_state=42)
14 | ag_rfr = pt.tl.Augur("random_forest_regressor", random_state=42)
15 |
16 |
17 | @pytest.fixture
18 | def adata():
19 | adata = pt.dt.sc_sim_augur()
20 | adata = sc.pp.subsample(adata, n_obs=200, copy=True, random_state=10)
21 |
22 | return adata
23 |
24 |
25 | def test_load(adata):
26 | """Test if load function creates anndata objects."""
27 | ag = pt.tl.Augur(estimator="random_forest_classifier")
28 |
29 | loaded_adata = ag.load(adata)
30 | loaded_df = ag.load(adata.to_df(), meta=adata.obs, cell_type_col="cell_type", label_col="label")
31 |
32 | assert loaded_adata.obs["y_"].equals(loaded_df.obs["y_"]) is True
33 | assert adata.to_df().equals(loaded_adata.to_df()) is True and adata.to_df().equals(loaded_df.to_df())
34 |
35 |
36 | def test_random_forest_classifier(adata):
37 | """Tests random forest for auc calculation."""
38 | adata = ag_rfc.load(adata)
39 | sc.pp.highly_variable_genes(adata)
40 | h_adata, results = ag_rfc.predict(
41 | adata, n_threads=4, n_subsamples=3, random_state=42, select_variance_features=False
42 | )
43 |
44 | assert results["CellTypeA"][2]["subsample_idx"] == 2
45 | assert "augur_score" in h_adata.obs.columns
46 | assert np.allclose(results["summary_metrics"].loc["mean_augur_score"].tolist(), [0.634920, 0.933484, 0.902494])
47 | assert "feature_importances" in results
48 | assert len(set(results["summary_metrics"]["CellTypeA"])) == len(results["summary_metrics"]["CellTypeA"]) - 1
49 |
50 |
51 | def test_logistic_regression_classifier(adata):
52 | """Tests logistic classifier for auc calculation."""
53 | adata = ag_rfc.load(adata)
54 | sc.pp.highly_variable_genes(adata)
55 | h_adata, results = ag_lrc.predict(
56 | adata, n_threads=4, n_subsamples=3, random_state=42, select_variance_features=False
57 | )
58 |
59 | assert "augur_score" in h_adata.obs.columns
60 | assert np.allclose(results["summary_metrics"].loc["mean_augur_score"].tolist(), [0.691232, 0.955404, 0.972789])
61 | assert "feature_importances" in results
62 |
63 |
64 | def test_random_forest_regressor(adata):
65 | """Tests random forest regressor for ccc calculation."""
66 | adata = ag_rfc.load(adata)
67 | sc.pp.highly_variable_genes(adata)
68 |
69 | with pytest.raises(ValueError):
70 | ag_rfr.predict(adata, n_threads=4, n_subsamples=3, random_state=42)
71 |
72 |
73 | def test_classifier(adata):
74 | """Test run cross validation with classifier."""
75 | adata = ag_rfc.load(adata)
76 | sc.pp.highly_variable_genes(adata)
77 | adata_subsampled = sc.pp.subsample(adata, n_obs=100, random_state=42, copy=True)
78 |
79 | cv = ag_rfc.run_cross_validation(adata_subsampled, subsample_idx=1, folds=3, random_state=42, zero_division=0)
80 | auc = 0.786412
81 | assert any([isclose(cv["mean_auc"], auc, abs_tol=10**-3)])
82 |
83 | cv = ag_lrc.run_cross_validation(adata, subsample_idx=1, folds=3, random_state=42, zero_division=0)
84 | auc = 0.978673
85 | assert any([isclose(cv["mean_auc"], auc, abs_tol=10**-3)])
86 |
87 |
88 | def test_regressor(adata):
89 | """Test run cross validation with regressor."""
90 | adata = ag_rfc.load(adata)
91 | cv = ag_rfr.run_cross_validation(adata, subsample_idx=1, folds=3, random_state=42, zero_division=0)
92 | ccc = 0.168800
93 | r2 = 0.149887
94 | assert any([isclose(cv["mean_ccc"], ccc, abs_tol=10**-5), isclose(cv["mean_r2"], r2, abs_tol=10**-5)])
95 |
96 |
97 | def test_subsample(adata):
98 | """Test default, permute and velocity subsampling process."""
99 | adata = ag_rfc.load(adata)
100 | sc.pp.highly_variable_genes(adata)
101 | categorical_subsample = ag_rfc.draw_subsample(
102 | adata=adata,
103 | augur_mode="default",
104 | subsample_size=20,
105 | feature_perc=0.3,
106 | categorical=True,
107 | random_state=42,
108 | )
109 | assert len(categorical_subsample.obs_names) == 40
110 |
111 | non_categorical_subsample = ag_rfc.draw_subsample(
112 | adata=adata,
113 | augur_mode="default",
114 | subsample_size=20,
115 | feature_perc=0.3,
116 | categorical=False,
117 | random_state=42,
118 | )
119 | assert len(non_categorical_subsample.obs_names) == 20
120 |
121 | permut_subsample = ag_rfc.draw_subsample(
122 | adata=adata,
123 | augur_mode="permute",
124 | subsample_size=20,
125 | feature_perc=0.3,
126 | categorical=True,
127 | random_state=42,
128 | )
129 | assert (adata.obs.loc[permut_subsample.obs.index, "y_"] != permut_subsample.obs["y_"]).any()
130 |
131 | velocity_subsample = ag_rfc.draw_subsample(
132 | adata=adata,
133 | augur_mode="velocity",
134 | subsample_size=20,
135 | feature_perc=0.3,
136 | categorical=True,
137 | random_state=42,
138 | )
139 | assert len(velocity_subsample.var_names) == 5505 and len(velocity_subsample.obs_names) == 40
140 |
141 |
142 | def test_select_variance(adata):
143 | """Test select variance implementation."""
144 | adata = ag_rfc.load(adata)
145 | sc.pp.highly_variable_genes(adata)
146 | adata_cell_type = adata[adata.obs["cell_type"] == "CellTypeA"].copy()
147 | ad = ag_rfc.select_variance(adata_cell_type, var_quantile=0.5, span=0.3, filter_negative_residuals=False)
148 |
149 | assert len(ad.var.index[ad.var["highly_variable"]]) == 3672
150 |
151 |
152 | def test_differential_prioritization():
153 | """Test differential prioritization run."""
154 | # Requires the full dataset or it fails because of a lack of statistical power
155 | adata = pt.dt.sc_sim_augur()
156 | adata = sc.pp.subsample(adata, n_obs=500, copy=True, random_state=10)
157 | ag = pt.tl.Augur("logistic_regression_classifier", random_state=42)
158 | ag.load(adata)
159 |
160 | adata, results1 = ag.predict(adata, n_threads=4, n_subsamples=3, random_state=2)
161 | adata, results2 = ag.predict(adata, n_threads=4, n_subsamples=3, random_state=42)
162 |
163 | a, permut1 = ag.predict(adata, augur_mode="permute", n_threads=4, n_subsamples=100, random_state=2)
164 | a, permut2 = ag.predict(adata, augur_mode="permute", n_threads=4, n_subsamples=100, random_state=42)
165 | delta = ag.predict_differential_prioritization(results1, results2, permut1, permut2)
166 | assert not np.isnan(delta["z"]).any()
167 |
--------------------------------------------------------------------------------
/tests/tools/test_cinemaot.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import numpy as np
4 | import pertpy as pt
5 | import scanpy as sc
6 | from _pytest.fixtures import fixture
7 |
8 | CWD = Path(__file__).parent.resolve()
9 |
10 |
11 | @fixture
12 | def adata():
13 | adata = pt.dt.cinemaot_example()
14 | adata = sc.pp.subsample(adata, 0.1, copy=True)
15 |
16 | return adata
17 |
18 |
19 | def test_unweighted(adata):
20 | sc.pp.pca(adata)
21 | model = pt.tl.Cinemaot()
22 | de = model.causaleffect(
23 | adata,
24 | pert_key="perturbation",
25 | control="No stimulation",
26 | return_matching=True,
27 | thres=0.5,
28 | smoothness=1e-5,
29 | eps=1e-3,
30 | solver="Sinkhorn",
31 | preweight_label="cell_type0528",
32 | )
33 |
34 | eps = 1e-1
35 | assert "cf" in adata.obsm
36 | assert "ot" in de.obsm
37 | assert not np.isnan(np.sum(de.obsm["ot"]))
38 | assert not np.abs(np.sum(de.obsm["ot"]) - 1) > eps
39 |
40 |
41 | def test_weighted(adata):
42 | sc.pp.pca(adata)
43 | model = pt.tl.Cinemaot()
44 | ad, de = model.causaleffect_weighted(
45 | adata,
46 | pert_key="perturbation",
47 | control="No stimulation",
48 | return_matching=True,
49 | thres=0.5,
50 | smoothness=1e-5,
51 | eps=1e-3,
52 | solver="Sinkhorn",
53 | )
54 |
55 | eps = 1e-1
56 | assert "cf" in ad.obsm
57 | assert "ot" in de.obsm
58 | assert not np.isnan(np.sum(de.obsm["ot"]))
59 | assert not np.abs(np.sum(de.obsm["ot"]) - 1) > eps
60 |
61 |
62 | def test_pseudobulk(adata):
63 | sc.pp.pca(adata)
64 | model = pt.tl.Cinemaot()
65 | de = model.causaleffect(
66 | adata,
67 | pert_key="perturbation",
68 | control="No stimulation",
69 | return_matching=True,
70 | thres=0.5,
71 | smoothness=1e-5,
72 | eps=1e-3,
73 | solver="Sinkhorn",
74 | preweight_label="cell_type0528",
75 | )
76 | adata_pb = model.generate_pseudobulk(adata, de, pert_key="perturbation", control="No stimulation", label_list=None)
77 |
78 | expect_num = 9
79 | eps = 7
80 | assert "ptb" in adata_pb.obs
81 | assert not np.isnan(np.sum(adata_pb.X))
82 | print(adata_pb.shape)
83 | assert not np.abs(adata_pb.shape[0] - expect_num) >= eps
84 |
--------------------------------------------------------------------------------
/tests/tools/test_dialogue.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import pertpy as pt
3 | import scanpy as sc
4 |
5 | # This is not a proper test!
6 | # We are only testing a few functions to ensure that at least these run
7 | # The pipeline is obtained from https://pertpy.readthedocs.io/en/latest/tutorials/notebooks/dialogue.html
8 |
9 |
10 | def test_dialogue_pipeline():
11 | adata = pt.dt.dialogue_example()
12 |
13 | sc.pp.pca(adata)
14 | sc.pp.neighbors(adata)
15 | sc.tl.umap(adata)
16 |
17 | isecs = pd.crosstab(adata.obs["cell.subtypes"], adata.obs["sample"])
18 | adata = adata[adata.obs["cell.subtypes"] != "CD8+ IL17+"]
19 | isecs = pd.crosstab(adata.obs["cell.subtypes"], adata.obs["sample"])
20 |
21 | keep_pts = list(isecs.loc[:, (isecs > 3).sum(axis=0) == isecs.shape[0]].columns.values)
22 | adata = adata[adata.obs["sample"].isin(keep_pts), :].copy()
23 |
24 | dl = pt.tl.Dialogue(
25 | sample_id="sample",
26 | celltype_key="cell.subtypes",
27 | n_counts_key="nCount_RNA",
28 | n_mpcs=3,
29 | )
30 |
31 | adata, mcps, ws, ct_subs = dl.calculate_multifactor_PMD(adata, normalize=True)
32 |
33 | dl.test_association(adata, "path_str")
34 |
35 | dl.get_extrema_MCP_genes(ct_subs)
36 |
--------------------------------------------------------------------------------
/tests/tools/test_enrichment.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pertpy as pt
3 | import pytest
4 | import scanpy as sc
5 | from anndata import AnnData
6 |
7 |
8 | @pytest.fixture
9 | def dummy_adata():
10 | n_obs = 10
11 | n_vars = 5
12 | rng = np.random.default_rng()
13 | X = rng.random((n_obs, n_vars))
14 | adata = AnnData(X)
15 | adata.var_names = [f"gene{i}" for i in range(n_vars)]
16 | adata.obs["cluster"] = ["group_1"] * 5 + ["group_2"] * 5
17 | sc.tl.rank_genes_groups(adata, groupby="cluster", method="t-test")
18 |
19 | return adata
20 |
21 |
22 | @pytest.fixture(scope="module")
23 | def enricher():
24 | return pt.tl.Enrichment()
25 |
26 |
27 | def test_score_basic(dummy_adata, enricher):
28 | targets = {"group1": ["gene1", "gene2"], "group2": ["gene3", "gene4"]}
29 | enricher.score(adata=dummy_adata, targets=targets)
30 | assert "pertpy_enrichment_score" in dummy_adata.uns
31 |
32 |
33 | def test_score_with_different_layers(dummy_adata, enricher):
34 | rng = np.random.default_rng()
35 | dummy_adata.layers["layer"] = rng.random((10, 5))
36 | targets = {"group1": ["gene1", "gene2"], "group2": ["gene3", "gene4"]}
37 | enricher.score(adata=dummy_adata, layer="layer", targets=targets)
38 | assert "pertpy_enrichment_score" in dummy_adata.uns
39 |
40 |
41 | def test_score_with_nested_targets(dummy_adata, enricher):
42 | targets = {"category1": {"group1": ["gene1", "gene2"]}, "category2": {"group2": ["gene3", "gene4"]}}
43 | enricher.score(adata=dummy_adata, targets=targets, nested=True)
44 | assert "pertpy_enrichment_score" in dummy_adata.uns
45 |
46 |
47 | def test_hypergeometric_basic(dummy_adata, enricher):
48 | targets = {"group1": ["gene1", "gene2"]}
49 | results = enricher.hypergeometric(dummy_adata, targets)
50 | assert isinstance(results, dict)
51 |
52 |
53 | def test_hypergeometric_with_nested_targets(dummy_adata, enricher):
54 | targets = {"category1": {"group1": ["gene1", "gene2"]}}
55 | results = enricher.hypergeometric(dummy_adata, targets, nested=True)
56 | assert isinstance(results, dict)
57 |
58 |
59 | @pytest.mark.parametrize("direction", ["up", "down", "both"])
60 | def test_hypergeometric_with_different_directions(dummy_adata, enricher, direction):
61 | targets = {"group1": ["gene1", "gene2"]}
62 | results = enricher.hypergeometric(dummy_adata, targets, direction=direction)
63 | assert isinstance(results, dict)
64 |
--------------------------------------------------------------------------------
/tests/tools/test_scgen.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import anndata as ad
4 | import pertpy as pt
5 | import scanpy as sc
6 | from scvi.data import synthetic_iid
7 |
8 |
9 | def test_scgen():
10 | with warnings.catch_warnings():
11 | warnings.filterwarnings("ignore", message="Observation names are not unique")
12 | adata = synthetic_iid()
13 | adata.obs_names_make_unique()
14 | pt.tl.Scgen.setup_anndata(
15 | adata,
16 | batch_key="batch",
17 | labels_key="labels",
18 | )
19 |
20 | scg = pt.tl.Scgen(adata)
21 | scg.train(max_epochs=1, batch_size=32, early_stopping=True, early_stopping_patience=25)
22 |
23 | scg.batch_removal()
24 |
25 | # predict
26 | pred, delta = scg.predict(ctrl_key="batch_0", stim_key="batch_1", celltype_to_predict="label_0")
27 | pred.obs["batch"] = "pred"
28 |
29 | # reg mean and reg var
30 | ctrl_adata = adata[((adata.obs["labels"] == "label_0") & (adata.obs["batch"] == "batch_0"))]
31 | stim_adata = adata[((adata.obs["labels"] == "label_0") & (adata.obs["batch"] == "batch_1"))]
32 | eval_adata = ad.concat([ctrl_adata, stim_adata, pred], label="concat_batches")
33 | label_0 = adata[adata.obs["labels"] == "label_0"]
34 | sc.tl.rank_genes_groups(label_0, groupby="batch", method="wilcoxon")
35 | diff_genes = label_0.uns["rank_genes_groups"]["names"]["batch_1"]
36 |
37 | scg.plot_reg_mean_plot(
38 | eval_adata,
39 | condition_key="batch",
40 | axis_keys={"x": "pred", "y": "batch_1"},
41 | gene_list=diff_genes[:10],
42 | labels={"x": "predicted", "y": "ground truth"},
43 | save=False,
44 | show=False,
45 | legend=False,
46 | )
47 |
48 | scg.plot_reg_var_plot(
49 | eval_adata,
50 | condition_key="batch",
51 | axis_keys={"x": "pred", "y": "batch_1"},
52 | gene_list=diff_genes[:10],
53 | labels={"x": "predicted", "y": "ground truth"},
54 | save=False,
55 | show=False,
56 | legend=False,
57 | )
58 |
--------------------------------------------------------------------------------