├── generic_connection_pool ├── py.typed ├── contrib │ ├── __init__.py │ ├── psycopg2.py │ ├── unix.py │ ├── asyncpg.py │ ├── unix_async.py │ ├── socket_async.py │ └── socket.py ├── __init__.py ├── threading │ ├── __init__.py │ ├── locks.py │ └── pool.py ├── asyncio │ ├── __init__.py │ ├── utils.py │ └── locks.py ├── exceptions.py ├── rankmap.py └── common.py ├── pytest.ini ├── docs ├── source │ ├── pages │ │ ├── changelog.rst │ │ ├── installation.rst │ │ ├── api.rst │ │ ├── contribute.rst │ │ └── quickstart.rst │ ├── index.rst │ └── conf.py ├── Makefile └── make.bat ├── .flake8 ├── .readthedocs.yaml ├── .github └── workflows │ ├── release.yml │ └── test.yml ├── tests ├── conftest.py ├── resources │ ├── ssl.cert │ └── ssl.key ├── contrib │ ├── test_unix.py │ ├── test_unix_async.py │ ├── test_async_socket_manager.py │ └── test_socket_manager.py ├── test_rankmap.py ├── test_sync_locks.py ├── test_async_locks.py ├── test_sync_pool.py └── test_async_pool.py ├── CHANGELOG.rst ├── examples ├── tcp_pool.py ├── pg_pool.py ├── ssl_pool.py ├── manager_hooks.py ├── async_pg_pool.py ├── async_pool.py └── custom_manager.py ├── LICENSE ├── .pre-commit-config.yaml ├── pyproject.toml ├── .gitignore ├── README.rst └── static └── connection-pool.drawio /generic_connection_pool/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /generic_connection_pool/contrib/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | asyncio_mode=auto 3 | -------------------------------------------------------------------------------- /docs/source/pages/changelog.rst: -------------------------------------------------------------------------------- 1 | .. _changelog: 2 | 3 | .. include:: ../../../CHANGELOG.rst 4 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | per-file-ignores = 4 | generic_connection_pool/*__init__.py: F401 5 | -------------------------------------------------------------------------------- /generic_connection_pool/__init__.py: -------------------------------------------------------------------------------- 1 | from . import common, exceptions 2 | from .common import BaseConnectionPool 3 | from .exceptions import BaseError, ConnectionPoolClosedError 4 | 5 | __all__ = [ 6 | 'BaseConnectionPool', 7 | 'ConnectionPoolClosedError', 8 | ] 9 | -------------------------------------------------------------------------------- /generic_connection_pool/threading/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Threading connection pool implementation. 3 | """ 4 | 5 | from .pool import BaseConnectionManager, ConnectionPool, ConnectionT, EndpointT 6 | 7 | __all__ = [ 8 | 'BaseConnectionManager', 9 | 'ConnectionPool', 10 | 'ConnectionT', 11 | 'EndpointT', 12 | ] 13 | -------------------------------------------------------------------------------- /generic_connection_pool/asyncio/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Asyncio connection pool implementation. 3 | """ 4 | 5 | from .locks import SharedLock 6 | from .pool import BaseConnectionManager, ConnectionPool, ConnectionT, EndpointT 7 | 8 | __all__ = [ 9 | 'BaseConnectionManager', 10 | 'ConnectionPool', 11 | 'ConnectionT', 12 | 'EndpointT', 13 | ] 14 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 2 | 3 | version: 2 4 | 5 | build: 6 | os: ubuntu-22.04 7 | tools: 8 | python: "3.11" 9 | 10 | sphinx: 11 | configuration: docs/source/conf.py 12 | 13 | python: 14 | install: 15 | - method: pip 16 | path: . 17 | extra_requirements: 18 | - docs 19 | -------------------------------------------------------------------------------- /docs/source/pages/installation.rst: -------------------------------------------------------------------------------- 1 | .. _installation: 2 | 3 | 4 | Installation 5 | ~~~~~~~~~~~~ 6 | 7 | This part of the documentation covers the installation of ``generic-connection-pool`` library. 8 | 9 | 10 | Installation using pip 11 | ______________________ 12 | 13 | To install ``generic-connection-pool``, run: 14 | 15 | .. code-block:: console 16 | 17 | $ pip install generic-connection-pool 18 | -------------------------------------------------------------------------------- /generic_connection_pool/exceptions.py: -------------------------------------------------------------------------------- 1 | 2 | class BaseError(Exception): 3 | """ 4 | Base package exception. 5 | """ 6 | 7 | 8 | class ConnectionPoolClosedError(BaseError): 9 | """ 10 | Connection pool already closed. 11 | """ 12 | 13 | 14 | class ConnectionPoolIsFull(BaseError): 15 | """ 16 | Indicates that a pool is full of connections. 17 | """ 18 | 19 | 20 | class ConnectionPoolNotFound(BaseError): 21 | """ 22 | Connection pool not found. 23 | """ 24 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: release 2 | 3 | on: 4 | release: 5 | types: 6 | - released 7 | 8 | jobs: 9 | release: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v2 13 | - name: Set up Python 14 | uses: actions/setup-python@v2 15 | with: 16 | python-version: '3.x' 17 | - name: Install dependencies 18 | run: | 19 | python -m pip install --upgrade pip poetry 20 | poetry install 21 | - name: Build and publish 22 | run: | 23 | poetry build 24 | poetry publish -u __token__ -p ${{ secrets.PYPI_TOKEN }} 25 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import random 2 | from pathlib import Path 3 | from typing import Generator 4 | 5 | import pytest 6 | 7 | 8 | @pytest.fixture(scope='session') 9 | def test_dir() -> Path: 10 | return Path(__file__).parent 11 | 12 | 13 | @pytest.fixture(scope='session') 14 | def resource_dir(test_dir: Path) -> Path: 15 | return test_dir / 'resources' 16 | 17 | 18 | @pytest.fixture(autouse=True) 19 | def init_random() -> None: 20 | random.seed(0) 21 | 22 | 23 | @pytest.fixture(scope='session') 24 | def delay(pytestconfig: pytest.Config) -> float: 25 | return pytestconfig.getoption('--delay', 0.05) 26 | 27 | 28 | @pytest.fixture(scope='session') 29 | def port_gen() -> Generator[int, None, None]: 30 | def gen() -> Generator[int, None, None]: 31 | for port in range(10000, 65535): 32 | yield port 33 | 34 | return gen() 35 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/pages/api.rst: -------------------------------------------------------------------------------- 1 | .. _api: 2 | 3 | 4 | Developer Interface 5 | ~~~~~~~~~~~~~~~~~~~ 6 | 7 | .. automodule:: generic_connection_pool 8 | :members: 9 | 10 | 11 | Threading 12 | _________ 13 | 14 | .. automodule:: generic_connection_pool.threading 15 | :members: 16 | 17 | 18 | Asyncio 19 | _______ 20 | 21 | .. automodule:: generic_connection_pool.asyncio 22 | :members: 23 | 24 | Contrib 25 | _______ 26 | 27 | .. automodule:: generic_connection_pool.contrib.socket 28 | :members: 29 | 30 | .. automodule:: generic_connection_pool.contrib.socket_async 31 | :members: 32 | 33 | .. automodule:: generic_connection_pool.contrib.unix 34 | :members: 35 | 36 | .. automodule:: generic_connection_pool.contrib.unix_async 37 | :members: 38 | 39 | .. automodule:: generic_connection_pool.contrib.psycopg2 40 | :members: 41 | 42 | .. automodule:: generic_connection_pool.contrib.asyncpg 43 | :members: 44 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. generic-connection-pool documentation master file, created by 2 | sphinx-quickstart on Tue Aug 8 22:40:11 2023. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | .. include:: ../../README.rst 7 | 8 | 9 | User Guide 10 | ---------- 11 | 12 | .. toctree:: 13 | :maxdepth: 1 14 | 15 | pages/quickstart 16 | pages/installation 17 | 18 | .. toctree:: 19 | :maxdepth: 1 20 | 21 | pages/changelog 22 | 23 | 24 | API Documentation 25 | ----------------- 26 | 27 | .. toctree:: 28 | :maxdepth: 1 29 | 30 | pages/api 31 | 32 | 33 | Contribute 34 | ---------- 35 | 36 | .. toctree:: 37 | :maxdepth: 1 38 | 39 | pages/contribute 40 | 41 | 42 | Links 43 | ----- 44 | 45 | - `Source code `_ 46 | 47 | 48 | Indices and tables 49 | ================== 50 | 51 | * :ref:`genindex` 52 | * :ref:`modindex` 53 | -------------------------------------------------------------------------------- /tests/resources/ssl.cert: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIC2DCCAcACCQCPGAEZ/izmXDANBgkqhkiG9w0BAQsFADAuMQswCQYDVQQGEwJV 3 | UzELMAkGA1UECAwCQ0ExEjAQBgNVBAMMCWxvY2FsaG9zdDAeFw0yMzAzMTYxNjMx 4 | NDhaFw0yNTEyMTAxNjMxNDhaMC4xCzAJBgNVBAYTAlVTMQswCQYDVQQIDAJDQTES 5 | MBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKC 6 | AQEAvLSFc4DafQ2zrU+05P0PGlGjDeab9zXfVZmbEziNQ/8V+RtvmwFw9aENJtRv 7 | IsehZgz505hmTbBtsVPGA3F9MUjBRyMD25ztPxMTA8jQ5xtKhXECErVQ2AmgHHSD 8 | mEWa9S7/Mcu8Ld5dL2lqFsJTG/GjlHcmRMzCHfNKrDQuXeucsNm1mWRlMZ/eK/bj 9 | n+RkeBAKaAN4khVV7UAFeIaISTuD4cUVfv14spB2UV3nmezATXmbR9YCOM91us8Y 10 | rJP3XMJ/vgWVRofDqGCZM9Fs5WlpSbl4/Fei/zDIS/5p6moukpEWNRuV3pN9WI+Q 11 | ciwWbCMwITOmMCXVelRuPpXefwIDAQABMA0GCSqGSIb3DQEBCwUAA4IBAQBOb+jT 12 | Tseengdf+J7XI4CYtrPoCf9OPZrZHRZcU/d7y04OZpiBin7thBkOsfhESVRABpKN 13 | UhAWtfgnF43Pbxm1l6jg6SbTIwaQ8kXXu+Cx/RaVVCAek5GDStFNSj3ZMsnGrFOO 14 | wtsYoHBLSO11mOu+VcfCbzUqGBaGGLbXySSBx0uwkcSo+Qa5NVXLk54mmWrg84nE 15 | HjdbWTVj+UrSWkJ/9G1bjl1QeKfT9gwusSDKsLiN5QUCVKFjNaARxJ5SDbrxaTmL 16 | 1391HGjt1zclDmnypokj56Wkj+HOkUKohC11b1Kgsxl8/Ykyntz89OAj0Vv17xYa 17 | LjQZuCWYlln6S9+S 18 | -----END CERTIFICATE----- 19 | -------------------------------------------------------------------------------- /CHANGELOG.rst: -------------------------------------------------------------------------------- 1 | Changelog 2 | ========= 3 | 4 | 0.7.0 (2023-12-09) 5 | ------------------ 6 | 7 | - asyncpg connection manager added. 8 | 9 | 10 | 0.6.1 (2023-10-06) 11 | ------------------ 12 | 13 | - some refactoring done. 14 | 15 | 16 | 0.6.0 (2023-10-05) 17 | ------------------ 18 | 19 | - python 3.12 support added. 20 | 21 | 22 | 0.5.0 (2023-08-17) 23 | ------------------ 24 | 25 | - unix socket managers implemented. 26 | 27 | 28 | 0.4.1 (2023-08-16) 29 | ------------------ 30 | 31 | - aliveness checking extracted to mixins. 32 | 33 | 34 | 0.4.0 (2023-08-15) 35 | ------------------ 36 | 37 | - socket connection manager aliveness check added 38 | - shared lock bug fixed 39 | 40 | 41 | 0.3.0 (2023-08-10) 42 | ------------------ 43 | 44 | - refactoring done. 45 | 46 | 47 | 0.2.0 (2023-04-19) 48 | ------------------ 49 | 50 | - socket_timeout contex manager works with ssl sockets. 51 | - graceful_timeout default set to 0. 52 | - connection release bug fixed. 53 | 54 | 55 | 0.1.1 (2023-03-17) 56 | ------------------ 57 | 58 | - Contrib bugs fixed 59 | 60 | 61 | 0.1.0 (2023-03-15) 62 | ------------------ 63 | 64 | - Initial release 65 | -------------------------------------------------------------------------------- /examples/tcp_pool.py: -------------------------------------------------------------------------------- 1 | import socket 2 | from ipaddress import IPv4Address, IPv6Address 3 | from typing import Tuple, Union 4 | 5 | from generic_connection_pool.contrib.socket import TcpSocketConnectionManager 6 | from generic_connection_pool.threading import ConnectionPool 7 | 8 | Port = int 9 | IpAddress = Union[IPv4Address, IPv6Address] 10 | Endpoint = Tuple[IpAddress, Port] 11 | Connection = socket.socket 12 | 13 | 14 | redis_pool = ConnectionPool[Endpoint, Connection]( 15 | TcpSocketConnectionManager(), 16 | idle_timeout=30.0, 17 | max_lifetime=600.0, 18 | min_idle=3, 19 | max_size=20, 20 | total_max_size=100, 21 | background_collector=True, 22 | ) 23 | 24 | 25 | def command(addr: IpAddress, port: int, cmd: str) -> None: 26 | with redis_pool.connection(endpoint=(addr, port), timeout=5.0) as sock: 27 | sock.sendall(cmd.encode() + b'\n') 28 | response = sock.recv(1024) 29 | 30 | print(response.decode()) 31 | 32 | 33 | try: 34 | command(IPv4Address('127.0.0.1'), 6379, 'CLIENT ID') # tcp connection opened 35 | command(IPv4Address('127.0.0.1'), 6379, 'INFO') # tcp connection reused 36 | finally: 37 | redis_pool.close() 38 | -------------------------------------------------------------------------------- /docs/source/pages/contribute.rst: -------------------------------------------------------------------------------- 1 | .. _contribute: 2 | 3 | 4 | Development 5 | ~~~~~~~~~~~ 6 | 7 | Initialize the development environment installing dev dependencies: 8 | 9 | .. code-block:: console 10 | 11 | $ poetry install --no-root 12 | 13 | 14 | Code style 15 | __________ 16 | 17 | After any code changes make sure that code style is followed. 18 | To control that automatically install pre-commit hooks: 19 | 20 | .. code-block:: console 21 | 22 | $ pre-commit install 23 | 24 | It will be checking your changes for coding conventions used in the project before any commit. 25 | 26 | 27 | Pull Requests 28 | _____________ 29 | 30 | To contribute checkout branch ``dev``, create a feature branch and make pull request setting 31 | ``dev`` as a target. 32 | 33 | 34 | Documentation 35 | _____________ 36 | 37 | If you've made any changes to the documentation, make sure it builds successfully. 38 | To build the documentation follow the instructions: 39 | 40 | - Install documentation generation dependencies: 41 | 42 | .. code-block:: console 43 | 44 | $ poetry install -E docs 45 | 46 | - Build the documentation: 47 | 48 | .. code-block:: console 49 | 50 | $ cd docs 51 | $ make html 52 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: test 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - dev 7 | - master 8 | push: 9 | branches: 10 | - master 11 | 12 | jobs: 13 | test: 14 | runs-on: ubuntu-latest 15 | strategy: 16 | matrix: 17 | python-version: ['3.9', '3.10', '3.11', '3.12'] 18 | steps: 19 | - uses: actions/checkout@v2 20 | - name: Set up Python ${{ matrix.python-version }} 21 | uses: actions/setup-python@v2 22 | with: 23 | python-version: ${{ matrix.python-version }} 24 | - name: Install dependencies 25 | run: | 26 | python -m pip install --upgrade pip 27 | pip install poetry 28 | poetry install --no-root 29 | - name: Run pre-commit hooks 30 | run: poetry run pre-commit run --hook-stage merge-commit --all-files 31 | - name: Run tests 32 | run: PYTHONPATH="$(pwd):$PYTHONPATH" poetry run py.test --cov=generic_connection_pool --cov-report=xml tests 33 | - name: Upload coverage to Codecov 34 | uses: codecov/codecov-action@v1 35 | with: 36 | token: ${{ secrets.CODECOV_TOKEN }} 37 | files: ./coverage.xml 38 | flags: unittests 39 | fail_ci_if_error: true 40 | -------------------------------------------------------------------------------- /generic_connection_pool/asyncio/utils.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import functools as ft 3 | import typing 4 | from typing import Any, Awaitable, Callable, TypeVar 5 | 6 | ResultT = TypeVar('ResultT') 7 | 8 | 9 | async def guard(coro: Awaitable[ResultT]) -> ResultT: 10 | """ 11 | Guards a coroutine from cancellation ignoring `asyncio.CancelledError`. 12 | Acts similar to `asyncio.shield` but wait for the shielded coroutine to be finished. 13 | """ 14 | 15 | inner_fut = asyncio.ensure_future(coro) 16 | 17 | cancelled = False 18 | while True: 19 | try: 20 | result = await asyncio.shield(inner_fut) 21 | except asyncio.CancelledError: 22 | cancelled = True 23 | else: 24 | break 25 | 26 | if cancelled: 27 | raise asyncio.CancelledError 28 | 29 | return result 30 | 31 | 32 | AsyncFuncT = TypeVar('AsyncFuncT', bound=Callable[..., Awaitable[Any]]) 33 | 34 | 35 | def guarded(func: AsyncFuncT) -> AsyncFuncT: 36 | """ 37 | Guard decorator. 38 | """ 39 | 40 | @ft.wraps(func) 41 | async def decorator(*args: Any, **kwargs: Any) -> Any: 42 | return await guard(func(*args, **kwargs)) 43 | 44 | return typing.cast(AsyncFuncT, decorator) 45 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | This is free and unencumbered software released into the public domain. 2 | 3 | Anyone is free to copy, modify, publish, use, compile, sell, or 4 | distribute this software, either in source code form or as a compiled 5 | binary, for any purpose, commercial or non-commercial, and by any 6 | means. 7 | 8 | In jurisdictions that recognize copyright laws, the author or authors 9 | of this software dedicate any and all copyright interest in the 10 | software to the public domain. We make this dedication for the benefit 11 | of the public at large and to the detriment of our heirs and 12 | successors. We intend this dedication to be an overt act of 13 | relinquishment in perpetuity of all present and future rights to this 14 | software under copyright law. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 19 | IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR 20 | OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, 21 | ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 22 | OTHER DEALINGS IN THE SOFTWARE. 23 | 24 | For more information, please refer to 25 | -------------------------------------------------------------------------------- /examples/pg_pool.py: -------------------------------------------------------------------------------- 1 | import psycopg2.extensions 2 | 3 | from generic_connection_pool.contrib.psycopg2 import DbConnectionManager 4 | from generic_connection_pool.threading import ConnectionPool 5 | 6 | Endpoint = str 7 | Connection = psycopg2.extensions.connection 8 | 9 | 10 | dsn_params = dict(dbname='postgres', user='postgres', password='secret') 11 | 12 | pg_pool = ConnectionPool[Endpoint, Connection]( 13 | DbConnectionManager( 14 | dsn_params={ 15 | 'master': dict(dsn_params, host='db-master.local'), 16 | 'replica-1': dict(dsn_params, host='db-replica-1.local'), 17 | 'replica-2': dict(dsn_params, host='db-replica-2.local'), 18 | }, 19 | ), 20 | acquire_timeout=2.0, 21 | idle_timeout=60.0, 22 | max_lifetime=600.0, 23 | min_idle=3, 24 | max_size=10, 25 | total_max_size=15, 26 | background_collector=True, 27 | ) 28 | 29 | try: 30 | # connection opened 31 | with pg_pool.connection(endpoint='master') as conn: 32 | cur = conn.cursor() 33 | cur.execute("SELECT inet_server_addr()") 34 | print(cur.fetchone()) 35 | 36 | # connection opened 37 | with pg_pool.connection(endpoint='replica-1') as conn: 38 | cur = conn.cursor() 39 | cur.execute("SELECT inet_server_addr()") 40 | print(cur.fetchone()) 41 | 42 | # connection reused 43 | with pg_pool.connection(endpoint='master') as conn: 44 | cur = conn.cursor() 45 | cur.execute("SELECT inet_server_addr()") 46 | print(cur.fetchone()) 47 | 48 | finally: 49 | pg_pool.close() 50 | -------------------------------------------------------------------------------- /examples/ssl_pool.py: -------------------------------------------------------------------------------- 1 | import ssl 2 | import urllib.parse 3 | from http.client import HTTPResponse 4 | from typing import Tuple 5 | 6 | from generic_connection_pool.contrib.socket import SslSocketConnectionManager 7 | from generic_connection_pool.threading import ConnectionPool 8 | 9 | Hostname = str 10 | Port = int 11 | Endpoint = Tuple[Hostname, Port] 12 | Connection = ssl.SSLSocket 13 | 14 | 15 | http_pool = ConnectionPool[Endpoint, Connection]( 16 | SslSocketConnectionManager(ssl.create_default_context()), 17 | idle_timeout=30.0, 18 | max_lifetime=600.0, 19 | min_idle=3, 20 | max_size=20, 21 | total_max_size=100, 22 | background_collector=True, 23 | ) 24 | 25 | 26 | def fetch(url: str, timeout: float = 5.0) -> None: 27 | url = urllib.parse.urlsplit(url) 28 | if url.hostname is None: 29 | raise ValueError 30 | 31 | port = url.port or 443 if url.scheme == 'https' else 80 32 | 33 | with http_pool.connection(endpoint=(url.hostname, port), timeout=timeout) as sock: 34 | request = ( 35 | 'GET {path} HTTP/1.1\r\n' 36 | 'Host: {host}\r\n' 37 | '\r\n' 38 | '\r\n' 39 | ).format(host=url.hostname, path=url.path) 40 | 41 | sock.write(request.encode()) 42 | 43 | response = HTTPResponse(sock) 44 | response.begin() 45 | status, body = response.getcode(), response.read(response.length) 46 | 47 | print(status) 48 | print(body) 49 | 50 | 51 | try: 52 | fetch('https://en.wikipedia.org/wiki/HTTP') # http connection opened 53 | fetch('https://en.wikipedia.org/wiki/Python_(programming_language)') # http connection reused 54 | finally: 55 | http_pool.close() 56 | -------------------------------------------------------------------------------- /examples/manager_hooks.py: -------------------------------------------------------------------------------- 1 | import ssl 2 | import time 3 | from ssl import SSLSocket 4 | from typing import Any, Dict, Tuple 5 | 6 | import prometheus_client as prom 7 | 8 | from generic_connection_pool.contrib.socket import SslSocketConnectionManager 9 | from generic_connection_pool.threading import ConnectionPool 10 | 11 | Hostname = str 12 | Port = int 13 | Endpoint = Tuple[Hostname, Port] 14 | 15 | acquire_latency_hist = prom.Histogram('acquire_latency', 'Connections acquire latency', labelnames=['hostname']) 16 | acquire_total = prom.Counter('acquire_total', 'Connections acquire count', labelnames=['hostname']) 17 | dead_conn_total = prom.Counter('dead_conn_total', 'Dead connections count', labelnames=['hostname']) 18 | 19 | 20 | class ObservableConnectionManager(SslSocketConnectionManager): 21 | 22 | def __init__(self, *args: Any, **kwargs: Any): 23 | super().__init__(*args, **kwargs) 24 | self._acquires: Dict[SSLSocket, float] = {} 25 | 26 | def on_acquire(self, endpoint: Endpoint, conn: SSLSocket) -> None: 27 | hostname, port = endpoint 28 | 29 | acquire_total.labels(hostname).inc() 30 | self._acquires[conn] = time.time() 31 | 32 | def on_release(self, endpoint: Endpoint, conn: SSLSocket) -> None: 33 | hostname, port = endpoint 34 | 35 | acquired_at = self._acquires.pop(conn) 36 | acquire_latency_hist.labels(hostname).observe(time.time() - acquired_at) 37 | 38 | def on_connection_dead(self, endpoint: Endpoint, conn: SSLSocket) -> None: 39 | hostname, port = endpoint 40 | 41 | dead_conn_total.labels(hostname).inc() 42 | 43 | 44 | http_pool = ConnectionPool[Endpoint, SSLSocket]( 45 | ObservableConnectionManager(ssl.create_default_context()), 46 | ) 47 | -------------------------------------------------------------------------------- /examples/async_pg_pool.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | import asyncpg 4 | 5 | from generic_connection_pool.asyncio import ConnectionPool 6 | from generic_connection_pool.contrib.asyncpg import DbConnectionManager 7 | 8 | Endpoint = str 9 | Connection = asyncpg.Connection 10 | 11 | 12 | async def main() -> None: 13 | dsn_params = dict(database='postgres', user='postgres', password='secret') 14 | 15 | pg_pool = ConnectionPool[Endpoint, 'Connection[asyncpg.Record]']( 16 | DbConnectionManager[asyncpg.Record]( 17 | dsn_params={ 18 | 'master': dict(dsn_params, host='db-master.local'), 19 | 'replica-1': dict(dsn_params, host='db-replica-1.local'), 20 | 'replica-2': dict(dsn_params, host='db-replica-2.local'), 21 | }, 22 | ), 23 | acquire_timeout=2.0, 24 | idle_timeout=60.0, 25 | max_lifetime=600.0, 26 | min_idle=3, 27 | max_size=10, 28 | total_max_size=15, 29 | background_collector=True, 30 | ) 31 | 32 | try: 33 | # connection opened 34 | async with pg_pool.connection(endpoint='master') as conn: 35 | result = await conn.fetchval("SELECT inet_server_addr()") 36 | print(result) 37 | 38 | # connection opened 39 | async with pg_pool.connection(endpoint='replica-1') as conn: 40 | result = await conn.fetchval("SELECT inet_server_addr()") 41 | print(result) 42 | 43 | # connection reused 44 | async with pg_pool.connection(endpoint='master') as conn: 45 | result = await conn.fetchval("SELECT inet_server_addr()") 46 | print(result) 47 | 48 | finally: 49 | await pg_pool.close() 50 | 51 | asyncio.run(main()) 52 | -------------------------------------------------------------------------------- /tests/resources/ssl.key: -------------------------------------------------------------------------------- 1 | -----BEGIN PRIVATE KEY----- 2 | MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQC8tIVzgNp9DbOt 3 | T7Tk/Q8aUaMN5pv3Nd9VmZsTOI1D/xX5G2+bAXD1oQ0m1G8ix6FmDPnTmGZNsG2x 4 | U8YDcX0xSMFHIwPbnO0/ExMDyNDnG0qFcQIStVDYCaAcdIOYRZr1Lv8xy7wt3l0v 5 | aWoWwlMb8aOUdyZEzMId80qsNC5d65yw2bWZZGUxn94r9uOf5GR4EApoA3iSFVXt 6 | QAV4hohJO4PhxRV+/XiykHZRXeeZ7MBNeZtH1gI4z3W6zxisk/dcwn++BZVGh8Oo 7 | YJkz0WzlaWlJuXj8V6L/MMhL/mnqai6SkRY1G5Xek31Yj5ByLBZsIzAhM6YwJdV6 8 | VG4+ld5/AgMBAAECggEAfZ7/KAEjchRpBHsHRVlhcHfgQCTAtzVZ07ZKEeWXxShP 9 | DGJDTcEL3bi09KB+y3xx6WnB9iaCFD3bCC1oqGoomWKBqEWbD9vL5C1ifyZ0SyVT 10 | 2rl8U8/4XZkqyUaXRAsyOP5sTE4Am9hn2GQoh7YddYDLEM6w3yQgJagMkc66/zHx 11 | f3oqFPx7BcbB6VDTIpjxawc6k4A1C5ApDJG9udIA1aptYkJhpbQsHjxTSmW3oo4e 12 | awwOr5JnPuPT03mzVtA4euuoqHROmFUt3Yhz/SVPCQ64cBARP6/i8WdWG3hV9JcI 13 | DG0IoQ9XMQqobWRnThtytV4OPiG8gUhCCX+GMk8iYQKBgQDiP3oeGcwrJbLKLWkj 14 | 0N4mpgTtRzaoeBRE9j4PSElwcvidJY6b10oeqSXEN6yKxfsPIZojiu9CdyfHXE3N 15 | pX4dB4DpBZnAEEJzDoUQFtWcFHNnoFzfYWwtnTU/NaSVmAdJ7uhXaGw2F3QoBgUd 16 | IPDCOs9wCesd3v1KjrpO4RnKcwKBgQDVhS9fdEo7ZQkgwyHTgR+FFEB0Trd/dLyS 17 | 5Z9kjWu2+TDVht09B0+6uU6dgAtu2idXaSpCmfJHw6fOV0paogtpjyKh/GNXhSie 18 | xH22YTTrgzbTOCJEN+9o43WoqbxIVZiXZhyFupbGXjgKMODX/m7L3gwGVsvil8dU 19 | PuRdFWWcxQKBgHVqtBnDEa6i1fMPNi2cTG6KYqwx9S/hgcN4eCS+Qz7UrCoCP8yp 20 | IpJe/nai3iz3KqBjs/cWN62q4T4ZrVc4uAagykok2fJPfezwcCY1c46ZHnt9QjW7 21 | /cR+fg/b6xqn18CK+JHEY8R+z42l8il32vsyQk3HF/pcq99xy0b8k8H5AoGAHa1e 22 | UUEjlC/N3fzhNbmLvP58mu3Z+WArWauKxPoXD56BGByfoXzjqwtYjvGeJTEzKKYY 23 | Vpt5Hlpmd3qQfhppxak8YhFnaWG7rJ2Y74GBTn61XxQ9RwgTQZvj3aaB4ffrtpdd 24 | vYSaskWkOl5i0gKuOa3KNBNaUUtRTDdVnE5+ChUCgYAprsJ8CyTMldg3K8UfURkN 25 | 7VHBKRJB6txx3Ud7Imixxyjl1cBNT9feq3TOqd3PEhpYI1Xei+yjglxRq5dfBDeh 26 | 7u0ZXy9nT4njaTXO8AK96zIb3ZjmOl5VZJ9Nnox7O34cWKp3wGKQx7uAGNfBRmuq 27 | ZSahWVs86vOvxf04/o5vNA== 28 | -----END PRIVATE KEY----- 29 | -------------------------------------------------------------------------------- /generic_connection_pool/contrib/psycopg2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Postgres connection manager implementation. 3 | """ 4 | 5 | from typing import Mapping, Optional 6 | 7 | import psycopg2.extensions 8 | 9 | from generic_connection_pool.threading import BaseConnectionManager 10 | 11 | DbEndpoint = str 12 | Connection = psycopg2.extensions.connection 13 | DsnParameters = Mapping[DbEndpoint, Mapping[str, str]] 14 | 15 | 16 | class DbConnectionManager(BaseConnectionManager[DbEndpoint, Connection]): 17 | """ 18 | Psycopg2 based postgres connection manager. 19 | 20 | :param dsn_params: databases dsn parameters 21 | """ 22 | 23 | def __init__(self, dsn_params: DsnParameters): 24 | self._dsn_params = dsn_params 25 | 26 | def create(self, endpoint: DbEndpoint, timeout: Optional[float] = None) -> Connection: 27 | return psycopg2.connect(**self._dsn_params[endpoint]) # type: ignore[call-overload] 28 | 29 | def dispose(self, endpoint: DbEndpoint, conn: Connection, timeout: Optional[float] = None) -> None: 30 | conn.close() 31 | 32 | def check_aliveness(self, endpoint: DbEndpoint, conn: Connection, timeout: Optional[float] = None) -> bool: 33 | try: 34 | with conn.cursor() as cur: 35 | cur.execute("SELECT 1;") 36 | cur.fetchone() 37 | except (psycopg2.Error, OSError): 38 | return False 39 | 40 | return True 41 | 42 | def on_acquire(self, endpoint: DbEndpoint, conn: Connection) -> None: 43 | self._rollback_uncommitted(conn) 44 | 45 | def on_release(self, endpoint: DbEndpoint, conn: Connection) -> None: 46 | self._rollback_uncommitted(conn) 47 | 48 | def _rollback_uncommitted(self, conn: Connection) -> None: 49 | if conn.status == psycopg2.extensions.STATUS_IN_TRANSACTION: 50 | conn.rollback() 51 | -------------------------------------------------------------------------------- /tests/contrib/test_unix.py: -------------------------------------------------------------------------------- 1 | import socket 2 | import socketserver 3 | import sys 4 | import tempfile 5 | import threading 6 | from pathlib import Path 7 | from typing import Generator, Optional 8 | 9 | import pytest 10 | 11 | from generic_connection_pool.contrib.unix import UnixSocketConnectionManager 12 | from generic_connection_pool.threading import ConnectionPool 13 | from tests.contrib.test_socket_manager import EchoTCPServer 14 | 15 | if sys.platform not in ('linux', 'darwin', 'freebsd'): 16 | pytest.skip(reason="unix platform only", allow_module_level=True) 17 | 18 | 19 | class EchoUnixServer(EchoTCPServer): 20 | address_family = socket.AF_UNIX 21 | 22 | def __init__(self, server_address: str): 23 | socketserver.TCPServer.__init__(self, server_address, self.EchoRequestHandler) 24 | self._server_thread: Optional[threading.Thread] = None 25 | 26 | 27 | @pytest.fixture(scope='module') 28 | def unix_socket_server() -> Generator[Path, None, None]: 29 | with tempfile.TemporaryDirectory() as tmp_dir: 30 | path = Path(tmp_dir) / 'unix-socket-server.sock' 31 | server = EchoUnixServer(str(path)) 32 | server.start() 33 | yield path 34 | server.stop() 35 | 36 | 37 | @pytest.mark.timeout(5.0) 38 | def test_unix_socket_manager(unix_socket_server: Path): 39 | path = unix_socket_server 40 | 41 | pool = ConnectionPool[Path, socket.socket](UnixSocketConnectionManager()) 42 | 43 | attempts = 3 44 | request = b'test' 45 | for _ in range(attempts): 46 | with pool.connection(path) as sock1: 47 | sock1.sendall(request) 48 | response = sock1.recv(len(request)) 49 | assert response == request 50 | 51 | with pool.connection(path) as sock2: 52 | sock2.sendall(request) 53 | response = sock2.recv(len(request)) 54 | assert response == request 55 | 56 | pool.close() 57 | -------------------------------------------------------------------------------- /generic_connection_pool/contrib/unix.py: -------------------------------------------------------------------------------- 1 | """ 2 | Unix specific functionality. 3 | """ 4 | 5 | import errno 6 | import pathlib 7 | import socket 8 | import sys 9 | from typing import Generic, Optional 10 | 11 | from generic_connection_pool.contrib.socket import BaseConnectionManager 12 | from generic_connection_pool.threading import EndpointT 13 | 14 | from .socket import socket_timeout 15 | 16 | if sys.platform not in ('linux', 'darwin', 'freebsd'): 17 | raise AssertionError('this module is only supported by unix platforms') 18 | 19 | 20 | UnixSocketEndpoint = pathlib.Path 21 | 22 | 23 | class CheckSocketAlivenessMixin(Generic[EndpointT]): 24 | """ 25 | Socket aliveness checking mixin. 26 | """ 27 | 28 | def check_aliveness(self, endpoint: EndpointT, conn: socket.socket, timeout: Optional[float] = None) -> bool: 29 | try: 30 | if conn.recv(1, socket.MSG_PEEK | socket.MSG_DONTWAIT) == b'': 31 | return False 32 | except BlockingIOError as exc: 33 | if exc.errno != errno.EAGAIN: 34 | raise 35 | except OSError: 36 | return False 37 | 38 | return True 39 | 40 | 41 | class UnixSocketConnectionManager( 42 | CheckSocketAlivenessMixin[UnixSocketEndpoint], 43 | BaseConnectionManager[UnixSocketEndpoint, socket.socket], 44 | ): 45 | """ 46 | Unix socket connection manager. 47 | """ 48 | 49 | def create(self, endpoint: UnixSocketEndpoint, timeout: Optional[float] = None) -> socket.socket: 50 | sock = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_STREAM) 51 | 52 | with socket_timeout(sock, timeout): 53 | sock.connect(str(endpoint)) 54 | 55 | return sock 56 | 57 | def dispose(self, endpoint: UnixSocketEndpoint, conn: socket.socket, timeout: Optional[float] = None) -> None: 58 | try: 59 | conn.shutdown(socket.SHUT_RDWR) 60 | except OSError: 61 | pass 62 | 63 | conn.close() 64 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | # -- Project information ----------------------------------------------------- 7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 8 | 9 | import sys 10 | from pathlib import Path 11 | 12 | import toml 13 | 14 | THIS_PATH = Path(__file__).parent 15 | ROOT_PATH = THIS_PATH.parent.parent 16 | sys.path.insert(0, str(ROOT_PATH)) 17 | 18 | PYPROJECT = toml.load(ROOT_PATH / 'pyproject.toml') 19 | PROJECT_INFO = PYPROJECT['tool']['poetry'] 20 | 21 | project = PROJECT_INFO['name'] 22 | copyright = f"2023, {PROJECT_INFO['name']}" 23 | author = PROJECT_INFO['authors'][0] 24 | release = PROJECT_INFO['version'] 25 | 26 | # -- General configuration --------------------------------------------------- 27 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 28 | 29 | extensions = [ 30 | 'sphinx.ext.autodoc', 31 | 'sphinx.ext.doctest', 32 | 'sphinx.ext.intersphinx', 33 | 'sphinx.ext.autosectionlabel', 34 | 'sphinx.ext.viewcode', 35 | 'sphinx_copybutton', 36 | 'sphinx_design', 37 | ] 38 | 39 | intersphinx_mapping = { 40 | 'python': ('https://docs.python.org/3', None), 41 | } 42 | 43 | autodoc_typehints = 'description' 44 | autodoc_typehints_format = 'short' 45 | autodoc_member_order = 'bysource' 46 | autodoc_default_options = { 47 | 'show-inheritance': True, 48 | } 49 | 50 | autosectionlabel_prefix_document = True 51 | 52 | html_theme_options = {} 53 | html_title = PROJECT_INFO['name'] 54 | 55 | templates_path = ['templates'] 56 | exclude_patterns = [] 57 | 58 | # -- Options for HTML output ------------------------------------------------- 59 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 60 | 61 | html_theme = 'furo' 62 | html_static_path = ['static'] 63 | html_css_files = ['css/custom.css'] 64 | -------------------------------------------------------------------------------- /generic_connection_pool/contrib/asyncpg.py: -------------------------------------------------------------------------------- 1 | """ 2 | Postgres asyncpg connection manager implementation. 3 | """ 4 | 5 | from typing import Generic, Mapping, Optional, TypeVar 6 | 7 | import asyncpg 8 | 9 | from generic_connection_pool.asyncio import BaseConnectionManager 10 | 11 | DbEndpoint = str 12 | Connection = asyncpg.Connection 13 | DsnParameters = Mapping[DbEndpoint, Mapping[str, str]] 14 | 15 | RecordT = TypeVar('RecordT', bound=asyncpg.Record) 16 | 17 | 18 | class DbConnectionManager(BaseConnectionManager[DbEndpoint, 'Connection[RecordT]'], Generic[RecordT]): 19 | """ 20 | Psycopg2 based postgres connection manager. 21 | 22 | :param dsn_params: databases dsn parameters 23 | """ 24 | 25 | def __init__(self, dsn_params: DsnParameters): 26 | self._dsn_params = dsn_params 27 | 28 | async def create( 29 | self, 30 | endpoint: DbEndpoint, 31 | timeout: Optional[float] = None, 32 | ) -> 'Connection[RecordT]': 33 | return await asyncpg.connect(**self._dsn_params[endpoint]) # type: ignore[call-overload] 34 | 35 | async def dispose( 36 | self, 37 | endpoint: DbEndpoint, 38 | conn: 'Connection[RecordT]', 39 | timeout: Optional[float] = None, 40 | ) -> None: 41 | await conn.close(timeout=timeout) 42 | 43 | async def check_aliveness( 44 | self, 45 | endpoint: DbEndpoint, 46 | conn: 'Connection[RecordT]', 47 | timeout: Optional[float] = None, 48 | ) -> bool: 49 | return conn.is_closed() 50 | 51 | async def on_acquire(self, endpoint: DbEndpoint, conn: 'Connection[RecordT]') -> None: 52 | await self._rollback_uncommitted(conn) 53 | 54 | async def on_release(self, endpoint: DbEndpoint, conn: 'Connection[RecordT]') -> None: 55 | await self._rollback_uncommitted(conn) 56 | 57 | async def _rollback_uncommitted(self, conn: 'Connection[RecordT]') -> None: 58 | if conn.is_in_transaction(): 59 | await conn.execute('ROLLBACK') 60 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_stages: 2 | - commit 3 | - merge-commit 4 | 5 | repos: 6 | - repo: https://github.com/pre-commit/pre-commit-hooks 7 | rev: v4.4.0 8 | hooks: 9 | - id: check-yaml 10 | - id: check-toml 11 | - id: trailing-whitespace 12 | - id: end-of-file-fixer 13 | stages: 14 | - commit 15 | - id: mixed-line-ending 16 | name: fix line ending 17 | stages: 18 | - commit 19 | args: 20 | - --fix=lf 21 | - id: mixed-line-ending 22 | name: check line ending 23 | stages: 24 | - merge-commit 25 | args: 26 | - --fix=no 27 | - repo: https://github.com/asottile/add-trailing-comma 28 | rev: v3.1.0 29 | hooks: 30 | - id: add-trailing-comma 31 | stages: 32 | - commit 33 | - repo: https://github.com/pre-commit/mirrors-autopep8 34 | rev: v2.0.4 35 | hooks: 36 | - id: autopep8 37 | stages: 38 | - commit 39 | args: 40 | - --diff 41 | - repo: https://github.com/pycqa/flake8 42 | rev: 6.1.0 43 | hooks: 44 | - id: flake8 45 | - repo: https://github.com/pycqa/isort 46 | rev: 5.12.0 47 | hooks: 48 | - id: isort 49 | name: fix import order 50 | stages: 51 | - commit 52 | args: 53 | - --line-length=120 54 | - --multi-line=9 55 | - --project=generic_connection_pool 56 | - id: isort 57 | name: check import order 58 | stages: 59 | - merge-commit 60 | args: 61 | - --check-only 62 | - --line-length=120 63 | - --multi-line=9 64 | - --project=generic_connection_pool 65 | - repo: https://github.com/pre-commit/mirrors-mypy 66 | rev: v1.5.1 67 | hooks: 68 | - id: mypy 69 | stages: 70 | - commit 71 | name: mypy 72 | pass_filenames: false 73 | additional_dependencies: 74 | - types-psycopg2>=2.9.5 75 | - asyncpg-stubs>=0.29.0 76 | args: ["--package", "generic_connection_pool"] 77 | -------------------------------------------------------------------------------- /generic_connection_pool/contrib/unix_async.py: -------------------------------------------------------------------------------- 1 | """ 2 | Asynchronous unix specific functionality. 3 | """ 4 | 5 | import asyncio 6 | import pathlib 7 | import socket 8 | import sys 9 | 10 | from generic_connection_pool.contrib.socket_async import BaseConnectionManager, SocketAlivenessCheckingMixin, Stream 11 | from generic_connection_pool.contrib.socket_async import StreamAlivenessCheckingMixin 12 | 13 | if sys.platform not in ('linux', 'darwin', 'freebsd'): 14 | raise AssertionError('this module is only supported by unix platforms') 15 | 16 | 17 | UnixSocketEndpoint = pathlib.Path 18 | 19 | 20 | class UnixSocketConnectionManager( 21 | SocketAlivenessCheckingMixin[UnixSocketEndpoint], 22 | BaseConnectionManager[UnixSocketEndpoint, socket.socket], 23 | ): 24 | """ 25 | Asynchronous unix socket connection manager. 26 | """ 27 | 28 | async def create(self, endpoint: UnixSocketEndpoint) -> socket.socket: 29 | loop = asyncio.get_running_loop() 30 | 31 | sock = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_STREAM) 32 | sock.setblocking(False) 33 | 34 | await loop.sock_connect(sock, address=(str(endpoint))) 35 | 36 | return sock 37 | 38 | async def dispose(self, endpoint: UnixSocketEndpoint, conn: socket.socket) -> None: 39 | try: 40 | conn.shutdown(socket.SHUT_RDWR) 41 | except OSError: 42 | pass 43 | 44 | conn.close() 45 | 46 | 47 | class UnixSocketStreamConnectionManager( 48 | StreamAlivenessCheckingMixin[UnixSocketEndpoint], 49 | BaseConnectionManager[UnixSocketEndpoint, Stream], 50 | ): 51 | """ 52 | Asynchronous unix socket stream connection manager. 53 | """ 54 | 55 | async def create(self, endpoint: UnixSocketEndpoint) -> Stream: 56 | reader, writer = await asyncio.open_unix_connection(path=endpoint) 57 | 58 | return reader, writer 59 | 60 | async def dispose(self, endpoint: UnixSocketEndpoint, conn: Stream) -> None: 61 | reader, writer = conn 62 | if writer.can_write_eof(): 63 | writer.write_eof() 64 | 65 | writer.close() 66 | await writer.wait_closed() 67 | -------------------------------------------------------------------------------- /tests/test_rankmap.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from generic_connection_pool.rankmap import RankMap 4 | 5 | 6 | def test_heap_top(): 7 | rmap = RankMap() 8 | 9 | rmap['3'] = 3 10 | rmap['1'] = 1 11 | rmap['2'] = 2 12 | rmap['4'] = 4 13 | 14 | assert rmap.top() == 1 15 | assert rmap.next() == 1 16 | 17 | assert rmap.top() == 2 18 | assert rmap.next() == 2 19 | 20 | assert rmap.top() == 3 21 | assert rmap.next() == 3 22 | 23 | assert rmap.top() == 4 24 | assert rmap.next() == 4 25 | 26 | assert rmap.top() is None 27 | 28 | 29 | def test_contains_key(): 30 | rmap = RankMap() 31 | 32 | rmap['a'] = 1 33 | assert 'a' in rmap 34 | assert 'b' not in rmap 35 | 36 | 37 | def test_replace(): 38 | rmap = RankMap() 39 | 40 | rmap['a'] = 1 41 | assert rmap['a'] == 1 42 | 43 | rmap['a'] = 2 44 | assert rmap['a'] == 2 45 | 46 | 47 | def test_heap_push_pop(): 48 | rmap = RankMap() 49 | 50 | items = [0, 2, 1, 4, 3, 5, 6, 7, 8, 12, 11, 10, 9] 51 | for item in items: 52 | rmap.insert(item, item) 53 | 54 | actual_result = [] 55 | while rmap: 56 | actual_result.append(rmap.next()) 57 | 58 | expected_result = sorted(items) 59 | assert actual_result == expected_result 60 | assert rmap.next() is None 61 | assert len(rmap) == 0 62 | 63 | 64 | def test_heap_remove(): 65 | rmap = RankMap() 66 | 67 | rmap.insert(1, 1) 68 | del rmap[1] 69 | 70 | assert len(rmap) == 0 71 | 72 | rmap.insert(1, 1) 73 | rmap.insert(2, 2) 74 | rmap.insert(3, 3) 75 | rmap.insert(4, 4) 76 | 77 | del rmap[2] 78 | assert rmap.next() == 1 79 | assert rmap.next() == 3 80 | 81 | del rmap[4] 82 | assert len(rmap) == 0 83 | 84 | 85 | def test_heap_duplicate_error(): 86 | rmap = RankMap() 87 | 88 | rmap.insert(1, 1) 89 | with pytest.raises(KeyError): 90 | rmap.insert(1, 1) 91 | 92 | rmap.insert(2, 1) 93 | with pytest.raises(KeyError): 94 | rmap.insert(2, 2) 95 | 96 | 97 | def test_heap_clear(): 98 | rmap = RankMap() 99 | 100 | items = [0, 1, 2] 101 | for item in items: 102 | rmap.insert(item, item) 103 | 104 | rmap.clear() 105 | assert len(rmap) == 0 106 | -------------------------------------------------------------------------------- /examples/async_pool.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import urllib.parse 3 | from typing import Tuple 4 | 5 | from generic_connection_pool.asyncio import ConnectionPool 6 | from generic_connection_pool.contrib.socket_async import TcpStreamConnectionManager 7 | 8 | Hostname = str 9 | Port = int 10 | Endpoint = Tuple[Hostname, Port] 11 | Connection = Tuple[asyncio.StreamReader, asyncio.StreamWriter] 12 | 13 | 14 | async def main() -> None: 15 | http_pool = ConnectionPool[Endpoint, Connection]( 16 | TcpStreamConnectionManager(ssl=True), 17 | idle_timeout=30.0, 18 | max_lifetime=600.0, 19 | min_idle=3, 20 | max_size=20, 21 | total_max_size=100, 22 | background_collector=True, 23 | ) 24 | 25 | async def fetch(url: str, timeout: float = 5.0) -> None: 26 | url = urllib.parse.urlsplit(url) 27 | if url.hostname is None: 28 | raise ValueError 29 | 30 | port = url.port or 443 if url.scheme == 'https' else 80 31 | 32 | async with http_pool.connection(endpoint=(url.hostname, port), timeout=timeout) as (reader, writer): 33 | request = ( 34 | 'GET {path} HTTP/1.1\r\n' 35 | 'Host: {host}\r\n' 36 | '\r\n' 37 | '\r\n' 38 | ).format(host=url.hostname, path=url.path) 39 | 40 | writer.write(request.encode()) 41 | await writer.drain() 42 | 43 | status_line = await reader.readuntil(b'\r\n') 44 | headers = await reader.readuntil(b'\r\n\r\n') 45 | headers = { 46 | pair[0].lower(): pair[1] 47 | for header in headers.split(b'\r\n') 48 | if len(pair := header.decode().split(':', maxsplit=1)) == 2 49 | } 50 | 51 | chunks = [] 52 | content_length = int(headers['content-length']) 53 | while content_length: 54 | chunk = await reader.read(content_length) 55 | chunks.append(chunk) 56 | content_length -= len(chunk) 57 | 58 | print(status_line) 59 | print(b''.join(chunks)) 60 | 61 | try: 62 | await fetch('https://en.wikipedia.org/wiki/HTTP') # http connection opened 63 | await fetch('https://en.wikipedia.org/wiki/Python_(programming_language)') # http connection reused 64 | finally: 65 | await http_pool.close() 66 | 67 | asyncio.run(main()) 68 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "generic-connection-pool" 3 | version = "0.7.0" 4 | description = "generic connection pool" 5 | authors = ["Dmitry Pershin "] 6 | license = "Unlicense" 7 | readme = "README.rst" 8 | homepage = "https://github.com/dapper91/generic-connection-pool" 9 | repository = "https://github.com/dapper91/generic-connection-pool" 10 | documentation = "https://generic-connection-pool.readthedocs.io" 11 | keywords = ['pool', 'connection-pool', 'asyncio', 'socket', 'tcp'] 12 | classifiers = [ 13 | "Development Status :: 5 - Production/Stable", 14 | "Intended Audience :: Developers", 15 | "Natural Language :: English", 16 | "License :: Public Domain", 17 | "Operating System :: OS Independent", 18 | "Topic :: Database", 19 | "Topic :: Internet :: WWW/HTTP", 20 | "Topic :: Software Development :: Libraries", 21 | "Topic :: System :: Networking", 22 | "Programming Language :: Python :: 3", 23 | "Programming Language :: Python :: 3.9", 24 | "Programming Language :: Python :: 3.10", 25 | "Programming Language :: Python :: 3.11", 26 | "Programming Language :: Python :: 3.12", 27 | "Typing :: Typed", 28 | ] 29 | 30 | [tool.poetry.dependencies] 31 | python = ">=3.9" 32 | psycopg2 = {version = ">=2.9.5", optional = true} 33 | asyncpg = {version = "^0.29.0", optional = true} 34 | 35 | furo = {version = "^2023.7.26", optional = true} 36 | Sphinx = {version = "^6.2.1", optional = true} 37 | sphinx_design = {version = "^0.5.0", optional = true} 38 | toml = {version = "^0.10.2", optional = true} 39 | sphinx-copybutton = {version = "^0.5.2", optional = true} 40 | 41 | [tool.poetry.extras] 42 | psycopg2 = ["psycopg2"] 43 | asyncpg = ["asyncpg"] 44 | docs = ['Sphinx', 'sphinx_design', 'sphinx-copybutton', 'furo', 'toml'] 45 | 46 | [tool.poetry.dev-dependencies] 47 | pre-commit = "^3.3.3" 48 | mypy = "^1.4.1" 49 | pytest = "^7.4.0" 50 | pytest-asyncio = "^0.21.1" 51 | pytest-mock = "^3.11.1" 52 | pytest-cov = "^4.1.0" 53 | types-psycopg2 = "^2.9.21" 54 | pytest-timeout = "^2.1.0" 55 | asyncpg-stubs = {version = "^0.29.0", python = ">=3.9,<4.0"} 56 | 57 | 58 | [build-system] 59 | requires = ["poetry-core>=1.0.0"] 60 | build-backend = "poetry.core.masonry.api" 61 | 62 | 63 | [tool.mypy] 64 | allow_redefinition = true 65 | disallow_incomplete_defs = true 66 | disallow_any_generics = true 67 | disallow_untyped_decorators = true 68 | disallow_untyped_defs = true 69 | no_implicit_optional = true 70 | show_error_codes = true 71 | strict_equality = true 72 | warn_unused_ignores = true 73 | -------------------------------------------------------------------------------- /examples/custom_manager.py: -------------------------------------------------------------------------------- 1 | import socket 2 | from ipaddress import IPv4Address 3 | from typing import Optional, Tuple 4 | 5 | from generic_connection_pool.threading import BaseConnectionManager, ConnectionPool 6 | 7 | Port = int 8 | Endpoint = Tuple[IPv4Address, Port] 9 | Connection = socket.socket 10 | 11 | 12 | class RedisConnectionManager(BaseConnectionManager[Endpoint, socket.socket]): 13 | def __init__(self, username: str, password: str): 14 | self._username = username 15 | self._password = password 16 | 17 | def create(self, endpoint: Endpoint, timeout: Optional[float] = None) -> socket.socket: 18 | addr, port = endpoint 19 | 20 | sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_STREAM) 21 | sock.connect((str(addr), port)) 22 | 23 | if (resp := self.cmd(sock, f'AUTH {self._username} {self._password}')) != '+OK': 24 | raise RuntimeError(f"authentication failed: {resp}") 25 | 26 | return sock 27 | 28 | def cmd(self, sock: socket.socket, cmd: str) -> str: 29 | sock.sendall(f'{cmd}\r\n'.encode()) 30 | 31 | response = sock.recv(1024) 32 | return response.rstrip(b'\r\n').decode() 33 | 34 | def dispose(self, endpoint: Endpoint, conn: socket.socket, timeout: Optional[float] = None) -> None: 35 | try: 36 | conn.shutdown(socket.SHUT_RDWR) 37 | except OSError: 38 | pass 39 | 40 | conn.close() 41 | 42 | def check_aliveness(self, endpoint: Endpoint, conn: socket.socket, timeout: Optional[float] = None) -> bool: 43 | try: 44 | if self.cmd(conn, 'ping') != '+PONG': 45 | return False 46 | except OSError: 47 | return False 48 | 49 | return True 50 | 51 | 52 | redis_pool = ConnectionPool[Endpoint, Connection]( 53 | RedisConnectionManager('', 'secret'), 54 | idle_timeout=30.0, 55 | max_lifetime=600.0, 56 | min_idle=3, 57 | max_size=20, 58 | total_max_size=100, 59 | background_collector=True, 60 | ) 61 | 62 | 63 | def command(addr: IPv4Address, port: int, cmd: str) -> None: 64 | with redis_pool.connection(endpoint=(addr, port), timeout=5.0) as sock: 65 | sock.sendall(cmd.encode() + b'\r\n') 66 | response = sock.recv(1024) 67 | 68 | print(response.decode()) 69 | 70 | 71 | try: 72 | command(IPv4Address('127.0.0.1'), 6379, 'CLIENT ID') # tcp connection opened 73 | command(IPv4Address('127.0.0.1'), 6379, 'CLIENT INFO') # tcp connection reused 74 | finally: 75 | redis_pool.close() 76 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # IDE 132 | .idea 133 | 134 | # poetry 135 | poetry.lock 136 | -------------------------------------------------------------------------------- /tests/test_sync_locks.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import time 3 | 4 | import pytest 5 | 6 | from generic_connection_pool.threading.locks import SharedLock 7 | 8 | 9 | @pytest.mark.timeout(5.0) 10 | def test_shared_mode(): 11 | lock = SharedLock() 12 | 13 | thread_cnt = 5 14 | barrier = threading.Barrier(thread_cnt) 15 | 16 | def acquire_connection() -> None: 17 | with lock.acquired(exclusive=False): 18 | barrier.wait() 19 | 20 | threads = [ 21 | threading.Thread(target=acquire_connection) 22 | for _ in range(thread_cnt) 23 | ] 24 | 25 | for thread in threads: 26 | thread.start() 27 | 28 | for thread in threads: 29 | thread.join() 30 | 31 | 32 | @pytest.mark.parametrize( 33 | 'mode1, mode2', [ 34 | (True, True), 35 | (True, False), 36 | (False, True), 37 | ], 38 | ) 39 | @pytest.mark.timeout(5.0) 40 | def test_exclusive_mode(delay: float, mode1: bool, mode2: bool): 41 | lock = SharedLock() 42 | 43 | locked = threading.Event() 44 | stopped = threading.Event() 45 | 46 | def acquire_connection(exclusive: bool): 47 | with lock.acquired(exclusive=exclusive): 48 | locked.set() 49 | stopped.wait() 50 | 51 | thread1 = threading.Thread(target=acquire_connection, kwargs=dict(exclusive=mode1)) 52 | thread1.start() 53 | 54 | locked.wait() 55 | locked.clear() 56 | thread2 = threading.Thread(target=acquire_connection, kwargs=dict(exclusive=mode2)) 57 | thread2.start() 58 | 59 | time.sleep(delay) 60 | assert not locked.is_set() 61 | 62 | stopped.set() 63 | locked.wait() 64 | assert locked.is_set() 65 | 66 | thread1.join() 67 | thread2.join() 68 | 69 | 70 | @pytest.mark.parametrize( 71 | 'blocking, timeout, mode1, mode2', [ 72 | (False, -1, True, False), 73 | (False, -1, True, True), 74 | (False, -1, False, True), 75 | 76 | (True, 0.01, True, False), 77 | (True, 0.01, True, True), 78 | (True, 0.01, False, True), 79 | ], 80 | ) 81 | @pytest.mark.timeout(5.0) 82 | def test_nonblocking_and_timeout(blocking: bool, timeout: float, mode1: bool, mode2: bool): 83 | lock = SharedLock() 84 | 85 | locked = threading.Event() 86 | stopped = threading.Event() 87 | 88 | def acquire_blocking() -> None: 89 | with lock.acquired(exclusive=mode1): 90 | locked.set() 91 | stopped.wait() 92 | 93 | thread1 = threading.Thread(target=acquire_blocking) 94 | thread1.start() 95 | 96 | locked.wait() 97 | 98 | def acquire_nonblocking() -> None: 99 | assert not lock.acquire(exclusive=mode2, blocking=blocking, timeout=timeout) 100 | 101 | thread2 = threading.Thread(target=acquire_nonblocking) 102 | thread2.start() 103 | thread2.join() 104 | 105 | stopped.set() 106 | thread1.join() 107 | 108 | 109 | @pytest.mark.timeout(5.0) 110 | def test_repr(): 111 | thread_id = threading.get_ident() 112 | lock = SharedLock() 113 | 114 | assert repr(lock) == 'SharedLock' 115 | 116 | with lock.acquired(exclusive=True): 117 | assert repr(lock) == f'SharedLock' 118 | 119 | with lock.acquired(exclusive=False): 120 | assert repr(lock) == 'SharedLock' 121 | -------------------------------------------------------------------------------- /tests/test_async_locks.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | import pytest 4 | 5 | from generic_connection_pool.asyncio.locks import SharedLock 6 | 7 | 8 | @pytest.mark.timeout(5.0) 9 | async def test_shared_mode(): 10 | lock = SharedLock() 11 | 12 | task_cnt = 5 13 | acquire_cnt = 0 14 | acquire_event = asyncio.Event() 15 | finished = asyncio.Event() 16 | 17 | async def acquire_connection(): 18 | nonlocal acquire_cnt 19 | 20 | async with lock.acquired(exclusive=False): 21 | acquire_cnt += 1 22 | acquire_event.set() 23 | await finished.wait() 24 | 25 | tasks = [ 26 | asyncio.create_task(acquire_connection()) 27 | for _ in range(task_cnt) 28 | ] 29 | while acquire_cnt < task_cnt: 30 | await acquire_event.wait() 31 | acquire_event.clear() 32 | 33 | finished.set() 34 | await asyncio.gather(*tasks) 35 | 36 | 37 | @pytest.mark.parametrize( 38 | 'mode1, mode2', [ 39 | (True, True), 40 | (True, False), 41 | (False, True), 42 | ], 43 | ) 44 | @pytest.mark.timeout(5.0) 45 | async def test_exclusive_mode(delay: float, mode1: bool, mode2: bool): 46 | lock = SharedLock() 47 | 48 | locked = asyncio.Event() 49 | stopped = asyncio.Event() 50 | 51 | async def acquire_connection(exclusive: bool): 52 | async with lock.acquired(exclusive=exclusive): 53 | locked.set() 54 | await stopped.wait() 55 | 56 | task1 = asyncio.create_task(acquire_connection(exclusive=mode1)) 57 | 58 | await locked.wait() 59 | locked.clear() 60 | task2 = asyncio.create_task(acquire_connection(exclusive=mode2)) 61 | 62 | await asyncio.sleep(delay) 63 | assert not locked.is_set() 64 | 65 | stopped.set() 66 | await locked.wait() 67 | assert locked.is_set() 68 | 69 | await task1 70 | await task2 71 | 72 | 73 | @pytest.mark.parametrize( 74 | 'timeout, mode1, mode2', [ 75 | (0.00, True, False), 76 | (0.00, True, True), 77 | (0.00, False, True), 78 | (0.01, True, False), 79 | (0.01, True, True), 80 | (0.01, False, True), 81 | ], 82 | ) 83 | @pytest.mark.timeout(5.0) 84 | async def test_nonblocking_and_timeout(timeout: float, mode1: bool, mode2: bool): 85 | lock = SharedLock() 86 | 87 | locked = asyncio.Event() 88 | stopped = asyncio.Event() 89 | 90 | async def acquire_blocking() -> None: 91 | async with lock.acquired(exclusive=mode1): 92 | locked.set() 93 | await stopped.wait() 94 | 95 | task1 = asyncio.create_task(acquire_blocking()) 96 | await locked.wait() 97 | 98 | async def acquire_nonblocking() -> None: 99 | assert not await lock.acquire(exclusive=mode2, timeout=timeout) 100 | 101 | task2 = asyncio.create_task(acquire_nonblocking()) 102 | await task2 103 | 104 | stopped.set() 105 | await task1 106 | 107 | 108 | @pytest.mark.timeout(5.0) 109 | async def test_repr(): 110 | task_name = asyncio.current_task().get_name() 111 | lock = SharedLock() 112 | 113 | assert repr(lock) == 'SharedLock' 114 | 115 | async with lock.acquired(exclusive=True): 116 | assert repr(lock) == f'SharedLock' 117 | 118 | async with lock.acquired(exclusive=False): 119 | assert repr(lock) == 'SharedLock' 120 | -------------------------------------------------------------------------------- /tests/contrib/test_unix_async.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import socket 3 | import sys 4 | import tempfile 5 | from pathlib import Path 6 | from typing import AsyncGenerator, Generator, Optional 7 | 8 | import pytest 9 | 10 | from generic_connection_pool.asyncio import ConnectionPool 11 | from generic_connection_pool.contrib.unix_async import Stream, UnixSocketConnectionManager 12 | from generic_connection_pool.contrib.unix_async import UnixSocketStreamConnectionManager 13 | 14 | if sys.platform not in ('linux', 'darwin', 'freebsd'): 15 | pytest.skip(reason="unix platform only", allow_module_level=True) 16 | 17 | 18 | class EchoUnixServer: 19 | @staticmethod 20 | async def echo_handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): 21 | while data := await reader.read(1024): 22 | writer.write(data) 23 | await writer.drain() 24 | 25 | writer.close() 26 | await writer.wait_closed() 27 | 28 | def __init__(self, path: Path): 29 | self._path = path 30 | self._server_task: Optional[asyncio.Task[None]] = None 31 | 32 | async def start(self) -> None: 33 | server = await asyncio.start_unix_server( 34 | self.echo_handler, 35 | path=self._path, 36 | start_serving=False, 37 | ) 38 | self._server_task = asyncio.create_task(server.serve_forever()) 39 | 40 | async def stop(self) -> None: 41 | if (server_task := self._server_task) is not None: 42 | server_task.cancel() 43 | try: 44 | await server_task 45 | except asyncio.CancelledError: 46 | pass 47 | 48 | 49 | @pytest.fixture 50 | async def unix_socket_server(port_gen: Generator[int, None, None]) -> AsyncGenerator[Path, None]: 51 | with tempfile.TemporaryDirectory() as tmp_dir: 52 | path = Path(tmp_dir) / 'unix-socket-server-async.sock' 53 | server = EchoUnixServer(path) 54 | await server.start() 55 | yield path 56 | await server.stop() 57 | 58 | 59 | @pytest.mark.timeout(5.0) 60 | async def test_unix_socket_manager(unix_socket_server: Path): 61 | loop = asyncio.get_running_loop() 62 | path = unix_socket_server 63 | 64 | pool = ConnectionPool[Path, socket.socket](UnixSocketConnectionManager()) 65 | 66 | attempts = 3 67 | request = b'test' 68 | for _ in range(attempts): 69 | async with pool.connection(path) as sock1: 70 | await loop.sock_sendall(sock1, request) 71 | response = await loop.sock_recv(sock1, len(request)) 72 | assert response == request 73 | 74 | async with pool.connection(path) as sock2: 75 | await loop.sock_sendall(sock2, request) 76 | response = await loop.sock_recv(sock2, len(request)) 77 | assert response == request 78 | 79 | await pool.close() 80 | 81 | 82 | @pytest.mark.timeout(5.0) 83 | async def test_unix_stream_manager(resource_dir: Path, unix_socket_server: Path): 84 | path = unix_socket_server 85 | 86 | pool = ConnectionPool[Path, Stream]( 87 | UnixSocketStreamConnectionManager(), 88 | ) 89 | 90 | attempts = 3 91 | request = b'test' 92 | for _ in range(attempts): 93 | async with pool.connection(path) as (reader1, writer1): 94 | writer1.write(request) 95 | await writer1.drain() 96 | response = await reader1.read(len(request)) 97 | assert response == request 98 | 99 | async with pool.connection(path) as (reader2, writer2): 100 | writer2.write(request) 101 | await writer2.drain() 102 | response = await reader2.read(len(request)) 103 | assert response == request 104 | 105 | await pool.close() 106 | -------------------------------------------------------------------------------- /generic_connection_pool/contrib/socket_async.py: -------------------------------------------------------------------------------- 1 | """ 2 | Asynchronous socket connection manager implementation. 3 | """ 4 | 5 | import asyncio 6 | import errno 7 | import socket 8 | from ipaddress import IPv4Address, IPv6Address 9 | from ssl import SSLContext 10 | from typing import Generic, Tuple, Union 11 | 12 | from generic_connection_pool.asyncio import BaseConnectionManager, EndpointT 13 | 14 | IpAddress = Union[IPv4Address, IPv6Address] 15 | Port = int 16 | TcpEndpoint = Tuple[IpAddress, Port] 17 | 18 | 19 | class SocketAlivenessCheckingMixin(Generic[EndpointT]): 20 | """ 21 | Nonblocking socket aliveness checking mix-in. 22 | """ 23 | 24 | async def check_aliveness(self, endpoint: EndpointT, conn: socket.socket) -> bool: 25 | try: 26 | if conn.recv(1, socket.MSG_PEEK) == b'': 27 | return False 28 | except BlockingIOError as exc: 29 | if exc.errno != errno.EAGAIN: 30 | raise 31 | except OSError: 32 | return False 33 | 34 | return True 35 | 36 | 37 | class TcpSocketConnectionManager( 38 | SocketAlivenessCheckingMixin[TcpEndpoint], 39 | BaseConnectionManager[TcpEndpoint, socket.socket], 40 | ): 41 | """ 42 | TCP socket connection manager. 43 | """ 44 | 45 | async def create(self, endpoint: TcpEndpoint) -> socket.socket: 46 | loop = asyncio.get_running_loop() 47 | addr, port = endpoint 48 | 49 | if addr.version == 4: 50 | family = socket.AF_INET 51 | elif addr.version == 6: 52 | family = socket.AF_INET6 53 | else: 54 | raise RuntimeError("unsupported address version type: %s", addr.version) 55 | 56 | sock = socket.socket(family=family, type=socket.SOCK_STREAM) 57 | sock.setblocking(False) 58 | 59 | await loop.sock_connect(sock, address=(str(addr), port)) 60 | 61 | return sock 62 | 63 | async def dispose(self, endpoint: TcpEndpoint, conn: socket.socket) -> None: 64 | try: 65 | conn.shutdown(socket.SHUT_RDWR) 66 | except OSError: 67 | pass 68 | 69 | conn.close() 70 | 71 | 72 | Hostname = str 73 | TcpStreamEndpoint = Tuple[Hostname, Port] 74 | Stream = Tuple[asyncio.StreamReader, asyncio.StreamWriter] 75 | 76 | 77 | class StreamAlivenessCheckingMixin(Generic[EndpointT]): 78 | """ 79 | Asynchronous stream aliveness checking mix-in. 80 | """ 81 | 82 | async def check_aliveness(self, endpoint: EndpointT, conn: Stream) -> bool: 83 | reader, writer = conn 84 | 85 | try: 86 | await reader.read(0) 87 | except OSError: 88 | return False 89 | 90 | return not writer.is_closing() and not reader.at_eof() 91 | 92 | 93 | class TcpStreamConnectionManager( 94 | StreamAlivenessCheckingMixin[TcpStreamEndpoint], 95 | BaseConnectionManager[TcpStreamEndpoint, Stream], 96 | ): 97 | """ 98 | TCP stream connection manager. 99 | """ 100 | 101 | def __init__(self, ssl: Union[None, bool, SSLContext] = None): 102 | self._ssl = ssl 103 | 104 | async def create(self, endpoint: TcpStreamEndpoint) -> Stream: 105 | hostname, port = endpoint 106 | server_hostname = hostname if self._ssl is not None else None 107 | 108 | reader, writer = await asyncio.open_connection( 109 | hostname, 110 | port, 111 | server_hostname=server_hostname, 112 | ssl=self._ssl, 113 | ) 114 | 115 | return reader, writer 116 | 117 | async def dispose(self, endpoint: TcpStreamEndpoint, conn: Stream) -> None: 118 | reader, writer = conn 119 | if writer.can_write_eof(): 120 | writer.write_eof() 121 | 122 | writer.close() 123 | await writer.wait_closed() 124 | -------------------------------------------------------------------------------- /generic_connection_pool/threading/locks.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import threading 3 | import time 4 | from typing import Any, Generator, Optional 5 | 6 | 7 | class SharedLock: 8 | """ 9 | Shared lock. Supports shared and exclusive locking modes. 10 | """ 11 | 12 | def __init__(self) -> None: 13 | self._lock = threading.Condition(threading.Lock()) 14 | self._shared_cnt = 0 15 | self._exclusive = False 16 | self._exclusive_owner: Optional[int] = None 17 | 18 | def __repr__(self) -> str: 19 | if self._exclusive: 20 | state = "locked exclusively" 21 | elif self._shared_cnt: 22 | state = "locked for share" 23 | else: 24 | state = "unlocked" 25 | 26 | return "%s" % ( 27 | self.__class__.__qualname__, 28 | state, 29 | self._exclusive_owner, 30 | self._shared_cnt, 31 | ) 32 | 33 | def __enter__(self) -> None: 34 | self.acquire() 35 | 36 | def __exit__(self, *args: Any) -> None: 37 | self.release() 38 | 39 | @contextlib.contextmanager 40 | def acquired( 41 | self, 42 | exclusive: bool = True, 43 | blocking: bool = True, 44 | timeout: Optional[float] = None, 45 | ) -> Generator[None, None, None]: 46 | """ 47 | Opens the lock acquiring context. 48 | """ 49 | 50 | if not self.acquire(exclusive, blocking, timeout): 51 | raise TimeoutError 52 | try: 53 | yield 54 | finally: 55 | self.release(exclusive) 56 | 57 | def acquire(self, exclusive: bool = True, blocking: bool = True, timeout: Optional[float] = None) -> bool: 58 | """ 59 | Acquires the lock. 60 | 61 | :param exclusive: acquire the lock in exclusive mode 62 | :param blocking: when `True` blocks until the lock is acquired 63 | :param timeout: number of seconds during which the lock is tried to be acquired 64 | """ 65 | 66 | if exclusive: 67 | return self._acquire_exclusive(blocking, timeout) 68 | else: 69 | return self._acquire_shared(blocking, timeout) 70 | 71 | def release(self, exclusive: bool = True) -> None: 72 | """ 73 | Releases the acquired lock. 74 | 75 | :param exclusive: release the lock acquired in exclusive mode 76 | """ 77 | 78 | if exclusive: 79 | self._release_exclusive() 80 | else: 81 | self._release_shared() 82 | 83 | def _acquire_shared(self, blocking: bool = True, timeout: Optional[float] = None) -> bool: 84 | if not self._lock.acquire(blocking, timeout if timeout is not None else -1): 85 | return False 86 | 87 | try: 88 | self._shared_cnt += 1 89 | finally: 90 | self._lock.release() 91 | 92 | return True 93 | 94 | def _release_shared(self) -> None: 95 | if self._shared_cnt == 0: 96 | raise RuntimeError("lock is not acquired") 97 | 98 | try: 99 | self._lock.acquire() 100 | finally: 101 | self._shared_cnt -= 1 102 | 103 | if not self._shared_cnt: 104 | self._lock.notify() 105 | 106 | self._lock.release() 107 | 108 | def _acquire_exclusive(self, blocking: bool = True, timeout: Optional[float] = None) -> bool: 109 | started_at = time.monotonic() 110 | 111 | if not self._lock.acquire(blocking, timeout if timeout is not None else -1): 112 | return False 113 | 114 | try: 115 | while self._shared_cnt > 0: 116 | if not blocking: 117 | self._lock.release() 118 | return False 119 | 120 | if not self._lock.wait(timeout - (time.monotonic() - started_at) if timeout is not None else None): 121 | self._lock.release() 122 | return False 123 | 124 | except BaseException: 125 | self._lock.release() 126 | raise 127 | 128 | self._exclusive = True 129 | self._exclusive_owner = threading.get_ident() 130 | 131 | return True 132 | 133 | def _release_exclusive(self) -> None: 134 | self._exclusive = False 135 | self._exclusive_owner = None 136 | self._lock.release() 137 | 138 | def _is_owned(self) -> bool: # to be compatible with threading.Condition 139 | return self._exclusive_owner == threading.get_ident() 140 | -------------------------------------------------------------------------------- /generic_connection_pool/asyncio/locks.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import contextlib 3 | import time 4 | from typing import Any, AsyncGenerator, Optional 5 | 6 | from .utils import guard 7 | 8 | 9 | def get_ident() -> Optional[str]: 10 | """ 11 | Returns asyncio task identifier. 12 | """ 13 | 14 | if task := asyncio.current_task(): 15 | return task.get_name() 16 | 17 | return None 18 | 19 | 20 | class SharedLock: 21 | """ 22 | Asynchronous shared lock. Supports shared and exclusive locking modes. 23 | """ 24 | 25 | def __init__(self) -> None: 26 | self._lock = asyncio.Condition() 27 | self._shared_cnt = 0 28 | self._exclusive = False 29 | self._exclusive_owner: Optional[str] = None 30 | 31 | def __repr__(self) -> str: 32 | if self._exclusive: 33 | state = "locked exclusively" 34 | elif self._shared_cnt: 35 | state = "locked for share" 36 | else: 37 | state = "unlocked" 38 | 39 | return "%s" % ( 40 | self.__class__.__qualname__, 41 | state, 42 | self._exclusive_owner, 43 | self._shared_cnt, 44 | ) 45 | 46 | async def __aenter__(self) -> None: 47 | await self.acquire() 48 | 49 | async def __aexit__(self, *args: Any) -> None: 50 | await self.release() 51 | 52 | @contextlib.asynccontextmanager 53 | async def acquired( 54 | self, 55 | exclusive: bool = True, 56 | timeout: Optional[float] = None, 57 | ) -> AsyncGenerator[None, None]: 58 | """ 59 | Opens the lock acquiring context. 60 | """ 61 | 62 | if not await self.acquire(exclusive, timeout): 63 | raise asyncio.TimeoutError 64 | try: 65 | yield 66 | finally: 67 | await self.release(exclusive) 68 | 69 | async def acquire(self, exclusive: bool = True, timeout: Optional[float] = None) -> bool: 70 | """ 71 | Acquires the lock. 72 | 73 | :param exclusive: acquire the lock in exclusive mode 74 | :param timeout: number of seconds during which the lock is tried to be acquired 75 | """ 76 | 77 | if exclusive: 78 | return await self._acquire_exclusive(timeout) 79 | else: 80 | return await self._acquire_shared(timeout) 81 | 82 | async def release(self, exclusive: bool = True) -> None: 83 | """ 84 | Releases the acquired lock. 85 | 86 | :param exclusive: release the lock acquired in exclusive mode 87 | """ 88 | 89 | if exclusive: 90 | self._release_exclusive() 91 | else: 92 | await self._release_shared() 93 | 94 | async def _acquire_shared(self, timeout: Optional[float] = None) -> bool: 95 | try: 96 | await asyncio.wait_for(self._lock.acquire(), timeout=timeout) 97 | except asyncio.TimeoutError: 98 | return False 99 | 100 | try: 101 | self._shared_cnt += 1 102 | finally: 103 | self._lock.release() 104 | 105 | return True 106 | 107 | async def _release_shared(self) -> None: 108 | if self._shared_cnt == 0: 109 | raise RuntimeError("lock is not acquired") 110 | 111 | cancelled = False 112 | try: 113 | await guard(self._lock.acquire()) 114 | except asyncio.CancelledError: 115 | cancelled = True 116 | finally: 117 | self._shared_cnt -= 1 118 | 119 | if not self._shared_cnt: 120 | self._lock.notify() 121 | 122 | self._lock.release() 123 | 124 | if cancelled: 125 | raise asyncio.CancelledError 126 | 127 | async def _acquire_exclusive(self, timeout: Optional[float] = None) -> bool: 128 | started_at = time.monotonic() 129 | 130 | try: 131 | await asyncio.wait_for(self._lock.acquire(), timeout=timeout) 132 | except asyncio.TimeoutError: 133 | return False 134 | 135 | try: 136 | while self._shared_cnt > 0: 137 | timeout = timeout - (time.monotonic() - started_at) if timeout is not None else None 138 | await asyncio.wait_for(self._lock.wait(), timeout=timeout) 139 | 140 | except asyncio.TimeoutError: 141 | self._lock.release() 142 | return False 143 | 144 | except BaseException: 145 | self._lock.release() 146 | raise 147 | 148 | self._exclusive = True 149 | self._exclusive_owner = get_ident() 150 | 151 | return True 152 | 153 | def _release_exclusive(self) -> None: 154 | self._exclusive = False 155 | self._exclusive_owner = None 156 | self._lock.release() 157 | -------------------------------------------------------------------------------- /generic_connection_pool/contrib/socket.py: -------------------------------------------------------------------------------- 1 | """ 2 | Synchronous socket connection manager implementation. 3 | """ 4 | 5 | import contextlib 6 | import errno 7 | import socket 8 | from ipaddress import IPv4Address, IPv6Address 9 | from ssl import SSLContext, SSLSocket 10 | from typing import Generator, Generic, Optional, Tuple, Union 11 | 12 | from generic_connection_pool.threading import BaseConnectionManager, EndpointT 13 | 14 | IpAddress = Union[IPv4Address, IPv6Address] 15 | Hostname = str 16 | Port = int 17 | TcpEndpoint = Tuple[IpAddress, Port] 18 | 19 | 20 | @contextlib.contextmanager 21 | def socket_nonblocking(sock: socket.socket) -> Generator[None, None, None]: 22 | orig_timeout = sock.gettimeout() 23 | sock.settimeout(0) 24 | try: 25 | yield 26 | finally: 27 | sock.settimeout(orig_timeout) 28 | 29 | 30 | @contextlib.contextmanager 31 | def socket_timeout(sock: socket.socket, timeout: Optional[float]) -> Generator[None, None, None]: 32 | if timeout is None: 33 | yield 34 | return 35 | 36 | orig_timeout = sock.gettimeout() 37 | sock.settimeout(max(timeout, 1e-6)) # if timeout is 0 set it a small value to prevent non-blocking socket mode 38 | try: 39 | yield 40 | except OSError as e: 41 | if 'timed out' in str(e): 42 | raise TimeoutError 43 | else: 44 | raise 45 | finally: 46 | sock.settimeout(orig_timeout) 47 | 48 | 49 | class SocketAlivenessCheckingMixin(Generic[EndpointT]): 50 | """ 51 | Socket aliveness checking mix-in. 52 | """ 53 | 54 | def check_aliveness(self, endpoint: EndpointT, conn: socket.socket, timeout: Optional[float] = None) -> bool: 55 | try: 56 | with socket_nonblocking(conn): 57 | if conn.recv(1, socket.MSG_PEEK) == b'': 58 | return False 59 | except BlockingIOError as exc: 60 | if exc.errno != errno.EAGAIN: 61 | raise 62 | except OSError: 63 | return False 64 | 65 | return True 66 | 67 | 68 | class TcpSocketConnectionManager( 69 | SocketAlivenessCheckingMixin[TcpEndpoint], 70 | BaseConnectionManager[TcpEndpoint, socket.socket], 71 | ): 72 | """ 73 | TCP socket connection manager. 74 | """ 75 | 76 | def create(self, endpoint: TcpEndpoint, timeout: Optional[float] = None) -> socket.socket: 77 | addr, port = endpoint 78 | 79 | if addr.version == 4: 80 | family = socket.AF_INET 81 | elif addr.version == 6: 82 | family = socket.AF_INET6 83 | else: 84 | raise RuntimeError("unsupported address version type: %s", addr.version) 85 | 86 | sock = socket.socket(family=family, type=socket.SOCK_STREAM) 87 | 88 | with socket_timeout(sock, timeout): 89 | sock.connect((str(addr), port)) 90 | 91 | return sock 92 | 93 | def dispose(self, endpoint: TcpEndpoint, conn: socket.socket, timeout: Optional[float] = None) -> None: 94 | try: 95 | conn.shutdown(socket.SHUT_RDWR) 96 | except OSError: 97 | pass 98 | 99 | conn.close() 100 | 101 | 102 | class SslSocketAlivenessCheckingMixin(Generic[EndpointT]): 103 | """ 104 | SSL socket aliveness checking mix-in. 105 | """ 106 | 107 | def check_aliveness(self, endpoint: EndpointT, conn: SSLSocket, timeout: Optional[float] = None) -> bool: 108 | try: 109 | with socket_nonblocking(conn): 110 | # peek into the plain socket since ssl socket doesn't support flags 111 | if socket.socket.recv(conn, 1, socket.MSG_PEEK) == b'': 112 | return False 113 | except BlockingIOError as exc: 114 | if exc.errno != errno.EAGAIN: 115 | raise 116 | except OSError: 117 | return False 118 | 119 | return True 120 | 121 | 122 | SslEndpoint = Tuple[Hostname, Port] 123 | 124 | 125 | class SslSocketConnectionManager( 126 | SslSocketAlivenessCheckingMixin[SslEndpoint], 127 | BaseConnectionManager[SslEndpoint, SSLSocket], 128 | ): 129 | """ 130 | SSL socket connection manager. 131 | """ 132 | 133 | def __init__(self, ssl: SSLContext): 134 | self._ssl = ssl 135 | 136 | def create(self, endpoint: SslEndpoint, timeout: Optional[float] = None) -> SSLSocket: 137 | hostname, port = endpoint 138 | 139 | sock = self._ssl.wrap_socket(socket.socket(type=socket.SOCK_STREAM), server_hostname=hostname) 140 | 141 | with socket_timeout(sock, timeout): 142 | sock.connect((hostname, port)) 143 | 144 | return sock 145 | 146 | def dispose(self, endpoint: SslEndpoint, conn: SSLSocket, timeout: Optional[float] = None) -> None: 147 | try: 148 | conn.shutdown(socket.SHUT_RDWR) 149 | except OSError: 150 | pass 151 | 152 | conn.close() 153 | -------------------------------------------------------------------------------- /tests/contrib/test_async_socket_manager.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import socket 3 | import ssl 4 | from ipaddress import IPv4Address 5 | from pathlib import Path 6 | from typing import AsyncGenerator, Generator, Optional, Tuple 7 | 8 | import pytest 9 | 10 | from generic_connection_pool.asyncio import ConnectionPool 11 | from generic_connection_pool.contrib.socket_async import Stream, TcpSocketConnectionManager, TcpStreamConnectionManager 12 | 13 | 14 | class EchoTCPServer: 15 | @staticmethod 16 | async def echo_handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: 17 | while data := await reader.read(1024): 18 | writer.write(data) 19 | await writer.drain() 20 | 21 | writer.close() 22 | await writer.wait_closed() 23 | 24 | def __init__(self, hostname: str, port: int, ssl_ctx: Optional[ssl.SSLContext] = None): 25 | self._hostname = hostname 26 | self._port = port 27 | self._ssl_ctx = ssl_ctx 28 | self._server_task: Optional[asyncio.Task[None]] = None 29 | 30 | async def start(self) -> None: 31 | server = await asyncio.start_server( 32 | self.echo_handler, 33 | host=self._hostname, 34 | port=self._port, 35 | ssl=self._ssl_ctx, 36 | family=socket.AF_INET, 37 | reuse_port=True, 38 | start_serving=False, 39 | ) 40 | self._server_task = asyncio.create_task(server.serve_forever()) 41 | 42 | async def stop(self) -> None: 43 | if (server_task := self._server_task) is not None: 44 | server_task.cancel() 45 | try: 46 | await server_task 47 | except asyncio.CancelledError: 48 | pass 49 | 50 | 51 | @pytest.fixture 52 | async def tcp_server(port_gen: Generator[int, None, None]) -> AsyncGenerator[Tuple[IPv4Address, int], None]: 53 | addr, port = IPv4Address('127.0.0.1'), next(port_gen) 54 | server = EchoTCPServer(str(addr), port) 55 | await server.start() 56 | yield addr, port 57 | await server.stop() 58 | 59 | 60 | @pytest.fixture 61 | async def ssl_server( 62 | resource_dir: Path, 63 | port_gen: Generator[int, None, None], 64 | ) -> AsyncGenerator[Tuple[str, int], None]: 65 | hostname, port = 'localhost', next(port_gen) 66 | ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) 67 | ssl_ctx.load_cert_chain(resource_dir / 'ssl.cert', resource_dir / 'ssl.key') 68 | 69 | server = EchoTCPServer(hostname, port, ssl_ctx=ssl_ctx) 70 | await server.start() 71 | yield hostname, port 72 | await server.stop() 73 | 74 | 75 | @pytest.mark.timeout(5.0) 76 | async def test_tcp_socket_manager(tcp_server: Tuple[IPv4Address, int]): 77 | loop = asyncio.get_running_loop() 78 | addr, port = tcp_server 79 | 80 | pool = ConnectionPool[Tuple[IPv4Address, int], socket.socket]( 81 | TcpSocketConnectionManager(), 82 | ) 83 | 84 | attempts = 3 85 | request = b'test' 86 | for _ in range(attempts): 87 | async with pool.connection((addr, port)) as sock1: 88 | await loop.sock_sendall(sock1, request) 89 | response = await loop.sock_recv(sock1, len(request)) 90 | assert response == request 91 | 92 | async with pool.connection((addr, port)) as sock2: 93 | await loop.sock_sendall(sock2, request) 94 | response = await loop.sock_recv(sock2, len(request)) 95 | assert response == request 96 | 97 | await pool.close() 98 | 99 | 100 | @pytest.mark.timeout(5.0) 101 | async def test_tcp_stream_manager(resource_dir: Path, tcp_server: Tuple[IPv4Address, int]): 102 | addr, port = tcp_server 103 | 104 | pool = ConnectionPool[Tuple[IPv4Address, int], Stream]( 105 | TcpStreamConnectionManager(ssl=None), 106 | ) 107 | 108 | attempts = 3 109 | request = b'test' 110 | for _ in range(attempts): 111 | async with pool.connection((str(addr), port)) as (reader1, writer1): 112 | writer1.write(request) 113 | await writer1.drain() 114 | response = await reader1.read(len(request)) 115 | assert response == request 116 | 117 | async with pool.connection((str(addr), port)) as (reader2, writer2): 118 | writer2.write(request) 119 | await writer2.drain() 120 | response = await reader2.read(len(request)) 121 | assert response == request 122 | 123 | await pool.close() 124 | 125 | 126 | @pytest.mark.timeout(5.0) 127 | async def test_ssl_stream_manager(resource_dir: Path, ssl_server: Tuple[str, int]): 128 | hostname, port = ssl_server 129 | ssl_context = ssl.create_default_context(cafile=resource_dir / 'ssl.cert') 130 | 131 | pool = ConnectionPool[Tuple[str, int], Stream]( 132 | TcpStreamConnectionManager(ssl_context), 133 | ) 134 | 135 | attempts = 3 136 | request = b'test' 137 | for _ in range(attempts): 138 | async with pool.connection((hostname, port)) as (reader1, writer1): 139 | writer1.write(request) 140 | await writer1.drain() 141 | response = await reader1.read(len(request)) 142 | assert response == request 143 | 144 | async with pool.connection((hostname, port)) as (reader2, writer2): 145 | writer2.write(request) 146 | await writer2.drain() 147 | response = await reader2.read(len(request)) 148 | assert response == request 149 | 150 | await pool.close() 151 | -------------------------------------------------------------------------------- /generic_connection_pool/rankmap.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Hashable, Iterator, List, MutableMapping, Optional, Protocol, Tuple, TypeVar 2 | 3 | 4 | class ComparableP(Protocol): 5 | def __lt__(self, other: Any) -> bool: ... 6 | def __eq__(self, other: Any) -> bool: ... 7 | 8 | 9 | Key = TypeVar('Key', bound=Hashable) 10 | Value = TypeVar('Value', bound=ComparableP) 11 | 12 | 13 | class RankMap(MutableMapping[Key, Value]): 14 | """ 15 | A hash-map / binary-heap combined data structure that additionally supports the following operations: 16 | - top: get the element with the lowest value in O(1) 17 | - next: removes the element with the lowest value in O(log(n)) 18 | """ 19 | 20 | def __init__(self) -> None: 21 | self._heap: List[Tuple[Value, Key]] = [] 22 | self._index: Dict[Key, int] = {} 23 | 24 | def __len__(self) -> int: 25 | return len(self._heap) 26 | 27 | def __bool__(self) -> bool: 28 | return bool(self._heap) 29 | 30 | def __getitem__(self, key: Key) -> Value: 31 | idx = self._index[key] 32 | value, key = self._heap[idx] 33 | 34 | return value 35 | 36 | def __setitem__(self, key: Key, value: Value) -> None: 37 | self.insert_or_replace(key, value) 38 | 39 | def __delitem__(self, key: Key) -> None: 40 | self.remove(key) 41 | 42 | def __iter__(self) -> Iterator[Key]: 43 | return (key for value, key in self._heap) 44 | 45 | def clear(self) -> None: 46 | """ 47 | Remove all items from the map. 48 | """ 49 | 50 | self._heap.clear() 51 | self._index.clear() 52 | 53 | def insert(self, key: Key, value: Value) -> None: 54 | """ 55 | Inserts an item into the map. 56 | If an item with the provided key already presented raises `KeyError`. 57 | 58 | :param key: item key 59 | :param value: item value 60 | """ 61 | 62 | if key in self._index: 63 | raise KeyError("key already exists") 64 | 65 | item_idx = len(self._heap) 66 | self._heap.append((value, key)) 67 | self._index[key] = item_idx 68 | 69 | self._siftdown(item_idx) 70 | 71 | def insert_or_replace(self, key: Key, value: Value) -> Optional[Value]: 72 | """ 73 | Inserts an item into the map or replaces it if it already exists. 74 | 75 | :param key: item key 76 | :param value: item value 77 | :return: previous item or `None` 78 | """ 79 | 80 | if key in self._index: 81 | prev = self.remove(key) 82 | else: 83 | prev = None 84 | 85 | self.insert(key, value) 86 | 87 | return prev 88 | 89 | def next(self) -> Optional[Value]: 90 | """ 91 | Pops the smallest item from the map. 92 | 93 | :return: the smallest item 94 | """ 95 | 96 | if len(self._heap) == 0: 97 | return None 98 | 99 | value, key = self._heap[0] 100 | return self.remove(key) 101 | 102 | def top(self) -> Optional[Value]: 103 | """ 104 | Returns the smallest item from the map. 105 | 106 | :return: the smallest item 107 | """ 108 | 109 | if len(self._heap) != 0: 110 | first_value, first_key = self._heap[0] 111 | return first_value 112 | else: 113 | return None 114 | 115 | def remove(self, key: Key) -> Optional[Value]: 116 | """ 117 | Removes an item from the map 118 | 119 | :param key: item key 120 | :returns: removed item value 121 | """ 122 | 123 | if (idx := self._index.get(key)) is None: 124 | return None 125 | 126 | self._swap_heap_elements(idx, len(self._heap) - 1) 127 | 128 | value, key = self._heap.pop() 129 | self._index.pop(key) 130 | 131 | if idx < len(self._heap): 132 | self._siftdown(idx) 133 | self._siftup(idx) 134 | 135 | return value 136 | 137 | def _swap_heap_elements(self, idx1: int, idx2: int) -> None: 138 | key1 = self._heap[idx1][1] 139 | key2 = self._heap[idx2][1] 140 | 141 | self._heap[idx1], self._heap[idx2] = self._heap[idx2], self._heap[idx1] 142 | self._index[key1] = idx2 143 | self._index[key2] = idx1 144 | 145 | def _siftup(self, idx: int) -> None: 146 | value, key = self._heap[idx] 147 | 148 | left_child_idx = 2 * idx + 1 149 | while left_child_idx < len(self._heap): 150 | right_child_idx = left_child_idx + 1 151 | if right_child_idx >= len(self._heap) or self._heap[left_child_idx][0] < self._heap[right_child_idx][0]: 152 | min_child_idx = left_child_idx 153 | else: 154 | min_child_idx = right_child_idx 155 | 156 | child_value, child_key = self._heap[min_child_idx] 157 | if value < child_value or value == child_value: 158 | break 159 | 160 | self._heap[idx] = child_value, child_key 161 | self._index[child_key] = idx 162 | 163 | idx = min_child_idx 164 | left_child_idx = 2 * idx + 1 165 | 166 | self._heap[idx] = value, key 167 | self._index[key] = idx 168 | 169 | def _siftdown(self, idx: int) -> None: 170 | value, key = self._heap[idx] 171 | 172 | while idx > 0: 173 | parent_idx = (idx - 1) // 2 174 | parent_value, parent_key = self._heap[parent_idx] 175 | if value < parent_value: 176 | self._heap[idx] = parent_value, parent_key 177 | self._index[parent_key] = idx 178 | idx = parent_idx 179 | else: 180 | break 181 | 182 | self._heap[idx] = value, key 183 | self._index[key] = idx 184 | -------------------------------------------------------------------------------- /tests/contrib/test_socket_manager.py: -------------------------------------------------------------------------------- 1 | import socket 2 | import socketserver 3 | import ssl 4 | import threading 5 | from ipaddress import IPv4Address 6 | from pathlib import Path 7 | from typing import Generator, Optional, Tuple 8 | 9 | import pytest 10 | 11 | from generic_connection_pool.contrib.socket import SslSocketConnectionManager, TcpSocketConnectionManager 12 | from generic_connection_pool.threading import ConnectionPool 13 | 14 | 15 | class EchoTCPServer(socketserver.ThreadingMixIn, socketserver.TCPServer): 16 | allow_reuse_address = True 17 | 18 | class EchoRequestHandler(socketserver.BaseRequestHandler): 19 | def handle(self) -> None: 20 | while data := self.request.recv(1024): 21 | self.request.sendall(data) 22 | 23 | def __init__(self, server_address: Tuple[str, int]): 24 | super().__init__(server_address, self.EchoRequestHandler) 25 | self._server_thread: Optional[threading.Thread] = None 26 | 27 | def start(self) -> None: 28 | self._server_thread = threading.Thread(target=self.serve_forever) 29 | self._server_thread.daemon = True 30 | self._server_thread.start() 31 | 32 | def stop(self) -> None: 33 | self.server_close() 34 | self.shutdown() 35 | if (server_thread := self._server_thread) is not None: 36 | server_thread.join() 37 | 38 | 39 | class EchoSSLServer(EchoTCPServer): 40 | def __init__(self, server_address: Tuple[str, int], ssl_ctx: ssl.SSLContext): 41 | super().__init__(server_address) 42 | self._ssl_ctx = ssl_ctx 43 | 44 | def get_request(self) -> Tuple[socket.socket, Tuple[str, int]]: 45 | sock, addr = super().get_request() 46 | ssl_socket = self._ssl_ctx.wrap_socket(sock, server_side=True) 47 | return ssl_socket, addr 48 | 49 | 50 | @pytest.fixture(scope='module') 51 | def tcp_server(port_gen: Generator[int, None, None]) -> Generator[Tuple[IPv4Address, int], None, None]: 52 | addr, port = IPv4Address('127.0.0.1'), next(port_gen) 53 | 54 | server = EchoTCPServer((str(addr), port)) 55 | server.start() 56 | yield addr, port 57 | server.stop() 58 | 59 | 60 | @pytest.fixture(scope='module') 61 | def ssl_server(resource_dir: Path, port_gen: Generator[int, None, None]) -> Generator[Tuple[str, int], None, None]: 62 | hostname, port = 'localhost', next(port_gen) 63 | 64 | ssl_ctx = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_SERVER) 65 | ssl_ctx.load_cert_chain(keyfile=resource_dir / 'ssl.key', certfile=resource_dir / 'ssl.cert') 66 | 67 | server = EchoSSLServer((hostname, port), ssl_ctx=ssl_ctx) 68 | server.start() 69 | yield hostname, port 70 | server.stop() 71 | 72 | 73 | @pytest.mark.timeout(5.0) 74 | def test_tcp_socket_manager(tcp_server: Tuple[IPv4Address, int]): 75 | addr, port = tcp_server 76 | 77 | pool = ConnectionPool[Tuple[IPv4Address, int], socket.socket]( 78 | TcpSocketConnectionManager(), 79 | ) 80 | 81 | attempts = 3 82 | request = b'test' 83 | for _ in range(attempts): 84 | with pool.connection((addr, port)) as sock1: 85 | sock1.sendall(request) 86 | response = sock1.recv(len(request)) 87 | assert response == request 88 | 89 | with pool.connection((addr, port)) as sock2: 90 | sock2.sendall(request) 91 | response = sock2.recv(len(request)) 92 | assert response == request 93 | 94 | pool.close() 95 | 96 | 97 | @pytest.mark.timeout(5.0) 98 | def test_ssl_socket_manager(resource_dir: Path, ssl_server: Tuple[str, int]): 99 | hostname, port = ssl_server 100 | ssl_context = ssl.create_default_context(cafile=resource_dir / 'ssl.cert') 101 | 102 | pool = ConnectionPool[Tuple[str, int], socket.socket]( 103 | SslSocketConnectionManager(ssl_context), 104 | ) 105 | 106 | attempts = 3 107 | request = b'test' 108 | for _ in range(attempts): 109 | with pool.connection((hostname, port)) as sock1: 110 | sock1.sendall(request) 111 | response = sock1.recv(len(request)) 112 | assert response == request 113 | 114 | with pool.connection((hostname, port)) as sock2: 115 | sock2.sendall(request) 116 | response = sock2.recv(len(request)) 117 | assert response == request 118 | 119 | pool.close() 120 | 121 | 122 | @pytest.mark.timeout(5.0) 123 | def test_tcp_socket_manager_timeout(delay: float, port_gen: Generator[int, None, None]): 124 | addr, port = IPv4Address('127.0.0.1'), next(port_gen) 125 | 126 | server_sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_STREAM) 127 | server_sock.bind((str(addr), port)) 128 | server_sock.listen(1) 129 | 130 | pool = ConnectionPool[Tuple[IPv4Address, int], socket.socket]( 131 | TcpSocketConnectionManager(), 132 | ) 133 | with pytest.raises(TimeoutError): 134 | with pool.connection(endpoint=(addr, port), timeout=delay): 135 | with pool.connection(endpoint=(addr, port), timeout=delay): 136 | with pool.connection(endpoint=(addr, port), timeout=delay): 137 | pass 138 | 139 | pool.close() 140 | server_sock.close() 141 | 142 | 143 | @pytest.mark.timeout(5.0) 144 | def test_ssl_socket_manager_timeout(delay: float, port_gen: Generator[int, None, None]): 145 | hostname, port = 'localhost', next(port_gen) 146 | 147 | server_sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_STREAM) 148 | server_sock.bind((hostname, port)) 149 | server_sock.listen(1) 150 | 151 | pool = ConnectionPool[Tuple[str, int], socket.socket]( 152 | SslSocketConnectionManager(ssl.create_default_context()), 153 | ) 154 | with pytest.raises(TimeoutError): 155 | with pool.connection(endpoint=(hostname, port), timeout=delay): 156 | with pool.connection(endpoint=(hostname, port), timeout=delay): 157 | with pool.connection(endpoint=(hostname, port), timeout=delay): 158 | pass 159 | 160 | pool.close() 161 | server_sock.close() 162 | -------------------------------------------------------------------------------- /docs/source/pages/quickstart.rst: -------------------------------------------------------------------------------- 1 | .. _quickstart: 2 | 3 | 4 | Quickstart 5 | ~~~~~~~~~~ 6 | 7 | Runtime 8 | _______ 9 | 10 | ``generic-connection-pool`` supports synchronous and asynchronous connection pools. To instantiate one 11 | create :py:class:`~generic_connection_pool.threading.ConnectionPool` for threading runtime or 12 | :py:class:`~generic_connection_pool.asyncio.ConnectionPool` for asynchronous runtime: 13 | 14 | .. code-block:: python 15 | 16 | import socket 17 | from ipaddress import IPv4Address 18 | from typing import Tuple 19 | 20 | from generic_connection_pool.contrib.socket import TcpSocketConnectionManager 21 | from generic_connection_pool.threading import ConnectionPool 22 | 23 | Port = int 24 | Endpoint = Tuple[IPv4Address, Port] 25 | Connection = socket.socket 26 | 27 | pool = ConnectionPool[Endpoint, Connection]( 28 | TcpSocketConnectionManager(), 29 | idle_timeout=30.0, 30 | max_lifetime=600.0, 31 | min_idle=3, 32 | max_size=20, 33 | total_max_size=100, 34 | background_collector=True, 35 | ) 36 | 37 | 38 | Connection manager 39 | __________________ 40 | 41 | Connection pool implements logic common for any connection. Logic specific to a particular 42 | connection type is implemented by a connection manager. It defines connection lifecycle 43 | (how to create, dispose a connection or check its aliveness). 44 | For more information see :py:class:`~generic_connection_pool.threading.BaseConnectionManager` 45 | or :py:class:`~generic_connection_pool.asyncio.BaseConnectionManager`: 46 | 47 | .. code-block:: python 48 | 49 | IpAddress = IPv4Address 50 | Port = int 51 | TcpEndpoint = Tuple[IpAddress, Port] 52 | 53 | class TcpSocketConnectionManager(BaseConnectionManager[TcpEndpoint, socket.socket]): 54 | def create(self, endpoint: TcpEndpoint, timeout: Optional[float] = None) -> socket.socket: 55 | addr, port = endpoint 56 | 57 | sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_STREAM) 58 | sock.connect((str(addr), port)) 59 | 60 | return sock 61 | 62 | def dispose(self, endpoint: TcpEndpoint, conn: socket.socket, timeout: Optional[float] = None) -> None: 63 | conn.shutdown(socket.SHUT_RDWR) 64 | conn.close() 65 | 66 | 67 | You can use one of the :ref:`predefined ` connection managers or implement your own one. 68 | How to implement custom connection manager :ref:`see ` 69 | 70 | 71 | Connection acquiring 72 | ____________________ 73 | 74 | After the pool is instantiated a connection can be acquired using :py:meth:`~generic_connection_pool.threading.ConnectionPool.acquire` 75 | and released using :py:meth:`~generic_connection_pool.threading.ConnectionPool.release`. 76 | To get rid of boilerplate code the connection pool supports automated connection acquiring using 77 | :py:meth:`~generic_connection_pool.threading.ConnectionPool.connection` returning a context manager: 78 | 79 | 80 | .. code-block:: python 81 | 82 | with pool.connection(endpoint=(addr, port), timeout=5.0) as sock: 83 | sock.sendall(...) 84 | response = sock.recv(...) 85 | 86 | 87 | Connection aliveness checks 88 | ___________________________ 89 | 90 | Connection manager provides the api for connection aliveness checks. 91 | To implement that override method :py:meth:`~generic_connection_pool.threading.BaseConnectionManager.check_aliveness`. 92 | The method must return ``True`` if connection is alive and ``False`` otherwise: 93 | 94 | .. code-block:: python 95 | 96 | DbEndpoint = str 97 | Connection = psycopg2.extensions.connection 98 | 99 | class DbConnectionManager(BaseConnectionManager[DbEndpoint, Connection]): 100 | ... 101 | 102 | def check_aliveness(self, endpoint: DbEndpoint, conn: Connection, timeout: Optional[float] = None) -> bool: 103 | try: 104 | with conn.cursor() as cur: 105 | cur.execute("SELECT 1;") 106 | cur.fetchone() 107 | except (psycopg2.Error, OSError): 108 | return False 109 | 110 | return True 111 | 112 | 113 | 114 | Custom connection manager 115 | _________________________ 116 | 117 | To implement a custom connection manager you must override two methods: 118 | :py:meth:`~generic_connection_pool.threading.BaseConnectionManager.create` and 119 | :py:meth:`~generic_connection_pool.threading.BaseConnectionManager.dispose`. 120 | The other methods are optional. 121 | 122 | The following example illustrate how to implement a custom connection manager for redis. 123 | 124 | .. literalinclude:: ../../../examples/custom_manager.py 125 | :language: python 126 | 127 | 128 | Connection manager allows to define methods to be called on connection acquire, release or when 129 | a connection determined to be dead. That helps to log pool actions or collect metrics. 130 | The following examples illustrate how to collect pool metrics and export them to prometheus. 131 | 132 | .. literalinclude:: ../../../examples/manager_hooks.py 133 | :language: python 134 | 135 | 136 | Examples 137 | ________ 138 | 139 | 140 | TCP connection pool 141 | ................... 142 | 143 | The following example illustrate how to use synchronous tcp connection pool. 144 | 145 | .. literalinclude:: ../../../examples/tcp_pool.py 146 | :language: python 147 | 148 | 149 | SSL connection pool 150 | ................... 151 | 152 | The library also provide ssl based connection manager. 153 | The following example illustrate how to create ssl connection pool. 154 | 155 | .. literalinclude:: ../../../examples/ssl_pool.py 156 | :language: python 157 | 158 | 159 | Asynchronous connection pool 160 | ............................ 161 | 162 | Asynchronous connection pool api looks pretty much the same: 163 | 164 | .. literalinclude:: ../../../examples/async_pool.py 165 | :language: python 166 | 167 | 168 | DB connection pool 169 | .................. 170 | 171 | Connection pool can manage any connection including database ones. The library provides 172 | connection manager for postgres :py:class:`~generic_connection_pool.contrib.psycopg2.DbConnectionManager`. 173 | 174 | .. literalinclude:: ../../../examples/pg_pool.py 175 | :language: python 176 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | ======================= 2 | generic-connection-pool 3 | ======================= 4 | 5 | .. image:: https://static.pepy.tech/personalized-badge/generic-connection-pool?period=month&units=international_system&left_color=grey&right_color=orange&left_text=Downloads/month 6 | :target: https://pepy.tech/project/generic-connection-pool 7 | :alt: Downloads/month 8 | .. image:: https://github.com/dapper91/generic-connection-pool/actions/workflows/test.yml/badge.svg?branch=master 9 | :target: https://github.com/dapper91/generic-connection-pool/actions/workflows/test.yml 10 | :alt: Build status 11 | .. image:: https://img.shields.io/pypi/l/generic-connection-pool.svg 12 | :target: https://pypi.org/project/generic-connection-pool 13 | :alt: License 14 | .. image:: https://img.shields.io/pypi/pyversions/generic-connection-pool.svg 15 | :target: https://pypi.org/project/generic-connection-pool 16 | :alt: Supported Python versions 17 | .. image:: https://codecov.io/gh/dapper91/generic-connection-pool/branch/master/graph/badge.svg 18 | :target: https://codecov.io/gh/dapper91/generic-connection-pool 19 | :alt: Code coverage 20 | .. image:: https://readthedocs.org/projects/generic-connection-pool/badge/?version=stable&style=flat 21 | :alt: ReadTheDocs status 22 | :target: https://generic-connection-pool.readthedocs.io/en/stable/ 23 | 24 | 25 | ``generic-connection-pool`` is an extensible connection pool agnostic to the connection type it is managing. 26 | It can be used for TCP, http, database or ssh connections. 27 | 28 | Features 29 | -------- 30 | 31 | - **generic nature**: can be used for any connection you desire (TCP, http, database, ssh, etc.) 32 | - **runtime agnostic**: synchronous and asynchronous runtime supported 33 | - **flexibility**: flexable connection retention and recycling policy 34 | - **fully-typed**: mypy type-checker compatible 35 | 36 | 37 | Getting started 38 | --------------- 39 | 40 | Connection pool supports the following configurations: 41 | 42 | * **background_collector**: if ``True`` starts a background worker that disposes expired and idle connections 43 | maintaining requested pool state. If ``False`` the connections will be disposed on each connection release. 44 | * **dispose_batch_size**: maximum number of expired and idle connections to be disposed on connection release 45 | (if background collector is started the parameter is ignored). 46 | * **idle_timeout**: inactivity time (in seconds) after which an extra connection will be disposed 47 | (a connection considered as extra if the number of endpoint connection exceeds ``min_idle``). 48 | * **max_lifetime**: number of seconds after which any connection will be disposed. 49 | * **min_idle**: minimum number of connections in each endpoint the pool tries to hold. Connections that exceed 50 | that number will be considered as extra and disposed after ``idle_timeout`` seconds of inactivity. 51 | * **max_size**: maximum number of endpoint connections. 52 | * **total_max_size**: maximum number of all connections in the pool. 53 | 54 | 55 | .. image:: /static/connection-pool.svg 56 | :width: 1024 57 | :alt: Connection Pool parameters 58 | 59 | 60 | The following example illustrates how to create https pool: 61 | 62 | .. code-block:: python 63 | 64 | import socket 65 | import ssl 66 | import urllib.parse 67 | from http.client import HTTPResponse 68 | from typing import Tuple 69 | 70 | from generic_connection_pool.contrib.socket import SslSocketConnectionManager 71 | from generic_connection_pool.threading import ConnectionPool 72 | 73 | Hostname = str 74 | Port = int 75 | Endpoint = Tuple[Hostname, Port] 76 | Connection = socket.socket 77 | 78 | 79 | http_pool = ConnectionPool[Endpoint, Connection]( 80 | SslSocketConnectionManager(ssl.create_default_context()), 81 | idle_timeout=30.0, 82 | max_lifetime=600.0, 83 | min_idle=3, 84 | max_size=20, 85 | total_max_size=100, 86 | background_collector=True, 87 | ) 88 | 89 | 90 | def fetch(url: str, timeout: float = 5.0) -> None: 91 | url = urllib.parse.urlsplit(url) 92 | port = url.port or 443 if url.scheme == 'https' else 80 93 | 94 | with http_pool.connection(endpoint=(url.hostname, port), timeout=timeout) as sock: 95 | request = ( 96 | 'GET {path} HTTP/1.1\r\n' 97 | 'Host: {host}\r\n' 98 | '\r\n' 99 | '\r\n' 100 | ).format(host=url.hostname, path=url.path) 101 | 102 | sock.write(request.encode()) 103 | 104 | response = HTTPResponse(sock) 105 | response.begin() 106 | status, body = response.getcode(), response.read(response.length) 107 | 108 | print(status) 109 | print(body) 110 | 111 | 112 | try: 113 | fetch('https://en.wikipedia.org/wiki/HTTP') # http connection opened 114 | fetch('https://en.wikipedia.org/wiki/Python_(programming_language)') # http connection reused 115 | finally: 116 | http_pool.close() 117 | 118 | ... or database one 119 | 120 | .. code-block:: python 121 | 122 | import psycopg2.extensions 123 | 124 | from generic_connection_pool.contrib.psycopg2 import DbConnectionManager 125 | from generic_connection_pool.threading import ConnectionPool 126 | 127 | Endpoint = str 128 | Connection = psycopg2.extensions.connection 129 | 130 | 131 | dsn_params = dict(dbname='postgres', user='postgres', password='secret') 132 | 133 | pg_pool = ConnectionPool[Endpoint, Connection]( 134 | DbConnectionManager( 135 | dsn_params={ 136 | 'master': dict(dsn_params, host='db-master.local'), 137 | 'replica-1': dict(dsn_params, host='db-replica-1.local'), 138 | 'replica-2': dict(dsn_params, host='db-replica-2.local'), 139 | }, 140 | ), 141 | acquire_timeout=2.0, 142 | idle_timeout=60.0, 143 | max_lifetime=600.0, 144 | min_idle=3, 145 | max_size=10, 146 | total_max_size=15, 147 | background_collector=True, 148 | ) 149 | 150 | try: 151 | # connection opened 152 | with pg_pool.connection(endpoint='master') as conn: 153 | cur = conn.cursor() 154 | cur.execute("SELECT * FROM pg_stats;") 155 | print(cur.fetchone()) 156 | 157 | # connection opened 158 | with pg_pool.connection(endpoint='replica-1') as conn: 159 | cur = conn.cursor() 160 | cur.execute("SELECT * FROM pg_stats;") 161 | print(cur.fetchone()) 162 | 163 | # connection reused 164 | with pg_pool.connection(endpoint='master') as conn: 165 | cur = conn.cursor() 166 | cur.execute("SELECT * FROM pg_stats;") 167 | print(cur.fetchone()) 168 | 169 | finally: 170 | pg_pool.close() 171 | 172 | 173 | See `documentation `_ for more details. 174 | -------------------------------------------------------------------------------- /generic_connection_pool/common.py: -------------------------------------------------------------------------------- 1 | import dataclasses as dc 2 | import logging 3 | import time 4 | from collections import OrderedDict 5 | from enum import IntEnum 6 | from typing import Dict, Generic, Hashable, Optional, Tuple, TypeVar 7 | 8 | from . import exceptions 9 | from .rankmap import RankMap 10 | 11 | logger = logging.getLogger(__package__) 12 | 13 | EndpointT = TypeVar('EndpointT', bound=Hashable) 14 | ConnectionT = TypeVar('ConnectionT', bound=Hashable) 15 | 16 | 17 | class Timer: 18 | """ 19 | Timer. 20 | 21 | :param timeout: timer timeout 22 | """ 23 | 24 | def __init__(self, timeout: Optional[float]): 25 | self._timeout = timeout 26 | self._started_at = time.monotonic() 27 | 28 | @property 29 | def elapsed(self) -> float: 30 | """ 31 | Returns number of seconds since the timer was created. 32 | """ 33 | 34 | return time.monotonic() - self._started_at 35 | 36 | @property 37 | def remains(self) -> Optional[float]: 38 | """ 39 | Returns number of seconds until the timeout. 40 | """ 41 | 42 | if self._timeout is None: 43 | return None 44 | 45 | return max(0.0, self._timeout - self.elapsed) 46 | 47 | @property 48 | def timedout(self) -> bool: 49 | """ 50 | Returns if the timer timed-out. 51 | """ 52 | 53 | if self._timeout is None: 54 | return False 55 | 56 | return self.elapsed > self._timeout 57 | 58 | 59 | @dc.dataclass 60 | class ConnectionInfo(Generic[EndpointT, ConnectionT]): 61 | endpoint: EndpointT 62 | conn: ConnectionT 63 | accessed_at: float = dc.field(default_factory=time.monotonic) 64 | created_at: float = dc.field(default_factory=time.monotonic) 65 | acquires: int = 0 66 | 67 | @property 68 | def lifetime(self) -> float: 69 | return time.monotonic() - self.created_at 70 | 71 | @property 72 | def idletime(self) -> float: 73 | return time.monotonic() - self.accessed_at 74 | 75 | 76 | class BaseEndpointPool(Generic[EndpointT, ConnectionT]): 77 | """ 78 | Base endpoint pool. Contains information about connections related to a particular endpoint. 79 | """ 80 | 81 | def __init__(self, max_pool_size: int, max_extra_size: int): 82 | """ 83 | :param max_pool_size: connection pool max size 84 | :param max_extra_size: extra connection pool max size 85 | """ 86 | 87 | self._max_pool_size = max_pool_size 88 | self._max_extra_size = max_extra_size 89 | 90 | # connections from this pool are kept open util theirs lifetime expires; 91 | # they are acquired using FIFO (round-robin) strategy 92 | self._pool: OrderedDict[ConnectionT, ConnectionInfo[EndpointT, ConnectionT]] = OrderedDict() 93 | # connections from this pool are closed after idle-time expires; they are acquired using LIFO strategy 94 | # to recycle extra connections as soon as possible since they are recycled based on last access time 95 | self._extra: OrderedDict[ConnectionT, ConnectionInfo[EndpointT, ConnectionT]] = OrderedDict() 96 | # acquired connections 97 | self._acquired: Dict[ConnectionT, ConnectionInfo[EndpointT, ConnectionT]] = dict() 98 | # number of reserved connections (to be established) 99 | self._reserved = 0 100 | 101 | @property 102 | def max_size(self) -> int: 103 | """ 104 | Returns pool max size. 105 | """ 106 | 107 | return self._max_pool_size + self._max_extra_size 108 | 109 | def _has_available_slot(self) -> bool: 110 | """ 111 | Returns `True` if the pool has available connection slots otherwise `False`. 112 | """ 113 | 114 | return self._size() < self.max_size 115 | 116 | def _is_overflowed(self) -> bool: 117 | """ 118 | Returns `True` if the pool has extra connections otherwise `False`. 119 | """ 120 | 121 | return self._size() > self._max_pool_size 122 | 123 | def _size(self) -> int: 124 | return len(self._pool) + len(self._extra) + len(self._acquired) + self._reserved 125 | 126 | def _get_size(self, acquired: Optional[bool] = None) -> int: 127 | """ 128 | Returns the number of connections. 129 | 130 | :param acquired: if `True` - return the number of acquired connections 131 | if `False` - return the number of free connections 132 | if `None` - return all connections number (including reserved) 133 | 134 | return: number of connections 135 | """ 136 | 137 | if acquired is None: 138 | result = self._size() 139 | elif acquired is True: 140 | result = len(self._acquired) 141 | elif acquired is False: 142 | result = len(self._pool) + len(self._extra) 143 | else: 144 | raise AssertionError("unreachable") 145 | 146 | return result 147 | 148 | def _reserve(self) -> bool: 149 | """ 150 | Reserve one slot in the pool. 151 | """ 152 | 153 | if not self._has_available_slot(): 154 | return False 155 | 156 | self._reserved += 1 157 | 158 | return True 159 | 160 | def _unreserve(self) -> None: 161 | """ 162 | Un-reserve a pool slot. 163 | """ 164 | 165 | assert self._reserved != 0 166 | self._reserved -= 1 167 | 168 | def _acquire(self) -> Tuple[Optional[ConnectionInfo[EndpointT, ConnectionT]], bool]: 169 | """ 170 | Acquires a connection from the pool. 171 | """ 172 | 173 | if self._pool: 174 | extra = False 175 | conn, conn_info = self._pool.popitem(last=False) 176 | elif self._extra: 177 | extra = True 178 | conn, conn_info = self._extra.popitem(last=True) 179 | else: 180 | return None, False 181 | 182 | self._acquired[conn] = conn_info 183 | 184 | conn_info.acquires += 1 185 | conn_info.accessed_at = time.monotonic() 186 | 187 | return conn_info, extra 188 | 189 | def _release(self, conn: ConnectionT) -> Tuple[ConnectionInfo[EndpointT, ConnectionT], bool]: 190 | """ 191 | Releases a connection. 192 | 193 | :param conn: connection to be released 194 | """ 195 | 196 | assert conn in self._acquired, "connection is not acquired" 197 | 198 | conn_info = self._acquired.pop(conn) 199 | conn_info.accessed_at = time.monotonic() 200 | 201 | free_pool_slots = self._max_pool_size - len(self._pool) 202 | if len(self._acquired) >= free_pool_slots: 203 | extra = True 204 | self._extra[conn] = conn_info 205 | elif len(self._pool) < self._max_pool_size: 206 | extra = False 207 | self._pool[conn] = conn_info 208 | else: 209 | raise AssertionError("unreachable") 210 | 211 | return conn_info, extra 212 | 213 | def _detach(self, conn: ConnectionT, acquired: bool = False) -> ConnectionInfo[EndpointT, ConnectionT]: 214 | """ 215 | Detaches off a connection. 216 | 217 | :param conn: connection to be detached 218 | :param acquired: is the connection acquired 219 | """ 220 | 221 | if acquired: 222 | conn_info = self._acquired.pop(conn) 223 | else: 224 | if conn in self._pool: 225 | conn_info = self._pool.pop(conn) 226 | else: 227 | conn_info = self._extra.pop(conn) 228 | 229 | return conn_info 230 | 231 | def _attach(self, conn_info: ConnectionInfo[EndpointT, ConnectionT], acquired: bool = False) -> None: 232 | """ 233 | Attaches a connection to the pool. 234 | 235 | :param conn_info: connection information to be attached 236 | :param acquired: acquire the connection 237 | """ 238 | 239 | if not self._has_available_slot(): 240 | raise exceptions.ConnectionPoolIsFull 241 | 242 | if acquired: 243 | self._acquired[conn_info.conn] = conn_info 244 | else: 245 | free_pool_slots = self._max_pool_size - len(self._pool) 246 | if len(self._acquired) >= free_pool_slots: 247 | self._extra[conn_info.conn] = conn_info 248 | elif len(self._pool) < self._max_pool_size: 249 | self._pool[conn_info.conn] = conn_info 250 | else: 251 | raise AssertionError("unreachable") 252 | 253 | 254 | KeyType = TypeVar('KeyType', bound=Hashable) 255 | 256 | 257 | @dc.dataclass(frozen=True, order=True) 258 | class Event(Generic[KeyType]): 259 | """ 260 | Connection pool event. 261 | 262 | :param timestamp: event raise time 263 | :param key: event key (must be equal for the same event) 264 | """ 265 | 266 | timestamp: float = dc.field(hash=False, compare=True) 267 | key: KeyType = dc.field(hash=True, compare=False) 268 | 269 | def __eq__(self, other: object) -> bool: 270 | if not isinstance(other, Event): 271 | return False 272 | 273 | return self.key == other.key 274 | 275 | 276 | class BaseEventQueue(Generic[KeyType]): 277 | """ 278 | Connection pool event queue. 279 | """ 280 | 281 | def __init__(self) -> None: 282 | self._queue: RankMap[KeyType, Event[KeyType]] = RankMap() 283 | 284 | def _insert(self, timestamp: float, key: KeyType) -> None: 285 | """ 286 | Adds a new event to the queue. 287 | """ 288 | 289 | self._queue.insert_or_replace(key, Event(timestamp, key)) 290 | 291 | def _remove(self, key: KeyType) -> None: 292 | """ 293 | Remove an event from the queue. 294 | """ 295 | 296 | self._queue.remove(key) 297 | 298 | def _clear(self) -> None: 299 | """ 300 | Clears the queue. 301 | """ 302 | 303 | self._queue.clear() 304 | 305 | def _try_get_next_event(self) -> Tuple[Optional[KeyType], Optional[float]]: 306 | """ 307 | Tries to pop the next event from the queue. 308 | If the queue is empty returns `None ` 309 | If the first event has not occurred returns `None` and backoff timeout. 310 | 311 | :return: event data, backoff timeout 312 | """ 313 | 314 | if (event := self._queue.top()) and event.timestamp <= time.monotonic(): 315 | return event.key, 0.0 316 | 317 | backoff = event.timestamp - time.monotonic() if event is not None else None 318 | 319 | return None, backoff 320 | 321 | def _top(self) -> Optional[KeyType]: 322 | """ 323 | Pops an event from the queue regardless whether the event occurred or not. 324 | """ 325 | 326 | if (event := self._queue.top()) is None: 327 | return None 328 | 329 | return event.key 330 | 331 | 332 | class EventType(IntEnum): 333 | LIFETIME = 1 334 | IDLETIME = 2 335 | 336 | 337 | class BaseConnectionPool(Generic[EndpointT, ConnectionT]): 338 | """ 339 | Asynchronous connection pool. 340 | 341 | :param idle_timeout: inactivity time (in seconds) after which an extra connection will be disposed 342 | (a connection considered as extra if the number of endpoint connection exceeds ``min_idle``). 343 | :param max_lifetime: number of seconds after which any connection will be disposed. 344 | :param min_idle: minimum number of connections the pool tries to hold for each endpoint. Connections that exceed 345 | that number will be considered as extra and will be disposed after ``idle_timeout`` of inactivity. 346 | :param max_size: maximum number of endpoint connections. 347 | :param total_max_size: maximum number of all connections in the pool. 348 | """ 349 | 350 | def __init__( 351 | self, 352 | *, 353 | idle_timeout: float = 60.0, 354 | max_lifetime: float = 3600.0, 355 | min_idle: int = 1, 356 | max_size: int = 10, 357 | total_max_size: int = 100, 358 | ): 359 | assert min_idle <= max_size, "min_idle can't be greater than max_size" 360 | assert max_size <= total_max_size, "max_size can't be greater than total_max_size" 361 | assert idle_timeout <= max_lifetime, "idle_timeout can't be greater than max_lifetime" 362 | 363 | self._idle_timeout = idle_timeout 364 | self._max_lifetime = max_lifetime 365 | self._min_idle = min_idle 366 | self._max_size = max_size 367 | self._total_max_size = total_max_size 368 | 369 | self._pool_size = 0 370 | 371 | def get_size(self) -> int: 372 | return self._pool_size 373 | 374 | @property 375 | def idle_timeout(self) -> float: 376 | return self._idle_timeout 377 | 378 | @property 379 | def max_lifetime(self) -> float: 380 | return self._max_lifetime 381 | 382 | @property 383 | def min_idle(self) -> int: 384 | return self._min_idle 385 | 386 | @property 387 | def max_size(self) -> int: 388 | return self._max_size 389 | 390 | @property 391 | def total_max_size(self) -> int: 392 | return self._total_max_size 393 | 394 | @property 395 | def is_full(self) -> bool: 396 | return self._pool_size >= self._total_max_size 397 | -------------------------------------------------------------------------------- /static/connection-pool.drawio: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | -------------------------------------------------------------------------------- /tests/test_sync_pool.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import itertools as it 3 | import random 4 | import threading 5 | import time 6 | from typing import Dict, List, Optional, Tuple 7 | 8 | import pytest 9 | 10 | from generic_connection_pool import exceptions 11 | from generic_connection_pool.threading import BaseConnectionManager, ConnectionPool 12 | 13 | 14 | class MockConnection: 15 | def __init__(self, name: str): 16 | self._name = name 17 | 18 | def __str__(self) -> str: 19 | return self._name 20 | 21 | def __repr__(self) -> str: 22 | return f"{self.__class__.__name__}({str(self)})" 23 | 24 | 25 | class MockConnectionManager(BaseConnectionManager[int, MockConnection]): 26 | def __init__(self) -> None: 27 | self.creations: List[Tuple[int, MockConnection]] = [] 28 | self.disposals: List[MockConnection] = [] 29 | self.acquires: List[MockConnection] = [] 30 | self.releases: List[MockConnection] = [] 31 | self.dead: List[MockConnection] = [] 32 | self.aliveness: Dict[MockConnection, bool] = {} 33 | 34 | self.create_err: Optional[Exception] = None 35 | self.dispose_err: Optional[Exception] = None 36 | self.check_aliveness_err: Optional[Exception] = None 37 | self.on_acquire_err: Optional[Exception] = None 38 | self.on_release_err: Optional[Exception] = None 39 | self.on_connection_dead_err: Optional[Exception] = None 40 | 41 | self._conn_cnt = 1 42 | 43 | def create(self, endpoint: int, timeout: Optional[float] = None) -> MockConnection: 44 | if err := self.create_err: 45 | raise err 46 | 47 | conn = MockConnection(f"{endpoint}:{self._conn_cnt}") 48 | self._conn_cnt += 1 49 | self.creations.append((endpoint, conn)) 50 | 51 | return conn 52 | 53 | def dispose(self, endpoint: int, conn: MockConnection, timeout: Optional[float] = None) -> None: 54 | if err := self.dispose_err: 55 | raise err 56 | 57 | self.disposals.append(conn) 58 | 59 | def check_aliveness(self, endpoint: int, conn: MockConnection, timeout: Optional[float] = None) -> bool: 60 | if err := self.check_aliveness_err: 61 | raise err 62 | 63 | return self.aliveness.get(conn, True) 64 | 65 | def on_acquire(self, endpoint: int, conn: MockConnection) -> None: 66 | if err := self.on_acquire_err: 67 | raise err 68 | 69 | self.acquires.append(conn) 70 | 71 | def on_release(self, endpoint: int, conn: MockConnection) -> None: 72 | if err := self.on_release_err: 73 | raise err 74 | 75 | self.releases.append(conn) 76 | 77 | def on_connection_dead(self, endpoint: int, conn: MockConnection) -> None: 78 | if err := self.on_connection_dead_err: 79 | raise err 80 | 81 | self.dead.append(conn) 82 | 83 | 84 | @pytest.fixture 85 | def connection_manager() -> MockConnectionManager: 86 | return MockConnectionManager() 87 | 88 | 89 | @pytest.mark.timeout(5.0) 90 | def test_params(connection_manager: MockConnectionManager): 91 | idle_timeout = 10.0 92 | max_lifetime = 60.0 93 | min_idle = 2 94 | max_size = 16 95 | total_max_size = 512 96 | 97 | pool = ConnectionPool[int, MockConnection]( 98 | connection_manager, 99 | idle_timeout=idle_timeout, 100 | max_lifetime=max_lifetime, 101 | min_idle=min_idle, 102 | max_size=max_size, 103 | total_max_size=total_max_size, 104 | ) 105 | 106 | assert pool.idle_timeout == idle_timeout 107 | assert pool.max_lifetime == max_lifetime 108 | assert pool.min_idle == min_idle 109 | assert pool.max_size == max_size 110 | assert pool.total_max_size == total_max_size 111 | 112 | pool.close() 113 | assert pool.get_size() == 0 114 | 115 | 116 | @pytest.mark.timeout(5.0) 117 | def test_pool_context_manager(connection_manager: MockConnectionManager): 118 | pool = ConnectionPool[int, MockConnection](connection_manager, min_idle=0, idle_timeout=0.0) 119 | 120 | with pool.connection(endpoint=1) as conn1: 121 | assert pool.get_size() == 1 122 | 123 | with pool.connection(endpoint=2) as conn2: 124 | assert pool.get_size() == 2 125 | 126 | with pool.connection(endpoint=3) as conn3: 127 | assert pool.get_size() == 3 128 | 129 | assert connection_manager.creations == [(1, conn1), (2, conn2), (3, conn3)] 130 | assert pool.get_size() == 0 131 | assert connection_manager.disposals == [conn3, conn2, conn1] 132 | 133 | pool.close() 134 | 135 | 136 | @pytest.mark.timeout(5.0) 137 | def test_pool_acquire_round_robin(connection_manager: MockConnectionManager): 138 | def fill_endpoint_pool(pool: ConnectionPool[int, MockConnection], endpoint: int, size: int) -> List[MockConnection]: 139 | connections = [pool.acquire(endpoint) for _ in range(size)] 140 | for conn in connections: 141 | pool.release(conn, endpoint) 142 | 143 | return connections 144 | 145 | pool = ConnectionPool[int, MockConnection](connection_manager, min_idle=3) 146 | connections = { 147 | 1: fill_endpoint_pool(pool, endpoint=1, size=1), 148 | 2: fill_endpoint_pool(pool, endpoint=2, size=2), 149 | 3: fill_endpoint_pool(pool, endpoint=3, size=3), 150 | } 151 | 152 | for endpoint, connections in connections.items(): 153 | for conn in it.chain(connections, connections): 154 | assert pool.acquire(endpoint) == conn 155 | pool.release(conn, endpoint) 156 | 157 | pool.close() 158 | assert pool.get_size() == 0 159 | 160 | 161 | @pytest.mark.timeout(5.0) 162 | def test_connection_manager_callbacks(connection_manager: MockConnectionManager): 163 | pool = ConnectionPool[int, MockConnection](connection_manager) 164 | with pool.connection(endpoint=1) as conn: 165 | pass 166 | 167 | assert connection_manager.acquires == [conn] 168 | assert connection_manager.releases == [conn] 169 | 170 | pool.close() 171 | assert pool.get_size() == 0 172 | 173 | 174 | @pytest.mark.timeout(5.0) 175 | def test_connection_wait(delay: float, connection_manager: MockConnectionManager): 176 | pool = ConnectionPool[int, MockConnection]( 177 | connection_manager, 178 | max_size=1, 179 | ) 180 | 181 | def acquire_connection(timeout: float): 182 | with pool.connection(endpoint=1, timeout=timeout): 183 | time.sleep(delay) 184 | assert pool.get_size() == 1 185 | 186 | threads_cnt = 10 187 | threads = [ 188 | threading.Thread(target=acquire_connection(timeout=(threads_cnt + 1) * delay)) 189 | for _ in range(threads_cnt) 190 | ] 191 | for thread in threads: 192 | thread.start() 193 | 194 | for thread in threads: 195 | thread.join() 196 | 197 | assert pool.get_size() == 1 198 | 199 | pool.close() 200 | assert pool.get_size() == 0 201 | 202 | 203 | @pytest.mark.timeout(5.0) 204 | def test_pool_max_size(delay: float, connection_manager: MockConnectionManager): 205 | pool = ConnectionPool[int, MockConnection]( 206 | connection_manager, 207 | min_idle=0, 208 | idle_timeout=0.0, 209 | max_size=1, 210 | total_max_size=2, 211 | ) 212 | 213 | conn1 = pool.acquire(endpoint=1) 214 | with pytest.raises(TimeoutError): 215 | pool.acquire(endpoint=1, timeout=delay) 216 | assert pool.get_size() == 1 217 | 218 | pool.release(conn1, endpoint=1) 219 | assert pool.get_size() == 0 220 | 221 | conn1 = pool.acquire(endpoint=1) 222 | pool.release(conn1, endpoint=1) 223 | assert pool.get_size() == 0 224 | 225 | pool.close() 226 | 227 | 228 | @pytest.mark.timeout(5.0) 229 | def test_pool_total_max_size(delay: float, connection_manager: MockConnectionManager): 230 | pool = ConnectionPool[int, MockConnection]( 231 | connection_manager, 232 | min_idle=0, 233 | idle_timeout=0.0, 234 | max_size=1, 235 | total_max_size=1, 236 | ) 237 | 238 | conn1 = pool.acquire(1) 239 | with pytest.raises(TimeoutError): 240 | pool.acquire(endpoint=2, timeout=delay) 241 | assert pool.get_size() == 1 242 | 243 | pool.release(conn1, endpoint=1) 244 | assert pool.get_size() == 0 245 | 246 | conn1 = pool.acquire(endpoint=1) 247 | pool.release(conn1, endpoint=1) 248 | assert pool.get_size() == 0 249 | 250 | pool.close() 251 | 252 | 253 | @pytest.mark.parametrize('background_collector', [True, False]) 254 | @pytest.mark.timeout(5.0) 255 | def test_pool_disposable_connections_collection( 256 | delay: float, 257 | connection_manager: MockConnectionManager, 258 | background_collector: bool, 259 | ): 260 | pool = ConnectionPool[int, MockConnection]( 261 | connection_manager, 262 | idle_timeout=0.0, 263 | min_idle=0, 264 | background_collector=background_collector, 265 | dispose_batch_size=10, 266 | ) 267 | 268 | connections = [ 269 | (endpoint := n % 3, pool.acquire(endpoint=endpoint)) 270 | for n in range(20) 271 | ] 272 | assert pool.get_size() == len(connections) 273 | 274 | assert connection_manager.creations == connections 275 | 276 | for endpoint, connection in connections: 277 | pool.release(connection, endpoint=endpoint) 278 | 279 | time.sleep(delay) 280 | 281 | assert pool.get_size() == 0 282 | assert connection_manager.disposals == [conn for ep, conn in connections] 283 | 284 | pool.close() 285 | 286 | 287 | @pytest.mark.parametrize('background_collector', [True, False]) 288 | @pytest.mark.timeout(5.0) 289 | def test_pool_min_idle( 290 | delay: float, 291 | connection_manager: MockConnectionManager, 292 | background_collector: bool, 293 | ): 294 | pool = ConnectionPool[int, MockConnection]( 295 | connection_manager, 296 | idle_timeout=0.0, 297 | min_idle=1, 298 | background_collector=background_collector, 299 | ) 300 | conn11 = pool.acquire(endpoint=1) 301 | conn12 = pool.acquire(endpoint=1) 302 | conn13 = pool.acquire(endpoint=1) 303 | conn21 = pool.acquire(endpoint=2) 304 | conn22 = pool.acquire(endpoint=2) 305 | conn31 = pool.acquire(endpoint=3) 306 | 307 | assert pool.get_size() == 6 308 | 309 | assert connection_manager.creations == [ 310 | (1, conn11), (1, conn12), (1, conn13), 311 | (2, conn21), (2, conn22), 312 | (3, conn31), 313 | ] 314 | 315 | pool.release(conn11, endpoint=1) 316 | pool.release(conn12, endpoint=1) 317 | pool.release(conn13, endpoint=1) 318 | time.sleep(delay) 319 | pool.release(conn21, endpoint=2) 320 | pool.release(conn22, endpoint=2) 321 | time.sleep(delay) 322 | pool.release(conn31, endpoint=3) 323 | 324 | time.sleep(delay) 325 | 326 | assert pool.get_size() == 3 327 | assert connection_manager.disposals == [conn11, conn12, conn21] 328 | 329 | pool.close() 330 | assert pool.get_size() == 0 331 | 332 | 333 | @pytest.mark.parametrize('background_collector', [True, False]) 334 | @pytest.mark.timeout(5.0) 335 | def test_pool_idle_timeout( 336 | delay: float, 337 | connection_manager: MockConnectionManager, 338 | background_collector: bool, 339 | ): 340 | pool = ConnectionPool[int, MockConnection]( 341 | connection_manager, 342 | idle_timeout=delay, 343 | min_idle=1, 344 | dispose_batch_size=5, 345 | background_collector=background_collector, 346 | ) 347 | 348 | with pool.connection(endpoint=1): 349 | with pool.connection(endpoint=1): 350 | pass 351 | 352 | with pool.connection(endpoint=2): 353 | with pool.connection(endpoint=2): 354 | pass 355 | assert pool.get_size() == 4 356 | 357 | time.sleep(2 * delay) 358 | 359 | # run disposal if background worker is not started 360 | with pool.connection(endpoint=3): 361 | pass 362 | assert pool.get_size() == 3 363 | 364 | pool.close() 365 | assert pool.get_size() == 0 366 | 367 | 368 | @pytest.mark.timeout(5.0) 369 | def test_idle_connection_close_on_total_max_size_exceeded(connection_manager: MockConnectionManager): 370 | pool = ConnectionPool[int, MockConnection]( 371 | connection_manager, 372 | min_idle=3, 373 | max_size=3, 374 | total_max_size=3, 375 | ) 376 | 377 | with pool.connection(endpoint=1): 378 | with pool.connection(endpoint=1): 379 | with pool.connection(endpoint=1): 380 | pass 381 | assert pool.get_size() == 3 382 | 383 | with pool.connection(endpoint=2): 384 | pass 385 | assert pool.get_size() == 3 386 | 387 | with pool.connection(endpoint=3): 388 | pass 389 | assert pool.get_size() == 3 390 | 391 | pool.close() 392 | assert pool.get_size() == 0 393 | 394 | 395 | @pytest.mark.parametrize('background_collector', [True, False]) 396 | @pytest.mark.timeout(5.0) 397 | def test_pool_max_lifetime( 398 | delay: float, 399 | connection_manager: MockConnectionManager, 400 | background_collector: bool, 401 | ): 402 | pool = ConnectionPool[int, MockConnection]( 403 | connection_manager, 404 | idle_timeout=delay, 405 | max_lifetime=2 * delay, 406 | min_idle=1, 407 | dispose_batch_size=4, 408 | background_collector=background_collector, 409 | ) 410 | 411 | with pool.connection(endpoint=1): 412 | with pool.connection(endpoint=1): 413 | pass 414 | 415 | with pool.connection(endpoint=2): 416 | with pool.connection(endpoint=2): 417 | pass 418 | assert pool.get_size() == 4 419 | 420 | time.sleep(3 * delay) 421 | 422 | # run disposal if background worker is not started 423 | with pool.connection(endpoint=3): 424 | pass 425 | assert pool.get_size() == 1 426 | 427 | pool.close() 428 | assert pool.get_size() == 0 429 | 430 | 431 | @pytest.mark.timeout(5.0) 432 | def test_pool_aliveness_check(connection_manager: MockConnectionManager): 433 | pool = ConnectionPool[int, MockConnection](connection_manager) 434 | 435 | with pool.connection(endpoint=1) as conn1: 436 | pass 437 | 438 | connection_manager.aliveness[conn1] = False 439 | 440 | with pool.connection(endpoint=1) as conn2: 441 | assert conn2 != conn1 442 | 443 | assert pool.get_size() == 1 444 | assert connection_manager.dead == [conn1] 445 | 446 | pool.close() 447 | assert pool.get_size() == 0 448 | 449 | 450 | @pytest.mark.timeout(5.0) 451 | def test_pool_close(delay: float, connection_manager: MockConnectionManager): 452 | pool = ConnectionPool[int, MockConnection]( 453 | connection_manager, 454 | idle_timeout=10, 455 | max_lifetime=10, 456 | min_idle=10, 457 | background_collector=False, 458 | ) 459 | 460 | with pool.connection(endpoint=1): 461 | with pool.connection(endpoint=1): 462 | pass 463 | 464 | with pool.connection(endpoint=2): 465 | with pool.connection(endpoint=2): 466 | pass 467 | 468 | assert pool.get_size() == 4 469 | pool.close() 470 | assert pool.get_size() == 0 471 | 472 | 473 | @pytest.mark.timeout(5.0) 474 | def test_pool_close_wait(delay: float, connection_manager: MockConnectionManager): 475 | pool = ConnectionPool[int, MockConnection]( 476 | connection_manager, 477 | idle_timeout=10, 478 | max_lifetime=10, 479 | min_idle=10, 480 | max_size=10, 481 | total_max_size=100, 482 | background_collector=False, 483 | ) 484 | 485 | thread_cnt = 10 486 | endpoint_cnt = 6 487 | delay_factor = 0.2 488 | 489 | barrier = threading.Barrier(thread_cnt + 1) 490 | 491 | def acquire_connection(endpoint: int, delay: float): 492 | with contextlib.suppress(exceptions.ConnectionPoolClosedError): 493 | with pool.connection(endpoint=endpoint): 494 | barrier.wait() 495 | time.sleep(delay) 496 | 497 | threads = [ 498 | threading.Thread( 499 | target=acquire_connection, 500 | kwargs=dict(endpoint=i % endpoint_cnt, delay=random.random() * delay_factor), 501 | ) 502 | for i in range(thread_cnt) 503 | ] 504 | 505 | for thread in threads: 506 | thread.start() 507 | 508 | barrier.wait() 509 | 510 | pool.close(timeout=thread_cnt * delay) 511 | assert pool.get_size() == 0 512 | 513 | for thread in threads: 514 | thread.join() 515 | 516 | 517 | @pytest.mark.timeout(5.0) 518 | def test_pool_close_timeout(delay: float, connection_manager: MockConnectionManager): 519 | pool = ConnectionPool[int, MockConnection](connection_manager) 520 | 521 | acquired = threading.Event() 522 | 523 | def acquire_connection() -> None: 524 | with contextlib.suppress(exceptions.ConnectionPoolClosedError): 525 | with pool.connection(endpoint=1): 526 | acquired.set() 527 | time.sleep(delay) 528 | assert connection_manager.disposals == [] 529 | 530 | thread = threading.Thread(target=acquire_connection) 531 | thread.start() 532 | 533 | acquired.wait() 534 | pool.close(timeout=2 * delay) 535 | 536 | thread.join() 537 | 538 | 539 | @pytest.mark.timeout(5.0) 540 | def test_pool_connection_manager_creation_error(connection_manager: MockConnectionManager): 541 | class TestException(Exception): 542 | pass 543 | connection_manager.create_err = TestException() 544 | 545 | pool = ConnectionPool[int, MockConnection](connection_manager) 546 | with pytest.raises(TestException): 547 | with pool.connection(endpoint=1): 548 | pass 549 | 550 | assert pool.get_size() == 0 551 | assert pool.get_endpoint_pool_size(endpoint=1) == 0 552 | 553 | 554 | @pytest.mark.timeout(5.0) 555 | def test_pool_connection_manager_release_error(connection_manager: MockConnectionManager): 556 | class TestException(Exception): 557 | pass 558 | connection_manager.on_release_err = TestException() 559 | 560 | pool = ConnectionPool[int, MockConnection](connection_manager, idle_timeout=0.0, min_idle=0) 561 | with pytest.raises(TestException): 562 | with pool.connection(endpoint=1): 563 | pass 564 | 565 | assert pool.get_size() == 1 566 | assert pool.get_endpoint_pool_size(endpoint=1, acquired=False) == 1 567 | 568 | 569 | @pytest.mark.timeout(5.0) 570 | def test_pool_connection_manager_aliveness_error(delay: float, connection_manager: MockConnectionManager): 571 | class TestException(Exception): 572 | pass 573 | connection_manager.check_aliveness_err = TestException() 574 | 575 | pool = ConnectionPool[int, MockConnection](connection_manager, idle_timeout=delay, min_idle=1) 576 | with pool.connection(endpoint=1): 577 | pass 578 | 579 | with pytest.raises(TestException): 580 | with pool.connection(endpoint=1): 581 | pass 582 | 583 | assert pool.get_size() == 1 584 | assert pool.get_endpoint_pool_size(endpoint=1, acquired=False) == 1 585 | 586 | 587 | @pytest.mark.timeout(5.0) 588 | def test_pool_connection_manager_dead_connection_error(delay: float, connection_manager: MockConnectionManager): 589 | class TestException(Exception): 590 | pass 591 | connection_manager.on_connection_dead_err = TestException() 592 | 593 | pool = ConnectionPool[int, MockConnection](connection_manager, idle_timeout=delay, min_idle=1) 594 | with pool.connection(endpoint=1) as conn: 595 | pass 596 | 597 | connection_manager.aliveness[conn] = False 598 | with pytest.raises(TestException): 599 | with pool.connection(endpoint=1): 600 | pass 601 | 602 | assert pool.get_size() == 0 603 | assert pool.get_endpoint_pool_size(endpoint=1) == 0 604 | 605 | 606 | @pytest.mark.timeout(5.0) 607 | def test_pool_connection_manager_acquire_error(delay: float, connection_manager: MockConnectionManager): 608 | class TestException(Exception): 609 | pass 610 | connection_manager.on_acquire_err = TestException() 611 | 612 | pool = ConnectionPool[int, MockConnection](connection_manager, idle_timeout=delay, min_idle=1) 613 | with pytest.raises(TestException): 614 | with pool.connection(endpoint=1): 615 | pass 616 | 617 | assert pool.get_size() == 1 618 | assert pool.get_endpoint_pool_size(endpoint=1, acquired=False) == 1 619 | 620 | 621 | @pytest.mark.timeout(5.0) 622 | def test_pool_connection_manager_dispose_error(connection_manager: MockConnectionManager): 623 | class TestException(Exception): 624 | pass 625 | connection_manager.dispose_err = TestException() 626 | 627 | pool = ConnectionPool[int, MockConnection](connection_manager, idle_timeout=0.0, min_idle=0) 628 | with pool.connection(endpoint=1): 629 | pass 630 | 631 | assert pool.get_size() == 0 632 | assert pool.get_endpoint_pool_size(endpoint=1, acquired=False) == 0 633 | -------------------------------------------------------------------------------- /tests/test_async_pool.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import contextlib 3 | import itertools as it 4 | import random 5 | from typing import Dict, List, Optional, Tuple 6 | 7 | import pytest 8 | 9 | from generic_connection_pool import exceptions 10 | from generic_connection_pool.asyncio import BaseConnectionManager, ConnectionPool 11 | 12 | 13 | class MockConnection: 14 | def __init__(self, name: str): 15 | self._name = name 16 | 17 | def __str__(self) -> str: 18 | return self._name 19 | 20 | def __repr__(self) -> str: 21 | return f"{self.__class__.__name__}({str(self)})" 22 | 23 | 24 | class MockConnectionManager(BaseConnectionManager[int, MockConnection]): 25 | def __init__(self): 26 | self.creations: List[Tuple[int, MockConnection]] = [] 27 | self.disposals: List[MockConnection] = [] 28 | self.acquires: List[MockConnection] = [] 29 | self.releases: List[MockConnection] = [] 30 | self.dead: List[MockConnection] = [] 31 | self.aliveness: Dict[MockConnection, bool] = {} 32 | 33 | self.create_err: Optional[Exception] = None 34 | self.dispose_err: Optional[Exception] = None 35 | self.check_aliveness_err: Optional[Exception] = None 36 | self.on_acquire_err: Optional[Exception] = None 37 | self.on_release_err: Optional[Exception] = None 38 | self.on_connection_dead_err: Optional[Exception] = None 39 | 40 | self._conn_cnt = 1 41 | 42 | async def create(self, endpoint: int) -> MockConnection: 43 | if err := self.create_err: 44 | raise err 45 | 46 | conn = MockConnection(f"{endpoint}:{self._conn_cnt}") 47 | self._conn_cnt += 1 48 | self.creations.append((endpoint, conn)) 49 | 50 | return conn 51 | 52 | async def dispose(self, endpoint: int, conn: MockConnection) -> None: 53 | if err := self.dispose_err: 54 | raise err 55 | 56 | self.disposals.append(conn) 57 | 58 | async def check_aliveness(self, endpoint: int, conn: MockConnection) -> bool: 59 | if err := self.check_aliveness_err: 60 | raise err 61 | 62 | return self.aliveness.get(conn, True) 63 | 64 | async def on_acquire(self, endpoint: int, conn: MockConnection) -> None: 65 | if err := self.on_acquire_err: 66 | raise err 67 | 68 | self.acquires.append(conn) 69 | 70 | async def on_release(self, endpoint: int, conn: MockConnection) -> None: 71 | if err := self.on_release_err: 72 | raise err 73 | 74 | self.releases.append(conn) 75 | 76 | async def on_connection_dead(self, endpoint: int, conn: MockConnection) -> None: 77 | if err := self.on_connection_dead_err: 78 | raise err 79 | 80 | self.dead.append(conn) 81 | 82 | 83 | @pytest.fixture 84 | def connection_manager() -> MockConnectionManager: 85 | return MockConnectionManager() 86 | 87 | 88 | @pytest.mark.timeout(5.0) 89 | async def test_params(connection_manager: MockConnectionManager): 90 | idle_timeout = 10.0 91 | max_lifetime = 60.0 92 | min_idle = 2 93 | max_size = 16 94 | total_max_size = 512 95 | 96 | pool = ConnectionPool[int, MockConnection]( 97 | connection_manager, 98 | idle_timeout=idle_timeout, 99 | max_lifetime=max_lifetime, 100 | min_idle=min_idle, 101 | max_size=max_size, 102 | total_max_size=total_max_size, 103 | ) 104 | 105 | assert pool.idle_timeout == idle_timeout 106 | assert pool.max_lifetime == max_lifetime 107 | assert pool.min_idle == min_idle 108 | assert pool.max_size == max_size 109 | assert pool.total_max_size == total_max_size 110 | 111 | await pool.close() 112 | assert pool.get_size() == 0 113 | 114 | 115 | @pytest.mark.timeout(5.0) 116 | async def test_pool_context_manager(connection_manager: MockConnectionManager): 117 | pool = ConnectionPool[int, MockConnection](connection_manager, min_idle=0, idle_timeout=0.0) 118 | 119 | async with pool.connection(endpoint=1) as conn1: 120 | assert pool.get_size() == 1 121 | 122 | async with pool.connection(endpoint=2) as conn2: 123 | assert pool.get_size() == 2 124 | 125 | async with pool.connection(endpoint=3) as conn3: 126 | assert pool.get_size() == 3 127 | 128 | assert connection_manager.creations == [(1, conn1), (2, conn2), (3, conn3)] 129 | assert pool.get_size() == 0 130 | assert connection_manager.disposals == [conn3, conn2, conn1] 131 | 132 | await pool.close() 133 | 134 | 135 | @pytest.mark.timeout(5.0) 136 | async def test_pool_acquire_round_robin(connection_manager: MockConnectionManager): 137 | async def fill_endpoint_pool( 138 | pool: ConnectionPool[int, MockConnection], 139 | endpoint: int, 140 | size: int, 141 | ) -> List[MockConnection]: 142 | connections = [await pool.acquire(endpoint) for _ in range(size)] 143 | for conn in connections: 144 | await pool.release(conn, endpoint) 145 | 146 | return connections 147 | 148 | pool = ConnectionPool[int, MockConnection](connection_manager, min_idle=3) 149 | connections = { 150 | 1: await fill_endpoint_pool(pool, endpoint=1, size=1), 151 | 2: await fill_endpoint_pool(pool, endpoint=2, size=2), 152 | 3: await fill_endpoint_pool(pool, endpoint=3, size=3), 153 | } 154 | 155 | for endpoint, connections in connections.items(): 156 | for conn in it.chain(connections, connections): 157 | assert await pool.acquire(endpoint) == conn 158 | await pool.release(conn, endpoint) 159 | 160 | await pool.close() 161 | assert pool.get_size() == 0 162 | 163 | 164 | @pytest.mark.timeout(5.0) 165 | async def test_connection_manager_callbacks(connection_manager: MockConnectionManager): 166 | pool = ConnectionPool[int, MockConnection](connection_manager) 167 | async with pool.connection(endpoint=1) as conn: 168 | pass 169 | 170 | assert connection_manager.acquires == [conn] 171 | assert connection_manager.releases == [conn] 172 | 173 | await pool.close() 174 | assert pool.get_size() == 0 175 | 176 | 177 | @pytest.mark.timeout(5.0) 178 | async def test_connection_wait(delay: float, connection_manager: MockConnectionManager): 179 | pool = ConnectionPool[int, MockConnection]( 180 | connection_manager, 181 | max_size=1, 182 | ) 183 | 184 | async def acquire_connection(timeout: float) -> None: 185 | async with pool.connection(endpoint=1, timeout=timeout): 186 | await asyncio.sleep(delay) 187 | assert pool.get_size() == 1 188 | 189 | task_cnt = 10 190 | for task in [ 191 | asyncio.create_task(acquire_connection(timeout=(task_cnt + 1) * delay)) 192 | for _ in range(task_cnt) 193 | ]: 194 | await task 195 | 196 | assert pool.get_size() == 1 197 | 198 | await pool.close() 199 | assert pool.get_size() == 0 200 | 201 | 202 | @pytest.mark.timeout(5.0) 203 | async def test_pool_max_size(delay: float, connection_manager: MockConnectionManager): 204 | pool = ConnectionPool[int, MockConnection]( 205 | connection_manager, 206 | min_idle=0, 207 | idle_timeout=0.0, 208 | max_size=1, 209 | total_max_size=2, 210 | ) 211 | 212 | conn1 = await pool.acquire(endpoint=1) 213 | with pytest.raises(asyncio.TimeoutError): 214 | await pool.acquire(endpoint=1, timeout=delay) 215 | assert pool.get_size() == 1 216 | 217 | await pool.release(conn1, endpoint=1) 218 | assert pool.get_size() == 0 219 | 220 | conn1 = await pool.acquire(endpoint=1) 221 | await pool.release(conn1, endpoint=1) 222 | assert pool.get_size() == 0 223 | 224 | await pool.close() 225 | 226 | 227 | @pytest.mark.timeout(5.0) 228 | async def test_pool_total_max_size(delay: float, connection_manager: MockConnectionManager): 229 | pool = ConnectionPool[int, MockConnection]( 230 | connection_manager, 231 | min_idle=0, 232 | idle_timeout=0.0, 233 | max_size=1, 234 | total_max_size=1, 235 | ) 236 | 237 | conn1 = await pool.acquire(1) 238 | with pytest.raises(asyncio.TimeoutError): 239 | await pool.acquire(endpoint=2, timeout=delay) 240 | assert pool.get_size() == 1 241 | 242 | await pool.release(conn1, endpoint=1) 243 | assert pool.get_size() == 0 244 | 245 | conn1 = await pool.acquire(endpoint=1) 246 | await pool.release(conn1, endpoint=1) 247 | assert pool.get_size() == 0 248 | 249 | await pool.close() 250 | 251 | 252 | @pytest.mark.parametrize('background_collector', [True, False]) 253 | @pytest.mark.timeout(5.0) 254 | async def test_pool_disposable_connections_collection( 255 | delay: float, 256 | connection_manager: MockConnectionManager, 257 | background_collector: bool, 258 | ): 259 | pool = ConnectionPool[int, MockConnection]( 260 | connection_manager, 261 | idle_timeout=0.0, 262 | min_idle=0, 263 | background_collector=background_collector, 264 | dispose_batch_size=10, 265 | ) 266 | 267 | connections = [ 268 | (endpoint := n % 3, await pool.acquire(endpoint=endpoint)) 269 | for n in range(20) 270 | ] 271 | assert pool.get_size() == len(connections) 272 | 273 | assert connection_manager.creations == connections 274 | 275 | for endpoint, connection in connections: 276 | await pool.release(connection, endpoint=endpoint) 277 | 278 | await asyncio.sleep(delay) 279 | 280 | assert pool.get_size() == 0 281 | assert connection_manager.disposals == [conn for ep, conn in connections] 282 | 283 | await pool.close() 284 | 285 | 286 | @pytest.mark.parametrize('background_collector', [True, False]) 287 | @pytest.mark.timeout(5.0) 288 | async def test_pool_min_idle( 289 | delay: float, 290 | connection_manager: MockConnectionManager, 291 | background_collector: bool, 292 | ): 293 | pool = ConnectionPool[int, MockConnection]( 294 | connection_manager, 295 | idle_timeout=0.0, 296 | min_idle=1, 297 | background_collector=background_collector, 298 | ) 299 | conn11 = await pool.acquire(endpoint=1) 300 | conn12 = await pool.acquire(endpoint=1) 301 | conn13 = await pool.acquire(endpoint=1) 302 | conn21 = await pool.acquire(endpoint=2) 303 | conn22 = await pool.acquire(endpoint=2) 304 | conn31 = await pool.acquire(endpoint=3) 305 | 306 | assert pool.get_size() == 6 307 | 308 | assert connection_manager.creations == [ 309 | (1, conn11), (1, conn12), (1, conn13), 310 | (2, conn21), (2, conn22), 311 | (3, conn31), 312 | ] 313 | 314 | await pool.release(conn11, endpoint=1) 315 | await pool.release(conn12, endpoint=1) 316 | await pool.release(conn13, endpoint=1) 317 | await asyncio.sleep(delay) 318 | await pool.release(conn21, endpoint=2) 319 | await pool.release(conn22, endpoint=2) 320 | await asyncio.sleep(delay) 321 | await pool.release(conn31, endpoint=3) 322 | 323 | await asyncio.sleep(delay) 324 | 325 | assert pool.get_size() == 3 326 | assert connection_manager.disposals == [conn11, conn12, conn21] 327 | 328 | await pool.close() 329 | assert pool.get_size() == 0 330 | 331 | 332 | @pytest.mark.parametrize('background_collector', [True, False]) 333 | @pytest.mark.timeout(5.0) 334 | async def test_pool_idle_timeout( 335 | delay: float, 336 | connection_manager: MockConnectionManager, 337 | background_collector: bool, 338 | ): 339 | pool = ConnectionPool[int, MockConnection]( 340 | connection_manager, 341 | idle_timeout=delay, 342 | min_idle=1, 343 | dispose_batch_size=5, 344 | background_collector=background_collector, 345 | ) 346 | 347 | async with pool.connection(endpoint=1): 348 | async with pool.connection(endpoint=1): 349 | pass 350 | 351 | async with pool.connection(endpoint=2): 352 | async with pool.connection(endpoint=2): 353 | pass 354 | assert pool.get_size() == 4 355 | 356 | await asyncio.sleep(2 * delay) 357 | 358 | # run disposal if background worker is not started 359 | async with pool.connection(endpoint=3): 360 | pass 361 | assert pool.get_size() == 3 362 | 363 | await pool.close() 364 | assert pool.get_size() == 0 365 | 366 | 367 | @pytest.mark.timeout(5.0) 368 | async def test_idle_connection_close_on_total_max_size_exceeded(connection_manager: MockConnectionManager): 369 | pool = ConnectionPool[int, MockConnection]( 370 | connection_manager, 371 | min_idle=3, 372 | max_size=3, 373 | total_max_size=3, 374 | ) 375 | 376 | async with pool.connection(endpoint=1): 377 | async with pool.connection(endpoint=1): 378 | async with pool.connection(endpoint=1): 379 | pass 380 | assert pool.get_size() == 3 381 | 382 | async with pool.connection(endpoint=2): 383 | pass 384 | assert pool.get_size() == 3 385 | 386 | async with pool.connection(endpoint=3): 387 | pass 388 | assert pool.get_size() == 3 389 | 390 | await pool.close() 391 | assert pool.get_size() == 0 392 | 393 | 394 | @pytest.mark.parametrize('background_collector', [True, False]) 395 | @pytest.mark.timeout(5.0) 396 | async def test_pool_max_lifetime( 397 | delay: float, 398 | connection_manager: MockConnectionManager, 399 | background_collector: bool, 400 | ): 401 | pool = ConnectionPool[int, MockConnection]( 402 | connection_manager, 403 | idle_timeout=delay, 404 | max_lifetime=2 * delay, 405 | min_idle=1, 406 | dispose_batch_size=4, 407 | background_collector=background_collector, 408 | ) 409 | 410 | async with pool.connection(endpoint=1): 411 | async with pool.connection(endpoint=1): 412 | pass 413 | 414 | async with pool.connection(endpoint=2): 415 | async with pool.connection(endpoint=2): 416 | pass 417 | assert pool.get_size() == 4 418 | 419 | await asyncio.sleep(3 * delay) 420 | 421 | # run disposal if background worker is not started 422 | async with pool.connection(endpoint=3): 423 | pass 424 | assert pool.get_size() == 1 425 | 426 | await pool.close() 427 | assert pool.get_size() == 0 428 | 429 | 430 | @pytest.mark.timeout(5.0) 431 | async def test_pool_aliveness_check(delay: float, connection_manager: MockConnectionManager): 432 | pool = ConnectionPool[int, MockConnection](connection_manager) 433 | 434 | async with pool.connection(endpoint=1) as conn1: 435 | pass 436 | 437 | connection_manager.aliveness[conn1] = False 438 | 439 | async with pool.connection(endpoint=1) as conn2: 440 | assert conn2 != conn1 441 | 442 | assert pool.get_size() == 1 443 | assert connection_manager.dead == [conn1] 444 | 445 | await pool.close() 446 | assert pool.get_size() == 0 447 | 448 | 449 | @pytest.mark.timeout(5.0) 450 | async def test_pool_close(delay: float, connection_manager: MockConnectionManager): 451 | pool = ConnectionPool[int, MockConnection]( 452 | connection_manager, 453 | idle_timeout=10, 454 | max_lifetime=10, 455 | min_idle=10, 456 | background_collector=False, 457 | ) 458 | 459 | async with pool.connection(endpoint=1): 460 | async with pool.connection(endpoint=1): 461 | pass 462 | 463 | async with pool.connection(endpoint=2): 464 | async with pool.connection(endpoint=2): 465 | pass 466 | 467 | assert pool.get_size() == 4 468 | await pool.close() 469 | assert pool.get_size() == 0 470 | 471 | 472 | @pytest.mark.timeout(5.0) 473 | async def test_pool_close_wait(delay: float, connection_manager: MockConnectionManager): 474 | pool = ConnectionPool[int, MockConnection]( 475 | connection_manager, 476 | idle_timeout=10, 477 | max_lifetime=10, 478 | min_idle=10, 479 | max_size=10, 480 | total_max_size=100, 481 | background_collector=False, 482 | ) 483 | 484 | task_cnt = 50 485 | endpoint_cnt = 6 486 | delay_factor = 0.2 487 | 488 | acquire_cnt = 0 489 | acquire_event = asyncio.Event() 490 | finished = asyncio.Event() 491 | 492 | async def acquire_connection(endpoint: int, delay: float): 493 | nonlocal acquire_cnt 494 | 495 | with contextlib.suppress(exceptions.ConnectionPoolClosedError): 496 | async with pool.connection(endpoint=endpoint): 497 | acquire_cnt += 1 498 | acquire_event.set() 499 | await finished.wait() 500 | 501 | await asyncio.sleep(delay) 502 | 503 | tasks = [ 504 | asyncio.create_task(acquire_connection(endpoint=i % endpoint_cnt, delay=random.random() * delay_factor)) 505 | for i in range(task_cnt) 506 | ] 507 | 508 | while acquire_cnt < task_cnt: 509 | await acquire_event.wait() 510 | acquire_event.clear() 511 | 512 | finished.set() 513 | 514 | await pool.close(timeout=task_cnt * delay) 515 | assert pool.get_size() == 0 516 | 517 | for task in tasks: 518 | await task 519 | 520 | 521 | @pytest.mark.timeout(5.0) 522 | async def test_pool_close_timeout(delay: float, connection_manager: MockConnectionManager): 523 | pool = ConnectionPool[int, MockConnection](connection_manager) 524 | 525 | acquired = asyncio.Event() 526 | 527 | async def acquire_connection() -> None: 528 | with contextlib.suppress(exceptions.ConnectionPoolClosedError): 529 | async with pool.connection(endpoint=1): 530 | acquired.set() 531 | await asyncio.sleep(delay) 532 | assert connection_manager.disposals == [] 533 | 534 | task = asyncio.create_task(acquire_connection()) 535 | 536 | await acquired.wait() 537 | await pool.close(timeout=2 * delay) 538 | 539 | await task 540 | 541 | 542 | @pytest.mark.timeout(5.0) 543 | async def test_pool_connection_manager_creation_error(connection_manager: MockConnectionManager): 544 | class TestException(Exception): 545 | pass 546 | connection_manager.create_err = TestException() 547 | 548 | pool = ConnectionPool[int, MockConnection](connection_manager) 549 | with pytest.raises(TestException): 550 | async with pool.connection(endpoint=1): 551 | pass 552 | 553 | assert pool.get_size() == 0 554 | assert await pool.get_endpoint_pool_size(endpoint=1) == 0 555 | 556 | 557 | @pytest.mark.timeout(5.0) 558 | async def test_pool_connection_manager_release_error(connection_manager: MockConnectionManager): 559 | class TestException(Exception): 560 | pass 561 | connection_manager.on_release_err = TestException() 562 | 563 | pool = ConnectionPool[int, MockConnection](connection_manager, idle_timeout=0.0, min_idle=0) 564 | with pytest.raises(TestException): 565 | async with pool.connection(endpoint=1): 566 | pass 567 | 568 | assert pool.get_size() == 1 569 | assert await pool.get_endpoint_pool_size(endpoint=1, acquired=False) == 1 570 | 571 | 572 | async def test_pool_connection_manager_aliveness_error(delay: float, connection_manager: MockConnectionManager): 573 | class TestException(Exception): 574 | pass 575 | connection_manager.check_aliveness_err = TestException() 576 | 577 | pool = ConnectionPool[int, MockConnection](connection_manager, idle_timeout=delay, min_idle=1) 578 | async with pool.connection(endpoint=1): 579 | pass 580 | 581 | with pytest.raises(TestException): 582 | async with pool.connection(endpoint=1): 583 | pass 584 | 585 | assert pool.get_size() == 1 586 | assert await pool.get_endpoint_pool_size(endpoint=1, acquired=False) == 1 587 | 588 | 589 | @pytest.mark.timeout(5.0) 590 | async def test_pool_connection_manager_dead_connection_error(delay: float, connection_manager: MockConnectionManager): 591 | class TestException(Exception): 592 | pass 593 | connection_manager.on_connection_dead_err = TestException() 594 | 595 | pool = ConnectionPool[int, MockConnection](connection_manager, idle_timeout=delay, min_idle=1) 596 | async with pool.connection(endpoint=1) as conn: 597 | pass 598 | 599 | connection_manager.aliveness[conn] = False 600 | with pytest.raises(TestException): 601 | async with pool.connection(endpoint=1): 602 | pass 603 | 604 | assert pool.get_size() == 0 605 | assert await pool.get_endpoint_pool_size(endpoint=1) == 0 606 | 607 | 608 | @pytest.mark.timeout(5.0) 609 | async def test_pool_connection_manager_acquire_error(delay: float, connection_manager: MockConnectionManager): 610 | class TestException(Exception): 611 | pass 612 | connection_manager.on_acquire_err = TestException() 613 | 614 | pool = ConnectionPool[int, MockConnection](connection_manager, idle_timeout=delay, min_idle=1) 615 | with pytest.raises(TestException): 616 | async with pool.connection(endpoint=1): 617 | pass 618 | 619 | assert pool.get_size() == 1 620 | assert await pool.get_endpoint_pool_size(endpoint=1, acquired=False) == 1 621 | 622 | 623 | @pytest.mark.timeout(5.0) 624 | async def test_pool_connection_manager_dispose_error(connection_manager: MockConnectionManager): 625 | class TestException(Exception): 626 | pass 627 | connection_manager.dispose_err = TestException() 628 | 629 | pool = ConnectionPool[int, MockConnection](connection_manager, idle_timeout=0.0, min_idle=0) 630 | async with pool.connection(endpoint=1): 631 | pass 632 | 633 | assert pool.get_size() == 0 634 | assert await pool.get_endpoint_pool_size(endpoint=1, acquired=False) == 0 635 | -------------------------------------------------------------------------------- /generic_connection_pool/threading/pool.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import contextlib 3 | import logging 4 | import math 5 | import threading 6 | from collections import defaultdict 7 | from typing import Any, Callable, DefaultDict, Generator, Generic, Hashable, List, Optional, Tuple, TypeVar 8 | 9 | from generic_connection_pool import exceptions 10 | from generic_connection_pool.common import BaseConnectionPool, BaseEndpointPool, BaseEventQueue, ConnectionInfo 11 | from generic_connection_pool.common import EventType, Timer 12 | 13 | from .locks import SharedLock 14 | 15 | logger = logging.getLogger(__package__) 16 | 17 | EndpointT = TypeVar('EndpointT', bound=Hashable) 18 | ConnectionT = TypeVar('ConnectionT', bound=Hashable) 19 | 20 | 21 | class BaseConnectionManager(Generic[EndpointT, ConnectionT], abc.ABC): 22 | """ 23 | Abstract synchronous connection factory. 24 | """ 25 | 26 | @abc.abstractmethod 27 | def create(self, endpoint: EndpointT, timeout: Optional[float] = None) -> ConnectionT: 28 | """ 29 | Creates a new connection. 30 | 31 | :param endpoint: endpoint to connect to 32 | :param timeout: operation timeout 33 | :return: new connection 34 | """ 35 | 36 | @abc.abstractmethod 37 | def dispose(self, endpoint: EndpointT, conn: ConnectionT, timeout: Optional[float] = None) -> None: 38 | """ 39 | Disposes the connection. 40 | 41 | :param endpoint: endpoint to connect to 42 | :param conn: connection to be disposed 43 | :param timeout: operation timeout 44 | """ 45 | 46 | def check_aliveness(self, endpoint: EndpointT, conn: ConnectionT, timeout: Optional[float] = None) -> bool: 47 | """ 48 | Checks that the connection is alive. 49 | 50 | :param endpoint: endpoint to connect to 51 | :param conn: connection to be checked 52 | :param timeout: operation timeout 53 | :return: ``True`` if connection is alive otherwise ``False`` 54 | """ 55 | 56 | return True 57 | 58 | def on_acquire(self, endpoint: EndpointT, conn: ConnectionT) -> None: 59 | """ 60 | Callback invoked on connection acquire. 61 | 62 | :param endpoint: endpoint to connect to 63 | :param conn: connection to be acquired 64 | """ 65 | 66 | def on_release(self, endpoint: EndpointT, conn: ConnectionT) -> None: 67 | """ 68 | Callback invoked on connection on_release. 69 | 70 | :param endpoint: endpoint to connect to 71 | :param conn: connection to be acquired 72 | """ 73 | 74 | def on_connection_dead(self, endpoint: EndpointT, conn: ConnectionT) -> None: 75 | """ 76 | Callback invoked on when connection aliveness check failed. 77 | 78 | :param endpoint: endpoint to connect to 79 | :param conn: dead connection 80 | """ 81 | 82 | 83 | KeyType = TypeVar('KeyType', bound=Hashable) 84 | 85 | 86 | class EventQueue(BaseEventQueue[KeyType], Generic[KeyType]): 87 | """ 88 | Thread-safe event queue wrapper. 89 | """ 90 | 91 | def __init__(self) -> None: 92 | super().__init__() 93 | self._lock = threading.Condition(threading.Lock()) 94 | self._stopped = False 95 | 96 | def get_size(self) -> int: 97 | with self._lock: 98 | return len(self._queue) 99 | 100 | def insert(self, timestamp: float, key: KeyType) -> None: 101 | with self._lock: 102 | self._insert(timestamp, key) 103 | self._lock.notify_all() 104 | 105 | def remove(self, key: KeyType) -> None: 106 | with self._lock: 107 | self._remove(key) 108 | 109 | def clear(self) -> None: 110 | with self._lock: 111 | self._clear() 112 | 113 | def wait(self, timeout: Optional[float] = None) -> KeyType: 114 | """ 115 | Waits for the next event. The event is not removed from the queue. 116 | """ 117 | 118 | timer = Timer(timeout) 119 | 120 | with self._lock: 121 | while True: 122 | if self._stopped: 123 | raise exceptions.ConnectionPoolClosedError 124 | 125 | key, backoff = self._try_get_next_event() 126 | if key is not None: 127 | return key 128 | elif timer.timedout: 129 | raise TimeoutError 130 | else: 131 | self._lock.wait( 132 | timeout=min(backoff, timer.remains) 133 | if backoff is not None and timer.remains is not None 134 | else backoff or timer.remains, 135 | ) 136 | 137 | def top(self) -> Optional[KeyType]: 138 | """ 139 | Returns top event. 140 | """ 141 | 142 | with self._lock: 143 | return self._top() 144 | 145 | def stop(self) -> None: 146 | """ 147 | Notifies the subscribers that the process in stopped. 148 | """ 149 | 150 | with self._lock: 151 | self._stopped = True 152 | self._lock.notify_all() 153 | 154 | 155 | class EndpointPool(BaseEndpointPool[EndpointT, ConnectionT], Generic[EndpointT, ConnectionT]): 156 | """ 157 | Thread-safe endpoint pool wrapper. 158 | """ 159 | 160 | def __init__(self, *args: Any, **kwargs: Any) -> None: 161 | super().__init__(*args, **kwargs) 162 | self._condvar = threading.Condition(threading.Lock()) 163 | 164 | @property 165 | def empty(self) -> bool: 166 | return self._size() == 0 167 | 168 | def size(self) -> int: 169 | with self._condvar: 170 | return self._size() 171 | 172 | def has_available_slot(self) -> bool: 173 | with self._condvar: 174 | return self._has_available_slot() 175 | 176 | def is_overflowed(self) -> bool: 177 | with self._condvar: 178 | return self._is_overflowed() 179 | 180 | def get_size(self, acquired: Optional[bool] = None) -> int: 181 | with self._condvar: 182 | return self._get_size(acquired) 183 | 184 | def reserve(self) -> bool: 185 | with self._condvar: 186 | return self._reserve() 187 | 188 | def acquire(self) -> Tuple[Optional[ConnectionInfo[EndpointT, ConnectionT]], bool]: 189 | with self._condvar: 190 | return self._acquire() 191 | 192 | def release(self, conn: ConnectionT) -> Tuple[ConnectionInfo[EndpointT, ConnectionT], bool]: 193 | with self._condvar: 194 | result = self._release(conn) 195 | self._condvar.notify() 196 | 197 | return result 198 | 199 | def detach(self, conn: ConnectionT, acquired: bool = False) -> ConnectionInfo[EndpointT, ConnectionT]: 200 | with self._condvar: 201 | result = self._detach(conn, acquired) 202 | self._condvar.notify() 203 | 204 | return result 205 | 206 | def attach(self, conn_info: ConnectionInfo[EndpointT, ConnectionT], acquired: bool = False) -> None: 207 | with self._condvar: 208 | self._attach(conn_info, acquired) 209 | self._condvar.notify() 210 | 211 | def acquire_and_detach(self) -> Optional[ConnectionInfo[EndpointT, ConnectionT]]: 212 | with self._condvar: 213 | conn_info, extra = self._acquire() 214 | if conn_info is None: 215 | return None 216 | 217 | result = self._detach(conn_info.conn, acquired=True) 218 | self._condvar.notify() 219 | 220 | return result 221 | 222 | def try_acquire_or_reserve( 223 | self, 224 | timeout: Optional[float] = None, 225 | ) -> Tuple[Optional[ConnectionInfo[EndpointT, ConnectionT]], bool]: 226 | timer = Timer(timeout) 227 | 228 | with self._condvar: 229 | while True: 230 | conn_info, extra = self._acquire() 231 | if conn_info is not None: 232 | return conn_info, extra 233 | 234 | elif self._reserve(): 235 | return None, False 236 | 237 | elif not self._condvar.wait(timer.remains): 238 | raise TimeoutError 239 | 240 | def attach_and_unreserve(self, conn_info: ConnectionInfo[EndpointT, ConnectionT], acquired: bool = False) -> None: 241 | with self._condvar: 242 | self._unreserve() 243 | self._attach(conn_info, acquired) 244 | self._condvar.notify() 245 | 246 | def unreserve(self) -> None: 247 | with self._condvar: 248 | self._unreserve() 249 | self._condvar.notify() 250 | 251 | 252 | class PoolManager(Generic[EndpointT, ConnectionT]): 253 | """ 254 | Connection pool manager. 255 | Provides an api to work with connection pools safely. 256 | """ 257 | 258 | def __init__(self, pool_factory: Callable[[], EndpointPool[EndpointT, ConnectionT]]) -> None: 259 | self._pools: DefaultDict[EndpointT, Tuple[SharedLock, EndpointPool[EndpointT, ConnectionT]]] = defaultdict( 260 | lambda: (SharedLock(), pool_factory()), 261 | ) 262 | self._condvar = threading.Condition(lock=threading.Lock()) 263 | 264 | def get_size(self) -> int: 265 | with self._condvar: 266 | return sum(pool.size() for lock, pool in self._pools.values()) 267 | 268 | def endpoints(self) -> List[EndpointT]: 269 | """ 270 | Returns available endpoints. 271 | """ 272 | 273 | with self._condvar: 274 | return list(self._pools.keys()) 275 | 276 | def wait_for(self, predicate: Callable[[], bool], timeout: Optional[float] = None) -> bool: 277 | """ 278 | Waits for the pool manager state change. 279 | """ 280 | 281 | with self._condvar: 282 | return self._condvar.wait_for(predicate, timeout=timeout) 283 | 284 | @contextlib.contextmanager 285 | def acquired( 286 | self, 287 | endpoint: EndpointT, 288 | exclusive: bool = False, 289 | blocking: bool = True, 290 | timeout: Optional[float] = None, 291 | setdefault: bool = False, 292 | ) -> Generator[EndpointPool[EndpointT, ConnectionT], None, None]: 293 | """ 294 | Opens the endpoint pool acquiring context. 295 | 296 | :param endpoint: pool endpoint 297 | :param exclusive: pool access mode (shared or exclusive) 298 | :param blocking: acquiring mode 299 | :param timeout: pool acquiring timeout 300 | :param setdefault: create a new pool if it not exists 301 | 302 | :return: acquired pool 303 | """ 304 | 305 | with self._condvar: 306 | if (lock_and_pool := self._pools[endpoint] if setdefault else self._pools.get(endpoint)) is None: 307 | raise exceptions.ConnectionPoolNotFound 308 | 309 | lock, pool = lock_and_pool 310 | if not lock.acquire(exclusive, blocking=blocking, timeout=timeout): 311 | raise TimeoutError 312 | 313 | try: 314 | yield pool 315 | finally: 316 | lock.release(exclusive) 317 | with self._condvar: 318 | self._condvar.notify() 319 | 320 | def try_delete(self, endpoint: EndpointT) -> bool: 321 | """ 322 | Tries to delete the endpoint pool. 323 | Acquires the pool in exclusive mode and checks that pool is empty. 324 | 325 | :param endpoint: pool endpoint 326 | 327 | :return: `True` if the pool has been deleted otherwise `False` 328 | """ 329 | 330 | with self._condvar: 331 | if (lock_and_pool := self._pools.get(endpoint)) is None: 332 | return True 333 | else: 334 | lock, pool = lock_and_pool 335 | 336 | try: 337 | with lock.acquired(exclusive=True, blocking=False): 338 | if not pool.empty: 339 | return False 340 | else: 341 | self._pools.pop(endpoint) 342 | return True 343 | except TimeoutError: 344 | return False 345 | 346 | 347 | class ConnectionPool(Generic[EndpointT, ConnectionT], BaseConnectionPool[EndpointT, ConnectionT]): 348 | """ 349 | Synchronous connection pool. 350 | 351 | :param connection_manager: connection manager instance. Used to create, dispose or check connection aliveness. 352 | :param acquire_timeout: connection acquiring default timeout. 353 | :param background_collector: if ``True`` starts a background worker that disposes expired and idle connections 354 | maintaining requested pool state. If ``False`` the connections will be disposed 355 | on each connection release. 356 | :param dispose_batch_size: maximum number of expired and idle connections to be disposed on connection release 357 | (if background collector is started the parameter is ignored). 358 | :param dispose_timeout: connection disposal timeout. 359 | :param min_idle: minimum number of connections in each endpoint the pool tries to hold. Connections that exceed 360 | that number will be considered as extra and disposed after ``idle_timeout`` seconds of inactivity. 361 | :param max_size: maximum number of endpoint connections. 362 | :param kwargs: see :py:class:`generic_connection_pool.common.BaseConnectionPool` 363 | """ 364 | 365 | def __init__( 366 | self, 367 | connection_manager: BaseConnectionManager[EndpointT, ConnectionT], 368 | *, 369 | acquire_timeout: Optional[float] = None, 370 | background_collector: bool = False, 371 | dispose_batch_size: int = 0, 372 | dispose_timeout: Optional[float] = None, 373 | min_idle: int = 1, 374 | max_size: int = 10, 375 | **kwargs: Any, 376 | ): 377 | super().__init__(min_idle=min_idle, max_size=max_size, **kwargs) 378 | 379 | self._stopped = False 380 | self._acquire_timeout = acquire_timeout 381 | self._dispose_batch_size = dispose_batch_size 382 | self._dispose_timeout = dispose_timeout 383 | self._connection_manager = connection_manager 384 | 385 | self._lock = threading.Lock() 386 | 387 | self._pools = PoolManager( 388 | pool_factory=lambda: EndpointPool[EndpointT, ConnectionT]( 389 | max_pool_size=min_idle, 390 | max_extra_size=max_size - min_idle, 391 | ), 392 | ) 393 | self._event_queue = EventQueue[Tuple[EventType, EndpointT, ConnectionT]]() 394 | 395 | self._collector: Optional[threading.Thread] = None 396 | if background_collector: 397 | self._collector = threading.Thread( 398 | target=self._start_collector, 399 | name='gcp-collector', 400 | ) 401 | self._collector.start() 402 | 403 | def get_size(self) -> int: 404 | with self._lock: 405 | return super().get_size() 406 | 407 | def get_endpoint_pool_size(self, endpoint: EndpointT, acquired: Optional[bool] = None) -> int: 408 | """ 409 | Returns endpoint pool size. 410 | 411 | :param endpoint: pool endpoint 412 | :param acquired: if `True` returns the number of acquired connections, 413 | if `False` returns the number of free connections 414 | otherwise returns total size 415 | """ 416 | 417 | try: 418 | with self._pools.acquired(endpoint) as pool: 419 | return pool.get_size(acquired) 420 | except exceptions.ConnectionPoolNotFound: 421 | return 0 422 | 423 | @contextlib.contextmanager 424 | def connection(self, endpoint: EndpointT, timeout: Optional[float] = None) -> Generator[ConnectionT, None, None]: 425 | """ 426 | Acquires a connection from the pool. 427 | 428 | :param endpoint: connection endpoint 429 | :param timeout: number of seconds to wait. If timeout is reached :py:class:`TimeoutError` is raised. 430 | :return: acquired connection 431 | """ 432 | 433 | conn = self.acquire(endpoint, timeout=timeout) 434 | try: 435 | yield conn 436 | finally: 437 | self.release(conn, endpoint) 438 | 439 | def acquire(self, endpoint: EndpointT, timeout: Optional[float] = None) -> ConnectionT: 440 | """ 441 | Acquires a connection from the pool. 442 | 443 | :param endpoint: connection endpoint 444 | :param timeout: number of seconds to wait. If timeout is reached :py:class:`TimeoutError` is raised. 445 | :return: acquired connection 446 | """ 447 | 448 | timeout = self._acquire_timeout if timeout is None else timeout 449 | 450 | conn = self._acquire_connection(endpoint, timeout=timeout) 451 | try: 452 | self._connection_manager.on_acquire(endpoint, conn) 453 | except Exception: 454 | self._release_connection(endpoint, conn) 455 | raise 456 | 457 | return conn 458 | 459 | def release(self, conn: ConnectionT, endpoint: EndpointT) -> None: 460 | """ 461 | Releases a connection. 462 | 463 | :param conn: connection to be released 464 | :param endpoint: connection endpoint 465 | """ 466 | 467 | try: 468 | self._connection_manager.on_release(endpoint, conn) 469 | finally: 470 | self._release_connection(endpoint, conn) 471 | 472 | if self._collector is None: 473 | dispose_batch_size = self._dispose_batch_size or int(math.log2(self._pool_size + 1)) + 1 474 | self._collect_disposable_connections(dispose_batch_size) 475 | 476 | def close(self, timeout: Optional[float] = None) -> None: 477 | """ 478 | Closes the connection pool. 479 | 480 | :param timeout: timeout after which the pool closes all connection despite they are released or not 481 | """ 482 | 483 | timer = Timer(timeout) 484 | 485 | self._stopped = True 486 | self._event_queue.stop() 487 | 488 | if self._collector is not None: 489 | self._collector.join(timeout=timer.remains) 490 | 491 | self._close_connections(timeout=timer.remains) 492 | self._event_queue.clear() 493 | 494 | def _acquire_connection(self, endpoint: EndpointT, timeout: Optional[float]) -> ConnectionT: 495 | timer = Timer(timeout) 496 | 497 | while True: 498 | if self._stopped: 499 | raise exceptions.ConnectionPoolClosedError 500 | 501 | with self._pools.acquired(endpoint, setdefault=True) as pool: 502 | conn_info, extra = pool.try_acquire_or_reserve(timeout=timer.remains) 503 | if conn_info is not None: 504 | # unsubscribe the connection since acquired connection can't be disposed 505 | self._event_queue.remove((EventType.LIFETIME, conn_info.endpoint, conn_info.conn)) 506 | if extra: 507 | self._event_queue.remove((EventType.IDLETIME, conn_info.endpoint, conn_info.conn)) 508 | 509 | try: 510 | is_alive = self._connection_manager.check_aliveness(endpoint, conn_info.conn, timer.remains) 511 | except Exception: 512 | self._release_connection(endpoint, conn_info.conn) 513 | raise 514 | 515 | if not is_alive: 516 | pool.detach(conn_info.conn, acquired=True) 517 | self._decrease_pool_size() 518 | self._connection_manager.on_connection_dead(endpoint, conn_info.conn) 519 | continue 520 | else: 521 | try: 522 | if conn_info := self._create_connection(endpoint, timer.remains): 523 | pool.attach_and_unreserve(conn_info, acquired=True) 524 | else: 525 | pool.unreserve() 526 | continue 527 | except Exception: 528 | pool.unreserve() 529 | raise 530 | 531 | logger.debug("connection created: %s", endpoint) 532 | 533 | return conn_info.conn 534 | 535 | def _create_connection( 536 | self, 537 | endpoint: EndpointT, 538 | timeout: Optional[float] = None, 539 | ) -> ConnectionInfo[EndpointT, ConnectionT]: 540 | timer = Timer(timeout) 541 | 542 | while True: 543 | if self._increase_pool_size(): 544 | try: 545 | conn = self._connection_manager.create(endpoint, timeout=timer.remains) 546 | except Exception: 547 | self._decrease_pool_size() 548 | raise 549 | 550 | return ConnectionInfo(endpoint, conn) 551 | 552 | else: 553 | self._try_free_slot(timeout=timer.remains) 554 | 555 | def _try_free_slot(self, timeout: Optional[float] = None) -> bool: 556 | timer = Timer(timeout) 557 | 558 | if event := self._event_queue.top(): 559 | ev, endpoint, conn = event 560 | return self._try_detach_connection(endpoint, conn, timeout=timer.remains) 561 | 562 | elif not self._pools.wait_for(predicate=lambda: not self.is_full, timeout=timer.remains): 563 | raise TimeoutError 564 | 565 | return True 566 | 567 | def _collect_disposable_connections(self, max_disposals: int) -> None: 568 | disposals = 0 569 | 570 | while disposals < max_disposals: 571 | try: 572 | ev, endpoint, conn = self._event_queue.wait(timeout=0.0) 573 | except TimeoutError: 574 | # no connections to dispose 575 | break 576 | 577 | if self._try_detach_connection(endpoint, conn): 578 | disposals += 1 579 | 580 | if disposals > 0: 581 | logger.debug("disposed %d connections", disposals) 582 | 583 | def _start_collector(self) -> None: 584 | logger.debug("collector started") 585 | 586 | while not self._stopped: 587 | try: 588 | ev, endpoint, conn = self._event_queue.wait() 589 | self._try_detach_connection(endpoint, conn) 590 | except exceptions.ConnectionPoolClosedError: 591 | break 592 | 593 | def _try_detach_connection(self, endpoint: EndpointT, conn: ConnectionT, timeout: Optional[float] = None) -> bool: 594 | with contextlib.suppress(exceptions.ConnectionPoolNotFound): 595 | with self._pools.acquired(endpoint, timeout=timeout) as pool: 596 | try: 597 | conn_info = pool.detach(conn) 598 | except KeyError: 599 | return False 600 | finally: 601 | self._event_queue.remove((EventType.LIFETIME, endpoint, conn)) 602 | self._event_queue.remove((EventType.IDLETIME, endpoint, conn)) 603 | 604 | self._dispose_connection(conn_info, timeout=self._dispose_timeout) 605 | self._decrease_pool_size() 606 | is_pool_empty = pool.empty 607 | 608 | if is_pool_empty: 609 | self._pools.try_delete(endpoint) 610 | 611 | return True 612 | 613 | def _dispose_connection(self, conn_info: ConnectionInfo[EndpointT, ConnectionT], timeout: Optional[float]) -> bool: 614 | try: 615 | self._connection_manager.dispose(conn_info.endpoint, conn_info.conn, timeout=timeout) 616 | except TimeoutError: 617 | logger.error("connection disposal timed-out: %s", conn_info.endpoint) 618 | return False 619 | except Exception as e: 620 | logger.error("connection disposal failed: %s", e) 621 | return False 622 | 623 | logger.debug("connection disposed: %s", conn_info.endpoint) 624 | return True 625 | 626 | def _release_connection(self, endpoint: EndpointT, conn: ConnectionT) -> None: 627 | with self._pools.acquired(endpoint) as pool: 628 | try: 629 | conn_info, extra = pool.release(conn) 630 | except KeyError: 631 | raise RuntimeError("connection not acquired") 632 | 633 | # subscribe the connection, it is disposable again 634 | self._event_queue.insert( 635 | conn_info.created_at + self.max_lifetime, 636 | (EventType.LIFETIME, conn_info.endpoint, conn_info.conn), 637 | ) 638 | if extra: 639 | self._event_queue.insert( 640 | conn_info.accessed_at + self._idle_timeout, 641 | (EventType.IDLETIME, conn_info.endpoint, conn_info.conn), 642 | ) 643 | 644 | def _close_connections(self, timeout: Optional[float]) -> None: 645 | timer = Timer(timeout) 646 | 647 | while self._pools.get_size() != 0: 648 | for endpoint in self._pools.endpoints(): 649 | with contextlib.suppress(exceptions.ConnectionPoolNotFound): 650 | with self._pools.acquired(endpoint, timeout=timer.remains) as pool: 651 | conn_info = pool.acquire_and_detach() 652 | if conn_info is None: 653 | continue 654 | 655 | self._dispose_connection(conn_info, timeout=timer.remains) 656 | self._decrease_pool_size() 657 | is_pool_empty = pool.empty 658 | 659 | if is_pool_empty: 660 | self._pools.try_delete(endpoint) 661 | 662 | def _increase_pool_size(self) -> bool: 663 | with self._lock: 664 | if self.is_full: 665 | return False 666 | else: 667 | self._pool_size += 1 668 | return True 669 | 670 | def _decrease_pool_size(self) -> None: 671 | with self._lock: 672 | self._pool_size -= 1 673 | --------------------------------------------------------------------------------