├── .github
└── workflows
│ └── test.yml
├── .gitignore
├── LICENSE
├── README.md
├── pyproject.toml
├── src
└── pg8000
│ ├── __init__.py
│ ├── converters.py
│ ├── core.py
│ ├── dbapi.py
│ ├── exceptions.py
│ ├── legacy.py
│ ├── native.py
│ └── types.py
└── test
├── __init__.py
├── dbapi
├── __init__.py
├── auth
│ ├── __init__.py
│ ├── test_gss.py
│ ├── test_md5.py
│ ├── test_md5_ssl.py
│ ├── test_password.py
│ ├── test_scram-sha-256.py
│ └── test_scram-sha-256_ssl.py
├── conftest.py
├── test_benchmarks.py
├── test_connection.py
├── test_copy.py
├── test_dbapi.py
├── test_dbapi20.py
├── test_paramstyle.py
├── test_query.py
└── test_typeconversion.py
├── legacy
├── __init__.py
├── auth
│ ├── __init__.py
│ ├── test_gss.py
│ ├── test_md5.py
│ ├── test_md5_ssl.py
│ ├── test_password.py
│ ├── test_scram-sha-256.py
│ └── test_scram-sha-256_ssl.py
├── conftest.py
├── stress.py
├── test_benchmarks.py
├── test_connection.py
├── test_copy.py
├── test_dbapi.py
├── test_dbapi20.py
├── test_error_recovery.py
├── test_paramstyle.py
├── test_prepared_statement.py
├── test_query.py
├── test_typeconversion.py
└── test_typeobjects.py
├── native
├── __init__.py
├── auth
│ ├── __init__.py
│ ├── test_gss.py
│ ├── test_md5.py
│ ├── test_md5_ssl.py
│ ├── test_password.py
│ ├── test_scram-sha-256.py
│ └── test_scram-sha-256_ssl.py
├── conftest.py
├── test_benchmarks.py
├── test_connection.py
├── test_copy.py
├── test_core.py
├── test_import_all.py
├── test_prepared_statement.py
├── test_query.py
└── test_typeconversion.py
├── test_converters.py
├── test_readme.py
├── test_types.py
├── test_utils.py
└── utils.py
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | name: pg8000
2 |
3 | on: [push]
4 |
5 | permissions: read-all
6 |
7 | jobs:
8 | main-test:
9 |
10 | runs-on: ubuntu-latest
11 | strategy:
12 | matrix:
13 | python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
14 | postgresql-version: [16, 15, 14, 13, 12]
15 |
16 | container:
17 | image: python:${{ matrix.python-version }}
18 | env:
19 | PGHOST: postgres
20 | PGPASSWORD: postgres
21 | PGUSER: postgres
22 |
23 | services:
24 | postgres:
25 | image: postgres:${{ matrix.postgresql-version }}
26 | env:
27 | POSTGRES_PASSWORD: postgres
28 | # Set health checks to wait until postgres has started
29 | options: >-
30 | --health-cmd pg_isready
31 | --health-interval 10s
32 | --health-timeout 5s
33 | --health-retries 5
34 |
35 | steps:
36 | - uses: actions/checkout@v4
37 | - name: Install dependencies
38 | run: |
39 | git config --global --add safe.directory "$GITHUB_WORKSPACE"
40 | python -m pip install --no-cache-dir --upgrade pip
41 | pip install --no-cache-dir --root-user-action ignore pytest pytest-mock pytest-benchmark pytz .
42 | - name: Set up Postgresql
43 | run: |
44 | apt-get update
45 | apt-get install --yes --no-install-recommends postgresql-client
46 | psql -c "CREATE EXTENSION hstore;"
47 | psql -c "SELECT pg_reload_conf()"
48 | - name: Test with pytest
49 | run: |
50 | python -m pytest -x -v -W error --ignore=test/dbapi/auth/ --ignore=test/legacy/auth/ --ignore=test/native/auth/ test --ignore=test/test_readme.py
51 |
52 | auth-test:
53 | runs-on: ubuntu-latest
54 | strategy:
55 | matrix:
56 | python-version: ["3.11"]
57 | postgresql-version: ["16", "12"]
58 | auth-type: [md5, gss, password, scram-sha-256]
59 |
60 | container:
61 | image: python:${{ matrix.python-version }}
62 | env:
63 | PGHOST: postgres
64 | PGPASSWORD: postgres
65 | PGUSER: postgres
66 | PIP_ROOT_USER_ACTION: ignore
67 |
68 | services:
69 | postgres:
70 | image: postgres:${{ matrix.postgresql-version }}
71 | env:
72 | POSTGRES_PASSWORD: postgres
73 | POSTGRES_HOST_AUTH_METHOD: ${{ matrix.auth-type }}
74 | POSTGRES_INITDB_ARGS: "${{ matrix.auth-type == 'scram-sha-256' && '--auth-host=scram-sha-256' || '' }}"
75 | # Set health checks to wait until postgres has started
76 | options: >-
77 | --health-cmd pg_isready
78 | --health-interval 10s
79 | --health-timeout 5s
80 | --health-retries 5
81 |
82 | steps:
83 | - uses: actions/checkout@v4
84 | - name: Install dependencies
85 | run: |
86 | git config --global --add safe.directory "$GITHUB_WORKSPACE"
87 | python -m pip install --no-cache-dir --upgrade pip
88 | pip install --no-cache-dir pytest pytest-mock pytest-benchmark pytz .
89 | - name: Test with pytest
90 | run: |
91 | python -m pytest -x -v -W error test/dbapi/auth/test_${{ matrix.auth-type }}.py test/native/auth/test_${{ matrix.auth-type }}.py test/legacy/auth/test_${{ matrix.auth-type }}.py
92 |
93 | ssl-test:
94 | runs-on: ubuntu-latest
95 |
96 | strategy:
97 | matrix:
98 | auth-type: [md5, scram-sha-256]
99 |
100 | services:
101 | postgres:
102 | image: postgres
103 | env:
104 | POSTGRES_PASSWORD: postgres
105 | POSTGRES_HOST_AUTH_METHOD: ${{ matrix.auth-type }}
106 | POSTGRES_INITDB_ARGS: "${{ matrix.auth-type == 'scram-sha-256' && '--auth-host=scram-sha-256' || '' }}"
107 | options: >-
108 | --health-cmd pg_isready
109 | --health-interval 10s
110 | --health-timeout 5s
111 | --health-retries 5
112 | ports:
113 | - 5432:5432
114 |
115 | steps:
116 | - name: Configure Postgres
117 | env:
118 | PGPASSWORD: postgres
119 | PGUSER: postgres
120 | PGHOST: localhost
121 | run: |
122 | sudo apt update
123 | sudo apt install --yes --no-install-recommends postgresql-client
124 | psql -c "ALTER SYSTEM SET ssl = on;"
125 | psql -c "ALTER SYSTEM SET ssl_cert_file = '/etc/ssl/certs/ssl-cert-snakeoil.pem'"
126 | psql -c "ALTER SYSTEM SET ssl_key_file = '/etc/ssl/private/ssl-cert-snakeoil.key'"
127 | psql -c "SELECT pg_reload_conf()"
128 |
129 | - name: Check out repository code
130 | uses: actions/checkout@v4
131 | - name: Set up Python
132 | uses: actions/setup-python@v5
133 | with:
134 | python-version: "3.11"
135 | - name: Install dependencies
136 | run: |
137 | python -m pip install --upgrade pip
138 | pip install pytest pytest-mock pytest-benchmark pytz .
139 | - name: SSL Test
140 | env:
141 | PGPASSWORD: postgres
142 | USER: postgres
143 | run: |
144 | python -m pytest -x -v -W error test/dbapi/auth/test_${{ matrix.auth-type}}_ssl.py test/native/auth/test_${{ matrix.auth-type}}_ssl.py test/legacy/auth/test_${{ matrix.auth-type}}_ssl.py
145 |
146 | static-test:
147 | runs-on: ubuntu-latest
148 |
149 | services:
150 | postgres:
151 | image: postgres:12
152 | env:
153 | POSTGRES_PASSWORD: cpsnow
154 | options: >-
155 | --health-cmd pg_isready
156 | --health-interval 10s
157 | --health-timeout 5s
158 | --health-retries 5
159 | ports:
160 | - 5432:5432
161 |
162 | steps:
163 | - name: Check out repository code
164 | uses: actions/checkout@v4
165 | - name: Set up Python
166 | uses: actions/setup-python@v5
167 | with:
168 | python-version: "3.11"
169 | - name: Install dependencies
170 | run: |
171 | psql postgresql://postgres:cpsnow@localhost -c "ALTER SYSTEM SET ssl = on;"
172 | psql postgresql://postgres:cpsnow@localhost -c "ALTER SYSTEM SET ssl_cert_file = '/etc/ssl/certs/ssl-cert-snakeoil.pem'"
173 | psql postgresql://postgres:cpsnow@localhost -c "ALTER SYSTEM SET ssl_key_file = '/etc/ssl/private/ssl-cert-snakeoil.key'"
174 | psql postgresql://postgres:cpsnow@localhost localhost -c "SELECT pg_reload_conf()"
175 | python -m pip install --upgrade pip
176 | pip install black build flake8 pytest flake8-alphabetize Flake8-pyproject \
177 | twine .
178 | - name: Lint check
179 | run: |
180 | black --check .
181 | flake8 .
182 | - name: Doctest
183 | env:
184 | PGPASSWORD: cpsnow
185 | USER: postgres
186 | run: |
187 | python -m pytest -x -v -W error test/test_readme.py
188 | - name: Check Distribution
189 | run: |
190 | python -m build
191 | twine check dist/*
192 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.py[co]
2 | *.swp
3 | *.orig
4 | *.class
5 | build
6 | pg8000.egg-info
7 | tmp
8 | dist
9 | .tox
10 | MANIFEST
11 | venv
12 | .cache
13 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright Mathieu Fenniak and Contributors to pg8000.
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | 3. Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = [
3 | "hatchling",
4 | "versioningit >= 3.1.2",
5 | ]
6 | build-backend = "hatchling.build"
7 |
8 | [project]
9 | name = "pg8000"
10 | authors = [{name = "The Contributors"}]
11 | description = "PostgreSQL interface library"
12 | readme = "README.md"
13 | requires-python = ">=3.8"
14 | keywords= ["postgresql", "dbapi"]
15 | license = {text = "BSD 3-Clause License"}
16 | classifiers = [
17 | "Development Status :: 5 - Production/Stable",
18 | "Intended Audience :: Developers",
19 | "License :: OSI Approved :: BSD License",
20 | "Programming Language :: Python",
21 | "Programming Language :: Python :: 3",
22 | "Programming Language :: Python :: 3.8",
23 | "Programming Language :: Python :: 3.9",
24 | "Programming Language :: Python :: 3.10",
25 | "Programming Language :: Python :: 3.11",
26 | "Programming Language :: Python :: 3.12",
27 | "Programming Language :: Python :: Implementation",
28 | "Programming Language :: Python :: Implementation :: CPython",
29 | "Programming Language :: Python :: Implementation :: PyPy",
30 | "Operating System :: OS Independent",
31 | "Topic :: Database :: Front-Ends",
32 | "Topic :: Software Development :: Libraries :: Python Modules",
33 | ]
34 | dependencies = [
35 | "scramp >= 1.4.5",
36 | 'python-dateutil >= 2.8.2',
37 | ]
38 | dynamic = ["version"]
39 |
40 | [project.urls]
41 | Homepage = "https://github.com/tlocke/pg8000"
42 |
43 | [tool.hatch.version]
44 | source = "versioningit"
45 |
46 | [tool.versioningit]
47 |
48 | [tool.versioningit.vcs]
49 | method = "git"
50 | default-tag = "0.0.0"
51 |
52 | [tool.flake8]
53 | application-names = ['pg8000']
54 | ignore = ['E203', 'W503']
55 | max-line-length = 88
56 | exclude = ['.git', '__pycache__', 'build', 'dist', 'venv', '.tox']
57 | application-import-names = ['pg8000']
58 |
59 | [tool.tox]
60 | legacy_tox_ini = """
61 | [tox]
62 | isolated_build = True
63 | envlist = py
64 |
65 | [testenv]
66 | passenv = PGPORT
67 | allowlist_externals=/usr/bin/rm
68 | commands =
69 | black --check .
70 | flake8 .
71 | python -m pytest -x -v -W error test
72 | rm -rf dist
73 | python -m build
74 | twine check dist/*
75 | deps =
76 | build
77 | pytest
78 | pytest-mock
79 | pytest-benchmark
80 | black
81 | flake8
82 | flake8-alphabetize
83 | Flake8-pyproject
84 | pytz
85 | twine
86 | """
87 |
--------------------------------------------------------------------------------
/src/pg8000/__init__.py:
--------------------------------------------------------------------------------
1 | from pg8000.legacy import (
2 | BIGINTEGER,
3 | BINARY,
4 | BOOLEAN,
5 | BOOLEAN_ARRAY,
6 | BYTES,
7 | Binary,
8 | CHAR,
9 | CHAR_ARRAY,
10 | Connection,
11 | Cursor,
12 | DATE,
13 | DATETIME,
14 | DECIMAL,
15 | DECIMAL_ARRAY,
16 | DataError,
17 | DatabaseError,
18 | Date,
19 | DateFromTicks,
20 | Error,
21 | FLOAT,
22 | FLOAT_ARRAY,
23 | INET,
24 | INT2VECTOR,
25 | INTEGER,
26 | INTEGER_ARRAY,
27 | INTERVAL,
28 | IntegrityError,
29 | InterfaceError,
30 | InternalError,
31 | JSON,
32 | JSONB,
33 | MACADDR,
34 | NAME,
35 | NAME_ARRAY,
36 | NULLTYPE,
37 | NUMBER,
38 | NotSupportedError,
39 | OID,
40 | OperationalError,
41 | PGInterval,
42 | ProgrammingError,
43 | ROWID,
44 | Range,
45 | STRING,
46 | TEXT,
47 | TEXT_ARRAY,
48 | TIME,
49 | TIMEDELTA,
50 | TIMESTAMP,
51 | TIMESTAMPTZ,
52 | Time,
53 | TimeFromTicks,
54 | Timestamp,
55 | TimestampFromTicks,
56 | UNKNOWN,
57 | UUID_TYPE,
58 | VARCHAR,
59 | VARCHAR_ARRAY,
60 | Warning,
61 | XID,
62 | __version__,
63 | pginterval_in,
64 | pginterval_out,
65 | timedelta_in,
66 | )
67 |
68 | # Copyright (c) 2007-2009, Mathieu Fenniak
69 | # Copyright (c) The Contributors
70 | # All rights reserved.
71 | #
72 | # Redistribution and use in source and binary forms, with or without
73 | # modification, are permitted provided that the following conditions are
74 | # met:
75 | #
76 | # * Redistributions of source code must retain the above copyright notice,
77 | # this list of conditions and the following disclaimer.
78 | # * Redistributions in binary form must reproduce the above copyright notice,
79 | # this list of conditions and the following disclaimer in the documentation
80 | # and/or other materials provided with the distribution.
81 | # * The name of the author may not be used to endorse or promote products
82 | # derived from this software without specific prior written permission.
83 | #
84 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
85 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
86 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
87 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
88 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
89 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
90 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
91 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
92 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
93 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
94 | # POSSIBILITY OF SUCH DAMAGE.
95 |
96 |
97 | def connect(
98 | user,
99 | host="localhost",
100 | database=None,
101 | port=5432,
102 | password=None,
103 | source_address=None,
104 | unix_sock=None,
105 | ssl_context=None,
106 | timeout=None,
107 | tcp_keepalive=True,
108 | application_name=None,
109 | replication=None,
110 | startup_params=None,
111 | ):
112 | return Connection(
113 | user,
114 | host=host,
115 | database=database,
116 | port=port,
117 | password=password,
118 | source_address=source_address,
119 | unix_sock=unix_sock,
120 | ssl_context=ssl_context,
121 | timeout=timeout,
122 | tcp_keepalive=tcp_keepalive,
123 | application_name=application_name,
124 | replication=replication,
125 | startup_params=startup_params,
126 | )
127 |
128 |
129 | apilevel = "2.0"
130 | """The DBAPI level supported, currently "2.0".
131 |
132 | This property is part of the `DBAPI 2.0 specification
133 | `_.
134 | """
135 |
136 | threadsafety = 1
137 | """Integer constant stating the level of thread safety the DBAPI interface
138 | supports. This DBAPI module supports sharing of the module only. Connections
139 | and cursors my not be shared between threads. This gives pg8000 a threadsafety
140 | value of 1.
141 |
142 | This property is part of the `DBAPI 2.0 specification
143 | `_.
144 | """
145 |
146 | paramstyle = "format"
147 |
148 |
149 | __all__ = [
150 | "BIGINTEGER",
151 | "BINARY",
152 | "BOOLEAN",
153 | "BOOLEAN_ARRAY",
154 | "BYTES",
155 | "Binary",
156 | "CHAR",
157 | "CHAR_ARRAY",
158 | "Connection",
159 | "Cursor",
160 | "DATE",
161 | "DATETIME",
162 | "DECIMAL",
163 | "DECIMAL_ARRAY",
164 | "DataError",
165 | "DatabaseError",
166 | "Date",
167 | "DateFromTicks",
168 | "Error",
169 | "FLOAT",
170 | "FLOAT_ARRAY",
171 | "INET",
172 | "INT2VECTOR",
173 | "INTEGER",
174 | "INTEGER_ARRAY",
175 | "INTERVAL",
176 | "IntegrityError",
177 | "InterfaceError",
178 | "InternalError",
179 | "JSON",
180 | "JSONB",
181 | "MACADDR",
182 | "NAME",
183 | "NAME_ARRAY",
184 | "NULLTYPE",
185 | "NUMBER",
186 | "NotSupportedError",
187 | "OID",
188 | "OperationalError",
189 | "PGInterval",
190 | "ProgrammingError",
191 | "ROWID",
192 | "Range",
193 | "STRING",
194 | "TEXT",
195 | "TEXT_ARRAY",
196 | "TIME",
197 | "TIMEDELTA",
198 | "TIMESTAMP",
199 | "TIMESTAMPTZ",
200 | "Time",
201 | "TimeFromTicks",
202 | "Timestamp",
203 | "TimestampFromTicks",
204 | "UNKNOWN",
205 | "UUID_TYPE",
206 | "VARCHAR",
207 | "VARCHAR_ARRAY",
208 | "Warning",
209 | "XID",
210 | "__version__",
211 | "connect",
212 | "pginterval_in",
213 | "pginterval_out",
214 | "timedelta_in",
215 | ]
216 |
--------------------------------------------------------------------------------
/src/pg8000/converters.py:
--------------------------------------------------------------------------------
1 | from datetime import (
2 | date as Date,
3 | datetime as Datetime,
4 | time as Time,
5 | timedelta as Timedelta,
6 | timezone as Timezone,
7 | )
8 | from decimal import Decimal
9 | from enum import Enum
10 | from ipaddress import (
11 | IPv4Address,
12 | IPv4Network,
13 | IPv6Address,
14 | IPv6Network,
15 | ip_address,
16 | ip_network,
17 | )
18 | from json import dumps, loads
19 | from uuid import UUID
20 |
21 | from dateutil.parser import ParserError, parse
22 |
23 | from pg8000.exceptions import InterfaceError
24 | from pg8000.types import PGInterval, Range
25 |
26 |
27 | ANY_ARRAY = 2277
28 | BIGINT = 20
29 | BIGINT_ARRAY = 1016
30 | BOOLEAN = 16
31 | BOOLEAN_ARRAY = 1000
32 | BYTES = 17
33 | BYTES_ARRAY = 1001
34 | CHAR = 1042
35 | CHAR_ARRAY = 1014
36 | CIDR = 650
37 | CIDR_ARRAY = 651
38 | CSTRING = 2275
39 | CSTRING_ARRAY = 1263
40 | DATE = 1082
41 | DATE_ARRAY = 1182
42 | DATEMULTIRANGE = 4535
43 | DATEMULTIRANGE_ARRAY = 6155
44 | DATERANGE = 3912
45 | DATERANGE_ARRAY = 3913
46 | FLOAT = 701
47 | FLOAT_ARRAY = 1022
48 | INET = 869
49 | INET_ARRAY = 1041
50 | INT2VECTOR = 22
51 | INT4MULTIRANGE = 4451
52 | INT4MULTIRANGE_ARRAY = 6150
53 | INT4RANGE = 3904
54 | INT4RANGE_ARRAY = 3905
55 | INT8MULTIRANGE = 4536
56 | INT8MULTIRANGE_ARRAY = 6157
57 | INT8RANGE = 3926
58 | INT8RANGE_ARRAY = 3927
59 | INTEGER = 23
60 | INTEGER_ARRAY = 1007
61 | INTERVAL = 1186
62 | INTERVAL_ARRAY = 1187
63 | OID = 26
64 | JSON = 114
65 | JSON_ARRAY = 199
66 | JSONB = 3802
67 | JSONB_ARRAY = 3807
68 | MACADDR = 829
69 | MONEY = 790
70 | MONEY_ARRAY = 791
71 | NAME = 19
72 | NAME_ARRAY = 1003
73 | NUMERIC = 1700
74 | NUMERIC_ARRAY = 1231
75 | NUMRANGE = 3906
76 | NUMRANGE_ARRAY = 3907
77 | NUMMULTIRANGE = 4532
78 | NUMMULTIRANGE_ARRAY = 6151
79 | NULLTYPE = -1
80 | OID = 26
81 | POINT = 600
82 | REAL = 700
83 | REAL_ARRAY = 1021
84 | RECORD = 2249
85 | SMALLINT = 21
86 | SMALLINT_ARRAY = 1005
87 | SMALLINT_VECTOR = 22
88 | STRING = 1043
89 | TEXT = 25
90 | TEXT_ARRAY = 1009
91 | TIME = 1083
92 | TIME_ARRAY = 1183
93 | TIMESTAMP = 1114
94 | TIMESTAMP_ARRAY = 1115
95 | TIMESTAMPTZ = 1184
96 | TIMESTAMPTZ_ARRAY = 1185
97 | TSMULTIRANGE = 4533
98 | TSMULTIRANGE_ARRAY = 6152
99 | TSRANGE = 3908
100 | TSRANGE_ARRAY = 3909
101 | TSTZMULTIRANGE = 4534
102 | TSTZMULTIRANGE_ARRAY = 6153
103 | TSTZRANGE = 3910
104 | TSTZRANGE_ARRAY = 3911
105 | UNKNOWN = 705
106 | UUID_TYPE = 2950
107 | UUID_ARRAY = 2951
108 | VARCHAR = 1043
109 | VARCHAR_ARRAY = 1015
110 | XID = 28
111 |
112 |
113 | MIN_INT2, MAX_INT2 = -(2**15), 2**15
114 | MIN_INT4, MAX_INT4 = -(2**31), 2**31
115 | MIN_INT8, MAX_INT8 = -(2**63), 2**63
116 |
117 |
118 | def bool_in(data):
119 | return data == "t"
120 |
121 |
122 | def bool_out(v):
123 | return "true" if v else "false"
124 |
125 |
126 | def bytes_in(data):
127 | return bytes.fromhex(data[2:])
128 |
129 |
130 | def bytes_out(v):
131 | return "\\x" + v.hex()
132 |
133 |
134 | def cidr_out(v):
135 | return str(v)
136 |
137 |
138 | def cidr_in(data):
139 | return ip_network(data, False) if "/" in data else ip_address(data)
140 |
141 |
142 | def date_in(data):
143 | if data in ("infinity", "-infinity"):
144 | return data
145 | else:
146 | try:
147 | return Datetime.strptime(data, "%Y-%m-%d").date()
148 | except ValueError:
149 | # pg date can overflow Python Datetime
150 | return data
151 |
152 |
153 | def date_out(v):
154 | return v.isoformat()
155 |
156 |
157 | def datetime_out(v):
158 | if v.tzinfo is None:
159 | return v.isoformat()
160 | else:
161 | return v.astimezone(Timezone.utc).isoformat()
162 |
163 |
164 | def enum_out(v):
165 | return str(v.value)
166 |
167 |
168 | def float_out(v):
169 | return str(v)
170 |
171 |
172 | def inet_in(data):
173 | return ip_network(data, False) if "/" in data else ip_address(data)
174 |
175 |
176 | def inet_out(v):
177 | return str(v)
178 |
179 |
180 | def int_in(data):
181 | return int(data)
182 |
183 |
184 | def int_out(v):
185 | return str(v)
186 |
187 |
188 | def interval_in(data):
189 | pg_interval = PGInterval.from_str(data)
190 | try:
191 | return pg_interval.to_timedelta()
192 | except ValueError:
193 | return pg_interval
194 |
195 |
196 | def interval_out(v):
197 | return f"{v.days} days {v.seconds} seconds {v.microseconds} microseconds"
198 |
199 |
200 | def json_in(data):
201 | return loads(data)
202 |
203 |
204 | def json_out(v):
205 | return dumps(v)
206 |
207 |
208 | def null_out(v):
209 | return None
210 |
211 |
212 | def numeric_in(data):
213 | return Decimal(data)
214 |
215 |
216 | def numeric_out(d):
217 | return str(d)
218 |
219 |
220 | def point_in(data):
221 | return tuple(map(float, data[1:-1].split(",")))
222 |
223 |
224 | def pg_interval_in(data):
225 | return PGInterval.from_str(data)
226 |
227 |
228 | def pg_interval_out(v):
229 | return str(v)
230 |
231 |
232 | def range_out(v):
233 | if v.is_empty:
234 | return "empty"
235 | else:
236 | le = v.lower
237 | val_lower = "" if le is None else make_param(PY_TYPES, le)
238 | ue = v.upper
239 | val_upper = "" if ue is None else make_param(PY_TYPES, ue)
240 | return f"{v.bounds[0]}{val_lower},{val_upper}{v.bounds[1]}"
241 |
242 |
243 | def string_in(data):
244 | return data
245 |
246 |
247 | def string_out(v):
248 | return v
249 |
250 |
251 | def time_in(data):
252 | pattern = "%H:%M:%S.%f" if "." in data else "%H:%M:%S"
253 | return Datetime.strptime(data, pattern).time()
254 |
255 |
256 | def time_out(v):
257 | return v.isoformat()
258 |
259 |
260 | def timestamp_in(data):
261 | if data in ("infinity", "-infinity"):
262 | return data
263 |
264 | try:
265 | pattern = "%Y-%m-%d %H:%M:%S.%f" if "." in data else "%Y-%m-%d %H:%M:%S"
266 | return Datetime.strptime(data, pattern)
267 | except ValueError:
268 | try:
269 | return parse(data)
270 | except ParserError:
271 | # pg timestamp can overflow Python Datetime
272 | return data
273 |
274 |
275 | def timestamptz_in(data):
276 | if data in ("infinity", "-infinity"):
277 | return data
278 |
279 | try:
280 | patt = "%Y-%m-%d %H:%M:%S.%f%z" if "." in data else "%Y-%m-%d %H:%M:%S%z"
281 | return Datetime.strptime(f"{data}00", patt)
282 | except ValueError:
283 | try:
284 | return parse(data)
285 | except ParserError:
286 | # pg timestamptz can overflow Python Datetime
287 | return data
288 |
289 |
290 | def unknown_out(v):
291 | return str(v)
292 |
293 |
294 | def vector_in(data):
295 | return [int(v) for v in data.split()]
296 |
297 |
298 | def uuid_out(v):
299 | return str(v)
300 |
301 |
302 | def uuid_in(data):
303 | return UUID(data)
304 |
305 |
306 | def _range_in(elem_func):
307 | def range_in(data):
308 | if data == "empty":
309 | return Range(is_empty=True)
310 | else:
311 | le, ue = [None if v == "" else elem_func(v) for v in data[1:-1].split(",")]
312 | return Range(le, ue, bounds=f"{data[0]}{data[-1]}")
313 |
314 | return range_in
315 |
316 |
317 | daterange_in = _range_in(date_in)
318 | int4range_in = _range_in(int)
319 | int8range_in = _range_in(int)
320 | numrange_in = _range_in(Decimal)
321 |
322 |
323 | def ts_in(data):
324 | return timestamp_in(data[1:-1])
325 |
326 |
327 | def tstz_in(data):
328 | return timestamptz_in(data[1:-1])
329 |
330 |
331 | tsrange_in = _range_in(ts_in)
332 | tstzrange_in = _range_in(tstz_in)
333 |
334 |
335 | def _multirange_in(adapter):
336 | def f(data):
337 | in_range = False
338 | result = []
339 | val = []
340 | for c in data:
341 | if in_range:
342 | val.append(c)
343 | if c in "])":
344 | value = "".join(val)
345 | val.clear()
346 | result.append(adapter(value))
347 | in_range = False
348 | elif c in "[(":
349 | val.append(c)
350 | in_range = True
351 |
352 | return result
353 |
354 | return f
355 |
356 |
357 | datemultirange_in = _multirange_in(daterange_in)
358 | int4multirange_in = _multirange_in(int4range_in)
359 | int8multirange_in = _multirange_in(int8range_in)
360 | nummultirange_in = _multirange_in(numrange_in)
361 | tsmultirange_in = _multirange_in(tsrange_in)
362 | tstzmultirange_in = _multirange_in(tstzrange_in)
363 |
364 |
365 | class ParserState(Enum):
366 | InString = 1
367 | InEscape = 2
368 | InValue = 3
369 | Out = 4
370 |
371 |
372 | def _parse_array(data, adapter):
373 | state = ParserState.Out
374 | stack = [[]]
375 | val = []
376 | for c in data:
377 | if state == ParserState.InValue:
378 | if c in ("}", ","):
379 | value = "".join(val)
380 | stack[-1].append(None if value == "NULL" else adapter(value))
381 | state = ParserState.Out
382 | else:
383 | val.append(c)
384 |
385 | if state == ParserState.Out:
386 | if c == "{":
387 | a = []
388 | stack[-1].append(a)
389 | stack.append(a)
390 | elif c == "}":
391 | stack.pop()
392 | elif c == ",":
393 | pass
394 | elif c == '"':
395 | val = []
396 | state = ParserState.InString
397 | else:
398 | val = [c]
399 | state = ParserState.InValue
400 |
401 | elif state == ParserState.InString:
402 | if c == '"':
403 | stack[-1].append(adapter("".join(val)))
404 | state = ParserState.Out
405 | elif c == "\\":
406 | state = ParserState.InEscape
407 | else:
408 | val.append(c)
409 | elif state == ParserState.InEscape:
410 | val.append(c)
411 | state = ParserState.InString
412 |
413 | return stack[0][0]
414 |
415 |
416 | def _array_in(adapter):
417 | def f(data):
418 | return _parse_array(data, adapter)
419 |
420 | return f
421 |
422 |
423 | bool_array_in = _array_in(bool_in)
424 | bytes_array_in = _array_in(bytes_in)
425 | cidr_array_in = _array_in(cidr_in)
426 | date_array_in = _array_in(date_in)
427 | datemultirange_array_in = _array_in(datemultirange_in)
428 | daterange_array_in = _array_in(daterange_in)
429 | inet_array_in = _array_in(inet_in)
430 | int_array_in = _array_in(int)
431 | int4multirange_array_in = _array_in(int4multirange_in)
432 | int4range_array_in = _array_in(int4range_in)
433 | int8multirange_array_in = _array_in(int8multirange_in)
434 | int8range_array_in = _array_in(int8range_in)
435 | interval_array_in = _array_in(interval_in)
436 | json_array_in = _array_in(json_in)
437 | float_array_in = _array_in(float)
438 | numeric_array_in = _array_in(numeric_in)
439 | nummultirange_array_in = _array_in(nummultirange_in)
440 | numrange_array_in = _array_in(numrange_in)
441 | string_array_in = _array_in(string_in)
442 | time_array_in = _array_in(time_in)
443 | timestamp_array_in = _array_in(timestamp_in)
444 | timestamptz_array_in = _array_in(timestamptz_in)
445 | tsrange_array_in = _array_in(tsrange_in)
446 | tsmultirange_array_in = _array_in(tsmultirange_in)
447 | tstzmultirange_array_in = _array_in(tstzmultirange_in)
448 | tstzrange_array_in = _array_in(tstzrange_in)
449 | uuid_array_in = _array_in(uuid_in)
450 |
451 |
452 | def array_string_escape(v):
453 | cs = []
454 | for c in v:
455 | if c == "\\":
456 | cs.append("\\")
457 | elif c == '"':
458 | cs.append("\\")
459 | cs.append(c)
460 | val = "".join(cs)
461 | if (
462 | len(val) == 0
463 | or val == "NULL"
464 | or any(c.isspace() for c in val)
465 | or any(c in val for c in ("{", "}", ",", "\\"))
466 | ):
467 | val = f'"{val}"'
468 | return val
469 |
470 |
471 | def array_out(ar):
472 | result = []
473 | for v in ar:
474 | if isinstance(v, list):
475 | val = array_out(v)
476 |
477 | elif isinstance(v, tuple):
478 | val = f'"{composite_out(v)}"'
479 |
480 | elif v is None:
481 | val = "NULL"
482 |
483 | elif isinstance(v, dict):
484 | val = array_string_escape(json_out(v))
485 |
486 | elif isinstance(v, (bytes, bytearray)):
487 | val = f'"\\{bytes_out(v)}"'
488 |
489 | elif isinstance(v, str):
490 | val = array_string_escape(v)
491 |
492 | else:
493 | val = make_param(PY_TYPES, v)
494 |
495 | result.append(val)
496 |
497 | return f'{{{",".join(result)}}}'
498 |
499 |
500 | def composite_out(ar):
501 | result = []
502 | for v in ar:
503 | if isinstance(v, list):
504 | val = array_out(v)
505 |
506 | elif isinstance(v, tuple):
507 | val = composite_out(v)
508 |
509 | elif v is None:
510 | val = ""
511 |
512 | elif isinstance(v, dict):
513 | val = array_string_escape(json_out(v))
514 |
515 | elif isinstance(v, (bytes, bytearray)):
516 | val = f'"\\{bytes_out(v)}"'
517 |
518 | elif isinstance(v, str):
519 | val = array_string_escape(v)
520 |
521 | else:
522 | val = make_param(PY_TYPES, v)
523 |
524 | result.append(val)
525 |
526 | return f'({",".join(result)})'
527 |
528 |
529 | def record_in(data):
530 | state = ParserState.Out
531 | results = []
532 | val = []
533 | for c in data:
534 | if state == ParserState.InValue:
535 | if c in (")", ","):
536 | value = "".join(val)
537 | val.clear()
538 | results.append(None if value == "" else value)
539 | state = ParserState.Out
540 | else:
541 | val.append(c)
542 |
543 | if state == ParserState.Out:
544 | if c in "(),":
545 | pass
546 | elif c == '"':
547 | state = ParserState.InString
548 | else:
549 | val.append(c)
550 | state = ParserState.InValue
551 |
552 | elif state == ParserState.InString:
553 | if c == '"':
554 | results.append("".join(val))
555 | val.clear()
556 | state = ParserState.Out
557 | elif c == "\\":
558 | state = ParserState.InEscape
559 | else:
560 | val.append(c)
561 |
562 | elif state == ParserState.InEscape:
563 | val.append(c)
564 | state = ParserState.InString
565 |
566 | return tuple(results)
567 |
568 |
569 | PY_PG = {
570 | Date: DATE,
571 | Decimal: NUMERIC,
572 | IPv4Address: INET,
573 | IPv6Address: INET,
574 | IPv4Network: INET,
575 | IPv6Network: INET,
576 | PGInterval: INTERVAL,
577 | Time: TIME,
578 | Timedelta: INTERVAL,
579 | UUID: UUID_TYPE,
580 | bool: BOOLEAN,
581 | bytearray: BYTES,
582 | dict: JSONB,
583 | float: FLOAT,
584 | type(None): NULLTYPE,
585 | bytes: BYTES,
586 | str: TEXT,
587 | }
588 |
589 |
590 | PY_TYPES = {
591 | Date: date_out, # date
592 | Datetime: datetime_out,
593 | Decimal: numeric_out, # numeric
594 | Enum: enum_out, # enum
595 | IPv4Address: inet_out, # inet
596 | IPv6Address: inet_out, # inet
597 | IPv4Network: inet_out, # inet
598 | IPv6Network: inet_out, # inet
599 | PGInterval: interval_out, # interval
600 | Range: range_out, # range types
601 | Time: time_out, # time
602 | Timedelta: interval_out, # interval
603 | UUID: uuid_out, # uuid
604 | bool: bool_out, # bool
605 | bytearray: bytes_out, # bytea
606 | dict: json_out, # jsonb
607 | float: float_out, # float8
608 | type(None): null_out, # null
609 | bytes: bytes_out, # bytea
610 | str: string_out, # unknown
611 | int: int_out,
612 | list: array_out,
613 | tuple: composite_out,
614 | }
615 |
616 |
617 | PG_TYPES = {
618 | BIGINT: int, # int8
619 | BIGINT_ARRAY: int_array_in, # int8[]
620 | BOOLEAN: bool_in, # bool
621 | BOOLEAN_ARRAY: bool_array_in, # bool[]
622 | BYTES: bytes_in, # bytea
623 | BYTES_ARRAY: bytes_array_in, # bytea[]
624 | CHAR: string_in, # char
625 | CHAR_ARRAY: string_array_in, # char[]
626 | CIDR_ARRAY: cidr_array_in, # cidr[]
627 | CSTRING: string_in, # cstring
628 | CSTRING_ARRAY: string_array_in, # cstring[]
629 | DATE: date_in, # date
630 | DATE_ARRAY: date_array_in, # date[]
631 | DATEMULTIRANGE: datemultirange_in, # datemultirange
632 | DATEMULTIRANGE_ARRAY: datemultirange_array_in, # datemultirange[]
633 | DATERANGE: daterange_in, # daterange
634 | DATERANGE_ARRAY: daterange_array_in, # daterange[]
635 | FLOAT: float, # float8
636 | FLOAT_ARRAY: float_array_in, # float8[]
637 | INET: inet_in, # inet
638 | INET_ARRAY: inet_array_in, # inet[]
639 | INT4MULTIRANGE: int4multirange_in, # int4multirange
640 | INT4MULTIRANGE_ARRAY: int4multirange_array_in, # int4multirange[]
641 | INT4RANGE: int4range_in, # int4range
642 | INT4RANGE_ARRAY: int4range_array_in, # int4range[]
643 | INT8MULTIRANGE: int8multirange_in, # int8multirange
644 | INT8MULTIRANGE_ARRAY: int8multirange_array_in, # int8multirange[]
645 | INT8RANGE: int8range_in, # int8range
646 | INT8RANGE_ARRAY: int8range_array_in, # int8range[]
647 | INTEGER: int, # int4
648 | INTEGER_ARRAY: int_array_in, # int4[]
649 | JSON: json_in, # json
650 | JSON_ARRAY: json_array_in, # json[]
651 | JSONB: json_in, # jsonb
652 | JSONB_ARRAY: json_array_in, # jsonb[]
653 | MACADDR: string_in, # MACADDR type
654 | MONEY: string_in, # money
655 | MONEY_ARRAY: string_array_in, # money[]
656 | NAME: string_in, # name
657 | NAME_ARRAY: string_array_in, # name[]
658 | NUMERIC: numeric_in, # numeric
659 | NUMERIC_ARRAY: numeric_array_in, # numeric[]
660 | NUMRANGE: numrange_in, # numrange
661 | NUMRANGE_ARRAY: numrange_array_in, # numrange[]
662 | NUMMULTIRANGE: nummultirange_in, # nummultirange
663 | NUMMULTIRANGE_ARRAY: nummultirange_array_in, # nummultirange[]
664 | OID: int, # oid
665 | POINT: point_in, # point
666 | INTERVAL: interval_in, # interval
667 | INTERVAL_ARRAY: interval_array_in, # interval[]
668 | REAL: float, # float4
669 | REAL_ARRAY: float_array_in, # float4[]
670 | RECORD: record_in, # record
671 | SMALLINT: int, # int2
672 | SMALLINT_ARRAY: int_array_in, # int2[]
673 | SMALLINT_VECTOR: vector_in, # int2vector
674 | TEXT: string_in, # text
675 | TEXT_ARRAY: string_array_in, # text[]
676 | TIME: time_in, # time
677 | TIME_ARRAY: time_array_in, # time[]
678 | INTERVAL: interval_in, # interval
679 | TIMESTAMP: timestamp_in, # timestamp
680 | TIMESTAMP_ARRAY: timestamp_array_in, # timestamp
681 | TIMESTAMPTZ: timestamptz_in, # timestamptz
682 | TIMESTAMPTZ_ARRAY: timestamptz_array_in, # timestamptz
683 | TSMULTIRANGE: tsmultirange_in, # tsmultirange
684 | TSMULTIRANGE_ARRAY: tsmultirange_array_in, # tsmultirange[]
685 | TSRANGE: tsrange_in, # tsrange
686 | TSRANGE_ARRAY: tsrange_array_in, # tsrange[]
687 | TSTZMULTIRANGE: tstzmultirange_in, # tstzmultirange
688 | TSTZMULTIRANGE_ARRAY: tstzmultirange_array_in, # tstzmultirange[]
689 | TSTZRANGE: tstzrange_in, # tstzrange
690 | TSTZRANGE_ARRAY: tstzrange_array_in, # tstzrange[]
691 | UNKNOWN: string_in, # unknown
692 | UUID_ARRAY: uuid_array_in, # uuid[]
693 | UUID_TYPE: uuid_in, # uuid
694 | VARCHAR: string_in, # varchar
695 | VARCHAR_ARRAY: string_array_in, # varchar[]
696 | XID: int, # xid
697 | }
698 |
699 |
700 | # PostgreSQL encodings:
701 | # https://www.postgresql.org/docs/current/multibyte.html
702 | #
703 | # Python encodings:
704 | # https://docs.python.org/3/library/codecs.html
705 | #
706 | # Commented out encodings don't require a name change between PostgreSQL and
707 | # Python. If the py side is None, then the encoding isn't supported.
708 | PG_PY_ENCODINGS = {
709 | # Not supported:
710 | "mule_internal": None,
711 | "euc_tw": None,
712 | # Name fine as-is:
713 | # "euc_jp",
714 | # "euc_jis_2004",
715 | # "euc_kr",
716 | # "gb18030",
717 | # "gbk",
718 | # "johab",
719 | # "sjis",
720 | # "shift_jis_2004",
721 | # "uhc",
722 | # "utf8",
723 | # Different name:
724 | "euc_cn": "gb2312",
725 | "iso_8859_5": "is8859_5",
726 | "iso_8859_6": "is8859_6",
727 | "iso_8859_7": "is8859_7",
728 | "iso_8859_8": "is8859_8",
729 | "koi8": "koi8_r",
730 | "latin1": "iso8859-1",
731 | "latin2": "iso8859_2",
732 | "latin3": "iso8859_3",
733 | "latin4": "iso8859_4",
734 | "latin5": "iso8859_9",
735 | "latin6": "iso8859_10",
736 | "latin7": "iso8859_13",
737 | "latin8": "iso8859_14",
738 | "latin9": "iso8859_15",
739 | "sql_ascii": "ascii",
740 | "win866": "cp886",
741 | "win874": "cp874",
742 | "win1250": "cp1250",
743 | "win1251": "cp1251",
744 | "win1252": "cp1252",
745 | "win1253": "cp1253",
746 | "win1254": "cp1254",
747 | "win1255": "cp1255",
748 | "win1256": "cp1256",
749 | "win1257": "cp1257",
750 | "win1258": "cp1258",
751 | "unicode": "utf-8", # Needed for Amazon Redshift
752 | }
753 |
754 |
755 | def make_param(py_types, value):
756 | try:
757 | func = py_types[type(value)]
758 | except KeyError:
759 | func = str
760 | for k, v in py_types.items():
761 | try:
762 | if isinstance(value, k):
763 | func = v
764 | break
765 | except TypeError:
766 | pass
767 |
768 | return func(value)
769 |
770 |
771 | def make_params(py_types, values):
772 | return tuple([make_param(py_types, v) for v in values])
773 |
774 |
775 | def identifier(sql):
776 | if not isinstance(sql, str):
777 | raise InterfaceError("identifier must be a str")
778 |
779 | if len(sql) == 0:
780 | raise InterfaceError("identifier must be > 0 characters in length")
781 |
782 | if "\u0000" in sql:
783 | raise InterfaceError("identifier cannot contain the code zero character")
784 |
785 | sql = sql.replace('"', '""')
786 | return f'"{sql}"'
787 |
788 |
789 | def literal(value):
790 | if value is None:
791 | return "NULL"
792 | elif isinstance(value, bool):
793 | return "TRUE" if value else "FALSE"
794 | elif isinstance(value, (int, float, Decimal)):
795 | return str(value)
796 | elif isinstance(value, (bytes, bytearray)):
797 | return f"X'{value.hex()}'"
798 | elif isinstance(value, Datetime):
799 | return f"'{datetime_out(value)}'"
800 | elif isinstance(value, Date):
801 | return f"'{date_out(value)}'"
802 | elif isinstance(value, Time):
803 | return f"'{time_out(value)}'"
804 | elif isinstance(value, Timedelta):
805 | return f"'{interval_out(value)}'"
806 | elif isinstance(value, list):
807 | return f"'{array_out(value)}'"
808 | else:
809 | val = str(value).replace("'", "''")
810 | return f"'{val}'"
811 |
--------------------------------------------------------------------------------
/src/pg8000/exceptions.py:
--------------------------------------------------------------------------------
1 | class Error(Exception):
2 | """Generic exception that is the base exception of all other error
3 | exceptions.
4 |
5 | This exception is part of the `DBAPI 2.0 specification
6 | `_.
7 | """
8 |
9 | pass
10 |
11 |
12 | class InterfaceError(Error):
13 | """Generic exception raised for errors that are related to the database
14 | interface rather than the database itself. For example, if the interface
15 | attempts to use an SSL connection but the server refuses, an InterfaceError
16 | will be raised.
17 |
18 | This exception is part of the `DBAPI 2.0 specification
19 | `_.
20 | """
21 |
22 | pass
23 |
24 |
25 | class DatabaseError(Error):
26 | """Generic exception raised for errors that are related to the database.
27 |
28 | This exception is part of the `DBAPI 2.0 specification
29 | `_.
30 | """
31 |
32 | pass
33 |
--------------------------------------------------------------------------------
/src/pg8000/native.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | from enum import Enum, auto
3 |
4 | from pg8000.converters import (
5 | BIGINT,
6 | BOOLEAN,
7 | BOOLEAN_ARRAY,
8 | BYTES,
9 | CHAR,
10 | CHAR_ARRAY,
11 | DATE,
12 | FLOAT,
13 | FLOAT_ARRAY,
14 | INET,
15 | INT2VECTOR,
16 | INTEGER,
17 | INTEGER_ARRAY,
18 | INTERVAL,
19 | JSON,
20 | JSONB,
21 | JSONB_ARRAY,
22 | JSON_ARRAY,
23 | MACADDR,
24 | NAME,
25 | NAME_ARRAY,
26 | NULLTYPE,
27 | NUMERIC,
28 | NUMERIC_ARRAY,
29 | OID,
30 | PGInterval,
31 | STRING,
32 | TEXT,
33 | TEXT_ARRAY,
34 | TIME,
35 | TIMESTAMP,
36 | TIMESTAMPTZ,
37 | UNKNOWN,
38 | UUID_TYPE,
39 | VARCHAR,
40 | VARCHAR_ARRAY,
41 | XID,
42 | identifier,
43 | literal,
44 | make_params,
45 | )
46 | from pg8000.core import CoreConnection, ver
47 | from pg8000.exceptions import DatabaseError, Error, InterfaceError
48 | from pg8000.types import Range
49 |
50 | __version__ = ver
51 |
52 | # Copyright (c) 2007-2009, Mathieu Fenniak
53 | # Copyright (c) The Contributors
54 | # All rights reserved.
55 | #
56 | # Redistribution and use in source and binary forms, with or without
57 | # modification, are permitted provided that the following conditions are
58 | # met:
59 | #
60 | # * Redistributions of source code must retain the above copyright notice,
61 | # this list of conditions and the following disclaimer.
62 | # * Redistributions in binary form must reproduce the above copyright notice,
63 | # this list of conditions and the following disclaimer in the documentation
64 | # and/or other materials provided with the distribution.
65 | # * The name of the author may not be used to endorse or promote products
66 | # derived from this software without specific prior written permission.
67 | #
68 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
69 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
70 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
71 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
72 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
73 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
74 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
75 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
76 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
77 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
78 | # POSSIBILITY OF SUCH DAMAGE.
79 |
80 |
81 | class State(Enum):
82 | OUT = auto() # outside quoted string
83 | IN_SQ = auto() # inside single-quote string '...'
84 | IN_QI = auto() # inside quoted identifier "..."
85 | IN_ES = auto() # inside escaped single-quote string, E'...'
86 | IN_PN = auto() # inside parameter name eg. :name
87 | IN_CO = auto() # inside inline comment eg. --
88 | IN_DQ = auto() # inside dollar-quoted string eg. $$...$$
89 | IN_DP = auto() # inside dollar parameter eg. $1
90 |
91 |
92 | def to_statement(query):
93 | in_quote_escape = False
94 | placeholders = []
95 | output_query = []
96 | state = State.OUT
97 | prev_c = None
98 | for i, c in enumerate(query):
99 | if i + 1 < len(query):
100 | next_c = query[i + 1]
101 | else:
102 | next_c = None
103 |
104 | if state == State.OUT:
105 | if c == "'":
106 | output_query.append(c)
107 | if prev_c == "E":
108 | state = State.IN_ES
109 | else:
110 | state = State.IN_SQ
111 | elif c == '"':
112 | output_query.append(c)
113 | state = State.IN_QI
114 | elif c == "-":
115 | output_query.append(c)
116 | if prev_c == "-":
117 | state = State.IN_CO
118 | elif c == "$":
119 | output_query.append(c)
120 | if prev_c == "$":
121 | state = State.IN_DQ
122 | elif next_c.isdigit():
123 | state = State.IN_DP
124 | placeholders.append("")
125 | elif c == ":" and next_c not in ":=" and prev_c != ":":
126 | state = State.IN_PN
127 | placeholders.append("")
128 | else:
129 | output_query.append(c)
130 |
131 | elif state == State.IN_SQ:
132 | if c == "'":
133 | if in_quote_escape:
134 | in_quote_escape = False
135 | elif next_c == "'":
136 | in_quote_escape = True
137 | else:
138 | state = State.OUT
139 | output_query.append(c)
140 |
141 | elif state == State.IN_QI:
142 | if c == '"':
143 | state = State.OUT
144 | output_query.append(c)
145 |
146 | elif state == State.IN_ES:
147 | if c == "'" and prev_c != "\\":
148 | # check for escaped single-quote
149 | state = State.OUT
150 | output_query.append(c)
151 |
152 | elif state == State.IN_PN:
153 | placeholders[-1] += c
154 | if next_c is None or (not next_c.isalnum() and next_c != "_"):
155 | state = State.OUT
156 | try:
157 | pidx = placeholders.index(placeholders[-1], 0, -1)
158 | output_query.append(f"${pidx + 1}")
159 | del placeholders[-1]
160 | except ValueError:
161 | output_query.append(f"${len(placeholders)}")
162 |
163 | elif state == State.IN_DP:
164 | placeholders[-1] += c
165 | output_query.append(c)
166 | if next_c is None or not next_c.isdigit():
167 | try:
168 | placeholders[-1] = int(placeholders[-1]) - 1
169 | except ValueError:
170 | raise InterfaceError(
171 | f"Expected an integer for the $ placeholder but found "
172 | f"'{placeholders[-1]}'"
173 | )
174 | state = State.OUT
175 |
176 | elif state == State.IN_CO:
177 | output_query.append(c)
178 | if c == "\n":
179 | state = State.OUT
180 |
181 | elif state == State.IN_DQ:
182 | output_query.append(c)
183 | if c == "$" and prev_c == "$":
184 | state = State.OUT
185 |
186 | prev_c = c
187 |
188 | for reserved in ("types", "stream"):
189 | if reserved in placeholders:
190 | raise InterfaceError(
191 | f"The name '{reserved}' can't be used as a placeholder because it's "
192 | f"used for another purpose."
193 | )
194 |
195 | def make_vals(args):
196 | arg_list = [v for _, v in args.items()]
197 | vals = []
198 | for p in placeholders:
199 | if isinstance(p, int):
200 | vals.append(arg_list[p])
201 | else:
202 | try:
203 | vals.append(args[p])
204 | except KeyError:
205 | raise InterfaceError(
206 | f"There's a placeholder '{p}' in the query, but no matching "
207 | f"keyword argument."
208 | )
209 | return tuple(vals)
210 |
211 | return "".join(output_query), make_vals
212 |
213 |
214 | class Connection(CoreConnection):
215 | def __init__(self, *args, **kwargs):
216 | super().__init__(*args, **kwargs)
217 | self._context = None
218 |
219 | @property
220 | def columns(self):
221 | context = self._context
222 | if context is None:
223 | return None
224 | return context.columns
225 |
226 | @property
227 | def row_count(self):
228 | context = self._context
229 | if context is None:
230 | return None
231 | return context.row_count
232 |
233 | def run(self, sql, stream=None, types=None, **params):
234 | if len(params) == 0 and stream is None:
235 | self._context = self.execute_simple(sql)
236 | else:
237 | statement, make_vals = to_statement(sql)
238 | oids = () if types is None else make_vals(defaultdict(lambda: None, types))
239 | self._context = self.execute_unnamed(
240 | statement, make_vals(params), oids=oids, stream=stream
241 | )
242 | return self._context.rows
243 |
244 | def prepare(self, sql):
245 | return PreparedStatement(self, sql)
246 |
247 |
248 | class PreparedStatement:
249 | def __init__(self, con, sql, types=None):
250 | self.con = con
251 | self.statement, self.make_vals = to_statement(sql)
252 | oids = () if types is None else self.make_vals(defaultdict(lambda: None, types))
253 | self.name_bin, self.cols, self.input_funcs = con.prepare_statement(
254 | self.statement, oids
255 | )
256 |
257 | @property
258 | def columns(self):
259 | return self._context.columns
260 |
261 | def run(self, stream=None, **params):
262 | params = make_params(self.con.py_types, self.make_vals(params))
263 |
264 | self._context = self.con.execute_named(
265 | self.name_bin, params, self.cols, self.input_funcs, self.statement
266 | )
267 |
268 | return self._context.rows
269 |
270 | def close(self):
271 | self.con.close_prepared_statement(self.name_bin)
272 |
273 |
274 | __all__ = [
275 | "BIGINT",
276 | "BOOLEAN",
277 | "BOOLEAN_ARRAY",
278 | "BYTES",
279 | "CHAR",
280 | "CHAR_ARRAY",
281 | "Connection",
282 | "DATE",
283 | "DatabaseError",
284 | "Error",
285 | "FLOAT",
286 | "FLOAT_ARRAY",
287 | "INET",
288 | "INT2VECTOR",
289 | "INTEGER",
290 | "INTEGER_ARRAY",
291 | "INTERVAL",
292 | "InterfaceError",
293 | "JSON",
294 | "JSONB",
295 | "JSONB_ARRAY",
296 | "JSON_ARRAY",
297 | "MACADDR",
298 | "NAME",
299 | "NAME_ARRAY",
300 | "NULLTYPE",
301 | "NUMERIC",
302 | "NUMERIC_ARRAY",
303 | "OID",
304 | "PGInterval",
305 | "Range",
306 | "STRING",
307 | "TEXT",
308 | "TEXT_ARRAY",
309 | "TIME",
310 | "TIMESTAMP",
311 | "TIMESTAMPTZ",
312 | "UNKNOWN",
313 | "UUID_TYPE",
314 | "VARCHAR",
315 | "VARCHAR_ARRAY",
316 | "XID",
317 | "identifier",
318 | "literal",
319 | ]
320 |
--------------------------------------------------------------------------------
/src/pg8000/types.py:
--------------------------------------------------------------------------------
1 | from datetime import timedelta as Timedelta
2 |
3 |
4 | class PGInterval:
5 | UNIT_MAP = {
6 | "millennia": "millennia",
7 | "millennium": "millennia",
8 | "centuries": "centuries",
9 | "century": "centuries",
10 | "decades": "decades",
11 | "decade": "decades",
12 | "years": "years",
13 | "year": "years",
14 | "months": "months",
15 | "month": "months",
16 | "mon": "months",
17 | "mons": "months",
18 | "weeks": "weeks",
19 | "week": "weeks",
20 | "days": "days",
21 | "day": "days",
22 | "hours": "hours",
23 | "hour": "hours",
24 | "minutes": "minutes",
25 | "minute": "minutes",
26 | "mins": "minutes",
27 | "secs": "seconds",
28 | "seconds": "seconds",
29 | "second": "seconds",
30 | "microseconds": "microseconds",
31 | "microsecond": "microseconds",
32 | }
33 |
34 | ISO_LOOKUP = {
35 | True: {
36 | "Y": "years",
37 | "M": "months",
38 | "D": "days",
39 | },
40 | False: {
41 | "H": "hours",
42 | "M": "minutes",
43 | "S": "seconds",
44 | },
45 | }
46 |
47 | @classmethod
48 | def from_str_iso_8601(cls, interval_str):
49 | # P[n]Y[n]M[n]DT[n]H[n]M[n]S
50 | kwargs = {}
51 | lookup = cls.ISO_LOOKUP[True]
52 | val = []
53 |
54 | for c in interval_str[1:]:
55 | if c == "T":
56 | lookup = cls.ISO_LOOKUP[False]
57 | elif c.isdigit() or c in ("-", "."):
58 | val.append(c)
59 | else:
60 | val_str = "".join(val)
61 | name = lookup[c]
62 | v = float(val_str) if name == "seconds" else int(val_str)
63 | kwargs[name] = v
64 | val.clear()
65 |
66 | return cls(**kwargs)
67 |
68 | @classmethod
69 | def from_str_postgres(cls, interval_str):
70 | """Parses both the postgres and postgres_verbose formats"""
71 |
72 | t = {}
73 |
74 | curr_val = None
75 | for k in interval_str.split():
76 | if ":" in k:
77 | hours_str, minutes_str, seconds_str = k.split(":")
78 | hours = int(hours_str)
79 | if hours != 0:
80 | t["hours"] = hours
81 | minutes = int(minutes_str)
82 | if minutes != 0:
83 | t["minutes"] = minutes
84 |
85 | seconds = float(seconds_str)
86 |
87 | if seconds != 0:
88 | t["seconds"] = seconds
89 |
90 | elif k == "@":
91 | continue
92 |
93 | elif k == "ago":
94 | for k, v in tuple(t.items()):
95 | t[k] = -1 * v
96 |
97 | else:
98 | try:
99 | curr_val = int(k)
100 | except ValueError:
101 | t[cls.UNIT_MAP[k]] = curr_val
102 |
103 | return cls(**t)
104 |
105 | @classmethod
106 | def from_str_sql_standard(cls, interval_str):
107 | """YYYY-MM
108 | or
109 | DD HH:MM:SS.F
110 | or
111 | YYYY-MM DD HH:MM:SS.F
112 | """
113 | month_part = None
114 | day_parts = None
115 | parts = interval_str.split()
116 |
117 | if len(parts) == 1:
118 | month_part = parts[0]
119 | elif len(parts) == 2:
120 | day_parts = parts
121 | else:
122 | month_part = parts[0]
123 | day_parts = parts[1:]
124 |
125 | kwargs = {}
126 |
127 | if month_part is not None:
128 | if month_part.startswith("-"):
129 | sign = -1
130 | p = month_part[1:]
131 | else:
132 | sign = 1
133 | p = month_part
134 |
135 | kwargs["years"], kwargs["months"] = [int(v) * sign for v in p.split("-")]
136 |
137 | if day_parts is not None:
138 | kwargs["days"] = int(day_parts[0])
139 | time_part = day_parts[1]
140 |
141 | if time_part.startswith("-"):
142 | sign = -1
143 | p = time_part[1:]
144 | else:
145 | sign = 1
146 | p = time_part
147 |
148 | hours, minutes, seconds = p.split(":")
149 | kwargs["hours"] = int(hours) * sign
150 | kwargs["minutes"] = int(minutes) * sign
151 | kwargs["seconds"] = float(seconds) * sign
152 |
153 | return cls(**kwargs)
154 |
155 | @classmethod
156 | def from_str(cls, interval_str):
157 | if interval_str.startswith("P"):
158 | return cls.from_str_iso_8601(interval_str)
159 | elif interval_str.startswith("@"):
160 | return cls.from_str_postgres(interval_str)
161 | else:
162 | parts = interval_str.split()
163 | if (len(parts) > 1 and parts[1][0].isalpha()) or (
164 | len(parts) == 1 and ":" in parts[0]
165 | ):
166 | return cls.from_str_postgres(interval_str)
167 | else:
168 | return cls.from_str_sql_standard(interval_str)
169 |
170 | def __init__(
171 | self,
172 | millennia=None,
173 | centuries=None,
174 | decades=None,
175 | years=None,
176 | months=None,
177 | weeks=None,
178 | days=None,
179 | hours=None,
180 | minutes=None,
181 | seconds=None,
182 | microseconds=None,
183 | ):
184 | self.millennia = millennia
185 | self.centuries = centuries
186 | self.decades = decades
187 | self.years = years
188 | self.months = months
189 | self.weeks = weeks
190 | self.days = days
191 | self.hours = hours
192 | self.minutes = minutes
193 | self.seconds = seconds
194 | self.microseconds = microseconds
195 |
196 | def __repr__(self):
197 | return f""
198 |
199 | def _value_dict(self):
200 | return {
201 | k: v
202 | for k, v in (
203 | ("millennia", self.millennia),
204 | ("centuries", self.centuries),
205 | ("decades", self.decades),
206 | ("years", self.years),
207 | ("months", self.months),
208 | ("weeks", self.weeks),
209 | ("days", self.days),
210 | ("hours", self.hours),
211 | ("minutes", self.minutes),
212 | ("seconds", self.seconds),
213 | ("microseconds", self.microseconds),
214 | )
215 | if v is not None
216 | }
217 |
218 | def __str__(self):
219 | return " ".join(f"{v} {n}" for n, v in self._value_dict().items())
220 |
221 | def normalize(self):
222 | months = 0
223 | if self.months is not None:
224 | months += self.months
225 | if self.years is not None:
226 | months += self.years * 12
227 |
228 | days = 0
229 | if self.days is not None:
230 | days += self.days
231 | if self.weeks is not None:
232 | days += self.weeks * 7
233 |
234 | seconds = 0
235 | if self.hours is not None:
236 | seconds += self.hours * 60 * 60
237 | if self.minutes is not None:
238 | seconds += self.minutes * 60
239 | if self.seconds is not None:
240 | seconds += self.seconds
241 | if self.microseconds is not None:
242 | seconds += self.microseconds / 1000000
243 |
244 | return PGInterval(months=months, days=days, seconds=seconds)
245 |
246 | def __eq__(self, other):
247 | if isinstance(other, PGInterval):
248 | s = self.normalize()
249 | o = other.normalize()
250 | return s.months == o.months and s.days == o.days and s.seconds == o.seconds
251 | else:
252 | return False
253 |
254 | def to_timedelta(self):
255 | pairs = self._value_dict()
256 | overlap = pairs.keys() & {
257 | "weeks",
258 | "months",
259 | "years",
260 | "decades",
261 | "centuries",
262 | "millennia",
263 | }
264 | if len(overlap) > 0:
265 | raise ValueError(
266 | "Can't fit the interval fields {overlap} into a datetime.timedelta."
267 | )
268 |
269 | return Timedelta(**pairs)
270 |
271 |
272 | class Range:
273 | def __init__(
274 | self,
275 | lower=None,
276 | upper=None,
277 | bounds="[)",
278 | is_empty=False,
279 | ):
280 | self.lower = lower
281 | self.upper = upper
282 | self.bounds = bounds
283 | self.is_empty = is_empty
284 |
285 | def __eq__(self, other):
286 | if isinstance(other, Range):
287 | if self.is_empty or other.is_empty:
288 | return self.is_empty == other.is_empty
289 | else:
290 | return (
291 | self.lower == other.lower
292 | and self.upper == other.upper
293 | and self.bounds == other.bounds
294 | )
295 | return False
296 |
297 | def __str__(self):
298 | if self.is_empty:
299 | return "empty"
300 | else:
301 | le, ue = ["" if v is None else v for v in (self.lower, self.upper)]
302 | return f"{self.bounds[0]}{le},{ue}{self.bounds[1]}"
303 |
304 | def __repr__(self):
305 | return f""
306 |
--------------------------------------------------------------------------------
/test/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tlocke/pg8000/46c00021ade1d19466b07ed30392386c5f0a6b8e/test/__init__.py
--------------------------------------------------------------------------------
/test/dbapi/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tlocke/pg8000/46c00021ade1d19466b07ed30392386c5f0a6b8e/test/dbapi/__init__.py
--------------------------------------------------------------------------------
/test/dbapi/auth/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tlocke/pg8000/46c00021ade1d19466b07ed30392386c5f0a6b8e/test/dbapi/auth/__init__.py
--------------------------------------------------------------------------------
/test/dbapi/auth/test_gss.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from pg8000.dbapi import InterfaceError, connect
4 |
5 |
6 | # This requires a line in pg_hba.conf that requires gss for the database
7 | # pg8000_gss
8 |
9 |
10 | def test_gss(db_kwargs):
11 | db_kwargs["database"] = "pg8000_gss"
12 |
13 | # Should raise an exception saying gss isn't supported
14 | with pytest.raises(
15 | InterfaceError,
16 | match="Authentication method 7 not supported by pg8000.",
17 | ):
18 | with connect(**db_kwargs):
19 | pass
20 |
--------------------------------------------------------------------------------
/test/dbapi/auth/test_md5.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from pg8000.dbapi import DatabaseError, connect
4 |
5 |
6 | # This requires a line in pg_hba.conf that requires md5 for the database
7 | # pg8000_md5
8 |
9 |
10 | def test_md5(db_kwargs):
11 | db_kwargs["database"] = "pg8000_md5"
12 |
13 | # Should only raise an exception saying db doesn't exist
14 | with pytest.raises(DatabaseError, match="3D000"):
15 | with connect(**db_kwargs):
16 | pass
17 |
--------------------------------------------------------------------------------
/test/dbapi/auth/test_md5_ssl.py:
--------------------------------------------------------------------------------
1 | import ssl
2 |
3 | from pg8000.dbapi import connect
4 |
5 |
6 | def test_md5_ssl(db_kwargs):
7 | context = ssl.create_default_context()
8 | context.check_hostname = False
9 | context.verify_mode = ssl.CERT_NONE
10 | db_kwargs["ssl_context"] = context
11 | with connect(**db_kwargs):
12 | pass
13 |
--------------------------------------------------------------------------------
/test/dbapi/auth/test_password.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from pg8000.dbapi import DatabaseError, connect
4 |
5 | # This requires a line in pg_hba.conf that requires 'password' for the
6 | # database pg8000_password
7 |
8 |
9 | def test_password(db_kwargs):
10 | db_kwargs["database"] = "pg8000_password"
11 |
12 | # Should only raise an exception saying db doesn't exist
13 | with pytest.raises(DatabaseError, match="3D000"):
14 | with connect(**db_kwargs):
15 | pass
16 |
--------------------------------------------------------------------------------
/test/dbapi/auth/test_scram-sha-256.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from pg8000.dbapi import DatabaseError, connect
4 |
5 | # This requires a line in pg_hba.conf that requires scram-sha-256 for the
6 | # database pg8000_scram_sha_256
7 |
8 | DB = "pg8000_scram_sha_256"
9 |
10 |
11 | @pytest.fixture
12 | def setup(con, cursor):
13 | con.autocommit = True
14 | try:
15 | cursor.execute(f"CREATE DATABASE {DB}")
16 | except DatabaseError:
17 | con.rollback()
18 |
19 | cursor.execute("ALTER SYSTEM SET ssl = off")
20 | cursor.execute("SELECT pg_reload_conf()")
21 | yield
22 | cursor.execute("ALTER SYSTEM SET ssl = on")
23 | cursor.execute("SELECT pg_reload_conf()")
24 |
25 |
26 | def test_scram_sha_256(setup, db_kwargs):
27 | db_kwargs["database"] = DB
28 |
29 | with connect(**db_kwargs):
30 | pass
31 |
--------------------------------------------------------------------------------
/test/dbapi/auth/test_scram-sha-256_ssl.py:
--------------------------------------------------------------------------------
1 | from ssl import CERT_NONE, create_default_context
2 | import pytest
3 |
4 | from pg8000.dbapi import DatabaseError, connect
5 |
6 | # This requires a line in pg_hba.conf that requires scram-sha-256 for the
7 | # database pg8000_scram_sha_256
8 |
9 | DB = "pg8000_scram_sha_256"
10 |
11 |
12 | @pytest.fixture
13 | def setup(con, cursor):
14 | con.autocommit = True
15 | try:
16 | cursor.execute(f"CREATE DATABASE {DB}")
17 | except DatabaseError:
18 | con.rollback()
19 |
20 |
21 | def test_scram_sha_256(setup, db_kwargs):
22 | db_kwargs["database"] = DB
23 |
24 | con = connect(**db_kwargs)
25 | con.close()
26 |
27 |
28 | def test_scram_sha_256_ssl_context(setup, db_kwargs):
29 | ssl_context = create_default_context()
30 | ssl_context.check_hostname = False
31 | ssl_context.verify_mode = CERT_NONE
32 |
33 | db_kwargs["database"] = DB
34 | db_kwargs["ssl_context"] = ssl_context
35 |
36 | con = connect(**db_kwargs)
37 | con.close()
38 |
--------------------------------------------------------------------------------
/test/dbapi/conftest.py:
--------------------------------------------------------------------------------
1 | from os import environ
2 |
3 | import pytest
4 |
5 | from test.utils import parse_server_version
6 |
7 | import pg8000.dbapi
8 |
9 |
10 | @pytest.fixture(scope="class")
11 | def db_kwargs():
12 | db_connect = {"user": "postgres", "password": "pw"}
13 |
14 | for kw, var, f in [
15 | ("host", "PGHOST", str),
16 | ("password", "PGPASSWORD", str),
17 | ("port", "PGPORT", int),
18 | ]:
19 | try:
20 | db_connect[kw] = f(environ[var])
21 | except KeyError:
22 | pass
23 |
24 | return db_connect
25 |
26 |
27 | @pytest.fixture
28 | def con(request, db_kwargs):
29 | conn = pg8000.dbapi.connect(**db_kwargs)
30 |
31 | def fin():
32 | try:
33 | conn.rollback()
34 | except pg8000.dbapi.InterfaceError:
35 | pass
36 |
37 | try:
38 | conn.close()
39 | except pg8000.dbapi.InterfaceError:
40 | pass
41 |
42 | request.addfinalizer(fin)
43 | return conn
44 |
45 |
46 | @pytest.fixture
47 | def cursor(request, con):
48 | cursor = con.cursor()
49 |
50 | def fin():
51 | cursor.close()
52 |
53 | request.addfinalizer(fin)
54 | return cursor
55 |
56 |
57 | @pytest.fixture
58 | def pg_version(cursor):
59 | cursor.execute("select current_setting('server_version')")
60 | retval = cursor.fetchall()
61 | version = retval[0][0]
62 | major = parse_server_version(version)
63 | return major
64 |
--------------------------------------------------------------------------------
/test/dbapi/test_benchmarks.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from pg8000.dbapi import connect
4 |
5 |
6 | @pytest.mark.parametrize(
7 | "txt",
8 | (
9 | ("int2", "cast(id / 100 as int2)"),
10 | "cast(id as int4)",
11 | "cast(id * 100 as int8)",
12 | "(id % 2) = 0",
13 | "N'Static text string'",
14 | "cast(id / 100 as float4)",
15 | "cast(id / 100 as float8)",
16 | "cast(id / 100 as numeric)",
17 | "timestamp '2001-09-28'",
18 | ),
19 | )
20 | def test_round_trips(db_kwargs, benchmark, txt):
21 | def torun():
22 | with connect(**db_kwargs) as con:
23 | query = f"""SELECT {txt}, {txt}, {txt}, {txt}, {txt}, {txt}, {txt}
24 | FROM (SELECT generate_series(1, 10000) AS id) AS tbl"""
25 | cursor = con.cursor()
26 | cursor.execute(query)
27 | cursor.fetchall()
28 | cursor.close()
29 |
30 | benchmark(torun)
31 |
--------------------------------------------------------------------------------
/test/dbapi/test_connection.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import socket
3 | import warnings
4 |
5 | import pytest
6 |
7 | from pg8000.dbapi import DatabaseError, InterfaceError, __version__, connect
8 |
9 |
10 | def test_unix_socket_missing():
11 | conn_params = {"unix_sock": "/file-does-not-exist", "user": "doesn't-matter"}
12 |
13 | with pytest.raises(InterfaceError):
14 | with connect(**conn_params):
15 | pass
16 |
17 |
18 | def test_internet_socket_connection_refused():
19 | conn_params = {"port": 0, "user": "doesn't-matter"}
20 |
21 | with pytest.raises(
22 | InterfaceError,
23 | match="Can't create a connection to host localhost and port 0 "
24 | "\\(timeout is None and source_address is None\\).",
25 | ):
26 | with connect(**conn_params):
27 | pass
28 |
29 |
30 | def test_Connection_plain_socket(db_kwargs):
31 | host = db_kwargs.get("host", "localhost")
32 | port = db_kwargs.get("port", 5432)
33 | with socket.create_connection((host, port)) as sock:
34 | conn_params = {
35 | "sock": sock,
36 | "user": db_kwargs["user"],
37 | "password": db_kwargs["password"],
38 | "ssl_context": False,
39 | }
40 |
41 | with connect(**conn_params) as con:
42 | cur = con.cursor()
43 |
44 | cur.execute("SELECT 1")
45 | res = cur.fetchall()
46 | assert res[0][0] == 1
47 |
48 |
49 | def test_database_missing(db_kwargs):
50 | db_kwargs["database"] = "missing-db"
51 | with pytest.raises(DatabaseError):
52 | with connect(**db_kwargs):
53 | pass
54 |
55 |
56 | def test_database_name_unicode(db_kwargs):
57 | db_kwargs["database"] = "pg8000_sn\uFF6Fw"
58 |
59 | # Should only raise an exception saying db doesn't exist
60 | with pytest.raises(DatabaseError, match="3D000"):
61 | with connect(**db_kwargs):
62 | pass
63 |
64 |
65 | def test_database_name_bytes(db_kwargs):
66 | """Should only raise an exception saying db doesn't exist"""
67 |
68 | db_kwargs["database"] = bytes("pg8000_sn\uFF6Fw", "utf8")
69 | with pytest.raises(DatabaseError, match="3D000"):
70 | with connect(**db_kwargs):
71 | pass
72 |
73 |
74 | def test_password_bytes(con, db_kwargs):
75 | # Create user
76 | username = "boltzmann"
77 | password = "cha\uFF6Fs"
78 | cur = con.cursor()
79 | cur.execute("create user " + username + " with password '" + password + "';")
80 | con.commit()
81 |
82 | db_kwargs["user"] = username
83 | db_kwargs["password"] = password.encode("utf8")
84 | db_kwargs["database"] = "pg8000_md5"
85 | with pytest.raises(DatabaseError, match="3D000"):
86 | with connect(**db_kwargs):
87 | pass
88 |
89 | cur.execute("drop role " + username)
90 | con.commit()
91 |
92 |
93 | def test_application_name(db_kwargs):
94 | app_name = "my test application name"
95 | db_kwargs["application_name"] = app_name
96 | with connect(**db_kwargs) as db:
97 | cur = db.cursor()
98 | cur.execute(
99 | "select application_name from pg_stat_activity "
100 | " where pid = pg_backend_pid()"
101 | )
102 |
103 | application_name = cur.fetchone()[0]
104 | assert application_name == app_name
105 |
106 |
107 | def test_application_name_integer(db_kwargs):
108 | db_kwargs["application_name"] = 1
109 | with pytest.raises(
110 | InterfaceError,
111 | match="The parameter application_name can't be of type " ".",
112 | ):
113 | with connect(**db_kwargs):
114 | pass
115 |
116 |
117 | def test_application_name_bytearray(db_kwargs):
118 | db_kwargs["application_name"] = bytearray(b"Philby")
119 | with connect(**db_kwargs):
120 | pass
121 |
122 |
123 | def test_notify(con):
124 | cursor = con.cursor()
125 | cursor.execute("select pg_backend_pid()")
126 | backend_pid = cursor.fetchall()[0][0]
127 | assert list(con.notifications) == []
128 | cursor.execute("LISTEN test")
129 | cursor.execute("NOTIFY test")
130 | con.commit()
131 |
132 | cursor.execute("VALUES (1, 2), (3, 4), (5, 6)")
133 | assert len(con.notifications) == 1
134 | assert con.notifications[0] == (backend_pid, "test", "")
135 |
136 |
137 | def test_notify_with_payload(con):
138 | cursor = con.cursor()
139 | cursor.execute("select pg_backend_pid()")
140 | backend_pid = cursor.fetchall()[0][0]
141 | assert list(con.notifications) == []
142 | cursor.execute("LISTEN test")
143 | cursor.execute("NOTIFY test, 'Parnham'")
144 | con.commit()
145 |
146 | cursor.execute("VALUES (1, 2), (3, 4), (5, 6)")
147 | assert len(con.notifications) == 1
148 | assert con.notifications[0] == (backend_pid, "test", "Parnham")
149 |
150 |
151 | def test_broken_pipe_read(con, db_kwargs):
152 | db1 = connect(**db_kwargs)
153 | cur1 = db1.cursor()
154 | cur2 = con.cursor()
155 | cur1.execute("select pg_backend_pid()")
156 | pid1 = cur1.fetchone()[0]
157 |
158 | cur2.execute("select pg_terminate_backend(%s)", (pid1,))
159 | with pytest.raises(InterfaceError, match="network error"):
160 | cur1.execute("select 1")
161 |
162 | try:
163 | db1.close()
164 | except InterfaceError:
165 | pass
166 |
167 |
168 | def test_broken_pipe_flush(con, db_kwargs):
169 | db1 = connect(**db_kwargs)
170 | cur1 = db1.cursor()
171 | cur2 = con.cursor()
172 | cur1.execute("select pg_backend_pid()")
173 | pid1 = cur1.fetchone()[0]
174 |
175 | cur2.execute("select pg_terminate_backend(%s)", (pid1,))
176 | try:
177 | cur1.execute("select 1")
178 | except BaseException:
179 | pass
180 |
181 | # Sometimes raises and sometime doesn't
182 | try:
183 | db1.close()
184 | except InterfaceError as e:
185 | assert str(e) == "network error"
186 |
187 |
188 | def test_broken_pipe_unpack(con):
189 | cur = con.cursor()
190 | cur.execute("select pg_backend_pid()")
191 | pid1 = cur.fetchone()[0]
192 |
193 | with pytest.raises(InterfaceError, match="network error"):
194 | cur.execute("select pg_terminate_backend(%s)", (pid1,))
195 |
196 |
197 | def test_py_value_fail(con, mocker):
198 | # Ensure that if types.py_value throws an exception, the original
199 | # exception is raised (PG8000TestException), and the connection is
200 | # still usable after the error.
201 |
202 | class PG8000TestException(Exception):
203 | pass
204 |
205 | def raise_exception(val):
206 | raise PG8000TestException("oh noes!")
207 |
208 | mocker.patch.object(con, "py_types")
209 | con.py_types = {datetime.time: raise_exception}
210 |
211 | with pytest.raises(PG8000TestException):
212 | c = con.cursor()
213 | c.execute("SELECT CAST(%s AS TIME) AS f1", (datetime.time(10, 30),))
214 | c.fetchall()
215 |
216 | # ensure that the connection is still usable for a new query
217 | c.execute("VALUES ('hw3'::text)")
218 | assert c.fetchone()[0] == "hw3"
219 |
220 |
221 | def test_no_data_error_recovery(con):
222 | for i in range(1, 4):
223 | with pytest.raises(DatabaseError) as e:
224 | c = con.cursor()
225 | c.execute("DROP TABLE t1")
226 | assert e.value.args[0]["C"] == "42P01"
227 | con.rollback()
228 |
229 |
230 | def test_closed_connection(db_kwargs):
231 | warnings.simplefilter("ignore")
232 |
233 | my_db = connect(**db_kwargs)
234 | cursor = my_db.cursor()
235 | my_db.close()
236 | with pytest.raises(my_db.InterfaceError, match="connection is closed"):
237 | cursor.execute("VALUES ('hw1'::text)")
238 |
239 | warnings.resetwarnings()
240 |
241 |
242 | def test_version():
243 | try:
244 | from importlib.metadata import version
245 | except ImportError:
246 | from importlib_metadata import version
247 |
248 | ver = version("pg8000")
249 |
250 | assert __version__ == ver
251 |
252 |
253 | @pytest.mark.parametrize(
254 | "commit",
255 | [
256 | "commit",
257 | "COMMIT;",
258 | ],
259 | )
260 | def test_failed_transaction_commit_sql(cursor, commit):
261 | cursor.execute("create temporary table tt (f1 int primary key)")
262 | cursor.execute("begin")
263 | try:
264 | cursor.execute("insert into tt(f1) values(null)")
265 | except DatabaseError:
266 | pass
267 |
268 | with pytest.raises(InterfaceError):
269 | cursor.execute(commit)
270 |
271 |
272 | def test_failed_transaction_commit_method(con, cursor):
273 | cursor.execute("create temporary table tt (f1 int primary key)")
274 | cursor.execute("begin")
275 | try:
276 | cursor.execute("insert into tt(f1) values(null)")
277 | except DatabaseError:
278 | pass
279 |
280 | with pytest.raises(InterfaceError):
281 | con.commit()
282 |
283 |
284 | @pytest.mark.parametrize(
285 | "rollback",
286 | [
287 | "rollback",
288 | "rollback;",
289 | "ROLLBACK ;",
290 | ],
291 | )
292 | def test_failed_transaction_rollback_sql(cursor, rollback):
293 | cursor.execute("create temporary table tt (f1 int primary key)")
294 | cursor.execute("begin")
295 | try:
296 | cursor.execute("insert into tt(f1) values(null)")
297 | except DatabaseError:
298 | pass
299 |
300 | cursor.execute(rollback)
301 |
302 |
303 | def test_failed_transaction_rollback_method(cursor, con):
304 | cursor.execute("create temporary table tt (f1 int primary key)")
305 | cursor.execute("begin")
306 | try:
307 | cursor.execute("insert into tt(f1) values(null)")
308 | except DatabaseError:
309 | pass
310 |
311 | con.rollback()
312 |
313 |
314 | @pytest.mark.parametrize(
315 | "sql",
316 | [
317 | "BEGIN",
318 | "select * from tt;",
319 | ],
320 | )
321 | def test_failed_transaction_sql(cursor, sql):
322 | cursor.execute("create temporary table tt (f1 int primary key)")
323 | cursor.execute("begin")
324 | try:
325 | cursor.execute("insert into tt(f1) values(null)")
326 | except DatabaseError:
327 | pass
328 |
329 | with pytest.raises(DatabaseError):
330 | cursor.execute(sql)
331 |
--------------------------------------------------------------------------------
/test/dbapi/test_copy.py:
--------------------------------------------------------------------------------
1 | from io import BytesIO
2 |
3 | import pytest
4 |
5 |
6 | @pytest.fixture
7 | def db_table(request, con):
8 | cursor = con.cursor()
9 | cursor.execute(
10 | "CREATE TEMPORARY TABLE t1 (f1 int primary key, "
11 | "f2 int not null, f3 varchar(50) null) on commit drop"
12 | )
13 | return con
14 |
15 |
16 | def test_copy_to_with_table(db_table):
17 | cursor = db_table.cursor()
18 | cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (1, 1, 1))
19 | cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (2, 2, 2))
20 | cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (3, 3, 3))
21 |
22 | stream = BytesIO()
23 | cursor.execute("copy t1 to stdout", stream=stream)
24 | assert stream.getvalue() == b"1\t1\t1\n2\t2\t2\n3\t3\t3\n"
25 | assert cursor.rowcount == 3
26 |
27 |
28 | def test_copy_to_with_query(db_table):
29 | cursor = db_table.cursor()
30 | stream = BytesIO()
31 | cursor.execute(
32 | "COPY (SELECT 1 as One, 2 as Two) TO STDOUT WITH DELIMITER "
33 | "'X' CSV HEADER QUOTE AS 'Y' FORCE QUOTE Two",
34 | stream=stream,
35 | )
36 | assert stream.getvalue() == b"oneXtwo\n1XY2Y\n"
37 | assert cursor.rowcount == 1
38 |
39 |
40 | def test_copy_from_with_table(db_table):
41 | cursor = db_table.cursor()
42 | stream = BytesIO(b"1\t1\t1\n2\t2\t2\n3\t3\t3\n")
43 | cursor.execute("copy t1 from STDIN", stream=stream)
44 | assert cursor.rowcount == 3
45 |
46 | cursor.execute("SELECT * FROM t1 ORDER BY f1")
47 | retval = cursor.fetchall()
48 | assert retval == ([1, 1, "1"], [2, 2, "2"], [3, 3, "3"])
49 |
50 |
51 | def test_copy_from_with_query(db_table):
52 | cursor = db_table.cursor()
53 | stream = BytesIO(b"f1Xf2\n1XY1Y\n")
54 | cursor.execute(
55 | "COPY t1 (f1, f2) FROM STDIN WITH DELIMITER 'X' CSV HEADER "
56 | "QUOTE AS 'Y' FORCE NOT NULL f1",
57 | stream=stream,
58 | )
59 | assert cursor.rowcount == 1
60 |
61 | cursor.execute("SELECT * FROM t1 ORDER BY f1")
62 | retval = cursor.fetchall()
63 | assert retval == ([1, 1, None],)
64 |
65 |
66 | def test_copy_from_with_error(db_table):
67 | cursor = db_table.cursor()
68 | stream = BytesIO(b"f1Xf2\n\n1XY1Y\n")
69 | with pytest.raises(BaseException) as e:
70 | cursor.execute(
71 | "COPY t1 (f1, f2) FROM STDIN WITH DELIMITER 'X' CSV HEADER "
72 | "QUOTE AS 'Y' FORCE NOT NULL f1",
73 | stream=stream,
74 | )
75 |
76 | arg = {
77 | "S": ("ERROR",),
78 | "C": ("22P02",),
79 | "M": (
80 | 'invalid input syntax for type integer: ""',
81 | 'invalid input syntax for integer: ""',
82 | ),
83 | "W": ('COPY t1, line 2, column f1: ""',),
84 | "F": ("numutils.c",),
85 | "R": ("pg_atoi", "pg_strtoint32", "pg_strtoint32_safe"),
86 | }
87 | earg = e.value.args[0]
88 | for k, v in arg.items():
89 | assert earg[k] in v
90 |
--------------------------------------------------------------------------------
/test/dbapi/test_dbapi.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import os
3 | import time
4 |
5 | import pytest
6 |
7 | from pg8000.dbapi import (
8 | BINARY,
9 | Binary,
10 | Date,
11 | DateFromTicks,
12 | Time,
13 | TimeFromTicks,
14 | Timestamp,
15 | TimestampFromTicks,
16 | )
17 |
18 |
19 | @pytest.fixture
20 | def has_tzset():
21 | # Neither Windows nor Jython 2.5.3 have a time.tzset() so skip
22 | if hasattr(time, "tzset"):
23 | os.environ["TZ"] = "UTC"
24 | time.tzset()
25 | return True
26 | return False
27 |
28 |
29 | # DBAPI compatible interface tests
30 | @pytest.fixture
31 | def db_table(con, has_tzset):
32 | c = con.cursor()
33 | c.execute(
34 | "CREATE TEMPORARY TABLE t1 "
35 | "(f1 int primary key, f2 int not null, f3 varchar(50) null) "
36 | "ON COMMIT DROP"
37 | )
38 | c.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (1, 1, None))
39 | c.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (2, 10, None))
40 | c.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (3, 100, None))
41 | c.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (4, 1000, None))
42 | c.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (5, 10000, None))
43 | return con
44 |
45 |
46 | def test_parallel_queries(db_table):
47 | c1 = db_table.cursor()
48 | c2 = db_table.cursor()
49 |
50 | c1.execute("SELECT f1, f2, f3 FROM t1")
51 | while 1:
52 | row = c1.fetchone()
53 | if row is None:
54 | break
55 | f1, f2, f3 = row
56 | c2.execute("SELECT f1, f2, f3 FROM t1 WHERE f1 > %s", (f1,))
57 | while 1:
58 | row = c2.fetchone()
59 | if row is None:
60 | break
61 | f1, f2, f3 = row
62 |
63 |
64 | def test_qmark(mocker, db_table):
65 | mocker.patch("pg8000.dbapi.paramstyle", "qmark")
66 | c1 = db_table.cursor()
67 | c1.execute("SELECT f1, f2, f3 FROM t1 WHERE f1 > ?", (3,))
68 | while 1:
69 | row = c1.fetchone()
70 | if row is None:
71 | break
72 | f1, f2, f3 = row
73 |
74 |
75 | def test_numeric(mocker, db_table):
76 | mocker.patch("pg8000.dbapi.paramstyle", "numeric")
77 | c1 = db_table.cursor()
78 | c1.execute("SELECT f1, f2, f3 FROM t1 WHERE f1 > :1", (3,))
79 | while 1:
80 | row = c1.fetchone()
81 | if row is None:
82 | break
83 | f1, f2, f3 = row
84 |
85 |
86 | def test_named(mocker, db_table):
87 | mocker.patch("pg8000.dbapi.paramstyle", "named")
88 | c1 = db_table.cursor()
89 | c1.execute("SELECT f1, f2, f3 FROM t1 WHERE f1 > :f1", {"f1": 3})
90 | while 1:
91 | row = c1.fetchone()
92 | if row is None:
93 | break
94 | f1, f2, f3 = row
95 |
96 |
97 | def test_format(mocker, db_table):
98 | mocker.patch("pg8000.dbapi.paramstyle", "format")
99 | c1 = db_table.cursor()
100 | c1.execute("SELECT f1, f2, f3 FROM t1 WHERE f1 > %s", (3,))
101 | while 1:
102 | row = c1.fetchone()
103 | if row is None:
104 | break
105 | f1, f2, f3 = row
106 |
107 |
108 | def test_pyformat(mocker, db_table):
109 | mocker.patch("pg8000.dbapi.paramstyle", "pyformat")
110 | c1 = db_table.cursor()
111 | c1.execute("SELECT f1, f2, f3 FROM t1 WHERE f1 > %(f1)s", {"f1": 3})
112 | while 1:
113 | row = c1.fetchone()
114 | if row is None:
115 | break
116 | f1, f2, f3 = row
117 |
118 |
119 | def test_arraysize(db_table):
120 | c1 = db_table.cursor()
121 | c1.arraysize = 3
122 | c1.execute("SELECT * FROM t1")
123 | retval = c1.fetchmany()
124 | assert len(retval) == c1.arraysize
125 |
126 |
127 | def test_date():
128 | val = Date(2001, 2, 3)
129 | assert val == datetime.date(2001, 2, 3)
130 |
131 |
132 | def test_time():
133 | val = Time(4, 5, 6)
134 | assert val == datetime.time(4, 5, 6)
135 |
136 |
137 | def test_timestamp():
138 | val = Timestamp(2001, 2, 3, 4, 5, 6)
139 | assert val == datetime.datetime(2001, 2, 3, 4, 5, 6)
140 |
141 |
142 | def test_date_from_ticks(has_tzset):
143 | if has_tzset:
144 | val = DateFromTicks(1173804319)
145 | assert val == datetime.date(2007, 3, 13)
146 |
147 |
148 | def testTimeFromTicks(has_tzset):
149 | if has_tzset:
150 | val = TimeFromTicks(1173804319)
151 | assert val == datetime.time(16, 45, 19)
152 |
153 |
154 | def test_timestamp_from_ticks(has_tzset):
155 | if has_tzset:
156 | val = TimestampFromTicks(1173804319)
157 | assert val == datetime.datetime(2007, 3, 13, 16, 45, 19)
158 |
159 |
160 | def test_binary():
161 | v = Binary(b"\x00\x01\x02\x03\x02\x01\x00")
162 | assert v == b"\x00\x01\x02\x03\x02\x01\x00"
163 | assert isinstance(v, BINARY)
164 |
165 |
166 | def test_row_count(db_table):
167 | c1 = db_table.cursor()
168 | c1.execute("SELECT * FROM t1")
169 |
170 | assert 5 == c1.rowcount
171 |
172 | c1.execute("UPDATE t1 SET f3 = %s WHERE f2 > 101", ("Hello!",))
173 | assert 2 == c1.rowcount
174 |
175 | c1.execute("DELETE FROM t1")
176 | assert 5 == c1.rowcount
177 |
178 |
179 | def test_fetch_many(db_table):
180 | cursor = db_table.cursor()
181 | cursor.arraysize = 2
182 | cursor.execute("SELECT * FROM t1")
183 | assert 2 == len(cursor.fetchmany())
184 | assert 2 == len(cursor.fetchmany())
185 | assert 1 == len(cursor.fetchmany())
186 | assert 0 == len(cursor.fetchmany())
187 |
188 |
189 | def test_iterator(db_table):
190 | cursor = db_table.cursor()
191 | cursor.execute("SELECT * FROM t1 ORDER BY f1")
192 | f1 = 0
193 | for row in cursor.fetchall():
194 | next_f1 = row[0]
195 | assert next_f1 > f1
196 | f1 = next_f1
197 |
198 |
199 | # Vacuum can't be run inside a transaction, so we need to turn
200 | # autocommit on.
201 | def test_vacuum(con):
202 | con.autocommit = True
203 | cursor = con.cursor()
204 | cursor.execute("vacuum")
205 |
206 |
207 | def test_prepared_statement(con):
208 | cursor = con.cursor()
209 | cursor.execute("PREPARE gen_series AS SELECT generate_series(1, 10);")
210 | cursor.execute("EXECUTE gen_series")
211 |
212 |
213 | def test_cursor_type(cursor):
214 | assert str(type(cursor)) == ""
215 |
--------------------------------------------------------------------------------
/test/dbapi/test_dbapi20.py:
--------------------------------------------------------------------------------
1 | import time
2 | import warnings
3 |
4 | import pytest
5 |
6 | import pg8000
7 |
8 | """ Python DB API 2.0 driver compliance unit test suite.
9 |
10 | This software is Public Domain and may be used without restrictions.
11 |
12 | "Now we have booze and barflies entering the discussion, plus rumours of
13 | DBAs on drugs... and I won't tell you what flashes through my mind each
14 | time I read the subject line with 'Anal Compliance' in it. All around
15 | this is turning out to be a thoroughly unwholesome unit test."
16 |
17 | -- Ian Bicking
18 | """
19 |
20 | __rcs_id__ = "$Id: dbapi20.py,v 1.10 2003/10/09 03:14:14 zenzen Exp $"
21 | __version__ = "$Revision: 1.10 $"[11:-2]
22 | __author__ = "Stuart Bishop "
23 |
24 |
25 | # $Log: dbapi20.py,v $
26 | # Revision 1.10 2003/10/09 03:14:14 zenzen
27 | # Add test for DB API 2.0 optional extension, where database exceptions
28 | # are exposed as attributes on the Connection object.
29 | #
30 | # Revision 1.9 2003/08/13 01:16:36 zenzen
31 | # Minor tweak from Stefan Fleiter
32 | #
33 | # Revision 1.8 2003/04/10 00:13:25 zenzen
34 | # Changes, as per suggestions by M.-A. Lemburg
35 | # - Add a table prefix, to ensure namespace collisions can always be avoided
36 | #
37 | # Revision 1.7 2003/02/26 23:33:37 zenzen
38 | # Break out DDL into helper functions, as per request by David Rushby
39 | #
40 | # Revision 1.6 2003/02/21 03:04:33 zenzen
41 | # Stuff from Henrik Ekelund:
42 | # added test_None
43 | # added test_nextset & hooks
44 | #
45 | # Revision 1.5 2003/02/17 22:08:43 zenzen
46 | # Implement suggestions and code from Henrik Eklund - test that
47 | # cursor.arraysize defaults to 1 & generic cursor.callproc test added
48 | #
49 | # Revision 1.4 2003/02/15 00:16:33 zenzen
50 | # Changes, as per suggestions and bug reports by M.-A. Lemburg,
51 | # Matthew T. Kromer, Federico Di Gregorio and Daniel Dittmar
52 | # - Class renamed
53 | # - Now a subclass of TestCase, to avoid requiring the driver stub
54 | # to use multiple inheritance
55 | # - Reversed the polarity of buggy test in test_description
56 | # - Test exception hierarchy correctly
57 | # - self.populate is now self._populate(), so if a driver stub
58 | # overrides self.ddl1 this change propagates
59 | # - VARCHAR columns now have a width, which will hopefully make the
60 | # DDL even more portible (this will be reversed if it causes more problems)
61 | # - cursor.rowcount being checked after various execute and fetchXXX methods
62 | # - Check for fetchall and fetchmany returning empty lists after results
63 | # are exhausted (already checking for empty lists if select retrieved
64 | # nothing
65 | # - Fix bugs in test_setoutputsize_basic and test_setinputsizes
66 | #
67 |
68 |
69 | """ Test a database self.driver for DB API 2.0 compatibility.
70 | This implementation tests Gadfly, but the TestCase
71 | is structured so that other self.drivers can subclass this
72 | test case to ensure compiliance with the DB-API. It is
73 | expected that this TestCase may be expanded in the future
74 | if ambiguities or edge conditions are discovered.
75 |
76 | The 'Optional Extensions' are not yet being tested.
77 |
78 | self.drivers should subclass this test, overriding setUp, tearDown,
79 | self.driver, connect_args and connect_kw_args. Class specification
80 | should be as follows:
81 |
82 | import dbapi20
83 | class mytest(dbapi20.DatabaseAPI20Test):
84 | [...]
85 |
86 | Don't 'import DatabaseAPI20Test from dbapi20', or you will
87 | confuse the unit tester - just 'import dbapi20'.
88 | """
89 |
90 | # The self.driver module. This should be the module where the 'connect'
91 | # method is to be found
92 | driver = pg8000
93 | table_prefix = "dbapi20test_" # If you need to specify a prefix for tables
94 |
95 | ddl1 = "create table %sbooze (name varchar(20))" % table_prefix
96 | ddl2 = "create table %sbarflys (name varchar(20))" % table_prefix
97 | xddl1 = "drop table %sbooze" % table_prefix
98 | xddl2 = "drop table %sbarflys" % table_prefix
99 |
100 | # Name of stored procedure to convert
101 | # string->lowercase
102 | lowerfunc = "lower"
103 |
104 |
105 | # Some drivers may need to override these helpers, for example adding
106 | # a 'commit' after the execute.
107 | def executeDDL1(cursor):
108 | cursor.execute(ddl1)
109 |
110 |
111 | def executeDDL2(cursor):
112 | cursor.execute(ddl2)
113 |
114 |
115 | @pytest.fixture
116 | def db(request, con):
117 | def fin():
118 | with con.cursor() as cur:
119 | for ddl in (xddl1, xddl2):
120 | try:
121 | cur.execute(ddl)
122 | con.commit()
123 | except driver.Error:
124 | # Assume table didn't exist. Other tests will check if
125 | # execute is busted.
126 | pass
127 |
128 | request.addfinalizer(fin)
129 | return con
130 |
131 |
132 | def test_apilevel():
133 | # Must exist
134 | apilevel = driver.apilevel
135 |
136 | # Must equal 2.0
137 | assert apilevel == "2.0"
138 |
139 |
140 | def test_threadsafety():
141 | try:
142 | # Must exist
143 | threadsafety = driver.threadsafety
144 | # Must be a valid value
145 | assert threadsafety in (0, 1, 2, 3)
146 | except AttributeError:
147 | assert False, "Driver doesn't define threadsafety"
148 |
149 |
150 | def test_paramstyle():
151 | try:
152 | # Must exist
153 | paramstyle = driver.paramstyle
154 | # Must be a valid value
155 | assert paramstyle in ("qmark", "numeric", "named", "format", "pyformat")
156 | except AttributeError:
157 | assert False, "Driver doesn't define paramstyle"
158 |
159 |
160 | def test_Exceptions():
161 | # Make sure required exceptions exist, and are in the
162 | # defined hierarchy.
163 | assert issubclass(driver.Warning, Exception)
164 | assert issubclass(driver.Error, Exception)
165 | assert issubclass(driver.InterfaceError, driver.Error)
166 | assert issubclass(driver.DatabaseError, driver.Error)
167 | assert issubclass(driver.OperationalError, driver.Error)
168 | assert issubclass(driver.IntegrityError, driver.Error)
169 | assert issubclass(driver.InternalError, driver.Error)
170 | assert issubclass(driver.ProgrammingError, driver.Error)
171 | assert issubclass(driver.NotSupportedError, driver.Error)
172 |
173 |
174 | def test_ExceptionsAsConnectionAttributes(con):
175 | # OPTIONAL EXTENSION
176 | # Test for the optional DB API 2.0 extension, where the exceptions
177 | # are exposed as attributes on the Connection object
178 | # I figure this optional extension will be implemented by any
179 | # driver author who is using this test suite, so it is enabled
180 | # by default.
181 | warnings.simplefilter("ignore")
182 | drv = driver
183 | assert con.Warning is drv.Warning
184 | assert con.Error is drv.Error
185 | assert con.InterfaceError is drv.InterfaceError
186 | assert con.DatabaseError is drv.DatabaseError
187 | assert con.OperationalError is drv.OperationalError
188 | assert con.IntegrityError is drv.IntegrityError
189 | assert con.InternalError is drv.InternalError
190 | assert con.ProgrammingError is drv.ProgrammingError
191 | assert con.NotSupportedError is drv.NotSupportedError
192 | warnings.resetwarnings()
193 |
194 |
195 | def test_commit(con):
196 | # Commit must work, even if it doesn't do anything
197 | con.commit()
198 |
199 |
200 | def test_rollback(con):
201 | # If rollback is defined, it should either work or throw
202 | # the documented exception
203 | if hasattr(con, "rollback"):
204 | try:
205 | con.rollback()
206 | except driver.NotSupportedError:
207 | pass
208 |
209 |
210 | def test_cursor(con):
211 | con.cursor()
212 |
213 |
214 | def test_cursor_isolation(con):
215 | # Make sure cursors created from the same connection have
216 | # the documented transaction isolation level
217 | cur1 = con.cursor()
218 | cur2 = con.cursor()
219 | executeDDL1(cur1)
220 | cur1.execute("insert into %sbooze values ('Victoria Bitter')" % (table_prefix))
221 | cur2.execute("select name from %sbooze" % table_prefix)
222 | booze = cur2.fetchall()
223 | assert len(booze) == 1
224 | assert len(booze[0]) == 1
225 | assert booze[0][0] == "Victoria Bitter"
226 |
227 |
228 | def test_description(con):
229 | cur = con.cursor()
230 | executeDDL1(cur)
231 | assert cur.description is None, (
232 | "cursor.description should be none after executing a "
233 | "statement that can return no rows (such as DDL)"
234 | )
235 | cur.execute("select name from %sbooze" % table_prefix)
236 | assert len(cur.description) == 1, "cursor.description describes too many columns"
237 | assert (
238 | len(cur.description[0]) == 7
239 | ), "cursor.description[x] tuples must have 7 elements"
240 | assert (
241 | cur.description[0][0].lower() == "name"
242 | ), "cursor.description[x][0] must return column name"
243 | assert cur.description[0][1] == driver.STRING, (
244 | "cursor.description[x][1] must return column type. Got %r"
245 | % cur.description[0][1]
246 | )
247 |
248 | # Make sure self.description gets reset
249 | executeDDL2(cur)
250 | assert cur.description is None, (
251 | "cursor.description not being set to None when executing "
252 | "no-result statements (eg. DDL)"
253 | )
254 |
255 |
256 | def test_rowcount(cursor):
257 | executeDDL1(cursor)
258 | assert cursor.rowcount == -1, (
259 | "cursor.rowcount should be -1 after executing no-result " "statements"
260 | )
261 | cursor.execute("insert into %sbooze values ('Victoria Bitter')" % (table_prefix))
262 | assert cursor.rowcount in (-1, 1), (
263 | "cursor.rowcount should == number or rows inserted, or "
264 | "set to -1 after executing an insert statement"
265 | )
266 | cursor.execute("select name from %sbooze" % table_prefix)
267 | assert cursor.rowcount in (-1, 1), (
268 | "cursor.rowcount should == number of rows returned, or "
269 | "set to -1 after executing a select statement"
270 | )
271 | executeDDL2(cursor)
272 | assert cursor.rowcount == -1, (
273 | "cursor.rowcount not being reset to -1 after executing " "no-result statements"
274 | )
275 |
276 |
277 | def test_close(con):
278 | cur = con.cursor()
279 | con.close()
280 |
281 | # cursor.execute should raise an Error if called after connection
282 | # closed
283 | with pytest.raises(driver.Error):
284 | executeDDL1(cur)
285 |
286 | # connection.commit should raise an Error if called after connection'
287 | # closed.'
288 | with pytest.raises(driver.Error):
289 | con.commit()
290 |
291 | # connection.close should raise an Error if called more than once
292 | with pytest.raises(driver.Error):
293 | con.close()
294 |
295 |
296 | def test_execute(con):
297 | cur = con.cursor()
298 | _paraminsert(cur)
299 |
300 |
301 | def _paraminsert(cur):
302 | executeDDL1(cur)
303 | cur.execute("insert into %sbooze values ('Victoria Bitter')" % (table_prefix))
304 | assert cur.rowcount in (-1, 1)
305 |
306 | if driver.paramstyle == "qmark":
307 | cur.execute("insert into %sbooze values (?)" % table_prefix, ("Cooper's",))
308 | elif driver.paramstyle == "numeric":
309 | cur.execute("insert into %sbooze values (:1)" % table_prefix, ("Cooper's",))
310 | elif driver.paramstyle == "named":
311 | cur.execute(
312 | "insert into %sbooze values (:beer)" % table_prefix, {"beer": "Cooper's"}
313 | )
314 | elif driver.paramstyle == "format":
315 | cur.execute("insert into %sbooze values (%%s)" % table_prefix, ("Cooper's",))
316 | elif driver.paramstyle == "pyformat":
317 | cur.execute(
318 | "insert into %sbooze values (%%(beer)s)" % table_prefix,
319 | {"beer": "Cooper's"},
320 | )
321 | else:
322 | assert False, "Invalid paramstyle"
323 |
324 | assert cur.rowcount in (-1, 1)
325 |
326 | cur.execute("select name from %sbooze" % table_prefix)
327 | res = cur.fetchall()
328 | assert len(res) == 2, "cursor.fetchall returned too few rows"
329 | beers = [res[0][0], res[1][0]]
330 | beers.sort()
331 | assert beers[0] == "Cooper's", (
332 | "cursor.fetchall retrieved incorrect data, or data inserted " "incorrectly"
333 | )
334 | assert beers[1] == "Victoria Bitter", (
335 | "cursor.fetchall retrieved incorrect data, or data inserted " "incorrectly"
336 | )
337 |
338 |
339 | def test_executemany(cursor):
340 | executeDDL1(cursor)
341 | largs = [("Cooper's",), ("Boag's",)]
342 | margs = [{"beer": "Cooper's"}, {"beer": "Boag's"}]
343 | if driver.paramstyle == "qmark":
344 | cursor.executemany("insert into %sbooze values (?)" % table_prefix, largs)
345 | elif driver.paramstyle == "numeric":
346 | cursor.executemany("insert into %sbooze values (:1)" % table_prefix, largs)
347 | elif driver.paramstyle == "named":
348 | cursor.executemany("insert into %sbooze values (:beer)" % table_prefix, margs)
349 | elif driver.paramstyle == "format":
350 | cursor.executemany("insert into %sbooze values (%%s)" % table_prefix, largs)
351 | elif driver.paramstyle == "pyformat":
352 | cursor.executemany(
353 | "insert into %sbooze values (%%(beer)s)" % (table_prefix), margs
354 | )
355 | else:
356 | assert False, "Unknown paramstyle"
357 |
358 | assert cursor.rowcount in (-1, 2), (
359 | "insert using cursor.executemany set cursor.rowcount to "
360 | "incorrect value %r" % cursor.rowcount
361 | )
362 |
363 | cursor.execute("select name from %sbooze" % table_prefix)
364 | res = cursor.fetchall()
365 | assert len(res) == 2, "cursor.fetchall retrieved incorrect number of rows"
366 | beers = [res[0][0], res[1][0]]
367 | beers.sort()
368 | assert beers[0] == "Boag's", "incorrect data retrieved"
369 | assert beers[1] == "Cooper's", "incorrect data retrieved"
370 |
371 |
372 | def test_fetchone(cursor):
373 | # cursor.fetchone should raise an Error if called before
374 | # executing a select-type query
375 | with pytest.raises(driver.Error):
376 | cursor.fetchone()
377 |
378 | # cursor.fetchone should raise an Error if called after
379 | # executing a query that cannot return rows
380 | executeDDL1(cursor)
381 | with pytest.raises(driver.Error):
382 | cursor.fetchone()
383 |
384 | cursor.execute("select name from %sbooze" % table_prefix)
385 | assert cursor.fetchone() is None, (
386 | "cursor.fetchone should return None if a query retrieves " "no rows"
387 | )
388 | assert cursor.rowcount in (-1, 0)
389 |
390 | # cursor.fetchone should raise an Error if called after
391 | # executing a query that cannot return rows
392 | cursor.execute("insert into %sbooze values ('Victoria Bitter')" % (table_prefix))
393 | with pytest.raises(driver.Error):
394 | cursor.fetchone()
395 |
396 | cursor.execute("select name from %sbooze" % table_prefix)
397 | r = cursor.fetchone()
398 | assert len(r) == 1, "cursor.fetchone should have retrieved a single row"
399 | assert r[0] == "Victoria Bitter", "cursor.fetchone retrieved incorrect data"
400 | assert (
401 | cursor.fetchone() is None
402 | ), "cursor.fetchone should return None if no more rows available"
403 | assert cursor.rowcount in (-1, 1)
404 |
405 |
406 | samples = [
407 | "Carlton Cold",
408 | "Carlton Draft",
409 | "Mountain Goat",
410 | "Redback",
411 | "Victoria Bitter",
412 | "XXXX",
413 | ]
414 |
415 |
416 | def _populate():
417 | """Return a list of sql commands to setup the DB for the fetch
418 | tests.
419 | """
420 | populate = [
421 | "insert into %sbooze values ('%s')" % (table_prefix, s) for s in samples
422 | ]
423 | return populate
424 |
425 |
426 | def test_fetchmany(cursor):
427 | # cursor.fetchmany should raise an Error if called without
428 | # issuing a query
429 | with pytest.raises(driver.Error):
430 | cursor.fetchmany(4)
431 |
432 | executeDDL1(cursor)
433 | for sql in _populate():
434 | cursor.execute(sql)
435 |
436 | cursor.execute("select name from %sbooze" % table_prefix)
437 | r = cursor.fetchmany()
438 | assert len(r) == 1, (
439 | "cursor.fetchmany retrieved incorrect number of rows, "
440 | "default of arraysize is one."
441 | )
442 | cursor.arraysize = 10
443 | r = cursor.fetchmany(3) # Should get 3 rows
444 | assert len(r) == 3, "cursor.fetchmany retrieved incorrect number of rows"
445 | r = cursor.fetchmany(4) # Should get 2 more
446 | assert len(r) == 2, "cursor.fetchmany retrieved incorrect number of rows"
447 | r = cursor.fetchmany(4) # Should be an empty sequence
448 | assert len(r) == 0, (
449 | "cursor.fetchmany should return an empty sequence after "
450 | "results are exhausted"
451 | )
452 | assert cursor.rowcount in (-1, 6)
453 |
454 | # Same as above, using cursor.arraysize
455 | cursor.arraysize = 4
456 | cursor.execute("select name from %sbooze" % table_prefix)
457 | r = cursor.fetchmany() # Should get 4 rows
458 | assert len(r) == 4, "cursor.arraysize not being honoured by fetchmany"
459 | r = cursor.fetchmany() # Should get 2 more
460 | assert len(r) == 2
461 | r = cursor.fetchmany() # Should be an empty sequence
462 | assert len(r) == 0
463 | assert cursor.rowcount in (-1, 6)
464 |
465 | cursor.arraysize = 6
466 | cursor.execute("select name from %sbooze" % table_prefix)
467 | rows = cursor.fetchmany() # Should get all rows
468 | assert cursor.rowcount in (-1, 6)
469 | assert len(rows) == 6
470 | assert len(rows) == 6
471 | rows = [row[0] for row in rows]
472 | rows.sort()
473 |
474 | # Make sure we get the right data back out
475 | for i in range(0, 6):
476 | assert rows[i] == samples[i], "incorrect data retrieved by cursor.fetchmany"
477 |
478 | rows = cursor.fetchmany() # Should return an empty list
479 | assert len(rows) == 0, (
480 | "cursor.fetchmany should return an empty sequence if "
481 | "called after the whole result set has been fetched"
482 | )
483 | assert cursor.rowcount in (-1, 6)
484 |
485 | executeDDL2(cursor)
486 | cursor.execute("select name from %sbarflys" % table_prefix)
487 | r = cursor.fetchmany() # Should get empty sequence
488 | assert len(r) == 0, (
489 | "cursor.fetchmany should return an empty sequence if " "query retrieved no rows"
490 | )
491 | assert cursor.rowcount in (-1, 0)
492 |
493 |
494 | def test_fetchall(cursor):
495 | # cursor.fetchall should raise an Error if called
496 | # without executing a query that may return rows (such
497 | # as a select)
498 | with pytest.raises(driver.Error):
499 | cursor.fetchall()
500 |
501 | executeDDL1(cursor)
502 | for sql in _populate():
503 | cursor.execute(sql)
504 |
505 | # cursor.fetchall should raise an Error if called
506 | # after executing a a statement that cannot return rows
507 | with pytest.raises(driver.Error):
508 | cursor.fetchall()
509 |
510 | cursor.execute("select name from %sbooze" % table_prefix)
511 | rows = cursor.fetchall()
512 | assert cursor.rowcount in (-1, len(samples))
513 | assert len(rows) == len(samples), "cursor.fetchall did not retrieve all rows"
514 | rows = [r[0] for r in rows]
515 | rows.sort()
516 | for i in range(0, len(samples)):
517 | assert rows[i] == samples[i], "cursor.fetchall retrieved incorrect rows"
518 | rows = cursor.fetchall()
519 | assert len(rows) == 0, (
520 | "cursor.fetchall should return an empty list if called "
521 | "after the whole result set has been fetched"
522 | )
523 | assert cursor.rowcount in (-1, len(samples))
524 |
525 | executeDDL2(cursor)
526 | cursor.execute("select name from %sbarflys" % table_prefix)
527 | rows = cursor.fetchall()
528 | assert cursor.rowcount in (-1, 0)
529 | assert len(rows) == 0, (
530 | "cursor.fetchall should return an empty list if "
531 | "a select query returns no rows"
532 | )
533 |
534 |
535 | def test_mixedfetch(cursor):
536 | executeDDL1(cursor)
537 | for sql in _populate():
538 | cursor.execute(sql)
539 |
540 | cursor.execute("select name from %sbooze" % table_prefix)
541 | rows1 = cursor.fetchone()
542 | rows23 = cursor.fetchmany(2)
543 | rows4 = cursor.fetchone()
544 | rows56 = cursor.fetchall()
545 | assert cursor.rowcount in (-1, 6)
546 | assert len(rows23) == 2, "fetchmany returned incorrect number of rows"
547 | assert len(rows56) == 2, "fetchall returned incorrect number of rows"
548 |
549 | rows = [rows1[0]]
550 | rows.extend([rows23[0][0], rows23[1][0]])
551 | rows.append(rows4[0])
552 | rows.extend([rows56[0][0], rows56[1][0]])
553 | rows.sort()
554 | for i in range(0, len(samples)):
555 | assert rows[i] == samples[i], "incorrect data retrieved or inserted"
556 |
557 |
558 | def help_nextset_setUp(cur):
559 | """Should create a procedure called deleteme
560 | that returns two result sets, first the
561 | number of rows in booze then "name from booze"
562 | """
563 | raise NotImplementedError("Helper not implemented")
564 |
565 |
566 | def help_nextset_tearDown(cur):
567 | "If cleaning up is needed after nextSetTest"
568 | raise NotImplementedError("Helper not implemented")
569 |
570 |
571 | def test_nextset(cursor):
572 | if not hasattr(cursor, "nextset"):
573 | return
574 |
575 | try:
576 | executeDDL1(cursor)
577 | sql = _populate()
578 | for sql in _populate():
579 | cursor.execute(sql)
580 |
581 | help_nextset_setUp(cursor)
582 |
583 | cursor.callproc("deleteme")
584 | numberofrows = cursor.fetchone()
585 | assert numberofrows[0] == len(samples)
586 | assert cursor.nextset()
587 | names = cursor.fetchall()
588 | assert len(names) == len(samples)
589 | s = cursor.nextset()
590 | assert s is None, "No more return sets, should return None"
591 | finally:
592 | help_nextset_tearDown(cursor)
593 |
594 |
595 | def test_arraysize(cursor):
596 | # Not much here - rest of the tests for this are in test_fetchmany
597 | assert hasattr(cursor, "arraysize"), "cursor.arraysize must be defined"
598 |
599 |
600 | def test_setinputsizes(cursor):
601 | cursor.setinputsizes(25)
602 |
603 |
604 | def test_setoutputsize_basic(cursor):
605 | # Basic test is to make sure setoutputsize doesn't blow up
606 | cursor.setoutputsize(1000)
607 | cursor.setoutputsize(2000, 0)
608 | _paraminsert(cursor) # Make sure the cursor still works
609 |
610 |
611 | def test_None(cursor):
612 | executeDDL1(cursor)
613 | cursor.execute("insert into %sbooze values (NULL)" % table_prefix)
614 | cursor.execute("select name from %sbooze" % table_prefix)
615 | r = cursor.fetchall()
616 | assert len(r) == 1
617 | assert len(r[0]) == 1
618 | assert r[0][0] is None, "NULL value not returned as None"
619 |
620 |
621 | def test_Date():
622 | driver.Date(2002, 12, 25)
623 | driver.DateFromTicks(time.mktime((2002, 12, 25, 0, 0, 0, 0, 0, 0)))
624 | # Can we assume this? API doesn't specify, but it seems implied
625 | # self.assertEqual(str(d1),str(d2))
626 |
627 |
628 | def test_Time():
629 | driver.Time(13, 45, 30)
630 | driver.TimeFromTicks(time.mktime((2001, 1, 1, 13, 45, 30, 0, 0, 0)))
631 | # Can we assume this? API doesn't specify, but it seems implied
632 | # self.assertEqual(str(t1),str(t2))
633 |
634 |
635 | def test_Timestamp():
636 | driver.Timestamp(2002, 12, 25, 13, 45, 30)
637 | driver.TimestampFromTicks(time.mktime((2002, 12, 25, 13, 45, 30, 0, 0, 0)))
638 | # Can we assume this? API doesn't specify, but it seems implied
639 | # self.assertEqual(str(t1),str(t2))
640 |
641 |
642 | def test_Binary():
643 | driver.Binary(b"Something")
644 | driver.Binary(b"")
645 |
646 |
647 | def test_STRING():
648 | assert hasattr(driver, "STRING"), "module.STRING must be defined"
649 |
650 |
651 | def test_BINARY():
652 | assert hasattr(driver, "BINARY"), "module.BINARY must be defined."
653 |
654 |
655 | def test_NUMBER():
656 | assert hasattr(driver, "NUMBER"), "module.NUMBER must be defined."
657 |
658 |
659 | def test_DATETIME():
660 | assert hasattr(driver, "DATETIME"), "module.DATETIME must be defined."
661 |
662 |
663 | def test_ROWID():
664 | assert hasattr(driver, "ROWID"), "module.ROWID must be defined."
665 |
--------------------------------------------------------------------------------
/test/dbapi/test_paramstyle.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from pg8000.dbapi import convert_paramstyle
4 |
5 | # "(id %% 2) = 0",
6 |
7 |
8 | @pytest.mark.parametrize(
9 | "query,statement",
10 | [
11 | [
12 | 'SELECT ?, ?, "field_?" FROM t '
13 | "WHERE a='say ''what?''' AND b=? AND c=E'?\\'test\\'?'",
14 | 'SELECT $1, $2, "field_?" FROM t WHERE '
15 | "a='say ''what?''' AND b=$3 AND c=E'?\\'test\\'?'",
16 | ],
17 | [
18 | "SELECT ?, ?, * FROM t WHERE a=? AND b='are you ''sure?'",
19 | "SELECT $1, $2, * FROM t WHERE a=$3 AND b='are you ''sure?'",
20 | ],
21 | ],
22 | )
23 | def test_qmark(query, statement):
24 | args = 1, 2, 3
25 | new_query, vals = convert_paramstyle("qmark", query, args)
26 | assert (new_query, vals) == (statement, args)
27 |
28 |
29 | @pytest.mark.parametrize(
30 | "query,expected",
31 | [
32 | [
33 | "SELECT sum(x)::decimal(5, 2) :2, :1, * FROM t WHERE a=:3",
34 | "SELECT sum(x)::decimal(5, 2) $2, $1, * FROM t WHERE a=$3",
35 | ],
36 | ],
37 | )
38 | def test_numeric(query, expected):
39 | args = 1, 2, 3
40 | new_query, vals = convert_paramstyle("numeric", query, args)
41 | assert (new_query, vals) == (expected, args)
42 |
43 |
44 | @pytest.mark.parametrize(
45 | "query",
46 | [
47 | "make_interval(days := 10)",
48 | ],
49 | )
50 | def test_numeric_unchanged(query):
51 | args = 1, 2, 3
52 | new_query, vals = convert_paramstyle("numeric", query, args)
53 | assert (new_query, vals) == (query, args)
54 |
55 |
56 | @pytest.mark.parametrize(
57 | "query,args,expected_query,expected_args",
58 | [
59 | [
60 | "SELECT sum(x)::decimal(5, 2) :f_2, :f1 FROM t WHERE a=:f_2",
61 | {"f_2": 1, "f1": 2},
62 | "SELECT sum(x)::decimal(5, 2) $1, $2 FROM t WHERE a=$1",
63 | (1, 2),
64 | ],
65 | ["SELECT $$'$$ = :v", {"v": "'"}, "SELECT $$'$$ = $1", ("'",)],
66 | ],
67 | )
68 | def test_named(query, args, expected_query, expected_args):
69 | new_query, vals = convert_paramstyle("named", query, args)
70 | assert (new_query, vals) == (expected_query, expected_args)
71 |
72 |
73 | @pytest.mark.parametrize(
74 | "query,expected",
75 | [
76 | [
77 | "SELECT %s, %s, \"f1_%%\", E'txt_%%' "
78 | "FROM t WHERE a=%s AND b='75%%' AND c = '%' -- Comment with %",
79 | "SELECT $1, $2, \"f1_%%\", E'txt_%%' FROM t WHERE a=$3 AND "
80 | "b='75%%' AND c = '%' -- Comment with %",
81 | ],
82 | [
83 | "SELECT -- Comment\n%s FROM t",
84 | "SELECT -- Comment\n$1 FROM t",
85 | ],
86 | ],
87 | )
88 | def test_format_changed(query, expected):
89 | args = 1, 2, 3
90 | new_query, vals = convert_paramstyle("format", query, args)
91 | assert (new_query, vals) == (expected, args)
92 |
93 |
94 | @pytest.mark.parametrize(
95 | "query",
96 | [
97 | r"""COMMENT ON TABLE test_schema.comment_test """
98 | r"""IS 'the test % '' " \ table comment'""",
99 | ],
100 | )
101 | def test_format_unchanged(query):
102 | args = 1, 2, 3
103 | new_query, vals = convert_paramstyle("format", query, args)
104 | assert (new_query, vals) == (query, args)
105 |
106 |
107 | def test_py_format():
108 | args = {"f2": 1, "f1": 2, "f3": 3}
109 |
110 | new_query, vals = convert_paramstyle(
111 | "pyformat",
112 | "SELECT %(f2)s, %(f1)s, \"f1_%%\", E'txt_%%' "
113 | "FROM t WHERE a=%(f2)s AND b='75%%'",
114 | args,
115 | )
116 | expected = "SELECT $1, $2, \"f1_%%\", E'txt_%%' FROM t WHERE a=$1 AND " "b='75%%'"
117 | assert (new_query, vals) == (expected, (1, 2))
118 |
119 |
120 | def test_pyformat_format():
121 | """pyformat should support %s and an array, too:"""
122 | args = 1, 2, 3
123 | new_query, vals = convert_paramstyle(
124 | "pyformat",
125 | "SELECT %s, %s, \"f1_%%\", E'txt_%%' " "FROM t WHERE a=%s AND b='75%%'",
126 | args,
127 | )
128 | expected = "SELECT $1, $2, \"f1_%%\", E'txt_%%' FROM t WHERE a=$3 AND " "b='75%%'"
129 | assert (new_query, vals) == (expected, args)
130 |
--------------------------------------------------------------------------------
/test/dbapi/test_query.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime as Datetime, timezone as Timezone
2 |
3 | import pytest
4 |
5 | import pg8000.dbapi
6 | from pg8000.converters import INET_ARRAY, INTEGER
7 |
8 |
9 | # Tests relating to the basic operation of the database driver, driven by the
10 | # pg8000 custom interface.
11 |
12 |
13 | @pytest.fixture
14 | def db_table(request, con):
15 | con.paramstyle = "format"
16 | cursor = con.cursor()
17 | cursor.execute(
18 | "CREATE TEMPORARY TABLE t1 (f1 int primary key, "
19 | "f2 bigint not null, f3 varchar(50) null) "
20 | )
21 |
22 | def fin():
23 | try:
24 | cursor = con.cursor()
25 | cursor.execute("drop table t1")
26 | except pg8000.dbapi.DatabaseError:
27 | pass
28 |
29 | request.addfinalizer(fin)
30 | return con
31 |
32 |
33 | def test_database_error(cursor):
34 | with pytest.raises(pg8000.dbapi.DatabaseError):
35 | cursor.execute("INSERT INTO t99 VALUES (1, 2, 3)")
36 |
37 |
38 | def test_parallel_queries(db_table):
39 | cursor = db_table.cursor()
40 | cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (1, 1, None))
41 | cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (2, 10, None))
42 | cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (3, 100, None))
43 | cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (4, 1000, None))
44 | cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (5, 10000, None))
45 | c1 = db_table.cursor()
46 | c2 = db_table.cursor()
47 | c1.execute("SELECT f1, f2, f3 FROM t1")
48 | for row in c1.fetchall():
49 | f1, f2, f3 = row
50 | c2.execute("SELECT f1, f2, f3 FROM t1 WHERE f1 > %s", (f1,))
51 | for row in c2.fetchall():
52 | f1, f2, f3 = row
53 |
54 |
55 | def test_parallel_open_portals(con):
56 | c1 = con.cursor()
57 | c2 = con.cursor()
58 | c1count, c2count = 0, 0
59 | q = "select * from generate_series(1, %s)"
60 | params = (100,)
61 | c1.execute(q, params)
62 | c2.execute(q, params)
63 | for c2row in c2.fetchall():
64 | c2count += 1
65 | for c1row in c1.fetchall():
66 | c1count += 1
67 |
68 | assert c1count == c2count
69 |
70 |
71 | # Run a query on a table, alter the structure of the table, then run the
72 | # original query again.
73 |
74 |
75 | def test_alter(db_table):
76 | cursor = db_table.cursor()
77 | cursor.execute("select * from t1")
78 | cursor.execute("alter table t1 drop column f3")
79 | cursor.execute("select * from t1")
80 |
81 |
82 | # Run a query on a table, drop then re-create the table, then run the
83 | # original query again.
84 |
85 |
86 | def test_create(db_table):
87 | cursor = db_table.cursor()
88 | cursor.execute("select * from t1")
89 | cursor.execute("drop table t1")
90 | cursor.execute("create temporary table t1 (f1 int primary key)")
91 | cursor.execute("select * from t1")
92 |
93 |
94 | def test_insert_returning(db_table):
95 | cursor = db_table.cursor()
96 | cursor.execute("CREATE TEMPORARY TABLE t2 (id serial, data text)")
97 |
98 | # Test INSERT ... RETURNING with one row...
99 | cursor.execute("INSERT INTO t2 (data) VALUES (%s) RETURNING id", ("test1",))
100 | row_id = cursor.fetchone()[0]
101 | cursor.execute("SELECT data FROM t2 WHERE id = %s", (row_id,))
102 | assert "test1" == cursor.fetchone()[0]
103 |
104 | assert cursor.rowcount == 1
105 |
106 | # Test with multiple rows...
107 | cursor.execute(
108 | "INSERT INTO t2 (data) VALUES (%s), (%s), (%s) " "RETURNING id",
109 | ("test2", "test3", "test4"),
110 | )
111 | assert cursor.rowcount == 3
112 | ids = tuple([x[0] for x in cursor.fetchall()])
113 | assert len(ids) == 3
114 |
115 |
116 | def test_row_count(db_table):
117 | cursor = db_table.cursor()
118 | expected_count = 57
119 | cursor.executemany(
120 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)",
121 | tuple((i, i, None) for i in range(expected_count)),
122 | )
123 |
124 | # Check rowcount after executemany
125 | assert expected_count == cursor.rowcount
126 |
127 | cursor.execute("SELECT * FROM t1")
128 |
129 | # Check row_count without doing any reading first...
130 | assert expected_count == cursor.rowcount
131 |
132 | # Check rowcount after reading some rows, make sure it still
133 | # works...
134 | for i in range(expected_count // 2):
135 | cursor.fetchone()
136 | assert expected_count == cursor.rowcount
137 |
138 | cursor = db_table.cursor()
139 | # Restart the cursor, read a few rows, and then check rowcount
140 | # again...
141 | cursor.execute("SELECT * FROM t1")
142 | for i in range(expected_count // 3):
143 | cursor.fetchone()
144 | assert expected_count == cursor.rowcount
145 |
146 | # Should be -1 for a command with no results
147 | cursor.execute("DROP TABLE t1")
148 | assert -1 == cursor.rowcount
149 |
150 |
151 | def test_row_count_update(db_table):
152 | cursor = db_table.cursor()
153 | cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (1, 1, None))
154 | cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (2, 10, None))
155 | cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (3, 100, None))
156 | cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (4, 1000, None))
157 | cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (5, 10000, None))
158 | cursor.execute("UPDATE t1 SET f3 = %s WHERE f2 > 101", ("Hello!",))
159 | assert cursor.rowcount == 2
160 |
161 |
162 | def test_int_oid(cursor):
163 | # https://bugs.launchpad.net/pg8000/+bug/230796
164 | cursor.execute("SELECT typname FROM pg_type WHERE oid = %s", (100,))
165 |
166 |
167 | def test_unicode_query(cursor):
168 | cursor.execute(
169 | "CREATE TEMPORARY TABLE \u043c\u0435\u0441\u0442\u043e "
170 | "(\u0438\u043c\u044f VARCHAR(50), "
171 | "\u0430\u0434\u0440\u0435\u0441 VARCHAR(250))"
172 | )
173 |
174 |
175 | def test_executemany(db_table):
176 | cursor = db_table.cursor()
177 | cursor.executemany(
178 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)",
179 | ((1, 1, "Avast ye!"), (2, 1, None)),
180 | )
181 |
182 | cursor.executemany(
183 | "select CAST(%s AS TIMESTAMP)",
184 | ((Datetime(2014, 5, 7, tzinfo=Timezone.utc),), (Datetime(2014, 5, 7),)),
185 | )
186 |
187 |
188 | def test_executemany_setinputsizes(cursor):
189 | """Make sure that setinputsizes works for all the parameter sets"""
190 |
191 | cursor.execute(
192 | "CREATE TEMPORARY TABLE t1 (f1 int primary key, f2 inet[] not null) "
193 | )
194 |
195 | cursor.setinputsizes(INTEGER, INET_ARRAY)
196 | cursor.executemany(
197 | "INSERT INTO t1 (f1, f2) VALUES (%s, %s)", ((1, ["1.1.1.1"]), (2, ["0.0.0.0"]))
198 | )
199 |
200 |
201 | def test_executemany_no_param_sets(cursor):
202 | cursor.executemany("INSERT INTO t1 (f1, f2) VALUES (%s, %s)", [])
203 | assert cursor.rowcount == -1
204 |
205 |
206 | # Check that autocommit stays off
207 | # We keep track of whether we're in a transaction or not by using the
208 | # READY_FOR_QUERY message.
209 | def test_transactions(db_table):
210 | cursor = db_table.cursor()
211 | cursor.execute("commit")
212 | cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (1, 1, "Zombie"))
213 | cursor.execute("rollback")
214 | cursor.execute("select * from t1")
215 |
216 | assert cursor.rowcount == 0
217 |
218 |
219 | def test_in(cursor):
220 | cursor.execute("SELECT typname FROM pg_type WHERE oid = any(%s)", ([16, 23],))
221 | ret = cursor.fetchall()
222 | assert ret[0][0] == "bool"
223 |
224 |
225 | def test_no_previous_tpc(con):
226 | con.tpc_begin("Stacey")
227 | cursor = con.cursor()
228 | cursor.execute("SELECT * FROM pg_type")
229 | con.tpc_commit()
230 |
231 |
232 | # Check that tpc_recover() doesn't start a transaction
233 | def test_tpc_recover(con):
234 | con.tpc_recover()
235 | cursor = con.cursor()
236 | con.autocommit = True
237 |
238 | # If tpc_recover() has started a transaction, this will fail
239 | cursor.execute("VACUUM")
240 |
241 |
242 | def test_tpc_prepare(con):
243 | xid = "Stacey"
244 | con.tpc_begin(xid)
245 | con.tpc_prepare()
246 | con.tpc_rollback(xid)
247 |
248 |
249 | def test_empty_query(cursor):
250 | """No exception thrown"""
251 | cursor.execute("")
252 |
253 |
254 | # rolling back when not in a transaction doesn't generate a warning
255 | def test_rollback_no_transaction(con):
256 | # Remove any existing notices
257 | con.notices.clear()
258 |
259 | # First, verify that a raw rollback does produce a notice
260 | con.execute_unnamed("rollback")
261 |
262 | assert 1 == len(con.notices)
263 |
264 | # 25P01 is the code for no_active_sql_tronsaction. It has
265 | # a message and severity name, but those might be
266 | # localized/depend on the server version.
267 | assert con.notices.pop().get(b"C") == b"25P01"
268 |
269 | # Now going through the rollback method doesn't produce
270 | # any notices because it knows we're not in a transaction.
271 | con.rollback()
272 |
273 | assert 0 == len(con.notices)
274 |
275 |
276 | @pytest.mark.parametrize("sizes,oids", [([0], [0]), ([float], [701])])
277 | def test_setinputsizes(con, sizes, oids):
278 | cursor = con.cursor()
279 | cursor.setinputsizes(*sizes)
280 | assert cursor._input_oids == oids
281 | cursor.execute("select %s", (None,))
282 | retval = cursor.fetchall()
283 | assert retval[0][0] is None
284 |
285 |
286 | def test_unexecuted_cursor_rowcount(con):
287 | cursor = con.cursor()
288 | assert cursor.rowcount == -1
289 |
290 |
291 | def test_unexecuted_cursor_description(con):
292 | cursor = con.cursor()
293 | assert cursor.description is None
294 |
295 |
296 | def test_callproc(pg_version, cursor):
297 | if pg_version > 10:
298 | cursor.execute(
299 | """
300 | CREATE PROCEDURE echo(INOUT val text)
301 | LANGUAGE plpgsql AS
302 | $proc$
303 | BEGIN
304 | END
305 | $proc$;
306 | """
307 | )
308 |
309 | cursor.callproc("echo", ["hello"])
310 | assert cursor.fetchall() == (["hello"],)
311 |
312 |
313 | def test_null_result(db_table):
314 | cur = db_table.cursor()
315 | cur.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (1, 1, "a"))
316 | with pytest.raises(pg8000.dbapi.ProgrammingError):
317 | cur.fetchall()
318 |
319 |
320 | def test_not_parsed_if_no_params(mocker, cursor):
321 | mock_convert_paramstyle = mocker.patch("pg8000.dbapi.convert_paramstyle")
322 | cursor.execute("ROLLBACK")
323 | mock_convert_paramstyle.assert_not_called()
324 |
--------------------------------------------------------------------------------
/test/legacy/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tlocke/pg8000/46c00021ade1d19466b07ed30392386c5f0a6b8e/test/legacy/__init__.py
--------------------------------------------------------------------------------
/test/legacy/auth/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tlocke/pg8000/46c00021ade1d19466b07ed30392386c5f0a6b8e/test/legacy/auth/__init__.py
--------------------------------------------------------------------------------
/test/legacy/auth/test_gss.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from pg8000 import InterfaceError, connect
4 |
5 |
6 | def test_gss(db_kwargs):
7 | """This requires a line in pg_hba.conf that requires gss for the database
8 | pg8000_gss
9 | """
10 | db_kwargs["database"] = "pg8000_gss"
11 |
12 | # Should raise an exception saying gss isn't supported
13 | with pytest.raises(
14 | InterfaceError,
15 | match="Authentication method 7 not supported by pg8000.",
16 | ):
17 | connect(**db_kwargs)
18 |
--------------------------------------------------------------------------------
/test/legacy/auth/test_md5.py:
--------------------------------------------------------------------------------
1 | def test_md5(con):
2 | """Called by GitHub Actions with auth method md5.
3 | We just need to check that we can get a connection.
4 | """
5 | pass
6 |
--------------------------------------------------------------------------------
/test/legacy/auth/test_md5_ssl.py:
--------------------------------------------------------------------------------
1 | import ssl
2 |
3 | from pg8000 import connect
4 |
5 |
6 | def test_ssl(db_kwargs):
7 | context = ssl.create_default_context()
8 | context.check_hostname = False
9 | context.verify_mode = ssl.CERT_NONE
10 |
11 | db_kwargs["ssl_context"] = context
12 |
13 | with connect(**db_kwargs):
14 | pass
15 |
--------------------------------------------------------------------------------
/test/legacy/auth/test_password.py:
--------------------------------------------------------------------------------
1 | def test_password(con):
2 | """Called by GitHub Actions with auth method password.
3 | We just need to check that we can get a connection.
4 | """
5 | pass
6 |
--------------------------------------------------------------------------------
/test/legacy/auth/test_scram-sha-256.py:
--------------------------------------------------------------------------------
1 | def test_scram_sha_256(con):
2 | """Called by GitHub Actions with auth method scram-sha-256.
3 | We just need to check that we can get a connection.
4 | """
5 | pass
6 |
--------------------------------------------------------------------------------
/test/legacy/auth/test_scram-sha-256_ssl.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from pg8000 import DatabaseError, connect
4 |
5 | # This requires a line in pg_hba.conf that requires scram-sha-256 for the
6 | # database pg8000_scram_sha_256
7 |
8 | DB = "pg8000_scram_sha_256"
9 |
10 |
11 | @pytest.fixture
12 | def setup(con):
13 | try:
14 | con.run(f"CREATE DATABASE {DB}")
15 | except DatabaseError:
16 | pass
17 |
18 |
19 | def test_scram_sha_256_plus(setup, db_kwargs):
20 | db_kwargs["database"] = DB
21 |
22 | with connect(**db_kwargs):
23 | pass
24 |
--------------------------------------------------------------------------------
/test/legacy/conftest.py:
--------------------------------------------------------------------------------
1 | from os import environ
2 |
3 | import pytest
4 |
5 | import pg8000
6 |
7 |
8 | @pytest.fixture(scope="class")
9 | def db_kwargs():
10 | db_connect = {"user": "postgres", "password": "pw"}
11 |
12 | for kw, var, f in [
13 | ("host", "PGHOST", str),
14 | ("password", "PGPASSWORD", str),
15 | ("port", "PGPORT", int),
16 | ]:
17 | try:
18 | db_connect[kw] = f(environ[var])
19 | except KeyError:
20 | pass
21 |
22 | return db_connect
23 |
24 |
25 | @pytest.fixture
26 | def con(request, db_kwargs):
27 | conn = pg8000.connect(**db_kwargs)
28 |
29 | def fin():
30 | try:
31 | conn.rollback()
32 | except pg8000.InterfaceError:
33 | pass
34 |
35 | try:
36 | conn.close()
37 | except pg8000.InterfaceError:
38 | pass
39 |
40 | request.addfinalizer(fin)
41 | return conn
42 |
43 |
44 | @pytest.fixture
45 | def cursor(request, con):
46 | cursor = con.cursor()
47 |
48 | def fin():
49 | cursor.close()
50 |
51 | request.addfinalizer(fin)
52 | return cursor
53 |
--------------------------------------------------------------------------------
/test/legacy/stress.py:
--------------------------------------------------------------------------------
1 | from contextlib import closing
2 |
3 | import pg8000
4 | from pg8000.tests.connection_settings import db_connect
5 |
6 |
7 | with closing(pg8000.connect(**db_connect)) as db:
8 | for i in range(100):
9 | cursor = db.cursor()
10 | cursor.execute(
11 | """
12 | SELECT n.nspname as "Schema",
13 | pg_catalog.format_type(t.oid, NULL) AS "Name",
14 | pg_catalog.obj_description(t.oid, 'pg_type') as "Description"
15 | FROM pg_catalog.pg_type t
16 | LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
17 | left join pg_catalog.pg_namespace kj on n.oid = t.typnamespace
18 | WHERE (t.typrelid = 0
19 | OR (SELECT c.relkind = 'c'
20 | FROM pg_catalog.pg_class c WHERE c.oid = t.typrelid))
21 | AND NOT EXISTS(
22 | SELECT 1 FROM pg_catalog.pg_type el
23 | WHERE el.oid = t.typelem AND el.typarray = t.oid)
24 | AND pg_catalog.pg_type_is_visible(t.oid)
25 | ORDER BY 1, 2;"""
26 | )
27 |
--------------------------------------------------------------------------------
/test/legacy/test_benchmarks.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 |
4 | @pytest.mark.parametrize(
5 | "txt",
6 | (
7 | ("int2", "cast(id / 100 as int2)"),
8 | "cast(id as int4)",
9 | "cast(id * 100 as int8)",
10 | "(id % 2) = 0",
11 | "N'Static text string'",
12 | "cast(id / 100 as float4)",
13 | "cast(id / 100 as float8)",
14 | "cast(id / 100 as numeric)",
15 | "timestamp '2001-09-28'",
16 | ),
17 | )
18 | def test_round_trips(con, benchmark, txt):
19 | def torun():
20 | query = """SELECT {0} AS column1, {0} AS column2, {0} AS column3,
21 | {0} AS column4, {0} AS column5, {0} AS column6, {0} AS column7
22 | FROM (SELECT generate_series(1, 10000) AS id) AS tbl""".format(
23 | txt
24 | )
25 | con.run(query)
26 |
27 | benchmark(torun)
28 |
--------------------------------------------------------------------------------
/test/legacy/test_connection.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from pg8000 import DatabaseError, InterfaceError, ProgrammingError, connect
4 |
5 |
6 | def testUnixSocketMissing():
7 | conn_params = {"unix_sock": "/file-does-not-exist", "user": "doesn't-matter"}
8 |
9 | with pytest.raises(InterfaceError):
10 | connect(**conn_params)
11 |
12 |
13 | def test_internet_socket_connection_refused():
14 | conn_params = {"port": 0, "user": "doesn't-matter"}
15 |
16 | with pytest.raises(
17 | InterfaceError,
18 | match="Can't create a connection to host localhost and port 0 "
19 | "\\(timeout is None and source_address is None\\).",
20 | ):
21 | connect(**conn_params)
22 |
23 |
24 | def testDatabaseMissing(db_kwargs):
25 | db_kwargs["database"] = "missing-db"
26 | with pytest.raises(ProgrammingError):
27 | connect(**db_kwargs)
28 |
29 |
30 | def test_notify(con):
31 | backend_pid = con.run("select pg_backend_pid()")[0][0]
32 | assert list(con.notifications) == []
33 | con.run("LISTEN test")
34 | con.run("NOTIFY test")
35 | con.commit()
36 |
37 | con.run("VALUES (1, 2), (3, 4), (5, 6)")
38 | assert len(con.notifications) == 1
39 | assert con.notifications[0] == (backend_pid, "test", "")
40 |
41 |
42 | def test_notify_with_payload(con):
43 | backend_pid = con.run("select pg_backend_pid()")[0][0]
44 | assert list(con.notifications) == []
45 | con.run("LISTEN test")
46 | con.run("NOTIFY test, 'Parnham'")
47 | con.commit()
48 |
49 | con.run("VALUES (1, 2), (3, 4), (5, 6)")
50 | assert len(con.notifications) == 1
51 | assert con.notifications[0] == (backend_pid, "test", "Parnham")
52 |
53 |
54 | # This requires a line in pg_hba.conf that requires md5 for the database
55 | # pg8000_md5
56 |
57 |
58 | def testMd5(db_kwargs):
59 | db_kwargs["database"] = "pg8000_md5"
60 |
61 | # Should only raise an exception saying db doesn't exist
62 | with pytest.raises(ProgrammingError, match="3D000"):
63 | connect(**db_kwargs)
64 |
65 |
66 | # This requires a line in pg_hba.conf that requires 'password' for the
67 | # database pg8000_password
68 |
69 |
70 | def testPassword(db_kwargs):
71 | db_kwargs["database"] = "pg8000_password"
72 |
73 | # Should only raise an exception saying db doesn't exist
74 | with pytest.raises(ProgrammingError, match="3D000"):
75 | connect(**db_kwargs)
76 |
77 |
78 | def testUnicodeDatabaseName(db_kwargs):
79 | db_kwargs["database"] = "pg8000_sn\uFF6Fw"
80 |
81 | # Should only raise an exception saying db doesn't exist
82 | with pytest.raises(ProgrammingError, match="3D000"):
83 | connect(**db_kwargs)
84 |
85 |
86 | def testBytesDatabaseName(db_kwargs):
87 | """Should only raise an exception saying db doesn't exist"""
88 |
89 | db_kwargs["database"] = bytes("pg8000_sn\uFF6Fw", "utf8")
90 | with pytest.raises(ProgrammingError, match="3D000"):
91 | connect(**db_kwargs)
92 |
93 |
94 | def testBytesPassword(con, db_kwargs):
95 | # Create user
96 | username = "boltzmann"
97 | password = "cha\uFF6Fs"
98 | with con.cursor() as cur:
99 | cur.execute("create user " + username + " with password '" + password + "';")
100 | con.commit()
101 |
102 | db_kwargs["user"] = username
103 | db_kwargs["password"] = password.encode("utf8")
104 | db_kwargs["database"] = "pg8000_md5"
105 | with pytest.raises(ProgrammingError, match="3D000"):
106 | connect(**db_kwargs)
107 |
108 | cur.execute("drop role " + username)
109 | con.commit()
110 |
111 |
112 | def test_broken_pipe_read(con, db_kwargs):
113 | db1 = connect(**db_kwargs)
114 | cur1 = db1.cursor()
115 | cur2 = con.cursor()
116 | cur1.execute("select pg_backend_pid()")
117 | pid1 = cur1.fetchone()[0]
118 |
119 | cur2.execute("select pg_terminate_backend(%s)", (pid1,))
120 | with pytest.raises(InterfaceError, match="network error"):
121 | cur1.execute("select 1")
122 |
123 | try:
124 | db1.close()
125 | except InterfaceError:
126 | pass
127 |
128 |
129 | def test_broken_pipe_flush(con, db_kwargs):
130 | db1 = connect(**db_kwargs)
131 | cur1 = db1.cursor()
132 | cur2 = con.cursor()
133 | cur1.execute("select pg_backend_pid()")
134 | pid1 = cur1.fetchone()[0]
135 |
136 | cur2.execute("select pg_terminate_backend(%s)", (pid1,))
137 | try:
138 | cur1.execute("select 1")
139 | except BaseException:
140 | pass
141 |
142 | # Can do an assert_raises when we're on 3.8 or above
143 | try:
144 | db1.close()
145 | except InterfaceError as e:
146 | assert str(e) == "network error"
147 |
148 |
149 | def test_broken_pipe_unpack(con):
150 | cur = con.cursor()
151 | cur.execute("select pg_backend_pid()")
152 | pid1 = cur.fetchone()[0]
153 |
154 | with pytest.raises(InterfaceError, match="network error"):
155 | cur.execute("select pg_terminate_backend(%s)", (pid1,))
156 |
157 |
158 | def testApplicatioName(db_kwargs):
159 | app_name = "my test application name"
160 | db_kwargs["application_name"] = app_name
161 | with connect(**db_kwargs) as db:
162 | cur = db.cursor()
163 | cur.execute(
164 | "select application_name from pg_stat_activity "
165 | " where pid = pg_backend_pid()"
166 | )
167 |
168 | application_name = cur.fetchone()[0]
169 | assert application_name == app_name
170 |
171 |
172 | def test_application_name_integer(db_kwargs):
173 | db_kwargs["application_name"] = 1
174 | with pytest.raises(
175 | InterfaceError,
176 | match="The parameter application_name can't be of type .",
177 | ):
178 | connect(**db_kwargs)
179 |
180 |
181 | def test_application_name_bytearray(db_kwargs):
182 | db_kwargs["application_name"] = bytearray(b"Philby")
183 | with connect(**db_kwargs):
184 | pass
185 |
186 |
187 | # This requires a line in pg_hba.conf that requires scram-sha-256 for the
188 | # database pg8000_scram_sha_256
189 |
190 | DB = "pg8000_scram_sha_256"
191 |
192 |
193 | @pytest.fixture
194 | def setup(con, cursor):
195 | con.autocommit = True
196 | try:
197 | cursor.execute(f"CREATE DATABASE {DB}")
198 | except DatabaseError:
199 | con.rollback()
200 |
201 |
202 | def test_scram_sha_256(setup, db_kwargs):
203 | db_kwargs["database"] = DB
204 |
205 | con = connect(**db_kwargs)
206 | con.close()
207 |
208 |
209 | @pytest.mark.parametrize(
210 | "commit",
211 | [
212 | "commit",
213 | "COMMIT;",
214 | ],
215 | )
216 | def test_failed_transaction_commit_sql(cursor, commit):
217 | cursor.execute("create temporary table tt (f1 int primary key)")
218 | cursor.execute("begin")
219 | try:
220 | cursor.execute("insert into tt(f1) values(null)")
221 | except DatabaseError:
222 | pass
223 |
224 | with pytest.raises(InterfaceError):
225 | cursor.execute(commit)
226 |
227 |
228 | def test_failed_transaction_commit_method(con, cursor):
229 | cursor.execute("create temporary table tt (f1 int primary key)")
230 | cursor.execute("begin")
231 | try:
232 | cursor.execute("insert into tt(f1) values(null)")
233 | except DatabaseError:
234 | pass
235 |
236 | with pytest.raises(InterfaceError):
237 | con.commit()
238 |
239 |
240 | @pytest.mark.parametrize(
241 | "rollback",
242 | [
243 | "rollback",
244 | "rollback;",
245 | "ROLLBACK ;",
246 | ],
247 | )
248 | def test_failed_transaction_rollback_sql(cursor, rollback):
249 | cursor.execute("create temporary table tt (f1 int primary key)")
250 | cursor.execute("begin")
251 | try:
252 | cursor.execute("insert into tt(f1) values(null)")
253 | except DatabaseError:
254 | pass
255 |
256 | cursor.execute(rollback)
257 |
258 |
259 | def test_failed_transaction_rollback_method(cursor, con):
260 | cursor.execute("create temporary table tt (f1 int primary key)")
261 | cursor.execute("begin")
262 | try:
263 | cursor.execute("insert into tt(f1) values(null)")
264 | except DatabaseError:
265 | pass
266 |
267 | con.rollback()
268 |
--------------------------------------------------------------------------------
/test/legacy/test_copy.py:
--------------------------------------------------------------------------------
1 | from io import BytesIO
2 |
3 | import pytest
4 |
5 |
6 | @pytest.fixture
7 | def db_table(request, con):
8 | with con.cursor() as cursor:
9 | cursor.execute(
10 | "CREATE TEMPORARY TABLE t1 (f1 int primary key, "
11 | "f2 int not null, f3 varchar(50) null) "
12 | "on commit drop"
13 | )
14 | return con
15 |
16 |
17 | def test_copy_to_with_table(db_table):
18 | with db_table.cursor() as cursor:
19 | cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (1, 1, 1))
20 | cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (2, 2, 2))
21 | cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (3, 3, 3))
22 |
23 | stream = BytesIO()
24 | cursor.execute("copy t1 to stdout", stream=stream)
25 | assert stream.getvalue() == b"1\t1\t1\n2\t2\t2\n3\t3\t3\n"
26 | assert cursor.rowcount == 3
27 |
28 |
29 | def test_copy_to_with_query(db_table):
30 | with db_table.cursor() as cursor:
31 | stream = BytesIO()
32 | cursor.execute(
33 | "COPY (SELECT 1 as One, 2 as Two) TO STDOUT WITH DELIMITER "
34 | "'X' CSV HEADER QUOTE AS 'Y' FORCE QUOTE Two",
35 | stream=stream,
36 | )
37 | assert stream.getvalue() == b"oneXtwo\n1XY2Y\n"
38 | assert cursor.rowcount == 1
39 |
40 |
41 | def test_copy_from_with_table(db_table):
42 | with db_table.cursor() as cursor:
43 | stream = BytesIO(b"1\t1\t1\n2\t2\t2\n3\t3\t3\n")
44 | cursor.execute("copy t1 from STDIN", stream=stream)
45 | assert cursor.rowcount == 3
46 |
47 | cursor.execute("SELECT * FROM t1 ORDER BY f1")
48 | retval = cursor.fetchall()
49 | assert retval == ([1, 1, "1"], [2, 2, "2"], [3, 3, "3"])
50 |
51 |
52 | def test_copy_from_with_query(db_table):
53 | with db_table.cursor() as cursor:
54 | stream = BytesIO(b"f1Xf2\n1XY1Y\n")
55 | cursor.execute(
56 | "COPY t1 (f1, f2) FROM STDIN WITH DELIMITER 'X' CSV HEADER "
57 | "QUOTE AS 'Y' FORCE NOT NULL f1",
58 | stream=stream,
59 | )
60 | assert cursor.rowcount == 1
61 |
62 | cursor.execute("SELECT * FROM t1 ORDER BY f1")
63 | retval = cursor.fetchall()
64 | assert retval == ([1, 1, None],)
65 |
66 |
67 | def test_copy_from_with_error(db_table):
68 | with db_table.cursor() as cursor:
69 | stream = BytesIO(b"f1Xf2\n\n1XY1Y\n")
70 | with pytest.raises(BaseException) as e:
71 | cursor.execute(
72 | "COPY t1 (f1, f2) FROM STDIN WITH DELIMITER 'X' CSV HEADER "
73 | "QUOTE AS 'Y' FORCE NOT NULL f1",
74 | stream=stream,
75 | )
76 |
77 | arg = {
78 | "S": ("ERROR",),
79 | "C": ("22P02",),
80 | "M": (
81 | 'invalid input syntax for type integer: ""',
82 | 'invalid input syntax for integer: ""',
83 | ),
84 | "W": ('COPY t1, line 2, column f1: ""',),
85 | "F": ("numutils.c",),
86 | "R": ("pg_atoi", "pg_strtoint32", "pg_strtoint32_safe"),
87 | }
88 | earg = e.value.args[0]
89 | for k, v in arg.items():
90 | assert earg[k] in v
91 |
--------------------------------------------------------------------------------
/test/legacy/test_dbapi.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import os
3 | import time
4 |
5 | import pytest
6 |
7 | import pg8000
8 |
9 |
10 | @pytest.fixture
11 | def has_tzset():
12 | # Neither Windows nor Jython 2.5.3 have a time.tzset() so skip
13 | if hasattr(time, "tzset"):
14 | os.environ["TZ"] = "UTC"
15 | time.tzset()
16 | return True
17 | return False
18 |
19 |
20 | # DBAPI compatible interface tests
21 | @pytest.fixture
22 | def db_table(con, has_tzset):
23 | with con.cursor() as c:
24 | c.execute(
25 | "CREATE TEMPORARY TABLE t1 "
26 | "(f1 int primary key, f2 int not null, f3 varchar(50) null) "
27 | "ON COMMIT DROP"
28 | )
29 | c.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (1, 1, None))
30 | c.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (2, 10, None))
31 | c.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (3, 100, None))
32 | c.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (4, 1000, None))
33 | c.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (5, 10000, None))
34 | return con
35 |
36 |
37 | def test_parallel_queries(db_table):
38 | with db_table.cursor() as c1, db_table.cursor() as c2:
39 | c1.execute("SELECT f1, f2, f3 FROM t1")
40 | while 1:
41 | row = c1.fetchone()
42 | if row is None:
43 | break
44 | f1, f2, f3 = row
45 | c2.execute("SELECT f1, f2, f3 FROM t1 WHERE f1 > %s", (f1,))
46 | while 1:
47 | row = c2.fetchone()
48 | if row is None:
49 | break
50 | f1, f2, f3 = row
51 |
52 |
53 | def test_qmark(db_table):
54 | orig_paramstyle = pg8000.paramstyle
55 | try:
56 | pg8000.paramstyle = "qmark"
57 | with db_table.cursor() as c1:
58 | c1.execute("SELECT f1, f2, f3 FROM t1 WHERE f1 > ?", (3,))
59 | while 1:
60 | row = c1.fetchone()
61 | if row is None:
62 | break
63 | f1, f2, f3 = row
64 | finally:
65 | pg8000.paramstyle = orig_paramstyle
66 |
67 |
68 | def test_numeric(db_table):
69 | orig_paramstyle = pg8000.paramstyle
70 | try:
71 | pg8000.paramstyle = "numeric"
72 | with db_table.cursor() as c1:
73 | c1.execute("SELECT f1, f2, f3 FROM t1 WHERE f1 > :1", (3,))
74 | while 1:
75 | row = c1.fetchone()
76 | if row is None:
77 | break
78 | f1, f2, f3 = row
79 | finally:
80 | pg8000.paramstyle = orig_paramstyle
81 |
82 |
83 | def test_named(db_table):
84 | orig_paramstyle = pg8000.paramstyle
85 | try:
86 | pg8000.paramstyle = "named"
87 | with db_table.cursor() as c1:
88 | c1.execute("SELECT f1, f2, f3 FROM t1 WHERE f1 > :f1", {"f1": 3})
89 | while 1:
90 | row = c1.fetchone()
91 | if row is None:
92 | break
93 | f1, f2, f3 = row
94 | finally:
95 | pg8000.paramstyle = orig_paramstyle
96 |
97 |
98 | def test_format(db_table):
99 | orig_paramstyle = pg8000.paramstyle
100 | try:
101 | pg8000.paramstyle = "format"
102 | with db_table.cursor() as c1:
103 | c1.execute("SELECT f1, f2, f3 FROM t1 WHERE f1 > %s", (3,))
104 | while 1:
105 | row = c1.fetchone()
106 | if row is None:
107 | break
108 | f1, f2, f3 = row
109 | finally:
110 | pg8000.paramstyle = orig_paramstyle
111 |
112 |
113 | def test_pyformat(db_table):
114 | orig_paramstyle = pg8000.paramstyle
115 | try:
116 | pg8000.paramstyle = "pyformat"
117 | with db_table.cursor() as c1:
118 | c1.execute("SELECT f1, f2, f3 FROM t1 WHERE f1 > %(f1)s", {"f1": 3})
119 | while 1:
120 | row = c1.fetchone()
121 | if row is None:
122 | break
123 | f1, f2, f3 = row
124 | finally:
125 | pg8000.paramstyle = orig_paramstyle
126 |
127 |
128 | def test_arraysize(db_table):
129 | with db_table.cursor() as c1:
130 | c1.arraysize = 3
131 | c1.execute("SELECT * FROM t1")
132 | retval = c1.fetchmany()
133 | assert len(retval) == c1.arraysize
134 |
135 |
136 | def test_date():
137 | val = pg8000.Date(2001, 2, 3)
138 | assert val == datetime.date(2001, 2, 3)
139 |
140 |
141 | def test_time():
142 | val = pg8000.Time(4, 5, 6)
143 | assert val == datetime.time(4, 5, 6)
144 |
145 |
146 | def test_timestamp():
147 | val = pg8000.Timestamp(2001, 2, 3, 4, 5, 6)
148 | assert val == datetime.datetime(2001, 2, 3, 4, 5, 6)
149 |
150 |
151 | def test_date_from_ticks(has_tzset):
152 | if has_tzset:
153 | val = pg8000.DateFromTicks(1173804319)
154 | assert val == datetime.date(2007, 3, 13)
155 |
156 |
157 | def testTimeFromTicks(has_tzset):
158 | if has_tzset:
159 | val = pg8000.TimeFromTicks(1173804319)
160 | assert val == datetime.time(16, 45, 19)
161 |
162 |
163 | def test_timestamp_from_ticks(has_tzset):
164 | if has_tzset:
165 | val = pg8000.TimestampFromTicks(1173804319)
166 | assert val == datetime.datetime(2007, 3, 13, 16, 45, 19)
167 |
168 |
169 | def test_binary():
170 | v = pg8000.Binary(b"\x00\x01\x02\x03\x02\x01\x00")
171 | assert v == b"\x00\x01\x02\x03\x02\x01\x00"
172 | assert isinstance(v, pg8000.BINARY)
173 |
174 |
175 | def test_row_count(db_table):
176 | with db_table.cursor() as c1:
177 | c1.execute("SELECT * FROM t1")
178 |
179 | assert 5 == c1.rowcount
180 |
181 | c1.execute("UPDATE t1 SET f3 = %s WHERE f2 > 101", ("Hello!",))
182 | assert 2 == c1.rowcount
183 |
184 | c1.execute("DELETE FROM t1")
185 | assert 5 == c1.rowcount
186 |
187 |
188 | def test_fetch_many(db_table):
189 | with db_table.cursor() as cursor:
190 | cursor.arraysize = 2
191 | cursor.execute("SELECT * FROM t1")
192 | assert 2 == len(cursor.fetchmany())
193 | assert 2 == len(cursor.fetchmany())
194 | assert 1 == len(cursor.fetchmany())
195 | assert 0 == len(cursor.fetchmany())
196 |
197 |
198 | def test_iterator(db_table):
199 | with db_table.cursor() as cursor:
200 | cursor.execute("SELECT * FROM t1 ORDER BY f1")
201 | f1 = 0
202 | for row in cursor:
203 | next_f1 = row[0]
204 | assert next_f1 > f1
205 | f1 = next_f1
206 |
207 |
208 | # Vacuum can't be run inside a transaction, so we need to turn
209 | # autocommit on.
210 | def test_vacuum(con):
211 | con.autocommit = True
212 | with con.cursor() as cursor:
213 | cursor.execute("vacuum")
214 |
215 |
216 | def test_prepared_statement(con):
217 | with con.cursor() as cursor:
218 | cursor.execute("PREPARE gen_series AS SELECT generate_series(1, 10);")
219 | cursor.execute("EXECUTE gen_series")
220 |
221 |
222 | def test_cursor_type(cursor):
223 | assert str(type(cursor)) == ""
224 |
--------------------------------------------------------------------------------
/test/legacy/test_error_recovery.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import warnings
3 |
4 | import pytest
5 |
6 | import pg8000
7 |
8 |
9 | class PG8000TestException(Exception):
10 | pass
11 |
12 |
13 | def raise_exception(val):
14 | raise PG8000TestException("oh noes!")
15 |
16 |
17 | def test_py_value_fail(con, mocker):
18 | # Ensure that if types.py_value throws an exception, the original
19 | # exception is raised (PG8000TestException), and the connection is
20 | # still usable after the error.
21 | mocker.patch.object(con, "py_types")
22 | con.py_types = {datetime.time: raise_exception}
23 |
24 | with con.cursor() as c, pytest.raises(PG8000TestException):
25 | c.execute("SELECT CAST(%s AS TIME)", (datetime.time(10, 30),))
26 | c.fetchall()
27 |
28 | # ensure that the connection is still usable for a new query
29 | c.execute("VALUES ('hw3'::text)")
30 | assert c.fetchone()[0] == "hw3"
31 |
32 |
33 | def test_no_data_error_recovery(con):
34 | for i in range(1, 4):
35 | with con.cursor() as c, pytest.raises(pg8000.ProgrammingError) as e:
36 | c.execute("DROP TABLE t1")
37 | assert e.value.args[0]["C"] == "42P01"
38 | con.rollback()
39 |
40 |
41 | def testClosedConnection(db_kwargs):
42 | warnings.simplefilter("ignore")
43 | my_db = pg8000.connect(**db_kwargs)
44 | cursor = my_db.cursor()
45 | my_db.close()
46 | with pytest.raises(my_db.InterfaceError, match="connection is closed"):
47 | cursor.execute("VALUES ('hw1'::text)")
48 |
49 | warnings.resetwarnings()
50 |
--------------------------------------------------------------------------------
/test/legacy/test_paramstyle.py:
--------------------------------------------------------------------------------
1 | from pg8000.legacy import convert_paramstyle as convert
2 |
3 |
4 | # Tests of the convert_paramstyle function.
5 |
6 |
7 | def test_qmark():
8 | args = 1, 2, 3
9 | new_query, vals = convert(
10 | "qmark",
11 | 'SELECT ?, ?, "field_?" FROM t '
12 | "WHERE a='say ''what?''' AND b=? AND c=E'?\\'test\\'?'",
13 | args,
14 | )
15 | expected = (
16 | 'SELECT $1, $2, "field_?" FROM t WHERE '
17 | "a='say ''what?''' AND b=$3 AND c=E'?\\'test\\'?'"
18 | )
19 | assert (new_query, vals) == (expected, args)
20 |
21 |
22 | def test_qmark_2():
23 | args = 1, 2, 3
24 | new_query, vals = convert(
25 | "qmark", "SELECT ?, ?, * FROM t WHERE a=? AND b='are you ''sure?'", args
26 | )
27 | expected = "SELECT $1, $2, * FROM t WHERE a=$3 AND b='are you ''sure?'"
28 | assert (new_query, vals) == (expected, args)
29 |
30 |
31 | def test_numeric():
32 | args = 1, 2, 3
33 | new_query, vals = convert(
34 | "numeric", "SELECT sum(x)::decimal(5, 2) :2, :1, * FROM t WHERE a=:3", args
35 | )
36 | expected = "SELECT sum(x)::decimal(5, 2) $2, $1, * FROM t WHERE a=$3"
37 | assert (new_query, vals) == (expected, args)
38 |
39 |
40 | def test_numeric_default_parameter():
41 | args = 1, 2, 3
42 | new_query, vals = convert("numeric", "make_interval(days := 10)", args)
43 |
44 | assert (new_query, vals) == ("make_interval(days := 10)", args)
45 |
46 |
47 | def test_named():
48 | args = {
49 | "f_2": 1,
50 | "f1": 2,
51 | }
52 | new_query, vals = convert(
53 | "named", "SELECT sum(x)::decimal(5, 2) :f_2, :f1 FROM t WHERE a=:f_2", args
54 | )
55 | expected = "SELECT sum(x)::decimal(5, 2) $1, $2 FROM t WHERE a=$1"
56 | assert (new_query, vals) == (expected, (1, 2))
57 |
58 |
59 | def test_format():
60 | args = 1, 2, 3
61 | new_query, vals = convert(
62 | "format",
63 | "SELECT %s, %s, \"f1_%%\", E'txt_%%' "
64 | "FROM t WHERE a=%s AND b='75%%' AND c = '%' -- Comment with %",
65 | args,
66 | )
67 | expected = (
68 | "SELECT $1, $2, \"f1_%%\", E'txt_%%' FROM t WHERE a=$3 AND "
69 | "b='75%%' AND c = '%' -- Comment with %"
70 | )
71 | assert (new_query, vals) == (expected, args)
72 |
73 | sql = (
74 | r"""COMMENT ON TABLE test_schema.comment_test """
75 | r"""IS 'the test % '' " \ table comment'"""
76 | )
77 | new_query, vals = convert("format", sql, args)
78 | assert (new_query, vals) == (sql, args)
79 |
80 |
81 | def test_format_multiline():
82 | args = 1, 2, 3
83 | new_query, vals = convert("format", "SELECT -- Comment\n%s FROM t", args)
84 | assert (new_query, vals) == ("SELECT -- Comment\n$1 FROM t", args)
85 |
86 |
87 | def test_py_format():
88 | args = {"f2": 1, "f1": 2, "f3": 3}
89 |
90 | new_query, vals = convert(
91 | "pyformat",
92 | "SELECT %(f2)s, %(f1)s, \"f1_%%\", E'txt_%%' "
93 | "FROM t WHERE a=%(f2)s AND b='75%%'",
94 | args,
95 | )
96 | expected = "SELECT $1, $2, \"f1_%%\", E'txt_%%' FROM t WHERE a=$1 AND " "b='75%%'"
97 | assert (new_query, vals) == (expected, (1, 2))
98 |
99 | # pyformat should support %s and an array, too:
100 | args = 1, 2, 3
101 | new_query, vals = convert(
102 | "pyformat",
103 | "SELECT %s, %s, \"f1_%%\", E'txt_%%' " "FROM t WHERE a=%s AND b='75%%'",
104 | args,
105 | )
106 | expected = "SELECT $1, $2, \"f1_%%\", E'txt_%%' FROM t WHERE a=$3 AND " "b='75%%'"
107 | assert (new_query, vals) == (expected, args)
108 |
--------------------------------------------------------------------------------
/test/legacy/test_prepared_statement.py:
--------------------------------------------------------------------------------
1 | def test_prepare(con):
2 | con.prepare("SELECT CAST(:v AS INTEGER)")
3 |
4 |
5 | def test_run(con):
6 | ps = con.prepare("SELECT cast(:v as varchar)")
7 | ps.run(v="speedy")
8 |
9 |
10 | def test_run_with_no_results(con):
11 | ps = con.prepare("ROLLBACK")
12 | ps.run()
13 |
--------------------------------------------------------------------------------
/test/legacy/test_query.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime as Datetime, timezone as Timezone
2 | from warnings import filterwarnings
3 |
4 | import pytest
5 |
6 | import pg8000
7 | from pg8000.converters import INET_ARRAY, INTEGER
8 |
9 |
10 | # Tests relating to the basic operation of the database driver, driven by the
11 | # pg8000 custom interface.
12 |
13 |
14 | @pytest.fixture
15 | def db_table(request, con):
16 | filterwarnings("ignore", "DB-API extension cursor.next()")
17 | filterwarnings("ignore", "DB-API extension cursor.__iter__()")
18 | con.paramstyle = "format"
19 | with con.cursor() as cursor:
20 | cursor.execute(
21 | "CREATE TEMPORARY TABLE t1 (f1 int primary key, "
22 | "f2 bigint not null, f3 varchar(50) null) "
23 | )
24 |
25 | def fin():
26 | try:
27 | with con.cursor() as cursor:
28 | cursor.execute("drop table t1")
29 | except pg8000.ProgrammingError:
30 | pass
31 |
32 | request.addfinalizer(fin)
33 | return con
34 |
35 |
36 | def test_database_error(cursor):
37 | with pytest.raises(pg8000.ProgrammingError):
38 | cursor.execute("INSERT INTO t99 VALUES (1, 2, 3)")
39 |
40 |
41 | def test_parallel_queries(db_table):
42 | with db_table.cursor() as cursor:
43 | cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (1, 1, None))
44 | cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (2, 10, None))
45 | cursor.execute(
46 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (3, 100, None)
47 | )
48 | cursor.execute(
49 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (4, 1000, None)
50 | )
51 | cursor.execute(
52 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (5, 10000, None)
53 | )
54 | with db_table.cursor() as c1, db_table.cursor() as c2:
55 | c1.execute("SELECT f1, f2, f3 FROM t1")
56 | for row in c1:
57 | f1, f2, f3 = row
58 | c2.execute("SELECT f1, f2, f3 FROM t1 WHERE f1 > %s", (f1,))
59 | for row in c2:
60 | f1, f2, f3 = row
61 |
62 |
63 | def test_parallel_open_portals(con):
64 | with con.cursor() as c1, con.cursor() as c2:
65 | c1count, c2count = 0, 0
66 | q = "select * from generate_series(1, %s)"
67 | params = (100,)
68 | c1.execute(q, params)
69 | c2.execute(q, params)
70 | for c2row in c2:
71 | c2count += 1
72 | for c1row in c1:
73 | c1count += 1
74 |
75 | assert c1count == c2count
76 |
77 |
78 | # Run a query on a table, alter the structure of the table, then run the
79 | # original query again.
80 |
81 |
82 | def test_alter(db_table):
83 | with db_table.cursor() as cursor:
84 | cursor.execute("select * from t1")
85 | cursor.execute("alter table t1 drop column f3")
86 | cursor.execute("select * from t1")
87 |
88 |
89 | # Run a query on a table, drop then re-create the table, then run the
90 | # original query again.
91 |
92 |
93 | def test_create(db_table):
94 | with db_table.cursor() as cursor:
95 | cursor.execute("select * from t1")
96 | cursor.execute("drop table t1")
97 | cursor.execute("create temporary table t1 (f1 int primary key)")
98 | cursor.execute("select * from t1")
99 |
100 |
101 | def test_insert_returning(db_table):
102 | with db_table.cursor() as cursor:
103 | cursor.execute("CREATE TABLE t2 (id serial, data text)")
104 |
105 | # Test INSERT ... RETURNING with one row...
106 | cursor.execute("INSERT INTO t2 (data) VALUES (%s) RETURNING id", ("test1",))
107 | row_id = cursor.fetchone()[0]
108 | cursor.execute("SELECT data FROM t2 WHERE id = %s", (row_id,))
109 | assert "test1" == cursor.fetchone()[0]
110 |
111 | assert cursor.rowcount == 1
112 |
113 | # Test with multiple rows...
114 | cursor.execute(
115 | "INSERT INTO t2 (data) VALUES (%s), (%s), (%s) " "RETURNING id",
116 | ("test2", "test3", "test4"),
117 | )
118 | assert cursor.rowcount == 3
119 | ids = tuple([x[0] for x in cursor])
120 | assert len(ids) == 3
121 |
122 |
123 | def test_row_count(db_table):
124 | with db_table.cursor() as cursor:
125 | expected_count = 57
126 | cursor.executemany(
127 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)",
128 | tuple((i, i, None) for i in range(expected_count)),
129 | )
130 |
131 | # Check rowcount after executemany
132 | assert expected_count == cursor.rowcount
133 |
134 | cursor.execute("SELECT * FROM t1")
135 |
136 | # Check row_count without doing any reading first...
137 | assert expected_count == cursor.rowcount
138 |
139 | # Check rowcount after reading some rows, make sure it still
140 | # works...
141 | for i in range(expected_count // 2):
142 | cursor.fetchone()
143 | assert expected_count == cursor.rowcount
144 |
145 | with db_table.cursor() as cursor:
146 | # Restart the cursor, read a few rows, and then check rowcount
147 | # again...
148 | cursor.execute("SELECT * FROM t1")
149 | for i in range(expected_count // 3):
150 | cursor.fetchone()
151 | assert expected_count == cursor.rowcount
152 |
153 | # Should be -1 for a command with no results
154 | cursor.execute("DROP TABLE t1")
155 | assert -1 == cursor.rowcount
156 |
157 |
158 | def test_row_count_update(db_table):
159 | with db_table.cursor() as cursor:
160 | cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (1, 1, None))
161 | cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (2, 10, None))
162 | cursor.execute(
163 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (3, 100, None)
164 | )
165 | cursor.execute(
166 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (4, 1000, None)
167 | )
168 | cursor.execute(
169 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (5, 10000, None)
170 | )
171 | cursor.execute("UPDATE t1 SET f3 = %s WHERE f2 > 101", ("Hello!",))
172 | assert cursor.rowcount == 2
173 |
174 |
175 | def test_int_oid(cursor):
176 | # https://bugs.launchpad.net/pg8000/+bug/230796
177 | cursor.execute("SELECT typname FROM pg_type WHERE oid = %s", (100,))
178 |
179 |
180 | def test_unicode_query(cursor):
181 | cursor.execute(
182 | "CREATE TEMPORARY TABLE \u043c\u0435\u0441\u0442\u043e "
183 | "(\u0438\u043c\u044f VARCHAR(50), "
184 | "\u0430\u0434\u0440\u0435\u0441 VARCHAR(250))"
185 | )
186 |
187 |
188 | def test_executemany(db_table):
189 | with db_table.cursor() as cursor:
190 | cursor.executemany(
191 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)",
192 | ((1, 1, "Avast ye!"), (2, 1, None)),
193 | )
194 |
195 | cursor.executemany(
196 | "SELECT CAST(%s AS TIMESTAMP)",
197 | ((Datetime(2014, 5, 7, tzinfo=Timezone.utc),), (Datetime(2014, 5, 7),)),
198 | )
199 |
200 |
201 | def test_executemany_setinputsizes(cursor):
202 | """Make sure that setinputsizes works for all the parameter sets"""
203 |
204 | cursor.execute(
205 | "CREATE TEMPORARY TABLE t1 (f1 int primary key, f2 inet[] not null) "
206 | )
207 |
208 | cursor.setinputsizes(INTEGER, INET_ARRAY)
209 | cursor.executemany(
210 | "INSERT INTO t1 (f1, f2) VALUES (%s, %s)", ((1, ["1.1.1.1"]), (2, ["0.0.0.0"]))
211 | )
212 |
213 |
214 | def test_executemany_no_param_sets(cursor):
215 | cursor.executemany("INSERT INTO t1 (f1, f2) VALUES (%s, %s)", [])
216 | assert cursor.rowcount == -1
217 |
218 |
219 | # Check that autocommit stays off
220 | # We keep track of whether we're in a transaction or not by using the
221 | # READY_FOR_QUERY message.
222 | def test_transactions(db_table):
223 | with db_table.cursor() as cursor:
224 | cursor.execute("commit")
225 | cursor.execute(
226 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (1, 1, "Zombie")
227 | )
228 | cursor.execute("rollback")
229 | cursor.execute("select * from t1")
230 |
231 | assert cursor.rowcount == 0
232 |
233 |
234 | def test_in(cursor):
235 | cursor.execute("SELECT typname FROM pg_type WHERE oid = any(%s)", ([16, 23],))
236 | ret = cursor.fetchall()
237 | assert ret[0][0] == "bool"
238 |
239 |
240 | def test_no_previous_tpc(con):
241 | con.tpc_begin("Stacey")
242 | with con.cursor() as cursor:
243 | cursor.execute("SELECT * FROM pg_type")
244 | con.tpc_commit()
245 |
246 |
247 | # Check that tpc_recover() doesn't start a transaction
248 | def test_tpc_recover(con):
249 | con.tpc_recover()
250 | with con.cursor() as cursor:
251 | con.autocommit = True
252 |
253 | # If tpc_recover() has started a transaction, this will fail
254 | cursor.execute("VACUUM")
255 |
256 |
257 | def test_tpc_prepare(con):
258 | xid = "Stacey"
259 | con.tpc_begin(xid)
260 | con.tpc_prepare()
261 | con.tpc_rollback(xid)
262 |
263 |
264 | def test_empty_query(cursor):
265 | """No exception raised"""
266 | cursor.execute("")
267 |
268 |
269 | # rolling back when not in a transaction doesn't generate a warning
270 | def test_rollback_no_transaction(con):
271 | # Remove any existing notices
272 | con.notices.clear()
273 |
274 | # First, verify that a raw rollback does produce a notice
275 | con.execute_unnamed("rollback")
276 |
277 | assert 1 == len(con.notices)
278 |
279 | # 25P01 is the code for no_active_sql_tronsaction. It has
280 | # a message and severity name, but those might be
281 | # localized/depend on the server version.
282 | assert con.notices.pop().get(b"C") == b"25P01"
283 |
284 | # Now going through the rollback method doesn't produce
285 | # any notices because it knows we're not in a transaction.
286 | con.rollback()
287 |
288 | assert 0 == len(con.notices)
289 |
290 |
291 | def test_context_manager_class(con):
292 | assert "__enter__" in pg8000.legacy.Cursor.__dict__
293 | assert "__exit__" in pg8000.legacy.Cursor.__dict__
294 |
295 | with con.cursor() as cursor:
296 | cursor.execute("select 1")
297 |
298 |
299 | def test_close_prepared_statement(con):
300 | ps = con.prepare("select 1")
301 | ps.run()
302 | res = con.run("select count(*) from pg_prepared_statements")
303 | assert res[0][0] == 1 # Should have one prepared statement
304 |
305 | ps.close()
306 |
307 | res = con.run("select count(*) from pg_prepared_statements")
308 | assert res[0][0] == 0 # Should have no prepared statements
309 |
310 |
311 | def test_setinputsizes(con):
312 | cursor = con.cursor()
313 | cursor.setinputsizes(20)
314 | cursor.execute("select %s", (None,))
315 | retval = cursor.fetchall()
316 | assert retval[0][0] is None
317 |
318 |
319 | def test_setinputsizes_class(con):
320 | cursor = con.cursor()
321 | cursor.setinputsizes(bytes)
322 | cursor.execute("select %s", (None,))
323 | retval = cursor.fetchall()
324 | assert retval[0][0] is None
325 |
326 |
327 | def test_unexecuted_cursor_rowcount(con):
328 | cursor = con.cursor()
329 | assert cursor.rowcount == -1
330 |
331 |
332 | def test_unexecuted_cursor_description(con):
333 | cursor = con.cursor()
334 | assert cursor.description is None
335 |
336 |
337 | def test_not_parsed_if_no_params(mocker, cursor):
338 | mock_convert_paramstyle = mocker.patch("pg8000.legacy.convert_paramstyle")
339 | cursor.execute("ROLLBACK")
340 | mock_convert_paramstyle.assert_not_called()
341 |
--------------------------------------------------------------------------------
/test/legacy/test_typeobjects.py:
--------------------------------------------------------------------------------
1 | from pg8000 import PGInterval
2 |
3 |
4 | def test_pginterval_constructor_days():
5 | i = PGInterval(days=1)
6 | assert i.months is None
7 | assert i.days == 1
8 | assert i.microseconds is None
9 |
--------------------------------------------------------------------------------
/test/native/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tlocke/pg8000/46c00021ade1d19466b07ed30392386c5f0a6b8e/test/native/__init__.py
--------------------------------------------------------------------------------
/test/native/auth/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tlocke/pg8000/46c00021ade1d19466b07ed30392386c5f0a6b8e/test/native/auth/__init__.py
--------------------------------------------------------------------------------
/test/native/auth/test_gss.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from pg8000.native import Connection, InterfaceError
4 |
5 |
6 | def test_gss(db_kwargs):
7 | """This requires a line in pg_hba.conf that requires gss for the database
8 | pg8000_gss
9 | """
10 |
11 | db_kwargs["database"] = "pg8000_gss"
12 |
13 | # Should raise an exception saying gss isn't supported
14 | with pytest.raises(
15 | InterfaceError,
16 | match="Authentication method 7 not supported by pg8000.",
17 | ):
18 | Connection(**db_kwargs)
19 |
--------------------------------------------------------------------------------
/test/native/auth/test_md5.py:
--------------------------------------------------------------------------------
1 | def test_md5(con):
2 | """Called by GitHub Actions with auth method md5.
3 | We just need to check that we can get a connection.
4 | """
5 | pass
6 |
--------------------------------------------------------------------------------
/test/native/auth/test_md5_ssl.py:
--------------------------------------------------------------------------------
1 | import ssl
2 |
3 | from pg8000.native import Connection
4 |
5 |
6 | def test_ssl(db_kwargs):
7 | context = ssl.create_default_context()
8 | context.check_hostname = False
9 | context.verify_mode = ssl.CERT_NONE
10 |
11 | db_kwargs["ssl_context"] = context
12 |
13 | with Connection(**db_kwargs):
14 | pass
15 |
--------------------------------------------------------------------------------
/test/native/auth/test_password.py:
--------------------------------------------------------------------------------
1 | def test_password(con):
2 | """Called by GitHub Actions with auth method password.
3 | We just need to check that we can get a connection.
4 | """
5 | pass
6 |
--------------------------------------------------------------------------------
/test/native/auth/test_scram-sha-256.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from pg8000.native import Connection, DatabaseError, InterfaceError
4 |
5 | # This requires a line in pg_hba.conf that requires scram-sha-256 for the
6 | # database pg8000_scram_sha_256
7 |
8 | DB = "pg8000_scram_sha_256"
9 |
10 |
11 | @pytest.fixture
12 | def setup(con):
13 | try:
14 | con.run(f"CREATE DATABASE {DB}")
15 | except DatabaseError:
16 | pass
17 | con.run("ALTER SYSTEM SET ssl = off")
18 | con.run("SELECT pg_reload_conf()")
19 | yield
20 | con.run("ALTER SYSTEM SET ssl = on")
21 | con.run("SELECT pg_reload_conf()")
22 |
23 |
24 | def test_scram_sha_256(setup, db_kwargs):
25 | db_kwargs["database"] = DB
26 |
27 | with Connection(**db_kwargs):
28 | pass
29 |
30 |
31 | def test_scram_sha_256_ssl_False(setup, db_kwargs):
32 | db_kwargs["database"] = DB
33 | db_kwargs["ssl_context"] = False
34 |
35 | with Connection(**db_kwargs):
36 | pass
37 |
38 |
39 | def test_scram_sha_256_ssl_True(setup, db_kwargs):
40 | db_kwargs["database"] = DB
41 | db_kwargs["ssl_context"] = True
42 |
43 | with pytest.raises(InterfaceError, match="Server refuses SSL"):
44 | with Connection(**db_kwargs):
45 | pass
46 |
--------------------------------------------------------------------------------
/test/native/auth/test_scram-sha-256_ssl.py:
--------------------------------------------------------------------------------
1 | from ssl import CERT_NONE, SSLSocket, create_default_context
2 |
3 | import pytest
4 |
5 | from pg8000.native import Connection, DatabaseError
6 |
7 | # This requires a line in pg_hba.conf that requires scram-sha-256 for the
8 | # database pg8000_scram_sha_256
9 |
10 | DB = "pg8000_scram_sha_256"
11 |
12 |
13 | @pytest.fixture
14 | def setup(con):
15 | try:
16 | con.run(f"CREATE DATABASE {DB}")
17 | except DatabaseError:
18 | pass
19 |
20 |
21 | def test_scram_sha_256_plus(setup, db_kwargs):
22 | db_kwargs["database"] = DB
23 |
24 | with Connection(**db_kwargs) as con:
25 | assert isinstance(con._usock, SSLSocket)
26 |
27 |
28 | def test_scram_sha_256_plus_ssl_True(setup, db_kwargs):
29 | db_kwargs["ssl_context"] = True
30 | db_kwargs["database"] = DB
31 |
32 | with Connection(**db_kwargs) as con:
33 | assert isinstance(con._usock, SSLSocket)
34 |
35 |
36 | def test_scram_sha_256_plus_ssl_custom(setup, db_kwargs):
37 | context = create_default_context()
38 | context.check_hostname = False
39 | context.verify_mode = CERT_NONE
40 |
41 | db_kwargs["ssl_context"] = context
42 | db_kwargs["database"] = DB
43 |
44 | with Connection(**db_kwargs) as con:
45 | assert isinstance(con._usock, SSLSocket)
46 |
47 |
48 | def test_scram_sha_256_plus_ssl_False(setup, db_kwargs):
49 | db_kwargs["ssl_context"] = False
50 | db_kwargs["database"] = DB
51 |
52 | with Connection(**db_kwargs) as con:
53 | assert not isinstance(con._usock, SSLSocket)
54 |
--------------------------------------------------------------------------------
/test/native/conftest.py:
--------------------------------------------------------------------------------
1 | from os import environ
2 |
3 | import pytest
4 |
5 | from test.utils import parse_server_version
6 |
7 | import pg8000.native
8 |
9 |
10 | @pytest.fixture(scope="class")
11 | def db_kwargs():
12 | db_connect = {"user": "postgres", "password": "pw"}
13 |
14 | for kw, var, f in [
15 | ("host", "PGHOST", str),
16 | ("password", "PGPASSWORD", str),
17 | ("port", "PGPORT", int),
18 | ]:
19 | try:
20 | db_connect[kw] = f(environ[var])
21 | except KeyError:
22 | pass
23 |
24 | return db_connect
25 |
26 |
27 | @pytest.fixture
28 | def con(request, db_kwargs):
29 | conn = pg8000.native.Connection(**db_kwargs)
30 |
31 | def fin():
32 | try:
33 | conn.run("rollback")
34 | except pg8000.native.InterfaceError:
35 | pass
36 |
37 | try:
38 | conn.close()
39 | except pg8000.native.InterfaceError:
40 | pass
41 |
42 | request.addfinalizer(fin)
43 | return conn
44 |
45 |
46 | @pytest.fixture
47 | def pg_version(con):
48 | retval = con.run("select current_setting('server_version')")
49 | version = retval[0][0]
50 | major = parse_server_version(version)
51 | return major
52 |
--------------------------------------------------------------------------------
/test/native/test_benchmarks.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 |
4 | @pytest.mark.parametrize(
5 | "txt",
6 | (
7 | ("int2", "cast(id / 100 as int2)"),
8 | "cast(id as int4)",
9 | "cast(id * 100 as int8)",
10 | "(id % 2) = 0",
11 | "N'Static text string'",
12 | "cast(id / 100 as float4)",
13 | "cast(id / 100 as float8)",
14 | "cast(id / 100 as numeric)",
15 | "timestamp '2001-09-28'",
16 | ),
17 | )
18 | def test_round_trips(con, benchmark, txt):
19 | def torun():
20 | query = """SELECT {0} AS column1, {0} AS column2, {0} AS column3,
21 | {0} AS column4, {0} AS column5, {0} AS column6, {0} AS column7
22 | FROM (SELECT generate_series(1, 10000) AS id) AS tbl""".format(
23 | txt
24 | )
25 | con.run(query)
26 |
27 | benchmark(torun)
28 |
--------------------------------------------------------------------------------
/test/native/test_connection.py:
--------------------------------------------------------------------------------
1 | import socket
2 |
3 | from datetime import time as Time
4 |
5 | import pytest
6 |
7 | from pg8000.native import Connection, DatabaseError, InterfaceError, __version__
8 |
9 |
10 | def test_unix_socket_missing():
11 | conn_params = {"unix_sock": "/file-does-not-exist", "user": "doesn't-matter"}
12 |
13 | with pytest.raises(InterfaceError):
14 | Connection(**conn_params)
15 |
16 |
17 | def test_internet_socket_connection_refused():
18 | conn_params = {"port": 0, "user": "doesn't-matter"}
19 |
20 | with pytest.raises(
21 | InterfaceError,
22 | match="Can't create a connection to host localhost and port 0 "
23 | "\\(timeout is None and source_address is None\\).",
24 | ):
25 | Connection(**conn_params)
26 |
27 |
28 | def test_Connection_plain_socket(db_kwargs):
29 | host = db_kwargs.get("host", "localhost")
30 | port = db_kwargs.get("port", 5432)
31 | with socket.create_connection((host, port)) as sock:
32 | conn_params = {
33 | "sock": sock,
34 | "user": db_kwargs["user"],
35 | "password": db_kwargs["password"],
36 | "ssl_context": False,
37 | }
38 |
39 | with Connection(**conn_params) as con:
40 | res = con.run("SELECT 1")
41 | assert res[0][0] == 1
42 |
43 |
44 | def test_database_missing(db_kwargs):
45 | db_kwargs["database"] = "missing-db"
46 | with pytest.raises(DatabaseError):
47 | Connection(**db_kwargs)
48 |
49 |
50 | def test_notify(con):
51 | backend_pid = con.run("select pg_backend_pid()")[0][0]
52 | assert list(con.notifications) == []
53 | con.run("LISTEN test")
54 | con.run("NOTIFY test")
55 |
56 | con.run("VALUES (1, 2), (3, 4), (5, 6)")
57 | assert len(con.notifications) == 1
58 | assert con.notifications[0] == (backend_pid, "test", "")
59 |
60 |
61 | def test_notify_with_payload(con):
62 | backend_pid = con.run("select pg_backend_pid()")[0][0]
63 | assert list(con.notifications) == []
64 | con.run("LISTEN test")
65 | con.run("NOTIFY test, 'Parnham'")
66 |
67 | con.run("VALUES (1, 2), (3, 4), (5, 6)")
68 | assert len(con.notifications) == 1
69 | assert con.notifications[0] == (backend_pid, "test", "Parnham")
70 |
71 |
72 | # This requires a line in pg_hba.conf that requires md5 for the database
73 | # pg8000_md5
74 |
75 |
76 | def test_md5(db_kwargs):
77 | db_kwargs["database"] = "pg8000_md5"
78 |
79 | # Should only raise an exception saying db doesn't exist
80 | with pytest.raises(DatabaseError, match="3D000"):
81 | Connection(**db_kwargs)
82 |
83 |
84 | # This requires a line in pg_hba.conf that requires 'password' for the
85 | # database pg8000_password
86 |
87 |
88 | def test_password(db_kwargs):
89 | db_kwargs["database"] = "pg8000_password"
90 |
91 | # Should only raise an exception saying db doesn't exist
92 | with pytest.raises(DatabaseError, match="3D000"):
93 | Connection(**db_kwargs)
94 |
95 |
96 | def test_unicode_databaseName(db_kwargs):
97 | db_kwargs["database"] = "pg8000_sn\uFF6Fw"
98 |
99 | # Should only raise an exception saying db doesn't exist
100 | with pytest.raises(DatabaseError, match="3D000"):
101 | Connection(**db_kwargs)
102 |
103 |
104 | def test_bytes_databaseName(db_kwargs):
105 | """Should only raise an exception saying db doesn't exist"""
106 |
107 | db_kwargs["database"] = bytes("pg8000_sn\uFF6Fw", "utf8")
108 | with pytest.raises(DatabaseError, match="3D000"):
109 | Connection(**db_kwargs)
110 |
111 |
112 | def test_bytes_password(con, db_kwargs):
113 | # Create user
114 | username = "boltzmann"
115 | password = "cha\uFF6Fs"
116 | con.run("create user " + username + " with password '" + password + "';")
117 |
118 | db_kwargs["user"] = username
119 | db_kwargs["password"] = password.encode("utf8")
120 | db_kwargs["database"] = "pg8000_md5"
121 | with pytest.raises(DatabaseError, match="3D000"):
122 | Connection(**db_kwargs)
123 |
124 | con.run("drop role " + username)
125 |
126 |
127 | def test_broken_pipe_read(con, db_kwargs):
128 | db1 = Connection(**db_kwargs)
129 | res = db1.run("select pg_backend_pid()")
130 | pid1 = res[0][0]
131 |
132 | con.run("select pg_terminate_backend(:v)", v=pid1)
133 | with pytest.raises(InterfaceError, match="network error"):
134 | db1.run("select 1")
135 |
136 | try:
137 | db1.close()
138 | except InterfaceError:
139 | pass
140 |
141 |
142 | def test_broken_pipe_unpack(con):
143 | res = con.run("select pg_backend_pid()")
144 | pid1 = res[0][0]
145 |
146 | with pytest.raises(InterfaceError, match="network error"):
147 | con.run("select pg_terminate_backend(:v)", v=pid1)
148 |
149 |
150 | def test_broken_pipe_flush(con, db_kwargs):
151 | db1 = Connection(**db_kwargs)
152 | res = db1.run("select pg_backend_pid()")
153 | pid1 = res[0][0]
154 |
155 | con.run("select pg_terminate_backend(:v)", v=pid1)
156 | try:
157 | db1.run("select 1")
158 | except BaseException:
159 | pass
160 |
161 | # Sometimes raises and sometime doesn't
162 | try:
163 | db1.close()
164 | except InterfaceError as e:
165 | assert str(e) == "network error"
166 |
167 |
168 | def test_application_name(db_kwargs):
169 | app_name = "my test application name"
170 | db_kwargs["application_name"] = app_name
171 | with Connection(**db_kwargs) as db:
172 | res = db.run(
173 | "select application_name from pg_stat_activity "
174 | " where pid = pg_backend_pid()"
175 | )
176 |
177 | application_name = res[0][0]
178 | assert application_name == app_name
179 |
180 |
181 | def test_application_name_integer(db_kwargs):
182 | db_kwargs["application_name"] = 1
183 | with pytest.raises(
184 | InterfaceError,
185 | match="The parameter application_name can't be of type .",
186 | ):
187 | Connection(**db_kwargs)
188 |
189 |
190 | def test_application_name_bytearray(db_kwargs):
191 | db_kwargs["application_name"] = bytearray(b"Philby")
192 | with Connection(**db_kwargs):
193 | pass
194 |
195 |
196 | class PG8000TestException(Exception):
197 | pass
198 |
199 |
200 | def raise_exception(val):
201 | raise PG8000TestException("oh noes!")
202 |
203 |
204 | def test_py_value_fail(con, mocker):
205 | # Ensure that if types.py_value throws an exception, the original
206 | # exception is raised (PG8000TestException), and the connection is
207 | # still usable after the error.
208 | mocker.patch.object(con, "py_types")
209 | con.py_types = {Time: raise_exception}
210 |
211 | with pytest.raises(PG8000TestException):
212 | con.run("SELECT CAST(:v AS TIME)", v=Time(10, 30))
213 |
214 | # ensure that the connection is still usable for a new query
215 | res = con.run("VALUES ('hw3'::text)")
216 | assert res[0][0] == "hw3"
217 |
218 |
219 | def test_no_data_error_recovery(con):
220 | for i in range(1, 4):
221 | with pytest.raises(DatabaseError) as e:
222 | con.run("DROP TABLE t1")
223 | assert e.value.args[0]["C"] == "42P01"
224 | con.run("ROLLBACK")
225 |
226 |
227 | def test_closed_connection(con):
228 | con.close()
229 | with pytest.raises(InterfaceError, match="connection is closed"):
230 | con.run("VALUES ('hw1'::text)")
231 |
232 |
233 | def test_version():
234 | try:
235 | from importlib.metadata import version
236 | except ImportError:
237 | from importlib_metadata import version
238 |
239 | v = version("pg8000")
240 |
241 | assert __version__ == v
242 |
243 |
244 | @pytest.mark.parametrize(
245 | "commit",
246 | [
247 | "commit",
248 | "COMMIT;",
249 | ],
250 | )
251 | def test_failed_transaction_commit(con, commit):
252 | con.run("create temporary table tt (f1 int primary key)")
253 | con.run("begin")
254 | try:
255 | con.run("insert into tt(f1) values(null)")
256 | except DatabaseError:
257 | pass
258 |
259 | with pytest.raises(InterfaceError):
260 | con.run(commit)
261 |
262 |
263 | @pytest.mark.parametrize(
264 | "rollback",
265 | [
266 | "rollback",
267 | "rollback;",
268 | "ROLLBACK ;",
269 | ],
270 | )
271 | def test_failed_transaction_rollback(con, rollback):
272 | con.run("create temporary table tt (f1 int primary key)")
273 | con.run("begin")
274 | try:
275 | con.run("insert into tt(f1) values(null)")
276 | except DatabaseError:
277 | pass
278 |
279 | con.run(rollback)
280 |
281 |
282 | @pytest.mark.parametrize(
283 | "rollback",
284 | [
285 | "rollback to sp",
286 | "rollback to sp;",
287 | "ROLLBACK TO sp ;",
288 | ],
289 | )
290 | def test_failed_transaction_rollback_to_savepoint(con, rollback):
291 | con.run("create temporary table tt (f1 int primary key)")
292 | con.run("begin")
293 | con.run("SAVEPOINT sp;")
294 |
295 | try:
296 | con.run("insert into tt(f1) values(null)")
297 | except DatabaseError:
298 | pass
299 |
300 | con.run(rollback)
301 |
302 |
303 | @pytest.mark.parametrize(
304 | "sql",
305 | [
306 | "BEGIN",
307 | "select * from tt;",
308 | ],
309 | )
310 | def test_failed_transaction_sql(con, sql):
311 | con.run("create temporary table tt (f1 int primary key)")
312 | con.run("begin")
313 | try:
314 | con.run("insert into tt(f1) values(null)")
315 | except DatabaseError:
316 | pass
317 |
318 | with pytest.raises(DatabaseError):
319 | con.run(sql)
320 |
321 |
322 | def test_parameter_statuses(con):
323 | role_name = "Æthelred"
324 | try:
325 | con.run(f"create role {role_name}")
326 | except DatabaseError:
327 | pass
328 | con.run(f"set session authorization '{role_name}'")
329 | assert role_name == con.parameter_statuses["session_authorization"]
330 |
--------------------------------------------------------------------------------
/test/native/test_copy.py:
--------------------------------------------------------------------------------
1 | from io import BytesIO, StringIO
2 |
3 | import pytest
4 |
5 |
6 | @pytest.fixture
7 | def db_table(request, con):
8 | con.run("START TRANSACTION")
9 | con.run(
10 | "CREATE TEMPORARY TABLE t1 "
11 | "(f1 int primary key, f2 int not null, f3 varchar(50) null) "
12 | "on commit drop"
13 | )
14 | return con
15 |
16 |
17 | def test_copy_to_with_table(db_table):
18 | db_table.run("INSERT INTO t1 (f1, f2, f3) VALUES (:v1, :v1, :v2)", v1=1, v2="1")
19 | db_table.run("INSERT INTO t1 (f1, f2, f3) VALUES (:v1, :v1, :v2)", v1=2, v2="2")
20 | db_table.run("INSERT INTO t1 (f1, f2, f3) VALUES (:v1, :v1, :v2)", v1=3, v2="3")
21 |
22 | stream = BytesIO()
23 | db_table.run("copy t1 to stdout", stream=stream)
24 | assert stream.getvalue() == b"1\t1\t1\n2\t2\t2\n3\t3\t3\n"
25 | assert db_table.row_count == 3
26 |
27 |
28 | def test_copy_to_with_query(con):
29 | stream = BytesIO()
30 | con.run(
31 | "COPY (SELECT 1 as One, 2 as Two) TO STDOUT WITH DELIMITER "
32 | "'X' CSV HEADER QUOTE AS 'Y' FORCE QUOTE Two",
33 | stream=stream,
34 | )
35 | assert stream.getvalue() == b"oneXtwo\n1XY2Y\n"
36 | assert con.row_count == 1
37 |
38 |
39 | def test_copy_to_with_text_stream(con):
40 | stream = StringIO()
41 | con.run(
42 | "COPY (SELECT 1 as One, 2 as Two) TO STDOUT WITH DELIMITER "
43 | "'X' CSV HEADER QUOTE AS 'Y' FORCE QUOTE Two",
44 | stream=stream,
45 | )
46 | assert stream.getvalue() == "oneXtwo\n1XY2Y\n"
47 | assert con.row_count == 1
48 |
49 |
50 | def test_copy_from_with_table(db_table):
51 | stream = BytesIO(b"1\t1\t1\n2\t2\t2\n3\t3\t3\n")
52 | db_table.run("copy t1 from STDIN", stream=stream)
53 | assert db_table.row_count == 3
54 |
55 | retval = db_table.run("SELECT * FROM t1 ORDER BY f1")
56 | assert retval == [[1, 1, "1"], [2, 2, "2"], [3, 3, "3"]]
57 |
58 |
59 | def test_copy_from_with_text_stream(db_table):
60 | stream = StringIO("1\t1\t1\n2\t2\t2\n3\t3\t3\n")
61 | db_table.run("copy t1 from STDIN", stream=stream)
62 |
63 | retval = db_table.run("SELECT * FROM t1 ORDER BY f1")
64 | assert retval == [[1, 1, "1"], [2, 2, "2"], [3, 3, "3"]]
65 |
66 |
67 | def test_copy_from_with_query(db_table):
68 | stream = BytesIO(b"f1Xf2\n1XY1Y\n")
69 | db_table.run(
70 | "COPY t1 (f1, f2) FROM STDIN WITH DELIMITER 'X' CSV HEADER "
71 | "QUOTE AS 'Y' FORCE NOT NULL f1",
72 | stream=stream,
73 | )
74 | assert db_table.row_count == 1
75 |
76 | retval = db_table.run("SELECT * FROM t1 ORDER BY f1")
77 | assert retval == [[1, 1, None]]
78 |
79 |
80 | def test_copy_from_with_error(db_table):
81 | stream = BytesIO(b"f1Xf2\n\n1XY1Y\n")
82 | with pytest.raises(BaseException) as e:
83 | db_table.run(
84 | "COPY t1 (f1, f2) FROM STDIN WITH DELIMITER 'X' CSV HEADER "
85 | "QUOTE AS 'Y' FORCE NOT NULL f1",
86 | stream=stream,
87 | )
88 |
89 | arg = {
90 | "S": ("ERROR",),
91 | "C": ("22P02",),
92 | "M": (
93 | 'invalid input syntax for type integer: ""',
94 | 'invalid input syntax for integer: ""',
95 | ),
96 | "W": ('COPY t1, line 2, column f1: ""',),
97 | "F": ("numutils.c",),
98 | "R": ("pg_atoi", "pg_strtoint32", "pg_strtoint32_safe"),
99 | }
100 | earg = e.value.args[0]
101 | for k, v in arg.items():
102 | assert earg[k] in v
103 |
104 |
105 | def test_copy_from_with_text_iterable(db_table):
106 | stream = ["1\t1\t1\n", "2\t2\t2\n", "3\t3\t3\n"]
107 | db_table.run("copy t1 from STDIN", stream=stream)
108 |
109 | retval = db_table.run("SELECT * FROM t1 ORDER BY f1")
110 | assert retval == [[1, 1, "1"], [2, 2, "2"], [3, 3, "3"]]
111 |
--------------------------------------------------------------------------------
/test/native/test_core.py:
--------------------------------------------------------------------------------
1 | from io import BytesIO
2 |
3 | import pytest
4 |
5 | from pg8000.core import (
6 | Context,
7 | CoreConnection,
8 | NULL_BYTE,
9 | PASSWORD,
10 | _create_message,
11 | _make_socket,
12 | _read,
13 | )
14 | from pg8000.native import InterfaceError
15 |
16 |
17 | def test_make_socket(mocker):
18 | unix_sock = None
19 | sock = mocker.Mock()
20 | host = "localhost"
21 | port = 5432
22 | timeout = None
23 | source_address = None
24 | tcp_keepalive = True
25 | ssl_context = None
26 | _make_socket(
27 | unix_sock, sock, host, port, timeout, source_address, tcp_keepalive, ssl_context
28 | )
29 |
30 |
31 | def test_handle_AUTHENTICATION_3(mocker):
32 | """Shouldn't send a FLUSH message, as FLUSH only used in extended-query"""
33 |
34 | mocker.patch.object(CoreConnection, "__init__", lambda x: None)
35 | con = CoreConnection()
36 | password = "barbour".encode("utf8")
37 | con.password = password
38 | con._sock = mocker.Mock()
39 | buf = BytesIO()
40 | con._sock.write = buf.write
41 | CoreConnection.handle_AUTHENTICATION_REQUEST(con, b"\x00\x00\x00\x03", None)
42 | assert buf.getvalue() == _create_message(PASSWORD, password + NULL_BYTE)
43 |
44 |
45 | def test_create_message():
46 | msg = _create_message(PASSWORD, "barbour".encode("utf8") + NULL_BYTE)
47 | assert msg == b"p\x00\x00\x00\x0cbarbour\x00"
48 |
49 |
50 | def test_handle_ERROR_RESPONSE(mocker):
51 | """Check it handles invalid encodings in the error messages"""
52 |
53 | mocker.patch.object(CoreConnection, "__init__", lambda x: None)
54 | con = CoreConnection()
55 | con._client_encoding = "utf8"
56 | data = b"S\xc2err" + NULL_BYTE + NULL_BYTE
57 | context = Context(None)
58 | CoreConnection.handle_ERROR_RESPONSE(con, data, context)
59 | assert str(context.error) == "{'S': '�err'}"
60 |
61 |
62 | def test_read(mocker):
63 | mock_socket = mocker.Mock()
64 | mock_socket.read = mocker.Mock(return_value=b"")
65 | with pytest.raises(InterfaceError, match="network error"):
66 | _read(mock_socket, 5)
67 |
--------------------------------------------------------------------------------
/test/native/test_import_all.py:
--------------------------------------------------------------------------------
1 | from pg8000.native import * # noqa: F403
2 |
3 |
4 | def test_import_all():
5 | type(Connection) # noqa: F405
6 |
--------------------------------------------------------------------------------
/test/native/test_prepared_statement.py:
--------------------------------------------------------------------------------
1 | def test_prepare(con):
2 | con.prepare("SELECT CAST(:v AS INTEGER)")
3 |
4 |
5 | def test_run(con):
6 | ps = con.prepare("SELECT cast(:v as varchar)")
7 | ps.run(v="speedy")
8 |
--------------------------------------------------------------------------------
/test/native/test_query.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from pg8000.native import DatabaseError, to_statement
4 |
5 |
6 | # Tests relating to the basic operation of the database driver, driven by the
7 | # pg8000 custom interface.
8 |
9 |
10 | @pytest.fixture
11 | def db_table(request, con):
12 | con.run(
13 | "CREATE TEMPORARY TABLE t1 (f1 int primary key, "
14 | "f2 bigint not null, f3 varchar(50) null) "
15 | )
16 |
17 | def fin():
18 | try:
19 | con.run("drop table t1")
20 | except DatabaseError:
21 | pass
22 |
23 | request.addfinalizer(fin)
24 | return con
25 |
26 |
27 | def test_database_error(con):
28 | with pytest.raises(DatabaseError):
29 | con.run("INSERT INTO t99 VALUES (1, 2, 3)")
30 |
31 |
32 | # Run a query on a table, alter the structure of the table, then run the
33 | # original query again.
34 |
35 |
36 | def test_alter(db_table):
37 | db_table.run("select * from t1")
38 | db_table.run("alter table t1 drop column f3")
39 | db_table.run("select * from t1")
40 |
41 |
42 | # Run a query on a table, drop then re-create the table, then run the
43 | # original query again.
44 |
45 |
46 | def test_create(db_table):
47 | db_table.run("select * from t1")
48 | db_table.run("drop table t1")
49 | db_table.run("create temporary table t1 (f1 int primary key)")
50 | db_table.run("select * from t1")
51 |
52 |
53 | def test_parametrized(db_table):
54 | res = db_table.run("SELECT f1, f2, f3 FROM t1 WHERE f1 > :f1", f1=3)
55 | for row in res:
56 | f1, f2, f3 = row
57 |
58 |
59 | def test_insert_returning(db_table):
60 | db_table.run("CREATE TEMPORARY TABLE t2 (id serial, data text)")
61 |
62 | # Test INSERT ... RETURNING with one row...
63 | res = db_table.run("INSERT INTO t2 (data) VALUES (:v) RETURNING id", v="test1")
64 | row_id = res[0][0]
65 | res = db_table.run("SELECT data FROM t2 WHERE id = :v", v=row_id)
66 | assert "test1" == res[0][0]
67 |
68 | assert db_table.row_count == 1
69 |
70 | # Test with multiple rows...
71 | res = db_table.run(
72 | "INSERT INTO t2 (data) VALUES (:v1), (:v2), (:v3) " "RETURNING id",
73 | v1="test2",
74 | v2="test3",
75 | v3="test4",
76 | )
77 | assert db_table.row_count == 3
78 | ids = [x[0] for x in res]
79 | assert len(ids) == 3
80 |
81 |
82 | def test_row_count_select(db_table):
83 | expected_count = 57
84 | for i in range(expected_count):
85 | db_table.run(
86 | "INSERT INTO t1 (f1, f2, f3) VALUES (:v1, :v2, :v3)", v1=i, v2=i, v3=None
87 | )
88 |
89 | db_table.run("SELECT * FROM t1")
90 |
91 | # Check row_count
92 | assert expected_count == db_table.row_count
93 |
94 | # Should be -1 for a command with no results
95 | db_table.run("DROP TABLE t1")
96 | assert -1 == db_table.row_count
97 |
98 |
99 | def test_row_count_delete(db_table):
100 | db_table.run(
101 | "INSERT INTO t1 (f1, f2, f3) VALUES (:v1, :v2, :v3)", v1=1, v2=1, v3=None
102 | )
103 | db_table.run("DELETE FROM t1")
104 | assert db_table.row_count == 1
105 |
106 |
107 | def test_row_count_update(db_table):
108 | db_table.run(
109 | "INSERT INTO t1 (f1, f2, f3) VALUES (:v1, :v2, :v3)", v1=1, v2=1, v3=None
110 | )
111 | db_table.run(
112 | "INSERT INTO t1 (f1, f2, f3) VALUES (:v1, :v2, :v3)", v1=2, v2=10, v3=None
113 | )
114 | db_table.run(
115 | "INSERT INTO t1 (f1, f2, f3) VALUES (:v1, :v2, :v3)", v1=3, v2=100, v3=None
116 | )
117 | db_table.run(
118 | "INSERT INTO t1 (f1, f2, f3) VALUES (:v1, :v2, :v3)", v1=4, v2=1000, v3=None
119 | )
120 | db_table.run(
121 | "INSERT INTO t1 (f1, f2, f3) VALUES (:v1, :v2, :v3)", v1=5, v2=10000, v3=None
122 | )
123 | db_table.run("UPDATE t1 SET f3 = :v1 WHERE f2 > 101", v1="Hello!")
124 | assert db_table.row_count == 2
125 |
126 |
127 | def test_int_oid(con):
128 | # https://bugs.launchpad.net/pg8000/+bug/230796
129 | con.run("SELECT typname FROM pg_type WHERE oid = :v", v=100)
130 |
131 |
132 | def test_unicode_query(con):
133 | con.run(
134 | "CREATE TEMPORARY TABLE \u043c\u0435\u0441\u0442\u043e "
135 | "(\u0438\u043c\u044f VARCHAR(50), "
136 | "\u0430\u0434\u0440\u0435\u0441 VARCHAR(250))"
137 | )
138 |
139 |
140 | def test_transactions(db_table):
141 | db_table.run("start transaction")
142 | db_table.run(
143 | "INSERT INTO t1 (f1, f2, f3) VALUES (:v1, :v2, :v3)", v1=1, v2=1, v3="Zombie"
144 | )
145 | db_table.run("rollback")
146 | db_table.run("select * from t1")
147 |
148 | assert db_table.row_count == 0
149 |
150 |
151 | def test_in(con):
152 | ret = con.run("SELECT typname FROM pg_type WHERE oid = any(:v)", v=[16, 23])
153 | assert ret[0][0] == "bool"
154 |
155 |
156 | def test_empty_query(con):
157 | """No exception thrown"""
158 | con.run("")
159 |
160 |
161 | def test_rollback_no_transaction(con):
162 | # Remove any existing notices
163 | con.notices.clear()
164 |
165 | # First, verify that a raw rollback does produce a notice
166 | con.run("rollback")
167 |
168 | assert 1 == len(con.notices)
169 |
170 | # 25P01 is the code for no_active_sql_tronsaction. It has
171 | # a message and severity name, but those might be
172 | # localized/depend on the server version.
173 | assert con.notices.pop().get(b"C") == b"25P01"
174 |
175 |
176 | def test_close_prepared_statement(con):
177 | ps = con.prepare("select 1")
178 | ps.run()
179 | res = con.run("select count(*) from pg_prepared_statements")
180 | assert res[0][0] == 1 # Should have one prepared statement
181 |
182 | ps.close()
183 |
184 | res = con.run("select count(*) from pg_prepared_statements")
185 | assert res[0][0] == 0 # Should have no prepared statements
186 |
187 |
188 | def test_no_data(con):
189 | assert con.run("START TRANSACTION") is None
190 |
191 |
192 | def test_multiple_statements(con):
193 | statements = "SELECT 5; SELECT 'Erich Fromm';"
194 | assert con.run(statements) == [[5], ["Erich Fromm"]]
195 |
196 |
197 | def test_unexecuted_connection_row_count(con):
198 | assert con.row_count is None
199 |
200 |
201 | def test_unexecuted_connection_columns(con):
202 | assert con.columns is None
203 |
204 |
205 | def test_sql_prepared_statement(con):
206 | con.run("PREPARE gen_series AS SELECT generate_series(1, 10);")
207 | con.run("EXECUTE gen_series")
208 |
209 |
210 | def test_to_statement():
211 | new_query, _ = to_statement(
212 | "SELECT sum(x)::decimal(5, 2) :f_2, :f1 FROM t WHERE a=:f_2"
213 | )
214 | expected = "SELECT sum(x)::decimal(5, 2) $1, $2 FROM t WHERE a=$1"
215 | assert new_query == expected
216 |
217 |
218 | def test_to_statement_quotes():
219 | new_query, _ = to_statement("SELECT $$'$$ = :v")
220 | expected = "SELECT $$'$$ = $1"
221 | assert new_query == expected
222 |
223 |
224 | def test_not_parsed_if_no_params(mocker, con):
225 | mock_to_statement = mocker.patch("pg8000.native.to_statement")
226 | con.run("ROLLBACK")
227 | mock_to_statement.assert_not_called()
228 |
229 |
230 | def test_max_parameters(con):
231 | SIZE = 60000
232 | kwargs = {f"param_{i}": 1 for i in range(SIZE)}
233 | con.run(
234 | f"SELECT 1 WHERE 1 IN ({','.join([f':param_{i}' for i in range(SIZE)])})",
235 | **kwargs,
236 | )
237 |
238 |
239 | def test_pg_placeholder_style(con):
240 | rows = con.run("SELECT $1", title="A Time Of Hope")
241 | assert rows[0] == ["A Time Of Hope"]
242 |
--------------------------------------------------------------------------------
/test/test_converters.py:
--------------------------------------------------------------------------------
1 | from datetime import (
2 | date as Date,
3 | datetime as DateTime,
4 | time as Time,
5 | timedelta as TimeDelta,
6 | timezone as TimeZone,
7 | )
8 | from decimal import Decimal
9 | from ipaddress import IPv4Address, IPv4Network
10 |
11 | import pytest
12 |
13 | from pg8000.converters import (
14 | PGInterval,
15 | PY_TYPES,
16 | Range,
17 | array_out,
18 | array_string_escape,
19 | date_in,
20 | datemultirange_in,
21 | identifier,
22 | int4range_in,
23 | interval_in,
24 | literal,
25 | make_param,
26 | null_out,
27 | numeric_in,
28 | numeric_out,
29 | pg_interval_in,
30 | range_out,
31 | string_in,
32 | string_out,
33 | time_in,
34 | timestamp_in,
35 | timestamptz_in,
36 | tsrange_in,
37 | )
38 | from pg8000.native import InterfaceError
39 |
40 |
41 | @pytest.mark.parametrize(
42 | "value,expected",
43 | [
44 | ["2022-03-02", Date(2022, 3, 2)],
45 | ["infinity", "infinity"],
46 | ["-infinity", "-infinity"],
47 | ["20022-03-02", "20022-03-02"],
48 | ],
49 | )
50 | def test_date_in(value, expected):
51 | assert date_in(value) == expected
52 |
53 |
54 | def test_null_out():
55 | assert null_out(None) is None
56 |
57 |
58 | @pytest.mark.parametrize(
59 | "array,out",
60 | [
61 | [[True, False, None], "{true,false,NULL}"], # bool[]
62 | [[IPv4Address("192.168.0.1")], "{192.168.0.1}"], # inet[]
63 | [[Date(2021, 3, 1)], "{2021-03-01}"], # date[]
64 | [[b"\x00\x01\x02\x03\x02\x01\x00"], '{"\\\\x00010203020100"}'], # bytea[]
65 | [[IPv4Network("192.168.0.0/28")], "{192.168.0.0/28}"], # inet[]
66 | [[1, 2, 3], "{1,2,3}"], # int2[]
67 | [[1, None, 3], "{1,NULL,3}"], # int2[] with None
68 | [[[1, 2], [3, 4]], "{{1,2},{3,4}}"], # int2[] multidimensional
69 | [[70000, 2, 3], "{70000,2,3}"], # int4[]
70 | [[7000000000, 2, 3], "{7000000000,2,3}"], # int8[]
71 | [[0, 7000000000, 2], "{0,7000000000,2}"], # int8[]
72 | [[1.1, 2.2, 3.3], "{1.1,2.2,3.3}"], # float8[]
73 | [["Veni", "vidi", "vici"], "{Veni,vidi,vici}"], # varchar[]
74 | [[("Veni", True, 1)], '{"(Veni,true,1)"}'], # array of composites
75 | ],
76 | )
77 | def test_array_out(array, out):
78 | assert array_out(array) == out
79 |
80 |
81 | @pytest.mark.parametrize(
82 | "value",
83 | [
84 | "1.1",
85 | "-1.1",
86 | "10000",
87 | "20000",
88 | "-1000000000.123456789",
89 | "1.0",
90 | "12.44",
91 | ],
92 | )
93 | def test_numeric_out(value):
94 | assert numeric_out(value) == str(value)
95 |
96 |
97 | @pytest.mark.parametrize(
98 | "value",
99 | [
100 | "1.1",
101 | "-1.1",
102 | "10000",
103 | "20000",
104 | "-1000000000.123456789",
105 | "1.0",
106 | "12.44",
107 | ],
108 | )
109 | def test_numeric_in(value):
110 | assert numeric_in(value) == Decimal(value)
111 |
112 |
113 | @pytest.mark.parametrize(
114 | "data,expected",
115 | [
116 | ("[6,3]", Range(6, 3, bounds="[]")),
117 | ],
118 | )
119 | def test_int4range_in(data, expected):
120 | assert int4range_in(data) == expected
121 |
122 |
123 | @pytest.mark.parametrize(
124 | "v,expected",
125 | [
126 | (Range(6, 3, bounds="[]"), "[6,3]"),
127 | ],
128 | )
129 | def test_range_out(v, expected):
130 | assert range_out(v) == expected
131 |
132 |
133 | @pytest.mark.parametrize(
134 | "value",
135 | [
136 | "hello \u0173 world",
137 | ],
138 | )
139 | def test_string_out(value):
140 | assert string_out(value) == value
141 |
142 |
143 | @pytest.mark.parametrize(
144 | "value",
145 | [
146 | "hello \u0173 world",
147 | ],
148 | )
149 | def test_string_in(value):
150 | assert string_in(value) == value
151 |
152 |
153 | def test_time_in():
154 | actual = time_in("12:57:18.000396")
155 | assert actual == Time(12, 57, 18, 396)
156 |
157 |
158 | @pytest.mark.parametrize(
159 | "value,expected",
160 | [
161 | ("1 year", PGInterval(years=1)),
162 | ("2 hours", PGInterval(hours=2)),
163 | ],
164 | )
165 | def test_pg_interval_in(value, expected):
166 | assert pg_interval_in(value) == expected
167 |
168 |
169 | @pytest.mark.parametrize(
170 | "value,expected",
171 | [
172 | ("2 hours", TimeDelta(hours=2)),
173 | ("00:00:30", TimeDelta(seconds=30)),
174 | ],
175 | )
176 | def test_interval_in(value, expected):
177 | assert interval_in(value) == expected
178 |
179 |
180 | @pytest.mark.parametrize(
181 | "value,expected",
182 | [
183 | ("a", "a"),
184 | ('"', '"\\""'),
185 | ("\r", '"\r"'),
186 | ("\n", '"\n"'),
187 | ("\t", '"\t"'),
188 | ],
189 | )
190 | def test_array_string_escape(value, expected):
191 | res = array_string_escape(value)
192 | assert res == expected
193 |
194 |
195 | @pytest.mark.parametrize(
196 | "value,expected",
197 | [
198 | [
199 | "2022-10-08 15:01:39+01:30",
200 | DateTime(
201 | 2022, 10, 8, 15, 1, 39, tzinfo=TimeZone(TimeDelta(hours=1, minutes=30))
202 | ),
203 | ],
204 | [
205 | "2022-10-08 15:01:39-01:30",
206 | DateTime(
207 | 2022,
208 | 10,
209 | 8,
210 | 15,
211 | 1,
212 | 39,
213 | tzinfo=TimeZone(TimeDelta(hours=-1, minutes=-30)),
214 | ),
215 | ],
216 | [
217 | "2022-10-08 15:01:39+02",
218 | DateTime(2022, 10, 8, 15, 1, 39, tzinfo=TimeZone(TimeDelta(hours=2))),
219 | ],
220 | [
221 | "2022-10-08 15:01:39-02",
222 | DateTime(2022, 10, 8, 15, 1, 39, tzinfo=TimeZone(TimeDelta(hours=-2))),
223 | ],
224 | [
225 | "2022-10-08 15:01:39.597026+01:30",
226 | DateTime(
227 | 2022,
228 | 10,
229 | 8,
230 | 15,
231 | 1,
232 | 39,
233 | 597026,
234 | tzinfo=TimeZone(TimeDelta(hours=1, minutes=30)),
235 | ),
236 | ],
237 | [
238 | "2022-10-08 15:01:39.597026-01:30",
239 | DateTime(
240 | 2022,
241 | 10,
242 | 8,
243 | 15,
244 | 1,
245 | 39,
246 | 597026,
247 | tzinfo=TimeZone(TimeDelta(hours=-1, minutes=-30)),
248 | ),
249 | ],
250 | [
251 | "2022-10-08 15:01:39.597026+02",
252 | DateTime(
253 | 2022, 10, 8, 15, 1, 39, 597026, tzinfo=TimeZone(TimeDelta(hours=2))
254 | ),
255 | ],
256 | [
257 | "2022-10-08 15:01:39.597026-02",
258 | DateTime(
259 | 2022, 10, 8, 15, 1, 39, 597026, tzinfo=TimeZone(TimeDelta(hours=-2))
260 | ),
261 | ],
262 | [
263 | "20022-10-08 15:01:39.597026-02",
264 | "20022-10-08 15:01:39.597026-02",
265 | ],
266 | ],
267 | )
268 | def test_timestamptz_in(value, expected):
269 | assert timestamptz_in(value) == expected
270 |
271 |
272 | @pytest.mark.parametrize(
273 | "value,expected",
274 | [
275 | [
276 | "20022-10-08 15:01:39.597026",
277 | "20022-10-08 15:01:39.597026",
278 | ],
279 | ],
280 | )
281 | def test_timestamp_in(value, expected):
282 | assert timestamp_in(value) == expected
283 |
284 |
285 | @pytest.mark.parametrize(
286 | "value,expected",
287 | [
288 | [
289 | '["2001-02-03 04:05:00","2023-02-03 04:05:00")',
290 | Range(DateTime(2001, 2, 3, 4, 5), DateTime(2023, 2, 3, 4, 5)),
291 | ],
292 | ],
293 | )
294 | def test_tsrange_in(value, expected):
295 | assert tsrange_in(value) == expected
296 |
297 |
298 | @pytest.mark.parametrize(
299 | "value,expected",
300 | [
301 | [
302 | "{[2023-06-01,2023-06-06),[2023-06-10,2023-06-13)}",
303 | [
304 | Range(Date(2023, 6, 1), Date(2023, 6, 6)),
305 | Range(Date(2023, 6, 10), Date(2023, 6, 13)),
306 | ],
307 | ],
308 | ],
309 | )
310 | def test_datemultirange_in(value, expected):
311 | assert datemultirange_in(value) == expected
312 |
313 |
314 | def test_make_param():
315 | class BClass(bytearray):
316 | pass
317 |
318 | val = BClass(b"\x00\x01\x02\x03\x02\x01\x00")
319 | assert make_param(PY_TYPES, val) == "\\x00010203020100"
320 |
321 |
322 | def test_identifier_int():
323 | with pytest.raises(InterfaceError, match="identifier must be a str"):
324 | identifier(9)
325 |
326 |
327 | def test_identifier_empty():
328 | with pytest.raises(
329 | InterfaceError, match="identifier must be > 0 characters in length"
330 | ):
331 | identifier("")
332 |
333 |
334 | def test_identifier_quoted_null():
335 | with pytest.raises(
336 | InterfaceError, match="identifier cannot contain the code zero character"
337 | ):
338 | identifier("tabl\u0000e")
339 |
340 |
341 | @pytest.mark.parametrize(
342 | "value,expected",
343 | [
344 | ("top_secret", '"top_secret"'),
345 | (" Table", '" Table"'),
346 | ("A Table", '"A Table"'),
347 | ('A " Table', '"A "" Table"'),
348 | ("table$", '"table$"'),
349 | ("Table$", '"Table$"'),
350 | ("tableఐ", '"tableఐ"'), # Unicode character 0C10 which is uncased
351 | ("table", '"table"'),
352 | ("tAble", '"tAble"'),
353 | ],
354 | )
355 | def test_identifier_success(value, expected):
356 | assert identifier(value) == expected
357 |
358 |
359 | @pytest.mark.parametrize(
360 | "value,expected",
361 | [("top_secret", "'top_secret'"), (["cove"], "'{cove}'")],
362 | )
363 | def test_literal(value, expected):
364 | assert literal(value) == expected
365 |
366 |
367 | def test_literal_quote():
368 | assert literal("bob's") == "'bob''s'"
369 |
370 |
371 | def test_literal_int():
372 | assert literal(7) == "7"
373 |
374 |
375 | def test_literal_float():
376 | assert literal(7.9) == "7.9"
377 |
378 |
379 | def test_literal_decimal():
380 | assert literal(Decimal("0.1")) == "0.1"
381 |
382 |
383 | def test_literal_bytes():
384 | assert literal(b"\x03") == "X'03'"
385 |
386 |
387 | def test_literal_boolean():
388 | assert literal(True) == "TRUE"
389 |
390 |
391 | def test_literal_None():
392 | assert literal(None) == "NULL"
393 |
394 |
395 | def test_literal_Time():
396 | assert literal(Time(22, 13, 2)) == "'22:13:02'"
397 |
398 |
399 | def test_literal_Date():
400 | assert literal(Date(2063, 11, 2)) == "'2063-11-02'"
401 |
402 |
403 | def test_literal_TimeDelta():
404 | assert literal(TimeDelta(22, 13, 2)) == "'22 days 13 seconds 2 microseconds'"
405 |
406 |
407 | def test_literal_Datetime():
408 | assert literal(DateTime(2063, 3, 31, 22, 13, 2)) == "'2063-03-31T22:13:02'"
409 |
410 |
411 | def test_literal_Trojan():
412 | class Trojan:
413 | def __str__(self):
414 | return "A Gift"
415 |
416 | assert literal(Trojan()) == "'A Gift'"
417 |
--------------------------------------------------------------------------------
/test/test_readme.py:
--------------------------------------------------------------------------------
1 | import doctest
2 |
3 | from pathlib import Path
4 |
5 |
6 | def test_readme():
7 | failure_count, _ = doctest.testfile(
8 | str(Path("..") / "README.md"), verbose=False, optionflags=doctest.ELLIPSIS
9 | )
10 | assert failure_count == 0
11 |
--------------------------------------------------------------------------------
/test/test_types.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from pg8000.types import PGInterval, Range
4 |
5 |
6 | def test_PGInterval_init():
7 | i = PGInterval(days=1)
8 | assert i.months is None
9 | assert i.days == 1
10 | assert i.microseconds is None
11 |
12 |
13 | def test_PGInterval_repr():
14 | v = PGInterval(microseconds=123456789, days=2, months=24)
15 | assert repr(v) == ""
16 |
17 |
18 | def test_PGInterval_str():
19 | v = PGInterval(microseconds=123456789, days=2, months=24, millennia=2)
20 | assert str(v) == "2 millennia 24 months 2 days 123456789 microseconds"
21 |
22 |
23 | @pytest.mark.parametrize(
24 | "value,expected",
25 | [
26 | ("P1Y2M", PGInterval(years=1, months=2)),
27 | ("P12DT30S", PGInterval(days=12, seconds=30)),
28 | (
29 | "P-1Y-2M3DT-4H-5M-6S",
30 | PGInterval(years=-1, months=-2, days=3, hours=-4, minutes=-5, seconds=-6),
31 | ),
32 | ("PT1M32.32S", PGInterval(minutes=1, seconds=32.32)),
33 | ],
34 | )
35 | def test_PGInterval_from_str_iso_8601(value, expected):
36 | interval = PGInterval.from_str_iso_8601(value)
37 | assert interval == expected
38 |
39 |
40 | @pytest.mark.parametrize(
41 | "value,expected",
42 | [
43 | ("@ 1 year 2 mons", PGInterval(years=1, months=2)),
44 | (
45 | "@ 3 days 4 hours 5 mins 6 secs",
46 | PGInterval(days=3, hours=4, minutes=5, seconds=6),
47 | ),
48 | (
49 | "@ 1 year 2 mons -3 days 4 hours 5 mins 6 secs ago",
50 | PGInterval(years=-1, months=-2, days=3, hours=-4, minutes=-5, seconds=-6),
51 | ),
52 | (
53 | "@ 1 millennium -2 mons",
54 | PGInterval(millennia=1, months=-2),
55 | ),
56 | ],
57 | )
58 | def test_PGInterval_from_str_postgres(value, expected):
59 | interval = PGInterval.from_str_postgres(value)
60 | assert interval == expected
61 |
62 |
63 | @pytest.mark.parametrize(
64 | "value,expected",
65 | [
66 | ["1-2", PGInterval(years=1, months=2)],
67 | ["3 4:05:06", PGInterval(days=3, hours=4, minutes=5, seconds=6)],
68 | [
69 | "-1-2 +3 -4:05:06",
70 | PGInterval(years=-1, months=-2, days=3, hours=-4, minutes=-5, seconds=-6),
71 | ],
72 | ["8 4:00:32.32", PGInterval(days=8, hours=4, minutes=0, seconds=32.32)],
73 | ],
74 | )
75 | def test_PGInterval_from_str_sql_standard(value, expected):
76 | interval = PGInterval.from_str_sql_standard(value)
77 | assert interval == expected
78 |
79 |
80 | @pytest.mark.parametrize(
81 | "value,expected",
82 | [
83 | ("P12DT30S", PGInterval(days=12, seconds=30)),
84 | ("@ 1 year 2 mons", PGInterval(years=1, months=2)),
85 | ("1-2", PGInterval(years=1, months=2)),
86 | ("3 4:05:06", PGInterval(days=3, hours=4, minutes=5, seconds=6)),
87 | (
88 | "-1-2 +3 -4:05:06",
89 | PGInterval(years=-1, months=-2, days=3, hours=-4, minutes=-5, seconds=-6),
90 | ),
91 | ("00:00:30", PGInterval(seconds=30)),
92 | ],
93 | )
94 | def test_PGInterval_from_str(value, expected):
95 | interval = PGInterval.from_str(value)
96 | assert interval == expected
97 |
98 |
99 | def test_Range_equals():
100 | pg_range_a = Range("[", 1, 2, ")")
101 | pg_range_b = Range("[", 1, 2, ")")
102 | assert pg_range_a == pg_range_b
103 |
104 |
105 | def test_Range_str():
106 | v = Range(5, 6)
107 | assert str(v) == "[5,6)"
108 |
--------------------------------------------------------------------------------
/test/test_utils.py:
--------------------------------------------------------------------------------
1 | from test.utils import parse_server_version
2 |
3 |
4 | def test_parse_server_version():
5 | assert parse_server_version("17rc1") == 17
6 |
--------------------------------------------------------------------------------
/test/utils.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 |
4 | def parse_server_version(version):
5 | major = re.match(r"\d+", version).group() # leading digits in 17.0, 17rc1
6 | return int(major)
7 |
--------------------------------------------------------------------------------