├── .coveragerc ├── .editorconfig ├── .flake8 ├── .github └── workflows │ ├── docs.yml │ ├── lint.yml │ ├── publish.yml │ └── tests.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yaml ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── README.rst ├── debian ├── changelog ├── compat ├── control ├── copyright ├── python3-pgtoolkit.pydist ├── rules └── source │ ├── format │ └── options ├── docs ├── .gitignore ├── Makefile ├── conf.py ├── conf.rst ├── contents.rst ├── ctl.rst ├── hba.rst ├── index.rst ├── log.rst ├── pgpass.rst └── service.rst ├── mypy.ini ├── pgtoolkit ├── __init__.py ├── _helpers.py ├── conf.py ├── ctl.py ├── errors.py ├── hba.py ├── log │ ├── __init__.py │ ├── __main__.py │ └── parser.py ├── pgpass.py ├── py.typed └── service.py ├── pyproject.toml ├── pytest.ini ├── rpm ├── Makefile ├── README.md ├── build ├── docker-compose.yml └── python-pgtoolkit.spec ├── scripts └── profile-log ├── tests ├── data │ ├── conf.d │ │ ├── .hidden.conf │ │ ├── .includeme │ │ ├── listen.conf │ │ └── with-include.conf │ ├── pg_hba.conf │ ├── pg_hba_bad.conf │ ├── pg_service.conf │ ├── pg_service_bad.conf │ ├── pgpass │ ├── pgpass_bad │ ├── postgres-my-my.conf │ ├── postgres-my.conf │ ├── postgres-mymymy.conf │ ├── postgres.conf │ └── postgresql.log ├── datatests.sh ├── test_conf.py ├── test_ctl.py ├── test_ctl_func.py ├── test_hba.py ├── test_helpers.py ├── test_log.py ├── test_pass.py └── test_service.py └── tox.ini /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | source = pgtoolkit 3 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*] 4 | charset = utf-8 5 | end_of_line = lf 6 | indent_size = 8 7 | insert_final_newline = true 8 | trim_trailing_whitespace = true 9 | max_line_length = 79 10 | 11 | [*.py] 12 | indent_size = 4 13 | indent_style = space 14 | max_line_length = 88 15 | 16 | [*.yml] 17 | indent_size = 2 18 | indent_style = space 19 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | doctests = True 3 | select = B,C,E,F,W,T4,B9 4 | ignore = 5 | # whitespace before ':' 6 | E203, 7 | # line too long 8 | E501, 9 | # missing whitespace around arithmetic operator 10 | E226, 11 | # multiple statements on one line (def) 12 | E704, 13 | # line break before binary operator 14 | W503, 15 | exclude = 16 | .tox, 17 | .venv 18 | -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: Documentation 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | docs: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v4 10 | with: 11 | fetch-depth: 1 12 | - uses: actions/setup-python@v5 13 | with: 14 | python-version: '3.13' 15 | - name: Install dependencies 16 | run: python -m pip install -e ".[doc]" 17 | - name: Check documentation 18 | run: | 19 | rst2html --strict README.rst > /dev/null 20 | make -C docs clean html 21 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | lint: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v4 10 | - uses: actions/setup-python@v5 11 | with: 12 | python-version: '3.13' 13 | - name: Install dependencies 14 | run: python -m pip install tox 15 | - name: Style 16 | run: tox -v -e lint 17 | - name: Typing 18 | run: tox -v -e typing 19 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish 2 | 3 | on: 4 | push: 5 | tags: 6 | - '*' 7 | 8 | jobs: 9 | publish: 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - uses: actions/checkout@v4 14 | with: 15 | fetch-depth: 1 16 | - uses: actions/setup-python@v5 17 | with: 18 | python-version: '3.13' 19 | - name: Install 20 | run: python -m pip install build twine 21 | - name: Build 22 | run: | 23 | python -m build 24 | python -m twine check dist/* 25 | - name: Publish 26 | run: python -m twine upload dist/* 27 | env: 28 | TWINE_USERNAME: __token__ 29 | TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }} 30 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | tests: 7 | 8 | runs-on: ubuntu-latest 9 | strategy: 10 | matrix: 11 | python-version: 12 | - "3.9" 13 | - "3.13" 14 | 15 | steps: 16 | - uses: actions/checkout@v4 17 | - name: Set up Python ${{ matrix.python-version }} 18 | uses: actions/setup-python@v5 19 | with: 20 | python-version: ${{ matrix.python-version }} 21 | - name: Install dependencies 22 | run: python -m pip install tox 23 | - name: Test with pytest 24 | run: | 25 | tox -v -e tests-ci 26 | - name: Upload coverage to Codecov 27 | if: ${{ matrix.python-version == '3.13' }} 28 | uses: codecov/codecov-action@v4 29 | with: 30 | token: ${{ secrets.CODECOV_TOKEN }} 31 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.egg-info/ 2 | *.pyc 3 | coverage.xml 4 | .eggs 5 | .mypy_cache/ 6 | .pybuild 7 | .pytest_cache/ 8 | __pycache__/ 9 | build/ 10 | debian/.debhelper/ 11 | dist/ 12 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: local 3 | hooks: 4 | - id: pyupgrade 5 | name: pyupgrade 6 | entry: pyupgrade --exit-zero-even-if-changed 7 | language: system 8 | types: [python] 9 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | os: ubuntu-22.04 5 | tools: 6 | python: "3" 7 | 8 | python: 9 | install: 10 | - method: pip 11 | path: . 12 | extra_requirements: 13 | - doc 14 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to pgtoolkit 2 | 3 | You're welcome to contribute to pgtoolkit with code and more like issue, review, 4 | documentation and spreading the word! 5 | 6 | pgtoolkit home for contribution is it's [GitHub project 7 | page](https://github.com/dalibo/pgtoolkit). Use issue, PR and comments to get in 8 | touch with us! 9 | 10 | 11 | ## Releasing a new version 12 | 13 | To release a new version you'll need read-write access to GitHub project 14 | https://github.com/dalibo/pgtoolkit 15 | 16 | Then, follow the next steps: 17 | 18 | - Create an annotated (and optionally signed) tag 19 | `git tag -a [-s] -m "pgtoolkit " ` 20 | - Push the new tag 21 | `git push --follow-tags` 22 | - Then the new release will be available at 23 | [PyPI](https://pypi.org/project/pgtoolkit/) 24 | (after GitHub publish workflow is finished) 25 | - Follow instructions to [build rpm](./rpm) and upload to [Dalibo 26 | Labs](https://yum.dalibo.org/labs/). 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | PostgreSQL Licence 2 | 3 | Copyright (c) 2017, DALIBO 4 | 5 | Permission to use, copy, modify, and distribute this software and its 6 | documentation for any purpose, without fee, and without a written agreement is 7 | hereby granted, provided that the above copyright notice and this paragraph and 8 | the following two paragraphs appear in all copies. 9 | 10 | IN NO EVENT SHALL DALIBO BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, SPECIAL, 11 | INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS, ARISING OUT OF THE 12 | USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF DALIBO HAS BEEN ADVISED OF 13 | THE POSSIBILITY OF SUCH DAMAGE. 14 | 15 | DALIBO SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE 17 | SOFTWARE PROVIDED HEREUNDER IS ON AN "AS IS" BASIS, AND DALIBO HAS NO 18 | OBLIGATIONS TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR 19 | MODIFICATIONS. 20 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include .coveragerc 2 | include .flake8 3 | include .pre-commit-config.yaml 4 | include pyproject.toml 5 | include *.md 6 | include *.txt 7 | include LICENSE 8 | include mypy.ini 9 | include pytest.ini 10 | include tox.ini 11 | recursive-include docs *.py *.rst Makefile 12 | recursive-include pgtoolkit *.py 13 | recursive-include tests *.py *.sh 14 | recursive-include tests/data *.conf *.includeme *.log pgpass* 15 | 16 | exclude .editorconfig 17 | exclude .readthedocs.yaml 18 | exclude Makefile 19 | prune .circleci 20 | prune rpm 21 | prune debian 22 | prune scripts 23 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | #################################### 2 | Postgres Cluster Support in Python 3 | #################################### 4 | 5 | | |Tests status| |Codecov| |RTD| 6 | 7 | 8 | ``pgtoolkit`` provides implementations to manage various file formats in Postgres 9 | cluster. Currently: 10 | 11 | - ``postgresql.conf``: read, edit, save. 12 | - ``pg_hba.conf``: render, validate and align columns. 13 | - ``.pgpass``: render, validate and sort lines. 14 | - ``pg_service.conf``: find, read, edit, render. 15 | - Cluster logs. 16 | 17 | It also provides a Python API for calling pg_ctl_ commands. 18 | 19 | .. _pg_ctl: https://www.postgresql.org/docs/current/app-pg-ctl.html 20 | 21 | 22 | .. code:: 23 | 24 | import sys 25 | 26 | from pgtoolkit.hba import parse 27 | 28 | 29 | with open('pg_hba.conf') as fo: 30 | hba = parse(fo) 31 | 32 | hba.write(sys.stdout) 33 | 34 | 35 | The API in this toolkit must: 36 | 37 | - Use only Python stdlib. 38 | - Use Postgres idioms. 39 | - Have full test coverage. 40 | - Run everywhere. 41 | 42 | 43 | Support 44 | ------- 45 | 46 | `pgtoolkit `_ home on GitHub is the unique 47 | way of interacting with developers. Feel free to open an issue to get support. 48 | 49 | 50 | .. |Codecov| image:: https://codecov.io/gh/dalibo/pgtoolkit/branch/master/graph/badge.svg 51 | :target: https://codecov.io/gh/dalibo/pgtoolkit 52 | :alt: Code coverage report 53 | 54 | .. |Tests status| image:: https://github.com/dalibo/pgtoolkit/actions/workflows/tests.yml/badge.svg 55 | :target: https://github.com/dalibo/pgtoolkit/actions/workflows/tests.yml 56 | :alt: Continuous Integration report 57 | 58 | .. |RTD| image:: https://readthedocs.org/projects/pgtoolkit/badge/?version=latest 59 | :target: https://pgtoolkit.readthedocs.io/en/latest/ 60 | :alt: Documentation 61 | -------------------------------------------------------------------------------- /debian/changelog: -------------------------------------------------------------------------------- 1 | pgtoolkit (0.13.0-1) unstable; urgency=medium 2 | 3 | * Initial Release. 4 | 5 | -- Denis Laxalde Tue, 02 Mar 2021 14:42:11 +0100 6 | -------------------------------------------------------------------------------- /debian/compat: -------------------------------------------------------------------------------- 1 | 11 2 | -------------------------------------------------------------------------------- /debian/control: -------------------------------------------------------------------------------- 1 | Source: pgtoolkit 2 | Section: python 3 | Priority: optional 4 | Maintainer: Dalibo 5 | Uploaders: 6 | Denis Laxalde , 7 | Build-Depends: 8 | debhelper (>= 11), 9 | dh-python, 10 | python3-all, 11 | python3-setuptools, 12 | python3-setuptools-scm, 13 | python3-psycopg2, 14 | python3-pytest, 15 | python3-pytest-asyncio, 16 | python3-pytest-mock, 17 | python3-typing-extensions, 18 | Standards-Version: 4.1.3 19 | Homepage: https://github.com/dalibo/pgtoolkit 20 | X-Python3-Version: >= 3.9 21 | #Testsuite: autopkgtest-pkg-python 22 | 23 | Package: python3-pgtoolkit 24 | Architecture: all 25 | Depends: 26 | ${python3:Depends}, 27 | ${misc:Depends}, 28 | Suggests: python-pgtoolkit-doc 29 | Description: PostgreSQL toolkit for Python 30 | pgtoolkit manages various file formats in a PostgreSQL cluster such as: 31 | server configuration ('postgresql.conf'), host-based authentication 32 | ('pg_hba.conf'), password files ('.pgpass'), etc. 33 | . 34 | This package installs the library for Python 3. 35 | -------------------------------------------------------------------------------- /debian/copyright: -------------------------------------------------------------------------------- 1 | Format: https://www.debian.org/doc/packaging-manuals/copyright-format/1.0/ 2 | Upstream-Name: pgtoolkit 3 | Source: https://github.com/dalibo/pgtoolkit 4 | 5 | Files: * 6 | Copyright: 2017 Dalibo 7 | License: PostgreSQL 8 | 9 | License: PostgreSQL 10 | Permission to use, copy, modify, and distribute this software and its 11 | documentation for any purpose, without fee, and without a written agreement is 12 | hereby granted, provided that the above copyright notice and this paragraph and 13 | the following two paragraphs appear in all copies. 14 | . 15 | IN NO EVENT SHALL DALIBO BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, SPECIAL, 16 | INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS, ARISING OUT OF THE 17 | USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF DALIBO HAS BEEN ADVISED OF 18 | THE POSSIBILITY OF SUCH DAMAGE. 19 | . 20 | DALIBO SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE 22 | SOFTWARE PROVIDED HEREUNDER IS ON AN "AS IS" BASIS, AND DALIBO HAS NO 23 | OBLIGATIONS TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR 24 | MODIFICATIONS. 25 | -------------------------------------------------------------------------------- /debian/python3-pgtoolkit.pydist: -------------------------------------------------------------------------------- 1 | pgtoolkit python3-pgtoolkit; PEP386 2 | -------------------------------------------------------------------------------- /debian/rules: -------------------------------------------------------------------------------- 1 | #!/usr/bin/make -f 2 | # See debhelper(7) (uncomment to enable) 3 | # output every command that modifies files on the build system. 4 | #export DH_VERBOSE = 1 5 | 6 | export PYBUILD_NAME=pgtoolkit 7 | export PYBUILD_DISABLE=test 8 | 9 | %: 10 | dh $@ --with python3 --buildsystem=pybuild 11 | -------------------------------------------------------------------------------- /debian/source/format: -------------------------------------------------------------------------------- 1 | 3.0 (quilt) 2 | -------------------------------------------------------------------------------- /debian/source/options: -------------------------------------------------------------------------------- 1 | extend-diff-ignore = "(^[^/]*[.]egg-info/|.circleci|.coverage|.editorconfig|.github|.mypy_cache|.pytest_cache|.readthedocs.yaml|.venv|Makefile|docs/_build|dist|rpm|setup.cfg|scripts|tox.ini)" 2 | -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | _build/ 2 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = -W 6 | SPHINXBUILD = sphinx-build 7 | SPHINXPROJ = pgtoolkit 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | 22 | serve: 23 | xdg-open $(BUILDDIR)/html/index.html 24 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # 2 | # Configuration file for the Sphinx documentation builder. 3 | # 4 | # This file does only contain a selection of the most common options. For a 5 | # full list see the documentation: 6 | # http://www.sphinx-doc.org/en/master/config 7 | 8 | # -- Path setup -------------------------------------------------------------- 9 | 10 | # If extensions (or modules to document with autodoc) are in another directory, 11 | # add these directories to sys.path here. If the directory is relative to the 12 | # documentation root, use os.path.abspath to make it absolute, like shown here. 13 | # 14 | import importlib.metadata 15 | 16 | # import sys 17 | # sys.path.insert(0, os.path.abspath('.')) 18 | 19 | # -- Project information ----------------------------------------------------- 20 | project = "pgtoolkit" 21 | author = "Dalibo" 22 | copyright = "2018, Dalibo Labs" 23 | 24 | # The full version, including alpha/beta/rc tags 25 | release = importlib.metadata.version(project) 26 | # The short X.Y version 27 | version = ".".join(release.split(".")[:2]) 28 | 29 | 30 | # -- General configuration --------------------------------------------------- 31 | 32 | # If your documentation needs a minimal Sphinx version, state it here. 33 | # 34 | # needs_sphinx = '1.0' 35 | 36 | # Add any Sphinx extension module names here, as strings. They can be 37 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 38 | # ones. 39 | extensions = [ 40 | "sphinx.ext.autodoc", 41 | "sphinx.ext.doctest", 42 | "sphinx.ext.intersphinx", 43 | "sphinx.ext.viewcode", 44 | ] 45 | 46 | # Add any paths that contain templates here, relative to this directory. 47 | templates_path = ["_templates"] 48 | 49 | # The suffix(es) of source filenames. 50 | # You can specify multiple suffix as a list of string: 51 | # 52 | # source_suffix = ['.rst', '.md'] 53 | source_suffix = ".rst" 54 | 55 | # The master toctree document. 56 | master_doc = "contents" 57 | 58 | # The language for content autogenerated by Sphinx. Refer to documentation 59 | # for a list of supported languages. 60 | # 61 | # This is also used if you do content translation via gettext catalogs. 62 | # Usually you set "language" from the command line for these cases. 63 | language = "en" 64 | 65 | # List of patterns, relative to source directory, that match files and 66 | # directories to ignore when looking for source files. 67 | # This pattern also affects html_static_path and html_extra_path . 68 | exclude_patterns = [] 69 | 70 | # The name of the Pygments (syntax highlighting) style to use. 71 | pygments_style = "sphinx" 72 | 73 | 74 | # -- Options for HTML output ------------------------------------------------- 75 | 76 | # The theme to use for HTML and HTML Help pages. See the documentation for 77 | # a list of builtin themes. 78 | # 79 | html_theme = "sphinx_rtd_theme" 80 | 81 | # Theme options are theme-specific and customize the look and feel of a theme 82 | # further. For a list of options available for each theme, see the 83 | # documentation. 84 | # 85 | # html_theme_options = {} 86 | 87 | # Add any paths that contain custom static files (such as style sheets) here, 88 | # relative to this directory. They are copied after the builtin static files, 89 | # so a file named "default.css" will overwrite the builtin "default.css". 90 | # html_static_path = ['_static'] 91 | 92 | # Custom sidebar templates, must be a dictionary that maps document names 93 | # to template names. 94 | # 95 | # The default sidebars (for documents that don't match any pattern) are 96 | # defined by theme itself. Builtin themes are using these templates by 97 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', 98 | # 'searchbox.html']``. 99 | # 100 | # html_sidebars = {} 101 | 102 | 103 | # -- Options for manual page output ------------------------------------------ 104 | 105 | # One entry per manual page. List of tuples 106 | # (source start file, name, description, authors, manual section). 107 | man_pages = [(master_doc, "pgtoolkit", "pgtoolkit Documentation", [author], 1)] 108 | 109 | 110 | # -- Options for Texinfo output ---------------------------------------------- 111 | 112 | # Grouping the document tree into Texinfo files. List of tuples 113 | # (source start file, target name, title, author, 114 | # dir menu entry, description, category) 115 | # texinfo_documents = [ 116 | # (master_doc, 'pgtoolkit', 'pgtoolkit Documentation', 117 | # author, 'pgtoolkit', metadatas['description'], 118 | # 'Miscellaneous'), 119 | # ] 120 | 121 | 122 | # -- Extension configuration ------------------------------------------------- 123 | 124 | # -- Options for intersphinx extension --------------------------------------- 125 | 126 | # Example configuration for intersphinx: refer to the Python standard library. 127 | intersphinx_mapping = { 128 | "python": ("https://docs.python.org/3/", None), 129 | } 130 | -------------------------------------------------------------------------------- /docs/conf.rst: -------------------------------------------------------------------------------- 1 | ====================== 2 | :mod:`pgtoolkit.conf` 3 | ====================== 4 | 5 | .. automodule:: pgtoolkit.conf 6 | -------------------------------------------------------------------------------- /docs/contents.rst: -------------------------------------------------------------------------------- 1 | =================== 2 | Table of Contents 3 | =================== 4 | 5 | 6 | .. toctree:: 7 | 8 | index 9 | conf 10 | hba 11 | log 12 | pgpass 13 | service 14 | ctl 15 | -------------------------------------------------------------------------------- /docs/ctl.rst: -------------------------------------------------------------------------------- 1 | ====================== 2 | :mod:`pgtoolkit.ctl` 3 | ====================== 4 | 5 | .. note:: This module requires Python 3.8 or higher. 6 | 7 | .. automodule:: pgtoolkit.ctl 8 | -------------------------------------------------------------------------------- /docs/hba.rst: -------------------------------------------------------------------------------- 1 | ======================== 2 | :mod:`pgtoolkit.hba` 3 | ======================== 4 | 5 | .. automodule:: pgtoolkit.hba 6 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | #################################### 2 | Postgres Cluster Support in Python 3 | #################################### 4 | 5 | pgtoolkit is a Python library providing API to interact with various PostgreSQL 6 | file formats, offline. Namely: 7 | 8 | * :mod:`postgresql.conf ` 9 | * :mod:`pg_hba.conf ` 10 | * :mod:`.pgpass ` 11 | * :mod:`pg_service.conf ` 12 | * :mod:`logs ` 13 | 14 | It also provides a Python API for calling pg_ctl_ commands in :mod:`ctl 15 | ` module. 16 | 17 | .. _pg_ctl: https://www.postgresql.org/docs/current/app-pg-ctl.html 18 | 19 | Quick installation 20 | ------------------ 21 | 22 | Just use PyPI as any regular Python project: 23 | 24 | .. code:: console 25 | 26 | $ pip install --pre pgtoolkit 27 | 28 | 29 | Support 30 | ------- 31 | 32 | If you need support for ``pgtoolkit``, just drop an `issue on 33 | GitHub `__! 34 | 35 | 36 | Project name 37 | ------------ 38 | 39 | There is a homonym project by @grayhemp since September 2013: 40 | `PgToolkit `__. 41 | ``grayhemp/PgToolkit`` is a single tool project, thus *toolkit* is 42 | misleading. Also, as of August 2018, it is inactive for 3 years. 43 | 44 | There is no Python library named ``pgtoolkit``. There is no CLI program 45 | named ``pgtoolkit``. There is no ``pgtoolkit`` package. Considering 46 | this, ``pgtoolkit`` was chosen for this project. 47 | 48 | Please file a `new issue `_ if 49 | you have feedback on project name. 50 | -------------------------------------------------------------------------------- /docs/log.rst: -------------------------------------------------------------------------------- 1 | ====================== 2 | :mod:`pgtoolkit.log` 3 | ====================== 4 | 5 | .. automodule:: pgtoolkit.log 6 | -------------------------------------------------------------------------------- /docs/pgpass.rst: -------------------------------------------------------------------------------- 1 | ========================= 2 | :mod:`pgtoolkit.pgpass` 3 | ========================= 4 | 5 | .. automodule:: pgtoolkit.pgpass 6 | -------------------------------------------------------------------------------- /docs/service.rst: -------------------------------------------------------------------------------- 1 | ======================== 2 | :mod:`pgtoolkit.service` 3 | ======================== 4 | 5 | .. automodule:: pgtoolkit.service 6 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | strict = True 3 | warn_unused_ignores = True 4 | show_error_codes = True 5 | -------------------------------------------------------------------------------- /pgtoolkit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dalibo/pgtoolkit/1996d702a46f39e7020005a638dad8b7200b209a/pgtoolkit/__init__.py -------------------------------------------------------------------------------- /pgtoolkit/_helpers.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import json 4 | import sys 5 | from datetime import datetime, timedelta, timezone 6 | from pathlib import Path 7 | from typing import IO, Any, Generic, NoReturn, TypeVar, overload 8 | 9 | 10 | def format_timedelta(delta: timedelta) -> str: 11 | values = [ 12 | f"{v}{u}" 13 | for v, u in ( 14 | (delta.days, "d"), 15 | (delta.seconds, "s"), 16 | (delta.microseconds, "us"), 17 | ) 18 | if v 19 | ] 20 | if values: 21 | return " ".join(values) 22 | else: 23 | return "0s" 24 | 25 | 26 | class JSONDateEncoder(json.JSONEncoder): 27 | def default(self, obj: timedelta | datetime | object) -> Any: 28 | if isinstance(obj, datetime): 29 | return obj.isoformat() 30 | elif isinstance(obj, timedelta): 31 | return format_timedelta(obj) 32 | return super().default(obj) 33 | 34 | 35 | def open_or_stdin(filename: str, stdin: IO[str] = sys.stdin) -> IO[str]: 36 | if filename == "-": 37 | fo = stdin 38 | else: 39 | fo = open(filename) 40 | return fo 41 | 42 | 43 | T = TypeVar("T") 44 | 45 | 46 | class PassthroughManager(Generic[T]): 47 | def __init__(self, ret: T) -> None: 48 | self.ret = ret 49 | 50 | def __enter__(self) -> T: 51 | return self.ret 52 | 53 | def __exit__(self, *a: Any) -> None: 54 | pass 55 | 56 | 57 | @overload 58 | def open_or_return(fo_or_path: None, mode: str = "r") -> NoReturn: ... 59 | 60 | 61 | @overload 62 | def open_or_return(fo_or_path: str, mode: str = "r") -> IO[str]: ... 63 | 64 | 65 | @overload 66 | def open_or_return(fo_or_path: Path, mode: str = "r") -> IO[str]: ... 67 | 68 | 69 | @overload 70 | def open_or_return( 71 | fo_or_path: IO[str], mode: str = "r" 72 | ) -> PassthroughManager[IO[str]]: ... 73 | 74 | 75 | def open_or_return( 76 | fo_or_path: str | Path | IO[str] | None, mode: str = "r" 77 | ) -> IO[str] | PassthroughManager[IO[str]]: 78 | # Returns a context manager around a file-object for fo_or_path. If 79 | # fo_or_path is a file-object, the context manager keeps it open. If it's a 80 | # path, the file is opened with mode and will be closed upon context exit. 81 | # If fo_or_path is None, a ValueError is raised. 82 | 83 | if fo_or_path is None: 84 | raise ValueError("No file-like object nor path provided") 85 | if isinstance(fo_or_path, str): 86 | return open(fo_or_path, mode) 87 | if isinstance(fo_or_path, Path): 88 | return fo_or_path.open(mode) 89 | 90 | # Skip default file context manager. This allows to always use with 91 | # statement and don't care about closing the file. If the file is opened 92 | # here, it will be closed properly. Otherwise, it will be kept open thanks 93 | # to PassthroughManager. 94 | return PassthroughManager(fo_or_path) 95 | 96 | 97 | class Timer: 98 | def __enter__(self) -> Timer: 99 | self.start = datetime.now(timezone.utc) 100 | return self 101 | 102 | def __exit__(self, *a: Any) -> None: 103 | self.delta = datetime.now(timezone.utc) - self.start 104 | -------------------------------------------------------------------------------- /pgtoolkit/conf.py: -------------------------------------------------------------------------------- 1 | """\ 2 | .. currentmodule:: pgtoolkit.conf 3 | 4 | This module implements ``postgresql.conf`` file format. This is the same format 5 | for ``recovery.conf``. The main entry point of the API is :func:`parse`. The 6 | module can be used as a CLI script. 7 | 8 | 9 | API Reference 10 | ------------- 11 | 12 | .. autofunction:: parse 13 | .. autofunction:: parse_string 14 | .. autoclass:: Configuration 15 | .. autoclass:: ParseError 16 | 17 | 18 | Using as a CLI Script 19 | --------------------- 20 | 21 | You can use this module to dump a configuration file as JSON object 22 | 23 | .. code:: console 24 | 25 | $ python -m pgtoolkit.conf postgresql.conf | jq . 26 | { 27 | "lc_monetary": "fr_FR.UTF8", 28 | "datestyle": "iso, dmy", 29 | "log_rotation_age": "1d", 30 | "log_min_duration_statement": "3s", 31 | "log_lock_waits": true, 32 | "log_min_messages": "notice", 33 | "log_directory": "log", 34 | "port": 5432, 35 | "log_truncate_on_rotation": true, 36 | "log_rotation_size": 0 37 | } 38 | $ 39 | 40 | """ 41 | 42 | from __future__ import annotations 43 | 44 | import contextlib 45 | import copy 46 | import enum 47 | import json 48 | import pathlib 49 | import re 50 | import sys 51 | from collections import OrderedDict 52 | from collections.abc import Iterable, Iterator 53 | from dataclasses import dataclass, field 54 | from datetime import timedelta 55 | from typing import IO, Any, ClassVar, NoReturn, Union 56 | from warnings import warn 57 | 58 | from ._helpers import JSONDateEncoder, open_or_return 59 | 60 | 61 | class ParseError(Exception): 62 | """Error while parsing configuration content.""" 63 | 64 | 65 | class IncludeType(enum.Enum): 66 | """Include directive types. 67 | 68 | https://www.postgresql.org/docs/13/config-setting.html#CONFIG-INCLUDES 69 | """ 70 | 71 | include_dir = enum.auto() 72 | include_if_exists = enum.auto() 73 | include = enum.auto() 74 | 75 | 76 | def parse(fo: str | pathlib.Path | IO[str]) -> Configuration: 77 | """Parse a configuration file. 78 | 79 | The parser tries to return Python object corresponding to value, based on 80 | some heuristics. booleans, octal number, decimal integers and floating 81 | point numbers are parsed. Multiplier units like kB or MB are applied and 82 | you get an int. Interval value like ``3s`` are returned as 83 | :class:`datetime.timedelta`. 84 | 85 | In case of doubt, the value is kept as a string. It's up to you to enforce 86 | format. 87 | 88 | Include directives are processed recursively, when 'fo' is a file path (not 89 | a file object). If some included file is not found a FileNotFoundError 90 | exception is raised. If a loop is detected in include directives, a 91 | RuntimeError is raised. 92 | 93 | :param fo: A line iterator such as a file-like object or a path. 94 | :returns: A :class:`Configuration` containing parsed configuration. 95 | 96 | """ 97 | with open_or_return(fo) as f: 98 | conf = Configuration(getattr(f, "name", None)) 99 | list(_consume(conf, f)) 100 | 101 | return conf 102 | 103 | 104 | def _consume(conf: Configuration, content: Iterable[str]) -> Iterator[None]: 105 | for include_path, include_type in conf.parse(content): 106 | yield from parse_include(conf, include_path, include_type) 107 | 108 | 109 | def parse_string(string: str, source: str | None = None) -> Configuration: 110 | """Parse configuration data from a string. 111 | 112 | Optional *source* argument can be used to set the context path of built 113 | Configuration. 114 | 115 | :raises ParseError: if the string contains include directives referencing a relative 116 | path and *source* is unspecified. 117 | """ 118 | conf = Configuration(source) 119 | conf.parse_string(string) 120 | return conf 121 | 122 | 123 | def parse_include( 124 | conf: Configuration, 125 | path: pathlib.Path, 126 | include_type: IncludeType, 127 | *, 128 | _processed: set[pathlib.Path] | None = None, 129 | ) -> Iterator[None]: 130 | """Parse on include directive with 'path' value of type 'include_type' into 131 | 'conf' object. 132 | """ 133 | if _processed is None: 134 | _processed = set() 135 | 136 | def notfound( 137 | path: pathlib.Path, include_type: str, reference_path: str | None 138 | ) -> FileNotFoundError: 139 | ref = ( 140 | f"{reference_path!r}" if reference_path is not None else "" 141 | ) 142 | return FileNotFoundError( 143 | f"{include_type} '{path}', included from {ref}, not found" 144 | ) 145 | 146 | if not path.is_absolute(): 147 | if not conf.path: 148 | raise ParseError( 149 | "cannot process include directives referencing a relative path" 150 | ) 151 | relative_to = pathlib.Path(conf.path).absolute() 152 | assert relative_to.is_absolute() 153 | if relative_to.is_file(): 154 | relative_to = relative_to.parent 155 | path = relative_to / path 156 | 157 | if include_type == IncludeType.include_dir: 158 | if not path.exists() or not path.is_dir(): 159 | raise notfound(path, "directory", conf.path) 160 | for confpath in sorted(path.glob("*.conf")): 161 | if not confpath.name.startswith("."): 162 | yield from parse_include( 163 | conf, 164 | confpath, 165 | IncludeType.include, 166 | _processed=_processed, 167 | ) 168 | 169 | elif include_type == IncludeType.include_if_exists: 170 | if path.exists(): 171 | yield from parse_include( 172 | conf, path, IncludeType.include, _processed=_processed 173 | ) 174 | 175 | elif include_type == IncludeType.include: 176 | if not path.exists(): 177 | raise notfound(path, "file", conf.path) 178 | 179 | if path in _processed: 180 | raise RuntimeError(f"loop detected in include directive about '{path}'") 181 | _processed.add(path) 182 | 183 | subconf = Configuration(path=str(path)) 184 | with path.open() as f: 185 | for sub_include_path, sub_include_type in subconf.parse(f): 186 | yield from parse_include( 187 | subconf, 188 | sub_include_path, 189 | sub_include_type, 190 | _processed=_processed, 191 | ) 192 | conf.entries.update(subconf.entries) 193 | 194 | else: 195 | assert False, include_type # pragma: nocover 196 | 197 | 198 | MEMORY_MULTIPLIERS = { 199 | "kB": 1024, 200 | "MB": 1024 * 1024, 201 | "GB": 1024 * 1024 * 1024, 202 | "TB": 1024 * 1024 * 1024 * 1024, 203 | } 204 | _memory_re = re.compile(r"^\s*(?P\d+)\s*(?P[kMGT]B)\s*$") 205 | TIMEDELTA_ARGNAME = { 206 | "ms": "milliseconds", 207 | "s": "seconds", 208 | "min": "minutes", 209 | "h": "hours", 210 | "d": "days", 211 | } 212 | _timedelta_re = re.compile(r"^\s*(?P\d+)\s*(?Pms|s|min|h|d)\s*$") 213 | 214 | _minute = 60 215 | _hour = 60 * _minute 216 | _day = 24 * _hour 217 | _timedelta_unit_map = [ 218 | ("d", _day), 219 | ("h", _hour), 220 | # The space before 'min' is intentional. I find '1 min' more readable 221 | # than '1min'. 222 | (" min", _minute), 223 | ("s", 1), 224 | ] 225 | 226 | 227 | Value = Union[str, bool, float, int, timedelta] 228 | 229 | 230 | def parse_value(raw: str) -> Value: 231 | # Ref. 232 | # https://www.postgresql.org/docs/current/static/config-setting.html#CONFIG-SETTING-NAMES-VALUES 233 | 234 | quoted = False 235 | if raw.startswith("'"): 236 | if not raw.endswith("'"): 237 | raise ValueError(raw) 238 | # unquote value and unescape quotes 239 | raw = raw[1:-1].replace("''", "'").replace(r"\'", "'") 240 | quoted = True 241 | 242 | if raw.startswith("0") and raw != "0": 243 | try: 244 | int(raw, base=8) 245 | return raw 246 | except ValueError: 247 | pass 248 | 249 | m = _memory_re.match(raw) 250 | if m: 251 | return raw.strip() 252 | 253 | m = _timedelta_re.match(raw) 254 | if m: 255 | unit = m.group("unit") 256 | arg = TIMEDELTA_ARGNAME[unit] 257 | kwargs = {arg: int(m.group("number"))} 258 | return timedelta(**kwargs) 259 | 260 | if raw.lower() in ("true", "yes", "on"): 261 | return True 262 | 263 | if raw.lower() in ("false", "no", "off"): 264 | return False 265 | 266 | if not quoted: 267 | try: 268 | return int(raw) 269 | except ValueError: 270 | try: 271 | return float(raw) 272 | except ValueError: 273 | return raw 274 | 275 | return raw 276 | 277 | 278 | def serialize_value(value: Value) -> str: 279 | # This is the reverse of parse_value. 280 | if isinstance(value, bool): 281 | value = "on" if value else "off" 282 | elif isinstance(value, str): 283 | # Only quote if not already quoted. 284 | if not (value.startswith("'") and value.endswith("'")): 285 | # Only double quotes, if not already done; we assume this is 286 | # done everywhere in the string or nowhere. 287 | if "''" not in value and r"\'" not in value: 288 | value = value.replace("'", "''") 289 | value = "'%s'" % value 290 | elif isinstance(value, timedelta): 291 | seconds = value.days * _day + value.seconds 292 | if value.microseconds: 293 | unit = " ms" 294 | value = seconds * 1000 + value.microseconds // 1000 295 | else: 296 | for unit, mod in _timedelta_unit_map: 297 | if seconds % mod: 298 | continue 299 | value = seconds // mod 300 | break 301 | value = f"'{value}{unit}'" 302 | else: 303 | value = str(value) 304 | return value 305 | 306 | 307 | _unspecified: Any = object() 308 | 309 | 310 | @dataclass 311 | class Entry: 312 | """Configuration entry, parsed from a line in the configuration file.""" 313 | 314 | name: str 315 | _value: Value 316 | # _: KW_ONLY from Python 3.10 317 | commented: bool = False 318 | comment: str | None = None 319 | raw_line: str = field(default=_unspecified, compare=False, repr=False) 320 | 321 | def __post_init__(self) -> None: 322 | if self.raw_line is _unspecified: 323 | # We parse value only if not already parsed from a file 324 | if isinstance(self._value, str): 325 | self._value = parse_value(self._value) 326 | # Store the raw_line to track the position in the list of lines. 327 | self.raw_line = str(self) + "\n" 328 | 329 | @property 330 | def value(self) -> Value: 331 | return self._value 332 | 333 | @value.setter 334 | def value(self, value: str | Value) -> None: 335 | if isinstance(value, str): 336 | value = parse_value(value) 337 | self._value = value 338 | 339 | def serialize(self) -> str: 340 | return serialize_value(self.value) 341 | 342 | def __str__(self) -> str: 343 | line = "%(name)s = %(value)s" % dict(name=self.name, value=self.serialize()) 344 | if self.comment: 345 | line += " # " + self.comment 346 | if self.commented: 347 | line = "#" + line 348 | return line 349 | 350 | 351 | class EntriesProxy(dict[str, Entry]): 352 | """Proxy object used during Configuration edition. 353 | 354 | >>> p = EntriesProxy(port=Entry('port', '5432'), 355 | ... shared_buffers=Entry('shared_buffers', '1GB')) 356 | 357 | Existing entries can be edited: 358 | 359 | >>> p['port'].value = '5433' 360 | 361 | New entries can be added as: 362 | 363 | >>> p.add('listen_addresses', '*', commented=True, comment='IP address') 364 | >>> p # doctest: +NORMALIZE_WHITESPACE 365 | {'port': Entry(name='port', _value=5433, commented=False, comment=None), 366 | 'shared_buffers': Entry(name='shared_buffers', _value='1GB', commented=False, comment=None), 367 | 'listen_addresses': Entry(name='listen_addresses', _value='*', commented=True, comment='IP address')} 368 | >>> del p['shared_buffers'] 369 | >>> p # doctest: +NORMALIZE_WHITESPACE 370 | {'port': Entry(name='port', _value=5433, commented=False, comment=None), 371 | 'listen_addresses': Entry(name='listen_addresses', _value='*', commented=True, comment='IP address')} 372 | 373 | Adding an existing entry fails: 374 | >>> p.add('port', 5433) 375 | Traceback (most recent call last): 376 | ... 377 | ValueError: 'port' key already present 378 | 379 | So does adding a value to the underlying dict: 380 | >>> p['bonjour_name'] = 'pgserver' 381 | Traceback (most recent call last): 382 | ... 383 | TypeError: cannot set a key 384 | """ 385 | 386 | def __setitem__(self, key: str, value: Any) -> NoReturn: 387 | raise TypeError("cannot set a key") 388 | 389 | def add( 390 | self, 391 | name: str, 392 | value: Value, 393 | *, 394 | commented: bool = False, 395 | comment: str | None = None, 396 | ) -> None: 397 | """Add a new entry.""" 398 | if name in self: 399 | raise ValueError(f"'{name}' key already present") 400 | entry = Entry(name, value, commented=commented, comment=comment) 401 | super().__setitem__(name, entry) 402 | 403 | 404 | @dataclass 405 | class Configuration: 406 | r"""Holds a parsed configuration. 407 | 408 | You can access parameter using attribute or dictionary syntax. 409 | 410 | >>> conf = parse(['port=5432\n', 'pg_stat_statement.min_duration = 3s\n']) 411 | >>> conf.port 412 | 5432 413 | >>> conf.port = 5433 414 | >>> conf.port 415 | 5433 416 | >>> conf['port'] = 5434 417 | >>> conf.port 418 | 5434 419 | >>> conf['pg_stat_statement.min_duration'].total_seconds() 420 | 3.0 421 | >>> conf.get("ssl") 422 | >>> conf.get("ssl", False) 423 | False 424 | 425 | Configuration instances can be merged: 426 | 427 | >>> otherconf = parse(["listen_addresses='*'\n", "port = 5454\n"]) 428 | >>> sumconf = conf + otherconf 429 | >>> print(json.dumps(sumconf.as_dict(), cls=JSONDateEncoder, indent=2)) 430 | { 431 | "port": 5454, 432 | "pg_stat_statement.min_duration": "3s", 433 | "listen_addresses": "*" 434 | } 435 | 436 | though, lines are discarded in the operation: 437 | 438 | >>> sumconf.lines 439 | [] 440 | 441 | >>> conf += otherconf 442 | >>> print(json.dumps(conf.as_dict(), cls=JSONDateEncoder, indent=2)) 443 | { 444 | "port": 5454, 445 | "pg_stat_statement.min_duration": "3s", 446 | "listen_addresses": "*" 447 | } 448 | >>> conf.lines 449 | [] 450 | 451 | .. attribute:: path 452 | 453 | Path to a file. Automatically set when calling :func:`parse` with a path 454 | to a file. This is default target for :meth:`save`. 455 | 456 | .. automethod:: edit 457 | .. automethod:: save 458 | 459 | """ # noqa 460 | 461 | # Internally, lines property contains an updated list of all comments and 462 | # entries serialized. When adding a setting or updating an existing one, 463 | # the serialized line is updated accordingly. This allows to keep comments 464 | # and serialize only what's needed. Other lines are just written as-is. 465 | 466 | path: str | None = None 467 | lines: list[str] = field(default_factory=list, init=False) 468 | entries: dict[str, Entry] = field(default_factory=OrderedDict, init=False) 469 | 470 | _parameter_re: ClassVar = re.compile( 471 | r"^(?P[a-z_.]+)(?: +(?!=)| *= *)(?P.*?)" 472 | "[\\s\t]*" 473 | r"(?P#.*)?$" 474 | ) 475 | 476 | def parse(self, fo: Iterable[str]) -> Iterator[tuple[pathlib.Path, IncludeType]]: 477 | for raw_line in fo: 478 | self.lines.append(raw_line) 479 | line = raw_line.strip() 480 | if not line: 481 | continue 482 | commented = False 483 | if line.startswith("#"): 484 | # Try to parse the commented line as a commented parameter, 485 | # but only if in the form of 'name = value' since we cannot 486 | # discriminate a commented sentence (with whitespaces) from a 487 | # commented parameter in the form of 'name value'. 488 | if "=" not in line: 489 | continue 490 | line = line.lstrip("#").lstrip() 491 | m = self._parameter_re.match(line) 492 | if not m: 493 | # This is a real comment 494 | continue 495 | commented = True 496 | else: 497 | m = self._parameter_re.match(line) 498 | if not m: 499 | raise ValueError("Bad line: %r." % raw_line) 500 | kwargs = m.groupdict() 501 | name = kwargs.pop("name") 502 | value = parse_value(kwargs.pop("value")) 503 | if name in IncludeType.__members__: 504 | if not commented: 505 | include_type = IncludeType[name] 506 | assert isinstance(value, str), type(value) 507 | yield (pathlib.Path(value), include_type) 508 | else: 509 | comment = kwargs["comment"] 510 | if comment is not None: 511 | kwargs["comment"] = comment.lstrip("#").lstrip() 512 | if commented: 513 | # Only overwrite a previous entry if it is commented. 514 | try: 515 | existing_entry = self.entries[name] 516 | except KeyError: 517 | pass 518 | else: 519 | if not existing_entry.commented: 520 | continue 521 | self.entries[name] = Entry( 522 | name, value, commented=commented, raw_line=raw_line, **kwargs 523 | ) 524 | 525 | def parse_string(self, string: str) -> None: 526 | list(_consume(self, string.splitlines(keepends=True))) 527 | 528 | def __add__(self, other: Any) -> Configuration: 529 | cls = self.__class__ 530 | if not isinstance(other, cls): 531 | return NotImplemented 532 | s = cls() 533 | s.entries.update(self.entries) 534 | s.entries.update(other.entries) 535 | return s 536 | 537 | def __iadd__(self, other: Any) -> Configuration: 538 | cls = self.__class__ 539 | if not isinstance(other, cls): 540 | return NotImplemented 541 | self.lines[:] = [] 542 | self.entries.update(other.entries) 543 | return self 544 | 545 | def __getattr__(self, name: str) -> Value: 546 | try: 547 | return self.entries[name].value 548 | except KeyError: 549 | raise AttributeError(name) 550 | 551 | def __setattr__(self, name: str, value: Value) -> None: 552 | if name in self.__dataclass_fields__: 553 | super().__setattr__(name, value) 554 | else: 555 | self[name] = value 556 | 557 | def __contains__(self, key: str) -> bool: 558 | return key in self.entries 559 | 560 | def __getitem__(self, key: str) -> Value: 561 | return self.entries[key].value 562 | 563 | def __setitem__(self, key: str, value: Value) -> None: 564 | if key in IncludeType.__members__: 565 | raise ValueError("cannot add an include directive") 566 | if key in self.entries: 567 | e = self.entries[key] 568 | e.value = value 569 | self._update_entry(e) 570 | else: 571 | self._add_entry(Entry(key, value)) 572 | 573 | def get(self, key: str, default: Value | None = None) -> Value | None: 574 | try: 575 | return self[key] 576 | except KeyError: 577 | return default 578 | 579 | def _add_entry(self, entry: Entry) -> None: 580 | assert entry.name not in self.entries 581 | self.entries[entry.name] = entry 582 | # Append serialized line. 583 | entry.raw_line = str(entry) + "\n" 584 | self.lines.append(entry.raw_line) 585 | 586 | def _update_entry(self, entry: Entry) -> None: 587 | key = entry.name 588 | old_entry, self.entries[key] = self.entries[key], entry 589 | if old_entry.commented: 590 | # If the entry was previously commented, we uncomment it (assuming 591 | # that setting a value to a commented entry does not make much 592 | # sense.) 593 | entry.commented = False 594 | # Update serialized entry. 595 | old_line = old_entry.raw_line 596 | entry.raw_line = str(entry) + "\n" 597 | try: 598 | lineno = self.lines.index(old_line) 599 | except ValueError: 600 | if not entry.commented: 601 | msg = ( 602 | f"entry {key!r} not directly found in {self.path or 'parsed content'}" 603 | " (it might be defined in an included file)," 604 | " appending a new line to set requested value" 605 | ) 606 | warn(msg, UserWarning) 607 | self.lines.append(entry.raw_line) 608 | else: 609 | self.lines[lineno : lineno + 1] = [entry.raw_line] 610 | 611 | def __iter__(self) -> Iterator[Entry]: 612 | return iter(self.entries.values()) 613 | 614 | def as_dict(self) -> dict[str, Value]: 615 | return {k: v.value for k, v in self.entries.items() if not v.commented} 616 | 617 | @contextlib.contextmanager 618 | def edit(self) -> Iterator[EntriesProxy]: 619 | r"""Context manager allowing edition of the Configuration instance. 620 | 621 | >>> import sys 622 | 623 | >>> cfg = Configuration() 624 | >>> includes = cfg.parse([ 625 | ... "#listen_addresses = 'localhost' # what IP address(es) to listen on;\n", 626 | ... " # comma-separated list of addresses;\n", 627 | ... "port = 5432 # (change requires restart)\n", 628 | ... "max_connections = 100 # (change requires restart)\n", 629 | ... ]) 630 | >>> list(includes) 631 | [] 632 | >>> cfg.save(sys.stdout) 633 | #listen_addresses = 'localhost' # what IP address(es) to listen on; 634 | # comma-separated list of addresses; 635 | port = 5432 # (change requires restart) 636 | max_connections = 100 # (change requires restart) 637 | 638 | >>> with cfg.edit() as entries: 639 | ... entries["port"].value = 2345 640 | ... entries["port"].comment = None 641 | ... entries["listen_addresses"].value = '*' 642 | ... del entries["max_connections"] 643 | ... entries.add( 644 | ... "unix_socket_directories", 645 | ... "'/var/run/postgresql'", 646 | ... comment="comma-separated list of directories", 647 | ... ) 648 | >>> cfg.save(sys.stdout) 649 | listen_addresses = '*' # what IP address(es) to listen on; 650 | # comma-separated list of addresses; 651 | port = 2345 652 | unix_socket_directories = '/var/run/postgresql' # comma-separated list of directories 653 | """ # noqa: E501 654 | entries = EntriesProxy({k: copy.copy(v) for k, v in self.entries.items()}) 655 | try: 656 | yield entries 657 | except Exception: 658 | raise 659 | else: 660 | # Add or update entries. 661 | for k, entry in entries.items(): 662 | assert isinstance(entry, Entry), "expecting Entry values" 663 | if k not in self: 664 | self._add_entry(entry) 665 | elif self.entries[k] != entry: 666 | self._update_entry(entry) 667 | # Discard removed entries. 668 | for k, entry in list(self.entries.items()): 669 | if k not in entries: 670 | del self.entries[k] 671 | if entry.raw_line is not None: 672 | self.lines.remove(entry.raw_line) 673 | 674 | def save(self, fo: str | pathlib.Path | IO[str] | None = None) -> None: 675 | """Write configuration to a file. 676 | 677 | Configuration entries order and comments are preserved. 678 | 679 | :param fo: A path or file-like object. Required if :attr:`path` is 680 | None. 681 | 682 | """ 683 | with open_or_return(fo or self.path, mode="w") as fo: 684 | for line in self.lines: 685 | fo.write(line) 686 | 687 | 688 | def _main(argv: list[str]) -> int: # pragma: nocover 689 | try: 690 | conf = parse(argv[0] if argv else sys.stdin) 691 | print(json.dumps(conf.as_dict(), cls=JSONDateEncoder, indent=2)) 692 | return 0 693 | except Exception as e: 694 | print(str(e), file=sys.stderr) 695 | return 1 696 | 697 | 698 | if __name__ == "__main__": # pragma: nocover 699 | exit(_main(sys.argv[1:])) 700 | -------------------------------------------------------------------------------- /pgtoolkit/ctl.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. currentmodule:: pgtoolkit.ctl 3 | 4 | API Reference 5 | ------------- 6 | 7 | .. autoclass:: PGCtl 8 | :members: 9 | .. autoclass:: AsyncPGCtl 10 | :members: 11 | .. autoclass:: Status 12 | :members: 13 | .. autofunction:: run_command 14 | .. autofunction:: asyncio_run_command 15 | .. autoclass:: CommandRunner 16 | :members: __call__ 17 | .. autoclass:: AsyncCommandRunner 18 | :members: __call__ 19 | """ 20 | 21 | from __future__ import annotations 22 | 23 | import abc 24 | import asyncio 25 | import enum 26 | import re 27 | import shutil 28 | import subprocess 29 | from collections.abc import Mapping, Sequence 30 | from functools import cached_property 31 | from pathlib import Path 32 | from typing import TYPE_CHECKING, Any, Literal, Protocol 33 | 34 | if TYPE_CHECKING: 35 | CompletedProcess = subprocess.CompletedProcess[str] 36 | else: 37 | CompletedProcess = subprocess.CompletedProcess 38 | 39 | 40 | class CommandRunner(Protocol): 41 | """Protocol for `run_command` callable parameter of :class:`PGCtl`. 42 | 43 | The `text` mode, as defined in :mod:`subprocess`, must be used in 44 | implementations. 45 | 46 | Keyword arguments are expected to match that of :func:`subprocess.run`. 47 | """ 48 | 49 | def __call__( 50 | self, 51 | args: Sequence[str], 52 | *, 53 | capture_output: bool = False, 54 | check: bool = False, 55 | **kwargs: Any, 56 | ) -> CompletedProcess: ... 57 | 58 | 59 | class AsyncCommandRunner(Protocol): 60 | """Protocol for `run_command` callable parameter of :class:`PGCtl`. 61 | 62 | The `text` mode, as defined in :mod:`subprocess`, must be used in 63 | implementations. 64 | 65 | Keyword arguments are expected to match that of :func:`subprocess.run`. 66 | """ 67 | 68 | async def __call__( 69 | self, 70 | args: Sequence[str], 71 | *, 72 | capture_output: bool = False, 73 | check: bool = False, 74 | **kwargs: Any, 75 | ) -> CompletedProcess: ... 76 | 77 | 78 | def run_command( 79 | args: Sequence[str], 80 | *, 81 | check: bool = False, 82 | **kwargs: Any, 83 | ) -> CompletedProcess: 84 | """Default :class:`CommandRunner` implementation for :class:`PGCtl` using 85 | :func:`subprocess.run`. 86 | """ 87 | return subprocess.run(args, check=check, text=True, **kwargs) 88 | 89 | 90 | async def asyncio_run_command( 91 | args: Sequence[str], 92 | *, 93 | capture_output: bool = False, 94 | check: bool = False, 95 | **kwargs: Any, 96 | ) -> CompletedProcess: 97 | """Default :class:`AsyncCommandRunner` implementation for 98 | :class:`AsyncPGCtl` using :func:`asyncio.subprocess`. 99 | """ 100 | if capture_output: 101 | kwargs["stdout"] = kwargs["stderr"] = subprocess.PIPE 102 | proc = await asyncio.create_subprocess_exec(*args, **kwargs) 103 | stdout, stderr = await proc.communicate() 104 | assert proc.returncode is not None 105 | r = CompletedProcess( 106 | args, 107 | proc.returncode, 108 | stdout.decode() if stdout is not None else None, 109 | stderr.decode() if stderr is not None else None, 110 | ) 111 | if check: 112 | r.check_returncode() 113 | return r 114 | 115 | 116 | def _args_to_opts(args: Mapping[str, str | Literal[True]]) -> list[str]: 117 | options = [] 118 | for name, value in sorted(args.items()): 119 | short = len(name) == 1 120 | name = name.replace("_", "-") 121 | if value is True: 122 | opt = f"-{name}" if short else f"--{name}" 123 | else: 124 | opt = f"-{name} {value}" if short else f"--{name}={value}" 125 | options.append(opt) 126 | return options 127 | 128 | 129 | def _wait_args_to_opts(wait: bool | int) -> list[str]: 130 | options = [] 131 | if not wait: 132 | options.append("--no-wait") 133 | else: 134 | options.append("--wait") 135 | if isinstance(wait, int) and not isinstance(wait, bool): 136 | options.append(f"--timeout={wait}") 137 | return options 138 | 139 | 140 | @enum.unique 141 | class Status(enum.IntEnum): 142 | """PostgreSQL cluster runtime status.""" 143 | 144 | running = 0 145 | """Running""" 146 | not_running = 3 147 | """Not running""" 148 | unspecified_datadir = 4 149 | """Unspecified data directory""" 150 | 151 | 152 | class AbstractPGCtl(abc.ABC): 153 | bindir: Path 154 | 155 | @cached_property 156 | def pg_ctl(self) -> Path: 157 | """Path to ``pg_ctl`` executable.""" 158 | value = self.bindir / "pg_ctl" 159 | if not value.exists(): 160 | raise OSError("pg_ctl executable not found") 161 | return value 162 | 163 | def init_cmd(self, datadir: Path | str, **opts: str | Literal[True]) -> list[str]: 164 | cmd = [str(self.pg_ctl), "init"] + ["-D", str(datadir)] 165 | options = _args_to_opts(opts) 166 | if options: 167 | cmd.extend(["-o", " ".join(options)]) 168 | return cmd 169 | 170 | def start_cmd( 171 | self, 172 | datadir: Path | str, 173 | *, 174 | wait: bool | int = True, 175 | logfile: Path | str | None = None, 176 | **opts: str | Literal[True], 177 | ) -> list[str]: 178 | cmd = [str(self.pg_ctl), "start"] + ["-D", str(datadir)] 179 | cmd.extend(_wait_args_to_opts(wait)) 180 | if logfile: 181 | cmd.append(f"--log={logfile}") 182 | options = _args_to_opts(opts) 183 | if options: 184 | cmd.extend(["-o", " ".join(options)]) 185 | return cmd 186 | 187 | def stop_cmd( 188 | self, 189 | datadir: Path | str, 190 | *, 191 | mode: str | None = None, 192 | wait: bool | int = True, 193 | ) -> list[str]: 194 | cmd = [str(self.pg_ctl), "stop"] + ["-D", str(datadir)] 195 | cmd.extend(_wait_args_to_opts(wait)) 196 | if mode: 197 | cmd.append(f"--mode={mode}") 198 | return cmd 199 | 200 | def restart_cmd( 201 | self, 202 | datadir: Path | str, 203 | *, 204 | mode: str | None = None, 205 | wait: bool | int = True, 206 | **opts: str | Literal[True], 207 | ) -> list[str]: 208 | cmd = [str(self.pg_ctl), "restart"] + ["-D", str(datadir)] 209 | cmd.extend(_wait_args_to_opts(wait)) 210 | if mode: 211 | cmd.append(f"--mode={mode}") 212 | options = _args_to_opts(opts) 213 | if options: 214 | cmd.extend(["-o", " ".join(options)]) 215 | return cmd 216 | 217 | def reload_cmd(self, datadir: Path | str) -> list[str]: 218 | return [str(self.pg_ctl), "reload"] + ["-D", str(datadir)] 219 | 220 | def status_cmd(self, datadir: Path | str) -> list[str]: 221 | return [str(self.pg_ctl), "status"] + ["-D", str(datadir)] 222 | 223 | def controldata_cmd(self, datadir: Path | str) -> list[str]: 224 | pg_controldata = self.bindir / "pg_controldata" 225 | if not pg_controldata.exists(): 226 | raise OSError("pg_controldata executable not found") 227 | return [str(pg_controldata)] + ["-D", str(datadir)] 228 | 229 | def _parse_control_data(self, lines: list[str]) -> dict[str, str]: 230 | """Parse pg_controldata command output.""" 231 | controldata = {} 232 | for line in lines: 233 | m = re.match(r"^([^:]+):(.*)$", line) 234 | if m: 235 | controldata[m.group(1).strip()] = m.group(2).strip() 236 | return controldata 237 | 238 | 239 | class PGCtl(AbstractPGCtl): 240 | """Handler for pg_ctl commands. 241 | 242 | :param bindir: location of postgresql user executables; if not specified, 243 | this will be determined by calling ``pg_config`` if that executable is 244 | found in ``$PATH``. 245 | :param run_command: callable implementing :class:`CommandRunner` that will 246 | be used to execute ``pg_ctl`` commands. 247 | 248 | :raises: :class:`OSError` if either ``pg_config`` or ``pg_ctl`` 249 | is not available. 250 | """ 251 | 252 | run_command: CommandRunner 253 | 254 | def __init__( 255 | self, 256 | bindir: str | Path | None = None, 257 | *, 258 | run_command: CommandRunner = run_command, 259 | ) -> None: 260 | if bindir is None: 261 | pg_config = shutil.which("pg_config") 262 | if pg_config is None: 263 | raise OSError("pg_config executable not found") 264 | bindir = run_command( 265 | [pg_config, "--bindir"], check=True, capture_output=True 266 | ).stdout.strip() 267 | self.bindir = Path(bindir) 268 | self.run_command = run_command 269 | 270 | def init( 271 | self, datadir: Path | str, **opts: str | Literal[True] 272 | ) -> CompletedProcess: 273 | """Initialize a PostgreSQL cluster (initdb) at `datadir`. 274 | 275 | :param datadir: Path to database storage area 276 | :param opts: extra options passed to initdb 277 | 278 | Options name passed as `opts` should be underscore'd instead dash'ed 279 | and flag options should be passed a boolean ``True`` value; e.g. 280 | ``auth_local="md5", data_checksums=True`` for ``pg_ctl init -o 281 | '--auth-local=md5 --data-checksums'``. 282 | """ 283 | return self.run_command(self.init_cmd(datadir, **opts), check=True) 284 | 285 | def start( 286 | self, 287 | datadir: Path | str, 288 | *, 289 | wait: bool | int = True, 290 | logfile: Path | str | None = None, 291 | **opts: str | Literal[True], 292 | ) -> CompletedProcess: 293 | """Start a PostgreSQL cluster. 294 | 295 | :param datadir: Path to database storage area 296 | :param wait: Wait until operation completes, if an integer value is 297 | passed, this will be used as --timeout value. 298 | :param logfile: Optional log file path 299 | :param opts: extra options passed to ``postgres`` command. 300 | 301 | Options name passed as `opts` should be underscore'd instead of dash'ed 302 | and flag options should be passed a boolean ``True`` value; e.g. 303 | ``F=True, work_mem=123`` for ``pg_ctl start -o '-F --work-mem=123'``. 304 | """ 305 | return self.run_command( 306 | self.start_cmd(datadir, wait=wait, logfile=logfile, **opts), check=True 307 | ) 308 | 309 | def stop( 310 | self, 311 | datadir: Path | str, 312 | *, 313 | mode: str | None = None, 314 | wait: bool | int = True, 315 | ) -> CompletedProcess: 316 | """Stop a PostgreSQL cluster. 317 | 318 | :param datadir: Path to database storage area 319 | :param mode: Shutdown mode, can be "smart", "fast", or "immediate" 320 | :param wait: Wait until operation completes, if an integer value is 321 | passed, this will be used as --timeout value. 322 | """ 323 | return self.run_command( 324 | self.stop_cmd(datadir, mode=mode, wait=wait), check=True 325 | ) 326 | 327 | def restart( 328 | self, 329 | datadir: Path | str, 330 | *, 331 | mode: str | None = None, 332 | wait: bool | int = True, 333 | **opts: str | Literal[True], 334 | ) -> CompletedProcess: 335 | """Restart a PostgreSQL cluster. 336 | 337 | :param datadir: Path to database storage area 338 | :param mode: Shutdown mode, can be "smart", "fast", or "immediate" 339 | :param wait: Wait until operation completes, if an integer value is 340 | passed, this will be used as --timeout value. 341 | :param opts: extra options passed to ``postgres`` command. 342 | 343 | Options name passed as `opts` should be underscore'd instead of dash'ed 344 | and flag options should be passed a boolean ``True`` value; e.g. 345 | ``F=True, work_mem=123`` for ``pg_ctl restart -o '-F --work-mem=123'``. 346 | """ 347 | return self.run_command( 348 | self.restart_cmd(datadir, mode=mode, wait=wait, **opts), check=True 349 | ) 350 | 351 | def reload( 352 | self, 353 | datadir: Path | str, 354 | ) -> CompletedProcess: 355 | """Reload a PostgreSQL cluster. 356 | 357 | :param datadir: Path to database storage area 358 | """ 359 | return self.run_command(self.reload_cmd(datadir), check=True) 360 | 361 | def status(self, datadir: Path | str) -> Status: 362 | """Check PostgreSQL cluster status. 363 | 364 | :param datadir: Path to database storage area 365 | :return: Status value. 366 | """ 367 | cp = self.run_command(self.status_cmd(datadir)) 368 | rc = cp.returncode 369 | if rc == 1: 370 | raise subprocess.CalledProcessError(rc, cp.args, cp.stdout, cp.stderr) 371 | return Status(rc) 372 | 373 | def controldata(self, datadir: Path | str) -> dict[str, str]: 374 | """Run the pg_controldata command and parse the result to return 375 | controldata as dict. 376 | 377 | :param datadir: Path to database storage area 378 | """ 379 | r = self.run_command( 380 | self.controldata_cmd(datadir), 381 | check=True, 382 | env={"LC_ALL": "C"}, 383 | capture_output=True, 384 | ).stdout 385 | return parse_control_data(r.splitlines()) 386 | 387 | 388 | class AsyncPGCtl(AbstractPGCtl): 389 | """Async handler for pg_ctl commands. 390 | 391 | See :class:`PGCtl` for the interface. 392 | """ 393 | 394 | run_command: AsyncCommandRunner 395 | 396 | def __init__(self, bindir: Path, run_command: AsyncCommandRunner) -> None: 397 | self.bindir = bindir 398 | self.run_command = run_command 399 | 400 | @classmethod 401 | async def get( 402 | cls, 403 | bindir: str | Path | None = None, 404 | *, 405 | run_command: AsyncCommandRunner = asyncio_run_command, 406 | ) -> AsyncPGCtl: 407 | """Construct an AsyncPGCtl instance from specified or inferred 'bindir'. 408 | 409 | :param bindir: location of postgresql user executables; if not specified, 410 | this will be determined by calling ``pg_config`` if that executable is 411 | found in ``$PATH``. 412 | :param run_command: callable implementing :class:`CommandRunner` that will 413 | be used to execute ``pg_ctl`` commands. 414 | 415 | :raises: :class:`OSError` if either ``pg_config`` or ``pg_ctl`` 416 | is not available. 417 | """ 418 | if bindir is None: 419 | pg_config = shutil.which("pg_config") 420 | if pg_config is None: 421 | raise OSError("pg_config executable not found") 422 | bindir = ( 423 | await run_command( 424 | [pg_config, "--bindir"], check=True, capture_output=True 425 | ) 426 | ).stdout.strip() 427 | bindir = Path(bindir) 428 | self = cls(bindir, run_command) 429 | return self 430 | 431 | async def init( 432 | self, datadir: Path | str, **opts: str | Literal[True] 433 | ) -> CompletedProcess: 434 | return await self.run_command(self.init_cmd(datadir, **opts), check=True) 435 | 436 | async def start( 437 | self, 438 | datadir: Path | str, 439 | *, 440 | wait: bool | int = True, 441 | logfile: Path | str | None = None, 442 | **opts: str | Literal[True], 443 | ) -> CompletedProcess: 444 | return await self.run_command( 445 | self.start_cmd(datadir, wait=wait, logfile=logfile, **opts), check=True 446 | ) 447 | 448 | async def stop( 449 | self, 450 | datadir: Path | str, 451 | *, 452 | mode: str | None = None, 453 | wait: bool | int = True, 454 | ) -> CompletedProcess: 455 | return await self.run_command( 456 | self.stop_cmd(datadir, mode=mode, wait=wait), check=True 457 | ) 458 | 459 | async def restart( 460 | self, 461 | datadir: Path | str, 462 | *, 463 | mode: str | None = None, 464 | wait: bool | int = True, 465 | **opts: str | Literal[True], 466 | ) -> CompletedProcess: 467 | return await self.run_command( 468 | self.restart_cmd(datadir, mode=mode, wait=wait, **opts), check=True 469 | ) 470 | 471 | async def reload( 472 | self, 473 | datadir: Path | str, 474 | ) -> CompletedProcess: 475 | return await self.run_command(self.reload_cmd(datadir), check=True) 476 | 477 | async def status(self, datadir: Path | str) -> Status: 478 | cp = await self.run_command(self.status_cmd(datadir)) 479 | rc = cp.returncode 480 | if rc == 1: 481 | raise subprocess.CalledProcessError(rc, cp.args, cp.stdout, cp.stderr) 482 | return Status(rc) 483 | 484 | async def controldata(self, datadir: Path | str) -> dict[str, str]: 485 | r = ( 486 | await self.run_command( 487 | self.controldata_cmd(datadir), 488 | check=True, 489 | env={"LC_ALL": "C"}, 490 | capture_output=True, 491 | ) 492 | ).stdout 493 | return parse_control_data(r.splitlines()) 494 | 495 | 496 | def parse_control_data(lines: Sequence[str]) -> dict[str, str]: 497 | """Parse pg_controldata command output.""" 498 | controldata = {} 499 | for line in lines: 500 | m = re.match(r"^([^:]+):(.*)$", line) 501 | if m: 502 | controldata[m.group(1).strip()] = m.group(2).strip() 503 | return controldata 504 | -------------------------------------------------------------------------------- /pgtoolkit/errors.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | 4 | class ParseError(Exception): 5 | def __init__(self, lineno: int, line: str, message: str) -> None: 6 | super().__init__(message) 7 | self.lineno = lineno 8 | self.line = line 9 | 10 | def __repr__(self) -> str: 11 | return "<%s at line %d: %.32s>" % ( 12 | self.__class__.__name__, 13 | self.lineno, 14 | self.args[0], 15 | ) 16 | 17 | def __str__(self) -> str: 18 | return "Bad line #{} '{:.32}': {}".format( 19 | self.lineno, 20 | self.line.strip(), 21 | self.args[0], 22 | ) 23 | -------------------------------------------------------------------------------- /pgtoolkit/hba.py: -------------------------------------------------------------------------------- 1 | """.. currentmodule:: pgtoolkit.hba 2 | 3 | This module supports reading, validating, editing and rendering ``pg_hba.conf`` 4 | file. See `Client Authentication 5 | `__ in 6 | PostgreSQL documentation for details on format and values of ``pg_hba.conf`` 7 | file. 8 | 9 | 10 | API Reference 11 | ------------- 12 | 13 | The main entrypoint of this API is the :func:`parse` function. It returns a 14 | :class:`HBA` object containing :class:`HBARecord` instances. 15 | 16 | .. autofunction:: parse 17 | .. autoclass:: HBA 18 | .. autoclass:: HBARecord 19 | 20 | 21 | Examples 22 | -------- 23 | 24 | Loading a ``pg_hba.conf`` file : 25 | 26 | .. code:: python 27 | 28 | pgpass = parse('my_pg_hba.conf') 29 | 30 | You can also pass a file-object: 31 | 32 | .. code:: python 33 | 34 | with open('my_pg_hba.conf', 'r') as fo: 35 | hba = parse(fo) 36 | 37 | Creating a ``pg_hba.conf`` file from scratch : 38 | 39 | .. code:: python 40 | 41 | hba = HBA() 42 | record = HBARecord( 43 | conntype='local', database='all', user='all', method='peer', 44 | ) 45 | hba.lines.append(record) 46 | 47 | with open('pg_hba.conf', 'w') as fo: 48 | hba.save(fo) 49 | 50 | 51 | Using as a script 52 | ----------------- 53 | 54 | :mod:`pgtoolkit.hba` is usable as a CLI script. It accepts a pg_hba file path 55 | as first argument, read it, validate it and re-render it. Fields are aligned to 56 | fit pseudo-column width. If filename is ``-``, stdin is read instead. 57 | 58 | .. code:: console 59 | 60 | $ python -m pgtoolkit.hba - < data/pg_hba.conf 61 | # TYPE DATABASE USER ADDRESS METHOD 62 | 63 | # "local" is for Unix domain socket connections only 64 | local all all trust 65 | # IPv4 local connections: 66 | host all all 127.0.0.1/32 ident map=omicron 67 | 68 | """ # noqa 69 | 70 | from __future__ import annotations 71 | 72 | import os 73 | import re 74 | import sys 75 | import warnings 76 | from collections.abc import Callable, Iterable, Iterator 77 | from dataclasses import dataclass, field 78 | from pathlib import Path 79 | from typing import IO, Any 80 | 81 | from ._helpers import open_or_return, open_or_stdin 82 | from .errors import ParseError 83 | 84 | 85 | class HBAComment(str): 86 | def __repr__(self) -> str: 87 | return f"<{self.__class__.__name__} {self:.32}>" 88 | 89 | 90 | class HBARecord: 91 | """Holds a HBA record composed of fields and a comment. 92 | 93 | Common fields are accessible through attribute : ``conntype``, 94 | ``database``, ``user``, ``address``, ``netmask``, ``method``. 95 | Auth-options fields are also accessible through attribute like ``map``, 96 | ``ldapserver``, etc. 97 | 98 | ``address`` and ``netmask`` fields are not always defined. If not, 99 | accessing undefined attributes trigger an :exc:`AttributeError`. 100 | 101 | .. automethod:: parse 102 | .. automethod:: __init__ 103 | .. automethod:: __str__ 104 | .. automethod:: matches 105 | .. autoattribute:: database 106 | .. autoattribute:: user 107 | 108 | """ 109 | 110 | COMMON_FIELDS = [ 111 | "conntype", 112 | "database", 113 | "user", 114 | "address", 115 | "netmask", 116 | "method", 117 | ] 118 | CONNECTION_TYPES = [ 119 | "local", 120 | "host", 121 | "hostssl", 122 | "hostnossl", 123 | "hostgssenc", 124 | "hostnogssenc", 125 | ] 126 | 127 | @classmethod 128 | def parse(cls, line: str) -> HBARecord: 129 | """Parse a HBA record 130 | 131 | :rtype: :class:`HBARecord` or a :class:`str` for a comment or blank 132 | line. 133 | :raises ValueError: If connection type is wrong. 134 | 135 | """ 136 | line = line.strip() 137 | record_fields = ["conntype", "database", "user"] 138 | 139 | # What the regexp below does is finding all elements separated by spaces 140 | # unless they are enclosed in double-quotes 141 | # (?: … )+ = non-capturing group 142 | # \"+.*?\"+ = any element with or without spaces enclosed within 143 | # double-quotes (alternative 1) 144 | # \S = any non-whitespace character (alternative 2) 145 | values = [p for p in re.findall(r"(?:\"+.*?\"+|\S)+", line) if p.strip()] 146 | assert len(values) > 2 147 | try: 148 | hash_pos = values.index("#") 149 | except ValueError: 150 | comment = None 151 | else: 152 | values, comments = values[:hash_pos], values[hash_pos:] 153 | comment = " ".join(comments[1:]) 154 | 155 | if values[0] not in cls.CONNECTION_TYPES: 156 | raise ValueError("Unknown connection type '%s'" % values[0]) 157 | if "local" != values[0]: 158 | record_fields.append("address") 159 | common_values = [v for v in values if "=" not in v] 160 | if len(common_values) >= 6: 161 | record_fields.append("netmask") 162 | record_fields.append("method") 163 | base_options = list(zip(record_fields, values[: len(record_fields)])) 164 | auth_options = [o.split("=", 1) for o in values[len(record_fields) :]] 165 | # Remove extra outer double quotes for auth options values if any 166 | auth_options = [(o[0], re.sub(r"^\"|\"$", "", o[1])) for o in auth_options] 167 | options = base_options + auth_options 168 | return cls(**{k: v for k, v in options}, comment=comment) 169 | 170 | conntype: str | None 171 | database: str 172 | user: str 173 | 174 | def __init__(self, *, comment: str | None = None, **values: Any) -> None: 175 | """ 176 | :param comment: Optional comment. 177 | :param values: Fields passed as keyword. 178 | """ 179 | self.__dict__.update(values) 180 | self.comment = comment 181 | self.fields = list(values) 182 | 183 | def __repr__(self) -> str: 184 | return "<{} {}{}>".format( 185 | self.__class__.__name__, 186 | " ".join(self.common_values), 187 | "..." if self.auth_options else "", 188 | ) 189 | 190 | def __str__(self) -> str: 191 | """Serialize a record line, without EOL.""" 192 | # Stolen from default pg_hba.conf 193 | widths = [8, 16, 16, 16, 8] 194 | 195 | fmt = "" 196 | for i, field_ in enumerate(self.COMMON_FIELDS): 197 | try: 198 | width = widths[i] 199 | except IndexError: 200 | width = 0 201 | 202 | if field_ not in self.fields: 203 | fmt += " " * width 204 | continue 205 | 206 | if width: 207 | fmt += "%%(%s)-%ds " % (field_, width - 1) 208 | else: 209 | fmt += f"%({field_})s " 210 | line = fmt.rstrip() % self.__dict__ 211 | 212 | auth_options = ['%s="%s"' % i for i in self.auth_options] 213 | if auth_options: 214 | line += " " + " ".join(auth_options) 215 | 216 | if self.comment is not None: 217 | line += " # " + self.comment 218 | else: 219 | line = line.rstrip() 220 | 221 | return line 222 | 223 | def __eq__(self, other: object) -> bool: 224 | return str(self) == str(other) 225 | 226 | def as_dict(self) -> dict[str, Any]: 227 | str_fields = self.COMMON_FIELDS[:] 228 | return {f: getattr(self, f) for f in str_fields if hasattr(self, f)} 229 | 230 | @property 231 | def common_values(self) -> list[str]: 232 | str_fields = self.COMMON_FIELDS[:] 233 | return [getattr(self, f) for f in str_fields if f in self.fields] 234 | 235 | @property 236 | def auth_options(self) -> list[tuple[str, str]]: 237 | return [ 238 | (f, getattr(self, f)) for f in self.fields if f not in self.COMMON_FIELDS 239 | ] 240 | 241 | @property 242 | def databases(self) -> list[str]: 243 | return self.database.split(",") 244 | 245 | @property 246 | def users(self) -> list[str]: 247 | return self.user.split(",") 248 | 249 | def matches(self, **attrs: str) -> bool: 250 | """Tells if the current record is matching provided attributes. 251 | 252 | :param attrs: keyword/values pairs corresponding to one or more 253 | HBARecord attributes (ie. user, conntype, etc…) 254 | """ 255 | 256 | # Provided attributes should be comparable to HBARecord attributes 257 | for k in attrs.keys(): 258 | if k not in self.COMMON_FIELDS + ["database", "user"]: 259 | raise AttributeError("%s is not a valid attribute" % k) 260 | 261 | for k, v in attrs.items(): 262 | if getattr(self, k, None) != v: 263 | return False 264 | return True 265 | 266 | 267 | @dataclass 268 | class HBA: 269 | """Represents pg_hba.conf records 270 | 271 | .. attribute:: lines 272 | 273 | List of :class:`HBARecord` and comments. 274 | 275 | .. attribute:: path 276 | 277 | Path to a file. Is automatically set when calling :meth:`parse` with a 278 | path to a file. :meth:`save` will write to this file if set. 279 | 280 | .. automethod:: __iter__ 281 | .. automethod:: parse 282 | .. automethod:: save 283 | .. automethod:: remove 284 | .. automethod:: merge 285 | """ 286 | 287 | lines: list[HBAComment | HBARecord] = field(default_factory=list) 288 | path: str | Path | None = None 289 | 290 | def __iter__(self) -> Iterator[HBARecord]: 291 | """Iterate on records, ignoring comments and blank lines.""" 292 | for line in self.lines: 293 | if isinstance(line, HBARecord): 294 | yield line 295 | 296 | def parse(self, fo: Iterable[str]) -> None: 297 | """Parse records and comments from file object 298 | 299 | :param fo: An iterable returning lines 300 | """ 301 | for i, line in enumerate(fo): 302 | stripped = line.lstrip() 303 | record: HBARecord | HBAComment 304 | if not stripped or stripped.startswith("#"): 305 | record = HBAComment(line.replace(os.linesep, "")) 306 | else: 307 | try: 308 | record = HBARecord.parse(line) 309 | except Exception as e: 310 | raise ParseError(1 + i, line, str(e)) 311 | self.lines.append(record) 312 | 313 | def save(self, fo: str | Path | IO[str] | None = None) -> None: 314 | """Write records and comments in a file 315 | 316 | :param fo: a file-like object. Is not required if :attr:`path` is set. 317 | 318 | Line order is preserved. Record fields are vertically aligned to match 319 | the columen size of column headers from default configuration file. 320 | 321 | .. code:: 322 | 323 | # TYPE DATABASE USER ADDRESS METHOD 324 | local all all trust 325 | """ # noqa 326 | with open_or_return(fo or self.path, mode="w") as fo: 327 | for line in self.lines: 328 | fo.write(str(line) + os.linesep) 329 | 330 | def remove( 331 | self, 332 | filter: Callable[[HBARecord], bool] | None = None, 333 | **attrs: str, 334 | ) -> bool: 335 | """Remove records matching the provided attributes. 336 | 337 | One can for example remove all records for which user is 'david'. 338 | 339 | :param filter: a function to be used as filter. It is passed the record 340 | to test against. If it returns True, the record is removed. It is 341 | kept otherwise. 342 | :param attrs: keyword/values pairs correspond to one or more 343 | HBARecord attributes (ie. user, conntype, etc...) 344 | 345 | :returns: ``True`` if records have changed. 346 | 347 | Usage examples: 348 | 349 | .. code:: python 350 | 351 | hba.remove(filter=lamdba r: r.user == 'david') 352 | hba.remove(user='david') 353 | 354 | """ 355 | if filter is not None and len(attrs.keys()): 356 | warnings.warn("Only filter will be taken into account") 357 | 358 | # Attributes list to look for must not be empty 359 | if filter is None and not len(attrs.keys()): 360 | raise ValueError("Attributes dict cannot be empty") 361 | 362 | filter = filter or (lambda line: line.matches(**attrs)) 363 | 364 | lines_before = self.lines 365 | 366 | self.lines = [ 367 | line 368 | for line in self.lines 369 | if not (isinstance(line, HBARecord) and filter(line)) 370 | ] 371 | 372 | return lines_before != self.lines 373 | 374 | def merge(self, other: HBA) -> bool: 375 | """Add new records to HBAFile or replace them if they are matching 376 | (ie. same conntype, database, user and address) 377 | 378 | :param other: HBAFile to merge into the current one. 379 | Lines with matching conntype, database, user and database will be 380 | replaced by the new one. Otherwise they will be added at the end. 381 | Comments from the original hba are preserved. 382 | 383 | :returns: ``True`` if records have changed. 384 | """ 385 | lines = self.lines[:] 386 | new_lines = other.lines[:] 387 | other_comments = [] 388 | 389 | for i, line in enumerate(lines): 390 | if isinstance(line, HBAComment): 391 | continue 392 | for new_line in new_lines: 393 | if isinstance(new_line, HBAComment): 394 | # preserve comments until next record 395 | other_comments.append(new_line) 396 | else: 397 | kwargs = dict() 398 | for a in ["conntype", "database", "user", "address"]: 399 | if hasattr(new_line, a): 400 | kwargs[a] = getattr(new_line, a) 401 | if line.matches(**kwargs): 402 | # replace matched line with comments + record 403 | self.lines[i : i + 1] = other_comments + [new_line] 404 | for c in other_comments: 405 | new_lines.remove(c) 406 | new_lines.remove(new_line) 407 | break # found match, go to next line 408 | other_comments[:] = [] 409 | # Then add remaining new lines (not merged) 410 | self.lines.extend(new_lines) 411 | 412 | return lines != self.lines 413 | 414 | 415 | def parse(file: str | Iterable[str] | Path) -> HBA: 416 | """Parse a `pg_hba.conf` file. 417 | 418 | :param file: Either a line iterator such as a file-like object, a path or a string 419 | corresponding to the path to the file to open and parse. 420 | :rtype: :class:`HBA`. 421 | """ 422 | if isinstance(file, (str, Path)): 423 | with open(file) as fo: 424 | hba = parse(fo) 425 | hba.path = file 426 | else: 427 | hba = HBA() 428 | hba.parse(file) 429 | return hba 430 | 431 | 432 | if __name__ == "__main__": # pragma: nocover 433 | argv = sys.argv[1:] + ["-"] 434 | try: 435 | with open_or_stdin(argv[0]) as fo: 436 | hba = parse(fo) 437 | hba.save(sys.stdout) 438 | except Exception as e: 439 | print(str(e), file=sys.stderr) 440 | exit(1) 441 | -------------------------------------------------------------------------------- /pgtoolkit/log/__init__.py: -------------------------------------------------------------------------------- 1 | """\ 2 | .. currentmodule:: pgtoolkit.log 3 | 4 | Postgres logs are still the most comprehensive source of information on what's 5 | going on in a cluster. :mod:`pgtoolkit.log` provides a parser to exploit 6 | efficiently Postgres log records from Python. 7 | 8 | Parsing logs is tricky because format varies across configurations. Also 9 | performance is important while logs can contain thousands of records. 10 | 11 | 12 | Configuration 13 | ------------- 14 | 15 | Postgres log records have a prefix, configured with ``log_line_prefix`` cluster 16 | setting. When analyzing a log file, you must known the ``log_line_prefix`` 17 | value used to generate the records. 18 | 19 | Postgres can emit more message for your needs. See `Error Reporting and Logging 20 | section 21 | `_ 22 | if PostgreSQL documentation for details on logging fields and message type. 23 | 24 | 25 | Performance 26 | ----------- 27 | 28 | The fastest code is NOOP. Thus, the parser allows you to filter records as soon 29 | as possible. The parser has several distinct stages. After each stage, the 30 | parser calls a filter to determine whether to stop record processing. Here are 31 | the stages in processing order : 32 | 33 | 1. Split prefix, severity and message, determine message type. 34 | 2. Extract and decode prefix data 35 | 3. Extract and decode message data. 36 | 37 | 38 | Limitations 39 | ----------- 40 | 41 | :mod:`pgtoolkit.log` does not manage opening and uncompressing logs. It only 42 | accepts a line reader iterator that loops log lines. The same way, 43 | :mod:`pgtoolkit.log` does not manage to start analyze at a specific position in 44 | a file. 45 | 46 | :mod:`pgtoolkit.log` does not gather record set such as ``ERROR`` and 47 | following ``HINT`` record. It's up to the application to make sense of record 48 | sequences. 49 | 50 | :mod:`pgtoolkit.log` does not analyze log records. It's just a parser, a 51 | building block to write a log analyzer in your app. 52 | 53 | 54 | API Reference 55 | ------------- 56 | 57 | Here are the few functions and classes used to parse and access log records. 58 | 59 | .. autofunction:: parse 60 | .. autoclass:: LogParser 61 | .. autoclass:: PrefixParser 62 | .. autoclass:: Record 63 | .. autoclass:: UnknownData 64 | .. autoclass:: NoopFilters 65 | 66 | 67 | Example 68 | ------- 69 | 70 | Here is a sample structure of code parsing a plain log file. 71 | 72 | .. code-block:: python 73 | 74 | with open('postgresql.log') as fo: 75 | for r in parse(fo, prefix_fmt='%m [%p]'): 76 | if isinstance(r, UnknownData): 77 | "Process unknown data" 78 | else: 79 | "Process record" 80 | 81 | 82 | 83 | Using :mod:`pgtoolkit.log` as a script 84 | -------------------------------------- 85 | 86 | You can use this module to dump logs as JSON using the following usage:: 87 | 88 | python -m pgtoolkit.log [] 89 | 90 | :mod:`pgtoolkit.log` serializes each record as a JSON object on a single line. 91 | 92 | .. code:: console 93 | 94 | $ python -m pgtoolkit.log '%m [%p]: [%l-1] app=%a,db=%d%q,client=%h,user=%u ' data/postgresql.log 95 | {"severity": "LOG", "timestamp": "2018-06-15T10:49:31.000144", "message_type": "connection", "line_num": 2, "remote_host": "[local]", "application": "[unknown]", "user": "postgres", "message": "connection authorized: user=postgres database=postgres", "database": "postgres", "pid": 8423} 96 | {"severity": "LOG", "timestamp": "2018-06-15T10:49:34.000172", "message_type": "connection", "line_num": 1, "remote_host": "[local]", "application": "[unknown]", "user": "[unknown]", "message": "connection received: host=[local]", "database": "[unknown]", "pid": 8424} 97 | 98 | """ # noqa 99 | 100 | from __future__ import annotations 101 | 102 | from .parser import LogParser, NoopFilters, PrefixParser, Record, UnknownData, parse 103 | 104 | __all__ = [ 105 | o.__name__ # type: ignore[attr-defined] 106 | for o in [ 107 | LogParser, 108 | NoopFilters, 109 | PrefixParser, 110 | Record, 111 | UnknownData, 112 | parse, 113 | ] 114 | ] 115 | -------------------------------------------------------------------------------- /pgtoolkit/log/__main__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import bdb 4 | import json 5 | import logging 6 | import os 7 | import pdb 8 | import sys 9 | from argparse import ArgumentParser 10 | from collections.abc import MutableMapping 11 | from logging import basicConfig 12 | 13 | from .._helpers import JSONDateEncoder, Timer, open_or_stdin 14 | from .parser import UnknownData, parse 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | def main( 20 | argv: list[str] = sys.argv[1:], 21 | environ: MutableMapping[str, str] = os.environ, 22 | ) -> int: 23 | debug = environ.get("DEBUG", "n").lower() in ("1", "y", "yes", "true", "on") 24 | basicConfig( 25 | level=logging.DEBUG if debug else logging.INFO, 26 | format="%(asctime)s %(levelname).1s: %(message)s", 27 | ) 28 | parser = ArgumentParser() 29 | # Default comes from PostgreSQL documentation. 30 | parser.add_argument( 31 | "log_line_prefix", 32 | default="%m [%p] ", 33 | metavar="LOG_LINE_PREFIX", 34 | help="log_line_prefix as configured in PostgreSQL. " "default: '%(default)s'", 35 | ) 36 | parser.add_argument( 37 | "filename", 38 | nargs="?", 39 | default="-", 40 | metavar="FILENAME", 41 | help="Log filename or - for stdin. default: %(default)s", 42 | ) 43 | args = parser.parse_args(argv) 44 | 45 | counter = 0 46 | try: 47 | with open_or_stdin(args.filename) as fo: 48 | with Timer() as timer: 49 | for record in parse(fo, prefix_fmt=args.log_line_prefix): 50 | if isinstance(record, UnknownData): 51 | logger.warning("%s", record) 52 | else: 53 | counter += 1 54 | print(json.dumps(record.as_dict(), cls=JSONDateEncoder)) 55 | logger.info("Parsed %d records in %s.", counter, timer.delta) 56 | except (KeyboardInterrupt, bdb.BdbQuit): # pragma: nocover 57 | logger.info("Interrupted.") 58 | return 1 59 | except Exception: 60 | logger.exception("Unhandled error:") 61 | if debug: # pragma: nocover 62 | pdb.post_mortem(sys.exc_info()[2]) 63 | return 1 64 | return 0 65 | 66 | 67 | if "__main__" == __name__: # pragma: nocover 68 | sys.exit(main(argv=sys.argv[1:], environ=os.environ)) 69 | -------------------------------------------------------------------------------- /pgtoolkit/log/parser.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import re 4 | from collections.abc import ( 5 | Callable, 6 | Iterable, 7 | Iterator, 8 | Mapping, 9 | MutableMapping, 10 | Sequence, 11 | ) 12 | from datetime import datetime, timedelta, timezone 13 | from re import Pattern 14 | from typing import Any 15 | 16 | 17 | class LogParser: 18 | """Log parsing manager 19 | 20 | This object gather parsing parameters and trigger parsing logic. When 21 | parsing multiple files with the same parameters or when parsing multiple 22 | sets of lines, :class:`LogParser` object ease the initialization and 23 | preservation of parsing parameters. 24 | 25 | When parsing a single set of lines, one can use :func:`parse` helper 26 | instead. 27 | 28 | :param prefix_parser: An instance of :class:`PrefixParser`. 29 | :param filters: An instance of :class:`NoopFilters` 30 | 31 | """ 32 | 33 | def __init__( 34 | self, prefix_parser: PrefixParser, filters: NoopFilters | None = None 35 | ) -> None: 36 | self.prefix_parser = prefix_parser 37 | self.filters = filters or NoopFilters() 38 | 39 | def parse(self, fo: Iterable[str]) -> Iterator[Record | UnknownData]: 40 | """Yield records and unparsed data from file-like object ``fo`` 41 | 42 | :param fo: A line iterator such as a file object. 43 | :rtype: Iterator[:class:`Record` | :class:`UnknownData`] 44 | :returns: Yields either :class:`Record` or :class:`UnknownData` object. 45 | """ 46 | # Fast access variables to avoid attribute access overhead on each 47 | # line. 48 | parse_prefix = self.prefix_parser.parse 49 | stage1 = Record.parse_stage1 50 | filter_stage1 = self.filters.stage1 51 | filter_stage2 = self.filters.stage2 52 | filter_stage3 = self.filters.stage3 53 | 54 | for group in group_lines(fo): 55 | try: 56 | record = stage1(group) 57 | if filter_stage1(record): 58 | continue 59 | record.parse_stage2(parse_prefix) 60 | if filter_stage2(record): 61 | continue 62 | record.parse_stage3() 63 | if filter_stage3(record): 64 | continue 65 | except UnknownData as e: 66 | yield e 67 | else: 68 | yield record 69 | 70 | 71 | def parse( 72 | fo: Iterable[str], prefix_fmt: str, filters: NoopFilters | None = None 73 | ) -> Iterator[Record | UnknownData]: 74 | """Parses log lines and yield :class:`Record` or :class:`UnknownData` objects. 75 | 76 | This is a helper around :class:`LogParser` and :`PrefixParser`. 77 | 78 | :param fo: A line iterator such as a file-like object. 79 | :param prefix_fmt: is exactly the value of ``log_line_prefix`` Postgresql 80 | settings. 81 | :param filters: is an object like :class:`NoopFilters` instance. 82 | 83 | See Example_ section for usage. 84 | 85 | """ 86 | 87 | parser = LogParser( 88 | PrefixParser.from_configuration(prefix_fmt), 89 | filters=filters, 90 | ) 91 | yield from parser.parse(fo) 92 | 93 | 94 | def group_lines(lines: Iterable[str], cont: str = "\t") -> Iterator[list[str]]: 95 | # Group continuation lines according to continuation prefix. Yield a list 96 | # on lines supposed to belong to the same log record. 97 | 98 | group: list[str] = [] 99 | for line in lines: 100 | if not line.startswith(cont) and group: 101 | yield group 102 | group = [] 103 | group.append(line) 104 | 105 | if group: 106 | yield group 107 | 108 | 109 | def parse_isodatetime(raw: str) -> datetime: 110 | try: 111 | infos = ( 112 | int(raw[:4]), 113 | int(raw[5:7]), 114 | int(raw[8:10]), 115 | int(raw[11:13]), 116 | int(raw[14:16]), 117 | int(raw[17:19]), 118 | int(raw[20:23]) if raw[19] == "." else 0, 119 | ) 120 | except ValueError: 121 | raise ValueError("%s is not a known date" % raw) 122 | 123 | if raw[-3:] != "UTC": 124 | # We need tzdata for that. 125 | raise ValueError("%s not in UTC." % raw) 126 | 127 | return datetime(*infos) 128 | 129 | 130 | def parse_epoch(raw: str) -> datetime: 131 | epoch, ms = raw.split(".") 132 | return datetime.fromtimestamp(int(epoch), timezone.utc) + timedelta( 133 | microseconds=int(ms) 134 | ) 135 | 136 | 137 | class UnknownData(Exception): 138 | """Represents unparsable data. 139 | 140 | :class:`UnknownData` is throwable, you can raise it. 141 | 142 | .. attribute:: lines 143 | 144 | The list of unparsable strings. 145 | """ 146 | 147 | # UnknownData object is an exception to be throwable. 148 | 149 | def __init__(self, lines: Sequence[str]) -> None: 150 | self.lines = lines 151 | 152 | def __repr__(self) -> str: 153 | summary = str(self)[:32].replace("\n", "") 154 | return f"<{self.__class__.__name__} {summary}...>" 155 | 156 | def __str__(self) -> str: 157 | return "".join(self.lines) 158 | 159 | 160 | class NoopFilters: 161 | """Basic filter doing nothing. 162 | 163 | Filters are grouped in an object to simplify the definition of a filtering 164 | policy. By subclassing :class:`NoopFilters`, you can implement simple 165 | filtering or heavy parameterized filtering policy from this API. 166 | 167 | If a filter method returns True, the record processing stops and the 168 | record is dropped. 169 | 170 | .. automethod:: stage1 171 | .. automethod:: stage2 172 | .. automethod:: stage3 173 | 174 | """ 175 | 176 | def stage1(self, record: Record) -> None: 177 | """First stage filter. 178 | 179 | :param Record record: A new record. 180 | :returns: ``True`` if record must be dropped. 181 | 182 | ``record`` has only `prefix`, `severity` and `message_type` 183 | attributes. 184 | """ 185 | 186 | def stage2(self, record: Record) -> None: 187 | """Second stage filter. 188 | 189 | :param Record record: A new record. 190 | :returns: ``True`` if record must be dropped. 191 | 192 | ``record`` has attributes from stage 1 plus attributes from prefix 193 | analysis. See :class:`Record` for details. 194 | """ 195 | 196 | def stage3(self, record: Record) -> None: 197 | """Third stage filter. 198 | 199 | :param Record record: A new record. 200 | :returns: ``True`` if record must be dropped. 201 | 202 | ``record`` has attributes from stage 2 plus attributes from message 203 | analysis, depending on message type. 204 | """ 205 | 206 | 207 | class PrefixParser: 208 | """Extract record metadata from PostgreSQL log line prefix. 209 | 210 | .. automethod:: from_configuration 211 | """ 212 | 213 | # cf. 214 | # https://www.postgresql.org/docs/current/static/runtime-config-logging.html#GUC-LOG-LINE-PREFIX 215 | 216 | _datetime_pat = r"\d{4}-[01]\d-[0-3]\d [012]\d:[0-6]\d:[0-6]\d" 217 | # Pattern map of Status information. 218 | _status_pat = dict( 219 | # Application name 220 | a=r"(?P\[unknown\]|\w+)?", 221 | # Session ID 222 | c=r"(?P\[unknown\]|[0-9a-f.]+)", 223 | # Database name 224 | d=r"(?P\[unknown\]|\w+)?", 225 | # SQLSTATE error code 226 | e=r"(?P\d+)", 227 | # Remote host name or IP address 228 | h=r"(?P\[local\]|\[unknown\]|[a-z0-9_-]+|[0-9.:]+)?", 229 | # Command tag: type of session's current command 230 | i=r"(?P\w+)", 231 | # Number of the log line for each session or process, starting at 1. 232 | l=r"(?P\d+)", # noqa 233 | # Time stamp with milliseconds 234 | m=r"(?P" + _datetime_pat + r".\d{3} [A-Z]{2,5})", 235 | # Time stamp with milliseconds (as a Unix epoch) 236 | n=r"(?P\d+\.\d+)", 237 | # Process ID 238 | p=r"(?P\d+)", 239 | # Remote host name or IP address, and remote port 240 | r=r"(?P\[local\]|\[unknown\]|[a-z0-9_-]+|[0-9.:]+\((?P\d+)\))?", # noqa 241 | # Process start time stamp 242 | s=r"(?P" + _datetime_pat + " [A-Z]{2,5})", 243 | # Time stamp without milliseconds 244 | t=r"(?P" + _datetime_pat + " [A-Z]{2,5})", 245 | # User name 246 | u=r"(?P\[unknown\]|\w+)?", 247 | # Virtual transaction ID (backendID/localXID) 248 | v=r"(?P\d+/\d+)", 249 | # Transaction ID (0 if none is assigned) 250 | x=r"(?P\d+)", 251 | ) 252 | # re to search for %… in log_line_prefix. 253 | _format_re = re.compile(r"%([" + "".join(_status_pat.keys()) + "])") 254 | # re to find %q separator in log_line_prefix. 255 | _q_re = re.compile(r"(? str: 270 | # Builds a pattern from each known fields. 271 | segments = cls._format_re.split(prefix) 272 | for i, segment in enumerate(segments): 273 | if i % 2: 274 | segments[i] = cls._status_pat[segment] 275 | else: 276 | segments[i] = re.escape(segment) 277 | return "".join(segments) 278 | 279 | @classmethod 280 | def from_configuration(cls, log_line_prefix: str) -> PrefixParser: 281 | """Factory from log_line_prefix 282 | 283 | Parses log_line_prefix and build a prefix parser from this. 284 | 285 | :param log_line_prefix: ``log_line_prefix`` PostgreSQL setting. 286 | :return: A :class:`PrefixParser` instance. 287 | 288 | """ 289 | optional: str | None 290 | try: 291 | fixed, optional = cls._q_re.split(log_line_prefix) 292 | except ValueError: 293 | fixed, optional = log_line_prefix, None 294 | 295 | pattern = cls.mkpattern(fixed) 296 | if optional: 297 | pattern += r"(?:" + cls.mkpattern(optional) + ")?" 298 | return cls(re.compile(pattern), log_line_prefix) 299 | 300 | def __init__(self, re_: Pattern[str], prefix_fmt: str | None = None) -> None: 301 | self.re_ = re_ 302 | self.prefix_fmt = prefix_fmt 303 | 304 | def __repr__(self) -> str: 305 | return f"<{self.__class__.__name__} '{self.prefix_fmt}'>" 306 | 307 | def parse(self, prefix: str) -> MutableMapping[str, Any]: 308 | # Parses the prefix line according to the inner regular expression. If 309 | # prefix does not match, raises an UnknownData. 310 | 311 | match = self.re_.search(prefix) 312 | if not match: 313 | raise UnknownData([prefix]) 314 | fields = match.groupdict() 315 | 316 | self.cast_fields(fields) 317 | 318 | # Ensure remote_host is fed either by %h or %r. 319 | remote_host = fields.pop("remote_host_r", None) 320 | if remote_host: 321 | fields.setdefault("remote_host", remote_host) 322 | 323 | # Ensure timestamp field is fed either by %m or %t. 324 | timestamp_ms = fields.pop("timestamp_ms", None) 325 | if timestamp_ms: 326 | fields.setdefault("timestamp", timestamp_ms) 327 | 328 | return fields 329 | 330 | @classmethod 331 | def cast_fields(cls, fields: MutableMapping[str, Any]) -> None: 332 | # In-place cast of values in fields dictionary. 333 | 334 | for k in fields: 335 | v = fields[k] 336 | if v is None: 337 | continue 338 | cast = cls._casts.get(k) 339 | if cast: 340 | fields[k] = cast(v) 341 | 342 | 343 | class Record: 344 | """Log record object. 345 | 346 | Record object stores record fields and implements the different parse 347 | stages. 348 | 349 | A record is primarily composed by a prefix, a severity and a message. 350 | Actually, severity is mixed with message type. For example, a HINT: message 351 | has the same severity as ``LOG:`` and is actually a continuation message 352 | (see csvlog output to compare). Thus we can determine easily message type 353 | as this stage. :mod:`pgtoolkit.log` does not rewrite message severity. 354 | 355 | Once prefix, severity and message are split, the parser analyze prefix 356 | according to ``log_line_prefix`` parameter. Prefix can give a lot of 357 | information for filtering, but costs some CPU cycles to process. 358 | 359 | Finally, the parser analyze the message to extract information such as 360 | statement, hint, duration, execution plan, etc. depending on the message 361 | type. 362 | 363 | These stages are separated so that marshalling can apply filter between 364 | each stage. 365 | 366 | .. automethod:: as_dict 367 | 368 | Each record field is accessible as an attribute : 369 | 370 | .. attribute:: prefix 371 | 372 | Raw prefix line. 373 | 374 | .. attribute:: severity 375 | 376 | One of ``DEBUG1`` to ``DEBUG5``, ``CONTEXT``, ``DETAIL``, ``ERROR``, 377 | etc. 378 | 379 | .. attribute:: message_type 380 | 381 | A string identifying message type. One of ``unknown``, ``duration``, 382 | ``connection``, ``analyze``, ``checkpoint``. 383 | 384 | .. attribute:: raw_lines 385 | 386 | A record can span multiple lines. This attribute keep a reference on 387 | raw record lines of the record. 388 | 389 | .. attribute:: message_lines 390 | 391 | Just like :attr:`raw_lines`, but the first line only include message, 392 | without prefix nor severity. 393 | 394 | The following attributes correspond to prefix fields. See `log_line_prefix 395 | documentation 396 | `_ 397 | for details. 398 | 399 | .. attribute:: application_name 400 | .. attribute:: command_tag 401 | .. attribute:: database 402 | .. attribute:: epoch 403 | 404 | :type: :class:`datetime.datetime` 405 | 406 | .. attribute:: error 407 | .. attribute:: line_num 408 | 409 | :type: :class:`int` 410 | 411 | .. attribute:: pid 412 | 413 | :type: :class:`int` 414 | 415 | .. attribute:: remote_host 416 | .. attribute:: remote_port 417 | 418 | :type: :class:`int` 419 | 420 | .. attribute:: session 421 | .. attribute:: start 422 | 423 | :type: :class:`datetime.datetime` 424 | 425 | .. attribute:: timestamp 426 | 427 | :type: :class:`datetime.datetime` 428 | 429 | .. attribute:: user 430 | .. attribute:: virtual_xid 431 | .. attribute:: xid 432 | 433 | :type: :class:`int` 434 | 435 | If the log lines miss a field, the record won't have the attribute. Use 436 | :func:`hasattr` to check whether a record have a specific attribute. 437 | """ 438 | 439 | __slots__ = ( 440 | "__dict__", 441 | "message_lines", 442 | "prefix", 443 | "raw_lines", 444 | ) 445 | 446 | # This actually mix severities and message types since they are in the same 447 | # field. 448 | _severities = [ 449 | "CONTEXT", 450 | "DETAIL", 451 | "ERROR", 452 | "FATAL", 453 | "HINT", 454 | "INFO", 455 | "LOG", 456 | "NOTICE", 457 | "PANIC", 458 | "QUERY", 459 | "STATEMENT", 460 | "WARNING", 461 | ] 462 | _stage1_re = re.compile("(DEBUG[1-5]|" + "|".join(_severities) + "): ") 463 | 464 | _types_prefixes = { 465 | "duration: ": "duration", 466 | "connection ": "connection", 467 | "disconnection": "connection", 468 | "automatic analyze": "analyze", 469 | "checkpoint ": "checkpoint", 470 | } 471 | 472 | @classmethod 473 | def guess_type(cls, severity: str, message_start: str) -> str: 474 | # Guess message type from severity and the first line of the message. 475 | 476 | if severity in ("HINT", "STATEMENT"): 477 | return severity.lower() 478 | for prefix in cls._types_prefixes: 479 | if message_start.startswith(prefix): 480 | return cls._types_prefixes[prefix] 481 | return "unknown" 482 | 483 | @classmethod 484 | def parse_stage1(cls, lines: list[str]) -> Record: 485 | # Stage1: split prefix, severity and message. 486 | try: 487 | prefix, severity, message0 = cls._stage1_re.split(lines[0], maxsplit=1) 488 | except ValueError: 489 | raise UnknownData(lines) 490 | 491 | return cls( 492 | prefix=prefix, 493 | severity=severity, 494 | message_type=cls.guess_type(severity, message0), 495 | message_lines=[message0] + lines[1:], 496 | raw_lines=lines, 497 | ) 498 | 499 | def __init__( 500 | self, 501 | prefix: str, 502 | severity: str, 503 | message_type: str = "unknown", 504 | message_lines: list[str] | None = None, 505 | raw_lines: list[str] | None = None, 506 | **fields: str, 507 | ) -> None: 508 | self.prefix = prefix 509 | self.severity = severity 510 | self.message_type = message_type 511 | self.message_lines = message_lines or [] 512 | self.raw_lines = raw_lines or [] 513 | self.__dict__.update(fields) 514 | 515 | def __repr__(self) -> str: 516 | return "<{} {}: {:.32}...>".format( 517 | self.__class__.__name__, 518 | self.severity, 519 | self.message_lines[0].replace("\n", ""), 520 | ) 521 | 522 | def parse_stage2(self, parse_prefix: Callable[[str], Mapping[str, Any]]) -> None: 523 | # Stage 2. Analyze prefix fields 524 | 525 | self.__dict__.update(parse_prefix(self.prefix)) 526 | 527 | def parse_stage3(self) -> None: 528 | # Stage 3. Analyze message lines. 529 | 530 | self.message = "".join( 531 | [line.lstrip("\t").rstrip("\n") for line in self.message_lines] 532 | ) 533 | 534 | def as_dict(self) -> dict[str, str | object | datetime]: 535 | """Returns record fields as a :class:`dict`.""" 536 | return {k: v for k, v in self.__dict__.items()} 537 | -------------------------------------------------------------------------------- /pgtoolkit/pgpass.py: -------------------------------------------------------------------------------- 1 | r""".. currentmodule:: pgtoolkit.pgpass 2 | 3 | This module provides support for `.pgpass` file format. Here are some 4 | highlights : 5 | 6 | - Supports ``:`` and ``\`` escape. 7 | - Sorts entry by precision (even if commented). 8 | - Preserves comments order when sorting. 9 | 10 | See `The Password File 11 | `__ section 12 | in PostgreSQL documentation. 13 | 14 | .. autofunction:: parse 15 | .. autofunction:: edit 16 | .. autoclass:: PassEntry 17 | .. autoclass:: PassComment 18 | .. autoclass:: PassFile 19 | 20 | 21 | Editing a .pgpass file 22 | ---------------------- 23 | 24 | .. code:: python 25 | 26 | with open('.pgpass') as fo: 27 | pgpass = parse(fo) 28 | pgpass.lines.append(PassEntry(username='toto', password='confidential')) 29 | pgpass.sort() 30 | with open('.pgpass', 'w') as fo: 31 | pgpass.save(fo) 32 | 33 | Shorter version using the file directly in `parse`: 34 | 35 | .. code:: python 36 | 37 | pgpass = parse('.pgpass') 38 | pgpass.lines.append(PassEntry(username='toto', password='confidential')) 39 | pgpass.sort() 40 | pgpass.save() 41 | 42 | Alternatively, this can be done with the `edit` context manager: 43 | 44 | .. code:: python 45 | 46 | with edit('.pgpass') as pgpass: 47 | pgpass.lines.append((PassEntry(username='toto', password='confidential')) 48 | passfile.sort() 49 | 50 | 51 | Using as a script 52 | ----------------- 53 | 54 | You can call :mod:`pgtoolkit.pgpass` module as a CLI script. It accepts a file 55 | path as first argument, read it, validate it, sort it and output it in stdout. 56 | 57 | 58 | .. code:: console 59 | 60 | $ python -m pgtoolkit.pgpass ~/.pgpass 61 | more:5432:precise:entry:0revea\\ed 62 | #disabled:5432:*:entry:0secret 63 | 64 | # Multiline 65 | # comment. 66 | other:5432:*:username:0unveiled 67 | *:*:*:postgres:c0nfident\:el 68 | 69 | """ # noqa 70 | 71 | from __future__ import annotations 72 | 73 | import os 74 | import sys 75 | import warnings 76 | from collections.abc import Callable, Iterable, Iterator 77 | from contextlib import contextmanager 78 | from pathlib import Path 79 | from typing import IO 80 | 81 | from ._helpers import open_or_stdin 82 | from .errors import ParseError 83 | 84 | 85 | def unescape(s: str, delim: str) -> str: 86 | return s.replace("\\" + delim, delim).replace("\\\\", "\\") 87 | 88 | 89 | def escapedsplit(s: str, delim: str) -> Iterator[str]: 90 | if len(delim) != 1: 91 | raise ValueError("Invalid delimiter: " + delim) 92 | 93 | ln = len(s) 94 | escaped = False 95 | i = 0 96 | j = 0 97 | 98 | while j < ln: 99 | if s[j] == "\\": 100 | escaped = not escaped 101 | elif s[j] == delim: 102 | if not escaped: 103 | yield unescape(s[i:j], delim) 104 | i = j + 1 105 | escaped = False 106 | j += 1 107 | yield unescape(s[i:j], delim) 108 | 109 | 110 | class PassComment(str): 111 | """A .pgpass comment, including spaces and ``#``. 112 | 113 | It's a child of ``str``. 114 | 115 | >>> comm = PassComment("# my comment") 116 | >>> comm.comment 117 | 'my comment' 118 | 119 | .. automethod:: matches 120 | 121 | .. attribute:: comment 122 | 123 | The actual message of the comment. Surrounding whitespaces stripped. 124 | 125 | """ 126 | 127 | def __repr__(self) -> str: 128 | return f"<{self.__class__.__name__} {self:.32}>" 129 | 130 | def __lt__(self, other: str) -> bool: 131 | if isinstance(other, PassEntry): 132 | try: 133 | return self.entry < other 134 | except ValueError: 135 | pass 136 | return False 137 | 138 | @property 139 | def comment(self) -> str: 140 | return self.lstrip("#").strip() 141 | 142 | @property 143 | def entry(self) -> PassEntry: 144 | if not hasattr(self, "_entry"): 145 | self._entry = PassEntry.parse(self.comment) 146 | return self._entry 147 | 148 | def matches(self, **attrs: int | str) -> bool: 149 | """In case of a commented entry, tells if it is matching provided 150 | attributes. Returns False otherwise. 151 | 152 | :param attrs: keyword/values pairs correspond to one or more 153 | PassEntry attributes (ie. hostname, port, etc...) 154 | """ 155 | try: 156 | return self.entry.matches(**attrs) 157 | except ValueError: 158 | return False 159 | 160 | 161 | class PassEntry: 162 | """Holds a .pgpass entry. 163 | 164 | .. automethod:: parse 165 | .. automethod:: matches 166 | 167 | .. attribute:: hostname 168 | 169 | Server hostname, the first field. 170 | 171 | .. attribute:: port 172 | 173 | Server port, the second field. 174 | 175 | .. attribute:: database 176 | 177 | Database, the third field. 178 | 179 | .. attribute:: username 180 | 181 | Username, the fourth field. 182 | 183 | .. attribute:: password 184 | 185 | Password, the fifth field. 186 | 187 | :class:`PassEntry` object is sortable. A :class:`PassEntry` object is lower 188 | than another if it is more specific. The more an entry has wildcard, the 189 | less it is specific. 190 | 191 | """ 192 | 193 | @classmethod 194 | def parse(cls, line: str) -> PassEntry: 195 | """Parse a single line. 196 | 197 | :param line: string containing a serialized .pgpass entry. 198 | :return: :class:`PassEntry` object holding entry data. 199 | :raises ValueError: on invalid line. 200 | """ 201 | fields = list(escapedsplit(line.strip(), ":")) 202 | if len(fields) != 5: 203 | raise ValueError("Invalid line.") 204 | hostname, port, database, username, password = fields 205 | return cls( 206 | hostname, int(port) if port != "*" else port, database, username, password 207 | ) 208 | 209 | def __init__( 210 | self, 211 | hostname: str, 212 | port: int | str, 213 | database: str, 214 | username: str, 215 | password: str, 216 | ) -> None: 217 | self.hostname = hostname 218 | self.port = port 219 | self.database = database 220 | self.username = username 221 | self.password = password 222 | 223 | def __eq__(self, other: object) -> bool: 224 | if isinstance(other, PassComment): 225 | try: 226 | other = other.entry 227 | except ValueError: 228 | return False 229 | if isinstance(other, PassEntry): 230 | return self.as_tuple()[:-1] == other.as_tuple()[:-1] 231 | return NotImplemented 232 | 233 | def __hash__(self) -> int: 234 | return hash(self.as_tuple()[:-1]) 235 | 236 | def __lt__(self, other: PassComment | PassEntry) -> bool: 237 | if isinstance(other, PassComment): 238 | try: 239 | other = other.entry 240 | except ValueError: 241 | return False 242 | if isinstance(other, PassEntry): 243 | return self.sort_key() < other.sort_key() 244 | return NotImplemented 245 | 246 | def __repr__(self) -> str: 247 | return "<{} {}@{}:{}/{}>".format( 248 | self.__class__.__name__, 249 | self.username, 250 | self.hostname, 251 | self.port, 252 | self.database, 253 | ) 254 | 255 | def __str__(self) -> str: 256 | return ":".join( 257 | [str(x).replace("\\", r"\\").replace(":", r"\:") for x in self.as_tuple()] 258 | ) 259 | 260 | def as_tuple(self) -> tuple[str, str, str, str, str]: 261 | return ( 262 | self.hostname, 263 | str(self.port), 264 | self.database, 265 | self.username, 266 | self.password, 267 | ) 268 | 269 | def sort_key(self) -> tuple[int, str, int | str, str, str]: 270 | tpl = self.as_tuple()[:-1] 271 | # Compute precision from * occurrences. 272 | precision = len([x for x in tpl if x == "*"]) 273 | # More specific entries comes first. 274 | return (precision,) + tuple(chr(0xFF) if x == "*" else x for x in tpl) # type: ignore[return-value] 275 | 276 | def matches(self, **attrs: int | str) -> bool: 277 | """Tells if the current entry is matching provided attributes. 278 | 279 | :param attrs: keyword/values pairs correspond to one or more 280 | PassEntry attributes (ie. hostname, port, etc...) 281 | """ 282 | 283 | # Provided attributes should be comparable to PassEntry attributes 284 | expected_attributes = self.__dict__.keys() 285 | for k in attrs.keys(): 286 | if k not in expected_attributes: 287 | raise AttributeError("%s is not a valid attribute" % k) 288 | 289 | for k, v in attrs.items(): 290 | if getattr(self, k) != v: 291 | return False 292 | return True 293 | 294 | 295 | class PassFile: 296 | """Holds .pgpass file entries and comments. 297 | 298 | .. automethod:: parse 299 | .. automethod:: __iter__ 300 | .. automethod:: sort 301 | .. automethod:: save 302 | .. automethod:: remove 303 | 304 | .. attribute:: lines 305 | 306 | List of either :class:`PassEntry` or :class:`PassFile`. You can add 307 | lines by appending :class:`PassEntry` or :class:`PassFile` instances to 308 | this list. 309 | 310 | .. attribute:: path 311 | 312 | Path to a file. Is automatically set when calling :meth:`parse` with a 313 | path to a file. :meth:`save` will write to this file if set. 314 | 315 | """ 316 | 317 | lines: list[PassComment | PassEntry] 318 | path: str | None = None 319 | 320 | def __init__( 321 | self, 322 | entries: list[PassComment | PassEntry] | None = None, 323 | *, 324 | path: str | None = None, 325 | ) -> None: 326 | """PassFile constructor. 327 | 328 | :param entries: A list of PassEntry or PassComment. Optional. 329 | """ 330 | if entries and not isinstance(entries, list): 331 | raise ValueError("%s should be a list" % entries) 332 | self.lines = entries or [] 333 | self.path = path 334 | 335 | def __iter__(self) -> Iterator[PassEntry]: 336 | """Iterate entries 337 | 338 | Yield :class:`PassEntry` instance from parsed file, ignoring comments. 339 | """ 340 | for line in self.lines: 341 | if isinstance(line, PassEntry): 342 | yield line 343 | 344 | def parse(self, fo: Iterable[str]) -> None: 345 | """Parse lines 346 | 347 | :param fo: A line iterator such as a file-like object. 348 | 349 | Raises ``ParseError`` if a bad line is found. 350 | """ 351 | entry: PassComment | PassEntry 352 | for i, line in enumerate(fo): 353 | stripped = line.lstrip() 354 | if not stripped or stripped.startswith("#"): 355 | entry = PassComment(line.replace(os.linesep, "")) 356 | else: 357 | try: 358 | entry = PassEntry.parse(line) 359 | except Exception as e: 360 | raise ParseError(1 + i, line, str(e)) 361 | self.lines.append(entry) 362 | 363 | def sort(self) -> None: 364 | """Sort entries preserving comments. 365 | 366 | libpq use the first entry from .pgpass matching connection information. 367 | Thus, less specific entries should be last in the file. This is the 368 | purpose of :func:`sort` method. 369 | 370 | About comments. Comments are supposed to bear with the entries 371 | **below**. Thus comments block are sorted according to the first entry 372 | below. 373 | 374 | Commented entries are sorted like entries, not like comment. 375 | """ 376 | # Sort but preserve comments above entries. 377 | entries = [] 378 | comments = [] 379 | for line in self.lines: 380 | if isinstance(line, PassComment): 381 | try: 382 | line.entry 383 | except ValueError: 384 | comments.append(line) 385 | continue 386 | 387 | entries.append((line, comments)) 388 | comments = [] 389 | 390 | self.lines[:] = [] 391 | if not entries and comments: 392 | # no entry, only comments 393 | self.lines.extend(comments) 394 | else: 395 | entries.sort() 396 | for entry, comments in entries: 397 | self.lines.extend(comments) 398 | self.lines.append(entry) 399 | 400 | def save(self, fo: IO[str] | None = None) -> None: 401 | """Save entries and comment in a file. 402 | 403 | :param fo: a file-like object. Is not required if :attr:`path` is set. 404 | """ 405 | 406 | def _write(fo: IO[str], lines: Iterable[object]) -> None: 407 | for line in lines: 408 | fo.write(str(line) + os.linesep) 409 | 410 | if fo: 411 | _write(fo, self.lines) 412 | elif self.path: 413 | fpath = Path(self.path) 414 | if not fpath.exists(): 415 | if not self.lines: 416 | return 417 | fpath.touch(mode=0o600) 418 | with open(self.path, "w") as fo: 419 | _write(fo, self.lines) 420 | else: 421 | raise ValueError("No file-like object nor path provided") 422 | 423 | def remove( 424 | self, 425 | filter: Callable[[PassComment | PassEntry | str], bool] | None = None, 426 | **attrs: int | str, 427 | ) -> None: 428 | """Remove entries matching the provided attributes. 429 | 430 | One can for example remove all entries for which port is 5433. 431 | 432 | Note: commented entries matching will also be removed. 433 | 434 | :param filter: a function to be used as filter. It is passed the line 435 | to test against. If it returns True, the line is removed. It is 436 | kept otherwise. 437 | :param attrs: keyword/values pairs correspond to one or more 438 | PassEntry attributes (ie. hostname, port, etc...) 439 | 440 | Usage examples: 441 | 442 | .. code:: python 443 | 444 | pgpass.remove(port=5432) 445 | pgpass.remove(filter=lambda r: r.port != 5432) 446 | """ 447 | if filter is not None and len(attrs): 448 | warnings.warn("Only filter will be taken into account") 449 | 450 | # Attributes list to look for must not be empty 451 | if filter is None and not len(attrs.keys()): 452 | raise ValueError("Attributes dict cannot be empty") 453 | 454 | if filter is not None: 455 | # Silently handle the case when line is a PassComment 456 | def filter_(line: PassComment | PassEntry) -> bool: 457 | assert filter is not None 458 | if isinstance(line, PassComment): 459 | try: 460 | return filter(line.entry) 461 | except ValueError: 462 | return False 463 | else: 464 | return filter(line) 465 | 466 | else: 467 | 468 | def filter_(line: PassComment | PassEntry) -> bool: 469 | return line.matches(**attrs) 470 | 471 | self.lines = [line for line in self.lines if not filter_(line)] 472 | 473 | 474 | def parse(file: Path | str | IO[str]) -> PassFile: 475 | """Parses a .pgpass file. 476 | 477 | :param file: Either a line iterator such as a file-like object or a file 478 | path to open and parse. 479 | :rtype: :class:`PassFile` 480 | """ 481 | if isinstance(file, (Path, str)): 482 | with open(os.path.expanduser(file)) as fo: 483 | pgpass = parse(fo) 484 | pgpass.path = str(file) 485 | else: 486 | pgpass = PassFile() 487 | pgpass.parse(file) 488 | return pgpass 489 | 490 | 491 | @contextmanager 492 | def edit(fpath: Path | str) -> Iterator[PassFile]: 493 | """Context manager to edit a .pgpass file. 494 | 495 | If the file does not exists, it is created with 600 permissions. 496 | Upon exit of the context manager, the file is saved, if no error occurred. 497 | """ 498 | fpath = Path(fpath).expanduser() 499 | if fpath.exists(): 500 | passfile = parse(fpath) 501 | else: 502 | passfile = PassFile(path=str(fpath)) 503 | yield passfile 504 | passfile.save() 505 | 506 | 507 | if __name__ == "__main__": # pragma: nocover 508 | argv = sys.argv[1:] + ["-"] 509 | try: 510 | with open_or_stdin(argv[0]) as fo: 511 | pgpass = parse(fo) 512 | pgpass.sort() 513 | pgpass.save(sys.stdout) 514 | except Exception as e: 515 | print(str(e), file=sys.stderr) 516 | exit(1) 517 | -------------------------------------------------------------------------------- /pgtoolkit/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dalibo/pgtoolkit/1996d702a46f39e7020005a638dad8b7200b209a/pgtoolkit/py.typed -------------------------------------------------------------------------------- /pgtoolkit/service.py: -------------------------------------------------------------------------------- 1 | """.. currentmodule:: pgtoolkit.service 2 | 3 | This module supports reading, validating, editing and rendering ``pg_service`` 4 | file. See `The Connection Service File 5 | `__ in 6 | PostgreSQL documentation. 7 | 8 | API Reference 9 | ------------- 10 | 11 | The main entrypoint of the API is the :func:`parse` function. :func:`find` 12 | function may be useful if you need to search for ``pg_service.conf`` files in 13 | regular locations. 14 | 15 | .. autofunction:: find 16 | .. autofunction:: parse 17 | .. autoclass:: Service 18 | .. autoclass:: ServiceFile 19 | 20 | 21 | Edit a service file 22 | ------------------- 23 | 24 | .. code:: python 25 | 26 | from pgtoolkit.service import parse, Service 27 | 28 | servicefilename = 'my_service.conf' 29 | with open(servicefile) as fo: 30 | servicefile = parse(fo, source=servicefilename) 31 | 32 | myservice = servicefile['myservice'] 33 | myservice.host = 'newhost' 34 | # Update service file 35 | servicefile.add(myservice) 36 | 37 | newservice = Service(name='newservice', host='otherhost') 38 | servicefile.add(newservice) 39 | 40 | with open(servicefile, 'w') as fo: 41 | servicefile.save(fo) 42 | 43 | Shorter version using the file directly in `parse`: 44 | 45 | .. code:: python 46 | 47 | servicefile = parse('my_service.conf') 48 | [...] 49 | servicefile.save() 50 | 51 | 52 | Load a service file to connect with psycopg2 53 | -------------------------------------------- 54 | 55 | Actually, psycopg2 already support pgservice file. This is just a showcase. 56 | 57 | .. code:: python 58 | 59 | from pgtoolkit.service import find, parse 60 | from psycopg2 import connect 61 | 62 | servicefilename = find() 63 | with open(servicefile) as fo: 64 | servicefile = parse(fo, source=servicefilename) 65 | connection = connect(**servicefile['myservice']) 66 | 67 | 68 | Using as a script 69 | ----------------- 70 | 71 | :mod:`pgtoolkit.service` is usable as a CLI script. It accepts a service file 72 | path as first argument, read it, validate it and re-render it, losing 73 | comments. 74 | 75 | :class:`ServiceFile` is less strict than `libpq`. Spaces are accepted around 76 | `=`. The output conform strictly to `libpq` parser. 77 | 78 | .. code:: console 79 | 80 | $ python -m pgtoolkit.service data/pg_service.conf 81 | [mydb] 82 | host=somehost 83 | port=5433 84 | user=admin 85 | 86 | [my ini-style] 87 | host=otherhost 88 | 89 | """ 90 | 91 | from __future__ import annotations 92 | 93 | import os 94 | import sys 95 | from collections.abc import Iterable, MutableMapping 96 | from configparser import ConfigParser 97 | from typing import IO, Union 98 | 99 | from ._helpers import open_or_stdin 100 | 101 | Parameter = Union[str, int] 102 | 103 | 104 | class Service(dict[str, Parameter]): 105 | """Service definition. 106 | 107 | The :class:`Service` class represents a single service definition in a 108 | Service file. It’s actually a dictionary of its own parameters. 109 | 110 | The ``name`` attributes is mapped to the section name of the service in the 111 | Service file. 112 | 113 | Each parameters can be accessed either as a dictionary entry or as an 114 | attributes. 115 | 116 | >>> myservice = Service('myservice', {'dbname': 'mydb'}, host='myhost') 117 | >>> myservice.name 118 | 'myservice' 119 | >>> myservice.dbname 120 | 'mydb' 121 | >>> myservice['dbname'] 122 | 'mydb' 123 | >>> myservice.user = 'myuser' 124 | >>> list(sorted(myservice.items())) 125 | [('dbname', 'mydb'), ('host', 'myhost'), ('user', 'myuser')] 126 | 127 | """ # noqa 128 | 129 | name: str 130 | 131 | def __init__( 132 | self, 133 | name: str, 134 | parameters: dict[str, Parameter] | None = None, 135 | **extra: Parameter, 136 | ) -> None: 137 | super().__init__() 138 | super().__setattr__("name", name) 139 | self.update(parameters or {}) 140 | self.update(extra) 141 | 142 | def __repr__(self) -> str: 143 | return f"<{self.__class__.__name__} {self.name}>" 144 | 145 | def __getattr__(self, name: str) -> Parameter: 146 | return self[name] 147 | 148 | def __setattr__(self, name: str, value: Parameter) -> None: 149 | self[name] = value 150 | 151 | 152 | class ServiceFile: 153 | """Service file representation, parsing and rendering. 154 | 155 | :class:`ServiceFile` is subscriptable. You can access service using 156 | ``servicefile['servicename']`` syntax. 157 | 158 | .. automethod:: add 159 | .. automethod:: parse 160 | .. automethod:: save 161 | 162 | .. attribute:: path 163 | 164 | Path to a file. Is automatically set when calling :meth:`parse` with a 165 | path to a file. :meth:`save` will write to this file if set. 166 | """ 167 | 168 | path: str | None 169 | 170 | _CONVERTERS = { 171 | "port": int, 172 | } 173 | 174 | def __init__(self) -> None: 175 | self.path = None 176 | self.config = ConfigParser( 177 | comment_prefixes=("#",), 178 | delimiters=("=",), 179 | ) 180 | 181 | def __repr__(self) -> str: 182 | return "<%s>" % (self.__class__.__name__) 183 | 184 | def __getitem__(self, key: str) -> Service: 185 | parameters = { 186 | k: self._CONVERTERS.get(k, str)(v) for k, v in self.config.items(key) 187 | } 188 | return Service(key, parameters) 189 | 190 | def __len__(self) -> int: 191 | return len(self.config.sections()) 192 | 193 | def add(self, service: Service) -> None: 194 | """Adds a :class:`Service` object to the service file.""" 195 | self.config.remove_section(service.name) 196 | self.config.add_section(service.name) 197 | for parameter, value in service.items(): 198 | self.config.set(service.name, parameter, str(value)) 199 | 200 | def parse(self, fo: Iterable[str], source: str | None = None) -> None: 201 | """Add service from a service file. 202 | 203 | This method is strictly the same as :func:`parse`. It’s the method 204 | counterpart. 205 | """ 206 | self.config.read_file(fo, source=source) 207 | 208 | def save(self, fo: IO[str] | None = None) -> None: 209 | """Writes services in ``fo`` file-like object. 210 | 211 | :param fo: a file-like object. Is not required if :attr:`path` is set. 212 | 213 | .. note:: Comments are not preserved. 214 | """ 215 | config = self.config 216 | 217 | def _write(fo: IO[str]) -> None: 218 | config.write(fo, space_around_delimiters=False) 219 | 220 | if fo: 221 | _write(fo) 222 | elif self.path: 223 | with open(self.path, "w") as fo: 224 | _write(fo) 225 | else: 226 | raise ValueError("No file-like object nor path provided") 227 | 228 | 229 | def guess_sysconfdir(environ: MutableMapping[str, str] = os.environ) -> str: 230 | fromenv = environ.get("PGSYSCONFDIR") 231 | if fromenv: 232 | candidates = [fromenv] 233 | else: 234 | candidates = [ 235 | # From PGDG APT packages. 236 | "/etc/postgresql-common", 237 | # From PGDG RPM packages. 238 | "/etc/sysconfig/pgsql", 239 | ] 240 | 241 | for candidate in candidates: 242 | if candidate and os.path.isdir(candidate): 243 | return candidate 244 | raise Exception("Can't find sysconfdir") 245 | 246 | 247 | def find(environ: MutableMapping[str, str] | None = None) -> str: 248 | """Find service file. 249 | 250 | :param dict environ: Dict of environment variables. 251 | 252 | :func:`find` searches for the first candidate of ``pg_service.conf`` file 253 | from either environment and regular locations. :func:`find` raises an 254 | Exception if it fails to find a Connection service file. 255 | 256 | .. code:: python 257 | 258 | from pgtoolkit.service import find 259 | 260 | try: 261 | servicefile = find() 262 | except Exception as e: 263 | "Deal with exception." 264 | else: 265 | "Manage servicefile." 266 | 267 | """ 268 | if environ is None: 269 | environ = os.environ 270 | 271 | fromenv = environ.get("PGSERVICEFILE") 272 | if fromenv: 273 | candidates = [fromenv] 274 | else: 275 | candidates = [os.path.expanduser("~/.pg_service.conf")] 276 | try: 277 | sysconfdir = guess_sysconfdir(environ) 278 | except Exception: 279 | pass 280 | else: 281 | candidates.append(os.path.join(sysconfdir, "pg_service.conf")) 282 | 283 | for candidate in candidates: 284 | if os.path.exists(candidate): 285 | return candidate 286 | raise Exception("Can't find pg_service file.") 287 | 288 | 289 | def parse(file: str | Iterable[str], source: str | None = None) -> ServiceFile: 290 | """Parse a service file. 291 | 292 | :param file: a file-object as returned by open or a string corresponding to 293 | the path to a file to open and parse. 294 | :param source: Name of the source. 295 | :rtype: A ``ServiceFile`` object. 296 | 297 | Actually it only requires as ``fo`` an iterable object yielding each lines 298 | of the file. You can provide ``source`` to have more precise error message. 299 | 300 | .. warning:: 301 | 302 | pgtoolkit is less strict than `libpq`. `libpq` does not accepts spaces 303 | around equals. pgtoolkit accepts spaces but do not write them. 304 | 305 | """ 306 | if isinstance(file, str): 307 | with open(file) as fo: 308 | services = parse(fo, source=source) 309 | services.path = file 310 | else: 311 | services = ServiceFile() 312 | services.parse(file, source=source) 313 | return services 314 | 315 | 316 | if __name__ == "__main__": # pragma: nocover 317 | argv = sys.argv[1:] + ["-"] 318 | try: 319 | with open_or_stdin(argv[0]) as fo: 320 | services = parse(fo) 321 | services.save(sys.stdout) 322 | except Exception as e: 323 | print(str(e), file=sys.stderr) 324 | exit(1) 325 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel", "setuptools_scm"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "pgtoolkit" 7 | description = "PostgreSQL Support from Python" 8 | readme = "README.rst" 9 | license = { text = "PostgreSQL" } 10 | requires-python = ">=3.9" 11 | authors = [ 12 | { name = "Dalibo", email = "contact@dalibo.com" }, 13 | ] 14 | maintainers = [ 15 | { name = "Denis Laxalde", email = "denis.laxalde@dalibo.com" }, 16 | { name = "Pierre Giraud", email = "pierre.giraud@dalibo.com" }, 17 | { name = "Julian Vanden Broeck", email = "julian.vandenbroeck@dalibo.com" }, 18 | ] 19 | keywords = ["postgresql", "postgresql.conf", "pg_hba", "pgpass", "pg_service.conf"] 20 | classifiers = [ 21 | "Development Status :: 3 - Alpha", 22 | "Intended Audience :: Developers", 23 | "License :: OSI Approved :: PostgreSQL License", 24 | "Programming Language :: Python :: 3", 25 | "Topic :: Database", 26 | ] 27 | dynamic = ["version"] 28 | 29 | [project.optional-dependencies] 30 | dev = [ 31 | "pgtoolkit[lint,typing,test,doc]", 32 | ] 33 | lint = [ 34 | "black", 35 | "check-manifest", 36 | "flake8", 37 | "isort", 38 | "pyupgrade", 39 | ] 40 | typing = [ 41 | "mypy", 42 | ] 43 | test = [ 44 | "pytest", 45 | "pytest-asyncio", 46 | "pytest-cov", 47 | "pytest-mock", 48 | "psycopg2-binary", 49 | ] 50 | doc = [ 51 | "sphinx", 52 | "sphinx-autobuild", 53 | "sphinx_rtd_theme", 54 | ] 55 | 56 | [project.urls] 57 | Repository = "https://github.com/dalibo/pgtoolkit" 58 | Documentation = "https://pgtoolkit.readthedocs.io/" 59 | 60 | [tool.isort] 61 | profile = "black" 62 | multi_line_output = 3 63 | 64 | [tool.setuptools.packages.find] 65 | where = ["."] 66 | 67 | [tool.setuptools.package-data] 68 | pgtoolkit = ["py.typed"] 69 | 70 | [tool.setuptools_scm] 71 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | addopts = -vvv --strict-markers --showlocals --doctest-modules --ignore pgtoolkit/ctl.py 3 | asyncio_mode = strict 4 | asyncio_default_fixture_loop_scope = function 5 | filterwarnings = 6 | error 7 | -------------------------------------------------------------------------------- /rpm/Makefile: -------------------------------------------------------------------------------- 1 | TOPDIR=$(shell readlink -e ../dist) 2 | YUM_LABS=../../yum-labs 3 | 4 | all: 5 | $(MAKE) -sC $(YUM_LABS) clean 6 | $(MAKE) build-centos7 build-centos8 7 | 8 | build-centos%: 9 | docker-compose run --rm centos$* 10 | mkdir -p $(YUM_LABS)/rpms/CentOS$*-x86_64 11 | cp -f $(shell readlink -e $(TOPDIR)/last_build.rpm) $(YUM_LABS)/rpms/CentOS$*-x86_64/ 12 | 13 | push: 14 | $(MAKE) -sC $(YUM_LABS) push createrepos clean 15 | -------------------------------------------------------------------------------- /rpm/README.md: -------------------------------------------------------------------------------- 1 | # Building RPM Package 2 | 3 | With docker-compose, you can build a RPM package for pgtoolkit in a few steps. 4 | 5 | ``` console 6 | $ make all push 7 | ``` 8 | 9 | The spec file is based on [Devrim Günduz](https://twitter.com/DevrimGunduz) 10 | packaging for pgspecial. 11 | 12 | The version in `rpm/python-pgtoolkit.spec` file may need to be updated. 13 | -------------------------------------------------------------------------------- /rpm/build: -------------------------------------------------------------------------------- 1 | #!/bin/bash -eux 2 | 3 | teardown() { 4 | exit_code=$? 5 | # If not on CI, wait for user interrupt on exit 6 | if [ -z "${CI-}" -a $exit_code -gt 0 -a $$ = 1 ] ; then 7 | tail -f /dev/null 8 | fi 9 | } 10 | 11 | trap teardown EXIT TERM 12 | 13 | top_srcdir=$(readlink -m $0/../..) 14 | cd $top_srcdir 15 | test -f setup.py 16 | 17 | yum_install() { 18 | local packages=$* 19 | sudo yum install -y $packages 20 | rpm --query --queryformat= $packages 21 | } 22 | 23 | # Fasten yum by disabling updates repository 24 | if [ -f /etc/yum.repos.d/CentOS-Base.repo ] ; then 25 | sudo sed -i '/^\[updates\]/,/^gpgkey=/d' /etc/yum.repos.d/CentOS-Base.repo 26 | fi 27 | 28 | # Purge previous installation 29 | if rpm --query --queryformat= python3-pgtoolkit ; then 30 | sudo yum remove -y python3-pgtoolkit 31 | fi 32 | 33 | rm -rf build/bdist*/rpm 34 | 35 | yum_install python39 python39-setuptools 36 | 37 | # Set default python3 to python3.9 38 | update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.9 1 39 | 40 | # Build it 41 | 42 | sudo sed -i 's/\.centos//' /etc/rpm/macros.dist 43 | 44 | rpmbuild -bb \ 45 | --undefine _disable_source_fetch \ 46 | --define "_topdir ${top_srcdir}/dist" \ 47 | --define "_sourcedir ${top_srcdir}/dist" \ 48 | --define "_rpmdir ${top_srcdir}/dist" \ 49 | rpm/python-pgtoolkit.spec 50 | version=$(sed -n '/^Version:/{s,.*:\t,,g; p; q}' rpm/python-pgtoolkit.spec) 51 | rpm=dist/noarch/python3-pgtoolkit*${version}*$(rpm --eval '%dist').noarch.rpm 52 | ln -fs noarch/$(basename $rpm) dist/last_build.rpm 53 | 54 | chown -R $(id -nu):$(id -ng) dist 55 | 56 | # Test it 57 | sudo yum install -y $rpm 58 | cd / 59 | python3 -c 'import pgtoolkit' 60 | -------------------------------------------------------------------------------- /rpm/docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "3" 2 | 3 | services: 4 | centos7: 5 | image: rpmbuild/centos7 6 | volumes: 7 | - ../:/srv 8 | command: /srv/rpm/build 9 | 10 | centos6: 11 | image: rpmbuild/centos6 12 | volumes: 13 | - ../:/srv 14 | command: /srv/rpm/build 15 | 16 | centos8: 17 | image: dalibo/buildpack-pkg:rockylinux8 18 | volumes: 19 | - ../:/srv 20 | command: /srv/rpm/build 21 | -------------------------------------------------------------------------------- /rpm/python-pgtoolkit.spec: -------------------------------------------------------------------------------- 1 | %global __ospython3 %{_bindir}/python3 2 | %{expand: %%global py3ver %(echo `%{__ospython3} -c "import sys; sys.stdout.write(sys.version[:3])"`)} 3 | %global python3_sitelib %(%{__ospython3} -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())") 4 | 5 | %global sname pgtoolkit 6 | %global srcname pgtoolkit 7 | 8 | Name: python3-%{sname} 9 | # Must point to a released version on PyPI. 10 | Version: 0.24.1 11 | Release: 1%{?dist} 12 | Epoch: 1 13 | Summary: Manage Postgres cluster files from Python 14 | 15 | License: PostgreSQL 16 | URL: https://pypi.org/project/pgtoolkit/ 17 | Source0: https://files.pythonhosted.org/packages/source/%(n=%{srcname}; echo ${n:0:1})/%{srcname}/%{srcname}-%{version}.tar.gz 18 | BuildArch: noarch 19 | BuildRequires: python3-setuptools 20 | 21 | %description 22 | pgtoolkit provides implementations to manage various file formats in Postgres 23 | cluster or from libpq. Including pg_hba.conf, logs, .pgpass and pg_service.conf. 24 | 25 | %prep 26 | %setup -q -n %{srcname}-%{version} 27 | 28 | %build 29 | CFLAGS="%{optflags}" %{__ospython3} setup.py build 30 | 31 | %install 32 | %{__ospython3} setup.py install --skip-build --root %{buildroot} 33 | 34 | 35 | %files 36 | %doc README.rst 37 | %{python3_sitelib}/%{sname}-%{version}-py%{py3ver}.egg-info 38 | %dir %{python3_sitelib}/%{sname} 39 | %{python3_sitelib}/%{sname}/* 40 | 41 | %changelog 42 | * Tue Jul 28 2020 Denis Laxalde - 1:0.8.0-1 43 | - Only build the Python3 version. 44 | * Tue Aug 28 2018 Étienne BERSAC - 1:0.0.1b1-1 45 | - Initial packaging. 46 | -------------------------------------------------------------------------------- /scripts/profile-log: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import cProfile 4 | import pstats 5 | 6 | from pgtoolkit.log import parse # noqa 7 | 8 | 9 | def generate_lines(lines, count=20000): 10 | while count > 0: 11 | yield from iter(lines) 12 | count -= len(lines) 13 | 14 | 15 | def main(): 16 | log_line_prefix = '%m [%p]: [%l-1] app=%a,db=%d%q,client=%h,user=%u ' # noqa 17 | filename = 'tests/data/postgresql.log' 18 | with open(filename) as fo: 19 | lines = fo.readlines() 20 | lines = generate_lines(lines) 21 | 22 | cProfile.runctx( 23 | 'list(parse(lines, log_line_prefix))', 24 | globals=globals(), 25 | locals=locals(), 26 | filename='my-log-stats', 27 | ) 28 | p = pstats.Stats('my-log-stats') 29 | p.strip_dirs().sort_stats('cumtime').print_stats() 30 | 31 | 32 | if __name__ == '__main__': 33 | main() 34 | -------------------------------------------------------------------------------- /tests/data/conf.d/.hidden.conf: -------------------------------------------------------------------------------- 1 | # invalid values, file should not be parsed (from inclusion) 2 | p o r t = 666 3 | listen_addresses = '0.0. 4 | -------------------------------------------------------------------------------- /tests/data/conf.d/.includeme: -------------------------------------------------------------------------------- 1 | bonsoir = on 2 | -------------------------------------------------------------------------------- /tests/data/conf.d/listen.conf: -------------------------------------------------------------------------------- 1 | listen_addresses = '1.2.3.4' 2 | -------------------------------------------------------------------------------- /tests/data/conf.d/with-include.conf: -------------------------------------------------------------------------------- 1 | include = '.includeme' 2 | -------------------------------------------------------------------------------- /tests/data/pg_hba.conf: -------------------------------------------------------------------------------- 1 | # CAUTION: Configuring the system for local "trust" authentication 2 | # allows any local user to connect as any PostgreSQL user, including 3 | # the database superuser. If you do not trust all your local users, 4 | # use another authentication method. 5 | 6 | 7 | # TYPE DATABASE USER ADDRESS METHOD 8 | 9 | # "local" is for Unix domain socket connections only 10 | local all all trust 11 | # IPv4 local connections: 12 | host all all 127.0.0.1/32 ident map=omicron 13 | host all toto,tata,iqw616,pyy0799 all ldap ldapserver=svcldap.grandmars.fr ldapbasedn="ou=utilisateurs,o=annuaire" ldapsearchattribute="cn" ldapbinddn="cn=dev,ou=services,o=annuaire" ldapbindpasswd="Conf1tdanc!el" ldapport=389 14 | # IPv6 local connections: 15 | host all all ::1/128 trust # End of line comment 16 | # Allow replication connections from localhost, by a user with the 17 | # replication privilege. 18 | local replication all trust 19 | host replication all 127.0.0.1/32 trust 20 | host replication all ::1/128 trust 21 | 22 | host all all all trust 23 | 24 | # Accentué. 25 | -------------------------------------------------------------------------------- /tests/data/pg_hba_bad.conf: -------------------------------------------------------------------------------- 1 | local all all 127.0.0.1/32 trust 2 | -------------------------------------------------------------------------------- /tests/data/pg_service.conf: -------------------------------------------------------------------------------- 1 | # From https://www.postgresql.org/docs/current/static/libpq-pgservice.html 2 | # comment 3 | [mydb] 4 | host=somehost 5 | port=5433 6 | user=admin 7 | 8 | [my ini-style] 9 | host = otherhost 10 | -------------------------------------------------------------------------------- /tests/data/pg_service_bad.conf: -------------------------------------------------------------------------------- 1 | [badsection 2 | -------------------------------------------------------------------------------- /tests/data/pgpass: -------------------------------------------------------------------------------- 1 | *:*:*:postgres:c0nfident\:el 2 | # Comment accentué 3 | 4 | more:5432:precise:entry:0revea\\ed 5 | 6 | # Multi 7 | # Line 8 | # Ordered 9 | # Comment. 10 | other:5432:*:username:0unveiled 11 | #disabled:5432:*:entry:0secret 12 | -------------------------------------------------------------------------------- /tests/data/pgpass_bad: -------------------------------------------------------------------------------- 1 | bad:line 2 | -------------------------------------------------------------------------------- /tests/data/postgres-my-my.conf: -------------------------------------------------------------------------------- 1 | log_line_prefix = '%m %q@%d' 2 | mymy = true 3 | -------------------------------------------------------------------------------- /tests/data/postgres-my.conf: -------------------------------------------------------------------------------- 1 | cluster_name = 'pgtoolkit' 2 | authentication_timeout = 2min 3 | include_if_exists = 'missing.conf' 4 | include_if_exists = 'postgres-my-my.conf' 5 | mymy = false 6 | mymymy = false 7 | include = 'postgres-mymymy.conf' 8 | my = true 9 | -------------------------------------------------------------------------------- /tests/data/postgres-mymymy.conf: -------------------------------------------------------------------------------- 1 | mymymy = true 2 | -------------------------------------------------------------------------------- /tests/data/postgres.conf: -------------------------------------------------------------------------------- 1 | include_dir = 'conf.d' 2 | #include_dir = 'conf.11.d' 3 | include = 'postgres-my.conf' 4 | #------------------------------------------------------------------------------ 5 | # CONNECTIONS AND AUTHENTICATION 6 | #------------------------------------------------------------------------------ 7 | # - Connection Settings - 8 | listen_addresses = '*' # comma-separated list of addresses; 9 | # defaults to 'localhost'; use '*' for all 10 | # (change requires restart) 11 | port = 5432 12 | max_connections = 100 # (change requires restart) 13 | #superuser_reserved_connections = 3 # (change requires restart) 14 | #unix_socket_directories = '/var/run/postgresql, /tmp' # comma-separated list of directories 15 | # (change requires restart) 16 | #unix_socket_group = '' # (change requires restart) 17 | unix_socket_permissions = 0777 # begin with 0 to use octal notation 18 | # (change requires restart) 19 | bonjour = off # advertise server via Bonjour 20 | # (change requires restart) 21 | # - Security and Authentication - 22 | #authentication_timeout = 1min # 1s-600s 23 | ssl = on 24 | #------------------------------------------------------------------------------ 25 | # RESOURCE USAGE (except WAL) 26 | #------------------------------------------------------------------------------ 27 | # - Memory - 28 | shared_buffers = 248MB 29 | # (change requires restart) 30 | autovacuum_work_mem = -1 # min 1MB, or -1 to use maintenance_work_mem 31 | shared_preload_libraries = 'pg_stat_statements' 32 | #------------------------------------------------------------------------------ 33 | # WRITE AHEAD LOG 34 | #------------------------------------------------------------------------------ 35 | # - Settings - 36 | wal_level = hot_standby 37 | checkpoint_completion_target = 0.9 38 | #------------------------------------------------------------------------------ 39 | # ERROR REPORTING AND LOGGING 40 | #------------------------------------------------------------------------------ 41 | log_rotation_age = 1d # Automatic rotation of logfiles will 42 | # happen after that time. 0 disables. 43 | #------------------------------------------------------------------------------ 44 | # CUSTOMIZED OPTIONS 45 | #------------------------------------------------------------------------------ 46 | # Add settings for extensions here 47 | pg_stat_statements.max = 10000 48 | pg_stat_statements.track = all 49 | 50 | some_guc = '2.30' 51 | -------------------------------------------------------------------------------- /tests/datatests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -eux 2 | 3 | python -m pgtoolkit.hba tests/data/pg_hba.conf 4 | ! (python -m pgtoolkit.hba tests/data/pg_hba_bad.conf && exit 1) 5 | 6 | python -m pgtoolkit.pgpass tests/data/pgpass 7 | ! (python -m pgtoolkit.pgpass tests/data/pgpass_bad && exit 1) 8 | 9 | python -m pgtoolkit.service tests/data/pg_service.conf 10 | ! (python -m pgtoolkit.service tests/data/pg_service_bad.conf && exit 1) 11 | 12 | logscript=pgtoolkit.log 13 | python -m $logscript '%m [%p]: [%l-1] app=%a,db=%d%q,client=%h,user=%u ' tests/data/postgresql.log 14 | scripts/profile-log 15 | 16 | python -m pgtoolkit.conf tests/data/postgres.conf 17 | -------------------------------------------------------------------------------- /tests/test_conf.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | from datetime import timedelta 3 | from io import StringIO 4 | from textwrap import dedent 5 | 6 | import pytest 7 | 8 | 9 | def test_parse_value(): 10 | from pgtoolkit.conf import parse_value 11 | 12 | # Booleans 13 | assert parse_value("on") is True 14 | assert parse_value("off") is False 15 | assert parse_value("true") is True 16 | assert parse_value("false") is False 17 | assert parse_value("yes") is True 18 | assert parse_value("'no'") is False 19 | assert parse_value("On") is True 20 | assert parse_value("TRUE") is True 21 | assert parse_value("fAlSe") is False 22 | 23 | # Numbers 24 | assert 10 == parse_value("10") 25 | assert "010" == parse_value("010") 26 | assert "010" == parse_value("'010'") 27 | assert 1.4 == parse_value("1.4") 28 | assert -2 == parse_value("-2") 29 | assert 0.2 == parse_value("0.2") 30 | assert 0 == parse_value("0") 31 | # Numbers, quoted 32 | assert "0" == parse_value("'0'") 33 | assert "2.3" == parse_value("'2.3'") 34 | 35 | # Strings 36 | assert "/a/path/to/file.conf" == parse_value(r"/a/path/to/file.conf") 37 | assert "0755.log" == parse_value(r"0755.log") 38 | assert "file_ending_with_B" == parse_value(r"file_ending_with_B") 39 | 40 | # Escaped quotes: double-quotes or backslash-quote are replaced by 41 | # single-quotes. 42 | assert "esc'aped string" == parse_value(r"'esc\'aped string'") 43 | # Expected values in the following assertions should match what 44 | # psycopg2.extensions.parse_dsn() (or libpq) recognizes. 45 | assert "host='127.0.0.1'" == parse_value("'host=''127.0.0.1'''") 46 | assert "user=foo password=se'cret" == parse_value("'user=foo password=se''cret'") 47 | assert "user=foo password=se''cret" == parse_value("user=foo password=se''cret") 48 | assert "user=foo password=secret'" == parse_value("'user=foo password=secret'''") 49 | assert ( 50 | # this one does not work in parse_dsn() 51 | "user=foo password='secret" 52 | == parse_value("'user=foo password=''secret'") 53 | ) 54 | assert "%m [%p] %q%u@%d " == parse_value(r"'%m [%p] %q%u@%d '") 55 | assert "124.7MB" == parse_value("124.7MB") 56 | assert "124.7ms" == parse_value("124.7ms") 57 | 58 | # Memory 59 | assert "1kB" == parse_value("1kB") 60 | assert "512MB" == parse_value("512MB") 61 | assert "64 GB" == parse_value(" 64 GB ") 62 | assert "5TB" == parse_value("5TB") 63 | 64 | # Time 65 | delta = parse_value("150 ms") 66 | assert 150000 == delta.microseconds 67 | delta = parse_value("24s ") 68 | assert 24 == delta.seconds 69 | delta = parse_value("' 5 min'") 70 | assert 300 == delta.seconds 71 | delta = parse_value("2 h") 72 | assert 7200 == delta.seconds 73 | delta = parse_value("5d") 74 | assert 5 == delta.days 75 | 76 | # Enums 77 | assert "md5" == parse_value("md5") 78 | 79 | # Errors 80 | with pytest.raises(ValueError): 81 | parse_value("'missing last quote") 82 | 83 | 84 | def test_parser(): 85 | from pgtoolkit.conf import parse, parse_string 86 | 87 | content = dedent( 88 | """\ 89 | # This file consists of lines of the form: 90 | # 91 | # name = value 92 | # 93 | # (The "=" is optional.) Whitespace may be used. Comments are introduced with 94 | # "#" anywhere on a line. The complete list of parameter names and allowed 95 | # values can be found in the PostgreSQL documentation. 96 | # 97 | # The commented-out settings shown in this file represent the default values. 98 | # Re-commenting a setting is NOT sufficient to revert it to the default value; 99 | # you need to reload the server. 100 | 101 | # - Connection Settings - 102 | listen_addresses = '*' # comma-separated list of addresses; 103 | # defaults to 'localhost'; use '*' for all 104 | # (change requires restart) 105 | 106 | primary_conninfo = 'host=''example.com'' port=5432 dbname=mydb connect_timeout=10' 107 | port = 5432 108 | bonjour 'without equals' 109 | # bonjour_name = '' # defaults to the computer name 110 | shared.buffers = 248MB 111 | #authentication_timeout = 2min # will be overwritten by the one below 112 | #authentication_timeout = 1min # 1s-600s 113 | # port = 5454 # commented value does not override previous (uncommented) one 114 | """ 115 | ) 116 | 117 | conf = parse_string(content, "/etc/postgres/postgresql.conf") 118 | 119 | assert conf.path == "/etc/postgres/postgresql.conf" 120 | 121 | assert "*" == conf.listen_addresses 122 | assert ( 123 | str(conf.entries["listen_addresses"]) 124 | == "listen_addresses = '*' # comma-separated list of addresses;" 125 | ) 126 | assert 5432 == conf.port 127 | assert ( 128 | conf.primary_conninfo 129 | == "host='example.com' port=5432 dbname=mydb connect_timeout=10" 130 | ) 131 | assert "without equals" == conf.bonjour 132 | assert "248MB" == conf["shared.buffers"] 133 | 134 | assert conf.entries["bonjour_name"].commented 135 | assert ( 136 | str(conf.entries["bonjour_name"]) 137 | == "#bonjour_name = '' # defaults to the computer name" 138 | ) 139 | assert conf.entries["authentication_timeout"].commented 140 | assert conf.entries["authentication_timeout"].value == timedelta(minutes=1) 141 | assert ( 142 | str(conf.entries["authentication_timeout"]) 143 | == "#authentication_timeout = '1 min' # 1s-600s" 144 | ) 145 | 146 | assert [(e.name, e.value) for e in conf if e.commented] == [ 147 | ("name", "value"), 148 | ("bonjour_name", ""), 149 | ("authentication_timeout", timedelta(seconds=60)), 150 | ] 151 | 152 | dict_ = conf.as_dict() 153 | assert "*" == dict_["listen_addresses"] 154 | 155 | with pytest.raises(AttributeError): 156 | conf.inexistent 157 | 158 | with pytest.raises(KeyError): 159 | conf["inexistent"] 160 | 161 | with pytest.raises(ValueError): 162 | parse(["bad_line"]) 163 | 164 | 165 | def test_configuration_fields(): 166 | """Configuration fields (the ones from dataclass definition) can be changed.""" 167 | from pgtoolkit.conf import Configuration 168 | 169 | cfg = Configuration(path="my/postgresql.conf") 170 | assert cfg.path == "my/postgresql.conf" 171 | cfg.path = "changed/to/postgres.conf" 172 | assert cfg.path == "changed/to/postgres.conf" 173 | assert "path" not in cfg and "path" not in cfg.entries 174 | 175 | 176 | def test_configuration_multiple_entries(): 177 | from pgtoolkit.conf import Configuration 178 | 179 | conf = Configuration() 180 | list( 181 | conf.parse( 182 | [ 183 | "port=5432\n", 184 | "# port=5423\n", 185 | "port=5433 # the real one!!\n", 186 | ] 187 | ) 188 | ) 189 | assert conf["port"] == 5433 190 | fo = StringIO() 191 | conf.save(fo) 192 | out = fo.getvalue().strip().splitlines() 193 | assert out == [ 194 | "port=5432", 195 | "# port=5423", 196 | "port=5433 # the real one!!", 197 | ] 198 | 199 | 200 | def test_parser_includes_require_an_absolute_file_path(): 201 | from pgtoolkit.conf import ParseError, parse 202 | 203 | lines = ["include = 'foo.conf'\n"] 204 | with pytest.raises(ParseError, match="cannot process include directives"): 205 | parse(lines) 206 | 207 | 208 | def test_parser_includes_string_with_absolute_file_path(tmp_path: pathlib.Path): 209 | from pgtoolkit.conf import parse 210 | 211 | included = tmp_path / "included.conf" 212 | included.write_text("shared_buffers = 123MG\nmax_connections=45\n") 213 | 214 | lines = [f"include = '{included}'", "cluster_name=foo"] 215 | conf = parse(lines) 216 | assert conf.as_dict() == { 217 | "cluster_name": "foo", 218 | "max_connections": 45, 219 | "shared_buffers": "123MG", 220 | } 221 | 222 | 223 | def test_parser_includes(): 224 | from pgtoolkit.conf import parse 225 | 226 | fpath = pathlib.Path(__file__).parent / "data" / "postgres.conf" 227 | conf = parse(str(fpath)) 228 | assert conf.as_dict() == { 229 | "authentication_timeout": timedelta(seconds=120), 230 | "autovacuum_work_mem": -1, 231 | "bonjour": False, 232 | "bonsoir": True, 233 | "checkpoint_completion_target": 0.9, 234 | "cluster_name": "pgtoolkit", 235 | "listen_addresses": "*", 236 | "log_line_prefix": "%m %q@%d", 237 | "log_rotation_age": timedelta(days=1), 238 | "max_connections": 100, 239 | "my": True, 240 | "mymy": False, 241 | "mymymy": True, 242 | "pg_stat_statements.max": 10000, 243 | "pg_stat_statements.track": "all", 244 | "port": 5432, 245 | "shared_buffers": "248MB", 246 | "shared_preload_libraries": "pg_stat_statements", 247 | "ssl": True, 248 | "unix_socket_permissions": "0777", 249 | "wal_level": "hot_standby", 250 | "some_guc": "2.30", 251 | } 252 | assert "include" not in conf 253 | assert "include_if_exists" not in conf 254 | assert "include_dir" not in conf 255 | 256 | # Make sure original file is preserved on save (i.e. includes do not 257 | # interfere). 258 | fo = StringIO() 259 | conf.save(fo) 260 | lines = fo.getvalue().strip().splitlines() 261 | assert lines[:8] == [ 262 | "include_dir = 'conf.d'", 263 | "#include_dir = 'conf.11.d'", 264 | "include = 'postgres-my.conf'", 265 | "#------------------------------------------------------------------------------", 266 | "# CONNECTIONS AND AUTHENTICATION", 267 | "#------------------------------------------------------------------------------", 268 | "# - Connection Settings -", 269 | "listen_addresses = '*' # comma-separated list of addresses;", 270 | ] 271 | assert lines[-5:] == [ 272 | "# Add settings for extensions here", 273 | "pg_stat_statements.max = 10000", 274 | "pg_stat_statements.track = all", 275 | "", 276 | "some_guc = '2.30'", 277 | ] 278 | 279 | 280 | def test_parser_includes_loop(tmp_path): 281 | from pgtoolkit.conf import parse 282 | 283 | pgconf = tmp_path / "postgres.conf" 284 | with pgconf.open("w") as f: 285 | f.write(f"include = '{pgconf.absolute()}'\n") 286 | 287 | with pytest.raises(RuntimeError, match="loop detected"): 288 | parse(str(pgconf)) 289 | 290 | 291 | def test_parser_includes_notfound(tmp_path): 292 | from pgtoolkit.conf import parse 293 | 294 | pgconf = tmp_path / "postgres.conf" 295 | with pgconf.open("w") as f: 296 | f.write("include = 'missing.conf'\n") 297 | missing_conf = tmp_path / "missing.conf" 298 | msg = f"file '{missing_conf}', included from '{pgconf}', not found" 299 | with pytest.raises(FileNotFoundError, match=msg): 300 | parse(str(pgconf)) 301 | 302 | pgconf = tmp_path / "postgres.conf" 303 | with pgconf.open("w") as f: 304 | f.write("include_dir = 'conf.d'\n") 305 | missing_conf = tmp_path / "conf.d" 306 | msg = f"directory '{missing_conf}', included from '{pgconf}', not found" 307 | with pytest.raises(FileNotFoundError, match=msg): 308 | parse(str(pgconf)) 309 | 310 | 311 | def test_parse_string_include(tmp_path): 312 | from pgtoolkit.conf import ParseError, parse_string 313 | 314 | with pytest.raises( 315 | ParseError, 316 | match="cannot process include directives referencing a relative path", 317 | ): 318 | parse_string("work_mem=1MB\ninclude = x\n") 319 | 320 | 321 | def test_entry_edit(): 322 | from pgtoolkit.conf import Entry 323 | 324 | entry = Entry("port", "1234") 325 | assert entry.value == 1234 326 | entry.value = "9876" 327 | assert entry.value == 9876 328 | 329 | 330 | def test_entry_constructor_parse_value(): 331 | from pgtoolkit.conf import Entry 332 | 333 | entry = Entry("var", "'1.2'") 334 | assert entry.value == "1.2" 335 | entry = Entry("var", "1234") 336 | assert entry.value == 1234 337 | # If value come from the parsing of a file (ie. raw_line is provided) value should 338 | # not be parsed and be kept as is 339 | entry = Entry("var", "'1.2'", raw_line="var = '1.2'") 340 | assert entry.value == "'1.2'" 341 | 342 | 343 | def test_serialize_entry(): 344 | from pgtoolkit.conf import Entry 345 | 346 | e = Entry("grp.setting", True) 347 | 348 | assert "grp.setting" in repr(e) 349 | assert "grp.setting = on" == str(e) 350 | 351 | assert "'2kB'" == Entry("var", "2kB").serialize() 352 | assert "2048" == Entry("var", 2048).serialize() 353 | assert "var = 0" == str(Entry("var", 0)) 354 | assert "var = 15" == str(Entry("var", 15)) 355 | assert "var = 0.1" == str(Entry("var", 0.1)) 356 | assert "var = 'enum'" == str(Entry("var", "enum")) 357 | assert "addrs = '*'" == str(Entry("addrs", "*")) 358 | assert "var = 'sp ced'" == str(Entry("var", "sp ced")) 359 | assert "var = 'quo''ed'" == str(Entry("var", "quo'ed")) 360 | assert "var = 'quo''ed'' and space'" == str(Entry("var", "quo'ed' and space")) 361 | 362 | assert r"'quo\'ed'" == Entry("var", r"quo\'ed").serialize() 363 | e = Entry("var", "app=''foo'' host=192.168.0.8") 364 | assert e.serialize() == "'app=''foo'' host=192.168.0.8'" 365 | assert str(e) == "var = 'app=''foo'' host=192.168.0.8'" 366 | 367 | e = Entry( 368 | "primary_conninfo", 369 | "port=5432 password=pa'sw0'd dbname=postgres", 370 | ) 371 | assert ( 372 | str(e) == "primary_conninfo = 'port=5432 password=pa''sw0''d dbname=postgres'" 373 | ) 374 | 375 | assert "var = 'quoted'" == str(Entry("var", "'quoted'")) 376 | 377 | assert "'1d'" == Entry("var", timedelta(days=1)).serialize() 378 | assert "'1h'" == Entry("var", timedelta(minutes=60)).serialize() 379 | assert "'61 min'" == Entry("var", timedelta(minutes=61)).serialize() 380 | e = Entry("var", timedelta(microseconds=12000)) 381 | assert "'12 ms'" == e.serialize() 382 | 383 | assert " # Comment" in str(Entry("var", 1, comment="Comment")) 384 | 385 | 386 | def test_save(): 387 | from pgtoolkit.conf import parse 388 | 389 | conf = parse(["listen_addresses = *"]) 390 | conf["primary_conninfo"] = "user=repli password=pa'sw0'd" 391 | fo = StringIO() 392 | conf.save(fo) 393 | out = fo.getvalue() 394 | assert "listen_addresses = *" in out 395 | assert "primary_conninfo = 'user=repli password=pa''sw0''d'" in out 396 | 397 | 398 | def test_edit(): 399 | from pgtoolkit.conf import Configuration 400 | 401 | conf = Configuration() 402 | list( 403 | conf.parse( 404 | [ 405 | "#bonjour = off # advertise server via Bonjour\n", 406 | "#bonjour_name = '' # defaults to computer name\n", 407 | ] 408 | ) 409 | ) 410 | 411 | conf.listen_addresses = "*" 412 | assert "listen_addresses" in conf 413 | assert "*" == conf.listen_addresses 414 | 415 | assert "port" not in conf 416 | conf["port"] = 5432 417 | assert 5432 == conf.port 418 | 419 | conf["port"] = "5433" 420 | assert 5433 == conf.port 421 | 422 | conf["primary_conninfo"] = "'port=5432 host=''example.com'''" 423 | assert conf.primary_conninfo == "port=5432 host='example.com'" 424 | 425 | with StringIO() as fo: 426 | conf.save(fo) 427 | lines = fo.getvalue().splitlines() 428 | 429 | assert lines == [ 430 | "#bonjour = off # advertise server via Bonjour", 431 | "#bonjour_name = '' # defaults to computer name", 432 | "listen_addresses = '*'", 433 | "port = 5433", 434 | "primary_conninfo = 'port=5432 host=''example.com'''", 435 | ] 436 | 437 | conf["port"] = 5454 438 | conf["log_line_prefix"] = "[%p]: [%l-1] db=%d,user=%u,app=%a,client=%h " 439 | conf["bonjour_name"] = "pgserver" 440 | conf["track_activity_query_size"] = 32768 441 | with StringIO() as fo: 442 | conf.save(fo) 443 | lines = fo.getvalue().splitlines() 444 | 445 | assert lines == [ 446 | "#bonjour = off # advertise server via Bonjour", 447 | "bonjour_name = 'pgserver' # defaults to computer name", 448 | "listen_addresses = '*'", 449 | "port = 5454", 450 | "primary_conninfo = 'port=5432 host=''example.com'''", 451 | "log_line_prefix = '[%p]: [%l-1] db=%d,user=%u,app=%a,client=%h '", 452 | "track_activity_query_size = 32768", 453 | ] 454 | 455 | with pytest.raises(ValueError, match="cannot add an include directive"): 456 | conf["include_if_exists"] = "file.conf" 457 | 458 | with conf.edit() as entries: 459 | entries.add( 460 | "external_pid_file", 461 | "/tmp/11-main.pid", 462 | comment="write an extra PID file", 463 | ) 464 | del entries["log_line_prefix"] 465 | entries["port"].value = "54" 466 | entries["bonjour"].value = True 467 | 468 | assert conf.port == 54 469 | assert conf.entries["port"].value == 54 470 | 471 | with StringIO() as fo: 472 | conf.save(fo) 473 | lines = fo.getvalue().splitlines() 474 | 475 | expected_lines = [ 476 | "bonjour = on # advertise server via Bonjour", 477 | "bonjour_name = 'pgserver' # defaults to computer name", 478 | "listen_addresses = '*'", 479 | "port = 54", 480 | "primary_conninfo = 'port=5432 host=''example.com'''", 481 | "track_activity_query_size = 32768", 482 | "external_pid_file = '/tmp/11-main.pid' # write an extra PID file", 483 | ] 484 | assert lines == expected_lines 485 | 486 | with pytest.raises(ValueError): 487 | with conf.edit() as entries: 488 | entries["port"].value = "'invalid" 489 | assert lines == expected_lines 490 | 491 | 492 | def test_edit_included_value(tmp_path: pathlib.Path) -> None: 493 | from pgtoolkit.conf import parse 494 | 495 | included = tmp_path / "included.conf" 496 | included.write_text("foo = true\nbar= off\n") 497 | base = tmp_path / "postgresql.conf" 498 | lines = ["bonjour = on", f"include = {included}", "cluster_name=test"] 499 | base.write_text("\n".join(lines) + "\n") 500 | 501 | conf = parse(base) 502 | with pytest.warns(UserWarning, match="entry 'foo' not directly found"): 503 | with conf.edit() as entries: 504 | entries["foo"].value = False 505 | entries["bar"].commented = True 506 | 507 | out = tmp_path / "postgresql-new.conf" 508 | conf.save(out) 509 | newlines = out.read_text().splitlines() 510 | assert newlines == lines + ["foo = off"] 511 | 512 | conf = parse(out) 513 | assert conf["foo"] is False 514 | 515 | 516 | def test_configuration_iter(): 517 | from pgtoolkit.conf import Configuration 518 | 519 | conf = Configuration() 520 | conf.port = 5432 521 | conf.log_timezone = "Europe/Paris" 522 | assert [e.name for e in conf] == ["port", "log_timezone"] 523 | -------------------------------------------------------------------------------- /tests/test_ctl.py: -------------------------------------------------------------------------------- 1 | import shlex 2 | import subprocess 3 | from collections.abc import Sequence 4 | from pathlib import Path 5 | from typing import Any 6 | from unittest.mock import patch 7 | 8 | import pytest 9 | import pytest_asyncio 10 | 11 | from pgtoolkit import ctl # noqa: E402 12 | 13 | 14 | def test__args_to_opts(): 15 | opts = ctl._args_to_opts( 16 | { 17 | "encoding": "latin1", 18 | "auth_local": "trust", 19 | "show": True, 20 | "n": True, 21 | "L": "DIR", 22 | } 23 | ) 24 | assert opts == [ 25 | "-L DIR", 26 | "--auth-local=trust", 27 | "--encoding=latin1", 28 | "-n", 29 | "--show", 30 | ] 31 | 32 | 33 | def test__wait_args_to_opts(): 34 | assert ctl._wait_args_to_opts(False) == ["--no-wait"] 35 | assert ctl._wait_args_to_opts(True) == ["--wait"] 36 | assert ctl._wait_args_to_opts(42) == ["--wait", "--timeout=42"] 37 | 38 | 39 | @pytest.fixture 40 | def bindir(tmp_path: Path) -> Path: 41 | (tmp_path / "pg_ctl").touch(mode=0o777) 42 | with open(tmp_path / "pg_ctl", "w") as f: 43 | f.write("#!/bin/sh\necho 'pg_ctl (PostgreSQL) 11.10'") 44 | return tmp_path 45 | 46 | 47 | def run_command_version_only( 48 | args: Sequence[str], **kwargs: Any 49 | ) -> ctl.CompletedProcess: 50 | try: 51 | executable, *opts = args 52 | except ValueError: 53 | pass 54 | else: 55 | if executable.endswith("/pg_ctl") and opts == ["--version"]: 56 | return subprocess.CompletedProcess( 57 | args, 0, stdout="pg_ctl (PostgreSQL) 11.10\n", stderr="" 58 | ) 59 | pytest.fail(f"unexpectedly called with: {args}") 60 | 61 | 62 | @pytest.fixture 63 | def pgctl(bindir: Path) -> ctl.PGCtl: 64 | c = ctl.PGCtl(bindir, run_command=run_command_version_only) 65 | c.pg_ctl = Path("pg_ctl") 66 | return c 67 | 68 | 69 | def test_init_cmd(pgctl: ctl.PGCtl) -> None: 70 | assert pgctl.init_cmd( 71 | "data", 72 | auth_local="md5", 73 | data_checksums=True, 74 | g=True, 75 | X="wal", 76 | ) == shlex.split( 77 | "pg_ctl init -D data -o '-X wal --auth-local=md5 --data-checksums -g'" 78 | ) 79 | 80 | 81 | def test_start_cmd(pgctl: ctl.PGCtl) -> None: 82 | assert pgctl.start_cmd("data") == shlex.split("pg_ctl start -D data --wait") 83 | assert pgctl.start_cmd("data", wait=False) == shlex.split( 84 | "pg_ctl start -D data --no-wait" 85 | ) 86 | assert pgctl.start_cmd( 87 | "data", 88 | wait=3, 89 | logfile="logfile", 90 | ) == shlex.split("pg_ctl start -D data --wait --timeout=3 --log=logfile") 91 | assert pgctl.start_cmd("data", k="/tmp/sockets") == shlex.split( 92 | "pg_ctl start -D data --wait -o '-k /tmp/sockets'" 93 | ) 94 | 95 | 96 | def test_stop_cmd(pgctl: ctl.PGCtl) -> None: 97 | assert pgctl.stop_cmd("data") == shlex.split("pg_ctl stop -D data --wait") 98 | assert pgctl.stop_cmd("data", wait=False) == shlex.split( 99 | "pg_ctl stop -D data --no-wait" 100 | ) 101 | assert pgctl.stop_cmd( 102 | "data", 103 | wait=3, 104 | mode="fast", 105 | ) == shlex.split("pg_ctl stop -D data --wait --timeout=3 --mode=fast") 106 | 107 | 108 | def test_restart_cmd(pgctl: ctl.PGCtl) -> None: 109 | assert pgctl.restart_cmd("data") == shlex.split("pg_ctl restart -D data --wait") 110 | assert pgctl.restart_cmd("data", wait=False) == shlex.split( 111 | "pg_ctl restart -D data --no-wait" 112 | ) 113 | assert pgctl.restart_cmd( 114 | "data", 115 | wait=3, 116 | mode="fast", 117 | ) == shlex.split("pg_ctl restart -D data --wait --timeout=3 --mode=fast") 118 | assert pgctl.restart_cmd("data", k="/tmp/sockets") == shlex.split( 119 | "pg_ctl restart -D data --wait -o '-k /tmp/sockets'" 120 | ) 121 | 122 | 123 | def test_reload_cmd(pgctl: ctl.PGCtl) -> None: 124 | assert pgctl.reload_cmd("data") == shlex.split("pg_ctl reload -D data") 125 | 126 | 127 | def test_status_cmd(pgctl: ctl.PGCtl) -> None: 128 | assert pgctl.status_cmd("data") == shlex.split("pg_ctl status -D data") 129 | 130 | 131 | @pytest.mark.parametrize( 132 | "returncode, status", 133 | [ 134 | [0, ctl.Status.running], 135 | [3, ctl.Status.not_running], 136 | [4, ctl.Status.unspecified_datadir], 137 | ], 138 | ) 139 | def test_status(pgctl: ctl.PGCtl, returncode: int, status: ctl.Status) -> None: 140 | with patch.object( 141 | pgctl, "run_command", return_value=subprocess.CompletedProcess([], returncode) 142 | ) as run_command: 143 | actual = pgctl.status("data") 144 | run_command.assert_called_once_with(["pg_ctl", "status", "-D", "data"]) 145 | assert actual == status 146 | 147 | 148 | def test_status_returncode1(pgctl: ctl.PGCtl) -> None: 149 | with ( 150 | patch.object( 151 | pgctl, "run_command", return_value=subprocess.CompletedProcess([], 1) 152 | ) as run_command, 153 | pytest.raises(subprocess.CalledProcessError), 154 | ): 155 | pgctl.status("data") 156 | run_command.assert_called_once_with(["pg_ctl", "status", "-D", "data"]) 157 | 158 | 159 | @pytest_asyncio.fixture 160 | async def apgctl(bindir: Path) -> ctl.AsyncPGCtl: 161 | async def run_command(args: Sequence[str], **kwargs: Any) -> ctl.CompletedProcess: 162 | return run_command_version_only(args, **kwargs) 163 | 164 | c = await ctl.AsyncPGCtl.get(bindir, run_command=run_command) 165 | c.pg_ctl = Path("pg_ctl") 166 | return c 167 | 168 | 169 | @pytest.mark.parametrize( 170 | "rc, status", 171 | [ 172 | [0, ctl.Status.running], 173 | [3, ctl.Status.not_running], 174 | [4, ctl.Status.unspecified_datadir], 175 | ], 176 | ) 177 | @pytest.mark.asyncio 178 | async def test_status_async( 179 | apgctl: ctl.AsyncPGCtl, rc: int, status: ctl.Status 180 | ) -> None: 181 | with patch.object( 182 | apgctl, "run_command", return_value=subprocess.CompletedProcess([], rc) 183 | ) as run_command: 184 | actual = await apgctl.status("data") 185 | run_command.assert_called_once_with(["pg_ctl", "status", "-D", "data"]) 186 | assert actual == status 187 | 188 | 189 | @pytest.mark.asyncio 190 | async def test_status_returncode1_async(apgctl: ctl.AsyncPGCtl) -> None: 191 | with ( 192 | patch.object( 193 | apgctl, "run_command", return_value=subprocess.CompletedProcess([], 1) 194 | ) as run_command, 195 | pytest.raises(subprocess.CalledProcessError), 196 | ): 197 | await apgctl.status("data") 198 | run_command.assert_called_once_with(["pg_ctl", "status", "-D", "data"]) 199 | 200 | 201 | def test_parse_controldata() -> None: 202 | lines = [ 203 | "pg_control version number: 1100", 204 | "Catalog version number: 201809051", 205 | "Database system identifier: 6798427594087098476", 206 | "Database cluster state: shut down", 207 | "pg_control last modified: Tue 07 Jul 2020 01:08:58 PM CEST", 208 | "WAL block size: 8192", 209 | ] 210 | controldata = ctl.parse_control_data(lines) 211 | assert controldata == { 212 | "Catalog version number": "201809051", 213 | "Database cluster state": "shut down", 214 | "Database system identifier": "6798427594087098476", 215 | "WAL block size": "8192", 216 | "pg_control last modified": "Tue 07 Jul 2020 01:08:58 PM CEST", 217 | "pg_control version number": "1100", 218 | } 219 | -------------------------------------------------------------------------------- /tests/test_ctl_func.py: -------------------------------------------------------------------------------- 1 | import socket 2 | import stat 3 | from pathlib import Path 4 | 5 | import pytest 6 | import pytest_asyncio 7 | 8 | from pgtoolkit import ctl, hba 9 | 10 | 11 | @pytest.fixture(scope="module") 12 | def pgctl() -> ctl.PGCtl: 13 | try: 14 | return ctl.PGCtl() 15 | except OSError as e: 16 | pytest.skip(str(e)) 17 | 18 | 19 | @pytest_asyncio.fixture(loop_scope="module") 20 | async def apgctl() -> ctl.AsyncPGCtl: 21 | try: 22 | return await ctl.AsyncPGCtl.get() 23 | except OSError as e: 24 | pytest.skip(str(e)) 25 | 26 | 27 | def test_pgctl(pgctl: ctl.PGCtl) -> None: 28 | assert pgctl.pg_ctl 29 | 30 | 31 | def test_pgctl_async(apgctl: ctl.AsyncPGCtl) -> None: 32 | assert apgctl.pg_ctl 33 | 34 | 35 | @pytest.fixture(scope="module") 36 | def initdb(tmp_path_factory, pgctl: ctl.PGCtl) -> tuple[Path, Path, Path]: 37 | datadir = tmp_path_factory.mktemp("data") 38 | waldir = tmp_path_factory.mktemp("wal") 39 | pgctl.init( 40 | datadir, 41 | auth_local="scram-sha-256", 42 | data_checksums=True, 43 | g=True, 44 | X=waldir, 45 | ) 46 | run_path = tmp_path_factory.mktemp("run") 47 | pid_path = run_path / "pid" 48 | with (datadir / "postgresql.conf").open("a") as f: 49 | f.write(f"\nunix_socket_directories = '{run_path}'") 50 | f.write(f"\nexternal_pid_file = '{pid_path}'") 51 | return datadir, waldir, pid_path 52 | 53 | 54 | @pytest_asyncio.fixture(loop_scope="module") 55 | async def ainitdb(tmp_path_factory, apgctl: ctl.AsyncPGCtl) -> tuple[Path, Path, Path]: 56 | datadir = tmp_path_factory.mktemp("data") 57 | waldir = tmp_path_factory.mktemp("wal") 58 | await apgctl.init( 59 | datadir, 60 | auth_local="scram-sha-256", 61 | data_checksums=True, 62 | g=True, 63 | X=waldir, 64 | ) 65 | run_path = tmp_path_factory.mktemp("run") 66 | pid_path = run_path / "pid" 67 | with (datadir / "postgresql.conf").open("a") as f: 68 | f.write(f"\nunix_socket_directories = '{run_path}'") 69 | f.write(f"\nexternal_pid_file = '{pid_path}'") 70 | return datadir, waldir, pid_path 71 | 72 | 73 | def _check_initdb(datadir: Path, waldir: Path, pid_path: Path) -> None: 74 | assert (datadir / "PG_VERSION").exists() 75 | assert (waldir / "archive_status").is_dir() 76 | with (datadir / "pg_hba.conf").open() as f: 77 | pghba = hba.parse(f) 78 | assert next(iter(pghba)).method == "scram-sha-256" 79 | st_mode = datadir.stat().st_mode 80 | assert st_mode & stat.S_IRGRP 81 | assert st_mode & stat.S_IXGRP 82 | assert not st_mode & stat.S_IWGRP 83 | 84 | 85 | def test_init(initdb: tuple[Path, Path, Path]) -> None: 86 | _check_initdb(*initdb) 87 | 88 | 89 | @pytest.mark.asyncio 90 | async def test_init_async(ainitdb: tuple[Path, Path, Path]) -> None: 91 | _check_initdb(*ainitdb) 92 | 93 | 94 | @pytest.fixture 95 | def tmp_port() -> int: 96 | s = socket.socket() 97 | s.bind(("", 0)) 98 | with s: 99 | port = s.getsockname()[1] 100 | return port 101 | 102 | 103 | def test_start_stop_status_restart_reload( 104 | initdb: tuple[Path, Path, Path], pgctl: ctl.PGCtl, tmp_port: int 105 | ) -> None: 106 | datadir, __, pidpath = initdb 107 | assert pgctl.status("invalid") == ctl.Status.unspecified_datadir 108 | assert pgctl.status(str(datadir)) == ctl.Status.not_running 109 | assert not pidpath.exists() 110 | pgctl.start(str(datadir), logfile=datadir / "logs", port=str(tmp_port)) 111 | assert pidpath.exists() 112 | pid1 = pidpath.read_text() 113 | 114 | assert pgctl.status(str(datadir)) == ctl.Status.running 115 | pgctl.restart(str(datadir), mode="immediate", wait=2) 116 | pid2 = pidpath.read_text() 117 | assert pid2 != pid1 118 | assert pgctl.status(str(datadir)) == ctl.Status.running 119 | pgctl.reload(str(datadir)) 120 | pid3 = pidpath.read_text() 121 | assert pid3 == pid2 122 | assert pgctl.status(str(datadir)) == ctl.Status.running 123 | pgctl.stop(str(datadir), mode="smart") 124 | assert not pidpath.exists() 125 | assert pgctl.status(str(datadir)) == ctl.Status.not_running 126 | 127 | 128 | @pytest.mark.asyncio 129 | async def test_start_stop_status_restart_reload_async( 130 | ainitdb: tuple[Path, Path, Path], apgctl: ctl.AsyncPGCtl, tmp_port: int 131 | ) -> None: 132 | datadir, __, pidpath = ainitdb 133 | assert (await apgctl.status("invalid")) == ctl.Status.unspecified_datadir 134 | assert (await apgctl.status(str(datadir))) == ctl.Status.not_running 135 | assert not pidpath.exists() 136 | await apgctl.start(str(datadir), logfile=datadir / "logs", port=str(tmp_port)) 137 | assert pidpath.exists() 138 | pid1 = pidpath.read_text() 139 | 140 | assert (await apgctl.status(str(datadir))) == ctl.Status.running 141 | await apgctl.restart(str(datadir), mode="immediate", wait=2) 142 | pid2 = pidpath.read_text() 143 | assert pid2 != pid1 144 | assert (await apgctl.status(str(datadir))) == ctl.Status.running 145 | await apgctl.reload(str(datadir)) 146 | pid3 = pidpath.read_text() 147 | assert pid3 == pid2 148 | assert (await apgctl.status(str(datadir))) == ctl.Status.running 149 | await apgctl.stop(str(datadir), mode="smart") 150 | assert not pidpath.exists() 151 | assert (await apgctl.status(str(datadir))) == ctl.Status.not_running 152 | 153 | 154 | def test_controldata(initdb: tuple[Path, Path, Path], pgctl: ctl.PGCtl) -> None: 155 | datadir, __, __ = initdb 156 | controldata = pgctl.controldata(datadir=datadir) 157 | assert "Database block size" in controldata 158 | assert controldata["Database block size"] == "8192" 159 | assert "Database cluster state" in controldata 160 | 161 | 162 | @pytest.mark.asyncio 163 | async def test_controldata_async( 164 | ainitdb: tuple[Path, Path, Path], apgctl: ctl.AsyncPGCtl 165 | ) -> None: 166 | datadir, __, __ = ainitdb 167 | controldata = await apgctl.controldata(datadir=datadir) 168 | assert "Database block size" in controldata 169 | assert controldata["Database block size"] == "8192" 170 | assert "Database cluster state" in controldata 171 | -------------------------------------------------------------------------------- /tests/test_hba.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | HBA_SAMPLE = """\ 4 | # CAUTION: Configuring the system for local "trust" authentication 5 | # allows any local user to connect as any PostgreSQL user, including 6 | # the database superuser. If you do not trust all your local users, 7 | # use another authentication method. 8 | 9 | 10 | # TYPE DATABASE USER ADDRESS METHOD 11 | 12 | # "local" is for Unix domain socket connections only 13 | local all all trust 14 | # IPv4 local connections: 15 | host all u0,u1 127.0.0.1/32 trust 16 | # IPv6 local connections: 17 | host all +group,u2 ::1/128 trust 18 | # Allow replication connections from localhost, by a user with the 19 | # replication privilege. 20 | local replication all trust 21 | host replication all 127.0.0.1 255.255.255.255 trust 22 | host replication all ::1/128 trust 23 | 24 | host all all all trust 25 | """ 26 | 27 | 28 | def test_comment(): 29 | from pgtoolkit.hba import HBAComment 30 | 31 | comment = HBAComment("# toto") 32 | assert "toto" in repr(comment) 33 | assert "# toto" == str(comment) 34 | 35 | 36 | def test_parse_host_line(): 37 | from pgtoolkit.hba import HBARecord 38 | 39 | record = HBARecord.parse("host replication,mydb all ::1/128 trust") 40 | assert "host" in repr(record) 41 | assert "host" == record.conntype 42 | assert "replication,mydb" == record.database 43 | assert ["replication", "mydb"] == record.databases 44 | assert "all" == record.user 45 | assert ["all"] == record.users 46 | assert "::1/128" == record.address 47 | assert "trust" == record.method 48 | 49 | # This is not actually a public API. But let's keep it stable. 50 | values = record.common_values 51 | assert "trust" in values 52 | 53 | 54 | def test_parse_local_line(): 55 | from pgtoolkit.hba import HBARecord 56 | 57 | record = HBARecord.parse("local all all trust") 58 | assert "local" == record.conntype 59 | assert "all" == record.database 60 | assert ["all"] == record.databases 61 | assert "all" == record.user 62 | assert ["all"] == record.users 63 | assert "trust" == record.method 64 | 65 | with pytest.raises(AttributeError): 66 | record.address 67 | 68 | wanted = ( 69 | "local all all trust" # noqa 70 | ) 71 | assert wanted == str(record) 72 | 73 | 74 | def test_parse_auth_option(): 75 | from pgtoolkit.hba import HBARecord 76 | 77 | record = HBARecord.parse( 78 | "local veryverylongdatabasenamethatdonotfit all ident map=omicron", 79 | ) 80 | assert "local" == record.conntype 81 | assert "veryverylongdatabasenamethatdonotfit" == record.database 82 | assert "all" == record.user 83 | assert "ident" == record.method 84 | assert "omicron" == record.map 85 | 86 | wanted = [ 87 | "local", 88 | "veryverylongdatabasenamethatdonotfit", 89 | "all", 90 | "ident", 91 | 'map="omicron"', 92 | ] 93 | assert wanted == str(record).split() 94 | 95 | 96 | def test_parse_record_with_comment(): 97 | from pgtoolkit.hba import HBARecord 98 | 99 | record = HBARecord.parse("local all all trust # My comment") 100 | assert "local" == record.conntype 101 | assert "all" == record.database 102 | assert "all" == record.user 103 | assert "trust" == record.method 104 | assert "My comment" == record.comment 105 | 106 | fields = str(record).split() 107 | assert ["local", "all", "all", "trust", "#", "My", "comment"] == fields 108 | 109 | 110 | def test_parse_invalid_connection_type(): 111 | from pgtoolkit.hba import HBARecord 112 | 113 | with pytest.raises(ValueError, match="Unknown connection type 'pif'"): 114 | HBARecord.parse("pif all all") 115 | 116 | 117 | def test_parse_record_with_backslash(): 118 | from pgtoolkit.hba import HBARecord 119 | 120 | record = HBARecord.parse( 121 | r'host all all all ldap ldapserver=host.local ldapprefix="DOMAINE\"' 122 | ) 123 | assert record.ldapprefix == "DOMAINE\\" 124 | 125 | 126 | def test_parse_record_with_double_quoting(): 127 | from pgtoolkit.hba import HBARecord 128 | 129 | record = HBARecord.parse( 130 | r'host all all all radius radiusservers="server1,server2" radiussecrets="""secret one"",""secret two"""' 131 | ) 132 | assert record.radiussecrets == '""secret one"",""secret two""' 133 | 134 | 135 | def test_parse_record_blank_in_quotes(): 136 | from pgtoolkit.hba import HBARecord 137 | 138 | record = HBARecord.parse( 139 | r"host all all all ldap ldapserver=ldap.example.net" 140 | r' ldapbasedn="dc=example, dc=net"' 141 | r' ldapsearchfilter="(|(uid=$username)(mail=$username))"' 142 | ) 143 | assert record.ldapserver == "ldap.example.net" 144 | assert record.ldapbasedn == "dc=example, dc=net" 145 | assert record.ldapsearchfilter == "(|(uid=$username)(mail=$username))" 146 | 147 | 148 | def test_hba(mocker): 149 | from pgtoolkit.hba import parse 150 | 151 | lines = HBA_SAMPLE.splitlines(True) 152 | hba = parse(lines) 153 | entries = list(iter(hba)) 154 | 155 | assert 7 == len(entries) 156 | 157 | hba.save(mocker.Mock(name="file")) 158 | 159 | 160 | def test_hba_path(tmp_path): 161 | from pgtoolkit.hba import HBA 162 | 163 | hba = HBA() 164 | assert hba.path is None 165 | with pytest.raises(ValueError, match="No file-like object nor path provided"): 166 | hba.save() 167 | 168 | p = tmp_path / "filename" 169 | hba = HBA([], p) 170 | assert hba.path == p 171 | hba.save() 172 | 173 | 174 | def test_hba_create(): 175 | from pgtoolkit.hba import HBA, HBAComment, HBARecord 176 | 177 | hba = HBA( 178 | [ 179 | HBAComment("# a comment"), 180 | HBARecord( 181 | conntype="local", 182 | database="all", 183 | user="all", 184 | method="trust", 185 | ), 186 | ] 187 | ) 188 | assert 2 == len(hba.lines) 189 | 190 | r = hba.lines[1] 191 | assert "all" == r.database 192 | 193 | 194 | def test_parse_file(mocker, tmp_path): 195 | from pgtoolkit.hba import HBAComment, parse 196 | 197 | m = mocker.mock_open() 198 | try: 199 | mocker.patch("builtins.open", m) 200 | except Exception: 201 | mocker.patch("__builtin__.open", m) 202 | hba = parse("filename") 203 | hba.lines.append(HBAComment("# Something")) 204 | 205 | assert m.called 206 | hba.save() 207 | handle = m() 208 | handle.write.assert_called_with("# Something\n") 209 | 210 | # Also works for other string types 211 | m.reset_mock() 212 | hba = parse("filename") 213 | hba.lines.append(HBAComment("# Something")) 214 | assert m.called 215 | 216 | # Also works with path 217 | m.reset_mock() 218 | hba = parse(tmp_path / "filename") 219 | hba.lines.append(HBAComment("# Something")) 220 | 221 | assert m.called 222 | 223 | 224 | def test_hba_error(mocker): 225 | from pgtoolkit.hba import ParseError, parse 226 | 227 | with pytest.raises(ParseError) as ei: 228 | parse(["lcal all all\n"]) 229 | e = ei.value 230 | assert "line #1" in str(e) 231 | assert repr(e) 232 | 233 | with pytest.raises(ParseError) as ei: 234 | parse(["local incomplete\n"]) 235 | 236 | 237 | def test_remove(): 238 | from pgtoolkit.hba import parse 239 | 240 | lines = HBA_SAMPLE.splitlines(True) 241 | hba = parse(lines) 242 | 243 | with pytest.raises(ValueError): 244 | hba.remove() 245 | 246 | result = hba.remove(database="badname") 247 | assert not result 248 | 249 | result = hba.remove(database="replication") 250 | assert result 251 | entries = list(iter(hba)) 252 | assert 4 == len(entries) 253 | 254 | hba = parse(lines) 255 | result = hba.remove(filter=lambda r: r.database == "replication") 256 | assert result 257 | entries = list(iter(hba)) 258 | assert 4 == len(entries) 259 | 260 | hba = parse(lines) 261 | result = hba.remove(conntype="host", database="replication") 262 | assert result 263 | entries = list(iter(hba)) 264 | assert 5 == len(entries) 265 | 266 | # Works even for fields that may not be valid for all records 267 | # `address` is not valid for `local` connection type 268 | hba = parse(lines) 269 | result = hba.remove(address="127.0.0.1/32") 270 | assert result 271 | entries = list(iter(hba)) 272 | assert 6 == len(entries) 273 | 274 | def filter(r): 275 | return r.conntype == "host" and r.database == "replication" 276 | 277 | hba = parse(lines) 278 | result = hba.remove(filter=filter) 279 | assert result 280 | entries = list(iter(hba)) 281 | assert 5 == len(entries) 282 | 283 | # Only filter is taken into account 284 | hba = parse(lines) 285 | with pytest.warns(UserWarning): 286 | hba.remove(filter=filter, database="replication") 287 | entries = list(iter(hba)) 288 | assert 5 == len(entries) 289 | 290 | # Error if attribute name is not valid 291 | hba = parse(lines) 292 | with pytest.raises(AttributeError): 293 | hba.remove(foo="postgres") 294 | 295 | 296 | def test_merge(): 297 | import os 298 | 299 | from pgtoolkit.hba import HBA, HBARecord, parse 300 | 301 | sample = """\ 302 | # comment 303 | host replication all all trust 304 | # other comment 305 | host replication all 127.0.0.1 255.255.255.255 trust 306 | # Comment should be kept 307 | host all all all trust""" 308 | lines = sample.splitlines(True) 309 | hba = parse(lines) 310 | 311 | other_sample = """\ 312 | # line with no address 313 | local all all trust 314 | # comment before 1.2.3.4 line 315 | host replication all 1.2.3.4 trust 316 | # method changed to 'peer' 317 | # second comment 318 | host all all all peer 319 | """ 320 | other_lines = other_sample.splitlines(True) 321 | other_hba = parse(other_lines) 322 | result = hba.merge(other_hba) 323 | assert result 324 | 325 | expected_sample = """\ 326 | # comment 327 | host replication all all trust 328 | # other comment 329 | host replication all 127.0.0.1 255.255.255.255 trust 330 | # Comment should be kept 331 | # method changed to 'peer' 332 | # second comment 333 | host all all all peer 334 | # line with no address 335 | local all all trust 336 | # comment before 1.2.3.4 line 337 | host replication all 1.2.3.4 trust 338 | """ 339 | expected_lines = expected_sample.splitlines(True) 340 | expected_hba = parse(expected_lines) 341 | 342 | def r(hba): 343 | return os.linesep.join([str(line) for line in hba.lines]) 344 | 345 | assert r(hba) == r(expected_hba) 346 | 347 | other_hba = HBA() 348 | record = HBARecord( 349 | conntype="host", 350 | database="replication", 351 | user="all", 352 | address="1.2.3.4", 353 | method="trust", 354 | ) 355 | other_hba.lines.append(record) 356 | result = hba.merge(other_hba) 357 | assert not result 358 | 359 | 360 | def test_as_dict(): 361 | from pgtoolkit.hba import HBARecord 362 | 363 | r = HBARecord( 364 | conntype="local", 365 | database="all", 366 | user="all", 367 | method="trust", 368 | ) 369 | assert r.as_dict() == { 370 | "conntype": "local", 371 | "database": "all", 372 | "user": "all", 373 | "method": "trust", 374 | } 375 | 376 | r = HBARecord( 377 | conntype="local", 378 | database="mydb,mydb2", 379 | user="bob,alice", 380 | address="127.0.0.1", 381 | netmask="255.255.255.255", 382 | method="trust", 383 | ) 384 | assert r.as_dict() == { 385 | "address": "127.0.0.1", 386 | "conntype": "local", 387 | "database": "mydb,mydb2", 388 | "user": "bob,alice", 389 | "method": "trust", 390 | "netmask": "255.255.255.255", 391 | } 392 | 393 | 394 | def test_hbarecord_equality(): 395 | from pgtoolkit.hba import HBARecord 396 | 397 | r = HBARecord( 398 | conntype="local", 399 | database="all", 400 | user="all", 401 | method="trust", 402 | ) 403 | r2 = HBARecord.parse("local all all trust") 404 | assert r == r2 405 | 406 | r = HBARecord( 407 | conntype="host", 408 | database="all", 409 | user="u0,u1", 410 | address="127.0.0.1/32", 411 | method="trust", 412 | ) 413 | r2 = HBARecord.parse("host all u0,u1 127.0.0.1/32 trust") 414 | assert r == r2 415 | 416 | r2 = HBARecord.parse("host mydb u0,u1 127.0.0.1/32 trust") 417 | assert r != r2 418 | -------------------------------------------------------------------------------- /tests/test_helpers.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | def test_open_or_stdin(mocker): 5 | from pgtoolkit._helpers import open_or_stdin 6 | 7 | stdin = object() 8 | assert open_or_stdin("-", stdin=stdin) is stdin 9 | 10 | open_ = mocker.patch("pgtoolkit._helpers.open", creates=True) 11 | open_.return_value = fo = object() 12 | 13 | assert open_or_stdin("toto.conf") is fo 14 | 15 | 16 | def test_open_or_return(tmp_path): 17 | from pgtoolkit._helpers import open_or_return 18 | 19 | # File case. 20 | with (tmp_path / "foo").open("w") as fo: 21 | with open_or_return(fo) as ret: 22 | assert ret is fo 23 | 24 | # Path case. 25 | fpath = tmp_path / "toto.conf" 26 | fpath.write_text("paf") 27 | with open_or_return(fpath) as ret: 28 | assert ret.name == str(fpath) 29 | assert ret.read() == "paf" 30 | 31 | # Path as str case. 32 | with open_or_return(str(fpath)) as ret: 33 | assert ret.name == str(fpath) 34 | assert ret.read() == "paf" 35 | 36 | # None case. 37 | with pytest.raises(ValueError): 38 | open_or_return(None) 39 | 40 | 41 | def test_timer(): 42 | from pgtoolkit._helpers import Timer 43 | 44 | with Timer() as timer: 45 | pass 46 | 47 | assert timer.start 48 | assert timer.delta 49 | 50 | 51 | def test_format_timedelta(): 52 | from datetime import timedelta 53 | 54 | from pgtoolkit._helpers import format_timedelta 55 | 56 | assert "5s" == format_timedelta(timedelta(seconds=5)) 57 | assert "1d 5s" == format_timedelta(timedelta(days=1, seconds=5)) 58 | assert "20us" == format_timedelta(timedelta(microseconds=20)) 59 | assert "0s" == format_timedelta(timedelta()) 60 | 61 | 62 | def test_json_encoder(): 63 | import json 64 | from datetime import datetime, timedelta 65 | 66 | from pgtoolkit._helpers import JSONDateEncoder 67 | 68 | data_ = dict( 69 | date=datetime(year=2012, month=12, day=21), 70 | delta=timedelta(seconds=40), 71 | integer=42, 72 | ) 73 | 74 | payload = json.dumps(data_, cls=JSONDateEncoder) 75 | 76 | assert '"2012-12-21T00:00:00' in payload 77 | assert '"40s"' in payload 78 | assert ": 42" in payload 79 | -------------------------------------------------------------------------------- /tests/test_log.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import pytest 4 | 5 | 6 | def test_parse(): 7 | from pgtoolkit.log import UnknownData, parse 8 | 9 | lines = """\ 10 | \tResult (cost=0.00..0.01 rows=1 width=4) (actual time=1001.117..1001.118 rows=1 loops=1) 11 | \t Output: pg_sleep('1'::double precision) 12 | 2018-06-15 10:03:53.488 UTC [7931]: [2-1] app=[unknown],db=[unknown],client=[local],user=[unknown] LOG: incomplete startup packet 13 | 2018-06-15 10:44:42.923 UTC [8280]: [2-1] app=,db=,client=,user= LOG: checkpoint starting: shutdown immediate 14 | 2018-06-15 10:44:58.206 UTC [8357]: [4-1] app=psql,db=postgres,client=[local],user=postgres HINT: No function matches the given name and argument types. You might need to add explicit type casts. 15 | 2018-06-15 10:45:03.175 UTC [8357]: [7-1] app=psql,db=postgres,client=[local],user=postgres LOG: duration: 1002.209 ms statement: select pg_sleep(1); 16 | 2018-06-15 10:49:11.512 UTC [8357]: [8-1] app=psql,db=postgres,client=[local],user=postgres LOG: duration: 0.223 ms statement: show log_timezone; 17 | 2018-06-15 10:49:26.084 UTC [8420]: [2-1] app=[unknown],db=postgres,client=[local],user=postgres LOG: connection authorized: user=postgres database=postgres 18 | 2018-06-15 10:49:26.088 UTC [8420]: [3-1] app=psql,db=postgres,client=[local],user=postgres LOG: duration: 1.449 ms statement: SELECT d.datname as "Name", 19 | \t pg_catalog.pg_get_userbyid(d.datdba) as "Owner", 20 | \t pg_catalog.pg_encoding_to_char(d.encoding) as "Encoding", 21 | \t d.datcollate as "Collate", 22 | \t d.datctype as "Ctype", 23 | \t pg_catalog.array_to_string(d.datacl, E'\\n') AS "Access privileges" 24 | \tFROM pg_catalog.pg_database d 25 | \tORDER BY 1; 26 | 2018-06-15 10:49:26.088 UTC [8420]: [4-1] app=psql,db=postgres,client=[local],user=postgres LOG: disconnection: session time: 0:00:00.006 user=postgres database=postgres host=[local] 27 | BAD PREFIX 10:49:31.140 UTC [8423]: [1-1] app=[unknown],db=[unknown],client=[local],user=[unknown] LOG: connection received: host=[local] 28 | """.splitlines( 29 | True 30 | ) # noqa 31 | 32 | log_line_prefix = "%m [%p]: [%l-1] app=%a,db=%d,client=%h,user=%u " 33 | records = list(parse(lines, prefix_fmt=log_line_prefix)) 34 | 35 | assert isinstance(records[0], UnknownData) 36 | assert "\n" not in repr(records[0]) 37 | record = records[1] 38 | assert "LOG" == record.severity 39 | 40 | 41 | def test_group_lines(): 42 | from pgtoolkit.log.parser import group_lines 43 | 44 | lines = """\ 45 | \tResult (cost=0.00..0.01 rows=1 width=4) (actual time=1001.117..1001.118 rows=1 loops=1) 46 | \t Output: pg_sleep('1'::double precision) 47 | 2018-06-15 10:45:03.175 UTC [8357]: [7-1] app=psql,db=postgres,client=[local],user=postgres LOG: duration: 1002.209 ms statement: select pg_sleep(1); 48 | 2018-06-15 10:49:11.512 UTC [8357]: [8-1] app=psql,db=postgres,client=[local],user=postgres LOG: duration: 0.223 ms statement: show log_timezone; 49 | 2018-06-15 10:49:26.084 UTC [8420]: [2-1] app=[unknown],db=postgres,client=[local],user=postgres LOG: connection authorized: user=postgres database=postgres 50 | 2018-06-15 10:49:26.088 UTC [8420]: [3-1] app=psql,db=postgres,client=[local],user=postgres LOG: duration: 1.449 ms statement: SELECT d.datname as "Name", 51 | \t pg_catalog.pg_get_userbyid(d.datdba) as "Owner", 52 | \t pg_catalog.pg_encoding_to_char(d.encoding) as "Encoding", 53 | \t d.datcollate as "Collate", 54 | \t d.datctype as "Ctype", 55 | \t pg_catalog.array_to_string(d.datacl, E'\\n') AS "Access privileges" 56 | \tFROM pg_catalog.pg_database d 57 | \tORDER BY 1; 58 | 2018-06-15 10:49:26.088 UTC [8420]: [4-1] app=psql,db=postgres,client=[local],user=postgres LOG: disconnection: session time: 0:00:00.006 user=postgres database=postgres host=[local] 59 | 2018-06-15 10:49:31.140 UTC [8423]: [1-1] app=[unknown],db=[unknown],client=[local],user=[unknown] LOG: connection received: host=[local] 60 | """.splitlines( 61 | True 62 | ) # noqa 63 | 64 | groups = list(group_lines(lines)) 65 | assert 7 == len(groups) 66 | 67 | 68 | def test_prefix_parser(): 69 | from pgtoolkit.log.parser import PrefixParser 70 | 71 | # log_line_prefix with all options. 72 | prefix_fmt = "%m [%p]: [%l-1] app=%a,db=%d,client=%h,user=%u,remote=%r,epoch=%n,timestamp=%t,tag=%i,error=%e,session=%c,start=%s,vxid=%v,xid=%x " # noqa 73 | prefix = "2018-06-15 14:15:52.332 UTC [10011]: [2-1] app=[unknown],db=postgres,client=[local],user=postgres,remote=[local],epoch=1529072152.332,timestamp=2018-06-15 14:15:52 UTC,tag=authentication,error=00000,session=5b23ca18.271b,start=2018-06-15 14:15:52 UTC,vxid=3/7,xid=0 " # noqa 74 | 75 | # Ensure each pattern matches. 76 | for pat in PrefixParser._status_pat.values(): 77 | assert re.search(pat, prefix) 78 | 79 | parser = PrefixParser.from_configuration(prefix_fmt) 80 | assert "%m" in repr(parser) 81 | fields = parser.parse(prefix) 82 | 83 | assert 2018 == fields["timestamp"].year 84 | assert "application" in fields 85 | assert "user" in fields 86 | assert "database" in fields 87 | assert "remote_host" in fields 88 | assert 10011 == fields["pid"] 89 | assert 2 == fields["line_num"] 90 | 91 | 92 | def test_prefix_parser_q(): 93 | from pgtoolkit.log.parser import PrefixParser 94 | 95 | # log_line_prefix with all options. 96 | prefix_fmt = "%m [%p]: %q%u@%h " 97 | 98 | parser = PrefixParser.from_configuration(prefix_fmt) 99 | fields = parser.parse("2018-06-15 14:15:52.332 UTC [10011]: ") 100 | assert fields["user"] is None 101 | 102 | 103 | def test_isodatetime(): 104 | from pgtoolkit.log.parser import parse_isodatetime 105 | 106 | date = parse_isodatetime("2018-06-04 20:12:34.343 UTC") 107 | assert date 108 | assert 2018 == date.year 109 | assert 6 == date.month 110 | assert 4 == date.day 111 | assert 20 == date.hour 112 | assert 12 == date.minute 113 | assert 34 == date.second 114 | assert 343 == date.microsecond 115 | 116 | with pytest.raises(ValueError): 117 | parse_isodatetime("2018-06-000004") 118 | 119 | with pytest.raises(ValueError): 120 | parse_isodatetime("2018-06-04 20:12:34.343 CEST") 121 | 122 | 123 | def test_record_stage1_ok(): 124 | from pgtoolkit.log import Record 125 | 126 | lines = """\ 127 | 2018-06-15 10:49:26.088 UTC [8420]: [3-1] app=psql,db=postgres,client=[local],user=postgres LOG: duration: 1.449 ms statement: SELECT d.datname as "Name", 128 | \t pg_catalog.array_to_string(d.datacl, E'\\n') AS "Access privileges" 129 | \tFROM pg_catalog.pg_database d 130 | \tORDER BY 1; 131 | """.splitlines( 132 | True 133 | ) # noqa 134 | 135 | record = Record.parse_stage1(lines) 136 | assert "LOG" in repr(record) 137 | assert "\n" not in repr(record) 138 | assert 4 == len(record.raw_lines) 139 | assert "LOG" == record.severity 140 | assert 4 == len(record.message_lines) 141 | assert record.message_lines[0].startswith("duration: ") 142 | 143 | 144 | def test_record_stage1_nok(): 145 | from pgtoolkit.log import Record, UnknownData 146 | 147 | lines = ["pouet\n", "toto\n"] 148 | with pytest.raises(UnknownData) as ei: 149 | Record.parse_stage1(lines) 150 | assert "pouet\ntoto\n" in str(ei.value) 151 | 152 | 153 | def test_record_stage2_ok(mocker): 154 | from pgtoolkit.log import Record 155 | 156 | record = Record( 157 | prefix="2018-06-15 10:49:26.088 UTC [8420]: [3-1] app=psql,db=postgres,client=[local],user=postgres ", # noqa 158 | severity="LOG", 159 | message_lines=["message"], 160 | raw_lines=[], 161 | ) 162 | 163 | prefix_parser = mocker.Mock(name="prefix_parser") 164 | prefix_parser.return_value = dict(remote_host="[local]") 165 | record.parse_stage2(prefix_parser) 166 | assert "[local]" == record.remote_host 167 | 168 | 169 | def test_filters(): 170 | from pgtoolkit.log import NoopFilters, parse 171 | 172 | lines = """\ 173 | stage1 LOG: duration: 1002.209 ms statement: select pg_sleep(1); 174 | stage2 LOG: duration: 0.223 ms statement: show log_timezone; 175 | stage3 LOG: connection authorized: user=postgres database=postgres 176 | """.splitlines( 177 | True 178 | ) # noqa 179 | 180 | class MyFilters(NoopFilters): 181 | def stage1(self, record): 182 | return record.prefix.startswith("stage1") 183 | 184 | def stage2(self, record): 185 | return record.prefix.startswith("stage2") 186 | 187 | def stage3(self, record): 188 | return record.prefix.startswith("stage3") 189 | 190 | log_line_prefix = "stage%p " 191 | filters = MyFilters() 192 | records = list(parse(lines, prefix_fmt=log_line_prefix, filters=filters)) 193 | assert 0 == len(records) 194 | 195 | 196 | def test_main(mocker, caplog, capsys): 197 | from datetime import datetime, timezone 198 | 199 | from pgtoolkit.log import Record, UnknownData, __main__ 200 | from pgtoolkit.log.__main__ import main 201 | 202 | mocker.patch.object(__main__, "basicConfig", autospec=True) 203 | open_ = mocker.patch.object(__main__, "open_or_stdin", autospec=True) 204 | open_.return_value = mocker.MagicMock() 205 | parse = mocker.patch.object(__main__, "parse", autospec=True) 206 | parse.return_value = [ 207 | Record("prefix", "LOG", timestamp=datetime.now(timezone.utc)), 208 | UnknownData(["unknown line\n"]), 209 | ] 210 | log_line_prefix = "%m [%p]: " 211 | main(argv=[log_line_prefix], environ=dict()) 212 | out, err = capsys.readouterr() 213 | assert "LOG" in out 214 | 215 | if isinstance(caplog.records, list): 216 | records = caplog.records 217 | else: # Compat python pytest-capturelog for py26 218 | records = caplog.records() 219 | 220 | for record in records: 221 | if "unknown line" in record.message: 222 | break 223 | else: 224 | assert False, "Bad line not logged" 225 | 226 | 227 | def test_main_ko(mocker): 228 | from pgtoolkit.log import Record, __main__ 229 | from pgtoolkit.log.__main__ import main 230 | 231 | mocker.patch.object(__main__, "basicConfig", autospec=True) 232 | open_ = mocker.patch.object(__main__, "open_or_stdin", autospec=True) 233 | open_.return_value = mocker.MagicMock() 234 | parse = mocker.patch.object(__main__, "parse", autospec=True) 235 | parse.return_value = [Record("prefix", "LOG", badentry=object())] 236 | assert 1 == main(argv=["%m"], environ=dict()) 237 | -------------------------------------------------------------------------------- /tests/test_pass.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pytest 4 | 5 | 6 | def test_escaped_split(): 7 | from pgtoolkit.pgpass import escapedsplit 8 | 9 | assert ["a", "b"] == list(escapedsplit("a:b", ":")) 10 | assert ["a", ""] == list(escapedsplit("a:", ":")) 11 | assert ["a:"] == list(escapedsplit(r"a\:", ":")) 12 | assert ["a\\", ""] == list(escapedsplit(r"a\\:", ":")) 13 | 14 | with pytest.raises(ValueError): 15 | list(escapedsplit(r"", "long-delim")) 16 | 17 | 18 | def test_passfile_create(): 19 | from pgtoolkit.pgpass import PassComment, PassEntry, PassFile 20 | 21 | pgpass = PassFile([PassComment("# Comment"), PassEntry.parse("foo:*:bar:baz:dude")]) 22 | assert 2 == len(pgpass.lines) 23 | 24 | with pytest.raises(ValueError): 25 | PassFile("blah") 26 | 27 | 28 | def test_entry(): 29 | from pgtoolkit.pgpass import PassEntry 30 | 31 | a = PassEntry.parse(r"/var/run/postgresql:5432:db:postgres:conf\:dentie\\") 32 | assert "/var/run/postgresql" == a.hostname 33 | assert 5432 == a.port 34 | assert "db" == a.database 35 | assert "postgres" == a.username 36 | assert "conf:dentie\\" == a.password 37 | 38 | assert "dentie\\" not in repr(a) 39 | assert r"conf\:dentie\\" in str(a) 40 | 41 | b = PassEntry( 42 | hostname="/var/run/postgresql", 43 | port=5432, 44 | database="db", 45 | username="postgres", 46 | password="newpassword", 47 | ) 48 | 49 | entries = {a} 50 | 51 | assert b in entries 52 | 53 | 54 | def test_compare(): 55 | from pgtoolkit.pgpass import PassComment, PassEntry 56 | 57 | a = PassEntry.parse(":*:*:*:confidential") 58 | b = PassEntry.parse("hostname:*:*:*:otherpassword") 59 | c = PassEntry.parse("hostname:5442:*:username:otherpassword") 60 | d = PassEntry("hostname", "5442", "*", "username", "otherpassword") 61 | e = PassEntry("hostname", "5443", "*", "username", "otherpassword") 62 | 63 | assert a < b 64 | assert c < b 65 | assert a != b 66 | assert c == d 67 | assert c < e 68 | 69 | assert [c, a, b] == sorted([a, b, c]) 70 | 71 | d = PassComment("# Comment") 72 | e = PassComment("# hostname:5432:*:*:password") 73 | 74 | assert "Comment" in repr(d) 75 | 76 | # Preserve comment order. 77 | assert not d < e 78 | assert not e < d 79 | assert not d < a 80 | assert not a < d 81 | assert a != d 82 | 83 | assert e < a 84 | assert c < e 85 | 86 | with pytest.raises(TypeError): 87 | a < 42 88 | with pytest.raises(TypeError): 89 | "meh" > a 90 | assert (a == [1, 2]) is False 91 | 92 | 93 | def test_parse_lines(tmp_path): 94 | from pgtoolkit.pgpass import ParseError, parse 95 | 96 | lines = [ 97 | "# Comment for h2", 98 | "h2:*:*:postgres:confidential", 99 | "# h1:*:*:postgres:confidential", 100 | "h2:5432:*:postgres:confidential", 101 | ] 102 | 103 | pgpass = parse(lines) 104 | with pytest.raises(ParseError): 105 | pgpass.parse(["bad:line"]) 106 | 107 | pgpass.sort() 108 | 109 | # Ensure more precise line first. 110 | assert "h2:5432:" in str(pgpass.lines[0]) 111 | # Ensure h1 line is before h2 line, even commented. 112 | assert "# h1:" in str(pgpass.lines[1]) 113 | # Ensure comment is kept before h2:* line. 114 | assert "Comment" in str(pgpass.lines[2]) 115 | assert "h2:*" in str(pgpass.lines[3]) 116 | 117 | assert 2 == len(list(pgpass)) 118 | 119 | passfile = tmp_path / "fo" 120 | with passfile.open("w") as fo: 121 | pgpass.save(fo) 122 | assert passfile.read_text().splitlines() == [ 123 | "h2:5432:*:postgres:confidential", 124 | "# h1:*:*:postgres:confidential", 125 | "# Comment for h2", 126 | "h2:*:*:postgres:confidential", 127 | ] 128 | 129 | header = "#hostname:port:database:username:password" 130 | pgpass = parse([header]) 131 | pgpass.sort() 132 | assert pgpass.lines == [header] 133 | 134 | 135 | @pytest.mark.parametrize("pathtype", [str, Path]) 136 | def test_parse_file(pathtype, tmp_path): 137 | from pgtoolkit.pgpass import PassComment, parse 138 | 139 | fpath = tmp_path / "pgpass" 140 | fpath.touch() 141 | pgpass = parse(pathtype(fpath)) 142 | pgpass.lines.append(PassComment("# Something")) 143 | pgpass.save() 144 | assert fpath.read_text() == "# Something\n" 145 | 146 | 147 | def test_edit(tmp_path): 148 | from pgtoolkit.pgpass import PassComment, PassEntry, edit 149 | 150 | fpath = tmp_path / "pgpass" 151 | assert not fpath.exists() 152 | 153 | # Check we don't create an empty file. 154 | with edit(fpath): 155 | pass 156 | assert not fpath.exists() 157 | 158 | with edit(fpath) as passfile: 159 | passfile.lines.append(PassComment("# commented")) 160 | assert fpath.read_text() == "# commented\n" 161 | 162 | with edit(fpath) as passfile: 163 | del passfile.lines[:] 164 | assert fpath.read_text() == "" 165 | 166 | with edit(fpath) as passfile: 167 | passfile.lines.append(PassComment("# commented")) 168 | assert fpath.read_text() == "# commented\n" 169 | 170 | with edit(fpath) as passfile: 171 | passfile.lines.extend( 172 | [ 173 | PassEntry("*", "5443", "*", "username", "otherpassword"), 174 | PassEntry("hostname", "5443", "*", "username", "password"), 175 | ] 176 | ) 177 | passfile.sort() 178 | assert fpath.read_text().splitlines() == [ 179 | "hostname:5443:*:username:password", 180 | "# commented", 181 | "*:5443:*:username:otherpassword", 182 | ] 183 | 184 | 185 | def test_save_nofile(): 186 | from pgtoolkit.pgpass import PassComment, PassFile 187 | 188 | pgpass = PassFile() 189 | pgpass.lines.append(PassComment("# Something")) 190 | with pytest.raises(ValueError): 191 | pgpass.save() 192 | 193 | 194 | def test_matches(): 195 | from pgtoolkit.pgpass import PassComment, PassEntry 196 | 197 | a = PassEntry( 198 | hostname="/var/run/postgresql", 199 | port=5432, 200 | database="db", 201 | username="postgres", 202 | password="newpassword", 203 | ) 204 | assert a.matches(port=5432, database="db") 205 | with pytest.raises(AttributeError): 206 | assert a.matches(dbname="newpassword") 207 | assert not a.matches(port=5433) 208 | 209 | b = PassComment("# some non-entry comment") 210 | assert not b.matches(port=5432) 211 | 212 | c = PassComment("# hostname:5432:*:*:password") 213 | assert c.matches(port=5432) 214 | 215 | 216 | def test_remove(): 217 | from pgtoolkit.pgpass import parse 218 | 219 | lines = [ 220 | "# Comment for h2", 221 | "h2:*:*:postgres:confidential", 222 | "# h1:*:*:postgres:confidential", 223 | "h2:5432:*:postgres:confidential", 224 | "h2:5432:*:david:Som3Password", 225 | "h2:5433:*:postgres:confidential", 226 | ] 227 | 228 | pgpass = parse(lines) 229 | 230 | with pytest.raises(ValueError): 231 | pgpass.remove() 232 | 233 | pgpass.remove(port=5432) 234 | assert 4 == len(pgpass.lines) 235 | 236 | # All matching entries are removed even commented ones 237 | pgpass = parse(lines) 238 | pgpass.remove(username="postgres") 239 | assert 2 == len(pgpass.lines) 240 | 241 | pgpass = parse(lines) 242 | pgpass.remove(port=5432, username="postgres") 243 | assert 5 == len(pgpass.lines) 244 | 245 | def filter(line): 246 | return line.username == "postgres" 247 | 248 | pgpass = parse(lines) 249 | pgpass.remove(filter=filter) 250 | assert 2 == len(pgpass.lines) 251 | 252 | # Only filter is taken into account 253 | pgpass = parse(lines) 254 | with pytest.warns(UserWarning): 255 | pgpass.remove(filter=filter, port=5432) 256 | assert 2 == len(pgpass.lines) 257 | 258 | # Error if attribute name is not valid 259 | pgpass = parse(lines) 260 | with pytest.raises(AttributeError): 261 | pgpass.remove(userna="postgres") 262 | -------------------------------------------------------------------------------- /tests/test_service.py: -------------------------------------------------------------------------------- 1 | from io import StringIO 2 | from textwrap import dedent 3 | 4 | import pytest 5 | 6 | 7 | def test_parse(): 8 | from pgtoolkit.service import parse 9 | 10 | lines = dedent( 11 | """\ 12 | [service0] 13 | host=myhost 14 | port=5432 15 | user=toto 16 | 17 | [service1] 18 | host=myhost1 19 | port=5432 20 | dbname=myapp 21 | """ 22 | ).splitlines() 23 | 24 | services = parse(lines, source="in-memory") 25 | assert 2 == len(services) 26 | assert "pgtoolkit" not in repr(services) 27 | 28 | service0 = services["service0"] 29 | assert "service0" == service0.name 30 | assert "myhost" == service0.host 31 | assert "service0" in repr(service0) 32 | 33 | 34 | def test_parse_file(mocker): 35 | from pgtoolkit.service import parse 36 | 37 | m = mocker.mock_open() 38 | try: 39 | mocker.patch("builtins.open", m) 40 | except Exception: 41 | mocker.patch("__builtin__.open", m) 42 | services = parse("filename") 43 | 44 | assert m.called 45 | services.save() 46 | 47 | m = mocker.Mock() 48 | try: 49 | mocker.patch("configparser.ConfigParser.write", new_callable=m) 50 | except Exception: 51 | mocker.patch("ConfigParser.ConfigParser.write", new_callable=m) 52 | assert m.called 53 | 54 | 55 | def test_render(): 56 | from pgtoolkit.service import Service, ServiceFile 57 | 58 | services = ServiceFile() 59 | service0 = Service(name="service0", dbname="mydb") 60 | services.add(service0) 61 | 62 | # Moving options and updating service. 63 | service0.pop("dbname") 64 | service0.port = 5432 65 | services.add(service0) 66 | 67 | service0 = services["service0"] 68 | assert 5432 == service0.port 69 | assert "dbname" not in service0 70 | 71 | services.add(Service(name="service1", host="myhost")) 72 | 73 | fo = StringIO() 74 | services.save(fo) 75 | raw = fo.getvalue() 76 | 77 | # Ensure no space around = 78 | assert raw == "\n".join( 79 | [ 80 | "[service0]", 81 | "port=5432", 82 | "", 83 | "[service1]", 84 | "host=myhost", 85 | "", 86 | "", 87 | ] 88 | ) 89 | 90 | 91 | def test_sysconfdir(mocker): 92 | isdir = mocker.patch("pgtoolkit.service.os.path.isdir", autospec=True) 93 | 94 | from pgtoolkit.service import guess_sysconfdir 95 | 96 | isdir.return_value = False 97 | with pytest.raises(Exception): 98 | guess_sysconfdir(environ=dict(PGSYSCONFDIR="/toto")) 99 | 100 | isdir.return_value = True 101 | sysconfdir = guess_sysconfdir(environ=dict(PGSYSCONFDIR="/toto")) 102 | assert "/toto" == sysconfdir 103 | 104 | isdir.return_value = False 105 | with pytest.raises(Exception): 106 | guess_sysconfdir(environ=dict()) 107 | 108 | isdir.return_value = True 109 | sysconfdir = guess_sysconfdir(environ=dict()) 110 | assert sysconfdir.startswith("/etc") 111 | 112 | 113 | def test_find(mocker): 114 | g_scd = mocker.patch("pgtoolkit.service.guess_sysconfdir", autospec=True) 115 | exists = mocker.patch("pgtoolkit.service.os.path.exists", autospec=True) 116 | 117 | from pgtoolkit.service import find 118 | 119 | exists.return_value = False 120 | with pytest.raises(Exception): 121 | find(environ=dict(PGSERVICEFILE="my-services.conf")) 122 | 123 | g_scd.return_value = "/etc/postgresql-common" 124 | with pytest.raises(Exception): 125 | find(environ=dict()) 126 | 127 | exists.return_value = True 128 | servicefile = find(environ=dict(PGSERVICEFILE="toto.conf")) 129 | assert "toto.conf" == servicefile 130 | 131 | exists.side_effect = [False, True] 132 | servicefile = find(environ=dict()) 133 | assert servicefile.endswith("/pg_service.conf") 134 | exists.side_effect = None 135 | 136 | g_scd.side_effect = Exception("Pouet") 137 | servicefile = find(environ=dict()) 138 | assert servicefile.endswith("/.pg_service.conf") 139 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | minversion = 4.20.0 3 | envlist = lint,tests,typing 4 | isolated_build = true 5 | 6 | [testenv:lint] 7 | commands= 8 | flake8 9 | black --check --diff . 10 | isort --check --diff . 11 | pre-commit run --all-files --show-diff-on-failure pyupgrade 12 | check-manifest 13 | deps = 14 | flake8 15 | black 16 | check-manifest 17 | isort 18 | pre-commit 19 | pyupgrade 20 | skip_install = true 21 | 22 | [testenv:tests{,-ci}] 23 | allowlist_externals = 24 | ./tests/datatests.sh 25 | commands = 26 | pytest -ra --cov --cov-report=term-missing --cov-report=xml {posargs} 27 | ci: ./tests/datatests.sh 28 | deps = 29 | ci: codecov 30 | extras = 31 | test 32 | 33 | [testenv:typing] 34 | commands= 35 | mypy --strict pgtoolkit 36 | deps = 37 | mypy 38 | skip_install = true 39 | --------------------------------------------------------------------------------