├── .basedpyright └── baseline.json ├── .editorconfig ├── .github ├── dependabot.yml └── workflows │ ├── autopush.yml │ └── ci.yml ├── .gitignore ├── .gitlab-ci.yml ├── .pylintrc-local.yml ├── CITATION.cff ├── LICENSE ├── README.rst ├── doc ├── .gitignore ├── Makefile ├── codegen.rst ├── conf.py ├── convergence.rst ├── graph.rst ├── index.rst ├── misc.rst ├── mpi.rst ├── obj_array.rst ├── persistent_dict.rst ├── reference.rst ├── tag.rst └── upload-docs.sh ├── pyproject.toml ├── pytools ├── __init__.py ├── batchjob.py ├── codegen.py ├── convergence.py ├── datatable.py ├── debug.py ├── graph.py ├── graphviz.py ├── lex.py ├── mpi.py ├── mpiwrap.py ├── obj_array.py ├── persistent_dict.py ├── prefork.py ├── py.typed ├── py_codegen.py ├── spatial_btree.py ├── stopwatch.py ├── tag.py ├── test │ ├── __init__.py │ ├── test_data_table.py │ ├── test_dataclasses.py │ ├── test_graph_tools.py │ ├── test_math_stuff.py │ ├── test_mpi.py │ ├── test_persistent_dict.py │ ├── test_py_codegen.py │ └── test_pytools.py └── version.py ├── run-mypy.sh └── run-pylint.sh /.editorconfig: -------------------------------------------------------------------------------- 1 | # https://editorconfig.org/ 2 | # https://github.com/editorconfig/editorconfig-vim 3 | # https://github.com/editorconfig/editorconfig-emacs 4 | 5 | root = true 6 | 7 | [*] 8 | indent_style = space 9 | end_of_line = lf 10 | charset = utf-8 11 | trim_trailing_whitespace = true 12 | insert_final_newline = true 13 | 14 | [*.py] 15 | indent_size = 4 16 | 17 | [*.rst] 18 | indent_size = 4 19 | 20 | [*.cpp] 21 | indent_size = 2 22 | 23 | [*.hpp] 24 | indent_size = 2 25 | 26 | # There may be one in doc/ 27 | [Makefile] 28 | indent_style = tab 29 | 30 | # https://github.com/microsoft/vscode/issues/1679 31 | [*.md] 32 | trim_trailing_whitespace = false -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | # Set update schedule for GitHub Actions 4 | - package-ecosystem: "github-actions" 5 | directory: "/" 6 | schedule: 7 | interval: "weekly" 8 | 9 | # vim: sw=4 10 | -------------------------------------------------------------------------------- /.github/workflows/autopush.yml: -------------------------------------------------------------------------------- 1 | name: Gitlab mirror 2 | on: 3 | push: 4 | branches: 5 | - main 6 | 7 | jobs: 8 | autopush: 9 | name: Automatic push to gitlab.tiker.net 10 | if: startsWith(github.repository, 'inducer/') 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v4 14 | - run: | 15 | curl -L -O https://tiker.net/ci-support-v0 16 | . ./ci-support-v0 17 | mirror_github_to_gitlab 18 | 19 | env: 20 | GITLAB_AUTOPUSH_KEY: ${{ secrets.GITLAB_AUTOPUSH_KEY }} 21 | 22 | # vim: sw=4 23 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | on: 3 | push: 4 | branches: 5 | - main 6 | pull_request: 7 | paths-ignore: 8 | - 'doc/*.rst' 9 | schedule: 10 | - cron: '17 3 * * 0' 11 | 12 | concurrency: 13 | group: ${{ github.head_ref || github.ref_name }} 14 | cancel-in-progress: true 15 | 16 | jobs: 17 | ruff: 18 | name: Ruff 19 | runs-on: ubuntu-latest 20 | steps: 21 | - uses: actions/checkout@v4 22 | - uses: actions/setup-python@v5 23 | - name: "Main Script" 24 | run: | 25 | pip install ruff 26 | ruff check 27 | 28 | typos: 29 | name: Typos 30 | runs-on: ubuntu-latest 31 | steps: 32 | - uses: actions/checkout@v4 33 | - uses: crate-ci/typos@master 34 | 35 | validate_cff: 36 | name: Validate CITATION.cff 37 | runs-on: ubuntu-latest 38 | steps: 39 | - uses: actions/checkout@v4 40 | - uses: actions/setup-python@v5 41 | with: 42 | python-version: '3.x' 43 | - run: | 44 | pip install cffconvert 45 | cffconvert -i CITATION.cff --validate 46 | 47 | pylint: 48 | name: Pylint 49 | runs-on: ubuntu-latest 50 | steps: 51 | - uses: actions/checkout@v4 52 | - 53 | uses: actions/setup-python@v5 54 | with: 55 | python-version: '3.x' 56 | - name: "Main Script" 57 | run: | 58 | EXTRA_INSTALL="numpy pymbolic orderedsets" 59 | curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/prepare-and-run-pylint.sh 60 | . ./prepare-and-run-pylint.sh "$(basename $GITHUB_REPOSITORY)" 61 | 62 | basedpyright: 63 | runs-on: ubuntu-latest 64 | steps: 65 | - uses: actions/checkout@v4 66 | - uses: actions/setup-python@v5 67 | with: 68 | python-version: '3.x' 69 | - name: "Main Script" 70 | run: | 71 | curl -L -O https://tiker.net/ci-support-v0 72 | . ./ci-support-v0 73 | build_py_project_in_venv 74 | sudo apt update 75 | sudo apt -y install libopenmpi-dev 76 | pip install numpy attrs orderedsets pytest mpi4py matplotlib 77 | pip install basedpyright 78 | basedpyright 79 | 80 | pytest: 81 | name: Pytest on Py${{ matrix.python-version }} ${{ matrix.os }} 82 | runs-on: ${{ matrix.os }} 83 | strategy: 84 | matrix: 85 | python-version: ["3.10", "3.12", "3.x", "pypy3.10"] 86 | os: [ubuntu-latest, macos-latest] 87 | steps: 88 | - uses: actions/checkout@v4 89 | - 90 | uses: actions/setup-python@v5 91 | with: 92 | python-version: ${{ matrix.python-version }} 93 | - name: "Main Script" 94 | run: | 95 | # untested, causes import error with Pytest >= 6.2.0 96 | # AK, 2020-12-13 97 | rm pytools/mpiwrap.py 98 | 99 | EXTRA_INSTALL="numpy frozendict immutabledict orderedsets constantdict immutables pyrsistent attrs" 100 | curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/build-and-test-py-project.sh 101 | . ./build-and-test-py-project.sh 102 | 103 | # Also run with optimizations turned on, since opt_frozen_dataclass 104 | # depends on the __debug__ setting. 105 | python -O -m pytest pytools/test/test_dataclasses.py 106 | 107 | pytest_nonumpy: 108 | name: Pytest without Numpy 109 | runs-on: ubuntu-latest 110 | steps: 111 | - uses: actions/checkout@v4 112 | - 113 | uses: actions/setup-python@v5 114 | with: 115 | python-version: '3.x' 116 | - name: "Main Script" 117 | run: | 118 | rm pytools/{convergence,spatial_btree,obj_array,mpiwrap}.py 119 | curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/build-and-test-py-project.sh 120 | . ./build-and-test-py-project.sh 121 | 122 | #examples: 123 | # name: Examples Py3 124 | # runs-on: ubuntu-latest 125 | # steps: 126 | # - uses: actions/checkout@v4 127 | # - 128 | # uses: actions/setup-python@v5 129 | # with: 130 | # python-version: '3.x' 131 | # - name: "Main Script" 132 | # run: | 133 | # EXTRA_INSTALL="numpy pymbolic" 134 | # curl -L -O https://tiker.net/ci-support-v0 135 | # . ./ci-support-v0 136 | # build_py_project_in_venv 137 | # run_examples 138 | 139 | downstream_tests: 140 | strategy: 141 | matrix: 142 | downstream_project: [loopy, pytato] 143 | name: Tests for downstream project ${{ matrix.downstream_project }} 144 | runs-on: ubuntu-latest 145 | steps: 146 | - uses: actions/checkout@v4 147 | - name: "Main Script" 148 | env: 149 | DOWNSTREAM_PROJECT: ${{ matrix.downstream_project }} 150 | run: | 151 | curl -L -O https://tiker.net/ci-support-v0 152 | . ./ci-support-v0 153 | test_downstream "$DOWNSTREAM_PROJECT" 154 | 155 | docs: 156 | name: Documentation 157 | runs-on: ubuntu-latest 158 | steps: 159 | - uses: actions/checkout@v4 160 | - 161 | uses: actions/setup-python@v5 162 | with: 163 | python-version: '3.x' 164 | - name: "Main Script" 165 | run: | 166 | EXTRA_INSTALL="numpy" 167 | curl -L -O https://tiker.net/ci-support-v0 168 | . ci-support-v0 169 | build_py_project_in_venv 170 | build_docs 171 | 172 | # vim: sw=4 173 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | .*.sw[po] 3 | *~ 4 | *.pyc 5 | *.pyo 6 | *.egg-info 7 | MANIFEST 8 | dist 9 | setuptools*egg 10 | setuptools.pth 11 | setuptools*tar.gz 12 | distribute*egg 13 | distribute*tar.gz 14 | 15 | .cache 16 | .mypy_cache 17 | 18 | *.dat 19 | 20 | .pylintrc.yml 21 | .run-pylint.py 22 | -------------------------------------------------------------------------------- /.gitlab-ci.yml: -------------------------------------------------------------------------------- 1 | Pytest: 2 | script: | 3 | # untested, causes import error with Pytest >= 6.2.0 4 | # AK, 2020-12-13 5 | rm pytools/mpiwrap.py 6 | 7 | export EXTRA_INSTALL="numpy siphash24" 8 | curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/build-and-test-py-project.sh 9 | . ./build-and-test-py-project.sh 10 | tags: 11 | - python3 12 | except: 13 | - tags 14 | artifacts: 15 | reports: 16 | junit: test/pytest.xml 17 | 18 | Pytest without Numpy: 19 | script: | 20 | EXTRA_INSTALL="siphash24" 21 | rm pytools/{convergence,spatial_btree,obj_array,mpiwrap}.py 22 | curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/build-and-test-py-project.sh 23 | . ./build-and-test-py-project.sh 24 | tags: 25 | - python3 26 | except: 27 | - tags 28 | artifacts: 29 | reports: 30 | junit: test/pytest.xml 31 | 32 | # Examples: 33 | # script: | 34 | # EXTRA_INSTALL="numpy pymbolic" 35 | # curl -L -O https://tiker.net/ci-support-v0 36 | # . ./ci-support-v0 37 | # build_py_project_in_venv 38 | # run_examples 39 | # tags: 40 | # - python3 41 | # except: 42 | # - tags 43 | 44 | Ruff: 45 | script: 46 | - pipx install ruff 47 | - ruff check 48 | tags: 49 | - docker-runner 50 | except: 51 | - tags 52 | 53 | Pylint: 54 | script: 55 | - EXTRA_INSTALL="numpy pymbolic orderedsets siphash24" 56 | - py_version=3 57 | - curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/prepare-and-run-pylint.sh 58 | - . ./prepare-and-run-pylint.sh "$CI_PROJECT_NAME" 59 | tags: 60 | - python3 61 | except: 62 | - tags 63 | 64 | Documentation: 65 | script: 66 | - EXTRA_INSTALL="numpy siphash24" 67 | - curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/build-docs.sh 68 | - ". ./build-docs.sh" 69 | tags: 70 | - python3 71 | 72 | Downstream: 73 | parallel: 74 | matrix: 75 | - DOWNSTREAM_PROJECT: [loopy, pytato] 76 | tags: 77 | - large-node 78 | - "docker-runner" 79 | script: | 80 | curl -L -O https://tiker.net/ci-support-v0 81 | . ./ci-support-v0 82 | test_downstream "$DOWNSTREAM_PROJECT" 83 | 84 | # vim: sw=2 85 | -------------------------------------------------------------------------------- /.pylintrc-local.yml: -------------------------------------------------------------------------------- 1 | - arg: ignore 2 | val: 3 | - mpiwrap.py 4 | - arg: ignored-modules 5 | val: 6 | - matplotlib 7 | - siphash24 8 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: "Kloeckner" 5 | given-names: "Andreas" 6 | orcid: "https://orcid.org/0000-0003-1228-519X" 7 | - family-names: "Christensen" 8 | given-names: "Nick" 9 | - family-names: "Wala" 10 | given-names: "Matt" 11 | - family-names: "Fikl" 12 | given-names: "Alex" 13 | - family-names: "Stevens" 14 | given-names: "James" 15 | - family-names: "Diener" 16 | given-names: "Matthias" 17 | - family-names: "Kulkarni" 18 | given-names: "Kaushik" 19 | - family-names: "Witherden" 20 | given-names: "Freddie" 21 | - family-names: "Kempf" 22 | given-names: "Dominic" 23 | - family-names: "Gibson" 24 | given-names: "Thomas H." 25 | - family-names: "Fernando" 26 | given-names: "Isuru" 27 | - family-names: "Yu" 28 | given-names: "Yichao" 29 | - family-names: "Wei" 30 | given-names: "Xiaoyu" 31 | - family-names: "Drix" 32 | given-names: "Damien" 33 | - family-names: "Hoag" 34 | given-names: "Ellis" 35 | - family-names: "Gao" 36 | given-names: "Hao" 37 | - family-names: "Dercksen" 38 | given-names: "Koen" 39 | - family-names: "Kocak" 40 | given-names: "Kubilay" 41 | - family-names: "Ofitserov" 42 | given-names: "Nikita" 43 | 44 | title: "pytools" 45 | version: 2022.1.7 46 | doi: 10.5281/zenodo.6533949 47 | date-released: 2022-05-04 48 | url: "https://github.com/inducer/pytools" 49 | license: MIT 50 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | pytools is licensed to you under the MIT/X Consortium license: 2 | 3 | Copyright (c) 2009-16 Andreas Klöckner and Contributors. 4 | 5 | Permission is hereby granted, free of charge, to any person 6 | obtaining a copy of this software and associated documentation 7 | files (the "Software"), to deal in the Software without 8 | restriction, including without limitation the rights to use, 9 | copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the 11 | Software is furnished to do so, subject to the following 12 | conditions: 13 | 14 | The above copyright notice and this permission notice shall be 15 | included in all copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 18 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES 19 | OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 20 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT 21 | HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 22 | WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 23 | FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 24 | OTHER DEALINGS IN THE SOFTWARE. 25 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | Pytools: Lots of Little Utilities 2 | ================================= 3 | 4 | .. image:: https://gitlab.tiker.net/inducer/pytools/badges/main/pipeline.svg 5 | :alt: Gitlab Build Status 6 | :target: https://gitlab.tiker.net/inducer/pytools/commits/main 7 | .. image:: https://github.com/inducer/pytools/actions/workflows/ci.yml/badge.svg 8 | :alt: Github Build Status 9 | :target: https://github.com/inducer/pytools/actions/workflows/ci.yml 10 | .. image:: https://badge.fury.io/py/pytools.svg 11 | :alt: Python Package Index Release Page 12 | :target: https://pypi.org/project/pytools/ 13 | .. image:: https://zenodo.org/badge/1575270.svg 14 | :alt: Zenodo DOI for latest release 15 | :target: https://zenodo.org/badge/latestdoi/1575270 16 | 17 | Pytools is a big bag of things that are "missing" from the Python standard 18 | library. This is mainly a dependency of my other software packages, and is 19 | probably of little interest to you unless you use those. If you're curious 20 | nonetheless, here's what's on offer: 21 | 22 | * A ton of small tool functions such as ``len_iterable``, ``argmin``, 23 | tuple generation, permutation generation, ASCII table pretty printing, 24 | GvR's ``monkeypatch_xxx()`` hack, the elusive ``flatten``, and much more. 25 | * Batch job submission, ``pytools.batchjob``. 26 | * A lexer, ``pytools.lex``. 27 | * A persistent key-value store, ``pytools.persistent_dict``. 28 | 29 | Links: 30 | 31 | * `Documentation `__ 32 | * `Github `__ 33 | * ``pytools.log`` has been spun out into a separate project, 34 | `logpyle `__. 35 | -------------------------------------------------------------------------------- /doc/.gitignore: -------------------------------------------------------------------------------- 1 | _build 2 | -------------------------------------------------------------------------------- /doc/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = python $(shell which sphinx-build) 7 | SPHINXPROJ = pytools 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /doc/codegen.rst: -------------------------------------------------------------------------------- 1 | .. automodule:: pytools.codegen 2 | -------------------------------------------------------------------------------- /doc/conf.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from urllib.request import urlopen 4 | 5 | 6 | _conf_url = \ 7 | "https://raw.githubusercontent.com/inducer/sphinxconfig/main/sphinxconfig.py" 8 | with urlopen(_conf_url) as _inf: 9 | exec(compile(_inf.read(), _conf_url, "exec"), globals()) 10 | 11 | copyright = "2009-21, Andreas Kloeckner" 12 | author = "Andreas Kloeckner" 13 | 14 | # The version info for the project you're documenting, acts as replacement for 15 | # |version| and |release|, also used in various other places throughout the 16 | # built documents. 17 | # 18 | # The short X.Y version. 19 | ver_dic = {} 20 | with open("../pytools/version.py") as vfile: 21 | exec(compile(vfile.read(), "../pytools/version.py", "exec"), 22 | ver_dic) 23 | 24 | version = ".".join(str(x) for x in ver_dic["VERSION"]) 25 | release = ver_dic["VERSION_TEXT"] 26 | 27 | # List of patterns, relative to source directory, that match files and 28 | # directories to ignore when looking for source files. 29 | # This patterns also effect to html_static_path and html_extra_path 30 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 31 | 32 | intersphinx_mapping = { 33 | "loopy": ("https://documen.tician.de/loopy", None), 34 | "numpy": ("https://numpy.org/doc/stable", None), 35 | "pymbolic": ("https://documen.tician.de/pymbolic", None), 36 | "pytest": ("https://docs.pytest.org/en/stable", None), 37 | "setuptools": ("https://setuptools.pypa.io/en/latest", None), 38 | "python": ("https://docs.python.org/3", None), 39 | "platformdirs": ("https://platformdirs.readthedocs.io/en/latest", None), 40 | } 41 | 42 | nitpicky = True 43 | nitpick_ignore_regex = [ 44 | ["py:class", r"typing_extensions\.(.+)"], 45 | ["py:class", r"ReadableBuffer"], 46 | ] 47 | 48 | autodoc_type_aliases = { 49 | "GraphT": "pytools.graph.GraphT", 50 | "NodeT": "pytools.graph.NodeT", 51 | } 52 | -------------------------------------------------------------------------------- /doc/convergence.rst: -------------------------------------------------------------------------------- 1 | Testing convergence 2 | ------------------- 3 | 4 | .. automodule:: pytools.convergence 5 | -------------------------------------------------------------------------------- /doc/graph.rst: -------------------------------------------------------------------------------- 1 | .. automodule:: pytools.graph 2 | -------------------------------------------------------------------------------- /doc/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to pytools's documentation! 2 | =================================== 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | :caption: Contents: 7 | 8 | reference 9 | obj_array 10 | persistent_dict 11 | convergence 12 | graph 13 | tag 14 | codegen 15 | mpi 16 | misc 17 | 🚀 Github 18 | 💾 Download Releases 19 | 20 | Indices and tables 21 | ================== 22 | 23 | * :ref:`genindex` 24 | * :ref:`modindex` 25 | * :ref:`search` 26 | -------------------------------------------------------------------------------- /doc/misc.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | This command should install :mod:`pytools`:: 5 | 6 | pip install pytools 7 | 8 | You may need to run this with :command:`sudo`. 9 | If you don't already have `pip `_, 10 | run this beforehand:: 11 | 12 | curl -O https://raw.github.com/pypa/pip/master/contrib/get-pip.py 13 | python get-pip.py 14 | 15 | For a more manual installation, download the source, unpack it, 16 | and say:: 17 | 18 | python setup.py install 19 | 20 | User-visible changes 21 | ==================== 22 | 23 | Version 2020.4 24 | -------------- 25 | 26 | .. note:: 27 | 28 | This version is currently under development. You can get snapshots from 29 | Pytools's `git repository `_ 30 | 31 | * :mod:`pytools.codegen` was added. 32 | 33 | Version 2020.3 34 | -------------- 35 | 36 | * Type annotations were added. 37 | * Python 2 support was dropped. 38 | 39 | .. _license: 40 | 41 | License 42 | ======= 43 | 44 | :mod:`pytools` is licensed to you under the MIT/X Consortium license: 45 | 46 | Copyright (c) 2008-17 Andreas Klöckner 47 | 48 | Permission is hereby granted, free of charge, to any person 49 | obtaining a copy of this software and associated documentation 50 | files (the "Software"), to deal in the Software without 51 | restriction, including without limitation the rights to use, 52 | copy, modify, merge, publish, distribute, sublicense, and/or sell 53 | copies of the Software, and to permit persons to whom the 54 | Software is furnished to do so, subject to the following 55 | conditions: 56 | 57 | The above copyright notice and this permission notice shall be 58 | included in all copies or substantial portions of the Software. 59 | 60 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 61 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES 62 | OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 63 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT 64 | HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 65 | WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 66 | FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 67 | OTHER DEALINGS IN THE SOFTWARE. 68 | -------------------------------------------------------------------------------- /doc/mpi.rst: -------------------------------------------------------------------------------- 1 | .. automodule:: pytools.mpi 2 | -------------------------------------------------------------------------------- /doc/obj_array.rst: -------------------------------------------------------------------------------- 1 | .. automodule:: pytools.obj_array 2 | -------------------------------------------------------------------------------- /doc/persistent_dict.rst: -------------------------------------------------------------------------------- 1 | .. automodule:: pytools.persistent_dict 2 | -------------------------------------------------------------------------------- /doc/reference.rst: -------------------------------------------------------------------------------- 1 | .. automodule:: pytools 2 | .. automodule:: pytools.datatable 3 | 4 | .. automodule:: pytools.graphviz 5 | -------------------------------------------------------------------------------- /doc/tag.rst: -------------------------------------------------------------------------------- 1 | .. automodule:: pytools.tag 2 | -------------------------------------------------------------------------------- /doc/upload-docs.sh: -------------------------------------------------------------------------------- 1 | #! /bin/sh 2 | 3 | rsync --verbose --archive --delete _build/html/ doc-upload:doc/pytools 4 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "pytools" 7 | version = "2025.1.6" 8 | description = "A collection of tools for Python" 9 | readme = "README.rst" 10 | license = "MIT" 11 | authors = [ 12 | { name = "Andreas Kloeckner", email = "inform@tiker.net" }, 13 | ] 14 | requires-python = ">=3.10" 15 | classifiers = [ 16 | "Development Status :: 4 - Beta", 17 | "Intended Audience :: Developers", 18 | "Intended Audience :: Other Audience", 19 | "Intended Audience :: Science/Research", 20 | "Natural Language :: English", 21 | "Programming Language :: Python", 22 | "Programming Language :: Python :: 3 :: Only", 23 | "Topic :: Scientific/Engineering", 24 | "Topic :: Scientific/Engineering :: Information Analysis", 25 | "Topic :: Scientific/Engineering :: Mathematics", 26 | "Topic :: Scientific/Engineering :: Visualization", 27 | "Topic :: Software Development :: Libraries", 28 | "Topic :: Utilities", 29 | ] 30 | dependencies = [ 31 | "platformdirs>=2.2", 32 | # for dataclass_transform with frozen_default 33 | "typing-extensions>=4.5", 34 | "siphash24>=1.6", 35 | ] 36 | 37 | [project.optional-dependencies] 38 | numpy = [ 39 | "numpy>=1.6", 40 | ] 41 | test = [ 42 | "mypy", 43 | "pytest", 44 | "ruff", 45 | ] 46 | 47 | [project.urls] 48 | Documentation = "https://documen.tician.de/pytools/" 49 | Homepage = "https://github.com/inducer/pytools/" 50 | 51 | [tool.hatch.build.targets.sdist] 52 | exclude = [ 53 | "/.git*", 54 | "/doc/_build", 55 | "/.editorconfig", 56 | "/run-*.sh", 57 | ] 58 | 59 | [tool.ruff] 60 | preview = true 61 | 62 | [tool.ruff.lint] 63 | extend-select = [ 64 | "B", # flake8-bugbear 65 | "C", # flake8-comprehensions 66 | "E", # pycodestyle 67 | "F", # pyflakes 68 | "G", # flake8-logging-format 69 | "I", # flake8-isort 70 | "N", # pep8-naming 71 | "NPY", # numpy 72 | "PGH", # pygrep-hooks 73 | "Q", # flake8-quotes 74 | "RUF", # ruff 75 | "SIM", # flake8-simplify 76 | "TC", # flake8-type-checking 77 | "UP", # pyupgrade 78 | "W", # pycodestyle 79 | ] 80 | extend-ignore = [ 81 | "C90", # McCabe complexity 82 | "E221", # multiple spaces before operator 83 | "E226", # missing whitespace around arithmetic operator 84 | "E402", # module-level import not at top of file 85 | "UP031", # use f-strings instead of % 86 | "UP032", # use f-strings instead of .format 87 | ] 88 | 89 | [tool.ruff.lint.flake8-quotes] 90 | docstring-quotes = "double" 91 | inline-quotes = "double" 92 | multiline-quotes = "double" 93 | 94 | [tool.ruff.lint.isort] 95 | combine-as-imports = true 96 | known-local-folder = [ 97 | "pytools", 98 | ] 99 | lines-after-imports = 2 100 | required-imports = ["from __future__ import annotations"] 101 | 102 | [tool.ruff.lint.pep8-naming] 103 | extend-ignore-names = ["update_for_*"] 104 | 105 | 106 | [tool.basedpyright] 107 | reportImplicitStringConcatenation = "none" 108 | reportUnnecessaryIsInstance = "none" 109 | reportUnusedCallResult = "none" 110 | reportExplicitAny = "none" 111 | reportUnreachable = "hint" 112 | 113 | # This reports even cycles that are qualified by 'if TYPE_CHECKING'. Not what 114 | # we care about at this moment. 115 | # https://github.com/microsoft/pyright/issues/746 116 | reportImportCycles = "none" 117 | pythonVersion = "3.10" 118 | pythonPlatform = "All" 119 | 120 | [[tool.basedpyright.executionEnvironments]] 121 | root = "pytools/test" 122 | reportUnknownArgumentType = "hint" 123 | reportPrivateUsage = "none" 124 | 125 | [tool.mypy] 126 | python_version = "3.10" 127 | ignore_missing_imports = true 128 | warn_unused_ignores = true 129 | # TODO: enable this at some point 130 | # check_untyped_defs = true 131 | 132 | [tool.typos.default] 133 | extend-ignore-re = [ 134 | "(?Rm)^.*(#|//)\\s*spellchecker:\\s*disable-line$" 135 | ] 136 | -------------------------------------------------------------------------------- /pytools/batchjob.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing_extensions import override 4 | 5 | 6 | def _cp(src, dest): 7 | from pytools import assert_not_a_file 8 | assert_not_a_file(dest) 9 | 10 | with open(src, "rb") as inf, open(dest, "wb") as outf: 11 | outf.write(inf.read()) 12 | 13 | 14 | def get_timestamp(): 15 | from datetime import datetime 16 | return datetime.now().strftime("%Y-%m-%d-%H%M%S") 17 | 18 | 19 | class BatchJob: 20 | def __init__(self, moniker, main_file, aux_files=(), timestamp=None): 21 | import os 22 | import os.path 23 | 24 | if timestamp is None: 25 | timestamp = get_timestamp() 26 | 27 | self.moniker = ( 28 | moniker 29 | .replace("/", "-") 30 | .replace("-$DATE", "") 31 | .replace("$DATE-", "") 32 | .replace("$DATE", "") 33 | ) 34 | self.subdir = moniker.replace("$DATE", timestamp) 35 | self.path = os.path.join( 36 | os.getcwd(), 37 | self.subdir) 38 | 39 | os.makedirs(self.path) 40 | 41 | with open(f"{self.path}/run.sh", "w") as runscript: 42 | import sys 43 | runscript.write(f"{sys.executable} {main_file} setup.cpy") 44 | 45 | from os.path import basename 46 | 47 | if not main_file.startswith("-m "): 48 | _cp(main_file, os.path.join(self.path, basename(main_file))) 49 | 50 | for aux_file in aux_files: 51 | _cp(aux_file, os.path.join(self.path, basename(aux_file))) 52 | 53 | def write_setup(self, lines): 54 | import os.path 55 | with open(os.path.join(self.path, "setup.cpy"), "w") as setup: 56 | setup.write("\n".join(lines)) 57 | 58 | 59 | class INHERIT: 60 | pass 61 | 62 | 63 | class GridEngineJob(BatchJob): 64 | def submit(self, env=(("LD_LIBRARY_PATH", INHERIT), ("PYTHONPATH", INHERIT)), 65 | memory_megs=None, extra_args=()): 66 | from subprocess import Popen 67 | args = [ 68 | "-N", self.moniker, 69 | "-cwd", 70 | ] 71 | 72 | from os import getenv 73 | env = dict(env) 74 | for var, value in env.items(): 75 | if value is INHERIT: 76 | value = getenv(var) 77 | 78 | args += ["-v", f"{var}={value}"] 79 | 80 | if memory_megs is not None: 81 | args.extend(["-l", f"mem={memory_megs}"]) 82 | 83 | args.extend(extra_args) 84 | 85 | subproc = Popen(["qsub", *args, "run.sh"], cwd=self.path) 86 | if subproc.wait() != 0: 87 | raise RuntimeError(f"Process submission of {self.moniker} failed") 88 | 89 | 90 | class PBSJob(BatchJob): 91 | def submit(self, env=(("LD_LIBRARY_PATH", INHERIT), ("PYTHONPATH", INHERIT)), 92 | memory_megs=None, extra_args=()): 93 | from subprocess import Popen 94 | args = [ 95 | "-N", self.moniker, 96 | "-d", self.path, 97 | ] 98 | 99 | if memory_megs is not None: 100 | args.extend(["-l", f"pmem={memory_megs}mb"]) 101 | 102 | from os import getenv 103 | 104 | env = dict(env) 105 | for var, value in env.items(): 106 | if value is INHERIT: 107 | value = getenv(var) 108 | 109 | args += ["-v", f"{var}={value}"] 110 | 111 | args.extend(extra_args) 112 | 113 | subproc = Popen(["qsub", *args, "run.sh"], cwd=self.path) 114 | if subproc.wait() != 0: 115 | raise RuntimeError(f"Process submission of {self.moniker} failed") 116 | 117 | 118 | def guess_job_class(): 119 | from subprocess import PIPE, STDOUT, Popen 120 | qstat_helplines = Popen(["qstat", "--help"], 121 | stdout=PIPE, stderr=STDOUT).communicate()[0].split("\n") 122 | if qstat_helplines[0].startswith("GE"): 123 | return GridEngineJob 124 | return PBSJob 125 | 126 | 127 | class ConstructorPlaceholder: 128 | def __init__(self, classname, *args, **kwargs): 129 | self.classname = classname 130 | self.args = args 131 | self.kwargs = kwargs 132 | 133 | def arg(self, i): 134 | return self.args[i] 135 | 136 | def kwarg(self, name): 137 | return self.kwargs[name] 138 | 139 | @override 140 | def __str__(self): 141 | return "{}({})".format(self.classname, 142 | ",".join( 143 | [str(arg) for arg in self.args] 144 | + [f"{kw}={val!r}" 145 | for kw, val in self.kwargs.items()] 146 | ) 147 | ) 148 | __repr__ = __str__ 149 | -------------------------------------------------------------------------------- /pytools/codegen.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | 4 | __copyright__ = "Copyright (C) 2009-2013 Andreas Kloeckner" 5 | 6 | __license__ = """ 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in 15 | all copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 23 | THE SOFTWARE. 24 | """ 25 | 26 | __doc__ = """ 27 | Tools for Source Code Generation 28 | ================================ 29 | 30 | .. autoclass:: CodeGenerator 31 | .. autoclass:: Indentation 32 | .. autofunction:: remove_common_indentation 33 | """ 34 | 35 | from typing import Any 36 | 37 | 38 | # {{{ code generation 39 | 40 | # loosely based on 41 | # http://effbot.org/zone/python-code-generator.htm 42 | 43 | class CodeGenerator: 44 | """Language-agnostic functionality for source code generation. 45 | 46 | .. automethod:: extend 47 | .. automethod:: get 48 | .. automethod:: add_to_preamble 49 | .. automethod:: __call__ 50 | .. automethod:: indent 51 | .. automethod:: dedent 52 | """ 53 | def __init__(self) -> None: 54 | self.preamble: list[str] = [] 55 | self.code: list[str] = [] 56 | self.level = 0 57 | self.indent_amount = 4 58 | 59 | def extend(self, sub_generator: CodeGenerator) -> None: 60 | for line in sub_generator.code: 61 | self.code.append(" "*(self.indent_amount*self.level) + line) 62 | 63 | def get(self) -> str: 64 | result = "\n".join(self.code) 65 | if self.preamble: 66 | result = "\n".join(self.preamble) + "\n" + result 67 | return result 68 | 69 | def add_to_preamble(self, s: str) -> None: 70 | self.preamble.append(s) 71 | 72 | def __call__(self, s: str) -> None: 73 | if not s.strip(): 74 | self.code.append("") 75 | else: 76 | if "\n" in s: 77 | s = remove_common_indentation(s) 78 | 79 | for line in s.split("\n"): 80 | self.code.append(" "*(self.indent_amount*self.level) + line) 81 | 82 | def indent(self) -> None: 83 | self.level += 1 84 | 85 | def dedent(self) -> None: 86 | if self.level == 0: 87 | raise RuntimeError("cannot decrease indentation level") 88 | self.level -= 1 89 | 90 | 91 | class Indentation: 92 | """A context manager for indentation for use with :class:`CodeGenerator`. 93 | 94 | .. attribute:: generator 95 | .. automethod:: __enter__ 96 | .. automethod:: __exit__ 97 | """ 98 | def __init__(self, generator: CodeGenerator): 99 | self.generator = generator 100 | 101 | def __enter__(self) -> None: 102 | self.generator.indent() 103 | 104 | def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: 105 | self.generator.dedent() 106 | 107 | # }}} 108 | 109 | 110 | # {{{ remove common indentation 111 | 112 | def remove_common_indentation(code: str, require_leading_newline: bool = True): 113 | r"""Remove leading indentation from one or more lines of code. 114 | 115 | Removes an amount of indentation equal to the indentation level of the first 116 | nonempty line in *code*. 117 | 118 | :param code: Input string. 119 | :param require_leading_newline: If *True*, only remove indentation if *code* 120 | starts with ``\n``. 121 | 122 | :returns: A copy of *code* stripped of leading common indentation. 123 | """ 124 | if "\n" not in code: 125 | return code 126 | 127 | if require_leading_newline and not code.startswith("\n"): 128 | return code 129 | 130 | lines = code.split("\n") 131 | while lines[0].strip() == "": 132 | lines.pop(0) 133 | while lines[-1].strip() == "": 134 | lines.pop(-1) 135 | 136 | if lines: 137 | base_indent = 0 138 | while lines[0][base_indent] in " \t": 139 | base_indent += 1 140 | 141 | for line in lines[1:]: 142 | if line[:base_indent].strip(): 143 | raise ValueError("inconsistent indentation") 144 | 145 | return "\n".join(line[base_indent:] for line in lines) 146 | 147 | # }}} 148 | 149 | # vim: foldmethod=marker 150 | -------------------------------------------------------------------------------- /pytools/convergence.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. autofunction:: estimate_order_of_convergence 3 | .. autoclass:: EOCRecorder 4 | .. autofunction:: stringify_eocs 5 | .. autoclass:: PConvergenceVerifier 6 | """ 7 | 8 | from __future__ import annotations 9 | 10 | import numbers 11 | 12 | import numpy as np 13 | from typing_extensions import override 14 | 15 | 16 | # {{{ eoc estimation -------------------------------------------------------------- 17 | 18 | def estimate_order_of_convergence(abscissae, errors): 19 | r"""Assuming that abscissae and errors are connected by a law of the form 20 | 21 | .. math:: 22 | 23 | \text{Error} = \text{constant} \cdot \text{abscissa }^{\text{order}}, 24 | 25 | this function finds, in a least-squares sense, the best approximation of 26 | constant and order for the given data set. It returns a tuple (constant, order). 27 | """ 28 | assert len(abscissae) == len(errors) 29 | if len(abscissae) <= 1: 30 | raise RuntimeError("Need more than one value to guess order of convergence.") 31 | 32 | coefficients = np.polyfit(np.log10(abscissae), np.log10(errors), 1) 33 | return 10**coefficients[-1], coefficients[-2] 34 | 35 | 36 | class EOCRecorder: 37 | """ 38 | .. automethod:: add_data_point 39 | 40 | .. automethod:: estimate_order_of_convergence 41 | .. automethod:: order_estimate 42 | .. automethod:: max_error 43 | 44 | .. automethod:: pretty_print 45 | .. automethod:: write_gnuplot_file 46 | """ 47 | 48 | def __init__(self) -> None: 49 | self.history: list[tuple[float, float]] = [] 50 | 51 | def add_data_point(self, abscissa: float, error: float) -> None: 52 | if not (isinstance(abscissa, numbers.Number) 53 | or (isinstance(abscissa, np.ndarray) and abscissa.shape == ())): 54 | raise TypeError( 55 | f"'abscissa' is not a scalar: '{type(abscissa).__name__}'") 56 | 57 | if not (isinstance(error, numbers.Number) 58 | or (isinstance(error, np.ndarray) and error.shape == ())): 59 | raise TypeError(f"'error' is not a scalar: '{type(error).__name__}'") 60 | 61 | self.history.append((abscissa, error)) 62 | 63 | def estimate_order_of_convergence(self, 64 | gliding_mean: int | None = None, 65 | ) -> np.ndarray: 66 | abscissae = np.array([a for a, e in self.history]) 67 | errors = np.array([e for a, e in self.history]) 68 | 69 | # NOTE: in case any of the errors are exactly 0.0, which 70 | # can give NaNs in `estimate_order_of_convergence` 71 | emax: float = np.amax(errors) 72 | errors += (1 if emax == 0 else emax) * np.finfo(errors.dtype).eps 73 | 74 | size = len(abscissae) 75 | if gliding_mean is None: 76 | gliding_mean = size 77 | 78 | data_points = size - gliding_mean + 1 79 | result: np.ndarray = np.zeros((data_points, 2), float) 80 | for i in range(data_points): 81 | result[i, 0], result[i, 1] = estimate_order_of_convergence( 82 | abscissae[i:i+gliding_mean], errors[i:i+gliding_mean]) 83 | return result 84 | 85 | def order_estimate(self) -> float: 86 | return self.estimate_order_of_convergence()[0, 1] 87 | 88 | def max_error(self) -> float: 89 | return max(err for absc, err in self.history) 90 | 91 | def _to_table(self, *, 92 | abscissa_label="h", 93 | error_label="Error", 94 | gliding_mean=2, 95 | abscissa_format="%s", 96 | error_format="%s", 97 | eoc_format="%s"): 98 | from pytools import Table 99 | 100 | tbl = Table() 101 | tbl.add_row((abscissa_label, error_label, "Running EOC")) 102 | 103 | gm_eoc = self.estimate_order_of_convergence(gliding_mean) 104 | for i, (absc, err) in enumerate(self.history): 105 | absc_str = abscissa_format % absc 106 | err_str = error_format % err 107 | if i < gliding_mean-1: 108 | eoc_str = "" 109 | else: 110 | eoc_str = eoc_format % (gm_eoc[i - gliding_mean + 1, 1]) 111 | 112 | tbl.add_row((absc_str, err_str, eoc_str)) 113 | 114 | if len(self.history) > 1: 115 | order = self.estimate_order_of_convergence()[0, 1] 116 | tbl.add_row(("Overall", "", eoc_format % order)) 117 | 118 | return tbl 119 | 120 | def pretty_print(self, *, 121 | abscissa_label: str = "h", 122 | error_label: str = "Error", 123 | gliding_mean: int = 2, 124 | abscissa_format: str = "%s", 125 | error_format: str = "%s", 126 | eoc_format: str = "%s", 127 | table_type: str = "markdown") -> str: 128 | tbl = self._to_table( 129 | abscissa_label=abscissa_label, error_label=error_label, 130 | abscissa_format=abscissa_format, 131 | error_format=error_format, 132 | eoc_format=eoc_format, 133 | gliding_mean=gliding_mean) 134 | 135 | if table_type == "markdown": 136 | return tbl.github_markdown() 137 | if table_type == "latex": 138 | return tbl.latex() 139 | if table_type == "ascii": 140 | return str(tbl) 141 | if table_type == "csv": 142 | return tbl.csv() 143 | raise ValueError(f"unknown table type: {table_type}") 144 | 145 | @override 146 | def __str__(self): 147 | return self.pretty_print() 148 | 149 | def write_gnuplot_file(self, filename: str) -> None: 150 | with open(filename, "w") as outfile: 151 | for absc, err in self.history: 152 | outfile.write(f"{absc:f} {err:f}\n") 153 | result = self.estimate_order_of_convergence() 154 | const = result[0, 0] 155 | order = result[0, 1] 156 | outfile.write("\n") 157 | for absc, _err in self.history: 158 | outfile.write(f"{absc:f} {const * absc**(-order):f}\n") 159 | 160 | 161 | def stringify_eocs(*eocs: EOCRecorder, 162 | names: tuple[str, ...] | None = None, 163 | abscissa_label: str = "h", 164 | error_label: str = "Error", 165 | gliding_mean: int = 2, 166 | abscissa_format: str = "%s", 167 | error_format: str = "%s", 168 | eoc_format: str = "%s", 169 | table_type: str = "markdown") -> str: 170 | """ 171 | :arg names: a :class:`tuple` of names to use for the *error_label* of each 172 | *eoc*. 173 | """ 174 | if names is not None and len(names) < len(eocs): 175 | raise ValueError( 176 | f"insufficient names: got {len(names)} names for " 177 | f"{len(eocs)} EOCRecorder instances") 178 | 179 | if names is None: 180 | names = tuple(f"{error_label} {i}" for i in range(len(eocs))) 181 | 182 | from pytools import merge_tables 183 | tbl = merge_tables(*[eoc._to_table( 184 | abscissa_label=abscissa_label, error_label=name, 185 | abscissa_format=abscissa_format, 186 | error_format=error_format, 187 | eoc_format=eoc_format, 188 | gliding_mean=gliding_mean) 189 | for name, eoc in zip(names, eocs, strict=True) 190 | ], skip_columns=(0,)) 191 | 192 | if table_type == "markdown": 193 | return tbl.github_markdown() 194 | if table_type == "latex": 195 | return tbl.latex() 196 | if table_type == "ascii": 197 | return str(tbl) 198 | if table_type == "csv": 199 | return tbl.csv() 200 | raise ValueError(f"unknown table type: {table_type}") 201 | 202 | # }}} 203 | 204 | 205 | # {{{ p convergence verifier 206 | 207 | class PConvergenceVerifier: 208 | def __init__(self): 209 | self.orders = [] 210 | self.errors = [] 211 | 212 | def add_data_point(self, order, error): 213 | self.orders.append(order) 214 | self.errors.append(error) 215 | 216 | @override 217 | def __str__(self): 218 | from pytools import Table 219 | tbl = Table() 220 | tbl.add_row(("p", "error")) 221 | 222 | for p, err in zip(self.orders, self.errors, strict=True): 223 | tbl.add_row((str(p), str(err))) 224 | 225 | return str(tbl) 226 | 227 | def __call__(self): 228 | orders = np.array(self.orders, np.float64) 229 | errors = np.abs(np.array(self.errors, np.float64)) 230 | 231 | rel_change = np.diff(1e-20 + np.log10(errors)) / np.diff(orders) 232 | 233 | assert (rel_change < -0.2).all() 234 | 235 | # }}} 236 | 237 | 238 | # vim: foldmethod=marker 239 | -------------------------------------------------------------------------------- /pytools/datatable.py: -------------------------------------------------------------------------------- 1 | """ 2 | An in-memory relational database table 3 | ====================================== 4 | 5 | .. autoclass:: DataTable 6 | """ 7 | 8 | from __future__ import annotations 9 | 10 | from typing import IO, TYPE_CHECKING, Any 11 | 12 | from typing_extensions import override 13 | 14 | from pytools import Record 15 | 16 | 17 | if TYPE_CHECKING: 18 | from collections.abc import Callable, Iterator, Sequence 19 | 20 | 21 | class Row(Record): 22 | pass 23 | 24 | 25 | class DataTable: 26 | """An in-memory relational database table. 27 | 28 | .. automethod:: __init__ 29 | .. automethod:: copy 30 | .. automethod:: deep_copy 31 | .. automethod:: join 32 | """ 33 | 34 | def __init__(self, column_names: Sequence[str], 35 | column_data: list[Any] | None = None) -> None: 36 | """Construct a new table, with the given C{column_names}. 37 | 38 | :arg column_names: An indexable of column name strings. 39 | :arg column_data: None or a list of tuples of the same length as 40 | *column_names* indicating an initial set of data. 41 | """ 42 | if column_data is None: 43 | self.data = [] 44 | else: 45 | self.data = column_data 46 | 47 | self.column_names = column_names 48 | self.column_indices = { 49 | colname: i for i, colname in enumerate(column_names)} 50 | 51 | if len(self.column_indices) != len(self.column_names): 52 | raise RuntimeError("non-unique column names encountered") 53 | 54 | def __bool__(self) -> bool: 55 | return bool(self.data) 56 | 57 | def __len__(self) -> int: 58 | return len(self.data) 59 | 60 | def __iter__(self) -> Iterator[list[Any]]: 61 | return self.data.__iter__() 62 | 63 | @override 64 | def __str__(self) -> str: 65 | """Return a pretty-printed version of the table.""" 66 | 67 | def col_width(i: int) -> int: 68 | width = len(self.column_names[i]) 69 | if self: 70 | width = max(width, max(len(str(row[i])) for row in self.data)) 71 | return width 72 | col_widths = [col_width(i) for i in range(len(self.column_names))] 73 | 74 | def format_row(row: Sequence[str]) -> str: 75 | return "|".join([str(cell).ljust(col_width) 76 | for cell, col_width in zip(row, col_widths, strict=True)]) 77 | 78 | lines = [format_row(self.column_names), 79 | "+".join("-"*col_width for col_width in col_widths)] + \ 80 | [format_row(row) for row in self.data] 81 | return "\n".join(lines) 82 | 83 | def insert(self, **kwargs: Any) -> None: 84 | values = [None for i in range(len(self.column_names))] 85 | 86 | for key, val in kwargs.items(): 87 | values[self.column_indices[key]] = val 88 | 89 | self.insert_row(tuple(values)) 90 | 91 | def insert_row(self, values: tuple[Any, ...]) -> None: 92 | assert isinstance(values, tuple) 93 | assert len(values) == len(self.column_names) 94 | self.data.append(values) 95 | 96 | def insert_rows(self, rows: Sequence[tuple[Any, ...]]) -> None: 97 | for row in rows: 98 | self.insert_row(row) 99 | 100 | def filtered(self, **kwargs: Any) -> DataTable: 101 | if not kwargs: 102 | return self 103 | 104 | criteria = tuple( 105 | (self.column_indices[key], value) 106 | for key, value in kwargs.items()) 107 | 108 | result_data = [] 109 | 110 | for row in self.data: 111 | satisfied = True 112 | for idx, val in criteria: 113 | if row[idx] != val: 114 | satisfied = False 115 | break 116 | 117 | if satisfied: 118 | result_data.append(row) 119 | 120 | return DataTable(self.column_names, result_data) 121 | 122 | def get(self, **kwargs: Any) -> Row: 123 | filtered = self.filtered(**kwargs) 124 | if not filtered: 125 | raise RuntimeError("no matching entry for get()") 126 | if len(filtered) > 1: 127 | raise RuntimeError("more than one matching entry for get()") 128 | 129 | return Row(dict(zip(self.column_names, filtered.data[0], strict=True))) 130 | 131 | def clear(self) -> None: 132 | del self.data[:] 133 | 134 | def copy(self) -> DataTable: 135 | """Make a copy of the instance, but leave individual rows untouched. 136 | 137 | If the rows are modified later, they will also be modified in the copy. 138 | """ 139 | return DataTable(self.column_names, self.data[:]) 140 | 141 | def deep_copy(self) -> DataTable: 142 | """Make a copy of the instance down to the row level. 143 | 144 | The copy's rows may be modified independently from the original. 145 | """ 146 | return DataTable(self.column_names, [row[:] for row in self.data]) 147 | 148 | def sort(self, columns: Sequence[str], reverse: bool = False) -> None: 149 | col_indices = [self.column_indices[col] for col in columns] 150 | 151 | def mykey(row: Sequence[Any]) -> tuple[Any, ...]: 152 | return tuple( 153 | row[col_index] 154 | for col_index in col_indices) 155 | 156 | self.data.sort(reverse=reverse, key=mykey) 157 | 158 | def aggregated(self, groupby: Sequence[str], agg_column: str, 159 | aggregate_func: Callable[[Sequence[Any]], Any]) -> DataTable: 160 | gb_indices = [self.column_indices[col] for col in groupby] 161 | agg_index = self.column_indices[agg_column] 162 | 163 | first = True 164 | 165 | result_data = [] 166 | 167 | # to pacify pyflakes: 168 | last_values: tuple[Any, ...] = () 169 | agg_values: list[Row] = [] 170 | 171 | for row in self.data: 172 | this_values = tuple(row[i] for i in gb_indices) 173 | if first or this_values != last_values: 174 | if not first: 175 | result_data.append((*last_values, aggregate_func(agg_values))) 176 | 177 | agg_values = [row[agg_index]] 178 | last_values = this_values 179 | first = False 180 | else: 181 | agg_values.append(row[agg_index]) 182 | 183 | if not first and agg_values: 184 | result_data.append((*this_values, aggregate_func(agg_values))) 185 | 186 | return DataTable( 187 | [self.column_names[i] for i in gb_indices] + [agg_column], 188 | result_data) 189 | 190 | def join(self, column: str, other_column: str, other_table: DataTable, 191 | outer: bool = False) -> DataTable: 192 | """Return a table joining this and the C{other_table} on C{column}. 193 | 194 | The new table has the following columns: 195 | - C{column}, titled the same as in this table. 196 | - the columns of this table, minus C{column}. 197 | - the columns of C{other_table}, minus C{other_column}. 198 | 199 | Assumes both tables are sorted ascendingly by the column 200 | by which they are joined. 201 | """ 202 | 203 | def without(indexable: tuple[str, ...], idx: int) -> tuple[str, ...]: 204 | return indexable[:idx] + indexable[idx+1:] 205 | 206 | this_key_idx = self.column_indices[column] 207 | other_key_idx = other_table.column_indices[other_column] 208 | 209 | this_iter = self.data.__iter__() 210 | other_iter = other_table.data.__iter__() 211 | 212 | result_columns = tuple(self.column_names[this_key_idx]) + \ 213 | without(tuple(self.column_names), this_key_idx) + \ 214 | without(tuple(other_table.column_names), other_key_idx) 215 | 216 | result_data = [] 217 | 218 | this_row = next(this_iter) 219 | other_row = next(other_iter) 220 | 221 | this_over = False 222 | other_over = False 223 | 224 | while True: 225 | this_batch = [] 226 | other_batch = [] 227 | 228 | if this_over: 229 | run_other = True 230 | elif other_over: 231 | run_this = True 232 | else: 233 | this_key = this_row[this_key_idx] 234 | other_key = other_row[other_key_idx] 235 | 236 | run_this = this_key < other_key 237 | run_other = this_key > other_key 238 | if this_key == other_key: 239 | run_this = run_other = True 240 | 241 | if run_this and not this_over: 242 | key = this_key 243 | while this_row[this_key_idx] == this_key: 244 | this_batch.append(this_row) 245 | try: 246 | this_row = next(this_iter) 247 | except StopIteration: 248 | this_over = True 249 | break 250 | elif outer: 251 | this_batch = [(None,) * len(self.column_names)] 252 | 253 | if run_other and not other_over: 254 | key = other_key 255 | while other_row[other_key_idx] == other_key: 256 | other_batch.append(other_row) 257 | try: 258 | other_row = next(other_iter) 259 | except StopIteration: 260 | other_over = True 261 | break 262 | elif outer: 263 | other_batch = [(None,) * len(other_table.column_names)] 264 | 265 | for this_batch_row in this_batch: 266 | for other_batch_row in other_batch: 267 | result_data.append(( 268 | key, 269 | *without(this_batch_row, this_key_idx), 270 | *without(other_batch_row, other_key_idx))) 271 | 272 | if outer: 273 | if this_over and other_over: 274 | break 275 | elif this_over or other_over: 276 | break 277 | 278 | return DataTable(result_columns, result_data) 279 | 280 | def restricted(self, columns: Sequence[str]) -> DataTable: 281 | col_indices = [self.column_indices[col] for col in columns] 282 | 283 | return DataTable(columns, 284 | [[row[i] for i in col_indices] for row in self.data]) 285 | 286 | def column_data(self, column: str) -> list[tuple[Any, ...]]: 287 | col_index = self.column_indices[column] 288 | return [row[col_index] for row in self.data] 289 | 290 | def write_csv(self, filelike: IO[Any], **kwargs: Any) -> None: 291 | from csv import writer 292 | csvwriter = writer(filelike, **kwargs) 293 | csvwriter.writerow(self.column_names) 294 | csvwriter.writerows(self.data) 295 | -------------------------------------------------------------------------------- /pytools/debug.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import contextlib 4 | import sys 5 | 6 | from typing_extensions import override 7 | 8 | from pytools import memoize 9 | 10 | 11 | # {{{ debug files ------------------------------------------------------------- 12 | 13 | def make_unique_filesystem_object(stem, extension="", directory="", 14 | creator=None): 15 | """ 16 | :param extension: needs a leading dot. 17 | :param directory: must not have a trailing slash. 18 | """ 19 | import os 20 | from os.path import join 21 | 22 | if creator is None: 23 | def default_creator(name): 24 | return os.fdopen(os.open(name, 25 | os.O_CREAT | os.O_WRONLY | os.O_EXCL, 0o444), "w") 26 | creator = default_creator 27 | 28 | i = 0 29 | while True: 30 | fname = join(directory, f"{stem}-{i}{extension}") 31 | try: 32 | return creator(fname), fname 33 | except OSError: 34 | i += 1 35 | 36 | 37 | @memoize 38 | def get_run_debug_directory(): 39 | def creator(name): 40 | from os import mkdir 41 | mkdir(name) 42 | return name 43 | 44 | return make_unique_filesystem_object("run-debug", creator=creator)[0] 45 | 46 | 47 | def open_unique_debug_file(stem, extension=""): 48 | """ 49 | :param extension: needs a leading dot. 50 | """ 51 | return make_unique_filesystem_object( 52 | stem, extension, get_run_debug_directory()) 53 | 54 | # }}} 55 | 56 | 57 | # {{{ refcount debugging ------------------------------------------------------ 58 | 59 | class RefDebugQuit(Exception): # noqa: N818 60 | pass 61 | 62 | 63 | def refdebug(obj, top_level=True, exclude=()): 64 | from types import FrameType 65 | 66 | def is_excluded(o): 67 | for ex in exclude: 68 | if o is ex: 69 | return True 70 | 71 | from sys import _getframe 72 | return bool(isinstance(o, FrameType) 73 | and o.f_code.co_filename == _getframe().f_code.co_filename) 74 | 75 | if top_level: 76 | with contextlib.suppress(RefDebugQuit): 77 | refdebug(obj, top_level=False, exclude=exclude) 78 | return 79 | 80 | import gc 81 | print_head = True 82 | print("-------------->") 83 | try: 84 | reflist = [x for x in gc.get_referrers(obj) 85 | if not is_excluded(x)] 86 | 87 | idx = 0 88 | while True: 89 | if print_head: 90 | print("referring to", id(obj), type(obj), obj) 91 | print("----------------------") 92 | print_head = False 93 | r = reflist[idx] 94 | 95 | s = str(r.f_code) if isinstance(r, FrameType) else str(r) 96 | 97 | print(f"{idx}/{len(reflist)}: ", id(r), type(r), s) 98 | 99 | if isinstance(r, dict): 100 | for k, v in r.items(): 101 | if v is obj: 102 | print("...referred to from key", k) 103 | 104 | print("[d]ig, [n]ext, [p]rev, [e]val, [r]eturn, [q]uit?") 105 | 106 | response = input() 107 | 108 | if response == "d": 109 | refdebug(r, top_level=False, exclude=exclude+tuple(reflist)) 110 | print_head = True 111 | elif response == "n": 112 | if idx + 1 < len(reflist): 113 | idx += 1 114 | elif response == "p": 115 | if idx - 1 >= 0: 116 | idx -= 1 117 | elif response == "e": 118 | print("type expression, obj is your object:") 119 | expr_str = input() 120 | try: 121 | res = eval(expr_str, {"obj": r}) # pylint:disable=eval-used 122 | except Exception: # pylint:disable=broad-except 123 | from traceback import print_exc 124 | print_exc() 125 | print(res) 126 | elif response == "r": 127 | return 128 | elif response == "q": 129 | raise RefDebugQuit 130 | else: 131 | print("WHAT YOU SAY!!! (invalid choice)") 132 | 133 | finally: 134 | print("<--------------") 135 | 136 | # }}} 137 | 138 | 139 | # {{{ interactive shell 140 | 141 | def get_shell_hist_filename() -> str: 142 | import os 143 | 144 | return os.path.expanduser(os.path.join("~", ".pytools-debug-shell-history")) 145 | 146 | 147 | def setup_readline(): 148 | from os.path import exists 149 | hist_filename = get_shell_hist_filename() 150 | if exists(hist_filename): 151 | try: 152 | readline.read_history_file(hist_filename) 153 | except Exception: # pylint:disable=broad-except 154 | # http://docs.python.org/3/howto/pyporting.html#capturing-the-currently-raised-exception 155 | import sys 156 | e = sys.exc_info()[1] 157 | 158 | from warnings import warn 159 | warn(f"Error opening readline history file: {e}", stacklevel=2) 160 | 161 | readline.parse_and_bind("tab: complete") 162 | 163 | 164 | try: 165 | import readline 166 | import rlcompleter 167 | HAVE_READLINE = True 168 | except ImportError: 169 | HAVE_READLINE = False 170 | else: 171 | setup_readline() 172 | 173 | 174 | class SetPropagatingDict(dict): 175 | def __init__(self, source_dicts, target_dict): 176 | dict.__init__(self) 177 | for s in source_dicts[::-1]: 178 | self.update(s) 179 | 180 | self.target_dict = target_dict 181 | 182 | @override 183 | def __setitem__(self, key, value): 184 | dict.__setitem__(self, key, value) 185 | self.target_dict[key] = value 186 | 187 | @override 188 | def __delitem__(self, key): 189 | dict.__delitem__(self, key) 190 | del self.target_dict[key] 191 | 192 | 193 | def shell(locals_=None, globals_=None): 194 | from inspect import currentframe, getouterframes 195 | calling_frame = getouterframes(currentframe())[1][0] 196 | 197 | if locals_ is None: 198 | locals_ = calling_frame.f_locals 199 | if globals_ is None: 200 | globals_ = calling_frame.f_globals 201 | 202 | ns = SetPropagatingDict([locals_, globals_], locals_) 203 | 204 | if HAVE_READLINE: 205 | readline.set_completer( 206 | rlcompleter.Completer(ns).complete) 207 | 208 | from code import InteractiveConsole 209 | cons = InteractiveConsole(ns) 210 | cons.interact("") 211 | 212 | readline.write_history_file(get_shell_hist_filename()) 213 | 214 | # }}} 215 | 216 | 217 | # {{{ estimate memory usage 218 | 219 | def estimate_memory_usage(root, seen_ids=None): 220 | if seen_ids is None: 221 | seen_ids = set() 222 | 223 | id_root = id(root) 224 | if id_root in seen_ids: 225 | return 0 226 | 227 | seen_ids.add(id_root) 228 | 229 | result = sys.getsizeof(root) 230 | 231 | from gc import get_referents 232 | for ref in get_referents(root): 233 | result += estimate_memory_usage(ref, seen_ids=seen_ids) 234 | 235 | return result 236 | 237 | # }}} 238 | 239 | # vim: foldmethod=marker 240 | -------------------------------------------------------------------------------- /pytools/graph.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | 4 | __copyright__ = """ 5 | Copyright (C) 2009-2013 Andreas Kloeckner 6 | Copyright (C) 2020 Matt Wala 7 | Copyright (C) 2020 James Stevens 8 | Copyright (C) 2024 Addison Alvey-Blanco 9 | """ 10 | 11 | __license__ = """ 12 | Permission is hereby granted, free of charge, to any person obtaining a copy 13 | of this software and associated documentation files (the "Software"), to deal 14 | in the Software without restriction, including without limitation the rights 15 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 16 | copies of the Software, and to permit persons to whom the Software is 17 | furnished to do so, subject to the following conditions: 18 | 19 | The above copyright notice and this permission notice shall be included in 20 | all copies or substantial portions of the Software. 21 | 22 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 23 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 24 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 25 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 26 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 27 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 28 | THE SOFTWARE. 29 | """ 30 | 31 | __doc__ = """ 32 | Graph Algorithms 33 | ================ 34 | 35 | .. note:: 36 | 37 | These functions are mostly geared towards directed graphs (digraphs). 38 | 39 | .. autofunction:: reverse_graph 40 | .. autofunction:: a_star 41 | .. autofunction:: compute_sccs 42 | .. autoexception:: CycleError 43 | .. autofunction:: compute_topological_order 44 | .. autofunction:: compute_transitive_closure 45 | .. autofunction:: contains_cycle 46 | .. autofunction:: compute_induced_subgraph 47 | .. autofunction:: as_graphviz_dot 48 | .. autofunction:: validate_graph 49 | .. autofunction:: is_connected 50 | .. autofunction:: undirected_graph_from_edges 51 | .. autofunction:: get_reachable_nodes 52 | 53 | Type Variables Used 54 | ------------------- 55 | 56 | .. class:: _SupportsLT 57 | 58 | A :class:`~typing.Protocol` for `__lt__` support. 59 | 60 | .. class:: NodeT 61 | 62 | Type of a graph node, can be any hashable type. 63 | 64 | .. class:: GraphT 65 | 66 | A :class:`collections.abc.Mapping` representing a directed 67 | graph. The mapping contains one key representing each node in the 68 | graph, and this key maps to a :class:`collections.abc.Collection` of its 69 | successor nodes. Note that most functions expect that every graph node 70 | is included as a key in the graph. 71 | """ 72 | 73 | from collections.abc import ( 74 | Callable, 75 | Collection, 76 | Hashable, 77 | Iterable, 78 | Iterator, 79 | Mapping, 80 | MutableSet, 81 | ) 82 | from dataclasses import dataclass 83 | from typing import ( 84 | Any, 85 | Generic, 86 | Protocol, 87 | TypeAlias, 88 | TypeVar, 89 | ) 90 | 91 | 92 | NodeT = TypeVar("NodeT", bound=Hashable) 93 | 94 | GraphT: TypeAlias = Mapping[NodeT, Collection[NodeT]] 95 | 96 | 97 | # {{{ reverse_graph 98 | 99 | def reverse_graph(graph: GraphT[NodeT]) -> GraphT[NodeT]: 100 | """ 101 | Reverses a graph *graph*. 102 | 103 | :returns: A :class:`dict` representing *graph* with edges reversed. 104 | """ 105 | result: dict[NodeT, set[NodeT]] = {} 106 | 107 | for node_key, successor_nodes in graph.items(): 108 | # Make sure every node is in the result even if it has no successors 109 | result.setdefault(node_key, set()) 110 | 111 | for successor in successor_nodes: 112 | result.setdefault(successor, set()).add(node_key) 113 | 114 | return {k: frozenset(v) for k, v in result.items()} 115 | 116 | # }}} 117 | 118 | 119 | # {{{ a_star 120 | 121 | @dataclass(frozen=True) 122 | class _AStarNode(Generic[NodeT]): 123 | state: NodeT 124 | parent: _AStarNode[NodeT] | None 125 | path_cost: float | int 126 | 127 | 128 | def a_star( 129 | initial_state: NodeT, goal_state: NodeT, neighbor_map: GraphT[NodeT], 130 | estimate_remaining_cost: Callable[[NodeT], float] | None = None, 131 | get_step_cost: Callable[[Any, NodeT], float] = lambda x, y: 1 132 | ) -> list[NodeT]: 133 | """ 134 | With the default cost and heuristic, this amounts to Dijkstra's algorithm. 135 | """ 136 | 137 | from heapq import heappop, heappush 138 | 139 | if estimate_remaining_cost is None: 140 | def estimate_remaining_cost(x: NodeT) -> float: 141 | if x != goal_state: 142 | return 1 143 | return 0 144 | 145 | inf = float("inf") 146 | init_remcost = estimate_remaining_cost(initial_state) 147 | assert init_remcost != inf 148 | 149 | queue = [(init_remcost, _AStarNode(initial_state, parent=None, path_cost=0))] 150 | visited_states = set() 151 | 152 | while queue: 153 | _, top = heappop(queue) 154 | visited_states.add(top.state) 155 | 156 | if top.state == goal_state: 157 | result = [] 158 | it: _AStarNode[NodeT] | None = top 159 | while it is not None: 160 | result.append(it.state) 161 | it = it.parent 162 | return result[::-1] 163 | 164 | for state in neighbor_map[top.state]: 165 | if state in visited_states: 166 | continue 167 | 168 | remaining_cost = estimate_remaining_cost(state) 169 | if remaining_cost == inf: 170 | continue 171 | step_cost = get_step_cost(top, state) 172 | 173 | estimated_path_cost = top.path_cost+step_cost+remaining_cost 174 | heappush(queue, 175 | (estimated_path_cost, 176 | _AStarNode(state, top, path_cost=top.path_cost + step_cost))) 177 | 178 | raise RuntimeError("no solution") 179 | 180 | # }}} 181 | 182 | 183 | # {{{ compute SCCs with Tarjan's algorithm 184 | 185 | def compute_sccs(graph: GraphT[NodeT]) -> list[list[NodeT]]: 186 | to_search = set(graph.keys()) 187 | visit_order: dict[NodeT, int] = {} 188 | scc_root = {} 189 | sccs = [] 190 | 191 | while to_search: 192 | top = next(iter(to_search)) 193 | call_stack: list[tuple[NodeT, Iterator[NodeT], NodeT | None]] = ( 194 | [(top, iter(graph[top]), None)]) 195 | visit_stack = [] 196 | visiting = set() 197 | 198 | scc: list[NodeT] = [] 199 | 200 | while call_stack: 201 | top, children, last_popped_child = call_stack.pop() 202 | 203 | if top not in visiting: 204 | # Unvisited: mark as visited, initialize SCC root. 205 | count = len(visit_order) 206 | visit_stack.append(top) 207 | visit_order[top] = count 208 | scc_root[top] = count 209 | visiting.add(top) 210 | to_search.discard(top) 211 | 212 | # Returned from a recursion, update SCC. 213 | if last_popped_child is not None: 214 | scc_root[top] = min( 215 | scc_root[top], 216 | scc_root[last_popped_child]) 217 | 218 | for child in children: 219 | if child not in visit_order: 220 | # Recurse. 221 | call_stack.append((top, children, child)) 222 | call_stack.append((child, iter(graph[child]), None)) 223 | break 224 | if child in visiting: 225 | scc_root[top] = min( 226 | scc_root[top], 227 | visit_order[child]) 228 | else: 229 | if scc_root[top] == visit_order[top]: 230 | scc = [] 231 | while visit_stack[-1] != top: 232 | scc.append(visit_stack.pop()) 233 | scc.append(visit_stack.pop()) 234 | for item in scc: 235 | visiting.remove(item) 236 | sccs.append(scc) 237 | 238 | return sccs 239 | 240 | # }}} 241 | 242 | 243 | # {{{ compute topological order 244 | 245 | class CycleError(Exception): 246 | """ 247 | Raised when a topological ordering cannot be computed due to a cycle. 248 | 249 | :attr node: Node in a directed graph that is part of a cycle. 250 | """ 251 | def __init__(self, node: NodeT) -> None: 252 | self.node = node 253 | 254 | 255 | class _SupportsLT(Protocol): 256 | def __lt__(self, other: Any) -> bool: 257 | ... 258 | 259 | 260 | @dataclass(frozen=True) 261 | class _HeapEntry(Generic[NodeT]): 262 | """ 263 | Helper class to compare associated keys while comparing the elements in 264 | heap operations. 265 | 266 | Only needs to define :func:`pytools.graph.__lt__` according to 267 | . 268 | """ 269 | node: NodeT 270 | key: _SupportsLT 271 | 272 | def __lt__(self, other: _HeapEntry[NodeT]) -> bool: 273 | return self.key < other.key 274 | 275 | 276 | def compute_topological_order( 277 | graph: GraphT[NodeT], 278 | key: Callable[[NodeT], _SupportsLT] | None = None, 279 | ) -> list[NodeT]: 280 | """Compute a topological order of nodes in a directed graph. 281 | 282 | :arg key: A custom key function may be supplied to determine the order in 283 | break-even cases. Expects a function of one argument that is used to 284 | extract a comparison key from each node of the *graph*. 285 | 286 | :returns: A :class:`list` representing a valid topological ordering of the 287 | nodes in the directed graph. 288 | 289 | .. note:: 290 | 291 | * Requires the keys of the mapping *graph* to be hashable. 292 | * Implements `Kahn's algorithm `__. 293 | 294 | .. versionadded:: 2020.2 295 | """ 296 | # all nodes have the same keys when not provided 297 | keyfunc = key if key is not None else (lambda x: 0) 298 | 299 | from heapq import heapify, heappop, heappush 300 | 301 | order = [] 302 | 303 | # {{{ compute nodes_to_num_predecessors 304 | 305 | nodes_to_num_predecessors = dict.fromkeys(graph, 0) 306 | 307 | for node in graph: 308 | for child in graph[node]: 309 | nodes_to_num_predecessors[child] = ( 310 | nodes_to_num_predecessors.get(child, 0) + 1) 311 | 312 | # }}} 313 | 314 | total_num_nodes = len(nodes_to_num_predecessors) 315 | 316 | # heap: list of instances of HeapEntry(n) where 'n' is a node in 317 | # 'graph' with no predecessor. Nodes with no predecessors are the 318 | # schedulable candidates. 319 | heap = [_HeapEntry(n, keyfunc(n)) 320 | for n, num_preds in nodes_to_num_predecessors.items() 321 | if num_preds == 0] 322 | heapify(heap) 323 | 324 | while heap: 325 | # pick the node with least key 326 | node_to_be_scheduled = heappop(heap).node 327 | order.append(node_to_be_scheduled) 328 | 329 | # discard 'node_to_be_scheduled' from the predecessors of its 330 | # successors since it's been scheduled 331 | for child in graph.get(node_to_be_scheduled, ()): 332 | nodes_to_num_predecessors[child] -= 1 333 | if nodes_to_num_predecessors[child] == 0: 334 | heappush(heap, _HeapEntry(child, keyfunc(child))) 335 | 336 | if len(order) != total_num_nodes: 337 | # any node which has a predecessor left is a part of a cycle 338 | raise CycleError(next(iter(n for n, num_preds in 339 | nodes_to_num_predecessors.items() if num_preds != 0))) 340 | 341 | return order 342 | 343 | # }}} 344 | 345 | 346 | # {{{ compute transitive closure 347 | 348 | def compute_transitive_closure( 349 | graph: Mapping[NodeT, MutableSet[NodeT]]) -> GraphT[NodeT]: 350 | """Compute the transitive closure of a directed graph using Warshall's 351 | algorithm. 352 | 353 | :arg graph: A :class:`collections.abc.Mapping` representing a directed 354 | graph. The mapping contains one key representing each node in the 355 | graph, and this key maps to a :class:`collections.abc.MutableSet` of 356 | nodes that are connected to the node by outgoing edges. This graph may 357 | contain cycles. This object must be picklable. Every graph node must 358 | be included as a key in the graph. 359 | 360 | :returns: The transitive closure of the graph, represented using the same 361 | data type. 362 | 363 | .. versionadded:: 2020.2 364 | """ 365 | # Warshall's algorithm 366 | 367 | from copy import deepcopy 368 | closure = deepcopy(graph) 369 | 370 | # (assumes all graph nodes are included in keys) 371 | for k in graph: 372 | for n1 in graph: 373 | for n2 in graph: 374 | if k in closure[n1] and n2 in closure[k]: 375 | closure[n1].add(n2) 376 | 377 | return closure 378 | 379 | # }}} 380 | 381 | 382 | # {{{ check for cycle 383 | 384 | def contains_cycle(graph: GraphT[NodeT]) -> bool: 385 | """Determine whether a graph contains a cycle. 386 | 387 | :returns: A :class:`bool` indicating whether the graph contains a cycle. 388 | 389 | .. versionadded:: 2020.2 390 | """ 391 | 392 | try: 393 | compute_topological_order(graph) 394 | return False 395 | except CycleError: 396 | return True 397 | 398 | # }}} 399 | 400 | 401 | # {{{ compute induced subgraph 402 | 403 | def compute_induced_subgraph(graph: Mapping[NodeT, set[NodeT]], 404 | subgraph_nodes: set[NodeT]) -> GraphT[NodeT]: 405 | """Compute the induced subgraph formed by a subset of the vertices in a 406 | graph. 407 | 408 | :arg graph: A :class:`collections.abc.Mapping` representing a directed 409 | graph. The mapping contains one key representing each node in the 410 | graph, and this key maps to a :class:`collections.abc.Set` of nodes 411 | that are connected to the node by outgoing edges. 412 | 413 | :arg subgraph_nodes: A :class:`collections.abc.Set` containing a subset of 414 | the graph nodes in the graph. 415 | 416 | :returns: A :class:`dict` representing the induced subgraph formed by 417 | the subset of the vertices included in `subgraph_nodes`. 418 | 419 | .. versionadded:: 2020.2 420 | """ 421 | 422 | new_graph = {} 423 | for node, children in graph.items(): 424 | if node in subgraph_nodes: 425 | new_graph[node] = children & subgraph_nodes 426 | return new_graph 427 | 428 | # }}} 429 | 430 | 431 | # {{{ as_graphviz_dot 432 | 433 | def as_graphviz_dot(graph: GraphT[NodeT], 434 | node_labels: Callable[[NodeT], str] | None = None, 435 | edge_labels: Callable[[NodeT, NodeT], str] | None = None, 436 | ) -> str: 437 | """ 438 | Create a visualization of the graph *graph* in the 439 | `dot `__ language. 440 | 441 | :arg node_labels: An optional function that returns node labels 442 | for each node. 443 | 444 | :arg edge_labels: An optional function that returns edge labels 445 | for each pair of nodes. 446 | 447 | :returns: A string in the `dot `__ language. 448 | """ 449 | from pytools import UniqueNameGenerator 450 | id_gen = UniqueNameGenerator(forced_prefix="mynode") 451 | 452 | from pytools.graphviz import dot_escape 453 | 454 | if node_labels is None: 455 | def node_labels(x: NodeT) -> str: 456 | return str(x) 457 | 458 | if edge_labels is None: 459 | def edge_labels(x: NodeT, y: NodeT) -> str: 460 | return "" 461 | 462 | node_to_id = {} 463 | 464 | for node, targets in graph.items(): 465 | if node not in node_to_id: 466 | node_to_id[node] = id_gen() 467 | for t in targets: 468 | if t not in node_to_id: 469 | node_to_id[t] = id_gen() 470 | 471 | # Add nodes 472 | content = "\n".join( 473 | [f'{node_to_id[node]} [label="{dot_escape(node_labels(node))}"];' 474 | for node in node_to_id]) 475 | 476 | content += "\n" 477 | 478 | # Add edges 479 | content += "\n".join( 480 | [f"{node_to_id[node]} -> {node_to_id[t]} " 481 | f'[label="{dot_escape(edge_labels(node, t))}"];' 482 | for (node, targets) in graph.items() 483 | for t in targets]) 484 | 485 | return f"digraph mygraph {{\n{content}\n}}\n" 486 | 487 | # }}} 488 | 489 | 490 | # {{{ validate graph 491 | 492 | def validate_graph(graph: GraphT[NodeT]) -> None: 493 | """ 494 | Validates that all successor nodes of each node in *graph* are keys in 495 | *graph* itself. Raises a :class:`ValueError` if not. 496 | """ 497 | seen_nodes: set[NodeT] = set() 498 | 499 | for children in graph.values(): 500 | seen_nodes.update(children) 501 | 502 | if not seen_nodes <= graph.keys(): 503 | raise ValueError( 504 | f"invalid graph, missing keys: {seen_nodes-graph.keys()}") 505 | 506 | # }}} 507 | 508 | 509 | # {{{ is_connected 510 | 511 | def is_connected(graph: GraphT[NodeT]) -> bool: 512 | """ 513 | Returns whether all nodes in *graph* are connected, ignoring 514 | the edge direction. 515 | 516 | :returns: A :class:`bool` indicating whether the graph is connected. 517 | """ 518 | if not graph: 519 | # https://cs.stackexchange.com/questions/52815/is-a-graph-of-zero-nodes-vertices-connected 520 | return True 521 | 522 | visited = set() 523 | 524 | undirected_graph = {node: set(children) for node, children in graph.items()} 525 | 526 | for node, children in graph.items(): 527 | for child in children: 528 | undirected_graph[child].add(node) 529 | 530 | def dfs(node: NodeT) -> None: 531 | visited.add(node) 532 | for child in undirected_graph[node]: 533 | if child not in visited: 534 | dfs(child) 535 | 536 | dfs(next(iter(graph.keys()))) 537 | 538 | return visited == graph.keys() 539 | 540 | # }}} 541 | 542 | 543 | def undirected_graph_from_edges( 544 | edges: Iterable[tuple[NodeT, NodeT]], 545 | ) -> GraphT[NodeT]: 546 | """ 547 | Constructs an undirected graph using *edges*. 548 | 549 | :arg edges: An :class:`Iterable` of pairs of related :class:`NodeT` s. 550 | 551 | :returns: A :class:`GraphT` that is the undirected graph. 552 | """ 553 | undirected_graph: dict[NodeT, set[NodeT]] = {} 554 | 555 | for lhs, rhs in edges: 556 | if lhs == rhs: 557 | raise TypeError("Found loop in edges," 558 | f" LHS, RHS = {lhs}") 559 | 560 | undirected_graph.setdefault(lhs, set()).add(rhs) 561 | undirected_graph.setdefault(rhs, set()).add(lhs) 562 | 563 | return undirected_graph 564 | 565 | 566 | def get_reachable_nodes( 567 | undirected_graph: GraphT[NodeT], 568 | source_node: NodeT, 569 | exclude_nodes: Collection[NodeT] | None = None) -> frozenset[NodeT]: 570 | """ 571 | Returns a :class:`frozenset` of all nodes in *undirected_graph* that are 572 | reachable from *source_node*. 573 | 574 | If any node from *exclude_nodes* lies on a path between *source_node* and 575 | some other node :math:`u` in *undirected_graph* and there are no other 576 | viable paths, then :math:`u` is considered not reachable from *source_node*. 577 | 578 | In the case where *source_node* is in *exclude_nodes*, then no node is 579 | reachable from *source_node*, so an empty :class:`frozenset` is returned. 580 | """ 581 | if exclude_nodes is not None and source_node in exclude_nodes: 582 | return frozenset() 583 | 584 | nodes_visited: set[NodeT] = set() 585 | reachable_nodes: set[NodeT] = set() 586 | nodes_to_visit = {source_node} 587 | 588 | if exclude_nodes is None: 589 | exclude_nodes = set() 590 | 591 | while nodes_to_visit: 592 | current_node = nodes_to_visit.pop() 593 | nodes_visited.add(current_node) 594 | 595 | reachable_nodes.add(current_node) 596 | 597 | neighbors = undirected_graph[current_node] 598 | nodes_to_visit.update({ 599 | node for node in neighbors 600 | if node not in nodes_visited and node not in exclude_nodes 601 | }) 602 | 603 | return frozenset(reachable_nodes) 604 | 605 | 606 | # vim: foldmethod=marker 607 | -------------------------------------------------------------------------------- /pytools/graphviz.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | 4 | __copyright__ = """ 5 | Copyright (C) 2013 Andreas Kloeckner 6 | Copyright (C) 2014 Matt Wala 7 | """ 8 | 9 | __license__ = """ 10 | Permission is hereby granted, free of charge, to any person obtaining a copy 11 | of this software and associated documentation files (the "Software"), to deal 12 | in the Software without restriction, including without limitation the rights 13 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | copies of the Software, and to permit persons to whom the Software is 15 | furnished to do so, subject to the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be included in 18 | all copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 26 | THE SOFTWARE. 27 | """ 28 | 29 | __doc__ = """ 30 | Dot helper functions 31 | ==================== 32 | 33 | .. autofunction:: dot_escape 34 | .. autofunction:: show_dot 35 | """ 36 | 37 | import html 38 | import logging 39 | import os 40 | 41 | 42 | logger = logging.getLogger(__name__) 43 | 44 | 45 | # {{{ graphviz / dot interactive show 46 | 47 | def dot_escape(s: str) -> str: 48 | """ 49 | Escape the string *s* for compatibility with the 50 | `dot `__ language, particularly 51 | backslashes and HTML tags. 52 | 53 | :arg s: The input string to escape. 54 | 55 | :returns: *s* with special characters escaped. 56 | """ 57 | # "\" and HTML are significant in graphviz. 58 | return html.escape(s.replace("\\", "\\\\")) 59 | 60 | 61 | def show_dot(dot_code: str, output_to: str | None = None) -> str | None: 62 | """ 63 | Visualize the graph represented by *dot_code*. 64 | 65 | :arg dot_code: An instance of :class:`str` in the `dot `__ 66 | language to visualize. 67 | :arg output_to: An instance of :class:`str` that can be one of: 68 | 69 | - ``"xwindow"`` to visualize the graph as an 70 | `X window `_. 71 | - ``"browser"`` to visualize the graph as an SVG file in the 72 | system's default web-browser. 73 | - ``"svg"`` to store the dot code as an SVG file on the file system. 74 | Returns the path to the generated SVG file. 75 | 76 | Defaults to ``"xwindow"`` if X11 support is present, otherwise defaults 77 | to ``"browser"``. 78 | 79 | :returns: Depends on *output_to*. If ``"svg"``, returns the path to the 80 | generated SVG file, otherwise returns ``None``. 81 | """ 82 | 83 | import subprocess 84 | from tempfile import mkdtemp 85 | temp_dir = mkdtemp(prefix="tmp_pytools_dot") 86 | 87 | dot_file_name = "code.dot" 88 | 89 | from os.path import join 90 | with open(join(temp_dir, dot_file_name), "w") as dotf: 91 | dotf.write(dot_code) 92 | 93 | # {{{ preprocess 'output_to' 94 | 95 | if output_to is None: 96 | with subprocess.Popen(["dot", "-T?"], 97 | stdout=subprocess.PIPE, 98 | stderr=subprocess.PIPE 99 | ) as proc: 100 | assert proc.stderr, ("Could not execute the 'dot' program. " 101 | "Please install the 'graphviz' package and " 102 | "make sure it is in your $PATH.") 103 | supported_formats = proc.stderr.read().decode() 104 | 105 | if " x11 " in supported_formats and "DISPLAY" in os.environ: 106 | output_to = "xwindow" 107 | else: 108 | output_to = "browser" 109 | 110 | # }}} 111 | 112 | if output_to == "xwindow": 113 | subprocess.check_call(["dot", "-Tx11", dot_file_name], cwd=temp_dir) 114 | elif output_to in ["browser", "svg"]: 115 | svg_file_name = "code.svg" 116 | subprocess.check_call(["dot", "-Tsvg", "-o", svg_file_name, dot_file_name], 117 | cwd=temp_dir) 118 | 119 | full_svg_file_name = join(temp_dir, svg_file_name) 120 | logger.info("show_dot: svg written to '%s'", full_svg_file_name) 121 | 122 | if output_to == "svg": 123 | return full_svg_file_name 124 | assert output_to == "browser" 125 | 126 | from webbrowser import open as browser_open 127 | browser_open("file://" + full_svg_file_name) 128 | else: 129 | raise ValueError("`output_to` can be one of 'xwindow', 'browser', or 'svg'," 130 | f" got '{output_to}'") 131 | 132 | return None 133 | # }}} 134 | 135 | 136 | # vim: foldmethod=marker 137 | -------------------------------------------------------------------------------- /pytools/lex.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | 4 | __copyright__ = """ 5 | Copyright (C) 2009-2013 Andreas Kloeckner 6 | Copyright (C) 2013- University of Illinois Board of Trustees 7 | """ 8 | 9 | __license__ = """ 10 | Permission is hereby granted, free of charge, to any person obtaining a copy 11 | of this software and associated documentation files (the "Software"), to deal 12 | in the Software without restriction, including without limitation the rights 13 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | copies of the Software, and to permit persons to whom the Software is 15 | furnished to do so, subject to the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be included in 18 | all copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 26 | THE SOFTWARE. 27 | """ 28 | 29 | 30 | import re 31 | from collections.abc import Mapping, Sequence 32 | from dataclasses import dataclass 33 | from typing import Literal, NoReturn, TypeAlias, cast 34 | 35 | from typing_extensions import Self, final, overload, override 36 | 37 | 38 | @final 39 | class RuleError(RuntimeError): 40 | def __init__(self, rule: LexRule) -> None: 41 | RuntimeError.__init__(self) 42 | self.Rule = rule 43 | 44 | @override 45 | def __str__(self) -> str: 46 | return repr(self.Rule) 47 | 48 | 49 | @final 50 | class InvalidTokenError(RuntimeError): 51 | def __init__(self, s: str, str_index: int) -> None: 52 | RuntimeError.__init__(self) 53 | self.string = s 54 | self.index = str_index 55 | 56 | @override 57 | def __str__(self) -> str: 58 | return "at index {}: ...{}...".format( 59 | self.index, self.string[self.index:self.index+20]) 60 | 61 | 62 | @final 63 | class ParseError(RuntimeError): 64 | def __init__(self, msg: str, s: str, token: Lexed | None) -> None: 65 | RuntimeError.__init__(self) 66 | self.message = msg 67 | self.string = s 68 | self.Token = token 69 | 70 | @override 71 | def __str__(self) -> str: 72 | if self.Token is None: 73 | return f"{self.message} at end of input" 74 | return "{} at index {}: ...{}...".format( 75 | self.message, self.Token[2], 76 | self.string[self.Token[2]:self.Token[2]+20]) 77 | 78 | 79 | @final 80 | class RE: 81 | def __init__(self, s: str, flags: int = 0) -> None: 82 | self.Content = s 83 | self.RE = re.compile(s, flags) 84 | 85 | @override 86 | def __repr__(self) -> str: 87 | return f"RE({self.Content})" 88 | 89 | 90 | LexRule: TypeAlias = tuple[str, RE | tuple[str | RE, ...]] 91 | LexTable: TypeAlias = Sequence[LexRule] 92 | BasicLexed: TypeAlias = tuple[str, str, int] 93 | LexedWithMatch: TypeAlias = tuple[str, str, int, re.Match[str]] 94 | Lexed: TypeAlias = BasicLexed | LexedWithMatch 95 | 96 | 97 | def _matches_rule( 98 | rule: RE | str | tuple[str | RE, ...], 99 | s: str, 100 | start: int, 101 | rule_dict: Mapping[str, RE | tuple[str | RE, ...]], 102 | debug: bool = False 103 | ) -> tuple[int, re.Match[str] | None]: 104 | if debug: 105 | print("Trying", rule, "on", s[start:]) 106 | 107 | if isinstance(rule, tuple): 108 | if rule[0] == "|": 109 | for subrule in rule[1:]: 110 | length, match_obj = _matches_rule( 111 | subrule, s, start, rule_dict, debug) 112 | if not length: 113 | continue 114 | return length, match_obj 115 | else: 116 | my_match_length = 0 117 | for subrule in rule: 118 | length, _ = _matches_rule( 119 | subrule, s, start, rule_dict, debug) 120 | if not length: 121 | break 122 | my_match_length += length 123 | start += length 124 | else: 125 | return my_match_length, None 126 | return 0, None 127 | 128 | if isinstance(rule, str): 129 | return _matches_rule(rule_dict[rule], s, start, rule_dict, debug) 130 | 131 | if isinstance(rule, RE): 132 | match_obj = rule.RE.match(s, start) 133 | if match_obj: 134 | return match_obj.end()-start, match_obj 135 | return 0, None 136 | 137 | raise RuleError(rule) 138 | 139 | 140 | @overload 141 | def lex( 142 | lex_table: LexTable, 143 | s: str, 144 | *, 145 | debug: bool = ..., 146 | match_objects: Literal[False] = False, 147 | ) -> Sequence[BasicLexed]: 148 | ... 149 | 150 | 151 | @overload 152 | def lex( 153 | lex_table: LexTable, 154 | s: str, 155 | *, 156 | debug: bool = ..., 157 | match_objects: Literal[True], 158 | ) -> Sequence[LexedWithMatch]: 159 | ... 160 | 161 | 162 | def lex( 163 | lex_table: LexTable, 164 | s: str, 165 | *, 166 | debug: bool = False, 167 | match_objects: bool = False 168 | ) -> Sequence[Lexed]: 169 | rule_dict = dict(lex_table) 170 | result: list[Lexed] = [] 171 | i = 0 172 | while i < len(s): 173 | for name, rule in lex_table: 174 | length, match_obj = _matches_rule(rule, s, i, rule_dict, debug) 175 | if length: 176 | if match_objects: 177 | assert match_obj 178 | result.append((name, s[i:i+length], i, match_obj)) 179 | else: 180 | result.append((name, s[i:i+length], i)) 181 | i += length 182 | break 183 | else: 184 | raise InvalidTokenError(s, i) 185 | return result 186 | 187 | 188 | @dataclass 189 | class LexIterator: 190 | lexed: Sequence[tuple[str, str, int] | tuple[str, str, int, re.Match[str]]] 191 | raw_string: str 192 | index: int = 0 193 | 194 | def copy(self) -> Self: 195 | return type(self)(self.lexed, self.raw_string, self.index) 196 | 197 | def assign(self, rhs: LexIterator) -> None: 198 | assert self.lexed is rhs.lexed 199 | assert self.raw_string is rhs.raw_string 200 | 201 | self.index = rhs.index 202 | 203 | def next_tag(self, i: int = 0) -> str: 204 | return self.lexed[self.index + i][0] 205 | 206 | def next_str(self, i: int = 0) -> str: 207 | return self.lexed[self.index + i][1] 208 | 209 | def next_match_obj(self) -> re.Match[str]: 210 | _tok, _s, _i, match = cast("LexedWithMatch", self.lexed[self.index]) 211 | return match 212 | 213 | def next_str_and_advance(self) -> str: 214 | result = self.next_str() 215 | self.advance() 216 | return result 217 | 218 | def advance(self) -> None: 219 | self.index += 1 220 | 221 | def is_at_end(self, i: int = 0) -> bool: 222 | return self.index + i >= len(self.lexed) 223 | 224 | def is_next(self, tag: str, i: int = 0) -> bool: 225 | return ( 226 | self.index + i < len(self.lexed) 227 | and self.next_tag(i) is tag) 228 | 229 | def raise_parse_error(self, msg: str) -> NoReturn: 230 | if self.is_at_end(): 231 | raise ParseError(msg, self.raw_string, None) 232 | 233 | raise ParseError(msg, self.raw_string, self.lexed[self.index]) 234 | 235 | def expected(self, what_expected: str) -> NoReturn: 236 | if self.is_at_end(): 237 | self.raise_parse_error( 238 | f"{what_expected} expected, end of input found instead") 239 | else: 240 | self.raise_parse_error( 241 | f"{what_expected} expected, {self.next_tag()} found instead") 242 | 243 | def expect_not_end(self) -> None: 244 | if self.is_at_end(): 245 | self.raise_parse_error("unexpected end of input") 246 | 247 | def expect(self, tag: str) -> None: 248 | self.expect_not_end() 249 | if not self.is_next(tag): 250 | self.expected(str(tag)) 251 | -------------------------------------------------------------------------------- /pytools/mpi.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | 4 | __copyright__ = """ 5 | Copyright (C) 2009-2019 Andreas Kloeckner 6 | Copyright (C) 2022 University of Illinois Board of Trustees 7 | """ 8 | 9 | __license__ = """ 10 | Permission is hereby granted, free of charge, to any person obtaining a copy 11 | of this software and associated documentation files (the "Software"), to deal 12 | in the Software without restriction, including without limitation the rights 13 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | copies of the Software, and to permit persons to whom the Software is 15 | furnished to do so, subject to the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be included in 18 | all copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 26 | THE SOFTWARE. 27 | """ 28 | 29 | __doc__ = """ 30 | MPI helper functionality 31 | ======================== 32 | 33 | .. autofunction:: check_for_mpi_relaunch 34 | .. autofunction:: run_with_mpi_ranks 35 | .. autofunction:: pytest_raises_on_rank 36 | """ 37 | 38 | import contextlib 39 | from typing import TYPE_CHECKING 40 | 41 | 42 | if TYPE_CHECKING: 43 | from collections.abc import Generator 44 | 45 | 46 | def check_for_mpi_relaunch(argv): 47 | if argv[1] != "--mpi-relaunch": 48 | return 49 | 50 | from pickle import loads 51 | f, args, kwargs = loads(argv[2]) 52 | 53 | f(*args, **kwargs) 54 | import sys 55 | sys.exit() 56 | 57 | 58 | def run_with_mpi_ranks(py_script, ranks, callable_, args=(), kwargs=None): 59 | if kwargs is None: 60 | kwargs = {} 61 | 62 | import os 63 | import sys 64 | newenv = os.environ.copy() 65 | newenv["PYTOOLS_RUN_WITHIN_MPI"] = "1" 66 | 67 | from pickle import dumps 68 | callable_and_args = dumps((callable_, args, kwargs)) 69 | 70 | from subprocess import check_call 71 | check_call(["mpirun", "-np", str(ranks), 72 | sys.executable, py_script, "--mpi-relaunch", callable_and_args], 73 | env=newenv) 74 | 75 | 76 | @contextlib.contextmanager 77 | def pytest_raises_on_rank( 78 | my_rank: int, fail_rank: int, 79 | expected_exception: type[BaseException] | tuple[type[BaseException], ...], 80 | ) -> Generator[contextlib.AbstractContextManager, None, None]: 81 | """ 82 | Like :func:`pytest.raises`, but only expect an exception on rank *fail_rank*. 83 | """ 84 | from contextlib import nullcontext 85 | 86 | import pytest 87 | 88 | if my_rank == fail_rank: 89 | cm: contextlib.AbstractContextManager = pytest.raises(expected_exception) 90 | else: 91 | cm = nullcontext() 92 | 93 | with cm as exc: 94 | yield exc 95 | -------------------------------------------------------------------------------- /pytools/mpiwrap.py: -------------------------------------------------------------------------------- 1 | """See pytools.prefork for this module's reason for being.""" 2 | from __future__ import annotations 3 | 4 | import mpi4py.rc # pylint:disable=import-error 5 | 6 | 7 | mpi4py.rc.initialize = False 8 | 9 | from mpi4py.MPI import * # noqa: F403 pylint:disable=wildcard-import,wrong-import-position 10 | 11 | import pytools.prefork # pylint:disable=wrong-import-position 12 | 13 | 14 | pytools.prefork.enable_prefork() 15 | 16 | 17 | if Is_initialized(): # type: ignore[name-defined] # noqa: F405 18 | raise RuntimeError("MPI already initialized before MPI wrapper import") 19 | 20 | 21 | def InitWithAutoFinalize(*args, **kwargs): # noqa: N802 22 | result = Init(*args, **kwargs) # noqa: F405 23 | import atexit 24 | atexit.register(Finalize) # # noqa: F405 25 | return result 26 | -------------------------------------------------------------------------------- /pytools/obj_array.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | 4 | __copyright__ = "Copyright (C) 2009-2020 Andreas Kloeckner" 5 | 6 | __license__ = """ 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in 15 | all copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 23 | THE SOFTWARE. 24 | """ 25 | 26 | from functools import partial, update_wrapper 27 | from warnings import warn 28 | 29 | import numpy as np 30 | 31 | 32 | __doc__ = """ 33 | Handling :mod:`numpy` Object Arrays 34 | =================================== 35 | 36 | Creation 37 | -------- 38 | 39 | .. autofunction:: make_obj_array 40 | .. autofunction:: flat_obj_array 41 | 42 | Mapping 43 | ------- 44 | 45 | .. autofunction:: obj_array_vectorize 46 | .. autofunction:: obj_array_vectorize_n_args 47 | 48 | Numpy workarounds 49 | ----------------- 50 | These functions work around a `long-standing, annoying numpy issue 51 | `__. 52 | 53 | .. autofunction:: obj_array_real 54 | .. autofunction:: obj_array_imag 55 | .. autofunction:: obj_array_real_copy 56 | .. autofunction:: obj_array_imag_copy 57 | """ 58 | 59 | 60 | def make_obj_array(res_list): 61 | """Create a one-dimensional object array from *res_list*. 62 | This differs from ``numpy.array(res_list, dtype=object)`` 63 | by whether it tries to determine its shape by descending 64 | into nested array-like objects. Consider the following example: 65 | 66 | .. doctest:: 67 | 68 | >>> import numpy as np 69 | >>> a = np.array([np.arange(5), np.arange(5)], dtype=object) 70 | >>> a 71 | array([[0, 1, 2, 3, 4], 72 | [0, 1, 2, 3, 4]], dtype=object) 73 | >>> a.shape 74 | (2, 5) 75 | >>> # meanwhile: 76 | >>> from pytools.obj_array import make_obj_array 77 | >>> b = make_obj_array([np.arange(5), np.arange(5)]) 78 | >>> b 79 | array([array([0, 1, 2, 3, 4]), array([0, 1, 2, 3, 4])], dtype=object) 80 | >>> b.shape 81 | (2,) 82 | 83 | In some settings (such as when the sub-arrays are large and/or 84 | live on a GPU), the recursive behavior of :func:`numpy.array` 85 | can be undesirable. 86 | """ 87 | result = np.empty((len(res_list),), dtype=object) 88 | 89 | # 'result[:] = res_list' may look tempting, however: 90 | # https://github.com/numpy/numpy/issues/16564 91 | for idx in range(len(res_list)): 92 | result[idx] = res_list[idx] 93 | 94 | return result 95 | 96 | 97 | def obj_array_to_hashable(f): 98 | if isinstance(f, np.ndarray) and f.dtype.char == "O": 99 | return tuple(f) 100 | return f 101 | 102 | 103 | def flat_obj_array(*args): 104 | """Return a one-dimensional flattened object array consisting of 105 | elements obtained by 'flattening' *args* as follows: 106 | 107 | - The first axis of any non-subclassed object arrays will be flattened 108 | into the result. 109 | - Instances of :class:`list` will be flattened into the result. 110 | - Any other type will appear in the list as-is. 111 | """ 112 | res_list = [] 113 | for arg in args: 114 | if isinstance(arg, list): 115 | res_list.extend(arg) 116 | 117 | # Only flatten genuine, non-subclassed object arrays. 118 | elif type(arg) is np.ndarray: 119 | res_list.extend(arg.flat) 120 | 121 | else: 122 | res_list.append(arg) 123 | 124 | return make_obj_array(res_list) 125 | 126 | 127 | def obj_array_vectorize(f, ary): 128 | """Apply the function *f* to all entries of the object array *ary*. 129 | Return an object array of the same shape consisting of the return 130 | values. 131 | If *ary* is not an object array, return ``f(ary)``. 132 | 133 | .. note :: 134 | 135 | This function exists because :class:`numpy.vectorize` suffers from the same 136 | issue described under :func:`make_obj_array`. 137 | """ 138 | 139 | if isinstance(ary, np.ndarray) and ary.dtype.char == "O": 140 | result = np.empty_like(ary) 141 | for i in np.ndindex(ary.shape): 142 | result[i] = f(ary[i]) 143 | return result 144 | return f(ary) 145 | 146 | 147 | def obj_array_vectorized(f): 148 | wrapper = partial(obj_array_vectorize, f) 149 | update_wrapper(wrapper, f) 150 | return wrapper 151 | 152 | 153 | def rec_obj_array_vectorize(f, ary): 154 | """Apply the function *f* to all entries of the object array *ary*. 155 | Return an object array of the same shape consisting of the return 156 | values. 157 | If the elements of *ary* are further object arrays, recurse 158 | until non-object-arrays are found and then apply *f* to those 159 | entries. 160 | If *ary* is not an object array, return ``f(ary)``. 161 | 162 | .. note :: 163 | 164 | This function exists because :class:`numpy.vectorize` suffers from the same 165 | issue described under :func:`make_obj_array`. 166 | """ 167 | if isinstance(ary, np.ndarray) and ary.dtype.char == "O": 168 | result = np.empty_like(ary) 169 | for i in np.ndindex(ary.shape): 170 | result[i] = rec_obj_array_vectorize(f, ary[i]) 171 | return result 172 | return f(ary) 173 | 174 | 175 | def rec_obj_array_vectorized(f): 176 | wrapper = partial(rec_obj_array_vectorize, f) 177 | update_wrapper(wrapper, f) 178 | return wrapper 179 | 180 | 181 | def obj_array_vectorize_n_args(f, *args): 182 | """Apply the function *f* elementwise to all entries of any 183 | object arrays in *args*. All such object arrays are expected 184 | to have the same shape (but this is not checked). 185 | Equivalent to an appropriately-looped execution of:: 186 | 187 | result[idx] = f(obj_array_arg1[idx], arg2, obj_array_arg3[idx]) 188 | 189 | Return an object array of the same shape as the arguments consisting of the 190 | return values of *f*. 191 | 192 | .. note :: 193 | 194 | This function exists because :class:`numpy.vectorize` suffers from the same 195 | issue described under :func:`make_obj_array`. 196 | """ 197 | oarray_arg_indices = [] 198 | for i, arg in enumerate(args): 199 | if isinstance(arg, np.ndarray) and arg.dtype.char == "O": 200 | oarray_arg_indices.append(i) 201 | 202 | if not oarray_arg_indices: 203 | return f(*args) 204 | 205 | leading_oa_index = oarray_arg_indices[0] 206 | 207 | template_ary = args[leading_oa_index] 208 | result = np.empty_like(template_ary) 209 | new_args = list(args) 210 | for i in np.ndindex(template_ary.shape): 211 | for arg_i in oarray_arg_indices: 212 | new_args[arg_i] = args[arg_i][i] 213 | result[i] = f(*new_args) 214 | return result 215 | 216 | 217 | def obj_array_vectorized_n_args(f): 218 | # Unfortunately, this can't use partial(), as the callable returned by it 219 | # will not be turned into a bound method upon attribute access. 220 | # This may happen here, because the decorator *could* be used 221 | # on methods, since it can "look past" the leading `self` argument. 222 | # Only exactly function objects receive this treatment. 223 | # 224 | # Spec link: 225 | # https://docs.python.org/3/reference/datamodel.html#the-standard-type-hierarchy 226 | # (under "Instance Methods", quote as of Py3.9.4) 227 | # > Also notice that this transformation only happens for user-defined functions; 228 | # > other callable objects (and all non-callable objects) are retrieved 229 | # > without transformation. 230 | 231 | def wrapper(*args): 232 | return obj_array_vectorize_n_args(f, *args) 233 | 234 | update_wrapper(wrapper, f) 235 | return wrapper 236 | 237 | 238 | # {{{ workarounds for https://github.com/numpy/numpy/issues/1740 239 | 240 | def obj_array_real(ary): 241 | return rec_obj_array_vectorize(lambda x: x.real, ary) 242 | 243 | 244 | def obj_array_imag(ary): 245 | return rec_obj_array_vectorize(lambda x: x.imag, ary) 246 | 247 | 248 | def obj_array_real_copy(ary): 249 | return rec_obj_array_vectorize(lambda x: x.real.copy(), ary) 250 | 251 | 252 | def obj_array_imag_copy(ary): 253 | return rec_obj_array_vectorize(lambda x: x.imag.copy(), ary) 254 | 255 | # }}} 256 | 257 | 258 | # {{{ deprecated junk 259 | 260 | def is_obj_array(val): 261 | warn("is_obj_array is deprecated and will go away in 2022, " 262 | "just inline the check.", DeprecationWarning, stacklevel=2) 263 | 264 | try: 265 | return isinstance(val, np.ndarray) and val.dtype.char == "O" 266 | except AttributeError: 267 | return False 268 | 269 | 270 | def log_shape(array): 271 | """Returns the "logical shape" of the array. 272 | 273 | The "logical shape" is the shape that's left when the node-depending 274 | dimension has been eliminated. 275 | """ 276 | 277 | warn("log_shape is deprecated and will go away in 2021, " 278 | "use the actual object array shape.", 279 | DeprecationWarning, stacklevel=2) 280 | 281 | try: 282 | if array.dtype.char == "O": 283 | return array.shape 284 | return array.shape[:-1] 285 | except AttributeError: 286 | return () 287 | 288 | 289 | def join_fields(*args): 290 | warn("join_fields is deprecated and will go away in 2022, " 291 | "use flat_obj_array", DeprecationWarning, stacklevel=2) 292 | 293 | return flat_obj_array(*args) 294 | 295 | 296 | def is_equal(a, b): 297 | warn("is_equal is deprecated and will go away in 2021, " 298 | "use numpy.array_equal", DeprecationWarning, stacklevel=2) 299 | 300 | if is_obj_array(a): 301 | return is_obj_array(b) and (a.shape == b.shape) and (a == b).all() 302 | return not is_obj_array(b) and a == b 303 | 304 | 305 | is_field_equal = is_equal 306 | 307 | 308 | def gen_len(expr): 309 | if is_obj_array(expr): 310 | return len(expr) 311 | return 1 312 | 313 | 314 | def gen_slice(expr, slice_): 315 | warn("gen_slice is deprecated and will go away in 2021", 316 | DeprecationWarning, stacklevel=2) 317 | 318 | result = expr[slice_] 319 | if len(result) == 1: 320 | return result[0] 321 | return result 322 | 323 | 324 | def obj_array_equal(a, b): 325 | warn("obj_array_equal is deprecated and will go away in 2021, " 326 | "use numpy.array_equal", DeprecationWarning, stacklevel=2) 327 | 328 | a_is_oa = is_obj_array(a) 329 | assert a_is_oa == is_obj_array(b) 330 | 331 | if a_is_oa: 332 | return np.array_equal(a, b) 333 | return a == b 334 | 335 | 336 | def to_obj_array(ary): 337 | warn("to_obj_array is deprecated and will go away in 2021, " 338 | "use make_obj_array", DeprecationWarning, 339 | stacklevel=2) 340 | 341 | ls = log_shape(ary) 342 | result = np.empty(ls, dtype=object) 343 | 344 | for i in np.ndindex(ls): 345 | result[i] = ary[i] 346 | 347 | return result 348 | 349 | 350 | def setify_field(f): 351 | warn("setify_field is deprecated and will go away in 2021", 352 | DeprecationWarning, stacklevel=2) 353 | 354 | if is_obj_array(f): 355 | return set(f) 356 | return {f} 357 | 358 | 359 | def cast_field(field, dtype): 360 | warn("cast_field is deprecated and will go away in 2021", 361 | DeprecationWarning, stacklevel=2) 362 | 363 | return with_object_array_or_scalar( 364 | lambda f: f.astype(dtype), field) 365 | 366 | 367 | def with_object_array_or_scalar(f, field, obj_array_only=False): 368 | warn("with_object_array_or_scalar is deprecated and will go away in 2022, " 369 | "use obj_array_vectorize", DeprecationWarning, stacklevel=2) 370 | 371 | if obj_array_only: 372 | ls = field.shape if is_obj_array(field) else () 373 | else: 374 | ls = log_shape(field) 375 | if ls != (): 376 | result = np.zeros(ls, dtype=object) 377 | for i in np.ndindex(ls): 378 | result[i] = f(field[i]) 379 | return result 380 | return f(field) 381 | 382 | 383 | def as_oarray_func(f): 384 | wrapper = partial(with_object_array_or_scalar, f) 385 | update_wrapper(wrapper, f) 386 | return wrapper 387 | 388 | 389 | def with_object_array_or_scalar_n_args(f, *args): 390 | warn("with_object_array_or_scalar_n_args is deprecated and " 391 | "will go away in 2022, " 392 | "use obj_array_vectorize_n_args", DeprecationWarning, stacklevel=2) 393 | 394 | oarray_arg_indices = [] 395 | for i, arg in enumerate(args): 396 | if is_obj_array(arg): 397 | oarray_arg_indices.append(i) 398 | 399 | if not oarray_arg_indices: 400 | return f(*args) 401 | 402 | leading_oa_index = oarray_arg_indices[0] 403 | 404 | ls = log_shape(args[leading_oa_index]) 405 | if ls != (): 406 | result = np.zeros(ls, dtype=object) 407 | 408 | new_args = list(args) 409 | for i in np.ndindex(ls): 410 | for arg_i in oarray_arg_indices: 411 | new_args[arg_i] = args[arg_i][i] 412 | 413 | result[i] = f(*new_args) 414 | return result 415 | return f(*args) 416 | 417 | 418 | def as_oarray_func_n_args(f): 419 | wrapper = partial(with_object_array_or_scalar_n_args, f) 420 | update_wrapper(wrapper, f) 421 | return wrapper 422 | 423 | 424 | def oarray_real(ary): 425 | warn("oarray_real is deprecated and will go away in 2022, " 426 | "use obj_array_real", DeprecationWarning, stacklevel=2) 427 | return obj_array_real(ary) 428 | 429 | 430 | def oarray_imag(ary): 431 | warn("oarray_imag is deprecated and will go away in 2022, " 432 | "use obj_array_imag", DeprecationWarning, stacklevel=2) 433 | return obj_array_imag(ary) 434 | 435 | 436 | def oarray_real_copy(ary): 437 | warn("oarray_real_copy is deprecated and will go away in 2022, " 438 | "use obj_array_real_copy", DeprecationWarning, stacklevel=2) 439 | return obj_array_real_copy(ary) 440 | 441 | 442 | def oarray_imag_copy(ary): 443 | warn("oarray_imag_copy is deprecated and will go away in 2022, " 444 | "use obj_array_imag_copy", DeprecationWarning, stacklevel=2) 445 | return obj_array_imag_copy(ary) 446 | 447 | # }}} 448 | 449 | # vim: foldmethod=marker 450 | -------------------------------------------------------------------------------- /pytools/prefork.py: -------------------------------------------------------------------------------- 1 | """OpenMPI, once initialized, prohibits forking. This helper module 2 | allows the forking of *one* helper child process before OpenMPI 3 | initialization that can do the forking for the fork-challenged 4 | parent process. 5 | 6 | Since none of this is MPI-specific, it got parked in :mod:`pytools`. 7 | 8 | .. autoexception:: ExecError 9 | :show-inheritance: 10 | 11 | .. autoclass:: Forker 12 | .. autoclass:: DirectForker 13 | .. autoclass:: IndirectForker 14 | 15 | .. autofunction:: enable_prefork 16 | .. autofunction:: call 17 | .. autofunction:: call_async 18 | .. autofunction:: call_capture_output 19 | .. autofunction:: wait 20 | .. autofunction:: waitall 21 | """ 22 | from __future__ import annotations 23 | 24 | import socket 25 | from abc import ABC, abstractmethod 26 | from subprocess import Popen 27 | from typing import TYPE_CHECKING, Any 28 | 29 | from typing_extensions import override 30 | 31 | 32 | if TYPE_CHECKING: 33 | from collections.abc import Sequence 34 | 35 | 36 | class ExecError(OSError): 37 | pass 38 | 39 | 40 | class Forker(ABC): 41 | @abstractmethod 42 | def call(self, cmdline: Sequence[str], cwd: str | None = None) -> int: 43 | pass 44 | 45 | @abstractmethod 46 | def call_async(self, cmdline: Sequence[str], cwd: str | None = None) -> int: 47 | pass 48 | 49 | @abstractmethod 50 | def call_capture_output(self, 51 | cmdline: Sequence[str], 52 | cwd: str | None = None, 53 | error_on_nonzero: bool = True) -> tuple[int, bytes, bytes]: 54 | pass 55 | 56 | @abstractmethod 57 | def wait(self, aid: int) -> int: 58 | pass 59 | 60 | @abstractmethod 61 | def waitall(self) -> dict[int, int]: 62 | pass 63 | 64 | 65 | class DirectForker(Forker): 66 | def __init__(self) -> None: 67 | self.apids: dict[int, Popen[bytes]] = {} 68 | self.count: int = 0 69 | 70 | @override 71 | def call(self, cmdline: Sequence[str], cwd: str | None = None) -> int: 72 | from subprocess import call as spcall 73 | 74 | try: 75 | return spcall(cmdline, cwd=cwd) 76 | except OSError as e: 77 | raise ExecError( 78 | "error invoking '{}': {}".format(" ".join(cmdline), e)) from e 79 | 80 | @override 81 | def call_async(self, cmdline: Sequence[str], cwd: str | None = None) -> int: 82 | try: 83 | self.count += 1 84 | 85 | proc = Popen(cmdline, cwd=cwd) 86 | self.apids[self.count] = proc 87 | 88 | return self.count 89 | except OSError as e: 90 | raise ExecError( 91 | "error invoking '{}': {}".format(" ".join(cmdline), e)) from e 92 | 93 | @override 94 | def call_capture_output(self, 95 | cmdline: Sequence[str], 96 | cwd: str | None = None, 97 | error_on_nonzero: bool = True) -> tuple[int, bytes, bytes]: 98 | from subprocess import PIPE, Popen 99 | 100 | try: 101 | popen = Popen(cmdline, cwd=cwd, stdin=PIPE, stdout=PIPE, 102 | stderr=PIPE) 103 | stdout_data, stderr_data = popen.communicate() 104 | 105 | if error_on_nonzero and popen.returncode: 106 | raise ExecError("status {} invoking '{}': {}".format( 107 | popen.returncode, 108 | " ".join(cmdline), 109 | stderr_data.decode("utf-8", errors="replace"))) 110 | 111 | return popen.returncode, stdout_data, stderr_data 112 | except OSError as e: 113 | raise ExecError( 114 | "error invoking '{}': {}".format(" ".join(cmdline), e)) from e 115 | 116 | @override 117 | def wait(self, aid: int) -> int: 118 | proc = self.apids.pop(aid) 119 | retc = proc.wait() 120 | 121 | return retc 122 | 123 | @override 124 | def waitall(self) -> dict[int, int]: 125 | rets = {} 126 | 127 | for aid in self.apids: 128 | rets[aid] = self.wait(aid) 129 | 130 | return rets 131 | 132 | 133 | def _send_packet(sock: socket.socket, data: object) -> None: 134 | from pickle import dumps 135 | from struct import pack 136 | 137 | packet = dumps(data) 138 | 139 | sock.sendall(pack("I", len(packet))) 140 | sock.sendall(packet) 141 | 142 | 143 | def _recv_packet(sock: socket.socket, 144 | who: str = "Process", 145 | partner: str = "other end") -> tuple[object, ...]: 146 | from struct import calcsize, unpack 147 | size_bytes_size = calcsize("I") 148 | size_bytes = sock.recv(size_bytes_size) 149 | 150 | if len(size_bytes) < size_bytes_size: 151 | raise SystemExit 152 | 153 | size, = unpack("I", size_bytes) 154 | 155 | packet = b"" 156 | while len(packet) < size: 157 | packet += sock.recv(size) 158 | 159 | from pickle import loads 160 | 161 | result = loads(packet) 162 | assert isinstance(result, tuple) 163 | 164 | return result 165 | 166 | 167 | def _fork_server(sock: socket.socket) -> None: 168 | # Ignore keyboard interrupts, we'll get notified by the parent. 169 | import signal 170 | signal.signal(signal.SIGINT, signal.SIG_IGN) 171 | 172 | # Construct a local DirectForker to do the dirty work 173 | df = DirectForker() 174 | 175 | funcs = { 176 | "call": df.call, 177 | "call_async": df.call_async, 178 | "call_capture_output": df.call_capture_output, 179 | "wait": df.wait, 180 | "waitall": df.waitall 181 | } 182 | 183 | try: 184 | while True: 185 | func_name, args, kwargs = _recv_packet( 186 | sock, who="Prefork server", partner="parent" 187 | ) 188 | assert isinstance(func_name, str) 189 | 190 | if func_name == "quit": 191 | df.waitall() 192 | _send_packet(sock, ("ok", None)) 193 | break 194 | try: 195 | result = funcs[func_name](*args, **kwargs) # type: ignore[operator] 196 | # FIXME: Is catching all exceptions the right course of action? 197 | except Exception as e: # pylint:disable=broad-except 198 | _send_packet(sock, ("exception", e)) 199 | else: 200 | _send_packet(sock, ("ok", result)) 201 | finally: 202 | sock.close() 203 | 204 | import os 205 | os._exit(0) 206 | 207 | 208 | class IndirectForker(Forker): 209 | def __init__(self, server_pid: int, sock: socket.socket) -> None: 210 | self.server_pid = server_pid 211 | self.socket = sock 212 | 213 | import atexit 214 | atexit.register(self._quit) 215 | 216 | def _remote_invoke(self, name: str, *args: Any, **kwargs: Any) -> object: 217 | _send_packet(self.socket, (name, args, kwargs)) 218 | status, result = _recv_packet( 219 | self.socket, who="Prefork client", partner="prefork server" 220 | ) 221 | 222 | if status == "exception": 223 | assert isinstance(result, Exception) 224 | raise result 225 | 226 | assert status == "ok" 227 | return result 228 | 229 | def _quit(self) -> None: 230 | self._remote_invoke("quit") 231 | 232 | from os import waitpid 233 | waitpid(self.server_pid, 0) 234 | 235 | @override 236 | def call(self, cmdline: Sequence[str], cwd: str | None = None) -> int: 237 | result = self._remote_invoke("call", cmdline, cwd) 238 | 239 | assert isinstance(result, int) 240 | return result 241 | 242 | @override 243 | def call_async(self, cmdline: Sequence[str], cwd: str | None = None) -> int: 244 | result = self._remote_invoke("call_async", cmdline, cwd) 245 | 246 | assert isinstance(result, int) 247 | return result 248 | 249 | @override 250 | def call_capture_output(self, 251 | cmdline: Sequence[str], 252 | cwd: str | None = None, 253 | error_on_nonzero: bool = True, 254 | ) -> tuple[int, bytes, bytes]: 255 | return self._remote_invoke("call_capture_output", cmdline, cwd, # type: ignore[return-value] 256 | error_on_nonzero) 257 | 258 | @override 259 | def wait(self, aid: int) -> int: 260 | result = self._remote_invoke("wait", aid) 261 | 262 | assert isinstance(result, int) 263 | return result 264 | 265 | @override 266 | def waitall(self) -> dict[int, int]: 267 | result = self._remote_invoke("waitall") 268 | 269 | assert isinstance(result, dict) 270 | return result 271 | 272 | 273 | forker: Forker = DirectForker() 274 | 275 | 276 | def enable_prefork() -> None: 277 | global forker 278 | 279 | if isinstance(forker, IndirectForker): 280 | return 281 | 282 | s_parent, s_child = socket.socketpair() 283 | 284 | from os import fork 285 | fork_res = fork() 286 | 287 | # Child 288 | if fork_res == 0: 289 | s_parent.close() 290 | _fork_server(s_child) 291 | # Parent 292 | else: 293 | s_child.close() 294 | forker = IndirectForker(fork_res, s_parent) 295 | 296 | 297 | def call(cmdline: Sequence[str], cwd: str | None = None) -> int: 298 | return forker.call(cmdline, cwd) 299 | 300 | 301 | def call_async(cmdline: Sequence[str], cwd: str | None = None) -> int: 302 | return forker.call_async(cmdline, cwd) 303 | 304 | 305 | def call_capture_output(cmdline: Sequence[str], 306 | cwd: str | None = None, 307 | error_on_nonzero: bool = True) -> tuple[int, bytes, bytes]: 308 | return forker.call_capture_output(cmdline, cwd, error_on_nonzero) 309 | 310 | 311 | def wait(aid: int) -> int: 312 | return forker.wait(aid) 313 | 314 | 315 | def waitall() -> dict[int, int]: 316 | return forker.waitall() 317 | -------------------------------------------------------------------------------- /pytools/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inducer/pytools/2be5f3f72670830454a4687ca8bb8dff0e214b97/pytools/py.typed -------------------------------------------------------------------------------- /pytools/py_codegen.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | 4 | __copyright__ = "Copyright (C) 2009-2013 Andreas Kloeckner" 5 | 6 | __license__ = """ 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in 15 | all copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 23 | THE SOFTWARE. 24 | """ 25 | 26 | 27 | import marshal 28 | from dataclasses import dataclass, field 29 | from importlib.util import MAGIC_NUMBER as BYTECODE_VERSION 30 | from types import FunctionType, ModuleType 31 | from typing import TYPE_CHECKING, Any, Literal, TypeAlias, cast 32 | 33 | 34 | if TYPE_CHECKING: 35 | from collections.abc import Callable, Iterable 36 | 37 | from pytools.codegen import ( 38 | CodeGenerator as CodeGeneratorBase, 39 | Indentation, 40 | remove_common_indentation, 41 | ) 42 | 43 | 44 | __all__ = ( 45 | "ExistingLineCacheWarning", 46 | "Indentation", 47 | "PicklableFunction", 48 | "PicklableModule", 49 | "PythonCodeGenerator", 50 | "PythonFunctionGenerator", 51 | "remove_common_indentation", 52 | ) 53 | 54 | 55 | class ExistingLineCacheWarning(Warning): 56 | """Warning for overwriting existing generated code in the linecache.""" 57 | 58 | 59 | def _linecache_unique_name(name_prefix: str, source_text: str | None) -> str: 60 | import linecache 61 | 62 | if source_text is not None: 63 | from siphash24 import siphash13 # pyright: ignore[reportUnknownVariableType] 64 | src_digest = cast("str", siphash13(source_text.encode()).hexdigest()) # pyright: ignore[reportUnknownMemberType] 65 | 66 | name_prefix = f"{name_prefix}-{src_digest}" 67 | 68 | from pytools import UniqueNameGenerator 69 | name_gen = UniqueNameGenerator( 70 | existing_names=linecache.cache.keys(), 71 | forced_prefix=" dict[str, Any]: 83 | if name_prefix is None: 84 | name_prefix = "module" 85 | 86 | if name is None: 87 | name = _linecache_unique_name(name_prefix, source_text) 88 | 89 | result_dict: dict[str, Any] = {} 90 | 91 | # {{{ insert into linecache 92 | 93 | import linecache 94 | 95 | if name in linecache.cache: 96 | from warnings import warn 97 | warn(f"Overwriting existing generated code in linecache: '{name}'.", 98 | ExistingLineCacheWarning, 99 | stacklevel=2) 100 | 101 | linecache.cache[name] = (len(source_text), None, 102 | [e+"\n" for e in source_text.split("\n")], name) 103 | 104 | # }}} 105 | 106 | code_obj = compile( 107 | source_text.rstrip()+"\n", name, "exec") 108 | result_dict["__code__"] = code_obj 109 | exec(code_obj, result_dict) 110 | 111 | return result_dict 112 | 113 | 114 | class PythonCodeGenerator(CodeGeneratorBase): 115 | def get_module(self, name: str | None = None, 116 | *, 117 | name_prefix: str | None = None, 118 | ) -> dict[str, Any]: 119 | return _make_module(self.get(), name=name, name_prefix=name_prefix) 120 | 121 | def get_picklable_module(self, 122 | name: str | None = None, 123 | name_prefix: str | None = None 124 | ) -> PicklableModule: 125 | return PicklableModule(self.get_module(name=name, name_prefix=name_prefix)) 126 | 127 | 128 | class PythonFunctionGenerator(PythonCodeGenerator): 129 | name: str 130 | 131 | def __init__(self, name: str, args: Iterable[str], 132 | decorators: Iterable[str] = ()) -> None: 133 | super().__init__() 134 | self.name = name 135 | 136 | for decorator in decorators: 137 | self(decorator) 138 | 139 | self("def {}({}):".format(name, ", ".join(args))) 140 | self.indent() 141 | 142 | def get_function(self) -> Callable[..., Any]: 143 | return self.get_module(name_prefix=self.name)[self.name] # pyright: ignore[reportAny] 144 | 145 | def get_picklable_function(self) -> PicklableFunction: 146 | return PicklableFunction( 147 | self.get_picklable_module(name_prefix=self.name), self.name) 148 | 149 | 150 | # {{{ pickling of binaries for generated code 151 | 152 | def _get_empty_module_dict(filename: str | None = None) -> dict[str, Any]: 153 | if filename is None: 154 | filename = "" 155 | 156 | result_dict: dict[str, Any] = {} 157 | code_obj = compile("", filename, "exec") 158 | result_dict["__code__"] = code_obj 159 | exec(code_obj, result_dict) 160 | return result_dict 161 | 162 | 163 | _empty_module_dict = _get_empty_module_dict() 164 | 165 | 166 | _FunctionsType: TypeAlias = dict[str, tuple[str, bytes, tuple[object, ...] | None]] 167 | _ModulesType: TypeAlias = dict[str, str] 168 | 169 | 170 | @dataclass 171 | class PicklableModule: 172 | mod_globals: dict[str, Any] 173 | name_prefix: str | None = field(kw_only=True, default=None) 174 | source_code: str | None = field(kw_only=True, default=None) 175 | 176 | def __getstate__(self): 177 | functions: _FunctionsType = {} 178 | modules: _ModulesType = {} 179 | nondefault_globals: dict[str, object] = {} 180 | 181 | for k, v in self.mod_globals.items(): # pyright: ignore[reportAny] 182 | if isinstance(v, FunctionType): 183 | functions[k] = ( 184 | v.__name__, 185 | marshal.dumps(v.__code__), 186 | v.__defaults__) 187 | elif isinstance(v, ModuleType): 188 | modules[k] = v.__name__ 189 | elif k not in _empty_module_dict: 190 | nondefault_globals[k] = v 191 | 192 | return (2, BYTECODE_VERSION, functions, modules, nondefault_globals, 193 | self.name_prefix, self.source_code) 194 | 195 | def __setstate__(self, obj: ( 196 | tuple[Literal[0], bytes, _FunctionsType, dict[str, object]] 197 | | tuple[Literal[1], bytes, _FunctionsType, _ModulesType, 198 | dict[str, object]] 199 | | tuple[Literal[2], bytes, _FunctionsType, _ModulesType, 200 | dict[str, object], str | None, str | None] 201 | ) 202 | ): 203 | if obj[0] == 0: 204 | magic, functions, nondefault_globals = obj[1:] 205 | modules = {} 206 | name_prefix: str | None = None 207 | source_code: str | None = None 208 | elif obj[0] == 1: 209 | magic, functions, modules, nondefault_globals = obj[1:] 210 | name_prefix = None 211 | source_code = None 212 | elif obj[0] == 2: 213 | magic, functions, modules, nondefault_globals, name_prefix, source_code = \ 214 | obj[1:] 215 | else: 216 | raise ValueError("unknown version of PicklableModule") 217 | 218 | if magic != BYTECODE_VERSION: 219 | raise ValueError( 220 | "cannot unpickle function binary: incorrect magic value " 221 | f"(got: {magic!r}, expected: {BYTECODE_VERSION!r})") 222 | 223 | unique_filename = _linecache_unique_name( 224 | name_prefix if name_prefix else "module", source_code) 225 | mod_globals = _get_empty_module_dict(unique_filename) 226 | mod_globals.update(nondefault_globals) 227 | 228 | import linecache 229 | if source_code: 230 | linecache.cache[unique_filename] = (len(source_code), None, 231 | [e+"\n" for e in source_code.split("\n")], 232 | unique_filename) 233 | 234 | from importlib import import_module 235 | for k, mod_name in modules.items(): 236 | mod_globals[k] = import_module(mod_name) 237 | 238 | for k, (name, code_bytes, argdefs) in functions.items(): 239 | f = FunctionType( 240 | marshal.loads(code_bytes), mod_globals, name=name, 241 | argdefs=argdefs) 242 | mod_globals[k] = f 243 | 244 | self.mod_globals = mod_globals 245 | self.name_prefix = name_prefix 246 | self.source_code = source_code 247 | 248 | # }}} 249 | 250 | 251 | # {{{ picklable function 252 | 253 | class PicklableFunction: 254 | """Convenience class wrapping a function in a :class:`PicklableModule`. 255 | """ 256 | 257 | module: PicklableModule 258 | name: str 259 | 260 | def __init__(self, module: PicklableModule, name: str) -> None: 261 | self._initialize(module, name) 262 | 263 | def _initialize(self, module: PicklableModule, name: str) -> None: 264 | self.module = module 265 | self.name = name 266 | self._callable = cast("FunctionType", module.mod_globals[name]) # pyright: ignore[reportUnannotatedClassAttribute] 267 | 268 | def __call__(self, *args: object, **kwargs: object) -> object: 269 | return self._callable(*args, **kwargs) # pyright: ignore[reportAny] 270 | 271 | def __getstate__(self): 272 | return {"module": self.module, "name": self.name} 273 | 274 | def __setstate__(self, obj): 275 | self._initialize(obj["module"], obj["name"]) 276 | 277 | # }}} 278 | 279 | # vim: foldmethod=marker 280 | -------------------------------------------------------------------------------- /pytools/spatial_btree.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import numpy as np 4 | 5 | 6 | def do_boxes_intersect(bl, tr): 7 | (bl1, tr1) = bl 8 | (bl2, tr2) = tr 9 | (dimension,) = bl1.shape 10 | return all(max(bl1[i], bl2[i]) <= min(tr1[i], tr2[i]) for i in range(dimension)) 11 | 12 | 13 | def make_buckets(bottom_left, top_right, allbuckets, max_elements_per_box): 14 | (dimensions,) = bottom_left.shape 15 | 16 | half = (top_right - bottom_left) / 2. 17 | 18 | def do(dimension, pos): 19 | if dimension == dimensions: 20 | origin = bottom_left + pos*half 21 | bucket = SpatialBinaryTreeBucket(origin, origin + half, 22 | max_elements_per_box=max_elements_per_box) 23 | allbuckets.append(bucket) 24 | return bucket 25 | pos[dimension] = 0 26 | first = do(dimension + 1, pos) 27 | pos[dimension] = 1 28 | second = do(dimension + 1, pos) 29 | return [first, second] 30 | 31 | return do(0, np.zeros((dimensions,), np.float64)) 32 | 33 | 34 | class SpatialBinaryTreeBucket: 35 | """This class represents one bucket in a spatial binary tree. 36 | It automatically decides whether it needs to create more subdivisions 37 | beneath itself or not. 38 | 39 | .. attribute:: elements 40 | 41 | a list of tuples *(element, bbox)* where bbox is again 42 | a tuple *(lower_left, upper_right)* of :class:`numpy.ndarray` instances 43 | satisfying ``(lower_right <= upper_right).all()``. 44 | """ 45 | 46 | def __init__(self, bottom_left, top_right, max_elements_per_box=None): 47 | """:param bottom_left: A :mod: 'numpy' array of the minimal coordinates 48 | of the box being partitioned. 49 | :param top_right: A :mod: 'numpy' array of the maximal coordinates of 50 | the box being partitioned.""" 51 | 52 | self.elements = [] 53 | 54 | self.bottom_left = bottom_left 55 | self.top_right = top_right 56 | self.center = (bottom_left + top_right) / 2 57 | 58 | # As long as buckets is None, there are no subdivisions 59 | self.buckets = None 60 | self.elements = [] 61 | 62 | if max_elements_per_box is None: 63 | dimensions, = self.bottom_left.shape 64 | max_elements_per_box = 8 * 2**dimensions 65 | 66 | self.max_elements_per_box = max_elements_per_box 67 | 68 | def insert(self, element, bbox): 69 | """Insert an element into the spatial tree. 70 | 71 | :param element: the element to be stored in the retrieval data 72 | structure. It is treated as opaque and no assumptions are made on it. 73 | 74 | :param bbox: A bounding box supplied as a tuple *lower_left, 75 | upper_right* of :mod:`numpy` vectors, such that *(lower_right <= 76 | upper_right).all()*. 77 | 78 | Despite these names, the bounding box (and this entire data structure) 79 | may be of any dimension. 80 | """ 81 | 82 | def insert_into_subdivision(element, bbox): 83 | bucket_matches = [ 84 | ibucket 85 | for ibucket, bucket in enumerate(self.all_buckets) 86 | if do_boxes_intersect((bucket.bottom_left, bucket.top_right), bbox)] 87 | 88 | from random import uniform 89 | if len(bucket_matches) > len(self.all_buckets) // 2: 90 | # Would go into more than half of all buckets--keep it here 91 | self.elements.append((element, bbox)) 92 | elif len(bucket_matches) > 1 and uniform(0, 1) > 0.95: 93 | # Would go into more than one bucket and therefore may recurse 94 | # indefinitely. Keep it here with a low probability. 95 | self.elements.append((element, bbox)) 96 | else: 97 | for ibucket_match in bucket_matches: 98 | self.all_buckets[ibucket_match].insert(element, bbox) 99 | 100 | if self.buckets is None: 101 | # No subdivisions yet. 102 | if len(self.elements) > self.max_elements_per_box: 103 | # Too many elements. Need to subdivide. 104 | self.all_buckets = [] 105 | self.buckets = make_buckets( 106 | self.bottom_left, self.top_right, 107 | self.all_buckets, 108 | max_elements_per_box=self.max_elements_per_box) 109 | 110 | old_elements = self.elements 111 | self.elements = [] 112 | 113 | # Move all elements from the full bucket into the new finer ones 114 | for el, el_bbox in old_elements: 115 | insert_into_subdivision(el, el_bbox) 116 | 117 | insert_into_subdivision(element, bbox) 118 | else: 119 | # Simple: 120 | self.elements.append((element, bbox)) 121 | else: 122 | # Go find which sudivision to place element 123 | insert_into_subdivision(element, bbox) 124 | 125 | def generate_matches(self, point): 126 | if self.buckets: 127 | # We have subdivisions. Use them. 128 | (dimensions,) = point.shape 129 | bucket = self.buckets 130 | for dim in range(dimensions): 131 | bucket = bucket[0] if point[dim] < self.center[dim] else bucket[1] 132 | 133 | yield from bucket.generate_matches(point) 134 | 135 | # Perform linear search. 136 | for el, _ in self.elements: 137 | yield el 138 | 139 | def visualize(self, file): 140 | file.write(f"{self.bottom_left[0]:f} {self.bottom_left[1]:f}\n") 141 | file.write(f"{self.top_right[0]:f} {self.bottom_left[1]:f}\n") 142 | file.write(f"{self.top_right[0]:f} {self.top_right[1]:f}\n") 143 | file.write(f"{self.bottom_left[0]:f} {self.top_right[1]:f}\n") 144 | file.write(f"{self.bottom_left[0]:f} {self.bottom_left[1]:f}\n\n") 145 | if self.buckets: 146 | for i in self.all_buckets: 147 | i.visualize(file) 148 | 149 | def plot(self, **kwargs): 150 | import matplotlib.patches as mpatches 151 | import matplotlib.pyplot as pt 152 | from matplotlib.path import Path 153 | 154 | el = self.bottom_left 155 | eh = self.top_right 156 | pathdata = [ 157 | (Path.MOVETO, (el[0], el[1])), 158 | (Path.LINETO, (eh[0], el[1])), 159 | (Path.LINETO, (eh[0], eh[1])), 160 | (Path.LINETO, (el[0], eh[1])), 161 | (Path.CLOSEPOLY, (el[0], el[1])), 162 | ] 163 | 164 | codes, verts = zip(*pathdata, strict=True) 165 | path = Path(verts, codes) 166 | patch = mpatches.PathPatch(path, **kwargs) 167 | pt.gca().add_patch(patch) 168 | 169 | if self.buckets: 170 | for i in self.all_buckets: 171 | i.plot(**kwargs) 172 | -------------------------------------------------------------------------------- /pytools/stopwatch.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import time 4 | 5 | from pytools import DependentDictionary, Reference 6 | 7 | 8 | class StopWatch: 9 | def __init__(self) -> None: 10 | self.Elapsed = 0.0 11 | self.LastStart: float | None = None 12 | 13 | def start(self) -> StopWatch: 14 | assert self.LastStart is None 15 | 16 | self.LastStart = time.time() 17 | return self 18 | 19 | def stop(self) -> StopWatch: 20 | assert self.LastStart is not None 21 | 22 | self.Elapsed += time.time() - self.LastStart 23 | self.LastStart = None 24 | return self 25 | 26 | def elapsed(self) -> float: 27 | if self.LastStart: 28 | return time.time() - self.LastStart + self.Elapsed 29 | return self.Elapsed 30 | 31 | 32 | class Job: 33 | def __init__(self, name: str) -> None: 34 | self.Name = name 35 | self.StopWatch = StopWatch().start() 36 | 37 | if self.is_visible(): 38 | print(f"{name}...") 39 | 40 | def done(self) -> None: 41 | elapsed = self.StopWatch.elapsed() 42 | 43 | JOB_TIMES[self.Name] += elapsed 44 | if self.is_visible(): 45 | print(" " * (len(self.Name) + 2), elapsed, "seconds") 46 | 47 | def is_visible(self) -> bool: 48 | if PRINT_JOBS.get(): 49 | return self.Name not in HIDDEN_JOBS 50 | return self.Name in VISIBLE_JOBS 51 | 52 | 53 | class EtaEstimator: 54 | def __init__(self, total_steps: int) -> None: 55 | self.stopwatch = StopWatch().start() 56 | self.total_steps = total_steps 57 | assert total_steps > 0 58 | 59 | def estimate(self, done: int) -> float | None: 60 | fraction_done = done / self.total_steps 61 | time_spent = self.stopwatch.elapsed() 62 | 63 | if fraction_done > 1.0e-5: 64 | return time_spent / fraction_done - time_spent 65 | return None 66 | 67 | 68 | def print_job_summary() -> None: 69 | for key, value in JOB_TIMES.iteritems(): 70 | print(key, " " * (50 - len(key)), value) 71 | 72 | 73 | HIDDEN_JOBS: list[str] = [] 74 | VISIBLE_JOBS: list[str] = [] 75 | JOB_TIMES = DependentDictionary(lambda x: 0) 76 | PRINT_JOBS = Reference(True) 77 | -------------------------------------------------------------------------------- /pytools/tag.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tag Interface 3 | --------------- 4 | .. ``normalize_tags`` undocumented for now. (Not ready to commit.) 5 | 6 | .. autofunction:: check_tag_uniqueness 7 | .. autoclass:: Taggable 8 | .. autoclass:: Tag 9 | .. autoclass:: UniqueTag 10 | 11 | Supporting Functionality 12 | ------------------------ 13 | 14 | .. autoclass:: DottedName 15 | .. autoclass:: NonUniqueTagError 16 | 17 | 18 | Internal stuff that is only here because the documentation tool wants it 19 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 20 | 21 | .. class:: TagT 22 | 23 | A type variable with lower bound :class:`Tag`. 24 | """ 25 | 26 | from __future__ import annotations 27 | 28 | from collections.abc import Iterable 29 | from dataclasses import dataclass, field 30 | from typing import TYPE_CHECKING, Any, TypeVar 31 | from warnings import warn 32 | 33 | from typing_extensions import Self, dataclass_transform, override 34 | 35 | from pytools import memoize, memoize_method 36 | 37 | 38 | __copyright__ = """ 39 | Copyright (C) 2020 Andreas Klöckner 40 | Copyright (C) 2020 Matt Wala 41 | Copyright (C) 2020 Xiaoyu Wei 42 | Copyright (C) 2020 Nicholas Christensen 43 | """ 44 | 45 | __license__ = """ 46 | Permission is hereby granted, free of charge, to any person obtaining a copy 47 | of this software and associated documentation files (the "Software"), to deal 48 | in the Software without restriction, including without limitation the rights 49 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 50 | copies of the Software, and to permit persons to whom the Software is 51 | furnished to do so, subject to the following conditions: 52 | 53 | The above copyright notice and this permission notice shall be included in 54 | all copies or substantial portions of the Software. 55 | 56 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 57 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 58 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 59 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 60 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 61 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 62 | THE SOFTWARE. 63 | """ 64 | 65 | 66 | # {{{ dotted name 67 | 68 | class DottedName: 69 | """ 70 | .. attribute:: name_parts 71 | 72 | A tuple of strings, each of which is a valid 73 | Python identifier. No name part may start with 74 | a double underscore. 75 | 76 | The name (at least morally) exists in the 77 | name space defined by the Python module system. 78 | It need not necessarily identify an importable 79 | object. 80 | 81 | .. automethod:: from_class 82 | """ 83 | 84 | def __init__(self, name_parts: tuple[str, ...]) -> None: 85 | if len(name_parts) == 0: 86 | raise ValueError("empty name parts") 87 | 88 | for p in name_parts: 89 | if not p.isidentifier(): 90 | raise ValueError(f"{p} is not a Python identifier") 91 | 92 | self.name_parts = name_parts 93 | 94 | @classmethod 95 | def from_class(cls, argcls: Any) -> DottedName: 96 | name_parts = tuple( 97 | [str(part) for part in argcls.__module__.split(".")] 98 | + [str(argcls.__name__)]) 99 | if not all(not npart.startswith("__") for npart in name_parts): 100 | raise ValueError(f"some name parts of {'.'.join(name_parts)} " 101 | "start with double underscores") 102 | return cls(name_parts) 103 | 104 | @override 105 | def __repr__(self) -> str: 106 | return self.__class__.__name__ + repr(self.name_parts) 107 | 108 | @override 109 | def __eq__(self, other: object) -> bool: 110 | if isinstance(other, DottedName): 111 | return self.name_parts == other.name_parts 112 | return False 113 | 114 | 115 | # }}} 116 | 117 | 118 | # {{{ tag 119 | 120 | T = TypeVar("T") 121 | 122 | 123 | @dataclass_transform(eq_default=True, frozen_default=True, field_specifiers=(field,)) 124 | def tag_dataclass(cls: type[T]) -> type[T]: 125 | return dataclass(init=True, frozen=True, eq=True, repr=True)(cls) 126 | 127 | 128 | @tag_dataclass 129 | class Tag: 130 | """ 131 | Generic metadata, applied to, among other things, 132 | pytato Arrays. 133 | 134 | .. attribute:: tag_name 135 | 136 | A fully qualified :class:`DottedName` that reflects 137 | the class name of the tag. 138 | 139 | Instances of this type must be immutable, hashable, 140 | picklable, and have a reasonably concise :meth:`__repr__` 141 | of the form ``dotted.name(attr1=value1, attr2=value2)``. 142 | Positional arguments are not allowed. 143 | 144 | .. automethod:: __repr__ 145 | """ 146 | 147 | @property 148 | def tag_name(self) -> DottedName: 149 | return DottedName.from_class(type(self)) 150 | 151 | # }}} 152 | 153 | 154 | # {{{ unique tag 155 | 156 | @tag_dataclass 157 | class UniqueTag(Tag): 158 | """ 159 | A superclass for tags that are unique on each :class:`Taggable`. 160 | 161 | Each instance of :class:`Taggable` may have no more than one 162 | instance of each subclass of :class:`UniqueTag` in its 163 | set of `tags`. Multiple `UniqueTag` instances of 164 | different (immediate) subclasses are allowed. 165 | """ 166 | 167 | # }}} 168 | 169 | 170 | ToTagSetConvertible = Iterable[Tag] | Tag | None 171 | TagT = TypeVar("TagT", bound="Tag") 172 | 173 | 174 | # {{{ UniqueTag rules checking 175 | 176 | @memoize 177 | def _immediate_unique_tag_descendants(cls: type[Tag]) -> frozenset[type[Tag]]: 178 | if UniqueTag in cls.__bases__: 179 | return frozenset([cls]) 180 | result: frozenset[type[Tag]] = frozenset() 181 | for base in cls.__bases__: 182 | result = result | _immediate_unique_tag_descendants(base) 183 | return result 184 | 185 | 186 | class NonUniqueTagError(ValueError): 187 | """ 188 | Raised when a :class:`Taggable` object is instantiated with more 189 | than one :class:`UniqueTag` instances of the same subclass in 190 | its set of tags. 191 | """ 192 | 193 | 194 | def check_tag_uniqueness(tags: frozenset[Tag]) -> frozenset[Tag]: 195 | """Ensure that *tags* obeys the rules set forth in :class:`UniqueTag`. 196 | If not, raise :exc:`NonUniqueTagError`. If any *tags* are not 197 | subclasses of :class:`Tag`, a :exc:`TypeError` will be raised. 198 | 199 | :returns: *tags* 200 | """ 201 | unique_tag_descendants: set[type[Tag]] = set() 202 | for tag in tags: 203 | if not isinstance(tag, Tag): 204 | raise TypeError(f"'{tag}' is not an instance of pytools.tag.Tag") 205 | tag_unique_tag_descendants = _immediate_unique_tag_descendants( 206 | type(tag)) 207 | intersection = unique_tag_descendants & tag_unique_tag_descendants 208 | if intersection: 209 | raise NonUniqueTagError("Multiple tags are direct subclasses of " 210 | "the following UniqueTag(s): " 211 | f"{', '.join(d.__name__ for d in intersection)}") 212 | unique_tag_descendants.update(tag_unique_tag_descendants) 213 | 214 | return tags 215 | 216 | # }}} 217 | 218 | 219 | def normalize_tags(tags: ToTagSetConvertible) -> frozenset[Tag]: 220 | if isinstance(tags, Tag): 221 | tags = frozenset([tags]) 222 | elif tags is None: 223 | tags = frozenset() 224 | else: 225 | tags = frozenset(tags) 226 | return tags 227 | 228 | 229 | # {{{ taggable 230 | 231 | class Taggable: 232 | """ 233 | Parent class for objects with a `tags` attribute. 234 | 235 | .. autoattribute:: tags 236 | 237 | .. automethod:: _with_new_tags 238 | .. automethod:: tagged 239 | .. automethod:: without_tags 240 | .. automethod:: tags_of_type 241 | .. automethod:: tags_not_of_type 242 | 243 | .. versionadded:: 2021.1 244 | """ 245 | 246 | if not TYPE_CHECKING: 247 | def __init__(self, tags: frozenset[Tag] = frozenset()): 248 | warn("The Taggable constructor is deprecated. " 249 | "Subclasses must declare their own storage for .tags. " 250 | "The constructor will disappear in 2025.x.", 251 | DeprecationWarning, stacklevel=2) 252 | 253 | self.tags = tags 254 | 255 | # ReST references in docstrings must be fully qualified, as docstrings may 256 | # be inherited and appear in different contexts. 257 | 258 | # Before https://peps.python.org/pep-0767/, there isn't a good way to declare 259 | # an attribute read-only. Mypy accepts @property, but this is unsound: 260 | # https://peps.python.org/pep-0767/#clarifying-interaction-of-property-and-protocols 261 | tags: frozenset[Tag] # pyright: ignore[reportUninitializedInstanceVariable] 262 | 263 | def _with_new_tags(self, tags: frozenset[Tag]) -> Self: 264 | """ 265 | Returns a copy of *self* with the specified tags. This method 266 | should be overridden by subclasses. 267 | """ 268 | raise NotImplementedError 269 | 270 | def tagged(self, tags: ToTagSetConvertible) -> Self: 271 | """ 272 | Return a copy of *self* with the specified 273 | tag or tags added to the set of tags. If the resulting set of 274 | tags violates the rules on :class:`pytools.tag.UniqueTag`, 275 | an error is raised. 276 | 277 | :arg tags: An instance of :class:`~pytools.tag.Tag` or 278 | an iterable with instances therein. 279 | """ 280 | normalized = normalize_tags(tags) 281 | new_tags = check_tag_uniqueness(normalized | self.tags) 282 | if normalized <= self.tags: 283 | return self 284 | else: 285 | return self._with_new_tags(tags=new_tags) 286 | 287 | def without_tags(self, 288 | tags: ToTagSetConvertible, verify_existence: bool = True 289 | ) -> Self: 290 | """ 291 | Return a copy of *self* without the specified tags. 292 | 293 | :arg tags: An instance of :class:`~pytools.tag.Tag` or an iterable with 294 | instances therein. 295 | :arg verify_existence: If set to `True`, this method raises 296 | an exception if not all tags specified for removal are 297 | present in the original set of tags. Default `True`. 298 | """ 299 | 300 | to_remove = normalize_tags(tags) 301 | new_tags = self.tags - to_remove 302 | if verify_existence and len(new_tags) > len(self.tags) - len(to_remove): 303 | raise ValueError("A tag specified for removal was not present.") 304 | 305 | if to_remove & self.tags: 306 | return self._with_new_tags(tags=check_tag_uniqueness(new_tags)) 307 | else: 308 | return self 309 | 310 | @memoize_method 311 | def tags_of_type(self, tag_t: type[TagT]) -> frozenset[TagT]: 312 | """ 313 | Returns *self*'s tags of type *tag_t*. 314 | """ 315 | return frozenset({tag 316 | for tag in self.tags 317 | if isinstance(tag, tag_t)}) 318 | 319 | @memoize_method 320 | def tags_not_of_type(self, tag_t: type[TagT]) -> frozenset[Tag]: 321 | """ 322 | Returns *self*'s tags that are not of type *tag_t*. 323 | """ 324 | return frozenset({tag 325 | for tag in self.tags 326 | if not isinstance(tag, tag_t)}) 327 | 328 | @override 329 | def __eq__(self, other: object) -> bool: 330 | if isinstance(other, Taggable): 331 | return self.tags == other.tags 332 | return super().__eq__(other) 333 | 334 | @override 335 | def __hash__(self) -> int: 336 | return hash(self.tags) 337 | 338 | # }}} 339 | 340 | 341 | # {{{ deprecation 342 | 343 | _depr_name_to_replacement_and_obj = { 344 | "TagsType": ( 345 | "frozenset[Tag]", 346 | frozenset[Tag], 2023), 347 | "TagOrIterableType": ( 348 | "ToTagSetConvertible", 349 | ToTagSetConvertible, 2023), 350 | "T_co": ( 351 | "Self (i.e. the self type from Python 3.11)", 352 | TypeVar("TaggableT", bound="Taggable"), 2023), 353 | } 354 | 355 | 356 | def __getattr__(name: str) -> Any: 357 | replacement_and_obj = _depr_name_to_replacement_and_obj.get(name) 358 | if replacement_and_obj is not None: 359 | replacement, obj, year = replacement_and_obj 360 | from warnings import warn 361 | warn(f"'pytools.tag.{name}' is deprecated. " 362 | f"Use '{replacement}' instead. " 363 | f"'pytools.tag.{name}' will continue to work until {year}.", 364 | DeprecationWarning, stacklevel=2) 365 | return obj 366 | raise AttributeError(name) 367 | 368 | # }}} 369 | 370 | 371 | # vim: foldmethod=marker 372 | -------------------------------------------------------------------------------- /pytools/test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inducer/pytools/2be5f3f72670830454a4687ca8bb8dff0e214b97/pytools/test/__init__.py -------------------------------------------------------------------------------- /pytools/test/test_data_table.py: -------------------------------------------------------------------------------- 1 | # data from Wikipedia "join" article 2 | from __future__ import annotations 3 | 4 | 5 | def get_dept_table(): 6 | from pytools.datatable import DataTable 7 | dept_table = DataTable(["id", "name"]) 8 | dept_table.insert_row((31, "Sales")) 9 | dept_table.insert_row((33, "Engineering")) 10 | dept_table.insert_row((34, "Clerical")) 11 | dept_table.insert_row((35, "Marketing")) 12 | return dept_table 13 | 14 | 15 | def get_employee_table(): 16 | from pytools.datatable import DataTable 17 | employee_table = DataTable(["lastname", "dept"]) 18 | employee_table.insert_row(("Rafferty", 31)) 19 | employee_table.insert_row(("Jones", 33)) 20 | employee_table.insert_row(("Jasper", 36)) 21 | employee_table.insert_row(("Steinberg", 33)) 22 | employee_table.insert_row(("Robinson", 34)) 23 | employee_table.insert_row(("Smith", 34)) 24 | return employee_table 25 | 26 | 27 | def test_len(): 28 | et = get_employee_table() 29 | assert len(et) == 6 30 | 31 | 32 | def test_iter(): 33 | et = get_employee_table() 34 | 35 | count = 0 36 | for row in et: 37 | count += 1 38 | assert len(row) == 2 39 | 40 | assert count == 6 41 | 42 | 43 | def test_insert_and_get(): 44 | et = get_employee_table() 45 | et.insert(dept=33, lastname="Kloeckner") 46 | assert et.get(lastname="Kloeckner").dept == 33 47 | 48 | 49 | def test_filtered(): 50 | et = get_employee_table() 51 | assert len(et.filtered(dept=33)) == 2 52 | assert len(et.filtered(dept=34)) == 2 53 | 54 | 55 | def test_sort(): 56 | et = get_employee_table() 57 | et.sort(["lastname"]) 58 | assert et.column_data("dept") == [36, 33, 31, 34, 34, 33] 59 | 60 | 61 | def test_aggregate(): 62 | et = get_employee_table() 63 | et.sort(["dept"]) 64 | agg = et.aggregated(["dept"], "lastname", ",".join) 65 | assert len(agg) == 4 66 | for dept, lastnames in agg: 67 | lastnames = lastnames.split(",") 68 | for lastname in lastnames: 69 | assert et.get(lastname=lastname).dept == dept 70 | 71 | 72 | def test_aggregate_2(): 73 | from pytools.datatable import DataTable 74 | tbl = DataTable(["step", "value"], list(zip(range(20), range(20), strict=True))) 75 | agg = tbl.aggregated(["step"], "value", max) 76 | assert agg.column_data("step") == list(range(20)) 77 | assert agg.column_data("value") == list(range(20)) 78 | 79 | 80 | def test_join(): 81 | et = get_employee_table() 82 | dt = get_dept_table() 83 | 84 | et.sort(["dept"]) 85 | dt.sort(["id"]) 86 | 87 | inner_joined = et.join("dept", "id", dt) 88 | assert len(inner_joined) == len(et)-1 89 | for dept, lastname, deptname in inner_joined: 90 | dept_id = et.get(lastname=lastname).dept 91 | assert dept_id == dept 92 | assert dt.get(id=dept_id).name == deptname 93 | 94 | outer_joined = et.join("dept", "id", dt, outer=True) 95 | assert len(outer_joined) == len(et)+1 96 | -------------------------------------------------------------------------------- /pytools/test/test_dataclasses.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | 4 | __copyright__ = "Copyright (C) 2024 University of Illinois Board of Trustees" 5 | 6 | __license__ = """ 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in 15 | all copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 23 | THE SOFTWARE. 24 | """ 25 | 26 | 27 | import sys 28 | 29 | import pytest 30 | 31 | from pytools import opt_frozen_dataclass 32 | 33 | 34 | def test_opt_frozen_dataclass() -> None: 35 | # {{{ basic usage 36 | 37 | @opt_frozen_dataclass() 38 | class A: 39 | x: int 40 | 41 | a = A(1) 42 | assert a.x == 1 43 | 44 | # Needs to be hashable by default, not using object.__hash__ 45 | hash(a) 46 | assert hash(a) == hash(A(1)) 47 | assert a == A(1) 48 | 49 | # Needs to be frozen by default 50 | if __debug__: 51 | with pytest.raises(AttributeError): 52 | a.x = 2 # type: ignore[misc] 53 | else: 54 | a.x = 2 # type: ignore[misc] 55 | 56 | assert a.__dataclass_params__.frozen is __debug__ # type: ignore[attr-defined] # pylint: disable=no-member 57 | 58 | # }}} 59 | 60 | with pytest.raises(TypeError): 61 | # Can't specify frozen parameter 62 | @opt_frozen_dataclass(frozen=False) # type: ignore[call-arg] # pylint: disable=unexpected-keyword-arg 63 | class B: 64 | x: int 65 | 66 | # {{{ eq=False 67 | 68 | @opt_frozen_dataclass(eq=False) 69 | class C: 70 | x: int 71 | 72 | c = C(1) 73 | 74 | # Hashing still works, but uses object.__hash__ (i.e., id()) 75 | assert hash(c) != hash(C(1)) 76 | 77 | # Equality is not defined and uses id() 78 | assert c != C(1) 79 | 80 | # }}} 81 | 82 | 83 | def test_dataclass_weakref() -> None: 84 | if sys.version_info < (3, 11): 85 | pytest.skip("weakref support needs Python 3.11+") 86 | 87 | @opt_frozen_dataclass(weakref_slot=True, slots=True) 88 | class Weakref: 89 | x: int 90 | 91 | a = Weakref(1) 92 | assert a.x == 1 93 | 94 | import weakref 95 | ref = weakref.ref(a) 96 | 97 | _ = ref().x 98 | 99 | with pytest.raises(TypeError): 100 | @opt_frozen_dataclass(weakref_slot=True) # needs slots=True to work 101 | class Weakref2: 102 | x: int 103 | 104 | 105 | if __name__ == "__main__": 106 | if len(sys.argv) > 1: 107 | exec(sys.argv[1]) 108 | else: 109 | from pytest import main 110 | main([__file__]) 111 | -------------------------------------------------------------------------------- /pytools/test/test_graph_tools.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import sys 4 | 5 | import pytest 6 | 7 | 8 | def test_compute_sccs(): 9 | import random 10 | 11 | from pytools.graph import compute_sccs 12 | 13 | rng = random.Random(0) 14 | 15 | def generate_random_graph(nnodes): 16 | graph = {i: set() for i in range(nnodes)} 17 | for i in range(nnodes): 18 | for j in range(nnodes): 19 | # Edge probability 2/n: Generates decently interesting inputs. 20 | if rng.randint(0, nnodes - 1) <= 1: 21 | graph[i].add(j) 22 | return graph 23 | 24 | def verify_sccs(graph, sccs): 25 | visited = set() 26 | 27 | def visit(node): 28 | if node in visited: 29 | return [] 30 | visited.add(node) 31 | result = [] 32 | for child in graph[node]: 33 | result = result + visit(child) 34 | return [*result, node] 35 | 36 | for scc in sccs: 37 | scc = set(scc) 38 | assert not scc & visited 39 | # Check that starting from each element of the SCC results 40 | # in the same set of reachable nodes. 41 | for scc_root in scc: 42 | visited.difference_update(scc) 43 | result = visit(scc_root) 44 | assert set(result) == scc, (set(result), scc) 45 | 46 | for nnodes in range(10, 20): 47 | for _ in range(40): 48 | graph = generate_random_graph(nnodes) 49 | verify_sccs(graph, compute_sccs(graph)) 50 | 51 | 52 | def test_compute_topological_order(): 53 | from pytools.graph import CycleError, compute_topological_order 54 | 55 | empty = {} 56 | assert compute_topological_order(empty) == [] 57 | 58 | disconnected = {1: [], 2: [], 3: []} 59 | assert len(compute_topological_order(disconnected)) == 3 60 | 61 | line = list(zip(range(10), ([i] for i in range(1, 11)), strict=True)) 62 | import random 63 | random.seed(0) 64 | random.shuffle(line) 65 | expected = list(range(11)) 66 | assert compute_topological_order(dict(line)) == expected 67 | 68 | claw = {1: [2, 3], 0: [1]} 69 | assert compute_topological_order(claw)[:2] == [0, 1] 70 | 71 | repeated_edges = {1: [2, 2], 2: [0]} 72 | assert compute_topological_order(repeated_edges) == [1, 2, 0] 73 | 74 | self_cycle = {1: [1]} 75 | with pytest.raises(CycleError): 76 | compute_topological_order(self_cycle) 77 | 78 | cycle = {0: [2], 1: [2], 2: [3], 3: [4, 1]} 79 | with pytest.raises(CycleError): 80 | compute_topological_order(cycle) 81 | 82 | 83 | def test_transitive_closure(): 84 | from pytools.graph import compute_transitive_closure 85 | 86 | # simple test 87 | graph = { 88 | 1: {2}, 89 | 2: {3}, 90 | 3: {4}, 91 | 4: set(), 92 | } 93 | 94 | expected_closure = { 95 | 1: {2, 3, 4}, 96 | 2: {3, 4}, 97 | 3: {4}, 98 | 4: set(), 99 | } 100 | 101 | closure = compute_transitive_closure(graph) 102 | 103 | assert closure == expected_closure 104 | 105 | # test with branches that reconnect 106 | graph = { 107 | 1: {2}, 108 | 2: set(), 109 | 3: {1}, 110 | 4: {1}, 111 | 5: {6, 7}, 112 | 6: {7}, 113 | 7: {1}, 114 | 8: {3, 4}, 115 | } 116 | 117 | expected_closure = { 118 | 1: {2}, 119 | 2: set(), 120 | 3: {1, 2}, 121 | 4: {1, 2}, 122 | 5: {1, 2, 6, 7}, 123 | 6: {1, 2, 7}, 124 | 7: {1, 2}, 125 | 8: {1, 2, 3, 4}, 126 | } 127 | 128 | closure = compute_transitive_closure(graph) 129 | 130 | assert closure == expected_closure 131 | 132 | # test with cycles 133 | graph = { 134 | 1: {2}, 135 | 2: {3}, 136 | 3: {4}, 137 | 4: {1}, 138 | } 139 | 140 | expected_closure = { 141 | 1: {1, 2, 3, 4}, 142 | 2: {1, 2, 3, 4}, 143 | 3: {1, 2, 3, 4}, 144 | 4: {1, 2, 3, 4}, 145 | } 146 | 147 | closure = compute_transitive_closure(graph) 148 | 149 | assert closure == expected_closure 150 | 151 | 152 | def test_graph_cycle_finder(): 153 | 154 | from pytools.graph import contains_cycle 155 | 156 | graph = { 157 | "a": {"b", "c"}, 158 | "b": {"d", "e"}, 159 | "c": {"d", "f"}, 160 | "d": set(), 161 | "e": set(), 162 | "f": {"g"}, 163 | "g": set(), 164 | } 165 | 166 | assert not contains_cycle(graph) 167 | 168 | graph = { 169 | "a": {"b", "c"}, 170 | "b": {"d", "e"}, 171 | "c": {"d", "f"}, 172 | "d": set(), 173 | "e": set(), 174 | "f": {"g"}, 175 | "g": {"a"}, 176 | } 177 | 178 | assert contains_cycle(graph) 179 | 180 | graph = { 181 | "a": {"a", "c"}, 182 | "b": {"d", "e"}, 183 | "c": {"d", "f"}, 184 | "d": set(), 185 | "e": set(), 186 | "f": {"g"}, 187 | "g": set(), 188 | } 189 | 190 | assert contains_cycle(graph) 191 | 192 | graph = { 193 | "a": {"a"}, 194 | } 195 | 196 | assert contains_cycle(graph) 197 | 198 | 199 | def test_induced_subgraph(): 200 | 201 | from pytools.graph import compute_induced_subgraph 202 | 203 | graph = { 204 | "a": {"b", "c"}, 205 | "b": {"d", "e"}, 206 | "c": {"d", "f"}, 207 | "d": set(), 208 | "e": set(), 209 | "f": {"g"}, 210 | "g": {"h", "i", "j"}, 211 | } 212 | 213 | node_subset = {"b", "c", "e", "f", "g"} 214 | 215 | expected_subgraph = { 216 | "b": {"e"}, 217 | "c": {"f"}, 218 | "e": set(), 219 | "f": {"g"}, 220 | "g": set(), 221 | } 222 | 223 | subgraph = compute_induced_subgraph(graph, node_subset) 224 | 225 | assert subgraph == expected_subgraph 226 | 227 | 228 | def test_prioritized_topological_sort_examples(): 229 | 230 | from pytools.graph import compute_topological_order 231 | 232 | keys = {"a": 4, "b": 3, "c": 2, "e": 1, "d": 4} 233 | dag = { 234 | "a": ["b", "c"], 235 | "b": [], 236 | "c": ["d", "e"], 237 | "d": [], 238 | "e": []} 239 | 240 | assert compute_topological_order(dag, key=keys.get) == [ 241 | "a", "c", "e", "b", "d"] 242 | 243 | keys = {"a": 7, "b": 2, "c": 1, "d": 0} 244 | dag = { 245 | "d": set("c"), 246 | "b": set("a"), 247 | "a": set(), 248 | "c": set("a"), 249 | } 250 | 251 | assert compute_topological_order(dag, key=keys.get) == ["d", "c", "b", "a"] 252 | 253 | 254 | def test_prioritized_topological_sort(): 255 | 256 | import random 257 | 258 | from pytools.graph import compute_topological_order 259 | rng = random.Random(0) 260 | 261 | def generate_random_graph(nnodes): 262 | graph = {i: set() for i in range(nnodes)} 263 | for i in range(nnodes): 264 | # to avoid cycles only consider edges node_i->node_j where j > i. 265 | for j in range(i+1, nnodes): 266 | # Edge probability 4/n: Generates decently interesting inputs. 267 | if rng.randint(0, nnodes - 1) <= 2: 268 | graph[i].add(j) 269 | return graph 270 | 271 | nnodes = rng.randint(40, 100) 272 | rev_dep_graph = generate_random_graph(nnodes) 273 | dep_graph = {i: set() for i in range(nnodes)} 274 | 275 | for i in range(nnodes): 276 | for rev_dep in rev_dep_graph[i]: 277 | dep_graph[rev_dep].add(i) 278 | 279 | keys = [rng.random() for _ in range(nnodes)] 280 | topo_order = compute_topological_order(rev_dep_graph, key=keys.__getitem__) 281 | 282 | for scheduled_node in topo_order: 283 | nodes_with_no_deps = {node for node, deps in dep_graph.items() 284 | if len(deps) == 0} 285 | 286 | # check whether the order is a valid topological order 287 | assert scheduled_node in nodes_with_no_deps 288 | # check whether priorities are upheld 289 | assert keys[scheduled_node] == min( 290 | keys[node] for node in nodes_with_no_deps) 291 | 292 | # 'scheduled_node' is scheduled => no longer a dependency 293 | dep_graph.pop(scheduled_node) 294 | 295 | for deps in dep_graph.values(): 296 | deps.discard(scheduled_node) 297 | 298 | assert len(dep_graph) == 0 299 | 300 | 301 | def test_as_graphviz_dot(): 302 | graph = {"A": ["B", "C"], 303 | "B": [], 304 | "C": ["A"]} 305 | 306 | from pytools.graph import NodeT, as_graphviz_dot 307 | 308 | def edge_labels(n1: NodeT, n2: NodeT) -> str: 309 | if n1 == "A" and n2 == "B": 310 | return "foo" 311 | 312 | return "" 313 | 314 | def node_labels(node: NodeT) -> str: 315 | if node == "A": 316 | return "foonode" 317 | 318 | return str(node) 319 | 320 | res = as_graphviz_dot(graph, node_labels=node_labels, edge_labels=edge_labels) 321 | 322 | assert res == \ 323 | """digraph mygraph { 324 | mynodeid [label="foonode"]; 325 | mynodeid_0 [label="B"]; 326 | mynodeid_1 [label="C"]; 327 | mynodeid -> mynodeid_0 [label="foo"]; 328 | mynodeid -> mynodeid_1 [label=""]; 329 | mynodeid_1 -> mynodeid [label=""]; 330 | } 331 | """ 332 | 333 | 334 | def test_reverse_graph(): 335 | graph = { 336 | "a": frozenset(("b", "c")), 337 | "b": frozenset(("d", "e")), 338 | "c": frozenset(("d", "f")), 339 | "d": frozenset(), 340 | "e": frozenset(), 341 | "f": frozenset(("g",)), 342 | "g": frozenset(("h", "i", "j")), 343 | "h": frozenset(), 344 | "i": frozenset(), 345 | "j": frozenset(), 346 | } 347 | 348 | from pytools.graph import reverse_graph 349 | assert graph == reverse_graph(reverse_graph(graph)) 350 | 351 | 352 | def test_validate_graph(): 353 | from pytools.graph import validate_graph 354 | graph1 = { 355 | "d": set("c"), 356 | "b": set("a"), 357 | "a": set(), 358 | "c": set("a"), 359 | } 360 | 361 | validate_graph(graph1) 362 | 363 | graph2 = { 364 | "d": set("d"), 365 | "b": set("c"), 366 | "a": set("b"), 367 | "c": set("a"), 368 | } 369 | 370 | validate_graph(graph2) 371 | 372 | graph3 = { 373 | "a": {"b", "c"}, 374 | "b": {"d", "e"}, 375 | "c": {"d", "f"}, 376 | "d": set(), 377 | "e": set(), 378 | "f": {"g"}, 379 | "g": {"h", "i", "j"}, # h, i, j missing from keys 380 | } 381 | 382 | with pytest.raises(ValueError): 383 | validate_graph(graph3) 384 | 385 | validate_graph({}) 386 | 387 | 388 | def test_is_connected(): 389 | from pytools.graph import is_connected 390 | graph1 = { 391 | "d": set("c"), 392 | "b": set("a"), 393 | "a": set(), 394 | "c": set("a"), 395 | } 396 | 397 | assert is_connected(graph1) 398 | 399 | graph2 = { 400 | "d": set("d"), 401 | "b": set("c"), 402 | "a": set("b"), 403 | "c": set("a"), 404 | } 405 | 406 | assert not is_connected(graph2) 407 | 408 | graph3 = { 409 | "a": {"b", "c"}, 410 | "b": {"d", "e"}, 411 | "c": {"d", "f"}, 412 | "d": set(), 413 | "e": set(), 414 | "f": {"g"}, 415 | "g": {}, 416 | } 417 | 418 | assert is_connected(graph3) 419 | 420 | graph4 = { 421 | "a": {"c"}, 422 | "b": {"d", "e"}, 423 | "c": {"f"}, 424 | "d": set(), 425 | "e": set(), 426 | "f": {"g"}, 427 | "g": {}, 428 | } 429 | 430 | assert not is_connected(graph4) 431 | 432 | assert is_connected({}) 433 | 434 | 435 | def test_propagation_graph_tools(): 436 | from pytools.graph import ( 437 | get_reachable_nodes, 438 | undirected_graph_from_edges, 439 | ) 440 | vars = {"a", "b", "c", "d", "e", "f", "g"} 441 | 442 | constraints = { 443 | ("a", "b"), 444 | ("b", "c"), 445 | ("b", "d"), 446 | ("c", "e"), 447 | ("d", "f"), 448 | ("e", "g"), 449 | ("g", "f"), 450 | ("f", "g") 451 | } 452 | 453 | all_reachable_nodes = { 454 | "a": frozenset({"a", "b"}), 455 | "b": frozenset({"a", "b"}), 456 | "c": frozenset(), 457 | "d": frozenset(), 458 | "e": frozenset({"e", "f", "g"}), 459 | "f": frozenset({"e", "f", "g"}), 460 | "g": frozenset({"e", "f", "g"}) 461 | } 462 | 463 | exclude_nodes = {"d", "c"} 464 | propagation_graph = undirected_graph_from_edges(constraints) 465 | 466 | assert ( 467 | all_reachable_nodes[var] == get_reachable_nodes(propagation_graph, var, 468 | exclude_nodes) 469 | for var in vars 470 | ) 471 | 472 | 473 | if __name__ == "__main__": 474 | if len(sys.argv) > 1: 475 | exec(sys.argv[1]) 476 | else: 477 | from pytest import main 478 | main([__file__]) 479 | -------------------------------------------------------------------------------- /pytools/test/test_math_stuff.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | 4 | def test_variance(): 5 | data = [4, 7, 13, 16] 6 | 7 | def naive_var(data): 8 | n = len(data) 9 | return (( 10 | sum(di**2 for di in data) 11 | - sum(data)**2/n) 12 | / (n-1)) 13 | 14 | from pytools import variance 15 | orig_variance = variance(data, entire_pop=False) 16 | 17 | assert abs(naive_var(data) - orig_variance) < 1e-15 18 | 19 | data = [1e9 + x for x in data] 20 | assert abs(variance(data, entire_pop=False) - orig_variance) < 1e-15 21 | -------------------------------------------------------------------------------- /pytools/test/test_mpi.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | 4 | __copyright__ = "Copyright (C) 2022 University of Illinois Board of Trustees" 5 | 6 | __license__ = """ 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in 15 | all copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 23 | THE SOFTWARE. 24 | """ 25 | 26 | import pytest 27 | 28 | 29 | def test_pytest_raises_on_rank(): 30 | from pytools.mpi import pytest_raises_on_rank 31 | 32 | def fail(my_rank: int, fail_rank: int) -> None: 33 | if my_rank == fail_rank: 34 | raise ValueError("test failure") 35 | 36 | with pytest.raises(ValueError): 37 | fail(0, 0) 38 | 39 | fail(0, 1) 40 | 41 | with pytest_raises_on_rank(0, 0, ValueError): 42 | # Generates an exception, and pytest_raises_on_rank 43 | # expects one. 44 | fail(0, 0) 45 | 46 | with pytest_raises_on_rank(0, 1, ValueError): 47 | # Generates no exception, and pytest_raises_on_rank 48 | # does not expect one. 49 | fail(0, 1) 50 | -------------------------------------------------------------------------------- /pytools/test/test_py_codegen.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import pickle 4 | import sys 5 | from typing import cast 6 | 7 | import pytest 8 | 9 | import pytools 10 | import pytools.py_codegen as codegen 11 | 12 | 13 | def test_pickling_with_module_import(): 14 | cg = codegen.PythonCodeGenerator() 15 | cg("import pytools") 16 | cg("import math as m") 17 | 18 | import pickle 19 | mod = pickle.loads(pickle.dumps(cg.get_picklable_module())) 20 | 21 | assert mod.mod_globals["pytools"] is pytools 22 | 23 | import math 24 | assert mod.mod_globals["m"] is math 25 | 26 | 27 | def test_picklable_function(): 28 | cg = codegen.PythonFunctionGenerator("f", args=()) 29 | cg("return 1") 30 | 31 | import pickle 32 | f = pickle.loads(pickle.dumps(cg.get_picklable_function())) 33 | 34 | assert f() == 1 35 | 36 | 37 | def test_function_decorators(capfd): 38 | cg = codegen.PythonFunctionGenerator("f", args=(), decorators=["@staticmethod"]) 39 | cg("return 42") 40 | 41 | assert cg.get_function()() == 42 42 | 43 | cg = codegen.PythonFunctionGenerator("f", args=(), decorators=["@classmethod"]) 44 | cg("return 42") 45 | 46 | with pytest.raises(TypeError): 47 | cg.get_function()() 48 | 49 | cg = codegen.PythonFunctionGenerator("f", args=(), 50 | decorators=["@staticmethod", "@classmethod"]) 51 | cg("return 42") 52 | 53 | with pytest.raises(TypeError): 54 | cg.get_function()() 55 | 56 | cg = codegen.PythonFunctionGenerator("f", args=("x"), 57 | decorators=["from functools import lru_cache", "@lru_cache"]) 58 | cg("print('Hello World!')") 59 | cg("return 42") 60 | 61 | f = cg.get_function() 62 | 63 | assert f(0) == 42 64 | out, _err = capfd.readouterr() 65 | assert out == "Hello World!\n" 66 | 67 | assert f(0) == 42 68 | out, _err = capfd.readouterr() 69 | assert out == "" # second print is not executed due to lru_cache 70 | 71 | 72 | def test_linecache_func() -> None: 73 | cg = codegen.PythonFunctionGenerator("f", args=()) 74 | cg("return 42") 75 | 76 | func = cg.get_function() 77 | func() 78 | 79 | mod_name = func.__code__.co_filename 80 | 81 | import linecache 82 | 83 | assert linecache.getlines(mod_name) == [ 84 | "def f():\n", 85 | " return 42\n", 86 | ] 87 | 88 | assert linecache.getline(mod_name, 1) == "def f():\n" 89 | assert linecache.getline(mod_name, 2) == " return 42\n" 90 | 91 | pkl = pickle.dumps(cg.get_picklable_function()) 92 | 93 | pf = cast("codegen.PicklableFunction", pickle.loads(pkl)) 94 | 95 | post_pickle_mod_name = pf._callable.__code__.co_filename 96 | 97 | assert post_pickle_mod_name != mod_name 98 | assert linecache.getlines(post_pickle_mod_name) == [ 99 | "def f():\n", 100 | " return 42\n", 101 | ] 102 | 103 | 104 | def test_linecache_mod() -> None: 105 | cg2 = codegen.PythonCodeGenerator() 106 | cg2("def f():") 107 | cg2(" return 37") 108 | 109 | mod = cg2.get_module() 110 | mod["f"]() 111 | mod_name = cast("str", mod["__code__"].co_filename) 112 | 113 | assert mod_name 114 | 115 | import linecache 116 | assert linecache.getlines(mod_name) == [ 117 | "def f():\n", 118 | " return 37\n", 119 | ] 120 | 121 | 122 | if __name__ == "__main__": 123 | if len(sys.argv) > 1: 124 | exec(sys.argv[1]) 125 | else: 126 | from pytest import main 127 | main([__file__]) 128 | -------------------------------------------------------------------------------- /pytools/test/test_pytools.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | 4 | __copyright__ = "Copyright (C) 2009-2021 Andreas Kloeckner" 5 | 6 | __license__ = """ 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in 15 | all copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 23 | THE SOFTWARE. 24 | """ 25 | 26 | 27 | import logging 28 | import sys 29 | from dataclasses import dataclass 30 | from typing import TYPE_CHECKING 31 | 32 | import pytest 33 | 34 | from pytools import Record 35 | from pytools.tag import tag_dataclass 36 | 37 | 38 | if TYPE_CHECKING: 39 | from collections.abc import Sequence 40 | 41 | 42 | logger = logging.getLogger(__name__) 43 | 44 | 45 | def test_memoize_method_clear(): 46 | from pytools import memoize_method 47 | 48 | class SomeClass: 49 | def __init__(self): 50 | self.run_count = 0 51 | 52 | @memoize_method 53 | def f(self): 54 | self.run_count += 1 55 | return 17 56 | 57 | sc = SomeClass() 58 | sc.f() 59 | sc.f() 60 | assert sc.run_count == 1 61 | 62 | sc.f.clear_cache(sc) 63 | 64 | 65 | def test_keyed_memoize_method_with_uncached(): 66 | from pytools import keyed_memoize_method 67 | 68 | class SomeClass: 69 | def __init__(self): 70 | self.run_count = 0 71 | 72 | @keyed_memoize_method(key=lambda x, y, z: x) 73 | def f(self, x, y, z): 74 | del x, y, z 75 | self.run_count += 1 76 | return 17 77 | 78 | sc = SomeClass() 79 | sc.f(17, 18, z=19) 80 | sc.f(17, 19, z=20) 81 | assert sc.run_count == 1 82 | sc.f(18, 19, z=20) 83 | assert sc.run_count == 2 84 | 85 | sc.f.clear_cache(sc) 86 | 87 | 88 | def test_memoize_in(): 89 | from pytools import memoize_in 90 | 91 | class SomeClass: 92 | def __init__(self): 93 | self.run_count = 0 94 | 95 | def f(self): 96 | 97 | @memoize_in(self, (SomeClass.f,)) 98 | def inner(x): 99 | self.run_count += 1 100 | return 2*x 101 | 102 | inner(5) 103 | inner(5) 104 | 105 | sc = SomeClass() 106 | sc.f() 107 | assert sc.run_count == 1 108 | 109 | 110 | def test_p_convergence_verifier(): 111 | pytest.importorskip("numpy") 112 | 113 | from pytools.convergence import PConvergenceVerifier 114 | 115 | pconv_verifier = PConvergenceVerifier() 116 | for order in [2, 3, 4, 5]: 117 | pconv_verifier.add_data_point(order, 0.1**order) 118 | pconv_verifier() 119 | 120 | pconv_verifier = PConvergenceVerifier() 121 | for order in [2, 3, 4, 5]: 122 | pconv_verifier.add_data_point(order, 0.5**order) 123 | pconv_verifier() 124 | 125 | pconv_verifier = PConvergenceVerifier() 126 | for order in [2, 3, 4, 5]: 127 | pconv_verifier.add_data_point(order, 2) 128 | with pytest.raises(AssertionError): 129 | pconv_verifier() 130 | 131 | 132 | def test_memoize(): 133 | from pytools import memoize 134 | count = [0] 135 | 136 | @memoize 137 | def f(i, j): 138 | count[0] += 1 139 | return i + j 140 | 141 | assert f(1, 2) == 3 142 | assert f(1, 2) == 3 143 | assert count[0] == 1 144 | 145 | 146 | def test_memoize_with_kwargs(): 147 | from pytools import memoize 148 | count = [0] 149 | 150 | @memoize(use_kwargs=True) 151 | def f(i, j=1): 152 | count[0] += 1 153 | return i + j 154 | 155 | assert f(1) == 2 156 | assert f(1, 2) == 3 157 | assert f(2, j=3) == 5 158 | assert count[0] == 3 159 | assert f(1) == 2 160 | assert f(1, 2) == 3 161 | assert f(2, j=3) == 5 162 | assert count[0] == 3 163 | 164 | 165 | def test_memoize_keyfunc(): 166 | from pytools import memoize 167 | count = [0] 168 | 169 | @memoize(key=lambda i, j=(1,): (i, len(j))) 170 | def f(i, j=(1,)): 171 | count[0] += 1 172 | return i + len(j) 173 | 174 | assert f(1) == 2 175 | assert f(1, [2]) == 2 176 | assert f(2, j=[2, 3]) == 4 177 | assert count[0] == 2 178 | assert f(1) == 2 179 | assert f(1, (2,)) == 2 180 | assert f(2, j=(2, 3)) == 4 181 | assert count[0] == 2 182 | 183 | 184 | def test_memoize_frozen() -> None: 185 | 186 | from pytools import memoize_method 187 | 188 | # {{{ check frozen dataclass 189 | 190 | @dataclass(frozen=True) 191 | class FrozenDataclass: 192 | value: int 193 | 194 | @memoize_method 195 | def double_value(self): 196 | return 2 * self.value 197 | 198 | c0 = FrozenDataclass(10) 199 | assert c0.double_value() == 20 200 | 201 | c0.double_value.clear_cache(c0) # type: ignore[attr-defined] 202 | 203 | # }}} 204 | 205 | # {{{ check class with no setattr 206 | 207 | class FrozenClass: 208 | value: int 209 | 210 | def __init__(self, value): 211 | object.__setattr__(self, "value", value) 212 | 213 | def __setattr__(self, key, value): 214 | raise AttributeError(f"cannot set attribute {key}") 215 | 216 | @memoize_method 217 | def double_value(self): 218 | return 2 * self.value 219 | 220 | c1 = FrozenClass(10) 221 | assert c1.double_value() == 20 222 | 223 | c1.double_value.clear_cache(c1) # type: ignore[attr-defined] 224 | 225 | # }}} 226 | 227 | 228 | @pytest.mark.parametrize("dims", [2, 3]) 229 | def test_spatial_btree(dims, do_plot=False): 230 | pytest.importorskip("numpy") 231 | import numpy as np 232 | 233 | rng = np.random.default_rng() 234 | nparticles = 2000 235 | x = -1 + 2*rng.uniform(size=(dims, nparticles)) 236 | x = np.sign(x)*np.abs(x)**1.9 237 | x = (1.4 + x) % 2 - 1 238 | 239 | bl = np.min(x, axis=-1) 240 | tr = np.max(x, axis=-1) 241 | print(bl, tr) 242 | 243 | from pytools.spatial_btree import SpatialBinaryTreeBucket 244 | tree = SpatialBinaryTreeBucket(bl, tr, max_elements_per_box=10) 245 | for i in range(nparticles): 246 | tree.insert(i, (x[:, i], x[:, i])) 247 | 248 | if do_plot: 249 | import matplotlib.pyplot as pt 250 | pt.gca().set_aspect("equal") 251 | pt.plot(x[0], x[1], "x") 252 | tree.plot(fill=None) 253 | pt.show() 254 | 255 | 256 | def test_generate_numbered_unique_names(): 257 | from pytools import generate_numbered_unique_names 258 | 259 | gen = generate_numbered_unique_names("a") 260 | assert next(gen) == (0, "a") 261 | assert next(gen) == (1, "a_0") 262 | 263 | gen = generate_numbered_unique_names("b", 6) 264 | assert next(gen) == (7, "b_6") 265 | 266 | 267 | def test_cartesian_product(): 268 | from pytools import cartesian_product 269 | 270 | expected_outputs = [ 271 | (0, 2, 4), 272 | (0, 2, 5), 273 | (0, 3, 4), 274 | (0, 3, 5), 275 | (1, 2, 4), 276 | (1, 2, 5), 277 | (1, 3, 4), 278 | (1, 3, 5), 279 | ] 280 | 281 | for i, output in enumerate(cartesian_product([0, 1], [2, 3], [4, 5])): 282 | assert output == expected_outputs[i] 283 | 284 | 285 | def test_find_module_git_revision(): 286 | import pytools 287 | print(pytools.find_module_git_revision(pytools.__file__, n_levels_up=1)) 288 | 289 | 290 | def test_reshaped_view(): 291 | import pytools 292 | np = pytest.importorskip("numpy") 293 | 294 | a = np.zeros((10, 2)) 295 | b = a.T 296 | c = pytools.reshaped_view(a, -1) 297 | assert c.shape == (20,) 298 | with pytest.raises(AttributeError): 299 | pytools.reshaped_view(b, -1) 300 | 301 | 302 | def test_processlogger(): 303 | logging.basicConfig(level=logging.INFO) 304 | 305 | from pytools import ProcessLogger 306 | plog = ProcessLogger(logger, "testing the process logger", 307 | long_threshold_seconds=0.01) 308 | 309 | from time import sleep 310 | with plog: 311 | sleep(0.3) 312 | 313 | 314 | def test_table(): 315 | import math 316 | 317 | from pytools import Table 318 | 319 | tbl = Table() 320 | tbl.add_row(("i", "i^2", "i^3", "sqrt(i)")) 321 | 322 | for i in range(8): 323 | tbl.add_row((i, i ** 2, i ** 3, math.sqrt(i))) 324 | 325 | print(tbl) 326 | print() 327 | print(tbl.latex()) 328 | 329 | # {{{ test merging 330 | 331 | from pytools import merge_tables 332 | tbl = merge_tables(tbl, tbl, tbl, skip_columns=(0,)) 333 | print(tbl.github_markdown()) 334 | 335 | # }}} 336 | 337 | 338 | def test_eoc(): 339 | np = pytest.importorskip("numpy") 340 | 341 | from pytools.convergence import EOCRecorder 342 | eoc = EOCRecorder() 343 | 344 | # {{{ test pretty_print 345 | 346 | for i in range(1, 8): 347 | eoc.add_data_point(1.0 / i, 10 ** (-i)) 348 | 349 | p = eoc.pretty_print() 350 | print(p) 351 | print() 352 | 353 | p = eoc.pretty_print( 354 | abscissa_format="%.5e", 355 | error_format="%.5e", 356 | eoc_format="%5.2f") 357 | print(p) 358 | 359 | # }}} 360 | 361 | # {{{ test merging 362 | 363 | from pytools.convergence import stringify_eocs 364 | p = stringify_eocs(eoc, eoc, eoc, names=("First", "Second", "Third")) 365 | print(p) 366 | 367 | # }}} 368 | 369 | # {{{ test invalid inputs 370 | 371 | eoc = EOCRecorder() 372 | 373 | # scalar inputs are fine 374 | eoc.add_data_point(1, 1) 375 | eoc.add_data_point(1.0, 1.0) 376 | eoc.add_data_point(np.float32(1.0), 1.0) 377 | eoc.add_data_point(np.array(3), 1.0) 378 | eoc.add_data_point(1.0, np.array(3)) 379 | 380 | # non-scalar inputs are not fine though 381 | with pytest.raises(TypeError): 382 | eoc.add_data_point(np.array([3]), 1.0) 383 | 384 | with pytest.raises(TypeError): 385 | eoc.add_data_point(1.0, np.array([3])) 386 | 387 | # }}} 388 | 389 | 390 | def test_natsorted(): 391 | from pytools import natorder, natsorted 392 | 393 | assert natorder("1.001") < natorder("1.01") 394 | 395 | assert natsorted(["x10", "x1", "x9"]) == ["x1", "x9", "x10"] 396 | assert natsorted(map(str, range(100))) == list(map(str, range(100))) 397 | assert natsorted(["x10", "x1", "x9"], reverse=True) == ["x10", "x9", "x1"] 398 | assert natsorted([10, 1, 9], key=lambda d: f"x{d}") == [1, 9, 10] 399 | 400 | 401 | # {{{ object array iteration behavior 402 | 403 | class FakeArray: 404 | nopes = 0 405 | 406 | def __len__(self): 407 | FakeArray.nopes += 1 408 | return 10 409 | 410 | def __getitem__(self, idx): 411 | FakeArray.nopes += 1 412 | if idx > 10: 413 | raise IndexError 414 | 415 | 416 | def test_make_obj_array_iteration(): 417 | pytest.importorskip("numpy") 418 | 419 | from pytools.obj_array import make_obj_array 420 | make_obj_array([FakeArray()]) 421 | 422 | assert FakeArray.nopes == 0, FakeArray.nopes 423 | 424 | # }}} 425 | 426 | 427 | # {{{ test obj array vectorization and decorators 428 | 429 | def test_obj_array_vectorize(c=1): 430 | np = pytest.importorskip("numpy") 431 | la = pytest.importorskip("numpy.linalg") 432 | 433 | # {{{ functions 434 | 435 | import pytools.obj_array as obj 436 | 437 | def add_one(ary): 438 | assert ary.dtype.char != "O" 439 | return ary + c 440 | 441 | def two_add_one(x, y): 442 | assert x.dtype.char != "O" and y.dtype.char != "O" 443 | return x * y + c 444 | 445 | @obj.obj_array_vectorized 446 | def vectorized_add_one(ary): 447 | assert ary.dtype.char != "O" 448 | return ary + c 449 | 450 | @obj.obj_array_vectorized_n_args 451 | def vectorized_two_add_one(x, y): 452 | assert x.dtype.char != "O" and y.dtype.char != "O" 453 | return x * y + c 454 | 455 | class Adder: 456 | def __init__(self, c): 457 | self.c = c 458 | 459 | def add(self, ary): 460 | assert ary.dtype.char != "O" 461 | return ary + self.c 462 | 463 | @obj.obj_array_vectorized_n_args 464 | def vectorized_add(self, ary): 465 | assert ary.dtype.char != "O" 466 | return ary + self.c 467 | 468 | adder = Adder(c) 469 | 470 | # }}} 471 | 472 | # {{{ check 473 | 474 | scalar_ary = np.ones(42, dtype=np.float64) 475 | object_ary = obj.make_obj_array([scalar_ary, scalar_ary, scalar_ary]) 476 | 477 | for func, vectorizer, nargs in [ 478 | (add_one, obj.obj_array_vectorize, 1), 479 | (two_add_one, obj.obj_array_vectorize_n_args, 2), 480 | (adder.add, obj.obj_array_vectorize, 1), 481 | ]: 482 | input_ary = [scalar_ary] * nargs 483 | result = vectorizer(func, *input_ary) 484 | error = la.norm(result - c - 1) 485 | print(error) 486 | 487 | input_ary = [object_ary] * nargs 488 | result = vectorizer(func, *input_ary) 489 | error = 0 490 | 491 | for func, nargs in [ 492 | (vectorized_add_one, 1), 493 | (vectorized_two_add_one, 2), 494 | (adder.vectorized_add, 1), 495 | ]: 496 | input_ary = [scalar_ary] * nargs 497 | result = func(*input_ary) 498 | 499 | input_ary = [object_ary] * nargs 500 | result = func(*input_ary) 501 | 502 | # }}} 503 | 504 | # }}} 505 | 506 | 507 | def test_tag() -> None: 508 | from pytools.tag import ( 509 | NonUniqueTagError, 510 | Tag, 511 | Taggable, 512 | UniqueTag, 513 | check_tag_uniqueness, 514 | ) 515 | 516 | # Need a subclass that defines the copy function in order to test. 517 | @tag_dataclass 518 | class TaggableWithCopy(Taggable): 519 | tags: frozenset[Tag] 520 | 521 | def _with_new_tags(self, tags): 522 | return TaggableWithCopy(tags) 523 | 524 | class FairRibbon(Tag): 525 | pass 526 | 527 | class BlueRibbon(FairRibbon): 528 | pass 529 | 530 | class RedRibbon(FairRibbon): 531 | pass 532 | 533 | class ShowRibbon(FairRibbon, UniqueTag): 534 | pass 535 | 536 | class BestInShowRibbon(ShowRibbon): 537 | pass 538 | 539 | class ReserveBestInShowRibbon(ShowRibbon): 540 | pass 541 | 542 | class BestInClassRibbon(FairRibbon, UniqueTag): 543 | pass 544 | 545 | best_in_show_ribbon = BestInShowRibbon() 546 | reserve_best_in_show_ribbon = ReserveBestInShowRibbon() 547 | blue_ribbon = BlueRibbon() 548 | red_ribbon = RedRibbon() 549 | best_in_class_ribbon = BestInClassRibbon() 550 | 551 | # Test that input processing fails if there are multiple instances 552 | # of the same UniqueTag subclass 553 | with pytest.raises(NonUniqueTagError): 554 | check_tag_uniqueness(frozenset(( 555 | best_in_show_ribbon, 556 | reserve_best_in_show_ribbon, blue_ribbon, red_ribbon))) 557 | 558 | # Test that input processing fails if any of the tags are not 559 | # a subclass of Tag 560 | with pytest.raises(TypeError): 561 | check_tag_uniqueness(frozenset(( 562 | "I am not a tag", best_in_show_ribbon, # type: ignore[arg-type] 563 | blue_ribbon, red_ribbon))) 564 | 565 | # Test that instantiation succeeds if there are multiple instances 566 | # Tag subclasses. 567 | t1 = TaggableWithCopy(frozenset([reserve_best_in_show_ribbon, blue_ribbon, 568 | red_ribbon])) 569 | assert t1.tags == frozenset((reserve_best_in_show_ribbon, red_ribbon, 570 | blue_ribbon)) 571 | 572 | # Test that instantiation succeeds if there are multiple instances 573 | # of UniqueTag of different subclasses. 574 | t1 = TaggableWithCopy(frozenset([reserve_best_in_show_ribbon, 575 | best_in_class_ribbon, blue_ribbon, 576 | blue_ribbon])) 577 | assert t1.tags == frozenset((reserve_best_in_show_ribbon, best_in_class_ribbon, 578 | blue_ribbon)) 579 | 580 | # Test tagged() function 581 | t2 = t1.tagged(red_ribbon) 582 | print(t2.tags) 583 | assert t2.tags == frozenset((reserve_best_in_show_ribbon, best_in_class_ribbon, 584 | blue_ribbon, red_ribbon)) 585 | 586 | # Test that tagged() fails if a UniqueTag of the same subclass 587 | # is already present 588 | with pytest.raises(NonUniqueTagError): 589 | t1.tagged(best_in_show_ribbon) 590 | 591 | # Test that tagged() fails if tags are not a FrozenSet of Tags 592 | with pytest.raises(TypeError): 593 | t1.tagged(tags=frozenset((1,))) # type: ignore[arg-type] 594 | 595 | # Test without_tags() function 596 | t4 = t2.without_tags(red_ribbon) 597 | assert t4.tags == t1.tags 598 | 599 | # Test that without_tags() fails if the tag is not present. 600 | with pytest.raises(ValueError): 601 | t4.without_tags(red_ribbon) 602 | 603 | # Test DottedName comparison 604 | from pytools.tag import DottedName 605 | assert FairRibbon() == FairRibbon() 606 | assert (FairRibbon().tag_name 607 | == FairRibbon().tag_name 608 | == DottedName(("pytools", "test", "test_pytools", "FairRibbon"))) 609 | assert FairRibbon() != BlueRibbon() 610 | assert FairRibbon().tag_name != BlueRibbon().tag_name 611 | 612 | 613 | def test_unordered_hash(): 614 | import hashlib 615 | import random 616 | 617 | # FIXME: Use randbytes once >=3.9 is OK 618 | lst = [bytes([random.randrange(256) for _ in range(20)]) 619 | for _ in range(200)] 620 | lorig = lst[:] 621 | random.shuffle(lst) 622 | 623 | from pytools import unordered_hash 624 | assert (unordered_hash(hashlib.sha256(), lorig).digest() 625 | == unordered_hash(hashlib.sha256(), lst).digest()) 626 | assert (unordered_hash(hashlib.sha256(), lorig).digest() 627 | == unordered_hash(hashlib.sha256(), lorig).digest()) 628 | assert (unordered_hash(hashlib.sha256(), lorig).digest() 629 | != unordered_hash(hashlib.sha256(), lorig[:-1]).digest()) 630 | lst[0] = b"aksdjfla;sdfjafd" 631 | assert (unordered_hash(hashlib.sha256(), lorig).digest() 632 | != unordered_hash(hashlib.sha256(), lst).digest()) 633 | 634 | 635 | # {{{ sphere sampling 636 | 637 | @pytest.mark.parametrize("sampling", [ 638 | "equidistant", "fibonacci", "fibonacci_min", "fibonacci_avg", 639 | ]) 640 | def test_sphere_sampling(sampling, visualize=False): 641 | from functools import partial 642 | 643 | from pytools import sphere_sample_equidistant, sphere_sample_fibonacci 644 | 645 | npoints = 128 646 | radius = 1.5 647 | 648 | if sampling == "equidistant": 649 | sampling_func = partial(sphere_sample_equidistant, r=radius) 650 | elif sampling == "fibonacci": 651 | sampling_func = partial( 652 | sphere_sample_fibonacci, r=radius, optimize=None) 653 | elif sampling == "fibonacci_min": 654 | sampling_func = partial( 655 | sphere_sample_fibonacci, r=radius, optimize="minimum") 656 | elif sampling == "fibonacci_avg": 657 | sampling_func = partial( 658 | sphere_sample_fibonacci, r=radius, optimize="average") 659 | else: 660 | raise ValueError(f"unknown sampling method: '{sampling}'") 661 | 662 | np = pytest.importorskip("numpy") 663 | points = sampling_func(npoints) 664 | assert np.all(np.linalg.norm(points, axis=0) < radius + 1.0e-15) 665 | 666 | if not visualize: 667 | return 668 | 669 | import matplotlib.pyplot as plt 670 | fig = plt.figure(figsize=(10, 10), dpi=300) 671 | ax = fig.add_subplot(111, projection="3d") 672 | 673 | import matplotlib.tri as mtri 674 | theta = np.arctan2(np.sqrt(points[0]**2 + points[1]**2), points[2]) 675 | phi = np.arctan2(points[1], points[0]) 676 | triangles = mtri.Triangulation(theta, phi) 677 | 678 | ax.plot_trisurf(points[0], points[1], points[2], triangles=triangles.triangles) 679 | ax.set_xlim((-radius, radius)) 680 | ax.set_ylim((-radius, radius)) 681 | ax.set_zlim([-radius, radius]) 682 | ax.margins(0.05, 0.05, 0.05) 683 | 684 | # plt.show() 685 | fig.savefig(f"sphere_sampling_{sampling}") 686 | plt.close(fig) 687 | 688 | # }}} 689 | 690 | 691 | def test_unique_name_gen_conflicting_ok(): 692 | from pytools import UniqueNameGenerator 693 | 694 | ung = UniqueNameGenerator() 695 | ung.add_names({"a", "b", "c"}) 696 | 697 | with pytest.raises(ValueError): 698 | ung.add_names({"a"}) 699 | 700 | ung.add_names({"a", "b", "c"}, conflicting_ok=True) 701 | 702 | 703 | def test_strtobool(): 704 | from pytools import strtobool 705 | assert strtobool("true") is True 706 | assert strtobool("tRuE") is True 707 | assert strtobool("1") is True 708 | assert strtobool("t") is True 709 | assert strtobool("on") is True 710 | 711 | assert strtobool("false") is False 712 | assert strtobool("FaLse") is False 713 | assert strtobool("0") is False 714 | assert strtobool("f") is False 715 | assert strtobool("off") is False 716 | 717 | with pytest.raises(ValueError): 718 | strtobool("tru") 719 | strtobool("fal") 720 | strtobool("xxx") 721 | strtobool(".") 722 | 723 | assert strtobool(None, False) is False 724 | 725 | 726 | def test_to_identifier() -> None: 727 | from pytools import to_identifier 728 | 729 | assert to_identifier("_a_123_") == "_a_123_" 730 | assert to_identifier("a_123") == "a_123" 731 | assert to_identifier("a 123") == "a123" 732 | assert to_identifier("123") == "_123" 733 | assert to_identifier("_123") == "_123" 734 | assert to_identifier("123A") == "_123A" 735 | assert to_identifier("") == "_" 736 | 737 | assert not "a 123".isidentifier() 738 | assert to_identifier("a 123").isidentifier() 739 | assert to_identifier("123").isidentifier() 740 | assert to_identifier("").isidentifier() 741 | 742 | 743 | def test_typedump(): 744 | from pytools import typedump 745 | assert typedump("") == "str" 746 | assert typedump("abcdefg") == "str" 747 | assert typedump(5) == "int" 748 | 749 | assert typedump((5.0, 4)) == "tuple(float,int)" 750 | assert typedump([5, 4]) == "list(int,int)" 751 | assert typedump({5, 4}) == "set(int,int)" 752 | assert typedump(frozenset((1, 2, 3))) == "frozenset(int,int,int)" 753 | 754 | assert typedump([5, 4, 3, 2, 1]) == "list(int,int,int,int,int)" 755 | assert typedump([5, 4, 3, 2, 1, 0]) == "list(int,int,int,int,int,...)" 756 | assert typedump([5, 4, 3, 2, 1, 0], max_seq=6) == "list(int,int,int,int,int,int)" 757 | 758 | assert typedump({5: 42, 7: 43}) == "{'5': int, '7': int}" 759 | 760 | class C: 761 | class D: 762 | pass 763 | 764 | assert typedump(C()) == "pytools.test.test_pytools.test_typedump..C" 765 | assert typedump(C.D()) == "pytools.test.test_pytools.test_typedump..C.D" 766 | assert typedump(C.D(), fully_qualified_name=False) == "D" 767 | 768 | from pytools.datatable import DataTable 769 | t = DataTable(column_names=[]) 770 | 771 | assert typedump(t) == "pytools.datatable.DataTable()" 772 | assert typedump(t, special_handlers={type(t): lambda x: "foo"}) == "foo" 773 | 774 | 775 | def test_unique(): 776 | from pytools import unique, unique_difference, unique_intersection, unique_union 777 | 778 | assert list(unique([1, 2, 1])) == [1, 2] 779 | assert tuple(unique((1, 2, 1))) == (1, 2) 780 | 781 | assert list(range(1000)) == list(unique(range(1000))) 782 | assert list(unique(list(range(1000)) + list(range(1000)))) == list(range(1000)) 783 | 784 | # Also test strings since their ordering would be thrown off by 785 | # set-based 'unique' implementations. 786 | assert list(unique(["a", "b", "a"])) == ["a", "b"] 787 | assert tuple(unique(("a", "b", "a"))) == ("a", "b") 788 | 789 | assert list(unique_difference(["a", "b", "c"], ["b", "c", "d"])) == ["a"] 790 | assert list(unique_difference(["a", "b", "c"], ["a", "b", "c", "d"])) == [] 791 | assert list(unique_difference(["a", "b", "c"], ["a"], ["b"], ["c"])) == [] 792 | 793 | assert list(unique_intersection(["a", "b", "a"], ["b", "c", "a"])) == ["a", "b"] 794 | assert list(unique_intersection(["a", "b", "a"], ["d", "c", "e"])) == [] 795 | 796 | assert list(unique_union(["a", "b", "a"], ["b", "c", "b"])) == ["a", "b", "c"] 797 | assert list(unique_union( 798 | ["a", "b", "a"], ["b", "c", "b"], ["c", "d", "c"])) == ["a", "b", "c", "d"] 799 | assert list(unique(["a", "b", "a"])) == \ 800 | list(unique_union(["a", "b", "a"])) == ["a", "b"] 801 | 802 | assert list(unique_intersection()) == [] 803 | assert list(unique_difference()) == [] 804 | assert list(unique_union()) == [] 805 | 806 | 807 | # This class must be defined globally to be picklable 808 | class SimpleRecord(Record): 809 | pass 810 | 811 | 812 | def test_record(): 813 | r = SimpleRecord(c=3, b=2, a=1) 814 | 815 | assert r.a == 1 816 | assert r.b == 2 817 | assert r.c == 3 818 | 819 | # Fields are sorted alphabetically in records 820 | assert str(r) == "SimpleRecord(a=1, b=2, c=3)" 821 | 822 | # Unregistered fields are (silently) ignored for printing 823 | r.f = 6 824 | assert str(r) == "SimpleRecord(a=1, b=2, c=3)" 825 | 826 | # Registered fields are printed 827 | r.register_fields({"d", "e"}) 828 | assert str(r) == "SimpleRecord(a=1, b=2, c=3)" 829 | 830 | r.d = 4 831 | r.e = 5 832 | assert str(r) == "SimpleRecord(a=1, b=2, c=3, d=4, e=5)" 833 | 834 | with pytest.raises(AttributeError): 835 | r.ff # noqa: B018 836 | 837 | # Test pickling 838 | import pickle 839 | r_pickled = pickle.loads(pickle.dumps(r)) 840 | assert r == r_pickled 841 | 842 | # }}} 843 | 844 | # {{{ __slots__, __dict__, __weakref__ handling 845 | 846 | class RecordWithEmptySlots(Record): 847 | __slots__ = [] 848 | 849 | assert hasattr(RecordWithEmptySlots(), "__slots__") 850 | assert not hasattr(RecordWithEmptySlots(), "__dict__") 851 | assert not hasattr(RecordWithEmptySlots(), "__weakref__") 852 | 853 | class RecordWithUnsetSlots(Record): 854 | pass 855 | 856 | assert hasattr(RecordWithUnsetSlots(), "__slots__") 857 | assert hasattr(RecordWithUnsetSlots(), "__dict__") 858 | assert hasattr(RecordWithUnsetSlots(), "__weakref__") 859 | 860 | from pytools import ImmutableRecord 861 | 862 | class ImmutableRecordWithEmptySlots(ImmutableRecord): 863 | __slots__ = [] 864 | 865 | assert hasattr(ImmutableRecordWithEmptySlots(), "__slots__") 866 | assert hasattr(ImmutableRecordWithEmptySlots(), "__dict__") 867 | assert hasattr(ImmutableRecordWithEmptySlots(), "__weakref__") 868 | 869 | class ImmutableRecordWithUnsetSlots(ImmutableRecord): 870 | pass 871 | 872 | assert hasattr(ImmutableRecordWithUnsetSlots(), "__slots__") 873 | assert hasattr(ImmutableRecordWithUnsetSlots(), "__dict__") 874 | assert hasattr(ImmutableRecordWithUnsetSlots(), "__weakref__") 875 | 876 | # }}} 877 | 878 | 879 | def test_permutations(): 880 | from math import factorial 881 | 882 | from pytools import generate_permutations, generate_unique_permutations 883 | 884 | perm = list(generate_permutations([1, 2, 3, 4])) 885 | assert len(perm) == factorial(4) 886 | seq: Sequence[int] = (1, 3, 3, 4) 887 | perm = list(generate_unique_permutations(seq)) 888 | assert len(perm) == 12 889 | 890 | perms = list(generate_permutations("1234")) 891 | assert len(perms) == factorial(4) 892 | perms = list(generate_unique_permutations("1334")) 893 | assert len(perms) == 12 894 | 895 | 896 | if __name__ == "__main__": 897 | if len(sys.argv) > 1: 898 | exec(sys.argv[1]) 899 | else: 900 | from pytest import main 901 | main([__file__]) 902 | -------------------------------------------------------------------------------- /pytools/version.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import re 4 | from importlib import metadata 5 | 6 | 7 | VERSION_TEXT = metadata.version("pytools") 8 | _match = re.match(r"^([0-9.]+)([a-z0-9]*?)$", VERSION_TEXT) 9 | assert _match is not None 10 | VERSION_STATUS = _match.group(2) 11 | VERSION = tuple(int(nr) for nr in _match.group(1).split(".")) 12 | -------------------------------------------------------------------------------- /run-mypy.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | set -ex 4 | 5 | python -m mypy --strict --follow-imports=silent \ 6 | pytools/datatable.py \ 7 | pytools/graph.py \ 8 | pytools/persistent_dict.py \ 9 | pytools/prefork.py \ 10 | pytools/tag.py \ 11 | pytools/lex.py 12 | python -m mypy pytools 13 | 14 | -------------------------------------------------------------------------------- /run-pylint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -o errexit -o nounset 4 | 5 | ci_support="https://gitlab.tiker.net/inducer/ci-support/raw/main" 6 | 7 | if [[ ! -f .pylintrc.yml ]]; then 8 | curl -o .pylintrc.yml "${ci_support}/.pylintrc-default.yml" 9 | fi 10 | 11 | 12 | if [[ ! -f .run-pylint.py ]]; then 13 | curl -L -o .run-pylint.py "${ci_support}/run-pylint.py" 14 | fi 15 | 16 | 17 | PYLINT_RUNNER_ARGS="--jobs=4 --yaml-rcfile=.pylintrc.yml" 18 | 19 | if [[ -f .pylintrc-local.yml ]]; then 20 | PYLINT_RUNNER_ARGS+=" --yaml-rcfile=.pylintrc-local.yml" 21 | fi 22 | 23 | PYTHONWARNINGS=ignore python .run-pylint.py $PYLINT_RUNNER_ARGS pytools examples "$@" 24 | --------------------------------------------------------------------------------