├── .github ├── dependbot.yml └── workflows │ ├── publish.yml │ └── test-suite.yml ├── .gitignore ├── CHANGELOG.md ├── LICENSE.md ├── README.md ├── databases ├── __init__.py ├── backends │ ├── __init__.py │ ├── aiopg.py │ ├── asyncmy.py │ ├── common │ │ ├── __init__.py │ │ └── records.py │ ├── compilers │ │ ├── __init__.py │ │ └── psycopg.py │ ├── dialects │ │ ├── __init__.py │ │ └── psycopg.py │ ├── mysql.py │ ├── postgres.py │ └── sqlite.py ├── core.py ├── importer.py ├── interfaces.py └── py.typed ├── docs ├── connections_and_transactions.md ├── contributing.md ├── database_queries.md ├── index.md └── tests_and_migrations.md ├── mkdocs.yml ├── requirements.txt ├── scripts ├── README.md ├── build ├── check ├── clean ├── coverage ├── docs ├── install ├── lint ├── publish └── test ├── setup.cfg ├── setup.py └── tests ├── importer ├── __init__.py └── raise_import_error.py ├── test_connection_options.py ├── test_database_url.py ├── test_databases.py ├── test_importer.py └── test_integration.py /.github/dependbot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "pip" 4 | directory: "/" 5 | schedule: 6 | interval: "monthly" 7 | - package-ecosystem: "github-actions" 8 | directory: "/" 9 | schedule: 10 | interval: monthly 11 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: Publish 3 | 4 | on: 5 | push: 6 | tags: 7 | - '*' 8 | 9 | jobs: 10 | publish: 11 | name: "Publish release" 12 | runs-on: "ubuntu-latest" 13 | 14 | steps: 15 | - uses: "actions/checkout@v3" 16 | - uses: "actions/setup-python@v4" 17 | with: 18 | python-version: 3.8 19 | - name: "Install dependencies" 20 | run: "scripts/install" 21 | - name: "Build package & docs" 22 | run: "scripts/build" 23 | - name: "Publish to PyPI & deploy docs" 24 | run: "scripts/publish" 25 | env: 26 | TWINE_USERNAME: __token__ 27 | TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }} 28 | -------------------------------------------------------------------------------- /.github/workflows/test-suite.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: Test Suite 3 | 4 | on: 5 | push: 6 | branches: ["master"] 7 | pull_request: 8 | branches: ["master"] 9 | 10 | jobs: 11 | tests: 12 | name: "Python ${{ matrix.python-version }}" 13 | runs-on: "ubuntu-latest" 14 | 15 | strategy: 16 | matrix: 17 | python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] 18 | 19 | services: 20 | mysql: 21 | image: mysql:5.7 22 | env: 23 | MYSQL_USER: username 24 | MYSQL_PASSWORD: password 25 | MYSQL_ROOT_PASSWORD: password 26 | MYSQL_DATABASE: testsuite 27 | ports: 28 | - 3306:3306 29 | options: --health-cmd="mysqladmin ping" --health-interval=10s --health-timeout=5s --health-retries=3 30 | 31 | postgres: 32 | image: postgres:14 33 | env: 34 | POSTGRES_USER: username 35 | POSTGRES_PASSWORD: password 36 | POSTGRES_DB: testsuite 37 | ports: 38 | - 5432:5432 39 | options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 40 | 41 | steps: 42 | - uses: "actions/checkout@v3" 43 | - uses: "actions/setup-python@v4" 44 | with: 45 | python-version: "${{ matrix.python-version }}" 46 | - name: "Install dependencies" 47 | run: "scripts/install" 48 | - name: "Run linting checks" 49 | run: "scripts/check" 50 | - name: "Build package & docs" 51 | run: "scripts/build" 52 | - name: "Run tests" 53 | env: 54 | TEST_DATABASE_URLS: | 55 | sqlite:///testsuite, 56 | sqlite+aiosqlite:///testsuite, 57 | mysql://username:password@localhost:3306/testsuite, 58 | mysql+aiomysql://username:password@localhost:3306/testsuite, 59 | mysql+asyncmy://username:password@localhost:3306/testsuite, 60 | postgresql://username:password@localhost:5432/testsuite, 61 | postgresql+aiopg://username:password@127.0.0.1:5432/testsuite, 62 | postgresql+asyncpg://username:password@localhost:5432/testsuite 63 | run: "scripts/test" 64 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | .idea/ 3 | *.pyc 4 | test.db 5 | .coverage 6 | .pytest_cache/ 7 | .mypy_cache/ 8 | __pycache__/ 9 | *.egg-info/ 10 | htmlcov/ 11 | venv/ 12 | .idea/ 13 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). 6 | 7 | ## 0.9.0 (February 23th, 2024) 8 | 9 | ### Changed 10 | 11 | * Drop support for Python 3.7 and add support for Python 3.12 ([#583][#583]) 12 | * Add support for SQLAlchemy 2+ ([#540][#540]) 13 | * Allow SSL string parameters in PostgresSQL URL ([#575][#575]) and ([#576][#576]) 14 | 15 | [#583]: https://github.com/encode/databases/pull/583 16 | [#540]: https://github.com/encode/databases/pull/540 17 | [#575]: https://github.com/encode/databases/pull/575 18 | [#576]: https://github.com/encode/databases/pull/576 19 | 20 | ## 0.8.0 (August 28th, 2023) 21 | 22 | ### Added 23 | 24 | * Allow SQLite query parameters and support cached databases ([#561][#561]) 25 | * Support for unix socket for aiomysql and asyncmy ([#551][#551]) 26 | 27 | [#551]: https://github.com/encode/databases/pull/551 28 | [#561]: https://github.com/encode/databases/pull/546 29 | 30 | ### Changed 31 | 32 | * Change isolation connections and transactions during concurrent usage ([#546][#546]) 33 | * Bump requests from 2.28.1 to 2.31.0 ([#562][#562]) 34 | * Bump starlette from 0.20.4 to 0.27.0 ([#560][#560]) 35 | * Bump up asyncmy version to fix `No module named 'asyncmy.connection'` ([#553][#553]) 36 | * Bump wheel from 0.37.1 to 0.38.1 ([#524][#524]) 37 | 38 | [#546]: https://github.com/encode/databases/pull/546 39 | [#562]: https://github.com/encode/databases/pull/562 40 | [#560]: https://github.com/encode/databases/pull/560 41 | [#553]: https://github.com/encode/databases/pull/553 42 | [#524]: https://github.com/encode/databases/pull/524 43 | 44 | ### Fixed 45 | 46 | * Fix the type-hints using more standard mode ([#526][#526]) 47 | 48 | [#526]: https://github.com/encode/databases/pull/526 49 | 50 | ## 0.7.0 (Dec 18th, 2022) 51 | 52 | ### Fixed 53 | 54 | * Fixed breaking changes in SQLAlchemy cursor; supports `>=1.4.42,<1.5` ([#513][#513]) 55 | * Wrapped types in `typing.Optional` where applicable ([#510][#510]) 56 | 57 | [#513]: https://github.com/encode/databases/pull/513 58 | [#510]: https://github.com/encode/databases/pull/510 59 | 60 | ## 0.6.2 (Nov 7th, 2022) 61 | 62 | ### Changed 63 | 64 | * Pinned SQLAlchemy `<=1.4.41` to avoid breaking changes ([#520][#520]) 65 | 66 | [#520]: https://github.com/encode/databases/pull/520 67 | 68 | ## 0.6.1 (Aug 9th, 2022) 69 | 70 | ### Fixed 71 | 72 | * Improve typing for `Transaction` ([#493][#493]) 73 | * Allow string indexing into Record ([#501][#501]) 74 | 75 | [#493]: https://github.com/encode/databases/pull/493 76 | [#501]: https://github.com/encode/databases/pull/501 77 | 78 | ## 0.6.0 (May 29th, 2022) 79 | 80 | * Dropped Python 3.6 support ([#458][#458]) 81 | 82 | [#458]: https://github.com/encode/databases/pull/458 83 | 84 | ### Added 85 | 86 | * Add \_mapping property to the result set interface ([#447][#447]) 87 | * Add contributing docs ([#453][#453]) 88 | 89 | [#447]: https://github.com/encode/databases/pull/447 90 | [#453]: https://github.com/encode/databases/pull/453 91 | 92 | ### Fixed 93 | 94 | * Fix query result named access ([#448][#448]) 95 | * Fix connections getting into a bad state when a task is cancelled ([#457][#457]) 96 | * Revert #328 parallel transactions ([#472][#472]) 97 | * Change extra installations to specific drivers ([#436][#436]) 98 | 99 | [#448]: https://github.com/encode/databases/pull/448 100 | [#457]: https://github.com/encode/databases/pull/457 101 | [#472]: https://github.com/encode/databases/pull/472 102 | [#436]: https://github.com/encode/databases/pull/436 103 | 104 | ## 0.5.4 (January 14th, 2022) 105 | 106 | ### Added 107 | 108 | * Support for Unix domain in connections ([#423][#423]) 109 | * Added `asyncmy` MySQL driver ([#382][#382]) 110 | 111 | [#423]: https://github.com/encode/databases/pull/423 112 | [#382]: https://github.com/encode/databases/pull/382 113 | 114 | ### Fixed 115 | 116 | * Fix SQLite fetch queries with multiple parameters ([#435][#435]) 117 | * Changed `Record` type to `Sequence` ([#408][#408]) 118 | 119 | [#435]: https://github.com/encode/databases/pull/435 120 | [#408]: https://github.com/encode/databases/pull/408 121 | 122 | ## 0.5.3 (October 10th, 2021) 123 | 124 | ### Added 125 | 126 | * Support `dialect+driver` for default database drivers like `postgresql+asyncpg` ([#396][#396]) 127 | 128 | [#396]: https://github.com/encode/databases/pull/396 129 | 130 | ### Fixed 131 | 132 | * Documentation of low-level transaction ([#390][#390]) 133 | 134 | [#390]: https://github.com/encode/databases/pull/390 135 | 136 | ## 0.5.2 (September 10th, 2021) 137 | 138 | ### Fixed 139 | 140 | * Reset counter for failed connections ([#385][#385]) 141 | * Avoid dangling task-local connections after Database.disconnect() ([#211][#211]) 142 | 143 | [#385]: https://github.com/encode/databases/pull/385 144 | [#211]: https://github.com/encode/databases/pull/211 145 | 146 | ## 0.5.1 (September 2nd, 2021) 147 | 148 | ### Added 149 | 150 | * Make database `connect` and `disconnect` calls idempotent ([#379][#379]) 151 | 152 | [#379]: https://github.com/encode/databases/pull/379 153 | 154 | ### Fixed 155 | 156 | * Fix `in_` and `notin_` queries in SQLAlchemy 1.4 ([#378][#378]) 157 | 158 | [#378]: https://github.com/encode/databases/pull/378 159 | 160 | ## 0.5.0 (August 26th, 2021) 161 | 162 | ### Added 163 | 164 | * Support SQLAlchemy 1.4 ([#299][#299]) 165 | 166 | [#299]: https://github.com/encode/databases/pull/299 167 | 168 | ### Fixed 169 | 170 | * Fix concurrent transactions ([#328][#328]) 171 | 172 | [#328]: https://github.com/encode/databases/pull/328 173 | 174 | ## 0.4.3 (March 26th, 2021) 175 | 176 | ### Fixed 177 | 178 | * Pin SQLAlchemy to <1.4 ([#314][#314]) 179 | 180 | [#314]: https://github.com/encode/databases/pull/314 181 | 182 | ## 0.4.2 (March 14th, 2021) 183 | 184 | ### Fixed 185 | 186 | * Fix memory leak with asyncpg for SQLAlchemy generic functions ([#273][#273]) 187 | 188 | [#273]: https://github.com/encode/databases/pull/273 189 | 190 | ## 0.4.1 (November 16th, 2020) 191 | 192 | ### Fixed 193 | 194 | * Remove package dependency on the synchronous DB drivers ([#256][#256]) 195 | 196 | [#256]: https://github.com/encode/databases/pull/256 197 | 198 | ## 0.4.0 (October 20th, 2020) 199 | 200 | ### Added 201 | 202 | * Use backend native fetch_val() implementation when available ([#132][#132]) 203 | * Replace psycopg2-binary with psycopg2 ([#204][#204]) 204 | * Speed up PostgresConnection fetch() and iterate() ([#193][#193]) 205 | * Access asyncpg Record field by key on raw query ([#207][#207]) 206 | * Allow setting min_size and max_size in postgres DSN ([#210][#210]) 207 | * Add option pool_recycle in postgres DSN ([#233][#233]) 208 | * Allow extra transaction options ([#242][#242]) 209 | 210 | [#132]: https://github.com/encode/databases/pull/132 211 | [#204]: https://github.com/encode/databases/pull/204 212 | [#193]: https://github.com/encode/databases/pull/193 213 | [#207]: https://github.com/encode/databases/pull/207 214 | [#210]: https://github.com/encode/databases/pull/210 215 | [#233]: https://github.com/encode/databases/pull/233 216 | [#242]: https://github.com/encode/databases/pull/242 217 | 218 | ### Fixed 219 | 220 | * Fix type hinting for sqlite backend ([#227][#227]) 221 | * Fix SQLAlchemy DDL statements ([#226][#226]) 222 | * Make fetch_val call fetch_one for type conversion ([#246][#246]) 223 | * Unquote username and password in DatabaseURL ([#248][#248]) 224 | 225 | [#227]: https://github.com/encode/databases/pull/227 226 | [#226]: https://github.com/encode/databases/pull/226 227 | [#246]: https://github.com/encode/databases/pull/246 228 | [#248]: https://github.com/encode/databases/pull/248 229 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Copyright © 2019, [Encode OSS Ltd](https://www.encode.io/). 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | * Neither the name of the copyright holder nor the names of its 15 | contributors may be used to endorse or promote products derived from 16 | this software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Databases 2 | 3 |
11 | 12 | Databases gives you simple asyncio support for a range of databases. 13 | 14 | It allows you to make queries using the powerful [SQLAlchemy Core][sqlalchemy-core] 15 | expression language, and provides support for PostgreSQL, MySQL, and SQLite. 16 | 17 | Databases is suitable for integrating against any async Web framework, such as [Starlette][starlette], 18 | [Sanic][sanic], [Responder][responder], [Quart][quart], [aiohttp][aiohttp], [Tornado][tornado], or [FastAPI][fastapi]. 19 | 20 | **Documentation**: [https://www.encode.io/databases/](https://www.encode.io/databases/) 21 | 22 | **Requirements**: Python 3.8+ 23 | 24 | --- 25 | 26 | ## Installation 27 | 28 | ```shell 29 | $ pip install databases 30 | ``` 31 | 32 | Database drivers supported are: 33 | 34 | * [asyncpg][asyncpg] 35 | * [aiopg][aiopg] 36 | * [aiomysql][aiomysql] 37 | * [asyncmy][asyncmy] 38 | * [aiosqlite][aiosqlite] 39 | 40 | You can install the required database drivers with: 41 | 42 | ```shell 43 | $ pip install databases[asyncpg] 44 | $ pip install databases[aiopg] 45 | $ pip install databases[aiomysql] 46 | $ pip install databases[asyncmy] 47 | $ pip install databases[aiosqlite] 48 | ``` 49 | 50 | Note that if you are using any synchronous SQLAlchemy functions such as `engine.create_all()` or [alembic][alembic] migrations then you still have to install a synchronous DB driver: [psycopg2][psycopg2] for PostgreSQL and [pymysql][pymysql] for MySQL. 51 | 52 | --- 53 | 54 | ## Quickstart 55 | 56 | For this example we'll create a very simple SQLite database to run some 57 | queries against. 58 | 59 | ```shell 60 | $ pip install databases[aiosqlite] 61 | $ pip install ipython 62 | ``` 63 | 64 | We can now run a simple example from the console. 65 | 66 | Note that we want to use `ipython` here, because it supports using `await` 67 | expressions directly from the console. 68 | 69 | ```python 70 | # Create a database instance, and connect to it. 71 | from databases import Database 72 | database = Database('sqlite+aiosqlite:///example.db') 73 | await database.connect() 74 | 75 | # Create a table. 76 | query = """CREATE TABLE HighScores (id INTEGER PRIMARY KEY, name VARCHAR(100), score INTEGER)""" 77 | await database.execute(query=query) 78 | 79 | # Insert some data. 80 | query = "INSERT INTO HighScores(name, score) VALUES (:name, :score)" 81 | values = [ 82 | {"name": "Daisy", "score": 92}, 83 | {"name": "Neil", "score": 87}, 84 | {"name": "Carol", "score": 43}, 85 | ] 86 | await database.execute_many(query=query, values=values) 87 | 88 | # Run a database query. 89 | query = "SELECT * FROM HighScores" 90 | rows = await database.fetch_all(query=query) 91 | print('High Scores:', rows) 92 | ``` 93 | 94 | Check out the documentation on [making database queries](https://www.encode.io/databases/database_queries/) 95 | for examples of how to start using databases together with SQLAlchemy core expressions. 96 | 97 | 98 |— ⭐️ —
99 |Databases is BSD licensed code. Designed & built in Brighton, England.
100 | 101 | [sqlalchemy-core]: https://docs.sqlalchemy.org/en/latest/core/ 102 | [sqlalchemy-core-tutorial]: https://docs.sqlalchemy.org/en/latest/core/tutorial.html 103 | [alembic]: https://alembic.sqlalchemy.org/en/latest/ 104 | [psycopg2]: https://www.psycopg.org/ 105 | [pymysql]: https://github.com/PyMySQL/PyMySQL 106 | [asyncpg]: https://github.com/MagicStack/asyncpg 107 | [aiopg]: https://github.com/aio-libs/aiopg 108 | [aiomysql]: https://github.com/aio-libs/aiomysql 109 | [asyncmy]: https://github.com/long2ice/asyncmy 110 | [aiosqlite]: https://github.com/omnilib/aiosqlite 111 | 112 | [starlette]: https://github.com/encode/starlette 113 | [sanic]: https://github.com/huge-success/sanic 114 | [responder]: https://github.com/kennethreitz/responder 115 | [quart]: https://gitlab.com/pgjones/quart 116 | [aiohttp]: https://github.com/aio-libs/aiohttp 117 | [tornado]: https://github.com/tornadoweb/tornado 118 | [fastapi]: https://github.com/tiangolo/fastapi 119 | -------------------------------------------------------------------------------- /databases/__init__.py: -------------------------------------------------------------------------------- 1 | from databases.core import Database, DatabaseURL 2 | 3 | __version__ = "0.9.0" 4 | __all__ = ["Database", "DatabaseURL"] 5 | -------------------------------------------------------------------------------- /databases/backends/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/encode/databases/ae3fb16f40201d9ed0ed31bea31a289127169568/databases/backends/__init__.py -------------------------------------------------------------------------------- /databases/backends/aiopg.py: -------------------------------------------------------------------------------- 1 | import getpass 2 | import json 3 | import logging 4 | import typing 5 | import uuid 6 | 7 | import aiopg 8 | from sqlalchemy.engine.cursor import CursorResultMetaData 9 | from sqlalchemy.engine.interfaces import Dialect, ExecutionContext 10 | from sqlalchemy.engine.row import Row 11 | from sqlalchemy.sql import ClauseElement 12 | from sqlalchemy.sql.ddl import DDLElement 13 | 14 | from databases.backends.common.records import Record, Row, create_column_maps 15 | from databases.backends.compilers.psycopg import PGCompiler_psycopg 16 | from databases.backends.dialects.psycopg import PGDialect_psycopg 17 | from databases.core import LOG_EXTRA, DatabaseURL 18 | from databases.interfaces import ( 19 | ConnectionBackend, 20 | DatabaseBackend, 21 | Record as RecordInterface, 22 | TransactionBackend, 23 | ) 24 | 25 | logger = logging.getLogger("databases") 26 | 27 | 28 | class AiopgBackend(DatabaseBackend): 29 | def __init__( 30 | self, database_url: typing.Union[DatabaseURL, str], **options: typing.Any 31 | ) -> None: 32 | self._database_url = DatabaseURL(database_url) 33 | self._options = options 34 | self._dialect = self._get_dialect() 35 | self._pool: typing.Union[aiopg.Pool, None] = None 36 | 37 | def _get_dialect(self) -> Dialect: 38 | dialect = PGDialect_psycopg( 39 | json_serializer=json.dumps, json_deserializer=lambda x: x 40 | ) 41 | dialect.statement_compiler = PGCompiler_psycopg 42 | dialect.implicit_returning = True 43 | dialect.supports_native_enum = True 44 | dialect.supports_smallserial = True # 9.2+ 45 | dialect._backslash_escapes = False 46 | dialect.supports_sane_multi_rowcount = True # psycopg 2.0.9+ 47 | dialect._has_native_hstore = True 48 | dialect.supports_native_decimal = True 49 | 50 | return dialect 51 | 52 | def _get_connection_kwargs(self) -> dict: 53 | url_options = self._database_url.options 54 | 55 | kwargs = {} 56 | min_size = url_options.get("min_size") 57 | max_size = url_options.get("max_size") 58 | ssl = url_options.get("ssl") 59 | 60 | if min_size is not None: 61 | kwargs["minsize"] = int(min_size) 62 | if max_size is not None: 63 | kwargs["maxsize"] = int(max_size) 64 | if ssl is not None: 65 | kwargs["ssl"] = {"true": True, "false": False}[ssl.lower()] 66 | 67 | for key, value in self._options.items(): 68 | # Coerce 'min_size' and 'max_size' for consistency. 69 | if key == "min_size": 70 | key = "minsize" 71 | elif key == "max_size": 72 | key = "maxsize" 73 | kwargs[key] = value 74 | 75 | return kwargs 76 | 77 | async def connect(self) -> None: 78 | assert self._pool is None, "DatabaseBackend is already running" 79 | kwargs = self._get_connection_kwargs() 80 | self._pool = await aiopg.create_pool( 81 | host=self._database_url.hostname, 82 | port=self._database_url.port, 83 | user=self._database_url.username or getpass.getuser(), 84 | password=self._database_url.password, 85 | database=self._database_url.database, 86 | **kwargs, 87 | ) 88 | 89 | async def disconnect(self) -> None: 90 | assert self._pool is not None, "DatabaseBackend is not running" 91 | self._pool.close() 92 | await self._pool.wait_closed() 93 | self._pool = None 94 | 95 | def connection(self) -> "AiopgConnection": 96 | return AiopgConnection(self, self._dialect) 97 | 98 | 99 | class CompilationContext: 100 | def __init__(self, context: ExecutionContext): 101 | self.context = context 102 | 103 | 104 | class AiopgConnection(ConnectionBackend): 105 | def __init__(self, database: AiopgBackend, dialect: Dialect): 106 | self._database = database 107 | self._dialect = dialect 108 | self._connection: typing.Optional[aiopg.Connection] = None 109 | 110 | async def acquire(self) -> None: 111 | assert self._connection is None, "Connection is already acquired" 112 | assert self._database._pool is not None, "DatabaseBackend is not running" 113 | self._connection = await self._database._pool.acquire() 114 | 115 | async def release(self) -> None: 116 | assert self._connection is not None, "Connection is not acquired" 117 | assert self._database._pool is not None, "DatabaseBackend is not running" 118 | await self._database._pool.release(self._connection) 119 | self._connection = None 120 | 121 | async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]: 122 | assert self._connection is not None, "Connection is not acquired" 123 | query_str, args, result_columns, context = self._compile(query) 124 | column_maps = create_column_maps(result_columns) 125 | dialect = self._dialect 126 | 127 | cursor = await self._connection.cursor() 128 | try: 129 | await cursor.execute(query_str, args) 130 | rows = await cursor.fetchall() 131 | metadata = CursorResultMetaData(context, cursor.description) 132 | rows = [ 133 | Row( 134 | metadata, 135 | metadata._processors, 136 | metadata._keymap, 137 | row, 138 | ) 139 | for row in rows 140 | ] 141 | return [Record(row, result_columns, dialect, column_maps) for row in rows] 142 | finally: 143 | cursor.close() 144 | 145 | async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]: 146 | assert self._connection is not None, "Connection is not acquired" 147 | query_str, args, result_columns, context = self._compile(query) 148 | column_maps = create_column_maps(result_columns) 149 | dialect = self._dialect 150 | cursor = await self._connection.cursor() 151 | try: 152 | await cursor.execute(query_str, args) 153 | row = await cursor.fetchone() 154 | if row is None: 155 | return None 156 | metadata = CursorResultMetaData(context, cursor.description) 157 | row = Row( 158 | metadata, 159 | metadata._processors, 160 | metadata._keymap, 161 | row, 162 | ) 163 | return Record(row, result_columns, dialect, column_maps) 164 | finally: 165 | cursor.close() 166 | 167 | async def execute(self, query: ClauseElement) -> typing.Any: 168 | assert self._connection is not None, "Connection is not acquired" 169 | query_str, args, _, _ = self._compile(query) 170 | cursor = await self._connection.cursor() 171 | try: 172 | await cursor.execute(query_str, args) 173 | return cursor.lastrowid 174 | finally: 175 | cursor.close() 176 | 177 | async def execute_many(self, queries: typing.List[ClauseElement]) -> None: 178 | assert self._connection is not None, "Connection is not acquired" 179 | cursor = await self._connection.cursor() 180 | try: 181 | for single_query in queries: 182 | single_query, args, _, _ = self._compile(single_query) 183 | await cursor.execute(single_query, args) 184 | finally: 185 | cursor.close() 186 | 187 | async def iterate( 188 | self, query: ClauseElement 189 | ) -> typing.AsyncGenerator[typing.Any, None]: 190 | assert self._connection is not None, "Connection is not acquired" 191 | query_str, args, result_columns, context = self._compile(query) 192 | column_maps = create_column_maps(result_columns) 193 | dialect = self._dialect 194 | cursor = await self._connection.cursor() 195 | try: 196 | await cursor.execute(query_str, args) 197 | metadata = CursorResultMetaData(context, cursor.description) 198 | async for row in cursor: 199 | record = Row( 200 | metadata, 201 | metadata._processors, 202 | metadata._keymap, 203 | row, 204 | ) 205 | yield Record(record, result_columns, dialect, column_maps) 206 | finally: 207 | cursor.close() 208 | 209 | def transaction(self) -> TransactionBackend: 210 | return AiopgTransaction(self) 211 | 212 | def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]: 213 | compiled = query.compile( 214 | dialect=self._dialect, compile_kwargs={"render_postcompile": True} 215 | ) 216 | execution_context = self._dialect.execution_ctx_cls() 217 | execution_context.dialect = self._dialect 218 | 219 | if not isinstance(query, DDLElement): 220 | compiled_params = sorted(compiled.params.items()) 221 | 222 | args = compiled.construct_params() 223 | for key, val in args.items(): 224 | if key in compiled._bind_processors: 225 | args[key] = compiled._bind_processors[key](val) 226 | 227 | execution_context.result_column_struct = ( 228 | compiled._result_columns, 229 | compiled._ordered_columns, 230 | compiled._textual_ordered_columns, 231 | compiled._ad_hoc_textual, 232 | compiled._loose_column_name_matching, 233 | ) 234 | 235 | mapping = { 236 | key: "$" + str(i) for i, (key, _) in enumerate(compiled_params, start=1) 237 | } 238 | compiled_query = compiled.string % mapping 239 | result_map = compiled._result_columns 240 | 241 | else: 242 | args = {} 243 | result_map = None 244 | compiled_query = compiled.string 245 | 246 | query_message = compiled_query.replace(" \n", " ").replace("\n", " ") 247 | logger.debug( 248 | "Query: %s Args: %s", query_message, repr(tuple(args)), extra=LOG_EXTRA 249 | ) 250 | return compiled.string, args, result_map, CompilationContext(execution_context) 251 | 252 | @property 253 | def raw_connection(self) -> aiopg.connection.Connection: 254 | assert self._connection is not None, "Connection is not acquired" 255 | return self._connection 256 | 257 | 258 | class AiopgTransaction(TransactionBackend): 259 | def __init__(self, connection: AiopgConnection): 260 | self._connection = connection 261 | self._is_root = False 262 | self._savepoint_name = "" 263 | 264 | async def start( 265 | self, is_root: bool, extra_options: typing.Dict[typing.Any, typing.Any] 266 | ) -> None: 267 | assert self._connection._connection is not None, "Connection is not acquired" 268 | self._is_root = is_root 269 | cursor = await self._connection._connection.cursor() 270 | if self._is_root: 271 | await cursor.execute("BEGIN") 272 | else: 273 | id = str(uuid.uuid4()).replace("-", "_") 274 | self._savepoint_name = f"STARLETTE_SAVEPOINT_{id}" 275 | try: 276 | await cursor.execute(f"SAVEPOINT {self._savepoint_name}") 277 | finally: 278 | cursor.close() 279 | 280 | async def commit(self) -> None: 281 | assert self._connection._connection is not None, "Connection is not acquired" 282 | cursor = await self._connection._connection.cursor() 283 | if self._is_root: 284 | await cursor.execute("COMMIT") 285 | else: 286 | try: 287 | await cursor.execute(f"RELEASE SAVEPOINT {self._savepoint_name}") 288 | finally: 289 | cursor.close() 290 | 291 | async def rollback(self) -> None: 292 | assert self._connection._connection is not None, "Connection is not acquired" 293 | cursor = await self._connection._connection.cursor() 294 | if self._is_root: 295 | await cursor.execute("ROLLBACK") 296 | else: 297 | try: 298 | await cursor.execute(f"ROLLBACK TO SAVEPOINT {self._savepoint_name}") 299 | finally: 300 | cursor.close() 301 | -------------------------------------------------------------------------------- /databases/backends/asyncmy.py: -------------------------------------------------------------------------------- 1 | import getpass 2 | import logging 3 | import typing 4 | import uuid 5 | 6 | import asyncmy 7 | from sqlalchemy.dialects.mysql import pymysql 8 | from sqlalchemy.engine.cursor import CursorResultMetaData 9 | from sqlalchemy.engine.interfaces import Dialect, ExecutionContext 10 | from sqlalchemy.sql import ClauseElement 11 | from sqlalchemy.sql.ddl import DDLElement 12 | 13 | from databases.backends.common.records import Record, Row, create_column_maps 14 | from databases.core import LOG_EXTRA, DatabaseURL 15 | from databases.interfaces import ( 16 | ConnectionBackend, 17 | DatabaseBackend, 18 | Record as RecordInterface, 19 | TransactionBackend, 20 | ) 21 | 22 | logger = logging.getLogger("databases") 23 | 24 | 25 | class AsyncMyBackend(DatabaseBackend): 26 | def __init__( 27 | self, database_url: typing.Union[DatabaseURL, str], **options: typing.Any 28 | ) -> None: 29 | self._database_url = DatabaseURL(database_url) 30 | self._options = options 31 | self._dialect = pymysql.dialect(paramstyle="pyformat") 32 | self._dialect.supports_native_decimal = True 33 | self._pool = None 34 | 35 | def _get_connection_kwargs(self) -> dict: 36 | url_options = self._database_url.options 37 | 38 | kwargs = {} 39 | min_size = url_options.get("min_size") 40 | max_size = url_options.get("max_size") 41 | pool_recycle = url_options.get("pool_recycle") 42 | ssl = url_options.get("ssl") 43 | unix_socket = url_options.get("unix_socket") 44 | 45 | if min_size is not None: 46 | kwargs["minsize"] = int(min_size) 47 | if max_size is not None: 48 | kwargs["maxsize"] = int(max_size) 49 | if pool_recycle is not None: 50 | kwargs["pool_recycle"] = int(pool_recycle) 51 | if ssl is not None: 52 | kwargs["ssl"] = {"true": True, "false": False}[ssl.lower()] 53 | if unix_socket is not None: 54 | kwargs["unix_socket"] = unix_socket 55 | 56 | for key, value in self._options.items(): 57 | # Coerce 'min_size' and 'max_size' for consistency. 58 | if key == "min_size": 59 | key = "minsize" 60 | elif key == "max_size": 61 | key = "maxsize" 62 | kwargs[key] = value 63 | 64 | return kwargs 65 | 66 | async def connect(self) -> None: 67 | assert self._pool is None, "DatabaseBackend is already running" 68 | kwargs = self._get_connection_kwargs() 69 | self._pool = await asyncmy.create_pool( 70 | host=self._database_url.hostname, 71 | port=self._database_url.port or 3306, 72 | user=self._database_url.username or getpass.getuser(), 73 | password=self._database_url.password, 74 | db=self._database_url.database, 75 | autocommit=True, 76 | **kwargs, 77 | ) 78 | 79 | async def disconnect(self) -> None: 80 | assert self._pool is not None, "DatabaseBackend is not running" 81 | self._pool.close() 82 | await self._pool.wait_closed() 83 | self._pool = None 84 | 85 | def connection(self) -> "AsyncMyConnection": 86 | return AsyncMyConnection(self, self._dialect) 87 | 88 | 89 | class CompilationContext: 90 | def __init__(self, context: ExecutionContext): 91 | self.context = context 92 | 93 | 94 | class AsyncMyConnection(ConnectionBackend): 95 | def __init__(self, database: AsyncMyBackend, dialect: Dialect): 96 | self._database = database 97 | self._dialect = dialect 98 | self._connection: typing.Optional[asyncmy.Connection] = None 99 | 100 | async def acquire(self) -> None: 101 | assert self._connection is None, "Connection is already acquired" 102 | assert self._database._pool is not None, "DatabaseBackend is not running" 103 | self._connection = await self._database._pool.acquire() 104 | 105 | async def release(self) -> None: 106 | assert self._connection is not None, "Connection is not acquired" 107 | assert self._database._pool is not None, "DatabaseBackend is not running" 108 | await self._database._pool.release(self._connection) 109 | self._connection = None 110 | 111 | async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]: 112 | assert self._connection is not None, "Connection is not acquired" 113 | query_str, args, result_columns, context = self._compile(query) 114 | column_maps = create_column_maps(result_columns) 115 | dialect = self._dialect 116 | 117 | async with self._connection.cursor() as cursor: 118 | try: 119 | await cursor.execute(query_str, args) 120 | rows = await cursor.fetchall() 121 | metadata = CursorResultMetaData(context, cursor.description) 122 | rows = [ 123 | Row( 124 | metadata, 125 | metadata._processors, 126 | metadata._keymap, 127 | row, 128 | ) 129 | for row in rows 130 | ] 131 | return [ 132 | Record(row, result_columns, dialect, column_maps) for row in rows 133 | ] 134 | finally: 135 | await cursor.close() 136 | 137 | async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]: 138 | assert self._connection is not None, "Connection is not acquired" 139 | query_str, args, result_columns, context = self._compile(query) 140 | column_maps = create_column_maps(result_columns) 141 | dialect = self._dialect 142 | async with self._connection.cursor() as cursor: 143 | try: 144 | await cursor.execute(query_str, args) 145 | row = await cursor.fetchone() 146 | if row is None: 147 | return None 148 | metadata = CursorResultMetaData(context, cursor.description) 149 | row = Row( 150 | metadata, 151 | metadata._processors, 152 | metadata._keymap, 153 | row, 154 | ) 155 | return Record(row, result_columns, dialect, column_maps) 156 | finally: 157 | await cursor.close() 158 | 159 | async def execute(self, query: ClauseElement) -> typing.Any: 160 | assert self._connection is not None, "Connection is not acquired" 161 | query_str, args, _, _ = self._compile(query) 162 | async with self._connection.cursor() as cursor: 163 | try: 164 | await cursor.execute(query_str, args) 165 | if cursor.lastrowid == 0: 166 | return cursor.rowcount 167 | return cursor.lastrowid 168 | finally: 169 | await cursor.close() 170 | 171 | async def execute_many(self, queries: typing.List[ClauseElement]) -> None: 172 | assert self._connection is not None, "Connection is not acquired" 173 | async with self._connection.cursor() as cursor: 174 | try: 175 | for single_query in queries: 176 | single_query, args, _, _ = self._compile(single_query) 177 | await cursor.execute(single_query, args) 178 | finally: 179 | await cursor.close() 180 | 181 | async def iterate( 182 | self, query: ClauseElement 183 | ) -> typing.AsyncGenerator[typing.Any, None]: 184 | assert self._connection is not None, "Connection is not acquired" 185 | query_str, args, result_columns, context = self._compile(query) 186 | column_maps = create_column_maps(result_columns) 187 | dialect = self._dialect 188 | async with self._connection.cursor() as cursor: 189 | try: 190 | await cursor.execute(query_str, args) 191 | metadata = CursorResultMetaData(context, cursor.description) 192 | async for row in cursor: 193 | record = Row( 194 | metadata, 195 | metadata._processors, 196 | metadata._keymap, 197 | row, 198 | ) 199 | yield Record(record, result_columns, dialect, column_maps) 200 | finally: 201 | await cursor.close() 202 | 203 | def transaction(self) -> TransactionBackend: 204 | return AsyncMyTransaction(self) 205 | 206 | def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]: 207 | compiled = query.compile( 208 | dialect=self._dialect, compile_kwargs={"render_postcompile": True} 209 | ) 210 | execution_context = self._dialect.execution_ctx_cls() 211 | execution_context.dialect = self._dialect 212 | 213 | if not isinstance(query, DDLElement): 214 | compiled_params = sorted(compiled.params.items()) 215 | 216 | args = compiled.construct_params() 217 | for key, val in args.items(): 218 | if key in compiled._bind_processors: 219 | args[key] = compiled._bind_processors[key](val) 220 | 221 | execution_context.result_column_struct = ( 222 | compiled._result_columns, 223 | compiled._ordered_columns, 224 | compiled._textual_ordered_columns, 225 | compiled._ad_hoc_textual, 226 | compiled._loose_column_name_matching, 227 | ) 228 | 229 | mapping = { 230 | key: "$" + str(i) for i, (key, _) in enumerate(compiled_params, start=1) 231 | } 232 | compiled_query = compiled.string % mapping 233 | result_map = compiled._result_columns 234 | 235 | else: 236 | args = {} 237 | result_map = None 238 | compiled_query = compiled.string 239 | 240 | query_message = compiled_query.replace(" \n", " ").replace("\n", " ") 241 | logger.debug( 242 | "Query: %s Args: %s", query_message, repr(tuple(args)), extra=LOG_EXTRA 243 | ) 244 | return compiled.string, args, result_map, CompilationContext(execution_context) 245 | 246 | @property 247 | def raw_connection(self) -> asyncmy.connection.Connection: 248 | assert self._connection is not None, "Connection is not acquired" 249 | return self._connection 250 | 251 | 252 | class AsyncMyTransaction(TransactionBackend): 253 | def __init__(self, connection: AsyncMyConnection): 254 | self._connection = connection 255 | self._is_root = False 256 | self._savepoint_name = "" 257 | 258 | async def start( 259 | self, is_root: bool, extra_options: typing.Dict[typing.Any, typing.Any] 260 | ) -> None: 261 | assert self._connection._connection is not None, "Connection is not acquired" 262 | self._is_root = is_root 263 | if self._is_root: 264 | await self._connection._connection.begin() 265 | else: 266 | id = str(uuid.uuid4()).replace("-", "_") 267 | self._savepoint_name = f"STARLETTE_SAVEPOINT_{id}" 268 | async with self._connection._connection.cursor() as cursor: 269 | try: 270 | await cursor.execute(f"SAVEPOINT {self._savepoint_name}") 271 | finally: 272 | await cursor.close() 273 | 274 | async def commit(self) -> None: 275 | assert self._connection._connection is not None, "Connection is not acquired" 276 | if self._is_root: 277 | await self._connection._connection.commit() 278 | else: 279 | async with self._connection._connection.cursor() as cursor: 280 | try: 281 | await cursor.execute(f"RELEASE SAVEPOINT {self._savepoint_name}") 282 | finally: 283 | await cursor.close() 284 | 285 | async def rollback(self) -> None: 286 | assert self._connection._connection is not None, "Connection is not acquired" 287 | if self._is_root: 288 | await self._connection._connection.rollback() 289 | else: 290 | async with self._connection._connection.cursor() as cursor: 291 | try: 292 | await cursor.execute( 293 | f"ROLLBACK TO SAVEPOINT {self._savepoint_name}" 294 | ) 295 | finally: 296 | await cursor.close() 297 | -------------------------------------------------------------------------------- /databases/backends/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/encode/databases/ae3fb16f40201d9ed0ed31bea31a289127169568/databases/backends/common/__init__.py -------------------------------------------------------------------------------- /databases/backends/common/records.py: -------------------------------------------------------------------------------- 1 | import enum 2 | import typing 3 | from datetime import date, datetime, time 4 | 5 | from sqlalchemy.engine.interfaces import Dialect 6 | from sqlalchemy.engine.row import Row as SQLRow 7 | from sqlalchemy.sql.compiler import _CompileLabel 8 | from sqlalchemy.sql.schema import Column 9 | from sqlalchemy.sql.sqltypes import JSON 10 | from sqlalchemy.types import TypeEngine 11 | 12 | from databases.interfaces import Record as RecordInterface 13 | 14 | DIALECT_EXCLUDE = {"postgresql"} 15 | 16 | 17 | class Record(RecordInterface): 18 | __slots__ = ( 19 | "_row", 20 | "_result_columns", 21 | "_dialect", 22 | "_column_map", 23 | "_column_map_int", 24 | "_column_map_full", 25 | ) 26 | 27 | def __init__( 28 | self, 29 | row: typing.Any, 30 | result_columns: tuple, 31 | dialect: Dialect, 32 | column_maps: typing.Tuple[ 33 | typing.Mapping[typing.Any, typing.Tuple[int, TypeEngine]], 34 | typing.Mapping[int, typing.Tuple[int, TypeEngine]], 35 | typing.Mapping[str, typing.Tuple[int, TypeEngine]], 36 | ], 37 | ) -> None: 38 | self._row = row 39 | self._result_columns = result_columns 40 | self._dialect = dialect 41 | self._column_map, self._column_map_int, self._column_map_full = column_maps 42 | 43 | @property 44 | def _mapping(self) -> typing.Mapping: 45 | return self._row 46 | 47 | def keys(self) -> typing.KeysView: 48 | return self._mapping.keys() 49 | 50 | def values(self) -> typing.ValuesView: 51 | return self._mapping.values() 52 | 53 | def __getitem__(self, key: typing.Any) -> typing.Any: 54 | if len(self._column_map) == 0: 55 | return self._row[key] 56 | elif isinstance(key, Column): 57 | idx, datatype = self._column_map_full[str(key)] 58 | elif isinstance(key, int): 59 | idx, datatype = self._column_map_int[key] 60 | else: 61 | idx, datatype = self._column_map[key] 62 | 63 | raw = self._row[idx] 64 | processor = datatype._cached_result_processor(self._dialect, None) 65 | 66 | if self._dialect.name in DIALECT_EXCLUDE: 67 | if processor is not None and isinstance(raw, (int, str, float)): 68 | return processor(raw) 69 | 70 | return raw 71 | 72 | def __iter__(self) -> typing.Iterator: 73 | return iter(self._row.keys()) 74 | 75 | def __len__(self) -> int: 76 | return len(self._row) 77 | 78 | def __getattr__(self, name: str) -> typing.Any: 79 | try: 80 | return self.__getitem__(name) 81 | except KeyError as e: 82 | raise AttributeError(e.args[0]) from e 83 | 84 | 85 | class Row(SQLRow): 86 | def __getitem__(self, key: typing.Any) -> typing.Any: 87 | """ 88 | An instance of a Row in SQLAlchemy allows the access 89 | to the Row._fields as tuple and the Row._mapping for 90 | the values. 91 | """ 92 | if isinstance(key, int): 93 | return super().__getitem__(key) 94 | 95 | idx = self._key_to_index[key][0] 96 | return super().__getitem__(idx) 97 | 98 | def keys(self): 99 | return self._mapping.keys() 100 | 101 | def values(self): 102 | return self._mapping.values() 103 | 104 | 105 | def create_column_maps( 106 | result_columns: typing.Any, 107 | ) -> typing.Tuple[ 108 | typing.Mapping[typing.Any, typing.Tuple[int, TypeEngine]], 109 | typing.Mapping[int, typing.Tuple[int, TypeEngine]], 110 | typing.Mapping[str, typing.Tuple[int, TypeEngine]], 111 | ]: 112 | """ 113 | Generate column -> datatype mappings from the column definitions. 114 | 115 | These mappings are used throughout PostgresConnection methods 116 | to initialize Record-s. The underlying DB driver does not do type 117 | conversion for us so we have wrap the returned asyncpg.Record-s. 118 | 119 | :return: Three mappings from different ways to address a column to \ 120 | corresponding column indexes and datatypes: \ 121 | 1. by column identifier; \ 122 | 2. by column index; \ 123 | 3. by column name in Column sqlalchemy objects. 124 | """ 125 | column_map, column_map_int, column_map_full = {}, {}, {} 126 | for idx, (column_name, _, column, datatype) in enumerate(result_columns): 127 | column_map[column_name] = (idx, datatype) 128 | column_map_int[idx] = (idx, datatype) 129 | 130 | # Added in SQLA 2.0 and _CompileLabels do not have _annotations 131 | # When this happens, the mapping is on the second position 132 | if isinstance(column[0], _CompileLabel): 133 | column_map_full[str(column[2])] = (idx, datatype) 134 | else: 135 | column_map_full[str(column[0])] = (idx, datatype) 136 | return column_map, column_map_int, column_map_full 137 | -------------------------------------------------------------------------------- /databases/backends/compilers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/encode/databases/ae3fb16f40201d9ed0ed31bea31a289127169568/databases/backends/compilers/__init__.py -------------------------------------------------------------------------------- /databases/backends/compilers/psycopg.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy.dialects.postgresql.psycopg import PGCompiler_psycopg 2 | 3 | 4 | class APGCompiler_psycopg2(PGCompiler_psycopg): 5 | def construct_params(self, *args, **kwargs): 6 | pd = super().construct_params(*args, **kwargs) 7 | 8 | for column in self.prefetch: 9 | pd[column.key] = self._exec_default(column.default) 10 | 11 | return pd 12 | 13 | def _exec_default(self, default): 14 | if default.is_callable: 15 | return default.arg(self.dialect) 16 | else: 17 | return default.arg 18 | -------------------------------------------------------------------------------- /databases/backends/dialects/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/encode/databases/ae3fb16f40201d9ed0ed31bea31a289127169568/databases/backends/dialects/__init__.py -------------------------------------------------------------------------------- /databases/backends/dialects/psycopg.py: -------------------------------------------------------------------------------- 1 | """ 2 | All the unique changes for the databases package 3 | with the custom Numeric as the deprecated pypostgresql 4 | for backwards compatibility and to make sure the 5 | package can go to SQLAlchemy 2.0+. 6 | """ 7 | 8 | import typing 9 | 10 | from sqlalchemy import types, util 11 | from sqlalchemy.dialects.postgresql.base import PGDialect, PGExecutionContext 12 | from sqlalchemy.engine import processors 13 | from sqlalchemy.types import Float, Numeric 14 | 15 | 16 | class PGExecutionContext_psycopg(PGExecutionContext): 17 | ... 18 | 19 | 20 | class PGNumeric(Numeric): 21 | def bind_processor( 22 | self, dialect: typing.Any 23 | ) -> typing.Union[str, None]: # pragma: no cover 24 | return processors.to_str 25 | 26 | def result_processor( 27 | self, dialect: typing.Any, coltype: typing.Any 28 | ) -> typing.Union[float, None]: # pragma: no cover 29 | if self.asdecimal: 30 | return None 31 | else: 32 | return processors.to_float 33 | 34 | 35 | class PGDialect_psycopg(PGDialect): 36 | colspecs = util.update_copy( 37 | PGDialect.colspecs, 38 | { 39 | types.Numeric: PGNumeric, 40 | types.Float: Float, 41 | }, 42 | ) 43 | execution_ctx_cls = PGExecutionContext_psycopg 44 | 45 | 46 | dialect = PGDialect_psycopg 47 | -------------------------------------------------------------------------------- /databases/backends/mysql.py: -------------------------------------------------------------------------------- 1 | import getpass 2 | import logging 3 | import typing 4 | import uuid 5 | 6 | import aiomysql 7 | from sqlalchemy.dialects.mysql import pymysql 8 | from sqlalchemy.engine.cursor import CursorResultMetaData 9 | from sqlalchemy.engine.interfaces import Dialect, ExecutionContext 10 | from sqlalchemy.sql import ClauseElement 11 | from sqlalchemy.sql.ddl import DDLElement 12 | 13 | from databases.backends.common.records import Record, Row, create_column_maps 14 | from databases.core import LOG_EXTRA, DatabaseURL 15 | from databases.interfaces import ( 16 | ConnectionBackend, 17 | DatabaseBackend, 18 | Record as RecordInterface, 19 | TransactionBackend, 20 | ) 21 | 22 | logger = logging.getLogger("databases") 23 | 24 | 25 | class MySQLBackend(DatabaseBackend): 26 | def __init__( 27 | self, database_url: typing.Union[DatabaseURL, str], **options: typing.Any 28 | ) -> None: 29 | self._database_url = DatabaseURL(database_url) 30 | self._options = options 31 | self._dialect = pymysql.dialect(paramstyle="pyformat") 32 | self._dialect.supports_native_decimal = True 33 | self._pool = None 34 | 35 | def _get_connection_kwargs(self) -> dict: 36 | url_options = self._database_url.options 37 | 38 | kwargs = {} 39 | min_size = url_options.get("min_size") 40 | max_size = url_options.get("max_size") 41 | pool_recycle = url_options.get("pool_recycle") 42 | ssl = url_options.get("ssl") 43 | unix_socket = url_options.get("unix_socket") 44 | 45 | if min_size is not None: 46 | kwargs["minsize"] = int(min_size) 47 | if max_size is not None: 48 | kwargs["maxsize"] = int(max_size) 49 | if pool_recycle is not None: 50 | kwargs["pool_recycle"] = int(pool_recycle) 51 | if ssl is not None: 52 | kwargs["ssl"] = {"true": True, "false": False}[ssl.lower()] 53 | if unix_socket is not None: 54 | kwargs["unix_socket"] = unix_socket 55 | 56 | for key, value in self._options.items(): 57 | # Coerce 'min_size' and 'max_size' for consistency. 58 | if key == "min_size": 59 | key = "minsize" 60 | elif key == "max_size": 61 | key = "maxsize" 62 | kwargs[key] = value 63 | 64 | return kwargs 65 | 66 | async def connect(self) -> None: 67 | assert self._pool is None, "DatabaseBackend is already running" 68 | kwargs = self._get_connection_kwargs() 69 | self._pool = await aiomysql.create_pool( 70 | host=self._database_url.hostname, 71 | port=self._database_url.port or 3306, 72 | user=self._database_url.username or getpass.getuser(), 73 | password=self._database_url.password, 74 | db=self._database_url.database, 75 | autocommit=True, 76 | **kwargs, 77 | ) 78 | 79 | async def disconnect(self) -> None: 80 | assert self._pool is not None, "DatabaseBackend is not running" 81 | self._pool.close() 82 | await self._pool.wait_closed() 83 | self._pool = None 84 | 85 | def connection(self) -> "MySQLConnection": 86 | return MySQLConnection(self, self._dialect) 87 | 88 | 89 | class CompilationContext: 90 | def __init__(self, context: ExecutionContext): 91 | self.context = context 92 | 93 | 94 | class MySQLConnection(ConnectionBackend): 95 | def __init__(self, database: MySQLBackend, dialect: Dialect): 96 | self._database = database 97 | self._dialect = dialect 98 | self._connection: typing.Optional[aiomysql.Connection] = None 99 | 100 | async def acquire(self) -> None: 101 | assert self._connection is None, "Connection is already acquired" 102 | assert self._database._pool is not None, "DatabaseBackend is not running" 103 | self._connection = await self._database._pool.acquire() 104 | 105 | async def release(self) -> None: 106 | assert self._connection is not None, "Connection is not acquired" 107 | assert self._database._pool is not None, "DatabaseBackend is not running" 108 | await self._database._pool.release(self._connection) 109 | self._connection = None 110 | 111 | async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]: 112 | assert self._connection is not None, "Connection is not acquired" 113 | query_str, args, result_columns, context = self._compile(query) 114 | column_maps = create_column_maps(result_columns) 115 | dialect = self._dialect 116 | cursor = await self._connection.cursor() 117 | try: 118 | await cursor.execute(query_str, args) 119 | rows = await cursor.fetchall() 120 | metadata = CursorResultMetaData(context, cursor.description) 121 | rows = [ 122 | Row( 123 | metadata, 124 | metadata._processors, 125 | metadata._keymap, 126 | row, 127 | ) 128 | for row in rows 129 | ] 130 | return [Record(row, result_columns, dialect, column_maps) for row in rows] 131 | finally: 132 | await cursor.close() 133 | 134 | async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]: 135 | assert self._connection is not None, "Connection is not acquired" 136 | query_str, args, result_columns, context = self._compile(query) 137 | column_maps = create_column_maps(result_columns) 138 | dialect = self._dialect 139 | cursor = await self._connection.cursor() 140 | try: 141 | await cursor.execute(query_str, args) 142 | row = await cursor.fetchone() 143 | if row is None: 144 | return None 145 | metadata = CursorResultMetaData(context, cursor.description) 146 | row = Row( 147 | metadata, 148 | metadata._processors, 149 | metadata._keymap, 150 | row, 151 | ) 152 | return Record(row, result_columns, dialect, column_maps) 153 | finally: 154 | await cursor.close() 155 | 156 | async def execute(self, query: ClauseElement) -> typing.Any: 157 | assert self._connection is not None, "Connection is not acquired" 158 | query_str, args, _, _ = self._compile(query) 159 | cursor = await self._connection.cursor() 160 | try: 161 | await cursor.execute(query_str, args) 162 | if cursor.lastrowid == 0: 163 | return cursor.rowcount 164 | return cursor.lastrowid 165 | finally: 166 | await cursor.close() 167 | 168 | async def execute_many(self, queries: typing.List[ClauseElement]) -> None: 169 | assert self._connection is not None, "Connection is not acquired" 170 | cursor = await self._connection.cursor() 171 | try: 172 | for single_query in queries: 173 | single_query, args, _, _ = self._compile(single_query) 174 | await cursor.execute(single_query, args) 175 | finally: 176 | await cursor.close() 177 | 178 | async def iterate( 179 | self, query: ClauseElement 180 | ) -> typing.AsyncGenerator[typing.Any, None]: 181 | assert self._connection is not None, "Connection is not acquired" 182 | query_str, args, result_columns, context = self._compile(query) 183 | column_maps = create_column_maps(result_columns) 184 | dialect = self._dialect 185 | cursor = await self._connection.cursor() 186 | try: 187 | await cursor.execute(query_str, args) 188 | metadata = CursorResultMetaData(context, cursor.description) 189 | async for row in cursor: 190 | record = Row( 191 | metadata, 192 | metadata._processors, 193 | metadata._keymap, 194 | row, 195 | ) 196 | yield Record(record, result_columns, dialect, column_maps) 197 | finally: 198 | await cursor.close() 199 | 200 | def transaction(self) -> TransactionBackend: 201 | return MySQLTransaction(self) 202 | 203 | def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]: 204 | compiled = query.compile( 205 | dialect=self._dialect, compile_kwargs={"render_postcompile": True} 206 | ) 207 | execution_context = self._dialect.execution_ctx_cls() 208 | execution_context.dialect = self._dialect 209 | 210 | if not isinstance(query, DDLElement): 211 | compiled_params = sorted(compiled.params.items()) 212 | 213 | args = compiled.construct_params() 214 | for key, val in args.items(): 215 | if key in compiled._bind_processors: 216 | args[key] = compiled._bind_processors[key](val) 217 | 218 | execution_context.result_column_struct = ( 219 | compiled._result_columns, 220 | compiled._ordered_columns, 221 | compiled._textual_ordered_columns, 222 | compiled._ad_hoc_textual, 223 | compiled._loose_column_name_matching, 224 | ) 225 | 226 | mapping = { 227 | key: "$" + str(i) for i, (key, _) in enumerate(compiled_params, start=1) 228 | } 229 | compiled_query = compiled.string % mapping 230 | result_map = compiled._result_columns 231 | 232 | else: 233 | args = {} 234 | result_map = None 235 | compiled_query = compiled.string 236 | 237 | query_message = compiled_query.replace(" \n", " ").replace("\n", " ") 238 | logger.debug( 239 | "Query: %s Args: %s", query_message, repr(tuple(args)), extra=LOG_EXTRA 240 | ) 241 | return compiled.string, args, result_map, CompilationContext(execution_context) 242 | 243 | @property 244 | def raw_connection(self) -> aiomysql.connection.Connection: 245 | assert self._connection is not None, "Connection is not acquired" 246 | return self._connection 247 | 248 | 249 | class MySQLTransaction(TransactionBackend): 250 | def __init__(self, connection: MySQLConnection): 251 | self._connection = connection 252 | self._is_root = False 253 | self._savepoint_name = "" 254 | 255 | async def start( 256 | self, is_root: bool, extra_options: typing.Dict[typing.Any, typing.Any] 257 | ) -> None: 258 | assert self._connection._connection is not None, "Connection is not acquired" 259 | self._is_root = is_root 260 | if self._is_root: 261 | await self._connection._connection.begin() 262 | else: 263 | id = str(uuid.uuid4()).replace("-", "_") 264 | self._savepoint_name = f"STARLETTE_SAVEPOINT_{id}" 265 | cursor = await self._connection._connection.cursor() 266 | try: 267 | await cursor.execute(f"SAVEPOINT {self._savepoint_name}") 268 | finally: 269 | await cursor.close() 270 | 271 | async def commit(self) -> None: 272 | assert self._connection._connection is not None, "Connection is not acquired" 273 | if self._is_root: 274 | await self._connection._connection.commit() 275 | else: 276 | cursor = await self._connection._connection.cursor() 277 | try: 278 | await cursor.execute(f"RELEASE SAVEPOINT {self._savepoint_name}") 279 | finally: 280 | await cursor.close() 281 | 282 | async def rollback(self) -> None: 283 | assert self._connection._connection is not None, "Connection is not acquired" 284 | if self._is_root: 285 | await self._connection._connection.rollback() 286 | else: 287 | cursor = await self._connection._connection.cursor() 288 | try: 289 | await cursor.execute(f"ROLLBACK TO SAVEPOINT {self._savepoint_name}") 290 | finally: 291 | await cursor.close() 292 | -------------------------------------------------------------------------------- /databases/backends/postgres.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import typing 3 | 4 | import asyncpg 5 | from sqlalchemy.engine.interfaces import Dialect 6 | from sqlalchemy.sql import ClauseElement 7 | from sqlalchemy.sql.ddl import DDLElement 8 | 9 | from databases.backends.common.records import Record, create_column_maps 10 | from databases.backends.dialects.psycopg import dialect as psycopg_dialect 11 | from databases.core import LOG_EXTRA, DatabaseURL 12 | from databases.interfaces import ( 13 | ConnectionBackend, 14 | DatabaseBackend, 15 | Record as RecordInterface, 16 | TransactionBackend, 17 | ) 18 | 19 | logger = logging.getLogger("databases") 20 | 21 | 22 | class PostgresBackend(DatabaseBackend): 23 | def __init__( 24 | self, database_url: typing.Union[DatabaseURL, str], **options: typing.Any 25 | ) -> None: 26 | self._database_url = DatabaseURL(database_url) 27 | self._options = options 28 | self._dialect = self._get_dialect() 29 | self._pool = None 30 | 31 | def _get_dialect(self) -> Dialect: 32 | dialect = psycopg_dialect(paramstyle="pyformat") 33 | 34 | dialect.implicit_returning = True 35 | dialect.supports_native_enum = True 36 | dialect.supports_smallserial = True # 9.2+ 37 | dialect._backslash_escapes = False 38 | dialect.supports_sane_multi_rowcount = True # psycopg 2.0.9+ 39 | dialect._has_native_hstore = True 40 | dialect.supports_native_decimal = True 41 | 42 | return dialect 43 | 44 | def _get_connection_kwargs(self) -> dict: 45 | url_options = self._database_url.options 46 | 47 | kwargs: typing.Dict[str, typing.Any] = {} 48 | min_size = url_options.get("min_size") 49 | max_size = url_options.get("max_size") 50 | ssl = url_options.get("ssl") 51 | 52 | if min_size is not None: 53 | kwargs["min_size"] = int(min_size) 54 | if max_size is not None: 55 | kwargs["max_size"] = int(max_size) 56 | if ssl is not None: 57 | ssl = ssl.lower() 58 | kwargs["ssl"] = {"true": True, "false": False}.get(ssl, ssl) 59 | 60 | kwargs.update(self._options) 61 | 62 | return kwargs 63 | 64 | async def connect(self) -> None: 65 | assert self._pool is None, "DatabaseBackend is already running" 66 | kwargs = dict( 67 | host=self._database_url.hostname, 68 | port=self._database_url.port, 69 | user=self._database_url.username, 70 | password=self._database_url.password, 71 | database=self._database_url.database, 72 | ) 73 | kwargs.update(self._get_connection_kwargs()) 74 | self._pool = await asyncpg.create_pool(**kwargs) 75 | 76 | async def disconnect(self) -> None: 77 | assert self._pool is not None, "DatabaseBackend is not running" 78 | await self._pool.close() 79 | self._pool = None 80 | 81 | def connection(self) -> "PostgresConnection": 82 | return PostgresConnection(self, self._dialect) 83 | 84 | 85 | class PostgresConnection(ConnectionBackend): 86 | def __init__(self, database: PostgresBackend, dialect: Dialect): 87 | self._database = database 88 | self._dialect = dialect 89 | self._connection: typing.Optional[asyncpg.connection.Connection] = None 90 | 91 | async def acquire(self) -> None: 92 | assert self._connection is None, "Connection is already acquired" 93 | assert self._database._pool is not None, "DatabaseBackend is not running" 94 | self._connection = await self._database._pool.acquire() 95 | 96 | async def release(self) -> None: 97 | assert self._connection is not None, "Connection is not acquired" 98 | assert self._database._pool is not None, "DatabaseBackend is not running" 99 | self._connection = await self._database._pool.release(self._connection) 100 | self._connection = None 101 | 102 | async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]: 103 | assert self._connection is not None, "Connection is not acquired" 104 | query_str, args, result_columns = self._compile(query) 105 | rows = await self._connection.fetch(query_str, *args) 106 | dialect = self._dialect 107 | column_maps = create_column_maps(result_columns) 108 | return [Record(row, result_columns, dialect, column_maps) for row in rows] 109 | 110 | async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]: 111 | assert self._connection is not None, "Connection is not acquired" 112 | query_str, args, result_columns = self._compile(query) 113 | row = await self._connection.fetchrow(query_str, *args) 114 | if row is None: 115 | return None 116 | return Record( 117 | row, 118 | result_columns, 119 | self._dialect, 120 | create_column_maps(result_columns), 121 | ) 122 | 123 | async def fetch_val( 124 | self, query: ClauseElement, column: typing.Any = 0 125 | ) -> typing.Any: 126 | # we are not calling self._connection.fetchval here because 127 | # it does not convert all the types, e.g. JSON stays string 128 | # instead of an object 129 | # see also: 130 | # https://github.com/encode/databases/pull/131 131 | # https://github.com/encode/databases/pull/132 132 | # https://github.com/encode/databases/pull/246 133 | row = await self.fetch_one(query) 134 | if row is None: 135 | return None 136 | return row[column] 137 | 138 | async def execute(self, query: ClauseElement) -> typing.Any: 139 | assert self._connection is not None, "Connection is not acquired" 140 | query_str, args, _ = self._compile(query) 141 | return await self._connection.fetchval(query_str, *args) 142 | 143 | async def execute_many(self, queries: typing.List[ClauseElement]) -> None: 144 | assert self._connection is not None, "Connection is not acquired" 145 | # asyncpg uses prepared statements under the hood, so we just 146 | # loop through multiple executes here, which should all end up 147 | # using the same prepared statement. 148 | for single_query in queries: 149 | single_query, args, _ = self._compile(single_query) 150 | await self._connection.execute(single_query, *args) 151 | 152 | async def iterate( 153 | self, query: ClauseElement 154 | ) -> typing.AsyncGenerator[typing.Any, None]: 155 | assert self._connection is not None, "Connection is not acquired" 156 | query_str, args, result_columns = self._compile(query) 157 | column_maps = create_column_maps(result_columns) 158 | async for row in self._connection.cursor(query_str, *args): 159 | yield Record(row, result_columns, self._dialect, column_maps) 160 | 161 | def transaction(self) -> TransactionBackend: 162 | return PostgresTransaction(connection=self) 163 | 164 | def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]: 165 | compiled = query.compile( 166 | dialect=self._dialect, compile_kwargs={"render_postcompile": True} 167 | ) 168 | 169 | if not isinstance(query, DDLElement): 170 | compiled_params = sorted(compiled.params.items()) 171 | 172 | mapping = { 173 | key: "$" + str(i) for i, (key, _) in enumerate(compiled_params, start=1) 174 | } 175 | compiled_query = compiled.string % mapping 176 | 177 | processors = compiled._bind_processors 178 | args = [ 179 | processors[key](val) if key in processors else val 180 | for key, val in compiled_params 181 | ] 182 | result_map = compiled._result_columns 183 | else: 184 | compiled_query = compiled.string 185 | args = [] 186 | result_map = None 187 | 188 | query_message = compiled_query.replace(" \n", " ").replace("\n", " ") 189 | logger.debug( 190 | "Query: %s Args: %s", query_message, repr(tuple(args)), extra=LOG_EXTRA 191 | ) 192 | return compiled_query, args, result_map 193 | 194 | @property 195 | def raw_connection(self) -> asyncpg.connection.Connection: 196 | assert self._connection is not None, "Connection is not acquired" 197 | return self._connection 198 | 199 | 200 | class PostgresTransaction(TransactionBackend): 201 | def __init__(self, connection: PostgresConnection): 202 | self._connection = connection 203 | self._transaction: typing.Optional[asyncpg.transaction.Transaction] = None 204 | 205 | async def start( 206 | self, is_root: bool, extra_options: typing.Dict[typing.Any, typing.Any] 207 | ) -> None: 208 | assert self._connection._connection is not None, "Connection is not acquired" 209 | self._transaction = self._connection._connection.transaction(**extra_options) 210 | await self._transaction.start() 211 | 212 | async def commit(self) -> None: 213 | assert self._transaction is not None 214 | await self._transaction.commit() 215 | 216 | async def rollback(self) -> None: 217 | assert self._transaction is not None 218 | await self._transaction.rollback() 219 | -------------------------------------------------------------------------------- /databases/backends/sqlite.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sqlite3 3 | import typing 4 | import uuid 5 | from urllib.parse import urlencode 6 | 7 | import aiosqlite 8 | from sqlalchemy.dialects.sqlite import pysqlite 9 | from sqlalchemy.engine.cursor import CursorResultMetaData 10 | from sqlalchemy.engine.interfaces import Dialect, ExecutionContext 11 | from sqlalchemy.sql import ClauseElement 12 | from sqlalchemy.sql.ddl import DDLElement 13 | 14 | from databases.backends.common.records import Record, Row, create_column_maps 15 | from databases.core import LOG_EXTRA, DatabaseURL 16 | from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend 17 | 18 | logger = logging.getLogger("databases") 19 | 20 | 21 | class SQLiteBackend(DatabaseBackend): 22 | def __init__( 23 | self, database_url: typing.Union[DatabaseURL, str], **options: typing.Any 24 | ) -> None: 25 | self._database_url = DatabaseURL(database_url) 26 | self._options = options 27 | self._dialect = pysqlite.dialect(paramstyle="qmark") 28 | # aiosqlite does not support decimals 29 | self._dialect.supports_native_decimal = False 30 | self._pool = SQLitePool(self._database_url, **self._options) 31 | 32 | async def connect(self) -> None: 33 | ... 34 | 35 | async def disconnect(self) -> None: 36 | # if it extsis, remove reference to connection to cached in-memory database on disconnect 37 | if self._pool._memref: 38 | self._pool._memref = None 39 | # assert self._pool is not None, "DatabaseBackend is not running" 40 | # self._pool.close() 41 | # await self._pool.wait_closed() 42 | # self._pool = None 43 | 44 | def connection(self) -> "SQLiteConnection": 45 | return SQLiteConnection(self._pool, self._dialect) 46 | 47 | 48 | class SQLitePool: 49 | def __init__(self, url: DatabaseURL, **options: typing.Any) -> None: 50 | self._database = url.database 51 | self._memref = None 52 | # add query params to database connection string 53 | if url.options: 54 | self._database += "?" + urlencode(url.options) 55 | self._options = options 56 | 57 | if url.options and "cache" in url.options: 58 | # reference to a connection to the cached in-memory database must be held to keep it from being deleted 59 | self._memref = sqlite3.connect(self._database, **self._options) 60 | 61 | async def acquire(self) -> aiosqlite.Connection: 62 | connection = aiosqlite.connect( 63 | database=self._database, isolation_level=None, **self._options 64 | ) 65 | await connection.__aenter__() 66 | return connection 67 | 68 | async def release(self, connection: aiosqlite.Connection) -> None: 69 | await connection.__aexit__(None, None, None) 70 | 71 | 72 | class CompilationContext: 73 | def __init__(self, context: ExecutionContext): 74 | self.context = context 75 | 76 | 77 | class SQLiteConnection(ConnectionBackend): 78 | def __init__(self, pool: SQLitePool, dialect: Dialect): 79 | self._pool = pool 80 | self._dialect = dialect 81 | self._connection: typing.Optional[aiosqlite.Connection] = None 82 | 83 | async def acquire(self) -> None: 84 | assert self._connection is None, "Connection is already acquired" 85 | self._connection = await self._pool.acquire() 86 | 87 | async def release(self) -> None: 88 | assert self._connection is not None, "Connection is not acquired" 89 | await self._pool.release(self._connection) 90 | self._connection = None 91 | 92 | async def fetch_all(self, query: ClauseElement) -> typing.List[Record]: 93 | assert self._connection is not None, "Connection is not acquired" 94 | query_str, args, result_columns, context = self._compile(query) 95 | column_maps = create_column_maps(result_columns) 96 | dialect = self._dialect 97 | 98 | async with self._connection.execute(query_str, args) as cursor: 99 | rows = await cursor.fetchall() 100 | metadata = CursorResultMetaData(context, cursor.description) 101 | rows = [ 102 | Row( 103 | metadata, 104 | metadata._processors, 105 | metadata._keymap, 106 | row, 107 | ) 108 | for row in rows 109 | ] 110 | return [Record(row, result_columns, dialect, column_maps) for row in rows] 111 | 112 | async def fetch_one(self, query: ClauseElement) -> typing.Optional[Record]: 113 | assert self._connection is not None, "Connection is not acquired" 114 | query_str, args, result_columns, context = self._compile(query) 115 | column_maps = create_column_maps(result_columns) 116 | dialect = self._dialect 117 | 118 | async with self._connection.execute(query_str, args) as cursor: 119 | row = await cursor.fetchone() 120 | if row is None: 121 | return None 122 | metadata = CursorResultMetaData(context, cursor.description) 123 | row = Row( 124 | metadata, 125 | metadata._processors, 126 | metadata._keymap, 127 | row, 128 | ) 129 | return Record(row, result_columns, dialect, column_maps) 130 | 131 | async def execute(self, query: ClauseElement) -> typing.Any: 132 | assert self._connection is not None, "Connection is not acquired" 133 | query_str, args, result_columns, context = self._compile(query) 134 | async with self._connection.cursor() as cursor: 135 | await cursor.execute(query_str, args) 136 | if cursor.lastrowid == 0: 137 | return cursor.rowcount 138 | return cursor.lastrowid 139 | 140 | async def execute_many(self, queries: typing.List[ClauseElement]) -> None: 141 | assert self._connection is not None, "Connection is not acquired" 142 | for single_query in queries: 143 | await self.execute(single_query) 144 | 145 | async def iterate( 146 | self, query: ClauseElement 147 | ) -> typing.AsyncGenerator[typing.Any, None]: 148 | assert self._connection is not None, "Connection is not acquired" 149 | query_str, args, result_columns, context = self._compile(query) 150 | column_maps = create_column_maps(result_columns) 151 | dialect = self._dialect 152 | 153 | async with self._connection.execute(query_str, args) as cursor: 154 | metadata = CursorResultMetaData(context, cursor.description) 155 | async for row in cursor: 156 | record = Row( 157 | metadata, 158 | metadata._processors, 159 | metadata._keymap, 160 | row, 161 | ) 162 | yield Record(record, result_columns, dialect, column_maps) 163 | 164 | def transaction(self) -> TransactionBackend: 165 | return SQLiteTransaction(self) 166 | 167 | def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]: 168 | compiled = query.compile( 169 | dialect=self._dialect, compile_kwargs={"render_postcompile": True} 170 | ) 171 | execution_context = self._dialect.execution_ctx_cls() 172 | execution_context.dialect = self._dialect 173 | 174 | args = [] 175 | result_map = None 176 | 177 | if not isinstance(query, DDLElement): 178 | compiled_params = sorted(compiled.params.items()) 179 | 180 | params = compiled.construct_params() 181 | for key in compiled.positiontup: 182 | raw_val = params[key] 183 | if key in compiled._bind_processors: 184 | val = compiled._bind_processors[key](raw_val) 185 | else: 186 | val = raw_val 187 | args.append(val) 188 | 189 | execution_context.result_column_struct = ( 190 | compiled._result_columns, 191 | compiled._ordered_columns, 192 | compiled._textual_ordered_columns, 193 | compiled._ad_hoc_textual, 194 | compiled._loose_column_name_matching, 195 | ) 196 | 197 | mapping = { 198 | key: "$" + str(i) for i, (key, _) in enumerate(compiled_params, start=1) 199 | } 200 | compiled_query = compiled.string % mapping 201 | result_map = compiled._result_columns 202 | 203 | else: 204 | compiled_query = compiled.string 205 | 206 | query_message = compiled_query.replace(" \n", " ").replace("\n", " ") 207 | logger.debug( 208 | "Query: %s Args: %s", query_message, repr(tuple(args)), extra=LOG_EXTRA 209 | ) 210 | return compiled.string, args, result_map, CompilationContext(execution_context) 211 | 212 | @property 213 | def raw_connection(self) -> aiosqlite.core.Connection: 214 | assert self._connection is not None, "Connection is not acquired" 215 | return self._connection 216 | 217 | 218 | class SQLiteTransaction(TransactionBackend): 219 | def __init__(self, connection: SQLiteConnection): 220 | self._connection = connection 221 | self._is_root = False 222 | self._savepoint_name = "" 223 | 224 | async def start( 225 | self, is_root: bool, extra_options: typing.Dict[typing.Any, typing.Any] 226 | ) -> None: 227 | assert self._connection._connection is not None, "Connection is not acquired" 228 | self._is_root = is_root 229 | if self._is_root: 230 | async with self._connection._connection.execute("BEGIN") as cursor: 231 | await cursor.close() 232 | else: 233 | id = str(uuid.uuid4()).replace("-", "_") 234 | self._savepoint_name = f"STARLETTE_SAVEPOINT_{id}" 235 | async with self._connection._connection.execute( 236 | f"SAVEPOINT {self._savepoint_name}" 237 | ) as cursor: 238 | await cursor.close() 239 | 240 | async def commit(self) -> None: 241 | assert self._connection._connection is not None, "Connection is not acquired" 242 | if self._is_root: 243 | async with self._connection._connection.execute("COMMIT") as cursor: 244 | await cursor.close() 245 | else: 246 | async with self._connection._connection.execute( 247 | f"RELEASE SAVEPOINT {self._savepoint_name}" 248 | ) as cursor: 249 | await cursor.close() 250 | 251 | async def rollback(self) -> None: 252 | assert self._connection._connection is not None, "Connection is not acquired" 253 | if self._is_root: 254 | async with self._connection._connection.execute("ROLLBACK") as cursor: 255 | await cursor.close() 256 | else: 257 | async with self._connection._connection.execute( 258 | f"ROLLBACK TO SAVEPOINT {self._savepoint_name}" 259 | ) as cursor: 260 | await cursor.close() 261 | -------------------------------------------------------------------------------- /databases/core.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import contextlib 3 | import functools 4 | import logging 5 | import typing 6 | import weakref 7 | from contextvars import ContextVar 8 | from types import TracebackType 9 | from urllib.parse import SplitResult, parse_qsl, unquote, urlsplit 10 | 11 | from sqlalchemy import text 12 | from sqlalchemy.sql import ClauseElement 13 | 14 | from databases.importer import import_from_string 15 | from databases.interfaces import DatabaseBackend, Record, TransactionBackend 16 | 17 | try: # pragma: no cover 18 | import click 19 | 20 | # Extra log info for optional coloured terminal outputs. 21 | LOG_EXTRA = { 22 | "color_message": "Query: " + click.style("%s", bold=True) + " Args: %s" 23 | } 24 | CONNECT_EXTRA = { 25 | "color_message": "Connected to database " + click.style("%s", bold=True) 26 | } 27 | DISCONNECT_EXTRA = { 28 | "color_message": "Disconnected from database " + click.style("%s", bold=True) 29 | } 30 | except ImportError: # pragma: no cover 31 | LOG_EXTRA = {} 32 | CONNECT_EXTRA = {} 33 | DISCONNECT_EXTRA = {} 34 | 35 | 36 | logger = logging.getLogger("databases") 37 | 38 | 39 | _ACTIVE_TRANSACTIONS: ContextVar[ 40 | typing.Optional["weakref.WeakKeyDictionary['Transaction', 'TransactionBackend']"] 41 | ] = ContextVar("databases:active_transactions", default=None) 42 | 43 | 44 | class Database: 45 | SUPPORTED_BACKENDS = { 46 | "postgresql": "databases.backends.postgres:PostgresBackend", 47 | "postgresql+aiopg": "databases.backends.aiopg:AiopgBackend", 48 | "postgres": "databases.backends.postgres:PostgresBackend", 49 | "mysql": "databases.backends.mysql:MySQLBackend", 50 | "mysql+asyncmy": "databases.backends.asyncmy:AsyncMyBackend", 51 | "sqlite": "databases.backends.sqlite:SQLiteBackend", 52 | } 53 | 54 | _connection_map: "weakref.WeakKeyDictionary[asyncio.Task, 'Connection']" 55 | 56 | def __init__( 57 | self, 58 | url: typing.Union[str, "DatabaseURL"], 59 | *, 60 | force_rollback: bool = False, 61 | **options: typing.Any, 62 | ): 63 | self.url = DatabaseURL(url) 64 | self.options = options 65 | self.is_connected = False 66 | self._connection_map = weakref.WeakKeyDictionary() 67 | 68 | self._force_rollback = force_rollback 69 | 70 | backend_str = self._get_backend() 71 | backend_cls = import_from_string(backend_str) 72 | assert issubclass(backend_cls, DatabaseBackend) 73 | self._backend = backend_cls(self.url, **self.options) 74 | 75 | # When `force_rollback=True` is used, we use a single global 76 | # connection, within a transaction that always rolls back. 77 | self._global_connection: typing.Optional[Connection] = None 78 | self._global_transaction: typing.Optional[Transaction] = None 79 | 80 | @property 81 | def _current_task(self) -> asyncio.Task: 82 | task = asyncio.current_task() 83 | if not task: 84 | raise RuntimeError("No currently active asyncio.Task found") 85 | return task 86 | 87 | @property 88 | def _connection(self) -> typing.Optional["Connection"]: 89 | return self._connection_map.get(self._current_task) 90 | 91 | @_connection.setter 92 | def _connection( 93 | self, connection: typing.Optional["Connection"] 94 | ) -> typing.Optional["Connection"]: 95 | task = self._current_task 96 | 97 | if connection is None: 98 | self._connection_map.pop(task, None) 99 | else: 100 | self._connection_map[task] = connection 101 | 102 | return self._connection 103 | 104 | async def connect(self) -> None: 105 | """ 106 | Establish the connection pool. 107 | """ 108 | if self.is_connected: 109 | logger.debug("Already connected, skipping connection") 110 | return None 111 | 112 | await self._backend.connect() 113 | logger.info( 114 | "Connected to database %s", self.url.obscure_password, extra=CONNECT_EXTRA 115 | ) 116 | self.is_connected = True 117 | 118 | if self._force_rollback: 119 | assert self._global_connection is None 120 | assert self._global_transaction is None 121 | 122 | self._global_connection = Connection(self, self._backend) 123 | self._global_transaction = self._global_connection.transaction( 124 | force_rollback=True 125 | ) 126 | 127 | await self._global_transaction.__aenter__() 128 | 129 | async def disconnect(self) -> None: 130 | """ 131 | Close all connections in the connection pool. 132 | """ 133 | if not self.is_connected: 134 | logger.debug("Already disconnected, skipping disconnection") 135 | return None 136 | 137 | if self._force_rollback: 138 | assert self._global_connection is not None 139 | assert self._global_transaction is not None 140 | 141 | await self._global_transaction.__aexit__() 142 | 143 | self._global_transaction = None 144 | self._global_connection = None 145 | else: 146 | self._connection = None 147 | 148 | await self._backend.disconnect() 149 | logger.info( 150 | "Disconnected from database %s", 151 | self.url.obscure_password, 152 | extra=DISCONNECT_EXTRA, 153 | ) 154 | self.is_connected = False 155 | 156 | async def __aenter__(self) -> "Database": 157 | await self.connect() 158 | return self 159 | 160 | async def __aexit__( 161 | self, 162 | exc_type: typing.Optional[typing.Type[BaseException]] = None, 163 | exc_value: typing.Optional[BaseException] = None, 164 | traceback: typing.Optional[TracebackType] = None, 165 | ) -> None: 166 | await self.disconnect() 167 | 168 | async def fetch_all( 169 | self, 170 | query: typing.Union[ClauseElement, str], 171 | values: typing.Optional[dict] = None, 172 | ) -> typing.List[Record]: 173 | async with self.connection() as connection: 174 | return await connection.fetch_all(query, values) 175 | 176 | async def fetch_one( 177 | self, 178 | query: typing.Union[ClauseElement, str], 179 | values: typing.Optional[dict] = None, 180 | ) -> typing.Optional[Record]: 181 | async with self.connection() as connection: 182 | return await connection.fetch_one(query, values) 183 | 184 | async def fetch_val( 185 | self, 186 | query: typing.Union[ClauseElement, str], 187 | values: typing.Optional[dict] = None, 188 | column: typing.Any = 0, 189 | ) -> typing.Any: 190 | async with self.connection() as connection: 191 | return await connection.fetch_val(query, values, column=column) 192 | 193 | async def execute( 194 | self, 195 | query: typing.Union[ClauseElement, str], 196 | values: typing.Optional[dict] = None, 197 | ) -> typing.Any: 198 | async with self.connection() as connection: 199 | return await connection.execute(query, values) 200 | 201 | async def execute_many( 202 | self, query: typing.Union[ClauseElement, str], values: list 203 | ) -> None: 204 | async with self.connection() as connection: 205 | return await connection.execute_many(query, values) 206 | 207 | async def iterate( 208 | self, 209 | query: typing.Union[ClauseElement, str], 210 | values: typing.Optional[dict] = None, 211 | ) -> typing.AsyncGenerator[typing.Mapping, None]: 212 | async with self.connection() as connection: 213 | async for record in connection.iterate(query, values): 214 | yield record 215 | 216 | def connection(self) -> "Connection": 217 | if self._global_connection is not None: 218 | return self._global_connection 219 | 220 | if not self._connection: 221 | self._connection = Connection(self, self._backend) 222 | 223 | return self._connection 224 | 225 | def transaction( 226 | self, *, force_rollback: bool = False, **kwargs: typing.Any 227 | ) -> "Transaction": 228 | return Transaction(self.connection, force_rollback=force_rollback, **kwargs) 229 | 230 | @contextlib.contextmanager 231 | def force_rollback(self) -> typing.Iterator[None]: 232 | initial = self._force_rollback 233 | self._force_rollback = True 234 | try: 235 | yield 236 | finally: 237 | self._force_rollback = initial 238 | 239 | def _get_backend(self) -> str: 240 | return self.SUPPORTED_BACKENDS.get( 241 | self.url.scheme, self.SUPPORTED_BACKENDS[self.url.dialect] 242 | ) 243 | 244 | 245 | class Connection: 246 | def __init__(self, database: Database, backend: DatabaseBackend) -> None: 247 | self._database = database 248 | self._backend = backend 249 | 250 | self._connection_lock = asyncio.Lock() 251 | self._connection = self._backend.connection() 252 | self._connection_counter = 0 253 | 254 | self._transaction_lock = asyncio.Lock() 255 | self._transaction_stack: typing.List[Transaction] = [] 256 | 257 | self._query_lock = asyncio.Lock() 258 | 259 | async def __aenter__(self) -> "Connection": 260 | async with self._connection_lock: 261 | self._connection_counter += 1 262 | try: 263 | if self._connection_counter == 1: 264 | await self._connection.acquire() 265 | except BaseException as e: 266 | self._connection_counter -= 1 267 | raise e 268 | return self 269 | 270 | async def __aexit__( 271 | self, 272 | exc_type: typing.Optional[typing.Type[BaseException]] = None, 273 | exc_value: typing.Optional[BaseException] = None, 274 | traceback: typing.Optional[TracebackType] = None, 275 | ) -> None: 276 | async with self._connection_lock: 277 | assert self._connection is not None 278 | self._connection_counter -= 1 279 | if self._connection_counter == 0: 280 | await self._connection.release() 281 | self._database._connection = None 282 | 283 | async def fetch_all( 284 | self, 285 | query: typing.Union[ClauseElement, str], 286 | values: typing.Optional[dict] = None, 287 | ) -> typing.List[Record]: 288 | built_query = self._build_query(query, values) 289 | async with self._query_lock: 290 | return await self._connection.fetch_all(built_query) 291 | 292 | async def fetch_one( 293 | self, 294 | query: typing.Union[ClauseElement, str], 295 | values: typing.Optional[dict] = None, 296 | ) -> typing.Optional[Record]: 297 | built_query = self._build_query(query, values) 298 | async with self._query_lock: 299 | return await self._connection.fetch_one(built_query) 300 | 301 | async def fetch_val( 302 | self, 303 | query: typing.Union[ClauseElement, str], 304 | values: typing.Optional[dict] = None, 305 | column: typing.Any = 0, 306 | ) -> typing.Any: 307 | built_query = self._build_query(query, values) 308 | async with self._query_lock: 309 | return await self._connection.fetch_val(built_query, column) 310 | 311 | async def execute( 312 | self, 313 | query: typing.Union[ClauseElement, str], 314 | values: typing.Optional[dict] = None, 315 | ) -> typing.Any: 316 | built_query = self._build_query(query, values) 317 | async with self._query_lock: 318 | return await self._connection.execute(built_query) 319 | 320 | async def execute_many( 321 | self, query: typing.Union[ClauseElement, str], values: list 322 | ) -> None: 323 | queries = [self._build_query(query, values_set) for values_set in values] 324 | async with self._query_lock: 325 | await self._connection.execute_many(queries) 326 | 327 | async def iterate( 328 | self, 329 | query: typing.Union[ClauseElement, str], 330 | values: typing.Optional[dict] = None, 331 | ) -> typing.AsyncGenerator[typing.Any, None]: 332 | built_query = self._build_query(query, values) 333 | async with self.transaction(): 334 | async with self._query_lock: 335 | async for record in self._connection.iterate(built_query): 336 | yield record 337 | 338 | def transaction( 339 | self, *, force_rollback: bool = False, **kwargs: typing.Any 340 | ) -> "Transaction": 341 | def connection_callable() -> Connection: 342 | return self 343 | 344 | return Transaction(connection_callable, force_rollback, **kwargs) 345 | 346 | @property 347 | def raw_connection(self) -> typing.Any: 348 | return self._connection.raw_connection 349 | 350 | @staticmethod 351 | def _build_query( 352 | query: typing.Union[ClauseElement, str], values: typing.Optional[dict] = None 353 | ) -> ClauseElement: 354 | if isinstance(query, str): 355 | query = text(query) 356 | 357 | return query.bindparams(**values) if values is not None else query 358 | elif values: 359 | return query.values(**values) # type: ignore 360 | 361 | return query 362 | 363 | 364 | _CallableType = typing.TypeVar("_CallableType", bound=typing.Callable) 365 | 366 | 367 | class Transaction: 368 | def __init__( 369 | self, 370 | connection_callable: typing.Callable[[], Connection], 371 | force_rollback: bool, 372 | **kwargs: typing.Any, 373 | ) -> None: 374 | self._connection_callable = connection_callable 375 | self._force_rollback = force_rollback 376 | self._extra_options = kwargs 377 | 378 | @property 379 | def _connection(self) -> "Connection": 380 | # Returns the same connection if called multiple times 381 | return self._connection_callable() 382 | 383 | @property 384 | def _transaction(self) -> typing.Optional["TransactionBackend"]: 385 | transactions = _ACTIVE_TRANSACTIONS.get() 386 | if transactions is None: 387 | return None 388 | 389 | return transactions.get(self, None) 390 | 391 | @_transaction.setter 392 | def _transaction( 393 | self, transaction: typing.Optional["TransactionBackend"] 394 | ) -> typing.Optional["TransactionBackend"]: 395 | transactions = _ACTIVE_TRANSACTIONS.get() 396 | if transactions is None: 397 | transactions = weakref.WeakKeyDictionary() 398 | else: 399 | transactions = transactions.copy() 400 | 401 | if transaction is None: 402 | transactions.pop(self, None) 403 | else: 404 | transactions[self] = transaction 405 | 406 | _ACTIVE_TRANSACTIONS.set(transactions) 407 | return transactions.get(self, None) 408 | 409 | async def __aenter__(self) -> "Transaction": 410 | """ 411 | Called when entering `async with database.transaction()` 412 | """ 413 | await self.start() 414 | return self 415 | 416 | async def __aexit__( 417 | self, 418 | exc_type: typing.Optional[typing.Type[BaseException]] = None, 419 | exc_value: typing.Optional[BaseException] = None, 420 | traceback: typing.Optional[TracebackType] = None, 421 | ) -> None: 422 | """ 423 | Called when exiting `async with database.transaction()` 424 | """ 425 | if exc_type is not None or self._force_rollback: 426 | await self.rollback() 427 | else: 428 | await self.commit() 429 | 430 | def __await__(self) -> typing.Generator[None, None, "Transaction"]: 431 | """ 432 | Called if using the low-level `transaction = await database.transaction()` 433 | """ 434 | return self.start().__await__() 435 | 436 | def __call__(self, func: _CallableType) -> _CallableType: 437 | """ 438 | Called if using `@database.transaction()` as a decorator. 439 | """ 440 | 441 | @functools.wraps(func) 442 | async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: 443 | async with self: 444 | return await func(*args, **kwargs) 445 | 446 | return wrapper # type: ignore 447 | 448 | async def start(self) -> "Transaction": 449 | self._transaction = self._connection._connection.transaction() 450 | 451 | async with self._connection._transaction_lock: 452 | is_root = not self._connection._transaction_stack 453 | await self._connection.__aenter__() 454 | await self._transaction.start( 455 | is_root=is_root, extra_options=self._extra_options 456 | ) 457 | self._connection._transaction_stack.append(self) 458 | return self 459 | 460 | async def commit(self) -> None: 461 | async with self._connection._transaction_lock: 462 | assert self._connection._transaction_stack[-1] is self 463 | self._connection._transaction_stack.pop() 464 | assert self._transaction is not None 465 | await self._transaction.commit() 466 | await self._connection.__aexit__() 467 | self._transaction = None 468 | 469 | async def rollback(self) -> None: 470 | async with self._connection._transaction_lock: 471 | assert self._connection._transaction_stack[-1] is self 472 | self._connection._transaction_stack.pop() 473 | assert self._transaction is not None 474 | await self._transaction.rollback() 475 | await self._connection.__aexit__() 476 | self._transaction = None 477 | 478 | 479 | class _EmptyNetloc(str): 480 | def __bool__(self) -> bool: 481 | return True 482 | 483 | 484 | class DatabaseURL: 485 | def __init__(self, url: typing.Union[str, "DatabaseURL"]): 486 | if isinstance(url, DatabaseURL): 487 | self._url: str = url._url 488 | elif isinstance(url, str): 489 | self._url = url 490 | else: 491 | raise TypeError( 492 | f"Invalid type for DatabaseURL. Expected str or DatabaseURL, got {type(url)}" 493 | ) 494 | 495 | @property 496 | def components(self) -> SplitResult: 497 | if not hasattr(self, "_components"): 498 | self._components = urlsplit(self._url) 499 | return self._components 500 | 501 | @property 502 | def scheme(self) -> str: 503 | return self.components.scheme 504 | 505 | @property 506 | def dialect(self) -> str: 507 | return self.components.scheme.split("+")[0] 508 | 509 | @property 510 | def driver(self) -> str: 511 | if "+" not in self.components.scheme: 512 | return "" 513 | return self.components.scheme.split("+", 1)[1] 514 | 515 | @property 516 | def userinfo(self) -> typing.Optional[bytes]: 517 | if self.components.username: 518 | info = self.components.username 519 | if self.components.password: 520 | info += ":" + self.components.password 521 | return info.encode("utf-8") 522 | return None 523 | 524 | @property 525 | def username(self) -> typing.Optional[str]: 526 | if self.components.username is None: 527 | return None 528 | return unquote(self.components.username) 529 | 530 | @property 531 | def password(self) -> typing.Optional[str]: 532 | if self.components.password is None: 533 | return None 534 | return unquote(self.components.password) 535 | 536 | @property 537 | def hostname(self) -> typing.Optional[str]: 538 | return ( 539 | self.components.hostname 540 | or self.options.get("host") 541 | or self.options.get("unix_sock") 542 | ) 543 | 544 | @property 545 | def port(self) -> typing.Optional[int]: 546 | return self.components.port 547 | 548 | @property 549 | def netloc(self) -> typing.Optional[str]: 550 | return self.components.netloc 551 | 552 | @property 553 | def database(self) -> str: 554 | path = self.components.path 555 | if path.startswith("/"): 556 | path = path[1:] 557 | return unquote(path) 558 | 559 | @property 560 | def options(self) -> dict: 561 | if not hasattr(self, "_options"): 562 | self._options = dict(parse_qsl(self.components.query)) 563 | return self._options 564 | 565 | def replace(self, **kwargs: typing.Any) -> "DatabaseURL": 566 | if ( 567 | "username" in kwargs 568 | or "password" in kwargs 569 | or "hostname" in kwargs 570 | or "port" in kwargs 571 | ): 572 | hostname = kwargs.pop("hostname", self.hostname) 573 | port = kwargs.pop("port", self.port) 574 | username = kwargs.pop("username", self.components.username) 575 | password = kwargs.pop("password", self.components.password) 576 | 577 | netloc = hostname 578 | if port is not None: 579 | netloc += f":{port}" 580 | if username is not None: 581 | userpass = username 582 | if password is not None: 583 | userpass += f":{password}" 584 | netloc = f"{userpass}@{netloc}" 585 | 586 | kwargs["netloc"] = netloc 587 | 588 | if "database" in kwargs: 589 | kwargs["path"] = "/" + kwargs.pop("database") 590 | 591 | if "dialect" in kwargs or "driver" in kwargs: 592 | dialect = kwargs.pop("dialect", self.dialect) 593 | driver = kwargs.pop("driver", self.driver) 594 | kwargs["scheme"] = f"{dialect}+{driver}" if driver else dialect 595 | 596 | if not kwargs.get("netloc", self.netloc): 597 | # Using an empty string that evaluates as True means we end up 598 | # with URLs like `sqlite:///database` instead of `sqlite:/database` 599 | kwargs["netloc"] = _EmptyNetloc() 600 | 601 | components = self.components._replace(**kwargs) 602 | return self.__class__(components.geturl()) 603 | 604 | @property 605 | def obscure_password(self) -> str: 606 | if self.password: 607 | return self.replace(password="********")._url 608 | return self._url 609 | 610 | def __str__(self) -> str: 611 | return self._url 612 | 613 | def __repr__(self) -> str: 614 | return f"{self.__class__.__name__}({repr(self.obscure_password)})" 615 | 616 | def __eq__(self, other: typing.Any) -> bool: 617 | return str(self) == str(other) 618 | -------------------------------------------------------------------------------- /databases/importer.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import typing 3 | 4 | 5 | class ImportFromStringError(Exception): 6 | pass 7 | 8 | 9 | def import_from_string(import_str: str) -> typing.Any: 10 | module_str, _, attrs_str = import_str.partition(":") 11 | if not module_str or not attrs_str: 12 | message = ( 13 | 'Import string "{import_str}" must be in format "— ⭐️ —
97 |Databases is BSD licensed code. Designed & built in Brighton, England.
98 | 99 | [sqlalchemy-core]: https://docs.sqlalchemy.org/en/latest/core/ 100 | [sqlalchemy-core-tutorial]: https://docs.sqlalchemy.org/en/latest/core/tutorial.html 101 | [alembic]: https://alembic.sqlalchemy.org/en/latest/ 102 | [psycopg2]: https://www.psycopg.org/ 103 | [pymysql]: https://github.com/PyMySQL/PyMySQL 104 | [asyncpg]: https://github.com/MagicStack/asyncpg 105 | [aiopg]: https://github.com/aio-libs/aiopg 106 | [aiomysql]: https://github.com/aio-libs/aiomysql 107 | [asyncmy]: https://github.com/long2ice/asyncmy 108 | [aiosqlite]: https://github.com/omnilib/aiosqlite 109 | 110 | [starlette]: https://github.com/encode/starlette 111 | [sanic]: https://github.com/huge-success/sanic 112 | [responder]: https://github.com/kennethreitz/responder 113 | [quart]: https://gitlab.com/pgjones/quart 114 | [aiohttp]: https://github.com/aio-libs/aiohttp 115 | [tornado]: https://github.com/tornadoweb/tornado 116 | [fastapi]: https://github.com/tiangolo/fastapi 117 | -------------------------------------------------------------------------------- /docs/tests_and_migrations.md: -------------------------------------------------------------------------------- 1 | # Tests and Migrations 2 | 3 | Databases is designed to allow you to fully integrate with production 4 | ready services, with API support for test isolation, and integration 5 | with [Alembic][alembic] for database migrations. 6 | 7 | ## Test isolation 8 | 9 | For strict test isolation you will always want to rollback the test database 10 | to a clean state between each test case: 11 | 12 | ```python 13 | database = Database(DATABASE_URL, force_rollback=True) 14 | ``` 15 | 16 | This will ensure that all database connections are run within a transaction 17 | that rollbacks once the database is disconnected. 18 | 19 | If you're integrating against a web framework you'll typically want to 20 | use something like the following pattern: 21 | 22 | ```python 23 | if TESTING: 24 | database = Database(TEST_DATABASE_URL, force_rollback=True) 25 | else: 26 | database = Database(DATABASE_URL) 27 | ``` 28 | 29 | This will give you test cases that run against a different database to 30 | the development database, with strict test isolation so long as you make sure 31 | to connect and disconnect to the database between test cases. 32 | 33 | For a lower level API you can explicitly create force-rollback transactions: 34 | 35 | ```python 36 | async with database.transaction(force_rollback=True): 37 | ... 38 | ``` 39 | 40 | ## Migrations 41 | 42 | Because `databases` uses SQLAlchemy core, you can integrate with [Alembic][alembic] 43 | for database migration support. 44 | 45 | ```shell 46 | $ pip install alembic 47 | $ alembic init migrations 48 | ``` 49 | 50 | You'll want to set things up so that Alembic references the configured 51 | `DATABASE_URL`, and uses your table metadata. 52 | 53 | In `alembic.ini` remove the following line: 54 | 55 | ```shell 56 | sqlalchemy.url = driver://user:pass@localhost/dbname 57 | ``` 58 | 59 | In `migrations/env.py`, you need to set the ``'sqlalchemy.url'`` configuration key, 60 | and the `target_metadata` variable. You'll want something like this: 61 | 62 | ```python 63 | # The Alembic Config object. 64 | config = context.config 65 | 66 | # Configure Alembic to use our DATABASE_URL and our table definitions. 67 | # These are just examples - the exact setup will depend on whatever 68 | # framework you're integrating against. 69 | from myapp.settings import DATABASE_URL 70 | from myapp.tables import metadata 71 | 72 | config.set_main_option('sqlalchemy.url', str(DATABASE_URL)) 73 | target_metadata = metadata 74 | 75 | ... 76 | ``` 77 | 78 | Note that migrations will use a standard synchronous database driver, 79 | rather than using the async drivers that `databases` provides support for. 80 | 81 | This will also be the case if you're using SQLAlchemy's standard tooling, such 82 | as using `metadata.create_all(engine)` to setup the database tables. 83 | 84 | **Note for MySQL**: 85 | 86 | For MySQL you'll probably need to explicitly specify the `pymysql` dialect when 87 | using Alembic since the default MySQL dialect does not support Python 3. 88 | 89 | If you're using the `databases.DatabaseURL` datatype, you can obtain this using 90 | `DATABASE_URL.replace(dialect="pymysql")` 91 | 92 | [alembic]: https://alembic.sqlalchemy.org/en/latest/ 93 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: Databases 2 | site_description: Async database support for Python. 3 | 4 | theme: 5 | name: 'material' 6 | 7 | repo_name: encode/databases 8 | repo_url: https://github.com/encode/databases 9 | # edit_uri: "" 10 | 11 | nav: 12 | - Introduction: 'index.md' 13 | - Database Queries: 'database_queries.md' 14 | - Connections & Transactions: 'connections_and_transactions.md' 15 | - Tests & Migrations: 'tests_and_migrations.md' 16 | - Contributing: 'contributing.md' 17 | 18 | markdown_extensions: 19 | - mkautodoc 20 | - admonition 21 | - pymdownx.highlight 22 | - pymdownx.superfences 23 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -e . 2 | 3 | # Async database drivers 4 | asyncmy==0.2.9 5 | aiomysql==0.2.0 6 | aiopg==1.4.0 7 | aiosqlite==0.20.0 8 | asyncpg==0.29.0 9 | 10 | # Sync database drivers for standard tooling around setup/teardown/migrations. 11 | psycopg==3.1.18 12 | pymysql==1.1.0 13 | 14 | # Testing 15 | autoflake==1.4 16 | black==22.6.0 17 | httpx==0.24.1 18 | isort==5.10.1 19 | mypy==0.971 20 | pytest==7.1.2 21 | pytest-cov==3.0.0 22 | starlette==0.36.2 23 | requests==2.31.0 24 | 25 | # Documentation 26 | mkdocs==1.3.1 27 | mkdocs-material==8.3.9 28 | mkautodoc==0.1.0 29 | 30 | # Packaging 31 | twine==4.0.1 32 | wheel==0.38.1 33 | setuptools==69.0.3 34 | -------------------------------------------------------------------------------- /scripts/README.md: -------------------------------------------------------------------------------- 1 | # Development Scripts 2 | 3 | * `scripts/build` - Build package and documentation. 4 | * `scripts/check` - Check lint and formatting. 5 | * `scripts/clean` - Delete any build artifacts. 6 | * `scripts/coverage` - Check test coverage. 7 | * `scripts/docs` - Run documentation server locally. 8 | * `scripts/install` - Install dependencies in a virtual environment. 9 | * `scripts/lint` - Run the code linting. 10 | * `scripts/publish` - Publish the latest version to PyPI. 11 | * `scripts/test` - Run the test suite. 12 | 13 | Styled after GitHub's ["Scripts to Rule Them All"](https://github.com/github/scripts-to-rule-them-all). 14 | -------------------------------------------------------------------------------- /scripts/build: -------------------------------------------------------------------------------- 1 | #!/bin/sh -e 2 | 3 | if [ -d 'venv' ] ; then 4 | PREFIX="venv/bin/" 5 | else 6 | PREFIX="" 7 | fi 8 | 9 | set -x 10 | 11 | ${PREFIX}python setup.py sdist bdist_wheel 12 | ${PREFIX}twine check dist/* 13 | ${PREFIX}mkdocs build 14 | -------------------------------------------------------------------------------- /scripts/check: -------------------------------------------------------------------------------- 1 | #!/bin/sh -e 2 | 3 | export PREFIX="" 4 | if [ -d 'venv' ] ; then 5 | export PREFIX="venv/bin/" 6 | fi 7 | export SOURCE_FILES="databases tests" 8 | 9 | set -x 10 | 11 | ${PREFIX}isort --check --diff --project=databases $SOURCE_FILES 12 | ${PREFIX}black --check --diff $SOURCE_FILES 13 | ${PREFIX}mypy databases 14 | -------------------------------------------------------------------------------- /scripts/clean: -------------------------------------------------------------------------------- 1 | #!/bin/sh -e 2 | 3 | if [ -d 'dist' ] ; then 4 | rm -r dist 5 | fi 6 | if [ -d 'site' ] ; then 7 | rm -r site 8 | fi 9 | if [ -d 'databases.egg-info' ] ; then 10 | rm -r databases.egg-info 11 | fi 12 | if [ -d '.mypy_cache' ] ; then 13 | rm -r .mypy_cache 14 | fi 15 | if [ -d '.pytest_cache' ] ; then 16 | rm -r .pytest_cache 17 | fi 18 | 19 | find databases -type f -name "*.py[co]" -delete 20 | find databases -type d -name __pycache__ -delete 21 | -------------------------------------------------------------------------------- /scripts/coverage: -------------------------------------------------------------------------------- 1 | #!/bin/sh -e 2 | 3 | export PREFIX="" 4 | if [ -d 'venv' ] ; then 5 | export PREFIX="venv/bin/" 6 | fi 7 | 8 | set -x 9 | 10 | ${PREFIX}coverage report --show-missing --skip-covered --fail-under=100 11 | -------------------------------------------------------------------------------- /scripts/docs: -------------------------------------------------------------------------------- 1 | #!/bin/sh -e 2 | 3 | export PREFIX="" 4 | if [ -d 'venv' ] ; then 5 | export PREFIX="venv/bin/" 6 | fi 7 | 8 | set -x 9 | 10 | ${PREFIX}mkdocs serve 11 | -------------------------------------------------------------------------------- /scripts/install: -------------------------------------------------------------------------------- 1 | #!/bin/sh -e 2 | 3 | # Use the Python executable provided from the `-p` option, or a default. 4 | [ "$1" = "-p" ] && PYTHON=$2 || PYTHON="python3" 5 | 6 | REQUIREMENTS="requirements.txt" 7 | VENV="venv" 8 | 9 | set -x 10 | 11 | if [ -z "$GITHUB_ACTIONS" ]; then 12 | "$PYTHON" -m venv "$VENV" 13 | PIP="$VENV/bin/pip" 14 | else 15 | PIP="pip" 16 | fi 17 | 18 | "$PIP" install -r "$REQUIREMENTS" 19 | "$PIP" install -e . 20 | -------------------------------------------------------------------------------- /scripts/lint: -------------------------------------------------------------------------------- 1 | #!/bin/sh -e 2 | 3 | export PREFIX="" 4 | if [ -d 'venv' ] ; then 5 | export PREFIX="venv/bin/" 6 | fi 7 | 8 | set -x 9 | 10 | ${PREFIX}autoflake --in-place --recursive databases tests 11 | ${PREFIX}isort --project=databases databases tests 12 | ${PREFIX}black databases tests 13 | ${PREFIX}mypy databases 14 | -------------------------------------------------------------------------------- /scripts/publish: -------------------------------------------------------------------------------- 1 | #!/bin/sh -e 2 | 3 | VERSION_FILE="databases/__init__.py" 4 | 5 | if [ -d 'venv' ] ; then 6 | PREFIX="venv/bin/" 7 | else 8 | PREFIX="" 9 | fi 10 | 11 | if [ ! -z "$GITHUB_ACTIONS" ]; then 12 | git config --local user.email "41898282+github-actions[bot]@users.noreply.github.com" 13 | git config --local user.name "GitHub Action" 14 | 15 | VERSION=`grep __version__ ${VERSION_FILE} | grep -o '[0-9][^"]*'` 16 | 17 | if [ "refs/tags/${VERSION}" != "${GITHUB_REF}" ] ; then 18 | echo "GitHub Ref '${GITHUB_REF}' did not match package version '${VERSION}'" 19 | exit 1 20 | fi 21 | fi 22 | 23 | set -x 24 | 25 | ${PREFIX}twine upload dist/* 26 | ${PREFIX}mkdocs gh-deploy --force 27 | -------------------------------------------------------------------------------- /scripts/test: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | export PREFIX="" 4 | if [ -d 'venv' ] ; then 5 | export PREFIX="venv/bin/" 6 | fi 7 | 8 | if [ -z "$TEST_DATABASE_URLS" ]; then 9 | echo "Variable TEST_DATABASE_URLS must be set." 10 | exit 1 11 | fi 12 | 13 | set -ex 14 | 15 | if [ -z $GITHUB_ACTIONS ]; then 16 | scripts/check 17 | fi 18 | 19 | ${PREFIX}coverage run -m pytest $@ 20 | 21 | if [ -z $GITHUB_ACTIONS ]; then 22 | scripts/coverage 23 | fi 24 | 25 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [mypy] 2 | disallow_untyped_defs = True 3 | ignore_missing_imports = True 4 | no_implicit_optional = True 5 | disallow_any_generics = false 6 | disallow_untyped_decorators = true 7 | implicit_reexport = true 8 | disallow_incomplete_defs = true 9 | exclude = databases/backends 10 | 11 | [tool:isort] 12 | profile = black 13 | combine_as_imports = True 14 | 15 | [coverage:run] 16 | source = databases, tests 17 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import re 6 | 7 | from setuptools import setup 8 | 9 | 10 | def get_version(package): 11 | """ 12 | Return package version as listed in `__version__` in `init.py`. 13 | """ 14 | with open(os.path.join(package, "__init__.py")) as f: 15 | return re.search("__version__ = ['\"]([^'\"]+)['\"]", f.read()).group(1) 16 | 17 | 18 | def get_long_description(): 19 | """ 20 | Return the README. 21 | """ 22 | with open("README.md", encoding="utf8") as f: 23 | return f.read() 24 | 25 | 26 | def get_packages(package): 27 | """ 28 | Return root package and all sub-packages. 29 | """ 30 | return [ 31 | dirpath 32 | for dirpath, dirnames, filenames in os.walk(package) 33 | if os.path.exists(os.path.join(dirpath, "__init__.py")) 34 | ] 35 | 36 | 37 | setup( 38 | name="databases", 39 | version=get_version("databases"), 40 | python_requires=">=3.8", 41 | url="https://github.com/encode/databases", 42 | license="BSD", 43 | description="Async database support for Python.", 44 | long_description=get_long_description(), 45 | long_description_content_type="text/markdown", 46 | author="Tom Christie", 47 | author_email="tom@tomchristie.com", 48 | packages=get_packages("databases"), 49 | package_data={"databases": ["py.typed"]}, 50 | install_requires=["sqlalchemy>=2.0.7"], 51 | extras_require={ 52 | "postgresql": ["asyncpg"], 53 | "asyncpg": ["asyncpg"], 54 | "aiopg": ["aiopg"], 55 | "mysql": ["aiomysql"], 56 | "aiomysql": ["aiomysql"], 57 | "asyncmy": ["asyncmy"], 58 | "sqlite": ["aiosqlite"], 59 | "aiosqlite": ["aiosqlite"], 60 | }, 61 | classifiers=[ 62 | "Development Status :: 3 - Alpha", 63 | "Environment :: Web Environment", 64 | "Intended Audience :: Developers", 65 | "License :: OSI Approved :: BSD License", 66 | "Operating System :: OS Independent", 67 | "Topic :: Internet :: WWW/HTTP", 68 | "Programming Language :: Python :: 3", 69 | "Programming Language :: Python :: 3.8", 70 | "Programming Language :: Python :: 3.9", 71 | "Programming Language :: Python :: 3.10", 72 | "Programming Language :: Python :: 3.11", 73 | "Programming Language :: Python :: 3.12", 74 | "Programming Language :: Python :: 3 :: Only", 75 | ], 76 | zip_safe=False, 77 | ) 78 | -------------------------------------------------------------------------------- /tests/importer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/encode/databases/ae3fb16f40201d9ed0ed31bea31a289127169568/tests/importer/__init__.py -------------------------------------------------------------------------------- /tests/importer/raise_import_error.py: -------------------------------------------------------------------------------- 1 | # Used by test_importer.py 2 | 3 | myattr = 123 4 | 5 | import does_not_exist 6 | -------------------------------------------------------------------------------- /tests/test_connection_options.py: -------------------------------------------------------------------------------- 1 | """ 2 | Unit tests for the backend connection arguments. 3 | """ 4 | import sys 5 | 6 | import pytest 7 | 8 | from databases.backends.aiopg import AiopgBackend 9 | from databases.backends.postgres import PostgresBackend 10 | from databases.core import DatabaseURL 11 | from tests.test_databases import DATABASE_URLS, async_adapter 12 | 13 | if sys.version_info >= (3, 7): # pragma: no cover 14 | from databases.backends.asyncmy import AsyncMyBackend 15 | 16 | 17 | if sys.version_info < (3, 10): # pragma: no cover 18 | from databases.backends.mysql import MySQLBackend 19 | 20 | 21 | def test_postgres_pool_size(): 22 | backend = PostgresBackend("postgres://localhost/database?min_size=1&max_size=20") 23 | kwargs = backend._get_connection_kwargs() 24 | assert kwargs == {"min_size": 1, "max_size": 20} 25 | 26 | 27 | @async_adapter 28 | async def test_postgres_pool_size_connect(): 29 | for url in DATABASE_URLS: 30 | if DatabaseURL(url).dialect != "postgresql": 31 | continue 32 | backend = PostgresBackend(url + "?min_size=1&max_size=20") 33 | await backend.connect() 34 | await backend.disconnect() 35 | 36 | 37 | def test_postgres_explicit_pool_size(): 38 | backend = PostgresBackend("postgres://localhost/database", min_size=1, max_size=20) 39 | kwargs = backend._get_connection_kwargs() 40 | assert kwargs == {"min_size": 1, "max_size": 20} 41 | 42 | 43 | def test_postgres_ssl(): 44 | backend = PostgresBackend("postgres://localhost/database?ssl=true") 45 | kwargs = backend._get_connection_kwargs() 46 | assert kwargs == {"ssl": True} 47 | 48 | 49 | def test_postgres_ssl_verify_full(): 50 | backend = PostgresBackend("postgres://localhost/database?ssl=verify-full") 51 | kwargs = backend._get_connection_kwargs() 52 | assert kwargs == {"ssl": "verify-full"} 53 | 54 | 55 | def test_postgres_explicit_ssl(): 56 | backend = PostgresBackend("postgres://localhost/database", ssl=True) 57 | kwargs = backend._get_connection_kwargs() 58 | assert kwargs == {"ssl": True} 59 | 60 | 61 | def test_postgres_explicit_ssl_verify_full(): 62 | backend = PostgresBackend("postgres://localhost/database", ssl="verify-full") 63 | kwargs = backend._get_connection_kwargs() 64 | assert kwargs == {"ssl": "verify-full"} 65 | 66 | 67 | def test_postgres_no_extra_options(): 68 | backend = PostgresBackend("postgres://localhost/database") 69 | kwargs = backend._get_connection_kwargs() 70 | assert kwargs == {} 71 | 72 | 73 | def test_postgres_password_as_callable(): 74 | def gen_password(): 75 | return "Foo" 76 | 77 | backend = PostgresBackend( 78 | "postgres://:password@localhost/database", password=gen_password 79 | ) 80 | kwargs = backend._get_connection_kwargs() 81 | assert kwargs == {"password": gen_password} 82 | assert kwargs["password"]() == "Foo" 83 | 84 | 85 | @pytest.mark.skipif(sys.version_info >= (3, 10), reason="requires python3.9 or lower") 86 | def test_mysql_pool_size(): 87 | backend = MySQLBackend("mysql://localhost/database?min_size=1&max_size=20") 88 | kwargs = backend._get_connection_kwargs() 89 | assert kwargs == {"minsize": 1, "maxsize": 20} 90 | 91 | 92 | @pytest.mark.skipif(sys.version_info >= (3, 10), reason="requires python3.9 or lower") 93 | def test_mysql_unix_socket(): 94 | backend = MySQLBackend( 95 | "mysql+aiomysql://username:password@/testsuite?unix_socket=/tmp/mysqld/mysqld.sock" 96 | ) 97 | kwargs = backend._get_connection_kwargs() 98 | assert kwargs == {"unix_socket": "/tmp/mysqld/mysqld.sock"} 99 | 100 | 101 | @pytest.mark.skipif(sys.version_info >= (3, 10), reason="requires python3.9 or lower") 102 | def test_mysql_explicit_pool_size(): 103 | backend = MySQLBackend("mysql://localhost/database", min_size=1, max_size=20) 104 | kwargs = backend._get_connection_kwargs() 105 | assert kwargs == {"minsize": 1, "maxsize": 20} 106 | 107 | 108 | @pytest.mark.skipif(sys.version_info >= (3, 10), reason="requires python3.9 or lower") 109 | def test_mysql_ssl(): 110 | backend = MySQLBackend("mysql://localhost/database?ssl=true") 111 | kwargs = backend._get_connection_kwargs() 112 | assert kwargs == {"ssl": True} 113 | 114 | 115 | @pytest.mark.skipif(sys.version_info >= (3, 10), reason="requires python3.9 or lower") 116 | def test_mysql_explicit_ssl(): 117 | backend = MySQLBackend("mysql://localhost/database", ssl=True) 118 | kwargs = backend._get_connection_kwargs() 119 | assert kwargs == {"ssl": True} 120 | 121 | 122 | @pytest.mark.skipif(sys.version_info >= (3, 10), reason="requires python3.9 or lower") 123 | def test_mysql_pool_recycle(): 124 | backend = MySQLBackend("mysql://localhost/database?pool_recycle=20") 125 | kwargs = backend._get_connection_kwargs() 126 | assert kwargs == {"pool_recycle": 20} 127 | 128 | 129 | @pytest.mark.skipif(sys.version_info < (3, 7), reason="requires python3.7 or higher") 130 | def test_asyncmy_pool_size(): 131 | backend = AsyncMyBackend( 132 | "mysql+asyncmy://localhost/database?min_size=1&max_size=20" 133 | ) 134 | kwargs = backend._get_connection_kwargs() 135 | assert kwargs == {"minsize": 1, "maxsize": 20} 136 | 137 | 138 | @pytest.mark.skipif(sys.version_info < (3, 7), reason="requires python3.7 or higher") 139 | def test_asyncmy_unix_socket(): 140 | backend = AsyncMyBackend( 141 | "mysql+asyncmy://username:password@/testsuite?unix_socket=/tmp/mysqld/mysqld.sock" 142 | ) 143 | kwargs = backend._get_connection_kwargs() 144 | assert kwargs == {"unix_socket": "/tmp/mysqld/mysqld.sock"} 145 | 146 | 147 | @pytest.mark.skipif(sys.version_info < (3, 7), reason="requires python3.7 or higher") 148 | def test_asyncmy_explicit_pool_size(): 149 | backend = AsyncMyBackend("mysql://localhost/database", min_size=1, max_size=20) 150 | kwargs = backend._get_connection_kwargs() 151 | assert kwargs == {"minsize": 1, "maxsize": 20} 152 | 153 | 154 | @pytest.mark.skipif(sys.version_info < (3, 7), reason="requires python3.7 or higher") 155 | def test_asyncmy_ssl(): 156 | backend = AsyncMyBackend("mysql+asyncmy://localhost/database?ssl=true") 157 | kwargs = backend._get_connection_kwargs() 158 | assert kwargs == {"ssl": True} 159 | 160 | 161 | @pytest.mark.skipif(sys.version_info < (3, 7), reason="requires python3.7 or higher") 162 | def test_asyncmy_explicit_ssl(): 163 | backend = AsyncMyBackend("mysql+asyncmy://localhost/database", ssl=True) 164 | kwargs = backend._get_connection_kwargs() 165 | assert kwargs == {"ssl": True} 166 | 167 | 168 | @pytest.mark.skipif(sys.version_info < (3, 7), reason="requires python3.7 or higher") 169 | def test_asyncmy_pool_recycle(): 170 | backend = AsyncMyBackend("mysql+asyncmy://localhost/database?pool_recycle=20") 171 | kwargs = backend._get_connection_kwargs() 172 | assert kwargs == {"pool_recycle": 20} 173 | 174 | 175 | def test_aiopg_pool_size(): 176 | backend = AiopgBackend( 177 | "postgresql+aiopg://localhost/database?min_size=1&max_size=20" 178 | ) 179 | kwargs = backend._get_connection_kwargs() 180 | assert kwargs == {"minsize": 1, "maxsize": 20} 181 | 182 | 183 | def test_aiopg_explicit_pool_size(): 184 | backend = AiopgBackend( 185 | "postgresql+aiopg://localhost/database", min_size=1, max_size=20 186 | ) 187 | kwargs = backend._get_connection_kwargs() 188 | assert kwargs == {"minsize": 1, "maxsize": 20} 189 | 190 | 191 | def test_aiopg_ssl(): 192 | backend = AiopgBackend("postgresql+aiopg://localhost/database?ssl=true") 193 | kwargs = backend._get_connection_kwargs() 194 | assert kwargs == {"ssl": True} 195 | 196 | 197 | def test_aiopg_explicit_ssl(): 198 | backend = AiopgBackend("postgresql+aiopg://localhost/database", ssl=True) 199 | kwargs = backend._get_connection_kwargs() 200 | assert kwargs == {"ssl": True} 201 | -------------------------------------------------------------------------------- /tests/test_database_url.py: -------------------------------------------------------------------------------- 1 | from urllib.parse import quote 2 | 3 | import pytest 4 | 5 | from databases import DatabaseURL 6 | 7 | 8 | def test_database_url_repr(): 9 | u = DatabaseURL("postgresql://localhost/name") 10 | assert repr(u) == "DatabaseURL('postgresql://localhost/name')" 11 | 12 | u = DatabaseURL("postgresql://username@localhost/name") 13 | assert repr(u) == "DatabaseURL('postgresql://username@localhost/name')" 14 | 15 | u = DatabaseURL("postgresql://username:password@localhost/name") 16 | assert repr(u) == "DatabaseURL('postgresql://username:********@localhost/name')" 17 | 18 | u = DatabaseURL(f"postgresql://username:{quote('[password')}@localhost/name") 19 | assert repr(u) == "DatabaseURL('postgresql://username:********@localhost/name')" 20 | 21 | 22 | def test_database_url_properties(): 23 | u = DatabaseURL("postgresql+asyncpg://username:password@localhost:123/mydatabase") 24 | assert u.dialect == "postgresql" 25 | assert u.driver == "asyncpg" 26 | assert u.username == "username" 27 | assert u.password == "password" 28 | assert u.hostname == "localhost" 29 | assert u.port == 123 30 | assert u.database == "mydatabase" 31 | 32 | u = DatabaseURL( 33 | "postgresql://username:password@/mydatabase?host=/var/run/postgresql/.s.PGSQL.5432" 34 | ) 35 | assert u.dialect == "postgresql" 36 | assert u.username == "username" 37 | assert u.password == "password" 38 | assert u.hostname == "/var/run/postgresql/.s.PGSQL.5432" 39 | assert u.database == "mydatabase" 40 | 41 | u = DatabaseURL( 42 | "postgresql://username:password@/mydatabase?unix_sock=/var/run/postgresql/.s.PGSQL.5432" 43 | ) 44 | assert u.hostname == "/var/run/postgresql/.s.PGSQL.5432" 45 | 46 | 47 | def test_database_url_escape(): 48 | u = DatabaseURL(f"postgresql://username:{quote('[password')}@localhost/mydatabase") 49 | assert u.username == "username" 50 | assert u.password == "[password" 51 | assert u.userinfo == f"username:{quote('[password')}".encode("utf-8") 52 | 53 | u2 = DatabaseURL(u) 54 | assert u2.password == "[password" 55 | 56 | u3 = DatabaseURL(str(u)) 57 | assert u3.password == "[password" 58 | 59 | 60 | def test_database_url_constructor(): 61 | with pytest.raises(TypeError): 62 | DatabaseURL(("postgresql", "username", "password", "localhost", "mydatabase")) 63 | 64 | u = DatabaseURL("postgresql+asyncpg://username:password@localhost:123/mydatabase") 65 | assert DatabaseURL(u) == u 66 | 67 | 68 | def test_database_url_options(): 69 | u = DatabaseURL("postgresql://localhost/mydatabase?pool_size=20&ssl=true") 70 | assert u.options == {"pool_size": "20", "ssl": "true"} 71 | 72 | u = DatabaseURL( 73 | "mysql+asyncmy://username:password@/testsuite?unix_socket=/tmp/mysqld/mysqld.sock" 74 | ) 75 | assert u.options == {"unix_socket": "/tmp/mysqld/mysqld.sock"} 76 | 77 | 78 | def test_replace_database_url_components(): 79 | u = DatabaseURL("postgresql://localhost/mydatabase") 80 | 81 | assert u.database == "mydatabase" 82 | new = u.replace(database="test_" + u.database) 83 | assert new.database == "test_mydatabase" 84 | assert str(new) == "postgresql://localhost/test_mydatabase" 85 | 86 | assert u.driver == "" 87 | new = u.replace(driver="asyncpg") 88 | assert new.driver == "asyncpg" 89 | assert str(new) == "postgresql+asyncpg://localhost/mydatabase" 90 | 91 | assert u.port is None 92 | new = u.replace(port=123) 93 | assert new.port == 123 94 | assert str(new) == "postgresql://localhost:123/mydatabase" 95 | 96 | assert u.username is None 97 | assert u.userinfo is None 98 | 99 | u = DatabaseURL("sqlite:///mydatabase") 100 | assert u.database == "mydatabase" 101 | new = u.replace(database="test_" + u.database) 102 | assert new.database == "test_mydatabase" 103 | assert str(new) == "sqlite:///test_mydatabase" 104 | 105 | u = DatabaseURL("sqlite:////absolute/path") 106 | assert u.database == "/absolute/path" 107 | new = u.replace(database=u.database + "_test") 108 | assert new.database == "/absolute/path_test" 109 | assert str(new) == "sqlite:////absolute/path_test" 110 | -------------------------------------------------------------------------------- /tests/test_databases.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import datetime 3 | import decimal 4 | import enum 5 | import functools 6 | import gc 7 | import itertools 8 | import os 9 | import sqlite3 10 | from typing import MutableMapping 11 | from unittest.mock import MagicMock, patch 12 | 13 | import pytest 14 | import sqlalchemy 15 | 16 | from databases import Database, DatabaseURL 17 | 18 | assert "TEST_DATABASE_URLS" in os.environ, "TEST_DATABASE_URLS is not set." 19 | 20 | DATABASE_URLS = [url.strip() for url in os.environ["TEST_DATABASE_URLS"].split(",")] 21 | 22 | 23 | class AsyncMock(MagicMock): 24 | async def __call__(self, *args, **kwargs): 25 | return super(AsyncMock, self).__call__(*args, **kwargs) 26 | 27 | 28 | class MyEpochType(sqlalchemy.types.TypeDecorator): 29 | impl = sqlalchemy.Integer 30 | 31 | epoch = datetime.date(1970, 1, 1) 32 | 33 | def process_bind_param(self, value, dialect): 34 | return (value - self.epoch).days 35 | 36 | def process_result_value(self, value, dialect): 37 | return self.epoch + datetime.timedelta(days=value) 38 | 39 | 40 | metadata = sqlalchemy.MetaData() 41 | 42 | notes = sqlalchemy.Table( 43 | "notes", 44 | metadata, 45 | sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True), 46 | sqlalchemy.Column("text", sqlalchemy.String(length=100)), 47 | sqlalchemy.Column("completed", sqlalchemy.Boolean), 48 | ) 49 | 50 | # Used to test DateTime 51 | articles = sqlalchemy.Table( 52 | "articles", 53 | metadata, 54 | sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True), 55 | sqlalchemy.Column("title", sqlalchemy.String(length=100)), 56 | sqlalchemy.Column("published", sqlalchemy.DateTime), 57 | ) 58 | 59 | # Used to test Date 60 | events = sqlalchemy.Table( 61 | "events", 62 | metadata, 63 | sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True), 64 | sqlalchemy.Column("date", sqlalchemy.Date), 65 | ) 66 | 67 | 68 | # Used to test Time 69 | daily_schedule = sqlalchemy.Table( 70 | "daily_schedule", 71 | metadata, 72 | sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True), 73 | sqlalchemy.Column("time", sqlalchemy.Time), 74 | ) 75 | 76 | 77 | class TshirtSize(enum.Enum): 78 | SMALL = "SMALL" 79 | MEDIUM = "MEDIUM" 80 | LARGE = "LARGE" 81 | XL = "XL" 82 | 83 | 84 | class TshirtColor(enum.Enum): 85 | BLUE = 0 86 | GREEN = 1 87 | YELLOW = 2 88 | RED = 3 89 | 90 | 91 | # Used to test Enum 92 | tshirt_size = sqlalchemy.Table( 93 | "tshirt_size", 94 | metadata, 95 | sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True), 96 | sqlalchemy.Column("size", sqlalchemy.Enum(TshirtSize)), 97 | sqlalchemy.Column("color", sqlalchemy.Enum(TshirtColor)), 98 | ) 99 | 100 | # Used to test JSON 101 | session = sqlalchemy.Table( 102 | "session", 103 | metadata, 104 | sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True), 105 | sqlalchemy.Column("data", sqlalchemy.JSON), 106 | ) 107 | 108 | # Used to test custom column types 109 | custom_date = sqlalchemy.Table( 110 | "custom_date", 111 | metadata, 112 | sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True), 113 | sqlalchemy.Column("title", sqlalchemy.String(length=100)), 114 | sqlalchemy.Column("published", MyEpochType), 115 | ) 116 | 117 | # Used to test Numeric 118 | prices = sqlalchemy.Table( 119 | "prices", 120 | metadata, 121 | sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True), 122 | sqlalchemy.Column("price", sqlalchemy.Numeric(precision=30, scale=20)), 123 | ) 124 | 125 | 126 | @pytest.fixture(autouse=True, scope="function") 127 | def create_test_database(): 128 | # Create test databases with tables creation 129 | for url in DATABASE_URLS: 130 | database_url = DatabaseURL(url) 131 | if database_url.scheme in ["mysql", "mysql+aiomysql", "mysql+asyncmy"]: 132 | url = str(database_url.replace(driver="pymysql")) 133 | elif database_url.scheme in [ 134 | "postgresql+aiopg", 135 | "sqlite+aiosqlite", 136 | "postgresql+asyncpg", 137 | ]: 138 | url = str(database_url.replace(driver=None)) 139 | engine = sqlalchemy.create_engine(url) 140 | metadata.create_all(engine) 141 | 142 | # Run the test suite 143 | yield 144 | 145 | # Drop test databases 146 | for url in DATABASE_URLS: 147 | database_url = DatabaseURL(url) 148 | if database_url.scheme in ["mysql", "mysql+aiomysql", "mysql+asyncmy"]: 149 | url = str(database_url.replace(driver="pymysql")) 150 | elif database_url.scheme in [ 151 | "postgresql+aiopg", 152 | "sqlite+aiosqlite", 153 | "postgresql+asyncpg", 154 | ]: 155 | url = str(database_url.replace(driver=None)) 156 | engine = sqlalchemy.create_engine(url) 157 | metadata.drop_all(engine) 158 | 159 | # Run garbage collection to ensure any in-memory databases are dropped 160 | gc.collect() 161 | 162 | 163 | def async_adapter(wrapped_func): 164 | """ 165 | Decorator used to run async test cases. 166 | """ 167 | 168 | @functools.wraps(wrapped_func) 169 | def run_sync(*args, **kwargs): 170 | loop = asyncio.new_event_loop() 171 | task = wrapped_func(*args, **kwargs) 172 | return loop.run_until_complete(task) 173 | 174 | return run_sync 175 | 176 | 177 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 178 | @async_adapter 179 | async def test_queries(database_url): 180 | """ 181 | Test that the basic `execute()`, `execute_many()`, `fetch_all()``, and 182 | `fetch_one()` interfaces are all supported (using SQLAlchemy core). 183 | """ 184 | async with Database(database_url) as database: 185 | async with database.transaction(force_rollback=True): 186 | # execute() 187 | query = notes.insert() 188 | values = {"text": "example1", "completed": True} 189 | await database.execute(query, values) 190 | 191 | # execute_many() 192 | query = notes.insert() 193 | values = [ 194 | {"text": "example2", "completed": False}, 195 | {"text": "example3", "completed": True}, 196 | ] 197 | await database.execute_many(query, values) 198 | 199 | # fetch_all() 200 | query = notes.select() 201 | results = await database.fetch_all(query=query) 202 | 203 | assert len(results) == 3 204 | assert results[0]["text"] == "example1" 205 | assert results[0]["completed"] == True 206 | assert results[1]["text"] == "example2" 207 | assert results[1]["completed"] == False 208 | assert results[2]["text"] == "example3" 209 | assert results[2]["completed"] == True 210 | 211 | # fetch_one() 212 | query = notes.select() 213 | result = await database.fetch_one(query=query) 214 | assert result["text"] == "example1" 215 | assert result["completed"] == True 216 | 217 | # fetch_val() 218 | query = sqlalchemy.sql.select(*[notes.c.text]) 219 | result = await database.fetch_val(query=query) 220 | assert result == "example1" 221 | 222 | # fetch_val() with no rows 223 | query = sqlalchemy.sql.select(*[notes.c.text]).where( 224 | notes.c.text == "impossible" 225 | ) 226 | result = await database.fetch_val(query=query) 227 | assert result is None 228 | 229 | # fetch_val() with a different column 230 | query = sqlalchemy.sql.select(*[notes.c.id, notes.c.text]) 231 | result = await database.fetch_val(query=query, column=1) 232 | assert result == "example1" 233 | 234 | # row access (needed to maintain test coverage for Record.__getitem__ in postgres backend) 235 | query = sqlalchemy.sql.select(*[notes.c.text]) 236 | result = await database.fetch_one(query=query) 237 | assert result["text"] == "example1" 238 | assert result[0] == "example1" 239 | 240 | # iterate() 241 | query = notes.select() 242 | iterate_results = [] 243 | async for result in database.iterate(query=query): 244 | iterate_results.append(result) 245 | assert len(iterate_results) == 3 246 | assert iterate_results[0]["text"] == "example1" 247 | assert iterate_results[0]["completed"] == True 248 | assert iterate_results[1]["text"] == "example2" 249 | assert iterate_results[1]["completed"] == False 250 | assert iterate_results[2]["text"] == "example3" 251 | assert iterate_results[2]["completed"] == True 252 | 253 | 254 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 255 | @async_adapter 256 | async def test_queries_raw(database_url): 257 | """ 258 | Test that the basic `execute()`, `execute_many()`, `fetch_all()``, and 259 | `fetch_one()` interfaces are all supported (raw queries). 260 | """ 261 | async with Database(database_url) as database: 262 | async with database.transaction(force_rollback=True): 263 | # execute() 264 | query = "INSERT INTO notes(text, completed) VALUES (:text, :completed)" 265 | values = {"text": "example1", "completed": True} 266 | await database.execute(query, values) 267 | 268 | # execute_many() 269 | query = "INSERT INTO notes(text, completed) VALUES (:text, :completed)" 270 | values = [ 271 | {"text": "example2", "completed": False}, 272 | {"text": "example3", "completed": True}, 273 | ] 274 | await database.execute_many(query, values) 275 | 276 | # fetch_all() 277 | query = "SELECT * FROM notes WHERE completed = :completed" 278 | results = await database.fetch_all(query=query, values={"completed": True}) 279 | assert len(results) == 2 280 | assert results[0]["text"] == "example1" 281 | assert results[0]["completed"] == True 282 | assert results[1]["text"] == "example3" 283 | assert results[1]["completed"] == True 284 | 285 | # fetch_one() 286 | query = "SELECT * FROM notes WHERE completed = :completed" 287 | result = await database.fetch_one(query=query, values={"completed": False}) 288 | assert result["text"] == "example2" 289 | assert result["completed"] == False 290 | 291 | # fetch_val() 292 | query = "SELECT completed FROM notes WHERE text = :text" 293 | result = await database.fetch_val(query=query, values={"text": "example1"}) 294 | assert result == True 295 | 296 | query = "SELECT * FROM notes WHERE text = :text" 297 | result = await database.fetch_val( 298 | query=query, values={"text": "example1"}, column="completed" 299 | ) 300 | assert result == True 301 | 302 | # iterate() 303 | query = "SELECT * FROM notes" 304 | iterate_results = [] 305 | async for result in database.iterate(query=query): 306 | iterate_results.append(result) 307 | assert len(iterate_results) == 3 308 | assert iterate_results[0]["text"] == "example1" 309 | assert iterate_results[0]["completed"] == True 310 | assert iterate_results[1]["text"] == "example2" 311 | assert iterate_results[1]["completed"] == False 312 | assert iterate_results[2]["text"] == "example3" 313 | assert iterate_results[2]["completed"] == True 314 | 315 | 316 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 317 | @async_adapter 318 | async def test_ddl_queries(database_url): 319 | """ 320 | Test that the built-in DDL elements such as `DropTable()`, 321 | `CreateTable()` are supported (using SQLAlchemy core). 322 | """ 323 | async with Database(database_url) as database: 324 | async with database.transaction(force_rollback=True): 325 | # DropTable() 326 | query = sqlalchemy.schema.DropTable(notes) 327 | await database.execute(query) 328 | 329 | # CreateTable() 330 | query = sqlalchemy.schema.CreateTable(notes) 331 | await database.execute(query) 332 | 333 | 334 | @pytest.mark.parametrize("exception", [Exception, asyncio.CancelledError]) 335 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 336 | @async_adapter 337 | async def test_queries_after_error(database_url, exception): 338 | """ 339 | Test that the basic `execute()` works after a previous error. 340 | """ 341 | 342 | async with Database(database_url) as database: 343 | with patch.object( 344 | database.connection()._connection, 345 | "acquire", 346 | new=AsyncMock(side_effect=exception), 347 | ): 348 | with pytest.raises(exception): 349 | query = notes.select() 350 | await database.fetch_all(query) 351 | 352 | query = notes.select() 353 | await database.fetch_all(query) 354 | 355 | 356 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 357 | @async_adapter 358 | async def test_results_support_mapping_interface(database_url): 359 | """ 360 | Casting results to a dict should work, since the interface defines them 361 | as supporting the mapping interface. 362 | """ 363 | 364 | async with Database(database_url) as database: 365 | async with database.transaction(force_rollback=True): 366 | # execute() 367 | query = notes.insert() 368 | values = {"text": "example1", "completed": True} 369 | await database.execute(query, values) 370 | 371 | # fetch_all() 372 | query = notes.select() 373 | results = await database.fetch_all(query=query) 374 | results_as_dicts = [dict(item) for item in results] 375 | 376 | assert len(results[0]) == 3 377 | assert len(results_as_dicts[0]) == 3 378 | 379 | assert isinstance(results_as_dicts[0]["id"], int) 380 | assert results_as_dicts[0]["text"] == "example1" 381 | assert results_as_dicts[0]["completed"] == True 382 | 383 | 384 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 385 | @async_adapter 386 | async def test_results_support_column_reference(database_url): 387 | """ 388 | Casting results to a dict should work, since the interface defines them 389 | as supporting the mapping interface. 390 | """ 391 | async with Database(database_url) as database: 392 | async with database.transaction(force_rollback=True): 393 | now = datetime.datetime.now().replace(microsecond=0) 394 | today = datetime.date.today() 395 | 396 | # execute() 397 | query = articles.insert() 398 | values = {"title": "Hello, world Article", "published": now} 399 | await database.execute(query, values) 400 | 401 | query = custom_date.insert() 402 | values = {"title": "Hello, world Custom", "published": today} 403 | await database.execute(query, values) 404 | 405 | # fetch_all() 406 | query = sqlalchemy.select(*[articles, custom_date]) 407 | results = await database.fetch_all(query=query) 408 | assert len(results) == 1 409 | assert results[0][articles.c.title] == "Hello, world Article" 410 | assert results[0][articles.c.published] == now 411 | assert results[0][custom_date.c.title] == "Hello, world Custom" 412 | assert results[0][custom_date.c.published] == today 413 | 414 | 415 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 416 | @async_adapter 417 | async def test_result_values_allow_duplicate_names(database_url): 418 | """ 419 | The values of a result should respect when two columns are selected 420 | with the same name. 421 | """ 422 | async with Database(database_url) as database: 423 | async with database.transaction(force_rollback=True): 424 | query = "SELECT 1 AS id, 2 AS id" 425 | row = await database.fetch_one(query=query) 426 | 427 | assert list(row._mapping.keys()) == ["id", "id"] 428 | assert list(row._mapping.values()) == [1, 2] 429 | 430 | 431 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 432 | @async_adapter 433 | async def test_fetch_one_returning_no_results(database_url): 434 | """ 435 | fetch_one should return `None` when no results match. 436 | """ 437 | async with Database(database_url) as database: 438 | async with database.transaction(force_rollback=True): 439 | # fetch_all() 440 | query = notes.select() 441 | result = await database.fetch_one(query=query) 442 | assert result is None 443 | 444 | 445 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 446 | @async_adapter 447 | async def test_execute_return_val(database_url): 448 | """ 449 | Test using return value from `execute()` to get an inserted primary key. 450 | """ 451 | async with Database(database_url) as database: 452 | async with database.transaction(force_rollback=True): 453 | query = notes.insert() 454 | values = {"text": "example1", "completed": True} 455 | pk = await database.execute(query, values) 456 | assert isinstance(pk, int) 457 | 458 | # Apparently for `aiopg` it's OID that will always 0 in this case 459 | # As it's only one action within this cursor life cycle 460 | # It's recommended to use the `RETURNING` clause 461 | # For obtaining the record id 462 | if database.url.scheme == "postgresql+aiopg": 463 | assert pk == 0 464 | else: 465 | query = notes.select().where(notes.c.id == pk) 466 | result = await database.fetch_one(query) 467 | assert result["text"] == "example1" 468 | assert result["completed"] == True 469 | 470 | 471 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 472 | @async_adapter 473 | async def test_rollback_isolation(database_url): 474 | """ 475 | Ensure that `database.transaction(force_rollback=True)` provides strict isolation. 476 | """ 477 | 478 | async with Database(database_url) as database: 479 | # Perform some INSERT operations on the database. 480 | async with database.transaction(force_rollback=True): 481 | query = notes.insert().values(text="example1", completed=True) 482 | await database.execute(query) 483 | 484 | # Ensure INSERT operations have been rolled back. 485 | query = notes.select() 486 | results = await database.fetch_all(query=query) 487 | assert len(results) == 0 488 | 489 | 490 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 491 | @async_adapter 492 | async def test_rollback_isolation_with_contextmanager(database_url): 493 | """ 494 | Ensure that `database.force_rollback()` provides strict isolation. 495 | """ 496 | 497 | database = Database(database_url) 498 | 499 | with database.force_rollback(): 500 | async with database: 501 | # Perform some INSERT operations on the database. 502 | query = notes.insert().values(text="example1", completed=True) 503 | await database.execute(query) 504 | 505 | async with database: 506 | # Ensure INSERT operations have been rolled back. 507 | query = notes.select() 508 | results = await database.fetch_all(query=query) 509 | assert len(results) == 0 510 | 511 | 512 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 513 | @async_adapter 514 | async def test_transaction_commit(database_url): 515 | """ 516 | Ensure that transaction commit is supported. 517 | """ 518 | async with Database(database_url) as database: 519 | async with database.transaction(force_rollback=True): 520 | async with database.transaction(): 521 | query = notes.insert().values(text="example1", completed=True) 522 | await database.execute(query) 523 | 524 | query = notes.select() 525 | results = await database.fetch_all(query=query) 526 | assert len(results) == 1 527 | 528 | 529 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 530 | @async_adapter 531 | async def test_transaction_context_child_task_inheritance(database_url): 532 | """ 533 | Ensure that transactions are inherited by child tasks. 534 | """ 535 | async with Database(database_url) as database: 536 | 537 | async def check_transaction(transaction, active_transaction): 538 | # Should have inherited the same transaction backend from the parent task 539 | assert transaction._transaction is active_transaction 540 | 541 | async with database.transaction() as transaction: 542 | await asyncio.create_task( 543 | check_transaction(transaction, transaction._transaction) 544 | ) 545 | 546 | 547 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 548 | @async_adapter 549 | async def test_transaction_context_child_task_inheritance_example(database_url): 550 | """ 551 | Ensure that child tasks may influence inherited transactions. 552 | """ 553 | # This is an practical example of the above test. 554 | async with Database(database_url) as database: 555 | async with database.transaction(): 556 | # Create a note 557 | await database.execute( 558 | notes.insert().values(id=1, text="setup", completed=True) 559 | ) 560 | 561 | # Change the note from the same task 562 | await database.execute( 563 | notes.update().where(notes.c.id == 1).values(text="prior") 564 | ) 565 | 566 | # Confirm the change 567 | result = await database.fetch_one(notes.select().where(notes.c.id == 1)) 568 | assert result.text == "prior" 569 | 570 | async def run_update_from_child_task(connection): 571 | # Change the note from a child task 572 | await connection.execute( 573 | notes.update().where(notes.c.id == 1).values(text="test") 574 | ) 575 | 576 | await asyncio.create_task(run_update_from_child_task(database.connection())) 577 | 578 | # Confirm the child's change 579 | result = await database.fetch_one(notes.select().where(notes.c.id == 1)) 580 | assert result.text == "test" 581 | 582 | 583 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 584 | @async_adapter 585 | async def test_transaction_context_sibling_task_isolation(database_url): 586 | """ 587 | Ensure that transactions are isolated between sibling tasks. 588 | """ 589 | start = asyncio.Event() 590 | end = asyncio.Event() 591 | 592 | async with Database(database_url) as database: 593 | 594 | async def check_transaction(transaction): 595 | await start.wait() 596 | # Parent task is now in a transaction, we should not 597 | # see its transaction backend since this task was 598 | # _started_ in a context where no transaction was active. 599 | assert transaction._transaction is None 600 | end.set() 601 | 602 | transaction = database.transaction() 603 | assert transaction._transaction is None 604 | task = asyncio.create_task(check_transaction(transaction)) 605 | 606 | async with transaction: 607 | start.set() 608 | assert transaction._transaction is not None 609 | await end.wait() 610 | 611 | # Cleanup for "Task not awaited" warning 612 | await task 613 | 614 | 615 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 616 | @async_adapter 617 | async def test_transaction_context_sibling_task_isolation_example(database_url): 618 | """ 619 | Ensure that transactions are running in sibling tasks are isolated from eachother. 620 | """ 621 | # This is an practical example of the above test. 622 | setup = asyncio.Event() 623 | done = asyncio.Event() 624 | 625 | async def tx1(connection): 626 | async with connection.transaction(): 627 | await db.execute( 628 | notes.insert(), values={"id": 1, "text": "tx1", "completed": False} 629 | ) 630 | setup.set() 631 | await done.wait() 632 | 633 | async def tx2(connection): 634 | async with connection.transaction(): 635 | await setup.wait() 636 | result = await db.fetch_all(notes.select()) 637 | assert result == [], result 638 | done.set() 639 | 640 | async with Database(database_url) as db: 641 | await asyncio.gather(tx1(db), tx2(db)) 642 | 643 | 644 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 645 | @async_adapter 646 | async def test_connection_cleanup_contextmanager(database_url): 647 | """ 648 | Ensure that task connections are not persisted unecessarily. 649 | """ 650 | 651 | ready = asyncio.Event() 652 | done = asyncio.Event() 653 | 654 | async def check_child_connection(database: Database): 655 | async with database.connection(): 656 | ready.set() 657 | await done.wait() 658 | 659 | async with Database(database_url) as database: 660 | # Should have a connection in this task 661 | # .connect is lazy, it doesn't create a Connection, but .connection does 662 | connection = database.connection() 663 | assert isinstance(database._connection_map, MutableMapping) 664 | assert database._connection_map.get(asyncio.current_task()) is connection 665 | 666 | # Create a child task and see if it registers a connection 667 | task = asyncio.create_task(check_child_connection(database)) 668 | await ready.wait() 669 | assert database._connection_map.get(task) is not None 670 | assert database._connection_map.get(task) is not connection 671 | 672 | # Let the child task finish, and see if it cleaned up 673 | done.set() 674 | await task 675 | # This is normal exit logic cleanup, the WeakKeyDictionary 676 | # shouldn't have cleaned up yet since the task is still referenced 677 | assert task not in database._connection_map 678 | 679 | # Context manager closes, all open connections are removed 680 | assert isinstance(database._connection_map, MutableMapping) 681 | assert len(database._connection_map) == 0 682 | 683 | 684 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 685 | @async_adapter 686 | async def test_connection_cleanup_garbagecollector(database_url): 687 | """ 688 | Ensure that connections for tasks are not persisted unecessarily, even 689 | if exit handlers are not called. 690 | """ 691 | database = Database(database_url) 692 | await database.connect() 693 | 694 | created = asyncio.Event() 695 | 696 | async def check_child_connection(database: Database): 697 | # neither .disconnect nor .__aexit__ are called before deleting this task 698 | database.connection() 699 | created.set() 700 | 701 | task = asyncio.create_task(check_child_connection(database)) 702 | await created.wait() 703 | assert task in database._connection_map 704 | await task 705 | del task 706 | gc.collect() 707 | 708 | # Should not have a connection for the task anymore 709 | assert len(database._connection_map) == 0 710 | 711 | 712 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 713 | @async_adapter 714 | async def test_transaction_context_cleanup_contextmanager(database_url): 715 | """ 716 | Ensure that contextvar transactions are not persisted unecessarily. 717 | """ 718 | from databases.core import _ACTIVE_TRANSACTIONS 719 | 720 | assert _ACTIVE_TRANSACTIONS.get() is None 721 | 722 | async with Database(database_url) as database: 723 | async with database.transaction() as transaction: 724 | open_transactions = _ACTIVE_TRANSACTIONS.get() 725 | assert isinstance(open_transactions, MutableMapping) 726 | assert open_transactions.get(transaction) is transaction._transaction 727 | 728 | # Context manager closes, open_transactions is cleaned up 729 | open_transactions = _ACTIVE_TRANSACTIONS.get() 730 | assert isinstance(open_transactions, MutableMapping) 731 | assert open_transactions.get(transaction, None) is None 732 | 733 | 734 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 735 | @async_adapter 736 | async def test_transaction_context_cleanup_garbagecollector(database_url): 737 | """ 738 | Ensure that contextvar transactions are not persisted unecessarily, even 739 | if exit handlers are not called. 740 | 741 | This test should be an XFAIL, but cannot be due to the way that is hangs 742 | during teardown. 743 | """ 744 | from databases.core import _ACTIVE_TRANSACTIONS 745 | 746 | assert _ACTIVE_TRANSACTIONS.get() is None 747 | 748 | async with Database(database_url) as database: 749 | transaction = database.transaction() 750 | await transaction.start() 751 | 752 | # Should be tracking the transaction 753 | open_transactions = _ACTIVE_TRANSACTIONS.get() 754 | assert isinstance(open_transactions, MutableMapping) 755 | assert open_transactions.get(transaction) is transaction._transaction 756 | 757 | # neither .commit, .rollback, nor .__aexit__ are called 758 | del transaction 759 | gc.collect() 760 | 761 | # TODO(zevisert,review): Could skip instead of using the logic below 762 | # A strong reference to the transaction is kept alive by the connection's 763 | # ._transaction_stack, so it is still be tracked at this point. 764 | assert len(open_transactions) == 1 765 | 766 | # If that were magically cleared, the transaction would be cleaned up, 767 | # but as it stands this always causes a hang during teardown at 768 | # `Database(...).disconnect()` if the transaction is not closed. 769 | transaction = database.connection()._transaction_stack[-1] 770 | await transaction.rollback() 771 | del transaction 772 | 773 | # Now with the transaction rolled-back, it should be cleaned up. 774 | assert len(open_transactions) == 0 775 | 776 | 777 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 778 | @async_adapter 779 | async def test_transaction_commit_serializable(database_url): 780 | """ 781 | Ensure that serializable transaction commit via extra parameters is supported. 782 | """ 783 | 784 | database_url = DatabaseURL(database_url) 785 | 786 | if database_url.scheme not in ["postgresql", "postgresql+asyncpg"]: 787 | pytest.skip("Test (currently) only supports asyncpg") 788 | 789 | if database_url.scheme == "postgresql+asyncpg": 790 | database_url = database_url.replace(driver=None) 791 | 792 | def insert_independently(): 793 | engine = sqlalchemy.create_engine(str(database_url)) 794 | conn = engine.connect() 795 | 796 | query = notes.insert().values(text="example1", completed=True) 797 | conn.execute(query) 798 | conn.close() 799 | 800 | def delete_independently(): 801 | engine = sqlalchemy.create_engine(str(database_url)) 802 | conn = engine.connect() 803 | 804 | query = notes.delete() 805 | conn.execute(query) 806 | conn.close() 807 | 808 | async with Database(database_url) as database: 809 | async with database.transaction(force_rollback=True, isolation="serializable"): 810 | query = notes.select() 811 | results = await database.fetch_all(query=query) 812 | assert len(results) == 0 813 | 814 | insert_independently() 815 | 816 | query = notes.select() 817 | results = await database.fetch_all(query=query) 818 | assert len(results) == 0 819 | 820 | delete_independently() 821 | 822 | 823 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 824 | @async_adapter 825 | async def test_transaction_rollback(database_url): 826 | """ 827 | Ensure that transaction rollback is supported. 828 | """ 829 | 830 | async with Database(database_url) as database: 831 | async with database.transaction(force_rollback=True): 832 | try: 833 | async with database.transaction(): 834 | query = notes.insert().values(text="example1", completed=True) 835 | await database.execute(query) 836 | raise RuntimeError() 837 | except RuntimeError: 838 | pass 839 | 840 | query = notes.select() 841 | results = await database.fetch_all(query=query) 842 | assert len(results) == 0 843 | 844 | 845 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 846 | @async_adapter 847 | async def test_transaction_commit_low_level(database_url): 848 | """ 849 | Ensure that an explicit `await transaction.commit()` is supported. 850 | """ 851 | 852 | async with Database(database_url) as database: 853 | async with database.transaction(force_rollback=True): 854 | transaction = await database.transaction() 855 | try: 856 | query = notes.insert().values(text="example1", completed=True) 857 | await database.execute(query) 858 | except: # pragma: no cover 859 | await transaction.rollback() 860 | else: 861 | await transaction.commit() 862 | 863 | query = notes.select() 864 | results = await database.fetch_all(query=query) 865 | assert len(results) == 1 866 | 867 | 868 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 869 | @async_adapter 870 | async def test_transaction_rollback_low_level(database_url): 871 | """ 872 | Ensure that an explicit `await transaction.rollback()` is supported. 873 | """ 874 | 875 | async with Database(database_url) as database: 876 | async with database.transaction(force_rollback=True): 877 | transaction = await database.transaction() 878 | try: 879 | query = notes.insert().values(text="example1", completed=True) 880 | await database.execute(query) 881 | raise RuntimeError() 882 | except: 883 | await transaction.rollback() 884 | else: # pragma: no cover 885 | await transaction.commit() 886 | 887 | query = notes.select() 888 | results = await database.fetch_all(query=query) 889 | assert len(results) == 0 890 | 891 | 892 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 893 | @async_adapter 894 | async def test_transaction_decorator(database_url): 895 | """ 896 | Ensure that @database.transaction() is supported. 897 | """ 898 | database = Database(database_url, force_rollback=True) 899 | 900 | @database.transaction() 901 | async def insert_data(raise_exception): 902 | query = notes.insert().values(text="example", completed=True) 903 | await database.execute(query) 904 | if raise_exception: 905 | raise RuntimeError() 906 | 907 | async with database: 908 | with pytest.raises(RuntimeError): 909 | await insert_data(raise_exception=True) 910 | 911 | results = await database.fetch_all(query=notes.select()) 912 | assert len(results) == 0 913 | 914 | await insert_data(raise_exception=False) 915 | 916 | results = await database.fetch_all(query=notes.select()) 917 | assert len(results) == 1 918 | 919 | 920 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 921 | @async_adapter 922 | async def test_transaction_decorator_concurrent(database_url): 923 | """ 924 | Ensure that @database.transaction() can be called concurrently. 925 | """ 926 | 927 | database = Database(database_url) 928 | 929 | @database.transaction() 930 | async def insert_data(): 931 | await database.execute( 932 | query=notes.insert().values(text="example", completed=True) 933 | ) 934 | 935 | async with database: 936 | await asyncio.gather( 937 | insert_data(), 938 | insert_data(), 939 | insert_data(), 940 | insert_data(), 941 | insert_data(), 942 | insert_data(), 943 | ) 944 | 945 | results = await database.fetch_all(query=notes.select()) 946 | assert len(results) == 6 947 | 948 | 949 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 950 | @async_adapter 951 | async def test_datetime_field(database_url): 952 | """ 953 | Test DataTime columns, to ensure records are coerced to/from proper Python types. 954 | """ 955 | 956 | async with Database(database_url) as database: 957 | async with database.transaction(force_rollback=True): 958 | now = datetime.datetime.now().replace(microsecond=0) 959 | 960 | # execute() 961 | query = articles.insert() 962 | values = {"title": "Hello, world", "published": now} 963 | await database.execute(query, values) 964 | 965 | # fetch_all() 966 | query = articles.select() 967 | results = await database.fetch_all(query=query) 968 | assert len(results) == 1 969 | assert results[0]["title"] == "Hello, world" 970 | assert results[0]["published"] == now 971 | 972 | 973 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 974 | @async_adapter 975 | async def test_date_field(database_url): 976 | """ 977 | Test Date columns, to ensure records are coerced to/from proper Python types. 978 | """ 979 | 980 | async with Database(database_url) as database: 981 | async with database.transaction(force_rollback=True): 982 | now = datetime.date.today() 983 | 984 | # execute() 985 | query = events.insert() 986 | values = {"date": now} 987 | await database.execute(query, values) 988 | 989 | # fetch_all() 990 | query = events.select() 991 | results = await database.fetch_all(query=query) 992 | assert len(results) == 1 993 | assert results[0]["date"] == now 994 | 995 | 996 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 997 | @async_adapter 998 | async def test_time_field(database_url): 999 | """ 1000 | Test Time columns, to ensure records are coerced to/from proper Python types. 1001 | """ 1002 | 1003 | async with Database(database_url) as database: 1004 | async with database.transaction(force_rollback=True): 1005 | now = datetime.datetime.now().time().replace(microsecond=0) 1006 | 1007 | # execute() 1008 | query = daily_schedule.insert() 1009 | values = {"time": now} 1010 | await database.execute(query, values) 1011 | 1012 | # fetch_all() 1013 | query = daily_schedule.select() 1014 | results = await database.fetch_all(query=query) 1015 | assert len(results) == 1 1016 | assert results[0]["time"] == now 1017 | 1018 | 1019 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 1020 | @async_adapter 1021 | async def test_decimal_field(database_url): 1022 | """ 1023 | Test Decimal (NUMERIC) columns, to ensure records are coerced to/from proper Python types. 1024 | """ 1025 | 1026 | async with Database(database_url) as database: 1027 | async with database.transaction(force_rollback=True): 1028 | price = decimal.Decimal("0.700000000000001") 1029 | 1030 | # execute() 1031 | query = prices.insert() 1032 | values = {"price": price} 1033 | await database.execute(query, values) 1034 | 1035 | # fetch_all() 1036 | query = prices.select() 1037 | results = await database.fetch_all(query=query) 1038 | assert len(results) == 1 1039 | if database_url.startswith("sqlite"): 1040 | # aiosqlite does not support native decimals --> a roud-off error is expected 1041 | assert results[0]["price"] == pytest.approx(price) 1042 | else: 1043 | assert results[0]["price"] == price 1044 | 1045 | 1046 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 1047 | @async_adapter 1048 | async def test_enum_field(database_url): 1049 | """ 1050 | Test enum columns, to ensure correct cross-database support. 1051 | """ 1052 | 1053 | async with Database(database_url) as database: 1054 | async with database.transaction(force_rollback=True): 1055 | # execute() 1056 | size = TshirtSize.SMALL 1057 | color = TshirtColor.GREEN 1058 | values = {"size": size, "color": color} 1059 | query = tshirt_size.insert() 1060 | await database.execute(query, values) 1061 | 1062 | # fetch_all() 1063 | query = tshirt_size.select() 1064 | results = await database.fetch_all(query=query) 1065 | 1066 | assert len(results) == 1 1067 | assert results[0]["size"] == size 1068 | assert results[0]["color"] == color 1069 | 1070 | 1071 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 1072 | @async_adapter 1073 | async def test_json_dict_field(database_url): 1074 | """ 1075 | Test JSON columns, to ensure correct cross-database support. 1076 | """ 1077 | 1078 | async with Database(database_url) as database: 1079 | async with database.transaction(force_rollback=True): 1080 | # execute() 1081 | data = {"text": "hello", "boolean": True, "int": 1} 1082 | values = {"data": data} 1083 | query = session.insert() 1084 | await database.execute(query, values) 1085 | 1086 | # fetch_all() 1087 | query = session.select() 1088 | results = await database.fetch_all(query=query) 1089 | 1090 | assert len(results) == 1 1091 | assert results[0]["data"] == {"text": "hello", "boolean": True, "int": 1} 1092 | 1093 | 1094 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 1095 | @async_adapter 1096 | async def test_json_list_field(database_url): 1097 | """ 1098 | Test JSON columns, to ensure correct cross-database support. 1099 | """ 1100 | 1101 | async with Database(database_url) as database: 1102 | async with database.transaction(force_rollback=True): 1103 | # execute() 1104 | data = ["lemon", "raspberry", "lime", "pumice"] 1105 | values = {"data": data} 1106 | query = session.insert() 1107 | await database.execute(query, values) 1108 | 1109 | # fetch_all() 1110 | query = session.select() 1111 | results = await database.fetch_all(query=query) 1112 | 1113 | assert len(results) == 1 1114 | assert results[0]["data"] == ["lemon", "raspberry", "lime", "pumice"] 1115 | 1116 | 1117 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 1118 | @async_adapter 1119 | async def test_custom_field(database_url): 1120 | """ 1121 | Test custom column types. 1122 | """ 1123 | 1124 | async with Database(database_url) as database: 1125 | async with database.transaction(force_rollback=True): 1126 | today = datetime.date.today() 1127 | 1128 | # execute() 1129 | query = custom_date.insert() 1130 | values = {"title": "Hello, world", "published": today} 1131 | 1132 | await database.execute(query, values) 1133 | 1134 | # fetch_all() 1135 | query = custom_date.select() 1136 | results = await database.fetch_all(query=query) 1137 | assert len(results) == 1 1138 | assert results[0]["title"] == "Hello, world" 1139 | assert results[0]["published"] == today 1140 | 1141 | 1142 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 1143 | @async_adapter 1144 | async def test_connections_isolation(database_url): 1145 | """ 1146 | Ensure that changes are visible between different connections. 1147 | To check this we have to not create a transaction, so that 1148 | each query ends up on a different connection from the pool. 1149 | """ 1150 | 1151 | async with Database(database_url) as database: 1152 | try: 1153 | query = notes.insert().values(text="example1", completed=True) 1154 | await database.execute(query) 1155 | 1156 | query = notes.select() 1157 | results = await database.fetch_all(query=query) 1158 | assert len(results) == 1 1159 | finally: 1160 | query = notes.delete() 1161 | await database.execute(query) 1162 | 1163 | 1164 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 1165 | @async_adapter 1166 | async def test_commit_on_root_transaction(database_url): 1167 | """ 1168 | Because our tests are generally wrapped in rollback-islation, they 1169 | don't have coverage for commiting the root transaction. 1170 | 1171 | Deal with this here, and delete the records rather than rolling back. 1172 | """ 1173 | 1174 | async with Database(database_url) as database: 1175 | try: 1176 | async with database.transaction(): 1177 | query = notes.insert().values(text="example1", completed=True) 1178 | await database.execute(query) 1179 | 1180 | query = notes.select() 1181 | results = await database.fetch_all(query=query) 1182 | assert len(results) == 1 1183 | finally: 1184 | query = notes.delete() 1185 | await database.execute(query) 1186 | 1187 | 1188 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 1189 | @async_adapter 1190 | async def test_connect_and_disconnect(database_url): 1191 | """ 1192 | Test explicit connect() and disconnect(). 1193 | """ 1194 | database = Database(database_url) 1195 | 1196 | assert not database.is_connected 1197 | await database.connect() 1198 | assert database.is_connected 1199 | await database.disconnect() 1200 | assert not database.is_connected 1201 | 1202 | # connect and disconnect idempotence 1203 | await database.connect() 1204 | await database.connect() 1205 | assert database.is_connected 1206 | await database.disconnect() 1207 | await database.disconnect() 1208 | assert not database.is_connected 1209 | 1210 | 1211 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 1212 | @async_adapter 1213 | async def test_connection_context_same_task(database_url): 1214 | async with Database(database_url) as database: 1215 | async with database.connection() as connection_1: 1216 | async with database.connection() as connection_2: 1217 | assert connection_1 is connection_2 1218 | 1219 | 1220 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 1221 | @async_adapter 1222 | async def test_connection_context_multiple_sibling_tasks(database_url): 1223 | async with Database(database_url) as database: 1224 | connection_1 = None 1225 | connection_2 = None 1226 | test_complete = asyncio.Event() 1227 | 1228 | async def get_connection_1(): 1229 | nonlocal connection_1 1230 | 1231 | async with database.connection() as connection: 1232 | connection_1 = connection 1233 | await test_complete.wait() 1234 | 1235 | async def get_connection_2(): 1236 | nonlocal connection_2 1237 | 1238 | async with database.connection() as connection: 1239 | connection_2 = connection 1240 | await test_complete.wait() 1241 | 1242 | task_1 = asyncio.create_task(get_connection_1()) 1243 | task_2 = asyncio.create_task(get_connection_2()) 1244 | while connection_1 is None or connection_2 is None: 1245 | await asyncio.sleep(0.000001) 1246 | assert connection_1 is not connection_2 1247 | test_complete.set() 1248 | await task_1 1249 | await task_2 1250 | 1251 | 1252 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 1253 | @async_adapter 1254 | async def test_connection_context_multiple_tasks(database_url): 1255 | async with Database(database_url) as database: 1256 | parent_connection = database.connection() 1257 | connection_1 = None 1258 | connection_2 = None 1259 | task_1_ready = asyncio.Event() 1260 | task_2_ready = asyncio.Event() 1261 | test_complete = asyncio.Event() 1262 | 1263 | async def get_connection_1(): 1264 | nonlocal connection_1 1265 | 1266 | async with database.connection() as connection: 1267 | connection_1 = connection 1268 | task_1_ready.set() 1269 | await test_complete.wait() 1270 | 1271 | async def get_connection_2(): 1272 | nonlocal connection_2 1273 | 1274 | async with database.connection() as connection: 1275 | connection_2 = connection 1276 | task_2_ready.set() 1277 | await test_complete.wait() 1278 | 1279 | task_1 = asyncio.create_task(get_connection_1()) 1280 | task_2 = asyncio.create_task(get_connection_2()) 1281 | await task_1_ready.wait() 1282 | await task_2_ready.wait() 1283 | 1284 | assert connection_1 is not parent_connection 1285 | assert connection_2 is not parent_connection 1286 | assert connection_1 is not connection_2 1287 | 1288 | test_complete.set() 1289 | await task_1 1290 | await task_2 1291 | 1292 | 1293 | @pytest.mark.parametrize( 1294 | "database_url1,database_url2", 1295 | ( 1296 | pytest.param(db1, db2, id=f"{db1} | {db2}") 1297 | for (db1, db2) in itertools.combinations(DATABASE_URLS, 2) 1298 | ), 1299 | ) 1300 | @async_adapter 1301 | async def test_connection_context_multiple_databases(database_url1, database_url2): 1302 | async with Database(database_url1) as database1: 1303 | async with Database(database_url2) as database2: 1304 | assert database1.connection() is not database2.connection() 1305 | 1306 | 1307 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 1308 | @async_adapter 1309 | async def test_connection_context_with_raw_connection(database_url): 1310 | """ 1311 | Test connection contexts with respect to the raw connection. 1312 | """ 1313 | async with Database(database_url) as database: 1314 | async with database.connection() as connection_1: 1315 | async with database.connection() as connection_2: 1316 | assert connection_1 is connection_2 1317 | assert connection_1.raw_connection is connection_2.raw_connection 1318 | 1319 | 1320 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 1321 | @async_adapter 1322 | async def test_queries_with_expose_backend_connection(database_url): 1323 | """ 1324 | Replication of `execute()`, `execute_many()`, `fetch_all()``, and 1325 | `fetch_one()` using the raw driver interface. 1326 | """ 1327 | async with Database(database_url) as database: 1328 | async with database.connection() as connection: 1329 | async with connection.transaction(force_rollback=True): 1330 | # Get the raw connection 1331 | raw_connection = connection.raw_connection 1332 | 1333 | # Insert query 1334 | if database.url.scheme in [ 1335 | "mysql", 1336 | "mysql+asyncmy", 1337 | "mysql+aiomysql", 1338 | "postgresql+aiopg", 1339 | ]: 1340 | insert_query = "INSERT INTO notes (text, completed) VALUES (%s, %s)" 1341 | else: 1342 | insert_query = "INSERT INTO notes (text, completed) VALUES ($1, $2)" 1343 | 1344 | # execute() 1345 | values = ("example1", True) 1346 | 1347 | if database.url.scheme in [ 1348 | "mysql", 1349 | "mysql+aiomysql", 1350 | "postgresql+aiopg", 1351 | ]: 1352 | cursor = await raw_connection.cursor() 1353 | await cursor.execute(insert_query, values) 1354 | elif database.url.scheme == "mysql+asyncmy": 1355 | async with raw_connection.cursor() as cursor: 1356 | await cursor.execute(insert_query, values) 1357 | elif database.url.scheme in ["postgresql", "postgresql+asyncpg"]: 1358 | await raw_connection.execute(insert_query, *values) 1359 | elif database.url.scheme in ["sqlite", "sqlite+aiosqlite"]: 1360 | await raw_connection.execute(insert_query, values) 1361 | 1362 | # execute_many() 1363 | values = [("example2", False), ("example3", True)] 1364 | 1365 | if database.url.scheme in ["mysql", "mysql+aiomysql"]: 1366 | cursor = await raw_connection.cursor() 1367 | await cursor.executemany(insert_query, values) 1368 | elif database.url.scheme == "mysql+asyncmy": 1369 | async with raw_connection.cursor() as cursor: 1370 | await cursor.executemany(insert_query, values) 1371 | elif database.url.scheme == "postgresql+aiopg": 1372 | cursor = await raw_connection.cursor() 1373 | # No async support for `executemany` 1374 | for value in values: 1375 | await cursor.execute(insert_query, value) 1376 | else: 1377 | await raw_connection.executemany(insert_query, values) 1378 | 1379 | # Select query 1380 | select_query = "SELECT notes.id, notes.text, notes.completed FROM notes" 1381 | 1382 | # fetch_all() 1383 | if database.url.scheme in [ 1384 | "mysql", 1385 | "mysql+aiomysql", 1386 | "postgresql+aiopg", 1387 | ]: 1388 | cursor = await raw_connection.cursor() 1389 | await cursor.execute(select_query) 1390 | results = await cursor.fetchall() 1391 | elif database.url.scheme == "mysql+asyncmy": 1392 | async with raw_connection.cursor() as cursor: 1393 | await cursor.execute(select_query) 1394 | results = await cursor.fetchall() 1395 | elif database.url.scheme in ["postgresql", "postgresql+asyncpg"]: 1396 | results = await raw_connection.fetch(select_query) 1397 | elif database.url.scheme in ["sqlite", "sqlite+aiosqlite"]: 1398 | results = await raw_connection.execute_fetchall(select_query) 1399 | 1400 | assert len(results) == 3 1401 | # Raw output for the raw request 1402 | assert results[0][1] == "example1" 1403 | assert results[0][2] == True 1404 | assert results[1][1] == "example2" 1405 | assert results[1][2] == False 1406 | assert results[2][1] == "example3" 1407 | assert results[2][2] == True 1408 | 1409 | # fetch_one() 1410 | if database.url.scheme in ["postgresql", "postgresql+asyncpg"]: 1411 | result = await raw_connection.fetchrow(select_query) 1412 | elif database.url.scheme == "mysql+asyncmy": 1413 | async with raw_connection.cursor() as cursor: 1414 | await cursor.execute(select_query) 1415 | result = await cursor.fetchone() 1416 | else: 1417 | cursor = await raw_connection.cursor() 1418 | await cursor.execute(select_query) 1419 | result = await cursor.fetchone() 1420 | 1421 | # Raw output for the raw request 1422 | assert result[1] == "example1" 1423 | assert result[2] == True 1424 | 1425 | 1426 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 1427 | @async_adapter 1428 | async def test_database_url_interface(database_url): 1429 | """ 1430 | Test that Database instances expose a `.url` attribute. 1431 | """ 1432 | async with Database(database_url) as database: 1433 | assert isinstance(database.url, DatabaseURL) 1434 | assert database.url == database_url 1435 | 1436 | 1437 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 1438 | @async_adapter 1439 | async def test_concurrent_access_on_single_connection(database_url): 1440 | async with Database(database_url, force_rollback=True) as database: 1441 | 1442 | async def db_lookup(): 1443 | await database.fetch_one("SELECT 1 AS value") 1444 | 1445 | await asyncio.gather( 1446 | db_lookup(), 1447 | db_lookup(), 1448 | ) 1449 | 1450 | 1451 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 1452 | @async_adapter 1453 | async def test_concurrent_transactions_on_single_connection(database_url: str): 1454 | async with Database(database_url) as database: 1455 | 1456 | @database.transaction() 1457 | async def db_lookup(): 1458 | await database.fetch_one(query="SELECT 1 AS value") 1459 | 1460 | await asyncio.gather( 1461 | db_lookup(), 1462 | db_lookup(), 1463 | ) 1464 | 1465 | 1466 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 1467 | @async_adapter 1468 | async def test_concurrent_tasks_on_single_connection(database_url: str): 1469 | async with Database(database_url) as database: 1470 | 1471 | async def db_lookup(): 1472 | await database.fetch_one(query="SELECT 1 AS value") 1473 | 1474 | await asyncio.gather( 1475 | asyncio.create_task(db_lookup()), 1476 | asyncio.create_task(db_lookup()), 1477 | ) 1478 | 1479 | 1480 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 1481 | @async_adapter 1482 | async def test_concurrent_task_transactions_on_single_connection(database_url: str): 1483 | async with Database(database_url) as database: 1484 | 1485 | @database.transaction() 1486 | async def db_lookup(): 1487 | await database.fetch_one(query="SELECT 1 AS value") 1488 | 1489 | await asyncio.gather( 1490 | asyncio.create_task(db_lookup()), 1491 | asyncio.create_task(db_lookup()), 1492 | ) 1493 | 1494 | 1495 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 1496 | def test_global_connection_is_initialized_lazily(database_url): 1497 | """ 1498 | Ensure that global connection is initialized at latest possible time 1499 | so it's _query_lock will belong to same event loop that async_adapter has 1500 | initialized. 1501 | 1502 | See https://github.com/encode/databases/issues/157 for more context. 1503 | """ 1504 | 1505 | database_url = DatabaseURL(database_url) 1506 | if database_url.dialect != "postgresql": 1507 | pytest.skip("Test requires `pg_sleep()`") 1508 | 1509 | database = Database(database_url, force_rollback=True) 1510 | 1511 | @async_adapter 1512 | async def run_database_queries(): 1513 | async with database: 1514 | 1515 | async def db_lookup(): 1516 | await database.fetch_one("SELECT pg_sleep(1)") 1517 | 1518 | await asyncio.gather(db_lookup(), db_lookup()) 1519 | 1520 | run_database_queries() 1521 | 1522 | 1523 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 1524 | @async_adapter 1525 | async def test_iterate_outside_transaction_with_values(database_url): 1526 | """ 1527 | Ensure `iterate()` works even without a transaction on all drivers. 1528 | The asyncpg driver relies on server-side cursors without hold 1529 | for iteration, which requires a transaction to be created. 1530 | This is mentionned in both their documentation and their test suite. 1531 | """ 1532 | 1533 | database_url = DatabaseURL(database_url) 1534 | if database_url.dialect == "mysql": 1535 | pytest.skip("MySQL does not support `FROM (VALUES ...)` (F641)") 1536 | 1537 | async with Database(database_url) as database: 1538 | query = "SELECT * FROM (VALUES (1), (2), (3), (4), (5)) as t" 1539 | iterate_results = [] 1540 | 1541 | async for result in database.iterate(query=query): 1542 | iterate_results.append(result) 1543 | 1544 | assert len(iterate_results) == 5 1545 | 1546 | 1547 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 1548 | @async_adapter 1549 | async def test_iterate_outside_transaction_with_temp_table(database_url): 1550 | """ 1551 | Same as test_iterate_outside_transaction_with_values but uses a 1552 | temporary table instead of a list of values. 1553 | """ 1554 | 1555 | database_url = DatabaseURL(database_url) 1556 | if database_url.dialect == "sqlite": 1557 | pytest.skip("SQLite interface does not work with temporary tables.") 1558 | 1559 | async with Database(database_url) as database: 1560 | query = "CREATE TEMPORARY TABLE no_transac(num INTEGER)" 1561 | await database.execute(query) 1562 | 1563 | query = "INSERT INTO no_transac(num) VALUES (1), (2), (3), (4), (5)" 1564 | await database.execute(query) 1565 | 1566 | query = "SELECT * FROM no_transac" 1567 | iterate_results = [] 1568 | 1569 | async for result in database.iterate(query=query): 1570 | iterate_results.append(result) 1571 | 1572 | assert len(iterate_results) == 5 1573 | 1574 | 1575 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 1576 | @pytest.mark.parametrize("select_query", [notes.select(), "SELECT * FROM notes"]) 1577 | @async_adapter 1578 | async def test_column_names(database_url, select_query): 1579 | """ 1580 | Test that column names are exposed correctly through `._mapping.keys()` on each row. 1581 | """ 1582 | async with Database(database_url) as database: 1583 | async with database.transaction(force_rollback=True): 1584 | # insert values 1585 | query = notes.insert() 1586 | values = {"text": "example1", "completed": True} 1587 | await database.execute(query, values) 1588 | # fetch results 1589 | results = await database.fetch_all(query=select_query) 1590 | assert len(results) == 1 1591 | 1592 | assert sorted(results[0]._mapping.keys()) == ["completed", "id", "text"] 1593 | assert results[0]["text"] == "example1" 1594 | assert results[0]["completed"] == True 1595 | 1596 | 1597 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 1598 | @async_adapter 1599 | async def test_postcompile_queries(database_url): 1600 | """ 1601 | Since SQLAlchemy 1.4, IN operators needs to do render_postcompile 1602 | """ 1603 | async with Database(database_url) as database: 1604 | query = notes.insert() 1605 | values = {"text": "example1", "completed": True} 1606 | await database.execute(query, values) 1607 | 1608 | query = notes.select().where(notes.c.id.in_([2, 3])) 1609 | results = await database.fetch_all(query=query) 1610 | 1611 | assert len(results) == 0 1612 | 1613 | 1614 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 1615 | @async_adapter 1616 | async def test_result_named_access(database_url): 1617 | async with Database(database_url) as database: 1618 | query = notes.insert() 1619 | values = {"text": "example1", "completed": True} 1620 | await database.execute(query, values) 1621 | 1622 | query = notes.select().where(notes.c.text == "example1") 1623 | result = await database.fetch_one(query=query) 1624 | 1625 | assert result.text == "example1" 1626 | assert result.completed is True 1627 | 1628 | 1629 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 1630 | @async_adapter 1631 | async def test_mapping_property_interface(database_url): 1632 | """ 1633 | Test that all connections implement interface with `_mapping` property 1634 | """ 1635 | async with Database(database_url) as database: 1636 | query = notes.select() 1637 | single_result = await database.fetch_one(query=query) 1638 | assert single_result._mapping["text"] == "example1" 1639 | assert single_result._mapping["completed"] is True 1640 | 1641 | list_result = await database.fetch_all(query=query) 1642 | assert list_result[0]._mapping["text"] == "example1" 1643 | assert list_result[0]._mapping["completed"] is True 1644 | 1645 | 1646 | @async_adapter 1647 | async def test_should_not_maintain_ref_when_no_cache_param(): 1648 | async with Database( 1649 | "sqlite:///file::memory:", 1650 | uri=True, 1651 | ) as database: 1652 | query = sqlalchemy.schema.CreateTable(notes) 1653 | await database.execute(query) 1654 | 1655 | query = notes.insert() 1656 | values = {"text": "example1", "completed": True} 1657 | with pytest.raises(sqlite3.OperationalError): 1658 | await database.execute(query, values) 1659 | 1660 | 1661 | @async_adapter 1662 | async def test_should_maintain_ref_when_cache_param(): 1663 | async with Database( 1664 | "sqlite:///file::memory:?cache=shared", 1665 | uri=True, 1666 | ) as database: 1667 | query = sqlalchemy.schema.CreateTable(notes) 1668 | await database.execute(query) 1669 | 1670 | query = notes.insert() 1671 | values = {"text": "example1", "completed": True} 1672 | await database.execute(query, values) 1673 | 1674 | query = notes.select().where(notes.c.text == "example1") 1675 | result = await database.fetch_one(query=query) 1676 | assert result.text == "example1" 1677 | assert result.completed is True 1678 | 1679 | 1680 | @async_adapter 1681 | async def test_should_remove_ref_on_disconnect(): 1682 | async with Database( 1683 | "sqlite:///file::memory:?cache=shared", 1684 | uri=True, 1685 | ) as database: 1686 | query = sqlalchemy.schema.CreateTable(notes) 1687 | await database.execute(query) 1688 | 1689 | query = notes.insert() 1690 | values = {"text": "example1", "completed": True} 1691 | await database.execute(query, values) 1692 | 1693 | # Run garbage collection to reset the database if we dropped the reference 1694 | gc.collect() 1695 | 1696 | async with Database( 1697 | "sqlite:///file::memory:?cache=shared", 1698 | uri=True, 1699 | ) as database: 1700 | query = notes.select() 1701 | with pytest.raises(sqlite3.OperationalError): 1702 | await database.fetch_all(query=query) 1703 | 1704 | 1705 | @pytest.mark.parametrize("database_url", DATABASE_URLS) 1706 | @async_adapter 1707 | async def test_mapping_property_interface(database_url): 1708 | """ 1709 | Test that all connections implement interface with `_mapping` property 1710 | """ 1711 | async with Database(database_url) as database: 1712 | query = notes.insert() 1713 | values = {"text": "example1", "completed": True} 1714 | await database.execute(query, values) 1715 | 1716 | query = notes.select() 1717 | single_result = await database.fetch_one(query=query) 1718 | assert single_result._mapping["text"] == "example1" 1719 | assert single_result._mapping["completed"] is True 1720 | 1721 | list_result = await database.fetch_all(query=query) 1722 | assert list_result[0]._mapping["text"] == "example1" 1723 | assert list_result[0]._mapping["completed"] is True 1724 | -------------------------------------------------------------------------------- /tests/test_importer.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from databases.importer import ImportFromStringError, import_from_string 4 | 5 | 6 | def test_invalid_format(): 7 | with pytest.raises(ImportFromStringError) as exc_info: 8 | import_from_string("example:") 9 | expected = 'Import string "example:" must be in format "