├── .python-version ├── autopgpool ├── autopgpool │ ├── __init__.py │ ├── __tests__ │ │ ├── __init__.py │ │ ├── test_config.py │ │ ├── conftest.py │ │ ├── test_env.py │ │ ├── test_cli.py │ │ ├── test_ini_writer.py │ │ └── test_docker.py │ ├── logging.py │ ├── env.py │ ├── config.py │ ├── ini_writer.py │ └── cli.py ├── bootstrap.sh ├── pyproject.toml ├── Dockerfile ├── config.example.toml └── README.md ├── autopg ├── __init__.py ├── __tests__ │ ├── __init__.py │ ├── test_system_info.py │ ├── test_cli.py │ ├── test_docker.py │ ├── test_postgres.py │ └── test_logic.py ├── constants.py ├── system_info.py ├── static │ └── pygments.css ├── postgres.py ├── cli.py └── logic.py ├── benchmarks ├── results │ └── .gitkeep ├── benchmarks │ ├── __init__.py │ ├── utils.py │ ├── cli.py │ ├── seqscan.py │ ├── database.py │ └── insertion.py ├── README.md ├── pyproject.toml ├── Dockerfile.benchmark ├── pg_hba.conf ├── docker-compose.yml ├── postgres-init.sql └── uv.lock ├── .DS_Store ├── media ├── header.png └── analysis-app.png ├── Makefile ├── bootstrap.sh ├── postgresql.conf ├── pyrightconfig.json ├── LICENSE ├── Dockerfile ├── pyproject.toml ├── .github └── workflows │ ├── test.yml │ ├── test-pool.yml │ ├── docker-pool.yml │ └── docker.yml ├── .gitignore └── README.md /.python-version: -------------------------------------------------------------------------------- 1 | 3.12 2 | -------------------------------------------------------------------------------- /autopgpool/autopgpool/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /autopg/__init__.py: -------------------------------------------------------------------------------- 1 | # Empty file is fine 2 | -------------------------------------------------------------------------------- /autopgpool/autopgpool/__tests__/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /autopg/__tests__/__init__.py: -------------------------------------------------------------------------------- 1 | # Empty file is fine 2 | -------------------------------------------------------------------------------- /benchmarks/results/.gitkeep: -------------------------------------------------------------------------------- 1 | # This directory stores benchmark results 2 | -------------------------------------------------------------------------------- /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piercefreeman/autopg/HEAD/.DS_Store -------------------------------------------------------------------------------- /media/header.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piercefreeman/autopg/HEAD/media/header.png -------------------------------------------------------------------------------- /autopgpool/autopgpool/logging.py: -------------------------------------------------------------------------------- 1 | from rich.console import Console 2 | 3 | CONSOLE = Console() 4 | -------------------------------------------------------------------------------- /media/analysis-app.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piercefreeman/autopg/HEAD/media/analysis-app.png -------------------------------------------------------------------------------- /autopgpool/bootstrap.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | echo "Running Autopg for PgBouncer..." 4 | 5 | autopgpool generate 6 | 7 | echo "Booting PgBouncer..." 8 | 9 | # We need autopgpool to be run as the root user to access our internal binary, but 10 | # we should run the pgbouncer binary as a more constrained user. 11 | exec su -c "$*" postgres 12 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: lint lint-ruff lint-pyright 2 | 3 | lint: lint-fix lint-pyright 4 | 5 | lint-ci: lint-ruff lint-pyright 6 | 7 | lint-ruff: 8 | uv run ruff check . 9 | 10 | lint-pyright: 11 | uv run pyright . 12 | 13 | # Run all linting in fix mode where possible 14 | lint-fix: 15 | uv run ruff format . 16 | uv run ruff check . --fix 17 | -------------------------------------------------------------------------------- /bootstrap.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "Running Autopg..." 4 | 5 | autopg build-config --pg-path /etc/postgresql 6 | 7 | # Launch the webapp in the background if supported 8 | echo "Launching AutoPG webapp..." 9 | autopg webapp & 10 | 11 | echo "Booting PostgreSQL..." 12 | 13 | exec docker-entrypoint.sh postgres -c config_file=/etc/postgresql/postgresql.conf 14 | -------------------------------------------------------------------------------- /benchmarks/benchmarks/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | AutoPG Benchmarking Suite 3 | 4 | Load testing tools for PostgreSQL databases with unoptimized queries. 5 | """ 6 | 7 | from .cli import cli 8 | from .database import DatabaseConnection 9 | from .insertion import InsertionBenchmark 10 | from .seqscan import SequentialScanBenchmark 11 | 12 | __all__ = ["cli", "DatabaseConnection", "InsertionBenchmark", "SequentialScanBenchmark"] 13 | -------------------------------------------------------------------------------- /benchmarks/README.md: -------------------------------------------------------------------------------- 1 | # benchmarks 2 | 3 | The goal with our autopg `benchmarks` is to provide us an easy entry point to stress test postgres. This should create a large amount of pg_stat_statements and gives us a reference set of data to optimize our auto-analysis pipeline. 4 | 5 | ```bash 6 | docker compose up 7 | ``` 8 | 9 | Connect to the benchmark container: 10 | 11 | ```bash 12 | docker compose exec benchmark bash 13 | 14 | $ uv run autopg-bench full --scan-iterations 2000000 15 | ``` 16 | -------------------------------------------------------------------------------- /benchmarks/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "benchmarks" 3 | version = "0.1.0" 4 | description = "AutoPG Database Benchmarking Suite - Load testing tools for PostgreSQL" 5 | readme = "README.md" 6 | requires-python = ">=3.12" 7 | dependencies = [ 8 | "click>=8.1.8", 9 | "rich>=13.9.4", 10 | "asyncpg>=0.30.0", 11 | ] 12 | 13 | 14 | [build-system] 15 | requires = ["hatchling"] 16 | build-backend = "hatchling.build" 17 | 18 | 19 | [project.scripts] 20 | autopg-bench = "benchmarks.cli:cli" 21 | -------------------------------------------------------------------------------- /postgresql.conf: -------------------------------------------------------------------------------- 1 | # Generated by AutoPG 2 | 3 | checkpoint_completion_target = 0.9 4 | datestyle = 'iso, mdy' 5 | default_statistics_target = 100 6 | dynamic_shared_memory_type = 'posix' 7 | effective_cache_size = '24GB' 8 | huge_pages = 'try' 9 | listen_addresses = '*' 10 | log_timezone = 'Etc/UTC' 11 | maintenance_work_mem = '2GB' 12 | max_connections = 100 13 | max_parallel_maintenance_workers = 4 14 | max_parallel_workers = 10 15 | max_parallel_workers_per_gather = 4 16 | max_wal_size = '4GB' 17 | max_worker_processes = 10 18 | min_wal_size = '1GB' 19 | random_page_cost = 1.1 20 | shared_buffers = '8GB' 21 | timezone = 'Etc/UTC' 22 | wal_buffers = '16MB' 23 | work_mem = '20971kB' 24 | -------------------------------------------------------------------------------- /pyrightconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "include": ["autopg", "benchmarks/benchmarks"], 3 | "exclude": [ 4 | "**/__pycache__", 5 | "**/node_modules", 6 | ".venv", 7 | "**/.venv", 8 | "**/lib", 9 | "**/lib64", 10 | "autopgpool", 11 | "autopgbouncer" 12 | ], 13 | "reportMissingImports": false, 14 | "reportMissingTypeStubs": false, 15 | "reportUnknownMemberType": false, 16 | "reportUnknownArgumentType": false, 17 | "reportUnknownParameterType": false, 18 | "reportUnknownVariableType": false, 19 | "reportOptionalSubscript": false, 20 | "reportOptionalIterable": false, 21 | "reportPrivateUsage": false, 22 | "pythonVersion": "3.11", 23 | "typeCheckingMode": "basic" 24 | } -------------------------------------------------------------------------------- /autopgpool/autopgpool/__tests__/test_config.py: -------------------------------------------------------------------------------- 1 | import tomllib 2 | from pathlib import Path 3 | from typing import Any 4 | 5 | from autopgpool.config import MainConfig 6 | 7 | 8 | def test_example_config_loads_correctly(project_root: Path) -> None: 9 | """ 10 | Test that the example config file can be loaded correctly into the MainConfig model. 11 | """ 12 | # Find the project root and the example config file 13 | example_config_path = project_root / "config.example.toml" 14 | 15 | assert example_config_path.exists(), f"Example config file not found at {example_config_path}" 16 | 17 | # Load the TOML file 18 | with open(example_config_path, "rb") as f: 19 | config_data: dict[str, Any] = tomllib.load(f) 20 | 21 | # Parse the config data into the MainConfig model 22 | MainConfig.model_validate(config_data) 23 | -------------------------------------------------------------------------------- /autopgpool/autopgpool/__tests__/conftest.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from tempfile import TemporaryDirectory 3 | from typing import Generator 4 | 5 | import pytest 6 | 7 | 8 | @pytest.fixture(scope="session") 9 | def project_root() -> Path: 10 | """ 11 | Find the project root by looking for the pyproject.toml file. 12 | 13 | Returns: 14 | pathlib.Path: Path to the project root 15 | """ 16 | current_dir = Path(__file__).resolve().parent 17 | 18 | while current_dir != current_dir.parent: 19 | if (current_dir / "pyproject.toml").exists(): 20 | return current_dir 21 | current_dir = current_dir.parent 22 | 23 | raise FileNotFoundError("Could not find project root (pyproject.toml)") 24 | 25 | 26 | @pytest.fixture 27 | def temp_dir() -> Generator[Path, None, None]: 28 | """Fixture that provides a temporary directory as a Path object.""" 29 | with TemporaryDirectory() as temp_dir: 30 | yield Path(temp_dir) 31 | -------------------------------------------------------------------------------- /autopgpool/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "autopgpool" 3 | version = "0.1.0" 4 | description = "Opinionated package for postgres pools" 5 | readme = "README.md" 6 | requires-python = ">=3.12" 7 | dependencies = [ 8 | "click>=8.2.0", 9 | "pydantic>=2.11.4", 10 | "rich>=14.0.0", 11 | ] 12 | 13 | [project.scripts] 14 | autopgpool = "autopgpool.cli:cli" 15 | 16 | [dependency-groups] 17 | dev = [ 18 | "psycopg>=3.2.9", 19 | "pyright>=1.1.400", 20 | "pytest>=8.3.5", 21 | "ruff>=0.11.9", 22 | "tomli-w>=1.2.0", 23 | ] 24 | 25 | [build-system] 26 | requires = ["hatchling"] 27 | build-backend = "hatchling.build" 28 | 29 | [tool.ruff] 30 | target-version = "py312" 31 | line-length = 100 32 | 33 | [tool.ruff.lint] 34 | select = ["E", "F", "I", "N", "W", "B"] 35 | ignore = ["E501"] 36 | 37 | [tool.pyright] 38 | pythonVersion = "3.12" 39 | typeCheckingMode = "strict" 40 | 41 | [tool.pytest.ini_options] 42 | markers = [ 43 | "integration: marks tests that require external services (deselect with '-m \"not integration\"')", 44 | ] 45 | addopts = "-m 'not integration'" 46 | -------------------------------------------------------------------------------- /benchmarks/Dockerfile.benchmark: -------------------------------------------------------------------------------- 1 | FROM python:3.12-slim 2 | 3 | # Install system dependencies 4 | RUN apt-get update && apt-get install -y \ 5 | postgresql-client \ 6 | build-essential \ 7 | && rm -rf /var/lib/apt/lists/* 8 | 9 | # Copy uv from the official image 10 | COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/ 11 | 12 | # Create a virtual environment and install Python using uv 13 | ENV VIRTUAL_ENV=/opt/venv 14 | RUN uv venv $VIRTUAL_ENV 15 | ENV PATH="$VIRTUAL_ENV/bin:$PATH" 16 | 17 | # Create a working directory for the package 18 | WORKDIR /app 19 | 20 | # Copy benchmark project files 21 | COPY pyproject.toml ./ 22 | COPY uv.lock ./ 23 | 24 | # Install benchmarking dependencies 25 | RUN --mount=type=cache,target=/root/.cache/uv \ 26 | uv sync --frozen --no-install-project 27 | 28 | # Copy the rest of the benchmark code 29 | COPY . . 30 | 31 | # Install the benchmark project 32 | RUN --mount=type=cache,target=/root/.cache/uv \ 33 | uv pip install -e . 34 | 35 | # Set up the benchmark environment 36 | ENV PYTHONPATH=/app:$PYTHONPATH 37 | 38 | CMD ["uv", "run", "autopg-bench"] -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Pierce Freeman 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Declare the build argument before FROM 2 | ARG POSTGRES_VERSION=16 3 | 4 | # Use the build argument in FROM 5 | FROM postgres:${POSTGRES_VERSION} 6 | 7 | # Install build dependencies for psycopg2 8 | RUN apt-get update && apt-get install -y \ 9 | gcc \ 10 | python3-dev \ 11 | libpq-dev \ 12 | && rm -rf /var/lib/apt/lists/* 13 | 14 | # Copy uv from the official image 15 | COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/ 16 | 17 | # Create a virtual environment and install Python using uv 18 | ENV VIRTUAL_ENV=/opt/venv 19 | RUN uv venv $VIRTUAL_ENV 20 | ENV PATH="$VIRTUAL_ENV/bin:$PATH" 21 | 22 | # Create a working directory for the package 23 | WORKDIR /app 24 | 25 | # Install dependencies using uv 26 | RUN --mount=type=cache,target=/root/.cache/uv \ 27 | --mount=type=bind,source=uv.lock,target=uv.lock \ 28 | --mount=type=bind,source=pyproject.toml,target=pyproject.toml \ 29 | uv sync --frozen --active --no-install-project 30 | 31 | COPY . . 32 | 33 | RUN --mount=type=cache,target=/root/.cache/uv \ 34 | uv sync --frozen --active 35 | 36 | RUN chmod +x /app/bootstrap.sh 37 | 38 | # Keep the original postgres entrypoint 39 | ENTRYPOINT ["/app/bootstrap.sh"] 40 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "autopg" 3 | version = "0.1.0" 4 | description = "Autotune PostgreSQL for your system" 5 | readme = "README.md" 6 | requires-python = ">=3.12" 7 | dependencies = [ 8 | "click>=8.1.8", 9 | "psutil>=6.1.1", 10 | "pydantic>=2.10.6", 11 | "pydantic-settings>=2.7.1", 12 | "rich>=13.9.4", 13 | "fastapi>=0.100.0", 14 | "uvicorn>=0.23.0", 15 | "psycopg>=3.2.5", 16 | "pygments>=2.17.0", 17 | "sqlparse>=0.4.0", 18 | ] 19 | 20 | [project.scripts] 21 | autopg = "autopg.cli:cli" 22 | 23 | [dependency-groups] 24 | dev = [ 25 | "pytest>=8.3.4", 26 | "ruff>=0.3.0", 27 | "pyright>=1.1.350", 28 | "psycopg>=3.2.5", 29 | "tomli-w>=1.2.0", 30 | ] 31 | 32 | [build-system] 33 | requires = ["hatchling"] 34 | build-backend = "hatchling.build" 35 | 36 | [tool.ruff] 37 | target-version = "py312" 38 | line-length = 100 39 | 40 | [tool.ruff.lint] 41 | select = ["E", "F", "I", "N", "W", "B"] 42 | ignore = ["E501"] 43 | 44 | [tool.pyright] 45 | pythonVersion = "3.12" 46 | typeCheckingMode = "strict" 47 | 48 | [tool.pytest.ini_options] 49 | markers = [ 50 | "integration: marks tests that require external services (deselect with '-m \"not integration\"')", 51 | ] 52 | addopts = "-m 'not integration'" 53 | -------------------------------------------------------------------------------- /autopg/constants.py: -------------------------------------------------------------------------------- 1 | # postgresql versions 2 | DEFAULT_DB_VERSION = 18 3 | DB_VERSIONS = [DEFAULT_DB_VERSION, 17, 16, 15, 14, 13, 12, 11, 10] 4 | 5 | # os types 6 | OS_LINUX = "linux" 7 | OS_WINDOWS = "windows" 8 | OS_MAC = "mac" 9 | 10 | # db types 11 | DB_TYPE_WEB = "web" 12 | DB_TYPE_OLTP = "oltp" 13 | DB_TYPE_DW = "dw" 14 | DB_TYPE_DESKTOP = "desktop" 15 | DB_TYPE_MIXED = "mixed" 16 | 17 | # size units 18 | SIZE_UNIT_MB = "MB" 19 | SIZE_UNIT_GB = "GB" 20 | 21 | # harddrive types 22 | HARD_DRIVE_SSD = "SSD" 23 | HARD_DRIVE_SAN = "SAN" 24 | HARD_DRIVE_HDD = "HDD" 25 | 26 | # maximum value for integer fields 27 | MAX_NUMERIC_VALUE = 999999 28 | 29 | SIZE_UNIT_MAP: dict[str, int] = {"KB": 1024, "MB": 1048576, "GB": 1073741824, "TB": 1099511627776} 30 | 31 | KNOWN_STORAGE_VARS = [ 32 | "shared_buffers", 33 | "effective_cache_size", 34 | "maintenance_work_mem", 35 | "wal_buffers", 36 | "work_mem", 37 | "min_wal_size", 38 | "max_wal_size", 39 | ] 40 | 41 | PG_CONFIG_DIR = "/etc/postgresql" 42 | PG_CONFIG_FILE = "postgresql.conf" 43 | PG_CONFIG_FILE_BASE = "postgresql.conf.base" 44 | 45 | PG_STAT_STATEMENTS_SQL = """-- AutoPG Extension Initialization 46 | -- Enable pg_stat_statements extension for query statistics 47 | 48 | CREATE EXTENSION IF NOT EXISTS pg_stat_statements; 49 | """ 50 | -------------------------------------------------------------------------------- /benchmarks/pg_hba.conf: -------------------------------------------------------------------------------- 1 | # PostgreSQL Client Authentication Configuration File 2 | # =================================================== 3 | # 4 | # This file controls: which hosts are allowed to connect, how clients 5 | # are authenticated, which PostgreSQL user names they can use, which 6 | # databases they can access. 7 | # 8 | # TYPE DATABASE USER ADDRESS METHOD 9 | # 10 | # "local" is for Unix domain socket connections only 11 | local all all trust 12 | 13 | # IPv4 local connections: 14 | host all all 127.0.0.1/32 trust 15 | 16 | # IPv6 local connections: 17 | host all all ::1/128 trust 18 | 19 | # Allow connections from any IPv4 address (0.0.0.0/0 means anywhere) 20 | host all all 0.0.0.0/0 md5 21 | 22 | # Allow connections from any IPv6 address 23 | host all all ::/0 md5 24 | 25 | # Allow replication connections from localhost, by a user with the 26 | # replication privilege. 27 | local replication all trust 28 | host replication all 127.0.0.1/32 trust 29 | host replication all ::1/128 trust 30 | 31 | # Allow replication connections from anywhere for replication users 32 | host replication all 0.0.0.0/0 md5 33 | host replication all ::/0 md5 34 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | test: 11 | runs-on: ubuntu-latest 12 | strategy: 13 | matrix: 14 | python-version: ["3.11"] 15 | 16 | steps: 17 | - uses: actions/checkout@v4 18 | 19 | - name: Set up Python ${{ matrix.python-version }} 20 | uses: actions/setup-python@v5 21 | with: 22 | python-version: ${{ matrix.python-version }} 23 | 24 | - name: Install uv 25 | run: | 26 | curl -LsSf https://astral.sh/uv/install.sh | sh 27 | echo "$HOME/.cargo/bin" >> $GITHUB_PATH 28 | 29 | - name: Install dependencies 30 | run: | 31 | uv sync 32 | 33 | - name: Run tests 34 | run: | 35 | uv run pytest -vvv autopg 36 | 37 | integration-test: 38 | runs-on: ubuntu-latest 39 | strategy: 40 | matrix: 41 | python-version: ["3.11"] 42 | 43 | steps: 44 | - uses: actions/checkout@v4 45 | 46 | - name: Set up Python ${{ matrix.python-version }} 47 | uses: actions/setup-python@v5 48 | with: 49 | python-version: ${{ matrix.python-version }} 50 | 51 | - name: Install uv 52 | run: | 53 | curl -LsSf https://astral.sh/uv/install.sh | sh 54 | echo "$HOME/.cargo/bin" >> $GITHUB_PATH 55 | 56 | - name: Install dependencies 57 | run: | 58 | uv sync 59 | 60 | - name: Run integration tests 61 | run: | 62 | uv run pytest -vvv -m "integration" autopg -------------------------------------------------------------------------------- /.github/workflows/test-pool.yml: -------------------------------------------------------------------------------- 1 | name: Test PgBouncer 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | test: 11 | runs-on: ubuntu-latest 12 | strategy: 13 | matrix: 14 | python-version: ["3.11"] 15 | 16 | steps: 17 | - uses: actions/checkout@v4 18 | 19 | - name: Set up Python ${{ matrix.python-version }} 20 | uses: actions/setup-python@v5 21 | with: 22 | python-version: ${{ matrix.python-version }} 23 | 24 | - name: Install uv 25 | run: | 26 | curl -LsSf https://astral.sh/uv/install.sh | sh 27 | echo "$HOME/.cargo/bin" >> $GITHUB_PATH 28 | 29 | - name: Install dependencies 30 | run: | 31 | cd autopgpool 32 | uv sync 33 | 34 | - name: Run tests 35 | run: | 36 | cd autopgpool 37 | uv run pytest -vvv autopgpool 38 | 39 | integration-test: 40 | runs-on: ubuntu-latest 41 | strategy: 42 | matrix: 43 | python-version: ["3.11"] 44 | 45 | steps: 46 | - uses: actions/checkout@v4 47 | 48 | - name: Set up Python ${{ matrix.python-version }} 49 | uses: actions/setup-python@v5 50 | with: 51 | python-version: ${{ matrix.python-version }} 52 | 53 | - name: Install uv 54 | run: | 55 | curl -LsSf https://astral.sh/uv/install.sh | sh 56 | echo "$HOME/.cargo/bin" >> $GITHUB_PATH 57 | 58 | - name: Install dependencies 59 | run: | 60 | cd autopgpool 61 | uv sync 62 | 63 | - name: Run integration tests 64 | run: | 65 | cd autopgpool 66 | uv run pytest -vvv -m "integration" autopgpool -------------------------------------------------------------------------------- /autopgpool/autopgpool/env.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import tomllib 3 | from os import getenv 4 | from typing import Any, TypeVar 5 | 6 | from autopgpool.logging import CONSOLE 7 | 8 | T = TypeVar("T") 9 | 10 | 11 | def load_toml_config(config_path: str) -> dict[str, Any]: 12 | """ 13 | Load a TOML configuration file. 14 | 15 | Args: 16 | config_path: Path to the TOML file 17 | 18 | Returns: 19 | Dictionary containing the parsed TOML data 20 | """ 21 | try: 22 | with open(config_path, "rb") as f: 23 | payload = tomllib.load(f) 24 | payload = swap_env(payload) 25 | return payload 26 | except FileNotFoundError: 27 | CONSOLE.print(f"[red]Error: Config file not found at {config_path}[/red]") 28 | sys.exit(1) 29 | except tomllib.TOMLDecodeError as e: 30 | CONSOLE.print(f"[red]Error parsing TOML file: {str(e)}[/red]") 31 | sys.exit(1) 32 | 33 | 34 | def swap_env(obj: T) -> T: 35 | """ 36 | Recursively walk a structure (dict / list / scalar) and replace every string that 37 | starts with `$` by the matching OS environment variable. 38 | 39 | """ 40 | if isinstance(obj, dict): 41 | return {k: swap_env(v) for k, v in obj.items()} # type: ignore 42 | 43 | if isinstance(obj, list): 44 | return [swap_env(item) for item in obj] # type: ignore 45 | 46 | if isinstance(obj, str) and obj.startswith("$"): 47 | env_name = obj[1:] 48 | env_val = getenv(env_name) 49 | if env_val is None: 50 | raise EnvironmentError( 51 | f"Environment variable '{env_name}' referenced in config but not set." 52 | ) 53 | return env_val # type: ignore 54 | 55 | return obj 56 | -------------------------------------------------------------------------------- /autopgpool/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM alpine:3.21 AS build 2 | ARG VERSION=1.24.1 3 | 4 | # Install build dependencies 5 | RUN apk add --no-cache autoconf autoconf-doc automake curl gcc git libc-dev libevent-dev libtool make openssl-dev pandoc pkgconfig 6 | 7 | # Download and extract pgbouncer 8 | RUN curl -sS -o /pgbouncer.tar.gz -L https://pgbouncer.github.io/downloads/files/$VERSION/pgbouncer-$VERSION.tar.gz && \ 9 | tar -xzf /pgbouncer.tar.gz && mv /pgbouncer-$VERSION /pgbouncer 10 | 11 | # Build pgbouncer 12 | RUN cd /pgbouncer && ./configure --prefix=/usr && make 13 | 14 | FROM alpine:3.21 15 | 16 | RUN apk add --no-cache python3 py3-pip busybox libevent postgresql-client && \ 17 | mkdir -p /etc/pgbouncer /var/log/pgbouncer /var/run/pgbouncer && \ 18 | chown -R postgres /var/log/pgbouncer /var/run/pgbouncer /etc/pgbouncer 19 | 20 | # Copy pgbouncer binary 21 | COPY --from=build /pgbouncer/pgbouncer /usr/bin 22 | 23 | # Copy uv from the official image 24 | COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/ 25 | 26 | # Create a virtual environment and install Python using uv 27 | ENV VIRTUAL_ENV=/opt/venv 28 | RUN uv venv $VIRTUAL_ENV 29 | ENV PATH="$VIRTUAL_ENV/bin:$PATH" 30 | 31 | # Create a working directory for the package 32 | WORKDIR /app 33 | 34 | # Install dependencies using uv 35 | RUN --mount=type=cache,target=/root/.cache/uv \ 36 | --mount=type=bind,source=uv.lock,target=uv.lock \ 37 | --mount=type=bind,source=pyproject.toml,target=pyproject.toml \ 38 | uv sync --frozen --active --no-install-project 39 | 40 | COPY . . 41 | 42 | RUN --mount=type=cache,target=/root/.cache/uv \ 43 | uv sync --frozen --active 44 | 45 | RUN chmod +x /app/bootstrap.sh 46 | 47 | EXPOSE 5432 48 | ENTRYPOINT ["/app/bootstrap.sh"] 49 | CMD ["/usr/bin/pgbouncer", "/etc/pgbouncer/pgbouncer.ini"] 50 | -------------------------------------------------------------------------------- /autopgpool/config.example.toml: -------------------------------------------------------------------------------- 1 | # AutoPGPool Example Configuration 2 | 3 | # User definitions 4 | [[users]] 5 | username = "admin" 6 | password = "admin_password" 7 | grants = ["main_db", "analytics_db"] 8 | 9 | [[users]] 10 | username = "app_user" 11 | password = "app_password" 12 | grants = ["main_db"] 13 | 14 | [[users]] 15 | username = "stats_user" 16 | password = "stats_password" 17 | grants = ["analytics_db"] 18 | 19 | # Database definitions 20 | [pools.main_db] 21 | pool_mode = "transaction" 22 | 23 | [pools.main_db.remote] 24 | host = "localhost" 25 | port = 5432 26 | database = "main_db" 27 | username = "postgres" 28 | password = "postgres_password" 29 | 30 | [pools.analytics_db] 31 | pool_mode = "session" 32 | 33 | [pools.analytics_db.remote] 34 | host = "10.0.0.5" 35 | port = 5432 36 | database = "analytics_db" 37 | username = "analytics_user" 38 | password = "analytics_password" 39 | 40 | [pools.analytics_db.replica] 41 | pool_mode = "statement" 42 | 43 | [pools.analytics_db.replica.remote] 44 | host = "replica.example.com" 45 | port = 5432 46 | database = "replica_db" 47 | username = "replica_user" 48 | password = "replica_password" 49 | 50 | # PGBouncer configuration 51 | [pgbouncer] 52 | listen_addr = "0.0.0.0" 53 | listen_port = 6432 54 | auth_type = "md5" 55 | pool_mode = "transaction" 56 | max_client_conn = 200 57 | default_pool_size = 20 58 | ignore_startup_parameters = ["extra_float_digits", "search_path"] 59 | admin_users = ["admin"] 60 | stats_users = ["stats_user"] 61 | idle_transaction_timeout = 60 62 | max_prepared_statements = 25 63 | 64 | # Additional custom PGBouncer parameters 65 | [pgbouncer.passthrough_kwargs] 66 | server_reset_query = "DISCARD ALL" 67 | server_check_query = "select 1" 68 | server_check_delay = 30 69 | application_name_add_host = 1 70 | log_disconnections = 1 71 | log_connections = 1 72 | -------------------------------------------------------------------------------- /.github/workflows/docker-pool.yml: -------------------------------------------------------------------------------- 1 | name: Docker PgBouncer Images 2 | 3 | on: 4 | push: 5 | tags: 6 | - 'v*' 7 | pull_request: 8 | types: [labeled, synchronize] 9 | # Allow manual trigger for testing 10 | workflow_dispatch: 11 | 12 | env: 13 | REGISTRY: ghcr.io 14 | IMAGE_NAME: ${{ github.repository }}-pool 15 | 16 | jobs: 17 | build-and-push: 18 | # Skip if PR without "Full Build" label 19 | if: | 20 | github.event_name == 'push' || 21 | github.event_name == 'workflow_dispatch' || 22 | (github.event_name == 'pull_request' && contains(github.event.pull_request.labels.*.name, 'Full Build')) 23 | runs-on: ubuntu-latest 24 | permissions: 25 | contents: read 26 | packages: write 27 | 28 | steps: 29 | - name: Checkout repository 30 | uses: actions/checkout@v4 31 | 32 | - name: Set up Docker Buildx 33 | uses: docker/setup-buildx-action@v3 34 | 35 | - name: Log in to the Container registry 36 | # Only login for actual deployments 37 | if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/') 38 | uses: docker/login-action@v3 39 | with: 40 | registry: ${{ env.REGISTRY }} 41 | username: ${{ github.actor }} 42 | password: ${{ secrets.GITHUB_TOKEN }} 43 | 44 | - name: Extract metadata (tags, labels) for Docker 45 | id: meta 46 | uses: docker/metadata-action@v5 47 | with: 48 | images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} 49 | tags: | 50 | # Tag with release tag (e.g. v1.0.0) 51 | type=raw,value=${{ github.ref_name }} 52 | # Tag with latest 53 | type=raw,value=latest 54 | 55 | - name: Build and push Docker image 56 | uses: docker/build-push-action@v5 57 | with: 58 | context: ./autopgpool 59 | # Only push for tag events 60 | push: ${{ github.event_name == 'push' && startsWith(github.ref, 'refs/tags/') }} 61 | tags: ${{ steps.meta.outputs.tags }} 62 | labels: ${{ steps.meta.outputs.labels }} 63 | cache-from: type=gha 64 | cache-to: type=gha,mode=max -------------------------------------------------------------------------------- /autopg/system_info.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from enum import StrEnum 3 | 4 | import psutil 5 | from rich.console import Console 6 | 7 | from autopg.constants import HARD_DRIVE_HDD, HARD_DRIVE_SAN, HARD_DRIVE_SSD 8 | 9 | console = Console() 10 | 11 | 12 | @dataclass 13 | class MemoryInfo: 14 | total: float | None 15 | available: float 16 | 17 | 18 | @dataclass 19 | class CpuInfo: 20 | count: int | None 21 | current_freq: float 22 | 23 | 24 | class DiskType(StrEnum): 25 | SSD = HARD_DRIVE_SSD 26 | SAN = HARD_DRIVE_SAN 27 | HDD = HARD_DRIVE_HDD 28 | 29 | 30 | def get_memory_info() -> MemoryInfo: 31 | """ 32 | Get the total and available memory in GB 33 | """ 34 | vm = psutil.virtual_memory() 35 | total_gb = vm.total / (1024**3) 36 | available_gb = vm.available / (1024**3) 37 | return MemoryInfo(total=total_gb, available=available_gb) 38 | 39 | 40 | def get_cpu_info() -> CpuInfo: 41 | """ 42 | Get CPU count and current frequency 43 | """ 44 | cpu_count = psutil.cpu_count(logical=True) 45 | # Get the average frequency across all CPUs 46 | freq = psutil.cpu_freq() 47 | current_freq = freq.current if freq else 0.0 48 | return CpuInfo(count=cpu_count, current_freq=current_freq) 49 | 50 | 51 | def get_disk_type() -> DiskType | None: 52 | """ 53 | Attempt to determine if the primary disk is SSD or HDD 54 | """ 55 | try: 56 | # On Linux, we can check rotational flag 57 | import os 58 | 59 | # Check the first disk device 60 | for device in psutil.disk_partitions(): 61 | if device.device.startswith("/dev/"): 62 | # Get the base device (strip partition number) 63 | base_device = "".join(filter(str.isalpha, device.device)) 64 | rotational_path = f"/sys/block/{base_device}/queue/rotational" 65 | 66 | if os.path.exists(rotational_path): 67 | with open(rotational_path, "r") as f: 68 | rotational = int(f.read().strip()) 69 | return DiskType.HDD if rotational == 1 else DiskType.SSD 70 | return None 71 | except Exception: 72 | return None 73 | -------------------------------------------------------------------------------- /autopgpool/README.md: -------------------------------------------------------------------------------- 1 | # autopgpool 2 | 3 | `autopgpool` is a Postgres pooler with opinionated default configurations. 4 | 5 | Unlike `autopg`, which is guaranteed to wrap standard Postgres with auto-configuration useful on any device, `autopgpool` is more geared to users that are self hosting postgres and want a lightweight pooling layer out of the box. 6 | 7 | It's currently a wrapper on top of the battle hardened [pgbouncer](https://www.pgbouncer.org/), but this is an implementation detail that could change in the future. 8 | 9 | ## Features 10 | 11 | - toml configurable with a single deployment file (mounted via a docker volume typically) 12 | - simple user based access grants to different tables 13 | - automatic md5 calculation of user passwords 14 | - environment variable insertion to let your docker container remain the source of truth for configuration variables 15 | 16 | ## Basic configuration 17 | 18 | You'll minimally need to provide definitions for the remote databases that you want to route into, and the users that you'll use to connect to the pool. We will expand any env variables you include to their current values: 19 | 20 | ```toml 21 | [[users]] 22 | username = "app_user" 23 | password = "$APP_CLIENT_PASSWORD" 24 | grants = ["main_db"] 25 | 26 | [pools.main_db.remote] 27 | host = "127.0.0.1" 28 | port = "5056" 29 | database = "main_db" 30 | username = "main_user" 31 | password = "$MAIN_DB_PASSWORD" 32 | ``` 33 | 34 | For a more complete example config, see config.example.toml. To reference this in docker-compose, do something like the following: 35 | 36 | ```bash 37 | version: '3' 38 | 39 | services: 40 | pgpool: 41 | image: ghcr.io/piercefreeman/autopg-pool:latest 42 | ports: 43 | - "6432:6432" 44 | environment: 45 | - APP_CLIENT_PASSWORD=myapppassword 46 | - MAIN_DB_PASSWORD=mymaindbpassword 47 | volumes: 48 | - ./config.toml:/etc/autopgpool/autopgpool.toml 49 | restart: unless-stopped 50 | ``` 51 | 52 | ### autopgpool vs vanilla pgbouncer 53 | 54 | Vanilla pgbouncer requires configuration through `pgbouncer.ini` and `userlist.txt` files that are placed on disk. This works fine for static configuration, but requires you to hard-code in any env variables (and to pre-calculate the md5 hash of your user credentials before deployment). 55 | -------------------------------------------------------------------------------- /benchmarks/docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3.8' 2 | 3 | services: 4 | # AutoPG IS the PostgreSQL instance with optimizations 5 | autopg: 6 | build: 7 | context: .. 8 | dockerfile: Dockerfile 9 | container_name: autopg-postgres 10 | environment: 11 | # PostgreSQL configuration 12 | POSTGRES_USER: postgres 13 | POSTGRES_PASSWORD: postgres 14 | POSTGRES_DB: benchmark 15 | # AutoPG optimization settings 16 | AUTOPG_DB_TYPE: web 17 | AUTOPG_TOTAL_MEMORY_MB: 8192 18 | AUTOPG_CPU_COUNT: 4 19 | AUTOPG_NUM_CONNECTIONS: 200 20 | AUTOPG_PRIMARY_DISK_TYPE: SSD 21 | # Enable diagnostics webapp 22 | AUTOPG_ENABLE_WEBAPP: "true" 23 | AUTOPG_WEBAPP_HOST: 0.0.0.0 24 | AUTOPG_WEBAPP_PORT: 8000 25 | # Database connection for webapp 26 | AUTOPG_DB_HOST: localhost 27 | AUTOPG_DB_PORT: 5432 28 | AUTOPG_DB_NAME: benchmark 29 | AUTOPG_DB_USER: postgres 30 | AUTOPG_DB_PASSWORD: postgres 31 | ports: 32 | - "5434:5432" # PostgreSQL port 33 | - "8000:8000" # Diagnostics webapp port 34 | volumes: 35 | - postgres_data_v6:/var/lib/postgresql/data 36 | - ./postgres-init.sql:/docker-entrypoint-initdb.d/init.sql:ro 37 | - ./pg_hba.conf:/etc/postgresql/pg_hba.conf 38 | healthcheck: 39 | test: ["CMD-SHELL", "pg_isready -U postgres"] 40 | interval: 5s 41 | timeout: 5s 42 | retries: 10 43 | networks: 44 | - autopg-network 45 | 46 | benchmark: 47 | build: 48 | context: . 49 | dockerfile: Dockerfile.benchmark 50 | container_name: autopg-benchmark 51 | depends_on: 52 | autopg: 53 | condition: service_healthy 54 | environment: 55 | POSTGRES_HOST: autopg # Connect to the autopg container 56 | POSTGRES_PORT: 5432 57 | POSTGRES_DB: benchmark 58 | POSTGRES_USER: postgres 59 | POSTGRES_PASSWORD: postgres 60 | PYTHONUNBUFFERED: 1 61 | volumes: 62 | - .:/app 63 | - ./results:/results 64 | working_dir: /app 65 | command: tail -f /dev/null # Keep container running for interactive use 66 | networks: 67 | - autopg-network 68 | 69 | volumes: 70 | postgres_data_v6: 71 | 72 | networks: 73 | autopg-network: 74 | driver: bridge -------------------------------------------------------------------------------- /autopgpool/autopgpool/__tests__/test_env.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | from pydantic import BaseModel 5 | 6 | from autopgpool.env import swap_env 7 | 8 | 9 | def test_swap_env_with_string() -> None: 10 | # Set environment variables for testing 11 | os.environ["TEST_VAR"] = "test_value" 12 | 13 | # Test with a string that starts with $ 14 | assert swap_env("$TEST_VAR") == "test_value" 15 | 16 | # Test with a regular string 17 | assert swap_env("regular_string") == "regular_string" 18 | 19 | 20 | def test_swap_env_with_dict() -> None: 21 | os.environ["TEST_VAR"] = "test_value" 22 | os.environ["ANOTHER_VAR"] = "another_value" 23 | 24 | test_dict = {"key1": "$TEST_VAR", "key2": "regular_value", "key3": "$ANOTHER_VAR"} 25 | expected = {"key1": "test_value", "key2": "regular_value", "key3": "another_value"} 26 | 27 | assert swap_env(test_dict) == expected 28 | 29 | 30 | def test_swap_env_with_list() -> None: 31 | os.environ["TEST_VAR"] = "test_value" 32 | 33 | test_list = ["$TEST_VAR", "regular_value", 123] 34 | expected = ["test_value", "regular_value", 123] 35 | 36 | assert swap_env(test_list) == expected 37 | 38 | 39 | def test_swap_env_with_nested_structures() -> None: 40 | os.environ["TEST_VAR"] = "test_value" 41 | 42 | test_nested = { 43 | "key1": "$TEST_VAR", 44 | "key2": ["regular_value", "$TEST_VAR"], 45 | "key3": {"nested_key": "$TEST_VAR"}, 46 | } 47 | 48 | expected = { 49 | "key1": "test_value", 50 | "key2": ["regular_value", "test_value"], 51 | "key3": {"nested_key": "test_value"}, 52 | } 53 | 54 | assert swap_env(test_nested) == expected 55 | 56 | 57 | def test_swap_env_missing_env_var() -> None: 58 | # Ensure the environment variable doesn't exist 59 | if "NONEXISTENT_VAR" in os.environ: 60 | del os.environ["NONEXISTENT_VAR"] 61 | 62 | with pytest.raises(EnvironmentError): 63 | swap_env("$NONEXISTENT_VAR") 64 | 65 | 66 | def test_swap_env_pydantic_parse_int(): 67 | """ 68 | Verify that we can parse ints from string-based env vars. 69 | This mirrors the behavior that we'll be doing when reading the config file. 70 | 71 | """ 72 | 73 | class DemoModel(BaseModel): 74 | test_int: int 75 | 76 | os.environ["TEST_INT"] = "123" 77 | 78 | assert DemoModel.model_validate(swap_env({"test_int": "$TEST_INT"})) == DemoModel(test_int=123) 79 | -------------------------------------------------------------------------------- /.github/workflows/docker.yml: -------------------------------------------------------------------------------- 1 | name: Docker Images 2 | 3 | on: 4 | push: 5 | tags: 6 | - 'v*' 7 | pull_request: 8 | types: [labeled, synchronize] 9 | # Allow manual trigger for testing 10 | workflow_dispatch: 11 | 12 | env: 13 | REGISTRY: ghcr.io 14 | IMAGE_NAME: ${{ github.repository }} 15 | 16 | jobs: 17 | build-and-push: 18 | # Skip if PR without "Full Build" label 19 | if: | 20 | github.event_name == 'push' || 21 | github.event_name == 'workflow_dispatch' || 22 | (github.event_name == 'pull_request' && contains(github.event.pull_request.labels.*.name, 'Full Build')) 23 | runs-on: ubuntu-latest 24 | permissions: 25 | contents: read 26 | packages: write 27 | 28 | strategy: 29 | matrix: 30 | postgres-version: ['14', '15', '16', '17', '18'] 31 | 32 | steps: 33 | - name: Checkout repository 34 | uses: actions/checkout@v4 35 | 36 | - name: Set up Docker Buildx 37 | uses: docker/setup-buildx-action@v3 38 | 39 | - name: Log in to the Container registry 40 | # Only login for actual deployments 41 | if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/') 42 | uses: docker/login-action@v3 43 | with: 44 | registry: ${{ env.REGISTRY }} 45 | username: ${{ github.actor }} 46 | password: ${{ secrets.GITHUB_TOKEN }} 47 | 48 | - name: Extract metadata (tags, labels) for Docker 49 | id: meta 50 | uses: docker/metadata-action@v5 51 | with: 52 | images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} 53 | tags: | 54 | # Tag with postgres version and release tag (e.g. pg16-v1.0.0) 55 | type=raw,value=pg${{ matrix.postgres-version }}-${{ github.ref_name }} 56 | # Tag with postgres version and latest (e.g. pg16-latest) 57 | type=raw,value=pg${{ matrix.postgres-version }}-latest 58 | 59 | - name: Build and push Docker image 60 | uses: docker/build-push-action@v5 61 | with: 62 | context: . 63 | # Only push for tag events 64 | push: ${{ github.event_name == 'push' && startsWith(github.ref, 'refs/tags/') }} 65 | build-args: | 66 | POSTGRES_VERSION=${{ matrix.postgres-version }} 67 | tags: ${{ steps.meta.outputs.tags }} 68 | labels: ${{ steps.meta.outputs.labels }} 69 | cache-from: type=gha 70 | cache-to: type=gha,mode=max 71 | -------------------------------------------------------------------------------- /benchmarks/postgres-init.sql: -------------------------------------------------------------------------------- 1 | -- Initialize PostgreSQL for benchmarking 2 | -- Enable required extensions 3 | 4 | CREATE EXTENSION IF NOT EXISTS pg_trgm; 5 | CREATE EXTENSION IF NOT EXISTS btree_gin; 6 | CREATE EXTENSION IF NOT EXISTS btree_gist; 7 | 8 | -- pg_stat_statements configuration is now in postgresql.conf 9 | -- This ensures the extension is preloaded before these settings are applied 10 | 11 | -- Create benchmark schema 12 | CREATE SCHEMA IF NOT EXISTS benchmark; 13 | 14 | -- Grant permissions 15 | GRANT ALL ON SCHEMA benchmark TO postgres; 16 | GRANT ALL ON ALL TABLES IN SCHEMA benchmark TO postgres; 17 | GRANT ALL ON ALL SEQUENCES IN SCHEMA benchmark TO postgres; 18 | 19 | -- Create initial tables for stress testing 20 | -- These will be populated by the benchmark CLI 21 | 22 | -- Users table (will be heavily queried) 23 | CREATE TABLE IF NOT EXISTS benchmark.users ( 24 | id SERIAL PRIMARY KEY, 25 | username VARCHAR(100) NOT NULL, 26 | email VARCHAR(255) NOT NULL, 27 | created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, 28 | last_login TIMESTAMP, 29 | status VARCHAR(20) DEFAULT 'active', 30 | profile_data JSONB 31 | ); 32 | 33 | -- Posts table (large table for sequential scan testing) 34 | CREATE TABLE IF NOT EXISTS benchmark.posts ( 35 | id SERIAL PRIMARY KEY, 36 | user_id INTEGER REFERENCES benchmark.users(id), 37 | title VARCHAR(500), 38 | content TEXT, 39 | created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, 40 | updated_at TIMESTAMP, 41 | view_count INTEGER DEFAULT 0, 42 | tags TEXT[], 43 | metadata JSONB 44 | ); 45 | 46 | -- Comments table (for join operations) 47 | CREATE TABLE IF NOT EXISTS benchmark.comments ( 48 | id SERIAL PRIMARY KEY, 49 | post_id INTEGER REFERENCES benchmark.posts(id), 50 | user_id INTEGER REFERENCES benchmark.users(id), 51 | content TEXT, 52 | created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, 53 | parent_id INTEGER REFERENCES benchmark.comments(id), 54 | likes INTEGER DEFAULT 0 55 | ); 56 | 57 | -- Analytics events table (for aggregation queries) 58 | CREATE TABLE IF NOT EXISTS benchmark.events ( 59 | id SERIAL PRIMARY KEY, 60 | user_id INTEGER, 61 | event_type VARCHAR(50), 62 | event_data JSONB, 63 | created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, 64 | session_id UUID, 65 | ip_address INET 66 | ); 67 | 68 | -- Intentionally create tables without indexes to test diagnostics 69 | -- The benchmark will demonstrate finding missing indexes 70 | 71 | -- Add one partial index as an example 72 | CREATE INDEX idx_users_status_active ON benchmark.users(username) WHERE status = 'active'; 73 | 74 | -- Reset statistics 75 | SELECT pg_stat_reset(); 76 | SELECT pg_stat_statements_reset(); -------------------------------------------------------------------------------- /autopgpool/autopgpool/__tests__/test_cli.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from autopgpool.cli import generate_pgbouncer_config 4 | from autopgpool.config import MainConfig, PgbouncerConfig, Pool, User 5 | 6 | 7 | def test_generate_pgbouncer_config(temp_dir: Path) -> None: 8 | """Test that pgbouncer config files are generated correctly.""" 9 | # Create a test config 10 | config = MainConfig( 11 | users=[ 12 | User(username="testuser", password="testpass", grants=["testdb"]), 13 | User(username="admin", password="adminpass", grants=["testdb"]), 14 | ], 15 | pools={ 16 | "testdb": Pool( 17 | remote=Pool.RemoteDatabase( 18 | host="localhost", 19 | port=5432, 20 | database="testdb", 21 | username="pguser", 22 | password="pgpass", 23 | ), 24 | pool_mode="transaction", 25 | ) 26 | }, 27 | pgbouncer=PgbouncerConfig( 28 | listen_port=6432, 29 | auth_type="md5", 30 | admin_users=["admin"], 31 | passthrough_kwargs={"application_name": "pgbouncer"}, 32 | ), 33 | ) 34 | 35 | # Generate the config files 36 | generate_pgbouncer_config(config, str(temp_dir)) 37 | 38 | # Check that the files were created 39 | assert (temp_dir / "pgbouncer.ini").exists() 40 | assert (temp_dir / "userlist.txt").exists() 41 | assert (temp_dir / "pgbouncer_hba.conf").exists() 42 | 43 | # Check the content of the pgbouncer.ini file 44 | pgbouncer_ini = (temp_dir / "pgbouncer.ini").read_text() 45 | # Verify key configuration elements are present 46 | assert "[pgbouncer]" in pgbouncer_ini 47 | assert "listen_port = 6432" in pgbouncer_ini 48 | assert "auth_type = hba" in pgbouncer_ini # overridden 49 | assert "auth_file = " in pgbouncer_ini 50 | assert "application_name = pgbouncer" in pgbouncer_ini 51 | assert "[databases]" in pgbouncer_ini 52 | assert "testdb = " in pgbouncer_ini 53 | assert "host=localhost" in pgbouncer_ini 54 | assert "port=5432" in pgbouncer_ini 55 | 56 | # Check the content of the userlist.txt file 57 | userlist = (temp_dir / "userlist.txt").read_text() 58 | # MD5 passwords are hashed, so we can't check exact values 59 | assert '"testuser"' in userlist 60 | assert '"admin"' in userlist 61 | 62 | # Check the content of the HBA file 63 | hba_content = (temp_dir / "pgbouncer_hba.conf").read_text() 64 | assert "# TYPE\tDATABASE\tUSER\tADDRESS\tMETHOD" in hba_content 65 | 66 | # Check HBA entries for testuser 67 | assert "local\ttestdb\ttestuser\t\tmd5" in hba_content 68 | assert "host\ttestdb\ttestuser\t0.0.0.0/0\tmd5" in hba_content 69 | assert "host\ttestdb\ttestuser\t::/0\tmd5" in hba_content 70 | 71 | # Check HBA entries for admin 72 | assert "local\ttestdb\tadmin\t\tmd5" in hba_content 73 | assert "host\ttestdb\tadmin\t0.0.0.0/0\tmd5" in hba_content 74 | assert "host\ttestdb\tadmin\t::/0\tmd5" in hba_content 75 | -------------------------------------------------------------------------------- /autopg/__tests__/test_system_info.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple 2 | from unittest.mock import mock_open, patch 3 | 4 | import pytest 5 | 6 | from autopg.system_info import DiskType, get_cpu_info, get_disk_type, get_memory_info 7 | 8 | 9 | def test_get_memory_info() -> None: 10 | """Test memory info retrieval with mocked values""" 11 | VirtualMemory = NamedTuple("VirtualMemory", [("total", int), ("available", int)]) 12 | mock_vm = VirtualMemory( 13 | total=32 * (1024**3), # 32GB total 14 | available=16 * (1024**3), # 16GB available 15 | ) 16 | 17 | with patch("psutil.virtual_memory", return_value=mock_vm): 18 | memory_info = get_memory_info() 19 | assert memory_info.total == 32.0 20 | assert memory_info.available == 16.0 21 | 22 | 23 | def test_get_cpu_info() -> None: 24 | """Test CPU info retrieval with mocked values""" 25 | CpuFreq = NamedTuple("CpuFreq", [("current", float)]) 26 | mock_freq = CpuFreq(current=2.5) # 2.5 GHz 27 | 28 | with patch("psutil.cpu_count", return_value=8) as mock_count: 29 | with patch("psutil.cpu_freq", return_value=mock_freq): 30 | cpu_info = get_cpu_info() 31 | assert cpu_info.count == 8 32 | assert cpu_info.current_freq == 2.5 33 | mock_count.assert_called_once_with(logical=True) 34 | 35 | 36 | @pytest.mark.parametrize( 37 | "rotational_value,expected_type", 38 | [ 39 | ("0\n", DiskType.SSD), 40 | ("1\n", DiskType.HDD), 41 | ], 42 | ) 43 | def test_get_disk_type(rotational_value: str, expected_type: DiskType) -> None: 44 | """Test disk type detection for both SSD and HDD""" 45 | DiskPartition = NamedTuple("DiskPartition", [("device", str)]) 46 | mock_partition = DiskPartition(device="/dev/sda1") 47 | 48 | with ( 49 | patch("psutil.disk_partitions", return_value=[mock_partition]), 50 | patch("os.path.exists", return_value=True), 51 | patch("builtins.open", mock_open(read_data=rotational_value)), 52 | ): 53 | disk_type = get_disk_type() 54 | assert disk_type == expected_type 55 | 56 | 57 | @pytest.mark.parametrize( 58 | "error_source,expected_result", 59 | [ 60 | ("disk_partitions", None), # Test disk_partitions raising exception 61 | ("file_read", None), # Test file read raising exception 62 | ], 63 | ) 64 | def test_get_disk_type_errors(error_source: str, expected_result: None) -> None: 65 | """Test error handling in disk type detection""" 66 | DiskPartition = NamedTuple("DiskPartition", [("device", str)]) 67 | mock_partition = DiskPartition(device="/dev/sda1") 68 | 69 | if error_source == "disk_partitions": 70 | with patch("psutil.disk_partitions", side_effect=Exception()): 71 | assert get_disk_type() == expected_result 72 | else: 73 | with ( 74 | patch("psutil.disk_partitions", return_value=[mock_partition]), 75 | patch("os.path.exists", return_value=True), 76 | patch("builtins.open", side_effect=Exception()), 77 | ): 78 | assert get_disk_type() == expected_result 79 | -------------------------------------------------------------------------------- /autopgpool/autopgpool/config.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Literal 2 | 3 | from pydantic import BaseModel, model_validator 4 | 5 | POOL_MODES = Literal["session", "transaction", "statement"] 6 | AUTH_TYPES = Literal["cert", "md5", "scram-sha-256", "plain", "trust", "any", "hba", "pam"] 7 | 8 | 9 | class User(BaseModel): 10 | """ 11 | A user that can connect to the database through pgbouncer. This user gets 12 | encoded and placed into the pgbouncer userlist. 13 | 14 | """ 15 | 16 | username: str 17 | 18 | # Specified in raw text; we will encode this internally 19 | password: str 20 | 21 | # Pools that this user has access to 22 | grants: list[str] 23 | 24 | 25 | class Pool(BaseModel): 26 | """ 27 | A synthetically defined database that can be connected to through pgbouncer. These will 28 | establish an independent connection pool for each of these databases. 29 | 30 | """ 31 | 32 | class RemoteDatabase(BaseModel): 33 | host: str 34 | port: int 35 | database: str 36 | username: str 37 | password: str 38 | 39 | remote: RemoteDatabase 40 | pool_mode: POOL_MODES = "transaction" 41 | 42 | 43 | class PgbouncerConfig(BaseModel): 44 | listen_addr: str = "*" 45 | listen_port: int = 6432 46 | listen_addr: str = "0.0.0.0" 47 | 48 | auth_type: AUTH_TYPES = "md5" 49 | pool_mode: POOL_MODES = "transaction" 50 | 51 | max_client_conn: int = 100 52 | default_pool_size: int = 10 53 | 54 | ignore_startup_parameters: list[str] = ["extra_float_digits"] 55 | 56 | admin_users: list[str] | None = None 57 | stats_users: list[str] | None = None 58 | 59 | # By default we stop stalled transactions from blocking the pool 60 | # fixes: common query_wait_timeout (age=120s) where queries can't 61 | # be handled in the pool because the connection stream is saturated 62 | # If users override this to None, no timeout will be applied. 63 | # https://dba.stackexchange.com/questions/261709/pgbouncer-logging-details-for-query-wait-timeout-error 64 | # https://stackoverflow.com/questions/23394272/how-does-pgbouncer-behave-when-transaction-pooling-is-enabled-and-a-single-state 65 | idle_transaction_timeout: int | None = 60 66 | 67 | # Support prepared statements, which are used by some default query constructors 68 | # in sqlalchemy and asyncpg 69 | # https://github.com/pgbouncer/pgbouncer/pull/845 70 | max_prepared_statements: int = 10 71 | 72 | passthrough_kwargs: dict[str, Any] = {} 73 | 74 | 75 | class MainConfig(BaseModel): 76 | """ 77 | The main configuration for pgbouncer. 78 | """ 79 | 80 | users: list[User] 81 | pools: dict[str, Pool] 82 | pgbouncer: PgbouncerConfig = PgbouncerConfig() 83 | 84 | @model_validator(mode="after") 85 | def validate_pgbouncer_users(self): 86 | # Ensure that any specified users have been added to the userlist 87 | valid_users = {user.username for user in self.users} 88 | for user in self.pgbouncer.admin_users or []: 89 | if user not in valid_users: 90 | raise ValueError(f"User {user} is not in the userlist") 91 | for user in self.pgbouncer.stats_users or []: 92 | if user not in valid_users: 93 | raise ValueError(f"User {user} is not in the userlist") 94 | 95 | return self 96 | 97 | @model_validator(mode="after") 98 | def validate_pool_grants(self): 99 | valid_pools = set(self.pools.keys()) 100 | for user in self.users: 101 | for grant in user.grants: 102 | if grant not in valid_pools: 103 | raise ValueError( 104 | f"User {user.username} has grant {grant} which is not a valid pool" 105 | ) 106 | 107 | return self 108 | -------------------------------------------------------------------------------- /autopgpool/autopgpool/ini_writer.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | from pathlib import Path 3 | from typing import Any 4 | 5 | from autopgpool.config import AUTH_TYPES, User 6 | 7 | 8 | def format_ini_value(value: Any) -> str: 9 | """ 10 | Format a Python value for an INI file. 11 | 12 | Args: 13 | value: The value to format 14 | 15 | Returns: 16 | A string representation of the value suitable for an INI file 17 | """ 18 | if isinstance(value, bool): 19 | return "1" if value else "0" 20 | elif isinstance(value, (int, float)): 21 | return str(value) 22 | elif isinstance(value, str): 23 | # For strings, just return the string without quotes 24 | # PgBouncer doesn't require quotes for string values in its config 25 | return value 26 | elif isinstance(value, list): 27 | # For lists, join with commas 28 | return ", ".join(format_ini_value(item) for item in value) # type: ignore 29 | elif value is None: 30 | return "" 31 | else: 32 | return str(value) 33 | 34 | 35 | def write_ini_file( 36 | config: dict[str, dict[str, Any]], 37 | filepath: Path, 38 | section_comments: dict[str, str] | None = None, 39 | ) -> None: 40 | """ 41 | Write a configuration dictionary to an INI file. 42 | 43 | Args: 44 | config: Dictionary with sections as keys and key-value pairs as values 45 | filepath: Path to write the INI file to 46 | section_comments: Optional comments to add before each section 47 | """ 48 | with open(filepath, "w") as f: 49 | for section, items in config.items(): 50 | # Add optional comment for the section 51 | if section_comments and section in section_comments: 52 | f.write(f"# {section_comments[section]}\n") 53 | 54 | # Write section header 55 | f.write(f"[{section}]\n") 56 | 57 | # Write key-value pairs 58 | for key, value in items.items(): 59 | formatted_value = format_ini_value(value) 60 | if formatted_value: # Skip empty values 61 | f.write(f"{key} = {formatted_value}\n") 62 | 63 | # Add a blank line between sections 64 | f.write("\n") 65 | 66 | 67 | def write_userlist_file(users: list[User], filepath: Path, encrypt: AUTH_TYPES) -> None: 68 | """ 69 | Write a pgbouncer userlist file. 70 | 71 | Args: 72 | users: List of user dictionaries with username and password 73 | filepath: Path to write the userlist file to 74 | encrypt: Authentication type to use for password encryption 75 | """ 76 | with open(filepath, "w") as f: 77 | for user in users: 78 | password = user.password 79 | if encrypt == "md5": 80 | password = f"md5{hashlib.md5((password + user.username).encode()).hexdigest()}" 81 | elif encrypt == "scram-sha-256": 82 | raise NotImplementedError("SCRAM-SHA-256 is not yet implemented") 83 | f.write(f'"{user.username}" "{password}"\n') 84 | 85 | 86 | def write_hba_file(users: list[User], filepath: Path) -> None: 87 | """ 88 | Write a pgbouncer HBA (host-based authentication) file. 89 | 90 | Args: 91 | users: List of users with their granted pools 92 | filepath: Path to write the HBA file to 93 | """ 94 | with open(filepath, "w") as f: 95 | f.write("# TYPE\tDATABASE\tUSER\tADDRESS\tMETHOD\n") 96 | 97 | # For each user, create entries for their granted pools 98 | for user in users: 99 | for pool in user.grants: 100 | # Allow local connections 101 | f.write(f"local\t{pool}\t{user.username}\t\tmd5\n") 102 | # Allow host connections from anywhere (IPv4 and IPv6) 103 | f.write(f"host\t{pool}\t{user.username}\t0.0.0.0/0\tmd5\n") 104 | f.write(f"host\t{pool}\t{user.username}\t::/0\tmd5\n") 105 | # Block all other user/grants from everything else not listed above 106 | f.write("host\tall\tall\t0.0.0.0/0\treject\n") 107 | f.write("host\tall\tall\t::/0\treject\n") 108 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # PyPI configuration file 171 | .pypirc 172 | 173 | postgres_data 174 | -------------------------------------------------------------------------------- /autopgpool/autopgpool/cli.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from pathlib import Path 4 | 5 | import click 6 | from rich.markup import escape 7 | 8 | from autopgpool.config import MainConfig, User 9 | from autopgpool.env import load_toml_config 10 | from autopgpool.ini_writer import ( 11 | write_hba_file, 12 | write_ini_file, 13 | write_userlist_file, 14 | ) 15 | from autopgpool.logging import CONSOLE 16 | 17 | DEFAULT_CONFIG_PATH = "/etc/autopgpool/autopgpool.toml" 18 | DEFAULT_OUTPUT_DIR = "/etc/pgbouncer" 19 | 20 | 21 | def generate_pgbouncer_config(config: MainConfig, output_dir: str) -> None: 22 | """ 23 | Generate pgbouncer configuration files from the MainConfig. 24 | 25 | Args: 26 | config: The parsed configuration 27 | output_dir: Directory to write configuration files to 28 | """ 29 | output_path = Path(output_dir) 30 | os.makedirs(output_path, exist_ok=True) 31 | 32 | userlist_path = output_path / "userlist.txt" 33 | hba_path = output_path / "pgbouncer_hba.conf" 34 | ini_path = output_path / "pgbouncer.ini" 35 | 36 | # Create users with grants for HBA configuration 37 | users = [ 38 | User(username=user.username, password=user.password, grants=user.grants) 39 | for user in config.users 40 | ] 41 | 42 | # Write userlist.txt file 43 | write_userlist_file(users, userlist_path, encrypt=config.pgbouncer.auth_type) 44 | CONSOLE.print(f"Wrote userlist file to [bold]{userlist_path}[/bold]") 45 | CONSOLE.print(f"Userlist file contents:\n###\n{escape(userlist_path.read_text())}\n###\n") 46 | 47 | # Even when the user hasn't requested hba auth, we want to write the HBA file 48 | # to provide our access grants 49 | write_hba_file(users, hba_path) 50 | CONSOLE.print(f"Wrote HBA file to [bold]{hba_path}[/bold]") 51 | CONSOLE.print(f"HBA file contents:\n###\n{escape(hba_path.read_text())}\n###\n") 52 | 53 | # Create pgbouncer.ini 54 | pgbouncer_config = { 55 | "pgbouncer": { 56 | **config.pgbouncer.model_dump(exclude={"passthrough_kwargs"}), 57 | **config.pgbouncer.passthrough_kwargs, 58 | **{ 59 | "auth_type": "hba", 60 | "auth_file": userlist_path, 61 | "auth_hba_file": hba_path, 62 | }, 63 | }, 64 | "databases": { 65 | # Format: dbname = connection_string 66 | pool_name: ( 67 | f"host={pool.remote.host} port={pool.remote.port} dbname={pool.remote.database} " 68 | f"user={pool.remote.username} password={pool.remote.password} pool_mode={pool.pool_mode}" 69 | ) 70 | for pool_name, pool in config.pools.items() 71 | }, 72 | } 73 | 74 | # Write the pgbouncer.ini file 75 | write_ini_file(pgbouncer_config, ini_path) 76 | CONSOLE.print(f"Wrote pgbouncer.ini file to [bold]{ini_path}[/bold]") 77 | CONSOLE.print(f"PGBouncer.ini file contents:\n###\n{escape(ini_path.read_text())}\n###\n") 78 | 79 | CONSOLE.print(f"[green]Successfully wrote configuration to {output_dir}[/green]") 80 | 81 | 82 | @click.group() 83 | def cli() -> None: 84 | """autopgpool CLI tool for pgbouncer configuration management.""" 85 | pass 86 | 87 | 88 | @cli.command() 89 | @click.option( 90 | "--config-path", 91 | default=DEFAULT_CONFIG_PATH, 92 | help="Path to the autopgpool TOML configuration file", 93 | ) 94 | @click.option( 95 | "--output-dir", 96 | default=DEFAULT_OUTPUT_DIR, 97 | help="Directory to write pgbouncer configuration files to", 98 | ) 99 | def generate(config_path: str, output_dir: str) -> None: 100 | """Generate pgbouncer configuration files from TOML config.""" 101 | # Load TOML configuration 102 | config_data = load_toml_config(config_path) 103 | 104 | try: 105 | # Parse into Pydantic model 106 | config = MainConfig.model_validate(config_data) 107 | 108 | # Generate configuration files 109 | generate_pgbouncer_config(config, output_dir) 110 | except Exception as e: 111 | CONSOLE.print(f"[red]Error generating configuration: {str(e)}[/red]") 112 | sys.exit(1) 113 | 114 | 115 | @cli.command() 116 | @click.option( 117 | "--config-path", 118 | default=DEFAULT_CONFIG_PATH, 119 | help="Path to the autopgpool TOML configuration file", 120 | ) 121 | def validate(config_path: str) -> None: 122 | """Validate the autopgpool TOML configuration file.""" 123 | # Load TOML configuration 124 | config_data = load_toml_config(config_path) 125 | 126 | try: 127 | # Parse into Pydantic model 128 | MainConfig.model_validate(config_data) 129 | CONSOLE.print(f"[green]Configuration file at {config_path} is valid.[/green]") 130 | except Exception as e: 131 | CONSOLE.print(f"[red]Configuration validation error: {str(e)}[/red]") 132 | sys.exit(1) 133 | 134 | 135 | if __name__ == "__main__": 136 | cli() 137 | -------------------------------------------------------------------------------- /autopgpool/autopgpool/__tests__/test_ini_writer.py: -------------------------------------------------------------------------------- 1 | import textwrap 2 | from pathlib import Path 3 | from tempfile import NamedTemporaryFile 4 | from typing import Any 5 | 6 | import pytest 7 | 8 | from autopgpool.config import User 9 | from autopgpool.ini_writer import ( 10 | format_ini_value, 11 | write_ini_file, 12 | write_userlist_file, 13 | ) 14 | 15 | 16 | @pytest.mark.parametrize( 17 | "value,expected", 18 | [ 19 | (True, "1"), 20 | (False, "0"), 21 | (123, "123"), 22 | (45.67, "45.67"), 23 | ("hello", "hello"), 24 | ('hello"world', 'hello"world'), # String with quotes 25 | (["a", "b", "c"], "a, b, c"), 26 | ([1, 2, 3], "1, 2, 3"), 27 | ([True, False], "1, 0"), 28 | (None, ""), 29 | ([None, "test"], ", test"), 30 | ({"key": "value"}, "{'key': 'value'}"), # Default str() for unsupported types 31 | ], 32 | ) 33 | def test_format_ini_value(value: Any, expected: str) -> None: 34 | """Test the format_ini_value function with various input types.""" 35 | assert format_ini_value(value) == expected 36 | 37 | 38 | def test_write_ini_file() -> None: 39 | """Test writing a configuration to an INI file.""" 40 | config: dict[str, dict[str, Any]] = { 41 | "section1": { 42 | "key1": "value1", 43 | "key2": 123, 44 | "key3": True, 45 | "key4": None, # Should be skipped 46 | }, 47 | "section2": { 48 | "list_key": ["a", "b", "c"], 49 | "bool_key": False, 50 | }, 51 | } 52 | 53 | section_comments: dict[str, str] = { 54 | "section1": "This is section 1", 55 | # No comment for section2 56 | } 57 | 58 | with NamedTemporaryFile() as temp_file: 59 | filepath = Path(temp_file.name) 60 | write_ini_file(config, filepath, section_comments) 61 | 62 | with open(filepath, "r") as f: 63 | content = f.read() 64 | 65 | # Verify the content 66 | expected_content = textwrap.dedent("""\ 67 | # This is section 1 68 | [section1] 69 | key1 = value1 70 | key2 = 123 71 | key3 = 1 72 | 73 | [section2] 74 | list_key = a, b, c 75 | bool_key = 0 76 | 77 | """) 78 | assert content == expected_content 79 | 80 | 81 | def test_write_ini_file_no_comments() -> None: 82 | """Test writing a configuration to an INI file without section comments.""" 83 | config: dict[str, dict[str, Any]] = { 84 | "section1": { 85 | "key1": "value1", 86 | }, 87 | } 88 | 89 | with NamedTemporaryFile() as temp_file: 90 | filepath = Path(temp_file.name) 91 | write_ini_file(config, filepath) # No section_comments 92 | 93 | with open(filepath, "r") as f: 94 | content = f.read() 95 | 96 | expected_content = textwrap.dedent("""\ 97 | [section1] 98 | key1 = value1 99 | 100 | """) 101 | assert content == expected_content 102 | 103 | 104 | def test_write_userlist_file_plain() -> None: 105 | """Test writing users to a userlist file with plain auth.""" 106 | users = [ 107 | User(username="user1", password="pass1", grants=[]), 108 | User(username="user2", password="pass2", grants=[]), 109 | ] 110 | 111 | with NamedTemporaryFile() as temp_file: 112 | filepath = Path(temp_file.name) 113 | write_userlist_file(users, filepath, "plain") 114 | 115 | with open(filepath, "r") as f: 116 | content = f.read() 117 | 118 | expected_content = '"user1" "pass1"\n"user2" "pass2"\n' 119 | assert content == expected_content 120 | 121 | 122 | def test_write_userlist_file_md5() -> None: 123 | """Test writing users to a userlist file with md5 auth.""" 124 | users = [ 125 | User(username="user1", password="pass1", grants=[]), 126 | User(username="user2", password="pass2", grants=[]), 127 | ] 128 | 129 | with NamedTemporaryFile() as temp_file: 130 | filepath = Path(temp_file.name) 131 | # We're not testing the actual md5 implementation, just that it's used 132 | write_userlist_file(users, filepath, "md5") 133 | 134 | with open(filepath, "r") as f: 135 | content = f.read() 136 | 137 | assert '"user1" "md55e4eab96e8b9868ed28cc79c9ceec8b3"' in content 138 | assert '"user2" "md553cc9e310bc5e01cb42fd0aeda81e27d"' in content 139 | 140 | 141 | def test_write_userlist_file_empty() -> None: 142 | """Test writing an empty list of users.""" 143 | users: list[User] = [] 144 | 145 | with NamedTemporaryFile() as temp_file: 146 | filepath = Path(temp_file.name) 147 | write_userlist_file(users, filepath, "plain") 148 | 149 | with open(filepath, "r") as f: 150 | content = f.read() 151 | 152 | assert content == "" 153 | -------------------------------------------------------------------------------- /autopg/static/pygments.css: -------------------------------------------------------------------------------- 1 | pre { line-height: 125%; } 2 | td.linenos .normal { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; } 3 | span.linenos { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; } 4 | td.linenos .special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; } 5 | span.linenos.special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; } 6 | .highlight .hll { background-color: #ffffcc } 7 | .highlight { background: #f8f8f8; } 8 | .highlight .c { color: #3D7B7B; font-style: italic } /* Comment */ 9 | .highlight .err { border: 1px solid #F00 } /* Error */ 10 | .highlight .k { color: #008000; font-weight: bold } /* Keyword */ 11 | .highlight .o { color: #666 } /* Operator */ 12 | .highlight .ch { color: #3D7B7B; font-style: italic } /* Comment.Hashbang */ 13 | .highlight .cm { color: #3D7B7B; font-style: italic } /* Comment.Multiline */ 14 | .highlight .cp { color: #9C6500 } /* Comment.Preproc */ 15 | .highlight .cpf { color: #3D7B7B; font-style: italic } /* Comment.PreprocFile */ 16 | .highlight .c1 { color: #3D7B7B; font-style: italic } /* Comment.Single */ 17 | .highlight .cs { color: #3D7B7B; font-style: italic } /* Comment.Special */ 18 | .highlight .gd { color: #A00000 } /* Generic.Deleted */ 19 | .highlight .ge { font-style: italic } /* Generic.Emph */ 20 | .highlight .ges { font-weight: bold; font-style: italic } /* Generic.EmphStrong */ 21 | .highlight .gr { color: #E40000 } /* Generic.Error */ 22 | .highlight .gh { color: #000080; font-weight: bold } /* Generic.Heading */ 23 | .highlight .gi { color: #008400 } /* Generic.Inserted */ 24 | .highlight .go { color: #717171 } /* Generic.Output */ 25 | .highlight .gp { color: #000080; font-weight: bold } /* Generic.Prompt */ 26 | .highlight .gs { font-weight: bold } /* Generic.Strong */ 27 | .highlight .gu { color: #800080; font-weight: bold } /* Generic.Subheading */ 28 | .highlight .gt { color: #04D } /* Generic.Traceback */ 29 | .highlight .kc { color: #008000; font-weight: bold } /* Keyword.Constant */ 30 | .highlight .kd { color: #008000; font-weight: bold } /* Keyword.Declaration */ 31 | .highlight .kn { color: #008000; font-weight: bold } /* Keyword.Namespace */ 32 | .highlight .kp { color: #008000 } /* Keyword.Pseudo */ 33 | .highlight .kr { color: #008000; font-weight: bold } /* Keyword.Reserved */ 34 | .highlight .kt { color: #B00040 } /* Keyword.Type */ 35 | .highlight .m { color: #666 } /* Literal.Number */ 36 | .highlight .s { color: #BA2121 } /* Literal.String */ 37 | .highlight .na { color: #687822 } /* Name.Attribute */ 38 | .highlight .nb { color: #008000 } /* Name.Builtin */ 39 | .highlight .nc { color: #00F; font-weight: bold } /* Name.Class */ 40 | .highlight .no { color: #800 } /* Name.Constant */ 41 | .highlight .nd { color: #A2F } /* Name.Decorator */ 42 | .highlight .ni { color: #717171; font-weight: bold } /* Name.Entity */ 43 | .highlight .ne { color: #CB3F38; font-weight: bold } /* Name.Exception */ 44 | .highlight .nf { color: #00F } /* Name.Function */ 45 | .highlight .nl { color: #767600 } /* Name.Label */ 46 | .highlight .nn { color: #00F; font-weight: bold } /* Name.Namespace */ 47 | .highlight .nt { color: #008000; font-weight: bold } /* Name.Tag */ 48 | .highlight .nv { color: #19177C } /* Name.Variable */ 49 | .highlight .ow { color: #A2F; font-weight: bold } /* Operator.Word */ 50 | .highlight .w { color: #BBB } /* Text.Whitespace */ 51 | .highlight .mb { color: #666 } /* Literal.Number.Bin */ 52 | .highlight .mf { color: #666 } /* Literal.Number.Float */ 53 | .highlight .mh { color: #666 } /* Literal.Number.Hex */ 54 | .highlight .mi { color: #666 } /* Literal.Number.Integer */ 55 | .highlight .mo { color: #666 } /* Literal.Number.Oct */ 56 | .highlight .sa { color: #BA2121 } /* Literal.String.Affix */ 57 | .highlight .sb { color: #BA2121 } /* Literal.String.Backtick */ 58 | .highlight .sc { color: #BA2121 } /* Literal.String.Char */ 59 | .highlight .dl { color: #BA2121 } /* Literal.String.Delimiter */ 60 | .highlight .sd { color: #BA2121; font-style: italic } /* Literal.String.Doc */ 61 | .highlight .s2 { color: #BA2121 } /* Literal.String.Double */ 62 | .highlight .se { color: #AA5D1F; font-weight: bold } /* Literal.String.Escape */ 63 | .highlight .sh { color: #BA2121 } /* Literal.String.Heredoc */ 64 | .highlight .si { color: #A45A77; font-weight: bold } /* Literal.String.Interpol */ 65 | .highlight .sx { color: #008000 } /* Literal.String.Other */ 66 | .highlight .sr { color: #A45A77 } /* Literal.String.Regex */ 67 | .highlight .s1 { color: #BA2121 } /* Literal.String.Single */ 68 | .highlight .ss { color: #19177C } /* Literal.String.Symbol */ 69 | .highlight .bp { color: #008000 } /* Name.Builtin.Pseudo */ 70 | .highlight .fm { color: #00F } /* Name.Function.Magic */ 71 | .highlight .vc { color: #19177C } /* Name.Variable.Class */ 72 | .highlight .vg { color: #19177C } /* Name.Variable.Global */ 73 | .highlight .vi { color: #19177C } /* Name.Variable.Instance */ 74 | .highlight .vm { color: #19177C } /* Name.Variable.Magic */ 75 | .highlight .il { color: #666 } /* Literal.Number.Integer.Long */ -------------------------------------------------------------------------------- /autopg/__tests__/test_cli.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Any, Generator 3 | from unittest.mock import patch 4 | 5 | import pytest 6 | from click.testing import CliRunner 7 | 8 | from autopg.cli import cli 9 | from autopg.constants import ( 10 | OS_LINUX, 11 | ) 12 | from autopg.system_info import CpuInfo, DiskType, MemoryInfo 13 | 14 | 15 | @pytest.fixture 16 | def cli_runner() -> Generator[CliRunner, None, None]: 17 | """Create a Click CLI runner for testing""" 18 | runner = CliRunner() 19 | with runner.isolated_filesystem(): 20 | yield runner 21 | 22 | 23 | @pytest.fixture 24 | def mock_system_info(): 25 | """Mock all system info calls to return consistent values""" 26 | with ( 27 | patch("autopg.cli.get_memory_info") as mock_memory, 28 | patch("autopg.cli.get_cpu_info") as mock_cpu, 29 | patch("autopg.cli.get_disk_type") as mock_disk, 30 | patch("autopg.cli.get_os_type") as mock_os, 31 | patch("autopg.cli.get_postgres_version") as mock_postgres, 32 | ): 33 | # Set up mock returns 34 | mock_memory.return_value = MemoryInfo(total=16, available=8) # 16GB total, 8GB available 35 | mock_cpu.return_value = CpuInfo(count=4, current_freq=2.5) # 4 cores, 2.5GHz 36 | mock_disk.return_value = DiskType.SSD 37 | mock_os.return_value = OS_LINUX 38 | mock_postgres.return_value = "14.0" 39 | 40 | yield 41 | 42 | 43 | def test_build_config(cli_runner: CliRunner, mock_system_info: Any, tmp_path: Path): 44 | """Test that build_config generates a valid configuration file""" 45 | # Create a mock postgresql.conf in the temporary directory 46 | pg_conf_dir = tmp_path / "postgresql" 47 | pg_conf_dir.mkdir() 48 | pg_conf_file = pg_conf_dir / "postgresql.conf" 49 | pg_conf_file.write_text("") 50 | 51 | # Run the CLI command 52 | result = cli_runner.invoke(cli, ["build-config", "--pg-path", str(pg_conf_dir)]) 53 | 54 | # Check the command succeeded 55 | assert result.exit_code == 0 56 | assert "Successfully wrote new PostgreSQL configuration!" in result.output 57 | 58 | # Verify the configuration file was created and contains expected settings 59 | assert pg_conf_file.exists() 60 | config_content = pg_conf_file.read_text() 61 | 62 | # Check for some key configuration parameters 63 | assert "shared_buffers" in config_content 64 | assert "effective_cache_size" in config_content 65 | assert "work_mem" in config_content 66 | assert "max_connections" in config_content 67 | # Check for pg_stat_statements (enabled by default) 68 | assert "shared_preload_libraries = 'pg_stat_statements'" in config_content 69 | assert "pg_stat_statements.track = 'all'" in config_content 70 | assert "pg_stat_statements.max = 10000" in config_content 71 | 72 | 73 | def test_build_config_with_pg_stat_statements_disabled_env( 74 | cli_runner: CliRunner, mock_system_info: Any, tmp_path: Path, monkeypatch: Any 75 | ): 76 | """Test that build_config respects the AUTOPG_ENABLE_PG_STAT_STATEMENTS=false environment variable""" 77 | # Set environment variable to disable pg_stat_statements 78 | monkeypatch.setenv("AUTOPG_ENABLE_PG_STAT_STATEMENTS", "false") 79 | 80 | # Create a mock postgresql.conf in the temporary directory 81 | pg_conf_dir = tmp_path / "postgresql" 82 | pg_conf_dir.mkdir() 83 | pg_conf_file = pg_conf_dir / "postgresql.conf" 84 | pg_conf_file.write_text("") 85 | 86 | # Run the CLI command 87 | result = cli_runner.invoke(cli, ["build-config", "--pg-path", str(pg_conf_dir)]) 88 | 89 | # Check the command succeeded 90 | assert result.exit_code == 0 91 | assert "Successfully wrote new PostgreSQL configuration!" in result.output 92 | 93 | # Verify the configuration file was created and does NOT contain pg_stat_statements settings 94 | assert pg_conf_file.exists() 95 | config_content = pg_conf_file.read_text() 96 | 97 | # Check that pg_stat_statements settings are NOT present 98 | assert "shared_preload_libraries" not in config_content 99 | assert "pg_stat_statements.track" not in config_content 100 | assert "pg_stat_statements.max" not in config_content 101 | 102 | 103 | def test_build_config_with_pg_stat_statements_enabled_env( 104 | cli_runner: CliRunner, mock_system_info: Any, tmp_path: Path, monkeypatch: Any 105 | ): 106 | """Test that build_config respects the AUTOPG_ENABLE_PG_STAT_STATEMENTS=true environment variable""" 107 | # Set environment variable to enable pg_stat_statements 108 | monkeypatch.setenv("AUTOPG_ENABLE_PG_STAT_STATEMENTS", "true") 109 | 110 | # Create a mock postgresql.conf in the temporary directory 111 | pg_conf_dir = tmp_path / "postgresql" 112 | pg_conf_dir.mkdir() 113 | pg_conf_file = pg_conf_dir / "postgresql.conf" 114 | pg_conf_file.write_text("") 115 | 116 | # Run the CLI command 117 | result = cli_runner.invoke(cli, ["build-config", "--pg-path", str(pg_conf_dir)]) 118 | 119 | # Check the command succeeded 120 | assert result.exit_code == 0 121 | assert "Successfully wrote new PostgreSQL configuration!" in result.output 122 | 123 | # Verify the configuration file was created and contains pg_stat_statements settings 124 | assert pg_conf_file.exists() 125 | config_content = pg_conf_file.read_text() 126 | 127 | # Check for pg_stat_statements settings 128 | assert "shared_preload_libraries = 'pg_stat_statements'" in config_content 129 | assert "pg_stat_statements.track = 'all'" in config_content 130 | assert "pg_stat_statements.max = 10000" in config_content 131 | -------------------------------------------------------------------------------- /autopg/__tests__/test_docker.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import subprocess 4 | import tempfile 5 | from pathlib import Path 6 | from time import sleep, time 7 | from typing import Generator 8 | 9 | import psycopg 10 | import pytest 11 | from rich.console import Console 12 | 13 | console = Console() 14 | 15 | 16 | @pytest.fixture 17 | def temp_workspace() -> Generator[Path, None, None]: 18 | """Create a temporary workspace for Docker tests""" 19 | with tempfile.TemporaryDirectory() as temp_dir: 20 | workspace = Path(temp_dir) / "workspace" 21 | # Copy current workspace to temp directory 22 | shutil.copytree(os.getcwd(), workspace, dirs_exist_ok=True) 23 | yield workspace 24 | 25 | 26 | def build_docker_image(temp_workspace: Path, postgres_version: str) -> str: 27 | """ 28 | Build the Docker image for testing. 29 | 30 | :param temp_workspace: Temporary directory containing a copy of the workspace 31 | :param postgres_version: Version of PostgreSQL to test with 32 | 33 | """ 34 | test_tag = f"autopg:test-{postgres_version}" 35 | subprocess.run( 36 | [ 37 | "docker", 38 | "build", 39 | "--build-arg", 40 | f"POSTGRES_VERSION={postgres_version}", 41 | "-t", 42 | test_tag, 43 | ".", 44 | ], 45 | cwd=temp_workspace, 46 | check=True, 47 | ) 48 | return test_tag 49 | 50 | 51 | def start_postgres_container( 52 | temp_workspace: Path, 53 | test_tag: str, 54 | env_vars: dict[str, str] | None = None, 55 | ) -> str: 56 | """ 57 | Start a PostgreSQL container for testing. 58 | 59 | :param temp_workspace: Temporary directory containing a copy of the workspace 60 | :param test_tag: Docker image tag to run 61 | :param env_vars: Environment variables to set in the container 62 | 63 | """ 64 | prefix_env_args = [ 65 | "docker", 66 | "run", 67 | "-d", 68 | "-p", 69 | "5432:5432", 70 | "-e", 71 | "POSTGRES_USER=test_user", 72 | "-e", 73 | "POSTGRES_PASSWORD=test_password", 74 | ] 75 | 76 | for k, v in (env_vars or {}).items(): 77 | prefix_env_args.extend(["-e", f"{k}={v}"]) 78 | 79 | return ( 80 | subprocess.check_output( 81 | [ 82 | *prefix_env_args, 83 | test_tag, 84 | ], 85 | cwd=temp_workspace, 86 | ) 87 | .decode() 88 | .strip() 89 | ) 90 | 91 | 92 | def wait_for_postgres(container_id: str, timeout_seconds: int = 30) -> None: 93 | """ 94 | Wait for PostgreSQL to be ready with a timeout. 95 | 96 | :param container_id: Docker container ID 97 | :param timeout_seconds: Maximum time to wait in seconds 98 | 99 | """ 100 | start_time = time() 101 | while True: 102 | if time() - start_time > timeout_seconds: 103 | raise TimeoutError(f"PostgreSQL not ready after {timeout_seconds} seconds") 104 | 105 | try: 106 | subprocess.run( 107 | ["docker", "exec", container_id, "pg_isready", "-t", "5"], 108 | check=True, 109 | ) 110 | console.print("PostgreSQL is ready") 111 | break 112 | except subprocess.CalledProcessError: 113 | sleep(1) 114 | 115 | # Give time to fully boot and be reachable 116 | sleep(2) 117 | 118 | 119 | def cleanup_container(container_id: str) -> None: 120 | """ 121 | Stop and remove a Docker container. 122 | 123 | :param container_id: Docker container ID 124 | 125 | """ 126 | subprocess.run(["docker", "stop", container_id], check=True) 127 | subprocess.run(["docker", "rm", container_id], check=True) 128 | 129 | 130 | @pytest.mark.integration 131 | @pytest.mark.parametrize("postgres_version", ["14", "15", "16", "17", "18"]) 132 | def test_docker_max_connections(temp_workspace: Path, postgres_version: str) -> None: 133 | """ 134 | Test that Docker image correctly applies PostgreSQL configuration changes. 135 | Specifically tests max_connections parameter. 136 | 137 | :param temp_workspace: Temporary directory containing a copy of the workspace 138 | :param postgres_version: Version of PostgreSQL to test with 139 | 140 | """ 141 | # Build and start container 142 | test_tag = build_docker_image(temp_workspace, postgres_version) 143 | container_id = start_postgres_container( 144 | temp_workspace, 145 | test_tag, 146 | env_vars={ 147 | "AUTOPG_NUM_CONNECTIONS": "45", 148 | }, 149 | ) 150 | 151 | try: 152 | # Wait for PostgreSQL to be ready 153 | wait_for_postgres(container_id) 154 | 155 | # Connect and verify max_connections 156 | conn = psycopg.connect( 157 | host="localhost", 158 | port=5432, 159 | user="test_user", 160 | password="test_password", 161 | dbname="test_user", # PostgreSQL creates a database with the same name as the user by default 162 | ) 163 | 164 | try: 165 | with conn.cursor() as cur: 166 | cur.execute("SHOW max_connections") 167 | result = cur.fetchone() 168 | assert result is not None 169 | assert result[0] == "45" # PostgreSQL returns this as a string (default is 100) 170 | finally: 171 | conn.close() 172 | 173 | except Exception as e: 174 | console.print(f"Error: {e}") 175 | 176 | # Return all of the docker errors 177 | subprocess.run(["docker", "logs", container_id], check=True) 178 | 179 | raise e 180 | finally: 181 | cleanup_container(container_id) 182 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # autopg 2 | 3 | ![Autopg Logo](https://raw.githubusercontent.com/piercefreeman/autopg/main/media/header.png) 4 | 5 | [![Built Versions](https://img.shields.io/badge/autopg:latest-latest-black)](https://github.com/piercefreeman/autopg/pkgs/container/autopg) 6 | [![Docker build status](https://github.com/piercefreeman/autopg/actions/workflows/docker.yml/badge.svg)](https://github.com/piercefreeman/autopg/actions) 7 | [![Test status](https://github.com/piercefreeman/autopg/actions/workflows/test.yml/badge.svg)](https://github.com/piercefreeman/autopg/actions) 8 | 9 | Auto-optimizations for postgres to maximize speed and utilization of host hardware. 10 | 11 | ## Usage 12 | 13 | `autopg` should be a direct replacement for using the `postgres` docker image in your architecture - be that Docker, Kubernetes, etc. 14 | 15 | For example, in `docker-compose.yml` file, add the following: 16 | 17 | ```yaml 18 | services: 19 | postgres: 20 | image: ghcr.io/piercefreeman/autopg:pg16-latest 21 | ports: 22 | - 5432:5432 23 | ``` 24 | 25 | We use reasonable system defaults when you launch without any env variables. We will sniff the docker container for its host machine's hardware and use that to generate the `postgresql.conf` file. But if you want to customize your specific deployment we support two methods: 26 | 27 | 1. Env overrides for autopg inputs 28 | 29 | | Env Variable | Default | Values | Description | 30 | | ------------ | ------- | ------ | ----------- | 31 | | `AUTOPG_DB_TYPE` | `WEB` | `WEB`, `OLTP`, `DW`, `DESKTOP`, `MIXED` | What kind of application will be using the db | 32 | | `AUTOPG_TOTAL_MEMORY_MB` | `None` | Integer | Total memory of the host in MB | 33 | | `AUTOPG_CPU_COUNT` | `None` | Integer | Number of CPUs on the host | 34 | | `AUTOPG_NUM_CONNECTIONS` | `100` | Integer | Max number of concurrent connections to the db | 35 | | `AUTOPG_PRIMARY_DISK_TYPE` | `None` | `SSD`, `HDD`, `SAN` | Type of the primary disk | 36 | | `AUTOPG_ENABLE_PG_STAT_STATEMENTS` | `true` | `true`, `false` | Enable pg_stat_statements extension for query statistics | 37 | 38 | 2. Existing `postgresql.conf` file 39 | 40 | Mount a `postgresql.conf` file into the container at `/etc/postgresql/postgresql.conf`. Any values explicitly provided in the `postgresql.conf` file will override the values generated by autopg. We'll also merge in any values from this file that autopg does not support directly, so this is a great way to add additional custom settings. 41 | 42 | We build images following {postgres_version}-{autopg_version} tags. Use this table to find your desired version: 43 | 44 | | Postgres Version | Autopg Version | Tag | 45 | | ---------------- | -------------- | --- | 46 | | 18 | latest | `autopg:18-latest` | 47 | | 17 | latest | `autopg:17-latest` | 48 | | 16 | latest | `autopg:16-latest` | 49 | | 15 | latest | `autopg:15-latest` | 50 | | 14 | latest | `autopg:14-latest` | 51 | 52 | ## Debugging slow queries 53 | 54 | ![Analysis App](./media/analysis-app.png) 55 | 56 | Sequential scans can absolutely kill the performance of your webapp, since they require the database engine to iterate through all your table's data instead of just pulling from a much quicker index cache. 57 | 58 | We automatically configure your database with `pg_stat_statements`, which transparently captures queries that you run against the database. It puts the results in a regular postgres table so you can aggregate the stats like you do with any other Postgres data. While you're free to login as an admin user and query these stats yourself, we bundle a simple webapp to visualize these logs. For security this is disabled by default - if you want to enable it (which should only be done in firewalled deployments), you can run: 59 | 60 | ```yml 61 | autopg: 62 | environment: 63 | AUTOPG_ENABLE_WEBAPP: true 64 | AUTOPG_WEBAPP_HOST: 0.0.0.0 65 | AUTOPG_WEBAPP_PORT: 8000 66 | ``` 67 | 68 | This provides an interface that currently: 69 | 70 | - Reports sequential scans and the queries to reproduce 71 | - Reports on the current size of the indexes 72 | - Shows the average and aggregate time spent on different queries 73 | - Shows currently running queries 74 | - Recommends a `EXPLAIN ANALYZE` command you can run to inspect the query plan 75 | 76 | From there you can create indexes yourself on the most problematic indexes. I suggest running the `EXPLAIN ANALYZE` how the engine is routing the request, then creating an index to target this slow query, then finally confirming with another analyze. 77 | 78 | ## Algorithm 79 | 80 | The algorithm is a direct Python conversion from [pgtune](https://pgtune.leopard.in.ua/). If you notice any discrepancies in output from the two tools, please report them to Issues (or better yet - add a test case). 81 | 82 | ## Development 83 | 84 | ```bash 85 | uv sync 86 | ``` 87 | 88 | ```bash 89 | uv run autopg 90 | ``` 91 | 92 | To test the docker build pipeline locally, run: 93 | 94 | ```bash 95 | docker build --build-arg POSTGRES_VERSION=16 -t autopg . 96 | ``` 97 | 98 | ```bash 99 | docker run -e POSTGRES_USER=test_user -e POSTGRES_PASSWORD=test_password autopg 100 | ``` 101 | 102 | If you'd like to test out the our analysis webapp, you can boot up the benchmarks. More info is available in that [README](./benchmarks/README.md). 103 | 104 | ```bash 105 | cd benchmarks 106 | docker compose up 107 | ``` 108 | 109 | ## Unit Test 110 | 111 | We have unit tests for the logic and integration tests for the docker image. 112 | 113 | ```bash 114 | uv run pytest -vvv 115 | ``` 116 | 117 | ```bash 118 | uv run pytest -vvv -m "integration" 119 | ``` 120 | 121 | ## Limitations 122 | 123 | - Right now we write the optimization logic in Python so our postgres container relies on having a python interpreter installed. This adds a bit to space overhead and is potentially a security risk. We'd rather bundle a compiled binary that serves the same purpose. 124 | -------------------------------------------------------------------------------- /autopg/__tests__/test_postgres.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Any, Dict 3 | from unittest.mock import patch 4 | 5 | import pytest 6 | 7 | from autopg.postgres import ( 8 | CONFIG_TYPES, 9 | format_kb_value, 10 | format_postgres_values, 11 | format_value, 12 | get_postgres_version, 13 | parse_storage_value, 14 | parse_value, 15 | read_postgresql_conf, 16 | write_postgresql_conf, 17 | ) 18 | 19 | 20 | @pytest.mark.parametrize( 21 | "version_string,expected_version", 22 | [ 23 | ("postgres (PostgreSQL) 16.3 (Homebrew)", 16), 24 | ("postgres (PostgreSQL) 16.6 (Debian 16.6-1.pgdg120+1)", 16), 25 | ], 26 | ) 27 | def test_get_postgres_version(version_string: str, expected_version: int) -> None: 28 | """Test that we can correctly parse different PostgreSQL version strings""" 29 | with patch("subprocess.run") as mock_run: 30 | mock_run.return_value.stdout = version_string 31 | assert get_postgres_version() == expected_version 32 | 33 | 34 | def test_read_postgresql_conf(tmp_path: Path) -> None: 35 | """Test reading PostgreSQL configuration from a file""" 36 | # Create a mock postgresql.conf file 37 | conf_dir = tmp_path / "postgresql" 38 | conf_dir.mkdir() 39 | conf_file = conf_dir / "postgresql.conf" 40 | 41 | test_config = """ 42 | # This is a comment 43 | shared_buffers = 128MB 44 | work_mem = '4MB' 45 | max_connections = 100 46 | invalid_line_without_equals 47 | """ 48 | conf_file.write_text(test_config) 49 | 50 | # Read the configuration 51 | config = read_postgresql_conf(str(conf_dir)) 52 | 53 | # Verify the parsed configuration 54 | assert config == { 55 | "shared_buffers": 128 * 1024, 56 | "work_mem": 4 * 1024, 57 | "max_connections": 100, 58 | } 59 | 60 | # Test with non-existent file 61 | empty_dir = tmp_path / "empty" 62 | empty_dir.mkdir() 63 | assert read_postgresql_conf(str(empty_dir)) == {} 64 | 65 | # Test with .base file taking precedence 66 | base_file = conf_dir / "postgresql.conf.base" 67 | base_file.write_text("shared_buffers = 256MB") 68 | config = read_postgresql_conf(str(conf_dir)) 69 | assert config == {"shared_buffers": 256 * 1024} 70 | 71 | 72 | def test_write_postgresql_conf(tmp_path: Path) -> None: 73 | """Test writing PostgreSQL configuration to a file""" 74 | conf_dir = tmp_path / "postgresql" 75 | conf_dir.mkdir() 76 | 77 | test_config: Dict[str, Any] = { 78 | "shared_buffers": 128 * 1024, 79 | "work_mem": 4, 80 | "max_connections": 100, 81 | "ssl": "on", 82 | } 83 | 84 | # Write the configuration 85 | write_postgresql_conf(format_postgres_values(test_config), str(conf_dir)) 86 | 87 | # Verify the written file 88 | conf_file = conf_dir / "postgresql.conf" 89 | assert conf_file.exists() 90 | 91 | # Read the written content 92 | content = conf_file.read_text() 93 | 94 | # Check header 95 | assert "# Generated by AutoPG" in content 96 | 97 | # Check values are properly formatted 98 | assert "shared_buffers = '128MB'" in content 99 | assert "work_mem = '4kB'" in content 100 | assert "max_connections = 100" in content 101 | assert "ssl = 'on'" in content 102 | 103 | 104 | def test_backup_postgresql_conf(tmp_path: Path) -> None: 105 | """Test backup functionality""" 106 | # Existing configuration file should be backed up 107 | conf_dir = tmp_path / "postgresql" 108 | conf_dir.mkdir() 109 | conf_file = conf_dir / "postgresql.conf" 110 | conf_file.write_text("existing_param = 'value'") 111 | 112 | write_postgresql_conf({"new_param": "value"}, str(conf_dir)) 113 | 114 | # Test backup functionality 115 | base_conf = conf_dir / "postgresql.conf.base" 116 | assert base_conf.exists() # Backup should be created 117 | assert base_conf.read_text() == "existing_param = 'value'" 118 | 119 | 120 | def test_format_postgres_values() -> None: 121 | """Test formatting of configuration values for postgresql.conf""" 122 | input_config: dict[str, CONFIG_TYPES | None] = { 123 | "shared_buffers": 128 * 1024, # Storage value in KB 124 | "max_connections": 100, # Integer 125 | "ssl": "on", # String 126 | "enable_seqscan": True, # Boolean 127 | } 128 | 129 | formatted = format_postgres_values(input_config) 130 | assert formatted == { 131 | "shared_buffers": "'128MB'", 132 | "max_connections": "100", 133 | "ssl": "'on'", 134 | "enable_seqscan": "true", 135 | } 136 | 137 | 138 | # 139 | # Formatting 140 | # 141 | 142 | 143 | @pytest.mark.parametrize( 144 | "input_value,expected_output", 145 | [ 146 | (100, "100"), 147 | (3.14, "3.14"), 148 | ("128MB", "128MB"), 149 | ("on", "on"), 150 | (0, "0"), 151 | (-1, "-1"), 152 | (True, "true"), 153 | (False, "false"), 154 | ], 155 | ) 156 | def test_format_value(input_value: int | float | str | bool, expected_output: str) -> None: 157 | """Test formatting of different configuration value types""" 158 | assert format_value(input_value) == expected_output 159 | 160 | 161 | @pytest.mark.parametrize( 162 | "input_str,expected_output", 163 | [ 164 | ("true", True), 165 | ("false", False), 166 | ("TRUE", True), 167 | ("FALSE", False), 168 | ("123", 123), 169 | ("hello", "hello"), 170 | ("3.14", "3.14"), # Non-integer strings remain strings 171 | ], 172 | ) 173 | def test_parse_value(input_str: str, expected_output: int | float | str | bool) -> None: 174 | """Test parsing of different configuration value types""" 175 | assert parse_value(input_str) == expected_output 176 | 177 | 178 | @pytest.mark.parametrize( 179 | "input_str,expected_kb", 180 | [ 181 | ("1GB", 1048576), # 1GB = 1024 * 1024 KB 182 | ("512MB", 524288), # 512MB = 512 * 1024 KB 183 | ("64kB", 64), # Direct KB value 184 | ("0kB", 0), # Zero case 185 | ], 186 | ) 187 | def test_parse_storage_value(input_str: str, expected_kb: int) -> None: 188 | """Test parsing of storage values into kilobytes""" 189 | assert parse_storage_value(input_str) == expected_kb 190 | 191 | 192 | @pytest.mark.parametrize( 193 | "input_kb,expected_str", 194 | [ 195 | (1048576, "1GB"), # 1GB case 196 | (524288, "512MB"), # MB case 197 | (64, "64kB"), # Direct KB case 198 | (0, "0kB"), # Zero case 199 | # Edge cases between units 200 | (1024, "1MB"), # Exactly 1MB 201 | (2048, "2MB"), # Exactly 2MB 202 | (1073741824, "1024GB"), # Large number 203 | ], 204 | ) 205 | def test_format_kb_value(input_kb: int, expected_str: str) -> None: 206 | """Test formatting of kilobyte values into human readable strings""" 207 | assert format_kb_value(input_kb) == expected_str 208 | -------------------------------------------------------------------------------- /autopg/postgres.py: -------------------------------------------------------------------------------- 1 | import re 2 | import shutil 3 | import subprocess 4 | from pathlib import Path 5 | 6 | from autopg.constants import ( 7 | KNOWN_STORAGE_VARS, 8 | PG_CONFIG_DIR, 9 | PG_CONFIG_FILE, 10 | PG_CONFIG_FILE_BASE, 11 | SIZE_UNIT_MAP, 12 | ) 13 | 14 | CONFIG_TYPES = int | float | str | bool 15 | 16 | # 17 | # Config management 18 | # 19 | 20 | 21 | def read_postgresql_conf(base_path: str = PG_CONFIG_DIR) -> dict[str, CONFIG_TYPES]: 22 | """Read the postgresql.conf file, preferring .base if it exists""" 23 | conf_path = Path(base_path) / PG_CONFIG_FILE 24 | base_conf_path = Path(base_path) / PG_CONFIG_FILE_BASE 25 | 26 | target_path = base_conf_path if base_conf_path.exists() else conf_path 27 | if not target_path.exists(): 28 | return {} 29 | 30 | config: dict[str, CONFIG_TYPES] = {} 31 | with target_path.open() as f: 32 | for line in f: 33 | line = line.strip() 34 | if line and not line.startswith("#"): 35 | try: 36 | key, value = line.split("=", 1) 37 | key = key.strip() 38 | value = value.strip().strip("'") 39 | 40 | if key in KNOWN_STORAGE_VARS: 41 | value = parse_storage_value(value) 42 | else: 43 | value = parse_value(value) 44 | config[key] = value 45 | except ValueError: 46 | continue 47 | return config 48 | 49 | 50 | def format_postgres_values(config: dict[str, CONFIG_TYPES | None]) -> dict[str, str]: 51 | """ 52 | Re-format based on known units. The pipeline is expected to be: 53 | 54 | - format_postgres_values() 55 | - write_postgresql_conf() 56 | 57 | """ 58 | 59 | # These values are ready for direct insertion into the config file 60 | str_config: dict[str, str] = {} 61 | 62 | for key, value in config.items(): 63 | if not value: 64 | continue 65 | 66 | if key in KNOWN_STORAGE_VARS: 67 | # Storage values are always strings 68 | if not isinstance(value, (int, float)): 69 | raise ValueError(f"Storage value {key} is not a kb integer: {value}") 70 | config_value = f"'{format_kb_value(int(value))}'" 71 | else: 72 | # We should only wrap with single quotes if the original value is a string 73 | config_value = format_value(value) 74 | config_value = f"'{config_value}'" if isinstance(value, str) else config_value 75 | str_config[key] = config_value 76 | 77 | return str_config 78 | 79 | 80 | def write_postgresql_conf( 81 | config: dict[str, str], base_path: str = PG_CONFIG_DIR, backup: bool = True 82 | ) -> None: 83 | """Write the postgresql.conf file and optionally backup the old one""" 84 | conf_path = Path(base_path) / PG_CONFIG_FILE 85 | base_conf_path = Path(base_path) / PG_CONFIG_FILE_BASE 86 | 87 | # Backup existing config if requested 88 | if backup and conf_path.exists(): 89 | shutil.copy(conf_path, base_conf_path) 90 | 91 | # Write new config 92 | with open(conf_path, "w") as f: 93 | f.write("# Generated by AutoPG\n\n") 94 | for key, value in sorted(config.items()): 95 | f.write(f"{key} = {value}\n") 96 | 97 | 98 | def write_sql_init_file(sql_content: str, filename: str) -> tuple[bool, Path | None]: 99 | """ 100 | Write SQL initialization file, preferring Docker entrypoint directory if available. 101 | 102 | """ 103 | if not sql_content.strip(): 104 | return False, None 105 | 106 | docker_init_dir = Path("/docker-entrypoint-initdb.d") 107 | 108 | # Try to write to Docker entrypoint directory if it exists 109 | if docker_init_dir.exists() and docker_init_dir.is_dir(): 110 | init_sql_path = docker_init_dir / filename 111 | try: 112 | with open(init_sql_path, "w") as f: 113 | f.write(f"-- Generated by AutoPG\n\n{sql_content}") 114 | return True, init_sql_path 115 | except (OSError, IOError): 116 | # Fall through to non-Docker behavior if write fails 117 | pass 118 | 119 | # In non-Docker environments or if Docker write failed, return SQL content 120 | return False, None 121 | 122 | 123 | # 124 | # Formatters 125 | # 126 | 127 | 128 | def format_value(value: CONFIG_TYPES) -> str: 129 | """Format configuration values appropriately""" 130 | if isinstance(value, bool): 131 | return "true" if value else "false" 132 | elif isinstance(value, (int, float)): 133 | return str(value) 134 | return value 135 | 136 | 137 | def parse_value(value: str) -> CONFIG_TYPES: 138 | """Parse configuration values appropriately""" 139 | if value.lower() in ["true", "false"]: 140 | return value.lower() == "true" 141 | elif value.isdigit(): 142 | return int(value) 143 | return value 144 | 145 | 146 | def format_kb_value(value: int) -> str: 147 | """ 148 | Format a value in kilobytes to a human readable string with appropriate unit. 149 | The function will use the largest unit (GB, MB, KB) that results in a whole number. 150 | 151 | Args: 152 | value: The value in kilobytes to format 153 | 154 | Returns: 155 | A formatted string with the value and unit (e.g. "1GB", "100MB", "64kB") 156 | """ 157 | # 0 is a special case 158 | if value == 0: 159 | return "0kB" 160 | 161 | if value % (SIZE_UNIT_MAP["GB"] // SIZE_UNIT_MAP["KB"]) == 0: 162 | return f"{value // (SIZE_UNIT_MAP['GB'] // SIZE_UNIT_MAP['KB'])}GB" 163 | elif value % (SIZE_UNIT_MAP["MB"] // SIZE_UNIT_MAP["KB"]) == 0: 164 | return f"{value // (SIZE_UNIT_MAP['MB'] // SIZE_UNIT_MAP['KB'])}MB" 165 | return f"{value}kB" 166 | 167 | 168 | def parse_storage_value(value: str) -> int: 169 | """Parse storage values into kb""" 170 | if value.endswith("GB"): 171 | return int(value.strip("GB")) * SIZE_UNIT_MAP["GB"] // SIZE_UNIT_MAP["KB"] 172 | elif value.endswith("MB"): 173 | return int(value.strip("MB")) * SIZE_UNIT_MAP["MB"] // SIZE_UNIT_MAP["KB"] 174 | return int(value.strip("kB")) 175 | 176 | 177 | # 178 | # Helpers 179 | # 180 | 181 | 182 | def get_postgres_version() -> int: 183 | """Get the version of PostgreSQL installed 184 | 185 | Returns: 186 | int: The major version number of PostgreSQL (e.g. 16 for PostgreSQL 16.3) 187 | 188 | Raises: 189 | subprocess.CalledProcessError: If postgres is not installed or command fails 190 | ValueError: If version string cannot be parsed 191 | """ 192 | try: 193 | result = subprocess.run( 194 | ["postgres", "--version"], capture_output=True, text=True, check=True 195 | ) 196 | version_str = result.stdout.strip() 197 | # Use regex to find version number pattern (e.g. "16.3" in "postgres (PostgreSQL) 16.3 (Homebrew)") 198 | version_match = re.search(r"(\d+)\.?\d*", version_str) 199 | if not version_match: 200 | raise ValueError("Could not find version number in postgres output") 201 | return int(version_match.group(1)) 202 | except (subprocess.CalledProcessError, ValueError) as e: 203 | raise ValueError(f"Failed to get PostgreSQL version: {str(e)}") from e 204 | -------------------------------------------------------------------------------- /benchmarks/benchmarks/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions for benchmarking operations. 3 | """ 4 | 5 | import statistics 6 | import time 7 | from typing import Any, Dict, Generator, List, Optional, Union 8 | 9 | 10 | def format_duration(seconds: float) -> str: 11 | """Format duration in seconds to human-readable string.""" 12 | if seconds < 0.001: 13 | return f"{seconds * 1000000:.1f}μs" 14 | elif seconds < 1: 15 | return f"{seconds * 1000:.1f}ms" 16 | elif seconds < 60: 17 | return f"{seconds:.2f}s" 18 | elif seconds < 3600: 19 | minutes = int(seconds // 60) 20 | secs = seconds % 60 21 | return f"{minutes}m {secs:.1f}s" 22 | else: 23 | hours = int(seconds // 3600) 24 | minutes = int((seconds % 3600) // 60) 25 | secs = seconds % 60 26 | return f"{hours}h {minutes}m {secs:.1f}s" 27 | 28 | 29 | def format_number(num: Union[int, float]) -> str: 30 | """Format large numbers with appropriate suffixes.""" 31 | if isinstance(num, float) and num < 1: 32 | return f"{num:.3f}" 33 | 34 | num = int(num) if isinstance(num, float) else num 35 | 36 | if num < 1000: 37 | return str(num) 38 | elif num < 1000000: 39 | return f"{num / 1000:.1f}K" 40 | elif num < 1000000000: 41 | return f"{num / 1000000:.1f}M" 42 | else: 43 | return f"{num / 1000000000:.1f}B" 44 | 45 | 46 | def calculate_statistics(values: List[float]) -> Dict[str, float]: 47 | """Calculate statistical metrics from a list of values.""" 48 | if not values: 49 | return {} 50 | 51 | sorted_values = sorted(values) 52 | 53 | return { 54 | "min": min(values), 55 | "max": max(values), 56 | "mean": statistics.mean(values), 57 | "median": statistics.median(values), 58 | "std_dev": statistics.stdev(values) if len(values) > 1 else 0, 59 | "p95": sorted_values[int(0.95 * len(sorted_values))], 60 | "p99": sorted_values[int(0.99 * len(sorted_values))], 61 | } 62 | 63 | 64 | def chunks(lst: List[Any], n: int) -> Generator[List[Any], None, None]: 65 | """Yield successive n-sized chunks from lst.""" 66 | for i in range(0, len(lst), n): 67 | yield lst[i : i + n] 68 | 69 | 70 | class Timer: 71 | """Simple timer class for measuring execution time.""" 72 | 73 | def __init__(self): 74 | self.start_time = None 75 | self.end_time = None 76 | 77 | def start(self) -> None: 78 | """Start the timer.""" 79 | self.start_time = time.time() 80 | 81 | def stop(self) -> float: 82 | """Stop the timer and return elapsed time.""" 83 | self.end_time = time.time() 84 | if self.start_time is None: 85 | raise RuntimeError("Timer was not started") 86 | return self.end_time - self.start_time 87 | 88 | def elapsed(self) -> float: 89 | """Get elapsed time without stopping the timer.""" 90 | if self.start_time is None: 91 | raise RuntimeError("Timer was not started") 92 | return time.time() - self.start_time 93 | 94 | def __enter__(self): 95 | self.start() 96 | return self 97 | 98 | def __exit__( 99 | self, exc_type: Optional[type], exc_val: Optional[BaseException], exc_tb: Optional[Any] 100 | ) -> None: 101 | self.stop() 102 | 103 | 104 | def generate_random_string(length: int = 10) -> str: 105 | """Generate a random string of specified length.""" 106 | import random 107 | import string 108 | 109 | return "".join(random.choices(string.ascii_lowercase + string.digits, k=length)) 110 | 111 | 112 | def generate_random_email() -> str: 113 | """Generate a random email address.""" 114 | import random 115 | 116 | domains = ["example.com", "test.org", "demo.net", "sample.co"] 117 | username = generate_random_string(8) 118 | domain = random.choice(domains) 119 | return f"{username}@{domain}" 120 | 121 | 122 | def generate_random_text(min_words: int = 5, max_words: int = 50) -> str: 123 | """Generate random text with specified word count range.""" 124 | import random 125 | 126 | words = [ 127 | "lorem", 128 | "ipsum", 129 | "dolor", 130 | "sit", 131 | "amet", 132 | "consectetur", 133 | "adipiscing", 134 | "elit", 135 | "sed", 136 | "do", 137 | "eiusmod", 138 | "tempor", 139 | "incididunt", 140 | "ut", 141 | "labore", 142 | "et", 143 | "dolore", 144 | "magna", 145 | "aliqua", 146 | "enim", 147 | "ad", 148 | "minim", 149 | "veniam", 150 | "quis", 151 | "nostrud", 152 | "exercitation", 153 | "ullamco", 154 | "laboris", 155 | "nisi", 156 | "aliquip", 157 | "ex", 158 | "ea", 159 | "commodo", 160 | "consequat", 161 | "duis", 162 | "aute", 163 | "irure", 164 | "in", 165 | "reprehenderit", 166 | "voluptate", 167 | "velit", 168 | "esse", 169 | "cillum", 170 | "fugiat", 171 | "nulla", 172 | "pariatur", 173 | "excepteur", 174 | "sint", 175 | "occaecat", 176 | "cupidatat", 177 | "non", 178 | "proident", 179 | "sunt", 180 | "culpa", 181 | "qui", 182 | "officia", 183 | "deserunt", 184 | "mollit", 185 | "anim", 186 | "id", 187 | "est", 188 | "laborum", 189 | ] 190 | 191 | word_count = random.randint(min_words, max_words) 192 | selected_words = random.choices(words, k=word_count) 193 | 194 | # Capitalize first word 195 | if selected_words: 196 | selected_words[0] = selected_words[0].capitalize() 197 | 198 | return " ".join(selected_words) + "." 199 | 200 | 201 | def create_progress_callback(total: int, description: str = "Processing"): 202 | """Create a progress callback function for Rich progress bars.""" 203 | from rich.progress import ( 204 | BarColumn, 205 | Progress, 206 | SpinnerColumn, 207 | TaskProgressColumn, 208 | TextColumn, 209 | TimeRemainingColumn, 210 | ) 211 | 212 | progress = Progress( 213 | SpinnerColumn(), 214 | TextColumn("[progress.description]{task.description}"), 215 | BarColumn(), 216 | TaskProgressColumn(), 217 | TimeRemainingColumn(), 218 | console=None, 219 | transient=True, 220 | ) 221 | 222 | task_id = progress.add_task(description, total=total) 223 | 224 | def callback(completed: int): 225 | progress.update(task_id, completed=completed) 226 | 227 | return progress, callback 228 | 229 | 230 | def batch_execute_with_progress( 231 | db: Any, query: str, param_batches: List[List[Any]], description: str = "Executing batches" 232 | ) -> List[float]: 233 | """Execute query batches with progress display and timing.""" 234 | from rich.progress import ( 235 | BarColumn, 236 | Progress, 237 | SpinnerColumn, 238 | TaskProgressColumn, 239 | TextColumn, 240 | TimeRemainingColumn, 241 | ) 242 | 243 | batch_times = [] 244 | 245 | with Progress( 246 | SpinnerColumn(), 247 | TextColumn("[progress.description]{task.description}"), 248 | BarColumn(), 249 | TaskProgressColumn(), 250 | TimeRemainingColumn(), 251 | transient=False, 252 | ) as progress: 253 | task = progress.add_task(description, total=len(param_batches)) 254 | 255 | for i, batch_params in enumerate(param_batches): 256 | timer = Timer() 257 | timer.start() 258 | 259 | db.execute_many(query, batch_params) 260 | 261 | batch_time = timer.stop() 262 | batch_times.append(batch_time) 263 | 264 | progress.update(task, completed=i + 1) 265 | 266 | return batch_times 267 | -------------------------------------------------------------------------------- /autopg/__tests__/test_logic.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for this core logic are almost fully written from the original Javascript: 3 | https://github.com/le0pard/pgtune/blob/9ae57d0a97ba6c597390d43b15cd428311327939/src/features/configuration/__tests__/configurationSlice.test.js 4 | 5 | """ 6 | 7 | import pytest 8 | 9 | from autopg.constants import ( 10 | DB_TYPE_DESKTOP, 11 | DB_TYPE_DW, 12 | DB_TYPE_MIXED, 13 | DB_TYPE_OLTP, 14 | DB_TYPE_WEB, 15 | HARD_DRIVE_HDD, 16 | HARD_DRIVE_SAN, 17 | HARD_DRIVE_SSD, 18 | OS_LINUX, 19 | OS_WINDOWS, 20 | ) 21 | from autopg.logic import Configuration, PostgresConfig 22 | 23 | 24 | def test_pg_stat_statements_enabled_by_default() -> None: 25 | config = PostgresConfig(Configuration()) 26 | pg_stat_config = config.get_pg_stat_statements_config() 27 | 28 | assert pg_stat_config == { 29 | "shared_preload_libraries": "pg_stat_statements", 30 | "pg_stat_statements.track": "all", 31 | "pg_stat_statements.max": 10000, 32 | } 33 | 34 | 35 | def test_pg_stat_statements_disabled() -> None: 36 | config = PostgresConfig(Configuration(enable_pg_stat_statements=False)) 37 | pg_stat_config = config.get_pg_stat_statements_config() 38 | 39 | assert pg_stat_config == {} 40 | 41 | 42 | def test_pg_stat_statements_enabled_explicitly() -> None: 43 | config = PostgresConfig(Configuration(enable_pg_stat_statements=True)) 44 | pg_stat_config = config.get_pg_stat_statements_config() 45 | 46 | assert pg_stat_config == { 47 | "shared_preload_libraries": "pg_stat_statements", 48 | "pg_stat_statements.track": "all", 49 | "pg_stat_statements.max": 10000, 50 | } 51 | 52 | 53 | def test_is_configured_nothing_set() -> None: 54 | config = PostgresConfig(Configuration()) 55 | assert config.state.total_memory is None 56 | 57 | 58 | def test_is_configured_with_memory() -> None: 59 | config = PostgresConfig( 60 | Configuration( 61 | total_memory=100, 62 | db_version=14.0, 63 | os_type=OS_LINUX, 64 | db_type=DB_TYPE_WEB, 65 | total_memory_unit="GB", 66 | hd_type=HARD_DRIVE_SSD, 67 | ) 68 | ) 69 | assert config.state.total_memory == 100 70 | 71 | 72 | @pytest.mark.parametrize( 73 | "db_type,expected", 74 | [ 75 | (DB_TYPE_WEB, 200), 76 | (DB_TYPE_OLTP, 300), 77 | (DB_TYPE_DW, 40), 78 | (DB_TYPE_DESKTOP, 20), 79 | (DB_TYPE_MIXED, 100), 80 | ], 81 | ) 82 | def test_max_connections(db_type: str, expected: int) -> None: 83 | config = PostgresConfig( 84 | Configuration( 85 | db_type=db_type, 86 | db_version=14.0, 87 | os_type=OS_LINUX, 88 | total_memory_unit="GB", 89 | hd_type=HARD_DRIVE_SSD, 90 | ) 91 | ) 92 | assert config.get_max_connections() == expected 93 | 94 | 95 | @pytest.mark.parametrize( 96 | "db_type,expected", 97 | [ 98 | (DB_TYPE_WEB, 100), 99 | (DB_TYPE_OLTP, 100), 100 | (DB_TYPE_DW, 500), 101 | (DB_TYPE_DESKTOP, 100), 102 | (DB_TYPE_MIXED, 100), 103 | ], 104 | ) 105 | def test_default_statistics_target(db_type: str, expected: int) -> None: 106 | config = PostgresConfig( 107 | Configuration( 108 | db_type=db_type, 109 | db_version=14.0, 110 | os_type=OS_LINUX, 111 | total_memory_unit="GB", 112 | hd_type=HARD_DRIVE_SSD, 113 | ) 114 | ) 115 | assert config.get_default_statistics_target() == expected 116 | 117 | 118 | @pytest.mark.parametrize( 119 | "hd_type,expected", 120 | [(HARD_DRIVE_HDD, 4.0), (HARD_DRIVE_SSD, 1.1), (HARD_DRIVE_SAN, 1.1)], 121 | ) 122 | def test_random_page_cost(hd_type: str, expected: float) -> None: 123 | config = PostgresConfig( 124 | Configuration( 125 | hd_type=hd_type, 126 | db_version=14.0, 127 | os_type=OS_LINUX, 128 | db_type=DB_TYPE_WEB, 129 | total_memory_unit="GB", 130 | ) 131 | ) 132 | assert config.get_random_page_cost() == expected 133 | 134 | 135 | @pytest.mark.parametrize( 136 | "os_type,hd_type,expected", 137 | [ 138 | (OS_LINUX, HARD_DRIVE_HDD, 2), 139 | (OS_LINUX, HARD_DRIVE_SSD, 200), 140 | (OS_LINUX, HARD_DRIVE_SAN, 300), 141 | (OS_WINDOWS, HARD_DRIVE_SSD, None), 142 | ], 143 | ) 144 | def test_effective_io_concurrency(os_type: str, hd_type: str, expected: int | None) -> None: 145 | config = PostgresConfig( 146 | Configuration( 147 | os_type=os_type, 148 | hd_type=hd_type, 149 | db_version=14.0, 150 | db_type=DB_TYPE_WEB, 151 | total_memory_unit="GB", 152 | ) 153 | ) 154 | assert config.get_effective_io_concurrency() == expected 155 | 156 | 157 | def test_parallel_settings_less_than_2_cpu() -> None: 158 | config = PostgresConfig( 159 | Configuration( 160 | cpu_num=1, 161 | db_version=14.0, 162 | os_type=OS_LINUX, 163 | db_type=DB_TYPE_WEB, 164 | total_memory_unit="GB", 165 | ) 166 | ) 167 | assert config.get_parallel_settings() == {} 168 | 169 | 170 | def test_parallel_settings_postgresql_13() -> None: 171 | config = PostgresConfig( 172 | Configuration( 173 | db_version=13.0, 174 | cpu_num=12, 175 | os_type=OS_LINUX, 176 | db_type=DB_TYPE_WEB, 177 | total_memory_unit="GB", 178 | ) 179 | ) 180 | assert config.get_parallel_settings() == { 181 | "max_worker_processes": 12, 182 | "max_parallel_workers_per_gather": 4, 183 | "max_parallel_workers": 12, 184 | "max_parallel_maintenance_workers": 4, 185 | } 186 | 187 | 188 | def test_parallel_settings_postgresql_10() -> None: 189 | config = PostgresConfig( 190 | Configuration( 191 | db_version=10.0, 192 | cpu_num=12, 193 | os_type=OS_LINUX, 194 | db_type=DB_TYPE_WEB, 195 | total_memory_unit="GB", 196 | ) 197 | ) 198 | assert config.get_parallel_settings() == { 199 | "max_worker_processes": 12, 200 | "max_parallel_workers_per_gather": 4, 201 | "max_parallel_workers": 12, 202 | } 203 | 204 | 205 | def test_parallel_settings_postgresql_10_with_31_cpu() -> None: 206 | config = PostgresConfig( 207 | Configuration( 208 | db_version=10.0, 209 | cpu_num=31, 210 | os_type=OS_LINUX, 211 | db_type=DB_TYPE_WEB, 212 | total_memory_unit="GB", 213 | ) 214 | ) 215 | assert config.get_parallel_settings() == { 216 | "max_worker_processes": 31, 217 | "max_parallel_workers_per_gather": 4, 218 | "max_parallel_workers": 31, 219 | } 220 | 221 | 222 | def test_parallel_settings_postgresql_12_with_31_cpu_and_dwh() -> None: 223 | config = PostgresConfig( 224 | Configuration( 225 | db_version=12.0, 226 | cpu_num=31, 227 | db_type=DB_TYPE_DW, 228 | os_type=OS_LINUX, 229 | total_memory_unit="GB", 230 | hd_type=HARD_DRIVE_SSD, 231 | ) 232 | ) 233 | assert config.get_parallel_settings() == { 234 | "max_worker_processes": 31, 235 | "max_parallel_workers_per_gather": 16, 236 | "max_parallel_workers": 31, 237 | "max_parallel_maintenance_workers": 4, 238 | } 239 | 240 | 241 | @pytest.mark.parametrize( 242 | "db_type,expected", 243 | [ 244 | ( 245 | DB_TYPE_DESKTOP, 246 | {"wal_level": "minimal", "max_wal_senders": "0"}, 247 | ), 248 | (DB_TYPE_WEB, {}), 249 | ], 250 | ) 251 | def test_wal_level(db_type: str, expected: list[dict[str, str]]) -> None: 252 | config = PostgresConfig( 253 | Configuration( 254 | db_type=db_type, 255 | db_version=14.0, 256 | os_type=OS_LINUX, 257 | total_memory_unit="GB", 258 | hd_type=HARD_DRIVE_SSD, 259 | ) 260 | ) 261 | assert config.get_wal_level() == expected 262 | -------------------------------------------------------------------------------- /benchmarks/uv.lock: -------------------------------------------------------------------------------- 1 | version = 1 2 | revision = 3 3 | requires-python = ">=3.12" 4 | 5 | [[package]] 6 | name = "asyncpg" 7 | version = "0.30.0" 8 | source = { registry = "https://pypi.org/simple" } 9 | sdist = { url = "https://files.pythonhosted.org/packages/2f/4c/7c991e080e106d854809030d8584e15b2e996e26f16aee6d757e387bc17d/asyncpg-0.30.0.tar.gz", hash = "sha256:c551e9928ab6707602f44811817f82ba3c446e018bfe1d3abecc8ba5f3eac851", size = 957746, upload-time = "2024-10-20T00:30:41.127Z" } 10 | wheels = [ 11 | { url = "https://files.pythonhosted.org/packages/4b/64/9d3e887bb7b01535fdbc45fbd5f0a8447539833b97ee69ecdbb7a79d0cb4/asyncpg-0.30.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c902a60b52e506d38d7e80e0dd5399f657220f24635fee368117b8b5fce1142e", size = 673162, upload-time = "2024-10-20T00:29:41.88Z" }, 12 | { url = "https://files.pythonhosted.org/packages/6e/eb/8b236663f06984f212a087b3e849731f917ab80f84450e943900e8ca4052/asyncpg-0.30.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:aca1548e43bbb9f0f627a04666fedaca23db0a31a84136ad1f868cb15deb6e3a", size = 637025, upload-time = "2024-10-20T00:29:43.352Z" }, 13 | { url = "https://files.pythonhosted.org/packages/cc/57/2dc240bb263d58786cfaa60920779af6e8d32da63ab9ffc09f8312bd7a14/asyncpg-0.30.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c2a2ef565400234a633da0eafdce27e843836256d40705d83ab7ec42074efb3", size = 3496243, upload-time = "2024-10-20T00:29:44.922Z" }, 14 | { url = "https://files.pythonhosted.org/packages/f4/40/0ae9d061d278b10713ea9021ef6b703ec44698fe32178715a501ac696c6b/asyncpg-0.30.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1292b84ee06ac8a2ad8e51c7475aa309245874b61333d97411aab835c4a2f737", size = 3575059, upload-time = "2024-10-20T00:29:46.891Z" }, 15 | { url = "https://files.pythonhosted.org/packages/c3/75/d6b895a35a2c6506952247640178e5f768eeb28b2e20299b6a6f1d743ba0/asyncpg-0.30.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:0f5712350388d0cd0615caec629ad53c81e506b1abaaf8d14c93f54b35e3595a", size = 3473596, upload-time = "2024-10-20T00:29:49.201Z" }, 16 | { url = "https://files.pythonhosted.org/packages/c8/e7/3693392d3e168ab0aebb2d361431375bd22ffc7b4a586a0fc060d519fae7/asyncpg-0.30.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:db9891e2d76e6f425746c5d2da01921e9a16b5a71a1c905b13f30e12a257c4af", size = 3641632, upload-time = "2024-10-20T00:29:50.768Z" }, 17 | { url = "https://files.pythonhosted.org/packages/32/ea/15670cea95745bba3f0352341db55f506a820b21c619ee66b7d12ea7867d/asyncpg-0.30.0-cp312-cp312-win32.whl", hash = "sha256:68d71a1be3d83d0570049cd1654a9bdfe506e794ecc98ad0873304a9f35e411e", size = 560186, upload-time = "2024-10-20T00:29:52.394Z" }, 18 | { url = "https://files.pythonhosted.org/packages/7e/6b/fe1fad5cee79ca5f5c27aed7bd95baee529c1bf8a387435c8ba4fe53d5c1/asyncpg-0.30.0-cp312-cp312-win_amd64.whl", hash = "sha256:9a0292c6af5c500523949155ec17b7fe01a00ace33b68a476d6b5059f9630305", size = 621064, upload-time = "2024-10-20T00:29:53.757Z" }, 19 | { url = "https://files.pythonhosted.org/packages/3a/22/e20602e1218dc07692acf70d5b902be820168d6282e69ef0d3cb920dc36f/asyncpg-0.30.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:05b185ebb8083c8568ea8a40e896d5f7af4b8554b64d7719c0eaa1eb5a5c3a70", size = 670373, upload-time = "2024-10-20T00:29:55.165Z" }, 20 | { url = "https://files.pythonhosted.org/packages/3d/b3/0cf269a9d647852a95c06eb00b815d0b95a4eb4b55aa2d6ba680971733b9/asyncpg-0.30.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:c47806b1a8cbb0a0db896f4cd34d89942effe353a5035c62734ab13b9f938da3", size = 634745, upload-time = "2024-10-20T00:29:57.14Z" }, 21 | { url = "https://files.pythonhosted.org/packages/8e/6d/a4f31bf358ce8491d2a31bfe0d7bcf25269e80481e49de4d8616c4295a34/asyncpg-0.30.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9b6fde867a74e8c76c71e2f64f80c64c0f3163e687f1763cfaf21633ec24ec33", size = 3512103, upload-time = "2024-10-20T00:29:58.499Z" }, 22 | { url = "https://files.pythonhosted.org/packages/96/19/139227a6e67f407b9c386cb594d9628c6c78c9024f26df87c912fabd4368/asyncpg-0.30.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:46973045b567972128a27d40001124fbc821c87a6cade040cfcd4fa8a30bcdc4", size = 3592471, upload-time = "2024-10-20T00:30:00.354Z" }, 23 | { url = "https://files.pythonhosted.org/packages/67/e4/ab3ca38f628f53f0fd28d3ff20edff1c975dd1cb22482e0061916b4b9a74/asyncpg-0.30.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:9110df111cabc2ed81aad2f35394a00cadf4f2e0635603db6ebbd0fc896f46a4", size = 3496253, upload-time = "2024-10-20T00:30:02.794Z" }, 24 | { url = "https://files.pythonhosted.org/packages/ef/5f/0bf65511d4eeac3a1f41c54034a492515a707c6edbc642174ae79034d3ba/asyncpg-0.30.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:04ff0785ae7eed6cc138e73fc67b8e51d54ee7a3ce9b63666ce55a0bf095f7ba", size = 3662720, upload-time = "2024-10-20T00:30:04.501Z" }, 25 | { url = "https://files.pythonhosted.org/packages/e7/31/1513d5a6412b98052c3ed9158d783b1e09d0910f51fbe0e05f56cc370bc4/asyncpg-0.30.0-cp313-cp313-win32.whl", hash = "sha256:ae374585f51c2b444510cdf3595b97ece4f233fde739aa14b50e0d64e8a7a590", size = 560404, upload-time = "2024-10-20T00:30:06.537Z" }, 26 | { url = "https://files.pythonhosted.org/packages/c8/a4/cec76b3389c4c5ff66301cd100fe88c318563ec8a520e0b2e792b5b84972/asyncpg-0.30.0-cp313-cp313-win_amd64.whl", hash = "sha256:f59b430b8e27557c3fb9869222559f7417ced18688375825f8f12302c34e915e", size = 621623, upload-time = "2024-10-20T00:30:09.024Z" }, 27 | ] 28 | 29 | [[package]] 30 | name = "benchmarks" 31 | version = "0.1.0" 32 | source = { editable = "." } 33 | dependencies = [ 34 | { name = "asyncpg" }, 35 | { name = "click" }, 36 | { name = "rich" }, 37 | ] 38 | 39 | [package.metadata] 40 | requires-dist = [ 41 | { name = "asyncpg", specifier = ">=0.30.0" }, 42 | { name = "click", specifier = ">=8.1.8" }, 43 | { name = "rich", specifier = ">=13.9.4" }, 44 | ] 45 | 46 | [[package]] 47 | name = "click" 48 | version = "8.2.1" 49 | source = { registry = "https://pypi.org/simple" } 50 | dependencies = [ 51 | { name = "colorama", marker = "sys_platform == 'win32'" }, 52 | ] 53 | sdist = { url = "https://files.pythonhosted.org/packages/60/6c/8ca2efa64cf75a977a0d7fac081354553ebe483345c734fb6b6515d96bbc/click-8.2.1.tar.gz", hash = "sha256:27c491cc05d968d271d5a1db13e3b5a184636d9d930f148c50b038f0d0646202", size = 286342, upload-time = "2025-05-20T23:19:49.832Z" } 54 | wheels = [ 55 | { url = "https://files.pythonhosted.org/packages/85/32/10bb5764d90a8eee674e9dc6f4db6a0ab47c8c4d0d83c27f7c39ac415a4d/click-8.2.1-py3-none-any.whl", hash = "sha256:61a3265b914e850b85317d0b3109c7f8cd35a670f963866005d6ef1d5175a12b", size = 102215, upload-time = "2025-05-20T23:19:47.796Z" }, 56 | ] 57 | 58 | [[package]] 59 | name = "colorama" 60 | version = "0.4.6" 61 | source = { registry = "https://pypi.org/simple" } 62 | sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697, upload-time = "2022-10-25T02:36:22.414Z" } 63 | wheels = [ 64 | { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" }, 65 | ] 66 | 67 | [[package]] 68 | name = "markdown-it-py" 69 | version = "4.0.0" 70 | source = { registry = "https://pypi.org/simple" } 71 | dependencies = [ 72 | { name = "mdurl" }, 73 | ] 74 | sdist = { url = "https://files.pythonhosted.org/packages/5b/f5/4ec618ed16cc4f8fb3b701563655a69816155e79e24a17b651541804721d/markdown_it_py-4.0.0.tar.gz", hash = "sha256:cb0a2b4aa34f932c007117b194e945bd74e0ec24133ceb5bac59009cda1cb9f3", size = 73070, upload-time = "2025-08-11T12:57:52.854Z" } 75 | wheels = [ 76 | { url = "https://files.pythonhosted.org/packages/94/54/e7d793b573f298e1c9013b8c4dade17d481164aa517d1d7148619c2cedbf/markdown_it_py-4.0.0-py3-none-any.whl", hash = "sha256:87327c59b172c5011896038353a81343b6754500a08cd7a4973bb48c6d578147", size = 87321, upload-time = "2025-08-11T12:57:51.923Z" }, 77 | ] 78 | 79 | [[package]] 80 | name = "mdurl" 81 | version = "0.1.2" 82 | source = { registry = "https://pypi.org/simple" } 83 | sdist = { url = "https://files.pythonhosted.org/packages/d6/54/cfe61301667036ec958cb99bd3efefba235e65cdeb9c84d24a8293ba1d90/mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba", size = 8729, upload-time = "2022-08-14T12:40:10.846Z" } 84 | wheels = [ 85 | { url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979, upload-time = "2022-08-14T12:40:09.779Z" }, 86 | ] 87 | 88 | [[package]] 89 | name = "pygments" 90 | version = "2.19.2" 91 | source = { registry = "https://pypi.org/simple" } 92 | sdist = { url = "https://files.pythonhosted.org/packages/b0/77/a5b8c569bf593b0140bde72ea885a803b82086995367bf2037de0159d924/pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887", size = 4968631, upload-time = "2025-06-21T13:39:12.283Z" } 93 | wheels = [ 94 | { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" }, 95 | ] 96 | 97 | [[package]] 98 | name = "rich" 99 | version = "14.1.0" 100 | source = { registry = "https://pypi.org/simple" } 101 | dependencies = [ 102 | { name = "markdown-it-py" }, 103 | { name = "pygments" }, 104 | ] 105 | sdist = { url = "https://files.pythonhosted.org/packages/fe/75/af448d8e52bf1d8fa6a9d089ca6c07ff4453d86c65c145d0a300bb073b9b/rich-14.1.0.tar.gz", hash = "sha256:e497a48b844b0320d45007cdebfeaeed8db2a4f4bcf49f15e455cfc4af11eaa8", size = 224441, upload-time = "2025-07-25T07:32:58.125Z" } 106 | wheels = [ 107 | { url = "https://files.pythonhosted.org/packages/e3/30/3c4d035596d3cf444529e0b2953ad0466f6049528a879d27534700580395/rich-14.1.0-py3-none-any.whl", hash = "sha256:536f5f1785986d6dbdea3c75205c473f970777b4a0d6c6dd1b696aa05a3fa04f", size = 243368, upload-time = "2025-07-25T07:32:56.73Z" }, 108 | ] 109 | -------------------------------------------------------------------------------- /autopg/cli.py: -------------------------------------------------------------------------------- 1 | import platform 2 | import sys 3 | from dataclasses import dataclass 4 | from enum import StrEnum 5 | from pathlib import Path 6 | from typing import Any, Dict 7 | 8 | import click 9 | from pydantic_settings import BaseSettings 10 | from rich.console import Console 11 | from rich.table import Table 12 | 13 | from autopg.constants import ( 14 | DB_TYPE_WEB, 15 | HARD_DRIVE_SSD, 16 | OS_LINUX, 17 | OS_MAC, 18 | OS_WINDOWS, 19 | SIZE_UNIT_MB, 20 | ) 21 | from autopg.logic import Configuration, PostgresConfig 22 | from autopg.postgres import ( 23 | CONFIG_TYPES, 24 | format_postgres_values, 25 | get_postgres_version, 26 | read_postgresql_conf, 27 | write_postgresql_conf, 28 | write_sql_init_file, 29 | ) 30 | from autopg.system_info import DiskType, get_cpu_info, get_disk_type, get_memory_info 31 | 32 | console = Console() 33 | 34 | 35 | class DBType(StrEnum): 36 | WEB = "web" 37 | """ 38 | Web Application 39 | Typically CPU-bound, DB much smaller than RAM, 90% or more simple queries 40 | """ 41 | 42 | OLTP = "oltp" 43 | """ 44 | Online Transaction Processing 45 | Typically CPU- or I/O-bound, DB slightly larger than RAM to 1TB, 20-40% small data write queries 46 | """ 47 | 48 | DW = "dw" 49 | """ 50 | Data Warehouse 51 | Typically I/O- or RAM-bound, large bulk loads of data, large complex reporting queries 52 | """ 53 | 54 | DESKTOP = "desktop" 55 | """ 56 | Desktop Application 57 | Not a dedicated database, general workstation use 58 | """ 59 | 60 | MIXED = "mixed" 61 | """ 62 | Mixed Type 63 | Mixed DW and OLTP characteristics, wide mixture of queries 64 | """ 65 | 66 | 67 | @dataclass 68 | class DBDefinition: 69 | name: str 70 | description: str 71 | 72 | 73 | class EnvOverrides(BaseSettings): 74 | """ 75 | Users can optionally override our detected system information. These are reasonable 76 | defaults for most applications where we have no other context. 77 | """ 78 | 79 | DB_TYPE: DBType = DBType.WEB 80 | TOTAL_MEMORY_MB: int | None = None 81 | CPU_COUNT: int | None = None 82 | NUM_CONNECTIONS: int | None = 100 83 | PRIMARY_DISK_TYPE: DiskType | None = None 84 | ENABLE_PG_STAT_STATEMENTS: bool = True 85 | 86 | model_config = {"env_file": ".env", "env_prefix": "AUTOPG_"} 87 | 88 | 89 | def get_os_type() -> str: 90 | system = platform.system().lower() 91 | if system == "darwin": 92 | return OS_MAC 93 | elif system == "windows": 94 | return OS_WINDOWS 95 | return OS_LINUX 96 | 97 | 98 | def display_config_diff(old_config: Dict[str, Any], new_config: Dict[str, Any]) -> None: 99 | """Display the configuration differences in a rich table""" 100 | table = Table(title="Autopg Configuration") 101 | table.add_column("Parameter") 102 | table.add_column("Old Value") 103 | table.add_column("New Value") 104 | table.add_column("Source") 105 | 106 | all_keys = sorted(set(old_config.keys()) | set(new_config.keys())) 107 | for key in all_keys: 108 | old_val = old_config.get(key, "") 109 | new_val = new_config.get(key, "") 110 | source = "Existing" if key in old_config else "AutoPG" 111 | 112 | if old_val != new_val: 113 | table.add_row(key, str(old_val), str(new_val), source) 114 | 115 | console.print(table) 116 | 117 | 118 | def display_detected_params(config: Configuration) -> None: 119 | """Display the detected system parameters in a rich table""" 120 | table = Table(title="Detected System Parameters") 121 | table.add_column("Parameter") 122 | table.add_column("Value") 123 | 124 | # Add all configuration parameters 125 | table.add_row("Database Version", str(config.db_version)) 126 | table.add_row("Operating System", config.os_type) 127 | table.add_row("Database Type", config.db_type) 128 | table.add_row("Total Memory (MB)", str(config.total_memory)) 129 | table.add_row("Memory Unit", config.total_memory_unit) 130 | table.add_row("CPU Count", str(config.cpu_num)) 131 | table.add_row("Connection Count", str(config.connection_num)) 132 | table.add_row("Hard Drive Type", config.hd_type) 133 | table.add_row( 134 | "pg_stat_statements", "Enabled" if config.enable_pg_stat_statements else "Disabled" 135 | ) 136 | 137 | console.print(table) 138 | console.print() 139 | 140 | 141 | @click.group() 142 | def cli() -> None: 143 | """AutoPG CLI tool for PostgreSQL configuration and system analysis.""" 144 | pass 145 | 146 | 147 | @cli.command() 148 | def webapp() -> None: 149 | """Start the AutoPG diagnostics web application.""" 150 | from autopg.webapp import start_webapp 151 | 152 | start_webapp() 153 | 154 | 155 | @cli.command() 156 | @click.option( 157 | "--pg-path", default="/etc/postgresql", help="Path to PostgreSQL configuration directory" 158 | ) 159 | def build_config(pg_path: str) -> None: 160 | """Build a PostgreSQL configuration based on workload and system characteristics.""" 161 | # Load environment overrides 162 | env = EnvOverrides() 163 | 164 | # Get system information 165 | memory_info = get_memory_info() 166 | cpu_info = get_cpu_info() 167 | disk_type = get_disk_type() 168 | os_type = get_os_type() 169 | postgres_version = get_postgres_version() 170 | 171 | # Configure with detected values, allowing env overrides 172 | config_payload = Configuration( 173 | db_version=postgres_version, 174 | os_type=os_type, 175 | db_type=env.DB_TYPE or DB_TYPE_WEB, 176 | total_memory=( 177 | (int(env.TOTAL_MEMORY_MB) if env.TOTAL_MEMORY_MB else None) 178 | or (int(memory_info.total * 1024) if memory_info.total else None) 179 | ), 180 | total_memory_unit=SIZE_UNIT_MB, 181 | cpu_num=env.CPU_COUNT or cpu_info.count, 182 | connection_num=env.NUM_CONNECTIONS, 183 | hd_type=env.PRIMARY_DISK_TYPE or disk_type or HARD_DRIVE_SSD, 184 | enable_pg_stat_statements=env.ENABLE_PG_STAT_STATEMENTS, 185 | ) 186 | 187 | # Display detected parameters 188 | display_detected_params(config_payload) 189 | 190 | # Initialize PostgreSQL config calculator 191 | pg_config = PostgresConfig(config_payload) 192 | 193 | # Calculate recommended settings 194 | new_config: dict[str, CONFIG_TYPES | None] = { 195 | "shared_buffers": pg_config.get_shared_buffers(), 196 | "effective_cache_size": pg_config.get_effective_cache_size(), 197 | "maintenance_work_mem": pg_config.get_maintenance_work_mem(), 198 | "work_mem": pg_config.get_work_mem(), 199 | "huge_pages": pg_config.get_huge_pages(), 200 | "default_statistics_target": pg_config.get_default_statistics_target(), 201 | "random_page_cost": pg_config.get_random_page_cost(), 202 | "checkpoint_completion_target": pg_config.get_checkpoint_completion_target(), 203 | "max_connections": pg_config.get_max_connections(), 204 | } 205 | 206 | # Add WAL settings 207 | new_config = {**new_config, **pg_config.get_checkpoint_segments()} 208 | 209 | # Add parallel settings 210 | new_config = {**new_config, **pg_config.get_parallel_settings()} 211 | 212 | # Add WAL level settings 213 | new_config = {**new_config, **pg_config.get_wal_level()} 214 | 215 | # Add pg_stat_statements settings 216 | new_config = {**new_config, **pg_config.get_pg_stat_statements_config()} 217 | 218 | # Add WAL buffers if available 219 | wal_buffers = pg_config.get_wal_buffers() 220 | if wal_buffers is not None: 221 | new_config["wal_buffers"] = wal_buffers 222 | 223 | # Add IO concurrency if available 224 | io_concurrency = pg_config.get_effective_io_concurrency() 225 | if io_concurrency is not None: 226 | new_config["effective_io_concurrency"] = io_concurrency 227 | 228 | # Add in the docker specific settings 229 | new_config["listen_addresses"] = "*" 230 | new_config["dynamic_shared_memory_type"] = "posix" 231 | new_config["log_timezone"] = "Etc/UTC" 232 | new_config["datestyle"] = "iso, mdy" 233 | new_config["timezone"] = "Etc/UTC" 234 | 235 | # Merge configurations, preferring existing values 236 | existing_config = read_postgresql_conf(pg_path) 237 | final_config = format_postgres_values({**new_config, **existing_config}) 238 | 239 | # Display the differences 240 | display_config_diff(existing_config, final_config) 241 | 242 | # Check for any warnings 243 | warnings = pg_config.get_warning_info_messages() 244 | if warnings: 245 | console.print("\n[yellow]Warnings:[/yellow]") 246 | for warning in warnings: 247 | console.print(f"[yellow]- {warning}[/yellow]") 248 | 249 | # Write the new configuration 250 | try: 251 | write_postgresql_conf(final_config, pg_path) 252 | console.print("\n[green]Successfully wrote new PostgreSQL configuration![/green]") 253 | 254 | # Write SQL initialization file if pg_stat_statements is enabled 255 | init_sql = pg_config.get_pg_stat_statements_sql() 256 | if init_sql.strip(): 257 | success, _ = write_sql_init_file(init_sql, "init_extensions.sql") 258 | if not success: 259 | console.print( 260 | "\n[yellow]Failed to write SQL initialization file. Run this SQL manually:[/yellow]" 261 | ) 262 | console.print(f"[yellow]{init_sql}[/yellow]") 263 | 264 | except Exception as e: 265 | console.print(f"\n[red]Error writing configuration: {str(e)}[/red]") 266 | sys.exit(1) 267 | 268 | 269 | @cli.command() 270 | @click.option( 271 | "--output-dir", 272 | type=click.Path(exists=True, file_okay=False, dir_okay=True, path_type=Path), 273 | default=None, 274 | help="Output directory for CSS files (defaults to autopg/static/)", 275 | ) 276 | @click.option( 277 | "--style", 278 | type=str, 279 | default="default", 280 | help="Pygments style to use (default, github, monokai, etc.)", 281 | ) 282 | def generate_css(output_dir: Path | None, style: str) -> None: 283 | """Generate Pygments CSS for SQL syntax highlighting.""" 284 | try: 285 | from pygments.formatters import HtmlFormatter 286 | except ImportError: 287 | console.print( 288 | "[red]Error: pygments is not installed. Install it with: pip install pygments[/red]" 289 | ) 290 | sys.exit(1) 291 | 292 | if output_dir is None: 293 | # Default to the static directory relative to this file 294 | output_dir = Path(__file__).parent / "static" 295 | 296 | output_dir.mkdir(parents=True, exist_ok=True) 297 | css_file = output_dir / "pygments.css" 298 | 299 | console.print(f"Generating Pygments CSS with style '{style}'...") 300 | 301 | # Create HTML formatter with the specified style 302 | formatter = HtmlFormatter( # type: ignore[no-untyped-call] 303 | style=style, cssclass="highlight", noclasses=False 304 | ) 305 | 306 | # Generate CSS 307 | css_content = formatter.get_style_defs(".highlight") # type: ignore[no-untyped-call] 308 | 309 | # Write CSS file 310 | with open(css_file, "w", encoding="utf-8") as f: 311 | f.write(css_content) 312 | 313 | console.print(f"[green]✓ Generated Pygments CSS: {css_file}[/green]") 314 | console.print(f"[blue]Style used: {style}[/blue]") 315 | console.print("[yellow]Don't forget to include this CSS file in your HTML![/yellow]") 316 | -------------------------------------------------------------------------------- /benchmarks/benchmarks/cli.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Benchmarking CLI for AutoPG - Load testing PostgreSQL with unoptimized queries using asyncpg. 4 | """ 5 | 6 | import asyncio 7 | import os 8 | import sys 9 | from typing import Optional 10 | 11 | import click 12 | from rich.console import Console 13 | from rich.panel import Panel 14 | from rich.table import Table 15 | 16 | from .database import AsyncDatabaseConnection 17 | from .insertion import InsertionBenchmark 18 | from .seqscan import SequentialScanBenchmark 19 | from .utils import format_duration, format_number 20 | 21 | console = Console() 22 | 23 | 24 | @click.group() 25 | @click.option( 26 | "--host", default=lambda: os.getenv("POSTGRES_HOST", "localhost"), help="PostgreSQL host" 27 | ) 28 | @click.option( 29 | "--port", default=lambda: int(os.getenv("POSTGRES_PORT", "5432")), help="PostgreSQL port" 30 | ) 31 | @click.option( 32 | "--database", 33 | default=lambda: os.getenv("POSTGRES_DB", "benchmark"), 34 | help="PostgreSQL database name", 35 | ) 36 | @click.option( 37 | "--user", default=lambda: os.getenv("POSTGRES_USER", "postgres"), help="PostgreSQL username" 38 | ) 39 | @click.option( 40 | "--password", 41 | default=lambda: os.getenv("POSTGRES_PASSWORD", "postgres"), 42 | help="PostgreSQL password", 43 | ) 44 | @click.option("--verbose", "-v", is_flag=True, help="Enable verbose output") 45 | @click.pass_context 46 | def cli( 47 | ctx: click.Context, host: str, port: int, database: str, user: str, password: str, verbose: bool 48 | ) -> None: 49 | """AutoPG Database Benchmarking Tool - Load test your PostgreSQL instance.""" 50 | ctx.ensure_object(dict) 51 | ctx.obj["db_config"] = { 52 | "host": host, 53 | "port": port, 54 | "database": database, 55 | "user": user, 56 | "password": password, 57 | } 58 | ctx.obj["verbose"] = verbose 59 | 60 | # Test database connection 61 | async def test_connection(): 62 | try: 63 | async with AsyncDatabaseConnection(**ctx.obj["db_config"]) as db: 64 | await db.execute("SELECT 1") 65 | if verbose: 66 | console.print( 67 | f"✅ Connected to PostgreSQL at {host}:{port}/{database}", style="green" 68 | ) 69 | return True 70 | except Exception as e: 71 | console.print(f"❌ Failed to connect to PostgreSQL: {e}", style="red") 72 | return False 73 | 74 | # Run the async connection test 75 | if not asyncio.run(test_connection()): 76 | sys.exit(1) 77 | 78 | 79 | @cli.command() 80 | @click.option("--records", "-n", default=10000, help="Number of records to insert") 81 | @click.option("--batch-size", "-b", default=1000, help="Batch size for insertions") 82 | @click.option("--workers", "-w", default=1, help="Number of concurrent workers") 83 | @click.option( 84 | "--table", 85 | default="users", 86 | type=click.Choice(["users", "posts", "comments", "events"]), 87 | help="Table to insert into", 88 | ) 89 | @click.pass_context 90 | def insert(ctx: click.Context, records: int, batch_size: int, workers: int, table: str) -> None: 91 | """Run insertion load test on unoptimized tables.""" 92 | console.print( 93 | Panel.fit( 94 | f"[bold blue]Insertion Benchmark[/bold blue]\n" 95 | f"Table: {table}\n" 96 | f"Records: {format_number(records)}\n" 97 | f"Batch Size: {format_number(batch_size)}\n" 98 | f"Workers: {workers}", 99 | title="Configuration", 100 | ) 101 | ) 102 | 103 | benchmark = InsertionBenchmark(ctx.obj["db_config"], verbose=ctx.obj["verbose"]) 104 | results = benchmark.run( 105 | table_name=table, num_records=records, batch_size=batch_size, num_workers=workers 106 | ) 107 | 108 | _display_results("Insertion Benchmark Results", results) 109 | 110 | 111 | @cli.command() 112 | @click.option("--iterations", "-i", default=10, help="Number of scan iterations") 113 | @click.option( 114 | "--table", 115 | default="posts", 116 | type=click.Choice(["users", "posts", "comments", "events"]), 117 | help="Table to scan", 118 | ) 119 | @click.option("--limit", default=None, type=int, help="LIMIT clause for scans") 120 | @click.option("--workers", "-w", default=1, help="Number of concurrent workers") 121 | @click.pass_context 122 | def seqscan( 123 | ctx: click.Context, iterations: int, table: str, limit: Optional[int], workers: int 124 | ) -> None: 125 | """Run sequential scan load test on unoptimized tables.""" 126 | console.print( 127 | Panel.fit( 128 | f"[bold blue]Sequential Scan Benchmark[/bold blue]\n" 129 | f"Table: {table}\n" 130 | f"Iterations: {iterations}\n" 131 | f"Limit: {limit or 'None'}\n" 132 | f"Workers: {workers}", 133 | title="Configuration", 134 | ) 135 | ) 136 | 137 | benchmark = SequentialScanBenchmark(ctx.obj["db_config"], verbose=ctx.obj["verbose"]) 138 | results = benchmark.run( 139 | table_name=table, iterations=iterations, limit=limit, num_workers=workers 140 | ) 141 | 142 | _display_results("Sequential Scan Benchmark Results", results) 143 | 144 | 145 | @cli.command() 146 | @click.option("--insert-records", default=10000, help="Records to insert per table") 147 | @click.option("--scan-iterations", default=5, help="Sequential scan iterations") 148 | @click.option("--workers", "-w", default=2, help="Number of concurrent workers") 149 | @click.pass_context 150 | def full(ctx: click.Context, insert_records: int, scan_iterations: int, workers: int) -> None: 151 | """Run complete benchmark suite (insert + sequential scans).""" 152 | console.print( 153 | Panel.fit( 154 | f"[bold blue]Full Benchmark Suite[/bold blue]\n" 155 | f"Insert Records: {format_number(insert_records)}\n" 156 | f"Scan Iterations: {scan_iterations}\n" 157 | f"Workers: {workers}", 158 | title="Configuration", 159 | ) 160 | ) 161 | 162 | all_results = {} 163 | 164 | # Run insertion benchmarks 165 | console.print("\n[bold yellow]Phase 1: Insertion Benchmarks[/bold yellow]") 166 | insertion_benchmark = InsertionBenchmark(ctx.obj["db_config"], verbose=ctx.obj["verbose"]) 167 | 168 | for table in ["users", "posts", "comments", "events"]: 169 | console.print(f"\n[cyan]Inserting into {table}...[/cyan]") 170 | results = insertion_benchmark.run( 171 | table_name=table, num_records=insert_records, batch_size=1000, num_workers=workers 172 | ) 173 | all_results[f"insert_{table}"] = results 174 | 175 | # Run sequential scan benchmarks 176 | console.print("\n[bold yellow]Phase 2: Sequential Scan Benchmarks[/bold yellow]") 177 | seqscan_benchmark = SequentialScanBenchmark(ctx.obj["db_config"], verbose=ctx.obj["verbose"]) 178 | 179 | for table in ["users", "posts", "comments", "events"]: 180 | console.print(f"\n[cyan]Sequential scanning {table}...[/cyan]") 181 | results = seqscan_benchmark.run( 182 | table_name=table, iterations=scan_iterations, limit=None, num_workers=workers 183 | ) 184 | all_results[f"seqscan_{table}"] = results 185 | 186 | # Display summary 187 | _display_full_results(all_results) 188 | 189 | 190 | @cli.command() 191 | @click.pass_context 192 | def status(ctx: click.Context) -> None: 193 | """Show database status and table statistics.""" 194 | 195 | async def get_status(): 196 | async with AsyncDatabaseConnection(**ctx.obj["db_config"]) as db: 197 | # Get table sizes 198 | table_stats = await db.execute(""" 199 | SELECT 200 | schemaname, 201 | tablename, 202 | pg_size_pretty(pg_total_relation_size(schemaname||'.'||tablename)) as size, 203 | n_tup_ins as inserts, 204 | n_tup_upd as updates, 205 | n_tup_del as deletes, 206 | seq_scan, 207 | seq_tup_read, 208 | idx_scan, 209 | idx_tup_fetch 210 | FROM pg_stat_user_tables 211 | WHERE schemaname = 'benchmark' 212 | ORDER BY pg_total_relation_size(schemaname||'.'||tablename) DESC 213 | """) 214 | 215 | # Create table 216 | table = Table(title="Database Table Statistics") 217 | table.add_column("Table", style="cyan") 218 | table.add_column("Size", style="magenta") 219 | table.add_column("Rows", style="green") 220 | table.add_column("Seq Scans", style="yellow") 221 | table.add_column("Seq Reads", style="yellow") 222 | table.add_column("Index Scans", style="blue") 223 | table.add_column("Index Fetches", style="blue") 224 | 225 | for row in table_stats: 226 | # Get row count 227 | count_result = await db.execute_one( 228 | f"SELECT COUNT(*) FROM benchmark.{row['tablename']}" 229 | ) 230 | row_count = format_number(count_result[0]) if count_result else "0" 231 | 232 | table.add_row( 233 | row["tablename"], 234 | row["size"], 235 | row_count, 236 | format_number(row["seq_scan"]), 237 | format_number(row["seq_tup_read"]), 238 | format_number(row["idx_scan"] or 0), 239 | format_number(row["idx_tup_fetch"] or 0), 240 | ) 241 | 242 | console.print(table) 243 | 244 | asyncio.run(get_status()) 245 | 246 | 247 | def _display_results(title: str, results: dict) -> None: 248 | """Display benchmark results in a formatted table.""" 249 | table = Table(title=title) 250 | table.add_column("Metric", style="cyan") 251 | table.add_column("Value", style="magenta") 252 | 253 | # Core metrics 254 | table.add_row("Total Duration", format_duration(results["total_duration"])) 255 | table.add_row("Records Processed", format_number(results["records_processed"])) 256 | table.add_row("Records/Second", format_number(results["records_per_second"])) 257 | 258 | if "batches_processed" in results: 259 | table.add_row("Batches Processed", format_number(results["batches_processed"])) 260 | table.add_row("Avg Batch Time", format_duration(results["avg_batch_time"])) 261 | 262 | if "iterations" in results: 263 | table.add_row("Iterations", format_number(results["iterations"])) 264 | table.add_row("Avg Iteration Time", format_duration(results["avg_iteration_time"])) 265 | 266 | # Performance metrics 267 | if "min_time" in results: 268 | table.add_row("Min Time", format_duration(results["min_time"])) 269 | table.add_row("Max Time", format_duration(results["max_time"])) 270 | table.add_row("Median Time", format_duration(results["median_time"])) 271 | 272 | console.print(table) 273 | 274 | 275 | def _display_full_results(all_results: dict) -> None: 276 | """Display results from the full benchmark suite.""" 277 | console.print("\n[bold green]📊 Full Benchmark Results Summary[/bold green]") 278 | 279 | # Insertion results 280 | insert_table = Table(title="Insertion Benchmark Summary") 281 | insert_table.add_column("Table", style="cyan") 282 | insert_table.add_column("Records", style="green") 283 | insert_table.add_column("Duration", style="yellow") 284 | insert_table.add_column("Records/sec", style="magenta") 285 | 286 | for key, results in all_results.items(): 287 | if key.startswith("insert_"): 288 | table_name = key.replace("insert_", "") 289 | insert_table.add_row( 290 | table_name, 291 | format_number(results["records_processed"]), 292 | format_duration(results["total_duration"]), 293 | format_number(results["records_per_second"]), 294 | ) 295 | 296 | console.print(insert_table) 297 | 298 | # Sequential scan results 299 | scan_table = Table(title="Sequential Scan Benchmark Summary") 300 | scan_table.add_column("Table", style="cyan") 301 | scan_table.add_column("Iterations", style="green") 302 | scan_table.add_column("Avg Duration", style="yellow") 303 | scan_table.add_column("Records/sec", style="magenta") 304 | 305 | for key, results in all_results.items(): 306 | if key.startswith("seqscan_"): 307 | table_name = key.replace("seqscan_", "") 308 | scan_table.add_row( 309 | table_name, 310 | format_number(results["iterations"]), 311 | format_duration(results["avg_iteration_time"]), 312 | format_number(results["records_per_second"]), 313 | ) 314 | 315 | console.print(scan_table) 316 | 317 | 318 | if __name__ == "__main__": 319 | cli() 320 | -------------------------------------------------------------------------------- /autopgpool/autopgpool/__tests__/test_docker.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import subprocess 4 | import tempfile 5 | from pathlib import Path 6 | from time import sleep, time 7 | from typing import Generator, TypeVar 8 | 9 | import psycopg 10 | import pytest 11 | import tomli_w 12 | 13 | from autopgpool.config import MainConfig, PgbouncerConfig, Pool, User 14 | from autopgpool.logging import CONSOLE 15 | 16 | T = TypeVar("T") 17 | 18 | 19 | @pytest.fixture 20 | def temp_workspace() -> Generator[Path, None, None]: 21 | """Create a temporary workspace for Docker tests""" 22 | with tempfile.TemporaryDirectory() as temp_dir: 23 | workspace = Path(temp_dir) / "workspace" 24 | # Copy current workspace to temp directory 25 | shutil.copytree(os.getcwd(), workspace, dirs_exist_ok=True) 26 | yield workspace 27 | 28 | 29 | def build_autopgpool_docker_image(temp_workspace: Path) -> str: 30 | """ 31 | Build the AutoPGPool Docker image for testing. 32 | 33 | :param temp_workspace: Temporary directory containing a copy of the workspace 34 | :return: Docker image tag 35 | """ 36 | test_tag = "autopgpool:test" 37 | subprocess.run( 38 | [ 39 | "docker", 40 | "build", 41 | "-t", 42 | test_tag, 43 | ".", 44 | ], 45 | cwd=temp_workspace, 46 | check=True, 47 | ) 48 | return test_tag 49 | 50 | 51 | def create_docker_network() -> str: 52 | """ 53 | Create a Docker network for testing. 54 | 55 | :return: Network name 56 | """ 57 | network_name = f"autopgpool-test-{int(time())}" 58 | subprocess.run( 59 | ["docker", "network", "create", network_name], 60 | check=True, 61 | ) 62 | return network_name 63 | 64 | 65 | def remove_docker_network(network_name: str) -> None: 66 | """ 67 | Remove a Docker network. 68 | 69 | :param network_name: Name of the network to remove 70 | """ 71 | subprocess.run( 72 | ["docker", "network", "rm", network_name], 73 | check=True, 74 | ) 75 | 76 | 77 | def start_postgres_container(temp_workspace: Path, network_name: str) -> tuple[str, str, int]: 78 | """ 79 | Start a PostgreSQL container for testing. 80 | 81 | :param temp_workspace: Temporary directory containing a copy of the workspace 82 | :param network_name: Docker network name to connect to 83 | :return: Container ID, container name, and mapped port 84 | """ 85 | # Use a random port on the host to avoid conflicts 86 | postgres_port = 5433 87 | container_name = f"postgres-{int(time())}" 88 | 89 | container_id = ( 90 | subprocess.check_output( 91 | [ 92 | "docker", 93 | "run", 94 | "-d", 95 | "--name", 96 | container_name, 97 | "--network", 98 | network_name, 99 | "-p", 100 | f"{postgres_port}:5432", 101 | "-e", 102 | "POSTGRES_USER=test_user", 103 | "-e", 104 | "POSTGRES_PASSWORD=test_password", 105 | "-e", 106 | "POSTGRES_DB=test_db", 107 | "postgres:15", 108 | ], 109 | cwd=temp_workspace, 110 | ) 111 | .decode() 112 | .strip() 113 | ) 114 | 115 | return container_id, container_name, postgres_port 116 | 117 | 118 | def wait_for_postgres(container_id: str, timeout_seconds: int = 30) -> None: 119 | """ 120 | Wait for PostgreSQL to be ready with a timeout. 121 | 122 | :param container_id: Docker container ID 123 | :param timeout_seconds: Maximum time to wait in seconds 124 | """ 125 | start_time = time() 126 | while True: 127 | if time() - start_time > timeout_seconds: 128 | raise TimeoutError(f"PostgreSQL not ready after {timeout_seconds} seconds") 129 | 130 | try: 131 | subprocess.run( 132 | ["docker", "exec", container_id, "pg_isready", "-t", "5"], 133 | check=True, 134 | stdout=subprocess.DEVNULL, 135 | stderr=subprocess.DEVNULL, 136 | ) 137 | CONSOLE.print("PostgreSQL is ready") 138 | break 139 | except subprocess.CalledProcessError: 140 | sleep(1) 141 | 142 | # Give time to fully boot and be reachable 143 | sleep(2) 144 | 145 | 146 | def create_test_config(temp_workspace: Path, postgres_host: str, postgres_port: int) -> Path: 147 | """ 148 | Create a test configuration file for autopgpool using the pydantic models. 149 | 150 | :param temp_workspace: Temporary directory containing a copy of the workspace 151 | :param postgres_host: Hostname of the PostgreSQL container 152 | :param postgres_port: Port of the PostgreSQL container 153 | :return: Path to the configuration file 154 | """ 155 | # Create the user model 156 | test_user = User(username="test_user", password="test_password", grants=["test_db"]) 157 | 158 | # Create the pool model for test_db 159 | test_pool = Pool( 160 | remote=Pool.RemoteDatabase( 161 | host=postgres_host, 162 | port=postgres_port, 163 | database="test_db", 164 | username="test_user", 165 | password="test_password", 166 | ), 167 | pool_mode="transaction", 168 | ) 169 | 170 | # Create pgbouncer config 171 | pgbouncer_config = PgbouncerConfig( 172 | listen_addr="0.0.0.0", 173 | listen_port=6432, 174 | auth_type="md5", 175 | pool_mode="transaction", 176 | max_client_conn=100, 177 | default_pool_size=20, 178 | ignore_startup_parameters=["extra_float_digits"], 179 | # Explicitly set these to empty lists instead of None 180 | admin_users=[], 181 | stats_users=[], 182 | ) 183 | 184 | # Create the main config 185 | main_config = MainConfig( 186 | users=[test_user], pools={"test_db": test_pool}, pgbouncer=pgbouncer_config 187 | ) 188 | 189 | # Convert to dictionary and then to TOML 190 | config_dict = main_config.model_dump(mode="json") 191 | 192 | # Helper function to recursively remove None values from a dict 193 | def remove_none_values(d: T) -> T: 194 | if not isinstance(d, dict): 195 | return d 196 | return {k: remove_none_values(v) for k, v in d.items() if v is not None} # type: ignore 197 | 198 | # Clean the dict before serializing 199 | clean_config_dict = remove_none_values(config_dict) 200 | config_toml = tomli_w.dumps(clean_config_dict) 201 | 202 | # Write to file 203 | config_path = temp_workspace / "test_config.toml" 204 | with open(config_path, "wb") as f: 205 | f.write(config_toml.encode()) 206 | 207 | return config_path 208 | 209 | 210 | def start_autopgpool_container( 211 | temp_workspace: Path, 212 | image_tag: str, 213 | config_path: Path, 214 | network_name: str, 215 | ) -> tuple[str, int]: 216 | """ 217 | Start an autopgpool container for testing. 218 | 219 | :param temp_workspace: Temporary directory containing a copy of the workspace 220 | :param image_tag: Docker image tag to run 221 | :param config_path: Path to the configuration file 222 | :param network_name: Docker network name to connect to 223 | :return: Container ID and mapped port 224 | """ 225 | # Use a random port to avoid conflicts 226 | pgbouncer_port = 6436 227 | container_name = f"autopgpool-{int(time())}" 228 | 229 | container_id = ( 230 | subprocess.check_output( 231 | [ 232 | "docker", 233 | "run", 234 | "-d", 235 | "--name", 236 | container_name, 237 | "--network", 238 | network_name, 239 | "-p", 240 | f"{pgbouncer_port}:6432", 241 | "-v", 242 | f"{config_path}:/etc/autopgpool/autopgpool.toml", 243 | image_tag, 244 | ], 245 | cwd=temp_workspace, 246 | ) 247 | .decode() 248 | .strip() 249 | ) 250 | 251 | return container_id, pgbouncer_port 252 | 253 | 254 | def wait_for_pgbouncer(container_id: str, timeout_seconds: int = 30) -> None: 255 | """ 256 | Wait for PgBouncer to be ready with a timeout. 257 | 258 | :param container_id: Docker container ID 259 | :param timeout_seconds: Maximum time to wait in seconds 260 | """ 261 | start_time = time() 262 | 263 | # Wait for pgbouncer to start 264 | while True: 265 | if time() - start_time > timeout_seconds: 266 | raise TimeoutError(f"PgBouncer not ready after {timeout_seconds} seconds") 267 | 268 | try: 269 | # Check if pgbouncer process is running 270 | result = subprocess.run( 271 | ["docker", "exec", container_id, "ps", "aux"], 272 | check=True, 273 | capture_output=True, 274 | text=True, 275 | ) 276 | if "pgbouncer" in result.stdout and "/usr/bin/pgbouncer" in result.stdout: 277 | CONSOLE.print("PgBouncer is running") 278 | break 279 | except subprocess.CalledProcessError: 280 | pass 281 | 282 | sleep(1) 283 | 284 | # Give pgbouncer time to initialize 285 | sleep(5) 286 | 287 | 288 | def cleanup_container(container_id: str) -> None: 289 | """ 290 | Stop and remove a Docker container. 291 | 292 | :param container_id: Docker container ID 293 | """ 294 | subprocess.run(["docker", "stop", container_id], check=True) 295 | subprocess.run(["docker", "rm", container_id], check=True) 296 | 297 | 298 | @pytest.mark.integration 299 | def test_autopgpool_connection(temp_workspace: Path) -> None: 300 | """ 301 | Test that AutoPGPool correctly routes connections to PostgreSQL. 302 | 303 | This test: 304 | 1. Creates a Docker network 305 | 2. Starts a PostgreSQL container on the network 306 | 3. Builds and starts the AutoPGPool container on the same network 307 | 4. Verifies that connections can be made through the pool 308 | 309 | :param temp_workspace: Temporary directory containing a copy of the workspace 310 | """ 311 | postgres_container_id: str | None = None 312 | pgbouncer_container_id: str | None = None 313 | network_name: str | None = None 314 | 315 | try: 316 | # Create Docker network 317 | network_name = create_docker_network() 318 | CONSOLE.print(f"Created Docker network: {network_name}") 319 | 320 | # Start PostgreSQL container 321 | postgres_container_id, postgres_container_name, _ = start_postgres_container( 322 | temp_workspace, network_name 323 | ) 324 | wait_for_postgres(postgres_container_id) 325 | 326 | # Create test configuration - using the container name as hostname 327 | config_path = create_test_config(temp_workspace, postgres_container_name, 5432) 328 | 329 | # Build and start AutoPGPool container 330 | autopgpool_tag = build_autopgpool_docker_image(temp_workspace) 331 | pgbouncer_container_id, pgbouncer_port = start_autopgpool_container( 332 | temp_workspace, 333 | autopgpool_tag, 334 | config_path, 335 | network_name, 336 | ) 337 | wait_for_pgbouncer(pgbouncer_container_id) 338 | 339 | # Verify connection through pgbouncer 340 | conn = psycopg.connect( 341 | host="localhost", 342 | port=pgbouncer_port, 343 | user="test_user", 344 | password="test_password", 345 | dbname="test_db", 346 | connect_timeout=5, 347 | ) 348 | 349 | try: 350 | with conn.cursor() as cur: 351 | # Simple query to verify connection 352 | cur.execute("SELECT 1 AS test") 353 | result = cur.fetchone() 354 | assert result is not None 355 | assert result[0] == 1 356 | finally: 357 | conn.close() 358 | 359 | except Exception as e: 360 | CONSOLE.print(f"Error: {e}") 361 | 362 | # Print logs from containers 363 | if pgbouncer_container_id: 364 | CONSOLE.print("PgBouncer logs:") 365 | subprocess.run(["docker", "logs", pgbouncer_container_id], check=True) 366 | 367 | if postgres_container_id: 368 | CONSOLE.print("PostgreSQL logs:") 369 | subprocess.run(["docker", "logs", postgres_container_id], check=True) 370 | 371 | raise e 372 | finally: 373 | # Clean up containers 374 | if pgbouncer_container_id: 375 | cleanup_container(pgbouncer_container_id) 376 | if postgres_container_id: 377 | cleanup_container(postgres_container_id) 378 | # Clean up network 379 | if network_name: 380 | remove_docker_network(network_name) 381 | -------------------------------------------------------------------------------- /autopg/logic.py: -------------------------------------------------------------------------------- 1 | """ 2 | The MIT License (MIT) 3 | 4 | Copyright (c) 2014 Alexey Vasiliev 5 | Copyright (c) 2025 Pierce Freeman 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. 24 | 25 | """ 26 | 27 | from math import ceil 28 | from typing import Callable 29 | 30 | from pydantic import BaseModel 31 | 32 | from autopg.constants import ( 33 | DB_TYPE_DESKTOP, 34 | DB_TYPE_DW, 35 | DB_TYPE_MIXED, 36 | DB_TYPE_OLTP, 37 | DB_TYPE_WEB, 38 | DEFAULT_DB_VERSION, 39 | HARD_DRIVE_HDD, 40 | HARD_DRIVE_SAN, 41 | HARD_DRIVE_SSD, 42 | OS_LINUX, 43 | OS_WINDOWS, 44 | PG_STAT_STATEMENTS_SQL, 45 | SIZE_UNIT_GB, 46 | SIZE_UNIT_MAP, 47 | ) 48 | 49 | 50 | class Configuration(BaseModel): 51 | db_version: float = DEFAULT_DB_VERSION 52 | os_type: str = OS_LINUX 53 | db_type: str = DB_TYPE_WEB 54 | total_memory: int | None = None 55 | total_memory_unit: str = SIZE_UNIT_GB 56 | cpu_num: int | None = None 57 | connection_num: int | None = None 58 | hd_type: str = HARD_DRIVE_SSD 59 | enable_pg_stat_statements: bool = True 60 | 61 | 62 | class PostgresConfig: 63 | def __init__(self, config: Configuration): 64 | self.state = config 65 | 66 | def get_total_memory_in_bytes(self) -> int | None: 67 | if self.state.total_memory is None: 68 | return None 69 | return self.state.total_memory * SIZE_UNIT_MAP[self.state.total_memory_unit] 70 | 71 | def get_total_memory_in_kb(self) -> float | None: 72 | memory_bytes = self.get_total_memory_in_bytes() 73 | if memory_bytes is None: 74 | return None 75 | return memory_bytes / SIZE_UNIT_MAP["KB"] 76 | 77 | def get_max_connections(self) -> int: 78 | if self.state.connection_num: 79 | return self.state.connection_num 80 | 81 | connection_map = { 82 | DB_TYPE_WEB: 200, 83 | DB_TYPE_OLTP: 300, 84 | DB_TYPE_DW: 40, 85 | DB_TYPE_DESKTOP: 20, 86 | DB_TYPE_MIXED: 100, 87 | } 88 | return connection_map[self.state.db_type] 89 | 90 | def get_huge_pages(self) -> str: 91 | memory_kb = self.get_total_memory_in_kb() 92 | if memory_kb is None: 93 | return "off" 94 | return "try" if memory_kb >= 33554432 else "off" 95 | 96 | def get_shared_buffers(self) -> int | None: 97 | memory_kb = self.get_total_memory_in_kb() 98 | if memory_kb is None: 99 | return None 100 | 101 | shared_buffers_map: dict[str, Callable[[float], float]] = { 102 | DB_TYPE_WEB: lambda x: x / 4, 103 | DB_TYPE_OLTP: lambda x: x / 4, 104 | DB_TYPE_DW: lambda x: x / 4, 105 | DB_TYPE_DESKTOP: lambda x: x / 16, 106 | DB_TYPE_MIXED: lambda x: x / 4, 107 | } 108 | 109 | value = shared_buffers_map[self.state.db_type](memory_kb) 110 | 111 | if self.state.db_version < 10 and self.state.os_type == OS_WINDOWS: 112 | win_memory_limit = (512 * SIZE_UNIT_MAP["MB"]) / SIZE_UNIT_MAP["KB"] 113 | if value > win_memory_limit: 114 | value = win_memory_limit 115 | 116 | return int(value) 117 | 118 | def get_effective_cache_size(self) -> int | None: 119 | memory_kb = self.get_total_memory_in_kb() 120 | if memory_kb is None: 121 | return None 122 | 123 | cache_map: dict[str, Callable[[float], float]] = { 124 | DB_TYPE_WEB: lambda x: (x * 3) / 4, 125 | DB_TYPE_OLTP: lambda x: (x * 3) / 4, 126 | DB_TYPE_DW: lambda x: (x * 3) / 4, 127 | DB_TYPE_DESKTOP: lambda x: x / 4, 128 | DB_TYPE_MIXED: lambda x: (x * 3) / 4, 129 | } 130 | return int(cache_map[self.state.db_type](memory_kb)) 131 | 132 | def get_maintenance_work_mem(self) -> int | None: 133 | memory_kb = self.get_total_memory_in_kb() 134 | if memory_kb is None: 135 | return None 136 | 137 | maintenance_map: dict[str, Callable[[float], float]] = { 138 | DB_TYPE_WEB: lambda x: x / 16, 139 | DB_TYPE_OLTP: lambda x: x / 16, 140 | DB_TYPE_DW: lambda x: x / 8, 141 | DB_TYPE_DESKTOP: lambda x: x / 16, 142 | DB_TYPE_MIXED: lambda x: x / 16, 143 | } 144 | 145 | value = maintenance_map[self.state.db_type](memory_kb) 146 | memory_limit = (2 * SIZE_UNIT_MAP["GB"]) / SIZE_UNIT_MAP["KB"] 147 | 148 | if value >= memory_limit: 149 | if self.state.os_type == OS_WINDOWS: 150 | # 2048MB (2 GB) will raise error at Windows, so we need remove 1 MB from it 151 | value = memory_limit - (1 * SIZE_UNIT_MAP["MB"]) / SIZE_UNIT_MAP["KB"] 152 | else: 153 | value = memory_limit 154 | 155 | return int(value) 156 | 157 | def get_checkpoint_segments(self) -> dict[str, str | float]: 158 | min_wal_size_map = { 159 | DB_TYPE_WEB: 1024 * SIZE_UNIT_MAP["MB"] / SIZE_UNIT_MAP["KB"], 160 | DB_TYPE_OLTP: 2048 * SIZE_UNIT_MAP["MB"] / SIZE_UNIT_MAP["KB"], 161 | DB_TYPE_DW: 4096 * SIZE_UNIT_MAP["MB"] / SIZE_UNIT_MAP["KB"], 162 | DB_TYPE_DESKTOP: 100 * SIZE_UNIT_MAP["MB"] / SIZE_UNIT_MAP["KB"], 163 | DB_TYPE_MIXED: 1024 * SIZE_UNIT_MAP["MB"] / SIZE_UNIT_MAP["KB"], 164 | } 165 | 166 | max_wal_size_map = { 167 | DB_TYPE_WEB: 4096 * SIZE_UNIT_MAP["MB"] / SIZE_UNIT_MAP["KB"], 168 | DB_TYPE_OLTP: 8192 * SIZE_UNIT_MAP["MB"] / SIZE_UNIT_MAP["KB"], 169 | DB_TYPE_DW: 16384 * SIZE_UNIT_MAP["MB"] / SIZE_UNIT_MAP["KB"], 170 | DB_TYPE_DESKTOP: 2048 * SIZE_UNIT_MAP["MB"] / SIZE_UNIT_MAP["KB"], 171 | DB_TYPE_MIXED: 4096 * SIZE_UNIT_MAP["MB"] / SIZE_UNIT_MAP["KB"], 172 | } 173 | 174 | return { 175 | "min_wal_size": min_wal_size_map[self.state.db_type], 176 | "max_wal_size": max_wal_size_map[self.state.db_type], 177 | } 178 | 179 | def get_checkpoint_completion_target(self) -> float: 180 | return 0.9 # based on https://github.com/postgres/postgres/commit/bbcc4eb2 181 | 182 | def get_wal_buffers(self) -> int | None: 183 | shared_buffers = self.get_shared_buffers() 184 | if shared_buffers is None: 185 | return None 186 | 187 | # Follow auto-tuning guideline for wal_buffers added in 9.1, where it's 188 | # set to 3% of shared_buffers up to a maximum of 16MB. 189 | value = (3 * shared_buffers) // 100 190 | max_wal_buffer = int((16 * SIZE_UNIT_MAP["MB"]) / SIZE_UNIT_MAP["KB"]) 191 | 192 | if value > max_wal_buffer: 193 | value = max_wal_buffer 194 | 195 | # It's nice if wal_buffers is an even 16MB if it's near that number 196 | wal_buffer_near_value = int((14 * SIZE_UNIT_MAP["MB"]) / SIZE_UNIT_MAP["KB"]) 197 | if wal_buffer_near_value < value < max_wal_buffer: 198 | value = max_wal_buffer 199 | 200 | # if less than 32 kb, set it to minimum 201 | if value < 32: 202 | value = 32 203 | 204 | return int(value) 205 | 206 | def get_default_statistics_target(self) -> int: 207 | statistics_map = { 208 | DB_TYPE_WEB: 100, 209 | DB_TYPE_OLTP: 100, 210 | DB_TYPE_DW: 500, 211 | DB_TYPE_DESKTOP: 100, 212 | DB_TYPE_MIXED: 100, 213 | } 214 | return statistics_map[self.state.db_type] 215 | 216 | def get_random_page_cost(self) -> float: 217 | cost_map = {HARD_DRIVE_HDD: 4.0, HARD_DRIVE_SSD: 1.1, HARD_DRIVE_SAN: 1.1} 218 | return cost_map[self.state.hd_type] 219 | 220 | def get_effective_io_concurrency(self) -> int | None: 221 | if self.state.os_type != OS_LINUX: 222 | return None 223 | 224 | concurrency_map = {HARD_DRIVE_HDD: 2, HARD_DRIVE_SSD: 200, HARD_DRIVE_SAN: 300} 225 | return concurrency_map[self.state.hd_type] 226 | 227 | def get_parallel_settings(self) -> dict[str, str | int]: 228 | if not self.state.cpu_num or self.state.cpu_num < 4: 229 | return {} 230 | 231 | workers_per_gather = ceil(self.state.cpu_num / 2) 232 | 233 | if self.state.db_type != DB_TYPE_DW and workers_per_gather > 4: 234 | # no clear evidence, that each new worker will provide big benefit for each new core 235 | workers_per_gather = 4 236 | 237 | config: dict[str, str | int] = { 238 | "max_worker_processes": self.state.cpu_num, 239 | "max_parallel_workers_per_gather": workers_per_gather, 240 | } 241 | 242 | if self.state.db_version >= 10: 243 | config["max_parallel_workers"] = self.state.cpu_num 244 | 245 | if self.state.db_version >= 11: 246 | parallel_maintenance_workers = ceil(self.state.cpu_num / 2) 247 | if parallel_maintenance_workers > 4: 248 | parallel_maintenance_workers = 4 249 | 250 | config["max_parallel_maintenance_workers"] = parallel_maintenance_workers 251 | 252 | return config 253 | 254 | def get_work_mem(self) -> int | None: 255 | memory_kb = self.get_total_memory_in_kb() 256 | shared_buffers = self.get_shared_buffers() 257 | if memory_kb is None or shared_buffers is None: 258 | return None 259 | 260 | max_connections = self.get_max_connections() 261 | parallel_settings = self.get_parallel_settings() 262 | 263 | # Determine parallel workers 264 | parallel_workers = 1 265 | for key, value in parallel_settings.items(): 266 | if key == "max_parallel_workers_per_gather": 267 | if isinstance(value, int) and value > 0: 268 | parallel_workers = value 269 | break 270 | 271 | # Calculate work_mem 272 | work_mem = float(memory_kb - shared_buffers) / (max_connections * 3) / parallel_workers 273 | 274 | work_mem_map: dict[str, Callable[[float], float]] = { 275 | DB_TYPE_WEB: lambda x: x, 276 | DB_TYPE_OLTP: lambda x: x, 277 | DB_TYPE_DW: lambda x: x / 2, 278 | DB_TYPE_DESKTOP: lambda x: x / 6, 279 | DB_TYPE_MIXED: lambda x: x / 2, 280 | } 281 | 282 | value = int(work_mem_map[self.state.db_type](work_mem)) 283 | return max(64, value) # Minimum 64kb 284 | 285 | def get_warning_info_messages(self) -> list[str]: 286 | memory_bytes = self.get_total_memory_in_bytes() 287 | if memory_bytes is None: 288 | return [] 289 | 290 | if memory_bytes < 256 * SIZE_UNIT_MAP["MB"]: 291 | return ["WARNING", "this tool not being optimal", "for low memory systems"] 292 | if memory_bytes > 100 * SIZE_UNIT_MAP["GB"]: 293 | return ["WARNING", "this tool not being optimal", "for very high memory systems"] 294 | return [] 295 | 296 | def get_wal_level(self) -> dict[str, str]: 297 | if self.state.db_type == DB_TYPE_DESKTOP: 298 | return { 299 | "wal_level": "minimal", 300 | "max_wal_senders": "0", 301 | } 302 | return {} 303 | 304 | def get_pg_stat_statements_config(self) -> dict[str, str | int]: 305 | """ 306 | Get pg_stat_statements extension configuration. 307 | Returns configuration for shared_preload_libraries and pg_stat_statements settings. 308 | """ 309 | if not self.state.enable_pg_stat_statements: 310 | return {} 311 | 312 | return { 313 | "shared_preload_libraries": "pg_stat_statements", 314 | "pg_stat_statements.track": "all", 315 | "pg_stat_statements.max": 10000, 316 | } 317 | 318 | def get_pg_stat_statements_sql(self) -> str: 319 | """ 320 | Get SQL initialization script for pg_stat_statements extension. 321 | Returns SQL commands to create and configure the extension. 322 | """ 323 | if not self.state.enable_pg_stat_statements: 324 | return "" 325 | 326 | return PG_STAT_STATEMENTS_SQL 327 | -------------------------------------------------------------------------------- /benchmarks/benchmarks/seqscan.py: -------------------------------------------------------------------------------- 1 | """ 2 | Sequential scan benchmark for load testing database reads on unoptimized tables using asyncpg. 3 | """ 4 | 5 | import asyncio 6 | import random 7 | import time 8 | from typing import Any, Dict, List, Optional 9 | 10 | from rich.console import Console 11 | from rich.progress import ( 12 | BarColumn, 13 | Progress, 14 | SpinnerColumn, 15 | TaskProgressColumn, 16 | TextColumn, 17 | TimeRemainingColumn, 18 | ) 19 | 20 | from .database import AsyncConnectionPool, AsyncDatabaseConnection, timed_operation 21 | from .utils import calculate_statistics, format_duration, format_number 22 | 23 | console = Console() 24 | 25 | 26 | class AsyncSequentialScanBenchmark: 27 | """Async benchmark for testing sequential scan performance on unoptimized tables.""" 28 | 29 | def __init__(self, db_config: Dict[str, Any], verbose: bool = False): 30 | self.db_config = db_config 31 | self.verbose = verbose 32 | 33 | # Table-specific scan queries designed to force sequential scans 34 | self.table_queries = { 35 | "users": [ 36 | "SELECT * FROM benchmark.users WHERE profile_data::text LIKE '%theme%'", 37 | "SELECT * FROM benchmark.users WHERE created_at > NOW() - INTERVAL '30 days'", 38 | "SELECT username, email FROM benchmark.users WHERE status != 'active'", 39 | "SELECT COUNT(*) FROM benchmark.users WHERE last_login IS NULL", 40 | "SELECT * FROM benchmark.users WHERE email LIKE '%@example.com'", 41 | ], 42 | "posts": [ 43 | "SELECT * FROM benchmark.posts WHERE content ILIKE '%lorem%'", 44 | "SELECT title, view_count FROM benchmark.posts WHERE view_count > 100", 45 | "SELECT * FROM benchmark.posts WHERE tags && ARRAY['tech', 'news']", 46 | "SELECT COUNT(*) FROM benchmark.posts WHERE created_at > NOW() - INTERVAL '7 days'", 47 | "SELECT * FROM benchmark.posts WHERE metadata::text LIKE '%featured%'", 48 | ], 49 | "comments": [ 50 | "SELECT * FROM benchmark.comments WHERE content ILIKE '%dolor%'", 51 | "SELECT * FROM benchmark.comments WHERE likes > 10", 52 | "SELECT COUNT(*) FROM benchmark.comments WHERE parent_id IS NOT NULL", 53 | "SELECT * FROM benchmark.comments WHERE created_at > NOW() - INTERVAL '1 day'", 54 | "SELECT user_id, COUNT(*) FROM benchmark.comments GROUP BY user_id HAVING COUNT(*) > 5", 55 | ], 56 | "events": [ 57 | "SELECT * FROM benchmark.events WHERE event_type = 'login'", 58 | "SELECT * FROM benchmark.events WHERE event_data::text LIKE '%Chrome%'", 59 | "SELECT COUNT(*) FROM benchmark.events WHERE created_at > NOW() - INTERVAL '1 hour'", 60 | "SELECT event_type, COUNT(*) FROM benchmark.events GROUP BY event_type", 61 | "SELECT * FROM benchmark.events WHERE ip_address::text LIKE '192.168.%'", 62 | ], 63 | } 64 | 65 | async def run( 66 | self, 67 | table_name: str, 68 | iterations: int = 10, 69 | limit: Optional[int] = None, 70 | num_workers: int = 1, 71 | ) -> Dict[str, Any]: 72 | """Run the async sequential scan benchmark.""" 73 | if table_name not in self.table_queries: 74 | raise ValueError(f"Unsupported table: {table_name}") 75 | 76 | queries = self.table_queries[table_name] 77 | 78 | console.print( 79 | f"[cyan]Starting async sequential scan benchmark for table '{table_name}'[/cyan]" 80 | ) 81 | console.print(f"Iterations: {iterations}, Workers: {num_workers}, Limit: {limit or 'None'}") 82 | 83 | # Modify queries with LIMIT if specified 84 | if limit: 85 | queries = [f"{query} LIMIT {limit}" for query in queries] 86 | 87 | # Get table info for context 88 | table_info = await self._get_table_info(table_name) 89 | console.print( 90 | f"Table size: {table_info['size']}, Rows: {format_number(table_info['row_count'])}" 91 | ) 92 | 93 | # Run benchmark 94 | async with timed_operation( 95 | f"Async sequential scan benchmark ({num_workers} workers)", self.verbose 96 | ) as timing: 97 | if num_workers == 1: 98 | iteration_times, total_rows = await self._run_single_connection(queries, iterations) 99 | else: 100 | iteration_times, total_rows = await self._run_multi_connection( 101 | queries, iterations, num_workers 102 | ) 103 | 104 | # Calculate results 105 | total_duration = timing["duration"] 106 | total_iterations = len(iteration_times) 107 | avg_iteration_time = sum(iteration_times) / len(iteration_times) if iteration_times else 0 108 | records_per_second = total_rows / total_duration if total_duration > 0 else 0 109 | 110 | stats = calculate_statistics(iteration_times) 111 | 112 | results = { 113 | "table_name": table_name, 114 | "total_duration": total_duration, 115 | "iterations": total_iterations, 116 | "avg_iteration_time": avg_iteration_time, 117 | "records_processed": total_rows, 118 | "records_per_second": records_per_second, 119 | "num_workers": num_workers, 120 | "limit": limit, 121 | **{f"iteration_{k}": v for k, v in stats.items()}, 122 | } 123 | 124 | # Add min/max/median for compatibility with CLI display 125 | if stats: 126 | results.update( 127 | {"min_time": stats["min"], "max_time": stats["max"], "median_time": stats["median"]} 128 | ) 129 | 130 | return results 131 | 132 | async def _get_table_info(self, table_name: str) -> Dict[str, Any]: 133 | """Get information about the table being scanned.""" 134 | async with AsyncDatabaseConnection(**self.db_config) as db: 135 | table_info = await db.get_table_info() 136 | return table_info.get(table_name, {"size": "Unknown", "row_count": 0}) 137 | 138 | async def _run_single_connection( 139 | self, queries: List[str], iterations: int 140 | ) -> tuple[List[float], int]: 141 | """Run sequential scan benchmark with a single connection.""" 142 | iteration_times = [] 143 | total_rows = 0 144 | 145 | async with AsyncDatabaseConnection(**self.db_config) as db: 146 | with Progress( 147 | SpinnerColumn(), 148 | TextColumn("[progress.description]{task.description}"), 149 | BarColumn(), 150 | TaskProgressColumn(), 151 | TimeRemainingColumn(), 152 | transient=False, 153 | ) as progress: 154 | task = progress.add_task("Running sequential scans...", total=iterations) 155 | 156 | for i in range(iterations): 157 | # Randomly select a query for this iteration 158 | query = random.choice(queries) 159 | 160 | start_time = time.time() 161 | result = await db.execute(query) 162 | 163 | # Count rows processed 164 | rows = len(result) 165 | 166 | iteration_time = time.time() - start_time 167 | iteration_times.append(iteration_time) 168 | total_rows += rows 169 | 170 | if self.verbose: 171 | console.print( 172 | f"Iteration {i + 1}: {format_duration(iteration_time)}, {format_number(rows)} rows" 173 | ) 174 | 175 | progress.update(task, completed=i + 1) 176 | 177 | return iteration_times, total_rows 178 | 179 | async def _run_multi_connection( 180 | self, queries: List[str], iterations: int, num_workers: int 181 | ) -> tuple[List[float], int]: 182 | """Run sequential scan benchmark with multiple connections.""" 183 | iteration_times = [] 184 | total_rows = 0 185 | completed_iterations = 0 186 | 187 | async with AsyncConnectionPool(self.db_config, num_workers) as pool: 188 | with Progress( 189 | SpinnerColumn(), 190 | TextColumn("[progress.description]{task.description}"), 191 | BarColumn(), 192 | TaskProgressColumn(), 193 | TimeRemainingColumn(), 194 | transient=False, 195 | ) as progress: 196 | task = progress.add_task("Running sequential scans...", total=iterations) 197 | 198 | # Create semaphore to limit concurrent operations 199 | semaphore = asyncio.Semaphore(num_workers) 200 | 201 | async def execute_scan_with_semaphore( 202 | query: str, iteration_num: int 203 | ) -> tuple[float, int]: 204 | async with semaphore: 205 | return await self._execute_scan(pool, query, iteration_num) 206 | 207 | # Submit all iteration jobs 208 | tasks = [ 209 | execute_scan_with_semaphore(random.choice(queries), i) 210 | for i in range(iterations) 211 | ] 212 | 213 | # Collect results as they complete 214 | for coro in asyncio.as_completed(tasks): 215 | iteration_time, rows = await coro 216 | iteration_times.append(iteration_time) 217 | total_rows += rows 218 | completed_iterations += 1 219 | 220 | if self.verbose: 221 | console.print( 222 | f"Iteration {completed_iterations}: {format_duration(iteration_time)}, {format_number(rows)} rows" 223 | ) 224 | 225 | progress.update(task, completed=completed_iterations) 226 | 227 | return iteration_times, total_rows 228 | 229 | async def _execute_scan( 230 | self, pool: AsyncConnectionPool, query: str, iteration_num: int 231 | ) -> tuple[float, int]: 232 | """Execute a single sequential scan.""" 233 | start_time = time.time() 234 | 235 | async with pool.acquire() as conn: 236 | result = await conn.fetch(query) # type: ignore[no-untyped-call] 237 | rows = len(result) 238 | 239 | iteration_time = time.time() - start_time 240 | return iteration_time, rows 241 | 242 | async def run_explain_analyze( 243 | self, table_name: str, sample_queries: int = 3 244 | ) -> List[Dict[str, Any]]: 245 | """Run EXPLAIN ANALYZE on sample queries to show execution plans.""" 246 | if table_name not in self.table_queries: 247 | raise ValueError(f"Unsupported table: {table_name}") 248 | 249 | queries = self.table_queries[table_name] 250 | sample_queries = min(sample_queries, len(queries)) 251 | selected_queries = random.sample(queries, sample_queries) 252 | 253 | results = [] 254 | 255 | async with AsyncDatabaseConnection(**self.db_config) as db: 256 | for i, query in enumerate(selected_queries): 257 | console.print(f"\n[yellow]Query {i + 1}: {query}[/yellow]") 258 | 259 | explain_query = f"EXPLAIN (ANALYZE, BUFFERS, FORMAT JSON) {query}" 260 | 261 | try: 262 | result = await db.execute_one(explain_query) # type: ignore[no-untyped-call] 263 | explain_data = result[0][0] if result else {} # type: ignore[index] 264 | 265 | # Extract key metrics 266 | plan = explain_data.get("Plan", {}) 267 | execution_time = explain_data.get("Execution Time", 0) 268 | planning_time = explain_data.get("Planning Time", 0) 269 | 270 | analysis = { 271 | "query": query, 272 | "execution_time_ms": execution_time, 273 | "planning_time_ms": planning_time, 274 | "node_type": plan.get("Node Type", "Unknown"), 275 | "total_cost": plan.get("Total Cost", 0), 276 | "rows": plan.get("Actual Rows", 0), 277 | "shared_hit_blocks": plan.get("Shared Hit Blocks", 0), 278 | "shared_read_blocks": plan.get("Shared Read Blocks", 0), 279 | "full_explain": explain_data, 280 | } 281 | 282 | results.append(analysis) 283 | 284 | # Display summary 285 | console.print(f" Execution Time: {execution_time:.2f}ms") 286 | console.print(f" Planning Time: {planning_time:.2f}ms") 287 | console.print(f" Node Type: {plan.get('Node Type', 'Unknown')}") 288 | console.print(f" Rows: {format_number(plan.get('Actual Rows', 0))}") 289 | 290 | except Exception as e: 291 | console.print(f" Error running EXPLAIN: {e}", style="red") 292 | results.append({"query": query, "error": str(e)}) 293 | 294 | return results 295 | 296 | 297 | # Synchronous wrapper for backward compatibility 298 | class SequentialScanBenchmark: 299 | """Synchronous wrapper around AsyncSequentialScanBenchmark.""" 300 | 301 | def __init__(self, db_config: Dict[str, Any], verbose: bool = False): 302 | self.async_benchmark = AsyncSequentialScanBenchmark(db_config, verbose) 303 | 304 | def run( 305 | self, 306 | table_name: str, 307 | iterations: int = 10, 308 | limit: Optional[int] = None, 309 | num_workers: int = 1, 310 | ) -> Dict[str, Any]: 311 | """Run the sequential scan benchmark synchronously.""" 312 | return asyncio.run(self.async_benchmark.run(table_name, iterations, limit, num_workers)) 313 | 314 | def run_explain_analyze(self, table_name: str, sample_queries: int = 3) -> List[Dict[str, Any]]: 315 | """Run EXPLAIN ANALYZE synchronously.""" 316 | return asyncio.run(self.async_benchmark.run_explain_analyze(table_name, sample_queries)) 317 | -------------------------------------------------------------------------------- /benchmarks/benchmarks/database.py: -------------------------------------------------------------------------------- 1 | """ 2 | Database connection and utilities for benchmarking using asyncpg. 3 | """ 4 | 5 | import asyncio 6 | import time 7 | from contextlib import asynccontextmanager 8 | from typing import Any, AsyncGenerator, Dict, List, Optional 9 | 10 | import asyncpg 11 | from rich.console import Console 12 | 13 | console = Console() 14 | 15 | 16 | class AsyncDatabaseConnection: 17 | """Async database connection wrapper with benchmarking utilities.""" 18 | 19 | def __init__(self, host: str, port: int, database: str, user: str, password: str): 20 | self.host = host 21 | self.port = port 22 | self.database = database 23 | self.user = user 24 | self.password = password 25 | self.connection: Optional[asyncpg.Connection] = None 26 | 27 | async def __aenter__(self) -> "AsyncDatabaseConnection": 28 | """Enter async context manager and establish connection.""" 29 | self.connection = await asyncpg.connect( 30 | host=self.host, 31 | port=self.port, 32 | database=self.database, 33 | user=self.user, 34 | password=self.password, 35 | ) 36 | return self 37 | 38 | async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: 39 | """Exit async context manager and close connection.""" 40 | if self.connection: 41 | await self.connection.close() 42 | self.connection = None 43 | 44 | async def execute(self, query: str, *params: Any) -> List[asyncpg.Record]: 45 | """Execute a query and return all results.""" 46 | if not self.connection: 47 | raise RuntimeError("Database connection not established") 48 | 49 | return await self.connection.fetch(query, *params) 50 | 51 | async def execute_one(self, query: str, *params: Any) -> Optional[asyncpg.Record]: 52 | """Execute a query and return one result.""" 53 | if not self.connection: 54 | raise RuntimeError("Database connection not established") 55 | 56 | return await self.connection.fetchrow(query, *params) 57 | 58 | async def execute_many(self, query: str, params_list: List[tuple]) -> None: 59 | """Execute a query multiple times with different parameters.""" 60 | if not self.connection: 61 | raise RuntimeError("Database connection not established") 62 | 63 | await self.connection.executemany(query, params_list) 64 | 65 | async def execute_batch(self, query: str, params_list: List[tuple]) -> None: 66 | """Execute a batch of queries efficiently.""" 67 | if not self.connection: 68 | raise RuntimeError("Database connection not established") 69 | 70 | # Convert to format expected by asyncpg 71 | await self.connection.executemany(query, params_list) 72 | 73 | @asynccontextmanager 74 | async def transaction(self): 75 | """Async context manager for database transactions.""" 76 | if not self.connection: 77 | raise RuntimeError("Database connection not established") 78 | 79 | async with self.connection.transaction(): 80 | yield 81 | 82 | async def get_table_info(self, schema: str = "benchmark") -> Dict[str, Dict[str, Any]]: 83 | """Get information about tables in the specified schema.""" 84 | query = """ 85 | SELECT 86 | t.table_name, 87 | pg_size_pretty(pg_total_relation_size(quote_ident(t.table_schema)||'.'||quote_ident(t.table_name))) as size, 88 | pg_total_relation_size(quote_ident(t.table_schema)||'.'||quote_ident(t.table_name)) as size_bytes, 89 | obj_description(c.oid) as comment 90 | FROM information_schema.tables t 91 | LEFT JOIN pg_class c ON c.relname = t.table_name 92 | WHERE t.table_schema = $1 93 | AND t.table_type = 'BASE TABLE' 94 | ORDER BY pg_total_relation_size(quote_ident(t.table_schema)||'.'||quote_ident(t.table_name)) DESC 95 | """ 96 | 97 | result = await self.execute(query, schema) 98 | 99 | tables = {} 100 | for row in result: 101 | table_name = row["table_name"] 102 | 103 | # Get row count 104 | count_query = f"SELECT COUNT(*) as count FROM {schema}.{table_name}" 105 | count_result = await self.execute_one(count_query) 106 | 107 | # Get column info 108 | col_query = """ 109 | SELECT column_name, data_type, is_nullable 110 | FROM information_schema.columns 111 | WHERE table_schema = $1 AND table_name = $2 112 | ORDER BY ordinal_position 113 | """ 114 | columns = await self.execute(col_query, schema, table_name) 115 | 116 | tables[table_name] = { 117 | "size": row["size"], 118 | "size_bytes": row["size_bytes"], 119 | "row_count": count_result["count"] if count_result else 0, 120 | "comment": row["comment"], 121 | "columns": [dict(col) for col in columns], 122 | } 123 | 124 | return tables 125 | 126 | async def analyze_table(self, table_name: str, schema: str = "benchmark") -> None: 127 | """Run ANALYZE on a table to update statistics.""" 128 | query = f"ANALYZE {schema}.{table_name}" 129 | if self.connection: 130 | await self.connection.execute(query) 131 | 132 | async def get_query_stats(self) -> List[Dict[str, Any]]: 133 | """Get query statistics from pg_stat_statements.""" 134 | query = """ 135 | SELECT 136 | query, 137 | calls, 138 | total_exec_time, 139 | mean_exec_time, 140 | max_exec_time, 141 | min_exec_time, 142 | rows 143 | FROM pg_stat_statements 144 | WHERE query LIKE '%benchmark%' 145 | ORDER BY total_exec_time DESC 146 | LIMIT 20 147 | """ 148 | 149 | try: 150 | result = await self.execute(query) 151 | return [dict(row) for row in result] 152 | except Exception: 153 | # pg_stat_statements might not be available 154 | return [] 155 | 156 | async def reset_stats(self) -> None: 157 | """Reset PostgreSQL statistics.""" 158 | try: 159 | if self.connection: 160 | await self.connection.execute("SELECT pg_stat_reset()") 161 | await self.connection.execute("SELECT pg_stat_statements_reset()") 162 | except Exception as e: 163 | console.print(f"Warning: Could not reset stats: {e}", style="yellow") 164 | 165 | async def vacuum_analyze_table(self, table_name: str, schema: str = "benchmark") -> None: 166 | """Run VACUUM ANALYZE on a table.""" 167 | # VACUUM cannot be run inside a transaction, so we need a separate connection 168 | vacuum_conn = await asyncpg.connect( 169 | host=self.host, 170 | port=self.port, 171 | database=self.database, 172 | user=self.user, 173 | password=self.password, 174 | ) 175 | 176 | try: 177 | query = f"VACUUM ANALYZE {schema}.{table_name}" 178 | await vacuum_conn.execute(query) 179 | finally: 180 | await vacuum_conn.close() 181 | 182 | 183 | @asynccontextmanager 184 | async def timed_operation( 185 | description: str, verbose: bool = False 186 | ) -> AsyncGenerator[Dict[str, float], None]: 187 | """Async context manager to time database operations.""" 188 | if verbose: 189 | console.print(f"⏱️ Starting: {description}") 190 | 191 | start_time = time.time() 192 | timing_info = {} 193 | 194 | try: 195 | yield timing_info 196 | finally: 197 | end_time = time.time() 198 | duration = end_time - start_time 199 | timing_info["duration"] = duration 200 | timing_info["start_time"] = start_time 201 | timing_info["end_time"] = end_time 202 | 203 | if verbose: 204 | console.print(f"✅ Completed: {description} ({duration:.2f}s)") 205 | 206 | 207 | class AsyncConnectionPool: 208 | """Async connection pool wrapper for concurrent operations.""" 209 | 210 | def __init__(self, db_config: Dict[str, Any], pool_size: int = 5): 211 | self.db_config = db_config 212 | self.pool_size = pool_size 213 | self.pool: Optional[asyncpg.Pool] = None 214 | 215 | async def __aenter__(self) -> "AsyncConnectionPool": 216 | """Create and return the connection pool.""" 217 | self.pool = await asyncpg.create_pool( 218 | host=self.db_config["host"], 219 | port=self.db_config["port"], 220 | database=self.db_config["database"], 221 | user=self.db_config["user"], 222 | password=self.db_config["password"], 223 | min_size=1, 224 | max_size=self.pool_size, 225 | ) 226 | return self 227 | 228 | async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: 229 | """Close the connection pool.""" 230 | if self.pool: 231 | await self.pool.close() 232 | self.pool = None 233 | 234 | @asynccontextmanager 235 | async def acquire(self) -> AsyncGenerator[asyncpg.Connection, None]: 236 | """Acquire a connection from the pool.""" 237 | if not self.pool: 238 | raise RuntimeError("Connection pool not initialized") 239 | 240 | async with self.pool.acquire() as connection: 241 | yield connection 242 | 243 | 244 | def get_connection_pool(db_config: Dict[str, Any], pool_size: int = 5) -> AsyncConnectionPool: 245 | """Create an async connection pool for concurrent operations.""" 246 | return AsyncConnectionPool(db_config, pool_size) 247 | 248 | 249 | # Legacy sync interface for compatibility (will be removed after refactoring) 250 | class DatabaseConnection: 251 | """Synchronous wrapper around async database connection for backward compatibility.""" 252 | 253 | def __init__(self, host: str, port: int, database: str, user: str, password: str): 254 | self.async_conn = AsyncDatabaseConnection(host, port, database, user, password) 255 | self._loop = None 256 | 257 | def __enter__(self) -> "DatabaseConnection": 258 | """Enter context manager and establish connection.""" 259 | self._loop = asyncio.new_event_loop() 260 | asyncio.set_event_loop(self._loop) 261 | self._loop.run_until_complete(self.async_conn.__aenter__()) 262 | return self 263 | 264 | def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: 265 | """Exit context manager and close connection.""" 266 | if self._loop: 267 | self._loop.run_until_complete(self.async_conn.__aexit__(exc_type, exc_val, exc_tb)) 268 | self._loop.close() 269 | self._loop = None 270 | 271 | def execute(self, query: str, params: Optional[tuple] = None): 272 | """Execute a query and return a cursor-like object.""" 273 | if not self._loop: 274 | raise RuntimeError("Connection not established") 275 | 276 | # Convert params to individual arguments for asyncpg 277 | args = params if params else () 278 | result = self._loop.run_until_complete(self.async_conn.execute(query, *args)) 279 | 280 | # Return a cursor-like object for compatibility 281 | return SyncCursor(result) 282 | 283 | def execute_many(self, query: str, params_list: list) -> None: 284 | """Execute a query multiple times with different parameters.""" 285 | if not self._loop: 286 | raise RuntimeError("Connection not established") 287 | 288 | self._loop.run_until_complete(self.async_conn.execute_many(query, params_list)) 289 | 290 | def commit(self) -> None: 291 | """Commit is handled automatically by asyncpg.""" 292 | pass 293 | 294 | def rollback(self) -> None: 295 | """Rollback is handled by transaction context managers.""" 296 | pass 297 | 298 | def transaction(self): 299 | """Return a transaction context manager.""" 300 | if not self._loop: 301 | raise RuntimeError("Connection not established") 302 | return SyncTransaction(self.async_conn, self._loop) 303 | 304 | def get_table_info(self, schema: str = "benchmark") -> Dict[str, Dict[str, Any]]: 305 | """Get information about tables in the specified schema.""" 306 | if not self._loop: 307 | raise RuntimeError("Connection not established") 308 | 309 | return self._loop.run_until_complete(self.async_conn.get_table_info(schema)) 310 | 311 | 312 | class SyncCursor: 313 | """Cursor-like object for backward compatibility with psycopg.""" 314 | 315 | def __init__(self, records: List[asyncpg.Record]): 316 | self.records = records 317 | self._index = 0 318 | 319 | def fetchall(self) -> List[Dict[str, Any]]: 320 | """Fetch all records as dictionaries.""" 321 | return [dict(record) for record in self.records] 322 | 323 | def fetchone(self) -> Optional[Dict[str, Any]]: 324 | """Fetch one record as a dictionary.""" 325 | if self._index < len(self.records): 326 | record = dict(self.records[self._index]) 327 | self._index += 1 328 | return record 329 | return None 330 | 331 | def __iter__(self): 332 | """Make cursor iterable.""" 333 | return iter(dict(record) for record in self.records) 334 | 335 | 336 | class SyncTransaction: 337 | """Transaction context manager for backward compatibility.""" 338 | 339 | def __init__(self, async_conn: AsyncDatabaseConnection, loop: asyncio.AbstractEventLoop): 340 | self.async_conn = async_conn 341 | self.loop = loop 342 | self._transaction = None 343 | 344 | def __enter__(self): 345 | """Start a transaction.""" 346 | if not self.async_conn.connection: 347 | raise RuntimeError("Connection not established") 348 | 349 | # Start transaction 350 | self._transaction = self.async_conn.connection.transaction() 351 | self.loop.run_until_complete(self._transaction.__aenter__()) 352 | return self 353 | 354 | def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): 355 | """End transaction.""" 356 | if self._transaction: 357 | self.loop.run_until_complete(self._transaction.__aexit__(exc_type, exc_val, exc_tb)) 358 | self._transaction = None 359 | -------------------------------------------------------------------------------- /benchmarks/benchmarks/insertion.py: -------------------------------------------------------------------------------- 1 | """ 2 | Insertion benchmark for load testing database writes on unoptimized tables using asyncpg. 3 | """ 4 | 5 | import asyncio 6 | import json 7 | import random 8 | import time 9 | from datetime import datetime, timedelta 10 | from typing import Any, Dict, List, Tuple 11 | from uuid import uuid4 12 | 13 | from rich.console import Console 14 | from rich.progress import ( 15 | BarColumn, 16 | Progress, 17 | SpinnerColumn, 18 | TaskProgressColumn, 19 | TextColumn, 20 | TimeRemainingColumn, 21 | ) 22 | 23 | from .database import AsyncConnectionPool, AsyncDatabaseConnection, timed_operation 24 | from .utils import ( 25 | calculate_statistics, 26 | chunks, 27 | format_duration, 28 | format_number, 29 | generate_random_email, 30 | generate_random_string, 31 | generate_random_text, 32 | ) 33 | 34 | console = Console() 35 | 36 | 37 | class AsyncInsertionBenchmark: 38 | """Async benchmark for testing insertion performance on unoptimized tables.""" 39 | 40 | def __init__(self, db_config: Dict[str, Any], verbose: bool = False): 41 | self.db_config = db_config 42 | self.verbose = verbose 43 | 44 | # Table-specific insert queries and data generators 45 | self.table_configs = { 46 | "users": { 47 | "query": """ 48 | INSERT INTO benchmark.users (username, email, last_login, status, profile_data) 49 | VALUES ($1, $2, $3, $4, $5) 50 | """, 51 | "generator": self._generate_user_data, 52 | }, 53 | "posts": { 54 | "query": """ 55 | INSERT INTO benchmark.posts (user_id, title, content, updated_at, view_count, tags, metadata) 56 | VALUES ($1, $2, $3, $4, $5, $6, $7) 57 | """, 58 | "generator": self._generate_post_data, 59 | }, 60 | "comments": { 61 | "query": """ 62 | INSERT INTO benchmark.comments (post_id, user_id, content, parent_id, likes) 63 | VALUES ($1, $2, $3, $4, $5) 64 | """, 65 | "generator": self._generate_comment_data, 66 | }, 67 | "events": { 68 | "query": """ 69 | INSERT INTO benchmark.events (user_id, event_type, event_data, session_id, ip_address) 70 | VALUES ($1, $2, $3, $4, $5) 71 | """, 72 | "generator": self._generate_event_data, 73 | }, 74 | } 75 | 76 | async def run( 77 | self, table_name: str, num_records: int, batch_size: int = 1000, num_workers: int = 1 78 | ) -> Dict[str, Any]: 79 | """Run the async insertion benchmark.""" 80 | if table_name not in self.table_configs: 81 | raise ValueError(f"Unsupported table: {table_name}") 82 | 83 | config = self.table_configs[table_name] 84 | 85 | console.print(f"[cyan]Starting async insertion benchmark for table '{table_name}'[/cyan]") 86 | console.print( 87 | f"Records: {format_number(num_records)}, Batch size: {format_number(batch_size)}, Workers: {num_workers}" 88 | ) 89 | 90 | # Get user/post IDs for foreign key references 91 | reference_data = await self._get_reference_data() 92 | 93 | # Generate all data upfront 94 | console.print("[yellow]Generating test data...[/yellow]") 95 | async with timed_operation("Data generation", self.verbose) as timing: 96 | all_data = self._generate_batch_data(config["generator"], num_records, reference_data) 97 | 98 | console.print( 99 | f"✅ Generated {format_number(len(all_data))} records in {format_duration(timing['duration'])}" 100 | ) 101 | 102 | # Split into batches 103 | batches = list(chunks(all_data, batch_size)) 104 | console.print(f"[yellow]Split into {len(batches)} batches[/yellow]") 105 | 106 | # Run benchmark 107 | async with timed_operation( 108 | f"Async insertion benchmark ({num_workers} workers)", self.verbose 109 | ) as timing: 110 | if num_workers == 1: 111 | batch_times = await self._run_single_connection(config["query"], batches) 112 | else: 113 | batch_times = await self._run_multi_connection( 114 | config["query"], batches, num_workers 115 | ) 116 | 117 | # Calculate results 118 | total_duration = timing["duration"] 119 | records_processed = len(all_data) 120 | records_per_second = records_processed / total_duration if total_duration > 0 else 0 121 | 122 | stats = calculate_statistics(batch_times) 123 | 124 | results = { 125 | "table_name": table_name, 126 | "total_duration": total_duration, 127 | "records_processed": records_processed, 128 | "records_per_second": records_per_second, 129 | "batches_processed": len(batches), 130 | "avg_batch_time": sum(batch_times) / len(batch_times) if batch_times else 0, 131 | "batch_size": batch_size, 132 | "num_workers": num_workers, 133 | **{f"batch_{k}": v for k, v in stats.items()}, 134 | } 135 | 136 | # Add min/max/median for compatibility with CLI display 137 | if stats: 138 | results.update( 139 | {"min_time": stats["min"], "max_time": stats["max"], "median_time": stats["median"]} 140 | ) 141 | 142 | return results 143 | 144 | async def _get_reference_data(self) -> Dict[str, List[int]]: 145 | """Get existing IDs for foreign key references.""" 146 | reference_data = {"user_ids": [], "post_ids": []} 147 | 148 | try: 149 | async with AsyncDatabaseConnection(**self.db_config) as db: 150 | # Get user IDs 151 | user_result = await db.execute( 152 | "SELECT id FROM benchmark.users ORDER BY id LIMIT 1000" 153 | ) 154 | reference_data["user_ids"] = [row["id"] for row in user_result] 155 | 156 | # Get post IDs 157 | post_result = await db.execute( 158 | "SELECT id FROM benchmark.posts ORDER BY id LIMIT 1000" 159 | ) 160 | reference_data["post_ids"] = [row["id"] for row in post_result] 161 | 162 | except Exception as e: 163 | if self.verbose: 164 | console.print(f"Warning: Could not fetch reference data: {e}", style="yellow") 165 | 166 | return reference_data 167 | 168 | def _generate_batch_data( 169 | self, generator_func, num_records: int, reference_data: Dict[str, List[int]] 170 | ) -> List[Tuple]: 171 | """Generate all data for the benchmark.""" 172 | return [generator_func(reference_data) for _ in range(num_records)] 173 | 174 | def _generate_user_data(self, reference_data: Dict[str, List[int]]) -> Tuple: 175 | """Generate data for users table.""" 176 | username = generate_random_string(12) 177 | email = generate_random_email() 178 | last_login = datetime.now() - timedelta(days=random.randint(0, 365)) 179 | status = random.choice(["active", "inactive", "pending", "suspended"]) 180 | profile_data = json.dumps( 181 | { 182 | "age": random.randint(18, 80), 183 | "location": random.choice(["US", "UK", "CA", "AU", "DE", "FR"]), 184 | "preferences": { 185 | "theme": random.choice(["light", "dark"]), 186 | "notifications": random.choice([True, False]), 187 | }, 188 | } 189 | ) 190 | 191 | return (username, email, last_login, status, profile_data) 192 | 193 | def _generate_post_data(self, reference_data: Dict[str, List[int]]) -> Tuple: 194 | """Generate data for posts table.""" 195 | user_id = ( 196 | random.choice(reference_data["user_ids"]) 197 | if reference_data["user_ids"] 198 | else random.randint(1, 1000) 199 | ) 200 | title = generate_random_text(3, 8).replace(".", "") # Remove trailing period for titles 201 | content = generate_random_text(20, 200) 202 | updated_at = datetime.now() - timedelta(days=random.randint(0, 30)) 203 | view_count = random.randint(0, 10000) 204 | tags = [generate_random_string(6) for _ in range(random.randint(1, 5))] 205 | metadata = json.dumps( 206 | { 207 | "category": random.choice(["tech", "news", "sports", "entertainment", "science"]), 208 | "featured": random.choice([True, False]), 209 | "word_count": len(content.split()), 210 | } 211 | ) 212 | 213 | return (user_id, title, content, updated_at, view_count, tags, metadata) 214 | 215 | def _generate_comment_data(self, reference_data: Dict[str, List[int]]) -> Tuple: 216 | """Generate data for comments table.""" 217 | post_id = ( 218 | random.choice(reference_data["post_ids"]) 219 | if reference_data["post_ids"] 220 | else random.randint(1, 1000) 221 | ) 222 | user_id = ( 223 | random.choice(reference_data["user_ids"]) 224 | if reference_data["user_ids"] 225 | else random.randint(1, 1000) 226 | ) 227 | content = generate_random_text(5, 50) 228 | # Don't set parent_id for now to avoid foreign key violations during initial data load 229 | parent_id = None 230 | likes = random.randint(0, 100) 231 | 232 | return (post_id, user_id, content, parent_id, likes) 233 | 234 | def _generate_event_data(self, reference_data: Dict[str, List[int]]) -> Tuple: 235 | """Generate data for events table.""" 236 | user_id = ( 237 | random.choice(reference_data["user_ids"]) 238 | if reference_data["user_ids"] 239 | else random.randint(1, 1000) 240 | ) 241 | event_type = random.choice( 242 | ["login", "logout", "view_post", "create_post", "like", "comment", "share"] 243 | ) 244 | event_data = json.dumps( 245 | { 246 | "timestamp": datetime.now().isoformat(), 247 | "user_agent": random.choice(["Chrome", "Firefox", "Safari", "Edge"]), 248 | "referrer": random.choice(["google.com", "facebook.com", "twitter.com", "direct"]), 249 | "page": f"/page/{random.randint(1, 1000)}", 250 | } 251 | ) 252 | session_id = str(uuid4()) 253 | ip_address = f"{random.randint(1, 255)}.{random.randint(1, 255)}.{random.randint(1, 255)}.{random.randint(1, 255)}" 254 | 255 | return (user_id, event_type, event_data, session_id, ip_address) 256 | 257 | async def _run_single_connection(self, query: str, batches: List[List[Tuple]]) -> List[float]: 258 | """Run insertion benchmark with a single connection.""" 259 | batch_times = [] 260 | 261 | async with AsyncDatabaseConnection(**self.db_config) as db: 262 | with Progress( 263 | SpinnerColumn(), 264 | TextColumn("[progress.description]{task.description}"), 265 | BarColumn(), 266 | TaskProgressColumn(), 267 | TimeRemainingColumn(), 268 | transient=False, 269 | ) as progress: 270 | task = progress.add_task("Inserting batches...", total=len(batches)) 271 | 272 | for i, batch in enumerate(batches): 273 | start_time = time.time() 274 | 275 | async with db.transaction(): 276 | await db.execute_many(query, batch) 277 | 278 | batch_time = time.time() - start_time 279 | batch_times.append(batch_time) 280 | 281 | progress.update(task, completed=i + 1) 282 | 283 | return batch_times 284 | 285 | async def _run_multi_connection( 286 | self, query: str, batches: List[List[Tuple]], num_workers: int 287 | ) -> List[float]: 288 | """Run insertion benchmark with multiple connections.""" 289 | batch_times = [] 290 | completed_batches = 0 291 | 292 | async with AsyncConnectionPool(self.db_config, num_workers) as pool: 293 | with Progress( 294 | SpinnerColumn(), 295 | TextColumn("[progress.description]{task.description}"), 296 | BarColumn(), 297 | TaskProgressColumn(), 298 | TimeRemainingColumn(), 299 | transient=False, 300 | ) as progress: 301 | task = progress.add_task("Inserting batches...", total=len(batches)) 302 | 303 | # Create semaphore to limit concurrent operations 304 | semaphore = asyncio.Semaphore(num_workers) 305 | 306 | async def execute_batch_with_semaphore(batch: List[Tuple], batch_idx: int) -> float: 307 | async with semaphore: 308 | try: 309 | return await self._execute_batch(pool, query, batch, batch_idx) 310 | except Exception as e: 311 | console.print(f"Error in batch {batch_idx}: {e}", style="red") 312 | raise 313 | 314 | # Submit all batch jobs 315 | tasks = [ 316 | asyncio.create_task(execute_batch_with_semaphore(batch, i)) 317 | for i, batch in enumerate(batches) 318 | ] 319 | 320 | # Collect results as they complete, handling errors gracefully 321 | try: 322 | for task_obj in asyncio.as_completed(tasks): 323 | try: 324 | batch_time = await task_obj 325 | batch_times.append(batch_time) 326 | completed_batches += 1 327 | progress.update(task, completed=completed_batches) 328 | except Exception as e: 329 | # Log error but continue with other batches 330 | console.print(f"Batch failed: {e}", style="red") 331 | completed_batches += 1 332 | progress.update(task, completed=completed_batches) 333 | # Re-raise to stop processing if it's a critical error 334 | raise 335 | except Exception: 336 | # Cancel remaining tasks 337 | for task_obj in tasks: 338 | if not task_obj.done(): 339 | task_obj.cancel() 340 | raise 341 | 342 | return batch_times 343 | 344 | async def _execute_batch( 345 | self, pool: AsyncConnectionPool, query: str, batch: List[Tuple], batch_idx: int 346 | ) -> float: 347 | """Execute a single batch of insertions.""" 348 | start_time = time.time() 349 | 350 | async with pool.acquire() as conn: 351 | async with conn.transaction(): 352 | await conn.executemany(query, batch) 353 | 354 | return time.time() - start_time 355 | 356 | 357 | # Synchronous wrapper for backward compatibility 358 | class InsertionBenchmark: 359 | """Synchronous wrapper around AsyncInsertionBenchmark.""" 360 | 361 | def __init__(self, db_config: Dict[str, Any], verbose: bool = False): 362 | self.async_benchmark = AsyncInsertionBenchmark(db_config, verbose) 363 | 364 | def run( 365 | self, table_name: str, num_records: int, batch_size: int = 1000, num_workers: int = 1 366 | ) -> Dict[str, Any]: 367 | """Run the insertion benchmark synchronously.""" 368 | return asyncio.run( 369 | self.async_benchmark.run(table_name, num_records, batch_size, num_workers) 370 | ) 371 | --------------------------------------------------------------------------------