├── .coveragerc ├── .github └── workflows │ ├── publish-pypi.yml │ └── run-tests.yml ├── .gitignore ├── .pre-commit-config.yaml ├── AUTHORS.md ├── CHANGELOG.md ├── CONTRIBUTING.md ├── LICENSE.txt ├── README.md ├── assets └── fig1.png ├── docs ├── Makefile ├── _static │ └── .gitignore ├── authors.md ├── changelog.md ├── conf.py ├── contributing.md ├── index.md ├── license.md ├── readme.md └── requirements.txt ├── pyproject.toml ├── scripts ├── finetune.py └── predict_genes.py ├── setup.cfg ├── setup.py ├── src └── decima │ ├── __init__.py │ ├── decima_model.py │ ├── evaluate.py │ ├── interpret.py │ ├── lightning.py │ ├── loss.py │ ├── metrics.py │ ├── preprocess.py │ ├── read_hdf5.py │ ├── variant.py │ ├── visualize.py │ └── write_hdf5.py ├── tests ├── conftest.py └── test_package.py ├── tox.ini └── tutorials ├── tutorial.ipynb └── variants.tsv /.coveragerc: -------------------------------------------------------------------------------- 1 | # .coveragerc to control coverage.py 2 | [run] 3 | branch = True 4 | source = decima 5 | # omit = bad_file.py 6 | 7 | [paths] 8 | source = 9 | src/ 10 | */site-packages/ 11 | 12 | [report] 13 | # Regexes for lines to exclude from consideration 14 | exclude_lines = 15 | # Have to re-enable the standard pragma 16 | pragma: no cover 17 | 18 | # Don't complain about missing debug-only code: 19 | def __repr__ 20 | if self\.debug 21 | 22 | # Don't complain if tests don't hit defensive assertion code: 23 | raise AssertionError 24 | raise NotImplementedError 25 | 26 | # Don't complain if non-runnable code isn't run: 27 | if 0: 28 | if __name__ == .__main__.: 29 | -------------------------------------------------------------------------------- /.github/workflows/publish-pypi.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Publish to PyPI 5 | 6 | on: 7 | push: 8 | tags: "*" 9 | 10 | jobs: 11 | build: 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - uses: actions/checkout@v4 16 | 17 | - name: Set up Python 3.11 18 | uses: actions/setup-python@v5 19 | with: 20 | python-version: 3.11 21 | 22 | - name: Install dependencies 23 | run: | 24 | python -m pip install --upgrade pip 25 | pip install tox 26 | 27 | - name: Test with tox 28 | run: | 29 | tox 30 | 31 | - name: Build docs 32 | run: | 33 | tox -e docs 34 | 35 | - run: touch ./docs/_build/html/.nojekyll 36 | 37 | - name: GH Pages Deployment 38 | uses: JamesIves/github-pages-deploy-action@v4 39 | with: 40 | branch: gh-pages # The branch the action should deploy to. 41 | folder: ./docs/_build/html 42 | clean: true # Automatically remove deleted files from the deploy branch 43 | 44 | - name: Build Project and Publish 45 | run: | 46 | python -m tox -e clean,build 47 | 48 | - name: Publish package 49 | uses: pypa/gh-action-pypi-publish@v1.12.2 50 | with: 51 | user: __token__ 52 | password: ${{ secrets.PYPI_API_TOKEN }} 53 | -------------------------------------------------------------------------------- /.github/workflows/run-tests.yml: -------------------------------------------------------------------------------- 1 | name: Run tests 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | 8 | jobs: 9 | build: 10 | runs-on: ubuntu-latest 11 | strategy: 12 | matrix: 13 | python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] 14 | 15 | name: Python ${{ matrix.python-version }} 16 | steps: 17 | - uses: actions/checkout@v4 18 | 19 | - name: Setup Python 20 | uses: actions/setup-python@v5 21 | with: 22 | python-version: ${{ matrix.python-version }} 23 | cache: "pip" 24 | 25 | - name: Install dependencies 26 | run: | 27 | python -m pip install --upgrade pip 28 | pip install tox 29 | 30 | - name: Test with tox 31 | run: | 32 | tox 33 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Temporary and binary files 2 | *~ 3 | *.py[cod] 4 | *.so 5 | *.cfg 6 | !.isort.cfg 7 | !setup.cfg 8 | *.orig 9 | *.log 10 | *.pot 11 | __pycache__/* 12 | .cache/* 13 | .*.swp 14 | */.ipynb_checkpoints/* 15 | .DS_Store 16 | 17 | # Project files 18 | .ropeproject 19 | .project 20 | .pydevproject 21 | .settings 22 | .idea 23 | .vscode 24 | tags 25 | 26 | # Package files 27 | *.egg 28 | *.eggs/ 29 | .installed.cfg 30 | *.egg-info 31 | 32 | # Unittest and coverage 33 | htmlcov/* 34 | .coverage 35 | .coverage.* 36 | .tox 37 | junit*.xml 38 | coverage.xml 39 | .pytest_cache/ 40 | 41 | # Build and docs folder/files 42 | build/* 43 | dist/* 44 | sdist/* 45 | docs/api/* 46 | docs/_rst/* 47 | docs/_build/* 48 | cover/* 49 | MANIFEST 50 | 51 | # Per-project virtualenvs 52 | .venv*/ 53 | .conda*/ 54 | .python-version 55 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | exclude: '^docs/conf.py' 2 | 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v5.0.0 6 | hooks: 7 | - id: trailing-whitespace 8 | - id: check-added-large-files 9 | - id: check-ast 10 | - id: check-json 11 | - id: check-merge-conflict 12 | - id: check-xml 13 | - id: check-yaml 14 | - id: debug-statements 15 | - id: end-of-file-fixer 16 | - id: requirements-txt-fixer 17 | - id: mixed-line-ending 18 | args: ['--fix=auto'] # replace 'auto' with 'lf' to enforce Linux/Mac line endings or 'crlf' for Windows 19 | 20 | - repo: https://github.com/astral-sh/ruff-pre-commit 21 | # Ruff version. 22 | rev: v0.8.2 23 | hooks: 24 | - id: ruff 25 | args: [--fix, --exit-non-zero-on-fix] 26 | - id: ruff-format 27 | 28 | ## Check for misspells in documentation files: 29 | # - repo: https://github.com/codespell-project/codespell 30 | # rev: v2.2.5 31 | # hooks: 32 | # - id: codespell 33 | -------------------------------------------------------------------------------- /AUTHORS.md: -------------------------------------------------------------------------------- 1 | # Contributors 2 | 3 | * Avantia Lal 4 | * Alexander Karollus 5 | * Laura Gunsalus 6 | * Gokcen Eraslan 7 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | ## Version 0.1 (development) 4 | 5 | - Feature A added 6 | - FIX: nasty bug #1729 fixed 7 | - add your changes here! 8 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | ```{todo} THIS IS SUPPOSED TO BE AN EXAMPLE. MODIFY IT ACCORDING TO YOUR NEEDS! 2 | 3 | The document assumes you are using a source repository service that promotes a 4 | contribution model similar to [GitHub's fork and pull request workflow]. 5 | While this is true for the majority of services (like GitHub, GitLab, 6 | BitBucket), it might not be the case for private repositories (e.g., when 7 | using Gerrit). 8 | 9 | Also notice that the code examples might refer to GitHub URLs or the text 10 | might use GitHub specific terminology (e.g., *Pull Request* instead of *Merge 11 | Request*). 12 | 13 | Please make sure to check the document having these assumptions in mind 14 | and update things accordingly. 15 | ``` 16 | 17 | ```{todo} Provide the correct links/replacements at the bottom of the document. 18 | ``` 19 | 20 | ```{todo} You might want to have a look on [PyScaffold's contributor's guide], 21 | 22 | especially if your project is open source. The text should be very similar to 23 | this template, but there are a few extra contents that you might decide to 24 | also include, like mentioning labels of your issue tracker or automated 25 | releases. 26 | ``` 27 | 28 | # Contributing 29 | 30 | Welcome to `decima` contributor's guide. 31 | 32 | This document focuses on getting any potential contributor familiarized with 33 | the development processes, but [other kinds of contributions] are also appreciated. 34 | 35 | If you are new to using [git] or have never collaborated in a project previously, 36 | please have a look at [contribution-guide.org]. Other resources are also 37 | listed in the excellent [guide created by FreeCodeCamp] [^contrib1]. 38 | 39 | Please notice, all users and contributors are expected to be **open, 40 | considerate, reasonable, and respectful**. When in doubt, 41 | [Python Software Foundation's Code of Conduct] is a good reference in terms of 42 | behavior guidelines. 43 | 44 | ## Issue Reports 45 | 46 | If you experience bugs or general issues with `decima`, please have a look 47 | on the [issue tracker]. 48 | If you don't see anything useful there, please feel free to fire an issue report. 49 | 50 | :::{tip} 51 | Please don't forget to include the closed issues in your search. 52 | Sometimes a solution was already reported, and the problem is considered 53 | **solved**. 54 | ::: 55 | 56 | New issue reports should include information about your programming environment 57 | (e.g., operating system, Python version) and steps to reproduce the problem. 58 | Please try also to simplify the reproduction steps to a very minimal example 59 | that still illustrates the problem you are facing. By removing other factors, 60 | you help us to identify the root cause of the issue. 61 | 62 | ## Documentation Improvements 63 | 64 | You can help improve `decima` docs by making them more readable and coherent, or 65 | by adding missing information and correcting mistakes. 66 | 67 | `decima` documentation uses [Sphinx] as its main documentation compiler. 68 | This means that the docs are kept in the same repository as the project code, and 69 | that any documentation update is done in the same way was a code contribution. 70 | 71 | ```{todo} Don't forget to mention which markup language you are using. 72 | 73 | e.g., [reStructuredText] or [CommonMark] with [MyST] extensions. 74 | ``` 75 | 76 | ```{todo} If your project is hosted on GitHub, you can also mention the following tip: 77 | 78 | :::{tip} 79 | Please notice that the [GitHub web interface] provides a quick way of 80 | propose changes in `decima`'s files. While this mechanism can 81 | be tricky for normal code contributions, it works perfectly fine for 82 | contributing to the docs, and can be quite handy. 83 | 84 | If you are interested in trying this method out, please navigate to 85 | the `docs` folder in the source [repository], find which file you 86 | would like to propose changes and click in the little pencil icon at the 87 | top, to open [GitHub's code editor]. Once you finish editing the file, 88 | please write a message in the form at the bottom of the page describing 89 | which changes have you made and what are the motivations behind them and 90 | submit your proposal. 91 | ::: 92 | ``` 93 | 94 | When working on documentation changes in your local machine, you can 95 | compile them using [tox] : 96 | 97 | ``` 98 | tox -e docs 99 | ``` 100 | 101 | and use Python's built-in web server for a preview in your web browser 102 | (`http://localhost:8000`): 103 | 104 | ``` 105 | python3 -m http.server --directory 'docs/_build/html' 106 | ``` 107 | 108 | ## Code Contributions 109 | 110 | ```{todo} Please include a reference or explanation about the internals of the project. 111 | 112 | An architecture description, design principles or at least a summary of the 113 | main concepts will make it easy for potential contributors to get started 114 | quickly. 115 | ``` 116 | 117 | ### Submit an issue 118 | 119 | Before you work on any non-trivial code contribution it's best to first create 120 | a report in the [issue tracker] to start a discussion on the subject. 121 | This often provides additional considerations and avoids unnecessary work. 122 | 123 | ### Create an environment 124 | 125 | Before you start coding, we recommend creating an isolated [virtual environment] 126 | to avoid any problems with your installed Python packages. 127 | This can easily be done via either [virtualenv]: 128 | 129 | ``` 130 | virtualenv 131 | source /bin/activate 132 | ``` 133 | 134 | or [Miniconda]: 135 | 136 | ``` 137 | conda create -n decima python=3 six virtualenv pytest pytest-cov 138 | conda activate decima 139 | ``` 140 | 141 | ### Clone the repository 142 | 143 | 1. Create an user account on GitHub if you do not already have one. 144 | 145 | 2. Fork the project [repository]: click on the *Fork* button near the top of the 146 | page. This creates a copy of the code under your account on GitHub. 147 | 148 | 3. Clone this copy to your local disk: 149 | 150 | ``` 151 | git clone git@github.com:YourLogin/decima.git 152 | cd decima 153 | ``` 154 | 155 | 4. You should run: 156 | 157 | ``` 158 | pip install -U pip setuptools -e . 159 | ``` 160 | 161 | to be able to import the package under development in the Python REPL. 162 | 163 | ```{todo} if you are not using pre-commit, please remove the following item: 164 | ``` 165 | 166 | 5. Install [pre-commit]: 167 | 168 | ``` 169 | pip install pre-commit 170 | pre-commit install 171 | ``` 172 | 173 | `decima` comes with a lot of hooks configured to automatically help the 174 | developer to check the code being written. 175 | 176 | ### Implement your changes 177 | 178 | 1. Create a branch to hold your changes: 179 | 180 | ``` 181 | git checkout -b my-feature 182 | ``` 183 | 184 | and start making changes. Never work on the main branch! 185 | 186 | 2. Start your work on this branch. Don't forget to add [docstrings] to new 187 | functions, modules and classes, especially if they are part of public APIs. 188 | 189 | 3. Add yourself to the list of contributors in `AUTHORS.rst`. 190 | 191 | 4. When you’re done editing, do: 192 | 193 | ``` 194 | git add 195 | git commit 196 | ``` 197 | 198 | to record your changes in [git]. 199 | 200 | ```{todo} if you are not using pre-commit, please remove the following item: 201 | ``` 202 | 203 | Please make sure to see the validation messages from [pre-commit] and fix 204 | any eventual issues. 205 | This should automatically use [flake8]/[black] to check/fix the code style 206 | in a way that is compatible with the project. 207 | 208 | :::{important} 209 | Don't forget to add unit tests and documentation in case your 210 | contribution adds an additional feature and is not just a bugfix. 211 | 212 | Moreover, writing a [descriptive commit message] is highly recommended. 213 | In case of doubt, you can check the commit history with: 214 | 215 | ``` 216 | git log --graph --decorate --pretty=oneline --abbrev-commit --all 217 | ``` 218 | 219 | to look for recurring communication patterns. 220 | ::: 221 | 222 | 5. Please check that your changes don't break any unit tests with: 223 | 224 | ``` 225 | tox 226 | ``` 227 | 228 | (after having installed [tox] with `pip install tox` or `pipx`). 229 | 230 | You can also use [tox] to run several other pre-configured tasks in the 231 | repository. Try `tox -av` to see a list of the available checks. 232 | 233 | ### Submit your contribution 234 | 235 | 1. If everything works fine, push your local branch to the remote server with: 236 | 237 | ``` 238 | git push -u origin my-feature 239 | ``` 240 | 241 | 2. Go to the web page of your fork and click "Create pull request" 242 | to send your changes for review. 243 | 244 | ```{todo} if you are using GitHub, you can uncomment the following paragraph 245 | 246 | Find more detailed information in [creating a PR]. You might also want to open 247 | the PR as a draft first and mark it as ready for review after the feedbacks 248 | from the continuous integration (CI) system or any required fixes. 249 | 250 | ``` 251 | 252 | ### Troubleshooting 253 | 254 | The following tips can be used when facing problems to build or test the 255 | package: 256 | 257 | 1. Make sure to fetch all the tags from the upstream [repository]. 258 | The command `git describe --abbrev=0 --tags` should return the version you 259 | are expecting. If you are trying to run CI scripts in a fork repository, 260 | make sure to push all the tags. 261 | You can also try to remove all the egg files or the complete egg folder, i.e., 262 | `.eggs`, as well as the `*.egg-info` folders in the `src` folder or 263 | potentially in the root of your project. 264 | 265 | 2. Sometimes [tox] misses out when new dependencies are added, especially to 266 | `setup.cfg` and `docs/requirements.txt`. If you find any problems with 267 | missing dependencies when running a command with [tox], try to recreate the 268 | `tox` environment using the `-r` flag. For example, instead of: 269 | 270 | ``` 271 | tox -e docs 272 | ``` 273 | 274 | Try running: 275 | 276 | ``` 277 | tox -r -e docs 278 | ``` 279 | 280 | 3. Make sure to have a reliable [tox] installation that uses the correct 281 | Python version (e.g., 3.7+). When in doubt you can run: 282 | 283 | ``` 284 | tox --version 285 | # OR 286 | which tox 287 | ``` 288 | 289 | If you have trouble and are seeing weird errors upon running [tox], you can 290 | also try to create a dedicated [virtual environment] with a [tox] binary 291 | freshly installed. For example: 292 | 293 | ``` 294 | virtualenv .venv 295 | source .venv/bin/activate 296 | .venv/bin/pip install tox 297 | .venv/bin/tox -e all 298 | ``` 299 | 300 | 4. [Pytest can drop you] in an interactive session in the case an error occurs. 301 | In order to do that you need to pass a `--pdb` option (for example by 302 | running `tox -- -k --pdb`). 303 | You can also setup breakpoints manually instead of using the `--pdb` option. 304 | 305 | ## Maintainer tasks 306 | 307 | ### Releases 308 | 309 | ```{todo} This section assumes you are using PyPI to publicly release your package. 310 | 311 | If instead you are using a different/private package index, please update 312 | the instructions accordingly. 313 | ``` 314 | 315 | If you are part of the group of maintainers and have correct user permissions 316 | on [PyPI], the following steps can be used to release a new version for 317 | `decima`: 318 | 319 | 1. Make sure all unit tests are successful. 320 | 2. Tag the current commit on the main branch with a release tag, e.g., `v1.2.3`. 321 | 3. Push the new tag to the upstream [repository], 322 | e.g., `git push upstream v1.2.3` 323 | 4. Clean up the `dist` and `build` folders with `tox -e clean` 324 | (or `rm -rf dist build`) 325 | to avoid confusion with old builds and Sphinx docs. 326 | 5. Run `tox -e build` and check that the files in `dist` have 327 | the correct version (no `.dirty` or [git] hash) according to the [git] tag. 328 | Also check the sizes of the distributions, if they are too big (e.g., > 329 | 500KB), unwanted clutter may have been accidentally included. 330 | 6. Run `tox -e publish -- --repository pypi` and check that everything was 331 | uploaded to [PyPI] correctly. 332 | 333 | [^contrib1]: Even though, these resources focus on open source projects and 334 | communities, the general ideas behind collaborating with other developers 335 | to collectively create software are general and can be applied to all sorts 336 | of environments, including private companies and proprietary code bases. 337 | 338 | 339 | [black]: https://pypi.org/project/black/ 340 | [commonmark]: https://commonmark.org/ 341 | [contribution-guide.org]: http://www.contribution-guide.org/ 342 | [creating a pr]: https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request 343 | [descriptive commit message]: https://chris.beams.io/posts/git-commit 344 | [docstrings]: https://www.sphinx-doc.org/en/master/usage/extensions/napoleon.html 345 | [first-contributions tutorial]: https://github.com/firstcontributions/first-contributions 346 | [flake8]: https://flake8.pycqa.org/en/stable/ 347 | [git]: https://git-scm.com 348 | [github web interface]: https://docs.github.com/en/github/managing-files-in-a-repository/managing-files-on-github/editing-files-in-your-repository 349 | [github's code editor]: https://docs.github.com/en/github/managing-files-in-a-repository/managing-files-on-github/editing-files-in-your-repository 350 | [github's fork and pull request workflow]: https://guides.github.com/activities/forking/ 351 | [guide created by freecodecamp]: https://github.com/freecodecamp/how-to-contribute-to-open-source 352 | [miniconda]: https://docs.conda.io/en/latest/miniconda.html 353 | [myst]: https://myst-parser.readthedocs.io/en/latest/syntax/syntax.html 354 | [other kinds of contributions]: https://opensource.guide/how-to-contribute 355 | [pre-commit]: https://pre-commit.com/ 356 | [pypi]: https://pypi.org/ 357 | [pyscaffold's contributor's guide]: https://pyscaffold.org/en/stable/contributing.html 358 | [pytest can drop you]: https://docs.pytest.org/en/stable/usage.html#dropping-to-pdb-python-debugger-at-the-start-of-a-test 359 | [python software foundation's code of conduct]: https://www.python.org/psf/conduct/ 360 | [restructuredtext]: https://www.sphinx-doc.org/en/master/usage/restructuredtext/ 361 | [sphinx]: https://www.sphinx-doc.org/en/master/ 362 | [tox]: https://tox.readthedocs.io/en/stable/ 363 | [virtual environment]: https://realpython.com/python-virtual-environments-a-primer/ 364 | [virtualenv]: https://virtualenv.pypa.io/en/stable/ 365 | 366 | 367 | ```{todo} Please review and change the following definitions: 368 | ``` 369 | 370 | [repository]: https://github.com//decima 371 | [issue tracker]: https://github.com//decima/issues 372 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright Genentech, Inc., 2024. 2 | 3 | Licensed under the Genentech Non-Commercial Software License Version 1.0, September 2022 (the "License"); you may not use the associated software except in compliance with the License. 4 | 5 | Note the Genentech Non-Commercial Software License reuses most of the text of the Apache License, Version 2.0, except that it includes a restriction against commercial use. The full text of the Apache License, Version 2.0 can be found at https://www.apache.org/licenses/LICENSE-2.0. 6 | 7 | Genentech Non-Commercial Software License Version 1.0, September 2022 8 | 9 | Everyone is permitted to copy and distribute verbatim copies of this license document, but changing it is not allowed. 10 | 11 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 12 | 13 | 1. Definitions. 14 | 15 | "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. 16 | 17 | "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. 18 | 19 | "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. 20 | 21 | "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. 22 | 23 | "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. 24 | 25 | "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. 26 | 27 | "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work. 28 | 29 | "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. 30 | 31 | "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." 32 | 33 | "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 34 | 35 | "Commercial Use" shall mean (i) the sale of a product embodying or incorporating or generated through the use of the Work (ii) the provision to a third party of a service for commercial purposes using the Work or (iii) granting commercial licenses and/or assigning such commercial rights to the Work to a third party, (iv) use of the Work in the discovery, research, pre-clinical or clinical development, or related manufacturing of diagnostic, prognostic, prophylactic, or therapeutic treatments, (v) use of the Work to identify patient characteristics (e.g., genomic sequences or phenotypic traits) to monitor, target, or use in or with such therapies in the course of such discovery, research as well as preclinical or clinical development activities. 36 | 37 | 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form for any legal purpose that is not a Commercial Use. 38 | 39 | 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, import, and otherwise transfer the Work for any legal purpose that is not a Commercial Use, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 40 | 41 | 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: 42 | 43 | You must give any other recipients of the Work or Derivative Works a copy of this License; and 44 | You must cause any modified files to carry prominent notices stating that You changed the files; and 45 | You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and 46 | If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. 47 | 48 | You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 49 | 50 | 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 51 | 52 | 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 53 | 54 | 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 55 | 56 | 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 57 | 58 | 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. 59 | 60 | END OF TERMS AND CONDITIONS -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![PyPI-Server](https://img.shields.io/pypi/v/decima.svg)](https://pypi.org/project/decima/) 2 | ![Unit tests](https://github.com/genentech/decima/actions/workflows/run-tests.yml/badge.svg) 3 | [![DOI](https://zenodo.org/badge/870361048.svg)](https://doi.org/10.5281/zenodo.15319897) 4 | 5 | # Decima 6 | 7 | ## Introduction 8 | Decima is a Python library to train sequence models on single-cell RNA-seq data. 9 | 10 | ![Figure](assets/fig1.png) 11 | 12 | ## Weights 13 | Weights of the trained Decima models (4 replicates) are now available at https://zenodo.org/records/15092691. See the tutorial for how to load and use these. 14 | 15 | ## Preprint 16 | Please cite https://www.biorxiv.org/content/10.1101/2024.10.09.617507v3. Also see https://github.com/Genentech/decima-applications for all the code used to train and apply models in this preprint. 17 | 18 | ## Installation 19 | 20 | Install the package from PyPI, 21 | 22 | ```sh 23 | pip install decima 24 | ``` 25 | 26 | Or if you want to be on the cutting edge, 27 | 28 | ```sh 29 | pip install git+https://github.com/genentech/decima.git@main 30 | ``` 31 | 32 | 33 | 34 | ## Note 35 | 36 | This project has been set up using [BiocSetup](https://github.com/biocpy/biocsetup) 37 | and [PyScaffold](https://pyscaffold.org/). 38 | -------------------------------------------------------------------------------- /assets/fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/decima/24c22bb3551f02d57dd3c6c58dc6969190fc7fd2/assets/fig1.png -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | AUTODOCDIR = api 11 | 12 | # User-friendly check for sphinx-build 13 | ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $?), 1) 14 | $(error "The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from https://sphinx-doc.org/") 15 | endif 16 | 17 | .PHONY: help clean Makefile 18 | 19 | # Put it first so that "make" without argument is like "make help". 20 | help: 21 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 22 | 23 | clean: 24 | rm -rf $(BUILDDIR)/* $(AUTODOCDIR) 25 | 26 | # Catch-all target: route all unknown targets to Sphinx using the new 27 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 28 | %: Makefile 29 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 30 | -------------------------------------------------------------------------------- /docs/_static/.gitignore: -------------------------------------------------------------------------------- 1 | # Empty directory 2 | -------------------------------------------------------------------------------- /docs/authors.md: -------------------------------------------------------------------------------- 1 | ```{include} ../AUTHORS.md 2 | :relative-docs: docs/ 3 | :relative-images: 4 | ``` 5 | -------------------------------------------------------------------------------- /docs/changelog.md: -------------------------------------------------------------------------------- 1 | ```{include} ../CHANGELOG.md 2 | :relative-docs: docs/ 3 | :relative-images: 4 | ``` 5 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # This file is execfile()d with the current directory set to its containing dir. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | # 7 | # All configuration values have a default; values that are commented out 8 | # serve to show the default. 9 | 10 | import os 11 | import sys 12 | import shutil 13 | 14 | # -- Path setup -------------------------------------------------------------- 15 | 16 | __location__ = os.path.dirname(__file__) 17 | 18 | # If extensions (or modules to document with autodoc) are in another directory, 19 | # add these directories to sys.path here. If the directory is relative to the 20 | # documentation root, use os.path.abspath to make it absolute, like shown here. 21 | sys.path.insert(0, os.path.join(__location__, "../src")) 22 | 23 | # -- Run sphinx-apidoc ------------------------------------------------------- 24 | # This hack is necessary since RTD does not issue `sphinx-apidoc` before running 25 | # `sphinx-build -b html . _build/html`. See Issue: 26 | # https://github.com/readthedocs/readthedocs.org/issues/1139 27 | # DON'T FORGET: Check the box "Install your project inside a virtualenv using 28 | # setup.py install" in the RTD Advanced Settings. 29 | # Additionally it helps us to avoid running apidoc manually 30 | 31 | try: # for Sphinx >= 1.7 32 | from sphinx.ext import apidoc 33 | except ImportError: 34 | from sphinx import apidoc 35 | 36 | output_dir = os.path.join(__location__, "api") 37 | module_dir = os.path.join(__location__, "../src/decima") 38 | try: 39 | shutil.rmtree(output_dir) 40 | except FileNotFoundError: 41 | pass 42 | 43 | try: 44 | import sphinx 45 | 46 | cmd_line = f"sphinx-apidoc --implicit-namespaces -f -o {output_dir} {module_dir}" 47 | 48 | args = cmd_line.split(" ") 49 | if tuple(sphinx.__version__.split(".")) >= ("1", "7"): 50 | # This is a rudimentary parse_version to avoid external dependencies 51 | args = args[1:] 52 | 53 | apidoc.main(args) 54 | except Exception as e: 55 | print("Running `sphinx-apidoc` failed!\n{}".format(e)) 56 | 57 | # -- General configuration --------------------------------------------------- 58 | 59 | # If your documentation needs a minimal Sphinx version, state it here. 60 | # needs_sphinx = '1.0' 61 | 62 | # Add any Sphinx extension module names here, as strings. They can be extensions 63 | # coming with Sphinx (named 'sphinx.ext.*') or your custom ones. 64 | extensions = [ 65 | "sphinx.ext.autodoc", 66 | "sphinx.ext.intersphinx", 67 | "sphinx.ext.todo", 68 | "sphinx.ext.autosummary", 69 | "sphinx.ext.viewcode", 70 | "sphinx.ext.coverage", 71 | "sphinx.ext.doctest", 72 | "sphinx.ext.ifconfig", 73 | "sphinx.ext.mathjax", 74 | "sphinx.ext.napoleon", 75 | ] 76 | 77 | # Add any paths that contain templates here, relative to this directory. 78 | templates_path = ["_templates"] 79 | 80 | 81 | # Enable markdown 82 | extensions.append("myst_parser") 83 | 84 | # Configure MyST-Parser 85 | myst_enable_extensions = [ 86 | "amsmath", 87 | "colon_fence", 88 | "deflist", 89 | "dollarmath", 90 | "html_image", 91 | "linkify", 92 | "replacements", 93 | "smartquotes", 94 | "substitution", 95 | "tasklist", 96 | ] 97 | 98 | # The suffix of source filenames. 99 | source_suffix = [".rst", ".md"] 100 | 101 | # The encoding of source files. 102 | # source_encoding = 'utf-8-sig' 103 | 104 | # The master toctree document. 105 | master_doc = "index" 106 | 107 | # General information about the project. 108 | project = "decima" 109 | copyright = "2024, Gokcen Eraslan" 110 | 111 | # The version info for the project you're documenting, acts as replacement for 112 | # |version| and |release|, also used in various other places throughout the 113 | # built documents. 114 | # 115 | # version: The short X.Y version. 116 | # release: The full version, including alpha/beta/rc tags. 117 | # If you don’t need the separation provided between version and release, 118 | # just set them both to the same value. 119 | try: 120 | from decima import __version__ as version 121 | except ImportError: 122 | version = "" 123 | 124 | if not version or version.lower() == "unknown": 125 | version = os.getenv("READTHEDOCS_VERSION", "unknown") # automatically set by RTD 126 | 127 | release = version 128 | 129 | # The language for content autogenerated by Sphinx. Refer to documentation 130 | # for a list of supported languages. 131 | # language = None 132 | 133 | # There are two options for replacing |today|: either, you set today to some 134 | # non-false value, then it is used: 135 | # today = '' 136 | # Else, today_fmt is used as the format for a strftime call. 137 | # today_fmt = '%B %d, %Y' 138 | 139 | # List of patterns, relative to source directory, that match files and 140 | # directories to ignore when looking for source files. 141 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", ".venv"] 142 | 143 | # The reST default role (used for this markup: `text`) to use for all documents. 144 | # default_role = None 145 | 146 | # If true, '()' will be appended to :func: etc. cross-reference text. 147 | # add_function_parentheses = True 148 | 149 | # If true, the current module name will be prepended to all description 150 | # unit titles (such as .. function::). 151 | # add_module_names = True 152 | 153 | # If true, sectionauthor and moduleauthor directives will be shown in the 154 | # output. They are ignored by default. 155 | # show_authors = False 156 | 157 | # The name of the Pygments (syntax highlighting) style to use. 158 | pygments_style = "sphinx" 159 | 160 | # A list of ignored prefixes for module index sorting. 161 | # modindex_common_prefix = [] 162 | 163 | # If true, keep warnings as "system message" paragraphs in the built documents. 164 | # keep_warnings = False 165 | 166 | # If this is True, todo emits a warning for each TODO entries. The default is False. 167 | todo_emit_warnings = True 168 | 169 | 170 | # -- Options for HTML output ------------------------------------------------- 171 | 172 | # The theme to use for HTML and HTML Help pages. See the documentation for 173 | # a list of builtin themes. 174 | html_theme = "alabaster" 175 | 176 | # Theme options are theme-specific and customize the look and feel of a theme 177 | # further. For a list of options available for each theme, see the 178 | # documentation. 179 | html_theme_options = { 180 | "sidebar_width": "300px", 181 | "page_width": "1200px" 182 | } 183 | 184 | # Add any paths that contain custom themes here, relative to this directory. 185 | # html_theme_path = [] 186 | 187 | # The name for this set of Sphinx documents. If None, it defaults to 188 | # " v documentation". 189 | # html_title = None 190 | 191 | # A shorter title for the navigation bar. Default is the same as html_title. 192 | # html_short_title = None 193 | 194 | # The name of an image file (relative to this directory) to place at the top 195 | # of the sidebar. 196 | # html_logo = "" 197 | 198 | # The name of an image file (within the static path) to use as favicon of the 199 | # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 200 | # pixels large. 201 | # html_favicon = None 202 | 203 | # Add any paths that contain custom static files (such as style sheets) here, 204 | # relative to this directory. They are copied after the builtin static files, 205 | # so a file named "default.css" will overwrite the builtin "default.css". 206 | html_static_path = ["_static"] 207 | 208 | # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, 209 | # using the given strftime format. 210 | # html_last_updated_fmt = '%b %d, %Y' 211 | 212 | # If true, SmartyPants will be used to convert quotes and dashes to 213 | # typographically correct entities. 214 | # html_use_smartypants = True 215 | 216 | # Custom sidebar templates, maps document names to template names. 217 | # html_sidebars = {} 218 | 219 | # Additional templates that should be rendered to pages, maps page names to 220 | # template names. 221 | # html_additional_pages = {} 222 | 223 | # If false, no module index is generated. 224 | # html_domain_indices = True 225 | 226 | # If false, no index is generated. 227 | # html_use_index = True 228 | 229 | # If true, the index is split into individual pages for each letter. 230 | # html_split_index = False 231 | 232 | # If true, links to the reST sources are added to the pages. 233 | # html_show_sourcelink = True 234 | 235 | # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. 236 | # html_show_sphinx = True 237 | 238 | # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. 239 | # html_show_copyright = True 240 | 241 | # If true, an OpenSearch description file will be output, and all pages will 242 | # contain a tag referring to it. The value of this option must be the 243 | # base URL from which the finished HTML is served. 244 | # html_use_opensearch = '' 245 | 246 | # This is the file name suffix for HTML files (e.g. ".xhtml"). 247 | # html_file_suffix = None 248 | 249 | # Output file base name for HTML help builder. 250 | htmlhelp_basename = "decima-doc" 251 | 252 | 253 | # -- Options for LaTeX output ------------------------------------------------ 254 | 255 | latex_elements = { 256 | # The paper size ("letterpaper" or "a4paper"). 257 | # "papersize": "letterpaper", 258 | # The font size ("10pt", "11pt" or "12pt"). 259 | # "pointsize": "10pt", 260 | # Additional stuff for the LaTeX preamble. 261 | # "preamble": "", 262 | } 263 | 264 | # Grouping the document tree into LaTeX files. List of tuples 265 | # (source start file, target name, title, author, documentclass [howto/manual]). 266 | latex_documents = [ 267 | ("index", "user_guide.tex", "decima Documentation", "Gokcen Eraslan", "manual") 268 | ] 269 | 270 | # The name of an image file (relative to this directory) to place at the top of 271 | # the title page. 272 | # latex_logo = "" 273 | 274 | # For "manual" documents, if this is true, then toplevel headings are parts, 275 | # not chapters. 276 | # latex_use_parts = False 277 | 278 | # If true, show page references after internal links. 279 | # latex_show_pagerefs = False 280 | 281 | # If true, show URL addresses after external links. 282 | # latex_show_urls = False 283 | 284 | # Documents to append as an appendix to all manuals. 285 | # latex_appendices = [] 286 | 287 | # If false, no module index is generated. 288 | # latex_domain_indices = True 289 | 290 | # -- External mapping -------------------------------------------------------- 291 | python_version = ".".join(map(str, sys.version_info[0:2])) 292 | intersphinx_mapping = { 293 | "sphinx": ("https://www.sphinx-doc.org/en/master", None), 294 | "python": ("https://docs.python.org/" + python_version, None), 295 | "matplotlib": ("https://matplotlib.org", None), 296 | "numpy": ("https://numpy.org/doc/stable", None), 297 | "sklearn": ("https://scikit-learn.org/stable", None), 298 | "pandas": ("https://pandas.pydata.org/pandas-docs/stable", None), 299 | "scipy": ("https://docs.scipy.org/doc/scipy/reference", None), 300 | "setuptools": ("https://setuptools.pypa.io/en/stable/", None), 301 | "pyscaffold": ("https://pyscaffold.org/en/stable", None), 302 | } 303 | 304 | print(f"loading configurations for {project} {version} ...", file=sys.stderr) 305 | 306 | # -- Biocsetup configuration ------------------------------------------------- 307 | 308 | # Enable execution of code chunks in markdown 309 | extensions.remove('myst_parser') 310 | extensions.append('myst_nb') 311 | 312 | # Less verbose api documentation 313 | extensions.append('sphinx_autodoc_typehints') 314 | 315 | autodoc_default_options = { 316 | "special-members": True, 317 | "undoc-members": True, 318 | "exclude-members": "__weakref__, __dict__, __str__, __module__", 319 | } 320 | 321 | autosummary_generate = True 322 | autosummary_imported_members = True 323 | 324 | html_theme = "furo" 325 | -------------------------------------------------------------------------------- /docs/contributing.md: -------------------------------------------------------------------------------- 1 | ```{include} ../CONTRIBUTING.md 2 | :relative-docs: docs/ 3 | :relative-images: 4 | ``` 5 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # decima 2 | 3 | ## Contents 4 | 5 | ```{toctree} 6 | :maxdepth: 2 7 | 8 | Overview 9 | Contributions & Help 10 | License 11 | Authors 12 | Changelog 13 | Module Reference 14 | ``` 15 | 16 | ## Indices and tables 17 | 18 | * {ref}`genindex` 19 | * {ref}`modindex` 20 | * {ref}`search` 21 | 22 | [Sphinx]: http://www.sphinx-doc.org/ 23 | [Markdown]: https://daringfireball.net/projects/markdown/ 24 | [reStructuredText]: http://www.sphinx-doc.org/en/master/usage/restructuredtext/basics.html 25 | [MyST]: https://myst-parser.readthedocs.io/en/latest/ 26 | -------------------------------------------------------------------------------- /docs/license.md: -------------------------------------------------------------------------------- 1 | # License 2 | 3 | ```{literalinclude} ../LICENSE.txt 4 | :language: text 5 | ``` 6 | -------------------------------------------------------------------------------- /docs/readme.md: -------------------------------------------------------------------------------- 1 | ```{include} ../README.md 2 | :relative-docs: docs/ 3 | :relative-images: 4 | ``` 5 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | # Requirements file for ReadTheDocs, check .readthedocs.yml. 2 | # To build the module reference correctly, make sure every external package 3 | # under `install_requires` in `setup.cfg` is also listed here! 4 | # sphinx_rtd_theme 5 | myst-parser[linkify] 6 | sphinx>=3.2.1 7 | myst-nb 8 | furo 9 | sphinx-autodoc-typehints 10 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | # AVOID CHANGING REQUIRES: IT WILL BE UPDATED BY PYSCAFFOLD! 3 | requires = ["setuptools>=46.1.0", "setuptools_scm[toml]>=5"] 4 | build-backend = "setuptools.build_meta" 5 | 6 | [tool.setuptools_scm] 7 | # For smarter version schemes and other configuration options, 8 | # check out https://github.com/pypa/setuptools_scm 9 | version_scheme = "no-guess-dev" 10 | 11 | [tool.ruff] 12 | line-length = 120 13 | src = ["src"] 14 | exclude = ["tests"] 15 | extend-ignore = ["F821"] 16 | 17 | [tool.ruff.pydocstyle] 18 | convention = "google" 19 | 20 | [tool.ruff.format] 21 | docstring-code-format = true 22 | docstring-code-line-length = 20 23 | 24 | [tool.ruff.per-file-ignores] 25 | "__init__.py" = ["E402", "F401"] 26 | -------------------------------------------------------------------------------- /scripts/finetune.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | 5 | import anndata 6 | import wandb 7 | 8 | src_dir = f"{os.path.dirname(__file__)}/../src/decima/" 9 | sys.path.append(src_dir) 10 | from lightning import LightningModel 11 | from read_hdf5 import HDF5Dataset 12 | 13 | # Parse arguments 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--name", type=str) 16 | parser.add_argument("--dir", type=str) 17 | parser.add_argument("--lr", type=float) 18 | parser.add_argument("--weight", type=float) 19 | parser.add_argument("--grad", type=int) 20 | parser.add_argument("--replicate", type=int, default=0) 21 | parser.add_argument("--bs", type=int, default=4) 22 | args = parser.parse_args() 23 | 24 | 25 | def main(): 26 | wandb.login(host="https://genentech.wandb.io") 27 | run = wandb.init(project="decima", dir=args.name, name=args.name) 28 | 29 | # Get paths 30 | data_dir = args.dir 31 | matrix_file = os.path.join(data_dir, "aggregated.h5ad") 32 | h5_file = os.path.join(data_dir, "data.h5") 33 | print(f"Data paths: {matrix_file}, {h5_file}") 34 | 35 | # Load data 36 | print("Reading anndata") 37 | ad = anndata.read_h5ad(matrix_file) 38 | 39 | # Make datasets 40 | print("Making dataset objects") 41 | train_dataset = HDF5Dataset( 42 | h5_file=h5_file, 43 | ad=ad, 44 | key="train", 45 | max_seq_shift=5000, 46 | augment_mode="random", 47 | seed=0, 48 | ) 49 | val_dataset = HDF5Dataset(h5_file=h5_file, ad=ad, key="val", max_seq_shift=0) 50 | 51 | # Make param dicts 52 | train_params = { 53 | "optimizer": "adam", 54 | "batch_size": args.bs, 55 | "num_workers": 16, 56 | "devices": 0, 57 | "logger": "wandb", 58 | "save_dir": data_dir, 59 | "max_epochs": 15, 60 | "lr": args.lr, 61 | "total_weight": args.weight, 62 | "accumulate_grad_batches": args.grad, 63 | "loss": "poisson_multinomial", 64 | "pairs": ad.uns["disease_pairs"].values, 65 | } 66 | model_params = { 67 | "n_tasks": ad.shape[0], 68 | "replicate": args.replicate, 69 | } 70 | 71 | print(f"train_params: {train_params}") 72 | print(f"model_params: {model_params}") 73 | 74 | # Make model 75 | print("Initializing model") 76 | model = LightningModel(model_params=model_params, train_params=train_params) 77 | 78 | # Fine-tune model 79 | print("Training") 80 | model.train_on_dataset(train_dataset, val_dataset) 81 | 82 | train_dataset.close() 83 | val_dataset.close() 84 | run.finish() 85 | 86 | 87 | if __name__ == "__main__": 88 | main() 89 | -------------------------------------------------------------------------------- /scripts/predict_genes.py: -------------------------------------------------------------------------------- 1 | # Given an hdf5 file created by write_hdf5.py, make predictions for all the genes 2 | 3 | import argparse 4 | import os 5 | import sys 6 | 7 | import anndata 8 | import numpy as np 9 | import pandas as pd 10 | import torch 11 | from tqdm import tqdm 12 | 13 | src_dir = f"{os.path.dirname(__file__)}/../src/decima/" 14 | sys.path.append(src_dir) 15 | 16 | from lightning import LightningModel 17 | from read_hdf5 import HDF5Dataset, list_genes 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("--device", help="which gpu to use", type=int) 21 | parser.add_argument("--ckpts", help="Path to the model checkpoint", nargs="+") 22 | parser.add_argument("--h5_file", help="Path to h5 file indexed by genes") 23 | parser.add_argument( 24 | "--matrix_file", help="Path to h5ad file containing genes to predict" 25 | ) 26 | parser.add_argument("--out_file", help="Output file path") 27 | parser.add_argument( 28 | "--max_seq_shift", help="Maximum jitter for augmentation", default=0, type=int 29 | ) 30 | 31 | args = parser.parse_args() 32 | 33 | 34 | torch.set_float32_matmul_precision("medium") 35 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device) 36 | device = torch.device(0) 37 | 38 | print("Loading anndata") 39 | ad = anndata.read_h5ad(args.matrix_file) 40 | assert np.all(list_genes(args.h5_file, key=None) == ad.var_names.tolist()) 41 | 42 | print("Making dataset") 43 | ds = HDF5Dataset( 44 | key=None, 45 | h5_file=args.h5_file, 46 | ad=ad, 47 | seq_len=524288, 48 | max_seq_shift=args.max_seq_shift, 49 | ) 50 | 51 | print("Loading models from checkpoint") 52 | models = [LightningModel.load_from_checkpoint(f).eval() for f in args.ckpts] 53 | 54 | print("Computing predictions") 55 | preds = ( 56 | np.stack( 57 | [ 58 | model.predict_on_dataset(ds, devices=0, batch_size=6, num_workers=16) 59 | for model in models 60 | ] 61 | ) 62 | .mean(0) 63 | .T 64 | ) 65 | ad.layers["preds"] = preds 66 | 67 | print("Computing correlations per gene") 68 | ad.var["pearson"] = [ 69 | np.corrcoef(ad.X[:, i], ad.layers["preds"][:, i])[0, 1] for i in range(ad.shape[1]) 70 | ] 71 | ad.var["size_factor_pearson"] = [ 72 | np.corrcoef(ad.X[:, i], ad.obs["size_factor"])[0, 1] for i in range(ad.shape[1]) 73 | ] 74 | print( 75 | f"Mean Pearson Correlation per gene: True: {ad.var.pearson.mean().round(2)} Size Factor: {ad.var.size_factor_pearson.mean().round(2)}" 76 | ) 77 | 78 | print("Computing correlation per track") 79 | for dataset in ad.var.dataset.unique(): 80 | key = f"{dataset}_pearson" 81 | ad.obs[key] = [ 82 | np.corrcoef( 83 | ad[i, ad.var.dataset == dataset].X, 84 | ad[i, ad.var.dataset == dataset].layers["preds"], 85 | )[0, 1] 86 | for i in range(ad.shape[0]) 87 | ] 88 | print( 89 | f"Mean Pearson Correlation per pseudobulk over {dataset} genes: {ad.obs[key].mean().round(2)}" 90 | ) 91 | 92 | 93 | print("Saved") 94 | ad.write_h5ad(args.out_file) 95 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | # This file is used to configure your project. 2 | # Read more about the various options under: 3 | # https://setuptools.pypa.io/en/latest/userguide/declarative_config.html 4 | # https://setuptools.pypa.io/en/latest/references/keywords.html 5 | 6 | [metadata] 7 | name = decima 8 | description = Add a short description here! 9 | author = Avantika Lal 10 | author_email = lal.avantika@gene.com 11 | license = Genentech Non-Commercial Software License Version 1.0 12 | license_files = LICENSE.txt 13 | long_description = file: README.md 14 | long_description_content_type = text/markdown; charset=UTF-8; variant=GFM 15 | url = https://github.com/genentech/decima 16 | # Add here related links, for example: 17 | project_urls = 18 | Documentation = https://github.com/genentech/decima 19 | # Source = https://github.com/pyscaffold/pyscaffold/ 20 | # Changelog = https://pyscaffold.org/en/latest/changelog.html 21 | # Tracker = https://github.com/pyscaffold/pyscaffold/issues 22 | # Conda-Forge = https://anaconda.org/conda-forge/pyscaffold 23 | # Download = https://pypi.org/project/PyScaffold/#files 24 | # Twitter = https://twitter.com/PyScaffold 25 | 26 | # Change if running only on Windows, Mac or Linux (comma-separated) 27 | platforms = any 28 | 29 | # Add here all kinds of additional classifiers as defined under 30 | # https://pypi.org/classifiers/ 31 | classifiers = 32 | Development Status :: 4 - Beta 33 | Programming Language :: Python 34 | 35 | 36 | [options] 37 | zip_safe = False 38 | packages = find_namespace: 39 | include_package_data = True 40 | package_dir = 41 | =src 42 | 43 | # Require a min/specific Python version (comma-separated conditions) 44 | python_requires = >=3.9 45 | 46 | # Add here dependencies of your project (line-separated), e.g. requests>=2.2,<3.0. 47 | # Version specifiers like >=2.2,<3.0 avoid problems due to API changes in 48 | # new major versions. This works if the required packages follow Semantic Versioning. 49 | # For more information, check out https://semver.org/. 50 | install_requires = 51 | importlib-metadata; python_version<"3.8" 52 | wandb 53 | numpy 54 | torch 55 | grelu 56 | lightning 57 | torchmetrics 58 | bioframe 59 | scipy 60 | scikit-learn 61 | pandas 62 | tqdm 63 | captum 64 | anndata 65 | h5py 66 | 67 | [options.packages.find] 68 | where = src 69 | exclude = 70 | tests 71 | 72 | [options.extras_require] 73 | # Add here additional requirements for extra features, to install with: 74 | # `pip install decima[PDF]` like: 75 | # PDF = ReportLab; RXP 76 | 77 | # Add here test requirements (semicolon/line-separated) 78 | testing = 79 | setuptools 80 | pytest 81 | pytest-cov 82 | 83 | [options.entry_points] 84 | # Add here console scripts like: 85 | # console_scripts = 86 | # script_name = decima.module:function 87 | # For example: 88 | # console_scripts = 89 | # fibonacci = decima.skeleton:run 90 | # And any other entry points, for example: 91 | # pyscaffold.cli = 92 | # awesome = pyscaffoldext.awesome.extension:AwesomeExtension 93 | 94 | [tool:pytest] 95 | # Specify command line options as you would do when invoking pytest directly. 96 | # e.g. --cov-report html (or xml) for html/xml output or --junitxml junit.xml 97 | # in order to write a coverage file that can be read by Jenkins. 98 | # CAUTION: --cov flags may prohibit setting breakpoints while debugging. 99 | # Comment those flags to avoid this pytest issue. 100 | addopts = 101 | --cov decima --cov-report term-missing 102 | --verbose 103 | norecursedirs = 104 | dist 105 | build 106 | .tox 107 | testpaths = tests 108 | # Use pytest markers to select/deselect specific tests 109 | # markers = 110 | # slow: mark tests as slow (deselect with '-m "not slow"') 111 | # system: mark end-to-end system tests 112 | 113 | [devpi:upload] 114 | # Options for the devpi: PyPI server and packaging tool 115 | # VCS export must be deactivated since we are using setuptools-scm 116 | no_vcs = 1 117 | formats = bdist_wheel 118 | 119 | [flake8] 120 | # Some sane defaults for the code style checker flake8 121 | max_line_length = 88 122 | extend_ignore = E203, W503 123 | # ^ Black-compatible 124 | # E203 and W503 have edge cases handled by black 125 | exclude = 126 | .tox 127 | build 128 | dist 129 | .eggs 130 | docs/conf.py 131 | 132 | [pyscaffold] 133 | # PyScaffold's parameters when the project was created. 134 | # This will be used when updating. Do not change! 135 | version = 4.6 136 | package = decima 137 | extensions = 138 | github_actions 139 | no_skeleton 140 | pre_commit 141 | markdown 142 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Setup file for decima. 3 | Use setup.cfg to configure your project. 4 | 5 | This file was generated with PyScaffold 4.5. 6 | PyScaffold helps you to put up the scaffold of your new Python project. 7 | Learn more under: https://pyscaffold.org/ 8 | """ 9 | 10 | from setuptools import setup 11 | 12 | if __name__ == "__main__": 13 | try: 14 | setup(use_scm_version={"version_scheme": "no-guess-dev"}) 15 | except: # noqa 16 | print( 17 | "\n\nAn error occurred while building the project, " 18 | "please ensure you have the most updated version of setuptools, " 19 | "setuptools_scm and wheel with:\n" 20 | " pip install -U setuptools setuptools_scm wheel\n\n" 21 | ) 22 | raise 23 | -------------------------------------------------------------------------------- /src/decima/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | if sys.version_info[:2] >= (3, 8): 4 | # TODO: Import directly (no need for conditional) when `python_requires = >= 3.8` 5 | from importlib.metadata import PackageNotFoundError, version # pragma: no cover 6 | else: 7 | from importlib_metadata import PackageNotFoundError, version # pragma: no cover 8 | 9 | try: 10 | # Change here if project is renamed and does not equal the package name 11 | dist_name = __name__ 12 | __version__ = version(dist_name) 13 | except PackageNotFoundError: # pragma: no cover 14 | __version__ = "unknown" 15 | finally: 16 | del version, PackageNotFoundError 17 | -------------------------------------------------------------------------------- /src/decima/decima_model.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from tempfile import TemporaryDirectory 3 | 4 | import torch 5 | import wandb 6 | from grelu.model.heads import ConvHead 7 | from grelu.model.models import BaseModel, BorzoiModel 8 | from torch import nn 9 | 10 | 11 | class DecimaModel(BaseModel): 12 | def __init__(self, n_tasks: int, replicate: int = 0, mask=True, init_borzoi=True): 13 | self.mask = mask 14 | model = BorzoiModel( 15 | crop_len=5120, 16 | n_tasks=7611, 17 | stem_channels=512, 18 | stem_kernel_size=15, 19 | init_channels=608, 20 | n_conv=7, 21 | kernel_size=5, 22 | n_transformers=8, 23 | key_len=64, 24 | value_len=192, 25 | pos_dropout=0.0, 26 | attn_dropout=0.0, 27 | n_heads=8, 28 | n_pos_features=32, 29 | final_act_func=None, 30 | final_pool_func=None, 31 | ) 32 | 33 | if init_borzoi: 34 | # Load state dict 35 | wandb.login(host="https://api.wandb.ai/", anonymous="must") 36 | api = wandb.Api(overrides={"base_url": "https://api.wandb.ai/"}) 37 | art = api.artifact(f"grelu/borzoi/human_state_dict_fold{replicate}:latest") 38 | with TemporaryDirectory() as d: 39 | art.download(d) 40 | state_dict = torch.load(Path(d) / f"fold{replicate}.h5") 41 | model.load_state_dict(state_dict) 42 | 43 | # Change head 44 | head = ConvHead(n_tasks=n_tasks, in_channels=1920, pool_func="avg") 45 | 46 | super().__init__(embedding=model.embedding, head=head) 47 | 48 | # Add a channel for the gene mask 49 | if self.mask: 50 | weight = self.embedding.conv_tower.blocks[0].conv.weight 51 | new_layer = nn.Conv1d(5, 512, kernel_size=(15,), stride=(1,), padding="same") 52 | new_weight = nn.Parameter(torch.cat([weight, new_layer.weight[:, [-1], :]], axis=1)) 53 | self.embedding.conv_tower.blocks[0].conv.weight = new_weight 54 | -------------------------------------------------------------------------------- /src/decima/evaluate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from scipy.stats import pearsonr, zscore 4 | from sklearn.metrics import average_precision_score, roc_auc_score 5 | from tqdm.auto import tqdm 6 | 7 | 8 | def match_criteria(df, filter_df): 9 | for col in filter_df.columns: 10 | all_values = df[col].unique().tolist() 11 | isna = filter_df[col].isna() 12 | if isna.sum() > 0: 13 | filter_df.loc[isna, col] = [all_values] * isna.sum() 14 | filter_df = filter_df.explode(col) 15 | 16 | sel_idxs = df.reset_index().merge(filter_df, how="inner")["index"] 17 | return df.index.isin(sel_idxs) 18 | 19 | 20 | def marker_zscores(ad, key="cell_type", layer=None): 21 | E = ad.X if layer is None else ad.layers[layer] 22 | z = zscore(E, axis=0) 23 | 24 | dfs = [] 25 | for group in tqdm(ad.obs[key].unique()): 26 | scores = z[ad.obs[key] == group, :].mean(0).squeeze() 27 | df = pd.DataFrame({"gene": ad.var_names, "score": scores, key: group}) 28 | dfs.append(df) 29 | 30 | return pd.concat(dfs, axis=0).reset_index(drop=True) 31 | 32 | 33 | def compare_marker_zscores(ad, key="cell_type"): 34 | marker_df_obs = marker_zscores(ad, key) 35 | marker_df_pred = marker_zscores(ad, key, layer="preds") 36 | marker_df = marker_df_pred.merge(marker_df_obs, on=["gene", key], suffixes=("_pred", "_obs")) 37 | return marker_df 38 | 39 | 40 | def compute_marker_metrics(marker_df, key="cell_type", tp_cutoff=1): 41 | df_list = [] 42 | for k in set(marker_df[key]): 43 | # get celltype data 44 | curr_marker_df = marker_df[marker_df[key] == k].copy() 45 | 46 | # compute corrs 47 | corr = pearsonr(curr_marker_df["score_obs"], curr_marker_df["score_pred"])[0] 48 | 49 | # compute binary labels 50 | labels = curr_marker_df["score_obs"] > tp_cutoff 51 | n_positive = np.sum(labels) 52 | 53 | if n_positive == len(labels) or n_positive == 0: 54 | auprc, auroc = np.nan 55 | else: 56 | auprc = average_precision_score(labels, curr_marker_df["score_pred"]) 57 | auroc = roc_auc_score(labels, curr_marker_df["score_pred"]) 58 | 59 | # append df 60 | curr_df = pd.DataFrame( 61 | { 62 | key: [k], 63 | "n_positive": [n_positive], 64 | "pearson": [corr], 65 | "auprc": [auprc], 66 | "auroc": [auroc], 67 | } 68 | ) 69 | df_list.append(curr_df) 70 | 71 | return pd.concat(df_list) 72 | -------------------------------------------------------------------------------- /src/decima/interpret.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import torch 4 | from captum.attr import InputXGradient 5 | from grelu.interpret.motifs import scan_sequences 6 | from grelu.sequence.format import convert_input_type 7 | from grelu.transforms.prediction_transforms import Aggregate, Specificity 8 | from scipy.signal import find_peaks 9 | 10 | from .read_hdf5 import extract_gene_data 11 | 12 | 13 | def attributions( 14 | gene, 15 | tasks, 16 | model, 17 | device=None, 18 | h5_file=None, 19 | inputs=None, 20 | off_tasks=None, 21 | transform="specificity", 22 | method=InputXGradient, 23 | **kwargs, 24 | ): 25 | if inputs is None: 26 | assert h5_file is not None 27 | inputs = extract_gene_data(h5_file, gene, merge=True) 28 | 29 | tss_pos = np.where(inputs[-1] == 1)[0][0] 30 | if transform == "specificity": 31 | model.add_transform( 32 | Specificity( 33 | on_tasks=tasks, 34 | off_tasks=off_tasks, 35 | model=model, 36 | compare_func="subtract", 37 | ) 38 | ) 39 | elif transform == "aggregate": 40 | model.add_transform(Aggregate(tasks=tasks, task_aggfunc="mean", model=model)) 41 | 42 | model = model.eval() 43 | if device is None: 44 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 45 | else: 46 | device = torch.device(device) 47 | 48 | attributer = method(model.to(device)) 49 | with torch.no_grad(): 50 | attr = attributer.attribute(inputs.to(device), **kwargs).cpu().numpy()[:4] 51 | 52 | model.reset_transform() 53 | return attr, tss_pos 54 | 55 | 56 | def find_attr_peaks(attr, tss_pos=None, n=5, min_dist=6): 57 | peaks, heights = find_peaks(attr.sum(0), height=0.1, distance=min_dist) 58 | peaks = pd.DataFrame({"peak": peaks, "height": heights["peak_heights"]}) 59 | if tss_pos is not None: 60 | peaks["from_tss"] = peaks["peak"] - tss_pos 61 | peaks = peaks.sort_values("height", ascending=False).head(n) 62 | return peaks.reset_index(drop=True) 63 | 64 | 65 | def scan_attributions(seq, attr, motifs, peaks, names=None, pthresh=1e-3, rc=True, window=18): 66 | # Attributions and sequences 67 | peak_attrs = np.stack([attr[:, peak - window // 2 : peak + window // 2] for peak in peaks.peak]) 68 | peak_seqs = torch.stack([seq[:, peak - window // 2 : peak + window // 2] for peak in peaks.peak]) 69 | 70 | # Scan 71 | results = scan_sequences( 72 | seqs=convert_input_type(peak_seqs, "strings"), 73 | motifs=motifs, 74 | names=names, 75 | pthresh=pthresh, 76 | rc=rc, 77 | attrs=peak_attrs, 78 | ) 79 | results.sequence = results.sequence.astype(int) 80 | return results.merge(peaks.reset_index(drop=True), left_on="sequence", right_index=True) 81 | -------------------------------------------------------------------------------- /src/decima/lightning.py: -------------------------------------------------------------------------------- 1 | """ 2 | The LightningModel class. 3 | """ 4 | 5 | from datetime import datetime 6 | from typing import Callable, List, Optional, Tuple, Union 7 | 8 | import numpy as np 9 | import pytorch_lightning as pl 10 | import torch 11 | from einops import rearrange 12 | from grelu.lightning.metrics import MSE, PearsonCorrCoef 13 | from grelu.utils import make_list 14 | from pytorch_lightning.callbacks import ModelCheckpoint 15 | from pytorch_lightning.loggers import CSVLogger, WandbLogger 16 | from torch import Tensor, nn, optim 17 | from torch.utils.data import DataLoader 18 | from torchmetrics import MetricCollection 19 | 20 | from .decima_model import DecimaModel 21 | from .loss import TaskWisePoissonMultinomialLoss 22 | from .metrics import DiseaseLfcMSE 23 | 24 | default_train_params = { 25 | "lr": 4e-5, 26 | "batch_size": 4, 27 | "num_workers": 1, 28 | "devices": 0, 29 | "logger": "csv", 30 | "save_dir": ".", 31 | "max_epochs": 1, 32 | "accumulate_grad_batches": 1, 33 | "total_weight": 1e-4, 34 | "disease_weight": 1e-2, 35 | } 36 | 37 | 38 | class LightningModel(pl.LightningModule): 39 | """ 40 | Wrapper for predictive sequence models 41 | 42 | Args: 43 | model_params: Dictionary of parameters specifying model architecture 44 | train_params: Dictionary specifying training parameters 45 | data_params: Dictionary specifying parameters of the training data. 46 | This is empty by default and will be filled at the time of 47 | training. 48 | """ 49 | 50 | def __init__(self, model_params: dict, train_params: dict = {}, data_params: dict = {}) -> None: 51 | super().__init__() 52 | 53 | self.save_hyperparameters(ignore=["model"]) 54 | 55 | # Add default training parameters 56 | for key in default_train_params.keys(): 57 | if key not in train_params: 58 | train_params[key] = default_train_params[key] 59 | 60 | # Save params 61 | self.model_params = model_params 62 | self.train_params = train_params 63 | self.data_params = data_params 64 | 65 | # Build model 66 | self.model = DecimaModel(**{k: v for k, v in self.model_params.items()}) 67 | 68 | # Set up loss function 69 | self.loss = TaskWisePoissonMultinomialLoss(total_weight=self.train_params["total_weight"], debug=True) 70 | self.val_losses = [] 71 | self.test_losses = [] 72 | 73 | # Set up activation function 74 | self.activation = torch.exp 75 | 76 | # Inititalize metrics 77 | metrics = MetricCollection( 78 | { 79 | "mse": MSE(num_outputs=self.model.head.n_tasks, average=False), 80 | "pearson": PearsonCorrCoef(num_outputs=self.model.head.n_tasks, average=False), 81 | "disease_lfc_mse": DiseaseLfcMSE(pairs=self.train_params["pairs"], average=False), 82 | } 83 | ) 84 | self.val_metrics = metrics.clone(prefix="val_") 85 | self.test_metrics = metrics.clone(prefix="test_") 86 | 87 | # Initialize prediction transform 88 | self.reset_transform() 89 | 90 | def format_input(self, x: Union[Tuple[Tensor, Tensor], Tensor]) -> Tensor: 91 | """ 92 | Extract the one-hot encoded sequence from the input 93 | """ 94 | # if x is a tuple of sequence, label, return the sequence 95 | if isinstance(x, Tensor): 96 | if x.ndim == 3: 97 | return x 98 | else: 99 | return x.unsqueeze(0) 100 | elif isinstance(x, Tuple): 101 | return x[0] 102 | else: 103 | raise Exception("Cannot perform forward pass on the given input format.") 104 | 105 | def forward( 106 | self, 107 | x: Union[Tuple[Tensor, Tensor], Tensor, str, List[str]], 108 | logits: bool = False, 109 | ) -> Tensor: 110 | """ 111 | Forward pass 112 | """ 113 | # Format the input as a one-hot encoded tensor 114 | x = self.format_input(x) 115 | 116 | # Run the model 117 | x = self.model(x) 118 | 119 | # forward() produces prediction (e.g. post-activation) 120 | # unless logits=True, which is used in loss functions 121 | if not logits: 122 | x = self.activation(x) 123 | 124 | # Apply transform 125 | x = self.transform(x) 126 | return x 127 | 128 | def training_step(self, batch: Tensor, batch_idx: int) -> Tensor: 129 | x, y = batch 130 | logits = self.forward(x, logits=True) 131 | loss = self.loss(logits, y) 132 | self.log("train_loss", loss, logger=True, on_step=True, on_epoch=True, prog_bar=True) 133 | return loss 134 | 135 | def validation_step(self, batch: Tensor, batch_idx: int) -> Tensor: 136 | x, y = batch 137 | logits = self.forward(x, logits=True) 138 | loss = self.loss(logits, y) 139 | y_hat = self.activation(logits) 140 | self.log("val_loss", loss, logger=True, on_step=False, on_epoch=True) 141 | self.val_metrics.update(y_hat, y) 142 | self.val_losses.append(loss) 143 | return loss 144 | 145 | def on_validation_epoch_end(self): 146 | """ 147 | Calculate metrics for entire validation set 148 | """ 149 | # Compute metrics 150 | val_metrics = self.val_metrics.compute() 151 | mean_val_metrics = {k: v.mean() for k, v in val_metrics.items()} 152 | # Compute loss 153 | losses = torch.stack(self.val_losses) 154 | mean_losses = torch.mean(losses) 155 | # Log 156 | self.log_dict(mean_val_metrics) 157 | self.log("val_loss", mean_losses) 158 | 159 | self.val_metrics.reset() 160 | self.val_losses = [] 161 | 162 | def test_step(self, batch: Tensor, batch_idx: int) -> Tensor: 163 | """ 164 | Calculate metrics after a single test step 165 | """ 166 | x, y = batch 167 | logits = self.forward(x, logits=True) 168 | loss = self.loss(logits, y) 169 | y_hat = self.activation(logits) 170 | self.log("test_loss", loss, logger=True, on_step=False, on_epoch=True) 171 | self.test_metrics.update(y_hat, y) 172 | self.test_losses.append(loss) 173 | return loss 174 | 175 | def on_test_epoch_end(self) -> None: 176 | """ 177 | Calculate metrics for entire test set 178 | """ 179 | self.computed_test_metrics = self.test_metrics.compute() 180 | self.log_dict({k: v.mean() for k, v in self.computed_test_metrics.items()}) 181 | losses = torch.stack(self.test_losses) 182 | self.log("test_loss", torch.mean(losses)) 183 | self.test_metrics.reset() 184 | self.test_losses = [] 185 | 186 | def configure_optimizers(self) -> None: 187 | """ 188 | Configure oprimizer for training 189 | """ 190 | return optim.Adam(self.parameters(), lr=self.train_params["lr"]) 191 | 192 | def count_params(self) -> int: 193 | """ 194 | Number of gradient enabled parameters in the model 195 | """ 196 | return sum(p.numel() for p in self.model.parameters() if p.requires_grad) 197 | 198 | def parse_logger(self) -> str: 199 | """ 200 | Parses the name of the logger supplied in train_params. 201 | """ 202 | if "name" not in self.train_params: 203 | self.train_params["name"] = datetime.now().strftime("%Y_%d_%m_%H_%M") 204 | if self.train_params["logger"] == "wandb": 205 | logger = WandbLogger( 206 | name=self.train_params["name"], 207 | log_model=True, 208 | save_dir=self.train_params["save_dir"], 209 | ) 210 | elif self.train_params["logger"] == "csv": 211 | logger = CSVLogger(name=self.train_params["name"], save_dir=self.train_params["save_dir"]) 212 | else: 213 | raise NotImplementedError 214 | return logger 215 | 216 | def add_transform(self, prediction_transform: Callable) -> None: 217 | """ 218 | Add a prediction transform 219 | """ 220 | if prediction_transform is not None: 221 | self.transform = prediction_transform 222 | 223 | def reset_transform(self) -> None: 224 | """ 225 | Remove a prediction transform 226 | """ 227 | self.transform = nn.Identity() 228 | 229 | def make_train_loader( 230 | self, 231 | dataset: Callable, 232 | batch_size: Optional[int] = None, 233 | num_workers: Optional[int] = None, 234 | ) -> Callable: 235 | """ 236 | Make dataloader for training 237 | """ 238 | return DataLoader( 239 | dataset, 240 | batch_size=batch_size or self.train_params["batch_size"], 241 | shuffle=True, 242 | num_workers=num_workers or self.train_params["num_workers"], 243 | ) 244 | 245 | def make_test_loader( 246 | self, 247 | dataset: Callable, 248 | batch_size: Optional[int] = None, 249 | num_workers: Optional[int] = None, 250 | ) -> Callable: 251 | """ 252 | Make dataloader for validation and testing 253 | """ 254 | return DataLoader( 255 | dataset, 256 | batch_size=batch_size or self.train_params["batch_size"], 257 | shuffle=False, 258 | num_workers=num_workers or self.train_params["num_workers"], 259 | ) 260 | 261 | def make_predict_loader( 262 | self, 263 | dataset: Callable, 264 | batch_size: Optional[int] = None, 265 | num_workers: Optional[int] = None, 266 | ) -> Callable: 267 | """ 268 | Make dataloader for prediction 269 | """ 270 | dataset.predict = True 271 | return DataLoader( 272 | dataset, 273 | batch_size=batch_size or self.train_params["batch_size"], 274 | shuffle=False, 275 | num_workers=num_workers or self.train_params["num_workers"], 276 | ) 277 | 278 | def train_on_dataset( 279 | self, 280 | train_dataset: Callable, 281 | val_dataset: Callable, 282 | checkpoint_path: Optional[str] = None, 283 | ): 284 | """ 285 | Train model and optionally log metrics to wandb. 286 | 287 | Args: 288 | train_dataset (Dataset): Dataset object that yields training examples 289 | val_dataset (Dataset) : Dataset object that yields training examples 290 | checkpoint_path (str): Path to model checkpoint from which to resume training. 291 | The optimizer will be set to its checkpointed state. 292 | 293 | Returns: 294 | PyTorch Lightning Trainer 295 | """ 296 | torch.set_float32_matmul_precision("medium") 297 | 298 | # Set up logging 299 | logger = self.parse_logger() 300 | 301 | # Set up trainer 302 | trainer = pl.Trainer( 303 | max_epochs=self.train_params["max_epochs"], 304 | accelerator="gpu", 305 | devices=make_list(self.train_params["devices"]), 306 | logger=logger, 307 | callbacks=[ModelCheckpoint(monitor="val_loss", mode="min", save_last=True)], 308 | default_root_dir=self.train_params["save_dir"], 309 | accumulate_grad_batches=self.train_params["accumulate_grad_batches"], 310 | precision="16-mixed", 311 | ) 312 | 313 | # Make dataloaders 314 | train_dataloader = self.make_train_loader(train_dataset) 315 | val_dataloader = self.make_test_loader(val_dataset) 316 | 317 | if checkpoint_path is None: 318 | # First validation pass 319 | trainer.validate(model=self, dataloaders=val_dataloader) 320 | self.val_metrics.reset() 321 | 322 | # Add data parameters 323 | self.data_params["tasks"] = train_dataset.tasks.reset_index(names="name").to_dict(orient="list") 324 | 325 | for attr, value in self._get_dataset_attrs(train_dataset): 326 | self.data_params["train_" + attr] = value 327 | 328 | for attr, value in self._get_dataset_attrs(val_dataset): 329 | self.data_params["val_" + attr] = value 330 | 331 | # Training 332 | trainer.fit( 333 | model=self, 334 | train_dataloaders=train_dataloader, 335 | val_dataloaders=val_dataloader, 336 | ckpt_path=checkpoint_path, 337 | ) 338 | return trainer 339 | 340 | def _get_dataset_attrs(self, dataset: Callable) -> None: 341 | """ 342 | Read data parameters from a dataset object 343 | """ 344 | for attr in dir(dataset): 345 | if not attr.startswith("_") and not attr.isupper(): 346 | value = getattr(dataset, attr) 347 | if ( 348 | (isinstance(value, str)) 349 | or (isinstance(value, int)) 350 | or (isinstance(value, float)) 351 | or (value is None) 352 | ): 353 | yield attr, value 354 | 355 | def on_save_checkpoint(self, checkpoint: dict) -> None: 356 | checkpoint["hyper_parameters"]["data_params"] = self.data_params 357 | 358 | def predict_on_dataset( 359 | self, 360 | dataset: Callable, 361 | devices: Optional[int] = None, 362 | num_workers: int = 1, 363 | batch_size: int = 6, 364 | augment_aggfunc: Union[str, Callable] = "mean", 365 | compare_func: Optional[Union[str, Callable]] = None, 366 | ): 367 | """ 368 | Predict for a dataset of sequences or variants 369 | 370 | Args: 371 | dataset: Dataset object that yields one-hot encoded sequences 372 | 373 | devices: Number of devices to use, 374 | e.g. machine has 4 gpu's but only want to use 2 for predictions 375 | 376 | num_workers: Number of workers for data loader 377 | 378 | batch_size: Batch size for data loader 379 | 380 | Returns: 381 | Model predictions as a numpy array or dataframe 382 | """ 383 | torch.set_float32_matmul_precision("medium") 384 | dataloader = self.make_predict_loader( 385 | dataset, 386 | num_workers=num_workers, 387 | batch_size=batch_size, 388 | ) 389 | accelerator = "auto" 390 | if devices is None: 391 | devices = "auto" # use all devices 392 | # device = "cuda" if torch.cuda.is_available() else "cpu" 393 | accelerator = "gpu" if torch.cuda.is_available() else "auto" 394 | 395 | if accelerator == "auto": 396 | trainer = pl.Trainer(accelerator=accelerator, logger=None) 397 | else: 398 | trainer = pl.Trainer(accelerator=accelerator, devices=devices, logger=None) 399 | 400 | # Predict 401 | preds = torch.concat(trainer.predict(self, dataloader)).squeeze(-1) 402 | 403 | # Reshape predictions 404 | preds = rearrange( 405 | preds, 406 | "(b n a) t -> b n a t", 407 | n=dataset.n_augmented, 408 | a=dataset.n_alleles, 409 | ) 410 | 411 | # Convert predictions to numpy array 412 | preds = preds.detach().cpu().numpy() 413 | 414 | if dataset.n_alleles == 2: 415 | preds = preds[:, :, 1, :] - preds[:, :, 0, :] # BNT 416 | else: 417 | preds = preds.squeeze(2) # B N T 418 | 419 | preds = np.mean(preds, axis=1) # B T 420 | return preds 421 | 422 | def get_task_idxs( 423 | self, 424 | tasks: Union[int, str, List[int], List[str]], 425 | key: str = "name", 426 | invert: bool = False, 427 | ) -> Union[int, List[int]]: 428 | """ 429 | Given a task name or metadata entry, get the task index 430 | If integers are provided, return them unchanged 431 | 432 | Args: 433 | tasks: A string corresponding to a task name or metadata entry, 434 | or an integer indicating the index of a task, or a list of strings/integers 435 | key: key to model.data_params["tasks"] in which the relevant task data is 436 | stored. "name" will be used by default. 437 | invert: Get indices for all tasks except those listed in tasks 438 | 439 | Returns: 440 | The index or indices of the corresponding task(s) in the model's 441 | output. 442 | """ 443 | # If a string is provided, extract the index 444 | if isinstance(tasks, str): 445 | return self.data_params["tasks"][key].index(tasks) 446 | # If an integer is provided, return it as the index 447 | elif isinstance(tasks, int): 448 | return tasks 449 | # If a list is provided, return teh index for each element 450 | elif isinstance(tasks, list): 451 | return [self.get_task_idxs(task) for task in tasks] 452 | else: 453 | raise TypeError("Input must be a list, string or integer") 454 | if invert: 455 | return [i for i in range(self.model_params["n_tasks"]) if i not in make_list(tasks)] 456 | -------------------------------------------------------------------------------- /src/decima/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import Tensor, nn 4 | 5 | 6 | class TaskWisePoissonMultinomialLoss(nn.Module): 7 | def __init__( 8 | self, 9 | total_weight: float = 1, 10 | eps: float = 1e-7, 11 | debug=False, 12 | ) -> None: 13 | super().__init__() 14 | self.eps = eps 15 | self.total_weight = total_weight 16 | self.debug = debug 17 | 18 | def forward(self, input: Tensor, target: Tensor) -> Tensor: 19 | input = torch.exp(input).squeeze(-1) # B, T 20 | target = target.squeeze(-1) # B, T 21 | 22 | total_target = target.sum(axis=-1) # B, 23 | total_input = input.sum(axis=-1) # B, 24 | 25 | # total count poisson loss, mean across targets 26 | poisson_term = F.poisson_nll_loss(total_input, total_target, log_input=False, reduction="mean") # B 27 | poisson_term = self.total_weight * poisson_term # B, 28 | 29 | # Get multinomial probabilities 30 | p_input = input / total_input.unsqueeze(1) # B, T 31 | log_p_input = torch.log(p_input) # B, T 32 | 33 | # multinomial loss 34 | multinomial_dot = -torch.multiply(target, log_p_input) # B x T 35 | multinomial_term = multinomial_dot.mean() 36 | 37 | # Combine 38 | loss = multinomial_term + poisson_term 39 | if self.debug: 40 | print(f"Multinomial: {multinomial_term}, Poisson: {poisson_term}") 41 | return loss 42 | -------------------------------------------------------------------------------- /src/decima/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from grelu.lightning.metrics import MSE 3 | from torchmetrics import Metric 4 | 5 | 6 | class DiseaseLfcMSE(Metric): 7 | def __init__(self, pairs, average: bool = True) -> None: 8 | super().__init__() 9 | self.mse = MSE(num_outputs=1, average=False) 10 | self.disease = pairs[:, 0] 11 | self.healthy = pairs[:, 1] 12 | self.average = average 13 | 14 | def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: 15 | pred_lfcs = preds[:, self.disease, 0] - preds[:, self.healthy, 0] # B, T 16 | target_lfcs = target[:, self.disease, 0] - target[:, self.healthy, 0] # B, T 17 | self.mse.update(pred_lfcs, target_lfcs) 18 | 19 | def compute(self) -> torch.Tensor: 20 | output = self.mse.compute() 21 | if self.average: 22 | return output.mean() 23 | else: 24 | return output 25 | 26 | def reset(self) -> None: 27 | self.mse.reset() 28 | -------------------------------------------------------------------------------- /src/decima/preprocess.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import anndata 4 | import bioframe as bf 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | from grelu.io.genome import read_sizes 9 | from grelu.sequence.format import intervals_to_strings, strings_to_one_hot 10 | from grelu.sequence.utils import get_unique_length 11 | from tqdm import tqdm 12 | 13 | 14 | def merge_transcripts(gtf): 15 | # Get gene-level columns 16 | genes = gtf[["chrom", "start", "end", "strand", "gene_id", "gene_type", "gene_name"]].copy() 17 | 18 | # Aggregate all features from the same gene 19 | genes = genes.groupby("gene_name").agg(lambda x: list(set(x))) 20 | 21 | # Get minimum start and maximum end 22 | genes["start"] = genes["start"].apply(min) 23 | genes["end"] = genes["end"].apply(max) 24 | 25 | # Merge all other columns 26 | for col in ["chrom", "strand", "gene_id", "gene_type"]: 27 | genes[col] = genes[col].apply(lambda x: x[0]) 28 | 29 | return genes 30 | 31 | 32 | def var_to_intervals(ad, chr_end_pad=10000, genome="hg38", seq_len=524288, crop_coords=163840): 33 | sizes = read_sizes(genome) 34 | 35 | # Calculate interval size 36 | print( 37 | f"The interval size is {seq_len} bases. Of these, {crop_coords} will be upstream of the gene start and {seq_len - crop_coords} will be downstream of the gene start." 38 | ) 39 | 40 | # Create intervals around + strand genes 41 | ad.var.loc[ad.var.strand == "+", "start"] = ad.var.loc[ad.var.strand == "+", "gene_start"] - crop_coords 42 | ad.var.loc[ad.var.strand == "+", "end"] = ad.var.loc[ad.var.strand == "+", "start"] + seq_len 43 | 44 | # Create interval around - strand genes 45 | ad.var.loc[ad.var.strand == "-", "end"] = ad.var.loc[ad.var.strand == "-", "gene_end"] + crop_coords 46 | ad.var.loc[ad.var.strand == "-", "start"] = ad.var.loc[ad.var.strand == "-", "end"] - seq_len 47 | 48 | # shift sequences with start < 0 49 | crossing_start = ad.var.start < chr_end_pad 50 | ad.var.loc[crossing_start, "start"] = chr_end_pad 51 | ad.var.loc[crossing_start, "end"] = ad.var.loc[crossing_start, "start"] + seq_len 52 | print(f"{np.sum(crossing_start)} intervals extended beyond the chromosome start and have been shifted") 53 | 54 | # shift sequences with end > chromosome size 55 | crossing_end = 0 56 | for chrom, size in sizes.values: 57 | max_end = size - chr_end_pad 58 | drop = (ad.var.chrom == chrom) & (ad.var.end > max_end) 59 | crossing_end += drop.sum() 60 | ad.var.loc[drop, "end"] = max_end 61 | ad.var.loc[drop, "start"] = ad.var.loc[drop, "end"] - seq_len 62 | print(f"{crossing_end} intervals extended beyond the chromosome end and have been shifted") 63 | 64 | # gene start position on output sequence 65 | ad.var.loc[ad.var.strand == "+", "gene_mask_start"] = ( 66 | ad.var.loc[ad.var.strand == "+", "gene_start"] - ad.var.loc[ad.var.strand == "+", "start"] 67 | ) 68 | ad.var.loc[ad.var.strand == "-", "gene_mask_start"] = ( 69 | ad.var.loc[ad.var.strand == "-", "end"] - ad.var.loc[ad.var.strand == "-", "gene_end"] 70 | ) 71 | ad.var.gene_mask_start = ad.var.gene_mask_start.astype(int) 72 | ad.var.gene_length = ad.var.gene_length.astype(int) 73 | 74 | # Get gene end position on sequence 75 | ad.var["gene_mask_end"] = (ad.var.gene_mask_start + ad.var.gene_length).apply(lambda x: min(seq_len, x)) 76 | ad.var.gene_mask_end = ad.var.gene_mask_end.astype(int) 77 | 78 | # Drop intervals with less than crop_coords upstream bases 79 | drop = ad.var.gene_mask_start < crop_coords 80 | ad = ad[:, ~drop] 81 | print(f"{np.sum(drop)} intervals did not extend far enough upstream of the TSS and have been dropped") 82 | 83 | # Check length 84 | assert get_unique_length(ad.var) == seq_len 85 | return ad 86 | 87 | 88 | def assign_borzoi_folds(ad, splits): 89 | # Extract gene intervals 90 | genes = ad.var.reset_index().rename(columns={"index": "gene_name"}) 91 | 92 | # Overlap with Borzoi splits 93 | overlaps = bf.overlap(genes, splits, how="left") 94 | overlaps = overlaps[["gene_id", "fold_"]].drop_duplicates() 95 | overlaps.columns = ["gene_id", "fold"] 96 | 97 | # List all overlapping folds for each gene 98 | overlaps = ( 99 | overlaps.groupby("gene_id") 100 | .fold.apply(list) 101 | .apply(lambda x: x if x[0] is None else ",".join([f[-1] for f in x])) 102 | ) 103 | overlaps = overlaps.reset_index() 104 | overlaps.loc[overlaps.fold.apply(lambda x: x[0] is None), "fold"] = "none" 105 | 106 | # Add back to AnnData 107 | ind = ad.var.index 108 | ad.var = ad.var.merge(overlaps, on="gene_id", how="left") 109 | ad.var.index = ind 110 | return ad 111 | 112 | 113 | def aggregate_anndata( 114 | ad, 115 | by_cols=[ 116 | "cell_type", 117 | "tissue", 118 | "organ", 119 | "disease", 120 | "study", 121 | "dataset", 122 | "region", 123 | "subregion", 124 | "celltype_coarse", 125 | ], 126 | sum_cols=["n_cells"], 127 | ): 128 | # Get column names 129 | obs_cols = by_cols + sum_cols 130 | gene_names = ad.var_names.tolist() 131 | 132 | # Format obs 133 | print("Creating new obs matrix") 134 | obs = ad.obs[obs_cols].copy() 135 | for col in by_cols: 136 | obs[col] = obs[col].astype(str) 137 | for col in sum_cols: 138 | obs[col] = obs[col].astype(int) 139 | 140 | # Create X 141 | X = pd.DataFrame(ad.X, index=obs.index.tolist(), columns=gene_names) 142 | X = pd.concat( 143 | [ 144 | obs, 145 | X, 146 | ], 147 | axis=1, 148 | ) 149 | 150 | # Aggregate X 151 | print("Aggregating") 152 | X = X.groupby(by_cols).sum().reset_index() 153 | 154 | # Split off the obs again 155 | obs = X[obs_cols] 156 | obs.index = [f"agg_{i}" for i in range(len(obs))] 157 | X = X[gene_names] 158 | 159 | # Create new anndata 160 | print("Creating new anndata") 161 | new_ad = anndata.AnnData(X=np.array(X), obs=obs, var=ad.var.copy()) 162 | return new_ad 163 | 164 | 165 | def change_values(df, col, value_dict): 166 | df[col] = df[col].astype(str) 167 | for k, v in value_dict.items(): 168 | df.loc[df[col] == k, col] = v 169 | return df 170 | 171 | 172 | def get_frac_N(interval, genome="hg38"): 173 | seq = intervals_to_strings(interval, genome=genome) 174 | return seq.count("N") / len(seq) 175 | 176 | 177 | def match_cellranger_2024(ad, genes24): 178 | matched = 0 179 | unmatched_genes = ad.var.index[ad.var.gene_id.isna()].tolist() 180 | print(f"{len(unmatched_genes)} genes unmatched.") 181 | for gene in tqdm(unmatched_genes): 182 | if gene in genes24.index.tolist(): 183 | gene_id = genes24.gene_id[genes24.index == gene].values[0] 184 | if gene_id not in ad.var.gene_id.tolist(): 185 | for col in ad.var.columns: 186 | ad.var.loc[gene, col] = genes24.loc[genes24.index == gene, col].values[0] 187 | matched += 1 188 | 189 | print(f"{matched} genes matched.") 190 | 191 | 192 | def match_ref_ad(ad, ref_ad): 193 | matched = 0 194 | unmatched_genes = ad.var.index[ad.var.gene_id.isna()].tolist() 195 | print(f"{len(unmatched_genes)} genes unmatched.") 196 | 197 | for gene in tqdm(unmatched_genes): 198 | if gene in ref_ad.var.index.tolist(): 199 | gene_id = ref_ad.var.gene_id[ref_ad.var.index == gene].values[0] 200 | if gene_id not in ad.var.gene_id.tolist(): 201 | for col in ad.var.columns: 202 | ad.var.loc[gene, col] = ref_ad.var.loc[ref_ad.var.index == gene, col].values[0] 203 | matched += 1 204 | 205 | print(f"{matched} genes matched.") 206 | 207 | 208 | def load_ncbi_string(string): 209 | out = [] 210 | reports = json.loads(string[0]) 211 | 212 | # Check the total count 213 | if reports == {"total_count": 0}: 214 | pass 215 | else: 216 | for i, r in enumerate(reports["reports"]): 217 | try: 218 | curr_dict = {} 219 | if "query" in r: 220 | curr_dict["gene_name"] = r["query"][0] 221 | else: 222 | curr_dict["gene_name"] = r["gene"]["symbol"] 223 | r = r["gene"] 224 | curr_dict["symbol"] = r["symbol"] 225 | curr_dict["chrom"] = "chr" + r["chromosomes"][0] 226 | curr_dict["gene_type"] = r["type"] 227 | 228 | if "ensembl_gene_ids" in r: 229 | eids = r["ensembl_gene_ids"] 230 | if len(eids) > 1: 231 | pass 232 | else: 233 | curr_dict["gene_id"] = eids[0] 234 | else: 235 | curr_dict["gene_id"] = r["symbol"] 236 | 237 | for annot in r["annotations"]: 238 | if "assembly_name" in annot: 239 | if annot["assembly_name"] == "GRCh38.p14": 240 | curr_dict["start"] = annot["genomic_locations"][0]["genomic_range"]["begin"] 241 | curr_dict["end"] = annot["genomic_locations"][0]["genomic_range"]["end"] 242 | curr_dict["strand"] = ( 243 | "-" if annot["genomic_locations"][0]["genomic_range"]["orientation"] == "minus" else "+" 244 | ) 245 | 246 | out.append(curr_dict) 247 | except: 248 | print(i) 249 | 250 | out = pd.DataFrame(out) 251 | out = out[out.gene_name.isin(out.gene_name.value_counts()[out.gene_name.value_counts() == 1].index)] 252 | return out 253 | 254 | 255 | def match_ncbi(ad, ncbi): 256 | matched = 0 257 | unmatched_genes = ad.var.index[ad.var.gene_id.isna()].tolist() 258 | print(f"{len(unmatched_genes)} genes unmatched.") 259 | for gene in tqdm(unmatched_genes): 260 | if gene in ncbi.gene_name.tolist(): 261 | for col in ad.var.columns: 262 | ad.var.loc[gene, col] = ncbi.loc[ncbi.gene_name == gene, col].values[0] 263 | matched += 1 264 | 265 | print(f"{matched} genes matched.") 266 | 267 | 268 | def make_inputs(gene, ad): 269 | assert gene in ad.var_names, f"{gene} is not in the anndata object" 270 | row = ad.var.loc[gene] 271 | 272 | print("One-hot encoding sequence") 273 | seq = intervals_to_strings(row, genome="hg38") 274 | seq = strings_to_one_hot(seq) 275 | 276 | print("Making gene mask") 277 | mask = np.zeros(shape=(1, 524288)) 278 | mask[0, row.gene_mask_start : row.gene_mask_end] += 1 279 | mask = torch.Tensor(mask) 280 | 281 | return seq, mask 282 | -------------------------------------------------------------------------------- /src/decima/read_hdf5.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | import pandas as pd 4 | import torch 5 | from grelu.data.augment import Augmenter, _split_overall_idx 6 | from grelu.sequence.format import BASE_TO_INDEX_HASH, indices_to_one_hot 7 | from torch.utils.data import Dataset 8 | 9 | from .preprocess import make_inputs 10 | 11 | 12 | def count_genes(h5_file, key=None): 13 | with h5py.File(h5_file, "r") as f: 14 | genes = np.array(f["genes"]).astype(str) 15 | if key is None: 16 | return genes.shape[0] 17 | else: 18 | return np.sum(genes[:, 1] == key) 19 | 20 | 21 | def index_genes(h5_file, key=None): 22 | with h5py.File(h5_file, "r") as f: 23 | genes = np.array(f["genes"]).astype(str) 24 | if key is None: 25 | return np.array(range(len(genes))) 26 | else: 27 | return np.where(genes[:, 1] == key)[0] 28 | 29 | 30 | def list_genes(h5_file, key=None): 31 | with h5py.File(h5_file, "r") as f: 32 | genes = np.array(f["genes"]).astype(str) 33 | if key is None: 34 | return genes[:, 0] 35 | else: 36 | return genes[genes[:, 1] == key, 0] 37 | 38 | 39 | def get_gene_idx(h5_file, gene, key=None): 40 | gene_ord = list_genes(h5_file, key=None) 41 | return np.where(gene_ord == gene)[0][0] 42 | 43 | 44 | def _extract_center(x, seq_len, shift=0): 45 | start = (x.shape[-1] - seq_len) // 2 46 | start -= shift 47 | return x[..., start : start + seq_len] 48 | 49 | 50 | def extract_gene_data(h5_file, gene, seq_len=524288, merge=True): 51 | gene_idx = get_gene_idx(h5_file, key=None, gene=gene) 52 | 53 | with h5py.File(h5_file, "r") as f: 54 | seq = np.array(f["sequences"][gene_idx]) 55 | seq = indices_to_one_hot(seq) 56 | mask = torch.Tensor(np.array(f["masks"][[gene_idx]])) 57 | 58 | seq = _extract_center(seq, seq_len=seq_len) 59 | mask = _extract_center(mask, seq_len=seq_len) 60 | 61 | if merge: 62 | return torch.vstack([seq, mask]) 63 | else: 64 | return seq, mask 65 | 66 | 67 | def mutate(seq, allele, pos): 68 | idx = BASE_TO_INDEX_HASH[allele] 69 | seq[:4, pos] = 0 70 | seq[idx, pos] = 1 71 | return seq 72 | 73 | 74 | class HDF5Dataset(Dataset): 75 | def __init__( 76 | self, 77 | key, 78 | h5_file, 79 | ad=None, 80 | seq_len=524288, 81 | max_seq_shift=0, 82 | seed=0, 83 | augment_mode="random", 84 | ): 85 | super().__init__() 86 | 87 | # Save data params 88 | self.h5_file = h5_file 89 | self.seq_len = seq_len 90 | self.key = key 91 | 92 | # Save augmentation params 93 | self.max_seq_shift = max_seq_shift 94 | self.augmenter = Augmenter( 95 | rc=False, 96 | max_seq_shift=self.max_seq_shift, 97 | max_pair_shift=0, 98 | seq_len=self.seq_len, 99 | label_len=None, 100 | seed=seed, 101 | mode=augment_mode, 102 | ) 103 | self.n_augmented = len(self.augmenter) 104 | self.padded_seq_len = self.seq_len + (2 * self.max_seq_shift) 105 | 106 | # Index genes 107 | self.gene_index = index_genes(self.h5_file, key=self.key) 108 | self.n_seqs = len(self.gene_index) 109 | 110 | # Setup 111 | self.dataset = h5py.File(self.h5_file, "r") 112 | self.extract_tasks(ad) 113 | self.predict = False 114 | self.n_alleles = 1 115 | 116 | def __len__(self): 117 | return self.n_seqs * self.n_augmented 118 | 119 | def close(self): 120 | self.dataset.close() 121 | 122 | def extract_tasks(self, ad=None): 123 | tasks = np.array(self.dataset["tasks"]).astype(str) 124 | if ad is not None: 125 | assert np.all(tasks == ad.obs_names) 126 | self.tasks = ad.obs 127 | else: 128 | self.tasks = pd.DataFrame(index=tasks) 129 | 130 | def extract_seq(self, idx): 131 | seq = self.dataset["sequences"][idx] 132 | seq = indices_to_one_hot(seq) # 4, L 133 | mask = self.dataset["masks"][[idx]] # 1, L 134 | seq = np.concatenate([seq, mask]) # 5, L 135 | seq = _extract_center(seq, seq_len=self.padded_seq_len) 136 | return torch.Tensor(seq) 137 | 138 | def extract_label(self, idx): 139 | return torch.Tensor(self.dataset["labels"][idx]) 140 | 141 | def __getitem__(self, idx): 142 | # Augment 143 | seq_idx, augment_idx = _split_overall_idx(idx, (self.n_seqs, self.n_augmented)) 144 | 145 | # Extract the sequence 146 | gene_idx = self.gene_index[seq_idx] 147 | seq = self.extract_seq(gene_idx) 148 | 149 | # Augment the sequence 150 | seq = self.augmenter(seq=seq, idx=augment_idx) 151 | 152 | if self.predict: 153 | return seq 154 | 155 | else: 156 | label = self.extract_label(gene_idx) 157 | return seq, label 158 | 159 | 160 | class VariantDataset(Dataset): 161 | def __init__( 162 | self, 163 | variants, 164 | h5_file=None, 165 | ad=None, 166 | seq_len=524288, 167 | max_seq_shift=0, 168 | test_ref=False, 169 | ): 170 | super().__init__() 171 | 172 | # Save data params 173 | self.seq_len = seq_len 174 | self.h5_file = h5_file 175 | 176 | # Save variant params 177 | self.variants = variants[["gene", "rel_pos", "ref_tx", "alt_tx"]].copy() 178 | self.n_seqs = len(self.variants) 179 | self.n_alleles = 2 180 | self.test_ref = test_ref 181 | 182 | # Save augmentation params 183 | self.max_seq_shift = max_seq_shift 184 | self.augmenter = Augmenter( 185 | rc=False, 186 | max_seq_shift=self.max_seq_shift, 187 | max_pair_shift=0, 188 | seq_len=self.seq_len, 189 | label_len=None, 190 | mode="serial", 191 | ) 192 | self.n_augmented = len(self.augmenter) 193 | self.padded_seq_len = self.seq_len + (2 * self.max_seq_shift) 194 | 195 | # Map each variant to the corresponding gene in the h5 file 196 | if ad is None: 197 | assert self.h5_file is not None 198 | gene_map = {gene: get_gene_idx(self.h5_file, gene) for gene in self.variants.gene.unique()} 199 | self.dataset = h5py.File(self.h5_file, "r") 200 | self.pad = self.dataset["pad"] 201 | else: 202 | gene_map = {gene: i for i, gene in enumerate(self.variants.gene.unique())} 203 | seqs, masks = list(zip(*[make_inputs(gene, ad) for gene in gene_map.keys()])) 204 | self.dataset = { 205 | "sequences": torch.stack(seqs), 206 | "masks": torch.vstack(masks), 207 | } 208 | self.pad = 0 209 | 210 | # Setup 211 | self.variants["gene_idx"] = self.variants.gene.map(gene_map) 212 | 213 | def __len__(self): 214 | return self.n_seqs * self.n_augmented * self.n_alleles 215 | 216 | def close(self): 217 | self.dataset.close() 218 | 219 | def extract_seq(self, idx): 220 | seq = self.dataset["sequences"][idx] 221 | if self.h5_file is not None: 222 | seq = indices_to_one_hot(seq) # 4, L 223 | mask = self.dataset["masks"][[idx]] # 1, L 224 | seq = np.concatenate([seq, mask]) # 5, L 225 | return torch.Tensor(seq) 226 | 227 | def __getitem__(self, idx): 228 | # Get indices 229 | seq_idx, augment_idx, allele_idx = _split_overall_idx(idx, (self.n_seqs, self.n_augmented, self.n_alleles)) 230 | 231 | # Extract the sequence 232 | variant = self.variants.iloc[seq_idx] 233 | seq = self.extract_seq(int(variant.gene_idx)) 234 | 235 | if self.test_ref: # check that ref is actually present 236 | assert ["A", "C", "G", "T"][seq[:4, variant.rel_pos + self.pad].argmax()] == variant.ref_tx, ( 237 | variant.ref_tx + "_vs_" + seq[:4, variant.rel_pos + self.pad] + "__" + str(seq_idx) 238 | ) 239 | 240 | # Insert the allele 241 | if allele_idx: 242 | seq = mutate(seq, variant.alt_tx, variant.rel_pos + self.pad) 243 | else: 244 | seq = mutate(seq, variant.ref_tx, variant.rel_pos + self.pad) 245 | 246 | # Augment the sequence 247 | seq = _extract_center(seq, seq_len=self.padded_seq_len) 248 | seq = self.augmenter(seq=seq, idx=augment_idx) 249 | return seq 250 | -------------------------------------------------------------------------------- /src/decima/variant.py: -------------------------------------------------------------------------------- 1 | from grelu.sequence.utils import get_unique_length, reverse_complement 2 | 3 | 4 | def process_variants(variants, ad, min_from_end=0): 5 | # Match to gene intervals 6 | orig_len = len(variants) 7 | variants = variants[variants.gene.isin(ad.var_names)] 8 | print(f"dropped {orig_len - len(variants)} variants because the gene was not found in ad.var") 9 | variants = variants.merge( 10 | ad.var[["start", "end", "strand", "gene_mask_start"]], 11 | left_on="gene", 12 | right_index=True, 13 | how="left", 14 | ) 15 | 16 | # Get relative position 17 | variants["rel_pos"] = variants.apply( 18 | lambda row: row.pos - row.start if row.strand == "+" else row.end - row.pos, 19 | axis=1, 20 | ) 21 | 22 | # Filter by relative position 23 | orig_len = len(variants) 24 | interval_len = get_unique_length(ad.var) 25 | variants = variants[(variants.rel_pos > min_from_end) & (variants.rel_pos < interval_len - min_from_end)] 26 | print(f"dropped {orig_len - len(variants)} variants because the variant did not fit in the interval") 27 | 28 | # Reverse complement the alleles for - strand genes 29 | variants["ref_tx"] = variants.apply( 30 | lambda row: row.ref if row.strand == "+" else reverse_complement(row.ref), 31 | axis=1, 32 | ) 33 | variants["alt_tx"] = variants.apply( 34 | lambda row: row.alt if row.strand == "+" else reverse_complement(row.alt), 35 | axis=1, 36 | ) 37 | 38 | # Get distance from TSS 39 | variants["tss_dist"] = variants.rel_pos - variants.gene_mask_start 40 | 41 | return variants 42 | -------------------------------------------------------------------------------- /src/decima/visualize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import plotnine as p9 4 | from grelu.visualize import plot_attributions 5 | 6 | from .evaluate import match_criteria 7 | 8 | 9 | def plot_logo(motif, rc=False, figsize=(2, 1)): 10 | m = np.array(motif.frequencies).T 11 | ic = 2 + (m * np.log2(m)).sum(0) 12 | m = m * np.expand_dims(ic, 0) 13 | if rc: 14 | m = np.flip(m, (0, 1)) 15 | return plot_attributions(m, figsize=figsize) 16 | 17 | 18 | def plot_gene_scatter( 19 | ad, 20 | gene, 21 | show_corr=True, 22 | show_abline=False, 23 | size=0.1, 24 | alpha=0.5, 25 | figure_size=(3, 2.5), 26 | corr_col="pearson", 27 | corrx=None, 28 | corry=None, 29 | ): 30 | # Extract data 31 | df = pd.DataFrame( 32 | { 33 | "true": ad[:, gene].X.squeeze(), 34 | "pred": ad[:, gene].layers["preds"].squeeze(), 35 | }, 36 | index=ad.obs_names, 37 | ) 38 | 39 | # Make plot 40 | g = ( 41 | p9.ggplot(df, p9.aes(x="true", y="pred")) 42 | + p9.geom_pointdensity(size=size, alpha=alpha) 43 | + p9.theme_classic() 44 | + p9.theme(figure_size=figure_size) 45 | + p9.ggtitle(gene) 46 | + p9.xlab("True expression") 47 | + p9.ylab("Predicted expression") 48 | + p9.theme(plot_title=p9.element_text(face="italic")) 49 | ) 50 | 51 | # Compute correlation 52 | if show_corr: 53 | corr = np.round(ad.var.loc[gene, corr_col], 2) 54 | corrx = corrx or 1 55 | corry = corry or df.pred.max() - 0.5 56 | g = g + p9.geom_text(x=corrx, y=corry, label=f"rho={corr}") 57 | 58 | if show_abline: 59 | # Add diagonal 60 | g = g + p9.geom_abline(intercept=0, slope=1) 61 | 62 | return g 63 | 64 | 65 | def plot_track_scatter( 66 | ad, 67 | track, 68 | off_track=None, 69 | show_corr=True, 70 | show_abline=False, 71 | size=0.1, 72 | alpha=0.5, 73 | figure_size=(3, 2.5), 74 | corrx=None, 75 | corry=None, 76 | ): 77 | # Extract data 78 | df = pd.DataFrame( 79 | { 80 | "true": ad[track, :].X.squeeze(), 81 | "pred": ad[track, :].layers["preds"].squeeze(), 82 | }, 83 | index=ad.var_names, 84 | ) 85 | if off_track is not None: 86 | df["true"] = df["true"] - ad[off_track, :].X.squeeze() 87 | df["pred"] = df["pred"] - ad[off_track, :].layers["preds"].squeeze() 88 | 89 | # Make plot 90 | g = ( 91 | p9.ggplot(df, p9.aes(x="true", y="pred")) 92 | + p9.geom_pointdensity(size=size, alpha=alpha) 93 | + p9.theme_classic() 94 | + p9.theme(figure_size=figure_size) 95 | ) 96 | 97 | if off_track is None: 98 | g = g + p9.xlab("True expression") + p9.ylab("Predicted expression") 99 | else: 100 | g = g + p9.xlab("True log FC") + p9.ylab("Predicted log FC") 101 | 102 | # Compute correlation 103 | if show_corr: 104 | corr = np.round(df.corr().iloc[0, 1], 2) 105 | corrx = corrx or 1 106 | corry = corry or df.pred.max() - 0.5 107 | g = g + p9.geom_text(x=corrx, y=corry, label=f"rho={corr}") 108 | 109 | if show_abline: 110 | # Add diagonal 111 | g = g + p9.geom_abline(intercept=0, slope=1) 112 | 113 | return g 114 | 115 | 116 | def plot_marker_box( 117 | gene, 118 | ad, 119 | marker_features, 120 | label_name="label", 121 | split_col=None, 122 | split_values=None, 123 | order=None, 124 | include_preds=True, 125 | fill=True, 126 | ): 127 | # Get criteria to filter 128 | if isinstance(marker_features, list): 129 | marker_features = ad.var.loc[gene, marker_features].to_dict() 130 | filter_df = pd.DataFrame(marker_features) 131 | 132 | # Collect observations for this gene 133 | to_plot = pd.DataFrame( 134 | { 135 | "True": ad[:, gene].X.squeeze(), 136 | "Predicted": ad[:, gene].layers["preds"].squeeze(), 137 | label_name: "Other", 138 | }, 139 | index=ad.obs.index, 140 | ) 141 | 142 | # Get matching observations 143 | labels = filter_df.apply(lambda row: "_".join(row.dropna()), axis=1).tolist() 144 | for i in range(len(filter_df)): 145 | match = match_criteria(df=ad.obs, filter_df=filter_df.iloc[[i]].copy()) 146 | to_plot.loc[match, label_name] = labels[i] 147 | 148 | # Choose background organs 149 | if split_col is not None: 150 | to_plot[split_col] = ad.obs[split_col].tolist() 151 | split_values = split_values or to_plot[split_col].unique() 152 | for spl in split_values: 153 | to_plot.loc[ 154 | (to_plot[split_col] == spl) & (to_plot[label_name] == "Other"), 155 | label_name, 156 | ] = f"Other {spl}" 157 | to_plot = to_plot.iloc[:, :3] 158 | 159 | if include_preds: 160 | # Reorder the factor levels based on the median value 161 | to_plot = to_plot.melt(id_vars=label_name) 162 | if order is None: 163 | order = to_plot.groupby(label_name)["value"].median().sort_values(ascending=False).index.tolist() 164 | to_plot[label_name] = pd.Categorical(to_plot[label_name], categories=order) 165 | 166 | # Plot 167 | to_plot.variable = pd.Categorical(to_plot.variable, categories=["True", "Predicted"]) 168 | g = ( 169 | p9.ggplot(to_plot, p9.aes(x="variable", y="value", fill=label_name)) 170 | + p9.geom_boxplot(outlier_size=0.1) 171 | + p9.theme_classic() 172 | + p9.theme(figure_size=(3, 2.5)) 173 | + p9.ggtitle(gene) 174 | + p9.ylab("Expression") 175 | + p9.theme(plot_title=p9.element_text(face="italic")) 176 | + p9.theme(axis_title_x=p9.element_blank()) 177 | ) 178 | 179 | else: 180 | # Reorder the factor levels based on the median value 181 | if order is None: 182 | order = to_plot.groupby(label_name)["True"].median().sort_values(ascending=False).index.tolist() 183 | to_plot[label_name] = pd.Categorical(to_plot[label_name], categories=order) 184 | 185 | # Plot 186 | if fill: 187 | g = p9.ggplot(to_plot, p9.aes(x=label_name, y="True", fill=label_name)) 188 | else: 189 | g = p9.ggplot(to_plot, p9.aes(x=label_name, y="True")) 190 | g = ( 191 | g 192 | + p9.geom_boxplot(outlier_size=0.1) 193 | + p9.theme_classic() 194 | + p9.theme(figure_size=(3, 2.5)) 195 | + p9.ylab("Measured Expression") 196 | + p9.ggtitle(gene) 197 | + p9.theme(plot_title=p9.element_text(face="italic")) 198 | + p9.theme(axis_title_x=p9.element_blank()) 199 | + p9.theme(axis_text_x=p9.element_text(rotation=60, hjust=1)) 200 | ) 201 | return g 202 | 203 | 204 | def plot_attribution_peaks(attr, tss_pos): 205 | to_plot = pd.DataFrame( 206 | { 207 | "distance from TSS": [x - tss_pos for x in range(attr.shape[1])], 208 | "attribution": attr.mean(0), 209 | } 210 | ) 211 | g = ( 212 | p9.ggplot(to_plot, p9.aes(x="distance from TSS", y="attribution")) 213 | + p9.geom_line() 214 | + p9.theme_classic() 215 | + p9.theme(figure_size=(6, 2)) 216 | ) 217 | 218 | return g 219 | -------------------------------------------------------------------------------- /src/decima/write_hdf5.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | from grelu.sequence.format import convert_input_type 4 | from grelu.sequence.utils import get_unique_length 5 | 6 | 7 | def write_hdf5(file, ad, pad=0): 8 | # Calculate seq_len 9 | seq_len = get_unique_length(ad.var) 10 | 11 | with h5py.File(file, "w") as f: 12 | # Metadata 13 | print("Writing metadata") 14 | f.create_dataset("pad", shape=(), data=pad) 15 | f.create_dataset("seq_len", shape=(), data=seq_len) 16 | padded_seq_len = seq_len + 2 * pad 17 | f.create_dataset("padded_seq_len", shape=(), data=padded_seq_len) 18 | 19 | # Tasks 20 | print("Writing tasks") 21 | tasks = np.array(ad.obs.index) 22 | f.create_dataset("tasks", shape=tasks.shape, data=tasks) 23 | 24 | # Genes 25 | arr = np.array(ad.var[["dataset"]].reset_index()) 26 | print(f"Writing genes array of shape: {arr.shape}") 27 | f.create_dataset("genes", shape=arr.shape, data=arr) 28 | 29 | # Labels 30 | arr = np.expand_dims(ad.X.T.astype(np.float32), 2) 31 | print(f"Writing labels array of shape: {arr.shape}") 32 | f.create_dataset("labels", shape=arr.shape, dtype=np.float32, data=arr) 33 | 34 | # Gene masks 35 | print("Making gene masks") 36 | shape = (ad.var.shape[0], padded_seq_len) 37 | arr = np.zeros(shape=shape) 38 | for i, row in enumerate(ad.var.itertuples()): 39 | arr[i, row.gene_mask_start + pad : row.gene_mask_end + pad] += 1 40 | print(f"Writing mask array of shape: {arr.shape}") 41 | f.create_dataset("masks", shape=shape, dtype=np.float32, data=arr) 42 | 43 | # Sequences 44 | print("Encoding sequences") 45 | arr = ad.var[["chrom", "start", "end", "strand"]].copy() 46 | arr.start = arr.start - pad 47 | arr.end = arr.end + pad 48 | arr = convert_input_type(arr, "indices", genome="hg38") 49 | print(f"Writing sequence array of shape: {arr.shape}") 50 | f.create_dataset("sequences", shape=arr.shape, dtype=np.int8, data=arr) 51 | 52 | print("Done!") 53 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dummy conftest.py for decima. 3 | 4 | If you don't know what this is for, just leave it empty. 5 | Read more about conftest.py under: 6 | - https://docs.pytest.org/en/stable/fixture.html 7 | - https://docs.pytest.org/en/stable/writing_plugins.html 8 | """ 9 | 10 | # import pytest 11 | -------------------------------------------------------------------------------- /tests/test_package.py: -------------------------------------------------------------------------------- 1 | import decima 2 | 3 | def test_package_load(): 4 | pass -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | # Tox configuration file 2 | # Read more under https://tox.wiki/ 3 | # THIS SCRIPT IS SUPPOSED TO BE AN EXAMPLE. MODIFY IT ACCORDING TO YOUR NEEDS! 4 | 5 | [tox] 6 | minversion = 3.24 7 | envlist = default 8 | isolated_build = True 9 | 10 | 11 | [testenv] 12 | description = Invoke pytest to run automated tests 13 | setenv = 14 | TOXINIDIR = {toxinidir} 15 | passenv = 16 | HOME 17 | SETUPTOOLS_* 18 | extras = 19 | testing 20 | commands = 21 | pytest {posargs} 22 | 23 | 24 | # # To run `tox -e lint` you need to make sure you have a 25 | # # `.pre-commit-config.yaml` file. See https://pre-commit.com 26 | # [testenv:lint] 27 | # description = Perform static analysis and style checks 28 | # skip_install = True 29 | # deps = pre-commit 30 | # passenv = 31 | # HOMEPATH 32 | # PROGRAMDATA 33 | # SETUPTOOLS_* 34 | # commands = 35 | # pre-commit run --all-files {posargs:--show-diff-on-failure} 36 | 37 | 38 | [testenv:{build,clean}] 39 | description = 40 | build: Build the package in isolation according to PEP517, see https://github.com/pypa/build 41 | clean: Remove old distribution files and temporary build artifacts (./build and ./dist) 42 | # https://setuptools.pypa.io/en/stable/build_meta.html#how-to-use-it 43 | skip_install = True 44 | changedir = {toxinidir} 45 | deps = 46 | build: build[virtualenv] 47 | passenv = 48 | SETUPTOOLS_* 49 | commands = 50 | clean: python -c 'import shutil; [shutil.rmtree(p, True) for p in ("build", "dist", "docs/_build")]' 51 | clean: python -c 'import pathlib, shutil; [shutil.rmtree(p, True) for p in pathlib.Path("src").glob("*.egg-info")]' 52 | build: python -m build {posargs} 53 | # By default, both `sdist` and `wheel` are built. If your sdist is too big or you don't want 54 | # to make it available, consider running: `tox -e build -- --wheel` 55 | 56 | 57 | [testenv:{docs,doctests,linkcheck}] 58 | description = 59 | docs: Invoke sphinx-build to build the docs 60 | doctests: Invoke sphinx-build to run doctests 61 | linkcheck: Check for broken links in the documentation 62 | passenv = 63 | SETUPTOOLS_* 64 | setenv = 65 | DOCSDIR = {toxinidir}/docs 66 | BUILDDIR = {toxinidir}/docs/_build 67 | docs: BUILD = html 68 | doctests: BUILD = doctest 69 | linkcheck: BUILD = linkcheck 70 | deps = 71 | -r {toxinidir}/docs/requirements.txt 72 | # ^ requirements.txt shared with Read The Docs 73 | commands = 74 | sphinx-build --color -b {env:BUILD} -d "{env:BUILDDIR}/doctrees" "{env:DOCSDIR}" "{env:BUILDDIR}/{env:BUILD}" {posargs} 75 | 76 | 77 | [testenv:publish] 78 | description = 79 | Publish the package you have been developing to a package index server. 80 | By default, it uses testpypi. If you really want to publish your package 81 | to be publicly accessible in PyPI, use the `-- --repository pypi` option. 82 | skip_install = True 83 | changedir = {toxinidir} 84 | passenv = 85 | # See: https://twine.readthedocs.io/en/latest/ 86 | TWINE_USERNAME 87 | TWINE_PASSWORD 88 | TWINE_REPOSITORY 89 | TWINE_REPOSITORY_URL 90 | deps = twine 91 | commands = 92 | python -m twine check dist/* 93 | python -m twine upload {posargs:--repository {env:TWINE_REPOSITORY:testpypi}} dist/* 94 | -------------------------------------------------------------------------------- /tutorials/variants.tsv: -------------------------------------------------------------------------------- 1 | chrom pos ref alt gene rsid 2 | chr1 1000018 G A ISG15 rs146254088 3 | chr1 1002308 T C ISG15 rs2489000 4 | chr1 109727471 A C GSTM3 rs11101994 5 | chr1 109728286 T G GSTM3 rs4540683 6 | chr1 109728807 T G GSTM3 rs4970775 7 | --------------------------------------------------------------------------------