├── .github └── workflows │ └── test.yml ├── .gitignore ├── LICENSE ├── README.md ├── pyproject.toml ├── src └── pg8000 │ ├── __init__.py │ ├── converters.py │ ├── core.py │ ├── dbapi.py │ ├── exceptions.py │ ├── legacy.py │ ├── native.py │ └── types.py └── test ├── __init__.py ├── dbapi ├── __init__.py ├── auth │ ├── __init__.py │ ├── test_gss.py │ ├── test_md5.py │ ├── test_md5_ssl.py │ ├── test_password.py │ ├── test_scram-sha-256.py │ └── test_scram-sha-256_ssl.py ├── conftest.py ├── test_benchmarks.py ├── test_connection.py ├── test_copy.py ├── test_dbapi.py ├── test_dbapi20.py ├── test_paramstyle.py ├── test_query.py └── test_typeconversion.py ├── legacy ├── __init__.py ├── auth │ ├── __init__.py │ ├── test_gss.py │ ├── test_md5.py │ ├── test_md5_ssl.py │ ├── test_password.py │ ├── test_scram-sha-256.py │ └── test_scram-sha-256_ssl.py ├── conftest.py ├── stress.py ├── test_benchmarks.py ├── test_connection.py ├── test_copy.py ├── test_dbapi.py ├── test_dbapi20.py ├── test_error_recovery.py ├── test_paramstyle.py ├── test_prepared_statement.py ├── test_query.py ├── test_typeconversion.py └── test_typeobjects.py ├── native ├── __init__.py ├── auth │ ├── __init__.py │ ├── test_gss.py │ ├── test_md5.py │ ├── test_md5_ssl.py │ ├── test_password.py │ ├── test_scram-sha-256.py │ └── test_scram-sha-256_ssl.py ├── conftest.py ├── test_benchmarks.py ├── test_connection.py ├── test_copy.py ├── test_core.py ├── test_import_all.py ├── test_prepared_statement.py ├── test_query.py └── test_typeconversion.py ├── test_converters.py ├── test_readme.py ├── test_types.py ├── test_utils.py └── utils.py /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: pg8000 2 | 3 | on: [push] 4 | 5 | permissions: read-all 6 | 7 | jobs: 8 | main-test: 9 | 10 | runs-on: ubuntu-latest 11 | strategy: 12 | matrix: 13 | python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] 14 | postgresql-version: [16, 15, 14, 13, 12] 15 | 16 | container: 17 | image: python:${{ matrix.python-version }} 18 | env: 19 | PGHOST: postgres 20 | PGPASSWORD: postgres 21 | PGUSER: postgres 22 | 23 | services: 24 | postgres: 25 | image: postgres:${{ matrix.postgresql-version }} 26 | env: 27 | POSTGRES_PASSWORD: postgres 28 | # Set health checks to wait until postgres has started 29 | options: >- 30 | --health-cmd pg_isready 31 | --health-interval 10s 32 | --health-timeout 5s 33 | --health-retries 5 34 | 35 | steps: 36 | - uses: actions/checkout@v4 37 | - name: Install dependencies 38 | run: | 39 | git config --global --add safe.directory "$GITHUB_WORKSPACE" 40 | python -m pip install --no-cache-dir --upgrade pip 41 | pip install --no-cache-dir --root-user-action ignore pytest pytest-mock pytest-benchmark pytz . 42 | - name: Set up Postgresql 43 | run: | 44 | apt-get update 45 | apt-get install --yes --no-install-recommends postgresql-client 46 | psql -c "CREATE EXTENSION hstore;" 47 | psql -c "SELECT pg_reload_conf()" 48 | - name: Test with pytest 49 | run: | 50 | python -m pytest -x -v -W error --ignore=test/dbapi/auth/ --ignore=test/legacy/auth/ --ignore=test/native/auth/ test --ignore=test/test_readme.py 51 | 52 | auth-test: 53 | runs-on: ubuntu-latest 54 | strategy: 55 | matrix: 56 | python-version: ["3.11"] 57 | postgresql-version: ["16", "12"] 58 | auth-type: [md5, gss, password, scram-sha-256] 59 | 60 | container: 61 | image: python:${{ matrix.python-version }} 62 | env: 63 | PGHOST: postgres 64 | PGPASSWORD: postgres 65 | PGUSER: postgres 66 | PIP_ROOT_USER_ACTION: ignore 67 | 68 | services: 69 | postgres: 70 | image: postgres:${{ matrix.postgresql-version }} 71 | env: 72 | POSTGRES_PASSWORD: postgres 73 | POSTGRES_HOST_AUTH_METHOD: ${{ matrix.auth-type }} 74 | POSTGRES_INITDB_ARGS: "${{ matrix.auth-type == 'scram-sha-256' && '--auth-host=scram-sha-256' || '' }}" 75 | # Set health checks to wait until postgres has started 76 | options: >- 77 | --health-cmd pg_isready 78 | --health-interval 10s 79 | --health-timeout 5s 80 | --health-retries 5 81 | 82 | steps: 83 | - uses: actions/checkout@v4 84 | - name: Install dependencies 85 | run: | 86 | git config --global --add safe.directory "$GITHUB_WORKSPACE" 87 | python -m pip install --no-cache-dir --upgrade pip 88 | pip install --no-cache-dir pytest pytest-mock pytest-benchmark pytz . 89 | - name: Test with pytest 90 | run: | 91 | python -m pytest -x -v -W error test/dbapi/auth/test_${{ matrix.auth-type }}.py test/native/auth/test_${{ matrix.auth-type }}.py test/legacy/auth/test_${{ matrix.auth-type }}.py 92 | 93 | ssl-test: 94 | runs-on: ubuntu-latest 95 | 96 | strategy: 97 | matrix: 98 | auth-type: [md5, scram-sha-256] 99 | 100 | services: 101 | postgres: 102 | image: postgres 103 | env: 104 | POSTGRES_PASSWORD: postgres 105 | POSTGRES_HOST_AUTH_METHOD: ${{ matrix.auth-type }} 106 | POSTGRES_INITDB_ARGS: "${{ matrix.auth-type == 'scram-sha-256' && '--auth-host=scram-sha-256' || '' }}" 107 | options: >- 108 | --health-cmd pg_isready 109 | --health-interval 10s 110 | --health-timeout 5s 111 | --health-retries 5 112 | ports: 113 | - 5432:5432 114 | 115 | steps: 116 | - name: Configure Postgres 117 | env: 118 | PGPASSWORD: postgres 119 | PGUSER: postgres 120 | PGHOST: localhost 121 | run: | 122 | sudo apt update 123 | sudo apt install --yes --no-install-recommends postgresql-client 124 | psql -c "ALTER SYSTEM SET ssl = on;" 125 | psql -c "ALTER SYSTEM SET ssl_cert_file = '/etc/ssl/certs/ssl-cert-snakeoil.pem'" 126 | psql -c "ALTER SYSTEM SET ssl_key_file = '/etc/ssl/private/ssl-cert-snakeoil.key'" 127 | psql -c "SELECT pg_reload_conf()" 128 | 129 | - name: Check out repository code 130 | uses: actions/checkout@v4 131 | - name: Set up Python 132 | uses: actions/setup-python@v5 133 | with: 134 | python-version: "3.11" 135 | - name: Install dependencies 136 | run: | 137 | python -m pip install --upgrade pip 138 | pip install pytest pytest-mock pytest-benchmark pytz . 139 | - name: SSL Test 140 | env: 141 | PGPASSWORD: postgres 142 | USER: postgres 143 | run: | 144 | python -m pytest -x -v -W error test/dbapi/auth/test_${{ matrix.auth-type}}_ssl.py test/native/auth/test_${{ matrix.auth-type}}_ssl.py test/legacy/auth/test_${{ matrix.auth-type}}_ssl.py 145 | 146 | static-test: 147 | runs-on: ubuntu-latest 148 | 149 | services: 150 | postgres: 151 | image: postgres:12 152 | env: 153 | POSTGRES_PASSWORD: cpsnow 154 | options: >- 155 | --health-cmd pg_isready 156 | --health-interval 10s 157 | --health-timeout 5s 158 | --health-retries 5 159 | ports: 160 | - 5432:5432 161 | 162 | steps: 163 | - name: Check out repository code 164 | uses: actions/checkout@v4 165 | - name: Set up Python 166 | uses: actions/setup-python@v5 167 | with: 168 | python-version: "3.11" 169 | - name: Install dependencies 170 | run: | 171 | psql postgresql://postgres:cpsnow@localhost -c "ALTER SYSTEM SET ssl = on;" 172 | psql postgresql://postgres:cpsnow@localhost -c "ALTER SYSTEM SET ssl_cert_file = '/etc/ssl/certs/ssl-cert-snakeoil.pem'" 173 | psql postgresql://postgres:cpsnow@localhost -c "ALTER SYSTEM SET ssl_key_file = '/etc/ssl/private/ssl-cert-snakeoil.key'" 174 | psql postgresql://postgres:cpsnow@localhost localhost -c "SELECT pg_reload_conf()" 175 | python -m pip install --upgrade pip 176 | pip install black build flake8 pytest flake8-alphabetize Flake8-pyproject \ 177 | twine . 178 | - name: Lint check 179 | run: | 180 | black --check . 181 | flake8 . 182 | - name: Doctest 183 | env: 184 | PGPASSWORD: cpsnow 185 | USER: postgres 186 | run: | 187 | python -m pytest -x -v -W error test/test_readme.py 188 | - name: Check Distribution 189 | run: | 190 | python -m build 191 | twine check dist/* 192 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.py[co] 2 | *.swp 3 | *.orig 4 | *.class 5 | build 6 | pg8000.egg-info 7 | tmp 8 | dist 9 | .tox 10 | MANIFEST 11 | venv 12 | .cache 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright Mathieu Fenniak and Contributors to pg8000. 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "hatchling", 4 | "versioningit >= 3.1.2", 5 | ] 6 | build-backend = "hatchling.build" 7 | 8 | [project] 9 | name = "pg8000" 10 | authors = [{name = "The Contributors"}] 11 | description = "PostgreSQL interface library" 12 | readme = "README.md" 13 | requires-python = ">=3.8" 14 | keywords= ["postgresql", "dbapi"] 15 | license = {text = "BSD 3-Clause License"} 16 | classifiers = [ 17 | "Development Status :: 5 - Production/Stable", 18 | "Intended Audience :: Developers", 19 | "License :: OSI Approved :: BSD License", 20 | "Programming Language :: Python", 21 | "Programming Language :: Python :: 3", 22 | "Programming Language :: Python :: 3.8", 23 | "Programming Language :: Python :: 3.9", 24 | "Programming Language :: Python :: 3.10", 25 | "Programming Language :: Python :: 3.11", 26 | "Programming Language :: Python :: 3.12", 27 | "Programming Language :: Python :: Implementation", 28 | "Programming Language :: Python :: Implementation :: CPython", 29 | "Programming Language :: Python :: Implementation :: PyPy", 30 | "Operating System :: OS Independent", 31 | "Topic :: Database :: Front-Ends", 32 | "Topic :: Software Development :: Libraries :: Python Modules", 33 | ] 34 | dependencies = [ 35 | "scramp >= 1.4.5", 36 | 'python-dateutil >= 2.8.2', 37 | ] 38 | dynamic = ["version"] 39 | 40 | [project.urls] 41 | Homepage = "https://github.com/tlocke/pg8000" 42 | 43 | [tool.hatch.version] 44 | source = "versioningit" 45 | 46 | [tool.versioningit] 47 | 48 | [tool.versioningit.vcs] 49 | method = "git" 50 | default-tag = "0.0.0" 51 | 52 | [tool.flake8] 53 | application-names = ['pg8000'] 54 | ignore = ['E203', 'W503'] 55 | max-line-length = 88 56 | exclude = ['.git', '__pycache__', 'build', 'dist', 'venv', '.tox'] 57 | application-import-names = ['pg8000'] 58 | 59 | [tool.tox] 60 | legacy_tox_ini = """ 61 | [tox] 62 | isolated_build = True 63 | envlist = py 64 | 65 | [testenv] 66 | passenv = PGPORT 67 | allowlist_externals=/usr/bin/rm 68 | commands = 69 | black --check . 70 | flake8 . 71 | python -m pytest -x -v -W error test 72 | rm -rf dist 73 | python -m build 74 | twine check dist/* 75 | deps = 76 | build 77 | pytest 78 | pytest-mock 79 | pytest-benchmark 80 | black 81 | flake8 82 | flake8-alphabetize 83 | Flake8-pyproject 84 | pytz 85 | twine 86 | """ 87 | -------------------------------------------------------------------------------- /src/pg8000/__init__.py: -------------------------------------------------------------------------------- 1 | from pg8000.legacy import ( 2 | BIGINTEGER, 3 | BINARY, 4 | BOOLEAN, 5 | BOOLEAN_ARRAY, 6 | BYTES, 7 | Binary, 8 | CHAR, 9 | CHAR_ARRAY, 10 | Connection, 11 | Cursor, 12 | DATE, 13 | DATETIME, 14 | DECIMAL, 15 | DECIMAL_ARRAY, 16 | DataError, 17 | DatabaseError, 18 | Date, 19 | DateFromTicks, 20 | Error, 21 | FLOAT, 22 | FLOAT_ARRAY, 23 | INET, 24 | INT2VECTOR, 25 | INTEGER, 26 | INTEGER_ARRAY, 27 | INTERVAL, 28 | IntegrityError, 29 | InterfaceError, 30 | InternalError, 31 | JSON, 32 | JSONB, 33 | MACADDR, 34 | NAME, 35 | NAME_ARRAY, 36 | NULLTYPE, 37 | NUMBER, 38 | NotSupportedError, 39 | OID, 40 | OperationalError, 41 | PGInterval, 42 | ProgrammingError, 43 | ROWID, 44 | Range, 45 | STRING, 46 | TEXT, 47 | TEXT_ARRAY, 48 | TIME, 49 | TIMEDELTA, 50 | TIMESTAMP, 51 | TIMESTAMPTZ, 52 | Time, 53 | TimeFromTicks, 54 | Timestamp, 55 | TimestampFromTicks, 56 | UNKNOWN, 57 | UUID_TYPE, 58 | VARCHAR, 59 | VARCHAR_ARRAY, 60 | Warning, 61 | XID, 62 | __version__, 63 | pginterval_in, 64 | pginterval_out, 65 | timedelta_in, 66 | ) 67 | 68 | # Copyright (c) 2007-2009, Mathieu Fenniak 69 | # Copyright (c) The Contributors 70 | # All rights reserved. 71 | # 72 | # Redistribution and use in source and binary forms, with or without 73 | # modification, are permitted provided that the following conditions are 74 | # met: 75 | # 76 | # * Redistributions of source code must retain the above copyright notice, 77 | # this list of conditions and the following disclaimer. 78 | # * Redistributions in binary form must reproduce the above copyright notice, 79 | # this list of conditions and the following disclaimer in the documentation 80 | # and/or other materials provided with the distribution. 81 | # * The name of the author may not be used to endorse or promote products 82 | # derived from this software without specific prior written permission. 83 | # 84 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 85 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 86 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 87 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 88 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 89 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 90 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 91 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 92 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 93 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 94 | # POSSIBILITY OF SUCH DAMAGE. 95 | 96 | 97 | def connect( 98 | user, 99 | host="localhost", 100 | database=None, 101 | port=5432, 102 | password=None, 103 | source_address=None, 104 | unix_sock=None, 105 | ssl_context=None, 106 | timeout=None, 107 | tcp_keepalive=True, 108 | application_name=None, 109 | replication=None, 110 | startup_params=None, 111 | ): 112 | return Connection( 113 | user, 114 | host=host, 115 | database=database, 116 | port=port, 117 | password=password, 118 | source_address=source_address, 119 | unix_sock=unix_sock, 120 | ssl_context=ssl_context, 121 | timeout=timeout, 122 | tcp_keepalive=tcp_keepalive, 123 | application_name=application_name, 124 | replication=replication, 125 | startup_params=startup_params, 126 | ) 127 | 128 | 129 | apilevel = "2.0" 130 | """The DBAPI level supported, currently "2.0". 131 | 132 | This property is part of the `DBAPI 2.0 specification 133 | `_. 134 | """ 135 | 136 | threadsafety = 1 137 | """Integer constant stating the level of thread safety the DBAPI interface 138 | supports. This DBAPI module supports sharing of the module only. Connections 139 | and cursors my not be shared between threads. This gives pg8000 a threadsafety 140 | value of 1. 141 | 142 | This property is part of the `DBAPI 2.0 specification 143 | `_. 144 | """ 145 | 146 | paramstyle = "format" 147 | 148 | 149 | __all__ = [ 150 | "BIGINTEGER", 151 | "BINARY", 152 | "BOOLEAN", 153 | "BOOLEAN_ARRAY", 154 | "BYTES", 155 | "Binary", 156 | "CHAR", 157 | "CHAR_ARRAY", 158 | "Connection", 159 | "Cursor", 160 | "DATE", 161 | "DATETIME", 162 | "DECIMAL", 163 | "DECIMAL_ARRAY", 164 | "DataError", 165 | "DatabaseError", 166 | "Date", 167 | "DateFromTicks", 168 | "Error", 169 | "FLOAT", 170 | "FLOAT_ARRAY", 171 | "INET", 172 | "INT2VECTOR", 173 | "INTEGER", 174 | "INTEGER_ARRAY", 175 | "INTERVAL", 176 | "IntegrityError", 177 | "InterfaceError", 178 | "InternalError", 179 | "JSON", 180 | "JSONB", 181 | "MACADDR", 182 | "NAME", 183 | "NAME_ARRAY", 184 | "NULLTYPE", 185 | "NUMBER", 186 | "NotSupportedError", 187 | "OID", 188 | "OperationalError", 189 | "PGInterval", 190 | "ProgrammingError", 191 | "ROWID", 192 | "Range", 193 | "STRING", 194 | "TEXT", 195 | "TEXT_ARRAY", 196 | "TIME", 197 | "TIMEDELTA", 198 | "TIMESTAMP", 199 | "TIMESTAMPTZ", 200 | "Time", 201 | "TimeFromTicks", 202 | "Timestamp", 203 | "TimestampFromTicks", 204 | "UNKNOWN", 205 | "UUID_TYPE", 206 | "VARCHAR", 207 | "VARCHAR_ARRAY", 208 | "Warning", 209 | "XID", 210 | "__version__", 211 | "connect", 212 | "pginterval_in", 213 | "pginterval_out", 214 | "timedelta_in", 215 | ] 216 | -------------------------------------------------------------------------------- /src/pg8000/converters.py: -------------------------------------------------------------------------------- 1 | from datetime import ( 2 | date as Date, 3 | datetime as Datetime, 4 | time as Time, 5 | timedelta as Timedelta, 6 | timezone as Timezone, 7 | ) 8 | from decimal import Decimal 9 | from enum import Enum 10 | from ipaddress import ( 11 | IPv4Address, 12 | IPv4Network, 13 | IPv6Address, 14 | IPv6Network, 15 | ip_address, 16 | ip_network, 17 | ) 18 | from json import dumps, loads 19 | from uuid import UUID 20 | 21 | from dateutil.parser import ParserError, parse 22 | 23 | from pg8000.exceptions import InterfaceError 24 | from pg8000.types import PGInterval, Range 25 | 26 | 27 | ANY_ARRAY = 2277 28 | BIGINT = 20 29 | BIGINT_ARRAY = 1016 30 | BOOLEAN = 16 31 | BOOLEAN_ARRAY = 1000 32 | BYTES = 17 33 | BYTES_ARRAY = 1001 34 | CHAR = 1042 35 | CHAR_ARRAY = 1014 36 | CIDR = 650 37 | CIDR_ARRAY = 651 38 | CSTRING = 2275 39 | CSTRING_ARRAY = 1263 40 | DATE = 1082 41 | DATE_ARRAY = 1182 42 | DATEMULTIRANGE = 4535 43 | DATEMULTIRANGE_ARRAY = 6155 44 | DATERANGE = 3912 45 | DATERANGE_ARRAY = 3913 46 | FLOAT = 701 47 | FLOAT_ARRAY = 1022 48 | INET = 869 49 | INET_ARRAY = 1041 50 | INT2VECTOR = 22 51 | INT4MULTIRANGE = 4451 52 | INT4MULTIRANGE_ARRAY = 6150 53 | INT4RANGE = 3904 54 | INT4RANGE_ARRAY = 3905 55 | INT8MULTIRANGE = 4536 56 | INT8MULTIRANGE_ARRAY = 6157 57 | INT8RANGE = 3926 58 | INT8RANGE_ARRAY = 3927 59 | INTEGER = 23 60 | INTEGER_ARRAY = 1007 61 | INTERVAL = 1186 62 | INTERVAL_ARRAY = 1187 63 | OID = 26 64 | JSON = 114 65 | JSON_ARRAY = 199 66 | JSONB = 3802 67 | JSONB_ARRAY = 3807 68 | MACADDR = 829 69 | MONEY = 790 70 | MONEY_ARRAY = 791 71 | NAME = 19 72 | NAME_ARRAY = 1003 73 | NUMERIC = 1700 74 | NUMERIC_ARRAY = 1231 75 | NUMRANGE = 3906 76 | NUMRANGE_ARRAY = 3907 77 | NUMMULTIRANGE = 4532 78 | NUMMULTIRANGE_ARRAY = 6151 79 | NULLTYPE = -1 80 | OID = 26 81 | POINT = 600 82 | REAL = 700 83 | REAL_ARRAY = 1021 84 | RECORD = 2249 85 | SMALLINT = 21 86 | SMALLINT_ARRAY = 1005 87 | SMALLINT_VECTOR = 22 88 | STRING = 1043 89 | TEXT = 25 90 | TEXT_ARRAY = 1009 91 | TIME = 1083 92 | TIME_ARRAY = 1183 93 | TIMESTAMP = 1114 94 | TIMESTAMP_ARRAY = 1115 95 | TIMESTAMPTZ = 1184 96 | TIMESTAMPTZ_ARRAY = 1185 97 | TSMULTIRANGE = 4533 98 | TSMULTIRANGE_ARRAY = 6152 99 | TSRANGE = 3908 100 | TSRANGE_ARRAY = 3909 101 | TSTZMULTIRANGE = 4534 102 | TSTZMULTIRANGE_ARRAY = 6153 103 | TSTZRANGE = 3910 104 | TSTZRANGE_ARRAY = 3911 105 | UNKNOWN = 705 106 | UUID_TYPE = 2950 107 | UUID_ARRAY = 2951 108 | VARCHAR = 1043 109 | VARCHAR_ARRAY = 1015 110 | XID = 28 111 | 112 | 113 | MIN_INT2, MAX_INT2 = -(2**15), 2**15 114 | MIN_INT4, MAX_INT4 = -(2**31), 2**31 115 | MIN_INT8, MAX_INT8 = -(2**63), 2**63 116 | 117 | 118 | def bool_in(data): 119 | return data == "t" 120 | 121 | 122 | def bool_out(v): 123 | return "true" if v else "false" 124 | 125 | 126 | def bytes_in(data): 127 | return bytes.fromhex(data[2:]) 128 | 129 | 130 | def bytes_out(v): 131 | return "\\x" + v.hex() 132 | 133 | 134 | def cidr_out(v): 135 | return str(v) 136 | 137 | 138 | def cidr_in(data): 139 | return ip_network(data, False) if "/" in data else ip_address(data) 140 | 141 | 142 | def date_in(data): 143 | if data in ("infinity", "-infinity"): 144 | return data 145 | else: 146 | try: 147 | return Datetime.strptime(data, "%Y-%m-%d").date() 148 | except ValueError: 149 | # pg date can overflow Python Datetime 150 | return data 151 | 152 | 153 | def date_out(v): 154 | return v.isoformat() 155 | 156 | 157 | def datetime_out(v): 158 | if v.tzinfo is None: 159 | return v.isoformat() 160 | else: 161 | return v.astimezone(Timezone.utc).isoformat() 162 | 163 | 164 | def enum_out(v): 165 | return str(v.value) 166 | 167 | 168 | def float_out(v): 169 | return str(v) 170 | 171 | 172 | def inet_in(data): 173 | return ip_network(data, False) if "/" in data else ip_address(data) 174 | 175 | 176 | def inet_out(v): 177 | return str(v) 178 | 179 | 180 | def int_in(data): 181 | return int(data) 182 | 183 | 184 | def int_out(v): 185 | return str(v) 186 | 187 | 188 | def interval_in(data): 189 | pg_interval = PGInterval.from_str(data) 190 | try: 191 | return pg_interval.to_timedelta() 192 | except ValueError: 193 | return pg_interval 194 | 195 | 196 | def interval_out(v): 197 | return f"{v.days} days {v.seconds} seconds {v.microseconds} microseconds" 198 | 199 | 200 | def json_in(data): 201 | return loads(data) 202 | 203 | 204 | def json_out(v): 205 | return dumps(v) 206 | 207 | 208 | def null_out(v): 209 | return None 210 | 211 | 212 | def numeric_in(data): 213 | return Decimal(data) 214 | 215 | 216 | def numeric_out(d): 217 | return str(d) 218 | 219 | 220 | def point_in(data): 221 | return tuple(map(float, data[1:-1].split(","))) 222 | 223 | 224 | def pg_interval_in(data): 225 | return PGInterval.from_str(data) 226 | 227 | 228 | def pg_interval_out(v): 229 | return str(v) 230 | 231 | 232 | def range_out(v): 233 | if v.is_empty: 234 | return "empty" 235 | else: 236 | le = v.lower 237 | val_lower = "" if le is None else make_param(PY_TYPES, le) 238 | ue = v.upper 239 | val_upper = "" if ue is None else make_param(PY_TYPES, ue) 240 | return f"{v.bounds[0]}{val_lower},{val_upper}{v.bounds[1]}" 241 | 242 | 243 | def string_in(data): 244 | return data 245 | 246 | 247 | def string_out(v): 248 | return v 249 | 250 | 251 | def time_in(data): 252 | pattern = "%H:%M:%S.%f" if "." in data else "%H:%M:%S" 253 | return Datetime.strptime(data, pattern).time() 254 | 255 | 256 | def time_out(v): 257 | return v.isoformat() 258 | 259 | 260 | def timestamp_in(data): 261 | if data in ("infinity", "-infinity"): 262 | return data 263 | 264 | try: 265 | pattern = "%Y-%m-%d %H:%M:%S.%f" if "." in data else "%Y-%m-%d %H:%M:%S" 266 | return Datetime.strptime(data, pattern) 267 | except ValueError: 268 | try: 269 | return parse(data) 270 | except ParserError: 271 | # pg timestamp can overflow Python Datetime 272 | return data 273 | 274 | 275 | def timestamptz_in(data): 276 | if data in ("infinity", "-infinity"): 277 | return data 278 | 279 | try: 280 | patt = "%Y-%m-%d %H:%M:%S.%f%z" if "." in data else "%Y-%m-%d %H:%M:%S%z" 281 | return Datetime.strptime(f"{data}00", patt) 282 | except ValueError: 283 | try: 284 | return parse(data) 285 | except ParserError: 286 | # pg timestamptz can overflow Python Datetime 287 | return data 288 | 289 | 290 | def unknown_out(v): 291 | return str(v) 292 | 293 | 294 | def vector_in(data): 295 | return [int(v) for v in data.split()] 296 | 297 | 298 | def uuid_out(v): 299 | return str(v) 300 | 301 | 302 | def uuid_in(data): 303 | return UUID(data) 304 | 305 | 306 | def _range_in(elem_func): 307 | def range_in(data): 308 | if data == "empty": 309 | return Range(is_empty=True) 310 | else: 311 | le, ue = [None if v == "" else elem_func(v) for v in data[1:-1].split(",")] 312 | return Range(le, ue, bounds=f"{data[0]}{data[-1]}") 313 | 314 | return range_in 315 | 316 | 317 | daterange_in = _range_in(date_in) 318 | int4range_in = _range_in(int) 319 | int8range_in = _range_in(int) 320 | numrange_in = _range_in(Decimal) 321 | 322 | 323 | def ts_in(data): 324 | return timestamp_in(data[1:-1]) 325 | 326 | 327 | def tstz_in(data): 328 | return timestamptz_in(data[1:-1]) 329 | 330 | 331 | tsrange_in = _range_in(ts_in) 332 | tstzrange_in = _range_in(tstz_in) 333 | 334 | 335 | def _multirange_in(adapter): 336 | def f(data): 337 | in_range = False 338 | result = [] 339 | val = [] 340 | for c in data: 341 | if in_range: 342 | val.append(c) 343 | if c in "])": 344 | value = "".join(val) 345 | val.clear() 346 | result.append(adapter(value)) 347 | in_range = False 348 | elif c in "[(": 349 | val.append(c) 350 | in_range = True 351 | 352 | return result 353 | 354 | return f 355 | 356 | 357 | datemultirange_in = _multirange_in(daterange_in) 358 | int4multirange_in = _multirange_in(int4range_in) 359 | int8multirange_in = _multirange_in(int8range_in) 360 | nummultirange_in = _multirange_in(numrange_in) 361 | tsmultirange_in = _multirange_in(tsrange_in) 362 | tstzmultirange_in = _multirange_in(tstzrange_in) 363 | 364 | 365 | class ParserState(Enum): 366 | InString = 1 367 | InEscape = 2 368 | InValue = 3 369 | Out = 4 370 | 371 | 372 | def _parse_array(data, adapter): 373 | state = ParserState.Out 374 | stack = [[]] 375 | val = [] 376 | for c in data: 377 | if state == ParserState.InValue: 378 | if c in ("}", ","): 379 | value = "".join(val) 380 | stack[-1].append(None if value == "NULL" else adapter(value)) 381 | state = ParserState.Out 382 | else: 383 | val.append(c) 384 | 385 | if state == ParserState.Out: 386 | if c == "{": 387 | a = [] 388 | stack[-1].append(a) 389 | stack.append(a) 390 | elif c == "}": 391 | stack.pop() 392 | elif c == ",": 393 | pass 394 | elif c == '"': 395 | val = [] 396 | state = ParserState.InString 397 | else: 398 | val = [c] 399 | state = ParserState.InValue 400 | 401 | elif state == ParserState.InString: 402 | if c == '"': 403 | stack[-1].append(adapter("".join(val))) 404 | state = ParserState.Out 405 | elif c == "\\": 406 | state = ParserState.InEscape 407 | else: 408 | val.append(c) 409 | elif state == ParserState.InEscape: 410 | val.append(c) 411 | state = ParserState.InString 412 | 413 | return stack[0][0] 414 | 415 | 416 | def _array_in(adapter): 417 | def f(data): 418 | return _parse_array(data, adapter) 419 | 420 | return f 421 | 422 | 423 | bool_array_in = _array_in(bool_in) 424 | bytes_array_in = _array_in(bytes_in) 425 | cidr_array_in = _array_in(cidr_in) 426 | date_array_in = _array_in(date_in) 427 | datemultirange_array_in = _array_in(datemultirange_in) 428 | daterange_array_in = _array_in(daterange_in) 429 | inet_array_in = _array_in(inet_in) 430 | int_array_in = _array_in(int) 431 | int4multirange_array_in = _array_in(int4multirange_in) 432 | int4range_array_in = _array_in(int4range_in) 433 | int8multirange_array_in = _array_in(int8multirange_in) 434 | int8range_array_in = _array_in(int8range_in) 435 | interval_array_in = _array_in(interval_in) 436 | json_array_in = _array_in(json_in) 437 | float_array_in = _array_in(float) 438 | numeric_array_in = _array_in(numeric_in) 439 | nummultirange_array_in = _array_in(nummultirange_in) 440 | numrange_array_in = _array_in(numrange_in) 441 | string_array_in = _array_in(string_in) 442 | time_array_in = _array_in(time_in) 443 | timestamp_array_in = _array_in(timestamp_in) 444 | timestamptz_array_in = _array_in(timestamptz_in) 445 | tsrange_array_in = _array_in(tsrange_in) 446 | tsmultirange_array_in = _array_in(tsmultirange_in) 447 | tstzmultirange_array_in = _array_in(tstzmultirange_in) 448 | tstzrange_array_in = _array_in(tstzrange_in) 449 | uuid_array_in = _array_in(uuid_in) 450 | 451 | 452 | def array_string_escape(v): 453 | cs = [] 454 | for c in v: 455 | if c == "\\": 456 | cs.append("\\") 457 | elif c == '"': 458 | cs.append("\\") 459 | cs.append(c) 460 | val = "".join(cs) 461 | if ( 462 | len(val) == 0 463 | or val == "NULL" 464 | or any(c.isspace() for c in val) 465 | or any(c in val for c in ("{", "}", ",", "\\")) 466 | ): 467 | val = f'"{val}"' 468 | return val 469 | 470 | 471 | def array_out(ar): 472 | result = [] 473 | for v in ar: 474 | if isinstance(v, list): 475 | val = array_out(v) 476 | 477 | elif isinstance(v, tuple): 478 | val = f'"{composite_out(v)}"' 479 | 480 | elif v is None: 481 | val = "NULL" 482 | 483 | elif isinstance(v, dict): 484 | val = array_string_escape(json_out(v)) 485 | 486 | elif isinstance(v, (bytes, bytearray)): 487 | val = f'"\\{bytes_out(v)}"' 488 | 489 | elif isinstance(v, str): 490 | val = array_string_escape(v) 491 | 492 | else: 493 | val = make_param(PY_TYPES, v) 494 | 495 | result.append(val) 496 | 497 | return f'{{{",".join(result)}}}' 498 | 499 | 500 | def composite_out(ar): 501 | result = [] 502 | for v in ar: 503 | if isinstance(v, list): 504 | val = array_out(v) 505 | 506 | elif isinstance(v, tuple): 507 | val = composite_out(v) 508 | 509 | elif v is None: 510 | val = "" 511 | 512 | elif isinstance(v, dict): 513 | val = array_string_escape(json_out(v)) 514 | 515 | elif isinstance(v, (bytes, bytearray)): 516 | val = f'"\\{bytes_out(v)}"' 517 | 518 | elif isinstance(v, str): 519 | val = array_string_escape(v) 520 | 521 | else: 522 | val = make_param(PY_TYPES, v) 523 | 524 | result.append(val) 525 | 526 | return f'({",".join(result)})' 527 | 528 | 529 | def record_in(data): 530 | state = ParserState.Out 531 | results = [] 532 | val = [] 533 | for c in data: 534 | if state == ParserState.InValue: 535 | if c in (")", ","): 536 | value = "".join(val) 537 | val.clear() 538 | results.append(None if value == "" else value) 539 | state = ParserState.Out 540 | else: 541 | val.append(c) 542 | 543 | if state == ParserState.Out: 544 | if c in "(),": 545 | pass 546 | elif c == '"': 547 | state = ParserState.InString 548 | else: 549 | val.append(c) 550 | state = ParserState.InValue 551 | 552 | elif state == ParserState.InString: 553 | if c == '"': 554 | results.append("".join(val)) 555 | val.clear() 556 | state = ParserState.Out 557 | elif c == "\\": 558 | state = ParserState.InEscape 559 | else: 560 | val.append(c) 561 | 562 | elif state == ParserState.InEscape: 563 | val.append(c) 564 | state = ParserState.InString 565 | 566 | return tuple(results) 567 | 568 | 569 | PY_PG = { 570 | Date: DATE, 571 | Decimal: NUMERIC, 572 | IPv4Address: INET, 573 | IPv6Address: INET, 574 | IPv4Network: INET, 575 | IPv6Network: INET, 576 | PGInterval: INTERVAL, 577 | Time: TIME, 578 | Timedelta: INTERVAL, 579 | UUID: UUID_TYPE, 580 | bool: BOOLEAN, 581 | bytearray: BYTES, 582 | dict: JSONB, 583 | float: FLOAT, 584 | type(None): NULLTYPE, 585 | bytes: BYTES, 586 | str: TEXT, 587 | } 588 | 589 | 590 | PY_TYPES = { 591 | Date: date_out, # date 592 | Datetime: datetime_out, 593 | Decimal: numeric_out, # numeric 594 | Enum: enum_out, # enum 595 | IPv4Address: inet_out, # inet 596 | IPv6Address: inet_out, # inet 597 | IPv4Network: inet_out, # inet 598 | IPv6Network: inet_out, # inet 599 | PGInterval: interval_out, # interval 600 | Range: range_out, # range types 601 | Time: time_out, # time 602 | Timedelta: interval_out, # interval 603 | UUID: uuid_out, # uuid 604 | bool: bool_out, # bool 605 | bytearray: bytes_out, # bytea 606 | dict: json_out, # jsonb 607 | float: float_out, # float8 608 | type(None): null_out, # null 609 | bytes: bytes_out, # bytea 610 | str: string_out, # unknown 611 | int: int_out, 612 | list: array_out, 613 | tuple: composite_out, 614 | } 615 | 616 | 617 | PG_TYPES = { 618 | BIGINT: int, # int8 619 | BIGINT_ARRAY: int_array_in, # int8[] 620 | BOOLEAN: bool_in, # bool 621 | BOOLEAN_ARRAY: bool_array_in, # bool[] 622 | BYTES: bytes_in, # bytea 623 | BYTES_ARRAY: bytes_array_in, # bytea[] 624 | CHAR: string_in, # char 625 | CHAR_ARRAY: string_array_in, # char[] 626 | CIDR_ARRAY: cidr_array_in, # cidr[] 627 | CSTRING: string_in, # cstring 628 | CSTRING_ARRAY: string_array_in, # cstring[] 629 | DATE: date_in, # date 630 | DATE_ARRAY: date_array_in, # date[] 631 | DATEMULTIRANGE: datemultirange_in, # datemultirange 632 | DATEMULTIRANGE_ARRAY: datemultirange_array_in, # datemultirange[] 633 | DATERANGE: daterange_in, # daterange 634 | DATERANGE_ARRAY: daterange_array_in, # daterange[] 635 | FLOAT: float, # float8 636 | FLOAT_ARRAY: float_array_in, # float8[] 637 | INET: inet_in, # inet 638 | INET_ARRAY: inet_array_in, # inet[] 639 | INT4MULTIRANGE: int4multirange_in, # int4multirange 640 | INT4MULTIRANGE_ARRAY: int4multirange_array_in, # int4multirange[] 641 | INT4RANGE: int4range_in, # int4range 642 | INT4RANGE_ARRAY: int4range_array_in, # int4range[] 643 | INT8MULTIRANGE: int8multirange_in, # int8multirange 644 | INT8MULTIRANGE_ARRAY: int8multirange_array_in, # int8multirange[] 645 | INT8RANGE: int8range_in, # int8range 646 | INT8RANGE_ARRAY: int8range_array_in, # int8range[] 647 | INTEGER: int, # int4 648 | INTEGER_ARRAY: int_array_in, # int4[] 649 | JSON: json_in, # json 650 | JSON_ARRAY: json_array_in, # json[] 651 | JSONB: json_in, # jsonb 652 | JSONB_ARRAY: json_array_in, # jsonb[] 653 | MACADDR: string_in, # MACADDR type 654 | MONEY: string_in, # money 655 | MONEY_ARRAY: string_array_in, # money[] 656 | NAME: string_in, # name 657 | NAME_ARRAY: string_array_in, # name[] 658 | NUMERIC: numeric_in, # numeric 659 | NUMERIC_ARRAY: numeric_array_in, # numeric[] 660 | NUMRANGE: numrange_in, # numrange 661 | NUMRANGE_ARRAY: numrange_array_in, # numrange[] 662 | NUMMULTIRANGE: nummultirange_in, # nummultirange 663 | NUMMULTIRANGE_ARRAY: nummultirange_array_in, # nummultirange[] 664 | OID: int, # oid 665 | POINT: point_in, # point 666 | INTERVAL: interval_in, # interval 667 | INTERVAL_ARRAY: interval_array_in, # interval[] 668 | REAL: float, # float4 669 | REAL_ARRAY: float_array_in, # float4[] 670 | RECORD: record_in, # record 671 | SMALLINT: int, # int2 672 | SMALLINT_ARRAY: int_array_in, # int2[] 673 | SMALLINT_VECTOR: vector_in, # int2vector 674 | TEXT: string_in, # text 675 | TEXT_ARRAY: string_array_in, # text[] 676 | TIME: time_in, # time 677 | TIME_ARRAY: time_array_in, # time[] 678 | INTERVAL: interval_in, # interval 679 | TIMESTAMP: timestamp_in, # timestamp 680 | TIMESTAMP_ARRAY: timestamp_array_in, # timestamp 681 | TIMESTAMPTZ: timestamptz_in, # timestamptz 682 | TIMESTAMPTZ_ARRAY: timestamptz_array_in, # timestamptz 683 | TSMULTIRANGE: tsmultirange_in, # tsmultirange 684 | TSMULTIRANGE_ARRAY: tsmultirange_array_in, # tsmultirange[] 685 | TSRANGE: tsrange_in, # tsrange 686 | TSRANGE_ARRAY: tsrange_array_in, # tsrange[] 687 | TSTZMULTIRANGE: tstzmultirange_in, # tstzmultirange 688 | TSTZMULTIRANGE_ARRAY: tstzmultirange_array_in, # tstzmultirange[] 689 | TSTZRANGE: tstzrange_in, # tstzrange 690 | TSTZRANGE_ARRAY: tstzrange_array_in, # tstzrange[] 691 | UNKNOWN: string_in, # unknown 692 | UUID_ARRAY: uuid_array_in, # uuid[] 693 | UUID_TYPE: uuid_in, # uuid 694 | VARCHAR: string_in, # varchar 695 | VARCHAR_ARRAY: string_array_in, # varchar[] 696 | XID: int, # xid 697 | } 698 | 699 | 700 | # PostgreSQL encodings: 701 | # https://www.postgresql.org/docs/current/multibyte.html 702 | # 703 | # Python encodings: 704 | # https://docs.python.org/3/library/codecs.html 705 | # 706 | # Commented out encodings don't require a name change between PostgreSQL and 707 | # Python. If the py side is None, then the encoding isn't supported. 708 | PG_PY_ENCODINGS = { 709 | # Not supported: 710 | "mule_internal": None, 711 | "euc_tw": None, 712 | # Name fine as-is: 713 | # "euc_jp", 714 | # "euc_jis_2004", 715 | # "euc_kr", 716 | # "gb18030", 717 | # "gbk", 718 | # "johab", 719 | # "sjis", 720 | # "shift_jis_2004", 721 | # "uhc", 722 | # "utf8", 723 | # Different name: 724 | "euc_cn": "gb2312", 725 | "iso_8859_5": "is8859_5", 726 | "iso_8859_6": "is8859_6", 727 | "iso_8859_7": "is8859_7", 728 | "iso_8859_8": "is8859_8", 729 | "koi8": "koi8_r", 730 | "latin1": "iso8859-1", 731 | "latin2": "iso8859_2", 732 | "latin3": "iso8859_3", 733 | "latin4": "iso8859_4", 734 | "latin5": "iso8859_9", 735 | "latin6": "iso8859_10", 736 | "latin7": "iso8859_13", 737 | "latin8": "iso8859_14", 738 | "latin9": "iso8859_15", 739 | "sql_ascii": "ascii", 740 | "win866": "cp886", 741 | "win874": "cp874", 742 | "win1250": "cp1250", 743 | "win1251": "cp1251", 744 | "win1252": "cp1252", 745 | "win1253": "cp1253", 746 | "win1254": "cp1254", 747 | "win1255": "cp1255", 748 | "win1256": "cp1256", 749 | "win1257": "cp1257", 750 | "win1258": "cp1258", 751 | "unicode": "utf-8", # Needed for Amazon Redshift 752 | } 753 | 754 | 755 | def make_param(py_types, value): 756 | try: 757 | func = py_types[type(value)] 758 | except KeyError: 759 | func = str 760 | for k, v in py_types.items(): 761 | try: 762 | if isinstance(value, k): 763 | func = v 764 | break 765 | except TypeError: 766 | pass 767 | 768 | return func(value) 769 | 770 | 771 | def make_params(py_types, values): 772 | return tuple([make_param(py_types, v) for v in values]) 773 | 774 | 775 | def identifier(sql): 776 | if not isinstance(sql, str): 777 | raise InterfaceError("identifier must be a str") 778 | 779 | if len(sql) == 0: 780 | raise InterfaceError("identifier must be > 0 characters in length") 781 | 782 | if "\u0000" in sql: 783 | raise InterfaceError("identifier cannot contain the code zero character") 784 | 785 | sql = sql.replace('"', '""') 786 | return f'"{sql}"' 787 | 788 | 789 | def literal(value): 790 | if value is None: 791 | return "NULL" 792 | elif isinstance(value, bool): 793 | return "TRUE" if value else "FALSE" 794 | elif isinstance(value, (int, float, Decimal)): 795 | return str(value) 796 | elif isinstance(value, (bytes, bytearray)): 797 | return f"X'{value.hex()}'" 798 | elif isinstance(value, Datetime): 799 | return f"'{datetime_out(value)}'" 800 | elif isinstance(value, Date): 801 | return f"'{date_out(value)}'" 802 | elif isinstance(value, Time): 803 | return f"'{time_out(value)}'" 804 | elif isinstance(value, Timedelta): 805 | return f"'{interval_out(value)}'" 806 | elif isinstance(value, list): 807 | return f"'{array_out(value)}'" 808 | else: 809 | val = str(value).replace("'", "''") 810 | return f"'{val}'" 811 | -------------------------------------------------------------------------------- /src/pg8000/exceptions.py: -------------------------------------------------------------------------------- 1 | class Error(Exception): 2 | """Generic exception that is the base exception of all other error 3 | exceptions. 4 | 5 | This exception is part of the `DBAPI 2.0 specification 6 | `_. 7 | """ 8 | 9 | pass 10 | 11 | 12 | class InterfaceError(Error): 13 | """Generic exception raised for errors that are related to the database 14 | interface rather than the database itself. For example, if the interface 15 | attempts to use an SSL connection but the server refuses, an InterfaceError 16 | will be raised. 17 | 18 | This exception is part of the `DBAPI 2.0 specification 19 | `_. 20 | """ 21 | 22 | pass 23 | 24 | 25 | class DatabaseError(Error): 26 | """Generic exception raised for errors that are related to the database. 27 | 28 | This exception is part of the `DBAPI 2.0 specification 29 | `_. 30 | """ 31 | 32 | pass 33 | -------------------------------------------------------------------------------- /src/pg8000/native.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from enum import Enum, auto 3 | 4 | from pg8000.converters import ( 5 | BIGINT, 6 | BOOLEAN, 7 | BOOLEAN_ARRAY, 8 | BYTES, 9 | CHAR, 10 | CHAR_ARRAY, 11 | DATE, 12 | FLOAT, 13 | FLOAT_ARRAY, 14 | INET, 15 | INT2VECTOR, 16 | INTEGER, 17 | INTEGER_ARRAY, 18 | INTERVAL, 19 | JSON, 20 | JSONB, 21 | JSONB_ARRAY, 22 | JSON_ARRAY, 23 | MACADDR, 24 | NAME, 25 | NAME_ARRAY, 26 | NULLTYPE, 27 | NUMERIC, 28 | NUMERIC_ARRAY, 29 | OID, 30 | PGInterval, 31 | STRING, 32 | TEXT, 33 | TEXT_ARRAY, 34 | TIME, 35 | TIMESTAMP, 36 | TIMESTAMPTZ, 37 | UNKNOWN, 38 | UUID_TYPE, 39 | VARCHAR, 40 | VARCHAR_ARRAY, 41 | XID, 42 | identifier, 43 | literal, 44 | make_params, 45 | ) 46 | from pg8000.core import CoreConnection, ver 47 | from pg8000.exceptions import DatabaseError, Error, InterfaceError 48 | from pg8000.types import Range 49 | 50 | __version__ = ver 51 | 52 | # Copyright (c) 2007-2009, Mathieu Fenniak 53 | # Copyright (c) The Contributors 54 | # All rights reserved. 55 | # 56 | # Redistribution and use in source and binary forms, with or without 57 | # modification, are permitted provided that the following conditions are 58 | # met: 59 | # 60 | # * Redistributions of source code must retain the above copyright notice, 61 | # this list of conditions and the following disclaimer. 62 | # * Redistributions in binary form must reproduce the above copyright notice, 63 | # this list of conditions and the following disclaimer in the documentation 64 | # and/or other materials provided with the distribution. 65 | # * The name of the author may not be used to endorse or promote products 66 | # derived from this software without specific prior written permission. 67 | # 68 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 69 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 70 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 71 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 72 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 73 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 74 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 75 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 76 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 77 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 78 | # POSSIBILITY OF SUCH DAMAGE. 79 | 80 | 81 | class State(Enum): 82 | OUT = auto() # outside quoted string 83 | IN_SQ = auto() # inside single-quote string '...' 84 | IN_QI = auto() # inside quoted identifier "..." 85 | IN_ES = auto() # inside escaped single-quote string, E'...' 86 | IN_PN = auto() # inside parameter name eg. :name 87 | IN_CO = auto() # inside inline comment eg. -- 88 | IN_DQ = auto() # inside dollar-quoted string eg. $$...$$ 89 | IN_DP = auto() # inside dollar parameter eg. $1 90 | 91 | 92 | def to_statement(query): 93 | in_quote_escape = False 94 | placeholders = [] 95 | output_query = [] 96 | state = State.OUT 97 | prev_c = None 98 | for i, c in enumerate(query): 99 | if i + 1 < len(query): 100 | next_c = query[i + 1] 101 | else: 102 | next_c = None 103 | 104 | if state == State.OUT: 105 | if c == "'": 106 | output_query.append(c) 107 | if prev_c == "E": 108 | state = State.IN_ES 109 | else: 110 | state = State.IN_SQ 111 | elif c == '"': 112 | output_query.append(c) 113 | state = State.IN_QI 114 | elif c == "-": 115 | output_query.append(c) 116 | if prev_c == "-": 117 | state = State.IN_CO 118 | elif c == "$": 119 | output_query.append(c) 120 | if prev_c == "$": 121 | state = State.IN_DQ 122 | elif next_c.isdigit(): 123 | state = State.IN_DP 124 | placeholders.append("") 125 | elif c == ":" and next_c not in ":=" and prev_c != ":": 126 | state = State.IN_PN 127 | placeholders.append("") 128 | else: 129 | output_query.append(c) 130 | 131 | elif state == State.IN_SQ: 132 | if c == "'": 133 | if in_quote_escape: 134 | in_quote_escape = False 135 | elif next_c == "'": 136 | in_quote_escape = True 137 | else: 138 | state = State.OUT 139 | output_query.append(c) 140 | 141 | elif state == State.IN_QI: 142 | if c == '"': 143 | state = State.OUT 144 | output_query.append(c) 145 | 146 | elif state == State.IN_ES: 147 | if c == "'" and prev_c != "\\": 148 | # check for escaped single-quote 149 | state = State.OUT 150 | output_query.append(c) 151 | 152 | elif state == State.IN_PN: 153 | placeholders[-1] += c 154 | if next_c is None or (not next_c.isalnum() and next_c != "_"): 155 | state = State.OUT 156 | try: 157 | pidx = placeholders.index(placeholders[-1], 0, -1) 158 | output_query.append(f"${pidx + 1}") 159 | del placeholders[-1] 160 | except ValueError: 161 | output_query.append(f"${len(placeholders)}") 162 | 163 | elif state == State.IN_DP: 164 | placeholders[-1] += c 165 | output_query.append(c) 166 | if next_c is None or not next_c.isdigit(): 167 | try: 168 | placeholders[-1] = int(placeholders[-1]) - 1 169 | except ValueError: 170 | raise InterfaceError( 171 | f"Expected an integer for the $ placeholder but found " 172 | f"'{placeholders[-1]}'" 173 | ) 174 | state = State.OUT 175 | 176 | elif state == State.IN_CO: 177 | output_query.append(c) 178 | if c == "\n": 179 | state = State.OUT 180 | 181 | elif state == State.IN_DQ: 182 | output_query.append(c) 183 | if c == "$" and prev_c == "$": 184 | state = State.OUT 185 | 186 | prev_c = c 187 | 188 | for reserved in ("types", "stream"): 189 | if reserved in placeholders: 190 | raise InterfaceError( 191 | f"The name '{reserved}' can't be used as a placeholder because it's " 192 | f"used for another purpose." 193 | ) 194 | 195 | def make_vals(args): 196 | arg_list = [v for _, v in args.items()] 197 | vals = [] 198 | for p in placeholders: 199 | if isinstance(p, int): 200 | vals.append(arg_list[p]) 201 | else: 202 | try: 203 | vals.append(args[p]) 204 | except KeyError: 205 | raise InterfaceError( 206 | f"There's a placeholder '{p}' in the query, but no matching " 207 | f"keyword argument." 208 | ) 209 | return tuple(vals) 210 | 211 | return "".join(output_query), make_vals 212 | 213 | 214 | class Connection(CoreConnection): 215 | def __init__(self, *args, **kwargs): 216 | super().__init__(*args, **kwargs) 217 | self._context = None 218 | 219 | @property 220 | def columns(self): 221 | context = self._context 222 | if context is None: 223 | return None 224 | return context.columns 225 | 226 | @property 227 | def row_count(self): 228 | context = self._context 229 | if context is None: 230 | return None 231 | return context.row_count 232 | 233 | def run(self, sql, stream=None, types=None, **params): 234 | if len(params) == 0 and stream is None: 235 | self._context = self.execute_simple(sql) 236 | else: 237 | statement, make_vals = to_statement(sql) 238 | oids = () if types is None else make_vals(defaultdict(lambda: None, types)) 239 | self._context = self.execute_unnamed( 240 | statement, make_vals(params), oids=oids, stream=stream 241 | ) 242 | return self._context.rows 243 | 244 | def prepare(self, sql): 245 | return PreparedStatement(self, sql) 246 | 247 | 248 | class PreparedStatement: 249 | def __init__(self, con, sql, types=None): 250 | self.con = con 251 | self.statement, self.make_vals = to_statement(sql) 252 | oids = () if types is None else self.make_vals(defaultdict(lambda: None, types)) 253 | self.name_bin, self.cols, self.input_funcs = con.prepare_statement( 254 | self.statement, oids 255 | ) 256 | 257 | @property 258 | def columns(self): 259 | return self._context.columns 260 | 261 | def run(self, stream=None, **params): 262 | params = make_params(self.con.py_types, self.make_vals(params)) 263 | 264 | self._context = self.con.execute_named( 265 | self.name_bin, params, self.cols, self.input_funcs, self.statement 266 | ) 267 | 268 | return self._context.rows 269 | 270 | def close(self): 271 | self.con.close_prepared_statement(self.name_bin) 272 | 273 | 274 | __all__ = [ 275 | "BIGINT", 276 | "BOOLEAN", 277 | "BOOLEAN_ARRAY", 278 | "BYTES", 279 | "CHAR", 280 | "CHAR_ARRAY", 281 | "Connection", 282 | "DATE", 283 | "DatabaseError", 284 | "Error", 285 | "FLOAT", 286 | "FLOAT_ARRAY", 287 | "INET", 288 | "INT2VECTOR", 289 | "INTEGER", 290 | "INTEGER_ARRAY", 291 | "INTERVAL", 292 | "InterfaceError", 293 | "JSON", 294 | "JSONB", 295 | "JSONB_ARRAY", 296 | "JSON_ARRAY", 297 | "MACADDR", 298 | "NAME", 299 | "NAME_ARRAY", 300 | "NULLTYPE", 301 | "NUMERIC", 302 | "NUMERIC_ARRAY", 303 | "OID", 304 | "PGInterval", 305 | "Range", 306 | "STRING", 307 | "TEXT", 308 | "TEXT_ARRAY", 309 | "TIME", 310 | "TIMESTAMP", 311 | "TIMESTAMPTZ", 312 | "UNKNOWN", 313 | "UUID_TYPE", 314 | "VARCHAR", 315 | "VARCHAR_ARRAY", 316 | "XID", 317 | "identifier", 318 | "literal", 319 | ] 320 | -------------------------------------------------------------------------------- /src/pg8000/types.py: -------------------------------------------------------------------------------- 1 | from datetime import timedelta as Timedelta 2 | 3 | 4 | class PGInterval: 5 | UNIT_MAP = { 6 | "millennia": "millennia", 7 | "millennium": "millennia", 8 | "centuries": "centuries", 9 | "century": "centuries", 10 | "decades": "decades", 11 | "decade": "decades", 12 | "years": "years", 13 | "year": "years", 14 | "months": "months", 15 | "month": "months", 16 | "mon": "months", 17 | "mons": "months", 18 | "weeks": "weeks", 19 | "week": "weeks", 20 | "days": "days", 21 | "day": "days", 22 | "hours": "hours", 23 | "hour": "hours", 24 | "minutes": "minutes", 25 | "minute": "minutes", 26 | "mins": "minutes", 27 | "secs": "seconds", 28 | "seconds": "seconds", 29 | "second": "seconds", 30 | "microseconds": "microseconds", 31 | "microsecond": "microseconds", 32 | } 33 | 34 | ISO_LOOKUP = { 35 | True: { 36 | "Y": "years", 37 | "M": "months", 38 | "D": "days", 39 | }, 40 | False: { 41 | "H": "hours", 42 | "M": "minutes", 43 | "S": "seconds", 44 | }, 45 | } 46 | 47 | @classmethod 48 | def from_str_iso_8601(cls, interval_str): 49 | # P[n]Y[n]M[n]DT[n]H[n]M[n]S 50 | kwargs = {} 51 | lookup = cls.ISO_LOOKUP[True] 52 | val = [] 53 | 54 | for c in interval_str[1:]: 55 | if c == "T": 56 | lookup = cls.ISO_LOOKUP[False] 57 | elif c.isdigit() or c in ("-", "."): 58 | val.append(c) 59 | else: 60 | val_str = "".join(val) 61 | name = lookup[c] 62 | v = float(val_str) if name == "seconds" else int(val_str) 63 | kwargs[name] = v 64 | val.clear() 65 | 66 | return cls(**kwargs) 67 | 68 | @classmethod 69 | def from_str_postgres(cls, interval_str): 70 | """Parses both the postgres and postgres_verbose formats""" 71 | 72 | t = {} 73 | 74 | curr_val = None 75 | for k in interval_str.split(): 76 | if ":" in k: 77 | hours_str, minutes_str, seconds_str = k.split(":") 78 | hours = int(hours_str) 79 | if hours != 0: 80 | t["hours"] = hours 81 | minutes = int(minutes_str) 82 | if minutes != 0: 83 | t["minutes"] = minutes 84 | 85 | seconds = float(seconds_str) 86 | 87 | if seconds != 0: 88 | t["seconds"] = seconds 89 | 90 | elif k == "@": 91 | continue 92 | 93 | elif k == "ago": 94 | for k, v in tuple(t.items()): 95 | t[k] = -1 * v 96 | 97 | else: 98 | try: 99 | curr_val = int(k) 100 | except ValueError: 101 | t[cls.UNIT_MAP[k]] = curr_val 102 | 103 | return cls(**t) 104 | 105 | @classmethod 106 | def from_str_sql_standard(cls, interval_str): 107 | """YYYY-MM 108 | or 109 | DD HH:MM:SS.F 110 | or 111 | YYYY-MM DD HH:MM:SS.F 112 | """ 113 | month_part = None 114 | day_parts = None 115 | parts = interval_str.split() 116 | 117 | if len(parts) == 1: 118 | month_part = parts[0] 119 | elif len(parts) == 2: 120 | day_parts = parts 121 | else: 122 | month_part = parts[0] 123 | day_parts = parts[1:] 124 | 125 | kwargs = {} 126 | 127 | if month_part is not None: 128 | if month_part.startswith("-"): 129 | sign = -1 130 | p = month_part[1:] 131 | else: 132 | sign = 1 133 | p = month_part 134 | 135 | kwargs["years"], kwargs["months"] = [int(v) * sign for v in p.split("-")] 136 | 137 | if day_parts is not None: 138 | kwargs["days"] = int(day_parts[0]) 139 | time_part = day_parts[1] 140 | 141 | if time_part.startswith("-"): 142 | sign = -1 143 | p = time_part[1:] 144 | else: 145 | sign = 1 146 | p = time_part 147 | 148 | hours, minutes, seconds = p.split(":") 149 | kwargs["hours"] = int(hours) * sign 150 | kwargs["minutes"] = int(minutes) * sign 151 | kwargs["seconds"] = float(seconds) * sign 152 | 153 | return cls(**kwargs) 154 | 155 | @classmethod 156 | def from_str(cls, interval_str): 157 | if interval_str.startswith("P"): 158 | return cls.from_str_iso_8601(interval_str) 159 | elif interval_str.startswith("@"): 160 | return cls.from_str_postgres(interval_str) 161 | else: 162 | parts = interval_str.split() 163 | if (len(parts) > 1 and parts[1][0].isalpha()) or ( 164 | len(parts) == 1 and ":" in parts[0] 165 | ): 166 | return cls.from_str_postgres(interval_str) 167 | else: 168 | return cls.from_str_sql_standard(interval_str) 169 | 170 | def __init__( 171 | self, 172 | millennia=None, 173 | centuries=None, 174 | decades=None, 175 | years=None, 176 | months=None, 177 | weeks=None, 178 | days=None, 179 | hours=None, 180 | minutes=None, 181 | seconds=None, 182 | microseconds=None, 183 | ): 184 | self.millennia = millennia 185 | self.centuries = centuries 186 | self.decades = decades 187 | self.years = years 188 | self.months = months 189 | self.weeks = weeks 190 | self.days = days 191 | self.hours = hours 192 | self.minutes = minutes 193 | self.seconds = seconds 194 | self.microseconds = microseconds 195 | 196 | def __repr__(self): 197 | return f"" 198 | 199 | def _value_dict(self): 200 | return { 201 | k: v 202 | for k, v in ( 203 | ("millennia", self.millennia), 204 | ("centuries", self.centuries), 205 | ("decades", self.decades), 206 | ("years", self.years), 207 | ("months", self.months), 208 | ("weeks", self.weeks), 209 | ("days", self.days), 210 | ("hours", self.hours), 211 | ("minutes", self.minutes), 212 | ("seconds", self.seconds), 213 | ("microseconds", self.microseconds), 214 | ) 215 | if v is not None 216 | } 217 | 218 | def __str__(self): 219 | return " ".join(f"{v} {n}" for n, v in self._value_dict().items()) 220 | 221 | def normalize(self): 222 | months = 0 223 | if self.months is not None: 224 | months += self.months 225 | if self.years is not None: 226 | months += self.years * 12 227 | 228 | days = 0 229 | if self.days is not None: 230 | days += self.days 231 | if self.weeks is not None: 232 | days += self.weeks * 7 233 | 234 | seconds = 0 235 | if self.hours is not None: 236 | seconds += self.hours * 60 * 60 237 | if self.minutes is not None: 238 | seconds += self.minutes * 60 239 | if self.seconds is not None: 240 | seconds += self.seconds 241 | if self.microseconds is not None: 242 | seconds += self.microseconds / 1000000 243 | 244 | return PGInterval(months=months, days=days, seconds=seconds) 245 | 246 | def __eq__(self, other): 247 | if isinstance(other, PGInterval): 248 | s = self.normalize() 249 | o = other.normalize() 250 | return s.months == o.months and s.days == o.days and s.seconds == o.seconds 251 | else: 252 | return False 253 | 254 | def to_timedelta(self): 255 | pairs = self._value_dict() 256 | overlap = pairs.keys() & { 257 | "weeks", 258 | "months", 259 | "years", 260 | "decades", 261 | "centuries", 262 | "millennia", 263 | } 264 | if len(overlap) > 0: 265 | raise ValueError( 266 | "Can't fit the interval fields {overlap} into a datetime.timedelta." 267 | ) 268 | 269 | return Timedelta(**pairs) 270 | 271 | 272 | class Range: 273 | def __init__( 274 | self, 275 | lower=None, 276 | upper=None, 277 | bounds="[)", 278 | is_empty=False, 279 | ): 280 | self.lower = lower 281 | self.upper = upper 282 | self.bounds = bounds 283 | self.is_empty = is_empty 284 | 285 | def __eq__(self, other): 286 | if isinstance(other, Range): 287 | if self.is_empty or other.is_empty: 288 | return self.is_empty == other.is_empty 289 | else: 290 | return ( 291 | self.lower == other.lower 292 | and self.upper == other.upper 293 | and self.bounds == other.bounds 294 | ) 295 | return False 296 | 297 | def __str__(self): 298 | if self.is_empty: 299 | return "empty" 300 | else: 301 | le, ue = ["" if v is None else v for v in (self.lower, self.upper)] 302 | return f"{self.bounds[0]}{le},{ue}{self.bounds[1]}" 303 | 304 | def __repr__(self): 305 | return f"" 306 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tlocke/pg8000/46c00021ade1d19466b07ed30392386c5f0a6b8e/test/__init__.py -------------------------------------------------------------------------------- /test/dbapi/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tlocke/pg8000/46c00021ade1d19466b07ed30392386c5f0a6b8e/test/dbapi/__init__.py -------------------------------------------------------------------------------- /test/dbapi/auth/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tlocke/pg8000/46c00021ade1d19466b07ed30392386c5f0a6b8e/test/dbapi/auth/__init__.py -------------------------------------------------------------------------------- /test/dbapi/auth/test_gss.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from pg8000.dbapi import InterfaceError, connect 4 | 5 | 6 | # This requires a line in pg_hba.conf that requires gss for the database 7 | # pg8000_gss 8 | 9 | 10 | def test_gss(db_kwargs): 11 | db_kwargs["database"] = "pg8000_gss" 12 | 13 | # Should raise an exception saying gss isn't supported 14 | with pytest.raises( 15 | InterfaceError, 16 | match="Authentication method 7 not supported by pg8000.", 17 | ): 18 | with connect(**db_kwargs): 19 | pass 20 | -------------------------------------------------------------------------------- /test/dbapi/auth/test_md5.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from pg8000.dbapi import DatabaseError, connect 4 | 5 | 6 | # This requires a line in pg_hba.conf that requires md5 for the database 7 | # pg8000_md5 8 | 9 | 10 | def test_md5(db_kwargs): 11 | db_kwargs["database"] = "pg8000_md5" 12 | 13 | # Should only raise an exception saying db doesn't exist 14 | with pytest.raises(DatabaseError, match="3D000"): 15 | with connect(**db_kwargs): 16 | pass 17 | -------------------------------------------------------------------------------- /test/dbapi/auth/test_md5_ssl.py: -------------------------------------------------------------------------------- 1 | import ssl 2 | 3 | from pg8000.dbapi import connect 4 | 5 | 6 | def test_md5_ssl(db_kwargs): 7 | context = ssl.create_default_context() 8 | context.check_hostname = False 9 | context.verify_mode = ssl.CERT_NONE 10 | db_kwargs["ssl_context"] = context 11 | with connect(**db_kwargs): 12 | pass 13 | -------------------------------------------------------------------------------- /test/dbapi/auth/test_password.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from pg8000.dbapi import DatabaseError, connect 4 | 5 | # This requires a line in pg_hba.conf that requires 'password' for the 6 | # database pg8000_password 7 | 8 | 9 | def test_password(db_kwargs): 10 | db_kwargs["database"] = "pg8000_password" 11 | 12 | # Should only raise an exception saying db doesn't exist 13 | with pytest.raises(DatabaseError, match="3D000"): 14 | with connect(**db_kwargs): 15 | pass 16 | -------------------------------------------------------------------------------- /test/dbapi/auth/test_scram-sha-256.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from pg8000.dbapi import DatabaseError, connect 4 | 5 | # This requires a line in pg_hba.conf that requires scram-sha-256 for the 6 | # database pg8000_scram_sha_256 7 | 8 | DB = "pg8000_scram_sha_256" 9 | 10 | 11 | @pytest.fixture 12 | def setup(con, cursor): 13 | con.autocommit = True 14 | try: 15 | cursor.execute(f"CREATE DATABASE {DB}") 16 | except DatabaseError: 17 | con.rollback() 18 | 19 | cursor.execute("ALTER SYSTEM SET ssl = off") 20 | cursor.execute("SELECT pg_reload_conf()") 21 | yield 22 | cursor.execute("ALTER SYSTEM SET ssl = on") 23 | cursor.execute("SELECT pg_reload_conf()") 24 | 25 | 26 | def test_scram_sha_256(setup, db_kwargs): 27 | db_kwargs["database"] = DB 28 | 29 | with connect(**db_kwargs): 30 | pass 31 | -------------------------------------------------------------------------------- /test/dbapi/auth/test_scram-sha-256_ssl.py: -------------------------------------------------------------------------------- 1 | from ssl import CERT_NONE, create_default_context 2 | import pytest 3 | 4 | from pg8000.dbapi import DatabaseError, connect 5 | 6 | # This requires a line in pg_hba.conf that requires scram-sha-256 for the 7 | # database pg8000_scram_sha_256 8 | 9 | DB = "pg8000_scram_sha_256" 10 | 11 | 12 | @pytest.fixture 13 | def setup(con, cursor): 14 | con.autocommit = True 15 | try: 16 | cursor.execute(f"CREATE DATABASE {DB}") 17 | except DatabaseError: 18 | con.rollback() 19 | 20 | 21 | def test_scram_sha_256(setup, db_kwargs): 22 | db_kwargs["database"] = DB 23 | 24 | con = connect(**db_kwargs) 25 | con.close() 26 | 27 | 28 | def test_scram_sha_256_ssl_context(setup, db_kwargs): 29 | ssl_context = create_default_context() 30 | ssl_context.check_hostname = False 31 | ssl_context.verify_mode = CERT_NONE 32 | 33 | db_kwargs["database"] = DB 34 | db_kwargs["ssl_context"] = ssl_context 35 | 36 | con = connect(**db_kwargs) 37 | con.close() 38 | -------------------------------------------------------------------------------- /test/dbapi/conftest.py: -------------------------------------------------------------------------------- 1 | from os import environ 2 | 3 | import pytest 4 | 5 | from test.utils import parse_server_version 6 | 7 | import pg8000.dbapi 8 | 9 | 10 | @pytest.fixture(scope="class") 11 | def db_kwargs(): 12 | db_connect = {"user": "postgres", "password": "pw"} 13 | 14 | for kw, var, f in [ 15 | ("host", "PGHOST", str), 16 | ("password", "PGPASSWORD", str), 17 | ("port", "PGPORT", int), 18 | ]: 19 | try: 20 | db_connect[kw] = f(environ[var]) 21 | except KeyError: 22 | pass 23 | 24 | return db_connect 25 | 26 | 27 | @pytest.fixture 28 | def con(request, db_kwargs): 29 | conn = pg8000.dbapi.connect(**db_kwargs) 30 | 31 | def fin(): 32 | try: 33 | conn.rollback() 34 | except pg8000.dbapi.InterfaceError: 35 | pass 36 | 37 | try: 38 | conn.close() 39 | except pg8000.dbapi.InterfaceError: 40 | pass 41 | 42 | request.addfinalizer(fin) 43 | return conn 44 | 45 | 46 | @pytest.fixture 47 | def cursor(request, con): 48 | cursor = con.cursor() 49 | 50 | def fin(): 51 | cursor.close() 52 | 53 | request.addfinalizer(fin) 54 | return cursor 55 | 56 | 57 | @pytest.fixture 58 | def pg_version(cursor): 59 | cursor.execute("select current_setting('server_version')") 60 | retval = cursor.fetchall() 61 | version = retval[0][0] 62 | major = parse_server_version(version) 63 | return major 64 | -------------------------------------------------------------------------------- /test/dbapi/test_benchmarks.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from pg8000.dbapi import connect 4 | 5 | 6 | @pytest.mark.parametrize( 7 | "txt", 8 | ( 9 | ("int2", "cast(id / 100 as int2)"), 10 | "cast(id as int4)", 11 | "cast(id * 100 as int8)", 12 | "(id % 2) = 0", 13 | "N'Static text string'", 14 | "cast(id / 100 as float4)", 15 | "cast(id / 100 as float8)", 16 | "cast(id / 100 as numeric)", 17 | "timestamp '2001-09-28'", 18 | ), 19 | ) 20 | def test_round_trips(db_kwargs, benchmark, txt): 21 | def torun(): 22 | with connect(**db_kwargs) as con: 23 | query = f"""SELECT {txt}, {txt}, {txt}, {txt}, {txt}, {txt}, {txt} 24 | FROM (SELECT generate_series(1, 10000) AS id) AS tbl""" 25 | cursor = con.cursor() 26 | cursor.execute(query) 27 | cursor.fetchall() 28 | cursor.close() 29 | 30 | benchmark(torun) 31 | -------------------------------------------------------------------------------- /test/dbapi/test_connection.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import socket 3 | import warnings 4 | 5 | import pytest 6 | 7 | from pg8000.dbapi import DatabaseError, InterfaceError, __version__, connect 8 | 9 | 10 | def test_unix_socket_missing(): 11 | conn_params = {"unix_sock": "/file-does-not-exist", "user": "doesn't-matter"} 12 | 13 | with pytest.raises(InterfaceError): 14 | with connect(**conn_params): 15 | pass 16 | 17 | 18 | def test_internet_socket_connection_refused(): 19 | conn_params = {"port": 0, "user": "doesn't-matter"} 20 | 21 | with pytest.raises( 22 | InterfaceError, 23 | match="Can't create a connection to host localhost and port 0 " 24 | "\\(timeout is None and source_address is None\\).", 25 | ): 26 | with connect(**conn_params): 27 | pass 28 | 29 | 30 | def test_Connection_plain_socket(db_kwargs): 31 | host = db_kwargs.get("host", "localhost") 32 | port = db_kwargs.get("port", 5432) 33 | with socket.create_connection((host, port)) as sock: 34 | conn_params = { 35 | "sock": sock, 36 | "user": db_kwargs["user"], 37 | "password": db_kwargs["password"], 38 | "ssl_context": False, 39 | } 40 | 41 | with connect(**conn_params) as con: 42 | cur = con.cursor() 43 | 44 | cur.execute("SELECT 1") 45 | res = cur.fetchall() 46 | assert res[0][0] == 1 47 | 48 | 49 | def test_database_missing(db_kwargs): 50 | db_kwargs["database"] = "missing-db" 51 | with pytest.raises(DatabaseError): 52 | with connect(**db_kwargs): 53 | pass 54 | 55 | 56 | def test_database_name_unicode(db_kwargs): 57 | db_kwargs["database"] = "pg8000_sn\uFF6Fw" 58 | 59 | # Should only raise an exception saying db doesn't exist 60 | with pytest.raises(DatabaseError, match="3D000"): 61 | with connect(**db_kwargs): 62 | pass 63 | 64 | 65 | def test_database_name_bytes(db_kwargs): 66 | """Should only raise an exception saying db doesn't exist""" 67 | 68 | db_kwargs["database"] = bytes("pg8000_sn\uFF6Fw", "utf8") 69 | with pytest.raises(DatabaseError, match="3D000"): 70 | with connect(**db_kwargs): 71 | pass 72 | 73 | 74 | def test_password_bytes(con, db_kwargs): 75 | # Create user 76 | username = "boltzmann" 77 | password = "cha\uFF6Fs" 78 | cur = con.cursor() 79 | cur.execute("create user " + username + " with password '" + password + "';") 80 | con.commit() 81 | 82 | db_kwargs["user"] = username 83 | db_kwargs["password"] = password.encode("utf8") 84 | db_kwargs["database"] = "pg8000_md5" 85 | with pytest.raises(DatabaseError, match="3D000"): 86 | with connect(**db_kwargs): 87 | pass 88 | 89 | cur.execute("drop role " + username) 90 | con.commit() 91 | 92 | 93 | def test_application_name(db_kwargs): 94 | app_name = "my test application name" 95 | db_kwargs["application_name"] = app_name 96 | with connect(**db_kwargs) as db: 97 | cur = db.cursor() 98 | cur.execute( 99 | "select application_name from pg_stat_activity " 100 | " where pid = pg_backend_pid()" 101 | ) 102 | 103 | application_name = cur.fetchone()[0] 104 | assert application_name == app_name 105 | 106 | 107 | def test_application_name_integer(db_kwargs): 108 | db_kwargs["application_name"] = 1 109 | with pytest.raises( 110 | InterfaceError, 111 | match="The parameter application_name can't be of type " ".", 112 | ): 113 | with connect(**db_kwargs): 114 | pass 115 | 116 | 117 | def test_application_name_bytearray(db_kwargs): 118 | db_kwargs["application_name"] = bytearray(b"Philby") 119 | with connect(**db_kwargs): 120 | pass 121 | 122 | 123 | def test_notify(con): 124 | cursor = con.cursor() 125 | cursor.execute("select pg_backend_pid()") 126 | backend_pid = cursor.fetchall()[0][0] 127 | assert list(con.notifications) == [] 128 | cursor.execute("LISTEN test") 129 | cursor.execute("NOTIFY test") 130 | con.commit() 131 | 132 | cursor.execute("VALUES (1, 2), (3, 4), (5, 6)") 133 | assert len(con.notifications) == 1 134 | assert con.notifications[0] == (backend_pid, "test", "") 135 | 136 | 137 | def test_notify_with_payload(con): 138 | cursor = con.cursor() 139 | cursor.execute("select pg_backend_pid()") 140 | backend_pid = cursor.fetchall()[0][0] 141 | assert list(con.notifications) == [] 142 | cursor.execute("LISTEN test") 143 | cursor.execute("NOTIFY test, 'Parnham'") 144 | con.commit() 145 | 146 | cursor.execute("VALUES (1, 2), (3, 4), (5, 6)") 147 | assert len(con.notifications) == 1 148 | assert con.notifications[0] == (backend_pid, "test", "Parnham") 149 | 150 | 151 | def test_broken_pipe_read(con, db_kwargs): 152 | db1 = connect(**db_kwargs) 153 | cur1 = db1.cursor() 154 | cur2 = con.cursor() 155 | cur1.execute("select pg_backend_pid()") 156 | pid1 = cur1.fetchone()[0] 157 | 158 | cur2.execute("select pg_terminate_backend(%s)", (pid1,)) 159 | with pytest.raises(InterfaceError, match="network error"): 160 | cur1.execute("select 1") 161 | 162 | try: 163 | db1.close() 164 | except InterfaceError: 165 | pass 166 | 167 | 168 | def test_broken_pipe_flush(con, db_kwargs): 169 | db1 = connect(**db_kwargs) 170 | cur1 = db1.cursor() 171 | cur2 = con.cursor() 172 | cur1.execute("select pg_backend_pid()") 173 | pid1 = cur1.fetchone()[0] 174 | 175 | cur2.execute("select pg_terminate_backend(%s)", (pid1,)) 176 | try: 177 | cur1.execute("select 1") 178 | except BaseException: 179 | pass 180 | 181 | # Sometimes raises and sometime doesn't 182 | try: 183 | db1.close() 184 | except InterfaceError as e: 185 | assert str(e) == "network error" 186 | 187 | 188 | def test_broken_pipe_unpack(con): 189 | cur = con.cursor() 190 | cur.execute("select pg_backend_pid()") 191 | pid1 = cur.fetchone()[0] 192 | 193 | with pytest.raises(InterfaceError, match="network error"): 194 | cur.execute("select pg_terminate_backend(%s)", (pid1,)) 195 | 196 | 197 | def test_py_value_fail(con, mocker): 198 | # Ensure that if types.py_value throws an exception, the original 199 | # exception is raised (PG8000TestException), and the connection is 200 | # still usable after the error. 201 | 202 | class PG8000TestException(Exception): 203 | pass 204 | 205 | def raise_exception(val): 206 | raise PG8000TestException("oh noes!") 207 | 208 | mocker.patch.object(con, "py_types") 209 | con.py_types = {datetime.time: raise_exception} 210 | 211 | with pytest.raises(PG8000TestException): 212 | c = con.cursor() 213 | c.execute("SELECT CAST(%s AS TIME) AS f1", (datetime.time(10, 30),)) 214 | c.fetchall() 215 | 216 | # ensure that the connection is still usable for a new query 217 | c.execute("VALUES ('hw3'::text)") 218 | assert c.fetchone()[0] == "hw3" 219 | 220 | 221 | def test_no_data_error_recovery(con): 222 | for i in range(1, 4): 223 | with pytest.raises(DatabaseError) as e: 224 | c = con.cursor() 225 | c.execute("DROP TABLE t1") 226 | assert e.value.args[0]["C"] == "42P01" 227 | con.rollback() 228 | 229 | 230 | def test_closed_connection(db_kwargs): 231 | warnings.simplefilter("ignore") 232 | 233 | my_db = connect(**db_kwargs) 234 | cursor = my_db.cursor() 235 | my_db.close() 236 | with pytest.raises(my_db.InterfaceError, match="connection is closed"): 237 | cursor.execute("VALUES ('hw1'::text)") 238 | 239 | warnings.resetwarnings() 240 | 241 | 242 | def test_version(): 243 | try: 244 | from importlib.metadata import version 245 | except ImportError: 246 | from importlib_metadata import version 247 | 248 | ver = version("pg8000") 249 | 250 | assert __version__ == ver 251 | 252 | 253 | @pytest.mark.parametrize( 254 | "commit", 255 | [ 256 | "commit", 257 | "COMMIT;", 258 | ], 259 | ) 260 | def test_failed_transaction_commit_sql(cursor, commit): 261 | cursor.execute("create temporary table tt (f1 int primary key)") 262 | cursor.execute("begin") 263 | try: 264 | cursor.execute("insert into tt(f1) values(null)") 265 | except DatabaseError: 266 | pass 267 | 268 | with pytest.raises(InterfaceError): 269 | cursor.execute(commit) 270 | 271 | 272 | def test_failed_transaction_commit_method(con, cursor): 273 | cursor.execute("create temporary table tt (f1 int primary key)") 274 | cursor.execute("begin") 275 | try: 276 | cursor.execute("insert into tt(f1) values(null)") 277 | except DatabaseError: 278 | pass 279 | 280 | with pytest.raises(InterfaceError): 281 | con.commit() 282 | 283 | 284 | @pytest.mark.parametrize( 285 | "rollback", 286 | [ 287 | "rollback", 288 | "rollback;", 289 | "ROLLBACK ;", 290 | ], 291 | ) 292 | def test_failed_transaction_rollback_sql(cursor, rollback): 293 | cursor.execute("create temporary table tt (f1 int primary key)") 294 | cursor.execute("begin") 295 | try: 296 | cursor.execute("insert into tt(f1) values(null)") 297 | except DatabaseError: 298 | pass 299 | 300 | cursor.execute(rollback) 301 | 302 | 303 | def test_failed_transaction_rollback_method(cursor, con): 304 | cursor.execute("create temporary table tt (f1 int primary key)") 305 | cursor.execute("begin") 306 | try: 307 | cursor.execute("insert into tt(f1) values(null)") 308 | except DatabaseError: 309 | pass 310 | 311 | con.rollback() 312 | 313 | 314 | @pytest.mark.parametrize( 315 | "sql", 316 | [ 317 | "BEGIN", 318 | "select * from tt;", 319 | ], 320 | ) 321 | def test_failed_transaction_sql(cursor, sql): 322 | cursor.execute("create temporary table tt (f1 int primary key)") 323 | cursor.execute("begin") 324 | try: 325 | cursor.execute("insert into tt(f1) values(null)") 326 | except DatabaseError: 327 | pass 328 | 329 | with pytest.raises(DatabaseError): 330 | cursor.execute(sql) 331 | -------------------------------------------------------------------------------- /test/dbapi/test_copy.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | 3 | import pytest 4 | 5 | 6 | @pytest.fixture 7 | def db_table(request, con): 8 | cursor = con.cursor() 9 | cursor.execute( 10 | "CREATE TEMPORARY TABLE t1 (f1 int primary key, " 11 | "f2 int not null, f3 varchar(50) null) on commit drop" 12 | ) 13 | return con 14 | 15 | 16 | def test_copy_to_with_table(db_table): 17 | cursor = db_table.cursor() 18 | cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (1, 1, 1)) 19 | cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (2, 2, 2)) 20 | cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (3, 3, 3)) 21 | 22 | stream = BytesIO() 23 | cursor.execute("copy t1 to stdout", stream=stream) 24 | assert stream.getvalue() == b"1\t1\t1\n2\t2\t2\n3\t3\t3\n" 25 | assert cursor.rowcount == 3 26 | 27 | 28 | def test_copy_to_with_query(db_table): 29 | cursor = db_table.cursor() 30 | stream = BytesIO() 31 | cursor.execute( 32 | "COPY (SELECT 1 as One, 2 as Two) TO STDOUT WITH DELIMITER " 33 | "'X' CSV HEADER QUOTE AS 'Y' FORCE QUOTE Two", 34 | stream=stream, 35 | ) 36 | assert stream.getvalue() == b"oneXtwo\n1XY2Y\n" 37 | assert cursor.rowcount == 1 38 | 39 | 40 | def test_copy_from_with_table(db_table): 41 | cursor = db_table.cursor() 42 | stream = BytesIO(b"1\t1\t1\n2\t2\t2\n3\t3\t3\n") 43 | cursor.execute("copy t1 from STDIN", stream=stream) 44 | assert cursor.rowcount == 3 45 | 46 | cursor.execute("SELECT * FROM t1 ORDER BY f1") 47 | retval = cursor.fetchall() 48 | assert retval == ([1, 1, "1"], [2, 2, "2"], [3, 3, "3"]) 49 | 50 | 51 | def test_copy_from_with_query(db_table): 52 | cursor = db_table.cursor() 53 | stream = BytesIO(b"f1Xf2\n1XY1Y\n") 54 | cursor.execute( 55 | "COPY t1 (f1, f2) FROM STDIN WITH DELIMITER 'X' CSV HEADER " 56 | "QUOTE AS 'Y' FORCE NOT NULL f1", 57 | stream=stream, 58 | ) 59 | assert cursor.rowcount == 1 60 | 61 | cursor.execute("SELECT * FROM t1 ORDER BY f1") 62 | retval = cursor.fetchall() 63 | assert retval == ([1, 1, None],) 64 | 65 | 66 | def test_copy_from_with_error(db_table): 67 | cursor = db_table.cursor() 68 | stream = BytesIO(b"f1Xf2\n\n1XY1Y\n") 69 | with pytest.raises(BaseException) as e: 70 | cursor.execute( 71 | "COPY t1 (f1, f2) FROM STDIN WITH DELIMITER 'X' CSV HEADER " 72 | "QUOTE AS 'Y' FORCE NOT NULL f1", 73 | stream=stream, 74 | ) 75 | 76 | arg = { 77 | "S": ("ERROR",), 78 | "C": ("22P02",), 79 | "M": ( 80 | 'invalid input syntax for type integer: ""', 81 | 'invalid input syntax for integer: ""', 82 | ), 83 | "W": ('COPY t1, line 2, column f1: ""',), 84 | "F": ("numutils.c",), 85 | "R": ("pg_atoi", "pg_strtoint32", "pg_strtoint32_safe"), 86 | } 87 | earg = e.value.args[0] 88 | for k, v in arg.items(): 89 | assert earg[k] in v 90 | -------------------------------------------------------------------------------- /test/dbapi/test_dbapi.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import time 4 | 5 | import pytest 6 | 7 | from pg8000.dbapi import ( 8 | BINARY, 9 | Binary, 10 | Date, 11 | DateFromTicks, 12 | Time, 13 | TimeFromTicks, 14 | Timestamp, 15 | TimestampFromTicks, 16 | ) 17 | 18 | 19 | @pytest.fixture 20 | def has_tzset(): 21 | # Neither Windows nor Jython 2.5.3 have a time.tzset() so skip 22 | if hasattr(time, "tzset"): 23 | os.environ["TZ"] = "UTC" 24 | time.tzset() 25 | return True 26 | return False 27 | 28 | 29 | # DBAPI compatible interface tests 30 | @pytest.fixture 31 | def db_table(con, has_tzset): 32 | c = con.cursor() 33 | c.execute( 34 | "CREATE TEMPORARY TABLE t1 " 35 | "(f1 int primary key, f2 int not null, f3 varchar(50) null) " 36 | "ON COMMIT DROP" 37 | ) 38 | c.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (1, 1, None)) 39 | c.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (2, 10, None)) 40 | c.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (3, 100, None)) 41 | c.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (4, 1000, None)) 42 | c.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (5, 10000, None)) 43 | return con 44 | 45 | 46 | def test_parallel_queries(db_table): 47 | c1 = db_table.cursor() 48 | c2 = db_table.cursor() 49 | 50 | c1.execute("SELECT f1, f2, f3 FROM t1") 51 | while 1: 52 | row = c1.fetchone() 53 | if row is None: 54 | break 55 | f1, f2, f3 = row 56 | c2.execute("SELECT f1, f2, f3 FROM t1 WHERE f1 > %s", (f1,)) 57 | while 1: 58 | row = c2.fetchone() 59 | if row is None: 60 | break 61 | f1, f2, f3 = row 62 | 63 | 64 | def test_qmark(mocker, db_table): 65 | mocker.patch("pg8000.dbapi.paramstyle", "qmark") 66 | c1 = db_table.cursor() 67 | c1.execute("SELECT f1, f2, f3 FROM t1 WHERE f1 > ?", (3,)) 68 | while 1: 69 | row = c1.fetchone() 70 | if row is None: 71 | break 72 | f1, f2, f3 = row 73 | 74 | 75 | def test_numeric(mocker, db_table): 76 | mocker.patch("pg8000.dbapi.paramstyle", "numeric") 77 | c1 = db_table.cursor() 78 | c1.execute("SELECT f1, f2, f3 FROM t1 WHERE f1 > :1", (3,)) 79 | while 1: 80 | row = c1.fetchone() 81 | if row is None: 82 | break 83 | f1, f2, f3 = row 84 | 85 | 86 | def test_named(mocker, db_table): 87 | mocker.patch("pg8000.dbapi.paramstyle", "named") 88 | c1 = db_table.cursor() 89 | c1.execute("SELECT f1, f2, f3 FROM t1 WHERE f1 > :f1", {"f1": 3}) 90 | while 1: 91 | row = c1.fetchone() 92 | if row is None: 93 | break 94 | f1, f2, f3 = row 95 | 96 | 97 | def test_format(mocker, db_table): 98 | mocker.patch("pg8000.dbapi.paramstyle", "format") 99 | c1 = db_table.cursor() 100 | c1.execute("SELECT f1, f2, f3 FROM t1 WHERE f1 > %s", (3,)) 101 | while 1: 102 | row = c1.fetchone() 103 | if row is None: 104 | break 105 | f1, f2, f3 = row 106 | 107 | 108 | def test_pyformat(mocker, db_table): 109 | mocker.patch("pg8000.dbapi.paramstyle", "pyformat") 110 | c1 = db_table.cursor() 111 | c1.execute("SELECT f1, f2, f3 FROM t1 WHERE f1 > %(f1)s", {"f1": 3}) 112 | while 1: 113 | row = c1.fetchone() 114 | if row is None: 115 | break 116 | f1, f2, f3 = row 117 | 118 | 119 | def test_arraysize(db_table): 120 | c1 = db_table.cursor() 121 | c1.arraysize = 3 122 | c1.execute("SELECT * FROM t1") 123 | retval = c1.fetchmany() 124 | assert len(retval) == c1.arraysize 125 | 126 | 127 | def test_date(): 128 | val = Date(2001, 2, 3) 129 | assert val == datetime.date(2001, 2, 3) 130 | 131 | 132 | def test_time(): 133 | val = Time(4, 5, 6) 134 | assert val == datetime.time(4, 5, 6) 135 | 136 | 137 | def test_timestamp(): 138 | val = Timestamp(2001, 2, 3, 4, 5, 6) 139 | assert val == datetime.datetime(2001, 2, 3, 4, 5, 6) 140 | 141 | 142 | def test_date_from_ticks(has_tzset): 143 | if has_tzset: 144 | val = DateFromTicks(1173804319) 145 | assert val == datetime.date(2007, 3, 13) 146 | 147 | 148 | def testTimeFromTicks(has_tzset): 149 | if has_tzset: 150 | val = TimeFromTicks(1173804319) 151 | assert val == datetime.time(16, 45, 19) 152 | 153 | 154 | def test_timestamp_from_ticks(has_tzset): 155 | if has_tzset: 156 | val = TimestampFromTicks(1173804319) 157 | assert val == datetime.datetime(2007, 3, 13, 16, 45, 19) 158 | 159 | 160 | def test_binary(): 161 | v = Binary(b"\x00\x01\x02\x03\x02\x01\x00") 162 | assert v == b"\x00\x01\x02\x03\x02\x01\x00" 163 | assert isinstance(v, BINARY) 164 | 165 | 166 | def test_row_count(db_table): 167 | c1 = db_table.cursor() 168 | c1.execute("SELECT * FROM t1") 169 | 170 | assert 5 == c1.rowcount 171 | 172 | c1.execute("UPDATE t1 SET f3 = %s WHERE f2 > 101", ("Hello!",)) 173 | assert 2 == c1.rowcount 174 | 175 | c1.execute("DELETE FROM t1") 176 | assert 5 == c1.rowcount 177 | 178 | 179 | def test_fetch_many(db_table): 180 | cursor = db_table.cursor() 181 | cursor.arraysize = 2 182 | cursor.execute("SELECT * FROM t1") 183 | assert 2 == len(cursor.fetchmany()) 184 | assert 2 == len(cursor.fetchmany()) 185 | assert 1 == len(cursor.fetchmany()) 186 | assert 0 == len(cursor.fetchmany()) 187 | 188 | 189 | def test_iterator(db_table): 190 | cursor = db_table.cursor() 191 | cursor.execute("SELECT * FROM t1 ORDER BY f1") 192 | f1 = 0 193 | for row in cursor.fetchall(): 194 | next_f1 = row[0] 195 | assert next_f1 > f1 196 | f1 = next_f1 197 | 198 | 199 | # Vacuum can't be run inside a transaction, so we need to turn 200 | # autocommit on. 201 | def test_vacuum(con): 202 | con.autocommit = True 203 | cursor = con.cursor() 204 | cursor.execute("vacuum") 205 | 206 | 207 | def test_prepared_statement(con): 208 | cursor = con.cursor() 209 | cursor.execute("PREPARE gen_series AS SELECT generate_series(1, 10);") 210 | cursor.execute("EXECUTE gen_series") 211 | 212 | 213 | def test_cursor_type(cursor): 214 | assert str(type(cursor)) == "" 215 | -------------------------------------------------------------------------------- /test/dbapi/test_dbapi20.py: -------------------------------------------------------------------------------- 1 | import time 2 | import warnings 3 | 4 | import pytest 5 | 6 | import pg8000 7 | 8 | """ Python DB API 2.0 driver compliance unit test suite. 9 | 10 | This software is Public Domain and may be used without restrictions. 11 | 12 | "Now we have booze and barflies entering the discussion, plus rumours of 13 | DBAs on drugs... and I won't tell you what flashes through my mind each 14 | time I read the subject line with 'Anal Compliance' in it. All around 15 | this is turning out to be a thoroughly unwholesome unit test." 16 | 17 | -- Ian Bicking 18 | """ 19 | 20 | __rcs_id__ = "$Id: dbapi20.py,v 1.10 2003/10/09 03:14:14 zenzen Exp $" 21 | __version__ = "$Revision: 1.10 $"[11:-2] 22 | __author__ = "Stuart Bishop " 23 | 24 | 25 | # $Log: dbapi20.py,v $ 26 | # Revision 1.10 2003/10/09 03:14:14 zenzen 27 | # Add test for DB API 2.0 optional extension, where database exceptions 28 | # are exposed as attributes on the Connection object. 29 | # 30 | # Revision 1.9 2003/08/13 01:16:36 zenzen 31 | # Minor tweak from Stefan Fleiter 32 | # 33 | # Revision 1.8 2003/04/10 00:13:25 zenzen 34 | # Changes, as per suggestions by M.-A. Lemburg 35 | # - Add a table prefix, to ensure namespace collisions can always be avoided 36 | # 37 | # Revision 1.7 2003/02/26 23:33:37 zenzen 38 | # Break out DDL into helper functions, as per request by David Rushby 39 | # 40 | # Revision 1.6 2003/02/21 03:04:33 zenzen 41 | # Stuff from Henrik Ekelund: 42 | # added test_None 43 | # added test_nextset & hooks 44 | # 45 | # Revision 1.5 2003/02/17 22:08:43 zenzen 46 | # Implement suggestions and code from Henrik Eklund - test that 47 | # cursor.arraysize defaults to 1 & generic cursor.callproc test added 48 | # 49 | # Revision 1.4 2003/02/15 00:16:33 zenzen 50 | # Changes, as per suggestions and bug reports by M.-A. Lemburg, 51 | # Matthew T. Kromer, Federico Di Gregorio and Daniel Dittmar 52 | # - Class renamed 53 | # - Now a subclass of TestCase, to avoid requiring the driver stub 54 | # to use multiple inheritance 55 | # - Reversed the polarity of buggy test in test_description 56 | # - Test exception hierarchy correctly 57 | # - self.populate is now self._populate(), so if a driver stub 58 | # overrides self.ddl1 this change propagates 59 | # - VARCHAR columns now have a width, which will hopefully make the 60 | # DDL even more portible (this will be reversed if it causes more problems) 61 | # - cursor.rowcount being checked after various execute and fetchXXX methods 62 | # - Check for fetchall and fetchmany returning empty lists after results 63 | # are exhausted (already checking for empty lists if select retrieved 64 | # nothing 65 | # - Fix bugs in test_setoutputsize_basic and test_setinputsizes 66 | # 67 | 68 | 69 | """ Test a database self.driver for DB API 2.0 compatibility. 70 | This implementation tests Gadfly, but the TestCase 71 | is structured so that other self.drivers can subclass this 72 | test case to ensure compiliance with the DB-API. It is 73 | expected that this TestCase may be expanded in the future 74 | if ambiguities or edge conditions are discovered. 75 | 76 | The 'Optional Extensions' are not yet being tested. 77 | 78 | self.drivers should subclass this test, overriding setUp, tearDown, 79 | self.driver, connect_args and connect_kw_args. Class specification 80 | should be as follows: 81 | 82 | import dbapi20 83 | class mytest(dbapi20.DatabaseAPI20Test): 84 | [...] 85 | 86 | Don't 'import DatabaseAPI20Test from dbapi20', or you will 87 | confuse the unit tester - just 'import dbapi20'. 88 | """ 89 | 90 | # The self.driver module. This should be the module where the 'connect' 91 | # method is to be found 92 | driver = pg8000 93 | table_prefix = "dbapi20test_" # If you need to specify a prefix for tables 94 | 95 | ddl1 = "create table %sbooze (name varchar(20))" % table_prefix 96 | ddl2 = "create table %sbarflys (name varchar(20))" % table_prefix 97 | xddl1 = "drop table %sbooze" % table_prefix 98 | xddl2 = "drop table %sbarflys" % table_prefix 99 | 100 | # Name of stored procedure to convert 101 | # string->lowercase 102 | lowerfunc = "lower" 103 | 104 | 105 | # Some drivers may need to override these helpers, for example adding 106 | # a 'commit' after the execute. 107 | def executeDDL1(cursor): 108 | cursor.execute(ddl1) 109 | 110 | 111 | def executeDDL2(cursor): 112 | cursor.execute(ddl2) 113 | 114 | 115 | @pytest.fixture 116 | def db(request, con): 117 | def fin(): 118 | with con.cursor() as cur: 119 | for ddl in (xddl1, xddl2): 120 | try: 121 | cur.execute(ddl) 122 | con.commit() 123 | except driver.Error: 124 | # Assume table didn't exist. Other tests will check if 125 | # execute is busted. 126 | pass 127 | 128 | request.addfinalizer(fin) 129 | return con 130 | 131 | 132 | def test_apilevel(): 133 | # Must exist 134 | apilevel = driver.apilevel 135 | 136 | # Must equal 2.0 137 | assert apilevel == "2.0" 138 | 139 | 140 | def test_threadsafety(): 141 | try: 142 | # Must exist 143 | threadsafety = driver.threadsafety 144 | # Must be a valid value 145 | assert threadsafety in (0, 1, 2, 3) 146 | except AttributeError: 147 | assert False, "Driver doesn't define threadsafety" 148 | 149 | 150 | def test_paramstyle(): 151 | try: 152 | # Must exist 153 | paramstyle = driver.paramstyle 154 | # Must be a valid value 155 | assert paramstyle in ("qmark", "numeric", "named", "format", "pyformat") 156 | except AttributeError: 157 | assert False, "Driver doesn't define paramstyle" 158 | 159 | 160 | def test_Exceptions(): 161 | # Make sure required exceptions exist, and are in the 162 | # defined hierarchy. 163 | assert issubclass(driver.Warning, Exception) 164 | assert issubclass(driver.Error, Exception) 165 | assert issubclass(driver.InterfaceError, driver.Error) 166 | assert issubclass(driver.DatabaseError, driver.Error) 167 | assert issubclass(driver.OperationalError, driver.Error) 168 | assert issubclass(driver.IntegrityError, driver.Error) 169 | assert issubclass(driver.InternalError, driver.Error) 170 | assert issubclass(driver.ProgrammingError, driver.Error) 171 | assert issubclass(driver.NotSupportedError, driver.Error) 172 | 173 | 174 | def test_ExceptionsAsConnectionAttributes(con): 175 | # OPTIONAL EXTENSION 176 | # Test for the optional DB API 2.0 extension, where the exceptions 177 | # are exposed as attributes on the Connection object 178 | # I figure this optional extension will be implemented by any 179 | # driver author who is using this test suite, so it is enabled 180 | # by default. 181 | warnings.simplefilter("ignore") 182 | drv = driver 183 | assert con.Warning is drv.Warning 184 | assert con.Error is drv.Error 185 | assert con.InterfaceError is drv.InterfaceError 186 | assert con.DatabaseError is drv.DatabaseError 187 | assert con.OperationalError is drv.OperationalError 188 | assert con.IntegrityError is drv.IntegrityError 189 | assert con.InternalError is drv.InternalError 190 | assert con.ProgrammingError is drv.ProgrammingError 191 | assert con.NotSupportedError is drv.NotSupportedError 192 | warnings.resetwarnings() 193 | 194 | 195 | def test_commit(con): 196 | # Commit must work, even if it doesn't do anything 197 | con.commit() 198 | 199 | 200 | def test_rollback(con): 201 | # If rollback is defined, it should either work or throw 202 | # the documented exception 203 | if hasattr(con, "rollback"): 204 | try: 205 | con.rollback() 206 | except driver.NotSupportedError: 207 | pass 208 | 209 | 210 | def test_cursor(con): 211 | con.cursor() 212 | 213 | 214 | def test_cursor_isolation(con): 215 | # Make sure cursors created from the same connection have 216 | # the documented transaction isolation level 217 | cur1 = con.cursor() 218 | cur2 = con.cursor() 219 | executeDDL1(cur1) 220 | cur1.execute("insert into %sbooze values ('Victoria Bitter')" % (table_prefix)) 221 | cur2.execute("select name from %sbooze" % table_prefix) 222 | booze = cur2.fetchall() 223 | assert len(booze) == 1 224 | assert len(booze[0]) == 1 225 | assert booze[0][0] == "Victoria Bitter" 226 | 227 | 228 | def test_description(con): 229 | cur = con.cursor() 230 | executeDDL1(cur) 231 | assert cur.description is None, ( 232 | "cursor.description should be none after executing a " 233 | "statement that can return no rows (such as DDL)" 234 | ) 235 | cur.execute("select name from %sbooze" % table_prefix) 236 | assert len(cur.description) == 1, "cursor.description describes too many columns" 237 | assert ( 238 | len(cur.description[0]) == 7 239 | ), "cursor.description[x] tuples must have 7 elements" 240 | assert ( 241 | cur.description[0][0].lower() == "name" 242 | ), "cursor.description[x][0] must return column name" 243 | assert cur.description[0][1] == driver.STRING, ( 244 | "cursor.description[x][1] must return column type. Got %r" 245 | % cur.description[0][1] 246 | ) 247 | 248 | # Make sure self.description gets reset 249 | executeDDL2(cur) 250 | assert cur.description is None, ( 251 | "cursor.description not being set to None when executing " 252 | "no-result statements (eg. DDL)" 253 | ) 254 | 255 | 256 | def test_rowcount(cursor): 257 | executeDDL1(cursor) 258 | assert cursor.rowcount == -1, ( 259 | "cursor.rowcount should be -1 after executing no-result " "statements" 260 | ) 261 | cursor.execute("insert into %sbooze values ('Victoria Bitter')" % (table_prefix)) 262 | assert cursor.rowcount in (-1, 1), ( 263 | "cursor.rowcount should == number or rows inserted, or " 264 | "set to -1 after executing an insert statement" 265 | ) 266 | cursor.execute("select name from %sbooze" % table_prefix) 267 | assert cursor.rowcount in (-1, 1), ( 268 | "cursor.rowcount should == number of rows returned, or " 269 | "set to -1 after executing a select statement" 270 | ) 271 | executeDDL2(cursor) 272 | assert cursor.rowcount == -1, ( 273 | "cursor.rowcount not being reset to -1 after executing " "no-result statements" 274 | ) 275 | 276 | 277 | def test_close(con): 278 | cur = con.cursor() 279 | con.close() 280 | 281 | # cursor.execute should raise an Error if called after connection 282 | # closed 283 | with pytest.raises(driver.Error): 284 | executeDDL1(cur) 285 | 286 | # connection.commit should raise an Error if called after connection' 287 | # closed.' 288 | with pytest.raises(driver.Error): 289 | con.commit() 290 | 291 | # connection.close should raise an Error if called more than once 292 | with pytest.raises(driver.Error): 293 | con.close() 294 | 295 | 296 | def test_execute(con): 297 | cur = con.cursor() 298 | _paraminsert(cur) 299 | 300 | 301 | def _paraminsert(cur): 302 | executeDDL1(cur) 303 | cur.execute("insert into %sbooze values ('Victoria Bitter')" % (table_prefix)) 304 | assert cur.rowcount in (-1, 1) 305 | 306 | if driver.paramstyle == "qmark": 307 | cur.execute("insert into %sbooze values (?)" % table_prefix, ("Cooper's",)) 308 | elif driver.paramstyle == "numeric": 309 | cur.execute("insert into %sbooze values (:1)" % table_prefix, ("Cooper's",)) 310 | elif driver.paramstyle == "named": 311 | cur.execute( 312 | "insert into %sbooze values (:beer)" % table_prefix, {"beer": "Cooper's"} 313 | ) 314 | elif driver.paramstyle == "format": 315 | cur.execute("insert into %sbooze values (%%s)" % table_prefix, ("Cooper's",)) 316 | elif driver.paramstyle == "pyformat": 317 | cur.execute( 318 | "insert into %sbooze values (%%(beer)s)" % table_prefix, 319 | {"beer": "Cooper's"}, 320 | ) 321 | else: 322 | assert False, "Invalid paramstyle" 323 | 324 | assert cur.rowcount in (-1, 1) 325 | 326 | cur.execute("select name from %sbooze" % table_prefix) 327 | res = cur.fetchall() 328 | assert len(res) == 2, "cursor.fetchall returned too few rows" 329 | beers = [res[0][0], res[1][0]] 330 | beers.sort() 331 | assert beers[0] == "Cooper's", ( 332 | "cursor.fetchall retrieved incorrect data, or data inserted " "incorrectly" 333 | ) 334 | assert beers[1] == "Victoria Bitter", ( 335 | "cursor.fetchall retrieved incorrect data, or data inserted " "incorrectly" 336 | ) 337 | 338 | 339 | def test_executemany(cursor): 340 | executeDDL1(cursor) 341 | largs = [("Cooper's",), ("Boag's",)] 342 | margs = [{"beer": "Cooper's"}, {"beer": "Boag's"}] 343 | if driver.paramstyle == "qmark": 344 | cursor.executemany("insert into %sbooze values (?)" % table_prefix, largs) 345 | elif driver.paramstyle == "numeric": 346 | cursor.executemany("insert into %sbooze values (:1)" % table_prefix, largs) 347 | elif driver.paramstyle == "named": 348 | cursor.executemany("insert into %sbooze values (:beer)" % table_prefix, margs) 349 | elif driver.paramstyle == "format": 350 | cursor.executemany("insert into %sbooze values (%%s)" % table_prefix, largs) 351 | elif driver.paramstyle == "pyformat": 352 | cursor.executemany( 353 | "insert into %sbooze values (%%(beer)s)" % (table_prefix), margs 354 | ) 355 | else: 356 | assert False, "Unknown paramstyle" 357 | 358 | assert cursor.rowcount in (-1, 2), ( 359 | "insert using cursor.executemany set cursor.rowcount to " 360 | "incorrect value %r" % cursor.rowcount 361 | ) 362 | 363 | cursor.execute("select name from %sbooze" % table_prefix) 364 | res = cursor.fetchall() 365 | assert len(res) == 2, "cursor.fetchall retrieved incorrect number of rows" 366 | beers = [res[0][0], res[1][0]] 367 | beers.sort() 368 | assert beers[0] == "Boag's", "incorrect data retrieved" 369 | assert beers[1] == "Cooper's", "incorrect data retrieved" 370 | 371 | 372 | def test_fetchone(cursor): 373 | # cursor.fetchone should raise an Error if called before 374 | # executing a select-type query 375 | with pytest.raises(driver.Error): 376 | cursor.fetchone() 377 | 378 | # cursor.fetchone should raise an Error if called after 379 | # executing a query that cannot return rows 380 | executeDDL1(cursor) 381 | with pytest.raises(driver.Error): 382 | cursor.fetchone() 383 | 384 | cursor.execute("select name from %sbooze" % table_prefix) 385 | assert cursor.fetchone() is None, ( 386 | "cursor.fetchone should return None if a query retrieves " "no rows" 387 | ) 388 | assert cursor.rowcount in (-1, 0) 389 | 390 | # cursor.fetchone should raise an Error if called after 391 | # executing a query that cannot return rows 392 | cursor.execute("insert into %sbooze values ('Victoria Bitter')" % (table_prefix)) 393 | with pytest.raises(driver.Error): 394 | cursor.fetchone() 395 | 396 | cursor.execute("select name from %sbooze" % table_prefix) 397 | r = cursor.fetchone() 398 | assert len(r) == 1, "cursor.fetchone should have retrieved a single row" 399 | assert r[0] == "Victoria Bitter", "cursor.fetchone retrieved incorrect data" 400 | assert ( 401 | cursor.fetchone() is None 402 | ), "cursor.fetchone should return None if no more rows available" 403 | assert cursor.rowcount in (-1, 1) 404 | 405 | 406 | samples = [ 407 | "Carlton Cold", 408 | "Carlton Draft", 409 | "Mountain Goat", 410 | "Redback", 411 | "Victoria Bitter", 412 | "XXXX", 413 | ] 414 | 415 | 416 | def _populate(): 417 | """Return a list of sql commands to setup the DB for the fetch 418 | tests. 419 | """ 420 | populate = [ 421 | "insert into %sbooze values ('%s')" % (table_prefix, s) for s in samples 422 | ] 423 | return populate 424 | 425 | 426 | def test_fetchmany(cursor): 427 | # cursor.fetchmany should raise an Error if called without 428 | # issuing a query 429 | with pytest.raises(driver.Error): 430 | cursor.fetchmany(4) 431 | 432 | executeDDL1(cursor) 433 | for sql in _populate(): 434 | cursor.execute(sql) 435 | 436 | cursor.execute("select name from %sbooze" % table_prefix) 437 | r = cursor.fetchmany() 438 | assert len(r) == 1, ( 439 | "cursor.fetchmany retrieved incorrect number of rows, " 440 | "default of arraysize is one." 441 | ) 442 | cursor.arraysize = 10 443 | r = cursor.fetchmany(3) # Should get 3 rows 444 | assert len(r) == 3, "cursor.fetchmany retrieved incorrect number of rows" 445 | r = cursor.fetchmany(4) # Should get 2 more 446 | assert len(r) == 2, "cursor.fetchmany retrieved incorrect number of rows" 447 | r = cursor.fetchmany(4) # Should be an empty sequence 448 | assert len(r) == 0, ( 449 | "cursor.fetchmany should return an empty sequence after " 450 | "results are exhausted" 451 | ) 452 | assert cursor.rowcount in (-1, 6) 453 | 454 | # Same as above, using cursor.arraysize 455 | cursor.arraysize = 4 456 | cursor.execute("select name from %sbooze" % table_prefix) 457 | r = cursor.fetchmany() # Should get 4 rows 458 | assert len(r) == 4, "cursor.arraysize not being honoured by fetchmany" 459 | r = cursor.fetchmany() # Should get 2 more 460 | assert len(r) == 2 461 | r = cursor.fetchmany() # Should be an empty sequence 462 | assert len(r) == 0 463 | assert cursor.rowcount in (-1, 6) 464 | 465 | cursor.arraysize = 6 466 | cursor.execute("select name from %sbooze" % table_prefix) 467 | rows = cursor.fetchmany() # Should get all rows 468 | assert cursor.rowcount in (-1, 6) 469 | assert len(rows) == 6 470 | assert len(rows) == 6 471 | rows = [row[0] for row in rows] 472 | rows.sort() 473 | 474 | # Make sure we get the right data back out 475 | for i in range(0, 6): 476 | assert rows[i] == samples[i], "incorrect data retrieved by cursor.fetchmany" 477 | 478 | rows = cursor.fetchmany() # Should return an empty list 479 | assert len(rows) == 0, ( 480 | "cursor.fetchmany should return an empty sequence if " 481 | "called after the whole result set has been fetched" 482 | ) 483 | assert cursor.rowcount in (-1, 6) 484 | 485 | executeDDL2(cursor) 486 | cursor.execute("select name from %sbarflys" % table_prefix) 487 | r = cursor.fetchmany() # Should get empty sequence 488 | assert len(r) == 0, ( 489 | "cursor.fetchmany should return an empty sequence if " "query retrieved no rows" 490 | ) 491 | assert cursor.rowcount in (-1, 0) 492 | 493 | 494 | def test_fetchall(cursor): 495 | # cursor.fetchall should raise an Error if called 496 | # without executing a query that may return rows (such 497 | # as a select) 498 | with pytest.raises(driver.Error): 499 | cursor.fetchall() 500 | 501 | executeDDL1(cursor) 502 | for sql in _populate(): 503 | cursor.execute(sql) 504 | 505 | # cursor.fetchall should raise an Error if called 506 | # after executing a a statement that cannot return rows 507 | with pytest.raises(driver.Error): 508 | cursor.fetchall() 509 | 510 | cursor.execute("select name from %sbooze" % table_prefix) 511 | rows = cursor.fetchall() 512 | assert cursor.rowcount in (-1, len(samples)) 513 | assert len(rows) == len(samples), "cursor.fetchall did not retrieve all rows" 514 | rows = [r[0] for r in rows] 515 | rows.sort() 516 | for i in range(0, len(samples)): 517 | assert rows[i] == samples[i], "cursor.fetchall retrieved incorrect rows" 518 | rows = cursor.fetchall() 519 | assert len(rows) == 0, ( 520 | "cursor.fetchall should return an empty list if called " 521 | "after the whole result set has been fetched" 522 | ) 523 | assert cursor.rowcount in (-1, len(samples)) 524 | 525 | executeDDL2(cursor) 526 | cursor.execute("select name from %sbarflys" % table_prefix) 527 | rows = cursor.fetchall() 528 | assert cursor.rowcount in (-1, 0) 529 | assert len(rows) == 0, ( 530 | "cursor.fetchall should return an empty list if " 531 | "a select query returns no rows" 532 | ) 533 | 534 | 535 | def test_mixedfetch(cursor): 536 | executeDDL1(cursor) 537 | for sql in _populate(): 538 | cursor.execute(sql) 539 | 540 | cursor.execute("select name from %sbooze" % table_prefix) 541 | rows1 = cursor.fetchone() 542 | rows23 = cursor.fetchmany(2) 543 | rows4 = cursor.fetchone() 544 | rows56 = cursor.fetchall() 545 | assert cursor.rowcount in (-1, 6) 546 | assert len(rows23) == 2, "fetchmany returned incorrect number of rows" 547 | assert len(rows56) == 2, "fetchall returned incorrect number of rows" 548 | 549 | rows = [rows1[0]] 550 | rows.extend([rows23[0][0], rows23[1][0]]) 551 | rows.append(rows4[0]) 552 | rows.extend([rows56[0][0], rows56[1][0]]) 553 | rows.sort() 554 | for i in range(0, len(samples)): 555 | assert rows[i] == samples[i], "incorrect data retrieved or inserted" 556 | 557 | 558 | def help_nextset_setUp(cur): 559 | """Should create a procedure called deleteme 560 | that returns two result sets, first the 561 | number of rows in booze then "name from booze" 562 | """ 563 | raise NotImplementedError("Helper not implemented") 564 | 565 | 566 | def help_nextset_tearDown(cur): 567 | "If cleaning up is needed after nextSetTest" 568 | raise NotImplementedError("Helper not implemented") 569 | 570 | 571 | def test_nextset(cursor): 572 | if not hasattr(cursor, "nextset"): 573 | return 574 | 575 | try: 576 | executeDDL1(cursor) 577 | sql = _populate() 578 | for sql in _populate(): 579 | cursor.execute(sql) 580 | 581 | help_nextset_setUp(cursor) 582 | 583 | cursor.callproc("deleteme") 584 | numberofrows = cursor.fetchone() 585 | assert numberofrows[0] == len(samples) 586 | assert cursor.nextset() 587 | names = cursor.fetchall() 588 | assert len(names) == len(samples) 589 | s = cursor.nextset() 590 | assert s is None, "No more return sets, should return None" 591 | finally: 592 | help_nextset_tearDown(cursor) 593 | 594 | 595 | def test_arraysize(cursor): 596 | # Not much here - rest of the tests for this are in test_fetchmany 597 | assert hasattr(cursor, "arraysize"), "cursor.arraysize must be defined" 598 | 599 | 600 | def test_setinputsizes(cursor): 601 | cursor.setinputsizes(25) 602 | 603 | 604 | def test_setoutputsize_basic(cursor): 605 | # Basic test is to make sure setoutputsize doesn't blow up 606 | cursor.setoutputsize(1000) 607 | cursor.setoutputsize(2000, 0) 608 | _paraminsert(cursor) # Make sure the cursor still works 609 | 610 | 611 | def test_None(cursor): 612 | executeDDL1(cursor) 613 | cursor.execute("insert into %sbooze values (NULL)" % table_prefix) 614 | cursor.execute("select name from %sbooze" % table_prefix) 615 | r = cursor.fetchall() 616 | assert len(r) == 1 617 | assert len(r[0]) == 1 618 | assert r[0][0] is None, "NULL value not returned as None" 619 | 620 | 621 | def test_Date(): 622 | driver.Date(2002, 12, 25) 623 | driver.DateFromTicks(time.mktime((2002, 12, 25, 0, 0, 0, 0, 0, 0))) 624 | # Can we assume this? API doesn't specify, but it seems implied 625 | # self.assertEqual(str(d1),str(d2)) 626 | 627 | 628 | def test_Time(): 629 | driver.Time(13, 45, 30) 630 | driver.TimeFromTicks(time.mktime((2001, 1, 1, 13, 45, 30, 0, 0, 0))) 631 | # Can we assume this? API doesn't specify, but it seems implied 632 | # self.assertEqual(str(t1),str(t2)) 633 | 634 | 635 | def test_Timestamp(): 636 | driver.Timestamp(2002, 12, 25, 13, 45, 30) 637 | driver.TimestampFromTicks(time.mktime((2002, 12, 25, 13, 45, 30, 0, 0, 0))) 638 | # Can we assume this? API doesn't specify, but it seems implied 639 | # self.assertEqual(str(t1),str(t2)) 640 | 641 | 642 | def test_Binary(): 643 | driver.Binary(b"Something") 644 | driver.Binary(b"") 645 | 646 | 647 | def test_STRING(): 648 | assert hasattr(driver, "STRING"), "module.STRING must be defined" 649 | 650 | 651 | def test_BINARY(): 652 | assert hasattr(driver, "BINARY"), "module.BINARY must be defined." 653 | 654 | 655 | def test_NUMBER(): 656 | assert hasattr(driver, "NUMBER"), "module.NUMBER must be defined." 657 | 658 | 659 | def test_DATETIME(): 660 | assert hasattr(driver, "DATETIME"), "module.DATETIME must be defined." 661 | 662 | 663 | def test_ROWID(): 664 | assert hasattr(driver, "ROWID"), "module.ROWID must be defined." 665 | -------------------------------------------------------------------------------- /test/dbapi/test_paramstyle.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from pg8000.dbapi import convert_paramstyle 4 | 5 | # "(id %% 2) = 0", 6 | 7 | 8 | @pytest.mark.parametrize( 9 | "query,statement", 10 | [ 11 | [ 12 | 'SELECT ?, ?, "field_?" FROM t ' 13 | "WHERE a='say ''what?''' AND b=? AND c=E'?\\'test\\'?'", 14 | 'SELECT $1, $2, "field_?" FROM t WHERE ' 15 | "a='say ''what?''' AND b=$3 AND c=E'?\\'test\\'?'", 16 | ], 17 | [ 18 | "SELECT ?, ?, * FROM t WHERE a=? AND b='are you ''sure?'", 19 | "SELECT $1, $2, * FROM t WHERE a=$3 AND b='are you ''sure?'", 20 | ], 21 | ], 22 | ) 23 | def test_qmark(query, statement): 24 | args = 1, 2, 3 25 | new_query, vals = convert_paramstyle("qmark", query, args) 26 | assert (new_query, vals) == (statement, args) 27 | 28 | 29 | @pytest.mark.parametrize( 30 | "query,expected", 31 | [ 32 | [ 33 | "SELECT sum(x)::decimal(5, 2) :2, :1, * FROM t WHERE a=:3", 34 | "SELECT sum(x)::decimal(5, 2) $2, $1, * FROM t WHERE a=$3", 35 | ], 36 | ], 37 | ) 38 | def test_numeric(query, expected): 39 | args = 1, 2, 3 40 | new_query, vals = convert_paramstyle("numeric", query, args) 41 | assert (new_query, vals) == (expected, args) 42 | 43 | 44 | @pytest.mark.parametrize( 45 | "query", 46 | [ 47 | "make_interval(days := 10)", 48 | ], 49 | ) 50 | def test_numeric_unchanged(query): 51 | args = 1, 2, 3 52 | new_query, vals = convert_paramstyle("numeric", query, args) 53 | assert (new_query, vals) == (query, args) 54 | 55 | 56 | @pytest.mark.parametrize( 57 | "query,args,expected_query,expected_args", 58 | [ 59 | [ 60 | "SELECT sum(x)::decimal(5, 2) :f_2, :f1 FROM t WHERE a=:f_2", 61 | {"f_2": 1, "f1": 2}, 62 | "SELECT sum(x)::decimal(5, 2) $1, $2 FROM t WHERE a=$1", 63 | (1, 2), 64 | ], 65 | ["SELECT $$'$$ = :v", {"v": "'"}, "SELECT $$'$$ = $1", ("'",)], 66 | ], 67 | ) 68 | def test_named(query, args, expected_query, expected_args): 69 | new_query, vals = convert_paramstyle("named", query, args) 70 | assert (new_query, vals) == (expected_query, expected_args) 71 | 72 | 73 | @pytest.mark.parametrize( 74 | "query,expected", 75 | [ 76 | [ 77 | "SELECT %s, %s, \"f1_%%\", E'txt_%%' " 78 | "FROM t WHERE a=%s AND b='75%%' AND c = '%' -- Comment with %", 79 | "SELECT $1, $2, \"f1_%%\", E'txt_%%' FROM t WHERE a=$3 AND " 80 | "b='75%%' AND c = '%' -- Comment with %", 81 | ], 82 | [ 83 | "SELECT -- Comment\n%s FROM t", 84 | "SELECT -- Comment\n$1 FROM t", 85 | ], 86 | ], 87 | ) 88 | def test_format_changed(query, expected): 89 | args = 1, 2, 3 90 | new_query, vals = convert_paramstyle("format", query, args) 91 | assert (new_query, vals) == (expected, args) 92 | 93 | 94 | @pytest.mark.parametrize( 95 | "query", 96 | [ 97 | r"""COMMENT ON TABLE test_schema.comment_test """ 98 | r"""IS 'the test % '' " \ table comment'""", 99 | ], 100 | ) 101 | def test_format_unchanged(query): 102 | args = 1, 2, 3 103 | new_query, vals = convert_paramstyle("format", query, args) 104 | assert (new_query, vals) == (query, args) 105 | 106 | 107 | def test_py_format(): 108 | args = {"f2": 1, "f1": 2, "f3": 3} 109 | 110 | new_query, vals = convert_paramstyle( 111 | "pyformat", 112 | "SELECT %(f2)s, %(f1)s, \"f1_%%\", E'txt_%%' " 113 | "FROM t WHERE a=%(f2)s AND b='75%%'", 114 | args, 115 | ) 116 | expected = "SELECT $1, $2, \"f1_%%\", E'txt_%%' FROM t WHERE a=$1 AND " "b='75%%'" 117 | assert (new_query, vals) == (expected, (1, 2)) 118 | 119 | 120 | def test_pyformat_format(): 121 | """pyformat should support %s and an array, too:""" 122 | args = 1, 2, 3 123 | new_query, vals = convert_paramstyle( 124 | "pyformat", 125 | "SELECT %s, %s, \"f1_%%\", E'txt_%%' " "FROM t WHERE a=%s AND b='75%%'", 126 | args, 127 | ) 128 | expected = "SELECT $1, $2, \"f1_%%\", E'txt_%%' FROM t WHERE a=$3 AND " "b='75%%'" 129 | assert (new_query, vals) == (expected, args) 130 | -------------------------------------------------------------------------------- /test/dbapi/test_query.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime as Datetime, timezone as Timezone 2 | 3 | import pytest 4 | 5 | import pg8000.dbapi 6 | from pg8000.converters import INET_ARRAY, INTEGER 7 | 8 | 9 | # Tests relating to the basic operation of the database driver, driven by the 10 | # pg8000 custom interface. 11 | 12 | 13 | @pytest.fixture 14 | def db_table(request, con): 15 | con.paramstyle = "format" 16 | cursor = con.cursor() 17 | cursor.execute( 18 | "CREATE TEMPORARY TABLE t1 (f1 int primary key, " 19 | "f2 bigint not null, f3 varchar(50) null) " 20 | ) 21 | 22 | def fin(): 23 | try: 24 | cursor = con.cursor() 25 | cursor.execute("drop table t1") 26 | except pg8000.dbapi.DatabaseError: 27 | pass 28 | 29 | request.addfinalizer(fin) 30 | return con 31 | 32 | 33 | def test_database_error(cursor): 34 | with pytest.raises(pg8000.dbapi.DatabaseError): 35 | cursor.execute("INSERT INTO t99 VALUES (1, 2, 3)") 36 | 37 | 38 | def test_parallel_queries(db_table): 39 | cursor = db_table.cursor() 40 | cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (1, 1, None)) 41 | cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (2, 10, None)) 42 | cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (3, 100, None)) 43 | cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (4, 1000, None)) 44 | cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (5, 10000, None)) 45 | c1 = db_table.cursor() 46 | c2 = db_table.cursor() 47 | c1.execute("SELECT f1, f2, f3 FROM t1") 48 | for row in c1.fetchall(): 49 | f1, f2, f3 = row 50 | c2.execute("SELECT f1, f2, f3 FROM t1 WHERE f1 > %s", (f1,)) 51 | for row in c2.fetchall(): 52 | f1, f2, f3 = row 53 | 54 | 55 | def test_parallel_open_portals(con): 56 | c1 = con.cursor() 57 | c2 = con.cursor() 58 | c1count, c2count = 0, 0 59 | q = "select * from generate_series(1, %s)" 60 | params = (100,) 61 | c1.execute(q, params) 62 | c2.execute(q, params) 63 | for c2row in c2.fetchall(): 64 | c2count += 1 65 | for c1row in c1.fetchall(): 66 | c1count += 1 67 | 68 | assert c1count == c2count 69 | 70 | 71 | # Run a query on a table, alter the structure of the table, then run the 72 | # original query again. 73 | 74 | 75 | def test_alter(db_table): 76 | cursor = db_table.cursor() 77 | cursor.execute("select * from t1") 78 | cursor.execute("alter table t1 drop column f3") 79 | cursor.execute("select * from t1") 80 | 81 | 82 | # Run a query on a table, drop then re-create the table, then run the 83 | # original query again. 84 | 85 | 86 | def test_create(db_table): 87 | cursor = db_table.cursor() 88 | cursor.execute("select * from t1") 89 | cursor.execute("drop table t1") 90 | cursor.execute("create temporary table t1 (f1 int primary key)") 91 | cursor.execute("select * from t1") 92 | 93 | 94 | def test_insert_returning(db_table): 95 | cursor = db_table.cursor() 96 | cursor.execute("CREATE TEMPORARY TABLE t2 (id serial, data text)") 97 | 98 | # Test INSERT ... RETURNING with one row... 99 | cursor.execute("INSERT INTO t2 (data) VALUES (%s) RETURNING id", ("test1",)) 100 | row_id = cursor.fetchone()[0] 101 | cursor.execute("SELECT data FROM t2 WHERE id = %s", (row_id,)) 102 | assert "test1" == cursor.fetchone()[0] 103 | 104 | assert cursor.rowcount == 1 105 | 106 | # Test with multiple rows... 107 | cursor.execute( 108 | "INSERT INTO t2 (data) VALUES (%s), (%s), (%s) " "RETURNING id", 109 | ("test2", "test3", "test4"), 110 | ) 111 | assert cursor.rowcount == 3 112 | ids = tuple([x[0] for x in cursor.fetchall()]) 113 | assert len(ids) == 3 114 | 115 | 116 | def test_row_count(db_table): 117 | cursor = db_table.cursor() 118 | expected_count = 57 119 | cursor.executemany( 120 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", 121 | tuple((i, i, None) for i in range(expected_count)), 122 | ) 123 | 124 | # Check rowcount after executemany 125 | assert expected_count == cursor.rowcount 126 | 127 | cursor.execute("SELECT * FROM t1") 128 | 129 | # Check row_count without doing any reading first... 130 | assert expected_count == cursor.rowcount 131 | 132 | # Check rowcount after reading some rows, make sure it still 133 | # works... 134 | for i in range(expected_count // 2): 135 | cursor.fetchone() 136 | assert expected_count == cursor.rowcount 137 | 138 | cursor = db_table.cursor() 139 | # Restart the cursor, read a few rows, and then check rowcount 140 | # again... 141 | cursor.execute("SELECT * FROM t1") 142 | for i in range(expected_count // 3): 143 | cursor.fetchone() 144 | assert expected_count == cursor.rowcount 145 | 146 | # Should be -1 for a command with no results 147 | cursor.execute("DROP TABLE t1") 148 | assert -1 == cursor.rowcount 149 | 150 | 151 | def test_row_count_update(db_table): 152 | cursor = db_table.cursor() 153 | cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (1, 1, None)) 154 | cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (2, 10, None)) 155 | cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (3, 100, None)) 156 | cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (4, 1000, None)) 157 | cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (5, 10000, None)) 158 | cursor.execute("UPDATE t1 SET f3 = %s WHERE f2 > 101", ("Hello!",)) 159 | assert cursor.rowcount == 2 160 | 161 | 162 | def test_int_oid(cursor): 163 | # https://bugs.launchpad.net/pg8000/+bug/230796 164 | cursor.execute("SELECT typname FROM pg_type WHERE oid = %s", (100,)) 165 | 166 | 167 | def test_unicode_query(cursor): 168 | cursor.execute( 169 | "CREATE TEMPORARY TABLE \u043c\u0435\u0441\u0442\u043e " 170 | "(\u0438\u043c\u044f VARCHAR(50), " 171 | "\u0430\u0434\u0440\u0435\u0441 VARCHAR(250))" 172 | ) 173 | 174 | 175 | def test_executemany(db_table): 176 | cursor = db_table.cursor() 177 | cursor.executemany( 178 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", 179 | ((1, 1, "Avast ye!"), (2, 1, None)), 180 | ) 181 | 182 | cursor.executemany( 183 | "select CAST(%s AS TIMESTAMP)", 184 | ((Datetime(2014, 5, 7, tzinfo=Timezone.utc),), (Datetime(2014, 5, 7),)), 185 | ) 186 | 187 | 188 | def test_executemany_setinputsizes(cursor): 189 | """Make sure that setinputsizes works for all the parameter sets""" 190 | 191 | cursor.execute( 192 | "CREATE TEMPORARY TABLE t1 (f1 int primary key, f2 inet[] not null) " 193 | ) 194 | 195 | cursor.setinputsizes(INTEGER, INET_ARRAY) 196 | cursor.executemany( 197 | "INSERT INTO t1 (f1, f2) VALUES (%s, %s)", ((1, ["1.1.1.1"]), (2, ["0.0.0.0"])) 198 | ) 199 | 200 | 201 | def test_executemany_no_param_sets(cursor): 202 | cursor.executemany("INSERT INTO t1 (f1, f2) VALUES (%s, %s)", []) 203 | assert cursor.rowcount == -1 204 | 205 | 206 | # Check that autocommit stays off 207 | # We keep track of whether we're in a transaction or not by using the 208 | # READY_FOR_QUERY message. 209 | def test_transactions(db_table): 210 | cursor = db_table.cursor() 211 | cursor.execute("commit") 212 | cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (1, 1, "Zombie")) 213 | cursor.execute("rollback") 214 | cursor.execute("select * from t1") 215 | 216 | assert cursor.rowcount == 0 217 | 218 | 219 | def test_in(cursor): 220 | cursor.execute("SELECT typname FROM pg_type WHERE oid = any(%s)", ([16, 23],)) 221 | ret = cursor.fetchall() 222 | assert ret[0][0] == "bool" 223 | 224 | 225 | def test_no_previous_tpc(con): 226 | con.tpc_begin("Stacey") 227 | cursor = con.cursor() 228 | cursor.execute("SELECT * FROM pg_type") 229 | con.tpc_commit() 230 | 231 | 232 | # Check that tpc_recover() doesn't start a transaction 233 | def test_tpc_recover(con): 234 | con.tpc_recover() 235 | cursor = con.cursor() 236 | con.autocommit = True 237 | 238 | # If tpc_recover() has started a transaction, this will fail 239 | cursor.execute("VACUUM") 240 | 241 | 242 | def test_tpc_prepare(con): 243 | xid = "Stacey" 244 | con.tpc_begin(xid) 245 | con.tpc_prepare() 246 | con.tpc_rollback(xid) 247 | 248 | 249 | def test_empty_query(cursor): 250 | """No exception thrown""" 251 | cursor.execute("") 252 | 253 | 254 | # rolling back when not in a transaction doesn't generate a warning 255 | def test_rollback_no_transaction(con): 256 | # Remove any existing notices 257 | con.notices.clear() 258 | 259 | # First, verify that a raw rollback does produce a notice 260 | con.execute_unnamed("rollback") 261 | 262 | assert 1 == len(con.notices) 263 | 264 | # 25P01 is the code for no_active_sql_tronsaction. It has 265 | # a message and severity name, but those might be 266 | # localized/depend on the server version. 267 | assert con.notices.pop().get(b"C") == b"25P01" 268 | 269 | # Now going through the rollback method doesn't produce 270 | # any notices because it knows we're not in a transaction. 271 | con.rollback() 272 | 273 | assert 0 == len(con.notices) 274 | 275 | 276 | @pytest.mark.parametrize("sizes,oids", [([0], [0]), ([float], [701])]) 277 | def test_setinputsizes(con, sizes, oids): 278 | cursor = con.cursor() 279 | cursor.setinputsizes(*sizes) 280 | assert cursor._input_oids == oids 281 | cursor.execute("select %s", (None,)) 282 | retval = cursor.fetchall() 283 | assert retval[0][0] is None 284 | 285 | 286 | def test_unexecuted_cursor_rowcount(con): 287 | cursor = con.cursor() 288 | assert cursor.rowcount == -1 289 | 290 | 291 | def test_unexecuted_cursor_description(con): 292 | cursor = con.cursor() 293 | assert cursor.description is None 294 | 295 | 296 | def test_callproc(pg_version, cursor): 297 | if pg_version > 10: 298 | cursor.execute( 299 | """ 300 | CREATE PROCEDURE echo(INOUT val text) 301 | LANGUAGE plpgsql AS 302 | $proc$ 303 | BEGIN 304 | END 305 | $proc$; 306 | """ 307 | ) 308 | 309 | cursor.callproc("echo", ["hello"]) 310 | assert cursor.fetchall() == (["hello"],) 311 | 312 | 313 | def test_null_result(db_table): 314 | cur = db_table.cursor() 315 | cur.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (1, 1, "a")) 316 | with pytest.raises(pg8000.dbapi.ProgrammingError): 317 | cur.fetchall() 318 | 319 | 320 | def test_not_parsed_if_no_params(mocker, cursor): 321 | mock_convert_paramstyle = mocker.patch("pg8000.dbapi.convert_paramstyle") 322 | cursor.execute("ROLLBACK") 323 | mock_convert_paramstyle.assert_not_called() 324 | -------------------------------------------------------------------------------- /test/legacy/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tlocke/pg8000/46c00021ade1d19466b07ed30392386c5f0a6b8e/test/legacy/__init__.py -------------------------------------------------------------------------------- /test/legacy/auth/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tlocke/pg8000/46c00021ade1d19466b07ed30392386c5f0a6b8e/test/legacy/auth/__init__.py -------------------------------------------------------------------------------- /test/legacy/auth/test_gss.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from pg8000 import InterfaceError, connect 4 | 5 | 6 | def test_gss(db_kwargs): 7 | """This requires a line in pg_hba.conf that requires gss for the database 8 | pg8000_gss 9 | """ 10 | db_kwargs["database"] = "pg8000_gss" 11 | 12 | # Should raise an exception saying gss isn't supported 13 | with pytest.raises( 14 | InterfaceError, 15 | match="Authentication method 7 not supported by pg8000.", 16 | ): 17 | connect(**db_kwargs) 18 | -------------------------------------------------------------------------------- /test/legacy/auth/test_md5.py: -------------------------------------------------------------------------------- 1 | def test_md5(con): 2 | """Called by GitHub Actions with auth method md5. 3 | We just need to check that we can get a connection. 4 | """ 5 | pass 6 | -------------------------------------------------------------------------------- /test/legacy/auth/test_md5_ssl.py: -------------------------------------------------------------------------------- 1 | import ssl 2 | 3 | from pg8000 import connect 4 | 5 | 6 | def test_ssl(db_kwargs): 7 | context = ssl.create_default_context() 8 | context.check_hostname = False 9 | context.verify_mode = ssl.CERT_NONE 10 | 11 | db_kwargs["ssl_context"] = context 12 | 13 | with connect(**db_kwargs): 14 | pass 15 | -------------------------------------------------------------------------------- /test/legacy/auth/test_password.py: -------------------------------------------------------------------------------- 1 | def test_password(con): 2 | """Called by GitHub Actions with auth method password. 3 | We just need to check that we can get a connection. 4 | """ 5 | pass 6 | -------------------------------------------------------------------------------- /test/legacy/auth/test_scram-sha-256.py: -------------------------------------------------------------------------------- 1 | def test_scram_sha_256(con): 2 | """Called by GitHub Actions with auth method scram-sha-256. 3 | We just need to check that we can get a connection. 4 | """ 5 | pass 6 | -------------------------------------------------------------------------------- /test/legacy/auth/test_scram-sha-256_ssl.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from pg8000 import DatabaseError, connect 4 | 5 | # This requires a line in pg_hba.conf that requires scram-sha-256 for the 6 | # database pg8000_scram_sha_256 7 | 8 | DB = "pg8000_scram_sha_256" 9 | 10 | 11 | @pytest.fixture 12 | def setup(con): 13 | try: 14 | con.run(f"CREATE DATABASE {DB}") 15 | except DatabaseError: 16 | pass 17 | 18 | 19 | def test_scram_sha_256_plus(setup, db_kwargs): 20 | db_kwargs["database"] = DB 21 | 22 | with connect(**db_kwargs): 23 | pass 24 | -------------------------------------------------------------------------------- /test/legacy/conftest.py: -------------------------------------------------------------------------------- 1 | from os import environ 2 | 3 | import pytest 4 | 5 | import pg8000 6 | 7 | 8 | @pytest.fixture(scope="class") 9 | def db_kwargs(): 10 | db_connect = {"user": "postgres", "password": "pw"} 11 | 12 | for kw, var, f in [ 13 | ("host", "PGHOST", str), 14 | ("password", "PGPASSWORD", str), 15 | ("port", "PGPORT", int), 16 | ]: 17 | try: 18 | db_connect[kw] = f(environ[var]) 19 | except KeyError: 20 | pass 21 | 22 | return db_connect 23 | 24 | 25 | @pytest.fixture 26 | def con(request, db_kwargs): 27 | conn = pg8000.connect(**db_kwargs) 28 | 29 | def fin(): 30 | try: 31 | conn.rollback() 32 | except pg8000.InterfaceError: 33 | pass 34 | 35 | try: 36 | conn.close() 37 | except pg8000.InterfaceError: 38 | pass 39 | 40 | request.addfinalizer(fin) 41 | return conn 42 | 43 | 44 | @pytest.fixture 45 | def cursor(request, con): 46 | cursor = con.cursor() 47 | 48 | def fin(): 49 | cursor.close() 50 | 51 | request.addfinalizer(fin) 52 | return cursor 53 | -------------------------------------------------------------------------------- /test/legacy/stress.py: -------------------------------------------------------------------------------- 1 | from contextlib import closing 2 | 3 | import pg8000 4 | from pg8000.tests.connection_settings import db_connect 5 | 6 | 7 | with closing(pg8000.connect(**db_connect)) as db: 8 | for i in range(100): 9 | cursor = db.cursor() 10 | cursor.execute( 11 | """ 12 | SELECT n.nspname as "Schema", 13 | pg_catalog.format_type(t.oid, NULL) AS "Name", 14 | pg_catalog.obj_description(t.oid, 'pg_type') as "Description" 15 | FROM pg_catalog.pg_type t 16 | LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace 17 | left join pg_catalog.pg_namespace kj on n.oid = t.typnamespace 18 | WHERE (t.typrelid = 0 19 | OR (SELECT c.relkind = 'c' 20 | FROM pg_catalog.pg_class c WHERE c.oid = t.typrelid)) 21 | AND NOT EXISTS( 22 | SELECT 1 FROM pg_catalog.pg_type el 23 | WHERE el.oid = t.typelem AND el.typarray = t.oid) 24 | AND pg_catalog.pg_type_is_visible(t.oid) 25 | ORDER BY 1, 2;""" 26 | ) 27 | -------------------------------------------------------------------------------- /test/legacy/test_benchmarks.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | @pytest.mark.parametrize( 5 | "txt", 6 | ( 7 | ("int2", "cast(id / 100 as int2)"), 8 | "cast(id as int4)", 9 | "cast(id * 100 as int8)", 10 | "(id % 2) = 0", 11 | "N'Static text string'", 12 | "cast(id / 100 as float4)", 13 | "cast(id / 100 as float8)", 14 | "cast(id / 100 as numeric)", 15 | "timestamp '2001-09-28'", 16 | ), 17 | ) 18 | def test_round_trips(con, benchmark, txt): 19 | def torun(): 20 | query = """SELECT {0} AS column1, {0} AS column2, {0} AS column3, 21 | {0} AS column4, {0} AS column5, {0} AS column6, {0} AS column7 22 | FROM (SELECT generate_series(1, 10000) AS id) AS tbl""".format( 23 | txt 24 | ) 25 | con.run(query) 26 | 27 | benchmark(torun) 28 | -------------------------------------------------------------------------------- /test/legacy/test_connection.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from pg8000 import DatabaseError, InterfaceError, ProgrammingError, connect 4 | 5 | 6 | def testUnixSocketMissing(): 7 | conn_params = {"unix_sock": "/file-does-not-exist", "user": "doesn't-matter"} 8 | 9 | with pytest.raises(InterfaceError): 10 | connect(**conn_params) 11 | 12 | 13 | def test_internet_socket_connection_refused(): 14 | conn_params = {"port": 0, "user": "doesn't-matter"} 15 | 16 | with pytest.raises( 17 | InterfaceError, 18 | match="Can't create a connection to host localhost and port 0 " 19 | "\\(timeout is None and source_address is None\\).", 20 | ): 21 | connect(**conn_params) 22 | 23 | 24 | def testDatabaseMissing(db_kwargs): 25 | db_kwargs["database"] = "missing-db" 26 | with pytest.raises(ProgrammingError): 27 | connect(**db_kwargs) 28 | 29 | 30 | def test_notify(con): 31 | backend_pid = con.run("select pg_backend_pid()")[0][0] 32 | assert list(con.notifications) == [] 33 | con.run("LISTEN test") 34 | con.run("NOTIFY test") 35 | con.commit() 36 | 37 | con.run("VALUES (1, 2), (3, 4), (5, 6)") 38 | assert len(con.notifications) == 1 39 | assert con.notifications[0] == (backend_pid, "test", "") 40 | 41 | 42 | def test_notify_with_payload(con): 43 | backend_pid = con.run("select pg_backend_pid()")[0][0] 44 | assert list(con.notifications) == [] 45 | con.run("LISTEN test") 46 | con.run("NOTIFY test, 'Parnham'") 47 | con.commit() 48 | 49 | con.run("VALUES (1, 2), (3, 4), (5, 6)") 50 | assert len(con.notifications) == 1 51 | assert con.notifications[0] == (backend_pid, "test", "Parnham") 52 | 53 | 54 | # This requires a line in pg_hba.conf that requires md5 for the database 55 | # pg8000_md5 56 | 57 | 58 | def testMd5(db_kwargs): 59 | db_kwargs["database"] = "pg8000_md5" 60 | 61 | # Should only raise an exception saying db doesn't exist 62 | with pytest.raises(ProgrammingError, match="3D000"): 63 | connect(**db_kwargs) 64 | 65 | 66 | # This requires a line in pg_hba.conf that requires 'password' for the 67 | # database pg8000_password 68 | 69 | 70 | def testPassword(db_kwargs): 71 | db_kwargs["database"] = "pg8000_password" 72 | 73 | # Should only raise an exception saying db doesn't exist 74 | with pytest.raises(ProgrammingError, match="3D000"): 75 | connect(**db_kwargs) 76 | 77 | 78 | def testUnicodeDatabaseName(db_kwargs): 79 | db_kwargs["database"] = "pg8000_sn\uFF6Fw" 80 | 81 | # Should only raise an exception saying db doesn't exist 82 | with pytest.raises(ProgrammingError, match="3D000"): 83 | connect(**db_kwargs) 84 | 85 | 86 | def testBytesDatabaseName(db_kwargs): 87 | """Should only raise an exception saying db doesn't exist""" 88 | 89 | db_kwargs["database"] = bytes("pg8000_sn\uFF6Fw", "utf8") 90 | with pytest.raises(ProgrammingError, match="3D000"): 91 | connect(**db_kwargs) 92 | 93 | 94 | def testBytesPassword(con, db_kwargs): 95 | # Create user 96 | username = "boltzmann" 97 | password = "cha\uFF6Fs" 98 | with con.cursor() as cur: 99 | cur.execute("create user " + username + " with password '" + password + "';") 100 | con.commit() 101 | 102 | db_kwargs["user"] = username 103 | db_kwargs["password"] = password.encode("utf8") 104 | db_kwargs["database"] = "pg8000_md5" 105 | with pytest.raises(ProgrammingError, match="3D000"): 106 | connect(**db_kwargs) 107 | 108 | cur.execute("drop role " + username) 109 | con.commit() 110 | 111 | 112 | def test_broken_pipe_read(con, db_kwargs): 113 | db1 = connect(**db_kwargs) 114 | cur1 = db1.cursor() 115 | cur2 = con.cursor() 116 | cur1.execute("select pg_backend_pid()") 117 | pid1 = cur1.fetchone()[0] 118 | 119 | cur2.execute("select pg_terminate_backend(%s)", (pid1,)) 120 | with pytest.raises(InterfaceError, match="network error"): 121 | cur1.execute("select 1") 122 | 123 | try: 124 | db1.close() 125 | except InterfaceError: 126 | pass 127 | 128 | 129 | def test_broken_pipe_flush(con, db_kwargs): 130 | db1 = connect(**db_kwargs) 131 | cur1 = db1.cursor() 132 | cur2 = con.cursor() 133 | cur1.execute("select pg_backend_pid()") 134 | pid1 = cur1.fetchone()[0] 135 | 136 | cur2.execute("select pg_terminate_backend(%s)", (pid1,)) 137 | try: 138 | cur1.execute("select 1") 139 | except BaseException: 140 | pass 141 | 142 | # Can do an assert_raises when we're on 3.8 or above 143 | try: 144 | db1.close() 145 | except InterfaceError as e: 146 | assert str(e) == "network error" 147 | 148 | 149 | def test_broken_pipe_unpack(con): 150 | cur = con.cursor() 151 | cur.execute("select pg_backend_pid()") 152 | pid1 = cur.fetchone()[0] 153 | 154 | with pytest.raises(InterfaceError, match="network error"): 155 | cur.execute("select pg_terminate_backend(%s)", (pid1,)) 156 | 157 | 158 | def testApplicatioName(db_kwargs): 159 | app_name = "my test application name" 160 | db_kwargs["application_name"] = app_name 161 | with connect(**db_kwargs) as db: 162 | cur = db.cursor() 163 | cur.execute( 164 | "select application_name from pg_stat_activity " 165 | " where pid = pg_backend_pid()" 166 | ) 167 | 168 | application_name = cur.fetchone()[0] 169 | assert application_name == app_name 170 | 171 | 172 | def test_application_name_integer(db_kwargs): 173 | db_kwargs["application_name"] = 1 174 | with pytest.raises( 175 | InterfaceError, 176 | match="The parameter application_name can't be of type .", 177 | ): 178 | connect(**db_kwargs) 179 | 180 | 181 | def test_application_name_bytearray(db_kwargs): 182 | db_kwargs["application_name"] = bytearray(b"Philby") 183 | with connect(**db_kwargs): 184 | pass 185 | 186 | 187 | # This requires a line in pg_hba.conf that requires scram-sha-256 for the 188 | # database pg8000_scram_sha_256 189 | 190 | DB = "pg8000_scram_sha_256" 191 | 192 | 193 | @pytest.fixture 194 | def setup(con, cursor): 195 | con.autocommit = True 196 | try: 197 | cursor.execute(f"CREATE DATABASE {DB}") 198 | except DatabaseError: 199 | con.rollback() 200 | 201 | 202 | def test_scram_sha_256(setup, db_kwargs): 203 | db_kwargs["database"] = DB 204 | 205 | con = connect(**db_kwargs) 206 | con.close() 207 | 208 | 209 | @pytest.mark.parametrize( 210 | "commit", 211 | [ 212 | "commit", 213 | "COMMIT;", 214 | ], 215 | ) 216 | def test_failed_transaction_commit_sql(cursor, commit): 217 | cursor.execute("create temporary table tt (f1 int primary key)") 218 | cursor.execute("begin") 219 | try: 220 | cursor.execute("insert into tt(f1) values(null)") 221 | except DatabaseError: 222 | pass 223 | 224 | with pytest.raises(InterfaceError): 225 | cursor.execute(commit) 226 | 227 | 228 | def test_failed_transaction_commit_method(con, cursor): 229 | cursor.execute("create temporary table tt (f1 int primary key)") 230 | cursor.execute("begin") 231 | try: 232 | cursor.execute("insert into tt(f1) values(null)") 233 | except DatabaseError: 234 | pass 235 | 236 | with pytest.raises(InterfaceError): 237 | con.commit() 238 | 239 | 240 | @pytest.mark.parametrize( 241 | "rollback", 242 | [ 243 | "rollback", 244 | "rollback;", 245 | "ROLLBACK ;", 246 | ], 247 | ) 248 | def test_failed_transaction_rollback_sql(cursor, rollback): 249 | cursor.execute("create temporary table tt (f1 int primary key)") 250 | cursor.execute("begin") 251 | try: 252 | cursor.execute("insert into tt(f1) values(null)") 253 | except DatabaseError: 254 | pass 255 | 256 | cursor.execute(rollback) 257 | 258 | 259 | def test_failed_transaction_rollback_method(cursor, con): 260 | cursor.execute("create temporary table tt (f1 int primary key)") 261 | cursor.execute("begin") 262 | try: 263 | cursor.execute("insert into tt(f1) values(null)") 264 | except DatabaseError: 265 | pass 266 | 267 | con.rollback() 268 | -------------------------------------------------------------------------------- /test/legacy/test_copy.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | 3 | import pytest 4 | 5 | 6 | @pytest.fixture 7 | def db_table(request, con): 8 | with con.cursor() as cursor: 9 | cursor.execute( 10 | "CREATE TEMPORARY TABLE t1 (f1 int primary key, " 11 | "f2 int not null, f3 varchar(50) null) " 12 | "on commit drop" 13 | ) 14 | return con 15 | 16 | 17 | def test_copy_to_with_table(db_table): 18 | with db_table.cursor() as cursor: 19 | cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (1, 1, 1)) 20 | cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (2, 2, 2)) 21 | cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (3, 3, 3)) 22 | 23 | stream = BytesIO() 24 | cursor.execute("copy t1 to stdout", stream=stream) 25 | assert stream.getvalue() == b"1\t1\t1\n2\t2\t2\n3\t3\t3\n" 26 | assert cursor.rowcount == 3 27 | 28 | 29 | def test_copy_to_with_query(db_table): 30 | with db_table.cursor() as cursor: 31 | stream = BytesIO() 32 | cursor.execute( 33 | "COPY (SELECT 1 as One, 2 as Two) TO STDOUT WITH DELIMITER " 34 | "'X' CSV HEADER QUOTE AS 'Y' FORCE QUOTE Two", 35 | stream=stream, 36 | ) 37 | assert stream.getvalue() == b"oneXtwo\n1XY2Y\n" 38 | assert cursor.rowcount == 1 39 | 40 | 41 | def test_copy_from_with_table(db_table): 42 | with db_table.cursor() as cursor: 43 | stream = BytesIO(b"1\t1\t1\n2\t2\t2\n3\t3\t3\n") 44 | cursor.execute("copy t1 from STDIN", stream=stream) 45 | assert cursor.rowcount == 3 46 | 47 | cursor.execute("SELECT * FROM t1 ORDER BY f1") 48 | retval = cursor.fetchall() 49 | assert retval == ([1, 1, "1"], [2, 2, "2"], [3, 3, "3"]) 50 | 51 | 52 | def test_copy_from_with_query(db_table): 53 | with db_table.cursor() as cursor: 54 | stream = BytesIO(b"f1Xf2\n1XY1Y\n") 55 | cursor.execute( 56 | "COPY t1 (f1, f2) FROM STDIN WITH DELIMITER 'X' CSV HEADER " 57 | "QUOTE AS 'Y' FORCE NOT NULL f1", 58 | stream=stream, 59 | ) 60 | assert cursor.rowcount == 1 61 | 62 | cursor.execute("SELECT * FROM t1 ORDER BY f1") 63 | retval = cursor.fetchall() 64 | assert retval == ([1, 1, None],) 65 | 66 | 67 | def test_copy_from_with_error(db_table): 68 | with db_table.cursor() as cursor: 69 | stream = BytesIO(b"f1Xf2\n\n1XY1Y\n") 70 | with pytest.raises(BaseException) as e: 71 | cursor.execute( 72 | "COPY t1 (f1, f2) FROM STDIN WITH DELIMITER 'X' CSV HEADER " 73 | "QUOTE AS 'Y' FORCE NOT NULL f1", 74 | stream=stream, 75 | ) 76 | 77 | arg = { 78 | "S": ("ERROR",), 79 | "C": ("22P02",), 80 | "M": ( 81 | 'invalid input syntax for type integer: ""', 82 | 'invalid input syntax for integer: ""', 83 | ), 84 | "W": ('COPY t1, line 2, column f1: ""',), 85 | "F": ("numutils.c",), 86 | "R": ("pg_atoi", "pg_strtoint32", "pg_strtoint32_safe"), 87 | } 88 | earg = e.value.args[0] 89 | for k, v in arg.items(): 90 | assert earg[k] in v 91 | -------------------------------------------------------------------------------- /test/legacy/test_dbapi.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import time 4 | 5 | import pytest 6 | 7 | import pg8000 8 | 9 | 10 | @pytest.fixture 11 | def has_tzset(): 12 | # Neither Windows nor Jython 2.5.3 have a time.tzset() so skip 13 | if hasattr(time, "tzset"): 14 | os.environ["TZ"] = "UTC" 15 | time.tzset() 16 | return True 17 | return False 18 | 19 | 20 | # DBAPI compatible interface tests 21 | @pytest.fixture 22 | def db_table(con, has_tzset): 23 | with con.cursor() as c: 24 | c.execute( 25 | "CREATE TEMPORARY TABLE t1 " 26 | "(f1 int primary key, f2 int not null, f3 varchar(50) null) " 27 | "ON COMMIT DROP" 28 | ) 29 | c.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (1, 1, None)) 30 | c.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (2, 10, None)) 31 | c.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (3, 100, None)) 32 | c.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (4, 1000, None)) 33 | c.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (5, 10000, None)) 34 | return con 35 | 36 | 37 | def test_parallel_queries(db_table): 38 | with db_table.cursor() as c1, db_table.cursor() as c2: 39 | c1.execute("SELECT f1, f2, f3 FROM t1") 40 | while 1: 41 | row = c1.fetchone() 42 | if row is None: 43 | break 44 | f1, f2, f3 = row 45 | c2.execute("SELECT f1, f2, f3 FROM t1 WHERE f1 > %s", (f1,)) 46 | while 1: 47 | row = c2.fetchone() 48 | if row is None: 49 | break 50 | f1, f2, f3 = row 51 | 52 | 53 | def test_qmark(db_table): 54 | orig_paramstyle = pg8000.paramstyle 55 | try: 56 | pg8000.paramstyle = "qmark" 57 | with db_table.cursor() as c1: 58 | c1.execute("SELECT f1, f2, f3 FROM t1 WHERE f1 > ?", (3,)) 59 | while 1: 60 | row = c1.fetchone() 61 | if row is None: 62 | break 63 | f1, f2, f3 = row 64 | finally: 65 | pg8000.paramstyle = orig_paramstyle 66 | 67 | 68 | def test_numeric(db_table): 69 | orig_paramstyle = pg8000.paramstyle 70 | try: 71 | pg8000.paramstyle = "numeric" 72 | with db_table.cursor() as c1: 73 | c1.execute("SELECT f1, f2, f3 FROM t1 WHERE f1 > :1", (3,)) 74 | while 1: 75 | row = c1.fetchone() 76 | if row is None: 77 | break 78 | f1, f2, f3 = row 79 | finally: 80 | pg8000.paramstyle = orig_paramstyle 81 | 82 | 83 | def test_named(db_table): 84 | orig_paramstyle = pg8000.paramstyle 85 | try: 86 | pg8000.paramstyle = "named" 87 | with db_table.cursor() as c1: 88 | c1.execute("SELECT f1, f2, f3 FROM t1 WHERE f1 > :f1", {"f1": 3}) 89 | while 1: 90 | row = c1.fetchone() 91 | if row is None: 92 | break 93 | f1, f2, f3 = row 94 | finally: 95 | pg8000.paramstyle = orig_paramstyle 96 | 97 | 98 | def test_format(db_table): 99 | orig_paramstyle = pg8000.paramstyle 100 | try: 101 | pg8000.paramstyle = "format" 102 | with db_table.cursor() as c1: 103 | c1.execute("SELECT f1, f2, f3 FROM t1 WHERE f1 > %s", (3,)) 104 | while 1: 105 | row = c1.fetchone() 106 | if row is None: 107 | break 108 | f1, f2, f3 = row 109 | finally: 110 | pg8000.paramstyle = orig_paramstyle 111 | 112 | 113 | def test_pyformat(db_table): 114 | orig_paramstyle = pg8000.paramstyle 115 | try: 116 | pg8000.paramstyle = "pyformat" 117 | with db_table.cursor() as c1: 118 | c1.execute("SELECT f1, f2, f3 FROM t1 WHERE f1 > %(f1)s", {"f1": 3}) 119 | while 1: 120 | row = c1.fetchone() 121 | if row is None: 122 | break 123 | f1, f2, f3 = row 124 | finally: 125 | pg8000.paramstyle = orig_paramstyle 126 | 127 | 128 | def test_arraysize(db_table): 129 | with db_table.cursor() as c1: 130 | c1.arraysize = 3 131 | c1.execute("SELECT * FROM t1") 132 | retval = c1.fetchmany() 133 | assert len(retval) == c1.arraysize 134 | 135 | 136 | def test_date(): 137 | val = pg8000.Date(2001, 2, 3) 138 | assert val == datetime.date(2001, 2, 3) 139 | 140 | 141 | def test_time(): 142 | val = pg8000.Time(4, 5, 6) 143 | assert val == datetime.time(4, 5, 6) 144 | 145 | 146 | def test_timestamp(): 147 | val = pg8000.Timestamp(2001, 2, 3, 4, 5, 6) 148 | assert val == datetime.datetime(2001, 2, 3, 4, 5, 6) 149 | 150 | 151 | def test_date_from_ticks(has_tzset): 152 | if has_tzset: 153 | val = pg8000.DateFromTicks(1173804319) 154 | assert val == datetime.date(2007, 3, 13) 155 | 156 | 157 | def testTimeFromTicks(has_tzset): 158 | if has_tzset: 159 | val = pg8000.TimeFromTicks(1173804319) 160 | assert val == datetime.time(16, 45, 19) 161 | 162 | 163 | def test_timestamp_from_ticks(has_tzset): 164 | if has_tzset: 165 | val = pg8000.TimestampFromTicks(1173804319) 166 | assert val == datetime.datetime(2007, 3, 13, 16, 45, 19) 167 | 168 | 169 | def test_binary(): 170 | v = pg8000.Binary(b"\x00\x01\x02\x03\x02\x01\x00") 171 | assert v == b"\x00\x01\x02\x03\x02\x01\x00" 172 | assert isinstance(v, pg8000.BINARY) 173 | 174 | 175 | def test_row_count(db_table): 176 | with db_table.cursor() as c1: 177 | c1.execute("SELECT * FROM t1") 178 | 179 | assert 5 == c1.rowcount 180 | 181 | c1.execute("UPDATE t1 SET f3 = %s WHERE f2 > 101", ("Hello!",)) 182 | assert 2 == c1.rowcount 183 | 184 | c1.execute("DELETE FROM t1") 185 | assert 5 == c1.rowcount 186 | 187 | 188 | def test_fetch_many(db_table): 189 | with db_table.cursor() as cursor: 190 | cursor.arraysize = 2 191 | cursor.execute("SELECT * FROM t1") 192 | assert 2 == len(cursor.fetchmany()) 193 | assert 2 == len(cursor.fetchmany()) 194 | assert 1 == len(cursor.fetchmany()) 195 | assert 0 == len(cursor.fetchmany()) 196 | 197 | 198 | def test_iterator(db_table): 199 | with db_table.cursor() as cursor: 200 | cursor.execute("SELECT * FROM t1 ORDER BY f1") 201 | f1 = 0 202 | for row in cursor: 203 | next_f1 = row[0] 204 | assert next_f1 > f1 205 | f1 = next_f1 206 | 207 | 208 | # Vacuum can't be run inside a transaction, so we need to turn 209 | # autocommit on. 210 | def test_vacuum(con): 211 | con.autocommit = True 212 | with con.cursor() as cursor: 213 | cursor.execute("vacuum") 214 | 215 | 216 | def test_prepared_statement(con): 217 | with con.cursor() as cursor: 218 | cursor.execute("PREPARE gen_series AS SELECT generate_series(1, 10);") 219 | cursor.execute("EXECUTE gen_series") 220 | 221 | 222 | def test_cursor_type(cursor): 223 | assert str(type(cursor)) == "" 224 | -------------------------------------------------------------------------------- /test/legacy/test_error_recovery.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import warnings 3 | 4 | import pytest 5 | 6 | import pg8000 7 | 8 | 9 | class PG8000TestException(Exception): 10 | pass 11 | 12 | 13 | def raise_exception(val): 14 | raise PG8000TestException("oh noes!") 15 | 16 | 17 | def test_py_value_fail(con, mocker): 18 | # Ensure that if types.py_value throws an exception, the original 19 | # exception is raised (PG8000TestException), and the connection is 20 | # still usable after the error. 21 | mocker.patch.object(con, "py_types") 22 | con.py_types = {datetime.time: raise_exception} 23 | 24 | with con.cursor() as c, pytest.raises(PG8000TestException): 25 | c.execute("SELECT CAST(%s AS TIME)", (datetime.time(10, 30),)) 26 | c.fetchall() 27 | 28 | # ensure that the connection is still usable for a new query 29 | c.execute("VALUES ('hw3'::text)") 30 | assert c.fetchone()[0] == "hw3" 31 | 32 | 33 | def test_no_data_error_recovery(con): 34 | for i in range(1, 4): 35 | with con.cursor() as c, pytest.raises(pg8000.ProgrammingError) as e: 36 | c.execute("DROP TABLE t1") 37 | assert e.value.args[0]["C"] == "42P01" 38 | con.rollback() 39 | 40 | 41 | def testClosedConnection(db_kwargs): 42 | warnings.simplefilter("ignore") 43 | my_db = pg8000.connect(**db_kwargs) 44 | cursor = my_db.cursor() 45 | my_db.close() 46 | with pytest.raises(my_db.InterfaceError, match="connection is closed"): 47 | cursor.execute("VALUES ('hw1'::text)") 48 | 49 | warnings.resetwarnings() 50 | -------------------------------------------------------------------------------- /test/legacy/test_paramstyle.py: -------------------------------------------------------------------------------- 1 | from pg8000.legacy import convert_paramstyle as convert 2 | 3 | 4 | # Tests of the convert_paramstyle function. 5 | 6 | 7 | def test_qmark(): 8 | args = 1, 2, 3 9 | new_query, vals = convert( 10 | "qmark", 11 | 'SELECT ?, ?, "field_?" FROM t ' 12 | "WHERE a='say ''what?''' AND b=? AND c=E'?\\'test\\'?'", 13 | args, 14 | ) 15 | expected = ( 16 | 'SELECT $1, $2, "field_?" FROM t WHERE ' 17 | "a='say ''what?''' AND b=$3 AND c=E'?\\'test\\'?'" 18 | ) 19 | assert (new_query, vals) == (expected, args) 20 | 21 | 22 | def test_qmark_2(): 23 | args = 1, 2, 3 24 | new_query, vals = convert( 25 | "qmark", "SELECT ?, ?, * FROM t WHERE a=? AND b='are you ''sure?'", args 26 | ) 27 | expected = "SELECT $1, $2, * FROM t WHERE a=$3 AND b='are you ''sure?'" 28 | assert (new_query, vals) == (expected, args) 29 | 30 | 31 | def test_numeric(): 32 | args = 1, 2, 3 33 | new_query, vals = convert( 34 | "numeric", "SELECT sum(x)::decimal(5, 2) :2, :1, * FROM t WHERE a=:3", args 35 | ) 36 | expected = "SELECT sum(x)::decimal(5, 2) $2, $1, * FROM t WHERE a=$3" 37 | assert (new_query, vals) == (expected, args) 38 | 39 | 40 | def test_numeric_default_parameter(): 41 | args = 1, 2, 3 42 | new_query, vals = convert("numeric", "make_interval(days := 10)", args) 43 | 44 | assert (new_query, vals) == ("make_interval(days := 10)", args) 45 | 46 | 47 | def test_named(): 48 | args = { 49 | "f_2": 1, 50 | "f1": 2, 51 | } 52 | new_query, vals = convert( 53 | "named", "SELECT sum(x)::decimal(5, 2) :f_2, :f1 FROM t WHERE a=:f_2", args 54 | ) 55 | expected = "SELECT sum(x)::decimal(5, 2) $1, $2 FROM t WHERE a=$1" 56 | assert (new_query, vals) == (expected, (1, 2)) 57 | 58 | 59 | def test_format(): 60 | args = 1, 2, 3 61 | new_query, vals = convert( 62 | "format", 63 | "SELECT %s, %s, \"f1_%%\", E'txt_%%' " 64 | "FROM t WHERE a=%s AND b='75%%' AND c = '%' -- Comment with %", 65 | args, 66 | ) 67 | expected = ( 68 | "SELECT $1, $2, \"f1_%%\", E'txt_%%' FROM t WHERE a=$3 AND " 69 | "b='75%%' AND c = '%' -- Comment with %" 70 | ) 71 | assert (new_query, vals) == (expected, args) 72 | 73 | sql = ( 74 | r"""COMMENT ON TABLE test_schema.comment_test """ 75 | r"""IS 'the test % '' " \ table comment'""" 76 | ) 77 | new_query, vals = convert("format", sql, args) 78 | assert (new_query, vals) == (sql, args) 79 | 80 | 81 | def test_format_multiline(): 82 | args = 1, 2, 3 83 | new_query, vals = convert("format", "SELECT -- Comment\n%s FROM t", args) 84 | assert (new_query, vals) == ("SELECT -- Comment\n$1 FROM t", args) 85 | 86 | 87 | def test_py_format(): 88 | args = {"f2": 1, "f1": 2, "f3": 3} 89 | 90 | new_query, vals = convert( 91 | "pyformat", 92 | "SELECT %(f2)s, %(f1)s, \"f1_%%\", E'txt_%%' " 93 | "FROM t WHERE a=%(f2)s AND b='75%%'", 94 | args, 95 | ) 96 | expected = "SELECT $1, $2, \"f1_%%\", E'txt_%%' FROM t WHERE a=$1 AND " "b='75%%'" 97 | assert (new_query, vals) == (expected, (1, 2)) 98 | 99 | # pyformat should support %s and an array, too: 100 | args = 1, 2, 3 101 | new_query, vals = convert( 102 | "pyformat", 103 | "SELECT %s, %s, \"f1_%%\", E'txt_%%' " "FROM t WHERE a=%s AND b='75%%'", 104 | args, 105 | ) 106 | expected = "SELECT $1, $2, \"f1_%%\", E'txt_%%' FROM t WHERE a=$3 AND " "b='75%%'" 107 | assert (new_query, vals) == (expected, args) 108 | -------------------------------------------------------------------------------- /test/legacy/test_prepared_statement.py: -------------------------------------------------------------------------------- 1 | def test_prepare(con): 2 | con.prepare("SELECT CAST(:v AS INTEGER)") 3 | 4 | 5 | def test_run(con): 6 | ps = con.prepare("SELECT cast(:v as varchar)") 7 | ps.run(v="speedy") 8 | 9 | 10 | def test_run_with_no_results(con): 11 | ps = con.prepare("ROLLBACK") 12 | ps.run() 13 | -------------------------------------------------------------------------------- /test/legacy/test_query.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime as Datetime, timezone as Timezone 2 | from warnings import filterwarnings 3 | 4 | import pytest 5 | 6 | import pg8000 7 | from pg8000.converters import INET_ARRAY, INTEGER 8 | 9 | 10 | # Tests relating to the basic operation of the database driver, driven by the 11 | # pg8000 custom interface. 12 | 13 | 14 | @pytest.fixture 15 | def db_table(request, con): 16 | filterwarnings("ignore", "DB-API extension cursor.next()") 17 | filterwarnings("ignore", "DB-API extension cursor.__iter__()") 18 | con.paramstyle = "format" 19 | with con.cursor() as cursor: 20 | cursor.execute( 21 | "CREATE TEMPORARY TABLE t1 (f1 int primary key, " 22 | "f2 bigint not null, f3 varchar(50) null) " 23 | ) 24 | 25 | def fin(): 26 | try: 27 | with con.cursor() as cursor: 28 | cursor.execute("drop table t1") 29 | except pg8000.ProgrammingError: 30 | pass 31 | 32 | request.addfinalizer(fin) 33 | return con 34 | 35 | 36 | def test_database_error(cursor): 37 | with pytest.raises(pg8000.ProgrammingError): 38 | cursor.execute("INSERT INTO t99 VALUES (1, 2, 3)") 39 | 40 | 41 | def test_parallel_queries(db_table): 42 | with db_table.cursor() as cursor: 43 | cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (1, 1, None)) 44 | cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (2, 10, None)) 45 | cursor.execute( 46 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (3, 100, None) 47 | ) 48 | cursor.execute( 49 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (4, 1000, None) 50 | ) 51 | cursor.execute( 52 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (5, 10000, None) 53 | ) 54 | with db_table.cursor() as c1, db_table.cursor() as c2: 55 | c1.execute("SELECT f1, f2, f3 FROM t1") 56 | for row in c1: 57 | f1, f2, f3 = row 58 | c2.execute("SELECT f1, f2, f3 FROM t1 WHERE f1 > %s", (f1,)) 59 | for row in c2: 60 | f1, f2, f3 = row 61 | 62 | 63 | def test_parallel_open_portals(con): 64 | with con.cursor() as c1, con.cursor() as c2: 65 | c1count, c2count = 0, 0 66 | q = "select * from generate_series(1, %s)" 67 | params = (100,) 68 | c1.execute(q, params) 69 | c2.execute(q, params) 70 | for c2row in c2: 71 | c2count += 1 72 | for c1row in c1: 73 | c1count += 1 74 | 75 | assert c1count == c2count 76 | 77 | 78 | # Run a query on a table, alter the structure of the table, then run the 79 | # original query again. 80 | 81 | 82 | def test_alter(db_table): 83 | with db_table.cursor() as cursor: 84 | cursor.execute("select * from t1") 85 | cursor.execute("alter table t1 drop column f3") 86 | cursor.execute("select * from t1") 87 | 88 | 89 | # Run a query on a table, drop then re-create the table, then run the 90 | # original query again. 91 | 92 | 93 | def test_create(db_table): 94 | with db_table.cursor() as cursor: 95 | cursor.execute("select * from t1") 96 | cursor.execute("drop table t1") 97 | cursor.execute("create temporary table t1 (f1 int primary key)") 98 | cursor.execute("select * from t1") 99 | 100 | 101 | def test_insert_returning(db_table): 102 | with db_table.cursor() as cursor: 103 | cursor.execute("CREATE TABLE t2 (id serial, data text)") 104 | 105 | # Test INSERT ... RETURNING with one row... 106 | cursor.execute("INSERT INTO t2 (data) VALUES (%s) RETURNING id", ("test1",)) 107 | row_id = cursor.fetchone()[0] 108 | cursor.execute("SELECT data FROM t2 WHERE id = %s", (row_id,)) 109 | assert "test1" == cursor.fetchone()[0] 110 | 111 | assert cursor.rowcount == 1 112 | 113 | # Test with multiple rows... 114 | cursor.execute( 115 | "INSERT INTO t2 (data) VALUES (%s), (%s), (%s) " "RETURNING id", 116 | ("test2", "test3", "test4"), 117 | ) 118 | assert cursor.rowcount == 3 119 | ids = tuple([x[0] for x in cursor]) 120 | assert len(ids) == 3 121 | 122 | 123 | def test_row_count(db_table): 124 | with db_table.cursor() as cursor: 125 | expected_count = 57 126 | cursor.executemany( 127 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", 128 | tuple((i, i, None) for i in range(expected_count)), 129 | ) 130 | 131 | # Check rowcount after executemany 132 | assert expected_count == cursor.rowcount 133 | 134 | cursor.execute("SELECT * FROM t1") 135 | 136 | # Check row_count without doing any reading first... 137 | assert expected_count == cursor.rowcount 138 | 139 | # Check rowcount after reading some rows, make sure it still 140 | # works... 141 | for i in range(expected_count // 2): 142 | cursor.fetchone() 143 | assert expected_count == cursor.rowcount 144 | 145 | with db_table.cursor() as cursor: 146 | # Restart the cursor, read a few rows, and then check rowcount 147 | # again... 148 | cursor.execute("SELECT * FROM t1") 149 | for i in range(expected_count // 3): 150 | cursor.fetchone() 151 | assert expected_count == cursor.rowcount 152 | 153 | # Should be -1 for a command with no results 154 | cursor.execute("DROP TABLE t1") 155 | assert -1 == cursor.rowcount 156 | 157 | 158 | def test_row_count_update(db_table): 159 | with db_table.cursor() as cursor: 160 | cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (1, 1, None)) 161 | cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (2, 10, None)) 162 | cursor.execute( 163 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (3, 100, None) 164 | ) 165 | cursor.execute( 166 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (4, 1000, None) 167 | ) 168 | cursor.execute( 169 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (5, 10000, None) 170 | ) 171 | cursor.execute("UPDATE t1 SET f3 = %s WHERE f2 > 101", ("Hello!",)) 172 | assert cursor.rowcount == 2 173 | 174 | 175 | def test_int_oid(cursor): 176 | # https://bugs.launchpad.net/pg8000/+bug/230796 177 | cursor.execute("SELECT typname FROM pg_type WHERE oid = %s", (100,)) 178 | 179 | 180 | def test_unicode_query(cursor): 181 | cursor.execute( 182 | "CREATE TEMPORARY TABLE \u043c\u0435\u0441\u0442\u043e " 183 | "(\u0438\u043c\u044f VARCHAR(50), " 184 | "\u0430\u0434\u0440\u0435\u0441 VARCHAR(250))" 185 | ) 186 | 187 | 188 | def test_executemany(db_table): 189 | with db_table.cursor() as cursor: 190 | cursor.executemany( 191 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", 192 | ((1, 1, "Avast ye!"), (2, 1, None)), 193 | ) 194 | 195 | cursor.executemany( 196 | "SELECT CAST(%s AS TIMESTAMP)", 197 | ((Datetime(2014, 5, 7, tzinfo=Timezone.utc),), (Datetime(2014, 5, 7),)), 198 | ) 199 | 200 | 201 | def test_executemany_setinputsizes(cursor): 202 | """Make sure that setinputsizes works for all the parameter sets""" 203 | 204 | cursor.execute( 205 | "CREATE TEMPORARY TABLE t1 (f1 int primary key, f2 inet[] not null) " 206 | ) 207 | 208 | cursor.setinputsizes(INTEGER, INET_ARRAY) 209 | cursor.executemany( 210 | "INSERT INTO t1 (f1, f2) VALUES (%s, %s)", ((1, ["1.1.1.1"]), (2, ["0.0.0.0"])) 211 | ) 212 | 213 | 214 | def test_executemany_no_param_sets(cursor): 215 | cursor.executemany("INSERT INTO t1 (f1, f2) VALUES (%s, %s)", []) 216 | assert cursor.rowcount == -1 217 | 218 | 219 | # Check that autocommit stays off 220 | # We keep track of whether we're in a transaction or not by using the 221 | # READY_FOR_QUERY message. 222 | def test_transactions(db_table): 223 | with db_table.cursor() as cursor: 224 | cursor.execute("commit") 225 | cursor.execute( 226 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (1, 1, "Zombie") 227 | ) 228 | cursor.execute("rollback") 229 | cursor.execute("select * from t1") 230 | 231 | assert cursor.rowcount == 0 232 | 233 | 234 | def test_in(cursor): 235 | cursor.execute("SELECT typname FROM pg_type WHERE oid = any(%s)", ([16, 23],)) 236 | ret = cursor.fetchall() 237 | assert ret[0][0] == "bool" 238 | 239 | 240 | def test_no_previous_tpc(con): 241 | con.tpc_begin("Stacey") 242 | with con.cursor() as cursor: 243 | cursor.execute("SELECT * FROM pg_type") 244 | con.tpc_commit() 245 | 246 | 247 | # Check that tpc_recover() doesn't start a transaction 248 | def test_tpc_recover(con): 249 | con.tpc_recover() 250 | with con.cursor() as cursor: 251 | con.autocommit = True 252 | 253 | # If tpc_recover() has started a transaction, this will fail 254 | cursor.execute("VACUUM") 255 | 256 | 257 | def test_tpc_prepare(con): 258 | xid = "Stacey" 259 | con.tpc_begin(xid) 260 | con.tpc_prepare() 261 | con.tpc_rollback(xid) 262 | 263 | 264 | def test_empty_query(cursor): 265 | """No exception raised""" 266 | cursor.execute("") 267 | 268 | 269 | # rolling back when not in a transaction doesn't generate a warning 270 | def test_rollback_no_transaction(con): 271 | # Remove any existing notices 272 | con.notices.clear() 273 | 274 | # First, verify that a raw rollback does produce a notice 275 | con.execute_unnamed("rollback") 276 | 277 | assert 1 == len(con.notices) 278 | 279 | # 25P01 is the code for no_active_sql_tronsaction. It has 280 | # a message and severity name, but those might be 281 | # localized/depend on the server version. 282 | assert con.notices.pop().get(b"C") == b"25P01" 283 | 284 | # Now going through the rollback method doesn't produce 285 | # any notices because it knows we're not in a transaction. 286 | con.rollback() 287 | 288 | assert 0 == len(con.notices) 289 | 290 | 291 | def test_context_manager_class(con): 292 | assert "__enter__" in pg8000.legacy.Cursor.__dict__ 293 | assert "__exit__" in pg8000.legacy.Cursor.__dict__ 294 | 295 | with con.cursor() as cursor: 296 | cursor.execute("select 1") 297 | 298 | 299 | def test_close_prepared_statement(con): 300 | ps = con.prepare("select 1") 301 | ps.run() 302 | res = con.run("select count(*) from pg_prepared_statements") 303 | assert res[0][0] == 1 # Should have one prepared statement 304 | 305 | ps.close() 306 | 307 | res = con.run("select count(*) from pg_prepared_statements") 308 | assert res[0][0] == 0 # Should have no prepared statements 309 | 310 | 311 | def test_setinputsizes(con): 312 | cursor = con.cursor() 313 | cursor.setinputsizes(20) 314 | cursor.execute("select %s", (None,)) 315 | retval = cursor.fetchall() 316 | assert retval[0][0] is None 317 | 318 | 319 | def test_setinputsizes_class(con): 320 | cursor = con.cursor() 321 | cursor.setinputsizes(bytes) 322 | cursor.execute("select %s", (None,)) 323 | retval = cursor.fetchall() 324 | assert retval[0][0] is None 325 | 326 | 327 | def test_unexecuted_cursor_rowcount(con): 328 | cursor = con.cursor() 329 | assert cursor.rowcount == -1 330 | 331 | 332 | def test_unexecuted_cursor_description(con): 333 | cursor = con.cursor() 334 | assert cursor.description is None 335 | 336 | 337 | def test_not_parsed_if_no_params(mocker, cursor): 338 | mock_convert_paramstyle = mocker.patch("pg8000.legacy.convert_paramstyle") 339 | cursor.execute("ROLLBACK") 340 | mock_convert_paramstyle.assert_not_called() 341 | -------------------------------------------------------------------------------- /test/legacy/test_typeobjects.py: -------------------------------------------------------------------------------- 1 | from pg8000 import PGInterval 2 | 3 | 4 | def test_pginterval_constructor_days(): 5 | i = PGInterval(days=1) 6 | assert i.months is None 7 | assert i.days == 1 8 | assert i.microseconds is None 9 | -------------------------------------------------------------------------------- /test/native/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tlocke/pg8000/46c00021ade1d19466b07ed30392386c5f0a6b8e/test/native/__init__.py -------------------------------------------------------------------------------- /test/native/auth/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tlocke/pg8000/46c00021ade1d19466b07ed30392386c5f0a6b8e/test/native/auth/__init__.py -------------------------------------------------------------------------------- /test/native/auth/test_gss.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from pg8000.native import Connection, InterfaceError 4 | 5 | 6 | def test_gss(db_kwargs): 7 | """This requires a line in pg_hba.conf that requires gss for the database 8 | pg8000_gss 9 | """ 10 | 11 | db_kwargs["database"] = "pg8000_gss" 12 | 13 | # Should raise an exception saying gss isn't supported 14 | with pytest.raises( 15 | InterfaceError, 16 | match="Authentication method 7 not supported by pg8000.", 17 | ): 18 | Connection(**db_kwargs) 19 | -------------------------------------------------------------------------------- /test/native/auth/test_md5.py: -------------------------------------------------------------------------------- 1 | def test_md5(con): 2 | """Called by GitHub Actions with auth method md5. 3 | We just need to check that we can get a connection. 4 | """ 5 | pass 6 | -------------------------------------------------------------------------------- /test/native/auth/test_md5_ssl.py: -------------------------------------------------------------------------------- 1 | import ssl 2 | 3 | from pg8000.native import Connection 4 | 5 | 6 | def test_ssl(db_kwargs): 7 | context = ssl.create_default_context() 8 | context.check_hostname = False 9 | context.verify_mode = ssl.CERT_NONE 10 | 11 | db_kwargs["ssl_context"] = context 12 | 13 | with Connection(**db_kwargs): 14 | pass 15 | -------------------------------------------------------------------------------- /test/native/auth/test_password.py: -------------------------------------------------------------------------------- 1 | def test_password(con): 2 | """Called by GitHub Actions with auth method password. 3 | We just need to check that we can get a connection. 4 | """ 5 | pass 6 | -------------------------------------------------------------------------------- /test/native/auth/test_scram-sha-256.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from pg8000.native import Connection, DatabaseError, InterfaceError 4 | 5 | # This requires a line in pg_hba.conf that requires scram-sha-256 for the 6 | # database pg8000_scram_sha_256 7 | 8 | DB = "pg8000_scram_sha_256" 9 | 10 | 11 | @pytest.fixture 12 | def setup(con): 13 | try: 14 | con.run(f"CREATE DATABASE {DB}") 15 | except DatabaseError: 16 | pass 17 | con.run("ALTER SYSTEM SET ssl = off") 18 | con.run("SELECT pg_reload_conf()") 19 | yield 20 | con.run("ALTER SYSTEM SET ssl = on") 21 | con.run("SELECT pg_reload_conf()") 22 | 23 | 24 | def test_scram_sha_256(setup, db_kwargs): 25 | db_kwargs["database"] = DB 26 | 27 | with Connection(**db_kwargs): 28 | pass 29 | 30 | 31 | def test_scram_sha_256_ssl_False(setup, db_kwargs): 32 | db_kwargs["database"] = DB 33 | db_kwargs["ssl_context"] = False 34 | 35 | with Connection(**db_kwargs): 36 | pass 37 | 38 | 39 | def test_scram_sha_256_ssl_True(setup, db_kwargs): 40 | db_kwargs["database"] = DB 41 | db_kwargs["ssl_context"] = True 42 | 43 | with pytest.raises(InterfaceError, match="Server refuses SSL"): 44 | with Connection(**db_kwargs): 45 | pass 46 | -------------------------------------------------------------------------------- /test/native/auth/test_scram-sha-256_ssl.py: -------------------------------------------------------------------------------- 1 | from ssl import CERT_NONE, SSLSocket, create_default_context 2 | 3 | import pytest 4 | 5 | from pg8000.native import Connection, DatabaseError 6 | 7 | # This requires a line in pg_hba.conf that requires scram-sha-256 for the 8 | # database pg8000_scram_sha_256 9 | 10 | DB = "pg8000_scram_sha_256" 11 | 12 | 13 | @pytest.fixture 14 | def setup(con): 15 | try: 16 | con.run(f"CREATE DATABASE {DB}") 17 | except DatabaseError: 18 | pass 19 | 20 | 21 | def test_scram_sha_256_plus(setup, db_kwargs): 22 | db_kwargs["database"] = DB 23 | 24 | with Connection(**db_kwargs) as con: 25 | assert isinstance(con._usock, SSLSocket) 26 | 27 | 28 | def test_scram_sha_256_plus_ssl_True(setup, db_kwargs): 29 | db_kwargs["ssl_context"] = True 30 | db_kwargs["database"] = DB 31 | 32 | with Connection(**db_kwargs) as con: 33 | assert isinstance(con._usock, SSLSocket) 34 | 35 | 36 | def test_scram_sha_256_plus_ssl_custom(setup, db_kwargs): 37 | context = create_default_context() 38 | context.check_hostname = False 39 | context.verify_mode = CERT_NONE 40 | 41 | db_kwargs["ssl_context"] = context 42 | db_kwargs["database"] = DB 43 | 44 | with Connection(**db_kwargs) as con: 45 | assert isinstance(con._usock, SSLSocket) 46 | 47 | 48 | def test_scram_sha_256_plus_ssl_False(setup, db_kwargs): 49 | db_kwargs["ssl_context"] = False 50 | db_kwargs["database"] = DB 51 | 52 | with Connection(**db_kwargs) as con: 53 | assert not isinstance(con._usock, SSLSocket) 54 | -------------------------------------------------------------------------------- /test/native/conftest.py: -------------------------------------------------------------------------------- 1 | from os import environ 2 | 3 | import pytest 4 | 5 | from test.utils import parse_server_version 6 | 7 | import pg8000.native 8 | 9 | 10 | @pytest.fixture(scope="class") 11 | def db_kwargs(): 12 | db_connect = {"user": "postgres", "password": "pw"} 13 | 14 | for kw, var, f in [ 15 | ("host", "PGHOST", str), 16 | ("password", "PGPASSWORD", str), 17 | ("port", "PGPORT", int), 18 | ]: 19 | try: 20 | db_connect[kw] = f(environ[var]) 21 | except KeyError: 22 | pass 23 | 24 | return db_connect 25 | 26 | 27 | @pytest.fixture 28 | def con(request, db_kwargs): 29 | conn = pg8000.native.Connection(**db_kwargs) 30 | 31 | def fin(): 32 | try: 33 | conn.run("rollback") 34 | except pg8000.native.InterfaceError: 35 | pass 36 | 37 | try: 38 | conn.close() 39 | except pg8000.native.InterfaceError: 40 | pass 41 | 42 | request.addfinalizer(fin) 43 | return conn 44 | 45 | 46 | @pytest.fixture 47 | def pg_version(con): 48 | retval = con.run("select current_setting('server_version')") 49 | version = retval[0][0] 50 | major = parse_server_version(version) 51 | return major 52 | -------------------------------------------------------------------------------- /test/native/test_benchmarks.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | @pytest.mark.parametrize( 5 | "txt", 6 | ( 7 | ("int2", "cast(id / 100 as int2)"), 8 | "cast(id as int4)", 9 | "cast(id * 100 as int8)", 10 | "(id % 2) = 0", 11 | "N'Static text string'", 12 | "cast(id / 100 as float4)", 13 | "cast(id / 100 as float8)", 14 | "cast(id / 100 as numeric)", 15 | "timestamp '2001-09-28'", 16 | ), 17 | ) 18 | def test_round_trips(con, benchmark, txt): 19 | def torun(): 20 | query = """SELECT {0} AS column1, {0} AS column2, {0} AS column3, 21 | {0} AS column4, {0} AS column5, {0} AS column6, {0} AS column7 22 | FROM (SELECT generate_series(1, 10000) AS id) AS tbl""".format( 23 | txt 24 | ) 25 | con.run(query) 26 | 27 | benchmark(torun) 28 | -------------------------------------------------------------------------------- /test/native/test_connection.py: -------------------------------------------------------------------------------- 1 | import socket 2 | 3 | from datetime import time as Time 4 | 5 | import pytest 6 | 7 | from pg8000.native import Connection, DatabaseError, InterfaceError, __version__ 8 | 9 | 10 | def test_unix_socket_missing(): 11 | conn_params = {"unix_sock": "/file-does-not-exist", "user": "doesn't-matter"} 12 | 13 | with pytest.raises(InterfaceError): 14 | Connection(**conn_params) 15 | 16 | 17 | def test_internet_socket_connection_refused(): 18 | conn_params = {"port": 0, "user": "doesn't-matter"} 19 | 20 | with pytest.raises( 21 | InterfaceError, 22 | match="Can't create a connection to host localhost and port 0 " 23 | "\\(timeout is None and source_address is None\\).", 24 | ): 25 | Connection(**conn_params) 26 | 27 | 28 | def test_Connection_plain_socket(db_kwargs): 29 | host = db_kwargs.get("host", "localhost") 30 | port = db_kwargs.get("port", 5432) 31 | with socket.create_connection((host, port)) as sock: 32 | conn_params = { 33 | "sock": sock, 34 | "user": db_kwargs["user"], 35 | "password": db_kwargs["password"], 36 | "ssl_context": False, 37 | } 38 | 39 | with Connection(**conn_params) as con: 40 | res = con.run("SELECT 1") 41 | assert res[0][0] == 1 42 | 43 | 44 | def test_database_missing(db_kwargs): 45 | db_kwargs["database"] = "missing-db" 46 | with pytest.raises(DatabaseError): 47 | Connection(**db_kwargs) 48 | 49 | 50 | def test_notify(con): 51 | backend_pid = con.run("select pg_backend_pid()")[0][0] 52 | assert list(con.notifications) == [] 53 | con.run("LISTEN test") 54 | con.run("NOTIFY test") 55 | 56 | con.run("VALUES (1, 2), (3, 4), (5, 6)") 57 | assert len(con.notifications) == 1 58 | assert con.notifications[0] == (backend_pid, "test", "") 59 | 60 | 61 | def test_notify_with_payload(con): 62 | backend_pid = con.run("select pg_backend_pid()")[0][0] 63 | assert list(con.notifications) == [] 64 | con.run("LISTEN test") 65 | con.run("NOTIFY test, 'Parnham'") 66 | 67 | con.run("VALUES (1, 2), (3, 4), (5, 6)") 68 | assert len(con.notifications) == 1 69 | assert con.notifications[0] == (backend_pid, "test", "Parnham") 70 | 71 | 72 | # This requires a line in pg_hba.conf that requires md5 for the database 73 | # pg8000_md5 74 | 75 | 76 | def test_md5(db_kwargs): 77 | db_kwargs["database"] = "pg8000_md5" 78 | 79 | # Should only raise an exception saying db doesn't exist 80 | with pytest.raises(DatabaseError, match="3D000"): 81 | Connection(**db_kwargs) 82 | 83 | 84 | # This requires a line in pg_hba.conf that requires 'password' for the 85 | # database pg8000_password 86 | 87 | 88 | def test_password(db_kwargs): 89 | db_kwargs["database"] = "pg8000_password" 90 | 91 | # Should only raise an exception saying db doesn't exist 92 | with pytest.raises(DatabaseError, match="3D000"): 93 | Connection(**db_kwargs) 94 | 95 | 96 | def test_unicode_databaseName(db_kwargs): 97 | db_kwargs["database"] = "pg8000_sn\uFF6Fw" 98 | 99 | # Should only raise an exception saying db doesn't exist 100 | with pytest.raises(DatabaseError, match="3D000"): 101 | Connection(**db_kwargs) 102 | 103 | 104 | def test_bytes_databaseName(db_kwargs): 105 | """Should only raise an exception saying db doesn't exist""" 106 | 107 | db_kwargs["database"] = bytes("pg8000_sn\uFF6Fw", "utf8") 108 | with pytest.raises(DatabaseError, match="3D000"): 109 | Connection(**db_kwargs) 110 | 111 | 112 | def test_bytes_password(con, db_kwargs): 113 | # Create user 114 | username = "boltzmann" 115 | password = "cha\uFF6Fs" 116 | con.run("create user " + username + " with password '" + password + "';") 117 | 118 | db_kwargs["user"] = username 119 | db_kwargs["password"] = password.encode("utf8") 120 | db_kwargs["database"] = "pg8000_md5" 121 | with pytest.raises(DatabaseError, match="3D000"): 122 | Connection(**db_kwargs) 123 | 124 | con.run("drop role " + username) 125 | 126 | 127 | def test_broken_pipe_read(con, db_kwargs): 128 | db1 = Connection(**db_kwargs) 129 | res = db1.run("select pg_backend_pid()") 130 | pid1 = res[0][0] 131 | 132 | con.run("select pg_terminate_backend(:v)", v=pid1) 133 | with pytest.raises(InterfaceError, match="network error"): 134 | db1.run("select 1") 135 | 136 | try: 137 | db1.close() 138 | except InterfaceError: 139 | pass 140 | 141 | 142 | def test_broken_pipe_unpack(con): 143 | res = con.run("select pg_backend_pid()") 144 | pid1 = res[0][0] 145 | 146 | with pytest.raises(InterfaceError, match="network error"): 147 | con.run("select pg_terminate_backend(:v)", v=pid1) 148 | 149 | 150 | def test_broken_pipe_flush(con, db_kwargs): 151 | db1 = Connection(**db_kwargs) 152 | res = db1.run("select pg_backend_pid()") 153 | pid1 = res[0][0] 154 | 155 | con.run("select pg_terminate_backend(:v)", v=pid1) 156 | try: 157 | db1.run("select 1") 158 | except BaseException: 159 | pass 160 | 161 | # Sometimes raises and sometime doesn't 162 | try: 163 | db1.close() 164 | except InterfaceError as e: 165 | assert str(e) == "network error" 166 | 167 | 168 | def test_application_name(db_kwargs): 169 | app_name = "my test application name" 170 | db_kwargs["application_name"] = app_name 171 | with Connection(**db_kwargs) as db: 172 | res = db.run( 173 | "select application_name from pg_stat_activity " 174 | " where pid = pg_backend_pid()" 175 | ) 176 | 177 | application_name = res[0][0] 178 | assert application_name == app_name 179 | 180 | 181 | def test_application_name_integer(db_kwargs): 182 | db_kwargs["application_name"] = 1 183 | with pytest.raises( 184 | InterfaceError, 185 | match="The parameter application_name can't be of type .", 186 | ): 187 | Connection(**db_kwargs) 188 | 189 | 190 | def test_application_name_bytearray(db_kwargs): 191 | db_kwargs["application_name"] = bytearray(b"Philby") 192 | with Connection(**db_kwargs): 193 | pass 194 | 195 | 196 | class PG8000TestException(Exception): 197 | pass 198 | 199 | 200 | def raise_exception(val): 201 | raise PG8000TestException("oh noes!") 202 | 203 | 204 | def test_py_value_fail(con, mocker): 205 | # Ensure that if types.py_value throws an exception, the original 206 | # exception is raised (PG8000TestException), and the connection is 207 | # still usable after the error. 208 | mocker.patch.object(con, "py_types") 209 | con.py_types = {Time: raise_exception} 210 | 211 | with pytest.raises(PG8000TestException): 212 | con.run("SELECT CAST(:v AS TIME)", v=Time(10, 30)) 213 | 214 | # ensure that the connection is still usable for a new query 215 | res = con.run("VALUES ('hw3'::text)") 216 | assert res[0][0] == "hw3" 217 | 218 | 219 | def test_no_data_error_recovery(con): 220 | for i in range(1, 4): 221 | with pytest.raises(DatabaseError) as e: 222 | con.run("DROP TABLE t1") 223 | assert e.value.args[0]["C"] == "42P01" 224 | con.run("ROLLBACK") 225 | 226 | 227 | def test_closed_connection(con): 228 | con.close() 229 | with pytest.raises(InterfaceError, match="connection is closed"): 230 | con.run("VALUES ('hw1'::text)") 231 | 232 | 233 | def test_version(): 234 | try: 235 | from importlib.metadata import version 236 | except ImportError: 237 | from importlib_metadata import version 238 | 239 | v = version("pg8000") 240 | 241 | assert __version__ == v 242 | 243 | 244 | @pytest.mark.parametrize( 245 | "commit", 246 | [ 247 | "commit", 248 | "COMMIT;", 249 | ], 250 | ) 251 | def test_failed_transaction_commit(con, commit): 252 | con.run("create temporary table tt (f1 int primary key)") 253 | con.run("begin") 254 | try: 255 | con.run("insert into tt(f1) values(null)") 256 | except DatabaseError: 257 | pass 258 | 259 | with pytest.raises(InterfaceError): 260 | con.run(commit) 261 | 262 | 263 | @pytest.mark.parametrize( 264 | "rollback", 265 | [ 266 | "rollback", 267 | "rollback;", 268 | "ROLLBACK ;", 269 | ], 270 | ) 271 | def test_failed_transaction_rollback(con, rollback): 272 | con.run("create temporary table tt (f1 int primary key)") 273 | con.run("begin") 274 | try: 275 | con.run("insert into tt(f1) values(null)") 276 | except DatabaseError: 277 | pass 278 | 279 | con.run(rollback) 280 | 281 | 282 | @pytest.mark.parametrize( 283 | "rollback", 284 | [ 285 | "rollback to sp", 286 | "rollback to sp;", 287 | "ROLLBACK TO sp ;", 288 | ], 289 | ) 290 | def test_failed_transaction_rollback_to_savepoint(con, rollback): 291 | con.run("create temporary table tt (f1 int primary key)") 292 | con.run("begin") 293 | con.run("SAVEPOINT sp;") 294 | 295 | try: 296 | con.run("insert into tt(f1) values(null)") 297 | except DatabaseError: 298 | pass 299 | 300 | con.run(rollback) 301 | 302 | 303 | @pytest.mark.parametrize( 304 | "sql", 305 | [ 306 | "BEGIN", 307 | "select * from tt;", 308 | ], 309 | ) 310 | def test_failed_transaction_sql(con, sql): 311 | con.run("create temporary table tt (f1 int primary key)") 312 | con.run("begin") 313 | try: 314 | con.run("insert into tt(f1) values(null)") 315 | except DatabaseError: 316 | pass 317 | 318 | with pytest.raises(DatabaseError): 319 | con.run(sql) 320 | 321 | 322 | def test_parameter_statuses(con): 323 | role_name = "Æthelred" 324 | try: 325 | con.run(f"create role {role_name}") 326 | except DatabaseError: 327 | pass 328 | con.run(f"set session authorization '{role_name}'") 329 | assert role_name == con.parameter_statuses["session_authorization"] 330 | -------------------------------------------------------------------------------- /test/native/test_copy.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO, StringIO 2 | 3 | import pytest 4 | 5 | 6 | @pytest.fixture 7 | def db_table(request, con): 8 | con.run("START TRANSACTION") 9 | con.run( 10 | "CREATE TEMPORARY TABLE t1 " 11 | "(f1 int primary key, f2 int not null, f3 varchar(50) null) " 12 | "on commit drop" 13 | ) 14 | return con 15 | 16 | 17 | def test_copy_to_with_table(db_table): 18 | db_table.run("INSERT INTO t1 (f1, f2, f3) VALUES (:v1, :v1, :v2)", v1=1, v2="1") 19 | db_table.run("INSERT INTO t1 (f1, f2, f3) VALUES (:v1, :v1, :v2)", v1=2, v2="2") 20 | db_table.run("INSERT INTO t1 (f1, f2, f3) VALUES (:v1, :v1, :v2)", v1=3, v2="3") 21 | 22 | stream = BytesIO() 23 | db_table.run("copy t1 to stdout", stream=stream) 24 | assert stream.getvalue() == b"1\t1\t1\n2\t2\t2\n3\t3\t3\n" 25 | assert db_table.row_count == 3 26 | 27 | 28 | def test_copy_to_with_query(con): 29 | stream = BytesIO() 30 | con.run( 31 | "COPY (SELECT 1 as One, 2 as Two) TO STDOUT WITH DELIMITER " 32 | "'X' CSV HEADER QUOTE AS 'Y' FORCE QUOTE Two", 33 | stream=stream, 34 | ) 35 | assert stream.getvalue() == b"oneXtwo\n1XY2Y\n" 36 | assert con.row_count == 1 37 | 38 | 39 | def test_copy_to_with_text_stream(con): 40 | stream = StringIO() 41 | con.run( 42 | "COPY (SELECT 1 as One, 2 as Two) TO STDOUT WITH DELIMITER " 43 | "'X' CSV HEADER QUOTE AS 'Y' FORCE QUOTE Two", 44 | stream=stream, 45 | ) 46 | assert stream.getvalue() == "oneXtwo\n1XY2Y\n" 47 | assert con.row_count == 1 48 | 49 | 50 | def test_copy_from_with_table(db_table): 51 | stream = BytesIO(b"1\t1\t1\n2\t2\t2\n3\t3\t3\n") 52 | db_table.run("copy t1 from STDIN", stream=stream) 53 | assert db_table.row_count == 3 54 | 55 | retval = db_table.run("SELECT * FROM t1 ORDER BY f1") 56 | assert retval == [[1, 1, "1"], [2, 2, "2"], [3, 3, "3"]] 57 | 58 | 59 | def test_copy_from_with_text_stream(db_table): 60 | stream = StringIO("1\t1\t1\n2\t2\t2\n3\t3\t3\n") 61 | db_table.run("copy t1 from STDIN", stream=stream) 62 | 63 | retval = db_table.run("SELECT * FROM t1 ORDER BY f1") 64 | assert retval == [[1, 1, "1"], [2, 2, "2"], [3, 3, "3"]] 65 | 66 | 67 | def test_copy_from_with_query(db_table): 68 | stream = BytesIO(b"f1Xf2\n1XY1Y\n") 69 | db_table.run( 70 | "COPY t1 (f1, f2) FROM STDIN WITH DELIMITER 'X' CSV HEADER " 71 | "QUOTE AS 'Y' FORCE NOT NULL f1", 72 | stream=stream, 73 | ) 74 | assert db_table.row_count == 1 75 | 76 | retval = db_table.run("SELECT * FROM t1 ORDER BY f1") 77 | assert retval == [[1, 1, None]] 78 | 79 | 80 | def test_copy_from_with_error(db_table): 81 | stream = BytesIO(b"f1Xf2\n\n1XY1Y\n") 82 | with pytest.raises(BaseException) as e: 83 | db_table.run( 84 | "COPY t1 (f1, f2) FROM STDIN WITH DELIMITER 'X' CSV HEADER " 85 | "QUOTE AS 'Y' FORCE NOT NULL f1", 86 | stream=stream, 87 | ) 88 | 89 | arg = { 90 | "S": ("ERROR",), 91 | "C": ("22P02",), 92 | "M": ( 93 | 'invalid input syntax for type integer: ""', 94 | 'invalid input syntax for integer: ""', 95 | ), 96 | "W": ('COPY t1, line 2, column f1: ""',), 97 | "F": ("numutils.c",), 98 | "R": ("pg_atoi", "pg_strtoint32", "pg_strtoint32_safe"), 99 | } 100 | earg = e.value.args[0] 101 | for k, v in arg.items(): 102 | assert earg[k] in v 103 | 104 | 105 | def test_copy_from_with_text_iterable(db_table): 106 | stream = ["1\t1\t1\n", "2\t2\t2\n", "3\t3\t3\n"] 107 | db_table.run("copy t1 from STDIN", stream=stream) 108 | 109 | retval = db_table.run("SELECT * FROM t1 ORDER BY f1") 110 | assert retval == [[1, 1, "1"], [2, 2, "2"], [3, 3, "3"]] 111 | -------------------------------------------------------------------------------- /test/native/test_core.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | 3 | import pytest 4 | 5 | from pg8000.core import ( 6 | Context, 7 | CoreConnection, 8 | NULL_BYTE, 9 | PASSWORD, 10 | _create_message, 11 | _make_socket, 12 | _read, 13 | ) 14 | from pg8000.native import InterfaceError 15 | 16 | 17 | def test_make_socket(mocker): 18 | unix_sock = None 19 | sock = mocker.Mock() 20 | host = "localhost" 21 | port = 5432 22 | timeout = None 23 | source_address = None 24 | tcp_keepalive = True 25 | ssl_context = None 26 | _make_socket( 27 | unix_sock, sock, host, port, timeout, source_address, tcp_keepalive, ssl_context 28 | ) 29 | 30 | 31 | def test_handle_AUTHENTICATION_3(mocker): 32 | """Shouldn't send a FLUSH message, as FLUSH only used in extended-query""" 33 | 34 | mocker.patch.object(CoreConnection, "__init__", lambda x: None) 35 | con = CoreConnection() 36 | password = "barbour".encode("utf8") 37 | con.password = password 38 | con._sock = mocker.Mock() 39 | buf = BytesIO() 40 | con._sock.write = buf.write 41 | CoreConnection.handle_AUTHENTICATION_REQUEST(con, b"\x00\x00\x00\x03", None) 42 | assert buf.getvalue() == _create_message(PASSWORD, password + NULL_BYTE) 43 | 44 | 45 | def test_create_message(): 46 | msg = _create_message(PASSWORD, "barbour".encode("utf8") + NULL_BYTE) 47 | assert msg == b"p\x00\x00\x00\x0cbarbour\x00" 48 | 49 | 50 | def test_handle_ERROR_RESPONSE(mocker): 51 | """Check it handles invalid encodings in the error messages""" 52 | 53 | mocker.patch.object(CoreConnection, "__init__", lambda x: None) 54 | con = CoreConnection() 55 | con._client_encoding = "utf8" 56 | data = b"S\xc2err" + NULL_BYTE + NULL_BYTE 57 | context = Context(None) 58 | CoreConnection.handle_ERROR_RESPONSE(con, data, context) 59 | assert str(context.error) == "{'S': '�err'}" 60 | 61 | 62 | def test_read(mocker): 63 | mock_socket = mocker.Mock() 64 | mock_socket.read = mocker.Mock(return_value=b"") 65 | with pytest.raises(InterfaceError, match="network error"): 66 | _read(mock_socket, 5) 67 | -------------------------------------------------------------------------------- /test/native/test_import_all.py: -------------------------------------------------------------------------------- 1 | from pg8000.native import * # noqa: F403 2 | 3 | 4 | def test_import_all(): 5 | type(Connection) # noqa: F405 6 | -------------------------------------------------------------------------------- /test/native/test_prepared_statement.py: -------------------------------------------------------------------------------- 1 | def test_prepare(con): 2 | con.prepare("SELECT CAST(:v AS INTEGER)") 3 | 4 | 5 | def test_run(con): 6 | ps = con.prepare("SELECT cast(:v as varchar)") 7 | ps.run(v="speedy") 8 | -------------------------------------------------------------------------------- /test/native/test_query.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from pg8000.native import DatabaseError, to_statement 4 | 5 | 6 | # Tests relating to the basic operation of the database driver, driven by the 7 | # pg8000 custom interface. 8 | 9 | 10 | @pytest.fixture 11 | def db_table(request, con): 12 | con.run( 13 | "CREATE TEMPORARY TABLE t1 (f1 int primary key, " 14 | "f2 bigint not null, f3 varchar(50) null) " 15 | ) 16 | 17 | def fin(): 18 | try: 19 | con.run("drop table t1") 20 | except DatabaseError: 21 | pass 22 | 23 | request.addfinalizer(fin) 24 | return con 25 | 26 | 27 | def test_database_error(con): 28 | with pytest.raises(DatabaseError): 29 | con.run("INSERT INTO t99 VALUES (1, 2, 3)") 30 | 31 | 32 | # Run a query on a table, alter the structure of the table, then run the 33 | # original query again. 34 | 35 | 36 | def test_alter(db_table): 37 | db_table.run("select * from t1") 38 | db_table.run("alter table t1 drop column f3") 39 | db_table.run("select * from t1") 40 | 41 | 42 | # Run a query on a table, drop then re-create the table, then run the 43 | # original query again. 44 | 45 | 46 | def test_create(db_table): 47 | db_table.run("select * from t1") 48 | db_table.run("drop table t1") 49 | db_table.run("create temporary table t1 (f1 int primary key)") 50 | db_table.run("select * from t1") 51 | 52 | 53 | def test_parametrized(db_table): 54 | res = db_table.run("SELECT f1, f2, f3 FROM t1 WHERE f1 > :f1", f1=3) 55 | for row in res: 56 | f1, f2, f3 = row 57 | 58 | 59 | def test_insert_returning(db_table): 60 | db_table.run("CREATE TEMPORARY TABLE t2 (id serial, data text)") 61 | 62 | # Test INSERT ... RETURNING with one row... 63 | res = db_table.run("INSERT INTO t2 (data) VALUES (:v) RETURNING id", v="test1") 64 | row_id = res[0][0] 65 | res = db_table.run("SELECT data FROM t2 WHERE id = :v", v=row_id) 66 | assert "test1" == res[0][0] 67 | 68 | assert db_table.row_count == 1 69 | 70 | # Test with multiple rows... 71 | res = db_table.run( 72 | "INSERT INTO t2 (data) VALUES (:v1), (:v2), (:v3) " "RETURNING id", 73 | v1="test2", 74 | v2="test3", 75 | v3="test4", 76 | ) 77 | assert db_table.row_count == 3 78 | ids = [x[0] for x in res] 79 | assert len(ids) == 3 80 | 81 | 82 | def test_row_count_select(db_table): 83 | expected_count = 57 84 | for i in range(expected_count): 85 | db_table.run( 86 | "INSERT INTO t1 (f1, f2, f3) VALUES (:v1, :v2, :v3)", v1=i, v2=i, v3=None 87 | ) 88 | 89 | db_table.run("SELECT * FROM t1") 90 | 91 | # Check row_count 92 | assert expected_count == db_table.row_count 93 | 94 | # Should be -1 for a command with no results 95 | db_table.run("DROP TABLE t1") 96 | assert -1 == db_table.row_count 97 | 98 | 99 | def test_row_count_delete(db_table): 100 | db_table.run( 101 | "INSERT INTO t1 (f1, f2, f3) VALUES (:v1, :v2, :v3)", v1=1, v2=1, v3=None 102 | ) 103 | db_table.run("DELETE FROM t1") 104 | assert db_table.row_count == 1 105 | 106 | 107 | def test_row_count_update(db_table): 108 | db_table.run( 109 | "INSERT INTO t1 (f1, f2, f3) VALUES (:v1, :v2, :v3)", v1=1, v2=1, v3=None 110 | ) 111 | db_table.run( 112 | "INSERT INTO t1 (f1, f2, f3) VALUES (:v1, :v2, :v3)", v1=2, v2=10, v3=None 113 | ) 114 | db_table.run( 115 | "INSERT INTO t1 (f1, f2, f3) VALUES (:v1, :v2, :v3)", v1=3, v2=100, v3=None 116 | ) 117 | db_table.run( 118 | "INSERT INTO t1 (f1, f2, f3) VALUES (:v1, :v2, :v3)", v1=4, v2=1000, v3=None 119 | ) 120 | db_table.run( 121 | "INSERT INTO t1 (f1, f2, f3) VALUES (:v1, :v2, :v3)", v1=5, v2=10000, v3=None 122 | ) 123 | db_table.run("UPDATE t1 SET f3 = :v1 WHERE f2 > 101", v1="Hello!") 124 | assert db_table.row_count == 2 125 | 126 | 127 | def test_int_oid(con): 128 | # https://bugs.launchpad.net/pg8000/+bug/230796 129 | con.run("SELECT typname FROM pg_type WHERE oid = :v", v=100) 130 | 131 | 132 | def test_unicode_query(con): 133 | con.run( 134 | "CREATE TEMPORARY TABLE \u043c\u0435\u0441\u0442\u043e " 135 | "(\u0438\u043c\u044f VARCHAR(50), " 136 | "\u0430\u0434\u0440\u0435\u0441 VARCHAR(250))" 137 | ) 138 | 139 | 140 | def test_transactions(db_table): 141 | db_table.run("start transaction") 142 | db_table.run( 143 | "INSERT INTO t1 (f1, f2, f3) VALUES (:v1, :v2, :v3)", v1=1, v2=1, v3="Zombie" 144 | ) 145 | db_table.run("rollback") 146 | db_table.run("select * from t1") 147 | 148 | assert db_table.row_count == 0 149 | 150 | 151 | def test_in(con): 152 | ret = con.run("SELECT typname FROM pg_type WHERE oid = any(:v)", v=[16, 23]) 153 | assert ret[0][0] == "bool" 154 | 155 | 156 | def test_empty_query(con): 157 | """No exception thrown""" 158 | con.run("") 159 | 160 | 161 | def test_rollback_no_transaction(con): 162 | # Remove any existing notices 163 | con.notices.clear() 164 | 165 | # First, verify that a raw rollback does produce a notice 166 | con.run("rollback") 167 | 168 | assert 1 == len(con.notices) 169 | 170 | # 25P01 is the code for no_active_sql_tronsaction. It has 171 | # a message and severity name, but those might be 172 | # localized/depend on the server version. 173 | assert con.notices.pop().get(b"C") == b"25P01" 174 | 175 | 176 | def test_close_prepared_statement(con): 177 | ps = con.prepare("select 1") 178 | ps.run() 179 | res = con.run("select count(*) from pg_prepared_statements") 180 | assert res[0][0] == 1 # Should have one prepared statement 181 | 182 | ps.close() 183 | 184 | res = con.run("select count(*) from pg_prepared_statements") 185 | assert res[0][0] == 0 # Should have no prepared statements 186 | 187 | 188 | def test_no_data(con): 189 | assert con.run("START TRANSACTION") is None 190 | 191 | 192 | def test_multiple_statements(con): 193 | statements = "SELECT 5; SELECT 'Erich Fromm';" 194 | assert con.run(statements) == [[5], ["Erich Fromm"]] 195 | 196 | 197 | def test_unexecuted_connection_row_count(con): 198 | assert con.row_count is None 199 | 200 | 201 | def test_unexecuted_connection_columns(con): 202 | assert con.columns is None 203 | 204 | 205 | def test_sql_prepared_statement(con): 206 | con.run("PREPARE gen_series AS SELECT generate_series(1, 10);") 207 | con.run("EXECUTE gen_series") 208 | 209 | 210 | def test_to_statement(): 211 | new_query, _ = to_statement( 212 | "SELECT sum(x)::decimal(5, 2) :f_2, :f1 FROM t WHERE a=:f_2" 213 | ) 214 | expected = "SELECT sum(x)::decimal(5, 2) $1, $2 FROM t WHERE a=$1" 215 | assert new_query == expected 216 | 217 | 218 | def test_to_statement_quotes(): 219 | new_query, _ = to_statement("SELECT $$'$$ = :v") 220 | expected = "SELECT $$'$$ = $1" 221 | assert new_query == expected 222 | 223 | 224 | def test_not_parsed_if_no_params(mocker, con): 225 | mock_to_statement = mocker.patch("pg8000.native.to_statement") 226 | con.run("ROLLBACK") 227 | mock_to_statement.assert_not_called() 228 | 229 | 230 | def test_max_parameters(con): 231 | SIZE = 60000 232 | kwargs = {f"param_{i}": 1 for i in range(SIZE)} 233 | con.run( 234 | f"SELECT 1 WHERE 1 IN ({','.join([f':param_{i}' for i in range(SIZE)])})", 235 | **kwargs, 236 | ) 237 | 238 | 239 | def test_pg_placeholder_style(con): 240 | rows = con.run("SELECT $1", title="A Time Of Hope") 241 | assert rows[0] == ["A Time Of Hope"] 242 | -------------------------------------------------------------------------------- /test/test_converters.py: -------------------------------------------------------------------------------- 1 | from datetime import ( 2 | date as Date, 3 | datetime as DateTime, 4 | time as Time, 5 | timedelta as TimeDelta, 6 | timezone as TimeZone, 7 | ) 8 | from decimal import Decimal 9 | from ipaddress import IPv4Address, IPv4Network 10 | 11 | import pytest 12 | 13 | from pg8000.converters import ( 14 | PGInterval, 15 | PY_TYPES, 16 | Range, 17 | array_out, 18 | array_string_escape, 19 | date_in, 20 | datemultirange_in, 21 | identifier, 22 | int4range_in, 23 | interval_in, 24 | literal, 25 | make_param, 26 | null_out, 27 | numeric_in, 28 | numeric_out, 29 | pg_interval_in, 30 | range_out, 31 | string_in, 32 | string_out, 33 | time_in, 34 | timestamp_in, 35 | timestamptz_in, 36 | tsrange_in, 37 | ) 38 | from pg8000.native import InterfaceError 39 | 40 | 41 | @pytest.mark.parametrize( 42 | "value,expected", 43 | [ 44 | ["2022-03-02", Date(2022, 3, 2)], 45 | ["infinity", "infinity"], 46 | ["-infinity", "-infinity"], 47 | ["20022-03-02", "20022-03-02"], 48 | ], 49 | ) 50 | def test_date_in(value, expected): 51 | assert date_in(value) == expected 52 | 53 | 54 | def test_null_out(): 55 | assert null_out(None) is None 56 | 57 | 58 | @pytest.mark.parametrize( 59 | "array,out", 60 | [ 61 | [[True, False, None], "{true,false,NULL}"], # bool[] 62 | [[IPv4Address("192.168.0.1")], "{192.168.0.1}"], # inet[] 63 | [[Date(2021, 3, 1)], "{2021-03-01}"], # date[] 64 | [[b"\x00\x01\x02\x03\x02\x01\x00"], '{"\\\\x00010203020100"}'], # bytea[] 65 | [[IPv4Network("192.168.0.0/28")], "{192.168.0.0/28}"], # inet[] 66 | [[1, 2, 3], "{1,2,3}"], # int2[] 67 | [[1, None, 3], "{1,NULL,3}"], # int2[] with None 68 | [[[1, 2], [3, 4]], "{{1,2},{3,4}}"], # int2[] multidimensional 69 | [[70000, 2, 3], "{70000,2,3}"], # int4[] 70 | [[7000000000, 2, 3], "{7000000000,2,3}"], # int8[] 71 | [[0, 7000000000, 2], "{0,7000000000,2}"], # int8[] 72 | [[1.1, 2.2, 3.3], "{1.1,2.2,3.3}"], # float8[] 73 | [["Veni", "vidi", "vici"], "{Veni,vidi,vici}"], # varchar[] 74 | [[("Veni", True, 1)], '{"(Veni,true,1)"}'], # array of composites 75 | ], 76 | ) 77 | def test_array_out(array, out): 78 | assert array_out(array) == out 79 | 80 | 81 | @pytest.mark.parametrize( 82 | "value", 83 | [ 84 | "1.1", 85 | "-1.1", 86 | "10000", 87 | "20000", 88 | "-1000000000.123456789", 89 | "1.0", 90 | "12.44", 91 | ], 92 | ) 93 | def test_numeric_out(value): 94 | assert numeric_out(value) == str(value) 95 | 96 | 97 | @pytest.mark.parametrize( 98 | "value", 99 | [ 100 | "1.1", 101 | "-1.1", 102 | "10000", 103 | "20000", 104 | "-1000000000.123456789", 105 | "1.0", 106 | "12.44", 107 | ], 108 | ) 109 | def test_numeric_in(value): 110 | assert numeric_in(value) == Decimal(value) 111 | 112 | 113 | @pytest.mark.parametrize( 114 | "data,expected", 115 | [ 116 | ("[6,3]", Range(6, 3, bounds="[]")), 117 | ], 118 | ) 119 | def test_int4range_in(data, expected): 120 | assert int4range_in(data) == expected 121 | 122 | 123 | @pytest.mark.parametrize( 124 | "v,expected", 125 | [ 126 | (Range(6, 3, bounds="[]"), "[6,3]"), 127 | ], 128 | ) 129 | def test_range_out(v, expected): 130 | assert range_out(v) == expected 131 | 132 | 133 | @pytest.mark.parametrize( 134 | "value", 135 | [ 136 | "hello \u0173 world", 137 | ], 138 | ) 139 | def test_string_out(value): 140 | assert string_out(value) == value 141 | 142 | 143 | @pytest.mark.parametrize( 144 | "value", 145 | [ 146 | "hello \u0173 world", 147 | ], 148 | ) 149 | def test_string_in(value): 150 | assert string_in(value) == value 151 | 152 | 153 | def test_time_in(): 154 | actual = time_in("12:57:18.000396") 155 | assert actual == Time(12, 57, 18, 396) 156 | 157 | 158 | @pytest.mark.parametrize( 159 | "value,expected", 160 | [ 161 | ("1 year", PGInterval(years=1)), 162 | ("2 hours", PGInterval(hours=2)), 163 | ], 164 | ) 165 | def test_pg_interval_in(value, expected): 166 | assert pg_interval_in(value) == expected 167 | 168 | 169 | @pytest.mark.parametrize( 170 | "value,expected", 171 | [ 172 | ("2 hours", TimeDelta(hours=2)), 173 | ("00:00:30", TimeDelta(seconds=30)), 174 | ], 175 | ) 176 | def test_interval_in(value, expected): 177 | assert interval_in(value) == expected 178 | 179 | 180 | @pytest.mark.parametrize( 181 | "value,expected", 182 | [ 183 | ("a", "a"), 184 | ('"', '"\\""'), 185 | ("\r", '"\r"'), 186 | ("\n", '"\n"'), 187 | ("\t", '"\t"'), 188 | ], 189 | ) 190 | def test_array_string_escape(value, expected): 191 | res = array_string_escape(value) 192 | assert res == expected 193 | 194 | 195 | @pytest.mark.parametrize( 196 | "value,expected", 197 | [ 198 | [ 199 | "2022-10-08 15:01:39+01:30", 200 | DateTime( 201 | 2022, 10, 8, 15, 1, 39, tzinfo=TimeZone(TimeDelta(hours=1, minutes=30)) 202 | ), 203 | ], 204 | [ 205 | "2022-10-08 15:01:39-01:30", 206 | DateTime( 207 | 2022, 208 | 10, 209 | 8, 210 | 15, 211 | 1, 212 | 39, 213 | tzinfo=TimeZone(TimeDelta(hours=-1, minutes=-30)), 214 | ), 215 | ], 216 | [ 217 | "2022-10-08 15:01:39+02", 218 | DateTime(2022, 10, 8, 15, 1, 39, tzinfo=TimeZone(TimeDelta(hours=2))), 219 | ], 220 | [ 221 | "2022-10-08 15:01:39-02", 222 | DateTime(2022, 10, 8, 15, 1, 39, tzinfo=TimeZone(TimeDelta(hours=-2))), 223 | ], 224 | [ 225 | "2022-10-08 15:01:39.597026+01:30", 226 | DateTime( 227 | 2022, 228 | 10, 229 | 8, 230 | 15, 231 | 1, 232 | 39, 233 | 597026, 234 | tzinfo=TimeZone(TimeDelta(hours=1, minutes=30)), 235 | ), 236 | ], 237 | [ 238 | "2022-10-08 15:01:39.597026-01:30", 239 | DateTime( 240 | 2022, 241 | 10, 242 | 8, 243 | 15, 244 | 1, 245 | 39, 246 | 597026, 247 | tzinfo=TimeZone(TimeDelta(hours=-1, minutes=-30)), 248 | ), 249 | ], 250 | [ 251 | "2022-10-08 15:01:39.597026+02", 252 | DateTime( 253 | 2022, 10, 8, 15, 1, 39, 597026, tzinfo=TimeZone(TimeDelta(hours=2)) 254 | ), 255 | ], 256 | [ 257 | "2022-10-08 15:01:39.597026-02", 258 | DateTime( 259 | 2022, 10, 8, 15, 1, 39, 597026, tzinfo=TimeZone(TimeDelta(hours=-2)) 260 | ), 261 | ], 262 | [ 263 | "20022-10-08 15:01:39.597026-02", 264 | "20022-10-08 15:01:39.597026-02", 265 | ], 266 | ], 267 | ) 268 | def test_timestamptz_in(value, expected): 269 | assert timestamptz_in(value) == expected 270 | 271 | 272 | @pytest.mark.parametrize( 273 | "value,expected", 274 | [ 275 | [ 276 | "20022-10-08 15:01:39.597026", 277 | "20022-10-08 15:01:39.597026", 278 | ], 279 | ], 280 | ) 281 | def test_timestamp_in(value, expected): 282 | assert timestamp_in(value) == expected 283 | 284 | 285 | @pytest.mark.parametrize( 286 | "value,expected", 287 | [ 288 | [ 289 | '["2001-02-03 04:05:00","2023-02-03 04:05:00")', 290 | Range(DateTime(2001, 2, 3, 4, 5), DateTime(2023, 2, 3, 4, 5)), 291 | ], 292 | ], 293 | ) 294 | def test_tsrange_in(value, expected): 295 | assert tsrange_in(value) == expected 296 | 297 | 298 | @pytest.mark.parametrize( 299 | "value,expected", 300 | [ 301 | [ 302 | "{[2023-06-01,2023-06-06),[2023-06-10,2023-06-13)}", 303 | [ 304 | Range(Date(2023, 6, 1), Date(2023, 6, 6)), 305 | Range(Date(2023, 6, 10), Date(2023, 6, 13)), 306 | ], 307 | ], 308 | ], 309 | ) 310 | def test_datemultirange_in(value, expected): 311 | assert datemultirange_in(value) == expected 312 | 313 | 314 | def test_make_param(): 315 | class BClass(bytearray): 316 | pass 317 | 318 | val = BClass(b"\x00\x01\x02\x03\x02\x01\x00") 319 | assert make_param(PY_TYPES, val) == "\\x00010203020100" 320 | 321 | 322 | def test_identifier_int(): 323 | with pytest.raises(InterfaceError, match="identifier must be a str"): 324 | identifier(9) 325 | 326 | 327 | def test_identifier_empty(): 328 | with pytest.raises( 329 | InterfaceError, match="identifier must be > 0 characters in length" 330 | ): 331 | identifier("") 332 | 333 | 334 | def test_identifier_quoted_null(): 335 | with pytest.raises( 336 | InterfaceError, match="identifier cannot contain the code zero character" 337 | ): 338 | identifier("tabl\u0000e") 339 | 340 | 341 | @pytest.mark.parametrize( 342 | "value,expected", 343 | [ 344 | ("top_secret", '"top_secret"'), 345 | (" Table", '" Table"'), 346 | ("A Table", '"A Table"'), 347 | ('A " Table', '"A "" Table"'), 348 | ("table$", '"table$"'), 349 | ("Table$", '"Table$"'), 350 | ("tableఐ", '"tableఐ"'), # Unicode character 0C10 which is uncased 351 | ("table", '"table"'), 352 | ("tAble", '"tAble"'), 353 | ], 354 | ) 355 | def test_identifier_success(value, expected): 356 | assert identifier(value) == expected 357 | 358 | 359 | @pytest.mark.parametrize( 360 | "value,expected", 361 | [("top_secret", "'top_secret'"), (["cove"], "'{cove}'")], 362 | ) 363 | def test_literal(value, expected): 364 | assert literal(value) == expected 365 | 366 | 367 | def test_literal_quote(): 368 | assert literal("bob's") == "'bob''s'" 369 | 370 | 371 | def test_literal_int(): 372 | assert literal(7) == "7" 373 | 374 | 375 | def test_literal_float(): 376 | assert literal(7.9) == "7.9" 377 | 378 | 379 | def test_literal_decimal(): 380 | assert literal(Decimal("0.1")) == "0.1" 381 | 382 | 383 | def test_literal_bytes(): 384 | assert literal(b"\x03") == "X'03'" 385 | 386 | 387 | def test_literal_boolean(): 388 | assert literal(True) == "TRUE" 389 | 390 | 391 | def test_literal_None(): 392 | assert literal(None) == "NULL" 393 | 394 | 395 | def test_literal_Time(): 396 | assert literal(Time(22, 13, 2)) == "'22:13:02'" 397 | 398 | 399 | def test_literal_Date(): 400 | assert literal(Date(2063, 11, 2)) == "'2063-11-02'" 401 | 402 | 403 | def test_literal_TimeDelta(): 404 | assert literal(TimeDelta(22, 13, 2)) == "'22 days 13 seconds 2 microseconds'" 405 | 406 | 407 | def test_literal_Datetime(): 408 | assert literal(DateTime(2063, 3, 31, 22, 13, 2)) == "'2063-03-31T22:13:02'" 409 | 410 | 411 | def test_literal_Trojan(): 412 | class Trojan: 413 | def __str__(self): 414 | return "A Gift" 415 | 416 | assert literal(Trojan()) == "'A Gift'" 417 | -------------------------------------------------------------------------------- /test/test_readme.py: -------------------------------------------------------------------------------- 1 | import doctest 2 | 3 | from pathlib import Path 4 | 5 | 6 | def test_readme(): 7 | failure_count, _ = doctest.testfile( 8 | str(Path("..") / "README.md"), verbose=False, optionflags=doctest.ELLIPSIS 9 | ) 10 | assert failure_count == 0 11 | -------------------------------------------------------------------------------- /test/test_types.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from pg8000.types import PGInterval, Range 4 | 5 | 6 | def test_PGInterval_init(): 7 | i = PGInterval(days=1) 8 | assert i.months is None 9 | assert i.days == 1 10 | assert i.microseconds is None 11 | 12 | 13 | def test_PGInterval_repr(): 14 | v = PGInterval(microseconds=123456789, days=2, months=24) 15 | assert repr(v) == "" 16 | 17 | 18 | def test_PGInterval_str(): 19 | v = PGInterval(microseconds=123456789, days=2, months=24, millennia=2) 20 | assert str(v) == "2 millennia 24 months 2 days 123456789 microseconds" 21 | 22 | 23 | @pytest.mark.parametrize( 24 | "value,expected", 25 | [ 26 | ("P1Y2M", PGInterval(years=1, months=2)), 27 | ("P12DT30S", PGInterval(days=12, seconds=30)), 28 | ( 29 | "P-1Y-2M3DT-4H-5M-6S", 30 | PGInterval(years=-1, months=-2, days=3, hours=-4, minutes=-5, seconds=-6), 31 | ), 32 | ("PT1M32.32S", PGInterval(minutes=1, seconds=32.32)), 33 | ], 34 | ) 35 | def test_PGInterval_from_str_iso_8601(value, expected): 36 | interval = PGInterval.from_str_iso_8601(value) 37 | assert interval == expected 38 | 39 | 40 | @pytest.mark.parametrize( 41 | "value,expected", 42 | [ 43 | ("@ 1 year 2 mons", PGInterval(years=1, months=2)), 44 | ( 45 | "@ 3 days 4 hours 5 mins 6 secs", 46 | PGInterval(days=3, hours=4, minutes=5, seconds=6), 47 | ), 48 | ( 49 | "@ 1 year 2 mons -3 days 4 hours 5 mins 6 secs ago", 50 | PGInterval(years=-1, months=-2, days=3, hours=-4, minutes=-5, seconds=-6), 51 | ), 52 | ( 53 | "@ 1 millennium -2 mons", 54 | PGInterval(millennia=1, months=-2), 55 | ), 56 | ], 57 | ) 58 | def test_PGInterval_from_str_postgres(value, expected): 59 | interval = PGInterval.from_str_postgres(value) 60 | assert interval == expected 61 | 62 | 63 | @pytest.mark.parametrize( 64 | "value,expected", 65 | [ 66 | ["1-2", PGInterval(years=1, months=2)], 67 | ["3 4:05:06", PGInterval(days=3, hours=4, minutes=5, seconds=6)], 68 | [ 69 | "-1-2 +3 -4:05:06", 70 | PGInterval(years=-1, months=-2, days=3, hours=-4, minutes=-5, seconds=-6), 71 | ], 72 | ["8 4:00:32.32", PGInterval(days=8, hours=4, minutes=0, seconds=32.32)], 73 | ], 74 | ) 75 | def test_PGInterval_from_str_sql_standard(value, expected): 76 | interval = PGInterval.from_str_sql_standard(value) 77 | assert interval == expected 78 | 79 | 80 | @pytest.mark.parametrize( 81 | "value,expected", 82 | [ 83 | ("P12DT30S", PGInterval(days=12, seconds=30)), 84 | ("@ 1 year 2 mons", PGInterval(years=1, months=2)), 85 | ("1-2", PGInterval(years=1, months=2)), 86 | ("3 4:05:06", PGInterval(days=3, hours=4, minutes=5, seconds=6)), 87 | ( 88 | "-1-2 +3 -4:05:06", 89 | PGInterval(years=-1, months=-2, days=3, hours=-4, minutes=-5, seconds=-6), 90 | ), 91 | ("00:00:30", PGInterval(seconds=30)), 92 | ], 93 | ) 94 | def test_PGInterval_from_str(value, expected): 95 | interval = PGInterval.from_str(value) 96 | assert interval == expected 97 | 98 | 99 | def test_Range_equals(): 100 | pg_range_a = Range("[", 1, 2, ")") 101 | pg_range_b = Range("[", 1, 2, ")") 102 | assert pg_range_a == pg_range_b 103 | 104 | 105 | def test_Range_str(): 106 | v = Range(5, 6) 107 | assert str(v) == "[5,6)" 108 | -------------------------------------------------------------------------------- /test/test_utils.py: -------------------------------------------------------------------------------- 1 | from test.utils import parse_server_version 2 | 3 | 4 | def test_parse_server_version(): 5 | assert parse_server_version("17rc1") == 17 6 | -------------------------------------------------------------------------------- /test/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | def parse_server_version(version): 5 | major = re.match(r"\d+", version).group() # leading digits in 17.0, 17rc1 6 | return int(major) 7 | --------------------------------------------------------------------------------