├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yml ├── LICENSE ├── README.md ├── docs ├── Makefile ├── _key_contributors.rst ├── _static │ └── seqdata_xr.png ├── api.md ├── conf.py ├── contributing.md ├── index.md ├── installation.md ├── references.bib ├── references.md ├── release-notes.md ├── tutorials │ ├── 1_Reading_Flat_Files.ipynb │ ├── 2_Reading_Region_Files.ipynb │ ├── 3_Reading_Tracks.ipynb │ ├── 4_Zarr_And_XArray.ipynb │ ├── 5_PyTorch_Dataloading.ipynb │ └── 6_Complex_Transforms.ipynb └── usage-principles.md ├── poetry.lock ├── pyproject.toml ├── seqdata ├── __init__.py ├── _io │ ├── __init__.py │ ├── bed_ops.py │ ├── read.py │ ├── readers.py │ ├── readers │ │ ├── __init__.py │ │ ├── bam.py │ │ ├── fasta.py │ │ ├── table.py │ │ ├── vcf.py │ │ └── wig.py │ └── utils.py ├── datasets.py ├── torch.py ├── types.py └── xarray │ ├── seqdata.py │ └── utils.py ├── setup.py └── tests ├── _test_vcf.py ├── data ├── README.md ├── fixed.bed ├── fixed.chrom.sizes ├── fixed.fa ├── fixed.fa.fai ├── fixed.tsv ├── simulated1.bam ├── simulated1.bam.bai ├── simulated1.bw ├── simulated2.bam ├── simulated2.bam.bai ├── simulated2.bw ├── simulated3.bam ├── simulated3.bam.bai ├── simulated3.bw ├── simulated4.bam ├── simulated4.bam.bai ├── simulated4.bw ├── simulated5.bam ├── simulated5.bam.bai ├── simulated5.bw ├── variable.bed ├── variable.bedcov.pkl ├── variable.chrom.sizes ├── variable.fa ├── variable.fa.fai └── variable.tsv ├── notebooks ├── generate_data.ipynb ├── test_datasets.ipynb ├── test_flat_files.ipynb └── test_regions_files.ipynb ├── readers ├── test_bam.py ├── test_bigwig.py ├── test_flat_fasta.py ├── test_genome_fasta.py └── test_table.py ├── test_bed_ops.py ├── test_max_jitter.py ├── test_open_zarr.py ├── test_to_zarr.py └── torch └── test_torch.py /.gitignore: -------------------------------------------------------------------------------- 1 | archive/ 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 100 | __pypackages__/ 101 | 102 | # Celery stuff 103 | celerybeat-schedule 104 | celerybeat.pid 105 | 106 | # SageMath parsed files 107 | *.sage.py 108 | 109 | # Environments 110 | .env 111 | .venv 112 | env/ 113 | venv/ 114 | ENV/ 115 | env.bak/ 116 | venv.bak/ 117 | 118 | # Spyder project settings 119 | .spyderproject 120 | .spyproject 121 | 122 | # Rope project settings 123 | .ropeproject 124 | 125 | # mkdocs documentation 126 | /site 127 | 128 | # mypy 129 | .mypy_cache/ 130 | .dmypy.json 131 | dmypy.json 132 | 133 | # Pyre type checker 134 | .pyre/ 135 | 136 | # pytype static type analyzer 137 | .pytype/ 138 | 139 | # Cython debug symbols 140 | cython_debug/ 141 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v3.2.0 4 | hooks: 5 | - id: check-merge-conflict 6 | - id: debug-statements 7 | - id: mixed-line-ending 8 | - id: check-case-conflict 9 | - id: check-yaml 10 | - repo: https://github.com/astral-sh/ruff-pre-commit 11 | # Ruff version. 12 | rev: v0.7.2 13 | hooks: 14 | # Run the linter. 15 | - id: ruff 16 | types_or: [ python, pyi ] 17 | args: [ --fix ] 18 | # Run the formatter. 19 | - id: ruff-format 20 | types_or: [ python, pyi ] 21 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | build: 3 | os: "ubuntu-22.04" 4 | tools: 5 | python: "3.9" 6 | jobs: 7 | post_create_environment: 8 | # Install poetry 9 | # https://python-poetry.org/docs/#installing-manually 10 | - pip install poetry 11 | # Tell poetry to not use a virtual environment 12 | - poetry config virtualenvs.create false 13 | post_install: 14 | # Install dependencies with 'docs' extras 15 | # https://python-poetry.org/docs/pyproject/#extrass 16 | - poetry install --extras docs 17 | 18 | sphinx: 19 | configuration: docs/conf.py 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Adam Klie 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![PyPI version](https://badge.fury.io/py/seqexplainer.svg)](https://badge.fury.io/py/seqdata) 2 | ![PyPI - Downloads](https://img.shields.io/pypi/dm/seqdata) 3 | 4 | seqdata xr 5 | 6 | # SeqData (Annotated sequence data) 7 | 8 | [[documentation](https://seqdata.readthedocs.io/en/latest/)][[tutorials]()] 9 | 10 | SeqData is a Python package for preparing ML-ready genomic sequence datasets. Some of the key features of SeqData include: 11 | 12 | - Keeps multi-dimensional data in one object (e.g. sequence, coverage, metadata, etc.) 13 | - Efficiently and flexibly loads of track-based data from BigWig or BAM 14 | - Fully compatible with PyTorch dataloading 15 | - Offers out-of-core dataloading from disk to CPU to GPU 16 | 17 | > [!NOTE] 18 | > SeqData is under active development. The API has largely been decided on, but may change slightly across versions until the first major release. 19 | 20 | ## Installation 21 | 22 | `pip install seqdata` 23 | 24 | ## Roadmap 25 | 26 | Although my focus will largely follow my research projects and the feedback I receive from the community, here is a roadmap for what I currently plan to focus on in the next few releases. 27 | 28 | - v0.1.0: ✔️ Initial API for reading BAM, FASTA, BigWig and Tabular data and building loading PyTorch dataloaders 29 | - v0.2.0: (WIP) Bug fixes, improved documentation, tutorials, and examples 30 | - v0.3.0: Improved out of core functionality, robust BED classification datasets 31 | - v0.0.4 — Interoperability with AnnData and SnapATAC2 32 | 33 | ## Usage 34 | 35 | ### Loading data from "flat" files 36 | The simplest way to store genomic sequence data is in a table or in a "flat" fasta file. Though this can easily be accomplished using something like `pandas.read_csv`, the SeqData interface keeps the resulting on-disk and in-memory objects standardized with the rest of the SeqData and larger ML4GLand API. 37 | 38 | ```python 39 | from seqdata import read_table 40 | sdata = sd.read_table( 41 | name="seq", # name of resulting xarray variable containing sequences 42 | out="sdata.zarr", # output file 43 | tables=["sequences.tsv"], # list of tabular files 44 | seq_col="seq_col", # column containing sequences 45 | fixed_length=False, # whether all sequences are the same length 46 | batch_size=1000, # number of sequences to load at once 47 | overwrite=True, # overwrite the output file if it exists 48 | ) 49 | ``` 50 | 51 | Will generate a `sdata.zarr` file containing the sequences in the `seq_col` column of `sequences.tsv`. The resulting `sdata` object can then be used for downstream analysis. 52 | 53 | ### Loading sequences from genomic coordinates 54 | 55 | ### Loading data from BAM files 56 | Reading from bam files allows one to choose custom counting strategies (often necessary with ATAC-seq data). 57 | 58 | ```python 59 | from seqdata import read_bam 60 | sdata = sd.read_bam( 61 | name="seq", # name of resulting xarray variable containing sequences 62 | out="sdata.zarr", # output file 63 | bams=["data.bam"], # list of BAM files 64 | seq_col="seq_col", # column containing sequences 65 | fixed_length=False, # whether all sequences are the same length 66 | batch_size=1000, # number of sequences to load at once 67 | overwrite=True, # overwrite the output file if it exists 68 | ) 69 | ``` 70 | 71 | ### Loading data from BigWig files 72 | [BigWig files](https://genome.ucsc.edu/goldenpath/help/bigWig.html) are a common way to store track-based data and the workhorse of modern genomic sequence based ML. ... 73 | 74 | ```python 75 | from seqdata import read_bigwig 76 | sdata = sd.read_bigwig( 77 | name="seq", # name of resulting xarray variable containing sequences 78 | out="sdata.zarr", # output file 79 | bigwigs=["data.bw"], # list of BigWig files 80 | seq_col="seq_col", # column containing sequences 81 | fixed_length=False, # whether all sequences are the same length 82 | batch_size=1000, # number of sequences to load at once 83 | overwrite=True, # overwrite the output file if it exists 84 | ) 85 | ``` 86 | 87 | ### Working with Zarr stores and XArray objects 88 | The SeqData API is built to convert data from common formats to Zarr stores on disk. The Zarr store... When coupled with XArray and Dask, we also have the ability to lazy load data and work with data that is too large to fit in memory. 89 | 90 | ```python 91 | ``` 92 | 93 | Admittedly, working with XArray can take some getting used to... 94 | 95 | ### Building a dataloader 96 | The main goal of SeqData is to allow a seamless flow 97 | 98 | ## Contributing 99 | This section was modified from https://github.com/pachterlab/kallisto. 100 | 101 | All contributions, including bug reports, documentation improvements, and enhancement suggestions are welcome. Everyone within the community is expected to abide by our [code of conduct](https://github.com/ML4GLand/EUGENe/blob/main/CODE_OF_CONDUCT.md) 102 | 103 | As we work towards a stable v1.0.0 release, and we typically develop on branches. These are merged into `dev` once sufficiently tested. `dev` is the latest, stable, development branch. 104 | 105 | `main` is used only for official releases and is considered to be stable. If you submit a pull request, please make sure to request to merge into `dev` and NOT `main`. 106 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | 22 | livehtml: 23 | sphinx-autobuild "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 24 | -------------------------------------------------------------------------------- /docs/_key_contributors.rst: -------------------------------------------------------------------------------- 1 | .. sidebar:: Key Contributors 2 | 3 | * `David Laub `_: lead developer ☀ 4 | * `Adam Klie `_: developer, diverse contributions ☀ 5 | -------------------------------------------------------------------------------- /docs/_static/seqdata_xr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML4GLand/SeqData/10afa85d0ec954b9d28955b072a55274c3c18dfd/docs/_static/seqdata_xr.png -------------------------------------------------------------------------------- /docs/api.md: -------------------------------------------------------------------------------- 1 | # API Reference 2 | 3 | ```{eval-rst} 4 | .. automodule:: seqdata 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | ``` -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | import seqdata as sd 7 | 8 | # -- Project information ----------------------------------------------------- 9 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 10 | 11 | project = "SeqData" 12 | copyright = "2023, Adam Klie" 13 | author = "Adam Klie" 14 | release = sd.__version__ 15 | # short X.Y verison 16 | version = ".".join(release.split(".")[:2]) 17 | 18 | 19 | # -- General configuration --------------------------------------------------- 20 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 21 | 22 | extensions = [ 23 | "sphinx.ext.autodoc", 24 | "sphinx.ext.napoleon", 25 | "sphinx.ext.intersphinx", 26 | "sphinx.ext.viewcode", 27 | "sphinxcontrib.bibtex", 28 | "myst_parser", 29 | "nbsphinx", 30 | ] 31 | 32 | bibtex_bibfiles = ['references.bib'] 33 | 34 | napoleon_use_param = True 35 | napoleon_type_aliases = { 36 | "ArrayLike": ":term:`array_like`", 37 | "NDArray": ":ref:`NDArray `", 38 | } 39 | napoleon_use_rtype = True 40 | 41 | autodoc_typehints = "both" 42 | autodoc_type_aliases = {"ArrayLike": "ArrayLike"} 43 | autodoc_default_options = {"private-members": False} 44 | autodoc_member_order = "bysource" 45 | 46 | myst_enable_extensions = ["colon_fence"] 47 | 48 | templates_path = ["_templates"] 49 | exclude_patterns = [] 50 | 51 | intersphinx_mapping = { 52 | "python": ("https://docs.python.org/3", None), 53 | "numpy": ("https://numpy.org/doc/stable/", None), 54 | "numba": ("https://numba.readthedocs.io/en/stable/", None), 55 | "polars": ("https://docs.pola.rs/py-polars/html", None), 56 | "torch": ("https://pytorch.org/docs/stable/", None), 57 | } 58 | 59 | # -- Options for HTML output ------------------------------------------------- 60 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 61 | 62 | html_theme = "sphinx_book_theme" 63 | html_static_path = ["_static"] 64 | html_theme_options = { 65 | "home_page_in_toc": True, 66 | "repository_url": "https://github.com/adamklie/SeqDat", 67 | "use_repository_button": True, 68 | "pygments_light_style": "tango", 69 | "pygments_dark_style": "material", 70 | "show_navbar_depth": 2, # Ensures dropdown levels are visible 71 | } 72 | html_logo = "_static/seqdata_xr.png" 73 | html_title = f"SeqData v{version}" 74 | html_sidebars = { 75 | "**": [ 76 | "navbar-logo.html", 77 | "icon-links.html", 78 | "search-button-field.html", 79 | "sbt-sidebar-nav.html", 80 | ] 81 | } 82 | -------------------------------------------------------------------------------- /docs/contributing.md: -------------------------------------------------------------------------------- 1 | # Contributors 2 | 3 | ```{eval-rst} 4 | .. include:: _key_contributors.rst 5 | ``` 6 | 7 | ## Current developers 8 | 9 | ## Other roles 10 | 11 | ## Former developers 12 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | ```{toctree} 2 | :hidden: true 3 | :caption: Contents 4 | :maxdepth: 2 5 | 6 | installation 7 | usage-principles 8 | api 9 | release-notes 10 | contributing 11 | references 12 | ``` 13 | 14 | ```{toctree} 15 | :hidden: true 16 | :caption: Tutorials 17 | :maxdepth: 2 18 | 19 | tutorials/1_Reading_Flat_Files 20 | tutorials/2_Reading_Region_Files 21 | tutorials/3_Reading_Tracks 22 | tutorials/4_Zarr_And_XArray 23 | tutorials/5_PyTorch_Dataloading 24 | tutorials/6_Complex_Transforms 25 | ``` 26 | 27 | 28 | # SeqData -- Annotated biological sequence data 29 | ```{image} https://badge.fury.io/py/SeqData.svg 30 | :alt: PyPI version 31 | :target: https://badge.fury.io/py/SeqData 32 | :class: inline-link 33 | ``` 34 | 35 | ```{image} https://readthedocs.org/projects/SeqData/badge/?version=latest 36 | :alt: Documentation Status 37 | :target: https://SeqData.readthedocs.io/en/latest/index.html 38 | :class: inline-link 39 | ``` 40 | 41 | ```{image} https://img.shields.io/pypi/dm/SeqData 42 | :alt: PyPI - Downloads 43 | :class: inline-link 44 | ``` 45 | 46 | SeqData is a Python package for preparing ML-ready genomic sequence datasets. Some of the key features of SeqData include: 47 | 48 | - Keeps multi-dimensional data in one object (e.g. sequence, coverage, metadata, etc.) 49 | - Efficiently and flexibly loads of track-based data from BigWig or BAM 50 | - Fully compatible with PyTorch dataloading 51 | - Offers out-of-core dataloading from disk to CPU to GPU 52 | 53 | SeqData is designed to be used via its Python API. 54 | 55 | # Getting started 56 | * {doc}`Install SeqData ` 57 | * Browse the main {doc}`API ` 58 | 59 | # Contributing 60 | SeqData is an open-source project and we welcome {doc}`contributions `. 61 | -------------------------------------------------------------------------------- /docs/installation.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | You must have Python version 3.9 or higher installed to use SeqData. SeqData can be installed using `pip`: 3 | 4 | ```bash 5 | pip install seqdata 6 | ``` 7 | 8 | ## Developmental installation 9 | To work with the latest version [on GitHub](https://github.com/ML4GLand/SeqData), clone the repository and `cd` into its root directory. 10 | 11 | ```bash 12 | git clone https://github.com/ML4GLand/SeqData.git 13 | cd 14 | ``` 15 | 16 | Then, install the package in development mode: 17 | 18 | ```bash 19 | pip install -e . 20 | ``` 21 | 22 | ## Optional dependencies 23 | If you plan on building PyTorch dataloaders from SeqData objects, you will need to install SeqData with PyTorch: 24 | 25 | ```bash 26 | pip install seqdata[torch] 27 | ``` 28 | 29 | ## Troubleshooting 30 | If you have any issues installing, please [open an issue](https://github.com/ML4GLand/SeqData/issues) on GitHub. 31 | -------------------------------------------------------------------------------- /docs/references.bib: -------------------------------------------------------------------------------- 1 | @article{example2024, 2 | author = {John Doe and Jane Smith}, 3 | title = {A Comprehensive Guide to Genomics}, 4 | journal = {Journal of Genomics}, 5 | year = {2024}, 6 | volume = {42}, 7 | number = {5}, 8 | pages = {123-456}, 9 | } -------------------------------------------------------------------------------- /docs/references.md: -------------------------------------------------------------------------------- 1 | # References 2 | 3 | ```{bibliography} 4 | :style: plain 5 | ``` 6 | -------------------------------------------------------------------------------- /docs/release-notes.md: -------------------------------------------------------------------------------- 1 | # Release notes 2 | 3 | ## Release 0.8.0 (released January XX, 2024) 4 | 5 | ### Features: 6 | - 7 | 8 | ### Bug fixes: 9 | 10 | ### Other changes: -------------------------------------------------------------------------------- /docs/usage-principles.md: -------------------------------------------------------------------------------- 1 | # Usage Principles 2 | SeqData is designed to facilitate reading common biological sequence data formats into analysis and machine learning pipelines. The following principles guide its design and usage: 3 | 4 | ## Unified Interface for Reading Data 5 | SeqData provides a consistent interface for reading two major types of data: 6 | 7 | 1. **Genomic Sequences:** 8 | - **TSV:** Explicitly defined sequences in tabular format (e.g., CSV). 9 | - **FASTA:** Explicitly defined sequences in FASTA format. 10 | - **BED:** Implicitly defined sequences corresponding to genomic start and end coordinates. 11 | 12 | 2. **Read Alignment/Coverage Data:** 13 | - **BAM:** Summarizes read alignments overlapping genomic regions. 14 | - **BigWig:** Summarizes coverage data overlapping genomic regions. 15 | 16 | ## Handling Diverse Experimental Data 17 | SeqData accommodates various experimental data types by combining file formats based on the dataset and analysis goals. Common examples in regulatory genomics include: 18 | 19 | 1. **Massively Parallel Reporter Assays (MPRAs):** 20 | - Use simple TSV or "flat" FASTA files to store information about regulatory activity. These formats can be read without a reference genome. 21 | - See [tutorials](tutorials/1_Reading_Flat_Files.md) for details. 22 | 23 | 2. **ATAC-seq or ChIP-seq Data:** 24 | - Typically stored in BAM or BigWig files. Combined with a reference genome and coordinates in a BED file, SeqData enables reading DNA sequences and associated read coverage data. 25 | - See [tutorials](tutorials/2_Reading_Coverage_Data.md). 26 | 27 | ## Building XArray Datasets Backed by Zarr Stores 28 | SeqData transforms these file formats into Zarr stores that can be read as XArray datasets. XArray provides N-dimensional labeled arrays, similar to NumPy arrays, with the following benefits: 29 | 30 | - **Lazy Loading:** Using Dask-backed Zarr stores, SeqData loads and processes only the required subsets of data, making it suitable for large datasets. 31 | - **Efficiency:** Aligns sequences, coverage, and metadata in a unified structure. 32 | 33 | See [tutorials](tutorials/3_XArray_Zarr.md) for implementation details. 34 | 35 | ## Standards Added to XArray Datasets 36 | SeqData enhances XArray datasets with additional standards to better support genomic sequence data: 37 | 38 | ##tandardized Dimensions: 39 | - `_sequence`: Number of sequences in the dataset. 40 | - `_length`: Length of sequences (exists only for fixed-length sequences). 41 | - `_sample_cov`: Number of coverage tracks (samples) 42 | - `_ohe`: Alphabet size for one-hot encoding. 43 | 44 | ##ttributes: 45 | - `max_jitter`: Stores maximum jitter information for sequences. 46 | 47 | ##oordinate Naming Conventions for BED Files: 48 | - `chrom`: Chromosome name. 49 | - `chromStart`: Start coordinate in the reference genome. 50 | - `chromEnd`: End coordinate in the reference genome. 51 | 52 | ##/O Terminology: 53 | - **Fixed-Length Sequences:** 54 | - For narrowPeak-like BED files, the fixed length is calculated from the summit (9th column). 55 | - For other BED files, it is calculated from the midpoint. 56 | 57 | ## Everything CAN Be Stored in a Single XArray Dataset 58 | SeqData enables storing all relevant data in one XArray dataset, ensuring alignment and accessibility. This unified dataset can include sequences, coverage, metadata, sequence attribution, prediction tracks, and more. With lazy loading, users can selectively access only the data needed for specific analyses, reducing memory overhead. 59 | 60 | ## Conversion to PyTorch Dataloaders 61 | SeqData simplifies the transition to machine learning workflows by supporting the conversion of XArray datasets to PyTorch dataloaders via the `sd.to_torch_dataloader()` method. This method flexibly handles genomic datasets for sequence-to-function modeling. 62 | 63 | See [tutorials](tutorials/4_PyTorch_Dataloaders.md) for usage examples. 64 | 65 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "seqdata" 3 | # managed by poetry-dynamic-versioning 4 | version = "0.0.0" 5 | description = "Annotated sequence data" 6 | authors = ["David Laub ", "Adam Klie "] 7 | 8 | [tool.poetry.dependencies] 9 | python = ">=3.9" 10 | pyranges = "^0.0.120" 11 | xarray = ">=2023.10.0" 12 | zarr = "^2.14.2" 13 | dask = "^2023.3.2" 14 | pandera = { version = "^0.22.0", extras = ["polars"] } 15 | cyvcf2 = "^0.30.18" 16 | pyBigWig = "^0.3.22" 17 | polars = "^1.0.0" 18 | more-itertools = "^9.1.0" 19 | pybedtools = "^0.9.0" 20 | pysam = "^0.21.0" 21 | joblib = "^1.1.0" 22 | natsort = "^8.3.1" 23 | numpy = "^1.26" 24 | pandas = "^1.5.2" 25 | numcodecs = "^0.11.0" 26 | typing-extensions = "^4.5.0" 27 | tqdm = "^4.65.0" 28 | seqpro = "^0.1.1" 29 | torch = { version = ">=2", extras = ["torch"] } 30 | pyarrow = "^17.0.0" 31 | pooch = "^1.8.2" 32 | 33 | [tool.poetry.extras] 34 | torch = ["torch"] 35 | 36 | [tool.poetry.group.dev.dependencies] 37 | pytest-cases = "^3.8.6" 38 | pytest = "^8.3.3" 39 | sphinx = ">=6.2.1" 40 | sphinx-autobuild = "2021.3.14" 41 | sphinx-autodoc-typehints = ">=1.23.4" 42 | sphinxcontrib-apidoc = "^0.3.0" 43 | sphinx-rtd-theme = "^1.2.2" 44 | myst-parser = "^2.0.0" 45 | nbsphinx = "^0.9.2" 46 | pandoc = "^2.3" 47 | icecream = "^2.1.3" 48 | 49 | [tool.poetry-dynamic-versioning] 50 | enable = true 51 | 52 | [build-system] 53 | requires = ["poetry-core>=1.0.0", "poetry-dynamic-versioning>=1.0.0,<2.0.0"] 54 | build-backend = "poetry_dynamic_versioning.backend" 55 | 56 | [tool.isort] 57 | profile = "black" 58 | 59 | [tool.pyright] 60 | include = ['seqdata', 'notebooks', 'tests'] 61 | -------------------------------------------------------------------------------- /seqdata/__init__.py: -------------------------------------------------------------------------------- 1 | """Annotated sequence data""" 2 | 3 | __version__ = "0.0.0" # managed by poetry-dynamic-versioning 4 | 5 | from . import datasets 6 | from ._io.bed_ops import add_bed_to_sdata, label_overlapping_regions, read_bedlike 7 | from ._io.read import ( 8 | read_bam, 9 | read_bigwig, 10 | read_flat_fasta, 11 | read_genome_fasta, 12 | read_table, 13 | read_vcf, 14 | ) 15 | from ._io.readers import BAM, VCF, BigWig, FlatFASTA, GenomeFASTA, Table 16 | from .xarray.seqdata import ( 17 | from_flat_files, 18 | from_region_files, 19 | merge_obs, 20 | open_zarr, 21 | to_zarr, 22 | ) 23 | 24 | try: 25 | from .torch import XArrayDataLoader, get_torch_dataloader 26 | 27 | TORCH_AVAILABLE = True 28 | except ImportError: 29 | TORCH_AVAILABLE = False 30 | 31 | def no_torch(): 32 | raise ImportError( 33 | "Install PyTorch to use functionality from SeqData's torch submodule." 34 | ) 35 | 36 | get_torch_dataloader = no_torch 37 | 38 | 39 | __all__ = [ 40 | "from_flat_files", 41 | "from_region_files", 42 | "open_zarr", 43 | "to_zarr", 44 | "get_torch_dataloader", 45 | "read_bedlike", 46 | "read_bam", 47 | "read_bigwig", 48 | "read_flat_fasta", 49 | "read_genome_fasta", 50 | "read_table", 51 | "read_vcf", 52 | "BAM", 53 | "VCF", 54 | "BigWig", 55 | "FlatFASTA", 56 | "GenomeFASTA", 57 | "Table", 58 | "add_bed_to_sdata", 59 | "label_overlapping_regions", 60 | "merge_obs", 61 | "XArrayDataLoader", 62 | "datasets", 63 | ] 64 | -------------------------------------------------------------------------------- /seqdata/_io/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML4GLand/SeqData/10afa85d0ec954b9d28955b072a55274c3c18dfd/seqdata/_io/__init__.py -------------------------------------------------------------------------------- /seqdata/_io/bed_ops.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from pathlib import Path 3 | from typing import List, Literal, Optional, Union, cast 4 | 5 | import pandera.polars as pa 6 | import polars as pl 7 | import xarray as xr 8 | import zarr 9 | from pybedtools import BedTool 10 | 11 | from seqdata._io.utils import _df_to_xr_zarr 12 | from seqdata.types import PathType 13 | 14 | BED_COLS = [ 15 | "chrom", 16 | "chromStart", 17 | "chromEnd", 18 | "name", 19 | "score", 20 | "strand", 21 | "thickStart", 22 | "thickEnd", 23 | "itemRgb", 24 | "blockCount", 25 | "blockSizes", 26 | "blockStarts", 27 | ] 28 | 29 | 30 | def _set_uniform_length_around_center(bed: pl.DataFrame, length: int) -> pl.DataFrame: 31 | if "peak" in bed: 32 | center = pl.col("chromStart") + pl.col("peak") 33 | else: 34 | center = (pl.col("chromStart") + pl.col("chromEnd")) / 2 35 | return bed.with_columns( 36 | chromStart=(center - length / 2).round().cast(pl.Int64), 37 | chromEnd=(center + length / 2).round().cast(pl.Int64), 38 | ) 39 | 40 | 41 | def _expand_regions(bed: pl.DataFrame, expansion_length: int) -> pl.DataFrame: 42 | return bed.with_columns( 43 | chromStart=(pl.col("chromStart") - expansion_length), 44 | chromEnd=(pl.col("chromEnd") + expansion_length), 45 | ) 46 | 47 | 48 | def _bed_to_zarr(bed: pl.DataFrame, root: zarr.Group, dim: str, **kwargs): 49 | bed = bed.with_columns(pl.col(pl.Utf8).fill_null(".")) 50 | _df_to_xr_zarr(bed, root, dim, **kwargs) 51 | 52 | 53 | def add_bed_to_sdata( 54 | sdata: xr.Dataset, 55 | bed: pl.DataFrame, 56 | col_prefix: Optional[str] = None, 57 | sequence_dim: Optional[str] = None, 58 | ): 59 | """Warning: This function is experimental and may change in the future. 60 | Add a BED-like DataFrame to a Dataset. 61 | 62 | Parameters 63 | ---------- 64 | sdata : xr.Dataset 65 | bed : pl.DataFrame 66 | col_prefix : str, optional 67 | Prefix to add to the column names of the DataFrame before merging. 68 | sequence_dim : str, optional 69 | Name of the sequence dimension in the resulting Dataset. 70 | """ 71 | bed_ = bed.to_pandas() 72 | if col_prefix is not None: 73 | bed_.columns = [col_prefix + c for c in bed_.columns] 74 | if sequence_dim is not None: 75 | bed_.index.name = sequence_dim 76 | return sdata.merge(bed_.to_xarray()) 77 | 78 | 79 | def label_overlapping_regions( 80 | sdata: xr.Dataset, 81 | targets: Union[PathType, pl.DataFrame, List[str]], 82 | mode: Literal["binary", "multitask"], 83 | label_dim: Optional[str] = None, 84 | fraction_overlap: Optional[float] = None, 85 | ) -> xr.DataArray: 86 | """Warning: This function is experimental and may change in the future. 87 | 88 | Label regions for binary or multitask classification based on whether they 89 | overlap with another set of regions. 90 | 91 | Parameters 92 | ---------- 93 | sdata : xr.Dataset 94 | targets : Union[str, Path, pl.DataFrame, List[str]] 95 | Either a DataFrame (or path to one) with (for binary classification) at least 96 | columns ['chrom', 'chromStart', 'chromEnd'], or a list of variable names in 97 | `sdata` to use that correspond to the ['chrom', 'chromStart', 'chromEnd'] 98 | columns, in that order. This is useful if, for example, another set of regions 99 | is already in the `sdata` object under a different set of column names. For 100 | multitask classification, the 'name' column is also required (i.e. binary 101 | requires BED3 format, multitask requires BED4). 102 | mode : Literal["binary", "multitask"] 103 | Whether to mark regions for binary (intersects with any of the target regions) 104 | or multitask classification (which target region does it intersect with?). 105 | label_dim : str, optional 106 | Name of the label dimension. Only needed for multitask classification. 107 | fraction_overlap: float, optional 108 | Fraction of the length that must be overlapping to be considered an 109 | overlap. This is the "reciprocal minimal overlap fraction" as described in the 110 | [bedtools documentation](https://bedtools.readthedocs.io/en/latest/content/tools/intersect.html#r-and-f-requiring-reciprocal-minimal-overlap-fraction). 111 | """ 112 | bed1 = BedTool.from_dataframe( 113 | sdata[["chrom", "chromStart", "chromEnd", "strand"]].to_dataframe() 114 | ) 115 | 116 | if isinstance(targets, (str, Path)): 117 | bed2 = BedTool(targets) 118 | elif isinstance(targets, pl.DataFrame): 119 | bed2 = BedTool.from_dataframe(targets) 120 | elif isinstance(targets, list): 121 | bed2 = BedTool.from_dataframe(sdata[targets].to_dataframe()) 122 | 123 | if fraction_overlap is not None and (fraction_overlap < 0 or fraction_overlap > 1): 124 | raise ValueError("Fraction overlap must be between 0 and 1 (inclusive).") 125 | 126 | if mode == "binary": 127 | if label_dim is not None: 128 | warnings.warn("Ignoring `label_dim` for binary classification.") 129 | if fraction_overlap is None: 130 | res = bed1.intersect(bed2, c=True) # type: ignore 131 | else: 132 | res = bed1.intersect(bed2, c=True, f=fraction_overlap, r=True) # type: ignore 133 | with open(res.fn) as f: 134 | n_cols = len(f.readline().split("\t")) 135 | labels = ( 136 | pl.read_csv( 137 | res.fn, 138 | separator="\t", 139 | has_header=False, 140 | columns=[0, 1, 2, n_cols - 1], 141 | new_columns=["chrom", "chromStart", "chromEnd", "label"], 142 | ) 143 | .with_columns((pl.col("label") > 0).cast(pl.UInt8))["label"] 144 | .to_numpy() 145 | ) 146 | return xr.DataArray(labels, dims=sdata.attrs["sequence_dim"]) 147 | elif mode == "multitask": 148 | if label_dim is None: 149 | raise ValueError( 150 | """Need a name for the label dimension when generating labels for 151 | multitask classification.""" 152 | ) 153 | if fraction_overlap is None: 154 | res = bed1.intersect(bed2, loj=True) # type: ignore 155 | else: 156 | res = bed1.intersect(bed2, loj=True, f=fraction_overlap, r=True) # type: ignore 157 | labels = ( 158 | pl.read_csv( 159 | res.fn, 160 | separator="\t", 161 | has_header=False, 162 | columns=[0, 1, 2, 7], 163 | new_columns=["chrom", "chromStart", "chromEnd", "label"], 164 | ) 165 | .to_dummies("label") 166 | .select(pl.exclude(r"^label_\.$")) 167 | .group_by("chrom", "chromStart", "chromEnd", maintain_order=True) 168 | .agg(pl.exclude(r"^label.*$").first(), pl.col(r"^label.*$").max()) 169 | .select(r"^label.*$") # (sequences labels) 170 | ) 171 | label_names = xr.DataArray( 172 | [c.split("_", 1)[1] for c in labels.columns], dims=label_dim 173 | ) 174 | return xr.DataArray( 175 | labels.to_numpy(), 176 | coords={label_dim: label_names}, 177 | dims=[sdata.attrs["sequence_dim"], label_dim], 178 | ) 179 | 180 | 181 | def read_bedlike(path: PathType) -> pl.DataFrame: 182 | """Reads a bed-like (BED3+) file as a pandas DataFrame. The file type is inferred 183 | from the file extension. 184 | 185 | Parameters 186 | ---------- 187 | path : PathType 188 | 189 | Returns 190 | ------- 191 | pandas.DataFrame 192 | """ 193 | path = Path(path) 194 | if ".bed" in path.suffixes: 195 | return _read_bed(path) 196 | elif ".narrowPeak" in path.suffixes: 197 | return _read_narrowpeak(path) 198 | elif ".broadPeak" in path.suffixes: 199 | return _read_broadpeak(path) 200 | else: 201 | raise ValueError( 202 | f"""Unrecognized file extension: {''.join(path.suffixes)}. Expected one of 203 | .bed, .narrowPeak, or .broadPeak""" 204 | ) 205 | 206 | 207 | BEDSchema = pa.DataFrameSchema( 208 | { 209 | "chrom": pa.Column(str), 210 | "chromStart": pa.Column(int), 211 | "chromEnd": pa.Column(int), 212 | "name": pa.Column(str, nullable=True, required=False), 213 | "score": pa.Column(float, nullable=True, required=False), 214 | "strand": pa.Column( 215 | str, nullable=True, checks=pa.Check.isin(["+", "-", "."]), required=False 216 | ), 217 | "thickStart": pa.Column(int, nullable=True, required=False), 218 | "thickEnd": pa.Column(int, nullable=True, required=False), 219 | "itemRgb": pa.Column(str, nullable=True, required=False), 220 | "blockCount": pa.Column(pl.UInt64, nullable=True, required=False), 221 | "blockSizes": pa.Column(str, nullable=True, required=False), 222 | "blockStarts": pa.Column(str, nullable=True, required=False), 223 | }, 224 | coerce=True, 225 | ) 226 | 227 | 228 | def _read_bed(bed_path: PathType): 229 | with open(bed_path) as f: 230 | skip_rows = 0 231 | while (line := f.readline()).startswith(("track", "browser")): 232 | skip_rows += 1 233 | n_cols = line.count("\t") + 1 234 | bed = pl.read_csv( 235 | bed_path, 236 | separator="\t", 237 | has_header=False, 238 | skip_rows=skip_rows, 239 | new_columns=BED_COLS[:n_cols], 240 | schema_overrides={"chrom": pl.Utf8, "name": pl.Utf8, "strand": pl.Utf8}, 241 | null_values=".", 242 | ).pipe(BEDSchema.validate) 243 | bed = cast(pl.DataFrame, bed) 244 | return bed 245 | 246 | 247 | NarrowPeakSchema = pa.DataFrameSchema( 248 | { 249 | "chrom": pa.Column(str), 250 | "chromStart": pa.Column(int), 251 | "chromEnd": pa.Column(int), 252 | "name": pa.Column(str, nullable=True, required=False), 253 | "score": pa.Column(float, nullable=True, required=False), 254 | "strand": pa.Column( 255 | str, nullable=True, checks=pa.Check.isin(["+", "-", "."]), required=False 256 | ), 257 | "signalValue": pa.Column(float, nullable=True, required=False), 258 | "pValue": pa.Column(float, nullable=True, required=False), 259 | "qValue": pa.Column(float, nullable=True, required=False), 260 | "peak": pa.Column(int, nullable=True, required=False), 261 | }, 262 | coerce=True, 263 | ) 264 | 265 | 266 | def _read_narrowpeak(narrowpeak_path: PathType) -> pl.DataFrame: 267 | with open(narrowpeak_path) as f: 268 | skip_rows = 0 269 | while f.readline().startswith(("track", "browser")): 270 | skip_rows += 1 271 | narrowpeaks = pl.read_csv( 272 | narrowpeak_path, 273 | separator="\t", 274 | has_header=False, 275 | skip_rows=skip_rows, 276 | new_columns=[ 277 | "chrom", 278 | "chromStart", 279 | "chromEnd", 280 | "name", 281 | "score", 282 | "strand", 283 | "signalValue", 284 | "pValue", 285 | "qValue", 286 | "peak", 287 | ], 288 | schema_overrides={"chrom": pl.Utf8, "name": pl.Utf8, "strand": pl.Utf8}, 289 | null_values=".", 290 | ).pipe(NarrowPeakSchema.validate) 291 | narrowpeaks = cast(pl.DataFrame, narrowpeaks) 292 | return narrowpeaks 293 | 294 | 295 | BroadPeakSchema = pa.DataFrameSchema( 296 | { 297 | "chrom": pa.Column(str), 298 | "chromStart": pa.Column(int), 299 | "chromEnd": pa.Column(int), 300 | "name": pa.Column(str, nullable=True, required=False), 301 | "score": pa.Column(float, nullable=True, required=False), 302 | "strand": pa.Column( 303 | str, nullable=True, checks=pa.Check.isin(["+", "-", "."]), required=False 304 | ), 305 | "signalValue": pa.Column(float, nullable=True, required=False), 306 | "pValue": pa.Column(float, nullable=True, required=False), 307 | "qValue": pa.Column(float, nullable=True, required=False), 308 | }, 309 | coerce=True, 310 | ) 311 | 312 | 313 | def _read_broadpeak(broadpeak_path: PathType): 314 | with open(broadpeak_path) as f: 315 | skip_rows = 0 316 | while f.readline().startswith(("track", "browser")): 317 | skip_rows += 1 318 | broadpeaks = pl.read_csv( 319 | broadpeak_path, 320 | separator="\t", 321 | has_header=False, 322 | skip_rows=skip_rows, 323 | new_columns=[ 324 | "chrom", 325 | "chromStart", 326 | "chromEnd", 327 | "name", 328 | "score", 329 | "strand", 330 | "signalValue", 331 | "pValue", 332 | "qValue", 333 | ], 334 | schema_overrides={"chrom": pl.Utf8, "name": pl.Utf8, "strand": pl.Utf8}, 335 | null_values=".", 336 | ).pipe(BroadPeakSchema.validate) 337 | broadpeaks = cast(pl.DataFrame, broadpeaks) 338 | return broadpeaks 339 | -------------------------------------------------------------------------------- /seqdata/_io/read.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, List, Optional, Type, Union 2 | 3 | import numpy as np 4 | import seqpro as sp 5 | 6 | from seqdata._io.readers import BAM, VCF, BigWig, FlatFASTA, GenomeFASTA, Table 7 | from seqdata.types import ListPathType, PathType 8 | from seqdata.xarray.seqdata import from_flat_files, from_region_files 9 | 10 | if TYPE_CHECKING: 11 | import pandas as pd 12 | import polars as pl 13 | import xarray as xr 14 | 15 | 16 | def read_table( 17 | name: str, 18 | out: PathType, 19 | tables: Union[PathType, ListPathType], 20 | seq_col: str, 21 | batch_size: int, 22 | fixed_length: bool, 23 | overwrite=False, 24 | **kwargs, 25 | ) -> "xr.Dataset": 26 | """Reads sequences and metadata from tabular files (e.g. CSV, TSV, etc.) into xarray. 27 | 28 | Uses polars under the hood to read the table files. 29 | 30 | Parameters 31 | ---------- 32 | name : str 33 | Name of the sequence variable in the output dataset. 34 | out : PathType 35 | Path to the output Zarr store where the data will be saved. 36 | Usually something like `/path/to/dataset_name.zarr`. 37 | tables : Union[PathType, ListPathType] 38 | Path to the input table file(s). Can be a single file or a list of files. 39 | seq_col : str 40 | Name of the column in the table that contains the sequence. 41 | batch_size : int 42 | Number of sequences to read at a time. Use as many as you can fit in memory. 43 | fixed_length : bool 44 | Whether your sequences have a fixed length or not. If they do, the data will be 45 | stored in a 2D array as bytes, otherwise it will be stored as unicode strings. 46 | overwrite : bool 47 | Whether to overwrite the output Zarr store if it already exists. 48 | **kwargs 49 | Additional keyword arguments to pass to the polars `read_csv` function. 50 | 51 | Returns 52 | ------- 53 | xr.Dataset 54 | The output dataset. 55 | """ 56 | sdata = from_flat_files( 57 | Table( 58 | name=name, tables=tables, seq_col=seq_col, batch_size=batch_size, **kwargs 59 | ), 60 | path=out, 61 | fixed_length=fixed_length, 62 | overwrite=overwrite, 63 | ) 64 | return sdata 65 | 66 | 67 | def read_flat_fasta( 68 | name: str, 69 | out: PathType, 70 | fasta: PathType, 71 | batch_size: int, 72 | fixed_length: bool, 73 | n_threads=1, 74 | overwrite=False, 75 | ) -> "xr.Dataset": 76 | """Reads sequences from a "flat" FASTA file into xarray. 77 | 78 | We differentiate between "flat" and "genome" FASTA files. A flat FASTA file is one 79 | where each contig in the FASTA file is a sequence in our dataset. A genome FASTA file 80 | is one where we may pull out multiple subsequences from a given contig. 81 | 82 | Parameters 83 | ---------- 84 | name : str 85 | Name of the sequence variable in the output dataset. 86 | out : PathType 87 | Path to the output Zarr store where the data will be saved. 88 | Usually something like `/path/to/dataset_name.zarr`. 89 | fasta : PathType 90 | Path to the input FASTA file. 91 | batch_size : int 92 | Number of sequences to read at a time. Use as many as you can fit in memory. 93 | fixed_length : bool 94 | Whether your sequences have a fixed length or not. If they do, the data will be 95 | stored in a 2D array as bytes, otherwise it will be stored as unicode strings. 96 | n_threads : int 97 | Number of threads to use for reading the FASTA file. 98 | overwrite : bool 99 | Whether to overwrite the output Zarr store if it already exists. 100 | 101 | Returns 102 | ------- 103 | xr.Dataset 104 | The output dataset. 105 | """ 106 | sdata = from_flat_files( 107 | FlatFASTA(name=name, fasta=fasta, batch_size=batch_size, n_threads=n_threads), 108 | path=out, 109 | fixed_length=fixed_length, 110 | overwrite=overwrite, 111 | ) 112 | return sdata 113 | 114 | 115 | def read_genome_fasta( 116 | name: str, 117 | out: PathType, 118 | fasta: PathType, 119 | bed: Union[PathType, "pd.DataFrame", "pl.DataFrame"], 120 | batch_size: int, 121 | fixed_length: Union[int, bool], 122 | n_threads=1, 123 | alphabet: Optional[Union[str, sp.NucleotideAlphabet]] = None, 124 | max_jitter=0, 125 | overwrite=False, 126 | ) -> "xr.Dataset": 127 | """Reads sequences from a "genome" FASTA file into xarray. 128 | 129 | We differentiate between "flat" and "genome" FASTA files. A flat FASTA file is one 130 | where each contig in the FASTA file is a sequence in our dataset. A genome FASTA file 131 | is one where we may pull out multiple subsequences from a given contig. 132 | 133 | Parameters 134 | ---------- 135 | name : str 136 | Name of the sequence variable in the output dataset. 137 | out : PathType 138 | Path to the output Zarr store where the data will be saved. 139 | Usually something like `/path/to/dataset_name.zarr`. 140 | fasta : PathType 141 | Path to the input FASTA file. 142 | bed : Union[PathType, pd.DataFrame] 143 | Path to the input BED file or a pandas DataFrame with the BED data. Used to 144 | define the regions of the genome to pull out. TODO: what does the BED 145 | have to have? 146 | batch_size : int 147 | Number of sequences to read at a time. Use as many as you can fit in memory. 148 | fixed_length : Union[int, bool] 149 | Whether your sequences have a fixed length or not. If they do, the data will be 150 | stored in a 2D array as bytes, otherwise it will be stored as unicode strings. 151 | n_threads : int 152 | Number of threads to use for reading the FASTA file. 153 | alphabet : Optional[Union[str, sp.NucleotideAlphabet]] 154 | Alphabet to use for reading sequences 155 | max_jitter : int 156 | Maximum amount of jitter anticipated. This will read in max_jitter/2 extra sequence 157 | on either side of the region defined by the BED file. This is useful for training 158 | models on coverage data 159 | overwrite : bool 160 | Whether to overwrite the output Zarr store if it already exists. 161 | """ 162 | sdata = from_region_files( 163 | GenomeFASTA( 164 | name=name, 165 | fasta=fasta, 166 | batch_size=batch_size, 167 | n_threads=n_threads, 168 | alphabet=alphabet, 169 | ), 170 | path=out, 171 | fixed_length=fixed_length, 172 | bed=bed, 173 | max_jitter=max_jitter, 174 | overwrite=overwrite, 175 | ) 176 | return sdata 177 | 178 | 179 | def read_bam( 180 | seq_name: str, 181 | cov_name: str, 182 | out: PathType, 183 | fasta: PathType, 184 | bams: ListPathType, 185 | samples: List[str], 186 | bed: Union[PathType, "pd.DataFrame", "pl.DataFrame"], 187 | batch_size: int, 188 | fixed_length: Union[int, bool], 189 | n_jobs=1, 190 | threads_per_job=1, 191 | alphabet: Optional[Union[str, sp.NucleotideAlphabet]] = None, 192 | dtype: Union[str, Type[np.number]] = np.uint16, 193 | max_jitter=0, 194 | overwrite=False, 195 | ) -> "xr.Dataset": 196 | """ 197 | Read in sequences with coverage from a BAM file. 198 | 199 | Parameters 200 | ---------- 201 | seq_name : str 202 | Name of the sequence variable in the output dataset. 203 | cov_name : str 204 | Name of the coverage variable in the output dataset. 205 | out : PathType 206 | Path to the output Zarr store where the data will be saved. 207 | Usually something like `/path/to/dataset_name.zarr`. 208 | fasta : PathType 209 | Path to the reference genome. 210 | bams : ListPathType 211 | List of paths to BAM files. 212 | Can be a single file or a list of files. 213 | samples : List[str] 214 | List of sample names to include. 215 | Should be the same length as `bams`. 216 | bed : Union[PathType, pd.DataFrame] 217 | Path to a BED file or a DataFrame with columns "chrom", "start", and "end". 218 | batch_size : int 219 | Number of regions to read at once. Use as many as you can fit in memory. 220 | fixed_length : Union[int, bool] 221 | Whether your sequences have a fixed length or not. If they do, the data will be 222 | stored in a 2D array as bytes, otherwise it will be stored as unicode strings. 223 | n_jobs : int 224 | Number of parallel jobs. Use if you have multiple BAM files. 225 | threads_per_job : int 226 | Number of threads per job. 227 | alphabet : Optional[Union[str, sp.NucleotideAlphabet]] 228 | Alphabet the sequences have. 229 | dtype : Union[str, Type[np.number]] 230 | Data type to use for coverage. 231 | max_jitter : int 232 | Maximum jitter to use for sampling regions. This will read in max_jitter/2 extra sequence 233 | on either side of the region defined by the BED file. This is useful for training 234 | models on coverage data 235 | overwrite : bool 236 | Whether to overwrite an existing dataset. 237 | 238 | Returns 239 | ------- 240 | xr.Dataset 241 | Dataset with dimensions "_sequence" TODO: what are the dimensions? 242 | """ 243 | sdata = from_region_files( 244 | GenomeFASTA( 245 | name=seq_name, 246 | fasta=fasta, 247 | batch_size=batch_size, 248 | n_threads=n_jobs * threads_per_job, 249 | alphabet=alphabet, 250 | ), 251 | BAM( 252 | name=cov_name, 253 | bams=bams, 254 | samples=samples, 255 | batch_size=batch_size, 256 | n_jobs=n_jobs, 257 | threads_per_job=threads_per_job, 258 | dtype=dtype, 259 | ), 260 | path=out, 261 | fixed_length=fixed_length, 262 | bed=bed, 263 | max_jitter=max_jitter, 264 | overwrite=overwrite, 265 | ) 266 | return sdata 267 | 268 | 269 | def read_bigwig( 270 | seq_name: str, 271 | cov_name: str, 272 | out: PathType, 273 | fasta: PathType, 274 | bigwigs: ListPathType, 275 | samples: List[str], 276 | bed: Union[PathType, "pd.DataFrame", "pl.DataFrame"], 277 | batch_size: int, 278 | fixed_length: Union[int, bool], 279 | n_jobs=1, 280 | threads_per_job=1, 281 | alphabet: Optional[Union[str, sp.NucleotideAlphabet]] = None, 282 | max_jitter=0, 283 | overwrite=False, 284 | ) -> "xr.Dataset": 285 | """ 286 | Read a bigWig file and return a Dataset. 287 | 288 | Parameters 289 | ---------- 290 | seq_name : str 291 | Name of the sequence variable in the output dataset. 292 | cov_name : str 293 | Name of the coverage variable in the output dataset. 294 | out : PathType 295 | Path to the output Zarr store where the data will be saved. 296 | Usually something like `/path/to/dataset_name.zarr`. 297 | fasta : PathType 298 | Path to the reference genome. 299 | bigwigs : ListPathType 300 | List of paths to bigWig files. 301 | Can be a single file or a list of files. 302 | samples : List[str] 303 | List of sample names to include. 304 | Should be the same length as `bigwigs`. 305 | bed : Union[PathType, pd.DataFrame] 306 | Path to a BED file or a DataFrame with columns "chrom", "start", and "end". 307 | batch_size : int 308 | Number of regions to read at once. Use as many as you can fit in memory. 309 | fixed_length : Union[int, bool] 310 | Whether your sequences have a fixed length or not. If they do, the data will be 311 | stored in a 2D array as bytes, otherwise it will be stored as unicode strings. 312 | n_jobs : int 313 | Number of parallel jobs. Use if you have multiple bigWig files. 314 | threads_per_job : int 315 | Number of threads per job. 316 | alphabet : Optional[Union[str, sp.NucleotideAlphabet]] 317 | Alphabet the sequences have. 318 | dtype : Union[str, Type[np.number]] 319 | Data type to use for coverage. 320 | max_jitter : int 321 | Maximum jitter to use for sampling regions. 322 | overwrite : bool 323 | Whether to overwrite an existing dataset. 324 | 325 | Returns 326 | ------- 327 | xr.Dataset 328 | Dataset with dimensions "_sequence" TODO: what are the dimensions? 329 | """ 330 | sdata = from_region_files( 331 | GenomeFASTA( 332 | name=seq_name, 333 | fasta=fasta, 334 | batch_size=batch_size, 335 | n_threads=n_jobs * threads_per_job, 336 | alphabet=alphabet, 337 | ), 338 | BigWig( 339 | name=cov_name, 340 | bigwigs=bigwigs, 341 | samples=samples, 342 | batch_size=batch_size, 343 | n_jobs=n_jobs, 344 | threads_per_job=threads_per_job, 345 | ), 346 | path=out, 347 | fixed_length=fixed_length, 348 | bed=bed, 349 | max_jitter=max_jitter, 350 | overwrite=overwrite, 351 | ) 352 | return sdata 353 | 354 | 355 | def read_vcf( 356 | name: str, 357 | out: PathType, 358 | vcf: PathType, 359 | fasta: PathType, 360 | samples: List[str], 361 | bed: Union[PathType, "pd.DataFrame", "pl.DataFrame"], 362 | batch_size: int, 363 | fixed_length: Union[int, bool], 364 | n_threads=1, 365 | samples_per_chunk=10, 366 | alphabet: Optional[Union[str, sp.NucleotideAlphabet]] = None, 367 | max_jitter=0, 368 | overwrite=False, 369 | splice=False, 370 | ) -> "xr.Dataset": 371 | """ 372 | Read a VCF file and return a Dataset. 373 | 374 | Parameters 375 | ---------- 376 | name : str 377 | Name of the sequence variable in the output dataset. 378 | out : PathType 379 | Path to the output Zarr store where the data will be saved. 380 | Usually something like `/path/to/dataset_name.zarr`. 381 | vcf : PathType 382 | Path to the VCF file. 383 | fasta : PathType 384 | Path to the reference genome. 385 | samples : List[str] 386 | List of sample names to include. 387 | bed : Union[PathType, pd.DataFrame] 388 | Path to a BED file or a DataFrame with columns "chrom", "start", and "end". 389 | batch_size : int 390 | Number of regions to read at once. Use as many as you can fit in memory. 391 | fixed_length : Union[int, bool] 392 | Whether your sequences have a fixed length or not. If they do, the data will be 393 | stored in a 2D array as bytes, otherwise it will be stored as unicode strings. 394 | n_threads : int 395 | Number of threads to use for reading the VCF file. 396 | samples_per_chunk : int 397 | Number of samples to read at a time. 398 | alphabet : Optional[Union[str, sp.NucleotideAlphabet]] 399 | Alphabet the sequences have. 400 | max_jitter : int 401 | Maximum jitter to use for sampling regions. 402 | overwrite : bool 403 | Whether to overwrite an existing dataset. 404 | splice : bool 405 | TODO 406 | Returns 407 | ------- 408 | xr.Dataset 409 | xarray dataset 410 | """ 411 | sdata = from_region_files( 412 | VCF( 413 | name=name, 414 | vcf=vcf, 415 | fasta=fasta, 416 | samples=samples, 417 | batch_size=batch_size, 418 | n_threads=n_threads, 419 | samples_per_chunk=samples_per_chunk, 420 | alphabet=alphabet, 421 | ), 422 | path=out, 423 | fixed_length=fixed_length, 424 | bed=bed, 425 | max_jitter=max_jitter, 426 | overwrite=overwrite, 427 | splice=splice, 428 | ) 429 | return sdata 430 | -------------------------------------------------------------------------------- /seqdata/_io/readers/__init__.py: -------------------------------------------------------------------------------- 1 | from .bam import BAM 2 | from .fasta import FlatFASTA, GenomeFASTA 3 | from .table import Table 4 | from .vcf import VCF 5 | from .wig import BigWig 6 | 7 | __all__ = ["BAM", "FlatFASTA", "GenomeFASTA", "VCF", "Table", "BigWig"] 8 | -------------------------------------------------------------------------------- /seqdata/_io/readers/bam.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from pathlib import Path 3 | from typing import Any, Dict, Generic, List, Literal, Optional, Type, Union, cast 4 | 5 | import joblib 6 | import numpy as np 7 | import polars as pl 8 | import pysam 9 | import zarr 10 | from more_itertools import split_when 11 | from numcodecs import ( 12 | Blosc, 13 | Delta, 14 | VLenArray, 15 | VLenUTF8, 16 | blosc, # type: ignore 17 | ) 18 | from numpy.typing import NDArray 19 | from tqdm import tqdm 20 | 21 | from seqdata._io.utils import _get_row_batcher 22 | from seqdata.types import DTYPE, PathType, RegionReader 23 | 24 | ### pysam implementation NOTE ### 25 | 26 | # pysam.AlignmentFile.count_coverage 27 | # contig not found => raises KeyError 28 | # start < 0 => raises ValueError 29 | # end > reference length => truncates interval 30 | 31 | 32 | class CountMethod(str, Enum): 33 | DEPTH = "depth-only" 34 | TN5_CUTSITE = "tn5-cutsite" 35 | TN5_FRAGMENT = "tn5-fragment" 36 | 37 | 38 | class BAM(RegionReader, Generic[DTYPE]): 39 | def __init__( 40 | self, 41 | name: str, 42 | bams: Union[str, Path, List[str], List[Path]], 43 | samples: Union[str, List[str]], 44 | batch_size: int, 45 | n_jobs=1, 46 | threads_per_job=1, 47 | dtype: Union[str, Type[np.number]] = np.uint16, 48 | sample_dim: Optional[str] = None, 49 | offset_tn5=False, 50 | count_method: Union[ 51 | CountMethod, Literal["depth-only", "tn5-cutsite", "tn5-fragment"] 52 | ] = "depth-only", 53 | ) -> None: 54 | """Reader for BAM files. 55 | 56 | Parameters 57 | ---------- 58 | name : str 59 | Name of the array this reader will write. 60 | bams : Union[str, Path, List[str], List[Path]] 61 | Path or a list of paths to BAM(s). 62 | samples : Union[str, List[str]] 63 | Sample names for each BAM. 64 | batch_size : int 65 | Number of sequences to write at a time. Note this also sets the chunksize 66 | along the sequence dimension. 67 | n_jobs : int, optional 68 | Number of BAMs to process in parallel, by default 1, which disables 69 | multiprocessing. Don't set this higher than the number of BAMs or number of 70 | cores available. 71 | threads_per_job : int, optional 72 | Threads to use per job, by default 1. Make sure the number of available 73 | cores is >= n_jobs * threads_per_job. 74 | dtype : Union[str, Type[np.number]], optional 75 | Data type to write the coverage as, by default np.uint16. 76 | sample_dim : Optional[str], optional 77 | Name of the sample dimension, by default None 78 | offset_tn5 : bool, optional 79 | Whether to adjust read lengths to account for Tn5 binding, by default False 80 | count_method : Union[CountMethod, Literal["depth-only", "tn5-cutsite", "tn5-fragment"]] 81 | Count method, by default "depth-only" 82 | """ 83 | if isinstance(bams, str): 84 | bams = [bams] 85 | elif isinstance(bams, Path): 86 | bams = [bams] 87 | if isinstance(samples, str): 88 | samples = [samples] 89 | 90 | self.name = name 91 | self.total_reads_name = f"total_reads_{name}" 92 | self.bams = bams 93 | self.samples = samples 94 | self.batch_size = batch_size 95 | self.n_jobs = n_jobs 96 | self.threads_per_job = threads_per_job 97 | self.dtype = np.dtype(dtype) 98 | self.sample_dim = f"{name}_sample" if sample_dim is None else sample_dim 99 | self.offset_tn5 = offset_tn5 100 | self.count_method = CountMethod(count_method) 101 | 102 | def _write( 103 | self, 104 | out: PathType, 105 | bed: pl.DataFrame, 106 | fixed_length: Union[int, Literal[False]], 107 | sequence_dim: str, 108 | length_dim: Optional[str] = None, 109 | splice=False, 110 | overwrite=False, 111 | ) -> None: 112 | if fixed_length is False: 113 | self._write_variable_length(out, bed, sequence_dim, overwrite, splice) 114 | else: 115 | assert length_dim is not None 116 | self._write_fixed_length( 117 | out, bed, fixed_length, sequence_dim, length_dim, overwrite, splice 118 | ) 119 | 120 | def _write_fixed_length( 121 | self, 122 | out: PathType, 123 | bed: pl.DataFrame, 124 | fixed_length: int, 125 | sequence_dim: str, 126 | length_dim: str, 127 | overwrite: bool, 128 | splice: bool, 129 | ): 130 | compressor = Blosc("zstd", clevel=7, shuffle=-1) 131 | 132 | batch_size = min(len(bed), self.batch_size) 133 | root = zarr.open_group(out) 134 | 135 | arr = root.array( 136 | self.sample_dim, 137 | data=np.array(self.samples, object), 138 | compressor=compressor, 139 | overwrite=overwrite, 140 | object_codec=VLenUTF8(), 141 | ) 142 | arr.attrs["_ARRAY_DIMENSIONS"] = [self.sample_dim] 143 | 144 | coverage = root.zeros( 145 | self.name, 146 | shape=(len(bed), len(self.samples), fixed_length), 147 | dtype=self.dtype, 148 | chunks=(batch_size, 1, None), 149 | overwrite=overwrite, 150 | compressor=compressor, 151 | filters=[Delta(self.dtype)], 152 | ) 153 | coverage.attrs["_ARRAY_DIMENSIONS"] = [ 154 | sequence_dim, 155 | self.sample_dim, 156 | length_dim, 157 | ] 158 | 159 | total_reads = root.zeros( 160 | self.total_reads_name, 161 | shape=len(self.samples), 162 | dtype=np.uint64, 163 | chunks=None, 164 | overwrite=overwrite, 165 | compressor=compressor, 166 | ) 167 | total_reads.attrs["_ARRAY_DIMENSIONS"] = [self.sample_dim] 168 | 169 | sample_idxs = np.arange(len(self.samples)) 170 | tasks = [ 171 | joblib.delayed(self._read_bam_fixed_length)( 172 | root, 173 | bam, 174 | bed, 175 | batch_size, 176 | sample_idx, 177 | self.threads_per_job, 178 | fixed_length=fixed_length, 179 | splice=splice, 180 | ) 181 | for bam, sample_idx in zip(self.bams, sample_idxs) 182 | ] 183 | with joblib.parallel_backend( 184 | "loky", n_jobs=self.n_jobs, inner_max_num_threads=self.threads_per_job 185 | ): 186 | joblib.Parallel()(tasks) 187 | 188 | def _write_variable_length( 189 | self, 190 | out: PathType, 191 | bed: pl.DataFrame, 192 | sequence_dim: str, 193 | overwrite: bool, 194 | splice: bool, 195 | ): 196 | compressor = Blosc("zstd", clevel=7, shuffle=-1) 197 | 198 | batch_size = min(len(bed), self.batch_size) 199 | root = zarr.open_group(out) 200 | 201 | arr = root.array( 202 | self.sample_dim, 203 | data=np.array(self.samples, object), 204 | compressor=compressor, 205 | overwrite=overwrite, 206 | object_codec=VLenUTF8(), 207 | ) 208 | arr.attrs["_ARRAY_DIMENSIONS"] = [self.sample_dim] 209 | 210 | coverage = root.empty( 211 | self.name, 212 | shape=(len(bed), len(self.samples)), 213 | dtype=object, 214 | chunks=(batch_size, 1), 215 | overwrite=overwrite, 216 | compressor=compressor, 217 | filters=[Delta(self.dtype)], 218 | object_codec=VLenArray(self.dtype), 219 | ) 220 | coverage.attrs["_ARRAY_DIMENSIONS"] = [ 221 | sequence_dim, 222 | self.sample_dim, 223 | ] 224 | 225 | total_reads = root.zeros( 226 | self.total_reads_name, 227 | shape=len(self.samples), 228 | dtype=np.uint64, 229 | chunks=None, 230 | overwrite=overwrite, 231 | compressor=compressor, 232 | ) 233 | total_reads.attrs["_ARRAY_DIMENSIONS"] = [self.sample_dim] 234 | 235 | sample_idxs = np.arange(len(self.samples)) 236 | tasks = [ 237 | joblib.delayed(self._read_bam_variable_length)( 238 | root, 239 | bam, 240 | bed, 241 | batch_size, 242 | sample_idx, 243 | self.threads_per_job, 244 | splice=splice, 245 | ) 246 | for bam, sample_idx in zip(self.bams, sample_idxs) 247 | ] 248 | with joblib.parallel_backend( 249 | "loky", n_jobs=self.n_jobs, inner_max_num_threads=self.threads_per_job 250 | ): 251 | joblib.Parallel()(tasks) 252 | 253 | def _read_bam_fixed_length( 254 | self, 255 | root: zarr.Group, 256 | bam: PathType, 257 | bed: pl.DataFrame, 258 | batch_size: int, 259 | sample_idx: int, 260 | n_threads: int, 261 | fixed_length: int, 262 | splice: bool, 263 | ): 264 | blosc.set_nthreads(n_threads) 265 | to_rc = cast(NDArray[np.bool_], bed["strand"].eq_missing("-").to_numpy()) 266 | 267 | batch = np.zeros((batch_size, fixed_length), self.dtype) 268 | 269 | with pysam.AlignmentFile(str(bam), threads=n_threads) as f: 270 | 271 | def read_cb(x: pysam.AlignedSegment): 272 | return x.is_proper_pair and not x.is_secondary 273 | 274 | total_reads = sum([f.count(c, read_callback=read_cb) for c in f.references]) 275 | root[self.total_reads_name][sample_idx] = total_reads 276 | 277 | reader = self._spliced_reader if splice else self._reader 278 | row_batcher = _get_row_batcher(reader(bed, f), batch_size) 279 | for is_last_row, is_last_in_batch, out, idx, start in row_batcher: 280 | batch[idx] = out 281 | if is_last_in_batch or is_last_row: 282 | _batch = batch[: idx + 1] 283 | to_rc_mask = to_rc[start : start + idx + 1] 284 | _batch[to_rc_mask] = _batch[to_rc_mask, ::-1] 285 | root[self.name][start : start + idx + 1, sample_idx] = _batch 286 | 287 | def _read_bam_variable_length( 288 | self, 289 | root: zarr.Group, 290 | bam: PathType, 291 | bed: pl.DataFrame, 292 | batch_size: int, 293 | sample_idx: int, 294 | n_threads: int, 295 | splice: bool, 296 | ): 297 | blosc.set_nthreads(n_threads) 298 | to_rc = cast(NDArray[np.bool_], bed["strand"].eq_missing("-").to_numpy()) 299 | 300 | batch = np.empty(batch_size, object) 301 | 302 | with pysam.AlignmentFile(str(bam), threads=n_threads) as f: 303 | 304 | def read_cb(x: pysam.AlignedSegment): 305 | return x.is_proper_pair and not x.is_secondary 306 | 307 | total_reads = sum([f.count(c, read_callback=read_cb) for c in f.references]) 308 | root[self.total_reads_name][sample_idx] = total_reads 309 | 310 | reader = self._spliced_reader if splice else self._reader 311 | row_batcher = _get_row_batcher(reader(bed, f), batch_size) 312 | for is_last_row, is_last_in_batch, out, idx, start in row_batcher: 313 | if to_rc[idx]: 314 | out = out[::-1] 315 | batch[idx] = out 316 | if is_last_in_batch or is_last_row: 317 | root[self.name][start : start + idx + 1, sample_idx] = batch[ 318 | : idx + 1 319 | ] 320 | 321 | def _reader(self, bed: pl.DataFrame, f: pysam.AlignmentFile): 322 | for row in tqdm(bed.iter_rows(), total=len(bed)): 323 | contig, start, end = row[:3] 324 | if self.count_method is CountMethod.DEPTH: 325 | coverage = self._count_depth_only(f, contig, start, end) 326 | else: 327 | coverage = self._count_tn5(f, contig, start, end) 328 | yield coverage 329 | 330 | def _spliced_reader(self, bed: pl.DataFrame, f: pysam.AlignmentFile): 331 | pbar = tqdm(total=len(bed)) 332 | for rows in split_when( 333 | bed.iter_rows(), 334 | lambda x, y: x[3] != y[3], # 4th column is "name" 335 | ): 336 | unspliced: List[NDArray[Any]] = [] 337 | for row in rows: 338 | pbar.update() 339 | contig, start, end = row[:3] 340 | if self.count_method is CountMethod.DEPTH: 341 | coverage = self._count_depth_only(f, contig, start, end) 342 | else: 343 | coverage = self._count_tn5(f, contig, start, end) 344 | unspliced.append(coverage) 345 | yield cast(NDArray[DTYPE], np.concatenate(coverage)) # type: ignore 346 | 347 | def _count_depth_only( 348 | self, f: pysam.AlignmentFile, contig: str, start: int, end: int 349 | ): 350 | a, c, g, t = f.count_coverage( 351 | contig, 352 | max(start, 0), 353 | end, 354 | read_callback=lambda x: x.is_proper_pair and not x.is_secondary, 355 | ) 356 | coverage = np.vstack([a, c, g, t]).sum(0).astype(self.dtype) 357 | if (pad_len := end - start - len(coverage)) > 0: 358 | pad_arr = np.zeros(pad_len, dtype=self.dtype) 359 | pad_left = start < 0 360 | if pad_left: 361 | coverage = np.concatenate([pad_arr, coverage]) 362 | else: 363 | coverage = np.concatenate([coverage, pad_arr]) 364 | return coverage 365 | 366 | def _count_tn5(self, f: pysam.AlignmentFile, contig: str, start: int, end: int): 367 | length = end - start 368 | out_array = np.zeros(length, dtype=self.dtype) 369 | 370 | read_cache: Dict[str, pysam.AlignedSegment] = {} 371 | 372 | for i, read in enumerate(f.fetch(contig, max(0, start), end)): 373 | if not read.is_proper_pair or read.is_secondary: 374 | continue 375 | 376 | if read.query_name not in read_cache: 377 | read_cache[read.query_name] = read # type: ignore 378 | continue 379 | 380 | # Forward and Reverse w/o r1 and r2 381 | if read.is_reverse: 382 | forward_read = read_cache.pop(read.query_name) 383 | reverse_read = read 384 | else: 385 | forward_read = read 386 | reverse_read = read_cache.pop(read.query_name) 387 | 388 | rel_start = forward_read.reference_start - start 389 | # 0-based, 1 past aligned 390 | # e.g. start:end == 0:2 == [0, 1] so position of end == 1 391 | rel_end = cast(int, reverse_read.reference_end) - start 392 | 393 | # Shift read if accounting for offset 394 | if self.offset_tn5: 395 | rel_start += 4 396 | rel_end -= 5 397 | 398 | # Check count method 399 | if self.count_method is CountMethod.TN5_CUTSITE: 400 | # Add cut sites to out_array 401 | if rel_start >= 0 and rel_start < length: 402 | out_array[rel_start] += 1 403 | if rel_end >= 0 and rel_end <= length: 404 | out_array[rel_end - 1] += 1 405 | elif self.count_method is CountMethod.TN5_FRAGMENT: 406 | # Add range to out array 407 | out_array[rel_start:rel_end] += 1 408 | 409 | # if any reads are still in the cache, then their mate isn't in the region 410 | for read in read_cache.values(): 411 | # for reverse reads, their mate is in the 5' <- direction 412 | if read.is_reverse: 413 | rel_end = cast(int, read.reference_end) - start 414 | if self.offset_tn5: 415 | rel_end -= 5 416 | if rel_end < 0 or rel_end > length: 417 | continue 418 | if self.count_method is CountMethod.TN5_CUTSITE: 419 | out_array[rel_end - 1] += 1 420 | elif self.count_method is CountMethod.TN5_FRAGMENT: 421 | out_array[:rel_end] += 1 422 | # for forward reads, their mate is in the 3' -> direction 423 | else: 424 | rel_start = read.reference_start - start 425 | if self.offset_tn5: 426 | rel_start += 4 427 | if rel_start < 0 or rel_start >= length: 428 | continue 429 | if self.count_method is CountMethod.TN5_CUTSITE: 430 | out_array[rel_start] += 1 431 | elif self.count_method is CountMethod.TN5_FRAGMENT: 432 | out_array[rel_start:] += 1 433 | 434 | return out_array 435 | -------------------------------------------------------------------------------- /seqdata/_io/readers/fasta.py: -------------------------------------------------------------------------------- 1 | from typing import List, Literal, Optional, Tuple, Union, cast 2 | 3 | import numpy as np 4 | import polars as pl 5 | import pysam 6 | import seqpro as sp 7 | import zarr 8 | from more_itertools import split_when 9 | from numcodecs import ( 10 | Blosc, 11 | VLenBytes, 12 | VLenUTF8, 13 | blosc, # type: ignore 14 | ) 15 | from numpy.typing import NDArray 16 | from tqdm import tqdm 17 | 18 | from seqdata._io.utils import _get_row_batcher 19 | from seqdata.types import FlatReader, PathType, RegionReader 20 | 21 | ### pysam and cyvcf2 implementation NOTE ### 22 | 23 | # pysam.FastaFile.fetch 24 | # contig not found => raises KeyError 25 | # if start < 0 => raises ValueError 26 | # if end > reference length => truncates interval 27 | 28 | 29 | class FlatFASTA(FlatReader): 30 | def __init__( 31 | self, 32 | name: str, 33 | fasta: PathType, 34 | batch_size: int, 35 | n_threads: int = 1, 36 | ) -> None: 37 | """Reader for flat FASTA files. 38 | 39 | Parameters 40 | ---------- 41 | name : str 42 | Name of the sequence array in resulting SeqData Zarr. 43 | fasta : str, Path 44 | Path to FASTA file. 45 | batch_size : int 46 | Number of sequences to read at once. 47 | n_threads : int, default 1 48 | Number of threads to use for reading the FASTA file. 49 | 50 | Returns 51 | ------- 52 | None 53 | """ 54 | self.name = name 55 | self.fasta = fasta 56 | self.batch_size = batch_size 57 | self.n_threads = n_threads 58 | with pysam.FastaFile(str(self.fasta)) as f: 59 | self.n_seqs = len(f.references) 60 | 61 | def _reader(self, f: pysam.FastaFile): 62 | for seq_name in f.references: 63 | seq = f.fetch(seq_name).encode("ascii") 64 | yield seq 65 | 66 | def _write( 67 | self, 68 | out: PathType, 69 | fixed_length: bool, 70 | sequence_dim: str, 71 | length_dim: Optional[str] = None, 72 | overwrite=False, 73 | ) -> None: 74 | if self.name in (sequence_dim, length_dim): 75 | raise ValueError("Name cannot be equal to sequence_dim or length_dim.") 76 | 77 | blosc.set_nthreads(self.n_threads) 78 | compressor = Blosc("zstd", clevel=7, shuffle=-1) 79 | 80 | z = zarr.open_group(out) 81 | with pysam.FastaFile(str(self.fasta)) as f: 82 | seq_names = f.references 83 | length = f.get_reference_length(seq_names[0]) 84 | 85 | arr = z.array( 86 | sequence_dim, 87 | data=np.array(list(seq_names), object), 88 | overwrite=overwrite, 89 | object_codec=VLenUTF8(), 90 | ) 91 | arr.attrs["_ARRAY_DIMENSIONS"] = [sequence_dim] 92 | 93 | n_seqs = len(seq_names) 94 | batch_size = min(n_seqs, self.batch_size) 95 | 96 | if fixed_length: 97 | shape = (n_seqs, length) 98 | dtype = "|S1" 99 | chunks = (batch_size, None) 100 | object_codec = None 101 | seq_dims = [sequence_dim, length_dim] 102 | batch = np.empty((batch_size, length), dtype="|S1") 103 | else: 104 | shape = n_seqs 105 | dtype = object 106 | chunks = batch_size 107 | object_codec = VLenBytes() 108 | seq_dims = [sequence_dim] 109 | batch = np.empty(batch_size, dtype=object) 110 | 111 | seqs = z.empty( 112 | self.name, 113 | shape=shape, 114 | dtype=dtype, 115 | chunks=chunks, 116 | overwrite=overwrite, 117 | compressor=compressor, 118 | object_codec=object_codec, 119 | ) 120 | seqs.attrs["_ARRAY_DIMENSIONS"] = seq_dims 121 | 122 | row_batcher = _get_row_batcher(self._reader(f), batch_size) 123 | for last_row, last_in_batch, seq, batch_idx, start_idx in tqdm( 124 | row_batcher, total=n_seqs 125 | ): 126 | if fixed_length and len(seq) != length: 127 | raise RuntimeError( 128 | """ 129 | Fixed length FlatFASTA reader got sequences with different 130 | lengths. 131 | """ 132 | ) 133 | if fixed_length: 134 | seq = np.frombuffer(seq, "|S1") 135 | batch[batch_idx] = seq 136 | if last_in_batch or last_row: 137 | seqs[start_idx : start_idx + batch_idx + 1] = batch[: batch_idx + 1] 138 | 139 | 140 | class GenomeFASTA(RegionReader): 141 | def __init__( 142 | self, 143 | name: str, 144 | fasta: PathType, 145 | batch_size: int, 146 | n_threads: int = 1, 147 | alphabet: Optional[Union[str, sp.NucleotideAlphabet]] = None, 148 | ) -> None: 149 | self.name = name 150 | self.fasta = fasta 151 | self.batch_size = batch_size 152 | self.n_threads = n_threads 153 | if alphabet is None: 154 | self.alphabet = sp.alphabets.DNA 155 | elif isinstance(alphabet, str): 156 | self.alphabet = getattr(sp.alphabets, alphabet) 157 | else: 158 | self.alphabet = alphabet 159 | 160 | def _reader(self, bed: pl.DataFrame, f: pysam.FastaFile): 161 | for row in tqdm(bed.iter_rows(), total=len(bed)): 162 | contig, start, end = cast(Tuple[str, int, int], row[:3]) 163 | seq = f.fetch(contig, max(0, start), end).encode("ascii") 164 | if (pad_len := end - start - len(seq)) > 0: 165 | pad_left = start < 0 166 | if pad_left: 167 | seq = (b"N" * pad_len) + seq 168 | else: 169 | seq += b"N" * pad_len 170 | yield seq 171 | 172 | def _spliced_reader(self, bed: pl.DataFrame, f: pysam.FastaFile): 173 | pbar = tqdm(total=len(bed)) 174 | for rows in split_when( 175 | bed.iter_rows(), 176 | lambda x, y: x[3] != y[3], # 4th column is "name" 177 | ): 178 | unspliced: List[bytes] = [] 179 | for row in rows: 180 | pbar.update() 181 | contig, start, end = cast(Tuple[str, int, int], row[:3]) 182 | seq = f.fetch(contig, max(0, start), end).encode("ascii") 183 | if (pad_len := end - start - len(seq)) > 0: 184 | pad_left = start < 0 185 | if pad_left: 186 | seq = (b"N" * pad_len) + seq 187 | else: 188 | seq += b"N" * pad_len 189 | unspliced.append(seq) 190 | spliced = b"".join(unspliced) 191 | yield spliced 192 | 193 | def _write( 194 | self, 195 | out: PathType, 196 | bed: pl.DataFrame, 197 | fixed_length: Union[int, Literal[False]], 198 | sequence_dim: str, 199 | length_dim: Optional[str] = None, 200 | splice=False, 201 | overwrite=False, 202 | ) -> None: 203 | if self.name in (sequence_dim, length_dim): 204 | raise ValueError("Name cannot be equal to sequence_dim or length_dim.") 205 | if fixed_length is False: 206 | self._write_variable_length( 207 | out=out, 208 | bed=bed, 209 | sequence_dim=sequence_dim, 210 | overwrite=overwrite, 211 | splice=splice, 212 | ) 213 | else: 214 | assert length_dim is not None 215 | self._write_fixed_length( 216 | out=out, 217 | bed=bed, 218 | fixed_length=fixed_length, 219 | sequence_dim=sequence_dim, 220 | length_dim=length_dim, 221 | overwrite=overwrite, 222 | splice=splice, 223 | ) 224 | 225 | def _write_fixed_length( 226 | self, 227 | out: PathType, 228 | bed: pl.DataFrame, 229 | fixed_length: int, 230 | sequence_dim: str, 231 | length_dim: str, 232 | overwrite: bool, 233 | splice: bool, 234 | ): 235 | blosc.set_nthreads(self.n_threads) 236 | compressor = Blosc("zstd", clevel=7, shuffle=-1) 237 | 238 | if splice: 239 | n_seqs = bed["name"].n_unique() 240 | else: 241 | n_seqs = len(bed) 242 | batch_size = min(n_seqs, self.batch_size) 243 | to_rc = cast(NDArray[np.bool_], bed["strand"].eq_missing("-").to_numpy()) 244 | 245 | root = zarr.open_group(out) 246 | 247 | seqs = root.empty( 248 | self.name, 249 | shape=(n_seqs, fixed_length), 250 | dtype="|S1", 251 | chunks=(batch_size, None), 252 | overwrite=overwrite, 253 | compressor=compressor, 254 | ) 255 | seqs.attrs["_ARRAY_DIMENSIONS"] = [sequence_dim, length_dim] 256 | 257 | batch = cast( 258 | NDArray[np.bytes_], np.empty((batch_size, fixed_length), dtype="|S1") 259 | ) 260 | 261 | with pysam.FastaFile(str(self.fasta)) as f: 262 | if splice: 263 | row_batcher = _get_row_batcher(self._spliced_reader(bed, f), batch_size) 264 | else: 265 | row_batcher = _get_row_batcher(self._reader(bed, f), batch_size) 266 | for is_last_row, is_last_in_batch, seq, idx, start in row_batcher: 267 | seq = np.frombuffer(seq, "|S1") 268 | batch[idx] = seq 269 | if is_last_in_batch or is_last_row: 270 | _batch = batch[: idx + 1] 271 | to_rc_mask = to_rc[start : start + idx + 1] 272 | _batch[to_rc_mask] = self.alphabet.rev_comp_byte( 273 | _batch[to_rc_mask], length_axis=-1 274 | ) 275 | seqs[start : start + idx + 1] = _batch 276 | 277 | def _write_variable_length( 278 | self, 279 | out: PathType, 280 | bed: pl.DataFrame, 281 | sequence_dim: str, 282 | overwrite: bool, 283 | splice: bool, 284 | ): 285 | blosc.set_nthreads(self.n_threads) 286 | compressor = Blosc("zstd", clevel=7, shuffle=-1) 287 | 288 | n_seqs = len(bed) 289 | batch_size = min(n_seqs, self.batch_size) 290 | to_rc = cast(NDArray[np.bool_], bed["strand"].eq_missing("-").to_numpy()) 291 | 292 | root = zarr.open_group(out) 293 | 294 | seqs = root.empty( 295 | self.name, 296 | shape=n_seqs, 297 | dtype=object, 298 | chunks=batch_size, 299 | overwrite=overwrite, 300 | compressor=compressor, 301 | object_codec=VLenBytes(), 302 | ) 303 | seqs.attrs["_ARRAY_DIMENSIONS"] = [sequence_dim] 304 | 305 | batch = cast(NDArray[np.object_], np.empty(batch_size, dtype=object)) 306 | 307 | with pysam.FastaFile(str(self.fasta)) as f: 308 | if splice: 309 | row_batcher = _get_row_batcher(self._spliced_reader(bed, f), batch_size) 310 | else: 311 | row_batcher = _get_row_batcher(self._reader(bed, f), batch_size) 312 | for is_last_row, is_last_in_batch, seq, idx, start in row_batcher: 313 | if to_rc[start + idx]: 314 | batch[idx] = self.alphabet.rev_comp_bstring(seq) 315 | else: 316 | batch[idx] = seq 317 | if is_last_in_batch or is_last_row: 318 | seqs[start : start + idx + 1] = batch[: idx + 1] 319 | -------------------------------------------------------------------------------- /seqdata/_io/readers/table.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from pathlib import Path 3 | from textwrap import dedent 4 | from typing import List, Optional, Union 5 | 6 | import numpy as np 7 | import polars as pl 8 | import zarr 9 | from numcodecs import Blosc, VLenBytes 10 | from tqdm import tqdm 11 | 12 | from seqdata._io.utils import _df_to_xr_zarr 13 | from seqdata.types import FlatReader, ListPathType, PathType 14 | 15 | 16 | class Table(FlatReader): 17 | def __init__( 18 | self, 19 | name: str, 20 | tables: Union[PathType, ListPathType], 21 | seq_col: str, 22 | batch_size: int, 23 | **kwargs, 24 | ) -> None: 25 | """Reader for tabular data. 26 | 27 | Parameters 28 | ---------- 29 | name : str 30 | Name of the data. 31 | tables : str, Path, List[str], List[Path] 32 | Path to table or list of paths to tables. 33 | seq_col : str 34 | Name of the column containing the sequences. 35 | batch_size : int 36 | Number of rows to read at a time. 37 | **kwargs 38 | Additional keyword arguments to pass to `pd.read_csv`. 39 | 40 | Returns 41 | ------- 42 | None 43 | """ 44 | self.name = name 45 | if not isinstance(tables, list): 46 | tables = [Path(tables)] 47 | self.tables = list(map(Path, tables)) 48 | self.seq_col = seq_col 49 | self.batch_size = batch_size 50 | self.kwargs = kwargs 51 | 52 | def _get_reader(self, table: Path): 53 | if ".csv" in table.suffixes: 54 | sep = "," 55 | elif ".tsv" in table.suffixes: 56 | sep = "\t" 57 | elif ".txt" in table.suffixes: 58 | sep = "\t" 59 | else: 60 | sep = None 61 | if sep is None: 62 | return pl.read_csv_batched(table, batch_size=self.batch_size, **self.kwargs) 63 | else: 64 | return pl.read_csv_batched( 65 | table, separator=sep, batch_size=self.batch_size, **self.kwargs 66 | ) 67 | 68 | def _write_first_variable_length( 69 | self, 70 | batch: pl.DataFrame, 71 | root: zarr.Group, 72 | compressor, 73 | sequence_dim: str, 74 | overwrite: bool, 75 | ): 76 | seqs = batch[self.seq_col].cast(pl.Binary).to_numpy() 77 | obs = batch.drop(self.seq_col) 78 | arr = root.array( 79 | self.name, 80 | data=seqs, 81 | chunks=self.batch_size, 82 | compressor=compressor, 83 | overwrite=overwrite, 84 | object_codec=VLenBytes(), 85 | ) 86 | arr.attrs["_ARRAY_DIMENSIONS"] = [sequence_dim] 87 | _df_to_xr_zarr( 88 | obs, 89 | root, 90 | sequence_dim, 91 | chunks=self.batch_size, 92 | compressor=compressor, 93 | overwrite=overwrite, 94 | ) 95 | first_cols = obs.columns 96 | return first_cols 97 | 98 | def _write_variable_length( 99 | self, batch: pl.DataFrame, root: zarr.Group, first_cols: List, table: Path 100 | ): 101 | seqs = batch[self.seq_col].cast(pl.Binary).to_numpy() 102 | obs = batch.drop(self.seq_col) 103 | if ( 104 | np.isin(obs.columns, first_cols, invert=True).any() 105 | or np.isin(first_cols, obs.columns, invert=True).any() 106 | ): 107 | raise RuntimeError( 108 | dedent( 109 | f"""Mismatching columns. 110 | First table {self.tables[0]} has columns {first_cols} 111 | Mismatched table {table} has columns {obs.columns} 112 | """ 113 | ).strip() 114 | ) 115 | root[self.name].append(seqs) # type: ignore 116 | for series in obs: 117 | root[series.name].append(series.to_numpy()) # type: ignore 118 | 119 | def _write_first_fixed_length( 120 | self, 121 | batch: pl.DataFrame, 122 | root: zarr.Group, 123 | compressor, 124 | sequence_dim: str, 125 | length_dim: str, 126 | overwrite: bool, 127 | ): 128 | seqs = ( 129 | batch[self.seq_col] 130 | .cast(pl.Binary) 131 | .to_numpy() 132 | .astype("S")[..., None] 133 | .view("S1") 134 | ) 135 | obs = batch.drop(self.seq_col) 136 | arr = root.array( 137 | self.name, 138 | data=seqs, 139 | chunks=(self.batch_size, None), 140 | compressor=compressor, 141 | overwrite=overwrite, 142 | ) 143 | arr.attrs["_ARRAY_DIMENSIONS"] = [sequence_dim, length_dim] 144 | _df_to_xr_zarr( 145 | obs, 146 | root, 147 | sequence_dim, 148 | chunks=self.batch_size, 149 | compressor=compressor, 150 | overwrite=overwrite, 151 | ) 152 | first_cols = obs.columns 153 | return first_cols 154 | 155 | def _write_fixed_length( 156 | self, batch: pl.DataFrame, root: zarr.Group, first_cols: List, table: Path 157 | ): 158 | seqs = ( 159 | batch[self.seq_col] 160 | .cast(pl.Binary) 161 | .to_numpy() 162 | .astype("S")[..., None] 163 | .view("S1") 164 | ) 165 | obs = batch.drop(self.seq_col) 166 | if ( 167 | np.isin(obs.columns, first_cols, invert=True).any() 168 | or np.isin(first_cols, obs.columns, invert=True).any() 169 | ): 170 | raise RuntimeError( 171 | dedent( 172 | f"""Mismatching columns. 173 | First table {self.tables[0]} has columns {first_cols} 174 | Mismatched table {table} has columns {obs.columns} 175 | """ 176 | ).strip() 177 | ) 178 | root[self.name].append(seqs) # type: ignore 179 | for series in obs: 180 | root[series.name].append(series.to_numpy()) # type: ignore 181 | 182 | def _write( 183 | self, 184 | out: PathType, 185 | fixed_length: bool, 186 | sequence_dim: str, 187 | length_dim: Optional[str] = None, 188 | overwrite=False, 189 | ) -> None: 190 | compressor = Blosc("zstd", clevel=7, shuffle=-1) 191 | root = zarr.open_group(out) 192 | 193 | if fixed_length: 194 | assert length_dim is not None 195 | write_first = partial( 196 | self._write_first_fixed_length, 197 | sequence_dim=sequence_dim, 198 | length_dim=length_dim, 199 | ) 200 | write_batch = self._write_fixed_length 201 | else: 202 | write_first = partial( 203 | self._write_first_variable_length, sequence_dim=sequence_dim 204 | ) 205 | write_batch = self._write_variable_length 206 | 207 | pbar = tqdm() 208 | first_batch = True 209 | for table in self.tables: 210 | reader = self._get_reader(table) 211 | while batch := reader.next_batches(1): 212 | batch = batch[0] 213 | if first_batch: 214 | first_cols = write_first( 215 | batch=batch, 216 | root=root, 217 | compressor=compressor, 218 | overwrite=overwrite, 219 | ) 220 | first_batch = False 221 | else: 222 | write_batch( 223 | batch, 224 | root, 225 | first_cols, # type: ignore guaranteed to be set during first batch 226 | table, 227 | ) 228 | pbar.update(len(batch)) 229 | -------------------------------------------------------------------------------- /seqdata/_io/readers/vcf.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from pathlib import Path 3 | from typing import List, Literal, Optional, Set, Tuple, Union, cast 4 | 5 | import cyvcf2 6 | import numpy as np 7 | import polars as pl 8 | import pysam 9 | import seqpro as sp 10 | import zarr 11 | from more_itertools import split_when 12 | from natsort import natsorted 13 | from numcodecs import ( 14 | Blosc, 15 | VLenBytes, 16 | VLenUTF8, 17 | blosc, # type: ignore 18 | ) 19 | from numpy.typing import NDArray 20 | from tqdm import tqdm 21 | 22 | from seqdata._io.utils import _get_row_batcher 23 | from seqdata.types import PathType, RegionReader 24 | 25 | N_HAPLOTYPES = 2 26 | 27 | 28 | ### pysam and cyvcf2 implementation NOTE ### 29 | 30 | # pysam.FastaFile.fetch 31 | # contig not found => raises KeyError 32 | # if start < 0 => raises ValueError 33 | # if end > reference length => truncates interval 34 | 35 | # cyvcf2.VCF 36 | # Contig not found => warning 37 | # start < 0 => warning 38 | # start = 0 (despite being 1-indexed) => nothing 39 | # end > contig length => treats end = contig length 40 | 41 | 42 | class VCF(RegionReader): 43 | name: str 44 | vcf: Path 45 | fasta: Path 46 | contigs: List[str] 47 | 48 | def __init__( 49 | self, 50 | name: str, 51 | vcf: PathType, 52 | fasta: PathType, 53 | batch_size: int, 54 | samples: Optional[List[str]] = None, 55 | n_threads=1, 56 | samples_per_chunk=10, 57 | alphabet: Optional[Union[str, sp.NucleotideAlphabet]] = None, 58 | sample_dim: Optional[str] = None, 59 | haplotype_dim: Optional[str] = None, 60 | ) -> None: 61 | """Warning: This reader is experimental and may change in the future. 62 | For a more comprehensive VCF reader, see https://github.com/mcvickerlab/GenVarLoader. 63 | 64 | Reader for variant call format (VCF) data. 65 | 66 | Parameters 67 | ---------- 68 | name : str 69 | Name of the data. 70 | vcf : str, Path 71 | Path to the VCF file. 72 | fasta : str, Path 73 | Path to the FASTA file. 74 | batch_size : int 75 | Number of sequences to read at a time. 76 | samples : List[str], optional 77 | List of sample names to read from the VCF. 78 | n_threads : int, optional 79 | Number of threads to use when reading the VCF. 80 | samples_per_chunk : int, optional 81 | Number of samples to read at a time. 82 | alphabet : str, NucleotideAlphabet, optional 83 | Alphabet to use for the sequences. 84 | sample_dim : str, optional 85 | Name of the sample dimension. 86 | haplotype_dim : str, optional 87 | Name of the haplotype dimension. 88 | 89 | Returns 90 | ------- 91 | None 92 | """ 93 | self.name = name 94 | self.vcf = Path(vcf) 95 | self.fasta = Path(fasta) 96 | self.batch_size = batch_size 97 | self.n_threads = n_threads 98 | self.samples_per_chunk = samples_per_chunk 99 | if alphabet is None: 100 | self.alphabet = sp.alphabets.DNA 101 | elif isinstance(alphabet, str): 102 | self.alphabet = getattr(sp.alphabets, alphabet) 103 | else: 104 | self.alphabet = alphabet 105 | self.sample_dim = f"{name}_sample" if sample_dim is None else sample_dim 106 | self.haplotype_dim = ( 107 | f"{name}_haplotype" if haplotype_dim is None else haplotype_dim 108 | ) 109 | 110 | with pysam.FastaFile(str(fasta)) as f: 111 | fasta_contigs = set(f.references) 112 | _vcf = cyvcf2.VCF(str(vcf), samples=samples) 113 | self.samples = _vcf.samples if samples is None else samples 114 | try: 115 | vcf_contigs = cast(Set[str], set(_vcf.seqnames)) 116 | except AttributeError: 117 | warnings.warn("VCF header has no contig annotations.") 118 | vcf_contigs: Set[str] = set() 119 | _vcf.close() 120 | 121 | self.contigs = cast(List[str], natsorted(fasta_contigs | vcf_contigs)) 122 | if len(self.contigs) == 0: 123 | raise RuntimeError("FASTA has no contigs.") 124 | # * don't check for contigs exclusive to FASTA because VCF is not guaranteed to 125 | # * have any contigs listed 126 | contigs_exclusive_to_vcf = natsorted(vcf_contigs - fasta_contigs) 127 | if contigs_exclusive_to_vcf: 128 | warnings.warn( 129 | f"""VCF has contigs not found in FASTA that may indicate variant calling 130 | was against a different reference than what was given to SeqData. 131 | Contigs that are exclusive to the VCF are: {contigs_exclusive_to_vcf}""" 132 | ) 133 | 134 | def _get_pos_alleles(self, v) -> Tuple[int, NDArray[np.bytes_]]: 135 | # change to bytes and extract alleles 136 | # (samples haplotypes) 137 | alleles = v.gt_bases.astype("S").reshape(-1, 1).view("S1")[:, [0, 2]] 138 | # change unknown to reference 139 | alleles[alleles == b"."] = bytes(v.REF, "ascii") 140 | # make position 0-indexed 141 | return v.POS - 1, alleles 142 | 143 | def _reader( 144 | self, 145 | bed: pl.DataFrame, 146 | f: pysam.FastaFile, 147 | vcf: cyvcf2.VCF, 148 | sample_order: NDArray[np.intp], 149 | ): 150 | for row in tqdm(bed.iter_rows(), total=len(bed)): 151 | contig, start, end = row[:3] 152 | start, end = cast(int, start), cast(int, end) 153 | seq_bytes = f.fetch(contig, max(start, 0), end).encode("ascii") 154 | pad_left = -min(start, 0) 155 | pad_right = end - start - len(seq_bytes) - pad_left 156 | seq_bytes = b"N" * pad_left + seq_bytes + b"N" * pad_right 157 | seq = cast(NDArray[np.bytes_], np.array([seq_bytes], "S").view("S1")) 158 | # (samples haplotypes length) 159 | tiled_seq = np.tile(seq, (len(self.samples), N_HAPLOTYPES, 1)) 160 | 161 | region = f"{contig}:{max(start, 0)+1}-{end}" 162 | positions_alleles = [ 163 | self._get_pos_alleles(v) for v in vcf(region) if v.is_snp 164 | ] 165 | 166 | # no variants in region 167 | if len(positions_alleles) == 0: 168 | yield tiled_seq 169 | continue 170 | 171 | positions_ls, alleles_ls = zip(*positions_alleles) 172 | # (variants) 173 | relative_positions = cast(NDArray[np.int64], np.array(positions_ls)) - start 174 | # (samples haplotypes variants) 175 | alleles = cast(NDArray[np.bytes_], np.stack(alleles_ls, -1)[sample_order]) 176 | # (samples haplotypes variants) = (samples haplotypes variants) 177 | tiled_seq[..., relative_positions] = alleles 178 | # (samples haplotypes length) 179 | yield tiled_seq 180 | 181 | def _spliced_reader( 182 | self, 183 | bed: pl.DataFrame, 184 | f: pysam.FastaFile, 185 | vcf: cyvcf2.VCF, 186 | sample_order: NDArray[np.intp], 187 | ): 188 | pbar = tqdm(total=len(bed)) 189 | for rows in split_when( 190 | bed.iter_rows(), 191 | lambda x, y: x[3] != y[3], # 4th column is "name" 192 | ): 193 | unspliced: List[NDArray[np.bytes_]] = [] 194 | for row in rows: 195 | pbar.update() 196 | contig, start, end = row[:3] 197 | start, end = cast(int, start), cast(int, end) 198 | seq_bytes = f.fetch(contig, max(start, 0), end).encode("ascii") 199 | pad_left = -min(start, 0) 200 | pad_right = end - start - len(seq_bytes) - pad_left 201 | seq_bytes = b"N" * pad_left + seq_bytes + b"N" * pad_right 202 | seq = cast(NDArray[np.bytes_], np.frombuffer(seq_bytes, "|S1")) 203 | # (samples haplotypes length) 204 | tiled_seq = np.tile(seq, (len(self.samples), 2, 1)) 205 | 206 | region = f"{contig}:{max(start, 0)+1}-{end}" 207 | positions_alleles = [ 208 | self._get_pos_alleles(v) for v in vcf(region) if v.is_snp 209 | ] 210 | # no variants in region 211 | if len(positions_alleles) == 0: 212 | unspliced.append(tiled_seq) 213 | continue 214 | 215 | positions_ls, alleles_ls = zip(*positions_alleles) 216 | # (variants) 217 | relative_positions = ( 218 | cast(NDArray[np.int64], np.array(positions_ls)) - start 219 | ) 220 | # (samples haplotypes variants) 221 | alleles = cast( 222 | NDArray[np.bytes_], np.stack(alleles_ls, -1)[sample_order] 223 | ) 224 | # (samples haplotypes variants) = (samples haplotypes variants) 225 | tiled_seq[..., relative_positions] = alleles 226 | unspliced.append(tiled_seq) 227 | # list of (samples haplotypes length) 228 | yield np.concatenate(unspliced, -1) 229 | 230 | def _write( 231 | self, 232 | out: PathType, 233 | bed: pl.DataFrame, 234 | fixed_length: Union[int, Literal[False]], 235 | sequence_dim: str, 236 | length_dim: Optional[str] = None, 237 | splice=False, 238 | overwrite=False, 239 | ) -> None: 240 | if self.name in (sequence_dim, self.sample_dim, self.haplotype_dim, length_dim): 241 | raise ValueError( 242 | """Name cannot be equal to sequence_dim, sample_dim, haplotype_dim, or 243 | length_dim.""" 244 | ) 245 | 246 | if fixed_length is False: 247 | self._write_variable_length( 248 | out=out, 249 | bed=bed, 250 | sequence_dim=sequence_dim, 251 | overwrite=overwrite, 252 | splice=splice, 253 | ) 254 | else: 255 | assert length_dim is not None 256 | self._write_fixed_length( 257 | out=out, 258 | bed=bed, 259 | fixed_length=fixed_length, 260 | sequence_dim=sequence_dim, 261 | length_dim=length_dim, 262 | overwrite=overwrite, 263 | splice=splice, 264 | ) 265 | 266 | def _write_fixed_length( 267 | self, 268 | out: PathType, 269 | bed: pl.DataFrame, 270 | fixed_length: int, 271 | sequence_dim: str, 272 | length_dim: str, 273 | overwrite: bool, 274 | splice: bool, 275 | ): 276 | blosc.set_nthreads(self.n_threads) 277 | compressor = Blosc("zstd", clevel=7, shuffle=-1) 278 | 279 | n_seqs = bed["name"].n_unique() if splice else len(bed) 280 | batch_size = min(n_seqs, self.batch_size) 281 | 282 | z = zarr.open_group(out) 283 | 284 | seqs = z.empty( 285 | self.name, 286 | shape=(n_seqs, len(self.samples), N_HAPLOTYPES, fixed_length), 287 | dtype="|S1", 288 | chunks=(batch_size, self.samples_per_chunk, 1, None), 289 | overwrite=overwrite, 290 | compressor=compressor, 291 | ) 292 | seqs.attrs["_ARRAY_DIMENSIONS"] = [ 293 | sequence_dim, 294 | self.sample_dim, 295 | self.haplotype_dim, 296 | length_dim, 297 | ] 298 | 299 | arr = z.array( 300 | self.sample_dim, 301 | np.array(self.samples, object), 302 | compressor=compressor, 303 | overwrite=overwrite, 304 | object_codec=VLenUTF8(), 305 | ) 306 | arr.attrs["_ARRAY_DIMENSIONS"] = [self.sample_dim] 307 | 308 | to_rc = cast(NDArray[np.bool_], bed["strand"].eq_missing("-").to_numpy()) 309 | 310 | _vcf = cyvcf2.VCF( 311 | self.vcf, lazy=True, samples=self.samples, threads=self.n_threads 312 | ) 313 | *_, sample_order = np.intersect1d( 314 | _vcf.samples, self.samples, assume_unique=True, return_indices=True 315 | ) 316 | 317 | # (batch samples haplotypes length) 318 | batch = cast( 319 | NDArray[np.bytes_], np.empty((batch_size, *seqs.shape[1:]), dtype="|S1") 320 | ) 321 | 322 | with pysam.FastaFile(str(self.fasta)) as f: 323 | if splice: 324 | reader = self._spliced_reader 325 | else: 326 | reader = self._reader 327 | row_batcher = _get_row_batcher( 328 | reader(bed, f, _vcf, sample_order), batch_size 329 | ) 330 | for is_last_row, is_last_in_batch, seq, idx, start in row_batcher: 331 | # (samples haplotypes length) 332 | batch[idx] = seq 333 | if is_last_in_batch or is_last_row: 334 | _batch = batch[: idx + 1] 335 | to_rc_mask = to_rc[start : start + idx + 1] 336 | _batch[to_rc_mask] = self.alphabet.rev_comp_byte( 337 | _batch[to_rc_mask], length_axis=-1 338 | ) 339 | seqs[start : start + idx + 1] = _batch[: idx + 1] 340 | 341 | _vcf.close() 342 | 343 | def _write_variable_length( 344 | self, 345 | out: PathType, 346 | bed: pl.DataFrame, 347 | sequence_dim: str, 348 | overwrite: bool, 349 | splice: bool, 350 | ): 351 | blosc.set_nthreads(self.n_threads) 352 | compressor = Blosc("zstd", clevel=7, shuffle=-1) 353 | 354 | n_seqs = bed["name"].n_unique() if splice else len(bed) 355 | batch_size = min(n_seqs, self.batch_size) 356 | 357 | z = zarr.open_group(out) 358 | 359 | seqs = z.empty( 360 | self.name, 361 | shape=(n_seqs, len(self.samples), N_HAPLOTYPES), 362 | dtype=object, 363 | chunks=(batch_size, self.samples_per_chunk, 1), 364 | overwrite=overwrite, 365 | compressor=compressor, 366 | object_codec=VLenBytes(), 367 | ) 368 | seqs.attrs["_ARRAY_DIMENSIONS"] = [ 369 | sequence_dim, 370 | self.sample_dim, 371 | self.haplotype_dim, 372 | ] 373 | 374 | arr = z.array( 375 | self.sample_dim, 376 | np.array(self.samples, object), 377 | compressor=compressor, 378 | overwrite=overwrite, 379 | object_codec=VLenUTF8(), 380 | ) 381 | arr.attrs["_ARRAY_DIMENSIONS"] = [self.sample_dim] 382 | 383 | to_rc = cast(NDArray[np.bool_], bed["strand"].eq_missing("-").to_numpy()) 384 | 385 | _vcf = cyvcf2.VCF( 386 | self.vcf, lazy=True, samples=self.samples, threads=self.n_threads 387 | ) 388 | *_, sample_order = np.intersect1d( 389 | _vcf.samples, self.samples, assume_unique=True, return_indices=True 390 | ) 391 | 392 | # (batch samples haplotypes) 393 | batch = cast( 394 | NDArray[np.object_], np.empty((batch_size, *seqs.shape[1:]), dtype=object) 395 | ) 396 | 397 | with pysam.FastaFile(str(self.fasta)) as f: 398 | if splice: 399 | reader = self._spliced_reader 400 | else: 401 | reader = self._reader 402 | row_batcher = _get_row_batcher( 403 | reader(bed, f, _vcf, sample_order), batch_size 404 | ) 405 | for is_last_row, is_last_in_batch, seq, idx, start in row_batcher: 406 | # (samples haplotypes length) 407 | if to_rc[idx]: 408 | seq = self.alphabet.rev_comp_byte(seq, length_axis=-1) 409 | # (samples haplotypes) 410 | batch[idx] = seq.view(f"|S{seq.shape[-1]}").squeeze().astype(object) 411 | if is_last_in_batch or is_last_row: 412 | seqs[start : start + idx + 1] = batch[: idx + 1] 413 | 414 | _vcf.close() 415 | 416 | def _sequence_generator(self, bed: pl.DataFrame, splice=False): 417 | to_rc = cast(NDArray[np.bool_], bed["strand"].eq_missing("-").to_numpy()) 418 | 419 | _vcf = cyvcf2.VCF( 420 | self.vcf, lazy=True, samples=self.samples, threads=self.n_threads 421 | ) 422 | *_, sample_order = np.intersect1d( 423 | _vcf.samples, self.samples, assume_unique=True, return_indices=True 424 | ) 425 | 426 | with pysam.FastaFile(str(self.fasta)) as f: 427 | if splice: 428 | reader = self._spliced_reader 429 | else: 430 | reader = self._reader 431 | # (samples haplotypes length) 432 | for i, seqs in enumerate(reader(bed, f, _vcf, sample_order)): 433 | if to_rc[i]: 434 | seqs = self.alphabet.rev_comp_byte(seqs, length_axis=-1) 435 | seqs = seqs.view(f"|S{seqs.shape[-1]}") 436 | for seq in seqs.ravel(): 437 | yield seq 438 | -------------------------------------------------------------------------------- /seqdata/_io/readers/wig.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Any, Dict, List, Literal, Optional, Union, cast 3 | 4 | import joblib 5 | import numpy as np 6 | import polars as pl 7 | import pyBigWig 8 | import zarr 9 | from more_itertools import split_when 10 | from numcodecs import ( 11 | Blosc, 12 | Delta, 13 | VLenArray, 14 | VLenUTF8, 15 | blosc, # type: ignore 16 | ) 17 | from numpy.typing import NDArray 18 | from tqdm import tqdm 19 | 20 | from seqdata._io.utils import _get_row_batcher 21 | from seqdata.types import ListPathType, PathType, RegionReader 22 | 23 | 24 | class BigWig(RegionReader): 25 | DTYPE = np.float32 # BigWig only supports float32 26 | 27 | def __init__( 28 | self, 29 | name: str, 30 | bigwigs: ListPathType, 31 | samples: List[str], 32 | batch_size: int, 33 | n_jobs=1, 34 | threads_per_job=1, 35 | sample_dim: Optional[str] = None, 36 | ) -> None: 37 | """Reader for BigWig files. 38 | 39 | Parameters 40 | ---------- 41 | name : str 42 | Name of sequence variable in resulting Zarr. 43 | bigwigs : List[str], List[Path] 44 | Paths to BigWig files. 45 | samples : List[str] 46 | Names of samples corresponding to BigWig files. 47 | batch_size : int 48 | Number of regions to read at a time. 49 | n_jobs : int, default 1 50 | Number of jobs to run in parallel. 51 | threads_per_job : int, default 1 52 | Number of threads per job. 53 | sample_dim : str, default None 54 | Name of sample dimension. 55 | 56 | Returns 57 | ------- 58 | None 59 | """ 60 | self.name = name 61 | self.bigwigs = list(map(Path, bigwigs)) 62 | self.samples = samples 63 | self.batch_size = batch_size 64 | self.n_jobs = n_jobs 65 | self.threads_per_job = threads_per_job 66 | self.sample_dim = f"{name}_sample" if sample_dim is None else sample_dim 67 | 68 | def _write( 69 | self, 70 | out: PathType, 71 | bed: pl.DataFrame, 72 | fixed_length: Union[int, Literal[False]], 73 | sequence_dim: str, 74 | length_dim: Optional[str] = None, 75 | splice=False, 76 | overwrite=False, 77 | ) -> None: 78 | if self.name in (sequence_dim, self.sample_dim, length_dim): 79 | raise ValueError( 80 | "Name cannot be equal to sequence_dim, sample_dim, or length_dim." 81 | ) 82 | if fixed_length is False: 83 | self._write_variable_length( 84 | out=out, 85 | bed=bed, 86 | sequence_dim=sequence_dim, 87 | overwrite=overwrite, 88 | splice=splice, 89 | ) 90 | else: 91 | assert length_dim is not None 92 | self._write_fixed_length( 93 | out=out, 94 | bed=bed, 95 | fixed_length=fixed_length, 96 | sequence_dim=sequence_dim, 97 | length_dim=length_dim, 98 | overwrite=overwrite, 99 | splice=splice, 100 | ) 101 | 102 | def _write_fixed_length( 103 | self, 104 | out: PathType, 105 | bed: pl.DataFrame, 106 | fixed_length: int, 107 | sequence_dim: str, 108 | length_dim: str, 109 | overwrite: bool, 110 | splice: bool, 111 | ): 112 | compressor = Blosc("zstd", clevel=7, shuffle=-1) 113 | 114 | n_seqs = bed["name"].n_unique() if splice else len(bed) 115 | batch_size = min(n_seqs, self.batch_size) 116 | z = zarr.open_group(out) 117 | 118 | arr = z.array( 119 | self.sample_dim, 120 | data=np.array(self.samples, object), 121 | compressor=compressor, 122 | overwrite=overwrite, 123 | object_codec=VLenUTF8(), 124 | ) 125 | arr.attrs["_ARRAY_DIMENSIONS"] = [self.sample_dim] 126 | 127 | coverage = z.zeros( 128 | self.name, 129 | shape=(n_seqs, len(self.samples), fixed_length), 130 | dtype=self.DTYPE, 131 | chunks=(batch_size, 1, None), 132 | overwrite=overwrite, 133 | compressor=compressor, 134 | filters=[Delta(self.DTYPE)], 135 | ) 136 | coverage.attrs["_ARRAY_DIMENSIONS"] = [ 137 | sequence_dim, 138 | self.sample_dim, 139 | length_dim, 140 | ] 141 | 142 | sample_idxs = np.arange(len(self.samples)) 143 | tasks = [ 144 | joblib.delayed(self._read_bigwig_fixed_length)( 145 | coverage, 146 | bigwig, 147 | bed, 148 | batch_size, 149 | sample_idx, 150 | self.threads_per_job, 151 | fixed_length=fixed_length, 152 | splice=splice, 153 | ) 154 | for bigwig, sample_idx in zip(self.bigwigs, sample_idxs) 155 | ] 156 | with joblib.parallel_backend( 157 | "loky", n_jobs=self.n_jobs, inner_max_num_threads=self.threads_per_job 158 | ): 159 | joblib.Parallel()(tasks) 160 | 161 | def _write_variable_length( 162 | self, 163 | out: PathType, 164 | bed: pl.DataFrame, 165 | sequence_dim: str, 166 | overwrite: bool, 167 | splice: bool, 168 | ): 169 | compressor = Blosc("zstd", clevel=7, shuffle=-1) 170 | 171 | n_seqs = bed["name"].n_unique() if splice else len(bed) 172 | batch_size = min(n_seqs, self.batch_size) 173 | z = zarr.open_group(out) 174 | 175 | arr = z.array( 176 | self.sample_dim, 177 | data=np.array(self.samples, object), 178 | compressor=compressor, 179 | overwrite=overwrite, 180 | object_codec=VLenUTF8(), 181 | ) 182 | arr.attrs["_ARRAY_DIMENSIONS"] = [self.sample_dim] 183 | 184 | coverage = z.empty( 185 | self.name, 186 | shape=(n_seqs, len(self.samples)), 187 | dtype=object, 188 | chunks=(batch_size, 1), 189 | overwrite=overwrite, 190 | compressor=compressor, 191 | filters=[Delta(self.DTYPE)], 192 | object_codec=VLenArray(self.DTYPE), 193 | ) 194 | coverage.attrs["_ARRAY_DIMENSIONS"] = [ 195 | sequence_dim, 196 | self.sample_dim, 197 | ] 198 | 199 | sample_idxs = np.arange(len(self.samples)) 200 | tasks = [ 201 | joblib.delayed(self._read_bigwig_variable_length)( 202 | coverage, 203 | bigwig, 204 | bed, 205 | batch_size, 206 | sample_idx, 207 | self.threads_per_job, 208 | splice=splice, 209 | ) 210 | for bigwig, sample_idx in zip(self.bigwigs, sample_idxs) 211 | ] 212 | with joblib.parallel_backend( 213 | "loky", n_jobs=self.n_jobs, inner_max_num_threads=self.threads_per_job 214 | ): 215 | joblib.Parallel()(tasks) 216 | 217 | def _reader(self, bed: pl.DataFrame, f, contig_lengths: Dict[str, int]): 218 | for row in tqdm(bed.iter_rows(), total=len(bed)): 219 | contig, start, end = row[:3] 220 | pad_left = max(-start, 0) 221 | pad_right = max(end - contig_lengths[contig], 0) 222 | pad_right_idx = end - start - pad_right 223 | out = np.empty(end - start, dtype=self.DTYPE) 224 | out[:pad_left] = 0 225 | out[pad_right_idx:] = 0 226 | values = cast( 227 | NDArray, 228 | f.values( 229 | contig, max(0, start), min(contig_lengths[contig], end), numpy=True 230 | ), 231 | ) 232 | np.nan_to_num(values, copy=False) 233 | out[pad_left:pad_right_idx] = values 234 | yield out 235 | 236 | def _spliced_reader(self, bed: pl.DataFrame, f, contig_lengths: Dict[str, int]): 237 | pbar = tqdm(total=len(bed)) 238 | for rows in split_when( 239 | bed.iter_rows(), 240 | lambda x, y: x[3] != y[3], # 4th column is "name" 241 | ): 242 | unspliced: List[NDArray[Any]] = [] 243 | for row in rows: 244 | pbar.update() 245 | contig, start, end = row[:3] 246 | values = np.empty(end - start, dtype=self.DTYPE) 247 | pad_left = max(-start, 0) 248 | pad_right = max(end - contig_lengths[contig], 0) 249 | pad_right_idx = end - start - pad_right 250 | values[:pad_left] = 0 251 | values[pad_right_idx:] = 0 252 | _values = cast(NDArray, f.values(contig, start, end, numpy=True)) 253 | np.nan_to_num(_values, copy=False) 254 | values[pad_left:pad_right_idx] = _values 255 | unspliced.append(values) 256 | yield np.concatenate(unspliced) 257 | 258 | def _read_bigwig_fixed_length( 259 | self, 260 | coverage: zarr.Array, 261 | bigwig: PathType, 262 | bed: pl.DataFrame, 263 | batch_size: int, 264 | sample_idx: int, 265 | n_threads: int, 266 | fixed_length: int, 267 | splice: bool, 268 | ): 269 | blosc.set_nthreads(n_threads) 270 | to_rc = cast(NDArray[np.bool_], bed["strand"].eq_missing("-").to_numpy()) 271 | 272 | batch = np.empty((batch_size, fixed_length), dtype=self.DTYPE) 273 | 274 | with pyBigWig.open(str(bigwig)) as f: 275 | if splice: 276 | reader = self._spliced_reader 277 | else: 278 | reader = self._reader 279 | contig_lengths = f.chroms() 280 | row_batcher = _get_row_batcher(reader(bed, f, contig_lengths), batch_size) 281 | for is_last_row, is_last_in_batch, values, idx, start in row_batcher: 282 | batch[idx] = values 283 | if is_last_row or is_last_in_batch: 284 | _batch = batch[: idx + 1] 285 | to_rc_mask = to_rc[start : start + idx + 1] 286 | _batch[to_rc_mask] = _batch[to_rc_mask, ::-1] 287 | coverage[start : start + idx + 1, sample_idx] = _batch 288 | 289 | def _read_bigwig_variable_length( 290 | self, 291 | coverage: zarr.Array, 292 | bigwig: PathType, 293 | bed: pl.DataFrame, 294 | batch_size: int, 295 | sample_idx: int, 296 | n_threads: int, 297 | splice: bool, 298 | ): 299 | blosc.set_nthreads(n_threads) 300 | to_rc = cast(NDArray[np.bool_], bed["strand"].eq_missing("-").to_numpy()) 301 | 302 | batch = np.empty(batch_size, object) 303 | 304 | with pyBigWig.open(str(bigwig)) as f: 305 | if splice: 306 | reader = self._spliced_reader 307 | else: 308 | reader = self._reader 309 | contig_lengths = f.chroms() 310 | row_batcher = _get_row_batcher(reader(bed, f, contig_lengths), batch_size) 311 | for is_last_row, is_last_in_batch, values, idx, start in row_batcher: 312 | if to_rc[idx]: 313 | batch[idx] = values[::-1] 314 | else: 315 | batch[idx] = values 316 | if is_last_in_batch or is_last_row: 317 | coverage[start : start + idx + 1, sample_idx] = batch[: idx + 1] 318 | -------------------------------------------------------------------------------- /seqdata/_io/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from itertools import count, cycle 3 | from subprocess import CalledProcessError, run 4 | from textwrap import dedent 5 | from typing import Generator, Tuple 6 | 7 | import numpy as np 8 | import polars as pl 9 | import zarr 10 | from more_itertools import mark_ends, repeat_each 11 | from numcodecs import VLenArray, VLenBytes, VLenUTF8 12 | 13 | from seqdata.types import T 14 | 15 | 16 | def _df_to_xr_zarr(df: pl.DataFrame, root: zarr.Group, dim: str, **kwargs): 17 | for series in df: 18 | data = series.to_numpy() 19 | if data.dtype.type == np.object_: 20 | if isinstance(data[0], np.ndarray): 21 | object_codec = VLenArray(data[0].dtype) 22 | elif isinstance(data[0], str): 23 | object_codec = VLenUTF8() 24 | elif isinstance(data[0], bytes): 25 | object_codec = VLenBytes() 26 | else: 27 | raise ValueError("Got column in dataframe that isn't serializable.") 28 | else: 29 | object_codec = None 30 | arr = root.array(series.name, data, object_codec=object_codec, **kwargs) 31 | arr.attrs["_ARRAY_DIMENSIONS"] = [dim] 32 | 33 | 34 | def _get_row_batcher( 35 | reader: Generator[T, None, None], batch_size: int 36 | ) -> Generator[Tuple[bool, bool, T, int, int], None, None]: 37 | batch_idxs = cycle(mark_ends(range(batch_size))) 38 | start_idxs = repeat_each(count(0, batch_size), batch_size) 39 | for row_info, batch_info, start_idx in zip( 40 | mark_ends(reader), batch_idxs, start_idxs 41 | ): 42 | first_row, last_row, row = row_info 43 | first_in_batch, last_in_batch, batch_idx = batch_info 44 | yield last_row, last_in_batch, row, batch_idx, start_idx 45 | 46 | 47 | def run_shell(cmd: str, logger: logging.Logger, **kwargs): 48 | try: 49 | status = run(dedent(cmd).strip(), check=True, shell=True, **kwargs) 50 | except CalledProcessError as e: 51 | logger.error(e.stdout) 52 | logger.error(e.stderr) 53 | raise e 54 | return status 55 | -------------------------------------------------------------------------------- /seqdata/datasets.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing_extensions import Literal 3 | 4 | from pathlib import Path 5 | import pooch 6 | import tarfile 7 | 8 | 9 | # This is a global variable used to store all datasets. It is initialized only once 10 | # when the data is requested. 11 | _datasets = None 12 | 13 | def datasets(): 14 | global _datasets 15 | if _datasets is None: 16 | _datasets = pooch.create( 17 | path=pooch.os_cache("seqdata"), 18 | base_url="https://zenodo.org/records/11415225/files/", 19 | env="SEQDATA_DATA_DIR", # The user can overwrite the storage path by setting this environment variable. 20 | registry={ 21 | 22 | # K562-HepG2-SKNSH MPRA 23 | 24 | # K562 ATAC-seq 25 | "K562_ATAC-seq.zarr.tar.gz": "sha256:da601746f933a623fc0465c172f0338425690d480ae4aa7c6d645f02d32a7504", 26 | "signal.bw": "sha256:df4b2af6ad7612207dcb4f6acce41e8f731b08d2d84c00263f280325c9be8f53", 27 | 28 | # K562 CTCF ChIP-seq 29 | "K562_CTCF-ChIP-seq.zarr.tar.gz": "sha256:c0098fce7464459e88c8b1ef30cad84c1931b818273c07f80005eeb9037e8276", 30 | "plus.bw": "sha256:005ba907136c477754c287113b3479a68121c47368455fef9f19f593e2623462", 31 | "minus.bw": "sha256:2ff74b44bea80b1c854a265a1f759a3e1aa7baec10ba20139e39d78d7ea5e1ed", 32 | 33 | # BICCN mouse cortex snATAC-seq 34 | 35 | }, 36 | urls={ 37 | 38 | # K562 ATAC-seq 39 | "K562_ATAC-seq.zarr": "https://zenodo.org/records/11415225/files/K562_ATAC-seq.zarr", 40 | "signal.bw": "https://zenodo.org/records/11415225/files/signal.bw", 41 | 42 | # K562 CTCF ChIP-seq 43 | "K562_CTCF-ChIP-seq.zarr": "https://zenodo.org/records/11415225/files/K562_CTCF-ChIP-seq.zarr", 44 | "plus.bw": "https://zenodo.org/records/11415225/files/plus.bw", 45 | "minus.bw": "https://zenodo.org/records/11415225/files/minus.bw", 46 | 47 | }, 48 | ) 49 | return _datasets 50 | 51 | 52 | def K562_ATAC_seq(type: Literal["seqdata", "bigwig"]="seqdata") -> Path: 53 | if type == "seqdata": 54 | path = Path(datasets().fetch("K562_ATAC-seq.zarr.tar.gz")) 55 | with tarfile.open(path, "r:gz") as tar: 56 | tar.extractall(path.parent) 57 | path.unlink() # Remove the tar.gz file after extraction 58 | extracted_path = path.parent / "K562_ATAC-seq.zarr" 59 | return extracted_path 60 | elif type == "bigwig": 61 | return Path(datasets().fetch("signal.bw")) 62 | 63 | 64 | def K562_CTCF_ChIP_seq(type: Literal["seqdata", "bigwig"]="seqdata") -> Path: 65 | if type == "seqdata": 66 | path = Path(datasets().fetch("K562_CTCF-ChIP-seq.zarr.tar.gz")) 67 | with tarfile.open(path, "r:gz") as tar: 68 | tar.extractall(path.parent) 69 | path.unlink() # Remove the tar.gz file after extraction 70 | extracted_path = path.parent / "K562_CTCF-ChIP-seq.zarr" 71 | return extracted_path 72 | elif type == "bigwig": 73 | return Path(datasets().fetch("plus.bw")), Path(datasets().fetch("minus.bw")) 74 | 75 | -------------------------------------------------------------------------------- /seqdata/torch.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from itertools import accumulate, chain, repeat 3 | from typing import ( 4 | Callable, 5 | Dict, 6 | Hashable, 7 | Iterable, 8 | List, 9 | Literal, 10 | Optional, 11 | Sequence, 12 | Set, 13 | Tuple, 14 | Union, 15 | cast, 16 | overload, 17 | ) 18 | 19 | import numpy as np 20 | import torch 21 | import xarray as xr 22 | from numpy.typing import NDArray 23 | from torch.utils.data import DataLoader, Sampler 24 | 25 | 26 | def _cartesian_product(arrays: Sequence[NDArray]) -> NDArray: 27 | """Get the cartesian product of multiple arrays such that each entry corresponds to 28 | a unique combination of the input arrays' values. 29 | """ 30 | # https://stackoverflow.com/a/49445693 31 | la = len(arrays) 32 | shape = *map(len, arrays), la 33 | dtype = np.result_type(*arrays) 34 | arr = np.empty(shape, dtype=dtype) 35 | arrs = (*accumulate(chain((arr,), repeat(0, la - 1)), np.ndarray.__getitem__),) 36 | idx = slice(None), *repeat(None, la - 1) 37 | for i in range(la - 1, 0, -1): 38 | arrs[i][..., i] = arrays[i][idx[: la - i]] 39 | arrs[i - 1][1:] = arrs[i] 40 | arr[..., 0] = arrays[0][idx] 41 | return arr.reshape(-1, la) 42 | 43 | 44 | @overload 45 | def get_torch_dataloader( 46 | sdata: xr.Dataset, 47 | sample_dims: Union[str, List[str]], 48 | variables: Union[str, List[str]], 49 | transform: Optional[Callable[[Dict[str, NDArray]], Dict[str, NDArray]]] = None, 50 | dtypes: Union[torch.dtype, Dict[str, torch.dtype]] = torch.float32, 51 | *, 52 | return_tuples: Literal[False], 53 | batch_size: Optional[int] = 1, 54 | shuffle: bool = False, 55 | sampler: Optional[Union["Sampler", Iterable]] = None, 56 | batch_sampler: Optional[Union["Sampler[List]", Iterable[List]]] = None, 57 | num_workers: int = 0, 58 | pin_memory: bool = False, 59 | drop_last: bool = False, 60 | timeout: float = 0, 61 | worker_init_fn=None, 62 | multiprocessing_context=None, 63 | generator=None, 64 | prefetch_factor: Optional[int] = None, 65 | persistent_workers: bool = False, 66 | ) -> "DataLoader[Dict[str, torch.Tensor]]": ... 67 | 68 | 69 | @overload 70 | def get_torch_dataloader( 71 | sdata: xr.Dataset, 72 | sample_dims: Union[str, List[str]], 73 | variables: Union[str, List[str]], 74 | transform: Optional[Callable[[Dict[str, NDArray]], Dict[str, NDArray]]] = None, 75 | dtypes: Union[torch.dtype, Dict[str, torch.dtype]] = torch.float32, 76 | *, 77 | return_tuples: Literal[True], 78 | batch_size: Optional[int] = 1, 79 | shuffle: bool = False, 80 | sampler: Optional[Union["Sampler", Iterable]] = None, 81 | batch_sampler: Optional[Union["Sampler[List]", Iterable[List]]] = None, 82 | num_workers: int = 0, 83 | pin_memory: bool = False, 84 | drop_last: bool = False, 85 | timeout: float = 0, 86 | worker_init_fn=None, 87 | multiprocessing_context=None, 88 | generator=None, 89 | prefetch_factor: Optional[int] = None, 90 | persistent_workers: bool = False, 91 | ) -> "DataLoader[Tuple[torch.Tensor, ...]]": ... 92 | 93 | 94 | @overload 95 | def get_torch_dataloader( 96 | sdata: xr.Dataset, 97 | sample_dims: Union[str, List[str]], 98 | variables: Union[str, List[str]], 99 | transform: Optional[Callable[[Dict[str, NDArray]], Dict[str, NDArray]]] = None, 100 | dtypes: Union[torch.dtype, Dict[str, torch.dtype]] = torch.float32, 101 | *, 102 | return_tuples=False, 103 | batch_size: Optional[int] = 1, 104 | shuffle: bool = False, 105 | sampler: Optional[Union["Sampler", Iterable]] = None, 106 | batch_sampler: Optional[Union["Sampler[List]", Iterable[List]]] = None, 107 | num_workers: int = 0, 108 | pin_memory: bool = False, 109 | drop_last: bool = False, 110 | timeout: float = 0, 111 | worker_init_fn=None, 112 | multiprocessing_context=None, 113 | generator=None, 114 | prefetch_factor: Optional[int] = None, 115 | persistent_workers: bool = False, 116 | ) -> "DataLoader[Union[Dict[str, torch.Tensor], Tuple[torch.Tensor, ...]]]": ... 117 | 118 | 119 | def get_torch_dataloader( 120 | sdata: xr.Dataset, 121 | sample_dims: Union[str, List[str]], 122 | variables: Union[str, List[str]], 123 | transform: Optional[Callable[[Dict[str, NDArray]], Dict[str, NDArray]]] = None, 124 | dtypes: Union[torch.dtype, Dict[str, torch.dtype]] = torch.float32, 125 | *, 126 | return_tuples=False, 127 | batch_size: Optional[int] = 1, 128 | shuffle=False, 129 | sampler: Optional[Union["Sampler", Iterable]] = None, 130 | batch_sampler: Optional[Union["Sampler[List]", Iterable[List]]] = None, 131 | num_workers=0, 132 | pin_memory=False, 133 | drop_last=False, 134 | timeout=0.0, 135 | worker_init_fn=None, 136 | multiprocessing_context=None, 137 | generator=None, 138 | prefetch_factor: Optional[int] = None, 139 | persistent_workers: bool = False, 140 | ) -> "DataLoader[Union[Dict[str, torch.Tensor], Tuple[torch.Tensor, ...]]]": 141 | """Get a PyTorch DataLoader for this SeqData. 142 | 143 | Parameters 144 | ---------- 145 | sample_dims : str or list[str] 146 | Sample dimensions that will be indexed over when fetching batches. For 147 | example, if `sample_dims = ['_sequence', 'sample']` for a variable with 148 | dimensions `['_sequence', 'length', 'sample']` then a batch of data will 149 | have dimensions `['batch', 'length']`. 150 | variables : list[str] 151 | Which variables to sample from. 152 | transforms : Dict[str | tuple[str], (ndarray | tuple[ndarray]) -> ndarray], optional 153 | Transforms to apply to each variable. Will be applied in order and keys that are 154 | tuples of strings will pass the corresponding variables to the transform in the 155 | order that the variable names appear. See examples for details. 156 | dtypes : torch.dtype, Dict[str, torch.dtype] 157 | Data type to convert each variable to after applying all transforms. 158 | 159 | For other parameters, see documentation for [DataLoader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) 160 | 161 | Returns 162 | ------- 163 | DataLoader that returns dictionaries or tuples of tensors. 164 | """ 165 | 166 | if isinstance(sample_dims, str): 167 | sample_dims = [sample_dims] 168 | if isinstance(variables, str): 169 | variables = [variables] 170 | 171 | variables_not_in_ds = set(variables) - set(sdata.data_vars.keys()) 172 | if variables_not_in_ds: 173 | raise ValueError( 174 | f"Got variables that are not in the SeqData: {variables_not_in_ds}" 175 | ) 176 | 177 | if isinstance(dtypes, torch.dtype): 178 | dtypes = {k: dtypes for k in variables} 179 | dim_sizes = [sdata.sizes[d] for d in sample_dims] 180 | ravel_indices = cast( 181 | NDArray[np.intp], 182 | np.arange(np.prod(dim_sizes, dtype=int), dtype=np.intp), # type: ignore 183 | ) 184 | data: Dict[Hashable, NDArray] = { 185 | var: arr.to_numpy() 186 | for var, arr in sdata[variables].transpose(*sample_dims, ...).items() 187 | } 188 | 189 | def collate_fn(indices: List[np.intp]): 190 | # improve performance by sorted indexing 191 | # note: assumes order within batch is irrelevant (true for ML) 192 | indices.sort() 193 | _idxs = np.unravel_index(indices, dim_sizes) 194 | 195 | # select data 196 | out = { 197 | var: dat[ 198 | tuple(_idxs[i] for i, d in enumerate(sample_dims) if d in arr.dims) 199 | ] 200 | for dat, (var, arr) in zip( 201 | data.values(), sdata[variables].data_vars.items() 202 | ) 203 | } 204 | out = cast(Dict[str, NDArray], out) 205 | 206 | # apply transforms 207 | if transform is not None: 208 | out = transform(out) 209 | 210 | # convert to torch 211 | for k in out: 212 | out[k] = torch.as_tensor(out[k], dtype=dtypes[k]) # type: ignore 213 | out = cast(Dict[str, torch.Tensor], out) 214 | 215 | # convert to a tuple if desired 216 | if return_tuples: 217 | out = tuple(out.values()) 218 | 219 | return out 220 | 221 | return DataLoader( 222 | ravel_indices, # type: ignore 223 | batch_size=batch_size, 224 | shuffle=shuffle, 225 | sampler=sampler, 226 | batch_sampler=batch_sampler, 227 | num_workers=num_workers, 228 | collate_fn=collate_fn, 229 | pin_memory=pin_memory, 230 | drop_last=drop_last, 231 | timeout=timeout, 232 | worker_init_fn=worker_init_fn, 233 | multiprocessing_context=multiprocessing_context, 234 | generator=generator, 235 | prefetch_factor=prefetch_factor, 236 | persistent_workers=persistent_workers, 237 | ) 238 | 239 | 240 | # TODO: allow in-memory sdata 241 | # TODO: add parameters for `sampler`, `pin_memory`, `drop_last` 242 | class XArrayDataLoader: 243 | def __init__( 244 | self, 245 | sdata: xr.Dataset, 246 | sample_dims: Union[str, List[str]], 247 | variables: Union[str, List[str]], 248 | transform: Optional[Callable[[Dict[str, NDArray]], Dict[str, NDArray]]] = None, 249 | dtypes: Union[torch.dtype, Dict[str, torch.dtype]] = torch.float32, 250 | batch_size: int = 1, 251 | prefetch_factor: int = 2, 252 | shuffle: bool = False, 253 | seed: Optional[int] = None, 254 | return_tuples: bool = False, 255 | ) -> None: 256 | """Get an XArray DataLoader that supports substantially faster out-of-core 257 | dataloading from chunked storage formats than a PyTorch DataLoader. Note the 258 | absence of concurrency parameters. This is intentional: concurrent I/O is 259 | enabled by instantiating a `dask.distributed.Client` before iteration. 260 | 261 | Parameters 262 | ---------- 263 | sdata : xr.Dataset 264 | sample_dims : Union[str, List[str]] 265 | Dimensions to sample over (i.e. what dimensions would you index over to get 266 | a single instance?) 267 | variables : Union[str, List[str]] 268 | What variables to load data from. 269 | transform : Optional[Callable[[Dict[str, NDArray]], Dict[str, NDArray]]] 270 | A function to transform batches after loading them into memory. Should take 271 | a dictionary of numpy arrays, each corresponding to a `variable`, transform 272 | them, and return the result as a dictionary with the same keys. By default 273 | no transforms will be applied. 274 | dtypes : Union[torch.dtype, Dict[str, torch.dtype]], optional 275 | What dtype to convert each batch array to. Either a single dtype to convert 276 | all variables or a dictionary mapping variables to dtypes. By default 277 | `torch.float32`. 278 | batch_size : int, optional 279 | How many instances per batch, by default 1 280 | prefetch_factor : int, optional 281 | What multiple of chunks to prefetch, by default 2. Tune this and the Zarr 282 | chunk sizes appropriately to control peak memory usage and balance speed and 283 | memory usage. A higher prefetch factor improves speed but uses more memory. 284 | shuffle : bool, optional 285 | Whether to randomly shuffle the dataset on each epoch, by default False 286 | seed : Optional[int], optional 287 | Seed for random shuffling, by default None 288 | return_tuples : bool, optional 289 | Whether to yield tuples (or dictionaries). By default False. 290 | 291 | Raises 292 | ------ 293 | ValueError 294 | When `variables` specifies variables that aren't in the Dataset. 295 | ValueError 296 | When variables have different chunk sizes in any of the sample dimensions. 297 | 298 | Notes 299 | ----- 300 | **Data flow** 301 | 302 | 1. Load contiguous chunks of data from the dataset into buffers that are larger 303 | than the batch size. 304 | 2. Yield batches from the buffer until the buffer is empty, then repeat. 305 | 306 | **Random shuffling** 307 | 308 | We implement random shuffling by prefetching random chunks and then randomly 309 | sampling data from within those chunks. It is possible (although unlikely) that 310 | the data may have structure that isn't randomized due to the lack of fully 311 | random sampling. 312 | """ 313 | if isinstance(sample_dims, str): 314 | sample_dims = [sample_dims] 315 | if isinstance(variables, str): 316 | variables = [variables] 317 | 318 | variables_not_in_ds = set(variables) - set(sdata.data_vars.keys()) 319 | if variables_not_in_ds: 320 | raise ValueError( 321 | f"Got variables that are not in the dataset: {variables_not_in_ds}" 322 | ) 323 | 324 | if isinstance(dtypes, torch.dtype): 325 | self.dtypes = {k: dtypes for k in variables} 326 | else: 327 | self.dtypes = dtypes 328 | 329 | self.sdata = sdata 330 | self.variables = variables 331 | # mapping from dimension name to chunksize 332 | self.chunksizes = self.get_chunksizes(sdata, sample_dims, variables) 333 | self.sample_dims = sample_dims 334 | 335 | self.instances_per_chunk = np.prod(list(self.chunksizes.values()), dtype=int) 336 | chunks_per_batch = -(-batch_size // self.instances_per_chunk) 337 | self.n_prefetch_chunks = prefetch_factor * chunks_per_batch 338 | self.n_instances = np.prod([sdata.sizes[d] for d in sample_dims], dtype=int) 339 | if batch_size > self.n_instances: 340 | warnings.warn( 341 | f"""Batch size {batch_size} is larger than the number of instances in 342 | the dataset {self.n_instances}. Reducing batch size to maximum number of 343 | instances.""" 344 | ) 345 | self.batch_size = self.n_instances 346 | else: 347 | self.batch_size = batch_size 348 | self.max_batches = -(-self.n_instances // self.batch_size) 349 | 350 | self.rng = np.random.default_rng(seed) 351 | self.shuffle = shuffle 352 | self.transform = transform 353 | self.return_tuples = return_tuples 354 | 355 | chunk_start_idx: Dict[str, NDArray[np.int64]] = {} 356 | for dim in self.chunksizes: 357 | length = sdata.sizes[dim] 358 | chunksize = self.chunksizes[dim] 359 | chunk_start_idx[dim] = np.arange(0, length, chunksize, dtype=np.int64) 360 | self.chunk_idxs = _cartesian_product(list(chunk_start_idx.values())) 361 | 362 | def get_chunksizes( 363 | self, sdata: xr.Dataset, sample_dims: List[str], variables: List[str] 364 | ): 365 | chunksizes: Dict[str, Set[int]] = {} 366 | for dim in sample_dims: 367 | dim_chunk_sizes = set() 368 | for v in sdata[variables].data_vars.values(): 369 | if dim in v.dims: 370 | dim_chunk_sizes.add(v.data.chunksize[v.get_axis_num(dim)]) 371 | chunksizes[dim] = dim_chunk_sizes 372 | discrepant_chunk_sizes = {k: v for k, v in chunksizes.items() if len(v) > 1} 373 | if len(discrepant_chunk_sizes) > 1: 374 | raise ValueError( 375 | f"""Variables have different chunksizes in the sample dimensions.\n 376 | Dimensions with discrepant chunksizes: {list(discrepant_chunk_sizes.keys())}.\n 377 | Rechunk the variables in the sample dimensions so they are the same. 378 | """ 379 | ) 380 | return {k: v.pop() for k, v in chunksizes.items()} 381 | 382 | def __len__(self): 383 | return self.max_batches 384 | 385 | def __iter__(self): 386 | # which slice of chunks is going into the buffer 387 | self.chunk_slice = slice(0, self.n_prefetch_chunks) 388 | # which slice of the buffer is going into the batch 389 | self.buffer_slice = slice(0, self.batch_size) 390 | # which slice of the batch is getting pulled & processed 391 | # i.e. batch[self.batch_slice] = self.buffer[self.buffer_slice] 392 | self.batch_slice = slice(0, self.batch_size) 393 | self.current_batch = 0 394 | if self.shuffle: 395 | self.chunk_idxs = self.rng.permutation(self.chunk_idxs, axis=0) 396 | self._flush_and_fill_buffers() 397 | return self 398 | 399 | def _flush_and_fill_buffers(self): 400 | """Flush buffers and fill them with new data.""" 401 | # Each buffer in buffers will have shape (self.buffer_size, ...) 402 | self.buffers: Dict[str, NDArray] = {} 403 | shuffler = None 404 | # (n_chunks, n_dim) 405 | chunk_idx = self.chunk_idxs[self.chunk_slice] 406 | self.chunk_slice = slice( 407 | self.chunk_slice.start, self.chunk_slice.start + self.n_prefetch_chunks 408 | ) 409 | for var in self.variables: 410 | var_dims = [d for d in self.sdata[var].dims if d in self.sample_dims] 411 | buffer = [] 412 | for chunk in chunk_idx: 413 | selector = { 414 | d: slice(start, start + self.chunksizes[d]) 415 | for start, d in zip(chunk, self.sample_dims) 416 | } 417 | buffer.append( 418 | self.sdata[var] 419 | .isel(selector, missing_dims="ignore") 420 | .stack(batch=var_dims) 421 | .transpose("batch", ...) 422 | .to_numpy() 423 | ) 424 | buffer = np.concatenate(buffer) 425 | if shuffler is None: 426 | shuffler = self.rng.permutation(len(buffer)) 427 | if self.shuffle: 428 | buffer = buffer[shuffler] 429 | self.buffers[var] = buffer 430 | 431 | def __next__(self): 432 | if self.current_batch == self.max_batches: 433 | raise StopIteration 434 | 435 | # init empty batch arrays 436 | batch: Dict[str, NDArray] = { 437 | k: np.empty_like(v.data, shape=(self.batch_size, *v.shape[1:])) 438 | for k, v in self.buffers.items() 439 | } 440 | 441 | overshoot = self.buffer_slice.stop - len(self.buffers[self.variables[0]]) 442 | 443 | # buffers don't have enough data to fill the batch 444 | if overshoot > 0: 445 | # grab what they do have 446 | self.batch_slice = slice(0, self.batch_size - overshoot) 447 | for var, buffer in self.buffers.items(): 448 | batch[var][self.batch_slice] = buffer[self.buffer_slice] 449 | 450 | # fetch more data 451 | self._flush_and_fill_buffers() 452 | 453 | # setup to fill the rest of the batch 454 | self.buffer_slice = slice(0, overshoot) 455 | self.batch_slice = slice(self.batch_slice.stop, self.batch_size) 456 | 457 | for var, buffer in self.buffers.items(): 458 | batch[var][self.batch_slice] = buffer[self.buffer_slice] 459 | 460 | # setup for next batch 461 | self.buffer_slice = slice( 462 | self.buffer_slice.stop, self.buffer_slice.stop + self.batch_size 463 | ) 464 | self.batch_slice = slice(0, self.batch_size) 465 | self.current_batch += 1 466 | 467 | # apply transforms, if any 468 | if self.transform is not None: 469 | batch = self.transform(batch) 470 | 471 | out = self._apply_dtypes(batch) 472 | 473 | if self.return_tuples: 474 | return tuple(out.values()) 475 | 476 | return out 477 | 478 | def _apply_dtypes(self, batch: Dict[str, NDArray]): 479 | out = { 480 | k: torch.as_tensor(v, dtype=dtype) 481 | for (k, dtype), v in zip(self.dtypes.items(), batch.values()) 482 | } 483 | return out 484 | -------------------------------------------------------------------------------- /seqdata/types.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from pathlib import Path 3 | from typing import TYPE_CHECKING, List, Literal, Optional, TypeVar, Union 4 | 5 | import numpy as np 6 | 7 | if TYPE_CHECKING: 8 | import polars as pl 9 | 10 | PathType = Union[str, Path] 11 | ListPathType = Union[List[str], List[Path]] 12 | T = TypeVar("T") 13 | DTYPE = TypeVar("DTYPE", bound=np.generic, covariant=True) 14 | 15 | 16 | class FlatReader(ABC): 17 | name: str 18 | 19 | @abstractmethod 20 | def _write( 21 | self, 22 | out: PathType, 23 | fixed_length: bool, 24 | sequence_dim: str, 25 | length_dim: Optional[str] = None, 26 | overwrite=False, 27 | ) -> None: 28 | """Write data from the reader to a SeqData Zarr on disk. 29 | 30 | Parameters 31 | ---------- 32 | out : str, Path 33 | Output file, should be a `.zarr` file. 34 | fixed_length : bool 35 | `int`: length of sequences. `False`: write variable length sequences. 36 | sequence_dim : str 37 | Name of sequence dimension. 38 | length_dim : str 39 | Name of length dimension. 40 | overwrite : bool, default False 41 | Whether to overwrite existing output file. 42 | """ 43 | ... 44 | 45 | 46 | class RegionReader(ABC): 47 | name: str 48 | 49 | @abstractmethod 50 | def _write( 51 | self, 52 | out: PathType, 53 | bed: "pl.DataFrame", 54 | fixed_length: Union[int, Literal[False]], 55 | sequence_dim: str, 56 | length_dim: Optional[str] = None, 57 | splice=False, 58 | overwrite=False, 59 | ) -> None: 60 | """Write data in regions specified from a BED file. 61 | 62 | Parameters 63 | ---------- 64 | out : str, Path 65 | Output file, should be a `.zarr` file. 66 | bed : pl.DataFrame 67 | DataFrame corresponding to a BED file. 68 | fixed_length : int, bool 69 | `int`: length of sequences. `False`: write variable length sequences. 70 | sequence_dim : str 71 | Name of sequence dimension. 72 | length_dim : str, optional 73 | Name of length dimension. Ignored if fixed_length = False. 74 | splice : bool, default False 75 | Whether to splice together regions with the same `name` (i.e. the 4th BED 76 | column). For example, to splice together exons from transcripts or coding 77 | sequences of proteins. 78 | overwrite : bool, default False 79 | Whether to overwrite existing output file. 80 | """ 81 | ... 82 | -------------------------------------------------------------------------------- /seqdata/xarray/seqdata.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from pathlib import Path 3 | from typing import ( 4 | Any, 5 | Dict, 6 | Hashable, 7 | Iterable, 8 | Literal, 9 | Mapping, 10 | MutableMapping, 11 | Optional, 12 | Tuple, 13 | Union, 14 | cast, 15 | ) 16 | 17 | import numpy as np 18 | import pandas as pd 19 | import polars as pl 20 | import xarray as xr 21 | import zarr 22 | from numcodecs import Blosc 23 | 24 | from seqdata._io.bed_ops import ( 25 | _bed_to_zarr, 26 | _expand_regions, 27 | _set_uniform_length_around_center, 28 | read_bedlike, 29 | ) 30 | from seqdata.types import FlatReader, PathType, RegionReader 31 | 32 | from .utils import _filter_by_exact_dims, _filter_layers, _filter_uns 33 | 34 | 35 | def open_zarr( 36 | store: PathType, 37 | group: Optional[str] = None, 38 | synchronizer=None, 39 | chunks: Optional[ 40 | Union[Literal["auto"], int, Mapping[str, int], Tuple[int, ...]] 41 | ] = "auto", 42 | decode_cf=True, 43 | mask_and_scale=False, 44 | decode_times=True, 45 | concat_characters=False, 46 | decode_coords=True, 47 | drop_variables: Optional[Union[str, Iterable[str]]] = None, 48 | consolidated: Optional[bool] = None, 49 | overwrite_encoded_chunks=False, 50 | chunk_store: Optional[Union[MutableMapping, PathType]] = None, 51 | storage_options: Optional[Dict[str, str]] = None, 52 | decode_timedelta: Optional[bool] = None, 53 | use_cftime: Optional[bool] = None, 54 | zarr_version: Optional[int] = None, 55 | **kwargs, 56 | ): 57 | """Open a SeqData object from disk. 58 | 59 | Parameters 60 | ---------- 61 | store : str, Path 62 | Path to the SeqData object. 63 | group : str, optional 64 | Name of the group to open, by default None 65 | synchronizer : None, optional 66 | Synchronizer to use, by default None 67 | chunks : {None, True, False, int, dict, tuple}, optional 68 | Chunking scheme to use, by default "auto" 69 | decode_cf : bool, optional 70 | Whether to decode CF conventions, by default True 71 | mask_and_scale : bool, optional 72 | Whether to mask and scale data, by default False 73 | decode_times : bool, optional 74 | Whether to decode times, by default True 75 | concat_characters : bool, optional 76 | Whether to concatenate characters, by default False 77 | decode_coords : bool, optional 78 | Whether to decode coordinates, by default True 79 | drop_variables : {None, str, iterable}, optional 80 | Variables to drop, by default None 81 | consolidated : bool, optional 82 | Whether to consolidate metadata, by default None 83 | overwrite_encoded_chunks : bool, optional 84 | Whether to overwrite encoded chunks, by default False 85 | chunk_store : {None, MutableMapping, str, Path}, optional 86 | Chunk store to use, by default None 87 | storage_options : dict, optional 88 | Storage options to use, by default None 89 | decode_timedelta : bool, optional 90 | Whether to decode timedeltas, by default None 91 | use_cftime : bool, optional 92 | Whether to use cftime, by default None 93 | zarr_version : int, optional 94 | Zarr version to use, by default None 95 | 96 | Returns 97 | ------- 98 | xr.Dataset 99 | SeqData object 100 | """ 101 | ds = xr.open_zarr( 102 | store=store, 103 | group=group, 104 | synchronizer=synchronizer, 105 | chunks=chunks, # type: ignore 106 | decode_cf=decode_cf, 107 | mask_and_scale=mask_and_scale, 108 | decode_times=decode_times, 109 | concat_characters=concat_characters, 110 | decode_coords=decode_coords, 111 | drop_variables=drop_variables, 112 | consolidated=consolidated, 113 | overwrite_encoded_chunks=overwrite_encoded_chunks, 114 | chunk_store=chunk_store, 115 | storage_options=storage_options, 116 | decode_timedelta=decode_timedelta, 117 | use_cftime=use_cftime, 118 | zarr_version=zarr_version, 119 | **kwargs, 120 | ) 121 | return ds 122 | 123 | 124 | def to_zarr( 125 | sdata: Union[xr.DataArray, xr.Dataset], 126 | store: PathType, 127 | chunk_store: Optional[Union[MutableMapping, PathType]] = None, 128 | mode: Optional[Literal["w", "w-", "a", "r+"]] = None, 129 | synchronizer: Optional[Any] = None, 130 | group: Optional[str] = None, 131 | encoding: Optional[Dict] = None, 132 | compute=True, 133 | consolidated: Optional[bool] = None, 134 | append_dim: Optional[Hashable] = None, 135 | region: Optional[Dict] = None, 136 | safe_chunks=True, 137 | storage_options: Optional[Dict] = None, 138 | zarr_version: Optional[int] = None, 139 | ): 140 | """Write a xarray object to disk as a Zarr store. 141 | 142 | Makes use of the `to_zarr` method of xarray objects, but modifies 143 | the encoding for cases where the chunking is not uniform. 144 | 145 | Parameters 146 | ---------- 147 | sdata : xr.Dataset 148 | SeqData object to write to disk. 149 | store : str, Path 150 | Path to the SeqData object. 151 | chunk_store : {None, MutableMapping, str, Path}, optional 152 | Chunk store to use, by default None 153 | mode : {None, "w", "w-", "a", "r+"}, optional 154 | Mode to use, by default None 155 | synchronizer : None, optional 156 | Synchronizer to use, by default None 157 | group : str, optional 158 | Name of the group to open, by default None 159 | encoding : dict, optional 160 | Encoding to use, by default None 161 | compute : bool, optional 162 | Whether to compute, by default True 163 | consolidated : bool, optional 164 | Whether to consolidate metadata, by default None 165 | append_dim : {None, str}, optional 166 | Name of the append dimension, by default None 167 | region : dict, optional 168 | Region to use, by default None 169 | safe_chunks : bool, optional 170 | Whether to use safe chunks, by default True 171 | storage_options : dict, optional 172 | Storage options to use, by default None 173 | zarr_version : int, optional 174 | Zarr version to use, by default None 175 | 176 | Returns 177 | ------- 178 | None 179 | """ 180 | sdata = sdata.drop_encoding() 181 | 182 | if isinstance(sdata, xr.Dataset): 183 | for coord in sdata.coords.values(): 184 | if "_FillValue" in coord.attrs: 185 | del coord.attrs["_FillValue"] 186 | 187 | for arr in sdata.data_vars: 188 | sdata[arr] = _uniform_chunking(sdata[arr]) 189 | else: 190 | sdata = _uniform_chunking(sdata) 191 | 192 | sdata.to_zarr( 193 | store=store, 194 | chunk_store=chunk_store, 195 | mode=mode, 196 | synchronizer=synchronizer, 197 | group=group, 198 | encoding=encoding, 199 | compute=compute, # type: ignore 200 | consolidated=consolidated, 201 | append_dim=append_dim, 202 | region=region, 203 | safe_chunks=safe_chunks, 204 | storage_options=storage_options, 205 | zarr_version=zarr_version, 206 | ) 207 | 208 | 209 | def _uniform_chunking(arr: xr.DataArray): 210 | # rechunk if write requirements are broken. namely: 211 | # - all chunks except the last are the same size 212 | # - the final chunk is <= the size of the rest 213 | # Use chunk size that is: 214 | # 1. most frequent 215 | # 2. to break ties, largest 216 | if arr.chunksizes is not None: 217 | new_chunks = {} 218 | for dim, chunk in arr.chunksizes.items(): 219 | # > 1 chunk and either the last chunk is different from the rest 220 | # or the second to last chunk is larger than the last 221 | chunks, counts = np.unique(chunk, return_counts=True) 222 | chunk_size = int(chunks[counts == counts.max()].max()) 223 | new_chunks[dim] = chunk_size 224 | if new_chunks != arr.chunksizes: 225 | arr = arr.chunk(new_chunks) 226 | 227 | if "_FillValue" in arr.attrs: 228 | del arr.attrs["_FillValue"] 229 | return arr 230 | 231 | 232 | def from_flat_files( 233 | *readers: FlatReader, 234 | path: PathType, 235 | fixed_length: bool, 236 | sequence_dim: Optional[str] = None, 237 | length_dim: Optional[str] = None, 238 | overwrite=False, 239 | ) -> xr.Dataset: 240 | """Composable function to create a SeqData object from flat files. 241 | 242 | Saves a SeqData to disk and open it (without loading it into memory). 243 | TODO: Tutorials coming soon. 244 | 245 | Parameters 246 | ---------- 247 | *readers : FlatReader 248 | Readers to use to create the SeqData object. 249 | path : str, Path 250 | Path to save this SeqData to. 251 | fixed_length : bool 252 | `True`: assume the all sequences have the same length and will infer it 253 | from the first sequence. 254 | `False`: write variable length sequences. 255 | overwrite : bool, optional 256 | Whether to overwrite existing arrays of the SeqData at `path`, by default False 257 | 258 | Returns 259 | ------- 260 | xr.Dataset 261 | """ 262 | sequence_dim = "_sequence" if sequence_dim is None else sequence_dim 263 | if not fixed_length and length_dim is not None: 264 | warnings.warn("Treating sequences as variable length, ignoring `length_dim`.") 265 | elif fixed_length: 266 | length_dim = "_length" if length_dim is None else length_dim 267 | 268 | for reader in readers: 269 | reader._write( 270 | out=path, 271 | fixed_length=fixed_length, 272 | overwrite=overwrite, 273 | sequence_dim=sequence_dim, 274 | length_dim=length_dim, 275 | ) 276 | 277 | zarr.consolidate_metadata(path) # type: ignore 278 | 279 | ds = open_zarr(path) 280 | return ds 281 | 282 | 283 | def from_region_files( 284 | *readers: RegionReader, 285 | path: PathType, 286 | fixed_length: Union[int, bool], 287 | bed: Union[PathType, pl.DataFrame, pd.DataFrame], 288 | max_jitter=0, 289 | sequence_dim: Optional[str] = None, 290 | length_dim: Optional[str] = None, 291 | splice=False, 292 | overwrite=False, 293 | ) -> xr.Dataset: 294 | """Composable function to create a SeqData object from region based files. 295 | 296 | Saves a SeqData to disk and open it (without loading it into memory). 297 | TODO: Tutorials coming soon. 298 | 299 | Parameters 300 | ---------- 301 | *readers : RegionReader 302 | Readers to use to create the SeqData object. 303 | path : str, Path 304 | Path to save this SeqData to. 305 | fixed_length : int, bool, optional 306 | `int`: use regions of this length centered around those in the BED file. 307 | 308 | `True`: assume the all sequences have the same length and will try to infer it 309 | from the data. 310 | 311 | `False`: write variable length sequences 312 | bed : str, Path, pl.DataFrame, optional 313 | BED file or DataFrame matching the BED3+ specification describing what regions 314 | to write. 315 | max_jitter : int, optional 316 | How much jitter to allow for the SeqData object by writing additional 317 | flanking sequences, by default 0 318 | sequence_dim : str, optional 319 | Name of sequence dimension. Defaults to "_sequence". 320 | length_dim : str, optional 321 | Name of length dimension. Defaults to "_length". 322 | splice : bool, optional 323 | Whether to splice together regions that have the same `name` in the BED file, by 324 | default False 325 | overwrite : bool, optional 326 | Whether to overwrite existing arrays of the SeqData at `path`, by default False 327 | 328 | Returns 329 | ------- 330 | xr.Dataset 331 | """ 332 | sequence_dim = "_sequence" if sequence_dim is None else sequence_dim 333 | if not fixed_length and length_dim is not None: 334 | warnings.warn("Treating sequences as variable length, ignoring `length_dim`.") 335 | elif fixed_length: 336 | length_dim = "_length" if length_dim is None else length_dim 337 | 338 | root = zarr.open_group(path) 339 | root.attrs["max_jitter"] = max_jitter 340 | root.attrs["sequence_dim"] = sequence_dim 341 | root.attrs["length_dim"] = length_dim 342 | 343 | if isinstance(bed, (str, Path)): 344 | _bed = read_bedlike(bed) 345 | elif isinstance(bed, pd.DataFrame): 346 | _bed = pl.from_pandas(bed) 347 | else: 348 | _bed = bed 349 | 350 | if "strand" not in _bed: 351 | _bed = _bed.with_columns(strand=pl.lit("+")) 352 | 353 | if not splice: 354 | if fixed_length is False: 355 | _bed = _expand_regions(_bed, max_jitter) 356 | else: 357 | if fixed_length is True: 358 | fixed_length = cast( 359 | int, 360 | _bed.item(0, "chromEnd") - _bed.item(0, "chromStart"), 361 | ) 362 | fixed_length += 2 * max_jitter 363 | _bed = _set_uniform_length_around_center(_bed, fixed_length) 364 | _bed_to_zarr( 365 | _bed, 366 | root, 367 | sequence_dim, 368 | compressor=Blosc("zstd", clevel=7, shuffle=-1), 369 | overwrite=overwrite, 370 | ) 371 | else: 372 | if max_jitter > 0: 373 | _bed = _bed.with_columns( 374 | pl.when(pl.col("chromStart") == pl.col("chromStart").min().over("name")) 375 | .then(pl.col("chromStart").min().over("name") - max_jitter) 376 | .otherwise(pl.col("chromStart")) 377 | .alias("chromStart"), 378 | pl.when(pl.col("chromEnd") == pl.col("chromEnd").max().over("name")) 379 | .then(pl.col("chromEnd").max().over("name") + max_jitter) 380 | .otherwise(pl.col("chromEnd")) 381 | .alias("chromEnd"), 382 | ) 383 | bed_to_write = _bed.group_by("name").agg( 384 | pl.col(pl.Utf8).first(), pl.exclude(pl.Utf8) 385 | ) 386 | _bed_to_zarr( 387 | bed_to_write, 388 | root, 389 | sequence_dim, 390 | compressor=Blosc("zstd", clevel=7, shuffle=-1), 391 | overwrite=overwrite, 392 | ) 393 | 394 | for reader in readers: 395 | reader._write( 396 | out=path, 397 | bed=_bed, 398 | fixed_length=fixed_length, 399 | sequence_dim=sequence_dim, 400 | length_dim=length_dim, 401 | overwrite=overwrite, 402 | splice=splice, 403 | ) 404 | 405 | zarr.consolidate_metadata(path) # type: ignore 406 | 407 | ds = open_zarr(path) 408 | return ds 409 | 410 | 411 | @xr.register_dataset_accessor("sd") 412 | class SeqDataAccessor: 413 | def __init__(self, ds: xr.Dataset) -> None: 414 | self._ds = ds 415 | 416 | @property 417 | def obs(self): 418 | return _filter_by_exact_dims(self._ds, self._ds.attrs["sequence_dim"]) 419 | 420 | @property 421 | def layers(self): 422 | return _filter_layers(self._ds) 423 | 424 | @property 425 | def obsp(self): 426 | return _filter_by_exact_dims( 427 | self._ds, (self._ds.attrs["sequence_dim"], self._ds.attrs["sequence_dim"]) 428 | ) 429 | 430 | @property 431 | def uns(self): 432 | return _filter_uns(self._ds) 433 | 434 | def __repr__(self) -> str: 435 | return "SeqData accessor." 436 | 437 | 438 | def merge_obs( 439 | sdata: xr.Dataset, 440 | obs: Union[xr.Dataset, pl.DataFrame], 441 | on: Optional[str] = None, 442 | left_on: Optional[str] = None, 443 | right_on: Optional[str] = None, 444 | how: Literal["inner", "left", "right", "outer", "exact"] = "inner", 445 | ): 446 | """Warning: This function is experimental and may change in the future. 447 | Merge observations into a SeqData object along sequence axis. 448 | 449 | Parameters 450 | ---------- 451 | sdata : xr.Dataset 452 | SeqData object. 453 | obs : xr.Dataset, pd.DataFrame 454 | Observations to merge. 455 | on : str, optional 456 | Column to merge on, by default None 457 | left_on : str, optional 458 | Column to merge on from the left dataset, by default None 459 | right_on : str, optional 460 | Column to merge on from the right dataset, by default None 461 | how : {"inner", "left", "right", "outer", "exact"}, optional 462 | Type of merge to perform, by default "inner" 463 | 464 | Returns 465 | ------- 466 | xr.Dataset 467 | Merged SeqData object. 468 | """ 469 | if on is None and (left_on is None or right_on is None): 470 | raise ValueError("Must provide either `on` or both `left_on` and `right_on`.") 471 | if on is not None and (left_on is not None or right_on is not None): 472 | raise ValueError("Cannot provide both `on` and `left_on` or `right_on`.") 473 | 474 | if on is None: 475 | assert left_on is not None 476 | assert right_on is not None 477 | else: 478 | left_on = on 479 | right_on = on 480 | 481 | if left_on not in sdata.data_vars: 482 | sdata = sdata.assign({left_on: np.arange(sdata.sizes[left_on])}) 483 | if left_on not in sdata.xindexes: 484 | sdata = sdata.set_coords(left_on).set_xindex(left_on) 485 | 486 | if isinstance(obs, pl.DataFrame): 487 | obs_ = obs.to_pandas() 488 | if obs_.index.name != right_on: 489 | obs_ = obs_.set_index(right_on) 490 | obs_.index.name = left_on 491 | obs_ = obs_.to_xarray() 492 | sdata_dim = sdata[left_on].dims[0] 493 | obs_dim = obs_[left_on].dims[0] 494 | if sdata_dim != obs_dim: 495 | obs_[left_on].rename({obs_dim: sdata_dim}) 496 | sdata = sdata.merge(obs, join=how) # type: ignore 497 | elif isinstance(obs, xr.Dataset): 498 | if right_on not in obs.data_vars: 499 | obs = obs.assign({right_on: np.arange(sdata.sizes[right_on])}) 500 | if right_on not in obs.xindexes: 501 | obs = ( 502 | obs.rename({right_on: left_on}).set_coords(left_on).set_xindex(left_on) 503 | ) 504 | sdata = sdata.merge(obs, join=how) 505 | 506 | return sdata 507 | 508 | 509 | def add_layers_from_files( 510 | sdata: xr.Dataset, 511 | *readers: Union[FlatReader, RegionReader], 512 | path: PathType, 513 | overwrite=False, 514 | ): 515 | raise NotImplementedError 516 | # if any(map(lambda r: isinstance(r, RegionReader), readers)): 517 | # bed = sdata[["chrom", "chromStart", "chromEnd", "strand"]].to_dataframe() 518 | 519 | # for reader in readers: 520 | # if isinstance(reader, FlatReader): 521 | # if reader.n_seqs is not None and reader.n_seqs != sdata.sizes["_sequence"]: 522 | # raise ValueError( 523 | # f"""Reader "{reader.name}" has a different number of sequences 524 | # than this SeqData.""" 525 | # ) 526 | # _fixed_length = fixed_length is not False 527 | # reader._write(out=path, fixed_length=_fixed_length, overwrite=overwrite) 528 | # elif isinstance(reader, RegionReader): 529 | # reader._write( 530 | # out=path, 531 | # bed=bed, # type: ignore 532 | # overwrite=overwrite, 533 | # ) 534 | 535 | # ds = xr.open_zarr(path, mask_and_scale=False, concat_characters=False) 536 | # return ds 537 | -------------------------------------------------------------------------------- /seqdata/xarray/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union 2 | 3 | import xarray as xr 4 | 5 | 6 | def _filter_by_exact_dims(ds: xr.Dataset, dims: Union[str, Tuple[str, ...]]): 7 | if isinstance(dims, str): 8 | dims = (dims,) 9 | else: 10 | dims = tuple(dims) 11 | selector = [] 12 | for name, arr in ds.data_vars.items(): 13 | if arr.dims == dims: 14 | selector.append(name) 15 | return ds[selector] 16 | 17 | 18 | def _filter_layers(ds: xr.Dataset): 19 | selector = [] 20 | for name, arr in ds.data_vars.items(): 21 | if ( 22 | len(arr.dims) > 1 23 | and arr.dims[0] == ds.attrs["sequence_dim"] 24 | and arr.dims[1] != ds.attrs["sequence_dim"] 25 | ): 26 | selector.append(name) 27 | return ds[selector] 28 | 29 | 30 | def _filter_uns(ds: xr.Dataset): 31 | selector = [] 32 | for name, arr in ds.data_vars.items(): 33 | if ds.attrs["sequence_dim"] not in arr.dims: 34 | selector.append(name) 35 | return ds[selector] 36 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup(name="seqdata", packages=find_packages()) 4 | 5 | with open("README.md", "r") as readme_file: 6 | readme = readme_file.read() 7 | 8 | requirements = [] 9 | 10 | setup( 11 | name="seqdata", 12 | version="0.1.2", 13 | author="Adam Klie", 14 | author_email="aklie@ucsd.edu", 15 | description="Annotated sequence data", 16 | long_description=readme, 17 | long_description_content_type="text/markdown", 18 | url="https://github.com/adamklie/SeqData", 19 | packages=find_packages(), 20 | install_requires=requirements, 21 | classifiers=[ 22 | "Programming Language :: Python :: 3.7", 23 | "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", 24 | ], 25 | ) 26 | -------------------------------------------------------------------------------- /tests/_test_vcf.py: -------------------------------------------------------------------------------- 1 | import pytest # noqa 2 | from pytest import fixture, parametrize_with_cases 3 | 4 | from seqdata import read_vcf 5 | 6 | 7 | @fixture 8 | def vcf(): 9 | raise NotImplementedError 10 | 11 | 12 | @fixture 13 | def reference(): 14 | raise NotImplementedError 15 | 16 | 17 | def consensus(sample): 18 | raise NotImplementedError 19 | 20 | 21 | def bed_no_variants(): 22 | raise NotImplementedError 23 | 24 | 25 | def bed_variants(): 26 | raise NotImplementedError 27 | 28 | 29 | def length_variable(): 30 | return None 31 | 32 | 33 | def length_600(): 34 | return 600 35 | 36 | 37 | def samples_one(): 38 | raise NotImplementedError 39 | 40 | 41 | def samples_two(): 42 | raise NotImplementedError 43 | 44 | 45 | @parametrize_with_cases("bed", cases=".", prefix="bed_") 46 | @parametrize_with_cases("samples", cases=".", prefix="samples_") 47 | @parametrize_with_cases("length", cases=".", prefix="length_") 48 | def test_fixed_length(vcf, reference, samples, bed, length): 49 | sdata = read_vcf( 50 | "vcf", "foo", vcf, reference, samples, bed, 1024, length, overwrite=True 51 | ) # noqa 52 | for region in sdata.obs[["contig", "start", "end"]].itertuples(): 53 | for i, sample in enumerate(sdata.ds.coords["vcf_samples"]): 54 | pass 55 | # consensus_path = 56 | raise NotImplementedError 57 | 58 | 59 | def test_variable_length(): 60 | raise NotImplementedError 61 | 62 | 63 | def test_spliced_fixed_length(): 64 | raise NotImplementedError 65 | 66 | 67 | def test_spliced_variable_length(): 68 | raise NotImplementedError 69 | -------------------------------------------------------------------------------- /tests/data/README.md: -------------------------------------------------------------------------------- 1 | # Simulated 2 | 3 | ## Tabular 4 | 5 | - `variable.tsv` ✅ 6 | - `fixed.tsv` ✅ 7 | 8 | ## Fastas 9 | 10 | - `variable.fa` — variable length .fa file ✅ 11 | 12 | ``` 13 | > chr1 14 | CGACTACTACCGACTAACTGACTGATGATGATGCATGCTGATGCTGAACTGACTAGCACTGCATGACTGATGACTGACTG 15 | TACTCCTACCATGACTATCCTAGTGCTGACCTGACTGATGCTGACTGACTGCATATGCACTGACTGACTCTACATGACTG 16 | ACTCACTCATCTGACATATCCATGCTGCATACTCATGATCATGCATGCATCATACTCATGCATGACTGACTCATGATGCA 17 | CATACTACTGCAGTCTGCATCATGCATGCATGCATGCACATCAT 18 | 19 | > chr2 20 | CTATCATCTCTGATGACTGATGCATATTCTATCTACTACTGCTATACTCATATCTACTACTACTACTCATACTATCACTA 21 | TCACTATGCATCATCATCATGCATGCATGCATGCATGCATACTTACTATCATGACTGACTGTGACTGCATGCTGATGATT 22 | TTTTATCTGCATACTCATATCATACTCATCATCATACTACTCATGACTGCA 23 | 24 | > chr3 25 | ATCTACTGacggagcacATCCATCTACGcacaCTACTACTACTCAACGTTGCATGATGCTGACTACTACAaaTCATAcaa 26 | CATCATACacacACTACTCATGGTACGTCATGCATATGCAcagtaa 27 | 28 | > chr4 29 | NNNNNCGATCATCATCACTACTACACTGGCACGTTACATCCTAGCTGACTGACTACTGACTGCATGCATGACTCATCATA 30 | CTATCTAACNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNCTACTCAT 31 | CTATCAAGCCTATCNNNNNNNNNNNNNNNNNNNNNNNNNACTCTACTACATGCTGAGTCATGCATGCATGACTACTATAT 32 | 33 | > chr5 34 | CATCTCGACTATACGACATATACGTZCTACACTGCAATCTGATGACGCAACTCAGCANNNNNNNNCacatactatctaca 35 | CTCATGCTGACGCATGCTGAcaatatctagaCATACTACTCAGTTGTATATATTACTACACCATGACTGCTACTCATGAC 36 | 37 | > chr6 38 | CATCATCATCATCATCATCATCATCATCATCATCATCATCATCATCATCATCATCATCATCATCATCATCATCATCATCA 39 | 40 | > chr7 41 | CTATCAGTACTGATCTGACTGACTGACTGATGCACGTATGACTGATGATGCAGTCAGTCATGTGACATGCAGCGTAACTA 42 | CTACTGCTACTGACTGCATGACTGCATGACTGACTGCATGCATGCATGCATGCATGACTGCATGGTCATGCTGATGACTG 43 | CATCTCTATGCTCGAACTGCATGATGCATGACTGCTGATGCTGTGACTGACTGGTACGTCAGTTGCTGCAGTCATGAAAT 44 | TATCAACTACTACTGCTGACTGCGTGTACGTATGCATCTCATCATTCTGAGCATGCGTATGTGTGCATGGTACTGATTCA 45 | ACCCTATCAGCATGTACTGCATGCAAAAACCTATGCCTATGTGACTCATCTGCATACTCATCATGACTGACTGCATCATT 46 | ACTTCACTGATGCAGTCAGTACGTACTGACTTGCATGACTGACTGTTGCGTAGGGGGACGTACGTGGGTACTGGTCAGTA 47 | ACTCATACTGGCGTCATGCATGACTGCATGACTGACTGCATGATGCACATCAAAAAACTACTGCATGCATGACACCACAC 48 | CATACTGCATGCATTACGCATGCTGATGCATGACTACTTATTACTGGTGGGTACTTTACTCTCCCTCATCATCACTCATA 49 | CATCTACTGACTATGCATACTCATCATACTCATCATCATCATCATTCACTATACCATCATCATCATACTACTACTTACAC 50 | CATACTCATACTACTTCATCATGCATCGAAGTACTGATCATATATATTAAATCATCAGTACTTGGTTGACTGTGTGACAT 51 | ATCTACTGCGGTCTGCTCCCCCTCATACCCATTGACTCATGTTACTGACTGTACGTTGTGTGTGCATTATCTTTCATTAA 52 | TATCTCCATCATCATCATCATCTCAGACGTCATGCATGCATGCATGCATGCATGCTGATGCAGTGTATGATGTATGATAT 53 | ACTCTACATCATCTGATCATCATACTCATGCATCATCACATGCATGTGCATGCATGCATGACTGCATCATGTGTGATGCT 54 | ATATATATATATATATATATATATATTATATTATATATTACTCATACTCTCATCATACTCATTACTACTCATTACTATTA 55 | ACTCTCGGCGCTATATCCTTGGGGGCATCATGCATCACATTACTCATTACTATATATATTATTACTCAGTATATATATAT 56 | ATGCGCGATATTATGTGTGTTATTGATATATGATATACTATATTAGAATATATATAAATATATCTCAATTATAATTATAT 57 | ATACCATCATGTGATATTATATATATATCTGTGATATTAATATATTACTGCAACTGACCATACTACTACTACTCATTACC 58 | ACTTCAACTCTAGTCAGTACTACTACTCATTATTATATTACTCTCATTGGGACTACTACTATCTATATCGACGTCTACTA 59 | ATTCTGCACTGCATGCATCATGACTGACTATATATGTACGCGCGCCGCGCGCGGCGCTCATCAGGCGCGCGCTAGCGTCT 60 | CGCGCGCTAACGCGCGCGGCGCGCATGACTGACTCATGCGCGCGCGCGCCTTCAGTCAGACGACGCGTTGCAGCATGTAC 61 | TACTGCAGTCAGCAGACGCGTTATGTGTGACGTCGCAGCGCGCGCGCGCAGTACTGCAACTCATGGTACGCAGCAATATG 62 | ATGCATGCAGCAGCGCGACCATGATCTACTACTATATATACTCGCAGTACGTACGACGACGACGCAGACCATACTACGAC 63 | ACTACTGAGTCGTCATACGTGCAGTACGTCATGACTGGCAGACGACGCGCGGCGCGCATGCTCTCTCTCACATCTACTGC 64 | CGTCGCGCGCGCGCGCGCGGGCGGCGGCGGCGCGGCGCGGCGCGCGCGCGCATACTGCACATGCCGCGCGCCTGCCCCCC 65 | CGCATTCACACGTCCGGCGCCGGCCTATCTACTGCACATGCATGCATGCTGAGACGCGCAGACGCATCTCGCGCGCGGCC 66 | ``` 67 | 68 | - `fixed.fa` — first 50bp of `variable.fa` ✅ 69 | 70 | ## BEDs (remember these are 0-based) 71 | 72 | - `variable.bed` — variable length regions ✅ 73 | - `fixed.bed` — first 20bp of each region in `variable.bed` ✅ 74 | 75 | ## BAM 76 | 77 | - `simulated.{1-5}.bam` — Simulate a BAM file from `variable.bed` ✅ 78 | 79 | ## BigWig 80 | 81 | - `simulated.{1-5}.bw` — write out coverage to BigWig ✅ 82 | 83 | ## Fragments 84 | 85 | - `simulated.{1-5}.bed.gz` 86 | - Convert bam to fragments 87 | 88 | # K562 ATAC-seq chr22 89 | 90 | - `hg38.chr22.fa` ✅ 91 | - `hg38.chr22.chromsizes` ✅ 92 | - `ENCSR868FGK.chr22.bed` ✅ 93 | - `ENCSR868FGK.chr22.bam` ✅ 94 | - `ENCSR868FGK.chr22.chrombpnet.bw` — Using chrombpnet `reads_to_bigwig.py` ✅ 95 | 96 | # deBoer et al sample data 97 | 98 | https://zenodo.org/records/10633252/files/filtered_test_data_with_MAUDE_expression.txt?download=1 99 | 100 | - Add header: seq\texp -------------------------------------------------------------------------------- /tests/data/fixed.bed: -------------------------------------------------------------------------------- 1 | chr1 4 24 2 | chr1 47 67 3 | chr2 46 66 4 | chr2 174 194 5 | chr3 18 38 6 | chr3 78 98 7 | chr4 35 55 8 | chr4 87 107 9 | chr5 40 60 10 | chr5 156 176 11 | chr6 19 39 12 | chr6 61 81 13 | chr7 12 32 14 | chr7 153 173 15 | -------------------------------------------------------------------------------- /tests/data/fixed.chrom.sizes: -------------------------------------------------------------------------------- 1 | chr1 120 2 | chr2 400 3 | chr3 110 4 | chr4 150 5 | chr5 300 6 | chr6 100 7 | chr7 200 8 | -------------------------------------------------------------------------------- /tests/data/fixed.fa: -------------------------------------------------------------------------------- 1 | >chr1 2 | ACTACGATCATCACAACTTGGCCTAAGCGATAAACTCCAGTGAGGGCGTTAACATAGTAG 3 | GTAACACAACCAATGGGTCT 4 | 5 | >chr2 6 | TcAACTACGTGGATacTCGGCCATCTGAAAcTgGGATGACGATGTGAACAGCTCCTTAGc 7 | CTGTcAgGGGGGAGCGCcTC 8 | 9 | >chr3 10 | AGAGTACCTGCCTTAAAATATTTTACACAGCCGCAAGTCTAGGAGTGTCACAGGGATGCT 11 | AGTTGGTGTTAAGTTGAGTA 12 | 13 | >chr4 14 | CAGACTAAGTGCCAGAAGGCTTCTATGCAACGATGGTTGGCAACCCGTGTCATGAGGTAG 15 | TCTACTTGGTTCTTATATCT 16 | 17 | >chr5 18 | TAGCTACCACGCTGAGGTATTACGTTTCATTTAAAATCGTATAAGCGGCCGGTTATGACT 19 | AACCCCAAGCCAGCTAGAGC 20 | 21 | >chr6 22 | GCGCTGGGCGAGTGTATCTATGGCTTTCGACTCGAGATTTATGGACGTAGTATAATCAGA 23 | GCACCGAGCTCAGAATACTA 24 | 25 | >chr7 26 | GCGATCTGGGAAGTAAGCACAGCATTCTACACATTCGACCCTCTGCCAGCGCTTGGCCGC 27 | GACACGACCTATAGCTCATG 28 | 29 | -------------------------------------------------------------------------------- /tests/data/fixed.fa.fai: -------------------------------------------------------------------------------- 1 | chr1 80 6 60 61 2 | chr2 80 95 60 61 3 | chr3 80 184 60 61 4 | chr4 80 273 60 61 5 | chr5 80 362 60 61 6 | chr6 80 451 60 61 7 | chr7 80 540 60 61 8 | -------------------------------------------------------------------------------- /tests/data/fixed.tsv: -------------------------------------------------------------------------------- 1 | seq target 2 | ACTACGATCATCACAACTTG 1.195480844083972 3 | TcAACTACGTGGATacTCGG -0.6929163050128713 4 | AGAGTACCTGCCTTAAAATA -0.3066981740025167 5 | CAGACTAAGTGCCAGAAGGC 0.3260554187453464 6 | TAGCTACCACGCTGAGGTAT 0.21557499827573193 7 | GCGCTGGGCGAGTGTATCTA -0.6086421901135827 8 | GCGATCTGGGAAGTAAGCAC 0.549338320319528 9 | -------------------------------------------------------------------------------- /tests/data/simulated1.bam: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML4GLand/SeqData/10afa85d0ec954b9d28955b072a55274c3c18dfd/tests/data/simulated1.bam -------------------------------------------------------------------------------- /tests/data/simulated1.bam.bai: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML4GLand/SeqData/10afa85d0ec954b9d28955b072a55274c3c18dfd/tests/data/simulated1.bam.bai -------------------------------------------------------------------------------- /tests/data/simulated1.bw: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML4GLand/SeqData/10afa85d0ec954b9d28955b072a55274c3c18dfd/tests/data/simulated1.bw -------------------------------------------------------------------------------- /tests/data/simulated2.bam: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML4GLand/SeqData/10afa85d0ec954b9d28955b072a55274c3c18dfd/tests/data/simulated2.bam -------------------------------------------------------------------------------- /tests/data/simulated2.bam.bai: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML4GLand/SeqData/10afa85d0ec954b9d28955b072a55274c3c18dfd/tests/data/simulated2.bam.bai -------------------------------------------------------------------------------- /tests/data/simulated2.bw: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML4GLand/SeqData/10afa85d0ec954b9d28955b072a55274c3c18dfd/tests/data/simulated2.bw -------------------------------------------------------------------------------- /tests/data/simulated3.bam: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML4GLand/SeqData/10afa85d0ec954b9d28955b072a55274c3c18dfd/tests/data/simulated3.bam -------------------------------------------------------------------------------- /tests/data/simulated3.bam.bai: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML4GLand/SeqData/10afa85d0ec954b9d28955b072a55274c3c18dfd/tests/data/simulated3.bam.bai -------------------------------------------------------------------------------- /tests/data/simulated3.bw: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML4GLand/SeqData/10afa85d0ec954b9d28955b072a55274c3c18dfd/tests/data/simulated3.bw -------------------------------------------------------------------------------- /tests/data/simulated4.bam: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML4GLand/SeqData/10afa85d0ec954b9d28955b072a55274c3c18dfd/tests/data/simulated4.bam -------------------------------------------------------------------------------- /tests/data/simulated4.bam.bai: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML4GLand/SeqData/10afa85d0ec954b9d28955b072a55274c3c18dfd/tests/data/simulated4.bam.bai -------------------------------------------------------------------------------- /tests/data/simulated4.bw: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML4GLand/SeqData/10afa85d0ec954b9d28955b072a55274c3c18dfd/tests/data/simulated4.bw -------------------------------------------------------------------------------- /tests/data/simulated5.bam: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML4GLand/SeqData/10afa85d0ec954b9d28955b072a55274c3c18dfd/tests/data/simulated5.bam -------------------------------------------------------------------------------- /tests/data/simulated5.bam.bai: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML4GLand/SeqData/10afa85d0ec954b9d28955b072a55274c3c18dfd/tests/data/simulated5.bam.bai -------------------------------------------------------------------------------- /tests/data/simulated5.bw: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML4GLand/SeqData/10afa85d0ec954b9d28955b072a55274c3c18dfd/tests/data/simulated5.bw -------------------------------------------------------------------------------- /tests/data/variable.bed: -------------------------------------------------------------------------------- 1 | chr1 4 28 2 | chr1 47 76 3 | chr2 46 72 4 | chr2 174 197 5 | chr3 18 43 6 | chr3 78 106 7 | chr4 35 60 8 | chr4 87 111 9 | chr5 40 62 10 | chr5 156 181 11 | chr6 19 49 12 | chr6 61 85 13 | chr7 12 34 14 | chr7 153 174 15 | -------------------------------------------------------------------------------- /tests/data/variable.bedcov.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML4GLand/SeqData/10afa85d0ec954b9d28955b072a55274c3c18dfd/tests/data/variable.bedcov.pkl -------------------------------------------------------------------------------- /tests/data/variable.chrom.sizes: -------------------------------------------------------------------------------- 1 | chr1 120 2 | chr2 400 3 | chr3 110 4 | chr4 150 5 | chr5 300 6 | chr6 100 7 | chr7 200 8 | -------------------------------------------------------------------------------- /tests/data/variable.fa: -------------------------------------------------------------------------------- 1 | >chr1 2 | ACTACGATCATCACAACTTGGCCTAAGCGATAAACTCCAGTGAGGGCGTTAACATAGTAG 3 | GTAACACAACCAATGGGTCTCGGTTGAGCACGTCGGATTCTTTCATATGATAAGCATTTC 4 | 5 | >chr2 6 | TcAACTACGTGGATacTCGGCCATCTGAAAcTgGGATGACGATGTGAACAGCTCCTTAGc 7 | CTGTcAgGGGGGAGCGCcTCAATAACCGAGAACAGTCTCtGCAAAATTCCGTCAATTtGA 8 | TAAAATTCGGAcCAGtCtGCagACCTACGGTGGGgAAATCCCTTACgcGAGGTGTTatAG 9 | CCtAAGCcTAACTTCATGTGGCCACcCAAGATCGAcaGaGTGTGTGGTGAGGAGcCcGGA 10 | CTtAGGGTGCACTGAAAATAGGCTTAAACCCtTTTCTATAAACTTCGAGGTTAtACGTGT 11 | tAGTTCTtGGCGCTTGTgGACTTAGTaTTTCAGGATGGaGcCGGGATTATTCAATCCCGA 12 | CTcAATGACCCGTCTtGtGGgCTGTGaATTGTtGTtGCaG 13 | 14 | >chr3 15 | AGAGTACCTGCCTTAAAATATTTTACACAGCCGCAAGTCTAGGAGTGTCACAGGGATGCT 16 | AGTTGGTGTTAAGTTGAGTACTCACCCGTTAAATTTGCCGAGGTGCCTCG 17 | 18 | >chr4 19 | CAGACTAAGTGCCAGAAGGCTTCTATGCAACGATGGTTGGCAACCCGTGTCATGAGGTAG 20 | TCTACTTGGTTCTTATATCTGCCCTCGGTTTATGTAGGCGGGATTCTGATATTCCGTACC 21 | TGTCACTCTTTCAGTAGACGCAGCCTGTAG 22 | 23 | >chr5 24 | TAGCTACCACGCTGAGGTATTACGTTTCATTTAAAATCGTATAAGCGGCCGGTTATGACT 25 | AACCCCAAGCCAGCTAGAGCACGGCAAGACCTTTAGGATCTCGGATAGAACGCCGTTAAA 26 | GTGTGGGGCCAGAAGAGGGTGCGATTAAGTTGATACCATAGTGAGTAAGAACGATCGCTC 27 | GGAATTCAACAACTACACCCTGAGACTACTACATTCCGTCTTCGGTGGACCCGTCGCGAA 28 | TATACTGGACGGCTAGGCTTCGTCAACTTGTGGGTAGTAGGTTTACGTGCCCGACTTTGT 29 | 30 | >chr6 31 | GCGCTGGGCGAGTGTATCTATGGCTTTCGACTCGAGATTTATGGACGTAGTATAATCAGA 32 | GCACCGAGCTCAGAATACTAAAAGCACCTTATGTTTTCAA 33 | 34 | >chr7 35 | GCGATCTGGGAAGTAAGCACAGCATTCTACACATTCGACCCTCTGCCAGCGCTTGGCCGC 36 | GACACGACCTATAGCTCATGTAAAGCGCGTGGCCATGGCCGCCTAATCAATGACGTATGC 37 | GCTTTTAAAATCAGTCTTGCCACTGTTATTAGCGAGGTAAACCTGACATAGTATAACGTG 38 | CTAATTAAATTTGTCCTTCC 39 | 40 | -------------------------------------------------------------------------------- /tests/data/variable.fa.fai: -------------------------------------------------------------------------------- 1 | chr1 120 6 60 61 2 | chr2 400 135 60 61 3 | chr3 110 549 60 61 4 | chr4 150 668 60 61 5 | chr5 300 828 60 61 6 | chr6 100 1140 60 61 7 | chr7 200 1249 60 61 8 | -------------------------------------------------------------------------------- /tests/data/variable.tsv: -------------------------------------------------------------------------------- 1 | seq target 2 | ACTACGATCATCACAACTTGGCCTAAGCGATAAACTCCAGTGAGGGCGTTAACATAGTAGGTAACACAACCAATGGGTCTCGGTTGAGCACGTCGGATTCTTTCATATGATAAGCATTTC 1.195480844083972 3 | TcAACTACGTGGATacTCGGCCATCTGAAAcTgGGATGACGATGTGAACAGCTCCTTAGcCTGTcAgGGGGGAGCGCcTCAATAACCGAGAACAGTCTCtGCAAAATTCCGTCAATTtGATAAAATTCGGAcCAGtCtGCagACCTACGGTGGGgAAATCCCTTACgcGAGGTGTTatAGCCtAAGCcTAACTTCATGTGGCCACcCAAGATCGAcaGaGTGTGTGGTGAGGAGcCcGGACTtAGGGTGCACTGAAAATAGGCTTAAACCCtTTTCTATAAACTTCGAGGTTAtACGTGTtAGTTCTtGGCGCTTGTgGACTTAGTaTTTCAGGATGGaGcCGGGATTATTCAATCCCGACTcAATGACCCGTCTtGtGGgCTGTGaATTGTtGTtGCaG -0.6929163050128713 4 | AGAGTACCTGCCTTAAAATATTTTACACAGCCGCAAGTCTAGGAGTGTCACAGGGATGCTAGTTGGTGTTAAGTTGAGTACTCACCCGTTAAATTTGCCGAGGTGCCTCG -0.3066981740025167 5 | CAGACTAAGTGCCAGAAGGCTTCTATGCAACGATGGTTGGCAACCCGTGTCATGAGGTAGTCTACTTGGTTCTTATATCTGCCCTCGGTTTATGTAGGCGGGATTCTGATATTCCGTACCTGTCACTCTTTCAGTAGACGCAGCCTGTAG 0.3260554187453464 6 | TAGCTACCACGCTGAGGTATTACGTTTCATTTAAAATCGTATAAGCGGCCGGTTATGACTAACCCCAAGCCAGCTAGAGCACGGCAAGACCTTTAGGATCTCGGATAGAACGCCGTTAAAGTGTGGGGCCAGAAGAGGGTGCGATTAAGTTGATACCATAGTGAGTAAGAACGATCGCTCGGAATTCAACAACTACACCCTGAGACTACTACATTCCGTCTTCGGTGGACCCGTCGCGAATATACTGGACGGCTAGGCTTCGTCAACTTGTGGGTAGTAGGTTTACGTGCCCGACTTTGT 0.21557499827573193 7 | GCGCTGGGCGAGTGTATCTATGGCTTTCGACTCGAGATTTATGGACGTAGTATAATCAGAGCACCGAGCTCAGAATACTAAAAGCACCTTATGTTTTCAA -0.6086421901135827 8 | GCGATCTGGGAAGTAAGCACAGCATTCTACACATTCGACCCTCTGCCAGCGCTTGGCCGCGACACGACCTATAGCTCATGTAAAGCGCGTGGCCATGGCCGCCTAATCAATGACGTATGCGCTTTTAAAATCAGTCTTGCCACTGTTATTAGCGAGGTAAACCTGACATAGTATAACGTGCTAATTAAATTTGTCCTTCC 0.549338320319528 9 | -------------------------------------------------------------------------------- /tests/readers/test_bam.py: -------------------------------------------------------------------------------- 1 | from tempfile import TemporaryDirectory 2 | 3 | import numpy as np 4 | import polars as pl 5 | import pandas as pd 6 | import seqpro as sp 7 | import xarray as xr 8 | from pytest import fixture 9 | import pickle 10 | 11 | import seqdata as sd 12 | import zarr 13 | 14 | 15 | def read_fasta(file_path): 16 | """ 17 | Reads a FASTA file and returns a dictionary of sequences. 18 | 19 | Parameters: 20 | file_path (str): Path to the FASTA file. 21 | 22 | Returns: 23 | dict: A dictionary where keys are sequence IDs and values are sequences. 24 | """ 25 | sequences = {} 26 | with open(file_path, 'r') as file: 27 | sequence_id = None 28 | sequence_lines = [] 29 | for line in file: 30 | line = line.strip() 31 | if line.startswith(">"): 32 | # Save the previous sequence (if any) before starting a new one 33 | if sequence_id: 34 | sequences[sequence_id] = ''.join(sequence_lines) 35 | # Start a new sequence 36 | sequence_id = line[1:] # Remove the '>' 37 | sequence_lines = [] 38 | else: 39 | # Append sequence lines 40 | sequence_lines.append(line) 41 | # Save the last sequence 42 | if sequence_id: 43 | sequences[sequence_id] = ''.join(sequence_lines) 44 | return sequences 45 | 46 | 47 | @fixture 48 | def fasta(): 49 | fasta = "tests/data/variable.fa" 50 | return read_fasta(fasta) 51 | 52 | 53 | @fixture 54 | def fasta_reader(): 55 | fasta = "tests/data/variable.fa" 56 | reader = sd.GenomeFASTA( 57 | name="seq", 58 | fasta=fasta, 59 | batch_size=50 60 | ) 61 | return reader 62 | 63 | 64 | @fixture 65 | def variable_bed(): 66 | variable_bed = "tests/data/variable.bed" 67 | variable_bed = pd.read_csv(variable_bed, sep="\t", header=None) 68 | return variable_bed 69 | 70 | 71 | @fixture 72 | def fixed_bed(): 73 | fixed_bed = "tests/data/fixed.bed" 74 | fixed_bed = pd.read_csv(fixed_bed, sep="\t", header=None) 75 | return fixed_bed 76 | 77 | @fixture 78 | def variable_coverage(): 79 | variable_coverage = "tests/data/variable.bedcov.pkl" 80 | variable_coverage = pickle.load(open(variable_coverage, 'rb')) 81 | return variable_coverage 82 | 83 | 84 | @fixture 85 | def fixed_coverage(variable_coverage, fixed_bed): 86 | fixed_coverage = {} 87 | bams = variable_coverage.keys() 88 | for bam in bams: 89 | fixed_coverage[bam] = {} 90 | for i, (region, coverage) in enumerate(variable_coverage[bam].items()): 91 | coverage_interval = region.split(":")[1] 92 | coverage_start, coverage_end = map(int, coverage_interval.split("-")) 93 | start_offset = coverage_start - fixed_bed[1].values[i] 94 | end_offset = fixed_bed[2].values[i] - coverage_end 95 | new_region = f"{fixed_bed[0].values[i]}:{fixed_bed[1].values[i]}-{fixed_bed[2].values[i]}" 96 | if end_offset == 0: 97 | fixed_coverage[bam][new_region] = coverage[start_offset:] 98 | else: 99 | fixed_coverage[bam][new_region] = coverage[start_offset:end_offset] 100 | return fixed_coverage 101 | 102 | 103 | @fixture 104 | def single_reader(): 105 | bam = "tests/data/simulated1.bam" 106 | single_reader = sd.BAM( 107 | name="cov", 108 | bams=bam, 109 | samples=["simulated1.bam"], 110 | batch_size=50 111 | ) 112 | return single_reader 113 | 114 | 115 | @fixture 116 | def multi_reader(): 117 | bams = [f"tests/data/simulated{i}.bam" for i in range(1, 6)] 118 | multi_reader = sd.BAM( 119 | name="cov", 120 | bams=bams, 121 | samples=[f"simulated{i}.bam" for i in range(1, 6)], 122 | batch_size=50 123 | ) 124 | return multi_reader 125 | 126 | 127 | def test_single_bam_write( 128 | single_reader, 129 | variable_bed, 130 | variable_coverage, 131 | ): 132 | with ( 133 | TemporaryDirectory(suffix=".zarr") as out, 134 | ): 135 | variable_bed["strand"] = "+" 136 | single_reader._write( 137 | out=out, 138 | bed=variable_bed, 139 | fixed_length=False, 140 | sequence_dim="_sequence", 141 | overwrite=True, 142 | ) 143 | zarr.consolidate_metadata(out) 144 | ds = sd.open_zarr(out) 145 | cov = ds.sel(cov_sample="simulated1.bam").cov.values 146 | for i in range(len(cov)): 147 | np.testing.assert_array_equal(cov[i], list(variable_coverage["simulated1.bam"].values())[i]) 148 | 149 | 150 | def test_multi_bam_write( 151 | multi_reader, 152 | fixed_bed, 153 | fixed_coverage, 154 | ): 155 | with ( 156 | TemporaryDirectory(suffix=".zarr") as out, 157 | ): 158 | fixed_bed["strand"] = "+" 159 | multi_reader._write( 160 | out=out, 161 | bed=fixed_bed, 162 | fixed_length=20, 163 | sequence_dim="_sequence", 164 | length_dim="_length", 165 | overwrite=True, 166 | ) 167 | zarr.consolidate_metadata(out) 168 | ds = sd.open_zarr(out) 169 | for i in range(1, 6): 170 | cov = ds.sel(cov_sample=f"simulated{i}.bam").cov.values 171 | for j in range(len(cov)): 172 | np.testing.assert_array_equal(cov[j], list(fixed_coverage[f"simulated{i}.bam"].values())[j]) 173 | 174 | 175 | def test_from_region_files( 176 | fasta_reader, 177 | multi_reader, 178 | fasta, 179 | fixed_coverage, 180 | ): 181 | fixed_bed = "tests/data/fixed.bed" 182 | with ( 183 | TemporaryDirectory(suffix=".zarr") as out, 184 | ): 185 | ds = sd.from_region_files( 186 | fasta_reader, 187 | multi_reader, 188 | path=out, 189 | bed=fixed_bed, 190 | fixed_length=20, 191 | sequence_dim="_sequence", 192 | overwrite=True 193 | ) 194 | bed = pd.read_csv(fixed_bed, sep="\t", header=None) 195 | zarr.consolidate_metadata(out) 196 | seqs = [''.join(row.astype(str)) for row in ds["seq"].values] 197 | true_seqs = [fasta[chrom][start:end] for chrom, start, end in bed.values] 198 | np.testing.assert_array_equal(seqs, true_seqs) 199 | for i in range(1, 6): 200 | cov = ds.sel(cov_sample=f"simulated{i}.bam").cov.values 201 | for j in range(len(cov)): 202 | np.testing.assert_array_equal(cov[j], list(fixed_coverage[f"simulated{i}.bam"].values())[j]) 203 | 204 | 205 | def test_read_bam(fixed_coverage): 206 | fasta = "tests/data/variable.fa" 207 | fixed_bed = "tests/data/fixed.bed" 208 | bams = [f"tests/data/simulated{i}.bam" for i in range(1, 6)] 209 | with ( 210 | TemporaryDirectory(suffix=".zarr") as out, 211 | ): 212 | ds = sd.read_bam( 213 | seq_name="seq", 214 | cov_name="cov", 215 | out=out, 216 | fasta=fasta, 217 | bams=bams, 218 | samples=[f"simulated{i}.bam" for i in range(1, 6)], 219 | bed=fixed_bed, 220 | batch_size=50, 221 | fixed_length=20, 222 | overwrite=True 223 | ) 224 | bed = pd.read_csv(fixed_bed, sep="\t", header=None) 225 | zarr.consolidate_metadata(out) 226 | seqs = [''.join(row.astype(str)) for row in ds["seq"].values] 227 | fasta = read_fasta(fasta) 228 | true_seqs = [fasta[chrom][start:end] for chrom, start, end in bed.values] 229 | np.testing.assert_array_equal(seqs, true_seqs) 230 | for i in range(1, 6): 231 | cov = ds.sel(cov_sample=f"simulated{i}.bam").cov.values 232 | for j in range(len(cov)): 233 | np.testing.assert_array_equal(cov[j], list(fixed_coverage[f"simulated{i}.bam"].values())[j]) 234 | -------------------------------------------------------------------------------- /tests/readers/test_bigwig.py: -------------------------------------------------------------------------------- 1 | from tempfile import TemporaryDirectory 2 | 3 | import numpy as np 4 | import polars as pl 5 | import pandas as pd 6 | import seqpro as sp 7 | import xarray as xr 8 | from pytest import fixture 9 | import pickle 10 | 11 | import seqdata as sd 12 | import zarr 13 | 14 | 15 | def read_fasta(file_path): 16 | """ 17 | Reads a FASTA file and returns a dictionary of sequences. 18 | 19 | Parameters: 20 | file_path (str): Path to the FASTA file. 21 | 22 | Returns: 23 | dict: A dictionary where keys are sequence IDs and values are sequences. 24 | """ 25 | sequences = {} 26 | with open(file_path, 'r') as file: 27 | sequence_id = None 28 | sequence_lines = [] 29 | for line in file: 30 | line = line.strip() 31 | if line.startswith(">"): 32 | # Save the previous sequence (if any) before starting a new one 33 | if sequence_id: 34 | sequences[sequence_id] = ''.join(sequence_lines) 35 | # Start a new sequence 36 | sequence_id = line[1:] # Remove the '>' 37 | sequence_lines = [] 38 | else: 39 | # Append sequence lines 40 | sequence_lines.append(line) 41 | # Save the last sequence 42 | if sequence_id: 43 | sequences[sequence_id] = ''.join(sequence_lines) 44 | return sequences 45 | 46 | 47 | @fixture 48 | def fasta(): 49 | fasta = "tests/data/variable.fa" 50 | return read_fasta(fasta) 51 | 52 | 53 | @fixture 54 | def fasta_reader(): 55 | fasta = "tests/data/variable.fa" 56 | reader = sd.GenomeFASTA( 57 | name="seq", 58 | fasta=fasta, 59 | batch_size=50 60 | ) 61 | return reader 62 | 63 | 64 | @fixture 65 | def variable_bed(): 66 | variable_bed = "tests/data/variable.bed" 67 | variable_bed = pd.read_csv(variable_bed, sep="\t", header=None) 68 | return variable_bed 69 | 70 | 71 | @fixture 72 | def fixed_bed(): 73 | fixed_bed = "tests/data/fixed.bed" 74 | fixed_bed = pd.read_csv(fixed_bed, sep="\t", header=None) 75 | return fixed_bed 76 | 77 | @fixture 78 | def variable_coverage(): 79 | variable_coverage = "tests/data/variable.bedcov.pkl" 80 | variable_coverage = pickle.load(open(variable_coverage, 'rb')) 81 | variable_coverage = {k.replace(".bam", ".bw"): v for k, v in variable_coverage.items()} 82 | return variable_coverage 83 | 84 | 85 | @fixture 86 | def fixed_coverage(variable_coverage, fixed_bed): 87 | fixed_coverage = {} 88 | bws = variable_coverage.keys() 89 | for bw in bws: 90 | fixed_coverage[bw] = {} 91 | for i, (region, coverage) in enumerate(variable_coverage[bw].items()): 92 | coverage_interval = region.split(":")[1] 93 | coverage_start, coverage_end = map(int, coverage_interval.split("-")) 94 | start_offset = coverage_start - fixed_bed[1].values[i] 95 | end_offset = fixed_bed[2].values[i] - coverage_end 96 | new_region = f"{fixed_bed[0].values[i]}:{fixed_bed[1].values[i]}-{fixed_bed[2].values[i]}" 97 | if end_offset == 0: 98 | fixed_coverage[bw][new_region] = coverage[start_offset:] 99 | else: 100 | fixed_coverage[bw][new_region] = coverage[start_offset:end_offset] 101 | return fixed_coverage 102 | 103 | 104 | @fixture 105 | def single_reader(): 106 | bw = "tests/data/simulated1.bw" 107 | single_reader = sd.BigWig( 108 | name="cov", 109 | bigwigs=[bw], 110 | samples=["simulated1.bw"], 111 | batch_size=50 112 | ) 113 | return single_reader 114 | 115 | 116 | @fixture 117 | def multi_reader(): 118 | bws = [f"tests/data/simulated{i}.bw" for i in range(1, 6)] 119 | multi_reader = sd.BigWig( 120 | name="cov", 121 | bigwigs=bws, 122 | samples=[f"simulated{i}.bw" for i in range(1, 6)], 123 | batch_size=50 124 | ) 125 | return multi_reader 126 | 127 | 128 | def test_single_bigwig_write( 129 | single_reader, 130 | fixed_bed, 131 | fixed_coverage, 132 | ): 133 | with ( 134 | TemporaryDirectory(suffix=".zarr") as out, 135 | ): 136 | fixed_bed["strand"] = "+" 137 | single_reader._write( 138 | out=out, 139 | bed=fixed_bed, 140 | fixed_length=20, 141 | sequence_dim="_sequence", 142 | length_dim="_length", 143 | overwrite=True, 144 | ) 145 | zarr.consolidate_metadata(out) 146 | ds = sd.open_zarr(out) 147 | cov = ds.sel(cov_sample="simulated1.bw").cov.values 148 | for i in range(len(cov)): 149 | np.testing.assert_array_equal(cov[i], list(fixed_coverage["simulated1.bw"].values())[i]) 150 | 151 | 152 | def test_multi_bigwig_write( 153 | multi_reader, 154 | fixed_bed, 155 | fixed_coverage, 156 | ): 157 | with ( 158 | TemporaryDirectory(suffix=".zarr") as out, 159 | ): 160 | fixed_bed["strand"] = "+" 161 | multi_reader._write( 162 | out=out, 163 | bed=fixed_bed, 164 | fixed_length=20, 165 | sequence_dim="_sequence", 166 | length_dim="_length", 167 | overwrite=True, 168 | ) 169 | zarr.consolidate_metadata(out) 170 | ds = sd.open_zarr(out) 171 | for i in range(1, 6): 172 | cov = ds.sel(cov_sample=f"simulated{i}.bw").cov.values 173 | for j in range(len(cov)): 174 | np.testing.assert_array_equal(cov[j], list(fixed_coverage[f"simulated{i}.bw"].values())[j]) 175 | 176 | 177 | def test_from_region_files( 178 | fasta_reader, 179 | multi_reader, 180 | fasta, 181 | fixed_coverage, 182 | ): 183 | fixed_bed = "tests/data/fixed.bed" 184 | with ( 185 | TemporaryDirectory(suffix=".zarr") as out, 186 | ): 187 | ds = sd.from_region_files( 188 | fasta_reader, 189 | multi_reader, 190 | path=out, 191 | bed=fixed_bed, 192 | fixed_length=20, 193 | sequence_dim="_sequence", 194 | overwrite=True 195 | ) 196 | bed = pd.read_csv(fixed_bed, sep="\t", header=None) 197 | zarr.consolidate_metadata(out) 198 | seqs = [''.join(row.astype(str)) for row in ds["seq"].values] 199 | true_seqs = [fasta[chrom][start:end] for chrom, start, end in bed.values] 200 | np.testing.assert_array_equal(seqs, true_seqs) 201 | for i in range(1, 6): 202 | cov = ds.sel(cov_sample=f"simulated{i}.bw").cov.values 203 | for j in range(len(cov)): 204 | np.testing.assert_array_equal(cov[j], list(fixed_coverage[f"simulated{i}.bw"].values())[j]) 205 | 206 | 207 | def test_read_bigwig(fixed_coverage): 208 | fasta = "tests/data/variable.fa" 209 | fixed_bed = "tests/data/fixed.bed" 210 | bws = [f"tests/data/simulated{i}.bw" for i in range(1, 6)] 211 | with ( 212 | TemporaryDirectory(suffix=".zarr") as out, 213 | ): 214 | ds = sd.read_bigwig( 215 | seq_name="seq", 216 | cov_name="cov", 217 | out=out, 218 | fasta=fasta, 219 | bigwigs=bws, 220 | samples=[f"simulated{i}.bw" for i in range(1, 6)], 221 | bed=fixed_bed, 222 | batch_size=50, 223 | fixed_length=20, 224 | overwrite=True 225 | ) 226 | bed = pd.read_csv(fixed_bed, sep="\t", header=None) 227 | zarr.consolidate_metadata(out) 228 | seqs = [''.join(row.astype(str)) for row in ds["seq"].values] 229 | fasta = read_fasta(fasta) 230 | true_seqs = [fasta[chrom][start:end] for chrom, start, end in bed.values] 231 | np.testing.assert_array_equal(seqs, true_seqs) 232 | for i in range(1, 6): 233 | cov = ds.sel(cov_sample=f"simulated{i}.bw").cov.values 234 | for j in range(len(cov)): 235 | np.testing.assert_array_equal(cov[j], list(fixed_coverage[f"simulated{i}.bw"].values())[j]) 236 | -------------------------------------------------------------------------------- /tests/readers/test_flat_fasta.py: -------------------------------------------------------------------------------- 1 | from tempfile import TemporaryDirectory 2 | 3 | import numpy as np 4 | import polars as pl 5 | import pandas as pd 6 | import seqpro as sp 7 | import xarray as xr 8 | from pytest import fixture 9 | 10 | import seqdata as sd 11 | import zarr 12 | 13 | 14 | def read_fasta(file_path): 15 | """ 16 | Reads a FASTA file and returns a dictionary of sequences. 17 | 18 | Parameters: 19 | file_path (str): Path to the FASTA file. 20 | 21 | Returns: 22 | dict: A dictionary where keys are sequence IDs and values are sequences. 23 | """ 24 | sequences = {} 25 | with open(file_path, 'r') as file: 26 | sequence_id = None 27 | sequence_lines = [] 28 | for line in file: 29 | line = line.strip() 30 | if line.startswith(">"): 31 | # Save the previous sequence (if any) before starting a new one 32 | if sequence_id: 33 | sequences[sequence_id] = ''.join(sequence_lines) 34 | # Start a new sequence 35 | sequence_id = line[1:] # Remove the '>' 36 | sequence_lines = [] 37 | else: 38 | # Append sequence lines 39 | sequence_lines.append(line) 40 | # Save the last sequence 41 | if sequence_id: 42 | sequences[sequence_id] = ''.join(sequence_lines) 43 | return sequences 44 | 45 | 46 | @fixture 47 | def variable_fasta(): 48 | variable_fasta = "tests/data/variable.fa" 49 | return read_fasta(variable_fasta) 50 | 51 | 52 | @fixture 53 | def fixed_fasta(): 54 | fixed_fasta = "tests/data/fixed.fa" 55 | return read_fasta(fixed_fasta) 56 | 57 | @fixture 58 | def variable_reader(): 59 | variable_fasta = "tests/data/variable.fa" 60 | variable_reader = sd.FlatFASTA( 61 | name="variable_seq", 62 | fasta=variable_fasta, 63 | batch_size=50 64 | ) 65 | return variable_reader 66 | 67 | 68 | @fixture 69 | def fixed_reader(): 70 | fixed_fasta = "tests/data/fixed.fa" 71 | fixed_reader = sd.FlatFASTA( 72 | name="fixed_seq", 73 | fasta=fixed_fasta, 74 | batch_size=50 75 | ) 76 | return fixed_reader 77 | 78 | 79 | def test_variable_write( 80 | variable_reader, 81 | variable_fasta, 82 | ): 83 | with ( 84 | TemporaryDirectory(suffix=".zarr") as out, 85 | ): 86 | variable_reader._write( 87 | out=out, 88 | fixed_length=False, 89 | sequence_dim="_sequence", 90 | overwrite=True, 91 | ) 92 | zarr.consolidate_metadata(out) 93 | ds = sd.open_zarr(out) 94 | seqs = ds["variable_seq"].values.astype(str) 95 | np.testing.assert_array_equal(seqs, list(variable_fasta.values())) 96 | 97 | 98 | def test_fixed_write( 99 | fixed_reader, 100 | fixed_fasta, 101 | ): 102 | with ( 103 | TemporaryDirectory(suffix=".zarr") as out, 104 | ): 105 | fixed_reader._write( 106 | out=out, 107 | fixed_length=20, 108 | sequence_dim="_sequence", 109 | length_dim="_length", 110 | overwrite=True, 111 | ) 112 | zarr.consolidate_metadata(out) 113 | ds = sd.open_zarr(out) 114 | seqs = [''.join(row.astype(str)) for row in ds["fixed_seq"].values] 115 | np.testing.assert_array_equal(seqs, list(fixed_fasta.values())) 116 | 117 | 118 | def test_from_flat_files( 119 | variable_reader, 120 | fixed_reader, 121 | variable_fasta, 122 | fixed_fasta, 123 | ): 124 | with ( 125 | TemporaryDirectory(suffix=".zarr") as out, 126 | ): 127 | ds = sd.from_flat_files( 128 | variable_reader, 129 | fixed_reader, 130 | path=out, 131 | fixed_length=False, 132 | sequence_dim="_sequence", 133 | overwrite=True 134 | ) 135 | 136 | variable_seqs = ds["variable_seq"].values.astype(str) 137 | fixed_seqs = ds["fixed_seq"].values.astype(str) 138 | np.testing.assert_array_equal(variable_seqs, list(variable_fasta.values())) 139 | np.testing.assert_array_equal(fixed_seqs, list(fixed_fasta.values())) 140 | 141 | 142 | def test_read_flat_fasta( 143 | variable_fasta, 144 | ): 145 | variable_fasta = "tests/data/variable.fa" 146 | with ( 147 | TemporaryDirectory(suffix=".zarr") as out, 148 | ): 149 | ds = sd.read_flat_fasta( 150 | fasta=variable_fasta, 151 | out=out, 152 | name="seq", 153 | fixed_length=False, 154 | batch_size=50, 155 | overwrite=True 156 | ) 157 | seqs = ds["seq"].values.astype(str) 158 | np.testing.assert_array_equal(seqs, list(read_fasta(variable_fasta).values())) 159 | -------------------------------------------------------------------------------- /tests/readers/test_genome_fasta.py: -------------------------------------------------------------------------------- 1 | from tempfile import TemporaryDirectory 2 | 3 | import numpy as np 4 | import polars as pl 5 | import pandas as pd 6 | import seqpro as sp 7 | import xarray as xr 8 | from pytest import fixture 9 | 10 | import seqdata as sd 11 | import zarr 12 | 13 | 14 | def read_fasta(file_path): 15 | """ 16 | Reads a FASTA file and returns a dictionary of sequences. 17 | 18 | Parameters: 19 | file_path (str): Path to the FASTA file. 20 | 21 | Returns: 22 | dict: A dictionary where keys are sequence IDs and values are sequences. 23 | """ 24 | sequences = {} 25 | with open(file_path, 'r') as file: 26 | sequence_id = None 27 | sequence_lines = [] 28 | for line in file: 29 | line = line.strip() 30 | if line.startswith(">"): 31 | # Save the previous sequence (if any) before starting a new one 32 | if sequence_id: 33 | sequences[sequence_id] = ''.join(sequence_lines) 34 | # Start a new sequence 35 | sequence_id = line[1:] # Remove the '>' 36 | sequence_lines = [] 37 | else: 38 | # Append sequence lines 39 | sequence_lines.append(line) 40 | # Save the last sequence 41 | if sequence_id: 42 | sequences[sequence_id] = ''.join(sequence_lines) 43 | return sequences 44 | 45 | 46 | @fixture 47 | def fasta(): 48 | fasta = "tests/data/variable.fa" 49 | return read_fasta(fasta) 50 | 51 | 52 | @fixture 53 | def fasta_reader(): 54 | fasta = "tests/data/variable.fa" 55 | reader = sd.GenomeFASTA( 56 | name="seq", 57 | fasta=fasta, 58 | batch_size=50 59 | ) 60 | return reader 61 | 62 | 63 | @fixture 64 | def variable_bed(): 65 | variable_bed = "tests/data/variable.bed" 66 | variable_bed = pd.read_csv(variable_bed, sep="\t", header=None, names=["chrom", "chromStart", "chromEnd"]) 67 | return variable_bed 68 | 69 | 70 | @fixture 71 | def fixed_bed(): 72 | fixed_bed = "tests/data/fixed.bed" 73 | fixed_bed = pd.read_csv(fixed_bed, sep="\t", header=None, names=["chrom", "chromStart", "chromEnd"]) 74 | return fixed_bed 75 | 76 | 77 | def test_variable_write( 78 | fasta_reader, 79 | fasta, 80 | variable_bed, 81 | ): 82 | with ( 83 | TemporaryDirectory(suffix=".zarr") as out, 84 | ): 85 | true_seqs = [fasta[chrom][start:end] for chrom, start, end in variable_bed.values] 86 | variable_bed["strand"] = "+" 87 | fasta_reader._write( 88 | out=out, 89 | bed=variable_bed, 90 | fixed_length=False, 91 | sequence_dim="_sequence", 92 | overwrite=True 93 | ) 94 | zarr.consolidate_metadata(out) 95 | ds = sd.open_zarr(out) 96 | seqs = ds["seq"].values.astype(str) 97 | np.testing.assert_array_equal(seqs, true_seqs) 98 | 99 | 100 | def test_fixed_write( 101 | fasta_reader, 102 | fasta, 103 | fixed_bed, 104 | ): 105 | with ( 106 | TemporaryDirectory(suffix=".zarr") as out, 107 | ): 108 | true_seqs = [fasta[chrom][start:end] for chrom, start, end in fixed_bed.values] 109 | fixed_bed["strand"] = "+" 110 | fasta_reader._write( 111 | out=out, 112 | bed=fixed_bed, 113 | fixed_length=20, 114 | sequence_dim="_sequence", 115 | length_dim="_length", 116 | overwrite=True 117 | ) 118 | zarr.consolidate_metadata(out) 119 | ds = sd.open_zarr(out) 120 | seqs = [''.join(row.astype(str)) for row in ds["seq"].values] 121 | np.testing.assert_array_equal(seqs, true_seqs) 122 | 123 | 124 | def test_from_region_files( 125 | fasta_reader, 126 | fasta, 127 | ): 128 | variable_bed = "tests/data/variable.bed" 129 | with ( 130 | TemporaryDirectory(suffix=".zarr") as out, 131 | ): 132 | ds = sd.from_region_files( 133 | fasta_reader, 134 | bed=variable_bed, 135 | path=out, 136 | fixed_length=False, 137 | sequence_dim="_sequence", 138 | overwrite=True 139 | ) 140 | bed = pd.read_csv(variable_bed, sep="\t", header=None, names=["chrom", "chromStart", "chromEnd"]) 141 | true_seqs = [fasta[chrom][start:end] for chrom, start, end in bed.values] 142 | seqs = ds["seq"].values.astype(str) 143 | np.testing.assert_array_equal(seqs, true_seqs) 144 | 145 | 146 | def test_read_genome_fasta(): 147 | fasta = "tests/data/variable.fa" 148 | variable_bed = "tests/data/variable.bed" 149 | with ( 150 | TemporaryDirectory(suffix=".zarr") as out, 151 | ): 152 | ds = sd.read_genome_fasta( 153 | name="seq", 154 | out=out, 155 | fasta=fasta, 156 | bed=variable_bed, 157 | fixed_length=False, 158 | batch_size=50, 159 | overwrite=True 160 | ) 161 | 162 | bed = pd.read_csv(variable_bed, sep="\t", header=None, names=["chrom", "chromStart", "chromEnd"]) 163 | true_seqs = [read_fasta(fasta)[chrom][start:end] for chrom, start, end in bed.values] 164 | seqs = ds["seq"].values.astype(str) 165 | np.testing.assert_array_equal(seqs, true_seqs) 166 | -------------------------------------------------------------------------------- /tests/readers/test_table.py: -------------------------------------------------------------------------------- 1 | from tempfile import NamedTemporaryFile, TemporaryDirectory 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import polars as pl 6 | import seqpro as sp 7 | import xarray as xr 8 | from pytest_cases import fixture 9 | import zarr 10 | 11 | import seqdata as sd 12 | 13 | 14 | @fixture # type: ignore 15 | def gen_table(): 16 | """Dummy dataset with AA, CC, GG, TT sequences.""" 17 | 18 | ds = xr.Dataset( 19 | { 20 | "seq": xr.DataArray( 21 | sp.random_seqs((2, 5), sp.DNA, 0), dims=["_sequence", "_length"] 22 | ), 23 | "target": xr.DataArray([5, 11.2], dims=["_sequence"]), 24 | } 25 | ) 26 | return ds 27 | 28 | 29 | def write_table(csv: str, ds: xr.Dataset): 30 | ( 31 | ds.assign( 32 | seq=xr.DataArray( 33 | ds["seq"].values.view("S5").astype(str).squeeze(), dims=["_sequence"] 34 | ) 35 | ) 36 | .to_pandas() 37 | .reset_index(drop=True) 38 | .to_csv(csv, index=False) 39 | ) 40 | 41 | 42 | def test_write_table(gen_table): 43 | with NamedTemporaryFile(suffix=".csv") as csv: 44 | write_table(csv.name, gen_table) 45 | 46 | df = pl.read_csv(csv.name) 47 | for name, da in gen_table.items(): 48 | df_data = df[name].to_numpy() 49 | if da.dtype.char == "S": 50 | df_data = df_data.astype("S")[..., None].view("S1") 51 | np.testing.assert_array_equal(da.values, df_data) 52 | 53 | 54 | def test_table(gen_table): 55 | with ( 56 | NamedTemporaryFile(suffix=".csv") as csv, 57 | TemporaryDirectory(suffix=".zarr") as out, 58 | ): 59 | write_table(csv.name, gen_table) 60 | 61 | ds = sd.from_flat_files( 62 | sd.Table("seq", csv.name, "seq", batch_size=1), 63 | path=out, 64 | fixed_length=True, 65 | ) 66 | 67 | for name, da in gen_table.items(): 68 | np.testing.assert_array_equal(da, ds[name]) 69 | 70 | 71 | @fixture 72 | def variable_table(): 73 | variable_tsv = "tests/data/variable.tsv" 74 | variable_table = pd.read_csv(variable_tsv, sep="\t") 75 | return variable_table 76 | 77 | 78 | @fixture 79 | def variable_reader(): 80 | variable_tsv = "tests/data/variable.tsv" 81 | variable_reader = sd.Table( 82 | name="variable_seq", 83 | tables=variable_tsv, 84 | seq_col="seq", 85 | batch_size=50 86 | ) 87 | return variable_reader 88 | 89 | @fixture 90 | def fixed_table(): 91 | fixed_tsv = "tests/data/fixed.tsv" 92 | fixed_table = pd.read_csv(fixed_tsv, sep="\t") 93 | return fixed_table 94 | 95 | 96 | @fixture 97 | def fixed_reader(): 98 | fixed_tsv = "tests/data/fixed.tsv" 99 | fixed_reader = sd.Table( 100 | name="fixed_seq", 101 | tables=fixed_tsv, 102 | seq_col="seq", 103 | batch_size=50 104 | ) 105 | return fixed_reader 106 | 107 | 108 | @fixture 109 | def combo_table(): 110 | tsvs = ["tests/data/variable.tsv", "tests/data/fixed.tsv"] 111 | combo_table = pd.concat([pd.read_csv(tsv, sep="\t") for tsv in tsvs]) 112 | return combo_table 113 | 114 | 115 | @fixture 116 | def combo_reader(): 117 | tsvs = ["tests/data/variable.tsv", "tests/data/fixed.tsv"] 118 | combo_reader = sd.Table( 119 | name="seq", 120 | tables=tsvs, 121 | seq_col="seq", 122 | batch_size=50 123 | ) 124 | return combo_reader 125 | 126 | 127 | def test_variable_write( 128 | variable_reader, 129 | variable_table, 130 | ): 131 | with ( 132 | TemporaryDirectory(suffix=".zarr") as out, 133 | ): 134 | variable_reader._write( 135 | out=out, 136 | fixed_length=False, 137 | sequence_dim="_sequence", 138 | overwrite=True 139 | ) 140 | zarr.consolidate_metadata(out) 141 | ds = sd.open_zarr(out) 142 | seqs = ds["variable_seq"].values.astype(str) 143 | targets = ds["target"].values 144 | np.testing.assert_array_equal(seqs, variable_table["seq"]) 145 | np.testing.assert_almost_equal(targets, variable_table["target"]) 146 | 147 | 148 | def test_fixed_write( 149 | fixed_reader, 150 | fixed_table 151 | ): 152 | with ( 153 | TemporaryDirectory(suffix=".zarr") as out, 154 | ): 155 | fixed_reader._write( 156 | out=out, 157 | fixed_length=20, 158 | sequence_dim="_sequence", 159 | length_dim="_length", 160 | overwrite=True 161 | ) 162 | zarr.consolidate_metadata(out) 163 | ds = sd.open_zarr(out) 164 | seqs = [''.join(row.astype(str)) for row in ds["fixed_seq"].values] 165 | targets = ds["target"].values 166 | np.testing.assert_array_equal(seqs, fixed_table["seq"]) 167 | np.testing.assert_almost_equal(targets, fixed_table["target"]) 168 | 169 | 170 | def test_combo_write( 171 | combo_reader, 172 | combo_table 173 | ): 174 | with ( 175 | TemporaryDirectory(suffix=".zarr") as out, 176 | ): 177 | combo_reader._write( 178 | out=out, 179 | fixed_length=False, 180 | sequence_dim="_sequence", 181 | overwrite=True 182 | ) 183 | zarr.consolidate_metadata(out) 184 | ds = sd.open_zarr(out) 185 | seqs = ds["seq"].values.astype(str) 186 | targets = ds["target"].values 187 | np.testing.assert_array_equal(seqs, combo_table["seq"]) 188 | np.testing.assert_almost_equal(targets, combo_table["target"]) 189 | 190 | 191 | def test_from_flat_files( 192 | variable_reader, 193 | fixed_reader, 194 | variable_table, 195 | fixed_table, 196 | ): 197 | with ( 198 | TemporaryDirectory(suffix=".zarr") as out, 199 | ): 200 | ds = sd.from_flat_files( 201 | variable_reader, 202 | fixed_reader, 203 | path=out, 204 | fixed_length=False, 205 | sequence_dim="_sequence", 206 | overwrite=True 207 | ) 208 | 209 | variable_seqs = ds["variable_seq"].values.astype(str) 210 | fixed_seqs = ds["fixed_seq"].values.astype(str) 211 | np.testing.assert_array_equal(variable_seqs, variable_table["seq"]) 212 | np.testing.assert_array_equal(fixed_seqs, fixed_table["seq"]) 213 | 214 | 215 | def test_read_table( 216 | combo_table, 217 | ): 218 | tsvs = ["tests/data/variable.tsv", "tests/data/fixed.tsv"] 219 | with ( 220 | TemporaryDirectory(suffix=".zarr") as out, 221 | ): 222 | ds = sd.read_table( 223 | tables=tsvs, 224 | seq_col="seq", 225 | name="seq", 226 | out=out, 227 | fixed_length=False, 228 | batch_size=50, 229 | overwrite=True 230 | ) 231 | seqs = ds["seq"].values.astype(str) 232 | targets = ds["target"].values 233 | np.testing.assert_array_equal(seqs, combo_table["seq"]) 234 | np.testing.assert_almost_equal(targets, combo_table["target"]) 235 | -------------------------------------------------------------------------------- /tests/test_bed_ops.py: -------------------------------------------------------------------------------- 1 | from tempfile import NamedTemporaryFile 2 | 3 | import polars as pl 4 | import polars.testing as pl_testing 5 | from pytest_cases import parametrize_with_cases 6 | 7 | import seqdata as sd 8 | 9 | 10 | def bed_bed3(): 11 | bed = pl.DataFrame( 12 | { 13 | "chrom": ["chr1", "chr1", "chr1"], 14 | "chromStart": [1, 2, 3], 15 | "chromEnd": [2, 3, 4], 16 | } 17 | ) 18 | return bed 19 | 20 | 21 | def bed_bed4(): 22 | bed = pl.DataFrame( 23 | { 24 | "chrom": ["chr1", "chr1", "chr1"], 25 | "chromStart": [1, 2, 3], 26 | "chromEnd": [2, 3, 4], 27 | "name": ["a", "b", "c"], 28 | } 29 | ) 30 | return bed 31 | 32 | 33 | def bed_bed5(): 34 | bed = pl.DataFrame( 35 | { 36 | "chrom": ["chr1", "chr1", "chr1"], 37 | "chromStart": [1, 2, 3], 38 | "chromEnd": [2, 3, 4], 39 | "name": ["a", "b", "c"], 40 | "score": [1.1, 2, 3], 41 | } 42 | ) 43 | return bed 44 | 45 | 46 | def narrowpeak_simple(): 47 | bed = pl.DataFrame( 48 | { 49 | "chrom": ["chr1", "chr1", "chr1"], 50 | "chromStart": [1, 2, 3], 51 | "chromEnd": [2, 3, 4], 52 | "name": ["a", "b", "c"], 53 | "score": [1.1, 2, 3], 54 | "strand": ["+", "-", None], 55 | "signalValue": [1.1, 2, 3], 56 | "pValue": [1.1, 2, 3], 57 | "qValue": [1.1, 2, 3], 58 | "peak": [1, 2, 3], 59 | } 60 | ) 61 | return bed 62 | 63 | 64 | def broadpeak_simple(): 65 | bed = pl.DataFrame( 66 | { 67 | "chrom": ["chr1", "chr1", "chr1"], 68 | "chromStart": [1, 2, 3], 69 | "chromEnd": [2, 3, 4], 70 | "name": ["a", "b", "c"], 71 | "score": [1.1, 2, 3], 72 | "strand": ["+", "-", None], 73 | "signalValue": [1.1, 2, 3], 74 | "pValue": [1.1, 2, 3], 75 | "qValue": [1.1, 2, 3], 76 | } 77 | ) 78 | return bed 79 | 80 | 81 | @parametrize_with_cases("bed", cases=".", prefix="bed_") 82 | def test_read_bed(bed: pl.DataFrame): 83 | with NamedTemporaryFile(suffix=".bed") as f: 84 | bed.write_csv(f.name, include_header=False, separator="\t", null_value=".") 85 | bed2 = sd.read_bedlike(f.name) 86 | pl_testing.assert_frame_equal(bed, bed2) 87 | 88 | 89 | @parametrize_with_cases("narrowpeak", cases=".", prefix="narrowpeak_") 90 | def test_read_narrowpeak(narrowpeak: pl.DataFrame): 91 | with NamedTemporaryFile(suffix=".narrowPeak") as f: 92 | narrowpeak.write_csv( 93 | f.name, include_header=False, separator="\t", null_value="." 94 | ) 95 | narrowpeak2 = sd.read_bedlike(f.name) 96 | pl_testing.assert_frame_equal(narrowpeak, narrowpeak2) 97 | 98 | 99 | @parametrize_with_cases("broadpeak", cases=".", prefix="broadpeak_") 100 | def test_read_broadpeak(broadpeak: pl.DataFrame): 101 | with NamedTemporaryFile(suffix=".broadPeak") as f: 102 | broadpeak.write_csv( 103 | f.name, include_header=False, separator="\t", null_value="." 104 | ) 105 | broadpeak2 = sd.read_bedlike(f.name) 106 | pl_testing.assert_frame_equal(broadpeak, broadpeak2) 107 | -------------------------------------------------------------------------------- /tests/test_max_jitter.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML4GLand/SeqData/10afa85d0ec954b9d28955b072a55274c3c18dfd/tests/test_max_jitter.py -------------------------------------------------------------------------------- /tests/test_open_zarr.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML4GLand/SeqData/10afa85d0ec954b9d28955b072a55274c3c18dfd/tests/test_open_zarr.py -------------------------------------------------------------------------------- /tests/test_to_zarr.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | 3 | import numpy as np 4 | import xarray as xr 5 | from pytest import fixture 6 | 7 | import seqdata as sd 8 | 9 | 10 | @fixture 11 | def sdata(): 12 | """Dummy dataset with AA, CC, GG, TT sequences.""" 13 | seqs = np.array([[b"A", "A"], [b"C", "C"], [b"G", "G"], [b"T", b"T"]]) 14 | return xr.Dataset( 15 | { 16 | "seqs": xr.DataArray(seqs, dims=["_sequence", "_length"]), 17 | } 18 | ) 19 | 20 | 21 | def test_to_zarr_non_uniform_chunks(sdata: xr.Dataset): 22 | # set chunks to violate write requirements 23 | # - uniform except last 24 | # - last <= in size than the rest 25 | sdata = sdata.chunk({"_sequence": (1, 3), "_length": -1}) 26 | 27 | with tempfile.TemporaryDirectory() as tmpdir: 28 | sd.to_zarr(sdata, tmpdir, mode="w") 29 | 30 | after = sd.open_zarr(tmpdir) 31 | # chunks should satisfy write requirements 32 | assert after.chunksizes == {"_sequence": (3, 1), "_length": (2,)} 33 | -------------------------------------------------------------------------------- /tests/torch/test_torch.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import numpy as np 4 | import pytest 5 | import seqpro as sp 6 | import torch 7 | import xarray as xr 8 | from numpy.typing import NDArray 9 | from pytest import fixture 10 | 11 | import seqdata as sd 12 | 13 | 14 | @fixture 15 | def dummy_dataset(): 16 | """Dummy dataset with 3 sequences of length 5.""" 17 | seqs = sp.random_seqs((3, 5), sp.DNA, 0) 18 | return xr.Dataset( 19 | { 20 | "seqs": xr.DataArray(seqs, dims=["_sequence", "_length"]), 21 | } 22 | ) 23 | 24 | 25 | def test_no_transforms(dummy_dataset: xr.Dataset): 26 | dl = sd.get_torch_dataloader( 27 | dummy_dataset, sample_dims="_sequence", variables=["seqs"] 28 | ) 29 | # this should raise a TypeError: can't convert np.bytes_ to Tensor 30 | with pytest.raises(TypeError): 31 | next(iter(dl)) 32 | 33 | 34 | def test_ohe_transform(dummy_dataset: xr.Dataset): 35 | def transform(batch: Dict[str, NDArray]): 36 | batch["seqs"] = sp.DNA.ohe(batch["seqs"]) 37 | return batch 38 | 39 | dl = sd.get_torch_dataloader( 40 | dummy_dataset, 41 | sample_dims="_sequence", 42 | variables=["seqs"], 43 | transform=transform, 44 | batch_size=2, 45 | ) 46 | batch: Dict[str, torch.Tensor] = next(iter(dl)) 47 | seqs: NDArray = batch["seqs"].numpy() 48 | ds_seqs = sp.DNA.ohe(dummy_dataset["seqs"].values) 49 | np.testing.assert_array_equal(seqs, ds_seqs[:2]) 50 | 51 | 52 | @fixture 53 | def multi_var_dataset(): 54 | """Dataset with multiple variables.""" 55 | seqs = sp.random_seqs((3, 5), sp.DNA, 0) 56 | scores = np.random.rand(3, 5) 57 | return xr.Dataset( 58 | { 59 | "seqs": xr.DataArray(seqs, dims=["_sequence", "_length"]), 60 | "scores": xr.DataArray(scores, dims=["_sequence", "_length"]), 61 | } 62 | ) 63 | 64 | 65 | def test_multi_variable(multi_var_dataset: xr.Dataset): 66 | dl = sd.get_torch_dataloader( 67 | multi_var_dataset, 68 | sample_dims="_sequence", 69 | variables=["seqs", "scores"], 70 | batch_size=2, 71 | ) 72 | batch = next(iter(dl)) 73 | assert "seqs" in batch, "seqs variable missing in batch" 74 | assert "scores" in batch, "scores variable missing in batch" 75 | assert batch["seqs"].shape == (2, 5), "Shape mismatch for seqs" 76 | assert batch["scores"].shape == (2, 5), "Shape mismatch for scores" 77 | 78 | 79 | def test_shuffling(dummy_dataset: xr.Dataset): 80 | dl_shuffled = sd.get_torch_dataloader( 81 | dummy_dataset, 82 | sample_dims="_sequence", 83 | variables=["seqs"], 84 | shuffle=True, 85 | batch_size=3, 86 | seed=42, # Ensure deterministic shuffling 87 | ) 88 | dl_unshuffled = sd.get_torch_dataloader( 89 | dummy_dataset, 90 | sample_dims="_sequence", 91 | variables=["seqs"], 92 | shuffle=False, 93 | batch_size=3, 94 | ) 95 | 96 | batch_shuffled = next(iter(dl_shuffled)) 97 | batch_unshuffled = next(iter(dl_unshuffled)) 98 | assert not np.array_equal( 99 | batch_shuffled["seqs"].numpy(), batch_unshuffled["seqs"].numpy() 100 | ), "Shuffled data should not match unshuffled data" 101 | 102 | 103 | def test_return_tuples(dummy_dataset: xr.Dataset): 104 | dl = sd.get_torch_dataloader( 105 | dummy_dataset, 106 | sample_dims="_sequence", 107 | variables=["seqs"], 108 | return_tuples=True, 109 | batch_size=2, 110 | ) 111 | batch = next(iter(dl)) 112 | assert isinstance(batch, tuple), "Batch should be a tuple" 113 | assert len(batch) == 1, "Tuple should contain one item" 114 | assert isinstance(batch[0], torch.Tensor), "Tuple item should be a tensor" 115 | --------------------------------------------------------------------------------