├── .github └── workflows │ └── ci.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE.md ├── README.md ├── docs ├── override │ └── main.html ├── source │ ├── .pages │ ├── CNAME │ ├── benchmark │ │ ├── diff │ │ │ ├── models.py │ │ │ ├── simulated.py │ │ │ └── timing.py │ │ └── index.md │ ├── index.md │ ├── javascripts │ │ └── reference.js │ ├── quickstart │ │ ├── 0_install.py │ │ ├── 1_data.py │ │ ├── 2_diff.py │ │ └── 3_pred.py │ ├── reference │ │ ├── .pages │ │ ├── data │ │ │ ├── clustering.md │ │ │ ├── folds.md │ │ │ ├── fragments.md │ │ │ ├── motifscan.md │ │ │ ├── regions.md │ │ │ └── transcriptome.md │ │ ├── index.md │ │ ├── loaders │ │ │ └── fragments.md │ │ └── models │ │ │ ├── diff │ │ │ ├── interpret.md │ │ │ ├── model.md │ │ │ └── plot.md │ │ │ └── pred │ │ │ ├── interpret.md │ │ │ ├── model.md │ │ │ └── plot.md │ ├── static │ │ ├── comparison.gif │ │ ├── eu.png │ │ ├── favicon.ai │ │ ├── favicon.png │ │ ├── logo.ai │ │ ├── logo.png │ │ └── models │ │ │ ├── diff │ │ │ ├── 1x │ │ │ │ └── logo.png │ │ │ └── logo.pdf │ │ │ ├── dime │ │ │ ├── logo.pdf │ │ │ └── logo.png │ │ │ ├── pred │ │ │ ├── 1x │ │ │ │ └── logo.png │ │ │ └── logo.pdf │ │ │ └── time │ │ │ ├── logo.pdf │ │ │ └── logo.png │ └── stylesheets │ │ ├── extra-reference.css │ │ └── extra.css └── tex │ └── overview.tex ├── mkdocs.yml ├── pyproject.toml ├── scripts ├── benchmark │ └── datasets │ │ └── pbmc10k │ │ ├── 1-download.py │ │ ├── 2-process_all.py │ │ ├── 3-process_large.py │ │ ├── 4-process_tiny.py │ │ └── 5-process_wide.py ├── cythonize.sh ├── dist.sh ├── docs.sh ├── followup │ └── miff │ │ ├── fit_negbinom.py │ │ ├── layers_linear_spline.py │ │ ├── quadratic.py │ │ └── truncated_normal.py ├── install.sh └── test.py ├── setup.py ├── src └── chromatinhd │ ├── __init__.py │ ├── biomart │ ├── __init__.py │ ├── cache.py │ ├── dataset.py │ ├── homology.py │ └── tss.py │ ├── data │ ├── __init__.py │ ├── associations │ │ ├── __init__.py │ │ ├── associations.py │ │ └── plot.py │ ├── clustering │ │ ├── __init__.py │ │ └── clustering.py │ ├── examples │ │ └── pbmc10ktiny │ │ │ ├── fragments.tsv.gz │ │ │ ├── fragments.tsv.gz.REMOVED.git-id │ │ │ ├── fragments.tsv.gz.tbi │ │ │ └── transcriptome.h5ad │ ├── folds │ │ ├── __init__.py │ │ └── folds.py │ ├── fragments │ │ ├── __init__.py │ │ ├── fragments.py │ │ └── view.py │ ├── genotype │ │ ├── __init__.py │ │ └── genotype.py │ ├── gradient │ │ ├── __init__.py │ │ └── gradient.py │ ├── motifscan │ │ ├── __init__.py │ │ ├── download.py │ │ ├── motifcount.py │ │ ├── motifscan.py │ │ ├── motiftrack.py │ │ ├── plot.py │ │ ├── plot_genome.py │ │ ├── scan_helpers.html │ │ ├── scan_helpers.pyx │ │ └── view.py │ ├── peakcounts │ │ ├── __init__.py │ │ ├── peakcounts.py │ │ └── plot.py │ ├── regions.py │ └── transcriptome │ │ ├── __init__.py │ │ ├── timetranscriptome.py │ │ └── transcriptome.py │ ├── device.py │ ├── embedding.py │ ├── flow │ ├── __init__.py │ ├── flow.py │ ├── flow_template.jinja2 │ ├── linked.py │ ├── objects.py │ ├── sparse.py │ ├── tensorstore.py │ └── tipyte.py │ ├── loaders │ ├── __init__.py │ ├── clustering.py │ ├── clustering_fragments.py │ ├── extraction │ │ ├── fragments.pyx │ │ └── motifs.pyx │ ├── fragments.py │ ├── fragments2.py │ ├── fragments_helpers.html │ ├── fragments_helpers.pyx │ ├── minibatches.py │ ├── peakcounts.py │ ├── pool.py │ ├── transcriptome.py │ ├── transcriptome_fragments.py │ ├── transcriptome_fragments2.py │ └── transcriptome_fragments_time.py │ ├── models │ ├── __init__.py │ ├── diff │ │ ├── __init__.py │ │ ├── interpret │ │ │ ├── __init__.py │ │ │ ├── differential.py │ │ │ ├── enrichment │ │ │ │ ├── __init__.py │ │ │ │ ├── enrichment.py │ │ │ │ └── group.py │ │ │ ├── performance.py │ │ │ ├── regionpositional.py │ │ │ └── slices.py │ │ ├── loader │ │ │ ├── __init__.py │ │ │ ├── clustering_cuts.py │ │ │ └── clustering_fragments.py │ │ ├── model │ │ │ ├── __init__.py │ │ │ ├── binary.py │ │ │ ├── cutnf.py │ │ │ ├── encoders.py │ │ │ ├── playground.py │ │ │ ├── spline.py │ │ │ ├── splines │ │ │ │ ├── __init__.py │ │ │ │ ├── cubic.py │ │ │ │ ├── linear.py │ │ │ │ └── quadratic.py │ │ │ └── truncated_normal.py │ │ ├── plot │ │ │ ├── __init__.py │ │ │ ├── differential.py │ │ │ └── differential_expression.py │ │ └── trainer │ │ │ ├── __init__.py │ │ │ └── trainer.py │ ├── miff │ │ ├── loader │ │ │ ├── __init__.py │ │ │ ├── clustering.py │ │ │ ├── combinations.py │ │ │ ├── minibatches.py │ │ │ ├── motifcount.py │ │ │ ├── motifs.py │ │ │ └── motifs_fragments.py │ │ └── model │ │ │ ├── __init__.py │ │ │ ├── clustering │ │ │ ├── __init__.py │ │ │ ├── distributions.py │ │ │ └── model.py │ │ │ ├── clustering2 │ │ │ ├── __init__.py │ │ │ ├── distributions.py │ │ │ └── model.py │ │ │ ├── clustering3 │ │ │ ├── __init__.py │ │ │ ├── count.py │ │ │ ├── model.py │ │ │ └── position.py │ │ │ ├── global_norm │ │ │ ├── __init__.py │ │ │ ├── distributions.py │ │ │ └── model.py │ │ │ ├── local_norm │ │ │ ├── __init__.py │ │ │ ├── distributions.py │ │ │ └── model.py │ │ │ ├── site_embedder.py │ │ │ └── zoom.py │ ├── model.py │ ├── pred │ │ ├── __init__.py │ │ ├── interpret │ │ │ ├── __init__.py │ │ │ ├── censorers.py │ │ │ ├── performance.py │ │ │ ├── regionmultiwindow.py │ │ │ ├── regionpairwindow.py │ │ │ ├── regionsizewindow.py │ │ │ └── size.py │ │ ├── model │ │ │ ├── __init__.py │ │ │ ├── better.py │ │ │ ├── encoders.py │ │ │ ├── loss.py │ │ │ ├── multilinear.py │ │ │ ├── multiscale.py │ │ │ ├── peakcounts.py │ │ │ ├── peakcounts_test.py │ │ │ └── shared.py │ │ ├── plot │ │ │ ├── __init__.py │ │ │ ├── copredictivity.py │ │ │ ├── effect.py │ │ │ └── predictivity.py │ │ └── trainer │ │ │ ├── __init__.py │ │ │ └── trainer.py │ └── pret │ │ ├── model │ │ ├── __init__.py │ │ ├── better.py │ │ ├── loss.py │ │ └── peakcounts.py │ │ └── trainer │ │ ├── __init__.py │ │ └── trainer.py │ ├── optim.py │ ├── plot │ ├── __init__.py │ ├── genome │ │ ├── __init__.py │ │ └── genes.py │ ├── matshow45.py │ ├── patch.py │ ├── quasirandom.py │ ├── tickers.py │ └── tracks │ │ ├── __init__.py │ │ └── tracks.py │ ├── scoring │ └── prediction │ │ └── filterers.py │ ├── simulation │ └── simulate.py │ ├── sparse.py │ ├── train.py │ └── utils │ ├── __init__.py │ ├── ecdf.py │ ├── interleave.py │ ├── intervals.py │ ├── numpy.py │ ├── scanpy.py │ ├── testing.py │ ├── timing.py │ └── torch.py └── tests ├── _test_sparse.py ├── biomart ├── conftest.py └── tss_test.py ├── conftest.py ├── data └── motifscan │ └── test_motifscan.py ├── loaders ├── test_fragment_helpers.py ├── test_fragments.py └── test_transcriptome_fragments.py ├── models ├── diff │ ├── loader │ │ └── test_clustering_cuts.py │ └── model │ │ ├── _test_spline.py │ │ └── test_diff_model.py ├── miff │ └── _test_zoom.py └── pred │ └── model │ └── test_pred_model.py └── utils └── test_interleave.py /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: test package 2 | concurrency: 3 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} 4 | cancel-in-progress: true 5 | on: 6 | push: 7 | branches: 8 | - main 9 | - devel 10 | permissions: 11 | contents: write 12 | jobs: 13 | test: 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: actions/checkout@v4 17 | - uses: actions/setup-python@v5 18 | with: 19 | python-version: 3.9 20 | cache: 'pip' 21 | - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV 22 | - run: pip install torch==2.3.0 # because torch_scatter contains import torch in the setup.py, we have to install torch first 23 | - run: pip install torch-scatter -f https://data.pyg.org/whl/torch-2.3.0+cpu.html # it seems a wheel is necessary to install torch-scatter 24 | - name: regular install 25 | run: pip install .[full] 26 | - name: attempt import 27 | run: python -c "import chromatinhd" 28 | - name: test install 29 | run: pip install -e .[test] 30 | # - uses: jpetrucciani/ruff-check@main 31 | - name: test 32 | run: pytest 33 | # - run: mkdocs gh-deploy --force -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ./data/ 2 | output 3 | software 4 | *.c 5 | 6 | 7 | *.ipynb 8 | 9 | # Byte-compiled / optimized / DLL files 10 | __pycache__/ 11 | *.py[cod] 12 | *$py.class 13 | 14 | # C extensions 15 | *.so 16 | 17 | # Distribution / packaging 18 | .Python 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | share/python-wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | cover/ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | db.sqlite3 70 | db.sqlite3-journal 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | .pybuilder/ 84 | target/ 85 | 86 | # Jupyter Notebook 87 | .ipynb_checkpoints 88 | 89 | # IPython 90 | profile_default/ 91 | ipython_config.py 92 | 93 | # pyenv 94 | # For a library or package, you might want to ignore these files since the code is 95 | # intended to run in multiple environments; otherwise, check them in: 96 | # .python-version 97 | 98 | # pipenv 99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 102 | # install all needed dependencies. 103 | #Pipfile.lock 104 | 105 | # poetry 106 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 107 | # This is especially recommended for binary packages to ensure reproducibility, and is more 108 | # commonly ignored for libraries. 109 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 110 | #poetry.lock 111 | 112 | # pdm 113 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 114 | #pdm.lock 115 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 116 | # in version control. 117 | # https://pdm.fming.dev/#use-with-ide 118 | .pdm.toml 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | 171 | # manually added 172 | *.pkl 173 | *.tsv 174 | *trace.png 175 | tmp/ 176 | software/ 177 | *.parquet 178 | 179 | 180 | .vscode/ 181 | 182 | # ignore example dataset 183 | docs/source/quickstart/example/* 184 | docs/source/quickstart/restats 185 | *.code-workspace 186 | 187 | 188 | *.zarr 189 | *.zarr/* 190 | *.hdf5 191 | restats 192 | docs/tex/overview.aux 193 | docs/tex/overview.fdb_latexmk 194 | docs/tex/overview.fls 195 | docs/tex/overview.pdf 196 | docs/tex/overview.synctex.gz 197 | 198 | output/* 199 | 200 | 201 | src/*.c -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | # - repo: https://github.com/astral-sh/ruff-pre-commit 3 | # # Ruff version. 4 | # rev: v0.0.282 5 | # hooks: 6 | # - id: ruff 7 | # args: [., --fix, --exit-non-zero-on-fix ] 8 | # exclude: '(^scripts|^docs)/.*' 9 | - repo: https://github.com/psf/black 10 | rev: 23.3.0 11 | hooks: 12 | - id: black 13 | language_version: python3 14 | exclude: '(^scripts|^docs)/.*' -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Wouter Saelens 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 | 4 | 5 | 6 | 7 | 8 |

