├── .gitattributes ├── .github ├── rubric.png ├── workflows │ ├── join-release.yaml │ ├── tag-release.yaml │ ├── ci-dev.yaml │ ├── add-to-project-ci.yml │ └── update-copyright-headers.yaml ├── ISSUE_TEMPLATE │ ├── 2-dev-need-help.md │ ├── 1-user-need-help.md │ ├── 5-make-new.md │ ├── 4-make-better.md │ ├── 3-found-bug.md │ └── 6-where-to-go.md ├── pull_request_template.md ├── CONTRIBUTING.md └── SUPPORT.md ├── q2_feature_classifier ├── tests │ ├── data │ │ ├── class_weight.biom │ │ ├── blast6-format.tsv │ │ ├── dna-sequences.fasta │ │ ├── dna-sequences-mixed.fasta │ │ ├── dna-sequences-reverse.fasta │ │ ├── dna_sequence_both_test.fasta │ │ ├── dna-sequences-degenerate-primers.fasta │ │ └── query-seqs.fasta │ ├── __init__.py │ ├── test_custom.py │ ├── test_taxonomic_classifier.py │ ├── test_cutter.py │ ├── test_classifier.py │ └── test_consensus_assignment.py ├── types │ ├── __init__.py │ ├── _type.py │ └── _format.py ├── __init__.py ├── plugin_setup.py ├── citations.bib ├── _taxonomic_classifier.py ├── custom.py ├── _skl.py ├── _consensus_assignment.py ├── _cutter.py ├── _blast.py ├── classifier.py └── _vsearch.py ├── README.md ├── .coveragerc ├── Makefile ├── .copier-answers.yml ├── .gitignore ├── conda-recipe └── meta.yaml ├── pyproject.toml └── LICENSE /.gitattributes: -------------------------------------------------------------------------------- 1 | pyproject.toml export-subst 2 | -------------------------------------------------------------------------------- /.github/rubric.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiime2/q2-feature-classifier/HEAD/.github/rubric.png -------------------------------------------------------------------------------- /q2_feature_classifier/tests/data/class_weight.biom: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiime2/q2-feature-classifier/HEAD/q2_feature_classifier/tests/data/class_weight.biom -------------------------------------------------------------------------------- /.github/workflows/join-release.yaml: -------------------------------------------------------------------------------- 1 | name: join-release 2 | on: 3 | workflow_dispatch: {} 4 | jobs: 5 | release: 6 | uses: qiime2/distributions/.github/workflows/lib-join-release.yaml@dev -------------------------------------------------------------------------------- /.github/workflows/tag-release.yaml: -------------------------------------------------------------------------------- 1 | name: tag-release 2 | on: 3 | push: 4 | branches: ["Release-*"] 5 | jobs: 6 | tag: 7 | uses: qiime2/distributions/.github/workflows/lib-tag-release.yaml@dev -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # q2-feature-classifier 2 | 3 | ![](https://github.com/qiime2/q2-feature-classifier/workflows/ci-dev/badge.svg) 4 | 5 | This is a QIIME 2 plugin. For details on QIIME 2, see https://qiime2.org. -------------------------------------------------------------------------------- /q2_feature_classifier/tests/data/blast6-format.tsv: -------------------------------------------------------------------------------- 1 | 1111561 1111561 100.000 75 0 0 1 75 44 118 7.50e-34 139 2 | 1111561 574274 92.308 78 2 4 1 75 44 120 2.11e-24 108 3 | 835097 835097 100.000 80 0 0 1 80 52 131 1.36e-36 148 4 | -------------------------------------------------------------------------------- /q2_feature_classifier/tests/data/dna-sequences.fasta: -------------------------------------------------------------------------------- 1 | >Sequence1 2 | AGAGAACGTGCAGG 3 | >Sequence2 4 | AGAGAAAGTGCAGG 5 | >Sequence3 6 | AGAGAACCTGCAGG 7 | >Sequence4 8 | AGAGAACGGGCAGG 9 | >Sequence5 10 | AGAGAACTTGCAGG 11 | -------------------------------------------------------------------------------- /q2_feature_classifier/tests/data/dna-sequences-mixed.fasta: -------------------------------------------------------------------------------- 1 | >Sequence1 2 | AGAGAACGTGCAGG 3 | >Sequence2 4 | CCTGCACTTTCTCT 5 | >Sequence3 6 | AGAGAACCTGCAGG 7 | >Sequence4 8 | AGAGAACGGGCAGG 9 | >Sequence5 10 | CCTGCAAGTTCTCT 11 | -------------------------------------------------------------------------------- /q2_feature_classifier/tests/data/dna-sequences-reverse.fasta: -------------------------------------------------------------------------------- 1 | >Sequence1 2 | CCTGCACGTTCTCT 3 | >Sequence2 4 | CCTGCACTTTCTCT 5 | >Sequence3 6 | CCTGCAGGTTCTCT 7 | >Sequence4 8 | CCTGCCCGTTCTCT 9 | >Sequence5 10 | CCTGCAAGTTCTCT 11 | -------------------------------------------------------------------------------- /q2_feature_classifier/tests/data/dna_sequence_both_test.fasta: -------------------------------------------------------------------------------- 1 | >DNA_SEQUENCE_1 2 | AGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAG 3 | >DNA_SEQUENCE_2 4 | CTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCT 5 | 6 | -------------------------------------------------------------------------------- /q2_feature_classifier/tests/data/dna-sequences-degenerate-primers.fasta: -------------------------------------------------------------------------------- 1 | >Sequence1 2 | ATATAACGTGCCGG 3 | >Sequence2 4 | ATATAAAGTGCCGG 5 | >Sequence3 6 | ATATAACCTGCCGG 7 | >Sequence4 8 | ATATAACGGGCCGG 9 | >Sequence5 10 | ATATAACTTGCCGG 11 | -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | branch = True 3 | omit = 4 | */tests* 5 | */__init__.py 6 | q2_feature_classifier/_version.py 7 | versioneer.py 8 | 9 | [report] 10 | omit = 11 | */tests* 12 | */__init__.py 13 | q2_feature_classifier/_version.py 14 | versioneer.py 15 | -------------------------------------------------------------------------------- /q2_feature_classifier/tests/data/query-seqs.fasta: -------------------------------------------------------------------------------- 1 | >1111561 2 | ACACATGCAAGTCGAACGGCAGCGGGGGAAAGCTTGCTTTCCTGCCGGCGAGTGGCGGACGGGTGAGTAATGCGT 3 | >835097 4 | AAGTCGAGCGAAAGACCCCGGGCTTGCCCGGGTGATTTAGCGGCGGACGGCTGAGTAACACGTGAGAAACTTGCCCTTAG 5 | >junk 6 | AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA 7 | -------------------------------------------------------------------------------- /.github/workflows/ci-dev.yaml: -------------------------------------------------------------------------------- 1 | # Example of workflow trigger for calling workflow (the client). 2 | name: ci-dev 3 | on: 4 | pull_request: 5 | branches: ["dev"] 6 | push: 7 | branches: ["dev"] 8 | jobs: 9 | ci: 10 | uses: qiime2/distributions/.github/workflows/lib-ci-dev.yaml@dev 11 | with: 12 | distro: amplicon 13 | recipe-path: 'conda-recipe' 14 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: all lint test test-cov install dev clean distclean 2 | 3 | PYTHON ?= python 4 | 5 | all: ; 6 | 7 | lint: 8 | q2lint 9 | flake8 10 | 11 | test: all 12 | py.test 13 | 14 | test-cov: all 15 | py.test --cov=q2_feature_classifier 16 | 17 | install: all 18 | $(PYTHON) -m pip install -v . 19 | 20 | dev: all 21 | pip install -e . 22 | 23 | clean: distclean 24 | 25 | distclean: ; 26 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/2-dev-need-help.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: I am a developer and I need help with QIIME 2... 3 | about: I am developing a QIIME 2 plugin or interface and have a question or a problem 4 | 5 | --- 6 | 7 | Have you had a chance to check out the developer docs? 8 | https://dev.qiime2.org 9 | There are many tutorials, walkthroughs, and guides available. 10 | 11 | If you still need help, please visit: 12 | https://forum.qiime2.org/c/dev-discussion 13 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | Brief summary of the Pull Request, including any issues it may fix using the GitHub closing syntax: 2 | 3 | https://help.github.com/articles/closing-issues-using-keywords/ 4 | 5 | Also, include any co-authors or contributors using the GitHub coauthor tag: 6 | 7 | https://help.github.com/articles/creating-a-commit-with-multiple-authors/ 8 | 9 | --- 10 | 11 | Include any questions for reviewers, screenshots, sample outputs, etc. 12 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/1-user-need-help.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: I am a user and I need help with QIIME 2... 3 | about: I am using QIIME 2 and have a question or am experiencing a problem 4 | 5 | --- 6 | 7 | Have you had a chance to check out the docs? 8 | https://docs.qiime2.org 9 | There are many tutorials, walkthroughs, and guides available. 10 | 11 | If you still need help, please visit: 12 | https://forum.qiime2.org/c/user-support 13 | 14 | Help requests filed here will not be answered. 15 | -------------------------------------------------------------------------------- /.copier-answers.yml: -------------------------------------------------------------------------------- 1 | # Changes here will be overwritten by Copier; NEVER EDIT MANUALLY 2 | _commit: dfb0404 3 | _src_path: https://github.com/qiime2/q2-setup-template.git 4 | module_name: feature_classifier 5 | plugin_name: q2_feature_classifier 6 | plugin_scripts: null 7 | project_author_email: kaehler@gmail.com 8 | project_author_name: Ben Kaehler 9 | project_description: Functionality for taxonomic classification 10 | project_name: q2-feature-classifier 11 | project_urls_homepage: https://qiime2.org 12 | project_urls_repository: https://github.com/qiime2/q2-feature-classifier 13 | -------------------------------------------------------------------------------- /q2_feature_classifier/types/__init__.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------------------------------------------- 2 | # Copyright (c) 2016-2025, QIIME 2 development team. 3 | # 4 | # Distributed under the terms of the Modified BSD License. 5 | # 6 | # The full license is in the file LICENSE, distributed with this software. 7 | # ---------------------------------------------------------------------------- 8 | 9 | from ._format import BLASTDBFileFmtV5, BLASTDBDirFmtV5 10 | from ._type import BLASTDB 11 | 12 | 13 | __all__ = ['BLASTDBFileFmtV5', 'BLASTDBDirFmtV5', 'BLASTDB'] 14 | -------------------------------------------------------------------------------- /.github/workflows/add-to-project-ci.yml: -------------------------------------------------------------------------------- 1 | name: Add new issues and PRs to triage project board 2 | 3 | on: 4 | issues: 5 | types: 6 | - opened 7 | pull_request_target: 8 | types: 9 | - opened 10 | 11 | jobs: 12 | add-to-project: 13 | name: Add issue to project 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: actions/add-to-project@v0.3.0 17 | with: 18 | project-url: https://github.com/orgs/qiime2/projects/36 19 | github-token: ${{ secrets.ADD_TO_PROJECT_PAT }} 20 | labeled: skip-triage 21 | label-operator: NOT 22 | -------------------------------------------------------------------------------- /q2_feature_classifier/types/_type.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------------------------------------------- 2 | # Copyright (c) 2016-2025, QIIME 2 development team. 3 | # 4 | # Distributed under the terms of the Modified BSD License. 5 | # 6 | # The full license is in the file LICENSE, distributed with this software. 7 | # ---------------------------------------------------------------------------- 8 | 9 | from qiime2.plugin import SemanticType 10 | from . import BLASTDBDirFmtV5 11 | from ..plugin_setup import plugin 12 | 13 | 14 | BLASTDB = SemanticType('BLASTDB') 15 | 16 | plugin.register_semantic_types(BLASTDB) 17 | plugin.register_artifact_class(BLASTDB, BLASTDBDirFmtV5) 18 | -------------------------------------------------------------------------------- /q2_feature_classifier/__init__.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------------------------------------------- 2 | # Copyright (c) 2016-2025, QIIME 2 development team. 3 | # 4 | # Distributed under the terms of the Modified BSD License. 5 | # 6 | # The full license is in the file LICENSE, distributed with this software. 7 | # ---------------------------------------------------------------------------- 8 | 9 | import importlib 10 | 11 | try: 12 | from ._version import __version__ 13 | except ModuleNotFoundError: 14 | __version__ = '0.0.0+notfound' 15 | 16 | importlib.import_module('q2_feature_classifier.types') 17 | importlib.import_module('q2_feature_classifier.classifier') 18 | importlib.import_module('q2_feature_classifier._cutter') 19 | importlib.import_module('q2_feature_classifier._blast') 20 | importlib.import_module('q2_feature_classifier._vsearch') 21 | -------------------------------------------------------------------------------- /.github/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to this project 2 | 3 | Thanks for thinking of us :heart: :tada: - we would love a helping hand! 4 | 5 | ## I just have a question 6 | 7 | > Note: Please don't file an issue to ask a question. You'll get faster results 8 | > by using the resources below. 9 | 10 | ### QIIME 2 Users 11 | 12 | Check out the [User Docs](https://docs.qiime2.org) - there are many tutorials, 13 | walkthroughs, and guides available. If you still need help, please visit us at 14 | the [QIIME 2 Forum](https://forum.qiime2.org/c/user-support). 15 | 16 | ### QIIME 2 Developers 17 | 18 | Check out the [Developer Docs](https://dev.qiime2.org) - there are many 19 | tutorials, walkthroughs, and guides available. If you still need help, please 20 | visit us at the [QIIME 2 Forum](https://forum.qiime2.org/c/dev-discussion). 21 | 22 | This document is based heavily on the following: 23 | https://github.com/atom/atom/blob/master/CONTRIBUTING.md 24 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/5-make-new.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: I am a developer and I have an idea for a new feature... 3 | about: I am a developer and I have an idea for new functionality 4 | 5 | --- 6 | 7 | **Addition Description** 8 | A clear and concise description of what the addition is. 9 | 10 | **Current Behavior** 11 | Please provide a brief description of the current behavior, if applicable. 12 | 13 | **Proposed Behavior** 14 | Please provide a brief description of the proposed behavior. 15 | 16 | **Questions** 17 | 1. An enumerated list of questions related to the proposal. 18 | 2. If not applicable, please delete this section. 19 | 20 | **Comments** 21 | 1. An enumerated list of comments related to the proposal that don't fit anywhere else. 22 | 2. If not applicable, please delete this section. 23 | 24 | **References** 25 | 1. An enumerated list of links to relevant references, including forum posts, stack overflow, etc. 26 | 2. If not applicable, please delete this section. 27 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/4-make-better.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: I am a developer and I have an idea for an improvement... 3 | about: I am a developer and I have an idea for an improvement to existing functionality 4 | 5 | --- 6 | 7 | **Improvement Description** 8 | A clear and concise description of what the improvement is. 9 | 10 | **Current Behavior** 11 | Please provide a brief description of the current behavior. 12 | 13 | **Proposed Behavior** 14 | Please provide a brief description of the proposed behavior. 15 | 16 | **Questions** 17 | 1. An enumerated list of questions related to the proposal. 18 | 2. If not applicable, please delete this section. 19 | 20 | **Comments** 21 | 1. An enumerated list of comments related to the proposal that don't fit anywhere else. 22 | 2. If not applicable, please delete this section. 23 | 24 | **References** 25 | 1. An enumerated list of links to relevant references, including forum posts, stack overflow, etc. 26 | 2. If not applicable, please delete this section. 27 | -------------------------------------------------------------------------------- /q2_feature_classifier/plugin_setup.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------------------------------------------- 2 | # Copyright (c) 2016-2025, QIIME 2 development team. 3 | # 4 | # Distributed under the terms of the Modified BSD License. 5 | # 6 | # The full license is in the file LICENSE, distributed with this software. 7 | # ---------------------------------------------------------------------------- 8 | 9 | from qiime2.plugin import Plugin, Citations 10 | 11 | import q2_feature_classifier 12 | 13 | citations = Citations.load('citations.bib', package='q2_feature_classifier') 14 | plugin = Plugin( 15 | name='feature-classifier', 16 | version=q2_feature_classifier.__version__, 17 | website='https://github.com/qiime2/q2-feature-classifier', 18 | package='q2_feature_classifier', 19 | description=('This QIIME 2 plugin supports taxonomic ' 20 | 'classification of features using a variety ' 21 | 'of methods, including Naive Bayes, vsearch, ' 22 | 'and BLAST+.'), 23 | short_description='Plugin for taxonomic classification.', 24 | citations=[citations['bokulich2018optimizing']] 25 | ) 26 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | 55 | # Sphinx documentation 56 | docs/_build/ 57 | 58 | # PyBuilder 59 | target/ 60 | 61 | #Ipython Notebook 62 | .ipynb_checkpoints 63 | 64 | # vi 65 | .*.swp 66 | 67 | .DS_Store 68 | 69 | # Version file from versioningit 70 | _version.py 71 | 72 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/3-found-bug.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: I am a developer and I found a bug... 3 | about: I am a developer and I found a bug that I can describe 4 | 5 | --- 6 | 7 | **Bug Description** 8 | A clear and concise description of what the bug is. 9 | 10 | **Steps to reproduce the behavior** 11 | 1. Go to '...' 12 | 2. Click on '....' 13 | 3. Scroll down to '....' 14 | 4. See error 15 | 16 | **Expected behavior** 17 | A clear and concise description of what you expected to happen. 18 | 19 | **Screenshots** 20 | If applicable, add screenshots to help explain your problem. 21 | 22 | **Computation Environment** 23 | - OS: [e.g. macOS High Sierra] 24 | - QIIME 2 Release [e.g. 2018.6] 25 | 26 | **Questions** 27 | 1. An enumerated list with any questions about the problem here. 28 | 2. If not applicable, please delete this section. 29 | 30 | **Comments** 31 | 1. An enumerated list with any other context or comments about the problem here. 32 | 2. If not applicable, please delete this section. 33 | 34 | **References** 35 | 1. An enumerated list of links to relevant references, including forum posts, stack overflow, etc. 36 | 2. If not applicable, please delete this section. 37 | -------------------------------------------------------------------------------- /conda-recipe/meta.yaml: -------------------------------------------------------------------------------- 1 | package: 2 | name: q2-feature-classifier 3 | version: {{ PLUGIN_VERSION }} 4 | source: 5 | path: .. 6 | build: 7 | script: make install 8 | requirements: 9 | host: 10 | - python {{ python }} 11 | - setuptools 12 | - versioningit 13 | - wheel 14 | run: 15 | - python {{ python }} 16 | - scikit-learn {{ scikit_learn }} 17 | - joblib 18 | - scikit-bio {{ scikit_bio }} 19 | - biom-format {{ biom_format }} 20 | - blast {{ blast }} 21 | - vsearch {{ vsearch }} 22 | - qiime2 >={{ qiime2 }} 23 | - q2-types >={{ q2_types }} 24 | - q2-quality-control >={{ q2_quality_control }} 25 | - q2-taxa >={{ q2_taxa }} 26 | - q2-feature-table >={{ q2_feature_table }} 27 | build: 28 | - python {{ python }} 29 | - setuptools 30 | - versioningit 31 | test: 32 | requires: 33 | - qiime2 >={{ qiime2 }} 34 | - q2-types >={{ q2_types }} 35 | - q2-quality-control >={{ q2_quality_control }} 36 | - q2-taxa >={{ q2_taxa }} 37 | - q2-feature-table >={{ q2_feature_table }} 38 | - pytest 39 | imports: 40 | - q2_feature_classifier 41 | - qiime2.plugins.feature_classifier 42 | commands: 43 | - py.test --pyargs q2_feature_classifier 44 | about: 45 | home: https://qiime2.org 46 | license: BSD-3-Clause 47 | license_family: BSD 48 | -------------------------------------------------------------------------------- /q2_feature_classifier/tests/__init__.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------------------------------------------- 2 | # Copyright (c) 2016-2025, QIIME 2 development team. 3 | # 4 | # Distributed under the terms of the Modified BSD License. 5 | # 6 | # The full license is in the file LICENSE, distributed with this software. 7 | # ---------------------------------------------------------------------------- 8 | 9 | import tempfile 10 | import shutil 11 | from warnings import filterwarnings 12 | 13 | from qiime2.plugin.testing import TestPluginBase 14 | 15 | 16 | class FeatureClassifierTestPluginBase(TestPluginBase): 17 | package = 'q2_feature_classifier.tests' 18 | 19 | def setUp(self): 20 | try: 21 | from q2_feature_classifier.plugin_setup import plugin 22 | except ImportError: 23 | self.fail("Could not import plugin object.") 24 | 25 | self.plugin = plugin 26 | 27 | self.temp_dir = tempfile.TemporaryDirectory( 28 | prefix='q2-feature-classifier-test-temp-') 29 | 30 | filterwarnings('ignore', 'The TaxonomicClassifier ', UserWarning) 31 | 32 | def _setup_dir(self, filenames, dirfmt): 33 | for filename in filenames: 34 | filepath = self.get_data_path(filename) 35 | shutil.copy(filepath, self.temp_dir.name) 36 | 37 | return dirfmt(self.temp_dir.name, mode='r') 38 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "q2-feature-classifier" 3 | authors = [ 4 | { name = "Ben Kaehler", email = "kaehler@gmail.com" } 5 | ] 6 | description = "Functionality for taxonomic classification" 7 | readme = {file = "README.md", content-type = "text/markdown"} 8 | license = {file = "LICENSE"} 9 | dynamic = ["version"] 10 | 11 | [project.urls] 12 | Homepage = "https://qiime2.org" 13 | Repository = "https://github.com/qiime2/q2-feature-classifier" 14 | 15 | [project.entry-points.'qiime2.plugins'] 16 | "q2-feature-classifier" = "q2_feature_classifier.plugin_setup:plugin" 17 | 18 | [build-system] 19 | requires = [ 20 | "setuptools", 21 | "versioningit", 22 | "wheel" 23 | ] 24 | build-backend = "setuptools.build_meta" 25 | 26 | [tool.versioningit.vcs] 27 | method = "git-archive" 28 | describe-subst = "2026.1.0.dev0" 29 | default-tag = "0.0.1" 30 | 31 | [tool.versioningit.next-version] 32 | method = "minor" 33 | 34 | [tool.versioningit.format] 35 | distance = "{base_version}+{distance}.{vcs}{rev}" 36 | dirty = "{base_version}+{distance}.{vcs}{rev}.dirty" 37 | distance-dirty = "{base_version}+{distance}.{vcs}{rev}.dirty" 38 | 39 | [tool.versioningit.write] 40 | file = "q2_feature_classifier/_version.py" 41 | 42 | [tool.setuptools] 43 | include-package-data = true 44 | 45 | [tool.setuptools.packages.find] 46 | where = ["."] 47 | include = ["q2_feature_classifier*"] 48 | 49 | [tool.setuptools.package-data] 50 | q2_feature_classifier = ["**/*"] 51 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2016-2025, QIIME 2 development team. 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /q2_feature_classifier/citations.bib: -------------------------------------------------------------------------------- 1 | @article{bokulich2018optimizing, 2 | author={Bokulich, Nicholas A. and Kaehler, Benjamin D. and Rideout, Jai Ram and Dillon, Matthew and Bolyen, Evan and Knight, Rob and Huttley, Gavin A. and Caporaso, J. Gregory}, 3 | title={Optimizing taxonomic classification of marker-gene amplicon sequences with QIIME 2's q2-feature-classifier plugin}, 4 | journal={Microbiome}, 5 | year={2018}, 6 | volume={6}, 7 | number={1}, 8 | pages={90}, 9 | doi={10.1186/s40168-018-0470-z}, 10 | url={https://doi.org/10.1186/s40168-018-0470-z} 11 | } 12 | 13 | @article{pedregosa2011scikit, 14 | title={Scikit-learn: Machine learning in Python}, 15 | author={Pedregosa, Fabian and Varoquaux, Ga{\"e}l and Gramfort, Alexandre and Michel, Vincent and Thirion, Bertrand and Grisel, Olivier and Blondel, Mathieu and Prettenhofer, Peter and Weiss, Ron and Dubourg, Vincent and Vanderplas, Jake and Passos, Alexandre and Cournapeau, David and Brucher, Matthieu and Perrot, Matthieu and Duchesnay, {\'E}douard}, 16 | journal={Journal of machine learning research}, 17 | volume={12}, 18 | number={Oct}, 19 | pages={2825--2830}, 20 | year={2011} 21 | } 22 | 23 | @article{camacho2009blast+, 24 | title={BLAST+: architecture and applications}, 25 | author={Camacho, Christiam and Coulouris, George and Avagyan, Vahram and Ma, Ning and Papadopoulos, Jason and Bealer, Kevin and Madden, Thomas L}, 26 | journal={BMC bioinformatics}, 27 | volume={10}, 28 | number={1}, 29 | pages={421}, 30 | year={2009}, 31 | publisher={BioMed Central}, 32 | doi={10.1186/1471-2105-10-421} 33 | } 34 | 35 | @article{rognes2016vsearch, 36 | title={VSEARCH: a versatile open source tool for metagenomics}, 37 | author={Rognes, Torbj{\o}rn and Flouri, Tom{\'a}{\v{s}} and Nichols, Ben and Quince, Christopher and Mah{\'e}, Fr{\'e}d{\'e}ric}, 38 | journal={PeerJ}, 39 | volume={4}, 40 | pages={e2584}, 41 | year={2016}, 42 | publisher={PeerJ Inc.}, 43 | doi={10.7717/peerj.2584} 44 | } 45 | -------------------------------------------------------------------------------- /.github/workflows/update-copyright-headers.yaml: -------------------------------------------------------------------------------- 1 | name: Update Copyright Headers 2 | 3 | on: 4 | # Runs at 00:00 UTC on Jan 4th or via manual trigger 5 | schedule: 6 | - cron: '0 0 4 1 *' 7 | workflow_dispatch: 8 | inputs: 9 | newYear: 10 | description: "Desired year to update to (e.g., 2025). If not provided, will auto-detect." 11 | required: false 12 | 13 | jobs: 14 | update-headers: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - name: Check out repository 18 | uses: actions/checkout@v3 19 | 20 | - name: Determine years for update 21 | id: determine-years 22 | run: | 23 | INPUT_YEAR="${{ github.event.inputs.newYear }}" 24 | 25 | if [ -z "$INPUT_YEAR" ]; then 26 | CURRENT_YEAR="$(date +'%Y')" 27 | echo "No 'newYear' input. Using current year: ${CURRENT_YEAR}" 28 | else 29 | CURRENT_YEAR="$INPUT_YEAR" 30 | echo "Received user input. Updating to: ${CURRENT_YEAR}" 31 | fi 32 | 33 | echo "CURRENT_YEAR=$CURRENT_YEAR" >> $GITHUB_ENV 34 | 35 | - name: Bump ending year in QIIME 2 headers 36 | run: | 37 | source $GITHUB_ENV 38 | 39 | echo "Will update any QIIME 2 header years in [2015..$((CURRENT_YEAR-1))] to $CURRENT_YEAR" 40 | 41 | for OLD_YEAR in $(seq 2015 $((CURRENT_YEAR - 1))); do 42 | find . -type f -exec \ 43 | sed -i -E "s/(Copyright \(c\) [0-9]{4})-${OLD_YEAR}, QIIME 2/\1-${CURRENT_YEAR}, QIIME 2/g" {} + 44 | done 45 | 46 | - name: Commit and push changes 47 | run: | 48 | git config --global user.name "q2d2" 49 | git config --global user.email "q2d2.noreply@gmail.com" 50 | 51 | if [ -n "$(git status --porcelain)" ]; then 52 | git add . 53 | git commit -m "Auto-update copyright year to $CURRENT_YEAR" 54 | git push 55 | else 56 | echo "No changes to commit." 57 | exit 0 58 | fi 59 | -------------------------------------------------------------------------------- /q2_feature_classifier/types/_format.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------------------------------------------- 2 | # Copyright (c) 2016-2025, QIIME 2 development team. 3 | # 4 | # Distributed under the terms of the Modified BSD License. 5 | # 6 | # The full license is in the file LICENSE, distributed with this software. 7 | # ---------------------------------------------------------------------------- 8 | 9 | import os 10 | import itertools 11 | from qiime2.plugin import model 12 | from ..plugin_setup import plugin, citations 13 | 14 | 15 | class BLASTDBFileFmtV5(model.BinaryFileFormat): 16 | # We do not have a good way to validate the individual blastdb files. 17 | # TODO: could wire up `blastdbcheck` to do a deep check when level=max 18 | # but this must be done on the directory, not individual files. 19 | # For now validation be done at the DirFmt level on file extensions. 20 | def _validate_(self, level): 21 | pass 22 | 23 | 24 | class BLASTDBDirFmtV5(model.DirectoryFormat): 25 | # TODO: is there a more robust way to do this/make some files optional? 26 | # Some file extensions were introduced with more recent versions of 27 | # blast, but are not actually needed for our purposes. Making these 28 | # optional would allow more flexibility in blast versions, avoiding 29 | # possible dependency conflicts. 30 | # NOTE that the .n?? extensions are also nucleotide database specific. 31 | # should we rather call the type/formats BLASTNucDB*? 32 | idx1 = model.File(r'.+\.ndb', format=BLASTDBFileFmtV5) 33 | idx2 = model.File(r'.+\.nhr', format=BLASTDBFileFmtV5) 34 | idx3 = model.File(r'.+\.nin', format=BLASTDBFileFmtV5) 35 | idx4 = model.File(r'.+\.not', format=BLASTDBFileFmtV5) 36 | idx5 = model.File(r'.+\.nsq', format=BLASTDBFileFmtV5) 37 | idx6 = model.File(r'.+\.ntf', format=BLASTDBFileFmtV5) 38 | idx7 = model.File(r'.+\.nto', format=BLASTDBFileFmtV5) 39 | # introducted in blast 2.13.0 40 | # https://ncbiinsights.ncbi.nlm.nih.gov/2022/03/29/blast-2-13-0/ 41 | idx8 = model.File(r'.+\.njs', format=BLASTDBFileFmtV5) 42 | 43 | # borrowed from q2-types 44 | def get_basename(self): 45 | paths = [str(x.relative_to(self.path)) for x in self.path.iterdir()] 46 | prefix = os.path.splitext(_get_prefix(paths))[0] 47 | return prefix 48 | 49 | 50 | # SO: https://stackoverflow.com/a/6718380/579416 51 | def _get_prefix(strings): 52 | def all_same(x): 53 | return all(x[0] == y for y in x) 54 | 55 | char_tuples = zip(*strings) 56 | prefix_tuples = itertools.takewhile(all_same, char_tuples) 57 | return ''.join(x[0] for x in prefix_tuples) 58 | 59 | 60 | plugin.register_views(BLASTDBDirFmtV5, 61 | citations=[citations['camacho2009blast+']]) 62 | -------------------------------------------------------------------------------- /q2_feature_classifier/tests/test_custom.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------------------------------------------- 2 | # Copyright (c) 2016-2025, QIIME 2 development team. 3 | # 4 | # Distributed under the terms of the Modified BSD License. 5 | # 6 | # The full license is in the file LICENSE, distributed with this software. 7 | # ---------------------------------------------------------------------------- 8 | 9 | import json 10 | 11 | from qiime2.sdk import Artifact 12 | from qiime2.plugins import feature_classifier 13 | from q2_types.feature_data import DNAIterator 14 | from q2_feature_classifier.custom import ChunkedHashingVectorizer 15 | from q2_feature_classifier._skl import _extract_reads 16 | from sklearn.feature_extraction.text import HashingVectorizer 17 | import pandas as pd 18 | 19 | from . import FeatureClassifierTestPluginBase 20 | 21 | 22 | class CustomTests(FeatureClassifierTestPluginBase): 23 | package = 'q2_feature_classifier.tests' 24 | 25 | def setUp(self): 26 | super().setUp() 27 | self.taxonomy = Artifact.import_data( 28 | 'FeatureData[Taxonomy]', self.get_data_path('taxonomy.tsv')) 29 | 30 | def test_low_memory_multinomial_nb(self): 31 | # results should not depend on chunk size 32 | fitter = feature_classifier.methods.fit_classifier_sklearn 33 | classify = feature_classifier.methods.classify_sklearn 34 | reads = Artifact.import_data( 35 | 'FeatureData[Sequence]', 36 | self.get_data_path('se-dna-sequences.fasta')) 37 | 38 | spec = [['feat_ext', 39 | {'__type__': 'feature_extraction.text.HashingVectorizer', 40 | 'analyzer': 'char', 41 | 'n_features': 8192, 42 | 'ngram_range': (8, 8), 43 | 'alternate_sign': False}], 44 | ['classify', 45 | {'__type__': 'custom.LowMemoryMultinomialNB', 46 | 'alpha': 0.01, 47 | 'chunk_size': 20000}]] 48 | 49 | classifier_spec = json.dumps(spec) 50 | result = fitter(reads, self.taxonomy, classifier_spec) 51 | result = classify(reads, result.classifier) 52 | gc = result.classification.view(pd.Series).to_dict() 53 | 54 | spec[1][1]['chunk_size'] = 20 55 | classifier_spec = json.dumps(spec) 56 | result = fitter(reads, self.taxonomy, classifier_spec) 57 | result = classify(reads, result.classifier) 58 | sc = result.classification.view(pd.Series).to_dict() 59 | 60 | for taxon in gc: 61 | self.assertEqual(gc[taxon], sc[taxon]) 62 | 63 | def test_chunked_hashing_vectorizer(self): 64 | # results should not depend on chunk size 65 | _, X = _extract_reads(Artifact.import_data( 66 | 'FeatureData[Sequence]', 67 | self.get_data_path('se-dna-sequences.fasta')).view(DNAIterator)) 68 | 69 | params = {'analyzer': 'char', 70 | 'n_features': 8192, 71 | 'ngram_range': (8, 8), 72 | 'alternate_sign': False} 73 | hv = HashingVectorizer(**params) 74 | unchunked = hv.fit_transform(X) 75 | 76 | for chunk_size in (-1, 3, 13): 77 | chv = ChunkedHashingVectorizer(chunk_size=chunk_size, **params) 78 | chunked = chv.fit_transform(X) 79 | for x1, x2 in zip(chunked, unchunked): 80 | self.assertTrue((x1.todense() == x2.todense()).all()) 81 | -------------------------------------------------------------------------------- /q2_feature_classifier/_taxonomic_classifier.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------------------------------------------- 2 | # Copyright (c) 2016-2025, QIIME 2 development team. 3 | # 4 | # Distributed under the terms of the Modified BSD License. 5 | # 6 | # The full license is in the file LICENSE, distributed with this software. 7 | # ---------------------------------------------------------------------------- 8 | 9 | import json 10 | import tarfile 11 | import os 12 | 13 | import sklearn 14 | import joblib 15 | from sklearn.pipeline import Pipeline 16 | import qiime2.plugin 17 | import qiime2.plugin.model as model 18 | 19 | from .plugin_setup import plugin 20 | 21 | 22 | # Semantic Types 23 | TaxonomicClassifier = qiime2.plugin.SemanticType('TaxonomicClassifier') 24 | 25 | 26 | # Formats 27 | class PickleFormat(model.BinaryFileFormat): 28 | def sniff(self): 29 | return tarfile.is_tarfile(str(self)) 30 | 31 | 32 | # https://github.com/qiime2/q2-types/issues/49 33 | class JSONFormat(model.TextFileFormat): 34 | def sniff(self): 35 | with self.open() as fh: 36 | try: 37 | json.load(fh) 38 | return True 39 | except json.JSONDecodeError: 40 | pass 41 | return False 42 | 43 | 44 | class TaxonomicClassifierDirFmt(model.DirectoryFormat): 45 | preprocess_params = model.File('preprocess_params.json', format=JSONFormat) 46 | sklearn_pipeline = model.File('sklearn_pipeline.tar', format=PickleFormat) 47 | 48 | 49 | class TaxonomicClassiferTemporaryPickleDirFmt(model.DirectoryFormat): 50 | version_info = model.File('sklearn_version.json', format=JSONFormat) 51 | sklearn_pipeline = model.File('sklearn_pipeline.tar', format=PickleFormat) 52 | 53 | 54 | # Transformers 55 | @plugin.register_transformer 56 | def _1(dirfmt: TaxonomicClassiferTemporaryPickleDirFmt) -> Pipeline: 57 | sklearn_version = dirfmt.version_info.view(dict)['sklearn-version'] 58 | if sklearn_version != sklearn.__version__: 59 | raise ValueError('The scikit-learn version (%s) used to generate this' 60 | ' artifact does not match the current version' 61 | ' of scikit-learn installed (%s). Please retrain your' 62 | ' classifier for your current deployment to prevent' 63 | ' data-corruption errors.' 64 | % (sklearn_version, sklearn.__version__)) 65 | 66 | sklearn_pipeline = dirfmt.sklearn_pipeline.view(PickleFormat) 67 | 68 | with tarfile.open(str(sklearn_pipeline)) as tar: 69 | tmpdir = model.DirectoryFormat() 70 | dirname = str(tmpdir) 71 | tar.extractall(dirname) 72 | pipeline = joblib.load(os.path.join(dirname, 'sklearn_pipeline.pkl')) 73 | for fn in tar.getnames(): 74 | os.unlink(os.path.join(dirname, fn)) 75 | 76 | return pipeline 77 | 78 | 79 | @plugin.register_transformer 80 | def _2(data: Pipeline) -> TaxonomicClassiferTemporaryPickleDirFmt: 81 | sklearn_pipeline = PickleFormat() 82 | with tarfile.open(str(sklearn_pipeline), 'w') as tar: 83 | tmpdir = model.DirectoryFormat() 84 | pf = os.path.join(str(tmpdir), 'sklearn_pipeline.pkl') 85 | for fn in joblib.dump(data, pf): 86 | tar.add(fn, os.path.basename(fn)) 87 | os.unlink(fn) 88 | 89 | dirfmt = TaxonomicClassiferTemporaryPickleDirFmt() 90 | dirfmt.version_info.write_data( 91 | {'sklearn-version': sklearn.__version__}, dict) 92 | dirfmt.sklearn_pipeline.write_data(sklearn_pipeline, PickleFormat) 93 | 94 | return dirfmt 95 | 96 | 97 | @plugin.register_transformer 98 | def _3(dirfmt: TaxonomicClassifierDirFmt) -> Pipeline: 99 | raise ValueError('The scikit-learn version could not be determined for' 100 | ' this artifact, please retrain your classifier for your' 101 | ' current deployment to prevent data-corruption errors.') 102 | 103 | 104 | @plugin.register_transformer 105 | def _4(fmt: JSONFormat) -> dict: 106 | with fmt.open() as fh: 107 | return json.load(fh) 108 | 109 | 110 | @plugin.register_transformer 111 | def _5(data: dict) -> JSONFormat: 112 | result = JSONFormat() 113 | with result.open() as fh: 114 | json.dump(data, fh) 115 | return result 116 | 117 | 118 | # Registrations 119 | plugin.register_semantic_types(TaxonomicClassifier) 120 | plugin.register_formats(TaxonomicClassifierDirFmt, 121 | TaxonomicClassiferTemporaryPickleDirFmt) 122 | plugin.register_semantic_type_to_format( 123 | TaxonomicClassifier, 124 | artifact_format=TaxonomicClassiferTemporaryPickleDirFmt) 125 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/6-where-to-go.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: I don't know where to file my issue... 3 | about: I am a developer and I don't know which repo to file this in 4 | 5 | --- 6 | 7 | The repos within the QIIME 2 GitHub Organization are listed below, with a brief description about the repo. 8 | 9 | Sorted alphabetically by repo name. 10 | 11 | - The CI automation engine that builds and distributes QIIME 2 12 | https://github.com/qiime2/busywork/issues 13 | 14 | - A Concourse resource for working with conda 15 | https://github.com/qiime2/conda-channel-resource/issues 16 | 17 | - Web app for vanity URLs for QIIME 2 data assets 18 | https://github.com/qiime2/data.qiime2.org/issues 19 | 20 | - The Developer Documentation 21 | https://github.com/qiime2/dev-docs/issues 22 | 23 | - A discourse plugin for handling queued/unqueued topics 24 | https://github.com/qiime2/discourse-unhandled-tagger/issues 25 | 26 | - The User Documentation 27 | https://github.com/qiime2/docs/issues 28 | 29 | - Rendered QIIME 2 environment files for conda 30 | https://github.com/qiime2/environment-files/issues 31 | 32 | - Google Sheets Add-On for validating tabular data 33 | https://github.com/qiime2/Keemei/issues 34 | 35 | - A docker image for linux-based busywork workers 36 | https://github.com/qiime2/linux-worker-docker/issues 37 | 38 | - Official project logos 39 | https://github.com/qiime2/logos/issues 40 | 41 | - The q2-alignment plugin 42 | https://github.com/qiime2/q2-alignment/issues 43 | 44 | - The q2-composition plugin 45 | https://github.com/qiime2/q2-composition/issues 46 | 47 | - The q2-cutadapt plugin 48 | https://github.com/qiime2/q2-cutadapt/issues 49 | 50 | - The q2-dada2 plugin 51 | https://github.com/qiime2/q2-dada2/issues 52 | 53 | - The q2-deblur plugin 54 | https://github.com/qiime2/q2-deblur/issues 55 | 56 | - The q2-demux plugin 57 | https://github.com/qiime2/q2-demux/issues 58 | 59 | - The q2-diversity plugin 60 | https://github.com/qiime2/q2-diversity/issues 61 | 62 | - The q2-diversity-lib plugin 63 | https://github.com/qiime2/q2-diversity-lib/issues 64 | 65 | - The q2-emperor plugin 66 | https://github.com/qiime2/q2-emperor/issues 67 | 68 | - The q2-feature-classifier plugin 69 | https://github.com/qiime2/q2-feature-classifier/issues 70 | 71 | - The q2-feature-table plugin 72 | https://github.com/qiime2/q2-feature-table/issues 73 | 74 | - The q2-fragment-insertion plugin 75 | https://github.com/qiime2/q2-fragment-insertion/issues 76 | 77 | - The q2-gneiss plugin 78 | https://github.com/qiime2/q2-gneiss/issues 79 | 80 | - The q2-longitudinal plugin 81 | https://github.com/qiime2/q2-longitudinal/issues 82 | 83 | - The q2-metadata plugin 84 | https://github.com/qiime2/q2-metadata/issues 85 | 86 | - The q2-phylogeny plugin 87 | https://github.com/qiime2/q2-phylogeny/issues 88 | 89 | - The q2-quality-control plugin 90 | https://github.com/qiime2/q2-quality-control/issues 91 | 92 | - The q2-quality-filter plugin 93 | https://github.com/qiime2/q2-quality-filter/issues 94 | 95 | - The q2-sample-classifier plugin 96 | https://github.com/qiime2/q2-sample-classifier/issues 97 | 98 | - The q2-shogun plugin 99 | https://github.com/qiime2/q2-shogun/issues 100 | 101 | - The q2-taxa plugin 102 | https://github.com/qiime2/q2-taxa/issues 103 | 104 | - The q2-types plugin 105 | https://github.com/qiime2/q2-types/issues 106 | 107 | - The q2-vsearch plugin 108 | https://github.com/qiime2/q2-vsearch/issues 109 | 110 | - The CLI interface 111 | https://github.com/qiime2/q2cli/issues 112 | 113 | - The prototype CWL interface 114 | https://github.com/qiime2/q2cwl/issues 115 | 116 | - The prototype Galaxy interface 117 | https://github.com/qiime2/q2galaxy/issues 118 | 119 | - An internal tool for ensuring header text and copyrights are present 120 | https://github.com/qiime2/q2lint/issues 121 | 122 | - The prototype GUI interface 123 | https://github.com/qiime2/q2studio/issues 124 | 125 | - A base template for use in official QIIME 2 plugins 126 | https://github.com/qiime2/q2templates/issues 127 | 128 | - The read-only web interface at view.qiime2.org 129 | https://github.com/qiime2/q2view/issues 130 | 131 | - The QIIME 2 homepage at qiime2.org 132 | https://github.com/qiime2/qiime2.github.io/issues 133 | 134 | - The QIIME 2 framework 135 | https://github.com/qiime2/qiime2/issues 136 | 137 | - Centralized templates for repo assets 138 | https://github.com/qiime2/template-repo/issues 139 | 140 | - Scripts for building QIIME 2 VMs 141 | https://github.com/qiime2/vm-playbooks/issues 142 | 143 | - Scripts for building QIIME 2 workshop clusters 144 | https://github.com/qiime2/workshop-playbooks/issues 145 | 146 | - The web app that runs workshops.qiime2.org 147 | https://github.com/qiime2/workshops.qiime2.org/issues 148 | -------------------------------------------------------------------------------- /q2_feature_classifier/custom.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------------------------------------------- 2 | # Copyright (c) 2016-2025, QIIME 2 development team. 3 | # 4 | # Distributed under the terms of the Modified BSD License. 5 | # 6 | # The full license is in the file LICENSE, distributed with this software. 7 | # ---------------------------------------------------------------------------- 8 | 9 | from itertools import islice 10 | 11 | import numpy 12 | from scipy.sparse import vstack 13 | from sklearn.base import BaseEstimator, ClassifierMixin, clone # noqa 14 | from sklearn.utils.validation import check_X_y, check_array, check_is_fitted # noqa 15 | from sklearn.naive_bayes import MultinomialNB 16 | from sklearn.preprocessing import LabelEncoder 17 | from sklearn.feature_extraction.text import HashingVectorizer 18 | 19 | 20 | class LowMemoryMultinomialNB(MultinomialNB): 21 | def __init__(self, alpha=1.0, fit_prior=True, class_prior=None, 22 | chunk_size=20000): 23 | self.chunk_size = chunk_size 24 | super().__init__(alpha=alpha, fit_prior=fit_prior, 25 | class_prior=class_prior) 26 | 27 | def fit(self, X, y, sample_weight=None): 28 | if self.chunk_size <= 0: 29 | return super().fit(X, y, sample_weight=sample_weight) 30 | 31 | classes = numpy.unique(y) 32 | for i in range(0, X.shape[0], self.chunk_size): 33 | upper = min(i+self.chunk_size, X.shape[0]) 34 | cX = X[i:upper] 35 | cy = y[i:upper] 36 | if sample_weight is None: 37 | csample_weight = None 38 | else: 39 | csample_weight = sample_weight[i:upper] 40 | self.partial_fit(cX, cy, sample_weight=csample_weight, 41 | classes=classes) 42 | 43 | return self 44 | 45 | 46 | class ChunkedHashingVectorizer(HashingVectorizer): 47 | # This class is a kludge to get around 48 | # https://github.com/scikit-learn/scikit-learn/issues/8941 49 | def __init__(self, input='content', encoding='utf-8', 50 | decode_error='strict', strip_accents=None, 51 | lowercase=True, preprocessor=None, tokenizer=None, 52 | stop_words=None, token_pattern=r"(?u)\b\w\w+\b", 53 | ngram_range=(1, 1), analyzer='word', n_features=(2 ** 20), 54 | binary=False, norm='l2', alternate_sign=True, 55 | dtype=numpy.float64, chunk_size=20000): 56 | self.chunk_size = chunk_size 57 | super().__init__( 58 | input=input, encoding=encoding, decode_error=decode_error, 59 | strip_accents=strip_accents, lowercase=lowercase, 60 | preprocessor=preprocessor, tokenizer=tokenizer, 61 | stop_words=stop_words, token_pattern=token_pattern, 62 | ngram_range=ngram_range, analyzer=analyzer, n_features=n_features, 63 | binary=binary, norm=norm, alternate_sign=alternate_sign, 64 | dtype=dtype) 65 | 66 | def transform(self, X): 67 | if self.chunk_size <= 0: 68 | return super().transform(X) 69 | 70 | returnX = None 71 | X = iter(X) 72 | while True: 73 | cX = list(islice(X, self.chunk_size)) 74 | if len(cX) == 0: 75 | break 76 | cX = super().transform(cX) 77 | if returnX is None: 78 | returnX = cX 79 | else: 80 | returnX = vstack([returnX, cX]) 81 | 82 | return returnX 83 | 84 | fit_transform = transform 85 | 86 | 87 | # Experimental feature. USE WITH CAUTION 88 | class _MultioutputClassifier(BaseEstimator, ClassifierMixin): 89 | # This is a hack because it looks like multioutput classifiers can't 90 | # handle non-numeric labels like regular classifiers. 91 | # TODO: raise issue linked to 92 | # https://github.com/scikit-learn/scikit-learn/issues/556 93 | 94 | def __init__(self, base_estimator=None, separator=';'): 95 | self.base_estimator = base_estimator 96 | self.separator = separator 97 | 98 | def fit(self, X, y, **fit_params): 99 | y = list(zip(*[label.split(self.separator) for label in y])) 100 | self.encoders_ = [LabelEncoder() for _ in y] 101 | y = [e.fit_transform(l) for e, l in zip(self.encoders_, y)] 102 | self.base_estimator.fit(X, list(zip(*y)), **fit_params) 103 | return self 104 | 105 | @property 106 | def classes_(self): 107 | classes = [e.inverse_transform(l) for e, l in 108 | zip(self.encoders_, zip(*self.base_estimator.classes_))] 109 | return [self.separator.join(label) for label in zip(*classes)] 110 | 111 | def predict(self, X): 112 | y = self.base_estimator.predict(X).astype(int) 113 | y = [e.inverse_transform(l) for e, l in zip(self.encoders_, y.T)] 114 | return [self.separator.join(label) for label in zip(*y)] 115 | 116 | def predict_proba(self, X): 117 | return self.base_estimator.predict_proba(X) 118 | -------------------------------------------------------------------------------- /q2_feature_classifier/tests/test_taxonomic_classifier.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------------------------------------------- 2 | # Copyright (c) 2016-2025, QIIME 2 development team. 3 | # 4 | # Distributed under the terms of the Modified BSD License. 5 | # 6 | # The full license is in the file LICENSE, distributed with this software. 7 | # ---------------------------------------------------------------------------- 8 | 9 | import unittest 10 | import json 11 | import tempfile 12 | import tarfile 13 | import os 14 | import shutil 15 | 16 | import sklearn 17 | import joblib 18 | from sklearn.pipeline import Pipeline 19 | from qiime2.sdk import Artifact 20 | from qiime2.plugins.feature_classifier.methods import \ 21 | fit_classifier_naive_bayes 22 | 23 | from .._taxonomic_classifier import ( 24 | TaxonomicClassifierDirFmt, TaxonomicClassifier, 25 | TaxonomicClassiferTemporaryPickleDirFmt, PickleFormat) 26 | from . import FeatureClassifierTestPluginBase 27 | 28 | 29 | class TaxonomicClassifierTestBase(FeatureClassifierTestPluginBase): 30 | package = 'q2_feature_classifier.tests' 31 | 32 | def setUp(self): 33 | super().setUp() 34 | 35 | reads = Artifact.import_data( 36 | 'FeatureData[Sequence]', 37 | self.get_data_path('se-dna-sequences.fasta')) 38 | taxonomy = Artifact.import_data( 39 | 'FeatureData[Taxonomy]', self.get_data_path('taxonomy.tsv')) 40 | classifier = fit_classifier_naive_bayes(reads, taxonomy) 41 | pipeline = classifier.classifier.view(Pipeline) 42 | transformer = self.get_transformer( 43 | Pipeline, TaxonomicClassiferTemporaryPickleDirFmt) 44 | self._sklp = transformer(pipeline) 45 | sklearn_pipeline = self._sklp.sklearn_pipeline.view(PickleFormat) 46 | self.sklearn_pipeline = str(sklearn_pipeline) 47 | 48 | def _custom_setup(self, version): 49 | with open(os.path.join(self.temp_dir.name, 50 | 'sklearn_version.json'), 'w') as fh: 51 | fh.write(json.dumps({'sklearn-version': version})) 52 | shutil.copy(self.sklearn_pipeline, self.temp_dir.name) 53 | return TaxonomicClassiferTemporaryPickleDirFmt( 54 | self.temp_dir.name, mode='r') 55 | 56 | 57 | class TestTypes(FeatureClassifierTestPluginBase): 58 | def test_taxonomic_classifier_semantic_type_registration(self): 59 | self.assertRegisteredSemanticType(TaxonomicClassifier) 60 | 61 | def test_taxonomic_classifier_semantic_type_to_format_registration(self): 62 | self.assertSemanticTypeRegisteredToFormat( 63 | TaxonomicClassifier, TaxonomicClassiferTemporaryPickleDirFmt) 64 | 65 | 66 | class TestFormats(TaxonomicClassifierTestBase): 67 | def test_taxonomic_classifier_dir_fmt(self): 68 | format = self._custom_setup(sklearn.__version__) 69 | 70 | # Should not error 71 | format.validate() 72 | 73 | 74 | class TestTransformers(TaxonomicClassifierTestBase): 75 | def test_old_sklearn_version(self): 76 | transformer = self.get_transformer( 77 | TaxonomicClassiferTemporaryPickleDirFmt, Pipeline) 78 | input = self._custom_setup('a very old version') 79 | with self.assertRaises(ValueError): 80 | transformer(input) 81 | 82 | def test_old_dirfmt(self): 83 | transformer = self.get_transformer(TaxonomicClassifierDirFmt, Pipeline) 84 | with open(os.path.join(self.temp_dir.name, 85 | 'preprocess_params.json'), 'w') as fh: 86 | fh.write(json.dumps([])) 87 | shutil.copy(self.sklearn_pipeline, self.temp_dir.name) 88 | input = TaxonomicClassifierDirFmt( 89 | self.temp_dir.name, mode='r') 90 | with self.assertRaises(ValueError): 91 | transformer(input) 92 | 93 | def test_taxo_class_dir_fmt_to_taxo_class_result(self): 94 | input = self._custom_setup(sklearn.__version__) 95 | 96 | transformer = self.get_transformer( 97 | TaxonomicClassiferTemporaryPickleDirFmt, Pipeline) 98 | obs = transformer(input) 99 | 100 | self.assertTrue(obs) 101 | 102 | def test_taxo_class_result_to_taxo_class_dir_fmt(self): 103 | def read_pipeline(pipeline_filepath): 104 | with tarfile.open(pipeline_filepath) as tar: 105 | dirname = tempfile.mkdtemp() 106 | tar.extractall(dirname) 107 | pipeline = joblib.load(os.path.join(dirname, 108 | 'sklearn_pipeline.pkl')) 109 | for fn in tar.getnames(): 110 | os.unlink(os.path.join(dirname, fn)) 111 | os.rmdir(dirname) 112 | return pipeline 113 | 114 | exp = read_pipeline(self.sklearn_pipeline) 115 | transformer = self.get_transformer( 116 | Pipeline, TaxonomicClassiferTemporaryPickleDirFmt) 117 | obs = transformer(exp) 118 | sklearn_pipeline = obs.sklearn_pipeline.view(PickleFormat) 119 | obs_pipeline = read_pipeline(str(sklearn_pipeline)) 120 | obs = obs_pipeline 121 | self.assertTrue(obs) 122 | 123 | 124 | if __name__ == "__main__": 125 | unittest.main() 126 | -------------------------------------------------------------------------------- /q2_feature_classifier/_skl.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------------------------------------------- 2 | # Copyright (c) 2016-2025, QIIME 2 development team. 3 | # 4 | # Distributed under the terms of the Modified BSD License. 5 | # 6 | # The full license is in the file LICENSE, distributed with this software. 7 | # ---------------------------------------------------------------------------- 8 | 9 | from dataclasses import dataclass, field 10 | from functools import cached_property 11 | from itertools import islice, repeat 12 | from typing import Dict, List 13 | 14 | from joblib import Parallel, delayed 15 | 16 | 17 | @dataclass 18 | class _TaxonNode: 19 | # The _TaxonNode is used to build a hierarchy from a list of sorted class 20 | # labels. It allows one to quickly find class label indices of taxonomy 21 | # labels that satisfy a given taxonomy hierarchy. For example, given the 22 | # 'k__Bacteria' taxon, the _TaxonNode.range property will yield all class 23 | # label indices where 'k__Bacteria' is a prefix. 24 | 25 | name: str 26 | offset_index: int 27 | children: Dict[str, "_TaxonNode"] = field( 28 | default_factory=dict, 29 | repr=False) 30 | 31 | @classmethod 32 | def create_tree(cls, classes: List[str], separator: str): 33 | if not all(a <= b for a, b in zip(classes, classes[1:])): 34 | raise Exception("classes must be in sorted order") 35 | root = cls("Unassigned", 0) 36 | for class_start_index, label in enumerate(classes): 37 | taxons = label.split(separator) 38 | node = root 39 | for name in taxons: 40 | if name not in node.children: 41 | node.children[name] = cls(name, class_start_index) 42 | node = node.children[name] 43 | return root 44 | 45 | @property 46 | def range(self) -> range: 47 | return range( 48 | self.offset_index, 49 | self.offset_index + self.num_leaf_nodes) 50 | 51 | @cached_property 52 | def num_leaf_nodes(self) -> int: 53 | if len(self.children) == 0: 54 | return 1 55 | return sum(c.num_leaf_nodes for c in self.children.values()) 56 | 57 | 58 | _specific_fitters = [ 59 | ['naive_bayes', 60 | [['feat_ext', 61 | {'__type__': 'feature_extraction.text.HashingVectorizer', 62 | 'analyzer': 'char_wb', 63 | 'n_features': 8192, 64 | 'ngram_range': (7, 7), 65 | 'alternate_sign': False}], 66 | ['classify', 67 | {'__type__': 'custom.LowMemoryMultinomialNB', 68 | 'alpha': 0.001, 69 | 'fit_prior': False}]]]] 70 | 71 | 72 | def fit_pipeline(reads, taxonomy, pipeline): 73 | seq_ids, X = _extract_reads(reads) 74 | data = [(taxonomy[s], x) for s, x in zip(seq_ids, X) if s in taxonomy] 75 | y, X = list(zip(*data)) 76 | pipeline.fit(X, y) 77 | return pipeline 78 | 79 | 80 | def _extract_reads(reads): 81 | return zip(*[(r.metadata['id'], r._string) for r in reads]) 82 | 83 | 84 | def predict(reads, pipeline, separator=';', chunk_size=262144, n_jobs=1, 85 | pre_dispatch='2*n_jobs', confidence='disable'): 86 | jobs = ( 87 | delayed(_predict_chunk)(pipeline, separator, confidence, chunk) 88 | for chunk in _chunks(reads, chunk_size)) 89 | workers = Parallel(n_jobs=n_jobs, batch_size=1, pre_dispatch=pre_dispatch) 90 | for calculated in workers(jobs): 91 | yield from calculated 92 | 93 | 94 | def _predict_chunk(pipeline, separator, confidence, chunk): 95 | if confidence == 'disable': 96 | return _predict_chunk_without_conf(pipeline, chunk) 97 | else: 98 | return _predict_chunk_with_conf(pipeline, separator, confidence, chunk) 99 | 100 | 101 | def _predict_chunk_without_conf(pipeline, chunk): 102 | seq_ids, X = _extract_reads(chunk) 103 | y = pipeline.predict(X) 104 | return zip(seq_ids, y, repeat(-1.)) 105 | 106 | 107 | def _predict_chunk_with_conf(pipeline, separator, confidence, chunk): 108 | seq_ids, X = _extract_reads(chunk) 109 | 110 | if not hasattr(pipeline, "predict_proba"): 111 | raise ValueError('this classifier does not support confidence values') 112 | prob_pos = pipeline.predict_proba(X) 113 | if prob_pos.shape != (len(X), len(pipeline.classes_)): 114 | raise ValueError('this classifier does not support confidence values') 115 | 116 | y = pipeline.classes_[prob_pos.argmax(axis=1)] 117 | 118 | taxonomy_tree = _TaxonNode.create_tree(pipeline.classes_, separator) 119 | 120 | results = [] 121 | for seq_id, taxon, class_probs in zip(seq_ids, y, prob_pos): 122 | split_taxon = taxon.split(separator) 123 | accepted_cum_prob = 0.0 124 | cum_prob = 0.0 125 | result = [] 126 | current = taxonomy_tree 127 | for rank in split_taxon: 128 | current = current.children[rank] 129 | cum_prob = class_probs[current.range].sum() 130 | if cum_prob < confidence: 131 | break 132 | accepted_cum_prob = cum_prob 133 | result.append(rank) 134 | if len(result) == 0: 135 | results.append((seq_id, "Unassigned", 1.0 - cum_prob)) 136 | else: 137 | results.append((seq_id, separator.join(result), accepted_cum_prob)) 138 | return results 139 | 140 | 141 | def _chunks(reads, chunk_size): 142 | reads = iter(reads) 143 | while True: 144 | chunk = list(islice(reads, chunk_size)) 145 | if len(chunk) == 0: 146 | break 147 | yield chunk 148 | -------------------------------------------------------------------------------- /.github/SUPPORT.md: -------------------------------------------------------------------------------- 1 | # QIIME 2 Users 2 | 3 | Check out the [User Docs](https://docs.qiime2.org) - there are many tutorials, 4 | walkthroughs, and guides available. If you still need help, please visit us at 5 | the [QIIME 2 Forum](https://forum.qiime2.org/c/user-support). 6 | 7 | # QIIME 2 Developers 8 | 9 | Check out the [Developer Docs](https://dev.qiime2.org) - there are many 10 | tutorials, walkthroughs, and guides available. If you still need help, please 11 | visit us at the [QIIME 2 Forum](https://forum.qiime2.org/c/dev-discussion). 12 | 13 | # General Bug/Issue Triage Discussion 14 | 15 | ![rubric](./rubric.png?raw=true) 16 | 17 | # Projects/Repositories in the QIIME 2 GitHub Organization 18 | 19 | Sorted alphabetically by repo name. 20 | 21 | - [busywork](https://github.com/qiime2/busywork/issues) 22 | | The CI automation engine that builds and distributes QIIME 2 23 | - [conda-channel-resource](https://github.com/qiime2/conda-channel-resource/issues) 24 | | A Concourse resource for working with conda 25 | - [data.qiime2.org](https://github.com/qiime2/data.qiime2.org/issues) 26 | | Web app for vanity URLs for QIIME 2 data assets 27 | - [dev-docs](https://github.com/qiime2/dev-docs/issues) 28 | | The Developer Documentation 29 | - [discourse-unhandled-tagger](https://github.com/qiime2/discourse-unhandled-tagger/issues) 30 | | A discourse plugin for handling queued/unqueued topics 31 | - [docs](https://github.com/qiime2/docs/issues) 32 | | The User Documentation 33 | - [environment-files](https://github.com/qiime2/environment-files/issues) 34 | | Rendered QIIME 2 environment files for conda 35 | - [Keemei](https://github.com/qiime2/Keemei/issues) 36 | | Google Sheets Add-On for validating tabular data 37 | - [linux-worker-docker](https://github.com/qiime2/linux-worker-docker/issues) 38 | | A docker image for linux-based busywork workers 39 | - [logos](https://github.com/qiime2/logos/issues) 40 | | Official project logos 41 | - [q2-alignment](https://github.com/qiime2/q2-alignment/issues) 42 | | The q2-alignment plugin 43 | - [q2-composition](https://github.com/qiime2/q2-composition/issues) 44 | | The q2-composition plugin 45 | - [q2-cutadapt](https://github.com/qiime2/q2-cutadapt/issues) 46 | | The q2-cutadapt plugin 47 | - [q2-dada2](https://github.com/qiime2/q2-dada2/issues) 48 | | The q2-dada2 plugin 49 | - [q2-deblur](https://github.com/qiime2/q2-deblur/issues) 50 | | The q2-deblur plugin 51 | - [q2-demux](https://github.com/qiime2/q2-demux/issues) 52 | | The q2-demux plugin 53 | - [q2-diversity](https://github.com/qiime2/q2-diversity/issues) 54 | | The q2-diversity plugin 55 | - [q2-diversity-lib](https://github.com/qiime2/q2-diversity-lib/issues) 56 | | The q2-diversity-lib plugin 57 | - [q2-emperor](https://github.com/qiime2/q2-emperor/issues) 58 | | The q2-emperor plugin 59 | - [q2-feature-classifier](https://github.com/qiime2/q2-feature-classifier/issues) 60 | | The q2-feature-classifier plugin 61 | - [q2-feature-table](https://github.com/qiime2/q2-feature-table/issues) 62 | | The q2-feature-table plugin 63 | - [q2-fragment-insertion](https://github.com/qiime2/q2-fragment-insertion/issues) 64 | | The q2-fragment-insertion plugin 65 | - [q2-gneiss](https://github.com/qiime2/q2-gneiss/issues) 66 | | The q2-gneiss plugin 67 | - [q2-longitudinal](https://github.com/qiime2/q2-longitudinal/issues) 68 | | The q2-longitudinal plugin 69 | - [q2-metadata](https://github.com/qiime2/q2-metadata/issues) 70 | | The q2-metadata plugin 71 | - [q2-phylogeny](https://github.com/qiime2/q2-phylogeny/issues) 72 | | The q2-phylogeny plugin 73 | - [q2-quality-control](https://github.com/qiime2/q2-quality-control/issues) 74 | | The q2-quality-control plugin 75 | - [q2-quality-filter](https://github.com/qiime2/q2-quality-filter/issues) 76 | | The q2-quality-filter plugin 77 | - [q2-sample-classifier](https://github.com/qiime2/q2-sample-classifier/issues) 78 | | The q2-sample-classifier plugin 79 | - [q2-shogun](https://github.com/qiime2/q2-shogun/issues) 80 | | The q2-shogun plugin 81 | - [q2-taxa](https://github.com/qiime2/q2-taxa/issues) 82 | | The q2-taxa plugin 83 | - [q2-types](https://github.com/qiime2/q2-types/issues) 84 | | The q2-types plugin 85 | - [q2-vsearch](https://github.com/qiime2/q2-vsearch/issues) 86 | | The q2-vsearch plugin 87 | - [q2cli](https://github.com/qiime2/q2cli/issues) 88 | | The CLI interface 89 | - [q2cwl](https://github.com/qiime2/q2cwl/issues) 90 | | The prototype CWL interface 91 | - [q2galaxy](https://github.com/qiime2/q2galaxy/issues) 92 | | The prototype Galaxy interface 93 | - [q2lint](https://github.com/qiime2/q2lint/issues) 94 | | An internal tool for ensuring header text and copyrights are present 95 | - [q2studio](https://github.com/qiime2/q2studio/issues) 96 | | The prototype GUI interface 97 | - [q2templates](https://github.com/qiime2/q2templates/issues) 98 | | A base template for use in official QIIME 2 plugins 99 | - [q2view](https://github.com/qiime2/q2view/issues) 100 | | The read-only web interface at view.qiime2.org 101 | - [qiime2.github.io](https://github.com/qiime2/qiime2.github.io/issues) 102 | | The QIIME 2 homepage at qiime2.org 103 | - [qiime2](https://github.com/qiime2/qiime2/issues) 104 | | The QIIME 2 framework 105 | - [template-repo](https://github.com/qiime2/template-repo/issues) 106 | | Centralized templates for repo assets 107 | - [vm-playbooks](https://github.com/qiime2/vm-playbooks/issues) 108 | | Scripts for building QIIME 2 VMs 109 | - [workshop-playbooks](https://github.com/qiime2/workshop-playbooks/issues) 110 | | Scripts for building QIIME 2 workshop clusters 111 | - [workshops.qiime2.org](https://github.com/qiime2/workshops.qiime2.org/issues) 112 | | The web app that runs workshops.qiime2.org 113 | -------------------------------------------------------------------------------- /q2_feature_classifier/tests/test_cutter.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------------------------------------------- 2 | # Copyright (c) 2016-2025, QIIME 2 development team. 3 | # 4 | # Distributed under the terms of the Modified BSD License. 5 | # 6 | # The full license is in the file LICENSE, distributed with this software. 7 | # ---------------------------------------------------------------------------- 8 | 9 | import skbio 10 | 11 | from qiime2.sdk import Artifact 12 | from qiime2.plugins.feature_classifier.actions import extract_reads 13 | from q2_types.feature_data import DNAFASTAFormat 14 | 15 | from . import FeatureClassifierTestPluginBase 16 | 17 | 18 | class CutterTests(FeatureClassifierTestPluginBase): 19 | package = 'q2_feature_classifier.tests' 20 | 21 | def setUp(self): 22 | super().setUp() 23 | self.sequences = Artifact.import_data( 24 | 'FeatureData[Sequence]', 25 | self.get_data_path('dna-sequences.fasta')) 26 | 27 | self.mixed_sequences = Artifact.import_data( 28 | 'FeatureData[Sequence]', 29 | self.get_data_path('dna-sequences-mixed.fasta')) 30 | 31 | self.f_primer = 'AGAGA' 32 | self.r_primer = 'GCTGC' 33 | 34 | self.amplicons = ['ACGT', 'AAGT', 'ACCT', 'ACGG', 'ACTT'] 35 | 36 | def _test_results(self, results): 37 | for i, result in enumerate( 38 | skbio.io.read(str(results.reads.view(DNAFASTAFormat)), 39 | format='fasta')): 40 | self.assertEqual(str(result), self.amplicons[i]) 41 | 42 | def test_extract_reads_expected(self): 43 | results = extract_reads( 44 | self.sequences, f_primer=self.f_primer, r_primer=self.r_primer, 45 | min_length=4) 46 | 47 | self._test_results(results) 48 | 49 | def test_extract_reads_expected_forward(self): 50 | results = extract_reads( 51 | self.sequences, f_primer=self.f_primer, r_primer=self.r_primer, 52 | min_length=4, read_orientation='forward') 53 | 54 | self._test_results(results) 55 | 56 | def test_extract_mixed(self): 57 | results = extract_reads( 58 | self.mixed_sequences, f_primer=self.f_primer, 59 | r_primer=self.r_primer, min_length=4) 60 | 61 | self._test_results(results) 62 | 63 | def test_extract_reads_expected_reverse(self): 64 | reverse_sequences = Artifact.import_data( 65 | 'FeatureData[Sequence]', 66 | self.get_data_path('dna-sequences-reverse.fasta')) 67 | 68 | results = extract_reads( 69 | reverse_sequences, f_primer=self.f_primer, r_primer=self.r_primer, 70 | min_length=4, read_orientation='reverse') 71 | 72 | self._test_results(results) 73 | 74 | def test_extract_reads_manual_batch_size(self): 75 | results = extract_reads( 76 | self.sequences, f_primer=self.f_primer, r_primer=self.r_primer, 77 | min_length=4, batch_size=10) 78 | 79 | self._test_results(results) 80 | 81 | def test_extract_reads_two_jobs(self): 82 | results = extract_reads( 83 | self.sequences, f_primer=self.f_primer, r_primer=self.r_primer, 84 | min_length=4, n_jobs=2) 85 | 86 | self._test_results(results) 87 | 88 | def test_extract_reads_expected_degenerate_primers(self): 89 | degenerate_f_primer = 'WWWWW' 90 | degenerate_r_primer = 'SSSSS' 91 | 92 | degenerate_sequences = Artifact.import_data( 93 | 'FeatureData[Sequence]', 94 | self.get_data_path('dna-sequences-degenerate-primers.fasta')) 95 | 96 | results = extract_reads( 97 | degenerate_sequences, f_primer=degenerate_f_primer, 98 | r_primer=degenerate_r_primer, min_length=4) 99 | 100 | self._test_results(results) 101 | 102 | def test_extract_reads_expected_trim_right(self): 103 | """Tests expected behavior of trim_right option""" 104 | results = extract_reads( 105 | self.sequences, f_primer=self.f_primer, r_primer=self.r_primer, 106 | min_length=3, trim_right=1) 107 | 108 | for i, result in enumerate( 109 | skbio.io.read(str(results.reads.view(DNAFASTAFormat)), 110 | format='fasta')): 111 | self.assertEqual(str(result), self.amplicons[i][:-1]) 112 | 113 | def test_extract_reads_fail_identity(self): 114 | with self.assertRaisesRegex(RuntimeError, "No matches found"): 115 | extract_reads( 116 | self.sequences, f_primer=self.f_primer, r_primer=self.r_primer, 117 | min_length=4, identity=1) 118 | 119 | def test_extract_reads_fail_min_length(self): 120 | with self.assertRaisesRegex(RuntimeError, "No matches found"): 121 | extract_reads( 122 | self.sequences, f_primer=self.f_primer, r_primer=self.r_primer, 123 | min_length=5) 124 | 125 | def test_extract_reads_fail_max_length(self): 126 | with self.assertRaisesRegex(RuntimeError, "No matches found"): 127 | extract_reads( 128 | self.sequences, f_primer=self.f_primer, r_primer=self.r_primer, 129 | max_length=1) 130 | 131 | def test_extract_reads_fail_trim_left_entire_read(self): 132 | with self.assertRaisesRegex(RuntimeError, "No matches found"): 133 | extract_reads( 134 | self.sequences, f_primer=self.f_primer, r_primer=self.r_primer, 135 | trim_left=4) 136 | 137 | def test_extract_reads_fail_trim_right_entire_read(self): 138 | with self.assertRaisesRegex(RuntimeError, "No matches found"): 139 | extract_reads( 140 | self.sequences, f_primer=self.f_primer, r_primer=self.r_primer, 141 | trim_right=4) 142 | 143 | def test_extract_reads_fail_trim_both_entire_read(self): 144 | with self.assertRaisesRegex(RuntimeError, "No matches found"): 145 | extract_reads( 146 | self.sequences, f_primer=self.f_primer, r_primer=self.r_primer, 147 | trim_left=2, trim_right=2) 148 | 149 | def test_extract_reads_fail_min_len_greater_than_trunc_len(self): 150 | with self.assertRaisesRegex(ValueError, "minimum length setting"): 151 | extract_reads( 152 | self.sequences, f_primer=self.f_primer, r_primer=self.r_primer, 153 | trunc_len=1) 154 | -------------------------------------------------------------------------------- /q2_feature_classifier/_consensus_assignment.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------------------------------------------- 2 | # Copyright (c) 2016-2025, QIIME 2 development team. 3 | # 4 | # Distributed under the terms of the Modified BSD License. 5 | # 6 | # The full license is in the file LICENSE, distributed with this software. 7 | # ---------------------------------------------------------------------------- 8 | 9 | from collections import Counter 10 | from math import ceil 11 | 12 | import pandas as pd 13 | 14 | from qiime2.plugin import Str, Float, Range 15 | from .plugin_setup import plugin 16 | from q2_types.feature_data import FeatureData, Taxonomy, BLAST6 17 | 18 | 19 | min_consensus_param = {'min_consensus': Float % Range( 20 | 0.5, 1.0, inclusive_end=True, inclusive_start=False)} 21 | 22 | min_consensus_param_description = { 23 | 'min_consensus': 'Minimum fraction of assignments must match top ' 24 | 'hit to be accepted as consensus assignment.'} 25 | 26 | DEFAULTUNASSIGNABLELABEL = "Unassigned" 27 | 28 | 29 | def find_consensus_annotation(search_results: pd.DataFrame, 30 | reference_taxonomy: pd.Series, 31 | min_consensus: int = 0.51, 32 | unassignable_label: str = 33 | DEFAULTUNASSIGNABLELABEL 34 | ) -> pd.DataFrame: 35 | """Find consensus taxonomy from BLAST6Format alignment summary. 36 | 37 | search_results: pd.dataframe 38 | BLAST6Format search results with canonical headers attached. 39 | reference_taxonomy: pd.Series 40 | Annotations of reference database used for original search. 41 | min_consensus : float 42 | The minimum fraction of the annotations that a specific annotation 43 | must be present in for that annotation to be accepted. Current 44 | lower boundary is 0.51. 45 | unassignable_label : str 46 | The label to apply if no acceptable annotations are identified. 47 | """ 48 | # load and convert blast6format results to dict of taxa hits 49 | obs_taxa = _blast6format_df_to_series_of_lists( 50 | search_results, reference_taxonomy, 51 | unassignable_label=unassignable_label) 52 | # TODO: is it worth allowing early stopping if maxaccepts==1? 53 | # compute consensus annotations 54 | result = _compute_consensus_annotations( 55 | obs_taxa, min_consensus=min_consensus, 56 | unassignable_label=unassignable_label) 57 | result.index.name = 'Feature ID' 58 | return result 59 | 60 | 61 | plugin.methods.register_function( 62 | function=find_consensus_annotation, 63 | inputs={'search_results': FeatureData[BLAST6], 64 | 'reference_taxonomy': FeatureData[Taxonomy]}, 65 | parameters={ 66 | **min_consensus_param, 67 | 'unassignable_label': Str}, 68 | outputs=[('consensus_taxonomy', FeatureData[Taxonomy])], 69 | input_descriptions={ 70 | 'search_results': 'Search results in BLAST6 output format', 71 | 'reference_taxonomy': 'reference taxonomy labels.'}, 72 | parameter_descriptions={ 73 | **min_consensus_param_description, 74 | 'unassignable_label': 'Annotation given when no consensus is found.' 75 | }, 76 | output_descriptions={ 77 | 'consensus_taxonomy': 'Consensus taxonomy and scores.'}, 78 | name='Find consensus among multiple annotations.', 79 | description=('Find consensus annotation for each query searched against ' 80 | 'a reference database, by finding the least common ancestor ' 81 | 'among one or more semicolon-delimited hierarchical ' 82 | 'annotations. Note that the annotation hierarchy is assumed ' 83 | 'to have an even number of ranks.'), 84 | ) 85 | 86 | 87 | def _blast6format_df_to_series_of_lists( 88 | assignments: pd.DataFrame, 89 | ref_taxa: pd.Series, 90 | unassignable_label: str = DEFAULTUNASSIGNABLELABEL 91 | ) -> pd.Series: 92 | """import observed assignments in blast6 format to series of lists. 93 | 94 | assignments: pd.DataFrame 95 | Taxonomy observation map in blast format 6. Each line consists of 96 | taxonomy assignments of a query sequence in tab-delimited format: 97 | <...other columns are ignored> 98 | 99 | ref_taxa: pd.Series 100 | Reference taxonomies in tab-delimited format: 101 | Annotation 102 | The accession IDs in this taxonomy should match the subject-seq-ids in 103 | the "assignment" input. 104 | """ 105 | # validate that assignments are present in reference taxonomy 106 | # (i.e., that the correct reference taxonomy was used). 107 | # Note that we drop unassigned labels from this set. 108 | missing_ids = \ 109 | set(assignments['sseqid'].values) - set(ref_taxa.index) - {'*', ''} 110 | if len(missing_ids) > 0: 111 | raise KeyError('Reference taxonomy and search results do not match. ' 112 | 'The following identifiers were reported in the search ' 113 | 'results but are not present in the reference taxonomy:' 114 | ' {0}'.format(', '.join(str(i) for i in missing_ids))) 115 | 116 | # if vsearch fails to find assignment, it reports '*' as the 117 | # accession ID, so we will add this mapping to the reference taxonomy. 118 | ref_taxa['*'] = unassignable_label 119 | assignments_copy = assignments.copy(deep=True) 120 | for index, value in assignments_copy.iterrows(): 121 | sseqid = assignments_copy.iloc[index]['sseqid'] 122 | assignments_copy.at[index, 'sseqid'] = ref_taxa.at[sseqid] 123 | # convert to dict of {accession_id: [annotations]} 124 | taxa_hits: pd.Series = assignments_copy.set_index('qseqid')['sseqid'] 125 | taxa_hits = taxa_hits.groupby(taxa_hits.index).apply(list) 126 | 127 | return taxa_hits 128 | 129 | 130 | def _compute_consensus_annotations( 131 | query_annotations, min_consensus, 132 | unassignable_label=DEFAULTUNASSIGNABLELABEL): 133 | """ 134 | Parameters 135 | ---------- 136 | query_annotations : pd.Series of lists 137 | Indices are query identifiers, and values are lists of all 138 | taxonomic annotations associated with that identifier. 139 | Returns 140 | ------- 141 | pd.DataFrame 142 | Indices are query identifiers, and values are the consensus of the 143 | input taxonomic annotations, and the consensus score. 144 | """ 145 | # define function to apply to each list of taxa hits 146 | # Note: I am setting this up to open the possibility to define other 147 | # functions later (e.g., not simple threshold consensus) 148 | def _apply_consensus_function(taxa, min_consensus=min_consensus, 149 | unassignable_label=unassignable_label, 150 | _consensus_function=_lca_consensus): 151 | # if there is no consensus, skip consensus calculation 152 | if len(taxa) == 1: 153 | taxa, score = taxa.pop(), 1. 154 | else: 155 | taxa = _taxa_to_cumulative_ranks(taxa) 156 | # apply and score consensus 157 | taxa, score = _consensus_function( 158 | taxa, min_consensus, unassignable_label) 159 | # return as a series so that the outer apply returns a dataframe 160 | # (i.e., consensus scores get inserted as an additional column) 161 | return pd.Series([taxa, score], index=['Taxon', 'Consensus']) 162 | 163 | # If func returns a Series object the result will be a DataFrame. 164 | return query_annotations.apply(_apply_consensus_function) 165 | 166 | 167 | # first split semicolon-delimited taxonomies by rank 168 | # and iteratively join ranks, so that: ['a;b;c', 'a;b;d', 'a;g;g'] --> 169 | # [['a', 'a;b', 'a;b;c'], ['a', 'a;b', 'a;b;d'], ['a', 'a;g', 'a;g;g']] 170 | # this is to avoid issues where e.g., the same species name may occur 171 | # in different taxonomic lineages. 172 | def _taxa_to_cumulative_ranks(taxa): 173 | """ 174 | Parameters 175 | ---------- 176 | taxa : list or str 177 | List of semicolon-delimited taxonomic labels. 178 | e.g., ['a;b;c', 'a;b;d'] 179 | Returns 180 | ------- 181 | list of lists of str 182 | Lists of cumulative taxonomic ranks for each input str 183 | e.g., [['a', 'a;b', 'a;b;c'], ['a', 'a;b', 'a;b;d']] 184 | """ 185 | return [[';'.join(t.split(';')[:n + 1]) 186 | for n in range(t.count(';') + 1)] 187 | for t in taxa] 188 | 189 | 190 | # Find the LCA by consensus threshold. Return label and the consensus score. 191 | def _lca_consensus(annotations, min_consensus, unassignable_label): 192 | """ Compute the consensus of a collection of annotations 193 | Parameters 194 | ---------- 195 | annotations : list of lists 196 | Taxonomic annotations to form consensus. 197 | min_consensus : float 198 | The minimum fraction of the annotations that a specific annotation 199 | must be present in for that annotation to be accepted. Current 200 | lower boundary is 0.51. 201 | unassignable_label : str 202 | The label to apply if no acceptable annotations are identified. 203 | Result 204 | ------ 205 | consensus_annotation: str 206 | The consensus assignment 207 | consensus_fraction: float 208 | Fraction of input annotations that agreed at the deepest 209 | level of assignment 210 | """ 211 | # count total number of labels to get consensus threshold 212 | n_annotations = len(annotations) 213 | threshold = ceil(n_annotations * min_consensus) 214 | # zip together ranks and count frequency of each unique label. 215 | # This assumes that a hierarchical taxonomy with even numbers of 216 | # ranks was used. 217 | taxa_comparison = [Counter(rank) for rank in zip(*annotations)] 218 | # iterate rank comparisons in reverse 219 | # to find rank with consensus count > threshold 220 | for rank in taxa_comparison[::-1]: 221 | # grab most common label and its count 222 | label, count = rank.most_common(1)[0] 223 | # TODO: this assumes that min_consensus >= 0.51 (current lower bound) 224 | # but could fail to find ties if we allow lower min_consensus scores 225 | if count >= threshold: 226 | return label, round(count / n_annotations, 3) 227 | # if we reach this point, no consensus was ever found at any rank 228 | return unassignable_label, 0.0 229 | -------------------------------------------------------------------------------- /q2_feature_classifier/_cutter.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------------------------------------------- 2 | # Copyright (c) 2016-2025, QIIME 2 development team. 3 | # 4 | # Distributed under the terms of the Modified BSD License. 5 | # 6 | # The full license is in the file LICENSE, distributed with this software. 7 | # ---------------------------------------------------------------------------- 8 | 9 | import skbio 10 | import os 11 | from joblib import Parallel, delayed, effective_n_jobs 12 | 13 | from qiime2.plugin import Int, Str, Float, Range, Choices 14 | from q2_types.feature_data import (FeatureData, Sequence, DNAIterator, 15 | DNASequencesDirectoryFormat, DNAFASTAFormat) 16 | from q2_feature_classifier._skl import _chunks 17 | from q2_feature_classifier.classifier import _autotune_reads_per_batch 18 | 19 | from .plugin_setup import plugin 20 | 21 | 22 | def _seq_to_regex(seq): 23 | """Build a regex out of a IUPAC consensus sequence""" 24 | result = [] 25 | for base in str(seq): 26 | if base in skbio.DNA.degenerate_chars: 27 | result.append('[{0}]'.format( 28 | ''.join(sorted(skbio.DNA.degenerate_map[base])))) 29 | else: 30 | result.append(base) 31 | 32 | return ''.join(result) 33 | 34 | 35 | def _primers_to_regex(f_primer, r_primer): 36 | return '({0}.*{1})'.format(_seq_to_regex(f_primer), 37 | _seq_to_regex(r_primer.reverse_complement())) 38 | 39 | 40 | def _local_aln(primer, sequence): 41 | best_score = None 42 | for one_primer in sorted([str(s) for s in primer.expand_degenerates()]): 43 | # `sequence` may contain degenerates. These will usually be N 44 | # characters, which SSW will score as zero. Although undocumented, SSW 45 | # will treat other degenerate characters as a mismatch. We acknowledge 46 | # that this approach is a heuristic to finding an optimal alignment and 47 | # may be revisited in the future if there's an aligner that explicitly 48 | # handles degenerates. 49 | this_aln = \ 50 | skbio.alignment.local_pairwise_align_ssw(skbio.DNA(one_primer), 51 | sequence) 52 | score = this_aln[1] 53 | if best_score is None or score > best_score: 54 | best_score = score 55 | best_aln = this_aln 56 | return best_aln 57 | 58 | 59 | def _semisemiglobal(primer, sequence, reverse=False): 60 | if reverse: 61 | primer = primer.reverse_complement() 62 | 63 | # locally align the primer 64 | (aln_prim, aln_seq), score, (prim_pos, seq_pos) = \ 65 | _local_aln(primer, sequence) 66 | amplicon_pos = seq_pos[1]+len(primer)-prim_pos[1] 67 | 68 | # naively extend the alignment to be semi-global 69 | bits = [primer[:prim_pos[0]], aln_prim, primer[prim_pos[1]+1:]] 70 | aln_prim = ''.join(map(str, bits)) 71 | bits = ['-'*(prim_pos[0]-seq_pos[0]), 72 | sequence[max(seq_pos[0]-prim_pos[0], 0):seq_pos[0]], 73 | aln_seq, 74 | sequence[seq_pos[1]+1:amplicon_pos], 75 | '-'*(amplicon_pos-len(sequence))] 76 | aln_seq = ''.join(map(str, bits)) 77 | 78 | # count the matches 79 | matches = sum(s in skbio.DNA.degenerate_map.get(p, {p}) 80 | for p, s in zip(aln_prim, aln_seq)) 81 | 82 | if reverse: 83 | amplicon_pos = max(seq_pos[0]-prim_pos[0], 0) 84 | 85 | return amplicon_pos, matches, len(aln_prim) 86 | 87 | 88 | def _exact_match(seq, f_primer, r_primer): 89 | try: 90 | regex = _primers_to_regex(f_primer, r_primer) 91 | match = next(seq.find_with_regex(regex)) 92 | beg, end = match.start + len(f_primer), match.stop - len(r_primer) 93 | return seq[beg:end] 94 | except StopIteration: 95 | return None 96 | 97 | 98 | def _approx_match(seq, f_primer, r_primer, identity): 99 | beg, b_matches, b_length = _semisemiglobal(f_primer, seq) 100 | end, e_matches, e_length = _semisemiglobal(r_primer, seq, reverse=True) 101 | if (b_matches + e_matches) / (b_length + e_length) >= identity: 102 | return seq[beg:end] 103 | return None 104 | 105 | 106 | def _gen_reads(sequence, f_primer, r_primer, trim_right, trunc_len, trim_left, 107 | identity, min_length, max_length, read_orientation): 108 | f_primer = skbio.DNA(f_primer) 109 | r_primer = skbio.DNA(r_primer) 110 | amp = None 111 | if read_orientation in ['forward', 'both']: 112 | amp = _exact_match(sequence, f_primer, r_primer) 113 | if not amp and read_orientation in ['reverse', 'both']: 114 | amp = _exact_match(sequence.reverse_complement(), f_primer, r_primer) 115 | if not amp and read_orientation in ['forward', 'both']: 116 | amp = _approx_match(sequence, f_primer, r_primer, identity) 117 | if not amp and read_orientation in ['reverse', 'both']: 118 | amp = _approx_match( 119 | sequence.reverse_complement(), f_primer, r_primer, identity) 120 | if not amp: 121 | return 122 | # we want to filter by max length before trimming 123 | if max_length > 0 and len(amp) > max_length: 124 | return 125 | if trim_right > 0: 126 | amp = amp[:-trim_right] 127 | if trunc_len > 0: 128 | amp = amp[:trunc_len] 129 | if trim_left > 0: 130 | amp = amp[trim_left:] 131 | if min_length > 0 and len(amp) < min_length: 132 | return 133 | if not amp: 134 | return 135 | return amp 136 | 137 | 138 | def extract_reads(sequences: DNASequencesDirectoryFormat, f_primer: str, 139 | r_primer: str, trim_right: int = 0, 140 | trunc_len: int = 0, trim_left: int = 0, 141 | identity: float = 0.8, min_length: int = 50, 142 | max_length: int = 0, n_jobs: int = 1, 143 | batch_size: int = 'auto', read_orientation: str = 'both') \ 144 | -> DNAFASTAFormat: 145 | """Extract the read selected by a primer or primer pair. Only sequences 146 | which match the primers at greater than the specified identity are returned 147 | 148 | Parameters 149 | ---------- 150 | sequences : DNASequencesDirectoryFormat 151 | An aligned list of skbio.sequence.DNA query sequences 152 | f_primer : skbio.sequence.DNA 153 | Forward primer sequence 154 | r_primer : skbio.sequence.DNA 155 | Reverse primer sequence 156 | trim_right : int, optional 157 | `trim_right` nucleotides are removed from the 3' end if trim_right is 158 | positive. Applied before trunc_len. 159 | trunc_len : int, optional 160 | Read is cut to trunc_len if trunc_len is positive. Applied after 161 | trim_right. 162 | trim_left : int, optional 163 | `trim_left` nucleotides are removed from the 5' end if trim_left is 164 | positive. Applied after trim_right and trunc_len. 165 | identity : float, optional 166 | Minimum combined primer match identity threshold. Default: 0.8 167 | min_length: int, optional 168 | Minimum amplicon length. Shorter amplicons are discarded. Default: 50 169 | max_length: int, optional 170 | Maximum amplicon length. Longer amplicons are discarded. 171 | n_jobs: int, optional 172 | Number of seperate processes to break the task into. 173 | batch_size: int, optional 174 | Number of samples to be processed in one batch. 175 | read_orientation: str, optional 176 | 'Orientation of primers relative to the sequences: "forward" searches ' 177 | 'for primer hits in the forward direction, "reverse" searches the ' 178 | 'reverse-complement, and "both" searches both directions.' 179 | Returns 180 | ------- 181 | q2_types.DNAFASTAFormat 182 | containing the reads 183 | """ 184 | if min_length > trunc_len - (trim_left + trim_right) and trunc_len > 0: 185 | raise ValueError('The minimum length setting is greater than the ' 186 | 'length of the truncated sequences. This will cause ' 187 | 'all sequences to be removed from the dataset. To ' 188 | 'proceed, set ' 189 | 'min_length ≤ trunc_len - (trim_left + ' 190 | 'trim_right).') 191 | 192 | n_jobs = effective_n_jobs(n_jobs) 193 | if batch_size == 'auto': 194 | batch_size = _autotune_reads_per_batch( 195 | sequences.file.view(DNAFASTAFormat), n_jobs) 196 | sequences = sequences.file.view(DNAIterator) 197 | ff = DNAFASTAFormat() 198 | with open(str(ff), 'a') as fh: 199 | with Parallel(n_jobs) as parallel: 200 | for chunk in _chunks(sequences, batch_size): 201 | amplicons = parallel(delayed(_gen_reads)(sequence, f_primer, 202 | r_primer, 203 | trim_right, 204 | trunc_len, 205 | trim_left, 206 | identity, 207 | min_length, 208 | max_length, 209 | read_orientation) 210 | for sequence in chunk) 211 | for amplicon in amplicons: 212 | if amplicon is not None: 213 | skbio.write(amplicon, format='fasta', into=fh) 214 | if os.stat(str(ff)).st_size == 0: 215 | raise RuntimeError("No matches found") 216 | return ff 217 | 218 | 219 | plugin.methods.register_function( 220 | function=extract_reads, 221 | inputs={'sequences': FeatureData[Sequence]}, 222 | parameters={'trunc_len': Int, 223 | 'trim_left': Int, 224 | 'trim_right': Int, 225 | 'f_primer': Str, 226 | 'r_primer': Str, 227 | 'identity': Float, 228 | 'min_length': Int % Range(0, None), 229 | 'max_length': Int % Range(0, None), 230 | 'n_jobs': Int % Range(1, None), 231 | 'batch_size': Int % Range(1, None) | Str % Choices(['auto']), 232 | 'read_orientation': Str % Choices(['both', 'forward', 233 | 'reverse'])}, 234 | outputs=[('reads', FeatureData[Sequence])], 235 | name='Extract reads from reference sequences.', 236 | description='Extract simulated amplicon reads from a reference database. ' 237 | 'Performs in-silico PCR to extract simulated amplicons from ' 238 | 'reference sequences that match the input primer sequences ' 239 | '(within the mismatch threshold specified by `identity`). ' 240 | 'Both primer sequences must be in the 5\' -> 3\' orientation. ' 241 | 'Sequences that fail to match both primers will be excluded. ' 242 | 'Reads are extracted, trimmed, and filtered in the following ' 243 | 'order: 1. reads are extracted in specified orientation; 2. ' 244 | 'primers are removed; 3. reads longer than `max_length` are ' 245 | 'removed; 4. reads are trimmed with `trim_right`; 5. reads ' 246 | 'are truncated to `trunc_len`; 6. reads are trimmed with ' 247 | '`trim_left`; 7. reads shorter than `min_length` are removed.', 248 | parameter_descriptions={ 249 | 'f_primer': 'forward primer sequence (5\' -> 3\').', 250 | 'r_primer': 'reverse primer sequence (5\' -> 3\'). Do not use reverse-' 251 | 'complemented primer sequence.', 252 | 'trim_right': 'trim_right nucleotides are removed from the 3\' end if ' 253 | 'trim_right is positive. Applied before trunc_len and ' 254 | 'trim_left.', 255 | 'trunc_len': 'read is cut to trunc_len if trunc_len is positive. ' 256 | 'Applied after trim_right but before trim_left.', 257 | 'trim_left': 'trim_left nucleotides are removed from the 5\' end if ' 258 | 'trim_left is positive. Applied after trim_right and ' 259 | 'trunc_len.', 260 | 'identity': 'minimum combined primer match identity threshold.', 261 | 'min_length': 'Minimum amplicon length. Shorter amplicons are ' 262 | 'discarded. Applied after trimming and truncation, so ' 263 | 'be aware that trimming may impact sequence retention. ' 264 | 'Set to zero to disable min length filtering.', 265 | 'max_length': 'Maximum amplicon length. Longer amplicons are ' 266 | 'discarded. Applied before trimming and truncation, ' 267 | 'so plan accordingly. Set to zero (default) to disable ' 268 | 'max length filtering.', 269 | 'n_jobs': 'Number of seperate processes to run.', 270 | 'batch_size': 'Number of sequences to process in a batch. The `auto` ' 271 | 'option is calculated from the number of sequences and ' 272 | 'number of jobs specified.', 273 | 'read_orientation': 'Orientation of primers relative to the ' 274 | 'sequences: "forward" searches for primer hits in ' 275 | 'the forward direction, "reverse" searches ' 276 | 'reverse-complement, and "both" searches both ' 277 | 'directions.'} 278 | ) 279 | -------------------------------------------------------------------------------- /q2_feature_classifier/_blast.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------------------------------------------- 2 | # Copyright (c) 2016-2025, QIIME 2 development team. 3 | # 4 | # Distributed under the terms of the Modified BSD License. 5 | # 6 | # The full license is in the file LICENSE, distributed with this software. 7 | # ---------------------------------------------------------------------------- 8 | 9 | import os 10 | import warnings 11 | import subprocess 12 | import pandas as pd 13 | from q2_types.feature_data import ( 14 | FeatureData, Taxonomy, Sequence, DNAFASTAFormat, DNAIterator, BLAST6, 15 | BLAST6Format) 16 | from .types import BLASTDBDirFmtV5, BLASTDB 17 | from qiime2.plugin import ( 18 | Int, Str, Float, Choices, Range, Bool, Threads, get_available_cores 19 | ) 20 | from .plugin_setup import plugin, citations 21 | from ._consensus_assignment import ( 22 | min_consensus_param, min_consensus_param_description, 23 | DEFAULTUNASSIGNABLELABEL) 24 | 25 | # --------------------------------------------------------------- 26 | # Reason for num_thread not being exposed. 27 | # BLAST doesn't allow threading when a subject is provided(As of 2/19/20). 28 | # num_thread was removed to prevent warning that stated: 29 | # "'Num_thread' is currently ignored when 'subject' is specified"(issue #77). 30 | # Seen here: https://github.com/qiime2/q2-feature-classifier/issues/77. 31 | # A '-subject' input is required in this function. 32 | # Therefore num_thread is not exposable. 33 | # --------------------------------------------------------------- 34 | 35 | 36 | # Specify default settings for various functions 37 | DEFAULTMAXACCEPTS = 10 38 | DEFAULTPERCENTID = 0.8 39 | DEFAULTQUERYCOV = 0.8 40 | DEFAULTSTRAND = 'both' 41 | DEFAULTEVALUE = 0.001 42 | DEFAULTMINCONSENSUS = 0.51 43 | DEFAULTOUTPUTNOHITS = True 44 | DEFAULTNUMTHREADS = 1 45 | 46 | 47 | # NOTE FOR THE FUTURE: should this be called blastn? would it be possible to 48 | # eventually generalize to e.g., blastp or blastx? or will this be too 49 | # challenging, e.g., to detect the format internally? A `mode` parameter could 50 | # be added and TypeMapped to the input type, a bit cumbersome but one way to 51 | # accomplish this without code bloat. But the question is: would we want to 52 | # expose different parameters etc? My feeling is let's call this `blast` for 53 | # now and then cross that bridge when we come to it. 54 | def blast(query: DNAFASTAFormat, 55 | reference_reads: DNAFASTAFormat = None, 56 | blastdb: BLASTDBDirFmtV5 = None, 57 | maxaccepts: int = DEFAULTMAXACCEPTS, 58 | perc_identity: float = DEFAULTPERCENTID, 59 | query_cov: float = DEFAULTQUERYCOV, 60 | strand: str = DEFAULTSTRAND, 61 | evalue: float = DEFAULTEVALUE, 62 | output_no_hits: bool = DEFAULTOUTPUTNOHITS, 63 | num_threads: int = DEFAULTNUMTHREADS) -> pd.DataFrame: 64 | if num_threads == 0: 65 | num_threads = get_available_cores() 66 | 67 | if reference_reads and blastdb: 68 | raise ValueError('Only one reference_reads or blastdb artifact ' 69 | 'can be provided as input. Choose one and try ' 70 | 'again.') 71 | perc_identity = perc_identity * 100 72 | query_cov = query_cov * 100 73 | seqs_fp = str(query) 74 | # TODO: generalize to support other blast types? 75 | output = BLAST6Format() 76 | cmd = ['blastn', '-query', seqs_fp, '-evalue', str(evalue), '-strand', 77 | strand, '-outfmt', '6', '-perc_identity', str(perc_identity), 78 | '-qcov_hsp_perc', str(query_cov), '-num_threads', str(num_threads), 79 | '-max_target_seqs', str(maxaccepts), '-out', str(output)] 80 | if reference_reads: 81 | cmd.extend(['-subject', str(reference_reads)]) 82 | if num_threads > 1: 83 | warnings.warn('The num_threads parameters is only compatible ' 84 | 'when using a pre-indexed blastdb. The num_threads ' 85 | 'is ignored when reference_reads are provided as ' 86 | 'input.', UserWarning) 87 | elif blastdb: 88 | cmd.extend(['-db', os.path.join(blastdb.path, blastdb.get_basename())]) 89 | else: 90 | raise ValueError('Either reference_reads or a blastdb must be ' 91 | 'provided as input.') 92 | _run_command(cmd) 93 | # load as dataframe to quickly validate (note: will fail now if empty) 94 | result = output.view(pd.DataFrame) 95 | 96 | # blast will not report reads with no hits. We will report this 97 | # information here, so that it is explicitly reported to the user. 98 | if output_no_hits: 99 | ids_with_hit = set(result['qseqid'].unique()) 100 | query_ids = {seq.metadata['id'] for seq in query.view(DNAIterator)} 101 | missing_ids = query_ids - ids_with_hit 102 | if len(missing_ids) > 0: 103 | # we will mirror vsearch behavior and annotate unassigneds as '*' 104 | # and fill other columns with 0 values (np.nan makes format error). 105 | missed = pd.DataFrame( 106 | [{'qseqid': i, 'sseqid': '*'} for i in missing_ids], 107 | columns=result.columns).fillna(0) 108 | result = pd.concat([result, missed], axis=0) 109 | return result 110 | 111 | 112 | def classify_consensus_blast(ctx, 113 | query, 114 | reference_taxonomy, 115 | blastdb=None, 116 | reference_reads=None, 117 | maxaccepts=DEFAULTMAXACCEPTS, 118 | perc_identity=DEFAULTPERCENTID, 119 | query_cov=DEFAULTQUERYCOV, 120 | strand=DEFAULTSTRAND, 121 | evalue=DEFAULTEVALUE, 122 | output_no_hits=DEFAULTOUTPUTNOHITS, 123 | min_consensus=DEFAULTMINCONSENSUS, 124 | unassignable_label=DEFAULTUNASSIGNABLELABEL, 125 | num_threads=DEFAULTNUMTHREADS): 126 | if num_threads == 0: 127 | num_threads = get_available_cores() 128 | 129 | search_db = ctx.get_action('feature_classifier', 'blast') 130 | lca = ctx.get_action('feature_classifier', 'find_consensus_annotation') 131 | result, = search_db(query=query, blastdb=blastdb, 132 | reference_reads=reference_reads, 133 | maxaccepts=maxaccepts, perc_identity=perc_identity, 134 | query_cov=query_cov, strand=strand, evalue=evalue, 135 | output_no_hits=output_no_hits, num_threads=num_threads) 136 | consensus, = lca(search_results=result, 137 | reference_taxonomy=reference_taxonomy, 138 | min_consensus=min_consensus, 139 | unassignable_label=unassignable_label) 140 | # New: add BLAST6Format result as an output. This could just as well be a 141 | # visualizer generated from these results (using q2-metadata tabulate). 142 | # Would that be more useful to the user that the QZA? 143 | return consensus, result 144 | 145 | 146 | def makeblastdb(sequences: DNAFASTAFormat) -> BLASTDBDirFmtV5: 147 | database = BLASTDBDirFmtV5() 148 | build_cmd = ['makeblastdb', '-blastdb_version', '5', '-dbtype', 'nucl', 149 | '-title', 'blastdb', '-in', str(sequences), 150 | '-out', os.path.join(str(database.path), 'blastdb')] 151 | _run_command(build_cmd) 152 | return database 153 | 154 | 155 | # Replace this function with QIIME2 API for wrapping commands/binaries, 156 | # pending https://github.com/qiime2/qiime2/issues/224 157 | def _run_command(cmd, verbose=True): 158 | if verbose: 159 | print("Running external command line application. This may print " 160 | "messages to stdout and/or stderr.") 161 | print("The command being run is below. This command cannot " 162 | "be manually re-run as it will depend on temporary files that " 163 | "no longer exist.") 164 | print("\nCommand:", end=' ') 165 | print(" ".join(cmd), end='\n\n') 166 | subprocess.run(cmd, check=True) 167 | 168 | 169 | inputs = {'query': FeatureData[Sequence], 170 | 'blastdb': BLASTDB, 171 | 'reference_reads': FeatureData[Sequence]} 172 | 173 | input_descriptions = {'query': 'Query sequences.', 174 | 'blastdb': 'BLAST indexed database. Incompatible with ' 175 | 'reference_reads.', 176 | 'reference_reads': 'Reference sequences. Incompatible ' 177 | 'with blastdb.'} 178 | 179 | classification_output = ('classification', FeatureData[Taxonomy]) 180 | 181 | classification_output_description = { 182 | 'classification': 'Taxonomy classifications of query sequences.'} 183 | 184 | parameters = {'evalue': Float, 185 | 'maxaccepts': Int % Range(1, None), 186 | 'perc_identity': Float % Range(0.0, 1.0, inclusive_end=True), 187 | 'query_cov': Float % Range(0.0, 1.0, inclusive_end=True), 188 | 'strand': Str % Choices(['both', 'plus', 'minus']), 189 | 'output_no_hits': Bool, 190 | 'num_threads': Threads, 191 | } 192 | 193 | parameter_descriptions = { 194 | 'evalue': 'BLAST expectation value (E) threshold for saving hits.', 195 | 'strand': ('Align against reference sequences in forward ("plus"), ' 196 | 'reverse ("minus"), or both directions ("both").'), 197 | 'maxaccepts': ('Maximum number of hits to keep for each query. BLAST will ' 198 | 'choose the first N hits in the reference database that ' 199 | 'exceed perc_identity similarity to query. NOTE: the ' 200 | 'database is not sorted by similarity to query, so these ' 201 | 'are the first N hits that pass the threshold, not ' 202 | 'necessarily the top N hits.'), 203 | 'perc_identity': ('Reject match if percent identity to query is lower.'), 204 | 'query_cov': 'Reject match if query alignment coverage per high-' 205 | 'scoring pair is lower. Note: this uses blastn\'s ' 206 | 'qcov_hsp_perc parameter, and may not behave identically ' 207 | 'to the query_cov parameter used by classify-consensus-' 208 | 'vsearch.', 209 | 'output_no_hits': 'Report both matching and non-matching queries. ' 210 | 'WARNING: always use the default setting for this ' 211 | 'option unless if you know what you are doing! If ' 212 | 'you set this option to False, your sequences and ' 213 | 'feature table will need to be filtered to exclude ' 214 | 'unclassified sequences, otherwise you may run into ' 215 | 'errors downstream from missing feature IDs. Set to ' 216 | 'FALSE to mirror default BLAST search.', 217 | 'num_threads': 'Number of threads (CPUs) to use in the BLAST search. ' 218 | 'Pass 0 to use all available CPUs.', 219 | } 220 | 221 | blast6_output = ('search_results', FeatureData[BLAST6]) 222 | 223 | blast6_output_description = {'search_results': 'Top hits for each query.'} 224 | 225 | 226 | plugin.methods.register_function( 227 | function=makeblastdb, 228 | inputs={'sequences': FeatureData[Sequence]}, 229 | parameters={}, 230 | outputs=[('database', BLASTDB)], 231 | input_descriptions={'sequences': 'Input reference sequences.'}, 232 | parameter_descriptions={}, 233 | output_descriptions={'database': 'Output BLAST database.'}, 234 | name='Make BLAST database.', 235 | description=('Make BLAST database from custom sequence collection.'), 236 | citations=[citations['camacho2009blast+']] 237 | ) 238 | 239 | 240 | # Note: name should be changed to blastn if we do NOT generalize this function 241 | plugin.methods.register_function( 242 | function=blast, 243 | inputs=inputs, 244 | parameters=parameters, 245 | outputs=[blast6_output], 246 | input_descriptions=input_descriptions, 247 | parameter_descriptions=parameter_descriptions, 248 | output_descriptions=blast6_output_description, 249 | name='BLAST+ local alignment search.', 250 | description=('Search for top hits in a reference database via local ' 251 | 'alignment between the query sequences and reference ' 252 | 'database sequences using BLAST+. Returns a report ' 253 | 'of the top M hits for each query (where M=maxaccepts).'), 254 | citations=[citations['camacho2009blast+']] 255 | ) 256 | 257 | 258 | plugin.pipelines.register_function( 259 | function=classify_consensus_blast, 260 | inputs={**inputs, 261 | 'reference_taxonomy': FeatureData[Taxonomy]}, 262 | parameters={**parameters, 263 | **min_consensus_param, 264 | 'unassignable_label': Str}, 265 | outputs=[classification_output, blast6_output], 266 | input_descriptions={**input_descriptions, 267 | 'reference_taxonomy': 'reference taxonomy labels.'}, 268 | parameter_descriptions={ 269 | **parameter_descriptions, 270 | **min_consensus_param_description, 271 | 'unassignable_label': 'Annotation given to sequences without any hits.' 272 | }, 273 | output_descriptions={**classification_output_description, 274 | **blast6_output_description}, 275 | name='BLAST+ consensus taxonomy classifier', 276 | description=('Assign taxonomy to query sequences using BLAST+. Performs ' 277 | 'BLAST+ local alignment between query and reference_reads, ' 278 | 'then assigns consensus taxonomy to each query sequence from ' 279 | 'among maxaccepts hits, min_consensus of which share ' 280 | 'that taxonomic assignment. Note that maxaccepts selects the ' 281 | 'first N hits with > perc_identity similarity to query, ' 282 | 'not the top N matches. For top N hits, use ' 283 | 'classify-consensus-vsearch.'), 284 | citations=[citations['camacho2009blast+']] 285 | ) 286 | -------------------------------------------------------------------------------- /q2_feature_classifier/tests/test_classifier.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------------------------------------------- 2 | # Copyright (c) 2016-2025, QIIME 2 development team. 3 | # 4 | # Distributed under the terms of the Modified BSD License. 5 | # 6 | # The full license is in the file LICENSE, distributed with this software. 7 | # ---------------------------------------------------------------------------- 8 | 9 | import json 10 | import os 11 | 12 | from unittest.mock import patch 13 | from qiime2.sdk import Artifact 14 | from q2_types.feature_data import DNAIterator 15 | from qiime2.plugins import feature_classifier 16 | import pandas as pd 17 | import skbio 18 | import biom 19 | 20 | from q2_feature_classifier._skl import _specific_fitters, _TaxonNode 21 | from q2_feature_classifier.classifier import spec_from_pipeline, \ 22 | pipeline_from_spec, populate_class_weight, _autotune_reads_per_batch 23 | from . import FeatureClassifierTestPluginBase 24 | 25 | 26 | class ClassifierTests(FeatureClassifierTestPluginBase): 27 | package = 'q2_feature_classifier.tests' 28 | 29 | def setUp(self): 30 | super().setUp() 31 | self.taxonomy = Artifact.import_data( 32 | 'FeatureData[Taxonomy]', self.get_data_path('taxonomy.tsv')) 33 | 34 | self.seq_path = self.get_data_path('se-dna-sequences.fasta') 35 | reads = Artifact.import_data('FeatureData[Sequence]', self.seq_path) 36 | fitter_name = _specific_fitters[0][0] 37 | fitter = getattr(feature_classifier.methods, 38 | 'fit_classifier_' + fitter_name) 39 | self.classifier = fitter(reads, self.taxonomy).classifier 40 | 41 | def test_fit_classifier(self): 42 | # fit_classifier should generate a working taxonomic_classifier 43 | reads = Artifact.import_data( 44 | 'FeatureData[Sequence]', 45 | self.get_data_path('se-dna-sequences.fasta')) 46 | 47 | classify = feature_classifier.methods.classify_sklearn 48 | result = classify(reads, self.classifier) 49 | 50 | ref = self.taxonomy.view(pd.Series).to_dict() 51 | classified = result.classification.view(pd.Series).to_dict() 52 | 53 | right = 0. 54 | for taxon in classified: 55 | right += ref[taxon].startswith(classified[taxon]) 56 | self.assertGreater(right/len(classified), 0.95) 57 | 58 | def test_populate_class_weight(self): 59 | # should populate the class weight of a pipeline 60 | weights = Artifact.import_data( 61 | 'FeatureTable[RelativeFrequency]', 62 | self.get_data_path('class_weight.biom')) 63 | table = weights.view(biom.Table) 64 | 65 | svc_spec = [['feat_ext', 66 | {'__type__': 'feature_extraction.text.HashingVectorizer', 67 | 'analyzer': 'char_wb', 68 | 'n_features': 8192, 69 | 'ngram_range': (8, 8), 70 | 'alternate_sign': False}], 71 | ['classify', 72 | {'__type__': 'naive_bayes.GaussianNB'}]] 73 | pipeline1 = pipeline_from_spec(svc_spec) 74 | populate_class_weight(pipeline1, table) 75 | 76 | classes = table.ids('observation') 77 | class_weights = [] 78 | for wts in table.iter_data(): 79 | class_weights.append(zip(classes, wts)) 80 | svc_spec[1][1]['priors'] = list(zip(*sorted(class_weights[0])))[1] 81 | pipeline2 = pipeline_from_spec(svc_spec) 82 | 83 | for a, b in zip(pipeline1.get_params()['classify__priors'], 84 | pipeline2.get_params()['classify__priors']): 85 | self.assertAlmostEqual(a, b) 86 | 87 | def test_class_weight(self): 88 | # we should be able to input class_weight to fit_classifier 89 | weights = Artifact.import_data( 90 | 'FeatureTable[RelativeFrequency]', 91 | self.get_data_path('class_weight.biom')) 92 | reads = Artifact.import_data( 93 | 'FeatureData[Sequence]', 94 | self.get_data_path('se-dna-sequences.fasta')) 95 | 96 | fitter = feature_classifier.methods.fit_classifier_naive_bayes 97 | classifier1 = fitter(reads, self.taxonomy, class_weight=weights) 98 | classifier1 = classifier1.classifier 99 | 100 | class_weight = weights.view(biom.Table) 101 | classes = class_weight.ids('observation') 102 | class_weights = [] 103 | for wts in class_weight.iter_data(): 104 | class_weights.append(zip(classes, wts)) 105 | priors = json.dumps(list(zip(*sorted(class_weights[0])))[1]) 106 | classifier2 = fitter(reads, self.taxonomy, 107 | classify__class_prior=priors).classifier 108 | 109 | classify = feature_classifier.methods.classify_sklearn 110 | result1 = classify(reads, classifier1) 111 | result1 = result1.classification.view(pd.Series).to_dict() 112 | result2 = classify(reads, classifier2) 113 | result2 = result2.classification.view(pd.Series).to_dict() 114 | self.assertEqual(result1, result2) 115 | 116 | svc_spec = [['feat_ext', 117 | {'__type__': 'feature_extraction.text.HashingVectorizer', 118 | 'analyzer': 'char_wb', 119 | 'n_features': 8192, 120 | 'ngram_range': (8, 8), 121 | 'alternate_sign': False}], 122 | ['classify', 123 | {'__type__': 'linear_model.LogisticRegression'}]] 124 | classifier_spec = json.dumps(svc_spec) 125 | gen_fitter = feature_classifier.methods.fit_classifier_sklearn 126 | classifier1 = gen_fitter(reads, self.taxonomy, classifier_spec, 127 | class_weight=weights).classifier 128 | 129 | svc_spec[1][1]['class_weight'] = dict(class_weights[0]) 130 | classifier_spec = json.dumps(svc_spec) 131 | gen_fitter = feature_classifier.methods.fit_classifier_sklearn 132 | classifier2 = gen_fitter(reads, self.taxonomy, classifier_spec 133 | ).classifier 134 | 135 | result1 = classify(reads, classifier1) 136 | result1 = result1.classification.view(pd.Series).to_dict() 137 | result2 = classify(reads, classifier2) 138 | result2 = result2.classification.view(pd.Series).to_dict() 139 | self.assertEqual(set(result1.keys()), set(result2.keys())) 140 | for k in result1: 141 | self.assertEqual(result1[k], result2[k]) 142 | 143 | def test_fit_specific_classifiers(self): 144 | # specific and general classifiers should produce the same results 145 | gen_fitter = feature_classifier.methods.fit_classifier_sklearn 146 | classify = feature_classifier.methods.classify_sklearn 147 | reads = Artifact.import_data( 148 | 'FeatureData[Sequence]', 149 | self.get_data_path('se-dna-sequences.fasta')) 150 | 151 | for name, spec in _specific_fitters: 152 | classifier_spec = json.dumps(spec) 153 | result = gen_fitter(reads, self.taxonomy, classifier_spec) 154 | result = classify(reads, result.classifier) 155 | gc = result.classification.view(pd.Series).to_dict() 156 | spec_fitter = getattr(feature_classifier.methods, 157 | 'fit_classifier_' + name) 158 | result = spec_fitter(reads, self.taxonomy) 159 | result = classify(reads, result.classifier) 160 | sc = result.classification.view(pd.Series).to_dict() 161 | for taxon in gc: 162 | self.assertEqual(gc[taxon], sc[taxon]) 163 | 164 | def test_pipeline_serialisation(self): 165 | # pipeline inflation and deflation should be inverse operations 166 | for name, spec in _specific_fitters: 167 | pipeline = pipeline_from_spec(spec) 168 | spec_one = spec_from_pipeline(pipeline) 169 | pipeline = pipeline_from_spec(spec_one) 170 | spec_two = spec_from_pipeline(pipeline) 171 | self.assertEqual(spec_one, spec_two) 172 | 173 | def test_classify(self): 174 | # test read direction detection and parallel classification 175 | classify = feature_classifier.methods.classify_sklearn 176 | seq_path = self.get_data_path('se-dna-sequences.fasta') 177 | reads = Artifact.import_data('FeatureData[Sequence]', seq_path) 178 | raw_reads = skbio.io.read( 179 | seq_path, format='fasta', constructor=skbio.DNA) 180 | rev_path = os.path.join(self.temp_dir.name, 'rev-dna-sequences.fasta') 181 | skbio.io.write((s.reverse_complement() for s in raw_reads), 182 | 'fasta', rev_path) 183 | rev_reads = Artifact.import_data('FeatureData[Sequence]', rev_path) 184 | 185 | result = classify(reads, self.classifier, 186 | read_orientation='auto') 187 | fc = result.classification.view(pd.Series).to_dict() 188 | result = classify(rev_reads, self.classifier, 189 | read_orientation='auto') 190 | rc = result.classification.view(pd.Series).to_dict() 191 | 192 | for taxon in fc: 193 | self.assertEqual(rc[taxon], fc[taxon]) 194 | 195 | result = classify(reads, self.classifier, read_orientation='same') 196 | fc = result.classification.view(pd.Series).to_dict() 197 | result = classify(rev_reads, self.classifier, 198 | read_orientation='reverse-complement') 199 | rc = result.classification.view(pd.Series).to_dict() 200 | 201 | for taxon in fc: 202 | self.assertEqual(fc[taxon], rc[taxon]) 203 | 204 | result = classify(reads, self.classifier, reads_per_batch=100, 205 | n_jobs=2, read_orientation='auto') 206 | cc = result.classification.view(pd.Series).to_dict() 207 | 208 | for taxon in fc: 209 | self.assertEqual(fc[taxon], cc[taxon]) 210 | 211 | def test_unassigned_taxa(self): 212 | # classifications that don't meet the threshold should be "Unassigned" 213 | classify = feature_classifier.methods.classify_sklearn 214 | seq_path = self.get_data_path('se-dna-sequences.fasta') 215 | reads = Artifact.import_data('FeatureData[Sequence]', seq_path) 216 | result = classify(reads, self.classifier, confidence=1.) 217 | 218 | ref = self.taxonomy.view(pd.Series).to_dict() 219 | classified = result.classification.view(pd.Series).to_dict() 220 | 221 | assert 'Unassigned' in classified.values() 222 | for seq in reads.view(DNAIterator): 223 | id_ = seq.metadata['id'] 224 | assert ref[id_].startswith(classified[id_]) or \ 225 | classified[id_] == 'Unassigned' 226 | 227 | def test_autotune_reads_per_batch(self): 228 | self.assertEqual( 229 | _autotune_reads_per_batch(self.seq_path, n_jobs=4), 276) 230 | 231 | def test_autotune_reads_per_batch_disable_if_single_job(self): 232 | self.assertEqual( 233 | _autotune_reads_per_batch(self.seq_path, n_jobs=1), 20000) 234 | 235 | def test_autotune_reads_per_batch_zero_jobs(self): 236 | with self.assertRaisesRegex( 237 | ValueError, "Value other than zero must be specified"): 238 | _autotune_reads_per_batch(self.seq_path, n_jobs=0) 239 | 240 | def test_autotune_reads_per_batch_ceil(self): 241 | self.assertEqual( 242 | _autotune_reads_per_batch(self.seq_path, n_jobs=5), 221) 243 | 244 | def test_autotune_reads_per_batch_more_jobs_than_reads(self): 245 | self.assertEqual( 246 | _autotune_reads_per_batch(self.seq_path, n_jobs=1105), 1) 247 | 248 | def test_TaxonNode_create_tree(self): 249 | classes = ['a;b;c', 'a;b;d', 'a;e;f', 'a;e;g'] 250 | separator = ';' 251 | tree = _TaxonNode.create_tree(classes, separator) 252 | self.assertEqual( 253 | tree.children['a'].children['b'].children['c'].name, 'c') 254 | self.assertEqual( 255 | tree.children['a'].children['b'].children['d'].name, 'd') 256 | self.assertEqual( 257 | tree.children['a'].children['e'].children['f'].name, 'f') 258 | self.assertEqual( 259 | tree.children['a'].children['e'].children['g'].name, 'g') 260 | 261 | def test_TaxonNode_range(self): 262 | classes = ['a;b;c', 'a;b;d', 'a;e;f', 'a;e;g'] 263 | separator = ';' 264 | tree = _TaxonNode.create_tree(classes, separator) 265 | self.assertEqual( 266 | tree.children['a'].children['b'].children['c'].range, range(0, 1)) 267 | self.assertEqual( 268 | tree.children['a'].children['b'].children['d'].range, range(1, 2)) 269 | self.assertEqual( 270 | tree.children['a'].children['e'].children['f'].range, range(2, 3)) 271 | self.assertEqual( 272 | tree.children['a'].children['e'].children['g'].range, range(3, 4)) 273 | self.assertEqual( 274 | tree.children['a'].children['b'].range, range(0, 2)) 275 | self.assertEqual( 276 | tree.children['a'].children['e'].range, range(2, 4)) 277 | 278 | def test_TaxonNode_num_leaf_nodes(self): 279 | classes = ['a;b;c', 'a;b;d', 'a;e;f', 'a;e;g'] 280 | separator = ';' 281 | tree = _TaxonNode.create_tree(classes, separator) 282 | self.assertEqual(tree.num_leaf_nodes, 4) 283 | self.assertEqual(tree.children['a'].num_leaf_nodes, 4) 284 | self.assertEqual(tree.children['a'].children['b'].num_leaf_nodes, 2) 285 | self.assertEqual(tree.children['a'].children['e'].num_leaf_nodes, 2) 286 | 287 | def test_both_orientations_patched_data(self): 288 | """ 289 | This function tests the functionality of the `both` orientation 290 | option for `classify_sklearn` by using patched data and asserting that 291 | the `both` data frame always contains the classifications with higher 292 | confidence. 293 | """ 294 | with patch('q2_feature_classifier.classifier.predict') as mock_predict: 295 | mock_predict.side_effect = [ 296 | [('DNA_SEQUENCE_1', 'k__Bacteria, p__A', 0.6), 297 | ('DNA_SEQUENCE_2', 'k__Bacteria, p__B', 0.9)], 298 | 299 | [('DNA_SEQUENCE_1', 'k__Bacteria, p__C', 0.9), 300 | ('DNA_SEQUENCE_2', 'k__Bacteria, p__D', 0.6)], 301 | 302 | [('DNA_SEQUENCE_1', 'k__Bacteria, p__C', 0.9), 303 | ('DNA_SEQUENCE_2', 'k__Bacteria, p__D', 0.6)], 304 | 305 | [('DNA_SEQUENCE_1', 'k__Bacteria, p__A', 0.6), 306 | ('DNA_SEQUENCE_2', 'k__Bacteria, p__B', 0.9)] 307 | ] 308 | 309 | classify = feature_classifier.methods.classify_sklearn 310 | seq_path = self.get_data_path('dna_sequence_both_test.fasta') 311 | reads = Artifact.import_data('FeatureData[Sequence]', seq_path) 312 | class_fwd = classify(reads, self.classifier, 313 | read_orientation='same') 314 | class_rev = classify(reads, self.classifier, 315 | read_orientation='reverse-complement') 316 | class_both = classify(reads, self.classifier, 317 | read_orientation='both') 318 | 319 | fc_df = class_fwd.classification.view(pd.DataFrame) 320 | rc_df = class_rev.classification.view(pd.DataFrame) 321 | bc_df = class_both.classification.view(pd.DataFrame) 322 | conf_fwd = float(fc_df.loc['DNA_SEQUENCE_1', 'Confidence']) 323 | conf_rev = float(rc_df.loc['DNA_SEQUENCE_1', 'Confidence']) 324 | self.assertGreater(conf_rev, conf_fwd) 325 | 326 | bc_tax_1 = bc_df.loc['DNA_SEQUENCE_1', 'Taxon'] 327 | rc_tax_1 = rc_df.loc['DNA_SEQUENCE_1', 'Taxon'] 328 | fc_tax_1 = fc_df.loc['DNA_SEQUENCE_1', 'Taxon'] 329 | self.assertNotEqual(fc_tax_1, rc_tax_1) 330 | self.assertEqual(bc_tax_1, rc_tax_1) 331 | 332 | conf_fwd_2 = float(fc_df.loc['DNA_SEQUENCE_2', 'Confidence']) 333 | conf_rev_2 = float(rc_df.loc['DNA_SEQUENCE_2', 'Confidence']) 334 | self.assertGreater(conf_fwd_2, conf_rev_2) 335 | 336 | bc_tax_2 = bc_df.loc['DNA_SEQUENCE_2', 'Taxon'] 337 | rc_tax_2 = rc_df.loc['DNA_SEQUENCE_2', 'Taxon'] 338 | fc_tax_2 = fc_df.loc['DNA_SEQUENCE_2', 'Taxon'] 339 | self.assertNotEqual(rc_tax_2, fc_tax_2) 340 | self.assertEqual(bc_tax_2, fc_tax_2) 341 | 342 | def test_both_orientations_real_data(self): 343 | """ 344 | This tests the functionality of the `both` orientation option for 345 | `classify_sklearn` by asserting that the `both` data frame always 346 | contains the classification with higher confidence. 347 | """ 348 | classify = feature_classifier.methods.classify_sklearn 349 | sequence_path = self.get_data_path('moving-pictures-rep-seqs.fasta') 350 | reads = Artifact.import_data('FeatureData[Sequence]', sequence_path) 351 | 352 | class_fwd = classify(reads, self.classifier, read_orientation='same') 353 | class_rev = classify( 354 | reads, self.classifier, read_orientation='reverse-complement' 355 | ) 356 | class_both = classify(reads, self.classifier, read_orientation='both') 357 | 358 | fwd_df = class_fwd.classification.view(pd.DataFrame) 359 | rev_df = class_rev.classification.view(pd.DataFrame) 360 | both_df = class_both.classification.view(pd.DataFrame) 361 | 362 | for feature in both_df.index: 363 | if ( 364 | fwd_df.loc[feature, 'Confidence'] >= 365 | rev_df.loc[feature, 'Confidence'] 366 | ): 367 | higher_df = fwd_df 368 | else: 369 | higher_df = rev_df 370 | 371 | self.assertEqual( 372 | both_df.loc[feature, 'Taxon'], higher_df.loc[feature, 'Taxon'] 373 | ) 374 | self.assertEqual( 375 | both_df.loc[feature, 'Confidence'], 376 | higher_df.loc[feature, 'Confidence'] 377 | ) 378 | -------------------------------------------------------------------------------- /q2_feature_classifier/classifier.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------------------------------------------- 2 | # Copyright (c) 2016-2025, QIIME 2 development team. 3 | # 4 | # Distributed under the terms of the Modified BSD License. 5 | # 6 | # The full license is in the file LICENSE, distributed with this software. 7 | # ---------------------------------------------------------------------------- 8 | 9 | import json 10 | import importlib 11 | import inspect 12 | import warnings 13 | from itertools import chain, islice 14 | import subprocess 15 | 16 | import pandas as pd 17 | from qiime2.plugin import ( 18 | Int, Str, Float, Bool, Choices, Range, Threads, get_available_cores 19 | ) 20 | from q2_types.feature_data import ( 21 | FeatureData, Taxonomy, Sequence, DNAIterator, DNAFASTAFormat) 22 | from q2_types.feature_table import FeatureTable, RelativeFrequency 23 | from sklearn.pipeline import Pipeline 24 | import sklearn 25 | from numpy import median, array, ceil 26 | import biom 27 | import skbio 28 | import joblib 29 | 30 | from ._skl import fit_pipeline, predict, _specific_fitters 31 | from ._taxonomic_classifier import TaxonomicClassifier 32 | from .plugin_setup import plugin, citations 33 | 34 | 35 | def _load_class(classname): 36 | err_message = classname + ' is not a recognised class' 37 | if '.' not in classname: 38 | raise ValueError(err_message) 39 | module, klass = classname.rsplit('.', 1) 40 | if module == 'custom': 41 | module = importlib.import_module('.custom', 'q2_feature_classifier') 42 | elif importlib.util.find_spec('.'+module, 'sklearn') is not None: 43 | module = importlib.import_module('.'+module, 'sklearn') 44 | else: 45 | raise ValueError(err_message) 46 | if not hasattr(module, klass): 47 | raise ValueError(err_message) 48 | klass = getattr(module, klass) 49 | if not issubclass(klass, sklearn.base.BaseEstimator): 50 | raise ValueError(err_message) 51 | return klass 52 | 53 | 54 | def spec_from_pipeline(pipeline): 55 | class StepsEncoder(json.JSONEncoder): 56 | def default(self, obj): 57 | if hasattr(obj, 'get_params'): 58 | encoded = {} 59 | params = obj.get_params() 60 | subobjs = [] 61 | for key, value in params.items(): 62 | if hasattr(value, 'get_params'): 63 | subobjs.append(key + '__') 64 | 65 | for key, value in params.items(): 66 | for so in subobjs: 67 | if key.startswith(so): 68 | break 69 | else: 70 | if hasattr(value, 'get_params'): 71 | encoded[key] = self.default(value) 72 | try: 73 | json.dumps(value, cls=StepsEncoder) 74 | encoded[key] = value 75 | except TypeError: 76 | pass 77 | 78 | module = obj.__module__ 79 | type = module + '.' + obj.__class__.__name__ 80 | encoded['__type__'] = type.split('.', 1)[1] 81 | return encoded 82 | return json.JSONEncoder.default(self, obj) 83 | steps = pipeline.get_params()['steps'] 84 | return json.loads(json.dumps(steps, cls=StepsEncoder)) 85 | 86 | 87 | def pipeline_from_spec(spec): 88 | def as_steps(obj): 89 | if 'ngram_range' in obj: 90 | obj['ngram_range'] = tuple(obj['ngram_range']) 91 | if '__type__' in obj: 92 | klass = _load_class(obj['__type__']) 93 | return klass(**{k: v for k, v in obj.items() if k != '__type__'}) 94 | return obj 95 | 96 | steps = json.loads(json.dumps(spec), object_hook=as_steps) 97 | return Pipeline(steps) 98 | 99 | 100 | def warn_about_sklearn(): 101 | warning = ( 102 | 'The TaxonomicClassifier artifact that results from this method was ' 103 | 'trained using scikit-learn version %s. It cannot be used with other ' 104 | 'versions of scikit-learn. (While the classifier may complete ' 105 | 'successfully, the results will be unreliable.)' % sklearn.__version__) 106 | warnings.warn(warning, UserWarning) 107 | 108 | 109 | def populate_class_weight(pipeline, class_weight): 110 | classes = class_weight.ids('observation') 111 | class_weights = [] 112 | for weights in class_weight.iter_data(): 113 | class_weights.append(zip(classes, weights)) 114 | step, classifier = pipeline.steps[-1] 115 | for param in classifier.get_params(): 116 | if param == 'class_weight': 117 | class_weights = list(map(dict, class_weights)) 118 | if len(class_weights) == 1: 119 | class_weights = class_weights[0] 120 | pipeline.set_params(**{'__'.join([step, param]): class_weights}) 121 | elif param in ('priors', 'class_prior'): 122 | if len(class_weights) != 1: 123 | raise ValueError('naive_bayes classifiers do not support ' 124 | 'multilabel classification') 125 | priors = list(zip(*sorted(class_weights[0])))[1] 126 | pipeline.set_params(**{'__'.join([step, param]): priors}) 127 | return pipeline 128 | 129 | 130 | def fit_classifier_sklearn(reference_reads: DNAIterator, 131 | reference_taxonomy: pd.Series, 132 | classifier_specification: str, 133 | class_weight: biom.Table = None) -> Pipeline: 134 | warn_about_sklearn() 135 | spec = json.loads(classifier_specification) 136 | pipeline = pipeline_from_spec(spec) 137 | if class_weight is not None: 138 | pipeline = populate_class_weight(pipeline, class_weight) 139 | pipeline = fit_pipeline(reference_reads, reference_taxonomy, pipeline) 140 | return pipeline 141 | 142 | 143 | plugin.methods.register_function( 144 | function=fit_classifier_sklearn, 145 | inputs={'reference_reads': FeatureData[Sequence], 146 | 'reference_taxonomy': FeatureData[Taxonomy], 147 | 'class_weight': FeatureTable[RelativeFrequency]}, 148 | parameters={'classifier_specification': Str}, 149 | outputs=[('classifier', TaxonomicClassifier)], 150 | name='Train an almost arbitrary scikit-learn classifier', 151 | description='Train a scikit-learn classifier to classify reads.', 152 | citations=[citations['pedregosa2011scikit']] 153 | ) 154 | 155 | 156 | def _autodetect_orientation(reads, classifier, n=100, 157 | read_orientation=None): 158 | reads = iter(reads) 159 | try: 160 | read = next(reads) 161 | except StopIteration: 162 | raise ValueError('empty reads input') 163 | if not hasattr(classifier, "predict_proba"): 164 | warnings.warn("this classifier does not support confidence values, " 165 | "so read orientation autodetection is disabled", 166 | UserWarning) 167 | return reads 168 | reads = chain([read], reads) 169 | if read_orientation == 'same': 170 | return reads 171 | if read_orientation == 'reverse-complement': 172 | return (r.reverse_complement() for r in reads) 173 | if read_orientation == 'both': 174 | return reads 175 | first_n_reads = list(islice(reads, n)) 176 | result = list(zip(*predict(first_n_reads, classifier, confidence=0.))) 177 | _, _, same_confidence = result 178 | reversed_n_reads = [r.reverse_complement() for r in first_n_reads] 179 | result = list(zip(*predict(reversed_n_reads, classifier, confidence=0.))) 180 | _, _, reverse_confidence = result 181 | if median(array(same_confidence) - array(reverse_confidence)) > 0.: 182 | return chain(first_n_reads, reads) 183 | return chain(reversed_n_reads, (r.reverse_complement() for r in reads)) 184 | 185 | 186 | def _autotune_reads_per_batch(reads, n_jobs): 187 | # detect effective jobs. Will raise error if n_jobs == 0 188 | if n_jobs == 0: 189 | raise ValueError("Value other than zero must be specified as number " 190 | "of jobs to run.") 191 | else: 192 | n_jobs = joblib.effective_n_jobs(n_jobs) 193 | 194 | # we really only want to calculate this if running in parallel 195 | if n_jobs != 1: 196 | seq_count = subprocess.run( 197 | ['grep', '-c', '^>', str(reads)], check=True, 198 | stdout=subprocess.PIPE) 199 | # set a max value to avoid blowing up memory 200 | return min(int(ceil(int(seq_count.stdout.decode('utf-8')) / n_jobs)), 201 | 20000) 202 | # otherwise reads_per_batch = 20000, which has a modest memory overhead 203 | else: 204 | return 20000 205 | 206 | 207 | def classify_sklearn(reads: DNAFASTAFormat, classifier: Pipeline, 208 | reads_per_batch: int = 'auto', n_jobs: int = 1, 209 | pre_dispatch: str = '2*n_jobs', confidence: float = 0.7, 210 | read_orientation: str = 'auto' 211 | ) -> pd.DataFrame: 212 | 213 | if n_jobs == 0: 214 | n_jobs = get_available_cores() 215 | 216 | try: 217 | # autotune reads per batch 218 | if reads_per_batch == 'auto': 219 | reads_per_batch = _autotune_reads_per_batch(reads, n_jobs) 220 | 221 | # transform reads to DNAIterator 222 | reads_iter = DNAIterator( 223 | skbio.read(str(reads), format='fasta', constructor=skbio.DNA)) 224 | reads_iter = _autodetect_orientation( 225 | reads_iter, classifier, read_orientation=read_orientation) 226 | 227 | if read_orientation == 'both': 228 | same_predict = predict( 229 | reads_iter, 230 | classifier, 231 | chunk_size=reads_per_batch, 232 | n_jobs=n_jobs, 233 | pre_dispatch=pre_dispatch, 234 | confidence=confidence 235 | ) 236 | reads_reverse_iter = DNAIterator( 237 | skbio.read(str(reads), format='fasta', constructor=skbio.DNA)) 238 | reverse_comp_predict = predict( 239 | (r.reverse_complement() for r in reads_reverse_iter), 240 | classifier, 241 | chunk_size=reads_per_batch, 242 | n_jobs=n_jobs, 243 | pre_dispatch=pre_dispatch, 244 | confidence=confidence 245 | ) 246 | seq_ids_same, taxonomy_same, confidence_same = list(zip( 247 | *same_predict)) 248 | seq_ids_rc, taxonomy_rc, confidence_rc = list(zip( 249 | *reverse_comp_predict)) 250 | 251 | data_frame_forward = pd.DataFrame( 252 | {'Forward Taxon': taxonomy_same, 253 | 'Forward Confidence': confidence_same, 254 | 'Feature ID': seq_ids_same} 255 | ) 256 | 257 | data_frame_rc = pd.DataFrame( 258 | {'Reverse Taxon': taxonomy_rc, 259 | 'Reverse Confidence': confidence_rc, 260 | 'Feature ID': seq_ids_rc} 261 | ) 262 | 263 | result = pd.merge( 264 | data_frame_forward, data_frame_rc, on='Feature ID' 265 | ) 266 | 267 | def choose_confidence(row): 268 | if row['Forward Confidence'] >= row['Reverse Confidence']: 269 | return row['Forward Confidence'] 270 | else: 271 | return row['Reverse Confidence'] 272 | 273 | def choose_taxonomy(row): 274 | if row['Forward Confidence'] >= row['Reverse Confidence']: 275 | return row['Forward Taxon'] 276 | else: 277 | return row['Reverse Taxon'] 278 | 279 | result["Confidence Final"] = result.apply( 280 | choose_confidence, axis=1 281 | ) 282 | result['Taxon Final'] = result.apply(choose_taxonomy, axis=1) 283 | 284 | result.rename( 285 | columns={ 286 | 'Taxon Final': 'Taxon', 'Confidence Final': 'Confidence' 287 | }, 288 | inplace=True 289 | ) 290 | result = result[['Taxon', 'Confidence', 'Feature ID']] 291 | result.set_index('Feature ID', inplace=True) 292 | result.index.name = 'Feature ID' 293 | 294 | return result 295 | 296 | predictions = predict( 297 | reads_iter, 298 | classifier, 299 | chunk_size=reads_per_batch, 300 | n_jobs=n_jobs, 301 | pre_dispatch=pre_dispatch, 302 | confidence=confidence 303 | ) 304 | seq_ids, taxonomy, confidence = list(zip(*predictions)) 305 | 306 | result = pd.DataFrame({'Taxon': taxonomy, 'Confidence': confidence}, 307 | index=seq_ids, columns=['Taxon', 'Confidence']) 308 | result.index.name = 'Feature ID' 309 | return result 310 | except MemoryError: 311 | raise MemoryError("The operation has run out of available memory. " 312 | "To correct this error:\n" 313 | "1. Reduce the reads per batch\n" 314 | "2. Reduce number of n_jobs being performed\n" 315 | "3. Use a more powerful machine or allocate " 316 | "more resources ") 317 | 318 | 319 | _classify_parameters = { 320 | 'reads_per_batch': Int % Range(1, None) | Str % Choices(['auto']), 321 | 'n_jobs': Threads, 322 | 'pre_dispatch': Str, 323 | 'confidence': Float % Range( 324 | 0, 1, inclusive_start=True, inclusive_end=True) | Str % Choices( 325 | ['disable']), 326 | 'read_orientation': Str % Choices(['same', 'reverse-complement', 'auto', 327 | 'both'])} 328 | 329 | _parameter_descriptions = { 330 | 'confidence': 'Confidence threshold for limiting ' 331 | 'taxonomic depth. Set to "disable" to disable ' 332 | 'confidence calculation, or 0 to calculate ' 333 | 'confidence but not apply it to limit the ' 334 | 'taxonomic depth of the assignments.', 335 | 'read_orientation': 'Direction of reads with ' 336 | 'respect to reference sequences. same will cause ' 337 | 'reads to be classified unchanged; reverse-' 338 | 'complement will cause reads to be reversed ' 339 | 'and complemented prior to classification. ' 340 | 'Both will classify sequences unchanged and in ' 341 | 'reverse-complement and retain the ' 342 | 'classification with higher confidence. ' 343 | 'auto will autodetect orientation based on the ' 344 | 'confidence estimates for the first 100 reads.', 345 | 'reads_per_batch': 'Number of reads to process in each batch. If "auto", ' 346 | 'this parameter is autoscaled to ' 347 | 'min( number of query sequences / n_jobs, 20000).', 348 | 'n_jobs': 'The maximum number of concurrent worker processes. If 0 ' 349 | 'all CPUs are used. If 1 is given, no parallel computing ' 350 | 'code is used at all, which is useful for debugging.', 351 | 'pre_dispatch': '"all" or expression, as in "3*n_jobs". The number of ' 352 | 'batches (of tasks) to be pre-dispatched.' 353 | } 354 | 355 | plugin.methods.register_function( 356 | function=classify_sklearn, 357 | inputs={'reads': FeatureData[Sequence], 358 | 'classifier': TaxonomicClassifier}, 359 | parameters=_classify_parameters, 360 | outputs=[('classification', FeatureData[Taxonomy])], 361 | name='Pre-fitted sklearn-based taxonomy classifier', 362 | description='Classify reads by taxon using a fitted classifier.', 363 | input_descriptions={ 364 | 'reads': 'The feature data to be classified.', 365 | 'classifier': 'The taxonomic classifier for classifying the reads.' 366 | }, 367 | parameter_descriptions={**_parameter_descriptions}, 368 | citations=[citations['pedregosa2011scikit']] 369 | ) 370 | 371 | 372 | def _pipeline_signature(spec): 373 | type_map = {int: Int, float: Float, bool: Bool, str: Str} 374 | parameters = {} 375 | signature_params = [] 376 | pipeline = pipeline_from_spec(spec) 377 | params = pipeline.get_params() 378 | for param, default in sorted(params.items()): 379 | # weed out pesky memory parameter from skl 380 | # https://github.com/qiime2/q2-feature-classifier/issues/101 381 | if param == 'memory': 382 | continue 383 | try: 384 | json.dumps(default) 385 | except TypeError: 386 | continue 387 | kind = inspect.Parameter.POSITIONAL_OR_KEYWORD 388 | if type(default) in type_map: 389 | annotation = type(default) 390 | else: 391 | annotation = str 392 | default = json.dumps(default) 393 | new_param = inspect.Parameter(param, kind, default=default, 394 | annotation=annotation) 395 | signature_params.append(new_param) 396 | parameters[param] = type_map.get(annotation, Str) 397 | return parameters, signature_params 398 | 399 | 400 | def _register_fitter(name, spec): 401 | parameters, signature_params = _pipeline_signature(spec) 402 | 403 | def generic_fitter(reference_reads: DNAIterator, 404 | reference_taxonomy: pd.Series, 405 | class_weight: biom.Table = None, **kwargs) -> Pipeline: 406 | warn_about_sklearn() 407 | for param in kwargs: 408 | try: 409 | kwargs[param] = json.loads(kwargs[param]) 410 | except (json.JSONDecodeError, TypeError): 411 | pass 412 | if param == 'feat_ext__ngram_range': 413 | kwargs[param] = tuple(kwargs[param]) 414 | pipeline = pipeline_from_spec(spec) 415 | pipeline.set_params(**kwargs) 416 | if class_weight is not None: 417 | pipeline = populate_class_weight(pipeline, class_weight) 418 | pipeline = fit_pipeline(reference_reads, reference_taxonomy, 419 | pipeline) 420 | return pipeline 421 | 422 | generic_signature = inspect.signature(generic_fitter) 423 | new_params = list(generic_signature.parameters.values())[:-1] 424 | new_params.extend(signature_params) 425 | return_annotation = generic_signature.return_annotation 426 | new_signature = inspect.Signature(parameters=new_params, 427 | return_annotation=return_annotation) 428 | generic_fitter.__signature__ = new_signature 429 | generic_fitter.__name__ = 'fit_classifier_' + name 430 | plugin.methods.register_function( 431 | function=generic_fitter, 432 | inputs={'reference_reads': FeatureData[Sequence], 433 | 'reference_taxonomy': FeatureData[Taxonomy], 434 | 'class_weight': FeatureTable[RelativeFrequency]}, 435 | parameters=parameters, 436 | outputs=[('classifier', TaxonomicClassifier)], 437 | name='Train the ' + name + ' classifier', 438 | description='Create a scikit-learn ' + name + ' classifier for reads', 439 | citations=[citations['pedregosa2011scikit']] 440 | ) 441 | 442 | 443 | for name, pipeline in _specific_fitters: 444 | _register_fitter(name, pipeline) 445 | -------------------------------------------------------------------------------- /q2_feature_classifier/_vsearch.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------------------------------------------- 2 | # Copyright (c) 2016-2025, QIIME 2 development team. 3 | # 4 | # Distributed under the terms of the Modified BSD License. 5 | # 6 | # The full license is in the file LICENSE, distributed with this software. 7 | # ---------------------------------------------------------------------------- 8 | 9 | import tempfile 10 | import qiime2 11 | import pandas as pd 12 | 13 | from q2_types.feature_data import ( 14 | FeatureData, Taxonomy, Sequence, DNAFASTAFormat, BLAST6, BLAST6Format) 15 | from .plugin_setup import plugin, citations 16 | from qiime2.plugin import Int, Str, Float, Choices, Range, Bool, Threads 17 | from ._blast import (_run_command) 18 | from ._consensus_assignment import (DEFAULTUNASSIGNABLELABEL, 19 | min_consensus_param, 20 | min_consensus_param_description) 21 | from ._taxonomic_classifier import TaxonomicClassifier 22 | from .classifier import _classify_parameters, _parameter_descriptions 23 | 24 | # Specify default settings for various functions 25 | DEFAULTMAXACCEPTS = 10 26 | DEFAULTPERCENTID = 0.8 27 | DEFAULTQUERYCOV = 0.8 28 | DEFAULTSTRAND = 'both' 29 | DEFAULTSEARCHEXACT = False 30 | DEFAULTTOPHITS = False 31 | DEFAULTMAXHITS = 'all' 32 | DEFAULTMAXREJECTS = 'all' 33 | DEFAULTOUTPUTNOHITS = True 34 | DEFAULTWEAKID = 0. 35 | DEFAULTTHREADS = 1 36 | DEFAULTMINCONSENSUS = 0.51 37 | 38 | 39 | def vsearch_global(query: DNAFASTAFormat, 40 | reference_reads: DNAFASTAFormat, 41 | maxaccepts: int = DEFAULTMAXACCEPTS, 42 | perc_identity: float = DEFAULTPERCENTID, 43 | query_cov: float = DEFAULTQUERYCOV, 44 | strand: str = DEFAULTSTRAND, 45 | search_exact: bool = DEFAULTSEARCHEXACT, 46 | top_hits_only: bool = DEFAULTTOPHITS, 47 | maxhits: int = DEFAULTMAXHITS, 48 | maxrejects: int = DEFAULTMAXREJECTS, 49 | output_no_hits: bool = DEFAULTOUTPUTNOHITS, 50 | weak_id: float = DEFAULTWEAKID, 51 | threads: str = DEFAULTTHREADS) -> BLAST6Format: 52 | seqs_fp = str(query) 53 | ref_fp = str(reference_reads) 54 | if maxaccepts == 'all': 55 | maxaccepts = 0 56 | if maxrejects == 'all': 57 | maxrejects = 0 58 | 59 | if search_exact: 60 | cmd = [ 61 | 'vsearch', 62 | '--search_exact', seqs_fp, 63 | '--strand', strand, 64 | '--db', ref_fp, 65 | '--threads', str(threads), 66 | ] 67 | else: 68 | cmd = [ 69 | 'vsearch', 70 | '--usearch_global', seqs_fp, 71 | '--id', str(perc_identity), 72 | '--query_cov', str(query_cov), 73 | '--strand', strand, 74 | '--maxaccepts', str(maxaccepts), 75 | '--maxrejects', str(maxrejects), 76 | '--db', ref_fp, 77 | '--threads', str(threads), 78 | ] 79 | 80 | if top_hits_only: 81 | cmd.append('--top_hits_only') 82 | if output_no_hits: 83 | cmd.append('--output_no_hits') 84 | if weak_id > 0 and weak_id < perc_identity: 85 | cmd.extend(['--weak_id', str(weak_id)]) 86 | if maxhits != 'all': 87 | cmd.extend(['--maxhits', str(maxhits)]) 88 | output = BLAST6Format() 89 | cmd.extend(['--blast6out', str(output)]) 90 | _run_command(cmd) 91 | return output 92 | 93 | 94 | def classify_consensus_vsearch(ctx, 95 | query, 96 | reference_reads, 97 | reference_taxonomy, 98 | maxaccepts=DEFAULTMAXACCEPTS, 99 | perc_identity=DEFAULTPERCENTID, 100 | query_cov=DEFAULTQUERYCOV, 101 | strand=DEFAULTSTRAND, 102 | search_exact=DEFAULTSEARCHEXACT, 103 | top_hits_only=DEFAULTTOPHITS, 104 | maxhits=DEFAULTMAXHITS, 105 | maxrejects=DEFAULTMAXREJECTS, 106 | output_no_hits=DEFAULTOUTPUTNOHITS, 107 | weak_id=DEFAULTWEAKID, 108 | threads=DEFAULTTHREADS, 109 | min_consensus=DEFAULTMINCONSENSUS, 110 | unassignable_label=DEFAULTUNASSIGNABLELABEL): 111 | search_db = ctx.get_action('feature_classifier', 'vsearch_global') 112 | lca = ctx.get_action('feature_classifier', 'find_consensus_annotation') 113 | result, = search_db(query=query, reference_reads=reference_reads, 114 | maxaccepts=maxaccepts, perc_identity=perc_identity, 115 | query_cov=query_cov, strand=strand, 116 | search_exact=search_exact, top_hits_only=top_hits_only, 117 | maxhits=maxhits, maxrejects=maxrejects, 118 | output_no_hits=output_no_hits, weak_id=weak_id, 119 | threads=threads) 120 | consensus, = lca(search_results=result, 121 | reference_taxonomy=reference_taxonomy, 122 | min_consensus=min_consensus, 123 | unassignable_label=unassignable_label) 124 | # New: add BLAST6Format result as an output. This could just as well be a 125 | # visualizer generated from these results (using q2-metadata tabulate). 126 | # Would that be more useful to the user that the QZA? 127 | return consensus, result 128 | 129 | 130 | def _annotate_method(taxa, method): 131 | taxa = taxa.view(pd.DataFrame) 132 | taxa['Method'] = method 133 | return qiime2.Artifact.import_data('FeatureData[Taxonomy]', taxa) 134 | 135 | 136 | def classify_hybrid_vsearch_sklearn(ctx, 137 | query, 138 | reference_reads, 139 | reference_taxonomy, 140 | classifier, 141 | maxaccepts=DEFAULTMAXACCEPTS, 142 | perc_identity=0.5, 143 | query_cov=DEFAULTQUERYCOV, 144 | strand=DEFAULTSTRAND, 145 | min_consensus=DEFAULTMINCONSENSUS, 146 | maxhits=DEFAULTMAXHITS, 147 | maxrejects=DEFAULTMAXREJECTS, 148 | reads_per_batch='auto', 149 | confidence=0.7, 150 | read_orientation='auto', 151 | threads=DEFAULTTHREADS, 152 | prefilter=True, 153 | sample_size=1000, 154 | randseed=0): 155 | exclude = ctx.get_action('quality_control', 'exclude_seqs') 156 | ccv = ctx.get_action('feature_classifier', 'classify_consensus_vsearch') 157 | cs = ctx.get_action('feature_classifier', 'classify_sklearn') 158 | filter_seqs = ctx.get_action('taxa', 'filter_seqs') 159 | merge = ctx.get_action('feature_table', 'merge_taxa') 160 | 161 | # randomly subsample reference sequences for rough positive filter 162 | if prefilter: 163 | ref = str(reference_reads.view(DNAFASTAFormat)) 164 | with tempfile.NamedTemporaryFile() as output: 165 | cmd = ['vsearch', '--fastx_subsample', ref, '--sample_size', 166 | str(sample_size), '--randseed', str(randseed), 167 | '--fastaout', output.name] 168 | _run_command(cmd) 169 | sparse_reference = qiime2.Artifact.import_data( 170 | 'FeatureData[Sequence]', output.name) 171 | 172 | # perform rough positive filter on query sequences 173 | query, misses, = exclude( 174 | query_sequences=query, reference_sequences=sparse_reference, 175 | method='vsearch', perc_identity=perc_identity, 176 | perc_query_aligned=query_cov, threads=threads) 177 | 178 | # find exact matches, perform LCA consensus classification 179 | # note: we only keep the taxonomic assignments, not the search report 180 | taxa1, _, = ccv( 181 | query=query, reference_reads=reference_reads, 182 | reference_taxonomy=reference_taxonomy, maxaccepts=maxaccepts, 183 | strand=strand, min_consensus=min_consensus, search_exact=True, 184 | threads=threads, maxhits=maxhits, maxrejects=maxrejects, 185 | output_no_hits=True) 186 | 187 | # Annotate taxonomic assignments with classification method 188 | taxa1 = _annotate_method(taxa1, 'VSEARCH') 189 | 190 | # perform second pass classification on unassigned taxa 191 | # filter out unassigned seqs 192 | try: 193 | query, = filter_seqs(sequences=query, taxonomy=taxa1, 194 | include=DEFAULTUNASSIGNABLELABEL) 195 | except ValueError: 196 | # get ValueError if all sequences are filtered out. 197 | # so if no sequences are unassigned, return exact match results 198 | return taxa1 199 | 200 | # classify with sklearn classifier 201 | taxa2, = cs(reads=query, classifier=classifier, 202 | reads_per_batch=reads_per_batch, n_jobs=threads, 203 | confidence=confidence, read_orientation=read_orientation) 204 | 205 | # Annotate taxonomic assignments with classification method 206 | taxa2 = _annotate_method(taxa2, 'sklearn') 207 | 208 | # merge into one big happy result 209 | taxa, = merge(data=[taxa2, taxa1]) 210 | return taxa 211 | 212 | 213 | parameters = {'maxaccepts': Int % Range(1, None) | Str % Choices(['all']), 214 | 'perc_identity': Float % Range(0.0, 1.0, inclusive_end=True), 215 | 'query_cov': Float % Range(0.0, 1.0, inclusive_end=True), 216 | 'strand': Str % Choices(['both', 'plus']), 217 | 'threads': Threads, 218 | 'maxhits': Int % Range(1, None) | Str % Choices(['all']), 219 | 'maxrejects': Int % Range(1, None) | Str % Choices(['all'])} 220 | 221 | extra_params = {'search_exact': Bool, 222 | 'top_hits_only': Bool, 223 | 'output_no_hits': Bool, 224 | 'weak_id': Float % Range(0.0, 1.0, inclusive_end=True)} 225 | 226 | inputs = {'query': FeatureData[Sequence], 227 | 'reference_reads': FeatureData[Sequence]} 228 | 229 | input_descriptions = {'query': 'Query Sequences.', 230 | 'reference_reads': 'Reference sequences.'} 231 | 232 | parameter_descriptions = { 233 | 'strand': 'Align against reference sequences in forward ("plus") ' 234 | 'or both directions ("both").', 235 | 'maxaccepts': 'Maximum number of hits to keep for each query. Set to ' 236 | '"all" to keep all hits > perc_identity similarity. Note ' 237 | 'that if strand=both, maxaccepts will keep N hits for each ' 238 | 'direction (if searches in the opposite direction yield ' 239 | 'results that exceed the minimum perc_identity). In those ' 240 | 'cases use maxhits to control the total number of hits ' 241 | 'returned. This option works in pair with maxrejects. ' 242 | 'The search process sorts target sequences by decreasing ' 243 | 'number of k-mers they have in common with the query ' 244 | 'sequence, using that information as a proxy for sequence ' 245 | 'similarity. After pairwise alignments, if the first target ' 246 | 'sequence passes the acceptation criteria, it is accepted ' 247 | 'as best hit and the search process stops for that query. ' 248 | 'If maxaccepts is set to a higher value, more hits are ' 249 | 'accepted. If maxaccepts and maxrejects are both set to ' 250 | '"all", the complete database is searched.', 251 | 'perc_identity': 'Reject match if percent identity to query is ' 252 | 'lower.', 253 | 'query_cov': 'Reject match if query alignment coverage per high-' 254 | 'scoring pair is lower.', 255 | 'threads': 'Number of threads to use for job parallelization. Pass 0 to ' 256 | 'use one per available CPU.', 257 | 'maxhits': 'Maximum number of hits to show once the search is terminated.', 258 | 'maxrejects': 'Maximum number of non-matching target sequences to ' 259 | 'consider before stopping the search. This option works in ' 260 | 'pair with maxaccepts (see maxaccepts description for ' 261 | 'details).'} 262 | 263 | extra_param_descriptions = { 264 | 'search_exact': 'Search for exact full-length matches to the query ' 265 | 'sequences. Only 100% exact matches are reported and this ' 266 | 'command is much faster than the default. If True, the ' 267 | 'perc_identity, query_cov, maxaccepts, and maxrejects ' 268 | 'settings are ignored. Note: query and reference reads ' 269 | 'must be trimmed to the exact same DNA locus (e.g., ' 270 | 'primer site) because only exact matches will be ' 271 | 'reported.', 272 | 'top_hits_only': 'Only the top hits between the query and reference ' 273 | 'sequence sets are reported. For each query, the top ' 274 | 'hit is the one presenting the highest percentage of ' 275 | 'identity. Multiple equally scored top hits will be ' 276 | 'used for consensus taxonomic assignment if ' 277 | 'maxaccepts is greater than 1.', 278 | 'output_no_hits': 'Report both matching and non-matching queries. ' 279 | 'WARNING: always use the default setting for this ' 280 | 'option unless if you know what you are doing! If ' 281 | 'you set this option to False, your sequences and ' 282 | 'feature table will need to be filtered to exclude ' 283 | 'unclassified sequences, otherwise you may run into ' 284 | 'errors downstream from missing feature IDs.', 285 | 'weak_id': 'Show hits with percentage of identity of at least N, ' 286 | 'without terminating the search. A normal search stops as ' 287 | 'soon as enough hits are found (as defined by maxaccepts, ' 288 | 'maxrejects, and perc_identity). As weak_id reports weak ' 289 | 'hits that are not deduced from maxaccepts, high ' 290 | 'perc_identity values can be used, hence preserving both ' 291 | 'speed and sensitivity. Logically, weak_id must be smaller ' 292 | 'than the value indicated by perc_identity, otherwise this ' 293 | 'option will be ignored.', 294 | } 295 | 296 | classification_output = ('classification', FeatureData[Taxonomy]) 297 | 298 | classification_output_description = { 299 | 'classification': 'Taxonomy classifications of query sequences.'} 300 | 301 | blast6_output = ('search_results', FeatureData[BLAST6]) 302 | 303 | blast6_output_description = {'search_results': 'Top hits for each query.'} 304 | 305 | ignore_prefilter = ' This parameter is ignored if `prefilter` is disabled.' 306 | 307 | 308 | plugin.methods.register_function( 309 | function=vsearch_global, 310 | inputs=inputs, 311 | parameters={**parameters, 312 | **extra_params}, 313 | outputs=[blast6_output], 314 | input_descriptions=input_descriptions, 315 | parameter_descriptions={ 316 | **parameter_descriptions, 317 | **extra_param_descriptions, 318 | }, 319 | output_descriptions=blast6_output_description, 320 | name='VSEARCH global alignment search', 321 | description=('Search for top hits in a reference database via global ' 322 | 'alignment between the query sequences and reference ' 323 | 'database sequences using VSEARCH. Returns a report of the ' 324 | 'top M hits for each query (where M=maxaccepts or maxhits).'), 325 | citations=[citations['rognes2016vsearch']] 326 | ) 327 | 328 | 329 | plugin.pipelines.register_function( 330 | function=classify_consensus_vsearch, 331 | inputs={**inputs, 332 | 'reference_taxonomy': FeatureData[Taxonomy]}, 333 | parameters={**parameters, 334 | **extra_params, 335 | **min_consensus_param, 336 | 'unassignable_label': Str, 337 | }, 338 | outputs=[classification_output, blast6_output], 339 | input_descriptions={**input_descriptions, 340 | 'reference_taxonomy': 'Reference taxonomy labels.'}, 341 | parameter_descriptions={ 342 | **parameter_descriptions, 343 | **extra_param_descriptions, 344 | **min_consensus_param_description, 345 | 'unassignable_label': 'Annotation given to sequences without any hits.' 346 | }, 347 | output_descriptions={**classification_output_description, 348 | **blast6_output_description}, 349 | name='VSEARCH-based consensus taxonomy classifier', 350 | description=('Assign taxonomy to query sequences using VSEARCH. Performs ' 351 | 'VSEARCH global alignment between query and reference_reads, ' 352 | 'then assigns consensus taxonomy to each query sequence from ' 353 | 'among maxaccepts top hits, min_consensus of which share ' 354 | 'that taxonomic assignment. Unlike classify-consensus-blast, ' 355 | 'this method searches the entire reference database before ' 356 | 'choosing the top N hits, not the first N hits.'), 357 | citations=[citations['rognes2016vsearch']] 358 | ) 359 | 360 | 361 | plugin.pipelines.register_function( 362 | function=classify_hybrid_vsearch_sklearn, 363 | inputs={**inputs, 364 | 'reference_taxonomy': FeatureData[Taxonomy], 365 | 'classifier': TaxonomicClassifier}, 366 | parameters={**parameters, 367 | **min_consensus_param, 368 | 'reads_per_batch': _classify_parameters['reads_per_batch'], 369 | 'confidence': _classify_parameters['confidence'], 370 | 'read_orientation': _classify_parameters['read_orientation'], 371 | 'prefilter': Bool, 372 | 'sample_size': Int % Range(1, None), 373 | 'randseed': Int % Range(0, None)}, 374 | outputs=[classification_output], 375 | input_descriptions={**input_descriptions, 376 | 'reference_taxonomy': 'Reference taxonomy labels.', 377 | 'classifier': 'Pre-trained sklearn taxonomic ' 378 | 'classifier for classifying the reads.'}, 379 | parameter_descriptions={ 380 | **{k: parameter_descriptions[k] for k in [ 381 | 'strand', 'maxaccepts', 'threads']}, 382 | **min_consensus_param_description, 383 | 'perc_identity': 'Percent sequence similarity to use for PREFILTER. ' + 384 | parameter_descriptions['perc_identity'] + ' Set to a ' 385 | 'lower value to perform a rough pre-filter.' + 386 | ignore_prefilter, 387 | 'query_cov': 'Query coverage threshold to use for PREFILTER. ' + 388 | parameter_descriptions['query_cov'] + ' Set to a ' 389 | 'lower value to perform a rough pre-filter.' + 390 | ignore_prefilter, 391 | 'confidence': _parameter_descriptions['confidence'], 392 | 'read_orientation': 'Direction of reads with respect to reference ' 393 | 'sequences in pre-trained sklearn classifier. ' 394 | 'same will cause reads to be classified unchanged' 395 | '; reverse-complement will cause reads to be ' 396 | 'reversed and complemented prior to ' 397 | 'classification. Both will classify sequences ' 398 | 'unchanged and in ' 399 | 'reverse-complement and retain the ' 400 | 'classification with higher confidence.' 401 | '"auto" will autodetect ' 402 | 'orientation based on the confidence estimates ' 403 | 'for the first 100 reads.', 404 | 'reads_per_batch': 'Number of reads to process in each batch for ' 405 | 'sklearn classification. If "auto", this parameter ' 406 | 'is autoscaled to min(number of query sequences / ' 407 | 'threads, 20000).', 408 | 'prefilter': 'Toggle positive filter of query sequences on or off.', 409 | 'sample_size': 'Randomly extract the given number of sequences from ' 410 | 'the reference database to use for prefiltering.' + 411 | ignore_prefilter, 412 | 'randseed': 'Use integer as a seed for the pseudo-random generator ' 413 | 'used during prefiltering. A given seed always produces ' 414 | 'the same output, which is useful for replicability. Set ' 415 | 'to 0 to use a pseudo-random seed.' + ignore_prefilter, 416 | }, 417 | output_descriptions=classification_output_description, 418 | name='ALPHA Hybrid classifier: VSEARCH exact match + sklearn classifier', 419 | description=('NOTE: THIS PIPELINE IS AN ALPHA RELEASE. Please report bugs ' 420 | 'to https://forum.qiime2.org!\n' 421 | 'Assign taxonomy to query sequences using hybrid classifier. ' 422 | 'First performs rough positive filter to remove artifact and ' 423 | 'low-coverage sequences (use "prefilter" parameter to toggle ' 424 | 'this step on or off). Second, performs VSEARCH exact match ' 425 | 'between query and reference_reads to find exact matches, ' 426 | 'followed by least common ancestor consensus taxonomy ' 427 | 'assignment from among maxaccepts top hits, min_consensus of ' 428 | 'which share that taxonomic assignment. Query sequences ' 429 | 'without an exact match are then classified with a pre-' 430 | 'trained sklearn taxonomy classifier to predict the most ' 431 | 'likely taxonomic lineage.'), 432 | ) 433 | -------------------------------------------------------------------------------- /q2_feature_classifier/tests/test_consensus_assignment.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------------------------------------------- 2 | # Copyright (c) 2016-2025, QIIME 2 development team. 3 | # 4 | # Distributed under the terms of the Modified BSD License. 5 | # 6 | # The full license is in the file LICENSE, distributed with this software. 7 | # ---------------------------------------------------------------------------- 8 | 9 | import pandas as pd 10 | import pandas.testing as pdt 11 | 12 | from qiime2.sdk import Artifact 13 | from q2_feature_classifier._skl import _specific_fitters 14 | from q2_feature_classifier._consensus_assignment import ( 15 | _lca_consensus, 16 | _compute_consensus_annotations, 17 | _blast6format_df_to_series_of_lists, 18 | _taxa_to_cumulative_ranks) 19 | from q2_types.feature_data import DNAFASTAFormat 20 | from . import FeatureClassifierTestPluginBase 21 | from qiime2.plugins import feature_classifier as qfc 22 | 23 | 24 | class SequenceSearchTests(FeatureClassifierTestPluginBase): 25 | 26 | def setUp(self): 27 | super().setUp() 28 | self.query = Artifact.import_data( 29 | 'FeatureData[Sequence]', self.get_data_path('query-seqs.fasta')) 30 | self.ref = Artifact.import_data( 31 | 'FeatureData[Sequence]', 32 | self.get_data_path('se-dna-sequences.fasta')) 33 | 34 | # The blastdb format is not documented in enough detail to validate 35 | # so for now we just run together with blastn to validate. 36 | def test_makeblastdb_and_blast(self): 37 | db, = qfc.actions.makeblastdb(self.ref) 38 | print(db) 39 | result1, = qfc.actions.blast(self.query, blastdb=db) 40 | result2, = qfc.actions.blast(self.query, self.ref) 41 | pdt.assert_frame_equal(result1.view(pd.DataFrame), 42 | result2.view(pd.DataFrame)) 43 | with self.assertRaisesRegex(ValueError, "Only one.*can be provided"): 44 | qfc.actions.blast(self.query, reference_reads=self.ref, blastdb=db) 45 | with self.assertRaisesRegex(ValueError, "Either.*must be provided"): 46 | qfc.actions.blast(self.query) 47 | 48 | def test_blast(self): 49 | result, = qfc.actions.blast( 50 | self.query, self.ref, maxaccepts=3, perc_identity=0.9) 51 | exp = pd.DataFrame({ 52 | 'qseqid': {0: '1111561', 1: '1111561', 2: '1111561', 3: '835097', 53 | 4: 'junk'}, 54 | 'sseqid': {0: '1111561', 1: '574274', 2: '149351', 3: '835097', 55 | 4: '*'}, 56 | 'pident': {0: 100.0, 1: 92.308, 2: 91.781, 3: 100.0, 4: 0.0}, 57 | 'length': {0: 75.0, 1: 78.0, 2: 73.0, 3: 80.0, 4: 0.0}, 58 | 'mismatch': {0: 0.0, 1: 2.0, 2: 4.0, 3: 0.0, 4: 0.0}, 59 | 'gapopen': {0: 0.0, 1: 4.0, 2: 2.0, 3: 0.0, 4: 0.0}, 60 | 'qstart': {0: 1.0, 1: 1.0, 2: 1.0, 3: 1.0, 4: 0.0}, 61 | 'qend': {0: 75.0, 1: 75.0, 2: 71.0, 3: 80.0, 4: 0.0}, 62 | 'sstart': {0: 24.0, 1: 24.0, 2: 24.0, 3: 32.0, 4: 0.0}, 63 | 'send': {0: 98.0, 1: 100.0, 2: 96.0, 3: 111.0, 4: 0.0}, 64 | 'evalue': {0: 8.35e-36, 1: 2.36e-26, 2: 3.94e-24, 65 | 3: 1.5000000000000002e-38, 4: 0.0}, 66 | 'bitscore': {0: 139.0, 1: 108.0, 2: 100.0, 3: 148.0, 4: 0.0}}) 67 | pdt.assert_frame_equal(result.view(pd.DataFrame), exp) 68 | 69 | def test_blast_no_output_no_hits(self): 70 | result, = qfc.actions.blast( 71 | self.query, self.ref, maxaccepts=3, perc_identity=0.9, 72 | output_no_hits=False) 73 | exp = pd.DataFrame({ 74 | 'qseqid': {0: '1111561', 1: '1111561', 2: '1111561', 3: '835097'}, 75 | 'sseqid': {0: '1111561', 1: '574274', 2: '149351', 3: '835097'}, 76 | 'pident': {0: 100.0, 1: 92.308, 2: 91.781, 3: 100.0}, 77 | 'length': {0: 75.0, 1: 78.0, 2: 73.0, 3: 80.0}, 78 | 'mismatch': {0: 0.0, 1: 2.0, 2: 4.0, 3: 0.0}, 79 | 'gapopen': {0: 0.0, 1: 4.0, 2: 2.0, 3: 0.0}, 80 | 'qstart': {0: 1.0, 1: 1.0, 2: 1.0, 3: 1.0}, 81 | 'qend': {0: 75.0, 1: 75.0, 2: 71.0, 3: 80.0}, 82 | 'sstart': {0: 24.0, 1: 24.0, 2: 24.0, 3: 32.0}, 83 | 'send': {0: 98.0, 1: 100.0, 2: 96.0, 3: 111.0}, 84 | 'evalue': {0: 8.35e-36, 1: 2.36e-26, 2: 3.94e-24, 85 | 3: 1.5000000000000002e-38}, 86 | 'bitscore': {0: 139.0, 1: 108.0, 2: 100.0, 3: 148.0}}) 87 | pdt.assert_frame_equal(result.view(pd.DataFrame), exp) 88 | 89 | def test_vsearch_global(self): 90 | result, = qfc.actions.vsearch_global( 91 | self.query, self.ref, maxaccepts=3, perc_identity=0.9) 92 | exp = pd.DataFrame({ 93 | 'qseqid': {0: '1111561', 1: '835097', 2: 'junk'}, 94 | 'sseqid': {0: '1111561', 1: '835097', 2: '*'}, 95 | 'pident': {0: 100.0, 1: 100.0, 2: 0.0}, 96 | 'length': {0: 75.0, 1: 80.0, 2: 0.0}, 97 | 'mismatch': {0: 0.0, 1: 0.0, 2: 0.0}, 98 | 'gapopen': {0: 0.0, 1: 0.0, 2: 0.0}, 99 | 'qstart': {0: 1.0, 1: 1.0, 2: 0.0}, 100 | 'qend': {0: 75.0, 1: 80.0, 2: 0.0}, 101 | 'sstart': {0: 1.0, 1: 1.0, 2: 0.0}, 102 | 'send': {0: 150.0, 1: 150.0, 2: 0.0}, 103 | 'evalue': {0: -1.0, 1: -1.0, 2: -1.0}, 104 | 'bitscore': {0: 0.0, 1: 0.0, 2: 0.0}}) 105 | pdt.assert_frame_equal( 106 | result.view(pd.DataFrame), exp, check_names=False) 107 | 108 | def test_vsearch_global_no_output_no_hits(self): 109 | result, = qfc.actions.vsearch_global( 110 | self.query, self.ref, maxaccepts=3, perc_identity=0.9, 111 | output_no_hits=False) 112 | exp = pd.DataFrame({ 113 | 'qseqid': {0: '1111561', 1: '835097'}, 114 | 'sseqid': {0: '1111561', 1: '835097'}, 115 | 'pident': {0: 100.0, 1: 100.0}, 116 | 'length': {0: 75.0, 1: 80.0}, 117 | 'mismatch': {0: 0.0, 1: 0.0}, 118 | 'gapopen': {0: 0.0, 1: 0.0}, 119 | 'qstart': {0: 1.0, 1: 1.0}, 120 | 'qend': {0: 75.0, 1: 80.0}, 121 | 'sstart': {0: 1.0, 1: 1.0}, 122 | 'send': {0: 150.0, 1: 150.0}, 123 | 'evalue': {0: -1.0, 1: -1.0}, 124 | 'bitscore': {0: 0.0, 1: 0.0}}) 125 | pdt.assert_frame_equal( 126 | result.view(pd.DataFrame), exp, check_names=False) 127 | 128 | def test_vsearch_global_permissive(self): 129 | result, = qfc.actions.vsearch_global( 130 | self.query, self.ref, maxaccepts=1, perc_identity=0.8, 131 | query_cov=0.2) 132 | exp = pd.DataFrame({ 133 | 'qseqid': {0: '1111561', 1: '835097', 2: 'junk'}, 134 | 'sseqid': {0: '1111561', 1: '835097', 2: '4314518'}, 135 | 'pident': {0: 100.0, 1: 100.0, 2: 90.0}, 136 | 'length': {0: 75.0, 1: 80.0, 2: 20.0}, 137 | 'mismatch': {0: 0.0, 1: 0.0, 2: 2.0}, 138 | 'gapopen': {0: 0.0, 1: 0.0, 2: 0.0}, 139 | 'qstart': {0: 1.0, 1: 1.0, 2: 1.0}, 140 | 'qend': {0: 75.0, 1: 80.0, 2: 100.0}, 141 | 'sstart': {0: 1.0, 1: 1.0, 2: 1.0}, 142 | 'send': {0: 150.0, 1: 150.0, 2: 95.0}, 143 | 'evalue': {0: -1.0, 1: -1.0, 2: -1.0}, 144 | 'bitscore': {0: 0.0, 1: 0.0, 2: 0.0}}) 145 | pdt.assert_frame_equal( 146 | result.view(pd.DataFrame), exp, check_names=False) 147 | 148 | 149 | # setting up utility test for comparing series below 150 | def series_is_subset(expected, observed): 151 | # join observed and expected results to compare 152 | joined = pd.concat([expected, observed], axis=1, join='inner') 153 | # check that all observed results are at least a substring of expected 154 | # (this should usually be the case, unless if consensus classification 155 | # did very badly, e.g., resulting in unclassified) 156 | compared = joined.apply(lambda x: x[0].startswith(x[1]), axis=1) 157 | # in the original tests we set a threshold of 50% for subsets... most 158 | # should be but in some cases misclassification could occur, or dodgy 159 | # annotations that screw up the LCA. So just check that we have at least 160 | # as many TRUE as FALSE. 161 | return len(compared[compared]) >= len(compared[~compared]) 162 | 163 | 164 | class ConsensusAssignmentsTests(FeatureClassifierTestPluginBase): 165 | 166 | def setUp(self): 167 | super().setUp() 168 | self.taxonomy = Artifact.import_data( 169 | 'FeatureData[Taxonomy]', self.get_data_path('taxonomy.tsv')) 170 | self.reads = Artifact.import_data( 171 | 'FeatureData[Sequence]', 172 | self.get_data_path('se-dna-sequences.fasta')) 173 | self.exp = self.taxonomy.view(pd.Series) 174 | 175 | # Make sure blast and vsearch produce expected outputs 176 | # but there is no "right" taxonomy assignment. 177 | # TODO: the results should be deterministic, so we should check expected 178 | # search and/or taxonomy classification outputs. 179 | def test_classify_consensus_blast(self): 180 | result, _, = qfc.actions.classify_consensus_blast( 181 | query=self.reads, reference_reads=self.reads, 182 | reference_taxonomy=self.taxonomy) 183 | self.assertTrue(series_is_subset(self.exp, result.view(pd.Series))) 184 | 185 | def test_classify_consensus_vsearch(self): 186 | result, _, = qfc.actions.classify_consensus_vsearch( 187 | self.reads, self.reads, self.taxonomy) 188 | self.assertTrue(series_is_subset(self.exp, result.view(pd.Series))) 189 | 190 | # search_exact with all other exposed params to confirm compatibility 191 | # in future releases of vsearch 192 | def test_classify_consensus_vsearch_search_exact(self): 193 | result, _, = qfc.actions.classify_consensus_vsearch( 194 | self.reads, self.reads, self.taxonomy, search_exact=True, 195 | top_hits_only=True, output_no_hits=True, weak_id=0.9, maxhits=10) 196 | self.assertTrue(series_is_subset(self.exp, result.view(pd.Series))) 197 | 198 | def test_classify_consensus_vsearch_top_hits_only(self): 199 | result, _, = qfc.actions.classify_consensus_vsearch( 200 | self.reads, self.reads, self.taxonomy, top_hits_only=True) 201 | self.assertTrue(series_is_subset(self.exp, result.view(pd.Series))) 202 | 203 | # make sure weak_id and other parameters do not conflict with each other. 204 | # This test just makes sure the command runs okay with all options. 205 | # We are not in the business of debugging VSEARCH, but want to have this 206 | # test as a canary in the coal mine. 207 | def test_classify_consensus_vsearch_the_works(self): 208 | result, _, = qfc.actions.classify_consensus_vsearch( 209 | self.reads, self.reads, self.taxonomy, top_hits_only=True, 210 | maxhits=1, maxrejects=10, weak_id=0.8, perc_identity=0.99, 211 | output_no_hits=False) 212 | self.assertTrue(series_is_subset(self.exp, result.view(pd.Series))) 213 | 214 | 215 | class HybridClassiferTests(FeatureClassifierTestPluginBase): 216 | 217 | def setUp(self): 218 | super().setUp() 219 | taxonomy = Artifact.import_data( 220 | 'FeatureData[Taxonomy]', self.get_data_path('taxonomy.tsv')) 221 | self.taxonomy = taxonomy.view(pd.Series) 222 | self.taxartifact = taxonomy 223 | # TODO: use `Artifact.import_data` here once we have a transformer 224 | # for DNASequencesDirectoryFormat -> DNAFASTAFormat 225 | reads_fp = self.get_data_path('se-dna-sequences.fasta') 226 | reads = DNAFASTAFormat(reads_fp, mode='r') 227 | self.reads = Artifact.import_data('FeatureData[Sequence]', reads) 228 | 229 | fitter = getattr(qfc.methods, 230 | 'fit_classifier_' + _specific_fitters[0][0]) 231 | self.classifier = fitter(self.reads, self.taxartifact).classifier 232 | 233 | self.query = Artifact.import_data('FeatureData[Sequence]', pd.Series( 234 | {'A': 'GCCTAACACATGCAAGTCGAACGGCAGCGGGGGAAAGCTTGCTTTCCTGCCGGCGA', 235 | 'B': 'TAACACATGCAAGTCAACGATGCTTATGTAGCAATATGTAAGTAGAGTGGCGCACG', 236 | 'C': 'ATACATGCAAGTCGTACGGTATTCCGGTTTCGGCCGGGAGAGAGTGGCGGATGGGT', 237 | 'D': 'GACGAACGCTGGCGACGTGCTTAACACATGCAAGTCGTGCGAGGACGGGCGGTGCT' 238 | 'TGCACTGCTCGAGCCGAGCGGCGGACGGGTGAGTAACACGTGAGCAACCTATCTCC' 239 | 'GTGCGGGGGACAACCCGGGGAAACCCGGGCTAATACCG'})) 240 | 241 | def test_classify_hybrid_vsearch_sklearn_all_exact_match(self): 242 | 243 | result, = qfc.actions.classify_hybrid_vsearch_sklearn( 244 | query=self.reads, reference_reads=self.reads, 245 | reference_taxonomy=self.taxartifact, classifier=self.classifier, 246 | prefilter=False) 247 | result, = qfc.actions.classify_hybrid_vsearch_sklearn( 248 | query=self.reads, reference_reads=self.reads, 249 | reference_taxonomy=self.taxartifact, classifier=self.classifier) 250 | result = result.view(pd.DataFrame) 251 | res = result.Taxon.to_dict() 252 | tax = self.taxonomy.to_dict() 253 | right = 0. 254 | for taxon in res: 255 | right += tax[taxon].startswith(res[taxon]) 256 | self.assertGreater(right/len(res), 0.5) 257 | 258 | def test_classify_hybrid_vsearch_sklearn_mixed_query(self): 259 | 260 | result, = qfc.actions.classify_hybrid_vsearch_sklearn( 261 | query=self.query, reference_reads=self.reads, 262 | reference_taxonomy=self.taxartifact, classifier=self.classifier, 263 | prefilter=True, read_orientation='same', randseed=1001) 264 | result = result.view(pd.DataFrame) 265 | obs = result.Taxon.to_dict() 266 | exp = {'A': 'k__Bacteria; p__Proteobacteria; c__Gammaproteobacteria; ' 267 | 'o__Legionellales; f__; g__; s__', 268 | 'B': 'k__Bacteria; p__Chlorobi; c__; o__; f__; g__; s__', 269 | 'C': 'k__Bacteria; p__Bacteroidetes; c__Cytophagia; ' 270 | 'o__Cytophagales; f__Cyclobacteriaceae; g__; s__', 271 | 'D': 'k__Bacteria; p__Gemmatimonadetes; c__Gemm-5; o__; f__; ' 272 | 'g__; s__'} 273 | self.assertDictEqual(obs, exp) 274 | 275 | 276 | class ImportBlastAssignmentTests(FeatureClassifierTestPluginBase): 277 | 278 | def setUp(self): 279 | super().setUp() 280 | result = Artifact.import_data( 281 | 'FeatureData[BLAST6]', self.get_data_path('blast6-format.tsv')) 282 | self.result = result.view(pd.DataFrame) 283 | taxonomy = Artifact.import_data( 284 | 'FeatureData[Taxonomy]', self.get_data_path('taxonomy.tsv')) 285 | self.taxonomy = taxonomy.view(pd.Series) 286 | 287 | def test_blast6format_df_to_series_of_lists(self): 288 | # and add in a query without any hits, to check that it is parsed 289 | self.result.loc[3] = ['junk', '*'] + [''] * 10 290 | obs = _blast6format_df_to_series_of_lists(self.result, self.taxonomy) 291 | exp = pd.Series( 292 | {'1111561': [ 293 | 'k__Bacteria; p__Proteobacteria; c__Gammaproteobacteria; ' 294 | 'o__Legionellales; f__; g__; s__', 295 | 'k__Bacteria; p__Proteobacteria; c__Gammaproteobacteria; ' 296 | 'o__Legionellales; f__Coxiellaceae; g__; s__'], 297 | '835097': [ 298 | 'k__Bacteria; p__Chloroflexi; c__SAR202; o__; f__; g__; s__'], 299 | 'junk': ['Unassigned']}, 300 | name='sseqid') 301 | exp.index.name = 'qseqid' 302 | pdt.assert_series_equal(exp, obs) 303 | 304 | # should fail when hit IDs are missing from reference taxonomy 305 | # in this case 1128818 is missing 306 | def test_blast6format_df_to_series_of_lists_fail_on_missing_ids(self): 307 | # add a bad idea 308 | self.result.loc[3] = ['junk', 'lost-id'] + [''] * 10 309 | with self.assertRaisesRegex(KeyError, "results do not match.*lost-id"): 310 | _blast6format_df_to_series_of_lists(self.result, self.taxonomy) 311 | 312 | 313 | class ConsensusAnnotationTests(FeatureClassifierTestPluginBase): 314 | 315 | def test_taxa_to_cumulative_ranks(self): 316 | taxa = ['a;b;c', 'a;b;d', 'a;g;g'] 317 | exp = [['a', 'a;b', 'a;b;c'], ['a', 'a;b', 'a;b;d'], 318 | ['a', 'a;g', 'a;g;g']] 319 | self.assertEqual(_taxa_to_cumulative_ranks(taxa), exp) 320 | 321 | def test_taxa_to_cumulative_ranks_with_uneven_ranks(self): 322 | taxa = ['a;b;c', 'a;b;d', 'a;g;g;somemoregarbage'] 323 | exp = [['a', 'a;b', 'a;b;c'], ['a', 'a;b', 'a;b;d'], 324 | ['a', 'a;g', 'a;g;g', 'a;g;g;somemoregarbage']] 325 | self.assertEqual(_taxa_to_cumulative_ranks(taxa), exp) 326 | 327 | def test_taxa_to_cumulative_ranks_with_one_entry(self): 328 | taxa = ['a;b;c'] 329 | exp = [['a', 'a;b', 'a;b;c']] 330 | self.assertEqual(_taxa_to_cumulative_ranks(taxa), exp) 331 | 332 | def test_taxa_to_cumulative_ranks_with_empty_list(self): 333 | taxa = [''] 334 | exp = [['']] 335 | self.assertEqual(_taxa_to_cumulative_ranks(taxa), exp) 336 | 337 | def test_varied_min_fraction(self): 338 | in_ = [['Ab', 'Ab;Bc', 'Ab;Bc;De'], 339 | ['Ab', 'Ab;Bc', 'Ab;Bc;Fg', 'Ab;Bc;Fg;Hi'], 340 | ['Ab', 'Ab;Bc', 'Ab;Bc;Fg', 'Ab;Bc;Fg;Jk']] 341 | 342 | actual = _lca_consensus(in_, 0.51, "Unassigned") 343 | expected = ('Ab;Bc;Fg', 0.667) 344 | self.assertEqual(actual, expected) 345 | 346 | # increased min_consensus_fraction yields decreased specificity 347 | actual = _lca_consensus(in_, 0.99, "Unassigned") 348 | expected = ('Ab;Bc', 1.0) 349 | self.assertEqual(actual, expected) 350 | 351 | def test_single_annotation(self): 352 | in_ = [['Ab', 'Ab;Bc', 'Ab;Bc;De']] 353 | 354 | actual = _lca_consensus(in_, 1.0, "Unassigned") 355 | expected = ('Ab;Bc;De', 1.0) 356 | self.assertEqual(actual, expected) 357 | 358 | actual = _lca_consensus(in_, 0.501, "Unassigned") 359 | expected = ('Ab;Bc;De', 1.0) 360 | self.assertEqual(actual, expected) 361 | 362 | def test_no_consensus(self): 363 | in_ = [['Ab', 'Ab;Bc', 'Ab;Bc;De'], 364 | ['Cd', 'Cd;Bc', 'Cd;Bc;Fg', 'Cd;Bc;Fg;Hi'], 365 | ['Ef', 'Ef;Bc', 'Ef;Bc;Fg', 'Ef;Bc;Fg;Jk']] 366 | 367 | actual = _lca_consensus(in_, 0.51, "Unassigned") 368 | expected = ('Unassigned', 0.) 369 | self.assertEqual(actual, expected) 370 | 371 | actual = _lca_consensus( 372 | in_, 0.51, unassignable_label="Hello world!") 373 | expected = ('Hello world!', 0.) 374 | self.assertEqual(actual, expected) 375 | 376 | def test_overlapping_names(self): 377 | # here the 3rd level is different, but the 4th level is the same 378 | # across the three assignments. this can happen in practice if 379 | # three different genera are assigned, and under each there is 380 | # an unnamed species 381 | # (e.g., f__x;g__A;s__, f__x;g__B;s__, f__x;g__B;s__) 382 | # in this case, the assignment should be f__x. 383 | in_ = [['Ab', 'Ab;Bc', 'Ab;Bc;De', 'Ab;Bc;De;Jk'], 384 | ['Ab', 'Ab;Bc', 'Ab;Bc;Fg', 'Ab;Bc;Fg;Jk'], 385 | ['Ab', 'Ab;Bc', 'Ab;Bc;Hi', 'Ab;Bc;Hi;Jk']] 386 | actual = _lca_consensus(in_, 0.51, "Unassigned") 387 | expected = ('Ab;Bc', 1.) 388 | self.assertEqual(actual, expected) 389 | 390 | # here the third level is the same in 4/5 of the 391 | # assignments, but one of them (z, y, c) refers to a 392 | # different taxa since the higher levels are different. 393 | # the consensus value should be 3/5, not 4/5, to 394 | # reflect that. 395 | in_ = [['a', 'a;b', 'a;b;c'], 396 | ['a', 'a;d', 'a;d;e'], 397 | ['a', 'a;b', 'a;b;c'], 398 | ['a', 'a;b', 'a;b;c'], 399 | ['z', 'z;y', 'z;y;c']] 400 | actual = _lca_consensus(in_, 0.51, "Unassigned") 401 | expected = ('a;b;c', 0.6) 402 | self.assertEqual(actual, expected) 403 | 404 | def test_adjusts_resolution(self): 405 | # max result depth is that of shallowest assignment 406 | # Reading this test now, I am not entirely sure that this is how 407 | # such cases should be handled. Technically such a case should not 408 | # arise (as the dbs should have even ranks) so we can leave this for 409 | # now, and it is arguable, but in this case I think that majority 410 | # should rule. We use `zip` but might want to consider `zip_longest`. 411 | in_ = [['Ab', 'Ab;Bc', 'Ab;Bc;Fg'], 412 | ['Ab', 'Ab;Bc', 'Ab;Bc;Fg', 'Ab;Bc;Fg;Hi'], 413 | ['Ab', 'Ab;Bc', 'Ab;Bc;Fg', 'Ab;Bc;Fg;Hi'], 414 | ['Ab', 'Ab;Bc', 'Ab;Bc;Fg', 'Ab;Bc;Fg;Hi'], 415 | ['Ab', 'Ab;Bc', 'Ab;Bc;Fg', 'Ab;Bc;Fg;Hi', 'Ab;Bc;Fg;Hi;Jk']] 416 | 417 | actual = _lca_consensus(in_, 0.51, "Unassigned") 418 | expected = ('Ab;Bc;Fg', 1.0) 419 | self.assertEqual(actual, expected) 420 | 421 | in_ = [['Ab', 'Ab;Bc', 'Ab;Bc;Fg'], 422 | ['Ab', 'Ab;Bc', 'Ab;Bc;Fg', 'Ab;Bc;Fg;Hi', 'Ab;Bc;Fg;Hi;Jk'], 423 | ['Ab', 'Ab;Bc', 'Ab;Bc;Fg', 'Ab;Bc;Fg;Hi', 'Ab;Bc;Fg;Hi;Jk'], 424 | ['Ab', 'Ab;Bc', 'Ab;Bc;Fg', 'Ab;Bc;Fg;Hi', 'Ab;Bc;Fg;Hi;Jk'], 425 | ['Ab', 'Ab;Bc', 'Ab;Bc;Fg', 'Ab;Bc;Fg;Hi', 'Ab;Bc;Fg;Hi;Jk']] 426 | 427 | actual = _lca_consensus(in_, 0.51, "Unassigned") 428 | expected = ('Ab;Bc;Fg', 1.0) 429 | self.assertEqual(actual, expected) 430 | 431 | 432 | # More edge cases are tested for the internals above, so the tests here are 433 | # made slim to just test the overarching functions. 434 | class ConsensusAnnotationsTests(FeatureClassifierTestPluginBase): 435 | 436 | def test_varied_fraction(self): 437 | 438 | in_ = pd.Series({'q1': ['A;B;C;D', 'A;B;C;E'], 439 | 'q2': ['A;H;I;J', 'A;H;K;L;M', 'A;H;I;J'], 440 | 'q3': ['A', 'A', 'B'], 441 | 'q4': ['A', 'B'], 442 | 'q5': []}) 443 | expected = pd.DataFrame({ 444 | 'Taxon': {'q1': 'A;B;C', 'q2': 'A;H;I;J', 'q3': 'A', 445 | 'q4': 'Unassigned', 'q5': 'Unassigned'}, 446 | 'Consensus': { 447 | 'q1': 1.0, 'q2': 0.667, 'q3': 0.667, 'q4': 0.0, 'q5': 0.0}}) 448 | actual = _compute_consensus_annotations(in_, 0.51, 'Unassigned') 449 | pdt.assert_frame_equal(actual, expected, check_names=False) 450 | 451 | expected = pd.DataFrame({ 452 | 'Taxon': {'q1': 'A;B;C', 'q2': 'A;H', 'q3': 'Unassigned', 453 | 'q4': 'Unassigned', 'q5': 'Unassigned'}, 454 | 'Consensus': { 455 | 'q1': 1.0, 'q2': 1.0, 'q3': 0.0, 'q4': 0.0, 'q5': 0.0}}) 456 | actual = _compute_consensus_annotations(in_, 0.99, 'Unassigned') 457 | pdt.assert_frame_equal(actual, expected, check_names=False) 458 | 459 | def test_find_consensus_annotation(self): 460 | 461 | result = Artifact.import_data( 462 | 'FeatureData[BLAST6]', self.get_data_path('blast6-format.tsv')) 463 | taxonomy = Artifact.import_data( 464 | 'FeatureData[Taxonomy]', self.get_data_path('taxonomy.tsv')) 465 | consensus, = qfc.actions.find_consensus_annotation(result, taxonomy) 466 | obs = consensus.view(pd.DataFrame) 467 | exp = pd.DataFrame( 468 | {'Taxon': { 469 | '1111561': 'k__Bacteria; p__Proteobacteria; ' 470 | 'c__Gammaproteobacteria; o__Legionellales', 471 | '835097': 'k__Bacteria; p__Chloroflexi; c__SAR202; o__; f__; ' 472 | 'g__; s__'}, 473 | 'Consensus': {'1111561': '1.0', '835097': '1.0'}}) 474 | pdt.assert_frame_equal(exp, obs, check_names=False) 475 | --------------------------------------------------------------------------------