├── pgdumplib ├── py.typed ├── exceptions.py ├── __init__.py ├── models.py ├── constants.py ├── converters.py └── dump.py ├── MANIFEST.in ├── docs ├── constants.md ├── converters.md ├── exceptions.md ├── api.md ├── index.md └── examples.md ├── .gitignore ├── tests ├── __init__.py ├── test_exceptions.py ├── test_edge_cases.py ├── test_save_dump.py ├── test_converters.py └── test_dump.py ├── ci └── test ├── CONTRIBUTING.md ├── .editorconfig ├── compose.yaml ├── .pre-commit-config.yaml ├── .github └── workflows │ ├── deploy.yaml │ ├── docs.yaml │ └── testing.yaml ├── README.md ├── LICENSE ├── mkdocs.yml ├── CHANGELOG.md ├── fixtures └── schema.sql ├── pyproject.toml ├── bootstrap ├── CLAUDE.md └── bin └── generate-fixture-data.py /pgdumplib/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md LICENSE 2 | -------------------------------------------------------------------------------- /docs/constants.md: -------------------------------------------------------------------------------- 1 | # Constants 2 | 3 | ::: pgdumplib.constants 4 | options: 5 | show_root_heading: true 6 | show_source: true 7 | -------------------------------------------------------------------------------- /docs/converters.md: -------------------------------------------------------------------------------- 1 | # Converters 2 | 3 | ::: pgdumplib.converters 4 | options: 5 | show_root_heading: true 6 | show_source: true 7 | -------------------------------------------------------------------------------- /docs/exceptions.md: -------------------------------------------------------------------------------- 1 | # Exceptions 2 | 3 | ::: pgdumplib.exceptions 4 | options: 5 | show_root_heading: true 6 | show_source: true 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.egg 2 | *.egg-info 3 | *.pyc 4 | *.pyo 5 | .coverage 6 | .mypy_cache 7 | .idea 8 | .vscode 9 | build/ 10 | dist/ 11 | env/ 12 | docs/_build 13 | site/ 14 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | from dotenv import load_dotenv 2 | 3 | 4 | def setup_module(): 5 | """Load environment variables from .env file""" 6 | # Load from .env file if it exists 7 | load_dotenv() 8 | -------------------------------------------------------------------------------- /ci/test: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | set -e 3 | ./bootstrap 4 | 5 | 6 | echo ' 7 | LINTING 8 | ' 9 | pre-commit run --all-files 10 | 11 | echo ' 12 | RUNNING TESTS 13 | ' 14 | coverage run -m unittest discover -f tests 15 | coverage report 16 | coverage xml 17 | -------------------------------------------------------------------------------- /docs/api.md: -------------------------------------------------------------------------------- 1 | # API 2 | 3 | ## pgdumplib 4 | 5 | ::: pgdumplib 6 | options: 7 | show_root_heading: true 8 | show_source: true 9 | 10 | ## pgdumplib.dump 11 | 12 | ::: pgdumplib.dump 13 | options: 14 | show_root_heading: true 15 | show_source: true 16 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | To get setup in the environment and run the tests, take the following steps: 4 | 5 | ```bash 6 | python3 -m venv env 7 | source env/bin/activate 8 | python3 -m pip install -e '.[dev]' 9 | python3 -m pip install pre-commit 10 | ``` 11 | 12 | ## Test Coverage 13 | 14 | Pull requests that make changes or additions that are not covered by tests 15 | will likely be closed without review. 16 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | # top-most EditorConfig file 2 | root = true 3 | 4 | # Unix-style newlines with a newline ending every file 5 | [*] 6 | end_of_line = lf 7 | insert_final_newline = true 8 | 9 | # 4 space indentation 10 | [*.py] 11 | indent_style = space 12 | indent_size = 4 13 | 14 | # 2 space indentation 15 | [*.yml] 16 | indent_style = space 17 | indent_size = 2 18 | 19 | [bootstrap] 20 | indent_style = space 21 | indent_size = 2 22 | -------------------------------------------------------------------------------- /compose.yaml: -------------------------------------------------------------------------------- 1 | services: 2 | postgres: 3 | image: postgres:18 4 | healthcheck: 5 | test: /usr/bin/pg_isready 6 | interval: 30s 7 | timeout: 10s 8 | start_period: 5s 9 | start_interval: 2s 10 | retries: 3 11 | environment: 12 | POSTGRES_HOST_AUTH_METHOD: trust 13 | ports: 14 | - "5432" 15 | volumes: 16 | - type: bind 17 | source: ./build/data 18 | target: /data 19 | - type: bind 20 | source: ./fixtures 21 | target: /fixtures 22 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v6.0.0 4 | hooks: 5 | - id: check-ast 6 | - id: check-toml 7 | - id: check-yaml 8 | args: [--allow-multiple-documents] 9 | - id: debug-statements 10 | - id: end-of-file-fixer 11 | - id: mixed-line-ending 12 | - id: trailing-whitespace 13 | 14 | - repo: https://github.com/astral-sh/ruff-pre-commit 15 | rev: v0.14.5 16 | hooks: 17 | - id: ruff-format 18 | - id: ruff 19 | args: [--fix] 20 | -------------------------------------------------------------------------------- /tests/test_exceptions.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from pgdumplib import exceptions 4 | 5 | 6 | class ExceptionTestCase(unittest.TestCase): 7 | def test_repr_formatting(self): 8 | exc = exceptions.EntityNotFoundError('public', 'table') 9 | self.assertEqual( 10 | repr(exc), "" 11 | ) 12 | 13 | def test_str_formatting(self): 14 | exc = exceptions.EntityNotFoundError('public', 'table') 15 | self.assertEqual( 16 | str(exc), 'Did not find public.table in the table of contents' 17 | ) 18 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # pgdumplib 2 | 3 | Python 3 library for reading and writing [pg_dump](https://www.postgresql.org/docs/current/app-pgdump.html) files using the custom format. 4 | 5 | [![Version](https://img.shields.io/pypi/v/pgdumplib.svg)](https://pypi.python.org/pypi/pgdumplib) 6 | [![License](https://img.shields.io/pypi/l/pgdumplib.svg)](https://github.com/gmr/pgdumplib/blob/master/LICENSE) 7 | 8 | ## Installation 9 | 10 | ```bash 11 | pip install pgdumplib 12 | ``` 13 | 14 | ## Documentation 15 | 16 | - [API Reference](api.md) 17 | - [Constants](constants.md) 18 | - [Converters](converters.md) 19 | - [Exceptions](exceptions.md) 20 | - [Examples](examples.md) 21 | -------------------------------------------------------------------------------- /.github/workflows/deploy.yaml: -------------------------------------------------------------------------------- 1 | name: Publish to PyPI 2 | 3 | on: 4 | release: 5 | types: [created] 6 | 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | 11 | environment: 12 | name: pypi 13 | url: https://pypi.org/p/pgdumplib 14 | 15 | permissions: 16 | id-token: write 17 | 18 | steps: 19 | - uses: actions/checkout@v5 20 | 21 | - name: Set up Python 3.14 22 | uses: actions/setup-python@v6 23 | with: 24 | python-version: "3.14" 25 | 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install build wheel twine 30 | 31 | - name: Build and check package 32 | run: | 33 | python -m build 34 | twine check dist/* 35 | 36 | - name: Publish package distributions to PyPI 37 | uses: pypa/gh-action-pypi-publish@release/v1 38 | -------------------------------------------------------------------------------- /pgdumplib/exceptions.py: -------------------------------------------------------------------------------- 1 | """ 2 | pgdumplib specific exceptions 3 | 4 | """ 5 | 6 | 7 | class PgDumpLibError(Exception): 8 | """Common Base Exception""" 9 | 10 | 11 | class NoDataError(PgDumpLibError): 12 | """Raised when attempting to work with data when do data entries exist""" 13 | 14 | 15 | class EntityNotFoundError(PgDumpLibError): 16 | """Raised when an attempt is made to read data from a relation in a 17 | dump file but it is not found in the table of contents. 18 | 19 | This can happen if a schema-only dump was created OR if the ``namespace`` 20 | and ``table`` specified were not found. 21 | 22 | """ 23 | 24 | def __init__(self, namespace: str, table: str): 25 | super().__init__() 26 | self.namespace = namespace 27 | self.table = table 28 | 29 | def __repr__(self) -> str: # pragma: nocover 30 | return ( 31 | f'' 33 | ) 34 | 35 | def __str__(self) -> str: # pragma: nocover 36 | return ( 37 | f'Did not find {self.namespace}.{self.table} in the table ' 38 | f'of contents' 39 | ) 40 | -------------------------------------------------------------------------------- /.github/workflows/docs.yaml: -------------------------------------------------------------------------------- 1 | name: Deploy Documentation 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | paths: 8 | - 'docs/**' 9 | - 'mkdocs.yml' 10 | - 'pgdumplib/**' 11 | - '.github/workflows/docs.yaml' 12 | workflow_dispatch: 13 | 14 | permissions: 15 | contents: read 16 | pages: write 17 | id-token: write 18 | 19 | concurrency: 20 | group: "pages" 21 | cancel-in-progress: false 22 | 23 | jobs: 24 | build: 25 | runs-on: ubuntu-latest 26 | steps: 27 | - name: Checkout repository 28 | uses: actions/checkout@v5 29 | 30 | - name: Set up Python 31 | uses: actions/setup-python@v6 32 | with: 33 | python-version: '3.12' 34 | 35 | - name: Install dependencies 36 | run: | 37 | python -m pip install --upgrade pip 38 | pip install -e '.[docs]' 39 | 40 | - name: Build documentation 41 | run: mkdocs build --strict 42 | 43 | - name: Upload artifact 44 | uses: actions/upload-pages-artifact@v3 45 | with: 46 | path: ./site 47 | 48 | deploy: 49 | environment: 50 | name: github-pages 51 | url: ${{ steps.deployment.outputs.page_url }} 52 | runs-on: ubuntu-latest 53 | needs: build 54 | steps: 55 | - name: Deploy to GitHub Pages 56 | id: deployment 57 | uses: actions/deploy-pages@v4 58 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pgdumplib 2 | 3 | Python 3 library for reading and writing pg_dump files using the custom format. 4 | 5 | [![Version](https://img.shields.io/pypi/v/pgdumplib.svg)](https://pypi.python.org/pypi/pgdumplib) 6 | [![Status](https://github.com/gmr/pgdumplib/workflows/Testing/badge.svg)](https://github.com/gmr/pgdumplib/actions) 7 | [![Coverage](https://codecov.io/gh/gmr/pgdumplib/branch/master/graph/badge.svg)](https://codecov.io/github/gmr/pgdumplib?branch=master) 8 | [![License](https://img.shields.io/pypi/l/pgdumplib.svg)](https://github.com/gmr/pgdumplib/blob/master/LICENSE) 9 | [![Docs](https://img.shields.io/badge/docs-github%20pages-blue)](https://gmr.github.io/pgdumplib/) 10 | 11 | ## Installation 12 | 13 | ```bash 14 | pip install pgdumplib 15 | ``` 16 | 17 | ## Example Usage 18 | 19 | The following example shows how to create a dump and then read it in, and 20 | iterate through the data of one of the tables. 21 | 22 | ```bash 23 | pg_dump -d pgbench -Fc -f pgbench.dump 24 | ``` 25 | 26 | ```python 27 | import pgdumplib 28 | 29 | dump = pgdumplib.load('pgbench.dump') 30 | 31 | print('Database: {}'.format(dump.toc.dbname)) 32 | print('Archive Timestamp: {}'.format(dump.toc.timestamp)) 33 | print('Server Version: {}'.format(dump.toc.server_version)) 34 | print('Dump Version: {}'.format(dump.toc.dump_version)) 35 | 36 | for line in dump.table_data('public', 'pgbench_accounts'): 37 | print(line) 38 | ``` 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018-2025 Gavin M. Roy, AWeber 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, 5 | are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | * Redistributions in binary form must reproduce the above copyright notice, 10 | this list of conditions and the following disclaimer in the documentation 11 | and/or other materials provided with the distribution. 12 | * Neither the name of the copyright holder nor the names of its contributors may 13 | be used to endorse or promote products derived from this software without 14 | specific prior written permission. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 17 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 18 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. 19 | IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, 20 | INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 21 | BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 22 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 23 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE 24 | OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 25 | ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: pgdumplib 2 | site_description: Python3 library for reading and writing pg_dump files using the custom format 3 | site_url: https://gmr.github.io/pgdumplib/ 4 | repo_url: https://github.com/gmr/pgdumplib 5 | repo_name: gmr/pgdumplib 6 | 7 | exclude_docs: | 8 | *.rst 9 | conf.py 10 | 11 | theme: 12 | name: material 13 | palette: 14 | - scheme: default 15 | primary: indigo 16 | accent: indigo 17 | toggle: 18 | icon: material/brightness-7 19 | name: Switch to dark mode 20 | - scheme: slate 21 | primary: indigo 22 | accent: indigo 23 | toggle: 24 | icon: material/brightness-4 25 | name: Switch to light mode 26 | features: 27 | - navigation.instant 28 | - navigation.tracking 29 | - navigation.tabs 30 | - navigation.sections 31 | - navigation.expand 32 | - navigation.top 33 | - search.suggest 34 | - search.highlight 35 | - content.code.copy 36 | 37 | plugins: 38 | - search 39 | - mkdocstrings: 40 | handlers: 41 | python: 42 | paths: [.] 43 | options: 44 | docstring_style: google 45 | show_source: true 46 | show_root_heading: true 47 | show_root_toc_entry: true 48 | heading_level: 2 49 | members_order: source 50 | 51 | markdown_extensions: 52 | - admonition 53 | - pymdownx.details 54 | - pymdownx.highlight: 55 | anchor_linenums: true 56 | - pymdownx.inlinehilite 57 | - pymdownx.snippets 58 | - pymdownx.superfences 59 | - tables 60 | - toc: 61 | permalink: true 62 | 63 | nav: 64 | - Home: index.md 65 | - Examples: examples.md 66 | - API Reference: 67 | - API: api.md 68 | - Constants: constants.md 69 | - Converters: converters.md 70 | - Exceptions: exceptions.md 71 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | ## 4.0.0 - 2025-11-17 4 | 5 | - Updates to support Postgres versions up to 18 6 | - Renamed `pgdumplib.exceptions.PgDumpLibException` to `pgdumplib.exceptions.PgDumpLibError` 7 | - Updated Python support to 3.11+ 8 | - Modernized project structure 9 | 10 | ## 3.1.0 - 2020-02-06 11 | 12 | - Add configurable version support and appropriately set dump version based upon Postgres version 13 | - When sorting while saving, prefer GROUP, ROLE, and USER before all other objects 14 | 15 | ## 3.0.0 - 2019-12-05 16 | 17 | - Add support for Postgres 12 / pg_dump archive 1.14.0 format 18 | 19 | ## 2.1.0 - 2019-11-12 20 | 21 | - Add TABLESPACE to supported objects when adding an entry 22 | 23 | ## 2.0.1 - 2019-11-05 24 | 25 | - Minor typing and docstring updates 26 | 27 | 28 | ## 2.0.0 - 2019-10-15 29 | 30 | - Add mapping of `desc`/object type to dump section in `pgdumplib.constants` 31 | - Change `pgdumplib.dump.Entry` to use the section mapping instead of as an assignable attribute 32 | - `pgdumplib.dump.Dump.add_entry` function signature change, dropping `section` and moving `desc` to the first arg 33 | - `pgdumplib.dump.Dump.add_entry` validates a provided `dump_id` is above `0` and is not already used 34 | - Change `pgdumplib.dump.Dump.save` behavior to provide a bit of forced ordering and toplogical sorting 35 | - No longer does two-pass writing on save if the dump has no data 36 | 37 | ## 1.0.1 - 2019-06-17 38 | 39 | - Cleanup type annotations 40 | - Distribute as a sdist instead of bdist 41 | 42 | ## 1.0.0 - 2019-06-14 43 | 44 | - Full read/write support of custom formatted pg_dump files 45 | - Refactored API 46 | 47 | ## 0.2.1 - 2019-01-15 48 | 49 | - Documentation and docstring updates 50 | 51 | ## 0.2.0 - 2019-01-15 52 | 53 | - Make the dump ID value part of the entry 54 | 55 | ## 0.1.1 - 2019-01-15 56 | 57 | - Distribution Fixes 58 | 59 | ## 0.1.0 - 2018-11-06 60 | 61 | - Initial alpha release 62 | -------------------------------------------------------------------------------- /pgdumplib/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | pgdumplib exposes a load method to create a :py:class:`~pgdumplib.dump.Dump` 3 | instance from a :command:`pg_dump` file created in the `custom` format. 4 | 5 | See the :doc:`examples` page to see how to read a dump or create one. 6 | 7 | """ 8 | 9 | import pathlib 10 | import typing 11 | from importlib import metadata 12 | 13 | if typing.TYPE_CHECKING: 14 | from pgdumplib import converters as converters 15 | from pgdumplib import dump 16 | 17 | version = metadata.version('pgdumplib') 18 | 19 | 20 | def load( 21 | filepath: str | pathlib.Path, 22 | converter: typing.Any = None, 23 | ) -> 'dump.Dump': 24 | """Load a pg_dump file created with -Fc from disk 25 | 26 | :param filepath: The path to the dump to load 27 | :type filepath: str or pathlib.Path 28 | :param converter: The data converter class to use 29 | (Default: :py:class:`pgdumplib.converters.DataConverter`) 30 | :type converter: Converter class or None 31 | :raises: :py:exc:`ValueError` 32 | :rtype: pgdumplib.dump.Dump 33 | 34 | """ 35 | from pgdumplib import dump 36 | 37 | return dump.Dump(converter=converter).load(filepath) 38 | 39 | 40 | def new( 41 | dbname: str = 'pgdumplib', 42 | encoding: str = 'UTF8', 43 | converter: typing.Any = None, 44 | appear_as: str = '18.0', 45 | ) -> 'dump.Dump': 46 | """Create a new :py:class:`pgdumplib.dump.Dump` instance 47 | 48 | :param dbname: The database name for the dump (Default: ``pgdumplib``) 49 | :param encoding: The data encoding (Default: ``UTF8``) 50 | :param converter: The data converter class to use 51 | (Default: :py:class:`pgdumplib.converters.DataConverter`) 52 | :type converter: Converter class or None 53 | :param appear_as: The version of Postgres to emulate 54 | (Default: ``18.0``) 55 | :rtype: pgdumplib.dump.Dump 56 | 57 | """ 58 | from pgdumplib import dump 59 | 60 | return dump.Dump(dbname, encoding, converter, appear_as) 61 | -------------------------------------------------------------------------------- /pgdumplib/models.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | 3 | from pgdumplib import constants 4 | 5 | 6 | @dataclasses.dataclass(eq=True) 7 | class Entry: 8 | """The entry model represents a single entry in the dataclass 9 | 10 | Custom formatted dump files are primarily comprised of entries, which 11 | contain all of the metadata and DDL required to construct the database. 12 | 13 | For table data and blobs, there are entries that contain offset locations 14 | in the dump file that instruct the reader as to where the data lives 15 | in the file. 16 | 17 | :var dump_id: The dump id, will be auto-calculated if left empty 18 | :var had_dumper: Indicates 19 | :var oid: The OID of the object the entry represents 20 | :var tag: The name/table/relation/etc of the entry 21 | :var desc: The entry description 22 | :var defn: The DDL definition for the entry 23 | :var drop_stmt: A drop statement used to drop the entry before 24 | :var copy_stmt: A copy statement used when there is a corresponding 25 | data section. 26 | :var namespace: The namespace of the entry 27 | :var tablespace: The tablespace to use 28 | :var tableam: The table access method 29 | :var relkind: The relation kind (added in v1.16) 30 | :var owner: The owner of the object in Postgres 31 | :var with_oids: Indicates ... 32 | :var dependencies: A list of dump_ids of objects that the entry 33 | is dependent upon. 34 | :var data_state: Indicates if the entry has data and how it is stored 35 | :var offset: If the entry has data, the offset to the data in the file 36 | :var section: The section of the dump file the entry belongs to 37 | 38 | """ 39 | 40 | dump_id: int 41 | had_dumper: bool = False 42 | table_oid: str = '0' 43 | oid: str = '0' 44 | tag: str | None = None 45 | desc: str = 'Unknown' 46 | defn: str | None = None 47 | drop_stmt: str | None = None 48 | copy_stmt: str | None = None 49 | namespace: str | None = None 50 | tablespace: str | None = None 51 | tableam: str | None = None 52 | relkind: str | None = None 53 | owner: str | None = None 54 | with_oids: bool = False 55 | dependencies: list[int] = dataclasses.field(default_factory=list) 56 | data_state: int = constants.K_OFFSET_NO_DATA 57 | offset: int = 0 58 | 59 | @property 60 | def section(self) -> str: 61 | """Return the section the entry belongs to""" 62 | return constants.SECTION_MAPPING[self.desc] 63 | -------------------------------------------------------------------------------- /fixtures/schema.sql: -------------------------------------------------------------------------------- 1 | -- Fixture Schema For Testing 2 | 3 | CREATE EXTENSION citext; 4 | CREATE EXTENSION IF NOT EXISTS "uuid-ossp"; 5 | 6 | CREATE SCHEMA test; 7 | GRANT USAGE ON SCHEMA test TO PUBLIC; 8 | 9 | SET search_path = test, public, pg_catalog; 10 | 11 | CREATE TABLE empty_table( 12 | id UUID NOT NULL DEFAULT uuid_generate_v4() PRIMARY KEY, 13 | created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, 14 | last_modified_at TIMESTAMP WITH TIME ZONE, 15 | column_name TEXT 16 | ); 17 | 18 | CREATE DOMAIN test.email_address AS citext 19 | CHECK ( value ~ '^[a-zA-Z0-9.!#$%&''*+/=?^_`{|}~-]+@[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(?:\.[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*$' ); 20 | 21 | -- Simplified locale check, doesn't fully conform to BCP-47 22 | CREATE DOMAIN test.bcp47_locale AS TEXT 23 | CHECK ( value ~ '^[a-z]{2}-[A-Z]{2,3}$' ); 24 | 25 | CREATE TYPE user_state AS ENUM ('unverified', 'verified', 'suspended'); 26 | 27 | CREATE TABLE users ( 28 | id UUID NOT NULL DEFAULT uuid_generate_v4() PRIMARY KEY, 29 | created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, 30 | last_modified_at TIMESTAMP WITH TIME ZONE, 31 | state user_state NOT NULL DEFAULT 'unverified', 32 | email email_address NOT NULL, 33 | name TEXT NOT NULL, 34 | surname TEXT NOT NULL, 35 | display_name TEXT, 36 | locale bcp47_locale NOT NULL DEFAULT 'en-US', 37 | password_salt TEXT NOT NULL, 38 | password TEXT NOT NULL, 39 | signup_ip INET NOT NULL, 40 | icon OID 41 | ); 42 | 43 | CREATE UNIQUE INDEX users_unique_email ON users (email); 44 | 45 | CREATE TYPE address_type AS ENUM ('billing', 'delivery'); 46 | 47 | CREATE TABLE addresses ( 48 | id UUID NOT NULL DEFAULT uuid_generate_v4() PRIMARY KEY, 49 | created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, 50 | last_modified_at TIMESTAMP WITH TIME ZONE, 51 | user_id UUID NOT NULL REFERENCES users (id) ON DELETE CASCADE ON UPDATE CASCADE, 52 | type address_type NOT NULL, 53 | address1 TEXT NOT NULL, 54 | address2 TEXT, 55 | address3 TEXT, 56 | locality TEXT NOT NULL, 57 | region TEXT, 58 | postal_code TEXT NOT NULL, 59 | country TEXT NOT NULL 60 | ); 61 | -------------------------------------------------------------------------------- /.github/workflows/testing.yaml: -------------------------------------------------------------------------------- 1 | name: Testing 2 | on: 3 | pull_request: 4 | push: 5 | branches: ["*"] 6 | paths-ignore: 7 | - 'docs/**' 8 | - '*.md' 9 | - '*.rst' 10 | tags-ignore: ["*"] 11 | 12 | env: 13 | TEST_HOST: postgres 14 | 15 | jobs: 16 | test: 17 | runs-on: ubuntu-latest 18 | strategy: 19 | fail-fast: false 20 | matrix: 21 | python: ["3.11", "3.12", "3.13", "3.14"] 22 | postgres: [15, 16, 17, 18] 23 | services: 24 | postgres: 25 | image: postgres:${{ matrix.postgres }} 26 | env: 27 | POSTGRES_PASSWORD: postgres 28 | ports: 29 | - 5432:5432 30 | options: >- 31 | --health-cmd pg_isready 32 | --health-interval 10s 33 | --health-timeout 5s 34 | --health-retries 5 35 | --name postgres 36 | --hostname postgres 37 | 38 | steps: 39 | - name: Checkout repository 40 | uses: actions/checkout@v5 41 | 42 | - name: Set up Python 43 | uses: actions/setup-python@v6 44 | with: 45 | python-version: '${{ matrix.python }}' 46 | 47 | - name: Set Timezone 48 | uses: szenius/set-timezone@v2.0 49 | with: 50 | timezoneLinux: "America/New_York" 51 | 52 | - name: Install PostgreSQL client tools 53 | run: | 54 | # Add PostgreSQL APT repository for version-specific client tools 55 | sudo apt-get install -y postgresql-common 56 | sudo /usr/share/postgresql-common/pgdg/apt.postgresql.org.sh -y 57 | sudo apt-get update 58 | sudo apt-get install -y postgresql-client-${{ matrix.postgres }} 59 | 60 | - name: Prepare container directories and copy fixtures 61 | run: | 62 | # Create directories inside the container 63 | docker exec postgres mkdir -p /data /fixtures 64 | docker exec postgres chmod 777 /data 65 | 66 | # Create local build directory 67 | mkdir -p build/data 68 | 69 | # Copy fixtures into container 70 | docker cp fixtures/. postgres:/fixtures/ 71 | 72 | - name: Bootstrap environment 73 | run: ./bootstrap 74 | env: 75 | CI: true 76 | 77 | - name: Copy dumps from container to workspace 78 | run: | 79 | # Copy generated dump files back to workspace for tests 80 | docker cp postgres:/data/. build/data/ 81 | 82 | - name: Lint Check 83 | run: pre-commit run --all-files 84 | 85 | - name: Run tests 86 | run: | 87 | coverage run 88 | coverage report 89 | coverage xml 90 | 91 | - name: Upload Coverage 92 | uses: codecov/codecov-action@v5 93 | with: 94 | files: ./build/coverage.xml 95 | flags: unittests 96 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "pgdumplib" 7 | version = "4.0.0" 8 | description = "Python3 library for working with pg_dump files" 9 | readme = "README.md" 10 | requires-python = ">=3.11" 11 | license = { text = "BSD 3-Clause License" } 12 | urls = { Homepage = "https://github.com/gmr/pgdumplib" } 13 | authors = [ 14 | { name = "Gavin M. Roy", email = "gavinmroy@gmail.com" } 15 | ] 16 | classifiers = [ 17 | "Development Status :: 5 - Production/Stable", 18 | "Intended Audience :: Developers", 19 | "License :: OSI Approved :: BSD License", 20 | "Natural Language :: English", 21 | "Operating System :: OS Independent", 22 | "Programming Language :: Python :: 3", 23 | "Programming Language :: Python :: 3.11", 24 | "Programming Language :: Python :: 3.12", 25 | "Programming Language :: Python :: 3.13", 26 | "Programming Language :: Python :: 3.14", 27 | "Programming Language :: Python :: Implementation :: CPython", 28 | "Programming Language :: SQL", 29 | "Topic :: Database", 30 | "Topic :: Database :: Database Engines/Servers", 31 | "Topic :: Software Development :: Libraries", 32 | "Topic :: Software Development :: Libraries :: Python Modules" 33 | ] 34 | dependencies = [ 35 | "pendulum", 36 | "toposort" 37 | ] 38 | 39 | [project.optional-dependencies] 40 | dev = [ 41 | "build", 42 | "coverage", 43 | "faker", 44 | "maya", 45 | "pre-commit", 46 | "psycopg", 47 | "python-dotenv", 48 | "ruff", 49 | ] 50 | docs = [ 51 | "mkdocs", 52 | "mkdocs-material", 53 | "mkdocstrings[python]", 54 | ] 55 | [tool.coverage.run] 56 | branch = true 57 | source = ["pgdumplib"] 58 | command_line = "-m unittest discover tests --buffer --verbose" 59 | 60 | [tool.coverage.report] 61 | exclude_also = [ 62 | "typing.TYPE_CHECKING", 63 | ] 64 | fail_under = 90 65 | show_missing = true 66 | 67 | [tool.coverage.html] 68 | directory = "build/coverage" 69 | 70 | [tool.coverage.xml] 71 | output = "build/reports/coverage.xml" 72 | 73 | [tool.hatch.build] 74 | artifacts = [ 75 | "pgdumplib" 76 | ] 77 | 78 | [tool.hatch.build.targets.wheel] 79 | include = [ 80 | "pgdumplib" 81 | ] 82 | 83 | [tool.hatch.env.default] 84 | python = "python3.12" 85 | features = ["dev"] 86 | 87 | [tool.ruff] 88 | line-length = 79 89 | target-version = "py312" 90 | 91 | [tool.ruff.format] 92 | quote-style = "single" 93 | 94 | [tool.ruff.lint] 95 | select = [ 96 | "BLE", # flake8-blind-except 97 | "C4", # flake8-comprehensions 98 | "C90", # mccabe 99 | "E", "W", # pycodestyle 100 | "F", # pyflakes 101 | "G", # flake8-logging-format 102 | "I", # isort 103 | "N", # pep8-naming 104 | "Q", # flake8-quotes 105 | "S", # flake8-bandit 106 | "ASYNC", # flake8-async 107 | "B", # flake8-bugbear 108 | "DTZ", # flake8-datetimez 109 | "FURB", # refurb 110 | "RUF", # ruff-specific 111 | "T20", # flake8-print 112 | "UP", # pyupgrade 113 | ] 114 | ignore = [ 115 | "RSE", # contradicts Python Style Guide 116 | "S311", # We're not doing any cryptography 117 | ] 118 | flake8-quotes = { inline-quotes = "single" } 119 | 120 | [tool.ruff.lint.mccabe] 121 | max-complexity = 15 122 | -------------------------------------------------------------------------------- /docs/examples.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | ## Reading 4 | 5 | First, create a dump of your database using `pg_dump` using the `custom` format: 6 | 7 | ```bash 8 | pg_dump -Fc -d [YOUR] dump.out 9 | ``` 10 | 11 | The following example shows how to get the data and the table definition from a dump: 12 | 13 | ```python 14 | import pgdumplib 15 | from pgdumplib import constants 16 | 17 | dump = pgdumplib.load('dump.out') 18 | for row in dump.table_data('public', 'table-name'): 19 | print(row) 20 | print(dump.lookup_entry(constants.TABLE, 'public', 'table-name').defn) 21 | ``` 22 | 23 | ## Writing 24 | 25 | To create a dump, you need to add sections. The following example shows how to create a dump with a schema, extension, comment, type, tables, and table data: 26 | 27 | ```python 28 | import datetime 29 | import uuid 30 | 31 | import pgdumplib 32 | from pgdumplib import constants 33 | 34 | dump = pgdumplib.new('example') 35 | 36 | schema = dump.add_entry( 37 | desc=constants.SCHEMA, 38 | tag='test', 39 | defn='CREATE SCHEMA test;', 40 | drop_stmt='DROP SCHEMA test;') 41 | 42 | dump.add_entry( 43 | desc=constants.ACL, 44 | tag='SCHEMA test', 45 | defn='GRANT USAGE ON SCHEMA test TO PUBLIC;', 46 | dependencies=[schema.dump_id]) 47 | 48 | uuid_ossp = dump.add_entry( 49 | desc=constants.EXTENSION, 50 | tag='uuid-ossp', 51 | defn='CREATE EXTENSION IF NOT EXISTS "uuid-ossp" WITH SCHEMA public;', 52 | drop_stmt='DROP EXTENSION "uuid-ossp";') 53 | 54 | dump.add_entry( 55 | desc=constants.COMMENT, 56 | tag='EXTENSION "uuid-ossp"', 57 | defn="""COMMENT ON EXTENSION "uuid-ossp" IS 'generate universally unique identifiers (UUIDs)'""", 58 | dependencies=[uuid_ossp.dump_id]) 59 | 60 | addr_type = dump.add_entry( 61 | desc=constants.TYPE, 62 | namespace='test', 63 | tag='address_type', 64 | owner='postgres', 65 | defn="""\ 66 | CREATE TYPE test.address_type AS ENUM ('billing', 'delivery');""", 67 | drop_stmt='DROP TYPE test.address_type;', 68 | dependencies=[schema.dump_id]) 69 | 70 | test_addresses = dump.add_entry( 71 | desc=constants.TABLE, 72 | namespace='test', 73 | tag='addresses', 74 | owner='postgres', 75 | defn="""\ 76 | CREATE TABLE addresses ( 77 | id UUID NOT NULL DEFAULT uuid_generate_v4() PRIMARY KEY, 78 | created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, 79 | last_modified_at TIMESTAMP WITH TIME ZONE, 80 | user_id UUID NOT NULL REFERENCES users (id) ON DELETE CASCADE ON UPDATE CASCADE, 81 | type address_type NOT NULL, 82 | address1 TEXT NOT NULL, 83 | address2 TEXT, 84 | address3 TEXT, 85 | locality TEXT NOT NULL, 86 | region TEXT, 87 | postal_code TEXT NOT NULL, 88 | country TEXT NOT NULL 89 | );""", 90 | drop_stmt='DROP TABLE test.addresses;', 91 | dependencies=[schema.dump_id, addr_type.dump_id, uuid_ossp.dump_id]) 92 | 93 | example = dump.add_entry( 94 | constants.TABLE, 95 | 'public', 'example', 'postgres', 96 | 'CREATE TABLE public.example (\ 97 | id UUID NOT NULL PRIMARY KEY,\ 98 | created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,\ 99 | value TEXT NOT NULL);', 100 | 'DROP TABLE public.example') 101 | 102 | with dump.table_data_writer(example, ['id', 'created_at', 'value']) as writer: 103 | writer.append(uuid.uuid4(), datetime.datetime.now(datetime.UTC), 'row1') 104 | writer.append(uuid.uuid4(), datetime.datetime.now(datetime.UTC), 'row2') 105 | writer.append(uuid.uuid4(), datetime.datetime.now(datetime.UTC), 'row3') 106 | writer.append(uuid.uuid4(), datetime.datetime.now(datetime.UTC), 'row4') 107 | writer.append(uuid.uuid4(), datetime.datetime.now(datetime.UTC), 'row5') 108 | 109 | dump.save('custom.dump') 110 | ``` 111 | -------------------------------------------------------------------------------- /bootstrap: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | # 3 | # NAME 4 | # bootstrap — bootstrap the project 5 | # 6 | # vim: set ts=4 sw=4 sts=4 noet: 7 | 8 | # NOTE THAT THIS FILE IS INDENTED WITH 4-SPACE HARD TABS 9 | # as is required for <<-EOF heredocs to work. Please keep 10 | # it this way! 11 | 12 | # -- Start bootstrap setup -- 13 | set -e 14 | if test -n "$CI_DEBUG_TRACE" -o -n "$SHELLDEBUG" 15 | then 16 | set -x 17 | fi 18 | 19 | if test "${CI}" = 'true' 20 | then 21 | echo "::group::Build logs" 22 | fi 23 | 24 | unset PYTHONDEBUG 25 | export COMPOSE_DISABLE_ENV_FILE=1 26 | export PIP_DISABLE_PIP_VERSION_CHECK=1 27 | export PYTHON_ROOT_USER_ACTION=ignore 28 | export PYTHONWARNINGS=ignore 29 | 30 | TEST_HOST=${TEST_HOST:-127.0.0.1} 31 | 32 | docker_exec() { 33 | if test "${CI}" = 'true' 34 | then 35 | docker exec -t postgres "$@" 36 | else 37 | docker compose exec postgres "$@" 38 | fi 39 | } 40 | 41 | get_exposed_port() { 42 | if test -f compose.yaml 43 | then 44 | if port=$(docker compose port "$1" --protocol "${3:-tcp}" "$2") 45 | then 46 | echo "${port#*:}" 47 | else 48 | echo "Failed to find port for $1:$2" >&2 49 | return 1 50 | fi 51 | else 52 | echo "Failed to find port for $1:$2 - compose.yaml not found" >&2 53 | return 1 54 | fi 55 | } 56 | 57 | if test "$CI" != 'true' 58 | then 59 | if test -e ./.venv/bin/activate 60 | then 61 | . ./.venv/bin/activate 62 | else 63 | python3 -m venv --upgrade-deps .venv 64 | . ./.venv/bin/activate 65 | fi 66 | fi 67 | 68 | mkdir -p build 69 | 70 | pip install --upgrade pip 71 | 72 | printf 'Installing package...' 73 | pip install -e '.[dev]' 74 | echo ' done.' 75 | 76 | if test -d .git && test -f .pre-commit-config.yaml 77 | then 78 | printf 'Installing pre-commit hook...' 79 | pre-commit install --install-hooks 80 | echo ' done.' 81 | fi 82 | 83 | if test "${CI}" != 'true' 84 | then 85 | printf 'Cleaning environment...' 86 | docker compose down --remove-orphans --volumes 87 | echo ' done.' 88 | 89 | printf 'Starting environment...' 90 | docker compose up --wait --wait-timeout 120 91 | echo ' done.' 92 | fi 93 | 94 | if test "${CI}" = 'true' 95 | then 96 | PGPORT=5432 97 | else 98 | PGPORT="$(get_exposed_port postgres 5432)" 99 | fi 100 | 101 | printf 'Generating .env...' 102 | if test "${CI}" = 'true' 103 | then 104 | # In CI, use localhost since port is exposed from service container 105 | cat > .env < .env <= 90% coverage 67 | ``` 68 | 69 | **Single test execution**: 70 | ```bash 71 | python -m unittest tests.test_dump.DumpTestCase.test_specific_method 72 | ``` 73 | 74 | ### Linting and Formatting 75 | 76 | ```bash 77 | # Run all pre-commit hooks: 78 | pre-commit run --all-files 79 | 80 | # Ruff is configured in pyproject.toml: 81 | # - Line length: 79 characters 82 | # - Quote style: single quotes 83 | # - Target: Python 3.12 84 | ``` 85 | 86 | ### Test Fixtures 87 | 88 | Test fixtures are created by the `bootstrap` script and `ci/generate-fixture-data.py`: 89 | - `build/data/dump.not-compressed` - Uncompressed dump 90 | - `build/data/dump.compressed` - Compressed dump (level 9) 91 | - `build/data/dump.no-data` - Schema-only dump 92 | - `build/data/dump.data-only` - Data-only dump 93 | - `build/data/dump.inserts` - Dump using INSERT statements 94 | 95 | Tests inherit from `EnvironmentVariableMixin` to load connection info from `build/test-environment`. 96 | 97 | ## Code Style Requirements 98 | 99 | - **Coverage**: Pull requests require test coverage. The project enforces >= 90% coverage. 100 | - **Line length**: 79 characters 101 | - **Quotes**: Single quotes for strings 102 | - **Type hints**: Use modern Python type hints (3.11+ syntax with `|` for unions) 103 | - **Target version**: Python 3.11, 3.12, 3.13 are supported 104 | 105 | ## PostgreSQL Version Support 106 | 107 | The library supports PostgreSQL 9-18 with dump format versions 1.12.0-1.16.0. Version mapping is defined in `constants.K_VERSION_MAP`. 108 | 109 | - Format version 1.12.0: PostgreSQL 9.0-10.2 (separate BLOB entries) 110 | - Format version 1.13.0: PostgreSQL 10.3-11.x (search_path behavior change) 111 | - Format version 1.14.0: PostgreSQL 12-15 (table access methods) 112 | - Format version 1.15.0: PostgreSQL 16 (compression algorithm in header) 113 | - Format version 1.16.0: PostgreSQL 17-18 (BLOB METADATA entries, multiple BLOBS, relkind) 114 | 115 | ## CI/CD 116 | 117 | GitHub Actions runs tests on: 118 | - Python versions: 3.11, 3.12, 3.13 119 | - PostgreSQL versions: 16, 17, 18 120 | - All commits to all branches (except docs-only changes) 121 | - Uploads coverage to Codecov 122 | -------------------------------------------------------------------------------- /tests/test_edge_cases.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import logging 3 | import pathlib 4 | import struct 5 | import tempfile 6 | import unittest 7 | from unittest import mock 8 | 9 | import pgdumplib 10 | from pgdumplib import constants, dump 11 | 12 | LOGGER = logging.getLogger(__name__) 13 | 14 | 15 | class EdgeTestCase(unittest.TestCase): 16 | @staticmethod 17 | def _write_byte(handle, value) -> None: 18 | """Write a byte to the handle""" 19 | handle.write(struct.pack('B', value)) 20 | 21 | def _write_int(self, handle, value): 22 | self._write_byte(handle, 1 if value < 0 else 0) 23 | if value < 0: 24 | value = -value 25 | for _offset in range(0, 4): 26 | self._write_byte(handle, value & 0xFF) 27 | value >>= 8 28 | 29 | def tearDown(self) -> None: 30 | test_file = pathlib.Path('build/data/dump.test') 31 | if test_file.exists(): 32 | test_file.unlink() 33 | 34 | def test_invalid_dependency(self): 35 | dmp = pgdumplib.new('test') 36 | with self.assertRaises(ValueError): 37 | dmp.add_entry( 38 | constants.TABLE, '', 'block_table', dependencies=[1024] 39 | ) 40 | 41 | def test_invalid_block_type_in_data(self): 42 | dmp = pgdumplib.new('test') 43 | dmp.add_entry(constants.TABLE_DATA, '', 'block_table', dump_id=128) 44 | with gzip.open(pathlib.Path(dmp._temp_dir.name) / '128.gz', 'wb') as h: 45 | h.write(b'1\t\1\t\1\n') 46 | with mock.patch('pgdumplib.constants.BLK_DATA', b'\x02'): 47 | dmp.save('build/data/dump.test') 48 | with self.assertRaises(RuntimeError): 49 | pgdumplib.load('build/data/dump.test') 50 | 51 | def test_encoding_not_first_entry(self): 52 | dmp = pgdumplib.new('test', 'LATIN1') 53 | entries = dmp.entries 54 | dmp.entries = [entries[1], entries[2], entries[0]] 55 | self.assertEqual(dmp.encoding, 'LATIN1') 56 | dmp.save('build/data/dump.test') 57 | 58 | dmp = pgdumplib.load('build/data/dump.test') 59 | self.assertEqual(dmp.encoding, 'LATIN1') 60 | 61 | def test_encoding_no_entries(self): 62 | dmp = pgdumplib.new('test', 'LATIN1') 63 | dmp.entries = [] 64 | self.assertEqual(dmp.encoding, 'LATIN1') 65 | dmp.save('build/data/dump.test') 66 | dmp = pgdumplib.load('build/data/dump.test') 67 | self.assertEqual(dmp.encoding, 'UTF8') 68 | 69 | def test_dump_id_mismatch_in_data(self): 70 | dmp = pgdumplib.new('test') 71 | dmp.add_entry(constants.TABLE_DATA, '', 'block_table', dump_id=1024) 72 | with gzip.open( 73 | pathlib.Path(dmp._temp_dir.name) / '1024.gz', 'wb' 74 | ) as handle: 75 | handle.write(b'1\t\1\t\1\n') 76 | dmp.save('build/data/dump.test') 77 | 78 | with mock.patch('pgdumplib.dump.Dump._read_block_header') as rbh: 79 | rbh.return_value = constants.BLK_DATA, 2048 80 | with self.assertRaises(RuntimeError): 81 | pgdumplib.load('build/data/dump.test') 82 | 83 | def test_no_data(self): 84 | dmp = pgdumplib.new('test') 85 | dmp.add_entry(constants.TABLE_DATA, '', 'empty_table', dump_id=5) 86 | with gzip.open(pathlib.Path(dmp._temp_dir.name) / '5.gz', 'wb') as h: 87 | h.write(b'') 88 | dmp.save('build/data/dump.test') 89 | dmp = pgdumplib.load('build/data/dump.test') 90 | data = list(dmp.table_data('', 'empty_table')) 91 | self.assertEqual(len(data), 0) 92 | 93 | def test_runtime_error_when_pos_not_set(self): 94 | dmp = pgdumplib.new('test') 95 | dmp.add_entry(constants.TABLE_DATA, 'public', 'table', dump_id=32) 96 | with gzip.open(pathlib.Path(dmp._temp_dir.name) / '32.gz', 'wb') as h: 97 | h.write(b'1\t\1\t\1\n') 98 | 99 | with mock.patch('pgdumplib.constants.K_OFFSET_POS_SET', 9): 100 | dmp.save('build/data/dump.test') 101 | 102 | with self.assertRaises(RuntimeError): 103 | pgdumplib.load('build/data/dump.test') 104 | 105 | def test_table_data_finish_called_with_closed_handle(self): 106 | with tempfile.TemporaryDirectory() as tempdir: 107 | table_data = dump.TableData(1024, tempdir, 'UTF8') 108 | table_data._handle.close() 109 | table_data.finish() 110 | table_data.finish() 111 | self.assertFalse(table_data._handle.closed) 112 | 113 | def test_bad_encoding(self): 114 | dmp = pgdumplib.new('test') 115 | dmp.entries[0].defn = 'BAD ENTRY WILL FAIL' 116 | dmp.save('build/data/dump.test') 117 | 118 | dmp = pgdumplib.load('build/data/dump.test') 119 | self.assertEqual(dmp.encoding, 'UTF8') 120 | 121 | def test_invalid_desc(self): 122 | dmp = pgdumplib.new('test') 123 | with self.assertRaises(ValueError): 124 | dmp.add_entry('foo', '', 'table') 125 | 126 | def test_invalid_dump_id(self): 127 | dmp = pgdumplib.new('test') 128 | with self.assertRaises(ValueError): 129 | dmp.add_entry(constants.TABLE, '', 'table', dump_id=0) 130 | 131 | def test_used_dump_id(self): 132 | dmp = pgdumplib.new('test') 133 | with self.assertRaises(ValueError): 134 | dmp.add_entry(constants.TABLE, '', 'table', dump_id=1) 135 | -------------------------------------------------------------------------------- /pgdumplib/constants.py: -------------------------------------------------------------------------------- 1 | """ 2 | Constants used in the reading and writing of a :command:`pg_dump` file. 3 | There are additional undocumented constants, but they should not be of concern 4 | unless you are hacking on the library itself. 5 | 6 | """ 7 | 8 | K_VERSION_MAP = { 9 | ((9, 0, 0), (10, 2)): (1, 12, 0), 10 | ((10, 3), (11, 99)): (1, 13, 0), 11 | ((12, 0), (15, 99)): (1, 14, 0), 12 | ((16, 0), (16, 99)): (1, 15, 0), 13 | ((17, 0), (99, 99)): (1, 16, 0), 14 | } 15 | 16 | BLK_DATA: bytes = b'\x01' 17 | BLK_BLOBS: bytes = b'\x03' 18 | 19 | EOF: int = -1 20 | 21 | FORMAT_UNKNOWN: int = 0 22 | FORMAT_CUSTOM: int = 1 23 | FORMAT_FILES: int = 2 24 | FORMAT_TAR: int = 3 25 | FORMAT_NULL: int = 4 26 | FORMAT_DIRECTORY: int = 5 27 | 28 | FORMATS: list[str] = ['Unknown', 'Custom', 'Files', 'Tar', 'Null', 'Directory'] 29 | 30 | """Specifies the compression algorithm used""" 31 | COMPRESSION_NONE: str = 'none' 32 | COMPRESSION_GZIP: str = 'gzip' 33 | COMPRESSION_LZ4: str = 'lz4' 34 | COMPRESSION_ZSTD: str = 'zstd' 35 | COMPRESSION_ALGORITHMS: list[str] = [ 36 | COMPRESSION_NONE, 37 | COMPRESSION_GZIP, 38 | COMPRESSION_LZ4, 39 | COMPRESSION_ZSTD, 40 | ] 41 | SUPPORTED_COMPRESSION_ALGORITHMS: list[str] = [ 42 | COMPRESSION_NONE, 43 | COMPRESSION_GZIP, 44 | ] 45 | 46 | K_OFFSET_POS_NOT_SET: int = 1 47 | """Specifies the entry has data but no offset""" 48 | K_OFFSET_POS_SET: int = 2 49 | """Specifies the entry has data and an offset""" 50 | K_OFFSET_NO_DATA: int = 3 51 | """Specifies the entry has no data""" 52 | 53 | MAGIC: bytes = b'PGDMP' 54 | 55 | MIN_VER: tuple[int, int, int] = (1, 12, 0) 56 | """The minumum supported version of pg_dump files ot support""" 57 | 58 | MAX_VER: tuple[int, int, int] = (1, 16, 0) 59 | """The maximum supported version of pg_dump files ot support""" 60 | 61 | PGDUMP_STRFTIME_FMT: str = '%Y-%m-%d %H:%M:%S.%f %Z' 62 | 63 | SECTION_NONE: str = 'None' 64 | """Non-specific section for an entry in a dump's table of contents""" 65 | 66 | SECTION_PRE_DATA: str = 'Pre-Data' 67 | """Pre-data section for an entry in a dump's table of contents""" 68 | 69 | SECTION_DATA: str = 'DATA' 70 | """Data section for an entry in a dump's table of contents""" 71 | 72 | SECTION_POST_DATA: str = 'Post-Data' 73 | """Post-data section for an entry in a dump's table of contents""" 74 | 75 | SECTIONS: list[str] = [ 76 | SECTION_NONE, 77 | SECTION_PRE_DATA, 78 | SECTION_DATA, 79 | SECTION_POST_DATA, 80 | ] 81 | 82 | VERSION: tuple[int, int, int] = (1, 14, 0) 83 | """pg_dump file format version to create by default""" 84 | 85 | ZLIB_OUT_SIZE: int = 4096 86 | ZLIB_IN_SIZE: int = 4096 87 | 88 | # Object Types 89 | 90 | ACCESS_METHOD: str = 'ACCESS METHOD' 91 | ACL: str = 'ACL' 92 | AGGREGATE: str = 'AGGREGATE' 93 | BLOB: str = 'BLOB' 94 | BLOB_METADATA: str = 'BLOB METADATA' 95 | BLOBS: str = 'BLOBS' 96 | CAST: str = 'CAST' 97 | CHECK_CONSTRAINT: str = 'CHECK CONSTRAINT' 98 | COLLATION: str = 'COLLATION' 99 | COMMENT: str = 'COMMENT' 100 | CONSTRAINT: str = 'CONSTRAINT' 101 | CONVERSION: str = 'CONVERSION' 102 | DATABASE: str = 'DATABASE' 103 | DATABASE_PROPERTIES: str = 'DATABASE PROPERTIES' 104 | DEFAULT: str = 'DEFAULT' 105 | DEFAULT_ACL: str = 'DEFAULT ACL' 106 | DOMAIN: str = 'DOMAIN' 107 | ENCODING: str = 'ENCODING' 108 | EVENT_TRIGGER: str = 'EVENT TRIGGER' 109 | EXTENSION: str = 'EXTENSION' 110 | FK_CONSTRAINT: str = 'FK CONSTRAINT' 111 | FOREIGN_DATA_WRAPPER: str = 'FOREIGN DATA WRAPPER' 112 | FOREIGN_SERVER: str = 'FOREIGN SERVER' 113 | FOREIGN_TABLE: str = 'FOREIGN TABLE' 114 | FUNCTION: str = 'FUNCTION' 115 | GROUP: str = 'GROUP' 116 | INDEX: str = 'INDEX' 117 | INDEX_ATTACH: str = 'INDEX ATTACH' 118 | LARGE_OBJECT: str = 'LARGE OBJECT' 119 | MATERIALIZED_VIEW: str = 'MATERIALIZED VIEW' 120 | MATERIALIZED_VIEW_DATA: str = 'MATERIALIZED VIEW DATA' 121 | OPERATOR: str = 'OPERATOR' 122 | OPERATOR_CLASS: str = 'OPERATOR CLASS' 123 | OPERATOR_FAMILY: str = 'OPERATOR FAMILY' 124 | PG_LARGEOBJECT: str = 'pg_largeobject' 125 | PG_LARGEOBJECT_METADATA: str = 'pg_largeobject_metadata' 126 | POLICY: str = 'POLICY' 127 | PROCEDURE: str = 'PROCEDURE' 128 | PROCEDURAL_LANGUAGE: str = 'PROCEDURAL LANGUAGE' 129 | PUBLICATION: str = 'PUBLICATION' 130 | PUBLICATION_TABLE: str = 'PUBLICATION TABLE' 131 | PUBLICATION_TABLES_IN_SCHEMA: str = 'PUBLICATION TABLES IN SCHEMA' 132 | ROLE: str = 'ROLE' 133 | ROW_SECURITY: str = 'ROW SECURITY' 134 | RULE: str = 'RULE' 135 | SCHEMA: str = 'SCHEMA' 136 | SEARCHPATH: str = 'SEARCHPATH' 137 | SECURITY_LABEL: str = 'SECURITY LABEL' 138 | SEQUENCE: str = 'SEQUENCE' 139 | SEQUENCE_OWNED_BY: str = 'SEQUENCE OWNED BY' 140 | SEQUENCE_SET: str = 'SEQUENCE SET' 141 | SERVER: str = 'SERVER' 142 | SHELL_TYPE: str = 'SHELL TYPE' 143 | STATISTICS: str = 'STATISTICS' 144 | STATISTICS_DATA: str = 'STATISTICS DATA' 145 | STDSTRINGS: str = 'STDSTRINGS' 146 | SUBSCRIPTION: str = 'SUBSCRIPTION' 147 | SUBSCRIPTION_TABLE: str = 'SUBSCRIPTION TABLE' 148 | TABLE: str = 'TABLE' 149 | TABLE_ATTACH: str = 'TABLE ATTACH' 150 | TABLE_DATA: str = 'TABLE DATA' 151 | TABLESPACE: str = 'TABLESPACE' 152 | TEXT_SEARCH_DICTIONARY: str = 'TEXT SEARCH DICTIONARY' 153 | TEXT_SEARCH_CONFIGURATION: str = 'TEXT SEARCH CONFIGURATION' 154 | TEXT_SEARCH_PARSER: str = 'TEXT SEARCH PARSER' 155 | TEXT_SEARCH_TEMPLATE: str = 'TEXT SEARCH TEMPLATE' 156 | TRANSFORM: str = 'TRANSFORM' 157 | TRIGGER: str = 'TRIGGER' 158 | TYPE: str = 'TYPE' 159 | USER: str = 'USER' 160 | USER_MAPPING: str = 'USER MAPPING' 161 | VIEW: str = 'VIEW' 162 | 163 | SECTION_MAPPING: dict[str, str] = { 164 | ACCESS_METHOD: SECTION_PRE_DATA, 165 | ACL: SECTION_NONE, 166 | AGGREGATE: SECTION_PRE_DATA, 167 | BLOB: SECTION_PRE_DATA, 168 | BLOB_METADATA: SECTION_PRE_DATA, 169 | BLOBS: SECTION_DATA, 170 | CAST: SECTION_PRE_DATA, 171 | CHECK_CONSTRAINT: SECTION_POST_DATA, 172 | COLLATION: SECTION_PRE_DATA, 173 | COMMENT: SECTION_NONE, 174 | CONSTRAINT: SECTION_POST_DATA, 175 | CONVERSION: SECTION_PRE_DATA, 176 | DATABASE: SECTION_PRE_DATA, 177 | DATABASE_PROPERTIES: SECTION_PRE_DATA, 178 | DEFAULT: SECTION_PRE_DATA, 179 | DEFAULT_ACL: SECTION_POST_DATA, 180 | DOMAIN: SECTION_PRE_DATA, 181 | ENCODING: SECTION_PRE_DATA, 182 | EVENT_TRIGGER: SECTION_POST_DATA, 183 | EXTENSION: SECTION_PRE_DATA, 184 | FK_CONSTRAINT: SECTION_POST_DATA, 185 | FOREIGN_DATA_WRAPPER: SECTION_PRE_DATA, 186 | FOREIGN_SERVER: SECTION_NONE, 187 | FOREIGN_TABLE: SECTION_PRE_DATA, 188 | FUNCTION: SECTION_PRE_DATA, 189 | GROUP: SECTION_NONE, 190 | INDEX: SECTION_POST_DATA, 191 | INDEX_ATTACH: SECTION_POST_DATA, 192 | LARGE_OBJECT: SECTION_NONE, 193 | MATERIALIZED_VIEW: SECTION_POST_DATA, 194 | MATERIALIZED_VIEW_DATA: SECTION_POST_DATA, 195 | OPERATOR: SECTION_PRE_DATA, 196 | OPERATOR_CLASS: SECTION_PRE_DATA, 197 | OPERATOR_FAMILY: SECTION_PRE_DATA, 198 | PG_LARGEOBJECT: SECTION_PRE_DATA, 199 | PG_LARGEOBJECT_METADATA: SECTION_PRE_DATA, 200 | POLICY: SECTION_POST_DATA, 201 | PROCEDURE: SECTION_PRE_DATA, 202 | PROCEDURAL_LANGUAGE: SECTION_PRE_DATA, 203 | PUBLICATION: SECTION_POST_DATA, 204 | PUBLICATION_TABLE: SECTION_POST_DATA, 205 | PUBLICATION_TABLES_IN_SCHEMA: SECTION_POST_DATA, 206 | ROLE: SECTION_NONE, 207 | ROW_SECURITY: SECTION_POST_DATA, 208 | RULE: SECTION_POST_DATA, 209 | SCHEMA: SECTION_PRE_DATA, 210 | SEARCHPATH: SECTION_PRE_DATA, 211 | SECURITY_LABEL: SECTION_NONE, 212 | SERVER: SECTION_PRE_DATA, 213 | SEQUENCE: SECTION_PRE_DATA, 214 | SEQUENCE_SET: SECTION_DATA, 215 | SEQUENCE_OWNED_BY: SECTION_PRE_DATA, 216 | SHELL_TYPE: SECTION_PRE_DATA, 217 | STATISTICS: SECTION_POST_DATA, 218 | STATISTICS_DATA: SECTION_POST_DATA, 219 | STDSTRINGS: SECTION_PRE_DATA, 220 | SUBSCRIPTION: SECTION_PRE_DATA, 221 | SUBSCRIPTION_TABLE: SECTION_POST_DATA, 222 | TABLESPACE: SECTION_PRE_DATA, # Not part of a postgres created pg_dump 223 | TABLE: SECTION_PRE_DATA, 224 | TABLE_ATTACH: SECTION_POST_DATA, 225 | TABLE_DATA: SECTION_DATA, 226 | TRANSFORM: SECTION_PRE_DATA, 227 | TRIGGER: SECTION_POST_DATA, 228 | TYPE: SECTION_PRE_DATA, 229 | TEXT_SEARCH_CONFIGURATION: SECTION_PRE_DATA, 230 | TEXT_SEARCH_DICTIONARY: SECTION_PRE_DATA, 231 | TEXT_SEARCH_PARSER: SECTION_PRE_DATA, 232 | TEXT_SEARCH_TEMPLATE: SECTION_PRE_DATA, 233 | USER: SECTION_NONE, 234 | USER_MAPPING: SECTION_PRE_DATA, 235 | VIEW: SECTION_PRE_DATA, 236 | } 237 | -------------------------------------------------------------------------------- /pgdumplib/converters.py: -------------------------------------------------------------------------------- 1 | """ 2 | When creating a new :class:`pgdumplib.dump.Dump` instance, either directly 3 | or by using :py:func:`pgdumplib.load`, you can specify a converter class to 4 | use when reading data using the :meth:`pgdumplib.dump.Dump.read_data` iterator. 5 | 6 | The default converter (:py:class:`DataConverter`) handles PostgreSQL COPY 7 | text format escape sequences and replaces columns that have a ``NULL`` 8 | indicator (``\\N``) with :py:const:`None`. 9 | 10 | Supported escape sequences: 11 | - ``\\b`` - backspace 12 | - ``\\f`` - form feed 13 | - ``\\n`` - newline 14 | - ``\\r`` - carriage return 15 | - ``\\t`` - tab 16 | - ``\\v`` - vertical tab 17 | - ``\\NNN`` - octal byte value (1-3 digits) 18 | - ``\\xNN`` - hex byte value (1-2 digits) 19 | - ``\\\\`` - backslash 20 | - ``\\X`` - any other character X taken literally 21 | 22 | The :py:class:`SmartDataConverter` extends the base converter and will 23 | attempt to convert individual columns to native Python data types after 24 | unescaping. 25 | 26 | Creating your own data converter is easy and should simply extend the 27 | :py:class:`DataConverter` class. 28 | 29 | """ 30 | 31 | import datetime 32 | import decimal 33 | import ipaddress 34 | import typing 35 | import uuid 36 | 37 | import pendulum 38 | 39 | 40 | def unescape_copy_text(field: str) -> str: # noqa: C901 41 | """Unescape PostgreSQL COPY text format escape sequences. 42 | 43 | This function implements the same escape sequence handling as PostgreSQL's 44 | CopyReadAttributesText() function in copyfromparse.c. 45 | 46 | Supported escape sequences: 47 | \\b - backspace (ASCII 8) 48 | \\f - form feed (ASCII 12) 49 | \\n - newline (ASCII 10) 50 | \\r - carriage return (ASCII 13) 51 | \\t - tab (ASCII 9) 52 | \\v - vertical tab (ASCII 11) 53 | \\NNN - octal byte value (1-3 digits) 54 | \\xNN - hex byte value (1-2 digits) 55 | \\X - any other character X literally 56 | 57 | :param field: The escaped field string 58 | :return: The unescaped string 59 | 60 | """ 61 | if '\\' not in field: 62 | return field 63 | 64 | result = [] 65 | i = 0 66 | while i < len(field): 67 | if field[i] == '\\' and i + 1 < len(field): 68 | i += 1 69 | c = field[i] 70 | 71 | # Octal escape: \0-\7 72 | if '0' <= c <= '7': 73 | val = ord(c) - ord('0') 74 | # Check for second octal digit 75 | if i + 1 < len(field) and '0' <= field[i + 1] <= '7': 76 | i += 1 77 | val = (val << 3) + (ord(field[i]) - ord('0')) 78 | # Check for third octal digit 79 | if i + 1 < len(field) and '0' <= field[i + 1] <= '7': 80 | i += 1 81 | val = (val << 3) + (ord(field[i]) - ord('0')) 82 | result.append(chr(val & 0o377)) 83 | 84 | # Hex escape: \xNN 85 | elif c == 'x': 86 | if i + 1 < len(field): 87 | hex_chars = field[i + 1 : i + 3] 88 | hex_val = '' 89 | for hc in hex_chars: 90 | if hc in '0123456789abcdefABCDEF': 91 | hex_val += hc 92 | else: 93 | break 94 | if hex_val: 95 | i += len(hex_val) 96 | result.append(chr(int(hex_val, 16) & 0xFF)) 97 | else: 98 | result.append(c) 99 | else: 100 | result.append(c) 101 | 102 | # Single character escapes 103 | elif c == 'b': 104 | result.append('\b') 105 | elif c == 'f': 106 | result.append('\f') 107 | elif c == 'n': 108 | result.append('\n') 109 | elif c == 'r': 110 | result.append('\r') 111 | elif c == 't': 112 | result.append('\t') 113 | elif c == 'v': 114 | result.append('\v') 115 | else: 116 | # Any other backslashed character is literal 117 | result.append(c) 118 | 119 | i += 1 120 | else: 121 | result.append(field[i]) 122 | i += 1 123 | 124 | return ''.join(result) 125 | 126 | 127 | class DataConverter: 128 | """Base Row/Column Converter 129 | 130 | Base class used for converting row/column data when using the 131 | :meth:`pgdumplib.dump.Dump.read_data` iterator. 132 | 133 | This class splits the row into individual columns, unescapes 134 | PostgreSQL COPY text format escape sequences, and converts ``\\N`` 135 | to :py:const:`None`. 136 | 137 | """ 138 | 139 | @staticmethod 140 | def convert(row: str) -> tuple[typing.Any, ...]: 141 | """Convert the string based row into a tuple of columns. 142 | 143 | :param str row: The row to convert 144 | :rtype: tuple 145 | 146 | """ 147 | return tuple( 148 | None if e == '\\N' else unescape_copy_text(e) 149 | for e in row.split('\t') 150 | ) 151 | 152 | 153 | class NoOpConverter: 154 | """Performs no conversion on the row passed in""" 155 | 156 | @staticmethod 157 | def convert(row: str) -> str: 158 | """Returns the row passed in 159 | 160 | :param str row: The row to convert 161 | :rtype: str 162 | 163 | """ 164 | return row 165 | 166 | 167 | SmartColumn = ( 168 | None 169 | | str 170 | | int 171 | | datetime.datetime 172 | | decimal.Decimal 173 | | ipaddress.IPv4Address 174 | | ipaddress.IPv4Network 175 | | ipaddress.IPv6Address 176 | | ipaddress.IPv6Network 177 | | uuid.UUID 178 | ) 179 | 180 | 181 | class SmartDataConverter(DataConverter): 182 | """Attempts to convert columns to native Python data types 183 | 184 | Used for converting row/column data with the 185 | :meth:`pgdumplib.dump.Dump.read_data` iterator. 186 | 187 | Possible conversion types: 188 | 189 | - :py:class:`int` 190 | - :py:class:`datetime.datetime` 191 | - :py:class:`decimal.Decimal` 192 | - :py:class:`ipaddress.IPv4Address` 193 | - :py:class:`ipaddress.IPv4Network` 194 | - :py:class:`ipaddress.IPv6Address` 195 | - :py:class:`ipaddress.IPv6Network` 196 | - :py:const:`None` 197 | - :py:class:`str` 198 | - :py:class:`uuid.UUID` 199 | 200 | """ 201 | 202 | @staticmethod 203 | def convert(row: str) -> tuple[SmartColumn, ...]: 204 | """Convert the string based row into a tuple of columns""" 205 | return tuple( 206 | SmartDataConverter._convert_column(c) for c in row.split('\t') 207 | ) 208 | 209 | @staticmethod 210 | def _convert_column(column: str) -> SmartColumn: 211 | """Attempt to convert the column from a string if appropriate""" 212 | # Check for NULL before unescaping (PostgreSQL does this) 213 | if column == '\\N': 214 | return None 215 | # Unescape the column data 216 | column = unescape_copy_text(column) 217 | if column.strip('-').isnumeric(): 218 | return int(column) 219 | elif column.strip('-').replace('.', '').isnumeric(): 220 | try: 221 | return decimal.Decimal(column) 222 | except (ValueError, decimal.InvalidOperation): 223 | pass 224 | try: 225 | return ipaddress.ip_address(column) 226 | except ValueError: 227 | pass 228 | try: 229 | return ipaddress.ip_network(column) 230 | except ValueError: 231 | pass 232 | try: 233 | return uuid.UUID(column) 234 | except ValueError: 235 | pass 236 | for tz_fmt in {'Z', 'ZZ', 'z', 'zz'}: 237 | for micro_fmt in {'.SSSSSS', ''}: 238 | try: 239 | pdt = pendulum.from_format( 240 | column, f'YYYY-MM-DD HH:mm:ss{micro_fmt} {tz_fmt}' 241 | ) 242 | # Convert pendulum DateTime to datetime.datetime 243 | return datetime.datetime.fromisoformat(pdt.isoformat()) 244 | except ValueError: 245 | pass 246 | return column 247 | -------------------------------------------------------------------------------- /tests/test_save_dump.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import datetime 3 | import os 4 | import pathlib 5 | import subprocess 6 | import tempfile 7 | import unittest 8 | import uuid 9 | 10 | import dotenv 11 | import faker 12 | from faker.providers import date_time 13 | 14 | import pgdumplib 15 | from pgdumplib import constants, converters, models 16 | 17 | dotenv.load_dotenv() 18 | 19 | 20 | class SavedDumpTestCase(unittest.TestCase): 21 | def setUp(self): 22 | dmp = pgdumplib.load('build/data/dump.compressed') 23 | dmp.save(pathlib.Path('build/data/dump.test')) 24 | self.original = pgdumplib.load('build/data/dump.compressed') 25 | self.saved = pgdumplib.load('build/data/dump.test') 26 | 27 | def tearDown(self) -> None: 28 | test_file = pathlib.Path('build/data/dump.test') 29 | if test_file.exists(): 30 | test_file.unlink() 31 | 32 | def test_timestamp_matches(self): 33 | self.assertEqual( 34 | self.original.timestamp.isoformat(), 35 | self.saved.timestamp.isoformat(), 36 | ) 37 | 38 | def test_version_matches(self): 39 | self.assertEqual(self.original.version, self.saved.version) 40 | 41 | def test_compression_does_not_match(self): 42 | self.assertNotEqual( 43 | self.original.compression_algorithm, constants.COMPRESSION_NONE 44 | ) 45 | self.assertEqual( 46 | self.saved.compression_algorithm, constants.COMPRESSION_NONE 47 | ) 48 | 49 | def test_entries_mostly_match(self): 50 | attrs = [e.name for e in dataclasses.fields(models.Entry)] 51 | attrs.remove('offset') 52 | for original in self.original.entries: 53 | saved_entry = self.saved.get_entry(original.dump_id) 54 | for attr in attrs: 55 | self.assertEqual( 56 | getattr(original, attr), 57 | getattr(saved_entry, attr), 58 | f'{attr} does not match: {getattr(original, attr)} != ' 59 | f'{getattr(saved_entry, attr)}', 60 | ) 61 | 62 | def test_table_data_matches(self): 63 | for entry in range(0, len(self.original.entries)): 64 | if self.original.entries[entry].desc != constants.TABLE_DATA: 65 | continue 66 | 67 | # Type narrowing for namespace and tag 68 | namespace = self.original.entries[entry].namespace 69 | tag = self.original.entries[entry].tag 70 | assert namespace is not None # noqa: S101 71 | assert tag is not None # noqa: S101 72 | 73 | original_data = list(self.original.table_data(namespace, tag)) 74 | 75 | saved_data = list(self.saved.table_data(namespace, tag)) 76 | 77 | for offset in range(0, len(original_data)): 78 | self.assertListEqual( 79 | list(original_data[offset]), 80 | list(saved_data[offset]), 81 | f'Data in {namespace}.{tag} does not match ' 82 | f'for row {offset}', 83 | ) 84 | 85 | 86 | class EmptyDumpTestCase(unittest.TestCase): 87 | def test_empty_dump_has_base_entries(self): 88 | dmp = pgdumplib.new('test', 'UTF8') 89 | self.assertEqual(len(dmp.entries), 3) 90 | 91 | def test_empty_save_does_not_err(self): 92 | dmp = pgdumplib.new('test', 'UTF8') 93 | dmp.save(pathlib.Path('build/data/dump.test')) 94 | test_file = pathlib.Path('build/data/dump.test') 95 | self.assertTrue(test_file.exists()) 96 | test_file.unlink() 97 | 98 | 99 | class CreateDumpTestCase(unittest.TestCase): 100 | def tearDown(self) -> None: 101 | test_file = pathlib.Path('build/data/dump.test') 102 | if test_file.exists(): 103 | test_file.unlink() 104 | 105 | def test_dump_expectations(self): 106 | dmp = pgdumplib.new('test', 'UTF8') 107 | database = dmp.add_entry( 108 | desc=constants.DATABASE, 109 | tag='postgres', 110 | owner='postgres', 111 | defn="""\ 112 | CREATE DATABASE postgres 113 | WITH TEMPLATE = template0 114 | ENCODING = 'UTF8' 115 | LC_COLLATE = 'en_US.utf8' 116 | LC_CTYPE = 'en_US.utf8';""", 117 | drop_stmt='DROP DATABASE postgres', 118 | ) 119 | 120 | dmp.add_entry( 121 | constants.COMMENT, 122 | tag='DATABASE postgres', 123 | owner='postgres', 124 | defn="""\ 125 | COMMENT ON DATABASE postgres 126 | IS 'default administrative connection database';""", 127 | dependencies=[database.dump_id], 128 | ) 129 | 130 | example = dmp.add_entry( 131 | constants.TABLE, 132 | 'public', 133 | 'example', 134 | 'postgres', 135 | 'CREATE TABLE public.example (\ 136 | id UUID NOT NULL PRIMARY KEY, \ 137 | created_at TIMESTAMP WITH TIME ZONE, \ 138 | value TEXT NOT NULL);', 139 | 'DROP TABLE public.example', 140 | ) 141 | 142 | columns = 'id', 'created_at', 'value' 143 | 144 | fake = faker.Faker() 145 | fake.add_provider(date_time) 146 | 147 | rows = [ 148 | (uuid.uuid4(), fake.date_time(tzinfo=datetime.UTC), 'foo'), 149 | (uuid.uuid4(), fake.date_time(tzinfo=datetime.UTC), 'bar'), 150 | (uuid.uuid4(), fake.date_time(tzinfo=datetime.UTC), 'baz'), 151 | (uuid.uuid4(), fake.date_time(tzinfo=datetime.UTC), 'qux'), 152 | ] 153 | 154 | with dmp.table_data_writer(example, columns) as writer: 155 | for row in rows: 156 | writer.append(*row) 157 | 158 | row = (uuid.uuid4(), fake.date_time(tzinfo=datetime.UTC), None) 159 | rows.append(row) 160 | 161 | # Append a second time to get same writer 162 | with dmp.table_data_writer(example, columns) as writer: 163 | writer.append(*row) 164 | 165 | test_file = pathlib.Path('build/data/dump.test') 166 | 167 | dmp.save(test_file) 168 | 169 | self.assertTrue(test_file.exists()) 170 | 171 | dmp = pgdumplib.load(test_file, converters.SmartDataConverter) 172 | entry = dmp.get_entry(database.dump_id) 173 | self.assertIsNotNone(entry) 174 | assert entry is not None # noqa: S101 - Type narrowing for type checker 175 | self.assertEqual(entry.desc, 'DATABASE') 176 | self.assertEqual(entry.owner, 'postgres') 177 | self.assertEqual(entry.tag, 'postgres') 178 | values = list(dmp.table_data('public', 'example')) 179 | self.assertListEqual(values, rows) 180 | 181 | 182 | class Issue6RegressionTestCase(unittest.TestCase): 183 | """Regression test for GitHub issue #6: empty tableam causing invalid SQL 184 | 185 | When pgdumplib loads and saves a dump, it should not generate entries 186 | with empty tableam values that cause pg_restore to generate invalid SQL 187 | like: SET default_table_access_method = ""; 188 | """ 189 | 190 | def test_load_save_no_empty_tableam_sql(self): 191 | """Test that resaved dumps don't generate invalid empty tableam SQL""" 192 | # Create a test dump with a table (has tableam) and others without 193 | dmp = pgdumplib.new('test_issue6', 'UTF8', appear_as='13.1') 194 | 195 | # Add a table with tableam='heap' 196 | dmp.add_entry( 197 | constants.TABLE, 198 | namespace='public', 199 | tag='test_table', 200 | owner='postgres', 201 | defn='CREATE TABLE public.test_table (id integer);', 202 | tableam='heap', 203 | ) 204 | 205 | # Add a sequence (should not have tableam set) 206 | dmp.add_entry( 207 | constants.SEQUENCE, 208 | namespace='public', 209 | tag='test_seq', 210 | owner='postgres', 211 | defn='CREATE SEQUENCE public.test_seq;', 212 | ) 213 | 214 | # Save and reload 215 | with tempfile.NamedTemporaryFile(delete=False, suffix='.dump') as tmp: 216 | tmp_path = tmp.name 217 | 218 | try: 219 | dmp.save(tmp_path) 220 | 221 | # Check that pg_restore can generate valid SQL 222 | result = subprocess.run( # noqa: S603 223 | ['pg_restore', '--schema-only', '-f', '-', tmp_path], # noqa: S607 224 | capture_output=True, 225 | text=True, 226 | ) 227 | 228 | # Should not fail 229 | self.assertEqual( 230 | result.returncode, 231 | 0, 232 | f'pg_restore failed: {result.stderr}', 233 | ) 234 | 235 | # Should not contain invalid empty SET statement 236 | self.assertNotIn( 237 | 'SET default_table_access_method = ""', 238 | result.stdout, 239 | 'Found invalid empty tableam SET statement in output', 240 | ) 241 | 242 | # Should contain the valid SET statement for the table 243 | self.assertIn( 244 | 'SET default_table_access_method = heap', 245 | result.stdout, 246 | 'Missing valid tableam SET statement for table', 247 | ) 248 | finally: 249 | if os.path.exists(tmp_path): 250 | os.unlink(tmp_path) 251 | -------------------------------------------------------------------------------- /bin/generate-fixture-data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import logging 4 | import os 5 | import random 6 | from os import path 7 | 8 | import faker 9 | import psycopg 10 | from faker.providers import address, internet, misc, person 11 | 12 | LOGGER = logging.getLogger(__name__) 13 | LOGGING_FORMAT = '[%(asctime)-15s] %(levelname)-8s %(message)s' 14 | 15 | # Faker Supported Locales 16 | LOCALES = [ 17 | 'ar_AA', 18 | 'ar_EG', 19 | 'ar_JO', 20 | 'ar_PS', 21 | 'ar_SA', 22 | 'bg_BG', 23 | 'bs_BA', 24 | 'cs_CZ', 25 | 'de_AT', 26 | 'de_CH', 27 | 'de_DE', 28 | 'dk_DK', 29 | 'el_CY', 30 | 'el_GR', 31 | 'en_AU', 32 | 'en_CA', 33 | 'en_GB', 34 | 'en_IE', 35 | 'en_NZ', 36 | 'en_TH', 37 | 'en_US', 38 | 'es_ES', 39 | 'es_MX', 40 | 'et_EE', 41 | 'fa_IR', 42 | 'fi_FI', 43 | 'fr_CH', 44 | 'fr_FR', 45 | 'he_IL', 46 | 'hi_IN', 47 | 'hr_HR', 48 | 'hu_HU', 49 | 'hy_AM', 50 | 'id_ID', 51 | 'it_IT', 52 | 'ja_JP', 53 | 'ka_GE', 54 | 'ko_KR', 55 | 'lb_LU', 56 | 'lt_LT', 57 | 'lv_LV', 58 | 'mt_MT', 59 | 'ne_NP', 60 | 'nl_BE', 61 | 'nl_NL', 62 | 'no_NO', 63 | 'pl_PL', 64 | 'pt_BR', 65 | 'pt_PT', 66 | 'ro_RO', 67 | 'ru_RU', 68 | 'sk_SK', 69 | 'sl_SI', 70 | 'sv_SE', 71 | 'th_TH', 72 | 'tr_TR', 73 | 'tw_GH', 74 | 'uk_UA', 75 | 'zh_CN', 76 | 'zh_TW', 77 | ] 78 | 79 | STATES = ['unverified', 'verified', 'suspended'] 80 | TYPES = ['billing', 'delivery'] 81 | 82 | ADDRESS_SQL = """\ 83 | INSERT INTO test.addresses 84 | (created_at, last_modified_at, user_id, "type", address1, address2, 85 | address3, locality, region, postal_code, country) 86 | VALUES (%(created_at)s, %(last_modified_at)s, %(user_id)s, %(type)s, 87 | %(address1)s, %(address2)s, %(address3)s, %(locality)s, 88 | %(region)s, %(postal_code)s, %(country)s) 89 | RETURNING id;""" 90 | ICON_SQL = 'SELECT lo_from_bytea(0, %(data)s)' 91 | USER_SQL = """\ 92 | INSERT INTO test.users 93 | (created_at, last_modified_at, state, email, name, surname, 94 | display_name, locale, password_salt, password, signup_ip, icon) 95 | VALUES (%(created_at)s, %(last_modified_at)s, %(state)s, %(email)s, 96 | %(name)s, %(surname)s, %(display_name)s, %(locale)s, 97 | %(password_salt)s, %(password)s, %(signup_ip)s, %(icon)s) 98 | RETURNING id, created_at;""" 99 | 100 | 101 | def add_connection_options_to_parser(parser): 102 | """Add PostgreSQL connection CLI options to the parser. 103 | 104 | :param argparse.ArgumentParser parser: The parser to add the args to 105 | 106 | """ 107 | conn = parser.add_argument_group('Connection Options') 108 | conn.add_argument( 109 | '-d', 110 | '--dbname', 111 | action='store', 112 | default=os.environ.get('PGDATABASE', 'postgres'), 113 | help='database name to connect to', 114 | ) 115 | conn.add_argument( 116 | '-h', 117 | '--host', 118 | action='store', 119 | default=os.environ.get('PGHOST', 'localhost'), 120 | help='database server host or socket directory', 121 | ) 122 | conn.add_argument( 123 | '-p', 124 | '--port', 125 | action='store', 126 | type=int, 127 | default=int(os.environ.get('PGPORT', 5432)), 128 | help='database server port number', 129 | ) 130 | conn.add_argument( 131 | '-U', 132 | '--username', 133 | action='store', 134 | default=os.environ.get('PGUSER', 'postgres'), 135 | help='The PostgreSQL username to operate as', 136 | ) 137 | 138 | 139 | def add_logging_options_to_parser(parser): 140 | """Add logging options to the parser. 141 | 142 | :param argparse.ArgumentParser parser: The parser to add the args to 143 | 144 | """ 145 | group = parser.add_argument_group(title='Logging Options') 146 | group.add_argument( 147 | '-L', 148 | '--log-file', 149 | action='store', 150 | help='Log to the specified filename. If not specified, ' 151 | 'log output is sent to STDOUT', 152 | ) 153 | group.add_argument( 154 | '-v', 155 | '--verbose', 156 | action='store_true', 157 | help='Increase output verbosity', 158 | ) 159 | group.add_argument( 160 | '--debug', action='store_true', help='Extra verbose debug logging' 161 | ) 162 | 163 | 164 | def configure_logging(args): 165 | """Configure Python logging. 166 | 167 | :param argparse.namespace args: The parsed cli arguments 168 | 169 | """ 170 | level = logging.WARNING 171 | if args.verbose: 172 | level = logging.INFO 173 | elif args.debug: 174 | level = logging.DEBUG 175 | filename = args.log_file if args.log_file else None 176 | if filename: 177 | filename = path.abspath(filename) 178 | if not path.exists(path.dirname(filename)): 179 | filename = None 180 | logging.basicConfig(level=level, filename=filename, format=LOGGING_FORMAT) 181 | 182 | 183 | def generate_address(fake, user_id, created_at): 184 | """Generate a street address for the user""" 185 | last_modified_at = None 186 | if fake.boolean(chance_of_getting_true=25): 187 | last_modified_at = created_at + fake.time_delta( 188 | end_datetime=datetime.datetime.now(tz=datetime.UTC) 189 | ) 190 | 191 | address1 = fake.street_address() 192 | address2 = None 193 | if '\n' in address1: 194 | parts = address1.split('\n') 195 | address1 = parts[0] 196 | address2 = parts[1] 197 | try: 198 | address3 = fake.secondary_address() 199 | except AttributeError: 200 | address3 = None 201 | try: 202 | region = fake.state() 203 | except AttributeError: 204 | region = None 205 | addr = { 206 | 'created_at': created_at, 207 | 'last_modified_at': last_modified_at, 208 | 'user_id': user_id, 209 | 'type': fake.random_element(TYPES), 210 | 'address1': address1, 211 | 'address2': address2, 212 | 'address3': address3, 213 | 'locality': fake.city(), 214 | 'region': region, 215 | 'postal_code': fake.postcode(), 216 | 'country': fake.country(), 217 | } 218 | LOGGER.debug('Returning %r', addr) 219 | return addr 220 | 221 | 222 | def generate_user(fake, locale_fake, locale, icon_oid): 223 | """Generate a fake user""" 224 | created_at = fake.date_time_this_year() 225 | last_modified_at = None 226 | if fake.boolean(chance_of_getting_true=45): 227 | last_modified_at = created_at + fake.time_delta( 228 | end_datetime=datetime.datetime.now(tz=datetime.UTC) 229 | ) 230 | name = locale_fake.first_name() 231 | surname = locale_fake.last_name() 232 | display_name = None 233 | if fake.boolean(chance_of_getting_true=50): 234 | if fake.boolean(chance_of_getting_true=50): 235 | display_name = name 236 | else: 237 | display_name = f'{name} {surname}' 238 | user = { 239 | 'created_at': created_at, 240 | 'last_modified_at': last_modified_at, 241 | 'state': fake.random_element(STATES), 242 | 'email': locale_fake.safe_email(), 243 | 'name': name, 244 | 'surname': surname, 245 | 'display_name': display_name, 246 | 'locale': locale.replace('_', '-'), 247 | 'password_salt': fake.uuid4(), 248 | 'password': fake.password(12), # It's fake data :-p 249 | 'signup_ip': fake.ipv4_public(), 250 | 'icon': icon_oid, 251 | } 252 | LOGGER.debug('Returning user: %r', user) 253 | return user 254 | 255 | 256 | def generate_users(args, cursor): 257 | """Generate fixture data for the user table.""" 258 | LOGGER.info('Creating %i users', args.user_count) 259 | fake = faker.Faker() 260 | fake.add_provider(address) 261 | fake.add_provider(internet) 262 | fake.add_provider(misc) 263 | 264 | for _offset in range(0, args.user_count): 265 | # Randomly maybe create a blob 266 | icon_oid = None 267 | if fake.boolean(chance_of_getting_true=25): 268 | cursor.execute( 269 | ICON_SQL, 270 | {'data': fake.binary(length=random.randint(10000, 200000))}, 271 | ) 272 | icon_oid = cursor.fetchone()[0] 273 | 274 | locale = fake.random_element(LOCALES) 275 | locale_fake = get_locale_faker(locale) 276 | 277 | # Create the User 278 | try: 279 | cursor.execute( 280 | USER_SQL, generate_user(fake, locale_fake, locale, icon_oid) 281 | ) 282 | except psycopg.IntegrityError as err: 283 | LOGGER.error('Error creating user: %s', err) 284 | continue 285 | 286 | user = cursor.fetchone() 287 | LOGGER.info('Created user %s', user[0]) 288 | created_at = user[1] 289 | for _offset in range(random.randint(0, 2)): 290 | cursor.execute( 291 | ADDRESS_SQL, generate_address(locale_fake, user[0], created_at) 292 | ) 293 | addr = cursor.fetchone() 294 | LOGGER.info('Created address %s for user %s', addr[0], user[0]) 295 | created_at = created_at + fake.time_delta( 296 | end_datetime=datetime.datetime.now(tz=datetime.UTC) 297 | ) 298 | 299 | 300 | def get_locale_faker(locale=None): 301 | """Return a fake.Faker instance for the specified locale. 302 | 303 | :param str locale: The locale to generate fake data for 304 | :rtype: faker.Faker 305 | 306 | """ 307 | fake = faker.Faker(locale) 308 | fake.add_provider(address) 309 | fake.add_provider(internet) 310 | fake.add_provider(person) 311 | return fake 312 | 313 | 314 | def main(args): 315 | """Primary script for generating fake data in test database 316 | 317 | :param argparse.namespace args: The parsed cli arguments 318 | 319 | """ 320 | configure_logging(args) 321 | conn = psycopg.connect( 322 | host=args.host, 323 | port=args.port, 324 | dbname=args.dbname, 325 | user=args.username, 326 | autocommit=True, 327 | ) 328 | cursor = conn.cursor() 329 | generate_users(args, cursor) 330 | conn.close() 331 | 332 | 333 | if __name__ == '__main__': 334 | parser = argparse.ArgumentParser( 335 | description='Generates fixture data for tests', 336 | conflict_handler='resolve', 337 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 338 | ) 339 | add_connection_options_to_parser(parser) 340 | add_logging_options_to_parser(parser) 341 | parser.add_argument( 342 | '--user-count', 343 | action='store', 344 | type=int, 345 | default=250, 346 | help='How many users to generate', 347 | ) 348 | main(parser.parse_args()) 349 | -------------------------------------------------------------------------------- /tests/test_converters.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import ipaddress 3 | import unittest 4 | import uuid 5 | 6 | import faker 7 | import maya 8 | 9 | from pgdumplib import constants, converters 10 | 11 | 12 | class TestCase(unittest.TestCase): 13 | def test_data_converter(self): 14 | data = [] 15 | for row in range(0, 10): 16 | data.append( 17 | [ 18 | str(row), 19 | str(uuid.uuid4()), 20 | str(datetime.datetime.now(tz=datetime.UTC)), 21 | str(uuid.uuid4()), 22 | None, 23 | ] 24 | ) 25 | 26 | converter = converters.DataConverter() 27 | for _offset, expectation in enumerate(data): 28 | line = '\t'.join(['\\N' if e is None else e for e in expectation]) 29 | self.assertListEqual(list(converter.convert(line)), expectation) 30 | 31 | def test_noop_converter(self): 32 | converter = converters.NoOpConverter() 33 | value = '1\t\\N\tfoo\t \t' 34 | self.assertEqual(converter.convert(value), value) 35 | 36 | def test_smart_data_converter(self): 37 | def convert(value): 38 | """Convert the value to the proper string type""" 39 | if value is None: 40 | return '\\N' 41 | elif isinstance(value, datetime.datetime): 42 | return value.strftime(constants.PGDUMP_STRFTIME_FMT) 43 | return str(value) 44 | 45 | fake = faker.Faker() 46 | data = [] 47 | for row in range(0, 10): 48 | data.append( 49 | [ 50 | row, 51 | None, 52 | fake.pydecimal( 53 | positive=True, left_digits=5, right_digits=3 54 | ), 55 | uuid.uuid4(), 56 | ipaddress.IPv4Network(fake.ipv4(True)), 57 | ipaddress.IPv4Address(fake.ipv4()), 58 | ipaddress.IPv6Address(fake.ipv6()), 59 | maya.now() 60 | .datetime(to_timezone='US/Eastern', naive=True) 61 | .strftime(constants.PGDUMP_STRFTIME_FMT), 62 | ] 63 | ) 64 | 65 | converter = converters.SmartDataConverter() 66 | for _offset, expectation in enumerate(data): 67 | line = '\t'.join([convert(e) for e in expectation]) 68 | row = list(converter.convert(line)) 69 | self.assertListEqual(row, expectation) 70 | 71 | def test_smart_data_converter_bad_date(self): 72 | converter = converters.SmartDataConverter() 73 | row = '2019-13-45 25:34:99 00:00\t1\tfoo\t\\N' 74 | self.assertEqual( 75 | converter.convert(row), 76 | ('2019-13-45 25:34:99 00:00', 1, 'foo', None), 77 | ) 78 | 79 | 80 | class UnescapeCopyTextTestCase(unittest.TestCase): 81 | """Test cases for PostgreSQL COPY text format escape sequence handling. 82 | 83 | These tests verify that pgdumplib correctly handles all escape sequences 84 | as defined in the PostgreSQL COPY text format specification, matching 85 | the behavior of PostgreSQL's CopyReadAttributesText() function. 86 | """ 87 | 88 | def test_no_escapes(self): 89 | """Test that strings without escapes pass through unchanged""" 90 | self.assertEqual(converters.unescape_copy_text('hello'), 'hello') 91 | self.assertEqual( 92 | converters.unescape_copy_text('test data'), 'test data' 93 | ) 94 | self.assertEqual(converters.unescape_copy_text(''), '') 95 | 96 | def test_backslash_escape(self): 97 | """Test that backslash itself can be escaped""" 98 | self.assertEqual(converters.unescape_copy_text('\\\\'), '\\') 99 | self.assertEqual(converters.unescape_copy_text('a\\\\b'), 'a\\b') 100 | self.assertEqual(converters.unescape_copy_text('\\\\\\\\'), '\\\\') 101 | 102 | def test_newline_escape(self): 103 | """Test \\n escape sequence""" 104 | self.assertEqual(converters.unescape_copy_text('\\n'), '\n') 105 | self.assertEqual( 106 | converters.unescape_copy_text('line1\\nline2'), 'line1\nline2' 107 | ) 108 | self.assertEqual(converters.unescape_copy_text('a\\nb\\nc'), 'a\nb\nc') 109 | 110 | def test_carriage_return_escape(self): 111 | """Test \\r escape sequence""" 112 | self.assertEqual(converters.unescape_copy_text('\\r'), '\r') 113 | self.assertEqual(converters.unescape_copy_text('test\\r'), 'test\r') 114 | 115 | def test_tab_escape(self): 116 | """Test \\t escape sequence""" 117 | self.assertEqual(converters.unescape_copy_text('\\t'), '\t') 118 | self.assertEqual(converters.unescape_copy_text('a\\tb'), 'a\tb') 119 | 120 | def test_backspace_escape(self): 121 | """Test \\b escape sequence""" 122 | self.assertEqual(converters.unescape_copy_text('\\b'), '\b') 123 | self.assertEqual(converters.unescape_copy_text('test\\b'), 'test\b') 124 | 125 | def test_form_feed_escape(self): 126 | """Test \\f escape sequence""" 127 | self.assertEqual(converters.unescape_copy_text('\\f'), '\f') 128 | self.assertEqual(converters.unescape_copy_text('test\\f'), 'test\f') 129 | 130 | def test_vertical_tab_escape(self): 131 | """Test \\v escape sequence""" 132 | self.assertEqual(converters.unescape_copy_text('\\v'), '\v') 133 | self.assertEqual(converters.unescape_copy_text('test\\v'), 'test\v') 134 | 135 | def test_octal_escapes(self): 136 | """Test octal escape sequences \\NNN""" 137 | # Single digit octal 138 | self.assertEqual(converters.unescape_copy_text('\\0'), '\x00') 139 | self.assertEqual(converters.unescape_copy_text('\\7'), '\x07') 140 | 141 | # Two digit octal 142 | self.assertEqual(converters.unescape_copy_text('\\10'), '\x08') 143 | self.assertEqual(converters.unescape_copy_text('\\77'), '\x3f') 144 | 145 | # Three digit octal 146 | self.assertEqual(converters.unescape_copy_text('\\101'), 'A') 147 | self.assertEqual(converters.unescape_copy_text('\\141'), 'a') 148 | self.assertEqual(converters.unescape_copy_text('\\377'), '\xff') 149 | 150 | # Octal in context 151 | self.assertEqual( 152 | converters.unescape_copy_text('test\\101data'), 'testAdata' 153 | ) 154 | 155 | def test_hex_escapes(self): 156 | """Test hex escape sequences \\xNN""" 157 | # Single digit hex 158 | self.assertEqual(converters.unescape_copy_text('\\x0'), '\x00') 159 | self.assertEqual(converters.unescape_copy_text('\\xA'), '\x0a') 160 | self.assertEqual(converters.unescape_copy_text('\\xa'), '\x0a') 161 | 162 | # Two digit hex 163 | self.assertEqual(converters.unescape_copy_text('\\x00'), '\x00') 164 | self.assertEqual(converters.unescape_copy_text('\\x41'), 'A') 165 | self.assertEqual(converters.unescape_copy_text('\\xFF'), '\xff') 166 | self.assertEqual(converters.unescape_copy_text('\\xff'), '\xff') 167 | 168 | # Hex in context 169 | self.assertEqual( 170 | converters.unescape_copy_text('test\\x41data'), 'testAdata' 171 | ) 172 | 173 | def test_literal_escapes(self): 174 | """Test that unknown escape sequences are taken literally""" 175 | self.assertEqual(converters.unescape_copy_text('\\z'), 'z') 176 | self.assertEqual(converters.unescape_copy_text('\\@'), '@') 177 | self.assertEqual(converters.unescape_copy_text('\\?'), '?') 178 | 179 | def test_incomplete_hex_escapes(self): 180 | """Test edge cases with incomplete or invalid hex escapes""" 181 | # \x with no digits - 'x' is taken literally 182 | self.assertEqual(converters.unescape_copy_text('\\x'), 'x') 183 | # \x with non-hex character - 'x' and following char are literal 184 | self.assertEqual(converters.unescape_copy_text('\\xG'), 'xG') 185 | self.assertEqual(converters.unescape_copy_text('\\xZ9'), 'xZ9') 186 | 187 | def test_trailing_backslash(self): 188 | """Test that a trailing backslash is preserved""" 189 | self.assertEqual(converters.unescape_copy_text('foo\\'), 'foo\\') 190 | self.assertEqual(converters.unescape_copy_text('\\'), '\\') 191 | 192 | def test_combined_escapes(self): 193 | """Test multiple escape sequences in one string""" 194 | self.assertEqual( 195 | converters.unescape_copy_text('a\\tb\\nc\\\\d'), 196 | 'a\tb\nc\\d', 197 | ) 198 | self.assertEqual( 199 | converters.unescape_copy_text('\\x41\\x42\\x43'), 'ABC' 200 | ) 201 | self.assertEqual( 202 | converters.unescape_copy_text('\\101\\102\\103'), 'ABC' 203 | ) 204 | 205 | def test_null_marker_not_unescaped(self): 206 | """Test that \\N is NOT unescaped by this function. 207 | 208 | The NULL marker \\N must be checked before calling unescape_copy_text, 209 | as PostgreSQL does in CopyReadAttributesText(). 210 | """ 211 | # unescape_copy_text should treat \N as literal N 212 | self.assertEqual(converters.unescape_copy_text('\\N'), 'N') 213 | 214 | def test_converter_preserves_null(self): 215 | """Test that DataConverter checks for \\N before unescaping""" 216 | converter = converters.DataConverter() 217 | # Use actual tab character as separator, \\N as NULL marker 218 | row = 'value1\tvalue2\t\\N\tvalue3' 219 | result = converter.convert(row) 220 | # The \\N should become None, not be unescaped to 'N' 221 | self.assertEqual(result[2], None) 222 | 223 | def test_converter_unescapes_data(self): 224 | """Test that DataConverter properly unescapes field data""" 225 | converter = converters.DataConverter() 226 | 227 | # Test backslash 228 | row = 'C:\\\\Windows\\\\System' 229 | result = converter.convert(row) 230 | self.assertEqual(result[0], 'C:\\Windows\\System') 231 | 232 | # Test newline in data 233 | row = 'line1\\nline2' 234 | result = converter.convert(row) 235 | self.assertEqual(result[0], 'line1\nline2') 236 | 237 | # Test tab-separated values with escaped tabs in data 238 | row = 'col1\\tdata\tcol2' 239 | result = converter.convert(row) 240 | self.assertEqual(result[0], 'col1\tdata') 241 | self.assertEqual(result[1], 'col2') 242 | 243 | def test_smart_converter_unescapes_before_type_conversion(self): 244 | """Test that SmartDataConverter unescapes before type detection""" 245 | converter = converters.SmartDataConverter() 246 | 247 | # Test that escaped characters don't interfere with type detection 248 | # Use actual tab as field separator 249 | row = '42\t\\N\thello\\nworld' 250 | result = converter.convert(row) 251 | self.assertEqual(result[0], 42) # int 252 | self.assertEqual(result[1], None) # NULL 253 | self.assertEqual(result[2], 'hello\nworld') # string with newline 254 | -------------------------------------------------------------------------------- /tests/test_dump.py: -------------------------------------------------------------------------------- 1 | """Test Reader""" 2 | 3 | import dataclasses 4 | import datetime 5 | import logging 6 | import os 7 | import pathlib 8 | import re 9 | import subprocess 10 | import tempfile 11 | import unittest 12 | import uuid 13 | from unittest import mock 14 | 15 | import dotenv 16 | import psycopg 17 | 18 | import pgdumplib 19 | from pgdumplib import constants, converters, dump, exceptions 20 | 21 | dotenv.load_dotenv() 22 | 23 | LOGGER = logging.getLogger(__name__) 24 | 25 | BLOB_COUNT_SQL = 'SELECT COUNT(*) FROM test.users WHERE icon IS NOT NULL;' 26 | 27 | PATTERNS = { 28 | 'timestamp': re.compile(r'\s+Archive created at (.*)\n'), 29 | 'dbname': re.compile(r'\s+dbname: (.*)\n'), 30 | 'compression': re.compile(r'\s+Compression: (.*)\n'), 31 | 'format': re.compile(r'\s+Format: (.*)\n'), 32 | 'integer': re.compile(r'\s+Integer: (.*)\n'), 33 | 'offset': re.compile(r'\s+Offset: (.*)\n'), 34 | 'server_version': re.compile(r'\s+Dumped from database version: (.*)\n'), 35 | 'pg_dump_version': re.compile(r'\s+Dumped by [\w_-]+ version: (.*)\n'), 36 | 'entry_count': re.compile(r'\s+TOC Entries: (.*)\n'), 37 | 'dump_version': re.compile(r'\s+Dump Version: (.*)\n'), 38 | } 39 | 40 | 41 | @dataclasses.dataclass 42 | class DumpInfo: 43 | timestamp: datetime.datetime 44 | dbname: str 45 | compression: str 46 | format: str 47 | integer: str 48 | offset: str 49 | dump_version: str 50 | server_version: str 51 | pg_dump_version: str 52 | entry_count: int 53 | 54 | 55 | class TestCase(unittest.TestCase): 56 | PATH = 'dump.not-compressed' 57 | CONVERTER = converters.DataConverter 58 | 59 | @classmethod 60 | def setUpClass(cls) -> None: 61 | cls.dump_path = pathlib.Path('build') / 'data' / cls.PATH 62 | cls.dump = dump.Dump(converter=cls.CONVERTER).load(cls.dump_path) 63 | 64 | def test_table_data(self): 65 | data = [] 66 | for line in self.dump.table_data('public', 'pgbench_accounts'): 67 | data.append(line) 68 | self.assertEqual(len(data), 100000) 69 | 70 | def test_table_data_empty(self): 71 | data = [] 72 | for line in self.dump.table_data('test', 'empty_table'): 73 | data.append(line) 74 | self.assertEqual(len(data), 0) 75 | 76 | def test_read_dump_entity_not_found(self): 77 | with self.assertRaises(exceptions.EntityNotFoundError): 78 | for line in self.dump.table_data('public', 'foo'): 79 | LOGGER.debug('Line: %r', line) 80 | 81 | def test_lookup_entry(self): 82 | entry = self.dump.lookup_entry( 83 | constants.TABLE, 'public', 'pgbench_accounts' 84 | ) 85 | self.assertEqual(entry.namespace, 'public') 86 | self.assertEqual(entry.tag, 'pgbench_accounts') 87 | self.assertEqual(entry.section, constants.SECTION_PRE_DATA) 88 | 89 | def test_lookup_entry_not_found(self): 90 | self.assertIsNone( 91 | self.dump.lookup_entry(constants.TABLE, 'public', 'foo') 92 | ) 93 | 94 | def test_lookup_entry_invalid_desc(self): 95 | with self.assertRaises(ValueError): 96 | self.dump.lookup_entry('foo', 'public', 'pgbench_accounts') 97 | 98 | def test_get_entry(self): 99 | entry = self.dump.lookup_entry( 100 | constants.TABLE, 'public', 'pgbench_accounts' 101 | ) 102 | self.assertEqual(self.dump.get_entry(entry.dump_id), entry) 103 | 104 | def test_get_entry_not_found(self): 105 | dump_id = max(entry.dump_id for entry in self.dump.entries) + 100 106 | self.assertIsNone(self.dump.get_entry(dump_id)) 107 | 108 | def test_read_blobs(self): 109 | conn = psycopg.connect(os.environ['POSTGRES_URI'], autocommit=True) 110 | cursor = conn.cursor() 111 | cursor.execute(BLOB_COUNT_SQL) 112 | expectation = cursor.fetchone()[0] 113 | conn.close() 114 | 115 | dmp = pgdumplib.load(self.dump_path, self.CONVERTER) 116 | blobs = [] 117 | for oid, blob in dmp.blobs(): 118 | self.assertIsInstance(oid, int) 119 | self.assertIsInstance(blob, bytes) 120 | blobs.append((oid, blob)) 121 | self.assertEqual(len(blobs), expectation) 122 | 123 | 124 | class CompressedTestCase(TestCase): 125 | PATH = 'dump.compressed' 126 | 127 | 128 | class InsertsTestCase(TestCase): 129 | CONVERTER = converters.NoOpConverter 130 | PATH = 'dump.inserts' 131 | 132 | def test_read_dump_data(self): 133 | count = 0 134 | for line in self.dump.table_data('public', 'pgbench_accounts'): 135 | self.assertTrue( 136 | line.startswith('INSERT INTO public.pgbench_accounts'), 137 | f'Unexpected start @ row {count}: {line!r}', 138 | ) 139 | count += 1 140 | self.assertEqual(count, 100000) 141 | 142 | 143 | class NoDataTestCase(TestCase): 144 | HAS_DATA = False 145 | PATH = 'dump.no-data' 146 | 147 | def test_table_data(self): 148 | with self.assertRaises(exceptions.EntityNotFoundError): 149 | super().test_table_data() 150 | 151 | def test_table_data_empty(self): 152 | with self.assertRaises(exceptions.EntityNotFoundError): 153 | super().test_table_data_empty() 154 | 155 | def test_read_blobs(self): 156 | self.assertEqual(len(list(self.dump.blobs())), 0) 157 | 158 | 159 | class ErrorsTestCase(unittest.TestCase): 160 | def test_missing_file_raises_value_error(self): 161 | path = pathlib.Path(tempfile.gettempdir()) / str(uuid.uuid4()) 162 | with self.assertRaises(ValueError): 163 | pgdumplib.load(path) 164 | 165 | def test_min_version_failure_raises(self): 166 | min_ver = ( 167 | constants.MIN_VER[0], 168 | constants.MIN_VER[1] + 10, 169 | constants.MIN_VER[2], 170 | ) 171 | LOGGER.debug('Setting pgdumplib.constants.MIN_VER to %s', min_ver) 172 | with mock.patch('pgdumplib.constants.MIN_VER', min_ver): 173 | with self.assertRaises(ValueError): 174 | pgdumplib.load('build/data/dump.not-compressed') 175 | 176 | def test_max_version_failure_raises(self): 177 | max_ver = (0, constants.MAX_VER[1], constants.MAX_VER[2]) 178 | LOGGER.debug('Setting pgdumplib.constants.MAX_VER to %s', max_ver) 179 | with mock.patch('pgdumplib.constants.MAX_VER', max_ver): 180 | with self.assertRaises(ValueError): 181 | pgdumplib.load('build/data/dump.not-compressed') 182 | 183 | def test_invalid_dump_file(self): 184 | with tempfile.NamedTemporaryFile('wb') as temp: 185 | temp.write(b'PGBAD') 186 | with open('build/data/dump.not-compressed', 'rb') as handle: 187 | handle.read(5) 188 | temp.write(handle.read()) 189 | 190 | with self.assertRaises(ValueError): 191 | pgdumplib.load(temp.name) 192 | 193 | def test_plain_sql_file_with_comment(self): 194 | """Test plain SQL files with comments get helpful error message""" 195 | with tempfile.NamedTemporaryFile( 196 | 'w', delete=False, suffix='.sql' 197 | ) as temp: 198 | temp.write('-- PostgreSQL database dump\n') 199 | temp.write('CREATE TABLE test (id INT);\n') 200 | temp_name = temp.name 201 | 202 | try: 203 | with self.assertRaises(ValueError) as context: 204 | pgdumplib.load(temp_name) 205 | self.assertIn('plain SQL text file', str(context.exception)) 206 | self.assertIn('pg_dump -Fc', str(context.exception)) 207 | finally: 208 | os.unlink(temp_name) 209 | 210 | def test_plain_sql_file_with_create(self): 211 | """Test plain SQL files starting with CREATE get helpful error""" 212 | with tempfile.NamedTemporaryFile( 213 | 'w', delete=False, suffix='.sql' 214 | ) as temp: 215 | temp.write('CREATE DATABASE test;\n') 216 | temp_name = temp.name 217 | 218 | try: 219 | with self.assertRaises(ValueError) as context: 220 | pgdumplib.load(temp_name) 221 | self.assertIn('plain SQL text file', str(context.exception)) 222 | self.assertIn('pg_dump -Fc', str(context.exception)) 223 | finally: 224 | os.unlink(temp_name) 225 | 226 | def test_invalid_binary_format(self): 227 | """Test that invalid binary formats get a helpful error message""" 228 | with tempfile.NamedTemporaryFile('wb', delete=False) as temp: 229 | temp.write(b'\x00\x01\x02\x03\x04') 230 | temp_name = temp.name 231 | 232 | try: 233 | with self.assertRaises(ValueError) as context: 234 | pgdumplib.load(temp_name) 235 | self.assertIn('Invalid archive', str(context.exception)) 236 | self.assertIn('pg_dump -Fc', str(context.exception)) 237 | finally: 238 | os.unlink(temp_name) 239 | 240 | 241 | class NewDumpTestCase(unittest.TestCase): 242 | def test_pgdumplib_new(self): 243 | dmp = pgdumplib.new('test', 'UTF8', converters.SmartDataConverter) 244 | self.assertIsInstance(dmp, dump.Dump) 245 | self.assertIsInstance(dmp._converter, converters.SmartDataConverter) 246 | 247 | 248 | class RestoreComparisonTestCase(unittest.TestCase): 249 | PATH = 'dump.not-compressed' 250 | 251 | @classmethod 252 | def setUpClass(cls) -> None: 253 | cls.local_path = pathlib.Path('build') / 'data' / cls.PATH 254 | cls.info = cls._read_dump_info(cls.local_path) 255 | LOGGER.debug('Info: %r', cls.info) 256 | 257 | def setUp(self) -> None: 258 | self.dump = self._read_dump() 259 | LOGGER.debug('Dump: %r', self.dump) 260 | 261 | def _read_dump(self): 262 | return pgdumplib.load(self.local_path) 263 | 264 | @classmethod 265 | def _read_dump_info(cls, remote_path) -> DumpInfo: 266 | restore = subprocess.run( # noqa: S603 267 | ['pg_restore', '-l', str(remote_path)], # noqa: S607 268 | check=True, 269 | capture_output=True, 270 | ) 271 | stdout = restore.stdout.decode('utf-8') 272 | data = {} 273 | for key, pattern in PATTERNS.items(): 274 | match = pattern.findall(stdout) 275 | if not match: 276 | LOGGER.warning('No match for %s', key) 277 | elif key == 'compression': 278 | # pg_restore outputs 'none', '0', or compression level 279 | data[key] = match[0] not in ('0', 'none') 280 | elif key == 'entry_count': 281 | data[key] = int(match[0]) 282 | else: 283 | data[key] = match[0] 284 | return DumpInfo(**data) 285 | 286 | def test_dump_exists(self): 287 | self.assertIsNotNone(self.dump) 288 | 289 | def test_toc_compression(self): 290 | self.assertEqual( 291 | self.dump.compression_algorithm != constants.COMPRESSION_NONE, 292 | self.info.compression, 293 | ) 294 | 295 | def test_toc_dbname(self): 296 | self.assertEqual(self.dump.dbname, 'postgres') 297 | 298 | def test_toc_dump_version(self): 299 | self.assertEqual(self.dump.dump_version, self.info.pg_dump_version) 300 | 301 | def test_toc_entry_count(self): 302 | self.assertEqual(len(self.dump.entries), self.info.entry_count) 303 | 304 | def test_toc_server_version(self): 305 | self.assertEqual(self.dump.server_version, self.info.server_version) 306 | 307 | # def test_toc_timestamp(self): 308 | # self.assertEqual( 309 | # self.dump.timestamp.isoformat(), self.info.timestamp.isoformat()) 310 | 311 | 312 | class RestoreComparisonCompressedTestCase(RestoreComparisonTestCase): 313 | PATH = 'dump.compressed' 314 | 315 | 316 | class RestoreComparisonNoDataTestCase(RestoreComparisonTestCase): 317 | HAS_DATA = False 318 | PATH = 'dump.no-data' 319 | 320 | 321 | class RestoreComparisonDataOnlyTestCase(RestoreComparisonTestCase): 322 | PATH = 'dump.data-only' 323 | 324 | 325 | class KVersionTestCase(unittest.TestCase): 326 | def test_default(self): 327 | instance = dump.Dump() 328 | self.assertEqual(instance.version, (1, 14, 0)) 329 | 330 | def test_postgres_9_0_1(self): 331 | instance = dump.Dump(appear_as='9.0.1') 332 | self.assertEqual(instance.version, (1, 12, 0)) 333 | 334 | def test_postgres_9_6_4(self): 335 | instance = dump.Dump(appear_as='9.6.4') 336 | self.assertEqual(instance.version, (1, 12, 0)) 337 | 338 | def test_postgres_10_1(self): 339 | instance = dump.Dump(appear_as='10.1') 340 | self.assertEqual(instance.version, (1, 12, 0)) 341 | 342 | def test_postgres_10_3(self): 343 | instance = dump.Dump(appear_as='10.3') 344 | self.assertEqual(instance.version, (1, 13, 0)) 345 | 346 | def test_postgres_11(self): 347 | instance = dump.Dump(appear_as='11.0') 348 | self.assertEqual(instance.version, (1, 13, 0)) 349 | 350 | def test_postgres_12(self): 351 | instance = dump.Dump(appear_as='12.0') 352 | self.assertEqual(instance.version, (1, 14, 0)) 353 | 354 | def test_postgres_8_4_0(self): 355 | with self.assertRaises(RuntimeError): 356 | dump.Dump(appear_as='8.4.0') 357 | 358 | def test_postgres_100(self): 359 | with self.assertRaises(RuntimeError): 360 | dump.Dump(appear_as='100.0') 361 | -------------------------------------------------------------------------------- /pgdumplib/dump.py: -------------------------------------------------------------------------------- 1 | """ 2 | The :py:class:`~pgdumplib.dump.Dump` class exposes methods to 3 | :py:meth:`load ` an existing dump, 4 | to :py:meth:`add an entry ` to a dump, 5 | to :py:meth:`add table data ` to a dump, 6 | to :py:meth:`add blob data ` to a dump, 7 | and to :py:meth:`save ` a new dump. 8 | 9 | There are :doc:`converters` that are available to format the data that is 10 | returned by :py:meth:`~pgdumplib.dump.Dump.read_data`. The converter 11 | is passed in during construction of a new :py:class:`~pgdumplib.dump.Dump`, 12 | and is also available as an argument to :py:func:`pgdumplib.load`. 13 | 14 | The default converter, :py:class:`~pgdumplib.converters.DataConverter` will 15 | return all fields as strings, only replacing ``NULL`` with 16 | :py:const:`None`. The :py:class:`~pgdumplib.converters.SmartDataConverter` 17 | will attempt to convert all columns to native Python data types. 18 | 19 | When loading or creating a dump, the table and blob data are stored in 20 | gzip compressed data files in a temporary directory that is automatically 21 | cleaned up when the :py:class:`~pgdumplib.dump.Dump` instance is released. 22 | 23 | """ 24 | 25 | import contextlib 26 | import datetime 27 | import gzip 28 | import io 29 | import logging 30 | import os 31 | import pathlib 32 | import re 33 | import struct 34 | import tempfile 35 | import typing 36 | import zlib 37 | from collections import abc 38 | 39 | import toposort # type: ignore[import-untyped] 40 | 41 | from pgdumplib import constants, converters, exceptions, models, version 42 | 43 | LOGGER = logging.getLogger(__name__) 44 | 45 | ENCODING_PATTERN = re.compile(r"^.*=\s+'(.*)'") 46 | 47 | VERSION_INFO = '{} (pgdumplib {})' 48 | 49 | 50 | class TableData: 51 | """Used to encapsulate table data using temporary file and allowing 52 | for an API that allows for the appending of data one row at a time. 53 | 54 | Do not create this class directly, instead invoke 55 | :py:meth:`~pgdumplib.dump.Dump.table_data_writer`. 56 | 57 | """ 58 | 59 | def __init__(self, dump_id: int, tempdir: str, encoding: str): 60 | self.dump_id = dump_id 61 | self._encoding = encoding 62 | self._path = pathlib.Path(tempdir) / f'{dump_id}.gz' 63 | self._handle = gzip.open(self._path, 'wb') 64 | 65 | def append(self, *args) -> None: 66 | """Append a row to the table data, passing columns in as args 67 | 68 | Column order must match the order specified when 69 | :py:meth:`~pgdumplib.dump.Dump.table_data_writer` was invoked. 70 | 71 | All columns will be coerced to a string with special attention 72 | paid to ``None``, converting it to the null marker (``\\N``) and 73 | :py:class:`datetime.datetime` objects, which will have the proper 74 | pg_dump timestamp format applied to them. 75 | 76 | """ 77 | row = '\t'.join([self._convert(c) for c in args]) 78 | self._handle.write(f'{row}\n'.encode(self._encoding)) 79 | 80 | def finish(self) -> None: 81 | """Invoked prior to saving a dump to close the temporary data 82 | handle and switch the class into read-only mode. 83 | 84 | For use by :py:class:`pgdumplib.dump.Dump` only. 85 | 86 | """ 87 | if not self._handle.closed: 88 | self._handle.close() 89 | self._handle = gzip.open(self._path, 'rb') 90 | 91 | def read(self) -> bytes: 92 | """Read the data from disk for writing to the dump 93 | 94 | For use by :py:class:`pgdumplib.dump.Dump` only. 95 | 96 | """ 97 | self._handle.seek(0) 98 | return self._handle.read() 99 | 100 | @property 101 | def size(self) -> int: 102 | """Return the current size of the data on disk""" 103 | self._handle.seek(0, io.SEEK_END) # Seek to end to figure out size 104 | size = self._handle.tell() 105 | self._handle.seek(0) 106 | return size 107 | 108 | @staticmethod 109 | def _convert(column: typing.Any) -> str: 110 | """Convert the column to a string 111 | 112 | :param column: The column to convert 113 | 114 | """ 115 | if isinstance(column, datetime.datetime): 116 | return column.strftime(constants.PGDUMP_STRFTIME_FMT) 117 | elif column is None: 118 | return '\\N' 119 | return str(column) 120 | 121 | 122 | class Dump: 123 | """Create a new instance of the :py:class:`~pgdumplib.dump.Dump` class 124 | 125 | Once created, the instance of :py:class:`~pgdumplib.dump.Dump` can 126 | be used to read existing dumps or to create new ones. 127 | 128 | :param str dbname: The database name for the dump (Default: ``pgdumplib``) 129 | :param str encoding: The data encoding (Default: ``UTF8``) 130 | :param converter: The data converter class to use 131 | (Default: :py:class:`pgdumplib.converters.DataConverter`) 132 | 133 | """ 134 | 135 | def __init__( 136 | self, 137 | dbname: str = 'pgdumplib', 138 | encoding: str = 'UTF8', 139 | converter: typing.Any = None, 140 | appear_as: str = '12.0', 141 | ): 142 | self.compression_algorithm = constants.COMPRESSION_NONE 143 | self.dbname = dbname 144 | self.dump_version = VERSION_INFO.format(appear_as, version) 145 | self.encoding = encoding 146 | self.entries = [ 147 | models.Entry( 148 | dump_id=1, 149 | tag=constants.ENCODING, 150 | desc=constants.ENCODING, 151 | defn=f"SET client_encoding = '{self.encoding}';\n", 152 | ), 153 | models.Entry( 154 | dump_id=2, 155 | tag='STDSTRINGS', 156 | desc='STDSTRINGS', 157 | defn="SET standard_conforming_strings = 'on';\n", 158 | ), 159 | models.Entry( 160 | dump_id=3, 161 | tag='SEARCHPATH', 162 | desc='SEARCHPATH', 163 | defn='SELECT pg_catalog.set_config(' 164 | "'search_path', '', false);\n", 165 | ), 166 | ] 167 | self.server_version = self.dump_version 168 | self.timestamp = datetime.datetime.now(tz=datetime.UTC) 169 | 170 | converter = converter or converters.DataConverter 171 | self._converter: converters.DataConverter = converter() 172 | self._format: str = 'Custom' 173 | self._handle: io.BufferedReader | io.BufferedWriter | None = None 174 | self._intsize: int = 4 175 | self._offsize: int = 8 176 | self._temp_dir = tempfile.TemporaryDirectory() 177 | parts = tuple(int(v) for v in appear_as.split('.')) 178 | if len(parts) < 2: 179 | raise ValueError(f'Invalid appear_as version: {appear_as}') 180 | k_version = self._get_k_version(parts) 181 | self._vmaj: int = k_version[0] 182 | self._vmin: int = k_version[1] 183 | self._vrev: int = k_version[2] 184 | self._writers: dict[int, TableData] = {} 185 | 186 | def __repr__(self) -> str: 187 | return ( 188 | f'' 191 | ) 192 | 193 | def add_entry( 194 | self, 195 | desc: str, 196 | namespace: str | None = None, 197 | tag: str | None = None, 198 | owner: str | None = None, 199 | defn: str | None = None, 200 | drop_stmt: str | None = None, 201 | copy_stmt: str | None = None, 202 | dependencies: list[int] | None = None, 203 | tablespace: str | None = None, 204 | tableam: str | None = None, 205 | dump_id: int | None = None, 206 | ) -> models.Entry: 207 | """Add an entry to the dump 208 | 209 | A :py:exc:`ValueError` will be raised if `desc` is not value that 210 | is known in :py:module:`pgdumplib.constants`. 211 | 212 | The section is 213 | 214 | When adding data, use :py:meth:`~Dump.table_data_writer` instead of 215 | invoking :py:meth:`~Dump.add_entry` directly. 216 | 217 | If ``dependencies`` are specified, they will be validated and if a 218 | ``dump_id`` is specified and no entry is found with that ``dump_id``, 219 | a :py:exc:`ValueError` will be raised. 220 | 221 | Other omitted values will be set to the default values will be set to 222 | the defaults specified in the :py:class:`pgdumplib.dump.Entry` 223 | class. 224 | 225 | The ``dump_id`` will be auto-calculated based upon the existing entries 226 | if it is not specified. 227 | 228 | .. note:: The creation of ad-hoc blobs is not supported. 229 | 230 | :param str desc: The entry description 231 | :param str namespace: The namespace of the entry 232 | :param str tag: The name/table/relation/etc of the entry 233 | :param str owner: The owner of the object in Postgres 234 | :param str defn: The DDL definition for the entry 235 | :param drop_stmt: A drop statement used to drop the entry before 236 | :param copy_stmt: A copy statement used when there is a corresponding 237 | data section. 238 | :param list dependencies: A list of dump_ids of objects that the entry 239 | is dependent upon. 240 | :param str tablespace: The tablespace to use 241 | :param str tableam: The table access method 242 | :param int dump_id: The dump id, will be auto-calculated if left empty 243 | :raises: :py:exc:`ValueError` 244 | :rtype: pgdumplib.dump.Entry 245 | 246 | """ 247 | if desc not in constants.SECTION_MAPPING: 248 | raise ValueError(f'Invalid desc: {desc}') 249 | 250 | if dump_id is not None and dump_id < 1: 251 | raise ValueError('dump_id must be greater than 1') 252 | 253 | dump_ids = [e.dump_id for e in self.entries] 254 | 255 | if dump_id and dump_id in dump_ids: 256 | raise ValueError('dump_id {!r} is already assigned', dump_id) 257 | 258 | for dependency in dependencies or []: 259 | if dependency not in dump_ids: 260 | raise ValueError( 261 | f'Dependency dump_id {dependency!r} not found' 262 | ) 263 | self.entries.append( 264 | models.Entry( 265 | dump_id=dump_id or self._next_dump_id(), 266 | had_dumper=False, 267 | table_oid='', 268 | oid='', 269 | tag=tag or '', 270 | desc=desc, 271 | defn=defn or '', 272 | drop_stmt=drop_stmt or '', 273 | copy_stmt=copy_stmt or '', 274 | namespace=namespace or '', 275 | tablespace=tablespace or None, 276 | tableam=tableam or None, 277 | relkind=None, 278 | owner=owner or '', 279 | with_oids=False, 280 | dependencies=dependencies or [], 281 | ) 282 | ) 283 | return self.entries[-1] 284 | 285 | def blobs(self) -> typing.Generator[tuple[int, bytes], None, None]: 286 | """Iterator that returns each blob in the dump 287 | 288 | :rtype: tuple(int, bytes) 289 | 290 | """ 291 | 292 | def read_oid(fd: io.BufferedReader) -> int | None: 293 | """Small helper function to deduplicate code""" 294 | try: 295 | return struct.unpack('I', fd.read(4))[0] 296 | except struct.error: 297 | return None 298 | 299 | for entry in self._data_entries: 300 | if entry.desc == constants.BLOBS: 301 | with self._tempfile(entry.dump_id, 'rb') as handle: 302 | oid: int | None = read_oid(handle) 303 | while oid: 304 | length: int = struct.unpack('I', handle.read(4))[0] 305 | yield oid, handle.read(length) 306 | oid = read_oid(handle) 307 | 308 | def get_entry(self, dump_id: int) -> models.Entry | None: 309 | """Return the entry for the given `dump_id` 310 | 311 | :param int dump_id: The dump ID of the entry to return. 312 | 313 | """ 314 | for entry in self.entries: 315 | if entry.dump_id == dump_id: 316 | return entry 317 | return None 318 | 319 | def load(self, path: str | os.PathLike) -> typing.Self: 320 | """Load the Dumpfile, including extracting all data into a temporary 321 | directory 322 | 323 | :param os.PathLike path: The path of the dump to load 324 | :raises: :py:exc:`RuntimeError` 325 | :raises: :py:exc:`ValueError` 326 | 327 | """ 328 | if not pathlib.Path(path).exists(): 329 | raise ValueError(f'Path {path!r} does not exist') 330 | 331 | LOGGER.debug('Loading dump file from %s', path) 332 | 333 | self.entries = [] # Wipe out pre-existing entries 334 | self._handle = open(path, 'rb') 335 | self._read_header() 336 | if not constants.MIN_VER <= self.version <= constants.MAX_VER: 337 | raise ValueError( 338 | 'Unsupported backup version: {}.{}.{}'.format(*self.version) 339 | ) 340 | 341 | if self.version >= (1, 15, 0): 342 | self.compression_algorithm = constants.COMPRESSION_ALGORITHMS[ 343 | self._compression_algorithm 344 | ] 345 | 346 | if ( 347 | self.compression_algorithm 348 | not in constants.SUPPORTED_COMPRESSION_ALGORITHMS 349 | ): 350 | raise ValueError( 351 | 'Unsupported compression algorithm: {}'.format( 352 | *self.compression_algorithm 353 | ) 354 | ) 355 | else: 356 | self.compression_algorithm = ( 357 | constants.COMPRESSION_GZIP 358 | if self._read_int() != 0 359 | else constants.COMPRESSION_NONE 360 | ) 361 | 362 | self.timestamp = self._read_timestamp() 363 | self.dbname = self._read_bytes().decode(self.encoding) 364 | self.server_version = self._read_bytes().decode(self.encoding) 365 | self.dump_version = self._read_bytes().decode(self.encoding) 366 | 367 | self._read_entries() 368 | self._set_encoding() 369 | 370 | # Cache table data and blobs 371 | _last_pos = self._handle.tell() 372 | 373 | for entry in self._data_entries: 374 | if entry.data_state == constants.K_OFFSET_NO_DATA: 375 | continue 376 | elif entry.data_state != constants.K_OFFSET_POS_SET: 377 | raise RuntimeError('Unsupported data format') 378 | self._handle.seek(entry.offset, io.SEEK_SET) 379 | block_type, dump_id = self._read_block_header() 380 | if not dump_id or dump_id != entry.dump_id: 381 | raise RuntimeError( 382 | f'Dump IDs do not match ({dump_id} != {entry.dump_id}' 383 | ) 384 | if block_type == constants.BLK_DATA: 385 | self._cache_table_data(dump_id) 386 | elif block_type == constants.BLK_BLOBS: 387 | self._cache_blobs(dump_id) 388 | else: 389 | raise RuntimeError(f'Unknown block type: {block_type!r}') 390 | return self 391 | 392 | def lookup_entry( 393 | self, desc: str, namespace: str, tag: str 394 | ) -> models.Entry | None: 395 | """Return the entry for the given namespace and tag 396 | 397 | :param str desc: The desc / object type of the entry 398 | :param str namespace: The namespace of the entry 399 | :param str tag: The tag/relation/table name 400 | :raises: :py:exc:`ValueError` 401 | :rtype: pgdumplib.dump.Entry or None 402 | 403 | """ 404 | if desc not in constants.SECTION_MAPPING: 405 | raise ValueError(f'Invalid desc: {desc}') 406 | for entry in [e for e in self.entries if e.desc == desc]: 407 | if entry.namespace == namespace and entry.tag == tag: 408 | return entry 409 | return None 410 | 411 | def save(self, path: str | os.PathLike) -> None: 412 | """Save the Dump file to the specified path 413 | 414 | :param path: The path to save the dump to 415 | :type path: str or os.PathLike 416 | 417 | """ 418 | if self._handle is not None and not self._handle.closed: 419 | self._handle.close() 420 | self.compression_algorithm = constants.COMPRESSION_NONE 421 | self._handle = open(path, 'wb') 422 | self._save() 423 | self._handle.close() 424 | 425 | def table_data( 426 | self, namespace: str, table: str 427 | ) -> typing.Generator[str | tuple[typing.Any, ...], None, None]: 428 | """Iterator that returns data for the given namespace and table 429 | 430 | :param str namespace: The namespace/schema for the table 431 | :param str table: The table name 432 | :raises: :py:exc:`pgdumplib.exceptions.EntityNotFoundError` 433 | 434 | """ 435 | for entry in self._data_entries: 436 | if entry.namespace == namespace and entry.tag == table: 437 | for row in self._read_table_data(entry.dump_id): 438 | yield self._converter.convert(row) 439 | return 440 | raise exceptions.EntityNotFoundError(namespace=namespace, table=table) 441 | 442 | @contextlib.contextmanager 443 | def table_data_writer( 444 | self, entry: models.Entry, columns: abc.Sequence 445 | ) -> typing.Generator[TableData, None, None]: 446 | """A context manager that is used to return a 447 | :py:class:`~pgdumplib.dump.TableData` instance, which can be used 448 | to add table data to the dump. 449 | 450 | When invoked for a given entry containing the table definition, 451 | 452 | :param Entry entry: The entry for the table to add data for 453 | :param columns: The ordered list of table columns 454 | :type columns: list or tuple 455 | :rtype: TableData 456 | 457 | """ 458 | if entry.dump_id not in self._writers.keys(): 459 | dump_id = self._next_dump_id() 460 | self.entries.append( 461 | models.Entry( 462 | dump_id=dump_id, 463 | had_dumper=True, 464 | tag=entry.tag, 465 | desc=constants.TABLE_DATA, 466 | copy_stmt='COPY {}.{} ({}) FROM stdin;'.format( 467 | entry.namespace, entry.tag, ', '.join(columns) 468 | ), 469 | namespace=entry.namespace, 470 | owner=entry.owner, 471 | dependencies=[entry.dump_id], 472 | data_state=constants.K_OFFSET_POS_NOT_SET, 473 | ) 474 | ) 475 | self._writers[entry.dump_id] = TableData( 476 | dump_id, self._temp_dir.name, self.encoding 477 | ) 478 | yield self._writers[entry.dump_id] 479 | return None 480 | 481 | @property 482 | def version(self) -> tuple[int, int, int]: 483 | """Return the version as a tuple to make version comparisons easier. 484 | 485 | :rtype: tuple 486 | 487 | """ 488 | return self._vmaj, self._vmin, self._vrev 489 | 490 | def _cache_blobs(self, dump_id: int) -> None: 491 | """Create a temp cache file for blob data 492 | 493 | :param int dump_id: The dump ID for the filename 494 | 495 | """ 496 | count = 0 497 | with self._tempfile(dump_id, 'wb') as handle: 498 | for oid, blob in self._read_blobs(): 499 | handle.write(struct.pack('I', oid)) 500 | handle.write(struct.pack('I', len(blob))) 501 | handle.write(blob) 502 | count += 1 503 | 504 | def _cache_table_data(self, dump_id: int) -> None: 505 | """Create a temp cache file for the table data 506 | 507 | :param int dump_id: The dump ID for the filename 508 | 509 | """ 510 | with self._tempfile(dump_id, 'wb') as handle: 511 | handle.write(self._read_data()) 512 | 513 | @property 514 | def _data_entries(self) -> list[models.Entry]: 515 | """Return the list of entries that are in the data section 516 | 517 | :rtype: list 518 | 519 | """ 520 | return [e for e in self.entries if e.section == constants.SECTION_DATA] 521 | 522 | @staticmethod 523 | def _get_k_version(appear_as: tuple[int, ...]) -> tuple[int, int, int]: 524 | for (min_ver, max_ver), value in constants.K_VERSION_MAP.items(): 525 | if min_ver <= appear_as <= max_ver: 526 | return value 527 | raise RuntimeError(f'Unsupported PostgreSQL version: {appear_as}') 528 | 529 | def _next_dump_id(self) -> int: 530 | """Get the next ``dump_id`` that is available for adding an entry 531 | 532 | :rtype: int 533 | 534 | """ 535 | return max(e.dump_id for e in self.entries) + 1 536 | 537 | def _read_blobs(self) -> typing.Generator[tuple[int, bytes], None, None]: 538 | """Read blobs, returning a tuple of the blob ID and the blob data 539 | 540 | :rtype: (int, bytes) 541 | :raises: :exc:`RuntimeError` 542 | 543 | """ 544 | oid = self._read_int() 545 | while oid is not None and oid > 0: 546 | data = self._read_data() 547 | yield oid, data 548 | oid = self._read_int() 549 | if oid == 0: 550 | oid = self._read_int() 551 | 552 | def _read_block_header(self) -> tuple[bytes, int | None]: 553 | """Read the block header in 554 | 555 | :rtype: bytes, int 556 | 557 | """ 558 | if self._handle is None: 559 | raise ValueError('File handle is not initialized') 560 | return self._handle.read(1), self._read_int() 561 | 562 | def _read_byte(self) -> int | None: 563 | """Read in an individual byte 564 | 565 | :rtype: int 566 | 567 | """ 568 | if self._handle is None: 569 | raise ValueError('File handle is not initialized') 570 | try: 571 | return struct.unpack('B', self._handle.read(1))[0] 572 | except struct.error: 573 | return None 574 | 575 | def _read_bytes(self) -> bytes: 576 | """Read in a byte stream 577 | 578 | :rtype: bytes 579 | 580 | """ 581 | if self._handle is None: 582 | raise ValueError('File handle is not initialized') 583 | length = self._read_int() 584 | if length and length > 0: 585 | value = self._handle.read(length) 586 | return value 587 | return b'' 588 | 589 | def _read_data(self) -> bytes: 590 | """Read a data block, returning the bytes. 591 | 592 | :rtype: bytes 593 | 594 | """ 595 | if self.compression_algorithm != constants.COMPRESSION_NONE: 596 | return self._read_data_compressed() 597 | return self._read_data_uncompressed() 598 | 599 | def _read_data_compressed(self) -> bytes: 600 | """Read a compressed data block 601 | 602 | :rtype: bytes 603 | 604 | """ 605 | if self._handle is None: 606 | raise ValueError('File handle is not initialized') 607 | buffer = io.BytesIO() 608 | chunk = b'' 609 | decompress = zlib.decompressobj() 610 | while True: 611 | chunk_size = self._read_int() 612 | if not chunk_size: # pragma: nocover 613 | break 614 | chunk += self._handle.read(chunk_size) 615 | buffer.write(decompress.decompress(chunk)) 616 | chunk = decompress.unconsumed_tail 617 | if chunk_size < constants.ZLIB_IN_SIZE: 618 | break 619 | return buffer.getvalue() 620 | 621 | def _read_data_uncompressed(self) -> bytes: 622 | """Read an uncompressed data block 623 | 624 | :rtype: bytes 625 | 626 | """ 627 | if self._handle is None: 628 | raise ValueError('File handle is not initialized') 629 | buffer = io.BytesIO() 630 | while True: 631 | block_length = self._read_int() 632 | if not block_length or block_length <= 0: 633 | break 634 | buffer.write(self._handle.read(block_length)) 635 | return buffer.getvalue() 636 | 637 | def _read_dependencies(self) -> list[int]: 638 | """Read in the dependencies for an entry. 639 | 640 | :rtype: list 641 | 642 | """ 643 | values = set({}) 644 | while True: 645 | value = self._read_bytes() 646 | if not value: 647 | break 648 | values.add(int(value)) 649 | return sorted(values) 650 | 651 | def _read_entries(self) -> None: 652 | """Read in all of the entries""" 653 | for _i in range(0, self._read_int() or 0): 654 | self._read_entry() 655 | 656 | def _read_entry(self) -> None: 657 | """Read in an individual entry and append it to the entries stack""" 658 | dump_id = self._read_int() 659 | if dump_id is None: 660 | raise ValueError('dump_id cannot be None') 661 | had_dumper = bool(self._read_int()) 662 | table_oid = self._read_bytes().decode(self.encoding) 663 | oid = self._read_bytes().decode(self.encoding) 664 | tag = self._read_bytes().decode(self.encoding) 665 | desc = self._read_bytes().decode(self.encoding) 666 | self._read_int() # Section is mapped, no need to assign 667 | defn = self._read_bytes().decode(self.encoding) 668 | drop_stmt = self._read_bytes().decode(self.encoding) 669 | copy_stmt = self._read_bytes().decode(self.encoding) 670 | namespace = self._read_bytes().decode(self.encoding) 671 | tablespace = self._read_bytes().decode(self.encoding) 672 | # Normalize empty strings to None for consistency 673 | tablespace = tablespace if tablespace else None 674 | if self.version >= (1, 14, 0): 675 | tableam = self._read_bytes().decode(self.encoding) 676 | # Normalize empty strings to None to prevent invalid SQL 677 | # generation (e.g., SET default_table_access_method = "";) 678 | tableam = tableam if tableam else None 679 | else: 680 | tableam = None 681 | if self.version >= (1, 16, 0): 682 | relkind_val = self._read_int() 683 | relkind = chr(relkind_val) if relkind_val else None 684 | else: 685 | relkind = None 686 | owner = self._read_bytes().decode(self.encoding) 687 | with_oids = self._read_bytes() == b'true' 688 | dependencies = self._read_dependencies() 689 | data_state, offset = self._read_offset() 690 | self.entries.append( 691 | models.Entry( 692 | dump_id=dump_id, 693 | had_dumper=had_dumper, 694 | table_oid=table_oid, 695 | oid=oid, 696 | tag=tag, 697 | desc=desc, 698 | defn=defn, 699 | drop_stmt=drop_stmt, 700 | copy_stmt=copy_stmt, 701 | namespace=namespace, 702 | tablespace=tablespace, 703 | tableam=tableam, 704 | relkind=relkind, 705 | owner=owner, 706 | with_oids=with_oids, 707 | dependencies=dependencies, 708 | data_state=data_state or 0, 709 | offset=offset or 0, 710 | ) 711 | ) 712 | 713 | def _read_header(self) -> None: 714 | """Read in the dump header 715 | 716 | :raises: ValueError 717 | 718 | """ 719 | if self._handle is None: 720 | raise ValueError('File handle is not initialized') 721 | magic_bytes = self._handle.read(5) 722 | if magic_bytes != constants.MAGIC: 723 | # Provide helpful error messages based on file content 724 | error_msg = ( 725 | 'Invalid archive header. ' 726 | 'pgdumplib only supports custom format dumps ' 727 | 'created with pg_dump -Fc' 728 | ) 729 | try: 730 | # Try to detect plain SQL files 731 | file_start = magic_bytes.decode('ascii', errors='ignore') 732 | if file_start.startswith(('--', '/*', 'SE', 'CR', 'IN', 'DR')): 733 | error_msg = ( 734 | 'This appears to be a plain SQL text file. ' 735 | 'pgdumplib only supports custom format dumps ' 736 | 'created with pg_dump -Fc' 737 | ) 738 | elif len(file_start) == 0 or not file_start.isprintable(): 739 | error_msg = ( 740 | 'Invalid archive format. ' 741 | 'pgdumplib only supports custom format dumps ' 742 | 'created with pg_dump -Fc' 743 | ) 744 | except (UnicodeDecodeError, AttributeError): 745 | # Ignore errors from decode or isprintable on invalid data 746 | pass 747 | raise ValueError(error_msg) 748 | self._vmaj = struct.unpack('B', self._handle.read(1))[0] 749 | self._vmin = struct.unpack('B', self._handle.read(1))[0] 750 | self._vrev = struct.unpack('B', self._handle.read(1))[0] 751 | self._intsize = struct.unpack('B', self._handle.read(1))[0] 752 | self._offsize = struct.unpack('B', self._handle.read(1))[0] 753 | self._format = constants.FORMATS[ 754 | struct.unpack('B', self._handle.read(1))[0] 755 | ] 756 | LOGGER.debug( 757 | 'Archive version %i.%i.%i', self._vmaj, self._vmin, self._vrev 758 | ) 759 | # v1.15+ has compression_spec.algorithm byte 760 | if (self._vmaj, self._vmin, self._vrev) >= (1, 15, 0): 761 | self._compression_algorithm = struct.unpack( 762 | 'B', self._handle.read(1) 763 | )[0] 764 | 765 | def _read_int(self) -> int | None: 766 | """Read in a signed integer 767 | 768 | :rtype: int or None 769 | 770 | """ 771 | sign = self._read_byte() 772 | if sign is None: 773 | return None 774 | bs, bv, value = 0, 0, 0 775 | for _offset in range(0, self._intsize): 776 | bv = (self._read_byte() or 0) & 0xFF 777 | if bv != 0: 778 | value += bv << bs 779 | bs += 8 780 | return -value if sign else value 781 | 782 | def _read_offset(self) -> tuple[int, int]: 783 | """Read in the value for the length of the data stored in the file 784 | 785 | :rtype: int, int 786 | 787 | """ 788 | data_state = self._read_byte() or 0 789 | value = 0 790 | for offset in range(0, self._offsize): 791 | bv = self._read_byte() or 0 792 | value |= bv << (offset * 8) 793 | return data_state, value 794 | 795 | def _read_table_data( 796 | self, dump_id: int 797 | ) -> typing.Generator[str, None, None]: 798 | """Iterate through the data returning on row at a time 799 | 800 | :rtype: str 801 | 802 | """ 803 | try: 804 | with self._tempfile(dump_id, 'rb') as handle: 805 | for line in handle: 806 | out = (line or b'').decode(self.encoding).strip() 807 | if out.startswith('\\.') or not out: 808 | break 809 | yield out 810 | except exceptions.NoDataError: 811 | pass 812 | 813 | def _read_timestamp(self) -> datetime.datetime: 814 | """Read in the timestamp from handle. 815 | 816 | :rtype: datetime.datetime 817 | 818 | """ 819 | second, minute, hour, day, month, year = ( 820 | self._read_int() or 0, 821 | self._read_int() or 0, 822 | self._read_int() or 0, 823 | self._read_int() or 0, 824 | (self._read_int() or 0) + 1, 825 | (self._read_int() or 0) + 1900, 826 | ) 827 | self._read_int() # DST flag 828 | return datetime.datetime( 829 | year, month, day, hour, minute, second, 0, tzinfo=datetime.UTC 830 | ) 831 | 832 | def _save(self) -> None: 833 | """Save the dump file to disk""" 834 | self._write_toc() 835 | self._write_entries() 836 | if self._write_data(): 837 | self._write_toc() # Overwrite ToC and entries 838 | self._write_entries() 839 | 840 | def _set_encoding(self) -> None: 841 | """If the encoding is found in the dump entries, set the encoding 842 | to `self.encoding`. 843 | 844 | """ 845 | for entry in self.entries: 846 | if entry.desc == constants.ENCODING and entry.defn: 847 | match = ENCODING_PATTERN.match(entry.defn) 848 | if match: 849 | self.encoding = match.group(1) 850 | return 851 | 852 | @contextlib.contextmanager 853 | def _tempfile( 854 | self, dump_id: int, mode: str 855 | ) -> typing.Generator[typing.Any, None, None]: 856 | """Open the temp file for the specified dump_id in the specified mode 857 | 858 | :param int dump_id: The dump_id for the temp file 859 | :param str mode: The mode (rb, wb) 860 | 861 | """ 862 | path = pathlib.Path(self._temp_dir.name) / f'{dump_id}.gz' 863 | if not path.exists() and mode.startswith('r'): 864 | raise exceptions.NoDataError() 865 | with gzip.open(path, mode) as handle: 866 | try: 867 | yield handle 868 | except Exception: 869 | raise 870 | 871 | def _write_blobs(self, dump_id: int) -> int: 872 | """Write the blobs for the entry. 873 | 874 | :param int dump_id: The entry dump ID for the blobs 875 | :rtype: int 876 | 877 | """ 878 | if self._handle is None: 879 | raise ValueError('File handle is not initialized') 880 | length = 0 881 | with self._tempfile(dump_id, 'rb') as handle: 882 | self._handle.write(constants.BLK_BLOBS) 883 | self._write_int(dump_id) 884 | while True: 885 | try: 886 | oid = struct.unpack('I', handle.read(4))[0] 887 | except struct.error: 888 | break 889 | length = struct.unpack('I', handle.read(4))[0] 890 | self._write_int(oid) 891 | self._write_int(length) 892 | self._handle.write(handle.read(length)) 893 | self._write_int(0) 894 | self._write_int(0) 895 | return length 896 | 897 | def _write_byte(self, value: int) -> None: 898 | """Write a byte to the handle 899 | 900 | :param int value: The byte value 901 | 902 | """ 903 | if self._handle is None: 904 | raise ValueError('File handle is not initialized') 905 | self._handle.write(struct.pack('B', value)) 906 | 907 | def _write_data(self) -> set[int]: 908 | """Write the data blocks, returning a set of IDs that were written""" 909 | if self._handle is None: 910 | raise ValueError('File handle is not initialized') 911 | saved = set({}) 912 | for offset, entry in enumerate(self.entries): 913 | if entry.section != constants.SECTION_DATA: 914 | continue 915 | self.entries[offset].offset = self._handle.tell() 916 | size = 0 917 | if entry.desc == constants.TABLE_DATA: 918 | size = self._write_table_data(entry.dump_id) 919 | saved.add(entry.dump_id) 920 | elif entry.desc == constants.BLOBS: 921 | size = self._write_blobs(entry.dump_id) 922 | saved.add(entry.dump_id) 923 | if size: 924 | self.entries[offset].data_state = constants.K_OFFSET_POS_SET 925 | return saved 926 | 927 | def _write_entries(self) -> None: 928 | self._write_int(len(self.entries)) 929 | saved = set({}) 930 | 931 | # Always add these entries first 932 | for entry in self.entries[0:3]: 933 | self._write_entry(entry) 934 | saved.add(entry.dump_id) 935 | 936 | saved = self._write_section( 937 | constants.SECTION_PRE_DATA, 938 | [ 939 | constants.GROUP, 940 | constants.ROLE, 941 | constants.USER, 942 | constants.SCHEMA, 943 | constants.EXTENSION, 944 | constants.AGGREGATE, 945 | constants.OPERATOR, 946 | constants.OPERATOR_CLASS, 947 | constants.CAST, 948 | constants.COLLATION, 949 | constants.CONVERSION, 950 | constants.PROCEDURAL_LANGUAGE, 951 | constants.FOREIGN_DATA_WRAPPER, 952 | constants.FOREIGN_SERVER, 953 | constants.SERVER, 954 | constants.DOMAIN, 955 | constants.TYPE, 956 | constants.SHELL_TYPE, 957 | ], 958 | saved, 959 | ) 960 | 961 | saved = self._write_section(constants.SECTION_DATA, [], saved) 962 | 963 | saved = self._write_section( 964 | constants.SECTION_POST_DATA, 965 | [ 966 | constants.CHECK_CONSTRAINT, 967 | constants.CONSTRAINT, 968 | constants.INDEX, 969 | ], 970 | saved, 971 | ) 972 | 973 | saved = self._write_section(constants.SECTION_NONE, [], saved) 974 | LOGGER.debug('Wrote %i of %i entries', len(saved), len(self.entries)) 975 | 976 | def _write_entry(self, entry: models.Entry) -> None: 977 | """Write the entry 978 | 979 | :param pgdumplib.dump.Entry entry: The entry to write 980 | 981 | """ 982 | LOGGER.debug('Writing %r', entry) 983 | self._write_int(entry.dump_id) 984 | self._write_int(int(entry.had_dumper)) 985 | self._write_str(entry.table_oid or '0') 986 | self._write_str(entry.oid or '0') 987 | self._write_str(entry.tag) 988 | self._write_str(entry.desc) 989 | self._write_int(constants.SECTIONS.index(entry.section) + 1) 990 | self._write_str(entry.defn) 991 | self._write_str(entry.drop_stmt) 992 | self._write_str(entry.copy_stmt) 993 | self._write_str(entry.namespace) 994 | self._write_str(entry.tablespace) 995 | if self.version >= (1, 14, 0): 996 | LOGGER.debug('Adding tableam') 997 | self._write_str(entry.tableam) 998 | if self.version >= (1, 16, 0): 999 | LOGGER.debug('Adding relkind') 1000 | # Write relkind as an int (character code) 1001 | relkind_val = ord(entry.relkind) if entry.relkind else 0 1002 | self._write_int(relkind_val) 1003 | self._write_str(entry.owner) 1004 | self._write_str('true' if entry.with_oids else 'false') 1005 | for dependency in entry.dependencies or []: 1006 | self._write_str(str(dependency)) 1007 | self._write_int(-1) 1008 | self._write_offset(entry.offset, entry.data_state) 1009 | 1010 | def _write_header(self) -> None: 1011 | """Write the file header""" 1012 | if self._handle is None: 1013 | raise ValueError('File handle is not initialized') 1014 | LOGGER.debug( 1015 | 'Writing archive version %i.%i.%i', 1016 | self._vmaj, 1017 | self._vmin, 1018 | self._vrev, 1019 | ) 1020 | self._handle.write(constants.MAGIC) 1021 | self._write_byte(self._vmaj) 1022 | self._write_byte(self._vmin) 1023 | self._write_byte(self._vrev) 1024 | self._write_byte(self._intsize) 1025 | self._write_byte(self._offsize) 1026 | self._write_byte(constants.FORMATS.index(self._format)) 1027 | # v1.15+ has compression algorithm in header 1028 | if self.version >= (1, 15, 0): 1029 | # Write compression algorithm: 0=none, 1=gzip, 2=lz4, 3=zstd 1030 | comp_alg = constants.COMPRESSION_ALGORITHMS.index( 1031 | self.compression_algorithm 1032 | ) 1033 | self._write_byte(comp_alg) 1034 | 1035 | def _write_int(self, value: int) -> None: 1036 | """Write an integer value 1037 | 1038 | :param int value: 1039 | 1040 | """ 1041 | if self._handle is None: 1042 | raise ValueError('File handle is not initialized') 1043 | self._write_byte(1 if value < 0 else 0) 1044 | if value < 0: 1045 | value = -value 1046 | for _offset in range(0, self._intsize): 1047 | self._write_byte(value & 0xFF) 1048 | value >>= 8 1049 | 1050 | def _write_offset(self, value: int, data_state: int) -> None: 1051 | """Write the offset value. 1052 | 1053 | :param int value: The value to write 1054 | :param int data_state: The data state flag 1055 | 1056 | """ 1057 | self._write_byte(data_state) 1058 | for _offset in range(0, self._offsize): 1059 | self._write_byte(value & 0xFF) 1060 | value >>= 8 1061 | 1062 | def _write_section( 1063 | self, section: str, obj_types: list[str], saved: set[int] 1064 | ) -> set[int]: 1065 | for obj_type in obj_types: 1066 | for entry in [e for e in self.entries if e.desc == obj_type]: 1067 | self._write_entry(entry) 1068 | saved.add(entry.dump_id) 1069 | for dump_id in toposort.toposort_flatten( 1070 | { 1071 | e.dump_id: set(e.dependencies) 1072 | for e in self.entries 1073 | if e.section == section 1074 | }, 1075 | True, 1076 | ): 1077 | if dump_id not in saved: 1078 | found_entry: models.Entry | None = self.get_entry(dump_id) 1079 | if found_entry: 1080 | self._write_entry(found_entry) 1081 | saved.add(dump_id) 1082 | else: 1083 | LOGGER.warning('Entry %d not found, skipping', dump_id) 1084 | return saved 1085 | 1086 | def _write_str(self, value: str | None) -> None: 1087 | """Write a string or NULL marker 1088 | 1089 | :param value: The string to write, or None to write -1 length 1090 | (indicating an unset/NULL field in the archive format) 1091 | 1092 | """ 1093 | if self._handle is None: 1094 | raise ValueError('File handle is not initialized') 1095 | if value is None: 1096 | # Write -1 length to indicate "not set" rather than "empty string" 1097 | self._write_int(-1) 1098 | else: 1099 | out = value.encode(self.encoding) 1100 | self._write_int(len(out)) 1101 | if out: 1102 | LOGGER.debug('Writing %r', out) 1103 | self._handle.write(out) 1104 | 1105 | def _write_table_data(self, dump_id: int) -> int: 1106 | """Write the blobs for the entry, returning the # of bytes written 1107 | 1108 | :param int dump_id: The entry dump ID for the blobs 1109 | :rtype: int 1110 | 1111 | """ 1112 | if self._handle is None: 1113 | raise ValueError('File handle is not initialized') 1114 | self._handle.write(constants.BLK_DATA) 1115 | self._write_int(dump_id) 1116 | 1117 | writer = [w for w in self._writers.values() if w.dump_id == dump_id] 1118 | if writer: # Data was added ad-hoc, read from TableData writer 1119 | writer[0].finish() 1120 | # writer.read() returns decompressed data (auto-decompressed) 1121 | data = writer[0].read() 1122 | 1123 | if self.compression_algorithm != constants.COMPRESSION_NONE: 1124 | # Re-compress with zlib and write in chunks 1125 | # Compress all data as a continuous stream 1126 | compressed_data = zlib.compress(data) 1127 | 1128 | # Write compressed data in ZLIB_IN_SIZE chunks 1129 | total_size = 0 1130 | offset = 0 1131 | while offset < len(compressed_data): 1132 | chunk_size = min( 1133 | constants.ZLIB_IN_SIZE, len(compressed_data) - offset 1134 | ) 1135 | self._write_int(chunk_size) 1136 | self._handle.write( 1137 | compressed_data[offset : offset + chunk_size] 1138 | ) 1139 | total_size += chunk_size 1140 | offset += chunk_size 1141 | else: 1142 | # Write uncompressed in chunks 1143 | total_size = 0 1144 | offset = 0 1145 | while offset < len(data): 1146 | chunk_size = min( 1147 | constants.ZLIB_IN_SIZE, len(data) - offset 1148 | ) 1149 | self._write_int(chunk_size) 1150 | self._handle.write(data[offset : offset + chunk_size]) 1151 | total_size += chunk_size 1152 | offset += chunk_size 1153 | self._write_int(0) # End of data indicator 1154 | return total_size 1155 | 1156 | # Data was cached on load - read from tempfile and write 1157 | with self._tempfile(dump_id, 'rb') as handle: 1158 | # Read all decompressed data from the gzip temp file 1159 | data = handle.read() 1160 | 1161 | if self.compression_algorithm != constants.COMPRESSION_NONE: 1162 | # Compress and write in chunks 1163 | # Compress all data as a continuous stream 1164 | compressed_data = zlib.compress(data) 1165 | 1166 | # Write compressed data in ZLIB_IN_SIZE chunks 1167 | total_size = 0 1168 | offset = 0 1169 | while offset < len(compressed_data): 1170 | chunk_size = min( 1171 | constants.ZLIB_IN_SIZE, len(compressed_data) - offset 1172 | ) 1173 | self._write_int(chunk_size) 1174 | self._handle.write( 1175 | compressed_data[offset : offset + chunk_size] 1176 | ) 1177 | total_size += chunk_size 1178 | offset += chunk_size 1179 | else: 1180 | # Write uncompressed in chunks 1181 | total_size = 0 1182 | offset = 0 1183 | while offset < len(data): 1184 | chunk_size = min(constants.ZLIB_IN_SIZE, len(data) - offset) 1185 | self._write_int(chunk_size) 1186 | self._handle.write(data[offset : offset + chunk_size]) 1187 | total_size += chunk_size 1188 | offset += chunk_size 1189 | 1190 | self._write_int(0) # End of data indicator 1191 | return total_size 1192 | 1193 | def _write_timestamp(self, value: datetime.datetime) -> None: 1194 | """Write a datetime.datetime value 1195 | 1196 | :param datetime.datetime value: The value to write 1197 | 1198 | """ 1199 | if self._handle is None: 1200 | raise ValueError('File handle is not initialized') 1201 | self._write_int(value.second) 1202 | self._write_int(value.minute) 1203 | self._write_int(value.hour) 1204 | self._write_int(value.day) 1205 | self._write_int(value.month - 1) 1206 | self._write_int(value.year - 1900) 1207 | self._write_int(1 if value.dst() else 0) 1208 | 1209 | def _write_toc(self) -> None: 1210 | """Write the ToC for the file""" 1211 | if self._handle is None: 1212 | raise ValueError('File handle is not initialized') 1213 | self._handle.seek(0) 1214 | self._write_header() 1215 | # v1.15+ has compression in header, older versions have it here 1216 | if self.version < (1, 15, 0): 1217 | self._write_int( 1218 | int(self.compression_algorithm != constants.COMPRESSION_NONE) 1219 | ) 1220 | 1221 | self._write_timestamp(self.timestamp) 1222 | self._write_str(self.dbname) 1223 | self._write_str(self.server_version) 1224 | self._write_str(self.dump_version) 1225 | --------------------------------------------------------------------------------