├── .github
└── workflows
│ └── ci.yml
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE.md
├── README.md
├── docs
├── override
│ └── main.html
├── source
│ ├── .pages
│ ├── CNAME
│ ├── benchmark
│ │ ├── diff
│ │ │ ├── models.py
│ │ │ ├── simulated.py
│ │ │ └── timing.py
│ │ └── index.md
│ ├── index.md
│ ├── javascripts
│ │ └── reference.js
│ ├── quickstart
│ │ ├── 0_install.py
│ │ ├── 1_data.py
│ │ ├── 2_diff.py
│ │ └── 3_pred.py
│ ├── reference
│ │ ├── .pages
│ │ ├── data
│ │ │ ├── clustering.md
│ │ │ ├── folds.md
│ │ │ ├── fragments.md
│ │ │ ├── motifscan.md
│ │ │ ├── regions.md
│ │ │ └── transcriptome.md
│ │ ├── index.md
│ │ ├── loaders
│ │ │ └── fragments.md
│ │ └── models
│ │ │ ├── diff
│ │ │ ├── interpret.md
│ │ │ ├── model.md
│ │ │ └── plot.md
│ │ │ └── pred
│ │ │ ├── interpret.md
│ │ │ ├── model.md
│ │ │ └── plot.md
│ ├── static
│ │ ├── comparison.gif
│ │ ├── eu.png
│ │ ├── favicon.ai
│ │ ├── favicon.png
│ │ ├── logo.ai
│ │ ├── logo.png
│ │ └── models
│ │ │ ├── diff
│ │ │ ├── 1x
│ │ │ │ └── logo.png
│ │ │ └── logo.pdf
│ │ │ ├── dime
│ │ │ ├── logo.pdf
│ │ │ └── logo.png
│ │ │ ├── pred
│ │ │ ├── 1x
│ │ │ │ └── logo.png
│ │ │ └── logo.pdf
│ │ │ └── time
│ │ │ ├── logo.pdf
│ │ │ └── logo.png
│ └── stylesheets
│ │ ├── extra-reference.css
│ │ └── extra.css
└── tex
│ └── overview.tex
├── mkdocs.yml
├── pyproject.toml
├── scripts
├── benchmark
│ └── datasets
│ │ └── pbmc10k
│ │ ├── 1-download.py
│ │ ├── 2-process_all.py
│ │ ├── 3-process_large.py
│ │ ├── 4-process_tiny.py
│ │ └── 5-process_wide.py
├── cythonize.sh
├── dist.sh
├── docs.sh
├── followup
│ └── miff
│ │ ├── fit_negbinom.py
│ │ ├── layers_linear_spline.py
│ │ ├── quadratic.py
│ │ └── truncated_normal.py
├── install.sh
└── test.py
├── setup.py
├── src
└── chromatinhd
│ ├── __init__.py
│ ├── biomart
│ ├── __init__.py
│ ├── cache.py
│ ├── dataset.py
│ ├── homology.py
│ └── tss.py
│ ├── data
│ ├── __init__.py
│ ├── associations
│ │ ├── __init__.py
│ │ ├── associations.py
│ │ └── plot.py
│ ├── clustering
│ │ ├── __init__.py
│ │ └── clustering.py
│ ├── examples
│ │ └── pbmc10ktiny
│ │ │ ├── fragments.tsv.gz
│ │ │ ├── fragments.tsv.gz.REMOVED.git-id
│ │ │ ├── fragments.tsv.gz.tbi
│ │ │ └── transcriptome.h5ad
│ ├── folds
│ │ ├── __init__.py
│ │ └── folds.py
│ ├── fragments
│ │ ├── __init__.py
│ │ ├── fragments.py
│ │ └── view.py
│ ├── genotype
│ │ ├── __init__.py
│ │ └── genotype.py
│ ├── gradient
│ │ ├── __init__.py
│ │ └── gradient.py
│ ├── motifscan
│ │ ├── __init__.py
│ │ ├── download.py
│ │ ├── motifcount.py
│ │ ├── motifscan.py
│ │ ├── motiftrack.py
│ │ ├── plot.py
│ │ ├── plot_genome.py
│ │ ├── scan_helpers.html
│ │ ├── scan_helpers.pyx
│ │ └── view.py
│ ├── peakcounts
│ │ ├── __init__.py
│ │ ├── peakcounts.py
│ │ └── plot.py
│ ├── regions.py
│ └── transcriptome
│ │ ├── __init__.py
│ │ ├── timetranscriptome.py
│ │ └── transcriptome.py
│ ├── device.py
│ ├── embedding.py
│ ├── flow
│ ├── __init__.py
│ ├── flow.py
│ ├── flow_template.jinja2
│ ├── linked.py
│ ├── objects.py
│ ├── sparse.py
│ ├── tensorstore.py
│ └── tipyte.py
│ ├── loaders
│ ├── __init__.py
│ ├── clustering.py
│ ├── clustering_fragments.py
│ ├── extraction
│ │ ├── fragments.pyx
│ │ └── motifs.pyx
│ ├── fragments.py
│ ├── fragments2.py
│ ├── fragments_helpers.html
│ ├── fragments_helpers.pyx
│ ├── minibatches.py
│ ├── peakcounts.py
│ ├── pool.py
│ ├── transcriptome.py
│ ├── transcriptome_fragments.py
│ ├── transcriptome_fragments2.py
│ └── transcriptome_fragments_time.py
│ ├── models
│ ├── __init__.py
│ ├── diff
│ │ ├── __init__.py
│ │ ├── interpret
│ │ │ ├── __init__.py
│ │ │ ├── differential.py
│ │ │ ├── enrichment
│ │ │ │ ├── __init__.py
│ │ │ │ ├── enrichment.py
│ │ │ │ └── group.py
│ │ │ ├── performance.py
│ │ │ ├── regionpositional.py
│ │ │ └── slices.py
│ │ ├── loader
│ │ │ ├── __init__.py
│ │ │ ├── clustering_cuts.py
│ │ │ └── clustering_fragments.py
│ │ ├── model
│ │ │ ├── __init__.py
│ │ │ ├── binary.py
│ │ │ ├── cutnf.py
│ │ │ ├── encoders.py
│ │ │ ├── playground.py
│ │ │ ├── spline.py
│ │ │ ├── splines
│ │ │ │ ├── __init__.py
│ │ │ │ ├── cubic.py
│ │ │ │ ├── linear.py
│ │ │ │ └── quadratic.py
│ │ │ └── truncated_normal.py
│ │ ├── plot
│ │ │ ├── __init__.py
│ │ │ ├── differential.py
│ │ │ └── differential_expression.py
│ │ └── trainer
│ │ │ ├── __init__.py
│ │ │ └── trainer.py
│ ├── miff
│ │ ├── loader
│ │ │ ├── __init__.py
│ │ │ ├── clustering.py
│ │ │ ├── combinations.py
│ │ │ ├── minibatches.py
│ │ │ ├── motifcount.py
│ │ │ ├── motifs.py
│ │ │ └── motifs_fragments.py
│ │ └── model
│ │ │ ├── __init__.py
│ │ │ ├── clustering
│ │ │ ├── __init__.py
│ │ │ ├── distributions.py
│ │ │ └── model.py
│ │ │ ├── clustering2
│ │ │ ├── __init__.py
│ │ │ ├── distributions.py
│ │ │ └── model.py
│ │ │ ├── clustering3
│ │ │ ├── __init__.py
│ │ │ ├── count.py
│ │ │ ├── model.py
│ │ │ └── position.py
│ │ │ ├── global_norm
│ │ │ ├── __init__.py
│ │ │ ├── distributions.py
│ │ │ └── model.py
│ │ │ ├── local_norm
│ │ │ ├── __init__.py
│ │ │ ├── distributions.py
│ │ │ └── model.py
│ │ │ ├── site_embedder.py
│ │ │ └── zoom.py
│ ├── model.py
│ ├── pred
│ │ ├── __init__.py
│ │ ├── interpret
│ │ │ ├── __init__.py
│ │ │ ├── censorers.py
│ │ │ ├── performance.py
│ │ │ ├── regionmultiwindow.py
│ │ │ ├── regionpairwindow.py
│ │ │ ├── regionsizewindow.py
│ │ │ └── size.py
│ │ ├── model
│ │ │ ├── __init__.py
│ │ │ ├── better.py
│ │ │ ├── encoders.py
│ │ │ ├── loss.py
│ │ │ ├── multilinear.py
│ │ │ ├── multiscale.py
│ │ │ ├── peakcounts.py
│ │ │ ├── peakcounts_test.py
│ │ │ └── shared.py
│ │ ├── plot
│ │ │ ├── __init__.py
│ │ │ ├── copredictivity.py
│ │ │ ├── effect.py
│ │ │ └── predictivity.py
│ │ └── trainer
│ │ │ ├── __init__.py
│ │ │ └── trainer.py
│ └── pret
│ │ ├── model
│ │ ├── __init__.py
│ │ ├── better.py
│ │ ├── loss.py
│ │ └── peakcounts.py
│ │ └── trainer
│ │ ├── __init__.py
│ │ └── trainer.py
│ ├── optim.py
│ ├── plot
│ ├── __init__.py
│ ├── genome
│ │ ├── __init__.py
│ │ └── genes.py
│ ├── matshow45.py
│ ├── patch.py
│ ├── quasirandom.py
│ ├── tickers.py
│ └── tracks
│ │ ├── __init__.py
│ │ └── tracks.py
│ ├── scoring
│ └── prediction
│ │ └── filterers.py
│ ├── simulation
│ └── simulate.py
│ ├── sparse.py
│ ├── train.py
│ └── utils
│ ├── __init__.py
│ ├── ecdf.py
│ ├── interleave.py
│ ├── intervals.py
│ ├── numpy.py
│ ├── scanpy.py
│ ├── testing.py
│ ├── timing.py
│ └── torch.py
└── tests
├── _test_sparse.py
├── biomart
├── conftest.py
└── tss_test.py
├── conftest.py
├── data
└── motifscan
│ └── test_motifscan.py
├── loaders
├── test_fragment_helpers.py
├── test_fragments.py
└── test_transcriptome_fragments.py
├── models
├── diff
│ ├── loader
│ │ └── test_clustering_cuts.py
│ └── model
│ │ ├── _test_spline.py
│ │ └── test_diff_model.py
├── miff
│ └── _test_zoom.py
└── pred
│ └── model
│ └── test_pred_model.py
└── utils
└── test_interleave.py
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | name: test package
2 | concurrency:
3 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
4 | cancel-in-progress: true
5 | on:
6 | push:
7 | branches:
8 | - main
9 | - devel
10 | permissions:
11 | contents: write
12 | jobs:
13 | test:
14 | runs-on: ubuntu-latest
15 | steps:
16 | - uses: actions/checkout@v4
17 | - uses: actions/setup-python@v5
18 | with:
19 | python-version: 3.9
20 | cache: 'pip'
21 | - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV
22 | - run: pip install torch==2.3.0 # because torch_scatter contains import torch in the setup.py, we have to install torch first
23 | - run: pip install torch-scatter -f https://data.pyg.org/whl/torch-2.3.0+cpu.html # it seems a wheel is necessary to install torch-scatter
24 | - name: regular install
25 | run: pip install .[full]
26 | - name: attempt import
27 | run: python -c "import chromatinhd"
28 | - name: test install
29 | run: pip install -e .[test]
30 | # - uses: jpetrucciani/ruff-check@main
31 | - name: test
32 | run: pytest
33 | # - run: mkdocs gh-deploy --force
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | ./data/
2 | output
3 | software
4 | *.c
5 |
6 |
7 | *.ipynb
8 |
9 | # Byte-compiled / optimized / DLL files
10 | __pycache__/
11 | *.py[cod]
12 | *$py.class
13 |
14 | # C extensions
15 | *.so
16 |
17 | # Distribution / packaging
18 | .Python
19 | build/
20 | develop-eggs/
21 | dist/
22 | downloads/
23 | eggs/
24 | .eggs/
25 | lib/
26 | lib64/
27 | parts/
28 | sdist/
29 | var/
30 | wheels/
31 | share/python-wheels/
32 | *.egg-info/
33 | .installed.cfg
34 | *.egg
35 | MANIFEST
36 |
37 | # PyInstaller
38 | # Usually these files are written by a python script from a template
39 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
40 | *.manifest
41 | *.spec
42 |
43 | # Installer logs
44 | pip-log.txt
45 | pip-delete-this-directory.txt
46 |
47 | # Unit test / coverage reports
48 | htmlcov/
49 | .tox/
50 | .nox/
51 | .coverage
52 | .coverage.*
53 | .cache
54 | nosetests.xml
55 | coverage.xml
56 | *.cover
57 | *.py,cover
58 | .hypothesis/
59 | .pytest_cache/
60 | cover/
61 |
62 | # Translations
63 | *.mo
64 | *.pot
65 |
66 | # Django stuff:
67 | *.log
68 | local_settings.py
69 | db.sqlite3
70 | db.sqlite3-journal
71 |
72 | # Flask stuff:
73 | instance/
74 | .webassets-cache
75 |
76 | # Scrapy stuff:
77 | .scrapy
78 |
79 | # Sphinx documentation
80 | docs/_build/
81 |
82 | # PyBuilder
83 | .pybuilder/
84 | target/
85 |
86 | # Jupyter Notebook
87 | .ipynb_checkpoints
88 |
89 | # IPython
90 | profile_default/
91 | ipython_config.py
92 |
93 | # pyenv
94 | # For a library or package, you might want to ignore these files since the code is
95 | # intended to run in multiple environments; otherwise, check them in:
96 | # .python-version
97 |
98 | # pipenv
99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
102 | # install all needed dependencies.
103 | #Pipfile.lock
104 |
105 | # poetry
106 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
107 | # This is especially recommended for binary packages to ensure reproducibility, and is more
108 | # commonly ignored for libraries.
109 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
110 | #poetry.lock
111 |
112 | # pdm
113 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
114 | #pdm.lock
115 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
116 | # in version control.
117 | # https://pdm.fming.dev/#use-with-ide
118 | .pdm.toml
119 |
120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121 | __pypackages__/
122 |
123 | # Celery stuff
124 | celerybeat-schedule
125 | celerybeat.pid
126 |
127 | # SageMath parsed files
128 | *.sage.py
129 |
130 | # Environments
131 | .env
132 | .venv
133 | env/
134 | venv/
135 | ENV/
136 | env.bak/
137 | venv.bak/
138 |
139 | # Spyder project settings
140 | .spyderproject
141 | .spyproject
142 |
143 | # Rope project settings
144 | .ropeproject
145 |
146 | # mkdocs documentation
147 | /site
148 |
149 | # mypy
150 | .mypy_cache/
151 | .dmypy.json
152 | dmypy.json
153 |
154 | # Pyre type checker
155 | .pyre/
156 |
157 | # pytype static type analyzer
158 | .pytype/
159 |
160 | # Cython debug symbols
161 | cython_debug/
162 |
163 | # PyCharm
164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166 | # and can be added to the global gitignore or merged into this file. For a more nuclear
167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168 | #.idea/
169 |
170 |
171 | # manually added
172 | *.pkl
173 | *.tsv
174 | *trace.png
175 | tmp/
176 | software/
177 | *.parquet
178 |
179 |
180 | .vscode/
181 |
182 | # ignore example dataset
183 | docs/source/quickstart/example/*
184 | docs/source/quickstart/restats
185 | *.code-workspace
186 |
187 |
188 | *.zarr
189 | *.zarr/*
190 | *.hdf5
191 | restats
192 | docs/tex/overview.aux
193 | docs/tex/overview.fdb_latexmk
194 | docs/tex/overview.fls
195 | docs/tex/overview.pdf
196 | docs/tex/overview.synctex.gz
197 |
198 | output/*
199 |
200 |
201 | src/*.c
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | # - repo: https://github.com/astral-sh/ruff-pre-commit
3 | # # Ruff version.
4 | # rev: v0.0.282
5 | # hooks:
6 | # - id: ruff
7 | # args: [., --fix, --exit-non-zero-on-fix ]
8 | # exclude: '(^scripts|^docs)/.*'
9 | - repo: https://github.com/psf/black
10 | rev: 23.3.0
11 | hooks:
12 | - id: black
13 | language_version: python3
14 | exclude: '(^scripts|^docs)/.*'
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Wouter Saelens
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 | ChromatinHD analyzes single-cell ATAC+RNA data using the raw fragments as input,
11 | by automatically adapting the scale at which
12 | relevant chromatin changes on a per-position, per-cell, and per-gene basis.
13 | This enables identification of functional chromatin changes
14 | regardless of whether they occur in a narrow or broad region.
15 |
16 | As we show in [our paper](https://www.biorxiv.org/content/10.1101/2023.07.21.549899v1):
17 | - Compared to the typical approach (peak calling + statistical analysis), ChromatinHD models are better able to capture functional chromatin changes. This is because there are extensive functional accessibility changes both outside and within peaks ([Figure 3](https://www.biorxiv.org/content/10.1101/2023.07.21.549899v1)).
18 | - ChromatinHD models can capture long-range interactions by considering fragments co-occuring within the same cell ([Figure 4](https://www.biorxiv.org/content/10.1101/2023.07.21.549899v1)).
19 | - ChromatinHD models can also capture changes in fragment size that are related to gene expression changes, likely driven by dense direct and indirect binding of transcription factors ([Figure 5](https://www.biorxiv.org/content/10.1101/2023.07.21.549899v1)).
20 |
21 | [📜 Manuscript](https://www.biorxiv.org/content/10.1101/2023.07.21.549899v1)
22 |
23 | [❔ Documentation](https://chromatinhd.org)
24 |
25 | [▶️ Quick start](https://chromatinhd.org/quickstart/0_install)
26 |
--------------------------------------------------------------------------------
/docs/override/main.html:
--------------------------------------------------------------------------------
1 | {% extends "base.html" %}
2 |
3 | {% block outdated %}
4 | You're not viewing the latest version.
5 |
6 | Click here to go to latest.
7 |
8 | {% endblock %}
9 |
10 | {% block footer %}
11 | {#-
12 | This file was automatically generated - do not edit
13 | -#}
14 |
77 | {% endblock %}
--------------------------------------------------------------------------------
/docs/source/.pages:
--------------------------------------------------------------------------------
1 | nav:
2 | - quickstart
3 | - reference
4 | - benchmark
5 |
--------------------------------------------------------------------------------
/docs/source/CNAME:
--------------------------------------------------------------------------------
1 | chromatinhd.org
--------------------------------------------------------------------------------
/docs/source/benchmark/diff/timing.py:
--------------------------------------------------------------------------------
1 | # ---
2 | # jupyter:
3 | # jupytext:
4 | # formats: ipynb,py:percent
5 | # text_representation:
6 | # extension: .py
7 | # format_name: percent
8 | # format_version: '1.3'
9 | # jupytext_version: 1.14.7
10 | # kernelspec:
11 | # display_name: Python 3 (ipykernel)
12 | # language: python
13 | # name: python3
14 | # ---
15 |
16 | # %%
17 | import polyptich as pp
18 | pp.setup_ipython()
19 |
20 | import numpy as np
21 | import pandas as pd
22 |
23 | import matplotlib.pyplot as plt
24 | import matplotlib as mpl
25 |
26 | import seaborn as sns
27 |
28 | sns.set_style("ticks")
29 | # %config InlineBackend.figure_format='retina'
30 |
31 | import tqdm.auto as tqdm
32 | import torch
33 | import os
34 | import time
35 |
36 | # %%
37 | import chromatinhd as chd
38 |
39 | chd.set_default_device("cuda:1")
40 |
41 | # %%
42 | dataset_folder_original = chd.get_output() / "datasets" / "pbmc10k"
43 | transcriptome_original = chd.data.Transcriptome(dataset_folder_original / "transcriptome")
44 | fragments_original = chd.data.Fragments(dataset_folder_original / "fragments" / "10k10k")
45 |
46 | # %%
47 | genes_oi = transcriptome_original.var.sort_values("dispersions_norm", ascending=False).head(30).index
48 | regions = fragments_original.regions.filter_genes(genes_oi)
49 | fragments = fragments_original.filter_genes(regions)
50 | fragments.create_cellxgene_indptr()
51 | transcriptome = transcriptome_original.filter_genes(regions.coordinates.index)
52 |
53 | # %%
54 | folds = chd.data.folds.Folds()
55 | folds.sample_cells(fragments, 5)
56 |
57 | # %%
58 | clustering = chd.data.Clustering.from_labels(transcriptome_original.obs["celltype"])
59 |
60 | # %%
61 | fold = folds[0]
62 |
63 | # %%
64 | models = {}
65 | scores = []
66 |
67 | # %%
68 | import logging
69 |
70 | logger = chd.models.diff.trainer.trainer.logger
71 | logger.setLevel(logging.DEBUG)
72 | logger.handlers = []
73 | # logger.handlers = [logging.StreamHandler()]
74 |
75 | # %%
76 | devices = pd.DataFrame({"device": ["cuda:0", "cuda:1", "cpu"]}).set_index("device")
77 | for device in devices.index:
78 | if device != "cpu":
79 | devices.loc[device, "label"] = torch.cuda.get_device_properties(device).name
80 | else:
81 | devices.loc[device, "label"] = os.popen("lscpu").read().split("\n")[13].split(": ")[-1].lstrip()
82 |
83 | # %%
84 | scores = pd.DataFrame({"device": devices.index}).set_index("device")
85 |
86 | # %%
87 | for device in devices.index:
88 | start = time.time()
89 | model = chd.models.diff.model.cutnf.Model(
90 | fragments,
91 | clustering,
92 | )
93 | model.train_model(fragments, clustering, fold, n_epochs=10, device=device)
94 | models[device] = model
95 | end = time.time()
96 | scores.loc[device, "train"] = end - start
97 |
98 | # %%
99 | for device in devices.index:
100 | genepositional = chd.models.diff.interpret.genepositional.GenePositional(
101 | path=chd.get_output() / "interpret" / "genepositional"
102 | )
103 |
104 | start = time.time()
105 | genepositional.score(fragments, clustering, [models[device]], force=True, device=device)
106 | end = time.time()
107 | scores.loc[device, "inference"] = end - start
108 |
109 | # %%
110 | fig = polyptich.grid.Figure(polyptich.grid.Wrap(padding_width=0.1))
111 | height = len(scores) * 0.2
112 |
113 | plotdata = scores.copy().loc[devices.index]
114 |
115 | panel, ax = fig.main.add(polyptich.grid.Panel((1, height)))
116 | ax.barh(plotdata.index, plotdata["train"])
117 | ax.set_yticks(np.arange(len(devices)))
118 | ax.set_yticklabels(devices.label)
119 | ax.axvline(0, color="black", linestyle="--", lw=1)
120 | ax.set_title("Training")
121 | ax.set_xlabel("seconds")
122 |
123 | panel, ax = fig.main.add(polyptich.grid.Panel((1, height)))
124 | ax.barh(plotdata.index, plotdata["inference"])
125 | ax.axvline(0, color="black", linestyle="--", lw=1)
126 | ax.set_title("Inference")
127 | ax.set_yticks([])
128 | fig.plot()
129 |
--------------------------------------------------------------------------------
/docs/source/benchmark/index.md:
--------------------------------------------------------------------------------
1 | # Benchmark
2 |
3 | Lightweight benchmarking of the various models in terms of function and scalability. More comprehensive benchmarking with the state-of-the-art was performed in [our paper](https://www.biorxiv.org/content/10.1101/2023.07.21.549899v1).
4 |
--------------------------------------------------------------------------------
/docs/source/index.md:
--------------------------------------------------------------------------------
1 | ChromatinHD analyzes single-cell ATAC+RNA data using the raw fragments as input,
2 | by automatically adapting the scale at which
3 | relevant chromatin changes on a per-position, per-cell, and per-gene basis.
4 | This enables identification of functional chromatin changes
5 | regardless of whether they occur in a narrow or broad region.
6 |
7 | As we show in [our paper](https://www.biorxiv.org/content/10.1101/2023.07.21.549899v1), ChromatinHD models are better able to capture functional chromatin changes that the typical approach, i.e. peak-calling + statistical analysis. This is because there are extensive functional accessibility changes both outside and within peaks.
8 |
9 | ChromatinHD models can capture long-range interactions by considering fragments co-occuring within the same cell, as we highlight in [Figure 5 of our paper](https://www.biorxiv.org/content/10.1101/2023.07.21.549899v1),
10 |
11 | ChromatinHD models can also capture changes in fragment size that are related to gene expression changes, likely driven by dense direct and indirect binding of transcription factors, as we highlight in [Figure 6 of our paper](https://www.biorxiv.org/content/10.1101/2023.07.21.549899v1).
12 |
13 | Currently, the following models are supported:
14 |
15 |
58 |
59 |
102 |
--------------------------------------------------------------------------------
/docs/source/javascripts/reference.js:
--------------------------------------------------------------------------------
1 | document.querySelectorAll('.doc-attribute>:is(h1, h2, h3, h4, h5, h6)').forEach(function (el) {
2 | // add a tag to the attribute at the beginning
3 | var tag = document.createElement('span');
4 | tag.className = 'doc-attribute-tag';
5 | tag.textContent = "attribute"
6 |
7 | el.prepend(tag)
8 |
9 | })
10 |
11 | document.querySelectorAll('.doc-function>:is(h1, h2, h3, h4, h5, h6)').forEach(function (el) {
12 | // add a tag to the function at the beginning
13 | var tag = document.createElement('span');
14 | tag.className = 'doc-function-tag';
15 | // check if it's a method or a function
16 | console.log(el.closest('.doc-class'))
17 | if (el.closest('.doc-class')) {
18 | tag.textContent = "method"
19 | } else {
20 | tag.textContent = "function"
21 | }
22 |
23 | el.prepend(tag)
24 |
25 | })
26 |
27 |
28 | document.querySelectorAll('.doc-class>:is(h1, h2, h3, h4, h5, h6)').forEach(function (el) {
29 | // add a tag to the class at the beginning
30 | var tag = document.createElement('span');
31 | tag.className = 'doc-class-tag';
32 | tag.textContent = "class"
33 |
34 | el.prepend(tag)
35 |
36 | })
--------------------------------------------------------------------------------
/docs/source/quickstart/0_install.py:
--------------------------------------------------------------------------------
1 | # ---
2 | # jupyter:
3 | # jupytext:
4 | # text_representation:
5 | # extension: .py
6 | # format_name: percent
7 | # format_version: '1.3'
8 | # jupytext_version: 1.16.2
9 | # kernelspec:
10 | # display_name: chromatinhd
11 | # language: python
12 | # name: python3
13 | # ---
14 |
15 | # %% [markdown]
16 | # # Installation
17 |
18 | # %% tags=["hide_code"]
19 | # autoreload
20 | import IPython
21 |
22 | if IPython.get_ipython() is not None:
23 | IPython.get_ipython().run_line_magic("load_ext", "autoreload")
24 | IPython.get_ipython().run_line_magic("autoreload", "2")
25 |
26 | # %% [markdown]
27 | #
28 | # # using pip
29 | # pip install chromatinhd
30 | #
31 | # # (soon) using conda
32 | # conda install -c bioconda chromatinhd
33 | #
34 | # # from github
35 | # pip install git+https://github.com/DeplanckeLab/ChromatinHD
36 | #
37 |
38 | # %% [markdown]
39 | #
40 | # To use the GPU, ensure that a PyTorch version was installed with cuda enabled:
41 | #
42 |
43 | # %%
44 | import torch
45 |
46 | torch.cuda.is_available() # should return True
47 | torch.cuda.device_count() # should be >= 1
48 |
49 | # %% [markdown]
50 | #
51 | # If not, follow the instructions at https://pytorch.org/get-started/locally/. You may have to re-install PyTorch.
52 |
53 | # %% tags=["hide_output"]
54 | import chromatinhd as chd
55 |
56 | # %% [markdown]
57 | # ## Frequently asked questions
58 |
59 | # %% [markdown]
60 | #
61 |
--------------------------------------------------------------------------------
/docs/source/reference/.pages:
--------------------------------------------------------------------------------
1 | nav:
2 | - data
3 | - models
4 | - loaders
5 |
--------------------------------------------------------------------------------
/docs/source/reference/data/clustering.md:
--------------------------------------------------------------------------------
1 | ::: chromatinhd.data.clustering.Clustering
--------------------------------------------------------------------------------
/docs/source/reference/data/folds.md:
--------------------------------------------------------------------------------
1 | ::: chromatinhd.data.folds.Folds
2 |
--------------------------------------------------------------------------------
/docs/source/reference/data/fragments.md:
--------------------------------------------------------------------------------
1 | ::: chromatinhd.data.fragments.Fragments
2 | ::: chromatinhd.data.fragments.FragmentsView
3 |
--------------------------------------------------------------------------------
/docs/source/reference/data/motifscan.md:
--------------------------------------------------------------------------------
1 | ::: chromatinhd.data.motifscan.Motifscan
2 |
--------------------------------------------------------------------------------
/docs/source/reference/data/regions.md:
--------------------------------------------------------------------------------
1 | ::: chromatinhd.data.Regions
2 |
3 | ::: chromatinhd.data.regions.select_tss_from_fragments
4 |
5 |
--------------------------------------------------------------------------------
/docs/source/reference/data/transcriptome.md:
--------------------------------------------------------------------------------
1 | ::: chromatinhd.data.Transcriptome
--------------------------------------------------------------------------------
/docs/source/reference/index.md:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeplanckeLab/ChromatinHD/3b089ba409b452a8275ed020555d8bcad50f3299/docs/source/reference/index.md
--------------------------------------------------------------------------------
/docs/source/reference/loaders/fragments.md:
--------------------------------------------------------------------------------
1 | ::: chromatinhd.loaders.fragments.Fragments
2 | ::: chromatinhd.loaders.fragments.Cuts
3 |
4 | ::: chromatinhd.loaders.fragments.FragmentsResult
5 | ::: chromatinhd.loaders.fragments.CutsResult
6 |
--------------------------------------------------------------------------------
/docs/source/reference/models/diff/interpret.md:
--------------------------------------------------------------------------------
1 | ::: chromatinhd.models.diff.interpret.RegionPositional
2 |
--------------------------------------------------------------------------------
/docs/source/reference/models/diff/model.md:
--------------------------------------------------------------------------------
1 |
2 | ## CutNF
3 |
4 | The basic differential model that only looks at cut sites individually, regardless of the fragment's and cell's other cut sites
5 |
6 | ::: chromatinhd.models.diff.model.cutnf.Model
7 | options:
8 | heading_level: 3
9 |
10 | ::: chromatinhd.models.diff.model.cutnf.Models
11 | options:
12 | heading_level: 3
13 |
--------------------------------------------------------------------------------
/docs/source/reference/models/diff/plot.md:
--------------------------------------------------------------------------------
1 |
2 | ::: chromatinhd.models.diff.plot
3 |
4 |
5 |
--------------------------------------------------------------------------------
/docs/source/reference/models/pred/interpret.md:
--------------------------------------------------------------------------------
1 | ::: chromatinhd.models.pred.interpret.RegionMultiWindow
2 | ::: chromatinhd.models.pred.interpret.RegionPairWindow
3 |
4 |
--------------------------------------------------------------------------------
/docs/source/reference/models/pred/model.md:
--------------------------------------------------------------------------------
1 |
2 | ## Additive
3 |
4 | ::: chromatinhd.models.pred.model.multiscale.Model
5 | options:
6 | heading_level: 3
7 |
8 | ::: chromatinhd.models.pred.model.multiscale.Models
9 | options:
10 | heading_level: 3
11 |
--------------------------------------------------------------------------------
/docs/source/reference/models/pred/plot.md:
--------------------------------------------------------------------------------
1 |
2 | ::: chromatinhd.models.pred.plot
3 |
4 |
--------------------------------------------------------------------------------
/docs/source/static/comparison.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeplanckeLab/ChromatinHD/3b089ba409b452a8275ed020555d8bcad50f3299/docs/source/static/comparison.gif
--------------------------------------------------------------------------------
/docs/source/static/eu.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeplanckeLab/ChromatinHD/3b089ba409b452a8275ed020555d8bcad50f3299/docs/source/static/eu.png
--------------------------------------------------------------------------------
/docs/source/static/favicon.ai:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeplanckeLab/ChromatinHD/3b089ba409b452a8275ed020555d8bcad50f3299/docs/source/static/favicon.ai
--------------------------------------------------------------------------------
/docs/source/static/favicon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeplanckeLab/ChromatinHD/3b089ba409b452a8275ed020555d8bcad50f3299/docs/source/static/favicon.png
--------------------------------------------------------------------------------
/docs/source/static/logo.ai:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeplanckeLab/ChromatinHD/3b089ba409b452a8275ed020555d8bcad50f3299/docs/source/static/logo.ai
--------------------------------------------------------------------------------
/docs/source/static/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeplanckeLab/ChromatinHD/3b089ba409b452a8275ed020555d8bcad50f3299/docs/source/static/logo.png
--------------------------------------------------------------------------------
/docs/source/static/models/diff/1x/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeplanckeLab/ChromatinHD/3b089ba409b452a8275ed020555d8bcad50f3299/docs/source/static/models/diff/1x/logo.png
--------------------------------------------------------------------------------
/docs/source/static/models/diff/logo.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeplanckeLab/ChromatinHD/3b089ba409b452a8275ed020555d8bcad50f3299/docs/source/static/models/diff/logo.pdf
--------------------------------------------------------------------------------
/docs/source/static/models/dime/logo.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeplanckeLab/ChromatinHD/3b089ba409b452a8275ed020555d8bcad50f3299/docs/source/static/models/dime/logo.pdf
--------------------------------------------------------------------------------
/docs/source/static/models/dime/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeplanckeLab/ChromatinHD/3b089ba409b452a8275ed020555d8bcad50f3299/docs/source/static/models/dime/logo.png
--------------------------------------------------------------------------------
/docs/source/static/models/pred/1x/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeplanckeLab/ChromatinHD/3b089ba409b452a8275ed020555d8bcad50f3299/docs/source/static/models/pred/1x/logo.png
--------------------------------------------------------------------------------
/docs/source/static/models/pred/logo.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeplanckeLab/ChromatinHD/3b089ba409b452a8275ed020555d8bcad50f3299/docs/source/static/models/pred/logo.pdf
--------------------------------------------------------------------------------
/docs/source/static/models/time/logo.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeplanckeLab/ChromatinHD/3b089ba409b452a8275ed020555d8bcad50f3299/docs/source/static/models/time/logo.pdf
--------------------------------------------------------------------------------
/docs/source/static/models/time/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeplanckeLab/ChromatinHD/3b089ba409b452a8275ed020555d8bcad50f3299/docs/source/static/models/time/logo.png
--------------------------------------------------------------------------------
/docs/source/stylesheets/extra-reference.css:
--------------------------------------------------------------------------------
1 | .doc-labels {
2 | display: none;
3 | }
4 |
5 | .doc-class>.doc-heading {
6 | margin-top: 0.5em;
7 | border-top: 3px #33333333 solid;
8 | background-color: var(--md-code-bg-color);
9 | /* padding-top: 0.5em; */
10 | }
11 |
12 | .doc-class {
13 | margin-bottom: 1em;
14 | border-bottom: 1px #33333333 solid;
15 | }
16 |
17 |
18 | .doc-function>.doc-heading {
19 | margin-top: 0.5em;
20 | border-left: 3px #33333333 solid;
21 | background-color: var(--md-code-bg-color);
22 | /* padding-top: 0.5em; */
23 | }
24 |
25 |
26 | .doc-attribute>h3,
27 | .doc-function>h3 {
28 | font-size: 1.0rem;
29 | }
30 |
31 | .doc-contents>details.quote>summary {
32 | background-color: #FFFFFF;
33 | opacity: 0.6;
34 | }
35 |
36 | .doc-contents>details.quote {
37 | border-color: #33333333;
38 | }
39 |
40 | /* attributes, functions, ... */
41 |
42 | .doc-attribute-tag {
43 | margin-left: 5px;
44 | font-size: 0.6em;
45 | font-style: italic;
46 | padding: 0.1em;
47 | color: green;
48 | background-color: rgb(220, 251, 220);
49 | }
50 |
51 | .doc-function-tag {
52 | margin-left: 5px;
53 | font-size: 0.6em;
54 | font-style: italic;
55 | padding: 0.1em;
56 |
57 | color: coral;
58 | background-color: rgb(251, 235, 220);
59 | }
60 |
61 | .doc-class-tag {
62 | margin-left: 5px;
63 | font-size: 0.6em;
64 | font-style: italic;
65 | padding: 0.1em;
66 | color: red;
67 | background-color: rgb(251, 220, 220);
68 | }
69 |
70 | /* bold name of attribute, function, ... */
71 | .doc-attribute code:first-child {
72 | font-weight: bold;
73 | }
74 |
75 | .doc-function code:first-child {
76 | font-weight: bold;
77 | }
78 |
79 | .doc .md-typeset__table,
80 | .doc .md-typeset__table table {
81 | display: table !important;
82 | width: inherit;
83 | }
84 |
85 | .md-typeset table:not([class]) th {
86 | padding-top: 0;
87 | padding-bottom: 0;
88 | }
--------------------------------------------------------------------------------
/docs/source/stylesheets/extra.css:
--------------------------------------------------------------------------------
1 | .jp-InputArea-prompt,
2 | .jp-OutputArea-prompt {
3 | display: none !important;
4 | }
5 |
6 | .jp-OutputArea-output {
7 | font-size: 1.2em;
8 | padding: 0.5em;
9 | border-left: 2px #33333333 solid !important;
10 | }
11 |
12 | .jupyter-wrapper,
13 | .jp-Cell:not(.jp-mod-noOutputs),
14 | .jp-Cell-outputWrapper {
15 | margin-top: 0;
16 | }
17 |
18 | .jp-Cell-outputWrapper img {
19 | box-shadow: rgba(99, 99, 99, 0.2) 0px 2px 8px 0px;
20 | }
21 |
22 | .jp-mod-noOutputs.celltag_hide_code {
23 | display: none;
24 | }
25 |
26 | .md-typeset h1,
27 | .md-typeset h2,
28 | .md-typeset h3,
29 | .md-typeset h4 {
30 | margin: 0;
31 | }
32 |
33 | p {
34 | margin-top: 0.1em;
35 | }
36 |
37 | .md-grid {
38 | max-width: 90rem;
39 | }
40 |
41 | clipboard-copy:hover {
42 | background-color: #33333333;
43 | }
44 |
45 | .jupyter-wrapper .jp-InputArea-editor {
46 | border-left: 2px green solid !important;
47 | }
48 |
49 | /* remove the cell toolbar with ugly tags */
50 | .jupyter-wrapper .celltoolbar {
51 | display: none !important;
52 | }
53 |
54 | /* nav bar */
55 | .md-tabs__item {
56 | height: inherit;
57 | padding-left: 0;
58 | padding-right: 0;
59 | }
60 |
61 | .md-tabs__item>a {
62 | padding-top: 0.8em;
63 | padding-bottom: 0.8em;
64 | padding-left: 0.8em;
65 | padding-right: 0.8em;
66 | margin-top: 0;
67 | transition: 0.1s;
68 | }
69 |
70 | .md-tabs__item>a:hover {
71 | background-color: var(--md-code-bg-color);
72 | }
73 |
74 | .md-tabs__link--active {
75 | background-color: var(--md-code-bg-color);
76 | }
77 |
78 | .md-nav--secondary {
79 | border-left: 2px solid #33333333 !important;
80 | }
81 |
82 | /* side bar */
83 | .md-sidebar {
84 | width: 15rem;
85 | }
86 |
87 | .md-sidebar__inner {
88 | padding-right: inherit !important;
89 | }
90 |
91 | .md-typeset details > p{
92 | margin-top:0.5em;
93 | }
--------------------------------------------------------------------------------
/docs/tex/overview.tex:
--------------------------------------------------------------------------------
1 | \documentclass{article}
2 | \usepackage{graphicx}
3 | \usepackage{amsmath}
4 |
5 | \title{ChromatinHD}
6 | \author{Wouter Saelens}
7 | \date{August 2023}
8 |
9 | \begin{document}
10 |
11 | \maketitle
12 |
13 | \begin{equation}
14 | p(x_1,x_2) = p(x_1)p(x_2|x_1)
15 | \end{equation}
16 |
17 | \begin{equation}
18 | p(x_1) = p(x_{1,l},x_{1,r}) = p(x_{1,l})p(x_{1,r}|x_{1,l})
19 | \end{equation}
20 |
21 | \begin{equation}
22 | p(x_{1,l}) = \prod_{i=0}^{d} p(x_{1,l}|x_{1}\in w_i)
23 | \end{equation}
24 |
25 | \begin{equation}
26 | p(x_{1,l}|x_{1}\in w_a) = f(motifs)
27 | \end{equation}
28 |
29 | \end{document}
30 |
--------------------------------------------------------------------------------
/mkdocs.yml:
--------------------------------------------------------------------------------
1 | site_name: High-definition modeling of chromatin + transcriptomics data
2 | docs_dir: docs/source
3 | theme:
4 | name: material
5 | palette:
6 | primary: white
7 | logo: static/logo.png
8 | favicon: static/favicon.png
9 | custom_dir: docs/override
10 | features:
11 | - navigation.tracking
12 | - navigation.tabs
13 | - navigation.footer
14 | - navigation.sections
15 | - toc.follow
16 | - toc.integrate
17 |
18 | site_url: "http://chromatinhd.org/"
19 | repo_url: https://github.com/DeplanckeLab/ChromatinHD
20 |
21 | plugins:
22 | - mkdocstrings:
23 | handlers:
24 | python:
25 | import:
26 | - https://docs.python-requests.org/en/master/objects.inv
27 | - https://installer.readthedocs.io/en/latest/objects.inv
28 | options:
29 | docstring_style: google
30 | heading_level: 2
31 | inheritance_diagram: True
32 | show_root_heading: True
33 | show_symbol_type_heading: True
34 | docstring_section_style: table
35 | - mike
36 | - search
37 | - mkdocs-jupyter:
38 | include: ["*.ipynb"] # Default: ["*.py", "*.ipynb"]
39 | remove_tag_config:
40 | remove_input_tags:
41 | - hide_code
42 | remove_all_outputs_tags:
43 | - hide_output
44 | - social
45 | - awesome-pages
46 | extra:
47 | version:
48 | provider: mike
49 | analytics:
50 | provider: google
51 | property: G-2EDCBPY71H
52 | social:
53 | - icon: fontawesome/brands/github
54 | link: https://github.com/DeplanckeLab/ChromatinHD
55 | markdown_extensions:
56 | - admonition
57 | - pymdownx.details
58 | - pymdownx.superfences
59 | - attr_list
60 | - md_in_html
61 | - pymdownx.details
62 | extra_css:
63 | - stylesheets/extra.css
64 | - stylesheets/extra-reference.css
65 | extra_javascript:
66 | - javascripts/reference.js
67 | copyright: Copyright © 2022 - 2024 Wouter Saelens
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=41", "wheel", "setuptools_scm[toml]>=6.2", "numpy", "Cython"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [tool.setuptools-git-versioning]
6 | enabled = true
7 |
8 | [project]
9 | name = "chromatinhd"
10 | authors = [
11 | {name = "Wouter Saelens", email = "wouter.saelens@gmail.com"},
12 | ]
13 | description = "High-definition modeling of (single-cell) chromatin + transcriptomics data"
14 | requires-python = ">=3.8"
15 | keywords = ["bioinformatics", "chromatin accessibility", "transcriptomics"]
16 | classifiers = [
17 | "Programming Language :: Python :: 3",
18 | ]
19 | dependencies = [
20 | "torch", # --extra-index-url https://download.pytorch.org/whl/cu113
21 | "scanpy",
22 | "matplotlib",
23 | "numpy",
24 | "seaborn",
25 | "diskcache",
26 | "appdirs",
27 | "xarray",
28 | "requests",
29 | "zarr",
30 | "pathlibfs",
31 | ]
32 | dynamic = ["version", "readme"]
33 | license = "MIT AND (Apache-2.0 OR BSD-2-Clause)"
34 |
35 | [project.urls]
36 | "Homepage" = "https://github.com/DeplanckeLab/ChromatinHD"
37 | "Bug Tracker" = "https://github.com/DeplanckeLab/ChromatinHD/issues"
38 |
39 | [tool.setuptools.dynamic]
40 | readme = {file = "README.md", content-type = "text/markdown"}
41 |
42 | [project.optional-dependencies]
43 | sam = [
44 | "pysam",
45 | ]
46 | dev = [
47 | "pre-commit",
48 | "pytest",
49 | "coverage",
50 | "polyptich",
51 | "black",
52 | "pylint",
53 | "jupytext",
54 | "mkdocs",
55 | "mkdocs-material",
56 | "mkdocstrings[python]",
57 | "mkdocs-jupyter",
58 | "mike",
59 | "cairosvg", # for mkdocs social
60 | "pillow", # for mkdocs social
61 | "mkdocs-awesome-pages-plugin",
62 | "setuptools_scm",
63 | "Cython",
64 | ]
65 | test = [
66 | "pytest",
67 | "ruff",
68 | ]
69 | full = ["chromatinhd[sam]"]
70 |
71 | [tool.setuptools_scm]
72 |
73 | [tool.pytest.ini_options]
74 | filterwarnings = [
75 | "ignore",
76 | ]
77 |
78 | [tool.pylint.'MESSAGES CONTROL']
79 | max-line-length = 120
80 | disable = [
81 | "too-many-arguments",
82 | "not-callable",
83 | "redefined-builtin",
84 | "redefined-outer-name",
85 | ]
86 |
87 | [tool.ruff]
88 | line-length = 90
89 | include = ['src/**/*.py']
90 | exclude = ['scripts/*']
91 |
92 |
93 | [tool.jupytext]
94 | formats = "ipynb,py:percent"
--------------------------------------------------------------------------------
/scripts/benchmark/datasets/pbmc10k/2-process_all.py:
--------------------------------------------------------------------------------
1 | # ---
2 | # jupyter:
3 | # jupytext:
4 | # formats: ipynb,py:percent
5 | # text_representation:
6 | # extension: .py
7 | # format_name: percent
8 | # format_version: '1.3'
9 | # jupytext_version: 1.14.7
10 | # kernelspec:
11 | # display_name: Python 3 (ipykernel)
12 | # language: python
13 | # name: python3
14 | # ---
15 |
16 | # %% [markdown]
17 | # # Preprocess
18 |
19 | # %%
20 | import polyptich as pp
21 | pp.setup_ipython()
22 |
23 | import numpy as np
24 | import pandas as pd
25 |
26 | import matplotlib.pyplot as plt
27 | import matplotlib as mpl
28 |
29 | import seaborn as sns
30 |
31 | sns.set_style("ticks")
32 |
33 | import torch
34 |
35 | import pickle
36 |
37 | import scanpy as sc
38 |
39 | import tqdm.auto as tqdm
40 | import io
41 |
42 | # %%
43 | import chromatinhd as chd
44 |
45 | # %%
46 | folder_root = chd.get_output()
47 | folder_data = folder_root / "data"
48 |
49 | dataset_name = "pbmc10k"
50 | genome = "GRCh38"
51 |
52 | folder_data_preproc = folder_data / dataset_name
53 | folder_data_preproc.mkdir(exist_ok=True, parents=True)
54 |
55 | # %%
56 | dataset_folder = chd.get_output() / "datasets" / "pbmc10k"
57 | dataset_folder.mkdir(exist_ok=True, parents=True)
58 |
59 | # %%
60 | adata = pickle.load((folder_data_preproc / "adata_annotated.pkl").open("rb"))
61 |
62 | # %%
63 | transcriptome = chd.data.transcriptome.Transcriptome.from_adata(adata, path=dataset_folder / "transcriptome")
64 |
65 | # %%
66 | import genomepy
67 |
68 | # genomepy.install_genome("GRCh38", genomes_dir="/data/genome/")
69 |
70 | sizes_file = "/data/genome/GRCh38/GRCh38.fa.sizes"
71 |
72 | # %%
73 | regions = chd.data.regions.Regions.from_chromosomes_file(sizes_file, path = dataset_folder / "regions" / "all")
74 |
75 | # %%
76 | fragments_file = folder_data_preproc / "fragments.tsv.gz"
77 |
78 | # %%
79 | fragments = chd.data.Fragments.from_fragments_tsv(fragments_file, regions = regions, obs = transcriptome.obs, path = dataset_folder / "fragments" / "all")
80 |
81 | # %%
82 | fragments.create_regionxcell_indptr()
83 |
--------------------------------------------------------------------------------
/scripts/benchmark/datasets/pbmc10k/3-process_large.py:
--------------------------------------------------------------------------------
1 | # ---
2 | # jupyter:
3 | # jupytext:
4 | # formats: ipynb,py:percent
5 | # text_representation:
6 | # extension: .py
7 | # format_name: percent
8 | # format_version: '1.3'
9 | # jupytext_version: 1.14.7
10 | # kernelspec:
11 | # display_name: Python 3 (ipykernel)
12 | # language: python
13 | # name: python3
14 | # ---
15 |
16 | # %% [markdown]
17 | # # Preprocess
18 |
19 | # %%
20 | import polyptich as pp
21 | pp.setup_ipython()
22 |
23 | import numpy as np
24 | import pandas as pd
25 |
26 | import matplotlib.pyplot as plt
27 | import matplotlib as mpl
28 |
29 | import seaborn as sns
30 |
31 | sns.set_style("ticks")
32 |
33 | import torch
34 |
35 | import pickle
36 |
37 | import scanpy as sc
38 |
39 | import tqdm.auto as tqdm
40 | import io
41 |
42 | # %%
43 | import chromatinhd as chd
44 |
45 | # %%
46 | folder_root = chd.get_output()
47 | folder_data = folder_root / "data"
48 |
49 | dataset_name = "pbmc10k"
50 | genome = "GRCh38"
51 |
52 | folder_data_preproc = folder_data / dataset_name
53 | folder_data_preproc.mkdir(exist_ok=True, parents=True)
54 |
55 | # %%
56 | dataset_folder = chd.get_output() / "datasets" / "pbmc10k"
57 | dataset_folder.mkdir(exist_ok=True, parents=True)
58 |
59 | # %%
60 | adata = pickle.load((folder_data_preproc / "adata_annotated.pkl").open("rb"))
61 |
62 | # %%
63 | transcriptome = chd.data.transcriptome.Transcriptome.from_adata(adata, path=dataset_folder / "transcriptome")
64 |
65 | # %%
66 | selected_transcripts = pickle.load((folder_data_preproc / "selected_transcripts.pkl").open("rb"))
67 | regions = chd.data.regions.Regions.from_transcripts(
68 | selected_transcripts, [-10000, 10000], dataset_folder / "regions" / "10k10k"
69 | )
70 |
71 | # %%
72 | fragments_file = folder_data_preproc / "fragments.tsv.gz"
73 | fragments = chd.data.fragments.Fragments(dataset_folder / "fragments" / "10k10k")
74 | fragments.regions = regions
75 | fragments = chd.data.fragments.Fragments.from_fragments_tsv(
76 | fragments_file=fragments_file,
77 | regions=regions,
78 | obs=transcriptome.obs,
79 | path=fragments.path,
80 | )
81 |
82 | # %%
83 |
--------------------------------------------------------------------------------
/scripts/benchmark/datasets/pbmc10k/4-process_tiny.py:
--------------------------------------------------------------------------------
1 | # ---
2 | # jupyter:
3 | # jupytext:
4 | # formats: ipynb,py:percent
5 | # text_representation:
6 | # extension: .py
7 | # format_name: percent
8 | # format_version: '1.3'
9 | # jupytext_version: 1.14.7
10 | # kernelspec:
11 | # display_name: Python 3 (ipykernel)
12 | # language: python
13 | # name: python3
14 | # ---
15 |
16 | # %% [markdown]
17 | # # Preprocess
18 |
19 | # %%
20 | import polyptich as pp
21 | pp.setup_ipython()
22 |
23 | import numpy as np
24 | import pandas as pd
25 |
26 | import matplotlib.pyplot as plt
27 | import matplotlib as mpl
28 |
29 | import seaborn as sns
30 |
31 | sns.set_style("ticks")
32 |
33 | import torch
34 |
35 | import pickle
36 |
37 | import scanpy as sc
38 |
39 | import tqdm.auto as tqdm
40 | import io
41 |
42 | # %%
43 | import chromatinhd as chd
44 |
45 | # %%
46 | folder_root = chd.get_output()
47 | folder_data = folder_root / "data"
48 |
49 | dataset_name = "pbmc10k"
50 | main_url = "https://cf.10xgenomics.com/samples/cell-arc/2.0.0/pbmc_granulocyte_sorted_10k/pbmc_granulocyte_sorted_10k"
51 | genome = "GRCh38"
52 |
53 | folder_data_preproc = folder_data / dataset_name
54 | folder_data_preproc.mkdir(exist_ok=True, parents=True)
55 |
56 | # %% [markdown]
57 | # ## Tiny dataset
58 |
59 | # %%
60 | dataset_folder = chd.get_output() / "datasets" / "pbmc10ktiny"
61 | folder_dataset_publish = chd.get_git_root() / "src" / "chromatinhd" / "data" / "examples" / "pbmc10ktiny"
62 | folder_dataset_publish.mkdir(exist_ok=True, parents=True)
63 |
64 | # %%
65 | adata = pickle.load((folder_data_preproc / "adata_annotated.pkl").open("rb"))
66 | genes_oi = adata.var.sort_values("dispersions_norm", ascending=False).index[:50]
67 | adata = adata[:, genes_oi]
68 | adata_tiny = sc.AnnData(X=adata[:, genes_oi].layers["magic"], obs=adata.obs, var=adata.var.loc[genes_oi])
69 |
70 | # %%
71 | transcriptome = chd.data.transcriptome.Transcriptome.from_adata(adata, path=dataset_folder / "transcriptome")
72 | adata_tiny.write(folder_dataset_publish / "transcriptome.h5ad", compression="gzip")
73 |
74 | # %%
75 | selected_transcripts = pickle.load((folder_data_preproc / "selected_transcripts.pkl").open("rb"))
76 | selected_transcripts = selected_transcripts.loc[genes_oi]
77 | regions = chd.data.regions.Regions.from_transcripts(
78 | selected_transcripts, [-10000, 10000], path=dataset_folder / "regions" / "10k10k"
79 | )
80 |
81 | # %%
82 | import pysam
83 |
84 | fragments_tabix = pysam.TabixFile(str(folder_data_preproc / "fragments.tsv.gz"))
85 | coordinates = regions.coordinates
86 |
87 | fragments_new = []
88 | for i, (gene, promoter_info) in tqdm.tqdm(enumerate(coordinates.iterrows()), total=coordinates.shape[0]):
89 | start = max(0, promoter_info["start"])
90 |
91 | fragments_promoter = fragments_tabix.fetch(promoter_info["chrom"], start, promoter_info["end"])
92 | fragments_new.extend(list(fragments_promoter))
93 |
94 | fragments_new = pd.DataFrame(
95 | [x.split("\t") for x in fragments_new], columns=["chrom", "start", "end", "cell", "nreads"]
96 | )
97 | fragments_new["start"] = fragments_new["start"].astype(int)
98 | fragments_new = fragments_new.sort_values(["chrom", "start", "cell"])
99 |
100 | # %%
101 | fragments_new.to_csv(folder_dataset_publish / "fragments.tsv", sep="\t", index=False, header=False)
102 | pysam.tabix_compress(folder_dataset_publish / "fragments.tsv", folder_dataset_publish / "fragments.tsv.gz", force=True)
103 |
104 | # %%
105 | # !ls -lh {folder_dataset_publish}
106 | # !tabix -p bed {folder_dataset_publish / "fragments.tsv.gz"}
107 | folder_dataset_publish
108 |
109 | # %%
110 | transcriptome.var.index
111 |
112 | # %%
113 | coordinates.index
114 |
--------------------------------------------------------------------------------
/scripts/cythonize.sh:
--------------------------------------------------------------------------------
1 | cython src/chromatinhd/loaders/fragments_helpers.pyx --embed -a
2 | cython src/chromatinhd/data/motifscan/scan_helpers.pyx --embed -a
--------------------------------------------------------------------------------
/scripts/dist.sh:
--------------------------------------------------------------------------------
1 | python -m setuptools_git_versioning
2 |
3 | version="0.4.3"
4 |
5 | git add .
6 | git commit -m "version v${version}"
7 |
8 | git tag -a v${version} -m "v${version}"
9 |
10 | python -m build
11 |
12 | # twine upload --repository testpypi dist/chromatinhd-${version}.tar.gz --verbose
13 |
14 | git push --tags
15 |
16 | gh release create v${version} -t "v${version}" -n "v${version}" dist/chromatinhd-${version}.tar.gz
17 |
18 | twine upload dist/chromatinhd-${version}.tar.gz --verbose
19 |
20 | python -m build --wheel
--------------------------------------------------------------------------------
/scripts/docs.sh:
--------------------------------------------------------------------------------
1 | jupytext --sync docs/source/*/*.py
2 |
3 |
4 | mkdocs build
--------------------------------------------------------------------------------
/scripts/followup/miff/fit_negbinom.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | ## fit_nbinom
4 | # Copyright (C) 2014 Gokcen Eraslan
5 | #
6 | # This program is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with this program. If not, see .
18 |
19 | from __future__ import print_function
20 |
21 | import numpy as np
22 | from scipy.special import gammaln
23 | from scipy.special import psi
24 | from scipy.special import factorial
25 | from scipy.optimize import fmin_l_bfgs_b as optim
26 |
27 | import sys
28 |
29 |
30 | # X is a numpy array representing the data
31 | # initial params is a numpy array representing the initial values of
32 | # size and prob parameters
33 | def fit_nbinom(X, initial_params=None):
34 | infinitesimal = np.finfo(np.float).eps
35 |
36 | def log_likelihood(params, *args):
37 | r, p = params
38 | X = args[0]
39 | N = X.size
40 |
41 | # MLE estimate based on the formula on Wikipedia:
42 | # http://en.wikipedia.org/wiki/Negative_binomial_distribution#Maximum_likelihood_estimation
43 | result = (
44 | np.sum(gammaln(X + r))
45 | - np.sum(np.log(factorial(X)))
46 | - N * (gammaln(r))
47 | + N * r * np.log(p)
48 | + np.sum(X * np.log(1 - (p if p < 1 else 1 - infinitesimal)))
49 | )
50 |
51 | return -result
52 |
53 | def log_likelihood_deriv(params, *args):
54 | r, p = params
55 | X = args[0]
56 | N = X.size
57 |
58 | pderiv = (N * r) / p - np.sum(X) / (1 - (p if p < 1 else 1 - infinitesimal))
59 | rderiv = np.sum(psi(X + r)) - N * psi(r) + N * np.log(p)
60 |
61 | return np.array([-rderiv, -pderiv])
62 |
63 | if initial_params is None:
64 | # reasonable initial values (from fitdistr function in R)
65 | m = np.mean(X)
66 | v = np.var(X)
67 | size = (m**2) / (v - m) if v > m else 10
68 |
69 | # convert mu/size parameterization to prob/size
70 | p0 = size / ((size + m) if size + m != 0 else 1)
71 | r0 = size
72 | initial_params = np.array([r0, p0])
73 |
74 | bounds = [(infinitesimal, None), (infinitesimal, 1)]
75 | optimres = optim(
76 | log_likelihood,
77 | x0=initial_params,
78 | # fprime=log_likelihood_deriv,
79 | args=(X,),
80 | approx_grad=1,
81 | bounds=bounds,
82 | )
83 |
84 | params = optimres[0]
85 | return {"size": params[0], "prob": params[1]}
86 |
87 |
88 | if __name__ == "__main__":
89 | if len(sys.argv) != 3:
90 | print("Usage: %s size_param prob_param" % sys.argv[0])
91 | exit()
92 |
93 | testset = np.random.negative_binomial(float(sys.argv[1]), float(sys.argv[2]), 1000)
94 | print(fit_nbinom(testset))
95 |
--------------------------------------------------------------------------------
/scripts/followup/miff/layers_linear_spline.py:
--------------------------------------------------------------------------------
1 | # ---
2 | # jupyter:
3 | # jupytext:
4 | # text_representation:
5 | # extension: .py
6 | # format_name: percent
7 | # format_version: '1.3'
8 | # jupytext_version: 1.15.1
9 | # kernelspec:
10 | # display_name: Python 3
11 | # language: python
12 | # name: python3
13 | # ---
14 |
15 | # %%
16 | import polyptich as pp
17 | pp.setup_ipython()
18 |
19 | import numpy as np
20 | import pandas as pd
21 |
22 | import matplotlib.pyplot as plt
23 | import matplotlib as mpl
24 |
25 | import seaborn as sns
26 |
27 | sns.set_style("ticks")
28 | # %config InlineBackend.figure_format='retina'
29 |
30 | import tqdm.auto as tqdm
31 |
32 | # %%
33 | import torch
34 | import chromatinhd as chd
35 | chd.set_default_device("cuda:0")
36 |
37 | import tempfile
38 | import pathlib
39 | import pickle
40 |
41 | # %%
42 | import quadratic
43 |
44 | # %%
45 | width = 1024
46 | positions = torch.tensor([100, 500, 1000])
47 | nbins = [8, 8, 8]
48 |
49 | unnormalized_heights_bins = [
50 | torch.tensor([[1] * 8]*3, dtype = torch.float),
51 | torch.tensor([[1] * 8]*3, dtype = torch.float),
52 | torch.tensor([[1] * 8]*3, dtype = torch.float),
53 | # torch.tensor([[1] * 2]*3, dtype = torch.float),
54 | ]
55 | unnormalized_heights_bins[1][1, 3] = 10
56 |
57 | # %%
58 | unnormalized_heights_all = []
59 | cur_total_n = 1
60 | for n in nbins:
61 | cur_total_n *= n
62 | unnormalized_heights_all.append(torch.zeros(cur_total_n).reshape(-1, n))
63 | unnormalized_heights_all[0][0, 3] = -1
64 | unnormalized_heights_all[1][2, 2:4] = 1
65 | unnormalized_heights_all[2][0, 2:4] = 1
66 |
67 | # %%
68 | import math
69 |
70 | def transform_linear_spline(positions, n, width, unnormalized_heights):
71 | binsize = width//n
72 |
73 | normalized_heights = torch.nn.functional.log_softmax(unnormalized_heights, -1)
74 | if normalized_heights.ndim == positions.ndim:
75 | normalized_heights = normalized_heights.unsqueeze(0)
76 |
77 | binixs = torch.div(positions, binsize, rounding_mode = "trunc")
78 |
79 | logprob = torch.gather(normalized_heights, 1, binixs.unsqueeze(1)).squeeze(1)
80 |
81 | positions = positions - binixs * binsize
82 | width = binsize
83 |
84 | return logprob, positions, width
85 |
86 | def calculate_logprob(positions, nbins, width, unnormalized_heights_bins):
87 | assert len(nbins) == len(unnormalized_heights_bins)
88 |
89 | curpositions = positions
90 | curwidth = width
91 | logprob = torch.zeros_like(positions, dtype = torch.float)
92 | for i, n in enumerate(nbins):
93 | assert (curwidth % n) == 0
94 | logprob_layer, curpositions, curwidth = transform_linear_spline(curpositions, n, curwidth, unnormalized_heights_bins[i])
95 | logprob += logprob_layer
96 | logprob = logprob - math.log(curwidth)
97 | return logprob
98 |
99 |
100 | # %%
101 | x = torch.arange(width)
102 |
103 | # %%
104 | totalnbins = np.cumprod(nbins)
105 | totalbinwidths = torch.tensor(width//totalnbins)
106 | totalbinixs = torch.div(x[:, None], totalbinwidths, rounding_mode="floor")
107 | totalbinsectors = torch.div(totalbinixs, torch.tensor(nbins)[None, :], rounding_mode="trunc")
108 | unnormalized_heights_bins = [unnormalized_heights_all[i][totalbinsector] for i, totalbinsector in enumerate(totalbinsectors.numpy().T)]
109 |
110 | # %%
111 | torch.tensor(width//np.array(nbins))
112 |
113 | # %%
114 | totalbinwidthsa
115 |
116 | # %%
117 | logprob = calculate_logprob(x, nbins, width, unnormalized_heights_bins)
118 | fig, ax = plt.subplots()
119 | ax.plot(x, torch.exp(logprob))
120 |
--------------------------------------------------------------------------------
/scripts/followup/miff/quadratic.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | DEFAULT_MIN_BIN_HEIGHT = 1e-5
4 |
5 |
6 | def calculate_heights(unnormalized_heights, widths, min_bin_height=DEFAULT_MIN_BIN_HEIGHT, local=True):
7 | unnorm_heights_exp = torch.exp(unnormalized_heights)
8 |
9 | min_bin_height = 1e-10
10 |
11 | if local:
12 | # per feature normalization
13 | unnormalized_area = torch.sum(
14 | ((unnorm_heights_exp[..., :-1] + unnorm_heights_exp[..., 1:]) / 2) * widths,
15 | dim=-1,
16 | keepdim=True,
17 | )
18 | heights = unnorm_heights_exp / unnormalized_area
19 | heights = min_bin_height + (1 - min_bin_height) * heights
20 | else:
21 | # global normalization
22 | unnormalized_area = torch.sum(
23 | ((unnorm_heights_exp[..., :-1] + unnorm_heights_exp[..., 1:]) / 2) * widths,
24 | )
25 | heights = unnorm_heights_exp * unnorm_heights_exp.shape[-2] / unnormalized_area
26 | heights = min_bin_height + (1 - min_bin_height) * heights
27 |
28 | # to check
29 | # normalized_area = torch.sum(
30 | # ((heights[..., :-1] + heights[..., 1:]) / 2) * widths,
31 | # dim=-1,
32 | # keepdim=True,
33 | # )
34 | # print(normalized_area.sum())
35 |
36 | return heights
37 |
38 |
39 | def calculate_bin_left_cdf(heights, widths):
40 | bin_left_cdf = torch.cumsum(((heights[..., :-1] + heights[..., 1:]) / 2) * widths, dim=-1)
41 | bin_left_cdf[..., -1] = 1.0
42 | bin_left_cdf = F.pad(bin_left_cdf, pad=(1, 0), mode="constant", value=0.0)
43 | return bin_left_cdf
44 |
--------------------------------------------------------------------------------
/scripts/install.sh:
--------------------------------------------------------------------------------
1 | pip uninstall torch
2 | pip cache purge
3 |
4 | pip uninstall torch
5 | pip cache purge
6 | pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118
7 | pip install torch-scatter -f https://data.pyg.org/whl/torch-2.0.1+cu118.html
8 |
--------------------------------------------------------------------------------
/scripts/test.py:
--------------------------------------------------------------------------------
1 | # %%
2 | import requests
3 |
4 | # Send a GET request to the BioMart REST API
5 | response = requests.get('http://www.ensembl.org/biomart/martservice?type=registry')
6 |
7 | # The response is an XML string, so parse it using ElementTree
8 | import xml.etree.ElementTree as ET
9 | root = ET.fromstring(response.text)
10 |
11 | # Iterate over all the MartURLLocation elements (these represent the datasets)
12 | for dataset in root.iter('MartURLLocation'):
13 | # The species is stored in the 'name' attribute
14 | species = dataset.get('name')
15 |
16 | # If the species is human, print the details of the dataset
17 | if 'hsapiens' in species:
18 | print(ET.tostring(dataset, encoding='utf8').decode('utf8'))
19 |
20 | # %%
21 | response.text
22 | # %%
23 | import pandas as pd
24 | import io
25 |
26 | def get_datasets():
27 | mart = "ENSEMBL_MART_ENSEMBL"
28 | baseurl = "http://www.ensembl.org/biomart/martservice?"
29 | url = "{baseurl}type=datasets&requestid=biomaRt&mart={mart}"
30 | response = requests.get(url.format(baseurl=baseurl, mart=mart))
31 | root = pd.read_table(io.StringIO(response.text), sep="\t", header=None, names=["_", "dataset", "description", "version", "assembly", "__", "___", "____", "last_update"])
32 | print(root)
33 | # %%
34 | get_biomart_datasets()
35 | # %%
36 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, Extension
2 | from Cython.Build import cythonize
3 | from Cython.Compiler import Options
4 | import numpy
5 |
6 | # These are optional
7 | Options.docstrings = True
8 | Options.annotate = False
9 |
10 | # Modules to be compiled and include_dirs when necessary
11 | extensions = [
12 | # Extension(
13 | # "pyctmctree.inpyranoid_c",
14 | # ["src/pyctmctree/inpyranoid_c.pyx"],
15 | # ),
16 | Extension(
17 | "chromatinhd.loaders.fragments_helpers",
18 | ["src/chromatinhd/loaders/fragments_helpers.pyx"],
19 | include_dirs=[numpy.get_include()],
20 | py_limited_api=True,
21 | ),
22 | Extension(
23 | "chromatinhd.data.motifscan.scan_helpers",
24 | ["src/chromatinhd/data/motifscan/scan_helpers.pyx"],
25 | include_dirs=[numpy.get_include()],
26 | py_limited_api=True,
27 | ),
28 | ]
29 |
30 |
31 | # This is the function that is executed
32 | setup(
33 | name='chromatinhd', # Required
34 |
35 | # A list of compiler Directives is available at
36 | # https://cython.readthedocs.io/en/latest/src/userguide/source_files_and_compilation.html#compiler-directives
37 |
38 | # external to be compiled
39 | ext_modules = cythonize(extensions, compiler_directives={"language_level": 3, "profile": False}),
40 | )
--------------------------------------------------------------------------------
/src/chromatinhd/__init__.py:
--------------------------------------------------------------------------------
1 | from .device import get_default_device, set_default_device
2 | from .utils import get_git_root, get_output, get_code, save, Unpickler, load
3 | from . import sparse
4 | from . import utils
5 | from . import flow
6 | from . import plot
7 | from . import data
8 | from . import train
9 | from . import embedding
10 | from . import optim
11 | from . import biomart
12 | from . import models
13 | from polyptich import grid
14 |
15 | __all__ = [
16 | "get_git_root",
17 | "get_output",
18 | "get_code",
19 | "save",
20 | "Unpickler",
21 | "load",
22 | "sparse",
23 | "utils",
24 | "flow",
25 | "data",
26 | "train",
27 | "embedding",
28 | "optim",
29 | "biomart",
30 | "models",
31 | "plot",
32 | "grid",
33 | ]
34 |
--------------------------------------------------------------------------------
/src/chromatinhd/biomart/__init__.py:
--------------------------------------------------------------------------------
1 | from .dataset import Dataset
2 | from .tss import get_canonical_transcripts, get_exons, get_transcripts, map_symbols, get_genes
3 | from . import tss
4 | from .homology import get_orthologs
5 |
6 | __all__ = ["Dataset", "get_canonical_transcripts", "get_exons", "get_transcripts", "tss"]
7 |
--------------------------------------------------------------------------------
/src/chromatinhd/biomart/cache.py:
--------------------------------------------------------------------------------
1 | import diskcache
2 | import appdirs
3 |
4 | cache = diskcache.Cache(appdirs.user_cache_dir('biomart'))
5 |
--------------------------------------------------------------------------------
/src/chromatinhd/biomart/homology.py:
--------------------------------------------------------------------------------
1 | from .dataset import Dataset
2 | import numpy as np
3 |
4 |
5 | def get_orthologs(biomart_dataset: Dataset, gene_ids, organism="mmusculus"):
6 | """
7 | Map ensembl gene ids to orthologs in another organism
8 | """
9 |
10 | gene_ids_to_map = np.unique(gene_ids)
11 | mapping = biomart_dataset.get_batched(
12 | [
13 | biomart_dataset.attribute("ensembl_gene_id"),
14 | biomart_dataset.attribute("external_gene_name"),
15 | biomart_dataset.attribute(f"{organism}_homolog_ensembl_gene"),
16 | biomart_dataset.attribute(f"{organism}_homolog_associated_gene_name"),
17 | ],
18 | filters=[
19 | biomart_dataset.filter("ensembl_gene_id", value=gene_ids_to_map),
20 | ],
21 | )
22 | mapping = mapping.groupby("ensembl_gene_id").first()
23 |
24 | return mapping[f"{organism}_homolog_ensembl_gene"].reindex(gene_ids).values
25 |
--------------------------------------------------------------------------------
/src/chromatinhd/data/__init__.py:
--------------------------------------------------------------------------------
1 | from . import fragments
2 | from . import transcriptome
3 | from . import motifscan
4 | from .fragments import Fragments
5 | from .transcriptome import Transcriptome
6 | from .genotype import Genotype
7 | from .clustering import Clustering
8 | from .motifscan import Motifscan, Motiftrack
9 | from .regions import Regions
10 | from . import regions
11 | from . import folds
12 |
13 | __all__ = ["Fragments", "Transcriptome", "Regions", "folds", "motifscan", "regions"]
14 |
--------------------------------------------------------------------------------
/src/chromatinhd/data/associations/__init__.py:
--------------------------------------------------------------------------------
1 | from .associations import Associations
2 |
--------------------------------------------------------------------------------
/src/chromatinhd/data/associations/associations.py:
--------------------------------------------------------------------------------
1 | from chromatinhd.data.motifscan import Motifscan
2 | from chromatinhd.flow import StoredDataFrame
3 |
4 | from . import plot
5 |
6 |
7 | class Associations(Motifscan):
8 | association = StoredDataFrame()
9 |
--------------------------------------------------------------------------------
/src/chromatinhd/data/clustering/__init__.py:
--------------------------------------------------------------------------------
1 | from .clustering import Clustering
2 |
--------------------------------------------------------------------------------
/src/chromatinhd/data/clustering/clustering.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import pandas as pd
3 | import numpy as np
4 |
5 | from chromatinhd.flow import Flow, Stored, StoredDataFrame, PathLike
6 | from chromatinhd.flow.tensorstore import Tensorstore
7 |
8 |
9 | class Clustering(Flow):
10 | labels: pd.DataFrame = Stored()
11 | "Labels for each cell."
12 |
13 | indices: np.array = Tensorstore(dtype=" Clustering:
27 | """
28 | Create a Clustering object from a series of labels.
29 |
30 | Parameters:
31 | labels:
32 | Series of labels for each cell, with index corresponding to cell
33 | names.
34 | path:
35 | Folder where the clustering information will be stored.
36 | overwrite:
37 | Whether to overwrite the clustering information if it already
38 | exists.
39 |
40 | Returns:
41 | Clustering object.
42 |
43 | """
44 | self = cls(path, reset=overwrite)
45 |
46 | if not overwrite and self.o.labels.exists(self):
47 | return self
48 |
49 | if not isinstance(labels, pd.Series):
50 | labels = pd.Series(labels).astype("category")
51 | elif not labels.dtype.name == "category":
52 | labels = labels.astype("category")
53 | self.labels = labels
54 | self.indices = labels.cat.codes.values
55 |
56 | if var is None:
57 | var = (
58 | pd.DataFrame(
59 | {
60 | "cluster": labels.cat.categories,
61 | "label": labels.cat.categories,
62 | }
63 | )
64 | .set_index("cluster")
65 | .loc[labels.cat.categories]
66 | )
67 | var["n_cells"] = labels.value_counts()
68 | else:
69 | var = var.reindex(labels.cat.categories)
70 | var["label"] = labels.cat.categories
71 | self.var = var
72 | return self
73 |
74 | @property
75 | def n_clusters(self):
76 | return len(self.labels.cat.categories)
77 |
78 | # temporarily link cluster_info to var
79 | @property
80 | def cluster_info(self):
81 | return self.var
82 |
83 | @cluster_info.setter
84 | def cluster_info(self, cluster_info):
85 | self.var = cluster_info
86 |
--------------------------------------------------------------------------------
/src/chromatinhd/data/examples/pbmc10ktiny/fragments.tsv.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeplanckeLab/ChromatinHD/3b089ba409b452a8275ed020555d8bcad50f3299/src/chromatinhd/data/examples/pbmc10ktiny/fragments.tsv.gz
--------------------------------------------------------------------------------
/src/chromatinhd/data/examples/pbmc10ktiny/fragments.tsv.gz.REMOVED.git-id:
--------------------------------------------------------------------------------
1 | 9bbd83c6bcf3f8c42728621ecb8554170dc7caea
--------------------------------------------------------------------------------
/src/chromatinhd/data/examples/pbmc10ktiny/fragments.tsv.gz.tbi:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeplanckeLab/ChromatinHD/3b089ba409b452a8275ed020555d8bcad50f3299/src/chromatinhd/data/examples/pbmc10ktiny/fragments.tsv.gz.tbi
--------------------------------------------------------------------------------
/src/chromatinhd/data/examples/pbmc10ktiny/transcriptome.h5ad:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeplanckeLab/ChromatinHD/3b089ba409b452a8275ed020555d8bcad50f3299/src/chromatinhd/data/examples/pbmc10ktiny/transcriptome.h5ad
--------------------------------------------------------------------------------
/src/chromatinhd/data/folds/__init__.py:
--------------------------------------------------------------------------------
1 | from .folds import Folds
2 |
--------------------------------------------------------------------------------
/src/chromatinhd/data/fragments/__init__.py:
--------------------------------------------------------------------------------
1 | from .fragments import Fragments
2 | from .view import FragmentsView
3 |
--------------------------------------------------------------------------------
/src/chromatinhd/data/genotype/__init__.py:
--------------------------------------------------------------------------------
1 | from .genotype import Genotype
2 |
--------------------------------------------------------------------------------
/src/chromatinhd/data/genotype/genotype.py:
--------------------------------------------------------------------------------
1 | from chromatinhd.flow import Flow, Stored
2 |
3 |
4 | class Genotype(Flow):
5 | genotypes = Stored()
6 | variants_info = Stored()
7 |
--------------------------------------------------------------------------------
/src/chromatinhd/data/gradient/__init__.py:
--------------------------------------------------------------------------------
1 | from .gradient import Gradient
2 |
--------------------------------------------------------------------------------
/src/chromatinhd/data/gradient/gradient.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import pandas as pd
3 | import numpy as np
4 |
5 | from chromatinhd.flow import Flow, Stored, StoredDataFrame, PathLike
6 | from chromatinhd.flow.tensorstore import Tensorstore
7 |
8 |
9 | class Gradient(Flow):
10 | values: np.array = Tensorstore(dtype=" Gradient:
18 | """
19 | Create a Gradient object from a series of values.
20 |
21 | Parameters:
22 | values:
23 | Series of values for each cell, with index corresponding to cell
24 | names.
25 | path:
26 | Folder where the gradient information will be stored.
27 |
28 | Returns:
29 | Gradient object.
30 |
31 | """
32 | gradient = cls(path)
33 | if isinstance(values, pd.Series):
34 | values = pd.DataFrame({"gradient": values}).astype(float)
35 | elif not isinstance(values, pd.DataFrame):
36 | values = pd.DataFrame(values).astype(float)
37 | gradient.values = values.values
38 | gradient.var = pd.DataFrame(
39 | {
40 | "gradient": values.columns,
41 | "label": values.columns,
42 | }
43 | ).set_index("gradient")
44 | return gradient
45 |
46 | @property
47 | def n_gradients(self):
48 | return len(self.values.shape[1])
49 |
--------------------------------------------------------------------------------
/src/chromatinhd/data/motifscan/__init__.py:
--------------------------------------------------------------------------------
1 | from .motifscan import Motifscan, read_pwms
2 | from .motiftrack import Motiftrack
3 | from . import plot
4 | from .view import MotifscanView
5 | from . import download
6 | from . import plot_genome
--------------------------------------------------------------------------------
/src/chromatinhd/data/motifscan/download.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import urllib
3 | import pathlib
4 | import chromatinhd.data.motifscan
5 | import json
6 |
7 |
8 | def get_hocomoco_11(path, organism="human", variant="core", overwrite=False):
9 | """
10 | Download hocomoco human data
11 |
12 | Parameters:
13 | path:
14 | the path to download to
15 | organism:
16 | the organism to download for, either "human" or "mouse"
17 | variant:
18 | the variant to download, either "full" or "core"
19 | overwrite:
20 | whether to overwrite existing files
21 | """
22 | path = pathlib.Path(path)
23 | path.mkdir(parents=True, exist_ok=True)
24 |
25 | if organism == "human":
26 | organism = "HUMAN"
27 | elif organism == "mouse":
28 | organism = "MOUSE"
29 | else:
30 | raise ValueError(f"Unknown organism: {organism}")
31 |
32 | # download cutoffs, pwms and annotations
33 | if overwrite or (not (path / "pwm_cutoffs.txt").exists()):
34 | urllib.request.urlretrieve(
35 | f"https://hocomoco11.autosome.org/final_bundle/hocomoco11/{variant}/{organism}/mono/HOCOMOCOv11_{variant}_standard_thresholds_{organism}_mono.txt",
36 | path / "pwm_cutoffs.txt",
37 | )
38 | urllib.request.urlretrieve(
39 | f"https://hocomoco11.autosome.org/final_bundle/hocomoco11/{variant}/{organism}/mono/HOCOMOCOv11_{variant}_pwms_{organism}_mono.txt",
40 | path / "pwms.txt",
41 | )
42 | urllib.request.urlretrieve(
43 | f"https://hocomoco11.autosome.org/final_bundle/hocomoco11/{variant}/{organism}/mono/HOCOMOCOv11_{variant}_annotation_{organism}_mono.tsv",
44 | path / "annot.txt",
45 | )
46 |
47 | pwms = chromatinhd.data.motifscan.read_pwms(path / "pwms.txt")
48 |
49 | motifs = pd.DataFrame({"motif": pwms.keys()}).set_index("motif")
50 | motif_cutoffs = pd.read_table(
51 | path / "pwm_cutoffs.txt",
52 | names=["motif", "cutoff_001", "cutoff_0005", "cutoff_0001"],
53 | skiprows=1,
54 | ).set_index("motif")
55 | motifs = motifs.join(motif_cutoffs)
56 | annot = (
57 | pd.read_table(path / "annot.txt")
58 | .rename(columns={"Model": "motif", "Transcription factor": "gene_label"})
59 | .set_index("motif")
60 | )
61 | motifs = motifs.join(annot)
62 |
63 | return pwms, motifs
64 |
65 |
66 | def get_hocomoco(path, organism="hs", variant="CORE", overwrite=False):
67 | """
68 | Download hocomoco human data
69 |
70 | Parameters:
71 | path:
72 | the path to download to
73 | organism:
74 | the organism to download for, either "hs" or "mm"
75 | variant:
76 | the variant to download, either "INVIVO" or "CORE"
77 | overwrite:
78 | whether to overwrite existing files
79 | """
80 | path = pathlib.Path(path)
81 | path.mkdir(parents=True, exist_ok=True)
82 |
83 | # download cutoffs, pwms and annotations
84 | if overwrite or (not (path / "pwms.tar.gz").exists()):
85 | urllib.request.urlretrieve(
86 | f"https://hocomoco12.autosome.org/final_bundle/hocomoco12/H12{variant}/H12{variant}_annotation.jsonl",
87 | path / "annotation.jsonl",
88 | )
89 | urllib.request.urlretrieve(
90 | f"https://hocomoco12.autosome.org/final_bundle/hocomoco12/H12{variant}/H12{variant}_pwm.tar.gz",
91 | path / "pwms.tar.gz",
92 | )
93 | urllib.request.urlretrieve(
94 | f"https://hocomoco12.autosome.org/final_bundle/hocomoco12/H12{variant}/H12{variant}_thresholds.tar.gz",
95 | path / "thresholds.tar.gz",
96 | )
97 |
98 | pwms = chromatinhd.data.motifscan.read_pwms(path / "pwms.tar.gz")
99 | motifs = [json.loads(line) for line in open(path / "annotation.jsonl").readlines()]
100 | motifs = pd.DataFrame(motifs).set_index("name")
101 | motifs.index.name = "motif"
102 |
103 | for thresh in motifs["standard_thresholds"].iloc[0].keys():
104 | motifs["cutoff_" + thresh] = [thresholds[thresh] for _, thresholds in motifs["standard_thresholds"].items()]
105 | for species in ["HUMAN", "MOUSE"]:
106 | motifs[species + "_gene_symbol"] = [
107 | masterlist_info["species"][species]["gene_symbol"] if species in masterlist_info["species"] else None
108 | for _, masterlist_info in motifs["masterlist_info"].items()
109 | ]
110 | motifs["symbol"] = motifs["HUMAN_gene_symbol"] if organism == "hs" else motifs["MOUSE_gene_symbol"]
111 |
112 | return pwms, motifs
113 |
--------------------------------------------------------------------------------
/src/chromatinhd/data/motifscan/motifcount.py:
--------------------------------------------------------------------------------
1 | from typing import List, Union
2 | from .motifscan import Motifscan
3 | from .view import MotifscanView
4 | import numpy as np
5 | import tqdm.auto as tqdm
6 |
7 |
8 | class BinnedMotifCounts:
9 | """
10 | Provides binned motif counts per fragment
11 | """
12 |
13 | def __init__(
14 | self,
15 | motifscan: Motifscan,
16 | motif_binsizes: List[int],
17 | fragment_binsizes: List[int],
18 | ):
19 | self.motifscan = motifscan
20 | self.width = motifscan.regions.width
21 | assert self.width is not None, "Regions must have a width"
22 | self.motif_binsizes = motif_binsizes
23 | self.fragment_binsizes = fragment_binsizes
24 | self.n_motifs = motifscan.motifs.shape[0]
25 |
26 | n_genes = len(motifscan.regions.coordinates)
27 |
28 | self.fragment_widths = []
29 | self.motifcount_sizes = []
30 | self.motif_widths = []
31 | self.fragmentprob_sizes = []
32 | width = self.region_width
33 | for motif_binsize, fragment_binsize in zip(self.motif_binsizes, self.fragment_binsizes):
34 | assert fragment_binsize % motif_binsize == 0, (
35 | "motif_binsize must be a multiple of fragment_binsize",
36 | motif_binsize,
37 | fragment_binsize,
38 | )
39 | self.motifcount_sizes.append(width // motif_binsize)
40 | self.motif_widths.append(self.region_width // motif_binsize)
41 | self.fragmentprob_sizes.append(width // fragment_binsize)
42 | self.fragment_widths.append(self.region_width // fragment_binsize)
43 | width = fragment_binsize
44 |
45 | precomputed = []
46 | for motif_binsize, motif_width, fragment_binsize, fragment_width in zip(
47 | self.motif_binsizes,
48 | self.motif_widths,
49 | self.fragment_binsizes,
50 | [1, *self.fragment_widths[:-1]],
51 | ):
52 | precomputed.append(
53 | np.bincount(motifscan.positions // motif_binsize, minlength=(n_genes * motif_width)).reshape(
54 | (n_genes * fragment_width, -1)
55 | )
56 | )
57 | self.precomputed = precomputed
58 |
59 |
60 | class BinnedMotifCounts:
61 | """
62 | Provides binned motif counts per fragment
63 | """
64 |
65 | def __init__(
66 | self,
67 | motifscan: Union[Motifscan, MotifscanView],
68 | binsize: int,
69 | ):
70 | self.motifscan = motifscan
71 | self.width = motifscan.regions.width
72 | assert self.width is not None, "Regions must have a width"
73 | self.binsize = binsize
74 | self.n_motifs = motifscan.motifs.shape[0]
75 |
76 | n_regions = motifscan.regions.n_regions
77 | motif_width = motifscan.regions.width // binsize
78 |
79 | precomputed = np.zeros((n_regions, motif_width), dtype=np.int32)
80 |
81 | for region_ix in tqdm.tqdm(range(n_regions)):
82 | if isinstance(motifscan, Motifscan):
83 | indptr_start = motifscan.region_indptr[region_ix]
84 | indptr_end = motifscan.region_indptr[region_ix + 1]
85 | coordinates = motifscan.coordinates[indptr_start:indptr_end]
86 | else:
87 | indptr_start, indptr_end = motifscan.region_indptr[region_ix]
88 | coordinates = (
89 | motifscan.coordinates[indptr_start:indptr_end] - motifscan.regions.region_starts[region_ix]
90 | )
91 | precomputed[region_ix] = np.bincount(
92 | coordinates // binsize,
93 | minlength=(motif_width),
94 | )
95 | self.precomputed = precomputed
96 |
97 | # self.fragment_binsizes = fragment_binsizes
98 | # self.n_motifs = motifscan.motifs.shape[0]
99 |
100 | # self.fragment_widths = []
101 | # self.fragmentprob_sizes = []
102 | # width = self.region_width
103 | # for fragment_binsize in zip(self.fragment_binsizes):
104 | # self.fragmentprob_sizes.append(width // fragment_binsize)
105 | # self.fragment_widths.append(self.region_width // fragment_binsize)
106 | # width = fragment_binsize
107 |
--------------------------------------------------------------------------------
/src/chromatinhd/data/motifscan/motiftrack.py:
--------------------------------------------------------------------------------
1 | from chromatinhd.flow import Flow, CompressedNumpyFloat64
2 | import pandas as pd
3 |
4 |
5 | class Motiftrack(Flow):
6 | scores = CompressedNumpyFloat64("scores")
7 |
8 | _motifs = None
9 |
10 | @property
11 | def motifs(self):
12 | if self._motifs is None:
13 | self._motifs = pd.read_pickle(self.path / "motifs.pkl")
14 | return self._motifs
15 |
16 | @motifs.setter
17 | def motifs(self, value):
18 | value.index.name = "gene"
19 | value.to_pickle(self.path / "motifs.pkl")
20 | self._motifs = value
21 |
--------------------------------------------------------------------------------
/src/chromatinhd/data/motifscan/scan_helpers.pyx:
--------------------------------------------------------------------------------
1 | #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
2 | #cython: language_level=3
3 | cimport cython
4 | import numpy as np
5 | cimport numpy as np
6 | np.import_array()
7 |
8 | INT8 = np.int8
9 | ctypedef np.int8_t INT8_t
10 |
11 | @cython.boundscheck(False) # turn off bounds-checking for entire function
12 | @cython.wraparound(False) # turn off negative index wrapping for entire function
13 | @cython.cdivision
14 | def seq_to_onehot(bytes s, INT8_t [:,::1] out_onehot):
15 | cdef char* cstr = s
16 | cdef int i, length = len(s)
17 |
18 | for i in range(length):
19 | if cstr[i] == b'A':
20 | out_onehot[i, 0] = 1
21 | elif cstr[i] == b'C':
22 | out_onehot[i, 1] = 1
23 | elif cstr[i] == b'G':
24 | out_onehot[i, 2] = 1
25 | elif cstr[i] == b'T':
26 | out_onehot[i, 3] = 1
27 |
28 | return out_onehot
--------------------------------------------------------------------------------
/src/chromatinhd/data/peakcounts/__init__.py:
--------------------------------------------------------------------------------
1 | from .peakcounts import PeakCounts, Windows
2 | from . import plot
3 |
--------------------------------------------------------------------------------
/src/chromatinhd/data/transcriptome/__init__.py:
--------------------------------------------------------------------------------
1 | from .transcriptome import Transcriptome
--------------------------------------------------------------------------------
/src/chromatinhd/data/transcriptome/timetranscriptome.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import numpy as np
4 | import pandas as pd
5 | import pathlib
6 | from typing import Union
7 |
8 | from chromatinhd.flow import Flow, Stored, StoredDict, TSV
9 | from chromatinhd.flow.tensorstore import Tensorstore
10 | from chromatinhd import sparse
11 | from typing import TYPE_CHECKING
12 |
13 | if TYPE_CHECKING:
14 | import scanpy as sc
15 |
16 |
17 | class TimeTranscriptome(Flow):
18 | """
19 | A transcriptome during pseudotime
20 | """
21 |
22 | var: pd.DataFrame = TSV(index_name="gene")
23 | obs: pd.DataFrame = TSV(index_name="pseudocell")
24 |
25 | @classmethod
26 | def from_transcriptome(
27 | cls,
28 | gradient,
29 | transcriptome,
30 | path: Union[pathlib.Path, str] = None,
31 | overwrite=False,
32 | ):
33 | """
34 | Create a TimeTranscriptome object from a Transcriptome object.
35 | """
36 |
37 | raise NotImplementedError()
38 |
39 | layers = StoredDict(Tensorstore, kwargs=dict(dtype=" str:
38 | s = "{num_embeddings}, {embedding_dims}"
39 | if self.padding_idx is not None:
40 | s += ", padding_idx={padding_idx}"
41 | if self.max_norm is not None:
42 | s += ", max_norm={max_norm}"
43 | if self.norm_type != 2:
44 | s += ", norm_type={norm_type}"
45 | if self.scale_grad_by_freq is not False:
46 | s += ", scale_grad_by_freq={scale_grad_by_freq}"
47 | if self.sparse is not False:
48 | s += ", sparse=True"
49 | return s.format(**self.__dict__)
50 |
51 | def get_full_weight(self):
52 | return self.weight.view((self.weight.shape[0], *self.embedding_dims))
53 |
54 | @property
55 | def data(self):
56 | """
57 | The data of the parameter in dimensions [num_embeddings, *embedding_dims]
58 | """
59 | return self.get_full_weight().data
60 |
61 | @data.setter
62 | def data(self, value):
63 | if value.ndim == 2:
64 | self.weight.data = value
65 | else:
66 | self.weight.data = value.reshape(self.weight.data.shape)
67 |
68 | @property
69 | def shape(self):
70 | """
71 | The shape of the parameter, i.e. [num_embeddings, *embedding_dims]
72 | """
73 | return (self.weight.shape[0], *self.embedding_dims)
74 |
75 | def __getitem__(self, k):
76 | return self.forward(k)
77 |
78 | @classmethod
79 | def from_pretrained(cls, pretrained: EmbeddingTensor):
80 | self = cls(pretrained.num_embeddings, pretrained.embedding_dims)
81 | self.data = pretrained.data
82 | self.weight.requires_grad = False
83 |
84 | return self
85 |
86 |
87 | class FeatureParameter(torch.nn.Module):
88 | _params = tuple()
89 |
90 | def __init__(self, num_embeddings, embedding_dims, constructor=torch.zeros, *args, **kwargs):
91 | super().__init__()
92 | params = []
93 | for i in range(num_embeddings):
94 | params.append(torch.nn.Parameter(constructor(embedding_dims, *args, **kwargs)))
95 | self.register_parameter(str(i), params[-1])
96 | self._params = tuple(params)
97 |
98 | def __getitem__(self, k):
99 | return self._params[k]
100 |
101 | def __setitem__(self, k, v):
102 | self._params = list(self._params)
103 | self._params[k] = v
104 | self._params = tuple(self._params)
105 |
106 | def __call__(self, ks):
107 | return torch.stack([self._params[k] for k in ks], 0)
108 |
--------------------------------------------------------------------------------
/src/chromatinhd/flow/__init__.py:
--------------------------------------------------------------------------------
1 | from .objects import (
2 | Linked,
3 | Stored,
4 | StoredDataFrame,
5 | StoredTensor,
6 | StoredNumpyInt64,
7 | CompressedNumpy,
8 | CompressedNumpyFloat64,
9 | CompressedNumpyInt64,
10 | TSV,
11 | StoredDict,
12 | DataArray,
13 | Dataset,
14 | )
15 | from .flow import (
16 | Flow,
17 | PathLike,
18 | )
19 | from . import tensorstore
20 | from .linked import LinkedDict
21 | from .sparse import SparseDataset
22 |
--------------------------------------------------------------------------------
/src/chromatinhd/flow/flow_template.jinja2:
--------------------------------------------------------------------------------
1 |
2 |
37 |
38 |
39 |
40 | {{ html }}
41 |
--------------------------------------------------------------------------------
/src/chromatinhd/flow/linked.py:
--------------------------------------------------------------------------------
1 | from .objects import Obj, Instance, Linked
2 |
3 |
4 | class LinkedDict(Obj):
5 | def __init__(self, cls=Linked, name=None, kwargs=None):
6 | super().__init__(name=name)
7 | if kwargs is None:
8 | kwargs = {}
9 | self.kwargs = kwargs
10 | self.cls = cls
11 |
12 | def get_path(self, folder):
13 | return folder / self.name
14 |
15 | def __get__(self, obj, type=None):
16 | if obj is not None:
17 | name = "_" + str(self.name)
18 | if not hasattr(obj, name):
19 | x = LinkedDictInstance(self.name, self.get_path(obj.path), self.cls, obj, self.kwargs)
20 | setattr(obj, name, x)
21 | return getattr(obj, name)
22 |
23 | def __set__(self, obj, value, folder=None):
24 | instance = self.__get__(obj)
25 | instance.__set__(obj, value)
26 |
27 | def _repr_html_(self, obj=None):
28 | instance = self.__get__(obj)
29 | return instance._repr_html_()
30 |
31 | def exists(self, obj):
32 | return True
33 |
34 |
35 | class LinkedDictInstance(Instance):
36 | def __init__(self, name, path, cls, obj, kwargs):
37 | super().__init__(name=name, path=path, obj=obj)
38 | self.dict = {}
39 | self.cls = cls
40 | self.obj = obj
41 | self.name = name
42 | self.path = path
43 | self.kwargs = kwargs
44 | if not self.path.exists():
45 | self.path.mkdir(parents=True)
46 | for file in self.path.iterdir():
47 | key = file.name
48 | self.dict[key] = self.cls(name=key, **self.kwargs)
49 |
50 | def __getitem__(self, key):
51 | return self.dict[key].__get__(self)
52 |
53 | def __setitem__(self, key, value):
54 | if key not in self.dict:
55 | self.dict[key] = self.cls(name=key, **self.kwargs)
56 | self.dict[key].__set__(self, value)
57 |
58 | def __set__(self, obj, value, folder=None):
59 | for k, v in value.items():
60 | self[k] = v
61 |
62 | def __contains__(self, key):
63 | return key in self.dict
64 |
65 | def __len__(self):
66 | return len(self.dict)
67 |
68 | def items(self):
69 | for k in self.dict:
70 | yield k, self[k]
71 |
72 | def keys(self):
73 | return self.dict.keys()
74 |
75 | def exists(self):
76 | return True
77 |
78 | def _repr_html_(self):
79 | # return f" {self.name} ({', '.join([getattr(self.obj, k)._repr_html_() for k in self.keys()])})"
80 | items = []
81 | for i, k in zip(range(3), self.keys()):
82 | items.append(self.dict[k]._repr_html_(self))
83 | if len(self.keys()) > 3:
84 | items.append("...")
85 | return f" {self.name} ({', '.join(items)})"
86 |
--------------------------------------------------------------------------------
/src/chromatinhd/flow/sparse.py:
--------------------------------------------------------------------------------
1 | from .objects import Obj, format_size, get_size
2 | from chromatinhd.sparse import SparseDataset as SparseDataset_
3 |
4 |
5 | class SparseDataset(Obj):
6 | def __init__(self, name=None):
7 | super().__init__(name=name)
8 |
9 | def get_path(self, folder):
10 | return folder / (self.name)
11 |
12 | def __get__(self, obj, type=None):
13 | if obj is not None:
14 | if self.name is None:
15 | raise ValueError(obj)
16 | name = "_" + str(self.name)
17 | if not hasattr(obj, name):
18 | path = self.get_path(obj.path)
19 | if not path.exists():
20 | raise FileNotFoundError(f"File {path} does not exist")
21 | setattr(obj, name, SparseDataset_.open(self.get_path(obj.path)))
22 | return getattr(obj, name)
23 |
24 | def __set__(self, obj, value):
25 | name = "_" + str(self.name)
26 | setattr(obj, name, value)
27 |
28 | def exists(self, obj):
29 | return self.get_path(obj.path).exists()
30 |
31 | def _repr_html_(self, obj):
32 | self.__get__(obj)
33 | if not str(self.get_path(obj.path)).startswith("memory"):
34 | size = format_size(get_size(self.get_path(obj.path)))
35 | else:
36 | size = ""
37 | return f" {self.name} {size}"
38 |
--------------------------------------------------------------------------------
/src/chromatinhd/loaders/__init__.py:
--------------------------------------------------------------------------------
1 | from .pool import LoaderPool
2 | from . import minibatches
3 |
4 | from .fragments import Fragments, FragmentsRegional, Cuts
5 | from .transcriptome import Transcriptome
6 | from .transcriptome_fragments import TranscriptomeFragments
7 | from .clustering_fragments import ClusteringCuts
8 |
--------------------------------------------------------------------------------
/src/chromatinhd/loaders/clustering.py:
--------------------------------------------------------------------------------
1 | import chromatinhd.data.clustering
2 | import dataclasses
3 | import torch
4 |
5 |
6 | @dataclasses.dataclass
7 | class Result:
8 | # onehot: torch.Tensor
9 | indices: torch.Tensor
10 |
11 | def to(self, device):
12 | self.indices = self.indices.to(device)
13 | # self.onehot = self.onehot.to(device)
14 | return self
15 |
16 |
17 | class Clustering:
18 | """
19 | Provides clustering data for a minibatch.
20 | """
21 |
22 | def __init__(
23 | self,
24 | clustering: chromatinhd.data.clustering.Clustering,
25 | ):
26 | assert (clustering.labels.cat.categories == clustering.cluster_info.index).all(), (
27 | clustering.labels.cat.categories,
28 | clustering.cluster_info.index,
29 | )
30 | self.onehot = torch.nn.functional.one_hot(
31 | torch.from_numpy(clustering.labels.cat.codes.values.copy()).to(torch.int64),
32 | clustering.n_clusters,
33 | ).to(torch.float)
34 |
35 | def load(self, minibatch):
36 | # onehot = self.onehot[minibatch.cells_oi, :]
37 | indices = torch.argmax(self.onehot[minibatch.cells_oi, :], dim=1)
38 | return Result(indices=indices)
39 |
--------------------------------------------------------------------------------
/src/chromatinhd/loaders/clustering_fragments.py:
--------------------------------------------------------------------------------
1 | import chromatinhd.data.clustering
2 | import dataclasses
3 |
4 | from .fragments import Cuts, CutsRegional, CutsResult
5 | from chromatinhd.loaders.minibatches import Minibatch
6 | from .clustering import Clustering
7 |
8 |
9 | @dataclasses.dataclass
10 | class Result:
11 | clustering: Clustering
12 | cuts: CutsResult
13 | minibatch: Minibatch
14 |
15 | def to(self, device):
16 | self.clustering.to(device)
17 | self.cuts.to(device)
18 | self.minibatch.to(device)
19 | return self
20 |
21 |
22 | class ClusteringCuts:
23 | def __init__(
24 | self,
25 | fragments: chromatinhd.data.fragments.Fragments,
26 | clustering: chromatinhd.data.clustering.Clustering,
27 | cellxregion_batch_size: int,
28 | layer: str = None,
29 | region_oi=None,
30 | ):
31 | # ensure that clustering and fragments have the same obs
32 | # if not all(clustering.obs.index == fragments.obs.index):
33 | # raise ValueError("Clustering and fragments should have the same obs index. ")
34 |
35 | if region_oi is None:
36 | self.cuts = Cuts(fragments, cellxregion_batch_size=cellxregion_batch_size)
37 | self.clustering = Clustering(clustering)
38 | else:
39 | self.cuts = CutsRegional(fragments, cellxregion_batch_size=cellxregion_batch_size, region_oi=region_oi)
40 | self.clustering = Clustering(clustering)
41 |
42 | def load(self, minibatch):
43 | return Result(
44 | clustering=self.clustering.load(minibatch),
45 | cuts=self.cuts.load(minibatch),
46 | minibatch=minibatch,
47 | )
48 |
--------------------------------------------------------------------------------
/src/chromatinhd/loaders/fragments_helpers.pyx:
--------------------------------------------------------------------------------
1 | #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
2 | #cython: language_level=3
3 | cimport cython
4 | import numpy as np
5 | cimport numpy as np
6 | np.import_array()
7 |
8 | INT64 = np.int64
9 | ctypedef np.int64_t INT64_t
10 | FLOAT64 = np.float64
11 | ctypedef np.float64_t FLOAT64_t
12 |
13 | @cython.boundscheck(False) # turn off bounds-checking for entire function
14 | @cython.wraparound(False) # turn off negative index wrapping for entire function
15 | @cython.cdivision
16 | def extract_fragments(
17 | INT64_t [::1] cellxgene_oi,
18 | INT64_t [::1] cellxgene_indptr,
19 | INT64_t [:,::1] coordinates,
20 | INT64_t [::1] genemapping,
21 | INT64_t [:,::1] out_coordinates,
22 | INT64_t [::1] out_genemapping,
23 | INT64_t [::1] out_local_cellxgene_ix,
24 | ):
25 | cdef INT64_t out_ix, local_cellxgene_ix, cellxgene_ix, position
26 | out_ix = 0 # will store where in the output array we are currently
27 | local_cellxgene_ix = 0 # will store the current fragment counting from 0
28 |
29 | with nogil:
30 | for local_cellxgene_ix in range(cellxgene_oi.shape[0]):
31 | cellxgene_ix = cellxgene_oi[local_cellxgene_ix]
32 | for position in range(cellxgene_indptr[cellxgene_ix], cellxgene_indptr[cellxgene_ix+1]):
33 | out_coordinates[out_ix] = coordinates[position]
34 | out_genemapping[out_ix] = genemapping[position]
35 | out_local_cellxgene_ix[out_ix] = local_cellxgene_ix
36 |
37 | out_ix += 1
38 |
39 | return out_ix
40 |
41 |
42 | @cython.boundscheck(False) # turn off bounds-checking for entire function
43 | @cython.wraparound(False) # turn off negative index wrapping for entire function
44 | @cython.cdivision
45 | def multiple_arange(
46 | INT64_t [::1] a,
47 | INT64_t [::1] b,
48 | INT64_t [::1] ixs,
49 | INT64_t [::1] local_cellxregion_ix,
50 | ):
51 | cdef INT64_t out_ix, pair_ix, position
52 | out_ix = 0 # will store where in the output array we are currently
53 | pair_ix = 0 # will store the current a, b pair index
54 |
55 | with nogil:
56 | for pair_ix in range(a.shape[0]):
57 | for position in range(a[pair_ix], b[pair_ix]):
58 | ixs[out_ix] = position
59 | local_cellxregion_ix[out_ix] = pair_ix
60 | out_ix += 1
61 | pair_ix += 1
62 |
63 | return out_ix
64 |
65 |
66 |
67 | @cython.boundscheck(False) # turn off bounds-checking for entire function
68 | @cython.wraparound(False) # turn off negative index wrapping for entire function
69 | @cython.cdivision
70 | def multiple_arange(
71 | INT64_t [::1] a,
72 | INT64_t [::1] b,
73 | INT64_t [::1] ixs,
74 | INT64_t [::1] local_cellxregion_ix,
75 | ):
76 | cdef INT64_t out_ix, pair_ix, position
77 | out_ix = 0 # will store where in the output array we are currently
78 | pair_ix = 0 # will store the current a, b pair index
79 |
80 | with nogil:
81 | for pair_ix in range(a.shape[0]):
82 | for position in range(a[pair_ix], b[pair_ix]):
83 | ixs[out_ix] = position
84 | local_cellxregion_ix[out_ix] = pair_ix
85 | out_ix += 1
86 | pair_ix += 1
87 |
88 | return out_ix
89 |
90 | @cython.boundscheck(False) # turn off bounds-checking for entire function
91 | @cython.wraparound(False) # turn off negative index wrapping for entire function
92 | @cython.cdivision
93 | def count_A_cython(bytes s, INT64_t [:,::1] out_onehot):
94 | cdef char* cstr = s
95 | cdef int i, length = len(s)
96 |
97 | for i in range(length):
98 | if cstr[i] == b'A':
99 | out_onehot[i, 0] = 1
100 | elif cstr[i] == b'C':
101 | out_onehot[i, 1] = 1
102 | elif cstr[i] == b'G':
103 | out_onehot[i, 2] = 1
104 | elif cstr[i] == b'T':
105 | out_onehot[i, 3] = 1
106 | else:
107 | out_onehot[i, 4] = 1
108 |
109 | return out_onehot
--------------------------------------------------------------------------------
/src/chromatinhd/loaders/peakcounts.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import dataclasses
4 |
5 |
6 | @dataclasses.dataclass
7 | class Result:
8 | counts: np.ndarray
9 | cells_oi: np.ndarray
10 | genes_oi: np.ndarray
11 |
12 | @property
13 | def n_cells(self):
14 | return len(self.cells_oi)
15 |
16 | @property
17 | def n_genes(self):
18 | return len(self.genes_oi)
19 |
20 | def to(self, device):
21 | for field_name, field in self.__dataclass_fields__.items():
22 | if field.type is torch.Tensor:
23 | self.__setattr__(field_name, self.__getattribute__(field_name).to(device))
24 | return self
25 |
26 | @property
27 | def genes_oi_torch(self):
28 | return torch.from_numpy(self.genes_oi).to(self.coordinates.device)
29 |
30 | @property
31 | def cells_oi_torch(self):
32 | return torch.from_numpy(self.cells_oi).to(self.coordinates.device)
33 |
34 |
35 | class PeakcountsResult(Result):
36 | pass
37 |
38 |
39 | class Peakcounts:
40 | def __init__(self, fragments, peakcounts):
41 | self.peakcounts = peakcounts
42 | assert "gene_ix" in peakcounts.peaks.columns
43 | var = peakcounts.var
44 | var["ix"] = np.arange(peakcounts.var.shape[0])
45 | peakcounts.var = var
46 | assert "ix" in peakcounts.var.columns
47 |
48 | assert peakcounts.counts.shape[1] == peakcounts.var.shape[0]
49 |
50 | self.gene_peak_mapping = []
51 | cur_gene_ix = -1
52 | for peak_ix, gene_ix in zip(peakcounts.var["ix"][peakcounts.peaks.index].values, peakcounts.peaks["gene_ix"]):
53 | while gene_ix != cur_gene_ix:
54 | self.gene_peak_mapping.append([])
55 | cur_gene_ix += 1
56 | self.gene_peak_mapping[-1].append(peak_ix)
57 |
58 | def load(self, minibatch):
59 | peak_ixs = np.concatenate([self.gene_peak_mapping[gene_ix] for gene_ix in minibatch.genes_oi])
60 | counts = self.peakcounts.counts[minibatch.cells_oi, :][:, peak_ixs]
61 |
62 | return PeakcountsResult(counts=counts, **minibatch.items())
63 |
--------------------------------------------------------------------------------
/src/chromatinhd/loaders/pool.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import threading
3 | import numpy as np
4 | from typing import List
5 | import time
6 |
7 |
8 | class ThreadWithResult(threading.Thread):
9 | result = None
10 |
11 | def __init__(
12 | self,
13 | group=None,
14 | target=None,
15 | name=None,
16 | loader=None,
17 | args=(),
18 | kwargs={},
19 | *,
20 | daemon=None,
21 | ):
22 | assert target is not None
23 | self.loader = loader
24 |
25 | def function():
26 | self.result = target(*args, **kwargs)
27 |
28 | super().__init__(group=group, target=function, name=name, daemon=daemon)
29 |
30 |
31 | def benchmark(loaders, minibatcher, n=100):
32 | """
33 | Benchmarks a pool of loaders
34 | """
35 |
36 | loaders.initialize(minibatcher)
37 | waits = []
38 | import time
39 | import tqdm.auto as tqdm
40 |
41 | start = time.time()
42 | for i, data in zip(range(n), tqdm.tqdm(loaders)):
43 | waits.append(time.time() - start)
44 | start = time.time()
45 | print(sum(waits))
46 |
47 | import matplotlib.pyplot as plt
48 |
49 | plt.plot(waits)
50 |
51 |
52 | class LoaderPool:
53 | loaders_running: list
54 | loaders_available: list
55 | counter: bool = False
56 |
57 | def __init__(
58 | self,
59 | loader_cls,
60 | loader_kwargs=None,
61 | n_workers=3,
62 | loader=None,
63 | ):
64 | self.loaders_running = []
65 |
66 | if loader_kwargs is None:
67 | loader_kwargs = {}
68 | self.loader_cls = loader_cls
69 | self.loader_kwargs = loader_kwargs
70 |
71 | self.n_workers = n_workers
72 |
73 | if loader is not None:
74 | self.loaders = [loader.copy() for i in range(n_workers)]
75 | else:
76 | self.loaders = [loader_cls(**loader_kwargs) for i in range(n_workers)]
77 | for loader in self.loaders:
78 | loader.running = False
79 | self.wait = []
80 |
81 | def initialize(self, tasker, *args, **kwargs):
82 | self.tasker = tasker
83 |
84 | self.args = args
85 | self.kwargs = kwargs
86 |
87 | self.start(*args, **kwargs)
88 |
89 | def start(self):
90 | # join all still running threads
91 | for thread in self.loaders_running:
92 | thread.join()
93 |
94 | for loader in self.loaders:
95 | loader.running = False
96 |
97 | self.loaders_available = copy.copy(self.loaders)
98 | self.loaders_running = []
99 |
100 | self.wait = []
101 |
102 | self.tasker_iter = iter(self.tasker)
103 |
104 | for i in range(min(len(self.tasker), self.n_workers - 1)):
105 | self.submit_next()
106 |
107 | def __iter__(self):
108 | self.counter = 0
109 | return self
110 |
111 | def __len__(self):
112 | return len(self.tasker)
113 |
114 | def __next__(self):
115 | self.counter += 1
116 | if self.counter > len(self.tasker):
117 | raise StopIteration
118 | result = self.pull()
119 | self.submit_next()
120 | return result
121 |
122 | def submit_next(self):
123 | try:
124 | task = next(self.tasker_iter)
125 | except StopIteration:
126 | self.tasker_iter = iter(self.tasker)
127 | task = next(self.tasker_iter)
128 | self.submit(task, *self.args, **self.kwargs)
129 |
130 | def submit(self, *args, **kwargs):
131 | if self.loaders_available is None:
132 | raise ValueError("Pool was not initialized")
133 | if len(self.loaders_available) == 0:
134 | raise ValueError("No loaders available")
135 |
136 | loader = self.loaders_available.pop(0)
137 | if loader.running:
138 | raise ValueError
139 | loader.running = True
140 | thread = ThreadWithResult(target=loader.load, loader=loader, args=args, kwargs=kwargs)
141 | self.loaders_running.append(thread)
142 | thread.start()
143 |
144 | def pull(self):
145 | thread = self.loaders_running.pop(0)
146 |
147 | start = time.time()
148 | thread.join()
149 | wait = time.time() - start
150 | self.wait.append(wait)
151 | thread.loader.running = False
152 | result = thread.result
153 | self.loaders_available.append(thread.loader)
154 | return result
155 |
156 | def terminate(self):
157 | for thread in self.loaders_running:
158 | thread.join()
159 | self.loaders_running = []
160 | self.loaders = None
161 |
162 | def __del__(self):
163 | self.terminate()
164 |
--------------------------------------------------------------------------------
/src/chromatinhd/loaders/transcriptome.py:
--------------------------------------------------------------------------------
1 | import chromatinhd.data.transcriptome
2 | import dataclasses
3 | import torch
4 | import chromatinhd.sparse
5 | from chromatinhd.flow.tensorstore import TensorstoreInstance
6 | import numpy as np
7 |
8 |
9 | @dataclasses.dataclass
10 | class Result:
11 | value: torch.Tensor
12 |
13 | def to(self, device):
14 | self.value = self.value.to(device)
15 | return self
16 |
17 |
18 | class Transcriptome:
19 | def __init__(
20 | self,
21 | transcriptome: chromatinhd.data.transcriptome.Transcriptome,
22 | layer: str = None,
23 | ):
24 | if layer is None:
25 | layer = list(transcriptome.layers.keys())[0]
26 |
27 | X = transcriptome.layers[layer]
28 | if chromatinhd.sparse.is_sparse(X):
29 | self.X = X.dense()
30 | elif torch.is_tensor(X):
31 | self.X = X.numpy()
32 | elif isinstance(X, TensorstoreInstance):
33 | # self.X = X
34 | self.X = X.oindex # open a tensorstore reader with orthogonal indexing
35 | else:
36 | self.X = X
37 |
38 | def load(self, minibatch):
39 | X = torch.from_numpy(self.X[minibatch.cells_oi, minibatch.genes_oi].astype(np.float32))
40 | if X.ndim == 1:
41 | X = X.unsqueeze(1)
42 | return Result(value=X)
43 |
44 |
45 | class TranscriptomeGene:
46 | def __init__(
47 | self,
48 | transcriptome: chromatinhd.data.transcriptome.Transcriptome,
49 | gene_oi,
50 | layer: str = None,
51 | ):
52 | if layer is None:
53 | layer = list(transcriptome.layers.keys())[0]
54 |
55 | gene_ix = transcriptome.var.index.get_loc(gene_oi)
56 |
57 | X = transcriptome.layers[layer]
58 | if chromatinhd.sparse.is_sparse(X):
59 | self.X = X[:, gene_ix].dense()[:, 0]
60 | elif torch.is_tensor(X):
61 | self.X = X[:, gene_ix].numpy()
62 | elif isinstance(X, TensorstoreInstance):
63 | # self.X = X
64 | self.X = X.oindex[:, gene_ix] # open a tensorstore reader with orthogonal indexing
65 | else:
66 | self.X = X[:, gene_ix]
67 |
68 | def load(self, minibatch):
69 | X = torch.from_numpy(self.X[minibatch.cells_oi].astype(np.float32))
70 | if X.ndim == 1:
71 | X = X.unsqueeze(1)
72 | return Result(value=X)
73 |
--------------------------------------------------------------------------------
/src/chromatinhd/loaders/transcriptome_fragments.py:
--------------------------------------------------------------------------------
1 | import chromatinhd.data.transcriptome
2 | import dataclasses
3 |
4 | from .fragments import Fragments, FragmentsRegional
5 | from chromatinhd.loaders.minibatches import Minibatch
6 | from .transcriptome import Transcriptome, TranscriptomeGene
7 |
8 |
9 | @dataclasses.dataclass
10 | class Result:
11 | transcriptome: Transcriptome
12 | fragments: Fragments
13 | minibatch: Minibatch
14 |
15 | def to(self, device):
16 | self.transcriptome.to(device)
17 | self.fragments.to(device)
18 | self.minibatch.to(device)
19 | return self
20 |
21 |
22 | class TranscriptomeFragments:
23 | def __init__(
24 | self,
25 | fragments: chromatinhd.data.fragments.Fragments,
26 | transcriptome: chromatinhd.data.transcriptome.Transcriptome,
27 | cellxregion_batch_size: int,
28 | layer: str = None,
29 | region_oi=None,
30 | ):
31 | # ensure that transcriptome and fragments have the same var
32 | if not all(transcriptome.var.index == fragments.var.index):
33 | raise ValueError("Transcriptome and fragments should have the same var index. ")
34 |
35 | if region_oi is None:
36 | self.fragments = Fragments(
37 | fragments,
38 | cellxregion_batch_size=cellxregion_batch_size,
39 | provide_multiplets=False,
40 | provide_libsize=True,
41 | )
42 | self.transcriptome = Transcriptome(transcriptome, layer=layer)
43 | else:
44 | self.fragments = FragmentsRegional(
45 | fragments,
46 | cellxregion_batch_size=cellxregion_batch_size,
47 | region_oi=region_oi,
48 | provide_libsize=True,
49 | )
50 | self.transcriptome = TranscriptomeGene(transcriptome, gene_oi=region_oi, layer=layer)
51 |
52 | def load(self, minibatch):
53 | return Result(
54 | transcriptome=self.transcriptome.load(minibatch),
55 | fragments=self.fragments.load(minibatch),
56 | minibatch=minibatch,
57 | )
58 |
--------------------------------------------------------------------------------
/src/chromatinhd/loaders/transcriptome_fragments2.py:
--------------------------------------------------------------------------------
1 | import chromatinhd.data.transcriptome
2 | import dataclasses
3 |
4 | from .fragments2 import Fragments
5 | from chromatinhd.loaders.minibatches import Minibatch
6 | from .transcriptome import Transcriptome, TranscriptomeGene
7 |
8 |
9 | @dataclasses.dataclass
10 | class Result:
11 | transcriptome: Transcriptome
12 | fragments: Fragments
13 | minibatch: Minibatch
14 |
15 | def to(self, device):
16 | self.transcriptome.to(device)
17 | self.fragments.to(device)
18 | self.minibatch.to(device)
19 | return self
20 |
21 |
22 | class TranscriptomeFragments:
23 | def __init__(
24 | self,
25 | fragments: chromatinhd.data.fragments.Fragments,
26 | transcriptome: chromatinhd.data.transcriptome.Transcriptome,
27 | regionxcell_batch_size: int,
28 | layer: str = None,
29 | region_oi=None,
30 | ):
31 | # ensure that transcriptome and fragments have the same var
32 | if not all(transcriptome.var.index == fragments.var.index):
33 | raise ValueError("Transcriptome and fragments should have the same var index. ")
34 |
35 | self.fragments = Fragments(fragments, regionxcell_batch_size=regionxcell_batch_size)
36 | self.transcriptome = Transcriptome(transcriptome, layer=layer)
37 |
38 | def load(self, minibatch):
39 | return Result(
40 | transcriptome=self.transcriptome.load(minibatch),
41 | fragments=self.fragments.load(minibatch),
42 | minibatch=minibatch,
43 | )
44 |
--------------------------------------------------------------------------------
/src/chromatinhd/loaders/transcriptome_fragments_time.py:
--------------------------------------------------------------------------------
1 | import chromatinhd.data.transcriptome
2 | import chromatinhd.data.gradient
3 | import dataclasses
4 | import numpy as np
5 | import torch
6 |
7 | from .fragments import Fragments
8 | from chromatinhd.loaders.minibatches import Minibatch
9 | from .transcriptome import Result as TranscriptomeResult
10 |
11 |
12 | @dataclasses.dataclass
13 | class Result:
14 | transcriptome: TranscriptomeResult
15 | fragments: Fragments
16 | minibatch: Minibatch
17 |
18 | def to(self, device):
19 | self.transcriptome.to(device)
20 | self.fragments.to(device)
21 | self.minibatch.to(device)
22 | return self
23 |
24 |
25 | class TranscriptomeTime:
26 | def __init__(
27 | self,
28 | transcriptome: chromatinhd.data.transcriptome.Transcriptome,
29 | gradient: chromatinhd.data.gradient.Gradient,
30 | layer: str = None,
31 | delta_time=0.25,
32 | n_bins=20,
33 | delta_expression=False,
34 | ):
35 | bins = self.bins = np.array([0] + list(np.linspace(0.1, 0.9, n_bins - 1)) + [1])
36 |
37 | x = self.x = gradient.values[:, 0]
38 | if layer is None:
39 | layer = list(transcriptome.layers.keys())[0]
40 | y = transcriptome.layers[layer][:]
41 |
42 | x_binned = self.x_binned = np.clip(np.searchsorted(bins, x) - 1, 0, bins.size - 2)
43 | x_onehot = np.zeros((x_binned.size, x_binned.max() + 1))
44 | x_onehot[np.arange(x_binned.size), x_binned] = 1
45 | y_binned = (x_onehot.T @ y) / x_onehot.sum(axis=0)[:, None]
46 | self.y_binned = (y_binned - y_binned.min(0)) / (y_binned.max(0) - y_binned.min(0))
47 |
48 | self.delta_time = delta_time
49 | self.delta_expression = delta_expression
50 |
51 | def load(self, minibatch):
52 | x = self.x[minibatch.cells_oi]
53 | x_desired = x + self.delta_time
54 | x_desired_bin = np.clip(np.searchsorted(self.bins, x_desired) - 1, 0, self.bins.size - 2)
55 | if self.delta_expression:
56 | x_desired_bin2 = np.clip(np.searchsorted(self.bins, x_desired + self.delta_time) - 1, 0, self.bins.size - 2)
57 | y_desired = (
58 | self.y_binned[x_desired_bin2, :][:, minibatch.genes_oi]
59 | - self.y_binned[x_desired_bin, :][:, minibatch.genes_oi]
60 | )
61 | else:
62 | y_desired = self.y_binned[x_desired_bin, :][:, minibatch.genes_oi]
63 |
64 | return TranscriptomeResult(value=torch.from_numpy(y_desired))
65 |
66 |
67 | class TranscriptomeFragmentsTime:
68 | def __init__(
69 | self,
70 | fragments: chromatinhd.data.fragments.Fragments,
71 | transcriptome: chromatinhd.data.transcriptome.Transcriptome,
72 | gradient: chromatinhd.data.gradient.Gradient,
73 | cellxregion_batch_size: int,
74 | layer: str = None,
75 | delta_time=0.25,
76 | n_bins=20,
77 | delta_expression=False,
78 | ):
79 | # ensure that transcriptome and fragments have the same var
80 | if not all(transcriptome.var.index == fragments.var.index):
81 | raise ValueError("Transcriptome and fragments should have the same var index.")
82 |
83 | self.fragments = Fragments(fragments, cellxregion_batch_size=cellxregion_batch_size)
84 | self.transcriptome = TranscriptomeTime(
85 | transcriptome,
86 | gradient,
87 | layer=layer,
88 | delta_expression=delta_expression,
89 | delta_time=delta_time,
90 | n_bins=n_bins,
91 | )
92 |
93 | def load(self, minibatch):
94 | return Result(
95 | transcriptome=self.transcriptome.load(minibatch),
96 | fragments=self.fragments.load(minibatch),
97 | minibatch=minibatch,
98 | )
99 |
--------------------------------------------------------------------------------
/src/chromatinhd/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import HybridModel, FlowModel
2 | from . import pred
3 | from . import diff
4 |
--------------------------------------------------------------------------------
/src/chromatinhd/models/diff/__init__.py:
--------------------------------------------------------------------------------
1 | # from .differential import DifferentialSlices
2 | from . import plot
3 | from . import loader
4 | from . import model
5 | from . import interpret
6 | from . import trainer
7 |
--------------------------------------------------------------------------------
/src/chromatinhd/models/diff/interpret/__init__.py:
--------------------------------------------------------------------------------
1 | from .regionpositional import RegionPositional
2 | from .performance import Performance
3 | from . import enrichment
4 |
--------------------------------------------------------------------------------
/src/chromatinhd/models/diff/interpret/enrichment/__init__.py:
--------------------------------------------------------------------------------
1 | from .enrichment import enrichment_foreground_vs_background, enrichment_cluster_vs_clusters
2 | from .group import group_enrichment
--------------------------------------------------------------------------------
/src/chromatinhd/models/diff/interpret/enrichment/enrichment.py:
--------------------------------------------------------------------------------
1 | import scipy.stats
2 | import tqdm.auto as tqdm
3 | import numpy as np
4 | import pandas as pd
5 | from chromatinhd.utils import fdr
6 |
7 |
8 | def enrichment_foreground_vs_background(slicescores_foreground, slicescores_background, slicecounts, motifs=None, expected = None):
9 | if motifs is None:
10 | motifs = slicecounts.columns
11 |
12 | if "length" not in slicescores_foreground.columns:
13 | slicescores_foreground["length"] = slicescores_foreground["end"] - slicescores_foreground["start"]
14 | if "length" not in slicescores_background.columns:
15 | slicescores_background["length"] = slicescores_background["end"] - slicescores_background["start"]
16 |
17 | x_foreground = slicecounts.loc[slicescores_foreground.index, motifs].sum(0)
18 | x_background = slicecounts.loc[slicescores_background.index, motifs].sum(0)
19 |
20 | n_foreground = slicescores_foreground["length"].sum()
21 | n_background = slicescores_background["length"].sum()
22 |
23 | contingencies = (
24 | np.stack(
25 | [
26 | n_background - x_background,
27 | x_background,
28 | n_foreground - x_foreground,
29 | x_foreground,
30 | ],
31 | axis=1,
32 | )
33 | .reshape(-1, 2, 2)
34 | .astype(np.int64)
35 | )
36 |
37 | odds = (contingencies[:, 1, 1] * contingencies[:, 0, 0] + 1) / (contingencies[:, 1, 0] * contingencies[:, 0, 1] + 1)
38 |
39 | if expected is not None:
40 | x_foreground = expected.loc[slicescores_foreground.index, motifs].sum(0)
41 | x_background = expected.loc[slicescores_background.index, motifs].sum(0)
42 | n_foreground = slicescores_foreground["length"].sum()
43 | n_background = slicescores_background["length"].sum()
44 |
45 | contingencies_expected = (
46 | np.stack(
47 | [
48 | n_background - x_background,
49 | x_background,
50 | n_foreground - x_foreground,
51 | x_foreground,
52 | ],
53 | axis=1,
54 | )
55 | .reshape(-1, 2, 2)
56 | .astype(np.int64)
57 | )
58 | odds_expected = (contingencies_expected[:, 1, 1] * contingencies_expected[:, 0, 0] + 1) / (contingencies_expected[:, 1, 0] * contingencies_expected[:, 0, 1] + 1)
59 |
60 | odds = odds / odds_expected
61 |
62 | contingencies[:, 0, 1] = contingencies[:, 0, 1] * odds_expected
63 |
64 |
65 | p_values = np.array(
66 | [
67 | scipy.stats.chi2_contingency(c).pvalue if (c > 5).all() else scipy.stats.fisher_exact(c).pvalue
68 | for c in contingencies
69 | ]
70 | )
71 | q_values = fdr(p_values)
72 |
73 | return pd.DataFrame(
74 | {
75 | "odds": odds,
76 | "p_value": p_values,
77 | "q_value": q_values,
78 | "motif": motifs,
79 | "contingency": [c for c in contingencies],
80 | }
81 | ).set_index("motif")
82 |
83 |
84 | def enrichment_cluster_vs_clusters(slicescores, slicecounts, clusters=None, motifs=None, pbar=True, expected = None):
85 | if clusters is None:
86 | if not "cluster" in slicescores.columns:
87 | raise ValueError("No cluster information in slicescores")
88 | elif not slicescores["cluster"].dtype.name == "category":
89 | raise ValueError("Cluster column should be categorical")
90 | clusters = slicescores["cluster"].cat.categories
91 | enrichment = []
92 |
93 | progress = clusters
94 | if pbar:
95 | progress = tqdm.tqdm(progress)
96 | for cluster in progress:
97 | selected_slices = slicescores["cluster"] == cluster
98 | slicescores_foreground = slicescores.loc[selected_slices]
99 | slicescores_background = slicescores.loc[~selected_slices]
100 |
101 | enrichment.append(
102 | enrichment_foreground_vs_background(
103 | slicescores_foreground,
104 | slicescores_background,
105 | slicecounts,
106 | motifs=motifs,
107 | expected = expected,
108 | )
109 | .assign(cluster=cluster)
110 | .reset_index()
111 | )
112 |
113 | enrichment = pd.concat(enrichment, axis=0).set_index(["cluster", "motif"])
114 | enrichment["log_odds"] = np.log(enrichment["odds"])
115 | return enrichment
116 |
--------------------------------------------------------------------------------
/src/chromatinhd/models/diff/interpret/enrichment/group.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 |
4 |
5 | def group_enrichment(
6 | enrichment,
7 | slicecounts,
8 | cluster_info,
9 | merge_cutoff=0.2,
10 | q_value_cutoff=0.01,
11 | odds_cutoff=1.1,
12 | min_found=100,
13 | ):
14 | """
15 | Group motifs based on correlation of their slice counts across clusters
16 |
17 | Parameters
18 | ----------
19 | enrichment : pd.DataFrame
20 | DataFrame with columns "odds", "q_value", "contingency"
21 | slicecounts : pd.DataFrame
22 | DataFrame with slice counts
23 | cluster_info : pd.DataFrame
24 | DataFrame with cluster information
25 | merge_cutoff : float
26 | Correlation cutoff for merging motifs
27 | q_value_cutoff : float
28 | q-value cutoff for enrichment
29 | odds_cutoff : float
30 | Odds ratio cutoff for enrichment
31 | min_found : int
32 | Minimum number of found sites for enrichment
33 | """
34 |
35 | enrichment_grouped = []
36 | slicecors = pd.DataFrame(
37 | np.corrcoef((slicecounts.T > 0) + np.random.normal(0, 1e-6, slicecounts.shape).T),
38 | index=slicecounts.columns,
39 | columns=slicecounts.columns,
40 | )
41 |
42 | for cluster_oi in cluster_info.index:
43 | for direction in ["up", "down"]:
44 | if direction == "up":
45 | enrichment["found"] = enrichment["contingency"].map(lambda x: x[1, 1].sum())
46 | enrichment_oi = (
47 | enrichment.loc[cluster_oi]
48 | .query("q_value < @q_value_cutoff")
49 | .query("odds > @odds_cutoff")
50 | .sort_values("odds", ascending=False)
51 | .query("found > @min_found")
52 | )
53 | else:
54 | enrichment["found"] = enrichment["contingency"].map(lambda x: x[0, 1].sum())
55 | print(enrichment["found"])
56 | enrichment_oi = (
57 | enrichment.loc[cluster_oi]
58 | .query("q_value < @q_value_cutoff")
59 | .query("1/odds > @odds_cutoff")
60 | .sort_values("odds", ascending=True)
61 | .query("found > @min_found")
62 | )
63 |
64 | # enrichment_oi = enrichment_oi.loc[(~enrichment_oi.index.get_level_values("motif").str.contains("ZNF")) & (~enrichment_oi.index.get_level_values("motif").str.startswith("ZN")) & (~enrichment_oi.index.get_level_values("motif").str.contains("KLF")) & (~enrichment_oi.index.get_level_values("motif").str.contains("WT"))]
65 |
66 | motif_grouping = {}
67 | for motif_id in enrichment_oi.index:
68 | slicecors_oi = slicecors.loc[motif_id, list(motif_grouping.keys())]
69 | if (slicecors_oi < merge_cutoff).all():
70 | motif_grouping[motif_id] = [motif_id]
71 | enrichment_oi.loc[motif_id, "group"] = motif_id
72 | else:
73 | group = slicecors_oi.sort_values(ascending=False).index[0]
74 | motif_grouping[group].append(motif_id)
75 | enrichment_oi.loc[motif_id, "group"] = group
76 | enrichment_group = enrichment_oi.sort_values("odds", ascending=False).loc[
77 | list(motif_grouping.keys())
78 | ]
79 | enrichment_group["members"] = [
80 | motif_grouping[group] for group in enrichment_group.index
81 | ]
82 | enrichment_group["direction"] = direction
83 | enrichment_grouped.append(enrichment_group.assign(cluster=cluster_oi))
84 | enrichment_grouped = (
85 | pd.concat(enrichment_grouped).reset_index().set_index(["cluster", "motif"])
86 | )
87 | enrichment_grouped = enrichment_grouped.sort_values("q_value", ascending=True)
88 |
89 | return enrichment_grouped
90 |
--------------------------------------------------------------------------------
/src/chromatinhd/models/diff/interpret/performance.py:
--------------------------------------------------------------------------------
1 | import chromatinhd as chd
2 | import numpy as np
3 | import pandas as pd
4 | import xarray as xr
5 | import pickle
6 | import scipy.stats
7 | import tqdm.auto as tqdm
8 | from chromatinhd import get_default_device
9 |
10 | from chromatinhd.flow.objects import Stored, Dataset, StoredDict
11 |
12 | from itertools import product
13 |
14 |
15 | class Performance(chd.flow.Flow):
16 | """
17 | The train/validation/test performance of a (set of) models.
18 | """
19 |
20 | scores = chd.flow.SparseDataset()
21 |
22 | folds = None
23 |
24 | @classmethod
25 | def create(cls, folds, fragments, phases=None, overwrite=False, path=None):
26 | self = cls(path=path, reset=overwrite)
27 |
28 | self.folds = folds
29 | self.fragments = fragments
30 |
31 | if self.o.scores.exists(self) and not overwrite:
32 | assert self.scores.coords_pointed["region"].equals(fragments.var.index)
33 |
34 | return self
35 |
36 | if phases is None:
37 | phases = ["train", "validation", "test"]
38 |
39 | coords_pointed = {
40 | "region": fragments.regions.var.index,
41 | "fold": pd.Index(range(len(folds)), name="fold"),
42 | "phase": pd.Index(phases, name="phase"),
43 | }
44 | coords_fixed = {}
45 | self.scores = chd.sparse.SparseDataset.create(
46 | self.path / "scores",
47 | variables={
48 | "likelihood": {
49 | "dimensions": ("region", "fold", "phase"),
50 | "dtype": np.float32,
51 | "sparse": False,
52 | },
53 | "scored": {
54 | "dimensions": ("region", "fold"),
55 | "dtype": np.bool,
56 | "sparse": False,
57 | },
58 | },
59 | coords_pointed=coords_pointed,
60 | coords_fixed=coords_fixed,
61 | )
62 |
63 | return self
64 |
65 | def score(
66 | self,
67 | models,
68 | force=False,
69 | device=None,
70 | regions=None,
71 | pbar=True,
72 | ):
73 | fragments = self.fragments
74 | folds = self.folds
75 |
76 | phases = self.scores.coords["phase"].values
77 |
78 | region_name = fragments.var.index.name
79 |
80 | design = self.scores["scored"].sel_xr().to_dataframe(name="scored")
81 | design = design.loc[~design["scored"]]
82 | design = design.reset_index()[[region_name, "fold"]]
83 |
84 | if regions is not None:
85 | design = design.loc[design[region_name].isin(regions)]
86 |
87 | progress = design.groupby("fold")
88 | if pbar is True:
89 | progress = tqdm.tqdm(progress, total=len(progress), leave=False)
90 |
91 | for fold_ix, subdesign in progress:
92 | if models.fitted(fold_ix):
93 | for phase in phases:
94 | fold = folds[fold_ix]
95 | cells_oi = fold[f"cells_{phase}"]
96 |
97 | likelihood = models.get_prediction(
98 | fold_ix=fold_ix, cell_ixs=cells_oi, regions=subdesign[region_name], return_raw=True
99 | )
100 |
101 | for region_ix, region_oi in enumerate(subdesign[region_name].values):
102 | self.scores["likelihood"][
103 | region_oi,
104 | fold_ix,
105 | phase,
106 | ] = likelihood[:, region_ix].mean()
107 |
108 | for region_oi in subdesign[region_name]:
109 | self.scores["scored"][region_oi, fold_ix] = True
110 |
111 | return self
112 |
--------------------------------------------------------------------------------
/src/chromatinhd/models/diff/interpret/slices.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | import torch
4 | import xarray as xr
5 | import tqdm
6 |
7 |
8 | def filter_slices_probs(prob_cutoff=0.0):
9 | prob_cutoff = 0.0
10 | # prob_cutoff = -1.
11 | # prob_cutoff = -4.
12 |
13 | start_position_ixs = []
14 | end_position_ixs = []
15 | data = []
16 | region_ixs = []
17 | for region, probs in tqdm.tqdm(regionpositional.probs.items()):
18 | region_ix = fragments.var.index.get_loc(region)
19 | desired_x = np.arange(*fragments.regions.window) - fragments.regions.window[0]
20 | x = probs.coords["coord"].values - fragments.regions.window[0]
21 | y = probs.values
22 |
23 | y_interpolated = chd.utils.interpolate_1d(
24 | torch.from_numpy(desired_x), torch.from_numpy(x), torch.from_numpy(y)
25 | ).numpy()
26 |
27 | # from y_interpolated, determine start and end positions of the relevant slices
28 | start_position_ixs_region, end_position_ixs_region, data_region = extract_slices(y_interpolated, prob_cutoff)
29 | start_position_ixs.append(start_position_ixs_region + fragments.regions.window[0])
30 | end_position_ixs.append(end_position_ixs_region + fragments.regions.window[0])
31 | data.append(data_region)
32 | region_ixs.append(np.ones(len(start_position_ixs_region), dtype=int) * region_ix)
33 | data = np.concatenate(data, axis=0)
34 | start_position_ixs = np.concatenate(start_position_ixs, axis=0)
35 | end_position_ixs = np.concatenate(end_position_ixs, axis=0)
36 | region_ixs = np.concatenate(region_ixs, axis=0)
37 |
38 | slices = Slices(region_ixs, start_position_ixs, end_position_ixs, data, fragments.n_regions)
39 |
40 |
41 | def extract_slices(x, cutoff=0.0):
42 | selected = (x > cutoff).any(0).astype(int)
43 | selected_padded = np.pad(selected, ((1, 1)))
44 | (start_position_indices,) = np.where(np.diff(selected_padded, axis=-1) == 1)
45 | (end_position_indices,) = np.where(np.diff(selected_padded, axis=-1) == -1)
46 | start_position_indices = start_position_indices + 1
47 | end_position_indices = end_position_indices + 1 - 1
48 |
49 | data = []
50 | for start_ix, end_ix in zip(start_position_indices, end_position_indices):
51 | data.append(x[:, start_ix:end_ix].transpose(1, 0))
52 | if len(data) == 0:
53 | data = np.zeros((0, x.shape[0]))
54 | else:
55 | data = np.concatenate(data, axis=0)
56 |
57 | return start_position_indices, end_position_indices, data
58 |
--------------------------------------------------------------------------------
/src/chromatinhd/models/diff/loader/__init__.py:
--------------------------------------------------------------------------------
1 | from chromatinhd.loaders.minibatches import Minibatcher, Minibatch
2 | from .clustering_cuts import ClusteringCuts
3 |
--------------------------------------------------------------------------------
/src/chromatinhd/models/diff/loader/clustering_cuts.py:
--------------------------------------------------------------------------------
1 | import chromatinhd.data.fragments
2 | import dataclasses
3 |
4 | from chromatinhd.loaders.clustering import Clustering
5 | from chromatinhd.loaders.minibatches import Minibatch
6 | from chromatinhd.loaders.fragments import Cuts
7 |
8 |
9 | @dataclasses.dataclass
10 | class Result:
11 | cuts: Cuts
12 | clustering: Clustering
13 | minibatch: Minibatch
14 |
15 | def to(self, device):
16 | self.cuts.to(device)
17 | self.clustering.to(device)
18 | self.minibatch.to(device)
19 | return self
20 |
21 |
22 | class ClusteringCuts:
23 | """
24 | Provides both clustering and cuts data for a minibatch.
25 | """
26 |
27 | def __init__(
28 | self,
29 | clustering: chromatinhd.data.clustering.Clustering,
30 | fragments: chromatinhd.data.fragments.Fragments,
31 | cellxregion_batch_size: int,
32 | ):
33 | # ensure that the order of clustering and fragment.obs is the same
34 | if not all(clustering.labels.index == fragments.obs.index):
35 | raise ValueError("Clustering and fragments should have the same obs index. ")
36 | self.clustering = Clustering(clustering)
37 | self.cuts = Cuts(fragments, cellxregion_batch_size=cellxregion_batch_size)
38 |
39 | def load(self, minibatch):
40 | return Result(
41 | cuts=self.cuts.load(minibatch),
42 | clustering=self.clustering.load(minibatch),
43 | minibatch=minibatch,
44 | )
45 |
--------------------------------------------------------------------------------
/src/chromatinhd/models/diff/loader/clustering_fragments.py:
--------------------------------------------------------------------------------
1 | import chromatinhd.data.fragments
2 | import dataclasses
3 |
4 | from chromatinhd.loaders.fragments import Fragments
5 | from chromatinhd.loaders.clustering import Clustering
6 | from chromatinhd.loaders.minibatches import Minibatch
7 |
8 |
9 | @dataclasses.dataclass
10 | class Result:
11 | fragments: Fragments
12 | clustering: Clustering
13 | minibatch: Minibatch
14 |
15 | def to(self, device):
16 | self.fragments.to(device)
17 | self.clustering.to(device)
18 | self.minibatch.to(device)
19 | return self
20 |
21 |
22 | class ClusteringFragments:
23 | """
24 | Provides both clustering and fragments data for a minibatch.
25 | """
26 |
27 | def __init__(
28 | self,
29 | clustering: chromatinhd.data.clustering.Clustering,
30 | fragments: chromatinhd.data.fragments.Fragments,
31 | cellxregion_batch_size: int,
32 | ):
33 | # ensure that the order of clustering and fragment.obs is the same
34 | if not all(clustering.labels.index == fragments.obs.index):
35 | raise ValueError("Clustering and fragments should have the same obs index. ")
36 | self.clustering = Clustering(clustering)
37 | self.fragments = Fragments(fragments, cellxregion_batch_size=cellxregion_batch_size)
38 |
39 | def load(self, minibatch):
40 | return Result(
41 | fragments=self.fragments.load(minibatch),
42 | clustering=self.clustering.load(minibatch),
43 | minibatch=minibatch,
44 | )
45 |
--------------------------------------------------------------------------------
/src/chromatinhd/models/diff/model/__init__.py:
--------------------------------------------------------------------------------
1 | from . import cutnf
2 | from . import binary
3 |
--------------------------------------------------------------------------------
/src/chromatinhd/models/diff/model/splines/__init__.py:
--------------------------------------------------------------------------------
1 | from . import linear
2 | from . import quadratic
3 |
--------------------------------------------------------------------------------
/src/chromatinhd/models/diff/model/splines/linear.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | from torch.nn import functional as F
5 | import numpy as np
6 |
7 |
8 | def unconstrained_linear_spline(
9 | inputs, unnormalized_pdf, inverse=False, tail_bound=1.0, tails="linear"
10 | ):
11 | inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
12 | outside_interval_mask = ~inside_interval_mask
13 |
14 | outputs = torch.zeros_like(inputs)
15 | logabsdet = torch.zeros_like(inputs)
16 |
17 | if tails == "linear":
18 | outputs[outside_interval_mask] = inputs[outside_interval_mask]
19 | logabsdet[outside_interval_mask] = 0
20 | else:
21 | raise RuntimeError("{} tails are not implemented.".format(tails))
22 |
23 | outputs[inside_interval_mask], logabsdet[inside_interval_mask] = linear_spline(
24 | inputs=inputs[inside_interval_mask],
25 | unnormalized_pdf=unnormalized_pdf[inside_interval_mask, :],
26 | inverse=inverse,
27 | left=-tail_bound,
28 | right=tail_bound,
29 | bottom=-tail_bound,
30 | top=tail_bound,
31 | )
32 |
33 | return outputs, logabsdet
34 |
35 |
36 | def linear_spline(
37 | inputs, unnormalized_pdf, inverse=False, left=0.0, right=1.0, bottom=0.0, top=1.0
38 | ):
39 | """
40 | Reference:
41 | > Müller et al., Neural Importance Sampling, arXiv:1808.03856, 2018.
42 | """
43 | if not inverse and (torch.min(inputs) < left or torch.max(inputs) > right):
44 | raise ValueError("Inputs outside domain")
45 | elif inverse and (torch.min(inputs) < bottom or torch.max(inputs) > top):
46 | raise ValueError("Inputs outside domain")
47 |
48 | if inverse:
49 | inputs = (inputs - bottom) / (top - bottom)
50 | else:
51 | inputs = (inputs - left) / (right - left)
52 |
53 | num_bins = unnormalized_pdf.size(-1)
54 |
55 | pdf = F.softmax(unnormalized_pdf, dim=-1)
56 |
57 | cdf = torch.cumsum(pdf, dim=-1)
58 | cdf[..., -1] = 1.01
59 | cdf = F.pad(cdf, pad=(1, 0), mode="constant", value=0.0)
60 |
61 | if inverse:
62 | inv_bin_idx = torch.searchsorted(cdf, inputs, right=False)
63 |
64 | bin_boundaries = (
65 | torch.linspace(0, 1, num_bins + 1)
66 | .view([1] * inputs.dim() + [-1])
67 | .expand(*inputs.shape, -1)
68 | )
69 |
70 | slopes = (cdf[..., 1:] - cdf[..., :-1]) / (
71 | bin_boundaries[..., 1:] - bin_boundaries[..., :-1]
72 | )
73 | offsets = cdf[..., 1:] - slopes * bin_boundaries[..., 1:]
74 |
75 | inv_bin_idx = inv_bin_idx.unsqueeze(-1)
76 | input_slopes = slopes.gather(-1, inv_bin_idx)[..., 0]
77 | input_offsets = offsets.gather(-1, inv_bin_idx)[..., 0]
78 |
79 | outputs = (inputs - input_offsets) / input_slopes
80 | outputs = torch.clamp(outputs, 0, 1)
81 |
82 | logabsdet = -torch.log(input_slopes)
83 | else:
84 | bin_pos = inputs * num_bins
85 |
86 | bin_idx = torch.floor(bin_pos).long()
87 | bin_idx[bin_idx >= num_bins] = num_bins - 1
88 |
89 | alpha = bin_pos - bin_idx.float()
90 |
91 | if bin_idx.ndim < pdf.ndim:
92 | bin_idx = bin_idx.unsqueeze(-1)
93 |
94 | input_pdfs = pdf.gather(-1, bin_idx) # [..., 0]
95 |
96 | outputs = cdf.gather(-1, bin_idx) # [..., 0]
97 | outputs += alpha * input_pdfs
98 | outputs = torch.clamp(outputs, 0, 1)
99 |
100 | bin_width = 1.0 / num_bins
101 | logabsdet = torch.log(input_pdfs) - np.log(bin_width)
102 |
103 | if inverse:
104 | outputs = outputs * (right - left) + left
105 | logabsdet = logabsdet - math.log(top - bottom) + math.log(right - left)
106 | else:
107 | outputs = outputs * (top - bottom) + bottom
108 | logabsdet = logabsdet + math.log(top - bottom) - math.log(right - left)
109 |
110 | return outputs, logabsdet
111 |
--------------------------------------------------------------------------------
/src/chromatinhd/models/diff/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | from .trainer import Trainer, TrainerPerFeature
2 |
--------------------------------------------------------------------------------
/src/chromatinhd/models/miff/loader/__init__.py:
--------------------------------------------------------------------------------
1 | from chromatinhd.loaders.minibatches import Minibatcher, Minibatch
2 |
--------------------------------------------------------------------------------
/src/chromatinhd/models/miff/loader/clustering.py:
--------------------------------------------------------------------------------
1 | import chromatinhd.data.clustering
2 | import dataclasses
3 | import torch
4 |
5 |
6 | @dataclasses.dataclass
7 | class Result:
8 | labels: torch.Tensor
9 |
10 | def to(self, device):
11 | self.labels = self.labels.to(device)
12 | return self
13 |
14 |
15 | class Clustering:
16 | """
17 | Provides clustering data for a minibatch.
18 | """
19 |
20 | def __init__(
21 | self,
22 | clustering: chromatinhd.data.clustering.Clustering,
23 | ):
24 | assert (clustering.labels.cat.categories == clustering.cluster_info.index).all(), (
25 | clustering.labels.cat.categories,
26 | clustering.cluster_info.index,
27 | )
28 | self.labels = torch.from_numpy(clustering.labels.cat.codes.values.copy()).to(torch.int64)
29 |
30 | def load(self, minibatch):
31 | return Result(labels=self.labels[minibatch.cells_oi])
32 |
--------------------------------------------------------------------------------
/src/chromatinhd/models/miff/loader/combinations.py:
--------------------------------------------------------------------------------
1 | import chromatinhd.data.fragments
2 | import dataclasses
3 | import numpy as np
4 | import torch
5 |
6 | from chromatinhd.loaders.fragments import Fragments as FragmentsLoader
7 | from chromatinhd.loaders.minibatches import Minibatch
8 | from .motifcount import BinnedMotifCounts as BinnedMotifCountsLoader
9 | from .clustering import Clustering as ClusteringLoader
10 | from typing import List
11 |
12 |
13 | @dataclasses.dataclass
14 | class MotifCountsFragmentsClusteringResult:
15 | fragments: any
16 | motifcounts: any
17 | clustering: any
18 | minibatch: any
19 |
20 | def to(self, device):
21 | self.fragments.to(device)
22 | self.motifcounts.to(device)
23 | self.clustering.to(device)
24 | self.minibatch.to(device)
25 | return self
26 |
27 |
28 | class MotifCountsFragmentsClustering:
29 | """
30 | Provides both motifcounts and fragments data for a minibatch.
31 | """
32 |
33 | def __init__(
34 | self,
35 | motifcounts,
36 | fragments: chromatinhd.data.fragments.Fragments,
37 | clustering: chromatinhd.data.clustering.Clustering,
38 | cellxregion_batch_size: int,
39 | ):
40 | # ensure that the order of motifs and fragment.obs is the same
41 | self.fragments = FragmentsLoader(fragments, cellxregion_batch_size=cellxregion_batch_size)
42 | self.motifcounts = BinnedMotifCountsLoader(motifcounts)
43 | self.clustering = ClusteringLoader(clustering)
44 |
45 | def load(self, minibatch):
46 | fragments = self.fragments.load(minibatch)
47 | return MotifCountsFragmentsClusteringResult(
48 | fragments=fragments,
49 | motifcounts=self.motifcounts.load(minibatch, fragments),
50 | clustering=self.clustering.load(minibatch),
51 | minibatch=minibatch,
52 | )
53 |
54 |
55 | @dataclasses.dataclass
56 | class FragmentsClusteringResult:
57 | fragments: any
58 | clustering: any
59 | minibatch: any
60 |
61 | def to(self, device):
62 | self.fragments.to(device)
63 | self.clustering.to(device)
64 | self.minibatch.to(device)
65 | return self
66 |
67 |
68 | class FragmentsClustering:
69 | """
70 | Provides both fragments and clustering data for a minibatch.
71 | """
72 |
73 | def __init__(
74 | self,
75 | fragments: chromatinhd.data.fragments.Fragments,
76 | clustering: chromatinhd.data.clustering.Clustering,
77 | cellxregion_batch_size: int,
78 | ):
79 | # ensure that the order of motifs and fragment.obs is the same
80 | self.fragments = FragmentsLoader(fragments, cellxregion_batch_size=cellxregion_batch_size)
81 | self.clustering = ClusteringLoader(clustering)
82 |
83 | def load(self, minibatch):
84 | fragments = self.fragments.load(minibatch)
85 | return FragmentsClusteringResult(
86 | fragments=fragments,
87 | clustering=self.clustering.load(minibatch),
88 | minibatch=minibatch,
89 | )
90 |
--------------------------------------------------------------------------------
/src/chromatinhd/models/miff/loader/minibatches.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import dataclasses
3 | import itertools
4 | import math
5 | import torch
6 |
7 |
8 | @dataclasses.dataclass
9 | class Minibatch:
10 | cells_oi: np.ndarray
11 | genes_oi: np.ndarray
12 | phase: str = "train"
13 | device: str = "cpu"
14 |
15 | def items(self):
16 | return {"cells_oi": self.cells_oi, "genes_oi": self.genes_oi}
17 |
18 | def filter_genes(self, genes):
19 | genes_oi = self.genes_oi[genes[self.genes_oi]]
20 |
21 | return Minibatch(
22 | cells_oi=self.cells_oi,
23 | genes_oi=genes_oi,
24 | phase=self.phase,
25 | )
26 |
27 | def to(self, device):
28 | self.device = device
29 |
30 | @property
31 | def genes_oi_torch(self):
32 | return torch.from_numpy(self.genes_oi).to(self.device)
33 |
34 | @property
35 | def cells_oi_torch(self):
36 | return torch.from_numpy(self.genes_oi).to(self.device)
37 |
38 | @property
39 | def n_cells(self):
40 | return len(self.cells_oi)
41 |
42 | @property
43 | def n_genes(self):
44 | return len(self.genes_oi)
45 |
46 |
47 | class Minibatcher:
48 | def __init__(
49 | self,
50 | cells,
51 | genes,
52 | n_cells_step,
53 | n_regions_step,
54 | use_all_cells=False,
55 | use_all_regions=True,
56 | permute_cells=True,
57 | permute_regions=False,
58 | ):
59 | self.cells = cells
60 | if not isinstance(genes, np.ndarray):
61 | genes = np.array(genes)
62 | self.genes = genes
63 | self.n_genes = len(genes)
64 | self.n_cells_step = n_cells_step
65 | self.n_regions_step = n_regions_step
66 |
67 | self.permute_cells = permute_cells
68 | self.permute_regions = permute_regions
69 |
70 | self.use_all_cells = use_all_cells or len(cells) < n_cells_step
71 | self.use_all_regions = use_all_regions or len(genes) < n_regions_step
72 |
73 | self.cellxregion_batch_size = n_cells_step * n_regions_step
74 |
75 | # calculate length
76 | n_cells = len(cells)
77 | n_genes = len(genes)
78 | if self.use_all_cells:
79 | n_cell_bins = math.ceil(n_cells / n_cells_step)
80 | else:
81 | n_cell_bins = math.floor(n_cells / n_cells_step)
82 | if self.use_all_regions:
83 | n_gene_bins = math.ceil(n_genes / n_regions_step)
84 | else:
85 | n_gene_bins = math.floor(n_genes / n_regions_step)
86 | self.length = n_cell_bins * n_gene_bins
87 |
88 | self.i = 0
89 |
90 | def __len__(self):
91 | return self.length
92 |
93 | def __iter__(self):
94 | self.rg = np.random.RandomState(self.i)
95 |
96 | if self.permute_cells:
97 | cells = self.rg.permutation(self.cells)
98 | else:
99 | cells = self.cells
100 | if self.permute_regions:
101 | genes = self.rg.permutation(self.genes)
102 | else:
103 | genes = self.genes
104 |
105 | gene_cuts = [*np.arange(0, len(genes), step=self.n_regions_step)]
106 | if self.use_all_regions:
107 | gene_cuts.append(len(genes))
108 | gene_bins = [genes[a:b] for a, b in zip(gene_cuts[:-1], gene_cuts[1:])]
109 |
110 | cell_cuts = [*np.arange(0, len(cells), step=self.n_cells_step)]
111 | if self.use_all_cells:
112 | cell_cuts.append(len(cells))
113 | cell_bins = [cells[a:b] for a, b in zip(cell_cuts[:-1], cell_cuts[1:])]
114 |
115 | for cells_oi, genes_oi in itertools.product(cell_bins, gene_bins):
116 | yield Minibatch(cells_oi=cells_oi, genes_oi=genes_oi)
117 |
118 | self.i += 1
119 |
--------------------------------------------------------------------------------
/src/chromatinhd/models/miff/loader/motifs.py:
--------------------------------------------------------------------------------
1 | import chromatinhd.data.motifscan
2 | import dataclasses
3 | import torch
4 | from chromatinhd.utils.numpy import indptr_to_indices
5 | import numpy as np
6 |
7 |
8 | @dataclasses.dataclass
9 | class Result:
10 | indices: torch.Tensor
11 | positions: torch.Tensor
12 | scores: torch.Tensor
13 | local_genexmotif_ix: torch.Tensor
14 | local_gene_ix: torch.Tensor
15 | n_genes: int
16 |
17 | def to(self, device):
18 | for field_name, field in self.__dataclass_fields__.items():
19 | if field.type is torch.Tensor:
20 | self.__setattr__(field_name, self.__getattribute__(field_name).to(device))
21 | return self
22 |
23 |
24 | class Motifs:
25 | """
26 | Provides motifscan data for a minibatch.
27 | """
28 |
29 | def __init__(
30 | self,
31 | motifscan: chromatinhd.data.motifscan.Motifscan,
32 | ):
33 | self.motifscan = motifscan
34 | self.region_width = motifscan.regions.window[1] - motifscan.regions.window[0]
35 | self.n_motifs = motifscan.motifs.shape[0]
36 |
37 | def load(self, minibatch):
38 | local_genexmotif_ix = []
39 | scores = []
40 | positions = []
41 | indices = []
42 | local_gene_ix = []
43 | for i, gene_ix in enumerate(minibatch.genes_oi):
44 | indptr_start = gene_ix * self.region_width
45 | indptr_end = (gene_ix + 1) * self.region_width
46 | indices_gene = self.motifscan.indices[
47 | self.motifscan.indptr[indptr_start] : self.motifscan.indptr[indptr_end]
48 | ]
49 | positions_gene = (
50 | indptr_to_indices(self.motifscan.indptr[indptr_start : indptr_end + 1])
51 | + self.motifscan.regions.window[0]
52 | )
53 | scores_gene = self.motifscan.scores[self.motifscan.indptr[indptr_start] : self.motifscan.indptr[indptr_end]]
54 |
55 | indices.append(indices_gene)
56 | positions.append(positions_gene)
57 | scores.append(scores_gene)
58 |
59 | local_genexmotif_ix_gene = np.ones_like(indices_gene, dtype=np.int32) * i * self.n_motifs + indices_gene
60 | local_genexmotif_ix.append(local_genexmotif_ix_gene)
61 | local_gene_ix_gene = np.ones_like(indices_gene, dtype=np.int32) * i
62 | local_gene_ix.append(local_gene_ix_gene)
63 |
64 | indices = torch.from_numpy(np.concatenate(indices)).contiguous()
65 | positions = torch.from_numpy(np.concatenate(positions)).contiguous()
66 | scores = torch.from_numpy(np.concatenate(scores)).contiguous()
67 | local_genexmotif_ix = torch.from_numpy(np.concatenate(local_genexmotif_ix)).contiguous()
68 | local_gene_ix = torch.from_numpy(np.concatenate(local_gene_ix)).contiguous()
69 | return Result(indices, positions, scores, local_genexmotif_ix, local_gene_ix, n_genes=minibatch.n_genes)
70 |
--------------------------------------------------------------------------------
/src/chromatinhd/models/miff/loader/motifs_fragments.py:
--------------------------------------------------------------------------------
1 | import chromatinhd.data.fragments
2 | import dataclasses
3 | import numpy as np
4 |
5 | from chromatinhd.loaders.fragments import Fragments
6 | from chromatinhd.loaders.minibatches import Minibatch
7 | from .motifs import Motifs
8 | from typing import List
9 |
10 |
11 | @dataclasses.dataclass
12 | class Result:
13 | fragments: Fragments
14 | motifs: Motifs
15 | minibatch: Minibatch
16 |
17 | def to(self, device):
18 | self.fragments.to(device)
19 | self.motifs.to(device)
20 | self.minibatch.to(device)
21 | return self
22 |
23 |
24 | class MotifsFragments:
25 | """
26 | Provides both motifs and fragments data for a minibatch.
27 | """
28 |
29 | def __init__(
30 | self,
31 | motifscan: chromatinhd.data.motifscan.Motifscan,
32 | fragments: chromatinhd.data.fragments.Fragments,
33 | cellxregion_batch_size: int,
34 | ):
35 | self.motifs = Motifs(motifscan)
36 | self.fragments = Fragments(fragments, cellxregion_batch_size=cellxregion_batch_size)
37 |
38 | def load(self, minibatch):
39 | return Result(
40 | fragments=self.fragments.load(minibatch),
41 | motifs=self.motifs.load(minibatch),
42 | minibatch=minibatch,
43 | )
44 |
--------------------------------------------------------------------------------
/src/chromatinhd/models/miff/model/__init__.py:
--------------------------------------------------------------------------------
1 | from . import global_norm
2 |
--------------------------------------------------------------------------------
/src/chromatinhd/models/miff/model/clustering/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import Model
2 |
--------------------------------------------------------------------------------
/src/chromatinhd/models/miff/model/clustering2/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import Model
2 |
--------------------------------------------------------------------------------
/src/chromatinhd/models/miff/model/clustering3/__init__.py:
--------------------------------------------------------------------------------
1 | from . import count, position
2 | from .model import Model
3 |
--------------------------------------------------------------------------------
/src/chromatinhd/models/miff/model/clustering3/count.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import math
4 | from typing import Union
5 | from chromatinhd.embedding import EmbeddingTensor
6 | from chromatinhd.data.fragments import Fragments, FragmentsView
7 |
8 |
9 | class FragmentCountDistribution(torch.nn.Module):
10 | pass
11 |
12 |
13 | def count_fragments(data):
14 | count = torch.bincount(
15 | data.fragments.local_cellxregion_ix, minlength=data.minibatch.n_cells * data.minibatch.n_regions
16 | ).reshape((data.minibatch.n_cells, data.minibatch.n_regions))
17 | return count
18 |
19 |
20 | class Baseline(FragmentCountDistribution):
21 | def __init__(self, fragments, clustering, baseline=None):
22 | super().__init__()
23 | self.logit = torch.nn.Parameter(torch.zeros((1,)))
24 | if baseline is not None:
25 | self.baseline = baseline
26 | else:
27 | self.baseline = EmbeddingTensor(fragments.n_regions, (1,), sparse=True)
28 | init = torch.from_numpy(fragments.regionxcell_counts.sum(1).astype(np.float) / fragments.n_cells)
29 | init = torch.log(init)
30 | self.baseline.weight.data[:] = init.unsqueeze(-1)
31 |
32 | lib = torch.from_numpy(fragments.regionxcell_counts.sum(0).astype(np.float) / fragments.n_regions)
33 | lib = torch.log(lib)
34 | self.register_buffer("lib", lib)
35 |
36 | def log_prob(self, data):
37 | count = count_fragments(data)
38 |
39 | if data.fragments.lib is not None:
40 | lib = data.fragments.lib
41 | else:
42 | lib = self.lib[data.minibatch.cells_oi]
43 | logits = self.baseline(data.minibatch.regions_oi_torch).squeeze(1).unsqueeze(0) + lib.unsqueeze(1)
44 | likelihood_count = torch.distributions.Poisson(rate=torch.exp(logits)).log_prob(count)
45 | return likelihood_count
46 |
47 | def parameters_sparse(self):
48 | yield self.baseline.weight
49 |
50 |
51 | class FragmentCountDistribution1(FragmentCountDistribution):
52 | def __init__(self, fragments, clustering, baseline=None, delta_logit=None):
53 | super().__init__()
54 | if baseline is not None:
55 | self.baseline = baseline
56 | else:
57 | self.baseline = EmbeddingTensor(fragments.n_regions, (1,), sparse=True)
58 | init = torch.from_numpy(fragments.regionxcell_counts.sum(1).astype(np.float) / fragments.n_cells)
59 | init = torch.log(init)
60 | self.baseline.weight.data[:] = init.unsqueeze(-1)
61 |
62 | if delta_logit is not None:
63 | self.delta_logit = delta_logit
64 | else:
65 | self.delta_logit = EmbeddingTensor(fragments.n_regions, (clustering.n_clusters,), sparse=True)
66 | self.delta_logit.weight.data[:] = 0
67 |
68 | lib = torch.from_numpy(fragments.regionxcell_counts.sum(0).astype(np.float) / fragments.n_regions)
69 | lib = torch.log(lib)
70 | self.register_buffer("lib", lib)
71 |
72 | def log_prob(self, data):
73 | count = count_fragments(data)
74 |
75 | if data.fragments.lib is not None:
76 | lib = data.fragments.lib
77 | else:
78 | lib = self.lib[data.minibatch.cells_oi]
79 | logits = (
80 | +self.baseline(data.minibatch.regions_oi_torch).squeeze(1).unsqueeze(0)
81 | + lib.unsqueeze(1)
82 | + self.delta_logit(data.minibatch.regions_oi_torch)[:, data.clustering.labels].transpose(0, 1)
83 | )
84 |
85 | count = torch.bincount(
86 | data.fragments.local_cellxregion_ix, minlength=data.minibatch.n_cells * data.minibatch.n_regions
87 | ).reshape((data.minibatch.n_cells, data.minibatch.n_regions))
88 | likelihood_count = torch.distributions.Poisson(rate=torch.exp(logits)).log_prob(count)
89 | return likelihood_count
90 |
91 | def parameters_sparse(self):
92 | yield self.baseline.weight
93 | yield self.delta_logit.weight
94 |
95 | @classmethod
96 | def from_baseline(cls, fragments, clustering, count_reference):
97 | baseline = torch.nn.Embedding.from_pretrained(count_reference.baseline.weight.data)
98 | delta_logit = EmbeddingTensor.from_pretrained(count_reference.delta_logit)
99 |
100 | return cls(fragments=fragments, clustering=clustering, delta_logit=delta_logit, baseline=baseline)
101 |
--------------------------------------------------------------------------------
/src/chromatinhd/models/miff/model/global_norm/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import Model
2 |
--------------------------------------------------------------------------------
/src/chromatinhd/models/miff/model/local_norm/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import Model
2 |
--------------------------------------------------------------------------------
/src/chromatinhd/models/miff/model/zoom.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 |
4 |
5 | def transform_linear_spline(positions, n, width, unnormalized_heights):
6 | binsize = torch.div(width, n, rounding_mode="floor")
7 |
8 | normalized_heights = torch.nn.functional.log_softmax(unnormalized_heights, -1)
9 | if normalized_heights.ndim == positions.ndim:
10 | normalized_heights = normalized_heights.unsqueeze(0)
11 |
12 | binixs = torch.div(positions, binsize, rounding_mode="trunc")
13 |
14 | logprob = torch.gather(normalized_heights, 1, binixs.unsqueeze(1)).squeeze(1)
15 |
16 | positions = positions - binixs * binsize
17 | width = binsize
18 |
19 | return logprob, positions, width
20 |
21 |
22 | def calculate_logprob(positions, nbins, width, unnormalized_heights_zooms):
23 | """
24 | Calculate the zoomed log probability per position given a set of unnormalized_heights_zooms
25 | """
26 | assert len(nbins) == len(unnormalized_heights_zooms)
27 |
28 | curpositions = positions
29 | curwidth = width
30 | logprob = torch.zeros_like(positions, dtype=torch.float)
31 | for i, (n, unnormalized_heights_zoom) in enumerate(zip(nbins, unnormalized_heights_zooms)):
32 | assert (curwidth % n) == 0
33 | logprob_layer, curpositions, curwidth = transform_linear_spline(
34 | curpositions,
35 | n,
36 | curwidth,
37 | unnormalized_heights_zoom,
38 | )
39 | logprob += logprob_layer
40 | logprob = logprob - math.log(
41 | curwidth
42 | ) # if any width is left, we need to divide by the remaining number of possibilities to get a properly normalized probability
43 | return logprob
44 |
45 |
46 | def extract_unnormalized_heights(positions, totalbinwidths, unnormalized_heights_all):
47 | """
48 | Extracts the unnormalized heights per zoom level from the global unnormalized heights tensor with size (totaln, n)
49 | You typically do not want to use this function directly, as the benifits of a zoomed likelihood are lost in this way.
50 | This function is mainly useful for debugging, inference or testing purposes
51 | """
52 | totalbinixs = torch.div(positions[:, None], totalbinwidths, rounding_mode="floor")
53 | totalbinsectors = torch.nn.functional.pad(totalbinixs[..., :-1], (1, 0))
54 | # totalbinsectors = torch.div(totalbinixs, self.nbins[None, :], rounding_mode="floor")
55 | unnormalized_heights_zooms = [
56 | torch.index_select(
57 | unnormalized_heights_all[i], 0, totalbinsector
58 | ) # index_select is much faster than direct indexin
59 | for i, totalbinsector in enumerate(totalbinsectors.T)
60 | ]
61 |
62 | return unnormalized_heights_zooms
63 |
--------------------------------------------------------------------------------
/src/chromatinhd/models/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from chromatinhd.flow import Flow, Stored
3 |
4 |
5 | def get_sparse_parameters(module, parameters):
6 | for module in module._modules.values():
7 | for p in module.parameters_sparse() if hasattr(module, "parameters_sparse") else []:
8 | parameters.add(p)
9 | parameters = get_sparse_parameters(module, parameters)
10 | return parameters
11 |
12 |
13 | class HybridModel:
14 | def parameters_dense(self, autoextend=True):
15 | """
16 | Get all dense parameters of the model
17 | """
18 | parameters = [
19 | parameter
20 | for module in self._modules.values()
21 | for parameter in (module.parameters_dense() if hasattr(module, "parameters_dense") else [])
22 | ]
23 |
24 | # extend with any left over parameters that were not specified in parameters_dense or parameters_sparse
25 | def contains(x, y):
26 | return any([x is y_ for y_ in y])
27 |
28 | parameters_sparse = set(self.parameters_sparse())
29 |
30 | if autoextend:
31 | for p in self.parameters():
32 | if p not in parameters_sparse:
33 | parameters.append(p)
34 | parameters = [p for p in parameters if p.requires_grad]
35 | return parameters
36 |
37 | def parameters_sparse(self, autoextend=True):
38 | """
39 | Get all sparse parameters in a model
40 | """
41 | parameters = set()
42 |
43 | parameters = get_sparse_parameters(self, parameters)
44 | parameters = [p for p in parameters if p.requires_grad]
45 | return parameters
46 |
47 |
48 | class FlowModel(torch.nn.Module, HybridModel, Flow):
49 | state = Stored()
50 |
51 | def __init__(self, path=None, reset=False, **kwargs):
52 | torch.nn.Module.__init__(self)
53 | Flow.__init__(self, path=path, reset=reset, **kwargs)
54 |
55 | if self.o.state.exists(self):
56 | if reset:
57 | raise ValueError("Cannot reset a model that has already been initialized")
58 | self.restore_state()
59 |
60 | def save_state(self):
61 | from collections import OrderedDict
62 |
63 | state = OrderedDict()
64 | for k, v in self.__dict__.items():
65 | if k.lstrip("_") in self._obj_map:
66 | continue
67 | if k == "path":
68 | continue
69 | state[k] = v
70 | self.state = state
71 |
72 | @classmethod
73 | def restore(cls, path):
74 | self = cls.__new__(cls)
75 | Flow.__init__(self, path=path)
76 | self.restore_state()
77 | return self
78 |
79 | def restore_state(self):
80 | state = self.state
81 | for k, v in state.items():
82 | # if k.lstrip("_") in self._obj_map:
83 | # continue
84 | self.__dict__[k] = v
85 | return self
86 |
--------------------------------------------------------------------------------
/src/chromatinhd/models/pred/__init__.py:
--------------------------------------------------------------------------------
1 | from . import model
2 | from . import interpret
3 | from . import plot
4 |
--------------------------------------------------------------------------------
/src/chromatinhd/models/pred/interpret/__init__.py:
--------------------------------------------------------------------------------
1 | from .censorers import WindowCensorer, MultiWindowCensorer, SizeCensorer, WindowSizeCensorer
2 | from .regionmultiwindow import RegionMultiWindow
3 | from .regionpairwindow import RegionPairWindow
4 | from .regionsizewindow import RegionSizeWindow
5 | from .performance import Performance
6 | from .size import Size
7 |
--------------------------------------------------------------------------------
/src/chromatinhd/models/pred/interpret/performance.py:
--------------------------------------------------------------------------------
1 | import chromatinhd as chd
2 | import numpy as np
3 | import pandas as pd
4 | import xarray as xr
5 | import pickle
6 | import scipy.stats
7 | import tqdm.auto as tqdm
8 | from chromatinhd import get_default_device
9 |
10 | from chromatinhd.flow.objects import Stored, Dataset, StoredDict
11 |
12 | from itertools import product
13 |
14 |
15 | class Performance(chd.flow.Flow):
16 | """
17 | The train/validation/test performance of a (set of) models.
18 | """
19 |
20 | scores = chd.flow.SparseDataset()
21 |
22 | folds = None
23 |
24 | @classmethod
25 | def create(cls, folds, transcriptome, fragments, phases=None, overwrite=False, path=None):
26 | self = cls(path=path, reset=overwrite)
27 |
28 | self.folds = folds
29 | self.transcriptome = transcriptome
30 | self.fragments = fragments
31 |
32 | if self.o.scores.exists(self) and not overwrite:
33 | assert self.scores.coords_pointed["region"].equals(fragments.var.index)
34 |
35 | return self
36 |
37 | if phases is None:
38 | phases = ["train", "validation", "test"]
39 |
40 | coords_pointed = {
41 | "region": fragments.regions.var.index,
42 | "fold": pd.Index(range(len(folds)), name="fold"),
43 | "phase": pd.Index(phases, name="phase"),
44 | }
45 | coords_fixed = {}
46 | self.scores = chd.sparse.SparseDataset.create(
47 | self.path / "scores",
48 | variables={
49 | "cor": {
50 | "dimensions": ("region", "fold", "phase"),
51 | "dtype": np.float32,
52 | "sparse": False,
53 | },
54 | "scored": {
55 | "dimensions": ("region", "fold"),
56 | "dtype": np.bool,
57 | "sparse": False,
58 | },
59 | },
60 | coords_pointed=coords_pointed,
61 | coords_fixed=coords_fixed,
62 | )
63 |
64 | return self
65 |
66 | def score(
67 | self,
68 | models,
69 | force=False,
70 | device=None,
71 | regions=None,
72 | pbar=True,
73 | ):
74 | fragments = self.fragments
75 | folds = self.folds
76 |
77 | phases = self.scores.coords["phase"].values
78 |
79 | if regions is None:
80 | regions_oi = fragments.var.index.tolist() if models.regions_oi is None else models.regions_oi
81 | if isinstance(regions_oi, pd.Series):
82 | regions_oi = regions_oi.tolist()
83 | else:
84 | regions_oi = regions
85 |
86 | design = (
87 | self.scores["scored"]
88 | .sel_xr()
89 | .sel({fragments.var.index.name: regions_oi, "fold": range(len(folds))})
90 | .to_dataframe(name="scored")
91 | )
92 | design["force"] = (~design["scored"]) | force
93 |
94 | design = design.groupby("gene").any()
95 |
96 | regions_oi = design.index[design["force"]]
97 |
98 | if len(regions_oi) == 0:
99 | return self
100 |
101 | for fold_ix in range(len(folds)):
102 | for phase in phases:
103 | fold = folds[fold_ix]
104 | cells_oi = fold[f"cells_{phase}"]
105 |
106 | for region_ix, region_oi in enumerate(regions_oi):
107 | predicted, expected, n_fragments = models.get_prediction(
108 | region=region_oi,
109 | fold_ix=fold_ix,
110 | cell_ixs=cells_oi,
111 | return_raw=True,
112 | fragments=self.fragments,
113 | transcriptome=self.transcriptome,
114 | )
115 |
116 | cor = chd.utils.paircor(predicted, expected)
117 |
118 | self.scores["cor"][
119 | region_oi,
120 | fold_ix,
121 | phase,
122 | ] = cor[0]
123 | self.scores["scored"][region_oi, fold_ix] = True
124 |
125 | return self
126 |
127 | @property
128 | def scored(self):
129 | return self.o.scores.exists(self)
130 |
--------------------------------------------------------------------------------
/src/chromatinhd/models/pred/model/__init__.py:
--------------------------------------------------------------------------------
1 | from . import better
2 | from . import multiscale
3 |
--------------------------------------------------------------------------------
/src/chromatinhd/models/pred/model/loss.py:
--------------------------------------------------------------------------------
1 | def paircor(x, y, dim=0, eps=1e-3):
2 | divisor = (y.std(dim) * x.std(dim)) + eps
3 | cor = ((x - x.mean(dim, keepdims=True)) * (y - y.mean(dim, keepdims=True))).mean(dim) / divisor
4 | return cor
5 |
6 |
7 | def paircor_loss(x, y):
8 | return -paircor(x, y).sum() * 100
9 |
10 |
11 | def region_paircor_loss(x, y):
12 | return -paircor(x, y) * 100
13 |
14 |
15 | def pairzmse(x, y, dim=0, eps=1e-3):
16 | y = (y - y.mean(dim, keepdims=True)) / (y.std(dim, keepdims=True) + eps)
17 | return (y - x).pow(2).mean(dim)
18 |
19 |
20 | def pairzmse_loss(x, y):
21 | return pairzmse(x, y).mean() * 0.1
22 |
23 |
24 | def region_pairzmse_loss(x, y):
25 | return pairzmse(x, y) * 0.1
26 |
--------------------------------------------------------------------------------
/src/chromatinhd/models/pred/model/multilinear.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from chromatinhd.embedding import FeatureParameter
3 | import math
4 |
5 |
6 | class MultiLinear(torch.nn.Module):
7 | def __init__(
8 | self,
9 | in_features: int,
10 | out_features: int,
11 | n_heads: int,
12 | bias: bool = True,
13 | device=None,
14 | dtype=None,
15 | weight_constructor=None,
16 | bias_constructor=None,
17 | ):
18 | super().__init__()
19 |
20 | self.out_features = out_features
21 |
22 | if bias:
23 | if bias_constructor is None:
24 |
25 | def bias_constructor(shape):
26 | stdv = 1.0 / math.sqrt(shape[-1])
27 | return torch.empty(shape, device=device, dtype=dtype).uniform_(-stdv, stdv)
28 |
29 | bias = FeatureParameter(n_heads, (out_features,), constructor=bias_constructor)
30 | self.register_module("bias", bias)
31 | else:
32 | self.bias = None
33 |
34 | if weight_constructor is None:
35 |
36 | def weight_constructor(shape):
37 | stdv = 1.0 / math.sqrt(shape[-1])
38 | return torch.empty(shape, device=device, dtype=dtype).uniform_(-stdv, stdv)
39 |
40 | torch.nn.Linear
41 |
42 | self.weight = FeatureParameter(
43 | n_heads,
44 | (
45 | in_features,
46 | out_features,
47 | ),
48 | constructor=weight_constructor,
49 | )
50 |
51 | def forward(self, input: torch.Tensor, indptr, regions_oi):
52 | outputs = []
53 |
54 | if self.bias is not None:
55 | for ix, start, end in zip(regions_oi, indptr[:-1], indptr[1:]):
56 | outputs.append(torch.matmul(input[start:end], self.weight[ix]) + self.bias[ix])
57 | else:
58 | for ix, start, end in zip(regions_oi, indptr[:-1], indptr[1:]):
59 | outputs.append(torch.matmul(input[start:end], self.weight[ix]))
60 | output = torch.cat(outputs, dim=0)
61 | return output
62 |
--------------------------------------------------------------------------------
/src/chromatinhd/models/pred/model/multiscale.py:
--------------------------------------------------------------------------------
1 | from .better import *
2 |
--------------------------------------------------------------------------------
/src/chromatinhd/models/pred/plot/__init__.py:
--------------------------------------------------------------------------------
1 | from .predictivity import Predictivity, Pileup, PredictivityBroken, PileupBroken
2 | from .effect import Effect
3 | from .copredictivity import Copredictivity, CopredictivityBroken
4 |
5 | __all__ = ["Predictivity", "Pileup", "Effect", "Copredictivity"]
6 |
--------------------------------------------------------------------------------
/src/chromatinhd/models/pred/plot/effect.py:
--------------------------------------------------------------------------------
1 | import polyptich.grid
2 | import numpy as np
3 |
4 |
5 | class Effect(polyptich.grid.Panel):
6 | def __init__(self, plotdata, window, width, show_accessibility=False):
7 | super().__init__((width, 0.5))
8 | if "position" not in plotdata.columns:
9 | plotdata = plotdata.reset_index()
10 |
11 | ax = self.ax
12 | ax.set_xlim(*window)
13 |
14 | ax.plot(
15 | plotdata["position"],
16 | plotdata["effect"],
17 | color="#333",
18 | lw=1,
19 | )
20 | ax.fill_between(
21 | plotdata["position"],
22 | plotdata["effect"],
23 | 0,
24 | color="#333",
25 | alpha=0.2,
26 | lw=0,
27 | )
28 |
29 | ax.set_ylabel(
30 | "Effect",
31 | rotation=0,
32 | ha="right",
33 | va="center",
34 | )
35 |
36 | ax.set_xticks([])
37 | ax.invert_yaxis()
38 | ymax = plotdata["effect"].abs().max()
39 | ax.set_ylim(-ymax, ymax)
40 |
41 | # change vertical alignment of last y tick to bottom
42 | ax.set_yticks([0, ax.get_ylim()[1]])
43 | ax.get_yticklabels()[-1].set_verticalalignment("top")
44 | ax.get_yticklabels()[0].set_verticalalignment("bottom")
45 |
46 | # vline at tss
47 | ax.axvline(0, color="#888888", lw=0.5, zorder=-1, dashes=(2, 2))
48 |
49 | @classmethod
50 | def from_RegionMultiWindow(cls, RegionMultiWindow, gene, width, show_accessibility=False):
51 | plotdata = RegionMultiWindow.get_plotdata(gene).reset_index()
52 | window = np.array([plotdata["position"].min(), plotdata["position"].max()])
53 | return cls(plotdata, window, width, show_accessibility=show_accessibility)
54 |
--------------------------------------------------------------------------------
/src/chromatinhd/models/pred/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | from .trainer import Trainer, SharedTrainer, TrainerPerFeature
2 |
--------------------------------------------------------------------------------
/src/chromatinhd/models/pret/model/__init__.py:
--------------------------------------------------------------------------------
1 | from . import better
2 |
--------------------------------------------------------------------------------
/src/chromatinhd/models/pret/model/loss.py:
--------------------------------------------------------------------------------
1 | def paircor(x, y, dim=0, eps=1e-3):
2 | divisor = (y.std(dim) * x.std(dim)) + eps
3 | cor = ((x - x.mean(dim, keepdims=True)) * (y - y.mean(dim, keepdims=True))).mean(dim) / divisor
4 | return cor
5 |
6 |
7 | def paircor_loss(x, y):
8 | return -paircor(x, y).mean() * 100
9 |
10 |
11 | def region_paircor_loss(x, y):
12 | return -paircor(x, y) * 100
13 |
14 |
15 | def pairzmse(x, y, dim=0, eps=1e-3):
16 | y = (y - y.mean(dim, keepdims=True)) / (y.std(dim, keepdims=True) + eps)
17 | return (y - x).pow(2).mean(dim)
18 |
19 |
20 | def pairzmse_loss(x, y):
21 | return pairzmse(x, y).mean() * 0.1
22 |
23 |
24 | def region_pairzmse_loss(x, y):
25 | return pairzmse(x, y) * 0.1
26 |
--------------------------------------------------------------------------------
/src/chromatinhd/models/pret/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | from .trainer import Trainer
2 |
--------------------------------------------------------------------------------
/src/chromatinhd/optim.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import itertools
3 |
4 |
5 | class SparseDenseAdam(torch.optim.Optimizer):
6 | """
7 | Optimize both sparse and densre parameters using ADAM
8 | """
9 |
10 | def __init__(self, parameters_sparse, parameters_dense, lr=1e-3, weight_decay=0.0, **kwargs):
11 | if len(parameters_sparse) == 0:
12 | self.optimizer_sparse = None
13 | else:
14 | self.optimizer_sparse = torch.optim.SparseAdam(parameters_sparse, lr=lr, **kwargs)
15 | if len(parameters_dense) == 0:
16 | self.optimizer_dense = None
17 | else:
18 | self.optimizer_dense = torch.optim.Adam(parameters_dense, lr=lr, weight_decay=weight_decay, **kwargs)
19 |
20 | def zero_grad(self):
21 | if self.optimizer_sparse is not None:
22 | self.optimizer_sparse.zero_grad()
23 | if self.optimizer_dense is not None:
24 | self.optimizer_dense.zero_grad()
25 |
26 | def step(self):
27 | if self.optimizer_sparse is not None:
28 | self.optimizer_sparse.step()
29 | if self.optimizer_dense is not None:
30 | self.optimizer_dense.step()
31 |
32 | @property
33 | def param_groups(self):
34 | return itertools.chain(
35 | self.optimizer_dense.param_groups if self.optimizer_dense is not None else [],
36 | self.optimizer_sparse.param_groups if self.optimizer_sparse is not None else [],
37 | )
38 |
39 |
40 | class AdamPerFeature(torch.optim.Optimizer):
41 | def __init__(self, parameters, n_features, lr=1e-3, weight_decay=0.0, **kwargs):
42 | self.n_features = n_features
43 |
44 | self.adams = []
45 | for i in range(n_features):
46 | parameters_features = [p[i] for p in parameters]
47 | self.adams.append(torch.optim.Adam(parameters_features, lr=lr, weight_decay=weight_decay, **kwargs))
48 |
49 | def zero_grad(self, feature_ixs):
50 | for i in feature_ixs:
51 | self.adams[i].zero_grad()
52 |
53 | def step(self, feature_ixs):
54 | for i in feature_ixs:
55 | self.adams[i].step()
56 |
57 | @property
58 | def param_groups(self):
59 | return itertools.chain(*[adam.param_groups for adam in self.adams])
60 |
61 |
62 | class SGDPerFeature(torch.optim.Optimizer):
63 | def __init__(self, parameters, n_features, lr=1e-3, weight_decay=0.0, **kwargs):
64 | self.n_features = n_features
65 |
66 | self.adams = []
67 | for i in range(n_features):
68 | parameters_features = [p[i] for p in parameters]
69 | self.adams.append(
70 | torch.optim.SGD(parameters_features, lr=lr, weight_decay=weight_decay, momentum=0.5, **kwargs)
71 | )
72 |
73 | def zero_grad(self, feature_ixs):
74 | for i in feature_ixs:
75 | self.adams[i].zero_grad()
76 |
77 | def step(self, feature_ixs):
78 | for i in feature_ixs:
79 | self.adams[i].step()
80 |
81 | @property
82 | def param_groups(self):
83 | return itertools.chain(*[adam.param_groups for adam in self.adams])
84 |
--------------------------------------------------------------------------------
/src/chromatinhd/plot/__init__.py:
--------------------------------------------------------------------------------
1 | from .tickers import distance_ticker, gene_ticker, DistanceFormatter, format_distance, round_significant
2 | from . import genome
3 |
4 | from .patch import replace_patch
5 | from .matshow45 import matshow45
6 | from . import tracks
7 |
8 |
9 |
--------------------------------------------------------------------------------
/src/chromatinhd/plot/genome/__init__.py:
--------------------------------------------------------------------------------
1 | from .genes import Genes
2 |
--------------------------------------------------------------------------------
/src/chromatinhd/plot/matshow45.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib as mpl
3 |
4 |
5 | def matshow45(ax, series, radius=None, cmap=None, norm=None):
6 | """
7 | fig, ax = plt.subplots()
8 | plotdata = pd.DataFrame(
9 | np.random.rand(10, 10),
10 | index=np.arange(10),
11 | columns=np.arange(10),
12 | ).stack()
13 | ax.set_aspect(1)
14 | matshow45(ax, plotdata)
15 | """
16 | offsets = []
17 | colors = []
18 |
19 | assert len(series.index.names) == 2
20 | x = series.index.get_level_values(0)
21 | y = series.index.get_level_values(1)
22 |
23 | if radius is None:
24 | radius = np.diff(y)[0] / 2
25 |
26 | centerxs = x + (y - x) / 2
27 | centerys = (y - x) / 2
28 |
29 | xlim = [
30 | centerxs.min() - radius,
31 | centerxs.max() + radius,
32 | ]
33 | ax.set_xlim(xlim)
34 |
35 | ylim = [
36 | centerys.unique().min(),
37 | centerys.unique().max(),
38 | ]
39 | ax.set_ylim(ylim)
40 |
41 | if norm is None:
42 | norm = mpl.colors.Normalize(vmin=series.min(), vmax=series.max())
43 |
44 | if cmap is None:
45 | cmap = mpl.cm.get_cmap()
46 |
47 | vertices = []
48 |
49 | for centerx, centery, value in zip(centerxs, centerys, series.values):
50 | center = np.array([centerx, centery])
51 | offsets.append(center)
52 | colors.append(cmap(norm(value)))
53 |
54 | vertices.append(
55 | [
56 | center + np.array([radius * 1.1, 0]),
57 | center + np.array([0, radius * 1.1]),
58 | center + np.array([-radius * 1.1, 0]),
59 | center + np.array([0, -radius * 1.1]),
60 | ]
61 | )
62 | vertices = np.array(vertices)
63 | collection = mpl.collections.PolyCollection(
64 | vertices,
65 | ec=None,
66 | lw=0,
67 | fc=colors,
68 | )
69 |
70 | ax.add_collection(collection)
71 |
72 | for x in [xlim[1]]:
73 | x2 = x
74 | x1 = x2 + (xlim[0] - x2) / 2
75 | y2 = 0
76 | y1 = x2 - x1
77 |
78 | if True:
79 | color = "black"
80 | lw = 0.8
81 | zorder = 10
82 | elif False:
83 | color = "#eee"
84 | lw = 0.5
85 | zorder = -1
86 | ax.plot(
87 | [x1 + radius * 1.2, x2 + radius * 1.2],
88 | [y1, y2],
89 | zorder=zorder,
90 | color=color,
91 | lw=lw,
92 | )
93 |
94 | x1, x2 = (xlim[1] - xlim[0]) / 2 - x2, (xlim[1] - xlim[0]) / 2 - x1
95 | ax.plot(
96 | [-x1 - radius * 1.2, -x2 - radius * 1.2],
97 | [y1, y2],
98 | zorder=zorder,
99 | color=color,
100 | lw=lw,
101 | )
102 | ax.axis("off")
103 |
--------------------------------------------------------------------------------
/src/chromatinhd/plot/patch.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 |
3 |
4 | phases = pd.DataFrame({"phase": ["train", "validation"], "color": ["#888888", "tomato"]}).set_index("phase")
5 |
6 |
7 | colors = ["#0074D9", "#FF4136", "#FF851B", "#2ECC40", "#39CCCC", "#85144b"]
8 |
9 |
10 | def replace_patch(ax, patch, points=10, ha="center", va="center"):
11 | """
12 | Replaces a patch, often an axis label, with a new rectangular axis according to figure size "points". Eample usage:
13 |
14 |
15 | fig, ax = plt.subplots()
16 | ax.set_title("100%, (0.5,1-0.3,.3,.3)")
17 | x.plot([0, 2], [0, 2])
18 |
19 | for l in ax.get_xticklabels():
20 | ax1 = replace_patch(ax, l)
21 | ax1.plot([0, 1], [0, 1])
22 | """
23 | fig = ax.get_figure()
24 |
25 | fig.draw_without_rendering()
26 |
27 | w, h = fig.transFigure.inverted().transform([[1, 1]]).ravel() * points
28 | bbox = fig.transFigure.inverted().transform(patch.get_window_extent())
29 | dw = bbox[1, 0] - bbox[0, 0]
30 | dh = bbox[1, 1] - bbox[0, 1]
31 | x = bbox[0, 0] + dw / 2
32 | y = bbox[0, 1] + dh / 2
33 |
34 | if ha == "left":
35 | x += 0
36 | elif ha == "right":
37 | x -= w - dw / 2
38 | elif ha == "center":
39 | x -= w / 2
40 |
41 | if va == "bottom":
42 | y += 0
43 | elif va == "top":
44 | y -= h - dh / 2
45 | elif va == "center":
46 | y -= h / 2
47 |
48 | ax1 = fig.add_axes([x, y, w, h])
49 | # ax1.axis("off")
50 |
51 | return ax1
52 |
--------------------------------------------------------------------------------
/src/chromatinhd/plot/quasirandom.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 |
4 |
5 | def digits2number(digits, base=2, fractional=False):
6 | powers = np.arange(len(digits))
7 | out = (digits * base**powers).sum()
8 | if fractional:
9 | out = out / base ** (len(digits))
10 | return out
11 |
12 |
13 | def number2digits(n, base=2):
14 | if n == 0:
15 | return np.array()
16 | nDigits = np.ceil(np.log(n + 1) / np.log(base))
17 | powers = base ** (np.arange(nDigits + 1))
18 | out = np.diff(n % powers) / powers[:-1]
19 | return out
20 |
21 |
22 | def vanDerCorput(n, base=2, start=1):
23 | out = np.array([digits2number(number2digits(ii, base)[::-1], base, True) for ii in range(1, n + start)])
24 | return out
25 |
26 |
27 | def offset(
28 | y,
29 | maxLength=None,
30 | method="quasirandom",
31 | nbins=20,
32 | adjust=1,
33 | bw_method="scott",
34 | max_density=None,
35 | ):
36 | if len(y) == 1:
37 | return [0]
38 |
39 | if isinstance(y, pd.Series):
40 | y = y.values
41 |
42 | if nbins is None:
43 | if method in ["pseudorandom", "quasirandom"]:
44 | nbins = 2**10
45 | else:
46 | nbins = int(max(2, np.ceil(len(y) / 5)))
47 |
48 | if maxLength is None:
49 | subgroup_width = 1
50 | else:
51 | subgroup_width = np.sqrt(len(y) / maxLength)
52 |
53 | # from sklearn.neighbors import KernelDensity
54 | # kde = KernelDensity(kernel='gaussian').fit(y.reshape(-1, 1))
55 | # pointDensities = np.exp(np.array(kde.score_samples(y.reshape(-1, 1))))
56 |
57 | import scipy.stats
58 |
59 | kernel = scipy.stats.gaussian_kde(y, bw_method=bw_method)
60 | pointDensities = kernel(y)
61 |
62 | if max_density is None:
63 | max_density = pointDensities.max()
64 |
65 | pointDensities = pointDensities / max_density
66 |
67 | if method == "quasirandom":
68 | offset = np.array(vanDerCorput(len(y)))[np.argsort(y)]
69 | elif method == "pseudorandom":
70 | offset = np.random.uniform(size=len(y))
71 |
72 | out = (offset - 0.5) * 2 * pointDensities * subgroup_width * 0.5
73 |
74 | return out
75 |
76 |
77 | def offsetr(y, **kwargs):
78 | from rpy2.robjects import pandas2ri
79 | import rpy2.robjects as ro
80 | from rpy2.robjects.packages import importr
81 | from rpy2.robjects import numpy2ri
82 | from rpy2.robjects import default_converter
83 |
84 | ro.numpy2ri.activate()
85 | vipor = ro.packages.importr("vipor")
86 | return vipor.offsetSingleGroup(y, method="quasi", **kwargs) * 0.5
87 |
--------------------------------------------------------------------------------
/src/chromatinhd/plot/tickers.py:
--------------------------------------------------------------------------------
1 | import matplotlib.ticker as ticker
2 |
3 | import decimal
4 |
5 | import functools
6 | import math
7 |
8 |
9 | def count_zeros(value):
10 | decimal_value = str(decimal.Decimal((value)))
11 | if "." in decimal_value:
12 | count = len(decimal_value.split(".")[1])
13 | else:
14 | count = 0
15 | return count
16 |
17 |
18 | def round_significant(value, significant=2):
19 | """
20 | Round to a significant number of digits, including trailing zeros.
21 | For example round_significant(15400, 2) = 15000
22 | """
23 |
24 | if value == 0:
25 | return 0
26 |
27 | n = math.log10(abs(value))
28 | return round(value, significant - int(n) - 1)
29 |
30 |
31 | def format_distance(value, tick_pos=None, add_sign=True):
32 | abs_value = int(abs(value))
33 | if abs_value >= 1000000:
34 | # zeros = 0
35 | zeros = len(str(abs_value).rstrip("0")) - 1
36 | abs_value = abs_value / 1000000
37 | suffix = "mb"
38 | elif abs_value >= 1000:
39 | zeros = 0
40 | # zeros = len(str(abs_value).rstrip("0")) - 1
41 | abs_value = abs_value / 1000
42 | suffix = "kb"
43 | elif abs_value == 0:
44 | # zeros = 0
45 | return "TSS"
46 | else:
47 | zeros = 0
48 | suffix = "b"
49 |
50 | formatted_value = ("{abs_value:." + str(zeros) + "f}{suffix}").format(abs_value=abs_value, suffix=suffix)
51 |
52 | if not add_sign:
53 | return formatted_value
54 | return f"-{formatted_value}" if value < 0 else f"+{formatted_value}"
55 |
56 |
57 | gene_ticker = ticker.FuncFormatter(format_distance)
58 |
59 |
60 | def custom_formatter(value, tick_pos, base=1):
61 | if base != 1:
62 | value = value / base
63 |
64 | abs_value = abs(value)
65 | if abs_value >= 1000000:
66 | abs_value = abs_value / 1000000
67 | suffix = "mb"
68 | elif abs_value >= 1000:
69 | abs_value = abs_value / 1000
70 | suffix = "kb"
71 | elif abs_value == 0:
72 | return "0"
73 | else:
74 | suffix = "b"
75 |
76 | zeros = count_zeros(abs_value)
77 |
78 | formatted_value = ("{abs_value:." + str(zeros) + "f}{suffix}").format(abs_value=abs_value, suffix=suffix)
79 | return formatted_value
80 |
81 |
82 | distance_ticker = ticker.FuncFormatter(custom_formatter)
83 |
84 |
85 | def DistanceFormatter(base=1):
86 | return ticker.FuncFormatter(functools.partial(custom_formatter, base=base))
87 |
88 |
89 | def lighten_color(color, amount=0.5):
90 | """
91 | Lightens the given color by multiplying (1-luminosity) by the given amount.
92 | Input can be matplotlib color string, hex string, or RGB tuple.
93 |
94 | Examples:
95 | >> lighten_color('g', 0.3)
96 | >> lighten_color('#F034A3', 0.6)
97 | >> lighten_color((.3,.55,.1), 0.5)
98 | """
99 | import matplotlib.colors as mc
100 | import colorsys
101 |
102 | try:
103 | c = mc.cnames[color]
104 | except KeyError:
105 | c = color
106 | c = colorsys.rgb_to_hls(*mc.to_rgb(c))
107 | return colorsys.hls_to_rgb(c[0], 1 - amount * (1 - c[1]), c[2])
108 |
--------------------------------------------------------------------------------
/src/chromatinhd/plot/tracks/__init__.py:
--------------------------------------------------------------------------------
1 | from .tracks import TracksBroken
2 |
--------------------------------------------------------------------------------
/src/chromatinhd/utils/ecdf.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def ecdf(data):
5 | sorted_data = np.sort(data)
6 | y = np.arange(1, len(sorted_data) + 1) / len(sorted_data)
7 |
8 | sorted_data = np.vstack([sorted_data - 1e-5, sorted_data]).T.flatten()
9 | y = np.vstack([np.hstack([[0], y[:-1]]), y]).T.flatten()
10 | return sorted_data, y
11 |
12 |
13 | def weighted_ecdf(data, weights):
14 | sorting = np.argsort(data)
15 | sorted_data = data[sorting]
16 | y = np.cumsum(weights[sorting] / weights.sum())
17 |
18 | sorted_data = np.vstack([sorted_data - 1e-5, sorted_data]).T.flatten()
19 | y = np.vstack([np.hstack([[0], y[:-1]]), y]).T.flatten()
20 | return sorted_data, y
21 |
22 |
23 | def area_under_ecdf(list1):
24 | x1, y1 = ecdf(list1)
25 | area = np.trapz(y1, x1)
26 | return area
27 |
28 |
29 | def area_between_ecdfs(list1, list2):
30 | x1, y1 = ecdf(list1)
31 | x2, y2 = ecdf(list2)
32 |
33 | combined_x = np.concatenate([[-0.001], np.sort(np.unique(np.concatenate((x1, x2))))])
34 |
35 | y1_interp = np.interp(combined_x, x1, y1, left=0, right=1)
36 | y2_interp = np.interp(combined_x, x2, y2, left=0, right=1)
37 |
38 | ecdf_diff = y1_interp - y2_interp
39 | area = np.trapz(ecdf_diff, combined_x)
40 | return area
41 |
42 |
43 | def relative_area_between_ecdfs(list1, list2):
44 | x1, y1 = ecdf(list1)
45 | area = area_between_ecdfs(list1, list2) / np.trapz(y1, x1)
46 | return area
47 |
--------------------------------------------------------------------------------
/src/chromatinhd/utils/interleave.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | def interleave(x, repeats=np.array([1, 2, 4, 8, 16])):
6 | assert isinstance(repeats, np.ndarray)
7 | out_shape = int(x.shape[-1] // (1 / repeats).sum())
8 | out = torch.zeros((*x.shape[:-1], out_shape), dtype=x.dtype, device=x.device)
9 | k = out_shape // repeats
10 | i = 0
11 | for k, r in zip(k, repeats):
12 | out += torch.repeat_interleave(x[..., i : i + k], r)
13 | i += k
14 | return out
15 |
16 |
17 | def deinterleave(y, repeats=np.array([1, 2, 4, 8, 16])):
18 | assert isinstance(repeats, np.ndarray)
19 | x = []
20 | for r in repeats[::-1]:
21 | x_ = y.reshape((*y.shape[:-1], -1, r)).mean(dim=-1)
22 | y = y - torch.repeat_interleave(x_, r)
23 | x.append(x_)
24 | return torch.cat(x[::-1], -1)
25 |
--------------------------------------------------------------------------------
/src/chromatinhd/utils/intervals.py:
--------------------------------------------------------------------------------
1 | def interval_contains_inclusive(x, y):
2 | """
3 | Determines whether the intervals in x are contained in any interval of y
4 | """
5 | contained = ~((y[:, 1] < x[:, 0][:, None]) | (y[:, 0] > x[:, 1][:, None]))
6 | return contained.any(1)
7 |
--------------------------------------------------------------------------------
/src/chromatinhd/utils/numpy.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def indices_to_indptr(x, n, dtype=np.int32):
5 | return np.pad(np.cumsum(np.bincount(x, minlength=n), 0, dtype=dtype), (1, 0))
6 |
7 |
8 | ind2ptr = indices_to_indptr
9 |
10 |
11 | def indptr_to_indices(x):
12 | n = len(x) - 1
13 | return np.repeat(np.arange(n), np.diff(x))
14 |
15 |
16 | ptr2ind = indptr_to_indices
17 |
18 |
19 | def indices_to_indptr_chunked(x, n, dtype=np.int32, batch_size=10e3):
20 | counts = np.zeros(n + 1, dtype=dtype)
21 | cur_value = 0
22 | for a, b in zip(
23 | np.arange(0, len(x), batch_size, dtype=int), np.arange(batch_size, len(x) + batch_size, batch_size, dtype=int)
24 | ):
25 | x_ = x[a:b]
26 | assert (x_ >= cur_value).all()
27 | bincount = np.bincount(x_ - cur_value)
28 | counts[(cur_value + 1) : (cur_value + len(bincount) + 1)] += bincount
29 | cur_value = x_[-1]
30 | indptr = np.cumsum(counts, dtype=dtype)
31 | return indptr
32 |
33 |
34 |
35 | def interpolate_1d(x: np.ndarray, xp: np.ndarray, fp: np.ndarray) -> np.ndarray:
36 | a = (fp[1:] - fp[:-1]) / (xp[1:] - xp[:-1])
37 | b = fp[:-1] - (a * xp[:-1])
38 |
39 | indices = np.searchsorted(xp, x, side="left") - 1
40 | indices = np.clip(indices, 0, a.shape[0] - 1)
41 |
42 | slope = a[indices]
43 | intercept = b[indices]
44 | return x * slope + intercept
45 |
--------------------------------------------------------------------------------
/src/chromatinhd/utils/scanpy.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scanpy as sc
3 | import pandas as pd
4 |
5 |
6 | # Define cluster score for all markers
7 | def evaluate_partition(anndata, marker_dict, gene_symbol_key=None, partition_key="louvain_r1"):
8 | # Inputs:
9 | # anndata - An AnnData object containing the data set and a partition
10 | # marker_dict - A dictionary with cell-type markers. The markers should be stores as anndata.var_names or
11 | # an anndata.var field with the key given by the gene_symbol_key input
12 | # gene_symbol_key - The key for the anndata.var field with gene IDs or names that correspond to the marker
13 | # genes
14 | # partition_key - The key for the anndata.obs field where the cluster IDs are stored. The default is
15 | # 'louvain_r1'
16 |
17 | # Test inputs
18 | if partition_key not in anndata.obs.columns.values:
19 | print("KeyError: The partition key was not found in the passed AnnData object.")
20 | print(" Have you done the clustering? If so, please tell pass the cluster IDs with the AnnData object!")
21 | raise
22 |
23 | if (gene_symbol_key is not None) and (gene_symbol_key not in anndata.var.columns.values):
24 | print("KeyError: The provided gene symbol key was not found in the passed AnnData object.")
25 | print(" Check that your cell type markers are given in a format that your anndata object knows!")
26 | raise
27 |
28 | if gene_symbol_key:
29 | gene_ids = anndata.var[gene_symbol_key]
30 | else:
31 | gene_ids = anndata.var_names
32 |
33 | clusters = np.unique(anndata.obs[partition_key])
34 | n_clust = len(clusters)
35 | n_groups = len(marker_dict)
36 |
37 | marker_res = np.zeros((n_groups, n_clust))
38 | z_scores = sc.pp.scale(anndata, copy=True)
39 |
40 | i = 0
41 | for group in marker_dict:
42 | # Find the corresponding columns and get their mean expression in the cluster
43 | j = 0
44 | for clust in clusters:
45 | cluster_cells = np.in1d(z_scores.obs[partition_key], clust)
46 | marker_genes = np.in1d(gene_ids, marker_dict[group])
47 | marker_res[i, j] = z_scores.X[np.ix_(cluster_cells, marker_genes)].mean()
48 | j += 1
49 | i += 1
50 |
51 | variances = np.nanvar(marker_res, axis=0)
52 | if np.all(np.isnan(variances)):
53 | print("No variances could be computed, check if your cell markers are in the data set.")
54 | print("Maybe the cell marker IDs do not correspond to your gene_symbol_key input or the var_names")
55 | raise
56 |
57 | marker_res_df = pd.DataFrame(marker_res, columns=clusters, index=marker_dict.keys())
58 |
59 | # Return the median of the variances over the clusters
60 | return marker_res_df
61 |
--------------------------------------------------------------------------------
/src/chromatinhd/utils/testing.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy
3 |
4 |
5 | def fdr(p_vals):
6 | from scipy.stats import rankdata
7 |
8 | ranked_p_values = rankdata(p_vals)
9 | fdr = p_vals * len(p_vals) / ranked_p_values
10 | fdr[fdr > 1] = 1
11 |
12 | return fdr
13 |
14 |
15 | def autocorrelation(x):
16 | """
17 | Calculate the autocorrelation of a sequence.
18 |
19 | Args:
20 | x (list): A sequence of numbers.
21 |
22 | Returns:
23 | float: Autocorrelation value.
24 | """
25 | n = len(x)
26 | x_bar = np.mean(x)
27 | numerator = np.sum([(x[i] - x_bar) * (x[i - 1] - x_bar) for i in range(1, n)])
28 | denominator = np.sum([(x[i] - x_bar) ** 2 for i in range(n)])
29 |
30 | return numerator / denominator
31 |
32 |
33 | # def repeated_kfold_corrected_t_test(performance_A, performance_B, k, num_repeats, alpha=0.05):
34 | # """
35 | # Perform the corrected t-test between two learning algorithms A and B for repeated K-fold cross-validation.
36 |
37 | # Args:
38 | # performance_A (list): A list of performance scores for algorithm A.
39 | # performance_B (list): A list of performance scores for algorithm B.
40 | # k (int): Number of folds in the cross-validation.
41 | # num_repeats (int): Number of times the K-fold cross-validation was repeated.
42 | # alpha (float): Significance level, default is 0.05.
43 |
44 | # Returns:
45 | # bool: True if there is a significant difference, False otherwise.
46 | # float: t-statistic value.
47 | # float: p-value.
48 | # """
49 |
50 | # if len(performance_A) != len(performance_B) or len(performance_A) != k * num_repeats:
51 | # raise ValueError("Performance scores for each algorithm
52 | # should have the same length and match k * num_repeats.")
53 |
54 | # n = k * num_repeats
55 |
56 | # d = [performance_A[i] - performance_B[i] for i in range(n)]
57 | # d_bar = np.mean(d)
58 | # s_d = np.std(d, ddof=1)
59 |
60 | # rho = autocorrelation(d)
61 | # effective_sample_size = n * (1 - rho) / (1 + rho)
62 |
63 | # t_statistic = d_bar / (s_d / np.sqrt(effective_sample_size))
64 | # degrees_of_freedom = effective_sample_size - 1
65 | # p_value = 2 * (1 - stats.t.cdf(abs(t_statistic), degrees_of_freedom))
66 |
67 | # reject_null_hypothesis = p_value < alpha
68 |
69 | # return reject_null_hypothesis, t_statistic, p_value
70 |
71 |
72 | def repeated_kfold_corrected_t_test(diff, r, k, n_train, n_test):
73 | diff_corrected = 1 / (k * r) * diff.sum()
74 | variance = 1 / (k * r - 1) * ((diff - diff_corrected) ** 2).sum()
75 | t = 1 / (k * r) * diff_corrected / np.sqrt((1 / (k * r) + n_test / n_train) * variance)
76 |
77 | scipy.stats.t.cdf(t, r * k - 1)
78 | return t
79 |
--------------------------------------------------------------------------------
/src/chromatinhd/utils/timing.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import time
3 |
4 |
5 | class catchtime(object):
6 | def __init__(self, dict, name):
7 | self.name = name
8 | self.dict = dict
9 |
10 | def __enter__(self):
11 | self.t = time.time()
12 | return self
13 |
14 | def __exit__(self, type, value, traceback):
15 | self.t = time.time() - self.t
16 | self.dict[self.name] += self.t
17 |
18 |
19 | class timer(object):
20 | def __init__(self):
21 | self.times = collections.defaultdict(float)
22 |
23 | def catch(self, name):
24 | return catchtime(self.times, name)
25 |
--------------------------------------------------------------------------------
/src/chromatinhd/utils/torch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | def interpolate_1d(x: torch.Tensor, xp: torch.Tensor, fp: torch.Tensor) -> torch.Tensor:
6 | a = (fp[..., 1:] - fp[..., :-1]) / (xp[..., 1:] - xp[..., :-1])
7 | b = fp[..., :-1] - (a.mul(xp[..., :-1]))
8 |
9 | indices = torch.searchsorted(xp.contiguous(), x.contiguous(), right=False) - 1
10 | indices = torch.clamp(indices, 0, a.shape[-1] - 1)
11 |
12 | slope = a.index_select(a.ndim - 1, indices)
13 | intercept = b.index_select(a.ndim - 1, indices)
14 | return x * slope + intercept
15 |
16 | def interpolate_0d(x: torch.Tensor, xp: torch.Tensor, fp: torch.Tensor) -> torch.Tensor:
17 | a = (fp[..., 1:] - fp[..., :-1]) / (xp[..., 1:] - xp[..., :-1])
18 | b = fp[..., :-1] - (a.mul(xp[..., :-1]))
19 |
20 | indices = torch.searchsorted(xp.contiguous(), x.contiguous(), right=False) - 1
21 | indices = torch.clamp(indices, 0, a.shape[-1] - 1)
22 |
23 | return b.index_select(a.ndim - 1, indices)
24 |
25 |
26 | def indices_to_indptr(x: torch.Tensor, n: int) -> torch.Tensor:
27 | return torch.nn.functional.pad(torch.cumsum(torch.bincount(x, minlength=n), 0), (1, 0))
28 |
29 |
30 | ind2ptr = indices_to_indptr
31 |
32 |
33 | def indptr_to_indices(x: torch.Tensor) -> torch.Tensor:
34 | return torch.repeat_interleave(torch.arange(len(x) - 1), torch.diff(x))
35 |
36 |
37 | ptr2ind = indptr_to_indices
38 |
--------------------------------------------------------------------------------
/tests/_test_sparse.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import chromatinhd as chd
3 | import numpy as np
4 |
5 |
6 | class TestSparse:
7 | def test_simple(self):
8 | path = "./hello"
9 | coords_pointed = {"a": pd.Index(["a", "b", "c"]), "b": pd.Index(["A", "B"])}
10 | coords_fixed = {"c": pd.Index(["1", "2", "3"]), "d": pd.Index(["100", "200"])}
11 | variables = {"x": ("a", "b", "c"), "y": ("a", "c", "d")}
12 |
13 | dataset = chd.sparse.SparseDataset2.create(path, variables, coords_pointed, coords_fixed)
14 |
15 | dataset["x"]["a", "B"] = np.array([1, 2, 3])
16 | dataset["x"]["c", "B"] = np.array([1, 0, 4])
17 | dataset["y"]["a"] = np.array([[1, 0], [2, 3], [1, 2]])
18 |
19 | dataset["x"]._read_data()[:]
20 |
21 | assert (dataset["x"]["a", "B"] == np.array([1.0, 2.0, 3.0])).all()
22 | assert (dataset["x"]["a", "A"] == np.array([0.0, 0.0, 0.0])).all()
23 | assert (dataset["x"]["b", "B"] == np.array([0.0, 0.0, 0.0])).all()
24 | assert (dataset["x"]["c", "B"] == np.array([1, 0, 4])).all()
25 | assert (dataset["y"]["a"] == np.array([[1, 0], [2, 3], [1, 2]])).all()
26 |
--------------------------------------------------------------------------------
/tests/biomart/conftest.py:
--------------------------------------------------------------------------------
1 | import chromatinhd as chd
2 | import pytest
3 |
4 | import pathlib
5 |
6 |
7 | @pytest.fixture(scope="module")
8 | def dataset_grch38():
9 | return chd.biomart.Dataset.from_genome("GRCh38")
10 |
--------------------------------------------------------------------------------
/tests/biomart/tss_test.py:
--------------------------------------------------------------------------------
1 | import chromatinhd as chd
2 |
3 |
4 | class TestGetCanonicalTranscripts:
5 | def test_simple(self, dataset_grch38):
6 | pass
7 | # transcripts = chd.biomart.tss.get_canonical_transcripts(biomart_dataset = dataset_grch38)
8 |
9 | # assert transcripts.shape[0] > 1000
10 |
11 |
12 | class TestGetExons:
13 | def test_simple(self, dataset_grch38):
14 | pass
15 | # exons = chd.biomart.tss.get_exons(biomart_dataset = dataset_grch38, chrom = "chr1", start = 1000, end = 2000)
16 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import pathlib
3 | import chromatinhd as chd
4 |
5 |
6 | @pytest.fixture(scope="session")
7 | def example_dataset_folder(tmp_path_factory):
8 | example_dataset_folder = tmp_path_factory.mktemp("example")
9 |
10 | import pkg_resources
11 | import shutil
12 |
13 | DATA_PATH = pathlib.Path(pkg_resources.resource_filename("chromatinhd", "data/examples/pbmc10ktiny/"))
14 |
15 | # copy all files from data path to dataset folder
16 | for file in DATA_PATH.iterdir():
17 | shutil.copy(file, example_dataset_folder / file.name)
18 | return example_dataset_folder
19 |
20 |
21 | @pytest.fixture(scope="session")
22 | def example_transcriptome(example_dataset_folder):
23 | import scanpy as sc
24 |
25 | adata = sc.read(example_dataset_folder / "transcriptome.h5ad")
26 | transcriptome = chd.data.Transcriptome.from_adata(adata, path=example_dataset_folder / "transcriptome")
27 | return transcriptome
28 |
29 |
30 | @pytest.fixture(scope="session")
31 | def example_clustering(example_dataset_folder, example_transcriptome):
32 | clustering = chd.data.Clustering.from_labels(
33 | example_transcriptome.adata.obs["celltype"],
34 | path=example_dataset_folder / "clustering",
35 | )
36 | return clustering
37 |
38 |
39 | @pytest.fixture(scope="session")
40 | def example_regions(example_dataset_folder, example_transcriptome):
41 | biomart_dataset = chd.biomart.Dataset.from_genome("GRCh38")
42 | canonical_transcripts = chd.biomart.get_canonical_transcripts(biomart_dataset, example_transcriptome.var.index)
43 | regions = chd.data.Regions.from_transcripts(
44 | canonical_transcripts,
45 | path=example_dataset_folder / "regions",
46 | window=[-10000, 10000],
47 | )
48 | return regions
49 |
50 |
51 | @pytest.fixture(scope="session")
52 | def example_fragments(example_dataset_folder, example_transcriptome, example_regions):
53 | fragments = chd.data.Fragments.from_fragments_tsv(
54 | example_dataset_folder / "fragments.tsv.gz",
55 | example_regions,
56 | obs=example_transcriptome.obs,
57 | path=example_dataset_folder / "fragments",
58 | )
59 | fragments.create_regionxcell_indptr()
60 | return fragments
61 |
62 |
63 | @pytest.fixture(scope="session")
64 | def example_folds(example_dataset_folder, example_fragments):
65 | folds = chd.data.folds.Folds(example_dataset_folder / "random_5fold")
66 | folds.sample_cells(example_fragments, n_folds=5, n_repeats=1)
67 | return folds
68 |
--------------------------------------------------------------------------------
/tests/loaders/test_fragment_helpers.py:
--------------------------------------------------------------------------------
1 | import chromatinhd.loaders.fragments_helpers
2 | import numpy as np
3 |
4 |
5 | def test_multiple_arange():
6 | a = np.array([10, 50, 60], dtype=np.int64)
7 | b = np.array([20, 52, 62], dtype=np.int64)
8 | ix = np.zeros(100, dtype=np.int64)
9 | local_cellxregion_ix = np.zeros(100, dtype=np.int64)
10 |
11 | n_fragments = chromatinhd.loaders.fragments_helpers.multiple_arange(a, b, ix, local_cellxregion_ix)
12 | ix.resize(n_fragments, refcheck=False)
13 | local_cellxregion_ix.resize(n_fragments, refcheck=False)
14 |
15 | assert ix.shape == (14,)
16 | assert local_cellxregion_ix.shape == (14,)
17 | assert local_cellxregion_ix.max() < 3
18 | assert np.all(ix == np.array([10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 50, 51, 60, 61], dtype=np.int64))
19 | assert np.all(local_cellxregion_ix == np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 2, 2], dtype=np.int64))
20 |
--------------------------------------------------------------------------------
/tests/loaders/test_transcriptome_fragments.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import chromatinhd as chd
3 |
4 |
5 | class TestTranscriptomeFragments:
6 | def test_example(self, example_fragments, example_transcriptome):
7 | loader = chd.loaders.TranscriptomeFragments(
8 | fragments=example_fragments,
9 | transcriptome=example_transcriptome,
10 | cellxregion_batch_size=10000,
11 | )
12 |
13 | minibatch = chd.loaders.minibatches.Minibatch(cells_oi=np.arange(20), regions_oi=np.arange(5), phase="train")
14 | loader.load(minibatch)
15 |
--------------------------------------------------------------------------------
/tests/models/diff/loader/test_clustering_cuts.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import chromatinhd as chd
3 |
4 |
5 | class TestClusteringCuts:
6 | def test_example(self, example_fragments, example_clustering):
7 | loader = chd.models.diff.loader.ClusteringCuts(
8 | fragments=example_fragments,
9 | clustering=example_clustering,
10 | cellxregion_batch_size=10000,
11 | )
12 |
13 | minibatch = chd.models.diff.loader.Minibatch(cells_oi=np.arange(20), regions_oi=np.arange(5), phase="train")
14 | loader.load(minibatch)
15 |
--------------------------------------------------------------------------------
/tests/models/diff/model/_test_spline.py:
--------------------------------------------------------------------------------
1 | import chromatinhd.models.diff.model.spline
2 | import torch
3 | import numpy as np
4 |
5 |
6 | class TestDifferentialQuadraticSplineStack:
7 | def test_basic(self):
8 | transform = chromatinhd.models.diff.model.spline.DifferentialQuadraticSplineStack(nbins=(128,), n_regions=1)
9 |
10 | x = torch.linspace(0, 1, 100)
11 | genes_oi = torch.tensor([0])
12 | local_gene_ix = torch.zeros(len(x), dtype=torch.int)
13 |
14 | delta = torch.zeros((len(x), np.sum(transform.split_deltas)))
15 | delta[:, :30] = 0
16 | ouput, logabsdet = transform.transform_forward(x, genes_oi, local_gene_ix, delta)
17 |
18 | assert np.isclose(np.trapz(torch.exp(logabsdet).detach().numpy(), x), 1, atol=5e-2)
19 |
20 | delta[:, :30] = 1
21 | ouput, logabsdet = transform.transform_forward(x, genes_oi, local_gene_ix, delta)
22 |
23 | assert np.isclose(np.trapz(torch.exp(logabsdet).detach().numpy(), x), 1, atol=5e-2)
24 |
25 | delta[:, :] = torch.from_numpy(np.random.normal(size=(1, delta.shape[1])))
26 | ouput, logabsdet = transform.transform_forward(x, genes_oi, local_gene_ix, delta)
27 |
28 | assert np.isclose(np.trapz(torch.exp(logabsdet).detach().numpy(), x), 1, atol=5e-2)
29 |
--------------------------------------------------------------------------------
/tests/models/diff/model/test_diff_model.py:
--------------------------------------------------------------------------------
1 | import chromatinhd as chd
2 | import time
3 |
4 |
5 | class TestDiff:
6 | def test_example_single_model(self, example_fragments, example_clustering, example_folds):
7 | fold = example_folds[0]
8 | model = chd.models.diff.model.binary.Model.create(
9 | fragments=example_fragments, clustering=example_clustering, fold=fold
10 | )
11 |
12 | start = time.time()
13 | model.train_model(n_epochs=10)
14 |
15 | delta = time.time() - start
16 |
17 | assert delta < 20
18 |
19 | def test_example_multiple_models(self, example_fragments, example_clustering, example_folds):
20 | model = chd.models.diff.model.binary.Models.create(
21 | fragments=example_fragments,
22 | clustering=example_clustering,
23 | folds=example_folds,
24 | )
25 |
26 | start = time.time()
27 | model.train_models(n_epochs=2, regions_oi=example_fragments.var.index[:2])
28 |
29 | delta = time.time() - start
30 |
31 | assert delta < 40
32 |
--------------------------------------------------------------------------------
/tests/models/miff/_test_zoom.py:
--------------------------------------------------------------------------------
1 | import chromatinhd.models.miff.model.zoom
2 | import torch
3 | import numpy as np
4 |
5 |
6 | class TestZoom:
7 | def test_simple(self):
8 | width = 16
9 | nbins = [4, 2, 2]
10 |
11 | totalnbins = np.cumprod(nbins)
12 | totalbinwidths = torch.tensor(width // totalnbins)
13 |
14 | positions = torch.arange(16)
15 |
16 | def test_probability(unnormalized_heights_all, expected_logprob):
17 | totalnbins = np.cumprod(nbins)
18 | totalbinwidths = torch.tensor(width // totalnbins)
19 |
20 | unnormalized_heights_zooms = chromatinhd.models.miff.model.zoom.extract_unnormalized_heights(
21 | positions, totalbinwidths, unnormalized_heights_all
22 | )
23 | logprob = chromatinhd.models.miff.model.zoom.calculate_logprob(
24 | positions, nbins, width, unnormalized_heights_zooms
25 | )
26 |
27 | # print(1 / np.exp((logprob)))
28 | assert np.isclose(1 / np.exp((logprob)), expected_logprob).all()
29 |
30 | ##
31 | unnormalized_heights_all = []
32 | cur_total_n = 1
33 | for n in nbins:
34 | cur_total_n *= n
35 | unnormalized_heights_all.append(torch.zeros(cur_total_n).reshape(-1, n))
36 | unnormalized_heights_all[1][2, 1] = np.log(2)
37 | expected_logprob = np.array([16, 16, 16, 16, 16, 16, 16, 16, 24, 24, 12, 12, 16, 16, 16, 16])
38 |
39 | test_probability(unnormalized_heights_all, expected_logprob)
40 |
41 | ##
42 | unnormalized_heights_all = []
43 | cur_total_n = 1
44 | for n in nbins:
45 | cur_total_n *= n
46 | unnormalized_heights_all.append(torch.zeros(cur_total_n).reshape(-1, n))
47 | unnormalized_heights_all[0][0, 1] = np.log(2)
48 | expected_logprob = np.array([20, 20, 20, 20, 10, 10, 10, 10, 20, 20, 20, 20, 20, 20, 20, 20])
49 |
50 | test_probability(unnormalized_heights_all, expected_logprob)
51 |
52 | ##
53 | unnormalized_heights_all = []
54 | cur_total_n = 1
55 | for n in nbins:
56 | cur_total_n *= n
57 | unnormalized_heights_all.append(torch.zeros(cur_total_n).reshape(-1, n))
58 | unnormalized_heights_all[2][2, 0] = np.log(2)
59 | expected_logprob = np.array([16, 16, 16, 16, 12, 24, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16])
60 |
61 | test_probability(unnormalized_heights_all, expected_logprob)
62 |
--------------------------------------------------------------------------------
/tests/models/pred/model/test_pred_model.py:
--------------------------------------------------------------------------------
1 | import chromatinhd as chd
2 | import time
3 |
4 |
5 | class TestPred:
6 | def test_example_single_model(self, example_fragments, example_transcriptome, example_folds):
7 | fold = example_folds[0]
8 | model = chd.models.pred.model.better.Model.create(
9 | fragments=example_fragments,
10 | transcriptome=example_transcriptome,
11 | region_oi=example_fragments.var.index[0],
12 | fold=fold,
13 | )
14 |
15 | start = time.time()
16 | model.train_model(fold=example_folds[0], n_epochs=10)
17 |
18 | delta = time.time() - start
19 |
20 | assert delta < 20
21 | result = model.get_prediction(cell_ixs=fold["cells_test"])
22 |
23 | assert result["predicted"].shape[0] == len(fold["cells_test"])
24 |
25 | def test_example_multiple_models(self, example_fragments, example_transcriptome, example_folds):
26 | models = chd.models.pred.model.better.Models.create(
27 | fragments=example_fragments, transcriptome=example_transcriptome
28 | )
29 |
30 | start = time.time()
31 | models.train_models(folds=example_folds, n_epochs=1, regions_oi=example_fragments.var.index[:2])
32 |
33 | delta = time.time() - start
34 |
35 | assert delta < 60
36 | result = models.get_prediction(region=example_fragments.var.index[0], fold_ix=0)
37 |
--------------------------------------------------------------------------------
/tests/utils/test_interleave.py:
--------------------------------------------------------------------------------
1 | import chromatinhd as chd
2 | import torch
3 | import numpy as np
4 |
5 |
6 | def test_interleave():
7 | y = chd.utils.interleave.interleave(torch.tensor([1, 2, 3, 4, 2, 3], dtype=float), repeats=np.array([1, 2]))
8 | assert torch.allclose(y, torch.tensor([3.0, 4.0, 6.0, 7.0], dtype=float))
9 |
10 | x = chd.utils.interleave.deinterleave(y, repeats=np.array([1, 2]))
11 | assert torch.allclose(x, torch.tensor([-0.5000, 0.5000, -0.5000, 0.5000, 3.5000, 6.5000], dtype=torch.float64))
12 |
13 | y2 = chd.utils.interleave.interleave(x, repeats=np.array([1, 2]))
14 | assert torch.allclose(y, y2)
15 |
16 |
17 | def test_interleave2():
18 | y = torch.tensor([1, 2, 3, 4, 2, 3], dtype=float)
19 | x = chd.utils.interleave.deinterleave(y, repeats=np.array([1, 2, 3]))
20 | y2 = chd.utils.interleave.interleave(x, repeats=np.array([1, 2, 3]))
21 | assert torch.allclose(y, y2)
22 |
--------------------------------------------------------------------------------