├── .flake8 ├── .github └── workflows │ ├── lint.yml │ ├── publish.yml │ ├── publishable.yml │ └── tests.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── pyproject.toml ├── pytest.ini ├── pytest_pytorch ├── __init__.py └── plugin.py ├── requirements-dev.txt ├── setup.cfg ├── tests ├── __init__.py ├── assets │ ├── _spy.py │ ├── test_device.py │ ├── test_disabled.py │ ├── test_dtype.py │ ├── test_nested_names.py │ └── test_op_infos.py ├── conftest.py ├── test_cli.py ├── test_plugin.py ├── test_smoke.py └── utils.py └── tox.ini /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | # See link below for available options 3 | # https://flake8.pycqa.org/en/latest/user/options.html#options-and-their-descriptions 4 | # Move this to pyproject.toml as soon as it is supported. 5 | # See https://gitlab.com/pycqa/flake8/issues/428 6 | 7 | exclude = 8 | .git, 9 | .github, 10 | .venv, 11 | .eggs, 12 | .mypy_cache, 13 | .pytest_cache, 14 | .tox, 15 | __pycache__, 16 | *.pyc, 17 | ignore = E203, E501, W503 18 | max-line-length = 88 19 | max-doc-length = 88 20 | per-file-ignores = 21 | __init__.py: F401, F403, F405 22 | conftest.py: F401, F403, F405 23 | tests/*: D1 24 | show_source = True 25 | statistics = True 26 | doctests = True 27 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: lint 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | - releases/* 8 | 9 | pull_request: 10 | 11 | jobs: 12 | style: 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - name: Set up python 17 | uses: actions/setup-python@v2 18 | with: 19 | python-version: "3.6" 20 | 21 | - name: Upgrade pip 22 | run: python -m pip install --upgrade pip 23 | 24 | - name: Checkout repository 25 | uses: actions/checkout@v2 26 | with: 27 | fetch-depth: 0 28 | 29 | - name: Install dev requirements 30 | run: pip install -r requirements-dev.txt 31 | 32 | - name: Create environment 33 | run: | 34 | tox -e lint --notest 35 | pre-commit install-hooks 36 | 37 | - name: Run lint 38 | run: tox -e lint 39 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: publish 2 | 3 | on: 4 | release: 5 | types: [created] 6 | 7 | jobs: 8 | pypi: 9 | runs-on: ubuntu-latest 10 | 11 | steps: 12 | - name: Set up python 13 | uses: actions/setup-python@v2 14 | with: 15 | python-version: "3.6" 16 | 17 | - name: Upgrade pip 18 | run: python -m pip install --upgrade pip 19 | 20 | - name: Checkout repository 21 | uses: actions/checkout@v2 22 | with: 23 | fetch-depth: 0 24 | 25 | - name: Install build and twine 26 | run: pip install build twine 27 | 28 | - name: Build source and binary 29 | run: python -m build --sdist --wheel . 30 | 31 | - name: Upload to PyPI 32 | env: 33 | TWINE_REPOSITORY: pypi 34 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 35 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 36 | run: twine upload dist/* 37 | -------------------------------------------------------------------------------- /.github/workflows/publishable.yml: -------------------------------------------------------------------------------- 1 | name: publishable 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | - releases/* 8 | 9 | pull_request: 10 | paths: 11 | - "pytest_pytorch/**" 12 | - "CONTRIBUTING.md" 13 | - "LICENSE" 14 | - "MANIFEST.in" 15 | - "pyproject.toml" 16 | - "README.md" 17 | - "setup.cfg" 18 | - "tox.ini" 19 | - ".github/workflows/publishable.yml" 20 | 21 | jobs: 22 | pypi: 23 | runs-on: ubuntu-latest 24 | 25 | steps: 26 | - name: Set up python 27 | uses: actions/setup-python@v2 28 | with: 29 | python-version: "3.6" 30 | 31 | - name: Upgrade pip 32 | run: python -m pip install --upgrade pip 33 | 34 | - name: Checkout repository 35 | uses: actions/checkout@v2 36 | with: 37 | fetch-depth: 0 38 | 39 | - name: Install dev requirements 40 | run: pip install -r requirements-dev.txt 41 | 42 | - name: Create environment 43 | run: tox -e publishable --notest 44 | 45 | - name: Test if publishable 46 | run: tox -e publishable 47 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: tests 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | - releases/* 8 | 9 | pull_request: 10 | paths: 11 | - "pytest_pytorch/**" 12 | - "tests/**" 13 | - "pyproject.toml" 14 | - "pytest.ini" 15 | - "requirements-dev.txt" 16 | - "setup.cfg" 17 | - "tox.ini" 18 | - ".github/workflows/tests.yml" 19 | 20 | schedule: 21 | - cron: "0 4 * * *" 22 | 23 | jobs: 24 | integration: 25 | strategy: 26 | matrix: 27 | os: [ubuntu-latest] 28 | python: ['3.6', '3.7', '3.8', '3.9'] 29 | include: 30 | - os: windows-latest 31 | python: "3.6" 32 | - os: macos-latest 33 | python: "3.6" 34 | 35 | runs-on: ${{ matrix.os }} 36 | env: 37 | OS: ${{ matrix.os }} 38 | PYTHON: ${{ matrix.python }} 39 | 40 | steps: 41 | - name: Set up python 42 | uses: actions/setup-python@v2 43 | with: 44 | python-version: ${{ matrix.python }} 45 | 46 | - name: Upgrade pip 47 | run: python -m pip install --upgrade pip 48 | 49 | - name: Checkout repository 50 | uses: actions/checkout@v2 51 | with: 52 | fetch-depth: 0 53 | 54 | - name: Install dev requirements 55 | run: pip install -r requirements-dev.txt 56 | 57 | - name: Create environment 58 | run: tox -e tests --notest 59 | 60 | - name: Run tests 61 | run: tox -e tests 62 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | pytest_pytorch/_version.py 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # PyCharm project settings 123 | .idea 124 | 125 | # mkdocs documentation 126 | /site 127 | 128 | # mypy 129 | .mypy_cache/ 130 | .dmypy.json 131 | dmypy.json 132 | 133 | # Pyre type checker 134 | .pyre/ 135 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/PyCQA/isort 3 | rev: "5.8.0" 4 | hooks: 5 | - id: isort 6 | args: ["--settings-path=pyproject.toml"] 7 | - repo: https://github.com/psf/black 8 | rev: "20.8b1" 9 | hooks: 10 | - id: black 11 | args: ["--config=pyproject.toml"] 12 | - repo: https://gitlab.com/pycqa/flake8 13 | rev: "3.9.0" 14 | hooks: 15 | - id: flake8 16 | args: ["--config=.flake8"] 17 | - repo: https://github.com/pre-commit/pre-commit-hooks 18 | rev: "v3.4.0" 19 | hooks: 20 | - id: check-added-large-files 21 | - id: check-case-conflict 22 | - id: check-merge-conflict 23 | - id: check-symlinks 24 | - id: check-toml 25 | - id: check-vcs-permalinks 26 | - id: debug-statements 27 | - id: destroyed-symlinks 28 | - id: end-of-file-fixer 29 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | First and foremost: Thank you for your interest in `pytest-pytorch`'s development! We appreciate all contributions be it code or something else. 4 | 5 | If you are contributing bug-fixes or documentation improvements, you can open a [pull request (PR)](https://github.com/Quansight/pytest-pytorch/pulls) without further discussion. If on the other hand you are planning to contribute new features, please open an [issue](https://github.com/Quansight/pytest-pytorch/issues) and discuss the feature with us first. 6 | 7 | Every PR is subjected to multiple automatic checks (continuous integration, CI) as well as a manual code review that it has to pass before it can be merged. The automatic checks are performed by [tox](https://tox.readthedocs.io/en/latest/). You can find details and instructions how to run the checks locally below. 8 | 9 | ## Guide lines 10 | 11 | `pytest-pytorch` uses the [GitHub workflow](https://guides.github.com/introduction/flow/). Below is small guide how to make your first contribution. 12 | 13 | The following guide assumes that [git](https://git-scm.com/), [python](https://www.python.org/), and [pip](https://pypi.org/project/pip/), are available on your system. If that is not the case, follow the official installation instructions. 14 | 15 | `pytest-pytorch` officially supports Python `3.6` to `3.9`. To ensure backward compatibility, the development should happen on the minimum Python version, i. e. `3.6`. 16 | 17 | 1. Fork `pytest-pytorch` on GitHub 18 | 19 | Navigate to [Quansight/pytest-pytorch](https://github.com/Quansight/pytest-pytorch) on GitHub and click the **Fork** button in the top right corner. 20 | 21 | 2. Clone your fork to your local file system 22 | 23 | Use `git clone` to get a local copy of `pytest-pytorch`'s repository that you can work on: 24 | 25 | ``` 26 | $ PYTEST_PYTORCH_ROOT="pytest-pytorch" 27 | $ git clone "https://github.com/Quansight/pytest-pytorch.git" $PYTEST_PYTORCH_ROOT 28 | ``` 29 | 30 | 3. Setup your development environment 31 | 32 | ``` 33 | $ cd $PYTEST_PYTORCH_ROOT 34 | $ virtualenv .venv --prompt="(pytest-pytorch) " 35 | $ source .venv/bin/activate 36 | $ pip install -r requirements-dev.txt 37 | $ pre-commit install 38 | ``` 39 | 40 | While `pytest-pytorch`'s development requirements are fairly lightweight, it is still recommended installing them in a virtual environment rather than system wide. If you do not have `virtualenv` installed, you can do so by running `pip install --user virtualenv`. 41 | 42 | 4. Create a branch for local development 43 | 44 | Use `git checkout` to create local branch with a descriptive name: 45 | 46 | ``` 47 | $ PYTEST_PYTORCH_BRANCH="my-awesome-feature-or-bug-fix" 48 | $ git checkout -b $PYTEST_PYTORCH_BRANCH 49 | ``` 50 | 51 | Now make your changes. Happy Coding! 52 | 53 | 5. Use `tox` to run various checks 54 | 55 | ``` 56 | $ tox 57 | ``` 58 | 59 | This is equivalent to running 60 | 61 | ``` 62 | $ tox -e lint 63 | $ tox -e tests 64 | ``` 65 | 66 | You can find details what the individual commands do below of this guide. 67 | 68 | 6. Commit and push your changes 69 | 70 | If all checks are passing you can commit your changes an push them to your fork: 71 | 72 | ``` 73 | $ git add . 74 | $ git commit -m "Descriptive message of the changes made" 75 | $ git push -u origin $PYTEST_PYTORCH_BRANCH 76 | ``` 77 | 78 | For larger changes, it is good practice to split them in multiple small commits rather than one large one. If you do that, make sure to run the test suite before every commit. Furthermore, use `git push` without any parameters for consecutive pushes. 79 | 80 | 7. Open a Pull request (PR) 81 | 82 | 1. Navigate to [Quansight/pytest-pytorch/pulls](https://github.com/Quansight/pytest-pytorch/pulls) on GitHub and click on the green button "New pull request". 83 | 2. Click on "compare across forks" below the "Compare changes" headline. 84 | 3. Select your fork for "head repository" and your branch for "compare" in the drop-down menus. 85 | 4. Click the the green button "Create pull request". 86 | 87 | If the time between the branch being pushed and the PR being opened is not too long, GitHub will offer you a yellow box after step 1. If you click the button, you can skip steps 2. and 3. 88 | 89 | Steps 1. to 3. only have to performed once. If you want to continue contributing, make sure to branch from the current `master` branch. You can use `git pull` 90 | 91 | ``` 92 | $ git checkout master 93 | $ git pull origin 94 | $ git checkout -b "my-second-awesome-feature-or-bug-fix" 95 | ``` 96 | 97 | If you forgot to do that or if since the creation of your branch many commits have been made to the `master` branch, simply rebase your branch on top of it. 98 | 99 | ``` 100 | $ git checkout master 101 | $ git pull origin 102 | $ git checkout "my-second-awesome-feature-or-bug-fix" 103 | $ git rebase master 104 | ``` 105 | 106 | ## Code format and linting 107 | 108 | `pytest-pytorch` uses [`isort`](https://github.com/PyCQA/isort) to sort the imports, [black](https://black.readthedocs.io/en/stable/) to format the code, and [flake8](https://flake8.pycqa.org/en/latest/) to enforce [PEP8](https://www.python.org/dev/peps/pep-0008/) compliance. To format and check the code style, run 109 | 110 | ``` 111 | $ cd $PYTEST_PYTORCH_ROOT 112 | $ source .venv/bin/activate 113 | $ tox -e lint 114 | ``` 115 | 116 | Instead of running the checks manually, you can install them as pre-commit hooks: 117 | 118 | ``` 119 | $ cd $PYTEST_PYTORCH_ROOT 120 | $ source .venv/bin/activate 121 | $ pre-commit install 122 | ``` 123 | 124 | Now, amongst others, the above checks are run automatically every time you add a commit. 125 | 126 | ## Testing 127 | 128 | `pytest-pytorch` uses [`pytest`](https://docs.pytest.org/en/stable/) to run the test suite. You can run it locally with 129 | 130 | ``` 131 | cd $PYTEST_PYTORCH_ROOT 132 | source .venv/bin/activate 133 | tox -e test 134 | ``` 135 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2021, Philip Meier 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-exclude .github * 2 | recursive-exclude tests * 3 | 4 | exclude .flake8 5 | exclude .gitignore 6 | exclude .pre-commit-config.yaml 7 | exclude MANIFEST.in 8 | exclude pytest.ini 9 | exclude requirements-dev.txt 10 | exclude tox.ini 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # `pytest-pytorch` 2 | 3 | [![license](https://img.shields.io/badge/License-BSD%203--Clause-blue.svg)](https://opensource.org/licenses/BSD-3-Clause) [![repo status](https://www.repostatus.org/badges/latest/wip.svg)](https://www.repostatus.org/#wip) [![isort](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/) [![black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) [![tests status](https://github.com/Quansight/pytest-pytorch/workflows/tests/badge.svg?branch=master)](https://github.com/Quansight/pytest-pytorch/actions?query=workflow%3Atests+branch%3Amaster) 4 | 5 | ## What is it? 6 | 7 | `pytest-pytorch` is a lightweight [`pytest`]-plugin that enhances the developer experience when working with the [PyTorch] test suite if you come from a [`pytest`] background. 8 | 9 | ## Why do I need it? 10 | 11 | Some testcases in the PyTorch test suite are only used as templates and will be instantiated at runtime. Unfortunately, [PyTorch]'s naming scheme for parametrizations differs from [`pytest`]'s. As a consequence, these tests cannot be selected by their names as written and one has to remember [PyTorch]'s naming scheme. This can be especially disrupting to your workflow if your IDE ([PyCharm](https://www.jetbrains.com/help/pycharm/pytest.html#run-pytest-test), [VSCode](https://code.visualstudio.com/docs/python/testing#_run-tests)) relies on [`pytest`]'s default selection syntax. 12 | 13 | If this has ever been a source of frustration for you, worry no longer. `pytest-pytorch` was made for you. 14 | 15 | ## How do I install it? 16 | 17 | You can install `pytest-pytorch` with `pip` 18 | 19 | ```shell 20 | $ pip install pytest-pytorch 21 | ``` 22 | 23 | or with `conda`: 24 | 25 | ```shell 26 | $ conda install -c conda-forge pytest-pytorch 27 | ``` 28 | 29 | ## How do I use it? 30 | 31 | With `pytest-pytorch` installed you can select test cases and tests by their names as written: 32 | 33 | | Use case | Command | 34 | |-------------------------------------|-----------------------------------------| 35 | | Run a test case against all devices | `pytest test_foo.py::TestBar` | 36 | | Run a test against all devices | `pytest test_foo.py::TestBar::test_baz` | 37 | 38 | Similar to a parametrization by [`@pytest.mark.parametrize`](https://docs.pytest.org/en/stable/example/parametrize.html#different-options-for-test-ids) you can use the [`-k` flag](https://docs.pytest.org/en/stable/reference.html#command-line-flags) to select a specific set of parameters: 39 | 40 | | Use case | Command | 41 | |------------------------------------|------------------------------------------------------| 42 | | Run a test case against one device | `pytest test_foo.py::TestBar -k "$DEVICE"` | 43 | | Run a test against one device | `pytest test_foo.py::TestBar::test_baz -k "$DEVICE"` | 44 | 45 | ## Can I have a little more background? 46 | 47 | Sure, we have written a [blog post about `pytest-pytorch`](https://labs.quansight.org/blog/2021/06/pytest-pytorch/) that goes into details. 48 | 49 | ## How do I contribute? 50 | 51 | First and foremost: Thank you for your interest in development of `pytest-pytorch`'s! We appreciate all contributions be it code or something else. Check out our [contribution guide lines](CONTRIBUTING.md) for details. 52 | 53 | [PyTorch]: https://pytorch.org 54 | [`pytest`]: https://docs.pytest.org/en/stable/ 55 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=45", 4 | "setuptools_scm[toml]>=6.0", 5 | "wheel", 6 | ] 7 | build-backend = "setuptools.build_meta" 8 | 9 | [tool.setuptools_scm] 10 | # See link below for available options 11 | # https://github.com/pypa/setuptools_scm/#configuration-parameters 12 | 13 | write_to = "pytest_pytorch/_version.py" 14 | version_scheme = "release-branch-semver" 15 | local_scheme = "node-and-timestamp" 16 | 17 | [tool.isort] 18 | # See link below for available options 19 | # https://pycqa.github.io/isort/docs/configuration/options/ 20 | 21 | profile = "black" 22 | line_length = 88 23 | 24 | skip_gitignore = true 25 | float_to_top = true 26 | color_output = true 27 | order_by_type = true 28 | combine_star = true 29 | filter_files = true 30 | 31 | known_third_party = ["pytest", "_pytest"] 32 | known_first_party = ["torch", "pytest_pytorch"] 33 | known_local_folder = ["tests"] 34 | 35 | [tool.black] 36 | # See link below for available options 37 | # https://github.com/psf/black#configuration-format 38 | 39 | line-length = 88 40 | target-version = ['py36'] 41 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | ;See link below for available options 3 | ;https://docs.pytest.org/en/latest/reference/reference.html#configuration-options 4 | 5 | testpaths = tests/ 6 | pytester_example_dir = tests/assets/ 7 | addopts = 8 | -ra 9 | 10 | --tb=short 11 | # enable all warnings 12 | -Wd 13 | --ignore=tests/assets/ 14 | -------------------------------------------------------------------------------- /pytest_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | from ._version import version as __version__ # type: ignore[import] 3 | except ImportError: 4 | __version__ = "UNKNOWN" 5 | -------------------------------------------------------------------------------- /pytest_pytorch/plugin.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import re 3 | import unittest.mock 4 | import warnings 5 | 6 | from _pytest.unittest import TestCaseFunction, UnitTestCase 7 | 8 | try: 9 | from torch.testing._internal.common_utils import TestCase as TestCaseTemplate 10 | 11 | TORCH_AVAILABLE = True 12 | except ImportError: 13 | TORCH_AVAILABLE = False 14 | 15 | warnings.warn( 16 | "Disabling the `pytest-pytorch` plugin, because 'torch' could not be imported." 17 | ) 18 | 19 | 20 | class TemplatedName(str): 21 | def __new__(cls, name, template_name): 22 | self = super().__new__(cls, name) 23 | self._template_name = template_name 24 | return self 25 | 26 | def __eq__(self, other): 27 | exact_match = str.__eq__(self, other) 28 | if exact_match: 29 | return True 30 | 31 | if not self._template_name: 32 | return False 33 | 34 | return str.__eq__(self._template_name, other) 35 | 36 | def __hash__(self): 37 | return super().__hash__() 38 | 39 | 40 | class TemplatedTestCaseFunction(TestCaseFunction): 41 | _TEMPLATE_NAME_PATTERN = re.compile(r"def (?Ptest_\w+)\(") 42 | 43 | @classmethod 44 | def _extract_template_name(cls, callobj): 45 | if not callobj: 46 | return None 47 | 48 | match = cls._TEMPLATE_NAME_PATTERN.search(inspect.getsource(callobj)) 49 | if not match: 50 | return None 51 | 52 | return match.group("template_name") 53 | 54 | @classmethod 55 | def from_parent(cls, parent, *, name, callobj, **kw): 56 | return super().from_parent( 57 | parent, name=TemplatedName(name, cls._extract_template_name(callobj)), **kw 58 | ) 59 | 60 | 61 | class TemplatedTestCase(UnitTestCase): 62 | @classmethod 63 | def _extract_template_name(cls, name, obj): 64 | if not obj: 65 | return None 66 | 67 | if not hasattr(obj, "device_type"): 68 | return None 69 | 70 | return name[: -len(obj.device_type)] 71 | 72 | @classmethod 73 | def from_parent(cls, parent, *, name, obj=None): 74 | return super().from_parent( 75 | parent, 76 | name=TemplatedName(name, cls._extract_template_name(name, obj)), 77 | obj=obj, 78 | ) 79 | 80 | def collect(self): 81 | # Yes, this is a bad practice. Unfortunately, there is no other option to 82 | # inject our custom 'TestCaseFunction' without duplicating everything in 83 | # 'UnitTestCase.collect()' 84 | with unittest.mock.patch( 85 | "_pytest.unittest.TestCaseFunction", new=TemplatedTestCaseFunction 86 | ): 87 | yield from super().collect() 88 | 89 | 90 | def pytest_addoption(parser, pluginmanager): 91 | parser.addoption( 92 | "--disable-pytest-pytorch", 93 | action="store_true", 94 | help="Disable the `pytest-pytorch` plugin", 95 | ) 96 | return None 97 | 98 | 99 | def pytest_pycollect_makeitem(collector, name, obj): 100 | if not TORCH_AVAILABLE: 101 | return None 102 | 103 | if collector.config.getoption("disable_pytest_pytorch"): 104 | return None 105 | 106 | try: 107 | if not issubclass(obj, TestCaseTemplate) or obj is TestCaseTemplate: 108 | return None 109 | except Exception: 110 | return None 111 | 112 | return TemplatedTestCase.from_parent(collector, name=name, obj=obj) 113 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | tox >= 3.2 2 | pre-commit 3 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = pytest_pytorch 3 | platforms = any 4 | description = pytest plugin for a better developer experience when working with the PyTorch test suite 5 | long_description = file: README.md 6 | long_description_content_type = text/markdown 7 | keywords = pytest, plugin, pytorch, torch 8 | url = https://github.com/Quansight/pytest-pytorch 9 | author = Philip Meier 10 | author_email = pmeier@quansight.com 11 | license = BSD-3-Clause 12 | classifiers = 13 | Development Status :: 4 - Beta 14 | Framework :: Pytest 15 | Intended Audience :: Developers 16 | License :: OSI Approved :: BSD License 17 | Operating System :: OS Independent 18 | Programming Language :: Python :: 3.6 19 | Programming Language :: Python :: 3.7 20 | Programming Language :: Python :: 3.8 21 | Programming Language :: Python :: 3.9 22 | Topic :: Software Development :: Testing 23 | project_urls = 24 | Source = https://github.com/Quansight/pytest-pytorch 25 | Documentation = https://github.com/Quansight/pytest-pytorch 26 | Tracker = https://github.com/Quansight/pytest-pytorch/issues 27 | 28 | [options] 29 | packages = find: 30 | include_package_data = True 31 | python_requires = >=3.6 32 | install_requires = 33 | pytest 34 | 35 | [options.packages.find] 36 | exclude = 37 | tests 38 | tests.* 39 | 40 | [options.entry_points] 41 | pytest11 = 42 | pytest_pytorch=pytest_pytorch.plugin 43 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Quansight/pytest-pytorch/d1a41f8f2572141da64fa9dfbff4dc7dfa2cb54f/tests/__init__.py -------------------------------------------------------------------------------- /tests/assets/_spy.py: -------------------------------------------------------------------------------- 1 | from typing import Collection, Set, Tuple 2 | 3 | from torch.testing._internal.common_utils import TestCase as PyTorchTestCase 4 | 5 | 6 | class TestCase: 7 | def __init__( 8 | self, 9 | template_test_case: PyTorchTestCase, 10 | instantiated_test_cases: Collection[PyTorchTestCase], 11 | ): 12 | self.template_test_case_name = template_test_case.__name__ 13 | 14 | functions = [] 15 | for template_function_name in dir(template_test_case): 16 | if not template_function_name.startswith("test_"): 17 | continue 18 | 19 | instantiated_test_cases_with_functions_names = [ 20 | (test_case.__name__, function_name) 21 | for test_case in instantiated_test_cases 22 | for function_name in dir(test_case) 23 | if function_name.startswith(template_function_name) 24 | ] 25 | 26 | functions.append( 27 | TestCaseFunction( 28 | template_test_case_name=self.template_test_case_name, 29 | template_function_name=template_function_name, 30 | instantiated_test_cases_with_functions_names=instantiated_test_cases_with_functions_names, 31 | ) 32 | ) 33 | self.functions = functions 34 | 35 | @property 36 | def new_cmds(self) -> str: 37 | return f"::{self.template_test_case_name}" 38 | 39 | @property 40 | def legacy_cmds(self) -> Tuple[str, ...]: 41 | return ("-k", f"{self.template_test_case_name}") 42 | 43 | def collect(self) -> Set[str]: 44 | collection = set() 45 | for function in self.functions: 46 | collection.update(function.collect()) 47 | return collection 48 | 49 | 50 | class TestCaseFunction: 51 | def __init__( 52 | self, 53 | *, 54 | template_test_case_name: str, 55 | template_function_name: str, 56 | instantiated_test_cases_with_functions_names: Collection[Tuple[str, str]], 57 | ) -> None: 58 | self._template_test_case_name = template_test_case_name 59 | self._template_function_name = template_function_name 60 | self._instantiated_test_cases_with_functions_names = ( 61 | instantiated_test_cases_with_functions_names 62 | ) 63 | 64 | @property 65 | def new_cmds(self) -> str: 66 | return f"::{self._template_test_case_name}::{self._template_function_name}" 67 | 68 | @property 69 | def legacy_cmds(self) -> Tuple[str, ...]: 70 | return ( 71 | "-k", 72 | f"{self._template_test_case_name} and {self._template_function_name}", 73 | ) 74 | 75 | def collect(self) -> Set[str]: 76 | return { 77 | f"::{test_case_name}::{function_name}" 78 | for test_case_name, function_name in self._instantiated_test_cases_with_functions_names 79 | } 80 | 81 | 82 | class Spy: 83 | def __init__(self): 84 | self.test_cases = [] 85 | 86 | def __call__(self, instantiate_device_type_tests): 87 | def wrapper(template_test_case, globals, *args, **kwargs): 88 | before = set(globals.keys()) 89 | instantiate_device_type_tests(template_test_case, globals, *args, **kwargs) 90 | after = set(globals.keys()) 91 | instantiated_test_cases = [globals[name] for name in (after - before)] 92 | 93 | self.test_cases.append( 94 | TestCase(template_test_case, instantiated_test_cases) 95 | ) 96 | 97 | return wrapper 98 | -------------------------------------------------------------------------------- /tests/assets/test_device.py: -------------------------------------------------------------------------------- 1 | from torch.testing._internal.common_device_type import ( 2 | instantiate_device_type_tests, 3 | onlyCPU, 4 | onlyOn, 5 | ) 6 | from torch.testing._internal.common_utils import TestCase 7 | 8 | # ====================================================================================== 9 | # This block is necessary to autogenerate the parametrization for 10 | # tests/test_plugin.py::test_standard_collection. 11 | # It needs to be placed **after** the import of 'instantiate_device_type_tests' and 12 | # **before** its first usage. 13 | # ====================================================================================== 14 | try: 15 | from _spy import Spy 16 | 17 | __spy__ = Spy() 18 | del Spy 19 | instantiate_device_type_tests = __spy__(instantiate_device_type_tests) 20 | except ModuleNotFoundError: 21 | pass 22 | # ====================================================================================== 23 | 24 | 25 | class TestFoo(TestCase): 26 | def test_bar(self, device): 27 | pass 28 | 29 | def test_baz(self, device): 30 | pass 31 | 32 | 33 | instantiate_device_type_tests(TestFoo, globals(), only_for=["cpu", "meta"]) 34 | 35 | 36 | class TestSpam(TestCase): 37 | @onlyOn("meta") 38 | def test_ham(self, device): 39 | pass 40 | 41 | @onlyCPU 42 | def test_eggs(self, device): 43 | pass 44 | 45 | 46 | instantiate_device_type_tests(TestSpam, globals(), only_for=["cpu", "meta"]) 47 | 48 | 49 | class TestQux(TestCase): 50 | def test_quux(self, device): 51 | pass 52 | 53 | 54 | instantiate_device_type_tests(TestQux, globals(), only_for=["cpu"]) 55 | -------------------------------------------------------------------------------- /tests/assets/test_disabled.py: -------------------------------------------------------------------------------- 1 | from torch.testing._internal.common_device_type import instantiate_device_type_tests 2 | from torch.testing._internal.common_utils import TestCase 3 | 4 | 5 | class TestFoo(TestCase): 6 | def test_bar(self, device): 7 | pass 8 | 9 | 10 | instantiate_device_type_tests(TestFoo, globals(), only_for="cpu") 11 | 12 | 13 | class TestSpam(TestCase): 14 | def test_ham(self): 15 | pass 16 | -------------------------------------------------------------------------------- /tests/assets/test_dtype.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.testing._internal.common_device_type import ( 3 | dtypes, 4 | instantiate_device_type_tests, 5 | ) 6 | from torch.testing._internal.common_utils import TestCase 7 | 8 | # ====================================================================================== 9 | # This block is necessary to autogenerate the parametrization for 10 | # tests/test_plugin.py::test_standard_collection. 11 | # It needs to be placed **after** the import of 'instantiate_device_type_tests' and 12 | # **before** its first usage. 13 | # ====================================================================================== 14 | try: 15 | from _spy import Spy 16 | 17 | __spy__ = Spy() 18 | del Spy 19 | instantiate_device_type_tests = __spy__(instantiate_device_type_tests) 20 | except ModuleNotFoundError: 21 | pass 22 | # ====================================================================================== 23 | 24 | 25 | class TestFoo(TestCase): 26 | @dtypes(torch.float16, torch.int32) 27 | def test_bar(self, device, dtype): 28 | pass 29 | 30 | 31 | instantiate_device_type_tests(TestFoo, globals(), only_for="cpu") 32 | -------------------------------------------------------------------------------- /tests/assets/test_nested_names.py: -------------------------------------------------------------------------------- 1 | from torch.testing._internal.common_utils import TestCase 2 | 3 | 4 | class TestFoo(TestCase): 5 | def test_baz(self): 6 | pass 7 | 8 | 9 | class TestFooBar(TestCase): 10 | def test_baz(self): 11 | pass 12 | 13 | 14 | class TestSpam(TestCase): 15 | def test_ham(self): 16 | pass 17 | 18 | def test_ham_eggs(self): 19 | pass 20 | -------------------------------------------------------------------------------- /tests/assets/test_op_infos.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.testing._core import _dispatch_dtypes 3 | from torch.testing._internal.common_device_type import ( 4 | instantiate_device_type_tests, 5 | ops, 6 | ) 7 | from torch.testing._internal.common_methods_invocations import OpInfo 8 | from torch.testing._internal.common_utils import TestCase 9 | 10 | # ====================================================================================== 11 | # This block is necessary to autogenerate the parametrization for 12 | # tests/test_plugin.py::test_standard_collection. 13 | # It needs to be placed **after** the import of 'instantiate_device_type_tests' and 14 | # **before** its first usage. 15 | # ====================================================================================== 16 | try: 17 | from _spy import Spy 18 | 19 | __spy__ = Spy() 20 | del Spy 21 | instantiate_device_type_tests = __spy__(instantiate_device_type_tests) 22 | except ModuleNotFoundError: 23 | pass 24 | # ====================================================================================== 25 | 26 | dtypes = _dispatch_dtypes((torch.float32,)) 27 | 28 | 29 | class TestFoo(TestCase): 30 | @ops( 31 | [ 32 | OpInfo("add", dtypes=dtypes), 33 | OpInfo("add", variant_test_name="with_alpha", dtypes=dtypes), 34 | OpInfo("sub", dtypes=dtypes), 35 | OpInfo("sub", variant_test_name="with_alpha", dtypes=dtypes), 36 | ] 37 | ) 38 | def test_bar(self, device, dtype, op): 39 | pass 40 | 41 | 42 | instantiate_device_type_tests(TestFoo, globals(), only_for="cpu") 43 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | pytest_plugins = ["pytester"] 4 | 5 | 6 | @pytest.fixture 7 | def collect_tests(testdir): 8 | def collect_tests_(file: str, cmds: str): 9 | testdir.copy_example(file) 10 | result = testdir.runpytest("--quiet", "--collect-only", *cmds) 11 | 12 | if result.outlines[-1].startswith("no tests collected"): 13 | return set() 14 | 15 | assert result.ret == pytest.ExitCode.OK 16 | 17 | collection = set() 18 | for line in result.outlines: 19 | if not line: 20 | break 21 | 22 | collection.add(line) 23 | 24 | return collection 25 | 26 | return collect_tests_ 27 | -------------------------------------------------------------------------------- /tests/test_cli.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | @pytest.mark.parametrize("option", ["--disable-pytest-pytorch"]) 5 | def test_disable_pytest_pytorch(testdir, option): 6 | result = testdir.runpytest("--help") 7 | assert option in "\n".join(result.outlines) 8 | -------------------------------------------------------------------------------- /tests/test_plugin.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import pathlib 3 | import sys 4 | 5 | import pytest 6 | 7 | from .utils import Config, make_parametrization, make_params 8 | 9 | 10 | def make_standard_collection_parametrization(): 11 | def make_file_config(test_cases): 12 | selection = set() 13 | for test_case in test_cases: 14 | selection.update(test_case.collect()) 15 | return Config( 16 | selection=selection, 17 | ) 18 | 19 | def make_test_case_configs(test_cases): 20 | return [ 21 | Config( 22 | new_cmds=test_case.new_cmds, 23 | legacy_cmds=test_case.legacy_cmds, 24 | selection=test_case.collect(), 25 | ) 26 | for test_case in test_cases 27 | ] 28 | 29 | def make_test_case_functions_configs(test_cases): 30 | return [ 31 | Config( 32 | new_cmds=test_case_function.new_cmds, 33 | legacy_cmds=test_case_function.legacy_cmds, 34 | selection=test_case_function.collect(), 35 | ) 36 | for test_case in test_cases 37 | for test_case_function in test_case.functions 38 | ] 39 | 40 | params = [] 41 | assets = pathlib.Path(__file__).parent / "assets" 42 | sys.path.insert(0, str(assets)) 43 | modules = sys.modules.copy() 44 | try: 45 | for test_file in assets.iterdir(): 46 | module = test_file.stem 47 | if not test_file.name.startswith("test"): 48 | continue 49 | 50 | try: 51 | test_module = importlib.import_module(module) 52 | spy = test_module.__spy__ 53 | except Exception: 54 | continue 55 | 56 | params.extend( 57 | make_params( 58 | make_file_config(spy.test_cases), 59 | *make_test_case_configs(spy.test_cases), 60 | *make_test_case_functions_configs(spy.test_cases), 61 | file=test_file.name, 62 | ) 63 | ) 64 | finally: 65 | sys.path.remove(str(assets)) 66 | sys.modules = modules 67 | 68 | return pytest.mark.parametrize(Config.PARAM_NAMES, params) 69 | 70 | 71 | @make_standard_collection_parametrization() 72 | def test_standard_collection(collect_tests, file, cmds, selection): 73 | collection = collect_tests(file, cmds) 74 | assert collection == selection 75 | 76 | 77 | @make_parametrization( 78 | Config( 79 | new_cmds=("-k", "cpu"), 80 | selection=( 81 | "::TestFooCPU::test_bar_cpu", 82 | "::TestFooCPU::test_baz_cpu", 83 | "::TestSpamCPU::test_ham_cpu", 84 | "::TestSpamCPU::test_eggs_cpu", 85 | "::TestQuxCPU::test_quux_cpu", 86 | ), 87 | ), 88 | Config( 89 | new_cmds=("-k", "meta"), 90 | selection=( 91 | "::TestFooMETA::test_bar_meta", 92 | "::TestFooMETA::test_baz_meta", 93 | "::TestSpamMETA::test_ham_meta", 94 | "::TestSpamMETA::test_eggs_meta", 95 | ), 96 | ), 97 | file="test_device.py", 98 | ) 99 | def test_devices(collect_tests, file, cmds, selection): 100 | collection = collect_tests(file, cmds) 101 | assert collection == selection 102 | 103 | 104 | @make_parametrization( 105 | Config( 106 | new_cmds=("-k", "float16"), 107 | selection=( 108 | "::TestFooCPU::test_bar_cpu_float16", 109 | "::TestFooCPU::test_bar_cpu_float16", 110 | ), 111 | ), 112 | Config( 113 | new_cmds=("-k", "int32"), 114 | selection=( 115 | "::TestFooCPU::test_bar_cpu_int32", 116 | "::TestFooCPU::test_bar_cpu_int32", 117 | ), 118 | ), 119 | file="test_dtype.py", 120 | ) 121 | def test_dtypes(collect_tests, file, cmds, selection): 122 | collection = collect_tests(file, cmds) 123 | assert collection == selection 124 | 125 | 126 | @make_parametrization( 127 | Config( 128 | new_cmds=("-k", "add"), 129 | selection=( 130 | "::TestFooCPU::test_bar_add_cpu_float32", 131 | "::TestFooCPU::test_bar_add_with_alpha_cpu_float32", 132 | ), 133 | ), 134 | Config( 135 | new_cmds=("-k", "sub"), 136 | selection=( 137 | "::TestFooCPU::test_bar_sub_cpu_float32", 138 | "::TestFooCPU::test_bar_sub_with_alpha_cpu_float32", 139 | ), 140 | ), 141 | file="test_op_infos.py", 142 | ) 143 | def test_op_infos(collect_tests, file, cmds, selection): 144 | collection = collect_tests(file, cmds) 145 | assert collection == selection 146 | 147 | 148 | @make_parametrization( 149 | Config( 150 | selection=( 151 | "::TestFoo::test_baz", 152 | "::TestFooBar::test_baz", 153 | "::TestSpam::test_ham", 154 | "::TestSpam::test_ham_eggs", 155 | ), 156 | ), 157 | Config( 158 | new_cmds="::TestFoo", 159 | legacy_cmds=("-k", "TestFoo and not TestFooBar"), 160 | selection=("::TestFoo::test_baz",), 161 | ), 162 | Config( 163 | new_cmds="::TestFooBar", 164 | legacy_cmds=("-k", "TestFooBar"), 165 | selection=("::TestFooBar::test_baz",), 166 | ), 167 | Config( 168 | new_cmds="::TestSpam::test_ham", 169 | legacy_cmds=("-k", "TestSpam and test_ham and not test_ham_eggs"), 170 | selection=("::TestSpam::test_ham",), 171 | ), 172 | Config( 173 | new_cmds="::TestSpam::test_ham_eggs", 174 | legacy_cmds=("-k", "TestSpam and test_ham_eggs"), 175 | selection=("::TestSpam::test_ham_eggs",), 176 | ), 177 | file="test_nested_names.py", 178 | ) 179 | def test_nested_names(collect_tests, file, cmds, selection): 180 | collection = collect_tests(file, cmds) 181 | assert collection == selection 182 | 183 | 184 | @make_parametrization( 185 | Config( 186 | selection=( 187 | "::TestFooCPU::test_bar_cpu", 188 | "::TestSpam::test_ham", 189 | ), 190 | ), 191 | Config( 192 | new_cmds="::TestFoo", 193 | selection=(), 194 | ), 195 | Config( 196 | new_cmds="::TestFoo::test_bar", 197 | selection=(), 198 | ), 199 | Config( 200 | new_cmds="::TestFooCPU", 201 | legacy_cmds=("-k", "TestFoo"), 202 | selection=("::TestFooCPU::test_bar_cpu",), 203 | ), 204 | Config( 205 | new_cmds="::TestFooCPU::test_bar_cpu", 206 | legacy_cmds=("-k", "TestFoo and test_bar"), 207 | selection=("::TestFooCPU::test_bar_cpu",), 208 | ), 209 | Config( 210 | new_cmds="::TestSpam", 211 | legacy_cmds=("-k", "TestSpam"), 212 | selection=("::TestSpam::test_ham",), 213 | ), 214 | Config( 215 | new_cmds="::TestSpam::test_ham", 216 | legacy_cmds=("-k", "TestSpam and test_ham"), 217 | selection=("::TestSpam::test_ham",), 218 | ), 219 | file="test_disabled.py", 220 | ) 221 | def test_disabled(collect_tests, file, cmds, selection): 222 | collection = collect_tests(file, ("--disable-pytest-pytorch", *cmds)) 223 | assert collection == selection 224 | -------------------------------------------------------------------------------- /tests/test_smoke.py: -------------------------------------------------------------------------------- 1 | import builtins 2 | import importlib 3 | import os 4 | import pathlib 5 | import re 6 | import sys 7 | 8 | import pytest 9 | 10 | PACKAGE_NAME = "pytest_pytorch" 11 | PROJECT_ROOT = (pathlib.Path(__file__).parent / "..").resolve() 12 | PACKAGE_ROOT = PROJECT_ROOT / PACKAGE_NAME 13 | 14 | 15 | def collect_modules(): 16 | def is_private(path): 17 | return pathlib.Path(path).name.startswith("_") 18 | 19 | def path_to_module(path): 20 | return str(pathlib.Path(path).with_suffix("")).replace(os.sep, ".") 21 | 22 | modules = [] 23 | for root, dirs, files in os.walk(PACKAGE_ROOT): 24 | if is_private(root) or "__init__.py" not in files: 25 | del dirs[:] 26 | continue 27 | 28 | path = pathlib.Path(root).relative_to(PROJECT_ROOT) 29 | modules.append(path_to_module(path)) 30 | 31 | for file in files: 32 | if is_private(file) or not file.endswith(".py"): 33 | continue 34 | 35 | modules.append(path_to_module(path / file)) 36 | 37 | return modules 38 | 39 | 40 | @pytest.mark.parametrize("module", collect_modules()) 41 | def test_importability(module): 42 | importlib.import_module(module) 43 | 44 | 45 | def import_package_under_test(): 46 | try: 47 | return importlib.import_module(PACKAGE_NAME) 48 | except Exception as error: 49 | raise RuntimeError( 50 | f"The package '{PACKAGE_NAME}' could not be imported. " 51 | f"Check the results of tests/test_smoke.py::test_importability for details." 52 | ) from error 53 | 54 | 55 | def test_version_installed(): 56 | def is_canonical(version): 57 | # Copied from 58 | # https://www.python.org/dev/peps/pep-0440/#appendix-b-parsing-version-strings-with-regular-expressions 59 | return ( 60 | re.match( 61 | r"^([1-9][0-9]*!)?(0|[1-9][0-9]*)(\.(0|[1-9][0-9]*))*((a|b|rc)(0|[1-9][0-9]*))?(\.post(0|[1-9][0-9]*))?(\.dev(0|[1-9][0-9]*))?$", 62 | version, 63 | ) 64 | is not None 65 | ) 66 | 67 | def is_dev(version): 68 | match = re.search(r"\+g[\da-f]{7}([.]\d{14})?", version) 69 | if match is not None: 70 | return is_canonical(version[: match.span()[0]]) 71 | else: 72 | return False 73 | 74 | put = import_package_under_test() 75 | assert is_canonical(put.__version__) or is_dev(put.__version__) 76 | 77 | 78 | def patch_imports( 79 | mocker, 80 | *names, 81 | retain_condition=None, 82 | import_error_condition=None, 83 | ): 84 | if retain_condition is None: 85 | 86 | def retain_condition(name): 87 | return not any(name.startswith(name_) for name_ in names) 88 | 89 | if import_error_condition is None: 90 | 91 | def import_error_condition(name, globals, locals, fromlist, level): 92 | direct = name in names 93 | indirect = fromlist is not None and any( 94 | from_ in names for from_ in fromlist 95 | ) 96 | return direct or indirect 97 | 98 | __import__ = builtins.__import__ 99 | 100 | def patched_import(name, globals, locals, fromlist, level): 101 | if import_error_condition(name, globals, locals, fromlist, level): 102 | raise ImportError() 103 | 104 | return __import__(name, globals, locals, fromlist, level) 105 | 106 | mocker.patch.object(builtins, "__import__", new=patched_import) 107 | 108 | values = { 109 | name: module for name, module in sys.modules.items() if retain_condition(name) 110 | } 111 | mocker.patch.dict(sys.modules, clear=True, values=values) 112 | 113 | 114 | def test_version_not_installed(mocker): 115 | def import_error_condition(name, globals, locals, fromlist, level): 116 | return name == "_version" and fromlist == ("version",) 117 | 118 | patch_imports(mocker, PACKAGE_NAME, import_error_condition=import_error_condition) 119 | 120 | put = import_package_under_test() 121 | assert put.__version__ == "UNKNOWN" 122 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Collection, List, Optional, Sequence, Set, Tuple, Union 2 | 3 | import pytest 4 | from _pytest.mark.structures import MarkDecorator, ParameterSet 5 | 6 | __all__ = ["Config", "make_params", "make_parametrization"] 7 | 8 | 9 | class Config: 10 | PARAM_NAMES = ("file", "cmds", "selection") 11 | 12 | def __init__( 13 | self, 14 | *, 15 | file: Optional[str] = None, 16 | new_cmds: Union[str, Sequence[str]] = (), 17 | legacy_cmds: Optional[Union[str, Sequence[str]]] = None, 18 | selection: Collection[str], 19 | ): 20 | self._file = file 21 | self._new_cmds = new_cmds 22 | if legacy_cmds is None: 23 | legacy_cmds = new_cmds 24 | self._legacy_cmds = legacy_cmds 25 | self._selection = selection 26 | 27 | @staticmethod 28 | def _parse_cmds(cmds: Union[str, Sequence[str]], file: str) -> Tuple[str, ...]: 29 | cmds = (cmds,) if isinstance(cmds, str) else tuple(cmds) 30 | if not cmds or not cmds[0].startswith("::"): 31 | return (file, *cmds) 32 | else: 33 | return (file + cmds[0], *cmds[1:]) 34 | 35 | @staticmethod 36 | def _parse_selection(selection: Collection[str], file: str) -> Set[str]: 37 | return {file + item if item.startswith("::") else item for item in selection} 38 | 39 | @staticmethod 40 | def _cmds_to_id(cmds: Tuple[str, ...]) -> str: 41 | return " ".join(cmds) 42 | 43 | def make_params(self, file: Optional[str] = None) -> Tuple[ParameterSet, ...]: 44 | file = self._file or file 45 | if not file: 46 | raise pytest.UsageError 47 | 48 | new_cmds = self._parse_cmds(self._new_cmds, file) 49 | legacy_cmds = self._parse_cmds(self._legacy_cmds, file) 50 | selection = self._parse_selection(self._selection, file) 51 | 52 | new_params = pytest.param( 53 | file, 54 | new_cmds, 55 | selection, 56 | id=self._cmds_to_id(new_cmds), 57 | ) 58 | if new_cmds == legacy_cmds: 59 | return (new_params,) 60 | 61 | legacy_cmds = pytest.param( 62 | file, 63 | legacy_cmds, 64 | selection, 65 | id=self._cmds_to_id(legacy_cmds), 66 | ) 67 | 68 | return (new_params, legacy_cmds) 69 | 70 | 71 | def make_params(*configs: Config, file: Optional[str] = None) -> List[ParameterSet]: 72 | return [param for config in configs for param in config.make_params(file)] 73 | 74 | 75 | def make_parametrization(*configs, file: Optional[str] = None) -> MarkDecorator: 76 | return pytest.mark.parametrize(Config.PARAM_NAMES, make_params(*configs, file=file)) 77 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | ;See link below for available options 3 | ;https://tox.readthedocs.io/en/latest/config.html 4 | 5 | requires = 6 | tox-ltt >= 0.4 7 | isolated_build = True 8 | envlist = lint, tests 9 | 10 | [testenv:lint] 11 | requires = 12 | pre-commit 13 | allowlist_externals = 14 | pre-commit 15 | skip_install = True 16 | commands_pre = pre-commit install-hooks 17 | commands = pre-commit run --all-files 18 | 19 | [testenv:tests] 20 | passenv = 21 | GITHUB_ACTIONS 22 | pytorch_force_cpu = True 23 | deps = 24 | pytest >= 6 25 | pytest-mock >= 3.1 26 | torch >= 1.9 27 | # PyTorch's test suite requires numpy and scipy, but they are no dependencies of torch 28 | numpy 29 | scipy 30 | commands = 31 | pytest -c pytest.ini {posargs} 32 | 33 | [testenv:publishable] 34 | allowlist_externals = 35 | rm 36 | skip_install = True 37 | deps = 38 | check-wheel-contents 39 | build 40 | twine 41 | commands = 42 | # TODO: Make this work on Windows 43 | rm -rf build dist pytest_pytorch.egg-info 44 | python -m build --sdist --wheel . 45 | twine check --strict dist/* 46 | check-wheel-contents dist 47 | --------------------------------------------------------------------------------