9 | 10 | ChromatinHD analyzes single-cell ATAC+RNA data using the raw fragments as input, 11 | by automatically adapting the scale at which 12 | relevant chromatin changes on a per-position, per-cell, and per-gene basis. 13 | This enables identification of functional chromatin changes 14 | regardless of whether they occur in a narrow or broad region. 15 | 16 | As we show in [our paper](https://www.biorxiv.org/content/10.1101/2023.07.21.549899v1): 17 | - Compared to the typical approach (peak calling + statistical analysis), ChromatinHD models are better able to capture functional chromatin changes. This is because there are extensive functional accessibility changes both outside and within peaks ([Figure 3](https://www.biorxiv.org/content/10.1101/2023.07.21.549899v1)). 18 | - ChromatinHD models can capture long-range interactions by considering fragments co-occuring within the same cell ([Figure 4](https://www.biorxiv.org/content/10.1101/2023.07.21.549899v1)). 19 | - ChromatinHD models can also capture changes in fragment size that are related to gene expression changes, likely driven by dense direct and indirect binding of transcription factors ([Figure 5](https://www.biorxiv.org/content/10.1101/2023.07.21.549899v1)). 20 | 21 | [📜 Manuscript](https://www.biorxiv.org/content/10.1101/2023.07.21.549899v1) 22 | 23 | [❔ Documentation](https://chromatinhd.org) 24 | 25 | [▶️ Quick start](https://chromatinhd.org/quickstart/0_install) 26 | -------------------------------------------------------------------------------- /docs/override/main.html: -------------------------------------------------------------------------------- 1 | {% extends "base.html" %} 2 | 3 | {% block outdated %} 4 | You're not viewing the latest version. 5 | 6 | Click here to go to latest. 7 | 8 | {% endblock %} 9 | 10 | {% block footer %} 11 | {#- 12 | This file was automatically generated - do not edit 13 | -#} 14 | 77 | {% endblock %} -------------------------------------------------------------------------------- /docs/source/.pages: -------------------------------------------------------------------------------- 1 | nav: 2 | - quickstart 3 | - reference 4 | - benchmark 5 | -------------------------------------------------------------------------------- /docs/source/CNAME: -------------------------------------------------------------------------------- 1 | chromatinhd.org -------------------------------------------------------------------------------- /docs/source/benchmark/diff/timing.py: -------------------------------------------------------------------------------- 1 | # --- 2 | # jupyter: 3 | # jupytext: 4 | # formats: ipynb,py:percent 5 | # text_representation: 6 | # extension: .py 7 | # format_name: percent 8 | # format_version: '1.3' 9 | # jupytext_version: 1.14.7 10 | # kernelspec: 11 | # display_name: Python 3 (ipykernel) 12 | # language: python 13 | # name: python3 14 | # --- 15 | 16 | # %% 17 | import polyptich as pp 18 | pp.setup_ipython() 19 | 20 | import numpy as np 21 | import pandas as pd 22 | 23 | import matplotlib.pyplot as plt 24 | import matplotlib as mpl 25 | 26 | import seaborn as sns 27 | 28 | sns.set_style("ticks") 29 | # %config InlineBackend.figure_format='retina' 30 | 31 | import tqdm.auto as tqdm 32 | import torch 33 | import os 34 | import time 35 | 36 | # %% 37 | import chromatinhd as chd 38 | 39 | chd.set_default_device("cuda:1") 40 | 41 | # %% 42 | dataset_folder_original = chd.get_output() / "datasets" / "pbmc10k" 43 | transcriptome_original = chd.data.Transcriptome(dataset_folder_original / "transcriptome") 44 | fragments_original = chd.data.Fragments(dataset_folder_original / "fragments" / "10k10k") 45 | 46 | # %% 47 | genes_oi = transcriptome_original.var.sort_values("dispersions_norm", ascending=False).head(30).index 48 | regions = fragments_original.regions.filter_genes(genes_oi) 49 | fragments = fragments_original.filter_genes(regions) 50 | fragments.create_cellxgene_indptr() 51 | transcriptome = transcriptome_original.filter_genes(regions.coordinates.index) 52 | 53 | # %% 54 | folds = chd.data.folds.Folds() 55 | folds.sample_cells(fragments, 5) 56 | 57 | # %% 58 | clustering = chd.data.Clustering.from_labels(transcriptome_original.obs["celltype"]) 59 | 60 | # %% 61 | fold = folds[0] 62 | 63 | # %% 64 | models = {} 65 | scores = [] 66 | 67 | # %% 68 | import logging 69 | 70 | logger = chd.models.diff.trainer.trainer.logger 71 | logger.setLevel(logging.DEBUG) 72 | logger.handlers = [] 73 | # logger.handlers = [logging.StreamHandler()] 74 | 75 | # %% 76 | devices = pd.DataFrame({"device": ["cuda:0", "cuda:1", "cpu"]}).set_index("device") 77 | for device in devices.index: 78 | if device != "cpu": 79 | devices.loc[device, "label"] = torch.cuda.get_device_properties(device).name 80 | else: 81 | devices.loc[device, "label"] = os.popen("lscpu").read().split("\n")[13].split(": ")[-1].lstrip() 82 | 83 | # %% 84 | scores = pd.DataFrame({"device": devices.index}).set_index("device") 85 | 86 | # %% 87 | for device in devices.index: 88 | start = time.time() 89 | model = chd.models.diff.model.cutnf.Model( 90 | fragments, 91 | clustering, 92 | ) 93 | model.train_model(fragments, clustering, fold, n_epochs=10, device=device) 94 | models[device] = model 95 | end = time.time() 96 | scores.loc[device, "train"] = end - start 97 | 98 | # %% 99 | for device in devices.index: 100 | genepositional = chd.models.diff.interpret.genepositional.GenePositional( 101 | path=chd.get_output() / "interpret" / "genepositional" 102 | ) 103 | 104 | start = time.time() 105 | genepositional.score(fragments, clustering, [models[device]], force=True, device=device) 106 | end = time.time() 107 | scores.loc[device, "inference"] = end - start 108 | 109 | # %% 110 | fig = polyptich.grid.Figure(polyptich.grid.Wrap(padding_width=0.1)) 111 | height = len(scores) * 0.2 112 | 113 | plotdata = scores.copy().loc[devices.index] 114 | 115 | panel, ax = fig.main.add(polyptich.grid.Panel((1, height))) 116 | ax.barh(plotdata.index, plotdata["train"]) 117 | ax.set_yticks(np.arange(len(devices))) 118 | ax.set_yticklabels(devices.label) 119 | ax.axvline(0, color="black", linestyle="--", lw=1) 120 | ax.set_title("Training") 121 | ax.set_xlabel("seconds") 122 | 123 | panel, ax = fig.main.add(polyptich.grid.Panel((1, height))) 124 | ax.barh(plotdata.index, plotdata["inference"]) 125 | ax.axvline(0, color="black", linestyle="--", lw=1) 126 | ax.set_title("Inference") 127 | ax.set_yticks([]) 128 | fig.plot() 129 | -------------------------------------------------------------------------------- /docs/source/benchmark/index.md: -------------------------------------------------------------------------------- 1 | # Benchmark 2 | 3 | Lightweight benchmarking of the various models in terms of function and scalability. More comprehensive benchmarking with the state-of-the-art was performed in [our paper](https://www.biorxiv.org/content/10.1101/2023.07.21.549899v1). 4 | -------------------------------------------------------------------------------- /docs/source/index.md: -------------------------------------------------------------------------------- 1 | ChromatinHD analyzes single-cell ATAC+RNA data using the raw fragments as input, 2 | by automatically adapting the scale at which 3 | relevant chromatin changes on a per-position, per-cell, and per-gene basis. 4 | This enables identification of functional chromatin changes 5 | regardless of whether they occur in a narrow or broad region. 6 | 7 | As we show in [our paper](https://www.biorxiv.org/content/10.1101/2023.07.21.549899v1), ChromatinHD models are better able to capture functional chromatin changes that the typical approach, i.e. peak-calling + statistical analysis. This is because there are extensive functional accessibility changes both outside and within peaks. 8 | 9 | ChromatinHD models can capture long-range interactions by considering fragments co-occuring within the same cell, as we highlight in [Figure 5 of our paper](https://www.biorxiv.org/content/10.1101/2023.07.21.549899v1), 10 | 11 | ChromatinHD models can also capture changes in fragment size that are related to gene expression changes, likely driven by dense direct and indirect binding of transcription factors, as we highlight in [Figure 6 of our paper](https://www.biorxiv.org/content/10.1101/2023.07.21.549899v1). 12 | 13 | Currently, the following models are supported: 14 | 15 | 58 | 59 |
60 | 61 |
62 |
63 |

Pred

64 |
65 |
66 |

To learn where and how accessibility is predictive for gene expression

67 | ChromatinHD-pred 68 |
69 | 70 |
71 |
72 | 73 |
74 |
75 |

Diff

76 |
77 |
78 |

To understand the differences in accessibilty between cell types/states

79 | ChromatinHD-diff 80 |
81 |
82 |
83 |
84 |
85 |

Time

86 |
87 |
88 |

To learn where and how accessibility is predictive over (pseudo)time

89 | ChromatinHD-time 90 |
91 |
92 |
93 |
94 |

Dime

95 |
96 |
97 |

To learn the differences in accessibility over (pseudo)time

98 | ChromatinHD-dime 99 |
100 |
101 |
102 | -------------------------------------------------------------------------------- /docs/source/javascripts/reference.js: -------------------------------------------------------------------------------- 1 | document.querySelectorAll('.doc-attribute>:is(h1, h2, h3, h4, h5, h6)').forEach(function (el) { 2 | // add a tag to the attribute at the beginning 3 | var tag = document.createElement('span'); 4 | tag.className = 'doc-attribute-tag'; 5 | tag.textContent = "attribute" 6 | 7 | el.prepend(tag) 8 | 9 | }) 10 | 11 | document.querySelectorAll('.doc-function>:is(h1, h2, h3, h4, h5, h6)').forEach(function (el) { 12 | // add a tag to the function at the beginning 13 | var tag = document.createElement('span'); 14 | tag.className = 'doc-function-tag'; 15 | // check if it's a method or a function 16 | console.log(el.closest('.doc-class')) 17 | if (el.closest('.doc-class')) { 18 | tag.textContent = "method" 19 | } else { 20 | tag.textContent = "function" 21 | } 22 | 23 | el.prepend(tag) 24 | 25 | }) 26 | 27 | 28 | document.querySelectorAll('.doc-class>:is(h1, h2, h3, h4, h5, h6)').forEach(function (el) { 29 | // add a tag to the class at the beginning 30 | var tag = document.createElement('span'); 31 | tag.className = 'doc-class-tag'; 32 | tag.textContent = "class" 33 | 34 | el.prepend(tag) 35 | 36 | }) -------------------------------------------------------------------------------- /docs/source/quickstart/0_install.py: -------------------------------------------------------------------------------- 1 | # --- 2 | # jupyter: 3 | # jupytext: 4 | # text_representation: 5 | # extension: .py 6 | # format_name: percent 7 | # format_version: '1.3' 8 | # jupytext_version: 1.16.2 9 | # kernelspec: 10 | # display_name: chromatinhd 11 | # language: python 12 | # name: python3 13 | # --- 14 | 15 | # %% [markdown] 16 | # # Installation 17 | 18 | # %% tags=["hide_code"] 19 | # autoreload 20 | import IPython 21 | 22 | if IPython.get_ipython() is not None: 23 | IPython.get_ipython().run_line_magic("load_ext", "autoreload") 24 | IPython.get_ipython().run_line_magic("autoreload", "2") 25 | 26 | # %% [markdown] 27 | #
28 | # # using pip
29 | # pip install chromatinhd
30 | #
31 | # # (soon) using conda
32 | # conda install -c bioconda chromatinhd
33 | #
34 | # # from github
35 | # pip install git+https://github.com/DeplanckeLab/ChromatinHD
36 | # 
37 | 38 | # %% [markdown] 39 | # 40 | # To use the GPU, ensure that a PyTorch version was installed with cuda enabled: 41 | # 42 | 43 | # %% 44 | import torch 45 | 46 | torch.cuda.is_available() # should return True 47 | torch.cuda.device_count() # should be >= 1 48 | 49 | # %% [markdown] 50 | # 51 | # If not, follow the instructions at https://pytorch.org/get-started/locally/. You may have to re-install PyTorch. 52 | 53 | # %% tags=["hide_output"] 54 | import chromatinhd as chd 55 | 56 | # %% [markdown] 57 | # ## Frequently asked questions 58 | 59 | # %% [markdown] 60 | # 61 | -------------------------------------------------------------------------------- /docs/source/reference/.pages: -------------------------------------------------------------------------------- 1 | nav: 2 | - data 3 | - models 4 | - loaders 5 | -------------------------------------------------------------------------------- /docs/source/reference/data/clustering.md: -------------------------------------------------------------------------------- 1 | ::: chromatinhd.data.clustering.Clustering -------------------------------------------------------------------------------- /docs/source/reference/data/folds.md: -------------------------------------------------------------------------------- 1 | ::: chromatinhd.data.folds.Folds 2 | -------------------------------------------------------------------------------- /docs/source/reference/data/fragments.md: -------------------------------------------------------------------------------- 1 | ::: chromatinhd.data.fragments.Fragments 2 | ::: chromatinhd.data.fragments.FragmentsView 3 | -------------------------------------------------------------------------------- /docs/source/reference/data/motifscan.md: -------------------------------------------------------------------------------- 1 | ::: chromatinhd.data.motifscan.Motifscan 2 | -------------------------------------------------------------------------------- /docs/source/reference/data/regions.md: -------------------------------------------------------------------------------- 1 | ::: chromatinhd.data.Regions 2 | 3 | ::: chromatinhd.data.regions.select_tss_from_fragments 4 | 5 | -------------------------------------------------------------------------------- /docs/source/reference/data/transcriptome.md: -------------------------------------------------------------------------------- 1 | ::: chromatinhd.data.Transcriptome -------------------------------------------------------------------------------- /docs/source/reference/index.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeplanckeLab/ChromatinHD/3b089ba409b452a8275ed020555d8bcad50f3299/docs/source/reference/index.md -------------------------------------------------------------------------------- /docs/source/reference/loaders/fragments.md: -------------------------------------------------------------------------------- 1 | ::: chromatinhd.loaders.fragments.Fragments 2 | ::: chromatinhd.loaders.fragments.Cuts 3 | 4 | ::: chromatinhd.loaders.fragments.FragmentsResult 5 | ::: chromatinhd.loaders.fragments.CutsResult 6 | -------------------------------------------------------------------------------- /docs/source/reference/models/diff/interpret.md: -------------------------------------------------------------------------------- 1 | ::: chromatinhd.models.diff.interpret.RegionPositional 2 | -------------------------------------------------------------------------------- /docs/source/reference/models/diff/model.md: -------------------------------------------------------------------------------- 1 | 2 | ## CutNF 3 | 4 | The basic differential model that only looks at cut sites individually, regardless of the fragment's and cell's other cut sites 5 | 6 | ::: chromatinhd.models.diff.model.cutnf.Model 7 | options: 8 | heading_level: 3 9 | 10 | ::: chromatinhd.models.diff.model.cutnf.Models 11 | options: 12 | heading_level: 3 13 | -------------------------------------------------------------------------------- /docs/source/reference/models/diff/plot.md: -------------------------------------------------------------------------------- 1 | 2 | ::: chromatinhd.models.diff.plot 3 | 4 | 5 | -------------------------------------------------------------------------------- /docs/source/reference/models/pred/interpret.md: -------------------------------------------------------------------------------- 1 | ::: chromatinhd.models.pred.interpret.RegionMultiWindow 2 | ::: chromatinhd.models.pred.interpret.RegionPairWindow 3 | 4 | -------------------------------------------------------------------------------- /docs/source/reference/models/pred/model.md: -------------------------------------------------------------------------------- 1 | 2 | ## Additive 3 | 4 | ::: chromatinhd.models.pred.model.multiscale.Model 5 | options: 6 | heading_level: 3 7 | 8 | ::: chromatinhd.models.pred.model.multiscale.Models 9 | options: 10 | heading_level: 3 11 | -------------------------------------------------------------------------------- /docs/source/reference/models/pred/plot.md: -------------------------------------------------------------------------------- 1 | 2 | ::: chromatinhd.models.pred.plot 3 | 4 | -------------------------------------------------------------------------------- /docs/source/static/comparison.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeplanckeLab/ChromatinHD/3b089ba409b452a8275ed020555d8bcad50f3299/docs/source/static/comparison.gif -------------------------------------------------------------------------------- /docs/source/static/eu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeplanckeLab/ChromatinHD/3b089ba409b452a8275ed020555d8bcad50f3299/docs/source/static/eu.png -------------------------------------------------------------------------------- /docs/source/static/favicon.ai: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeplanckeLab/ChromatinHD/3b089ba409b452a8275ed020555d8bcad50f3299/docs/source/static/favicon.ai -------------------------------------------------------------------------------- /docs/source/static/favicon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeplanckeLab/ChromatinHD/3b089ba409b452a8275ed020555d8bcad50f3299/docs/source/static/favicon.png -------------------------------------------------------------------------------- /docs/source/static/logo.ai: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeplanckeLab/ChromatinHD/3b089ba409b452a8275ed020555d8bcad50f3299/docs/source/static/logo.ai -------------------------------------------------------------------------------- /docs/source/static/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeplanckeLab/ChromatinHD/3b089ba409b452a8275ed020555d8bcad50f3299/docs/source/static/logo.png -------------------------------------------------------------------------------- /docs/source/static/models/diff/1x/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeplanckeLab/ChromatinHD/3b089ba409b452a8275ed020555d8bcad50f3299/docs/source/static/models/diff/1x/logo.png -------------------------------------------------------------------------------- /docs/source/static/models/diff/logo.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeplanckeLab/ChromatinHD/3b089ba409b452a8275ed020555d8bcad50f3299/docs/source/static/models/diff/logo.pdf -------------------------------------------------------------------------------- /docs/source/static/models/dime/logo.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeplanckeLab/ChromatinHD/3b089ba409b452a8275ed020555d8bcad50f3299/docs/source/static/models/dime/logo.pdf -------------------------------------------------------------------------------- /docs/source/static/models/dime/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeplanckeLab/ChromatinHD/3b089ba409b452a8275ed020555d8bcad50f3299/docs/source/static/models/dime/logo.png -------------------------------------------------------------------------------- /docs/source/static/models/pred/1x/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeplanckeLab/ChromatinHD/3b089ba409b452a8275ed020555d8bcad50f3299/docs/source/static/models/pred/1x/logo.png -------------------------------------------------------------------------------- /docs/source/static/models/pred/logo.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeplanckeLab/ChromatinHD/3b089ba409b452a8275ed020555d8bcad50f3299/docs/source/static/models/pred/logo.pdf -------------------------------------------------------------------------------- /docs/source/static/models/time/logo.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeplanckeLab/ChromatinHD/3b089ba409b452a8275ed020555d8bcad50f3299/docs/source/static/models/time/logo.pdf -------------------------------------------------------------------------------- /docs/source/static/models/time/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeplanckeLab/ChromatinHD/3b089ba409b452a8275ed020555d8bcad50f3299/docs/source/static/models/time/logo.png -------------------------------------------------------------------------------- /docs/source/stylesheets/extra-reference.css: -------------------------------------------------------------------------------- 1 | .doc-labels { 2 | display: none; 3 | } 4 | 5 | .doc-class>.doc-heading { 6 | margin-top: 0.5em; 7 | border-top: 3px #33333333 solid; 8 | background-color: var(--md-code-bg-color); 9 | /* padding-top: 0.5em; */ 10 | } 11 | 12 | .doc-class { 13 | margin-bottom: 1em; 14 | border-bottom: 1px #33333333 solid; 15 | } 16 | 17 | 18 | .doc-function>.doc-heading { 19 | margin-top: 0.5em; 20 | border-left: 3px #33333333 solid; 21 | background-color: var(--md-code-bg-color); 22 | /* padding-top: 0.5em; */ 23 | } 24 | 25 | 26 | .doc-attribute>h3, 27 | .doc-function>h3 { 28 | font-size: 1.0rem; 29 | } 30 | 31 | .doc-contents>details.quote>summary { 32 | background-color: #FFFFFF; 33 | opacity: 0.6; 34 | } 35 | 36 | .doc-contents>details.quote { 37 | border-color: #33333333; 38 | } 39 | 40 | /* attributes, functions, ... */ 41 | 42 | .doc-attribute-tag { 43 | margin-left: 5px; 44 | font-size: 0.6em; 45 | font-style: italic; 46 | padding: 0.1em; 47 | color: green; 48 | background-color: rgb(220, 251, 220); 49 | } 50 | 51 | .doc-function-tag { 52 | margin-left: 5px; 53 | font-size: 0.6em; 54 | font-style: italic; 55 | padding: 0.1em; 56 | 57 | color: coral; 58 | background-color: rgb(251, 235, 220); 59 | } 60 | 61 | .doc-class-tag { 62 | margin-left: 5px; 63 | font-size: 0.6em; 64 | font-style: italic; 65 | padding: 0.1em; 66 | color: red; 67 | background-color: rgb(251, 220, 220); 68 | } 69 | 70 | /* bold name of attribute, function, ... */ 71 | .doc-attribute code:first-child { 72 | font-weight: bold; 73 | } 74 | 75 | .doc-function code:first-child { 76 | font-weight: bold; 77 | } 78 | 79 | .doc .md-typeset__table, 80 | .doc .md-typeset__table table { 81 | display: table !important; 82 | width: inherit; 83 | } 84 | 85 | .md-typeset table:not([class]) th { 86 | padding-top: 0; 87 | padding-bottom: 0; 88 | } -------------------------------------------------------------------------------- /docs/source/stylesheets/extra.css: -------------------------------------------------------------------------------- 1 | .jp-InputArea-prompt, 2 | .jp-OutputArea-prompt { 3 | display: none !important; 4 | } 5 | 6 | .jp-OutputArea-output { 7 | font-size: 1.2em; 8 | padding: 0.5em; 9 | border-left: 2px #33333333 solid !important; 10 | } 11 | 12 | .jupyter-wrapper, 13 | .jp-Cell:not(.jp-mod-noOutputs), 14 | .jp-Cell-outputWrapper { 15 | margin-top: 0; 16 | } 17 | 18 | .jp-Cell-outputWrapper img { 19 | box-shadow: rgba(99, 99, 99, 0.2) 0px 2px 8px 0px; 20 | } 21 | 22 | .jp-mod-noOutputs.celltag_hide_code { 23 | display: none; 24 | } 25 | 26 | .md-typeset h1, 27 | .md-typeset h2, 28 | .md-typeset h3, 29 | .md-typeset h4 { 30 | margin: 0; 31 | } 32 | 33 | p { 34 | margin-top: 0.1em; 35 | } 36 | 37 | .md-grid { 38 | max-width: 90rem; 39 | } 40 | 41 | clipboard-copy:hover { 42 | background-color: #33333333; 43 | } 44 | 45 | .jupyter-wrapper .jp-InputArea-editor { 46 | border-left: 2px green solid !important; 47 | } 48 | 49 | /* remove the cell toolbar with ugly tags */ 50 | .jupyter-wrapper .celltoolbar { 51 | display: none !important; 52 | } 53 | 54 | /* nav bar */ 55 | .md-tabs__item { 56 | height: inherit; 57 | padding-left: 0; 58 | padding-right: 0; 59 | } 60 | 61 | .md-tabs__item>a { 62 | padding-top: 0.8em; 63 | padding-bottom: 0.8em; 64 | padding-left: 0.8em; 65 | padding-right: 0.8em; 66 | margin-top: 0; 67 | transition: 0.1s; 68 | } 69 | 70 | .md-tabs__item>a:hover { 71 | background-color: var(--md-code-bg-color); 72 | } 73 | 74 | .md-tabs__link--active { 75 | background-color: var(--md-code-bg-color); 76 | } 77 | 78 | .md-nav--secondary { 79 | border-left: 2px solid #33333333 !important; 80 | } 81 | 82 | /* side bar */ 83 | .md-sidebar { 84 | width: 15rem; 85 | } 86 | 87 | .md-sidebar__inner { 88 | padding-right: inherit !important; 89 | } 90 | 91 | .md-typeset details > p{ 92 | margin-top:0.5em; 93 | } -------------------------------------------------------------------------------- /docs/tex/overview.tex: -------------------------------------------------------------------------------- 1 | \documentclass{article} 2 | \usepackage{graphicx} 3 | \usepackage{amsmath} 4 | 5 | \title{ChromatinHD} 6 | \author{Wouter Saelens} 7 | \date{August 2023} 8 | 9 | \begin{document} 10 | 11 | \maketitle 12 | 13 | \begin{equation} 14 | p(x_1,x_2) = p(x_1)p(x_2|x_1) 15 | \end{equation} 16 | 17 | \begin{equation} 18 | p(x_1) = p(x_{1,l},x_{1,r}) = p(x_{1,l})p(x_{1,r}|x_{1,l}) 19 | \end{equation} 20 | 21 | \begin{equation} 22 | p(x_{1,l}) = \prod_{i=0}^{d} p(x_{1,l}|x_{1}\in w_i) 23 | \end{equation} 24 | 25 | \begin{equation} 26 | p(x_{1,l}|x_{1}\in w_a) = f(motifs) 27 | \end{equation} 28 | 29 | \end{document} 30 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: High-definition modeling of chromatin + transcriptomics data 2 | docs_dir: docs/source 3 | theme: 4 | name: material 5 | palette: 6 | primary: white 7 | logo: static/logo.png 8 | favicon: static/favicon.png 9 | custom_dir: docs/override 10 | features: 11 | - navigation.tracking 12 | - navigation.tabs 13 | - navigation.footer 14 | - navigation.sections 15 | - toc.follow 16 | - toc.integrate 17 | 18 | site_url: "http://chromatinhd.org/" 19 | repo_url: https://github.com/DeplanckeLab/ChromatinHD 20 | 21 | plugins: 22 | - mkdocstrings: 23 | handlers: 24 | python: 25 | import: 26 | - https://docs.python-requests.org/en/master/objects.inv 27 | - https://installer.readthedocs.io/en/latest/objects.inv 28 | options: 29 | docstring_style: google 30 | heading_level: 2 31 | inheritance_diagram: True 32 | show_root_heading: True 33 | show_symbol_type_heading: True 34 | docstring_section_style: table 35 | - mike 36 | - search 37 | - mkdocs-jupyter: 38 | include: ["*.ipynb"] # Default: ["*.py", "*.ipynb"] 39 | remove_tag_config: 40 | remove_input_tags: 41 | - hide_code 42 | remove_all_outputs_tags: 43 | - hide_output 44 | - social 45 | - awesome-pages 46 | extra: 47 | version: 48 | provider: mike 49 | analytics: 50 | provider: google 51 | property: G-2EDCBPY71H 52 | social: 53 | - icon: fontawesome/brands/github 54 | link: https://github.com/DeplanckeLab/ChromatinHD 55 | markdown_extensions: 56 | - admonition 57 | - pymdownx.details 58 | - pymdownx.superfences 59 | - attr_list 60 | - md_in_html 61 | - pymdownx.details 62 | extra_css: 63 | - stylesheets/extra.css 64 | - stylesheets/extra-reference.css 65 | extra_javascript: 66 | - javascripts/reference.js 67 | copyright: Copyright © 2022 - 2024 Wouter Saelens -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=41", "wheel", "setuptools_scm[toml]>=6.2", "numpy", "Cython"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [tool.setuptools-git-versioning] 6 | enabled = true 7 | 8 | [project] 9 | name = "chromatinhd" 10 | authors = [ 11 | {name = "Wouter Saelens", email = "wouter.saelens@gmail.com"}, 12 | ] 13 | description = "High-definition modeling of (single-cell) chromatin + transcriptomics data" 14 | requires-python = ">=3.8" 15 | keywords = ["bioinformatics", "chromatin accessibility", "transcriptomics"] 16 | classifiers = [ 17 | "Programming Language :: Python :: 3", 18 | ] 19 | dependencies = [ 20 | "torch", # --extra-index-url https://download.pytorch.org/whl/cu113 21 | "scanpy", 22 | "matplotlib", 23 | "numpy", 24 | "seaborn", 25 | "diskcache", 26 | "appdirs", 27 | "xarray", 28 | "requests", 29 | "zarr", 30 | "pathlibfs", 31 | ] 32 | dynamic = ["version", "readme"] 33 | license = "MIT AND (Apache-2.0 OR BSD-2-Clause)" 34 | 35 | [project.urls] 36 | "Homepage" = "https://github.com/DeplanckeLab/ChromatinHD" 37 | "Bug Tracker" = "https://github.com/DeplanckeLab/ChromatinHD/issues" 38 | 39 | [tool.setuptools.dynamic] 40 | readme = {file = "README.md", content-type = "text/markdown"} 41 | 42 | [project.optional-dependencies] 43 | sam = [ 44 | "pysam", 45 | ] 46 | dev = [ 47 | "pre-commit", 48 | "pytest", 49 | "coverage", 50 | "polyptich", 51 | "black", 52 | "pylint", 53 | "jupytext", 54 | "mkdocs", 55 | "mkdocs-material", 56 | "mkdocstrings[python]", 57 | "mkdocs-jupyter", 58 | "mike", 59 | "cairosvg", # for mkdocs social 60 | "pillow", # for mkdocs social 61 | "mkdocs-awesome-pages-plugin", 62 | "setuptools_scm", 63 | "Cython", 64 | ] 65 | test = [ 66 | "pytest", 67 | "ruff", 68 | ] 69 | full = ["chromatinhd[sam]"] 70 | 71 | [tool.setuptools_scm] 72 | 73 | [tool.pytest.ini_options] 74 | filterwarnings = [ 75 | "ignore", 76 | ] 77 | 78 | [tool.pylint.'MESSAGES CONTROL'] 79 | max-line-length = 120 80 | disable = [ 81 | "too-many-arguments", 82 | "not-callable", 83 | "redefined-builtin", 84 | "redefined-outer-name", 85 | ] 86 | 87 | [tool.ruff] 88 | line-length = 90 89 | include = ['src/**/*.py'] 90 | exclude = ['scripts/*'] 91 | 92 | 93 | [tool.jupytext] 94 | formats = "ipynb,py:percent" -------------------------------------------------------------------------------- /scripts/benchmark/datasets/pbmc10k/2-process_all.py: -------------------------------------------------------------------------------- 1 | # --- 2 | # jupyter: 3 | # jupytext: 4 | # formats: ipynb,py:percent 5 | # text_representation: 6 | # extension: .py 7 | # format_name: percent 8 | # format_version: '1.3' 9 | # jupytext_version: 1.14.7 10 | # kernelspec: 11 | # display_name: Python 3 (ipykernel) 12 | # language: python 13 | # name: python3 14 | # --- 15 | 16 | # %% [markdown] 17 | # # Preprocess 18 | 19 | # %% 20 | import polyptich as pp 21 | pp.setup_ipython() 22 | 23 | import numpy as np 24 | import pandas as pd 25 | 26 | import matplotlib.pyplot as plt 27 | import matplotlib as mpl 28 | 29 | import seaborn as sns 30 | 31 | sns.set_style("ticks") 32 | 33 | import torch 34 | 35 | import pickle 36 | 37 | import scanpy as sc 38 | 39 | import tqdm.auto as tqdm 40 | import io 41 | 42 | # %% 43 | import chromatinhd as chd 44 | 45 | # %% 46 | folder_root = chd.get_output() 47 | folder_data = folder_root / "data" 48 | 49 | dataset_name = "pbmc10k" 50 | genome = "GRCh38" 51 | 52 | folder_data_preproc = folder_data / dataset_name 53 | folder_data_preproc.mkdir(exist_ok=True, parents=True) 54 | 55 | # %% 56 | dataset_folder = chd.get_output() / "datasets" / "pbmc10k" 57 | dataset_folder.mkdir(exist_ok=True, parents=True) 58 | 59 | # %% 60 | adata = pickle.load((folder_data_preproc / "adata_annotated.pkl").open("rb")) 61 | 62 | # %% 63 | transcriptome = chd.data.transcriptome.Transcriptome.from_adata(adata, path=dataset_folder / "transcriptome") 64 | 65 | # %% 66 | import genomepy 67 | 68 | # genomepy.install_genome("GRCh38", genomes_dir="/data/genome/") 69 | 70 | sizes_file = "/data/genome/GRCh38/GRCh38.fa.sizes" 71 | 72 | # %% 73 | regions = chd.data.regions.Regions.from_chromosomes_file(sizes_file, path = dataset_folder / "regions" / "all") 74 | 75 | # %% 76 | fragments_file = folder_data_preproc / "fragments.tsv.gz" 77 | 78 | # %% 79 | fragments = chd.data.Fragments.from_fragments_tsv(fragments_file, regions = regions, obs = transcriptome.obs, path = dataset_folder / "fragments" / "all") 80 | 81 | # %% 82 | fragments.create_regionxcell_indptr() 83 | -------------------------------------------------------------------------------- /scripts/benchmark/datasets/pbmc10k/3-process_large.py: -------------------------------------------------------------------------------- 1 | # --- 2 | # jupyter: 3 | # jupytext: 4 | # formats: ipynb,py:percent 5 | # text_representation: 6 | # extension: .py 7 | # format_name: percent 8 | # format_version: '1.3' 9 | # jupytext_version: 1.14.7 10 | # kernelspec: 11 | # display_name: Python 3 (ipykernel) 12 | # language: python 13 | # name: python3 14 | # --- 15 | 16 | # %% [markdown] 17 | # # Preprocess 18 | 19 | # %% 20 | import polyptich as pp 21 | pp.setup_ipython() 22 | 23 | import numpy as np 24 | import pandas as pd 25 | 26 | import matplotlib.pyplot as plt 27 | import matplotlib as mpl 28 | 29 | import seaborn as sns 30 | 31 | sns.set_style("ticks") 32 | 33 | import torch 34 | 35 | import pickle 36 | 37 | import scanpy as sc 38 | 39 | import tqdm.auto as tqdm 40 | import io 41 | 42 | # %% 43 | import chromatinhd as chd 44 | 45 | # %% 46 | folder_root = chd.get_output() 47 | folder_data = folder_root / "data" 48 | 49 | dataset_name = "pbmc10k" 50 | genome = "GRCh38" 51 | 52 | folder_data_preproc = folder_data / dataset_name 53 | folder_data_preproc.mkdir(exist_ok=True, parents=True) 54 | 55 | # %% 56 | dataset_folder = chd.get_output() / "datasets" / "pbmc10k" 57 | dataset_folder.mkdir(exist_ok=True, parents=True) 58 | 59 | # %% 60 | adata = pickle.load((folder_data_preproc / "adata_annotated.pkl").open("rb")) 61 | 62 | # %% 63 | transcriptome = chd.data.transcriptome.Transcriptome.from_adata(adata, path=dataset_folder / "transcriptome") 64 | 65 | # %% 66 | selected_transcripts = pickle.load((folder_data_preproc / "selected_transcripts.pkl").open("rb")) 67 | regions = chd.data.regions.Regions.from_transcripts( 68 | selected_transcripts, [-10000, 10000], dataset_folder / "regions" / "10k10k" 69 | ) 70 | 71 | # %% 72 | fragments_file = folder_data_preproc / "fragments.tsv.gz" 73 | fragments = chd.data.fragments.Fragments(dataset_folder / "fragments" / "10k10k") 74 | fragments.regions = regions 75 | fragments = chd.data.fragments.Fragments.from_fragments_tsv( 76 | fragments_file=fragments_file, 77 | regions=regions, 78 | obs=transcriptome.obs, 79 | path=fragments.path, 80 | ) 81 | 82 | # %% 83 | -------------------------------------------------------------------------------- /scripts/benchmark/datasets/pbmc10k/4-process_tiny.py: -------------------------------------------------------------------------------- 1 | # --- 2 | # jupyter: 3 | # jupytext: 4 | # formats: ipynb,py:percent 5 | # text_representation: 6 | # extension: .py 7 | # format_name: percent 8 | # format_version: '1.3' 9 | # jupytext_version: 1.14.7 10 | # kernelspec: 11 | # display_name: Python 3 (ipykernel) 12 | # language: python 13 | # name: python3 14 | # --- 15 | 16 | # %% [markdown] 17 | # # Preprocess 18 | 19 | # %% 20 | import polyptich as pp 21 | pp.setup_ipython() 22 | 23 | import numpy as np 24 | import pandas as pd 25 | 26 | import matplotlib.pyplot as plt 27 | import matplotlib as mpl 28 | 29 | import seaborn as sns 30 | 31 | sns.set_style("ticks") 32 | 33 | import torch 34 | 35 | import pickle 36 | 37 | import scanpy as sc 38 | 39 | import tqdm.auto as tqdm 40 | import io 41 | 42 | # %% 43 | import chromatinhd as chd 44 | 45 | # %% 46 | folder_root = chd.get_output() 47 | folder_data = folder_root / "data" 48 | 49 | dataset_name = "pbmc10k" 50 | main_url = "https://cf.10xgenomics.com/samples/cell-arc/2.0.0/pbmc_granulocyte_sorted_10k/pbmc_granulocyte_sorted_10k" 51 | genome = "GRCh38" 52 | 53 | folder_data_preproc = folder_data / dataset_name 54 | folder_data_preproc.mkdir(exist_ok=True, parents=True) 55 | 56 | # %% [markdown] 57 | # ## Tiny dataset 58 | 59 | # %% 60 | dataset_folder = chd.get_output() / "datasets" / "pbmc10ktiny" 61 | folder_dataset_publish = chd.get_git_root() / "src" / "chromatinhd" / "data" / "examples" / "pbmc10ktiny" 62 | folder_dataset_publish.mkdir(exist_ok=True, parents=True) 63 | 64 | # %% 65 | adata = pickle.load((folder_data_preproc / "adata_annotated.pkl").open("rb")) 66 | genes_oi = adata.var.sort_values("dispersions_norm", ascending=False).index[:50] 67 | adata = adata[:, genes_oi] 68 | adata_tiny = sc.AnnData(X=adata[:, genes_oi].layers["magic"], obs=adata.obs, var=adata.var.loc[genes_oi]) 69 | 70 | # %% 71 | transcriptome = chd.data.transcriptome.Transcriptome.from_adata(adata, path=dataset_folder / "transcriptome") 72 | adata_tiny.write(folder_dataset_publish / "transcriptome.h5ad", compression="gzip") 73 | 74 | # %% 75 | selected_transcripts = pickle.load((folder_data_preproc / "selected_transcripts.pkl").open("rb")) 76 | selected_transcripts = selected_transcripts.loc[genes_oi] 77 | regions = chd.data.regions.Regions.from_transcripts( 78 | selected_transcripts, [-10000, 10000], path=dataset_folder / "regions" / "10k10k" 79 | ) 80 | 81 | # %% 82 | import pysam 83 | 84 | fragments_tabix = pysam.TabixFile(str(folder_data_preproc / "fragments.tsv.gz")) 85 | coordinates = regions.coordinates 86 | 87 | fragments_new = [] 88 | for i, (gene, promoter_info) in tqdm.tqdm(enumerate(coordinates.iterrows()), total=coordinates.shape[0]): 89 | start = max(0, promoter_info["start"]) 90 | 91 | fragments_promoter = fragments_tabix.fetch(promoter_info["chrom"], start, promoter_info["end"]) 92 | fragments_new.extend(list(fragments_promoter)) 93 | 94 | fragments_new = pd.DataFrame( 95 | [x.split("\t") for x in fragments_new], columns=["chrom", "start", "end", "cell", "nreads"] 96 | ) 97 | fragments_new["start"] = fragments_new["start"].astype(int) 98 | fragments_new = fragments_new.sort_values(["chrom", "start", "cell"]) 99 | 100 | # %% 101 | fragments_new.to_csv(folder_dataset_publish / "fragments.tsv", sep="\t", index=False, header=False) 102 | pysam.tabix_compress(folder_dataset_publish / "fragments.tsv", folder_dataset_publish / "fragments.tsv.gz", force=True) 103 | 104 | # %% 105 | # !ls -lh {folder_dataset_publish} 106 | # !tabix -p bed {folder_dataset_publish / "fragments.tsv.gz"} 107 | folder_dataset_publish 108 | 109 | # %% 110 | transcriptome.var.index 111 | 112 | # %% 113 | coordinates.index 114 | -------------------------------------------------------------------------------- /scripts/cythonize.sh: -------------------------------------------------------------------------------- 1 | cython src/chromatinhd/loaders/fragments_helpers.pyx --embed -a 2 | cython src/chromatinhd/data/motifscan/scan_helpers.pyx --embed -a -------------------------------------------------------------------------------- /scripts/dist.sh: -------------------------------------------------------------------------------- 1 | python -m setuptools_git_versioning 2 | 3 | version="0.4.3" 4 | 5 | git add . 6 | git commit -m "version v${version}" 7 | 8 | git tag -a v${version} -m "v${version}" 9 | 10 | python -m build 11 | 12 | # twine upload --repository testpypi dist/chromatinhd-${version}.tar.gz --verbose 13 | 14 | git push --tags 15 | 16 | gh release create v${version} -t "v${version}" -n "v${version}" dist/chromatinhd-${version}.tar.gz 17 | 18 | twine upload dist/chromatinhd-${version}.tar.gz --verbose 19 | 20 | python -m build --wheel -------------------------------------------------------------------------------- /scripts/docs.sh: -------------------------------------------------------------------------------- 1 | jupytext --sync docs/source/*/*.py 2 | 3 | 4 | mkdocs build -------------------------------------------------------------------------------- /scripts/followup/miff/fit_negbinom.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | ## fit_nbinom 4 | # Copyright (C) 2014 Gokcen Eraslan 5 | # 6 | # This program is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with this program. If not, see . 18 | 19 | from __future__ import print_function 20 | 21 | import numpy as np 22 | from scipy.special import gammaln 23 | from scipy.special import psi 24 | from scipy.special import factorial 25 | from scipy.optimize import fmin_l_bfgs_b as optim 26 | 27 | import sys 28 | 29 | 30 | # X is a numpy array representing the data 31 | # initial params is a numpy array representing the initial values of 32 | # size and prob parameters 33 | def fit_nbinom(X, initial_params=None): 34 | infinitesimal = np.finfo(np.float).eps 35 | 36 | def log_likelihood(params, *args): 37 | r, p = params 38 | X = args[0] 39 | N = X.size 40 | 41 | # MLE estimate based on the formula on Wikipedia: 42 | # http://en.wikipedia.org/wiki/Negative_binomial_distribution#Maximum_likelihood_estimation 43 | result = ( 44 | np.sum(gammaln(X + r)) 45 | - np.sum(np.log(factorial(X))) 46 | - N * (gammaln(r)) 47 | + N * r * np.log(p) 48 | + np.sum(X * np.log(1 - (p if p < 1 else 1 - infinitesimal))) 49 | ) 50 | 51 | return -result 52 | 53 | def log_likelihood_deriv(params, *args): 54 | r, p = params 55 | X = args[0] 56 | N = X.size 57 | 58 | pderiv = (N * r) / p - np.sum(X) / (1 - (p if p < 1 else 1 - infinitesimal)) 59 | rderiv = np.sum(psi(X + r)) - N * psi(r) + N * np.log(p) 60 | 61 | return np.array([-rderiv, -pderiv]) 62 | 63 | if initial_params is None: 64 | # reasonable initial values (from fitdistr function in R) 65 | m = np.mean(X) 66 | v = np.var(X) 67 | size = (m**2) / (v - m) if v > m else 10 68 | 69 | # convert mu/size parameterization to prob/size 70 | p0 = size / ((size + m) if size + m != 0 else 1) 71 | r0 = size 72 | initial_params = np.array([r0, p0]) 73 | 74 | bounds = [(infinitesimal, None), (infinitesimal, 1)] 75 | optimres = optim( 76 | log_likelihood, 77 | x0=initial_params, 78 | # fprime=log_likelihood_deriv, 79 | args=(X,), 80 | approx_grad=1, 81 | bounds=bounds, 82 | ) 83 | 84 | params = optimres[0] 85 | return {"size": params[0], "prob": params[1]} 86 | 87 | 88 | if __name__ == "__main__": 89 | if len(sys.argv) != 3: 90 | print("Usage: %s size_param prob_param" % sys.argv[0]) 91 | exit() 92 | 93 | testset = np.random.negative_binomial(float(sys.argv[1]), float(sys.argv[2]), 1000) 94 | print(fit_nbinom(testset)) 95 | -------------------------------------------------------------------------------- /scripts/followup/miff/layers_linear_spline.py: -------------------------------------------------------------------------------- 1 | # --- 2 | # jupyter: 3 | # jupytext: 4 | # text_representation: 5 | # extension: .py 6 | # format_name: percent 7 | # format_version: '1.3' 8 | # jupytext_version: 1.15.1 9 | # kernelspec: 10 | # display_name: Python 3 11 | # language: python 12 | # name: python3 13 | # --- 14 | 15 | # %% 16 | import polyptich as pp 17 | pp.setup_ipython() 18 | 19 | import numpy as np 20 | import pandas as pd 21 | 22 | import matplotlib.pyplot as plt 23 | import matplotlib as mpl 24 | 25 | import seaborn as sns 26 | 27 | sns.set_style("ticks") 28 | # %config InlineBackend.figure_format='retina' 29 | 30 | import tqdm.auto as tqdm 31 | 32 | # %% 33 | import torch 34 | import chromatinhd as chd 35 | chd.set_default_device("cuda:0") 36 | 37 | import tempfile 38 | import pathlib 39 | import pickle 40 | 41 | # %% 42 | import quadratic 43 | 44 | # %% 45 | width = 1024 46 | positions = torch.tensor([100, 500, 1000]) 47 | nbins = [8, 8, 8] 48 | 49 | unnormalized_heights_bins = [ 50 | torch.tensor([[1] * 8]*3, dtype = torch.float), 51 | torch.tensor([[1] * 8]*3, dtype = torch.float), 52 | torch.tensor([[1] * 8]*3, dtype = torch.float), 53 | # torch.tensor([[1] * 2]*3, dtype = torch.float), 54 | ] 55 | unnormalized_heights_bins[1][1, 3] = 10 56 | 57 | # %% 58 | unnormalized_heights_all = [] 59 | cur_total_n = 1 60 | for n in nbins: 61 | cur_total_n *= n 62 | unnormalized_heights_all.append(torch.zeros(cur_total_n).reshape(-1, n)) 63 | unnormalized_heights_all[0][0, 3] = -1 64 | unnormalized_heights_all[1][2, 2:4] = 1 65 | unnormalized_heights_all[2][0, 2:4] = 1 66 | 67 | # %% 68 | import math 69 | 70 | def transform_linear_spline(positions, n, width, unnormalized_heights): 71 | binsize = width//n 72 | 73 | normalized_heights = torch.nn.functional.log_softmax(unnormalized_heights, -1) 74 | if normalized_heights.ndim == positions.ndim: 75 | normalized_heights = normalized_heights.unsqueeze(0) 76 | 77 | binixs = torch.div(positions, binsize, rounding_mode = "trunc") 78 | 79 | logprob = torch.gather(normalized_heights, 1, binixs.unsqueeze(1)).squeeze(1) 80 | 81 | positions = positions - binixs * binsize 82 | width = binsize 83 | 84 | return logprob, positions, width 85 | 86 | def calculate_logprob(positions, nbins, width, unnormalized_heights_bins): 87 | assert len(nbins) == len(unnormalized_heights_bins) 88 | 89 | curpositions = positions 90 | curwidth = width 91 | logprob = torch.zeros_like(positions, dtype = torch.float) 92 | for i, n in enumerate(nbins): 93 | assert (curwidth % n) == 0 94 | logprob_layer, curpositions, curwidth = transform_linear_spline(curpositions, n, curwidth, unnormalized_heights_bins[i]) 95 | logprob += logprob_layer 96 | logprob = logprob - math.log(curwidth) 97 | return logprob 98 | 99 | 100 | # %% 101 | x = torch.arange(width) 102 | 103 | # %% 104 | totalnbins = np.cumprod(nbins) 105 | totalbinwidths = torch.tensor(width//totalnbins) 106 | totalbinixs = torch.div(x[:, None], totalbinwidths, rounding_mode="floor") 107 | totalbinsectors = torch.div(totalbinixs, torch.tensor(nbins)[None, :], rounding_mode="trunc") 108 | unnormalized_heights_bins = [unnormalized_heights_all[i][totalbinsector] for i, totalbinsector in enumerate(totalbinsectors.numpy().T)] 109 | 110 | # %% 111 | torch.tensor(width//np.array(nbins)) 112 | 113 | # %% 114 | totalbinwidthsa 115 | 116 | # %% 117 | logprob = calculate_logprob(x, nbins, width, unnormalized_heights_bins) 118 | fig, ax = plt.subplots() 119 | ax.plot(x, torch.exp(logprob)) 120 | -------------------------------------------------------------------------------- /scripts/followup/miff/quadratic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | DEFAULT_MIN_BIN_HEIGHT = 1e-5 4 | 5 | 6 | def calculate_heights(unnormalized_heights, widths, min_bin_height=DEFAULT_MIN_BIN_HEIGHT, local=True): 7 | unnorm_heights_exp = torch.exp(unnormalized_heights) 8 | 9 | min_bin_height = 1e-10 10 | 11 | if local: 12 | # per feature normalization 13 | unnormalized_area = torch.sum( 14 | ((unnorm_heights_exp[..., :-1] + unnorm_heights_exp[..., 1:]) / 2) * widths, 15 | dim=-1, 16 | keepdim=True, 17 | ) 18 | heights = unnorm_heights_exp / unnormalized_area 19 | heights = min_bin_height + (1 - min_bin_height) * heights 20 | else: 21 | # global normalization 22 | unnormalized_area = torch.sum( 23 | ((unnorm_heights_exp[..., :-1] + unnorm_heights_exp[..., 1:]) / 2) * widths, 24 | ) 25 | heights = unnorm_heights_exp * unnorm_heights_exp.shape[-2] / unnormalized_area 26 | heights = min_bin_height + (1 - min_bin_height) * heights 27 | 28 | # to check 29 | # normalized_area = torch.sum( 30 | # ((heights[..., :-1] + heights[..., 1:]) / 2) * widths, 31 | # dim=-1, 32 | # keepdim=True, 33 | # ) 34 | # print(normalized_area.sum()) 35 | 36 | return heights 37 | 38 | 39 | def calculate_bin_left_cdf(heights, widths): 40 | bin_left_cdf = torch.cumsum(((heights[..., :-1] + heights[..., 1:]) / 2) * widths, dim=-1) 41 | bin_left_cdf[..., -1] = 1.0 42 | bin_left_cdf = F.pad(bin_left_cdf, pad=(1, 0), mode="constant", value=0.0) 43 | return bin_left_cdf 44 | -------------------------------------------------------------------------------- /scripts/install.sh: -------------------------------------------------------------------------------- 1 | pip uninstall torch 2 | pip cache purge 3 | 4 | pip uninstall torch 5 | pip cache purge 6 | pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118 7 | pip install torch-scatter -f https://data.pyg.org/whl/torch-2.0.1+cu118.html 8 | -------------------------------------------------------------------------------- /scripts/test.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import requests 3 | 4 | # Send a GET request to the BioMart REST API 5 | response = requests.get('http://www.ensembl.org/biomart/martservice?type=registry') 6 | 7 | # The response is an XML string, so parse it using ElementTree 8 | import xml.etree.ElementTree as ET 9 | root = ET.fromstring(response.text) 10 | 11 | # Iterate over all the MartURLLocation elements (these represent the datasets) 12 | for dataset in root.iter('MartURLLocation'): 13 | # The species is stored in the 'name' attribute 14 | species = dataset.get('name') 15 | 16 | # If the species is human, print the details of the dataset 17 | if 'hsapiens' in species: 18 | print(ET.tostring(dataset, encoding='utf8').decode('utf8')) 19 | 20 | # %% 21 | response.text 22 | # %% 23 | import pandas as pd 24 | import io 25 | 26 | def get_datasets(): 27 | mart = "ENSEMBL_MART_ENSEMBL" 28 | baseurl = "http://www.ensembl.org/biomart/martservice?" 29 | url = "{baseurl}type=datasets&requestid=biomaRt&mart={mart}" 30 | response = requests.get(url.format(baseurl=baseurl, mart=mart)) 31 | root = pd.read_table(io.StringIO(response.text), sep="\t", header=None, names=["_", "dataset", "description", "version", "assembly", "__", "___", "____", "last_update"]) 32 | print(root) 33 | # %% 34 | get_biomart_datasets() 35 | # %% 36 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, Extension 2 | from Cython.Build import cythonize 3 | from Cython.Compiler import Options 4 | import numpy 5 | 6 | # These are optional 7 | Options.docstrings = True 8 | Options.annotate = False 9 | 10 | # Modules to be compiled and include_dirs when necessary 11 | extensions = [ 12 | # Extension( 13 | # "pyctmctree.inpyranoid_c", 14 | # ["src/pyctmctree/inpyranoid_c.pyx"], 15 | # ), 16 | Extension( 17 | "chromatinhd.loaders.fragments_helpers", 18 | ["src/chromatinhd/loaders/fragments_helpers.pyx"], 19 | include_dirs=[numpy.get_include()], 20 | py_limited_api=True, 21 | ), 22 | Extension( 23 | "chromatinhd.data.motifscan.scan_helpers", 24 | ["src/chromatinhd/data/motifscan/scan_helpers.pyx"], 25 | include_dirs=[numpy.get_include()], 26 | py_limited_api=True, 27 | ), 28 | ] 29 | 30 | 31 | # This is the function that is executed 32 | setup( 33 | name='chromatinhd', # Required 34 | 35 | # A list of compiler Directives is available at 36 | # https://cython.readthedocs.io/en/latest/src/userguide/source_files_and_compilation.html#compiler-directives 37 | 38 | # external to be compiled 39 | ext_modules = cythonize(extensions, compiler_directives={"language_level": 3, "profile": False}), 40 | ) -------------------------------------------------------------------------------- /src/chromatinhd/__init__.py: -------------------------------------------------------------------------------- 1 | from .device import get_default_device, set_default_device 2 | from .utils import get_git_root, get_output, get_code, save, Unpickler, load 3 | from . import sparse 4 | from . import utils 5 | from . import flow 6 | from . import plot 7 | from . import data 8 | from . import train 9 | from . import embedding 10 | from . import optim 11 | from . import biomart 12 | from . import models 13 | from polyptich import grid 14 | 15 | __all__ = [ 16 | "get_git_root", 17 | "get_output", 18 | "get_code", 19 | "save", 20 | "Unpickler", 21 | "load", 22 | "sparse", 23 | "utils", 24 | "flow", 25 | "data", 26 | "train", 27 | "embedding", 28 | "optim", 29 | "biomart", 30 | "models", 31 | "plot", 32 | "grid", 33 | ] 34 | -------------------------------------------------------------------------------- /src/chromatinhd/biomart/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import Dataset 2 | from .tss import get_canonical_transcripts, get_exons, get_transcripts, map_symbols, get_genes 3 | from . import tss 4 | from .homology import get_orthologs 5 | 6 | __all__ = ["Dataset", "get_canonical_transcripts", "get_exons", "get_transcripts", "tss"] 7 | -------------------------------------------------------------------------------- /src/chromatinhd/biomart/cache.py: -------------------------------------------------------------------------------- 1 | import diskcache 2 | import appdirs 3 | 4 | cache = diskcache.Cache(appdirs.user_cache_dir('biomart')) 5 | -------------------------------------------------------------------------------- /src/chromatinhd/biomart/homology.py: -------------------------------------------------------------------------------- 1 | from .dataset import Dataset 2 | import numpy as np 3 | 4 | 5 | def get_orthologs(biomart_dataset: Dataset, gene_ids, organism="mmusculus"): 6 | """ 7 | Map ensembl gene ids to orthologs in another organism 8 | """ 9 | 10 | gene_ids_to_map = np.unique(gene_ids) 11 | mapping = biomart_dataset.get_batched( 12 | [ 13 | biomart_dataset.attribute("ensembl_gene_id"), 14 | biomart_dataset.attribute("external_gene_name"), 15 | biomart_dataset.attribute(f"{organism}_homolog_ensembl_gene"), 16 | biomart_dataset.attribute(f"{organism}_homolog_associated_gene_name"), 17 | ], 18 | filters=[ 19 | biomart_dataset.filter("ensembl_gene_id", value=gene_ids_to_map), 20 | ], 21 | ) 22 | mapping = mapping.groupby("ensembl_gene_id").first() 23 | 24 | return mapping[f"{organism}_homolog_ensembl_gene"].reindex(gene_ids).values 25 | -------------------------------------------------------------------------------- /src/chromatinhd/data/__init__.py: -------------------------------------------------------------------------------- 1 | from . import fragments 2 | from . import transcriptome 3 | from . import motifscan 4 | from .fragments import Fragments 5 | from .transcriptome import Transcriptome 6 | from .genotype import Genotype 7 | from .clustering import Clustering 8 | from .motifscan import Motifscan, Motiftrack 9 | from .regions import Regions 10 | from . import regions 11 | from . import folds 12 | 13 | __all__ = ["Fragments", "Transcriptome", "Regions", "folds", "motifscan", "regions"] 14 | -------------------------------------------------------------------------------- /src/chromatinhd/data/associations/__init__.py: -------------------------------------------------------------------------------- 1 | from .associations import Associations 2 | -------------------------------------------------------------------------------- /src/chromatinhd/data/associations/associations.py: -------------------------------------------------------------------------------- 1 | from chromatinhd.data.motifscan import Motifscan 2 | from chromatinhd.flow import StoredDataFrame 3 | 4 | from . import plot 5 | 6 | 7 | class Associations(Motifscan): 8 | association = StoredDataFrame() 9 | -------------------------------------------------------------------------------- /src/chromatinhd/data/clustering/__init__.py: -------------------------------------------------------------------------------- 1 | from .clustering import Clustering 2 | -------------------------------------------------------------------------------- /src/chromatinhd/data/clustering/clustering.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import pandas as pd 3 | import numpy as np 4 | 5 | from chromatinhd.flow import Flow, Stored, StoredDataFrame, PathLike 6 | from chromatinhd.flow.tensorstore import Tensorstore 7 | 8 | 9 | class Clustering(Flow): 10 | labels: pd.DataFrame = Stored() 11 | "Labels for each cell." 12 | 13 | indices: np.array = Tensorstore(dtype=" Clustering: 27 | """ 28 | Create a Clustering object from a series of labels. 29 | 30 | Parameters: 31 | labels: 32 | Series of labels for each cell, with index corresponding to cell 33 | names. 34 | path: 35 | Folder where the clustering information will be stored. 36 | overwrite: 37 | Whether to overwrite the clustering information if it already 38 | exists. 39 | 40 | Returns: 41 | Clustering object. 42 | 43 | """ 44 | self = cls(path, reset=overwrite) 45 | 46 | if not overwrite and self.o.labels.exists(self): 47 | return self 48 | 49 | if not isinstance(labels, pd.Series): 50 | labels = pd.Series(labels).astype("category") 51 | elif not labels.dtype.name == "category": 52 | labels = labels.astype("category") 53 | self.labels = labels 54 | self.indices = labels.cat.codes.values 55 | 56 | if var is None: 57 | var = ( 58 | pd.DataFrame( 59 | { 60 | "cluster": labels.cat.categories, 61 | "label": labels.cat.categories, 62 | } 63 | ) 64 | .set_index("cluster") 65 | .loc[labels.cat.categories] 66 | ) 67 | var["n_cells"] = labels.value_counts() 68 | else: 69 | var = var.reindex(labels.cat.categories) 70 | var["label"] = labels.cat.categories 71 | self.var = var 72 | return self 73 | 74 | @property 75 | def n_clusters(self): 76 | return len(self.labels.cat.categories) 77 | 78 | # temporarily link cluster_info to var 79 | @property 80 | def cluster_info(self): 81 | return self.var 82 | 83 | @cluster_info.setter 84 | def cluster_info(self, cluster_info): 85 | self.var = cluster_info 86 | -------------------------------------------------------------------------------- /src/chromatinhd/data/examples/pbmc10ktiny/fragments.tsv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeplanckeLab/ChromatinHD/3b089ba409b452a8275ed020555d8bcad50f3299/src/chromatinhd/data/examples/pbmc10ktiny/fragments.tsv.gz -------------------------------------------------------------------------------- /src/chromatinhd/data/examples/pbmc10ktiny/fragments.tsv.gz.REMOVED.git-id: -------------------------------------------------------------------------------- 1 | 9bbd83c6bcf3f8c42728621ecb8554170dc7caea -------------------------------------------------------------------------------- /src/chromatinhd/data/examples/pbmc10ktiny/fragments.tsv.gz.tbi: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeplanckeLab/ChromatinHD/3b089ba409b452a8275ed020555d8bcad50f3299/src/chromatinhd/data/examples/pbmc10ktiny/fragments.tsv.gz.tbi -------------------------------------------------------------------------------- /src/chromatinhd/data/examples/pbmc10ktiny/transcriptome.h5ad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeplanckeLab/ChromatinHD/3b089ba409b452a8275ed020555d8bcad50f3299/src/chromatinhd/data/examples/pbmc10ktiny/transcriptome.h5ad -------------------------------------------------------------------------------- /src/chromatinhd/data/folds/__init__.py: -------------------------------------------------------------------------------- 1 | from .folds import Folds 2 | -------------------------------------------------------------------------------- /src/chromatinhd/data/fragments/__init__.py: -------------------------------------------------------------------------------- 1 | from .fragments import Fragments 2 | from .view import FragmentsView 3 | -------------------------------------------------------------------------------- /src/chromatinhd/data/genotype/__init__.py: -------------------------------------------------------------------------------- 1 | from .genotype import Genotype 2 | -------------------------------------------------------------------------------- /src/chromatinhd/data/genotype/genotype.py: -------------------------------------------------------------------------------- 1 | from chromatinhd.flow import Flow, Stored 2 | 3 | 4 | class Genotype(Flow): 5 | genotypes = Stored() 6 | variants_info = Stored() 7 | -------------------------------------------------------------------------------- /src/chromatinhd/data/gradient/__init__.py: -------------------------------------------------------------------------------- 1 | from .gradient import Gradient 2 | -------------------------------------------------------------------------------- /src/chromatinhd/data/gradient/gradient.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import pandas as pd 3 | import numpy as np 4 | 5 | from chromatinhd.flow import Flow, Stored, StoredDataFrame, PathLike 6 | from chromatinhd.flow.tensorstore import Tensorstore 7 | 8 | 9 | class Gradient(Flow): 10 | values: np.array = Tensorstore(dtype=" Gradient: 18 | """ 19 | Create a Gradient object from a series of values. 20 | 21 | Parameters: 22 | values: 23 | Series of values for each cell, with index corresponding to cell 24 | names. 25 | path: 26 | Folder where the gradient information will be stored. 27 | 28 | Returns: 29 | Gradient object. 30 | 31 | """ 32 | gradient = cls(path) 33 | if isinstance(values, pd.Series): 34 | values = pd.DataFrame({"gradient": values}).astype(float) 35 | elif not isinstance(values, pd.DataFrame): 36 | values = pd.DataFrame(values).astype(float) 37 | gradient.values = values.values 38 | gradient.var = pd.DataFrame( 39 | { 40 | "gradient": values.columns, 41 | "label": values.columns, 42 | } 43 | ).set_index("gradient") 44 | return gradient 45 | 46 | @property 47 | def n_gradients(self): 48 | return len(self.values.shape[1]) 49 | -------------------------------------------------------------------------------- /src/chromatinhd/data/motifscan/__init__.py: -------------------------------------------------------------------------------- 1 | from .motifscan import Motifscan, read_pwms 2 | from .motiftrack import Motiftrack 3 | from . import plot 4 | from .view import MotifscanView 5 | from . import download 6 | from . import plot_genome -------------------------------------------------------------------------------- /src/chromatinhd/data/motifscan/download.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import urllib 3 | import pathlib 4 | import chromatinhd.data.motifscan 5 | import json 6 | 7 | 8 | def get_hocomoco_11(path, organism="human", variant="core", overwrite=False): 9 | """ 10 | Download hocomoco human data 11 | 12 | Parameters: 13 | path: 14 | the path to download to 15 | organism: 16 | the organism to download for, either "human" or "mouse" 17 | variant: 18 | the variant to download, either "full" or "core" 19 | overwrite: 20 | whether to overwrite existing files 21 | """ 22 | path = pathlib.Path(path) 23 | path.mkdir(parents=True, exist_ok=True) 24 | 25 | if organism == "human": 26 | organism = "HUMAN" 27 | elif organism == "mouse": 28 | organism = "MOUSE" 29 | else: 30 | raise ValueError(f"Unknown organism: {organism}") 31 | 32 | # download cutoffs, pwms and annotations 33 | if overwrite or (not (path / "pwm_cutoffs.txt").exists()): 34 | urllib.request.urlretrieve( 35 | f"https://hocomoco11.autosome.org/final_bundle/hocomoco11/{variant}/{organism}/mono/HOCOMOCOv11_{variant}_standard_thresholds_{organism}_mono.txt", 36 | path / "pwm_cutoffs.txt", 37 | ) 38 | urllib.request.urlretrieve( 39 | f"https://hocomoco11.autosome.org/final_bundle/hocomoco11/{variant}/{organism}/mono/HOCOMOCOv11_{variant}_pwms_{organism}_mono.txt", 40 | path / "pwms.txt", 41 | ) 42 | urllib.request.urlretrieve( 43 | f"https://hocomoco11.autosome.org/final_bundle/hocomoco11/{variant}/{organism}/mono/HOCOMOCOv11_{variant}_annotation_{organism}_mono.tsv", 44 | path / "annot.txt", 45 | ) 46 | 47 | pwms = chromatinhd.data.motifscan.read_pwms(path / "pwms.txt") 48 | 49 | motifs = pd.DataFrame({"motif": pwms.keys()}).set_index("motif") 50 | motif_cutoffs = pd.read_table( 51 | path / "pwm_cutoffs.txt", 52 | names=["motif", "cutoff_001", "cutoff_0005", "cutoff_0001"], 53 | skiprows=1, 54 | ).set_index("motif") 55 | motifs = motifs.join(motif_cutoffs) 56 | annot = ( 57 | pd.read_table(path / "annot.txt") 58 | .rename(columns={"Model": "motif", "Transcription factor": "gene_label"}) 59 | .set_index("motif") 60 | ) 61 | motifs = motifs.join(annot) 62 | 63 | return pwms, motifs 64 | 65 | 66 | def get_hocomoco(path, organism="hs", variant="CORE", overwrite=False): 67 | """ 68 | Download hocomoco human data 69 | 70 | Parameters: 71 | path: 72 | the path to download to 73 | organism: 74 | the organism to download for, either "hs" or "mm" 75 | variant: 76 | the variant to download, either "INVIVO" or "CORE" 77 | overwrite: 78 | whether to overwrite existing files 79 | """ 80 | path = pathlib.Path(path) 81 | path.mkdir(parents=True, exist_ok=True) 82 | 83 | # download cutoffs, pwms and annotations 84 | if overwrite or (not (path / "pwms.tar.gz").exists()): 85 | urllib.request.urlretrieve( 86 | f"https://hocomoco12.autosome.org/final_bundle/hocomoco12/H12{variant}/H12{variant}_annotation.jsonl", 87 | path / "annotation.jsonl", 88 | ) 89 | urllib.request.urlretrieve( 90 | f"https://hocomoco12.autosome.org/final_bundle/hocomoco12/H12{variant}/H12{variant}_pwm.tar.gz", 91 | path / "pwms.tar.gz", 92 | ) 93 | urllib.request.urlretrieve( 94 | f"https://hocomoco12.autosome.org/final_bundle/hocomoco12/H12{variant}/H12{variant}_thresholds.tar.gz", 95 | path / "thresholds.tar.gz", 96 | ) 97 | 98 | pwms = chromatinhd.data.motifscan.read_pwms(path / "pwms.tar.gz") 99 | motifs = [json.loads(line) for line in open(path / "annotation.jsonl").readlines()] 100 | motifs = pd.DataFrame(motifs).set_index("name") 101 | motifs.index.name = "motif" 102 | 103 | for thresh in motifs["standard_thresholds"].iloc[0].keys(): 104 | motifs["cutoff_" + thresh] = [thresholds[thresh] for _, thresholds in motifs["standard_thresholds"].items()] 105 | for species in ["HUMAN", "MOUSE"]: 106 | motifs[species + "_gene_symbol"] = [ 107 | masterlist_info["species"][species]["gene_symbol"] if species in masterlist_info["species"] else None 108 | for _, masterlist_info in motifs["masterlist_info"].items() 109 | ] 110 | motifs["symbol"] = motifs["HUMAN_gene_symbol"] if organism == "hs" else motifs["MOUSE_gene_symbol"] 111 | 112 | return pwms, motifs 113 | -------------------------------------------------------------------------------- /src/chromatinhd/data/motifscan/motifcount.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | from .motifscan import Motifscan 3 | from .view import MotifscanView 4 | import numpy as np 5 | import tqdm.auto as tqdm 6 | 7 | 8 | class BinnedMotifCounts: 9 | """ 10 | Provides binned motif counts per fragment 11 | """ 12 | 13 | def __init__( 14 | self, 15 | motifscan: Motifscan, 16 | motif_binsizes: List[int], 17 | fragment_binsizes: List[int], 18 | ): 19 | self.motifscan = motifscan 20 | self.width = motifscan.regions.width 21 | assert self.width is not None, "Regions must have a width" 22 | self.motif_binsizes = motif_binsizes 23 | self.fragment_binsizes = fragment_binsizes 24 | self.n_motifs = motifscan.motifs.shape[0] 25 | 26 | n_genes = len(motifscan.regions.coordinates) 27 | 28 | self.fragment_widths = [] 29 | self.motifcount_sizes = [] 30 | self.motif_widths = [] 31 | self.fragmentprob_sizes = [] 32 | width = self.region_width 33 | for motif_binsize, fragment_binsize in zip(self.motif_binsizes, self.fragment_binsizes): 34 | assert fragment_binsize % motif_binsize == 0, ( 35 | "motif_binsize must be a multiple of fragment_binsize", 36 | motif_binsize, 37 | fragment_binsize, 38 | ) 39 | self.motifcount_sizes.append(width // motif_binsize) 40 | self.motif_widths.append(self.region_width // motif_binsize) 41 | self.fragmentprob_sizes.append(width // fragment_binsize) 42 | self.fragment_widths.append(self.region_width // fragment_binsize) 43 | width = fragment_binsize 44 | 45 | precomputed = [] 46 | for motif_binsize, motif_width, fragment_binsize, fragment_width in zip( 47 | self.motif_binsizes, 48 | self.motif_widths, 49 | self.fragment_binsizes, 50 | [1, *self.fragment_widths[:-1]], 51 | ): 52 | precomputed.append( 53 | np.bincount(motifscan.positions // motif_binsize, minlength=(n_genes * motif_width)).reshape( 54 | (n_genes * fragment_width, -1) 55 | ) 56 | ) 57 | self.precomputed = precomputed 58 | 59 | 60 | class BinnedMotifCounts: 61 | """ 62 | Provides binned motif counts per fragment 63 | """ 64 | 65 | def __init__( 66 | self, 67 | motifscan: Union[Motifscan, MotifscanView], 68 | binsize: int, 69 | ): 70 | self.motifscan = motifscan 71 | self.width = motifscan.regions.width 72 | assert self.width is not None, "Regions must have a width" 73 | self.binsize = binsize 74 | self.n_motifs = motifscan.motifs.shape[0] 75 | 76 | n_regions = motifscan.regions.n_regions 77 | motif_width = motifscan.regions.width // binsize 78 | 79 | precomputed = np.zeros((n_regions, motif_width), dtype=np.int32) 80 | 81 | for region_ix in tqdm.tqdm(range(n_regions)): 82 | if isinstance(motifscan, Motifscan): 83 | indptr_start = motifscan.region_indptr[region_ix] 84 | indptr_end = motifscan.region_indptr[region_ix + 1] 85 | coordinates = motifscan.coordinates[indptr_start:indptr_end] 86 | else: 87 | indptr_start, indptr_end = motifscan.region_indptr[region_ix] 88 | coordinates = ( 89 | motifscan.coordinates[indptr_start:indptr_end] - motifscan.regions.region_starts[region_ix] 90 | ) 91 | precomputed[region_ix] = np.bincount( 92 | coordinates // binsize, 93 | minlength=(motif_width), 94 | ) 95 | self.precomputed = precomputed 96 | 97 | # self.fragment_binsizes = fragment_binsizes 98 | # self.n_motifs = motifscan.motifs.shape[0] 99 | 100 | # self.fragment_widths = [] 101 | # self.fragmentprob_sizes = [] 102 | # width = self.region_width 103 | # for fragment_binsize in zip(self.fragment_binsizes): 104 | # self.fragmentprob_sizes.append(width // fragment_binsize) 105 | # self.fragment_widths.append(self.region_width // fragment_binsize) 106 | # width = fragment_binsize 107 | -------------------------------------------------------------------------------- /src/chromatinhd/data/motifscan/motiftrack.py: -------------------------------------------------------------------------------- 1 | from chromatinhd.flow import Flow, CompressedNumpyFloat64 2 | import pandas as pd 3 | 4 | 5 | class Motiftrack(Flow): 6 | scores = CompressedNumpyFloat64("scores") 7 | 8 | _motifs = None 9 | 10 | @property 11 | def motifs(self): 12 | if self._motifs is None: 13 | self._motifs = pd.read_pickle(self.path / "motifs.pkl") 14 | return self._motifs 15 | 16 | @motifs.setter 17 | def motifs(self, value): 18 | value.index.name = "gene" 19 | value.to_pickle(self.path / "motifs.pkl") 20 | self._motifs = value 21 | -------------------------------------------------------------------------------- /src/chromatinhd/data/motifscan/scan_helpers.pyx: -------------------------------------------------------------------------------- 1 | #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION 2 | #cython: language_level=3 3 | cimport cython 4 | import numpy as np 5 | cimport numpy as np 6 | np.import_array() 7 | 8 | INT8 = np.int8 9 | ctypedef np.int8_t INT8_t 10 | 11 | @cython.boundscheck(False) # turn off bounds-checking for entire function 12 | @cython.wraparound(False) # turn off negative index wrapping for entire function 13 | @cython.cdivision 14 | def seq_to_onehot(bytes s, INT8_t [:,::1] out_onehot): 15 | cdef char* cstr = s 16 | cdef int i, length = len(s) 17 | 18 | for i in range(length): 19 | if cstr[i] == b'A': 20 | out_onehot[i, 0] = 1 21 | elif cstr[i] == b'C': 22 | out_onehot[i, 1] = 1 23 | elif cstr[i] == b'G': 24 | out_onehot[i, 2] = 1 25 | elif cstr[i] == b'T': 26 | out_onehot[i, 3] = 1 27 | 28 | return out_onehot -------------------------------------------------------------------------------- /src/chromatinhd/data/peakcounts/__init__.py: -------------------------------------------------------------------------------- 1 | from .peakcounts import PeakCounts, Windows 2 | from . import plot 3 | -------------------------------------------------------------------------------- /src/chromatinhd/data/transcriptome/__init__.py: -------------------------------------------------------------------------------- 1 | from .transcriptome import Transcriptome -------------------------------------------------------------------------------- /src/chromatinhd/data/transcriptome/timetranscriptome.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import pathlib 6 | from typing import Union 7 | 8 | from chromatinhd.flow import Flow, Stored, StoredDict, TSV 9 | from chromatinhd.flow.tensorstore import Tensorstore 10 | from chromatinhd import sparse 11 | from typing import TYPE_CHECKING 12 | 13 | if TYPE_CHECKING: 14 | import scanpy as sc 15 | 16 | 17 | class TimeTranscriptome(Flow): 18 | """ 19 | A transcriptome during pseudotime 20 | """ 21 | 22 | var: pd.DataFrame = TSV(index_name="gene") 23 | obs: pd.DataFrame = TSV(index_name="pseudocell") 24 | 25 | @classmethod 26 | def from_transcriptome( 27 | cls, 28 | gradient, 29 | transcriptome, 30 | path: Union[pathlib.Path, str] = None, 31 | overwrite=False, 32 | ): 33 | """ 34 | Create a TimeTranscriptome object from a Transcriptome object. 35 | """ 36 | 37 | raise NotImplementedError() 38 | 39 | layers = StoredDict(Tensorstore, kwargs=dict(dtype=" str: 38 | s = "{num_embeddings}, {embedding_dims}" 39 | if self.padding_idx is not None: 40 | s += ", padding_idx={padding_idx}" 41 | if self.max_norm is not None: 42 | s += ", max_norm={max_norm}" 43 | if self.norm_type != 2: 44 | s += ", norm_type={norm_type}" 45 | if self.scale_grad_by_freq is not False: 46 | s += ", scale_grad_by_freq={scale_grad_by_freq}" 47 | if self.sparse is not False: 48 | s += ", sparse=True" 49 | return s.format(**self.__dict__) 50 | 51 | def get_full_weight(self): 52 | return self.weight.view((self.weight.shape[0], *self.embedding_dims)) 53 | 54 | @property 55 | def data(self): 56 | """ 57 | The data of the parameter in dimensions [num_embeddings, *embedding_dims] 58 | """ 59 | return self.get_full_weight().data 60 | 61 | @data.setter 62 | def data(self, value): 63 | if value.ndim == 2: 64 | self.weight.data = value 65 | else: 66 | self.weight.data = value.reshape(self.weight.data.shape) 67 | 68 | @property 69 | def shape(self): 70 | """ 71 | The shape of the parameter, i.e. [num_embeddings, *embedding_dims] 72 | """ 73 | return (self.weight.shape[0], *self.embedding_dims) 74 | 75 | def __getitem__(self, k): 76 | return self.forward(k) 77 | 78 | @classmethod 79 | def from_pretrained(cls, pretrained: EmbeddingTensor): 80 | self = cls(pretrained.num_embeddings, pretrained.embedding_dims) 81 | self.data = pretrained.data 82 | self.weight.requires_grad = False 83 | 84 | return self 85 | 86 | 87 | class FeatureParameter(torch.nn.Module): 88 | _params = tuple() 89 | 90 | def __init__(self, num_embeddings, embedding_dims, constructor=torch.zeros, *args, **kwargs): 91 | super().__init__() 92 | params = [] 93 | for i in range(num_embeddings): 94 | params.append(torch.nn.Parameter(constructor(embedding_dims, *args, **kwargs))) 95 | self.register_parameter(str(i), params[-1]) 96 | self._params = tuple(params) 97 | 98 | def __getitem__(self, k): 99 | return self._params[k] 100 | 101 | def __setitem__(self, k, v): 102 | self._params = list(self._params) 103 | self._params[k] = v 104 | self._params = tuple(self._params) 105 | 106 | def __call__(self, ks): 107 | return torch.stack([self._params[k] for k in ks], 0) 108 | -------------------------------------------------------------------------------- /src/chromatinhd/flow/__init__.py: -------------------------------------------------------------------------------- 1 | from .objects import ( 2 | Linked, 3 | Stored, 4 | StoredDataFrame, 5 | StoredTensor, 6 | StoredNumpyInt64, 7 | CompressedNumpy, 8 | CompressedNumpyFloat64, 9 | CompressedNumpyInt64, 10 | TSV, 11 | StoredDict, 12 | DataArray, 13 | Dataset, 14 | ) 15 | from .flow import ( 16 | Flow, 17 | PathLike, 18 | ) 19 | from . import tensorstore 20 | from .linked import LinkedDict 21 | from .sparse import SparseDataset 22 | -------------------------------------------------------------------------------- /src/chromatinhd/flow/flow_template.jinja2: -------------------------------------------------------------------------------- 1 | 2 | 37 | 38 | 39 | 40 | {{ html }} 41 | -------------------------------------------------------------------------------- /src/chromatinhd/flow/linked.py: -------------------------------------------------------------------------------- 1 | from .objects import Obj, Instance, Linked 2 | 3 | 4 | class LinkedDict(Obj): 5 | def __init__(self, cls=Linked, name=None, kwargs=None): 6 | super().__init__(name=name) 7 | if kwargs is None: 8 | kwargs = {} 9 | self.kwargs = kwargs 10 | self.cls = cls 11 | 12 | def get_path(self, folder): 13 | return folder / self.name 14 | 15 | def __get__(self, obj, type=None): 16 | if obj is not None: 17 | name = "_" + str(self.name) 18 | if not hasattr(obj, name): 19 | x = LinkedDictInstance(self.name, self.get_path(obj.path), self.cls, obj, self.kwargs) 20 | setattr(obj, name, x) 21 | return getattr(obj, name) 22 | 23 | def __set__(self, obj, value, folder=None): 24 | instance = self.__get__(obj) 25 | instance.__set__(obj, value) 26 | 27 | def _repr_html_(self, obj=None): 28 | instance = self.__get__(obj) 29 | return instance._repr_html_() 30 | 31 | def exists(self, obj): 32 | return True 33 | 34 | 35 | class LinkedDictInstance(Instance): 36 | def __init__(self, name, path, cls, obj, kwargs): 37 | super().__init__(name=name, path=path, obj=obj) 38 | self.dict = {} 39 | self.cls = cls 40 | self.obj = obj 41 | self.name = name 42 | self.path = path 43 | self.kwargs = kwargs 44 | if not self.path.exists(): 45 | self.path.mkdir(parents=True) 46 | for file in self.path.iterdir(): 47 | key = file.name 48 | self.dict[key] = self.cls(name=key, **self.kwargs) 49 | 50 | def __getitem__(self, key): 51 | return self.dict[key].__get__(self) 52 | 53 | def __setitem__(self, key, value): 54 | if key not in self.dict: 55 | self.dict[key] = self.cls(name=key, **self.kwargs) 56 | self.dict[key].__set__(self, value) 57 | 58 | def __set__(self, obj, value, folder=None): 59 | for k, v in value.items(): 60 | self[k] = v 61 | 62 | def __contains__(self, key): 63 | return key in self.dict 64 | 65 | def __len__(self): 66 | return len(self.dict) 67 | 68 | def items(self): 69 | for k in self.dict: 70 | yield k, self[k] 71 | 72 | def keys(self): 73 | return self.dict.keys() 74 | 75 | def exists(self): 76 | return True 77 | 78 | def _repr_html_(self): 79 | # return f" {self.name} ({', '.join([getattr(self.obj, k)._repr_html_() for k in self.keys()])})" 80 | items = [] 81 | for i, k in zip(range(3), self.keys()): 82 | items.append(self.dict[k]._repr_html_(self)) 83 | if len(self.keys()) > 3: 84 | items.append("...") 85 | return f" {self.name} ({', '.join(items)})" 86 | -------------------------------------------------------------------------------- /src/chromatinhd/flow/sparse.py: -------------------------------------------------------------------------------- 1 | from .objects import Obj, format_size, get_size 2 | from chromatinhd.sparse import SparseDataset as SparseDataset_ 3 | 4 | 5 | class SparseDataset(Obj): 6 | def __init__(self, name=None): 7 | super().__init__(name=name) 8 | 9 | def get_path(self, folder): 10 | return folder / (self.name) 11 | 12 | def __get__(self, obj, type=None): 13 | if obj is not None: 14 | if self.name is None: 15 | raise ValueError(obj) 16 | name = "_" + str(self.name) 17 | if not hasattr(obj, name): 18 | path = self.get_path(obj.path) 19 | if not path.exists(): 20 | raise FileNotFoundError(f"File {path} does not exist") 21 | setattr(obj, name, SparseDataset_.open(self.get_path(obj.path))) 22 | return getattr(obj, name) 23 | 24 | def __set__(self, obj, value): 25 | name = "_" + str(self.name) 26 | setattr(obj, name, value) 27 | 28 | def exists(self, obj): 29 | return self.get_path(obj.path).exists() 30 | 31 | def _repr_html_(self, obj): 32 | self.__get__(obj) 33 | if not str(self.get_path(obj.path)).startswith("memory"): 34 | size = format_size(get_size(self.get_path(obj.path))) 35 | else: 36 | size = "" 37 | return f" {self.name} {size}" 38 | -------------------------------------------------------------------------------- /src/chromatinhd/loaders/__init__.py: -------------------------------------------------------------------------------- 1 | from .pool import LoaderPool 2 | from . import minibatches 3 | 4 | from .fragments import Fragments, FragmentsRegional, Cuts 5 | from .transcriptome import Transcriptome 6 | from .transcriptome_fragments import TranscriptomeFragments 7 | from .clustering_fragments import ClusteringCuts 8 | -------------------------------------------------------------------------------- /src/chromatinhd/loaders/clustering.py: -------------------------------------------------------------------------------- 1 | import chromatinhd.data.clustering 2 | import dataclasses 3 | import torch 4 | 5 | 6 | @dataclasses.dataclass 7 | class Result: 8 | # onehot: torch.Tensor 9 | indices: torch.Tensor 10 | 11 | def to(self, device): 12 | self.indices = self.indices.to(device) 13 | # self.onehot = self.onehot.to(device) 14 | return self 15 | 16 | 17 | class Clustering: 18 | """ 19 | Provides clustering data for a minibatch. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | clustering: chromatinhd.data.clustering.Clustering, 25 | ): 26 | assert (clustering.labels.cat.categories == clustering.cluster_info.index).all(), ( 27 | clustering.labels.cat.categories, 28 | clustering.cluster_info.index, 29 | ) 30 | self.onehot = torch.nn.functional.one_hot( 31 | torch.from_numpy(clustering.labels.cat.codes.values.copy()).to(torch.int64), 32 | clustering.n_clusters, 33 | ).to(torch.float) 34 | 35 | def load(self, minibatch): 36 | # onehot = self.onehot[minibatch.cells_oi, :] 37 | indices = torch.argmax(self.onehot[minibatch.cells_oi, :], dim=1) 38 | return Result(indices=indices) 39 | -------------------------------------------------------------------------------- /src/chromatinhd/loaders/clustering_fragments.py: -------------------------------------------------------------------------------- 1 | import chromatinhd.data.clustering 2 | import dataclasses 3 | 4 | from .fragments import Cuts, CutsRegional, CutsResult 5 | from chromatinhd.loaders.minibatches import Minibatch 6 | from .clustering import Clustering 7 | 8 | 9 | @dataclasses.dataclass 10 | class Result: 11 | clustering: Clustering 12 | cuts: CutsResult 13 | minibatch: Minibatch 14 | 15 | def to(self, device): 16 | self.clustering.to(device) 17 | self.cuts.to(device) 18 | self.minibatch.to(device) 19 | return self 20 | 21 | 22 | class ClusteringCuts: 23 | def __init__( 24 | self, 25 | fragments: chromatinhd.data.fragments.Fragments, 26 | clustering: chromatinhd.data.clustering.Clustering, 27 | cellxregion_batch_size: int, 28 | layer: str = None, 29 | region_oi=None, 30 | ): 31 | # ensure that clustering and fragments have the same obs 32 | # if not all(clustering.obs.index == fragments.obs.index): 33 | # raise ValueError("Clustering and fragments should have the same obs index. ") 34 | 35 | if region_oi is None: 36 | self.cuts = Cuts(fragments, cellxregion_batch_size=cellxregion_batch_size) 37 | self.clustering = Clustering(clustering) 38 | else: 39 | self.cuts = CutsRegional(fragments, cellxregion_batch_size=cellxregion_batch_size, region_oi=region_oi) 40 | self.clustering = Clustering(clustering) 41 | 42 | def load(self, minibatch): 43 | return Result( 44 | clustering=self.clustering.load(minibatch), 45 | cuts=self.cuts.load(minibatch), 46 | minibatch=minibatch, 47 | ) 48 | -------------------------------------------------------------------------------- /src/chromatinhd/loaders/fragments_helpers.pyx: -------------------------------------------------------------------------------- 1 | #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION 2 | #cython: language_level=3 3 | cimport cython 4 | import numpy as np 5 | cimport numpy as np 6 | np.import_array() 7 | 8 | INT64 = np.int64 9 | ctypedef np.int64_t INT64_t 10 | FLOAT64 = np.float64 11 | ctypedef np.float64_t FLOAT64_t 12 | 13 | @cython.boundscheck(False) # turn off bounds-checking for entire function 14 | @cython.wraparound(False) # turn off negative index wrapping for entire function 15 | @cython.cdivision 16 | def extract_fragments( 17 | INT64_t [::1] cellxgene_oi, 18 | INT64_t [::1] cellxgene_indptr, 19 | INT64_t [:,::1] coordinates, 20 | INT64_t [::1] genemapping, 21 | INT64_t [:,::1] out_coordinates, 22 | INT64_t [::1] out_genemapping, 23 | INT64_t [::1] out_local_cellxgene_ix, 24 | ): 25 | cdef INT64_t out_ix, local_cellxgene_ix, cellxgene_ix, position 26 | out_ix = 0 # will store where in the output array we are currently 27 | local_cellxgene_ix = 0 # will store the current fragment counting from 0 28 | 29 | with nogil: 30 | for local_cellxgene_ix in range(cellxgene_oi.shape[0]): 31 | cellxgene_ix = cellxgene_oi[local_cellxgene_ix] 32 | for position in range(cellxgene_indptr[cellxgene_ix], cellxgene_indptr[cellxgene_ix+1]): 33 | out_coordinates[out_ix] = coordinates[position] 34 | out_genemapping[out_ix] = genemapping[position] 35 | out_local_cellxgene_ix[out_ix] = local_cellxgene_ix 36 | 37 | out_ix += 1 38 | 39 | return out_ix 40 | 41 | 42 | @cython.boundscheck(False) # turn off bounds-checking for entire function 43 | @cython.wraparound(False) # turn off negative index wrapping for entire function 44 | @cython.cdivision 45 | def multiple_arange( 46 | INT64_t [::1] a, 47 | INT64_t [::1] b, 48 | INT64_t [::1] ixs, 49 | INT64_t [::1] local_cellxregion_ix, 50 | ): 51 | cdef INT64_t out_ix, pair_ix, position 52 | out_ix = 0 # will store where in the output array we are currently 53 | pair_ix = 0 # will store the current a, b pair index 54 | 55 | with nogil: 56 | for pair_ix in range(a.shape[0]): 57 | for position in range(a[pair_ix], b[pair_ix]): 58 | ixs[out_ix] = position 59 | local_cellxregion_ix[out_ix] = pair_ix 60 | out_ix += 1 61 | pair_ix += 1 62 | 63 | return out_ix 64 | 65 | 66 | 67 | @cython.boundscheck(False) # turn off bounds-checking for entire function 68 | @cython.wraparound(False) # turn off negative index wrapping for entire function 69 | @cython.cdivision 70 | def multiple_arange( 71 | INT64_t [::1] a, 72 | INT64_t [::1] b, 73 | INT64_t [::1] ixs, 74 | INT64_t [::1] local_cellxregion_ix, 75 | ): 76 | cdef INT64_t out_ix, pair_ix, position 77 | out_ix = 0 # will store where in the output array we are currently 78 | pair_ix = 0 # will store the current a, b pair index 79 | 80 | with nogil: 81 | for pair_ix in range(a.shape[0]): 82 | for position in range(a[pair_ix], b[pair_ix]): 83 | ixs[out_ix] = position 84 | local_cellxregion_ix[out_ix] = pair_ix 85 | out_ix += 1 86 | pair_ix += 1 87 | 88 | return out_ix 89 | 90 | @cython.boundscheck(False) # turn off bounds-checking for entire function 91 | @cython.wraparound(False) # turn off negative index wrapping for entire function 92 | @cython.cdivision 93 | def count_A_cython(bytes s, INT64_t [:,::1] out_onehot): 94 | cdef char* cstr = s 95 | cdef int i, length = len(s) 96 | 97 | for i in range(length): 98 | if cstr[i] == b'A': 99 | out_onehot[i, 0] = 1 100 | elif cstr[i] == b'C': 101 | out_onehot[i, 1] = 1 102 | elif cstr[i] == b'G': 103 | out_onehot[i, 2] = 1 104 | elif cstr[i] == b'T': 105 | out_onehot[i, 3] = 1 106 | else: 107 | out_onehot[i, 4] = 1 108 | 109 | return out_onehot -------------------------------------------------------------------------------- /src/chromatinhd/loaders/peakcounts.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import dataclasses 4 | 5 | 6 | @dataclasses.dataclass 7 | class Result: 8 | counts: np.ndarray 9 | cells_oi: np.ndarray 10 | genes_oi: np.ndarray 11 | 12 | @property 13 | def n_cells(self): 14 | return len(self.cells_oi) 15 | 16 | @property 17 | def n_genes(self): 18 | return len(self.genes_oi) 19 | 20 | def to(self, device): 21 | for field_name, field in self.__dataclass_fields__.items(): 22 | if field.type is torch.Tensor: 23 | self.__setattr__(field_name, self.__getattribute__(field_name).to(device)) 24 | return self 25 | 26 | @property 27 | def genes_oi_torch(self): 28 | return torch.from_numpy(self.genes_oi).to(self.coordinates.device) 29 | 30 | @property 31 | def cells_oi_torch(self): 32 | return torch.from_numpy(self.cells_oi).to(self.coordinates.device) 33 | 34 | 35 | class PeakcountsResult(Result): 36 | pass 37 | 38 | 39 | class Peakcounts: 40 | def __init__(self, fragments, peakcounts): 41 | self.peakcounts = peakcounts 42 | assert "gene_ix" in peakcounts.peaks.columns 43 | var = peakcounts.var 44 | var["ix"] = np.arange(peakcounts.var.shape[0]) 45 | peakcounts.var = var 46 | assert "ix" in peakcounts.var.columns 47 | 48 | assert peakcounts.counts.shape[1] == peakcounts.var.shape[0] 49 | 50 | self.gene_peak_mapping = [] 51 | cur_gene_ix = -1 52 | for peak_ix, gene_ix in zip(peakcounts.var["ix"][peakcounts.peaks.index].values, peakcounts.peaks["gene_ix"]): 53 | while gene_ix != cur_gene_ix: 54 | self.gene_peak_mapping.append([]) 55 | cur_gene_ix += 1 56 | self.gene_peak_mapping[-1].append(peak_ix) 57 | 58 | def load(self, minibatch): 59 | peak_ixs = np.concatenate([self.gene_peak_mapping[gene_ix] for gene_ix in minibatch.genes_oi]) 60 | counts = self.peakcounts.counts[minibatch.cells_oi, :][:, peak_ixs] 61 | 62 | return PeakcountsResult(counts=counts, **minibatch.items()) 63 | -------------------------------------------------------------------------------- /src/chromatinhd/loaders/pool.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import threading 3 | import numpy as np 4 | from typing import List 5 | import time 6 | 7 | 8 | class ThreadWithResult(threading.Thread): 9 | result = None 10 | 11 | def __init__( 12 | self, 13 | group=None, 14 | target=None, 15 | name=None, 16 | loader=None, 17 | args=(), 18 | kwargs={}, 19 | *, 20 | daemon=None, 21 | ): 22 | assert target is not None 23 | self.loader = loader 24 | 25 | def function(): 26 | self.result = target(*args, **kwargs) 27 | 28 | super().__init__(group=group, target=function, name=name, daemon=daemon) 29 | 30 | 31 | def benchmark(loaders, minibatcher, n=100): 32 | """ 33 | Benchmarks a pool of loaders 34 | """ 35 | 36 | loaders.initialize(minibatcher) 37 | waits = [] 38 | import time 39 | import tqdm.auto as tqdm 40 | 41 | start = time.time() 42 | for i, data in zip(range(n), tqdm.tqdm(loaders)): 43 | waits.append(time.time() - start) 44 | start = time.time() 45 | print(sum(waits)) 46 | 47 | import matplotlib.pyplot as plt 48 | 49 | plt.plot(waits) 50 | 51 | 52 | class LoaderPool: 53 | loaders_running: list 54 | loaders_available: list 55 | counter: bool = False 56 | 57 | def __init__( 58 | self, 59 | loader_cls, 60 | loader_kwargs=None, 61 | n_workers=3, 62 | loader=None, 63 | ): 64 | self.loaders_running = [] 65 | 66 | if loader_kwargs is None: 67 | loader_kwargs = {} 68 | self.loader_cls = loader_cls 69 | self.loader_kwargs = loader_kwargs 70 | 71 | self.n_workers = n_workers 72 | 73 | if loader is not None: 74 | self.loaders = [loader.copy() for i in range(n_workers)] 75 | else: 76 | self.loaders = [loader_cls(**loader_kwargs) for i in range(n_workers)] 77 | for loader in self.loaders: 78 | loader.running = False 79 | self.wait = [] 80 | 81 | def initialize(self, tasker, *args, **kwargs): 82 | self.tasker = tasker 83 | 84 | self.args = args 85 | self.kwargs = kwargs 86 | 87 | self.start(*args, **kwargs) 88 | 89 | def start(self): 90 | # join all still running threads 91 | for thread in self.loaders_running: 92 | thread.join() 93 | 94 | for loader in self.loaders: 95 | loader.running = False 96 | 97 | self.loaders_available = copy.copy(self.loaders) 98 | self.loaders_running = [] 99 | 100 | self.wait = [] 101 | 102 | self.tasker_iter = iter(self.tasker) 103 | 104 | for i in range(min(len(self.tasker), self.n_workers - 1)): 105 | self.submit_next() 106 | 107 | def __iter__(self): 108 | self.counter = 0 109 | return self 110 | 111 | def __len__(self): 112 | return len(self.tasker) 113 | 114 | def __next__(self): 115 | self.counter += 1 116 | if self.counter > len(self.tasker): 117 | raise StopIteration 118 | result = self.pull() 119 | self.submit_next() 120 | return result 121 | 122 | def submit_next(self): 123 | try: 124 | task = next(self.tasker_iter) 125 | except StopIteration: 126 | self.tasker_iter = iter(self.tasker) 127 | task = next(self.tasker_iter) 128 | self.submit(task, *self.args, **self.kwargs) 129 | 130 | def submit(self, *args, **kwargs): 131 | if self.loaders_available is None: 132 | raise ValueError("Pool was not initialized") 133 | if len(self.loaders_available) == 0: 134 | raise ValueError("No loaders available") 135 | 136 | loader = self.loaders_available.pop(0) 137 | if loader.running: 138 | raise ValueError 139 | loader.running = True 140 | thread = ThreadWithResult(target=loader.load, loader=loader, args=args, kwargs=kwargs) 141 | self.loaders_running.append(thread) 142 | thread.start() 143 | 144 | def pull(self): 145 | thread = self.loaders_running.pop(0) 146 | 147 | start = time.time() 148 | thread.join() 149 | wait = time.time() - start 150 | self.wait.append(wait) 151 | thread.loader.running = False 152 | result = thread.result 153 | self.loaders_available.append(thread.loader) 154 | return result 155 | 156 | def terminate(self): 157 | for thread in self.loaders_running: 158 | thread.join() 159 | self.loaders_running = [] 160 | self.loaders = None 161 | 162 | def __del__(self): 163 | self.terminate() 164 | -------------------------------------------------------------------------------- /src/chromatinhd/loaders/transcriptome.py: -------------------------------------------------------------------------------- 1 | import chromatinhd.data.transcriptome 2 | import dataclasses 3 | import torch 4 | import chromatinhd.sparse 5 | from chromatinhd.flow.tensorstore import TensorstoreInstance 6 | import numpy as np 7 | 8 | 9 | @dataclasses.dataclass 10 | class Result: 11 | value: torch.Tensor 12 | 13 | def to(self, device): 14 | self.value = self.value.to(device) 15 | return self 16 | 17 | 18 | class Transcriptome: 19 | def __init__( 20 | self, 21 | transcriptome: chromatinhd.data.transcriptome.Transcriptome, 22 | layer: str = None, 23 | ): 24 | if layer is None: 25 | layer = list(transcriptome.layers.keys())[0] 26 | 27 | X = transcriptome.layers[layer] 28 | if chromatinhd.sparse.is_sparse(X): 29 | self.X = X.dense() 30 | elif torch.is_tensor(X): 31 | self.X = X.numpy() 32 | elif isinstance(X, TensorstoreInstance): 33 | # self.X = X 34 | self.X = X.oindex # open a tensorstore reader with orthogonal indexing 35 | else: 36 | self.X = X 37 | 38 | def load(self, minibatch): 39 | X = torch.from_numpy(self.X[minibatch.cells_oi, minibatch.genes_oi].astype(np.float32)) 40 | if X.ndim == 1: 41 | X = X.unsqueeze(1) 42 | return Result(value=X) 43 | 44 | 45 | class TranscriptomeGene: 46 | def __init__( 47 | self, 48 | transcriptome: chromatinhd.data.transcriptome.Transcriptome, 49 | gene_oi, 50 | layer: str = None, 51 | ): 52 | if layer is None: 53 | layer = list(transcriptome.layers.keys())[0] 54 | 55 | gene_ix = transcriptome.var.index.get_loc(gene_oi) 56 | 57 | X = transcriptome.layers[layer] 58 | if chromatinhd.sparse.is_sparse(X): 59 | self.X = X[:, gene_ix].dense()[:, 0] 60 | elif torch.is_tensor(X): 61 | self.X = X[:, gene_ix].numpy() 62 | elif isinstance(X, TensorstoreInstance): 63 | # self.X = X 64 | self.X = X.oindex[:, gene_ix] # open a tensorstore reader with orthogonal indexing 65 | else: 66 | self.X = X[:, gene_ix] 67 | 68 | def load(self, minibatch): 69 | X = torch.from_numpy(self.X[minibatch.cells_oi].astype(np.float32)) 70 | if X.ndim == 1: 71 | X = X.unsqueeze(1) 72 | return Result(value=X) 73 | -------------------------------------------------------------------------------- /src/chromatinhd/loaders/transcriptome_fragments.py: -------------------------------------------------------------------------------- 1 | import chromatinhd.data.transcriptome 2 | import dataclasses 3 | 4 | from .fragments import Fragments, FragmentsRegional 5 | from chromatinhd.loaders.minibatches import Minibatch 6 | from .transcriptome import Transcriptome, TranscriptomeGene 7 | 8 | 9 | @dataclasses.dataclass 10 | class Result: 11 | transcriptome: Transcriptome 12 | fragments: Fragments 13 | minibatch: Minibatch 14 | 15 | def to(self, device): 16 | self.transcriptome.to(device) 17 | self.fragments.to(device) 18 | self.minibatch.to(device) 19 | return self 20 | 21 | 22 | class TranscriptomeFragments: 23 | def __init__( 24 | self, 25 | fragments: chromatinhd.data.fragments.Fragments, 26 | transcriptome: chromatinhd.data.transcriptome.Transcriptome, 27 | cellxregion_batch_size: int, 28 | layer: str = None, 29 | region_oi=None, 30 | ): 31 | # ensure that transcriptome and fragments have the same var 32 | if not all(transcriptome.var.index == fragments.var.index): 33 | raise ValueError("Transcriptome and fragments should have the same var index. ") 34 | 35 | if region_oi is None: 36 | self.fragments = Fragments( 37 | fragments, 38 | cellxregion_batch_size=cellxregion_batch_size, 39 | provide_multiplets=False, 40 | provide_libsize=True, 41 | ) 42 | self.transcriptome = Transcriptome(transcriptome, layer=layer) 43 | else: 44 | self.fragments = FragmentsRegional( 45 | fragments, 46 | cellxregion_batch_size=cellxregion_batch_size, 47 | region_oi=region_oi, 48 | provide_libsize=True, 49 | ) 50 | self.transcriptome = TranscriptomeGene(transcriptome, gene_oi=region_oi, layer=layer) 51 | 52 | def load(self, minibatch): 53 | return Result( 54 | transcriptome=self.transcriptome.load(minibatch), 55 | fragments=self.fragments.load(minibatch), 56 | minibatch=minibatch, 57 | ) 58 | -------------------------------------------------------------------------------- /src/chromatinhd/loaders/transcriptome_fragments2.py: -------------------------------------------------------------------------------- 1 | import chromatinhd.data.transcriptome 2 | import dataclasses 3 | 4 | from .fragments2 import Fragments 5 | from chromatinhd.loaders.minibatches import Minibatch 6 | from .transcriptome import Transcriptome, TranscriptomeGene 7 | 8 | 9 | @dataclasses.dataclass 10 | class Result: 11 | transcriptome: Transcriptome 12 | fragments: Fragments 13 | minibatch: Minibatch 14 | 15 | def to(self, device): 16 | self.transcriptome.to(device) 17 | self.fragments.to(device) 18 | self.minibatch.to(device) 19 | return self 20 | 21 | 22 | class TranscriptomeFragments: 23 | def __init__( 24 | self, 25 | fragments: chromatinhd.data.fragments.Fragments, 26 | transcriptome: chromatinhd.data.transcriptome.Transcriptome, 27 | regionxcell_batch_size: int, 28 | layer: str = None, 29 | region_oi=None, 30 | ): 31 | # ensure that transcriptome and fragments have the same var 32 | if not all(transcriptome.var.index == fragments.var.index): 33 | raise ValueError("Transcriptome and fragments should have the same var index. ") 34 | 35 | self.fragments = Fragments(fragments, regionxcell_batch_size=regionxcell_batch_size) 36 | self.transcriptome = Transcriptome(transcriptome, layer=layer) 37 | 38 | def load(self, minibatch): 39 | return Result( 40 | transcriptome=self.transcriptome.load(minibatch), 41 | fragments=self.fragments.load(minibatch), 42 | minibatch=minibatch, 43 | ) 44 | -------------------------------------------------------------------------------- /src/chromatinhd/loaders/transcriptome_fragments_time.py: -------------------------------------------------------------------------------- 1 | import chromatinhd.data.transcriptome 2 | import chromatinhd.data.gradient 3 | import dataclasses 4 | import numpy as np 5 | import torch 6 | 7 | from .fragments import Fragments 8 | from chromatinhd.loaders.minibatches import Minibatch 9 | from .transcriptome import Result as TranscriptomeResult 10 | 11 | 12 | @dataclasses.dataclass 13 | class Result: 14 | transcriptome: TranscriptomeResult 15 | fragments: Fragments 16 | minibatch: Minibatch 17 | 18 | def to(self, device): 19 | self.transcriptome.to(device) 20 | self.fragments.to(device) 21 | self.minibatch.to(device) 22 | return self 23 | 24 | 25 | class TranscriptomeTime: 26 | def __init__( 27 | self, 28 | transcriptome: chromatinhd.data.transcriptome.Transcriptome, 29 | gradient: chromatinhd.data.gradient.Gradient, 30 | layer: str = None, 31 | delta_time=0.25, 32 | n_bins=20, 33 | delta_expression=False, 34 | ): 35 | bins = self.bins = np.array([0] + list(np.linspace(0.1, 0.9, n_bins - 1)) + [1]) 36 | 37 | x = self.x = gradient.values[:, 0] 38 | if layer is None: 39 | layer = list(transcriptome.layers.keys())[0] 40 | y = transcriptome.layers[layer][:] 41 | 42 | x_binned = self.x_binned = np.clip(np.searchsorted(bins, x) - 1, 0, bins.size - 2) 43 | x_onehot = np.zeros((x_binned.size, x_binned.max() + 1)) 44 | x_onehot[np.arange(x_binned.size), x_binned] = 1 45 | y_binned = (x_onehot.T @ y) / x_onehot.sum(axis=0)[:, None] 46 | self.y_binned = (y_binned - y_binned.min(0)) / (y_binned.max(0) - y_binned.min(0)) 47 | 48 | self.delta_time = delta_time 49 | self.delta_expression = delta_expression 50 | 51 | def load(self, minibatch): 52 | x = self.x[minibatch.cells_oi] 53 | x_desired = x + self.delta_time 54 | x_desired_bin = np.clip(np.searchsorted(self.bins, x_desired) - 1, 0, self.bins.size - 2) 55 | if self.delta_expression: 56 | x_desired_bin2 = np.clip(np.searchsorted(self.bins, x_desired + self.delta_time) - 1, 0, self.bins.size - 2) 57 | y_desired = ( 58 | self.y_binned[x_desired_bin2, :][:, minibatch.genes_oi] 59 | - self.y_binned[x_desired_bin, :][:, minibatch.genes_oi] 60 | ) 61 | else: 62 | y_desired = self.y_binned[x_desired_bin, :][:, minibatch.genes_oi] 63 | 64 | return TranscriptomeResult(value=torch.from_numpy(y_desired)) 65 | 66 | 67 | class TranscriptomeFragmentsTime: 68 | def __init__( 69 | self, 70 | fragments: chromatinhd.data.fragments.Fragments, 71 | transcriptome: chromatinhd.data.transcriptome.Transcriptome, 72 | gradient: chromatinhd.data.gradient.Gradient, 73 | cellxregion_batch_size: int, 74 | layer: str = None, 75 | delta_time=0.25, 76 | n_bins=20, 77 | delta_expression=False, 78 | ): 79 | # ensure that transcriptome and fragments have the same var 80 | if not all(transcriptome.var.index == fragments.var.index): 81 | raise ValueError("Transcriptome and fragments should have the same var index.") 82 | 83 | self.fragments = Fragments(fragments, cellxregion_batch_size=cellxregion_batch_size) 84 | self.transcriptome = TranscriptomeTime( 85 | transcriptome, 86 | gradient, 87 | layer=layer, 88 | delta_expression=delta_expression, 89 | delta_time=delta_time, 90 | n_bins=n_bins, 91 | ) 92 | 93 | def load(self, minibatch): 94 | return Result( 95 | transcriptome=self.transcriptome.load(minibatch), 96 | fragments=self.fragments.load(minibatch), 97 | minibatch=minibatch, 98 | ) 99 | -------------------------------------------------------------------------------- /src/chromatinhd/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import HybridModel, FlowModel 2 | from . import pred 3 | from . import diff 4 | -------------------------------------------------------------------------------- /src/chromatinhd/models/diff/__init__.py: -------------------------------------------------------------------------------- 1 | # from .differential import DifferentialSlices 2 | from . import plot 3 | from . import loader 4 | from . import model 5 | from . import interpret 6 | from . import trainer 7 | -------------------------------------------------------------------------------- /src/chromatinhd/models/diff/interpret/__init__.py: -------------------------------------------------------------------------------- 1 | from .regionpositional import RegionPositional 2 | from .performance import Performance 3 | from . import enrichment 4 | -------------------------------------------------------------------------------- /src/chromatinhd/models/diff/interpret/enrichment/__init__.py: -------------------------------------------------------------------------------- 1 | from .enrichment import enrichment_foreground_vs_background, enrichment_cluster_vs_clusters 2 | from .group import group_enrichment -------------------------------------------------------------------------------- /src/chromatinhd/models/diff/interpret/enrichment/enrichment.py: -------------------------------------------------------------------------------- 1 | import scipy.stats 2 | import tqdm.auto as tqdm 3 | import numpy as np 4 | import pandas as pd 5 | from chromatinhd.utils import fdr 6 | 7 | 8 | def enrichment_foreground_vs_background(slicescores_foreground, slicescores_background, slicecounts, motifs=None, expected = None): 9 | if motifs is None: 10 | motifs = slicecounts.columns 11 | 12 | if "length" not in slicescores_foreground.columns: 13 | slicescores_foreground["length"] = slicescores_foreground["end"] - slicescores_foreground["start"] 14 | if "length" not in slicescores_background.columns: 15 | slicescores_background["length"] = slicescores_background["end"] - slicescores_background["start"] 16 | 17 | x_foreground = slicecounts.loc[slicescores_foreground.index, motifs].sum(0) 18 | x_background = slicecounts.loc[slicescores_background.index, motifs].sum(0) 19 | 20 | n_foreground = slicescores_foreground["length"].sum() 21 | n_background = slicescores_background["length"].sum() 22 | 23 | contingencies = ( 24 | np.stack( 25 | [ 26 | n_background - x_background, 27 | x_background, 28 | n_foreground - x_foreground, 29 | x_foreground, 30 | ], 31 | axis=1, 32 | ) 33 | .reshape(-1, 2, 2) 34 | .astype(np.int64) 35 | ) 36 | 37 | odds = (contingencies[:, 1, 1] * contingencies[:, 0, 0] + 1) / (contingencies[:, 1, 0] * contingencies[:, 0, 1] + 1) 38 | 39 | if expected is not None: 40 | x_foreground = expected.loc[slicescores_foreground.index, motifs].sum(0) 41 | x_background = expected.loc[slicescores_background.index, motifs].sum(0) 42 | n_foreground = slicescores_foreground["length"].sum() 43 | n_background = slicescores_background["length"].sum() 44 | 45 | contingencies_expected = ( 46 | np.stack( 47 | [ 48 | n_background - x_background, 49 | x_background, 50 | n_foreground - x_foreground, 51 | x_foreground, 52 | ], 53 | axis=1, 54 | ) 55 | .reshape(-1, 2, 2) 56 | .astype(np.int64) 57 | ) 58 | odds_expected = (contingencies_expected[:, 1, 1] * contingencies_expected[:, 0, 0] + 1) / (contingencies_expected[:, 1, 0] * contingencies_expected[:, 0, 1] + 1) 59 | 60 | odds = odds / odds_expected 61 | 62 | contingencies[:, 0, 1] = contingencies[:, 0, 1] * odds_expected 63 | 64 | 65 | p_values = np.array( 66 | [ 67 | scipy.stats.chi2_contingency(c).pvalue if (c > 5).all() else scipy.stats.fisher_exact(c).pvalue 68 | for c in contingencies 69 | ] 70 | ) 71 | q_values = fdr(p_values) 72 | 73 | return pd.DataFrame( 74 | { 75 | "odds": odds, 76 | "p_value": p_values, 77 | "q_value": q_values, 78 | "motif": motifs, 79 | "contingency": [c for c in contingencies], 80 | } 81 | ).set_index("motif") 82 | 83 | 84 | def enrichment_cluster_vs_clusters(slicescores, slicecounts, clusters=None, motifs=None, pbar=True, expected = None): 85 | if clusters is None: 86 | if not "cluster" in slicescores.columns: 87 | raise ValueError("No cluster information in slicescores") 88 | elif not slicescores["cluster"].dtype.name == "category": 89 | raise ValueError("Cluster column should be categorical") 90 | clusters = slicescores["cluster"].cat.categories 91 | enrichment = [] 92 | 93 | progress = clusters 94 | if pbar: 95 | progress = tqdm.tqdm(progress) 96 | for cluster in progress: 97 | selected_slices = slicescores["cluster"] == cluster 98 | slicescores_foreground = slicescores.loc[selected_slices] 99 | slicescores_background = slicescores.loc[~selected_slices] 100 | 101 | enrichment.append( 102 | enrichment_foreground_vs_background( 103 | slicescores_foreground, 104 | slicescores_background, 105 | slicecounts, 106 | motifs=motifs, 107 | expected = expected, 108 | ) 109 | .assign(cluster=cluster) 110 | .reset_index() 111 | ) 112 | 113 | enrichment = pd.concat(enrichment, axis=0).set_index(["cluster", "motif"]) 114 | enrichment["log_odds"] = np.log(enrichment["odds"]) 115 | return enrichment 116 | -------------------------------------------------------------------------------- /src/chromatinhd/models/diff/interpret/enrichment/group.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | 4 | 5 | def group_enrichment( 6 | enrichment, 7 | slicecounts, 8 | cluster_info, 9 | merge_cutoff=0.2, 10 | q_value_cutoff=0.01, 11 | odds_cutoff=1.1, 12 | min_found=100, 13 | ): 14 | """ 15 | Group motifs based on correlation of their slice counts across clusters 16 | 17 | Parameters 18 | ---------- 19 | enrichment : pd.DataFrame 20 | DataFrame with columns "odds", "q_value", "contingency" 21 | slicecounts : pd.DataFrame 22 | DataFrame with slice counts 23 | cluster_info : pd.DataFrame 24 | DataFrame with cluster information 25 | merge_cutoff : float 26 | Correlation cutoff for merging motifs 27 | q_value_cutoff : float 28 | q-value cutoff for enrichment 29 | odds_cutoff : float 30 | Odds ratio cutoff for enrichment 31 | min_found : int 32 | Minimum number of found sites for enrichment 33 | """ 34 | 35 | enrichment_grouped = [] 36 | slicecors = pd.DataFrame( 37 | np.corrcoef((slicecounts.T > 0) + np.random.normal(0, 1e-6, slicecounts.shape).T), 38 | index=slicecounts.columns, 39 | columns=slicecounts.columns, 40 | ) 41 | 42 | for cluster_oi in cluster_info.index: 43 | for direction in ["up", "down"]: 44 | if direction == "up": 45 | enrichment["found"] = enrichment["contingency"].map(lambda x: x[1, 1].sum()) 46 | enrichment_oi = ( 47 | enrichment.loc[cluster_oi] 48 | .query("q_value < @q_value_cutoff") 49 | .query("odds > @odds_cutoff") 50 | .sort_values("odds", ascending=False) 51 | .query("found > @min_found") 52 | ) 53 | else: 54 | enrichment["found"] = enrichment["contingency"].map(lambda x: x[0, 1].sum()) 55 | print(enrichment["found"]) 56 | enrichment_oi = ( 57 | enrichment.loc[cluster_oi] 58 | .query("q_value < @q_value_cutoff") 59 | .query("1/odds > @odds_cutoff") 60 | .sort_values("odds", ascending=True) 61 | .query("found > @min_found") 62 | ) 63 | 64 | # enrichment_oi = enrichment_oi.loc[(~enrichment_oi.index.get_level_values("motif").str.contains("ZNF")) & (~enrichment_oi.index.get_level_values("motif").str.startswith("ZN")) & (~enrichment_oi.index.get_level_values("motif").str.contains("KLF")) & (~enrichment_oi.index.get_level_values("motif").str.contains("WT"))] 65 | 66 | motif_grouping = {} 67 | for motif_id in enrichment_oi.index: 68 | slicecors_oi = slicecors.loc[motif_id, list(motif_grouping.keys())] 69 | if (slicecors_oi < merge_cutoff).all(): 70 | motif_grouping[motif_id] = [motif_id] 71 | enrichment_oi.loc[motif_id, "group"] = motif_id 72 | else: 73 | group = slicecors_oi.sort_values(ascending=False).index[0] 74 | motif_grouping[group].append(motif_id) 75 | enrichment_oi.loc[motif_id, "group"] = group 76 | enrichment_group = enrichment_oi.sort_values("odds", ascending=False).loc[ 77 | list(motif_grouping.keys()) 78 | ] 79 | enrichment_group["members"] = [ 80 | motif_grouping[group] for group in enrichment_group.index 81 | ] 82 | enrichment_group["direction"] = direction 83 | enrichment_grouped.append(enrichment_group.assign(cluster=cluster_oi)) 84 | enrichment_grouped = ( 85 | pd.concat(enrichment_grouped).reset_index().set_index(["cluster", "motif"]) 86 | ) 87 | enrichment_grouped = enrichment_grouped.sort_values("q_value", ascending=True) 88 | 89 | return enrichment_grouped 90 | -------------------------------------------------------------------------------- /src/chromatinhd/models/diff/interpret/performance.py: -------------------------------------------------------------------------------- 1 | import chromatinhd as chd 2 | import numpy as np 3 | import pandas as pd 4 | import xarray as xr 5 | import pickle 6 | import scipy.stats 7 | import tqdm.auto as tqdm 8 | from chromatinhd import get_default_device 9 | 10 | from chromatinhd.flow.objects import Stored, Dataset, StoredDict 11 | 12 | from itertools import product 13 | 14 | 15 | class Performance(chd.flow.Flow): 16 | """ 17 | The train/validation/test performance of a (set of) models. 18 | """ 19 | 20 | scores = chd.flow.SparseDataset() 21 | 22 | folds = None 23 | 24 | @classmethod 25 | def create(cls, folds, fragments, phases=None, overwrite=False, path=None): 26 | self = cls(path=path, reset=overwrite) 27 | 28 | self.folds = folds 29 | self.fragments = fragments 30 | 31 | if self.o.scores.exists(self) and not overwrite: 32 | assert self.scores.coords_pointed["region"].equals(fragments.var.index) 33 | 34 | return self 35 | 36 | if phases is None: 37 | phases = ["train", "validation", "test"] 38 | 39 | coords_pointed = { 40 | "region": fragments.regions.var.index, 41 | "fold": pd.Index(range(len(folds)), name="fold"), 42 | "phase": pd.Index(phases, name="phase"), 43 | } 44 | coords_fixed = {} 45 | self.scores = chd.sparse.SparseDataset.create( 46 | self.path / "scores", 47 | variables={ 48 | "likelihood": { 49 | "dimensions": ("region", "fold", "phase"), 50 | "dtype": np.float32, 51 | "sparse": False, 52 | }, 53 | "scored": { 54 | "dimensions": ("region", "fold"), 55 | "dtype": np.bool, 56 | "sparse": False, 57 | }, 58 | }, 59 | coords_pointed=coords_pointed, 60 | coords_fixed=coords_fixed, 61 | ) 62 | 63 | return self 64 | 65 | def score( 66 | self, 67 | models, 68 | force=False, 69 | device=None, 70 | regions=None, 71 | pbar=True, 72 | ): 73 | fragments = self.fragments 74 | folds = self.folds 75 | 76 | phases = self.scores.coords["phase"].values 77 | 78 | region_name = fragments.var.index.name 79 | 80 | design = self.scores["scored"].sel_xr().to_dataframe(name="scored") 81 | design = design.loc[~design["scored"]] 82 | design = design.reset_index()[[region_name, "fold"]] 83 | 84 | if regions is not None: 85 | design = design.loc[design[region_name].isin(regions)] 86 | 87 | progress = design.groupby("fold") 88 | if pbar is True: 89 | progress = tqdm.tqdm(progress, total=len(progress), leave=False) 90 | 91 | for fold_ix, subdesign in progress: 92 | if models.fitted(fold_ix): 93 | for phase in phases: 94 | fold = folds[fold_ix] 95 | cells_oi = fold[f"cells_{phase}"] 96 | 97 | likelihood = models.get_prediction( 98 | fold_ix=fold_ix, cell_ixs=cells_oi, regions=subdesign[region_name], return_raw=True 99 | ) 100 | 101 | for region_ix, region_oi in enumerate(subdesign[region_name].values): 102 | self.scores["likelihood"][ 103 | region_oi, 104 | fold_ix, 105 | phase, 106 | ] = likelihood[:, region_ix].mean() 107 | 108 | for region_oi in subdesign[region_name]: 109 | self.scores["scored"][region_oi, fold_ix] = True 110 | 111 | return self 112 | -------------------------------------------------------------------------------- /src/chromatinhd/models/diff/interpret/slices.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import torch 4 | import xarray as xr 5 | import tqdm 6 | 7 | 8 | def filter_slices_probs(prob_cutoff=0.0): 9 | prob_cutoff = 0.0 10 | # prob_cutoff = -1. 11 | # prob_cutoff = -4. 12 | 13 | start_position_ixs = [] 14 | end_position_ixs = [] 15 | data = [] 16 | region_ixs = [] 17 | for region, probs in tqdm.tqdm(regionpositional.probs.items()): 18 | region_ix = fragments.var.index.get_loc(region) 19 | desired_x = np.arange(*fragments.regions.window) - fragments.regions.window[0] 20 | x = probs.coords["coord"].values - fragments.regions.window[0] 21 | y = probs.values 22 | 23 | y_interpolated = chd.utils.interpolate_1d( 24 | torch.from_numpy(desired_x), torch.from_numpy(x), torch.from_numpy(y) 25 | ).numpy() 26 | 27 | # from y_interpolated, determine start and end positions of the relevant slices 28 | start_position_ixs_region, end_position_ixs_region, data_region = extract_slices(y_interpolated, prob_cutoff) 29 | start_position_ixs.append(start_position_ixs_region + fragments.regions.window[0]) 30 | end_position_ixs.append(end_position_ixs_region + fragments.regions.window[0]) 31 | data.append(data_region) 32 | region_ixs.append(np.ones(len(start_position_ixs_region), dtype=int) * region_ix) 33 | data = np.concatenate(data, axis=0) 34 | start_position_ixs = np.concatenate(start_position_ixs, axis=0) 35 | end_position_ixs = np.concatenate(end_position_ixs, axis=0) 36 | region_ixs = np.concatenate(region_ixs, axis=0) 37 | 38 | slices = Slices(region_ixs, start_position_ixs, end_position_ixs, data, fragments.n_regions) 39 | 40 | 41 | def extract_slices(x, cutoff=0.0): 42 | selected = (x > cutoff).any(0).astype(int) 43 | selected_padded = np.pad(selected, ((1, 1))) 44 | (start_position_indices,) = np.where(np.diff(selected_padded, axis=-1) == 1) 45 | (end_position_indices,) = np.where(np.diff(selected_padded, axis=-1) == -1) 46 | start_position_indices = start_position_indices + 1 47 | end_position_indices = end_position_indices + 1 - 1 48 | 49 | data = [] 50 | for start_ix, end_ix in zip(start_position_indices, end_position_indices): 51 | data.append(x[:, start_ix:end_ix].transpose(1, 0)) 52 | if len(data) == 0: 53 | data = np.zeros((0, x.shape[0])) 54 | else: 55 | data = np.concatenate(data, axis=0) 56 | 57 | return start_position_indices, end_position_indices, data 58 | -------------------------------------------------------------------------------- /src/chromatinhd/models/diff/loader/__init__.py: -------------------------------------------------------------------------------- 1 | from chromatinhd.loaders.minibatches import Minibatcher, Minibatch 2 | from .clustering_cuts import ClusteringCuts 3 | -------------------------------------------------------------------------------- /src/chromatinhd/models/diff/loader/clustering_cuts.py: -------------------------------------------------------------------------------- 1 | import chromatinhd.data.fragments 2 | import dataclasses 3 | 4 | from chromatinhd.loaders.clustering import Clustering 5 | from chromatinhd.loaders.minibatches import Minibatch 6 | from chromatinhd.loaders.fragments import Cuts 7 | 8 | 9 | @dataclasses.dataclass 10 | class Result: 11 | cuts: Cuts 12 | clustering: Clustering 13 | minibatch: Minibatch 14 | 15 | def to(self, device): 16 | self.cuts.to(device) 17 | self.clustering.to(device) 18 | self.minibatch.to(device) 19 | return self 20 | 21 | 22 | class ClusteringCuts: 23 | """ 24 | Provides both clustering and cuts data for a minibatch. 25 | """ 26 | 27 | def __init__( 28 | self, 29 | clustering: chromatinhd.data.clustering.Clustering, 30 | fragments: chromatinhd.data.fragments.Fragments, 31 | cellxregion_batch_size: int, 32 | ): 33 | # ensure that the order of clustering and fragment.obs is the same 34 | if not all(clustering.labels.index == fragments.obs.index): 35 | raise ValueError("Clustering and fragments should have the same obs index. ") 36 | self.clustering = Clustering(clustering) 37 | self.cuts = Cuts(fragments, cellxregion_batch_size=cellxregion_batch_size) 38 | 39 | def load(self, minibatch): 40 | return Result( 41 | cuts=self.cuts.load(minibatch), 42 | clustering=self.clustering.load(minibatch), 43 | minibatch=minibatch, 44 | ) 45 | -------------------------------------------------------------------------------- /src/chromatinhd/models/diff/loader/clustering_fragments.py: -------------------------------------------------------------------------------- 1 | import chromatinhd.data.fragments 2 | import dataclasses 3 | 4 | from chromatinhd.loaders.fragments import Fragments 5 | from chromatinhd.loaders.clustering import Clustering 6 | from chromatinhd.loaders.minibatches import Minibatch 7 | 8 | 9 | @dataclasses.dataclass 10 | class Result: 11 | fragments: Fragments 12 | clustering: Clustering 13 | minibatch: Minibatch 14 | 15 | def to(self, device): 16 | self.fragments.to(device) 17 | self.clustering.to(device) 18 | self.minibatch.to(device) 19 | return self 20 | 21 | 22 | class ClusteringFragments: 23 | """ 24 | Provides both clustering and fragments data for a minibatch. 25 | """ 26 | 27 | def __init__( 28 | self, 29 | clustering: chromatinhd.data.clustering.Clustering, 30 | fragments: chromatinhd.data.fragments.Fragments, 31 | cellxregion_batch_size: int, 32 | ): 33 | # ensure that the order of clustering and fragment.obs is the same 34 | if not all(clustering.labels.index == fragments.obs.index): 35 | raise ValueError("Clustering and fragments should have the same obs index. ") 36 | self.clustering = Clustering(clustering) 37 | self.fragments = Fragments(fragments, cellxregion_batch_size=cellxregion_batch_size) 38 | 39 | def load(self, minibatch): 40 | return Result( 41 | fragments=self.fragments.load(minibatch), 42 | clustering=self.clustering.load(minibatch), 43 | minibatch=minibatch, 44 | ) 45 | -------------------------------------------------------------------------------- /src/chromatinhd/models/diff/model/__init__.py: -------------------------------------------------------------------------------- 1 | from . import cutnf 2 | from . import binary 3 | -------------------------------------------------------------------------------- /src/chromatinhd/models/diff/model/splines/__init__.py: -------------------------------------------------------------------------------- 1 | from . import linear 2 | from . import quadratic 3 | -------------------------------------------------------------------------------- /src/chromatinhd/models/diff/model/splines/linear.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | import numpy as np 6 | 7 | 8 | def unconstrained_linear_spline( 9 | inputs, unnormalized_pdf, inverse=False, tail_bound=1.0, tails="linear" 10 | ): 11 | inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) 12 | outside_interval_mask = ~inside_interval_mask 13 | 14 | outputs = torch.zeros_like(inputs) 15 | logabsdet = torch.zeros_like(inputs) 16 | 17 | if tails == "linear": 18 | outputs[outside_interval_mask] = inputs[outside_interval_mask] 19 | logabsdet[outside_interval_mask] = 0 20 | else: 21 | raise RuntimeError("{} tails are not implemented.".format(tails)) 22 | 23 | outputs[inside_interval_mask], logabsdet[inside_interval_mask] = linear_spline( 24 | inputs=inputs[inside_interval_mask], 25 | unnormalized_pdf=unnormalized_pdf[inside_interval_mask, :], 26 | inverse=inverse, 27 | left=-tail_bound, 28 | right=tail_bound, 29 | bottom=-tail_bound, 30 | top=tail_bound, 31 | ) 32 | 33 | return outputs, logabsdet 34 | 35 | 36 | def linear_spline( 37 | inputs, unnormalized_pdf, inverse=False, left=0.0, right=1.0, bottom=0.0, top=1.0 38 | ): 39 | """ 40 | Reference: 41 | > Müller et al., Neural Importance Sampling, arXiv:1808.03856, 2018. 42 | """ 43 | if not inverse and (torch.min(inputs) < left or torch.max(inputs) > right): 44 | raise ValueError("Inputs outside domain") 45 | elif inverse and (torch.min(inputs) < bottom or torch.max(inputs) > top): 46 | raise ValueError("Inputs outside domain") 47 | 48 | if inverse: 49 | inputs = (inputs - bottom) / (top - bottom) 50 | else: 51 | inputs = (inputs - left) / (right - left) 52 | 53 | num_bins = unnormalized_pdf.size(-1) 54 | 55 | pdf = F.softmax(unnormalized_pdf, dim=-1) 56 | 57 | cdf = torch.cumsum(pdf, dim=-1) 58 | cdf[..., -1] = 1.01 59 | cdf = F.pad(cdf, pad=(1, 0), mode="constant", value=0.0) 60 | 61 | if inverse: 62 | inv_bin_idx = torch.searchsorted(cdf, inputs, right=False) 63 | 64 | bin_boundaries = ( 65 | torch.linspace(0, 1, num_bins + 1) 66 | .view([1] * inputs.dim() + [-1]) 67 | .expand(*inputs.shape, -1) 68 | ) 69 | 70 | slopes = (cdf[..., 1:] - cdf[..., :-1]) / ( 71 | bin_boundaries[..., 1:] - bin_boundaries[..., :-1] 72 | ) 73 | offsets = cdf[..., 1:] - slopes * bin_boundaries[..., 1:] 74 | 75 | inv_bin_idx = inv_bin_idx.unsqueeze(-1) 76 | input_slopes = slopes.gather(-1, inv_bin_idx)[..., 0] 77 | input_offsets = offsets.gather(-1, inv_bin_idx)[..., 0] 78 | 79 | outputs = (inputs - input_offsets) / input_slopes 80 | outputs = torch.clamp(outputs, 0, 1) 81 | 82 | logabsdet = -torch.log(input_slopes) 83 | else: 84 | bin_pos = inputs * num_bins 85 | 86 | bin_idx = torch.floor(bin_pos).long() 87 | bin_idx[bin_idx >= num_bins] = num_bins - 1 88 | 89 | alpha = bin_pos - bin_idx.float() 90 | 91 | if bin_idx.ndim < pdf.ndim: 92 | bin_idx = bin_idx.unsqueeze(-1) 93 | 94 | input_pdfs = pdf.gather(-1, bin_idx) # [..., 0] 95 | 96 | outputs = cdf.gather(-1, bin_idx) # [..., 0] 97 | outputs += alpha * input_pdfs 98 | outputs = torch.clamp(outputs, 0, 1) 99 | 100 | bin_width = 1.0 / num_bins 101 | logabsdet = torch.log(input_pdfs) - np.log(bin_width) 102 | 103 | if inverse: 104 | outputs = outputs * (right - left) + left 105 | logabsdet = logabsdet - math.log(top - bottom) + math.log(right - left) 106 | else: 107 | outputs = outputs * (top - bottom) + bottom 108 | logabsdet = logabsdet + math.log(top - bottom) - math.log(right - left) 109 | 110 | return outputs, logabsdet 111 | -------------------------------------------------------------------------------- /src/chromatinhd/models/diff/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import Trainer, TrainerPerFeature 2 | -------------------------------------------------------------------------------- /src/chromatinhd/models/miff/loader/__init__.py: -------------------------------------------------------------------------------- 1 | from chromatinhd.loaders.minibatches import Minibatcher, Minibatch 2 | -------------------------------------------------------------------------------- /src/chromatinhd/models/miff/loader/clustering.py: -------------------------------------------------------------------------------- 1 | import chromatinhd.data.clustering 2 | import dataclasses 3 | import torch 4 | 5 | 6 | @dataclasses.dataclass 7 | class Result: 8 | labels: torch.Tensor 9 | 10 | def to(self, device): 11 | self.labels = self.labels.to(device) 12 | return self 13 | 14 | 15 | class Clustering: 16 | """ 17 | Provides clustering data for a minibatch. 18 | """ 19 | 20 | def __init__( 21 | self, 22 | clustering: chromatinhd.data.clustering.Clustering, 23 | ): 24 | assert (clustering.labels.cat.categories == clustering.cluster_info.index).all(), ( 25 | clustering.labels.cat.categories, 26 | clustering.cluster_info.index, 27 | ) 28 | self.labels = torch.from_numpy(clustering.labels.cat.codes.values.copy()).to(torch.int64) 29 | 30 | def load(self, minibatch): 31 | return Result(labels=self.labels[minibatch.cells_oi]) 32 | -------------------------------------------------------------------------------- /src/chromatinhd/models/miff/loader/combinations.py: -------------------------------------------------------------------------------- 1 | import chromatinhd.data.fragments 2 | import dataclasses 3 | import numpy as np 4 | import torch 5 | 6 | from chromatinhd.loaders.fragments import Fragments as FragmentsLoader 7 | from chromatinhd.loaders.minibatches import Minibatch 8 | from .motifcount import BinnedMotifCounts as BinnedMotifCountsLoader 9 | from .clustering import Clustering as ClusteringLoader 10 | from typing import List 11 | 12 | 13 | @dataclasses.dataclass 14 | class MotifCountsFragmentsClusteringResult: 15 | fragments: any 16 | motifcounts: any 17 | clustering: any 18 | minibatch: any 19 | 20 | def to(self, device): 21 | self.fragments.to(device) 22 | self.motifcounts.to(device) 23 | self.clustering.to(device) 24 | self.minibatch.to(device) 25 | return self 26 | 27 | 28 | class MotifCountsFragmentsClustering: 29 | """ 30 | Provides both motifcounts and fragments data for a minibatch. 31 | """ 32 | 33 | def __init__( 34 | self, 35 | motifcounts, 36 | fragments: chromatinhd.data.fragments.Fragments, 37 | clustering: chromatinhd.data.clustering.Clustering, 38 | cellxregion_batch_size: int, 39 | ): 40 | # ensure that the order of motifs and fragment.obs is the same 41 | self.fragments = FragmentsLoader(fragments, cellxregion_batch_size=cellxregion_batch_size) 42 | self.motifcounts = BinnedMotifCountsLoader(motifcounts) 43 | self.clustering = ClusteringLoader(clustering) 44 | 45 | def load(self, minibatch): 46 | fragments = self.fragments.load(minibatch) 47 | return MotifCountsFragmentsClusteringResult( 48 | fragments=fragments, 49 | motifcounts=self.motifcounts.load(minibatch, fragments), 50 | clustering=self.clustering.load(minibatch), 51 | minibatch=minibatch, 52 | ) 53 | 54 | 55 | @dataclasses.dataclass 56 | class FragmentsClusteringResult: 57 | fragments: any 58 | clustering: any 59 | minibatch: any 60 | 61 | def to(self, device): 62 | self.fragments.to(device) 63 | self.clustering.to(device) 64 | self.minibatch.to(device) 65 | return self 66 | 67 | 68 | class FragmentsClustering: 69 | """ 70 | Provides both fragments and clustering data for a minibatch. 71 | """ 72 | 73 | def __init__( 74 | self, 75 | fragments: chromatinhd.data.fragments.Fragments, 76 | clustering: chromatinhd.data.clustering.Clustering, 77 | cellxregion_batch_size: int, 78 | ): 79 | # ensure that the order of motifs and fragment.obs is the same 80 | self.fragments = FragmentsLoader(fragments, cellxregion_batch_size=cellxregion_batch_size) 81 | self.clustering = ClusteringLoader(clustering) 82 | 83 | def load(self, minibatch): 84 | fragments = self.fragments.load(minibatch) 85 | return FragmentsClusteringResult( 86 | fragments=fragments, 87 | clustering=self.clustering.load(minibatch), 88 | minibatch=minibatch, 89 | ) 90 | -------------------------------------------------------------------------------- /src/chromatinhd/models/miff/loader/minibatches.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import dataclasses 3 | import itertools 4 | import math 5 | import torch 6 | 7 | 8 | @dataclasses.dataclass 9 | class Minibatch: 10 | cells_oi: np.ndarray 11 | genes_oi: np.ndarray 12 | phase: str = "train" 13 | device: str = "cpu" 14 | 15 | def items(self): 16 | return {"cells_oi": self.cells_oi, "genes_oi": self.genes_oi} 17 | 18 | def filter_genes(self, genes): 19 | genes_oi = self.genes_oi[genes[self.genes_oi]] 20 | 21 | return Minibatch( 22 | cells_oi=self.cells_oi, 23 | genes_oi=genes_oi, 24 | phase=self.phase, 25 | ) 26 | 27 | def to(self, device): 28 | self.device = device 29 | 30 | @property 31 | def genes_oi_torch(self): 32 | return torch.from_numpy(self.genes_oi).to(self.device) 33 | 34 | @property 35 | def cells_oi_torch(self): 36 | return torch.from_numpy(self.genes_oi).to(self.device) 37 | 38 | @property 39 | def n_cells(self): 40 | return len(self.cells_oi) 41 | 42 | @property 43 | def n_genes(self): 44 | return len(self.genes_oi) 45 | 46 | 47 | class Minibatcher: 48 | def __init__( 49 | self, 50 | cells, 51 | genes, 52 | n_cells_step, 53 | n_regions_step, 54 | use_all_cells=False, 55 | use_all_regions=True, 56 | permute_cells=True, 57 | permute_regions=False, 58 | ): 59 | self.cells = cells 60 | if not isinstance(genes, np.ndarray): 61 | genes = np.array(genes) 62 | self.genes = genes 63 | self.n_genes = len(genes) 64 | self.n_cells_step = n_cells_step 65 | self.n_regions_step = n_regions_step 66 | 67 | self.permute_cells = permute_cells 68 | self.permute_regions = permute_regions 69 | 70 | self.use_all_cells = use_all_cells or len(cells) < n_cells_step 71 | self.use_all_regions = use_all_regions or len(genes) < n_regions_step 72 | 73 | self.cellxregion_batch_size = n_cells_step * n_regions_step 74 | 75 | # calculate length 76 | n_cells = len(cells) 77 | n_genes = len(genes) 78 | if self.use_all_cells: 79 | n_cell_bins = math.ceil(n_cells / n_cells_step) 80 | else: 81 | n_cell_bins = math.floor(n_cells / n_cells_step) 82 | if self.use_all_regions: 83 | n_gene_bins = math.ceil(n_genes / n_regions_step) 84 | else: 85 | n_gene_bins = math.floor(n_genes / n_regions_step) 86 | self.length = n_cell_bins * n_gene_bins 87 | 88 | self.i = 0 89 | 90 | def __len__(self): 91 | return self.length 92 | 93 | def __iter__(self): 94 | self.rg = np.random.RandomState(self.i) 95 | 96 | if self.permute_cells: 97 | cells = self.rg.permutation(self.cells) 98 | else: 99 | cells = self.cells 100 | if self.permute_regions: 101 | genes = self.rg.permutation(self.genes) 102 | else: 103 | genes = self.genes 104 | 105 | gene_cuts = [*np.arange(0, len(genes), step=self.n_regions_step)] 106 | if self.use_all_regions: 107 | gene_cuts.append(len(genes)) 108 | gene_bins = [genes[a:b] for a, b in zip(gene_cuts[:-1], gene_cuts[1:])] 109 | 110 | cell_cuts = [*np.arange(0, len(cells), step=self.n_cells_step)] 111 | if self.use_all_cells: 112 | cell_cuts.append(len(cells)) 113 | cell_bins = [cells[a:b] for a, b in zip(cell_cuts[:-1], cell_cuts[1:])] 114 | 115 | for cells_oi, genes_oi in itertools.product(cell_bins, gene_bins): 116 | yield Minibatch(cells_oi=cells_oi, genes_oi=genes_oi) 117 | 118 | self.i += 1 119 | -------------------------------------------------------------------------------- /src/chromatinhd/models/miff/loader/motifs.py: -------------------------------------------------------------------------------- 1 | import chromatinhd.data.motifscan 2 | import dataclasses 3 | import torch 4 | from chromatinhd.utils.numpy import indptr_to_indices 5 | import numpy as np 6 | 7 | 8 | @dataclasses.dataclass 9 | class Result: 10 | indices: torch.Tensor 11 | positions: torch.Tensor 12 | scores: torch.Tensor 13 | local_genexmotif_ix: torch.Tensor 14 | local_gene_ix: torch.Tensor 15 | n_genes: int 16 | 17 | def to(self, device): 18 | for field_name, field in self.__dataclass_fields__.items(): 19 | if field.type is torch.Tensor: 20 | self.__setattr__(field_name, self.__getattribute__(field_name).to(device)) 21 | return self 22 | 23 | 24 | class Motifs: 25 | """ 26 | Provides motifscan data for a minibatch. 27 | """ 28 | 29 | def __init__( 30 | self, 31 | motifscan: chromatinhd.data.motifscan.Motifscan, 32 | ): 33 | self.motifscan = motifscan 34 | self.region_width = motifscan.regions.window[1] - motifscan.regions.window[0] 35 | self.n_motifs = motifscan.motifs.shape[0] 36 | 37 | def load(self, minibatch): 38 | local_genexmotif_ix = [] 39 | scores = [] 40 | positions = [] 41 | indices = [] 42 | local_gene_ix = [] 43 | for i, gene_ix in enumerate(minibatch.genes_oi): 44 | indptr_start = gene_ix * self.region_width 45 | indptr_end = (gene_ix + 1) * self.region_width 46 | indices_gene = self.motifscan.indices[ 47 | self.motifscan.indptr[indptr_start] : self.motifscan.indptr[indptr_end] 48 | ] 49 | positions_gene = ( 50 | indptr_to_indices(self.motifscan.indptr[indptr_start : indptr_end + 1]) 51 | + self.motifscan.regions.window[0] 52 | ) 53 | scores_gene = self.motifscan.scores[self.motifscan.indptr[indptr_start] : self.motifscan.indptr[indptr_end]] 54 | 55 | indices.append(indices_gene) 56 | positions.append(positions_gene) 57 | scores.append(scores_gene) 58 | 59 | local_genexmotif_ix_gene = np.ones_like(indices_gene, dtype=np.int32) * i * self.n_motifs + indices_gene 60 | local_genexmotif_ix.append(local_genexmotif_ix_gene) 61 | local_gene_ix_gene = np.ones_like(indices_gene, dtype=np.int32) * i 62 | local_gene_ix.append(local_gene_ix_gene) 63 | 64 | indices = torch.from_numpy(np.concatenate(indices)).contiguous() 65 | positions = torch.from_numpy(np.concatenate(positions)).contiguous() 66 | scores = torch.from_numpy(np.concatenate(scores)).contiguous() 67 | local_genexmotif_ix = torch.from_numpy(np.concatenate(local_genexmotif_ix)).contiguous() 68 | local_gene_ix = torch.from_numpy(np.concatenate(local_gene_ix)).contiguous() 69 | return Result(indices, positions, scores, local_genexmotif_ix, local_gene_ix, n_genes=minibatch.n_genes) 70 | -------------------------------------------------------------------------------- /src/chromatinhd/models/miff/loader/motifs_fragments.py: -------------------------------------------------------------------------------- 1 | import chromatinhd.data.fragments 2 | import dataclasses 3 | import numpy as np 4 | 5 | from chromatinhd.loaders.fragments import Fragments 6 | from chromatinhd.loaders.minibatches import Minibatch 7 | from .motifs import Motifs 8 | from typing import List 9 | 10 | 11 | @dataclasses.dataclass 12 | class Result: 13 | fragments: Fragments 14 | motifs: Motifs 15 | minibatch: Minibatch 16 | 17 | def to(self, device): 18 | self.fragments.to(device) 19 | self.motifs.to(device) 20 | self.minibatch.to(device) 21 | return self 22 | 23 | 24 | class MotifsFragments: 25 | """ 26 | Provides both motifs and fragments data for a minibatch. 27 | """ 28 | 29 | def __init__( 30 | self, 31 | motifscan: chromatinhd.data.motifscan.Motifscan, 32 | fragments: chromatinhd.data.fragments.Fragments, 33 | cellxregion_batch_size: int, 34 | ): 35 | self.motifs = Motifs(motifscan) 36 | self.fragments = Fragments(fragments, cellxregion_batch_size=cellxregion_batch_size) 37 | 38 | def load(self, minibatch): 39 | return Result( 40 | fragments=self.fragments.load(minibatch), 41 | motifs=self.motifs.load(minibatch), 42 | minibatch=minibatch, 43 | ) 44 | -------------------------------------------------------------------------------- /src/chromatinhd/models/miff/model/__init__.py: -------------------------------------------------------------------------------- 1 | from . import global_norm 2 | -------------------------------------------------------------------------------- /src/chromatinhd/models/miff/model/clustering/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import Model 2 | -------------------------------------------------------------------------------- /src/chromatinhd/models/miff/model/clustering2/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import Model 2 | -------------------------------------------------------------------------------- /src/chromatinhd/models/miff/model/clustering3/__init__.py: -------------------------------------------------------------------------------- 1 | from . import count, position 2 | from .model import Model 3 | -------------------------------------------------------------------------------- /src/chromatinhd/models/miff/model/clustering3/count.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import math 4 | from typing import Union 5 | from chromatinhd.embedding import EmbeddingTensor 6 | from chromatinhd.data.fragments import Fragments, FragmentsView 7 | 8 | 9 | class FragmentCountDistribution(torch.nn.Module): 10 | pass 11 | 12 | 13 | def count_fragments(data): 14 | count = torch.bincount( 15 | data.fragments.local_cellxregion_ix, minlength=data.minibatch.n_cells * data.minibatch.n_regions 16 | ).reshape((data.minibatch.n_cells, data.minibatch.n_regions)) 17 | return count 18 | 19 | 20 | class Baseline(FragmentCountDistribution): 21 | def __init__(self, fragments, clustering, baseline=None): 22 | super().__init__() 23 | self.logit = torch.nn.Parameter(torch.zeros((1,))) 24 | if baseline is not None: 25 | self.baseline = baseline 26 | else: 27 | self.baseline = EmbeddingTensor(fragments.n_regions, (1,), sparse=True) 28 | init = torch.from_numpy(fragments.regionxcell_counts.sum(1).astype(np.float) / fragments.n_cells) 29 | init = torch.log(init) 30 | self.baseline.weight.data[:] = init.unsqueeze(-1) 31 | 32 | lib = torch.from_numpy(fragments.regionxcell_counts.sum(0).astype(np.float) / fragments.n_regions) 33 | lib = torch.log(lib) 34 | self.register_buffer("lib", lib) 35 | 36 | def log_prob(self, data): 37 | count = count_fragments(data) 38 | 39 | if data.fragments.lib is not None: 40 | lib = data.fragments.lib 41 | else: 42 | lib = self.lib[data.minibatch.cells_oi] 43 | logits = self.baseline(data.minibatch.regions_oi_torch).squeeze(1).unsqueeze(0) + lib.unsqueeze(1) 44 | likelihood_count = torch.distributions.Poisson(rate=torch.exp(logits)).log_prob(count) 45 | return likelihood_count 46 | 47 | def parameters_sparse(self): 48 | yield self.baseline.weight 49 | 50 | 51 | class FragmentCountDistribution1(FragmentCountDistribution): 52 | def __init__(self, fragments, clustering, baseline=None, delta_logit=None): 53 | super().__init__() 54 | if baseline is not None: 55 | self.baseline = baseline 56 | else: 57 | self.baseline = EmbeddingTensor(fragments.n_regions, (1,), sparse=True) 58 | init = torch.from_numpy(fragments.regionxcell_counts.sum(1).astype(np.float) / fragments.n_cells) 59 | init = torch.log(init) 60 | self.baseline.weight.data[:] = init.unsqueeze(-1) 61 | 62 | if delta_logit is not None: 63 | self.delta_logit = delta_logit 64 | else: 65 | self.delta_logit = EmbeddingTensor(fragments.n_regions, (clustering.n_clusters,), sparse=True) 66 | self.delta_logit.weight.data[:] = 0 67 | 68 | lib = torch.from_numpy(fragments.regionxcell_counts.sum(0).astype(np.float) / fragments.n_regions) 69 | lib = torch.log(lib) 70 | self.register_buffer("lib", lib) 71 | 72 | def log_prob(self, data): 73 | count = count_fragments(data) 74 | 75 | if data.fragments.lib is not None: 76 | lib = data.fragments.lib 77 | else: 78 | lib = self.lib[data.minibatch.cells_oi] 79 | logits = ( 80 | +self.baseline(data.minibatch.regions_oi_torch).squeeze(1).unsqueeze(0) 81 | + lib.unsqueeze(1) 82 | + self.delta_logit(data.minibatch.regions_oi_torch)[:, data.clustering.labels].transpose(0, 1) 83 | ) 84 | 85 | count = torch.bincount( 86 | data.fragments.local_cellxregion_ix, minlength=data.minibatch.n_cells * data.minibatch.n_regions 87 | ).reshape((data.minibatch.n_cells, data.minibatch.n_regions)) 88 | likelihood_count = torch.distributions.Poisson(rate=torch.exp(logits)).log_prob(count) 89 | return likelihood_count 90 | 91 | def parameters_sparse(self): 92 | yield self.baseline.weight 93 | yield self.delta_logit.weight 94 | 95 | @classmethod 96 | def from_baseline(cls, fragments, clustering, count_reference): 97 | baseline = torch.nn.Embedding.from_pretrained(count_reference.baseline.weight.data) 98 | delta_logit = EmbeddingTensor.from_pretrained(count_reference.delta_logit) 99 | 100 | return cls(fragments=fragments, clustering=clustering, delta_logit=delta_logit, baseline=baseline) 101 | -------------------------------------------------------------------------------- /src/chromatinhd/models/miff/model/global_norm/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import Model 2 | -------------------------------------------------------------------------------- /src/chromatinhd/models/miff/model/local_norm/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import Model 2 | -------------------------------------------------------------------------------- /src/chromatinhd/models/miff/model/zoom.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | 5 | def transform_linear_spline(positions, n, width, unnormalized_heights): 6 | binsize = torch.div(width, n, rounding_mode="floor") 7 | 8 | normalized_heights = torch.nn.functional.log_softmax(unnormalized_heights, -1) 9 | if normalized_heights.ndim == positions.ndim: 10 | normalized_heights = normalized_heights.unsqueeze(0) 11 | 12 | binixs = torch.div(positions, binsize, rounding_mode="trunc") 13 | 14 | logprob = torch.gather(normalized_heights, 1, binixs.unsqueeze(1)).squeeze(1) 15 | 16 | positions = positions - binixs * binsize 17 | width = binsize 18 | 19 | return logprob, positions, width 20 | 21 | 22 | def calculate_logprob(positions, nbins, width, unnormalized_heights_zooms): 23 | """ 24 | Calculate the zoomed log probability per position given a set of unnormalized_heights_zooms 25 | """ 26 | assert len(nbins) == len(unnormalized_heights_zooms) 27 | 28 | curpositions = positions 29 | curwidth = width 30 | logprob = torch.zeros_like(positions, dtype=torch.float) 31 | for i, (n, unnormalized_heights_zoom) in enumerate(zip(nbins, unnormalized_heights_zooms)): 32 | assert (curwidth % n) == 0 33 | logprob_layer, curpositions, curwidth = transform_linear_spline( 34 | curpositions, 35 | n, 36 | curwidth, 37 | unnormalized_heights_zoom, 38 | ) 39 | logprob += logprob_layer 40 | logprob = logprob - math.log( 41 | curwidth 42 | ) # if any width is left, we need to divide by the remaining number of possibilities to get a properly normalized probability 43 | return logprob 44 | 45 | 46 | def extract_unnormalized_heights(positions, totalbinwidths, unnormalized_heights_all): 47 | """ 48 | Extracts the unnormalized heights per zoom level from the global unnormalized heights tensor with size (totaln, n) 49 | You typically do not want to use this function directly, as the benifits of a zoomed likelihood are lost in this way. 50 | This function is mainly useful for debugging, inference or testing purposes 51 | """ 52 | totalbinixs = torch.div(positions[:, None], totalbinwidths, rounding_mode="floor") 53 | totalbinsectors = torch.nn.functional.pad(totalbinixs[..., :-1], (1, 0)) 54 | # totalbinsectors = torch.div(totalbinixs, self.nbins[None, :], rounding_mode="floor") 55 | unnormalized_heights_zooms = [ 56 | torch.index_select( 57 | unnormalized_heights_all[i], 0, totalbinsector 58 | ) # index_select is much faster than direct indexin 59 | for i, totalbinsector in enumerate(totalbinsectors.T) 60 | ] 61 | 62 | return unnormalized_heights_zooms 63 | -------------------------------------------------------------------------------- /src/chromatinhd/models/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from chromatinhd.flow import Flow, Stored 3 | 4 | 5 | def get_sparse_parameters(module, parameters): 6 | for module in module._modules.values(): 7 | for p in module.parameters_sparse() if hasattr(module, "parameters_sparse") else []: 8 | parameters.add(p) 9 | parameters = get_sparse_parameters(module, parameters) 10 | return parameters 11 | 12 | 13 | class HybridModel: 14 | def parameters_dense(self, autoextend=True): 15 | """ 16 | Get all dense parameters of the model 17 | """ 18 | parameters = [ 19 | parameter 20 | for module in self._modules.values() 21 | for parameter in (module.parameters_dense() if hasattr(module, "parameters_dense") else []) 22 | ] 23 | 24 | # extend with any left over parameters that were not specified in parameters_dense or parameters_sparse 25 | def contains(x, y): 26 | return any([x is y_ for y_ in y]) 27 | 28 | parameters_sparse = set(self.parameters_sparse()) 29 | 30 | if autoextend: 31 | for p in self.parameters(): 32 | if p not in parameters_sparse: 33 | parameters.append(p) 34 | parameters = [p for p in parameters if p.requires_grad] 35 | return parameters 36 | 37 | def parameters_sparse(self, autoextend=True): 38 | """ 39 | Get all sparse parameters in a model 40 | """ 41 | parameters = set() 42 | 43 | parameters = get_sparse_parameters(self, parameters) 44 | parameters = [p for p in parameters if p.requires_grad] 45 | return parameters 46 | 47 | 48 | class FlowModel(torch.nn.Module, HybridModel, Flow): 49 | state = Stored() 50 | 51 | def __init__(self, path=None, reset=False, **kwargs): 52 | torch.nn.Module.__init__(self) 53 | Flow.__init__(self, path=path, reset=reset, **kwargs) 54 | 55 | if self.o.state.exists(self): 56 | if reset: 57 | raise ValueError("Cannot reset a model that has already been initialized") 58 | self.restore_state() 59 | 60 | def save_state(self): 61 | from collections import OrderedDict 62 | 63 | state = OrderedDict() 64 | for k, v in self.__dict__.items(): 65 | if k.lstrip("_") in self._obj_map: 66 | continue 67 | if k == "path": 68 | continue 69 | state[k] = v 70 | self.state = state 71 | 72 | @classmethod 73 | def restore(cls, path): 74 | self = cls.__new__(cls) 75 | Flow.__init__(self, path=path) 76 | self.restore_state() 77 | return self 78 | 79 | def restore_state(self): 80 | state = self.state 81 | for k, v in state.items(): 82 | # if k.lstrip("_") in self._obj_map: 83 | # continue 84 | self.__dict__[k] = v 85 | return self 86 | -------------------------------------------------------------------------------- /src/chromatinhd/models/pred/__init__.py: -------------------------------------------------------------------------------- 1 | from . import model 2 | from . import interpret 3 | from . import plot 4 | -------------------------------------------------------------------------------- /src/chromatinhd/models/pred/interpret/__init__.py: -------------------------------------------------------------------------------- 1 | from .censorers import WindowCensorer, MultiWindowCensorer, SizeCensorer, WindowSizeCensorer 2 | from .regionmultiwindow import RegionMultiWindow 3 | from .regionpairwindow import RegionPairWindow 4 | from .regionsizewindow import RegionSizeWindow 5 | from .performance import Performance 6 | from .size import Size 7 | -------------------------------------------------------------------------------- /src/chromatinhd/models/pred/interpret/performance.py: -------------------------------------------------------------------------------- 1 | import chromatinhd as chd 2 | import numpy as np 3 | import pandas as pd 4 | import xarray as xr 5 | import pickle 6 | import scipy.stats 7 | import tqdm.auto as tqdm 8 | from chromatinhd import get_default_device 9 | 10 | from chromatinhd.flow.objects import Stored, Dataset, StoredDict 11 | 12 | from itertools import product 13 | 14 | 15 | class Performance(chd.flow.Flow): 16 | """ 17 | The train/validation/test performance of a (set of) models. 18 | """ 19 | 20 | scores = chd.flow.SparseDataset() 21 | 22 | folds = None 23 | 24 | @classmethod 25 | def create(cls, folds, transcriptome, fragments, phases=None, overwrite=False, path=None): 26 | self = cls(path=path, reset=overwrite) 27 | 28 | self.folds = folds 29 | self.transcriptome = transcriptome 30 | self.fragments = fragments 31 | 32 | if self.o.scores.exists(self) and not overwrite: 33 | assert self.scores.coords_pointed["region"].equals(fragments.var.index) 34 | 35 | return self 36 | 37 | if phases is None: 38 | phases = ["train", "validation", "test"] 39 | 40 | coords_pointed = { 41 | "region": fragments.regions.var.index, 42 | "fold": pd.Index(range(len(folds)), name="fold"), 43 | "phase": pd.Index(phases, name="phase"), 44 | } 45 | coords_fixed = {} 46 | self.scores = chd.sparse.SparseDataset.create( 47 | self.path / "scores", 48 | variables={ 49 | "cor": { 50 | "dimensions": ("region", "fold", "phase"), 51 | "dtype": np.float32, 52 | "sparse": False, 53 | }, 54 | "scored": { 55 | "dimensions": ("region", "fold"), 56 | "dtype": np.bool, 57 | "sparse": False, 58 | }, 59 | }, 60 | coords_pointed=coords_pointed, 61 | coords_fixed=coords_fixed, 62 | ) 63 | 64 | return self 65 | 66 | def score( 67 | self, 68 | models, 69 | force=False, 70 | device=None, 71 | regions=None, 72 | pbar=True, 73 | ): 74 | fragments = self.fragments 75 | folds = self.folds 76 | 77 | phases = self.scores.coords["phase"].values 78 | 79 | if regions is None: 80 | regions_oi = fragments.var.index.tolist() if models.regions_oi is None else models.regions_oi 81 | if isinstance(regions_oi, pd.Series): 82 | regions_oi = regions_oi.tolist() 83 | else: 84 | regions_oi = regions 85 | 86 | design = ( 87 | self.scores["scored"] 88 | .sel_xr() 89 | .sel({fragments.var.index.name: regions_oi, "fold": range(len(folds))}) 90 | .to_dataframe(name="scored") 91 | ) 92 | design["force"] = (~design["scored"]) | force 93 | 94 | design = design.groupby("gene").any() 95 | 96 | regions_oi = design.index[design["force"]] 97 | 98 | if len(regions_oi) == 0: 99 | return self 100 | 101 | for fold_ix in range(len(folds)): 102 | for phase in phases: 103 | fold = folds[fold_ix] 104 | cells_oi = fold[f"cells_{phase}"] 105 | 106 | for region_ix, region_oi in enumerate(regions_oi): 107 | predicted, expected, n_fragments = models.get_prediction( 108 | region=region_oi, 109 | fold_ix=fold_ix, 110 | cell_ixs=cells_oi, 111 | return_raw=True, 112 | fragments=self.fragments, 113 | transcriptome=self.transcriptome, 114 | ) 115 | 116 | cor = chd.utils.paircor(predicted, expected) 117 | 118 | self.scores["cor"][ 119 | region_oi, 120 | fold_ix, 121 | phase, 122 | ] = cor[0] 123 | self.scores["scored"][region_oi, fold_ix] = True 124 | 125 | return self 126 | 127 | @property 128 | def scored(self): 129 | return self.o.scores.exists(self) 130 | -------------------------------------------------------------------------------- /src/chromatinhd/models/pred/model/__init__.py: -------------------------------------------------------------------------------- 1 | from . import better 2 | from . import multiscale 3 | -------------------------------------------------------------------------------- /src/chromatinhd/models/pred/model/loss.py: -------------------------------------------------------------------------------- 1 | def paircor(x, y, dim=0, eps=1e-3): 2 | divisor = (y.std(dim) * x.std(dim)) + eps 3 | cor = ((x - x.mean(dim, keepdims=True)) * (y - y.mean(dim, keepdims=True))).mean(dim) / divisor 4 | return cor 5 | 6 | 7 | def paircor_loss(x, y): 8 | return -paircor(x, y).sum() * 100 9 | 10 | 11 | def region_paircor_loss(x, y): 12 | return -paircor(x, y) * 100 13 | 14 | 15 | def pairzmse(x, y, dim=0, eps=1e-3): 16 | y = (y - y.mean(dim, keepdims=True)) / (y.std(dim, keepdims=True) + eps) 17 | return (y - x).pow(2).mean(dim) 18 | 19 | 20 | def pairzmse_loss(x, y): 21 | return pairzmse(x, y).mean() * 0.1 22 | 23 | 24 | def region_pairzmse_loss(x, y): 25 | return pairzmse(x, y) * 0.1 26 | -------------------------------------------------------------------------------- /src/chromatinhd/models/pred/model/multilinear.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from chromatinhd.embedding import FeatureParameter 3 | import math 4 | 5 | 6 | class MultiLinear(torch.nn.Module): 7 | def __init__( 8 | self, 9 | in_features: int, 10 | out_features: int, 11 | n_heads: int, 12 | bias: bool = True, 13 | device=None, 14 | dtype=None, 15 | weight_constructor=None, 16 | bias_constructor=None, 17 | ): 18 | super().__init__() 19 | 20 | self.out_features = out_features 21 | 22 | if bias: 23 | if bias_constructor is None: 24 | 25 | def bias_constructor(shape): 26 | stdv = 1.0 / math.sqrt(shape[-1]) 27 | return torch.empty(shape, device=device, dtype=dtype).uniform_(-stdv, stdv) 28 | 29 | bias = FeatureParameter(n_heads, (out_features,), constructor=bias_constructor) 30 | self.register_module("bias", bias) 31 | else: 32 | self.bias = None 33 | 34 | if weight_constructor is None: 35 | 36 | def weight_constructor(shape): 37 | stdv = 1.0 / math.sqrt(shape[-1]) 38 | return torch.empty(shape, device=device, dtype=dtype).uniform_(-stdv, stdv) 39 | 40 | torch.nn.Linear 41 | 42 | self.weight = FeatureParameter( 43 | n_heads, 44 | ( 45 | in_features, 46 | out_features, 47 | ), 48 | constructor=weight_constructor, 49 | ) 50 | 51 | def forward(self, input: torch.Tensor, indptr, regions_oi): 52 | outputs = [] 53 | 54 | if self.bias is not None: 55 | for ix, start, end in zip(regions_oi, indptr[:-1], indptr[1:]): 56 | outputs.append(torch.matmul(input[start:end], self.weight[ix]) + self.bias[ix]) 57 | else: 58 | for ix, start, end in zip(regions_oi, indptr[:-1], indptr[1:]): 59 | outputs.append(torch.matmul(input[start:end], self.weight[ix])) 60 | output = torch.cat(outputs, dim=0) 61 | return output 62 | -------------------------------------------------------------------------------- /src/chromatinhd/models/pred/model/multiscale.py: -------------------------------------------------------------------------------- 1 | from .better import * 2 | -------------------------------------------------------------------------------- /src/chromatinhd/models/pred/plot/__init__.py: -------------------------------------------------------------------------------- 1 | from .predictivity import Predictivity, Pileup, PredictivityBroken, PileupBroken 2 | from .effect import Effect 3 | from .copredictivity import Copredictivity, CopredictivityBroken 4 | 5 | __all__ = ["Predictivity", "Pileup", "Effect", "Copredictivity"] 6 | -------------------------------------------------------------------------------- /src/chromatinhd/models/pred/plot/effect.py: -------------------------------------------------------------------------------- 1 | import polyptich.grid 2 | import numpy as np 3 | 4 | 5 | class Effect(polyptich.grid.Panel): 6 | def __init__(self, plotdata, window, width, show_accessibility=False): 7 | super().__init__((width, 0.5)) 8 | if "position" not in plotdata.columns: 9 | plotdata = plotdata.reset_index() 10 | 11 | ax = self.ax 12 | ax.set_xlim(*window) 13 | 14 | ax.plot( 15 | plotdata["position"], 16 | plotdata["effect"], 17 | color="#333", 18 | lw=1, 19 | ) 20 | ax.fill_between( 21 | plotdata["position"], 22 | plotdata["effect"], 23 | 0, 24 | color="#333", 25 | alpha=0.2, 26 | lw=0, 27 | ) 28 | 29 | ax.set_ylabel( 30 | "Effect", 31 | rotation=0, 32 | ha="right", 33 | va="center", 34 | ) 35 | 36 | ax.set_xticks([]) 37 | ax.invert_yaxis() 38 | ymax = plotdata["effect"].abs().max() 39 | ax.set_ylim(-ymax, ymax) 40 | 41 | # change vertical alignment of last y tick to bottom 42 | ax.set_yticks([0, ax.get_ylim()[1]]) 43 | ax.get_yticklabels()[-1].set_verticalalignment("top") 44 | ax.get_yticklabels()[0].set_verticalalignment("bottom") 45 | 46 | # vline at tss 47 | ax.axvline(0, color="#888888", lw=0.5, zorder=-1, dashes=(2, 2)) 48 | 49 | @classmethod 50 | def from_RegionMultiWindow(cls, RegionMultiWindow, gene, width, show_accessibility=False): 51 | plotdata = RegionMultiWindow.get_plotdata(gene).reset_index() 52 | window = np.array([plotdata["position"].min(), plotdata["position"].max()]) 53 | return cls(plotdata, window, width, show_accessibility=show_accessibility) 54 | -------------------------------------------------------------------------------- /src/chromatinhd/models/pred/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import Trainer, SharedTrainer, TrainerPerFeature 2 | -------------------------------------------------------------------------------- /src/chromatinhd/models/pret/model/__init__.py: -------------------------------------------------------------------------------- 1 | from . import better 2 | -------------------------------------------------------------------------------- /src/chromatinhd/models/pret/model/loss.py: -------------------------------------------------------------------------------- 1 | def paircor(x, y, dim=0, eps=1e-3): 2 | divisor = (y.std(dim) * x.std(dim)) + eps 3 | cor = ((x - x.mean(dim, keepdims=True)) * (y - y.mean(dim, keepdims=True))).mean(dim) / divisor 4 | return cor 5 | 6 | 7 | def paircor_loss(x, y): 8 | return -paircor(x, y).mean() * 100 9 | 10 | 11 | def region_paircor_loss(x, y): 12 | return -paircor(x, y) * 100 13 | 14 | 15 | def pairzmse(x, y, dim=0, eps=1e-3): 16 | y = (y - y.mean(dim, keepdims=True)) / (y.std(dim, keepdims=True) + eps) 17 | return (y - x).pow(2).mean(dim) 18 | 19 | 20 | def pairzmse_loss(x, y): 21 | return pairzmse(x, y).mean() * 0.1 22 | 23 | 24 | def region_pairzmse_loss(x, y): 25 | return pairzmse(x, y) * 0.1 26 | -------------------------------------------------------------------------------- /src/chromatinhd/models/pret/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import Trainer 2 | -------------------------------------------------------------------------------- /src/chromatinhd/optim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import itertools 3 | 4 | 5 | class SparseDenseAdam(torch.optim.Optimizer): 6 | """ 7 | Optimize both sparse and densre parameters using ADAM 8 | """ 9 | 10 | def __init__(self, parameters_sparse, parameters_dense, lr=1e-3, weight_decay=0.0, **kwargs): 11 | if len(parameters_sparse) == 0: 12 | self.optimizer_sparse = None 13 | else: 14 | self.optimizer_sparse = torch.optim.SparseAdam(parameters_sparse, lr=lr, **kwargs) 15 | if len(parameters_dense) == 0: 16 | self.optimizer_dense = None 17 | else: 18 | self.optimizer_dense = torch.optim.Adam(parameters_dense, lr=lr, weight_decay=weight_decay, **kwargs) 19 | 20 | def zero_grad(self): 21 | if self.optimizer_sparse is not None: 22 | self.optimizer_sparse.zero_grad() 23 | if self.optimizer_dense is not None: 24 | self.optimizer_dense.zero_grad() 25 | 26 | def step(self): 27 | if self.optimizer_sparse is not None: 28 | self.optimizer_sparse.step() 29 | if self.optimizer_dense is not None: 30 | self.optimizer_dense.step() 31 | 32 | @property 33 | def param_groups(self): 34 | return itertools.chain( 35 | self.optimizer_dense.param_groups if self.optimizer_dense is not None else [], 36 | self.optimizer_sparse.param_groups if self.optimizer_sparse is not None else [], 37 | ) 38 | 39 | 40 | class AdamPerFeature(torch.optim.Optimizer): 41 | def __init__(self, parameters, n_features, lr=1e-3, weight_decay=0.0, **kwargs): 42 | self.n_features = n_features 43 | 44 | self.adams = [] 45 | for i in range(n_features): 46 | parameters_features = [p[i] for p in parameters] 47 | self.adams.append(torch.optim.Adam(parameters_features, lr=lr, weight_decay=weight_decay, **kwargs)) 48 | 49 | def zero_grad(self, feature_ixs): 50 | for i in feature_ixs: 51 | self.adams[i].zero_grad() 52 | 53 | def step(self, feature_ixs): 54 | for i in feature_ixs: 55 | self.adams[i].step() 56 | 57 | @property 58 | def param_groups(self): 59 | return itertools.chain(*[adam.param_groups for adam in self.adams]) 60 | 61 | 62 | class SGDPerFeature(torch.optim.Optimizer): 63 | def __init__(self, parameters, n_features, lr=1e-3, weight_decay=0.0, **kwargs): 64 | self.n_features = n_features 65 | 66 | self.adams = [] 67 | for i in range(n_features): 68 | parameters_features = [p[i] for p in parameters] 69 | self.adams.append( 70 | torch.optim.SGD(parameters_features, lr=lr, weight_decay=weight_decay, momentum=0.5, **kwargs) 71 | ) 72 | 73 | def zero_grad(self, feature_ixs): 74 | for i in feature_ixs: 75 | self.adams[i].zero_grad() 76 | 77 | def step(self, feature_ixs): 78 | for i in feature_ixs: 79 | self.adams[i].step() 80 | 81 | @property 82 | def param_groups(self): 83 | return itertools.chain(*[adam.param_groups for adam in self.adams]) 84 | -------------------------------------------------------------------------------- /src/chromatinhd/plot/__init__.py: -------------------------------------------------------------------------------- 1 | from .tickers import distance_ticker, gene_ticker, DistanceFormatter, format_distance, round_significant 2 | from . import genome 3 | 4 | from .patch import replace_patch 5 | from .matshow45 import matshow45 6 | from . import tracks 7 | 8 | 9 | -------------------------------------------------------------------------------- /src/chromatinhd/plot/genome/__init__.py: -------------------------------------------------------------------------------- 1 | from .genes import Genes 2 | -------------------------------------------------------------------------------- /src/chromatinhd/plot/matshow45.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib as mpl 3 | 4 | 5 | def matshow45(ax, series, radius=None, cmap=None, norm=None): 6 | """ 7 | fig, ax = plt.subplots() 8 | plotdata = pd.DataFrame( 9 | np.random.rand(10, 10), 10 | index=np.arange(10), 11 | columns=np.arange(10), 12 | ).stack() 13 | ax.set_aspect(1) 14 | matshow45(ax, plotdata) 15 | """ 16 | offsets = [] 17 | colors = [] 18 | 19 | assert len(series.index.names) == 2 20 | x = series.index.get_level_values(0) 21 | y = series.index.get_level_values(1) 22 | 23 | if radius is None: 24 | radius = np.diff(y)[0] / 2 25 | 26 | centerxs = x + (y - x) / 2 27 | centerys = (y - x) / 2 28 | 29 | xlim = [ 30 | centerxs.min() - radius, 31 | centerxs.max() + radius, 32 | ] 33 | ax.set_xlim(xlim) 34 | 35 | ylim = [ 36 | centerys.unique().min(), 37 | centerys.unique().max(), 38 | ] 39 | ax.set_ylim(ylim) 40 | 41 | if norm is None: 42 | norm = mpl.colors.Normalize(vmin=series.min(), vmax=series.max()) 43 | 44 | if cmap is None: 45 | cmap = mpl.cm.get_cmap() 46 | 47 | vertices = [] 48 | 49 | for centerx, centery, value in zip(centerxs, centerys, series.values): 50 | center = np.array([centerx, centery]) 51 | offsets.append(center) 52 | colors.append(cmap(norm(value))) 53 | 54 | vertices.append( 55 | [ 56 | center + np.array([radius * 1.1, 0]), 57 | center + np.array([0, radius * 1.1]), 58 | center + np.array([-radius * 1.1, 0]), 59 | center + np.array([0, -radius * 1.1]), 60 | ] 61 | ) 62 | vertices = np.array(vertices) 63 | collection = mpl.collections.PolyCollection( 64 | vertices, 65 | ec=None, 66 | lw=0, 67 | fc=colors, 68 | ) 69 | 70 | ax.add_collection(collection) 71 | 72 | for x in [xlim[1]]: 73 | x2 = x 74 | x1 = x2 + (xlim[0] - x2) / 2 75 | y2 = 0 76 | y1 = x2 - x1 77 | 78 | if True: 79 | color = "black" 80 | lw = 0.8 81 | zorder = 10 82 | elif False: 83 | color = "#eee" 84 | lw = 0.5 85 | zorder = -1 86 | ax.plot( 87 | [x1 + radius * 1.2, x2 + radius * 1.2], 88 | [y1, y2], 89 | zorder=zorder, 90 | color=color, 91 | lw=lw, 92 | ) 93 | 94 | x1, x2 = (xlim[1] - xlim[0]) / 2 - x2, (xlim[1] - xlim[0]) / 2 - x1 95 | ax.plot( 96 | [-x1 - radius * 1.2, -x2 - radius * 1.2], 97 | [y1, y2], 98 | zorder=zorder, 99 | color=color, 100 | lw=lw, 101 | ) 102 | ax.axis("off") 103 | -------------------------------------------------------------------------------- /src/chromatinhd/plot/patch.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | 4 | phases = pd.DataFrame({"phase": ["train", "validation"], "color": ["#888888", "tomato"]}).set_index("phase") 5 | 6 | 7 | colors = ["#0074D9", "#FF4136", "#FF851B", "#2ECC40", "#39CCCC", "#85144b"] 8 | 9 | 10 | def replace_patch(ax, patch, points=10, ha="center", va="center"): 11 | """ 12 | Replaces a patch, often an axis label, with a new rectangular axis according to figure size "points". Eample usage: 13 | 14 | 15 | fig, ax = plt.subplots() 16 | ax.set_title("100%, (0.5,1-0.3,.3,.3)") 17 | x.plot([0, 2], [0, 2]) 18 | 19 | for l in ax.get_xticklabels(): 20 | ax1 = replace_patch(ax, l) 21 | ax1.plot([0, 1], [0, 1]) 22 | """ 23 | fig = ax.get_figure() 24 | 25 | fig.draw_without_rendering() 26 | 27 | w, h = fig.transFigure.inverted().transform([[1, 1]]).ravel() * points 28 | bbox = fig.transFigure.inverted().transform(patch.get_window_extent()) 29 | dw = bbox[1, 0] - bbox[0, 0] 30 | dh = bbox[1, 1] - bbox[0, 1] 31 | x = bbox[0, 0] + dw / 2 32 | y = bbox[0, 1] + dh / 2 33 | 34 | if ha == "left": 35 | x += 0 36 | elif ha == "right": 37 | x -= w - dw / 2 38 | elif ha == "center": 39 | x -= w / 2 40 | 41 | if va == "bottom": 42 | y += 0 43 | elif va == "top": 44 | y -= h - dh / 2 45 | elif va == "center": 46 | y -= h / 2 47 | 48 | ax1 = fig.add_axes([x, y, w, h]) 49 | # ax1.axis("off") 50 | 51 | return ax1 52 | -------------------------------------------------------------------------------- /src/chromatinhd/plot/quasirandom.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | 5 | def digits2number(digits, base=2, fractional=False): 6 | powers = np.arange(len(digits)) 7 | out = (digits * base**powers).sum() 8 | if fractional: 9 | out = out / base ** (len(digits)) 10 | return out 11 | 12 | 13 | def number2digits(n, base=2): 14 | if n == 0: 15 | return np.array() 16 | nDigits = np.ceil(np.log(n + 1) / np.log(base)) 17 | powers = base ** (np.arange(nDigits + 1)) 18 | out = np.diff(n % powers) / powers[:-1] 19 | return out 20 | 21 | 22 | def vanDerCorput(n, base=2, start=1): 23 | out = np.array([digits2number(number2digits(ii, base)[::-1], base, True) for ii in range(1, n + start)]) 24 | return out 25 | 26 | 27 | def offset( 28 | y, 29 | maxLength=None, 30 | method="quasirandom", 31 | nbins=20, 32 | adjust=1, 33 | bw_method="scott", 34 | max_density=None, 35 | ): 36 | if len(y) == 1: 37 | return [0] 38 | 39 | if isinstance(y, pd.Series): 40 | y = y.values 41 | 42 | if nbins is None: 43 | if method in ["pseudorandom", "quasirandom"]: 44 | nbins = 2**10 45 | else: 46 | nbins = int(max(2, np.ceil(len(y) / 5))) 47 | 48 | if maxLength is None: 49 | subgroup_width = 1 50 | else: 51 | subgroup_width = np.sqrt(len(y) / maxLength) 52 | 53 | # from sklearn.neighbors import KernelDensity 54 | # kde = KernelDensity(kernel='gaussian').fit(y.reshape(-1, 1)) 55 | # pointDensities = np.exp(np.array(kde.score_samples(y.reshape(-1, 1)))) 56 | 57 | import scipy.stats 58 | 59 | kernel = scipy.stats.gaussian_kde(y, bw_method=bw_method) 60 | pointDensities = kernel(y) 61 | 62 | if max_density is None: 63 | max_density = pointDensities.max() 64 | 65 | pointDensities = pointDensities / max_density 66 | 67 | if method == "quasirandom": 68 | offset = np.array(vanDerCorput(len(y)))[np.argsort(y)] 69 | elif method == "pseudorandom": 70 | offset = np.random.uniform(size=len(y)) 71 | 72 | out = (offset - 0.5) * 2 * pointDensities * subgroup_width * 0.5 73 | 74 | return out 75 | 76 | 77 | def offsetr(y, **kwargs): 78 | from rpy2.robjects import pandas2ri 79 | import rpy2.robjects as ro 80 | from rpy2.robjects.packages import importr 81 | from rpy2.robjects import numpy2ri 82 | from rpy2.robjects import default_converter 83 | 84 | ro.numpy2ri.activate() 85 | vipor = ro.packages.importr("vipor") 86 | return vipor.offsetSingleGroup(y, method="quasi", **kwargs) * 0.5 87 | -------------------------------------------------------------------------------- /src/chromatinhd/plot/tickers.py: -------------------------------------------------------------------------------- 1 | import matplotlib.ticker as ticker 2 | 3 | import decimal 4 | 5 | import functools 6 | import math 7 | 8 | 9 | def count_zeros(value): 10 | decimal_value = str(decimal.Decimal((value))) 11 | if "." in decimal_value: 12 | count = len(decimal_value.split(".")[1]) 13 | else: 14 | count = 0 15 | return count 16 | 17 | 18 | def round_significant(value, significant=2): 19 | """ 20 | Round to a significant number of digits, including trailing zeros. 21 | For example round_significant(15400, 2) = 15000 22 | """ 23 | 24 | if value == 0: 25 | return 0 26 | 27 | n = math.log10(abs(value)) 28 | return round(value, significant - int(n) - 1) 29 | 30 | 31 | def format_distance(value, tick_pos=None, add_sign=True): 32 | abs_value = int(abs(value)) 33 | if abs_value >= 1000000: 34 | # zeros = 0 35 | zeros = len(str(abs_value).rstrip("0")) - 1 36 | abs_value = abs_value / 1000000 37 | suffix = "mb" 38 | elif abs_value >= 1000: 39 | zeros = 0 40 | # zeros = len(str(abs_value).rstrip("0")) - 1 41 | abs_value = abs_value / 1000 42 | suffix = "kb" 43 | elif abs_value == 0: 44 | # zeros = 0 45 | return "TSS" 46 | else: 47 | zeros = 0 48 | suffix = "b" 49 | 50 | formatted_value = ("{abs_value:." + str(zeros) + "f}{suffix}").format(abs_value=abs_value, suffix=suffix) 51 | 52 | if not add_sign: 53 | return formatted_value 54 | return f"-{formatted_value}" if value < 0 else f"+{formatted_value}" 55 | 56 | 57 | gene_ticker = ticker.FuncFormatter(format_distance) 58 | 59 | 60 | def custom_formatter(value, tick_pos, base=1): 61 | if base != 1: 62 | value = value / base 63 | 64 | abs_value = abs(value) 65 | if abs_value >= 1000000: 66 | abs_value = abs_value / 1000000 67 | suffix = "mb" 68 | elif abs_value >= 1000: 69 | abs_value = abs_value / 1000 70 | suffix = "kb" 71 | elif abs_value == 0: 72 | return "0" 73 | else: 74 | suffix = "b" 75 | 76 | zeros = count_zeros(abs_value) 77 | 78 | formatted_value = ("{abs_value:." + str(zeros) + "f}{suffix}").format(abs_value=abs_value, suffix=suffix) 79 | return formatted_value 80 | 81 | 82 | distance_ticker = ticker.FuncFormatter(custom_formatter) 83 | 84 | 85 | def DistanceFormatter(base=1): 86 | return ticker.FuncFormatter(functools.partial(custom_formatter, base=base)) 87 | 88 | 89 | def lighten_color(color, amount=0.5): 90 | """ 91 | Lightens the given color by multiplying (1-luminosity) by the given amount. 92 | Input can be matplotlib color string, hex string, or RGB tuple. 93 | 94 | Examples: 95 | >> lighten_color('g', 0.3) 96 | >> lighten_color('#F034A3', 0.6) 97 | >> lighten_color((.3,.55,.1), 0.5) 98 | """ 99 | import matplotlib.colors as mc 100 | import colorsys 101 | 102 | try: 103 | c = mc.cnames[color] 104 | except KeyError: 105 | c = color 106 | c = colorsys.rgb_to_hls(*mc.to_rgb(c)) 107 | return colorsys.hls_to_rgb(c[0], 1 - amount * (1 - c[1]), c[2]) 108 | -------------------------------------------------------------------------------- /src/chromatinhd/plot/tracks/__init__.py: -------------------------------------------------------------------------------- 1 | from .tracks import TracksBroken 2 | -------------------------------------------------------------------------------- /src/chromatinhd/utils/ecdf.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def ecdf(data): 5 | sorted_data = np.sort(data) 6 | y = np.arange(1, len(sorted_data) + 1) / len(sorted_data) 7 | 8 | sorted_data = np.vstack([sorted_data - 1e-5, sorted_data]).T.flatten() 9 | y = np.vstack([np.hstack([[0], y[:-1]]), y]).T.flatten() 10 | return sorted_data, y 11 | 12 | 13 | def weighted_ecdf(data, weights): 14 | sorting = np.argsort(data) 15 | sorted_data = data[sorting] 16 | y = np.cumsum(weights[sorting] / weights.sum()) 17 | 18 | sorted_data = np.vstack([sorted_data - 1e-5, sorted_data]).T.flatten() 19 | y = np.vstack([np.hstack([[0], y[:-1]]), y]).T.flatten() 20 | return sorted_data, y 21 | 22 | 23 | def area_under_ecdf(list1): 24 | x1, y1 = ecdf(list1) 25 | area = np.trapz(y1, x1) 26 | return area 27 | 28 | 29 | def area_between_ecdfs(list1, list2): 30 | x1, y1 = ecdf(list1) 31 | x2, y2 = ecdf(list2) 32 | 33 | combined_x = np.concatenate([[-0.001], np.sort(np.unique(np.concatenate((x1, x2))))]) 34 | 35 | y1_interp = np.interp(combined_x, x1, y1, left=0, right=1) 36 | y2_interp = np.interp(combined_x, x2, y2, left=0, right=1) 37 | 38 | ecdf_diff = y1_interp - y2_interp 39 | area = np.trapz(ecdf_diff, combined_x) 40 | return area 41 | 42 | 43 | def relative_area_between_ecdfs(list1, list2): 44 | x1, y1 = ecdf(list1) 45 | area = area_between_ecdfs(list1, list2) / np.trapz(y1, x1) 46 | return area 47 | -------------------------------------------------------------------------------- /src/chromatinhd/utils/interleave.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def interleave(x, repeats=np.array([1, 2, 4, 8, 16])): 6 | assert isinstance(repeats, np.ndarray) 7 | out_shape = int(x.shape[-1] // (1 / repeats).sum()) 8 | out = torch.zeros((*x.shape[:-1], out_shape), dtype=x.dtype, device=x.device) 9 | k = out_shape // repeats 10 | i = 0 11 | for k, r in zip(k, repeats): 12 | out += torch.repeat_interleave(x[..., i : i + k], r) 13 | i += k 14 | return out 15 | 16 | 17 | def deinterleave(y, repeats=np.array([1, 2, 4, 8, 16])): 18 | assert isinstance(repeats, np.ndarray) 19 | x = [] 20 | for r in repeats[::-1]: 21 | x_ = y.reshape((*y.shape[:-1], -1, r)).mean(dim=-1) 22 | y = y - torch.repeat_interleave(x_, r) 23 | x.append(x_) 24 | return torch.cat(x[::-1], -1) 25 | -------------------------------------------------------------------------------- /src/chromatinhd/utils/intervals.py: -------------------------------------------------------------------------------- 1 | def interval_contains_inclusive(x, y): 2 | """ 3 | Determines whether the intervals in x are contained in any interval of y 4 | """ 5 | contained = ~((y[:, 1] < x[:, 0][:, None]) | (y[:, 0] > x[:, 1][:, None])) 6 | return contained.any(1) 7 | -------------------------------------------------------------------------------- /src/chromatinhd/utils/numpy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def indices_to_indptr(x, n, dtype=np.int32): 5 | return np.pad(np.cumsum(np.bincount(x, minlength=n), 0, dtype=dtype), (1, 0)) 6 | 7 | 8 | ind2ptr = indices_to_indptr 9 | 10 | 11 | def indptr_to_indices(x): 12 | n = len(x) - 1 13 | return np.repeat(np.arange(n), np.diff(x)) 14 | 15 | 16 | ptr2ind = indptr_to_indices 17 | 18 | 19 | def indices_to_indptr_chunked(x, n, dtype=np.int32, batch_size=10e3): 20 | counts = np.zeros(n + 1, dtype=dtype) 21 | cur_value = 0 22 | for a, b in zip( 23 | np.arange(0, len(x), batch_size, dtype=int), np.arange(batch_size, len(x) + batch_size, batch_size, dtype=int) 24 | ): 25 | x_ = x[a:b] 26 | assert (x_ >= cur_value).all() 27 | bincount = np.bincount(x_ - cur_value) 28 | counts[(cur_value + 1) : (cur_value + len(bincount) + 1)] += bincount 29 | cur_value = x_[-1] 30 | indptr = np.cumsum(counts, dtype=dtype) 31 | return indptr 32 | 33 | 34 | 35 | def interpolate_1d(x: np.ndarray, xp: np.ndarray, fp: np.ndarray) -> np.ndarray: 36 | a = (fp[1:] - fp[:-1]) / (xp[1:] - xp[:-1]) 37 | b = fp[:-1] - (a * xp[:-1]) 38 | 39 | indices = np.searchsorted(xp, x, side="left") - 1 40 | indices = np.clip(indices, 0, a.shape[0] - 1) 41 | 42 | slope = a[indices] 43 | intercept = b[indices] 44 | return x * slope + intercept 45 | -------------------------------------------------------------------------------- /src/chromatinhd/utils/scanpy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scanpy as sc 3 | import pandas as pd 4 | 5 | 6 | # Define cluster score for all markers 7 | def evaluate_partition(anndata, marker_dict, gene_symbol_key=None, partition_key="louvain_r1"): 8 | # Inputs: 9 | # anndata - An AnnData object containing the data set and a partition 10 | # marker_dict - A dictionary with cell-type markers. The markers should be stores as anndata.var_names or 11 | # an anndata.var field with the key given by the gene_symbol_key input 12 | # gene_symbol_key - The key for the anndata.var field with gene IDs or names that correspond to the marker 13 | # genes 14 | # partition_key - The key for the anndata.obs field where the cluster IDs are stored. The default is 15 | # 'louvain_r1' 16 | 17 | # Test inputs 18 | if partition_key not in anndata.obs.columns.values: 19 | print("KeyError: The partition key was not found in the passed AnnData object.") 20 | print(" Have you done the clustering? If so, please tell pass the cluster IDs with the AnnData object!") 21 | raise 22 | 23 | if (gene_symbol_key is not None) and (gene_symbol_key not in anndata.var.columns.values): 24 | print("KeyError: The provided gene symbol key was not found in the passed AnnData object.") 25 | print(" Check that your cell type markers are given in a format that your anndata object knows!") 26 | raise 27 | 28 | if gene_symbol_key: 29 | gene_ids = anndata.var[gene_symbol_key] 30 | else: 31 | gene_ids = anndata.var_names 32 | 33 | clusters = np.unique(anndata.obs[partition_key]) 34 | n_clust = len(clusters) 35 | n_groups = len(marker_dict) 36 | 37 | marker_res = np.zeros((n_groups, n_clust)) 38 | z_scores = sc.pp.scale(anndata, copy=True) 39 | 40 | i = 0 41 | for group in marker_dict: 42 | # Find the corresponding columns and get their mean expression in the cluster 43 | j = 0 44 | for clust in clusters: 45 | cluster_cells = np.in1d(z_scores.obs[partition_key], clust) 46 | marker_genes = np.in1d(gene_ids, marker_dict[group]) 47 | marker_res[i, j] = z_scores.X[np.ix_(cluster_cells, marker_genes)].mean() 48 | j += 1 49 | i += 1 50 | 51 | variances = np.nanvar(marker_res, axis=0) 52 | if np.all(np.isnan(variances)): 53 | print("No variances could be computed, check if your cell markers are in the data set.") 54 | print("Maybe the cell marker IDs do not correspond to your gene_symbol_key input or the var_names") 55 | raise 56 | 57 | marker_res_df = pd.DataFrame(marker_res, columns=clusters, index=marker_dict.keys()) 58 | 59 | # Return the median of the variances over the clusters 60 | return marker_res_df 61 | -------------------------------------------------------------------------------- /src/chromatinhd/utils/testing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy 3 | 4 | 5 | def fdr(p_vals): 6 | from scipy.stats import rankdata 7 | 8 | ranked_p_values = rankdata(p_vals) 9 | fdr = p_vals * len(p_vals) / ranked_p_values 10 | fdr[fdr > 1] = 1 11 | 12 | return fdr 13 | 14 | 15 | def autocorrelation(x): 16 | """ 17 | Calculate the autocorrelation of a sequence. 18 | 19 | Args: 20 | x (list): A sequence of numbers. 21 | 22 | Returns: 23 | float: Autocorrelation value. 24 | """ 25 | n = len(x) 26 | x_bar = np.mean(x) 27 | numerator = np.sum([(x[i] - x_bar) * (x[i - 1] - x_bar) for i in range(1, n)]) 28 | denominator = np.sum([(x[i] - x_bar) ** 2 for i in range(n)]) 29 | 30 | return numerator / denominator 31 | 32 | 33 | # def repeated_kfold_corrected_t_test(performance_A, performance_B, k, num_repeats, alpha=0.05): 34 | # """ 35 | # Perform the corrected t-test between two learning algorithms A and B for repeated K-fold cross-validation. 36 | 37 | # Args: 38 | # performance_A (list): A list of performance scores for algorithm A. 39 | # performance_B (list): A list of performance scores for algorithm B. 40 | # k (int): Number of folds in the cross-validation. 41 | # num_repeats (int): Number of times the K-fold cross-validation was repeated. 42 | # alpha (float): Significance level, default is 0.05. 43 | 44 | # Returns: 45 | # bool: True if there is a significant difference, False otherwise. 46 | # float: t-statistic value. 47 | # float: p-value. 48 | # """ 49 | 50 | # if len(performance_A) != len(performance_B) or len(performance_A) != k * num_repeats: 51 | # raise ValueError("Performance scores for each algorithm 52 | # should have the same length and match k * num_repeats.") 53 | 54 | # n = k * num_repeats 55 | 56 | # d = [performance_A[i] - performance_B[i] for i in range(n)] 57 | # d_bar = np.mean(d) 58 | # s_d = np.std(d, ddof=1) 59 | 60 | # rho = autocorrelation(d) 61 | # effective_sample_size = n * (1 - rho) / (1 + rho) 62 | 63 | # t_statistic = d_bar / (s_d / np.sqrt(effective_sample_size)) 64 | # degrees_of_freedom = effective_sample_size - 1 65 | # p_value = 2 * (1 - stats.t.cdf(abs(t_statistic), degrees_of_freedom)) 66 | 67 | # reject_null_hypothesis = p_value < alpha 68 | 69 | # return reject_null_hypothesis, t_statistic, p_value 70 | 71 | 72 | def repeated_kfold_corrected_t_test(diff, r, k, n_train, n_test): 73 | diff_corrected = 1 / (k * r) * diff.sum() 74 | variance = 1 / (k * r - 1) * ((diff - diff_corrected) ** 2).sum() 75 | t = 1 / (k * r) * diff_corrected / np.sqrt((1 / (k * r) + n_test / n_train) * variance) 76 | 77 | scipy.stats.t.cdf(t, r * k - 1) 78 | return t 79 | -------------------------------------------------------------------------------- /src/chromatinhd/utils/timing.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import time 3 | 4 | 5 | class catchtime(object): 6 | def __init__(self, dict, name): 7 | self.name = name 8 | self.dict = dict 9 | 10 | def __enter__(self): 11 | self.t = time.time() 12 | return self 13 | 14 | def __exit__(self, type, value, traceback): 15 | self.t = time.time() - self.t 16 | self.dict[self.name] += self.t 17 | 18 | 19 | class timer(object): 20 | def __init__(self): 21 | self.times = collections.defaultdict(float) 22 | 23 | def catch(self, name): 24 | return catchtime(self.times, name) 25 | -------------------------------------------------------------------------------- /src/chromatinhd/utils/torch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def interpolate_1d(x: torch.Tensor, xp: torch.Tensor, fp: torch.Tensor) -> torch.Tensor: 6 | a = (fp[..., 1:] - fp[..., :-1]) / (xp[..., 1:] - xp[..., :-1]) 7 | b = fp[..., :-1] - (a.mul(xp[..., :-1])) 8 | 9 | indices = torch.searchsorted(xp.contiguous(), x.contiguous(), right=False) - 1 10 | indices = torch.clamp(indices, 0, a.shape[-1] - 1) 11 | 12 | slope = a.index_select(a.ndim - 1, indices) 13 | intercept = b.index_select(a.ndim - 1, indices) 14 | return x * slope + intercept 15 | 16 | def interpolate_0d(x: torch.Tensor, xp: torch.Tensor, fp: torch.Tensor) -> torch.Tensor: 17 | a = (fp[..., 1:] - fp[..., :-1]) / (xp[..., 1:] - xp[..., :-1]) 18 | b = fp[..., :-1] - (a.mul(xp[..., :-1])) 19 | 20 | indices = torch.searchsorted(xp.contiguous(), x.contiguous(), right=False) - 1 21 | indices = torch.clamp(indices, 0, a.shape[-1] - 1) 22 | 23 | return b.index_select(a.ndim - 1, indices) 24 | 25 | 26 | def indices_to_indptr(x: torch.Tensor, n: int) -> torch.Tensor: 27 | return torch.nn.functional.pad(torch.cumsum(torch.bincount(x, minlength=n), 0), (1, 0)) 28 | 29 | 30 | ind2ptr = indices_to_indptr 31 | 32 | 33 | def indptr_to_indices(x: torch.Tensor) -> torch.Tensor: 34 | return torch.repeat_interleave(torch.arange(len(x) - 1), torch.diff(x)) 35 | 36 | 37 | ptr2ind = indptr_to_indices 38 | -------------------------------------------------------------------------------- /tests/_test_sparse.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import chromatinhd as chd 3 | import numpy as np 4 | 5 | 6 | class TestSparse: 7 | def test_simple(self): 8 | path = "./hello" 9 | coords_pointed = {"a": pd.Index(["a", "b", "c"]), "b": pd.Index(["A", "B"])} 10 | coords_fixed = {"c": pd.Index(["1", "2", "3"]), "d": pd.Index(["100", "200"])} 11 | variables = {"x": ("a", "b", "c"), "y": ("a", "c", "d")} 12 | 13 | dataset = chd.sparse.SparseDataset2.create(path, variables, coords_pointed, coords_fixed) 14 | 15 | dataset["x"]["a", "B"] = np.array([1, 2, 3]) 16 | dataset["x"]["c", "B"] = np.array([1, 0, 4]) 17 | dataset["y"]["a"] = np.array([[1, 0], [2, 3], [1, 2]]) 18 | 19 | dataset["x"]._read_data()[:] 20 | 21 | assert (dataset["x"]["a", "B"] == np.array([1.0, 2.0, 3.0])).all() 22 | assert (dataset["x"]["a", "A"] == np.array([0.0, 0.0, 0.0])).all() 23 | assert (dataset["x"]["b", "B"] == np.array([0.0, 0.0, 0.0])).all() 24 | assert (dataset["x"]["c", "B"] == np.array([1, 0, 4])).all() 25 | assert (dataset["y"]["a"] == np.array([[1, 0], [2, 3], [1, 2]])).all() 26 | -------------------------------------------------------------------------------- /tests/biomart/conftest.py: -------------------------------------------------------------------------------- 1 | import chromatinhd as chd 2 | import pytest 3 | 4 | import pathlib 5 | 6 | 7 | @pytest.fixture(scope="module") 8 | def dataset_grch38(): 9 | return chd.biomart.Dataset.from_genome("GRCh38") 10 | -------------------------------------------------------------------------------- /tests/biomart/tss_test.py: -------------------------------------------------------------------------------- 1 | import chromatinhd as chd 2 | 3 | 4 | class TestGetCanonicalTranscripts: 5 | def test_simple(self, dataset_grch38): 6 | pass 7 | # transcripts = chd.biomart.tss.get_canonical_transcripts(biomart_dataset = dataset_grch38) 8 | 9 | # assert transcripts.shape[0] > 1000 10 | 11 | 12 | class TestGetExons: 13 | def test_simple(self, dataset_grch38): 14 | pass 15 | # exons = chd.biomart.tss.get_exons(biomart_dataset = dataset_grch38, chrom = "chr1", start = 1000, end = 2000) 16 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import pathlib 3 | import chromatinhd as chd 4 | 5 | 6 | @pytest.fixture(scope="session") 7 | def example_dataset_folder(tmp_path_factory): 8 | example_dataset_folder = tmp_path_factory.mktemp("example") 9 | 10 | import pkg_resources 11 | import shutil 12 | 13 | DATA_PATH = pathlib.Path(pkg_resources.resource_filename("chromatinhd", "data/examples/pbmc10ktiny/")) 14 | 15 | # copy all files from data path to dataset folder 16 | for file in DATA_PATH.iterdir(): 17 | shutil.copy(file, example_dataset_folder / file.name) 18 | return example_dataset_folder 19 | 20 | 21 | @pytest.fixture(scope="session") 22 | def example_transcriptome(example_dataset_folder): 23 | import scanpy as sc 24 | 25 | adata = sc.read(example_dataset_folder / "transcriptome.h5ad") 26 | transcriptome = chd.data.Transcriptome.from_adata(adata, path=example_dataset_folder / "transcriptome") 27 | return transcriptome 28 | 29 | 30 | @pytest.fixture(scope="session") 31 | def example_clustering(example_dataset_folder, example_transcriptome): 32 | clustering = chd.data.Clustering.from_labels( 33 | example_transcriptome.adata.obs["celltype"], 34 | path=example_dataset_folder / "clustering", 35 | ) 36 | return clustering 37 | 38 | 39 | @pytest.fixture(scope="session") 40 | def example_regions(example_dataset_folder, example_transcriptome): 41 | biomart_dataset = chd.biomart.Dataset.from_genome("GRCh38") 42 | canonical_transcripts = chd.biomart.get_canonical_transcripts(biomart_dataset, example_transcriptome.var.index) 43 | regions = chd.data.Regions.from_transcripts( 44 | canonical_transcripts, 45 | path=example_dataset_folder / "regions", 46 | window=[-10000, 10000], 47 | ) 48 | return regions 49 | 50 | 51 | @pytest.fixture(scope="session") 52 | def example_fragments(example_dataset_folder, example_transcriptome, example_regions): 53 | fragments = chd.data.Fragments.from_fragments_tsv( 54 | example_dataset_folder / "fragments.tsv.gz", 55 | example_regions, 56 | obs=example_transcriptome.obs, 57 | path=example_dataset_folder / "fragments", 58 | ) 59 | fragments.create_regionxcell_indptr() 60 | return fragments 61 | 62 | 63 | @pytest.fixture(scope="session") 64 | def example_folds(example_dataset_folder, example_fragments): 65 | folds = chd.data.folds.Folds(example_dataset_folder / "random_5fold") 66 | folds.sample_cells(example_fragments, n_folds=5, n_repeats=1) 67 | return folds 68 | -------------------------------------------------------------------------------- /tests/loaders/test_fragment_helpers.py: -------------------------------------------------------------------------------- 1 | import chromatinhd.loaders.fragments_helpers 2 | import numpy as np 3 | 4 | 5 | def test_multiple_arange(): 6 | a = np.array([10, 50, 60], dtype=np.int64) 7 | b = np.array([20, 52, 62], dtype=np.int64) 8 | ix = np.zeros(100, dtype=np.int64) 9 | local_cellxregion_ix = np.zeros(100, dtype=np.int64) 10 | 11 | n_fragments = chromatinhd.loaders.fragments_helpers.multiple_arange(a, b, ix, local_cellxregion_ix) 12 | ix.resize(n_fragments, refcheck=False) 13 | local_cellxregion_ix.resize(n_fragments, refcheck=False) 14 | 15 | assert ix.shape == (14,) 16 | assert local_cellxregion_ix.shape == (14,) 17 | assert local_cellxregion_ix.max() < 3 18 | assert np.all(ix == np.array([10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 50, 51, 60, 61], dtype=np.int64)) 19 | assert np.all(local_cellxregion_ix == np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 2, 2], dtype=np.int64)) 20 | -------------------------------------------------------------------------------- /tests/loaders/test_transcriptome_fragments.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import chromatinhd as chd 3 | 4 | 5 | class TestTranscriptomeFragments: 6 | def test_example(self, example_fragments, example_transcriptome): 7 | loader = chd.loaders.TranscriptomeFragments( 8 | fragments=example_fragments, 9 | transcriptome=example_transcriptome, 10 | cellxregion_batch_size=10000, 11 | ) 12 | 13 | minibatch = chd.loaders.minibatches.Minibatch(cells_oi=np.arange(20), regions_oi=np.arange(5), phase="train") 14 | loader.load(minibatch) 15 | -------------------------------------------------------------------------------- /tests/models/diff/loader/test_clustering_cuts.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import chromatinhd as chd 3 | 4 | 5 | class TestClusteringCuts: 6 | def test_example(self, example_fragments, example_clustering): 7 | loader = chd.models.diff.loader.ClusteringCuts( 8 | fragments=example_fragments, 9 | clustering=example_clustering, 10 | cellxregion_batch_size=10000, 11 | ) 12 | 13 | minibatch = chd.models.diff.loader.Minibatch(cells_oi=np.arange(20), regions_oi=np.arange(5), phase="train") 14 | loader.load(minibatch) 15 | -------------------------------------------------------------------------------- /tests/models/diff/model/_test_spline.py: -------------------------------------------------------------------------------- 1 | import chromatinhd.models.diff.model.spline 2 | import torch 3 | import numpy as np 4 | 5 | 6 | class TestDifferentialQuadraticSplineStack: 7 | def test_basic(self): 8 | transform = chromatinhd.models.diff.model.spline.DifferentialQuadraticSplineStack(nbins=(128,), n_regions=1) 9 | 10 | x = torch.linspace(0, 1, 100) 11 | genes_oi = torch.tensor([0]) 12 | local_gene_ix = torch.zeros(len(x), dtype=torch.int) 13 | 14 | delta = torch.zeros((len(x), np.sum(transform.split_deltas))) 15 | delta[:, :30] = 0 16 | ouput, logabsdet = transform.transform_forward(x, genes_oi, local_gene_ix, delta) 17 | 18 | assert np.isclose(np.trapz(torch.exp(logabsdet).detach().numpy(), x), 1, atol=5e-2) 19 | 20 | delta[:, :30] = 1 21 | ouput, logabsdet = transform.transform_forward(x, genes_oi, local_gene_ix, delta) 22 | 23 | assert np.isclose(np.trapz(torch.exp(logabsdet).detach().numpy(), x), 1, atol=5e-2) 24 | 25 | delta[:, :] = torch.from_numpy(np.random.normal(size=(1, delta.shape[1]))) 26 | ouput, logabsdet = transform.transform_forward(x, genes_oi, local_gene_ix, delta) 27 | 28 | assert np.isclose(np.trapz(torch.exp(logabsdet).detach().numpy(), x), 1, atol=5e-2) 29 | -------------------------------------------------------------------------------- /tests/models/diff/model/test_diff_model.py: -------------------------------------------------------------------------------- 1 | import chromatinhd as chd 2 | import time 3 | 4 | 5 | class TestDiff: 6 | def test_example_single_model(self, example_fragments, example_clustering, example_folds): 7 | fold = example_folds[0] 8 | model = chd.models.diff.model.binary.Model.create( 9 | fragments=example_fragments, clustering=example_clustering, fold=fold 10 | ) 11 | 12 | start = time.time() 13 | model.train_model(n_epochs=10) 14 | 15 | delta = time.time() - start 16 | 17 | assert delta < 20 18 | 19 | def test_example_multiple_models(self, example_fragments, example_clustering, example_folds): 20 | model = chd.models.diff.model.binary.Models.create( 21 | fragments=example_fragments, 22 | clustering=example_clustering, 23 | folds=example_folds, 24 | ) 25 | 26 | start = time.time() 27 | model.train_models(n_epochs=2, regions_oi=example_fragments.var.index[:2]) 28 | 29 | delta = time.time() - start 30 | 31 | assert delta < 40 32 | -------------------------------------------------------------------------------- /tests/models/miff/_test_zoom.py: -------------------------------------------------------------------------------- 1 | import chromatinhd.models.miff.model.zoom 2 | import torch 3 | import numpy as np 4 | 5 | 6 | class TestZoom: 7 | def test_simple(self): 8 | width = 16 9 | nbins = [4, 2, 2] 10 | 11 | totalnbins = np.cumprod(nbins) 12 | totalbinwidths = torch.tensor(width // totalnbins) 13 | 14 | positions = torch.arange(16) 15 | 16 | def test_probability(unnormalized_heights_all, expected_logprob): 17 | totalnbins = np.cumprod(nbins) 18 | totalbinwidths = torch.tensor(width // totalnbins) 19 | 20 | unnormalized_heights_zooms = chromatinhd.models.miff.model.zoom.extract_unnormalized_heights( 21 | positions, totalbinwidths, unnormalized_heights_all 22 | ) 23 | logprob = chromatinhd.models.miff.model.zoom.calculate_logprob( 24 | positions, nbins, width, unnormalized_heights_zooms 25 | ) 26 | 27 | # print(1 / np.exp((logprob))) 28 | assert np.isclose(1 / np.exp((logprob)), expected_logprob).all() 29 | 30 | ## 31 | unnormalized_heights_all = [] 32 | cur_total_n = 1 33 | for n in nbins: 34 | cur_total_n *= n 35 | unnormalized_heights_all.append(torch.zeros(cur_total_n).reshape(-1, n)) 36 | unnormalized_heights_all[1][2, 1] = np.log(2) 37 | expected_logprob = np.array([16, 16, 16, 16, 16, 16, 16, 16, 24, 24, 12, 12, 16, 16, 16, 16]) 38 | 39 | test_probability(unnormalized_heights_all, expected_logprob) 40 | 41 | ## 42 | unnormalized_heights_all = [] 43 | cur_total_n = 1 44 | for n in nbins: 45 | cur_total_n *= n 46 | unnormalized_heights_all.append(torch.zeros(cur_total_n).reshape(-1, n)) 47 | unnormalized_heights_all[0][0, 1] = np.log(2) 48 | expected_logprob = np.array([20, 20, 20, 20, 10, 10, 10, 10, 20, 20, 20, 20, 20, 20, 20, 20]) 49 | 50 | test_probability(unnormalized_heights_all, expected_logprob) 51 | 52 | ## 53 | unnormalized_heights_all = [] 54 | cur_total_n = 1 55 | for n in nbins: 56 | cur_total_n *= n 57 | unnormalized_heights_all.append(torch.zeros(cur_total_n).reshape(-1, n)) 58 | unnormalized_heights_all[2][2, 0] = np.log(2) 59 | expected_logprob = np.array([16, 16, 16, 16, 12, 24, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16]) 60 | 61 | test_probability(unnormalized_heights_all, expected_logprob) 62 | -------------------------------------------------------------------------------- /tests/models/pred/model/test_pred_model.py: -------------------------------------------------------------------------------- 1 | import chromatinhd as chd 2 | import time 3 | 4 | 5 | class TestPred: 6 | def test_example_single_model(self, example_fragments, example_transcriptome, example_folds): 7 | fold = example_folds[0] 8 | model = chd.models.pred.model.better.Model.create( 9 | fragments=example_fragments, 10 | transcriptome=example_transcriptome, 11 | region_oi=example_fragments.var.index[0], 12 | fold=fold, 13 | ) 14 | 15 | start = time.time() 16 | model.train_model(fold=example_folds[0], n_epochs=10) 17 | 18 | delta = time.time() - start 19 | 20 | assert delta < 20 21 | result = model.get_prediction(cell_ixs=fold["cells_test"]) 22 | 23 | assert result["predicted"].shape[0] == len(fold["cells_test"]) 24 | 25 | def test_example_multiple_models(self, example_fragments, example_transcriptome, example_folds): 26 | models = chd.models.pred.model.better.Models.create( 27 | fragments=example_fragments, transcriptome=example_transcriptome 28 | ) 29 | 30 | start = time.time() 31 | models.train_models(folds=example_folds, n_epochs=1, regions_oi=example_fragments.var.index[:2]) 32 | 33 | delta = time.time() - start 34 | 35 | assert delta < 60 36 | result = models.get_prediction(region=example_fragments.var.index[0], fold_ix=0) 37 | -------------------------------------------------------------------------------- /tests/utils/test_interleave.py: -------------------------------------------------------------------------------- 1 | import chromatinhd as chd 2 | import torch 3 | import numpy as np 4 | 5 | 6 | def test_interleave(): 7 | y = chd.utils.interleave.interleave(torch.tensor([1, 2, 3, 4, 2, 3], dtype=float), repeats=np.array([1, 2])) 8 | assert torch.allclose(y, torch.tensor([3.0, 4.0, 6.0, 7.0], dtype=float)) 9 | 10 | x = chd.utils.interleave.deinterleave(y, repeats=np.array([1, 2])) 11 | assert torch.allclose(x, torch.tensor([-0.5000, 0.5000, -0.5000, 0.5000, 3.5000, 6.5000], dtype=torch.float64)) 12 | 13 | y2 = chd.utils.interleave.interleave(x, repeats=np.array([1, 2])) 14 | assert torch.allclose(y, y2) 15 | 16 | 17 | def test_interleave2(): 18 | y = torch.tensor([1, 2, 3, 4, 2, 3], dtype=float) 19 | x = chd.utils.interleave.deinterleave(y, repeats=np.array([1, 2, 3])) 20 | y2 = chd.utils.interleave.interleave(x, repeats=np.array([1, 2, 3])) 21 | assert torch.allclose(y, y2) 22 | --------------------------------------------------------------------------------